moscot 0.3.4__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. {moscot-0.3.4 → moscot-0.4.0}/.gitignore +1 -0
  2. {moscot-0.3.4 → moscot-0.4.0}/.pre-commit-config.yaml +10 -10
  3. moscot-0.4.0/.run_notebooks.sh +62 -0
  4. {moscot-0.3.4 → moscot-0.4.0}/PKG-INFO +10 -5
  5. {moscot-0.3.4 → moscot-0.4.0}/pyproject.toml +30 -8
  6. moscot-0.4.0/src/moscot/__init__.py +13 -0
  7. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/_types.py +9 -9
  8. moscot-0.4.0/src/moscot/backends/ott/__init__.py +15 -0
  9. moscot-0.4.0/src/moscot/backends/ott/_utils.py +331 -0
  10. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/backends/ott/output.py +247 -4
  11. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/backends/ott/solver.py +282 -29
  12. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/backends/utils.py +19 -7
  13. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/output.py +99 -40
  14. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/problems/__init__.py +2 -1
  15. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/problems/_mixins.py +135 -224
  16. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/problems/_utils.py +11 -0
  17. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/problems/birth_death.py +43 -53
  18. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/problems/compound_problem.py +52 -17
  19. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/problems/manager.py +4 -4
  20. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/problems/problem.py +367 -44
  21. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/solver.py +9 -6
  22. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/datasets.py +111 -46
  23. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/plotting/_utils.py +2 -3
  24. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/__init__.py +2 -0
  25. moscot-0.4.0/src/moscot/problems/_utils.py +203 -0
  26. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/cross_modality/_mixins.py +17 -27
  27. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/cross_modality/_translation.py +53 -24
  28. moscot-0.4.0/src/moscot/problems/generic/__init__.py +14 -0
  29. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/generic/_generic.py +189 -43
  30. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/generic/_mixins.py +11 -27
  31. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/space/_alignment.py +65 -24
  32. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/space/_mapping.py +82 -33
  33. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/space/_mixins.py +131 -110
  34. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/spatiotemporal/_spatio_temporal.py +23 -9
  35. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/time/_lineage.py +74 -27
  36. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/time/_mixins.py +45 -169
  37. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/subset_policy.py +5 -0
  38. moscot-0.4.0/src/moscot/utils/tagged_array.py +373 -0
  39. {moscot-0.3.4 → moscot-0.4.0}/src/moscot.egg-info/PKG-INFO +10 -5
  40. {moscot-0.3.4 → moscot-0.4.0}/src/moscot.egg-info/SOURCES.txt +1 -0
  41. {moscot-0.3.4 → moscot-0.4.0}/src/moscot.egg-info/requires.txt +9 -2
  42. moscot-0.3.4/src/moscot/__init__.py +0 -13
  43. moscot-0.3.4/src/moscot/backends/ott/__init__.py +0 -18
  44. moscot-0.3.4/src/moscot/backends/ott/_utils.py +0 -111
  45. moscot-0.3.4/src/moscot/problems/_utils.py +0 -85
  46. moscot-0.3.4/src/moscot/problems/generic/__init__.py +0 -4
  47. moscot-0.3.4/src/moscot/utils/tagged_array.py +0 -187
  48. {moscot-0.3.4 → moscot-0.4.0}/.gitmodules +0 -0
  49. {moscot-0.3.4 → moscot-0.4.0}/.readthedocs.yml +0 -0
  50. {moscot-0.3.4 → moscot-0.4.0}/LICENSE +0 -0
  51. {moscot-0.3.4 → moscot-0.4.0}/MANIFEST.in +0 -0
  52. {moscot-0.3.4 → moscot-0.4.0}/README.rst +0 -0
  53. {moscot-0.3.4 → moscot-0.4.0}/codecov.yml +0 -0
  54. {moscot-0.3.4 → moscot-0.4.0}/setup.cfg +0 -0
  55. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/_constants.py +0 -0
  56. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/_logging.py +0 -0
  57. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/_registry.py +0 -0
  58. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/backends/__init__.py +0 -0
  59. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/__init__.py +0 -0
  60. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/cost.py +0 -0
  61. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/costs/__init__.py +0 -0
  62. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/costs/_costs.py +0 -0
  63. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/costs/_utils.py +0 -0
  64. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/plotting/__init__.py +0 -0
  65. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/plotting/_plotting.py +0 -0
  66. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/cross_modality/__init__.py +0 -0
  67. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/space/__init__.py +0 -0
  68. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/spatiotemporal/__init__.py +0 -0
  69. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/time/__init__.py +0 -0
  70. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/py.typed +0 -0
  71. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/__init__.py +0 -0
  72. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/_data/allTFs_dmel.txt +0 -0
  73. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/_data/allTFs_hg38.txt +0 -0
  74. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/_data/allTFs_mm.txt +0 -0
  75. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/_data/human_apoptosis.txt +0 -0
  76. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/_data/human_proliferation.txt +0 -0
  77. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/_data/mouse_apoptosis.txt +0 -0
  78. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/_data/mouse_proliferation.txt +0 -0
  79. {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/data.py +0 -0
  80. {moscot-0.3.4 → moscot-0.4.0}/src/moscot.egg-info/dependency_links.txt +0 -0
  81. {moscot-0.3.4 → moscot-0.4.0}/src/moscot.egg-info/top_level.txt +0 -0
@@ -154,3 +154,4 @@ packages.dot
154
154
 
155
155
  # plotting tests
156
156
  tests/plotting/actual_figures/
157
+ tests/plotting/figures/
@@ -2,18 +2,18 @@ fail_fast: false
2
2
  default_language_version:
3
3
  python: python3
4
4
  default_stages:
5
- - commit
6
- - push
5
+ - pre-commit
6
+ - pre-push
7
7
  minimum_pre_commit_version: 3.0.0
8
8
  repos:
9
9
  - repo: https://github.com/pre-commit/mirrors-mypy
10
- rev: v1.8.0
10
+ rev: v1.13.0
11
11
  hooks:
12
12
  - id: mypy
13
13
  additional_dependencies: [numpy>=1.25.0]
14
14
  files: ^src
15
15
  - repo: https://github.com/psf/black
16
- rev: 24.2.0
16
+ rev: 24.10.0
17
17
  hooks:
18
18
  - id: black
19
19
  additional_dependencies: [toml]
@@ -29,7 +29,7 @@ repos:
29
29
  additional_dependencies: [toml]
30
30
  args: [--order-by-type]
31
31
  - repo: https://github.com/pre-commit/pre-commit-hooks
32
- rev: v4.5.0
32
+ rev: v5.0.0
33
33
  hooks:
34
34
  - id: check-merge-conflict
35
35
  - id: check-ast
@@ -42,28 +42,28 @@ repos:
42
42
  - id: check-yaml
43
43
  - id: check-toml
44
44
  - repo: https://github.com/asottile/pyupgrade
45
- rev: v3.15.1
45
+ rev: v3.19.0
46
46
  hooks:
47
47
  - id: pyupgrade
48
48
  args: [--py3-plus, --py38-plus, --keep-runtime-typing]
49
49
  - repo: https://github.com/asottile/blacken-docs
50
- rev: 1.16.0
50
+ rev: 1.19.1
51
51
  hooks:
52
52
  - id: blacken-docs
53
53
  additional_dependencies: [black==23.1.0]
54
54
  - repo: https://github.com/rstcheck/rstcheck
55
- rev: v6.2.0
55
+ rev: v6.2.4
56
56
  hooks:
57
57
  - id: rstcheck
58
58
  additional_dependencies: [tomli]
59
59
  args: [--config=pyproject.toml]
60
60
  - repo: https://github.com/PyCQA/doc8
61
- rev: v1.1.1
61
+ rev: v1.1.2
62
62
  hooks:
63
63
  - id: doc8
64
64
  - repo: https://github.com/astral-sh/ruff-pre-commit
65
65
  # Ruff version.
66
- rev: v0.2.2
66
+ rev: v0.7.2
67
67
  hooks:
68
68
  - id: ruff
69
69
  args: [--fix, --exit-non-zero-on-fix]
@@ -0,0 +1,62 @@
1
+ #!/bin/bash
2
+
3
+ # Check if the base directory is provided as an argument
4
+ if [ "$#" -ne 1 ]; then
5
+ echo "Usage: $0 <base_notebook_directory>"
6
+ exit 1
7
+ fi
8
+
9
+ # Base directory for notebooks
10
+ base_dir=$1
11
+
12
+ # Define notebook directories or patterns
13
+ declare -a notebooks=(
14
+ "$base_dir/examples/plotting/*.ipynb"
15
+ "$base_dir/examples/problems/*.ipynb"
16
+ "$base_dir/examples/solvers/*.ipynb"
17
+ )
18
+
19
+ # Initialize an array to hold valid notebook paths
20
+ declare -a valid_notebooks
21
+
22
+ # Gather all valid notebook files from the patterns
23
+ echo "Gathering notebooks..."
24
+ for pattern in "${notebooks[@]}"; do
25
+ for nb in $pattern; do
26
+ if [[ -f "$nb" ]]; then # Check if the file exists
27
+ valid_notebooks+=("$nb") # Add to the list of valid notebooks
28
+ fi
29
+ done
30
+ done
31
+
32
+ # Check if we have any notebooks to run
33
+ if [ ${#valid_notebooks[@]} -eq 0 ]; then
34
+ echo "No notebooks found to run."
35
+ exit 1
36
+ fi
37
+
38
+ # Echo the notebooks that will be run for clarity
39
+ echo "Preparing to run the following notebooks:"
40
+ for nb in "${valid_notebooks[@]}"; do
41
+ echo "$nb"
42
+ done
43
+
44
+ # Initialize a flag to track the success of all commands
45
+ all_success=true
46
+
47
+ # Execute all valid notebooks
48
+ for nb in "${valid_notebooks[@]}"; do
49
+ echo "Running $nb"
50
+ jupytext -k moscot --execute "$nb" || {
51
+ echo "Failed to run $nb"
52
+ all_success=false
53
+ }
54
+ done
55
+
56
+ # Check if any executions failed
57
+ if [ "$all_success" = false ]; then
58
+ echo "One or more notebooks failed to execute."
59
+ exit 1
60
+ fi
61
+
62
+ echo "All notebooks executed successfully."
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: moscot
3
- Version: 0.3.4
3
+ Version: 0.4.0
4
4
  Summary: Multi-omic single-cell optimal transport tools
5
5
  Author: Dominik Klein, Giovanni Palla, Michal Klein, Zoe Piran, Marius Lange
6
6
  Maintainer-email: Dominik Klein <dominik.klein@helmholtz-muenchen.de>, Giovanni Palla <giovanni.palla@helmholtz-muenchen.de>, Michal Klein <michal.klein@helmholtz-muenchen.de>
@@ -49,28 +49,33 @@ Classifier: Operating System :: MacOS :: MacOS X
49
49
  Classifier: Operating System :: Microsoft :: Windows
50
50
  Classifier: Typing :: Typed
51
51
  Classifier: Programming Language :: Python :: 3
52
- Classifier: Programming Language :: Python :: 3.8
53
52
  Classifier: Programming Language :: Python :: 3.9
54
53
  Classifier: Programming Language :: Python :: 3.10
55
54
  Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
56
55
  Classifier: Topic :: Scientific/Engineering :: Mathematics
57
- Requires-Python: >=3.8
56
+ Requires-Python: >=3.10
58
57
  Description-Content-Type: text/x-rst
59
58
  License-File: LICENSE
60
59
  Requires-Dist: numpy>=1.20.0
61
60
  Requires-Dist: scipy>=1.7.0
62
61
  Requires-Dist: pandas>=2.0.1
63
- Requires-Dist: networkx>=2.6.3
62
+ Requires-Dist: networkx>=3.2
64
63
  Requires-Dist: matplotlib>=3.5.0
65
64
  Requires-Dist: anndata>=0.9.1
66
65
  Requires-Dist: scanpy>=1.9.3
67
66
  Requires-Dist: wrapt>=1.13.2
68
67
  Requires-Dist: docrep>=0.3.2
69
- Requires-Dist: ott-jax>=0.4.5
68
+ Requires-Dist: ott-jax[neural]>=0.5.0
70
69
  Requires-Dist: cloudpickle>=2.2.0
71
70
  Requires-Dist: rich>=13.5
71
+ Requires-Dist: docstring_inheritance>=2.0.0
72
+ Requires-Dist: mudata>=0.2.2
72
73
  Provides-Extra: spatial
73
74
  Requires-Dist: squidpy>=1.2.3; extra == "spatial"
75
+ Provides-Extra: neural
76
+ Requires-Dist: optax; extra == "neural"
77
+ Requires-Dist: flax; extra == "neural"
78
+ Requires-Dist: diffrax; extra == "neural"
74
79
  Provides-Extra: dev
75
80
  Requires-Dist: pre-commit>=3.0.0; extra == "dev"
76
81
  Requires-Dist: tox>=4; extra == "dev"
@@ -7,7 +7,7 @@ name = "moscot"
7
7
  dynamic = ["version"]
8
8
  description = "Multi-omic single-cell optimal transport tools"
9
9
  readme = "README.rst"
10
- requires-python = ">=3.8"
10
+ requires-python = ">=3.10"
11
11
  license = {file = "LICENSE"}
12
12
  classifiers = [
13
13
  "Development Status :: 4 - Beta",
@@ -19,7 +19,6 @@ classifiers = [
19
19
  "Operating System :: Microsoft :: Windows",
20
20
  "Typing :: Typed",
21
21
  "Programming Language :: Python :: 3",
22
- "Programming Language :: Python :: 3.8",
23
22
  "Programming Language :: Python :: 3.9",
24
23
  "Programming Language :: Python :: 3.10",
25
24
  "Topic :: Scientific/Engineering :: Bio-Informatics",
@@ -48,22 +47,31 @@ dependencies = [
48
47
  "numpy>=1.20.0",
49
48
  "scipy>=1.7.0",
50
49
  "pandas>=2.0.1",
51
- "networkx>=2.6.3",
50
+ "networkx>=3.2",
52
51
  # https://github.com/scverse/scanpy/issues/2411
53
52
  "matplotlib>=3.5.0",
54
53
  "anndata>=0.9.1",
55
54
  "scanpy>=1.9.3",
56
55
  "wrapt>=1.13.2",
57
56
  "docrep>=0.3.2",
58
- "ott-jax>=0.4.5",
57
+ "ott-jax[neural]>=0.5.0",
59
58
  "cloudpickle>=2.2.0",
60
59
  "rich>=13.5",
60
+ "docstring_inheritance>=2.0.0",
61
+ "mudata>=0.2.2"
61
62
  ]
62
63
 
63
64
  [project.optional-dependencies]
64
65
  spatial = [
65
66
  "squidpy>=1.2.3"
66
67
  ]
68
+
69
+ neural = [
70
+ "optax",
71
+ "flax",
72
+ "diffrax",
73
+
74
+ ]
67
75
  dev = [
68
76
  "pre-commit>=3.0.0",
69
77
  "tox>=4",
@@ -225,7 +233,7 @@ ignore_roles = [
225
233
 
226
234
  [tool.mypy]
227
235
  mypy_path = "$MYPY_CONFIG_FILE_DIR/src"
228
- python_version = "3.9"
236
+ python_version = "3.10"
229
237
  plugins = "numpy.typing.mypy_plugin"
230
238
 
231
239
  ignore_errors = false
@@ -262,16 +270,16 @@ max_line_length = 120
262
270
  legacy_tox_ini = """
263
271
  [tox]
264
272
  min_version = 4.0
265
- env_list = lint-code,py{3.8,3.9,3.10,3.11}
273
+ env_list = lint-code,py{3.10,3.11,3.12}
266
274
  skip_missing_interpreters = true
267
275
 
268
276
  [testenv]
269
277
  extras = test
270
- pass_env = PYTEST_*,CI
271
278
  commands =
272
279
  python -m pytest {tty:--color=yes} {posargs: \
273
280
  --cov={env_site_packages_dir}{/}moscot --cov-config={tox_root}{/}pyproject.toml \
274
281
  --no-cov-on-fail --cov-report=xml --cov-report=term-missing:skip-covered}
282
+ passenv = PYTEST_*,CI
275
283
 
276
284
  [testenv:lint-code]
277
285
  description = Lint the code.
@@ -282,7 +290,6 @@ commands =
282
290
 
283
291
  [testenv:lint-docs]
284
292
  description = Lint the documentation.
285
- deps =
286
293
  extras = docs
287
294
  ignore_errors = true
288
295
  allowlist_externals = make
@@ -294,6 +301,21 @@ commands =
294
301
  # TODO(michalk8): uncomment after https://github.com/theislab/moscot/issues/490
295
302
  # make spelling {posargs}
296
303
 
304
+ [testenv:examples-docs]
305
+ allowlist_externals = bash
306
+ description = Run the notebooks.
307
+ use_develop = true
308
+ deps =
309
+ ipykernel
310
+ jupytext
311
+ nbconvert
312
+ leidenalg
313
+ extras = docs
314
+ changedir = {tox_root}{/}docs
315
+ commands =
316
+ python -m ipykernel install --user --name=moscot
317
+ bash {tox_root}/.run_notebooks.sh {tox_root}{/}docs/notebooks
318
+
297
319
  [testenv:clean-docs]
298
320
  description = Remove the documentation.
299
321
  deps =
@@ -0,0 +1,13 @@
1
+ from importlib import metadata
2
+
3
+ from moscot import backends, base, costs, datasets, plotting, problems, utils
4
+
5
+ try:
6
+ md = metadata.metadata(__name__)
7
+ __version__ = md.get("version", "") # type: ignore[attr-defined]
8
+ __author__ = md.get("Author", "") # type: ignore[attr-defined]
9
+ __maintainer__ = md.get("Maintainer-email", "") # type: ignore[attr-defined]
10
+ except ImportError:
11
+ md = None
12
+
13
+ del metadata, md
@@ -2,6 +2,9 @@ import os
2
2
  from typing import Any, Literal, Mapping, Optional, Sequence, Union
3
3
 
4
4
  import numpy as np
5
+ from ott.initializers.linear.initializers import SinkhornInitializer
6
+ from ott.initializers.linear.initializers_lr import LRInitializer
7
+ from ott.initializers.quadratic.initializers import BaseQuadraticInitializer
5
8
 
6
9
  # TODO(michalk8): polish
7
10
 
@@ -17,13 +20,14 @@ ProblemKind_t = Literal["linear", "quadratic", "unknown"]
17
20
  Numeric_t = Union[int, float] # type of `time_key` arguments
18
21
  Filter_t = Optional[Union[str, Mapping[str, Sequence[Any]]]] # type how to filter adata
19
22
  Str_Dict_t = Optional[Union[str, Mapping[str, Sequence[Any]]]] # type for `cell_transition`
20
- SinkFullRankInit = Literal["default", "gaussian", "sorting"]
21
- LRInitializer_t = Literal["random", "rank2", "k-means", "generalized-k-means"]
23
+ SinkhornInitializerTag_t = Literal["default", "gaussian", "sorting"]
24
+ LRInitializerTag_t = Literal["random", "rank2", "k-means", "generalized-k-means"]
22
25
 
23
- SinkhornInitializer_t = Optional[Union[SinkFullRankInit, LRInitializer_t]]
24
- QuadInitializer_t = Optional[LRInitializer_t]
26
+ LRInitializer_t = Optional[Union[LRInitializer, LRInitializerTag_t]]
27
+ SinkhornInitializer_t = Optional[Union[SinkhornInitializer, SinkhornInitializerTag_t]]
28
+ QuadInitializer_t = Optional[Union[BaseQuadraticInitializer]]
25
29
 
26
- Initializer_t = Union[SinkhornInitializer_t, LRInitializer_t]
30
+ Initializer_t = Union[SinkhornInitializer_t, QuadInitializer_t, LRInitializer_t]
27
31
  ProblemStage_t = Literal["prepared", "solved"]
28
32
  Device_t = Union[Literal["cpu", "gpu", "tpu"], str]
29
33
 
@@ -36,10 +40,6 @@ OttCostFn_t = Literal[
36
40
  "pnorm_p",
37
41
  "sq_pnorm",
38
42
  "cosine",
39
- "elastic_l1",
40
- "elastic_l2",
41
- "elastic_stvs",
42
- "elastic_sqk_overlap",
43
43
  "geodesic",
44
44
  ]
45
45
  OttCostFnMap_t = Union[OttCostFn_t, Mapping[Literal["xy", "x", "y"], OttCostFn_t]]
@@ -0,0 +1,15 @@
1
+ from ott.geometry import costs
2
+
3
+ from moscot.backends.ott._utils import sinkhorn_divergence
4
+ from moscot.backends.ott.output import GraphOTTOutput, OTTNeuralOutput, OTTOutput
5
+ from moscot.backends.ott.solver import GENOTLinSolver, GWSolver, SinkhornSolver
6
+ from moscot.costs import register_cost
7
+
8
+ __all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "OTTNeuralOutput", "sinkhorn_divergence", "GENOTLinSolver"]
9
+
10
+
11
+ register_cost("euclidean", backend="ott")(costs.Euclidean)
12
+ register_cost("sq_euclidean", backend="ott")(costs.SqEuclidean)
13
+ register_cost("cosine", backend="ott")(costs.Cosine)
14
+ register_cost("pnorm_p", backend="ott")(costs.PNormP)
15
+ register_cost("sq_pnorm", backend="ott")(costs.SqPNorm)
@@ -0,0 +1,331 @@
1
+ from collections import defaultdict
2
+ from functools import partial
3
+ from typing import Any, Dict, Iterable, Literal, Optional, Tuple, Union
4
+
5
+ import jax
6
+ import jax.experimental.sparse as jesp
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ import scipy.sparse as sp
10
+ from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
11
+ from ott.initializers.linear import initializers as init_lib
12
+ from ott.initializers.linear import initializers_lr as lr_init_lib
13
+ from ott.neural import datasets
14
+ from ott.solvers import utils as solver_utils
15
+ from ott.tools.sinkhorn_divergence import sinkhorn_divergence as sinkhorn_div
16
+
17
+ from moscot._logging import logger
18
+ from moscot._types import ArrayLike, ScaleCost_t
19
+
20
+ Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]
21
+
22
+
23
+ __all__ = ["sinkhorn_divergence"]
24
+
25
+
26
+ class InitializerResolver:
27
+ """Class for creating various OT solver initializers.
28
+
29
+ This class provides static methods to create and manage different types of
30
+ initializers used in optimal transport solvers, including low-rank, k-means,
31
+ and standard Sinkhorn initializers.
32
+ """
33
+
34
+ @staticmethod
35
+ def lr_from_str(
36
+ initializer: str,
37
+ rank: int,
38
+ **kwargs: Any,
39
+ ) -> lr_init_lib.LRInitializer:
40
+ """Create a low-rank initializer from a string specification.
41
+
42
+ Parameters
43
+ ----------
44
+ initializer : str
45
+ Either existing initializer instance or string specifier.
46
+ rank : int
47
+ Rank for the initialization.
48
+ **kwargs : Any
49
+ Additional keyword arguments for initializer creation.
50
+
51
+ Returns
52
+ -------
53
+ LRInitializer
54
+ Configured low-rank initializer.
55
+
56
+ Raises
57
+ ------
58
+ NotImplementedError
59
+ If requested initializer type is not implemented.
60
+ """
61
+ if isinstance(initializer, lr_init_lib.LRInitializer):
62
+ return initializer
63
+ if initializer == "k-means":
64
+ return lr_init_lib.KMeansInitializer(rank=rank, **kwargs)
65
+ if initializer == "generalized-k-means":
66
+ return lr_init_lib.GeneralizedKMeansInitializer(rank=rank, **kwargs)
67
+ if initializer == "random":
68
+ return lr_init_lib.RandomInitializer(rank=rank, **kwargs)
69
+ if initializer == "rank2":
70
+ return lr_init_lib.Rank2Initializer(rank=rank, **kwargs)
71
+ raise NotImplementedError(f"Initializer `{initializer}` is not implemented.")
72
+
73
+ @staticmethod
74
+ def from_str(
75
+ initializer: str,
76
+ **kwargs: Any,
77
+ ) -> init_lib.SinkhornInitializer:
78
+ """Create a Sinkhorn initializer from a string specification.
79
+
80
+ Parameters
81
+ ----------
82
+ initializer : str
83
+ String specifier for initializer type.
84
+ **kwargs : Any
85
+ Additional keyword arguments for initializer creation.
86
+
87
+ Returns
88
+ -------
89
+ SinkhornInitializer
90
+ Configured Sinkhorn initializer.
91
+
92
+ Raises
93
+ ------
94
+ NotImplementedError
95
+ If requested initializer type is not implemented.
96
+ """
97
+ if isinstance(initializer, init_lib.SinkhornInitializer):
98
+ return initializer
99
+ if initializer == "default":
100
+ return init_lib.DefaultInitializer(**kwargs)
101
+ if initializer == "gaussian":
102
+ return init_lib.GaussianInitializer(**kwargs)
103
+ if initializer == "sorting":
104
+ return init_lib.SortingInitializer(**kwargs)
105
+ if initializer == "subsample":
106
+ return init_lib.SubsampleInitializer(**kwargs)
107
+ raise NotImplementedError(f"Initializer `{initializer}` is not yet implemented.")
108
+
109
+
110
+ def sinkhorn_divergence(
111
+ point_cloud_1: ArrayLike,
112
+ point_cloud_2: ArrayLike,
113
+ a: Optional[ArrayLike] = None,
114
+ b: Optional[ArrayLike] = None,
115
+ epsilon: Union[float, epsilon_scheduler.Epsilon] = 1e-1,
116
+ tau_a: float = 1.0,
117
+ tau_b: float = 1.0,
118
+ scale_cost: ScaleCost_t = 1.0,
119
+ batch_size: Optional[int] = None,
120
+ **kwargs: Any,
121
+ ) -> float:
122
+ point_cloud_1 = jnp.asarray(point_cloud_1)
123
+ point_cloud_2 = jnp.asarray(point_cloud_2)
124
+ a = None if a is None else jnp.asarray(a)
125
+ b = None if b is None else jnp.asarray(b)
126
+
127
+ output = sinkhorn_div(
128
+ pointcloud.PointCloud,
129
+ x=point_cloud_1,
130
+ y=point_cloud_2,
131
+ batch_size=batch_size,
132
+ a=a,
133
+ b=b,
134
+ scale_cost=scale_cost,
135
+ epsilon=epsilon,
136
+ solve_kwargs={
137
+ "tau_a": tau_a,
138
+ "tau_b": tau_b,
139
+ },
140
+ **kwargs,
141
+ )[1]
142
+ xy_conv, xx_conv, *yy_conv = output.converged
143
+
144
+ if not xy_conv:
145
+ logger.warning("Solver did not converge in the `x/y` term.")
146
+ if not xx_conv:
147
+ logger.warning("Solver did not converge in the `x/x` term.")
148
+ if len(yy_conv) and not yy_conv[0]:
149
+ logger.warning("Solver did not converge in the `y/y` term.")
150
+
151
+ return float(output.divergence)
152
+
153
+
154
+ @partial(jax.jit, static_argnames=["k"])
155
+ def get_nearest_neighbors(
156
+ input_batch: jnp.ndarray,
157
+ target: jnp.ndarray,
158
+ k: int = 30,
159
+ recall_target: float = 0.95,
160
+ aggregate_to_topk: bool = True,
161
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
162
+ """Get the k nearest neighbors of the input batch in the target."""
163
+ if target.shape[0] < k:
164
+ raise ValueError(f"k is {k}, but must be smaller or equal than {target.shape[0]}.")
165
+ pairwise_euclidean_distances = pointcloud.PointCloud(input_batch, target).cost_matrix
166
+ return jax.lax.approx_min_k(
167
+ pairwise_euclidean_distances, k=k, recall_target=recall_target, aggregate_to_topk=aggregate_to_topk
168
+ )
169
+
170
+
171
+ def check_shapes(geom_x: geometry.Geometry, geom_y: geometry.Geometry, geom_xy: geometry.Geometry) -> None:
172
+ n, m = geom_xy.shape
173
+ n_, m_ = geom_x.shape[0], geom_y.shape[0]
174
+ if n != n_:
175
+ raise ValueError(f"Expected the first geometry to have `{n}` points, found `{n_}`.")
176
+ if m != m_:
177
+ raise ValueError(f"Expected the second geometry to have `{m}` points, found `{m_}`.")
178
+
179
+
180
+ def alpha_to_fused_penalty(alpha: float) -> float:
181
+ """Convert."""
182
+ if not (0 < alpha <= 1):
183
+ raise ValueError(f"Expected `alpha` to be in interval `(0, 1]`, found `{alpha}`.")
184
+ return (1 - alpha) / alpha
185
+
186
+
187
+ def densify(arr: ArrayLike) -> jax.Array:
188
+ """If the input is sparse, convert it to dense.
189
+
190
+ Parameters
191
+ ----------
192
+ arr
193
+ Array to check.
194
+
195
+ Returns
196
+ -------
197
+ dense :mod:`jax` array.
198
+ """
199
+ if sp.issparse(arr):
200
+ arr = arr.toarray() # type: ignore[attr-defined]
201
+ elif isinstance(arr, jesp.BCOO):
202
+ arr = arr.todense()
203
+ return jnp.asarray(arr)
204
+
205
+
206
+ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
207
+ """Ensure that an array is 2-dimensional.
208
+
209
+ Parameters
210
+ ----------
211
+ arr
212
+ Array to check.
213
+ reshape
214
+ Allow reshaping 1-dimensional array to ``[n, 1]``.
215
+
216
+ Returns
217
+ -------
218
+ 2-dimensional :mod:`jax` array.
219
+ """
220
+ if reshape and arr.ndim == 1:
221
+ return jnp.reshape(arr, (-1, 1))
222
+ if arr.ndim != 2:
223
+ raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
224
+ return arr.astype(jnp.float64)
225
+
226
+
227
+ def convert_scipy_sparse(arr: Union[sp.spmatrix, jesp.BCOO]) -> jesp.BCOO:
228
+ """If the input is a scipy sparse matrix, convert it to a jax BCOO."""
229
+ if sp.issparse(arr):
230
+ return jesp.BCOO.from_scipy_sparse(arr)
231
+ return arr
232
+
233
+
234
+ def _instantiate_geodesic_cost(
235
+ arr: jax.Array,
236
+ problem_shape: Tuple[int, int],
237
+ t: Optional[float],
238
+ is_linear_term: bool,
239
+ epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
240
+ relative_epsilon: Optional[bool] = None,
241
+ scale_cost: Scale_t = 1.0,
242
+ directed: bool = True,
243
+ **kwargs: Any,
244
+ ) -> geometry.Geometry:
245
+ n_src, n_tgt = problem_shape
246
+ if is_linear_term and n_src + n_tgt != arr.shape[0]:
247
+ raise ValueError(f"Expected `x` to have `{n_src + n_tgt}` points, found `{arr.shape[0]}`.")
248
+ t = epsilon / 4.0 if t is None else t
249
+ cm_full = geodesic.Geodesic.from_graph(arr, t=t, directed=directed, **kwargs).cost_matrix
250
+ cm = cm_full[:n_src, n_src:] if is_linear_term else cm_full
251
+ return geometry.Geometry(cm, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost)
252
+
253
+
254
+ def data_match_fn(
255
+ src_lin: Optional[jnp.ndarray] = None,
256
+ tgt_lin: Optional[jnp.ndarray] = None,
257
+ src_quad: Optional[jnp.ndarray] = None,
258
+ tgt_quad: Optional[jnp.ndarray] = None,
259
+ *,
260
+ typ: Literal["lin", "quad", "fused"],
261
+ **data_match_fn_kwargs,
262
+ ) -> jnp.ndarray:
263
+ if typ == "lin":
264
+ return solver_utils.match_linear(x=src_lin, y=tgt_lin, **data_match_fn_kwargs)
265
+ if typ == "quad":
266
+ return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad, **data_match_fn_kwargs)
267
+ if typ == "fused":
268
+ return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin, **data_match_fn_kwargs)
269
+ raise NotImplementedError(f"Unknown type: {typ}.")
270
+
271
+
272
+ class Loader:
273
+
274
+ def __init__(self, dataset: datasets.OTDataset, batch_size: int, seed: Optional[int] = None):
275
+ self.dataset = dataset
276
+ self.batch_size = batch_size
277
+ self._rng = np.random.default_rng(seed)
278
+
279
+ def __iter__(self):
280
+ return self
281
+
282
+ def __next__(self) -> Dict[str, jnp.ndarray]:
283
+ data = defaultdict(list)
284
+ for _ in range(self.batch_size):
285
+ ix = self._rng.integers(0, len(self.dataset))
286
+ for k, v in self.dataset[ix].items():
287
+ data[k].append(v)
288
+ return {k: jnp.vstack(v) for k, v in data.items()}
289
+
290
+ def __len__(self):
291
+ return len(self.dataset)
292
+
293
+
294
+ class MultiLoader:
295
+ """Dataset for OT problems with conditions.
296
+
297
+ This data loader wraps several data loaders and samples from them.
298
+
299
+ Args:
300
+ datasets: Datasets to sample from.
301
+ seed: Random seed.
302
+ """
303
+
304
+ def __init__(
305
+ self,
306
+ datasets: Iterable[Loader],
307
+ seed: Optional[int] = None,
308
+ ):
309
+ self.datasets = tuple(datasets)
310
+ self._rng = np.random.default_rng(seed)
311
+ self._iterators: list[MultiLoader] = []
312
+ self._it = 0
313
+
314
+ def __next__(self) -> Dict[str, jnp.ndarray]:
315
+ self._it += 1
316
+
317
+ ix = self._rng.choice(len(self._iterators))
318
+ iterator = self._iterators[ix]
319
+ if self._it < len(self):
320
+ return next(iterator)
321
+ # reset the consumed iterator and return it's first element
322
+ self._iterators[ix] = iterator = iter(self.datasets[ix])
323
+ return next(iterator)
324
+
325
+ def __iter__(self) -> "MultiLoader":
326
+ self._it = 0
327
+ self._iterators = [iter(ds) for ds in self.datasets]
328
+ return self
329
+
330
+ def __len__(self) -> int:
331
+ return max((len(ds) for ds in self.datasets), default=0)