moscot 0.4.0__tar.gz → 0.4.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {moscot-0.4.0 → moscot-0.4.2}/.pre-commit-config.yaml +6 -6
- {moscot-0.4.0 → moscot-0.4.2}/.readthedocs.yml +1 -1
- {moscot-0.4.0 → moscot-0.4.2}/PKG-INFO +37 -21
- {moscot-0.4.0 → moscot-0.4.2}/README.rst +31 -18
- {moscot-0.4.0 → moscot-0.4.2}/pyproject.toml +7 -5
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/_types.py +5 -6
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/backends/ott/__init__.py +10 -2
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/backends/ott/_utils.py +3 -2
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/backends/ott/output.py +22 -45
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/backends/ott/solver.py +4 -4
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/backends/utils.py +1 -2
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/cost.py +1 -1
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/output.py +15 -29
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/problems/__init__.py +1 -2
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/problems/_utils.py +1 -1
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/problems/birth_death.py +1 -1
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/problems/compound_problem.py +6 -6
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/problems/problem.py +2 -222
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/solver.py +3 -3
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/costs/_costs.py +2 -2
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/datasets.py +1 -1
- moscot-0.4.2/src/moscot/neural/base/problems/__init__.py +3 -0
- moscot-0.4.2/src/moscot/neural/base/problems/problem.py +243 -0
- moscot-0.4.2/src/moscot/neural/problems/__init__.py +3 -0
- moscot-0.4.2/src/moscot/neural/problems/generic/__init__.py +3 -0
- moscot-0.4.2/src/moscot/neural/problems/generic/_generic.py +78 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/plotting/_plotting.py +4 -4
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/plotting/_utils.py +1 -1
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/__init__.py +0 -2
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/cross_modality/_mixins.py +2 -2
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/cross_modality/_translation.py +2 -2
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/generic/__init__.py +1 -7
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/generic/_generic.py +12 -83
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/generic/_mixins.py +1 -1
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/space/_alignment.py +2 -2
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/space/_mapping.py +4 -4
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/space/_mixins.py +23 -9
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/spatiotemporal/_spatio_temporal.py +3 -3
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/time/_lineage.py +5 -5
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/time/_mixins.py +3 -3
- moscot-0.4.2/src/moscot/py.typed +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/subset_policy.py +3 -3
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/tagged_array.py +1 -1
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot.egg-info/PKG-INFO +37 -21
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot.egg-info/SOURCES.txt +6 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot.egg-info/requires.txt +3 -1
- {moscot-0.4.0 → moscot-0.4.2}/.gitignore +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/.gitmodules +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/.run_notebooks.sh +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/LICENSE +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/MANIFEST.in +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/codecov.yml +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/setup.cfg +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/__init__.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/_constants.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/_logging.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/_registry.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/backends/__init__.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/__init__.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/problems/_mixins.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/base/problems/manager.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/costs/__init__.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/costs/_utils.py +0 -0
- /moscot-0.4.0/src/moscot/py.typed → /moscot-0.4.2/src/moscot/neural/base/__init__.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/plotting/__init__.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/_utils.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/cross_modality/__init__.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/space/__init__.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/spatiotemporal/__init__.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/problems/time/__init__.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/__init__.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/_data/allTFs_dmel.txt +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/_data/allTFs_hg38.txt +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/_data/allTFs_mm.txt +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/_data/human_apoptosis.txt +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/_data/human_proliferation.txt +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/_data/mouse_apoptosis.txt +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/_data/mouse_proliferation.txt +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot/utils/data.py +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot.egg-info/dependency_links.txt +0 -0
- {moscot-0.4.0 → moscot-0.4.2}/src/moscot.egg-info/top_level.txt +0 -0
|
@@ -7,13 +7,13 @@ default_stages:
|
|
|
7
7
|
minimum_pre_commit_version: 3.0.0
|
|
8
8
|
repos:
|
|
9
9
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
|
10
|
-
rev: v1.
|
|
10
|
+
rev: v1.15.0
|
|
11
11
|
hooks:
|
|
12
12
|
- id: mypy
|
|
13
13
|
additional_dependencies: [numpy>=1.25.0]
|
|
14
14
|
files: ^src
|
|
15
15
|
- repo: https://github.com/psf/black
|
|
16
|
-
rev:
|
|
16
|
+
rev: 25.1.0
|
|
17
17
|
hooks:
|
|
18
18
|
- id: black
|
|
19
19
|
additional_dependencies: [toml]
|
|
@@ -23,7 +23,7 @@ repos:
|
|
|
23
23
|
- id: prettier
|
|
24
24
|
language_version: system
|
|
25
25
|
- repo: https://github.com/PyCQA/isort
|
|
26
|
-
rev:
|
|
26
|
+
rev: 6.0.1
|
|
27
27
|
hooks:
|
|
28
28
|
- id: isort
|
|
29
29
|
additional_dependencies: [toml]
|
|
@@ -42,7 +42,7 @@ repos:
|
|
|
42
42
|
- id: check-yaml
|
|
43
43
|
- id: check-toml
|
|
44
44
|
- repo: https://github.com/asottile/pyupgrade
|
|
45
|
-
rev: v3.19.
|
|
45
|
+
rev: v3.19.1
|
|
46
46
|
hooks:
|
|
47
47
|
- id: pyupgrade
|
|
48
48
|
args: [--py3-plus, --py38-plus, --keep-runtime-typing]
|
|
@@ -55,7 +55,7 @@ repos:
|
|
|
55
55
|
rev: v6.2.4
|
|
56
56
|
hooks:
|
|
57
57
|
- id: rstcheck
|
|
58
|
-
additional_dependencies: [
|
|
58
|
+
additional_dependencies: [toml, sphinx]
|
|
59
59
|
args: [--config=pyproject.toml]
|
|
60
60
|
- repo: https://github.com/PyCQA/doc8
|
|
61
61
|
rev: v1.1.2
|
|
@@ -63,7 +63,7 @@ repos:
|
|
|
63
63
|
- id: doc8
|
|
64
64
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
65
65
|
# Ruff version.
|
|
66
|
-
rev: v0.
|
|
66
|
+
rev: v0.9.10
|
|
67
67
|
hooks:
|
|
68
68
|
- id: ruff
|
|
69
69
|
args: [--fix, --exit-non-zero-on-fix]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: moscot
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.2
|
|
4
4
|
Summary: Multi-omic single-cell optimal transport tools
|
|
5
5
|
Author: Dominik Klein, Giovanni Palla, Michal Klein, Zoe Piran, Marius Lange
|
|
6
6
|
Maintainer-email: Dominik Klein <dominik.klein@helmholtz-muenchen.de>, Giovanni Palla <giovanni.palla@helmholtz-muenchen.de>, Michal Klein <michal.klein@helmholtz-muenchen.de>
|
|
@@ -65,7 +65,7 @@ Requires-Dist: anndata>=0.9.1
|
|
|
65
65
|
Requires-Dist: scanpy>=1.9.3
|
|
66
66
|
Requires-Dist: wrapt>=1.13.2
|
|
67
67
|
Requires-Dist: docrep>=0.3.2
|
|
68
|
-
Requires-Dist: ott-jax
|
|
68
|
+
Requires-Dist: ott-jax>=0.5.0
|
|
69
69
|
Requires-Dist: cloudpickle>=2.2.0
|
|
70
70
|
Requires-Dist: rich>=13.5
|
|
71
71
|
Requires-Dist: docstring_inheritance>=2.0.0
|
|
@@ -76,6 +76,7 @@ Provides-Extra: neural
|
|
|
76
76
|
Requires-Dist: optax; extra == "neural"
|
|
77
77
|
Requires-Dist: flax; extra == "neural"
|
|
78
78
|
Requires-Dist: diffrax; extra == "neural"
|
|
79
|
+
Requires-Dist: ott-jax[neural]>=0.5.0; extra == "neural"
|
|
79
80
|
Provides-Extra: dev
|
|
80
81
|
Requires-Dist: pre-commit>=3.0.0; extra == "dev"
|
|
81
82
|
Requires-Dist: tox>=4; extra == "dev"
|
|
@@ -85,6 +86,7 @@ Requires-Dist: pytest-xdist>=3; extra == "test"
|
|
|
85
86
|
Requires-Dist: pytest-mock>=3.5.0; extra == "test"
|
|
86
87
|
Requires-Dist: pytest-cov>=4; extra == "test"
|
|
87
88
|
Requires-Dist: coverage[toml]>=7; extra == "test"
|
|
89
|
+
Requires-Dist: moscot[neural]; extra == "test"
|
|
88
90
|
Provides-Extra: docs
|
|
89
91
|
Requires-Dist: sphinx>=5.1.1; extra == "docs"
|
|
90
92
|
Requires-Dist: sphinx_copybutton>=0.5.0; extra == "docs"
|
|
@@ -96,20 +98,37 @@ Requires-Dist: sphinx-tippy>=0.4.1; extra == "docs"
|
|
|
96
98
|
Requires-Dist: myst-nb>=0.17.1; extra == "docs"
|
|
97
99
|
Requires-Dist: ipython>=7.20.0; extra == "docs"
|
|
98
100
|
Requires-Dist: sphinx_design>=0.3.0; extra == "docs"
|
|
101
|
+
Dynamic: license-file
|
|
99
102
|
|
|
100
103
|
|PyPI| |Downloads| |CI| |Pre-commit| |Codecov| |Docs|
|
|
101
104
|
|
|
102
|
-
|
|
105
|
+
Moscot - Multiomics Single-cell Optimal Transport
|
|
103
106
|
=======================================================
|
|
104
107
|
|
|
105
|
-
|
|
106
|
-
|
|
108
|
+
.. image:: docs/_static/img/light_mode_concept_revised.png
|
|
109
|
+
:width: 800px
|
|
110
|
+
:align: center
|
|
111
|
+
:class: only-light
|
|
112
|
+
|
|
113
|
+
.. image:: docs/_static/img/dark_mode_concept_revised.png
|
|
114
|
+
:width: 800px
|
|
115
|
+
:align: center
|
|
116
|
+
:class: only-dark
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
**moscot** is a framework for Optimal Transport (OT) applications in
|
|
120
|
+
single-cell genomics. It scales to large datasets and can be used for a
|
|
121
|
+
variety of applications across different modalities.
|
|
122
|
+
|
|
123
|
+
moscot's key applications
|
|
124
|
+
---------------------------
|
|
125
|
+
- Trajectory inference (incorporating spatial and lineage information).
|
|
126
|
+
- Mapping cells to their spatial organisation.
|
|
127
|
+
- Aligning spatial transcriptomics slides.
|
|
128
|
+
- Translating modalities.
|
|
129
|
+
- prototyping of new OT models in single-cell genomics.
|
|
130
|
+
- ... and more, check out the `documentation <https://moscot.readthedocs.io>`_ for more information.
|
|
107
131
|
|
|
108
|
-
- trajectory inference (incorporating spatial and lineage information)
|
|
109
|
-
- mapping cells to their spatial organisation
|
|
110
|
-
- aligning spatial transcriptomics slides
|
|
111
|
-
- translating modalities
|
|
112
|
-
- prototyping of new OT models in single-cell genomics
|
|
113
132
|
|
|
114
133
|
**moscot** is powered by
|
|
115
134
|
`OTT <https://ott-jax.readthedocs.io>`_ which is a JAX-based Optimal
|
|
@@ -118,7 +137,7 @@ differentiation and linear memory complexity for OT problems.
|
|
|
118
137
|
|
|
119
138
|
Installation
|
|
120
139
|
------------
|
|
121
|
-
|
|
140
|
+
Install **moscot** by running::
|
|
122
141
|
|
|
123
142
|
pip install moscot
|
|
124
143
|
|
|
@@ -130,15 +149,10 @@ In order to install **moscot** from in editable mode, run::
|
|
|
130
149
|
|
|
131
150
|
For further instructions how to install jax, please refer to https://github.com/google/jax.
|
|
132
151
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
Reference
|
|
139
|
-
---------
|
|
140
|
-
|
|
141
|
-
Our preprint "Mapping cells through time and space with moscot" can be found `here <https://www.biorxiv.org/content/10.1101/2023.05.11.540374v1>`_.
|
|
152
|
+
Citing moscot
|
|
153
|
+
-------------
|
|
154
|
+
If you find a model useful for your research, please consider citing the `Klein et al., 2025`_ manuscript as
|
|
155
|
+
well as the publication introducing the model, which can be found in the corresponding documentation.
|
|
142
156
|
|
|
143
157
|
.. |Codecov| image:: https://codecov.io/gh/theislab/moscot/branch/master/graph/badge.svg?token=Rgtm5Tsblo
|
|
144
158
|
:target: https://codecov.io/gh/theislab/moscot
|
|
@@ -163,3 +177,5 @@ Our preprint "Mapping cells through time and space with moscot" can be found `he
|
|
|
163
177
|
.. |Downloads| image:: https://static.pepy.tech/badge/moscot
|
|
164
178
|
:target: https://pepy.tech/project/moscot
|
|
165
179
|
:alt: Downloads
|
|
180
|
+
|
|
181
|
+
.. _Klein et al., 2025: https://www.nature.com/articles/s41586-024-08453-2
|
|
@@ -1,16 +1,32 @@
|
|
|
1
1
|
|PyPI| |Downloads| |CI| |Pre-commit| |Codecov| |Docs|
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
Moscot - Multiomics Single-cell Optimal Transport
|
|
4
4
|
=======================================================
|
|
5
5
|
|
|
6
|
-
|
|
7
|
-
|
|
6
|
+
.. image:: docs/_static/img/light_mode_concept_revised.png
|
|
7
|
+
:width: 800px
|
|
8
|
+
:align: center
|
|
9
|
+
:class: only-light
|
|
10
|
+
|
|
11
|
+
.. image:: docs/_static/img/dark_mode_concept_revised.png
|
|
12
|
+
:width: 800px
|
|
13
|
+
:align: center
|
|
14
|
+
:class: only-dark
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
**moscot** is a framework for Optimal Transport (OT) applications in
|
|
18
|
+
single-cell genomics. It scales to large datasets and can be used for a
|
|
19
|
+
variety of applications across different modalities.
|
|
20
|
+
|
|
21
|
+
moscot's key applications
|
|
22
|
+
---------------------------
|
|
23
|
+
- Trajectory inference (incorporating spatial and lineage information).
|
|
24
|
+
- Mapping cells to their spatial organisation.
|
|
25
|
+
- Aligning spatial transcriptomics slides.
|
|
26
|
+
- Translating modalities.
|
|
27
|
+
- prototyping of new OT models in single-cell genomics.
|
|
28
|
+
- ... and more, check out the `documentation <https://moscot.readthedocs.io>`_ for more information.
|
|
8
29
|
|
|
9
|
-
- trajectory inference (incorporating spatial and lineage information)
|
|
10
|
-
- mapping cells to their spatial organisation
|
|
11
|
-
- aligning spatial transcriptomics slides
|
|
12
|
-
- translating modalities
|
|
13
|
-
- prototyping of new OT models in single-cell genomics
|
|
14
30
|
|
|
15
31
|
**moscot** is powered by
|
|
16
32
|
`OTT <https://ott-jax.readthedocs.io>`_ which is a JAX-based Optimal
|
|
@@ -19,7 +35,7 @@ differentiation and linear memory complexity for OT problems.
|
|
|
19
35
|
|
|
20
36
|
Installation
|
|
21
37
|
------------
|
|
22
|
-
|
|
38
|
+
Install **moscot** by running::
|
|
23
39
|
|
|
24
40
|
pip install moscot
|
|
25
41
|
|
|
@@ -31,15 +47,10 @@ In order to install **moscot** from in editable mode, run::
|
|
|
31
47
|
|
|
32
48
|
For further instructions how to install jax, please refer to https://github.com/google/jax.
|
|
33
49
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
Reference
|
|
40
|
-
---------
|
|
41
|
-
|
|
42
|
-
Our preprint "Mapping cells through time and space with moscot" can be found `here <https://www.biorxiv.org/content/10.1101/2023.05.11.540374v1>`_.
|
|
50
|
+
Citing moscot
|
|
51
|
+
-------------
|
|
52
|
+
If you find a model useful for your research, please consider citing the `Klein et al., 2025`_ manuscript as
|
|
53
|
+
well as the publication introducing the model, which can be found in the corresponding documentation.
|
|
43
54
|
|
|
44
55
|
.. |Codecov| image:: https://codecov.io/gh/theislab/moscot/branch/master/graph/badge.svg?token=Rgtm5Tsblo
|
|
45
56
|
:target: https://codecov.io/gh/theislab/moscot
|
|
@@ -64,3 +75,5 @@ Our preprint "Mapping cells through time and space with moscot" can be found `he
|
|
|
64
75
|
.. |Downloads| image:: https://static.pepy.tech/badge/moscot
|
|
65
76
|
:target: https://pepy.tech/project/moscot
|
|
66
77
|
:alt: Downloads
|
|
78
|
+
|
|
79
|
+
.. _Klein et al., 2025: https://www.nature.com/articles/s41586-024-08453-2
|
|
@@ -54,7 +54,7 @@ dependencies = [
|
|
|
54
54
|
"scanpy>=1.9.3",
|
|
55
55
|
"wrapt>=1.13.2",
|
|
56
56
|
"docrep>=0.3.2",
|
|
57
|
-
"ott-jax
|
|
57
|
+
"ott-jax>=0.5.0",
|
|
58
58
|
"cloudpickle>=2.2.0",
|
|
59
59
|
"rich>=13.5",
|
|
60
60
|
"docstring_inheritance>=2.0.0",
|
|
@@ -70,6 +70,7 @@ neural = [
|
|
|
70
70
|
"optax",
|
|
71
71
|
"flax",
|
|
72
72
|
"diffrax",
|
|
73
|
+
"ott-jax[neural]>=0.5.0",
|
|
73
74
|
|
|
74
75
|
]
|
|
75
76
|
dev = [
|
|
@@ -82,6 +83,7 @@ test = [
|
|
|
82
83
|
"pytest-mock>=3.5.0",
|
|
83
84
|
"pytest-cov>=4",
|
|
84
85
|
"coverage[toml]>=7",
|
|
86
|
+
"moscot[neural]"
|
|
85
87
|
]
|
|
86
88
|
docs = [
|
|
87
89
|
"sphinx>=5.1.1",
|
|
@@ -274,7 +276,7 @@ env_list = lint-code,py{3.10,3.11,3.12}
|
|
|
274
276
|
skip_missing_interpreters = true
|
|
275
277
|
|
|
276
278
|
[testenv]
|
|
277
|
-
extras = test
|
|
279
|
+
extras = test,neural
|
|
278
280
|
commands =
|
|
279
281
|
python -m pytest {tty:--color=yes} {posargs: \
|
|
280
282
|
--cov={env_site_packages_dir}{/}moscot --cov-config={tox_root}{/}pyproject.toml \
|
|
@@ -290,7 +292,7 @@ commands =
|
|
|
290
292
|
|
|
291
293
|
[testenv:lint-docs]
|
|
292
294
|
description = Lint the documentation.
|
|
293
|
-
extras = docs
|
|
295
|
+
extras = docs,neural
|
|
294
296
|
ignore_errors = true
|
|
295
297
|
allowlist_externals = make
|
|
296
298
|
pass_env = PYENCHANT_LIBRARY_PATH
|
|
@@ -310,7 +312,7 @@ deps =
|
|
|
310
312
|
jupytext
|
|
311
313
|
nbconvert
|
|
312
314
|
leidenalg
|
|
313
|
-
extras = docs
|
|
315
|
+
extras = docs,neural
|
|
314
316
|
changedir = {tox_root}{/}docs
|
|
315
317
|
commands =
|
|
316
318
|
python -m ipykernel install --user --name=moscot
|
|
@@ -328,7 +330,7 @@ commands =
|
|
|
328
330
|
[testenv:build-docs]
|
|
329
331
|
description = Build the documentation.
|
|
330
332
|
deps =
|
|
331
|
-
extras = docs
|
|
333
|
+
extras = docs,neural
|
|
332
334
|
allowlist_externals = make
|
|
333
335
|
changedir = {tox_root}{/}docs
|
|
334
336
|
commands =
|
|
@@ -2,19 +2,18 @@ import os
|
|
|
2
2
|
from typing import Any, Literal, Mapping, Optional, Sequence, Union
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
|
+
from jax import Array as JaxArray
|
|
6
|
+
from numpy.typing import DTypeLike as DTypeLikeNumpy
|
|
7
|
+
from numpy.typing import NDArray
|
|
5
8
|
from ott.initializers.linear.initializers import SinkhornInitializer
|
|
6
9
|
from ott.initializers.linear.initializers_lr import LRInitializer
|
|
7
10
|
from ott.initializers.quadratic.initializers import BaseQuadraticInitializer
|
|
8
11
|
|
|
9
12
|
# TODO(michalk8): polish
|
|
10
13
|
|
|
11
|
-
try:
|
|
12
|
-
from numpy.typing import DTypeLike, NDArray
|
|
13
14
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
ArrayLike = np.ndarray # type: ignore[misc]
|
|
17
|
-
DTypeLike = np.dtype # type: ignore[misc]
|
|
15
|
+
ArrayLike = Union[NDArray[np.floating], JaxArray]
|
|
16
|
+
DTypeLike = DTypeLikeNumpy
|
|
18
17
|
|
|
19
18
|
ProblemKind_t = Literal["linear", "quadratic", "unknown"]
|
|
20
19
|
Numeric_t = Union[int, float] # type of `time_key` arguments
|
|
@@ -1,11 +1,19 @@
|
|
|
1
1
|
from ott.geometry import costs
|
|
2
2
|
|
|
3
3
|
from moscot.backends.ott._utils import sinkhorn_divergence
|
|
4
|
-
from moscot.backends.ott.output import GraphOTTOutput,
|
|
4
|
+
from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput
|
|
5
5
|
from moscot.backends.ott.solver import GENOTLinSolver, GWSolver, SinkhornSolver
|
|
6
6
|
from moscot.costs import register_cost
|
|
7
7
|
|
|
8
|
-
__all__ = [
|
|
8
|
+
__all__ = [
|
|
9
|
+
"OTTOutput",
|
|
10
|
+
"GWSolver",
|
|
11
|
+
"SinkhornSolver",
|
|
12
|
+
"NeuralOutput",
|
|
13
|
+
"sinkhorn_divergence",
|
|
14
|
+
"GENOTLinSolver",
|
|
15
|
+
"GraphOTTOutput",
|
|
16
|
+
]
|
|
9
17
|
|
|
10
18
|
|
|
11
19
|
register_cost("euclidean", backend="ott")(costs.Euclidean)
|
|
@@ -184,7 +184,7 @@ def alpha_to_fused_penalty(alpha: float) -> float:
|
|
|
184
184
|
return (1 - alpha) / alpha
|
|
185
185
|
|
|
186
186
|
|
|
187
|
-
def densify(arr: ArrayLike) -> jax.Array:
|
|
187
|
+
def densify(arr: Union[ArrayLike, sp.sparray, sp.spmatrix]) -> jax.Array:
|
|
188
188
|
"""If the input is sparse, convert it to dense.
|
|
189
189
|
|
|
190
190
|
Parameters
|
|
@@ -197,7 +197,8 @@ def densify(arr: ArrayLike) -> jax.Array:
|
|
|
197
197
|
dense :mod:`jax` array.
|
|
198
198
|
"""
|
|
199
199
|
if sp.issparse(arr):
|
|
200
|
-
|
|
200
|
+
arr_sp: Union[sp.sparray, sp.spmatrix] = arr
|
|
201
|
+
arr = arr_sp.toarray()
|
|
201
202
|
elif isinstance(arr, jesp.BCOO):
|
|
202
203
|
arr = arr.todense()
|
|
203
204
|
return jnp.asarray(arr)
|
|
@@ -17,7 +17,7 @@ from moscot._types import ArrayLike, Device_t
|
|
|
17
17
|
from moscot.backends.ott._utils import get_nearest_neighbors
|
|
18
18
|
from moscot.base.output import BaseDiscreteSolverOutput, BaseNeuralOutput
|
|
19
19
|
|
|
20
|
-
__all__ = ["OTTOutput", "GraphOTTOutput", "
|
|
20
|
+
__all__ = ["OTTOutput", "GraphOTTOutput", "NeuralOutput"]
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class OTTOutput(BaseDiscreteSolverOutput):
|
|
@@ -182,6 +182,9 @@ class OTTOutput(BaseDiscreteSolverOutput):
|
|
|
182
182
|
axis=1 - forward,
|
|
183
183
|
).T # convert to batch first
|
|
184
184
|
|
|
185
|
+
def _apply_forward(self, x: ArrayLike) -> ArrayLike:
|
|
186
|
+
return self._apply(x, forward=True)
|
|
187
|
+
|
|
185
188
|
@property
|
|
186
189
|
def shape(self) -> Tuple[int, int]: # noqa: D102
|
|
187
190
|
if isinstance(self._output, sinkhorn.SinkhornOutput):
|
|
@@ -241,11 +244,11 @@ class OTTOutput(BaseDiscreteSolverOutput):
|
|
|
241
244
|
return jnp.ones((n,))
|
|
242
245
|
|
|
243
246
|
|
|
244
|
-
class
|
|
247
|
+
class NeuralOutput(BaseNeuralOutput):
|
|
245
248
|
"""Output wrapper for GENOT."""
|
|
246
249
|
|
|
247
250
|
def __init__(self, model: GENOT, logs: dict[str, list[float]]):
|
|
248
|
-
"""Initialize `
|
|
251
|
+
"""Initialize `NeuralOutput`.
|
|
249
252
|
|
|
250
253
|
Parameters
|
|
251
254
|
----------
|
|
@@ -269,8 +272,7 @@ class OTTNeuralOutput(BaseNeuralOutput):
|
|
|
269
272
|
self,
|
|
270
273
|
src_dist: ArrayLike,
|
|
271
274
|
tgt_dist: ArrayLike,
|
|
272
|
-
|
|
273
|
-
func: Callable[[jnp.ndarray], jnp.ndarray],
|
|
275
|
+
func: Callable[[ArrayLike], ArrayLike],
|
|
274
276
|
save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments
|
|
275
277
|
batch_size: int = 1024,
|
|
276
278
|
k: int = 30,
|
|
@@ -279,9 +281,9 @@ class OTTNeuralOutput(BaseNeuralOutput):
|
|
|
279
281
|
recall_target: float = 0.95,
|
|
280
282
|
aggregate_to_topk: bool = True,
|
|
281
283
|
) -> sp.csr_matrix:
|
|
282
|
-
row_indices:
|
|
283
|
-
column_indices:
|
|
284
|
-
distances_list:
|
|
284
|
+
row_indices: List[ArrayLike] = []
|
|
285
|
+
column_indices: List[ArrayLike] = []
|
|
286
|
+
distances_list: List[ArrayLike] = []
|
|
285
287
|
if length_scale is None:
|
|
286
288
|
key = jax.random.PRNGKey(seed)
|
|
287
289
|
src_batch = src_dist[jax.random.choice(key, src_dist.shape[0], shape=((batch_size,)))]
|
|
@@ -306,20 +308,14 @@ class OTTNeuralOutput(BaseNeuralOutput):
|
|
|
306
308
|
row_indices = jnp.concatenate(row_indices)
|
|
307
309
|
column_indices = jnp.concatenate(column_indices)
|
|
308
310
|
tm = sp.csr_matrix((distances, (row_indices, column_indices)), shape=[len(src_dist), len(tgt_dist)])
|
|
309
|
-
if
|
|
310
|
-
|
|
311
|
-
self._transport_matrix = tm
|
|
312
|
-
else:
|
|
313
|
-
tm = tm.T
|
|
314
|
-
if save_transport_matrix:
|
|
315
|
-
self._inverse_transport_matrix = tm
|
|
311
|
+
if save_transport_matrix:
|
|
312
|
+
self._transport_matrix = tm
|
|
316
313
|
return tm
|
|
317
314
|
|
|
318
315
|
def project_to_transport_matrix( # type:ignore[override]
|
|
319
316
|
self,
|
|
320
317
|
src_cells: ArrayLike,
|
|
321
318
|
tgt_cells: ArrayLike,
|
|
322
|
-
forward: bool = True,
|
|
323
319
|
condition: ArrayLike = None,
|
|
324
320
|
save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments
|
|
325
321
|
batch_size: int = 1024,
|
|
@@ -351,7 +347,7 @@ class OTTNeuralOutput(BaseNeuralOutput):
|
|
|
351
347
|
save_transport_matrix
|
|
352
348
|
Whether to save the transport matrix.
|
|
353
349
|
batch_size
|
|
354
|
-
Number of data points in the source distribution the
|
|
350
|
+
Number of data points in the source distribution the neighborhood graph is computed
|
|
355
351
|
for in parallel.
|
|
356
352
|
k
|
|
357
353
|
Number of neighbors to construct the k-nearest neighbor graph of a mapped cell.
|
|
@@ -375,13 +371,12 @@ class OTTNeuralOutput(BaseNeuralOutput):
|
|
|
375
371
|
The projected transport matrix.
|
|
376
372
|
"""
|
|
377
373
|
src_cells, tgt_cells = jnp.asarray(src_cells), jnp.asarray(tgt_cells)
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
func, src_dist, tgt_dist = (push, src_cells, tgt_cells)
|
|
374
|
+
conditioned_fn: Callable[[ArrayLike], ArrayLike] = lambda x: self.push(x, condition)
|
|
375
|
+
push = self.push if condition is None else conditioned_fn
|
|
376
|
+
func, src_dist, tgt_dist = (push, src_cells, tgt_cells)
|
|
381
377
|
return self._project_transport_matrix(
|
|
382
378
|
src_dist=src_dist,
|
|
383
379
|
tgt_dist=tgt_dist,
|
|
384
|
-
forward=forward,
|
|
385
380
|
func=func,
|
|
386
381
|
save_transport_matrix=save_transport_matrix, # TODO(@MUCDK) adapt order of arguments
|
|
387
382
|
batch_size=batch_size,
|
|
@@ -406,31 +401,13 @@ class OTTNeuralOutput(BaseNeuralOutput):
|
|
|
406
401
|
-------
|
|
407
402
|
Pushed distribution.
|
|
408
403
|
"""
|
|
404
|
+
if isinstance(x, (bool, int, float, complex)):
|
|
405
|
+
raise ValueError("Expected array, found scalar value.")
|
|
409
406
|
if x.ndim not in (1, 2):
|
|
410
407
|
raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.")
|
|
411
|
-
return self.
|
|
412
|
-
|
|
413
|
-
def pull(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
|
|
414
|
-
"""Pull distribution `x` conditioned on condition `cond`.
|
|
415
|
-
|
|
416
|
-
This does not make sense for some neural models and is therefore left unimplemented.
|
|
417
|
-
|
|
418
|
-
Parameters
|
|
419
|
-
----------
|
|
420
|
-
x
|
|
421
|
-
Distribution to push.
|
|
422
|
-
cond
|
|
423
|
-
Condition of conditional neural OT.
|
|
424
|
-
|
|
425
|
-
Raises
|
|
426
|
-
------
|
|
427
|
-
NotImplementedError
|
|
428
|
-
"""
|
|
429
|
-
raise NotImplementedError("`pull` does not make sense for neural OT.")
|
|
408
|
+
return self._apply_forward(x, cond=cond)
|
|
430
409
|
|
|
431
|
-
def
|
|
432
|
-
if not forward:
|
|
433
|
-
raise NotImplementedError("Backward i.e., pull on neural OT is not supported.")
|
|
410
|
+
def _apply_forward(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
|
|
434
411
|
return self._model.transport(x, condition=cond)
|
|
435
412
|
|
|
436
413
|
@property
|
|
@@ -445,7 +422,7 @@ class OTTNeuralOutput(BaseNeuralOutput):
|
|
|
445
422
|
def to(
|
|
446
423
|
self,
|
|
447
424
|
device: Optional[Device_t] = None,
|
|
448
|
-
) -> "
|
|
425
|
+
) -> "NeuralOutput":
|
|
449
426
|
"""Transfer the output to another device or change its data type.
|
|
450
427
|
|
|
451
428
|
Parameters
|
|
@@ -471,7 +448,7 @@ class OTTNeuralOutput(BaseNeuralOutput):
|
|
|
471
448
|
# raise IndexError(f"Unable to fetch the device with `id={idx}`.") from err
|
|
472
449
|
|
|
473
450
|
# out = jax.device_put(self._model, device)
|
|
474
|
-
# return
|
|
451
|
+
# return NeuralOutput(out)
|
|
475
452
|
return self # TODO(ilan-gold) move model to device
|
|
476
453
|
|
|
477
454
|
@property
|
|
@@ -53,7 +53,7 @@ from moscot.backends.ott._utils import (
|
|
|
53
53
|
densify,
|
|
54
54
|
ensure_2d,
|
|
55
55
|
)
|
|
56
|
-
from moscot.backends.ott.output import GraphOTTOutput,
|
|
56
|
+
from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput
|
|
57
57
|
from moscot.base.problems._utils import TimeScalesHeatKernel
|
|
58
58
|
from moscot.base.solver import OTSolver
|
|
59
59
|
from moscot.costs import get_cost
|
|
@@ -216,7 +216,7 @@ class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC):
|
|
|
216
216
|
problem_shape = x.shape if problem_shape is None else problem_shape
|
|
217
217
|
return _instantiate_geodesic_cost(
|
|
218
218
|
arr=arr,
|
|
219
|
-
problem_shape=problem_shape,
|
|
219
|
+
problem_shape=problem_shape, # type: ignore[arg-type]
|
|
220
220
|
t=t,
|
|
221
221
|
is_linear_term=is_linear_term,
|
|
222
222
|
epsilon=epsilon,
|
|
@@ -699,10 +699,10 @@ class GENOTLinSolver(OTSolver[OTTOutput]):
|
|
|
699
699
|
def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]:
|
|
700
700
|
return {"batch_size", "train_size", "trainloader", "validloader", "seed"}, {} # type: ignore[return-value]
|
|
701
701
|
|
|
702
|
-
def _solve(self, data_samplers: Tuple[MultiLoader, MultiLoader]) ->
|
|
702
|
+
def _solve(self, data_samplers: Tuple[MultiLoader, MultiLoader]) -> NeuralOutput: # type: ignore[override]
|
|
703
703
|
seed = self._neural_kwargs.get("seed", 0) # TODO(ilan-gold): unify rng hadnling like OTT tests
|
|
704
704
|
rng = jax.random.PRNGKey(seed)
|
|
705
705
|
logs = self.solver(
|
|
706
706
|
data_samplers[0], n_iters=self._neural_kwargs.get("n_iters", 100), rng=rng
|
|
707
707
|
) # TODO(ilan-gold): validation and figure out defualts
|
|
708
|
-
return
|
|
708
|
+
return NeuralOutput(self.solver, logs)
|
|
@@ -42,8 +42,7 @@ def register_solver(
|
|
|
42
42
|
return _REGISTRY.register(backend) # type: ignore[return-value]
|
|
43
43
|
|
|
44
44
|
|
|
45
|
-
|
|
46
|
-
@register_solver("ott") # type: ignore[arg-type]
|
|
45
|
+
@register_solver("ott")
|
|
47
46
|
def _(
|
|
48
47
|
problem_kind: Literal["linear", "quadratic"],
|
|
49
48
|
solver_name: Optional[Literal["GENOTLinSolver"]] = None,
|
|
@@ -58,7 +58,7 @@ class BaseCost(abc.ABC):
|
|
|
58
58
|
f"Cost matrix contains `{np.sum(np.isnan(cost))}` NaN values, "
|
|
59
59
|
f"setting them to the maximum value `{maxx}`."
|
|
60
60
|
)
|
|
61
|
-
cost = np.nan_to_num(cost, nan=maxx) # type: ignore[
|
|
61
|
+
cost = np.nan_to_num(cost, nan=maxx) # type: ignore[arg-type, type-var]
|
|
62
62
|
if np.any(cost < 0):
|
|
63
63
|
raise ValueError(f"Cost matrix contains `{np.sum(cost < 0)}` negative values.")
|
|
64
64
|
return cost
|