moscot 0.3.4__tar.gz → 0.3.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. {moscot-0.3.4 → moscot-0.3.5}/.gitignore +1 -0
  2. {moscot-0.3.4 → moscot-0.3.5}/.pre-commit-config.yaml +6 -6
  3. moscot-0.3.5/.run_notebooks.sh +62 -0
  4. {moscot-0.3.4 → moscot-0.3.5}/PKG-INFO +3 -2
  5. {moscot-0.3.4 → moscot-0.3.5}/pyproject.toml +17 -2
  6. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/_types.py +0 -4
  7. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/backends/ott/__init__.py +0 -4
  8. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/backends/ott/_utils.py +27 -3
  9. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/backends/ott/solver.py +6 -2
  10. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/problems/_utils.py +11 -0
  11. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/problems/birth_death.py +35 -10
  12. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/problems/compound_problem.py +49 -14
  13. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/problems/problem.py +46 -29
  14. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/datasets.py +15 -3
  15. moscot-0.3.5/src/moscot/problems/_utils.py +126 -0
  16. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/cross_modality/_translation.py +51 -22
  17. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/generic/_generic.py +101 -34
  18. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/space/_alignment.py +64 -23
  19. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/space/_mapping.py +80 -31
  20. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/space/_mixins.py +84 -20
  21. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/spatiotemporal/_spatio_temporal.py +22 -8
  22. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/time/_lineage.py +69 -19
  23. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/time/_mixins.py +3 -2
  24. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/tagged_array.py +7 -6
  25. {moscot-0.3.4 → moscot-0.3.5}/src/moscot.egg-info/PKG-INFO +3 -2
  26. {moscot-0.3.4 → moscot-0.3.5}/src/moscot.egg-info/SOURCES.txt +1 -0
  27. {moscot-0.3.4 → moscot-0.3.5}/src/moscot.egg-info/requires.txt +2 -1
  28. moscot-0.3.4/src/moscot/problems/_utils.py +0 -85
  29. {moscot-0.3.4 → moscot-0.3.5}/.gitmodules +0 -0
  30. {moscot-0.3.4 → moscot-0.3.5}/.readthedocs.yml +0 -0
  31. {moscot-0.3.4 → moscot-0.3.5}/LICENSE +0 -0
  32. {moscot-0.3.4 → moscot-0.3.5}/MANIFEST.in +0 -0
  33. {moscot-0.3.4 → moscot-0.3.5}/README.rst +0 -0
  34. {moscot-0.3.4 → moscot-0.3.5}/codecov.yml +0 -0
  35. {moscot-0.3.4 → moscot-0.3.5}/setup.cfg +0 -0
  36. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/__init__.py +0 -0
  37. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/_constants.py +0 -0
  38. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/_logging.py +0 -0
  39. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/_registry.py +0 -0
  40. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/backends/__init__.py +0 -0
  41. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/backends/ott/output.py +0 -0
  42. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/backends/utils.py +0 -0
  43. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/__init__.py +0 -0
  44. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/cost.py +0 -0
  45. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/output.py +0 -0
  46. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/problems/__init__.py +0 -0
  47. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/problems/_mixins.py +0 -0
  48. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/problems/manager.py +0 -0
  49. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/base/solver.py +0 -0
  50. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/costs/__init__.py +0 -0
  51. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/costs/_costs.py +0 -0
  52. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/costs/_utils.py +0 -0
  53. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/plotting/__init__.py +0 -0
  54. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/plotting/_plotting.py +0 -0
  55. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/plotting/_utils.py +0 -0
  56. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/__init__.py +0 -0
  57. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/cross_modality/__init__.py +0 -0
  58. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/cross_modality/_mixins.py +0 -0
  59. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/generic/__init__.py +0 -0
  60. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/generic/_mixins.py +0 -0
  61. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/space/__init__.py +0 -0
  62. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/spatiotemporal/__init__.py +0 -0
  63. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/problems/time/__init__.py +0 -0
  64. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/py.typed +0 -0
  65. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/__init__.py +0 -0
  66. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/_data/allTFs_dmel.txt +0 -0
  67. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/_data/allTFs_hg38.txt +0 -0
  68. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/_data/allTFs_mm.txt +0 -0
  69. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/_data/human_apoptosis.txt +0 -0
  70. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/_data/human_proliferation.txt +0 -0
  71. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/_data/mouse_apoptosis.txt +0 -0
  72. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/_data/mouse_proliferation.txt +0 -0
  73. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/data.py +0 -0
  74. {moscot-0.3.4 → moscot-0.3.5}/src/moscot/utils/subset_policy.py +0 -0
  75. {moscot-0.3.4 → moscot-0.3.5}/src/moscot.egg-info/dependency_links.txt +0 -0
  76. {moscot-0.3.4 → moscot-0.3.5}/src/moscot.egg-info/top_level.txt +0 -0
@@ -154,3 +154,4 @@ packages.dot
154
154
 
155
155
  # plotting tests
156
156
  tests/plotting/actual_figures/
157
+ tests/plotting/figures/
@@ -7,13 +7,13 @@ default_stages:
7
7
  minimum_pre_commit_version: 3.0.0
8
8
  repos:
9
9
  - repo: https://github.com/pre-commit/mirrors-mypy
10
- rev: v1.8.0
10
+ rev: v1.10.1
11
11
  hooks:
12
12
  - id: mypy
13
13
  additional_dependencies: [numpy>=1.25.0]
14
14
  files: ^src
15
15
  - repo: https://github.com/psf/black
16
- rev: 24.2.0
16
+ rev: 24.4.2
17
17
  hooks:
18
18
  - id: black
19
19
  additional_dependencies: [toml]
@@ -29,7 +29,7 @@ repos:
29
29
  additional_dependencies: [toml]
30
30
  args: [--order-by-type]
31
31
  - repo: https://github.com/pre-commit/pre-commit-hooks
32
- rev: v4.5.0
32
+ rev: v4.6.0
33
33
  hooks:
34
34
  - id: check-merge-conflict
35
35
  - id: check-ast
@@ -42,12 +42,12 @@ repos:
42
42
  - id: check-yaml
43
43
  - id: check-toml
44
44
  - repo: https://github.com/asottile/pyupgrade
45
- rev: v3.15.1
45
+ rev: v3.16.0
46
46
  hooks:
47
47
  - id: pyupgrade
48
48
  args: [--py3-plus, --py38-plus, --keep-runtime-typing]
49
49
  - repo: https://github.com/asottile/blacken-docs
50
- rev: 1.16.0
50
+ rev: 1.18.0
51
51
  hooks:
52
52
  - id: blacken-docs
53
53
  additional_dependencies: [black==23.1.0]
@@ -63,7 +63,7 @@ repos:
63
63
  - id: doc8
64
64
  - repo: https://github.com/astral-sh/ruff-pre-commit
65
65
  # Ruff version.
66
- rev: v0.2.2
66
+ rev: v0.5.0
67
67
  hooks:
68
68
  - id: ruff
69
69
  args: [--fix, --exit-non-zero-on-fix]
@@ -0,0 +1,62 @@
1
+ #!/bin/bash
2
+
3
+ # Check if the base directory is provided as an argument
4
+ if [ "$#" -ne 1 ]; then
5
+ echo "Usage: $0 <base_notebook_directory>"
6
+ exit 1
7
+ fi
8
+
9
+ # Base directory for notebooks
10
+ base_dir=$1
11
+
12
+ # Define notebook directories or patterns
13
+ declare -a notebooks=(
14
+ "$base_dir/examples/plotting/*.ipynb"
15
+ "$base_dir/examples/problems/*.ipynb"
16
+ "$base_dir/examples/solvers/*.ipynb"
17
+ )
18
+
19
+ # Initialize an array to hold valid notebook paths
20
+ declare -a valid_notebooks
21
+
22
+ # Gather all valid notebook files from the patterns
23
+ echo "Gathering notebooks..."
24
+ for pattern in "${notebooks[@]}"; do
25
+ for nb in $pattern; do
26
+ if [[ -f "$nb" ]]; then # Check if the file exists
27
+ valid_notebooks+=("$nb") # Add to the list of valid notebooks
28
+ fi
29
+ done
30
+ done
31
+
32
+ # Check if we have any notebooks to run
33
+ if [ ${#valid_notebooks[@]} -eq 0 ]; then
34
+ echo "No notebooks found to run."
35
+ exit 1
36
+ fi
37
+
38
+ # Echo the notebooks that will be run for clarity
39
+ echo "Preparing to run the following notebooks:"
40
+ for nb in "${valid_notebooks[@]}"; do
41
+ echo "$nb"
42
+ done
43
+
44
+ # Initialize a flag to track the success of all commands
45
+ all_success=true
46
+
47
+ # Execute all valid notebooks
48
+ for nb in "${valid_notebooks[@]}"; do
49
+ echo "Running $nb"
50
+ jupytext -k moscot --execute "$nb" || {
51
+ echo "Failed to run $nb"
52
+ all_success=false
53
+ }
54
+ done
55
+
56
+ # Check if any executions failed
57
+ if [ "$all_success" = false ]; then
58
+ echo "One or more notebooks failed to execute."
59
+ exit 1
60
+ fi
61
+
62
+ echo "All notebooks executed successfully."
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: moscot
3
- Version: 0.3.4
3
+ Version: 0.3.5
4
4
  Summary: Multi-omic single-cell optimal transport tools
5
5
  Author: Dominik Klein, Giovanni Palla, Michal Klein, Zoe Piran, Marius Lange
6
6
  Maintainer-email: Dominik Klein <dominik.klein@helmholtz-muenchen.de>, Giovanni Palla <giovanni.palla@helmholtz-muenchen.de>, Michal Klein <michal.klein@helmholtz-muenchen.de>
@@ -66,9 +66,10 @@ Requires-Dist: anndata>=0.9.1
66
66
  Requires-Dist: scanpy>=1.9.3
67
67
  Requires-Dist: wrapt>=1.13.2
68
68
  Requires-Dist: docrep>=0.3.2
69
- Requires-Dist: ott-jax>=0.4.5
69
+ Requires-Dist: ott-jax>=0.4.6
70
70
  Requires-Dist: cloudpickle>=2.2.0
71
71
  Requires-Dist: rich>=13.5
72
+ Requires-Dist: docstring_inheritance>=2.0.0
72
73
  Provides-Extra: spatial
73
74
  Requires-Dist: squidpy>=1.2.3; extra == "spatial"
74
75
  Provides-Extra: dev
@@ -55,9 +55,10 @@ dependencies = [
55
55
  "scanpy>=1.9.3",
56
56
  "wrapt>=1.13.2",
57
57
  "docrep>=0.3.2",
58
- "ott-jax>=0.4.5",
58
+ "ott-jax>=0.4.6",
59
59
  "cloudpickle>=2.2.0",
60
60
  "rich>=13.5",
61
+ "docstring_inheritance>=2.0.0"
61
62
  ]
62
63
 
63
64
  [project.optional-dependencies]
@@ -282,7 +283,6 @@ commands =
282
283
 
283
284
  [testenv:lint-docs]
284
285
  description = Lint the documentation.
285
- deps =
286
286
  extras = docs
287
287
  ignore_errors = true
288
288
  allowlist_externals = make
@@ -294,6 +294,21 @@ commands =
294
294
  # TODO(michalk8): uncomment after https://github.com/theislab/moscot/issues/490
295
295
  # make spelling {posargs}
296
296
 
297
+ [testenv:examples-docs]
298
+ allowlist_externals = bash
299
+ description = Run the notebooks.
300
+ use_develop = true
301
+ deps =
302
+ ipykernel
303
+ jupytext
304
+ nbconvert
305
+ leidenalg
306
+ extras = docs
307
+ changedir = {tox_root}{/}docs
308
+ commands =
309
+ python -m ipykernel install --user --name=moscot
310
+ bash {tox_root}/.run_notebooks.sh {tox_root}{/}docs/notebooks
311
+
297
312
  [testenv:clean-docs]
298
313
  description = Remove the documentation.
299
314
  deps =
@@ -36,10 +36,6 @@ OttCostFn_t = Literal[
36
36
  "pnorm_p",
37
37
  "sq_pnorm",
38
38
  "cosine",
39
- "elastic_l1",
40
- "elastic_l2",
41
- "elastic_stvs",
42
- "elastic_sqk_overlap",
43
39
  "geodesic",
44
40
  ]
45
41
  OttCostFnMap_t = Union[OttCostFn_t, Mapping[Literal["xy", "x", "y"], OttCostFn_t]]
@@ -12,7 +12,3 @@ register_cost("sq_euclidean", backend="ott")(costs.SqEuclidean)
12
12
  register_cost("cosine", backend="ott")(costs.Cosine)
13
13
  register_cost("pnorm_p", backend="ott")(costs.PNormP)
14
14
  register_cost("sq_pnorm", backend="ott")(costs.SqPNorm)
15
- register_cost("elastic_l1", backend="ott")(costs.ElasticL1)
16
- register_cost("elastic_l2", backend="ott")(costs.ElasticL2)
17
- register_cost("elastic_stvs", backend="ott")(costs.ElasticSTVS)
18
- register_cost("elastic_sqk_overlap", backend="ott")(costs.ElasticSqKOverlap)
@@ -1,6 +1,7 @@
1
1
  from typing import Any, Literal, Optional, Tuple, Union
2
2
 
3
3
  import jax
4
+ import jax.experimental.sparse as jesp
4
5
  import jax.numpy as jnp
5
6
  import scipy.sparse as sp
6
7
  from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
@@ -67,6 +68,25 @@ def alpha_to_fused_penalty(alpha: float) -> float:
67
68
  return (1 - alpha) / alpha
68
69
 
69
70
 
71
+ def densify(arr: ArrayLike) -> jax.Array:
72
+ """If the input is sparse, convert it to dense.
73
+
74
+ Parameters
75
+ ----------
76
+ arr
77
+ Array to check.
78
+
79
+ Returns
80
+ -------
81
+ dense :mod:`jax` array.
82
+ """
83
+ if sp.issparse(arr):
84
+ arr = arr.toarray() # type: ignore[attr-defined]
85
+ elif isinstance(arr, jesp.BCOO):
86
+ arr = arr.todense()
87
+ return jnp.asarray(arr)
88
+
89
+
70
90
  def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
71
91
  """Ensure that an array is 2-dimensional.
72
92
 
@@ -81,9 +101,6 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
81
101
  -------
82
102
  2-dimensional :mod:`jax` array.
83
103
  """
84
- if sp.issparse(arr):
85
- arr = arr.A # type: ignore[attr-defined]
86
- arr = jnp.asarray(arr)
87
104
  if reshape and arr.ndim == 1:
88
105
  return jnp.reshape(arr, (-1, 1))
89
106
  if arr.ndim != 2:
@@ -91,6 +108,13 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
91
108
  return arr
92
109
 
93
110
 
111
+ def convert_scipy_sparse(arr: Union[sp.spmatrix, jesp.BCOO]) -> jesp.BCOO:
112
+ """If the input is a scipy sparse matrix, convert it to a jax BCOO."""
113
+ if sp.issparse(arr):
114
+ return jesp.BCOO.from_scipy_sparse(arr)
115
+ return arr
116
+
117
+
94
118
  def _instantiate_geodesic_cost(
95
119
  arr: jax.Array,
96
120
  problem_shape: Tuple[int, int],
@@ -16,6 +16,8 @@ from moscot.backends.ott._utils import (
16
16
  _instantiate_geodesic_cost,
17
17
  alpha_to_fused_penalty,
18
18
  check_shapes,
19
+ convert_scipy_sparse,
20
+ densify,
19
21
  ensure_2d,
20
22
  )
21
23
  from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
@@ -76,8 +78,8 @@ class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC):
76
78
  if not isinstance(cost_fn, costs.CostFn):
77
79
  raise TypeError(f"Expected `cost_fn` to be `ott.geometry.costs.CostFn`, found `{type(cost_fn)}`.")
78
80
 
79
- y = None if x.data_tgt is None else ensure_2d(x.data_tgt, reshape=True)
80
- x = ensure_2d(x.data_src, reshape=True)
81
+ y = None if x.data_tgt is None else densify(ensure_2d(x.data_tgt, reshape=True))
82
+ x = densify(ensure_2d(x.data_src, reshape=True))
81
83
  if y is not None and x.shape[1] != y.shape[1]:
82
84
  raise ValueError(
83
85
  f"Expected `x/y` to have the same number of dimensions, found `{x.shape[1]}/{y.shape[1]}`."
@@ -94,6 +96,8 @@ class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC):
94
96
  )
95
97
 
96
98
  arr = ensure_2d(x.data_src, reshape=False)
99
+ arr = densify(arr) if x.is_graph else convert_scipy_sparse(arr)
100
+
97
101
  if x.is_cost_matrix:
98
102
  return geometry.Geometry(
99
103
  cost_matrix=arr, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost
@@ -158,6 +158,17 @@ def _validate_args_cell_transition(
158
158
  raise TypeError(f"Expected argument to be either `str` or `dict`, found `{type(arg)}`.")
159
159
 
160
160
 
161
+ def _assert_series_match(a: pd.Series, b: pd.Series) -> None:
162
+ """Assert that two series are equal ignoring the names."""
163
+ pd.testing.assert_series_equal(a, b, check_names=False)
164
+
165
+
166
+ def _assert_columns_and_index_match(a: pd.Series, b: pd.DataFrame) -> None:
167
+ """Assert that a series and a dataframe's index and columns are matching."""
168
+ _assert_series_match(a, b.index.to_series())
169
+ _assert_series_match(a, b.columns.to_series())
170
+
171
+
161
172
  def _get_cell_indices(
162
173
  adata: AnnData,
163
174
  key: Optional[str] = None,
@@ -1,3 +1,4 @@
1
+ from functools import partial
1
2
  from typing import (
2
3
  TYPE_CHECKING,
3
4
  Any,
@@ -166,7 +167,14 @@ class BirthDeathProblem(BirthDeathMixin, OTProblem):
166
167
  proliferation_key: Optional[str] = None,
167
168
  apoptosis_key: Optional[str] = None,
168
169
  scaling: Optional[float] = None,
169
- **kwargs: Any,
170
+ beta_max: float = 1.7,
171
+ beta_min: float = 0.3,
172
+ beta_center: float = 0.25,
173
+ beta_width: float = 0.5,
174
+ delta_max: float = 1.7,
175
+ delta_min: float = 0.3,
176
+ delta_center: float = 0.1,
177
+ delta_width: float = 0.2,
170
178
  ) -> ArrayLike:
171
179
  """Estimate the source or target :term:`marginals` based on marker genes, either with the
172
180
  `birth-death process <https://en.wikipedia.org/wiki/Birth%E2%80%93death_process>`_,
@@ -189,9 +197,22 @@ class BirthDeathProblem(BirthDeathMixin, OTProblem):
189
197
  If :obj:`float` is passed, it will be used as a scaling parameter in an exponential kernel
190
198
  with proliferation and apoptosis scores.
191
199
  If :obj:`None`, parameters corresponding to the birth and death processes will be used.
192
- kwargs
193
- Keyword arguments for :func:`~moscot.base.problems.birth_death.beta` and
194
- :func:`~moscot.base.problems.birth_death.delta`.
200
+ beta_max
201
+ Argument for :func:`~moscot.base.problems.birth_death.beta`
202
+ beta_min
203
+ Argument for :func:`~moscot.base.problems.birth_death.beta`
204
+ beta_center
205
+ Argument for :func:`~moscot.base.problems.birth_death.beta`
206
+ beta_width
207
+ Argument for :func:`~moscot.base.problems.birth_death.beta`
208
+ delta_max
209
+ Argument for :func:`~moscot.base.problems.birth_death.delta`
210
+ delta_min
211
+ Argument for :func:`~moscot.base.problems.birth_death.delta`
212
+ delta_center
213
+ Argument for :func:`~moscot.base.problems.birth_death.delta`
214
+ delta_width
215
+ Argument for :func:`~moscot.base.problems.birth_death.delta`
195
216
 
196
217
  Returns
197
218
  -------
@@ -223,12 +244,18 @@ class BirthDeathProblem(BirthDeathMixin, OTProblem):
223
244
  self.apoptosis_key = apoptosis_key
224
245
 
225
246
  if scaling:
226
- beta_fn = delta_fn = lambda x, *_, **__: x
247
+ beta_fn = delta_fn = lambda x: x
227
248
  else:
228
- beta_fn, delta_fn = beta, delta
249
+ beta_fn = partial(
250
+ beta, beta_max=beta_max, beta_min=beta_min, beta_center=beta_center, beta_width=beta_width
251
+ )
252
+ delta_fn = partial(
253
+ delta, delta_max=delta_max, delta_min=delta_min, delta_center=delta_center, delta_width=delta_width
254
+ )
255
+
229
256
  scaling = 1.0
230
- birth = estimate(proliferation_key, fn=beta_fn, **kwargs)
231
- death = estimate(apoptosis_key, fn=delta_fn, **kwargs)
257
+ birth = estimate(proliferation_key, fn=beta_fn)
258
+ death = estimate(apoptosis_key, fn=delta_fn)
232
259
 
233
260
  prior_growth = np.exp((birth - death) * self.delta / scaling)
234
261
 
@@ -287,7 +314,6 @@ def beta(
287
314
  beta_min: float = 0.3,
288
315
  beta_center: float = 0.25,
289
316
  beta_width: float = 0.5,
290
- **_: Any,
291
317
  ) -> ArrayLike:
292
318
  """Birth process."""
293
319
  return _gen_logistic(p, beta_max, beta_min, beta_center, beta_width)
@@ -299,7 +325,6 @@ def delta(
299
325
  delta_min: float = 0.3,
300
326
  delta_center: float = 0.1,
301
327
  delta_width: float = 0.2,
302
- **_: Any,
303
328
  ) -> ArrayLike:
304
329
  """Death process."""
305
330
  return _gen_logistic(a, delta_max, delta_min, delta_center, delta_width)
@@ -149,9 +149,10 @@ class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]):
149
149
  xy_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
150
150
  x_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
151
151
  y_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
152
- **kwargs: Any,
152
+ a: Optional[Union[bool, str, ArrayLike]] = None,
153
+ b: Optional[Union[bool, str, ArrayLike]] = None,
154
+ marginal_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
153
155
  ) -> Dict[Tuple[K, K], B]:
154
- from moscot.base.problems.birth_death import BirthDeathProblem
155
156
 
156
157
  if TYPE_CHECKING:
157
158
  assert isinstance(self._policy, SubsetPolicy)
@@ -187,10 +188,7 @@ class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]):
187
188
  if y_data:
188
189
  y = dict(y)
189
190
  y["tagged_array"] = y_data
190
- if isinstance(problem, BirthDeathProblem):
191
- kwargs["proliferation_key"] = self.proliferation_key # type: ignore[attr-defined]
192
- kwargs["apoptosis_key"] = self.apoptosis_key # type: ignore[attr-defined]
193
- problems[src_name, tgt_name] = problem.prepare(xy=xy, x=x, y=y, **kwargs)
191
+ problems[src_name, tgt_name] = problem.prepare(xy=xy, x=x, y=y, a=a, b=b, marginal_kwargs=marginal_kwargs)
194
192
 
195
193
  return problems
196
194
 
@@ -200,13 +198,18 @@ class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]):
200
198
  key: Optional[str],
201
199
  subset: Optional[Sequence[Tuple[K, K]]] = None,
202
200
  reference: Optional[Any] = None,
201
+ xy: Mapping[str, Any] = types.MappingProxyType({}),
202
+ x: Mapping[str, Any] = types.MappingProxyType({}),
203
+ y: Mapping[str, Any] = types.MappingProxyType({}),
203
204
  xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
204
205
  x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
205
206
  y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
206
207
  xy_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
207
208
  x_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
208
209
  y_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
209
- **kwargs: Any,
210
+ a: Optional[Union[bool, str, ArrayLike]] = None,
211
+ b: Optional[Union[bool, str, ArrayLike]] = None,
212
+ marginal_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
210
213
  ) -> "BaseCompoundProblem[K, B]":
211
214
  """Prepare the individual :term:`OT` subproblems.
212
215
 
@@ -224,6 +227,12 @@ class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]):
224
227
  for the :class:`~moscot.utils.subset_policy.ExplicitPolicy`. Only used when ``policy = 'explicit'``.
225
228
  reference
226
229
  Reference for the :class:`~moscot.utils.subset_policy.SubsetPolicy`. Only used when ``policy = 'star'``.
230
+ xy
231
+ Data for the :term:`linear term`.
232
+ x
233
+ Data for the source :term:`quadratic term`.
234
+ y
235
+ Data for the target :term:`quadratic term`.
227
236
  xy_callback
228
237
  Callback function used to prepare the data in the :term:`linear term`.
229
238
  x_callback
@@ -236,8 +245,24 @@ class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]):
236
245
  Keyword arguments for the ``x_callback``.
237
246
  y_callback_kwargs
238
247
  Keyword arguments for the ``y_callback``.
239
- kwargs
240
- Keyword arguments for the subproblems' :meth:`~moscot.base.problems.OTProblem.prepare` method.
248
+ a
249
+ Source :term:`marginals`. Valid options are:
250
+
251
+ - :class:`str` - key in :attr:`~anndata.AnnData.obs` where the source marginals are stored.
252
+ - :class:`bool` - if :obj:`True`,
253
+ :meth:`estimate the marginals <moscot.base.problems.OTProblem.estimate_marginals>`,
254
+ otherwise use uniform marginals.
255
+ - :obj:`None` - uniform marginals.
256
+ b
257
+ Target :term:`marginals`. Valid options are:
258
+
259
+ - :class:`str` - key in :attr:`~anndata.AnnData.obs` where the target marginals are stored.
260
+ - :class:`bool` - if :obj:`True`,
261
+ :meth:`estimate the marginals <moscot.base.problems.OTProblem.estimate_marginals>`,
262
+ otherwise use uniform marginals.
263
+ - :obj:`None` - uniform marginals.
264
+ marginal_kwargs
265
+ Keyword arguments for the :meth:`~moscot.base.problems.OTProblem.estimate_marginals` method.
241
266
 
242
267
  Returns
243
268
  -------
@@ -264,13 +289,18 @@ class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]):
264
289
  # when refactoring the callback, consider changing this
265
290
  self._problem_manager = ProblemManager(self, policy=policy)
266
291
  problems = self._create_problems(
267
- xy_callback=xy_callback,
292
+ x=x,
293
+ y=y,
294
+ xy=xy,
295
+ a=a,
296
+ b=b,
268
297
  x_callback=x_callback,
269
298
  y_callback=y_callback,
270
- xy_callback_kwargs=xy_callback_kwargs,
299
+ xy_callback=xy_callback,
271
300
  x_callback_kwargs=x_callback_kwargs,
272
301
  y_callback_kwargs=y_callback_kwargs,
273
- **kwargs,
302
+ xy_callback_kwargs=xy_callback_kwargs,
303
+ marginal_kwargs=marginal_kwargs,
274
304
  )
275
305
  self._problem_manager.add_problems(problems)
276
306
 
@@ -313,6 +343,11 @@ class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]):
313
343
  problems = self._problem_manager.get_problems(stage=stage)
314
344
 
315
345
  logger.info(f"Solving `{len(problems)}` problems")
346
+ # expose min/max iterations to the user but remove them if they are None
347
+ if "min_iterations" in kwargs and kwargs["min_iterations"] is None:
348
+ kwargs.pop("min_iterations")
349
+ if "max_iterations" in kwargs and kwargs["max_iterations"] is None:
350
+ kwargs.pop("max_iterations")
316
351
  for problem in problems.values():
317
352
  logger.info(f"Solving problem {problem}.")
318
353
  _ = problem.solve(**kwargs)
@@ -587,14 +622,14 @@ class CompoundProblem(BaseCompoundProblem[K, B], abc.ABC):
587
622
  linear_cost_matrix = data[mask, :][:, mask_2]
588
623
  if sp.issparse(linear_cost_matrix):
589
624
  logger.warning("Linear cost matrix being densified.")
590
- linear_cost_matrix = linear_cost_matrix.A
625
+ linear_cost_matrix = linear_cost_matrix.toarray()
591
626
  return TaggedArray(linear_cost_matrix, tag=Tag.COST_MATRIX)
592
627
 
593
628
  if term in ("x", "y"):
594
629
  quad_cost_matrix = data[mask, :][:, mask]
595
630
  if sp.issparse(quad_cost_matrix):
596
631
  logger.warning("Quadratic cost matrix being densified.")
597
- quad_cost_matrix = quad_cost_matrix.A
632
+ quad_cost_matrix = quad_cost_matrix.toarray()
598
633
  return TaggedArray(quad_cost_matrix, tag=Tag.COST_MATRIX)
599
634
 
600
635
  raise ValueError(f"Expected `term` to be one of `x`, `y`, or `xy`, found `{term!r}`.")