Source code for sctop.visualization

import matplotlib.pyplot as plt
import seaborn as sns
import os
from typing import Optional, List, Dict
from .utils import *
from .processing import *

[docs] def create_colorbar(data, label, colormap='rocket_r', ax = None): ax = ax or plt.gca() cmap = plt.get_cmap(colormap) scalarmap = plt.cm.ScalarMappable(norm=plt.Normalize(min(data), max(data)), cmap=cmap) scalarmap.set_array([]) plt.colorbar(scalarmap, label=label, ax = ax) return cmap
[docs] def plot_highest(projections, n=10, ax=None, color="olive", fontsize=40, **kwargs): """ Plots a horizontal bar chart of the top N projections with a fixed x-axis scale. """ ax = ax or plt.gca() projections_sorted = projections.sort_values(by=projections.columns[0]) projections_top_n = projections_sorted.iloc[-n:] projections_top_n.plot.barh(ax=ax, color=color, legend=False, **kwargs) # --- Adjustments for Presentation --- ax.tick_params(axis='x', labelsize=fontsize) ax.tick_params(axis='y', labelsize=fontsize) xlabel = projections.columns[0] ylabel = projections.index.name or 'Items' ax.set_xlabel(xlabel, fontsize=fontsize) ax.set_ylabel(ylabel, fontsize=fontsize) ax.set_xlim(0.0, 1.0) ax.grid(axis='x', linestyle='--', alpha=0.6) return ax
[docs] def plot_expression_distribution(scores, n=10, ax=None, box_color="skyblue", fontsize=30, **kwargs): """ Plots boxplots of expression for top genes with a fixed y-axis scale. """ ax = ax or plt.gca() gene_meds = scores.median(axis=1).sort_values(ascending=False) top_n_genes = gene_meds.head(n).index data_to_plot_genes = scores.loc[top_n_genes].T melted_data = data_to_plot_genes.melt(var_name="", value_name="Expression") order_genes = top_n_genes.tolist() sns.boxplot( x="", y="Expression", data=melted_data, ax=ax, color=box_color, order=order_genes, showfliers=True, **kwargs ) ax.set_xticklabels(ax.get_xticklabels(), rotation=30, ha='right', fontsize=fontsize) ax.tick_params(axis='y', labelsize=fontsize) ax.set_ylabel("Expression", fontsize=fontsize) return ax
[docs] def plot_two(projections, celltype1, celltype2, gene=None, gene_expressions=None, ax=None, **kwargs): ax = ax or plt.gca() if gene: palette = create_colorbar(gene_expressions.loc[gene], '{} expression'.format(gene), ax = ax) plot = sns.scatterplot(x = projections.loc[celltype1], y = projections.loc[celltype2], hue = gene_expressions.loc[gene], palette = palette, alpha = 0.5, ax = ax, **kwargs ) plot.legend_.remove() else: sns.scatterplot(x = projections.loc[celltype1], y = projections.loc[celltype2], alpha=0.5, ax=ax, **kwargs )
[docs] def plot_all_contributions( results: Dict[str, Dict], sample_names: List[str], output_dir: Optional[str] = None, highlight_genes: Optional[Dict[str, List[str]]] = None, dpi: int = 150, **plot_kwargs ) -> None: """ Create and save contribution plots for all cell types and samples. Parameters ---------- results : dict Results from analyze_sample_contributions sample_names : list List of sample names to plot output_dir : str, optional Base directory for saving plots. If None, uses current directory highlight_genes : dict, optional Dictionary mapping cell_type -> [genes_to_highlight] dpi : int DPI for saved images **plot_kwargs Additional kwargs passed to plot_gene_contribution_scatter """ if output_dir is None: output_dir = "." # Compute predictivity from first cell type (should be same for all) basis_needed = False # We'll get predictivity from results for cell_type, cell_data in results.items(): # Create directory for this cell type cell_type_dir = os.path.join(output_dir, f"gene_contributions_{cell_type}") os.makedirs(cell_type_dir, exist_ok=True) print(f"Creating plots for {cell_type}...") for sample_name in sample_names: if sample_name not in cell_data['expressions']: print(f" Warning: {sample_name} not found in results, skipping") continue # Get data expression = cell_data['expressions'][sample_name] contributions = cell_data['contributions'][sample_name] top_genes = cell_data['top_genes'][sample_name] # Compute predictivity from contributions and expression # predictivity = contributions / expression (approximately) # But better to pass it separately common_genes = contributions.index mean_expression = expression.loc[common_genes].mean(axis=1) mean_contribution = contributions.mean(axis=1) # Approximate predictivity for this cell type predictivity_approx = mean_contribution / (mean_expression + 1e-10) # Get highlight genes for this cell type highlight = None if highlight_genes and cell_type in highlight_genes: highlight = highlight_genes[cell_type] # Create plot fig, ax = plt.subplots(figsize=plot_kwargs.get('figsize', (15, 8))) # Use mean expression and mean contribution for plotting ax.scatter(mean_expression, mean_contribution, color='gray', alpha=0.6, label='All Genes', s=3) # Highlight top genes ax.scatter(mean_expression[top_genes], mean_contribution[top_genes], color='blue', label=f'Top {len(top_genes)} Genes', s=10) # Annotate texts = [] for gene in top_genes: x = mean_expression[gene] y = mean_contribution[gene] texts.append(ax.text(x, y, gene, fontsize=plot_kwargs.get('fontsize_annotations', 18))) # Highlight special genes if highlight: for gene in highlight: if gene in mean_expression.index: x = mean_expression[gene] y = mean_contribution[gene] ax.scatter(x, y, color='red', s=20, zorder=10) ax.text(x, y, gene, fontsize=18, color='red', weight='bold') # Formatting fontsize_labels = plot_kwargs.get('fontsize_labels', 20) fontsize_title = plot_kwargs.get('fontsize_title', 30) fontsize_legend = plot_kwargs.get('fontsize_legend', 14) ax.set_xlabel('Mean Gene Expression', fontsize=fontsize_labels) ax.set_ylabel(f'Mean Contribution to {cell_type} Score', fontsize=fontsize_labels) ax.set_title(f'{sample_name}', fontsize=fontsize_title) ax.legend(fontsize=fontsize_legend) ax.grid(True, alpha=0.3) # Save filename = os.path.join(cell_type_dir, f"{sample_name}.png") plt.savefig(filename, dpi=dpi, bbox_inches='tight') plt.show() plt.close(fig) print(f" Saved plots to {cell_type_dir}/")