Compare Cell Entry Effects¶
In this notebook, we'll investigate whether the same mutation affects cell entry differently across three cell types (293T-Mxra8, 293T-TIM1, and C6/36).
Although Mxra8 serves as a receptor and TIM1 as an entry factor for CHIKV in humans, the mosquito receptor remains unknown. By identifying sites where mutations affect cell entry differently in mosquito cells (C6/36) than in human cells (293T-Mxra8 and 293T-TIM1), we may uncover sites involved in binding to the unknown mosquito receptor.
import itertools
import os
import altair as alt
import dmslogo.colorschemes
import numpy
import pandas as pd
import polyclonal.alphabets
import scipy.spatial.distance
# Remove the limit of ~5000 rows -- maybe there are better ways? (https://altair-viz.github.io/user_guide/large_datasets.html)
_ = alt.data_transformers.disable_max_rows()
/fh/fast/bloom_j/computational_notebooks/jbloom/2025/CHIKV-181-25-E-DMS/.snakemake/conda/e937d0191e9e0d4fae02985bc3f90aba_/lib/python3.12/site-packages/dmslogo/logo.py:27: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81. import pkg_resources
Get the input parameters¶
The notebook is designed to be parameterized by papermill.
The next cell is tagged parameters:
# This cell is tagged parameters, so the values defined here will be overwritten
# by the `papermill` parameterization.
# CSV with filtered data
mut_effects_csv = "../results/summaries/entry_293T-Mxra8_C636_293T-TIM1_Mxra8-binding.csv"
addtl_site_annotations_csv = "../data/addtl_site_annotations.csv"
mxra8_dists_csv = "../results/mxra8_distances/mxra8_dists.csv"
# cells and their names in input file
cells = {"C6/36": "C636", "293T-Mxra8": "293T_Mxra8", "293T-TIM1": "293T_TIM1"}
# for calculating differences and display, floor mutation effects at this
floor_mut_effects = -5
# output files
site_diffs_csv = "../results/compare_cell_entry/site_diffs.csv"
mut_scatter_chart = "../results/compare_cell_entry/compare_cell_entry_scatter.html"
site_zoom_chart = "../results/compare_cell_entry/compare_cell_entry_site_zoom.html"
# Parameters
cells = {"293T-Mxra8": "293T_Mxra8", "C6/36": "C636", "293T-TIM1": "293T_TIM1"}
floor_mut_effects = -5
mut_effects_csv = "results/summaries/entry_293T-Mxra8_C636_293T-TIM1_Mxra8-binding.csv"
addtl_site_annotations_csv = "data/addtl_site_annotations.csv"
mxra8_dists_csv = "results/mxra8_distances/mxra8_dists.csv"
site_diffs_csv = "results/compare_cell_entry/site_diffs.csv"
mut_scatter_chart = "results/compare_cell_entry/compare_cell_entry_scatter.html"
site_zoom_chart = "results/compare_cell_entry/compare_cell_entry_site_zoom.html"
Read the data¶
For this analysis, we'll need the effects of mutations on cell entry in each cell line.
These are pre-filtered (for QC metrics) values:
print(f"Reading mutation effects from {mut_effects_csv=}")
mut_effects = pd.read_csv(mut_effects_csv)
mut_effects
Reading mutation effects from mut_effects_csv='results/summaries/entry_293T-Mxra8_C636_293T-TIM1_Mxra8-binding.csv'
| site | wildtype | mutant | entry in 293T_Mxra8 cells | entry in C636 cells | entry in 293T_TIM1 cells | binding to mouse Mxra8 | sequential_site | region | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | -1(E3) | M | I | -7.5430 | -7.5110 | -7.50300 | NaN | 1 | E3 |
| 1 | -1(E3) | M | M | 0.0000 | 0.0000 | 0.00000 | 0.00000 | 1 | E3 |
| 2 | -1(E3) | M | T | -7.5640 | -7.5410 | -7.57800 | NaN | 1 | E3 |
| 3 | 1(6K) | A | A | 0.0000 | 0.0000 | 0.00000 | 0.00000 | 489 | 6K |
| 4 | 1(6K) | A | C | 0.1782 | 0.0333 | 0.02943 | -0.03924 | 489 | 6K |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 19457 | 99(E2) | H | S | -7.2680 | -7.1330 | -6.61400 | NaN | 164 | E2 |
| 19458 | 99(E2) | H | T | -7.4900 | -6.8320 | -6.99300 | NaN | 164 | E2 |
| 19459 | 99(E2) | H | V | -7.5340 | -7.4920 | -7.41100 | NaN | 164 | E2 |
| 19460 | 99(E2) | H | W | -7.0080 | -6.4290 | -5.61500 | NaN | 164 | E2 |
| 19461 | 99(E2) | H | Y | -3.1700 | -1.4590 | -1.22700 | -0.42020 | 164 | E2 |
19462 rows × 9 columns
Get the data tidy format:
col_to_cell = {f"entry in {label} cells": cell for (cell, label) in cells.items()}
assert set(col_to_cell).issubset(mut_effects.columns), f"{col_to_cell=}, {mut_effects.columns=}"
mut_effects_tidy = (
mut_effects.rename(columns=col_to_cell)
.melt(
id_vars=["site", "sequential_site", "wildtype", "mutant", "region"],
value_vars=col_to_cell.values(),
var_name="cell",
value_name="effect",
)
.sort_values("sequential_site")
)
mut_effects_tidy
| site | sequential_site | wildtype | mutant | region | cell | effect | |
|---|---|---|---|---|---|---|---|
| 0 | -1(E3) | 1 | M | I | E3 | 293T-Mxra8 | -7.5430 |
| 1 | -1(E3) | 1 | M | M | E3 | 293T-Mxra8 | 0.0000 |
| 2 | -1(E3) | 1 | M | T | E3 | 293T-Mxra8 | -7.5640 |
| 19462 | -1(E3) | 1 | M | I | E3 | C6/36 | -7.5110 |
| 19463 | -1(E3) | 1 | M | M | E3 | C6/36 | 0.0000 |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 35513 | 439(E1) | 988 | H | R | E1 | C6/36 | 0.3585 |
| 54975 | 439(E1) | 988 | H | R | E1 | 293T-TIM1 | 0.1848 |
| 54976 | 439(E1) | 988 | H | S | E1 | 293T-TIM1 | -0.8648 |
| 54966 | 439(E1) | 988 | H | G | E1 | 293T-TIM1 | -0.5203 |
| 54967 | 439(E1) | 988 | H | H | E1 | 293T-TIM1 | 0.0000 |
58386 rows × 7 columns
Scatter plots of cell entry for each cell¶
How does the same mutation affect entry in each cell line? We'll plot the effect of each mutation between pairs of cell lines to determine if there are global differences.
def plot_mut_scatter_chart(
data,
condition,
value,
groupby=['site', 'mutant', 'wildtype', 'sequential_site'],
color=None,
label_suffix="",
init_floor_value=-6,
):
"""
Make an Altair scatter plot comparing mutant-level values for each condition.
Parameters
----------
data : pd.DataFrame
The long-form data to plot
conditions: str
The column containing the condition labels (i.e. TIM1, MXRA8, C636)
value : str
The column containing the values to compare between conditions
groupby : list of str
The columns to group the data on (i.e. ['site', 'mutant', 'wildtype'])
color : str
The column to color the points and add an interactive legend for
label_suffix : str
Label suffixed to x- and y-axis labels.
init_floor_value : float or None
Initial value for floor slider for values.
Returns
-------
alt.Chart
The Altair chart object
"""
if 'mutant' not in groupby or 'site' not in groupby:
raise ValueError("groupby must contain 'mutant' and 'site'")
missing_cols = [col for col in [condition, value] + groupby if col not in data.columns]
if missing_cols:
raise ValueError(f"Columns are missing from the data: {missing_cols}")
if color is not None:
if color not in data.columns:
raise ValueError(f"Color column '{color}' not found in data")
groupby.append(color)
conditions = data[condition].unique()
# pivot the data
data_wide = (
data
.pivot_table(index=groupby, columns=condition, values=value)
.reset_index()
)
tooltips = []
for col in groupby:
tooltips.append(alt.Tooltip(f'{col}:N'))
for col in conditions:
tooltips.append(alt.Tooltip(f'{col}:Q', format=".2f"))
brush = alt.selection_interval()
mut_selection = alt.selection_point(on="mouseover", fields=groupby, empty=False)
min_value_slider = alt.param(
name="min_value_slider",
bind=alt.binding_range(
min=min(data[value]),
max=max(data[value]),
name="floor values at this number",
),
value=(
max(init_floor_value, min(data[value]))
if init_floor_value is not None
else min(data[value])
),
)
base = (
alt.Chart(data_wide)
.add_params(mut_selection, brush, min_value_slider)
.transform_filter(brush)
)
scatters = []
for condition_a, condition_b in itertools.combinations(conditions, 2):
# Base data for the scatter plot
scatter = base.transform_filter(
f'isValid(datum["{condition_a}"]) && isValid(datum["{condition_b}"])'
).transform_calculate(
condition_a_floored=f'max(datum["{condition_a}"], min_value_slider)',
condition_b_floored=f'max(datum["{condition_b}"], min_value_slider)',
).encode(
x=alt.X(
"condition_a_floored:Q",
title=condition_a + label_suffix,
scale=alt.Scale(padding=10, nice=False, zero=False),
axis=alt.Axis(titleFontSize=14, labelFontSize=11, labelOverlap="greedy"),
),
y=alt.Y(
"condition_b_floored:Q",
title=condition_b + label_suffix,
scale=alt.Scale(padding=10, nice=False, zero=False),
axis=alt.Axis(titleFontSize=14, labelFontSize=11, labelOverlap="greedy"),
),
).properties(
title=alt.TitleParams(f'{condition_a} vs {condition_b}', fontSize=16),
width=250,
height=250
)
# Background points to show the full range of data when brushing
background = scatter.mark_point(
filled=True,
size=25,
color='lightgray',
opacity=0.3,
)
# Foreground points have tooltips and respond to brushing (and legend selection)
if color is not None:
selection = alt.selection_point(fields=[color], bind='legend')
foreground = scatter.mark_point(
filled=True,
fillOpacity=0.5,
stroke="black",
strokeOpacity=1,
).encode(
color=alt.Color(color, type='nominal').scale(domain=data[color].unique()),
strokeWidth=alt.condition(mut_selection, alt.value(3), alt.value(0)),
size=alt.condition(mut_selection, alt.value(80), alt.value(40)),
tooltip=tooltips,
).add_params(
selection
).transform_filter(selection)
else:
foreground = scatter.mark_point(
filled=True,
color='steelblue',
fillOpacity=0.5,
stroke="black",
strokeOpacity=1,
).encode(
tooltip=tooltips,
strokeWidth=alt.condition(mut_selection, alt.value(3), alt.value(0)),
size=alt.condition(mut_selection, alt.value(70), alt.value(35)),
)
scatters.append((background + foreground))
chart = alt.hconcat(*scatters).configure_axis(grid=False).configure_legend(
titleFontSize=14, labelFontSize=14
)
return chart
mut_scatter = plot_mut_scatter_chart(
mut_effects_tidy,
"cell",
"effect",
color="region",
label_suffix=" cell entry",
init_floor_value=floor_mut_effects,
)
print(f"Saving chart to {mut_scatter_chart=}")
os.makedirs(os.path.dirname(mut_scatter_chart), exist_ok=True)
mut_scatter.save(mut_scatter_chart)
mut_scatter
Saving chart to mut_scatter_chart='results/compare_cell_entry/compare_cell_entry_scatter.html'
- Mouseover on points to see a tooltip with information about that mutation.
- Hold Click and Drag over points to show only those mutations.
- Click on conditions in the legend to show only that condition (
region). - Use the slider to floor values at some mimum plot value.
- Double Click on the plot or legend to reset the plot.
Points with color show the active selection and gray points show total distribution of the data.
Identify sites where mutations have different effects in each cell¶
Compute site differences between conditions¶
We use three different site-level metrics for the differences between conditions:
- mean difference: The mean difference in effect on cell entry for all non-wildtype amino acids at each site in cell_1 minus cell_2. We compute this mean after flooring all cell entry effects at the value specified by
floor_mut_effects. - Jensen-Shannon divergence: A "probability" is assigned to each amino acid at each site as proportional
exp(effect), and then the Jensen-Shannon divergence is computed for the probabilities for cell_1 versus cell_2. - difference in constraint: A "probability" is assigned to each amino acid as proportional
exp(effect), and then the number of effective amino acids at each site is computed for each cell, and we report the number for cell_1 minus cell_2.
# first get color to use for each amino-acid in scatter plot
# this also defines list of amino acids to keep
aa_color_df = (
pd.Series(dmslogo.colorschemes.AA_FUNCTIONAL_GROUP)
.rename_axis("mutant")
.rename("color")
.reset_index()
)
aas = polyclonal.alphabets.biochem_order_aas(polyclonal.alphabets.AAS)
assert set(aa_color_df["mutant"]) == set(aas)
# get mutation level data, just for amino acids
assert set(cells) == set(mut_effects_tidy["cell"])
mut_data = (
mut_effects_tidy
.query("mutant in @aas")
.pivot_table(
index=["site", "sequential_site", "wildtype", "mutant", "region"],
columns="cell",
values="effect",
)
.sort_values("sequential_site")
.reset_index()
)
assert set(mut_data["wildtype"]).issubset(aas)
# get site difference data
def get_site_diffs(df):
is_wildtype = df.iloc[:, 0]
s1 = df.iloc[:, 1]
s2 = df.iloc[:, 2]
# simple mean difference across non-wildtype sites
mean_diff = (s1.clip(lower=floor_mut_effects) - s2.clip(lower=floor_mut_effects))[~is_wildtype].mean()
# relative entropy
p1 = numpy.exp(s1[s1.notnull() & s2.notnull()])
p2 = numpy.exp(s2[s1.notnull() & s2.notnull()])
assert len(p1) == len(p2)
if len(p1):
p1 /= p1.sum()
p2 /= p2.sum()
jsd = scipy.spatial.distance.jensenshannon(p1, p2)**2
else:
jsd = 0
# difference in n_effective
if len(p1) == 0:
n_eff_diff = 0
else:
n_eff_1 = len(aas)**(-p1 * numpy.log(p1) / numpy.log(len(aas))).sum()
n_eff_2 = len(aas)**(-p2 * numpy.log(p2) / numpy.log(len(aas))).sum()
n_eff_diff = n_eff_1 - n_eff_2
return pd.Series(
{
"mean difference": mean_diff,
"Jensen-Shannon divergence": jsd,
"difference in constraint": n_eff_diff,
}
)
site_diff_metrics = [
"difference in constraint", "mean difference", "Jensen-Shannon divergence"
]
site_diffs = []
for cell_1, cell_2 in itertools.combinations(cells, 2):
site_diffs.append(
mut_data
.assign(is_wildtype=lambda x: x["mutant"] == x["wildtype"])
.groupby(["site", "sequential_site", "region"])
[["is_wildtype", cell_1, cell_2]]
.apply(get_site_diffs)
.assign(cell_1=cell_1, cell_2=cell_2)
.sort_values("sequential_site")
.reset_index()
)
site_diffs = pd.concat(site_diffs, ignore_index=True)
assert set(site_diff_metrics).issubset(site_diffs.columns)
print(f"For mean difference, effects floored at {floor_mut_effects=} first.")
print(f"Saving site differences to {site_diffs_csv=}")
site_diffs.to_csv(site_diffs_csv, index=False, float_format="%.3f")
site_diffs
For mean difference, effects floored at floor_mut_effects=-5 first. Saving site differences to site_diffs_csv='results/compare_cell_entry/site_diffs.csv'
| site | sequential_site | region | mean difference | Jensen-Shannon divergence | difference in constraint | cell_1 | cell_2 | |
|---|---|---|---|---|---|---|---|---|
| 0 | -1(E3) | 1 | E3 | 0.000000 | 1.033940e-07 | -0.000222 | 293T-Mxra8 | C6/36 |
| 1 | 1(E3) | 2 | E3 | 0.117874 | 1.124724e-02 | 0.850198 | 293T-Mxra8 | C6/36 |
| 2 | 2(E3) | 3 | E3 | -0.069511 | 2.235952e-02 | 0.322409 | 293T-Mxra8 | C6/36 |
| 3 | 3(E3) | 4 | E3 | -0.031671 | 5.401844e-03 | 0.958338 | 293T-Mxra8 | C6/36 |
| 4 | 4(E3) | 5 | E3 | -0.092400 | 6.172539e-03 | -0.344348 | 293T-Mxra8 | C6/36 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 2959 | 435(E1) | 984 | E1 | 0.220798 | 1.962642e-02 | -0.251640 | C6/36 | 293T-TIM1 |
| 2960 | 436(E1) | 985 | E1 | 0.198397 | 1.349157e-02 | 0.014254 | C6/36 | 293T-TIM1 |
| 2961 | 437(E1) | 986 | E1 | 0.328327 | 1.683343e-02 | 0.338031 | C6/36 | 293T-TIM1 |
| 2962 | 438(E1) | 987 | E1 | 0.205080 | 3.944771e-02 | 0.387563 | C6/36 | 293T-TIM1 |
| 2963 | 439(E1) | 988 | E1 | 0.209541 | 1.865643e-02 | 0.098284 | C6/36 | 293T-TIM1 |
2964 rows × 8 columns
Plot sites with large differences¶
We make an interactive plot that includes:
- line plot with site differences at top left
- scatter plot of mutation effects at top right
- heatmaps centered around key site at bottom
You can click sites on the site plot to show them on the mutation-level plots, zoom with the zoom bar, and use = menu at the bottom to adjust other options including which cells to compare.
def plot_site_comparison(
mut_data,
site_diffs,
cells,
site_diff_metrics,
aas,
aa_color_df,
init_floor_effect,
heatmap_max_at_least=2,
heatmap_flank=12,
):
"""Plot (site-level) difference of entry effects between cells w mutation zooms."""
# some params
site_chart_width = 700
assert set(mut_data["site"]) == set(site_diffs["site"])
assert set(site_diff_metrics).issubset(site_diffs.columns)
# Drag to zoom into sites on the x-axis colored by region
zoom_selection = alt.selection_interval(
encodings=["x"],
mark=alt.BrushConfig(stroke='black', strokeWidth=2)
)
# zoom bar
zoom_bar = (
alt.Chart(mut_data[["site", "sequential_site", "region"]])
.mark_rect()
.encode(
alt.X(
"site:N",
sort=alt.SortField("sequential_site"),
title="click and drag to zoom on sites",
axis=alt.Axis(ticks=False, labels=False, titleFontWeight="normal"),
),
alt.Color("region", scale=alt.Scale(scheme="greys"), legend=None),
tooltip=["site", "sequential_site", "region"],
)
.properties(width=site_chart_width, height=10)
.add_params(zoom_selection)
)
# line plot
metric_selection = alt.selection_point(
fields=["metric"],
name="metric_selection",
value=site_diff_metrics[1],
bind=alt.binding_select(
options=site_diff_metrics,
name="metric for site differences between cells",
),
)
cell_1_options = [c for c in cells if c in set(site_diffs["cell_1"])]
cell_1_selection = alt.param(
name="cell_1",
value=cell_1_options[0],
bind=alt.binding_select(
options=cell_1_options,
name="comparator cell line",
)
)
cell_2_options = [c for c in cells if c in set(site_diffs["cell_2"])]
cell_2_selection = alt.param(
name="cell_2",
value=cell_2_options[0],
bind=alt.binding_select(
options=cell_2_options,
name="reference cell line",
)
)
# site w biggest effect
default_site = (
site_diffs[
(site_diffs["cell_1"] == cell_1_options[0])
& (site_diffs["cell_2"] == cell_2_options[0])
]
.set_index("site")
[site_diff_metrics[0]]
.abs()
.sort_values(ascending=False)
.index[0]
)
default_sequential_site = site_diffs.set_index("site")["sequential_site"].to_dict()[default_site]
site_selection = alt.selection_point(
fields=["site"], empty=False, value=default_site, on="click"
)
sequential_site_selection = alt.selection_point(
fields=["sequential_site"],
empty=False,
value=default_sequential_site,
on="click",
)
site_base = (
alt.Chart(site_diffs)
.transform_filter(zoom_selection)
.transform_filter(
(alt.datum["cell_1"] == cell_1_selection)
& (alt.datum["cell_2"] == cell_2_selection)
)
.transform_fold(
site_diff_metrics,
["metric", "difference"],
)
.transform_filter(metric_selection)
.encode(
alt.X(
"site:N",
sort=alt.SortField("sequential_site"),
title=None,
axis=alt.Axis(labelOverlap="greedy", ticks=False),
),
alt.Y(
"difference:Q",
title="difference at site",
scale=alt.Scale(nice=False, padding=9),
),
tooltip=[
"site", "sequential_site", "region", alt.Tooltip("difference:Q", format=".2f")
],
)
)
site_lines = site_base.mark_line(color="black", strokeWidth=1, opacity=1)
site_points = site_base.mark_circle(filled=True, fill="black", stroke="gold", opacity=1).encode(
strokeWidth=alt.condition(site_selection, alt.value(3), alt.value(0)),
size=alt.condition(site_selection, alt.value(180), alt.value(60)),
)
# Dynamic title for chart plot
site_title = alt.TitleParams(
alt.expr(
f'"difference between mutation effects in " + {cell_1_selection.name} + " versus " + {cell_2_selection.name} + " cells"'
),
subtitle="click on a site to show in the mutation-level scatter plot and heatmaps",
anchor="middle",
)
site_chart = (
(site_lines + site_points)
.properties(width=site_chart_width, height=185, title=site_title)
.add_params(
metric_selection, site_selection, sequential_site_selection, cell_1_selection, cell_2_selection,
)
)
# amino-acid scatter plot for a single site
min_effect = mut_data[list(cells)].min().min()
max_effect = mut_data[list(cells)].max().max()
min_effect_slider = alt.param(
name="min_effect_slider",
bind=alt.binding_range(
min=min_effect, max=max_effect, name="floor displayed mutation effect at",
),
value=max(init_floor_effect, min_effect) if init_floor_effect is not None else min_effect,
)
mut_base = alt.Chart(mut_data).add_params(min_effect_slider)
mutant_selection = alt.selection_point(
fields=["mutant", "site"], on="mouseover", empty=False
)
mut_scatter = (
mut_base
.transform_filter(site_selection)
.transform_lookup(
lookup='mutant',
from_=alt.LookupData(data=aa_color_df, key='mutant', fields=['color']),
)
.transform_calculate(
x=f"datum[{cell_1_selection.name}]",
y=f"datum[{cell_2_selection.name}]",
x_floored=f'isValid(datum.x) ? max(datum.x, {min_effect_slider.name}) : datum.x',
y_floored=f'isValid(datum.y) ? max(datum.y, {min_effect_slider.name}) : datum.y',
)
.encode(
alt.X("x_floored:Q", title="comparator cell line"),
alt.Y("y_floored:Q", title="reference cell line"),
alt.Text("mutant:N"),
alt.Color("color:N", scale=None),
size=alt.condition(mutant_selection, alt.value(22), alt.value(18)),
strokeWidth=alt.condition(mutant_selection, alt.value(1), alt.value(0)),
fillOpacity=alt.condition(mutant_selection, alt.value(1), alt.value(0.75)),
tooltip=(
["mutant", "wildtype"] + [alt.Tooltip(c, format=".2f") for c in cells]
)
)
.mark_text(stroke="black", strokeOpacity=1, fontWeight=600)
.add_params(cell_1_selection, cell_2_selection, mutant_selection)
.properties(
title=alt.TitleParams(
alt.expr(f'"mutation effects at site " + {site_selection.name}.site')
),
width=220,
height=220,
)
)
scatter_diagonal = (
alt.Chart()
.mark_rule(color="gray", strokeWidth=3, strokeDash=[6, 6], opacity=0.5)
.transform_calculate(ax_lim=min_effect_slider.name)
.encode(
alt.X("ax_lim:Q", scale=alt.Scale(nice=False, padding=9, zero=False)),
alt.Y("ax_lim:Q", scale=alt.Scale(nice=False, padding=9, zero=False)),
x2=alt.datum(max_effect),
y2=alt.datum(max_effect),
)
)
scatter_chart = scatter_diagonal + mut_scatter
# make the heatmaps
assert all(mut_data["sequential_site"] == mut_data["sequential_site"].astype(int))
assert all(site_diffs["sequential_site"] == site_diffs["sequential_site"].astype(int))
mut_base = alt.Chart(mut_data).add_params(min_effect_slider)
heatmap_base = (
mut_base
.transform_filter(
f"abs(datum.sequential_site - {sequential_site_selection.name}.sequential_site) <= 11"
)
.encode(
alt.X("site", sort=alt.SortField("sequential_site")),
alt.Y("mutant", sort=aas),
)
.properties(width=alt.Step(12), height=alt.Step(12))
)
# gray background for missing values
heatmap_bg = heatmap_base.transform_impute(
impute="_stat_dummy",
key="mutant",
keyvals=aas,
groupby=["site"],
value=None,
).mark_rect(color="#E0E0E0", opacity=0.8)
# mark X for wildtype
heatmap_wildtype = (
heatmap_base
.transform_filter(alt.datum["wildtype"] == alt.datum["mutant"])
.mark_text(text="x", color="black")
)
# make heatmap for each cell type
heatmaps = []
for cell in cells:
first_cell = (cell == list(cells)[0])
heatmap_muts = (
heatmap_base
.transform_calculate(
effect_floored=f'isValid(datum["{cell}"]) ? max(datum["{cell}"], {min_effect_slider.name}) : datum["{cell}"]'
)
.encode(
alt.Y("mutant", sort=aas, title="amino acid" if first_cell else None),
alt.Color(
"effect_floored:Q",
title="mutation effect",
legend=alt.Legend(
orient="right", titleOrient="right", gradientStrokeColor="black", gradientStrokeWidth=1
),
scale=alt.Scale(
scheme="redblue",
nice=False,
domainMid=0,
domainMax=max(mut_data[list(cells)].max().max(), heatmap_max_at_least),
),
),
strokeWidth=alt.condition(site_selection, alt.value(3), alt.value(1)),
tooltip=["site", "sequential_site", "wildtype", "mutant"] + [alt.Tooltip(c, format=".2f") for c in cells],
)
.mark_rect(stroke="black", opacity=1, strokeOpacity=1)
.properties(title=f"{cell} effect")
)
heatmaps.append(heatmap_bg + heatmap_muts + heatmap_wildtype)
heatmap = alt.hconcat(*heatmaps, spacing=7)
# assemble the final chart
chart = (
alt.vconcat(
alt.hconcat(alt.vconcat(site_chart, zoom_bar, spacing=4), scatter_chart),
heatmap,
)
.configure_title(fontSize=18, subtitleFontSize=16)
.configure_axis(grid=False, labelFontSize=11, titleFontSize=16)
.configure_legend(labelFontSize=12, titleFontSize=16)
)
return chart
site_chart = plot_site_comparison(
mut_data, site_diffs, cells, site_diff_metrics, aas, aa_color_df, floor_mut_effects, 2, 12
)
alt.renderers.set_embed_options(
padding={"left": 5, "right": 5, "bottom": 5, "top": 5}
)
print(f"Saving to {site_zoom_chart=}")
site_chart.save(site_zoom_chart)
site_chart
Saving to site_zoom_chart='results/compare_cell_entry/compare_cell_entry_site_zoom.html'
Make paper figure plots¶
These plots have some manually hardcoded variables unlike the code above.
First, scatter plots of mean effect at each site in different cells:
fig_site_data = (
mut_data
.query("wildtype != mutant")
.groupby(["wildtype", "site", "region"], as_index=False)
.aggregate(
**{
cell: pd.NamedAgg(cell, lambda s: s.clip(lower=floor_mut_effects).mean())
for cell in cells
}
)
)
fig_site_selection = alt.selection_point(fields=["site"], empty=False, on="mouseover")
fig_site_scatter_base = alt.Chart(fig_site_data).add_params(fig_site_selection)
fig_site_scatter_chart = []
cell_pairs = [('293T-TIM1', '293T-Mxra8'), ('293T-TIM1', 'C6/36'), ('C6/36', '293T-Mxra8')]
for cell1, cell2 in cell_pairs:
fig_site_scatter_chart.append(
fig_site_scatter_base
.encode(
alt.X(cell1, axis=alt.Axis(values=[0, -2, -4]), scale=alt.Scale(nice=False, padding=8)),
alt.Y(cell2, axis=alt.Axis(values=[0, -2, -4]), scale=alt.Scale(nice=False, padding=8)),
tooltip=["site", "wildtype"],
fill=alt.condition(fig_site_selection, alt.value("red"), alt.value("gray")),
fillOpacity=alt.condition(fig_site_selection, alt.value(1), alt.value(0.25)),
size=alt.condition(fig_site_selection, alt.value(90), alt.value(35)),
)
.mark_circle(fill="gray", fillOpacity=0.25, strokeOpacity=0.7, stroke="black", strokeWidth=0.5)
.properties(width=117, height=117)
)
fig_site_scatter_chart = (
alt.vconcat(*fig_site_scatter_chart, spacing=13)
.properties(
title=alt.TitleParams(
["average effect of", "mutations at each site"],
anchor="middle",
dx=13
)
)
)
fig_site_scatter_chart
Combine the scatter plot with a line plot of the summed difference at each site:
fig_site_diffs = (
site_diffs
.assign(comparison=lambda x: x["cell_1"] + " minus " + x["cell_2"])
[["site", "sequential_site", "region", "comparison", "mean difference"]]
)
comparisons = [f"{cell2} minus {cell1}" for (cell1, cell2) in cell_pairs]
assert set(comparisons) == set(fig_site_diffs["comparison"])
fig_site_width = 650
# line chart
fig_site_diffs_chart = (
alt.Chart(fig_site_diffs)
.add_params(fig_site_selection)
.encode(
alt.X(
"site",
sort=alt.SortField("sequential_site"),
axis=alt.Axis(
values=fig_site_diffs[["sequential_site", "site"]].drop_duplicates().sort_values("sequential_site")["site"].iloc[30::80],
labelAngle=0,
),
),
alt.Y("mean difference", title=None, scale=alt.Scale(nice=False, padding=4)),
alt.Row(
"comparison",
title=None,
header=alt.Header(labelFontSize=12, labelPadding=4),
spacing=10,
sort=comparisons,
),
color=alt.condition(fig_site_selection, alt.value("red"), alt.value("black")),
tooltip=["site", alt.Tooltip("mean difference", format=".2f", title="difference")],
)
.mark_bar(width=2, opacity=1, strokeWidth=0)
.properties(height=143, width=fig_site_width)
)
# region overlay for line chart
region_chart = (
alt.Chart(fig_site_diffs[["sequential_site", "region"]].drop_duplicates())
.encode(
alt.X("sequential_site:O", axis=None),
alt.Color(
"region",
legend=None,
scale=alt.Scale(range=["AliceBlue", "CadetBlue", "CadetBlue", "AliceBlue"])
),
)
.mark_rect(opacity=0.75, strokeWidth=0)
.properties(width=fig_site_width)
)
text_df = fig_site_diffs.groupby("region", as_index=False).aggregate(x=pd.NamedAgg("sequential_site", "mean"))
text_chart = (
alt.Chart(text_df)
.encode(
alt.X(
"x:Q",
title=None,
scale=alt.Scale(domain=(fig_site_diffs["sequential_site"].min(), fig_site_diffs["sequential_site"].max())),
axis=None,
),
alt.Text("region"),
)
.mark_text(fontWeight="bold", fontSize=13)
.properties(width=fig_site_width, height=15)
)
overlay_chart = region_chart + text_chart
fig_site_line_chart = (
alt.vconcat(overlay_chart, fig_site_diffs_chart, spacing=0)
.properties(
title=alt.TitleParams("average difference in mutation effects on cell entry at each site", anchor="middle"),
)
)
fig_site_chart = alt.hconcat(
fig_site_scatter_chart,
fig_site_line_chart,
spacing=53,
center=True,
)
fig_site_chart.configure_axis(grid=False, titleFontSize=12, titleFontWeight="normal").configure_view(stroke="black")
Correlate effects on 293T-Mxra8 - 293T-TIM1 with distance from Mxra8 in structure:
mxra8_dists = (
pd.read_csv(mxra8_dists_csv)
.assign(
site=lambda x: x["site"].astype(str) + "(" + x["region"] + ")",
PDB=lambda x: "distance in PDB " + x["PDB"],
)
[["PDB", "region", "site", "distance_to_Mxra8"]]
.merge(
(
fig_site_diffs
.query("comparison == '293T-Mxra8 minus 293T-TIM1'")
[["site", "mean difference"]]
),
on="site",
validate="m:1",
)
)
mxra8_dists_chart = (
alt.Chart(mxra8_dists)
.add_params(fig_site_selection)
.encode(
alt.X("mean difference", title=["293T-Mxra8 minus", "293T-TIM1"], scale=alt.Scale(nice=False, padding=5)),
alt.Y("distance_to_Mxra8", title=None, scale=alt.Scale(nice=False, padding=5)),
alt.Row("PDB", title=None, header=alt.Header(labelFontSize=12, labelPadding=1)),
fill=alt.condition(fig_site_selection, alt.value("red"), alt.value("gray")),
fillOpacity=alt.condition(fig_site_selection, alt.value(1), alt.value(0.25)),
size=alt.condition(fig_site_selection, alt.value(90), alt.value(35)),
tooltip=["site"],
)
.mark_circle(fill="gray", fillOpacity=0.25, strokeOpacity=0.7, stroke="black", strokeWidth=0.5)
.properties(
height=126,
width=126,
title=alt.TitleParams(["difference in effects", "vs distance to Mxra8"], anchor="middle", dx=10),
)
)
mxra8_dists_chart.configure_axis(
grid=False, titleFontWeight="normal", titleFontSize=12
)
Now make a function to plot the mutation effects in different cells at key sites:
def plot_site_scatter(
site, cell_1, cell_2, no_y_axis=None, ax_max=1.5, bold_letters=None, no_title=False, no_x_ticks=False
):
assert site in set(mut_data["site"]), site
mut_df = (
mut_data
[mut_data["site"] == site]
[["wildtype", "mutant", cell_1, cell_2]]
)
for cell in [cell_1, cell_2]:
mut_df[cell] = mut_df[cell].clip(lower=floor_mut_effects)
ax_min = floor_mut_effects - 0.5
if ax_max is None:
ax_max = mut_data[[cell_1, cell_2]].max().max() + 0.5
if bold_letters is None:
mut_df = mut_df.merge(aa_color_df, on="mutant", validate="one_to_one").assign(
opacity=0.75, strokeWidth=0.4, size=14
)
else:
mut_df["color"] = mut_df.apply(
lambda r: (
("red" if r["mutant"] != r["wildtype"] else "black") if r["mutant"] in bold_letters else "darkblue"
),
axis=1,
)
mut_df["opacity"] = mut_df["mutant"].map(lambda a: 1 if a in bold_letters else 0.25)
mut_df["strokeWidth"] = mut_df["mutant"].map(lambda a: 0.5 if a in bold_letters else 0)
mut_df["size"] = mut_df["mutant"].map(lambda a: 15 if a in bold_letters else 12)
mut_scatter = (
alt.Chart(mut_df)
.encode(
alt.X(
cell_1,
scale=alt.Scale(domain=(ax_min, ax_max), nice=False),
axis=alt.Axis(ticks=False, labels=False) if no_x_ticks else alt.Axis(),
),
alt.Y(
cell_2,
title=None if no_y_axis else cell_2,
scale=alt.Scale(domain=(ax_min, ax_max), nice=False),
axis=alt.Axis(ticks=False, labels=False) if no_y_axis else alt.Axis()
),
alt.Text("mutant"),
alt.Color("color:N", scale=None),
alt.FillOpacity("opacity", scale=None),
alt.StrokeWidth("strokeWidth", scale=None),
alt.Size("size", scale=None),
tooltip=(
["mutant", "wildtype"] + [alt.Tooltip(c, format=".2f") for c in [cell_1, cell_2]]
)
)
.mark_text(stroke="black", strokeOpacity=1, fontWeight=700)
.properties(
title="" if no_title else alt.TitleParams(f"site {site}", dy=7),
width=92,
height=92,
)
)
scatter_diagonal = (
alt.Chart()
.mark_rule(color="gray", strokeWidth=3, strokeDash=[6, 6], opacity=0.4)
.encode(
x=alt.datum(ax_min),
y=alt.datum(ax_min),
x2=alt.datum(ax_max),
y2=alt.datum(ax_max),
)
)
return scatter_diagonal + mut_scatter
Plot the mutation effects at key sites in each cell line pair, and merge into one figure:
key_sites = ['71(E2)', '119(E2)', '120(E2)', '121(E2)', '157(E2)', '158(E2)', '272(E2)']
top_diff_sites_scatter_chart = []
no_title = False
for cell1, cell2 in cell_pairs:
top_diff_sites_scatter_chart.append(
alt.hconcat(
*[
plot_site_scatter(
s, cell1, cell2, no_y_axis=(s != key_sites[0]), no_title=no_title
)
for s in key_sites
],
spacing=-2,
)
)
no_title = True
top_diff_sites_scatter_chart = alt.vconcat(
*top_diff_sites_scatter_chart, spacing=5
)
(
alt.vconcat(
fig_site_chart,
alt.hconcat(mxra8_dists_chart, top_diff_sites_scatter_chart, spacing=30),
spacing=9,
)
.configure_axis(grid=False, titleFontWeight="normal", titleFontSize=12, titlePadding=2)
.configure_view(stroke="black")
)
Make figure showing mutations selected for validation, with wildtype in black, mutants made in red, and other letters in faint blue:
# plot the differences in the experimentally validated sites
validation_sites = {
'119(E2)': ["R", "K"],
'120(E2)': ["K", "D"],
'121(E2)': ["I", "E"],
'157(E2)': ["A", "S"],
'158(E2)': ["Q", "T", "V"],
}
validation_scatter_chart = []
for cell1, cell2 in itertools.combinations(reversed(cells), 2):
validation_scatter_chart.append(
alt.hconcat(
*[
plot_site_scatter(
s,
cell1,
cell2,
no_y_axis=(s != list(validation_sites)[0]),
ax_max=mut_data.query("site in @validation_sites")[[cell1, cell2]].max().max() + 0.6,
bold_letters=letters,
)
for s, letters in validation_sites.items()
],
spacing=2,
)
)
validation_scatter_chart = (
alt.vconcat(*validation_scatter_chart, spacing=15)
.configure_axis(grid=False, titleFontWeight="normal", titleFontSize=12)
.configure_view(stroke="black")
)
validation_scatter_chart