Source code for vega.plotting

import matplotlib.pyplot as plt
from matplotlib import rcParams
import matplotlib as mpl
from matplotlib.lines import Line2D
import seaborn as sns
import numpy as np
import pandas as pd
from scanpy.plotting import embedding
from scanpy import settings
from anndata import AnnData
from adjustText import adjust_text
from vega import VEGA
from typing import Union

[docs]def volcano(adata: AnnData, group1: str, group2: str, sig_lvl: float = 3., metric_lvl: float = 3., annotate_gmv: Union[str,list] = None, s:int = 10, fontsize: int = 10, textsize: int = 8, figsize: Union[tuple,list] = None, title: str = False, save: Union[str,bool] = False): """ Plot Differential GMV results. Please run the Bayesian differential acitvity method of VEGA before plotting ("model.differential_activity()") Parameters ---------- adata scanpy single-cell object group1 name of reference group group2 name of out-group sig_lvl absolute Bayes Factor cutoff (>=0) metric_lvl mean Absolute Difference cutoff (>=0) annotate_gmv GMV to be displayed. If None, all GMVs passing significance thresholds are displayed s dot size fontsize text size for axis textsize text size for GMV name display title title for plot save path to save figure as pdf """ # Check if Anndata is setup correctly if '_vega' not in adata.uns.keys(): raise ValueError('Anndata not setup. Please setup the Anndata object and train the VEGA model.') if 'differential' not in adata.uns['_vega'].keys(): raise ValueError('No differential activity results found in Anndata. Please run model.differential_activity()') # Check if group exists key_comp = group1 + ' vs.' + group2 if key_comp not in adata.uns['_vega']['differential'].keys(): raise ValueError('Group(s) not found. Available comparisons:{}'.format(list(adata.uns['_vega']['differential'].keys()))) dfe_res = adata.uns['_vega']['differential'][key_comp] mad = np.abs(dfe_res['differential_metric']) xlim_v = np.abs(dfe_res['bayes_factor']).max() + 0.5 ylim_v = mad.max()+0.5 idx_sig = np.arange(len(dfe_res['bayes_factor']))[(np.abs(dfe_res['bayes_factor'])>sig_lvl) & (mad>metric_lvl)] # Plotting fig, ax = plt.subplots(figsize=figsize) ax.scatter(dfe_res['bayes_factor'], mad, color='grey', s=s, alpha=0.8, linewidth=0, rasterized=settings._vector_friendly) ax.scatter(dfe_res['bayes_factor'][idx_sig], mad[idx_sig], color='red', s=s*2, linewidth=0, rasterized=settings._vector_friendly) ax.vlines(x=-sig_lvl, ymin=-0.5, ymax=ylim_v, color='black', linestyles='--', linewidth=1., alpha=0.2) ax.vlines(x=sig_lvl, ymin=-0.5, ymax=ylim_v, color='black', linestyles='--', linewidth=1., alpha=0.2) ax.hlines(y=metric_lvl, xmin=-xlim_v, xmax=xlim_v, color='black', linestyles='--', linewidth=1., alpha=0.2) texts = [] if not annotate_gmv: for i in idx_sig: name = adata.uns['_vega']['gmv_names'][i] x = dfe_res['bayes_factor'][i] y = mad[i] texts.append(plt.text(x=x, y=y, s=name, fontdict={'size':textsize})) else: for name in annotate_gmv: i = list(adata.uns['_vega']['gmv_names']).index(name) x = dfe_res['bayes_factor'][i] y = mad[i] texts.append(plt.text(x=x, y=y, s=name, fontdict={'size':textsize})) ax.set_xlabel(r'$\log_e$(Bayes factor)', fontsize=fontsize) ax.set_ylabel('|Differential Metric|', fontsize=fontsize) ax.set_ylim([0,ylim_v]) ax.set_xlim([-xlim_v,xlim_v]) if title: ax.set_title(title, fontsize=fontsize) adjust_text(texts, only_move={'texts':'xy'}, arrowprops=dict(arrowstyle="-", color='k', lw=0.5)) ax.tick_params(axis="x", labelsize=fontsize) ax.tick_params(axis="y", labelsize=fontsize) plt.grid(False) if save: plt.savefig(save, format=save.split('.')[-1], dpi=rcParams['savefig.dpi'], bbox_inches='tight') plt.show()
[docs]def gmv_embedding(adata: AnnData, x: str, y: str, color: str = None, palette: str = None, title: str = None, save: Union[str,bool] = False, sct_kwds: dict = None): """ 2-D scatter plot in GMV space. Parameters ---------- adata scanpy single-cell object. VEGA analysis needs to be run before x GMV name for x-coordinates (eg. 'REACTOME_INTERFERON_SIGNALING') y GMV name for y-coordinates (eg. 'REACTOME_INTERFERON_SIGNALING') color categorical field of Anndata.obs to color single-cells title plot title save path to save plot sct_kwds kwargs for matplotlib.pyplot.scatter function """ if 'X_vega' not in adata.obsm.keys(): raise ValueError("No GMV coordinates found in Anndata. Run 'adata.obsm['X_vega'] = model.to_latent()'") # Check if dim exist if not color: if not all([_check_exist(adata, x), _check_exist(adata, y)]): raise ValueError("At least one of passed (x, y) names not found in Anndata. Make sure those names exist.") else: if not all([_check_exist(adata, x), _check_exist(adata, y), _check_exist(adata, color)]): raise ValueError("At least one of passed (x, y, color) names not found in Anndata. Make sure those names exist.") x_i = list(adata.uns['_vega']['gmv_names']).index(x) y_i = list(adata.uns['_vega']['gmv_names']).index(y) dim1 = adata.obsm['X_vega'][:,x_i] dim2 = adata.obsm['X_vega'][:,y_i] color_val = _get_color_values(adata, color, palette) sct_kwds = {} if sct_kwds is None else sct_kwds.copy() plt.scatter(x=dim1, y=dim2, c=color_val, **sct_kwds) plt.xlabel(x) plt.ylabel(y) if title: plt.title(title) if save: plt.savefig(save, format='pdf', dpi=150, bbox_inches='tight') plt.show()
#def gmv_dotplot(): #return def _check_exist(adata, x): """ Check if dimension exist in Anndata. """ if (x in list(adata.obs)) or (x in list(adata.var)) or (x in adata.uns['_vega']['gmv_names']): exist = True else: exist = False return exist def _get_color_values(adata, var, palette): """ Value to color. TODO: Add support for gene variable.""" if (not var) and (var not in adata.uns['_vega']['gmv_names']) and (var not in list(adata.obs)): return "lightgray" elif var in list(adata.uns['_vega']['gmv_names']): if not palette: palette = 'viridis' cmap = mpl.cm.get_cmap(palette) val_vec = adata.obsm['X_vega'][:,list(adata.uns['_vega']['gmv_names']).index(var)] color_vec = cmap(val_vec) return color_vec else: if adata.obs[var].dtype == 'category': if not palette: palette = 'tab10' lbl = adata.obs[var].unique() n = len(lbl) cval = sns.color_palette(palette, n) color_map = dict(zip(lbl, cval)) color_vec = [color_map[k] for k in adata.obs[var]] return color_vec else: if not palette: palette = 'viridis' cmap = mpl.cm.get_cmap(palette) val_vec = adata.obs[var] color_vec = cmap(val_vec) return color_vec
[docs]def gmv_plot(adata: AnnData, x: str, y: str, color: str = None, title: str = None, palette: str = None): """ GMV embedding plot, but using the Scanpy plotting API. Parameters ---------- adata scanpy single-cell dataset x GMV name for x-coordinates (eg. 'REACTOME_INTERFERON_SIGNALING') y GMV name for x-coordinates (eg. 'REACTOME_INTERFERON_SIGNALING') color .obs field to color by title title for the plot palette matplotlib colormap to be used """ if 'X_vega' not in adata.obsm.keys(): raise ValueError("No GMV coordinates found in Anndata. Run 'adata.obsm['X_vega'] = model.to_latent()'") # Check if dim exist if not color: if not all([_check_exist(adata, x), _check_exist(adata, y)]): raise ValueError("At least one of passed (x, y) names not found in Anndata. Make sure those names exist.") else: if not all([_check_exist(adata, x), _check_exist(adata, y), _check_exist(adata, color)]): raise ValueError("At least one of passed (x, y, color) names not found in Anndata. Make sure those names exist.") # Components are indexed starting at 1 - so add 1 to indices x_i = list(adata.uns['_vega']['gmv_names']).index(x)+1 y_i = list(adata.uns['_vega']['gmv_names']).index(y)+1 # Call Scanpy embedding wrapper fig = embedding(adata, basis='X_vega', color=color, components=[x_i, y_i], title=title, palette=palette, return_fig=True, show=False).gca() fig.set_xlabel(x) fig.set_ylabel(y) plt.show() return
[docs]def loss(model: VEGA, plot_validation: bool = True): """ Plot training loss and validation if plot_validation is True. Parameters ---------- model VEGA model (trained) plot_validation Whether to plot validation loss as well """ train_hist = model.epoch_history['train_loss'] n_epochs = len(train_hist) plt.plot(np.arange(n_epochs), train_hist, label='Training loss', color='blue') if plot_validation: plt.plot(np.arange(n_epochs), model.epoch_history['valid_loss'], label='Validation loss', color='orange' ) plt.legend() plt.xlabel('Epoch') plt.ylabel('Loss') plt.show() return
[docs]def rank_gene_weights(model: VEGA, gmv_list: Union[str,list], n_genes: int = 10, color_in_set: bool = True, n_panels_per_row: int = 3, fontsize: int = 8, star_names: list = [], save: Union[bool,str] = False): """ Plot gene members of input GMVs according to their magnitude (abs(w)). Inspired by scanpy.pl.rank_gene_groups() API. Parameters ----------- model VEGA trained model gmv_list list of GMV names n_genes number of top gene to display color_in_set Whether to color genes annotated as part of GMVs differently. n_panels_per_row number of panels max. per row star_names Name of genes to be highlighted with stars save path to save figure """ if not model.is_trained_: raise ValueError('Model is not trained. Please train the model before.') w = model.decoder._get_weights().data gmv_names = list(model.adata.uns['_vega']['gmv_names']) gene_names = model.adata.var.index.tolist() n_panelx = min(n_panels_per_row, len(gmv_list)) n_panely = np.ceil(len(gmv_list) / n_panelx).astype(int) from matplotlib import gridspec fig = plt.figure( figsize=( n_panelx * rcParams['figure.figsize'][0], n_panely * rcParams['figure.figsize'][1], ) ) gs = gridspec.GridSpec(nrows=n_panely, ncols=n_panelx, wspace=0.22, hspace=0.3) ax0 = None ymin = np.Inf ymax = -np.Inf for l, k in enumerate(gmv_list): # Get values i = gmv_names.index(k) w_i = w[:,i].detach().numpy() sort_idx = np.argsort(np.abs(w_i))[::-1] abs_w = np.abs(w_i)[sort_idx][:n_genes] genes = np.array(gene_names)[sort_idx][:n_genes] # Set plot params ymin = np.min(abs_w) ymax = np.max(abs_w) ymax += 0.3*(ymax - ymin) ax = fig.add_subplot(gs[l]) ax.set_ylim(ymin, ymax) ax.set_xlim(-0.9, n_genes - 0.1) for ig, gene_name in enumerate(genes): if color_in_set: in_set = bool(model.adata.uns['_vega']['mask'][sort_idx[ig],i]) col = 'black' if in_set else 'red' else: col = 'black' gene_name += '*' if gene_name in star_names else '' ax.text( ig, abs_w[ig], gene_name, rotation='vertical', verticalalignment='bottom', horizontalalignment='center', fontsize=fontsize, color=col ) ax.set_title('{}'.format(k)) if l >= n_panelx * (n_panely - 1): ax.set_xlabel('ranking') if l % n_panelx == 0: ax.set_ylabel('Weight magnitude') if color_in_set: leg = [Line2D([0], [0], marker='o', color='w', label='In set', markerfacecolor='black', markersize=5), Line2D([0], [0], marker='o', color='w', label='Not in set', markerfacecolor='red', markersize=5)] plt.legend(handles=leg, loc='upper right') #plt.grid(False) if save: plt.savefig(save, format=save.split('.')[-1], dpi=300, bbox_inches='tight') plt.show() return
[docs]def weight_heatmap(model: VEGA, cluster: bool = True, cmap: str = 'viridis', display_gmvs: Union[str,list] = 'all', display_genes: Union[str,list] = 'all', title: str = None, figsize: Union[tuple,list]=None, save: Union[bool,str] = False, hm_kwargs: dict = None): """ Heatmap plots of weights. Parameters ---------- model VEGA trained model cluster if True, use hierarchical clustering (seaborn.clustermap) cmap colormap to use display_gmvs if all, display all latent variables weights. Else (list) only the subset display_genes if all, display all gene weights of GMV. Else (list) only the subset title figure title figsize figure size save path to save figure hm_kwargs kwargs for sns.clustermap or sns.heatmap (depending on if ``cluster=True``) """ if cluster: fn = sns.clustermap else: fn = sns.heatmap w = model.decoder._get_weights().data.numpy() gmv_names = model.adata.uns['_vega']['gmv_names'] gene_names = model.adata.var_names df = pd.DataFrame(data=w, index=gene_names, columns=gmv_names) if display_gmvs != 'all' and type(display_gmvs)==list: df = df[display_gmvs] if display_genes != 'all' and type(display_genes)==list: df = df.loc[display_genes] hm_kwargs = {} if hm_kwargs is None else hm_kwargs if figsize: fig = plt.figure(figsize) ax = fn(df.T, cmap=cmap, **hm_kwargs, cbar_kws={'label': 'Weight magnitude'}) ax.set_xlabel('Genes') if title: plt.title(title) if save: print('Saving figure at %s'%(save)) plt.savefig(save, format=save.split('.')[-1], dpi=300, bbox_inches='tight') plt.show() return
def _make_pretty(string): """ Make GMV name pretty """ if 'UNANNOTATED' in string: s = '_'.join(string.split('_')).lower() else: s = ' '.join(string.split('_')[1:]).lower() + ' activity' return s.capitalize()