Fit shifts in functional effects¶

Import Python modules. We use multidms for the fitting:

In [1]:
import alignparse.utils

import dms_variants.codonvarianttable

import matplotlib.pyplot as plt

import multidms

import pandas as pd

import seaborn
/fh/fast/bloom_j/software/miniforge3/envs/dms-vep-pipeline-3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html

This notebook is parameterized by papermill. The next cell is tagged as parameters to get the passed parameters.

In [2]:
# this cell is tagged parameters for `papermill` parameterization
params = None
shifts_csv = None
site_numbering_map = None
threads = None
In [3]:
# Parameters
params = {
    "clip_lower": "median_stop",
    "clip_upper": None,
    "collapse_identical_variants": False,
    "latent_offset": True,
    "lasso_shifts": ["1e-05", "5e-05", 0.0001, 0.0002, 0.001],
    "reference": 220210,
    "conditions": {
        "220210": "LibA-220210-293T_ACE2-2",
        "220302": "LibA-220302-293T_ACE2-2",
    },
}
site_numbering_map = "data/site_numbering_map.csv"
shifts_csv = (
    "results/func_effect_shifts/by_comparison/LibA-date_comparison-2_shifts.csv"
)
threads = 1

Read and clip functional scores:

In [4]:
func_scores_df = pd.concat(
    [
        pd.read_csv(
            f"results/func_scores/{selection}_func_scores.csv", na_filter=None
        ).assign(condition=condition)
        for condition, selection in params["conditions"].items()
    ]
).pipe(dms_variants.codonvarianttable.CodonVariantTable.classifyVariants)

median_stop = func_scores_df.query("variant_class == 'stop'")["func_score"].median()

for bound in ["upper", "lower"]:
    clip = params[f"clip_{bound}"]
    if clip is None:
        print(f"No clipping on {bound} bound of functional scores")
    else:
        if clip == "median_stop":
            if pd.isnull(median_stop):
                raise ValueError(f"{median_stop=}")
            clip = median_stop
        assert isinstance(clip, (int, float)), clip
        print(f"Clipping {bound} bound of functional scores to {clip}")
        func_scores_df["func_score"] = func_scores_df["func_score"].clip(
            **{bound: clip}
        )
No clipping on upper bound of functional scores
Clipping lower bound of functional scores to -5.025

Renumber to sequential sites to allow arbitrary strings as sites:

In [5]:
site_numbering = pd.read_csv(site_numbering_map)
assert len(site_numbering) == site_numbering["sequential_site"].nunique()
assert len(site_numbering) == site_numbering["reference_site"].nunique()

renumber_to_sequential = alignparse.utils.MutationRenumber(
    number_mapping=site_numbering,
    old_num_col="reference_site",
    new_num_col="sequential_site",
    wt_nt_col=None,
    allow_arbitrary_numbers=True,
)

renumber_to_reference = alignparse.utils.MutationRenumber(
    number_mapping=site_numbering,
    old_num_col="sequential_site",
    new_num_col="reference_site",
    wt_nt_col=None,
    allow_arbitrary_numbers=True,
)

func_scores_df_sequential = func_scores_df.assign(
    aa_substitutions=lambda x: x["aa_substitutions"].apply(
        renumber_to_sequential.renumber_muts,
        allow_gaps=True,
        allow_stop=True,
    )
)

Initialize data for multidms:

In [6]:
data = multidms.Data(
    variants_df=func_scores_df_sequential,
    reference=params["reference"],
    alphabet=multidms.AAS_WITHSTOP_WITHGAP,
    collapse_identical_variants=params["collapse_identical_variants"],
    verbose=False,
    nb_workers=threads,
    assert_site_integrity=True,
)

Now initialize and fit the model for each lasso penalty:

In [7]:
lasso_shifts = params["lasso_shifts"]
assert len(lasso_shifts) == len(set(lasso_shifts))
n_lasso = len(params["lasso_shifts"])
fig, ax = plt.subplots(n_lasso, 2, figsize=[6, n_lasso * 3])

mutations_df = []
for i, lasso_shift in enumerate(lasso_shifts):
    lasso_shift = float(lasso_shift)
    print(f"Fitting model for {lasso_shift=}")
    model = multidms.Model(data)
    model.fit(lasso_shift=lasso_shift)
    mutations_df.append(
        model.get_mutations_df(phenotype_as_effect=True).assign(lasso_shift=lasso_shift)
    )
    model.plot_epistasis(ax=ax[i, 1], alpha=0.1, show=False, legend=not i)
    model.plot_pred_accuracy(ax=ax[i, 0], alpha=0.1, show=False, legend=False)
    ax[i, 1].set_title(f"Epistasis fit (lasso {lasso_shift})")
    ax[i, 0].set_title(f"Accuracy (lasso {lasso_shift})")

plt.show()
Fitting model for lasso_shift=1e-05
Fitting model for lasso_shift=5e-05
Fitting model for lasso_shift=0.0001
Fitting model for lasso_shift=0.0002
Fitting model for lasso_shift=0.001
No description has been provided for this image

Get the mutations effect shifts frame for each lasso penalty, adding in wildtype:

In [8]:
# get shifts for mutations
mut_shifts = pd.concat(mutations_df).rename(
    columns={
        "wts": "wildtype",
        "sites": "site",
        "muts": "mutant",
        "beta": "latent_phenotype_effect",
    }
)

# we do not keep predicted functional effects
mut_shifts = mut_shifts[
    [c for c in mut_shifts.columns if not c.startswith("predicted")]
]

# add wildtypes
mut_shifts = (
    pd.concat(
        [
            mut_shifts,
            pd.concat(
                [
                    pd.DataFrame(
                        {
                            "site": data.site_map.index,
                            "wildtype": data.site_map[str(params["reference"])],
                            "mutant": data.site_map[str(params["reference"])],
                            "latent_phenotype_effect": 0,
                            **{
                                col: 0
                                for col in mut_shifts.columns
                                if col.startswith("shift_")
                            },
                        }
                    ).assign(lasso_shift=float(lasso_shift))
                    for lasso_shift in lasso_shifts
                ]
            ),
        ]
    )
    .sort_values(["lasso_shift", "site", "mutant"])
    # convert back to reference numbering
    .assign(
        site=lambda x: x["site"].map(
            site_numbering.set_index("sequential_site")["reference_site"].to_dict()
        )
    )
    .reset_index(drop=True)
)

print(f"Saving shifts to {shifts_csv}")

mut_shifts.to_csv(shifts_csv, index=False, float_format="%.4g")
Saving shifts to results/func_effect_shifts/by_comparison/LibA-date_comparison-2_shifts.csv

Plot distribution of shifts for all non-wildtype residues for each regularization:

In [9]:
mut_shifts_cols = [c for c in mut_shifts.columns if c.startswith("shift_")]

mut_shifts_tidy = mut_shifts.query("wildtype != mutant").melt(
    id_vars=["site", "mutant", "lasso_shift"],
    value_vars=mut_shifts_cols,
    var_name="condition",
    value_name="shift",
)

_ = seaborn.displot(
    mut_shifts_tidy,
    x="shift",
    col="condition",
    row="lasso_shift",
    hue="condition",
    height=1.9,
    aspect=1.8,
    facet_kws={"margin_titles": True},
)
No description has been provided for this image