Make some paper figures¶

The next cell is tagged parameters for papermill parameterization:

In [1]:
# tagged parameters for `papermill`
In [2]:
# Parameters
params = {
    "func_scores_293T-Mxra8-E3E2-B-241028-func-1": "results/func_scores/293T-Mxra8-E3E2-B-241028-func-1_func_scores.csv",
    "func_scores_293T-Mxra8-E3E2-B-241028-func-2": "results/func_scores/293T-Mxra8-E3E2-B-241028-func-2_func_scores.csv",
    "func_scores_293T-TIM1-E3E2-B-241028-func-1": "results/func_scores/293T-TIM1-E3E2-B-241028-func-1_func_scores.csv",
    "func_scores_293T-TIM1-E3E2-B-241028-func-2": "results/func_scores/293T-TIM1-E3E2-B-241028-func-2_func_scores.csv",
    "func_scores_293T-Mxra8-6KE1-B-241028-func-1": "results/func_scores/293T-Mxra8-6KE1-B-241028-func-1_func_scores.csv",
    "func_scores_293T-Mxra8-6KE1-B-241028-func-2": "results/func_scores/293T-Mxra8-6KE1-B-241028-func-2_func_scores.csv",
    "func_scores_293T-TIM1-6KE1-B-241028-func-1": "results/func_scores/293T-TIM1-6KE1-B-241028-func-1_func_scores.csv",
    "func_scores_293T-TIM1-6KE1-B-241028-func-2": "results/func_scores/293T-TIM1-6KE1-B-241028-func-2_func_scores.csv",
    "func_scores_293T-Mxra8-E3E2-A-241113-func-1": "results/func_scores/293T-Mxra8-E3E2-A-241113-func-1_func_scores.csv",
    "func_scores_293T-Mxra8-E3E2-A-241113-func-2": "results/func_scores/293T-Mxra8-E3E2-A-241113-func-2_func_scores.csv",
    "func_scores_293T-TIM1-E3E2-A-241113-func-1": "results/func_scores/293T-TIM1-E3E2-A-241113-func-1_func_scores.csv",
    "func_scores_293T-TIM1-E3E2-A-241113-func-2": "results/func_scores/293T-TIM1-E3E2-A-241113-func-2_func_scores.csv",
    "func_scores_293T-Mxra8-6KE1-A-241113-func-1": "results/func_scores/293T-Mxra8-6KE1-A-241113-func-1_func_scores.csv",
    "func_scores_293T-Mxra8-6KE1-A-241113-func-2": "results/func_scores/293T-Mxra8-6KE1-A-241113-func-2_func_scores.csv",
    "func_scores_293T-TIM1-6KE1-A-241113-func-1": "results/func_scores/293T-TIM1-6KE1-A-241113-func-1_func_scores.csv",
    "func_scores_293T-TIM1-6KE1-A-241113-func-2": "results/func_scores/293T-TIM1-6KE1-A-241113-func-2_func_scores.csv",
    "func_scores_C636-E3E2-B-241122-func-1": "results/func_scores/C636-E3E2-B-241122-func-1_func_scores.csv",
    "func_scores_C636-E3E2-B-241122-func-2": "results/func_scores/C636-E3E2-B-241122-func-2_func_scores.csv",
    "func_scores_C636-6KE1-B-241122-func-1": "results/func_scores/C636-6KE1-B-241122-func-1_func_scores.csv",
    "func_scores_C636-6KE1-B-241122-func-2": "results/func_scores/C636-6KE1-B-241122-func-2_func_scores.csv",
    "func_scores_C636-E3E2-A-241212-func-1": "results/func_scores/C636-E3E2-A-241212-func-1_func_scores.csv",
    "func_scores_C636-E3E2-A-241212-func-2": "results/func_scores/C636-E3E2-A-241212-func-2_func_scores.csv",
    "func_scores_C636-6KE1-A-241212-func-1": "results/func_scores/C636-6KE1-A-241212-func-1_func_scores.csv",
    "func_scores_C636-6KE1-A-241212-func-2": "results/func_scores/C636-6KE1-A-241212-func-2_func_scores.csv",
    "func_effects_293T-Mxra8_entry_293T-Mxra8-E3E2-A-241113-func-1": "results/func_effects/by_selection/293T-Mxra8-E3E2-A-241113-func-1_func_effects.csv",
    "func_effects_293T-Mxra8_entry_293T-Mxra8-E3E2-A-241113-func-2": "results/func_effects/by_selection/293T-Mxra8-E3E2-A-241113-func-2_func_effects.csv",
    "func_effects_293T-Mxra8_entry_293T-Mxra8-E3E2-B-241028-func-1": "results/func_effects/by_selection/293T-Mxra8-E3E2-B-241028-func-1_func_effects.csv",
    "func_effects_293T-Mxra8_entry_293T-Mxra8-E3E2-B-241028-func-2": "results/func_effects/by_selection/293T-Mxra8-E3E2-B-241028-func-2_func_effects.csv",
    "func_effects_293T-Mxra8_entry_293T-Mxra8-6KE1-A-241113-func-1": "results/func_effects/by_selection/293T-Mxra8-6KE1-A-241113-func-1_func_effects.csv",
    "func_effects_293T-Mxra8_entry_293T-Mxra8-6KE1-A-241113-func-2": "results/func_effects/by_selection/293T-Mxra8-6KE1-A-241113-func-2_func_effects.csv",
    "func_effects_293T-Mxra8_entry_293T-Mxra8-6KE1-B-241028-func-1": "results/func_effects/by_selection/293T-Mxra8-6KE1-B-241028-func-1_func_effects.csv",
    "func_effects_293T-Mxra8_entry_293T-Mxra8-6KE1-B-241028-func-2": "results/func_effects/by_selection/293T-Mxra8-6KE1-B-241028-func-2_func_effects.csv",
    "func_effects_293T-TIM1_entry_293T-TIM1-E3E2-A-241113-func-1": "results/func_effects/by_selection/293T-TIM1-E3E2-A-241113-func-1_func_effects.csv",
    "func_effects_293T-TIM1_entry_293T-TIM1-E3E2-A-241113-func-2": "results/func_effects/by_selection/293T-TIM1-E3E2-A-241113-func-2_func_effects.csv",
    "func_effects_293T-TIM1_entry_293T-TIM1-E3E2-B-241028-func-1": "results/func_effects/by_selection/293T-TIM1-E3E2-B-241028-func-1_func_effects.csv",
    "func_effects_293T-TIM1_entry_293T-TIM1-E3E2-B-241028-func-2": "results/func_effects/by_selection/293T-TIM1-E3E2-B-241028-func-2_func_effects.csv",
    "func_effects_293T-TIM1_entry_293T-TIM1-6KE1-A-241113-func-1": "results/func_effects/by_selection/293T-TIM1-6KE1-A-241113-func-1_func_effects.csv",
    "func_effects_293T-TIM1_entry_293T-TIM1-6KE1-A-241113-func-2": "results/func_effects/by_selection/293T-TIM1-6KE1-A-241113-func-2_func_effects.csv",
    "func_effects_293T-TIM1_entry_293T-TIM1-6KE1-B-241028-func-1": "results/func_effects/by_selection/293T-TIM1-6KE1-B-241028-func-1_func_effects.csv",
    "func_effects_293T-TIM1_entry_293T-TIM1-6KE1-B-241028-func-2": "results/func_effects/by_selection/293T-TIM1-6KE1-B-241028-func-2_func_effects.csv",
    "func_effects_C636_entry_C636-E3E2-B-241122-func-1": "results/func_effects/by_selection/C636-E3E2-B-241122-func-1_func_effects.csv",
    "func_effects_C636_entry_C636-E3E2-B-241122-func-2": "results/func_effects/by_selection/C636-E3E2-B-241122-func-2_func_effects.csv",
    "func_effects_C636_entry_C636-6KE1-B-241122-func-1": "results/func_effects/by_selection/C636-6KE1-B-241122-func-1_func_effects.csv",
    "func_effects_C636_entry_C636-6KE1-B-241122-func-2": "results/func_effects/by_selection/C636-6KE1-B-241122-func-2_func_effects.csv",
    "func_effects_C636_entry_C636-E3E2-A-241212-func-1": "results/func_effects/by_selection/C636-E3E2-A-241212-func-1_func_effects.csv",
    "func_effects_C636_entry_C636-E3E2-A-241212-func-2": "results/func_effects/by_selection/C636-E3E2-A-241212-func-2_func_effects.csv",
    "func_effects_C636_entry_C636-6KE1-A-241212-func-1": "results/func_effects/by_selection/C636-6KE1-A-241212-func-1_func_effects.csv",
    "func_effects_C636_entry_C636-6KE1-A-241212-func-2": "results/func_effects/by_selection/C636-6KE1-A-241212-func-2_func_effects.csv",
    "codon_variants": "results/variants/codon_variants.csv",
    "annotated_mut_summary": "results/annotated_summary_csvs/entry_293T-Mxra8_C636_293T-TIM1_Mxra8-binding_annotated.csv",
    "annotated_site_summary": "results/annotated_summary_csvs/entry_293T-Mxra8_C636_293T-TIM1_Mxra8-binding_annotated_site_means.csv",
    "mxra8_binding_effects": "results/summaries/binding_mouse_vs_human_Mxra8.csv",
    "mxra8_validation_curves": "manual_analyses/experimental_data/RVP.mutants.neutralization.by.soluble.mouse.Mxra8.csv",
    "chikv_titers": "manual_analyses/experimental_data/CHIKV_mutant_titers.csv",
    "rvp_titers": "manual_analyses/experimental_data/RVP_mutant_titers.csv",
    "mxra8_validation_svg": "results/paper_figures/mxra8_validation.svg",
}
min_times_seen = 2
cell_entry_clip_lower = -6

