import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import spatial_lda.visualization as sv
from alpineer import misc_utils
from ark.utils.spatial_lda_utils import make_plot_fn
[docs]def draw_boxplot(cell_data, col_name, col_split=None,
split_vals=None, dpi=None, save_dir=None, save_file=None):
"""Draws a boxplot for a given column, optionally with help from a split column
Args:
cell_data (pandas.DataFrame):
Dataframe containing columns with Patient ID and Cell Name
col_name (str):
Name of the column we wish to draw a box-and-whisker plot for
col_split (str):
If specified, used for additional box-and-whisker plot faceting
split_vals (list):
If specified, only visualize the specified values in the col_split column
dpi (float):
The resolution of the image to save, ignored if save_dir is None
save_dir (str):
If specified, a directory where we will save the plot
save_file (str):
If save_dir specified, specify a file name you wish to save to.
Ignored if save_dir is None
"""
# the col_name must be valid
misc_utils.verify_in_list(col_name=col_name, column_names=cell_data.columns.values)
# if col_split is not None, it must exist as a column in cell_data
if col_split is not None and col_split not in cell_data.columns.values:
misc_utils.verify_in_list(col_split=col_split, column_names=cell_data.columns.values)
# basic error checks if split_vals is set
if split_vals is not None:
# the user cannot specify split_vales without specifying col_split
if col_split is None:
raise ValueError("If split_vals is set, then col_split must also be set")
# all the values in split_vals must exist in the col_name of cell_data
misc_utils.verify_in_list(split_vals=split_vals,
column_split_values=cell_data[col_split].unique())
# don't modify cell_data in anyway
data_to_viz = cell_data.copy(deep=True)
# ignore values in col_split not in split_vals if split_vals is set
if split_vals:
data_to_viz = data_to_viz[data_to_viz[col_split].isin(split_vals)]
if col_split:
# if col_split, then we explicitly facet the visualization
# labels are automatically generated in Seaborn
sns.boxplot(x=col_split, y=col_name, data=data_to_viz)
plt.title("Distribution of %s, faceted by %s" % (col_name, col_split))
else:
# otherwise, we don't facet anything, but we have to explicitly make vertical
sns.boxplot(x=col_name, data=data_to_viz, orient="v")
plt.title("Distribution of %s" % col_name)
# save visualization to a directory if specified
if save_dir is not None:
misc_utils.save_figure(save_dir, save_file, dpi=dpi)
[docs]def draw_heatmap(data, x_labels, y_labels, dpi=None, center_val=None, min_val=None, max_val=None,
cbar_ticks=None, colormap="vlag", row_colors=None, row_cluster=True,
col_colors=None, col_cluster=True, left_start=None, right_start=None,
w_spacing=None, h_spacing=None, save_dir=None, save_file=None):
"""Plots the z scores between all phenotypes as a clustermap.
Args:
data (numpy.ndarray):
The data array to visualize
x_labels (list):
List of names displayed on horizontal axis
y_labels (list):
List of all names displayed on vertical axis
dpi (float):
The resolution of the image to save, ignored if save_dir is None
center_val (float):
value at which to center the heatmap
min_val (float):
minimum value the heatmap should take
max_val (float):
maximum value the heatmap should take
cbar_ticks (int):
list of values containing tick labels for the heatmap colorbar
colormap (str):
color scheme for visualization
row_colors (list):
Include these values as an additional color-coded cluster bar for row values
row_cluster (bool):
Whether to include dendrogram clustering for the rows
col_colors (list):
Include these values as an additional color-coded cluster bar for column values
col_cluster (bool):
Whether to include dendrogram clustering for the columns
left_start (float):
The position to set the left edge of the figure to (from 0-1)
right_start (float):
The position to set the right edge of the figure to (from 0-1)
w_spacing (float):
The amount of spacing to put between the subplots width-wise (from 0-1)
h_spacing (float):
The amount of spacing to put between the subplots height-wise (from 0-1)
save_dir (str):
If specified, a directory where we will save the plot
save_file (str):
If save_dir specified, specify a file name you wish to save to.
Ignored if save_dir is None
"""
# Replace the NA's and inf values with 0s
data[np.isnan(data)] = 0
data[np.isinf(data)] = 0
# Assign numpy values respective phenotype labels
data_df = pd.DataFrame(data, index=x_labels, columns=y_labels)
sns.set(font_scale=.7)
heatmap = sns.clustermap(
data_df, cmap=colormap, center=center_val,
vmin=min_val, vmax=max_val, row_colors=row_colors, row_cluster=row_cluster,
col_colors=col_colors, col_cluster=col_cluster,
cbar_kws={'ticks': cbar_ticks}
)
# ensure the row color axis doesn't have a label attacked to it
if row_colors is not None:
_ = heatmap.ax_row_colors.xaxis.set_visible(False)
if col_colors is not None:
_ = heatmap.ax_col_colors.yaxis.set_visible(False)
# update the figure dimensions to accommodate Jupyter widget backend
_ = heatmap.gs.update(
left=left_start, right=right_start, wspace=w_spacing, hspace=h_spacing
)
# ensure the y-axis labels are horizontal, will be misaligned if vertical
_ = plt.setp(heatmap.ax_heatmap.get_yticklabels(), rotation=0)
plt.tight_layout()
if save_dir is not None:
misc_utils.save_figure(save_dir, save_file, dpi=dpi)
[docs]def get_sorted_data(cell_data, sort_by_first, sort_by_second, is_normalized=False):
"""Gets the cell data and generates a new Sorted DataFrame with each row representing a
patient and column representing Population categories
Args:
cell_data (pandas.DataFrame):
Dataframe containing columns with Patient ID and Cell Name
sort_by_first (str):
The first attribute we will be sorting our data by
sort_by_second (str):
The second attribute we will be sorting our data by
is_normalized (bool):
Boolean specifying whether to normalize cell counts or not, default is False
Returns:
pandas.DataFrame:
DataFrame with rows and columns sorted by population
"""
cell_data_stacked = pd.crosstab(
cell_data[sort_by_first],
cell_data[sort_by_second],
normalize='index' if is_normalized else False
)
# Sorts by Kagel Method :)
index_facet_order = cell_data.groupby(sort_by_first).count().sort_values(
by=sort_by_second,
ascending=False
).index.values
column_facet_order = cell_data.groupby(sort_by_second).count().sort_values(
by=sort_by_first,
ascending=False
).index.values
cell_data_stacked = cell_data_stacked.reindex(index_facet_order, axis='index')
cell_data_stacked = cell_data_stacked.reindex(column_facet_order, axis='columns')
return cell_data_stacked
[docs]def plot_barchart(data, title, x_label, y_label, color_map="jet", is_stacked=True,
is_legend=True, legend_loc='center left', bbox_to_anchor=(1.0, 0.5),
dpi=None, save_dir=None, save_file=None):
"""A helper function to visualize_patient_population_distribution
Args:
data (pandas.DataFrame):
The data we wish to visualize
title (str):
The title of the graph
x_label (str):
The label on the x-axis
y_label (str):
The label on the y-axis
color_map (str):
The name of the Matplotlib colormap used
is_stacked (bool):
Whether we want a stacked barchart or not
is_legend (bool):
Whether we want a legend or not
legend_loc (str):
If is_legend is set, specify where we want the legend to be
Ignored if is_legend is False
bbox_to_anchor (tuple):
If is_legend is set, specify the bounding box of the legend
Ignored if is_legend is False
dpi (float):
The resolution of the image to save, ignored if save_dir is None
save_dir (str):
Directory to save plots, default is None
save_file (str):
If save_dir specified, specify a file name you wish to save to.
Ignored if save_dir is None
"""
data.plot.bar(colormap=color_map, stacked=is_stacked, legend=is_legend)
plt.title(title)
plt.xlabel(x_label)
plt.ylabel(y_label)
if is_legend:
plt.legend(loc=legend_loc, bbox_to_anchor=bbox_to_anchor)
if save_dir is not None:
misc_utils.save_figure(save_dir, save_file, dpi=dpi)
[docs]def visualize_patient_population_distribution(cell_data, patient_col_name, population_col_name,
color_map="jet", show_total_count=True,
show_distribution=True, show_proportion=True,
dpi=None, save_dir=None):
"""Plots the distribution of the population given by total count, direct count, and proportion
Args:
cell_data (pandas.DataFrame):
Dataframe containing columns with Patient ID and Cell Name
patient_col_name (str):
Name of column containing categorical Patient data
population_col_name (str):
Name of column in dataframe containing Population data
color_map (str):
Name of MatPlotLib ColorMap used. Default is jet
show_total_count (bool):
Boolean specifying whether to show graph of total population count, default is true
show_distribution (bool):
Boolean specifying whether to show graph of population distribution, default is true
show_proportion (bool):
Boolean specifying whether to show graph of total count, default is true
dpi (float):
The resolution of the image to save, ignored if save_dir is None
save_dir (str):
Directory to save plots, default is None
"""
cell_data = cell_data.dropna()
# Plot by total count
if show_total_count:
population_values = cell_data[population_col_name].value_counts()
title = "Distribution of Population in all patients"
x_label = "Population Type"
y_label = "Population Count"
plot_barchart(population_values, title, x_label, y_label, is_legend=False,
dpi=dpi, save_dir=save_dir, save_file="PopulationDistribution.png")
# Plot by count
if show_distribution:
sorted_data = get_sorted_data(cell_data, patient_col_name, population_col_name)
title = "Distribution of Population Count in Patients"
plot_barchart(sorted_data, title, patient_col_name, population_col_name,
dpi=dpi, save_dir=save_dir, save_file="TotalPopulationDistribution.png")
# Plot by Proportion
if show_proportion:
sorted_data = get_sorted_data(cell_data, patient_col_name, population_col_name,
is_normalized=True)
title = "Distribution of Population Count Proportion in Patients"
plot_barchart(sorted_data, title, patient_col_name, population_col_name,
dpi=dpi, save_dir=save_dir, save_file="PopulationProportion.png")
[docs]def visualize_neighbor_cluster_metrics(neighbor_cluster_stats, metric_name,
dpi=None, save_dir=None):
"""Visualize the cluster performance results of a neighborhood matrix
Args:
neighbor_cluster_stats (xarray.DataArray):
contains the desired statistic we wish to visualize, should have one
coordinate called cluster_num labeled starting from 2
metric_name (str):
name of metric
dpi (float):
The resolution of the image to save, ignored if save_dir is None
save_dir (str):
Directory to save plots, default is None
"""
# get the coordinates and values we'll need
x_coords = neighbor_cluster_stats.coords['cluster_num'].values
scores = neighbor_cluster_stats.values
# plot the results
plt.plot(x_coords, scores)
plt.title(metric_name+" vs number of clusters")
plt.xlabel("Number of clusters")
plt.ylabel(metric_name)
# save if desired
if save_dir is not None:
misc_utils.save_figure(save_dir, "neighborhood_"+metric_name+"_scores.png", dpi=dpi)
[docs]def visualize_topic_eda(data, metric="gap_stat", gap_sd=True, k=None, transpose=False, scale=0.5,
dpi=None, save_dir=None):
"""Visualize the exploratory metrics for spatial-LDA topics
Args:
data (dict):
The dictionary of exploratory metrics produced by
:func:`~ark.spLDA.processing.compute_topic_eda`.
metric (str):
One of "gap_stat", "inertia", "silhouette", or "cell_counts".
gap_sd (bool):
If True, the standard error of the gap statistic is included in the plot.
k (int):
References a specific KMeans clustering with k clusters for visualizing the cell count
heatmap.
transpose (bool):
Swap axes for cell_counts heatmap
scale (float):
Plot size scaling for cell_counts heatmap
dpi (float):
The resolution of the image to save, ignored if save_dir is None
save_dir (str):
Directory to save plots, default is None
"""
valid_metrics = ["gap_stat", "inertia", "silhouette", "cell_counts"]
misc_utils.verify_in_list(actual=[metric], expected=valid_metrics)
featurization = data["featurization"]
data_k = {k: v for k, v in data.items() if k != "featurization"}
df = pd.DataFrame.from_dict(data_k)
df['num_clusters'] = df.index
if metric == "gap_stat":
if gap_sd:
plt.plot()
plt.errorbar(x=df["num_clusters"], y=df["gap_stat"], yerr=df["gap_sds"])
else:
sns.relplot(data=df, x="num_clusters", y="gap_stat", kind="line")
plt.xlabel("Number of Clusters")
plt.ylabel("Gap")
elif metric == "inertia":
sns.relplot(data=df, x="num_clusters", y="inertia", kind="line")
plt.xlabel("Number of Clusters")
plt.ylabel("Inertia")
elif metric == "silhouette":
sns.relplot(data=df, x="num_clusters", y="silhouette", kind="line")
plt.xlabel("Number of Clusters")
plt.ylabel("Silhouette Score")
elif metric == "cell_counts":
if k is None:
raise ValueError("Must provide number of clusters for k value.")
cell_counts = data["cell_counts"][k]
cell_counts = cell_counts / cell_counts.sum(axis=0)
if transpose:
cell_counts = cell_counts.T
plt.subplots(figsize=(scale * cell_counts.shape[1], scale * cell_counts.shape[0]))
sns.heatmap(cell_counts, vmin=0, square=True, xticklabels=True,
yticklabels=True, cmap="mako")
plt.xlabel("KMeans Cluster Label")
if featurization == "cluster":
plt.ylabel("Cell Cluster")
elif featurization == "marker" or featurization == "avg_marker":
plt.ylabel("Channel Marker")
else:
plt.ylabel("Cell Counts")
if save_dir is not None:
clust_label = ""
if metric == "cell_counts":
clust_label = "_k_{}".format(str(k))
file_name = "topic_eda_" + metric + clust_label + ".png"
misc_utils.save_figure(save_dir, file_name, dpi=dpi)
[docs]def visualize_fov_stats(data, metric="cellular_density", dpi=None, save_dir=None):
"""Visualize area and cell count distributions for all field of views.
Args:
data (dict):
The dictionary of field of view metrics produced by
:func:`~ark.spLDA.processing.fov_density`.
metric (str):
One of "cellular_density", "average_area", or "total_cells". See
documentation of :func:`~ark.spLDA.processing.fov_density` for details.
dpi (float):
The resolution of the image to save, ignored if save_dir is None
save_dir (str):
Directory to save plots, default is None
"""
df = pd.DataFrame.from_dict(data)
df['fov'] = df.index
if metric == "cellular_density":
sns.histplot(data=df, x="cellular_density")
plt.xlabel("FOV Cellular Density")
plt.ylabel("Count")
elif metric == "average_area":
sns.histplot(data=df, x="average_area")
plt.xlabel("FOV Average Cell Area")
plt.ylabel("Count")
else:
sns.histplot(data=df, x="total_cells")
plt.xlabel("FOV Total Cell Count")
plt.ylabel("Count")
if save_dir is not None:
file_name = "fov_metrics_" + metric + ".png"
misc_utils.save_figure(save_dir, file_name, dpi=dpi)
[docs]def visualize_fov_graphs(cell_table, features, diff_mats, fovs, dpi=None, save_dir=None):
"""Visualize the adjacency graph used to define neighboring environments in each field of view.
Args:
cell_table (dict):
A formatted cell table for use in spatial-LDA analysis. Specifically, this is the
output from :func:`~ark.spLDA.processing.format_cell_table`.
features (dict):
A featurized cell table. Specifically, this is the output from
:func:`~ark.spLDA.processing.featurize_cell_table`.
diff_mats (dict):
The difference matrices produced by
:func:`~ark.spLDA.processing.create_difference_matrices`.
fovs (list):
A list of field of view IDs to plot.
dpi (float):
The resolution of the image to save, ignored if save_dir is None.
save_dir (str):
Directory to save plots, default is None
"""
_plot_fn = make_plot_fn(plot="adjacency", difference_matrices=diff_mats["train_diff_mat"])
sv.plot_samples_in_a_row(features["train_features"], _plot_fn, cell_table, tumor_set=fovs)
if save_dir is not None:
fovs_str = "_".join([str(x) for x in fovs])
file_name = "adjacency_graph_fovs_" + fovs_str + ".png"
misc_utils.save_figure(save_dir, file_name, dpi=dpi)