Fit shifts in functional effects¶
Import Python modules.
We use multidms
for the fitting:
In [1]:
import dms_variants.codonvarianttable
import matplotlib.pyplot as plt
import multidms
import pandas as pd
import seaborn
This notebook is parameterized by papermill
.
The next cell is tagged as parameters
to get the passed parameters.
In [2]:
# this cell is tagged parameters for `papermill` parameterization
params = None
shifts_csv = None
threads = None
In [3]:
# Parameters
params = {
"clip_lower": "median_stop",
"clip_upper": None,
"collapse_identical_variants": False,
"latent_offset": True,
"lasso_shifts": ["1e-05", "5e-05", 0.0001, 0.0002, 0.001],
"reference": "high_ACE2",
"conditions": {
"high_ACE2": "Lib2-230614_high_ACE2",
"medium_ACE2": "Lib2-230614_medium_ACE2",
},
}
shifts_csv = (
"results/func_effect_shifts/by_comparison/Lib2-230614_ACE2_expression_shifts.csv"
)
threads = 1
Read and clip functional scores:
In [4]:
func_scores_df = pd.concat(
[
pd.read_csv(
f"results/func_scores/{selection}_func_scores.csv", na_filter=None
).assign(condition=condition)
for condition, selection in params["conditions"].items()
]
).pipe(dms_variants.codonvarianttable.CodonVariantTable.classifyVariants)
median_stop = func_scores_df.query("variant_class == 'stop'")["func_score"].median()
for bound in ["upper", "lower"]:
clip = params[f"clip_{bound}"]
if clip is None:
print(f"No clipping on {bound} bound of functional scores")
else:
if clip == "median_stop":
if pd.isnull(median_stop):
raise ValueError(f"{median_stop=}")
clip = median_stop
assert isinstance(clip, (int, float)), clip
print(f"Clipping {bound} bound of functional scores to {clip}")
func_scores_df["func_score"] = func_scores_df["func_score"].clip(
**{bound: clip}
)
No clipping on upper bound of functional scores Clipping lower bound of functional scores to -7.1665
Initialize data for multidms
:
In [5]:
data = multidms.Data(
variants_df=func_scores_df,
reference=params["reference"],
alphabet=multidms.AAS_WITHSTOP_WITHGAP,
collapse_identical_variants=params["collapse_identical_variants"],
letter_suffixed_sites=True,
verbose=True,
nb_workers=threads,
assert_site_integrity=True,
)
inferring site map for high_ACE2
inferring site map for medium_ACE2
Asserting site integrity
INFO: Pandarallel will run on 1 workers. INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
unknown cond wildtype at sites: [], dropping: 0 variantswhich have mutations at those sites.
invalid non-identical-sites: [], dropping 0 variants Converting mutations for high_ACE2 is reference, skipping Converting mutations for medium_ACE2 is reference, skipping
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Now initialize and fit the model for each lasso penalty:
In [6]:
lasso_shifts = params["lasso_shifts"]
assert len(lasso_shifts) == len(set(lasso_shifts))
n_lasso = len(params["lasso_shifts"])
fig, ax = plt.subplots(n_lasso, 2, figsize=[6, n_lasso * 3])
model = multidms.Model(data)
mutations_df = []
for i, lasso_shift in enumerate(lasso_shifts):
lasso_shift = float(lasso_shift)
print(f"Fitting model for {lasso_shift=}")
model.fit(lasso_shift=lasso_shift)
mutations_df.append(
model.get_mutations_df(phenotype_as_effect=True).assign(lasso_shift=lasso_shift)
)
model.plot_epistasis(ax=ax[i, 1], alpha=0.1, show=False, legend=not i)
model.plot_pred_accuracy(ax=ax[i, 0], alpha=0.1, show=False, legend=False)
ax[i, 1].set_title(f"Epistasis fit (lasso {lasso_shift})")
ax[i, 0].set_title(f"Accuracy (lasso {lasso_shift})")
plt.show()
Fitting model for lasso_shift=1e-05 Fitting model for lasso_shift=5e-05 Fitting model for lasso_shift=0.0001 Fitting model for lasso_shift=0.0002 Fitting model for lasso_shift=0.001
Get the mutations effect shifts frame for each lasso penalty, adding in wildtype:
In [7]:
# get shifts for mutations
mut_shifts = pd.concat(mutations_df).rename(
columns={
"wts": "wildtype",
"sites": "site",
"muts": "mutant",
"beta": "latent_phenotype_effect",
}
)
# we do not keep predicted functional effects
mut_shifts = mut_shifts[
[c for c in mut_shifts.columns if not c.startswith("predicted")]
]
# add wildtypes
mut_shifts = (
pd.concat(
[
mut_shifts,
pd.concat(
[
pd.DataFrame(
{
"site": data.site_map.index,
"wildtype": data.site_map[str(params["reference"])],
"mutant": data.site_map[str(params["reference"])],
"latent_phenotype_effect": 0,
**{
col: 0
for col in mut_shifts.columns
if col.startswith("shift_")
},
}
).assign(lasso_shift=float(lasso_shift))
for lasso_shift in lasso_shifts
]
),
]
)
.sort_values(["lasso_shift", "site", "mutant"])
.reset_index(drop=True)
)
print(f"Saving shifts to {shifts_csv}")
mut_shifts.to_csv(shifts_csv, index=False, float_format="%.4g")
Saving shifts to results/func_effect_shifts/by_comparison/Lib2-230614_ACE2_expression_shifts.csv
Plot distribution of shifts for all non-wildtype residues for each regularization:
In [8]:
mut_shifts_cols = [c for c in mut_shifts.columns if c.startswith("shift_")]
mut_shifts_tidy = mut_shifts.query("wildtype != mutant").melt(
id_vars=["site", "mutant", "lasso_shift"],
value_vars=mut_shifts_cols,
var_name="condition",
value_name="shift",
)
_ = seaborn.displot(
mut_shifts_tidy,
x="shift",
col="condition",
row="lasso_shift",
hue="condition",
height=1.9,
aspect=1.8,
facet_kws={"margin_titles": True},
)