moscot 0.4.0__tar.gz → 0.4.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. {moscot-0.4.0 → moscot-0.4.2}/.pre-commit-config.yaml +6 -6
  2. {moscot-0.4.0 → moscot-0.4.2}/.readthedocs.yml +1 -1
  3. {moscot-0.4.0 → moscot-0.4.2}/PKG-INFO +37 -21
  4. {moscot-0.4.0 → moscot-0.4.2}/README.rst +31 -18
  5. {moscot-0.4.0 → moscot-0.4.2}/pyproject.toml +7 -5
  6. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/_types.py +5 -6
  7. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/backends/ott/__init__.py +10 -2
  8. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/backends/ott/_utils.py +3 -2
  9. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/backends/ott/output.py +22 -45
  10. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/backends/ott/solver.py +4 -4
  11. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/backends/utils.py +1 -2
  12. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/cost.py +1 -1
  13. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/output.py +15 -29
  14. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/problems/__init__.py +1 -2
  15. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/problems/_utils.py +1 -1
  16. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/problems/birth_death.py +1 -1
  17. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/problems/compound_problem.py +6 -6
  18. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/problems/problem.py +2 -222
  19. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/solver.py +3 -3
  20. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/costs/_costs.py +2 -2
  21. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/datasets.py +1 -1
  22. moscot-0.4.2/src/moscot/neural/base/problems/__init__.py +3 -0
  23. moscot-0.4.2/src/moscot/neural/base/problems/problem.py +243 -0
  24. moscot-0.4.2/src/moscot/neural/problems/__init__.py +3 -0
  25. moscot-0.4.2/src/moscot/neural/problems/generic/__init__.py +3 -0
  26. moscot-0.4.2/src/moscot/neural/problems/generic/_generic.py +78 -0
  27. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/plotting/_plotting.py +4 -4
  28. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/plotting/_utils.py +1 -1
  29. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/__init__.py +0 -2
  30. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/cross_modality/_mixins.py +2 -2
  31. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/cross_modality/_translation.py +2 -2
  32. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/generic/__init__.py +1 -7
  33. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/generic/_generic.py +12 -83
  34. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/generic/_mixins.py +1 -1
  35. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/space/_alignment.py +2 -2
  36. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/space/_mapping.py +4 -4
  37. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/space/_mixins.py +23 -9
  38. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/spatiotemporal/_spatio_temporal.py +3 -3
  39. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/time/_lineage.py +5 -5
  40. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/time/_mixins.py +3 -3
  41. moscot-0.4.2/src/moscot/py.typed +0 -0
  42. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/subset_policy.py +3 -3
  43. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/tagged_array.py +1 -1
  44. {moscot-0.4.0 → moscot-0.4.2}/src/moscot.egg-info/PKG-INFO +37 -21
  45. {moscot-0.4.0 → moscot-0.4.2}/src/moscot.egg-info/SOURCES.txt +6 -0
  46. {moscot-0.4.0 → moscot-0.4.2}/src/moscot.egg-info/requires.txt +3 -1
  47. {moscot-0.4.0 → moscot-0.4.2}/.gitignore +0 -0
  48. {moscot-0.4.0 → moscot-0.4.2}/.gitmodules +0 -0
  49. {moscot-0.4.0 → moscot-0.4.2}/.run_notebooks.sh +0 -0
  50. {moscot-0.4.0 → moscot-0.4.2}/LICENSE +0 -0
  51. {moscot-0.4.0 → moscot-0.4.2}/MANIFEST.in +0 -0
  52. {moscot-0.4.0 → moscot-0.4.2}/codecov.yml +0 -0
  53. {moscot-0.4.0 → moscot-0.4.2}/setup.cfg +0 -0
  54. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/__init__.py +0 -0
  55. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/_constants.py +0 -0
  56. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/_logging.py +0 -0
  57. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/_registry.py +0 -0
  58. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/backends/__init__.py +0 -0
  59. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/__init__.py +0 -0
  60. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/problems/_mixins.py +0 -0
  61. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/problems/manager.py +0 -0
  62. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/costs/__init__.py +0 -0
  63. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/costs/_utils.py +0 -0
  64. /moscot-0.4.0/src/moscot/py.typed → /moscot-0.4.2/src/moscot/neural/base/__init__.py +0 -0
  65. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/plotting/__init__.py +0 -0
  66. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/_utils.py +0 -0
  67. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/cross_modality/__init__.py +0 -0
  68. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/space/__init__.py +0 -0
  69. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/spatiotemporal/__init__.py +0 -0
  70. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/time/__init__.py +0 -0
  71. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/__init__.py +0 -0
  72. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/_data/allTFs_dmel.txt +0 -0
  73. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/_data/allTFs_hg38.txt +0 -0
  74. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/_data/allTFs_mm.txt +0 -0
  75. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/_data/human_apoptosis.txt +0 -0
  76. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/_data/human_proliferation.txt +0 -0
  77. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/_data/mouse_apoptosis.txt +0 -0
  78. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/_data/mouse_proliferation.txt +0 -0
  79. {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/data.py +0 -0
  80. {moscot-0.4.0 → moscot-0.4.2}/src/moscot.egg-info/dependency_links.txt +0 -0
  81. {moscot-0.4.0 → moscot-0.4.2}/src/moscot.egg-info/top_level.txt +0 -0
@@ -7,13 +7,13 @@ default_stages:
7
7
  minimum_pre_commit_version: 3.0.0
8
8
  repos:
9
9
  - repo: https://github.com/pre-commit/mirrors-mypy
10
- rev: v1.13.0
10
+ rev: v1.15.0
11
11
  hooks:
12
12
  - id: mypy
13
13
  additional_dependencies: [numpy>=1.25.0]
14
14
  files: ^src
15
15
  - repo: https://github.com/psf/black
16
- rev: 24.10.0
16
+ rev: 25.1.0
17
17
  hooks:
18
18
  - id: black
19
19
  additional_dependencies: [toml]
@@ -23,7 +23,7 @@ repos:
23
23
  - id: prettier
24
24
  language_version: system
25
25
  - repo: https://github.com/PyCQA/isort
26
- rev: 5.13.2
26
+ rev: 6.0.1
27
27
  hooks:
28
28
  - id: isort
29
29
  additional_dependencies: [toml]
@@ -42,7 +42,7 @@ repos:
42
42
  - id: check-yaml
43
43
  - id: check-toml
44
44
  - repo: https://github.com/asottile/pyupgrade
45
- rev: v3.19.0
45
+ rev: v3.19.1
46
46
  hooks:
47
47
  - id: pyupgrade
48
48
  args: [--py3-plus, --py38-plus, --keep-runtime-typing]
@@ -55,7 +55,7 @@ repos:
55
55
  rev: v6.2.4
56
56
  hooks:
57
57
  - id: rstcheck
58
- additional_dependencies: [tomli]
58
+ additional_dependencies: [toml, sphinx]
59
59
  args: [--config=pyproject.toml]
60
60
  - repo: https://github.com/PyCQA/doc8
61
61
  rev: v1.1.2
@@ -63,7 +63,7 @@ repos:
63
63
  - id: doc8
64
64
  - repo: https://github.com/astral-sh/ruff-pre-commit
65
65
  # Ruff version.
66
- rev: v0.7.2
66
+ rev: v0.9.10
67
67
  hooks:
68
68
  - id: ruff
69
69
  args: [--fix, --exit-non-zero-on-fix]
@@ -14,7 +14,7 @@ python:
14
14
  install:
15
15
  - method: pip
16
16
  path: .
17
- extra_requirements: [docs]
17
+ extra_requirements: [docs, neural]
18
18
 
19
19
  submodules:
20
20
  include: [docs/notebooks]
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: moscot
3
- Version: 0.4.0
3
+ Version: 0.4.2
4
4
  Summary: Multi-omic single-cell optimal transport tools
5
5
  Author: Dominik Klein, Giovanni Palla, Michal Klein, Zoe Piran, Marius Lange
6
6
  Maintainer-email: Dominik Klein <dominik.klein@helmholtz-muenchen.de>, Giovanni Palla <giovanni.palla@helmholtz-muenchen.de>, Michal Klein <michal.klein@helmholtz-muenchen.de>
@@ -65,7 +65,7 @@ Requires-Dist: anndata>=0.9.1
65
65
  Requires-Dist: scanpy>=1.9.3
66
66
  Requires-Dist: wrapt>=1.13.2
67
67
  Requires-Dist: docrep>=0.3.2
68
- Requires-Dist: ott-jax[neural]>=0.5.0
68
+ Requires-Dist: ott-jax>=0.5.0
69
69
  Requires-Dist: cloudpickle>=2.2.0
70
70
  Requires-Dist: rich>=13.5
71
71
  Requires-Dist: docstring_inheritance>=2.0.0
@@ -76,6 +76,7 @@ Provides-Extra: neural
76
76
  Requires-Dist: optax; extra == "neural"
77
77
  Requires-Dist: flax; extra == "neural"
78
78
  Requires-Dist: diffrax; extra == "neural"
79
+ Requires-Dist: ott-jax[neural]>=0.5.0; extra == "neural"
79
80
  Provides-Extra: dev
80
81
  Requires-Dist: pre-commit>=3.0.0; extra == "dev"
81
82
  Requires-Dist: tox>=4; extra == "dev"
@@ -85,6 +86,7 @@ Requires-Dist: pytest-xdist>=3; extra == "test"
85
86
  Requires-Dist: pytest-mock>=3.5.0; extra == "test"
86
87
  Requires-Dist: pytest-cov>=4; extra == "test"
87
88
  Requires-Dist: coverage[toml]>=7; extra == "test"
89
+ Requires-Dist: moscot[neural]; extra == "test"
88
90
  Provides-Extra: docs
89
91
  Requires-Dist: sphinx>=5.1.1; extra == "docs"
90
92
  Requires-Dist: sphinx_copybutton>=0.5.0; extra == "docs"
@@ -96,20 +98,37 @@ Requires-Dist: sphinx-tippy>=0.4.1; extra == "docs"
96
98
  Requires-Dist: myst-nb>=0.17.1; extra == "docs"
97
99
  Requires-Dist: ipython>=7.20.0; extra == "docs"
98
100
  Requires-Dist: sphinx_design>=0.3.0; extra == "docs"
101
+ Dynamic: license-file
99
102
 
100
103
  |PyPI| |Downloads| |CI| |Pre-commit| |Codecov| |Docs|
101
104
 
102
- moscot - multi-omic single-cell optimal transport tools
105
+ Moscot - Multiomics Single-cell Optimal Transport
103
106
  =======================================================
104
107
 
105
- **moscot** is a scalable framework for Optimal Transport (OT) applications in
106
- single-cell genomics. It can be used for
108
+ .. image:: docs/_static/img/light_mode_concept_revised.png
109
+ :width: 800px
110
+ :align: center
111
+ :class: only-light
112
+
113
+ .. image:: docs/_static/img/dark_mode_concept_revised.png
114
+ :width: 800px
115
+ :align: center
116
+ :class: only-dark
117
+
118
+
119
+ **moscot** is a framework for Optimal Transport (OT) applications in
120
+ single-cell genomics. It scales to large datasets and can be used for a
121
+ variety of applications across different modalities.
122
+
123
+ moscot's key applications
124
+ ---------------------------
125
+ - Trajectory inference (incorporating spatial and lineage information).
126
+ - Mapping cells to their spatial organisation.
127
+ - Aligning spatial transcriptomics slides.
128
+ - Translating modalities.
129
+ - prototyping of new OT models in single-cell genomics.
130
+ - ... and more, check out the `documentation <https://moscot.readthedocs.io>`_ for more information.
107
131
 
108
- - trajectory inference (incorporating spatial and lineage information)
109
- - mapping cells to their spatial organisation
110
- - aligning spatial transcriptomics slides
111
- - translating modalities
112
- - prototyping of new OT models in single-cell genomics
113
132
 
114
133
  **moscot** is powered by
115
134
  `OTT <https://ott-jax.readthedocs.io>`_ which is a JAX-based Optimal
@@ -118,7 +137,7 @@ differentiation and linear memory complexity for OT problems.
118
137
 
119
138
  Installation
120
139
  ------------
121
- You can install **moscot** via::
140
+ Install **moscot** by running::
122
141
 
123
142
  pip install moscot
124
143
 
@@ -130,15 +149,10 @@ In order to install **moscot** from in editable mode, run::
130
149
 
131
150
  For further instructions how to install jax, please refer to https://github.com/google/jax.
132
151
 
133
- Resources
134
- ---------
135
-
136
- Please have a look at our `documentation <https://moscot.readthedocs.io>`_
137
-
138
- Reference
139
- ---------
140
-
141
- Our preprint "Mapping cells through time and space with moscot" can be found `here <https://www.biorxiv.org/content/10.1101/2023.05.11.540374v1>`_.
152
+ Citing moscot
153
+ -------------
154
+ If you find a model useful for your research, please consider citing the `Klein et al., 2025`_ manuscript as
155
+ well as the publication introducing the model, which can be found in the corresponding documentation.
142
156
 
143
157
  .. |Codecov| image:: https://codecov.io/gh/theislab/moscot/branch/master/graph/badge.svg?token=Rgtm5Tsblo
144
158
  :target: https://codecov.io/gh/theislab/moscot
@@ -163,3 +177,5 @@ Our preprint "Mapping cells through time and space with moscot" can be found `he
163
177
  .. |Downloads| image:: https://static.pepy.tech/badge/moscot
164
178
  :target: https://pepy.tech/project/moscot
165
179
  :alt: Downloads
180
+
181
+ .. _Klein et al., 2025: https://www.nature.com/articles/s41586-024-08453-2
@@ -1,16 +1,32 @@
1
1
  |PyPI| |Downloads| |CI| |Pre-commit| |Codecov| |Docs|
2
2
 
3
- moscot - multi-omic single-cell optimal transport tools
3
+ Moscot - Multiomics Single-cell Optimal Transport
4
4
  =======================================================
5
5
 
6
- **moscot** is a scalable framework for Optimal Transport (OT) applications in
7
- single-cell genomics. It can be used for
6
+ .. image:: docs/_static/img/light_mode_concept_revised.png
7
+ :width: 800px
8
+ :align: center
9
+ :class: only-light
10
+
11
+ .. image:: docs/_static/img/dark_mode_concept_revised.png
12
+ :width: 800px
13
+ :align: center
14
+ :class: only-dark
15
+
16
+
17
+ **moscot** is a framework for Optimal Transport (OT) applications in
18
+ single-cell genomics. It scales to large datasets and can be used for a
19
+ variety of applications across different modalities.
20
+
21
+ moscot's key applications
22
+ ---------------------------
23
+ - Trajectory inference (incorporating spatial and lineage information).
24
+ - Mapping cells to their spatial organisation.
25
+ - Aligning spatial transcriptomics slides.
26
+ - Translating modalities.
27
+ - prototyping of new OT models in single-cell genomics.
28
+ - ... and more, check out the `documentation <https://moscot.readthedocs.io>`_ for more information.
8
29
 
9
- - trajectory inference (incorporating spatial and lineage information)
10
- - mapping cells to their spatial organisation
11
- - aligning spatial transcriptomics slides
12
- - translating modalities
13
- - prototyping of new OT models in single-cell genomics
14
30
 
15
31
  **moscot** is powered by
16
32
  `OTT <https://ott-jax.readthedocs.io>`_ which is a JAX-based Optimal
@@ -19,7 +35,7 @@ differentiation and linear memory complexity for OT problems.
19
35
 
20
36
  Installation
21
37
  ------------
22
- You can install **moscot** via::
38
+ Install **moscot** by running::
23
39
 
24
40
  pip install moscot
25
41
 
@@ -31,15 +47,10 @@ In order to install **moscot** from in editable mode, run::
31
47
 
32
48
  For further instructions how to install jax, please refer to https://github.com/google/jax.
33
49
 
34
- Resources
35
- ---------
36
-
37
- Please have a look at our `documentation <https://moscot.readthedocs.io>`_
38
-
39
- Reference
40
- ---------
41
-
42
- Our preprint "Mapping cells through time and space with moscot" can be found `here <https://www.biorxiv.org/content/10.1101/2023.05.11.540374v1>`_.
50
+ Citing moscot
51
+ -------------
52
+ If you find a model useful for your research, please consider citing the `Klein et al., 2025`_ manuscript as
53
+ well as the publication introducing the model, which can be found in the corresponding documentation.
43
54
 
44
55
  .. |Codecov| image:: https://codecov.io/gh/theislab/moscot/branch/master/graph/badge.svg?token=Rgtm5Tsblo
45
56
  :target: https://codecov.io/gh/theislab/moscot
@@ -64,3 +75,5 @@ Our preprint "Mapping cells through time and space with moscot" can be found `he
64
75
  .. |Downloads| image:: https://static.pepy.tech/badge/moscot
65
76
  :target: https://pepy.tech/project/moscot
66
77
  :alt: Downloads
78
+
79
+ .. _Klein et al., 2025: https://www.nature.com/articles/s41586-024-08453-2
@@ -54,7 +54,7 @@ dependencies = [
54
54
  "scanpy>=1.9.3",
55
55
  "wrapt>=1.13.2",
56
56
  "docrep>=0.3.2",
57
- "ott-jax[neural]>=0.5.0",
57
+ "ott-jax>=0.5.0",
58
58
  "cloudpickle>=2.2.0",
59
59
  "rich>=13.5",
60
60
  "docstring_inheritance>=2.0.0",
@@ -70,6 +70,7 @@ neural = [
70
70
  "optax",
71
71
  "flax",
72
72
  "diffrax",
73
+ "ott-jax[neural]>=0.5.0",
73
74
 
74
75
  ]
75
76
  dev = [
@@ -82,6 +83,7 @@ test = [
82
83
  "pytest-mock>=3.5.0",
83
84
  "pytest-cov>=4",
84
85
  "coverage[toml]>=7",
86
+ "moscot[neural]"
85
87
  ]
86
88
  docs = [
87
89
  "sphinx>=5.1.1",
@@ -274,7 +276,7 @@ env_list = lint-code,py{3.10,3.11,3.12}
274
276
  skip_missing_interpreters = true
275
277
 
276
278
  [testenv]
277
- extras = test
279
+ extras = test,neural
278
280
  commands =
279
281
  python -m pytest {tty:--color=yes} {posargs: \
280
282
  --cov={env_site_packages_dir}{/}moscot --cov-config={tox_root}{/}pyproject.toml \
@@ -290,7 +292,7 @@ commands =
290
292
 
291
293
  [testenv:lint-docs]
292
294
  description = Lint the documentation.
293
- extras = docs
295
+ extras = docs,neural
294
296
  ignore_errors = true
295
297
  allowlist_externals = make
296
298
  pass_env = PYENCHANT_LIBRARY_PATH
@@ -310,7 +312,7 @@ deps =
310
312
  jupytext
311
313
  nbconvert
312
314
  leidenalg
313
- extras = docs
315
+ extras = docs,neural
314
316
  changedir = {tox_root}{/}docs
315
317
  commands =
316
318
  python -m ipykernel install --user --name=moscot
@@ -328,7 +330,7 @@ commands =
328
330
  [testenv:build-docs]
329
331
  description = Build the documentation.
330
332
  deps =
331
- extras = docs
333
+ extras = docs,neural
332
334
  allowlist_externals = make
333
335
  changedir = {tox_root}{/}docs
334
336
  commands =
@@ -2,19 +2,18 @@ import os
2
2
  from typing import Any, Literal, Mapping, Optional, Sequence, Union
3
3
 
4
4
  import numpy as np
5
+ from jax import Array as JaxArray
6
+ from numpy.typing import DTypeLike as DTypeLikeNumpy
7
+ from numpy.typing import NDArray
5
8
  from ott.initializers.linear.initializers import SinkhornInitializer
6
9
  from ott.initializers.linear.initializers_lr import LRInitializer
7
10
  from ott.initializers.quadratic.initializers import BaseQuadraticInitializer
8
11
 
9
12
  # TODO(michalk8): polish
10
13
 
11
- try:
12
- from numpy.typing import DTypeLike, NDArray
13
14
 
14
- ArrayLike = NDArray[np.float64]
15
- except (ImportError, TypeError):
16
- ArrayLike = np.ndarray # type: ignore[misc]
17
- DTypeLike = np.dtype # type: ignore[misc]
15
+ ArrayLike = Union[NDArray[np.floating], JaxArray]
16
+ DTypeLike = DTypeLikeNumpy
18
17
 
19
18
  ProblemKind_t = Literal["linear", "quadratic", "unknown"]
20
19
  Numeric_t = Union[int, float] # type of `time_key` arguments
@@ -1,11 +1,19 @@
1
1
  from ott.geometry import costs
2
2
 
3
3
  from moscot.backends.ott._utils import sinkhorn_divergence
4
- from moscot.backends.ott.output import GraphOTTOutput, OTTNeuralOutput, OTTOutput
4
+ from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput
5
5
  from moscot.backends.ott.solver import GENOTLinSolver, GWSolver, SinkhornSolver
6
6
  from moscot.costs import register_cost
7
7
 
8
- __all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "OTTNeuralOutput", "sinkhorn_divergence", "GENOTLinSolver"]
8
+ __all__ = [
9
+ "OTTOutput",
10
+ "GWSolver",
11
+ "SinkhornSolver",
12
+ "NeuralOutput",
13
+ "sinkhorn_divergence",
14
+ "GENOTLinSolver",
15
+ "GraphOTTOutput",
16
+ ]
9
17
 
10
18
 
11
19
  register_cost("euclidean", backend="ott")(costs.Euclidean)
@@ -184,7 +184,7 @@ def alpha_to_fused_penalty(alpha: float) -> float:
184
184
  return (1 - alpha) / alpha
185
185
 
186
186
 
187
- def densify(arr: ArrayLike) -> jax.Array:
187
+ def densify(arr: Union[ArrayLike, sp.sparray, sp.spmatrix]) -> jax.Array:
188
188
  """If the input is sparse, convert it to dense.
189
189
 
190
190
  Parameters
@@ -197,7 +197,8 @@ def densify(arr: ArrayLike) -> jax.Array:
197
197
  dense :mod:`jax` array.
198
198
  """
199
199
  if sp.issparse(arr):
200
- arr = arr.toarray() # type: ignore[attr-defined]
200
+ arr_sp: Union[sp.sparray, sp.spmatrix] = arr
201
+ arr = arr_sp.toarray()
201
202
  elif isinstance(arr, jesp.BCOO):
202
203
  arr = arr.todense()
203
204
  return jnp.asarray(arr)
@@ -17,7 +17,7 @@ from moscot._types import ArrayLike, Device_t
17
17
  from moscot.backends.ott._utils import get_nearest_neighbors
18
18
  from moscot.base.output import BaseDiscreteSolverOutput, BaseNeuralOutput
19
19
 
20
- __all__ = ["OTTOutput", "GraphOTTOutput", "OTTNeuralOutput"]
20
+ __all__ = ["OTTOutput", "GraphOTTOutput", "NeuralOutput"]
21
21
 
22
22
 
23
23
  class OTTOutput(BaseDiscreteSolverOutput):
@@ -182,6 +182,9 @@ class OTTOutput(BaseDiscreteSolverOutput):
182
182
  axis=1 - forward,
183
183
  ).T # convert to batch first
184
184
 
185
+ def _apply_forward(self, x: ArrayLike) -> ArrayLike:
186
+ return self._apply(x, forward=True)
187
+
185
188
  @property
186
189
  def shape(self) -> Tuple[int, int]: # noqa: D102
187
190
  if isinstance(self._output, sinkhorn.SinkhornOutput):
@@ -241,11 +244,11 @@ class OTTOutput(BaseDiscreteSolverOutput):
241
244
  return jnp.ones((n,))
242
245
 
243
246
 
244
- class OTTNeuralOutput(BaseNeuralOutput):
247
+ class NeuralOutput(BaseNeuralOutput):
245
248
  """Output wrapper for GENOT."""
246
249
 
247
250
  def __init__(self, model: GENOT, logs: dict[str, list[float]]):
248
- """Initialize `OTTNeuralOutput`.
251
+ """Initialize `NeuralOutput`.
249
252
 
250
253
  Parameters
251
254
  ----------
@@ -269,8 +272,7 @@ class OTTNeuralOutput(BaseNeuralOutput):
269
272
  self,
270
273
  src_dist: ArrayLike,
271
274
  tgt_dist: ArrayLike,
272
- forward: bool,
273
- func: Callable[[jnp.ndarray], jnp.ndarray],
275
+ func: Callable[[ArrayLike], ArrayLike],
274
276
  save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments
275
277
  batch_size: int = 1024,
276
278
  k: int = 30,
@@ -279,9 +281,9 @@ class OTTNeuralOutput(BaseNeuralOutput):
279
281
  recall_target: float = 0.95,
280
282
  aggregate_to_topk: bool = True,
281
283
  ) -> sp.csr_matrix:
282
- row_indices: Union[jnp.ndarray, List[jnp.ndarray]] = []
283
- column_indices: Union[jnp.ndarray, List[jnp.ndarray]] = []
284
- distances_list: Union[jnp.ndarray, List[jnp.ndarray]] = []
284
+ row_indices: List[ArrayLike] = []
285
+ column_indices: List[ArrayLike] = []
286
+ distances_list: List[ArrayLike] = []
285
287
  if length_scale is None:
286
288
  key = jax.random.PRNGKey(seed)
287
289
  src_batch = src_dist[jax.random.choice(key, src_dist.shape[0], shape=((batch_size,)))]
@@ -306,20 +308,14 @@ class OTTNeuralOutput(BaseNeuralOutput):
306
308
  row_indices = jnp.concatenate(row_indices)
307
309
  column_indices = jnp.concatenate(column_indices)
308
310
  tm = sp.csr_matrix((distances, (row_indices, column_indices)), shape=[len(src_dist), len(tgt_dist)])
309
- if forward:
310
- if save_transport_matrix:
311
- self._transport_matrix = tm
312
- else:
313
- tm = tm.T
314
- if save_transport_matrix:
315
- self._inverse_transport_matrix = tm
311
+ if save_transport_matrix:
312
+ self._transport_matrix = tm
316
313
  return tm
317
314
 
318
315
  def project_to_transport_matrix( # type:ignore[override]
319
316
  self,
320
317
  src_cells: ArrayLike,
321
318
  tgt_cells: ArrayLike,
322
- forward: bool = True,
323
319
  condition: ArrayLike = None,
324
320
  save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments
325
321
  batch_size: int = 1024,
@@ -351,7 +347,7 @@ class OTTNeuralOutput(BaseNeuralOutput):
351
347
  save_transport_matrix
352
348
  Whether to save the transport matrix.
353
349
  batch_size
354
- Number of data points in the source distribution the neighborhoodgraph is computed
350
+ Number of data points in the source distribution the neighborhood graph is computed
355
351
  for in parallel.
356
352
  k
357
353
  Number of neighbors to construct the k-nearest neighbor graph of a mapped cell.
@@ -375,13 +371,12 @@ class OTTNeuralOutput(BaseNeuralOutput):
375
371
  The projected transport matrix.
376
372
  """
377
373
  src_cells, tgt_cells = jnp.asarray(src_cells), jnp.asarray(tgt_cells)
378
- push = self.push if condition is None else lambda x: self.push(x, condition)
379
- pull = self.pull if condition is None else lambda x: self.pull(x, condition)
380
- func, src_dist, tgt_dist = (push, src_cells, tgt_cells) if forward else (pull, tgt_cells, src_cells)
374
+ conditioned_fn: Callable[[ArrayLike], ArrayLike] = lambda x: self.push(x, condition)
375
+ push = self.push if condition is None else conditioned_fn
376
+ func, src_dist, tgt_dist = (push, src_cells, tgt_cells)
381
377
  return self._project_transport_matrix(
382
378
  src_dist=src_dist,
383
379
  tgt_dist=tgt_dist,
384
- forward=forward,
385
380
  func=func,
386
381
  save_transport_matrix=save_transport_matrix, # TODO(@MUCDK) adapt order of arguments
387
382
  batch_size=batch_size,
@@ -406,31 +401,13 @@ class OTTNeuralOutput(BaseNeuralOutput):
406
401
  -------
407
402
  Pushed distribution.
408
403
  """
404
+ if isinstance(x, (bool, int, float, complex)):
405
+ raise ValueError("Expected array, found scalar value.")
409
406
  if x.ndim not in (1, 2):
410
407
  raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.")
411
- return self._apply(x, cond=cond, forward=True)
412
-
413
- def pull(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
414
- """Pull distribution `x` conditioned on condition `cond`.
415
-
416
- This does not make sense for some neural models and is therefore left unimplemented.
417
-
418
- Parameters
419
- ----------
420
- x
421
- Distribution to push.
422
- cond
423
- Condition of conditional neural OT.
424
-
425
- Raises
426
- ------
427
- NotImplementedError
428
- """
429
- raise NotImplementedError("`pull` does not make sense for neural OT.")
408
+ return self._apply_forward(x, cond=cond)
430
409
 
431
- def _apply(self, x: ArrayLike, forward: bool, cond: Optional[ArrayLike] = None) -> ArrayLike:
432
- if not forward:
433
- raise NotImplementedError("Backward i.e., pull on neural OT is not supported.")
410
+ def _apply_forward(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
434
411
  return self._model.transport(x, condition=cond)
435
412
 
436
413
  @property
@@ -445,7 +422,7 @@ class OTTNeuralOutput(BaseNeuralOutput):
445
422
  def to(
446
423
  self,
447
424
  device: Optional[Device_t] = None,
448
- ) -> "OTTNeuralOutput":
425
+ ) -> "NeuralOutput":
449
426
  """Transfer the output to another device or change its data type.
450
427
 
451
428
  Parameters
@@ -471,7 +448,7 @@ class OTTNeuralOutput(BaseNeuralOutput):
471
448
  # raise IndexError(f"Unable to fetch the device with `id={idx}`.") from err
472
449
 
473
450
  # out = jax.device_put(self._model, device)
474
- # return OTTNeuralOutput(out)
451
+ # return NeuralOutput(out)
475
452
  return self # TODO(ilan-gold) move model to device
476
453
 
477
454
  @property
@@ -53,7 +53,7 @@ from moscot.backends.ott._utils import (
53
53
  densify,
54
54
  ensure_2d,
55
55
  )
56
- from moscot.backends.ott.output import GraphOTTOutput, OTTNeuralOutput, OTTOutput
56
+ from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput
57
57
  from moscot.base.problems._utils import TimeScalesHeatKernel
58
58
  from moscot.base.solver import OTSolver
59
59
  from moscot.costs import get_cost
@@ -216,7 +216,7 @@ class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC):
216
216
  problem_shape = x.shape if problem_shape is None else problem_shape
217
217
  return _instantiate_geodesic_cost(
218
218
  arr=arr,
219
- problem_shape=problem_shape,
219
+ problem_shape=problem_shape, # type: ignore[arg-type]
220
220
  t=t,
221
221
  is_linear_term=is_linear_term,
222
222
  epsilon=epsilon,
@@ -699,10 +699,10 @@ class GENOTLinSolver(OTSolver[OTTOutput]):
699
699
  def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]:
700
700
  return {"batch_size", "train_size", "trainloader", "validloader", "seed"}, {} # type: ignore[return-value]
701
701
 
702
- def _solve(self, data_samplers: Tuple[MultiLoader, MultiLoader]) -> OTTNeuralOutput: # type: ignore[override]
702
+ def _solve(self, data_samplers: Tuple[MultiLoader, MultiLoader]) -> NeuralOutput: # type: ignore[override]
703
703
  seed = self._neural_kwargs.get("seed", 0) # TODO(ilan-gold): unify rng hadnling like OTT tests
704
704
  rng = jax.random.PRNGKey(seed)
705
705
  logs = self.solver(
706
706
  data_samplers[0], n_iters=self._neural_kwargs.get("n_iters", 100), rng=rng
707
707
  ) # TODO(ilan-gold): validation and figure out defualts
708
- return OTTNeuralOutput(self.solver, logs)
708
+ return NeuralOutput(self.solver, logs)
@@ -42,8 +42,7 @@ def register_solver(
42
42
  return _REGISTRY.register(backend) # type: ignore[return-value]
43
43
 
44
44
 
45
- # TODO(@MUCDK) fix mypy error
46
- @register_solver("ott") # type: ignore[arg-type]
45
+ @register_solver("ott")
47
46
  def _(
48
47
  problem_kind: Literal["linear", "quadratic"],
49
48
  solver_name: Optional[Literal["GENOTLinSolver"]] = None,
@@ -58,7 +58,7 @@ class BaseCost(abc.ABC):
58
58
  f"Cost matrix contains `{np.sum(np.isnan(cost))}` NaN values, "
59
59
  f"setting them to the maximum value `{maxx}`."
60
60
  )
61
- cost = np.nan_to_num(cost, nan=maxx) # type: ignore[call-overload]
61
+ cost = np.nan_to_num(cost, nan=maxx) # type: ignore[arg-type, type-var]
62
62
  if np.any(cost < 0):
63
63
  raise ValueError(f"Cost matrix contains `{np.sum(cost < 0)}` negative values.")
64
64
  return cost