pertpy 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. pertpy/__init__.py +4 -2
  2. pertpy/data/__init__.py +66 -1
  3. pertpy/data/_dataloader.py +28 -26
  4. pertpy/data/_datasets.py +261 -92
  5. pertpy/metadata/__init__.py +6 -0
  6. pertpy/metadata/_cell_line.py +795 -0
  7. pertpy/metadata/_compound.py +128 -0
  8. pertpy/metadata/_drug.py +238 -0
  9. pertpy/metadata/_look_up.py +569 -0
  10. pertpy/metadata/_metadata.py +70 -0
  11. pertpy/metadata/_moa.py +125 -0
  12. pertpy/plot/__init__.py +0 -13
  13. pertpy/preprocessing/__init__.py +2 -0
  14. pertpy/preprocessing/_guide_rna.py +89 -6
  15. pertpy/tools/__init__.py +48 -15
  16. pertpy/tools/_augur.py +329 -32
  17. pertpy/tools/_cinemaot.py +145 -6
  18. pertpy/tools/_coda/_base_coda.py +1237 -116
  19. pertpy/tools/_coda/_sccoda.py +66 -36
  20. pertpy/tools/_coda/_tasccoda.py +46 -39
  21. pertpy/tools/_dialogue.py +180 -77
  22. pertpy/tools/_differential_gene_expression/__init__.py +20 -0
  23. pertpy/tools/_differential_gene_expression/_base.py +657 -0
  24. pertpy/tools/_differential_gene_expression/_checks.py +41 -0
  25. pertpy/tools/_differential_gene_expression/_dge_comparison.py +86 -0
  26. pertpy/tools/_differential_gene_expression/_edger.py +125 -0
  27. pertpy/tools/_differential_gene_expression/_formulaic.py +189 -0
  28. pertpy/tools/_differential_gene_expression/_pydeseq2.py +95 -0
  29. pertpy/tools/_differential_gene_expression/_simple_tests.py +162 -0
  30. pertpy/tools/_differential_gene_expression/_statsmodels.py +72 -0
  31. pertpy/tools/_distances/_distance_tests.py +29 -24
  32. pertpy/tools/_distances/_distances.py +584 -98
  33. pertpy/tools/_enrichment.py +460 -0
  34. pertpy/tools/_kernel_pca.py +1 -1
  35. pertpy/tools/_milo.py +406 -49
  36. pertpy/tools/_mixscape.py +677 -55
  37. pertpy/tools/_perturbation_space/_clustering.py +10 -3
  38. pertpy/tools/_perturbation_space/_comparison.py +112 -0
  39. pertpy/tools/_perturbation_space/_discriminator_classifiers.py +524 -0
  40. pertpy/tools/_perturbation_space/_perturbation_space.py +146 -52
  41. pertpy/tools/_perturbation_space/_simple.py +52 -11
  42. pertpy/tools/_scgen/__init__.py +1 -1
  43. pertpy/tools/_scgen/_base_components.py +2 -3
  44. pertpy/tools/_scgen/_scgen.py +706 -0
  45. pertpy/tools/_scgen/_utils.py +3 -5
  46. pertpy/tools/decoupler_LICENSE +674 -0
  47. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/METADATA +48 -20
  48. pertpy-0.8.0.dist-info/RECORD +57 -0
  49. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/WHEEL +1 -1
  50. pertpy/plot/_augur.py +0 -234
  51. pertpy/plot/_cinemaot.py +0 -81
  52. pertpy/plot/_coda.py +0 -1001
  53. pertpy/plot/_dialogue.py +0 -91
  54. pertpy/plot/_guide_rna.py +0 -82
  55. pertpy/plot/_milopy.py +0 -284
  56. pertpy/plot/_mixscape.py +0 -594
  57. pertpy/plot/_scgen.py +0 -337
  58. pertpy/tools/_differential_gene_expression.py +0 -99
  59. pertpy/tools/_metadata/__init__.py +0 -0
  60. pertpy/tools/_metadata/_cell_line.py +0 -613
  61. pertpy/tools/_metadata/_look_up.py +0 -342
  62. pertpy/tools/_perturbation_space/_discriminator_classifier.py +0 -381
  63. pertpy/tools/_scgen/_jax_scgen.py +0 -370
  64. pertpy-0.6.0.dist-info/RECORD +0 -50
  65. /pertpy/tools/_scgen/{_jax_scgenvae.py → _scgenvae.py} +0 -0
  66. {pertpy-0.6.0.dist-info → pertpy-0.8.0.dist-info}/licenses/LICENSE +0 -0
@@ -8,8 +8,8 @@ import numpy as np
8
8
  import numpyro as npy
9
9
  import numpyro.distributions as npd
10
10
  from anndata import AnnData
11
- from jax import random
12
- from jax.config import config
11
+ from jax import config, random
12
+ from lamin_utils import logger
13
13
  from mudata import MuData
14
14
  from numpyro.infer import Predictive
15
15
  from rich import print
@@ -23,7 +23,6 @@ config.update("jax_enable_x64", True)
23
23
 
24
24
 
25
25
  class Sccoda(CompositionalModel2):
26
-
27
26
  """
28
27
  Statistical model for single-cell differential composition analysis with specification of a reference cell type.
29
28
  This is the standard scCODA model and recommended for all uses.
@@ -75,13 +74,13 @@ class Sccoda(CompositionalModel2):
75
74
  adata: AnnData object.
76
75
  type : Specify the input adata type, which could be either a cell-level AnnData or an aggregated sample-level AnnData.
77
76
  generate_sample_level: Whether to generate an AnnData object on the sample level or create an empty AnnData object.
78
- cell_type_identifier: If type is "cell_level", specify column name in adata.obs that specifies the cell types. Defaults to None.
79
- sample_identifier: If type is "cell_level", specify column name in adata.obs that specifies the sample. Defaults to None.
80
- covariate_uns: If type is "cell_level", specify key for adata.uns, where covariate values are stored. Defaults to None.
81
- covariate_obs: If type is "cell_level", specify list of keys for adata.obs, where covariate values are stored. Defaults to None.
82
- covariate_df: If type is "cell_level", specify dataFrame with covariates. Defaults to None.
83
- modality_key_1: Key to the cell-level AnnData in the MuData object. Defaults to "rna".
84
- modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object. Defaults to "coda".
77
+ cell_type_identifier: If type is "cell_level", specify column name in adata.obs that specifies the cell types.
78
+ sample_identifier: If type is "cell_level", specify column name in adata.obs that specifies the sample.
79
+ covariate_uns: If type is "cell_level", specify key for adata.uns, where covariate values are stored.
80
+ covariate_obs: If type is "cell_level", specify list of keys for adata.obs, where covariate values are stored.
81
+ covariate_df: If type is "cell_level", specify dataFrame with covariates.
82
+ modality_key_1: Key to the cell-level AnnData in the MuData object.
83
+ modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
85
84
 
86
85
  Returns:
87
86
  MuData: MuData object with cell-level AnnData (`mudata[modality_key_1]`) and aggregated sample-level AnnData (`mudata[modality_key_2]`).
@@ -90,8 +89,11 @@ class Sccoda(CompositionalModel2):
90
89
  >>> import pertpy as pt
91
90
  >>> haber_cells = pt.dt.haber_2017_regions()
92
91
  >>> sccoda = pt.tl.Sccoda()
93
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
94
- sample_identifier="batch", covariate_obs=["condition"])
92
+ >>> mdata = sccoda.load(haber_cells,
93
+ >>> type="cell_level",
94
+ >>> generate_sample_level=True,
95
+ >>> cell_type_identifier="cell_label",
96
+ >>> sample_identifier="batch", covariate_obs=["condition"])
95
97
  """
96
98
  if type == "cell_level":
97
99
  if generate_sample_level:
@@ -126,10 +128,10 @@ class Sccoda(CompositionalModel2):
126
128
  Categorical covariates are handled automatically, with the covariate value of the first sample being used as the reference category.
127
129
  To set a different level as the base category for a categorical covariate, use "C(<CovariateName>, Treatment('<ReferenceLevelName>'))"
128
130
  reference_cell_type: Column name that sets the reference cell type.
129
- Reference the name of a column. If "automatic", the cell type with the lowest dispersion in relative abundance that is present in at least 90% of samlpes will be chosen. Defaults to "automatic".
131
+ Reference the name of a column. If "automatic", the cell type with the lowest dispersion in relative abundance that is present in at least 90% of samlpes will be chosen.
130
132
  automatic_reference_absence_threshold: If using reference_cell_type = "automatic", determine the maximum fraction of zero entries for a cell type
131
- to be considered as a possible reference cell type. Defaults to 0.05.
132
- modality_key: If data is a MuData object, specify key to the aggregated sample-level AnnData object in the MuData object. Defaults to "coda".
133
+ to be considered as a possible reference cell type.
134
+ modality_key: If data is a MuData object, specify key to the aggregated sample-level AnnData object in the MuData object.
133
135
 
134
136
  Returns:
135
137
  Return an AnnData (if input data is an AnnData object) or return a MuData (if input data is a MuData object)
@@ -144,8 +146,12 @@ class Sccoda(CompositionalModel2):
144
146
  >>> import pertpy as pt
145
147
  >>> haber_cells = pt.dt.haber_2017_regions()
146
148
  >>> sccoda = pt.tl.Sccoda()
147
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
148
- sample_identifier="batch", covariate_obs=["condition"])
149
+ >>> mdata = sccoda.load(haber_cells,
150
+ >>> type="cell_level",
151
+ >>> generate_sample_level=True,
152
+ >>> cell_type_identifier="cell_label",
153
+ >>> sample_identifier="batch",
154
+ >>> covariate_obs=["condition"])
149
155
  >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
150
156
  """
151
157
  if isinstance(data, MuData):
@@ -193,10 +199,14 @@ class Sccoda(CompositionalModel2):
193
199
  >>> import pertpy as pt
194
200
  >>> haber_cells = pt.dt.haber_2017_regions()
195
201
  >>> sccoda = pt.tl.Sccoda()
196
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
197
- sample_identifier="batch", covariate_obs=["condition"])
202
+ >>> mdata = sccoda.load(haber_cells,
203
+ >>> type="cell_level",
204
+ >>> generate_sample_level=True,
205
+ >>> cell_type_identifier="cell_label",
206
+ >>> sample_identifier="batch",
207
+ >>> covariate_obs=["condition"])
198
208
  >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
199
- >>> adata = sccoda.set_init_mcmc_states(rng_key=42, ref_index=0, sample_adata=mdata['coda'])
209
+ >>> adata = sccoda.set_init_mcmc_states(rng_key=42, ref_index=0, sample_adata=mdata["coda"])
200
210
  """
201
211
  # data dimensions
202
212
  N, D = sample_adata.obsm["covariate_matrix"].shape
@@ -300,10 +310,10 @@ class Sccoda(CompositionalModel2):
300
310
 
301
311
  Args:
302
312
  data: AnnData object or MuData object.
303
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
304
- rng_key: The rng state used for the prior simulation. If None, a random state will be selected. Defaults to None.
305
- num_prior_samples: Number of prior samples calculated. Defaults to 500.
306
- use_posterior_predictive: If True, the posterior predictive will be calculated. Defaults to True.
313
+ modality_key: If data is a MuData object, specify which modality to use.
314
+ rng_key: The rng state used for the prior simulation. If None, a random state will be selected.
315
+ num_prior_samples: Number of prior samples calculated.
316
+ use_posterior_predictive: If True, the posterior predictive will be calculated.
307
317
 
308
318
  Returns:
309
319
  az.InferenceData: arviz_data with all MCMC information
@@ -312,8 +322,12 @@ class Sccoda(CompositionalModel2):
312
322
  >>> import pertpy as pt
313
323
  >>> haber_cells = pt.dt.haber_2017_regions()
314
324
  >>> sccoda = pt.tl.Sccoda()
315
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
316
- sample_identifier="batch", covariate_obs=["condition"])
325
+ >>> mdata = sccoda.load(haber_cells,
326
+ >>> type="cell_level",
327
+ >>> generate_sample_level=True,
328
+ >>> cell_type_identifier="cell_label",
329
+ >>> sample_identifier="batch",
330
+ >>> covariate_obs=["condition"])
317
331
  >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
318
332
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
319
333
  >>> arviz_data = sccoda.make_arviz(mdata, num_prior_samples=100)
@@ -322,7 +336,7 @@ class Sccoda(CompositionalModel2):
322
336
  try:
323
337
  sample_adata = data[modality_key]
324
338
  except IndexError:
325
- print("When data is a MuData object, modality_key must be specified!")
339
+ logger.error("When data is a MuData object, modality_key must be specified!")
326
340
  raise
327
341
  if isinstance(data, AnnData):
328
342
  sample_adata = data
@@ -365,7 +379,7 @@ class Sccoda(CompositionalModel2):
365
379
 
366
380
  if rng_key is None:
367
381
  rng = np.random.default_rng()
368
- rng_key = random.PRNGKey(rng.integers(0, 10000))
382
+ rng_key = random.key(rng.integers(0, 10000))
369
383
 
370
384
  if use_posterior_predictive:
371
385
  posterior_predictive = Predictive(self.model, self.mcmc.get_samples())(
@@ -414,8 +428,12 @@ class Sccoda(CompositionalModel2):
414
428
  >>> import pertpy as pt
415
429
  >>> haber_cells = pt.dt.haber_2017_regions()
416
430
  >>> sccoda = pt.tl.Sccoda()
417
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
418
- sample_identifier="batch", covariate_obs=["condition"])
431
+ >>> mdata = sccoda.load(haber_cells,
432
+ >>> type="cell_level",
433
+ >>> generate_sample_level=True,
434
+ >>> cell_type_identifier="cell_label",
435
+ >>> sample_identifier="batch",
436
+ >>> covariate_obs=["condition"])
419
437
  >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
420
438
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
421
439
  """
@@ -429,8 +447,12 @@ class Sccoda(CompositionalModel2):
429
447
  >>> import pertpy as pt
430
448
  >>> haber_cells = pt.dt.haber_2017_regions()
431
449
  >>> sccoda = pt.tl.Sccoda()
432
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
433
- sample_identifier="batch", covariate_obs=["condition"])
450
+ >>> mdata = sccoda.load(haber_cells,
451
+ >>> type="cell_level",
452
+ >>> generate_sample_level=True,
453
+ >>> cell_type_identifier="cell_label",
454
+ >>> sample_identifier="batch",
455
+ >>> covariate_obs=["condition"])
434
456
  >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
435
457
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
436
458
  >>> credible_effects = sccoda.credible_effects(mdata)
@@ -445,8 +467,12 @@ class Sccoda(CompositionalModel2):
445
467
  >>> import pertpy as pt
446
468
  >>> haber_cells = pt.dt.haber_2017_regions()
447
469
  >>> sccoda = pt.tl.Sccoda()
448
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
449
- sample_identifier="batch", covariate_obs=["condition"])
470
+ >>> mdata = sccoda.load(haber_cells,
471
+ >>> type="cell_level",
472
+ >>> generate_sample_level=True,
473
+ >>> cell_type_identifier="cell_label",
474
+ >>> sample_identifier="batch",
475
+ >>> covariate_obs=["condition"])
450
476
  >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
451
477
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
452
478
  >>> sccoda.summary(mdata)
@@ -461,8 +487,12 @@ class Sccoda(CompositionalModel2):
461
487
  >>> import pertpy as pt
462
488
  >>> haber_cells = pt.dt.haber_2017_regions()
463
489
  >>> sccoda = pt.tl.Sccoda()
464
- >>> mdata = sccoda.load(haber_cells, type="cell_level", generate_sample_level=True, cell_type_identifier="cell_label", \
465
- sample_identifier="batch", covariate_obs=["condition"])
490
+ >>> mdata = sccoda.load(haber_cells,
491
+ >>> type="cell_level",
492
+ >>> generate_sample_level=True,
493
+ >>> cell_type_identifier="cell_label",
494
+ >>> sample_identifier="batch",
495
+ >>> covariate_obs=["condition"])
466
496
  >>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
467
497
  >>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
468
498
  >>> sccoda.set_fdr(mdata, est_fdr=0.4)
@@ -3,18 +3,16 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING, Literal
4
4
 
5
5
  import arviz as az
6
- import ete3 as ete
7
6
  import jax.numpy as jnp
8
7
  import numpy as np
9
8
  import numpyro as npy
10
9
  import numpyro.distributions as npd
11
10
  import toytree as tt
12
11
  from anndata import AnnData
13
- from jax import random
14
- from jax.config import config
12
+ from jax import config, random
13
+ from lamin_utils import logger
15
14
  from mudata import MuData
16
15
  from numpyro.infer import Predictive
17
- from rich import print
18
16
 
19
17
  from pertpy.tools._coda._base_coda import (
20
18
  CompositionalModel2,
@@ -87,25 +85,25 @@ class Tasccoda(CompositionalModel2):
87
85
  Args:
88
86
  adata: AnnData object.
89
87
  type: Specify the input adata type, which could be either a cell-level AnnData or an aggregated sample-level AnnData.
90
- cell_type_identifier: If type is "cell_level", specify column name in adata.obs that specifies the cell types. Defaults to None.
91
- sample_identifier: If type is "cell_level", specify column name in adata.obs that specifies the sample. Defaults to None.
92
- covariate_uns: If type is "cell_level", specify key for adata.uns, where covariate values are stored. Defaults to None.
93
- covariate_obs: If type is "cell_level", specify list of keys for adata.obs, where covariate values are stored. Defaults to None.
94
- covariate_df: If type is "cell_level", specify dataFrame with covariates. Defaults to None.
95
- dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object. Defaults to None.
96
- levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None.
97
- levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level. Defaults to None.
98
- add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}. Defaults to False.
99
- key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`. If `data` is MuData, save tree in data[modality_2]. Defaults to "tree".
100
- modality_key_1: Key to the cell-level AnnData in the MuData object. Defaults to "rna".
101
- modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object. Defaults to "coda".
88
+ cell_type_identifier: If type is "cell_level", specify column name in adata.obs that specifies the cell types.
89
+ sample_identifier: If type is "cell_level", specify column name in adata.obs that specifies the sample.
90
+ covariate_uns: If type is "cell_level", specify key for adata.uns, where covariate values are stored.
91
+ covariate_obs: If type is "cell_level", specify list of keys for adata.obs, where covariate values are stored.
92
+ covariate_df: If type is "cell_level", specify dataFrame with covariates.
93
+ dendrogram_key: Key to the scanpy.tl.dendrogram result in `.uns` of original cell level anndata object.
94
+ levels_orig: List that indicates which columns in `.obs` of the original data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
95
+ levels_agg: List that indicates which columns in `.var` of the aggregated data correspond to tree levels. The list must begin with the root level, and end with the leaf level.
96
+ add_level_name: If True, internal nodes in the tree will be named as "{level_name}_{node_name}" instead of just {level_name}.
97
+ key_added: If not specified, the tree is stored in .uns[‘tree’]. If `data` is AnnData, save tree in `data`. If `data` is MuData, save tree in data[modality_2].
98
+ modality_key_1: Key to the cell-level AnnData in the MuData object.
99
+ modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object.
102
100
 
103
101
  Returns:
104
102
  MuData: MuData object with cell-level AnnData (`mudata[modality_key_1]`) and aggregated sample-level AnnData (`mudata[modality_key_2]`).
105
103
 
106
104
  Examples:
107
105
  >>> import pertpy as pt
108
- >>> adata = pt.dt.smillie()
106
+ >>> adata = pt.dt.tasccoda_example()
109
107
  >>> tasccoda = pt.tl.Tasccoda()
110
108
  >>> mdata = tasccoda.load(
111
109
  >>> adata, type="sample_level",
@@ -148,21 +146,22 @@ class Tasccoda(CompositionalModel2):
148
146
  pen_args: dict = None,
149
147
  modality_key: str = "coda",
150
148
  ) -> AnnData | MuData:
151
- """Handles data preprocessing, covariate matrix creation, reference selection, and zero count replacement for tascCODA. Also sets model parameters, model type (tree_agg), effect selection type (sslaso) and performs tree processing.
149
+ """Handles data preprocessing, covariate matrix creation, reference selection, and zero count replacement for tascCODA.
152
150
 
153
151
  Args:
154
152
  data: Anndata object with cell counts as .X and covariates saved in .obs or a MuData object.
155
153
  formula: R-style formula for building the covariate matrix.
156
- Categorical covariates are handled automatically, with the covariate value of the first sample being used as the reference category.
157
- To set a different level as the base category for a categorical covariate, use "C(<CovariateName>, Treatment('<ReferenceLevelName>'))"
154
+ Categorical covariates are handled automatically, with the covariate value of the first sample being used as the reference category.
155
+ To set a different level as the base category for a categorical covariate, use "C(<CovariateName>, Treatment('<ReferenceLevelName>'))"
158
156
  reference_cell_type: Column name that sets the reference cell type.
159
- Reference the name of a column. If "automatic", the cell type with the lowest dispersion in relative abundance that is present in at least 90% of samlpes will be chosen. Defaults to "automatic".
160
- automatic_reference_absence_threshold: If using reference_cell_type = "automatic", determine the maximum fraction of zero entries for a cell type
161
- to be considered as a possible reference cell type. Defaults to 0.05.
157
+ If "automatic", the cell type with the lowest dispersion in relative abundance that is present in at least 90% of samlpes will be chosen.
158
+ automatic_reference_absence_threshold: If using reference_cell_type = "automatic",
159
+ determine the maximum fraction of zero entries for a cell type
160
+ to be considered as a possible reference cell type.
162
161
  tree_key: Key in `adata.uns` that contains the tree structure
163
162
  pen_args: Dictionary with penalty arguments. With `reg="scaled_3"`, the parameters phi (aggregation bias), lambda_1, lambda_0 can be set here.
164
163
  See the tascCODA paper for an explanation of these parameters. Default: lambda_0 = 50, lambda_1 = 5, phi = 0.
165
- modality_key: If data is a MuData object, specify key to the aggregated sample-level AnnData object in the MuData object. Defaults to "coda".
164
+ modality_key: If data is a MuData object, specify key to the aggregated sample-level AnnData object in the MuData object.
166
165
 
167
166
  Returns:
168
167
  Return an AnnData (if input data is an AnnData object) or return a MuData (if input data is a MuData object)
@@ -175,7 +174,7 @@ class Tasccoda(CompositionalModel2):
175
174
 
176
175
  Examples:
177
176
  >>> import pertpy as pt
178
- >>> adata = pt.dt.smillie()
177
+ >>> adata = pt.dt.tasccoda_example()
179
178
  >>> tasccoda = pt.tl.Tasccoda()
180
179
  >>> mdata = tasccoda.load(
181
180
  >>> adata, type="sample_level",
@@ -199,8 +198,16 @@ class Tasccoda(CompositionalModel2):
199
198
  if tree_key is None:
200
199
  raise ValueError("Please specify the key in .uns that contains the tree structure!")
201
200
 
201
+ # Scoped import due to installation issues
202
+ try:
203
+ import ete3 as ete
204
+ except ImportError:
205
+ raise ImportError(
206
+ "To use tasccoda please install additional dependencies as `pip install pertpy[coda]`"
207
+ ) from None
208
+
202
209
  # toytree tree - only for legacy reasons, can be removed in the final version
203
- if isinstance(adata.uns[tree_key], tt.tree):
210
+ if isinstance(adata.uns[tree_key], tt.core.ToyTree):
204
211
  # Collapse singularities in the tree
205
212
  phy_tree = collapse_singularities(adata.uns[tree_key])
206
213
 
@@ -315,7 +322,7 @@ class Tasccoda(CompositionalModel2):
315
322
 
316
323
  Examples:
317
324
  >>> import pertpy as pt
318
- >>> adata = pt.dt.smillie()
325
+ >>> adata = pt.dt.tasccoda_example()
319
326
  >>> tasccoda = pt.tl.Tasccoda()
320
327
  >>> mdata = tasccoda.load(
321
328
  >>> adata, type="sample_level",
@@ -325,7 +332,7 @@ class Tasccoda(CompositionalModel2):
325
332
  >>> mdata = tasccoda.prepare(
326
333
  >>> mdata, formula="Health", reference_cell_type="automatic", tree_key="lineage", pen_args={"phi": 0}
327
334
  >>> )
328
- >>> adata = tasccoda.set_init_mcmc_states(rng_key=42, ref_index=[0,1], sample_adata=mdata['coda'])
335
+ >>> adata = tasccoda.set_init_mcmc_states(rng_key=42, ref_index=[0, 1], sample_adata=mdata["coda"])
329
336
  """
330
337
  N, D = sample_adata.obsm["covariate_matrix"].shape
331
338
  P = sample_adata.X.shape[1]
@@ -469,17 +476,17 @@ class Tasccoda(CompositionalModel2):
469
476
 
470
477
  Args:
471
478
  data: AnnData object or MuData object.
472
- modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
473
- rng_key: The rng state used for the prior simulation. If None, a random state will be selected. Defaults to None.
474
- num_prior_samples: Number of prior samples calculated. Defaults to 500.
475
- use_posterior_predictive: If True, the posterior predictive will be calculated. Defaults to True.
479
+ modality_key: If data is a MuData object, specify which modality to use.
480
+ rng_key: The rng state used for the prior simulation. If None, a random state will be selected.
481
+ num_prior_samples: Number of prior samples calculated.
482
+ use_posterior_predictive: If True, the posterior predictive will be calculated.
476
483
 
477
484
  Returns:
478
485
  arviz.InferenceData: arviz_data
479
486
 
480
487
  Examples:
481
488
  >>> import pertpy as pt
482
- >>> adata = pt.dt.smillie()
489
+ >>> adata = pt.dt.tasccoda_example()
483
490
  >>> tasccoda = pt.tl.Tasccoda()
484
491
  >>> mdata = tasccoda.load(
485
492
  >>> adata, type="sample_level",
@@ -496,7 +503,7 @@ class Tasccoda(CompositionalModel2):
496
503
  try:
497
504
  sample_adata = data[modality_key]
498
505
  except IndexError:
499
- print("When data is a MuData object, modality_key must be specified!")
506
+ logger.error("When data is a MuData object, modality_key must be specified!")
500
507
  raise
501
508
  if isinstance(data, AnnData):
502
509
  sample_adata = data
@@ -543,7 +550,7 @@ class Tasccoda(CompositionalModel2):
543
550
 
544
551
  if rng_key is None:
545
552
  rng = np.random.default_rng()
546
- rng_key = random.PRNGKey(rng.integers(0, 10000))
553
+ rng_key = random.key(rng.integers(0, 10000))
547
554
 
548
555
  if use_posterior_predictive:
549
556
  posterior_predictive = Predictive(self.model, self.mcmc.get_samples())(
@@ -590,7 +597,7 @@ class Tasccoda(CompositionalModel2):
590
597
  """
591
598
  Examples:
592
599
  >>> import pertpy as pt
593
- >>> adata = pt.dt.smillie()
600
+ >>> adata = pt.dt.tasccoda_example()
594
601
  >>> tasccoda = pt.tl.Tasccoda()
595
602
  >>> mdata = tasccoda.load(
596
603
  >>> adata, type="sample_level",
@@ -610,7 +617,7 @@ class Tasccoda(CompositionalModel2):
610
617
  """
611
618
  Examples:
612
619
  >>> import pertpy as pt
613
- >>> adata = pt.dt.smillie()
620
+ >>> adata = pt.dt.tasccoda_example()
614
621
  >>> tasccoda = pt.tl.Tasccoda()
615
622
  >>> mdata = tasccoda.load(
616
623
  >>> adata, type="sample_level",
@@ -631,7 +638,7 @@ class Tasccoda(CompositionalModel2):
631
638
  """
632
639
  Examples:
633
640
  >>> import pertpy as pt
634
- >>> adata = pt.dt.smillie()
641
+ >>> adata = pt.dt.tasccoda_example()
635
642
  >>> tasccoda = pt.tl.Tasccoda()
636
643
  >>> mdata = tasccoda.load(
637
644
  >>> adata, type="sample_level",
@@ -652,7 +659,7 @@ class Tasccoda(CompositionalModel2):
652
659
  """
653
660
  Examples:
654
661
  >>> import pertpy as pt
655
- >>> adata = pt.dt.smillie()
662
+ >>> adata = pt.dt.tasccoda_example()
656
663
  >>> tasccoda = pt.tl.Tasccoda()
657
664
  >>> mdata = tasccoda.load(
658
665
  >>> adata, type="sample_level",