moscot 0.3.3__tar.gz → 0.3.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {moscot-0.3.3 → moscot-0.3.5}/.gitignore +1 -0
- {moscot-0.3.3 → moscot-0.3.5}/.pre-commit-config.yaml +9 -9
- moscot-0.3.5/.run_notebooks.sh +62 -0
- {moscot-0.3.3 → moscot-0.3.5}/PKG-INFO +3 -2
- {moscot-0.3.3 → moscot-0.3.5}/pyproject.toml +19 -2
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/_types.py +6 -10
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/backends/ott/__init__.py +2 -4
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/backends/ott/_utils.py +52 -5
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/backends/ott/output.py +65 -2
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/backends/ott/solver.py +132 -15
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/problems/_mixins.py +199 -31
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/problems/_utils.py +19 -1
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/problems/birth_death.py +43 -14
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/problems/compound_problem.py +72 -38
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/problems/problem.py +269 -58
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/solver.py +0 -1
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/costs/_costs.py +5 -2
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/costs/_utils.py +15 -4
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/datasets.py +51 -7
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/plotting/_plotting.py +11 -11
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/plotting/_utils.py +9 -6
- moscot-0.3.5/src/moscot/problems/_utils.py +126 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/cross_modality/_mixins.py +62 -5
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/cross_modality/_translation.py +54 -25
- moscot-0.3.5/src/moscot/problems/generic/__init__.py +4 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/generic/_generic.py +332 -52
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/generic/_mixins.py +10 -3
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/space/_alignment.py +66 -21
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/space/_mapping.py +85 -37
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/space/_mixins.py +207 -30
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/spatiotemporal/_spatio_temporal.py +28 -9
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/time/_lineage.py +73 -22
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/time/_mixins.py +98 -47
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/_data/mouse_proliferation.txt +0 -1
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/subset_policy.py +7 -12
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/tagged_array.py +33 -6
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot.egg-info/PKG-INFO +3 -2
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot.egg-info/SOURCES.txt +1 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot.egg-info/requires.txt +2 -1
- moscot-0.3.3/src/moscot/problems/_utils.py +0 -87
- moscot-0.3.3/src/moscot/problems/generic/__init__.py +0 -4
- {moscot-0.3.3 → moscot-0.3.5}/.gitmodules +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/.readthedocs.yml +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/LICENSE +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/MANIFEST.in +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/README.rst +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/codecov.yml +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/setup.cfg +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/_constants.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/_logging.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/_registry.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/backends/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/backends/utils.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/cost.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/output.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/problems/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/base/problems/manager.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/costs/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/plotting/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/cross_modality/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/space/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/spatiotemporal/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/problems/time/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/py.typed +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/__init__.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/_data/allTFs_dmel.txt +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/_data/allTFs_hg38.txt +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/_data/allTFs_mm.txt +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/_data/human_apoptosis.txt +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/_data/human_proliferation.txt +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/_data/mouse_apoptosis.txt +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot/utils/data.py +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot.egg-info/dependency_links.txt +0 -0
- {moscot-0.3.3 → moscot-0.3.5}/src/moscot.egg-info/top_level.txt +0 -0
|
@@ -7,29 +7,29 @@ default_stages:
|
|
|
7
7
|
minimum_pre_commit_version: 3.0.0
|
|
8
8
|
repos:
|
|
9
9
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
|
10
|
-
rev: v1.
|
|
10
|
+
rev: v1.10.1
|
|
11
11
|
hooks:
|
|
12
12
|
- id: mypy
|
|
13
13
|
additional_dependencies: [numpy>=1.25.0]
|
|
14
14
|
files: ^src
|
|
15
15
|
- repo: https://github.com/psf/black
|
|
16
|
-
rev:
|
|
16
|
+
rev: 24.4.2
|
|
17
17
|
hooks:
|
|
18
18
|
- id: black
|
|
19
19
|
additional_dependencies: [toml]
|
|
20
20
|
- repo: https://github.com/pre-commit/mirrors-prettier
|
|
21
|
-
rev:
|
|
21
|
+
rev: v4.0.0-alpha.8
|
|
22
22
|
hooks:
|
|
23
23
|
- id: prettier
|
|
24
24
|
language_version: system
|
|
25
25
|
- repo: https://github.com/PyCQA/isort
|
|
26
|
-
rev: 5.
|
|
26
|
+
rev: 5.13.2
|
|
27
27
|
hooks:
|
|
28
28
|
- id: isort
|
|
29
29
|
additional_dependencies: [toml]
|
|
30
30
|
args: [--order-by-type]
|
|
31
31
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
32
|
-
rev: v4.
|
|
32
|
+
rev: v4.6.0
|
|
33
33
|
hooks:
|
|
34
34
|
- id: check-merge-conflict
|
|
35
35
|
- id: check-ast
|
|
@@ -42,17 +42,17 @@ repos:
|
|
|
42
42
|
- id: check-yaml
|
|
43
43
|
- id: check-toml
|
|
44
44
|
- repo: https://github.com/asottile/pyupgrade
|
|
45
|
-
rev: v3.
|
|
45
|
+
rev: v3.16.0
|
|
46
46
|
hooks:
|
|
47
47
|
- id: pyupgrade
|
|
48
48
|
args: [--py3-plus, --py38-plus, --keep-runtime-typing]
|
|
49
49
|
- repo: https://github.com/asottile/blacken-docs
|
|
50
|
-
rev: 1.
|
|
50
|
+
rev: 1.18.0
|
|
51
51
|
hooks:
|
|
52
52
|
- id: blacken-docs
|
|
53
53
|
additional_dependencies: [black==23.1.0]
|
|
54
54
|
- repo: https://github.com/rstcheck/rstcheck
|
|
55
|
-
rev: v6.
|
|
55
|
+
rev: v6.2.0
|
|
56
56
|
hooks:
|
|
57
57
|
- id: rstcheck
|
|
58
58
|
additional_dependencies: [tomli]
|
|
@@ -63,7 +63,7 @@ repos:
|
|
|
63
63
|
- id: doc8
|
|
64
64
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
65
65
|
# Ruff version.
|
|
66
|
-
rev: v0.0
|
|
66
|
+
rev: v0.5.0
|
|
67
67
|
hooks:
|
|
68
68
|
- id: ruff
|
|
69
69
|
args: [--fix, --exit-non-zero-on-fix]
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
|
|
3
|
+
# Check if the base directory is provided as an argument
|
|
4
|
+
if [ "$#" -ne 1 ]; then
|
|
5
|
+
echo "Usage: $0 <base_notebook_directory>"
|
|
6
|
+
exit 1
|
|
7
|
+
fi
|
|
8
|
+
|
|
9
|
+
# Base directory for notebooks
|
|
10
|
+
base_dir=$1
|
|
11
|
+
|
|
12
|
+
# Define notebook directories or patterns
|
|
13
|
+
declare -a notebooks=(
|
|
14
|
+
"$base_dir/examples/plotting/*.ipynb"
|
|
15
|
+
"$base_dir/examples/problems/*.ipynb"
|
|
16
|
+
"$base_dir/examples/solvers/*.ipynb"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Initialize an array to hold valid notebook paths
|
|
20
|
+
declare -a valid_notebooks
|
|
21
|
+
|
|
22
|
+
# Gather all valid notebook files from the patterns
|
|
23
|
+
echo "Gathering notebooks..."
|
|
24
|
+
for pattern in "${notebooks[@]}"; do
|
|
25
|
+
for nb in $pattern; do
|
|
26
|
+
if [[ -f "$nb" ]]; then # Check if the file exists
|
|
27
|
+
valid_notebooks+=("$nb") # Add to the list of valid notebooks
|
|
28
|
+
fi
|
|
29
|
+
done
|
|
30
|
+
done
|
|
31
|
+
|
|
32
|
+
# Check if we have any notebooks to run
|
|
33
|
+
if [ ${#valid_notebooks[@]} -eq 0 ]; then
|
|
34
|
+
echo "No notebooks found to run."
|
|
35
|
+
exit 1
|
|
36
|
+
fi
|
|
37
|
+
|
|
38
|
+
# Echo the notebooks that will be run for clarity
|
|
39
|
+
echo "Preparing to run the following notebooks:"
|
|
40
|
+
for nb in "${valid_notebooks[@]}"; do
|
|
41
|
+
echo "$nb"
|
|
42
|
+
done
|
|
43
|
+
|
|
44
|
+
# Initialize a flag to track the success of all commands
|
|
45
|
+
all_success=true
|
|
46
|
+
|
|
47
|
+
# Execute all valid notebooks
|
|
48
|
+
for nb in "${valid_notebooks[@]}"; do
|
|
49
|
+
echo "Running $nb"
|
|
50
|
+
jupytext -k moscot --execute "$nb" || {
|
|
51
|
+
echo "Failed to run $nb"
|
|
52
|
+
all_success=false
|
|
53
|
+
}
|
|
54
|
+
done
|
|
55
|
+
|
|
56
|
+
# Check if any executions failed
|
|
57
|
+
if [ "$all_success" = false ]; then
|
|
58
|
+
echo "One or more notebooks failed to execute."
|
|
59
|
+
exit 1
|
|
60
|
+
fi
|
|
61
|
+
|
|
62
|
+
echo "All notebooks executed successfully."
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: moscot
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.5
|
|
4
4
|
Summary: Multi-omic single-cell optimal transport tools
|
|
5
5
|
Author: Dominik Klein, Giovanni Palla, Michal Klein, Zoe Piran, Marius Lange
|
|
6
6
|
Maintainer-email: Dominik Klein <dominik.klein@helmholtz-muenchen.de>, Giovanni Palla <giovanni.palla@helmholtz-muenchen.de>, Michal Klein <michal.klein@helmholtz-muenchen.de>
|
|
@@ -66,9 +66,10 @@ Requires-Dist: anndata>=0.9.1
|
|
|
66
66
|
Requires-Dist: scanpy>=1.9.3
|
|
67
67
|
Requires-Dist: wrapt>=1.13.2
|
|
68
68
|
Requires-Dist: docrep>=0.3.2
|
|
69
|
-
Requires-Dist: ott-jax>=0.4.
|
|
69
|
+
Requires-Dist: ott-jax>=0.4.6
|
|
70
70
|
Requires-Dist: cloudpickle>=2.2.0
|
|
71
71
|
Requires-Dist: rich>=13.5
|
|
72
|
+
Requires-Dist: docstring_inheritance>=2.0.0
|
|
72
73
|
Provides-Extra: spatial
|
|
73
74
|
Requires-Dist: squidpy>=1.2.3; extra == "spatial"
|
|
74
75
|
Provides-Extra: dev
|
|
@@ -55,9 +55,10 @@ dependencies = [
|
|
|
55
55
|
"scanpy>=1.9.3",
|
|
56
56
|
"wrapt>=1.13.2",
|
|
57
57
|
"docrep>=0.3.2",
|
|
58
|
-
"ott-jax>=0.4.
|
|
58
|
+
"ott-jax>=0.4.6",
|
|
59
59
|
"cloudpickle>=2.2.0",
|
|
60
60
|
"rich>=13.5",
|
|
61
|
+
"docstring_inheritance>=2.0.0"
|
|
61
62
|
]
|
|
62
63
|
|
|
63
64
|
[project.optional-dependencies]
|
|
@@ -122,6 +123,8 @@ ignore = [
|
|
|
122
123
|
"D107",
|
|
123
124
|
# Missing docstring in magic method
|
|
124
125
|
"D105",
|
|
126
|
+
# Use `X | Y` for type annotations
|
|
127
|
+
"UP007",
|
|
125
128
|
]
|
|
126
129
|
line-length = 120
|
|
127
130
|
select = [
|
|
@@ -280,7 +283,6 @@ commands =
|
|
|
280
283
|
|
|
281
284
|
[testenv:lint-docs]
|
|
282
285
|
description = Lint the documentation.
|
|
283
|
-
deps =
|
|
284
286
|
extras = docs
|
|
285
287
|
ignore_errors = true
|
|
286
288
|
allowlist_externals = make
|
|
@@ -292,6 +294,21 @@ commands =
|
|
|
292
294
|
# TODO(michalk8): uncomment after https://github.com/theislab/moscot/issues/490
|
|
293
295
|
# make spelling {posargs}
|
|
294
296
|
|
|
297
|
+
[testenv:examples-docs]
|
|
298
|
+
allowlist_externals = bash
|
|
299
|
+
description = Run the notebooks.
|
|
300
|
+
use_develop = true
|
|
301
|
+
deps =
|
|
302
|
+
ipykernel
|
|
303
|
+
jupytext
|
|
304
|
+
nbconvert
|
|
305
|
+
leidenalg
|
|
306
|
+
extras = docs
|
|
307
|
+
changedir = {tox_root}{/}docs
|
|
308
|
+
commands =
|
|
309
|
+
python -m ipykernel install --user --name=moscot
|
|
310
|
+
bash {tox_root}/.run_notebooks.sh {tox_root}{/}docs/notebooks
|
|
311
|
+
|
|
295
312
|
[testenv:clean-docs]
|
|
296
313
|
description = Remove the documentation.
|
|
297
314
|
deps =
|
|
@@ -8,7 +8,7 @@ import numpy as np
|
|
|
8
8
|
try:
|
|
9
9
|
from numpy.typing import DTypeLike, NDArray
|
|
10
10
|
|
|
11
|
-
ArrayLike = NDArray[np.
|
|
11
|
+
ArrayLike = NDArray[np.float64]
|
|
12
12
|
except (ImportError, TypeError):
|
|
13
13
|
ArrayLike = np.ndarray # type: ignore[misc]
|
|
14
14
|
DTypeLike = np.dtype # type: ignore[misc]
|
|
@@ -33,20 +33,16 @@ OttCostFn_t = Literal[
|
|
|
33
33
|
"euclidean",
|
|
34
34
|
"sq_euclidean",
|
|
35
35
|
"cosine",
|
|
36
|
-
"
|
|
37
|
-
"
|
|
38
|
-
"
|
|
39
|
-
"
|
|
40
|
-
"Cosine",
|
|
41
|
-
"ElasticL1",
|
|
42
|
-
"ElasticSTVS",
|
|
43
|
-
"ElasticSqKOverlap",
|
|
36
|
+
"pnorm_p",
|
|
37
|
+
"sq_pnorm",
|
|
38
|
+
"cosine",
|
|
39
|
+
"geodesic",
|
|
44
40
|
]
|
|
45
41
|
OttCostFnMap_t = Union[OttCostFn_t, Mapping[Literal["xy", "x", "y"], OttCostFn_t]]
|
|
46
42
|
GenericCostFn_t = Literal["barcode_distance", "leaf_distance", "custom"]
|
|
47
43
|
CostFn_t = Union[str, GenericCostFn_t, OttCostFn_t]
|
|
48
44
|
CostFnMap_t = Union[Union[OttCostFn_t, GenericCostFn_t], Mapping[str, Union[OttCostFn_t, GenericCostFn_t]]]
|
|
49
|
-
PathLike = Union[os.PathLike, str]
|
|
45
|
+
PathLike = Union[os.PathLike, str] # type: ignore[type-arg]
|
|
50
46
|
Policy_t = Literal[
|
|
51
47
|
"sequential",
|
|
52
48
|
"star",
|
|
@@ -1,16 +1,14 @@
|
|
|
1
1
|
from ott.geometry import costs
|
|
2
2
|
|
|
3
3
|
from moscot.backends.ott._utils import sinkhorn_divergence
|
|
4
|
-
from moscot.backends.ott.output import OTTOutput
|
|
4
|
+
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
|
|
5
5
|
from moscot.backends.ott.solver import GWSolver, SinkhornSolver
|
|
6
6
|
from moscot.costs import register_cost
|
|
7
7
|
|
|
8
|
-
__all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"]
|
|
8
|
+
__all__ = ["OTTOutput", "GraphOTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"]
|
|
9
9
|
|
|
10
10
|
register_cost("euclidean", backend="ott")(costs.Euclidean)
|
|
11
11
|
register_cost("sq_euclidean", backend="ott")(costs.SqEuclidean)
|
|
12
12
|
register_cost("cosine", backend="ott")(costs.Cosine)
|
|
13
13
|
register_cost("pnorm_p", backend="ott")(costs.PNormP)
|
|
14
14
|
register_cost("sq_pnorm", backend="ott")(costs.SqPNorm)
|
|
15
|
-
register_cost("elastic_l1", backend="ott")(costs.ElasticL1)
|
|
16
|
-
register_cost("elastic_stvs", backend="ott")(costs.ElasticSTVS)
|
|
@@ -1,14 +1,18 @@
|
|
|
1
|
-
from typing import Any, Optional, Union
|
|
1
|
+
from typing import Any, Literal, Optional, Tuple, Union
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
|
+
import jax.experimental.sparse as jesp
|
|
4
5
|
import jax.numpy as jnp
|
|
5
6
|
import scipy.sparse as sp
|
|
6
|
-
from ott.geometry import epsilon_scheduler, geometry, pointcloud
|
|
7
|
+
from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
|
|
7
8
|
from ott.tools import sinkhorn_divergence as sdiv
|
|
8
9
|
|
|
9
10
|
from moscot._logging import logger
|
|
10
11
|
from moscot._types import ArrayLike, ScaleCost_t
|
|
11
12
|
|
|
13
|
+
Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]
|
|
14
|
+
|
|
15
|
+
|
|
12
16
|
__all__ = ["sinkhorn_divergence"]
|
|
13
17
|
|
|
14
18
|
|
|
@@ -64,6 +68,25 @@ def alpha_to_fused_penalty(alpha: float) -> float:
|
|
|
64
68
|
return (1 - alpha) / alpha
|
|
65
69
|
|
|
66
70
|
|
|
71
|
+
def densify(arr: ArrayLike) -> jax.Array:
|
|
72
|
+
"""If the input is sparse, convert it to dense.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
arr
|
|
77
|
+
Array to check.
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
dense :mod:`jax` array.
|
|
82
|
+
"""
|
|
83
|
+
if sp.issparse(arr):
|
|
84
|
+
arr = arr.toarray() # type: ignore[attr-defined]
|
|
85
|
+
elif isinstance(arr, jesp.BCOO):
|
|
86
|
+
arr = arr.todense()
|
|
87
|
+
return jnp.asarray(arr)
|
|
88
|
+
|
|
89
|
+
|
|
67
90
|
def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
|
|
68
91
|
"""Ensure that an array is 2-dimensional.
|
|
69
92
|
|
|
@@ -78,11 +101,35 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
|
|
|
78
101
|
-------
|
|
79
102
|
2-dimensional :mod:`jax` array.
|
|
80
103
|
"""
|
|
81
|
-
if sp.issparse(arr):
|
|
82
|
-
arr = arr.A # type: ignore[attr-defined]
|
|
83
|
-
arr = jnp.asarray(arr)
|
|
84
104
|
if reshape and arr.ndim == 1:
|
|
85
105
|
return jnp.reshape(arr, (-1, 1))
|
|
86
106
|
if arr.ndim != 2:
|
|
87
107
|
raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
|
|
88
108
|
return arr
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def convert_scipy_sparse(arr: Union[sp.spmatrix, jesp.BCOO]) -> jesp.BCOO:
|
|
112
|
+
"""If the input is a scipy sparse matrix, convert it to a jax BCOO."""
|
|
113
|
+
if sp.issparse(arr):
|
|
114
|
+
return jesp.BCOO.from_scipy_sparse(arr)
|
|
115
|
+
return arr
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _instantiate_geodesic_cost(
|
|
119
|
+
arr: jax.Array,
|
|
120
|
+
problem_shape: Tuple[int, int],
|
|
121
|
+
t: Optional[float],
|
|
122
|
+
is_linear_term: bool,
|
|
123
|
+
epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
|
|
124
|
+
relative_epsilon: Optional[bool] = None,
|
|
125
|
+
scale_cost: Scale_t = 1.0,
|
|
126
|
+
directed: bool = True,
|
|
127
|
+
**kwargs: Any,
|
|
128
|
+
) -> geometry.Geometry:
|
|
129
|
+
n_src, n_tgt = problem_shape
|
|
130
|
+
if is_linear_term and n_src + n_tgt != arr.shape[0]:
|
|
131
|
+
raise ValueError(f"Expected `x` to have `{n_src + n_tgt}` points, found `{arr.shape[0]}`.")
|
|
132
|
+
t = epsilon / 4.0 if t is None else t
|
|
133
|
+
cm_full = geodesic.Geodesic.from_graph(arr, t=t, directed=directed, **kwargs).cost_matrix
|
|
134
|
+
cm = cm_full[:n_src, n_src:] if is_linear_term else cm_full
|
|
135
|
+
return geometry.Geometry(cm, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost)
|
|
@@ -14,7 +14,7 @@ import matplotlib.pyplot as plt
|
|
|
14
14
|
from moscot._types import ArrayLike, Device_t
|
|
15
15
|
from moscot.base.output import BaseSolverOutput
|
|
16
16
|
|
|
17
|
-
__all__ = ["OTTOutput"]
|
|
17
|
+
__all__ = ["OTTOutput", "GraphOTTOutput"]
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class OTTOutput(BaseSolverOutput):
|
|
@@ -174,7 +174,10 @@ class OTTOutput(BaseSolverOutput):
|
|
|
174
174
|
def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
|
|
175
175
|
if x.ndim == 1:
|
|
176
176
|
return self._output.apply(x, axis=1 - forward)
|
|
177
|
-
return self._output.apply(
|
|
177
|
+
return self._output.apply(
|
|
178
|
+
x.T,
|
|
179
|
+
axis=1 - forward,
|
|
180
|
+
).T # convert to batch first
|
|
178
181
|
|
|
179
182
|
@property
|
|
180
183
|
def shape(self) -> Tuple[int, int]: # noqa: D102
|
|
@@ -233,3 +236,63 @@ class OTTOutput(BaseSolverOutput):
|
|
|
233
236
|
|
|
234
237
|
def _ones(self, n: int) -> ArrayLike: # noqa: D102
|
|
235
238
|
return jnp.ones((n,))
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class GraphOTTOutput(OTTOutput):
|
|
242
|
+
"""Output of :term:`OT` problems with a graph geometry in the linear term.
|
|
243
|
+
|
|
244
|
+
Parameters
|
|
245
|
+
----------
|
|
246
|
+
output
|
|
247
|
+
Output of the :mod:`ott` backend.
|
|
248
|
+
shape
|
|
249
|
+
Shape of the problem.
|
|
250
|
+
"""
|
|
251
|
+
|
|
252
|
+
def __init__(
|
|
253
|
+
self,
|
|
254
|
+
output: Union[
|
|
255
|
+
sinkhorn.SinkhornOutput,
|
|
256
|
+
sinkhorn_lr.LRSinkhornOutput,
|
|
257
|
+
gromov_wasserstein.GWOutput,
|
|
258
|
+
gromov_wasserstein_lr.LRGWOutput,
|
|
259
|
+
],
|
|
260
|
+
shape: Tuple[int, int],
|
|
261
|
+
):
|
|
262
|
+
super().__init__(output)
|
|
263
|
+
self._shape = shape
|
|
264
|
+
|
|
265
|
+
@property
|
|
266
|
+
def shape(self) -> Tuple[int, int]: # noqa: D102
|
|
267
|
+
return self._shape
|
|
268
|
+
|
|
269
|
+
def _expand_data(self, x: jnp.ndarray, forward: bool) -> jnp.ndarray:
|
|
270
|
+
if forward:
|
|
271
|
+
shape = (self.shape[1],) if x.ndim == 1 else (self.shape[1], x.shape[1])
|
|
272
|
+
return jnp.concatenate((x, jnp.zeros(shape)))
|
|
273
|
+
shape = (self.shape[0],) if x.ndim == 1 else (self.shape[0], x.shape[1])
|
|
274
|
+
return jnp.concatenate((jnp.zeros(shape), x))
|
|
275
|
+
|
|
276
|
+
def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
|
|
277
|
+
x_expanded = self._expand_data(x, forward=forward)
|
|
278
|
+
# ott-jax only supports lse_mode=False with graph geometry
|
|
279
|
+
res = self._output.apply(x_expanded.T, axis=1 - forward, lse_mode=False).T
|
|
280
|
+
return res[len(x) :] if forward else res[: -len(x)]
|
|
281
|
+
|
|
282
|
+
def to(self, device: Optional[Device_t] = None) -> "GraphOTTOutput": # noqa: D102
|
|
283
|
+
if device is None:
|
|
284
|
+
return GraphOTTOutput(jax.device_put(self._output, device=device), shape=self.shape)
|
|
285
|
+
|
|
286
|
+
if isinstance(device, str) and ":" in device:
|
|
287
|
+
device, ix = device.split(":")
|
|
288
|
+
idx = int(ix)
|
|
289
|
+
else:
|
|
290
|
+
idx = 0
|
|
291
|
+
|
|
292
|
+
if not isinstance(device, xla_ext.Device):
|
|
293
|
+
try:
|
|
294
|
+
device = jax.devices(device)[idx]
|
|
295
|
+
except IndexError:
|
|
296
|
+
raise IndexError(f"Unable to fetch the device with `id={idx}`.") from None
|
|
297
|
+
|
|
298
|
+
return GraphOTTOutput(jax.device_put(self._output, device), shape=self.shape)
|