Fit polyclonal model¶
Here we fit polyclonal models to the data.
First, import Python modules:
[1]:
import pickle
import altair as alt
import pandas as pd
import polyclonal
import yaml
[2]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()
Read input data¶
Get parameterized variable from papermill
[3]:
# papermill parameters cell (tagged as `parameters`)
prob_escape_csv = None
n_threads = None
pickle_file = None
antibody = None
[4]:
# Parameters
prob_escape_csv = (
"results/prob_escape/LibA_2022-02-10a_thaw-3_REGN10933_2_prob_escape.csv"
)
pickle_file = "results/polyclonal_fits/LibA_2022-02-10a_thaw-3_REGN10933_2.pickle"
n_threads = 1
Read the probabilities of escape, and filter for only those specified to be retained (adequate counts):
[5]:
print(f"\nReading probabilities of escape from {prob_escape_csv}")
prob_escape = pd.read_csv(
prob_escape_csv, keep_default_na=False, na_values="nan"
).query("retain")
assert prob_escape.drop(columns="antibody_count_threshold").notnull().all().all()
Reading probabilities of escape from results/prob_escape/LibA_2022-02-10a_thaw-3_REGN10933_2_prob_escape.csv
Read the rest of the configuration and input data:
[6]:
# get information from config
with open("config.yaml") as f:
config = yaml.safe_load(f)
antibody = prob_escape["antibody"].unique()
assert len(antibody) == 1, antibody
antibody = antibody[0]
# get site numbering map and the reference sites in order
site_numbering_map = pd.read_csv(config["site_numbering_map"])
reference_sites = site_numbering_map.sort_values("sequential_site")[
"reference_site"
].tolist()
# get the polyclonal configuration for this antibody
with open(config["polyclonal_config"]) as f:
polyclonal_config = yaml.safe_load(f)
if antibody not in polyclonal_config:
raise ValueError(f"`polyclonal_config` lacks configuration for {antibody=}")
antibody_config = polyclonal_config[antibody]
# print names of variables and settings
print(f"{antibody=}")
print(f"{n_threads=}")
print(f"{pickle_file=}")
print(f"{antibody_config=}")
antibody='REGN10933'
n_threads=1
pickle_file='results/polyclonal_fits/LibA_2022-02-10a_thaw-3_REGN10933_2.pickle'
antibody_config={'min_epitope_activity_to_include': 0.2, 'plot_kwargs': {'avg_type': 'min_magnitude', 'addtl_slider_stats': {'times_seen': 3, 'functional effect': -4}, 'slider_binding_range_kwargs': {'n_models': {'step': 1}, 'times_seen': {'step': 1, 'min': 1, 'max': 25}}, 'heatmap_max_at_least': 2, 'heatmap_min_at_least': -2}, 'icXX_plot_kwargs': {'avg_type': 'min_magnitude', 'addtl_slider_stats': {'times_seen': 3, 'functional effect': -4}, 'slider_binding_range_kwargs': {'n_models': {'step': 1}, 'times_seen': {'step': 1, 'min': 1, 'max': 25}}, 'heatmap_max_at_least': 2, 'heatmap_min_at_least': -2, 'x': 0.9, 'icXX_col': 'IC90', 'log_fold_change_icXX_col': 'log2 fold change IC90'}, 'max_epitopes': 1, 'fit_kwargs': {'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0}, 'epitope_colors': ['green']}
Read spatial distances if relevant:
[7]:
if ("spatial_distances" in config) and (config["spatial_distances"] is not None):
print(f"Reading spatial distances from {config['spatial_distances']}")
spatial_distances = pd.read_csv(config["spatial_distances"])
print(f"Read spatial distances for {len(spatial_distances)} residue pairs")
else:
print("No spatial distances")
spatial_distances = None
Reading spatial distances from results/spatial_distances/7tov.csv
Read spatial distances for 529935 residue pairs
Some summary statistics¶
Note that these statistics are only for the variants that passed upstream filtering in the pipeline.
Number of variants per concentration:
[8]:
display(
prob_escape.groupby("antibody_concentration").aggregate(
n_variants=pd.NamedAgg("barcode", "nunique")
)
)
| n_variants | |
|---|---|
| antibody_concentration | |
| 0.15 | 3484 |
| 1.39 | 3513 |
| 5.58 | 3526 |
Plot mean probability of escape across all variants with the indicated number of mutations. Note that this plot weights each variant the same in the means regardless of how many barcode counts it has. We plot means for both censored (set to between 0 and 1) and uncensored probabilities of escape. Also, note it uses a symlog scale for the y-axis. Mouseover points for values:
[9]:
max_aa_subs = 4 # group if >= this many substitutions
mean_prob_escape = (
prob_escape.assign(
n_subs=lambda x: (
x["aa_substitutions_reference"]
.str.split()
.map(len)
.clip(upper=max_aa_subs)
.map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
)
)
.groupby(["antibody_concentration", "n_subs"], as_index=False)
.aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
.rename(
columns={
"prob_escape": "censored to [0, 1]",
"prob_escape_uncensored": "not censored",
}
)
.melt(
id_vars=["antibody_concentration", "n_subs"],
var_name="censored",
value_name="probability escape",
)
)
mean_prob_escape_chart = (
alt.Chart(mean_prob_escape)
.encode(
x=alt.X("antibody_concentration"),
y=alt.Y(
"probability escape",
scale=alt.Scale(type="symlog", constant=0.05),
),
column=alt.Column("censored", title=None),
color=alt.Color("n_subs", title="n substitutions"),
tooltip=[
alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
for c in mean_prob_escape.columns
],
)
.mark_line(point=True, size=0.5)
.properties(width=200, height=125)
.configure_axis(grid=False)
)
mean_prob_escape_chart
[9]:
Fit polyclonal model¶
First, get the fitting related keyword arguments from the configuration passed by snakemake:
[10]:
max_epitopes = antibody_config["max_epitopes"]
print(f"{max_epitopes=}")
fit_kwargs = antibody_config["fit_kwargs"]
print(f"{fit_kwargs=}")
min_epitope_activity_to_include = antibody_config["min_epitope_activity_to_include"]
print(f"{min_epitope_activity_to_include=}")
max_epitopes=1
fit_kwargs={'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0}
min_epitope_activity_to_include=0.2
Fit a model to all the data, and keep adding epitopes until we either reach the maximum specified or the new epitope has negative activity. Note that that we fit using the reference based-site-numbering scheme, so results are shown with those numbers:Z
[11]:
models = []
for n_epitopes in range(1, max_epitopes + 1):
print(f"\nFitting model with {n_epitopes=}")
# create model
model = polyclonal.Polyclonal(
n_epitopes=n_epitopes,
data_to_fit=prob_escape.rename(
columns={
"antibody_concentration": "concentration",
"aa_substitutions_reference": "aa_substitutions",
}
),
alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
sites=reference_sites,
spatial_distances=spatial_distances,
**(
{"epitope_colors": antibody_config["epitope_colors"]}
if "epitope_colors" in antibody_config
else {}
),
)
# fit model
opt_res = model.fit(logfreq=200, **fit_kwargs)
# display activities
print("Curve fits for epitopes:")
display(model.curve_specs_df.round(2))
print("Max and mean absolute-value escape at each epitope:")
display(
model.mut_escape_df.groupby("epitope")
.aggregate(
max_escape=pd.NamedAgg("escape", "max"),
mean_abs_escape=pd.NamedAgg("escape", lambda s: s.abs().mean()),
)
.round(1)
)
# stop if activity below threshold for any epitope and fit at least one epitope
if len(models) and any(
model.activity_wt_df["activity"] <= min_epitope_activity_to_include
):
print(f"Stop fitting, epitope has activity <={min_epitope_activity_to_include}")
model = models[-1] # get previous model
break
else:
models.append(model)
print(f"\nThe selected model has {len(model.epitopes)} epitopes")
Fitting model with n_epitopes=1
#
# Fitting site-level fixed Hill coefficient and non-neutralized frac model.
# Starting optimization of 1009 parameters at Thu Jun 1 13:15:40 2023.
step time_sec loss fit_loss reg_escape reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity reg_hill_coefficient reg_non_neutralized_frac
0 0.066509 2139.3 2129.6 0 0 0 0 0 9.6307 0 0
102 7.686 166.4 124.99 3.6253 0 2.0424 0 0 35.738 0 0
# Successfully finished at Thu Jun 1 13:15:47 2023.
#
# Fitting fixed Hill coefficient and non-neutralized frac model.
# Starting optimization of 3526 parameters at Thu Jun 1 13:15:48 2023.
step time_sec loss fit_loss reg_escape reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity reg_hill_coefficient reg_non_neutralized_frac
0 0.091272 201.49 145.8 17.917 3.2316e-32 2.0424 0 0 35.733 0 0
159 15.288 177.89 120.41 16.327 2.7087 2.6727 0 0 35.775 0 0
# Successfully finished at Thu Jun 1 13:16:03 2023.
#
# Fitting model.
# Starting optimization of 3528 parameters at Thu Jun 1 13:16:03 2023.
step time_sec loss fit_loss reg_escape reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity reg_hill_coefficient reg_non_neutralized_frac
0 0.086802 145.7 120.41 16.327 2.7087 2.6727 0 0 3.5775 0 0
200 17.891 137.58 107.82 16.341 2.8503 3.439 0 0 3.0988 4.0226 0
256 22.855 137.56 107.74 16.353 2.8614 3.4484 0 0 3.0973 4.0614 0
# Successfully finished at Thu Jun 1 13:16:26 2023.
Curve fits for epitopes:
| epitope | activity | hill_coefficient | non_neutralized_frac | |
|---|---|---|---|---|
| 0 | 1 | 3.14 | 1.4 | 0.0 |
Max and mean absolute-value escape at each epitope:
| max_escape | mean_abs_escape | |
|---|---|---|
| epitope | ||
| 1 | 4.9 | 0.1 |
The selected model has 1 epitopes
Plot the neutralization curves against unmutated protein (which reflect the wildtype activities, Hill coefficients, and non-neutralizable fractions):
[12]:
model.curves_plot()
[12]:
Plot of escape values:
[13]:
df_to_merge = site_numbering_map.rename(columns={"reference_site": "site"})
plot_kwargs = antibody_config["plot_kwargs"]
if "plot_title" not in plot_kwargs:
plot_kwargs["plot_title"] = str(antibody)
if "region" in site_numbering_map:
plot_kwargs["site_zoom_bar_color_col"] = "region"
if "addtl_slider_stats" not in plot_kwargs:
plot_kwargs["addtl_slider_stats"] = {"times_seen": 1}
elif "times_seen" not in plot_kwargs["addtl_slider_stats"]:
plot_kwargs["addtl_slider_stats"]["times_seen"] = 1
if "functional effect" in plot_kwargs["addtl_slider_stats"]:
del plot_kwargs["addtl_slider_stats"]["functional effect"] # only antibody averages
if any(site_numbering_map["sequential_site"] != site_numbering_map["reference_site"]):
if "addtl_tooltip_stats" not in plot_kwargs:
plot_kwargs["addtl_tooltip_stats"] = ["sequential_site"]
else:
plot_kwargs["addtl_tooltip_stats"].append("sequential_site")
del plot_kwargs["avg_type"] # if specified for average plots, don't use here
model.mut_escape_plot(df_to_merge=df_to_merge, **plot_kwargs)
[13]:
Plot of ICXX values:
[14]:
df_to_merge = site_numbering_map.rename(columns={"reference_site": "site"})
plot_kwargs = antibody_config["icXX_plot_kwargs"]
if "plot_title" not in plot_kwargs:
plot_kwargs["plot_title"] = str(antibody)
if "region" in site_numbering_map:
plot_kwargs["site_zoom_bar_color_col"] = "region"
if "addtl_slider_stats" not in plot_kwargs:
plot_kwargs["addtl_slider_stats"] = {"times_seen": 1}
elif "times_seen" not in plot_kwargs["addtl_slider_stats"]:
plot_kwargs["addtl_slider_stats"]["times_seen"] = 1
if "functional effect" in plot_kwargs["addtl_slider_stats"]:
del plot_kwargs["addtl_slider_stats"]["functional effect"] # only antibody averages
if any(site_numbering_map["sequential_site"] != site_numbering_map["reference_site"]):
if "addtl_tooltip_stats" not in plot_kwargs:
plot_kwargs["addtl_tooltip_stats"] = ["sequential_site"]
else:
plot_kwargs["addtl_tooltip_stats"].append("sequential_site")
del plot_kwargs["avg_type"] # if specified for average plots, don't use here
model.mut_icXX_plot(df_to_merge=df_to_merge, **plot_kwargs)
[14]:
Pickle and save model:
[15]:
print(f"Saving {len(models)} models to {pickle_file=}")
with open(pickle_file, "wb") as f:
pickle.dump(models, f)
Saving 1 models to pickle_file='results/polyclonal_fits/LibA_2022-02-10a_thaw-3_REGN10933_2.pickle'