Summarize results across assays¶
This notebook makes summarizes the results across assays.
import functools
import operator
import re
import altair as alt
import pandas as pd
import polyclonal.alphabets
from polyclonal.plot import color_gradient_hex
_ = alt.data_transformers.disable_max_rows()
Get configuration parameteres¶
The next cell is tagged as parameters
for papermill
parameterization:
site_numbering_map_csv = None
chart_overlaid = None
chart_faceted = None
output_csv_file = None
per_antibody_escape_csv = None
config = None
input_csvs = None
# Parameters
config = {
"min_times_seen": 3,
"min_frac_models": 1,
"alphabet": [
"A",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"K",
"L",
"M",
"N",
"P",
"Q",
"R",
"S",
"T",
"V",
"W",
"Y",
"-",
],
"init_floor_escape_at_zero": True,
"init_site_escape_stat": "sum",
"func_effects": {
"spike mediated entry": {
"condition": "293T_high_ACE2_entry",
"effect_type": "func_effects",
"positive_color": "#009E73",
"negative_color": "#F0E442",
"max_at_least": 1,
"min_at_least": 0,
"init_min_value": -2.0,
"le_filters": {"effect_std": 1.6},
}
},
"other_assays": {
"ACE2_binding": {
"ACE2 binding": {
"condition": "monomeric_ACE2",
"stat": "ACE2 binding_median",
"positive_color": "#0072B2",
"negative_color": "#D55E00",
"max_at_least": 1,
"min_at_least": 0,
"fixed_max": 2,
"fixed_min": -3,
"init_min_value": -2,
"le_filters": {"ACE2 binding_std": 3.0},
}
}
},
"antibody_escape": {
"one infection": {
"stat": "escape_median",
"positive_color": "#56B4E9",
"negative_color": "#E69F00",
"max_at_least": 1,
"min_at_least": -1,
"le_filters": {"escape_std": 1.5},
"antibody_list": {
"sera_493C_mediumACE2": "serum 493C",
"sera_498C_mediumACE2": "serum 498C",
"sera_500C_mediumACE2": "serum 500C",
"sera_501C_mediumACE2": "serum 501C",
"sera_503C_mediumACE2": "serum 503C",
"sera_505C_mediumACE2": "serum 505C",
},
},
"multiple infections": {
"stat": "escape_median",
"positive_color": "#56B4E9",
"negative_color": "#E69F00",
"max_at_least": 1,
"min_at_least": -1,
"le_filters": {"escape_std": 1.5},
"antibody_list": {
"sera_287C_mediumACE2": "serum 287C",
"sera_288C_mediumACE2": "serum 288C",
"sera_343C_mediumACE2": "serum 343C",
"sera_497C_mediumACE2": "serum 497C",
},
},
},
}
input_csvs = {
"antibody_escape sera_493C_mediumACE2": "results/antibody_escape/averages/sera_493C_mediumACE2_mut_effect.csv",
"antibody_escape sera_498C_mediumACE2": "results/antibody_escape/averages/sera_498C_mediumACE2_mut_effect.csv",
"antibody_escape sera_500C_mediumACE2": "results/antibody_escape/averages/sera_500C_mediumACE2_mut_effect.csv",
"antibody_escape sera_501C_mediumACE2": "results/antibody_escape/averages/sera_501C_mediumACE2_mut_effect.csv",
"antibody_escape sera_503C_mediumACE2": "results/antibody_escape/averages/sera_503C_mediumACE2_mut_effect.csv",
"antibody_escape sera_287C_mediumACE2": "results/antibody_escape/averages/sera_287C_mediumACE2_mut_effect.csv",
"antibody_escape sera_288C_mediumACE2": "results/antibody_escape/averages/sera_288C_mediumACE2_mut_effect.csv",
"antibody_escape sera_343C_mediumACE2": "results/antibody_escape/averages/sera_343C_mediumACE2_mut_effect.csv",
"antibody_escape sera_497C_mediumACE2": "results/antibody_escape/averages/sera_497C_mediumACE2_mut_effect.csv",
"antibody_escape sera_505C_mediumACE2": "results/antibody_escape/averages/sera_505C_mediumACE2_mut_effect.csv",
"func_effects 293T_high_ACE2_entry": "results/func_effects/averages/293T_high_ACE2_entry_func_effects.csv",
"ACE2_binding monomeric_ACE2": "results/ACE2_binding/averages/monomeric_ACE2_mut_effect.csv",
"site_numbering_map_csv": "data/site_numbering_map.csv",
"nb": "dms-vep-pipeline-3/notebooks/summary.ipynb",
}
site_numbering_map_csv = "data/site_numbering_map.csv"
chart_faceted = "results/escape_by_prior_infections/summary_faceted_nolegend.html"
chart_overlaid = "results/escape_by_prior_infections/summary_overlaid_nolegend.html"
output_csv_file = "results/escape_by_prior_infections/summary.csv"
per_antibody_escape_csv = "results/escape_by_prior_infections/per_antibody_escape.csv"
Get the min_times_seen
and min_frac_models
filters:
min_times_seen = config["min_times_seen"]
min_frac_models = config["min_frac_models"]
alphabet = polyclonal.alphabets.biochem_order_aas(config["alphabet"])
print(f"Using {min_times_seen=} and {min_frac_models=}")
Using min_times_seen=3 and min_frac_models=1
Read the data¶
Read the site numbering map:
site_numbering_map = pd.read_csv(site_numbering_map_csv).rename(
columns={"reference_site": "site"}
)
site_numbering_map = site_numbering_map[
[c for c in site_numbering_map.columns if c.endswith("site")] + ["region"]
]
Read the escape data:
escape = {}
for antibody_set, antibody_set_d in config["antibody_escape"].items():
assert len(antibody_set_d["antibody_list"]) == len(
set(antibody_set_d["antibody_list"].values())
)
escape_dfs = []
for antibody, antibody_name in antibody_set_d["antibody_list"].items():
csv_file = input_csvs[f"antibody_escape {antibody}"]
escape_dfs.append(
pd.read_csv(csv_file)
.assign(antibody=antibody_name)
.rename(columns={antibody_set_d["stat"]: "escape"})
)
if "le_filters" in antibody_set_d:
le_filters = " and " + " and ".join(
f"(`{key}` <= {val})" for (key, val) in antibody_set_d["le_filters"].items()
)
else:
le_filters = ""
escape[antibody_set] = (
pd.concat(escape_dfs)
.query("frac_models >= @min_frac_models")
.query("times_seen >= @min_times_seen")
.query("(mutant in @alphabet) and (wildtype in @alphabet)" + le_filters)
.pivot_table(
index=["epitope", "site", "wildtype", "mutant"],
columns="antibody",
values="escape",
)
.reset_index()
.assign(site_mutant=lambda x: x["site"].astype(str) + x["mutant"])
)
assert escape[antibody_set]["epitope"].nunique() == 1, "can only have 1 epitope"
escape[antibody_set] = escape[antibody_set].drop(columns="epitope")
Write per antibody-escape to file:
print(f"Writing per-antibody escape to {per_antibody_escape_csv=}")
if escape:
pd.concat(
[
antibody_df.drop(columns="site_mutant")
.melt(
id_vars=["site", "wildtype", "mutant"],
var_name="antibody",
value_name="escape",
)
.assign(antibody_set=antibody_set)
for antibody_set, antibody_df in escape.items()
],
ignore_index=True,
).query("escape.notnull()").to_csv(
per_antibody_escape_csv, index=False, float_format="%.5g"
)
else:
(
pd.DataFrame(
columns=["site", "wildtype", "mutant", "antibody", "antibody_set"]
).to_csv(per_antibody_escape_csv, index=False)
)
Writing per-antibody escape to per_antibody_escape_csv='results/escape_by_prior_infections/per_antibody_escape.csv'
Read other properties (functional effects and measurements from other assays):
other_props = {}
for name, name_d in config["func_effects"].items():
csv_file = input_csvs[f"func_effects {name_d['condition']}"]
if "le_filters" in name_d:
le_filters = " and " + " and ".join(
f"(`{key}` <= {val})" for (key, val) in name_d["le_filters"].items()
)
else:
le_filters = ""
other_props[name] = (
pd.read_csv(csv_file)
.rename(columns={"effect": name})
.assign(frac_models=lambda x: x["n_selections"] / x["n_selections"].max())
.query("(times_seen >= @min_times_seen)" + le_filters)
.query("frac_models >= @min_frac_models")[["site", "wildtype", "mutant", name]]
)
for assay, assay_d in config["other_assays"].items():
for name, name_d in assay_d.items():
assert name not in other_props, f"{name} multiply defined"
csv_file = input_csvs[f"{assay} {name_d['condition']}"]
if "le_filters" in name_d:
le_filters = " and " + " and ".join(
f"(`{key}` <= {val})" for (key, val) in name_d["le_filters"].items()
)
else:
le_filters = ""
other_props[name] = (
pd.read_csv(csv_file)
.rename(columns={name_d["stat"]: name})
.query("(times_seen >= @min_times_seen)" + le_filters)
.query("frac_models >= @min_frac_models")[
["site", "wildtype", "mutant", name]
]
)
assert not set(other_props).intersection(escape), "multiply defined names"
# add wildtype effects of zero
site_wts = pd.concat([*escape.values(), *other_props.values()])[
["site", "wildtype"]
].drop_duplicates()
assert len(site_wts) == site_wts["site"].nunique()
for prop in other_props:
if not re.fullmatch("[\w ]+", prop):
raise ValueError("non-alphanumeric name for property: {prop}")
other_props[prop] = (
pd.concat(
[
other_props[prop],
site_wts.assign(
mutant=lambda x: x["wildtype"],
**{prop: 0},
),
],
ignore_index=True,
)
.assign(site_mutant=lambda x: x["site"].astype(str) + x["mutant"])
.merge(site_numbering_map, on="site", validate="many_to_one")
.query("(mutant in @alphabet) and (wildtype in @alphabet)")
)
assert other_props[prop]["site_mutant"].nunique() == len(other_props[prop])
Get from the config the plot parameters for each plot (essentially, this "flattens" some aspects of config
to make these easier to access below):
plot_params = {}
for name in config["antibody_escape"]:
assert name not in plot_params
plot_params[name] = config["antibody_escape"][name]
for name in config["func_effects"]:
assert name not in plot_params
plot_params[name] = config["func_effects"][name]
for assay, assay_d in config["other_assays"].items():
for name in assay_d:
assert name not in plot_params
plot_params[name] = assay_d[name]
Set up selections for interactive charts¶
site_escape_width = 800 # width of site escape chart
site_escape_overlaid_height = 130 # height of overlaid site escape plots
site_escape_faceted_height = 80 # height of faceted site escape plots
cell_size = 9 # heatmap cell size
floor_escape_at_zero = alt.param(
value=config["init_floor_escape_at_zero"],
name="floor_escape_at_zero",
bind=alt.binding_radio(options=[True, False], name="floor escape at zero"),
)
site_stats = ["mean", "sum", "max", "min"]
assert config["init_site_escape_stat"] in site_stats
site_escape_selection = alt.selection_point(
fields=["site escape statistic"],
bind=alt.binding_select(
options=site_stats,
name="site escape statistic",
),
value=config["init_site_escape_stat"],
)
site_selection = alt.selection_point(fields=["site"], on="mouseover", empty=False)
site_brush = alt.selection_interval(
encodings=["x"],
mark=alt.BrushConfig(stroke="black", strokeWidth=2, fillOpacity=0),
empty=True,
)
other_prop_sliders = {}
for prop, prop_df in other_props.items():
other_prop_sliders[prop] = alt.param(
value=max(prop_df[prop].min(), plot_params[prop]["init_min_value"]),
name=prop.replace(" ", "_") + "_slider",
bind=alt.binding_range(
name=f"minimum mutation {prop}",
min=prop_df[prop].min(),
max=0,
),
)
# region zoom bar
region_bar = (
alt.Chart(site_numbering_map)
.encode(
x=alt.X(
"site:N",
sort=alt.SortField("sequential_site"),
axis=None,
),
color=alt.Color(
"region",
scale=alt.Scale(domain=site_numbering_map["region"].unique()),
),
tooltip=site_numbering_map.columns.tolist(),
)
.mark_rect()
.properties(
width=site_escape_width,
height=11,
title=alt.TitleParams(
"site zoom bar",
fontSize=11,
fontWeight="bold",
orient="top",
),
)
)
Antibody escape site summary plots¶
Now make site summary escape plots for each antibody set. Do this with both the sera faceted and overlaid.
site_escape_charts = {"faceted": {}, "overlaid": {}}
for antibody_set, escape_df in escape.items():
antibodies = list(config["antibody_escape"][antibody_set]["antibody_list"].values())
if any(not re.fullmatch("[\w \-]+", a) for a in antibodies):
raise ValueError(f"antibody names not all alphanumeric:\n{antibodies}")
site_escape_base = (
# merge sequential_site rather than transform_lookup so sorting works
alt.Chart(escape_df.merge(site_numbering_map[["site", "sequential_site"]]))
.encode(
y=alt.Y(
"escape:Q",
scale=alt.Scale(nice=False, padding=4),
axis=alt.Axis(grid=False),
),
tooltip=[
"site",
alt.Tooltip("escape:Q", format=".2f"),
"antibody:N",
*[f"{c}:N" for c in site_numbering_map.columns if c != "site"],
],
)
.transform_filter(site_brush)
)
site_escape_lines = site_escape_base.mark_line(size=0.75)
site_escape_points = site_escape_base.encode(
strokeWidth=alt.condition(site_selection, alt.value(1.5), alt.value(0)),
size=alt.condition(site_selection, alt.value(45), alt.value(15)),
).mark_circle(filled=True, stroke="red")
site_escape_lines_and_points = (
(site_escape_lines + site_escape_points).transform_fold(
fold=antibodies, as_=["antibody", "escape_orig"]
)
# floor escape at zero if selected
.transform_calculate(
escape=alt.expr.if_(
floor_escape_at_zero,
alt.expr.max(alt.datum["escape_orig"], 0),
alt.datum["escape_orig"],
)
)
)
# filter on other properties
for prop, prop_df in other_props.items():
# https://github.com/altair-viz/altair/issues/2600
slider = other_prop_sliders[prop]
site_escape_lines_and_points = site_escape_lines_and_points.transform_lookup(
lookup="site_mutant",
from_=alt.LookupData(
prop_df,
key="site_mutant",
fields=[prop],
),
).transform_filter(alt.datum[prop] >= slider)
# compute site statistics
site_escape_lines_and_points = (
site_escape_lines_and_points
# compute site statistics from mutation statistics
.transform_aggregate(
**{stat: f"{stat}(escape)" for stat in site_stats},
groupby=["site", "sequential_site", "antibody", "wildtype"],
)
# filter on site statistic of interest
.transform_fold(
fold=site_stats, as_=["site escape statistic", "escape"]
).transform_filter(site_escape_selection)
# get sequential sites and regions
.transform_lookup(
lookup="site",
from_=alt.LookupData(
site_numbering_map,
key="site",
fields=[
c
for c in site_numbering_map.columns
if c not in {"site", "sequential_site"}
],
),
)
)
if len(antibody_set) < 14:
individual_title = f"individual {antibody_set}"
mean_title = f"mean {antibody_set}"
else:
individual_title = ["individual", antibody_set]
mean_title = ["mean", antibody_set]
if len(antibodies) == 1:
mean_title = antibody_set
site_escape_faceted = (
site_escape_lines_and_points.encode(
x=alt.X(
"site:N",
sort=alt.SortField("sequential_site"),
axis=alt.Axis(labelOverlap=True, grid=False, ticks=False),
),
color=alt.value("gray"),
)
.properties(height=site_escape_faceted_height, width=site_escape_width)
.facet(
facet=alt.Facet(
"antibody:N",
title=individual_title,
header=alt.Header(
labelOrient="right",
labelFontSize=10,
labelPadding=3,
titleOrient="right",
titlePadding=3,
),
sort=antibodies,
),
columns=1,
spacing=0,
)
.resolve_scale(y="independent")
)
site_escape_overlaid = site_escape_lines_and_points.encode(
x=alt.X(
"site:N",
sort=alt.SortField("sequential_site"),
axis=alt.Axis(labelOverlap=True, grid=False, ticks=False),
),
opacity=alt.value(0.4),
color=alt.value("gray" if len(antibodies) > 1 else "black"),
detail="antibody:N",
).properties(
height=site_escape_overlaid_height,
width=site_escape_width,
title=alt.TitleParams(
individual_title,
fontSize=11,
fontWeight="bold",
orient="right",
),
)
site_mean_escape = (
site_escape_lines_and_points
# average missing values as zero
.transform_calculate(
escape=alt.expr.if_(
alt.expr.isValid(alt.datum["escape"]),
alt.datum["escape"],
0,
),
)
# take mean over sera
.transform_aggregate(
escape="mean(escape)",
groupby=["wildtype", *site_numbering_map.columns],
)
.transform_calculate(
antibody="'mean escape'" if len(antibodies) > 1 else f"'{antibodies[0]}'"
)
.encode(
x=alt.X(
"site:N",
sort=alt.SortField("sequential_site"),
axis=None,
),
color=alt.value("black"),
)
.properties(
title=alt.TitleParams(
mean_title,
fontSize=11,
fontWeight="bold",
orient="right",
),
)
)
for chart_type, height, site_chart in [
("faceted", site_escape_faceted_height, site_escape_faceted),
("overlaid", site_escape_overlaid_height, site_escape_overlaid),
]:
site_escape_charts[chart_type][antibody_set] = alt.vconcat(
region_bar.add_params(site_brush),
(
alt.vconcat(
*(
[
site_mean_escape.properties(
height=height,
width=site_escape_width,
),
site_chart,
]
if len(antibodies) > 1
else [
site_mean_escape.properties(
height=height,
width=site_escape_width,
)
]
),
spacing=3,
).add_params(
site_escape_selection,
site_selection,
*other_prop_sliders.values(),
floor_escape_at_zero,
)
),
spacing=0,
)
site_escape_charts[chart_type][antibody_set] = alt.vconcat(
*(
[
site_mean_escape.properties(
height=height,
width=site_escape_width,
),
site_chart,
]
if len(antibodies) > 1
else [
site_mean_escape.properties(
height=height,
width=site_escape_width,
)
]
),
spacing=0,
).add_params(
site_escape_selection,
site_selection,
*other_prop_sliders.values(),
floor_escape_at_zero,
)
Heatmaps¶
First, create a data frame that has the average escape across sera (averaging mutations missing for a serum as zero for that serum) and other properties:
heatmap_data = None
heatmap_data_cols = []
for antibody_set, escape_df in escape.items():
antibody_list = list(
config["antibody_escape"][antibody_set]["antibody_list"].values()
)
df = (
pd.concat(
[
escape_df,
# add wildtype with zero escape
(
escape_df[["site", "wildtype"]]
.drop_duplicates()
.assign(mutant=lambda x: x["wildtype"])
),
],
)
.fillna(0)
.assign(
escape=lambda x: x[antibody_list].mean(axis=1),
site_mutant=lambda x: x["site"].astype(str) + x["mutant"],
)
.merge(site_numbering_map, on="site", validate="many_to_one")
.drop(columns=antibody_list)
.rename(columns={"escape": f"{antibody_set} escape"})
)
heatmap_data_cols.append(f"{antibody_set} escape")
if heatmap_data is not None:
heatmap_data = heatmap_data.merge(df, how="outer", validate="one_to_one")
else:
heatmap_data = df
for prop, prop_df in other_props.items():
heatmap_data_cols.append(prop)
if heatmap_data is None:
heatmap_data = prop_df
else:
heatmap_data = heatmap_data.merge(prop_df, validate="one_to_one", how="outer")
if heatmap_data is None:
raise ValueError("no data specified for summary")
heatmap_data = (
heatmap_data.drop(columns=[c for c in site_numbering_map.columns if c != "site"])
.merge(site_numbering_map, validate="many_to_one")
.drop(columns="site_mutant")
)
for antibody_set in escape:
col = f"{antibody_set} escape"
heatmap_data[col] = heatmap_data[col].where(
heatmap_data["wildtype"] != heatmap_data["mutant"], 0
)
if len(
dup_rows := (
heatmap_data.assign(
_n=lambda x: x.groupby(["site", "mutant"])["wildtype"].transform("count")
).query("_n > 1")
)
):
raise ValueError(f"Duplicate rows:\n{dup_rows}")
print(f"Writing summary data to {output_csv_file}")
heatmap_data.to_csv(output_csv_file, index=False, float_format="%.4g")
heatmap_data
Writing summary data to results/escape_by_prior_infections/summary.csv
site | wildtype | mutant | one infection escape | multiple infections escape | spike mediated entry | ACE2 binding | sequential_site | region | |
---|---|---|---|---|---|---|---|---|---|
0 | 2 | F | C | -0.021218 | 0.059678 | 0.10100 | 0.02151 | 2 | other |
1 | 2 | F | L | 0.067207 | -0.053920 | 0.09432 | -0.26980 | 2 | other |
2 | 2 | F | S | 0.047673 | 0.007725 | 0.05844 | -0.05642 | 2 | other |
3 | 2 | F | F | 0.000000 | 0.000000 | 0.00000 | 0.00000 | 2 | other |
4 | 3 | V | A | -0.055882 | 0.143862 | -0.04154 | -0.04977 | 3 | other |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
8345 | 1211 | K | K | 0.000000 | 0.000000 | 0.00000 | 0.00000 | 1207 | S2 |
8346 | 1212 | W | R | NaN | NaN | -2.37700 | -0.21050 | 1208 | other |
8347 | 1212 | W | W | 0.000000 | 0.000000 | 0.00000 | 0.00000 | 1208 | other |
8348 | 767 | L | L | 0.000000 | 0.000000 | 0.00000 | 0.00000 | 763 | S2 |
8349 | 767 | L | V | NaN | NaN | NaN | 1.69700 | 763 | S2 |
8350 rows × 9 columns
Make heatmaps:
last_heatmap = list(plot_params)[-1]
heatmap_base = (
alt.Chart(heatmap_data)
.transform_calculate(
**{
f"{antibody_set} escape_floored": alt.expr.if_(
floor_escape_at_zero,
alt.expr.max(alt.datum[f"{antibody_set} escape"], 0),
alt.datum[f"{antibody_set} escape"],
)
for antibody_set in escape
},
# convert null values to NaN so they show as NaN in tooltips rather than as 0.0
**{
col: alt.expr.if_(
alt.expr.isFinite(alt.datum[col]),
alt.datum[col],
alt.expr.NaN,
)
for col in heatmap_data_cols
},
)
.encode(
x=alt.X(
"site:N",
sort=alt.SortField("sequential_site"),
axis=alt.Axis(labelFontSize=9, ticks=False),
),
y=alt.Y(
"mutant:N",
title="amino acid",
sort=alphabet,
axis=alt.Axis(labelFontSize=9, ticks=False),
),
strokeWidth=alt.condition(site_selection, alt.value(2), alt.value(1)),
)
.properties(width=alt.Step(cell_size), height=alt.Step(cell_size))
.add_params(*other_prop_sliders.values(), floor_escape_at_zero)
)
# mark X for wildtype
heatmap_wildtype = heatmap_base.transform_filter(
alt.datum["wildtype"] == alt.datum["mutant"]
).mark_text(text="x", color="black")
# gray background for missing values
heatmap_bg = heatmap_base.transform_impute(
impute="_stat_dummy",
key="mutant",
keyvals=alphabet,
groupby=["site"],
value=None,
).mark_rect(color="#E0E0E0", opacity=0.8)
tooltips = [
"site",
"mutant",
*[alt.Tooltip(c, format=".2f") for c in heatmap_data_cols],
"wildtype",
*[c for c in site_numbering_map.columns if c != "site"],
]
legend = alt.Legend(
orient="left",
offset=10,
titleOrient="left",
gradientLength=100,
gradientThickness=10,
gradientStrokeColor="black",
gradientStrokeWidth=0.5,
)
# heatmaps for escape
escape_heatmaps = {}
escape_heatmap_base = heatmap_base.transform_filter(
functools.reduce(
operator.or_,
[(alt.datum[prop] >= slider) for prop, slider in other_prop_sliders.items()],
)
| (alt.datum["wildtype"] == alt.datum["mutant"])
| functools.reduce(
operator.and_,
[alt.expr.isNaN(alt.datum[prop]) for prop in other_prop_sliders],
)
)
for antibody_set in escape:
domain_lims = {
lim: (
plot_params[antibody_set][f"fixed_{lim}"]
if f"fixed_{lim}" in plot_params[antibody_set]
else lim_func(
plot_params[antibody_set][f"{lim}_at_least"],
lim_func(heatmap_data[f"{antibody_set} escape"]),
)
)
for (lim, lim_func) in [("min", min), ("max", max)]
}
escape_heatmaps[antibody_set] = escape_heatmap_base.encode(
x=alt.X(
"site:N",
sort=alt.SortField("sequential_site"),
title=None,
axis=(
alt.Axis()
if last_heatmap == antibody_set
else alt.Axis(ticks=False, labels=False)
),
),
color=alt.Color(
f"{antibody_set} escape_floored:Q",
title=f"{antibody_set} escape"
if len(antibody_set) < 14
else [antibody_set, "escape"],
legend=legend,
scale=alt.Scale(
zero=True,
nice=False,
type="linear",
domainMid=0,
domainMax=domain_lims["max"],
domainMin=alt.ExprRef(
f"if(floor_escape_at_zero, 0, {domain_lims['min']})"
),
range=(
color_gradient_hex(
plot_params[antibody_set]["negative_color"], "white", n=20
)
+ color_gradient_hex(
"white", plot_params[antibody_set]["positive_color"], n=20
)[1:]
),
),
),
tooltip=tooltips,
).mark_rect(stroke="black")
# heatmap for other property (eg, functional effect) filtered escape
escape_filtered_heatmap = (
heatmap_base.transform_filter(
functools.reduce(
operator.or_,
[alt.datum[prop] < slider for prop, slider in other_prop_sliders.items()],
)
& (alt.datum["wildtype"] != alt.datum["mutant"])
)
.transform_calculate(filtered="''")
.encode(
tooltip=tooltips,
color=alt.Color(
"filtered:N",
title="deleterious",
scale=alt.Scale(range=["silver"]),
legend=None,
),
)
.mark_rect(stroke="black")
)
# heatmaps for other properties
other_prop_heatmaps = {}
other_prop_filtered_heatmaps = {}
for prop in other_props:
params = plot_params[prop]
domain_lims = {
lim: (
params[f"fixed_{lim}"]
if f"fixed_{lim}" in params
else lim_func(
params[f"{lim}_at_least"],
lim_func(heatmap_data[prop]),
)
)
for (lim, lim_func) in [("min", min), ("max", max)]
}
slider_name = prop.replace(" ", "_") + "_slider" # name given when defining slider
other_prop_heatmaps[prop] = (
heatmap_base.transform_filter(
functools.reduce(
operator.and_,
[
(
(alt.datum[other_prop] >= slider)
| alt.expr.isNaN(alt.datum[other_prop])
)
for other_prop, slider in other_prop_sliders.items()
if other_prop != prop
],
True,
)
| (alt.datum["wildtype"] == alt.datum["mutant"])
)
.encode(
x=alt.X(
"site:N",
sort=alt.SortField("sequential_site"),
title=None,
axis=(
alt.Axis()
if last_heatmap == prop
else alt.Axis(ticks=False, labels=False)
),
),
color=alt.Color(
prop,
legend=legend,
scale=alt.Scale(
zero=True,
nice=False,
type="linear",
clamp=True,
domainMid=0,
domainMax=domain_lims["max"],
domainMin=alt.ExprRef(f"max({slider_name}, {domain_lims['min']})"),
range=(
color_gradient_hex(params["negative_color"], "white", n=20)
+ color_gradient_hex("white", params["positive_color"], n=20)[
1:
]
),
),
),
tooltip=tooltips,
)
.mark_rect(stroke="black")
)
other_prop_filtered_heatmaps[prop] = (
heatmap_base.transform_filter(
functools.reduce(
operator.or_,
[
alt.datum[other_prop] < slider
for other_prop, slider in other_prop_sliders.items()
if other_prop != prop
],
False,
)
& (alt.datum["wildtype"] != alt.datum["mutant"])
)
.transform_calculate(filtered="''")
.encode(
tooltip=tooltips,
color=alt.Color(
"filtered:N",
title="deleterious",
scale=alt.Scale(range=["silver"]),
legend=None,
),
)
.mark_rect(stroke="black")
)
heatmap = (
alt.vconcat(
*[
heatmap_bg + escape_heatmap + escape_filtered_heatmap + heatmap_wildtype
for escape_heatmap in escape_heatmaps.values()
],
*[
heatmap_bg
+ other_prop_heatmaps[prop]
+ other_prop_filtered_heatmaps[prop]
+ heatmap_wildtype
for prop in other_props
],
spacing=2,
)
.resolve_scale(color="independent")
.add_params(site_selection)
)
Merged site escape lineplot and heatmaps¶
Create and save merged escape lineplot and heatmap:
for chart_type, chartfile in [("overlaid", chart_overlaid), ("faceted", chart_faceted)]:
merged_chart = alt.vconcat(
region_bar.add_params(site_brush),
*site_escape_charts[chart_type].values(),
heatmap.transform_filter(site_brush),
spacing=2,
).configure_legend(orient="left", padding=0)
print(f"Saving {chart_type} chart to {chartfile}")
merged_chart.save(chartfile)
display(merged_chart)
Saving overlaid chart to results/escape_by_prior_infections/summary_overlaid_nolegend.html
Saving faceted chart to results/escape_by_prior_infections/summary_faceted_nolegend.html