moscot 0.3.4__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {moscot-0.3.4 → moscot-0.4.0}/.gitignore +1 -0
- {moscot-0.3.4 → moscot-0.4.0}/.pre-commit-config.yaml +10 -10
- moscot-0.4.0/.run_notebooks.sh +62 -0
- {moscot-0.3.4 → moscot-0.4.0}/PKG-INFO +10 -5
- {moscot-0.3.4 → moscot-0.4.0}/pyproject.toml +30 -8
- moscot-0.4.0/src/moscot/__init__.py +13 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/_types.py +9 -9
- moscot-0.4.0/src/moscot/backends/ott/__init__.py +15 -0
- moscot-0.4.0/src/moscot/backends/ott/_utils.py +331 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/backends/ott/output.py +247 -4
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/backends/ott/solver.py +282 -29
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/backends/utils.py +19 -7
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/output.py +99 -40
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/problems/__init__.py +2 -1
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/problems/_mixins.py +135 -224
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/problems/_utils.py +11 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/problems/birth_death.py +43 -53
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/problems/compound_problem.py +52 -17
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/problems/manager.py +4 -4
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/problems/problem.py +367 -44
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/solver.py +9 -6
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/datasets.py +111 -46
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/plotting/_utils.py +2 -3
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/__init__.py +2 -0
- moscot-0.4.0/src/moscot/problems/_utils.py +203 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/cross_modality/_mixins.py +17 -27
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/cross_modality/_translation.py +53 -24
- moscot-0.4.0/src/moscot/problems/generic/__init__.py +14 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/generic/_generic.py +189 -43
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/generic/_mixins.py +11 -27
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/space/_alignment.py +65 -24
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/space/_mapping.py +82 -33
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/space/_mixins.py +131 -110
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/spatiotemporal/_spatio_temporal.py +23 -9
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/time/_lineage.py +74 -27
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/time/_mixins.py +45 -169
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/subset_policy.py +5 -0
- moscot-0.4.0/src/moscot/utils/tagged_array.py +373 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot.egg-info/PKG-INFO +10 -5
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot.egg-info/SOURCES.txt +1 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot.egg-info/requires.txt +9 -2
- moscot-0.3.4/src/moscot/__init__.py +0 -13
- moscot-0.3.4/src/moscot/backends/ott/__init__.py +0 -18
- moscot-0.3.4/src/moscot/backends/ott/_utils.py +0 -111
- moscot-0.3.4/src/moscot/problems/_utils.py +0 -85
- moscot-0.3.4/src/moscot/problems/generic/__init__.py +0 -4
- moscot-0.3.4/src/moscot/utils/tagged_array.py +0 -187
- {moscot-0.3.4 → moscot-0.4.0}/.gitmodules +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/.readthedocs.yml +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/LICENSE +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/MANIFEST.in +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/README.rst +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/codecov.yml +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/setup.cfg +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/_constants.py +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/_logging.py +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/_registry.py +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/backends/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/base/cost.py +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/costs/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/costs/_costs.py +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/costs/_utils.py +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/plotting/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/plotting/_plotting.py +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/cross_modality/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/space/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/spatiotemporal/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/problems/time/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/py.typed +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/_data/allTFs_dmel.txt +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/_data/allTFs_hg38.txt +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/_data/allTFs_mm.txt +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/_data/human_apoptosis.txt +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/_data/human_proliferation.txt +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/_data/mouse_apoptosis.txt +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/_data/mouse_proliferation.txt +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot/utils/data.py +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot.egg-info/dependency_links.txt +0 -0
- {moscot-0.3.4 → moscot-0.4.0}/src/moscot.egg-info/top_level.txt +0 -0
|
@@ -2,18 +2,18 @@ fail_fast: false
|
|
|
2
2
|
default_language_version:
|
|
3
3
|
python: python3
|
|
4
4
|
default_stages:
|
|
5
|
-
- commit
|
|
6
|
-
- push
|
|
5
|
+
- pre-commit
|
|
6
|
+
- pre-push
|
|
7
7
|
minimum_pre_commit_version: 3.0.0
|
|
8
8
|
repos:
|
|
9
9
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
|
10
|
-
rev: v1.
|
|
10
|
+
rev: v1.13.0
|
|
11
11
|
hooks:
|
|
12
12
|
- id: mypy
|
|
13
13
|
additional_dependencies: [numpy>=1.25.0]
|
|
14
14
|
files: ^src
|
|
15
15
|
- repo: https://github.com/psf/black
|
|
16
|
-
rev: 24.
|
|
16
|
+
rev: 24.10.0
|
|
17
17
|
hooks:
|
|
18
18
|
- id: black
|
|
19
19
|
additional_dependencies: [toml]
|
|
@@ -29,7 +29,7 @@ repos:
|
|
|
29
29
|
additional_dependencies: [toml]
|
|
30
30
|
args: [--order-by-type]
|
|
31
31
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
32
|
-
rev:
|
|
32
|
+
rev: v5.0.0
|
|
33
33
|
hooks:
|
|
34
34
|
- id: check-merge-conflict
|
|
35
35
|
- id: check-ast
|
|
@@ -42,28 +42,28 @@ repos:
|
|
|
42
42
|
- id: check-yaml
|
|
43
43
|
- id: check-toml
|
|
44
44
|
- repo: https://github.com/asottile/pyupgrade
|
|
45
|
-
rev: v3.
|
|
45
|
+
rev: v3.19.0
|
|
46
46
|
hooks:
|
|
47
47
|
- id: pyupgrade
|
|
48
48
|
args: [--py3-plus, --py38-plus, --keep-runtime-typing]
|
|
49
49
|
- repo: https://github.com/asottile/blacken-docs
|
|
50
|
-
rev: 1.
|
|
50
|
+
rev: 1.19.1
|
|
51
51
|
hooks:
|
|
52
52
|
- id: blacken-docs
|
|
53
53
|
additional_dependencies: [black==23.1.0]
|
|
54
54
|
- repo: https://github.com/rstcheck/rstcheck
|
|
55
|
-
rev: v6.2.
|
|
55
|
+
rev: v6.2.4
|
|
56
56
|
hooks:
|
|
57
57
|
- id: rstcheck
|
|
58
58
|
additional_dependencies: [tomli]
|
|
59
59
|
args: [--config=pyproject.toml]
|
|
60
60
|
- repo: https://github.com/PyCQA/doc8
|
|
61
|
-
rev: v1.1.
|
|
61
|
+
rev: v1.1.2
|
|
62
62
|
hooks:
|
|
63
63
|
- id: doc8
|
|
64
64
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
65
65
|
# Ruff version.
|
|
66
|
-
rev: v0.
|
|
66
|
+
rev: v0.7.2
|
|
67
67
|
hooks:
|
|
68
68
|
- id: ruff
|
|
69
69
|
args: [--fix, --exit-non-zero-on-fix]
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
|
|
3
|
+
# Check if the base directory is provided as an argument
|
|
4
|
+
if [ "$#" -ne 1 ]; then
|
|
5
|
+
echo "Usage: $0 <base_notebook_directory>"
|
|
6
|
+
exit 1
|
|
7
|
+
fi
|
|
8
|
+
|
|
9
|
+
# Base directory for notebooks
|
|
10
|
+
base_dir=$1
|
|
11
|
+
|
|
12
|
+
# Define notebook directories or patterns
|
|
13
|
+
declare -a notebooks=(
|
|
14
|
+
"$base_dir/examples/plotting/*.ipynb"
|
|
15
|
+
"$base_dir/examples/problems/*.ipynb"
|
|
16
|
+
"$base_dir/examples/solvers/*.ipynb"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Initialize an array to hold valid notebook paths
|
|
20
|
+
declare -a valid_notebooks
|
|
21
|
+
|
|
22
|
+
# Gather all valid notebook files from the patterns
|
|
23
|
+
echo "Gathering notebooks..."
|
|
24
|
+
for pattern in "${notebooks[@]}"; do
|
|
25
|
+
for nb in $pattern; do
|
|
26
|
+
if [[ -f "$nb" ]]; then # Check if the file exists
|
|
27
|
+
valid_notebooks+=("$nb") # Add to the list of valid notebooks
|
|
28
|
+
fi
|
|
29
|
+
done
|
|
30
|
+
done
|
|
31
|
+
|
|
32
|
+
# Check if we have any notebooks to run
|
|
33
|
+
if [ ${#valid_notebooks[@]} -eq 0 ]; then
|
|
34
|
+
echo "No notebooks found to run."
|
|
35
|
+
exit 1
|
|
36
|
+
fi
|
|
37
|
+
|
|
38
|
+
# Echo the notebooks that will be run for clarity
|
|
39
|
+
echo "Preparing to run the following notebooks:"
|
|
40
|
+
for nb in "${valid_notebooks[@]}"; do
|
|
41
|
+
echo "$nb"
|
|
42
|
+
done
|
|
43
|
+
|
|
44
|
+
# Initialize a flag to track the success of all commands
|
|
45
|
+
all_success=true
|
|
46
|
+
|
|
47
|
+
# Execute all valid notebooks
|
|
48
|
+
for nb in "${valid_notebooks[@]}"; do
|
|
49
|
+
echo "Running $nb"
|
|
50
|
+
jupytext -k moscot --execute "$nb" || {
|
|
51
|
+
echo "Failed to run $nb"
|
|
52
|
+
all_success=false
|
|
53
|
+
}
|
|
54
|
+
done
|
|
55
|
+
|
|
56
|
+
# Check if any executions failed
|
|
57
|
+
if [ "$all_success" = false ]; then
|
|
58
|
+
echo "One or more notebooks failed to execute."
|
|
59
|
+
exit 1
|
|
60
|
+
fi
|
|
61
|
+
|
|
62
|
+
echo "All notebooks executed successfully."
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: moscot
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: Multi-omic single-cell optimal transport tools
|
|
5
5
|
Author: Dominik Klein, Giovanni Palla, Michal Klein, Zoe Piran, Marius Lange
|
|
6
6
|
Maintainer-email: Dominik Klein <dominik.klein@helmholtz-muenchen.de>, Giovanni Palla <giovanni.palla@helmholtz-muenchen.de>, Michal Klein <michal.klein@helmholtz-muenchen.de>
|
|
@@ -49,28 +49,33 @@ Classifier: Operating System :: MacOS :: MacOS X
|
|
|
49
49
|
Classifier: Operating System :: Microsoft :: Windows
|
|
50
50
|
Classifier: Typing :: Typed
|
|
51
51
|
Classifier: Programming Language :: Python :: 3
|
|
52
|
-
Classifier: Programming Language :: Python :: 3.8
|
|
53
52
|
Classifier: Programming Language :: Python :: 3.9
|
|
54
53
|
Classifier: Programming Language :: Python :: 3.10
|
|
55
54
|
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
|
|
56
55
|
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
57
|
-
Requires-Python: >=3.
|
|
56
|
+
Requires-Python: >=3.10
|
|
58
57
|
Description-Content-Type: text/x-rst
|
|
59
58
|
License-File: LICENSE
|
|
60
59
|
Requires-Dist: numpy>=1.20.0
|
|
61
60
|
Requires-Dist: scipy>=1.7.0
|
|
62
61
|
Requires-Dist: pandas>=2.0.1
|
|
63
|
-
Requires-Dist: networkx>=2
|
|
62
|
+
Requires-Dist: networkx>=3.2
|
|
64
63
|
Requires-Dist: matplotlib>=3.5.0
|
|
65
64
|
Requires-Dist: anndata>=0.9.1
|
|
66
65
|
Requires-Dist: scanpy>=1.9.3
|
|
67
66
|
Requires-Dist: wrapt>=1.13.2
|
|
68
67
|
Requires-Dist: docrep>=0.3.2
|
|
69
|
-
Requires-Dist: ott-jax>=0.
|
|
68
|
+
Requires-Dist: ott-jax[neural]>=0.5.0
|
|
70
69
|
Requires-Dist: cloudpickle>=2.2.0
|
|
71
70
|
Requires-Dist: rich>=13.5
|
|
71
|
+
Requires-Dist: docstring_inheritance>=2.0.0
|
|
72
|
+
Requires-Dist: mudata>=0.2.2
|
|
72
73
|
Provides-Extra: spatial
|
|
73
74
|
Requires-Dist: squidpy>=1.2.3; extra == "spatial"
|
|
75
|
+
Provides-Extra: neural
|
|
76
|
+
Requires-Dist: optax; extra == "neural"
|
|
77
|
+
Requires-Dist: flax; extra == "neural"
|
|
78
|
+
Requires-Dist: diffrax; extra == "neural"
|
|
74
79
|
Provides-Extra: dev
|
|
75
80
|
Requires-Dist: pre-commit>=3.0.0; extra == "dev"
|
|
76
81
|
Requires-Dist: tox>=4; extra == "dev"
|
|
@@ -7,7 +7,7 @@ name = "moscot"
|
|
|
7
7
|
dynamic = ["version"]
|
|
8
8
|
description = "Multi-omic single-cell optimal transport tools"
|
|
9
9
|
readme = "README.rst"
|
|
10
|
-
requires-python = ">=3.
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
11
|
license = {file = "LICENSE"}
|
|
12
12
|
classifiers = [
|
|
13
13
|
"Development Status :: 4 - Beta",
|
|
@@ -19,7 +19,6 @@ classifiers = [
|
|
|
19
19
|
"Operating System :: Microsoft :: Windows",
|
|
20
20
|
"Typing :: Typed",
|
|
21
21
|
"Programming Language :: Python :: 3",
|
|
22
|
-
"Programming Language :: Python :: 3.8",
|
|
23
22
|
"Programming Language :: Python :: 3.9",
|
|
24
23
|
"Programming Language :: Python :: 3.10",
|
|
25
24
|
"Topic :: Scientific/Engineering :: Bio-Informatics",
|
|
@@ -48,22 +47,31 @@ dependencies = [
|
|
|
48
47
|
"numpy>=1.20.0",
|
|
49
48
|
"scipy>=1.7.0",
|
|
50
49
|
"pandas>=2.0.1",
|
|
51
|
-
"networkx>=2
|
|
50
|
+
"networkx>=3.2",
|
|
52
51
|
# https://github.com/scverse/scanpy/issues/2411
|
|
53
52
|
"matplotlib>=3.5.0",
|
|
54
53
|
"anndata>=0.9.1",
|
|
55
54
|
"scanpy>=1.9.3",
|
|
56
55
|
"wrapt>=1.13.2",
|
|
57
56
|
"docrep>=0.3.2",
|
|
58
|
-
"ott-jax>=0.
|
|
57
|
+
"ott-jax[neural]>=0.5.0",
|
|
59
58
|
"cloudpickle>=2.2.0",
|
|
60
59
|
"rich>=13.5",
|
|
60
|
+
"docstring_inheritance>=2.0.0",
|
|
61
|
+
"mudata>=0.2.2"
|
|
61
62
|
]
|
|
62
63
|
|
|
63
64
|
[project.optional-dependencies]
|
|
64
65
|
spatial = [
|
|
65
66
|
"squidpy>=1.2.3"
|
|
66
67
|
]
|
|
68
|
+
|
|
69
|
+
neural = [
|
|
70
|
+
"optax",
|
|
71
|
+
"flax",
|
|
72
|
+
"diffrax",
|
|
73
|
+
|
|
74
|
+
]
|
|
67
75
|
dev = [
|
|
68
76
|
"pre-commit>=3.0.0",
|
|
69
77
|
"tox>=4",
|
|
@@ -225,7 +233,7 @@ ignore_roles = [
|
|
|
225
233
|
|
|
226
234
|
[tool.mypy]
|
|
227
235
|
mypy_path = "$MYPY_CONFIG_FILE_DIR/src"
|
|
228
|
-
python_version = "3.
|
|
236
|
+
python_version = "3.10"
|
|
229
237
|
plugins = "numpy.typing.mypy_plugin"
|
|
230
238
|
|
|
231
239
|
ignore_errors = false
|
|
@@ -262,16 +270,16 @@ max_line_length = 120
|
|
|
262
270
|
legacy_tox_ini = """
|
|
263
271
|
[tox]
|
|
264
272
|
min_version = 4.0
|
|
265
|
-
env_list = lint-code,py{3.
|
|
273
|
+
env_list = lint-code,py{3.10,3.11,3.12}
|
|
266
274
|
skip_missing_interpreters = true
|
|
267
275
|
|
|
268
276
|
[testenv]
|
|
269
277
|
extras = test
|
|
270
|
-
pass_env = PYTEST_*,CI
|
|
271
278
|
commands =
|
|
272
279
|
python -m pytest {tty:--color=yes} {posargs: \
|
|
273
280
|
--cov={env_site_packages_dir}{/}moscot --cov-config={tox_root}{/}pyproject.toml \
|
|
274
281
|
--no-cov-on-fail --cov-report=xml --cov-report=term-missing:skip-covered}
|
|
282
|
+
passenv = PYTEST_*,CI
|
|
275
283
|
|
|
276
284
|
[testenv:lint-code]
|
|
277
285
|
description = Lint the code.
|
|
@@ -282,7 +290,6 @@ commands =
|
|
|
282
290
|
|
|
283
291
|
[testenv:lint-docs]
|
|
284
292
|
description = Lint the documentation.
|
|
285
|
-
deps =
|
|
286
293
|
extras = docs
|
|
287
294
|
ignore_errors = true
|
|
288
295
|
allowlist_externals = make
|
|
@@ -294,6 +301,21 @@ commands =
|
|
|
294
301
|
# TODO(michalk8): uncomment after https://github.com/theislab/moscot/issues/490
|
|
295
302
|
# make spelling {posargs}
|
|
296
303
|
|
|
304
|
+
[testenv:examples-docs]
|
|
305
|
+
allowlist_externals = bash
|
|
306
|
+
description = Run the notebooks.
|
|
307
|
+
use_develop = true
|
|
308
|
+
deps =
|
|
309
|
+
ipykernel
|
|
310
|
+
jupytext
|
|
311
|
+
nbconvert
|
|
312
|
+
leidenalg
|
|
313
|
+
extras = docs
|
|
314
|
+
changedir = {tox_root}{/}docs
|
|
315
|
+
commands =
|
|
316
|
+
python -m ipykernel install --user --name=moscot
|
|
317
|
+
bash {tox_root}/.run_notebooks.sh {tox_root}{/}docs/notebooks
|
|
318
|
+
|
|
297
319
|
[testenv:clean-docs]
|
|
298
320
|
description = Remove the documentation.
|
|
299
321
|
deps =
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from importlib import metadata
|
|
2
|
+
|
|
3
|
+
from moscot import backends, base, costs, datasets, plotting, problems, utils
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
md = metadata.metadata(__name__)
|
|
7
|
+
__version__ = md.get("version", "") # type: ignore[attr-defined]
|
|
8
|
+
__author__ = md.get("Author", "") # type: ignore[attr-defined]
|
|
9
|
+
__maintainer__ = md.get("Maintainer-email", "") # type: ignore[attr-defined]
|
|
10
|
+
except ImportError:
|
|
11
|
+
md = None
|
|
12
|
+
|
|
13
|
+
del metadata, md
|
|
@@ -2,6 +2,9 @@ import os
|
|
|
2
2
|
from typing import Any, Literal, Mapping, Optional, Sequence, Union
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
|
+
from ott.initializers.linear.initializers import SinkhornInitializer
|
|
6
|
+
from ott.initializers.linear.initializers_lr import LRInitializer
|
|
7
|
+
from ott.initializers.quadratic.initializers import BaseQuadraticInitializer
|
|
5
8
|
|
|
6
9
|
# TODO(michalk8): polish
|
|
7
10
|
|
|
@@ -17,13 +20,14 @@ ProblemKind_t = Literal["linear", "quadratic", "unknown"]
|
|
|
17
20
|
Numeric_t = Union[int, float] # type of `time_key` arguments
|
|
18
21
|
Filter_t = Optional[Union[str, Mapping[str, Sequence[Any]]]] # type how to filter adata
|
|
19
22
|
Str_Dict_t = Optional[Union[str, Mapping[str, Sequence[Any]]]] # type for `cell_transition`
|
|
20
|
-
|
|
21
|
-
|
|
23
|
+
SinkhornInitializerTag_t = Literal["default", "gaussian", "sorting"]
|
|
24
|
+
LRInitializerTag_t = Literal["random", "rank2", "k-means", "generalized-k-means"]
|
|
22
25
|
|
|
23
|
-
|
|
24
|
-
|
|
26
|
+
LRInitializer_t = Optional[Union[LRInitializer, LRInitializerTag_t]]
|
|
27
|
+
SinkhornInitializer_t = Optional[Union[SinkhornInitializer, SinkhornInitializerTag_t]]
|
|
28
|
+
QuadInitializer_t = Optional[Union[BaseQuadraticInitializer]]
|
|
25
29
|
|
|
26
|
-
Initializer_t = Union[SinkhornInitializer_t, LRInitializer_t]
|
|
30
|
+
Initializer_t = Union[SinkhornInitializer_t, QuadInitializer_t, LRInitializer_t]
|
|
27
31
|
ProblemStage_t = Literal["prepared", "solved"]
|
|
28
32
|
Device_t = Union[Literal["cpu", "gpu", "tpu"], str]
|
|
29
33
|
|
|
@@ -36,10 +40,6 @@ OttCostFn_t = Literal[
|
|
|
36
40
|
"pnorm_p",
|
|
37
41
|
"sq_pnorm",
|
|
38
42
|
"cosine",
|
|
39
|
-
"elastic_l1",
|
|
40
|
-
"elastic_l2",
|
|
41
|
-
"elastic_stvs",
|
|
42
|
-
"elastic_sqk_overlap",
|
|
43
43
|
"geodesic",
|
|
44
44
|
]
|
|
45
45
|
OttCostFnMap_t = Union[OttCostFn_t, Mapping[Literal["xy", "x", "y"], OttCostFn_t]]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from ott.geometry import costs
|
|
2
|
+
|
|
3
|
+
from moscot.backends.ott._utils import sinkhorn_divergence
|
|
4
|
+
from moscot.backends.ott.output import GraphOTTOutput, OTTNeuralOutput, OTTOutput
|
|
5
|
+
from moscot.backends.ott.solver import GENOTLinSolver, GWSolver, SinkhornSolver
|
|
6
|
+
from moscot.costs import register_cost
|
|
7
|
+
|
|
8
|
+
__all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "OTTNeuralOutput", "sinkhorn_divergence", "GENOTLinSolver"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
register_cost("euclidean", backend="ott")(costs.Euclidean)
|
|
12
|
+
register_cost("sq_euclidean", backend="ott")(costs.SqEuclidean)
|
|
13
|
+
register_cost("cosine", backend="ott")(costs.Cosine)
|
|
14
|
+
register_cost("pnorm_p", backend="ott")(costs.PNormP)
|
|
15
|
+
register_cost("sq_pnorm", backend="ott")(costs.SqPNorm)
|
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Any, Dict, Iterable, Literal, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.experimental.sparse as jesp
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
import numpy as np
|
|
9
|
+
import scipy.sparse as sp
|
|
10
|
+
from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
|
|
11
|
+
from ott.initializers.linear import initializers as init_lib
|
|
12
|
+
from ott.initializers.linear import initializers_lr as lr_init_lib
|
|
13
|
+
from ott.neural import datasets
|
|
14
|
+
from ott.solvers import utils as solver_utils
|
|
15
|
+
from ott.tools.sinkhorn_divergence import sinkhorn_divergence as sinkhorn_div
|
|
16
|
+
|
|
17
|
+
from moscot._logging import logger
|
|
18
|
+
from moscot._types import ArrayLike, ScaleCost_t
|
|
19
|
+
|
|
20
|
+
Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
__all__ = ["sinkhorn_divergence"]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class InitializerResolver:
|
|
27
|
+
"""Class for creating various OT solver initializers.
|
|
28
|
+
|
|
29
|
+
This class provides static methods to create and manage different types of
|
|
30
|
+
initializers used in optimal transport solvers, including low-rank, k-means,
|
|
31
|
+
and standard Sinkhorn initializers.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def lr_from_str(
|
|
36
|
+
initializer: str,
|
|
37
|
+
rank: int,
|
|
38
|
+
**kwargs: Any,
|
|
39
|
+
) -> lr_init_lib.LRInitializer:
|
|
40
|
+
"""Create a low-rank initializer from a string specification.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
initializer : str
|
|
45
|
+
Either existing initializer instance or string specifier.
|
|
46
|
+
rank : int
|
|
47
|
+
Rank for the initialization.
|
|
48
|
+
**kwargs : Any
|
|
49
|
+
Additional keyword arguments for initializer creation.
|
|
50
|
+
|
|
51
|
+
Returns
|
|
52
|
+
-------
|
|
53
|
+
LRInitializer
|
|
54
|
+
Configured low-rank initializer.
|
|
55
|
+
|
|
56
|
+
Raises
|
|
57
|
+
------
|
|
58
|
+
NotImplementedError
|
|
59
|
+
If requested initializer type is not implemented.
|
|
60
|
+
"""
|
|
61
|
+
if isinstance(initializer, lr_init_lib.LRInitializer):
|
|
62
|
+
return initializer
|
|
63
|
+
if initializer == "k-means":
|
|
64
|
+
return lr_init_lib.KMeansInitializer(rank=rank, **kwargs)
|
|
65
|
+
if initializer == "generalized-k-means":
|
|
66
|
+
return lr_init_lib.GeneralizedKMeansInitializer(rank=rank, **kwargs)
|
|
67
|
+
if initializer == "random":
|
|
68
|
+
return lr_init_lib.RandomInitializer(rank=rank, **kwargs)
|
|
69
|
+
if initializer == "rank2":
|
|
70
|
+
return lr_init_lib.Rank2Initializer(rank=rank, **kwargs)
|
|
71
|
+
raise NotImplementedError(f"Initializer `{initializer}` is not implemented.")
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def from_str(
|
|
75
|
+
initializer: str,
|
|
76
|
+
**kwargs: Any,
|
|
77
|
+
) -> init_lib.SinkhornInitializer:
|
|
78
|
+
"""Create a Sinkhorn initializer from a string specification.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
initializer : str
|
|
83
|
+
String specifier for initializer type.
|
|
84
|
+
**kwargs : Any
|
|
85
|
+
Additional keyword arguments for initializer creation.
|
|
86
|
+
|
|
87
|
+
Returns
|
|
88
|
+
-------
|
|
89
|
+
SinkhornInitializer
|
|
90
|
+
Configured Sinkhorn initializer.
|
|
91
|
+
|
|
92
|
+
Raises
|
|
93
|
+
------
|
|
94
|
+
NotImplementedError
|
|
95
|
+
If requested initializer type is not implemented.
|
|
96
|
+
"""
|
|
97
|
+
if isinstance(initializer, init_lib.SinkhornInitializer):
|
|
98
|
+
return initializer
|
|
99
|
+
if initializer == "default":
|
|
100
|
+
return init_lib.DefaultInitializer(**kwargs)
|
|
101
|
+
if initializer == "gaussian":
|
|
102
|
+
return init_lib.GaussianInitializer(**kwargs)
|
|
103
|
+
if initializer == "sorting":
|
|
104
|
+
return init_lib.SortingInitializer(**kwargs)
|
|
105
|
+
if initializer == "subsample":
|
|
106
|
+
return init_lib.SubsampleInitializer(**kwargs)
|
|
107
|
+
raise NotImplementedError(f"Initializer `{initializer}` is not yet implemented.")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def sinkhorn_divergence(
|
|
111
|
+
point_cloud_1: ArrayLike,
|
|
112
|
+
point_cloud_2: ArrayLike,
|
|
113
|
+
a: Optional[ArrayLike] = None,
|
|
114
|
+
b: Optional[ArrayLike] = None,
|
|
115
|
+
epsilon: Union[float, epsilon_scheduler.Epsilon] = 1e-1,
|
|
116
|
+
tau_a: float = 1.0,
|
|
117
|
+
tau_b: float = 1.0,
|
|
118
|
+
scale_cost: ScaleCost_t = 1.0,
|
|
119
|
+
batch_size: Optional[int] = None,
|
|
120
|
+
**kwargs: Any,
|
|
121
|
+
) -> float:
|
|
122
|
+
point_cloud_1 = jnp.asarray(point_cloud_1)
|
|
123
|
+
point_cloud_2 = jnp.asarray(point_cloud_2)
|
|
124
|
+
a = None if a is None else jnp.asarray(a)
|
|
125
|
+
b = None if b is None else jnp.asarray(b)
|
|
126
|
+
|
|
127
|
+
output = sinkhorn_div(
|
|
128
|
+
pointcloud.PointCloud,
|
|
129
|
+
x=point_cloud_1,
|
|
130
|
+
y=point_cloud_2,
|
|
131
|
+
batch_size=batch_size,
|
|
132
|
+
a=a,
|
|
133
|
+
b=b,
|
|
134
|
+
scale_cost=scale_cost,
|
|
135
|
+
epsilon=epsilon,
|
|
136
|
+
solve_kwargs={
|
|
137
|
+
"tau_a": tau_a,
|
|
138
|
+
"tau_b": tau_b,
|
|
139
|
+
},
|
|
140
|
+
**kwargs,
|
|
141
|
+
)[1]
|
|
142
|
+
xy_conv, xx_conv, *yy_conv = output.converged
|
|
143
|
+
|
|
144
|
+
if not xy_conv:
|
|
145
|
+
logger.warning("Solver did not converge in the `x/y` term.")
|
|
146
|
+
if not xx_conv:
|
|
147
|
+
logger.warning("Solver did not converge in the `x/x` term.")
|
|
148
|
+
if len(yy_conv) and not yy_conv[0]:
|
|
149
|
+
logger.warning("Solver did not converge in the `y/y` term.")
|
|
150
|
+
|
|
151
|
+
return float(output.divergence)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@partial(jax.jit, static_argnames=["k"])
|
|
155
|
+
def get_nearest_neighbors(
|
|
156
|
+
input_batch: jnp.ndarray,
|
|
157
|
+
target: jnp.ndarray,
|
|
158
|
+
k: int = 30,
|
|
159
|
+
recall_target: float = 0.95,
|
|
160
|
+
aggregate_to_topk: bool = True,
|
|
161
|
+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
|
162
|
+
"""Get the k nearest neighbors of the input batch in the target."""
|
|
163
|
+
if target.shape[0] < k:
|
|
164
|
+
raise ValueError(f"k is {k}, but must be smaller or equal than {target.shape[0]}.")
|
|
165
|
+
pairwise_euclidean_distances = pointcloud.PointCloud(input_batch, target).cost_matrix
|
|
166
|
+
return jax.lax.approx_min_k(
|
|
167
|
+
pairwise_euclidean_distances, k=k, recall_target=recall_target, aggregate_to_topk=aggregate_to_topk
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def check_shapes(geom_x: geometry.Geometry, geom_y: geometry.Geometry, geom_xy: geometry.Geometry) -> None:
|
|
172
|
+
n, m = geom_xy.shape
|
|
173
|
+
n_, m_ = geom_x.shape[0], geom_y.shape[0]
|
|
174
|
+
if n != n_:
|
|
175
|
+
raise ValueError(f"Expected the first geometry to have `{n}` points, found `{n_}`.")
|
|
176
|
+
if m != m_:
|
|
177
|
+
raise ValueError(f"Expected the second geometry to have `{m}` points, found `{m_}`.")
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def alpha_to_fused_penalty(alpha: float) -> float:
|
|
181
|
+
"""Convert."""
|
|
182
|
+
if not (0 < alpha <= 1):
|
|
183
|
+
raise ValueError(f"Expected `alpha` to be in interval `(0, 1]`, found `{alpha}`.")
|
|
184
|
+
return (1 - alpha) / alpha
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def densify(arr: ArrayLike) -> jax.Array:
|
|
188
|
+
"""If the input is sparse, convert it to dense.
|
|
189
|
+
|
|
190
|
+
Parameters
|
|
191
|
+
----------
|
|
192
|
+
arr
|
|
193
|
+
Array to check.
|
|
194
|
+
|
|
195
|
+
Returns
|
|
196
|
+
-------
|
|
197
|
+
dense :mod:`jax` array.
|
|
198
|
+
"""
|
|
199
|
+
if sp.issparse(arr):
|
|
200
|
+
arr = arr.toarray() # type: ignore[attr-defined]
|
|
201
|
+
elif isinstance(arr, jesp.BCOO):
|
|
202
|
+
arr = arr.todense()
|
|
203
|
+
return jnp.asarray(arr)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
|
|
207
|
+
"""Ensure that an array is 2-dimensional.
|
|
208
|
+
|
|
209
|
+
Parameters
|
|
210
|
+
----------
|
|
211
|
+
arr
|
|
212
|
+
Array to check.
|
|
213
|
+
reshape
|
|
214
|
+
Allow reshaping 1-dimensional array to ``[n, 1]``.
|
|
215
|
+
|
|
216
|
+
Returns
|
|
217
|
+
-------
|
|
218
|
+
2-dimensional :mod:`jax` array.
|
|
219
|
+
"""
|
|
220
|
+
if reshape and arr.ndim == 1:
|
|
221
|
+
return jnp.reshape(arr, (-1, 1))
|
|
222
|
+
if arr.ndim != 2:
|
|
223
|
+
raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
|
|
224
|
+
return arr.astype(jnp.float64)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def convert_scipy_sparse(arr: Union[sp.spmatrix, jesp.BCOO]) -> jesp.BCOO:
|
|
228
|
+
"""If the input is a scipy sparse matrix, convert it to a jax BCOO."""
|
|
229
|
+
if sp.issparse(arr):
|
|
230
|
+
return jesp.BCOO.from_scipy_sparse(arr)
|
|
231
|
+
return arr
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _instantiate_geodesic_cost(
|
|
235
|
+
arr: jax.Array,
|
|
236
|
+
problem_shape: Tuple[int, int],
|
|
237
|
+
t: Optional[float],
|
|
238
|
+
is_linear_term: bool,
|
|
239
|
+
epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
|
|
240
|
+
relative_epsilon: Optional[bool] = None,
|
|
241
|
+
scale_cost: Scale_t = 1.0,
|
|
242
|
+
directed: bool = True,
|
|
243
|
+
**kwargs: Any,
|
|
244
|
+
) -> geometry.Geometry:
|
|
245
|
+
n_src, n_tgt = problem_shape
|
|
246
|
+
if is_linear_term and n_src + n_tgt != arr.shape[0]:
|
|
247
|
+
raise ValueError(f"Expected `x` to have `{n_src + n_tgt}` points, found `{arr.shape[0]}`.")
|
|
248
|
+
t = epsilon / 4.0 if t is None else t
|
|
249
|
+
cm_full = geodesic.Geodesic.from_graph(arr, t=t, directed=directed, **kwargs).cost_matrix
|
|
250
|
+
cm = cm_full[:n_src, n_src:] if is_linear_term else cm_full
|
|
251
|
+
return geometry.Geometry(cm, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def data_match_fn(
|
|
255
|
+
src_lin: Optional[jnp.ndarray] = None,
|
|
256
|
+
tgt_lin: Optional[jnp.ndarray] = None,
|
|
257
|
+
src_quad: Optional[jnp.ndarray] = None,
|
|
258
|
+
tgt_quad: Optional[jnp.ndarray] = None,
|
|
259
|
+
*,
|
|
260
|
+
typ: Literal["lin", "quad", "fused"],
|
|
261
|
+
**data_match_fn_kwargs,
|
|
262
|
+
) -> jnp.ndarray:
|
|
263
|
+
if typ == "lin":
|
|
264
|
+
return solver_utils.match_linear(x=src_lin, y=tgt_lin, **data_match_fn_kwargs)
|
|
265
|
+
if typ == "quad":
|
|
266
|
+
return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad, **data_match_fn_kwargs)
|
|
267
|
+
if typ == "fused":
|
|
268
|
+
return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin, **data_match_fn_kwargs)
|
|
269
|
+
raise NotImplementedError(f"Unknown type: {typ}.")
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
class Loader:
|
|
273
|
+
|
|
274
|
+
def __init__(self, dataset: datasets.OTDataset, batch_size: int, seed: Optional[int] = None):
|
|
275
|
+
self.dataset = dataset
|
|
276
|
+
self.batch_size = batch_size
|
|
277
|
+
self._rng = np.random.default_rng(seed)
|
|
278
|
+
|
|
279
|
+
def __iter__(self):
|
|
280
|
+
return self
|
|
281
|
+
|
|
282
|
+
def __next__(self) -> Dict[str, jnp.ndarray]:
|
|
283
|
+
data = defaultdict(list)
|
|
284
|
+
for _ in range(self.batch_size):
|
|
285
|
+
ix = self._rng.integers(0, len(self.dataset))
|
|
286
|
+
for k, v in self.dataset[ix].items():
|
|
287
|
+
data[k].append(v)
|
|
288
|
+
return {k: jnp.vstack(v) for k, v in data.items()}
|
|
289
|
+
|
|
290
|
+
def __len__(self):
|
|
291
|
+
return len(self.dataset)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class MultiLoader:
|
|
295
|
+
"""Dataset for OT problems with conditions.
|
|
296
|
+
|
|
297
|
+
This data loader wraps several data loaders and samples from them.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
datasets: Datasets to sample from.
|
|
301
|
+
seed: Random seed.
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
def __init__(
|
|
305
|
+
self,
|
|
306
|
+
datasets: Iterable[Loader],
|
|
307
|
+
seed: Optional[int] = None,
|
|
308
|
+
):
|
|
309
|
+
self.datasets = tuple(datasets)
|
|
310
|
+
self._rng = np.random.default_rng(seed)
|
|
311
|
+
self._iterators: list[MultiLoader] = []
|
|
312
|
+
self._it = 0
|
|
313
|
+
|
|
314
|
+
def __next__(self) -> Dict[str, jnp.ndarray]:
|
|
315
|
+
self._it += 1
|
|
316
|
+
|
|
317
|
+
ix = self._rng.choice(len(self._iterators))
|
|
318
|
+
iterator = self._iterators[ix]
|
|
319
|
+
if self._it < len(self):
|
|
320
|
+
return next(iterator)
|
|
321
|
+
# reset the consumed iterator and return it's first element
|
|
322
|
+
self._iterators[ix] = iterator = iter(self.datasets[ix])
|
|
323
|
+
return next(iterator)
|
|
324
|
+
|
|
325
|
+
def __iter__(self) -> "MultiLoader":
|
|
326
|
+
self._it = 0
|
|
327
|
+
self._iterators = [iter(ds) for ds in self.datasets]
|
|
328
|
+
return self
|
|
329
|
+
|
|
330
|
+
def __len__(self) -> int:
|
|
331
|
+
return max((len(ds) for ds in self.datasets), default=0)
|