pertpy 0.9.5__tar.gz → 0.10.0__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (185) hide show
  1. {pertpy-0.9.5 → pertpy-0.10.0}/.github/pull_request_template.md +3 -3
  2. {pertpy-0.9.5 → pertpy-0.10.0}/.github/release-drafter.yml +2 -2
  3. {pertpy-0.9.5 → pertpy-0.10.0}/.github/workflows/build.yml +2 -2
  4. {pertpy-0.9.5 → pertpy-0.10.0}/.github/workflows/test.yml +4 -7
  5. {pertpy-0.9.5 → pertpy-0.10.0}/.pre-commit-config.yaml +4 -4
  6. {pertpy-0.9.5 → pertpy-0.10.0}/CODE_OF_CONDUCT.md +14 -14
  7. {pertpy-0.9.5 → pertpy-0.10.0}/PKG-INFO +2 -2
  8. {pertpy-0.9.5 → pertpy-0.10.0}/docs/contributing.md +9 -9
  9. {pertpy-0.9.5 → pertpy-0.10.0}/docs/index.md +6 -2
  10. {pertpy-0.9.5 → pertpy-0.10.0}/docs/usage/usage.md +6 -6
  11. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/__init__.py +1 -1
  12. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/_doc.py +1 -2
  13. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/metadata/_cell_line.py +3 -5
  14. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/preprocessing/_guide_rna.py +98 -10
  15. pertpy-0.10.0/pertpy/preprocessing/_guide_rna_mixture.py +179 -0
  16. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_augur.py +32 -44
  17. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_cinemaot.py +1 -3
  18. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_coda/_base_coda.py +21 -29
  19. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_dialogue.py +17 -21
  20. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_differential_gene_expression/_base.py +4 -12
  21. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_distances/_distances.py +56 -48
  22. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_enrichment.py +1 -3
  23. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_milo.py +4 -12
  24. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_mixscape.py +215 -127
  25. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_perturbation_space/_simple.py +1 -3
  26. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_scgen/_scgen.py +1 -3
  27. {pertpy-0.9.5 → pertpy-0.10.0}/pyproject.toml +5 -3
  28. pertpy-0.10.0/tests/preprocessing/test_grna_assignment.py +62 -0
  29. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_distances/test_distances.py +1 -2
  30. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/test_augur.py +1 -1
  31. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/test_mixscape.py +81 -5
  32. pertpy-0.9.5/tests/preprocessing/test_grna_assignment.py +0 -50
  33. {pertpy-0.9.5 → pertpy-0.10.0}/.editorconfig +0 -0
  34. {pertpy-0.9.5 → pertpy-0.10.0}/.gitattributes +0 -0
  35. {pertpy-0.9.5 → pertpy-0.10.0}/.github/ISSUE_TEMPLATE/bug_report.yml +0 -0
  36. {pertpy-0.9.5 → pertpy-0.10.0}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  37. {pertpy-0.9.5 → pertpy-0.10.0}/.github/ISSUE_TEMPLATE/feature_request.yml +0 -0
  38. {pertpy-0.9.5 → pertpy-0.10.0}/.github/labels.yml +0 -0
  39. {pertpy-0.9.5 → pertpy-0.10.0}/.github/workflows/labeler.yml +0 -0
  40. {pertpy-0.9.5 → pertpy-0.10.0}/.github/workflows/release.yml +0 -0
  41. {pertpy-0.9.5 → pertpy-0.10.0}/.github/workflows/release_drafter.yml +0 -0
  42. {pertpy-0.9.5 → pertpy-0.10.0}/.gitignore +0 -0
  43. {pertpy-0.9.5 → pertpy-0.10.0}/.gitmodules +0 -0
  44. {pertpy-0.9.5 → pertpy-0.10.0}/.readthedocs.yml +0 -0
  45. {pertpy-0.9.5 → pertpy-0.10.0}/LICENSE +0 -0
  46. {pertpy-0.9.5 → pertpy-0.10.0}/README.md +0 -0
  47. {pertpy-0.9.5 → pertpy-0.10.0}/codecov.yml +0 -0
  48. {pertpy-0.9.5 → pertpy-0.10.0}/docs/Makefile +0 -0
  49. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_ext/edit_on_github.py +0 -0
  50. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_ext/typed_returns.py +0 -0
  51. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/SCVI_LICENSE +0 -0
  52. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/css/overwrite.css +0 -0
  53. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/css/sphinx_gallery.css +0 -0
  54. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/augur_dp_scatter.png +0 -0
  55. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/augur_important_features.png +0 -0
  56. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/augur_lollipop.png +0 -0
  57. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/augur_scatterplot.png +0 -0
  58. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/de_fold_change.png +0 -0
  59. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/de_multicomparison_fc.png +0 -0
  60. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/de_paired_expression.png +0 -0
  61. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/de_volcano.png +0 -0
  62. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/dialogue_pairplot.png +0 -0
  63. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/dialogue_violin.png +0 -0
  64. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/enrichment_dotplot.png +0 -0
  65. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/enrichment_gsea.png +0 -0
  66. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/milo_da_beeswarm.png +0 -0
  67. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/milo_nhood.png +0 -0
  68. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/milo_nhood_graph.png +0 -0
  69. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/mixscape_barplot.png +0 -0
  70. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/mixscape_heatmap.png +0 -0
  71. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/mixscape_lda.png +0 -0
  72. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/mixscape_perturbscore.png +0 -0
  73. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/mixscape_violin.png +0 -0
  74. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/pseudobulk_samples.png +0 -0
  75. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/sccoda_boxplots.png +0 -0
  76. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/sccoda_effects_barplot.png +0 -0
  77. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/sccoda_rel_abundance_dispersion_plot.png +0 -0
  78. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/sccoda_stacked_barplot.png +0 -0
  79. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/scgen_reg_mean.png +0 -0
  80. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/tasccoda_draw_effects.png +0 -0
  81. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/tasccoda_draw_tree.png +0 -0
  82. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/docstring_previews/tasccoda_effects_umap.png +0 -0
  83. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/icons/code-24px.svg +0 -0
  84. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/icons/computer-24px.svg +0 -0
  85. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/icons/library_books-24px.svg +0 -0
  86. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/icons/play_circle_outline-24px.svg +0 -0
  87. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/placeholder.png +0 -0
  88. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/augur.png +0 -0
  89. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/cinemaot.png +0 -0
  90. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/dge.png +0 -0
  91. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/dialogue.png +0 -0
  92. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/distances.png +0 -0
  93. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/distances_tests.png +0 -0
  94. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/enrichment.png +0 -0
  95. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/guide_rna_assignment.png +0 -0
  96. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/mcfarland.png +0 -0
  97. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/metadata.png +0 -0
  98. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/milo.png +0 -0
  99. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/mixscape.png +0 -0
  100. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/norman.png +0 -0
  101. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/ontology.png +0 -0
  102. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/perturbation_space.png +0 -0
  103. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/placeholder.png +0 -0
  104. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/sccoda.png +0 -0
  105. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/sccoda_extended.png +0 -0
  106. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/scgen_perturbation_prediction.png +0 -0
  107. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/tasccoda.png +0 -0
  108. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_static/tutorials/zhang.png +0 -0
  109. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_templates/autosummary/class.rst +0 -0
  110. {pertpy-0.9.5 → pertpy-0.10.0}/docs/_templates/class_no_inherited.rst +0 -0
  111. {pertpy-0.9.5 → pertpy-0.10.0}/docs/code_of_conduct.md +0 -0
  112. {pertpy-0.9.5 → pertpy-0.10.0}/docs/conf.py +0 -0
  113. {pertpy-0.9.5 → pertpy-0.10.0}/docs/installation.md +0 -0
  114. {pertpy-0.9.5 → pertpy-0.10.0}/docs/make.bat +0 -0
  115. {pertpy-0.9.5 → pertpy-0.10.0}/docs/references.bib +0 -0
  116. {pertpy-0.9.5 → pertpy-0.10.0}/docs/references.md +0 -0
  117. {pertpy-0.9.5 → pertpy-0.10.0}/docs/tutorials/index.md +0 -0
  118. {pertpy-0.9.5 → pertpy-0.10.0}/docs/utils.py +0 -0
  119. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/data/__init__.py +0 -0
  120. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/data/_dataloader.py +0 -0
  121. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/data/_datasets.py +0 -0
  122. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/metadata/__init__.py +0 -0
  123. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/metadata/_compound.py +0 -0
  124. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/metadata/_drug.py +0 -0
  125. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/metadata/_look_up.py +0 -0
  126. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/metadata/_metadata.py +0 -0
  127. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/metadata/_moa.py +0 -0
  128. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/plot/__init__.py +0 -0
  129. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/preprocessing/__init__.py +0 -0
  130. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/py.typed +0 -0
  131. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/__init__.py +0 -0
  132. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_coda/__init__.py +0 -0
  133. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_coda/_sccoda.py +0 -0
  134. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_coda/_tasccoda.py +0 -0
  135. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_differential_gene_expression/__init__.py +0 -0
  136. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_differential_gene_expression/_checks.py +0 -0
  137. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_differential_gene_expression/_dge_comparison.py +0 -0
  138. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_differential_gene_expression/_edger.py +0 -0
  139. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_differential_gene_expression/_pydeseq2.py +0 -0
  140. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_differential_gene_expression/_simple_tests.py +0 -0
  141. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_differential_gene_expression/_statsmodels.py +0 -0
  142. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_distances/__init__.py +0 -0
  143. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_distances/_distance_tests.py +0 -0
  144. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_kernel_pca.py +0 -0
  145. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_perturbation_space/__init__.py +0 -0
  146. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_perturbation_space/_clustering.py +0 -0
  147. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_perturbation_space/_comparison.py +0 -0
  148. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_perturbation_space/_discriminator_classifiers.py +0 -0
  149. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_perturbation_space/_metrics.py +0 -0
  150. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_perturbation_space/_perturbation_space.py +0 -0
  151. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_scgen/__init__.py +0 -0
  152. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_scgen/_base_components.py +0 -0
  153. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_scgen/_scgenvae.py +0 -0
  154. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/_scgen/_utils.py +0 -0
  155. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/decoupler_LICENSE +0 -0
  156. {pertpy-0.9.5 → pertpy-0.10.0}/pertpy/tools/transferlearning_MMD_LICENSE +0 -0
  157. {pertpy-0.9.5 → pertpy-0.10.0}/tests/conftest.py +0 -0
  158. {pertpy-0.9.5 → pertpy-0.10.0}/tests/metadata/test_cell_line.py +0 -0
  159. {pertpy-0.9.5 → pertpy-0.10.0}/tests/metadata/test_compound.py +0 -0
  160. {pertpy-0.9.5 → pertpy-0.10.0}/tests/metadata/test_drug.py +0 -0
  161. {pertpy-0.9.5 → pertpy-0.10.0}/tests/metadata/test_moa.py +0 -0
  162. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_coda/test_sccoda.py +0 -0
  163. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_coda/test_tasccoda.py +0 -0
  164. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_differential_gene_expression/__init__.py +0 -0
  165. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_differential_gene_expression/conftest.py +0 -0
  166. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_differential_gene_expression/test_base.py +0 -0
  167. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_differential_gene_expression/test_compare_groups.py +0 -0
  168. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_differential_gene_expression/test_dge.py +0 -0
  169. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_differential_gene_expression/test_edger.py +0 -0
  170. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_differential_gene_expression/test_input_checks.py +0 -0
  171. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_differential_gene_expression/test_pydeseq2.py +0 -0
  172. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_differential_gene_expression/test_simple_tests.py +0 -0
  173. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_differential_gene_expression/test_statsmodels.py +0 -0
  174. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_distances/test_distance_tests.py +0 -0
  175. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_perturbation_space/test_comparison.py +0 -0
  176. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_perturbation_space/test_discriminator_classifiers.py +0 -0
  177. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_perturbation_space/test_simple_cluster_space.py +0 -0
  178. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/_perturbation_space/test_simple_perturbation_space.py +0 -0
  179. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/haber_data.csv +0 -0
  180. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/r_result.csv +0 -0
  181. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/test_cinemaot.py +0 -0
  182. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/test_dialogue.py +0 -0
  183. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/test_enrichment.py +0 -0
  184. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/test_milo.py +0 -0
  185. {pertpy-0.9.5 → pertpy-0.10.0}/tests/tools/test_scgen.py +0 -0
@@ -4,9 +4,9 @@
4
4
 
5
5
  <!-- Please fill in the appropriate checklist below (delete whatever is not relevant). These are the most common things requested on pull requests (PRs). -->
6
6
 
7
- - [ ] Referenced issue is linked
8
- - [ ] If you've fixed a bug or added code that should be tested, add tests!
9
- - [ ] Documentation in `docs` is updated
7
+ - [ ] Referenced issue is linked
8
+ - [ ] If you've fixed a bug or added code that should be tested, add tests!
9
+ - [ ] Documentation in `docs` is updated
10
10
 
11
11
  **Description of changes**
12
12
 
@@ -1,5 +1,5 @@
1
- name-template: "0.9.5 🌈"
2
- tag-template: 0.9.5
1
+ name-template: "0.10.0 🌈"
2
+ tag-template: 0.10.0
3
3
  exclude-labels:
4
4
  - "skip-changelog"
5
5
 
@@ -16,10 +16,10 @@ jobs:
16
16
  steps:
17
17
  - uses: actions/checkout@v4
18
18
 
19
- - name: Set up Python 3.11
19
+ - name: Set up Python
20
20
  uses: actions/setup-python@v5
21
21
  with:
22
- python-version: "3.11"
22
+ python-version: "3.12"
23
23
  cache: "pip"
24
24
  cache-dependency-path: "**/pyproject.toml"
25
25
 
@@ -21,15 +21,12 @@ jobs:
21
21
  fail-fast: false
22
22
  matrix:
23
23
  include:
24
- - os: ubuntu-latest
25
- python: "3.10"
26
- run_mode: "slow"
27
- - os: ubuntu-latest
24
+ - os: ubuntu-22.04 # ubuntu-latest is currently broken for joblib
28
25
  python: "3.12"
29
26
  run_mode: "slow"
30
- # - os: ubuntu-latest
31
- # python: "3.12"
32
- # run_mode: "fast"
27
+ - os: ubuntu-22.04
28
+ python: "3.12"
29
+ run_mode: "fast"
33
30
  # - os: ubuntu-latest
34
31
  # python: "3.12"
35
32
  # run_mode: slow
@@ -2,8 +2,8 @@ fail_fast: false
2
2
  default_language_version:
3
3
  python: python3
4
4
  default_stages:
5
- - commit
6
- - push
5
+ - pre-commit
6
+ - pre-push
7
7
  minimum_pre_commit_version: 2.16.0
8
8
  repos:
9
9
  - repo: https://github.com/pre-commit/mirrors-prettier
@@ -11,7 +11,7 @@ repos:
11
11
  hooks:
12
12
  - id: prettier
13
13
  - repo: https://github.com/astral-sh/ruff-pre-commit
14
- rev: v0.4.7
14
+ rev: v0.8.6
15
15
  hooks:
16
16
  - id: ruff
17
17
  args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes]
@@ -27,7 +27,7 @@ repos:
27
27
  - id: trailing-whitespace
28
28
  - id: check-case-conflict
29
29
  - repo: https://github.com/pre-commit/mirrors-mypy
30
- rev: v1.10.0
30
+ rev: v1.14.1
31
31
  hooks:
32
32
  - id: mypy
33
33
  args: [--no-strict-optional, --ignore-missing-imports]
@@ -14,23 +14,23 @@ religion, or sexual identity and orientation.
14
14
  Examples of behavior that contributes to creating a positive environment
15
15
  include:
16
16
 
17
- - Using welcoming and inclusive language
18
- - Being respectful of differing viewpoints and experiences
19
- - Gracefully accepting constructive criticism
20
- - Focusing on what is best for the community
21
- - Showing empathy towards other community members
17
+ - Using welcoming and inclusive language
18
+ - Being respectful of differing viewpoints and experiences
19
+ - Gracefully accepting constructive criticism
20
+ - Focusing on what is best for the community
21
+ - Showing empathy towards other community members
22
22
 
23
23
  Examples of unacceptable behavior by participants include:
24
24
 
25
- - The use of sexualized language or imagery and unwelcome sexual
26
- attention or advances
27
- - Trolling, insulting/derogatory comments, and personal or political
28
- attacks
29
- - Public or private harassment
30
- - Publishing others’ private information, such as a physical or
31
- electronic address, without explicit permission
32
- - Other conduct which could reasonably be considered inappropriate in a
33
- professional setting
25
+ - The use of sexualized language or imagery and unwelcome sexual
26
+ attention or advances
27
+ - Trolling, insulting/derogatory comments, and personal or political
28
+ attacks
29
+ - Public or private harassment
30
+ - Publishing others’ private information, such as a physical or
31
+ electronic address, without explicit permission
32
+ - Other conduct which could reasonably be considered inappropriate in a
33
+ professional setting
34
34
 
35
35
  ## Our Responsibilities
36
36
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pertpy
3
- Version: 0.9.5
3
+ Version: 0.10.0
4
4
  Summary: Perturbation Analysis in the scverse ecosystem.
5
5
  Project-URL: Documentation, https://pertpy.readthedocs.io
6
6
  Project-URL: Source, https://github.com/scverse/pertpy
@@ -57,8 +57,8 @@ Requires-Dist: pyarrow
57
57
  Requires-Dist: requests
58
58
  Requires-Dist: rich
59
59
  Requires-Dist: scanpy[leiden]
60
+ Requires-Dist: scikit-learn>=1.4
60
61
  Requires-Dist: scikit-misc
61
- Requires-Dist: scipy
62
62
  Requires-Dist: scvi-tools
63
63
  Requires-Dist: sparsecca
64
64
  Provides-Extra: coda
@@ -132,11 +132,11 @@ in the cookiecutter-scverse template.
132
132
 
133
133
  Please write documentation for new or changed features and use-cases. This project uses [sphinx][] with the following features:
134
134
 
135
- - the [myst][] extension allows to write documentation in markdown/Markedly Structured Text
136
- - Google-style docstrings
137
- - Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks))
138
- - [Sphinx autodoc typehints][], to automatically reference annotated input and output types
139
- - Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/)
135
+ - the [myst][] extension allows to write documentation in markdown/Markedly Structured Text
136
+ - Google-style docstrings
137
+ - Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks))
138
+ - [Sphinx autodoc typehints][], to automatically reference annotated input and output types
139
+ - Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/)
140
140
 
141
141
  See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information
142
142
  on how to write documentation.
@@ -150,10 +150,10 @@ These notebooks come from [pert-tutorials](https://github.com/scverse/pertpy-tut
150
150
 
151
151
  #### Hints
152
152
 
153
- - If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only
154
- if you do so can sphinx automatically create a link to the external documentation.
155
- - If building the documentation fails because of a missing link that is outside your control, you can add an entry to
156
- the `nitpick_ignore` list in `docs/conf.py`
153
+ - If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only
154
+ if you do so can sphinx automatically create a link to the external documentation.
155
+ - If building the documentation fails because of a missing link that is outside your control, you can add an entry to
156
+ the `nitpick_ignore` list in `docs/conf.py`
157
157
 
158
158
  #### Building the docs locally
159
159
 
@@ -54,8 +54,12 @@ Discussions <https://github.com/scverse/pertpy/discussions>
54
54
  references
55
55
  ```
56
56
 
57
- - Consider citing [scanpy Genome Biology (2018)] along with original {doc}`references <references>`.
58
- - A paper for pertpy is in the works.
57
+ ## Citation
58
+
59
+ [Lukas Heumos, Yuge Ji, Lilly May, Tessa Green, Xinyue Zhang, Xichen Wu, Johannes Ostner, Stefan Peidli, Antonia Schumacher, Karin Hrovatin, Michaela Mueller, Faye Chong, Gregor Sturm, Alejandro Tejada, Emma Dann, Mingze Dong, Mojtaba Bahrami, Ilan Gold, Sergei Rybakov, Altana Namsaraeva, Amir Ali Moinfar, Zihe Zheng, Eljas Roellin, Isra Mekki, Chris Sander, Mohammad Lotfollahi, Herbert B. Schiller, Fabian J. Theis
60
+ bioRxiv 2024.08.04.606516; doi: https://doi.org/10.1101/2024.08.04.606516](https://www.biorxiv.org/content/10.1101/2024.08.04.606516v1)
61
+
62
+ Consider citing [scanpy Genome Biology (2018)] along with the original {doc}`references <references>`.
59
63
 
60
64
  # Indices and tables
61
65
 
@@ -563,9 +563,9 @@ including cell line annotation, bulk RNA and protein expression data.
563
563
 
564
564
  Available databases for cell line metadata:
565
565
 
566
- - [The Cancer Dependency Map Project at Broad](https://depmap.org/portal/)
567
- - [The Cancer Dependency Map Project at Sanger](https://depmap.sanger.ac.uk/)
568
- - [Genomics of Drug Sensitivity in Cancer (GDSC)](https://www.cancerrxgene.org/)
566
+ - [The Cancer Dependency Map Project at Broad](https://depmap.org/portal/)
567
+ - [The Cancer Dependency Map Project at Sanger](https://depmap.sanger.ac.uk/)
568
+ - [Genomics of Drug Sensitivity in Cancer (GDSC)](https://www.cancerrxgene.org/)
569
569
 
570
570
  ### Compound
571
571
 
@@ -573,7 +573,7 @@ The Compound module enables the retrieval of various types of information relate
573
573
 
574
574
  Available databases for compound metadata:
575
575
 
576
- - [PubChem](https://pubchem.ncbi.nlm.nih.gov/)
576
+ - [PubChem](https://pubchem.ncbi.nlm.nih.gov/)
577
577
 
578
578
  ### Mechanism of Action
579
579
 
@@ -581,7 +581,7 @@ This module aims to retrieve metadata of mechanism of action studies related to
581
581
 
582
582
  Available databases for mechanism of action metadata:
583
583
 
584
- - [CLUE](https://clue.io/)
584
+ - [CLUE](https://clue.io/)
585
585
 
586
586
  ### Drug
587
587
 
@@ -589,7 +589,7 @@ This module allows for the retrieval of Drug target information.
589
589
 
590
590
  Available databases for drug metadata:
591
591
 
592
- - [chembl](https://www.ebi.ac.uk/chembl/)
592
+ - [chembl](https://www.ebi.ac.uk/chembl/)
593
593
 
594
594
  ```{eval-rst}
595
595
  .. autosummary::
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = "Lukas Heumos"
4
4
  __email__ = "lukas.heumos@posteo.net"
5
- __version__ = "0.9.5"
5
+ __version__ = "0.10.0"
6
6
 
7
7
  import warnings
8
8
 
@@ -15,6 +15,5 @@ def _doc_params(**kwds): # pragma: no cover
15
15
 
16
16
 
17
17
  doc_common_plot_args = """\
18
- show: if `True`, shows the plot.
19
- return_fig: if `True`, returns figure of the plot.\
18
+ return_fig: if `True`, returns figure of the plot, that can be used for saving.\
20
19
  """
@@ -703,7 +703,6 @@ class CellLine(MetaData):
703
703
  metadata_key: str = "bulk_rna_broad",
704
704
  category: str = "cell line",
705
705
  subset_identifier: str | int | Iterable[str] | Iterable[int] | None = None,
706
- show: bool = True,
707
706
  return_fig: bool = False,
708
707
  ) -> Figure | None:
709
708
  """Visualise the correlation of cell lines with annotated metadata.
@@ -747,7 +746,7 @@ class CellLine(MetaData):
747
746
  if all(isinstance(id, str) for id in subset_identifier_list):
748
747
  if set(subset_identifier_list).issubset(adata.obs[identifier].unique()):
749
748
  subset_identifier_list = np.where(
750
- np.in1d(adata.obs[identifier].values, subset_identifier_list)
749
+ np.isin(adata.obs[identifier].values, subset_identifier_list)
751
750
  )[0]
752
751
  else:
753
752
  raise ValueError("`Subset_identifier` must be found in adata.obs.`identifier`.")
@@ -798,10 +797,9 @@ class CellLine(MetaData):
798
797
  },
799
798
  )
800
799
 
801
- if show:
802
- plt.show()
803
800
  if return_fig:
804
801
  return plt.gcf()
802
+ plt.show()
805
803
  return None
806
804
  else:
807
- raise NotImplementedError
805
+ raise NotImplementedError("Only 'cell line' category is supported for correlation comparison.")
@@ -1,15 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import uuid
4
- from typing import TYPE_CHECKING
4
+ from typing import TYPE_CHECKING, Literal
5
+ from warnings import warn
5
6
 
6
7
  import matplotlib.pyplot as plt
7
8
  import numpy as np
8
9
  import pandas as pd
9
10
  import scanpy as sc
10
11
  import scipy
12
+ from rich.progress import track
13
+ from scipy.sparse import issparse
11
14
 
12
15
  from pertpy._doc import _doc_params, doc_common_plot_args
16
+ from pertpy.preprocessing._guide_rna_mixture import PoissonGaussMixture
13
17
 
14
18
  if TYPE_CHECKING:
15
19
  from anndata import AnnData
@@ -17,7 +21,7 @@ if TYPE_CHECKING:
17
21
 
18
22
 
19
23
  class GuideAssignment:
20
- """Offers simple guide assigment based on count thresholds."""
24
+ """Assign cells to guide RNAs."""
21
25
 
22
26
  def assign_by_threshold(
23
27
  self,
@@ -33,12 +37,12 @@ class GuideAssignment:
33
37
  This function expects unnormalized data as input.
34
38
 
35
39
  Args:
36
- adata: Annotated data matrix containing gRNA values
40
+ adata: AnnData object containing gRNA values.
37
41
  assignment_threshold: The count threshold that is required for an assignment to be viable.
38
42
  layer: Key to the layer containing raw count values of the gRNAs.
39
43
  adata.X is used if layer is None. Expects count data.
40
44
  output_layer: Assigned guide will be saved on adata.layers[output_key].
41
- only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray.
45
+ only_return_results: Whether to input AnnData is not modified and the result is returned as an :class:`np.ndarray`.
42
46
 
43
47
  Examples:
44
48
  Each cell is assigned to gRNA that occurs at least 5 times in the respective cell.
@@ -67,7 +71,7 @@ class GuideAssignment:
67
71
  assignment_threshold: float,
68
72
  layer: str | None = None,
69
73
  output_key: str = "assigned_guide",
70
- no_grna_assigned_key: str = "NT",
74
+ no_grna_assigned_key: str = "Negative",
71
75
  only_return_results: bool = False,
72
76
  ) -> np.ndarray | None:
73
77
  """Simple threshold based max gRNA assignment function.
@@ -76,13 +80,13 @@ class GuideAssignment:
76
80
  This function expects unnormalized data as input.
77
81
 
78
82
  Args:
79
- adata: Annotated data matrix containing gRNA values
83
+ adata: AnnData object containing gRNA values.
80
84
  assignment_threshold: The count threshold that is required for an assignment to be viable.
81
85
  layer: Key to the layer containing raw count values of the gRNAs.
82
86
  adata.X is used if layer is None. Expects count data.
83
87
  output_key: Assigned guide will be saved on adata.obs[output_key]. default value is `assigned_guide`.
84
88
  no_grna_assigned_key: The key to return if no gRNA is expressed enough.
85
- only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray.
89
+ only_return_results: Whether to input AnnData is not modified and the result is returned as an np.ndarray.
86
90
 
87
91
  Examples:
88
92
  Each cell is assigned to the most expressed gRNA if it has at least 5 counts.
@@ -109,6 +113,92 @@ class GuideAssignment:
109
113
 
110
114
  return None
111
115
 
