Fit shifts in functional effects¶
Import Python modules.
We use multidms for the fitting:
In [1]:
import dms_variants.codonvarianttable
import matplotlib.pyplot as plt
import multidms
import pandas as pd
import seaborn
This notebook is parameterized by papermill.
The next cell is tagged as parameters to get the passed parameters.
In [2]:
# this cell is tagged parameters for `papermill` parameterization
params = None
shifts_csv = None
threads = None
In [3]:
# Parameters
params = {
"clip_lower": "median_stop",
"clip_upper": None,
"collapse_identical_variants": False,
"latent_offset": True,
"lasso_shifts": ["1e-05", "5e-05", 0.0001, 0.0002, 0.001],
"reference": "high_ACE2",
"conditions": {
"high_ACE2": "Lib1-230614_high_ACE2",
"medium_ACE2": "Lib1-230614_medium_ACE2",
},
}
shifts_csv = (
"results/func_effect_shifts/by_comparison/Lib1-230614_ACE2_expression_shifts.csv"
)
threads = 1
Read and clip functional scores:
In [4]:
func_scores_df = pd.concat(
[
pd.read_csv(
f"results/func_scores/{selection}_func_scores.csv", na_filter=None
).assign(condition=condition)
for condition, selection in params["conditions"].items()
]
).pipe(dms_variants.codonvarianttable.CodonVariantTable.classifyVariants)
median_stop = func_scores_df.query("variant_class == 'stop'")["func_score"].median()
for bound in ["upper", "lower"]:
clip = params[f"clip_{bound}"]
if clip is None:
print(f"No clipping on {bound} bound of functional scores")
else:
if clip == "median_stop":
if pd.isnull(median_stop):
raise ValueError(f"{median_stop=}")
clip = median_stop
assert isinstance(clip, (int, float)), clip
print(f"Clipping {bound} bound of functional scores to {clip}")
func_scores_df["func_score"] = func_scores_df["func_score"].clip(
**{bound: clip}
)
No clipping on upper bound of functional scores Clipping lower bound of functional scores to -7.055
Initialize data for multidms:
In [5]:
data = multidms.Data(
variants_df=func_scores_df,
reference=params["reference"],
alphabet=multidms.AAS_WITHSTOP_WITHGAP,
collapse_identical_variants=params["collapse_identical_variants"],
letter_suffixed_sites=True,
verbose=True,
nb_workers=threads,
assert_site_integrity=True,
)
inferring site map for high_ACE2
inferring site map for medium_ACE2
Asserting site integrity
INFO: Pandarallel will run on 1 workers. INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
unknown cond wildtype at sites: [], dropping: 0 variantswhich have mutations at those sites.
invalid non-identical-sites: [], dropping 0 variants Converting mutations for high_ACE2 is reference, skipping Converting mutations for medium_ACE2 is reference, skipping
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Now initialize and fit the model for each lasso penalty:
In [6]:
lasso_shifts = params["lasso_shifts"]
assert len(lasso_shifts) == len(set(lasso_shifts))
n_lasso = len(params["lasso_shifts"])
fig, ax = plt.subplots(n_lasso, 2, figsize=[6, n_lasso * 3])
model = multidms.Model(data)
mutations_df = []
for i, lasso_shift in enumerate(lasso_shifts):
lasso_shift = float(lasso_shift)
print(f"Fitting model for {lasso_shift=}")
model.fit(lasso_shift=lasso_shift)
mutations_df.append(
model.get_mutations_df(phenotype_as_effect=True).assign(lasso_shift=lasso_shift)
)
model.plot_epistasis(ax=ax[i, 1], alpha=0.1, show=False, legend=not i)
model.plot_pred_accuracy(ax=ax[i, 0], alpha=0.1, show=False, legend=False)
ax[i, 1].set_title(f"Epistasis fit (lasso {lasso_shift})")
ax[i, 0].set_title(f"Accuracy (lasso {lasso_shift})")
plt.show()
Fitting model for lasso_shift=1e-05 Fitting model for lasso_shift=5e-05 Fitting model for lasso_shift=0.0001 Fitting model for lasso_shift=0.0002 Fitting model for lasso_shift=0.001
Get the mutations effect shifts frame for each lasso penalty, adding in wildtype:
In [7]:
# get shifts for mutations
mut_shifts = pd.concat(mutations_df).rename(
columns={
"wts": "wildtype",
"sites": "site",
"muts": "mutant",
"beta": "latent_phenotype_effect",
}
)
# we do not keep predicted functional effects
mut_shifts = mut_shifts[
[c for c in mut_shifts.columns if not c.startswith("predicted")]
]
# add wildtypes
mut_shifts = (
pd.concat(
[
mut_shifts,
pd.concat(
[
pd.DataFrame(
{
"site": data.site_map.index,
"wildtype": data.site_map[str(params["reference"])],
"mutant": data.site_map[str(params["reference"])],
"latent_phenotype_effect": 0,
**{
col: 0
for col in mut_shifts.columns
if col.startswith("shift_")
},
}
).assign(lasso_shift=float(lasso_shift))
for lasso_shift in lasso_shifts
]
),
]
)
.sort_values(["lasso_shift", "site", "mutant"])
.reset_index(drop=True)
)
print(f"Saving shifts to {shifts_csv}")
mut_shifts.to_csv(shifts_csv, index=False, float_format="%.4g")
Saving shifts to results/func_effect_shifts/by_comparison/Lib1-230614_ACE2_expression_shifts.csv
Plot distribution of shifts for all non-wildtype residues for each regularization:
In [8]:
mut_shifts_cols = [c for c in mut_shifts.columns if c.startswith("shift_")]
mut_shifts_tidy = mut_shifts.query("wildtype != mutant").melt(
id_vars=["site", "mutant", "lasso_shift"],
value_vars=mut_shifts_cols,
var_name="condition",
value_name="shift",
)
_ = seaborn.displot(
mut_shifts_tidy,
x="shift",
col="condition",
row="lasso_shift",
hue="condition",
height=1.9,
aspect=1.8,
facet_kws={"margin_titles": True},
)