Compare Cell Entry Effects¶

In this notebook, we'll investigate whether the same mutation affects cell entry differently across three cell types (293T-Mxra8, 293T-TIM1, and C6/36).

Although Mxra8 serves as a receptor and TIM1 as an entry factor for CHIKV in humans, the mosquito receptor remains unknown. By identifying sites where mutations affect cell entry differently in mosquito cells (C6/36) than in human cells (293T-Mxra8 and 293T-TIM1), we may uncover sites involved in binding to the unknown mosquito receptor.

In [1]:
import itertools
import os

import altair as alt

import dmslogo.colorschemes

import numpy

import pandas as pd

import polyclonal.alphabets

import scipy.spatial.distance

# Remove the limit of ~5000 rows -- maybe there are better ways? (https://altair-viz.github.io/user_guide/large_datasets.html)
_ = alt.data_transformers.disable_max_rows()
/fh/fast/bloom_j/computational_notebooks/jbloom/2025/CHIKV-181-25-E-DMS/.snakemake/conda/e937d0191e9e0d4fae02985bc3f90aba_/lib/python3.12/site-packages/dmslogo/logo.py:27: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  import pkg_resources

Get the input parameters¶

The notebook is designed to be parameterized by papermill. The next cell is tagged parameters:

In [2]:
# This cell is tagged parameters, so the values defined here will be overwritten
# by the `papermill` parameterization.

# CSV with filtered data
mut_effects_csv = "../results/summaries/entry_293T-Mxra8_C636_293T-TIM1_Mxra8-binding.csv"
addtl_site_annotations_csv = "../data/addtl_site_annotations.csv"
mxra8_dists_csv = "../results/mxra8_distances/mxra8_dists.csv"

# cells and their names in input file
cells = {"C6/36": "C636", "293T-Mxra8": "293T_Mxra8", "293T-TIM1": "293T_TIM1"}

# for calculating differences and display, floor mutation effects at this
floor_mut_effects = -5

# output files
site_diffs_csv = "../results/compare_cell_entry/site_diffs.csv"
mut_scatter_chart = "../results/compare_cell_entry/compare_cell_entry_scatter.html"
site_zoom_chart = "../results/compare_cell_entry/compare_cell_entry_site_zoom.html"
In [3]:
# Parameters
cells = {"293T-Mxra8": "293T_Mxra8", "C6/36": "C636", "293T-TIM1": "293T_TIM1"}
floor_mut_effects = -5
mut_effects_csv = "results/summaries/entry_293T-Mxra8_C636_293T-TIM1_Mxra8-binding.csv"
addtl_site_annotations_csv = "data/addtl_site_annotations.csv"
mxra8_dists_csv = "results/mxra8_distances/mxra8_dists.csv"
site_diffs_csv = "results/compare_cell_entry/site_diffs.csv"
mut_scatter_chart = "results/compare_cell_entry/compare_cell_entry_scatter.html"
site_zoom_chart = "results/compare_cell_entry/compare_cell_entry_site_zoom.html"

Read the data¶

For this analysis, we'll need the effects of mutations on cell entry in each cell line.

These are pre-filtered (for QC metrics) values:

In [4]:
print(f"Reading mutation effects from {mut_effects_csv=}")
mut_effects = pd.read_csv(mut_effects_csv)

mut_effects
Reading mutation effects from mut_effects_csv='results/summaries/entry_293T-Mxra8_C636_293T-TIM1_Mxra8-binding.csv'
Out[4]:
site wildtype mutant entry in 293T_Mxra8 cells entry in C636 cells entry in 293T_TIM1 cells binding to mouse Mxra8 sequential_site region
0 -1(E3) M I -7.5430 -7.5110 -7.50300 NaN 1 E3
1 -1(E3) M M 0.0000 0.0000 0.00000 0.00000 1 E3
2 -1(E3) M T -7.5640 -7.5410 -7.57800 NaN 1 E3
3 1(6K) A A 0.0000 0.0000 0.00000 0.00000 489 6K
4 1(6K) A C 0.1782 0.0333 0.02943 -0.03924 489 6K
... ... ... ... ... ... ... ... ... ...
19457 99(E2) H S -7.2680 -7.1330 -6.61400 NaN 164 E2
19458 99(E2) H T -7.4900 -6.8320 -6.99300 NaN 164 E2
19459 99(E2) H V -7.5340 -7.4920 -7.41100 NaN 164 E2
19460 99(E2) H W -7.0080 -6.4290 -5.61500 NaN 164 E2
19461 99(E2) H Y -3.1700 -1.4590 -1.22700 -0.42020 164 E2

19462 rows × 9 columns

Get the data tidy format:

In [5]:
col_to_cell = {f"entry in {label} cells": cell for (cell, label) in cells.items()}

assert set(col_to_cell).issubset(mut_effects.columns), f"{col_to_cell=}, {mut_effects.columns=}"

mut_effects_tidy = (
    mut_effects.rename(columns=col_to_cell)
    .melt(
        id_vars=["site", "sequential_site", "wildtype", "mutant", "region"],
        value_vars=col_to_cell.values(),
        var_name="cell",
        value_name="effect",
    )
    .sort_values("sequential_site")
)

mut_effects_tidy
Out[5]:
site sequential_site wildtype mutant region cell effect
0 -1(E3) 1 M I E3 293T-Mxra8 -7.5430
1 -1(E3) 1 M M E3 293T-Mxra8 0.0000
2 -1(E3) 1 M T E3 293T-Mxra8 -7.5640
19462 -1(E3) 1 M I E3 C6/36 -7.5110
19463 -1(E3) 1 M M E3 C6/36 0.0000
... ... ... ... ... ... ... ...
35513 439(E1) 988 H R E1 C6/36 0.3585
54975 439(E1) 988 H R E1 293T-TIM1 0.1848
54976 439(E1) 988 H S E1 293T-TIM1 -0.8648
54966 439(E1) 988 H G E1 293T-TIM1 -0.5203
54967 439(E1) 988 H H E1 293T-TIM1 0.0000

58386 rows × 7 columns

Scatter plots of cell entry for each cell¶

How does the same mutation affect entry in each cell line? We'll plot the effect of each mutation between pairs of cell lines to determine if there are global differences.

In [6]:
def plot_mut_scatter_chart(
    data,
    condition,
    value,
    groupby=['site', 'mutant', 'wildtype', 'sequential_site'],
    color=None,
    label_suffix="",
    init_floor_value=-6,
):
    """
    Make an Altair scatter plot comparing mutant-level values for each condition.

    Parameters
    ----------
    data : pd.DataFrame
        The long-form data to plot
    conditions: str
        The column containing the condition labels (i.e. TIM1, MXRA8, C636)
    value : str
        The column containing the values to compare between conditions
    groupby : list of str
        The columns to group the data on (i.e. ['site', 'mutant', 'wildtype'])
    color : str
        The column to color the points and add an interactive legend for
    label_suffix : str
        Label suffixed to x- and y-axis labels.
    init_floor_value : float or None
        Initial value for floor slider for values.

    Returns
    -------
    alt.Chart
        The Altair chart object
    """    
    if 'mutant' not in groupby or 'site' not in groupby:
        raise ValueError("groupby must contain 'mutant' and 'site'")
    
    missing_cols = [col for col in [condition, value] + groupby if col not in data.columns]
    if missing_cols:
        raise ValueError(f"Columns are missing from the data: {missing_cols}")
    
    if color is not None:
        if color not in data.columns:
            raise ValueError(f"Color column '{color}' not found in data")
        groupby.append(color)
    
    conditions = data[condition].unique()

    # pivot the data
    data_wide = (
        data
        .pivot_table(index=groupby, columns=condition, values=value)
        .reset_index()
    )

    tooltips = []
    for col in groupby:
        tooltips.append(alt.Tooltip(f'{col}:N'))
    for col in conditions:
        tooltips.append(alt.Tooltip(f'{col}:Q', format=".2f"))

    brush = alt.selection_interval()
    
    mut_selection = alt.selection_point(on="mouseover", fields=groupby, empty=False)

    min_value_slider = alt.param(
        name="min_value_slider",
        bind=alt.binding_range(
            min=min(data[value]),
            max=max(data[value]),
            name="floor values at this number",
        ),
        value=(
            max(init_floor_value, min(data[value]))
            if init_floor_value is not None
            else min(data[value])
        ),
    )

    base = (
        alt.Chart(data_wide)
        .add_params(mut_selection, brush, min_value_slider)
        .transform_filter(brush)
    )

    scatters = []
    for condition_a, condition_b in itertools.combinations(conditions, 2):
        # Base data for the scatter plot
        scatter = base.transform_filter(
            f'isValid(datum["{condition_a}"]) && isValid(datum["{condition_b}"])'
        ).transform_calculate(
            condition_a_floored=f'max(datum["{condition_a}"], min_value_slider)',
            condition_b_floored=f'max(datum["{condition_b}"], min_value_slider)',
        ).encode(
            x=alt.X(
                "condition_a_floored:Q",
                title=condition_a + label_suffix,
                scale=alt.Scale(padding=10, nice=False, zero=False),
                axis=alt.Axis(titleFontSize=14, labelFontSize=11, labelOverlap="greedy"),
            ),
            y=alt.Y(
                "condition_b_floored:Q",
                title=condition_b + label_suffix,
                scale=alt.Scale(padding=10, nice=False, zero=False),
                axis=alt.Axis(titleFontSize=14, labelFontSize=11, labelOverlap="greedy"),
            ),
        ).properties(
            title=alt.TitleParams(f'{condition_a} vs {condition_b}', fontSize=16),
            width=250,
            height=250
        )
        # Background points to show the full range of data when brushing
        background = scatter.mark_point(
            filled=True,
            size=25,
            color='lightgray',
            opacity=0.3,
        )
        # Foreground points have tooltips and respond to brushing (and legend selection)
        if color is not None:
            selection = alt.selection_point(fields=[color], bind='legend')
            foreground = scatter.mark_point(
                filled=True,
                fillOpacity=0.5,
                stroke="black",
                strokeOpacity=1,
            ).encode(
                color=alt.Color(color, type='nominal').scale(domain=data[color].unique()),
                strokeWidth=alt.condition(mut_selection, alt.value(3), alt.value(0)),
                size=alt.condition(mut_selection, alt.value(80), alt.value(40)),
                tooltip=tooltips,
            ).add_params(
                selection
            ).transform_filter(selection)
        else:
            foreground = scatter.mark_point(
                filled=True,
                color='steelblue',
                fillOpacity=0.5,
                stroke="black",
                strokeOpacity=1,
            ).encode(
                tooltip=tooltips,
                strokeWidth=alt.condition(mut_selection, alt.value(3), alt.value(0)),
                size=alt.condition(mut_selection, alt.value(70), alt.value(35)),
            )

        scatters.append((background + foreground))

    chart = alt.hconcat(*scatters).configure_axis(grid=False).configure_legend(
        titleFontSize=14, labelFontSize=14
    )

    return chart
In [7]:
mut_scatter = plot_mut_scatter_chart(
    mut_effects_tidy,
    "cell",
    "effect", 
    color="region",
    label_suffix=" cell entry",
    init_floor_value=floor_mut_effects,
)

print(f"Saving chart to {mut_scatter_chart=}")
os.makedirs(os.path.dirname(mut_scatter_chart), exist_ok=True)
mut_scatter.save(mut_scatter_chart)

mut_scatter
Saving chart to mut_scatter_chart='results/compare_cell_entry/compare_cell_entry_scatter.html'
Out[7]:
  • Mouseover on points to see a tooltip with information about that mutation.
  • Hold Click and Drag over points to show only those mutations.
  • Click on conditions in the legend to show only that condition (region).
  • Use the slider to floor values at some mimum plot value.
  • Double Click on the plot or legend to reset the plot.

Points with color show the active selection and gray points show total distribution of the data.

Identify sites where mutations have different effects in each cell¶

Compute site differences between conditions¶

We use three different site-level metrics for the differences between conditions:

  • mean difference: The mean difference in effect on cell entry for all non-wildtype amino acids at each site in cell_1 minus cell_2. We compute this mean after flooring all cell entry effects at the value specified by floor_mut_effects.
  • Jensen-Shannon divergence: A "probability" is assigned to each amino acid at each site as proportional exp(effect), and then the Jensen-Shannon divergence is computed for the probabilities for cell_1 versus cell_2.
  • difference in constraint: A "probability" is assigned to each amino acid as proportional exp(effect), and then the number of effective amino acids at each site is computed for each cell, and we report the number for cell_1 minus cell_2.
In [8]:
# first get color to use for each amino-acid in scatter plot
# this also defines list of amino acids to keep
aa_color_df = (
    pd.Series(dmslogo.colorschemes.AA_FUNCTIONAL_GROUP)
    .rename_axis("mutant")
    .rename("color")
    .reset_index()
)
aas = polyclonal.alphabets.biochem_order_aas(polyclonal.alphabets.AAS)
assert set(aa_color_df["mutant"]) == set(aas)

# get mutation level data, just for amino acids
assert set(cells) == set(mut_effects_tidy["cell"])
mut_data = (
    mut_effects_tidy
    .query("mutant in @aas")
    .pivot_table(
        index=["site", "sequential_site", "wildtype", "mutant", "region"],
        columns="cell",
        values="effect",
    )
    .sort_values("sequential_site")
    .reset_index()
)
assert set(mut_data["wildtype"]).issubset(aas)

# get site difference data
def get_site_diffs(df):
    is_wildtype = df.iloc[:, 0]
    s1 = df.iloc[:, 1]
    s2 = df.iloc[:, 2]
    # simple mean difference across non-wildtype sites
    mean_diff = (s1.clip(lower=floor_mut_effects) - s2.clip(lower=floor_mut_effects))[~is_wildtype].mean()
    # relative entropy
    p1 = numpy.exp(s1[s1.notnull() & s2.notnull()])
    p2 = numpy.exp(s2[s1.notnull() & s2.notnull()])
    assert len(p1) == len(p2)
    if len(p1):
        p1 /= p1.sum()
        p2 /= p2.sum()
        jsd = scipy.spatial.distance.jensenshannon(p1, p2)**2
    else:
        jsd = 0
    # difference in n_effective
    if len(p1) == 0:
        n_eff_diff = 0
    else:
        n_eff_1 = len(aas)**(-p1 * numpy.log(p1) / numpy.log(len(aas))).sum()
        n_eff_2 = len(aas)**(-p2 * numpy.log(p2) / numpy.log(len(aas))).sum()
        n_eff_diff = n_eff_1 - n_eff_2
    return pd.Series(
        {
            "mean difference": mean_diff,
            "Jensen-Shannon divergence": jsd,
            "difference in constraint": n_eff_diff,
        }
    )
    
site_diff_metrics = [
    "difference in constraint", "mean difference", "Jensen-Shannon divergence"
]
site_diffs = []
for cell_1, cell_2 in itertools.combinations(cells, 2):
    site_diffs.append(
        mut_data
        .assign(is_wildtype=lambda x: x["mutant"] == x["wildtype"])
        .groupby(["site", "sequential_site", "region"])
        [["is_wildtype", cell_1, cell_2]]
        .apply(get_site_diffs)
        .assign(cell_1=cell_1, cell_2=cell_2)
        .sort_values("sequential_site")
        .reset_index()
    )
site_diffs = pd.concat(site_diffs, ignore_index=True)
assert set(site_diff_metrics).issubset(site_diffs.columns)

print(f"For mean difference, effects floored at {floor_mut_effects=} first.")
print(f"Saving site differences to {site_diffs_csv=}")
site_diffs.to_csv(site_diffs_csv, index=False, float_format="%.3f")
site_diffs
For mean difference, effects floored at floor_mut_effects=-5 first.
Saving site differences to site_diffs_csv='results/compare_cell_entry/site_diffs.csv'
Out[8]:
site sequential_site region mean difference Jensen-Shannon divergence difference in constraint cell_1 cell_2
0 -1(E3) 1 E3 0.000000 1.033940e-07 -0.000222 293T-Mxra8 C6/36
1 1(E3) 2 E3 0.117874 1.124724e-02 0.850198 293T-Mxra8 C6/36
2 2(E3) 3 E3 -0.069511 2.235952e-02 0.322409 293T-Mxra8 C6/36
3 3(E3) 4 E3 -0.031671 5.401844e-03 0.958338 293T-Mxra8 C6/36
4 4(E3) 5 E3 -0.092400 6.172539e-03 -0.344348 293T-Mxra8 C6/36
... ... ... ... ... ... ... ... ...
2959 435(E1) 984 E1 0.220798 1.962642e-02 -0.251640 C6/36 293T-TIM1
2960 436(E1) 985 E1 0.198397 1.349157e-02 0.014254 C6/36 293T-TIM1
2961 437(E1) 986 E1 0.328327 1.683343e-02 0.338031 C6/36 293T-TIM1
2962 438(E1) 987 E1 0.205080 3.944771e-02 0.387563 C6/36 293T-TIM1
2963 439(E1) 988 E1 0.209541 1.865643e-02 0.098284 C6/36 293T-TIM1

2964 rows × 8 columns

Plot sites with large differences¶

We make an interactive plot that includes:

  • line plot with site differences at top left
  • scatter plot of mutation effects at top right
  • heatmaps centered around key site at bottom

You can click sites on the site plot to show them on the mutation-level plots, zoom with the zoom bar, and use = menu at the bottom to adjust other options including which cells to compare.

In [9]:
def plot_site_comparison(
    mut_data,
    site_diffs,
    cells,
    site_diff_metrics,
    aas,
    aa_color_df,
    init_floor_effect,
    heatmap_max_at_least=2,
    heatmap_flank=12,
):
    """Plot (site-level) difference of entry effects between cells w mutation zooms."""

    # some params
    site_chart_width = 700

    assert set(mut_data["site"]) == set(site_diffs["site"])
    assert set(site_diff_metrics).issubset(site_diffs.columns)

    # Drag to zoom into sites on the x-axis colored by region
    zoom_selection = alt.selection_interval(
        encodings=["x"],
        mark=alt.BrushConfig(stroke='black', strokeWidth=2)
    )

    # zoom bar
    zoom_bar = (
        alt.Chart(mut_data[["site", "sequential_site", "region"]])
        .mark_rect()
        .encode(
            alt.X(
                "site:N",
                sort=alt.SortField("sequential_site"),
                title="click and drag to zoom on sites",
                axis=alt.Axis(ticks=False, labels=False, titleFontWeight="normal"),
            ),
            alt.Color("region", scale=alt.Scale(scheme="greys"), legend=None),
            tooltip=["site", "sequential_site", "region"],
        )
        .properties(width=site_chart_width, height=10)
        .add_params(zoom_selection)
    )

    # line plot
    metric_selection = alt.selection_point(
        fields=["metric"],
        name="metric_selection",
        value=site_diff_metrics[1],
        bind=alt.binding_select(
            options=site_diff_metrics,
            name="metric for site differences between cells",
        ),
    )

    cell_1_options = [c for c in cells if c in set(site_diffs["cell_1"])]
    cell_1_selection = alt.param(
        name="cell_1",
        value=cell_1_options[0],
        bind=alt.binding_select(
            options=cell_1_options,
            name="comparator cell line",
        )
    )

    cell_2_options = [c for c in cells if c in set(site_diffs["cell_2"])]
    cell_2_selection = alt.param(
        name="cell_2",
        value=cell_2_options[0],
        bind=alt.binding_select(
            options=cell_2_options,
            name="reference cell line",
        )
    )

    # site w biggest effect
    default_site = (
        site_diffs[
            (site_diffs["cell_1"] == cell_1_options[0])
            & (site_diffs["cell_2"] == cell_2_options[0])
        ]
        .set_index("site")
        [site_diff_metrics[0]]
        .abs()
        .sort_values(ascending=False)
        .index[0]
    )
    default_sequential_site = site_diffs.set_index("site")["sequential_site"].to_dict()[default_site]

    site_selection = alt.selection_point(
        fields=["site"], empty=False, value=default_site, on="click"
    )
    sequential_site_selection = alt.selection_point(
        fields=["sequential_site"],
        empty=False,
        value=default_sequential_site,
        on="click",
    )
    
    site_base = (
        alt.Chart(site_diffs)
        .transform_filter(zoom_selection)
        .transform_filter(
            (alt.datum["cell_1"] == cell_1_selection)
            & (alt.datum["cell_2"] == cell_2_selection)
        )
        .transform_fold(
            site_diff_metrics,
            ["metric", "difference"],
        )
        .transform_filter(metric_selection)
        .encode(
            alt.X(
                "site:N",
                sort=alt.SortField("sequential_site"),
                title=None,
                axis=alt.Axis(labelOverlap="greedy", ticks=False),
            ),
            alt.Y(
                "difference:Q",
                title="difference at site",
                scale=alt.Scale(nice=False, padding=9),
            ),
            tooltip=[
                "site", "sequential_site", "region", alt.Tooltip("difference:Q", format=".2f")
            ],
        )
    )
    
    site_lines = site_base.mark_line(color="black", strokeWidth=1, opacity=1)

    site_points = site_base.mark_circle(filled=True, fill="black", stroke="gold", opacity=1).encode(
        strokeWidth=alt.condition(site_selection, alt.value(3), alt.value(0)),
        size=alt.condition(site_selection, alt.value(180), alt.value(60)),
    )

    # Dynamic title for chart plot
    site_title = alt.TitleParams(
        alt.expr(
            f'"difference between mutation effects in " + {cell_1_selection.name} + " versus " + {cell_2_selection.name} + " cells"'
        ),
        subtitle="click on a site to show in the mutation-level scatter plot and heatmaps",
        anchor="middle",
    )

    site_chart = (
        (site_lines + site_points)
        .properties(width=site_chart_width, height=185, title=site_title)
        .add_params(
            metric_selection, site_selection, sequential_site_selection, cell_1_selection, cell_2_selection,
        )
    )

    # amino-acid scatter plot for a single site
    min_effect = mut_data[list(cells)].min().min()
    max_effect = mut_data[list(cells)].max().max()
    min_effect_slider = alt.param(
        name="min_effect_slider",
        bind=alt.binding_range(
            min=min_effect, max=max_effect, name="floor displayed mutation effect at",
        ),
        value=max(init_floor_effect, min_effect) if init_floor_effect is not None else min_effect,
    )
    
    mut_base = alt.Chart(mut_data).add_params(min_effect_slider)

    mutant_selection = alt.selection_point(
        fields=["mutant", "site"], on="mouseover", empty=False
    )

    mut_scatter = (
        mut_base
        .transform_filter(site_selection)
        .transform_lookup(
            lookup='mutant',
            from_=alt.LookupData(data=aa_color_df, key='mutant', fields=['color']),
        )
        .transform_calculate(
            x=f"datum[{cell_1_selection.name}]",
            y=f"datum[{cell_2_selection.name}]",
            x_floored=f'isValid(datum.x) ? max(datum.x, {min_effect_slider.name}) : datum.x',
            y_floored=f'isValid(datum.y) ? max(datum.y, {min_effect_slider.name}) : datum.y',
        )
        .encode(
            alt.X("x_floored:Q", title="comparator cell line"),
            alt.Y("y_floored:Q", title="reference cell line"),
            alt.Text("mutant:N"),
            alt.Color("color:N", scale=None),
            size=alt.condition(mutant_selection, alt.value(22), alt.value(18)),
            strokeWidth=alt.condition(mutant_selection, alt.value(1), alt.value(0)),
            fillOpacity=alt.condition(mutant_selection, alt.value(1), alt.value(0.75)),
            tooltip=(
                ["mutant", "wildtype"] + [alt.Tooltip(c, format=".2f") for c in cells]
            )
        )
        .mark_text(stroke="black", strokeOpacity=1, fontWeight=600)
        .add_params(cell_1_selection, cell_2_selection, mutant_selection)
        .properties(
            title=alt.TitleParams(
                alt.expr(f'"mutation effects at site " + {site_selection.name}.site')
            ),
            width=220,
            height=220,
        )
    )

    scatter_diagonal = (
        alt.Chart()
        .mark_rule(color="gray", strokeWidth=3, strokeDash=[6, 6], opacity=0.5)
        .transform_calculate(ax_lim=min_effect_slider.name)
        .encode(
            alt.X("ax_lim:Q", scale=alt.Scale(nice=False, padding=9, zero=False)),
            alt.Y("ax_lim:Q", scale=alt.Scale(nice=False, padding=9, zero=False)),
            x2=alt.datum(max_effect),
            y2=alt.datum(max_effect),
        )
    )

    scatter_chart = scatter_diagonal + mut_scatter

    # make the heatmaps
    assert all(mut_data["sequential_site"] == mut_data["sequential_site"].astype(int))
    assert all(site_diffs["sequential_site"] == site_diffs["sequential_site"].astype(int))
    
    mut_base = alt.Chart(mut_data).add_params(min_effect_slider)
    heatmap_base = (
        mut_base
        .transform_filter(
            f"abs(datum.sequential_site - {sequential_site_selection.name}.sequential_site) <= 11"
        )
        .encode(
            alt.X("site", sort=alt.SortField("sequential_site")),
            alt.Y("mutant", sort=aas),
        )
        .properties(width=alt.Step(12), height=alt.Step(12))
    )

    # gray background for missing values
    heatmap_bg = heatmap_base.transform_impute(
        impute="_stat_dummy",
        key="mutant",
        keyvals=aas,
        groupby=["site"],
        value=None,
    ).mark_rect(color="#E0E0E0", opacity=0.8)

    # mark X for wildtype
    heatmap_wildtype = (
        heatmap_base
        .transform_filter(alt.datum["wildtype"] == alt.datum["mutant"])
        .mark_text(text="x", color="black")
    )

    # make heatmap for each cell type
    heatmaps = []
    for cell in cells:
        first_cell = (cell == list(cells)[0])
        heatmap_muts = (
            heatmap_base
            .transform_calculate(
                effect_floored=f'isValid(datum["{cell}"]) ? max(datum["{cell}"], {min_effect_slider.name}) : datum["{cell}"]'
            )
            .encode(
                alt.Y("mutant", sort=aas, title="amino acid" if first_cell else None),
                alt.Color(
                    "effect_floored:Q",
                    title="mutation effect",
                    legend=alt.Legend(
                        orient="right", titleOrient="right", gradientStrokeColor="black", gradientStrokeWidth=1
                    ),
                    scale=alt.Scale(
                        scheme="redblue",
                        nice=False,
                        domainMid=0,
                        domainMax=max(mut_data[list(cells)].max().max(), heatmap_max_at_least),
                    ),
                ),
                strokeWidth=alt.condition(site_selection, alt.value(3), alt.value(1)),
                tooltip=["site", "sequential_site", "wildtype", "mutant"] + [alt.Tooltip(c, format=".2f") for c in cells],
            )
            .mark_rect(stroke="black", opacity=1, strokeOpacity=1)
            .properties(title=f"{cell} effect")
        )
        heatmaps.append(heatmap_bg + heatmap_muts + heatmap_wildtype)

    heatmap = alt.hconcat(*heatmaps, spacing=7)

    # assemble the final chart
    chart = (
        alt.vconcat(
            alt.hconcat(alt.vconcat(site_chart, zoom_bar, spacing=4), scatter_chart),
            heatmap,
        )
        .configure_title(fontSize=18, subtitleFontSize=16)
        .configure_axis(grid=False, labelFontSize=11, titleFontSize=16)
        .configure_legend(labelFontSize=12, titleFontSize=16)
    )

    return chart
