Fit polyclonal model

Here we fit polyclonal models to the data.

First, import Python modules:

[1]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import yaml
[2]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

Read input data

Get parameterized variable from papermill

[3]:
# papermill parameters cell (tagged as `parameters`)
prob_escape_csv = None
n_threads = None
pickle_file = None
antibody = None
[4]:
# Parameters
prob_escape_csv = (
    "results/prob_escape/Lib-3_2022-06-22_thaw-1_LyCoV-1404_1_prob_escape.csv"
)
pickle_file = "results/polyclonal_fits/Lib-3_2022-06-22_thaw-1_LyCoV-1404_1.pickle"
n_threads = 2

Read the probabilities of escape, and filter for those with sufficient no-antibody counts:

[5]:
print(f"\nReading probabilities of escape from {prob_escape_csv}")

prob_escape = pd.read_csv(
    prob_escape_csv, keep_default_na=False, na_values="nan"
).query("`no-antibody_count` >= no_antibody_count_threshold")
assert prob_escape.notnull().all().all()

Reading probabilities of escape from results/prob_escape/Lib-3_2022-06-22_thaw-1_LyCoV-1404_1_prob_escape.csv

Read the rest of the configuration and input data:

[6]:
# get information from config
with open("config.yaml") as f:
    config = yaml.safe_load(f)

antibody = prob_escape["antibody"].unique()
assert len(antibody) == 1, antibody
antibody = antibody[0]

# get site numbering map and the reference sites in order
site_numbering_map = pd.read_csv(config["site_numbering_map"])
reference_sites = site_numbering_map.sort_values("sequential_site")[
    "reference_site"
].tolist()

# get the polyclonal configuration for this antibody
with open(config["polyclonal_config"]) as f:
    polyclonal_config = yaml.safe_load(f)
if antibody not in polyclonal_config:
    raise ValueError(f"`polyclonal_config` lacks configuration for {antibody=}")
antibody_config = polyclonal_config[antibody]

# print names of variables and settings
print(f"{antibody=}")
print(f"{n_threads=}")
print(f"{pickle_file=}")
print(f"{antibody_config=}")
antibody='LyCoV-1404'
n_threads=2
pickle_file='results/polyclonal_fits/Lib-3_2022-06-22_thaw-1_LyCoV-1404_1.pickle'
antibody_config={'min_epitope_activity_to_include': 0.2, 'plot_kwargs': {'addtl_slider_stats': {'times_seen': 3, 'functional effect': -1.38}, 'slider_binding_range_kwargs': {'n_models': {'step': 1}, 'times_seen': {'step': 1, 'min': 1, 'max': 25}}, 'heatmap_max_at_least': 2, 'heatmap_min_at_least': -2}, 'max_epitopes': 1, 'fit_kwargs': {'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0}}

Some summary statistics

Note that these statistics are only for the variants that passed upstream filtering in the pipeline.

Number of variants per concentration:

[7]:
display(
    prob_escape.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)
n_variants
antibody_concentration
0.654 90244
2.616 90244
10.464 90244

Plot mean probability of escape across all variants with the indicated number of mutations. Note that this plot weights each variant the same in the means regardless of how many barcode counts it has. We plot means for both censored (set to between 0 and 1) and uncensored probabilities of escape. Also, note it uses a symlog scale for the y-axis. Mouseover points for values:

[8]:
max_aa_subs = 4  # group if >= this many substitutions

mean_prob_escape = (
    prob_escape.assign(
        n_subs=lambda x: (
            x["aa_substitutions_reference"]
            .str.split()
            .map(len)
            .clip(upper=max_aa_subs)
            .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
        )
    )
    .groupby(["antibody_concentration", "n_subs"], as_index=False)
    .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
    .rename(
        columns={
            "prob_escape": "censored to [0, 1]",
            "prob_escape_uncensored": "not censored",
        }
    )
    .melt(
        id_vars=["antibody_concentration", "n_subs"],
        var_name="censored",
        value_name="probability escape",
    )
)

mean_prob_escape_chart = (
    alt.Chart(mean_prob_escape)
    .encode(
        x=alt.X("antibody_concentration"),
        y=alt.Y(
            "probability escape",
            scale=alt.Scale(type="symlog", constant=0.05),
        ),
        column=alt.Column("censored", title=None),
        color=alt.Color("n_subs", title="n substitutions"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
            for c in mean_prob_escape.columns
        ],
    )
    .mark_line(point=True, size=0.5)
    .properties(width=200, height=125)
    .configure_axis(grid=False)
)

mean_prob_escape_chart
/fh/fast/bloom_j/software/miniconda3/envs/dms-vep-pipeline/lib/python3.9/site-packages/altair/utils/core.py:317: FutureWarning: iteritems is deprecated and will be removed in a future version. Use .items instead.
  for col_name, dtype in df.dtypes.iteritems():
[8]:

Fit polyclonal model

First, get the fitting related keyword arguments from the configuration passed by snakemake:

[9]:
max_epitopes = antibody_config["max_epitopes"]
print(f"{max_epitopes=}")

fit_kwargs = antibody_config["fit_kwargs"]
print(f"{fit_kwargs=}")

min_epitope_activity_to_include = antibody_config["min_epitope_activity_to_include"]
print(f"{min_epitope_activity_to_include=}")
max_epitopes=1
fit_kwargs={'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0}
min_epitope_activity_to_include=0.2

Fit a model to all the data, and keep adding epitopes until we either reach the maximum specified or the new epitope has negative activity. Note that that we fit using the reference based-site-numbering scheme, so results are shown with those numbers:Z

[10]:
models = []

for n_epitopes in range(1, max_epitopes + 1):
    print(f"\nFitting model with {n_epitopes=}")

    # create model
    model = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=prob_escape.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_reference": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
        sites=reference_sites,
    )

    # fit model
    opt_res = model.fit(logfreq=200, **fit_kwargs)

    # display activities
    print("Activities of epitopes:")
    display(model.activity_wt_df.round(1))
    print("Max and mean absolute-value escape at each epitope:")
    display(
        model.mut_escape_df.groupby("epitope")
        .aggregate(
            max_escape=pd.NamedAgg("escape", "max"),
            mean_abs_escape=pd.NamedAgg("escape", lambda s: s.abs().mean()),
        )
        .round(1)
    )

    # stop if activity below threshold for any epitope and fit at least one epitope
    if len(models) and any(
        model.activity_wt_df["activity"] <= min_epitope_activity_to_include
    ):
        print(f"Stop fitting, epitope has activity <={min_epitope_activity_to_include}")
        models.append(model)
        model = models[-2]  # get previous model
        break
    else:
        models.append(model)

print(f"\nThe selected model has {len(model.epitopes)} epitopes")

Fitting model with n_epitopes=1
# First fitting site-level model.
# Starting optimization of 1200 parameters at Fri Oct  7 15:28:47 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.049735        26098        26097            0            0      0.90499
          200       10.962       1786.9       1756.6       26.909            0       3.3753
          270       14.724       1785.9       1755.6       26.947            0       3.3758
# Successfully finished at Fri Oct  7 15:29:01 2022.
# Starting optimization of 7126 parameters at Fri Oct  7 15:29:02 2022.
         step     time_sec         loss     fit_loss   reg_escape   reg_spread reg_activity
            0     0.070104       3022.5       2869.4        149.7   6.4584e-31       3.3758
          200        16.17       2497.4       2389.7       84.116       20.018       3.5061
          218       17.434       2497.3       2389.6       84.021       20.134       3.5055
# Successfully finished at Fri Oct  7 15:29:19 2022.
Activities of epitopes:
epitope activity
0 1 3.6
Max and mean absolute-value escape at each epitope:
max_escape mean_abs_escape
epitope
1 8.7 0.2

The selected model has 1 epitopes

Epitope activities:

[11]:
model.activity_wt_barplot()
/fh/fast/bloom_j/software/miniconda3/envs/dms-vep-pipeline/lib/python3.9/site-packages/altair/utils/core.py:317: FutureWarning: iteritems is deprecated and will be removed in a future version. Use .items instead.
  for col_name, dtype in df.dtypes.iteritems():
[11]:

Plot of escape values:

[12]:
df_to_merge = site_numbering_map.rename(columns={"reference_site": "site"})

plot_kwargs = antibody_config["plot_kwargs"]
if "plot_title" not in plot_kwargs:
    plot_kwargs["plot_title"] = str(antibody)
if "region" in site_numbering_map:
    plot_kwargs["site_zoom_bar_color_col"] = "region"
if "addtl_slider_stats" not in plot_kwargs:
    plot_kwargs["addtl_slider_stats"] = {"times_seen": 1}
elif "times_seen" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["times_seen"] = 1
if "functional effect" in plot_kwargs["addtl_slider_stats"]:
    del plot_kwargs["addtl_slider_stats"]["functional effect"]  # only antibody averages
if any(site_numbering_map["sequential_site"] != site_numbering_map["reference_site"]):
    if "addtl_tooltip_stats" not in plot_kwargs:
        plot_kwargs["addtl_tooltip_stats"] = ["sequential_site"]
    else:
        plot_kwargs["addtl_tooltip_stats"].append("sequential_site")

model.mut_escape_plot(df_to_merge=df_to_merge, **plot_kwargs)
/fh/fast/bloom_j/software/miniconda3/envs/dms-vep-pipeline/lib/python3.9/site-packages/altair/utils/core.py:317: FutureWarning: iteritems is deprecated and will be removed in a future version. Use .items instead.
  for col_name, dtype in df.dtypes.iteritems():
[12]:

Pickle and save model:

[13]:
print(f"Saving model to {pickle_file=}")
with open(pickle_file, "wb") as f:
    pickle.dump(model, f)
Saving model to pickle_file='results/polyclonal_fits/Lib-3_2022-06-22_thaw-1_LyCoV-1404_1.pickle'