Fit shifts in functional effects¶
Import Python modules.
We use multidms for the fitting:
In [1]:
import dms_variants.codonvarianttable
import matplotlib.pyplot as plt
import multidms
import pandas as pd
import seaborn
This notebook is parameterized by papermill.
The next cell is tagged as parameters to get the passed parameters.
In [2]:
# this cell is tagged parameters for `papermill` parameterization
params = None
shifts_csv = None
threads = None
In [3]:
# Parameters
params = {
"clip_lower": "median_stop",
"clip_upper": None,
"collapse_identical_variants": False,
"latent_offset": True,
"lasso_shifts": [0, "1e-06", "5e-06", "1e-05", "5e-05", 0.0001, 0.0005, 0.001],
"reference": "293T",
"conditions": {
"293T": "LibA-220823-293T-2",
"human_aDG": "LibA-230722-human-2",
"mastomys_aDG": "LibA-230722-mastomys-2",
},
}
shifts_csv = "results/func_effect_shifts/by_comparison/LibA-2_shifts.csv"
threads = 1
Read and clip functional scores:
In [4]:
func_scores_df = pd.concat(
[
pd.read_csv(
f"results/func_scores/{selection}_func_scores.csv", na_filter=None
).assign(condition=condition)
for condition, selection in params["conditions"].items()
]
).pipe(dms_variants.codonvarianttable.CodonVariantTable.classifyVariants)
median_stop = func_scores_df.query("variant_class == 'stop'")["func_score"].median()
for bound in ["upper", "lower"]:
clip = params[f"clip_{bound}"]
if clip is None:
print(f"No clipping on {bound} bound of functional scores")
else:
if clip == "median_stop":
if pd.isnull(median_stop):
raise ValueError(f"{median_stop=}")
clip = median_stop
assert isinstance(clip, (int, float)), clip
print(f"Clipping {bound} bound of functional scores to {clip}")
func_scores_df["func_score"] = func_scores_df["func_score"].clip(
**{bound: clip}
)
No clipping on upper bound of functional scores Clipping lower bound of functional scores to -4.035
Initialize data for multidms:
In [5]:
data = multidms.Data(
variants_df=func_scores_df,
reference=params["reference"],
alphabet=multidms.AAS_WITHSTOP_WITHGAP,
collapse_identical_variants=params["collapse_identical_variants"],
letter_suffixed_sites=True,
verbose=True,
nb_workers=threads,
assert_site_integrity=True,
)
inferring site map for 293T
inferring site map for human_aDG
inferring site map for mastomys_aDG
Asserting site integrity
INFO: Pandarallel will run on 1 workers. INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
unknown cond wildtype at sites: [], dropping: 0 variantswhich have mutations at those sites.
invalid non-identical-sites: [], dropping 0 variants Converting mutations for 293T is reference, skipping Converting mutations for human_aDG is reference, skipping Converting mutations for mastomys_aDG is reference, skipping
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Now initialize and fit the model for each lasso penalty:
In [6]:
lasso_shifts = params["lasso_shifts"]
assert len(lasso_shifts) == len(set(lasso_shifts))
n_lasso = len(params["lasso_shifts"])
fig, ax = plt.subplots(n_lasso, 2, figsize=[6, n_lasso * 3])
mutations_df = []
for i, lasso_shift in enumerate(lasso_shifts):
lasso_shift = float(lasso_shift)
print(f"Fitting model for {lasso_shift=}")
model = multidms.Model(data)
model.fit(lasso_shift=lasso_shift)
mutations_df.append(
model.get_mutations_df(phenotype_as_effect=True).assign(lasso_shift=lasso_shift)
)
model.plot_epistasis(ax=ax[i, 1], alpha=0.1, show=False, legend=not i)
model.plot_pred_accuracy(ax=ax[i, 0], alpha=0.1, show=False, legend=False)
ax[i, 1].set_title(f"Epistasis fit (lasso {lasso_shift})")
ax[i, 0].set_title(f"Accuracy (lasso {lasso_shift})")
plt.show()
Fitting model for lasso_shift=0.0
Fitting model for lasso_shift=1e-06
Fitting model for lasso_shift=5e-06
Fitting model for lasso_shift=1e-05
Fitting model for lasso_shift=5e-05
Fitting model for lasso_shift=0.0001
Fitting model for lasso_shift=0.0005
Fitting model for lasso_shift=0.001
Get the mutations effect shifts frame for each lasso penalty, adding in wildtype:
In [7]:
# get shifts for mutations
mut_shifts = pd.concat(mutations_df).rename(
columns={
"wts": "wildtype",
"sites": "site",
"muts": "mutant",
"beta": "latent_phenotype_effect",
}
)
# we do not keep predicted functional effects
mut_shifts = mut_shifts[
[c for c in mut_shifts.columns if not c.startswith("predicted")]
]
# add wildtypes
mut_shifts = (
pd.concat(
[
mut_shifts,
pd.concat(
[
pd.DataFrame(
{
"site": data.site_map.index,
"wildtype": data.site_map[str(params["reference"])],
"mutant": data.site_map[str(params["reference"])],
"latent_phenotype_effect": 0,
**{
col: 0
for col in mut_shifts.columns
if col.startswith("shift_")
},
}
).assign(lasso_shift=float(lasso_shift))
for lasso_shift in lasso_shifts
]
),
]
)
.sort_values(["lasso_shift", "site", "mutant"])
.reset_index(drop=True)
)
print(f"Saving shifts to {shifts_csv}")
mut_shifts.to_csv(shifts_csv, index=False, float_format="%.4g")
Saving shifts to results/func_effect_shifts/by_comparison/LibA-2_shifts.csv
Plot distribution of shifts for all non-wildtype residues for each regularization:
In [8]:
mut_shifts_cols = [c for c in mut_shifts.columns if c.startswith("shift_")]
mut_shifts_tidy = mut_shifts.query("wildtype != mutant").melt(
id_vars=["site", "mutant", "lasso_shift"],
value_vars=mut_shifts_cols,
var_name="condition",
value_name="shift",
)
_ = seaborn.displot(
mut_shifts_tidy,
x="shift",
col="condition",
row="lasso_shift",
hue="condition",
height=1.9,
aspect=1.8,
facet_kws={"margin_titles": True},
)