moscot 0.3.3__tar.gz → 0.3.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. {moscot-0.3.3 → moscot-0.3.5}/.gitignore +1 -0
  2. {moscot-0.3.3 → moscot-0.3.5}/.pre-commit-config.yaml +9 -9
  3. moscot-0.3.5/.run_notebooks.sh +62 -0
  4. {moscot-0.3.3 → moscot-0.3.5}/PKG-INFO +3 -2
  5. {moscot-0.3.3 → moscot-0.3.5}/pyproject.toml +19 -2
  6. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/_types.py +6 -10
  7. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/backends/ott/__init__.py +2 -4
  8. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/backends/ott/_utils.py +52 -5
  9. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/backends/ott/output.py +65 -2
  10. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/backends/ott/solver.py +132 -15
  11. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/problems/_mixins.py +199 -31
  12. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/problems/_utils.py +19 -1
  13. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/problems/birth_death.py +43 -14
  14. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/problems/compound_problem.py +72 -38
  15. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/problems/problem.py +269 -58
  16. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/solver.py +0 -1
  17. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/costs/_costs.py +5 -2
  18. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/costs/_utils.py +15 -4
  19. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/datasets.py +51 -7
  20. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/plotting/_plotting.py +11 -11
  21. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/plotting/_utils.py +9 -6
  22. moscot-0.3.5/src/moscot/problems/_utils.py +126 -0
  23. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/cross_modality/_mixins.py +62 -5
  24. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/cross_modality/_translation.py +54 -25
  25. moscot-0.3.5/src/moscot/problems/generic/__init__.py +4 -0
  26. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/generic/_generic.py +332 -52
  27. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/generic/_mixins.py +10 -3
  28. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/space/_alignment.py +66 -21
  29. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/space/_mapping.py +85 -37
  30. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/space/_mixins.py +207 -30
  31. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/spatiotemporal/_spatio_temporal.py +28 -9
  32. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/time/_lineage.py +73 -22
  33. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/time/_mixins.py +98 -47
  34. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/_data/mouse_proliferation.txt +0 -1
  35. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/subset_policy.py +7 -12
  36. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/tagged_array.py +33 -6
  37. {moscot-0.3.3 → moscot-0.3.5}/src/moscot.egg-info/PKG-INFO +3 -2
  38. {moscot-0.3.3 → moscot-0.3.5}/src/moscot.egg-info/SOURCES.txt +1 -0
  39. {moscot-0.3.3 → moscot-0.3.5}/src/moscot.egg-info/requires.txt +2 -1
  40. moscot-0.3.3/src/moscot/problems/_utils.py +0 -87
  41. moscot-0.3.3/src/moscot/problems/generic/__init__.py +0 -4
  42. {moscot-0.3.3 → moscot-0.3.5}/.gitmodules +0 -0
  43. {moscot-0.3.3 → moscot-0.3.5}/.readthedocs.yml +0 -0
  44. {moscot-0.3.3 → moscot-0.3.5}/LICENSE +0 -0
  45. {moscot-0.3.3 → moscot-0.3.5}/MANIFEST.in +0 -0
  46. {moscot-0.3.3 → moscot-0.3.5}/README.rst +0 -0
  47. {moscot-0.3.3 → moscot-0.3.5}/codecov.yml +0 -0
  48. {moscot-0.3.3 → moscot-0.3.5}/setup.cfg +0 -0
  49. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/__init__.py +0 -0
  50. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/_constants.py +0 -0
  51. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/_logging.py +0 -0
  52. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/_registry.py +0 -0
  53. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/backends/__init__.py +0 -0
  54. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/backends/utils.py +0 -0
  55. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/__init__.py +0 -0
  56. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/cost.py +0 -0
  57. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/output.py +0 -0
  58. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/problems/__init__.py +0 -0
  59. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/problems/manager.py +0 -0
  60. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/costs/__init__.py +0 -0
  61. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/plotting/__init__.py +0 -0
  62. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/__init__.py +0 -0
  63. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/cross_modality/__init__.py +0 -0
  64. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/space/__init__.py +0 -0
  65. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/spatiotemporal/__init__.py +0 -0
  66. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/time/__init__.py +0 -0
  67. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/py.typed +0 -0
  68. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/__init__.py +0 -0
  69. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/_data/allTFs_dmel.txt +0 -0
  70. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/_data/allTFs_hg38.txt +0 -0
  71. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/_data/allTFs_mm.txt +0 -0
  72. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/_data/human_apoptosis.txt +0 -0
  73. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/_data/human_proliferation.txt +0 -0
  74. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/_data/mouse_apoptosis.txt +0 -0
  75. {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/data.py +0 -0
  76. {moscot-0.3.3 → moscot-0.3.5}/src/moscot.egg-info/dependency_links.txt +0 -0
  77. {moscot-0.3.3 → moscot-0.3.5}/src/moscot.egg-info/top_level.txt +0 -0
@@ -154,3 +154,4 @@ packages.dot
154
154
 
155
155
  # plotting tests
156
156
  tests/plotting/actual_figures/
157
+ tests/plotting/figures/
@@ -7,29 +7,29 @@ default_stages:
7
7
  minimum_pre_commit_version: 3.0.0
8
8
  repos:
9
9
  - repo: https://github.com/pre-commit/mirrors-mypy
10
- rev: v1.5.1
10
+ rev: v1.10.1
11
11
  hooks:
12
12
  - id: mypy
13
13
  additional_dependencies: [numpy>=1.25.0]
14
14
  files: ^src
15
15
  - repo: https://github.com/psf/black
16
- rev: 23.7.0
16
+ rev: 24.4.2
17
17
  hooks:
18
18
  - id: black
19
19
  additional_dependencies: [toml]
20
20
  - repo: https://github.com/pre-commit/mirrors-prettier
21
- rev: v3.0.2
21
+ rev: v4.0.0-alpha.8
22
22
  hooks:
23
23
  - id: prettier
24
24
  language_version: system
25
25
  - repo: https://github.com/PyCQA/isort
26
- rev: 5.12.0
26
+ rev: 5.13.2
27
27
  hooks:
28
28
  - id: isort
29
29
  additional_dependencies: [toml]
30
30
  args: [--order-by-type]
31
31
  - repo: https://github.com/pre-commit/pre-commit-hooks
32
- rev: v4.4.0
32
+ rev: v4.6.0
33
33
  hooks:
34
34
  - id: check-merge-conflict
35
35
  - id: check-ast
@@ -42,17 +42,17 @@ repos:
42
42
  - id: check-yaml
43
43
  - id: check-toml
44
44
  - repo: https://github.com/asottile/pyupgrade
45
- rev: v3.10.1
45
+ rev: v3.16.0
46
46
  hooks:
47
47
  - id: pyupgrade
48
48
  args: [--py3-plus, --py38-plus, --keep-runtime-typing]
49
49
  - repo: https://github.com/asottile/blacken-docs
50
- rev: 1.16.0
50
+ rev: 1.18.0
51
51
  hooks:
52
52
  - id: blacken-docs
53
53
  additional_dependencies: [black==23.1.0]
54
54
  - repo: https://github.com/rstcheck/rstcheck
55
- rev: v6.1.2
55
+ rev: v6.2.0
56
56
  hooks:
57
57
  - id: rstcheck
58
58
  additional_dependencies: [tomli]
@@ -63,7 +63,7 @@ repos:
63
63
  - id: doc8
64
64
  - repo: https://github.com/astral-sh/ruff-pre-commit
65
65
  # Ruff version.
66
- rev: v0.0.286
66
+ rev: v0.5.0
67
67
  hooks:
68
68
  - id: ruff
69
69
  args: [--fix, --exit-non-zero-on-fix]
@@ -0,0 +1,62 @@
1
+ #!/bin/bash
2
+
3
+ # Check if the base directory is provided as an argument
4
+ if [ "$#" -ne 1 ]; then
5
+ echo "Usage: $0 <base_notebook_directory>"
6
+ exit 1
7
+ fi
8
+
9
+ # Base directory for notebooks
10
+ base_dir=$1
11
+
12
+ # Define notebook directories or patterns
13
+ declare -a notebooks=(
14
+ "$base_dir/examples/plotting/*.ipynb"
15
+ "$base_dir/examples/problems/*.ipynb"
16
+ "$base_dir/examples/solvers/*.ipynb"
17
+ )
18
+
19
+ # Initialize an array to hold valid notebook paths
20
+ declare -a valid_notebooks
21
+
22
+ # Gather all valid notebook files from the patterns
23
+ echo "Gathering notebooks..."
24
+ for pattern in "${notebooks[@]}"; do
25
+ for nb in $pattern; do
26
+ if [[ -f "$nb" ]]; then # Check if the file exists
27
+ valid_notebooks+=("$nb") # Add to the list of valid notebooks
28
+ fi
29
+ done
30
+ done
31
+
32
+ # Check if we have any notebooks to run
33
+ if [ ${#valid_notebooks[@]} -eq 0 ]; then
34
+ echo "No notebooks found to run."
35
+ exit 1
36
+ fi
37
+
38
+ # Echo the notebooks that will be run for clarity
39
+ echo "Preparing to run the following notebooks:"
40
+ for nb in "${valid_notebooks[@]}"; do
41
+ echo "$nb"
42
+ done
43
+
44
+ # Initialize a flag to track the success of all commands
45
+ all_success=true
46
+
47
+ # Execute all valid notebooks
48
+ for nb in "${valid_notebooks[@]}"; do
49
+ echo "Running $nb"
50
+ jupytext -k moscot --execute "$nb" || {
51
+ echo "Failed to run $nb"
52
+ all_success=false
53
+ }
54
+ done
55
+
56
+ # Check if any executions failed
57
+ if [ "$all_success" = false ]; then
58
+ echo "One or more notebooks failed to execute."
59
+ exit 1
60
+ fi
61
+
62
+ echo "All notebooks executed successfully."
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: moscot
3
- Version: 0.3.3
3
+ Version: 0.3.5
4
4
  Summary: Multi-omic single-cell optimal transport tools
5
5
  Author: Dominik Klein, Giovanni Palla, Michal Klein, Zoe Piran, Marius Lange
6
6
  Maintainer-email: Dominik Klein <dominik.klein@helmholtz-muenchen.de>, Giovanni Palla <giovanni.palla@helmholtz-muenchen.de>, Michal Klein <michal.klein@helmholtz-muenchen.de>
@@ -66,9 +66,10 @@ Requires-Dist: anndata>=0.9.1
66
66
  Requires-Dist: scanpy>=1.9.3
67
67
  Requires-Dist: wrapt>=1.13.2
68
68
  Requires-Dist: docrep>=0.3.2
69
- Requires-Dist: ott-jax>=0.4.3
69
+ Requires-Dist: ott-jax>=0.4.6
70
70
  Requires-Dist: cloudpickle>=2.2.0
71
71
  Requires-Dist: rich>=13.5
72
+ Requires-Dist: docstring_inheritance>=2.0.0
72
73
  Provides-Extra: spatial
73
74
  Requires-Dist: squidpy>=1.2.3; extra == "spatial"
74
75
  Provides-Extra: dev
@@ -55,9 +55,10 @@ dependencies = [
55
55
  "scanpy>=1.9.3",
56
56
  "wrapt>=1.13.2",
57
57
  "docrep>=0.3.2",
58
- "ott-jax>=0.4.3",
58
+ "ott-jax>=0.4.6",
59
59
  "cloudpickle>=2.2.0",
60
60
  "rich>=13.5",
61
+ "docstring_inheritance>=2.0.0"
61
62
  ]
62
63
 
63
64
  [project.optional-dependencies]
@@ -122,6 +123,8 @@ ignore = [
122
123
  "D107",
123
124
  # Missing docstring in magic method
124
125
  "D105",
126
+ # Use `X | Y` for type annotations
127
+ "UP007",
125
128
  ]
126
129
  line-length = 120
127
130
  select = [
@@ -280,7 +283,6 @@ commands =
280
283
 
281
284
  [testenv:lint-docs]
282
285
  description = Lint the documentation.
283
- deps =
284
286
  extras = docs
285
287
  ignore_errors = true
286
288
  allowlist_externals = make
@@ -292,6 +294,21 @@ commands =
292
294
  # TODO(michalk8): uncomment after https://github.com/theislab/moscot/issues/490
293
295
  # make spelling {posargs}
294
296
 
297
+ [testenv:examples-docs]
298
+ allowlist_externals = bash
299
+ description = Run the notebooks.
300
+ use_develop = true
301
+ deps =
302
+ ipykernel
303
+ jupytext
304
+ nbconvert
305
+ leidenalg
306
+ extras = docs
307
+ changedir = {tox_root}{/}docs
308
+ commands =
309
+ python -m ipykernel install --user --name=moscot
310
+ bash {tox_root}/.run_notebooks.sh {tox_root}{/}docs/notebooks
311
+
295
312
  [testenv:clean-docs]
296
313
  description = Remove the documentation.
297
314
  deps =
@@ -8,7 +8,7 @@ import numpy as np
8
8
  try:
9
9
  from numpy.typing import DTypeLike, NDArray
10
10
 
11
- ArrayLike = NDArray[np.float_]
11
+ ArrayLike = NDArray[np.float64]
12
12
  except (ImportError, TypeError):
13
13
  ArrayLike = np.ndarray # type: ignore[misc]
14
14
  DTypeLike = np.dtype # type: ignore[misc]
@@ -33,20 +33,16 @@ OttCostFn_t = Literal[
33
33
  "euclidean",
34
34
  "sq_euclidean",
35
35
  "cosine",
36
- "PNormP",
37
- "SqPNorm",
38
- "Euclidean",
39
- "SqEuclidean",
40
- "Cosine",
41
- "ElasticL1",
42
- "ElasticSTVS",
43
- "ElasticSqKOverlap",
36
+ "pnorm_p",
37
+ "sq_pnorm",
38
+ "cosine",
39
+ "geodesic",
44
40
  ]
45
41
  OttCostFnMap_t = Union[OttCostFn_t, Mapping[Literal["xy", "x", "y"], OttCostFn_t]]
46
42
  GenericCostFn_t = Literal["barcode_distance", "leaf_distance", "custom"]
47
43
  CostFn_t = Union[str, GenericCostFn_t, OttCostFn_t]
48
44
  CostFnMap_t = Union[Union[OttCostFn_t, GenericCostFn_t], Mapping[str, Union[OttCostFn_t, GenericCostFn_t]]]
49
- PathLike = Union[os.PathLike, str]
45
+ PathLike = Union[os.PathLike, str] # type: ignore[type-arg]
50
46
  Policy_t = Literal[
51
47
  "sequential",
52
48
  "star",
@@ -1,16 +1,14 @@
1
1
  from ott.geometry import costs
2
2
 
3
3
  from moscot.backends.ott._utils import sinkhorn_divergence
4
- from moscot.backends.ott.output import OTTOutput
4
+ from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
5
5
  from moscot.backends.ott.solver import GWSolver, SinkhornSolver
6
6
  from moscot.costs import register_cost
7
7
 
8
- __all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"]
8
+ __all__ = ["OTTOutput", "GraphOTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"]
9
9
 
10
10
  register_cost("euclidean", backend="ott")(costs.Euclidean)
11
11
  register_cost("sq_euclidean", backend="ott")(costs.SqEuclidean)
12
12
  register_cost("cosine", backend="ott")(costs.Cosine)
13
13
  register_cost("pnorm_p", backend="ott")(costs.PNormP)
14
14
  register_cost("sq_pnorm", backend="ott")(costs.SqPNorm)
15
- register_cost("elastic_l1", backend="ott")(costs.ElasticL1)
16
- register_cost("elastic_stvs", backend="ott")(costs.ElasticSTVS)
@@ -1,14 +1,18 @@
1
- from typing import Any, Optional, Union
1
+ from typing import Any, Literal, Optional, Tuple, Union
2
2
 
3
3
  import jax
4
+ import jax.experimental.sparse as jesp
4
5
  import jax.numpy as jnp
5
6
  import scipy.sparse as sp
6
- from ott.geometry import epsilon_scheduler, geometry, pointcloud
7
+ from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
7
8
  from ott.tools import sinkhorn_divergence as sdiv
8
9
 
9
10
  from moscot._logging import logger
10
11
  from moscot._types import ArrayLike, ScaleCost_t
11
12
 
13
+ Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]
14
+
15
+
12
16
  __all__ = ["sinkhorn_divergence"]
13
17
 
14
18
 
@@ -64,6 +68,25 @@ def alpha_to_fused_penalty(alpha: float) -> float:
64
68
  return (1 - alpha) / alpha
65
69
 
66
70
 
71
+ def densify(arr: ArrayLike) -> jax.Array:
72
+ """If the input is sparse, convert it to dense.
73
+
74
+ Parameters
75
+ ----------
76
+ arr
77
+ Array to check.
78
+
79
+ Returns
80
+ -------
81
+ dense :mod:`jax` array.
82
+ """
83
+ if sp.issparse(arr):
84
+ arr = arr.toarray() # type: ignore[attr-defined]
85
+ elif isinstance(arr, jesp.BCOO):
86
+ arr = arr.todense()
87
+ return jnp.asarray(arr)
88
+
89
+
67
90
  def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
68
91
  """Ensure that an array is 2-dimensional.
69
92
 
@@ -78,11 +101,35 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
78
101
  -------
79
102
  2-dimensional :mod:`jax` array.
80
103
  """
81
- if sp.issparse(arr):
82
- arr = arr.A # type: ignore[attr-defined]
83
- arr = jnp.asarray(arr)
84
104
  if reshape and arr.ndim == 1:
85
105
  return jnp.reshape(arr, (-1, 1))
86
106
  if arr.ndim != 2:
87
107
  raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
88
108
  return arr
109
+
110
+
111
+ def convert_scipy_sparse(arr: Union[sp.spmatrix, jesp.BCOO]) -> jesp.BCOO:
112
+ """If the input is a scipy sparse matrix, convert it to a jax BCOO."""
113
+ if sp.issparse(arr):
114
+ return jesp.BCOO.from_scipy_sparse(arr)
115
+ return arr
116
+
117
+
118
+ def _instantiate_geodesic_cost(
119
+ arr: jax.Array,
120
+ problem_shape: Tuple[int, int],
121
+ t: Optional[float],
122
+ is_linear_term: bool,
123
+ epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
124
+ relative_epsilon: Optional[bool] = None,
125
+ scale_cost: Scale_t = 1.0,
126
+ directed: bool = True,
127
+ **kwargs: Any,
128
+ ) -> geometry.Geometry:
129
+ n_src, n_tgt = problem_shape
130
+ if is_linear_term and n_src + n_tgt != arr.shape[0]:
131
+ raise ValueError(f"Expected `x` to have `{n_src + n_tgt}` points, found `{arr.shape[0]}`.")
132
+ t = epsilon / 4.0 if t is None else t
133
+ cm_full = geodesic.Geodesic.from_graph(arr, t=t, directed=directed, **kwargs).cost_matrix
134
+ cm = cm_full[:n_src, n_src:] if is_linear_term else cm_full
135
+ return geometry.Geometry(cm, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost)
@@ -14,7 +14,7 @@ import matplotlib.pyplot as plt
14
14
  from moscot._types import ArrayLike, Device_t
15
15
  from moscot.base.output import BaseSolverOutput
16
16
 
17
- __all__ = ["OTTOutput"]
17
+ __all__ = ["OTTOutput", "GraphOTTOutput"]
18
18
 
19
19
 
20
20
  class OTTOutput(BaseSolverOutput):
@@ -174,7 +174,10 @@ class OTTOutput(BaseSolverOutput):
174
174
  def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
175
175
  if x.ndim == 1:
176
176
  return self._output.apply(x, axis=1 - forward)
177
- return self._output.apply(x.T, axis=1 - forward).T # convert to batch first
177
+ return self._output.apply(
178
+ x.T,
179
+ axis=1 - forward,
180
+ ).T # convert to batch first
178
181
 
179
182
  @property
180
183
  def shape(self) -> Tuple[int, int]: # noqa: D102
@@ -233,3 +236,63 @@ class OTTOutput(BaseSolverOutput):
233
236
 
234
237
  def _ones(self, n: int) -> ArrayLike: # noqa: D102
235
238
  return jnp.ones((n,))
239
+
240
+
241
+ class GraphOTTOutput(OTTOutput):
242
+ """Output of :term:`OT` problems with a graph geometry in the linear term.
243
+
244
+ Parameters
245
+ ----------
246
+ output
247
+ Output of the :mod:`ott` backend.
248
+ shape
249
+ Shape of the problem.
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ output: Union[
255
+ sinkhorn.SinkhornOutput,
256
+ sinkhorn_lr.LRSinkhornOutput,
257
+ gromov_wasserstein.GWOutput,
258
+ gromov_wasserstein_lr.LRGWOutput,
259
+ ],
260
+ shape: Tuple[int, int],
261
+ ):
262
+ super().__init__(output)
263
+ self._shape = shape
264
+
265
+ @property
266
+ def shape(self) -> Tuple[int, int]: # noqa: D102
267
+ return self._shape
268
+
269
+ def _expand_data(self, x: jnp.ndarray, forward: bool) -> jnp.ndarray:
270
+ if forward:
271
+ shape = (self.shape[1],) if x.ndim == 1 else (self.shape[1], x.shape[1])
272
+ return jnp.concatenate((x, jnp.zeros(shape)))
273
+ shape = (self.shape[0],) if x.ndim == 1 else (self.shape[0], x.shape[1])
274
+ return jnp.concatenate((jnp.zeros(shape), x))
275
+
276
+ def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
277
+ x_expanded = self._expand_data(x, forward=forward)
278
+ # ott-jax only supports lse_mode=False with graph geometry
279
+ res = self._output.apply(x_expanded.T, axis=1 - forward, lse_mode=False).T
280
+ return res[len(x) :] if forward else res[: -len(x)]
281
+
282
+ def to(self, device: Optional[Device_t] = None) -> "GraphOTTOutput": # noqa: D102
283
+ if device is None:
284
+ return GraphOTTOutput(jax.device_put(self._output, device=device), shape=self.shape)
285
+
286
+ if isinstance(device, str) and ":" in device:
287
+ device, ix = device.split(":")
288
+ idx = int(ix)
289
+ else:
290
+ idx = 0
291
+
292
+ if not isinstance(device, xla_ext.Device):
293
+ try:
294
+ device = jax.devices(device)[idx]
295
+ except IndexError:
296
+ raise IndexError(f"Unable to fetch the device with `id={idx}`.") from None
297
+
298
+ return GraphOTTOutput(jax.device_put(self._output, device), shape=self.shape)