Python imports:

In [3]:
import copy
import itertools
import os

import altair as alt

import dms_variants.codonvarianttable

import matplotlib.cm
import matplotlib.pyplot as plt

import neutcurve

import numpy

import pandas as pd

import scipy.stats

_ = alt.data_transformers.disable_max_rows()

Plot validation assays with RVP for Mxra8 binding¶

These plot correlations of experimental data for RVP validations and DMS.

First, read the validation data for the curves, fit curves, and plot them:

In [4]:
mxra8_curves = (
    pd.read_csv(params["mxra8_validation_curves"])
    .rename(columns={"Concentration (ug/mL)": "concentration"})
    .set_index("concentration")
)

# get and sort variants
mxra8_validation_variants = [
    c for c in mxra8_curves.columns if not c.startswith("Unnamed:")
]
assert mxra8_validation_variants[0] == "unmutated"
mxra8_validation_variants = ["unmutated"] + [
    tup[1]
    for tup in sorted(((v.split("-")[0], int(v.split("-")[1][1: -1]), v[-1]), v) for v in mxra8_validation_variants[1:])
]

mxra8_curves_tidy = []
for variant in mxra8_validation_variants:
    variant_index = mxra8_curves.columns.tolist().index(variant)
    cols = mxra8_curves.columns[variant_index: variant_index + 3].tolist()
    mxra8_curves_tidy.append(
        mxra8_curves[cols].assign(variant=variant).rename(
            columns={c: f"replicate {i + 1}" for (i, c) in enumerate(cols)}
        )
    )
mxra8_curves_tidy = (
    pd.concat(mxra8_curves_tidy)
    .reset_index()
    .melt(
        id_vars=["concentration", "variant"],
        var_name="replicate",
        value_name="fraction_infectivity",
    )
    .query("concentration != 0")
    .assign(
        serum="serum",
        variant=lambda x: pd.Categorical(x["variant"], mxra8_validation_variants, ordered=True),
    )
    .sort_values("variant")
)

mxra8_curve_fits = neutcurve.CurveFits(
    mxra8_curves_tidy,
    conc_col="concentration",
    fracinf_col="fraction_infectivity",
    virus_col="variant",
    replicate_col="replicate",
    fixslope=[0.5, 2],
)

nviruses = len(mxra8_validation_variants)
variant_colors = ["black"] + [
    tuple(c) for c in matplotlib.cm.get_cmap("tab20", nviruses - 1)(numpy.arange(nviruses - 1))
]
variant_markers = [
    m for m in matplotlib.markers.MarkerStyle.markers.keys()
    if m not in {".", ",", "1", "2", "3", "4"}
][: nviruses]

_, mxra8_curve_fits_axes = mxra8_curve_fits.plotSera(
    titles=[""],
    max_viruses_per_subplot=nviruses,
    colors=variant_colors,
    markers=variant_markers,
    ylabel="fraction infectivity",
    xlabel="mouse Mxra8 (ug/ml)",
    heightscale=0.9,
)
_ = mxra8_curve_fits_axes[0][0].set_yticks([0, 0.5, 1])
/loc/scratch/41548938/ipykernel_30809/2402822684.py:53: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.
/fh/fast/bloom_j/computational_notebooks/jbloom/2025/CHIKV-181-25-E-DMS/.snakemake/conda/e937d0191e9e0d4fae02985bc3f90aba_/lib/python3.12/site-packages/neutcurve/hillcurve.py:1177: RuntimeWarning: invalid value encountered in power
No description has been provided for this image

Plot IC50s versus DMS data:

In [5]:
mxra8_ic50s = (
    pd.read_csv(params["mxra8_binding_effects"])
    .assign(
        mutation=lambda x: x["region"] + "-" + x["wildtype"] + x["site"].str.split("(").str[0] + x["mutant"],
    )
    .rename(columns={"mutation": "virus", "binding to mouse Mxra8": "DMS_effect"})
    [["virus", "DMS_effect"]]
    .merge(
        mxra8_curve_fits.fitParams()[["virus", "ic50"]], on="virus", how="right", validate="1:1",
    )
    .assign(
        DMS_effect=lambda x: x["DMS_effect"].where(x["virus"] != "unmutated", 0),
        inv_ic50=lambda x: 1 / x["ic50"],
    )
)

r = mxra8_ic50s["DMS_effect"].corr(numpy.log(mxra8_ic50s["inv_ic50"]))

mxra8_validation_fig, mxra8_validation_axes = plt.subplots(
    1, 2, figsize=(7.5, 2.8), gridspec_kw={"width_ratios": [1.5, 1], "wspace": 0.4}
)

for v, x, y, c, m in zip(
    mxra8_ic50s["virus"], mxra8_ic50s["DMS_effect"], mxra8_ic50s["inv_ic50"], variant_colors, variant_markers
):
    mxra8_validation_axes[1].scatter([x], [y], color=c, marker=m, label=v)
mxra8_validation_axes[1].set_yscale("log")
mxra8_validation_axes[1].set_xlabel("DMS effect on mouse Mxra8 binding")
mxra8_validation_axes[1].set_ylabel("effect in RVP validation (1/IC50)")
mxra8_validation_axes[1].legend(
    loc="center left",
    bbox_to_anchor=(1.05, 0.5),
    borderaxespad=0.0,
    labelspacing=0.1,
)
mxra8_validation_axes[1].text(
    0.05, 0.9, f"R = {r:.2f}", transform=mxra8_validation_axes[1].transAxes, fontsize=12
)
    

# move the neutralization curves to the first subplot axis
def move_artists(ax_from, ax_to):
    """Move plot content from ax_from to ax_to."""
    frame_parts = (
        *ax_from.spines.values(), ax_from.patch,
        ax_from.xaxis, ax_from.yaxis,
        ax_from.title, ax_from._left_title, ax_from._right_title,
    )
    
    # Collect artists to move (can't modify during iteration)
    artists_to_move = [art for art in ax_from.get_children() if art not in frame_parts]
    
    # Remove from source and add to destination
    for art in artists_to_move:
        art.remove()  # This properly cleans up the figure association
        # Update transform if needed
        if art.get_transform() == ax_from.transData:
            art.set_transform(ax_to.transData)
        # Add to destination
        ax_to.add_artist(art)
    
    # copy limits / labels, etc.
    ax_to.set_xscale(ax_from.get_xscale())
    ax_to.set_yscale(ax_from.get_yscale())
    ax_to.set_xlim(ax_from.get_xlim())
    ax_to.set_ylim(ax_from.get_ylim())
    ax_to.set_xlabel(ax_from.get_xlabel())
    ax_to.set_ylabel(ax_from.get_ylabel())

move_artists(mxra8_curve_fits_axes[0][0], mxra8_validation_axes[0])
mxra8_validation_axes[0].set_xlabel("mouse Mxra8 concentration (ug/ml)")
mxra8_validation_axes[0].set_ylabel("fraction infectivity")
mxra8_validation_axes[0].set_yticks([0, 0.5, 1])

# final formatting for figure
_ = mxra8_validation_fig.suptitle(
    "alphavirus reporter virus particle (RVP) validations for mouse Mxra8 binding",
    fontsize=13,
    x=0.59,
)

print(f"Saving to {params['mxra8_validation_svg']=}")
mxra8_validation_fig.savefig(params["mxra8_validation_svg"], bbox_inches="tight")
Saving to params['mxra8_validation_svg']='results/paper_figures/mxra8_validation.svg'
No description has been provided for this image

Distribution of variant functional scores¶

We want the distribution of variant functional scores, similar to as made by this notebook but including both E3-E2 and 6K-E1 fragments and not including deletions (since those are rare in our libraries).

First read all the functional scores, ignoring those for deletions since those are rare and not reported in paper:

In [6]:
def classify_selection(sel):
    sels = {"293T-Mxra8": "293T-Mxra8", "293T-TIM1":"293T-TIM1", "C636":"C6/36"}
    assert sum(s in sel for s in sels) == 1, sel
    for s in sels:
        if s in sel:
            label = [sels[s]]
    libs = {"-A-": "library A", "-B-": "library B"}
    assert sum(l in sel for l in libs) == 1, sel
    for l in libs:
        if l in sel:
            label.append(libs[l])
    return " ".join(label)


func_scores_df = (
    pd.concat(
        [
            pd.read_csv(f).assign(selection=sel)
            for (sel, f) in params.items() if sel.startswith("func_scores_")
        ]
    )
    .assign(selection=lambda x: x["selection"].map(classify_selection))
    .pipe(dms_variants.codonvarianttable.CodonVariantTable.classifyVariants)
    .query("variant_class != 'deletion'")
)

(
    func_scores_df
    .groupby(["selection", "variant_class"])
    .aggregate(n_variants=pd.NamedAgg("barcode", "count"))
)
Out[6]:
n_variants
selection variant_class
293T-Mxra8 library A 1 nonsynonymous 142665
>1 nonsynonymous 79744
stop 5778
synonymous 4245
wildtype 29406
293T-Mxra8 library B 1 nonsynonymous 132342
>1 nonsynonymous 73868
stop 5855
synonymous 4048
wildtype 27337
293T-TIM1 library A 1 nonsynonymous 142805
>1 nonsynonymous 79859
stop 5789
synonymous 4237
wildtype 29411
293T-TIM1 library B 1 nonsynonymous 132443
>1 nonsynonymous 73929
stop 5863
synonymous 4048
wildtype 27349
C6/36 library A 1 nonsynonymous 140546
>1 nonsynonymous 78516
stop 5670
synonymous 4153
wildtype 28913
C6/36 library B 1 nonsynonymous 131133
>1 nonsynonymous 73139
stop 5797
synonymous 3999
wildtype 27056

Make the plot:

In [7]:
def ridgeplot(df):
    variant_classes = list(
        reversed(
            [
                c
                for c in [
                    "wildtype",
                    "synonymous",
                    "1 nonsynonymous",
                    ">1 nonsynonymous",
                    "deletion",
                    "stop",
                ]
                if c in set(df["variant_class"])
            ]
        )
    )

    assert set(df["variant_class"]) == set(variant_classes)

    # get smoothed distribution of functional scores
    bins = numpy.linspace(
        df["func_score"].min(),
        df["func_score"].max(),
        num=50,
    )
    smoothed_dist = pd.concat(
        [
            pd.DataFrame(
                {
                    "selection": sel,
                    "variant_class": var,
                    "func_score": bins,
                    "count": scipy.stats.gaussian_kde(df["func_score"])(bins),
                    "mean_func_score": df["func_score"].mean(),
                    "number of variants": len(df),
                }
            )
            for (sel, var), df in df.groupby(["selection", "variant_class"])
        ]
    )

    # assign y / y2 for plotting
    facet_overlap = 0.7  # maximal facet overlap
    max_count = (smoothed_dist["count"]).max()
    smoothed_dist = smoothed_dist.assign(
        y=lambda x: x["variant_class"].map(lambda v: variant_classes.index(v)),
        y2=lambda x: x["y"] + x["count"] / max_count / facet_overlap,
    )

    # ridgeline plot, based on this but using y / y2 rather than row:
    # https://altair-viz.github.io/gallery/ridgeline_plot.html
    ridgeline_chart = (
        alt.Chart(smoothed_dist)
        .encode(
            x=alt.X(
                "func_score", title="functional score for cell entry", scale=alt.Scale(nice=False)
            ),
            y=alt.Y(
                "y",
                scale=alt.Scale(nice=False),
                title=None,
                axis=alt.Axis(
                    ticks=False,
                    domain=False,
                    # set manual labels https://stackoverflow.com/a/64106056
                    values=[v + 0.5 for v in range(len(variant_classes))],
                    labelExpr=f"{str(variant_classes)}[round(datum.value - 0.5)]",
                ),
            ),
            y2=alt.Y2("y2"),
            fill=alt.Fill(
                "mean_func_score:Q",
                title="mean functional score",
                legend=alt.Legend(direction="horizontal"),
                scale=alt.Scale(scheme="yellowgreenblue"),
            ),
            facet=alt.Facet(
                "selection",
                columns=2,
                title=None,
                header=alt.Header(
                    labelFontWeight="bold",
                    labelPadding=0,
                ),
            ),
            tooltip=[
                "selection",
                "variant_class",
                alt.Tooltip(
                    "mean_func_score", format=".2f", title="mean functional score"
                ),
            ],
        )
        .mark_area(
            interpolate="monotone",
            smooth=True,
            fillOpacity=0.8,
            stroke="lightgray",
            strokeWidth=0.5,
        )
        .configure_view(stroke=None)
        .configure_axis(grid=False)
        .properties(width=180, height=22 * len(variant_classes))
    )

    ridgeline_chart = ridgeline_chart.properties(
        autosize=alt.AutoSizeParams(resize=True),
    )

    return ridgeline_chart


func_scores_chart = ridgeplot(func_scores_df)

func_scores_chart
Out[7]:

Number of mutations per variant¶

Plot number of mutations per variant as here but with better axis limits and labels, only keeping the combined libraries:

In [8]:
codon_variants = pd.read_csv(params["codon_variants"])

lib_rename = {"E_A": "library A", "E_B": "library B"}

display(
    codon_variants
    .groupby("library")
    .aggregate(n_variants=pd.NamedAgg("barcode", "count"))
)

max_muts = 4

print(f"Only keeping {lib_rename=}, and clipping at {max_muts=}")

nmuts_dist = (
    codon_variants
    .query("library in @lib_rename")
    .assign(
        library=lambda x: x["library"].map(lib_rename),
        n_muts=lambda x: x["n_aa_substitutions"].clip(upper=max_muts),
    )
    .groupby(["library", "n_muts"], as_index=False)
    .aggregate(n_variants=pd.NamedAgg("barcode", "count"))
    .assign(
        n_muts_label=lambda x: x["n_muts"].map(
            lambda n: str(n) if n < max_muts else f">{n - 1}"
        )
    )
)

nmuts_dist_chart = (
    alt.Chart(nmuts_dist)
    .encode(
        alt.X(
            "n_muts_label",
            sort=alt.SortField("n_muts"),
            title="number amino-acid mutations",
            axis=alt.Axis(labelAngle=0),
        ),
        alt.Y("n_variants", title="number of barcoded variants"),
        alt.Column(
            "library",
            title=None,
            header=alt.Header(labelFontStyle="bold", labelFontSize=11, labelPadding=0),
        ),
    )
    .mark_bar(color="black")
    .configure_axis(grid=False)
    .properties(height=150, width=150)
)

nmuts_dist_chart
n_variants
library
6KE1_A 69359
6KE1_B 62446
E3E2_A 65426
E3E2_B 62495
E_A 134757
E_B 124915
Only keeping lib_rename={'E_A': 'library A', 'E_B': 'library B'}, and clipping at max_muts=4
Out[8]:

Plot correlation among estimated functional effects on each cell¶

Plot similar to here but aggregating across libraries:

In [9]:
func_effects_by_lib = (
    pd.concat(
        [
            pd.read_csv(f).assign(
                name=name,
                cell=name.split("_")[2],
                replicate=name.split("-")[-1],
                library="library A" if "-A-" in name else "library B",
                region="E3E2" if "E3E2" in name else "6KE1",
            )
            for (name, f) in params.items() if name.startswith("func_effects_")
        ],
        ignore_index=True,
    )
    .query("wildtype != mutant")
    .query("times_seen >= @min_times_seen")
    .assign(
        mut_in_region=lambda x: x.apply(
            lambda r: (
                ("6K" in r["site"] or "E1" in r["site"]) and (r["region"] == "6KE1")
                or ("E2" in r["site"] or "E2" in r["site"]) and (r["region"] == "E3E2")
            ),
            axis=1,
        ),
    )
    .query("mut_in_region")
    .assign(
        library_replicate=lambda x: x["library"] + ", replicate " + x["replicate"],
        cell=lambda x: x["cell"].map(
            {
                "C636": "entry in C6/36 cells",
                "293T-Mxra8": "entry in 293T-Mxra8 cells",
                "293T-TIM1": "entry in 293T-TIM1 cells",
            }
        ),
    )
    [["site", "wildtype", "mutant", "functional_effect", "library_replicate", "cell"]]
)

func_effects_by_lib
Out[9]:
site wildtype mutant functional_effect library_replicate cell
1346 1(E2) S - -0.98420 library A, replicate 1 entry in 293T-Mxra8 cells
1347 1(E2) S A 0.24400 library A, replicate 1 entry in 293T-Mxra8 cells
1348 1(E2) S C -1.91600 library A, replicate 1 entry in 293T-Mxra8 cells
1349 1(E2) S D -0.17790 library A, replicate 1 entry in 293T-Mxra8 cells
1350 1(E2) S E -0.71330 library A, replicate 1 entry in 293T-Mxra8 cells
... ... ... ... ... ... ...
306391 439(E1) H Y -0.37950 library A, replicate 2 entry in C6/36 cells
306393 440(E1) * K 0.06405 library A, replicate 2 entry in C6/36 cells
306394 440(E1) * Q -2.44100 library A, replicate 2 entry in C6/36 cells
306395 440(E1) * W -0.49800 library A, replicate 2 entry in C6/36 cells
306396 440(E1) * Y -0.23450 library A, replicate 2 entry in C6/36 cells

210846 rows × 6 columns

Now make the plot:

In [10]:
for cell, cell_df in func_effects_by_lib.groupby("cell"):
    corr_panels = []
    for sel1, sel2 in itertools.combinations(sorted(cell_df["library_replicate"].unique()), 2):
        corr_df = (
            cell_df.query("library_replicate == @sel1")[["functional_effect", "site", "mutant"]]
            .rename(columns={"functional_effect": sel1})
            .merge(
                cell_df.query("library_replicate == @sel2")[["functional_effect", "site", "mutant"]].rename(
                    columns={"functional_effect": sel2}
                ),
                validate="one_to_one",
            )
            .drop(columns=["site", "mutant"])
        )
        n = len(corr_df)
        r = corr_df[[sel1, sel2]].corr().values[1, 0]
        corr_panels.append(
            alt.Chart(corr_df)
            .encode(
                alt.X(sel1, scale=alt.Scale(nice=False, padding=4)),
                alt.Y(sel2, scale=alt.Scale(nice=False, padding=4)),
            )
            .mark_circle(color="black", size=25, opacity=0.15)
            .properties(
                width=135,
                height=135,
                title=alt.TitleParams(
                    f"R = {r:.2f}, N = {n}", fontSize=11, fontWeight="normal", dy=2
                ),
            )
        )
    
    corr_chart = alt.hconcat(*corr_panels, spacing=5).configure_axis(grid=False).properties(
        title=alt.TitleParams(f"correlations among libraries and replicates for mutations effects on {cell}", anchor="middle")
    )
    display(corr_chart)

Summary of cell entry effects¶

First read the effects, apply a floor, and compute the site mean and number of effective amino acids at each site:

In [11]:
cell_entry_types = {
    "entry in 293T_Mxra8 cells": "293T-Mxra8",
    "entry in C636 cells": "C6/36",
    "entry in 293T_TIM1 cells": "293T-TIM1",
}

cell_entry_mut = (
    pd.read_csv(params["annotated_mut_summary"])
    .rename(columns=cell_entry_types)
    .melt(
        id_vars=["site", "sequential_site", "wildtype", "mutant", "region", "domain", "contacts"],
        value_vars=cell_entry_types.values(),
        var_name="cell",
        value_name="cell entry",
    )
    .assign(**{"cell entry": lambda x: x["cell entry"].clip(lower=cell_entry_clip_lower)})
)

cell_entry_site = (
    cell_entry_mut
    .assign(
        p_unnorm=lambda x: numpy.exp(x["cell entry"]),
        p=lambda x: x["p_unnorm"] / x.groupby(["cell", "site"])["p_unnorm"].transform("sum"),
        entry_no_wt=lambda x: x["cell entry"].where(x["wildtype"] != x["mutant"], pd.NA),
    )
    .groupby(
        ["cell", "site", "sequential_site", "wildtype", "region", "domain", "contacts"],
        dropna=False,
        as_index=False,
    )
    .aggregate(
        n_effective=pd.NamedAgg("p", lambda p: numpy.exp((-p * numpy.log(p)).sum())),
        site_mean_entry=pd.NamedAgg("entry_no_wt", lambda s: s.dropna().mean()),
    )
    .sort_values("sequential_site")
)

Plot various site statistics showing mutational tolerance. Plot as the mean effect of mutations and the number of effective amino acids assigning probability weights proportional to the exponential of the cell entry effect:

In [12]:
metrics = {"n_effective": "effective amino acids", "site_mean_entry": "mean cell entry effect"}

title_prefix = {
    "n_effective": "mutational tolerance at each site for entry in ", 
    "site_mean_entry": "effects of mutations at each site for entry in ",
}

width = 600

for cell, metric in itertools.product(cell_entry_site["cell"].unique(), metrics):

    df = cell_entry_site.query("cell == @cell")[["site", "sequential_site", "region", metric]]
    
    site_entry_lines = (
        alt.Chart(df)
        .encode(
            alt.X(
                "site",
                sort=alt.SortField("sequential_site"),
                axis=alt.Axis(values=cell_entry_site["site"].iloc[90::242], labelAngle=0),
            ),
            alt.Y(
                metric,
                title=metrics[metric],
                scale=alt.Scale(nice=False, padding=2),
                axis=alt.Axis(grid=False)
            ),
            alt.Color("region", scale=alt.Scale(range=["gray", "darkgreen", "darkblue", "teal"])),
            tooltip=df.columns.tolist(),
        )
        .mark_rect(opacity=1)
        .properties(height=110, width=width)
    )

    text_df = df.groupby("region", as_index=False).aggregate(x=pd.NamedAgg("sequential_site", "mean"))

    text_chart = (
        alt.Chart(text_df)
        .encode(
            alt.X(
                "x:Q",
                title=None,
                scale=alt.Scale(domain=(df["sequential_site"].min(), df["sequential_site"].max())),
                axis=None,
            ),
            alt.Text("region"),
            alt.Color("region", legend=None),
        )
        .mark_text(fontWeight="bold", fontSize=12)
        .properties(width=width, height=1)
    )

    site_entry_chart = (
        alt.vconcat(text_chart, site_entry_lines, spacing=0)
        .configure_view(stroke=None)
        .properties(title=alt.TitleParams(title_prefix[metric] + cell + " cells", anchor="middle", fontWeight="normal"))
    )

    display(site_entry_chart)

Plot showing distribution of mean-mutation effects at each site. Do this for different proteins (E3, E2, 6K, E1) and different subdomains, as well as different types of Mxra8 contact sites:

In [13]:
for domain_col, cell in itertools.product(
    ["domain", "contacts", "region"],
    cell_entry_site["cell"].unique(),
):
    df = (
        cell_entry_site
        [cell_entry_site[domain_col].notnull()]
        .query("cell == @cell")
        [["site_mean_entry", "sequential_site", domain_col]]
    )

    if domain_col == "contacts":
        df = df.query("contacts != 'no'")

    dist_chart = (
        alt.Chart(df)
        .encode(
            alt.X(
                "site_mean_entry",
                axis=alt.Axis(titleFontWeight="normal", titleFontSize=12),
                scale=alt.Scale(nice=False, padding=4),
                title="mean effect of mutations at each site",
            ),
            alt.Y(
                domain_col,
                sort=df[domain_col].unique(),
                title=None,
                axis=alt.Axis(labelFontSize=12),
            ),
        )
        .mark_boxplot(
            outliers=False,
            box=alt.MarkConfig(fill="darkgray", stroke="black"),
            median=alt.MarkConfig(color="black"),
            rule=alt.MarkConfig(color='black'),
            size=13,
        )
        .properties(
            width=170,
            height=alt.Step(17),
            title=f"entry in {cell} cells"
        )
        .configure_axis(grid=False)
    )
    
    display(dist_chart)

Validation assays of RVPs and CHIKV mutants with loss-of-function mutations for entry in a cell line¶

First plot reporter virus particle titers:

In [14]:
# color shades of blue if expected to be attenauted in Mxra8, orange if mosquitoe
viruses = {
    'unmutated': "black",
    'R119K': "#6A5ACD",
    'K120D': "#367588",
    'I121E': "#4682B4",
    'R119K-K120D': "#1E90FF",
    'R119K-I121E': "#4169E1",
    'K120D-I121E': "#545499",
    'A157S': "#FFB300",
    'Q158T': "#ED9121",
    'Q158V': "#FF7518",
    'A157S-Q158T': "#E9963A",
    'A157S-Q158V': "#CC5500",
}

# shapes by number of mutations
shapes = {}
for v in viruses:
    if v == "unmutated":
        shapes[v] = "square"
    elif "-" in v:
        shapes[v] = "diamond"
    else:
        shapes[v] = "circle"

rvp_titers = (
    pd.read_csv(params["rvp_titers"])
    .groupby(["cell", "virus"], as_index=False)
    .aggregate(
        titer=pd.NamedAgg("titer", "mean"),
        titer_sem=pd.NamedAgg("titer", "sem"),
    )
    .assign(
        virus=lambda x: x["virus"].str.replace("E2_", "").str.replace("_", "-"),
        titer_upper=lambda x: x["titer"] + x["titer_sem"],
        titer_lower=lambda x: x["titer"] - x["titer_sem"],
    )
)

