moscot 0.4.2__tar.gz → 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {moscot-0.4.2 → moscot-0.5.0}/.pre-commit-config.yaml +1 -1
- {moscot-0.4.2 → moscot-0.5.0}/PKG-INFO +7 -7
- {moscot-0.4.2 → moscot-0.5.0}/pyproject.toml +7 -9
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/backends/ott/__init__.py +2 -3
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/backends/ott/_utils.py +2 -2
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/backends/ott/output.py +5 -224
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/backends/ott/solver.py +3 -207
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/backends/utils.py +8 -12
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/output.py +1 -20
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/problems/_mixins.py +1 -1
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/problems/birth_death.py +1 -1
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/costs/_costs.py +3 -3
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/datasets.py +1 -1
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/spatiotemporal/_spatio_temporal.py +1 -1
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/subset_policy.py +7 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot.egg-info/PKG-INFO +7 -7
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot.egg-info/SOURCES.txt +0 -6
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot.egg-info/requires.txt +6 -7
- moscot-0.4.2/src/moscot/neural/base/__init__.py +0 -0
- moscot-0.4.2/src/moscot/neural/base/problems/__init__.py +0 -3
- moscot-0.4.2/src/moscot/neural/base/problems/problem.py +0 -243
- moscot-0.4.2/src/moscot/neural/problems/__init__.py +0 -3
- moscot-0.4.2/src/moscot/neural/problems/generic/__init__.py +0 -3
- moscot-0.4.2/src/moscot/neural/problems/generic/_generic.py +0 -78
- {moscot-0.4.2 → moscot-0.5.0}/.gitignore +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/.gitmodules +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/.readthedocs.yml +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/.run_notebooks.sh +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/LICENSE +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/MANIFEST.in +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/README.rst +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/codecov.yml +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/setup.cfg +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/__init__.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/_constants.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/_logging.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/_registry.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/_types.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/backends/__init__.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/__init__.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/cost.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/problems/__init__.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/problems/_utils.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/problems/compound_problem.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/problems/manager.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/problems/problem.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/solver.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/costs/__init__.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/costs/_utils.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/plotting/__init__.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/plotting/_plotting.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/plotting/_utils.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/__init__.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/_utils.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/cross_modality/__init__.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/cross_modality/_mixins.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/cross_modality/_translation.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/generic/__init__.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/generic/_generic.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/generic/_mixins.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/space/__init__.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/space/_alignment.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/space/_mapping.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/space/_mixins.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/spatiotemporal/__init__.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/time/__init__.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/time/_lineage.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/time/_mixins.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/py.typed +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/__init__.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/_data/allTFs_dmel.txt +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/_data/allTFs_hg38.txt +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/_data/allTFs_mm.txt +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/_data/human_apoptosis.txt +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/_data/human_proliferation.txt +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/_data/mouse_apoptosis.txt +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/_data/mouse_proliferation.txt +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/data.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/tagged_array.py +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot.egg-info/dependency_links.txt +0 -0
- {moscot-0.4.2 → moscot-0.5.0}/src/moscot.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: moscot
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0
|
|
4
4
|
Summary: Multi-omic single-cell optimal transport tools
|
|
5
5
|
Author: Dominik Klein, Giovanni Palla, Michal Klein, Zoe Piran, Marius Lange
|
|
6
6
|
Maintainer-email: Dominik Klein <dominik.klein@helmholtz-muenchen.de>, Giovanni Palla <giovanni.palla@helmholtz-muenchen.de>, Michal Klein <michal.klein@helmholtz-muenchen.de>
|
|
@@ -65,18 +65,18 @@ Requires-Dist: anndata>=0.9.1
|
|
|
65
65
|
Requires-Dist: scanpy>=1.9.3
|
|
66
66
|
Requires-Dist: wrapt>=1.13.2
|
|
67
67
|
Requires-Dist: docrep>=0.3.2
|
|
68
|
-
Requires-Dist:
|
|
68
|
+
Requires-Dist: jax>=0.6.1
|
|
69
|
+
Requires-Dist: ott-jax>=0.6.0
|
|
69
70
|
Requires-Dist: cloudpickle>=2.2.0
|
|
70
71
|
Requires-Dist: rich>=13.5
|
|
71
72
|
Requires-Dist: docstring_inheritance>=2.0.0
|
|
72
73
|
Requires-Dist: mudata>=0.2.2
|
|
74
|
+
Requires-Dist: optax
|
|
75
|
+
Requires-Dist: flax
|
|
76
|
+
Requires-Dist: diffrax
|
|
77
|
+
Requires-Dist: ott-jax[neural]>=0.5.0
|
|
73
78
|
Provides-Extra: spatial
|
|
74
79
|
Requires-Dist: squidpy>=1.2.3; extra == "spatial"
|
|
75
|
-
Provides-Extra: neural
|
|
76
|
-
Requires-Dist: optax; extra == "neural"
|
|
77
|
-
Requires-Dist: flax; extra == "neural"
|
|
78
|
-
Requires-Dist: diffrax; extra == "neural"
|
|
79
|
-
Requires-Dist: ott-jax[neural]>=0.5.0; extra == "neural"
|
|
80
80
|
Provides-Extra: dev
|
|
81
81
|
Requires-Dist: pre-commit>=3.0.0; extra == "dev"
|
|
82
82
|
Requires-Dist: tox>=4; extra == "dev"
|
|
@@ -54,11 +54,16 @@ dependencies = [
|
|
|
54
54
|
"scanpy>=1.9.3",
|
|
55
55
|
"wrapt>=1.13.2",
|
|
56
56
|
"docrep>=0.3.2",
|
|
57
|
-
"
|
|
57
|
+
"jax>=0.6.1",
|
|
58
|
+
"ott-jax>=0.6.0",
|
|
58
59
|
"cloudpickle>=2.2.0",
|
|
59
60
|
"rich>=13.5",
|
|
60
61
|
"docstring_inheritance>=2.0.0",
|
|
61
|
-
"mudata>=0.2.2"
|
|
62
|
+
"mudata>=0.2.2",
|
|
63
|
+
"optax",
|
|
64
|
+
"flax",
|
|
65
|
+
"diffrax",
|
|
66
|
+
"ott-jax[neural]>=0.5.0"
|
|
62
67
|
]
|
|
63
68
|
|
|
64
69
|
[project.optional-dependencies]
|
|
@@ -66,13 +71,6 @@ spatial = [
|
|
|
66
71
|
"squidpy>=1.2.3"
|
|
67
72
|
]
|
|
68
73
|
|
|
69
|
-
neural = [
|
|
70
|
-
"optax",
|
|
71
|
-
"flax",
|
|
72
|
-
"diffrax",
|
|
73
|
-
"ott-jax[neural]>=0.5.0",
|
|
74
|
-
|
|
75
|
-
]
|
|
76
74
|
dev = [
|
|
77
75
|
"pre-commit>=3.0.0",
|
|
78
76
|
"tox>=4",
|
|
@@ -1,15 +1,14 @@
|
|
|
1
1
|
from ott.geometry import costs
|
|
2
2
|
|
|
3
3
|
from moscot.backends.ott._utils import sinkhorn_divergence
|
|
4
|
-
from moscot.backends.ott.output import GraphOTTOutput,
|
|
5
|
-
from moscot.backends.ott.solver import
|
|
4
|
+
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
|
|
5
|
+
from moscot.backends.ott.solver import GWSolver, SinkhornSolver
|
|
6
6
|
from moscot.costs import register_cost
|
|
7
7
|
|
|
8
8
|
__all__ = [
|
|
9
9
|
"OTTOutput",
|
|
10
10
|
"GWSolver",
|
|
11
11
|
"SinkhornSolver",
|
|
12
|
-
"NeuralOutput",
|
|
13
12
|
"sinkhorn_divergence",
|
|
14
13
|
"GENOTLinSolver",
|
|
15
14
|
"GraphOTTOutput",
|
|
@@ -10,7 +10,6 @@ import scipy.sparse as sp
|
|
|
10
10
|
from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
|
|
11
11
|
from ott.initializers.linear import initializers as init_lib
|
|
12
12
|
from ott.initializers.linear import initializers_lr as lr_init_lib
|
|
13
|
-
from ott.neural import datasets
|
|
14
13
|
from ott.solvers import utils as solver_utils
|
|
15
14
|
from ott.tools.sinkhorn_divergence import sinkhorn_divergence as sinkhorn_div
|
|
16
15
|
|
|
@@ -18,6 +17,7 @@ from moscot._logging import logger
|
|
|
18
17
|
from moscot._types import ArrayLike, ScaleCost_t
|
|
19
18
|
|
|
20
19
|
Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]
|
|
20
|
+
OTDataset = Any # to be removed when neural part is being removed from moscot
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
__all__ = ["sinkhorn_divergence"]
|
|
@@ -272,7 +272,7 @@ def data_match_fn(
|
|
|
272
272
|
|
|
273
273
|
class Loader:
|
|
274
274
|
|
|
275
|
-
def __init__(self, dataset:
|
|
275
|
+
def __init__(self, dataset: OTDataset, batch_size: int, seed: Optional[int] = None):
|
|
276
276
|
self.dataset = dataset
|
|
277
277
|
self.batch_size = batch_size
|
|
278
278
|
self._rng = np.random.default_rng(seed)
|
|
@@ -1,12 +1,8 @@
|
|
|
1
|
-
from typing import Any,
|
|
2
|
-
|
|
3
|
-
import jaxlib.xla_extension as xla_ext
|
|
1
|
+
from typing import Any, Optional, Tuple, Union
|
|
4
2
|
|
|
5
3
|
import jax
|
|
6
4
|
import jax.numpy as jnp
|
|
7
5
|
import numpy as np
|
|
8
|
-
import scipy.sparse as sp
|
|
9
|
-
from ott.neural.methods.flows.genot import GENOT
|
|
10
6
|
from ott.solvers.linear import sinkhorn, sinkhorn_lr
|
|
11
7
|
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr
|
|
12
8
|
|
|
@@ -14,10 +10,9 @@ import matplotlib as mpl
|
|
|
14
10
|
import matplotlib.pyplot as plt
|
|
15
11
|
|
|
16
12
|
from moscot._types import ArrayLike, Device_t
|
|
17
|
-
from moscot.
|
|
18
|
-
from moscot.base.output import BaseDiscreteSolverOutput, BaseNeuralOutput
|
|
13
|
+
from moscot.base.output import BaseDiscreteSolverOutput
|
|
19
14
|
|
|
20
|
-
__all__ = ["OTTOutput", "GraphOTTOutput"
|
|
15
|
+
__all__ = ["OTTOutput", "GraphOTTOutput"]
|
|
21
16
|
|
|
22
17
|
|
|
23
18
|
class OTTOutput(BaseDiscreteSolverOutput):
|
|
@@ -209,7 +204,7 @@ class OTTOutput(BaseDiscreteSolverOutput):
|
|
|
209
204
|
else:
|
|
210
205
|
idx = 0
|
|
211
206
|
|
|
212
|
-
if not isinstance(device,
|
|
207
|
+
if not isinstance(device, jax.Device):
|
|
213
208
|
try:
|
|
214
209
|
device = jax.devices(device)[idx]
|
|
215
210
|
except IndexError:
|
|
@@ -244,220 +239,6 @@ class OTTOutput(BaseDiscreteSolverOutput):
|
|
|
244
239
|
return jnp.ones((n,))
|
|
245
240
|
|
|
246
241
|
|
|
247
|
-
class NeuralOutput(BaseNeuralOutput):
|
|
248
|
-
"""Output wrapper for GENOT."""
|
|
249
|
-
|
|
250
|
-
def __init__(self, model: GENOT, logs: dict[str, list[float]]):
|
|
251
|
-
"""Initialize `NeuralOutput`.
|
|
252
|
-
|
|
253
|
-
Parameters
|
|
254
|
-
----------
|
|
255
|
-
model : GENOT
|
|
256
|
-
The OTT-Jax GENOT model
|
|
257
|
-
"""
|
|
258
|
-
self._logs = logs
|
|
259
|
-
self._model = model
|
|
260
|
-
|
|
261
|
-
@property
|
|
262
|
-
def logs(self):
|
|
263
|
-
"""Logs of the training. A dictionary containing what the numeric values are i.e., loss.
|
|
264
|
-
|
|
265
|
-
Returns
|
|
266
|
-
-------
|
|
267
|
-
dict[str, list[float]]
|
|
268
|
-
"""
|
|
269
|
-
return self._logs
|
|
270
|
-
|
|
271
|
-
def _project_transport_matrix(
|
|
272
|
-
self,
|
|
273
|
-
src_dist: ArrayLike,
|
|
274
|
-
tgt_dist: ArrayLike,
|
|
275
|
-
func: Callable[[ArrayLike], ArrayLike],
|
|
276
|
-
save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments
|
|
277
|
-
batch_size: int = 1024,
|
|
278
|
-
k: int = 30,
|
|
279
|
-
length_scale: Optional[float] = None,
|
|
280
|
-
seed: int = 42,
|
|
281
|
-
recall_target: float = 0.95,
|
|
282
|
-
aggregate_to_topk: bool = True,
|
|
283
|
-
) -> sp.csr_matrix:
|
|
284
|
-
row_indices: List[ArrayLike] = []
|
|
285
|
-
column_indices: List[ArrayLike] = []
|
|
286
|
-
distances_list: List[ArrayLike] = []
|
|
287
|
-
if length_scale is None:
|
|
288
|
-
key = jax.random.PRNGKey(seed)
|
|
289
|
-
src_batch = src_dist[jax.random.choice(key, src_dist.shape[0], shape=((batch_size,)))]
|
|
290
|
-
tgt_batch = tgt_dist[jax.random.choice(key, tgt_dist.shape[0], shape=((batch_size,)))]
|
|
291
|
-
length_scale = jnp.std(jnp.concatenate((func(src_batch), tgt_batch)))
|
|
292
|
-
for index in range(0, len(src_dist), batch_size):
|
|
293
|
-
distances, indices = get_nearest_neighbors(
|
|
294
|
-
func(src_dist[index : index + batch_size, :]),
|
|
295
|
-
tgt_dist,
|
|
296
|
-
k,
|
|
297
|
-
recall_target=recall_target,
|
|
298
|
-
aggregate_to_topk=aggregate_to_topk,
|
|
299
|
-
)
|
|
300
|
-
distances = jnp.exp(-((distances / length_scale) ** 2))
|
|
301
|
-
distances /= jnp.expand_dims(jnp.sum(distances, axis=1), axis=1)
|
|
302
|
-
distances_list.append(distances.flatten())
|
|
303
|
-
column_indices.append(indices.flatten())
|
|
304
|
-
row_indices.append(
|
|
305
|
-
jnp.repeat(jnp.arange(index, index + min(batch_size, len(src_dist) - index)), min(k, len(tgt_dist)))
|
|
306
|
-
)
|
|
307
|
-
distances = jnp.concatenate(distances_list)
|
|
308
|
-
row_indices = jnp.concatenate(row_indices)
|
|
309
|
-
column_indices = jnp.concatenate(column_indices)
|
|
310
|
-
tm = sp.csr_matrix((distances, (row_indices, column_indices)), shape=[len(src_dist), len(tgt_dist)])
|
|
311
|
-
if save_transport_matrix:
|
|
312
|
-
self._transport_matrix = tm
|
|
313
|
-
return tm
|
|
314
|
-
|
|
315
|
-
def project_to_transport_matrix( # type:ignore[override]
|
|
316
|
-
self,
|
|
317
|
-
src_cells: ArrayLike,
|
|
318
|
-
tgt_cells: ArrayLike,
|
|
319
|
-
condition: ArrayLike = None,
|
|
320
|
-
save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments
|
|
321
|
-
batch_size: int = 1024,
|
|
322
|
-
k: int = 30,
|
|
323
|
-
length_scale: Optional[float] = None,
|
|
324
|
-
seed: int = 42,
|
|
325
|
-
recall_target: float = 0.95,
|
|
326
|
-
aggregate_to_topk: bool = True,
|
|
327
|
-
) -> sp.csr_matrix:
|
|
328
|
-
"""Project conditional neural OT map onto cells.
|
|
329
|
-
|
|
330
|
-
In constrast to discrete OT, (conditional) neural OT does not necessarily map cells onto cells,
|
|
331
|
-
but a cell can also be mapped to a location between two cells. This function computes
|
|
332
|
-
a pseudo-transport matrix considering the neighborhood of where a cell is mapped to.
|
|
333
|
-
Therefore, a neighborhood graph of `k` target cells is computed around each transported cell
|
|
334
|
-
of the source distribution. The assignment likelihood of each mapped cell to the target cells is then
|
|
335
|
-
computed with a Gaussian kernel with parameter `length_scale`.
|
|
336
|
-
|
|
337
|
-
Parameters
|
|
338
|
-
----------
|
|
339
|
-
condition
|
|
340
|
-
Condition `src_cells` correspond to.
|
|
341
|
-
src_cells
|
|
342
|
-
Cells which are to be mapped.
|
|
343
|
-
tgt_cells
|
|
344
|
-
Cells from which the neighborhood graph around the mapped `src_cells` are computed.
|
|
345
|
-
forward
|
|
346
|
-
Whether to map cells based on the forward transport map or backward transport map.
|
|
347
|
-
save_transport_matrix
|
|
348
|
-
Whether to save the transport matrix.
|
|
349
|
-
batch_size
|
|
350
|
-
Number of data points in the source distribution the neighborhood graph is computed
|
|
351
|
-
for in parallel.
|
|
352
|
-
k
|
|
353
|
-
Number of neighbors to construct the k-nearest neighbor graph of a mapped cell.
|
|
354
|
-
length_scale
|
|
355
|
-
Length scale of the Gaussian kernel used to compute the assignment likelihood. If `None`,
|
|
356
|
-
`length_scale` is set to the empirical standard deviation of `batch_size` pairs of data points of the
|
|
357
|
-
mapped source and target distribution.
|
|
358
|
-
seed
|
|
359
|
-
Random seed for sampling the pairs of distributions for computing the variance in case `length_scale`
|
|
360
|
-
is `None`.
|
|
361
|
-
recall_target
|
|
362
|
-
Recall target for the approximation.
|
|
363
|
-
aggregate_to_topk
|
|
364
|
-
When true, the nearest neighbor aggregates approximate results to the top-k in sorted order.
|
|
365
|
-
When false, returns the approximate results unsorted.
|
|
366
|
-
In this case, the number of the approximate results is implementation defined and is greater or
|
|
367
|
-
equal to the specified k.
|
|
368
|
-
|
|
369
|
-
Returns
|
|
370
|
-
-------
|
|
371
|
-
The projected transport matrix.
|
|
372
|
-
"""
|
|
373
|
-
src_cells, tgt_cells = jnp.asarray(src_cells), jnp.asarray(tgt_cells)
|
|
374
|
-
conditioned_fn: Callable[[ArrayLike], ArrayLike] = lambda x: self.push(x, condition)
|
|
375
|
-
push = self.push if condition is None else conditioned_fn
|
|
376
|
-
func, src_dist, tgt_dist = (push, src_cells, tgt_cells)
|
|
377
|
-
return self._project_transport_matrix(
|
|
378
|
-
src_dist=src_dist,
|
|
379
|
-
tgt_dist=tgt_dist,
|
|
380
|
-
func=func,
|
|
381
|
-
save_transport_matrix=save_transport_matrix, # TODO(@MUCDK) adapt order of arguments
|
|
382
|
-
batch_size=batch_size,
|
|
383
|
-
k=k,
|
|
384
|
-
length_scale=length_scale,
|
|
385
|
-
seed=seed,
|
|
386
|
-
recall_target=recall_target,
|
|
387
|
-
aggregate_to_topk=aggregate_to_topk,
|
|
388
|
-
)
|
|
389
|
-
|
|
390
|
-
def push(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
|
|
391
|
-
"""Push distribution `x` conditioned on condition `cond`.
|
|
392
|
-
|
|
393
|
-
Parameters
|
|
394
|
-
----------
|
|
395
|
-
x
|
|
396
|
-
Distribution to push.
|
|
397
|
-
cond
|
|
398
|
-
Condition of conditional neural OT.
|
|
399
|
-
|
|
400
|
-
Returns
|
|
401
|
-
-------
|
|
402
|
-
Pushed distribution.
|
|
403
|
-
"""
|
|
404
|
-
if isinstance(x, (bool, int, float, complex)):
|
|
405
|
-
raise ValueError("Expected array, found scalar value.")
|
|
406
|
-
if x.ndim not in (1, 2):
|
|
407
|
-
raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.")
|
|
408
|
-
return self._apply_forward(x, cond=cond)
|
|
409
|
-
|
|
410
|
-
def _apply_forward(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
|
|
411
|
-
return self._model.transport(x, condition=cond)
|
|
412
|
-
|
|
413
|
-
@property
|
|
414
|
-
def is_linear(self) -> bool: # noqa: D102
|
|
415
|
-
return True # TODO(ilan-gold): need to contribute something to ott-jax so this is resolvable from GENOT
|
|
416
|
-
|
|
417
|
-
@property
|
|
418
|
-
def shape(self) -> Tuple[int, int]:
|
|
419
|
-
"""%(shape)s."""
|
|
420
|
-
raise NotImplementedError()
|
|
421
|
-
|
|
422
|
-
def to(
|
|
423
|
-
self,
|
|
424
|
-
device: Optional[Device_t] = None,
|
|
425
|
-
) -> "NeuralOutput":
|
|
426
|
-
"""Transfer the output to another device or change its data type.
|
|
427
|
-
|
|
428
|
-
Parameters
|
|
429
|
-
----------
|
|
430
|
-
device
|
|
431
|
-
If not `None`, the output will be transferred to `device`.
|
|
432
|
-
|
|
433
|
-
Returns
|
|
434
|
-
-------
|
|
435
|
-
The output on a saved on `device`.
|
|
436
|
-
"""
|
|
437
|
-
# # TODO(michalk8): when polishing docs, move the definition to the base class + use docrep
|
|
438
|
-
# if isinstance(device, str) and ":" in device:
|
|
439
|
-
# device, ix = device.split(":")
|
|
440
|
-
# idx = int(ix)
|
|
441
|
-
# else:
|
|
442
|
-
# idx = 0
|
|
443
|
-
|
|
444
|
-
# if not isinstance(device, xla_ext.Device):
|
|
445
|
-
# try:
|
|
446
|
-
# device = jax.devices(device)[idx]
|
|
447
|
-
# except IndexError as err:
|
|
448
|
-
# raise IndexError(f"Unable to fetch the device with `id={idx}`.") from err
|
|
449
|
-
|
|
450
|
-
# out = jax.device_put(self._model, device)
|
|
451
|
-
# return NeuralOutput(out)
|
|
452
|
-
return self # TODO(ilan-gold) move model to device
|
|
453
|
-
|
|
454
|
-
@property
|
|
455
|
-
def converged(self) -> bool:
|
|
456
|
-
"""%(converged)s."""
|
|
457
|
-
# always return True for now
|
|
458
|
-
return True
|
|
459
|
-
|
|
460
|
-
|
|
461
242
|
class GraphOTTOutput(OTTOutput):
|
|
462
243
|
"""Output of :term:`OT` problems with a graph geometry in the linear term.
|
|
463
244
|
|
|
@@ -509,7 +290,7 @@ class GraphOTTOutput(OTTOutput):
|
|
|
509
290
|
else:
|
|
510
291
|
idx = 0
|
|
511
292
|
|
|
512
|
-
if not isinstance(device,
|
|
293
|
+
if not isinstance(device, jax.Device):
|
|
513
294
|
try:
|
|
514
295
|
device = jax.devices(device)[idx]
|
|
515
296
|
except IndexError:
|
|
@@ -1,12 +1,9 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
import functools
|
|
3
2
|
import inspect
|
|
4
|
-
import math
|
|
5
3
|
import types
|
|
6
4
|
from typing import (
|
|
7
5
|
Any,
|
|
8
6
|
Hashable,
|
|
9
|
-
List,
|
|
10
7
|
Literal,
|
|
11
8
|
Mapping,
|
|
12
9
|
NamedTuple,
|
|
@@ -17,21 +14,13 @@ from typing import (
|
|
|
17
14
|
Union,
|
|
18
15
|
)
|
|
19
16
|
|
|
20
|
-
import optax
|
|
21
|
-
|
|
22
17
|
import jax
|
|
23
18
|
import jax.numpy as jnp
|
|
24
|
-
import numpy as np
|
|
25
19
|
from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud
|
|
26
|
-
from ott.neural.datasets import OTData, OTDataset
|
|
27
|
-
from ott.neural.methods.flows import dynamics, genot
|
|
28
|
-
from ott.neural.networks.layers import time_encoder
|
|
29
|
-
from ott.neural.networks.velocity_field import VelocityField
|
|
30
20
|
from ott.problems.linear import linear_problem
|
|
31
21
|
from ott.problems.quadratic import quadratic_problem
|
|
32
22
|
from ott.solvers.linear import sinkhorn, sinkhorn_lr
|
|
33
23
|
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr
|
|
34
|
-
from ott.solvers.utils import uniform_sampler
|
|
35
24
|
|
|
36
25
|
from moscot._logging import logger
|
|
37
26
|
from moscot._types import (
|
|
@@ -43,23 +32,20 @@ from moscot._types import (
|
|
|
43
32
|
)
|
|
44
33
|
from moscot.backends.ott._utils import (
|
|
45
34
|
InitializerResolver,
|
|
46
|
-
Loader,
|
|
47
|
-
MultiLoader,
|
|
48
35
|
_instantiate_geodesic_cost,
|
|
49
36
|
alpha_to_fused_penalty,
|
|
50
37
|
check_shapes,
|
|
51
38
|
convert_scipy_sparse,
|
|
52
|
-
data_match_fn,
|
|
53
39
|
densify,
|
|
54
40
|
ensure_2d,
|
|
55
41
|
)
|
|
56
|
-
from moscot.backends.ott.output import GraphOTTOutput,
|
|
42
|
+
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
|
|
57
43
|
from moscot.base.problems._utils import TimeScalesHeatKernel
|
|
58
44
|
from moscot.base.solver import OTSolver
|
|
59
45
|
from moscot.costs import get_cost
|
|
60
|
-
from moscot.utils.tagged_array import
|
|
46
|
+
from moscot.utils.tagged_array import TaggedArray
|
|
61
47
|
|
|
62
|
-
__all__ = ["SinkhornSolver", "GWSolver"
|
|
48
|
+
__all__ = ["SinkhornSolver", "GWSolver"]
|
|
63
49
|
|
|
64
50
|
OTTSolver_t = Union[
|
|
65
51
|
sinkhorn.Sinkhorn,
|
|
@@ -516,193 +502,3 @@ class GWSolver(OTTJaxSolver):
|
|
|
516
502
|
problem_kwargs -= {"geom_xx", "geom_yy", "geom_xy", "fused_penalty"}
|
|
517
503
|
problem_kwargs |= {"alpha"}
|
|
518
504
|
return geom_kwargs | problem_kwargs, {"epsilon"}
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
class GENOTLinSolver(OTSolver[OTTOutput]):
|
|
522
|
-
"""Solver class for genot.GENOT linear :cite:`klein2023generative`."""
|
|
523
|
-
|
|
524
|
-
def __init__(self, **kwargs: Any) -> None:
|
|
525
|
-
"""Initiate the class with any kwargs passed to the ott-jax class."""
|
|
526
|
-
super().__init__()
|
|
527
|
-
self._train_sampler: Optional[MultiLoader] = None
|
|
528
|
-
self._valid_sampler: Optional[MultiLoader] = None
|
|
529
|
-
self._neural_kwargs = kwargs
|
|
530
|
-
|
|
531
|
-
@property
|
|
532
|
-
def problem_kind(self) -> ProblemKind_t: # noqa: D102
|
|
533
|
-
return "linear"
|
|
534
|
-
|
|
535
|
-
def _prepare( # type: ignore[override]
|
|
536
|
-
self,
|
|
537
|
-
distributions: DistributionCollection[K],
|
|
538
|
-
sample_pairs: List[Tuple[Any, Any]],
|
|
539
|
-
train_size: float = 0.9,
|
|
540
|
-
batch_size: int = 1024,
|
|
541
|
-
is_conditional: bool = True,
|
|
542
|
-
**kwargs: Any,
|
|
543
|
-
) -> Tuple[MultiLoader, MultiLoader]:
|
|
544
|
-
train_loaders = []
|
|
545
|
-
validate_loaders = []
|
|
546
|
-
seed = kwargs.get("seed")
|
|
547
|
-
is_aligned = kwargs.get("is_aligned", False)
|
|
548
|
-
if train_size == 1.0:
|
|
549
|
-
for sample_pair in sample_pairs:
|
|
550
|
-
source_key = sample_pair[0]
|
|
551
|
-
target_key = sample_pair[1]
|
|
552
|
-
src_data = OTData(
|
|
553
|
-
lin=distributions[source_key].xy,
|
|
554
|
-
condition=distributions[source_key].conditions if is_conditional else None,
|
|
555
|
-
)
|
|
556
|
-
tgt_data = OTData(
|
|
557
|
-
lin=distributions[target_key].xy,
|
|
558
|
-
condition=distributions[target_key].conditions if is_conditional else None,
|
|
559
|
-
)
|
|
560
|
-
dataset = OTDataset(src_data=src_data, tgt_data=tgt_data, seed=seed, is_aligned=is_aligned)
|
|
561
|
-
loader = Loader(dataset, batch_size=batch_size, seed=seed)
|
|
562
|
-
train_loaders.append(loader)
|
|
563
|
-
validate_loaders.append(loader)
|
|
564
|
-
else:
|
|
565
|
-
if train_size > 1.0 or train_size <= 0.0:
|
|
566
|
-
raise ValueError("Invalid train_size. Must be: 0 < train_size <= 1")
|
|
567
|
-
|
|
568
|
-
seed = kwargs.get("seed", 0)
|
|
569
|
-
for sample_pair in sample_pairs:
|
|
570
|
-
source_key = sample_pair[0]
|
|
571
|
-
target_key = sample_pair[1]
|
|
572
|
-
source_data: ArrayLike = distributions[source_key].xy
|
|
573
|
-
target_data: ArrayLike = distributions[target_key].xy
|
|
574
|
-
source_split_data = self._split_data(
|
|
575
|
-
source_data,
|
|
576
|
-
conditions=distributions[source_key].conditions,
|
|
577
|
-
train_size=train_size,
|
|
578
|
-
seed=seed,
|
|
579
|
-
a=distributions[source_key].a,
|
|
580
|
-
b=distributions[source_key].b,
|
|
581
|
-
)
|
|
582
|
-
target_split_data = self._split_data(
|
|
583
|
-
target_data,
|
|
584
|
-
conditions=distributions[target_key].conditions,
|
|
585
|
-
train_size=train_size,
|
|
586
|
-
seed=seed,
|
|
587
|
-
a=distributions[target_key].a,
|
|
588
|
-
b=distributions[target_key].b,
|
|
589
|
-
)
|
|
590
|
-
src_data_train = OTData(
|
|
591
|
-
lin=source_split_data.data_train,
|
|
592
|
-
condition=source_split_data.conditions_train if is_conditional else None,
|
|
593
|
-
)
|
|
594
|
-
tgt_data_train = OTData(
|
|
595
|
-
lin=target_split_data.data_train,
|
|
596
|
-
condition=target_split_data.conditions_train if is_conditional else None,
|
|
597
|
-
)
|
|
598
|
-
train_dataset = OTDataset(
|
|
599
|
-
src_data=src_data_train, tgt_data=tgt_data_train, seed=seed, is_aligned=is_aligned
|
|
600
|
-
)
|
|
601
|
-
train_loader = Loader(train_dataset, batch_size=batch_size, seed=seed)
|
|
602
|
-
src_data_validate = OTData(
|
|
603
|
-
lin=source_split_data.data_valid,
|
|
604
|
-
condition=source_split_data.conditions_valid if is_conditional else None,
|
|
605
|
-
)
|
|
606
|
-
tgt_data_validate = OTData(
|
|
607
|
-
lin=target_split_data.data_valid,
|
|
608
|
-
condition=target_split_data.conditions_valid if is_conditional else None,
|
|
609
|
-
)
|
|
610
|
-
validate_dataset = OTDataset(
|
|
611
|
-
src_data=src_data_validate, tgt_data=tgt_data_validate, seed=seed, is_aligned=is_aligned
|
|
612
|
-
)
|
|
613
|
-
validate_loader = Loader(validate_dataset, batch_size=batch_size, seed=seed)
|
|
614
|
-
train_loaders.append(train_loader)
|
|
615
|
-
validate_loaders.append(validate_loader)
|
|
616
|
-
source_dim = self._neural_kwargs.get("input_dim", 0)
|
|
617
|
-
target_dim = source_dim
|
|
618
|
-
condition_dim = self._neural_kwargs.get("cond_dim", 0)
|
|
619
|
-
# TODO(ilan-gold): What are reasonable defaults here?
|
|
620
|
-
neural_vf = VelocityField(
|
|
621
|
-
output_dims=[*self._neural_kwargs.get("velocity_field_output_dims", []), target_dim],
|
|
622
|
-
condition_dims=(
|
|
623
|
-
self._neural_kwargs.get("velocity_field_condition_dims", [source_dim + condition_dim])
|
|
624
|
-
if is_conditional
|
|
625
|
-
else None
|
|
626
|
-
),
|
|
627
|
-
hidden_dims=self._neural_kwargs.get("velocity_field_hidden_dims", [1024, 1024, 1024]),
|
|
628
|
-
time_dims=self._neural_kwargs.get("velocity_field_time_dims", None),
|
|
629
|
-
time_encoder=self._neural_kwargs.get(
|
|
630
|
-
"velocity_field_time_encoder", functools.partial(time_encoder.cyclical_time_encoder, n_freqs=1024)
|
|
631
|
-
),
|
|
632
|
-
)
|
|
633
|
-
seed = self._neural_kwargs.get("seed", 0)
|
|
634
|
-
rng = jax.random.PRNGKey(seed)
|
|
635
|
-
data_match_fn_kwargs = self._neural_kwargs.get(
|
|
636
|
-
"data_match_fn_kwargs",
|
|
637
|
-
{} if "data_match_fn" in self._neural_kwargs else {"epsilon": 1e-1, "tau_a": 1.0, "tau_b": 1.0},
|
|
638
|
-
)
|
|
639
|
-
time_sampler = self._neural_kwargs.get("time_sampler", uniform_sampler)
|
|
640
|
-
optimizer = self._neural_kwargs.get("optimizer", optax.adam(learning_rate=1e-4))
|
|
641
|
-
self._solver = genot.GENOT(
|
|
642
|
-
vf=neural_vf,
|
|
643
|
-
flow=self._neural_kwargs.get(
|
|
644
|
-
"flow",
|
|
645
|
-
dynamics.ConstantNoiseFlow(0.1),
|
|
646
|
-
),
|
|
647
|
-
data_match_fn=functools.partial(
|
|
648
|
-
self._neural_kwargs.get("data_match_fn", data_match_fn), typ="lin", **data_match_fn_kwargs
|
|
649
|
-
),
|
|
650
|
-
source_dim=source_dim,
|
|
651
|
-
target_dim=target_dim,
|
|
652
|
-
condition_dim=condition_dim if is_conditional else None,
|
|
653
|
-
optimizer=optimizer,
|
|
654
|
-
time_sampler=time_sampler,
|
|
655
|
-
rng=rng,
|
|
656
|
-
latent_noise_fn=self._neural_kwargs.get("latent_noise_fn", None),
|
|
657
|
-
**self._neural_kwargs.get("velocity_field_train_state_kwargs", {}),
|
|
658
|
-
)
|
|
659
|
-
return (
|
|
660
|
-
MultiLoader(datasets=train_loaders, seed=seed),
|
|
661
|
-
MultiLoader(datasets=validate_loaders, seed=seed),
|
|
662
|
-
)
|
|
663
|
-
|
|
664
|
-
def _split_data( # TODO: adapt for Gromov terms
|
|
665
|
-
self,
|
|
666
|
-
x: ArrayLike,
|
|
667
|
-
conditions: Optional[ArrayLike],
|
|
668
|
-
train_size: float,
|
|
669
|
-
seed: int,
|
|
670
|
-
a: Optional[ArrayLike] = None,
|
|
671
|
-
b: Optional[ArrayLike] = None,
|
|
672
|
-
) -> SingleDistributionData:
|
|
673
|
-
n_samples_x = x.shape[0]
|
|
674
|
-
n_train_x = math.ceil(train_size * n_samples_x)
|
|
675
|
-
rng = np.random.default_rng(seed)
|
|
676
|
-
x = rng.permutation(x)
|
|
677
|
-
if a is not None:
|
|
678
|
-
a = rng.permutation(a)
|
|
679
|
-
if b is not None:
|
|
680
|
-
b = rng.permutation(b)
|
|
681
|
-
|
|
682
|
-
return SingleDistributionData(
|
|
683
|
-
data_train=x[:n_train_x],
|
|
684
|
-
data_valid=x[n_train_x:],
|
|
685
|
-
conditions_train=conditions[:n_train_x] if conditions is not None else None,
|
|
686
|
-
conditions_valid=conditions[n_train_x:] if conditions is not None else None,
|
|
687
|
-
a_train=a[:n_train_x] if a is not None else None,
|
|
688
|
-
a_valid=a[n_train_x:] if a is not None else None,
|
|
689
|
-
b_train=b[:n_train_x] if b is not None else None,
|
|
690
|
-
b_valid=b[n_train_x:] if b is not None else None,
|
|
691
|
-
)
|
|
692
|
-
|
|
693
|
-
@property
|
|
694
|
-
def solver(self) -> genot.GENOT:
|
|
695
|
-
"""Underlying optimal transport solver."""
|
|
696
|
-
return self._solver
|
|
697
|
-
|
|
698
|
-
@classmethod
|
|
699
|
-
def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]:
|
|
700
|
-
return {"batch_size", "train_size", "trainloader", "validloader", "seed"}, {} # type: ignore[return-value]
|
|
701
|
-
|
|
702
|
-
def _solve(self, data_samplers: Tuple[MultiLoader, MultiLoader]) -> NeuralOutput: # type: ignore[override]
|
|
703
|
-
seed = self._neural_kwargs.get("seed", 0) # TODO(ilan-gold): unify rng hadnling like OTT tests
|
|
704
|
-
rng = jax.random.PRNGKey(seed)
|
|
705
|
-
logs = self.solver(
|
|
706
|
-
data_samplers[0], n_iters=self._neural_kwargs.get("n_iters", 100), rng=rng
|
|
707
|
-
) # TODO(ilan-gold): validation and figure out defualts
|
|
708
|
-
return NeuralOutput(self.solver, logs)
|