Fit shifts in functional effects¶

Import Python modules. We use multidms for the fitting:

In [1]:
import dms_variants.codonvarianttable

import matplotlib.pyplot as plt

import multidms

import pandas as pd

import seaborn

This notebook is parameterized by papermill. The next cell is tagged as parameters to get the passed parameters.

In [2]:
# this cell is tagged parameters for `papermill` parameterization
params = None
shifts_csv = None
threads = None
In [3]:
# Parameters
params = {
    "clip_lower": "median_stop",
    "clip_upper": None,
    "collapse_identical_variants": False,
    "latent_offset": True,
    "lasso_shifts": [0, "1e-06", "5e-06", "1e-05", "5e-05", 0.0001, 0.0005, 0.001],
    "reference": "293T",
    "conditions": {
        "293T": "LibA-220823-293T-1",
        "human_aDG": "LibA-230722-human-1",
        "mastomys_aDG": "LibA-230722-mastomys-1",
    },
}
shifts_csv = "results/func_effect_shifts/by_comparison/LibA-1_shifts.csv"
threads = 1

Read and clip functional scores:

In [4]:
func_scores_df = pd.concat(
    [
        pd.read_csv(
            f"results/func_scores/{selection}_func_scores.csv", na_filter=None
        ).assign(condition=condition)
        for condition, selection in params["conditions"].items()
    ]
).pipe(dms_variants.codonvarianttable.CodonVariantTable.classifyVariants)

median_stop = func_scores_df.query("variant_class == 'stop'")["func_score"].median()

for bound in ["upper", "lower"]:
    clip = params[f"clip_{bound}"]
    if clip is None:
        print(f"No clipping on {bound} bound of functional scores")
    else:
        if clip == "median_stop":
            if pd.isnull(median_stop):
                raise ValueError(f"{median_stop=}")
            clip = median_stop
        assert isinstance(clip, (int, float)), clip
        print(f"Clipping {bound} bound of functional scores to {clip}")
        func_scores_df["func_score"] = func_scores_df["func_score"].clip(
            **{bound: clip}
        )
No clipping on upper bound of functional scores
Clipping lower bound of functional scores to -4.047499999999999

Initialize data for multidms:

In [5]:
data = multidms.Data(
    variants_df=func_scores_df,
    reference=params["reference"],
    alphabet=multidms.AAS_WITHSTOP_WITHGAP,
    collapse_identical_variants=params["collapse_identical_variants"],
    letter_suffixed_sites=True,
    verbose=True,
    nb_workers=threads,
    assert_site_integrity=True,
)
inferring site map for 293T
inferring site map for human_aDG
inferring site map for mastomys_aDG
Asserting site integrity
INFO: Pandarallel will run on 1 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
unknown cond wildtype at sites: [],
dropping: 0 variantswhich have mutations at those sites.
invalid non-identical-sites: [], dropping 0 variants
Converting mutations for 293T
is reference, skipping
Converting mutations for human_aDG
is reference, skipping
Converting mutations for mastomys_aDG
is reference, skipping
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Now initialize and fit the model for each lasso penalty:

In [6]:
lasso_shifts = params["lasso_shifts"]
assert len(lasso_shifts) == len(set(lasso_shifts))
n_lasso = len(params["lasso_shifts"])
fig, ax = plt.subplots(n_lasso, 2, figsize=[6, n_lasso * 3])

mutations_df = []
for i, lasso_shift in enumerate(lasso_shifts):
    lasso_shift = float(lasso_shift)
    print(f"Fitting model for {lasso_shift=}")
    model = multidms.Model(data)
    model.fit(lasso_shift=lasso_shift)
    mutations_df.append(
        model.get_mutations_df(phenotype_as_effect=True).assign(lasso_shift=lasso_shift)
    )
    model.plot_epistasis(ax=ax[i, 1], alpha=0.1, show=False, legend=not i)
    model.plot_pred_accuracy(ax=ax[i, 0], alpha=0.1, show=False, legend=False)
    ax[i, 1].set_title(f"Epistasis fit (lasso {lasso_shift})")
    ax[i, 0].set_title(f"Accuracy (lasso {lasso_shift})")

plt.show()
Fitting model for lasso_shift=0.0
Fitting model for lasso_shift=1e-06
Fitting model for lasso_shift=5e-06
Fitting model for lasso_shift=1e-05
Fitting model for lasso_shift=5e-05
Fitting model for lasso_shift=0.0001
Fitting model for lasso_shift=0.0005
Fitting model for lasso_shift=0.001
No description has been provided for this image

Get the mutations effect shifts frame for each lasso penalty, adding in wildtype:

In [7]:
# get shifts for mutations
mut_shifts = pd.concat(mutations_df).rename(
    columns={
        "wts": "wildtype",
        "sites": "site",
        "muts": "mutant",
        "beta": "latent_phenotype_effect",
    }
)

# we do not keep predicted functional effects
mut_shifts = mut_shifts[
    [c for c in mut_shifts.columns if not c.startswith("predicted")]
]

# add wildtypes
mut_shifts = (
    pd.concat(
        [
            mut_shifts,
            pd.concat(
                [
                    pd.DataFrame(
                        {
                            "site": data.site_map.index,
                            "wildtype": data.site_map[str(params["reference"])],
                            "mutant": data.site_map[str(params["reference"])],
                            "latent_phenotype_effect": 0,
                            **{
                                col: 0
                                for col in mut_shifts.columns
                                if col.startswith("shift_")
                            },
                        }
                    ).assign(lasso_shift=float(lasso_shift))
                    for lasso_shift in lasso_shifts
                ]
            ),
        ]
    )
    .sort_values(["lasso_shift", "site", "mutant"])
    .reset_index(drop=True)
)

print(f"Saving shifts to {shifts_csv}")

mut_shifts.to_csv(shifts_csv, index=False, float_format="%.4g")
Saving shifts to results/func_effect_shifts/by_comparison/LibA-1_shifts.csv

Plot distribution of shifts for all non-wildtype residues for each regularization:

In [8]:
mut_shifts_cols = [c for c in mut_shifts.columns if c.startswith("shift_")]

mut_shifts_tidy = mut_shifts.query("wildtype != mutant").melt(
    id_vars=["site", "mutant", "lasso_shift"],
    value_vars=mut_shifts_cols,
    var_name="condition",
    value_name="shift",
)

_ = seaborn.displot(
    mut_shifts_tidy,
    x="shift",
    col="condition",
    row="lasso_shift",
    hue="condition",
    height=1.9,
    aspect=1.8,
    facet_kws={"margin_titles": True},
)
No description has been provided for this image