Compare mutation effects on ACE2 binding vs sera escape at key sites¶

This notebook compares how different mutations affect ACE2 binding versus escape at key sites.

In [1]:
# this cell is tagged as parameters for `papermill` parameterization
dms_csv = None
logoplot_subdir = None
min_cell_entry = None
min_mutations_at_site = None
RBD_up_down_chart_html = None
In [2]:
# Parameters
dms_csv = "results/summaries/all_adult_sera_escape.csv"
logoplot_subdir = "results/binding_vs_escape/logoplots"
RBD_up_down_subdir = "results/RBD_up_down"
min_cell_entry = -2
min_mutations_at_site = 5
RBD_up_down_chart_html = "results/binding_vs_escape/RBD_up_down_chart_html.html"
RBD_up_down_csv = "results/RBD_up_down/RBD_up_down_sites.csv"
In [3]:
import math
import os
import tempfile
import urllib.request

import altair as alt

import dmslogo

import matplotlib
import matplotlib.pyplot as plt

import numpy

import palettable

import pandas as pd

import polyclonal.pdb_utils

plt.rcParams['svg.fonttype'] = 'none'

os.makedirs(logoplot_subdir, exist_ok=True)
os.makedirs(RBD_up_down_subdir, exist_ok=True)
/fh/fast/bloom_j/computational_notebooks/bdadonai/2025/SARS-CoV-2_KP.3.1.1_spike_DMS/.snakemake/conda/92ba7412cf55ee5d47c61c431d1bed6f_/lib/python3.12/site-packages/dmslogo/logo.py:27: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  import pkg_resources

Read input data¶

In [4]:
dms_df = (
    pd.read_csv(dms_csv)
    .rename(columns={"adult sera escape": "sera escape", "spike mediated entry": "cell entry"})
    .dropna(subset=["sera escape", "cell entry", "ACE2 binding"])
    .query("`cell entry` >= @min_cell_entry")
    .query("mutant not in ['*', '-']")  # exclude stop and gap
    .assign(
        mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
        n_mutations_at_site=lambda x: x.groupby("site")["mutant"].transform("count"),
    )
    .reset_index(drop=True)
)

Calculate correlation between ACE2 binding and escape for each site¶

In [5]:
# compute correlations
correlation_df = (
    dms_df
    .groupby("site")
    [["sera escape", "ACE2 binding"]]
    .corr()
    .reset_index()
    .query("level_1 == 'sera escape'")
    .rename(columns={"ACE2 binding": "correlation"})
    [["site", "correlation"]]
    .dropna(subset="correlation")
    .reset_index(drop=True)
)

# add correlations to DMS data frame
dms_df = dms_df.merge(correlation_df, validate="many_to_one")
In [6]:
# Calculate root-square-mean effect for ACE2 binding for each site
rms_binding = (
    dms_df
    .groupby("site")["ACE2 binding"]
    .apply(lambda x: numpy.sqrt(numpy.mean(x**2)))
    .rename("RMS_binding")
)

# Calculate root-square-mean effect for sera escape for each site
rms_escape = (
    dms_df
    .groupby("site")["sera escape"]
    .apply(lambda x: numpy.sqrt(numpy.mean(x**2)))
    .rename("RMS_escape")
)

# Merge back into original dataframe only once for each RMS
dms_df = dms_df.merge(rms_binding, on="site")
dms_df = dms_df.merge(rms_escape, on="site")
In [7]:
#get R*RMS 
dms_df['RxRMS_be'] = dms_df['correlation'] * dms_df['RMS_binding']*dms_df['RMS_escape']
dms_df["RxRMS_be_floored"] = dms_df["RxRMS_be"].where(dms_df["RxRMS_be"] <= 0, 0)
dms_df["RxRMS_be_floored"] = -dms_df["RxRMS_be_floored"]
dms_df = dms_df.sort_values(by='sequential_site', ascending=True)
In [8]:
unique_sites_df = dms_df.drop_duplicates("site")
In [9]:
# Define RBM sites
rbm_sites = {413, 442, 449, 451, 452, 471, 481, 482, 489, 484, 493, 495, 496, 497, 500}

def assign_region(seq_site):
    if seq_site in rbm_sites:
        return "RBM"
    elif 13 <= seq_site <= 301:
        return "NTD"
    elif 327 <= seq_site <= 523:
        return "RBD"
    elif 523 < seq_site <= 586:
        return "SD1"
    elif 586 < seq_site <= 681:
        return "SD2"
    elif 681 < seq_site <= 1206:
        return "S2"
    else:
        return "other"


# Apply to dataframe
unique_sites_df = unique_sites_df.copy()
unique_sites_df["region_from_site"] = unique_sites_df["sequential_site"].apply(assign_region)

#mark if site belongs to RBM
unique_sites_df["is_RBM"] = unique_sites_df["sequential_site"].apply(
    lambda x: "RBM" if x in rbm_sites else "non-RBM"
)
In [10]:
columns_to_save = [
    "site",
    "region",
    "n_mutations_at_site",
    "correlation",
    "RMS_binding",
    "RMS_escape",
    "RxRMS_be",
    "RxRMS_be_floored",
    "is_RBM"
]

# Save to CSV
unique_sites_df[columns_to_save].to_csv(RBD_up_down_csv, index=False)
In [11]:
# Slider parameters
n_mut_slider = alt.param(
    name="nMutSlider",
    value=3,
    bind=alt.binding_range(
        name="Minimum mutations at site",
        min=2,
        max=int(unique_sites_df["n_mutations_at_site"].max()),
        step=1
    )
)

max_corr_slider = alt.param(
    value=1,
    bind=alt.binding_range(
        name="only show sites with correlation r less than this",
        min=-1,
        max=1,
        step=0.01,
    ),
)

# Tick values every 20 sequential sites
tick_values = numpy.arange(
    unique_sites_df["sequential_site"].min(),
    unique_sites_df["sequential_site"].max() + 1,
    20
).tolist()

# Base chart with filters
base = (
    alt.Chart(unique_sites_df)
    .add_params(n_mut_slider)
    .transform_filter(
        (alt.datum.n_mutations_at_site >= n_mut_slider)
    )
)

line = base.mark_line().encode(
    x=alt.X(
        "site:N",
        sort=alt.EncodingSortField(field="sequential_site", order="ascending"),
        axis=alt.Axis(
            title="Site",
            values=[
                row["site"]
                for _, row in unique_sites_df[
                    unique_sites_df["sequential_site"].isin(tick_values)
                ].iterrows()
            ],
            labelAngle=-45
        )
    ),
    y=alt.Y(
        "RxRMS_be_floored:Q",
        title="estimated effect on RBD up/down motion",
        axis=alt.Axis(grid=False)
    ),
    color=alt.value("#607d8b"),  # fixed line color
    tooltip=[
        alt.Tooltip("site:N"),
        alt.Tooltip("sequential_site:Q"),
        alt.Tooltip("RxRMS_be_floored:Q", format=".2f"),
        alt.Tooltip("RMS_binding:Q", format=".2f"),
        alt.Tooltip("RMS_escape:Q", format=".2f"),
        "n_mutations_at_site:Q",
        "region_from_site:N"
    ]
)

points = base.mark_point(filled=True, size=60).encode(
    x=alt.X("site:N", sort=alt.EncodingSortField(field="sequential_site", order="ascending")),
    y=alt.Y(
        "RxRMS_be_floored:Q",
        axis=alt.Axis(grid=False)
    ),
    color=alt.condition(
        alt.datum.is_RBM == "RBM",
        alt.value("#d1615d"),    # Color RBM points red
        alt.value("#607d8b")     # Other points blue
    ),
    tooltip=[
        "site:N", 
        alt.Tooltip("correlation:Q", format=".2f"),
        alt.Tooltip("RMS_binding:Q", format=".2f"),
        alt.Tooltip("RxRMS_be_floored:Q", format=".2f"),
        alt.Tooltip("RMS_escape:Q", format=".2f"),
        "n_mutations_at_site:Q",
        "region_from_site:N", 
    ],
)

# Define a brush selection for the region bar
brush = alt.selection_interval(encodings=["x"])

# Region bar with brush
region_bar = (
    base.mark_rect(height=20)
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.EncodingSortField(field="sequential_site", order="ascending"),
            axis=None
        ),
        y=alt.value(0),
        color=alt.Color("region_from_site:N", title="Region")
    )
    .add_params(brush)
    .properties(
        width=1200,
        height=20
    )
)

# Main chart filtered by brush
RxRMS_chart = (
    (line + points)
    .transform_filter(brush)
    .properties(
        width=1200,
        height=200
    )
)

full_chart = (
    alt.vconcat(region_bar, RxRMS_chart)
    .resolve_scale(color="shared")
    .add_params(max_corr_slider)
    .transform_filter(alt.datum["correlation"] <= max_corr_slider)
    .properties(
        title=alt.TitleParams(
            ["Site effect on RBD up/down conformation"],
            anchor="middle",
            fontSize=16,
            dy=-5,
        )
    )
    .configure_view(stroke=None)
)

full_chart

# Save chart
full_chart.save(RBD_up_down_chart_html)

Plot sites with high inverse correlation between ACE2 binding and escape¶

Plot sites with high inverse correlation of binding and escape; note the slider at the bottom can control which sites are shown:

In [12]:
# filter for min number of mutations per site
dms_df = (
    dms_df
    .query("n_mutations_at_site >= @min_mutations_at_site")
)
In [13]:
# first make base chart

facet_size = 100

cell_entry_slider = alt.param(
    value=min_cell_entry,
    bind=alt.binding_range(
        name="minimum cell entry",
        min=dms_df["cell entry"].min(),
        max=0,
    ),
)

binding_escape_corr_base = (
    alt.Chart(dms_df)
    .add_params(cell_entry_slider)
    .transform_filter(alt.datum["cell entry"] >= cell_entry_slider)
)

binding_escape_corr_chart = (
    (
        (
            binding_escape_corr_base
            .encode(
                x=alt.X("ACE2 binding", scale=alt.Scale(nice=False, padding=6)),
                y=alt.Y("sera escape", scale=alt.Scale(nice=False, padding=6)),
                tooltip=[
                    "site",
                    "mutation",
                    alt.Tooltip("ACE2 binding", format=".2f"),
                    alt.Tooltip("sera escape", format=".2f"),
                    alt.Tooltip("cell entry", format=".2f"),
                ],
            )           
            .mark_circle(color="black", opacity=0.3, size=60)
        )
        + (
            binding_escape_corr_base
            .transform_regression("ACE2 binding", "sera escape", params=True)
            .transform_calculate(
                r=alt.expr.if_(
                    alt.datum["coef"][1] > 0,
                    alt.expr.sqrt(alt.datum["rSquared"]),
                    -alt.expr.sqrt(alt.datum["rSquared"]),
                ),
                r_text="r = " + alt.expr.format(alt.datum["r"], ".2f"),
            )
            .encode(
                text="r_text:N",
                x=alt.value(3),
                y=alt.value(facet_size - 6),
            )
            .mark_text(size=12, align="left", color="blue")
        )
    )
    .properties(width=facet_size, height=facet_size)
    .facet(
        facet=alt.Facet(
            "site",
            title=None,
            header=alt.Header(
                labelFontSize=14,
                labelFontStyle="italic",
                labelPadding=0,
                labelExpr="'site ' + datum.label",
            )
        ),
        spacing=8,
        columns=8,
    )
    .configure_axis(grid=False)
)

# now make chart filtered for strongly negative correlations
max_corr_slider = alt.param(
    value=-0.82,
    bind=alt.binding_range(
        name="only show sites with correlation r less than this",
        min=-1,
        max=1,
        step=0.01,
    ),
)

binding_escape_neg_corr_chart = (
    binding_escape_corr_chart
    .properties(
        title=alt.TitleParams(
            "Correlation of ACE2 binding and escape filtered by extent of negative correlation",
            anchor="middle",
            fontSize=16,
            dy=-5,
        ),
        autosize=alt.AutoSizeParams(resize=True),
    )
    .add_params(max_corr_slider)
    .transform_filter(alt.datum["correlation"] <= max_corr_slider)
)

binding_escape_neg_corr_chart
Out[13]:

We now plot the same correlation for sites of strong escape¶

We manually specify some sites of strong escape:

In [14]:
escape_sites = [50, 132, 200, 222, 332, 344, 357, 393, 428, 440, 458, 470, 475, 478, 505, 518, 572, 852]

binding_escape_high_escape_corr_chart = (
    binding_escape_corr_chart
    .properties(
        title=alt.TitleParams(
            "Correlation of ACE2 binding and escape for sites of strong escape",
            anchor="middle",
            fontSize=16,
            dy=-5,
        ),
        autosize=alt.AutoSizeParams(resize=True),
    )
    .transform_filter(alt.FieldOneOfPredicate("site", escape_sites))
)

binding_escape_high_escape_corr_chart
Out[14]:

Plot sites of top escape for mutations in different regions.¶

We plot both binding-escape correlation plots and logo plots for the sites with the most escaping mutations.

We stratify sites by:

  • RBD ACE2 proximal
  • RBD ACE2 distal
  • non-RBD

First get RBD sites distance from ACE2, and then use that to separate ACE2 proximal and distal:

In [15]:
ace2_proximal_cutoff = 15  # classify as ACE2 proximal if CA distance <= this

# chain A is ACE2, chain E is RBD
with tempfile.NamedTemporaryFile() as f:
    urllib.request.urlretrieve(
        "https://files.rcsb.org/download/6M0J.pdb",
        f.name,
    )
    coords_df = polyclonal.pdb_utils.extract_atom_locations(f.name, ["A", "E"], target_atom="CA")

# get closest distance for each residue in chain E (RBD) to residue in chain A (ACE2)
dist_df = (
    coords_df
    .query("chain == 'E'")
    [["site", "x", "y", "z"]]
    .merge(
        (
            coords_df
            .query("chain == 'A'")
            [["site", "x", "y", "z"]]
            .rename(columns={c: f"ACE2_{c}" for c in ["site", "x", "y", "z"]})
        ),
        how="cross",
    )
    .assign(
        distance=lambda x: x.apply(
            lambda r: math.sqrt(sum((r[c] - r[f"ACE2_{c}"])**2 for c in ["x", "y", "z"])),
            axis=1,
        )
    )
    .groupby("site", as_index=False)
    .aggregate({"distance": "min"})
)

# Keep only rows where 'site' is fully numeric
dms_df = dms_df[dms_df["site"].astype(str).str.match(r"^\d+$")].copy()

# Convert 'site' to int
dms_df["site"] = dms_df["site"].astype(int)

dms_df_by_region = (
    dms_df
    .merge(dist_df, how="left", validate="many_to_one")
    .assign(
        region=lambda x: numpy.where(
            (x["region"] == "RBD") & (x["distance"] <= ace2_proximal_cutoff),
            "RBD ACE2 proximal",
            numpy.where(x["region"] == "RBD", "RBD ACE2 distal", "non-RBD"),
        ),
    )
)
In [16]:
# Keep only rows where 'site' is fully numeric
dms_df = dms_df[dms_df["site"].astype(str).str.match(r"^\d+$")].copy()

# Convert 'site' to int
dms_df["site"] = dms_df["site"].astype(int)
In [17]:
dms_df_by_region = (
    dms_df
    .merge(dist_df, how="left", validate="many_to_one")
    .assign(
        region=lambda x: numpy.where(
            (x["region"] == "RBD") & (x["distance"] <= ace2_proximal_cutoff),
            "RBD ACE2 proximal",
            numpy.where(x["region"] == "RBD", "RBD ACE2 distal", "non-RBD"),
        ),
    )
)

Now plot escape and binding for the sites with the top most escaping mutations in each region. Make both correlation plots and logo plots colored by ACE2 binding:

In [18]:
top_n = 7

# for coloring by ACE2
ace2_colormap = dmslogo.colorschemes.ValueToColorMap(
    minvalue=-1.5,
    maxvalue=1.5,
    cmap=palettable.colorbrewer.diverging.PRGn_4.mpl_colormap,
    #cmap=palettable.colorbrewer.diverging.PuOr_4.mpl_colormap,
    #cmap=palettable.lightbartlein.diverging.BlueOrange8_2.mpl_colormap,
)
assert abs(ace2_colormap.minvalue) == ace2_colormap.maxvalue, "not symmetric for diverging color scale"
for orientation in ["horizontal", "vertical"]:
    fig, _ = ace2_colormap.scale_bar(
        orientation=orientation, label="ACE2 binding",
    )
    display(fig)
    svg = os.path.join(logoplot_subdir, f"ace2_scalebar_{orientation}.svg")
    print(f"Saving to {svg}")
    fig.savefig(svg, bbox_inches="tight")
    plt.close(fig)

for region, df in dms_df_by_region.groupby("region"):

    print(f"\n\nAnalyzing top sites for {region=}")

    # get sites of top escape mutations
    top_escape_sites = (
        df
        .sort_values("sera escape", ascending=False)
        .groupby("site", sort=False)
        .first()
        .head(top_n)
    )
    sites = top_escape_sites.index.tolist()

    # plot correlation for these top sites
    corr_chart = (
        binding_escape_corr_chart
        .properties(
            title=alt.TitleParams(
                f"Correlation of ACE2 binding and escape for {region} sites where mutations cause strong escape",
                anchor="middle",
                fontSize=16,
                dy=-5,
            ),
            autosize=alt.AutoSizeParams(resize=True),
        )
        .transform_filter(alt.FieldOneOfPredicate("site", sites))
    )
    display(corr_chart)

    # make logo plot
    fig, _ = dmslogo.draw_logo(
        data=df[df["site"].isin(sites)].rename(columns={"sera escape": "escape"}).assign(
            wildtype_site=lambda x: x["wildtype"] + x["site"].astype(str),
            color=lambda x: (
                x["ACE2 binding"]
                .clip(lower=ace2_colormap.minvalue, upper=ace2_colormap.maxvalue)
                .map(ace2_colormap.val_to_color)
            ),
        ),
        x_col="sequential_site",
        letter_col="mutant",
        color_col="color",
        xtick_col="wildtype_site",
        letter_height_col="escape",
        xlabel="",
        heightscale=1,
    )
    display(fig)
    svg = os.path.join(logoplot_subdir, f"{region.replace(' ', '_')}_logoplot.svg")
    print(f"Saving to {svg}")
    fig.savefig(svg, bbox_inches="tight")
    plt.close(fig)

    plt.close(fig)
No description has been provided for this image
Saving to results/binding_vs_escape/logoplots/ace2_scalebar_horizontal.svg
No description has been provided for this image
Saving to results/binding_vs_escape/logoplots/ace2_scalebar_vertical.svg

Analyzing top sites for region='RBD ACE2 distal'
No description has been provided for this image
Saving to results/binding_vs_escape/logoplots/RBD_ACE2_distal_logoplot.svg

Analyzing top sites for region='RBD ACE2 proximal'
No description has been provided for this image
Saving to results/binding_vs_escape/logoplots/RBD_ACE2_proximal_logoplot.svg

Analyzing top sites for region='non-RBD'
No description has been provided for this image
Saving to results/binding_vs_escape/logoplots/non-RBD_logoplot.svg
In [ ]: