Summarize results across assays¶

This notebook makes summarizes the results across assays.

In [1]:
import functools
import operator
import re

import altair as alt

import pandas as pd

import polyclonal.alphabets
from polyclonal.plot import color_gradient_hex

_ = alt.data_transformers.disable_max_rows()

Get configuration parameteres¶

The next cell is tagged as parameters for papermill parameterization:

In [2]:
site_numbering_map_csv = None
chart_overlaid = None
chart_faceted = None
output_csv_file = None
per_antibody_escape_csv = None
config = None
input_csvs = None
In [3]:
# Parameters
config = {
    "min_times_seen": 3,
    "min_frac_models": 1,
    "alphabet": [
        "A",
        "C",
        "D",
        "E",
        "F",
        "G",
        "H",
        "I",
        "K",
        "L",
        "M",
        "N",
        "P",
        "Q",
        "R",
        "S",
        "T",
        "V",
        "W",
        "Y",
        "-",
    ],
    "init_floor_escape_at_zero": True,
    "init_site_escape_stat": "sum",
    "func_effects": {
        "spike mediated entry": {
            "condition": "293T_high_ACE2_entry",
            "effect_type": "func_effects",
            "positive_color": "#009E73",
            "negative_color": "#F0E442",
            "max_at_least": 1,
            "min_at_least": 0,
            "init_min_value": -2.0,
            "le_filters": {"effect_std": 1.6},
        }
    },
    "other_assays": {
        "ACE2_binding": {
            "ACE2 binding": {
                "condition": "monomeric_ACE2",
                "stat": "ACE2 binding_median",
                "positive_color": "#0072B2",
                "negative_color": "#D55E00",
                "max_at_least": 1,
                "min_at_least": 0,
                "fixed_max": 2,
                "fixed_min": -3,
                "init_min_value": -2,
                "le_filters": {"ACE2 binding_std": 3.0},
            }
        }
    },
    "antibody_escape": {
        "one infection": {
            "stat": "escape_median",
            "positive_color": "#56B4E9",
            "negative_color": "#E69F00",
            "max_at_least": 1,
            "min_at_least": -1,
            "le_filters": {"escape_std": 1.5},
            "antibody_list": {
                "sera_493C_mediumACE2": "serum 493C",
                "sera_498C_mediumACE2": "serum 498C",
                "sera_500C_mediumACE2": "serum 500C",
                "sera_501C_mediumACE2": "serum 501C",
                "sera_503C_mediumACE2": "serum 503C",
                "sera_505C_mediumACE2": "serum 505C",
            },
        },
        "multiple infections": {
            "stat": "escape_median",
            "positive_color": "#56B4E9",
            "negative_color": "#E69F00",
            "max_at_least": 1,
            "min_at_least": -1,
            "le_filters": {"escape_std": 1.5},
            "antibody_list": {
                "sera_287C_mediumACE2": "serum 287C",
                "sera_288C_mediumACE2": "serum 288C",
                "sera_343C_mediumACE2": "serum 343C",
                "sera_497C_mediumACE2": "serum 497C",
            },
        },
    },
}
input_csvs = {
    "antibody_escape sera_493C_mediumACE2": "results/antibody_escape/averages/sera_493C_mediumACE2_mut_effect.csv",
    "antibody_escape sera_498C_mediumACE2": "results/antibody_escape/averages/sera_498C_mediumACE2_mut_effect.csv",
    "antibody_escape sera_500C_mediumACE2": "results/antibody_escape/averages/sera_500C_mediumACE2_mut_effect.csv",
    "antibody_escape sera_501C_mediumACE2": "results/antibody_escape/averages/sera_501C_mediumACE2_mut_effect.csv",
    "antibody_escape sera_503C_mediumACE2": "results/antibody_escape/averages/sera_503C_mediumACE2_mut_effect.csv",
    "antibody_escape sera_287C_mediumACE2": "results/antibody_escape/averages/sera_287C_mediumACE2_mut_effect.csv",
    "antibody_escape sera_288C_mediumACE2": "results/antibody_escape/averages/sera_288C_mediumACE2_mut_effect.csv",
    "antibody_escape sera_343C_mediumACE2": "results/antibody_escape/averages/sera_343C_mediumACE2_mut_effect.csv",
    "antibody_escape sera_497C_mediumACE2": "results/antibody_escape/averages/sera_497C_mediumACE2_mut_effect.csv",
    "antibody_escape sera_505C_mediumACE2": "results/antibody_escape/averages/sera_505C_mediumACE2_mut_effect.csv",
    "func_effects 293T_high_ACE2_entry": "results/func_effects/averages/293T_high_ACE2_entry_func_effects.csv",
    "ACE2_binding monomeric_ACE2": "results/ACE2_binding/averages/monomeric_ACE2_mut_effect.csv",
    "site_numbering_map_csv": "data/site_numbering_map.csv",
    "nb": "dms-vep-pipeline-3/notebooks/summary.ipynb",
}
site_numbering_map_csv = "data/site_numbering_map.csv"
chart_faceted = "results/escape_by_prior_infections/summary_faceted_nolegend.html"
chart_overlaid = "results/escape_by_prior_infections/summary_overlaid_nolegend.html"
output_csv_file = "results/escape_by_prior_infections/summary.csv"
per_antibody_escape_csv = "results/escape_by_prior_infections/per_antibody_escape.csv"

Get the min_times_seen and min_frac_models filters:

In [4]:
min_times_seen = config["min_times_seen"]
min_frac_models = config["min_frac_models"]
alphabet = polyclonal.alphabets.biochem_order_aas(config["alphabet"])

print(f"Using {min_times_seen=} and {min_frac_models=}")
Using min_times_seen=3 and min_frac_models=1

Read the data¶

Read the site numbering map:

In [5]:
site_numbering_map = pd.read_csv(site_numbering_map_csv).rename(
    columns={"reference_site": "site"}
)
site_numbering_map = site_numbering_map[
    [c for c in site_numbering_map.columns if c.endswith("site")] + ["region"]
]

Read the escape data:

In [6]:
escape = {}
for antibody_set, antibody_set_d in config["antibody_escape"].items():
    assert len(antibody_set_d["antibody_list"]) == len(
        set(antibody_set_d["antibody_list"].values())
    )
    escape_dfs = []
    for antibody, antibody_name in antibody_set_d["antibody_list"].items():
        csv_file = input_csvs[f"antibody_escape {antibody}"]
        escape_dfs.append(
            pd.read_csv(csv_file)
            .assign(antibody=antibody_name)
            .rename(columns={antibody_set_d["stat"]: "escape"})
        )
    if "le_filters" in antibody_set_d:
        le_filters = " and " + " and ".join(
            f"(`{key}` <= {val})" for (key, val) in antibody_set_d["le_filters"].items()
        )
    else:
        le_filters = ""
    escape[antibody_set] = (
        pd.concat(escape_dfs)
        .query("frac_models >= @min_frac_models")
        .query("times_seen >= @min_times_seen")
        .query("(mutant in @alphabet) and (wildtype in @alphabet)" + le_filters)
        .pivot_table(
            index=["epitope", "site", "wildtype", "mutant"],
            columns="antibody",
            values="escape",
        )
        .reset_index()
        .assign(site_mutant=lambda x: x["site"].astype(str) + x["mutant"])
    )
    assert escape[antibody_set]["epitope"].nunique() == 1, "can only have 1 epitope"
    escape[antibody_set] = escape[antibody_set].drop(columns="epitope")

Write per antibody-escape to file:

In [7]:
print(f"Writing per-antibody escape to {per_antibody_escape_csv=}")

if escape:
    pd.concat(
        [
            antibody_df.drop(columns="site_mutant")
            .melt(
                id_vars=["site", "wildtype", "mutant"],
                var_name="antibody",
                value_name="escape",
            )
            .assign(antibody_set=antibody_set)
            for antibody_set, antibody_df in escape.items()
        ],
        ignore_index=True,
    ).query("escape.notnull()").to_csv(
        per_antibody_escape_csv, index=False, float_format="%.5g"
    )
else:
    (
        pd.DataFrame(
            columns=["site", "wildtype", "mutant", "antibody", "antibody_set"]
        ).to_csv(per_antibody_escape_csv, index=False)
    )
Writing per-antibody escape to per_antibody_escape_csv='results/escape_by_prior_infections/per_antibody_escape.csv'

Read other properties (functional effects and measurements from other assays):

In [8]:
other_props = {}

for name, name_d in config["func_effects"].items():
    csv_file = input_csvs[f"func_effects {name_d['condition']}"]
    if "le_filters" in name_d:
        le_filters = " and " + " and ".join(
            f"(`{key}` <= {val})" for (key, val) in name_d["le_filters"].items()
        )
    else:
        le_filters = ""
    other_props[name] = (
        pd.read_csv(csv_file)
        .rename(columns={"effect": name})
        .assign(frac_models=lambda x: x["n_selections"] / x["n_selections"].max())
        .query("(times_seen >= @min_times_seen)" + le_filters)
        .query("frac_models >= @min_frac_models")[["site", "wildtype", "mutant", name]]
    )

for assay, assay_d in config["other_assays"].items():
    for name, name_d in assay_d.items():
        assert name not in other_props, f"{name} multiply defined"
        csv_file = input_csvs[f"{assay} {name_d['condition']}"]
        if "le_filters" in name_d:
            le_filters = " and " + " and ".join(
                f"(`{key}` <= {val})" for (key, val) in name_d["le_filters"].items()
            )
        else:
            le_filters = ""
        other_props[name] = (
            pd.read_csv(csv_file)
            .rename(columns={name_d["stat"]: name})
            .query("(times_seen >= @min_times_seen)" + le_filters)
            .query("frac_models >= @min_frac_models")[
                ["site", "wildtype", "mutant", name]
            ]
        )

assert not set(other_props).intersection(escape), "multiply defined names"

# add wildtype effects of zero
site_wts = pd.concat([*escape.values(), *other_props.values()])[
    ["site", "wildtype"]
].drop_duplicates()
assert len(site_wts) == site_wts["site"].nunique()
for prop in other_props:
    if not re.fullmatch("[\w ]+", prop):
        raise ValueError("non-alphanumeric name for property: {prop}")
    other_props[prop] = (
        pd.concat(
            [
                other_props[prop],
                site_wts.assign(
                    mutant=lambda x: x["wildtype"],
                    **{prop: 0},
                ),
            ],
            ignore_index=True,
        )
        .assign(site_mutant=lambda x: x["site"].astype(str) + x["mutant"])
        .merge(site_numbering_map, on="site", validate="many_to_one")
        .query("(mutant in @alphabet) and (wildtype in @alphabet)")
    )
    assert other_props[prop]["site_mutant"].nunique() == len(other_props[prop])

Get from the config the plot parameters for each plot (essentially, this "flattens" some aspects of config to make these easier to access below):

In [9]:
plot_params = {}

for name in config["antibody_escape"]:
    assert name not in plot_params
    plot_params[name] = config["antibody_escape"][name]

for name in config["func_effects"]:
    assert name not in plot_params
    plot_params[name] = config["func_effects"][name]

for assay, assay_d in config["other_assays"].items():
    for name in assay_d:
        assert name not in plot_params
        plot_params[name] = assay_d[name]

Set up selections for interactive charts¶

In [10]:
site_escape_width = 800  # width of site escape chart
site_escape_overlaid_height = 130  # height of overlaid site escape plots
site_escape_faceted_height = 80  # height of faceted site escape plots
cell_size = 9  # heatmap cell size

floor_escape_at_zero = alt.param(
    value=config["init_floor_escape_at_zero"],
    name="floor_escape_at_zero",
    bind=alt.binding_radio(options=[True, False], name="floor escape at zero"),
)

site_stats = ["mean", "sum", "max", "min"]
assert config["init_site_escape_stat"] in site_stats
site_escape_selection = alt.selection_point(
    fields=["site escape statistic"],
    bind=alt.binding_select(
        options=site_stats,
        name="site escape statistic",
    ),
    value=config["init_site_escape_stat"],
)

site_selection = alt.selection_point(fields=["site"], on="mouseover", empty=False)

site_brush = alt.selection_interval(
    encodings=["x"],
    mark=alt.BrushConfig(stroke="black", strokeWidth=2, fillOpacity=0),
    empty=True,
)

other_prop_sliders = {}
for prop, prop_df in other_props.items():
    other_prop_sliders[prop] = alt.param(
        value=max(prop_df[prop].min(), plot_params[prop]["init_min_value"]),
        name=prop.replace(" ", "_") + "_slider",
        bind=alt.binding_range(
            name=f"minimum mutation {prop}",
            min=prop_df[prop].min(),
            max=0,
        ),
    )

# region zoom bar
region_bar = (
    alt.Chart(site_numbering_map)
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site"),
            axis=None,
        ),
        color=alt.Color(
            "region",
            scale=alt.Scale(domain=site_numbering_map["region"].unique()),
        ),
        tooltip=site_numbering_map.columns.tolist(),
    )
    .mark_rect()
    .properties(
        width=site_escape_width,
        height=11,
        title=alt.TitleParams(
            "site zoom bar",
            fontSize=11,
            fontWeight="bold",
            orient="top",
        ),
    )
)

Antibody escape site summary plots¶

Now make site summary escape plots for each antibody set. Do this with both the sera faceted and overlaid.

In [11]:
site_escape_charts = {"faceted": {}, "overlaid": {}}

for antibody_set, escape_df in escape.items():
    antibodies = list(config["antibody_escape"][antibody_set]["antibody_list"].values())

    if any(not re.fullmatch("[\w \-]+", a) for a in antibodies):
        raise ValueError(f"antibody names not all alphanumeric:\n{antibodies}")

    site_escape_base = (
        # merge sequential_site rather than transform_lookup so sorting works
        alt.Chart(escape_df.merge(site_numbering_map[["site", "sequential_site"]]))
        .encode(
            y=alt.Y(
                "escape:Q",
                scale=alt.Scale(nice=False, padding=4),
                axis=alt.Axis(grid=False),
            ),
            tooltip=[
                "site",
                alt.Tooltip("escape:Q", format=".2f"),
                "antibody:N",
                *[f"{c}:N" for c in site_numbering_map.columns if c != "site"],
            ],
        )
        .transform_filter(site_brush)
    )

    site_escape_lines = site_escape_base.mark_line(size=0.75)

    site_escape_points = site_escape_base.encode(
        strokeWidth=alt.condition(site_selection, alt.value(1.5), alt.value(0)),
        size=alt.condition(site_selection, alt.value(45), alt.value(15)),
    ).mark_circle(filled=True, stroke="red")

    site_escape_lines_and_points = (
        (site_escape_lines + site_escape_points).transform_fold(
            fold=antibodies, as_=["antibody", "escape_orig"]
        )
        # floor escape at zero if selected
        .transform_calculate(
            escape=alt.expr.if_(
                floor_escape_at_zero,
                alt.expr.max(alt.datum["escape_orig"], 0),
                alt.datum["escape_orig"],
            )
        )
    )

    # filter on other properties
    for prop, prop_df in other_props.items():
        # https://github.com/altair-viz/altair/issues/2600
        slider = other_prop_sliders[prop]
        site_escape_lines_and_points = site_escape_lines_and_points.transform_lookup(
            lookup="site_mutant",
            from_=alt.LookupData(
                prop_df,
                key="site_mutant",
                fields=[prop],
            ),
        ).transform_filter(alt.datum[prop] >= slider)

    # compute site statistics
    site_escape_lines_and_points = (
        site_escape_lines_and_points
        # compute site statistics from mutation statistics
        .transform_aggregate(
            **{stat: f"{stat}(escape)" for stat in site_stats},
            groupby=["site", "sequential_site", "antibody", "wildtype"],
        )
        # filter on site statistic of interest
        .transform_fold(
            fold=site_stats, as_=["site escape statistic", "escape"]
        ).transform_filter(site_escape_selection)
        # get sequential sites and regions
        .transform_lookup(
            lookup="site",
            from_=alt.LookupData(
                site_numbering_map,
                key="site",
                fields=[
                    c
                    for c in site_numbering_map.columns
                    if c not in {"site", "sequential_site"}
                ],
            ),
        )
    )

    if len(antibody_set) < 14:
        individual_title = f"individual {antibody_set}"
        mean_title = f"mean {antibody_set}"
    else:
        individual_title = ["individual", antibody_set]
        mean_title = ["mean", antibody_set]
    if len(antibodies) == 1:
        mean_title = antibody_set

    site_escape_faceted = (
        site_escape_lines_and_points.encode(
            x=alt.X(
                "site:N",
                sort=alt.SortField("sequential_site"),
                axis=alt.Axis(labelOverlap=True, grid=False, ticks=False),
            ),
            color=alt.value("gray"),
        )
        .properties(height=site_escape_faceted_height, width=site_escape_width)
        .facet(
            facet=alt.Facet(
                "antibody:N",
                title=individual_title,
                header=alt.Header(
                    labelOrient="right",
                    labelFontSize=10,
                    labelPadding=3,
                    titleOrient="right",
                    titlePadding=3,
                ),
                sort=antibodies,
            ),
            columns=1,
            spacing=0,
        )
        .resolve_scale(y="independent")
    )

    site_escape_overlaid = site_escape_lines_and_points.encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site"),
            axis=alt.Axis(labelOverlap=True, grid=False, ticks=False),
        ),
        opacity=alt.value(0.4),
        color=alt.value("gray" if len(antibodies) > 1 else "black"),
        detail="antibody:N",
    ).properties(
        height=site_escape_overlaid_height,
        width=site_escape_width,
        title=alt.TitleParams(
            individual_title,
            fontSize=11,
            fontWeight="bold",
            orient="right",
        ),
    )

    site_mean_escape = (
        site_escape_lines_and_points
        # average missing values as zero
        .transform_calculate(
            escape=alt.expr.if_(
                alt.expr.isValid(alt.datum["escape"]),
                alt.datum["escape"],
                0,
            ),
        )
        # take mean over sera
        .transform_aggregate(
            escape="mean(escape)",
            groupby=["wildtype", *site_numbering_map.columns],
        )
        .transform_calculate(
            antibody="'mean escape'" if len(antibodies) > 1 else f"'{antibodies[0]}'"
        )
        .encode(
            x=alt.X(
                "site:N",
                sort=alt.SortField("sequential_site"),
                axis=None,
            ),
            color=alt.value("black"),
        )
        .properties(
            title=alt.TitleParams(
                mean_title,
                fontSize=11,
                fontWeight="bold",
                orient="right",
            ),
        )
    )

    for chart_type, height, site_chart in [
        ("faceted", site_escape_faceted_height, site_escape_faceted),
        ("overlaid", site_escape_overlaid_height, site_escape_overlaid),
    ]:
        site_escape_charts[chart_type][antibody_set] = alt.vconcat(
            region_bar.add_params(site_brush),
            (
                alt.vconcat(
                    *(
                        [
                            site_mean_escape.properties(
                                height=height,
                                width=site_escape_width,
                            ),
                            site_chart,
                        ]
                        if len(antibodies) > 1
                        else [
                            site_mean_escape.properties(
                                height=height,
                                width=site_escape_width,
                            )
                        ]
                    ),
                    spacing=3,
                ).add_params(
                    site_escape_selection,
                    site_selection,
                    *other_prop_sliders.values(),
                    floor_escape_at_zero,
                )
            ),
            spacing=0,
        )

        site_escape_charts[chart_type][antibody_set] = alt.vconcat(
            *(
                [
                    site_mean_escape.properties(
                        height=height,
                        width=site_escape_width,
                    ),
                    site_chart,
                ]
                if len(antibodies) > 1
                else [
                    site_mean_escape.properties(
                        height=height,
                        width=site_escape_width,
                    )
                ]
            ),
            spacing=0,
        ).add_params(
            site_escape_selection,
            site_selection,
            *other_prop_sliders.values(),
            floor_escape_at_zero,
        )

Heatmaps¶

First, create a data frame that has the average escape across sera (averaging mutations missing for a serum as zero for that serum) and other properties:

In [12]:
heatmap_data = None
heatmap_data_cols = []
for antibody_set, escape_df in escape.items():
    antibody_list = list(
        config["antibody_escape"][antibody_set]["antibody_list"].values()
    )
    df = (
        pd.concat(
            [
                escape_df,
                # add wildtype with zero escape
                (
                    escape_df[["site", "wildtype"]]
                    .drop_duplicates()
                    .assign(mutant=lambda x: x["wildtype"])
                ),
            ],
        )
        .fillna(0)
        .assign(
            escape=lambda x: x[antibody_list].mean(axis=1),
            site_mutant=lambda x: x["site"].astype(str) + x["mutant"],
        )
        .merge(site_numbering_map, on="site", validate="many_to_one")
        .drop(columns=antibody_list)
        .rename(columns={"escape": f"{antibody_set} escape"})
    )
    heatmap_data_cols.append(f"{antibody_set} escape")
    if heatmap_data is not None:
        heatmap_data = heatmap_data.merge(df, how="outer", validate="one_to_one")
    else:
        heatmap_data = df

for prop, prop_df in other_props.items():
    heatmap_data_cols.append(prop)
    if heatmap_data is None:
        heatmap_data = prop_df
    else:
        heatmap_data = heatmap_data.merge(prop_df, validate="one_to_one", how="outer")

if heatmap_data is None:
    raise ValueError("no data specified for summary")

heatmap_data = (
    heatmap_data.drop(columns=[c for c in site_numbering_map.columns if c != "site"])
    .merge(site_numbering_map, validate="many_to_one")
    .drop(columns="site_mutant")
)

for antibody_set in escape:
    col = f"{antibody_set} escape"
    heatmap_data[col] = heatmap_data[col].where(
        heatmap_data["wildtype"] != heatmap_data["mutant"], 0
    )

if len(
    dup_rows := (
        heatmap_data.assign(
            _n=lambda x: x.groupby(["site", "mutant"])["wildtype"].transform("count")
        ).query("_n > 1")
    )
):
    raise ValueError(f"Duplicate rows:\n{dup_rows}")

print(f"Writing summary data to {output_csv_file}")
heatmap_data.to_csv(output_csv_file, index=False, float_format="%.4g")

heatmap_data
Writing summary data to results/escape_by_prior_infections/summary.csv
Out[12]:
site wildtype mutant one infection escape multiple infections escape spike mediated entry ACE2 binding sequential_site region
0 2 F C -0.021218 0.059678 0.10100 0.02151 2 other
1 2 F L 0.067207 -0.053920 0.09432 -0.26980 2 other
2 2 F S 0.047673 0.007725 0.05844 -0.05642 2 other
3 2 F F 0.000000 0.000000 0.00000 0.00000 2 other
4 3 V A -0.055882 0.143862 -0.04154 -0.04977 3 other
... ... ... ... ... ... ... ... ... ...
8345 1211 K K 0.000000 0.000000 0.00000 0.00000 1207 S2
8346 1212 W R NaN NaN -2.37700 -0.21050 1208 other
8347 1212 W W 0.000000 0.000000 0.00000 0.00000 1208 other
8348 767 L L 0.000000 0.000000 0.00000 0.00000 763 S2
8349 767 L V NaN NaN NaN 1.69700 763 S2

8350 rows × 9 columns

Make heatmaps:

In [13]:
last_heatmap = list(plot_params)[-1]

heatmap_base = (
    alt.Chart(heatmap_data)
    .transform_calculate(
        **{
            f"{antibody_set} escape_floored": alt.expr.if_(
                floor_escape_at_zero,
                alt.expr.max(alt.datum[f"{antibody_set} escape"], 0),
                alt.datum[f"{antibody_set} escape"],
            )
            for antibody_set in escape
        },
        # convert null values to NaN so they show as NaN in tooltips rather than as 0.0
        **{
            col: alt.expr.if_(
                alt.expr.isFinite(alt.datum[col]),
                alt.datum[col],
                alt.expr.NaN,
            )
            for col in heatmap_data_cols
        },
    )
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site"),
            axis=alt.Axis(labelFontSize=9, ticks=False),
        ),
        y=alt.Y(
            "mutant:N",
            title="amino acid",
            sort=alphabet,
            axis=alt.Axis(labelFontSize=9, ticks=False),
        ),
        strokeWidth=alt.condition(site_selection, alt.value(2), alt.value(1)),
    )
    .properties(width=alt.Step(cell_size), height=alt.Step(cell_size))
    .add_params(*other_prop_sliders.values(), floor_escape_at_zero)
)

# mark X for wildtype
heatmap_wildtype = heatmap_base.transform_filter(
    alt.datum["wildtype"] == alt.datum["mutant"]
).mark_text(text="x", color="black")

# gray background for missing values
heatmap_bg = heatmap_base.transform_impute(
    impute="_stat_dummy",
    key="mutant",
    keyvals=alphabet,
    groupby=["site"],
    value=None,
).mark_rect(color="#E0E0E0", opacity=0.8)

tooltips = [
    "site",
    "mutant",
    *[alt.Tooltip(c, format=".2f") for c in heatmap_data_cols],
    "wildtype",
    *[c for c in site_numbering_map.columns if c != "site"],
]

legend = alt.Legend(
    orient="left",
    offset=10,
    titleOrient="left",
    gradientLength=100,
    gradientThickness=10,
    gradientStrokeColor="black",
    gradientStrokeWidth=0.5,
)

# heatmaps for escape
escape_heatmaps = {}
escape_heatmap_base = heatmap_base.transform_filter(
    functools.reduce(
        operator.or_,
        [(alt.datum[prop] >= slider) for prop, slider in other_prop_sliders.items()],
    )
    | (alt.datum["wildtype"] == alt.datum["mutant"])
    | functools.reduce(
        operator.and_,
        [alt.expr.isNaN(alt.datum[prop]) for prop in other_prop_sliders],
    )
)
for antibody_set in escape:
    domain_lims = {
        lim: (
            plot_params[antibody_set][f"fixed_{lim}"]
            if f"fixed_{lim}" in plot_params[antibody_set]
            else lim_func(
                plot_params[antibody_set][f"{lim}_at_least"],
                lim_func(heatmap_data[f"{antibody_set} escape"]),
            )
        )
        for (lim, lim_func) in [("min", min), ("max", max)]
    }
    escape_heatmaps[antibody_set] = escape_heatmap_base.encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site"),
            title=None,
            axis=(
                alt.Axis()
                if last_heatmap == antibody_set
                else alt.Axis(ticks=False, labels=False)
            ),
        ),
        color=alt.Color(
            f"{antibody_set} escape_floored:Q",
            title=f"{antibody_set} escape"
            if len(antibody_set) < 14
            else [antibody_set, "escape"],
            legend=legend,
            scale=alt.Scale(
                zero=True,
                nice=False,
                type="linear",
                domainMid=0,
                domainMax=domain_lims["max"],
                domainMin=alt.ExprRef(
                    f"if(floor_escape_at_zero, 0, {domain_lims['min']})"
                ),
                range=(
                    color_gradient_hex(
                        plot_params[antibody_set]["negative_color"], "white", n=20
                    )
                    + color_gradient_hex(
                        "white", plot_params[antibody_set]["positive_color"], n=20
                    )[1:]
                ),
            ),
        ),
        tooltip=tooltips,
    ).mark_rect(stroke="black")

# heatmap for other property (eg, functional effect) filtered escape
escape_filtered_heatmap = (
    heatmap_base.transform_filter(
        functools.reduce(
            operator.or_,
            [alt.datum[prop] < slider for prop, slider in other_prop_sliders.items()],
        )
        & (alt.datum["wildtype"] != alt.datum["mutant"])
    )
    .transform_calculate(filtered="''")
    .encode(
        tooltip=tooltips,
        color=alt.Color(
            "filtered:N",
            title="deleterious",
            scale=alt.Scale(range=["silver"]),
            legend=None,
        ),
    )
    .mark_rect(stroke="black")
)

# heatmaps for other properties
other_prop_heatmaps = {}
other_prop_filtered_heatmaps = {}
for prop in other_props:
    params = plot_params[prop]
    domain_lims = {
        lim: (
            params[f"fixed_{lim}"]
            if f"fixed_{lim}" in params
            else lim_func(
                params[f"{lim}_at_least"],
                lim_func(heatmap_data[prop]),
            )
        )
        for (lim, lim_func) in [("min", min), ("max", max)]
    }
    slider_name = prop.replace(" ", "_") + "_slider"  # name given when defining slider
    other_prop_heatmaps[prop] = (
        heatmap_base.transform_filter(
            functools.reduce(
                operator.and_,
                [
                    (
                        (alt.datum[other_prop] >= slider)
                        | alt.expr.isNaN(alt.datum[other_prop])
                    )
                    for other_prop, slider in other_prop_sliders.items()
                    if other_prop != prop
                ],
                True,
            )
            | (alt.datum["wildtype"] == alt.datum["mutant"])
        )
        .encode(
            x=alt.X(
                "site:N",
                sort=alt.SortField("sequential_site"),
                title=None,
                axis=(
                    alt.Axis()
                    if last_heatmap == prop
                    else alt.Axis(ticks=False, labels=False)
                ),
            ),
            color=alt.Color(
                prop,
                legend=legend,
                scale=alt.Scale(
                    zero=True,
                    nice=False,
                    type="linear",
                    clamp=True,
                    domainMid=0,
                    domainMax=domain_lims["max"],
                    domainMin=alt.ExprRef(f"max({slider_name}, {domain_lims['min']})"),
                    range=(
                        color_gradient_hex(params["negative_color"], "white", n=20)
                        + color_gradient_hex("white", params["positive_color"], n=20)[
                            1:
                        ]
                    ),
                ),
            ),
            tooltip=tooltips,
        )
        .mark_rect(stroke="black")
    )
    other_prop_filtered_heatmaps[prop] = (
        heatmap_base.transform_filter(
            functools.reduce(
                operator.or_,
                [
                    alt.datum[other_prop] < slider
                    for other_prop, slider in other_prop_sliders.items()
                    if other_prop != prop
                ],
                False,
            )
            & (alt.datum["wildtype"] != alt.datum["mutant"])
        )
        .transform_calculate(filtered="''")
        .encode(
            tooltip=tooltips,
            color=alt.Color(
                "filtered:N",
                title="deleterious",
                scale=alt.Scale(range=["silver"]),
                legend=None,
            ),
        )
        .mark_rect(stroke="black")
    )

heatmap = (
    alt.vconcat(
        *[
            heatmap_bg + escape_heatmap + escape_filtered_heatmap + heatmap_wildtype
            for escape_heatmap in escape_heatmaps.values()
        ],
        *[
            heatmap_bg
            + other_prop_heatmaps[prop]
            + other_prop_filtered_heatmaps[prop]
            + heatmap_wildtype
            for prop in other_props
        ],
        spacing=2,
    )
    .resolve_scale(color="independent")
    .add_params(site_selection)
)

Merged site escape lineplot and heatmaps¶

Create and save merged escape lineplot and heatmap:

In [14]:
for chart_type, chartfile in [("overlaid", chart_overlaid), ("faceted", chart_faceted)]:
    merged_chart = alt.vconcat(
        region_bar.add_params(site_brush),
        *site_escape_charts[chart_type].values(),
        heatmap.transform_filter(site_brush),
        spacing=2,
    ).configure_legend(orient="left", padding=0)

    print(f"Saving {chart_type} chart to {chartfile}")
    merged_chart.save(chartfile)
    display(merged_chart)
Saving overlaid chart to results/escape_by_prior_infections/summary_overlaid_nolegend.html
Saving faceted chart to results/escape_by_prior_infections/summary_faceted_nolegend.html
In [ ]: