Fit shifts in functional effects¶
Import Python modules.
We use multidms for the fitting:
In [1]:
import alignparse.utils
import dms_variants.codonvarianttable
import matplotlib.pyplot as plt
import multidms
import pandas as pd
import seaborn
/fh/fast/bloom_j/software/miniforge3/envs/dms-vep-pipeline-3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
This notebook is parameterized by papermill.
The next cell is tagged as parameters to get the passed parameters.
In [2]:
# this cell is tagged parameters for `papermill` parameterization
params = None
shifts_csv = None
site_numbering_map = None
threads = None
In [3]:
# Parameters
params = {
"clip_lower": "median_stop",
"clip_upper": None,
"collapse_identical_variants": False,
"latent_offset": True,
"lasso_shifts": ["1e-05", "5e-05", 0.0001, 0.0002, 0.001],
"reference": 220210,
"conditions": {
"220210": "LibA-220210-293T_ACE2-1",
"220302": "LibA-220302-293T_ACE2-1",
},
}
site_numbering_map = "data/site_numbering_map.csv"
shifts_csv = (
"results/func_effect_shifts/by_comparison/LibA-date_comparison-1_shifts.csv"
)
threads = 1
Read and clip functional scores:
In [4]:
func_scores_df = pd.concat(
[
pd.read_csv(
f"results/func_scores/{selection}_func_scores.csv", na_filter=None
).assign(condition=condition)
for condition, selection in params["conditions"].items()
]
).pipe(dms_variants.codonvarianttable.CodonVariantTable.classifyVariants)
median_stop = func_scores_df.query("variant_class == 'stop'")["func_score"].median()
for bound in ["upper", "lower"]:
clip = params[f"clip_{bound}"]
if clip is None:
print(f"No clipping on {bound} bound of functional scores")
else:
if clip == "median_stop":
if pd.isnull(median_stop):
raise ValueError(f"{median_stop=}")
clip = median_stop
assert isinstance(clip, (int, float)), clip
print(f"Clipping {bound} bound of functional scores to {clip}")
func_scores_df["func_score"] = func_scores_df["func_score"].clip(
**{bound: clip}
)
No clipping on upper bound of functional scores Clipping lower bound of functional scores to -4.977
Renumber to sequential sites to allow arbitrary strings as sites:
In [5]:
site_numbering = pd.read_csv(site_numbering_map)
assert len(site_numbering) == site_numbering["sequential_site"].nunique()
assert len(site_numbering) == site_numbering["reference_site"].nunique()
renumber_to_sequential = alignparse.utils.MutationRenumber(
number_mapping=site_numbering,
old_num_col="reference_site",
new_num_col="sequential_site",
wt_nt_col=None,
allow_arbitrary_numbers=True,
)
renumber_to_reference = alignparse.utils.MutationRenumber(
number_mapping=site_numbering,
old_num_col="sequential_site",
new_num_col="reference_site",
wt_nt_col=None,
allow_arbitrary_numbers=True,
)
func_scores_df_sequential = func_scores_df.assign(
aa_substitutions=lambda x: x["aa_substitutions"].apply(
renumber_to_sequential.renumber_muts,
allow_gaps=True,
allow_stop=True,
)
)
Initialize data for multidms:
In [6]:
data = multidms.Data(
variants_df=func_scores_df_sequential,
reference=params["reference"],
alphabet=multidms.AAS_WITHSTOP_WITHGAP,
collapse_identical_variants=params["collapse_identical_variants"],
verbose=False,
nb_workers=threads,
assert_site_integrity=True,
)
Now initialize and fit the model for each lasso penalty:
In [7]:
lasso_shifts = params["lasso_shifts"]
assert len(lasso_shifts) == len(set(lasso_shifts))
n_lasso = len(params["lasso_shifts"])
fig, ax = plt.subplots(n_lasso, 2, figsize=[6, n_lasso * 3])
mutations_df = []
for i, lasso_shift in enumerate(lasso_shifts):
lasso_shift = float(lasso_shift)
print(f"Fitting model for {lasso_shift=}")
model = multidms.Model(data)
model.fit(lasso_shift=lasso_shift)
mutations_df.append(
model.get_mutations_df(phenotype_as_effect=True).assign(lasso_shift=lasso_shift)
)
model.plot_epistasis(ax=ax[i, 1], alpha=0.1, show=False, legend=not i)
model.plot_pred_accuracy(ax=ax[i, 0], alpha=0.1, show=False, legend=False)
ax[i, 1].set_title(f"Epistasis fit (lasso {lasso_shift})")
ax[i, 0].set_title(f"Accuracy (lasso {lasso_shift})")
plt.show()
Fitting model for lasso_shift=1e-05
Fitting model for lasso_shift=5e-05
Fitting model for lasso_shift=0.0001
Fitting model for lasso_shift=0.0002
Fitting model for lasso_shift=0.001
Get the mutations effect shifts frame for each lasso penalty, adding in wildtype:
In [8]:
# get shifts for mutations
mut_shifts = pd.concat(mutations_df).rename(
columns={
"wts": "wildtype",
"sites": "site",
"muts": "mutant",
"beta": "latent_phenotype_effect",
}
)
# we do not keep predicted functional effects
mut_shifts = mut_shifts[
[c for c in mut_shifts.columns if not c.startswith("predicted")]
]
# add wildtypes
mut_shifts = (
pd.concat(
[
mut_shifts,
pd.concat(
[
pd.DataFrame(
{
"site": data.site_map.index,
"wildtype": data.site_map[str(params["reference"])],
"mutant": data.site_map[str(params["reference"])],
"latent_phenotype_effect": 0,
**{
col: 0
for col in mut_shifts.columns
if col.startswith("shift_")
},
}
).assign(lasso_shift=float(lasso_shift))
for lasso_shift in lasso_shifts
]
),
]
)
.sort_values(["lasso_shift", "site", "mutant"])
# convert back to reference numbering
.assign(
site=lambda x: x["site"].map(
site_numbering.set_index("sequential_site")["reference_site"].to_dict()
)
)
.reset_index(drop=True)
)
print(f"Saving shifts to {shifts_csv}")
mut_shifts.to_csv(shifts_csv, index=False, float_format="%.4g")
Saving shifts to results/func_effect_shifts/by_comparison/LibA-date_comparison-1_shifts.csv
Plot distribution of shifts for all non-wildtype residues for each regularization:
In [9]:
mut_shifts_cols = [c for c in mut_shifts.columns if c.startswith("shift_")]
mut_shifts_tidy = mut_shifts.query("wildtype != mutant").melt(
id_vars=["site", "mutant", "lasso_shift"],
value_vars=mut_shifts_cols,
var_name="condition",
value_name="shift",
)
_ = seaborn.displot(
mut_shifts_tidy,
x="shift",
col="condition",
row="lasso_shift",
hue="condition",
height=1.9,
aspect=1.8,
facet_kws={"margin_titles": True},
)