moscot 0.2.0__tar.gz → 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. {moscot-0.2.0 → moscot-0.3.1}/.pre-commit-config.yaml +9 -9
  2. {moscot-0.2.0 → moscot-0.3.1}/PKG-INFO +34 -15
  3. moscot-0.3.1/README.rst +66 -0
  4. {moscot-0.2.0 → moscot-0.3.1}/pyproject.toml +22 -14
  5. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/_constants.py +0 -1
  6. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/_types.py +21 -6
  7. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/backends/ott/__init__.py +2 -1
  8. moscot-0.3.1/src/moscot/backends/ott/_utils.py +88 -0
  9. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/backends/ott/output.py +50 -32
  10. moscot-0.3.1/src/moscot/backends/ott/solver.py +333 -0
  11. moscot-0.3.1/src/moscot/backends/utils.py +53 -0
  12. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/cost.py +20 -22
  13. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/output.py +70 -72
  14. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/problems/_mixins.py +60 -102
  15. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/problems/_utils.py +69 -33
  16. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/problems/birth_death.py +103 -87
  17. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/problems/compound_problem.py +158 -220
  18. moscot-0.3.1/src/moscot/base/problems/manager.py +189 -0
  19. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/problems/problem.py +327 -184
  20. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/solver.py +29 -5
  21. moscot-0.3.1/src/moscot/costs/_costs.py +141 -0
  22. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/costs/_utils.py +1 -1
  23. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/datasets.py +203 -25
  24. moscot-0.3.1/src/moscot/plotting/_plotting.py +427 -0
  25. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/plotting/_utils.py +20 -12
  26. moscot-0.3.1/src/moscot/problems/__init__.py +13 -0
  27. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/problems/_utils.py +32 -26
  28. moscot-0.3.1/src/moscot/problems/cross_modality/__init__.py +4 -0
  29. moscot-0.3.1/src/moscot/problems/cross_modality/_mixins.py +195 -0
  30. moscot-0.3.1/src/moscot/problems/cross_modality/_translation.py +297 -0
  31. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/problems/generic/__init__.py +1 -1
  32. moscot-0.3.1/src/moscot/problems/generic/_generic.py +481 -0
  33. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/problems/generic/_mixins.py +94 -54
  34. moscot-0.3.1/src/moscot/problems/space/_alignment.py +251 -0
  35. moscot-0.3.1/src/moscot/problems/space/_mapping.py +322 -0
  36. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/problems/space/_mixins.py +237 -194
  37. moscot-0.3.1/src/moscot/problems/spatiotemporal/_spatio_temporal.py +268 -0
  38. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/problems/time/__init__.py +1 -1
  39. moscot-0.3.1/src/moscot/problems/time/_lineage.py +492 -0
  40. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/problems/time/_mixins.py +355 -257
  41. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/data.py +37 -4
  42. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/subset_policy.py +211 -72
  43. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/tagged_array.py +28 -16
  44. {moscot-0.2.0 → moscot-0.3.1}/src/moscot.egg-info/PKG-INFO +34 -15
  45. {moscot-0.2.0 → moscot-0.3.1}/src/moscot.egg-info/SOURCES.txt +3 -5
  46. {moscot-0.2.0 → moscot-0.3.1}/src/moscot.egg-info/requires.txt +5 -5
  47. moscot-0.2.0/README.rst +0 -47
  48. moscot-0.2.0/src/moscot/_docs/__init__.py +0 -0
  49. moscot-0.2.0/src/moscot/_docs/_docs.py +0 -458
  50. moscot-0.2.0/src/moscot/_docs/_docs_mixins.py +0 -224
  51. moscot-0.2.0/src/moscot/_docs/_docs_plot.py +0 -197
  52. moscot-0.2.0/src/moscot/_docs/_utils.py +0 -19
  53. moscot-0.2.0/src/moscot/backends/ott/_utils.py +0 -46
  54. moscot-0.2.0/src/moscot/backends/ott/solver.py +0 -309
  55. moscot-0.2.0/src/moscot/backends/utils.py +0 -40
  56. moscot-0.2.0/src/moscot/base/problems/manager.py +0 -130
  57. moscot-0.2.0/src/moscot/costs/_costs.py +0 -102
  58. moscot-0.2.0/src/moscot/plotting/_plotting.py +0 -339
  59. moscot-0.2.0/src/moscot/problems/__init__.py +0 -5
  60. moscot-0.2.0/src/moscot/problems/generic/_generic.py +0 -339
  61. moscot-0.2.0/src/moscot/problems/space/_alignment.py +0 -168
  62. moscot-0.2.0/src/moscot/problems/space/_mapping.py +0 -250
  63. moscot-0.2.0/src/moscot/problems/spatiotemporal/_spatio_temporal.py +0 -203
  64. moscot-0.2.0/src/moscot/problems/time/_lineage.py +0 -356
  65. {moscot-0.2.0 → moscot-0.3.1}/.gitignore +0 -0
  66. {moscot-0.2.0 → moscot-0.3.1}/.gitmodules +0 -0
  67. {moscot-0.2.0 → moscot-0.3.1}/.readthedocs.yml +0 -0
  68. {moscot-0.2.0 → moscot-0.3.1}/LICENSE +0 -0
  69. {moscot-0.2.0 → moscot-0.3.1}/MANIFEST.in +0 -0
  70. {moscot-0.2.0 → moscot-0.3.1}/codecov.yml +0 -0
  71. {moscot-0.2.0 → moscot-0.3.1}/setup.cfg +0 -0
  72. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/__init__.py +0 -0
  73. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/_logging.py +0 -0
  74. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/_registry.py +0 -0
  75. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/backends/__init__.py +0 -0
  76. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/__init__.py +0 -0
  77. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/base/problems/__init__.py +0 -0
  78. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/costs/__init__.py +0 -0
  79. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/plotting/__init__.py +0 -0
  80. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/problems/space/__init__.py +1 -1
  81. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/problems/spatiotemporal/__init__.py +0 -0
  82. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/py.typed +0 -0
  83. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/__init__.py +0 -0
  84. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/_data/allTFs_dmel.txt +0 -0
  85. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/_data/allTFs_hg38.txt +0 -0
  86. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/_data/allTFs_mm.txt +0 -0
  87. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/_data/human_apoptosis.txt +0 -0
  88. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/_data/human_proliferation.txt +0 -0
  89. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/_data/mouse_apoptosis.txt +0 -0
  90. {moscot-0.2.0 → moscot-0.3.1}/src/moscot/utils/_data/mouse_proliferation.txt +0 -0
  91. {moscot-0.2.0 → moscot-0.3.1}/src/moscot.egg-info/dependency_links.txt +0 -0
  92. {moscot-0.2.0 → moscot-0.3.1}/src/moscot.egg-info/top_level.txt +0 -0
@@ -7,22 +7,22 @@ default_stages:
7
7
  minimum_pre_commit_version: 3.0.0
8
8
  repos:
9
9
  - repo: https://github.com/pre-commit/mirrors-mypy
10
- rev: v1.1.1
10
+ rev: v1.4.1
11
11
  hooks:
12
12
  - id: mypy
13
- additional_dependencies: [numpy>=1.21.0, jax]
13
+ additional_dependencies: [numpy>=1.25.0]
14
14
  files: ^src
15
15
  - repo: https://github.com/psf/black
16
- rev: 23.1.0
16
+ rev: 23.7.0
17
17
  hooks:
18
18
  - id: black
19
19
  additional_dependencies: [toml]
20
20
  - repo: https://github.com/pre-commit/mirrors-prettier
21
- rev: v3.0.0-alpha.6
21
+ rev: v3.0.0
22
22
  hooks:
23
23
  - id: prettier
24
24
  language_version: system
25
- - repo: https://github.com/timothycrosley/isort
25
+ - repo: https://github.com/PyCQA/isort
26
26
  rev: 5.12.0
27
27
  hooks:
28
28
  - id: isort
@@ -42,12 +42,12 @@ repos:
42
42
  - id: check-yaml
43
43
  - id: check-toml
44
44
  - repo: https://github.com/asottile/pyupgrade
45
- rev: v3.3.1
45
+ rev: v3.9.0
46
46
  hooks:
47
47
  - id: pyupgrade
48
48
  args: [--py3-plus, --py38-plus, --keep-runtime-typing]
49
49
  - repo: https://github.com/asottile/blacken-docs
50
- rev: 1.13.0
50
+ rev: 1.15.0
51
51
  hooks:
52
52
  - id: blacken-docs
53
53
  additional_dependencies: [black==23.1.0]
@@ -61,9 +61,9 @@ repos:
61
61
  rev: v1.1.1
62
62
  hooks:
63
63
  - id: doc8
64
- - repo: https://github.com/charliermarsh/ruff-pre-commit
64
+ - repo: https://github.com/astral-sh/ruff-pre-commit
65
65
  # Ruff version.
66
- rev: v0.0.257
66
+ rev: v0.0.280
67
67
  hooks:
68
68
  - id: ruff
69
69
  args: [--fix, --exit-non-zero-on-fix]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: moscot
3
- Version: 0.2.0
3
+ Version: 0.3.1
4
4
  Summary: Multi-omic single-cell optimal transport tools
5
5
  Author: Dominik Klein, Giovanni Palla, Michal Klein, Zoe Piran, Marius Lange
6
6
  Maintainer-email: Dominik Klein <dominik.klein@helmholtz-muenchen.de>, Giovanni Palla <giovanni.palla@helmholtz-muenchen.de>, Michal Klein <michal.klein@helmholtz-muenchen.de>
@@ -62,16 +62,18 @@ Provides-Extra: test
62
62
  Provides-Extra: docs
63
63
  License-File: LICENSE
64
64
 
65
- |Codecov|
65
+ |PyPI| |Downloads| |CI| |Pre-commit| |Codecov| |Docs|
66
66
 
67
67
  moscot - multi-omic single-cell optimal transport tools
68
68
  =======================================================
69
69
 
70
70
  **moscot** is a scalable framework for Optimal Transport (OT) applications in
71
71
  single-cell genomics. It can be used for
72
- - temporal and spatio-temporal trajectory inference
73
- - spatial mapping
74
- - spatial alignment
72
+
73
+ - trajectory inference (incorporating spatial and lineage information)
74
+ - mapping cells to their spatial organisation
75
+ - aligning spatial transcriptomics slides
76
+ - translating modalities
75
77
  - prototyping of new OT models in single-cell genomics
76
78
 
77
79
  **moscot** is powered by
@@ -85,27 +87,44 @@ You can install **moscot** via::
85
87
 
86
88
  pip install moscot
87
89
 
88
- In order to install **moscot** from source, run::
90
+ In order to install **moscot** from in editable mode, run::
89
91
 
90
92
  git clone https://github.com/theislab/moscot
91
93
  cd moscot
92
- pip install -e .'[dev]'
94
+ pip install -e .
95
+
96
+ For further instructions how to install jax, please refer to https://github.com/google/jax.
93
97
 
94
- If used with GPU, additionally run::
98
+ Resources
99
+ ---------
95
100
 
96
- pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
101
+ Please have a look at our `documentation <https://moscot.readthedocs.io>`_
102
+
103
+ Reference
104
+ ---------
97
105
 
106
+ Our preprint "Mapping cells through time and space with moscot" can be found `here <https://www.biorxiv.org/content/10.1101/2023.05.11.540374v1>`_.
98
107
 
99
108
  .. |Codecov| image:: https://codecov.io/gh/theislab/moscot/branch/master/graph/badge.svg?token=Rgtm5Tsblo
100
109
  :target: https://codecov.io/gh/theislab/moscot
101
110
  :alt: Coverage
102
111
 
103
- Resources
104
- ---------
112
+ .. |PyPI| image:: https://img.shields.io/pypi/v/moscot.svg
113
+ :target: https://pypi.org/project/moscot/
114
+ :alt: PyPI
105
115
 
106
- Please have a look at our `documentation <https://moscot.readthedocs.io>`_
116
+ .. |CI| image:: https://img.shields.io/github/actions/workflow/status/theislab/moscot/test.yml?branch=main
117
+ :target: https://github.com/theislab/moscot/actions
118
+ :alt: CI
107
119
 
108
- Reference
109
- ---------
120
+ .. |Pre-commit| image:: https://results.pre-commit.ci/badge/github/theislab/moscot/main.svg
121
+ :target: https://results.pre-commit.ci/latest/github/theislab/moscot/main
122
+ :alt: pre-commit.ci status
123
+
124
+ .. |Docs| image:: https://img.shields.io/readthedocs/moscot
125
+ :target: https://moscot.readthedocs.io/en/stable/
126
+ :alt: Documentation
110
127
 
111
- Our manuscript will be available soon.
128
+ .. |Downloads| image:: https://pepy.tech/badge/moscot
129
+ :target: https://pepy.tech/project/moscot
130
+ :alt: Downloads
@@ -0,0 +1,66 @@
1
+ |PyPI| |Downloads| |CI| |Pre-commit| |Codecov| |Docs|
2
+
3
+ moscot - multi-omic single-cell optimal transport tools
4
+ =======================================================
5
+
6
+ **moscot** is a scalable framework for Optimal Transport (OT) applications in
7
+ single-cell genomics. It can be used for
8
+
9
+ - trajectory inference (incorporating spatial and lineage information)
10
+ - mapping cells to their spatial organisation
11
+ - aligning spatial transcriptomics slides
12
+ - translating modalities
13
+ - prototyping of new OT models in single-cell genomics
14
+
15
+ **moscot** is powered by
16
+ `OTT <https://ott-jax.readthedocs.io>`_ which is a JAX-based Optimal
17
+ Transport toolkit that supports just-in-time compilation, GPU acceleration, automatic
18
+ differentiation and linear memory complexity for OT problems.
19
+
20
+ Installation
21
+ ------------
22
+ You can install **moscot** via::
23
+
24
+ pip install moscot
25
+
26
+ In order to install **moscot** from in editable mode, run::
27
+
28
+ git clone https://github.com/theislab/moscot
29
+ cd moscot
30
+ pip install -e .
31
+
32
+ For further instructions how to install jax, please refer to https://github.com/google/jax.
33
+
34
+ Resources
35
+ ---------
36
+
37
+ Please have a look at our `documentation <https://moscot.readthedocs.io>`_
38
+
39
+ Reference
40
+ ---------
41
+
42
+ Our preprint "Mapping cells through time and space with moscot" can be found `here <https://www.biorxiv.org/content/10.1101/2023.05.11.540374v1>`_.
43
+
44
+ .. |Codecov| image:: https://codecov.io/gh/theislab/moscot/branch/master/graph/badge.svg?token=Rgtm5Tsblo
45
+ :target: https://codecov.io/gh/theislab/moscot
46
+ :alt: Coverage
47
+
48
+ .. |PyPI| image:: https://img.shields.io/pypi/v/moscot.svg
49
+ :target: https://pypi.org/project/moscot/
50
+ :alt: PyPI
51
+
52
+ .. |CI| image:: https://img.shields.io/github/actions/workflow/status/theislab/moscot/test.yml?branch=main
53
+ :target: https://github.com/theislab/moscot/actions
54
+ :alt: CI
55
+
56
+ .. |Pre-commit| image:: https://results.pre-commit.ci/badge/github/theislab/moscot/main.svg
57
+ :target: https://results.pre-commit.ci/latest/github/theislab/moscot/main
58
+ :alt: pre-commit.ci status
59
+
60
+ .. |Docs| image:: https://img.shields.io/readthedocs/moscot
61
+ :target: https://moscot.readthedocs.io/en/stable/
62
+ :alt: Documentation
63
+
64
+ .. |Downloads| image:: https://pepy.tech/badge/moscot
65
+ :target: https://pepy.tech/project/moscot
66
+ :alt: Downloads
@@ -47,15 +47,14 @@ maintainers = [
47
47
  dependencies = [
48
48
  "numpy>=1.20.0",
49
49
  "scipy>=1.7.0",
50
- "pandas>=1.4.0",
50
+ "pandas>=2.0.1",
51
51
  "networkx>=2.6.3",
52
52
  # https://github.com/scverse/scanpy/issues/2411
53
53
  "matplotlib>=3.5.0",
54
- "anndata>=0.8.0",
54
+ "anndata>=0.9.1",
55
55
  "scanpy>=1.9.3",
56
56
  "wrapt>=1.13.2",
57
- "docrep>=0.3.2",
58
- "ott-jax>=0.4.0",
57
+ "ott-jax==0.4.0",
59
58
  "cloudpickle>=2.2.0",
60
59
  ]
61
60
 
@@ -79,9 +78,10 @@ docs = [
79
78
  "sphinx_copybutton>=0.5.0",
80
79
  "sphinxcontrib-bibtex>=2.3.0",
81
80
  "sphinxcontrib-spelling>=7.6.2",
81
+ "sphinx-autodoc-typehints",
82
82
  "furo>=2022.09.29",
83
+ "sphinx-tippy>=0.4.1",
83
84
  "myst-nb>=0.17.1",
84
- "nbsphinx>=0.8.1",
85
85
  "ipython>=7.20.0",
86
86
  "sphinx_design>=0.3.0",
87
87
  ]
@@ -151,9 +151,6 @@ target-version = "py38"
151
151
  "src/moscot/utils/subset_policy.py" = ["D101", "D102"]
152
152
  [tool.ruff.pydocstyle]
153
153
  convention = "numpy"
154
- [tool.ruff.pyupgrade]
155
- # Preserve types, even if a file imports `from __future__ import annotations`.
156
- keep-runtime-typing = true
157
154
  [tool.ruff.flake8-tidy-imports]
158
155
  ban-relative-imports = "all"
159
156
  [tool.ruff.flake8-quotes]
@@ -168,9 +165,10 @@ include = '\.pyi?$'
168
165
  profile = "black"
169
166
  include_trailing_comma = true
170
167
  multi_line_output = 3
171
- sections = ["FUTURE", "STDLIB", "THIRDPARTY", "NUMERIC", "PLOTTING", "BIO", "FIRSTPARTY", "LOCALFOLDER"]
168
+ sections = ["FUTURE", "STDLIB", "THIRDPARTY", "GENERIC", "NUMERIC", "PLOTTING", "BIO", "FIRSTPARTY", "LOCALFOLDER"]
172
169
  # also contains what we import in notebooks
173
- known_numeric = ["numpy", "scipy", "jax", "ott", "pandas", "sklearn", "networkx"]
170
+ known_generic = ["wrapt", "joblib"]
171
+ known_numeric = ["numpy", "scipy", "jax", "ott", "pandas", "sklearn", "networkx", "statsmodels"]
174
172
  known_bio = ["anndata", "scanpy", "squidpy"]
175
173
  known_plotting = ["IPython", "matplotlib", "mpl_toolkits", "seaborn"]
176
174
 
@@ -178,9 +176,8 @@ known_plotting = ["IPython", "matplotlib", "mpl_toolkits", "seaborn"]
178
176
  markers = ["fast: marks tests as fask"]
179
177
  xfail_strict = true
180
178
  filterwarnings = [
181
- "ignore:X.dtype being converted:FutureWarning",
182
179
  "ignore:No data for colormapping:UserWarning",
183
- "ignore:jax\\.experimental\\.pjit\\.PartitionSpec:DeprecationWarning",
180
+ "ignore:The dtype argument will be deprecated in anndata:PendingDeprecationWarning"
184
181
  ]
185
182
 
186
183
  [tool.coverage.run]
@@ -214,6 +211,7 @@ ignore_directives = [
214
211
  "automodule",
215
212
  "autoclass",
216
213
  "bibliography",
214
+ "glossary",
217
215
  "card",
218
216
  "grid",
219
217
  ]
@@ -251,7 +249,7 @@ show_column_numbers = true
251
249
  error_summary = true
252
250
  ignore_missing_imports = true
253
251
 
254
- disable_error_code = ["assignment", "comparison-overlap", "no-untyped-def"]
252
+ disable_error_code = ["assignment", "comparison-overlap", "no-untyped-def", "override"]
255
253
 
256
254
  [tool.doc8]
257
255
  max_line_length = 120
@@ -294,6 +292,7 @@ commands =
294
292
 
295
293
  [testenv:clean-docs]
296
294
  description = Remove the documentation.
295
+ deps =
297
296
  skip_install = true
298
297
  changedir = {tox_root}{/}docs
299
298
  allowlist_externals = make
@@ -302,7 +301,6 @@ commands =
302
301
 
303
302
  [testenv:build-docs]
304
303
  description = Build the documentation.
305
- use_develop = true
306
304
  deps =
307
305
  extras = docs
308
306
  allowlist_externals = make
@@ -324,4 +322,14 @@ commands =
324
322
  python -m twine check {tox_root}{/}dist{/}*
325
323
  commands_post =
326
324
  python -c 'import pathlib; print(f"Package is under:", pathlib.Path("{tox_root}") / "dist")'
325
+
326
+ [testenv:format-references]
327
+ description = Format references.bib.
328
+ deps =
329
+ skip_install = true
330
+ allowlist_externals = biber
331
+ commands = biber --tool --output_file={tox_root}{/}docs{/}references.bib --nolog \
332
+ --output_align --output_indent=2 --output_fieldcase=lower \
333
+ --output_legacy_dates --output-field-replace=journaltitle:journal,thesis:phdthesis,institution:school \
334
+ {tox_root}{/}docs{/}references.bib
327
335
  """
@@ -4,7 +4,6 @@ STAR = "star"
4
4
  EXTERNAL_STAR = "external_star"
5
5
  TRIU = "triu"
6
6
  TRIL = "tril"
7
- PAIRWISE = "pairwise"
8
7
  EXPLICIT = "explicit"
9
8
  DUMMY = "dummy"
10
9
  # plotting
@@ -16,7 +16,7 @@ except (ImportError, TypeError):
16
16
  ProblemKind_t = Literal["linear", "quadratic", "unknown"]
17
17
  Numeric_t = Union[int, float] # type of `time_key` arguments
18
18
  Filter_t = Optional[Union[str, Mapping[str, Sequence[Any]]]] # type how to filter adata
19
- Str_Dict_t = Union[str, Mapping[str, Sequence[Any]]] # type for `cell_transition`
19
+ Str_Dict_t = Optional[Union[str, Mapping[str, Sequence[Any]]]] # type for `cell_transition`
20
20
  SinkFullRankInit = Literal["default", "gaussian", "sorting"]
21
21
  LRInitializer_t = Literal["random", "rank2", "k-means", "generalized-k-means"]
22
22
 
@@ -24,20 +24,35 @@ SinkhornInitializer_t = Optional[Union[SinkFullRankInit, LRInitializer_t]]
24
24
  QuadInitializer_t = Optional[LRInitializer_t]
25
25
 
26
26
  Initializer_t = Union[SinkhornInitializer_t, LRInitializer_t]
27
- ProblemStage_t = Literal["initialized", "prepared", "solved"]
28
- Device_t = Literal["cpu", "gpu", "tpu"]
27
+ ProblemStage_t = Literal["prepared", "solved"]
28
+ Device_t = Union[Literal["cpu", "gpu", "tpu"], str]
29
29
 
30
30
  # TODO(michalk8): autogenerate from the enums
31
- ScaleCost_t = Optional[Union[float, Literal["mean", "max_cost", "max_bound", "max_norm", "median"]]]
32
- OttCostFn_t = Literal["euclidean", "sq_euclidean", "cosine", "bures", "unbalanced_bures"]
31
+ ScaleCost_t = Union[float, Literal["mean", "max_cost", "max_bound", "max_norm", "median"]]
32
+ OttCostFn_t = Literal[
33
+ "euclidean",
34
+ "sq_euclidean",
35
+ "cosine",
36
+ "PNormP",
37
+ "SqPNorm",
38
+ "Euclidean",
39
+ "SqEuclidean",
40
+ "Cosine",
41
+ "ElasticL1",
42
+ "ElasticSTVS",
43
+ "ElasticSqKOverlap",
44
+ ]
45
+ OttCostFnMap_t = Union[OttCostFn_t, Mapping[Literal["xy", "x", "y"], OttCostFn_t]]
33
46
  GenericCostFn_t = Literal["barcode_distance", "leaf_distance", "custom"]
34
47
  CostFn_t = Union[str, GenericCostFn_t, OttCostFn_t]
48
+ CostFnMap_t = Union[Union[OttCostFn_t, GenericCostFn_t], Mapping[str, Union[OttCostFn_t, GenericCostFn_t]]]
35
49
  PathLike = Union[os.PathLike, str]
36
50
  Policy_t = Literal[
37
51
  "sequential",
38
52
  "star",
39
53
  "external_star",
54
+ "explicit",
40
55
  "triu",
41
56
  "tril",
42
- "explicit",
43
57
  ]
58
+ CostKwargs_t = Union[Mapping[str, Any], Mapping[Literal["x", "y", "xy"], Mapping[str, Any]]]
@@ -1,10 +1,11 @@
1
1
  from ott.geometry import costs
2
2
 
3
+ from moscot.backends.ott._utils import sinkhorn_divergence
3
4
  from moscot.backends.ott.output import OTTOutput
4
5
  from moscot.backends.ott.solver import GWSolver, SinkhornSolver
5
6
  from moscot.costs import register_cost
6
7
 
7
- __all__ = ["OTTOutput", "GWSolver", "SinkhornSolver"]
8
+ __all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"]
8
9
 
9
10
  register_cost("euclidean", backend="ott")(costs.Euclidean)
10
11
  register_cost("sq_euclidean", backend="ott")(costs.SqEuclidean)
@@ -0,0 +1,88 @@
1
+ from typing import Any, Optional
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import scipy.sparse as sp
6
+ from ott.geometry import geometry, pointcloud
7
+ from ott.tools import sinkhorn_divergence as sdiv
8
+
9
+ from moscot._logging import logger
10
+ from moscot._types import ArrayLike, ScaleCost_t
11
+
12
+ __all__ = ["sinkhorn_divergence"]
13
+
14
+
15
+ def sinkhorn_divergence(
16
+ point_cloud_1: ArrayLike,
17
+ point_cloud_2: ArrayLike,
18
+ a: Optional[ArrayLike] = None,
19
+ b: Optional[ArrayLike] = None,
20
+ epsilon: Optional[float] = 1e-1,
21
+ scale_cost: ScaleCost_t = 1.0,
22
+ **kwargs: Any,
23
+ ) -> float:
24
+ point_cloud_1 = jnp.asarray(point_cloud_1)
25
+ point_cloud_2 = jnp.asarray(point_cloud_2)
26
+ a = None if a is None else jnp.asarray(a)
27
+ b = None if b is None else jnp.asarray(b)
28
+
29
+ output = sdiv.sinkhorn_divergence(
30
+ pointcloud.PointCloud,
31
+ x=point_cloud_1,
32
+ y=point_cloud_2,
33
+ a=a,
34
+ b=b,
35
+ epsilon=epsilon,
36
+ scale_cost=scale_cost,
37
+ **kwargs,
38
+ )
39
+ xy_conv, xx_conv, *yy_conv = output.converged
40
+
41
+ if not xy_conv:
42
+ logger.warning("Solver did not converge in the `x/y` term.")
43
+ if not xx_conv:
44
+ logger.warning("Solver did not converge in the `x/x` term.")
45
+ if len(yy_conv) and not yy_conv[0]:
46
+ logger.warning("Solver did not converge in the `y/y` term.")
47
+
48
+ return float(output.divergence)
49
+
50
+
51
+ def check_shapes(geom_x: geometry.Geometry, geom_y: geometry.Geometry, geom_xy: geometry.Geometry) -> None:
52
+ n, m = geom_xy.shape
53
+ n_, m_ = geom_x.shape[0], geom_y.shape[0]
54
+ if n != n_:
55
+ raise ValueError(f"Expected the first geometry to have `{n}` points, found `{n_}`.")
56
+ if m != m_:
57
+ raise ValueError(f"Expected the second geometry to have `{m}` points, found `{m_}`.")
58
+
59
+
60
+ def alpha_to_fused_penalty(alpha: float) -> float:
61
+ """Convert."""
62
+ if not (0 < alpha <= 1):
63
+ raise ValueError(f"Expected `alpha` to be in interval `(0, 1]`, found `{alpha}`.")
64
+ return (1 - alpha) / alpha
65
+
66
+
67
+ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
68
+ """Ensure that an array is 2-dimensional.
69
+
70
+ Parameters
71
+ ----------
72
+ arr
73
+ Array to check.
74
+ reshape
75
+ Allow reshaping 1-dimensional array to ``[n, 1]``.
76
+
77
+ Returns
78
+ -------
79
+ 2-dimensional :mod:`jax` array.
80
+ """
81
+ if sp.issparse(arr):
82
+ arr = arr.A # type: ignore[attr-defined]
83
+ arr = jnp.asarray(arr)
84
+ if reshape and arr.ndim == 1:
85
+ return jnp.reshape(arr, (-1, 1))
86
+ if arr.ndim != 2:
87
+ raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
88
+ return arr
@@ -5,14 +5,12 @@ import jaxlib.xla_extension as xla_ext
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  import numpy as np
8
- from ott.solvers.linear.sinkhorn import SinkhornOutput as OTTSinkhornOutput
9
- from ott.solvers.linear.sinkhorn_lr import LRSinkhornOutput as OTTLRSinkhornOutput
10
- from ott.solvers.quadratic.gromov_wasserstein import GWOutput as OTTGWOutput
8
+ from ott.solvers.linear import sinkhorn, sinkhorn_lr
9
+ from ott.solvers.quadratic import gromov_wasserstein
11
10
 
12
11
  import matplotlib as mpl
13
12
  import matplotlib.pyplot as plt
14
13
 
15
- from moscot._docs._docs import d
16
14
  from moscot._types import ArrayLike, Device_t
17
15
  from moscot.base.output import BaseSolverOutput
18
16
 
@@ -20,58 +18,59 @@ __all__ = ["OTTOutput"]
20
18
 
21
19
 
22
20
  class OTTOutput(BaseSolverOutput):
23
- """Output of various optimal transport problems.
21
+ """Output of various :term:`OT` problems.
24
22
 
25
23
  Parameters
26
24
  ----------
27
25
  output
28
- Output of :mod:`ott` backend.
26
+ Output of the :mod:`ott` backend.
29
27
  """
30
28
 
31
- _NOT_COMPUTED = -1.0
29
+ _NOT_COMPUTED = -1.0 # sentinel value used in `ott`
32
30
 
33
- def __init__(self, output: Union[OTTSinkhornOutput, OTTLRSinkhornOutput, OTTGWOutput]):
31
+ def __init__(
32
+ self, output: Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput, gromov_wasserstein.GWOutput]
33
+ ):
34
34
  super().__init__()