assert set(rvp_titers["virus"]) == set(viruses)

cells = ["293T-TIM1", "293T-Mxra8", "Huh7", "C6/36", "Aag2"]

rvp_titers_base = (
    alt.Chart(rvp_titers)
    .encode(
        alt.Y("virus", title=None, sort=list(viruses)),
        alt.Color(
            "virus",
            scale=alt.Scale(domain=list(viruses), range=list(viruses.values())),
            legend=alt.Legend(title="E2 mutations", titleFontSize=11)
        ),
    )
    .properties(width=200, height=alt.Step(10))
)

rvp_titers_points = (
    rvp_titers_base
    .encode(
        alt.X(
            "titer",
            title="titer relative to unmutated",
            scale=alt.Scale(type="log", nice=False, padding=6),
            axis=alt.Axis(values=[1, 0.1, 0.01, 0.001], grid=False, titleFontWeight="normal", titleFontSize=11)
        ),
        alt.Shape("virus", scale=alt.Scale(domain=list(shapes), range=list(shapes.values()))),
    )
    .mark_point(filled=True, fillOpacity=1, size=60)
)

rvp_titers_errorbars = (
    rvp_titers_base
    .encode(
        alt.X("titer_upper"),
        alt.X2("titer_lower"),
    )
    .mark_rule(size=1)
)

rvp_titers_chart = (
    (rvp_titers_errorbars + rvp_titers_points)
    .facet(
        row=alt.Facet(
            "cell",
            title=None,
            header=alt.Header(orient="right", labelPadding=2, labelFontSize=11, labelFontWeight="bold"),
            sort=cells,
        ),
        spacing=3,
    )
    .resolve_scale(y="independent")
    .properties(
        title=alt.TitleParams("alphavirus RVP titers", anchor="middle", dx=-30),
    )
)

rvp_titers_chart
Out[14]:

Now plot CHIKV titers:

In [15]:
chikv_titers = (
    pd.read_csv(params["chikv_titers"])
    .groupby(["cell", "virus", "timepoint"], as_index=False)
    .aggregate(
        titer=pd.NamedAgg("titer", "mean"),
        titer_sem=pd.NamedAgg("titer", "sem"),
    )
    .assign(
        virus=lambda x: x["virus"].str.replace("E2_", "").str.replace("_", "-"),
        titer_upper=lambda x: x["titer"] + x["titer_sem"],
        # can't plot infinitely low error bars, so clip them
        titer_lower=lambda x: (x["titer"] - x["titer_sem"]).clip(lower=x["titer"].min() / 2),
    )
)

chikv_titers_base = (
    alt.Chart(chikv_titers)
    .encode(
        alt.X(
            "timepoint",
            title="hours post-infection",
            axis=alt.Axis(grid=False, titleFontWeight="normal", titleFontSize=11),
        ),
        alt.Color(
            "virus",
            scale=alt.Scale(domain=list(viruses), range=list(viruses.values())),
            legend=None,
        ),
    )
    .properties(width=200, height=134)
)

chikv_titers_points = (
    chikv_titers_base
    .encode(
        alt.Y(
            "titer",
            title="titer (RLU)",
            scale=alt.Scale(type="log", nice=False, padding=8),
            axis=alt.Axis(
                grid=False,
                titleFontWeight="normal",
                titleFontSize=11,
                format="~e",
                values=[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7],
            ),
        ),
        alt.Shape("virus", scale=alt.Scale(domain=list(shapes), range=list(shapes.values()))),
    )
    .mark_point(filled=True, fillOpacity=1, size=60)
)

chikv_titers_lines = (
    chikv_titers_base
    .encode(
        alt.Y("titer", scale=alt.Scale(type="log"), axis=alt.Axis(grid=False)),
    )
    .mark_line(size=1)
)

chikv_titers_errorbars = (
    chikv_titers_base
    .encode(
        alt.Y("titer_lower", scale=alt.Scale(type="log"), axis=alt.Axis(grid=False)),
        alt.Y2("titer_upper"),
    )
    .mark_rule(size=1)
)

chikv_titers_chart = (
    (chikv_titers_lines + chikv_titers_errorbars + chikv_titers_points)
    .facet(
        row=alt.Facet(
            "cell",
            title=None,
            header=alt.Header(orient="right", labelPadding=2, labelFontSize=11, labelFontWeight="bold"),
            sort=cells,
        ),
        spacing=3,
    )
    .resolve_scale(y="independent")
    .properties(
        title=alt.TitleParams("authentic CHIKV titers", anchor="middle", dx=20),
    )
)

chikv_titers_chart
Out[15]:
In [16]:
(
    alt.hconcat(rvp_titers_chart, chikv_titers_chart, spacing=30)
    .resolve_scale(color="independent", shape="independent")
    .configure_legend(offset=30, strokeColor="gray", padding=7)
    .configure_view(stroke="gray")
)
Out[16]:
In [ ]: