moscot 0.3.3__tar.gz → 0.3.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of moscot might be problematic. Click here for more details.
- {moscot-0.3.3 → moscot-0.3.4}/.pre-commit-config.yaml +8 -8
- {moscot-0.3.3 → moscot-0.3.4}/PKG-INFO +2 -2
- {moscot-0.3.3 → moscot-0.3.4}/pyproject.toml +3 -1
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/_types.py +10 -10
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/backends/ott/__init__.py +4 -2
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/backends/ott/_utils.py +25 -2
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/backends/ott/output.py +65 -2
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/backends/ott/solver.py +126 -13
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/problems/_mixins.py +199 -31
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/problems/_utils.py +8 -1
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/problems/birth_death.py +8 -4
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/problems/compound_problem.py +26 -27
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/problems/problem.py +234 -40
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/solver.py +0 -1
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/costs/_costs.py +5 -2
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/costs/_utils.py +15 -4
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/datasets.py +36 -4
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/plotting/_plotting.py +11 -11
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/plotting/_utils.py +9 -6
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/_utils.py +11 -13
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/cross_modality/_mixins.py +62 -5
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/cross_modality/_translation.py +4 -4
- moscot-0.3.4/src/moscot/problems/generic/__init__.py +4 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/generic/_generic.py +245 -32
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/generic/_mixins.py +10 -3
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/space/_alignment.py +11 -7
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/space/_mapping.py +8 -9
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/space/_mixins.py +123 -10
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/spatiotemporal/_spatio_temporal.py +6 -1
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/time/_lineage.py +7 -6
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/time/_mixins.py +95 -45
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/_data/mouse_proliferation.txt +0 -1
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/subset_policy.py +7 -12
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/tagged_array.py +27 -1
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot.egg-info/PKG-INFO +2 -2
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot.egg-info/requires.txt +1 -1
- moscot-0.3.3/src/moscot/problems/generic/__init__.py +0 -4
- {moscot-0.3.3 → moscot-0.3.4}/.gitignore +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/.gitmodules +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/.readthedocs.yml +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/LICENSE +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/MANIFEST.in +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/README.rst +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/codecov.yml +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/setup.cfg +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/_constants.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/_logging.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/_registry.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/backends/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/backends/utils.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/cost.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/output.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/problems/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/base/problems/manager.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/costs/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/plotting/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/cross_modality/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/space/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/spatiotemporal/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/problems/time/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/py.typed +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/_data/allTFs_dmel.txt +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/_data/allTFs_hg38.txt +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/_data/allTFs_mm.txt +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/_data/human_apoptosis.txt +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/_data/human_proliferation.txt +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/_data/mouse_apoptosis.txt +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot/utils/data.py +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot.egg-info/SOURCES.txt +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot.egg-info/dependency_links.txt +0 -0
- {moscot-0.3.3 → moscot-0.3.4}/src/moscot.egg-info/top_level.txt +0 -0
|
@@ -7,29 +7,29 @@ default_stages:
|
|
|
7
7
|
minimum_pre_commit_version: 3.0.0
|
|
8
8
|
repos:
|
|
9
9
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
|
10
|
-
rev: v1.
|
|
10
|
+
rev: v1.8.0
|
|
11
11
|
hooks:
|
|
12
12
|
- id: mypy
|
|
13
13
|
additional_dependencies: [numpy>=1.25.0]
|
|
14
14
|
files: ^src
|
|
15
15
|
- repo: https://github.com/psf/black
|
|
16
|
-
rev:
|
|
16
|
+
rev: 24.2.0
|
|
17
17
|
hooks:
|
|
18
18
|
- id: black
|
|
19
19
|
additional_dependencies: [toml]
|
|
20
20
|
- repo: https://github.com/pre-commit/mirrors-prettier
|
|
21
|
-
rev:
|
|
21
|
+
rev: v4.0.0-alpha.8
|
|
22
22
|
hooks:
|
|
23
23
|
- id: prettier
|
|
24
24
|
language_version: system
|
|
25
25
|
- repo: https://github.com/PyCQA/isort
|
|
26
|
-
rev: 5.
|
|
26
|
+
rev: 5.13.2
|
|
27
27
|
hooks:
|
|
28
28
|
- id: isort
|
|
29
29
|
additional_dependencies: [toml]
|
|
30
30
|
args: [--order-by-type]
|
|
31
31
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
32
|
-
rev: v4.
|
|
32
|
+
rev: v4.5.0
|
|
33
33
|
hooks:
|
|
34
34
|
- id: check-merge-conflict
|
|
35
35
|
- id: check-ast
|
|
@@ -42,7 +42,7 @@ repos:
|
|
|
42
42
|
- id: check-yaml
|
|
43
43
|
- id: check-toml
|
|
44
44
|
- repo: https://github.com/asottile/pyupgrade
|
|
45
|
-
rev: v3.
|
|
45
|
+
rev: v3.15.1
|
|
46
46
|
hooks:
|
|
47
47
|
- id: pyupgrade
|
|
48
48
|
args: [--py3-plus, --py38-plus, --keep-runtime-typing]
|
|
@@ -52,7 +52,7 @@ repos:
|
|
|
52
52
|
- id: blacken-docs
|
|
53
53
|
additional_dependencies: [black==23.1.0]
|
|
54
54
|
- repo: https://github.com/rstcheck/rstcheck
|
|
55
|
-
rev: v6.
|
|
55
|
+
rev: v6.2.0
|
|
56
56
|
hooks:
|
|
57
57
|
- id: rstcheck
|
|
58
58
|
additional_dependencies: [tomli]
|
|
@@ -63,7 +63,7 @@ repos:
|
|
|
63
63
|
- id: doc8
|
|
64
64
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
65
65
|
# Ruff version.
|
|
66
|
-
rev: v0.
|
|
66
|
+
rev: v0.2.2
|
|
67
67
|
hooks:
|
|
68
68
|
- id: ruff
|
|
69
69
|
args: [--fix, --exit-non-zero-on-fix]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: moscot
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.4
|
|
4
4
|
Summary: Multi-omic single-cell optimal transport tools
|
|
5
5
|
Author: Dominik Klein, Giovanni Palla, Michal Klein, Zoe Piran, Marius Lange
|
|
6
6
|
Maintainer-email: Dominik Klein <dominik.klein@helmholtz-muenchen.de>, Giovanni Palla <giovanni.palla@helmholtz-muenchen.de>, Michal Klein <michal.klein@helmholtz-muenchen.de>
|
|
@@ -66,7 +66,7 @@ Requires-Dist: anndata>=0.9.1
|
|
|
66
66
|
Requires-Dist: scanpy>=1.9.3
|
|
67
67
|
Requires-Dist: wrapt>=1.13.2
|
|
68
68
|
Requires-Dist: docrep>=0.3.2
|
|
69
|
-
Requires-Dist: ott-jax>=0.4.
|
|
69
|
+
Requires-Dist: ott-jax>=0.4.5
|
|
70
70
|
Requires-Dist: cloudpickle>=2.2.0
|
|
71
71
|
Requires-Dist: rich>=13.5
|
|
72
72
|
Provides-Extra: spatial
|
|
@@ -55,7 +55,7 @@ dependencies = [
|
|
|
55
55
|
"scanpy>=1.9.3",
|
|
56
56
|
"wrapt>=1.13.2",
|
|
57
57
|
"docrep>=0.3.2",
|
|
58
|
-
"ott-jax>=0.4.
|
|
58
|
+
"ott-jax>=0.4.5",
|
|
59
59
|
"cloudpickle>=2.2.0",
|
|
60
60
|
"rich>=13.5",
|
|
61
61
|
]
|
|
@@ -122,6 +122,8 @@ ignore = [
|
|
|
122
122
|
"D107",
|
|
123
123
|
# Missing docstring in magic method
|
|
124
124
|
"D105",
|
|
125
|
+
# Use `X | Y` for type annotations
|
|
126
|
+
"UP007",
|
|
125
127
|
]
|
|
126
128
|
line-length = 120
|
|
127
129
|
select = [
|
|
@@ -8,7 +8,7 @@ import numpy as np
|
|
|
8
8
|
try:
|
|
9
9
|
from numpy.typing import DTypeLike, NDArray
|
|
10
10
|
|
|
11
|
-
ArrayLike = NDArray[np.
|
|
11
|
+
ArrayLike = NDArray[np.float64]
|
|
12
12
|
except (ImportError, TypeError):
|
|
13
13
|
ArrayLike = np.ndarray # type: ignore[misc]
|
|
14
14
|
DTypeLike = np.dtype # type: ignore[misc]
|
|
@@ -33,20 +33,20 @@ OttCostFn_t = Literal[
|
|
|
33
33
|
"euclidean",
|
|
34
34
|
"sq_euclidean",
|
|
35
35
|
"cosine",
|
|
36
|
-
"
|
|
37
|
-
"
|
|
38
|
-
"
|
|
39
|
-
"
|
|
40
|
-
"
|
|
41
|
-
"
|
|
42
|
-
"
|
|
43
|
-
"
|
|
36
|
+
"pnorm_p",
|
|
37
|
+
"sq_pnorm",
|
|
38
|
+
"cosine",
|
|
39
|
+
"elastic_l1",
|
|
40
|
+
"elastic_l2",
|
|
41
|
+
"elastic_stvs",
|
|
42
|
+
"elastic_sqk_overlap",
|
|
43
|
+
"geodesic",
|
|
44
44
|
]
|
|
45
45
|
OttCostFnMap_t = Union[OttCostFn_t, Mapping[Literal["xy", "x", "y"], OttCostFn_t]]
|
|
46
46
|
GenericCostFn_t = Literal["barcode_distance", "leaf_distance", "custom"]
|
|
47
47
|
CostFn_t = Union[str, GenericCostFn_t, OttCostFn_t]
|
|
48
48
|
CostFnMap_t = Union[Union[OttCostFn_t, GenericCostFn_t], Mapping[str, Union[OttCostFn_t, GenericCostFn_t]]]
|
|
49
|
-
PathLike = Union[os.PathLike, str]
|
|
49
|
+
PathLike = Union[os.PathLike, str] # type: ignore[type-arg]
|
|
50
50
|
Policy_t = Literal[
|
|
51
51
|
"sequential",
|
|
52
52
|
"star",
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from ott.geometry import costs
|
|
2
2
|
|
|
3
3
|
from moscot.backends.ott._utils import sinkhorn_divergence
|
|
4
|
-
from moscot.backends.ott.output import OTTOutput
|
|
4
|
+
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
|
|
5
5
|
from moscot.backends.ott.solver import GWSolver, SinkhornSolver
|
|
6
6
|
from moscot.costs import register_cost
|
|
7
7
|
|
|
8
|
-
__all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"]
|
|
8
|
+
__all__ = ["OTTOutput", "GraphOTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"]
|
|
9
9
|
|
|
10
10
|
register_cost("euclidean", backend="ott")(costs.Euclidean)
|
|
11
11
|
register_cost("sq_euclidean", backend="ott")(costs.SqEuclidean)
|
|
@@ -13,4 +13,6 @@ register_cost("cosine", backend="ott")(costs.Cosine)
|
|
|
13
13
|
register_cost("pnorm_p", backend="ott")(costs.PNormP)
|
|
14
14
|
register_cost("sq_pnorm", backend="ott")(costs.SqPNorm)
|
|
15
15
|
register_cost("elastic_l1", backend="ott")(costs.ElasticL1)
|
|
16
|
+
register_cost("elastic_l2", backend="ott")(costs.ElasticL2)
|
|
16
17
|
register_cost("elastic_stvs", backend="ott")(costs.ElasticSTVS)
|
|
18
|
+
register_cost("elastic_sqk_overlap", backend="ott")(costs.ElasticSqKOverlap)
|
|
@@ -1,14 +1,17 @@
|
|
|
1
|
-
from typing import Any, Optional, Union
|
|
1
|
+
from typing import Any, Literal, Optional, Tuple, Union
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
4
|
import jax.numpy as jnp
|
|
5
5
|
import scipy.sparse as sp
|
|
6
|
-
from ott.geometry import epsilon_scheduler, geometry, pointcloud
|
|
6
|
+
from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
|
|
7
7
|
from ott.tools import sinkhorn_divergence as sdiv
|
|
8
8
|
|
|
9
9
|
from moscot._logging import logger
|
|
10
10
|
from moscot._types import ArrayLike, ScaleCost_t
|
|
11
11
|
|
|
12
|
+
Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]
|
|
13
|
+
|
|
14
|
+
|
|
12
15
|
__all__ = ["sinkhorn_divergence"]
|
|
13
16
|
|
|
14
17
|
|
|
@@ -86,3 +89,23 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
|
|
|
86
89
|
if arr.ndim != 2:
|
|
87
90
|
raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
|
|
88
91
|
return arr
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _instantiate_geodesic_cost(
|
|
95
|
+
arr: jax.Array,
|
|
96
|
+
problem_shape: Tuple[int, int],
|
|
97
|
+
t: Optional[float],
|
|
98
|
+
is_linear_term: bool,
|
|
99
|
+
epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
|
|
100
|
+
relative_epsilon: Optional[bool] = None,
|
|
101
|
+
scale_cost: Scale_t = 1.0,
|
|
102
|
+
directed: bool = True,
|
|
103
|
+
**kwargs: Any,
|
|
104
|
+
) -> geometry.Geometry:
|
|
105
|
+
n_src, n_tgt = problem_shape
|
|
106
|
+
if is_linear_term and n_src + n_tgt != arr.shape[0]:
|
|
107
|
+
raise ValueError(f"Expected `x` to have `{n_src + n_tgt}` points, found `{arr.shape[0]}`.")
|
|
108
|
+
t = epsilon / 4.0 if t is None else t
|
|
109
|
+
cm_full = geodesic.Geodesic.from_graph(arr, t=t, directed=directed, **kwargs).cost_matrix
|
|
110
|
+
cm = cm_full[:n_src, n_src:] if is_linear_term else cm_full
|
|
111
|
+
return geometry.Geometry(cm, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost)
|
|
@@ -14,7 +14,7 @@ import matplotlib.pyplot as plt
|
|
|
14
14
|
from moscot._types import ArrayLike, Device_t
|
|
15
15
|
from moscot.base.output import BaseSolverOutput
|
|
16
16
|
|
|
17
|
-
__all__ = ["OTTOutput"]
|
|
17
|
+
__all__ = ["OTTOutput", "GraphOTTOutput"]
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class OTTOutput(BaseSolverOutput):
|
|
@@ -174,7 +174,10 @@ class OTTOutput(BaseSolverOutput):
|
|
|
174
174
|
def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
|
|
175
175
|
if x.ndim == 1:
|
|
176
176
|
return self._output.apply(x, axis=1 - forward)
|
|
177
|
-
return self._output.apply(
|
|
177
|
+
return self._output.apply(
|
|
178
|
+
x.T,
|
|
179
|
+
axis=1 - forward,
|
|
180
|
+
).T # convert to batch first
|
|
178
181
|
|
|
179
182
|
@property
|
|
180
183
|
def shape(self) -> Tuple[int, int]: # noqa: D102
|
|
@@ -233,3 +236,63 @@ class OTTOutput(BaseSolverOutput):
|
|
|
233
236
|
|
|
234
237
|
def _ones(self, n: int) -> ArrayLike: # noqa: D102
|
|
235
238
|
return jnp.ones((n,))
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class GraphOTTOutput(OTTOutput):
|
|
242
|
+
"""Output of :term:`OT` problems with a graph geometry in the linear term.
|
|
243
|
+
|
|
244
|
+
Parameters
|
|
245
|
+
----------
|
|
246
|
+
output
|
|
247
|
+
Output of the :mod:`ott` backend.
|
|
248
|
+
shape
|
|
249
|
+
Shape of the problem.
|
|
250
|
+
"""
|
|
251
|
+
|
|
252
|
+
def __init__(
|
|
253
|
+
self,
|
|
254
|
+
output: Union[
|
|
255
|
+
sinkhorn.SinkhornOutput,
|
|
256
|
+
sinkhorn_lr.LRSinkhornOutput,
|
|
257
|
+
gromov_wasserstein.GWOutput,
|
|
258
|
+
gromov_wasserstein_lr.LRGWOutput,
|
|
259
|
+
],
|
|
260
|
+
shape: Tuple[int, int],
|
|
261
|
+
):
|
|
262
|
+
super().__init__(output)
|
|
263
|
+
self._shape = shape
|
|
264
|
+
|
|
265
|
+
@property
|
|
266
|
+
def shape(self) -> Tuple[int, int]: # noqa: D102
|
|
267
|
+
return self._shape
|
|
268
|
+
|
|
269
|
+
def _expand_data(self, x: jnp.ndarray, forward: bool) -> jnp.ndarray:
|
|
270
|
+
if forward:
|
|
271
|
+
shape = (self.shape[1],) if x.ndim == 1 else (self.shape[1], x.shape[1])
|
|
272
|
+
return jnp.concatenate((x, jnp.zeros(shape)))
|
|
273
|
+
shape = (self.shape[0],) if x.ndim == 1 else (self.shape[0], x.shape[1])
|
|
274
|
+
return jnp.concatenate((jnp.zeros(shape), x))
|
|
275
|
+
|
|
276
|
+
def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
|
|
277
|
+
x_expanded = self._expand_data(x, forward=forward)
|
|
278
|
+
# ott-jax only supports lse_mode=False with graph geometry
|
|
279
|
+
res = self._output.apply(x_expanded.T, axis=1 - forward, lse_mode=False).T
|
|
280
|
+
return res[len(x) :] if forward else res[: -len(x)]
|
|
281
|
+
|
|
282
|
+
def to(self, device: Optional[Device_t] = None) -> "GraphOTTOutput": # noqa: D102
|
|
283
|
+
if device is None:
|
|
284
|
+
return GraphOTTOutput(jax.device_put(self._output, device=device), shape=self.shape)
|
|
285
|
+
|
|
286
|
+
if isinstance(device, str) and ":" in device:
|
|
287
|
+
device, ix = device.split(":")
|
|
288
|
+
idx = int(ix)
|
|
289
|
+
else:
|
|
290
|
+
idx = 0
|
|
291
|
+
|
|
292
|
+
if not isinstance(device, xla_ext.Device):
|
|
293
|
+
try:
|
|
294
|
+
device = jax.devices(device)[idx]
|
|
295
|
+
except IndexError:
|
|
296
|
+
raise IndexError(f"Unable to fetch the device with `id={idx}`.") from None
|
|
297
|
+
|
|
298
|
+
return GraphOTTOutput(jax.device_put(self._output, device), shape=self.shape)
|
|
@@ -4,15 +4,22 @@ import types
|
|
|
4
4
|
from typing import Any, Literal, Mapping, Optional, Set, Tuple, Union
|
|
5
5
|
|
|
6
6
|
import jax
|
|
7
|
-
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud
|
|
8
9
|
from ott.problems.linear import linear_problem
|
|
9
10
|
from ott.problems.quadratic import quadratic_problem
|
|
10
11
|
from ott.solvers.linear import sinkhorn, sinkhorn_lr
|
|
11
12
|
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr
|
|
12
13
|
|
|
13
14
|
from moscot._types import ProblemKind_t, QuadInitializer_t, SinkhornInitializer_t
|
|
14
|
-
from moscot.backends.ott._utils import
|
|
15
|
-
|
|
15
|
+
from moscot.backends.ott._utils import (
|
|
16
|
+
_instantiate_geodesic_cost,
|
|
17
|
+
alpha_to_fused_penalty,
|
|
18
|
+
check_shapes,
|
|
19
|
+
ensure_2d,
|
|
20
|
+
)
|
|
21
|
+
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
|
|
22
|
+
from moscot.base.problems._utils import TimeScalesHeatKernel
|
|
16
23
|
from moscot.base.solver import OTSolver
|
|
17
24
|
from moscot.costs import get_cost
|
|
18
25
|
from moscot.utils.tagged_array import TaggedArray
|
|
@@ -43,14 +50,21 @@ class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC):
|
|
|
43
50
|
self._solver: Optional[OTTSolver_t] = None
|
|
44
51
|
self._problem: Optional[OTTProblem_t] = None
|
|
45
52
|
self._jit = jit
|
|
53
|
+
self._a: Optional[jnp.ndarray] = None
|
|
54
|
+
self._b: Optional[jnp.ndarray] = None
|
|
46
55
|
|
|
47
56
|
def _create_geometry(
|
|
48
57
|
self,
|
|
49
58
|
x: TaggedArray,
|
|
59
|
+
*,
|
|
60
|
+
is_linear_term: bool,
|
|
50
61
|
epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
|
|
51
62
|
relative_epsilon: Optional[bool] = None,
|
|
52
63
|
scale_cost: Scale_t = 1.0,
|
|
53
64
|
batch_size: Optional[int] = None,
|
|
65
|
+
problem_shape: Optional[Tuple[int, int]] = None,
|
|
66
|
+
t: Optional[float] = None,
|
|
67
|
+
directed: bool = True,
|
|
54
68
|
**kwargs: Any,
|
|
55
69
|
) -> geometry.Geometry:
|
|
56
70
|
if x.is_point_cloud:
|
|
@@ -88,17 +102,80 @@ class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC):
|
|
|
88
102
|
return geometry.Geometry(
|
|
89
103
|
kernel_matrix=arr, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost
|
|
90
104
|
)
|
|
105
|
+
if x.is_graph: # we currently only support this for the linear term.
|
|
106
|
+
return self._create_graph_geometry(
|
|
107
|
+
is_linear_term=is_linear_term,
|
|
108
|
+
x=x,
|
|
109
|
+
arr=arr,
|
|
110
|
+
problem_shape=problem_shape,
|
|
111
|
+
t=t,
|
|
112
|
+
epsilon=epsilon,
|
|
113
|
+
relative_epsilon=relative_epsilon,
|
|
114
|
+
scale_cost=scale_cost,
|
|
115
|
+
directed=directed,
|
|
116
|
+
**kwargs,
|
|
117
|
+
)
|
|
91
118
|
raise NotImplementedError(f"Creating geometry from `tag={x.tag!r}` is not yet implemented.")
|
|
92
119
|
|
|
93
120
|
def _solve( # type: ignore[override]
|
|
94
121
|
self,
|
|
95
122
|
prob: OTTProblem_t,
|
|
96
123
|
**kwargs: Any,
|
|
97
|
-
) -> OTTOutput:
|
|
124
|
+
) -> Union[OTTOutput, GraphOTTOutput]:
|
|
98
125
|
solver = jax.jit(self.solver) if self._jit else self.solver
|
|
99
126
|
out = solver(prob, **kwargs)
|
|
127
|
+
if isinstance(prob, linear_problem.LinearProblem) and isinstance(prob.geom, geodesic.Geodesic):
|
|
128
|
+
return GraphOTTOutput(out, shape=(len(self._a), len(self._b))) # type: ignore[arg-type]
|
|
100
129
|
return OTTOutput(out)
|
|
101
130
|
|
|
131
|
+
def _create_graph_geometry(
|
|
132
|
+
self,
|
|
133
|
+
is_linear_term: bool,
|
|
134
|
+
x: TaggedArray,
|
|
135
|
+
arr: jax.Array,
|
|
136
|
+
problem_shape: Optional[Tuple[int, int]],
|
|
137
|
+
t: Optional[float],
|
|
138
|
+
epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
|
|
139
|
+
relative_epsilon: Optional[bool] = None,
|
|
140
|
+
scale_cost: Scale_t = 1.0,
|
|
141
|
+
directed: bool = True,
|
|
142
|
+
**kwargs: Any,
|
|
143
|
+
) -> geometry.Geometry:
|
|
144
|
+
if x.cost == "geodesic":
|
|
145
|
+
if self.problem_kind == "linear":
|
|
146
|
+
if t is None:
|
|
147
|
+
if epsilon is None:
|
|
148
|
+
raise ValueError("`epsilon` cannot be `None`.")
|
|
149
|
+
return geodesic.Geodesic.from_graph(arr, t=epsilon / 4.0, directed=directed, **kwargs)
|
|
150
|
+
|
|
151
|
+
return _instantiate_geodesic_cost(
|
|
152
|
+
arr=arr,
|
|
153
|
+
problem_shape=problem_shape, # type: ignore[arg-type]
|
|
154
|
+
t=t,
|
|
155
|
+
is_linear_term=True,
|
|
156
|
+
epsilon=epsilon,
|
|
157
|
+
relative_epsilon=relative_epsilon,
|
|
158
|
+
scale_cost=scale_cost,
|
|
159
|
+
directed=directed,
|
|
160
|
+
**kwargs,
|
|
161
|
+
)
|
|
162
|
+
if self.problem_kind == "quadratic":
|
|
163
|
+
problem_shape = x.shape if problem_shape is None else problem_shape
|
|
164
|
+
return _instantiate_geodesic_cost(
|
|
165
|
+
arr=arr,
|
|
166
|
+
problem_shape=problem_shape,
|
|
167
|
+
t=t,
|
|
168
|
+
is_linear_term=is_linear_term,
|
|
169
|
+
epsilon=epsilon,
|
|
170
|
+
relative_epsilon=relative_epsilon,
|
|
171
|
+
scale_cost=scale_cost,
|
|
172
|
+
directed=directed,
|
|
173
|
+
**kwargs,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
raise NotImplementedError(f"Invalid problem kind `{self.problem_kind}`.")
|
|
177
|
+
raise NotImplementedError(f"If the geometry is a graph, `cost` must be `geodesic`, found `{x.cost}`.")
|
|
178
|
+
|
|
102
179
|
@property
|
|
103
180
|
def solver(self) -> OTTSolver_t:
|
|
104
181
|
""":mod:`ott` solver."""
|
|
@@ -165,6 +242,8 @@ class SinkhornSolver(OTTJaxSolver):
|
|
|
165
242
|
|
|
166
243
|
def _prepare(
|
|
167
244
|
self,
|
|
245
|
+
a: jnp.ndarray,
|
|
246
|
+
b: jnp.ndarray,
|
|
168
247
|
xy: Optional[TaggedArray] = None,
|
|
169
248
|
x: Optional[TaggedArray] = None,
|
|
170
249
|
y: Optional[TaggedArray] = None,
|
|
@@ -175,24 +254,35 @@ class SinkhornSolver(OTTJaxSolver):
|
|
|
175
254
|
scale_cost: Scale_t = 1.0,
|
|
176
255
|
cost_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
|
|
177
256
|
cost_matrix_rank: Optional[int] = None,
|
|
257
|
+
time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None,
|
|
178
258
|
# problem
|
|
179
259
|
**kwargs: Any,
|
|
180
260
|
) -> linear_problem.LinearProblem:
|
|
181
261
|
del x, y
|
|
262
|
+
time_scales_heat_kernel = (
|
|
263
|
+
TimeScalesHeatKernel(None, None, None) if time_scales_heat_kernel is None else time_scales_heat_kernel
|
|
264
|
+
)
|
|
182
265
|
if xy is None:
|
|
183
266
|
raise ValueError(f"Unable to create geometry from `xy={xy}`.")
|
|
184
|
-
|
|
267
|
+
self._a = a
|
|
268
|
+
self._b = b
|
|
185
269
|
geom = self._create_geometry(
|
|
186
270
|
xy,
|
|
271
|
+
is_linear_term=True,
|
|
187
272
|
epsilon=epsilon,
|
|
188
273
|
relative_epsilon=relative_epsilon,
|
|
189
274
|
batch_size=batch_size,
|
|
275
|
+
problem_shape=(len(self._a), len(self._b)),
|
|
190
276
|
scale_cost=scale_cost,
|
|
277
|
+
t=time_scales_heat_kernel.xy,
|
|
191
278
|
**cost_kwargs,
|
|
192
279
|
)
|
|
193
280
|
if cost_matrix_rank is not None:
|
|
194
281
|
geom = geom.to_LRCGeometry(rank=cost_matrix_rank)
|
|
195
|
-
|
|
282
|
+
if isinstance(geom, geodesic.Geodesic):
|
|
283
|
+
a = jnp.concatenate((a, jnp.zeros_like(self._b)), axis=0)
|
|
284
|
+
b = jnp.concatenate((jnp.zeros_like(self._a), b), axis=0)
|
|
285
|
+
self._problem = linear_problem.LinearProblem(geom, a=a, b=b, **kwargs)
|
|
196
286
|
return self._problem
|
|
197
287
|
|
|
198
288
|
@property
|
|
@@ -206,7 +296,15 @@ class SinkhornSolver(OTTJaxSolver):
|
|
|
206
296
|
|
|
207
297
|
@classmethod
|
|
208
298
|
def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]:
|
|
209
|
-
geom_kwargs = {
|
|
299
|
+
geom_kwargs = {
|
|
300
|
+
"epsilon",
|
|
301
|
+
"relative_epsilon",
|
|
302
|
+
"batch_size",
|
|
303
|
+
"scale_cost",
|
|
304
|
+
"cost_kwargs",
|
|
305
|
+
"cost_matrix_rank",
|
|
306
|
+
"t",
|
|
307
|
+
}
|
|
210
308
|
problem_kwargs = set(inspect.signature(linear_problem.LinearProblem).parameters.keys())
|
|
211
309
|
problem_kwargs -= {"geom"}
|
|
212
310
|
return geom_kwargs | problem_kwargs, {"epsilon"}
|
|
@@ -270,6 +368,8 @@ class GWSolver(OTTJaxSolver):
|
|
|
270
368
|
|
|
271
369
|
def _prepare(
|
|
272
370
|
self,
|
|
371
|
+
a: jnp.ndarray,
|
|
372
|
+
b: jnp.ndarray,
|
|
273
373
|
xy: Optional[TaggedArray] = None,
|
|
274
374
|
x: Optional[TaggedArray] = None,
|
|
275
375
|
y: Optional[TaggedArray] = None,
|
|
@@ -280,32 +380,45 @@ class GWSolver(OTTJaxSolver):
|
|
|
280
380
|
scale_cost: Scale_t = 1.0,
|
|
281
381
|
cost_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
|
|
282
382
|
cost_matrix_rank: Optional[int] = None,
|
|
383
|
+
time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None,
|
|
283
384
|
# problem
|
|
284
385
|
alpha: float = 0.5,
|
|
285
386
|
**kwargs: Any,
|
|
286
387
|
) -> quadratic_problem.QuadraticProblem:
|
|
388
|
+
self._a = a
|
|
389
|
+
self._b = b
|
|
390
|
+
time_scales_heat_kernel = (
|
|
391
|
+
TimeScalesHeatKernel(None, None, None) if time_scales_heat_kernel is None else time_scales_heat_kernel
|
|
392
|
+
)
|
|
287
393
|
if x is None or y is None:
|
|
288
394
|
raise ValueError(f"Unable to create geometry from `x={x}`, `y={y}`.")
|
|
289
|
-
geom_kwargs: Any = {
|
|
395
|
+
geom_kwargs: dict[str, Any] = {
|
|
290
396
|
"epsilon": epsilon,
|
|
291
397
|
"relative_epsilon": relative_epsilon,
|
|
292
398
|
"batch_size": batch_size,
|
|
293
399
|
"scale_cost": scale_cost,
|
|
294
|
-
"cost_matrix_rank": cost_matrix_rank,
|
|
295
400
|
**cost_kwargs,
|
|
296
401
|
}
|
|
297
|
-
|
|
298
|
-
|
|
402
|
+
if cost_matrix_rank is not None:
|
|
403
|
+
geom_kwargs["cost_matrix_rank"] = cost_matrix_rank
|
|
404
|
+
geom_xx = self._create_geometry(x, t=time_scales_heat_kernel.x, is_linear_term=False, **geom_kwargs)
|
|
405
|
+
geom_yy = self._create_geometry(y, t=time_scales_heat_kernel.y, is_linear_term=False, **geom_kwargs)
|
|
299
406
|
if alpha == 1.0 or xy is None: # GW
|
|
300
407
|
# arbitrary fused penalty; must be positive
|
|
301
408
|
geom_xy, fused_penalty = None, 1.0
|
|
302
409
|
else: # FGW
|
|
303
410
|
fused_penalty = alpha_to_fused_penalty(alpha)
|
|
304
|
-
geom_xy = self._create_geometry(
|
|
411
|
+
geom_xy = self._create_geometry(
|
|
412
|
+
xy,
|
|
413
|
+
t=time_scales_heat_kernel.xy,
|
|
414
|
+
problem_shape=(x.shape[0], y.shape[0]),
|
|
415
|
+
is_linear_term=True,
|
|
416
|
+
**geom_kwargs,
|
|
417
|
+
)
|
|
305
418
|
check_shapes(geom_xx, geom_yy, geom_xy)
|
|
306
419
|
|
|
307
420
|
self._problem = quadratic_problem.QuadraticProblem(
|
|
308
|
-
geom_xx, geom_yy, geom_xy, fused_penalty=fused_penalty, **kwargs
|
|
421
|
+
geom_xx, geom_yy, geom_xy, fused_penalty=fused_penalty, a=self._a, b=self._b, **kwargs
|
|
309
422
|
)
|
|
310
423
|
return self._problem
|
|
311
424
|
|