moscot 0.3.4__tar.gz → 0.3.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {moscot-0.3.4 → moscot-0.3.5}/.gitignore +1 -0
- {moscot-0.3.4 → moscot-0.3.5}/.pre-commit-config.yaml +6 -6
- moscot-0.3.5/.run_notebooks.sh +62 -0
- {moscot-0.3.4 → moscot-0.3.5}/PKG-INFO +3 -2
- {moscot-0.3.4 → moscot-0.3.5}/pyproject.toml +17 -2
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/_types.py +0 -4
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/backends/ott/__init__.py +0 -4
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/backends/ott/_utils.py +27 -3
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/backends/ott/solver.py +6 -2
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/problems/_utils.py +11 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/problems/birth_death.py +35 -10
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/problems/compound_problem.py +49 -14
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/problems/problem.py +46 -29
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/datasets.py +15 -3
- moscot-0.3.5/src/moscot/problems/_utils.py +126 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/cross_modality/_translation.py +51 -22
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/generic/_generic.py +101 -34
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/space/_alignment.py +64 -23
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/space/_mapping.py +80 -31
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/space/_mixins.py +84 -20
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/spatiotemporal/_spatio_temporal.py +22 -8
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/time/_lineage.py +69 -19
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/time/_mixins.py +3 -2
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/tagged_array.py +7 -6
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot.egg-info/PKG-INFO +3 -2
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot.egg-info/SOURCES.txt +1 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot.egg-info/requires.txt +2 -1
- moscot-0.3.4/src/moscot/problems/_utils.py +0 -85
- {moscot-0.3.4 → moscot-0.3.5}/.gitmodules +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/.readthedocs.yml +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/LICENSE +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/MANIFEST.in +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/README.rst +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/codecov.yml +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/setup.cfg +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/_constants.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/_logging.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/_registry.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/backends/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/backends/ott/output.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/backends/utils.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/cost.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/output.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/problems/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/problems/_mixins.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/problems/manager.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/solver.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/costs/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/costs/_costs.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/costs/_utils.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/plotting/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/plotting/_plotting.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/plotting/_utils.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/cross_modality/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/cross_modality/_mixins.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/generic/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/generic/_mixins.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/space/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/spatiotemporal/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/time/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/py.typed +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/__init__.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/_data/allTFs_dmel.txt +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/_data/allTFs_hg38.txt +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/_data/allTFs_mm.txt +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/_data/human_apoptosis.txt +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/_data/human_proliferation.txt +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/_data/mouse_apoptosis.txt +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/_data/mouse_proliferation.txt +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/data.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/subset_policy.py +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot.egg-info/dependency_links.txt +0 -0
- {moscot-0.3.4 → moscot-0.3.5}/src/moscot.egg-info/top_level.txt +0 -0
|
@@ -7,13 +7,13 @@ default_stages:
|
|
|
7
7
|
minimum_pre_commit_version: 3.0.0
|
|
8
8
|
repos:
|
|
9
9
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
|
10
|
-
rev: v1.
|
|
10
|
+
rev: v1.10.1
|
|
11
11
|
hooks:
|
|
12
12
|
- id: mypy
|
|
13
13
|
additional_dependencies: [numpy>=1.25.0]
|
|
14
14
|
files: ^src
|
|
15
15
|
- repo: https://github.com/psf/black
|
|
16
|
-
rev: 24.2
|
|
16
|
+
rev: 24.4.2
|
|
17
17
|
hooks:
|
|
18
18
|
- id: black
|
|
19
19
|
additional_dependencies: [toml]
|
|
@@ -29,7 +29,7 @@ repos:
|
|
|
29
29
|
additional_dependencies: [toml]
|
|
30
30
|
args: [--order-by-type]
|
|
31
31
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
32
|
-
rev: v4.
|
|
32
|
+
rev: v4.6.0
|
|
33
33
|
hooks:
|
|
34
34
|
- id: check-merge-conflict
|
|
35
35
|
- id: check-ast
|
|
@@ -42,12 +42,12 @@ repos:
|
|
|
42
42
|
- id: check-yaml
|
|
43
43
|
- id: check-toml
|
|
44
44
|
- repo: https://github.com/asottile/pyupgrade
|
|
45
|
-
rev: v3.
|
|
45
|
+
rev: v3.16.0
|
|
46
46
|
hooks:
|
|
47
47
|
- id: pyupgrade
|
|
48
48
|
args: [--py3-plus, --py38-plus, --keep-runtime-typing]
|
|
49
49
|
- repo: https://github.com/asottile/blacken-docs
|
|
50
|
-
rev: 1.
|
|
50
|
+
rev: 1.18.0
|
|
51
51
|
hooks:
|
|
52
52
|
- id: blacken-docs
|
|
53
53
|
additional_dependencies: [black==23.1.0]
|
|
@@ -63,7 +63,7 @@ repos:
|
|
|
63
63
|
- id: doc8
|
|
64
64
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
65
65
|
# Ruff version.
|
|
66
|
-
rev: v0.
|
|
66
|
+
rev: v0.5.0
|
|
67
67
|
hooks:
|
|
68
68
|
- id: ruff
|
|
69
69
|
args: [--fix, --exit-non-zero-on-fix]
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
|
|
3
|
+
# Check if the base directory is provided as an argument
|
|
4
|
+
if [ "$#" -ne 1 ]; then
|
|
5
|
+
echo "Usage: $0 <base_notebook_directory>"
|
|
6
|
+
exit 1
|
|
7
|
+
fi
|
|
8
|
+
|
|
9
|
+
# Base directory for notebooks
|
|
10
|
+
base_dir=$1
|
|
11
|
+
|
|
12
|
+
# Define notebook directories or patterns
|
|
13
|
+
declare -a notebooks=(
|
|
14
|
+
"$base_dir/examples/plotting/*.ipynb"
|
|
15
|
+
"$base_dir/examples/problems/*.ipynb"
|
|
16
|
+
"$base_dir/examples/solvers/*.ipynb"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Initialize an array to hold valid notebook paths
|
|
20
|
+
declare -a valid_notebooks
|
|
21
|
+
|
|
22
|
+
# Gather all valid notebook files from the patterns
|
|
23
|
+
echo "Gathering notebooks..."
|
|
24
|
+
for pattern in "${notebooks[@]}"; do
|
|
25
|
+
for nb in $pattern; do
|
|
26
|
+
if [[ -f "$nb" ]]; then # Check if the file exists
|
|
27
|
+
valid_notebooks+=("$nb") # Add to the list of valid notebooks
|
|
28
|
+
fi
|
|
29
|
+
done
|
|
30
|
+
done
|
|
31
|
+
|
|
32
|
+
# Check if we have any notebooks to run
|
|
33
|
+
if [ ${#valid_notebooks[@]} -eq 0 ]; then
|
|
34
|
+
echo "No notebooks found to run."
|
|
35
|
+
exit 1
|
|
36
|
+
fi
|
|
37
|
+
|
|
38
|
+
# Echo the notebooks that will be run for clarity
|
|
39
|
+
echo "Preparing to run the following notebooks:"
|
|
40
|
+
for nb in "${valid_notebooks[@]}"; do
|
|
41
|
+
echo "$nb"
|
|
42
|
+
done
|
|
43
|
+
|
|
44
|
+
# Initialize a flag to track the success of all commands
|
|
45
|
+
all_success=true
|
|
46
|
+
|
|
47
|
+
# Execute all valid notebooks
|
|
48
|
+
for nb in "${valid_notebooks[@]}"; do
|
|
49
|
+
echo "Running $nb"
|
|
50
|
+
jupytext -k moscot --execute "$nb" || {
|
|
51
|
+
echo "Failed to run $nb"
|
|
52
|
+
all_success=false
|
|
53
|
+
}
|
|
54
|
+
done
|
|
55
|
+
|
|
56
|
+
# Check if any executions failed
|
|
57
|
+
if [ "$all_success" = false ]; then
|
|
58
|
+
echo "One or more notebooks failed to execute."
|
|
59
|
+
exit 1
|
|
60
|
+
fi
|
|
61
|
+
|
|
62
|
+
echo "All notebooks executed successfully."
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: moscot
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.5
|
|
4
4
|
Summary: Multi-omic single-cell optimal transport tools
|
|
5
5
|
Author: Dominik Klein, Giovanni Palla, Michal Klein, Zoe Piran, Marius Lange
|
|
6
6
|
Maintainer-email: Dominik Klein <dominik.klein@helmholtz-muenchen.de>, Giovanni Palla <giovanni.palla@helmholtz-muenchen.de>, Michal Klein <michal.klein@helmholtz-muenchen.de>
|
|
@@ -66,9 +66,10 @@ Requires-Dist: anndata>=0.9.1
|
|
|
66
66
|
Requires-Dist: scanpy>=1.9.3
|
|
67
67
|
Requires-Dist: wrapt>=1.13.2
|
|
68
68
|
Requires-Dist: docrep>=0.3.2
|
|
69
|
-
Requires-Dist: ott-jax>=0.4.
|
|
69
|
+
Requires-Dist: ott-jax>=0.4.6
|
|
70
70
|
Requires-Dist: cloudpickle>=2.2.0
|
|
71
71
|
Requires-Dist: rich>=13.5
|
|
72
|
+
Requires-Dist: docstring_inheritance>=2.0.0
|
|
72
73
|
Provides-Extra: spatial
|
|
73
74
|
Requires-Dist: squidpy>=1.2.3; extra == "spatial"
|
|
74
75
|
Provides-Extra: dev
|
|
@@ -55,9 +55,10 @@ dependencies = [
|
|
|
55
55
|
"scanpy>=1.9.3",
|
|
56
56
|
"wrapt>=1.13.2",
|
|
57
57
|
"docrep>=0.3.2",
|
|
58
|
-
"ott-jax>=0.4.
|
|
58
|
+
"ott-jax>=0.4.6",
|
|
59
59
|
"cloudpickle>=2.2.0",
|
|
60
60
|
"rich>=13.5",
|
|
61
|
+
"docstring_inheritance>=2.0.0"
|
|
61
62
|
]
|
|
62
63
|
|
|
63
64
|
[project.optional-dependencies]
|
|
@@ -282,7 +283,6 @@ commands =
|
|
|
282
283
|
|
|
283
284
|
[testenv:lint-docs]
|
|
284
285
|
description = Lint the documentation.
|
|
285
|
-
deps =
|
|
286
286
|
extras = docs
|
|
287
287
|
ignore_errors = true
|
|
288
288
|
allowlist_externals = make
|
|
@@ -294,6 +294,21 @@ commands =
|
|
|
294
294
|
# TODO(michalk8): uncomment after https://github.com/theislab/moscot/issues/490
|
|
295
295
|
# make spelling {posargs}
|
|
296
296
|
|
|
297
|
+
[testenv:examples-docs]
|
|
298
|
+
allowlist_externals = bash
|
|
299
|
+
description = Run the notebooks.
|
|
300
|
+
use_develop = true
|
|
301
|
+
deps =
|
|
302
|
+
ipykernel
|
|
303
|
+
jupytext
|
|
304
|
+
nbconvert
|
|
305
|
+
leidenalg
|
|
306
|
+
extras = docs
|
|
307
|
+
changedir = {tox_root}{/}docs
|
|
308
|
+
commands =
|
|
309
|
+
python -m ipykernel install --user --name=moscot
|
|
310
|
+
bash {tox_root}/.run_notebooks.sh {tox_root}{/}docs/notebooks
|
|
311
|
+
|
|
297
312
|
[testenv:clean-docs]
|
|
298
313
|
description = Remove the documentation.
|
|
299
314
|
deps =
|
|
@@ -12,7 +12,3 @@ register_cost("sq_euclidean", backend="ott")(costs.SqEuclidean)
|
|
|
12
12
|
register_cost("cosine", backend="ott")(costs.Cosine)
|
|
13
13
|
register_cost("pnorm_p", backend="ott")(costs.PNormP)
|
|
14
14
|
register_cost("sq_pnorm", backend="ott")(costs.SqPNorm)
|
|
15
|
-
register_cost("elastic_l1", backend="ott")(costs.ElasticL1)
|
|
16
|
-
register_cost("elastic_l2", backend="ott")(costs.ElasticL2)
|
|
17
|
-
register_cost("elastic_stvs", backend="ott")(costs.ElasticSTVS)
|
|
18
|
-
register_cost("elastic_sqk_overlap", backend="ott")(costs.ElasticSqKOverlap)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from typing import Any, Literal, Optional, Tuple, Union
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
|
+
import jax.experimental.sparse as jesp
|
|
4
5
|
import jax.numpy as jnp
|
|
5
6
|
import scipy.sparse as sp
|
|
6
7
|
from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
|
|
@@ -67,6 +68,25 @@ def alpha_to_fused_penalty(alpha: float) -> float:
|
|
|
67
68
|
return (1 - alpha) / alpha
|
|
68
69
|
|
|
69
70
|
|
|
71
|
+
def densify(arr: ArrayLike) -> jax.Array:
|
|
72
|
+
"""If the input is sparse, convert it to dense.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
arr
|
|
77
|
+
Array to check.
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
dense :mod:`jax` array.
|
|
82
|
+
"""
|
|
83
|
+
if sp.issparse(arr):
|
|
84
|
+
arr = arr.toarray() # type: ignore[attr-defined]
|
|
85
|
+
elif isinstance(arr, jesp.BCOO):
|
|
86
|
+
arr = arr.todense()
|
|
87
|
+
return jnp.asarray(arr)
|
|
88
|
+
|
|
89
|
+
|
|
70
90
|
def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
|
|
71
91
|
"""Ensure that an array is 2-dimensional.
|
|
72
92
|
|
|
@@ -81,9 +101,6 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
|
|
|
81
101
|
-------
|
|
82
102
|
2-dimensional :mod:`jax` array.
|
|
83
103
|
"""
|
|
84
|
-
if sp.issparse(arr):
|
|
85
|
-
arr = arr.A # type: ignore[attr-defined]
|
|
86
|
-
arr = jnp.asarray(arr)
|
|
87
104
|
if reshape and arr.ndim == 1:
|
|
88
105
|
return jnp.reshape(arr, (-1, 1))
|
|
89
106
|
if arr.ndim != 2:
|
|
@@ -91,6 +108,13 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
|
|
|
91
108
|
return arr
|
|
92
109
|
|
|
93
110
|
|
|
111
|
+
def convert_scipy_sparse(arr: Union[sp.spmatrix, jesp.BCOO]) -> jesp.BCOO:
|
|
112
|
+
"""If the input is a scipy sparse matrix, convert it to a jax BCOO."""
|
|
113
|
+
if sp.issparse(arr):
|
|
114
|
+
return jesp.BCOO.from_scipy_sparse(arr)
|
|
115
|
+
return arr
|
|
116
|
+
|
|
117
|
+
|
|
94
118
|
def _instantiate_geodesic_cost(
|
|
95
119
|
arr: jax.Array,
|
|
96
120
|
problem_shape: Tuple[int, int],
|
|
@@ -16,6 +16,8 @@ from moscot.backends.ott._utils import (
|
|
|
16
16
|
_instantiate_geodesic_cost,
|
|
17
17
|
alpha_to_fused_penalty,
|
|
18
18
|
check_shapes,
|
|
19
|
+
convert_scipy_sparse,
|
|
20
|
+
densify,
|
|
19
21
|
ensure_2d,
|
|
20
22
|
)
|
|
21
23
|
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
|
|
@@ -76,8 +78,8 @@ class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC):
|
|
|
76
78
|
if not isinstance(cost_fn, costs.CostFn):
|
|
77
79
|
raise TypeError(f"Expected `cost_fn` to be `ott.geometry.costs.CostFn`, found `{type(cost_fn)}`.")
|
|
78
80
|
|
|
79
|
-
y = None if x.data_tgt is None else ensure_2d(x.data_tgt, reshape=True)
|
|
80
|
-
x = ensure_2d(x.data_src, reshape=True)
|
|
81
|
+
y = None if x.data_tgt is None else densify(ensure_2d(x.data_tgt, reshape=True))
|
|
82
|
+
x = densify(ensure_2d(x.data_src, reshape=True))
|
|
81
83
|
if y is not None and x.shape[1] != y.shape[1]:
|
|
82
84
|
raise ValueError(
|
|
83
85
|
f"Expected `x/y` to have the same number of dimensions, found `{x.shape[1]}/{y.shape[1]}`."
|
|
@@ -94,6 +96,8 @@ class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC):
|
|
|
94
96
|
)
|
|
95
97
|
|
|
96
98
|
arr = ensure_2d(x.data_src, reshape=False)
|
|
99
|
+
arr = densify(arr) if x.is_graph else convert_scipy_sparse(arr)
|
|
100
|
+
|
|
97
101
|
if x.is_cost_matrix:
|
|
98
102
|
return geometry.Geometry(
|
|
99
103
|
cost_matrix=arr, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost
|
|
@@ -158,6 +158,17 @@ def _validate_args_cell_transition(
|
|
|
158
158
|
raise TypeError(f"Expected argument to be either `str` or `dict`, found `{type(arg)}`.")
|
|
159
159
|
|
|
160
160
|
|
|
161
|
+
def _assert_series_match(a: pd.Series, b: pd.Series) -> None:
|
|
162
|
+
"""Assert that two series are equal ignoring the names."""
|
|
163
|
+
pd.testing.assert_series_equal(a, b, check_names=False)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _assert_columns_and_index_match(a: pd.Series, b: pd.DataFrame) -> None:
|
|
167
|
+
"""Assert that a series and a dataframe's index and columns are matching."""
|
|
168
|
+
_assert_series_match(a, b.index.to_series())
|
|
169
|
+
_assert_series_match(a, b.columns.to_series())
|
|
170
|
+
|
|
171
|
+
|
|
161
172
|
def _get_cell_indices(
|
|
162
173
|
adata: AnnData,
|
|
163
174
|
key: Optional[str] = None,
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from functools import partial
|
|
1
2
|
from typing import (
|
|
2
3
|
TYPE_CHECKING,
|
|
3
4
|
Any,
|
|
@@ -166,7 +167,14 @@ class BirthDeathProblem(BirthDeathMixin, OTProblem):
|
|
|
166
167
|
proliferation_key: Optional[str] = None,
|
|
167
168
|
apoptosis_key: Optional[str] = None,
|
|
168
169
|
scaling: Optional[float] = None,
|
|
169
|
-
|
|
170
|
+
beta_max: float = 1.7,
|
|
171
|
+
beta_min: float = 0.3,
|
|
172
|
+
beta_center: float = 0.25,
|
|
173
|
+
beta_width: float = 0.5,
|
|
174
|
+
delta_max: float = 1.7,
|
|
175
|
+
delta_min: float = 0.3,
|
|
176
|
+
delta_center: float = 0.1,
|
|
177
|
+
delta_width: float = 0.2,
|
|
170
178
|
) -> ArrayLike:
|
|
171
179
|
"""Estimate the source or target :term:`marginals` based on marker genes, either with the
|
|
172
180
|
`birth-death process <https://en.wikipedia.org/wiki/Birth%E2%80%93death_process>`_,
|
|
@@ -189,9 +197,22 @@ class BirthDeathProblem(BirthDeathMixin, OTProblem):
|
|
|
189
197
|
If :obj:`float` is passed, it will be used as a scaling parameter in an exponential kernel
|
|
190
198
|
with proliferation and apoptosis scores.
|
|
191
199
|
If :obj:`None`, parameters corresponding to the birth and death processes will be used.
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
200
|
+
beta_max
|
|
201
|
+
Argument for :func:`~moscot.base.problems.birth_death.beta`
|
|
202
|
+
beta_min
|
|
203
|
+
Argument for :func:`~moscot.base.problems.birth_death.beta`
|
|
204
|
+
beta_center
|
|
205
|
+
Argument for :func:`~moscot.base.problems.birth_death.beta`
|
|
206
|
+
beta_width
|
|
207
|
+
Argument for :func:`~moscot.base.problems.birth_death.beta`
|
|
208
|
+
delta_max
|
|
209
|
+
Argument for :func:`~moscot.base.problems.birth_death.delta`
|
|
210
|
+
delta_min
|
|
211
|
+
Argument for :func:`~moscot.base.problems.birth_death.delta`
|
|
212
|
+
delta_center
|
|
213
|
+
Argument for :func:`~moscot.base.problems.birth_death.delta`
|
|
214
|
+
delta_width
|
|
215
|
+
Argument for :func:`~moscot.base.problems.birth_death.delta`
|
|
195
216
|
|
|
196
217
|
Returns
|
|
197
218
|
-------
|
|
@@ -223,12 +244,18 @@ class BirthDeathProblem(BirthDeathMixin, OTProblem):
|
|
|
223
244
|
self.apoptosis_key = apoptosis_key
|
|
224
245
|
|
|
225
246
|
if scaling:
|
|
226
|
-
beta_fn = delta_fn = lambda x
|
|
247
|
+
beta_fn = delta_fn = lambda x: x
|
|
227
248
|
else:
|
|
228
|
-
beta_fn
|
|
249
|
+
beta_fn = partial(
|
|
250
|
+
beta, beta_max=beta_max, beta_min=beta_min, beta_center=beta_center, beta_width=beta_width
|
|
251
|
+
)
|
|
252
|
+
delta_fn = partial(
|
|
253
|
+
delta, delta_max=delta_max, delta_min=delta_min, delta_center=delta_center, delta_width=delta_width
|
|
254
|
+
)
|
|
255
|
+
|
|
229
256
|
scaling = 1.0
|
|
230
|
-
birth = estimate(proliferation_key, fn=beta_fn
|
|
231
|
-
death = estimate(apoptosis_key, fn=delta_fn
|
|
257
|
+
birth = estimate(proliferation_key, fn=beta_fn)
|
|
258
|
+
death = estimate(apoptosis_key, fn=delta_fn)
|
|
232
259
|
|
|
233
260
|
prior_growth = np.exp((birth - death) * self.delta / scaling)
|
|
234
261
|
|
|
@@ -287,7 +314,6 @@ def beta(
|
|
|
287
314
|
beta_min: float = 0.3,
|
|
288
315
|
beta_center: float = 0.25,
|
|
289
316
|
beta_width: float = 0.5,
|
|
290
|
-
**_: Any,
|
|
291
317
|
) -> ArrayLike:
|
|
292
318
|
"""Birth process."""
|
|
293
319
|
return _gen_logistic(p, beta_max, beta_min, beta_center, beta_width)
|
|
@@ -299,7 +325,6 @@ def delta(
|
|
|
299
325
|
delta_min: float = 0.3,
|
|
300
326
|
delta_center: float = 0.1,
|
|
301
327
|
delta_width: float = 0.2,
|
|
302
|
-
**_: Any,
|
|
303
328
|
) -> ArrayLike:
|
|
304
329
|
"""Death process."""
|
|
305
330
|
return _gen_logistic(a, delta_max, delta_min, delta_center, delta_width)
|
|
@@ -149,9 +149,10 @@ class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]):
|
|
|
149
149
|
xy_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
|
|
150
150
|
x_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
|
|
151
151
|
y_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
|
|
152
|
-
|
|
152
|
+
a: Optional[Union[bool, str, ArrayLike]] = None,
|
|
153
|
+
b: Optional[Union[bool, str, ArrayLike]] = None,
|
|
154
|
+
marginal_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
|
|
153
155
|
) -> Dict[Tuple[K, K], B]:
|
|
154
|
-
from moscot.base.problems.birth_death import BirthDeathProblem
|
|
155
156
|
|
|
156
157
|
if TYPE_CHECKING:
|
|
157
158
|
assert isinstance(self._policy, SubsetPolicy)
|
|
@@ -187,10 +188,7 @@ class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]):
|
|
|
187
188
|
if y_data:
|
|
188
189
|
y = dict(y)
|
|
189
190
|
y["tagged_array"] = y_data
|
|
190
|
-
|
|
191
|
-
kwargs["proliferation_key"] = self.proliferation_key # type: ignore[attr-defined]
|
|
192
|
-
kwargs["apoptosis_key"] = self.apoptosis_key # type: ignore[attr-defined]
|
|
193
|
-
problems[src_name, tgt_name] = problem.prepare(xy=xy, x=x, y=y, **kwargs)
|
|
191
|
+
problems[src_name, tgt_name] = problem.prepare(xy=xy, x=x, y=y, a=a, b=b, marginal_kwargs=marginal_kwargs)
|
|
194
192
|
|
|
195
193
|
return problems
|
|
196
194
|
|
|
@@ -200,13 +198,18 @@ class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]):
|
|
|
200
198
|
key: Optional[str],
|
|
201
199
|
subset: Optional[Sequence[Tuple[K, K]]] = None,
|
|
202
200
|
reference: Optional[Any] = None,
|
|
201
|
+
xy: Mapping[str, Any] = types.MappingProxyType({}),
|
|
202
|
+
x: Mapping[str, Any] = types.MappingProxyType({}),
|
|
203
|
+
y: Mapping[str, Any] = types.MappingProxyType({}),
|
|
203
204
|
xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
|
|
204
205
|
x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
|
|
205
206
|
y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
|
|
206
207
|
xy_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
|
|
207
208
|
x_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
|
|
208
209
|
y_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
|
|
209
|
-
|
|
210
|
+
a: Optional[Union[bool, str, ArrayLike]] = None,
|
|
211
|
+
b: Optional[Union[bool, str, ArrayLike]] = None,
|
|
212
|
+
marginal_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
|
|
210
213
|
) -> "BaseCompoundProblem[K, B]":
|
|
211
214
|
"""Prepare the individual :term:`OT` subproblems.
|
|
212
215
|
|
|
@@ -224,6 +227,12 @@ class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]):
|
|
|
224
227
|
for the :class:`~moscot.utils.subset_policy.ExplicitPolicy`. Only used when ``policy = 'explicit'``.
|
|
225
228
|
reference
|
|
226
229
|
Reference for the :class:`~moscot.utils.subset_policy.SubsetPolicy`. Only used when ``policy = 'star'``.
|
|
230
|
+
xy
|
|
231
|
+
Data for the :term:`linear term`.
|
|
232
|
+
x
|
|
233
|
+
Data for the source :term:`quadratic term`.
|
|
234
|
+
y
|
|
235
|
+
Data for the target :term:`quadratic term`.
|
|
227
236
|
xy_callback
|
|
228
237
|
Callback function used to prepare the data in the :term:`linear term`.
|
|
229
238
|
x_callback
|
|
@@ -236,8 +245,24 @@ class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]):
|
|
|
236
245
|
Keyword arguments for the ``x_callback``.
|
|
237
246
|
y_callback_kwargs
|
|
238
247
|
Keyword arguments for the ``y_callback``.
|
|
239
|
-
|
|
240
|
-
|
|
248
|
+
a
|
|
249
|
+
Source :term:`marginals`. Valid options are:
|
|
250
|
+
|
|
251
|
+
- :class:`str` - key in :attr:`~anndata.AnnData.obs` where the source marginals are stored.
|
|
252
|
+
- :class:`bool` - if :obj:`True`,
|
|
253
|
+
:meth:`estimate the marginals <moscot.base.problems.OTProblem.estimate_marginals>`,
|
|
254
|
+
otherwise use uniform marginals.
|
|
255
|
+
- :obj:`None` - uniform marginals.
|
|
256
|
+
b
|
|
257
|
+
Target :term:`marginals`. Valid options are:
|
|
258
|
+
|
|
259
|
+
- :class:`str` - key in :attr:`~anndata.AnnData.obs` where the target marginals are stored.
|
|
260
|
+
- :class:`bool` - if :obj:`True`,
|
|
261
|
+
:meth:`estimate the marginals <moscot.base.problems.OTProblem.estimate_marginals>`,
|
|
262
|
+
otherwise use uniform marginals.
|
|
263
|
+
- :obj:`None` - uniform marginals.
|
|
264
|
+
marginal_kwargs
|
|
265
|
+
Keyword arguments for the :meth:`~moscot.base.problems.OTProblem.estimate_marginals` method.
|
|
241
266
|
|
|
242
267
|
Returns
|
|
243
268
|
-------
|
|
@@ -264,13 +289,18 @@ class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]):
|
|
|
264
289
|
# when refactoring the callback, consider changing this
|
|
265
290
|
self._problem_manager = ProblemManager(self, policy=policy)
|
|
266
291
|
problems = self._create_problems(
|
|
267
|
-
|
|
292
|
+
x=x,
|
|
293
|
+
y=y,
|
|
294
|
+
xy=xy,
|
|
295
|
+
a=a,
|
|
296
|
+
b=b,
|
|
268
297
|
x_callback=x_callback,
|
|
269
298
|
y_callback=y_callback,
|
|
270
|
-
|
|
299
|
+
xy_callback=xy_callback,
|
|
271
300
|
x_callback_kwargs=x_callback_kwargs,
|
|
272
301
|
y_callback_kwargs=y_callback_kwargs,
|
|
273
|
-
|
|
302
|
+
xy_callback_kwargs=xy_callback_kwargs,
|
|
303
|
+
marginal_kwargs=marginal_kwargs,
|
|
274
304
|
)
|
|
275
305
|
self._problem_manager.add_problems(problems)
|
|
276
306
|
|
|
@@ -313,6 +343,11 @@ class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]):
|
|
|
313
343
|
problems = self._problem_manager.get_problems(stage=stage)
|
|
314
344
|
|
|
315
345
|
logger.info(f"Solving `{len(problems)}` problems")
|
|
346
|
+
# expose min/max iterations to the user but remove them if they are None
|
|
347
|
+
if "min_iterations" in kwargs and kwargs["min_iterations"] is None:
|
|
348
|
+
kwargs.pop("min_iterations")
|
|
349
|
+
if "max_iterations" in kwargs and kwargs["max_iterations"] is None:
|
|
350
|
+
kwargs.pop("max_iterations")
|
|
316
351
|
for problem in problems.values():
|
|
317
352
|
logger.info(f"Solving problem {problem}.")
|
|
318
353
|
_ = problem.solve(**kwargs)
|
|
@@ -587,14 +622,14 @@ class CompoundProblem(BaseCompoundProblem[K, B], abc.ABC):
|
|
|
587
622
|
linear_cost_matrix = data[mask, :][:, mask_2]
|
|
588
623
|
if sp.issparse(linear_cost_matrix):
|
|
589
624
|
logger.warning("Linear cost matrix being densified.")
|
|
590
|
-
linear_cost_matrix = linear_cost_matrix.
|
|
625
|
+
linear_cost_matrix = linear_cost_matrix.toarray()
|
|
591
626
|
return TaggedArray(linear_cost_matrix, tag=Tag.COST_MATRIX)
|
|
592
627
|
|
|
593
628
|
if term in ("x", "y"):
|
|
594
629
|
quad_cost_matrix = data[mask, :][:, mask]
|
|
595
630
|
if sp.issparse(quad_cost_matrix):
|
|
596
631
|
logger.warning("Quadratic cost matrix being densified.")
|
|
597
|
-
quad_cost_matrix = quad_cost_matrix.
|
|
632
|
+
quad_cost_matrix = quad_cost_matrix.toarray()
|
|
598
633
|
return TaggedArray(quad_cost_matrix, tag=Tag.COST_MATRIX)
|
|
599
634
|
|
|
600
635
|
raise ValueError(f"Expected `term` to be one of `x`, `y`, or `xy`, found `{term!r}`.")
|