In [10]:
site_chart = plot_site_comparison(
    mut_data, site_diffs, cells, site_diff_metrics, aas, aa_color_df, floor_mut_effects, 2, 12
)

alt.renderers.set_embed_options(
    padding={"left": 5, "right": 5, "bottom": 5, "top": 5}
)

print(f"Saving to {site_zoom_chart=}")
site_chart.save(site_zoom_chart)
site_chart
Saving to site_zoom_chart='results/compare_cell_entry/compare_cell_entry_site_zoom.html'
Out[10]:

Make paper figure plots¶

These plots have some manually hardcoded variables unlike the code above.

First, scatter plots of mean effect at each site in different cells:

In [11]:
fig_site_data = (
    mut_data
    .query("wildtype != mutant")
    .groupby(["wildtype", "site", "region"], as_index=False)
    .aggregate(
        **{
            cell: pd.NamedAgg(cell, lambda s: s.clip(lower=floor_mut_effects).mean())
            for cell in cells
        }
    )
)

fig_site_selection = alt.selection_point(fields=["site"], empty=False, on="mouseover")

fig_site_scatter_base = alt.Chart(fig_site_data).add_params(fig_site_selection)
fig_site_scatter_chart = []
cell_pairs = [('293T-TIM1', '293T-Mxra8'), ('293T-TIM1', 'C6/36'), ('C6/36', '293T-Mxra8')]
for cell1, cell2 in cell_pairs:
    fig_site_scatter_chart.append(
        fig_site_scatter_base
        .encode(
            alt.X(cell1, axis=alt.Axis(values=[0, -2, -4]), scale=alt.Scale(nice=False, padding=8)),
            alt.Y(cell2, axis=alt.Axis(values=[0, -2, -4]), scale=alt.Scale(nice=False, padding=8)),
            tooltip=["site", "wildtype"],
            fill=alt.condition(fig_site_selection, alt.value("red"), alt.value("gray")),
            fillOpacity=alt.condition(fig_site_selection, alt.value(1), alt.value(0.25)),
            size=alt.condition(fig_site_selection, alt.value(90), alt.value(35)),
        )
        .mark_circle(fill="gray", fillOpacity=0.25, strokeOpacity=0.7, stroke="black", strokeWidth=0.5)
        .properties(width=117, height=117)
    )
fig_site_scatter_chart = (
    alt.vconcat(*fig_site_scatter_chart, spacing=13)
    .properties(
        title=alt.TitleParams(
            ["average effect of", "mutations at each site"],
            anchor="middle",
            dx=13
        )
    )
)

fig_site_scatter_chart
Out[11]:

Combine the scatter plot with a line plot of the summed difference at each site:

In [12]:
fig_site_diffs = (
    site_diffs
    .assign(comparison=lambda x: x["cell_1"] + " minus " + x["cell_2"])
    [["site", "sequential_site", "region", "comparison", "mean difference"]]
)

comparisons = [f"{cell2} minus {cell1}" for (cell1, cell2) in cell_pairs]
assert set(comparisons) == set(fig_site_diffs["comparison"])

fig_site_width = 650

# line chart
fig_site_diffs_chart = (
    alt.Chart(fig_site_diffs)
    .add_params(fig_site_selection)
    .encode(
        alt.X(
            "site",
            sort=alt.SortField("sequential_site"),
            axis=alt.Axis(
                values=fig_site_diffs[["sequential_site", "site"]].drop_duplicates().sort_values("sequential_site")["site"].iloc[30::80],
                labelAngle=0,
            ),
        ),
        alt.Y("mean difference", title=None, scale=alt.Scale(nice=False, padding=4)),
        alt.Row(
            "comparison",
            title=None,
            header=alt.Header(labelFontSize=12, labelPadding=4),
            spacing=10,
            sort=comparisons,
        ),
        color=alt.condition(fig_site_selection, alt.value("red"), alt.value("black")),
        tooltip=["site", alt.Tooltip("mean difference", format=".2f", title="difference")],
    )
    .mark_bar(width=2, opacity=1, strokeWidth=0)
    .properties(height=143, width=fig_site_width)
)

# region overlay for line chart
region_chart = (
    alt.Chart(fig_site_diffs[["sequential_site", "region"]].drop_duplicates())
    .encode(
        alt.X("sequential_site:O", axis=None),
        alt.Color(
            "region",
            legend=None,
            scale=alt.Scale(range=["AliceBlue", "CadetBlue", "CadetBlue", "AliceBlue"])
        ),
    )
    .mark_rect(opacity=0.75, strokeWidth=0)
    .properties(width=fig_site_width)
)

text_df = fig_site_diffs.groupby("region", as_index=False).aggregate(x=pd.NamedAgg("sequential_site", "mean"))
text_chart = (
    alt.Chart(text_df)
    .encode(
        alt.X(
            "x:Q",
            title=None,
            scale=alt.Scale(domain=(fig_site_diffs["sequential_site"].min(), fig_site_diffs["sequential_site"].max())),
            axis=None,
        ),
        alt.Text("region"),
    )
    .mark_text(fontWeight="bold", fontSize=13)
    .properties(width=fig_site_width, height=15)
)

overlay_chart = region_chart + text_chart

fig_site_line_chart = (
    alt.vconcat(overlay_chart, fig_site_diffs_chart, spacing=0)
    .properties(
        title=alt.TitleParams("average difference in mutation effects on cell entry at each site", anchor="middle"),
    )
)

fig_site_chart = alt.hconcat(
    fig_site_scatter_chart,
    fig_site_line_chart,
    spacing=53,
    center=True,
)

fig_site_chart.configure_axis(grid=False, titleFontSize=12, titleFontWeight="normal").configure_view(stroke="black")
Out[12]:

Correlate effects on 293T-Mxra8 - 293T-TIM1 with distance from Mxra8 in structure:

In [13]:
mxra8_dists = (
    pd.read_csv(mxra8_dists_csv)
    .assign(
        site=lambda x: x["site"].astype(str) + "(" + x["region"] + ")",
        PDB=lambda x: "distance in PDB " + x["PDB"],
    )
    [["PDB", "region", "site", "distance_to_Mxra8"]]
    .merge(
        (
            fig_site_diffs
            .query("comparison == '293T-Mxra8 minus 293T-TIM1'")
            [["site", "mean difference"]]
        ),
        on="site",
        validate="m:1",
    )
)

mxra8_dists_chart = (
    alt.Chart(mxra8_dists)
    .add_params(fig_site_selection)
    .encode(
        alt.X("mean difference", title=["293T-Mxra8 minus", "293T-TIM1"], scale=alt.Scale(nice=False, padding=5)),
        alt.Y("distance_to_Mxra8", title=None, scale=alt.Scale(nice=False, padding=5)),
        alt.Row("PDB", title=None, header=alt.Header(labelFontSize=12, labelPadding=1)),
        fill=alt.condition(fig_site_selection, alt.value("red"), alt.value("gray")),
        fillOpacity=alt.condition(fig_site_selection, alt.value(1), alt.value(0.25)),
        size=alt.condition(fig_site_selection, alt.value(90), alt.value(35)),
        tooltip=["site"],
    )
    .mark_circle(fill="gray", fillOpacity=0.25, strokeOpacity=0.7, stroke="black", strokeWidth=0.5)
    .properties(
        height=126,
        width=126,
        title=alt.TitleParams(["difference in effects", "vs distance to Mxra8"], anchor="middle", dx=10),
    )
)

mxra8_dists_chart.configure_axis(
    grid=False, titleFontWeight="normal", titleFontSize=12
)
Out[13]:

Now make a function to plot the mutation effects in different cells at key sites:

In [14]:
def plot_site_scatter(
    site, cell_1, cell_2, no_y_axis=None, ax_max=1.5, bold_letters=None, no_title=False, no_x_ticks=False
):
    assert site in set(mut_data["site"]), site
    mut_df = (
        mut_data
        [mut_data["site"] == site]
        [["wildtype", "mutant", cell_1, cell_2]]
    )
    for cell in [cell_1, cell_2]:
        mut_df[cell] = mut_df[cell].clip(lower=floor_mut_effects)

    ax_min = floor_mut_effects - 0.5
    if ax_max is None:
        ax_max = mut_data[[cell_1, cell_2]].max().max() + 0.5

    if bold_letters is None:
        mut_df = mut_df.merge(aa_color_df, on="mutant", validate="one_to_one").assign(
            opacity=0.75, strokeWidth=0.4, size=14
        )
    else:
        mut_df["color"] = mut_df.apply(
            lambda r: (
                ("red" if r["mutant"] != r["wildtype"] else "black") if r["mutant"] in bold_letters else "darkblue"
            ),
            axis=1,
        )
        mut_df["opacity"] = mut_df["mutant"].map(lambda a: 1 if a in bold_letters else 0.25)
        mut_df["strokeWidth"] = mut_df["mutant"].map(lambda a: 0.5 if a in bold_letters else 0)
        mut_df["size"] = mut_df["mutant"].map(lambda a: 15 if a in bold_letters else 12)  
    
    mut_scatter = (
        alt.Chart(mut_df)
        .encode(
            alt.X(
                cell_1,
                scale=alt.Scale(domain=(ax_min, ax_max), nice=False),
                axis=alt.Axis(ticks=False, labels=False) if no_x_ticks else alt.Axis(),
            ),
            alt.Y(
                cell_2,
                title=None if no_y_axis else cell_2,
                scale=alt.Scale(domain=(ax_min, ax_max), nice=False),
                axis=alt.Axis(ticks=False, labels=False) if no_y_axis else alt.Axis()
            ),
            alt.Text("mutant"),
            alt.Color("color:N", scale=None),
            alt.FillOpacity("opacity", scale=None),
            alt.StrokeWidth("strokeWidth", scale=None),
            alt.Size("size", scale=None),
            tooltip=(
                ["mutant", "wildtype"] + [alt.Tooltip(c, format=".2f") for c in [cell_1, cell_2]]
            )
        )
        .mark_text(stroke="black", strokeOpacity=1, fontWeight=700)
        .properties(
            title="" if no_title else alt.TitleParams(f"site {site}", dy=7),
            width=92,
            height=92,
        )
    )

    scatter_diagonal = (
        alt.Chart()
        .mark_rule(color="gray", strokeWidth=3, strokeDash=[6, 6], opacity=0.4)
        .encode(
            x=alt.datum(ax_min),
            y=alt.datum(ax_min),
            x2=alt.datum(ax_max),
            y2=alt.datum(ax_max),
        )
    )

    return scatter_diagonal + mut_scatter

Plot the mutation effects at key sites in each cell line pair, and merge into one figure:

In [15]:
key_sites = ['71(E2)', '119(E2)', '120(E2)', '121(E2)', '157(E2)', '158(E2)', '272(E2)']

top_diff_sites_scatter_chart = []
no_title = False
for cell1, cell2 in cell_pairs:
    top_diff_sites_scatter_chart.append(
        alt.hconcat(
            *[
                plot_site_scatter(
                    s, cell1, cell2, no_y_axis=(s != key_sites[0]), no_title=no_title
                )
                for s in key_sites
            ],
            spacing=-2,
        )
    )
    no_title = True
    
top_diff_sites_scatter_chart = alt.vconcat(
    *top_diff_sites_scatter_chart, spacing=5
)

(
    alt.vconcat(
        fig_site_chart,
        alt.hconcat(mxra8_dists_chart, top_diff_sites_scatter_chart, spacing=30),
        spacing=9,
    )
    .configure_axis(grid=False, titleFontWeight="normal", titleFontSize=12, titlePadding=2)
    .configure_view(stroke="black")
)
Out[15]:

Make figure showing mutations selected for validation, with wildtype in black, mutants made in red, and other letters in faint blue:

In [16]:
# plot the differences in the experimentally validated sites
validation_sites = {
    '119(E2)': ["R", "K"],
    '120(E2)': ["K", "D"],
    '121(E2)': ["I", "E"],
    '157(E2)': ["A", "S"],
    '158(E2)': ["Q", "T", "V"],
}

validation_scatter_chart = []
for cell1, cell2 in itertools.combinations(reversed(cells), 2):
    validation_scatter_chart.append(
        alt.hconcat(
            *[
                plot_site_scatter(
                    s,
                    cell1,
                    cell2,
                    no_y_axis=(s != list(validation_sites)[0]),
                    ax_max=mut_data.query("site in @validation_sites")[[cell1, cell2]].max().max() + 0.6,
                    bold_letters=letters,
                )
                for s, letters in validation_sites.items()
            ],
            spacing=2,
        )
    )
validation_scatter_chart = (
    alt.vconcat(*validation_scatter_chart, spacing=15)
    .configure_axis(grid=False, titleFontWeight="normal", titleFontSize=12)
    .configure_view(stroke="black")
)
validation_scatter_chart
Out[16]:
In [ ]: