Validation neutralization assays versus polyclonal fits¶

Compare actual measured neutralization values for specific mutants to the polyclonal fits.

Import Python modules:

In [1]:
import os
import pickle

import altair as alt

import pandas as pd
import numpy as np

import yaml

from scipy import stats

import warnings
warnings.simplefilter("ignore")

palette = ['#999999', '#0072B2',  '#E69F00', '#F0E442', '#009E73','#56B4E9', "#D55E00", "#CC79A7"] 

extended_palette = ['#999999', '#0072B2',  '#E69F00', '#F0E442', '#009E73','#56B4E9', "#D55E00", "#CC79A7", '#9F0162'] 

long_palette = ['#9F0162', '#009F81', '#FF5AAF', '#8400CD', '#008DF9', '#00C2F9', '#FFB2FD', '#A40122', '#E20134', '#FF6E3A', '#FFC33B', '#00FCCF']

figure_palette = ['#999999', '#0072B2',  '#E69F00', '#F0E442', '#009E73','#56B4E9', "#D55E00", "#CC79A7", '#9F0162','#8400CD']

Read configuration and validation assay measurements:

In [2]:
with open("config.yaml") as f:
    config = yaml.safe_load(f)
    
validation_ic50s = pd.read_csv(config["validation_ics"], na_filter=None)

validation_ic50s
Out[2]:
antibody virus aa_substitutions measured IC50 date
0 10-1074 BF520 0.064166 6-24
1 10-1074 BF520 N332L 40.000000 6-24
2 10-1074 BF520 S140D 6.412878 6-24
3 10-1074 TRO11 0.031210 6-24
4 10-1074 TRO11 H330E 1.817125 10-24
5 10-1074 TRO11 D325R 8.618442 10-24
6 10-1074 TRO11 T415Q 0.060501 6-24
7 3BNC117 TRO11 0.256718 6-24
8 3BNC117 TRO11 R304G 0.671586 6-24
9 3BNC117 TRO11 Y318E 0.841764 6-24
10 10-1074 BF520 Q328D 40.000000 6-24
11 10-1074 BF520 E325R 40.000000 6-24
12 10-1074 BF520 H330Y 40.000000 6-24
13 10-1074 BF520 T415Q 18.089759 6-24
14 3BNC117 TRO11 K440D 1.730303 6-24
15 3BNC117 TRO11 N462K 1.389534 6-24
16 3BNC117 TRO11 N462T 1.682281 6-24
17 3BNC117 BF520 0.035509 12-23
18 3BNC117 BF520 T198D 0.059432 12-23
19 3BNC117 BF520 G459D 0.039781 12-23
20 3BNC117 BF520 R476M 0.009145 12-23
21 3BNC117 BF520 T202P 0.047202 12-23
22 3BNC117 BF520 N463S 0.057370 12-23
23 3BNC117 BF520 Q203P 0.019258 12-23
24 3BNC117 TRO11 N279D 0.543656 07-24
25 3BNC117 TRO11 T202P 0.218511 07-24
26 3BNC117 TRO11 N295R 0.201502 07-24
27 10-1074 TRO11 N332L 40.000000 10-24
28 10-1074 TRO11 N332T 40.000000 10-24
29 10-1074 TRO11 H330Y 0.012721 10-24

Now get the predictions by the averaged polyclonal model fits:

In [3]:
validation_vs_prediction = []
for virus, virus_df in validation_ic50s.groupby("virus"):
    if virus == 'TRO11':
        virus_data_path = 'results/antibody_escape/averages/'
    elif virus == 'BF520':
        virus_data_path = '../HIV_Envelope_BF520_DMS/results/antibody_escape/averages/'
    for antibody, antibody_df in virus_df.groupby("antibody"):
        with open(os.path.join(virus_data_path, f"{antibody}_polyclonal_model.pickle"), "rb") as f:
            model = pickle.load(f)
        df = model.icXX(antibody_df)
        #df = model.icXX(df, x=0.80, col="IC80")
        # if antibody == "10-1074":
        #     df['mean_IC50'] = df['mean_IC50'].clip(upper=20)
        #     df['median_IC50'] = df['median_IC50'].clip(upper=20)
        #     df['mean_IC80'] = df['mean_IC80'].clip(upper=20)
        #     df['median_IC80'] = df['median_IC80'].clip(upper=20)
        # elif antibody == "3BNC117":
        #     df['mean_IC50'] = df['mean_IC50'].clip(upper=4)
        #     df['median_IC50'] = df['median_IC50'].clip(upper=4)
        #     df['mean_IC80'] = df['mean_IC80'].clip(upper=4)
        #     df['median_IC80'] = df['median_IC80'].clip(upper=4)
        df = df.merge((model.mut_escape_df
                       .rename(columns={'mutation': 'aa_substitutions'})
                       [['aa_substitutions', 'times_seen']]
                      ), how='left', on='aa_substitutions')
        validation_vs_prediction.append(df)
    
validation_vs_prediction = pd.concat(validation_vs_prediction, ignore_index=True)

validation_vs_prediction = validation_vs_prediction.assign(standard_deviation=lambda x: x['std_IC50'] / x['mean_IC50'])

validation_vs_prediction
Out[3]:
antibody virus aa_substitutions measured IC50 date mean_IC50 median_IC50 std_IC50 n_models frac_models times_seen standard_deviation
0 10-1074 BF520 0.064166 6-24 4.167799 3.002720 2.537085 4 1.000000 NaN 0.608735
1 10-1074 BF520 E325R 40.000000 6-24 43.950464 43.349033 17.381129 4 1.000000 5.375000 0.395471
2 10-1074 BF520 H330Y 40.000000 6-24 65.848214 55.532010 40.293944 4 1.000000 46.833333 0.611922
3 10-1074 BF520 N332L 40.000000 6-24 63.279655 57.053600 29.738738 4 1.000000 5.250000 0.469957
4 10-1074 BF520 Q328D 40.000000 6-24 31.155681 37.758732 20.077395 4 1.000000 3.500000 0.644422
5 10-1074 BF520 S140D 6.412878 6-24 16.460924 16.640096 4.489344 4 1.000000 6.250000 0.272727
6 10-1074 BF520 T415Q 18.089759 6-24 12.772317 11.465297 9.345341 4 1.000000 10.500000 0.731687
7 3BNC117 BF520 0.035509 12-23 2.182814 1.932129 0.617574 3 1.000000 NaN 0.282926
8 3BNC117 BF520 G459D 0.039781 12-23 3.267495 2.666869 1.680265 3 1.000000 22.666667 0.514236
9 3BNC117 BF520 N463S 0.057370 12-23 3.343259 3.192128 0.661246 3 1.000000 6.000000 0.197785
10 3BNC117 BF520 Q203P 0.019258 12-23 2.523261 2.382497 0.752534 3 1.000000 8.111111 0.298239
11 3BNC117 BF520 R476M 0.009145 12-23 1.703096 1.423270 0.582409 3 1.000000 16.333333 0.341971
12 3BNC117 BF520 T198D 0.059432 12-23 4.637385 4.393343 1.960137 3 1.000000 13.333333 0.422682
13 3BNC117 BF520 T202P 0.047202 12-23 3.260015 2.623912 1.363234 3 1.000000 23.111111 0.418168
14 10-1074 TRO11 0.031210 6-24 2.306572 2.477744 1.236410 3 1.000000 NaN 0.536038
15 10-1074 TRO11 D325R 8.618442 10-24 25.852756 25.852756 10.514158 2 0.666667 5.000000 0.406694
16 10-1074 TRO11 H330E 1.817125 10-24 17.534924 17.555177 3.363874 3 1.000000 7.111111 0.191839
17 10-1074 TRO11 H330Y 0.012721 10-24 2.702517 2.338425 1.958102 3 1.000000 4.944444 0.724548
18 10-1074 TRO11 N332L 40.000000 10-24 109.956031 72.070482 105.751998 3 1.000000 7.555556 0.961766
19 10-1074 TRO11 N332T 40.000000 10-24 125.434097 84.734672 120.491347 3 1.000000 7.805556 0.960595
20 10-1074 TRO11 T415Q 0.060501 6-24 2.969692 2.861367 1.746759 3 1.000000 5.972222 0.588195
21 3BNC117 TRO11 0.256718 6-24 3.309937 3.452367 0.727068 3 1.000000 NaN 0.219662
22 3BNC117 TRO11 K440D 1.730303 6-24 6.632871 7.087263 2.529737 3 1.000000 3.166667 0.381394
23 3BNC117 TRO11 N279D 0.543656 07-24 4.514842 4.123966 1.112567 3 1.000000 3.000000 0.246424
24 3BNC117 TRO11 N295R 0.201502 07-24 4.649580 4.265035 2.241842 3 1.000000 3.333333 0.482160
25 3BNC117 TRO11 N462K 1.389534 6-24 5.918465 7.368855 2.874088 3 1.000000 7.944444 0.485614
26 3BNC117 TRO11 N462T 1.682281 6-24 6.091183 5.372617 1.755754 3 1.000000 2.500000 0.288245
27 3BNC117 TRO11 R304G 0.671586 6-24 6.012934 6.312847 1.255966 3 1.000000 2.833333 0.208877
28 3BNC117 TRO11 T202P 0.218511 07-24 3.001493 2.526030 1.183601 3 1.000000 3.555556 0.394337
29 3BNC117 TRO11 Y318E 0.841764 6-24 6.160541 5.679514 1.869263 3 1.000000 5.166667 0.303425

For each antibody, calculate the Pearson correlation coefficient between the predicted IC50s from our models and the IC50s measured in validation assays. We are doing this first for only single mutants:

In [4]:
print("Single mutant correlations between DMS predicted and neutralization assay measured IC50s:")
for virus, virus_df in validation_vs_prediction.groupby("virus"):
    for antibody, antibody_df in virus_df.groupby('antibody'):
        antibody_df = antibody_df.query("aa_substitutions!=''")
        antibody_df = antibody_df[~antibody_df['aa_substitutions'].str.contains(" ")]
        print(f"{virus}, {antibody}:")
        slope, intercept, r_value, p_value, std_err = stats.linregress(
            antibody_df["median_IC50"].astype(float),
            antibody_df["measured IC50"].astype(float))
        print(round(r_value**2,3))
Single mutant correlations between DMS predicted and neutralization assay measured IC50s:
BF520, 10-1074:
0.76
BF520, 3BNC117:
0.757
TRO11, 10-1074:
0.971
TRO11, 3BNC117:
0.565

Now, plot the results. We will plot the median across the replicate polyclonal fits to different deep mutational scanning replicates. This is an interactive plot that you can mouse over for details:

In [5]:
for virus in ["BF520", "TRO11"]:
    for antibody in ['10-1074', '3BNC117']:
        plot_data = validation_vs_prediction.query('virus==@virus').query('antibody==@antibody')
        plot_data = plot_data[~plot_data['aa_substitutions'].str.contains(" ")]
        plot_data['measured IC50'] = plot_data['measured IC50'].astype(float)
        #plot_data['measured IC80'] = plot_data['measured IC80'].astype(float)
        corr_chart = (
            alt.Chart(plot_data)
            .encode(
                x=alt.X("measured IC50", 
                        scale=alt.Scale(type="log", 
                                        nice=False,
                                       domain=(plot_data["measured IC50"].min()*.75, 
                                           plot_data["measured IC50"].max()*1.25)),
                       ),
                y=alt.Y(
                    "median_IC50",
                    title="predicted IC50 from DMS",
                    scale=alt.Scale(type="log", 
                                    nice=False,
                                   domain=(plot_data["median_IC50"].min()*.75, 
                                           plot_data["median_IC50"].max()*1.25)),
                ),
                #facet=alt.Facet("antibody", columns=4, title=None),
                color=alt.Color("aa_substitutions", 
                                title="Amino acid substitutions", 
                                scale=alt.Scale(range=figure_palette)),
                tooltip=[
                    alt.Tooltip(c, format=".3g") if validation_vs_prediction[c].dtype == float
                    else c
                    for c in validation_vs_prediction.columns.tolist()
                ],
            )
            .mark_circle(filled=True, size=60, opacity=1)
            #.configure_axis(grid=False)
            #.resolve_scale(y="independent", x="independent")
            .properties(width=150, height=150)
        )
        if antibody == "10-1074":
            line = alt.Chart(pd.DataFrame({'measured_IC50': [40]})).mark_rule(strokeDash=[8,8]).encode(x='measured_IC50')
        #elif antibody =="3BNC117":
        #    line = alt.Chart(pd.DataFrame({'measured_IC50': [4]})).mark_rule(strokeDash=[8,8]).encode(x='measured_IC50')
            (corr_chart + line).configure_axis(grid=False).display()
        else: 
            corr_chart.configure_axis(grid=False).display()
            

Now also calculate the fold changes, using the median prediction:

In [6]:
fold_changes = (
    validation_vs_prediction
    .rename(columns={"median_IC50": "predicted IC50"})
#    .query("aa_substitutions != ''")
    [["antibody",
      "virus",
      "aa_substitutions", 
      "measured IC50",
      "predicted IC50", 
      "times_seen", 
      "n_models"]]
    .merge(
        validation_vs_prediction
        .rename(columns={"median_IC50": "predicted IC50"})
        .query("aa_substitutions == ''")
        [["antibody", "virus", "measured IC50", "predicted IC50"]],
        on=["antibody", "virus"],
        how="left",
        #validate="many_to_one",
        suffixes=[" mutant", " unmutated"],
    )
    .assign(
        measured_fold_change=lambda x: x["measured IC50 mutant"] / x["measured IC50 unmutated"],
        predicted_fold_change=lambda x: x["predicted IC50 mutant"] / x["predicted IC50 unmutated"],
    )
)

plot_data = fold_changes.copy()
plot_data = plot_data[~plot_data['aa_substitutions'].str.contains(" ")]
#display(plot_data)
for virus in ["BF520", "TRO11"]:
    for antibody in ['10-1074', '3BNC117']:
        sub_plot_data = plot_data.query('virus==@virus').query('antibody==@antibody').copy()
        sub_plot_data['aa_substitutions'] = [f'wildtype {virus}' if x is '' else x for x in sub_plot_data['aa_substitutions']]
        fold_change_chart = (
            alt.Chart(sub_plot_data.query('virus==@virus').query('antibody==@antibody'))
            .encode(
                x=alt.X(
                    "measured_fold_change",
                    title="measured fold change IC50",
                    scale=alt.Scale(type="log", 
                                        nice=False,
                                       domain=(sub_plot_data["measured_fold_change"].min()*.75, 
                                           sub_plot_data["measured_fold_change"].max()*1.25)),
                       ),
                y=alt.Y(
                    "predicted_fold_change",
                    title="predicted fold change IC50",
                    scale=alt.Scale(type="log", 
                                    nice=False,
                                   domain=(sub_plot_data["predicted_fold_change"].min()*.75, 
                                           sub_plot_data["predicted_fold_change"].max()*1.25)),
                ),
                #facet=alt.Facet("antibody", columns=4, title=None),
                color=alt.Color("aa_substitutions", 
                                title="Amino acid substitutions", 
                                scale=alt.Scale(range=figure_palette),
                                sort=[
                                    'wildtype TRO11',
                                    'wildtype BF520',
                                    'E325R',
                                    'D325R',
                                    'H330Y',
                                    'N332L',
                                    'T415Q',
                                    'S140D',
                                    'Q328D',
                                    'H330E',
                                    'N332T',
                                    'T202P',
                                    'T198D',
                                    'Q203P',
                                    'N276D',
                                    'N279D',
                                    'N295R',
                                    'R304G',
                                    'Y318E',
                                    'K440D',
                                    'G459D',
                                    'N462K',
                                    'N462T',
                                    'N463S',
                                ],
                               ),
                tooltip=[
                    alt.Tooltip(c, format=".3g") if sub_plot_data[c].dtype == float
                    else c
                    for c in sub_plot_data.columns.tolist()
                ],
            )
            .mark_circle(filled=True, size=100, opacity=1)
            #.configure_axis(grid=False)
            #.resolve_scale(y="independent", x="independent")
            .properties(width=150, height=150)
        )
        
        antibody_df = fold_changes.query("antibody==@antibody").query('virus==@virus')
        antibody_df = antibody_df[~antibody_df['aa_substitutions'].str.contains(" ")]
        print(f"{antibody}:")
        slope, intercept, r_value, p_value, std_err = stats.linregress(
            antibody_df["predicted_fold_change"].astype(float),
            antibody_df["measured_fold_change"].astype(float))
        print(f"Predicted fold change correlation (R^2): {round(r_value**2,3)}")
        
        if antibody == "10-1074":
            line = alt.Chart(pd.DataFrame({'measured_fold_change': [sub_plot_data["measured_fold_change"].max()]})).mark_rule(strokeDash=[8,8]).encode(x='measured_fold_change')
        #elif antibody =="3BNC117":
        #    line = alt.Chart(pd.DataFrame({'measured_IC50': [4]})).mark_rule(strokeDash=[8,8]).encode(x='measured_IC50')
            (fold_change_chart + line).configure_axis(grid=False).display()
        else: 
            fold_change_chart.configure_axis(grid=False).display()
10-1074:
Predicted fold change correlation (R^2): 0.843
3BNC117:
Predicted fold change correlation (R^2): 0.704
10-1074:
Predicted fold change correlation (R^2): 0.972
3BNC117:
Predicted fold change correlation (R^2): 0.616
In [ ]: