Average mutation functional effect shifts for a set of comparisons¶

Import Python modules. We use polyclonal for the plotting:

In [1]:
import pandas as pd

import polyclonal
import polyclonal.plot

import seaborn

This notebook is parameterized by papermill. The next cell is tagged as parameters to get the passed parameters.

In [2]:
# this cell is tagged parameters for `papermill` parameterization
site_numbering_map_csv = None
mutation_annotations_csv = None
shifts_csv = None
shifts_html = None
params = None
In [3]:
# Parameters
params = {
    "avg_method": "median",
    "per_comparison_tooltips": True,
    "plot_kwargs": {
        "alphabet": [
            "R",
            "K",
            "H",
            "D",
            "E",
            "Q",
            "N",
            "S",
            "T",
            "Y",
            "W",
            "F",
            "A",
            "I",
            "L",
            "M",
            "V",
            "G",
            "P",
            "C",
            "*",
        ],
        "addtl_slider_stats": {"times_seen": 2, "n_comparisons": 8},
        "heatmap_max_at_least": 0.5,
        "heatmap_min_at_least": -0.5,
        "init_floor_at_zero": False,
        "init_site_statistic": "mean",
        "site_zoom_bar_color_col": "region",
        "slider_binding_range_kwargs": {
            "times_seen": {"step": 1, "min": 1, "max": 25},
            "n_comparisons": {"step": 1},
        },
    },
    "comparisons": [
        "LibA-1",
        "LibA-2",
        "LibA-3",
        "LibA-4",
        "LibB-1",
        "LibB-2",
        "LibB-3",
        "LibB-4",
    ],
    "lasso_shift": 0.0001,
}
mutation_annotations_csv = None
site_numbering_map_csv = "data/site_numbering_map.csv"
shifts_csv = "results/func_effect_shifts/averages/aDG_comparison_shifts.csv"
shifts_html = "results/func_effect_shifts/averages/aDG_comparison_shifts.html"

Read the input data:

In [4]:
comparisons = params["comparisons"]

shifts = [
    pd.read_csv(f"results/func_effect_shifts/by_comparison/{c}_shifts.csv").assign(
        comparison=c,
        lasso_shift=lambda x: x["lasso_shift"].astype(float),
    )
    for c in comparisons
]

if mutation_annotations_csv:
    mutation_annotations = pd.read_csv(mutation_annotations_csv)

# check all shift comparisons are comparable:
for shift_df in shifts[1:]:
    if (shift_df.columns != shifts[0].columns).any():
        raise ValueError("comparisons do not all have the same columns")
    if set(shift_df["lasso_shift"]) != set(shifts[0]["lasso_shift"]):
        raise ValueError("comparisons do not all have the same `lasso_shifts`")

shifts = pd.concat(shifts)

# add a times_seen column that is the average of all of the times_seen in all conditions
times_seen_cols = [c for c in shifts.columns if c.startswith("times_seen_")]
shifts["times_seen"] = shifts[times_seen_cols].mean(axis=1)

# get shifts in tidy format
shift_cols = [c for c in shifts.columns if c.startswith("shift_")]
shifts_tidy = shifts.melt(
    id_vars=[
        "comparison",
        "site",
        "wildtype",
        "mutant",
        "lasso_shift",
        "times_seen",
        "latent_phenotype_effect",
    ],
    value_vars=shift_cols,
    var_name="condition",
    value_name="shift",
)

# average times_seen & latent_phenotype_effect across comparisons, pivot on comparisons
shifts_comparison_pivoted = (
    shifts_tidy.assign(
        times_seen=lambda x: x.groupby(["site", "mutant", "lasso_shift"])[
            "times_seen"
        ].transform("mean"),
        latent_phenotype_effect=lambda x: x.groupby(["site", "mutant", "lasso_shift"])[
            "latent_phenotype_effect"
        ].transform("mean"),
    )
    .pivot_table(
        index=[
            "site",
            "wildtype",
            "mutant",
            "latent_phenotype_effect",
            "times_seen",
            "lasso_shift",
            "condition",
        ],
        values="shift",
        columns="comparison",
    )
    .reset_index()
)

Plot correlation of shifts for each lasso shift, restricting to a minimum threshold times_seen, and not plotting shifts for wildtype residues. In general, you might hope to find a lasso shift that has relatively few non-zero shifts, and those are correlated among comparisons.

In [5]:
try:
    times_seen = params["plot_kwargs"]["addtl_slider_stats"]["times_seen"]
except KeyError:
    times_seen = 3

print(f"Only plotting mutations with times_seen >= {times_seen}")

for lasso_shift, df in shifts_comparison_pivoted.groupby("lasso_shift"):
    grid = seaborn.pairplot(
        df.query("times_seen >= 3").query("wildtype != mutant"),
        vars=comparisons,
        hue=(
            None
            if shifts_comparison_pivoted["condition"].nunique() == 1
            else "condition"
        ),
        plot_kws={"alpha": 0.3, "s": 25},
    )
    grid.fig.suptitle(f"lasso shift = {lasso_shift}")
    grid.fig.tight_layout()
Only plotting mutations with times_seen >= 2
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Now make an interactive plots of the shifts. First, get the data to plot:

In [6]:
lasso_shift = float(params["lasso_shift"])
avg_method = params["avg_method"]

assert lasso_shift in set(shifts_comparison_pivoted["lasso_shift"])
assert avg_method in {"mean", "median"}, avg_method

site_numbering_map = pd.read_csv(site_numbering_map_csv).rename(
    columns={"reference_site": "site"}
)
addtl_site_cols = [
    c for c in site_numbering_map.columns if c != "site" and c.endswith("site")
]

# get the data to plot
df = (
    shifts_comparison_pivoted.query("lasso_shift == @lasso_shift")
    .drop(columns="lasso_shift")
    .merge(
        site_numbering_map[["site", *addtl_site_cols, "region"]],
        on="site",
        validate="many_to_one",
    )
    .assign(
        shift=lambda x: x[comparisons].apply(avg_method, axis=1),
        n_comparisons=lambda x: x[comparisons].notnull().sum(axis=1),
    )
)

print(f"Saving shifts to {shifts_csv}")
df.drop(columns=[c for c in addtl_site_cols if c != "sequential_site"]).to_csv(
    shifts_csv, index=False, float_format="%.4g"
)
Saving shifts to results/func_effect_shifts/averages/aDG_comparison_shifts.csv

Set up keyword arguments to https://jbloomlab.github.io/polyclonal/polyclonal.plot.html#polyclonal.plot.lineplot_and_heatmap if they are not already specified:

In [7]:
plot_kwargs = params["plot_kwargs"]

if mutation_annotations_csv:
    if not {"site", "mutant"}.issubset(mutation_annotations.columns):
        raise ValueError(f"{mutation_annotations.columns=} lacks 'site', 'mutant'")
    if set(mutation_annotations.columns).intersection(df.columns) != {"site", "mutant"}:
        raise ValueError(
            f"{mutation_annotations.columns=} shares columns with {df.columns=}"
        )
    df = df.merge(
        mutation_annotations,
        on=["site", "mutant"],
        how="left",
        validate="many_to_one",
    )
    for col in mutation_annotations.columns:
        if col not in {"site", "mutant"}:
            df[col] = df[col].where(df["wildtype"] != df["mutant"], pd.NA)


if "addtl_slider_stats" not in plot_kwargs:
    plot_kwargs["addtl_slider_stats"] = {}

if "times_seen" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["times_seen"] = times_seen

if "n_comparisons" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["n_comparisons"] = len(comparisons) // 2 + 1

if "addtl_tooltip_stats" not in plot_kwargs:
    plot_kwargs["addtl_tooltip_stats"] = []
for c in addtl_site_cols:
    if c not in plot_kwargs["addtl_tooltip_stats"]:
        plot_kwargs["addtl_tooltip_stats"].append(c)

if any(df["site"] != df["sequential_site"]):
    if "sequential_site" not in plot_kwargs["addtl_tooltip_stats"]:
        plot_kwargs["addtl_tooltip_stats"].append("sequential_site")

if params["per_comparison_tooltips"]:
    assert set(comparisons).issubset(df.columns)
    plot_kwargs["addtl_tooltip_stats"] += [
        c for c in comparisons if c not in plot_kwargs["addtl_tooltip_stats"]
    ]

if "alphabet" not in plot_kwargs:
    plot_kwargs["alphabet"] = [
        a
        for a in polyclonal.alphabets.biochem_order_aas(polyclonal.AAS_WITHSTOP_WITHGAP)
        if a in set(df["mutant"])
    ]

if "sites" not in plot_kwargs:
    plot_kwargs["sites"] = df.sort_values("sequential_site")["site"].unique().tolist()
In [8]:
chart = polyclonal.plot.lineplot_and_heatmap(
    data_df=df,
    stat_col="shift",
    category_col="condition",
    **plot_kwargs,
)

print(f"Saving chart to {shifts_html}")
chart.save(shifts_html)

chart
Saving chart to results/func_effect_shifts/averages/aDG_comparison_shifts.html
Out[8]:
In [ ]: