Escape at key sites: logo plots and binding / escape correlations¶

Make logo plots of serum escape at key sites, and look at relationship between escape and other phenotypes like ACE2 binding.

First get input files / parameters from papermill and import Python modules:

In [1]:
# this cell is tagged as `parameters` for papermill parameterization
dms_csv = None
per_antibody_csv = None
pango_consensus_seqs_json = None
codon_seq = None
logoplot_subdir = None
In [2]:
# Parameters
dms_csv = "results/summaries/summary.csv"
per_antibody_csv = "results/summaries/per_antibody_escape.csv"
codon_seq = "data/XBB_1_5_spike_codon.fa"
logoplot_subdir = "results/key_sites/logoplots"
In [3]:
import itertools
import os

import altair as alt

import Bio.SeqIO

import dmslogo
import dmslogo.colorschemes

import matplotlib
import matplotlib.pyplot as plt

import numpy

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

plt.rcParams['svg.fonttype'] = 'none'

os.makedirs(logoplot_subdir, exist_ok=True)

Read input data¶

Keep only mutations with all phenotypes measured:

In [4]:
# read averages for all DMS measurements
dms_df = (
    pd.read_csv(dms_csv)
    .rename(
        columns={"human sera escape": "sera escape", "spike mediated entry": "cell entry"}
    )
    .query("`sera escape`.notnull() and `cell entry`.notnull() and `ACE2 binding`.notnull()")
)

# read per antibody values, merge with averages to create escape_df
per_antibody_df = pd.read_csv(per_antibody_csv)

assert per_antibody_df["antibody_set"].nunique() == 1, "code expects 1 antibody_set"

if (
    (intersection := set(dms_df.columns).intersection(per_antibody_df.columns))
    != {"site", "wildtype", "mutant"}
):
    raise ValueError(f"unexpected {intersection=}")

assert "average" not in per_antibody_df["antibody"]

escape_df = (
    pd.concat(
        [
            dms_df[["site", "wildtype", "mutant", "sera escape"]].rename(
                columns={"sera escape": "escape"}
            ).assign(antibody="average"),
            per_antibody_df.drop(columns="antibody_set"),
        ],
        ignore_index=True,
    )
    .merge(dms_df.drop(columns="sera escape"), validate="many_to_one")
    .assign(wildtype_site=lambda x: x["wildtype"] + x["site"].astype(str))
)

Determine key sites of strongest escape¶

Get key sites with most site escape, and plot their site escape values in interactive chart:

In [5]:
# specification of how to choose sites
key_sites_by_rank = {
    "total_positive_escape": {
        "any antibody": 5,
        "average of antibodies": 15,
    },
}
# sites used in neuts
key_sites_manual = []

# get total magnitude of escape at each site, both for averages
# and across all individual antibodies
site_escape_df = (
    escape_df
    .assign(
        is_average=lambda x: numpy.where(
            x["antibody"] == "average", "average of antibodies", "any antibody"),
    )
    .groupby(["is_average", "antibody", "site", "sequential_site"], as_index=False)
    .aggregate(
        total_abs_escape=pd.NamedAgg("escape", lambda s: s.abs().sum()),
        total_positive_escape=pd.NamedAgg("escape", lambda s: s.clip(lower=0).sum()),
        total_negative_escape=pd.NamedAgg("escape", lambda s: s.clip(upper=0).abs().sum()),
    )
    .groupby(["is_average", "site", "sequential_site"], as_index=False)
    .aggregate(
        {
            "total_abs_escape": "max",
            "total_positive_escape": "max",
            "total_negative_escape": "max",
        }
    )
    .melt(
        id_vars=["is_average", "site", "sequential_site"],
        var_name="site metric",
        value_name="site escape",
    )
    .assign(
        rank=lambda x: (
            x.groupby(["is_average", "site metric"])
            ["site escape"]
            .rank(ascending=False, method="min")
            .astype(int)
        )
    )
)

# get key sites
print(f"Keeping the following manually specified sites: {key_sites_manual}")
key_sites = set(key_sites_manual)
for site_metric, site_metric_d in key_sites_by_rank.items():
    for is_average, rank in site_metric_d.items():
        new_sites = set(
            site_escape_df
            .query("`site metric` == @site_metric")
            .query("is_average == @is_average")
            .query("rank <= @rank")
            ["site"]
        )
        print(f"Adding sites with {site_metric} / {is_average} rank <= {rank}: {new_sites}")
        key_sites = key_sites.union(new_sites)
print(f"Overall keeping the following {len(key_sites)} sites: {key_sites}")

site_escape_df["key_site"] = site_escape_df["site"].isin(key_sites)

# plot sites being kept
site_selection = alt.selection_point(fields=["site"], on="mouseover", empty=False)

site_metric_selection = alt.selection_point(
    fields=["site metric"],
    value="total_positive_escape",
    bind=alt.binding_select(
        name="site metric",
        options=site_escape_df["site metric"].unique(),
    ),
)

site_escape_chart = (
    alt.Chart(site_escape_df)
    .add_params(site_selection, site_metric_selection)
    .transform_filter(site_metric_selection)
    .encode(
        alt.X("site", sort=alt.SortField("sequential_site"), scale=alt.Scale(nice=False, zero=False)),
        alt.Y("site escape"),
        alt.Color("key_site"),
        alt.Row("is_average", title=None),
        tooltip=[alt.Tooltip(c, format=".2f") if site_escape_df[c].dtype == float else c for c in site_escape_df.columns],
        strokeWidth=alt.condition(site_selection, alt.value(2), alt.value(0)),
        opacity=alt.condition(site_selection, alt.value(1), alt.value(0.35)),
        size=alt.condition(site_selection, alt.value(70), alt.value(30)),
    )
    .mark_circle(stroke="black")
    .configure_axis(grid=False)
    .resolve_scale(y="independent")
    .properties(
        width=600,
        height=150,
        title="Escape at each site for average of antibodies or max for any antibody",
    )
)

site_escape_chart
Keeping the following manually specified sites: []
Adding sites with total_positive_escape / any antibody rank <= 5: {420, 357, 485, 456, 473}
Adding sites with total_positive_escape / average of antibodies rank <= 15: {352, 450, 420, 357, 421, 455, 200, 456, 234, 371, 375, 440, 473, 475, 447}
Overall keeping the following 16 sites: {450, 455, 456, 200, 473, 475, 352, 420, 357, 485, 421, 234, 371, 375, 440, 447}
Out[5]:

Draw logo plots for key sites of strongest escape¶

First get key sites and assign colors by ACE2 affinity:

In [6]:
key_sites_df = (
    escape_df
    .query("site in @key_sites")
    .query("mutant not in ['*', '-']")
)

# for coloring by ACE2
ace2_colormap = dmslogo.colorschemes.ValueToColorMap(
    minvalue=max(-2, key_sites_df["ACE2 binding"].min()),
    maxvalue=0,
    cmap="YlOrBr",
)

key_sites_df["color"] = (
    key_sites_df["ACE2 binding"]
    .clip(lower=ace2_colormap.minvalue, upper=ace2_colormap.maxvalue)
    .map(ace2_colormap.val_to_color)
)

for orientation in ["horizontal", "vertical"]:
    fig, _ = ace2_colormap.scale_bar(
        orientation=orientation, label="ACE2 binding",
    )
    display(fig)
    svg = os.path.join(logoplot_subdir, f"key_sites_ace2_scalebar_{orientation}.svg")
    print(f"Saving to {svg}")
    fig.savefig(svg, bbox_inches="tight")
    plt.close(fig)
No description has been provided for this image
Saving to results/key_sites/logoplots/key_sites_ace2_scalebar_horizontal.svg
No description has been provided for this image
Saving to results/key_sites/logoplots/key_sites_ace2_scalebar_vertical.svg

Get which of the key sites are single nucleotide accessible:

In [7]:
codon_to_aas = {}
nts = "ACGT"
for nt1, nt2, nt3 in itertools.product(nts, nts, nts):
    codon = f"{nt1}{nt2}{nt3}"
    codon_to_aas[codon] = set()
    for i in range(len(codon)):
        for nt in nts:
            mutcodon = codon[: i] + nt + codon[i + 1: ]
            aa = str(Bio.Seq.Seq(mutcodon).translate())
            codon_to_aas[codon].add(aa)

gene = str(Bio.SeqIO.read(codon_seq, "fasta").seq).upper()

key_sites_df = (
    key_sites_df
    .assign(
        codon=lambda x: x["sequential_site"].map(
            lambda r: gene[3 * (r - 1): 3 * r]
        ),
        codon_translated=lambda x: x["codon"].map(
            lambda c: str(Bio.Seq.Seq(c).translate())
        ),
        single_nt_accessible=lambda x: x.apply(
            lambda r: r["mutant"] in codon_to_aas[r["codon"]],
            axis=1,
        )
    )
)

key_sites_df = pd.concat(
    [
        key_sites_df.query("single_nt_accessible").assign(
            single_nt_accessible="single-nucleotide accessible"
        ),
        key_sites_df.assign(single_nt_accessible="all measured mutations"),
    ],
    ignore_index=True,
)

assert (key_sites_df["wildtype"] == key_sites_df["codon_translated"]).all()

Plots for averages across sera, for all mutations and just single-nucleotide accessible ones:

In [8]:
draw_logo_kwargs={
    "letter_col": "mutant",
    "color_col": "color",
    "xtick_col": "wildtype_site",
    "letter_height_col": "escape",
    "xlabel": "",
    "clip_negative_heights": True,
}

fig, _ = dmslogo.facet_plot(
    data=key_sites_df.query("antibody == 'average'"),
    x_col="sequential_site",
    show_col=None,
    gridrow_col="single_nt_accessible",
    share_ylim_across_rows=False,
    hspace=0.6,
    height_per_ax=2.4,
    draw_logo_kwargs=draw_logo_kwargs,
)

display(fig)
svg = os.path.join(logoplot_subdir, "avg_sera_escape_at_key_sites.svg")
print(f"Saving to {svg}")
fig.savefig(svg, bbox_inches="tight")
plt.close(fig)
No description has been provided for this image
Saving to results/key_sites/logoplots/avg_sera_escape_at_key_sites.svg

Now make plots for all individual sera, both with all and only single-nucleotide accessible mutations:

In [9]:
for single_nt_accessible, df in key_sites_df.groupby("single_nt_accessible"):
    print(f"\n{single_nt_accessible=}")
    fig, axes = dmslogo.facet_plot(
        data=df.query("antibody != 'average'"),
        x_col="sequential_site",
        show_col=None,
        gridrow_col="antibody",
        share_ylim_across_rows=False,
        hspace=0.6,
        height_per_ax=2.1,
        draw_logo_kwargs=draw_logo_kwargs,
    )
    display(fig)
    svg = os.path.join(
        logoplot_subdir,
        f"all_sera_escape_at_key_sites_{single_nt_accessible.replace(' ', '_')}.svg",
    )
    print(f"Saving to {svg}")
    fig.savefig(svg, bbox_inches="tight")
    plt.close(fig)
single_nt_accessible='all measured mutations'
No description has been provided for this image
Saving to results/key_sites/logoplots/all_sera_escape_at_key_sites_all_measured_mutations.svg

single_nt_accessible='single-nucleotide accessible'
No description has been provided for this image
Saving to results/key_sites/logoplots/all_sera_escape_at_key_sites_single-nucleotide_accessible.svg

Sites of neutralization assay mutations¶

Now make plots that show logo plots at sites of mutations analyzed in neutralization assays for sera in those assays:

In [10]:
muts_in_neuts = [
    "V42F",
    "Y200C",
    "N234T",
    "R357H",
    "R403K",
    "N405K",
    "D420N",
    "K444M",
    "L455F",
    "F456L",
    "Y473S",
    "T572K",
    "A852V",
]

sites_in_neuts = [int(m[1: -1]) for m in muts_in_neuts]

sera_in_neuts = ["serum 287C", "serum 500C", "serum 501C"]

neuts_df = (
    escape_df
    .query("site in @sites_in_neuts")
    .query("mutant not in ['*', '-']")
    .query("antibody in @sera_in_neuts")
    .assign(
        mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
        color=lambda x: numpy.where(x["mutation"].isin(muts_in_neuts), "blue", "gray"),
    )
)

fig, axes = dmslogo.facet_plot(
    data=neuts_df,
    x_col="sequential_site",
    show_col=None,
    gridrow_col="antibody",
    share_ylim_across_rows=False,
    hspace=0.6,
    height_per_ax=2.9,
    draw_logo_kwargs={**draw_logo_kwargs, "clip_negative_heights": False},
)

display(fig)
svg = os.path.join(logoplot_subdir, f"neut_sites.svg")
print(f"Saving to {svg}")
fig.savefig(svg, bbox_inches="tight")
plt.close(fig)
No description has been provided for this image
Saving to results/key_sites/logoplots/neut_sites.svg
In [ ]: