Average antibody escape across polyclonal models

This notebook aggregates and averages the antibody escape computed across multiple fit polyclonal models to different libraries, replicates, etc.

First, import Python modules:

[1]:
import os
import pickle

import pandas as pd

import polyclonal

import yaml

Get parameterized variables from papermill:

[2]:
# papermill parameters cell (tagged as `parameters`)
antibody = None
escape_avg_method = None
polyclonal_config = None
site_numbering_map = None
muteffects_csv = None
avg_pickle = None
escape_plot = None
avg_escape = None
rep_escape = None
selection_groups_dict = None
[3]:
# Parameters
selection_groups_dict = {
    "Lib-2_2022-06-22_thaw-1_NTD_5-7_1": {
        "date": "2022-06-22",
        "library": "Lib-2",
        "pickle_file": "results/polyclonal_fits/Lib-2_2022-06-22_thaw-1_NTD_5-7_1.pickle",
        "replicate": 1,
        "virus_batch": "thaw-1",
    },
    "Lib-3_2022-06-22_thaw-1_NTD_5-7_1": {
        "date": "2022-06-22",
        "library": "Lib-3",
        "pickle_file": "results/polyclonal_fits/Lib-3_2022-06-22_thaw-1_NTD_5-7_1.pickle",
        "replicate": 1,
        "virus_batch": "thaw-1",
    },
}
antibody = "NTD_5-7"
escape_avg_method = "median"
polyclonal_config = "data/polyclonal_config.yaml"
muteffects_csv = "results/muteffects_functional/muteffects_observed.csv"
site_numbering_map = "results/site_numbering/site_numbering_map.csv"
avg_pickle = "results/antibody_escape/NTD_5-7.pickle"
avg_escape = "results/antibody_escape/NTD_5-7_avg.csv"
rep_escape = "results/antibody_escape/NTD_5-7_rep.csv"
escape_plot = "results/antibody_escape/NTD_5-7_escape_plot_unformatted.html"

Convert selection_groups into a data frame and get all of the pickled models:

[4]:
models_df = pd.DataFrame.from_dict(selection_groups_dict, orient="index")
print(f"Averaging the following models for {antibody=}")
display(models_df)

# convert pickle files into models
assert all(map(os.path.isfile, models_df["pickle_file"]))
models_df = models_df.assign(
    model=lambda x: x["pickle_file"].map(lambda f: pickle.load(open(f, "rb")))
).drop(columns="pickle_file")
Averaging the following models for antibody='NTD_5-7'
date library pickle_file replicate virus_batch
Lib-2_2022-06-22_thaw-1_NTD_5-7_1 2022-06-22 Lib-2 results/polyclonal_fits/Lib-2_2022-06-22_thaw-... 1 thaw-1
Lib-3_2022-06-22_thaw-1_NTD_5-7_1 2022-06-22 Lib-3 results/polyclonal_fits/Lib-3_2022-06-22_thaw-... 1 thaw-1

Now build the average model:

[5]:
avg_model = polyclonal.PolyclonalAverage(
    models_df,
    default_avg_to_plot=escape_avg_method,
)

Look at correlation in escape values across replicates:

[6]:
avg_model.mut_escape_corr_heatmap()
/fh/fast/bloom_j/computational_notebooks/jbloom/2022/SARS-CoV-2_Omicron_BA.1_spike_DMS_mAbs/.snakemake/conda/a73ad69c741ab6d85d86c04aa086afcd_/lib/python3.10/site-packages/altair/utils/core.py:317: FutureWarning: iteritems is deprecated and will be removed in a future version. Use .items instead.
  for col_name, dtype in df.dtypes.iteritems():
[6]:

Plot the activities:

[7]:
avg_model.activity_wt_barplot()
/fh/fast/bloom_j/computational_notebooks/jbloom/2022/SARS-CoV-2_Omicron_BA.1_spike_DMS_mAbs/.snakemake/conda/a73ad69c741ab6d85d86c04aa086afcd_/lib/python3.10/site-packages/altair/utils/core.py:317: FutureWarning: iteritems is deprecated and will be removed in a future version. Use .items instead.
  for col_name, dtype in df.dtypes.iteritems():
