moscot 0.2.0__tar.gz → 0.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {moscot-0.2.0 → moscot-0.3.1}/.pre-commit-config.yaml +9 -9
- {moscot-0.2.0 → moscot-0.3.1}/PKG-INFO +34 -15
- moscot-0.3.1/README.rst +66 -0
- {moscot-0.2.0 → moscot-0.3.1}/pyproject.toml +22 -14
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/_constants.py +0 -1
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/_types.py +21 -6
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/backends/ott/__init__.py +2 -1
- moscot-0.3.1/src/moscot/backends/ott/_utils.py +88 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/backends/ott/output.py +50 -32
- moscot-0.3.1/src/moscot/backends/ott/solver.py +333 -0
- moscot-0.3.1/src/moscot/backends/utils.py +53 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/cost.py +20 -22
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/output.py +70 -72
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/problems/_mixins.py +60 -102
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/problems/_utils.py +69 -33
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/problems/birth_death.py +103 -87
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/problems/compound_problem.py +158 -220
- moscot-0.3.1/src/moscot/base/problems/manager.py +189 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/problems/problem.py +327 -184
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/solver.py +29 -5
- moscot-0.3.1/src/moscot/costs/_costs.py +141 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/costs/_utils.py +1 -1
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/datasets.py +203 -25
- moscot-0.3.1/src/moscot/plotting/_plotting.py +427 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/plotting/_utils.py +20 -12
- moscot-0.3.1/src/moscot/problems/__init__.py +13 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/problems/_utils.py +32 -26
- moscot-0.3.1/src/moscot/problems/cross_modality/__init__.py +4 -0
- moscot-0.3.1/src/moscot/problems/cross_modality/_mixins.py +195 -0
- moscot-0.3.1/src/moscot/problems/cross_modality/_translation.py +297 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/problems/generic/__init__.py +1 -1
- moscot-0.3.1/src/moscot/problems/generic/_generic.py +481 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/problems/generic/_mixins.py +94 -54
- moscot-0.3.1/src/moscot/problems/space/_alignment.py +251 -0
- moscot-0.3.1/src/moscot/problems/space/_mapping.py +322 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/problems/space/_mixins.py +237 -194
- moscot-0.3.1/src/moscot/problems/spatiotemporal/_spatio_temporal.py +268 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/problems/time/__init__.py +1 -1
- moscot-0.3.1/src/moscot/problems/time/_lineage.py +492 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/problems/time/_mixins.py +355 -257
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/data.py +37 -4
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/subset_policy.py +211 -72
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/tagged_array.py +28 -16
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot.egg-info/PKG-INFO +34 -15
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot.egg-info/SOURCES.txt +3 -5
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot.egg-info/requires.txt +5 -5
- moscot-0.2.0/README.rst +0 -47
- moscot-0.2.0/src/moscot/_docs/__init__.py +0 -0
- moscot-0.2.0/src/moscot/_docs/_docs.py +0 -458
- moscot-0.2.0/src/moscot/_docs/_docs_mixins.py +0 -224
- moscot-0.2.0/src/moscot/_docs/_docs_plot.py +0 -197
- moscot-0.2.0/src/moscot/_docs/_utils.py +0 -19
- moscot-0.2.0/src/moscot/backends/ott/_utils.py +0 -46
- moscot-0.2.0/src/moscot/backends/ott/solver.py +0 -309
- moscot-0.2.0/src/moscot/backends/utils.py +0 -40
- moscot-0.2.0/src/moscot/base/problems/manager.py +0 -130
- moscot-0.2.0/src/moscot/costs/_costs.py +0 -102
- moscot-0.2.0/src/moscot/plotting/_plotting.py +0 -339
- moscot-0.2.0/src/moscot/problems/__init__.py +0 -5
- moscot-0.2.0/src/moscot/problems/generic/_generic.py +0 -339
- moscot-0.2.0/src/moscot/problems/space/_alignment.py +0 -168
- moscot-0.2.0/src/moscot/problems/space/_mapping.py +0 -250
- moscot-0.2.0/src/moscot/problems/spatiotemporal/_spatio_temporal.py +0 -203
- moscot-0.2.0/src/moscot/problems/time/_lineage.py +0 -356
- {moscot-0.2.0 → moscot-0.3.1}/.gitignore +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/.gitmodules +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/.readthedocs.yml +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/LICENSE +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/MANIFEST.in +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/codecov.yml +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/setup.cfg +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/__init__.py +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/_logging.py +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/_registry.py +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/backends/__init__.py +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/__init__.py +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/problems/__init__.py +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/costs/__init__.py +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/plotting/__init__.py +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/problems/space/__init__.py +1 -1
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/problems/spatiotemporal/__init__.py +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/py.typed +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/__init__.py +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/_data/allTFs_dmel.txt +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/_data/allTFs_hg38.txt +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/_data/allTFs_mm.txt +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/_data/human_apoptosis.txt +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/_data/human_proliferation.txt +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/_data/mouse_apoptosis.txt +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/_data/mouse_proliferation.txt +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot.egg-info/dependency_links.txt +0 -0
- {moscot-0.2.0 → moscot-0.3.1}/src/moscot.egg-info/top_level.txt +0 -0
|
@@ -7,22 +7,22 @@ default_stages:
|
|
|
7
7
|
minimum_pre_commit_version: 3.0.0
|
|
8
8
|
repos:
|
|
9
9
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
|
10
|
-
rev: v1.
|
|
10
|
+
rev: v1.4.1
|
|
11
11
|
hooks:
|
|
12
12
|
- id: mypy
|
|
13
|
-
additional_dependencies: [numpy>=1.
|
|
13
|
+
additional_dependencies: [numpy>=1.25.0]
|
|
14
14
|
files: ^src
|
|
15
15
|
- repo: https://github.com/psf/black
|
|
16
|
-
rev: 23.
|
|
16
|
+
rev: 23.7.0
|
|
17
17
|
hooks:
|
|
18
18
|
- id: black
|
|
19
19
|
additional_dependencies: [toml]
|
|
20
20
|
- repo: https://github.com/pre-commit/mirrors-prettier
|
|
21
|
-
rev: v3.0.0
|
|
21
|
+
rev: v3.0.0
|
|
22
22
|
hooks:
|
|
23
23
|
- id: prettier
|
|
24
24
|
language_version: system
|
|
25
|
-
- repo: https://github.com/
|
|
25
|
+
- repo: https://github.com/PyCQA/isort
|
|
26
26
|
rev: 5.12.0
|
|
27
27
|
hooks:
|
|
28
28
|
- id: isort
|
|
@@ -42,12 +42,12 @@ repos:
|
|
|
42
42
|
- id: check-yaml
|
|
43
43
|
- id: check-toml
|
|
44
44
|
- repo: https://github.com/asottile/pyupgrade
|
|
45
|
-
rev: v3.
|
|
45
|
+
rev: v3.9.0
|
|
46
46
|
hooks:
|
|
47
47
|
- id: pyupgrade
|
|
48
48
|
args: [--py3-plus, --py38-plus, --keep-runtime-typing]
|
|
49
49
|
- repo: https://github.com/asottile/blacken-docs
|
|
50
|
-
rev: 1.
|
|
50
|
+
rev: 1.15.0
|
|
51
51
|
hooks:
|
|
52
52
|
- id: blacken-docs
|
|
53
53
|
additional_dependencies: [black==23.1.0]
|
|
@@ -61,9 +61,9 @@ repos:
|
|
|
61
61
|
rev: v1.1.1
|
|
62
62
|
hooks:
|
|
63
63
|
- id: doc8
|
|
64
|
-
- repo: https://github.com/
|
|
64
|
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
65
65
|
# Ruff version.
|
|
66
|
-
rev: v0.0.
|
|
66
|
+
rev: v0.0.280
|
|
67
67
|
hooks:
|
|
68
68
|
- id: ruff
|
|
69
69
|
args: [--fix, --exit-non-zero-on-fix]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: moscot
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.1
|
|
4
4
|
Summary: Multi-omic single-cell optimal transport tools
|
|
5
5
|
Author: Dominik Klein, Giovanni Palla, Michal Klein, Zoe Piran, Marius Lange
|
|
6
6
|
Maintainer-email: Dominik Klein <dominik.klein@helmholtz-muenchen.de>, Giovanni Palla <giovanni.palla@helmholtz-muenchen.de>, Michal Klein <michal.klein@helmholtz-muenchen.de>
|
|
@@ -62,16 +62,18 @@ Provides-Extra: test
|
|
|
62
62
|
Provides-Extra: docs
|
|
63
63
|
License-File: LICENSE
|
|
64
64
|
|
|
65
|
-
|Codecov|
|
|
65
|
+
|PyPI| |Downloads| |CI| |Pre-commit| |Codecov| |Docs|
|
|
66
66
|
|
|
67
67
|
moscot - multi-omic single-cell optimal transport tools
|
|
68
68
|
=======================================================
|
|
69
69
|
|
|
70
70
|
**moscot** is a scalable framework for Optimal Transport (OT) applications in
|
|
71
71
|
single-cell genomics. It can be used for
|
|
72
|
-
|
|
73
|
-
- spatial
|
|
74
|
-
- spatial
|
|
72
|
+
|
|
73
|
+
- trajectory inference (incorporating spatial and lineage information)
|
|
74
|
+
- mapping cells to their spatial organisation
|
|
75
|
+
- aligning spatial transcriptomics slides
|
|
76
|
+
- translating modalities
|
|
75
77
|
- prototyping of new OT models in single-cell genomics
|
|
76
78
|
|
|
77
79
|
**moscot** is powered by
|
|
@@ -85,27 +87,44 @@ You can install **moscot** via::
|
|
|
85
87
|
|
|
86
88
|
pip install moscot
|
|
87
89
|
|
|
88
|
-
In order to install **moscot** from
|
|
90
|
+
In order to install **moscot** from in editable mode, run::
|
|
89
91
|
|
|
90
92
|
git clone https://github.com/theislab/moscot
|
|
91
93
|
cd moscot
|
|
92
|
-
pip install -e .
|
|
94
|
+
pip install -e .
|
|
95
|
+
|
|
96
|
+
For further instructions how to install jax, please refer to https://github.com/google/jax.
|
|
93
97
|
|
|
94
|
-
|
|
98
|
+
Resources
|
|
99
|
+
---------
|
|
95
100
|
|
|
96
|
-
|
|
101
|
+
Please have a look at our `documentation <https://moscot.readthedocs.io>`_
|
|
102
|
+
|
|
103
|
+
Reference
|
|
104
|
+
---------
|
|
97
105
|
|
|
106
|
+
Our preprint "Mapping cells through time and space with moscot" can be found `here <https://www.biorxiv.org/content/10.1101/2023.05.11.540374v1>`_.
|
|
98
107
|
|
|
99
108
|
.. |Codecov| image:: https://codecov.io/gh/theislab/moscot/branch/master/graph/badge.svg?token=Rgtm5Tsblo
|
|
100
109
|
:target: https://codecov.io/gh/theislab/moscot
|
|
101
110
|
:alt: Coverage
|
|
102
111
|
|
|
103
|
-
|
|
104
|
-
|
|
112
|
+
.. |PyPI| image:: https://img.shields.io/pypi/v/moscot.svg
|
|
113
|
+
:target: https://pypi.org/project/moscot/
|
|
114
|
+
:alt: PyPI
|
|
105
115
|
|
|
106
|
-
|
|
116
|
+
.. |CI| image:: https://img.shields.io/github/actions/workflow/status/theislab/moscot/test.yml?branch=main
|
|
117
|
+
:target: https://github.com/theislab/moscot/actions
|
|
118
|
+
:alt: CI
|
|
107
119
|
|
|
108
|
-
|
|
109
|
-
|
|
120
|
+
.. |Pre-commit| image:: https://results.pre-commit.ci/badge/github/theislab/moscot/main.svg
|
|
121
|
+
:target: https://results.pre-commit.ci/latest/github/theislab/moscot/main
|
|
122
|
+
:alt: pre-commit.ci status
|
|
123
|
+
|
|
124
|
+
.. |Docs| image:: https://img.shields.io/readthedocs/moscot
|
|
125
|
+
:target: https://moscot.readthedocs.io/en/stable/
|
|
126
|
+
:alt: Documentation
|
|
110
127
|
|
|
111
|
-
|
|
128
|
+
.. |Downloads| image:: https://pepy.tech/badge/moscot
|
|
129
|
+
:target: https://pepy.tech/project/moscot
|
|
130
|
+
:alt: Downloads
|
moscot-0.3.1/README.rst
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
|PyPI| |Downloads| |CI| |Pre-commit| |Codecov| |Docs|
|
|
2
|
+
|
|
3
|
+
moscot - multi-omic single-cell optimal transport tools
|
|
4
|
+
=======================================================
|
|
5
|
+
|
|
6
|
+
**moscot** is a scalable framework for Optimal Transport (OT) applications in
|
|
7
|
+
single-cell genomics. It can be used for
|
|
8
|
+
|
|
9
|
+
- trajectory inference (incorporating spatial and lineage information)
|
|
10
|
+
- mapping cells to their spatial organisation
|
|
11
|
+
- aligning spatial transcriptomics slides
|
|
12
|
+
- translating modalities
|
|
13
|
+
- prototyping of new OT models in single-cell genomics
|
|
14
|
+
|
|
15
|
+
**moscot** is powered by
|
|
16
|
+
`OTT <https://ott-jax.readthedocs.io>`_ which is a JAX-based Optimal
|
|
17
|
+
Transport toolkit that supports just-in-time compilation, GPU acceleration, automatic
|
|
18
|
+
differentiation and linear memory complexity for OT problems.
|
|
19
|
+
|
|
20
|
+
Installation
|
|
21
|
+
------------
|
|
22
|
+
You can install **moscot** via::
|
|
23
|
+
|
|
24
|
+
pip install moscot
|
|
25
|
+
|
|
26
|
+
In order to install **moscot** from in editable mode, run::
|
|
27
|
+
|
|
28
|
+
git clone https://github.com/theislab/moscot
|
|
29
|
+
cd moscot
|
|
30
|
+
pip install -e .
|
|
31
|
+
|
|
32
|
+
For further instructions how to install jax, please refer to https://github.com/google/jax.
|
|
33
|
+
|
|
34
|
+
Resources
|
|
35
|
+
---------
|
|
36
|
+
|
|
37
|
+
Please have a look at our `documentation <https://moscot.readthedocs.io>`_
|
|
38
|
+
|
|
39
|
+
Reference
|
|
40
|
+
---------
|
|
41
|
+
|
|
42
|
+
Our preprint "Mapping cells through time and space with moscot" can be found `here <https://www.biorxiv.org/content/10.1101/2023.05.11.540374v1>`_.
|
|
43
|
+
|
|
44
|
+
.. |Codecov| image:: https://codecov.io/gh/theislab/moscot/branch/master/graph/badge.svg?token=Rgtm5Tsblo
|
|
45
|
+
:target: https://codecov.io/gh/theislab/moscot
|
|
46
|
+
:alt: Coverage
|
|
47
|
+
|
|
48
|
+
.. |PyPI| image:: https://img.shields.io/pypi/v/moscot.svg
|
|
49
|
+
:target: https://pypi.org/project/moscot/
|
|
50
|
+
:alt: PyPI
|
|
51
|
+
|
|
52
|
+
.. |CI| image:: https://img.shields.io/github/actions/workflow/status/theislab/moscot/test.yml?branch=main
|
|
53
|
+
:target: https://github.com/theislab/moscot/actions
|
|
54
|
+
:alt: CI
|
|
55
|
+
|
|
56
|
+
.. |Pre-commit| image:: https://results.pre-commit.ci/badge/github/theislab/moscot/main.svg
|
|
57
|
+
:target: https://results.pre-commit.ci/latest/github/theislab/moscot/main
|
|
58
|
+
:alt: pre-commit.ci status
|
|
59
|
+
|
|
60
|
+
.. |Docs| image:: https://img.shields.io/readthedocs/moscot
|
|
61
|
+
:target: https://moscot.readthedocs.io/en/stable/
|
|
62
|
+
:alt: Documentation
|
|
63
|
+
|
|
64
|
+
.. |Downloads| image:: https://pepy.tech/badge/moscot
|
|
65
|
+
:target: https://pepy.tech/project/moscot
|
|
66
|
+
:alt: Downloads
|
|
@@ -47,15 +47,14 @@ maintainers = [
|
|
|
47
47
|
dependencies = [
|
|
48
48
|
"numpy>=1.20.0",
|
|
49
49
|
"scipy>=1.7.0",
|
|
50
|
-
"pandas>=
|
|
50
|
+
"pandas>=2.0.1",
|
|
51
51
|
"networkx>=2.6.3",
|
|
52
52
|
# https://github.com/scverse/scanpy/issues/2411
|
|
53
53
|
"matplotlib>=3.5.0",
|
|
54
|
-
"anndata>=0.
|
|
54
|
+
"anndata>=0.9.1",
|
|
55
55
|
"scanpy>=1.9.3",
|
|
56
56
|
"wrapt>=1.13.2",
|
|
57
|
-
"
|
|
58
|
-
"ott-jax>=0.4.0",
|
|
57
|
+
"ott-jax==0.4.0",
|
|
59
58
|
"cloudpickle>=2.2.0",
|
|
60
59
|
]
|
|
61
60
|
|
|
@@ -79,9 +78,10 @@ docs = [
|
|
|
79
78
|
"sphinx_copybutton>=0.5.0",
|
|
80
79
|
"sphinxcontrib-bibtex>=2.3.0",
|
|
81
80
|
"sphinxcontrib-spelling>=7.6.2",
|
|
81
|
+
"sphinx-autodoc-typehints",
|
|
82
82
|
"furo>=2022.09.29",
|
|
83
|
+
"sphinx-tippy>=0.4.1",
|
|
83
84
|
"myst-nb>=0.17.1",
|
|
84
|
-
"nbsphinx>=0.8.1",
|
|
85
85
|
"ipython>=7.20.0",
|
|
86
86
|
"sphinx_design>=0.3.0",
|
|
87
87
|
]
|
|
@@ -151,9 +151,6 @@ target-version = "py38"
|
|
|
151
151
|
"src/moscot/utils/subset_policy.py" = ["D101", "D102"]
|
|
152
152
|
[tool.ruff.pydocstyle]
|
|
153
153
|
convention = "numpy"
|
|
154
|
-
[tool.ruff.pyupgrade]
|
|
155
|
-
# Preserve types, even if a file imports `from __future__ import annotations`.
|
|
156
|
-
keep-runtime-typing = true
|
|
157
154
|
[tool.ruff.flake8-tidy-imports]
|
|
158
155
|
ban-relative-imports = "all"
|
|
159
156
|
[tool.ruff.flake8-quotes]
|
|
@@ -168,9 +165,10 @@ include = '\.pyi?$'
|
|
|
168
165
|
profile = "black"
|
|
169
166
|
include_trailing_comma = true
|
|
170
167
|
multi_line_output = 3
|
|
171
|
-
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "NUMERIC", "PLOTTING", "BIO", "FIRSTPARTY", "LOCALFOLDER"]
|
|
168
|
+
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "GENERIC", "NUMERIC", "PLOTTING", "BIO", "FIRSTPARTY", "LOCALFOLDER"]
|
|
172
169
|
# also contains what we import in notebooks
|
|
173
|
-
|
|
170
|
+
known_generic = ["wrapt", "joblib"]
|
|
171
|
+
known_numeric = ["numpy", "scipy", "jax", "ott", "pandas", "sklearn", "networkx", "statsmodels"]
|
|
174
172
|
known_bio = ["anndata", "scanpy", "squidpy"]
|
|
175
173
|
known_plotting = ["IPython", "matplotlib", "mpl_toolkits", "seaborn"]
|
|
176
174
|
|
|
@@ -178,9 +176,8 @@ known_plotting = ["IPython", "matplotlib", "mpl_toolkits", "seaborn"]
|
|
|
178
176
|
markers = ["fast: marks tests as fask"]
|
|
179
177
|
xfail_strict = true
|
|
180
178
|
filterwarnings = [
|
|
181
|
-
"ignore:X.dtype being converted:FutureWarning",
|
|
182
179
|
"ignore:No data for colormapping:UserWarning",
|
|
183
|
-
"ignore:
|
|
180
|
+
"ignore:The dtype argument will be deprecated in anndata:PendingDeprecationWarning"
|
|
184
181
|
]
|
|
185
182
|
|
|
186
183
|
[tool.coverage.run]
|
|
@@ -214,6 +211,7 @@ ignore_directives = [
|
|
|
214
211
|
"automodule",
|
|
215
212
|
"autoclass",
|
|
216
213
|
"bibliography",
|
|
214
|
+
"glossary",
|
|
217
215
|
"card",
|
|
218
216
|
"grid",
|
|
219
217
|
]
|
|
@@ -251,7 +249,7 @@ show_column_numbers = true
|
|
|
251
249
|
error_summary = true
|
|
252
250
|
ignore_missing_imports = true
|
|
253
251
|
|
|
254
|
-
disable_error_code = ["assignment", "comparison-overlap", "no-untyped-def"]
|
|
252
|
+
disable_error_code = ["assignment", "comparison-overlap", "no-untyped-def", "override"]
|
|
255
253
|
|
|
256
254
|
[tool.doc8]
|
|
257
255
|
max_line_length = 120
|
|
@@ -294,6 +292,7 @@ commands =
|
|
|
294
292
|
|
|
295
293
|
[testenv:clean-docs]
|
|
296
294
|
description = Remove the documentation.
|
|
295
|
+
deps =
|
|
297
296
|
skip_install = true
|
|
298
297
|
changedir = {tox_root}{/}docs
|
|
299
298
|
allowlist_externals = make
|
|
@@ -302,7 +301,6 @@ commands =
|
|
|
302
301
|
|
|
303
302
|
[testenv:build-docs]
|
|
304
303
|
description = Build the documentation.
|
|
305
|
-
use_develop = true
|
|
306
304
|
deps =
|
|
307
305
|
extras = docs
|
|
308
306
|
allowlist_externals = make
|
|
@@ -324,4 +322,14 @@ commands =
|
|
|
324
322
|
python -m twine check {tox_root}{/}dist{/}*
|
|
325
323
|
commands_post =
|
|
326
324
|
python -c 'import pathlib; print(f"Package is under:", pathlib.Path("{tox_root}") / "dist")'
|
|
325
|
+
|
|
326
|
+
[testenv:format-references]
|
|
327
|
+
description = Format references.bib.
|
|
328
|
+
deps =
|
|
329
|
+
skip_install = true
|
|
330
|
+
allowlist_externals = biber
|
|
331
|
+
commands = biber --tool --output_file={tox_root}{/}docs{/}references.bib --nolog \
|
|
332
|
+
--output_align --output_indent=2 --output_fieldcase=lower \
|
|
333
|
+
--output_legacy_dates --output-field-replace=journaltitle:journal,thesis:phdthesis,institution:school \
|
|
334
|
+
{tox_root}{/}docs{/}references.bib
|
|
327
335
|
"""
|
|
@@ -16,7 +16,7 @@ except (ImportError, TypeError):
|
|
|
16
16
|
ProblemKind_t = Literal["linear", "quadratic", "unknown"]
|
|
17
17
|
Numeric_t = Union[int, float] # type of `time_key` arguments
|
|
18
18
|
Filter_t = Optional[Union[str, Mapping[str, Sequence[Any]]]] # type how to filter adata
|
|
19
|
-
Str_Dict_t = Union[str, Mapping[str, Sequence[Any]]] # type for `cell_transition`
|
|
19
|
+
Str_Dict_t = Optional[Union[str, Mapping[str, Sequence[Any]]]] # type for `cell_transition`
|
|
20
20
|
SinkFullRankInit = Literal["default", "gaussian", "sorting"]
|
|
21
21
|
LRInitializer_t = Literal["random", "rank2", "k-means", "generalized-k-means"]
|
|
22
22
|
|
|
@@ -24,20 +24,35 @@ SinkhornInitializer_t = Optional[Union[SinkFullRankInit, LRInitializer_t]]
|
|
|
24
24
|
QuadInitializer_t = Optional[LRInitializer_t]
|
|
25
25
|
|
|
26
26
|
Initializer_t = Union[SinkhornInitializer_t, LRInitializer_t]
|
|
27
|
-
ProblemStage_t = Literal["
|
|
28
|
-
Device_t = Literal["cpu", "gpu", "tpu"]
|
|
27
|
+
ProblemStage_t = Literal["prepared", "solved"]
|
|
28
|
+
Device_t = Union[Literal["cpu", "gpu", "tpu"], str]
|
|
29
29
|
|
|
30
30
|
# TODO(michalk8): autogenerate from the enums
|
|
31
|
-
ScaleCost_t =
|
|
32
|
-
OttCostFn_t = Literal[
|
|
31
|
+
ScaleCost_t = Union[float, Literal["mean", "max_cost", "max_bound", "max_norm", "median"]]
|
|
32
|
+
OttCostFn_t = Literal[
|
|
33
|
+
"euclidean",
|
|
34
|
+
"sq_euclidean",
|
|
35
|
+
"cosine",
|
|
36
|
+
"PNormP",
|
|
37
|
+
"SqPNorm",
|
|
38
|
+
"Euclidean",
|
|
39
|
+
"SqEuclidean",
|
|
40
|
+
"Cosine",
|
|
41
|
+
"ElasticL1",
|
|
42
|
+
"ElasticSTVS",
|
|
43
|
+
"ElasticSqKOverlap",
|
|
44
|
+
]
|
|
45
|
+
OttCostFnMap_t = Union[OttCostFn_t, Mapping[Literal["xy", "x", "y"], OttCostFn_t]]
|
|
33
46
|
GenericCostFn_t = Literal["barcode_distance", "leaf_distance", "custom"]
|
|
34
47
|
CostFn_t = Union[str, GenericCostFn_t, OttCostFn_t]
|
|
48
|
+
CostFnMap_t = Union[Union[OttCostFn_t, GenericCostFn_t], Mapping[str, Union[OttCostFn_t, GenericCostFn_t]]]
|
|
35
49
|
PathLike = Union[os.PathLike, str]
|
|
36
50
|
Policy_t = Literal[
|
|
37
51
|
"sequential",
|
|
38
52
|
"star",
|
|
39
53
|
"external_star",
|
|
54
|
+
"explicit",
|
|
40
55
|
"triu",
|
|
41
56
|
"tril",
|
|
42
|
-
"explicit",
|
|
43
57
|
]
|
|
58
|
+
CostKwargs_t = Union[Mapping[str, Any], Mapping[Literal["x", "y", "xy"], Mapping[str, Any]]]
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from ott.geometry import costs
|
|
2
2
|
|
|
3
|
+
from moscot.backends.ott._utils import sinkhorn_divergence
|
|
3
4
|
from moscot.backends.ott.output import OTTOutput
|
|
4
5
|
from moscot.backends.ott.solver import GWSolver, SinkhornSolver
|
|
5
6
|
from moscot.costs import register_cost
|
|
6
7
|
|
|
7
|
-
__all__ = ["OTTOutput", "GWSolver", "SinkhornSolver"]
|
|
8
|
+
__all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"]
|
|
8
9
|
|
|
9
10
|
register_cost("euclidean", backend="ott")(costs.Euclidean)
|
|
10
11
|
register_cost("sq_euclidean", backend="ott")(costs.SqEuclidean)
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import scipy.sparse as sp
|
|
6
|
+
from ott.geometry import geometry, pointcloud
|
|
7
|
+
from ott.tools import sinkhorn_divergence as sdiv
|
|
8
|
+
|
|
9
|
+
from moscot._logging import logger
|
|
10
|
+
from moscot._types import ArrayLike, ScaleCost_t
|
|
11
|
+
|
|
12
|
+
__all__ = ["sinkhorn_divergence"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def sinkhorn_divergence(
|
|
16
|
+
point_cloud_1: ArrayLike,
|
|
17
|
+
point_cloud_2: ArrayLike,
|
|
18
|
+
a: Optional[ArrayLike] = None,
|
|
19
|
+
b: Optional[ArrayLike] = None,
|
|
20
|
+
epsilon: Optional[float] = 1e-1,
|
|
21
|
+
scale_cost: ScaleCost_t = 1.0,
|
|
22
|
+
**kwargs: Any,
|
|
23
|
+
) -> float:
|
|
24
|
+
point_cloud_1 = jnp.asarray(point_cloud_1)
|
|
25
|
+
point_cloud_2 = jnp.asarray(point_cloud_2)
|
|
26
|
+
a = None if a is None else jnp.asarray(a)
|
|
27
|
+
b = None if b is None else jnp.asarray(b)
|
|
28
|
+
|
|
29
|
+
output = sdiv.sinkhorn_divergence(
|
|
30
|
+
pointcloud.PointCloud,
|
|
31
|
+
x=point_cloud_1,
|
|
32
|
+
y=point_cloud_2,
|
|
33
|
+
a=a,
|
|
34
|
+
b=b,
|
|
35
|
+
epsilon=epsilon,
|
|
36
|
+
scale_cost=scale_cost,
|
|
37
|
+
**kwargs,
|
|
38
|
+
)
|
|
39
|
+
xy_conv, xx_conv, *yy_conv = output.converged
|
|
40
|
+
|
|
41
|
+
if not xy_conv:
|
|
42
|
+
logger.warning("Solver did not converge in the `x/y` term.")
|
|
43
|
+
if not xx_conv:
|
|
44
|
+
logger.warning("Solver did not converge in the `x/x` term.")
|
|
45
|
+
if len(yy_conv) and not yy_conv[0]:
|
|
46
|
+
logger.warning("Solver did not converge in the `y/y` term.")
|
|
47
|
+
|
|
48
|
+
return float(output.divergence)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def check_shapes(geom_x: geometry.Geometry, geom_y: geometry.Geometry, geom_xy: geometry.Geometry) -> None:
|
|
52
|
+
n, m = geom_xy.shape
|
|
53
|
+
n_, m_ = geom_x.shape[0], geom_y.shape[0]
|
|
54
|
+
if n != n_:
|
|
55
|
+
raise ValueError(f"Expected the first geometry to have `{n}` points, found `{n_}`.")
|
|
56
|
+
if m != m_:
|
|
57
|
+
raise ValueError(f"Expected the second geometry to have `{m}` points, found `{m_}`.")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def alpha_to_fused_penalty(alpha: float) -> float:
|
|
61
|
+
"""Convert."""
|
|
62
|
+
if not (0 < alpha <= 1):
|
|
63
|
+
raise ValueError(f"Expected `alpha` to be in interval `(0, 1]`, found `{alpha}`.")
|
|
64
|
+
return (1 - alpha) / alpha
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
|
|
68
|
+
"""Ensure that an array is 2-dimensional.
|
|
69
|
+
|
|
70
|
+
Parameters
|
|
71
|
+
----------
|
|
72
|
+
arr
|
|
73
|
+
Array to check.
|
|
74
|
+
reshape
|
|
75
|
+
Allow reshaping 1-dimensional array to ``[n, 1]``.
|
|
76
|
+
|
|
77
|
+
Returns
|
|
78
|
+
-------
|
|
79
|
+
2-dimensional :mod:`jax` array.
|
|
80
|
+
"""
|
|
81
|
+
if sp.issparse(arr):
|
|
82
|
+
arr = arr.A # type: ignore[attr-defined]
|
|
83
|
+
arr = jnp.asarray(arr)
|
|
84
|
+
if reshape and arr.ndim == 1:
|
|
85
|
+
return jnp.reshape(arr, (-1, 1))
|
|
86
|
+
if arr.ndim != 2:
|
|
87
|
+
raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
|
|
88
|
+
return arr
|
|
@@ -5,14 +5,12 @@ import jaxlib.xla_extension as xla_ext
|
|
|
5
5
|
import jax
|
|
6
6
|
import jax.numpy as jnp
|
|
7
7
|
import numpy as np
|
|
8
|
-
from ott.solvers.linear
|
|
9
|
-
from ott.solvers.
|
|
10
|
-
from ott.solvers.quadratic.gromov_wasserstein import GWOutput as OTTGWOutput
|
|
8
|
+
from ott.solvers.linear import sinkhorn, sinkhorn_lr
|
|
9
|
+
from ott.solvers.quadratic import gromov_wasserstein
|
|
11
10
|
|
|
12
11
|
import matplotlib as mpl
|
|
13
12
|
import matplotlib.pyplot as plt
|
|
14
13
|
|
|
15
|
-
from moscot._docs._docs import d
|
|
16
14
|
from moscot._types import ArrayLike, Device_t
|
|
17
15
|
from moscot.base.output import BaseSolverOutput
|
|
18
16
|
|
|
@@ -20,58 +18,59 @@ __all__ = ["OTTOutput"]
|
|
|
20
18
|
|
|
21
19
|
|
|
22
20
|
class OTTOutput(BaseSolverOutput):
|
|
23
|
-
"""Output of various
|
|
21
|
+
"""Output of various :term:`OT` problems.
|
|
24
22
|
|
|
25
23
|
Parameters
|
|
26
24
|
----------
|
|
27
25
|
output
|
|
28
|
-
Output of :mod:`ott` backend.
|
|
26
|
+
Output of the :mod:`ott` backend.
|
|
29
27
|
"""
|
|
30
28
|
|
|
31
|
-
_NOT_COMPUTED = -1.0
|
|
29
|
+
_NOT_COMPUTED = -1.0 # sentinel value used in `ott`
|
|
32
30
|
|
|
33
|
-
def __init__(
|
|
31
|
+
def __init__(
|
|
32
|
+
self, output: Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput, gromov_wasserstein.GWOutput]
|
|
33
|
+
):
|
|
34
34
|
super().__init__()
|
|
35
35
|
self._output = output
|
|
36
|
-
self._costs = None if isinstance(output,
|
|
36
|
+
self._costs = None if isinstance(output, sinkhorn.SinkhornOutput) else output.costs
|
|
37
37
|
self._errors = output.errors
|
|
38
38
|
|
|
39
|
-
@d.get_sections(base="plot_costs", sections=["Parameters", "Returns"])
|
|
40
39
|
def plot_costs(
|
|
41
40
|
self,
|
|
42
41
|
last: Optional[int] = None,
|
|
43
42
|
title: Optional[str] = None,
|
|
44
43
|
return_fig: bool = False,
|
|
44
|
+
ax: Optional[mpl.axes.Axes] = None,
|
|
45
45
|
figsize: Optional[Tuple[float, float]] = None,
|
|
46
46
|
dpi: Optional[int] = None,
|
|
47
47
|
save: Optional[str] = None,
|
|
48
|
-
ax: Optional[mpl.axes.Axes] = None,
|
|
49
48
|
**kwargs: Any,
|
|
50
49
|
) -> Optional[mpl.figure.Figure]:
|
|
51
|
-
"""Plot regularized OT costs during the iterations.
|
|
50
|
+
"""Plot regularized :term:`OT` costs during the iterations.
|
|
52
51
|
|
|
53
52
|
Parameters
|
|
54
53
|
----------
|
|
55
54
|
last
|
|
56
|
-
How many of the last steps of the algorithm to plot. If
|
|
55
|
+
How many of the last steps of the algorithm to plot. If :obj:`None`, plot the full curve.
|
|
57
56
|
title
|
|
58
|
-
Title of the plot. If
|
|
57
|
+
Title of the plot. If :obj:`None`, it is determined automatically.
|
|
58
|
+
return_fig
|
|
59
|
+
Whether to return the figure.
|
|
60
|
+
ax
|
|
61
|
+
Axes on which to plot.
|
|
59
62
|
figsize
|
|
60
63
|
Size of the figure.
|
|
61
64
|
dpi
|
|
62
65
|
Dots per inch.
|
|
63
66
|
save
|
|
64
67
|
Path where to save the figure.
|
|
65
|
-
return_fig
|
|
66
|
-
Whether to return the figure.
|
|
67
|
-
ax
|
|
68
|
-
Axes on which to plot.
|
|
69
68
|
kwargs
|
|
70
|
-
Keyword arguments for :meth
|
|
69
|
+
Keyword arguments for :meth:`matplotlib.axes.Axes.plot`.
|
|
71
70
|
|
|
72
71
|
Returns
|
|
73
72
|
-------
|
|
74
|
-
|
|
73
|
+
If ``return_fig = True``, return the figure.
|
|
75
74
|
"""
|
|
76
75
|
if self._costs is None:
|
|
77
76
|
raise RuntimeError("No costs to plot.")
|
|
@@ -83,7 +82,6 @@ class OTTOutput(BaseSolverOutput):
|
|
|
83
82
|
fig.savefig(save)
|
|
84
83
|
return fig if return_fig else None
|
|
85
84
|
|
|
86
|
-
@d.dedent
|
|
87
85
|
def plot_errors(
|
|
88
86
|
self,
|
|
89
87
|
last: Optional[int] = None,
|
|
@@ -96,15 +94,34 @@ class OTTOutput(BaseSolverOutput):
|
|
|
96
94
|
ax: Optional[mpl.axes.Axes] = None,
|
|
97
95
|
**kwargs: Any,
|
|
98
96
|
) -> Optional[mpl.figure.Figure]:
|
|
99
|
-
"""Plot errors
|
|
97
|
+
"""Plot errors along iterations.
|
|
100
98
|
|
|
101
99
|
Parameters
|
|
102
100
|
----------
|
|
103
|
-
|
|
101
|
+
last
|
|
102
|
+
Number of errors corresponding at the ``last`` steps of the algorithm to plot. If :obj:`None`,
|
|
103
|
+
plot the full curve.
|
|
104
|
+
title
|
|
105
|
+
Title of the plot. If :obj:`None`, it is determined automatically.
|
|
106
|
+
outer_iteration
|
|
107
|
+
Which outermost iteration's errors to plot.
|
|
108
|
+
Only used when this is the solution to the :term:`quadratic problem`.
|
|
109
|
+
return_fig
|
|
110
|
+
Whether to return the figure.
|
|
111
|
+
ax
|
|
112
|
+
Axes on which to plot.
|
|
113
|
+
figsize
|
|
114
|
+
Size of the figure.
|
|
115
|
+
dpi
|
|
116
|
+
Dots per inch.
|
|
117
|
+
save
|
|
118
|
+
Path where to save the figure.
|
|
119
|
+
kwargs
|
|
120
|
+
Keyword arguments for :meth:`matplotlib.axes.Axes.plot`.
|
|
104
121
|
|
|
105
122
|
Returns
|
|
106
123
|
-------
|
|
107
|
-
|
|
124
|
+
If ``return_fig = True``, return the figure.
|
|
108
125
|
"""
|
|
109
126
|
if self._errors is None:
|
|
110
127
|
raise RuntimeError("No errors to plot.")
|
|
@@ -155,7 +172,7 @@ class OTTOutput(BaseSolverOutput):
|
|
|
155
172
|
|
|
156
173
|
@property
|
|
157
174
|
def shape(self) -> Tuple[int, int]: # noqa: D102
|
|
158
|
-
if isinstance(self._output,
|
|
175
|
+
if isinstance(self._output, sinkhorn.SinkhornOutput):
|
|
159
176
|
return self._output.f.shape[0], self._output.g.shape[0]
|
|
160
177
|
return self._output.geom.shape
|
|
161
178
|
|
|
@@ -165,9 +182,12 @@ class OTTOutput(BaseSolverOutput):
|
|
|
165
182
|
|
|
166
183
|
@property
|
|
167
184
|
def is_linear(self) -> bool: # noqa: D102
|
|
168
|
-
return isinstance(self._output, (
|
|
185
|
+
return isinstance(self._output, (sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput))
|
|
169
186
|
|
|
170
187
|
def to(self, device: Optional[Device_t] = None) -> "OTTOutput": # noqa: D102
|
|
188
|
+
if device is None:
|
|
189
|
+
return OTTOutput(jax.device_put(self._output, device=device))
|
|
190
|
+
|
|
171
191
|
if isinstance(device, str) and ":" in device:
|
|
172
192
|
device, ix = device.split(":")
|
|
173
193
|
idx = int(ix)
|
|
@@ -184,9 +204,7 @@ class OTTOutput(BaseSolverOutput):
|
|
|
184
204
|
|
|
185
205
|
@property
|
|
186
206
|
def cost(self) -> float: # noqa: D102
|
|
187
|
-
|
|
188
|
-
return float(self._output.reg_ot_cost)
|
|
189
|
-
return float(self._output.reg_gw_cost)
|
|
207
|
+
return float(self._output.reg_ot_cost if self.is_linear else self._output.reg_gw_cost)
|
|
190
208
|
|
|
191
209
|
@property
|
|
192
210
|
def converged(self) -> bool: # noqa: D102
|
|
@@ -194,14 +212,14 @@ class OTTOutput(BaseSolverOutput):
|
|
|
194
212
|
|
|
195
213
|
@property
|
|
196
214
|
def potentials(self) -> Optional[Tuple[ArrayLike, ArrayLike]]: # noqa: D102
|
|
197
|
-
if isinstance(self._output,
|
|
215
|
+
if isinstance(self._output, sinkhorn.SinkhornOutput):
|
|
198
216
|
return self._output.f, self._output.g
|
|
199
217
|
return None
|
|
200
218
|
|
|
201
219
|
@property
|
|
202
220
|
def rank(self) -> int: # noqa: D102
|
|
203
221
|
lin_output = self._output if self.is_linear else self._output.linear_state
|
|
204
|
-
return len(lin_output.g) if isinstance(lin_output,
|
|
222
|
+
return len(lin_output.g) if isinstance(lin_output, sinkhorn_lr.LRSinkhornOutput) else -1
|
|
205
223
|
|
|
206
224
|
def _ones(self, n: int) -> ArrayLike: # noqa: D102
|
|
207
|
-
return jnp.ones((n,))
|
|
225
|
+
return jnp.ones((n,))
|