Test how various models score the phenotypes of BA.2.86 and its subsequent descendants¶
This notebook tests how well each phenotypic measure correctly assesses that BA.2.86 would have high fitness relative to equally mutated descendants of BA.2, and that observed BA.2.86 descendants would have high fitness relative to equally mutated descendants of BA.2.86.
It does this using a test similar to that in Extended Data Fig. 9 of Thadani et al.
Briefly, it calculates the actual phenotype of BA.2.86 relative to BA.2 for each of the various phenotypic measures. It then generates random sequences with the same number of spike amino-acid mutations relative to BA.2 as BA.2.86 from mutations observed at least to some extent in GISAID. It compares the phenotype of BA.2.86 to these randomly generated mutants.
It then does the same thing for all descendant Pango clades of BA.2.86, except now there are multiple clades (so multiple actual phenotypes).
# This cell is tagged as `parameters` for `papermill` parameterization
clade_phenotypes_csv = None # file with clade phenotypes
mutation_phenotypes_csv = None # file with mutation phenotypes
gisaid_mutation_counts_csv = None # file with GISAID mutation counts
gisaid_min_counts = None # draw random mutations from those with >= this many GISAID counts
nrandom = None # number randomized sequences, for descendants it is 10x less
linear_models = None # weights of phenotypes in linear model
# Parameters
gisaid_min_counts = 50
nrandom = 1000
linear_models = {
"spike pseudovirus DMS (combined phenotypes)": {
"spike pseudovirus DMS human sera escape": 38,
"spike pseudovirus DMS ACE2 binding": 2,
"spike pseudovirus DMS spike mediated entry": 16,
},
"RBD yeast-display DMS (combined phenotypes)": {
"RBD yeast-display DMS ACE2 affinity": 19,
"RBD yeast-display DMS RBD expression": 20,
"RBD yeast-display DMS escape": 27,
},
}
clade_phenotypes_csv = "SARS2-spike-predictor-phenos/results/clade_phenotypes.csv"
mutation_phenotypes_csv = "SARS2-spike-predictor-phenos/results/mutation_phenotypes.csv"
gisaid_mutation_counts_csv = "data/GISAID_alignment_counts_2024-01-27.csv"
Import Python modules:
import random
import altair as alt
import numpy
import pandas as pd
_ = alt.data_transformers.disable_max_rows()
Read the clade phenotypes and get the mutations of each clade relative to its parent. We only analyze BA.2.86 and descendant clades with at least one spike mutation relative to their parent:
def relative_mutations(muts, reference_muts):
"""Get mutation in `muts` relative `reference_muts`."""
if pd.isnull(muts):
muts = []
else:
muts = [(m[0], int(m[1: -1]), m[-1]) for m in muts.split()]
if pd.isnull(reference_muts):
reference_muts = []
else:
reference_muts = [(m[0], int(m[1: -1]), m[-1]) for m in reference_muts.split()]
shared_muts = set(muts).intersection(reference_muts)
sites = {
r: (wt, m) for (wt, r, m) in [tup for tup in muts if tup not in shared_muts]
}
reference_sites = {
r: (wt, m) for (wt, r, m) in [tup for tup in reference_muts if tup not in shared_muts]
}
muts = []
for r, (wt, m) in sites.items():
if r in reference_sites:
assert wt == reference_sites[r][0]
muts.append((r, reference_sites[r][1], m))
else:
muts.append((r, wt, m))
for r, (wt, m) in reference_sites.items():
if r in sites:
assert wt == sites[r][0]
pass # already counted
else:
muts.append((r, m, wt))
return [(wt, r, m) for (r, wt, m) in sorted(muts)]
clade_phenotypes = (
pd.read_csv(clade_phenotypes_csv)
[
["clade", "parent", "date", "spike muts from Wuhan-Hu-1", "descendant of BA.2.86"]
]
)
ba_2_86_spike_muts_from_wuhan_hu_1 = clade_phenotypes.set_index("clade").at[
"BA.2.86", "spike muts from Wuhan-Hu-1"
]
ba_2_spike_muts_from_wuhan_hu_1 = clade_phenotypes.set_index("clade").at[
"BA.2", "spike muts from Wuhan-Hu-1"
]
clade_phenotypes = (
clade_phenotypes
.query("(clade == 'BA.2.86') or `descendant of BA.2.86`")
.merge(
clade_phenotypes
[["clade", "spike muts from Wuhan-Hu-1"]]
.rename(
columns={
"clade": "parent",
"spike muts from Wuhan-Hu-1": "parent spike muts from Wuhan-Hu-1",
}
),
on="parent",
validate="many_to_one",
how="left",
)
.assign(
spike_muts_from_parent=lambda x: x.apply(
lambda row: relative_mutations(
row["spike muts from Wuhan-Hu-1"],
row["parent spike muts from Wuhan-Hu-1"],
),
axis=1,
),
spike_muts_from_ba_2_86=lambda x: x.apply(
lambda row: relative_mutations(
row["spike muts from Wuhan-Hu-1"],
ba_2_86_spike_muts_from_wuhan_hu_1,
),
axis=1,
),
n_spike_muts_from_parent=lambda x: x["spike_muts_from_parent"].map(len),
)
.query("n_spike_muts_from_parent >= 1")
.drop(columns=["spike muts from Wuhan-Hu-1", "parent spike muts from Wuhan-Hu-1", "n_spike_muts_from_parent"])
)
Get all amino acids at each site observed at least a threshold number of times in GISAID sequences as well as the "wildtype" amino acid at each site. We will randomize from these amino acids:
gisaid_mutation_counts = pd.read_csv(gisaid_mutation_counts_csv)
n_wts = (
pd.read_csv(gisaid_mutation_counts_csv)
.groupby("site")
.aggregate(
wildtypes=pd.NamedAgg("wildtype", "unique"),
n_wildtypes=pd.NamedAgg("wildtype", "nunique"),
)
)
if any(n_wts["n_wildtypes"]) != 1:
raise ValueError(f"multiple wildtypes for some sites:\n{n_wts.query('n_wildtypes != 1')}")
site_wts = gisaid_mutation_counts.set_index("site")["wildtype"].to_dict()
gisaid_mutation_counts = (
gisaid_mutation_counts
.assign(meets_threshold=lambda x: x["count"] >= gisaid_min_counts)
.query("meets_threshold")
)
gisaid_muts = list(
set(
gisaid_mutation_counts[["wildtype", "site", "mutant"]].itertuples(index=False, name=None)
)
)
print(f"Retained {len(gisaid_muts)} natural mutations to randomize among")
Retained 6605 natural mutations to randomize among
Now get the GISAID mutations relative to BA.2 and BA.2.86, which are the "parents" for our analyses below:
gisaid_muts_relative_to = {}
for clade, clade_muts in [
("BA.2", ba_2_spike_muts_from_wuhan_hu_1),
("BA.2.86", ba_2_86_spike_muts_from_wuhan_hu_1),
]:
clade_muts = {int(m[1: -1]): (m[0], m[-1]) for m in clade_muts.split()}
gisaid_muts_relative_to[clade] = []
for wt, r, m in gisaid_muts:
if r in clade_muts:
assert wt == clade_muts[r][0]
if m == clade_muts[r][1]:
gisaid_muts_relative_to[clade].append((clade_muts[r][1], r, wt))
else:
gisaid_muts_relative_to[clade].append((clade_muts[r][1], r, m))
else:
gisaid_muts_relative_to[clade].append((wt, r, m))
Now get phenotype changes of each clade:
mutation_phenotypes = pd.read_csv(mutation_phenotypes_csv)
assert (mutation_phenotypes["ref_clade"] == "XBB.1.5").all()
class PhenotypeAssigner:
"""Assign phenotypes to sets of mutations.
Parameters
----------
mutation_phenotypes_df : pandas.DataFrame
Should have columns `site`, `wildtype`, `mutant`, `mutation_effect`.
"""
def __init__(self, mutation_phenotypes_df):
assert len(mutation_phenotypes_df) == len(
mutation_phenotypes_df[["site", "mutant"]].drop_duplicates()
)
self.sites = sorted(set(mutation_phenotypes_df["site"]))
assert len(self.sites) == len(
mutation_phenotypes_df[["site", "wildtype"]].drop_duplicates()
)
self.wts = mutation_phenotypes_df.set_index("site")["wildtype"].to_dict()
self.effects = {
site: site_df.set_index("mutant")["mutation_effect"].to_dict()
for site, site_df in mutation_phenotypes_df.groupby("site")
}
for site, wt in self.wts.items():
assert wt not in self.effects[site]
self.effects[site][wt] = 0.0
def phenotype(self, muts):
"""Returns phenotype for list of `muts` as `(wildtype, site, mutant)`."""
pheno = 0.0
for wt, site, m in muts:
if (site in self.effects) and (wt in self.effects[site]) and (m in self.effects[site]):
pheno += self.effects[site][m] - self.effects[site][wt]
return pheno
pheno_changes_df = []
for clade_set, df, mut_col, randomize_muts, nrand in [
(
"BA.2.86 relative to BA.2",
clade_phenotypes.query("clade == 'BA.2.86'"),
"spike_muts_from_parent",
gisaid_muts_relative_to["BA.2"],
nrandom,
),
(
"BA.2.86-descended clades with new spike mutations relative to BA.2",
clade_phenotypes[clade_phenotypes["descendant of BA.2.86"]],
"spike_muts_from_ba_2_86",
gisaid_muts_relative_to["BA.2.86"],
nrandom // 10
),
]:
for phenotype, mut_df in mutation_phenotypes.groupby("phenotype"):
phenos = PhenotypeAssigner(mut_df)
for clade, muts in df[["clade", mut_col]].itertuples(index=False):
actual_pheno = phenos.phenotype(muts)
nmuts = len(muts)
randomized_phenotypes = []
for irandom in range(nrand):
random.seed(irandom)
random_sites = set()
random_muts = []
while len(random_muts) < nmuts:
wt, r, m = random.choice(randomize_muts)
if r not in random_sites:
random_sites.add(r)
random_muts.append((wt, r, m))
randomized_phenotypes.append(phenos.phenotype(random_muts))
p = sum(actual_pheno <= r for r in randomized_phenotypes) / nrand
pheno_changes_df.append(
(clade_set, phenotype, clade, nmuts, actual_pheno, p, randomized_phenotypes)
)
pheno_changes_df = pd.DataFrame(
pheno_changes_df,
columns=["clade_set", "phenotype", "clade", "n_mutations", "value", "P", "randomized_values"],
)
Add linear model phenotypes:
for linear_model, pheno_weights in linear_models.items():
value_sums = 0
randomized_value_sums = None
for pheno, weight in pheno_weights.items():
pheno_df = pheno_changes_df.query("phenotype == @pheno")
value_sums += pheno_df["value"].to_numpy() * weight
if randomized_value_sums is None:
randomized_value_sums = [weight * numpy.array(v) for v in pheno_df["randomized_values"]]
else:
assert len(randomized_value_sums) == len(pheno_df["randomized_values"])
randomized_value_sums = [
s + weight * numpy.array(v)
for (s, v) in zip(randomized_value_sums, pheno_df["randomized_values"])
]
linear_model_pheno_df = (
pheno_changes_df
.drop(columns=["phenotype", "value", "P", "randomized_values"])
.drop_duplicates()
.assign(
phenotype=linear_model,
value=value_sums,
randomized_values=randomized_value_sums,
P=lambda x: x.apply(
lambda row: sum(r >= row["value"] for r in row["randomized_values"]) / nrandom,
axis=1,
),
)
)
assert set(linear_model_pheno_df.columns) == set(pheno_changes_df.columns)
pheno_changes_df = pd.concat([pheno_changes_df, linear_model_pheno_df], ignore_index=True)
Plot actual phenotype versus distribution of randomized phenotypes for BA.2.86 relative to BA.2:
ba_2_86_df = (
pheno_changes_df[
pheno_changes_df["clade_set"] == "BA.2.86 relative to BA.2"
]
.assign(
phenotype_P=lambda x: (
x["phenotype"] + " (P " + x["P"].map(lambda p: f"< {1 / nrandom}" if p == 0 else f"= {p}") + ")"
)
)
.drop(columns=["clade_set", "clade", "n_mutations", "P"])
)
# hacky way to get phenotypes in desired order
pheno_order = [
p.replace("z(", "(") for p in
sorted(
[p.replace("(", "z(") for p in pheno_changes_df["phenotype"].unique()],
reverse=True,
)
]
ba_2_86_base = (
alt.Chart(ba_2_86_df)
.transform_calculate(y="1")
.encode(
alt.X(
"value",
title=None,
axis=alt.Axis(grid=False),
scale=alt.Scale(zero=False, nice=False, padding=8),
),
alt.Y("y:N", title=None, axis=None),
)
)
ba_2_86_point = ba_2_86_base.mark_circle(size=100, opacity=1, color="blue")
ba_2_86_boxplot = (
ba_2_86_base
.transform_flatten(["randomized_values"])
.transform_calculate(value=alt.datum["randomized_values"])
.mark_boxplot(color="gray", opacity=0.7, extent="min-max")
)
(
(ba_2_86_boxplot + ba_2_86_point)
.properties(width=370, height=alt.Step(20))
.facet(
alt.Facet(
"phenotype_P",
header=alt.Header(labelOrient="bottom", labelPadding=2, labelFontSize=11),
title=None,
sort=[ba_2_86_df.set_index("phenotype").at[p, "phenotype_P"] for p in pheno_order],
),
columns=1,
spacing=12,
)
.resolve_scale(x="independent", y="independent")
.properties(
title=alt.TitleParams(
"Phenotypes of BA.2.86 vs BA.2 compared to random mutants",
subtitle=[
"Blue point is the phenotype of BA.2.86 relative to BA.2.",
"",
f"Gray min-max boxplots are phenotypes of {nrandom} sequences with as",
"many spike mutations relative to BA.2 as BA.2.86 drawn randomly",
f"from mutations seen in at least {gisaid_min_counts} sequences in GISAID.",
"",
"P-values are indicated in x-axis labels for each plot row. All phenotypes",
"are calculated from mutation effects in XBB.1.5.",
],
dy=15,
orient="bottom",
),
)
)
Plots for BA.2.86 descendant clades:
descendants_df = (
pheno_changes_df[
pheno_changes_df["clade_set"] == "BA.2.86-descended clades with new spike mutations relative to BA.2"
]
.drop(columns=["clade_set", "P", "n_mutations", "clade"])
.explode("randomized_values")
.rename(columns={"value": "actual", "randomized_values": "randomized"})
.melt(
id_vars="phenotype",
value_vars=["actual", "randomized"],
value_name="value",
var_name="value_type",
)
)
descendants_chart = (
alt.Chart(descendants_df)
.encode(
alt.X(
"value",
title=None,
axis=alt.Axis(grid=False),
scale=alt.Scale(zero=False, nice=False, padding=8),
),
alt.Y("value_type", title=None, axis=alt.Axis(labelFontSize=11)),
alt.Row(
"phenotype",
header=alt.Header(labelOrient="bottom", labelPadding=2, labelFontSize=11),
title=None,
spacing=12,
sort=pheno_order,
),
alt.Stroke("value_type", scale=alt.Scale(range=["blue", "gray"]), legend=None),
alt.Fill("value_type", scale=alt.Scale(range=["blue", "gray"]), legend=None),
)
.mark_boxplot(outliers=False, median=alt.LineConfig(opacity=1, strokeWidth=2), opacity=0.5, size=10)
.resolve_scale(x="independent")
.properties(
width=370,
height=alt.Step(14),
title=alt.TitleParams(
"Phenotypes of BA.2.86 descendants compared to random mutants",
subtitle=[
"The Tukey boxplots show the distribution of phenotypes of all actual",
"BA.2.86 descendants with at least one additonal spike mutation relative to ",
"BA.2.86 (blue), and the distribution of phenotypes of randomly generated sequences ",
"with as many spike mutations relative to BA.2.86 as the actual descendants (gray).",
"",
f"Random mutations are drawn from all mutations seen in at least {gisaid_min_counts} sequences in GISAID.",
"",
"All phenotypes are calculated from mutation effects in XBB.1.5.",
],
dy=15,
orient="bottom",
),
)
)
descendants_chart