Average mutation functional effect shifts for a set of comparisons¶
Import Python modules.
We use polyclonal for the plotting:
In [1]:
import pandas as pd
import polyclonal
import polyclonal.plot
import seaborn
This notebook is parameterized by papermill.
The next cell is tagged as parameters to get the passed parameters.
In [2]:
# this cell is tagged parameters for `papermill` parameterization
site_numbering_map_csv = None
mutation_annotations_csv = None
shifts_csv = None
shifts_html = None
params = None
In [3]:
# Parameters
params = {
"avg_method": "median",
"per_comparison_tooltips": True,
"plot_kwargs": {
"addtl_slider_stats": {"times_seen": 3, "nt changes to codon": 3},
"addtl_slider_stats_as_max": ["nt changes to codon"],
"addtl_slider_stats_hide_not_filter": ["nt changes to codon"],
"heatmap_max_at_least": 0.5,
"heatmap_min_at_least": -0.5,
"init_floor_at_zero": False,
"init_site_statistic": "mean",
"site_zoom_bar_color_col": "region",
"slider_binding_range_kwargs": {
"times_seen": {"step": 1, "min": 1, "max": 25},
"n_comparisons": {"step": 1},
"nt changes to codon": {"step": 1, "min": 1, "max": 3},
},
},
"comparisons": ["LibA-date_comparison-1", "LibA-date_comparison-2"],
"lasso_shift": 0.0001,
}
mutation_annotations_csv = "data/mutation_annotations.csv"
site_numbering_map_csv = "data/site_numbering_map.csv"
shifts_csv = "results/func_effect_shifts/averages/date_comparison_shifts.csv"
shifts_html = "results/func_effect_shifts/averages/date_comparison_shifts.html"
Read the input data:
In [4]:
site_numbering_map = pd.read_csv(site_numbering_map_csv).rename(
columns={"reference_site": "site"}
)
sequential_to_site = site_numbering_map.set_index("sequential_site")["site"].to_dict()
assert len(sequential_to_site) == len(site_numbering_map)
comparisons = list(params["comparisons"])
# get sites to keep for each comparison (relevant if keeping only regions)
if isinstance(params["comparisons"], list):
comparison_sites = {
comparison: set(sequential_to_site.values()) for comparison in comparisons
}
else:
assert isinstance(params["comparisons"], dict)
comparison_sites = {}
for comparison, region in params["comparisons"].items():
region_sequential = []
for r in region:
if isinstance(r, int):
region_sequential.append(r)
else:
assert isinstance(r, list) and all(isinstance(ri, int) for ri in r), r
assert r[0] <= r[1], r
region_sequential += list(range(r[0], r[1] + 1))
comparison_sites[comparison] = [
sequential_to_site[r] for r in region_sequential
]
# get number of comparisons each site is kept in
n_comparisons_per_site = {
site: sum(site in comparison_sites[comparison] for comparison in comparisons)
for site in site_numbering_map["site"]
}
shifts = [
pd.read_csv(f"results/func_effect_shifts/by_comparison/{c}_shifts.csv")
.assign(
comparison=c,
lasso_shift=lambda x: x["lasso_shift"].astype(float),
)
.query("site in @sites_to_keep")
for (c, sites_to_keep) in comparison_sites.items()
]
if mutation_annotations_csv:
mutation_annotations = pd.read_csv(mutation_annotations_csv)
# check all shift comparisons are comparable:
for shift_df in shifts[1:]:
if (shift_df.columns != shifts[0].columns).any():
raise ValueError("comparisons do not all have the same columns")
if set(shift_df["lasso_shift"]) != set(shifts[0]["lasso_shift"]):
raise ValueError("comparisons do not all have the same `lasso_shifts`")
shifts = pd.concat(shifts)
# add a times_seen column that is the average of all of the times_seen in all conditions
# that are in the region for that site
times_seen_cols = [c for c in shifts.columns if c.startswith("times_seen_")]
shifts["times_seen"] = shifts[times_seen_cols].sum(axis=1) / shifts["site"].map(
n_comparisons_per_site
)
# get shifts in tidy format
shift_cols = [c for c in shifts.columns if c.startswith("shift_")]
shifts_tidy = shifts.melt(
id_vars=[
"comparison",
"site",
"wildtype",
"mutant",
"lasso_shift",
"times_seen",
"latent_phenotype_effect",
],
value_vars=shift_cols,
var_name="condition",
value_name="shift",
)
# average times_seen & latent_phenotype_effect across comparisons, pivot on comparisons
shifts_comparison_pivoted = (
shifts_tidy.assign(
times_seen=lambda x: x.groupby(["site", "mutant", "lasso_shift"])[
"times_seen"
].transform("mean"),
latent_phenotype_effect=lambda x: x.groupby(["site", "mutant", "lasso_shift"])[
"latent_phenotype_effect"
].transform("mean"),
)
.pivot_table(
index=[
"site",
"wildtype",
"mutant",
"latent_phenotype_effect",
"times_seen",
"lasso_shift",
"condition",
],
values="shift",
columns="comparison",
)
.reset_index()
)
Plot correlation of shifts for each lasso shift, restricting to a minimum threshold times_seen, and not plotting shifts for wildtype residues.
In general, you might hope to find a lasso shift that has relatively few non-zero shifts, and those are correlated among comparisons.
In [5]:
try:
times_seen = params["plot_kwargs"]["addtl_slider_stats"]["times_seen"]
except KeyError:
times_seen = 3
print(f"Only plotting mutations with times_seen >= {times_seen}")
for lasso_shift, df in shifts_comparison_pivoted.groupby("lasso_shift"):
grid = seaborn.pairplot(
df.query("times_seen >= 3").query("wildtype != mutant"),
vars=comparisons,
hue=(
None
if shifts_comparison_pivoted["condition"].nunique() == 1
else "condition"
),
plot_kws={"alpha": 0.3, "s": 25},
)
grid.fig.suptitle(f"lasso shift = {lasso_shift}")
grid.fig.tight_layout()
Only plotting mutations with times_seen >= 3
Now make an interactive plots of the shifts. First, get the data to plot:
In [6]:
lasso_shift = float(params["lasso_shift"])
avg_method = params["avg_method"]
assert lasso_shift in set(shifts_comparison_pivoted["lasso_shift"])
assert avg_method in {"mean", "median"}, avg_method
addtl_site_cols = [
c for c in site_numbering_map.columns if c != "site" and c.endswith("site")
]
# get the data to plot
df = (
shifts_comparison_pivoted.query("lasso_shift == @lasso_shift")
.drop(columns="lasso_shift")
.merge(
site_numbering_map[["site", *addtl_site_cols, "region"]],
on="site",
validate="many_to_one",
)
.assign(
shift=lambda x: x[comparisons].apply(avg_method, axis=1),
n_comparisons=lambda x: x[comparisons].notnull().sum(axis=1),
)
)
print(f"Saving shifts to {shifts_csv}")
df.drop(columns=[c for c in addtl_site_cols if c != "sequential_site"]).to_csv(
shifts_csv, index=False, float_format="%.4g"
)
Saving shifts to results/func_effect_shifts/averages/date_comparison_shifts.csv
Set up keyword arguments to https://jbloomlab.github.io/polyclonal/polyclonal.plot.html#polyclonal.plot.lineplot_and_heatmap if they are not already specified:
In [7]:
plot_kwargs = params["plot_kwargs"]
if mutation_annotations_csv:
if not {"site", "mutant"}.issubset(mutation_annotations.columns):
raise ValueError(f"{mutation_annotations.columns=} lacks 'site', 'mutant'")
if set(mutation_annotations.columns).intersection(df.columns) != {"site", "mutant"}:
raise ValueError(
f"{mutation_annotations.columns=} shares columns with {df.columns=}"
)
df = df.merge(
mutation_annotations,
on=["site", "mutant"],
how="left",
validate="many_to_one",
)
for col in mutation_annotations.columns:
if col not in {"site", "mutant"}:
df[col] = df[col].where(df["wildtype"] != df["mutant"], pd.NA)
if "addtl_slider_stats" not in plot_kwargs:
plot_kwargs["addtl_slider_stats"] = {}
if "times_seen" not in plot_kwargs["addtl_slider_stats"]:
plot_kwargs["addtl_slider_stats"]["times_seen"] = times_seen
if "n_comparisons" not in plot_kwargs["addtl_slider_stats"]:
plot_kwargs["addtl_slider_stats"]["n_comparisons"] = min(
max(n_comparisons_per_site.values()) // 2 + 1,
df["n_comparisons"].max(),
)
if "addtl_tooltip_stats" not in plot_kwargs:
plot_kwargs["addtl_tooltip_stats"] = []
for c in addtl_site_cols:
if c not in plot_kwargs["addtl_tooltip_stats"]:
plot_kwargs["addtl_tooltip_stats"].append(c)
if any(df["site"] != df["sequential_site"]):
if "sequential_site" not in plot_kwargs["addtl_tooltip_stats"]:
plot_kwargs["addtl_tooltip_stats"].append("sequential_site")
if params["per_comparison_tooltips"]:
assert set(comparisons).issubset(df.columns)
plot_kwargs["addtl_tooltip_stats"] += [
c for c in comparisons if c not in plot_kwargs["addtl_tooltip_stats"]
]
if "alphabet" not in plot_kwargs:
plot_kwargs["alphabet"] = [
a
for a in polyclonal.alphabets.biochem_order_aas(polyclonal.AAS_WITHSTOP_WITHGAP)
if a in set(df["mutant"])
]
if "sites" not in plot_kwargs:
plot_kwargs["sites"] = df.sort_values("sequential_site")["site"].unique().tolist()
In [8]:
chart = polyclonal.plot.lineplot_and_heatmap(
data_df=df,
stat_col="shift",
category_col="condition",
**plot_kwargs,
)
print(f"Saving chart to {shifts_html}")
chart.save(shifts_html)
chart
Saving chart to results/func_effect_shifts/averages/date_comparison_shifts.html
Out[8]:
In [ ]: