Average antibody escape across polyclonal models

This notebook aggregates and averages the antibody escape computed across multiple fit polyclonal models to different libraries, replicates, etc.

First, import Python modules:

[1]:
import os
import pickle

import pandas as pd

import polyclonal
from polyclonal.polyclonal import PolyclonalHarmonizeError

import yaml

Get parameterized variables from papermill:

[2]:
# papermill parameters cell (tagged as `parameters`)
antibody = None
escape_avg_method = None
polyclonal_config = None
site_numbering_map = None
muteffects_csv = None
avg_pickle = None
escape_plot = None
avg_escape = None
rep_escape = None
icXX_plot = None
avg_icXX = None
rep_icXX = None
selection_groups_dict = None
[3]:
# Parameters
selection_groups_dict = {
    "LibA_2022-03-02_thaw-3_S2M11_1": {
        "date": "2022-03-02",
        "library": "LibA",
        "pickle_file": "results/polyclonal_fits/LibA_2022-03-02_thaw-3_S2M11_1.pickle",
        "replicate": 1,
        "virus_batch": "thaw-3",
    },
    "LibA_2022-03-02_thaw-3_S2M11_2": {
        "date": "2022-03-02",
        "library": "LibA",
        "pickle_file": "results/polyclonal_fits/LibA_2022-03-02_thaw-3_S2M11_2.pickle",
        "replicate": 2,
        "virus_batch": "thaw-3",
    },
}
antibody = "S2M11"
escape_avg_method = "median"
polyclonal_config = "data/polyclonal_config.yaml"
muteffects_csv = "results/muteffects_functional/muteffects_observed.csv"
site_numbering_map = "data/site_numbering_map.csv"
avg_pickle = "results/antibody_escape/S2M11.pickle"
avg_escape = "results/antibody_escape/S2M11_avg.csv"
rep_escape = "results/antibody_escape/S2M11_rep.csv"
escape_plot = "results/antibody_escape/S2M11_escape_plot_unformatted.html"
avg_icXX = "results/antibody_escape/S2M11_icXX_avg.csv"
rep_icXX = "results/antibody_escape/S2M11_icXX_rep.csv"
icXX_plot = "results/antibody_escape/S2M11_icXX_plot_unformatted.html"

Convert selection_groups into a data frame and get all of the pickled models for each number of epitopes:

[4]:
models_df = pd.DataFrame.from_dict(selection_groups_dict, orient="index")
print(f"Averaging the following models for {antibody=}")
display(models_df)

# convert pickle files into models
assert all(map(os.path.isfile, models_df["pickle_file"])), models_df["pickle_file"]
models_df = (
    models_df.assign(
        model=lambda x: x["pickle_file"].map(lambda f: pickle.load(open(f, "rb")))
    )
    .explode("model")
    .drop(columns=["pickle_file"])
    .assign(n_epitopes=lambda x: x["model"].map(lambda m: len(m.epitopes)))
)
Averaging the following models for antibody='S2M11'
date library pickle_file replicate virus_batch
LibA_2022-03-02_thaw-3_S2M11_1 2022-03-02 LibA results/polyclonal_fits/LibA_2022-03-02_thaw-3... 1 thaw-3
LibA_2022-03-02_thaw-3_S2M11_2 2022-03-02 LibA results/polyclonal_fits/LibA_2022-03-02_thaw-3... 2 thaw-3

Now build the average model, starting with the max number of epitopes and continuing with fewer if we can’t get harmonization with the max number:

[5]:
n_epitopes = models_df["n_epitopes"].sort_values(ascending=False).unique()

for n in n_epitopes:
    try:
        print(f"Trying to harmonize models with {n} epitopes...")
        avg_model = polyclonal.PolyclonalAverage(
            models_df.query("n_epitopes == @n"),
            default_avg_to_plot=escape_avg_method,
        )
        print("Successfully harmonized models.")
        break
    except PolyclonalHarmonizeError as exc:
        print(f"Harmonization failed with this error:\n{str(exc)}\n\n")
Trying to harmonize models with 1 epitopes...
Successfully harmonized models.

Look at correlation in escape values across replicates:

[6]:
avg_model.mut_escape_corr_heatmap()
[6]:

Plot the neutralization curves against unmutated protein (which reflect the wildtype activities, Hill coefficients, and non-neutralizable fractions):

[7]:
avg_model.curves_plot()
[7]:

Plot the escape values:

[8]:
site_map = pd.read_csv(site_numbering_map).rename(columns={"reference_site": "site"})

with open(polyclonal_config) as f:
    antibody_config = yaml.safe_load(f)[antibody]

plot_kwargs = antibody_config["plot_kwargs"]
if "plot_title" not in plot_kwargs:
    plot_kwargs["plot_title"] = str(antibody)
if "region" in site_map.columns:
    plot_kwargs["site_zoom_bar_color_col"] = "region"
if "addtl_slider_stats" not in plot_kwargs:
    plot_kwargs["addtl_slider_stats"] = {"times_seen": 1}
elif "times_seen" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["times_seen"] = 1
if any(site_map["sequential_site"] != site_map["site"]):
    if "addtl_tooltip_stats" not in plot_kwargs:
        plot_kwargs["addtl_tooltip_stats"] = ["sequential_site"]
    elif "sequential_site" not in plot_kwargs["addtl_tooltip_stats"]:
        plot_kwargs["addtl_tooltip_stats"].append("sequential_site")

