moscot 0.3.3__tar.gz → 0.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of moscot might be problematic. Click here for more details.

Files changed (75) hide show
  1. {moscot-0.3.3 → moscot-0.3.4}/.pre-commit-config.yaml +8 -8
  2. {moscot-0.3.3 → moscot-0.3.4}/PKG-INFO +2 -2
  3. {moscot-0.3.3 → moscot-0.3.4}/pyproject.toml +3 -1
  4. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/_types.py +10 -10
  5. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/backends/ott/__init__.py +4 -2
  6. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/backends/ott/_utils.py +25 -2
  7. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/backends/ott/output.py +65 -2
  8. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/backends/ott/solver.py +126 -13
  9. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/problems/_mixins.py +199 -31
  10. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/problems/_utils.py +8 -1
  11. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/problems/birth_death.py +8 -4
  12. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/problems/compound_problem.py +26 -27
  13. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/problems/problem.py +234 -40
  14. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/solver.py +0 -1
  15. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/costs/_costs.py +5 -2
  16. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/costs/_utils.py +15 -4
  17. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/datasets.py +36 -4
  18. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/plotting/_plotting.py +11 -11
  19. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/plotting/_utils.py +9 -6
  20. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/_utils.py +11 -13
  21. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/cross_modality/_mixins.py +62 -5
  22. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/cross_modality/_translation.py +4 -4
  23. moscot-0.3.4/src/moscot/problems/generic/__init__.py +4 -0
  24. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/generic/_generic.py +245 -32
  25. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/generic/_mixins.py +10 -3
  26. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/space/_alignment.py +11 -7
  27. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/space/_mapping.py +8 -9
  28. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/space/_mixins.py +123 -10
  29. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/spatiotemporal/_spatio_temporal.py +6 -1
  30. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/time/_lineage.py +7 -6
  31. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/time/_mixins.py +95 -45
  32. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/_data/mouse_proliferation.txt +0 -1
  33. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/subset_policy.py +7 -12
  34. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/tagged_array.py +27 -1
  35. {moscot-0.3.3 → moscot-0.3.4}/src/moscot.egg-info/PKG-INFO +2 -2
  36. {moscot-0.3.3 → moscot-0.3.4}/src/moscot.egg-info/requires.txt +1 -1
  37. moscot-0.3.3/src/moscot/problems/generic/__init__.py +0 -4
  38. {moscot-0.3.3 → moscot-0.3.4}/.gitignore +0 -0
  39. {moscot-0.3.3 → moscot-0.3.4}/.gitmodules +0 -0
  40. {moscot-0.3.3 → moscot-0.3.4}/.readthedocs.yml +0 -0
  41. {moscot-0.3.3 → moscot-0.3.4}/LICENSE +0 -0
  42. {moscot-0.3.3 → moscot-0.3.4}/MANIFEST.in +0 -0
  43. {moscot-0.3.3 → moscot-0.3.4}/README.rst +0 -0
  44. {moscot-0.3.3 → moscot-0.3.4}/codecov.yml +0 -0
  45. {moscot-0.3.3 → moscot-0.3.4}/setup.cfg +0 -0
  46. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/__init__.py +0 -0
  47. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/_constants.py +0 -0
  48. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/_logging.py +0 -0
  49. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/_registry.py +0 -0
  50. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/backends/__init__.py +0 -0
  51. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/backends/utils.py +0 -0
  52. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/__init__.py +0 -0
  53. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/cost.py +0 -0
  54. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/output.py +0 -0
  55. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/problems/__init__.py +0 -0
  56. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/problems/manager.py +0 -0
  57. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/costs/__init__.py +0 -0
  58. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/plotting/__init__.py +0 -0
  59. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/__init__.py +0 -0
  60. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/cross_modality/__init__.py +0 -0
  61. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/space/__init__.py +0 -0
  62. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/spatiotemporal/__init__.py +0 -0
  63. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/time/__init__.py +0 -0
  64. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/py.typed +0 -0
  65. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/__init__.py +0 -0
  66. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/_data/allTFs_dmel.txt +0 -0
  67. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/_data/allTFs_hg38.txt +0 -0
  68. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/_data/allTFs_mm.txt +0 -0
  69. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/_data/human_apoptosis.txt +0 -0
  70. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/_data/human_proliferation.txt +0 -0
  71. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/_data/mouse_apoptosis.txt +0 -0
  72. {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/data.py +0 -0
  73. {moscot-0.3.3 → moscot-0.3.4}/src/moscot.egg-info/SOURCES.txt +0 -0
  74. {moscot-0.3.3 → moscot-0.3.4}/src/moscot.egg-info/dependency_links.txt +0 -0
  75. {moscot-0.3.3 → moscot-0.3.4}/src/moscot.egg-info/top_level.txt +0 -0
@@ -7,29 +7,29 @@ default_stages:
7
7
  minimum_pre_commit_version: 3.0.0
8
8
  repos:
9
9
  - repo: https://github.com/pre-commit/mirrors-mypy
10
- rev: v1.5.1
10
+ rev: v1.8.0
11
11
  hooks:
12
12
  - id: mypy
13
13
  additional_dependencies: [numpy>=1.25.0]
14
14
  files: ^src
15
15
  - repo: https://github.com/psf/black
16
- rev: 23.7.0
16
+ rev: 24.2.0
17
17
  hooks:
18
18
  - id: black
19
19
  additional_dependencies: [toml]
20
20
  - repo: https://github.com/pre-commit/mirrors-prettier
21
- rev: v3.0.2
21
+ rev: v4.0.0-alpha.8
22
22
  hooks:
23
23
  - id: prettier
24
24
  language_version: system
25
25
  - repo: https://github.com/PyCQA/isort
26
- rev: 5.12.0
26
+ rev: 5.13.2
27
27
  hooks:
28
28
  - id: isort
29
29
  additional_dependencies: [toml]
30
30
  args: [--order-by-type]
31
31
  - repo: https://github.com/pre-commit/pre-commit-hooks
32
- rev: v4.4.0
32
+ rev: v4.5.0
33
33
  hooks:
34
34
  - id: check-merge-conflict
35
35
  - id: check-ast
@@ -42,7 +42,7 @@ repos:
42
42
  - id: check-yaml
43
43
  - id: check-toml
44
44
  - repo: https://github.com/asottile/pyupgrade
45
- rev: v3.10.1
45
+ rev: v3.15.1
46
46
  hooks:
47
47
  - id: pyupgrade
48
48
  args: [--py3-plus, --py38-plus, --keep-runtime-typing]
@@ -52,7 +52,7 @@ repos:
52
52
  - id: blacken-docs
53
53
  additional_dependencies: [black==23.1.0]
54
54
  - repo: https://github.com/rstcheck/rstcheck
55
- rev: v6.1.2
55
+ rev: v6.2.0
56
56
  hooks:
57
57
  - id: rstcheck
58
58
  additional_dependencies: [tomli]
@@ -63,7 +63,7 @@ repos:
63
63
  - id: doc8
64
64
  - repo: https://github.com/astral-sh/ruff-pre-commit
65
65
  # Ruff version.
66
- rev: v0.0.286
66
+ rev: v0.2.2
67
67
  hooks:
68
68
  - id: ruff
69
69
  args: [--fix, --exit-non-zero-on-fix]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: moscot
3
- Version: 0.3.3
3
+ Version: 0.3.4
4
4
  Summary: Multi-omic single-cell optimal transport tools
5
5
  Author: Dominik Klein, Giovanni Palla, Michal Klein, Zoe Piran, Marius Lange
6
6
  Maintainer-email: Dominik Klein <dominik.klein@helmholtz-muenchen.de>, Giovanni Palla <giovanni.palla@helmholtz-muenchen.de>, Michal Klein <michal.klein@helmholtz-muenchen.de>
@@ -66,7 +66,7 @@ Requires-Dist: anndata>=0.9.1
66
66
  Requires-Dist: scanpy>=1.9.3
67
67
  Requires-Dist: wrapt>=1.13.2
68
68
  Requires-Dist: docrep>=0.3.2
69
- Requires-Dist: ott-jax>=0.4.3
69
+ Requires-Dist: ott-jax>=0.4.5
70
70
  Requires-Dist: cloudpickle>=2.2.0
71
71
  Requires-Dist: rich>=13.5
72
72
  Provides-Extra: spatial
@@ -55,7 +55,7 @@ dependencies = [
55
55
  "scanpy>=1.9.3",
56
56
  "wrapt>=1.13.2",
57
57
  "docrep>=0.3.2",
58
- "ott-jax>=0.4.3",
58
+ "ott-jax>=0.4.5",
59
59
  "cloudpickle>=2.2.0",
60
60
  "rich>=13.5",
61
61
  ]
@@ -122,6 +122,8 @@ ignore = [
122
122
  "D107",
123
123
  # Missing docstring in magic method
124
124
  "D105",
125
+ # Use `X | Y` for type annotations
126
+ "UP007",
125
127
  ]
126
128
  line-length = 120
127
129
  select = [
@@ -8,7 +8,7 @@ import numpy as np
8
8
  try:
9
9
  from numpy.typing import DTypeLike, NDArray
10
10
 
11
- ArrayLike = NDArray[np.float_]
11
+ ArrayLike = NDArray[np.float64]
12
12
  except (ImportError, TypeError):
13
13
  ArrayLike = np.ndarray # type: ignore[misc]
14
14
  DTypeLike = np.dtype # type: ignore[misc]
@@ -33,20 +33,20 @@ OttCostFn_t = Literal[
33
33
  "euclidean",
34
34
  "sq_euclidean",
35
35
  "cosine",
36
- "PNormP",
37
- "SqPNorm",
38
- "Euclidean",
39
- "SqEuclidean",
40
- "Cosine",
41
- "ElasticL1",
42
- "ElasticSTVS",
43
- "ElasticSqKOverlap",
36
+ "pnorm_p",
37
+ "sq_pnorm",
38
+ "cosine",
39
+ "elastic_l1",
40
+ "elastic_l2",
41
+ "elastic_stvs",
42
+ "elastic_sqk_overlap",
43
+ "geodesic",
44
44
  ]
45
45
  OttCostFnMap_t = Union[OttCostFn_t, Mapping[Literal["xy", "x", "y"], OttCostFn_t]]
46
46
  GenericCostFn_t = Literal["barcode_distance", "leaf_distance", "custom"]
47
47
  CostFn_t = Union[str, GenericCostFn_t, OttCostFn_t]
48
48
  CostFnMap_t = Union[Union[OttCostFn_t, GenericCostFn_t], Mapping[str, Union[OttCostFn_t, GenericCostFn_t]]]
49
- PathLike = Union[os.PathLike, str]
49
+ PathLike = Union[os.PathLike, str] # type: ignore[type-arg]
50
50
  Policy_t = Literal[
51
51
  "sequential",
52
52
  "star",
@@ -1,11 +1,11 @@
1
1
  from ott.geometry import costs
2
2
 
3
3
  from moscot.backends.ott._utils import sinkhorn_divergence
4
- from moscot.backends.ott.output import OTTOutput
4
+ from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
5
5
  from moscot.backends.ott.solver import GWSolver, SinkhornSolver
6
6
  from moscot.costs import register_cost
7
7
 
8
- __all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"]
8
+ __all__ = ["OTTOutput", "GraphOTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"]
9
9
 
10
10
  register_cost("euclidean", backend="ott")(costs.Euclidean)
11
11
  register_cost("sq_euclidean", backend="ott")(costs.SqEuclidean)
@@ -13,4 +13,6 @@ register_cost("cosine", backend="ott")(costs.Cosine)
13
13
  register_cost("pnorm_p", backend="ott")(costs.PNormP)
14
14
  register_cost("sq_pnorm", backend="ott")(costs.SqPNorm)
15
15
  register_cost("elastic_l1", backend="ott")(costs.ElasticL1)
16
+ register_cost("elastic_l2", backend="ott")(costs.ElasticL2)
16
17
  register_cost("elastic_stvs", backend="ott")(costs.ElasticSTVS)
18
+ register_cost("elastic_sqk_overlap", backend="ott")(costs.ElasticSqKOverlap)
@@ -1,14 +1,17 @@
1
- from typing import Any, Optional, Union
1
+ from typing import Any, Literal, Optional, Tuple, Union
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
5
  import scipy.sparse as sp
6
- from ott.geometry import epsilon_scheduler, geometry, pointcloud
6
+ from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
7
7
  from ott.tools import sinkhorn_divergence as sdiv
8
8
 
9
9
  from moscot._logging import logger
10
10
  from moscot._types import ArrayLike, ScaleCost_t
11
11
 
12
+ Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]
13
+
14
+
12
15
  __all__ = ["sinkhorn_divergence"]
13
16
 
14
17
 
@@ -86,3 +89,23 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
86
89
  if arr.ndim != 2:
87
90
  raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
88
91
  return arr
92
+
93
+
94
+ def _instantiate_geodesic_cost(
95
+ arr: jax.Array,
96
+ problem_shape: Tuple[int, int],
97
+ t: Optional[float],
98
+ is_linear_term: bool,
99
+ epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
100
+ relative_epsilon: Optional[bool] = None,
101
+ scale_cost: Scale_t = 1.0,
102
+ directed: bool = True,
103
+ **kwargs: Any,
104
+ ) -> geometry.Geometry:
105
+ n_src, n_tgt = problem_shape
106
+ if is_linear_term and n_src + n_tgt != arr.shape[0]:
107
+ raise ValueError(f"Expected `x` to have `{n_src + n_tgt}` points, found `{arr.shape[0]}`.")
108
+ t = epsilon / 4.0 if t is None else t
109
+ cm_full = geodesic.Geodesic.from_graph(arr, t=t, directed=directed, **kwargs).cost_matrix
110
+ cm = cm_full[:n_src, n_src:] if is_linear_term else cm_full
111
+ return geometry.Geometry(cm, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost)
@@ -14,7 +14,7 @@ import matplotlib.pyplot as plt
14
14
  from moscot._types import ArrayLike, Device_t
15
15
  from moscot.base.output import BaseSolverOutput
16
16
 
17
- __all__ = ["OTTOutput"]
17
+ __all__ = ["OTTOutput", "GraphOTTOutput"]
18
18
 
19
19
 
20
20
  class OTTOutput(BaseSolverOutput):
@@ -174,7 +174,10 @@ class OTTOutput(BaseSolverOutput):
174
174
  def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
175
175
  if x.ndim == 1:
176
176
  return self._output.apply(x, axis=1 - forward)
177
- return self._output.apply(x.T, axis=1 - forward).T # convert to batch first
177
+ return self._output.apply(
178
+ x.T,
179
+ axis=1 - forward,
180
+ ).T # convert to batch first
178
181
 
179
182
  @property
180
183
  def shape(self) -> Tuple[int, int]: # noqa: D102
@@ -233,3 +236,63 @@ class OTTOutput(BaseSolverOutput):
233
236
 
234
237
  def _ones(self, n: int) -> ArrayLike: # noqa: D102
235
238
  return jnp.ones((n,))
239
+
240
+
241
+ class GraphOTTOutput(OTTOutput):
242
+ """Output of :term:`OT` problems with a graph geometry in the linear term.
243
+
244
+ Parameters
245
+ ----------
246
+ output
247
+ Output of the :mod:`ott` backend.
248
+ shape
249
+ Shape of the problem.
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ output: Union[
255
+ sinkhorn.SinkhornOutput,
256
+ sinkhorn_lr.LRSinkhornOutput,
257
+ gromov_wasserstein.GWOutput,
258
+ gromov_wasserstein_lr.LRGWOutput,
259
+ ],
260
+ shape: Tuple[int, int],
261
+ ):
262
+ super().__init__(output)
263
+ self._shape = shape
264
+
265
+ @property
266
+ def shape(self) -> Tuple[int, int]: # noqa: D102
267
+ return self._shape
268
+
269
+ def _expand_data(self, x: jnp.ndarray, forward: bool) -> jnp.ndarray:
270
+ if forward:
271
+ shape = (self.shape[1],) if x.ndim == 1 else (self.shape[1], x.shape[1])
272
+ return jnp.concatenate((x, jnp.zeros(shape)))
273
+ shape = (self.shape[0],) if x.ndim == 1 else (self.shape[0], x.shape[1])
274
+ return jnp.concatenate((jnp.zeros(shape), x))
275
+
276
+ def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
277
+ x_expanded = self._expand_data(x, forward=forward)
278
+ # ott-jax only supports lse_mode=False with graph geometry
279
+ res = self._output.apply(x_expanded.T, axis=1 - forward, lse_mode=False).T
280
+ return res[len(x) :] if forward else res[: -len(x)]
281
+
282
+ def to(self, device: Optional[Device_t] = None) -> "GraphOTTOutput": # noqa: D102
283
+ if device is None:
284
+ return GraphOTTOutput(jax.device_put(self._output, device=device), shape=self.shape)
285
+
286
+ if isinstance(device, str) and ":" in device:
287
+ device, ix = device.split(":")
288
+ idx = int(ix)
289
+ else:
290
+ idx = 0
291
+
292
+ if not isinstance(device, xla_ext.Device):
293
+ try:
294
+ device = jax.devices(device)[idx]
295
+ except IndexError:
296
+ raise IndexError(f"Unable to fetch the device with `id={idx}`.") from None
297
+
298
+ return GraphOTTOutput(jax.device_put(self._output, device), shape=self.shape)
@@ -4,15 +4,22 @@ import types
4
4
  from typing import Any, Literal, Mapping, Optional, Set, Tuple, Union
5
5
 
6
6
  import jax
7
- from ott.geometry import costs, epsilon_scheduler, geometry, pointcloud
7
+ import jax.numpy as jnp
8
+ from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud
8
9
  from ott.problems.linear import linear_problem
9
10
  from ott.problems.quadratic import quadratic_problem
10
11
  from ott.solvers.linear import sinkhorn, sinkhorn_lr
11
12
  from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr
12
13
 
13
14
  from moscot._types import ProblemKind_t, QuadInitializer_t, SinkhornInitializer_t
14
- from moscot.backends.ott._utils import alpha_to_fused_penalty, check_shapes, ensure_2d
15
- from moscot.backends.ott.output import OTTOutput
15
+ from moscot.backends.ott._utils import (
16
+ _instantiate_geodesic_cost,
17
+ alpha_to_fused_penalty,
18
+ check_shapes,
19
+ ensure_2d,
20
+ )
21
+ from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
22
+ from moscot.base.problems._utils import TimeScalesHeatKernel
16
23
  from moscot.base.solver import OTSolver
17
24
  from moscot.costs import get_cost
18
25
  from moscot.utils.tagged_array import TaggedArray
@@ -43,14 +50,21 @@ class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC):
43
50
  self._solver: Optional[OTTSolver_t] = None
44
51
  self._problem: Optional[OTTProblem_t] = None
45
52
  self._jit = jit
53
+ self._a: Optional[jnp.ndarray] = None
54
+ self._b: Optional[jnp.ndarray] = None
46
55
 
47
56
  def _create_geometry(
48
57
  self,
49
58
  x: TaggedArray,
59
+ *,
60
+ is_linear_term: bool,
50
61
  epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
51
62
  relative_epsilon: Optional[bool] = None,
52
63
  scale_cost: Scale_t = 1.0,
53
64
  batch_size: Optional[int] = None,
65
+ problem_shape: Optional[Tuple[int, int]] = None,
66
+ t: Optional[float] = None,
67
+ directed: bool = True,
54
68
  **kwargs: Any,
55
69
  ) -> geometry.Geometry:
56
70
  if x.is_point_cloud:
@@ -88,17 +102,80 @@ class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC):
88
102
  return geometry.Geometry(
89
103
  kernel_matrix=arr, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost
90
104
  )
105
+ if x.is_graph: # we currently only support this for the linear term.
106
+ return self._create_graph_geometry(
107
+ is_linear_term=is_linear_term,
108
+ x=x,
109
+ arr=arr,
110
+ problem_shape=problem_shape,
111
+ t=t,
112
+ epsilon=epsilon,
113
+ relative_epsilon=relative_epsilon,
114
+ scale_cost=scale_cost,
115
+ directed=directed,
116
+ **kwargs,
117
+ )
91
118
  raise NotImplementedError(f"Creating geometry from `tag={x.tag!r}` is not yet implemented.")
92
119
 
93
120
  def _solve( # type: ignore[override]
94
121
  self,
95
122
  prob: OTTProblem_t,
96
123
  **kwargs: Any,
97
- ) -> OTTOutput:
124
+ ) -> Union[OTTOutput, GraphOTTOutput]:
98
125
  solver = jax.jit(self.solver) if self._jit else self.solver
99
126
  out = solver(prob, **kwargs)
127
+ if isinstance(prob, linear_problem.LinearProblem) and isinstance(prob.geom, geodesic.Geodesic):
128
+ return GraphOTTOutput(out, shape=(len(self._a), len(self._b))) # type: ignore[arg-type]
100
129
  return OTTOutput(out)
101
130
 
131
+ def _create_graph_geometry(
132
+ self,
133
+ is_linear_term: bool,
134
+ x: TaggedArray,
135
+ arr: jax.Array,
136
+ problem_shape: Optional[Tuple[int, int]],
137
+ t: Optional[float],
138
+ epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
139
+ relative_epsilon: Optional[bool] = None,
140
+ scale_cost: Scale_t = 1.0,
141
+ directed: bool = True,
142
+ **kwargs: Any,
143
+ ) -> geometry.Geometry:
144
+ if x.cost == "geodesic":
145
+ if self.problem_kind == "linear":
146
+ if t is None:
147
+ if epsilon is None:
148
+ raise ValueError("`epsilon` cannot be `None`.")
149
+ return geodesic.Geodesic.from_graph(arr, t=epsilon / 4.0, directed=directed, **kwargs)
150
+
151
+ return _instantiate_geodesic_cost(
152
+ arr=arr,
153
+ problem_shape=problem_shape, # type: ignore[arg-type]
154
+ t=t,
155
+ is_linear_term=True,
156
+ epsilon=epsilon,
157
+ relative_epsilon=relative_epsilon,
158
+ scale_cost=scale_cost,
159
+ directed=directed,
160
+ **kwargs,
161
+ )
162
+ if self.problem_kind == "quadratic":
163
+ problem_shape = x.shape if problem_shape is None else problem_shape
164
+ return _instantiate_geodesic_cost(
165
+ arr=arr,
166
+ problem_shape=problem_shape,
167
+ t=t,
168
+ is_linear_term=is_linear_term,
169
+ epsilon=epsilon,
170
+ relative_epsilon=relative_epsilon,
171
+ scale_cost=scale_cost,
172
+ directed=directed,
173
+ **kwargs,
174
+ )
175
+
176
+ raise NotImplementedError(f"Invalid problem kind `{self.problem_kind}`.")
177
+ raise NotImplementedError(f"If the geometry is a graph, `cost` must be `geodesic`, found `{x.cost}`.")
178
+
102
179
  @property
103
180
  def solver(self) -> OTTSolver_t:
104
181
  """:mod:`ott` solver."""
@@ -165,6 +242,8 @@ class SinkhornSolver(OTTJaxSolver):
165
242
 
166
243
  def _prepare(
167
244
  self,
245
+ a: jnp.ndarray,
246
+ b: jnp.ndarray,
168
247
  xy: Optional[TaggedArray] = None,
169
248
  x: Optional[TaggedArray] = None,
170
249
  y: Optional[TaggedArray] = None,
@@ -175,24 +254,35 @@ class SinkhornSolver(OTTJaxSolver):
175
254
  scale_cost: Scale_t = 1.0,
176
255
  cost_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
177
256
  cost_matrix_rank: Optional[int] = None,
257
+ time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None,
178
258
  # problem
179
259
  **kwargs: Any,
180
260
  ) -> linear_problem.LinearProblem:
181
261
  del x, y
262
+ time_scales_heat_kernel = (
263
+ TimeScalesHeatKernel(None, None, None) if time_scales_heat_kernel is None else time_scales_heat_kernel
264
+ )
182
265
  if xy is None:
183
266
  raise ValueError(f"Unable to create geometry from `xy={xy}`.")
184
-
267
+ self._a = a
268
+ self._b = b
185
269
  geom = self._create_geometry(
186
270
  xy,
271
+ is_linear_term=True,
187
272
  epsilon=epsilon,
188
273
  relative_epsilon=relative_epsilon,
189
274
  batch_size=batch_size,
275
+ problem_shape=(len(self._a), len(self._b)),
190
276
  scale_cost=scale_cost,
277
+ t=time_scales_heat_kernel.xy,
191
278
  **cost_kwargs,
192
279
  )
193
280
  if cost_matrix_rank is not None:
194
281
  geom = geom.to_LRCGeometry(rank=cost_matrix_rank)
195
- self._problem = linear_problem.LinearProblem(geom, **kwargs)
282
+ if isinstance(geom, geodesic.Geodesic):
283
+ a = jnp.concatenate((a, jnp.zeros_like(self._b)), axis=0)
284
+ b = jnp.concatenate((jnp.zeros_like(self._a), b), axis=0)
285
+ self._problem = linear_problem.LinearProblem(geom, a=a, b=b, **kwargs)
196
286
  return self._problem
197
287
 
198
288
  @property
@@ -206,7 +296,15 @@ class SinkhornSolver(OTTJaxSolver):
206
296
 
207
297
  @classmethod
208
298
  def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]:
209
- geom_kwargs = {"epsilon", "relative_epsilon", "batch_size", "scale_cost", "cost_kwargs", "cost_matrix_rank"}
299
+ geom_kwargs = {
300
+ "epsilon",
301
+ "relative_epsilon",
302
+ "batch_size",
303
+ "scale_cost",
304
+ "cost_kwargs",
305
+ "cost_matrix_rank",
306
+ "t",
307
+ }
210
308
  problem_kwargs = set(inspect.signature(linear_problem.LinearProblem).parameters.keys())
211
309
  problem_kwargs -= {"geom"}
212
310
  return geom_kwargs | problem_kwargs, {"epsilon"}
@@ -270,6 +368,8 @@ class GWSolver(OTTJaxSolver):
270
368
 
271
369
  def _prepare(
272
370
  self,
371
+ a: jnp.ndarray,
372
+ b: jnp.ndarray,
273
373
  xy: Optional[TaggedArray] = None,
274
374
  x: Optional[TaggedArray] = None,
275
375
  y: Optional[TaggedArray] = None,
@@ -280,32 +380,45 @@ class GWSolver(OTTJaxSolver):
280
380
  scale_cost: Scale_t = 1.0,
281
381
  cost_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
282
382
  cost_matrix_rank: Optional[int] = None,
383
+ time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None,
283
384
  # problem
284
385
  alpha: float = 0.5,
285
386
  **kwargs: Any,
286
387
  ) -> quadratic_problem.QuadraticProblem:
388
+ self._a = a
389
+ self._b = b
390
+ time_scales_heat_kernel = (
391
+ TimeScalesHeatKernel(None, None, None) if time_scales_heat_kernel is None else time_scales_heat_kernel
392
+ )
287
393
  if x is None or y is None:
288
394
  raise ValueError(f"Unable to create geometry from `x={x}`, `y={y}`.")
289
- geom_kwargs: Any = {
395
+ geom_kwargs: dict[str, Any] = {
290
396
  "epsilon": epsilon,
291
397
  "relative_epsilon": relative_epsilon,
292
398
  "batch_size": batch_size,
293
399
  "scale_cost": scale_cost,
294
- "cost_matrix_rank": cost_matrix_rank,
295
400
  **cost_kwargs,
296
401
  }
297
- geom_xx = self._create_geometry(x, **geom_kwargs)
298
- geom_yy = self._create_geometry(y, **geom_kwargs)
402
+ if cost_matrix_rank is not None:
403
+ geom_kwargs["cost_matrix_rank"] = cost_matrix_rank
404
+ geom_xx = self._create_geometry(x, t=time_scales_heat_kernel.x, is_linear_term=False, **geom_kwargs)
405
+ geom_yy = self._create_geometry(y, t=time_scales_heat_kernel.y, is_linear_term=False, **geom_kwargs)
299
406
  if alpha == 1.0 or xy is None: # GW
300
407
  # arbitrary fused penalty; must be positive
301
408
  geom_xy, fused_penalty = None, 1.0
302
409
  else: # FGW
303
410
  fused_penalty = alpha_to_fused_penalty(alpha)
304
- geom_xy = self._create_geometry(xy, **geom_kwargs)
411
+ geom_xy = self._create_geometry(
412
+ xy,
413
+ t=time_scales_heat_kernel.xy,
414
+ problem_shape=(x.shape[0], y.shape[0]),
415
+ is_linear_term=True,
416
+ **geom_kwargs,
417
+ )
305
418
  check_shapes(geom_xx, geom_yy, geom_xy)
306
419
 
307
420
  self._problem = quadratic_problem.QuadraticProblem(
308
- geom_xx, geom_yy, geom_xy, fused_penalty=fused_penalty, **kwargs
421
+ geom_xx, geom_yy, geom_xy, fused_penalty=fused_penalty, a=self._a, b=self._b, **kwargs
309
422
  )
310
423
  return self._problem
311
424