35
35
  self._output = output
36
- self._costs = None if isinstance(output, OTTSinkhornOutput) else output.costs
36
+ self._costs = None if isinstance(output, sinkhorn.SinkhornOutput) else output.costs
37
37
  self._errors = output.errors
38
38
 
39
- @d.get_sections(base="plot_costs", sections=["Parameters", "Returns"])
40
39
  def plot_costs(
41
40
  self,
42
41
  last: Optional[int] = None,
43
42
  title: Optional[str] = None,
44
43
  return_fig: bool = False,
44
+ ax: Optional[mpl.axes.Axes] = None,
45
45
  figsize: Optional[Tuple[float, float]] = None,
46
46
  dpi: Optional[int] = None,
47
47
  save: Optional[str] = None,
48
- ax: Optional[mpl.axes.Axes] = None,
49
48
  **kwargs: Any,
50
49
  ) -> Optional[mpl.figure.Figure]:
51
- """Plot regularized OT costs during the iterations.
50
+ """Plot regularized :term:`OT` costs during the iterations.
52
51
 
53
52
  Parameters
54
53
  ----------
55
54
  last
56
- How many of the last steps of the algorithm to plot. If `None`, plot the full curve.
55
+ How many of the last steps of the algorithm to plot. If :obj:`None`, plot the full curve.
57
56
  title
58
- Title of the plot. If `None`, it is determined automatically.
57
+ Title of the plot. If :obj:`None`, it is determined automatically.
58
+ return_fig
59
+ Whether to return the figure.
60
+ ax
61
+ Axes on which to plot.
59
62
  figsize
60
63
  Size of the figure.
61
64
  dpi
62
65
  Dots per inch.
63
66
  save
64
67
  Path where to save the figure.
65
- return_fig
66
- Whether to return the figure.
67
- ax
68
- Axes on which to plot.
69
68
  kwargs
70
- Keyword arguments for :meth:`~matplotlib.axes.Axes.plot`.
69
+ Keyword arguments for :meth:`matplotlib.axes.Axes.plot`.
71
70
 
72
71
  Returns
73
72
  -------
74
- The figure if ``return_fig = True``.
73
+ If ``return_fig = True``, return the figure.
75
74
  """
76
75
  if self._costs is None:
77
76
  raise RuntimeError("No costs to plot.")
@@ -83,7 +82,6 @@ class OTTOutput(BaseSolverOutput):
83
82
  fig.savefig(save)
84
83
  return fig if return_fig else None
85
84
 
86
- @d.dedent
87
85
  def plot_errors(
88
86
  self,
89
87
  last: Optional[int] = None,
@@ -96,15 +94,34 @@ class OTTOutput(BaseSolverOutput):
96
94
  ax: Optional[mpl.axes.Axes] = None,
97
95
  **kwargs: Any,
98
96
  ) -> Optional[mpl.figure.Figure]:
99
- """Plot errors during the iterations.
97
+ """Plot errors along iterations.
100
98
 
101
99
  Parameters
102
100
  ----------
103
- %(plot_costs.parameters)s
101
+ last
102
+ Number of errors corresponding at the ``last`` steps of the algorithm to plot. If :obj:`None`,
103
+ plot the full curve.
104
+ title
105
+ Title of the plot. If :obj:`None`, it is determined automatically.
106
+ outer_iteration
107
+ Which outermost iteration's errors to plot.
108
+ Only used when this is the solution to the :term:`quadratic problem`.
109
+ return_fig
110
+ Whether to return the figure.
111
+ ax
112
+ Axes on which to plot.
113
+ figsize
114
+ Size of the figure.
115
+ dpi
116
+ Dots per inch.
117
+ save
118
+ Path where to save the figure.
119
+ kwargs
120
+ Keyword arguments for :meth:`matplotlib.axes.Axes.plot`.
104
121
 
105
122
  Returns
106
123
  -------
107
- %(plot_costs.returns)s
124
+ If ``return_fig = True``, return the figure.
108
125
  """
