Source code for ark.utils.metacluster_remap_gui.metaclustergui
import warnings
import ipywidgets as widgets
import matplotlib.patheffects as path_effects
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from IPython.display import display
from scipy.cluster.hierarchy import dendrogram
from scipy.stats import zscore
from .colormap_helper import distinct_cmap
from .metaclusterdata import MetaClusterData
from .throttle import throttle
from .zscore_norm import ZScoreNormalize
# Third party ipympl causing this in it's backend_agg startup code
warnings.filterwarnings("ignore", message="nbagg.transparent is deprecated")
DEBUG_VIEW = widgets.Output(layout={'border': '1px solid black'})
DEFAULT_HEATMAP = sns.diverging_palette(240, 10, n=3, as_cmap=True)
[docs]class MetaClusterGui():
"""Coordinate and present the metacluster Graphical User Interface
Attributes:
mcd (MetaClusterData)):
State of the actual clusters at any point in time
selected_clusters (set[int]):
Currently selected clusters
Args:
data (MetaClusterData)):
An initialized MetaClusterData instance
heatmapcolors (matplotlib.colors.ColorMap)):
If you wish to change the default heatmap colors
width (float):
Adjust the actual width to accomodate monitor size, resolution, zoom, etc
debug (bool):
Enable debug mode for the GUI. This enables a special logging window where
output from callbacks can be printed.
enable_throttle (bool):
Control whether or not to throttle GUI callbacks. Disabling might be
helpful for debugging certain race conditions.
"""
def __init__(self, metaclusterdata, heatmapcolors=DEFAULT_HEATMAP,
width=17.0, debug=False, enable_throttle=True):
self.width: float = width
self.heatmapcolors: str = heatmapcolors
self.mcd: MetaClusterData = metaclusterdata
self.selected_clusters = set()
self.make_widgets()
self.make_gui()
self._heatmaps_stale = True
self.update_gui()
display(self.gui)
if debug:
self.enable_debug_mode()
if enable_throttle:
throttler = throttle(.3)
self.update_gui = throttler(self.update_gui)
[docs] def make_gui(self):
"""Create and configure all of the plots which make up the GUI
Below is a map of the physical subplot layout of
the Axes within the Figure.
The abbreviation is used both for the axes
e.g. self.ax_c
as well as the plotted items.
e.g. self.im_c, self.rects_cp
Map of matplotlib Figure::
| | Cluster | Meta |
----------------------------
| | cp | cb | counts of pixels, color bar
| cd| c | m | heatmap itself
| | cs | ms | selection markers
| | cl | ml | metacluster color labels
"""
width_ratios = [
int(self.mcd.cluster_count / 7),
self.mcd.cluster_count,
self.mcd.metacluster_count * 2,
]
marker_ratio = max(self.mcd.marker_count / 20, 1)
height_ratios = [
6 * marker_ratio, self.mcd.marker_count * marker_ratio, marker_ratio, marker_ratio
]
subplots = plt.subplots(
4, 3,
gridspec_kw={
'width_ratios': width_ratios,
'height_ratios': height_ratios},
figsize=(self.width, 6 * marker_ratio),
)
with self.plot_output:
plt.show()
(self.fig, (
(self.ax_01, self.ax_cp, self.ax_cb),
(self.ax_cd, self.ax_c, self.ax_m),
(self.ax_02, self.ax_cs, self.ax_ms),
(self.ax_03, self.ax_cl, self.ax_ml))) = subplots
self.fig.canvas.toolbar_visible = False
self.fig.canvas.header_visible = False
self.fig.canvas.footer_visible = False
self.fig.canvas.mpl_connect('pick_event', self.onpick)
# heatmaps
self.normalizer = ZScoreNormalize(-1, 0, 1)
def _heatmap(ax, column_count):
data = np.zeros((self.mcd.marker_count, column_count))
return ax.imshow(
data,
norm=self.normalizer,
cmap=self.heatmapcolors,
aspect='auto',
picker=True,
)
self.im_c = _heatmap(self.ax_c, self.mcd.cluster_count)
self.im_m = _heatmap(self.ax_m, self.mcd.metacluster_count)
self.ax_c.yaxis.set_tick_params(which='major', labelleft=False)
self.ax_c.set_yticks(np.arange(self.mcd.marker_count) + 0.5)
self.ax_c.set_xticks(np.arange(self.mcd.cluster_count) + 0.5)
self.ax_m.set_xticks(np.arange(self.mcd.metacluster_count) + 0.5)
self.ax_c.xaxis.set_tick_params(which='both', bottom=False, labelbottom=False)
self.ax_m.xaxis.set_tick_params(which='both', bottom=False, labelbottom=False)
self.ax_m.yaxis.set_tick_params(which='both', left=False, labelleft=False)
self.ax_m.yaxis.set_tick_params(which='both', right=True, labelright=True, labelsize=7)
self.ax_m.set_yticks(np.arange(self.mcd.marker_count) + 0.5)
# xaxis metacluster color labels
self.ax_cl.xaxis.set_tick_params(which='both', bottom=False, labelbottom=False)
self.ax_ml.xaxis.set_tick_params(which='both', bottom=True, labelbottom=True)
self.ax_cl.yaxis.set_tick_params(which='both', left=False, labelleft=True)
self.ax_cl.set_yticks([0.5])
self.ax_cl.set_yticklabels(["Metacluster"])
self.ax_ml.yaxis.set_tick_params(which='both', left=False, labelleft=False)
def _color_labels(ax, column_count):
data = np.zeros((1, column_count))
return ax.imshow(data, aspect='auto', picker=True, vmin=1, vmax=self.mcd.cluster_count)
self.im_cl = _color_labels(self.ax_cl, self.mcd.cluster_count)
self.im_ml = _color_labels(self.ax_ml, self.mcd.metacluster_count)
# xaxis cluster selection labels
self.ax_cs.xaxis.set_tick_params(which='both', bottom=False, labelbottom=False)
self.ax_cs.yaxis.set_tick_params(which='both', left=False, labelleft=True)
self.ax_cs.set_yticks([0.5])
self.ax_cs.set_yticklabels(["Selected"])
self.im_cs = self.ax_cs.imshow(
np.zeros((1, self.mcd.marker_count)),
cmap='Blues',
aspect='auto',
picker=True,
vmin=-0.3,
vmax=1,
)
# xaxis pixelcount graphs
self.ax_cp.xaxis.set_tick_params(which='both', bottom=False, labelbottom=False)
self.ax_cp.yaxis.set_tick_params(which='both', left=False, labelleft=False)
self.ax_cp.set_ylabel("Count (k)", rotation=90)
self.ax_cp.set_xlim(0, self.mcd.cluster_count)
self.rects_cp = self.ax_cp.bar(
np.arange(self.mcd.cluster_count) + 0.5,
np.zeros(self.mcd.cluster_count))
self.labels_cp = []
label_alignment_fudge = 0.08
for x in np.arange(self.mcd.cluster_count) + 0.5 + label_alignment_fudge:
label = self.ax_cp.text(
x=x, y=0, s="-", va='bottom',
ha='center', rotation=90, color='black', fontsize=8)
self.labels_cp.append(label)
# colorbar
self.cb = plt.colorbar(self.im_c, ax=self.ax_cb, orientation='horizontal',
fraction=.75, shrink=.95, aspect=15)
self.cb.ax.xaxis.set_tick_params(which='both', labelsize=7, labelrotation=90)
# dendrogram
self.ddg = dendrogram(
self.mcd.linkage_matrix,
ax=self.ax_cd,
orientation='left',
labels=self.mcd.fixed_width_marker_names,
leaf_font_size=8,
)
self.mcd.set_marker_order(self.ddg['leaves'][::-1])
self.ax_m.set_yticklabels(self.mcd.marker_names[::-1])
self.ax_cd.figure.frameon = False
self.ax_cd.spines["top"].set_visible(False)
self.ax_cd.spines["left"].set_visible(False)
self.ax_cd.spines["right"].set_visible(False)
self.ax_cd.spines["bottom"].set_visible(False)
self.ax_cd.xaxis.set_tick_params(which='both', bottom=False, labelbottom=False)
self.ax_cd.yaxis.set_tick_params(which='both', pad=-2)
self.ax_cd.tick_params(axis="y", direction="in")
self.move_dendro_labels(self.ax_cd)
self.ax_01.axis('off')
self.ax_02.axis('off')
self.ax_03.axis('off')
self.ax_cb.axis('off')
self.ax_ms.axis('off')
# space for longer labels hack
self.ax_ml.set_xticks([0.5])
self.ax_ml.set_xticklabels(["SpaceHolder--"], rotation=90, fontsize=8)
# Tighten layout based on display
self.fig.tight_layout()
plt.subplots_adjust(hspace=.0) # make color labels touch heatmap
plt.subplots_adjust(wspace=.02)
[docs] def make_widgets(self):
"""Create the physical ipywidgets that display below the GUI plot."""
# zscore adjuster
self.zscore_clamp_slider = widgets.FloatSlider(
value=3,
min=1,
max=10.0,
step=0.5,
description='Max Zscore:',
disabled=False,
continuous_update=True,
orientation='horizontal',
readout=True,
readout_format='.1f',
tooltip='Clamp/Clip zscore to a certain max value.',
)
self.zscore_clamp_slider.observe(self.update_zscore)
# clear_selection button
self.clear_selection_button = widgets.Button(
description='Clear Selection',
disabled=False,
button_style='warning',
tooltip='Clear currently selected clusters',
icon='ban',
)
self.clear_selection_button.on_click(self.clear_selection)
# new metacluster button
self.new_metacluster_button = widgets.Button(
description='New metacluster',
disabled=False,
button_style='success',
tooltip='Create new metacluster from current selection',
icon='plus',
)
self.new_metacluster_button.on_click(self.new_metacluster)
# metacluster metadata
self.current_metacluster = widgets.Dropdown(
value=self.mcd.metaclusters.index[0],
options=list(zip(self.mcd.metacluster_displaynames, self.mcd.metaclusters.index)),
description='MetaCluster:',
)
self.current_metacluster.observe(
self.update_current_metacluster_handler, type="change", names="value"
)
self.current_metacluster_displayname = widgets.Text(
value=self.mcd.get_metacluster_displayname(self.current_metacluster.value),
placeholder='Metacluster Displayname',
description='Edit Name:',
disabled=False,
)
self.current_metacluster_displayname.observe(
self.update_current_metacluster_displayname,
type="change",
names="value"
)
# group widgets to look nice
self.metacluster_info = widgets.VBox([
self.current_metacluster,
self.current_metacluster_displayname
])
self.tools = widgets.HBox([
self.zscore_clamp_slider,
self.clear_selection_button,
self.new_metacluster_button,
])
self.toolbar = widgets.HBox([
self.tools,
self.metacluster_info
])
self.toolbar.layout.justify_content = 'center'
self.plot_output = widgets.Output()
self.gui = widgets.VBox([self.plot_output, self.toolbar])
[docs] def move_dendro_labels(self, ax, dendrosplit_ratio=1.8):
"""Overlay axis labels directly onto a scipy dendrogram
Final image will use the ratio 1:dendrosplit_ratio
for tree_region:labels_region
Args:
ax (matplotlib.axes.Axes):
The axis containing the existing scipy dendrogram
dendrosplit_ratio (float):
How big to make the the labels compared to the tree
"""
def add_room_for_labels():
ax.set_axisbelow(False)
xlim = ax.get_xlim()
ax.set_xlim((xlim[0], -(xlim[0] * dendrosplit_ratio)))
def stretch_dendro_leaves():
for c in ax.collections:
for path in c.get_paths():
for v in path.vertices:
if v[0] == 0:
v[0] = ax.get_xlim()[1]
def get_ax_width_points(ax):
bbox = ax.get_window_extent().transformed(ax.figure.dpi_scale_trans.inverted())
return bbox.width * 72 # points = 1/72 in
def move_ax_labels():
dr = dendrosplit_ratio
width = get_ax_width_points(ax)
dedent = -(width * dr / (1 + dr))
ax.yaxis.set_tick_params(which='both', pad=dedent)
def restyle_ax_labels():
for lb in ax.get_yticklabels():
lb.set_path_effects([
path_effects.Stroke(linewidth=4, foreground='white'),
path_effects.Normal(),
])
lb.set_family('monospace')
lb.set_zorder(4)
add_room_for_labels()
stretch_dendro_leaves()
move_ax_labels()
restyle_ax_labels()
@property
def selection_mask(self):
"""2D boolean mask of shape (1,cluster_count) of currently selected clusters"""
def is_selected(cluster):
if cluster in self.selected_clusters:
return 1
else:
return 0
return [[is_selected(c) for c in self.mcd.clusters.index]]
[docs] def update_gui(self):
"""Update and redraw any updated GUI elements"""
self.im_cs.set_data(self.selection_mask)
self.im_cs.set_extent((0, self.mcd.cluster_count, 0, 1))
if not self._heatmaps_stale:
print("skipping other repaints")
self.fig.canvas.draw()
return
def _preplot(df):
return df.apply(zscore).clip(upper=self.zscore_clamp_slider.value).T
self.normalizer.calibrate(_preplot(self.mcd.clusters).values)
# clusters heatmap
self.im_c.set_data(_preplot(self.mcd.clusters))
self.im_c.set_extent((0, self.mcd.cluster_count, 0, self.mcd.marker_count))
self.im_c.set_clim(self.normalizer.vmin, self.normalizer.vmax)
# metaclusters heatmap
self.im_m.set_data(_preplot(self.mcd.metaclusters))
self.im_m.set_extent((0, self.mcd.metacluster_count, 0, self.mcd.marker_count))
self.im_m.set_clim(self.normalizer.vmin, self.normalizer.vmax)
# retrieve the current value of the zscore sliders
zscore_cap = self.zscore_clamp_slider.value
# due to delays, a zscore_cap modulo of 1 also needs to be considered here
# due to floating point error, allclose must be used
if np.allclose(zscore_cap % 1, 0) or np.allclose(zscore_cap % 1, 1):
new_ticks = np.arange(-zscore_cap, zscore_cap + 1)
else:
# fractional intervals are always in increments of 1/2
new_ticks = np.arange(-zscore_cap + 0.5, zscore_cap - 0.5 + 1)
new_ticks = np.insert(new_ticks, 0, -zscore_cap)
new_ticks = np.append(new_ticks, zscore_cap)
self.cb.ax.set_xticks(new_ticks)
# xaxis metacluster color labels
assert len(self.mcd.metaclusters.index) <= self.mcd.cluster_count, \
"Can't support num metaclusters > cluster count"
mc_cmap = distinct_cmap(self.mcd.cluster_count) # metaclusters < clusters
self.im_cl.set_data([self.mcd.clusters_with_metaclusters['metacluster']])
self.im_cl.set_extent((0, self.mcd.cluster_count, 0, 1))
self.im_cl.set_cmap(mc_cmap)
self.ax_ml.set_xticks(np.arange(self.mcd.metacluster_count) + 0.5)
self.ax_ml.set_xticklabels(self.mcd.metacluster_displaynames, rotation=90, fontsize=7)
self.im_ml.set_data([self.mcd.metaclusters.index])
self.im_ml.set_extent((0, self.mcd.metacluster_count, 0, 1))
self.im_ml.set_cmap(mc_cmap)
# xaxis pixelcount graphs
ax_cp_ymax = max(self.mcd.cluster_pixelcounts['count']) * 1.65
self.ax_cp.set_ylim(0, ax_cp_ymax)
sorted_pixel_counts = self.mcd.clusters.join(self.mcd.cluster_pixelcounts)['count']
for rect, h in zip(self.rects_cp, sorted_pixel_counts):
rect.set_height(h)
for label, y in zip(self.labels_cp, sorted_pixel_counts):
text = str(y)
label_y_spacing = ax_cp_ymax * 0.05
label.set_y(y + label_y_spacing)
label.set_text(text)
self.fig.canvas.draw()
self._heatmaps_stale = False
[docs] def enable_debug_mode(self):
"""Display the debug output widget as part of the GUI
This is used to route logging, output, and tracebacks that happen
in any of the event handler callbacks.
"""
self.fig.canvas.footer_visible = True
DEBUG_VIEW.clear_output()
DEBUG_VIEW.append_stdout("Debug mode started\n")
display(DEBUG_VIEW)
[docs] def remap_current_selection(self, metacluster):
"""Instruct the MetaClusterData to remap the selected clusters
All selected clusters will be remapped to the metacluster id which is passed
Args:
metacluster (int):
metacluster id to map the current selection to
"""
for cluster in self.selected_clusters:
print('remapping', cluster, metacluster)
self.mcd.remap(cluster, metacluster)
self._heatmaps_stale = True
self.mcd.save_output_mapping()
@DEBUG_VIEW.capture(clear_output=False)
def update_zscore(self, e):
self._heatmaps_stale = True
self.update_gui()
@DEBUG_VIEW.capture(clear_output=False)
def clear_selection(self, e):
self.selected_clusters.clear()
self.update_gui()
@DEBUG_VIEW.capture(clear_output=False)
def new_metacluster(self, e):
metacluster = self.mcd.new_metacluster()
self.remap_current_selection(metacluster)
self.update_current_metacluster(metacluster)
self.update_gui()
[docs] def update_current_metacluster_handler(self, t):
return self.update_current_metacluster(t.new)
@DEBUG_VIEW.capture(clear_output=False)
def update_current_metacluster(self, metacluster):
self.current_metacluster.options = \
list(zip(self.mcd.metacluster_displaynames, self.mcd.metaclusters.index))
self.current_metacluster.value = metacluster
self.current_metacluster_displayname.value = \
self.mcd.get_metacluster_displayname(metacluster)
@DEBUG_VIEW.capture(clear_output=False)
def update_current_metacluster_displayname(self, t):
self.mcd.change_displayname(self.current_metacluster.value, t.new)
old_current_metacluster = self.current_metacluster.value
self.current_metacluster.unobserve(
self.update_current_metacluster_handler, type="change", names="value"
)
self.current_metacluster.options = \
list(zip(self.mcd.metacluster_displaynames, self.mcd.metaclusters.index))
self.current_metacluster.value = old_current_metacluster
self.current_metacluster.observe(
self.update_current_metacluster_handler, type="change", names="value")
self._heatmaps_stale = True
self.update_gui()
@DEBUG_VIEW.capture(clear_output=False)
def onpick(self, e):
self.e = e
if e.mouseevent.name != 'button_press_event':
return
if e.mouseevent.button == 1:
self.onpick_select(e)
elif e.mouseevent.button == 3:
self.onpick_remap(e)
self.update_gui()
[docs] def onpick_select(self, e):
"""Handle or route for handling all clicks to any matplotlib plots."""
selected_ix = int(e.mouseevent.xdata)
if e.artist in [self.im_c, self.im_cs]:
selected_cluster = self.mcd.clusters.index[selected_ix]
# Toggle selection
if selected_cluster in self.selected_clusters:
self.selected_clusters.remove(selected_cluster)
else:
self.selected_clusters.add(selected_cluster)
elif e.artist in [self.im_m, self.im_ml]:
self.select_metacluster(self.mcd.metaclusters.index[selected_ix])
elif e.artist in [self.im_cl]:
selected_cluster = self.mcd.clusters_with_metaclusters.index[selected_ix]
metacluster = self.mcd.which_metacluster(cluster=selected_cluster)
self.select_metacluster(metacluster)
[docs] def select_metacluster(self, metacluster):
self.update_current_metacluster(metacluster)
clusters = self.mcd.cluster_in_metacluster(metacluster)
# Toggle entire metacluster
if all(c in self.selected_clusters for c in clusters):
# remove whole metacluster
self.selected_clusters.difference_update(clusters)
else:
# select whole metacluster
self.selected_clusters.update(clusters)
[docs] def onpick_remap(self, e):
selected_ix = int(e.mouseevent.xdata)
metacluster = None
if e.artist in [self.im_c, self.im_cs]:
selected_cluster = self.mcd.clusters.index[selected_ix]
metacluster = self.mcd.which_metacluster(cluster=selected_cluster)
elif e.artist in [self.im_m, self.im_ml]:
metacluster = self.mcd.metaclusters.index[selected_ix]
elif e.artist in [self.im_cl]:
selected_cluster = self.mcd.clusters_with_metaclusters.index[selected_ix]
metacluster = self.mcd.which_metacluster(cluster=selected_cluster)
self.update_current_metacluster(metacluster)
self.remap_current_selection(metacluster)