Compare mutation effects on ACE2 binding vs sera escape at key sites¶
This notebook compares how different mutations affect ACE2 binding versus escape at key sites.
In [1]:
# this cell is tagged as parameters for `papermill` parameterization
dms_csv = None
logoplot_subdir = None
min_cell_entry = None
min_mutations_at_site = None
RBD_up_down_chart_html = None
In [2]:
# Parameters
dms_csv = "results/summaries/all_adult_sera_escape.csv"
logoplot_subdir = "results/binding_vs_escape/logoplots"
RBD_up_down_subdir = "results/RBD_up_down"
min_cell_entry = -2
min_mutations_at_site = 5
RBD_up_down_chart_html = "results/binding_vs_escape/RBD_up_down_chart_html.html"
RBD_up_down_csv = "results/RBD_up_down/RBD_up_down_sites.csv"
In [3]:
import math
import os
import tempfile
import urllib.request
import altair as alt
import dmslogo
import matplotlib
import matplotlib.pyplot as plt
import numpy
import palettable
import pandas as pd
import polyclonal.pdb_utils
plt.rcParams['svg.fonttype'] = 'none'
os.makedirs(logoplot_subdir, exist_ok=True)
os.makedirs(RBD_up_down_subdir, exist_ok=True)
/fh/fast/bloom_j/computational_notebooks/bdadonai/2025/SARS-CoV-2_KP.3.1.1_spike_DMS/.snakemake/conda/92ba7412cf55ee5d47c61c431d1bed6f_/lib/python3.12/site-packages/dmslogo/logo.py:27: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81. import pkg_resources
Read input data¶
In [4]:
dms_df = (
pd.read_csv(dms_csv)
.rename(columns={"adult sera escape": "sera escape", "spike mediated entry": "cell entry"})
.dropna(subset=["sera escape", "cell entry", "ACE2 binding"])
.query("`cell entry` >= @min_cell_entry")
.query("mutant not in ['*', '-']") # exclude stop and gap
.assign(
mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
n_mutations_at_site=lambda x: x.groupby("site")["mutant"].transform("count"),
)
.reset_index(drop=True)
)
Calculate correlation between ACE2 binding and escape for each site¶
In [5]:
# compute correlations
correlation_df = (
dms_df
.groupby("site")
[["sera escape", "ACE2 binding"]]
.corr()
.reset_index()
.query("level_1 == 'sera escape'")
.rename(columns={"ACE2 binding": "correlation"})
[["site", "correlation"]]
.dropna(subset="correlation")
.reset_index(drop=True)
)
# add correlations to DMS data frame
dms_df = dms_df.merge(correlation_df, validate="many_to_one")
In [6]:
# Calculate root-square-mean effect for ACE2 binding for each site
rms_binding = (
dms_df
.groupby("site")["ACE2 binding"]
.apply(lambda x: numpy.sqrt(numpy.mean(x**2)))
.rename("RMS_binding")
)
# Calculate root-square-mean effect for sera escape for each site
rms_escape = (
dms_df
.groupby("site")["sera escape"]
.apply(lambda x: numpy.sqrt(numpy.mean(x**2)))
.rename("RMS_escape")
)
# Merge back into original dataframe only once for each RMS
dms_df = dms_df.merge(rms_binding, on="site")
dms_df = dms_df.merge(rms_escape, on="site")
In [7]:
#get R*RMS
dms_df['RxRMS_be'] = dms_df['correlation'] * dms_df['RMS_binding']*dms_df['RMS_escape']
dms_df["RxRMS_be_floored"] = dms_df["RxRMS_be"].where(dms_df["RxRMS_be"] <= 0, 0)
dms_df["RxRMS_be_floored"] = -dms_df["RxRMS_be_floored"]
dms_df = dms_df.sort_values(by='sequential_site', ascending=True)
In [8]:
unique_sites_df = dms_df.drop_duplicates("site")
In [9]:
# Define RBM sites
rbm_sites = {413, 442, 449, 451, 452, 471, 481, 482, 489, 484, 493, 495, 496, 497, 500}
def assign_region(seq_site):
if seq_site in rbm_sites:
return "RBM"
elif 13 <= seq_site <= 301:
return "NTD"
elif 327 <= seq_site <= 523:
return "RBD"
elif 523 < seq_site <= 586:
return "SD1"
elif 586 < seq_site <= 681:
return "SD2"
elif 681 < seq_site <= 1206:
return "S2"
else:
return "other"
# Apply to dataframe
unique_sites_df = unique_sites_df.copy()
unique_sites_df["region_from_site"] = unique_sites_df["sequential_site"].apply(assign_region)
#mark if site belongs to RBM
unique_sites_df["is_RBM"] = unique_sites_df["sequential_site"].apply(
lambda x: "RBM" if x in rbm_sites else "non-RBM"
)
In [10]:
columns_to_save = [
"site",
"region",
"n_mutations_at_site",
"correlation",
"RMS_binding",
"RMS_escape",
"RxRMS_be",
"RxRMS_be_floored",
"is_RBM"
]
# Save to CSV
unique_sites_df[columns_to_save].to_csv(RBD_up_down_csv, index=False)
In [11]:
# Slider parameters
n_mut_slider = alt.param(
name="nMutSlider",
value=3,
bind=alt.binding_range(
name="Minimum mutations at site",
min=2,
max=int(unique_sites_df["n_mutations_at_site"].max()),
step=1
)
)
max_corr_slider = alt.param(
value=1,
bind=alt.binding_range(
name="only show sites with correlation r less than this",
min=-1,
max=1,
step=0.01,
),
)
# Tick values every 20 sequential sites
tick_values = numpy.arange(
unique_sites_df["sequential_site"].min(),
unique_sites_df["sequential_site"].max() + 1,
20
).tolist()
# Base chart with filters
base = (
alt.Chart(unique_sites_df)
.add_params(n_mut_slider)
.transform_filter(
(alt.datum.n_mutations_at_site >= n_mut_slider)
)
)
line = base.mark_line().encode(
x=alt.X(
"site:N",
sort=alt.EncodingSortField(field="sequential_site", order="ascending"),
axis=alt.Axis(
title="Site",
values=[
row["site"]
for _, row in unique_sites_df[
unique_sites_df["sequential_site"].isin(tick_values)
].iterrows()
],
labelAngle=-45
)
),
y=alt.Y(
"RxRMS_be_floored:Q",
title="estimated effect on RBD up/down motion",
axis=alt.Axis(grid=False)
),
color=alt.value("#607d8b"), # fixed line color
tooltip=[
alt.Tooltip("site:N"),
alt.Tooltip("sequential_site:Q"),
alt.Tooltip("RxRMS_be_floored:Q", format=".2f"),
alt.Tooltip("RMS_binding:Q", format=".2f"),
alt.Tooltip("RMS_escape:Q", format=".2f"),
"n_mutations_at_site:Q",
"region_from_site:N"
]
)
points = base.mark_point(filled=True, size=60).encode(
x=alt.X("site:N", sort=alt.EncodingSortField(field="sequential_site", order="ascending")),
y=alt.Y(
"RxRMS_be_floored:Q",
axis=alt.Axis(grid=False)
),
color=alt.condition(
alt.datum.is_RBM == "RBM",
alt.value("#d1615d"), # Color RBM points red
alt.value("#607d8b") # Other points blue
),
tooltip=[
"site:N",
alt.Tooltip("correlation:Q", format=".2f"),
alt.Tooltip("RMS_binding:Q", format=".2f"),
alt.Tooltip("RxRMS_be_floored:Q", format=".2f"),
alt.Tooltip("RMS_escape:Q", format=".2f"),
"n_mutations_at_site:Q",
"region_from_site:N",
],
)
# Define a brush selection for the region bar
brush = alt.selection_interval(encodings=["x"])
# Region bar with brush
region_bar = (
base.mark_rect(height=20)
.encode(
x=alt.X(
"site:N",
sort=alt.EncodingSortField(field="sequential_site", order="ascending"),
axis=None
),
y=alt.value(0),
color=alt.Color("region_from_site:N", title="Region")
)
.add_params(brush)
.properties(
width=1200,
height=20
)
)
# Main chart filtered by brush
RxRMS_chart = (
(line + points)
.transform_filter(brush)
.properties(
width=1200,
height=200
)
)
full_chart = (
alt.vconcat(region_bar, RxRMS_chart)
.resolve_scale(color="shared")
.add_params(max_corr_slider)
.transform_filter(alt.datum["correlation"] <= max_corr_slider)
.properties(
title=alt.TitleParams(
["Site effect on RBD up/down conformation"],
anchor="middle",
fontSize=16,
dy=-5,
)
)
.configure_view(stroke=None)
)
full_chart
# Save chart
full_chart.save(RBD_up_down_chart_html)
Plot sites with high inverse correlation between ACE2 binding and escape¶
Plot sites with high inverse correlation of binding and escape; note the slider at the bottom can control which sites are shown:
In [12]:
# filter for min number of mutations per site
dms_df = (
dms_df
.query("n_mutations_at_site >= @min_mutations_at_site")
)
In [13]:
# first make base chart
facet_size = 100
cell_entry_slider = alt.param(
value=min_cell_entry,
bind=alt.binding_range(
name="minimum cell entry",
min=dms_df["cell entry"].min(),
max=0,
),
)
binding_escape_corr_base = (
alt.Chart(dms_df)
.add_params(cell_entry_slider)
.transform_filter(alt.datum["cell entry"] >= cell_entry_slider)
)
binding_escape_corr_chart = (
(
(
binding_escape_corr_base
.encode(
x=alt.X("ACE2 binding", scale=alt.Scale(nice=False, padding=6)),
y=alt.Y("sera escape", scale=alt.Scale(nice=False, padding=6)),
tooltip=[
"site",
"mutation",
alt.Tooltip("ACE2 binding", format=".2f"),
alt.Tooltip("sera escape", format=".2f"),
alt.Tooltip("cell entry", format=".2f"),
],
)
.mark_circle(color="black", opacity=0.3, size=60)
)
+ (
binding_escape_corr_base
.transform_regression("ACE2 binding", "sera escape", params=True)
.transform_calculate(
r=alt.expr.if_(
alt.datum["coef"][1] > 0,
alt.expr.sqrt(alt.datum["rSquared"]),
-alt.expr.sqrt(alt.datum["rSquared"]),
),
r_text="r = " + alt.expr.format(alt.datum["r"], ".2f"),
)
.encode(
text="r_text:N",
x=alt.value(3),
y=alt.value(facet_size - 6),
)
.mark_text(size=12, align="left", color="blue")
)
)
.properties(width=facet_size, height=facet_size)
.facet(
facet=alt.Facet(
"site",
title=None,
header=alt.Header(
labelFontSize=14,
labelFontStyle="italic",
labelPadding=0,
labelExpr="'site ' + datum.label",
)
),
spacing=8,
columns=8,
)
.configure_axis(grid=False)
)
# now make chart filtered for strongly negative correlations
max_corr_slider = alt.param(
value=-0.82,
bind=alt.binding_range(
name="only show sites with correlation r less than this",
min=-1,
max=1,
step=0.01,
),
)
binding_escape_neg_corr_chart = (
binding_escape_corr_chart
.properties(
title=alt.TitleParams(
"Correlation of ACE2 binding and escape filtered by extent of negative correlation",
anchor="middle",
fontSize=16,
dy=-5,
),
autosize=alt.AutoSizeParams(resize=True),
)
.add_params(max_corr_slider)
.transform_filter(alt.datum["correlation"] <= max_corr_slider)
)
binding_escape_neg_corr_chart
Out[13]:
We now plot the same correlation for sites of strong escape¶
We manually specify some sites of strong escape:
In [14]:
escape_sites = [50, 132, 200, 222, 332, 344, 357, 393, 428, 440, 458, 470, 475, 478, 505, 518, 572, 852]
binding_escape_high_escape_corr_chart = (
binding_escape_corr_chart
.properties(
title=alt.TitleParams(
"Correlation of ACE2 binding and escape for sites of strong escape",
anchor="middle",
fontSize=16,
dy=-5,
),
autosize=alt.AutoSizeParams(resize=True),
)
.transform_filter(alt.FieldOneOfPredicate("site", escape_sites))
)
binding_escape_high_escape_corr_chart
Out[14]:
Plot sites of top escape for mutations in different regions.¶
We plot both binding-escape correlation plots and logo plots for the sites with the most escaping mutations.
We stratify sites by:
- RBD ACE2 proximal
- RBD ACE2 distal
- non-RBD
First get RBD sites distance from ACE2, and then use that to separate ACE2 proximal and distal:
In [15]:
ace2_proximal_cutoff = 15 # classify as ACE2 proximal if CA distance <= this
# chain A is ACE2, chain E is RBD
with tempfile.NamedTemporaryFile() as f:
urllib.request.urlretrieve(
"https://files.rcsb.org/download/6M0J.pdb",
f.name,
)
coords_df = polyclonal.pdb_utils.extract_atom_locations(f.name, ["A", "E"], target_atom="CA")
# get closest distance for each residue in chain E (RBD) to residue in chain A (ACE2)
dist_df = (
coords_df
.query("chain == 'E'")
[["site", "x", "y", "z"]]
.merge(
(
coords_df
.query("chain == 'A'")
[["site", "x", "y", "z"]]
.rename(columns={c: f"ACE2_{c}" for c in ["site", "x", "y", "z"]})
),
how="cross",
)
.assign(
distance=lambda x: x.apply(
lambda r: math.sqrt(sum((r[c] - r[f"ACE2_{c}"])**2 for c in ["x", "y", "z"])),
axis=1,
)
)
.groupby("site", as_index=False)
.aggregate({"distance": "min"})
)
# Keep only rows where 'site' is fully numeric
dms_df = dms_df[dms_df["site"].astype(str).str.match(r"^\d+$")].copy()
# Convert 'site' to int
dms_df["site"] = dms_df["site"].astype(int)
dms_df_by_region = (
dms_df
.merge(dist_df, how="left", validate="many_to_one")
.assign(
region=lambda x: numpy.where(
(x["region"] == "RBD") & (x["distance"] <= ace2_proximal_cutoff),
"RBD ACE2 proximal",
numpy.where(x["region"] == "RBD", "RBD ACE2 distal", "non-RBD"),
),
)
)
In [16]:
# Keep only rows where 'site' is fully numeric
dms_df = dms_df[dms_df["site"].astype(str).str.match(r"^\d+$")].copy()
# Convert 'site' to int
dms_df["site"] = dms_df["site"].astype(int)
In [17]:
dms_df_by_region = (
dms_df
.merge(dist_df, how="left", validate="many_to_one")
.assign(
region=lambda x: numpy.where(
(x["region"] == "RBD") & (x["distance"] <= ace2_proximal_cutoff),
"RBD ACE2 proximal",
numpy.where(x["region"] == "RBD", "RBD ACE2 distal", "non-RBD"),
),
)
)
Now plot escape and binding for the sites with the top most escaping mutations in each region. Make both correlation plots and logo plots colored by ACE2 binding:
In [18]:
top_n = 7
# for coloring by ACE2
ace2_colormap = dmslogo.colorschemes.ValueToColorMap(
minvalue=-1.5,
maxvalue=1.5,
cmap=palettable.colorbrewer.diverging.PRGn_4.mpl_colormap,
#cmap=palettable.colorbrewer.diverging.PuOr_4.mpl_colormap,
#cmap=palettable.lightbartlein.diverging.BlueOrange8_2.mpl_colormap,
)
assert abs(ace2_colormap.minvalue) == ace2_colormap.maxvalue, "not symmetric for diverging color scale"
for orientation in ["horizontal", "vertical"]:
fig, _ = ace2_colormap.scale_bar(
orientation=orientation, label="ACE2 binding",
)
display(fig)
svg = os.path.join(logoplot_subdir, f"ace2_scalebar_{orientation}.svg")
print(f"Saving to {svg}")
fig.savefig(svg, bbox_inches="tight")
plt.close(fig)
for region, df in dms_df_by_region.groupby("region"):
print(f"\n\nAnalyzing top sites for {region=}")
# get sites of top escape mutations
top_escape_sites = (
df
.sort_values("sera escape", ascending=False)
.groupby("site", sort=False)
.first()
.head(top_n)
)
sites = top_escape_sites.index.tolist()
# plot correlation for these top sites
corr_chart = (
binding_escape_corr_chart
.properties(
title=alt.TitleParams(
f"Correlation of ACE2 binding and escape for {region} sites where mutations cause strong escape",
anchor="middle",
fontSize=16,
dy=-5,
),
autosize=alt.AutoSizeParams(resize=True),
)
.transform_filter(alt.FieldOneOfPredicate("site", sites))
)
display(corr_chart)
# make logo plot
fig, _ = dmslogo.draw_logo(
data=df[df["site"].isin(sites)].rename(columns={"sera escape": "escape"}).assign(
wildtype_site=lambda x: x["wildtype"] + x["site"].astype(str),
color=lambda x: (
x["ACE2 binding"]
.clip(lower=ace2_colormap.minvalue, upper=ace2_colormap.maxvalue)
.map(ace2_colormap.val_to_color)
),
),
x_col="sequential_site",
letter_col="mutant",
color_col="color",
xtick_col="wildtype_site",
letter_height_col="escape",
xlabel="",
heightscale=1,
)
display(fig)
svg = os.path.join(logoplot_subdir, f"{region.replace(' ', '_')}_logoplot.svg")
print(f"Saving to {svg}")
fig.savefig(svg, bbox_inches="tight")
plt.close(fig)
plt.close(fig)
Saving to results/binding_vs_escape/logoplots/ace2_scalebar_horizontal.svg
Saving to results/binding_vs_escape/logoplots/ace2_scalebar_vertical.svg
Analyzing top sites for region='RBD ACE2 distal'
Saving to results/binding_vs_escape/logoplots/RBD_ACE2_distal_logoplot.svg
Analyzing top sites for region='RBD ACE2 proximal'
Saving to results/binding_vs_escape/logoplots/RBD_ACE2_proximal_logoplot.svg
Analyzing top sites for region='non-RBD'
Saving to results/binding_vs_escape/logoplots/non-RBD_logoplot.svg
In [ ]: