Source code for vega.vega_count._vegamodel

import logging

from anndata import AnnData
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, RNASeqMixin, VAEMixin
from scvi._compat import Literal
from typing import Dict, Iterable, Optional, Sequence, Union
import torch
from ._vegamodule import SparseVAE
from vega.utils import *

logger = logging.getLogger(__name__)

[docs]class VegaSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): """ VEGA: VAE Enhanced by Gene Annotations [Seninge2021]_. Parameters ---------- adata AnnData object that has been registered via :func:`~scvi.data.setup_anndata`. gmt_paths A single or list of paths to .GMT files with gene annotations for GMVs initialization. add_nodes Number of additional fully-connected decoder nodes (unannotated GMVs). min_genes Minimum gene size for GMVs. max_genes Maximum gene size for GMVs. positive_decoder Whether to constrain decoder to positive weights. n_hidden Number of nodes per hidden layer. n_layers Number of hidden layers used for encoder NN. gene_likelihood Likelihood function for the generative model. dropout_rate Dropout rate for neural networks. use_cuda Using GPU with CUDA """ def __init__( self, adata: AnnData, gmt_paths: Literal = None, add_nodes: int = 1, min_genes: int = 0, max_genes: int = 5000, positive_decoder: bool = True, n_hidden: int = 600, n_layers: int = 2, gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb", dropout_rate: float = 0.1, z_dropout: float = 0, use_cuda: bool = True, **model_kwargs): super(VegaSCVI, self).__init__(adata) # Get attributes self.add_nodes_ = add_nodes self.min_genes_ = min_genes self.max_genes_ = max_genes # Check for setup and mask existence if '_vega' not in self.adata.uns.keys(): print('Initializing `_vega` field in adata.uns', flush=True) self.adata.uns['_vega'] = {} if 'mask' not in self.adata.uns['_vega'].keys() and not gmt_paths: raise ValueError('No existing mask found in Anndata object and no .gmt files passed to VEGA. Please provide .gmt file paths to initialize a new mask or use an Anndata object used for training of a previous VEGA model.') elif 'mask' not in self.adata.uns['_vega'] and gmt_paths: create_mask(self.adata, gmt_paths, add_nodes, self.min_genes_, self.max_genes_) self.gmv_mask = self.adata.uns['_vega']['mask'] self.n_gmvs = self.gmv_mask.shape[1] self.n_genes = self.gmv_mask.shape[0] n_cats_per_cov = (self.scvi_setup_dict_["extra_categoricals"]["n_cats_per_key"] if "extra_categoricals" in self.scvi_setup_dict_ else None) self.module = SparseVAE( n_input=self.n_genes, gmv_mask = self.gmv_mask, n_batch = self.summary_stats["n_batch"], n_continuous_cov = self.summary_stats["n_continuous_covs"], n_cats_per_cov = n_cats_per_cov, n_hidden = n_hidden, n_layers = n_layers, gene_likelihood = gene_likelihood, dropout_rate = dropout_rate, z_dropout = z_dropout, encode_covariates = False, **model_kwargs ) self.init_params_ = self._get_init_params(locals()) #@property #def _trainer_class(self): #return UnsupervisedTrainer #@property #def _scvi_dl_class(self): #return ScviDataLoader def differential_activity(self, adata: Optional[AnnData] = None, groupby: Optional[str] = None, group1: Optional[Iterable[str]] = None, group2: Optional[str] = None, idx1: Optional[Union[Sequence[int], Sequence[bool]]] = None, idx2: Optional[Union[Sequence[int], Sequence[bool]]] = None, mode: Literal["vanilla", "change"] = "change", delta: float = 0.25, batch_size: Optional[int] = None, all_stats: bool = True, batch_correction: bool = False, batchid1: Optional[Iterable[str]] = None, batchid2: Optional[Iterable[str]] = None, fdr_target: float = 0.05, silent: bool = False, **kwargs): adata = self._validate_anndata(adata) col_names = adata.uns['_vega']['gmv_names'] return result