[7]:

Plot the escape values:

[8]:
site_map = pd.read_csv(site_numbering_map).rename(columns={"reference_site": "site"})

with open(polyclonal_config) as f:
    antibody_config = yaml.safe_load(f)[antibody]

plot_kwargs = antibody_config["plot_kwargs"]
if "plot_title" not in plot_kwargs:
    plot_kwargs["plot_title"] = str(antibody)
if "region" in site_map.columns:
    plot_kwargs["site_zoom_bar_color_col"] = "region"
if "addtl_slider_stats" not in plot_kwargs:
    plot_kwargs["addtl_slider_stats"] = {"times_seen": 1}
elif "times_seen" not in plot_kwargs["addtl_slider_stats"]:
    plot_kwargs["addtl_slider_stats"]["times_seen"] = 1
if any(site_map["sequential_site"] != site_map["site"]):
    if "addtl_tooltip_stats" not in plot_kwargs:
        plot_kwargs["addtl_tooltip_stats"] = ["sequential_site"]
    elif "sequential_site" not in plot_kwargs["addtl_tooltip_stats"]:
        plot_kwargs["addtl_tooltip_stats"].append("sequential_site")

df_to_merge = [site_map]

if muteffects_csv != "none":
    muteffects = pd.read_csv(muteffects_csv).rename(
        columns={"reference_site": "site", "effect": "functional effect"}
    )[["site", "mutant", "functional effect"]]
    if "functional effect" not in plot_kwargs["addtl_slider_stats"]:
        plot_kwargs["addtl_slider_stats"]["functional effect"] = muteffects[
            "functional effect"
        ].min()
    df_to_merge.append(muteffects)
elif "functional effect" in plot_kwargs["addtl_slider_stats"]:
    del plot_kwargs["addtl_slider_stats"]["functional effect"]

escape_chart = avg_model.mut_escape_plot(
    df_to_merge=df_to_merge,
    **plot_kwargs,
)

print(f"Saving plot to {escape_plot=}")
escape_chart.save(escape_plot)

escape_chart
/fh/fast/bloom_j/computational_notebooks/jbloom/2022/SARS-CoV-2_Omicron_BA.1_spike_DMS_mAbs/.snakemake/conda/a73ad69c741ab6d85d86c04aa086afcd_/lib/python3.10/site-packages/altair/utils/core.py:317: FutureWarning: iteritems is deprecated and will be removed in a future version. Use .items instead.
  for col_name, dtype in df.dtypes.iteritems():
Saving plot to escape_plot='results/antibody_escape/NTD_5-7_escape_plot_unformatted.html'
[8]:
show line on site plot
site escape statistic
floor escape at zero

Save the average model to a pickle file:

[9]:
print(f"Saving model to {avg_pickle=}")

with open(avg_pickle, "wb") as f:
    pickle.dump(avg_model, f)
Saving model to avg_pickle='results/antibody_escape/NTD_5-7.pickle'

Save the average model escape values:

[10]:
print(f"Saving average escape values to {avg_escape=}")
avg_model.mut_escape_df.round(4).to_csv(avg_escape, index=False)

print(f"Saving per-replicate escape values to {rep_escape=}")
avg_model.mut_escape_df_replicates.round(4).to_csv(rep_escape, index=False)
Saving average escape values to avg_escape='results/antibody_escape/NTD_5-7_avg.csv'
Saving per-replicate escape values to rep_escape='results/antibody_escape/NTD_5-7_rep.csv'
/fh/fast/bloom_j/computational_notebooks/jbloom/2022/SARS-CoV-2_Omicron_BA.1_spike_DMS_mAbs/.snakemake/conda/a73ad69c741ab6d85d86c04aa086afcd_/lib/python3.10/site-packages/pandas/core/internals/blocks.py:2323: RuntimeWarning: invalid value encountered in cast
  values = values.astype(str)