Compare simple difference in functional effects across two conditions¶
Import Python modules.
We use polyclonal
for the plotting:
import itertools
import altair as alt
import dms_variants.utils
import pandas as pd
import polyclonal
import polyclonal.plot
This notebook is parameterized by papermill
.
The next cell is tagged as parameters
to get the passed parameters.
# this cell is tagged parameters for `papermill` parameterization
site_numbering_map_csv = None
mutation_annotations_csv = None
diffs_csv = None
chart_html = None
corr_chart_html = None
params = None
# Parameters
params = {
"condition_1": {
"name": "SA26 entry",
"selections": ["Lib1-240125-293-SA26", "Lib2-240125-293-SA26"],
},
"condition_2": {
"name": "SA23 entry",
"selections": ["Lib1-240125-293-SA23", "Lib2-240125-293-SA23"],
},
"avg_method": "median",
"per_selection_tooltips": True,
"plot_kwargs": {
"alphabet": [
"R",
"K",
"H",
"D",
"E",
"Q",
"N",
"S",
"T",
"Y",
"W",
"F",
"A",
"I",
"L",
"M",
"V",
"G",
"P",
"C",
],
"addtl_slider_stats": {
"times_seen": 2,
"difference_std": 2,
"fraction_pairs_w_mutation": 1,
"best_effect": -2,
"SA23 entry effect": -8,
"SA26 entry effect": 0,
"nt changes to codon": 3,
},
"addtl_slider_stats_hide_not_filter": [
"best_effect",
"SA23 entry effect",
"SA26 entry effect",
"nt changes to codon",
],
"addtl_slider_stats_as_max": ["difference_std", "nt changes to codon"],
"heatmap_max_at_least": 1,
"heatmap_min_at_least": -1,
"init_floor_at_zero": True,
"init_site_statistic": "sum",
"site_zoom_bar_color_col": "region",
"slider_binding_range_kwargs": {
"times_seen": {"step": 1, "min": 1, "max": 10},
"nt changes to codon": {"step": 1, "min": 1, "max": 3},
},
},
}
mutation_annotations_csv = "results/mutation_annotations/mutation_annotations.csv"
site_numbering_map_csv = "data/site_numbering_map.csv"
diffs_csv = "results/func_effect_diffs/SA26_vs_SA23_entry_diffs.csv"
chart_html = "results/func_effect_diffs/SA26_vs_SA23_entry_diffs.html"
corr_chart_html = "results/func_effect_diffs/SA26_vs_SA23_entry_diffs_corr.html"
Read the input data:
site_numbering_map = pd.read_csv(site_numbering_map_csv).rename(
columns={"reference_site": "site"}
)
assert site_numbering_map[["site", "sequential_site"]].notnull().all().all()
addtl_site_cols = [
c for c in site_numbering_map.columns if c != "site" and c.endswith("site")
]
sequential_to_site = site_numbering_map.set_index("sequential_site")["site"].to_dict()
assert len(sequential_to_site) == len(site_numbering_map)
if mutation_annotations_csv:
mutation_annotations = pd.read_csv(mutation_annotations_csv)
condition_1 = params["condition_1"]["name"]
condition_2 = params["condition_2"]["name"]
assert condition_1 != condition_2, f"{condition_1=}, {condition_2=}"
condition_1_selections = params["condition_1"]["selections"]
condition_2_selections = params["condition_2"]["selections"]
assert len(condition_1_selections) == len(set(condition_1_selections))
assert len(condition_2_selections) == len(set(condition_2_selections))
assert len(condition_1_selections), params["condition_1"]
assert len(condition_2_selections), params["condition_2"]
if set(condition_1_selections).intersection(condition_2_selections):
raise ValueError(
f"shared selections in {condition_1_selections=} and {condition_2_selections=}"
)
# get sites to keep for each selection (relevant if keeping only regions)
if isinstance(condition_1_selections, list):
assert isinstance(condition_2_selections, list)
selection_sites = {
selection: set(sequential_to_site.values())
for selection in condition_1_selections + condition_2_selections
}
else:
assert isinstance(condition_1_selections, dict)
assert isinstance(condition_2_selections, dict)
selection_sites = {}
for selection, region in list(condition_1_selections.items()) + list(
condition_2_selections.items()
):
region_sequential = []
for r in region:
if isinstance(r, int):
region_sequential.append(r)
else:
assert isinstance(r, list) and all(isinstance(ri, int) for ri in r), r
assert r[0] <= r[1], r
region_sequential += list(range(r[0], r[1] + 1))
selection_sites[selection] = [sequential_to_site[r] for r in region_sequential]
# get number of selections each site is kept in
n_selections_per_site = {
condition: {
site: sum(
site in selection_sites[selection] for selection in condition_selections
)
for site in site_numbering_map["site"]
}
for (condition, condition_selections) in [
(condition_1, condition_1_selections),
(condition_2, condition_2_selections),
]
}
dfs = []
for c, sels in [
(condition_1, condition_1_selections),
(condition_2, condition_2_selections),
]:
for s in sels:
sites_to_keep = selection_sites[s]
dfs.append(
pd.read_csv(f"results/func_effects/by_selection/{s}_func_effects.csv")
.assign(
selection=s,
condition=c,
times_seen=lambda x: x["times_seen"].astype("Int64"),
mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
)
.query("site in @sites_to_keep")
)
func_effects = pd.concat(dfs, ignore_index=True)
Correlations among all selections¶
Compute the correlations in the mutation effects across all selections:
# We compute for several times seen values, get those:
try:
init_times_seen = params["plot_kwargs"]["addtl_slider_stats"]["times_seen"]
except KeyError:
print("No times seen in params, using a value of 3")
init_times_seen = 3
# do analysis for each "times_seen"
func_effects_for_corr = pd.concat(
[
func_effects.query("times_seen >= @t", engine="python").assign(min_times_seen=t)
for t in [1, init_times_seen, 2 * init_times_seen]
]
)
corrs = (
dms_variants.utils.tidy_to_corr(
df=func_effects_for_corr,
sample_col="selection",
label_col="mutation",
value_col="functional_effect",
group_cols=["min_times_seen"],
)
.assign(
r2=lambda x: x["correlation"] ** 2,
min_times_seen=lambda x: "min times seen " + x["min_times_seen"].astype(str),
)
.rename(columns={"correlation": "r"})
)
corr_chart = (
alt.Chart(corrs)
.encode(
alt.X("selection_1", title=None),
alt.Y("selection_2", title=None),
column=alt.Column("min_times_seen", title=None),
color=alt.Color("r2", scale=alt.Scale(zero=True)),
tooltip=[
alt.Tooltip(c, format=".3g") if c in {"r2", "r"} else c
for c in ["selection_1", "selection_2", "r2", "r"]
],
)
.mark_rect(stroke="black")
.properties(
width=alt.Step(15),
height=alt.Step(15),
title="Per-selection correlation in mutation functional effects",
)
.configure_axis(labelLimit=500)
)
display(corr_chart)
print(
f"\nSelections for {condition_1}: {list(condition_1_selections)}\n"
f"Selections for {condition_2}: {list(condition_2_selections)}\n"
)
Selections for SA26 entry: ['Lib1-240125-293-SA26', 'Lib2-240125-293-SA26'] Selections for SA23 entry: ['Lib1-240125-293-SA23', 'Lib2-240125-293-SA23']
Average functional effects for each condition¶
Average the functional effects for each condition using the specified averaging method, then print the correlation between these average functional effects at several times seen:
avg_method = params["avg_method"]
assert avg_method in {"mean", "median"}, avg_method
avg_func_effects = (
func_effects.groupby(
["condition", "site", "wildtype", "mutant", "mutation"], as_index=False
)
.aggregate(
effect=pd.NamedAgg("functional_effect", avg_method),
times_seen=pd.NamedAgg("times_seen", "sum"),
)
.assign(
n_selections=lambda x: x.apply(
lambda r: n_selections_per_site[r.condition][r.site],
axis=1,
),
times_seen=lambda x: (x["times_seen"] / x["n_selections"]).where(
x["mutant"] != x["wildtype"],
pd.NA,
),
)
)
avg_func_effects_for_corr = pd.concat(
[
avg_func_effects.query("times_seen >= @t", engine="python").assign(
min_times_seen=t
)
for t in [1, init_times_seen, 2 * init_times_seen]
]
)
print("Correlation between average functional effects across conditions:")
display(
dms_variants.utils.tidy_to_corr(
df=avg_func_effects_for_corr,
sample_col="condition",
label_col="mutation",
value_col="effect",
group_cols=["min_times_seen"],
)
.assign(
r2=lambda x: x["correlation"] ** 2,
min_times_seen=lambda x: "min times seen " + x["min_times_seen"].astype(str),
)
.rename(columns={"correlation": "r"})
.query("condition_1 != condition_2")
.reset_index(drop=True)
.groupby("min_times_seen")
.first()
.round(3)
)
Correlation between average functional effects across conditions:
condition_1 | condition_2 | r | r2 | |
---|---|---|---|---|
min_times_seen | ||||
min times seen 1 | SA26 entry | SA23 entry | 0.872 | 0.761 |
min times seen 2 | SA26 entry | SA23 entry | 0.883 | 0.779 |
min times seen 4 | SA26 entry | SA23 entry | 0.900 | 0.810 |
Compute pairwise differences¶
Compute pairwise differences in effects between all pairs of condition 1 selections versus condition 2 selections. For each comparison, we compute the times seen as the mean between the two selections being compared.
We then compute the average (using the specified average method) difference across comparisons, the mean times seen, and the fraction of comparisons in which a difference can be computed:
# compute differences for all individual pairs
diffs_all = []
for sel1, sel2 in itertools.product(condition_1_selections, condition_2_selections):
df1 = func_effects.query("selection == @sel1")[
["wildtype", "site", "mutant", "times_seen", "functional_effect"]
]
df2 = func_effects.query("selection == @sel2")[
["wildtype", "site", "mutant", "times_seen", "functional_effect"]
]
diffs_all.append(
df1.merge(df2, on=["wildtype", "site", "mutant"], validate="1:1")
.assign(
times_seen=lambda x: (x["times_seen_x"] + x["times_seen_y"]) / 2,
difference=lambda x: x["functional_effect_x"] - x["functional_effect_y"],
)[["wildtype", "site", "mutant", "times_seen", "difference"]]
.assign(comparison=f"{sel1} vs {sel2}")
)
# compute average differences across pairs
diffs = (
pd.concat(diffs_all, ignore_index=True)
.groupby(["wildtype", "site", "mutant"], as_index=False)
.aggregate(
difference=pd.NamedAgg("difference", avg_method),
difference_std=pd.NamedAgg("difference", lambda s: s.dropna().std(ddof=0)),
times_seen=pd.NamedAgg("times_seen", "mean"),
n_pairs_w_mutation=pd.NamedAgg("difference", "count"),
)
.assign(
fraction_pairs_w_mutation=lambda x: x.apply(
lambda r: r.n_pairs_w_mutation
/ (
n_selections_per_site[condition_1][r.site]
* n_selections_per_site[condition_2][r.site]
),
axis=1,
),
)
.drop(columns="n_pairs_w_mutation")
)
# add other relevant stuff to data frame of differences
diffs = (
diffs
# add average effects in each condition
.merge(
avg_func_effects.pivot_table(
index=["site", "wildtype", "mutant"],
values="effect",
columns="condition",
)
.reset_index()
.assign(best_effect=lambda x: x[[condition_1, condition_2]].max(axis=1))
.rename(columns={c: f"{c} effect" for c in [condition_1, condition_2]}),
on=["wildtype", "site", "mutant"],
validate="one_to_one",
)
# add per-selection effects (times seen)
.merge(
func_effects.assign(
effect_times_seen=lambda x: (
x["functional_effect"].map(lambda e: f"{e:.2f}")
+ (" (" + x["times_seen"].astype(str) + ")").where(
x["mutant"] != x["wildtype"],
"",
)
)
)
.pivot_table(
index=[
"site",
"wildtype",
"mutant",
],
values="effect_times_seen",
columns="selection",
aggfunc=lambda s: ",".join(s),
)[list(condition_1_selections) + list(condition_2_selections)]
.reset_index(),
on=["wildtype", "site", "mutant"],
validate="one_to_one",
)
# sort values
.sort_values(["site", "mutant"]).reset_index(drop=True)
)
print(f"Writing differences to {diffs_csv}")
diffs.to_csv(diffs_csv, index=False, float_format="%.4g")
Writing differences to results/func_effect_diffs/SA26_vs_SA23_entry_diffs.csv
Get correlations for comparisons, applying times seen filter:
print(f"Correlating differences for times_seen of {init_times_seen}")
diffs_all_df = (
pd.concat(diffs_all)
.query("times_seen >= @init_times_seen")
.assign(mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"])
.rename(columns={"comparison": "selection"})
)
diffs_corr_df = (
diffs_all_df.pivot_table(
index="mutation",
columns="selection",
values="difference",
)
.corr(method="pearson")
.rename_axis("comparison_1")
.reset_index()
.melt(
id_vars="comparison_1",
value_name="correlation (r)",
var_name="comparison_2",
)
)
diffs_corr_chart = (
alt.Chart(diffs_corr_df)
.encode(
alt.X("comparison_1", title=None),
alt.Y("comparison_2", title=None),
color=alt.Color(
"correlation (r)", scale=alt.Scale(domain=[-1, 1], scheme="redblue")
),
tooltip=[
alt.Tooltip(c, format=".3g") if c == "correlation (r)" else c
for c in diffs_corr_df.columns
],
)
.mark_rect(stroke="black")
.properties(
width=alt.Step(15),
height=alt.Step(15),
title="Correlation in differences for mutations across comparisons",
)
.configure_axis(labelLimit=500)
)
display(diffs_corr_chart)
Correlating differences for times_seen of 2
Make a scatter plot comparing the conditions¶
Make a correlation plot between the two conditions with informative tooltips and slider bars:
# compact version of diffs for tooltips
diffs_compact = diffs
for condition, condition_selections in [
(condition_1, condition_1_selections),
(condition_2, condition_2_selections),
]:
assert set(condition_selections).issubset(diffs_compact.columns)
diffs_compact = diffs_compact.assign(
**{
f"{condition} selections": lambda x: x.apply(
lambda r: ", ".join(
r[s] for s in condition_selections if not pd.isnull(r[s])
),
axis=1,
)
}
).drop(columns=list(condition_selections))
mutation_selection = alt.selection_point(
on="mouseover", fields=["mutation"], empty=False
)
if mutation_annotations_csv:
if not {"site", "mutant"}.issubset(mutation_annotations.columns):
raise ValueError(f"{mutation_annotations.columns=} lacks 'site', 'mutant'")
if set(mutation_annotations.columns).intersection(diffs_compact.columns) != {
"site",
"mutant",
}:
raise ValueError(
f"{mutation_annotations.columns=} shares columns with {diffs.columns=}"
)
diffs_compact = diffs_compact.merge(
mutation_annotations,
on=["site", "mutant"],
how="left",
validate="many_to_one",
)
for col in mutation_annotations.columns:
if col not in {"site", "mutant"}:
diffs_compact[col] = diffs_compact[col].where(
diffs_compact["wildtype"] != diffs_compact["mutant"], pd.NA
)
corr_diffs = (
diffs_compact.query("wildtype != mutant")
.assign(
mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
)
.drop(columns=["wildtype", "site", "mutant"])
)
corr_diffs = corr_diffs[
["mutation"] + [c for c in corr_diffs.columns if c != "mutation"]
]
plot_kwargs = params["plot_kwargs"].copy()
if "slider_binding_range_kwargs" not in plot_kwargs:
plot_kwargs["slider_binding_range_kwargs"] = {}
if "addtl_slider_stats_as_max" not in plot_kwargs:
plot_kwargs["addtl_slider_stats_as_max"] = []
sliders = {}
for stat in plot_kwargs["addtl_slider_stats"]:
value = (
plot_kwargs["addtl_slider_stats"][stat]
if (
"addtl_slider_stats" in plot_kwargs
and stat in plot_kwargs["addtl_slider_stats"]
and plot_kwargs["addtl_slider_stats"][stat]
)
else (
corr_diffs[stat].max()
if stat == "difference_std"
else corr_diffs[stat].min()
)
)
sliders[stat] = alt.param(
value=value,
bind=alt.binding_range(
**(
{
"name": (
f"maximum {stat}"
if stat in plot_kwargs["addtl_slider_stats_as_max"]
else f"minimum {stat}"
),
"min": min(value, corr_diffs[stat].min()),
"max": max(value, corr_diffs[stat].max()),
}
| (
plot_kwargs["slider_binding_range_kwargs"][stat]
if stat in plot_kwargs["slider_binding_range_kwargs"]
else {}
)
)
),
)
corr_chart = (
alt.Chart(corr_diffs)
.add_params(mutation_selection)
.encode(
alt.X(
f"{condition_1} effect", scale=alt.Scale(nice=False, zero=False, padding=5)
),
alt.Y(
f"{condition_2} effect", scale=alt.Scale(nice=False, zero=False, padding=5)
),
strokeWidth=alt.condition(mutation_selection, alt.value(2), alt.value(0)),
size=alt.condition(mutation_selection, alt.value(70), alt.value(45)),
tooltip=[
alt.Tooltip(c, format=".3g") if corr_diffs[c].dtype == float else c
for c in corr_diffs.columns
],
)
.mark_circle(fill="black", fillOpacity=0.35, stroke="red")
.properties(width=275, height=275)
.configure_axis(grid=False)
)
for stat, slider in sliders.items():
if stat in plot_kwargs["addtl_slider_stats_as_max"]:
corr_chart = corr_chart.add_params(slider).transform_filter(
alt.datum[stat] <= slider
)
else:
corr_chart = corr_chart.add_params(slider).transform_filter(
alt.datum[stat] >= slider
)
print(f"Saving to {corr_chart_html=}")
corr_chart.save(corr_chart_html)
corr_chart
Saving to corr_chart_html='results/func_effect_diffs/SA26_vs_SA23_entry_diffs_corr.html'
Make interactive chart¶
Set up keyword arguments to https://jbloomlab.github.io/polyclonal/polyclonal.plot.html#polyclonal.plot.lineplot_and_heatmap if they are not already specified:
if "addtl_slider_stats" not in plot_kwargs:
plot_kwargs["addtl_slider_stats"] = {}
if "times_seen" not in plot_kwargs["addtl_slider_stats"]:
plot_kwargs["addtl_slider_stats"]["times_seen"] = 3
if "difference_std" not in plot_kwargs["addtl_slider_stats"]:
plot_kwargs["addtl_slider_stats"]["difference_std"] = diffs_compact[
"difference_std"
].max()
if "addtl_slider_stats_as_max" not in plot_kwargs:
plot_kwargs["addtl_slider_stats_as_max"] = ["difference_std"]
else:
plot_kwargs["addtl_slider_stats_as_max"].append("difference_std")
elif "addtl_slider_stats_as_max" not in plot_kwargs:
raise ValueError(
"You specified `difference_std` in `addtl_slider_stats` but did not add it to "
"`addtl_slider_stats_as_max`. If you really do not want `difference_std` in "
"`addtl_slider_stats_as_max`, then specify that list without it."
)
if "fraction_pairs_w_mutation" not in plot_kwargs["addtl_slider_stats"]:
plot_kwargs["addtl_slider_stats"]["fraction_pairs_w_mutation"] = 0.5
if "site_zoom_bar_color_col" in plot_kwargs:
if plot_kwargs["site_zoom_bar_color_col"] in diffs_compact.columns:
pass
elif plot_kwargs["site_zoom_bar_color_col"] in site_numbering_map.columns:
diffs_compact = diffs_compact.merge(
site_numbering_map[["site", plot_kwargs["site_zoom_bar_color_col"]]],
on="site",
validate="many_to_one",
how="left",
)
if "addtl_tooltip_stats" not in plot_kwargs:
plot_kwargs["addtl_tooltip_stats"] = []
for c in ["difference_std"] + addtl_site_cols:
if c not in plot_kwargs["addtl_tooltip_stats"]:
plot_kwargs["addtl_tooltip_stats"].append(c)
if "sequential_site" not in diffs_compact.columns:
diffs_compact = diffs_compact.merge(
site_numbering_map[["site", *addtl_site_cols]],
on="site",
validate="many_to_one",
how="left",
)
if any(diffs_compact["site"] != diffs_compact["sequential_site"]):
if "sequential_site" not in plot_kwargs["addtl_tooltip_stats"]:
plot_kwargs["addtl_tooltip_stats"].append("sequential_site")
plot_kwargs["addtl_tooltip_stats"] += [
s
for s in [f"{condition_1} selections", f"{condition_2} selections"]
if s not in plot_kwargs["addtl_tooltip_stats"]
]
if "alphabet" not in plot_kwargs:
plot_kwargs["alphabet"] = [
a
for a in polyclonal.alphabets.biochem_order_aas(polyclonal.AAS_WITHSTOP_WITHGAP)
if a in set(diffs_compact["mutant"])
]
if "sites" not in plot_kwargs:
plot_kwargs["sites"] = site_numbering_map.sort_values("sequential_site")[
"site"
].tolist()
Now make the interactive heatmap:
assert "_dummy" not in diffs_compact.columns
chart = polyclonal.plot.lineplot_and_heatmap(
data_df=diffs_compact.assign(_dummy="dummy"),
stat_col="difference",
category_col="_dummy",
**plot_kwargs,
)
display(chart)
print(f"\nSaving to {chart_html}")
chart.save(chart_html)
Saving to results/func_effect_diffs/SA26_vs_SA23_entry_diffs.html