Average antibody escape across polyclonal models

This notebook aggregates and averages the antibody escape computed across multiple fit polyclonal models to different libraries, replicates, etc.

First, import Python modules:

[1]:
import os
import pickle

import pandas as pd

import polyclonal
from polyclonal.polyclonal import PolyclonalHarmonizeError

import yaml

Get parameterized variables from papermill:

[2]:
# papermill parameters cell (tagged as `parameters`)
antibody = None
escape_avg_method = None
polyclonal_config = None
site_numbering_map = None
muteffects_csv = None
avg_pickle = None
escape_plot = None
avg_escape = None
rep_escape = None
selection_groups_dict = None
[3]:
# Parameters
selection_groups_dict = {
    "Lib-1_2022-03-14_thaw-3_2130-1-0114-112_1": {
        "date": "2022-03-14",
        "library": "Lib-1",
        "pickle_file": "results/polyclonal_fits/Lib-1_2022-03-14_thaw-3_2130-1-0114-112_1.pickle",
        "replicate": 1,
        "virus_batch": "thaw-3",
    },
    "Lib-2_2022-03-14_thaw-3_2130-1-0114-112_1": {
        "date": "2022-03-14",
        "library": "Lib-2",
        "pickle_file": "results/polyclonal_fits/Lib-2_2022-03-14_thaw-3_2130-1-0114-112_1.pickle",
        "replicate": 1,
        "virus_batch": "thaw-3",
    },
}
antibody = "2130-1-0114-112"
escape_avg_method = "median"
polyclonal_config = "data/polyclonal_config.yaml"
muteffects_csv = "results/muteffects_functional/muteffects_observed.csv"
site_numbering_map = "results/site_numbering/site_numbering_map.csv"
avg_pickle = "results/antibody_escape/2130-1-0114-112.pickle"
avg_escape = "results/antibody_escape/2130-1-0114-112_avg.csv"
rep_escape = "results/antibody_escape/2130-1-0114-112_rep.csv"
escape_plot = "results/antibody_escape/2130-1-0114-112_escape_plot_unformatted.html"

Convert selection_groups into a data frame and get all of the pickled models for each number of epitopes:

[4]:
models_df = pd.DataFrame.from_dict(selection_groups_dict, orient="index")
print(f"Averaging the following models for {antibody=}")
display(models_df)

# convert pickle files into models
assert all(map(os.path.isfile, models_df["pickle_file"])), models_df["pickle_file"]
models_df = (
    models_df.assign(
        model=lambda x: x["pickle_file"].map(lambda f: pickle.load(open(f, "rb")))
    )
    .explode("model")
    .drop(columns=["pickle_file"])
    .assign(n_epitopes=lambda x: x["model"].map(lambda m: len(m.epitopes)))
)
Averaging the following models for antibody='2130-1-0114-112'
date library pickle_file replicate virus_batch
Lib-1_2022-03-14_thaw-3_2130-1-0114-112_1 2022-03-14 Lib-1 results/polyclonal_fits/Lib-1_2022-03-14_thaw-... 1 thaw-3
Lib-2_2022-03-14_thaw-3_2130-1-0114-112_1 2022-03-14 Lib-2 results/polyclonal_fits/Lib-2_2022-03-14_thaw-... 1 thaw-3

Now build the average model, starting with the max number of epitopes and continuing with fewer if we can’t get harmonization with the max number:

[5]:
n_epitopes = models_df["n_epitopes"].sort_values(ascending=False).unique()

for n in n_epitopes:
    try:
        print(f"Trying to harmonize models with {n} epitopes...")
        avg_model = polyclonal.PolyclonalAverage(
            models_df.query("n_epitopes == @n"),
            default_avg_to_plot=escape_avg_method,
        )
        print("Successfully harmonized models.")
    except PolyclonalHarmonizeError as exc:
        print(f"Harmonization failed with this error:\n{str(exc)}\n\n")
Trying to harmonize models with 1 epitopes...
Successfully harmonized models.

Look at correlation in escape values across replicates:

[6]:
avg_model.mut_escape_corr_heatmap()
[6]:

Plot the activities:

[7]:
avg_model.activity_wt_barplot()
[7]:

Plot the escape values:

[8]:
site_map = pd.read_csv(site_numbering_map).rename(columns={"reference_site": "site"})

with open(polyclonal_config) as f:
    antibody_config = yaml.safe_load(f)[antibody]

plot_kwargs = antibody_config["plot_kwargs"]
if "plot_title" not in plot_kwargs:
    plot_kwargs["plot_title"] = str(antibody)
if "region" in site_map.columns:
    plot_kwargs["site_zoom_bar_color_col"] = "region"
if "addtl_slider_stats" not in plot_kwargs:
    plot_kwargs["addtl_slider_stats"] = {"times_seen": 1}
elif "times_seen" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["times_seen"] = 1
if any(site_map["sequential_site"] != site_map["site"]):
    if "addtl_tooltip_stats" not in plot_kwargs:
        plot_kwargs["addtl_tooltip_stats"] = ["sequential_site"]
    elif "sequential_site" not in plot_kwargs["addtl_tooltip_stats"]:
        plot_kwargs["addtl_tooltip_stats"].append("sequential_site")

df_to_merge = [site_map]

if muteffects_csv != "none":
    muteffects = pd.read_csv(muteffects_csv).rename(
        columns={"reference_site": "site", "effect": "functional effect"}
    )[["site", "mutant", "functional effect"]]
    if "functional effect" not in plot_kwargs["addtl_slider_stats"]:
        plot_kwargs["addtl_slider_stats"]["functional effect"] = muteffects[
            "functional effect"
        ].min()
    if "addtl_slider_stats_hide_not_filter" not in plot_kwargs:
        plot_kwargs["addtl_slider_stats_hide_not_filter"] = []
    if "functional effect" not in plot_kwargs["addtl_slider_stats_hide_not_filter"]:
        plot_kwargs["addtl_slider_stats_hide_not_filter"].append("functional effect")
    df_to_merge.append(muteffects)
elif "functional effect" in plot_kwargs["addtl_slider_stats"]:
    del plot_kwargs["addtl_slider_stats"]["functional effect"]

escape_chart = avg_model.mut_escape_plot(
    df_to_merge=df_to_merge,
    **plot_kwargs,
)

print(f"Saving plot to {escape_plot=}")
escape_chart.save(escape_plot)

escape_chart
Saving plot to escape_plot='results/antibody_escape/2130-1-0114-112_escape_plot_unformatted.html'
[8]:

Save the average model to a pickle file:

[9]:
print(f"Saving model to {avg_pickle=}")

with open(avg_pickle, "wb") as f:
    pickle.dump(avg_model, f)
Saving model to avg_pickle='results/antibody_escape/2130-1-0114-112.pickle'

Save the average model escape values:

[10]:
print(f"Saving average escape values to {avg_escape=}")
avg_model.mut_escape_df.round(4).to_csv(avg_escape, index=False)

print(f"Saving per-replicate escape values to {rep_escape=}")
avg_model.mut_escape_df_replicates.round(4).to_csv(rep_escape, index=False)
Saving average escape values to avg_escape='results/antibody_escape/2130-1-0114-112_avg.csv'
Saving per-replicate escape values to rep_escape='results/antibody_escape/2130-1-0114-112_rep.csv'