moscot 0.4.2__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. {moscot-0.4.2 → moscot-0.5.0}/.pre-commit-config.yaml +1 -1
  2. {moscot-0.4.2 → moscot-0.5.0}/PKG-INFO +7 -7
  3. {moscot-0.4.2 → moscot-0.5.0}/pyproject.toml +7 -9
  4. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/backends/ott/__init__.py +2 -3
  5. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/backends/ott/_utils.py +2 -2
  6. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/backends/ott/output.py +5 -224
  7. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/backends/ott/solver.py +3 -207
  8. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/backends/utils.py +8 -12
  9. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/output.py +1 -20
  10. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/problems/_mixins.py +1 -1
  11. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/problems/birth_death.py +1 -1
  12. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/costs/_costs.py +3 -3
  13. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/datasets.py +1 -1
  14. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/spatiotemporal/_spatio_temporal.py +1 -1
  15. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/subset_policy.py +7 -0
  16. {moscot-0.4.2 → moscot-0.5.0}/src/moscot.egg-info/PKG-INFO +7 -7
  17. {moscot-0.4.2 → moscot-0.5.0}/src/moscot.egg-info/SOURCES.txt +0 -6
  18. {moscot-0.4.2 → moscot-0.5.0}/src/moscot.egg-info/requires.txt +6 -7
  19. moscot-0.4.2/src/moscot/neural/base/__init__.py +0 -0
  20. moscot-0.4.2/src/moscot/neural/base/problems/__init__.py +0 -3
  21. moscot-0.4.2/src/moscot/neural/base/problems/problem.py +0 -243
  22. moscot-0.4.2/src/moscot/neural/problems/__init__.py +0 -3
  23. moscot-0.4.2/src/moscot/neural/problems/generic/__init__.py +0 -3
  24. moscot-0.4.2/src/moscot/neural/problems/generic/_generic.py +0 -78
  25. {moscot-0.4.2 → moscot-0.5.0}/.gitignore +0 -0
  26. {moscot-0.4.2 → moscot-0.5.0}/.gitmodules +0 -0
  27. {moscot-0.4.2 → moscot-0.5.0}/.readthedocs.yml +0 -0
  28. {moscot-0.4.2 → moscot-0.5.0}/.run_notebooks.sh +0 -0
  29. {moscot-0.4.2 → moscot-0.5.0}/LICENSE +0 -0
  30. {moscot-0.4.2 → moscot-0.5.0}/MANIFEST.in +0 -0
  31. {moscot-0.4.2 → moscot-0.5.0}/README.rst +0 -0
  32. {moscot-0.4.2 → moscot-0.5.0}/codecov.yml +0 -0
  33. {moscot-0.4.2 → moscot-0.5.0}/setup.cfg +0 -0
  34. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/__init__.py +0 -0
  35. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/_constants.py +0 -0
  36. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/_logging.py +0 -0
  37. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/_registry.py +0 -0
  38. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/_types.py +0 -0
  39. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/backends/__init__.py +0 -0
  40. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/__init__.py +0 -0
  41. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/cost.py +0 -0
  42. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/problems/__init__.py +0 -0
  43. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/problems/_utils.py +0 -0
  44. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/problems/compound_problem.py +0 -0
  45. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/problems/manager.py +0 -0
  46. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/problems/problem.py +0 -0
  47. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/base/solver.py +0 -0
  48. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/costs/__init__.py +0 -0
  49. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/costs/_utils.py +0 -0
  50. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/plotting/__init__.py +0 -0
  51. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/plotting/_plotting.py +0 -0
  52. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/plotting/_utils.py +0 -0
  53. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/__init__.py +0 -0
  54. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/_utils.py +0 -0
  55. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/cross_modality/__init__.py +0 -0
  56. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/cross_modality/_mixins.py +0 -0
  57. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/cross_modality/_translation.py +0 -0
  58. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/generic/__init__.py +0 -0
  59. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/generic/_generic.py +0 -0
  60. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/generic/_mixins.py +0 -0
  61. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/space/__init__.py +0 -0
  62. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/space/_alignment.py +0 -0
  63. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/space/_mapping.py +0 -0
  64. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/space/_mixins.py +0 -0
  65. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/spatiotemporal/__init__.py +0 -0
  66. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/time/__init__.py +0 -0
  67. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/time/_lineage.py +0 -0
  68. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/problems/time/_mixins.py +0 -0
  69. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/py.typed +0 -0
  70. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/__init__.py +0 -0
  71. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/_data/allTFs_dmel.txt +0 -0
  72. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/_data/allTFs_hg38.txt +0 -0
  73. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/_data/allTFs_mm.txt +0 -0
  74. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/_data/human_apoptosis.txt +0 -0
  75. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/_data/human_proliferation.txt +0 -0
  76. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/_data/mouse_apoptosis.txt +0 -0
  77. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/_data/mouse_proliferation.txt +0 -0
  78. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/data.py +0 -0
  79. {moscot-0.4.2 → moscot-0.5.0}/src/moscot/utils/tagged_array.py +0 -0
  80. {moscot-0.4.2 → moscot-0.5.0}/src/moscot.egg-info/dependency_links.txt +0 -0
  81. {moscot-0.4.2 → moscot-0.5.0}/src/moscot.egg-info/top_level.txt +0 -0
@@ -63,7 +63,7 @@ repos:
63
63
  - id: doc8
64
64
  - repo: https://github.com/astral-sh/ruff-pre-commit
65
65
  # Ruff version.
66
- rev: v0.9.10
66
+ rev: v0.11.7
67
67
  hooks:
68
68
  - id: ruff
69
69
  args: [--fix, --exit-non-zero-on-fix]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: moscot
3
- Version: 0.4.2
3
+ Version: 0.5.0
4
4
  Summary: Multi-omic single-cell optimal transport tools
5
5
  Author: Dominik Klein, Giovanni Palla, Michal Klein, Zoe Piran, Marius Lange
6
6
  Maintainer-email: Dominik Klein <dominik.klein@helmholtz-muenchen.de>, Giovanni Palla <giovanni.palla@helmholtz-muenchen.de>, Michal Klein <michal.klein@helmholtz-muenchen.de>
@@ -65,18 +65,18 @@ Requires-Dist: anndata>=0.9.1
65
65
  Requires-Dist: scanpy>=1.9.3
66
66
  Requires-Dist: wrapt>=1.13.2
67
67
  Requires-Dist: docrep>=0.3.2
68
- Requires-Dist: ott-jax>=0.5.0
68
+ Requires-Dist: jax>=0.6.1
69
+ Requires-Dist: ott-jax>=0.6.0
69
70
  Requires-Dist: cloudpickle>=2.2.0
70
71
  Requires-Dist: rich>=13.5
71
72
  Requires-Dist: docstring_inheritance>=2.0.0
72
73
  Requires-Dist: mudata>=0.2.2
74
+ Requires-Dist: optax
75
+ Requires-Dist: flax
76
+ Requires-Dist: diffrax
77
+ Requires-Dist: ott-jax[neural]>=0.5.0
73
78
  Provides-Extra: spatial
74
79
  Requires-Dist: squidpy>=1.2.3; extra == "spatial"
75
- Provides-Extra: neural
76
- Requires-Dist: optax; extra == "neural"
77
- Requires-Dist: flax; extra == "neural"
78
- Requires-Dist: diffrax; extra == "neural"
79
- Requires-Dist: ott-jax[neural]>=0.5.0; extra == "neural"
80
80
  Provides-Extra: dev
81
81
  Requires-Dist: pre-commit>=3.0.0; extra == "dev"
82
82
  Requires-Dist: tox>=4; extra == "dev"
@@ -54,11 +54,16 @@ dependencies = [
54
54
  "scanpy>=1.9.3",
55
55
  "wrapt>=1.13.2",
56
56
  "docrep>=0.3.2",
57
- "ott-jax>=0.5.0",
57
+ "jax>=0.6.1",
58
+ "ott-jax>=0.6.0",
58
59
  "cloudpickle>=2.2.0",
59
60
  "rich>=13.5",
60
61
  "docstring_inheritance>=2.0.0",
61
- "mudata>=0.2.2"
62
+ "mudata>=0.2.2",
63
+ "optax",
64
+ "flax",
65
+ "diffrax",
66
+ "ott-jax[neural]>=0.5.0"
62
67
  ]
63
68
 
64
69
  [project.optional-dependencies]
@@ -66,13 +71,6 @@ spatial = [
66
71
  "squidpy>=1.2.3"
67
72
  ]
68
73
 
69
- neural = [
70
- "optax",
71
- "flax",
72
- "diffrax",
73
- "ott-jax[neural]>=0.5.0",
74
-
75
- ]
76
74
  dev = [
77
75
  "pre-commit>=3.0.0",
78
76
  "tox>=4",
@@ -1,15 +1,14 @@
1
1
  from ott.geometry import costs
2
2
 
3
3
  from moscot.backends.ott._utils import sinkhorn_divergence
4
- from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput
5
- from moscot.backends.ott.solver import GENOTLinSolver, GWSolver, SinkhornSolver
4
+ from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
5
+ from moscot.backends.ott.solver import GWSolver, SinkhornSolver
6
6
  from moscot.costs import register_cost
7
7
 
8
8
  __all__ = [
9
9
  "OTTOutput",
10
10
  "GWSolver",
11
11
  "SinkhornSolver",
12
- "NeuralOutput",
13
12
  "sinkhorn_divergence",
14
13
  "GENOTLinSolver",
15
14
  "GraphOTTOutput",
@@ -10,7 +10,6 @@ import scipy.sparse as sp
10
10
  from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
11
11
  from ott.initializers.linear import initializers as init_lib
12
12
  from ott.initializers.linear import initializers_lr as lr_init_lib
13
- from ott.neural import datasets
14
13
  from ott.solvers import utils as solver_utils
15
14
  from ott.tools.sinkhorn_divergence import sinkhorn_divergence as sinkhorn_div
16
15
 
@@ -18,6 +17,7 @@ from moscot._logging import logger
18
17
  from moscot._types import ArrayLike, ScaleCost_t
19
18
 
20
19
  Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]
20
+ OTDataset = Any # to be removed when neural part is being removed from moscot
21
21
 
22
22
 
23
23
  __all__ = ["sinkhorn_divergence"]
@@ -272,7 +272,7 @@ def data_match_fn(
272
272
 
273
273
  class Loader:
274
274
 
275
- def __init__(self, dataset: datasets.OTDataset, batch_size: int, seed: Optional[int] = None):
275
+ def __init__(self, dataset: OTDataset, batch_size: int, seed: Optional[int] = None):
276
276
  self.dataset = dataset
277
277
  self.batch_size = batch_size
278
278
  self._rng = np.random.default_rng(seed)
@@ -1,12 +1,8 @@
1
- from typing import Any, Callable, List, Optional, Tuple, Union
2
-
3
- import jaxlib.xla_extension as xla_ext
1
+ from typing import Any, Optional, Tuple, Union
4
2
 
5
3
  import jax
6
4
  import jax.numpy as jnp
7
5
  import numpy as np
8
- import scipy.sparse as sp
9
- from ott.neural.methods.flows.genot import GENOT
10
6
  from ott.solvers.linear import sinkhorn, sinkhorn_lr
11
7
  from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr
12
8
 
@@ -14,10 +10,9 @@ import matplotlib as mpl
14
10
  import matplotlib.pyplot as plt
15
11
 
16
12
  from moscot._types import ArrayLike, Device_t
17
- from moscot.backends.ott._utils import get_nearest_neighbors
18
- from moscot.base.output import BaseDiscreteSolverOutput, BaseNeuralOutput
13
+ from moscot.base.output import BaseDiscreteSolverOutput
19
14
 
20
- __all__ = ["OTTOutput", "GraphOTTOutput", "NeuralOutput"]
15
+ __all__ = ["OTTOutput", "GraphOTTOutput"]
21
16
 
22
17
 
23
18
  class OTTOutput(BaseDiscreteSolverOutput):
@@ -209,7 +204,7 @@ class OTTOutput(BaseDiscreteSolverOutput):
209
204
  else:
210
205
  idx = 0
211
206
 
212
- if not isinstance(device, xla_ext.Device):
207
+ if not isinstance(device, jax.Device):
213
208
  try:
214
209
  device = jax.devices(device)[idx]
215
210
  except IndexError:
@@ -244,220 +239,6 @@ class OTTOutput(BaseDiscreteSolverOutput):
244
239
  return jnp.ones((n,))
245
240
 
246
241
 
247
- class NeuralOutput(BaseNeuralOutput):
248
- """Output wrapper for GENOT."""
249
-
250
- def __init__(self, model: GENOT, logs: dict[str, list[float]]):
251
- """Initialize `NeuralOutput`.
252
-
253
- Parameters
254
- ----------
255
- model : GENOT
256
- The OTT-Jax GENOT model
257
- """
258
- self._logs = logs
259
- self._model = model
260
-
261
- @property
262
- def logs(self):
263
- """Logs of the training. A dictionary containing what the numeric values are i.e., loss.
264
-
265
- Returns
266
- -------
267
- dict[str, list[float]]
268
- """
269
- return self._logs
270
-
271
- def _project_transport_matrix(
272
- self,
273
- src_dist: ArrayLike,
274
- tgt_dist: ArrayLike,
275
- func: Callable[[ArrayLike], ArrayLike],
276
- save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments
277
- batch_size: int = 1024,
278
- k: int = 30,
279
- length_scale: Optional[float] = None,
280
- seed: int = 42,
281
- recall_target: float = 0.95,
282
- aggregate_to_topk: bool = True,
283
- ) -> sp.csr_matrix:
284
- row_indices: List[ArrayLike] = []
285
- column_indices: List[ArrayLike] = []
286
- distances_list: List[ArrayLike] = []
287
- if length_scale is None:
288
- key = jax.random.PRNGKey(seed)
289
- src_batch = src_dist[jax.random.choice(key, src_dist.shape[0], shape=((batch_size,)))]
290
- tgt_batch = tgt_dist[jax.random.choice(key, tgt_dist.shape[0], shape=((batch_size,)))]
291
- length_scale = jnp.std(jnp.concatenate((func(src_batch), tgt_batch)))
292
- for index in range(0, len(src_dist), batch_size):
293
- distances, indices = get_nearest_neighbors(
294
- func(src_dist[index : index + batch_size, :]),
295
- tgt_dist,
296
- k,
297
- recall_target=recall_target,
298
- aggregate_to_topk=aggregate_to_topk,
299
- )
300
- distances = jnp.exp(-((distances / length_scale) ** 2))
301
- distances /= jnp.expand_dims(jnp.sum(distances, axis=1), axis=1)
302
- distances_list.append(distances.flatten())
303
- column_indices.append(indices.flatten())
304
- row_indices.append(
305
- jnp.repeat(jnp.arange(index, index + min(batch_size, len(src_dist) - index)), min(k, len(tgt_dist)))
306
- )
307
- distances = jnp.concatenate(distances_list)
308
- row_indices = jnp.concatenate(row_indices)
309
- column_indices = jnp.concatenate(column_indices)
310
- tm = sp.csr_matrix((distances, (row_indices, column_indices)), shape=[len(src_dist), len(tgt_dist)])
311
- if save_transport_matrix:
312
- self._transport_matrix = tm
313
- return tm
314
-
315
- def project_to_transport_matrix( # type:ignore[override]
316
- self,
317
- src_cells: ArrayLike,
318
- tgt_cells: ArrayLike,
319
- condition: ArrayLike = None,
320
- save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments
321
- batch_size: int = 1024,
322
- k: int = 30,
323
- length_scale: Optional[float] = None,
324
- seed: int = 42,
325
- recall_target: float = 0.95,
326
- aggregate_to_topk: bool = True,
327
- ) -> sp.csr_matrix:
328
- """Project conditional neural OT map onto cells.
329
-
330
- In constrast to discrete OT, (conditional) neural OT does not necessarily map cells onto cells,
331
- but a cell can also be mapped to a location between two cells. This function computes
332
- a pseudo-transport matrix considering the neighborhood of where a cell is mapped to.
333
- Therefore, a neighborhood graph of `k` target cells is computed around each transported cell
334
- of the source distribution. The assignment likelihood of each mapped cell to the target cells is then
335
- computed with a Gaussian kernel with parameter `length_scale`.
336
-
337
- Parameters
338
- ----------
339
- condition
340
- Condition `src_cells` correspond to.
341
- src_cells
342
- Cells which are to be mapped.
343
- tgt_cells
344
- Cells from which the neighborhood graph around the mapped `src_cells` are computed.
345
- forward
346
- Whether to map cells based on the forward transport map or backward transport map.
347
- save_transport_matrix
348
- Whether to save the transport matrix.
349
- batch_size
350
- Number of data points in the source distribution the neighborhood graph is computed
351
- for in parallel.
352
- k
353
- Number of neighbors to construct the k-nearest neighbor graph of a mapped cell.
354
- length_scale
355
- Length scale of the Gaussian kernel used to compute the assignment likelihood. If `None`,
356
- `length_scale` is set to the empirical standard deviation of `batch_size` pairs of data points of the
357
- mapped source and target distribution.
358
- seed
359
- Random seed for sampling the pairs of distributions for computing the variance in case `length_scale`
360
- is `None`.
361
- recall_target
362
- Recall target for the approximation.
363
- aggregate_to_topk
364
- When true, the nearest neighbor aggregates approximate results to the top-k in sorted order.
365
- When false, returns the approximate results unsorted.
366
- In this case, the number of the approximate results is implementation defined and is greater or
367
- equal to the specified k.
368
-
369
- Returns
370
- -------
371
- The projected transport matrix.
372
- """
373
- src_cells, tgt_cells = jnp.asarray(src_cells), jnp.asarray(tgt_cells)
374
- conditioned_fn: Callable[[ArrayLike], ArrayLike] = lambda x: self.push(x, condition)
375
- push = self.push if condition is None else conditioned_fn
376
- func, src_dist, tgt_dist = (push, src_cells, tgt_cells)
377
- return self._project_transport_matrix(
378
- src_dist=src_dist,
379
- tgt_dist=tgt_dist,
380
- func=func,
381
- save_transport_matrix=save_transport_matrix, # TODO(@MUCDK) adapt order of arguments
382
- batch_size=batch_size,
383
- k=k,
384
- length_scale=length_scale,
385
- seed=seed,
386
- recall_target=recall_target,
387
- aggregate_to_topk=aggregate_to_topk,
388
- )
389
-
390
- def push(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
391
- """Push distribution `x` conditioned on condition `cond`.
392
-
393
- Parameters
394
- ----------
395
- x
396
- Distribution to push.
397
- cond
398
- Condition of conditional neural OT.
399
-
400
- Returns
401
- -------
402
- Pushed distribution.
403
- """
404
- if isinstance(x, (bool, int, float, complex)):
405
- raise ValueError("Expected array, found scalar value.")
406
- if x.ndim not in (1, 2):
407
- raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.")
408
- return self._apply_forward(x, cond=cond)
409
-
410
- def _apply_forward(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
411
- return self._model.transport(x, condition=cond)
412
-
413
- @property
414
- def is_linear(self) -> bool: # noqa: D102
415
- return True # TODO(ilan-gold): need to contribute something to ott-jax so this is resolvable from GENOT
416
-
417
- @property
418
- def shape(self) -> Tuple[int, int]:
419
- """%(shape)s."""
420
- raise NotImplementedError()
421
-
422
- def to(
423
- self,
424
- device: Optional[Device_t] = None,
425
- ) -> "NeuralOutput":
426
- """Transfer the output to another device or change its data type.
427
-
428
- Parameters
429
- ----------
430
- device
431
- If not `None`, the output will be transferred to `device`.
432
-
433
- Returns
434
- -------
435
- The output on a saved on `device`.
436
- """
437
- # # TODO(michalk8): when polishing docs, move the definition to the base class + use docrep
438
- # if isinstance(device, str) and ":" in device:
439
- # device, ix = device.split(":")
440
- # idx = int(ix)
441
- # else:
442
- # idx = 0
443
-
444
- # if not isinstance(device, xla_ext.Device):
445
- # try:
446
- # device = jax.devices(device)[idx]
447
- # except IndexError as err:
448
- # raise IndexError(f"Unable to fetch the device with `id={idx}`.") from err
449
-
450
- # out = jax.device_put(self._model, device)
451
- # return NeuralOutput(out)
452
- return self # TODO(ilan-gold) move model to device
453
-
454
- @property
455
- def converged(self) -> bool:
456
- """%(converged)s."""
457
- # always return True for now
458
- return True
459
-
460
-
461
242
  class GraphOTTOutput(OTTOutput):
462
243
  """Output of :term:`OT` problems with a graph geometry in the linear term.
463
244
 
@@ -509,7 +290,7 @@ class GraphOTTOutput(OTTOutput):
509
290
  else:
510
291
  idx = 0
511
292
 
512
- if not isinstance(device, xla_ext.Device):
293
+ if not isinstance(device, jax.Device):
513
294
  try:
514
295
  device = jax.devices(device)[idx]
515
296
  except IndexError:
@@ -1,12 +1,9 @@
1
1
  import abc
2
- import functools
3
2
  import inspect
4
- import math
5
3
  import types
6
4
  from typing import (
7
5
  Any,
8
6
  Hashable,
9
- List,
10
7
  Literal,
11
8
  Mapping,
12
9
  NamedTuple,
@@ -17,21 +14,13 @@ from typing import (
17
14
  Union,
18
15
  )
19
16
 
20
- import optax
21
-
22
17
  import jax
23
18
  import jax.numpy as jnp
24
- import numpy as np
25
19
  from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud
26
- from ott.neural.datasets import OTData, OTDataset
27
- from ott.neural.methods.flows import dynamics, genot
28
- from ott.neural.networks.layers import time_encoder
29
- from ott.neural.networks.velocity_field import VelocityField
30
20
  from ott.problems.linear import linear_problem
31
21
  from ott.problems.quadratic import quadratic_problem
32
22
  from ott.solvers.linear import sinkhorn, sinkhorn_lr
33
23
  from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr
34
- from ott.solvers.utils import uniform_sampler
35
24
 
36
25
  from moscot._logging import logger
37
26
  from moscot._types import (
@@ -43,23 +32,20 @@ from moscot._types import (
43
32
  )
44
33
  from moscot.backends.ott._utils import (
45
34
  InitializerResolver,
46
- Loader,
47
- MultiLoader,
48
35
  _instantiate_geodesic_cost,
49
36
  alpha_to_fused_penalty,
50
37
  check_shapes,
51
38
  convert_scipy_sparse,
52
- data_match_fn,
53
39
  densify,
54
40
  ensure_2d,
55
41
  )
56
- from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput
42
+ from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
57
43
  from moscot.base.problems._utils import TimeScalesHeatKernel
58
44
  from moscot.base.solver import OTSolver
59
45
  from moscot.costs import get_cost
60
- from moscot.utils.tagged_array import DistributionCollection, TaggedArray
46
+ from moscot.utils.tagged_array import TaggedArray
61
47
 
62
- __all__ = ["SinkhornSolver", "GWSolver", "GENOTLinSolver"]
48
+ __all__ = ["SinkhornSolver", "GWSolver"]
63
49
 
64
50
  OTTSolver_t = Union[
65
51
  sinkhorn.Sinkhorn,
@@ -516,193 +502,3 @@ class GWSolver(OTTJaxSolver):
516
502
  problem_kwargs -= {"geom_xx", "geom_yy", "geom_xy", "fused_penalty"}
517
503
  problem_kwargs |= {"alpha"}
518
504
  return geom_kwargs | problem_kwargs, {"epsilon"}
519
-
520
-
521
- class GENOTLinSolver(OTSolver[OTTOutput]):
522
- """Solver class for genot.GENOT linear :cite:`klein2023generative`."""
523
-
524
- def __init__(self, **kwargs: Any) -> None:
525
- """Initiate the class with any kwargs passed to the ott-jax class."""
526
- super().__init__()
527
- self._train_sampler: Optional[MultiLoader] = None
528
- self._valid_sampler: Optional[MultiLoader] = None
529
- self._neural_kwargs = kwargs
530
-
531
- @property
532
- def problem_kind(self) -> ProblemKind_t: # noqa: D102
533
- return "linear"
534
-
535
- def _prepare( # type: ignore[override]
536
- self,
537
- distributions: DistributionCollection[K],
538
- sample_pairs: List[Tuple[Any, Any]],
539
- train_size: float = 0.9,
540
- batch_size: int = 1024,
541
- is_conditional: bool = True,
542
- **kwargs: Any,
543
- ) -> Tuple[MultiLoader, MultiLoader]:
544
- train_loaders = []
545
- validate_loaders = []
546
- seed = kwargs.get("seed")
547
- is_aligned = kwargs.get("is_aligned", False)
548
- if train_size == 1.0:
549
- for sample_pair in sample_pairs:
550
- source_key = sample_pair[0]
551
- target_key = sample_pair[1]
552
- src_data = OTData(
553
- lin=distributions[source_key].xy,
554
- condition=distributions[source_key].conditions if is_conditional else None,
555
- )
556
- tgt_data = OTData(
557
- lin=distributions[target_key].xy,
558
- condition=distributions[target_key].conditions if is_conditional else None,
559
- )
560
- dataset = OTDataset(src_data=src_data, tgt_data=tgt_data, seed=seed, is_aligned=is_aligned)
561
- loader = Loader(dataset, batch_size=batch_size, seed=seed)
562
- train_loaders.append(loader)
563
- validate_loaders.append(loader)
564
- else:
565
- if train_size > 1.0 or train_size <= 0.0:
566
- raise ValueError("Invalid train_size. Must be: 0 < train_size <= 1")
567
-
568
- seed = kwargs.get("seed", 0)
569
- for sample_pair in sample_pairs:
570
- source_key = sample_pair[0]
571
- target_key = sample_pair[1]
572
- source_data: ArrayLike = distributions[source_key].xy
573
- target_data: ArrayLike = distributions[target_key].xy
574
- source_split_data = self._split_data(
575
- source_data,
576
- conditions=distributions[source_key].conditions,
577
- train_size=train_size,
578
- seed=seed,
579
- a=distributions[source_key].a,
580
- b=distributions[source_key].b,
581
- )
582
- target_split_data = self._split_data(
583
- target_data,
584
- conditions=distributions[target_key].conditions,
585
- train_size=train_size,
586
- seed=seed,
587
- a=distributions[target_key].a,
588
- b=distributions[target_key].b,
589
- )
590
- src_data_train = OTData(
591
- lin=source_split_data.data_train,
592
- condition=source_split_data.conditions_train if is_conditional else None,
593
- )
594
- tgt_data_train = OTData(
595
- lin=target_split_data.data_train,
596
- condition=target_split_data.conditions_train if is_conditional else None,
597
- )
598
- train_dataset = OTDataset(
599
- src_data=src_data_train, tgt_data=tgt_data_train, seed=seed, is_aligned=is_aligned
600
- )
601
- train_loader = Loader(train_dataset, batch_size=batch_size, seed=seed)
602
- src_data_validate = OTData(
603
- lin=source_split_data.data_valid,
604
- condition=source_split_data.conditions_valid if is_conditional else None,
605
- )
606
- tgt_data_validate = OTData(
607
- lin=target_split_data.data_valid,
608
- condition=target_split_data.conditions_valid if is_conditional else None,
609
- )
610
- validate_dataset = OTDataset(
611
- src_data=src_data_validate, tgt_data=tgt_data_validate, seed=seed, is_aligned=is_aligned
612
- )
613
- validate_loader = Loader(validate_dataset, batch_size=batch_size, seed=seed)
614
- train_loaders.append(train_loader)
615
- validate_loaders.append(validate_loader)
616
- source_dim = self._neural_kwargs.get("input_dim", 0)
617
- target_dim = source_dim
618
- condition_dim = self._neural_kwargs.get("cond_dim", 0)
619
- # TODO(ilan-gold): What are reasonable defaults here?
620
- neural_vf = VelocityField(
621
- output_dims=[*self._neural_kwargs.get("velocity_field_output_dims", []), target_dim],
622
- condition_dims=(
623
- self._neural_kwargs.get("velocity_field_condition_dims", [source_dim + condition_dim])
624
- if is_conditional
625
- else None
626
- ),
627
- hidden_dims=self._neural_kwargs.get("velocity_field_hidden_dims", [1024, 1024, 1024]),
628
- time_dims=self._neural_kwargs.get("velocity_field_time_dims", None),
629
- time_encoder=self._neural_kwargs.get(
630
- "velocity_field_time_encoder", functools.partial(time_encoder.cyclical_time_encoder, n_freqs=1024)
631
- ),
632
- )
633
- seed = self._neural_kwargs.get("seed", 0)
634
- rng = jax.random.PRNGKey(seed)
635
- data_match_fn_kwargs = self._neural_kwargs.get(
636
- "data_match_fn_kwargs",
637
- {} if "data_match_fn" in self._neural_kwargs else {"epsilon": 1e-1, "tau_a": 1.0, "tau_b": 1.0},
638
- )
639
- time_sampler = self._neural_kwargs.get("time_sampler", uniform_sampler)
640
- optimizer = self._neural_kwargs.get("optimizer", optax.adam(learning_rate=1e-4))
641
- self._solver = genot.GENOT(
642
- vf=neural_vf,
643
- flow=self._neural_kwargs.get(
644
- "flow",
645
- dynamics.ConstantNoiseFlow(0.1),
646
- ),
647
- data_match_fn=functools.partial(
648
- self._neural_kwargs.get("data_match_fn", data_match_fn), typ="lin", **data_match_fn_kwargs
649
- ),
650
- source_dim=source_dim,
651
- target_dim=target_dim,
652
- condition_dim=condition_dim if is_conditional else None,
653
- optimizer=optimizer,
654
- time_sampler=time_sampler,
655
- rng=rng,
656
- latent_noise_fn=self._neural_kwargs.get("latent_noise_fn", None),
657
- **self._neural_kwargs.get("velocity_field_train_state_kwargs", {}),
658
- )
659
- return (
660
- MultiLoader(datasets=train_loaders, seed=seed),
661
- MultiLoader(datasets=validate_loaders, seed=seed),
662
- )
663
-
664
- def _split_data( # TODO: adapt for Gromov terms
665
- self,
666
- x: ArrayLike,
667
- conditions: Optional[ArrayLike],
668
- train_size: float,
669
- seed: int,
670
- a: Optional[ArrayLike] = None,
671
- b: Optional[ArrayLike] = None,
672
- ) -> SingleDistributionData:
673
- n_samples_x = x.shape[0]
674
- n_train_x = math.ceil(train_size * n_samples_x)
675
- rng = np.random.default_rng(seed)
676
- x = rng.permutation(x)
677
- if a is not None:
678
- a = rng.permutation(a)
679
- if b is not None:
680
- b = rng.permutation(b)
681
-
682
- return SingleDistributionData(
683
- data_train=x[:n_train_x],
684
- data_valid=x[n_train_x:],
685
- conditions_train=conditions[:n_train_x] if conditions is not None else None,
686
- conditions_valid=conditions[n_train_x:] if conditions is not None else None,
687
- a_train=a[:n_train_x] if a is not None else None,
688
- a_valid=a[n_train_x:] if a is not None else None,
689
- b_train=b[:n_train_x] if b is not None else None,
690
- b_valid=b[n_train_x:] if b is not None else None,
691
- )
692
-
693
- @property
694
- def solver(self) -> genot.GENOT:
695
- """Underlying optimal transport solver."""
696
- return self._solver
697
-
698
- @classmethod
699
- def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]:
700
- return {"batch_size", "train_size", "trainloader", "validloader", "seed"}, {} # type: ignore[return-value]
701
-
702
- def _solve(self, data_samplers: Tuple[MultiLoader, MultiLoader]) -> NeuralOutput: # type: ignore[override]
703
- seed = self._neural_kwargs.get("seed", 0) # TODO(ilan-gold): unify rng hadnling like OTT tests
704
- rng = jax.random.PRNGKey(seed)
705
- logs = self.solver(
706
- data_samplers[0], n_iters=self._neural_kwargs.get("n_iters", 100), rng=rng
707
- ) # TODO(ilan-gold): validation and figure out defualts
708
- return NeuralOutput(self.solver, logs)