Average mutational effects for an assay (eg, antibody)¶
This notebook averages selections that measure escape from neutralization in an assay. Below in the code, "antibody" is used as a generic term for the agent that neutralizes the virus.
Import Python modules.
We use polyclonal
for the averaging and plotting:
import copy
import itertools
import math
import pickle
import altair as alt
import pandas as pd
import polyclonal
import polyclonal.utils
_ = alt.data_transformers.disable_max_rows()
This notebook is parameterized by papermill
.
The next cell is tagged as parameters
to get the passed parameters.
# this cell is tagged parameters for `papermill` parameterization
assay = None
site_numbering_map_csv = None
mutation_annotations_csv = None
prob_escape_mean_csvs = None
pickles = None
assay_config = None
avg_pickle_file = None
effect_csv = None
icXX_csv = None
effect_html = None
icXX_html = None
params = None
# Parameters
params = {
"icXX": 80,
"escape_plot_kwargs": {
"alphabet": [
"A",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"K",
"L",
"M",
"N",
"P",
"Q",
"R",
"S",
"T",
"V",
"W",
"Y",
],
"addtl_slider_stats": {"times_seen": 2, "nt changes to codon": 3},
"heatmap_max_at_least": 2,
"heatmap_min_at_least": -2,
"init_floor_at_zero": True,
"init_site_statistic": "sum",
"site_zoom_bar_color_col": "region",
"addtl_slider_stats_as_max": ["nt changes to codon"],
"addtl_slider_stats_hide_not_filter": ["nt changes to codon"],
"slider_binding_range_kwargs": {
"nt changes to codon": {"min": 1, "max": 3, "step": 1},
"times_seen": {"min": 1, "max": 10, "step": 1},
},
"avg_type": "median",
"per_model_tooltip": True,
},
"plot_hide_stats": {
"functional effect": {
"csv": "results/func_effects/averages/293T_entry_func_effects.csv",
"csv_col": "effect",
"init": -3,
"min_filters": {"times_seen": 2},
}
},
"selections": ["Lib1-230822-ferret-sera-7", "Lib2-230822-ferret-sera-7"],
}
prob_escape_mean_csvs = [
"results/antibody_escape/by_selection/Lib1-230822-ferret-sera-7_prob_escape_mean.csv",
"results/antibody_escape/by_selection/Lib2-230822-ferret-sera-7_prob_escape_mean.csv",
]
pickles = [
"results/antibody_escape/by_selection/Lib1-230822-ferret-sera-7_polyclonal_model.pickle",
"results/antibody_escape/by_selection/Lib2-230822-ferret-sera-7_polyclonal_model.pickle",
]
site_numbering_map_csv = "data/site_numbering_map.csv"
assay_config = {
"title": "Antibody/serum escape",
"selections": "antibody_selections",
"averages": "avg_antibody_escape",
"prob_escape_scale": {"type": "symlog", "constant": 0.04},
"scale_stat": 1,
"stat_name": "escape",
}
mutation_annotations_csv = "results/mutation_annotations/mutation_annotations.csv"
assay = "antibody_escape"
avg_pickle_file = "results/antibody_escape/averages/ferret-7_polyclonal_model.pickle"
effect_csv = "results/antibody_escape/averages/ferret-7_mut_effect.csv"
icXX_csv = "results/antibody_escape/averages/ferret-7_mut_icXX.csv"
effect_html = "results/antibody_escape/averages/ferret-7_mut_effect.html"
icXX_html = "results/antibody_escape/averages/ferret-7_mut_icXX.html"
print(f"Analyzing results for {assay=}")
Analyzing results for assay='antibody_escape'
Read the input data and parameters:
models = [pickle.load(open(f, "rb")) for f in pickles]
site_numbering_map = pd.read_csv(site_numbering_map_csv).rename(
columns={"reference_site": "site"}
)
assert site_numbering_map[["site", "sequential_site"]].notnull().all().all()
if any(not m.sequential_integer_sites for m in models):
assert all(not m.sequential_integer_sites for m in models)
site_numbering_map["site"] = site_numbering_map["site"].astype(str)
sequential_to_site = site_numbering_map.set_index("sequential_site")["site"].to_dict()
assert len(sequential_to_site) == len(site_numbering_map)
selections = params["selections"]
assert len(selections) == len(set(selections))
# get sites to keep for each selection (relevant if keeping only regions)
if isinstance(selections, list):
selection_sites = {
selection: set(sequential_to_site.values()) for selection in selections
}
else:
assert isinstance(selections, dict)
selection_sites = {}
for selection, region in selections.items():
region_sequential = []
for r in region:
if isinstance(r, int):
region_sequential.append(r)
else:
assert isinstance(r, list) and all(isinstance(ri, int) for ri in r), r
assert r[0] <= r[1], r
region_sequential += list(range(r[0], r[1] + 1))
selection_sites[selection] = [sequential_to_site[r] for r in region_sequential]
selections = list(selections)
# get number of selections each site is kept in
n_selections_per_site = {
site: sum(site in selection_sites[selection] for selection in selections)
for site in site_numbering_map["site"]
}
# read Polyclonal models into a data frame that can be passed to PolyclonalAverage
assert len(selections) == len(models) == len(selection_sites)
models_df = pd.DataFrame(
[(s, m, selection_sites[s]) for s, m in zip(selections, models)],
columns=["selection", "model", "sites_to_keep"],
)
# read prob_escape means all into one data frame
assert len(selections) == len(prob_escape_mean_csvs)
prob_escape_means = pd.concat(
[
pd.read_csv(f).assign(selection=s)
for s, f in zip(selections, prob_escape_mean_csvs)
],
ignore_index=True,
)
# get the plot kwargs
escape_plot_kwargs = params["escape_plot_kwargs"]
Neutralization at concentrations used for each selection¶
For each selection going into the average, plot the average fraction neutralization (probability of escape) of variants with different numbers of mutations, both for the censored values used to fit the models and the uncensored values. Note the concentrations not used in the model fits are shown fainter and in a different shape:
mean_prob_escape_chart = (
alt.Chart(prob_escape_means)
.encode(
x=alt.X(
"concentration",
**(
{"title": assay_config["concentration_title"]}
if "concentration_title" in assay_config
else {}
),
scale=alt.Scale(
**(
assay_config["concentration_scale"]
if "concentration_scale" in assay_config
else {"type": "log"}
)
),
),
y=alt.Y(
"probability escape",
scale=alt.Scale(**assay_config["prob_escape_scale"]),
),
column=alt.Column(
"censored",
title=None,
header=alt.Header(labelFontWeight="bold", labelFontSize=10),
),
row=alt.Row(
"selection",
title=None,
header=alt.Header(labelFontWeight="bold", labelFontSize=10),
),
color=alt.Color("n_substitutions"),
tooltip=[
alt.Tooltip(c, format=".3g") if c == "probability escape" else c
for c in prob_escape_means.columns
],
shape=alt.Shape("use_in_fit", scale=alt.Scale(domain=[True, False])),
opacity=alt.Opacity(
"use_in_fit", scale=alt.Scale(domain=[True, False], range=[0.9, 0.3])
),
)
.mark_line(point=True, size=0.75, opacity=0.8)
.properties(width=230, height=145)
.configure_axis(grid=False)
.configure_point(size=50)
)
mean_prob_escape_chart
Average mutation effects¶
First build a PolyclonalAverage
:
avg_model = polyclonal.PolyclonalAverage(models_df, region_col="sites_to_keep")
assert set(avg_model.sites).issubset(site_numbering_map["site"])
print(f"Saving the average model to {avg_pickle_file}")
with open(avg_pickle_file, "wb") as f:
pickle.dump(avg_model, f)
Saving the average model to results/antibody_escape/averages/ferret-7_polyclonal_model.pickle
Neutralization curves against unmutated protein (which reflect the wildtype activities, Hill coefficients, and non-neutralizable fractions):
avg_model.curves_plot()
Get the mutation escape values, and apply any filters also being applied to the final average plot:
mut_escape_df = avg_model.mut_escape_df_replicates
# add mutation annotations if any
if mutation_annotations_csv:
mutation_annotations = pd.read_csv(mutation_annotations_csv)
if not avg_model.sequential_integer_sites:
mutation_annotations["site"] = mutation_annotations["site"].astype(str)
if not {"site", "mutant"}.issubset(mutation_annotations.columns):
raise ValueError(f"{mutation_annotations.columns=} lacks 'site', 'mutant'")
if set(mutation_annotations.columns).intersection(mut_escape_df.columns) != {
"site",
"mutant",
}:
raise ValueError(
f"{mutation_annotations.columns=} shares columns with {mut_escape_df.columns=}"
)
mut_escape_df = mut_escape_df.merge(
mutation_annotations,
on=["site", "mutant"],
how="left",
validate="many_to_one",
)
else:
mutation_annotations = None
# apply `times_seen` filter
try:
times_seen = escape_plot_kwargs["addtl_slider_stats"]["times_seen"]
except KeyError:
times_seen = 1
n = len(mut_escape_df)
mut_escape_df = mut_escape_df.query("times_seen >= @times_seen")
print(
f"Filtering for `times_seen` >= {times_seen} removes "
f"{n - len(mut_escape_df)} of {n} mutations"
)
# apply any other filters
if "plot_hide_stats" in params:
for stat, stat_d in params["plot_hide_stats"].items():
n = len(mut_escape_df)
min_stat = stat_d["init"]
df_to_merge = pd.read_csv(stat_d["csv"])
if not avg_model.sequential_integer_sites:
df_to_merge["site"] = df_to_merge["site"].astype(str)
if "min_filters" in stat_d:
for col, min_col in stat_d["min_filters"].items():
df_to_merge = df_to_merge.query(f"{col} >= @min_col")
col = stat_d["csv_col"]
assert col not in mut_escape_df.columns, f"{col=}, {mut_escape_df.columns=}"
mut_escape_df = mut_escape_df.merge(
df_to_merge[["site", "mutant", col]],
how="left",
validate="many_to_one",
)
mut_escape_df = mut_escape_df[
(mut_escape_df[col] >= min_stat) | mut_escape_df[col].isnull()
].drop(columns=col)
print(
f"Filtering for {stat} >= {min_stat} removes "
f"{n - len(mut_escape_df)} of {n} mutations"
)
Filtering for `times_seen` >= 2 removes 3472 of 12064 mutations Filtering for functional effect >= -3 removes 431 of 8592 mutations
Plot actual correlation scatter plots between replicates:
assert mut_escape_df["epitope"].nunique() == 1
corr_panels = []
for sel1, sel2 in itertools.combinations(
sorted(mut_escape_df["selection"].unique()), 2
):
corr_df = (
mut_escape_df.query("selection == @sel1")[["mutation", "escape"]]
.rename(columns={"escape": sel1})
.merge(
mut_escape_df.query("selection == @sel2")[["mutation", "escape"]].rename(
columns={"escape": sel2}
),
validate="one_to_one",
)
)
n = len(corr_df)
r = corr_df[[sel1, sel2]].corr().values[1, 0]
corr_panels.append(
alt.Chart(corr_df)
.encode(
alt.X(sel1, scale=alt.Scale(nice=False, padding=4)),
alt.Y(sel2, scale=alt.Scale(nice=False, padding=4)),
tooltip=[
"mutation",
alt.Tooltip(sel1, format=".3f"),
alt.Tooltip(sel2, format=".3f"),
],
)
.mark_circle(color="black", size=30, opacity=0.25)
.properties(
width=160,
height=160,
title=alt.TitleParams(
f"R = {r:.2f}, N = {n}", fontSize=11, fontWeight="normal", dy=2
),
)
)
ncols = 4
corr_rows = []
for irow in range(int(math.ceil(len(corr_panels) / ncols))):
corr_rows.append(
alt.hconcat(
*[
corr_panels[irow * ncols + icol]
for icol in range(min(ncols, len(corr_panels[irow * ncols :])))
]
)
)
alt.vconcat(*corr_rows).configure_axis(grid=False)
Correlation of escape across different selections:
mut_escape_corr = polyclonal.utils.tidy_to_corr(
mut_escape_df,
sample_col="selection",
label_col="mutation",
value_col="escape",
group_cols="epitope",
method="pearson",
).rename(columns={"correlation": "r"})
polyclonal.plot.corr_heatmap(
corr_df=mut_escape_corr,
corr_col="r",
sample_cols="selection",
group_col=None if mut_escape_corr["epitope"].nunique() == 1 else "epitope",
)
Site line plots for the site values for each individual selection (model) in the average. This makes it easier to tell if one selection is an outlier before we plot the full averages below, and how correlated the selections are. Note the plot is interactive: you can mouseover points and change the site metric shown.
site_stats = {
"sum": "sum",
"mean": "mean",
"max": "max",
"min": "min",
"sum_abs": lambda s: s.abs().sum(),
"mean_abs": lambda s: s.abs().mean(),
}
addtl_site_cols = [
c for c in site_numbering_map.columns if c != "site" and c.endswith("site")
]
site_escape_df = (
pd.concat(
[
mut_escape_df.assign(
floor_at_zero=floor_at_zero,
escape=lambda x: (x["escape"] * assay_config["scale_stat"]).clip(
lower=0 if floor_at_zero else None,
),
)
for floor_at_zero in [True, False]
]
)
.groupby(["epitope", "selection", "floor_at_zero", "site"], as_index=False)
.aggregate(**{stat: pd.NamedAgg("escape", site_stats[stat]) for stat in site_stats})
.melt(
id_vars=["epitope", "selection", "floor_at_zero", "site"],
value_vars=list(site_stats),
var_name="site statistic",
value_name="site_val",
)
.merge(
site_numbering_map[["site", *addtl_site_cols]],
validate="many_to_one",
)
)
try:
init_site_stat = params["escape_plot_kwargs"]["init_site_statistic"]
except KeyError:
init_site_stat = "sum"
site_statistic_selection = alt.selection_point(
fields=["site statistic"],
bind=alt.binding_radio(
name="site statistic",
options=list(site_stats),
),
value=init_site_stat,
)
try:
init_floor_at_zero = params["escape_plot_kwargs"]["init_floor_at_zero"]
except KeyError:
init_floor_at_zero = True
floor_selection = alt.selection_point(
fields=["floor_at_zero"],
bind=alt.binding_radio(
name="floor at zero",
options=[True, False],
),
value=init_floor_at_zero,
)
site_selection = alt.selection_point(fields=["site"], on="mouseover", empty=False)
per_selection_site_escape_chart_base = (
alt.Chart(site_escape_df)
.encode(
x=alt.X(
"site",
sort=alt.SortField("sequential_site"),
axis=alt.Axis(labelOverlap=True),
scale=alt.Scale(nice=False, zero=False),
),
y=alt.Y("site_val", title=assay),
color="epitope",
tooltip=[
"site",
alt.Tooltip("site_val", format=".2f", title=assay),
],
)
.properties(width=750, height=85)
.add_params(site_statistic_selection, floor_selection, site_selection)
.transform_filter(site_statistic_selection)
.transform_filter(floor_selection)
)
per_selection_site_escape_chart_lines = per_selection_site_escape_chart_base.mark_line(
size=0.75
)
per_selection_site_escape_chart_points = per_selection_site_escape_chart_base.encode(
size=alt.condition(site_selection, alt.value(75), alt.value(30)),
strokeWidth=alt.condition(site_selection, alt.value(2), alt.value(0)),
).mark_circle(filled=True, stroke="orange")
per_selection_escape_chart = (
(per_selection_site_escape_chart_lines + per_selection_site_escape_chart_points)
.facet(
facet=alt.Facet(
"selection",
title=None,
header=alt.Header(labelPadding=0),
),
columns=1,
spacing=5,
)
.configure_axis(grid=False)
)
per_selection_escape_chart
Plot standard deviation in escape versus escape. This is useful to look at if you have imposed a filter on the standard deviation on the escape to filter variants with large variation.
escape_plot_kwargs = params["escape_plot_kwargs"]
try:
avg_type = escape_plot_kwargs["avg_type"]
except KeyError:
avg_type = avg_model.default_avg_to_plot
std_col = f"{assay} std"
avg_col = f"{assay} {avg_type}"
std_df = avg_model.mut_escape_df_w_model_values.rename(
columns={f"escape_{avg_type}": avg_col, "escape_std": std_col}
)
# replace dots in column names, which cause problems in tooltips, with underscore
std_df = std_df.rename(columns={c: c.replace(".", "_") for c in std_df.columns})
try:
init_times_seen = escape_plot_kwargs["addtl_slider_stats"]["times_seen"]
except KeyError:
init_times_seen = 1
times_seen_slider = alt.param(
value=init_times_seen,
bind=alt.binding_range(
name="minimum times_seen",
min=1,
max=std_df[std_col].quantile(0.9),
),
)
try:
init_std = escape_plot_kwargs["addtl_slider_stats"]["escape_std"]
except KeyError:
try:
init_std = escape_plot_kwargs["addtl_slider_stats"][f"{assay}_std"]
except KeyError:
init_std = std_df[std_col].max()
std_slider = alt.param(
value=init_std,
bind=alt.binding_range(
name=f"maximum {std_col}",
min=0,
max=std_df[std_col].max(),
),
)
std_chart = (
alt.Chart(std_df)
.add_params(std_slider, times_seen_slider)
.transform_filter(alt.datum["times_seen"] >= times_seen_slider)
.transform_calculate(above_max_std=alt.datum[std_col] > std_slider)
.encode(
alt.X(std_col),
alt.Y(avg_col),
alt.Color(
"above_max_std:N",
scale=alt.Scale(domain=[False, True]),
legend=alt.Legend(orient="bottom", symbolOpacity=1),
),
tooltip=[
alt.Tooltip(c, format=".3g") if std_df[c].dtype == float else c
for c in std_df.columns
],
)
.mark_circle(opacity=0.2, strokeOpacity=1, stroke="black", strokeWidth=0.5)
.resolve_scale(x="independent", y="independent")
.properties(width=250, height=250)
.configure_axis(grid=False)
)
std_chart