Source code for ark.phenotyping.pixel_som_clustering

import multiprocessing
import os
from functools import partial
from shutil import move, rmtree
from typing import Any, Callable, Tuple

import feather
from alpineer import io_utils, misc_utils
from pyarrow.lib import ArrowInvalid

from ark.phenotyping import cluster_helpers, pixel_cluster_utils

multiprocessing.set_start_method('spawn', force=True)


[docs]def train_pixel_som(fovs, channels, base_dir, subset_dir='pixel_mat_subsetted', norm_vals_name='post_rowsum_chan_norm.feather', som_weights_name='pixel_som_weights.feather', xdim=10, ydim=10, lr_start=0.05, lr_end=0.01, num_passes=1, seed=42, overwrite=False): """Run the SOM training on the subsetted pixel data. Saves SOM weights to `base_dir/som_weights_name`. Args: fovs (list): The list of fovs to subset on channels (list): The list of markers to subset on base_dir (str): The path to the data directories subset_dir (str): The name of the subsetted data directory norm_vals_name (str): The name of the file to store the 99.9% normalization values som_weights_name (str): The name of the file to save the SOM weights to xdim (int): The number of x nodes to use for the SOM ydim (int): The number of y nodes to use for the SOM lr_start (float): The start learning rate for the SOM, decays to `lr_end` lr_end (float): The end learning rate for the SOM, decays from `lr_start` num_passes (int): The number of training passes to make through the dataset seed (int): The random seed to use for training the SOM overwrite (bool): If set, force retrains the SOM and overwrites the weights Returns: cluster_helpers.PixelSOMCluster: The SOM cluster object containing the pixel SOM weights """ # define the paths to the data subsetted_path = os.path.join(base_dir, subset_dir) norm_vals_path = os.path.join(base_dir, norm_vals_name) som_weights_path = os.path.join(base_dir, som_weights_name) # file path validation # NOTE: weights may or may not exist, that logic gets handled by PixelSOMCluster io_utils.validate_paths([subsetted_path, norm_vals_path]) # verify that all provided fovs exist in the folder files = io_utils.list_files(subsetted_path, substrs='.feather') misc_utils.verify_in_list(provided_fovs=fovs, subsetted_fovs=io_utils.remove_file_extensions(files)) # verify that all the provided channels exist in subsetted data sample_sub = feather.read_dataframe(os.path.join(subsetted_path, files[0])) misc_utils.verify_in_list(provided_channels=channels, subsetted_channels=sample_sub.columns.values) # define the pixel SOM cluster object pixel_pysom = cluster_helpers.PixelSOMCluster( subsetted_path, norm_vals_path, som_weights_path, fovs, channels, num_passes=num_passes, xdim=xdim, ydim=ydim, lr_start=lr_start, lr_end=lr_end, seed=seed ) # train the SOM weights # NOTE: seed has to be set in cyFlowSOM.pyx, done by passing flag in PixieSOMCluster print("Training SOM") pixel_pysom.train_som(overwrite=overwrite) return pixel_pysom
[docs]def run_pixel_som_assignment(pixel_data_path, pixel_pysom_obj, overwrite, num_parallel_pixels, fov): """Helper function to assign pixel SOM cluster labels Args: pixel_data_path (str): The path to the pixel data directory pixel_pysom_obj (ark.phenotyping.cluster_helpers.PixieConsensusCluster): The pixel SOM cluster object overwrite (bool): Whether to overwrite the pixel SOM clusters or not num_parallel_pixels (int): How many pixels to label in parallel at once for each FOV fov (str): The name of the FOV to process Returns: tuple (str, int): The name of the FOV as well as the return code """ # get the path to the fov fov_path = os.path.join(pixel_data_path, fov + '.feather') # read in the fov data with SOM labels try: fov_data = feather.read_dataframe(fov_path) # this indicates this fov file is corrupted except (ArrowInvalid, OSError, IOError): return fov, 1 # if the overwrite flag was set in cluster_pixels, drop the pixel_som_cluster column if overwrite: fov_data = fov_data.drop(columns="pixel_som_cluster", errors="ignore") # assign the SOM labels to fov_data, overwrite flag indicates if data needs normalization fov_data = pixel_pysom_obj.assign_som_clusters( fov_data, normalize_data=not overwrite, num_parallel_pixels=num_parallel_pixels ) # resave the data with the SOM cluster labels assigned temp_path = os.path.join(pixel_data_path + '_temp', fov + '.feather') feather.write_dataframe(fov_data, temp_path, compression='uncompressed') return fov, 0
[docs]def cluster_pixels(fovs, base_dir, pixel_pysom, data_dir='pixel_mat_data', multiprocess=False, batch_size=5, num_parallel_pixels=1000000, overwrite=False): """Uses trained SOM weights to assign cluster labels on full pixel data. Saves data with cluster labels to `data_dir`. Args: fovs (list): The list of fovs to subset on base_dir (str): The path to the data directory pixel_pysom (cluster_helpers.PixelSOMCluster): The SOM cluster object containing the pixel SOM weights data_dir (str): Name of the directory which contains the full preprocessed pixel data multiprocess (bool): Whether to use multiprocessing or not batch_size (int): The number of FOVs to process in parallel, ignored if `multiprocess` is `False` num_parallel_pixels (int): How many pixels to label in parallel at once for each FOV overwrite (bool): If set, force overwrite the SOM labels in all the FOVs """ # define the paths to the data data_path = os.path.join(base_dir, data_dir) # path validation io_utils.validate_paths([data_path]) # raise error if weights haven't been assigned to pixel_pysom if pixel_pysom.weights is None: raise ValueError("Using untrained pixel_pysom object, please invoke train_pixel_som first") # verify that all provided fovs exist in the folder # NOTE: remove the channel and pixel normalization files as those are not pixel data data_files = io_utils.list_files(data_path, substrs='.feather') misc_utils.verify_in_list(provided_fovs=fovs, subsetted_fovs=io_utils.remove_file_extensions(data_files)) # this will prevent reading in a corrupted sample_fov i = 0 while i < len(data_files): try: sample_fov = feather.read_dataframe(os.path.join(base_dir, data_dir, data_files[i])) if "segmentation_label" in sample_fov.columns: sample_fov.rename( columns={"segmentation_label": "label"}, inplace=True) except (ArrowInvalid, OSError, IOError): i += 1 continue break # for verification purposes, drop the metadata columns cols_to_drop = ['fov', 'row_index', 'column_index'] for col in ['label', 'pixel_som_cluster', 'pixel_meta_cluster', 'pixel_meta_cluster_rename']: if col in sample_fov.columns.values: cols_to_drop.append(col) sample_fov = sample_fov.drop( columns=cols_to_drop ) misc_utils.verify_same_elements( enforce_order=True, norm_vals_columns=pixel_pysom.norm_data.columns.values, pixel_data_columns=sample_fov.columns.values ) # ensure the SOM weights columns are valid indexes misc_utils.verify_same_elements( enforce_order=True, pixel_som_weights_columns=pixel_pysom.weights.columns.values, pixel_data_columns=sample_fov.columns.values ) # if overwrite flag set, run on all FOVs in data_dir, make sure to reset SOM clusters seen if overwrite: print('Overwrite flag set, reassigning SOM cluster labels to all FOVs') pixel_pysom.som_clusters_seen = set() os.mkdir(data_path + '_temp') fovs_list = io_utils.remove_file_extensions( io_utils.list_files(data_path, substrs='.feather') ) # otherwise, only assign SOM clusters to FOVs that don't already have them else: fovs_list = pixel_cluster_utils.find_fovs_missing_col( base_dir, data_dir, 'pixel_som_cluster' ) # make sure fovs_list only contain fovs that exist in the master fovs list specified fovs_list = list(set(fovs_list).intersection(fovs)) # if there are no FOVs left without SOM labels don't run function if len(fovs_list) == 0: print("There are no more FOVs to assign SOM labels to, skipping") return # if SOM cluster labeling is only partially complete, inform the user of restart if len(fovs_list) < len(fovs): print("Restarting SOM label assignment from fov %s, " "%d fovs left to process" % (fovs_list[0], len(fovs_list))) # define variable to keep track of number of fovs processed fovs_processed = 0 # define the partial function to iterate over fov_data_func = partial( run_pixel_som_assignment, data_path, pixel_pysom, overwrite, num_parallel_pixels ) # use the som weights to assign SOM cluster values to data in data_dir print("Mapping pixel data to SOM cluster labels") if multiprocess: with multiprocessing.get_context('spawn').Pool(batch_size) as fov_data_pool: for fov_batch in [fovs_list[i:(i + batch_size)] for i in range(0, len(fovs_list), batch_size)]: fov_statuses = fov_data_pool.map(fov_data_func, fov_batch) for fs in fov_statuses: if fs[1] == 1: print("The data for FOV %s has been corrupted, skipping" % fs[0]) fovs_processed -= 1 # update number of fovs processed fovs_processed += len(fov_batch) print("Processed %d fovs" % fovs_processed) else: for fov in fovs_list: fov_status = fov_data_func(fov) if fov_status[1] == 1: print("The data for FOV %s has been corrupted, skipping" % fov_status[0]) fovs_processed -= 1 # update number of fovs processed fovs_processed += 1 # update every 10 FOVs, or at the very end if fovs_processed % 10 == 0 or fovs_processed == len(fovs_list): print("Processed %d fovs" % fovs_processed) # remove the data directory and rename the temp directory to the data directory rmtree(data_path, onerror=_ignore_extended_attributes) move(data_path + '_temp', data_path)
def _ignore_extended_attributes(func: Callable, filename: str, exc_info: Tuple[Any, Any, Any]): """ Ignore the extended attribute files (prefixed with "._"). Read more here: https://tinyurl.com/extended-attributes Args: func (Callable): The function which raises the exception. filename (str): The file where an extended attribute file fails to remove. This originally gets passed into `func`. exc_info (OptExcInfo): The exception information returned by `sys.exec_info()`. """ is_meta_file: bool = os.path.basename(filename).startswith("._") if not (func is os.unlink and is_meta_file): raise
[docs]def generate_som_avg_files(fovs, channels, base_dir, pixel_pysom, data_dir='pixel_data_dir', pc_chan_avg_som_cluster_name='pixel_channel_avg_som_cluster.csv', num_fovs_subset=100, require_all_som_clusters=True, seed=42, overwrite=False): """Computes and saves the average channel expression across pixel SOM clusters. Args: fovs (list): The list of fovs to subset on channels (list): The list of channels to subset on base_dir (str): The path to the data directory pixel_pysom (cluster_helpers.PixelSOMCluster): The SOM cluster object containing the pixel SOM weights data_dir (str): Name of the directory which contains the full preprocessed pixel data pc_chan_avg_som_cluster_name (str): The name of the file to save the average channel expression across all SOM clusters num_fovs_subset (int): The number of FOVs to subset on for SOM cluster channel averaging require_all_som_clusters (bool): Whether to require all SOM clusters to have at least one pixel assigned seed (int): The random seed to set for subsetting FOVs overwrite (bool): If set, force overwrite the existing average channel expression file if it exists """ # define the paths to the data som_cluster_avg_path = os.path.join(base_dir, pc_chan_avg_som_cluster_name) # raise error if weights haven't been assigned to pixel_pysom if pixel_pysom.weights is None: raise ValueError("Using untrained pixel_pysom object, please invoke train_som first") # if the channel SOM average file already exists and the overwrite flag isn't set, skip if os.path.exists(som_cluster_avg_path): if not overwrite: print("Already generated SOM cluster channel average file, skipping") return print("Overwrite flag set, regenerating SOM cluster channel average file") # compute average channel expression for each pixel SOM cluster # and the number of pixels per SOM cluster print("Computing average channel expression across pixel SOM clusters") pixel_channel_avg_som_cluster = pixel_cluster_utils.compute_pixel_cluster_channel_avg( fovs, channels, base_dir, 'pixel_som_cluster', len(pixel_pysom.som_clusters_seen) if require_all_som_clusters else None, data_dir, num_fovs_subset=num_fovs_subset, seed=seed, keep_count=True ) # save pixel_channel_avg_som_cluster pixel_channel_avg_som_cluster.to_csv( som_cluster_avg_path, index=False )