Escape at key sites: logo plots and binding / escape correlations¶
Make logo plots of serum escape at key sites, and look at relationship between escape and other phenotypes like ACE2 binding.
First get input files / parameters from papermill and import Python modules:
# this cell is tagged as `parameters` for papermill parameterization
dms_csv = None
per_antibody_csv = None
pango_consensus_seqs_json = None
codon_seq = None
logoplot_subdir = None
# Parameters
dms_csv = "results/summaries/summary.csv"
per_antibody_csv = "results/summaries/per_antibody_escape.csv"
codon_seq = "data/XBB_1_5_spike_codon.fa"
logoplot_subdir = "results/key_sites/logoplots"
import itertools
import os
import altair as alt
import Bio.SeqIO
import dmslogo
import dmslogo.colorschemes
import matplotlib
import matplotlib.pyplot as plt
import numpy
import pandas as pd
_ = alt.data_transformers.disable_max_rows()
plt.rcParams['svg.fonttype'] = 'none'
os.makedirs(logoplot_subdir, exist_ok=True)
Read input data¶
Keep only mutations with all phenotypes measured:
# read averages for all DMS measurements
dms_df = (
pd.read_csv(dms_csv)
.rename(
columns={"human sera escape": "sera escape", "spike mediated entry": "cell entry"}
)
.query("`sera escape`.notnull() and `cell entry`.notnull() and `ACE2 binding`.notnull()")
)
# read per antibody values, merge with averages to create escape_df
per_antibody_df = pd.read_csv(per_antibody_csv)
assert per_antibody_df["antibody_set"].nunique() == 1, "code expects 1 antibody_set"
if (
(intersection := set(dms_df.columns).intersection(per_antibody_df.columns))
!= {"site", "wildtype", "mutant"}
):
raise ValueError(f"unexpected {intersection=}")
assert "average" not in per_antibody_df["antibody"]
escape_df = (
pd.concat(
[
dms_df[["site", "wildtype", "mutant", "sera escape"]].rename(
columns={"sera escape": "escape"}
).assign(antibody="average"),
per_antibody_df.drop(columns="antibody_set"),
],
ignore_index=True,
)
.merge(dms_df.drop(columns="sera escape"), validate="many_to_one")
.assign(wildtype_site=lambda x: x["wildtype"] + x["site"].astype(str))
)
Determine key sites of strongest escape¶
Get key sites with most site escape, and plot their site escape values in interactive chart:
# specification of how to choose sites
key_sites_by_rank = {
"total_positive_escape": {
"any antibody": 5,
"average of antibodies": 15,
},
}
# sites used in neuts
key_sites_manual = []
# get total magnitude of escape at each site, both for averages
# and across all individual antibodies
site_escape_df = (
escape_df
.assign(
is_average=lambda x: numpy.where(
x["antibody"] == "average", "average of antibodies", "any antibody"),
)
.groupby(["is_average", "antibody", "site", "sequential_site"], as_index=False)
.aggregate(
total_abs_escape=pd.NamedAgg("escape", lambda s: s.abs().sum()),
total_positive_escape=pd.NamedAgg("escape", lambda s: s.clip(lower=0).sum()),
total_negative_escape=pd.NamedAgg("escape", lambda s: s.clip(upper=0).abs().sum()),
)
.groupby(["is_average", "site", "sequential_site"], as_index=False)
.aggregate(
{
"total_abs_escape": "max",
"total_positive_escape": "max",
"total_negative_escape": "max",
}
)
.melt(
id_vars=["is_average", "site", "sequential_site"],
var_name="site metric",
value_name="site escape",
)
.assign(
rank=lambda x: (
x.groupby(["is_average", "site metric"])
["site escape"]
.rank(ascending=False, method="min")
.astype(int)
)
)
)
# get key sites
print(f"Keeping the following manually specified sites: {key_sites_manual}")
key_sites = set(key_sites_manual)
for site_metric, site_metric_d in key_sites_by_rank.items():
for is_average, rank in site_metric_d.items():
new_sites = set(
site_escape_df
.query("`site metric` == @site_metric")
.query("is_average == @is_average")
.query("rank <= @rank")
["site"]
)
print(f"Adding sites with {site_metric} / {is_average} rank <= {rank}: {new_sites}")
key_sites = key_sites.union(new_sites)
print(f"Overall keeping the following {len(key_sites)} sites: {key_sites}")
site_escape_df["key_site"] = site_escape_df["site"].isin(key_sites)
# plot sites being kept
site_selection = alt.selection_point(fields=["site"], on="mouseover", empty=False)
site_metric_selection = alt.selection_point(
fields=["site metric"],
value="total_positive_escape",
bind=alt.binding_select(
name="site metric",
options=site_escape_df["site metric"].unique(),
),
)
site_escape_chart = (
alt.Chart(site_escape_df)
.add_params(site_selection, site_metric_selection)
.transform_filter(site_metric_selection)
.encode(
alt.X("site", sort=alt.SortField("sequential_site"), scale=alt.Scale(nice=False, zero=False)),
alt.Y("site escape"),
alt.Color("key_site"),
alt.Row("is_average", title=None),
tooltip=[alt.Tooltip(c, format=".2f") if site_escape_df[c].dtype == float else c for c in site_escape_df.columns],
strokeWidth=alt.condition(site_selection, alt.value(2), alt.value(0)),
opacity=alt.condition(site_selection, alt.value(1), alt.value(0.35)),
size=alt.condition(site_selection, alt.value(70), alt.value(30)),
)
.mark_circle(stroke="black")
.configure_axis(grid=False)
.resolve_scale(y="independent")
.properties(
width=600,
height=150,
title="Escape at each site for average of antibodies or max for any antibody",
)
)
site_escape_chart
Keeping the following manually specified sites: []
Adding sites with total_positive_escape / any antibody rank <= 5: {420, 357, 485, 456, 473}
Adding sites with total_positive_escape / average of antibodies rank <= 15: {352, 450, 420, 357, 421, 455, 200, 456, 234, 371, 375, 440, 473, 475, 447}
Overall keeping the following 16 sites: {450, 455, 456, 200, 473, 475, 352, 420, 357, 485, 421, 234, 371, 375, 440, 447}
Draw logo plots for key sites of strongest escape¶
First get key sites and assign colors by ACE2 affinity:
key_sites_df = (
escape_df
.query("site in @key_sites")
.query("mutant not in ['*', '-']")
)
# for coloring by ACE2
ace2_colormap = dmslogo.colorschemes.ValueToColorMap(
minvalue=max(-2, key_sites_df["ACE2 binding"].min()),
maxvalue=0,
cmap="YlOrBr",
)
key_sites_df["color"] = (
key_sites_df["ACE2 binding"]
.clip(lower=ace2_colormap.minvalue, upper=ace2_colormap.maxvalue)
.map(ace2_colormap.val_to_color)
)
for orientation in ["horizontal", "vertical"]:
fig, _ = ace2_colormap.scale_bar(
orientation=orientation, label="ACE2 binding",
)
display(fig)
svg = os.path.join(logoplot_subdir, f"key_sites_ace2_scalebar_{orientation}.svg")
print(f"Saving to {svg}")
fig.savefig(svg, bbox_inches="tight")
plt.close(fig)
Saving to results/key_sites/logoplots/key_sites_ace2_scalebar_horizontal.svg
Saving to results/key_sites/logoplots/key_sites_ace2_scalebar_vertical.svg
Get which of the key sites are single nucleotide accessible:
codon_to_aas = {}
nts = "ACGT"
for nt1, nt2, nt3 in itertools.product(nts, nts, nts):
codon = f"{nt1}{nt2}{nt3}"
codon_to_aas[codon] = set()
for i in range(len(codon)):
for nt in nts:
mutcodon = codon[: i] + nt + codon[i + 1: ]
aa = str(Bio.Seq.Seq(mutcodon).translate())
codon_to_aas[codon].add(aa)
gene = str(Bio.SeqIO.read(codon_seq, "fasta").seq).upper()
key_sites_df = (
key_sites_df
.assign(
codon=lambda x: x["sequential_site"].map(
lambda r: gene[3 * (r - 1): 3 * r]
),
codon_translated=lambda x: x["codon"].map(
lambda c: str(Bio.Seq.Seq(c).translate())
),
single_nt_accessible=lambda x: x.apply(
lambda r: r["mutant"] in codon_to_aas[r["codon"]],
axis=1,
)
)
)
key_sites_df = pd.concat(
[
key_sites_df.query("single_nt_accessible").assign(
single_nt_accessible="single-nucleotide accessible"
),
key_sites_df.assign(single_nt_accessible="all measured mutations"),
],
ignore_index=True,
)
assert (key_sites_df["wildtype"] == key_sites_df["codon_translated"]).all()
Plots for averages across sera, for all mutations and just single-nucleotide accessible ones:
draw_logo_kwargs={
"letter_col": "mutant",
"color_col": "color",
"xtick_col": "wildtype_site",
"letter_height_col": "escape",
"xlabel": "",
"clip_negative_heights": True,
}
fig, _ = dmslogo.facet_plot(
data=key_sites_df.query("antibody == 'average'"),
x_col="sequential_site",
show_col=None,
gridrow_col="single_nt_accessible",
share_ylim_across_rows=False,
hspace=0.6,
height_per_ax=2.4,
draw_logo_kwargs=draw_logo_kwargs,
)
display(fig)
svg = os.path.join(logoplot_subdir, "avg_sera_escape_at_key_sites.svg")
print(f"Saving to {svg}")
fig.savefig(svg, bbox_inches="tight")
plt.close(fig)
Saving to results/key_sites/logoplots/avg_sera_escape_at_key_sites.svg
Now make plots for all individual sera, both with all and only single-nucleotide accessible mutations:
for single_nt_accessible, df in key_sites_df.groupby("single_nt_accessible"):
print(f"\n{single_nt_accessible=}")
fig, axes = dmslogo.facet_plot(
data=df.query("antibody != 'average'"),
x_col="sequential_site",
show_col=None,
gridrow_col="antibody",
share_ylim_across_rows=False,
hspace=0.6,
height_per_ax=2.1,
draw_logo_kwargs=draw_logo_kwargs,
)
display(fig)
svg = os.path.join(
logoplot_subdir,
f"all_sera_escape_at_key_sites_{single_nt_accessible.replace(' ', '_')}.svg",
)
print(f"Saving to {svg}")
fig.savefig(svg, bbox_inches="tight")
plt.close(fig)
single_nt_accessible='all measured mutations'
Saving to results/key_sites/logoplots/all_sera_escape_at_key_sites_all_measured_mutations.svg single_nt_accessible='single-nucleotide accessible'
Saving to results/key_sites/logoplots/all_sera_escape_at_key_sites_single-nucleotide_accessible.svg
Sites of neutralization assay mutations¶
Now make plots that show logo plots at sites of mutations analyzed in neutralization assays for sera in those assays:
muts_in_neuts = [
"V42F",
"Y200C",
"N234T",
"R357H",
"R403K",
"N405K",
"D420N",
"K444M",
"L455F",
"F456L",
"Y473S",
"T572K",
"A852V",
]
sites_in_neuts = [int(m[1: -1]) for m in muts_in_neuts]
sera_in_neuts = ["serum 287C", "serum 500C", "serum 501C"]
neuts_df = (
escape_df
.query("site in @sites_in_neuts")
.query("mutant not in ['*', '-']")
.query("antibody in @sera_in_neuts")
.assign(
mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
color=lambda x: numpy.where(x["mutation"].isin(muts_in_neuts), "blue", "gray"),
)
)
fig, axes = dmslogo.facet_plot(
data=neuts_df,
x_col="sequential_site",
show_col=None,
gridrow_col="antibody",
share_ylim_across_rows=False,
hspace=0.6,
height_per_ax=2.9,
draw_logo_kwargs={**draw_logo_kwargs, "clip_negative_heights": False},
)
display(fig)
svg = os.path.join(logoplot_subdir, f"neut_sites.svg")
print(f"Saving to {svg}")
fig.savefig(svg, bbox_inches="tight")
plt.close(fig)
Saving to results/key_sites/logoplots/neut_sites.svg