116
+ def assign_mixture_model(
117
+ self,
118
+ adata: AnnData,
119
+ model: Literal["poisson_gauss_mixture"] = "poisson_gauss_mixture",
120
+ assigned_guides_key: str = "assigned_guide",
121
+ no_grna_assigned_key: str = "negative",
122
+ max_assignments_per_cell: int = 5,
123
+ multiple_grna_assigned_key: str = "multiple",
124
+ multiple_grna_assignment_string: str = "+",
125
+ only_return_results: bool = False,
126
+ uns_key: str = "guide_assignment_params",
127
+ show_progress: bool = False,
128
+ **mixture_model_kwargs,
129
+ ) -> np.ndarray | None:
130
+ """Assigns gRNAs to cells using a mixture model.
131
+
132
+ Args:
133
+ adata: AnnData object containing gRNA values.
134
+ model: The model to use for the mixture model. Currently only `Poisson_Gauss_Mixture` is supported.
135
+ output_key: Assigned guide will be saved on adata.obs[output_key].
136
+ no_grna_assigned_key: The key to return if a cell is negative for all gRNAs.
137
+ max_assignments_per_cell: The maximum number of gRNAs that can be assigned to a cell.
138
+ multiple_grna_assigned_key: The key to return if multiple gRNAs are assigned to a cell.
139
+ multiple_grna_assignment_string: The string to use to join multiple gRNAs assigned to a cell.
140
+ only_return_results: Whether input AnnData is not modified and the result is returned as an np.ndarray.
141
+ show_progress: Whether to shows progress bar.
142
+ mixture_model_kwargs: Are passed to the mixture model.
143
+
144
+ Examples:
145
+ >>> import pertpy as pt
146
+ >>> mdata = pt.dt.papalexi_2021()
147
+ >>> gdo = mdata.mod["gdo"]
148
+ >>> ga = pt.pp.GuideAssignment()
149
+ >>> ga.assign_mixture_model(gdo)
150
+ """
151
+ if model == "poisson_gauss_mixture":
152
+ mixture_model = PoissonGaussMixture(**mixture_model_kwargs)
153
+ else:
154
+ raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.")
155
+
156
+ if uns_key not in adata.uns:
157
+ adata.uns[uns_key] = {}
158
+ elif type(adata.uns[uns_key]) is not dict:
159
+ raise ValueError(f"adata.uns['{uns_key}'] should be a dictionary. Please remove it or change the key.")
160
+
161
+ res = pd.DataFrame(0, index=adata.obs_names, columns=adata.var_names)
162
+ fct = track if show_progress else lambda iterable: iterable
163
+ for gene in fct(adata.var_names):
164
+ is_nonzero = (
165
+ np.ravel((adata[:, gene].X != 0).todense()) if issparse(adata.X) else np.ravel(adata[:, gene].X != 0)
166
+ )
167
+ if sum(is_nonzero) < 2:
168
+ warn(f"Skipping {gene} as there are less than 2 cells expressing the guide at all.", stacklevel=2)
169
+ continue
170
+ # We are only fitting the model to the non-zero values, the rest is
171
+ # automatically assigned to the negative class
172
+ data = adata[is_nonzero, gene].X.todense().A1 if issparse(adata.X) else adata[is_nonzero, gene].X
173
+ data = np.ravel(data)
174
+
175
+ if np.any(data < 0):
176
+ raise ValueError(
177
+ "Data contains negative values. Please use non-negative data for guide assignment with the Mixture Model."
178
+ )
179
+
180
+ # Log2 transform the data so positive population is approximately normal
181
+ data = np.log2(data)
182
+ assignments = mixture_model.run_model(data)
183
+ res.loc[adata.obs_names[is_nonzero][assignments == "Positive"], gene] = 1
184
+ adata.uns[uns_key][gene] = mixture_model.params
185
+
186
+ # Assign guides to cells
187
+ # Some cells might have multiple guides assigned
188
+ series = pd.Series(no_grna_assigned_key, index=adata.obs_names)
189
+ num_guides_assigned = res.sum(1)
190
+ series.loc[(num_guides_assigned <= max_assignments_per_cell) & (num_guides_assigned != 0)] = res.apply(
191
+ lambda row: row.index[row == 1].tolist(), axis=1
192
+ ).str.join(multiple_grna_assignment_string)
193
+ series.loc[num_guides_assigned > max_assignments_per_cell] = multiple_grna_assigned_key
194
+
195
+ if only_return_results:
196
+ return series.values
197
+
198
+ adata.obs[assigned_guides_key] = series.values
199
+
200
+ return None
201
+
112
202
  @_doc_params(common_plot_args=doc_common_plot_args)
113
203
  def plot_heatmap(
114
204
  self,
@@ -117,7 +207,6 @@ class GuideAssignment:
117
207
  layer: str | None = None,
118
208
  order_by: np.ndarray | str | None = None,
119
209
  key_to_save_order: str = None,
120
- show: bool = True,
121
210
  return_fig: bool = False,
122
211
  **kwargs,
123
212
  ) -> Figure | None:
@@ -194,8 +283,7 @@ class GuideAssignment:
194
283
  finally:
195
284
  del adata.obs[temp_col_name]
196
285
 
197
- if show:
198
- plt.show()
199
286
  if return_fig:
200
287
  return fig
288
+ plt.show()
201
289
  return None
@@ -0,0 +1,179 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Mapping
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ import numpyro
10
+ import numpyro.distributions as dist
11
+ from jax import random
12
+ from numpyro.infer import MCMC, NUTS
13
+
14
+ ParamsDict = Mapping[str, jnp.ndarray]
15
+
16
+
17
+ class MixtureModel(ABC):
18
+ """Abstract base class for 2-component mixture models.
19
+
20
+ Args:
21
+ num_warmup: Number of warmup steps for MCMC sampling.
22
+ num_samples: Number of samples to draw after warmup.
23
+ fraction_positive_expected: Prior belief about fraction of positive components.
24
+ poisson_rate_prior: Rate parameter for exponential prior on Poisson component.
25
+ gaussian_mean_prior: Mean and standard deviation for Gaussian prior on positive component mean.
26
+ gaussian_std_prior: Scale parameter for half-normal prior on positive component std.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ num_warmup: int = 50,
32
+ num_samples: int = 100,
33
+ fraction_positive_expected: float = 0.15,
34
+ poisson_rate_prior: float = 0.2,
35
+ gaussian_mean_prior: tuple[float, float] = (3, 2),
36
+ gaussian_std_prior: float = 1,
37
+ ) -> None:
38
+ self.num_warmup = num_warmup
39
+ self.num_samples = num_samples
40
+ self.fraction_positive_expected = fraction_positive_expected
41
+ self.poisson_rate_prior = poisson_rate_prior
42
+ self.gaussian_mean_prior = gaussian_mean_prior
43
+ self.gaussian_std_prior = gaussian_std_prior
44
+
45
+ @abstractmethod
46
+ def initialize_params(self) -> ParamsDict:
47
+ """Initialize model parameters via sampling from priors.
48
+
49
+ Returns:
50
+ Dictionary of sampled parameter values.
51
+ """
52
+ pass
53
+
54
+ @abstractmethod
55
+ def log_likelihood(self, data: jnp.ndarray, params: ParamsDict) -> jnp.ndarray:
56
+ """Calculate log likelihood of data under current parameters.
57
+
58
+ Args:
59
+ data: Input data array.
60
+ params: Current parameter values.
61
+
62
+ Returns:
63
+ Log likelihood values for each datapoint.
64
+ """
65
+ pass
66
+
67
+ def fit_model(self, data: jnp.ndarray, seed: int = 0) -> MCMC:
68
+ """Fit the mixture model using MCMC.
69
+
70
+ Args:
71
+ data: Input data to fit.
72
+ seed: Random seed for reproducibility.
73
+
74
+ Returns:
75
+ Fitted MCMC object containing samples.
76
+ """
77
+ nuts_kernel = NUTS(self.mixture_model)
78
+ mcmc = MCMC(nuts_kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, progress_bar=False)
79
+ mcmc.run(random.PRNGKey(seed), data=data)
80
+ return mcmc
81
+
82
+ def run_model(self, data: jnp.ndarray, seed: int = 0) -> np.ndarray:
83
+ """Run model fitting and assign components.
84
+
85
+ Args:
86
+ data: Input data array.
87
+ seed: Random seed.
88
+
89
+ Returns:
90
+ Array of "Positive"/"Negative" assignments for each datapoint.
91
+ """
92
+ self.mcmc = self.fit_model(data, seed)
93
+ self.samples = self.mcmc.get_samples()
94
+ self.assignments = self.assignment(self.samples, data)
95
+ return self.assignments
96
+
97
+ def mixture_model(self, data: jnp.ndarray) -> None:
98
+ """Define mixture model structure for NumPyro.
99
+
100
+ Args:
101
+ data: Input data array.
102
+ """
103
+ params = self.initialize_params()
104
+
105
+ with numpyro.plate("data", data.shape[0]):
106
+ log_likelihoods = self.log_likelihood(data, params)
107
+ log_mixture_likelihood = jax.scipy.special.logsumexp(log_likelihoods, axis=-1)
108
+ numpyro.sample("obs", dist.Normal(log_mixture_likelihood, 1.0), obs=data)
109
+
110
+ def assignment(self, samples: ParamsDict, data: jnp.ndarray) -> np.ndarray:
111
+ """Assign data points to mixture components.
112
+
113
+ Args:
114
+ samples: MCMC samples of parameters.
115
+ data: Input data array.
116
+
117
+ Returns:
118
+ Array of component assignments.
119
+ """
120
+ params = {key: samples[key].mean(axis=0) for key in samples.keys()}
121
+ self.params = params
122
+
123
+ log_likelihoods = self.log_likelihood(data, params)
124
+ guide_assignments = jnp.argmax(log_likelihoods, axis=-1)
125
+
126
+ assignments = ["Negative" if assign == 0 else "Positive" for assign in guide_assignments]
127
+ return np.array(assignments)
128
+
129
+
130
+ class PoissonGaussMixture(MixtureModel):
131
+ """Mixture model combining Poisson and Gaussian distributions."""
132
+
133
+ def log_likelihood(self, data: np.ndarray, params: ParamsDict) -> jnp.ndarray:
134
+ """Calculate component-wise log likelihoods.
135
+
136
+ Args:
137
+ data: Input data array.
138
+ params: Current parameter values.
139
+
140
+ Returns:
141
+ Log likelihood values for each component.
142
+ """
143
+ poisson_rate = params["poisson_rate"]
144
+ gaussian_mean = params["gaussian_mean"]
145
+ gaussian_std = params["gaussian_std"]
146
+ mix_probs = params["mix_probs"]
147
+
148
+ # We penalize the model for positioning the Poisson component to the right of the Gaussian component
149
+ # by imposing a soft constraint to penalize the Poisson rate being larger than the Gaussian mean
150
+ # Heuristic regularization term to prevent flipping of the components
151
+ numpyro.factor("separation_penalty", +10 * jnp.heaviside(-poisson_rate + gaussian_mean, 0))
152
+
153
+ log_likelihoods = jnp.stack(
154
+ [
155
+ # Poisson component
156
+ jnp.log(mix_probs[0]) + dist.Poisson(poisson_rate).log_prob(data),
157
+ # Gaussian component
158
+ jnp.log(mix_probs[1]) + dist.Normal(gaussian_mean, gaussian_std).log_prob(data),
159
+ ],
160
+ axis=-1,
161
+ )
162
+
163
+ return log_likelihoods
164
+
165
+ def initialize_params(self) -> ParamsDict:
166
+ """Initialize model parameters via prior sampling.
167
+
168
+ Returns:
169
+ Dictionary of sampled parameter values.
170
+ """
171
+ params = {}
172
+ params["poisson_rate"] = numpyro.sample("poisson_rate", dist.Exponential(self.poisson_rate_prior))
173
+ params["gaussian_mean"] = numpyro.sample("gaussian_mean", dist.Normal(*self.gaussian_mean_prior))
174
+ params["gaussian_std"] = numpyro.sample("gaussian_std", dist.HalfNormal(self.gaussian_std_prior))
175
+ params["mix_probs"] = numpyro.sample(
176
+ "mix_probs",
177
+ dist.Dirichlet(jnp.array([1 - self.fraction_positive_expected, self.fraction_positive_expected])),
178
+ )
179
+ return params