109
126
  if self._errors is None:
110
127
  raise RuntimeError("No errors to plot.")
@@ -155,7 +172,7 @@ class OTTOutput(BaseSolverOutput):
155
172
 
156
173
  @property
157
174
  def shape(self) -> Tuple[int, int]: # noqa: D102
158
- if isinstance(self._output, OTTSinkhornOutput):
175
+ if isinstance(self._output, sinkhorn.SinkhornOutput):
159
176
  return self._output.f.shape[0], self._output.g.shape[0]
160
177
  return self._output.geom.shape
161
178
 
@@ -165,9 +182,12 @@ class OTTOutput(BaseSolverOutput):
165
182
 
166
183
  @property
167
184
  def is_linear(self) -> bool: # noqa: D102
168
- return isinstance(self._output, (OTTSinkhornOutput, OTTLRSinkhornOutput))
185
+ return isinstance(self._output, (sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput))
169
186
 
170
187
  def to(self, device: Optional[Device_t] = None) -> "OTTOutput": # noqa: D102
188
+ if device is None:
189
+ return OTTOutput(jax.device_put(self._output, device=device))
190
+
171
191
  if isinstance(device, str) and ":" in device:
172
192
  device, ix = device.split(":")
173
193
  idx = int(ix)
@@ -184,9 +204,7 @@ class OTTOutput(BaseSolverOutput):
184
204
 
185
205
  @property
186
206
  def cost(self) -> float: # noqa: D102
187
- if isinstance(self._output, (OTTSinkhornOutput, OTTLRSinkhornOutput)):
188
- return float(self._output.reg_ot_cost)
189
- return float(self._output.reg_gw_cost)
207
+ return float(self._output.reg_ot_cost if self.is_linear else self._output.reg_gw_cost)
190
208
 
191
209
  @property
192
210
  def converged(self) -> bool: # noqa: D102
@@ -194,14 +212,14 @@ class OTTOutput(BaseSolverOutput):
194
212
 
195
213
  @property
196
214
  def potentials(self) -> Optional[Tuple[ArrayLike, ArrayLike]]: # noqa: D102
197
- if isinstance(self._output, OTTSinkhornOutput):
215
+ if isinstance(self._output, sinkhorn.SinkhornOutput):
198
216
  return self._output.f, self._output.g
199
217
  return None
200
218
 
201
219
  @property
202
220
  def rank(self) -> int: # noqa: D102
203
221
  lin_output = self._output if self.is_linear else self._output.linear_state
204
- return len(lin_output.g) if isinstance(lin_output, OTTLRSinkhornOutput) else -1
222
+ return len(lin_output.g) if isinstance(lin_output, sinkhorn_lr.LRSinkhornOutput) else -1
205
223
 
206
224
  def _ones(self, n: int) -> ArrayLike: # noqa: D102
207
- return jnp.ones((n,)) # type: ignore[return-value]
225
+ return jnp.ones((n,))