df_to_merge = [site_map]

if muteffects_csv != "none":
    muteffects = pd.read_csv(muteffects_csv).rename(
        columns={"reference_site": "site", "effect": "functional effect"}
    )[["site", "mutant", "functional effect"]]
    if "functional effect" not in plot_kwargs["addtl_slider_stats"]:
        plot_kwargs["addtl_slider_stats"]["functional effect"] = muteffects[
            "functional effect"
        ].min()
    if "addtl_slider_stats_hide_not_filter" not in plot_kwargs:
        plot_kwargs["addtl_slider_stats_hide_not_filter"] = []
    if "functional effect" not in plot_kwargs["addtl_slider_stats_hide_not_filter"]:
        plot_kwargs["addtl_slider_stats_hide_not_filter"].append("functional effect")
    df_to_merge.append(muteffects)
elif "functional effect" in plot_kwargs["addtl_slider_stats"]:
    del plot_kwargs["addtl_slider_stats"]["functional effect"]

escape_chart = avg_model.mut_escape_plot(
    df_to_merge=df_to_merge,
    **plot_kwargs,
)

print(f"Saving plot to {escape_plot=}")
escape_chart.save(escape_plot)

escape_chart
Saving plot to escape_plot='results/antibody_escape/S2M11_escape_plot_unformatted.html'
[8]:

Plot the ICXX values:

[9]:
icXX_plot_kwargs = antibody_config["icXX_plot_kwargs"]
if "plot_title" not in icXX_plot_kwargs:
    icXX_plot_kwargs["plot_title"] = str(antibody)
if "region" in site_map.columns:
    icXX_plot_kwargs["site_zoom_bar_color_col"] = "region"
if "addtl_slider_stats" not in icXX_plot_kwargs:
    icXX_plot_kwargs["addtl_slider_stats"] = {"times_seen": 1}
elif "times_seen" not in icXX_plot_kwargs["addtl_slider_stats"]:
    icXX_plot_kwargs["addtl_slider_stats"]["times_seen"] = 1
if any(site_map["sequential_site"] != site_map["site"]):
    if "addtl_tooltip_stats" not in icXX_plot_kwargs:
        icXX_plot_kwargs["addtl_tooltip_stats"] = ["sequential_site"]
    elif "sequential_site" not in icXX_plot_kwargs["addtl_tooltip_stats"]:
        icXX_plot_kwargs["addtl_tooltip_stats"].append("sequential_site")

if muteffects_csv != "none":
    if "functional effect" not in icXX_plot_kwargs["addtl_slider_stats"]:
        icXX_plot_kwargs["addtl_slider_stats"]["functional effect"] = muteffects[
            "functional effect"
        ].min()
    if "addtl_slider_stats_hide_not_filter" not in icXX_plot_kwargs:
        icXX_plot_kwargs["addtl_slider_stats_hide_not_filter"] = []
    if (
        "functional effect"
        not in icXX_plot_kwargs["addtl_slider_stats_hide_not_filter"]
    ):
        icXX_plot_kwargs["addtl_slider_stats_hide_not_filter"].append(
            "functional effect"
        )
elif "functional effect" in icXX_plot_kwargs["addtl_slider_stats"]:
    del icXX_plot_kwargs["addtl_slider_stats"]["functional effect"]

icXX_chart = avg_model.mut_icXX_plot(
    df_to_merge=df_to_merge,
    **icXX_plot_kwargs,
)

print(f"Saving plot to {icXX_plot=}")
icXX_chart.save(icXX_plot)

icXX_chart
Saving plot to icXX_plot='results/antibody_escape/S2M11_icXX_plot_unformatted.html'
[9]:

Save the average model to a pickle file:

[10]:
print(f"Saving model to {avg_pickle=}")

with open(avg_pickle, "wb") as f:
    pickle.dump(avg_model, f)
Saving model to avg_pickle='results/antibody_escape/S2M11.pickle'

Save the average model escape and icXX values:

[11]:
print(f"Saving average escape values to {avg_escape=}")
avg_model.mut_escape_df.round(4).to_csv(avg_escape, index=False)

print(f"Saving per-replicate escape values to {rep_escape=}")
avg_model.mut_escape_df_replicates.round(4).to_csv(rep_escape, index=False)

icXX_kwargs = {
    key: icXX_plot_kwargs[key] for key in ["x", "icXX_col", "log_fold_change_icXX_col"]
}
for key in ["min_c", "max_c", "logbase", "check_wt_icXX"]:
    if key in icXX_plot_kwargs:
        icXX_kwargs[key] = icXX_plot_kwargs[key]

print(f"Saving average ICXX values to {avg_icXX=}")
avg_model.mut_icXX_df(**icXX_kwargs).round(4).to_csv(avg_icXX, index=False)

print(f"Saving per-replicate ICXX values to {rep_icXX=}")
avg_model.mut_icXX_df_replicates(**icXX_kwargs).round(4).to_csv(
    rep_icXX,
    index=False,
)
Saving average escape values to avg_escape='results/antibody_escape/S2M11_avg.csv'
Saving per-replicate escape values to rep_escape='results/antibody_escape/S2M11_rep.csv'
Saving average ICXX values to avg_icXX='results/antibody_escape/S2M11_icXX_avg.csv'
Saving per-replicate ICXX values to rep_icXX='results/antibody_escape/S2M11_icXX_rep.csv'