moscot 0.3.2__tar.gz → 0.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of moscot might be problematic. Click here for more details.

Files changed (75) hide show
  1. {moscot-0.3.2 → moscot-0.3.4}/.pre-commit-config.yaml +9 -9
  2. {moscot-0.3.2 → moscot-0.3.4}/PKG-INFO +33 -3
  3. {moscot-0.3.2 → moscot-0.3.4}/README.rst +1 -1
  4. {moscot-0.3.2 → moscot-0.3.4}/pyproject.toml +4 -1
  5. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/_types.py +10 -10
  6. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/backends/ott/__init__.py +4 -2
  7. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/backends/ott/_utils.py +25 -2
  8. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/backends/ott/output.py +79 -6
  9. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/backends/ott/solver.py +148 -26
  10. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/problems/_mixins.py +199 -31
  11. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/problems/_utils.py +8 -1
  12. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/problems/birth_death.py +8 -4
  13. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/problems/compound_problem.py +30 -29
  14. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/problems/problem.py +234 -40
  15. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/solver.py +0 -1
  16. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/costs/_costs.py +5 -2
  17. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/costs/_utils.py +15 -4
  18. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/datasets.py +36 -4
  19. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/plotting/_plotting.py +11 -11
  20. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/plotting/_utils.py +9 -6
  21. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/_utils.py +11 -13
  22. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/cross_modality/_mixins.py +62 -5
  23. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/cross_modality/_translation.py +4 -4
  24. moscot-0.3.4/src/moscot/problems/generic/__init__.py +4 -0
  25. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/generic/_generic.py +245 -32
  26. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/generic/_mixins.py +10 -3
  27. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/space/_alignment.py +11 -7
  28. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/space/_mapping.py +8 -9
  29. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/space/_mixins.py +123 -10
  30. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/spatiotemporal/_spatio_temporal.py +6 -1
  31. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/time/_lineage.py +7 -6
  32. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/time/_mixins.py +95 -45
  33. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/_data/mouse_proliferation.txt +0 -1
  34. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/subset_policy.py +7 -12
  35. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/tagged_array.py +27 -1
  36. {moscot-0.3.2 → moscot-0.3.4}/src/moscot.egg-info/PKG-INFO +33 -3
  37. {moscot-0.3.2 → moscot-0.3.4}/src/moscot.egg-info/requires.txt +2 -1
  38. moscot-0.3.2/src/moscot/problems/generic/__init__.py +0 -4
  39. {moscot-0.3.2 → moscot-0.3.4}/.gitignore +0 -0
  40. {moscot-0.3.2 → moscot-0.3.4}/.gitmodules +0 -0
  41. {moscot-0.3.2 → moscot-0.3.4}/.readthedocs.yml +0 -0
  42. {moscot-0.3.2 → moscot-0.3.4}/LICENSE +0 -0
  43. {moscot-0.3.2 → moscot-0.3.4}/MANIFEST.in +0 -0
  44. {moscot-0.3.2 → moscot-0.3.4}/codecov.yml +0 -0
  45. {moscot-0.3.2 → moscot-0.3.4}/setup.cfg +0 -0
  46. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/__init__.py +0 -0
  47. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/_constants.py +0 -0
  48. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/_logging.py +0 -0
  49. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/_registry.py +0 -0
  50. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/backends/__init__.py +0 -0
  51. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/backends/utils.py +0 -0
  52. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/__init__.py +0 -0
  53. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/cost.py +0 -0
  54. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/output.py +0 -0
  55. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/problems/__init__.py +0 -0
  56. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/problems/manager.py +0 -0
  57. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/costs/__init__.py +0 -0
  58. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/plotting/__init__.py +0 -0
  59. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/__init__.py +0 -0
  60. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/cross_modality/__init__.py +0 -0
  61. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/space/__init__.py +0 -0
  62. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/spatiotemporal/__init__.py +0 -0
  63. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/time/__init__.py +0 -0
  64. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/py.typed +0 -0
  65. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/__init__.py +0 -0
  66. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/_data/allTFs_dmel.txt +0 -0
  67. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/_data/allTFs_hg38.txt +0 -0
  68. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/_data/allTFs_mm.txt +0 -0
  69. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/_data/human_apoptosis.txt +0 -0
  70. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/_data/human_proliferation.txt +0 -0
  71. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/_data/mouse_apoptosis.txt +0 -0
  72. {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/data.py +0 -0
  73. {moscot-0.3.2 → moscot-0.3.4}/src/moscot.egg-info/SOURCES.txt +0 -0
  74. {moscot-0.3.2 → moscot-0.3.4}/src/moscot.egg-info/dependency_links.txt +0 -0
  75. {moscot-0.3.2 → moscot-0.3.4}/src/moscot.egg-info/top_level.txt +0 -0
@@ -7,29 +7,29 @@ default_stages:
7
7
  minimum_pre_commit_version: 3.0.0
8
8
  repos:
9
9
  - repo: https://github.com/pre-commit/mirrors-mypy
10
- rev: v1.4.1
10
+ rev: v1.8.0
11
11
  hooks:
12
12
  - id: mypy
13
13
  additional_dependencies: [numpy>=1.25.0]
14
14
  files: ^src
15
15
  - repo: https://github.com/psf/black
16
- rev: 23.7.0
16
+ rev: 24.2.0
17
17
  hooks:
18
18
  - id: black
19
19
  additional_dependencies: [toml]
20
20
  - repo: https://github.com/pre-commit/mirrors-prettier
21
- rev: v3.0.0
21
+ rev: v4.0.0-alpha.8
22
22
  hooks:
23
23
  - id: prettier
24
24
  language_version: system
25
25
  - repo: https://github.com/PyCQA/isort
26
- rev: 5.12.0
26
+ rev: 5.13.2
27
27
  hooks:
28
28
  - id: isort
29
29
  additional_dependencies: [toml]
30
30
  args: [--order-by-type]
31
31
  - repo: https://github.com/pre-commit/pre-commit-hooks
32
- rev: v4.4.0
32
+ rev: v4.5.0
33
33
  hooks:
34
34
  - id: check-merge-conflict
35
35
  - id: check-ast
@@ -42,17 +42,17 @@ repos:
42
42
  - id: check-yaml
43
43
  - id: check-toml
44
44
  - repo: https://github.com/asottile/pyupgrade
45
- rev: v3.9.0
45
+ rev: v3.15.1
46
46
  hooks:
47
47
  - id: pyupgrade
48
48
  args: [--py3-plus, --py38-plus, --keep-runtime-typing]
49
49
  - repo: https://github.com/asottile/blacken-docs
50
- rev: 1.15.0
50
+ rev: 1.16.0
51
51
  hooks:
52
52
  - id: blacken-docs
53
53
  additional_dependencies: [black==23.1.0]
54
54
  - repo: https://github.com/rstcheck/rstcheck
55
- rev: v6.1.2
55
+ rev: v6.2.0
56
56
  hooks:
57
57
  - id: rstcheck
58
58
  additional_dependencies: [tomli]
@@ -63,7 +63,7 @@ repos:
63
63
  - id: doc8
64
64
  - repo: https://github.com/astral-sh/ruff-pre-commit
65
65
  # Ruff version.
66
- rev: v0.0.280
66
+ rev: v0.2.2
67
67
  hooks:
68
68
  - id: ruff
69
69
  args: [--fix, --exit-non-zero-on-fix]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: moscot
3
- Version: 0.3.2
3
+ Version: 0.3.4
4
4
  Summary: Multi-omic single-cell optimal transport tools
5
5
  Author: Dominik Klein, Giovanni Palla, Michal Klein, Zoe Piran, Marius Lange
6
6
  Maintainer-email: Dominik Klein <dominik.klein@helmholtz-muenchen.de>, Giovanni Palla <giovanni.palla@helmholtz-muenchen.de>, Michal Klein <michal.klein@helmholtz-muenchen.de>
@@ -56,11 +56,41 @@ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
56
56
  Classifier: Topic :: Scientific/Engineering :: Mathematics
57
57
  Requires-Python: >=3.8
58
58
  Description-Content-Type: text/x-rst
59
+ License-File: LICENSE
60
+ Requires-Dist: numpy>=1.20.0
61
+ Requires-Dist: scipy>=1.7.0
62
+ Requires-Dist: pandas>=2.0.1
63
+ Requires-Dist: networkx>=2.6.3
64
+ Requires-Dist: matplotlib>=3.5.0
65
+ Requires-Dist: anndata>=0.9.1
66
+ Requires-Dist: scanpy>=1.9.3
67
+ Requires-Dist: wrapt>=1.13.2
68
+ Requires-Dist: docrep>=0.3.2
69
+ Requires-Dist: ott-jax>=0.4.5
70
+ Requires-Dist: cloudpickle>=2.2.0
71
+ Requires-Dist: rich>=13.5
59
72
  Provides-Extra: spatial
73
+ Requires-Dist: squidpy>=1.2.3; extra == "spatial"
60
74
  Provides-Extra: dev
75
+ Requires-Dist: pre-commit>=3.0.0; extra == "dev"
76
+ Requires-Dist: tox>=4; extra == "dev"
61
77
  Provides-Extra: test
78
+ Requires-Dist: pytest>=7; extra == "test"
79
+ Requires-Dist: pytest-xdist>=3; extra == "test"
80
+ Requires-Dist: pytest-mock>=3.5.0; extra == "test"
81
+ Requires-Dist: pytest-cov>=4; extra == "test"
82
+ Requires-Dist: coverage[toml]>=7; extra == "test"
62
83
  Provides-Extra: docs
63
- License-File: LICENSE
84
+ Requires-Dist: sphinx>=5.1.1; extra == "docs"
85
+ Requires-Dist: sphinx_copybutton>=0.5.0; extra == "docs"
86
+ Requires-Dist: sphinxcontrib-bibtex>=2.3.0; extra == "docs"
87
+ Requires-Dist: sphinxcontrib-spelling>=7.6.2; extra == "docs"
88
+ Requires-Dist: sphinx-autodoc-typehints; extra == "docs"
89
+ Requires-Dist: furo>=2022.09.29; extra == "docs"
90
+ Requires-Dist: sphinx-tippy>=0.4.1; extra == "docs"
91
+ Requires-Dist: myst-nb>=0.17.1; extra == "docs"
92
+ Requires-Dist: ipython>=7.20.0; extra == "docs"
93
+ Requires-Dist: sphinx_design>=0.3.0; extra == "docs"
64
94
 
65
95
  |PyPI| |Downloads| |CI| |Pre-commit| |Codecov| |Docs|
66
96
 
@@ -125,6 +155,6 @@ Our preprint "Mapping cells through time and space with moscot" can be found `he
125
155
  :target: https://moscot.readthedocs.io/en/stable/
126
156
  :alt: Documentation
127
157
 
128
- .. |Downloads| image:: https://pepy.tech/badge/moscot
158
+ .. |Downloads| image:: https://static.pepy.tech/badge/moscot
129
159
  :target: https://pepy.tech/project/moscot
130
160
  :alt: Downloads
@@ -61,6 +61,6 @@ Our preprint "Mapping cells through time and space with moscot" can be found `he
61
61
  :target: https://moscot.readthedocs.io/en/stable/
62
62
  :alt: Documentation
63
63
 
64
- .. |Downloads| image:: https://pepy.tech/badge/moscot
64
+ .. |Downloads| image:: https://static.pepy.tech/badge/moscot
65
65
  :target: https://pepy.tech/project/moscot
66
66
  :alt: Downloads
@@ -55,8 +55,9 @@ dependencies = [
55
55
  "scanpy>=1.9.3",
56
56
  "wrapt>=1.13.2",
57
57
  "docrep>=0.3.2",
58
- "ott-jax>=0.4.3",
58
+ "ott-jax>=0.4.5",
59
59
  "cloudpickle>=2.2.0",
60
+ "rich>=13.5",
60
61
  ]
61
62
 
62
63
  [project.optional-dependencies]
@@ -121,6 +122,8 @@ ignore = [
121
122
  "D107",
122
123
  # Missing docstring in magic method
123
124
  "D105",
125
+ # Use `X | Y` for type annotations
126
+ "UP007",
124
127
  ]
125
128
  line-length = 120
126
129
  select = [
@@ -8,7 +8,7 @@ import numpy as np
8
8
  try:
9
9
  from numpy.typing import DTypeLike, NDArray
10
10
 
11
- ArrayLike = NDArray[np.float_]
11
+ ArrayLike = NDArray[np.float64]
12
12
  except (ImportError, TypeError):
13
13
  ArrayLike = np.ndarray # type: ignore[misc]
14
14
  DTypeLike = np.dtype # type: ignore[misc]
@@ -33,20 +33,20 @@ OttCostFn_t = Literal[
33
33
  "euclidean",
34
34
  "sq_euclidean",
35
35
  "cosine",
36
- "PNormP",
37
- "SqPNorm",
38
- "Euclidean",
39
- "SqEuclidean",
40
- "Cosine",
41
- "ElasticL1",
42
- "ElasticSTVS",
43
- "ElasticSqKOverlap",
36
+ "pnorm_p",
37
+ "sq_pnorm",
38
+ "cosine",
39
+ "elastic_l1",
40
+ "elastic_l2",
41
+ "elastic_stvs",
42
+ "elastic_sqk_overlap",
43
+ "geodesic",
44
44
  ]
45
45
  OttCostFnMap_t = Union[OttCostFn_t, Mapping[Literal["xy", "x", "y"], OttCostFn_t]]
46
46
  GenericCostFn_t = Literal["barcode_distance", "leaf_distance", "custom"]
47
47
  CostFn_t = Union[str, GenericCostFn_t, OttCostFn_t]
48
48
  CostFnMap_t = Union[Union[OttCostFn_t, GenericCostFn_t], Mapping[str, Union[OttCostFn_t, GenericCostFn_t]]]
49
- PathLike = Union[os.PathLike, str]
49
+ PathLike = Union[os.PathLike, str] # type: ignore[type-arg]
50
50
  Policy_t = Literal[
51
51
  "sequential",
52
52
  "star",
@@ -1,11 +1,11 @@
1
1
  from ott.geometry import costs
2
2
 
3
3
  from moscot.backends.ott._utils import sinkhorn_divergence
4
- from moscot.backends.ott.output import OTTOutput
4
+ from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
5
5
  from moscot.backends.ott.solver import GWSolver, SinkhornSolver
6
6
  from moscot.costs import register_cost
7
7
 
8
- __all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"]
8
+ __all__ = ["OTTOutput", "GraphOTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"]
9
9
 
10
10
  register_cost("euclidean", backend="ott")(costs.Euclidean)
11
11
  register_cost("sq_euclidean", backend="ott")(costs.SqEuclidean)
@@ -13,4 +13,6 @@ register_cost("cosine", backend="ott")(costs.Cosine)
13
13
  register_cost("pnorm_p", backend="ott")(costs.PNormP)
14
14
  register_cost("sq_pnorm", backend="ott")(costs.SqPNorm)
15
15
  register_cost("elastic_l1", backend="ott")(costs.ElasticL1)
16
+ register_cost("elastic_l2", backend="ott")(costs.ElasticL2)
16
17
  register_cost("elastic_stvs", backend="ott")(costs.ElasticSTVS)
18
+ register_cost("elastic_sqk_overlap", backend="ott")(costs.ElasticSqKOverlap)
@@ -1,14 +1,17 @@
1
- from typing import Any, Optional, Union
1
+ from typing import Any, Literal, Optional, Tuple, Union
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
5
  import scipy.sparse as sp
6
- from ott.geometry import epsilon_scheduler, geometry, pointcloud
6
+ from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
7
7
  from ott.tools import sinkhorn_divergence as sdiv
8
8
 
9
9
  from moscot._logging import logger
10
10
  from moscot._types import ArrayLike, ScaleCost_t
11
11
 
12
+ Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]
13
+
14
+
12
15
  __all__ = ["sinkhorn_divergence"]
13
16
 
14
17
 
@@ -86,3 +89,23 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
86
89
  if arr.ndim != 2:
87
90
  raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
88
91
  return arr
92
+
93
+
94
+ def _instantiate_geodesic_cost(
95
+ arr: jax.Array,
96
+ problem_shape: Tuple[int, int],
97
+ t: Optional[float],
98
+ is_linear_term: bool,
99
+ epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
100
+ relative_epsilon: Optional[bool] = None,
101
+ scale_cost: Scale_t = 1.0,
102
+ directed: bool = True,
103
+ **kwargs: Any,
104
+ ) -> geometry.Geometry:
105
+ n_src, n_tgt = problem_shape
106
+ if is_linear_term and n_src + n_tgt != arr.shape[0]:
107
+ raise ValueError(f"Expected `x` to have `{n_src + n_tgt}` points, found `{arr.shape[0]}`.")
108
+ t = epsilon / 4.0 if t is None else t
109
+ cm_full = geodesic.Geodesic.from_graph(arr, t=t, directed=directed, **kwargs).cost_matrix
110
+ cm = cm_full[:n_src, n_src:] if is_linear_term else cm_full
111
+ return geometry.Geometry(cm, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost)
@@ -6,7 +6,7 @@ import jax
6
6
  import jax.numpy as jnp
7
7
  import numpy as np
8
8
  from ott.solvers.linear import sinkhorn, sinkhorn_lr
9
- from ott.solvers.quadratic import gromov_wasserstein
9
+ from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr
10
10
 
11
11
  import matplotlib as mpl
12
12
  import matplotlib.pyplot as plt
@@ -14,7 +14,7 @@ import matplotlib.pyplot as plt
14
14
  from moscot._types import ArrayLike, Device_t
15
15
  from moscot.base.output import BaseSolverOutput
16
16
 
17
- __all__ = ["OTTOutput"]
17
+ __all__ = ["OTTOutput", "GraphOTTOutput"]
18
18
 
19
19
 
20
20
  class OTTOutput(BaseSolverOutput):
@@ -29,7 +29,13 @@ class OTTOutput(BaseSolverOutput):
29
29
  _NOT_COMPUTED = -1.0 # sentinel value used in `ott`
30
30
 
31
31
  def __init__(
32
- self, output: Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput, gromov_wasserstein.GWOutput]
32
+ self,
33
+ output: Union[
34
+ sinkhorn.SinkhornOutput,
35
+ sinkhorn_lr.LRSinkhornOutput,
36
+ gromov_wasserstein.GWOutput,
37
+ gromov_wasserstein_lr.LRGWOutput,
38
+ ],
33
39
  ):
34
40
  super().__init__()
35
41
  self._output = output
@@ -168,7 +174,10 @@ class OTTOutput(BaseSolverOutput):
168
174
  def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
169
175
  if x.ndim == 1:
170
176
  return self._output.apply(x, axis=1 - forward)
171
- return self._output.apply(x.T, axis=1 - forward).T # convert to batch first
177
+ return self._output.apply(
178
+ x.T,
179
+ axis=1 - forward,
180
+ ).T # convert to batch first
172
181
 
173
182
  @property
174
183
  def shape(self) -> Tuple[int, int]: # noqa: D102
@@ -218,8 +227,72 @@ class OTTOutput(BaseSolverOutput):
218
227
 
219
228
  @property
220
229
  def rank(self) -> int: # noqa: D102
221
- lin_output = self._output if self.is_linear else self._output.linear_state
222
- return len(lin_output.g) if isinstance(lin_output, sinkhorn_lr.LRSinkhornOutput) else -1
230
+ output = self._output.linear_state if isinstance(self._output, gromov_wasserstein.GWOutput) else self._output
231
+ return (
232
+ len(output.g)
233
+ if isinstance(output, (sinkhorn_lr.LRSinkhornOutput, gromov_wasserstein_lr.LRGWOutput))
234
+ else -1
235
+ )
223
236
 
224
237
  def _ones(self, n: int) -> ArrayLike: # noqa: D102
225
238
  return jnp.ones((n,))
239
+
240
+
241
+ class GraphOTTOutput(OTTOutput):
242
+ """Output of :term:`OT` problems with a graph geometry in the linear term.
243
+
244
+ Parameters
245
+ ----------
246
+ output
247
+ Output of the :mod:`ott` backend.
248
+ shape
249
+ Shape of the problem.
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ output: Union[
255
+ sinkhorn.SinkhornOutput,
256
+ sinkhorn_lr.LRSinkhornOutput,
257
+ gromov_wasserstein.GWOutput,
258
+ gromov_wasserstein_lr.LRGWOutput,
259
+ ],
260
+ shape: Tuple[int, int],
261
+ ):
262
+ super().__init__(output)
263
+ self._shape = shape
264
+
265
+ @property
266
+ def shape(self) -> Tuple[int, int]: # noqa: D102
267
+ return self._shape
268
+
269
+ def _expand_data(self, x: jnp.ndarray, forward: bool) -> jnp.ndarray:
270
+ if forward:
271
+ shape = (self.shape[1],) if x.ndim == 1 else (self.shape[1], x.shape[1])
272
+ return jnp.concatenate((x, jnp.zeros(shape)))
273
+ shape = (self.shape[0],) if x.ndim == 1 else (self.shape[0], x.shape[1])
274
+ return jnp.concatenate((jnp.zeros(shape), x))
275
+
276
+ def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
277
+ x_expanded = self._expand_data(x, forward=forward)
278
+ # ott-jax only supports lse_mode=False with graph geometry
279
+ res = self._output.apply(x_expanded.T, axis=1 - forward, lse_mode=False).T
280
+ return res[len(x) :] if forward else res[: -len(x)]
281
+
282
+ def to(self, device: Optional[Device_t] = None) -> "GraphOTTOutput": # noqa: D102
283
+ if device is None:
284
+ return GraphOTTOutput(jax.device_put(self._output, device=device), shape=self.shape)
285
+
286
+ if isinstance(device, str) and ":" in device:
287
+ device, ix = device.split(":")
288
+ idx = int(ix)
289
+ else:
290
+ idx = 0
291
+
292
+ if not isinstance(device, xla_ext.Device):
293
+ try:
294
+ device = jax.devices(device)[idx]
295
+ except IndexError:
296
+ raise IndexError(f"Unable to fetch the device with `id={idx}`.") from None
297
+
298
+ return GraphOTTOutput(jax.device_put(self._output, device), shape=self.shape)
@@ -4,22 +4,34 @@ import types
4
4
  from typing import Any, Literal, Mapping, Optional, Set, Tuple, Union
5
5
 
6
6
  import jax
7
- from ott.geometry import costs, epsilon_scheduler, geometry, pointcloud
7
+ import jax.numpy as jnp
8
+ from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud
8
9
  from ott.problems.linear import linear_problem
9
10
  from ott.problems.quadratic import quadratic_problem
10
11
  from ott.solvers.linear import sinkhorn, sinkhorn_lr
11
- from ott.solvers.quadratic import gromov_wasserstein
12
+ from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr
12
13
 
13
14
  from moscot._types import ProblemKind_t, QuadInitializer_t, SinkhornInitializer_t
14
- from moscot.backends.ott._utils import alpha_to_fused_penalty, check_shapes, ensure_2d
15
- from moscot.backends.ott.output import OTTOutput
15
+ from moscot.backends.ott._utils import (
16
+ _instantiate_geodesic_cost,
17
+ alpha_to_fused_penalty,
18
+ check_shapes,
19
+ ensure_2d,
20
+ )
21
+ from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
22
+ from moscot.base.problems._utils import TimeScalesHeatKernel
16
23
  from moscot.base.solver import OTSolver
17
24
  from moscot.costs import get_cost
18
25
  from moscot.utils.tagged_array import TaggedArray
19
26
 
20
27
  __all__ = ["SinkhornSolver", "GWSolver"]
21
28
 
22
- OTTSolver_t = Union[sinkhorn.Sinkhorn, sinkhorn_lr.LRSinkhorn, gromov_wasserstein.GromovWasserstein]
29
+ OTTSolver_t = Union[
30
+ sinkhorn.Sinkhorn,
31
+ sinkhorn_lr.LRSinkhorn,
32
+ gromov_wasserstein.GromovWasserstein,
33
+ gromov_wasserstein_lr.LRGromovWasserstein,
34
+ ]
23
35
  OTTProblem_t = Union[linear_problem.LinearProblem, quadratic_problem.QuadraticProblem]
24
36
  Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]
25
37
 
@@ -38,14 +50,21 @@ class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC):
38
50
  self._solver: Optional[OTTSolver_t] = None
39
51
  self._problem: Optional[OTTProblem_t] = None
40
52
  self._jit = jit
53
+ self._a: Optional[jnp.ndarray] = None
54
+ self._b: Optional[jnp.ndarray] = None
41
55
 
42
56
  def _create_geometry(
43
57
  self,
44
58
  x: TaggedArray,
59
+ *,
60
+ is_linear_term: bool,
45
61
  epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
46
62
  relative_epsilon: Optional[bool] = None,
47
63
  scale_cost: Scale_t = 1.0,
48
64
  batch_size: Optional[int] = None,
65
+ problem_shape: Optional[Tuple[int, int]] = None,
66
+ t: Optional[float] = None,
67
+ directed: bool = True,
49
68
  **kwargs: Any,
50
69
  ) -> geometry.Geometry:
51
70
  if x.is_point_cloud:
@@ -83,17 +102,80 @@ class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC):
83
102
  return geometry.Geometry(
84
103
  kernel_matrix=arr, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost
85
104
  )
105
+ if x.is_graph: # we currently only support this for the linear term.
106
+ return self._create_graph_geometry(
107
+ is_linear_term=is_linear_term,
108
+ x=x,
109
+ arr=arr,
110
+ problem_shape=problem_shape,
111
+ t=t,
112
+ epsilon=epsilon,
113
+ relative_epsilon=relative_epsilon,
114
+ scale_cost=scale_cost,
115
+ directed=directed,
116
+ **kwargs,
117
+ )
86
118
  raise NotImplementedError(f"Creating geometry from `tag={x.tag!r}` is not yet implemented.")
87
119
 
88
120
  def _solve( # type: ignore[override]
89
121
  self,
90
122
  prob: OTTProblem_t,
91
123
  **kwargs: Any,
92
- ) -> OTTOutput:
124
+ ) -> Union[OTTOutput, GraphOTTOutput]:
93
125
  solver = jax.jit(self.solver) if self._jit else self.solver
94
126
  out = solver(prob, **kwargs)
127
+ if isinstance(prob, linear_problem.LinearProblem) and isinstance(prob.geom, geodesic.Geodesic):
128
+ return GraphOTTOutput(out, shape=(len(self._a), len(self._b))) # type: ignore[arg-type]
95
129
  return OTTOutput(out)
96
130
 
131
+ def _create_graph_geometry(
132
+ self,
133
+ is_linear_term: bool,
134
+ x: TaggedArray,
135
+ arr: jax.Array,
136
+ problem_shape: Optional[Tuple[int, int]],
137
+ t: Optional[float],
138
+ epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
139
+ relative_epsilon: Optional[bool] = None,
140
+ scale_cost: Scale_t = 1.0,
141
+ directed: bool = True,
142
+ **kwargs: Any,
143
+ ) -> geometry.Geometry:
144
+ if x.cost == "geodesic":
145
+ if self.problem_kind == "linear":
146
+ if t is None:
147
+ if epsilon is None:
148
+ raise ValueError("`epsilon` cannot be `None`.")
149
+ return geodesic.Geodesic.from_graph(arr, t=epsilon / 4.0, directed=directed, **kwargs)
150
+
151
+ return _instantiate_geodesic_cost(
152
+ arr=arr,
153
+ problem_shape=problem_shape, # type: ignore[arg-type]
154
+ t=t,
155
+ is_linear_term=True,
156
+ epsilon=epsilon,
157
+ relative_epsilon=relative_epsilon,
158
+ scale_cost=scale_cost,
159
+ directed=directed,
160
+ **kwargs,
161
+ )
162
+ if self.problem_kind == "quadratic":
163
+ problem_shape = x.shape if problem_shape is None else problem_shape
164
+ return _instantiate_geodesic_cost(
165
+ arr=arr,
166
+ problem_shape=problem_shape,
167
+ t=t,
168
+ is_linear_term=is_linear_term,
169
+ epsilon=epsilon,
170
+ relative_epsilon=relative_epsilon,
171
+ scale_cost=scale_cost,
172
+ directed=directed,
173
+ **kwargs,
174
+ )
175
+
176
+ raise NotImplementedError(f"Invalid problem kind `{self.problem_kind}`.")
177
+ raise NotImplementedError(f"If the geometry is a graph, `cost` must be `geodesic`, found `{x.cost}`.")
178
+
97
179
  @property
98
180
  def solver(self) -> OTTSolver_t:
99
181
  """:mod:`ott` solver."""
@@ -160,6 +242,8 @@ class SinkhornSolver(OTTJaxSolver):
160
242
 
161
243
  def _prepare(
162
244
  self,
245
+ a: jnp.ndarray,
246
+ b: jnp.ndarray,
163
247
  xy: Optional[TaggedArray] = None,
164
248
  x: Optional[TaggedArray] = None,
165
249
  y: Optional[TaggedArray] = None,
@@ -170,24 +254,35 @@ class SinkhornSolver(OTTJaxSolver):
170
254
  scale_cost: Scale_t = 1.0,
171
255
  cost_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
172
256
  cost_matrix_rank: Optional[int] = None,
257
+ time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None,
173
258
  # problem
174
259
  **kwargs: Any,
175
260
  ) -> linear_problem.LinearProblem:
176
261
  del x, y
262
+ time_scales_heat_kernel = (
263
+ TimeScalesHeatKernel(None, None, None) if time_scales_heat_kernel is None else time_scales_heat_kernel
264
+ )
177
265
  if xy is None:
178
266
  raise ValueError(f"Unable to create geometry from `xy={xy}`.")
179
-
267
+ self._a = a
268
+ self._b = b
180
269
  geom = self._create_geometry(
181
270
  xy,
271
+ is_linear_term=True,
182
272
  epsilon=epsilon,
183
273
  relative_epsilon=relative_epsilon,
184
274
  batch_size=batch_size,
275
+ problem_shape=(len(self._a), len(self._b)),
185
276
  scale_cost=scale_cost,
277
+ t=time_scales_heat_kernel.xy,
186
278
  **cost_kwargs,
187
279
  )
188
280
  if cost_matrix_rank is not None:
189
281
  geom = geom.to_LRCGeometry(rank=cost_matrix_rank)
190
- self._problem = linear_problem.LinearProblem(geom, **kwargs)
282
+ if isinstance(geom, geodesic.Geodesic):
283
+ a = jnp.concatenate((a, jnp.zeros_like(self._b)), axis=0)
284
+ b = jnp.concatenate((jnp.zeros_like(self._a), b), axis=0)
285
+ self._problem = linear_problem.LinearProblem(geom, a=a, b=b, **kwargs)
191
286
  return self._problem
192
287
 
193
288
  @property
@@ -201,7 +296,15 @@ class SinkhornSolver(OTTJaxSolver):
201
296
 
202
297
  @classmethod
203
298
  def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]:
204
- geom_kwargs = {"epsilon", "relative_epsilon", "batch_size", "scale_cost", "cost_kwargs", "cost_matrix_rank"}
299
+ geom_kwargs = {
300
+ "epsilon",
301
+ "relative_epsilon",
302
+ "batch_size",
303
+ "scale_cost",
304
+ "cost_kwargs",
305
+ "cost_matrix_rank",
306
+ "t",
307
+ }
205
308
  problem_kwargs = set(inspect.signature(linear_problem.LinearProblem).parameters.keys())
206
309
  problem_kwargs -= {"geom"}
207
310
  return geom_kwargs | problem_kwargs, {"epsilon"}
@@ -243,24 +346,30 @@ class GWSolver(OTTJaxSolver):
243
346
  ):
244
347
  super().__init__(jit=jit)
245
348
  if rank > -1:
246
- linear_solver_kwargs = dict(linear_solver_kwargs)
247
- linear_solver_kwargs.setdefault("gamma", 10)
248
- linear_solver_kwargs.setdefault("gamma_rescale", True)
249
- linear_ot_solver = sinkhorn_lr.LRSinkhorn(rank=rank, **linear_solver_kwargs)
349
+ kwargs.setdefault("gamma", 10)
350
+ kwargs.setdefault("gamma_rescale", True)
250
351
  initializer = "rank2" if initializer is None else initializer
352
+ self._solver = gromov_wasserstein_lr.LRGromovWasserstein(
353
+ rank=rank,
354
+ initializer=initializer,
355
+ kwargs_init=initializer_kwargs,
356
+ **kwargs,
357
+ )
251
358
  else:
252
359
  linear_ot_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs)
253
360
  initializer = None
254
- self._solver = gromov_wasserstein.GromovWasserstein(
255
- rank=rank,
256
- linear_ot_solver=linear_ot_solver,
257
- quad_initializer=initializer,
258
- kwargs_init=initializer_kwargs,
259
- **kwargs,
260
- )
361
+ self._solver = gromov_wasserstein.GromovWasserstein(
362
+ rank=rank,
363
+ linear_ot_solver=linear_ot_solver,
364
+ quad_initializer=initializer,
365
+ kwargs_init=initializer_kwargs,
366
+ **kwargs,
367
+ )
261
368
 
262
369
  def _prepare(
263
370
  self,
371
+ a: jnp.ndarray,
372
+ b: jnp.ndarray,
264
373
  xy: Optional[TaggedArray] = None,
265
374
  x: Optional[TaggedArray] = None,
266
375
  y: Optional[TaggedArray] = None,
@@ -271,32 +380,45 @@ class GWSolver(OTTJaxSolver):
271
380
  scale_cost: Scale_t = 1.0,
272
381
  cost_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
273
382
  cost_matrix_rank: Optional[int] = None,
383
+ time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None,
274
384
  # problem
275
385
  alpha: float = 0.5,
276
386
  **kwargs: Any,
277
387
  ) -> quadratic_problem.QuadraticProblem:
388
+ self._a = a
389
+ self._b = b
390
+ time_scales_heat_kernel = (
391
+ TimeScalesHeatKernel(None, None, None) if time_scales_heat_kernel is None else time_scales_heat_kernel
392
+ )
278
393
  if x is None or y is None:
279
394
  raise ValueError(f"Unable to create geometry from `x={x}`, `y={y}`.")
280
- geom_kwargs: Any = {
395
+ geom_kwargs: dict[str, Any] = {
281
396
  "epsilon": epsilon,
282
397
  "relative_epsilon": relative_epsilon,
283
398
  "batch_size": batch_size,
284
399
  "scale_cost": scale_cost,
285
- "cost_matrix_rank": cost_matrix_rank,
286
400
  **cost_kwargs,
287
401
  }
288
- geom_xx = self._create_geometry(x, **geom_kwargs)
289
- geom_yy = self._create_geometry(y, **geom_kwargs)
402
+ if cost_matrix_rank is not None:
403
+ geom_kwargs["cost_matrix_rank"] = cost_matrix_rank
404
+ geom_xx = self._create_geometry(x, t=time_scales_heat_kernel.x, is_linear_term=False, **geom_kwargs)
405
+ geom_yy = self._create_geometry(y, t=time_scales_heat_kernel.y, is_linear_term=False, **geom_kwargs)
290
406
  if alpha == 1.0 or xy is None: # GW
291
407
  # arbitrary fused penalty; must be positive
292
408
  geom_xy, fused_penalty = None, 1.0
293
409
  else: # FGW
294
410
  fused_penalty = alpha_to_fused_penalty(alpha)
295
- geom_xy = self._create_geometry(xy, **geom_kwargs)
411
+ geom_xy = self._create_geometry(
412
+ xy,
413
+ t=time_scales_heat_kernel.xy,
414
+ problem_shape=(x.shape[0], y.shape[0]),
415
+ is_linear_term=True,
416
+ **geom_kwargs,
417
+ )
296
418
  check_shapes(geom_xx, geom_yy, geom_xy)
297
419
 
298
420
  self._problem = quadratic_problem.QuadraticProblem(
299
- geom_xx, geom_yy, geom_xy, fused_penalty=fused_penalty, **kwargs
421
+ geom_xx, geom_yy, geom_xy, fused_penalty=fused_penalty, a=self._a, b=self._b, **kwargs
300
422
  )
301
423
  return self._problem
302
424