moscot 0.3.2__tar.gz → 0.3.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of moscot might be problematic. Click here for more details.
- {moscot-0.3.2 → moscot-0.3.4}/.pre-commit-config.yaml +9 -9
- {moscot-0.3.2 → moscot-0.3.4}/PKG-INFO +33 -3
- {moscot-0.3.2 → moscot-0.3.4}/README.rst +1 -1
- {moscot-0.3.2 → moscot-0.3.4}/pyproject.toml +4 -1
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/_types.py +10 -10
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/backends/ott/__init__.py +4 -2
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/backends/ott/_utils.py +25 -2
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/backends/ott/output.py +79 -6
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/backends/ott/solver.py +148 -26
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/problems/_mixins.py +199 -31
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/problems/_utils.py +8 -1
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/problems/birth_death.py +8 -4
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/problems/compound_problem.py +30 -29
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/problems/problem.py +234 -40
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/solver.py +0 -1
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/costs/_costs.py +5 -2
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/costs/_utils.py +15 -4
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/datasets.py +36 -4
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/plotting/_plotting.py +11 -11
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/plotting/_utils.py +9 -6
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/_utils.py +11 -13
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/cross_modality/_mixins.py +62 -5
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/cross_modality/_translation.py +4 -4
- moscot-0.3.4/src/moscot/problems/generic/__init__.py +4 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/generic/_generic.py +245 -32
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/generic/_mixins.py +10 -3
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/space/_alignment.py +11 -7
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/space/_mapping.py +8 -9
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/space/_mixins.py +123 -10
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/spatiotemporal/_spatio_temporal.py +6 -1
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/time/_lineage.py +7 -6
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/time/_mixins.py +95 -45
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/_data/mouse_proliferation.txt +0 -1
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/subset_policy.py +7 -12
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/tagged_array.py +27 -1
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot.egg-info/PKG-INFO +33 -3
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot.egg-info/requires.txt +2 -1
- moscot-0.3.2/src/moscot/problems/generic/__init__.py +0 -4
- {moscot-0.3.2 → moscot-0.3.4}/.gitignore +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/.gitmodules +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/.readthedocs.yml +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/LICENSE +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/MANIFEST.in +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/codecov.yml +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/setup.cfg +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/__init__.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/_constants.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/_logging.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/_registry.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/backends/__init__.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/backends/utils.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/__init__.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/cost.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/output.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/problems/__init__.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/base/problems/manager.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/costs/__init__.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/plotting/__init__.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/__init__.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/cross_modality/__init__.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/space/__init__.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/spatiotemporal/__init__.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/problems/time/__init__.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/py.typed +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/__init__.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/_data/allTFs_dmel.txt +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/_data/allTFs_hg38.txt +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/_data/allTFs_mm.txt +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/_data/human_apoptosis.txt +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/_data/human_proliferation.txt +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/_data/mouse_apoptosis.txt +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot/utils/data.py +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot.egg-info/SOURCES.txt +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot.egg-info/dependency_links.txt +0 -0
- {moscot-0.3.2 → moscot-0.3.4}/src/moscot.egg-info/top_level.txt +0 -0
|
@@ -7,29 +7,29 @@ default_stages:
|
|
|
7
7
|
minimum_pre_commit_version: 3.0.0
|
|
8
8
|
repos:
|
|
9
9
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
|
10
|
-
rev: v1.
|
|
10
|
+
rev: v1.8.0
|
|
11
11
|
hooks:
|
|
12
12
|
- id: mypy
|
|
13
13
|
additional_dependencies: [numpy>=1.25.0]
|
|
14
14
|
files: ^src
|
|
15
15
|
- repo: https://github.com/psf/black
|
|
16
|
-
rev:
|
|
16
|
+
rev: 24.2.0
|
|
17
17
|
hooks:
|
|
18
18
|
- id: black
|
|
19
19
|
additional_dependencies: [toml]
|
|
20
20
|
- repo: https://github.com/pre-commit/mirrors-prettier
|
|
21
|
-
rev:
|
|
21
|
+
rev: v4.0.0-alpha.8
|
|
22
22
|
hooks:
|
|
23
23
|
- id: prettier
|
|
24
24
|
language_version: system
|
|
25
25
|
- repo: https://github.com/PyCQA/isort
|
|
26
|
-
rev: 5.
|
|
26
|
+
rev: 5.13.2
|
|
27
27
|
hooks:
|
|
28
28
|
- id: isort
|
|
29
29
|
additional_dependencies: [toml]
|
|
30
30
|
args: [--order-by-type]
|
|
31
31
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
32
|
-
rev: v4.
|
|
32
|
+
rev: v4.5.0
|
|
33
33
|
hooks:
|
|
34
34
|
- id: check-merge-conflict
|
|
35
35
|
- id: check-ast
|
|
@@ -42,17 +42,17 @@ repos:
|
|
|
42
42
|
- id: check-yaml
|
|
43
43
|
- id: check-toml
|
|
44
44
|
- repo: https://github.com/asottile/pyupgrade
|
|
45
|
-
rev: v3.
|
|
45
|
+
rev: v3.15.1
|
|
46
46
|
hooks:
|
|
47
47
|
- id: pyupgrade
|
|
48
48
|
args: [--py3-plus, --py38-plus, --keep-runtime-typing]
|
|
49
49
|
- repo: https://github.com/asottile/blacken-docs
|
|
50
|
-
rev: 1.
|
|
50
|
+
rev: 1.16.0
|
|
51
51
|
hooks:
|
|
52
52
|
- id: blacken-docs
|
|
53
53
|
additional_dependencies: [black==23.1.0]
|
|
54
54
|
- repo: https://github.com/rstcheck/rstcheck
|
|
55
|
-
rev: v6.
|
|
55
|
+
rev: v6.2.0
|
|
56
56
|
hooks:
|
|
57
57
|
- id: rstcheck
|
|
58
58
|
additional_dependencies: [tomli]
|
|
@@ -63,7 +63,7 @@ repos:
|
|
|
63
63
|
- id: doc8
|
|
64
64
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
65
65
|
# Ruff version.
|
|
66
|
-
rev: v0.
|
|
66
|
+
rev: v0.2.2
|
|
67
67
|
hooks:
|
|
68
68
|
- id: ruff
|
|
69
69
|
args: [--fix, --exit-non-zero-on-fix]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: moscot
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.4
|
|
4
4
|
Summary: Multi-omic single-cell optimal transport tools
|
|
5
5
|
Author: Dominik Klein, Giovanni Palla, Michal Klein, Zoe Piran, Marius Lange
|
|
6
6
|
Maintainer-email: Dominik Klein <dominik.klein@helmholtz-muenchen.de>, Giovanni Palla <giovanni.palla@helmholtz-muenchen.de>, Michal Klein <michal.klein@helmholtz-muenchen.de>
|
|
@@ -56,11 +56,41 @@ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
|
|
|
56
56
|
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
57
57
|
Requires-Python: >=3.8
|
|
58
58
|
Description-Content-Type: text/x-rst
|
|
59
|
+
License-File: LICENSE
|
|
60
|
+
Requires-Dist: numpy>=1.20.0
|
|
61
|
+
Requires-Dist: scipy>=1.7.0
|
|
62
|
+
Requires-Dist: pandas>=2.0.1
|
|
63
|
+
Requires-Dist: networkx>=2.6.3
|
|
64
|
+
Requires-Dist: matplotlib>=3.5.0
|
|
65
|
+
Requires-Dist: anndata>=0.9.1
|
|
66
|
+
Requires-Dist: scanpy>=1.9.3
|
|
67
|
+
Requires-Dist: wrapt>=1.13.2
|
|
68
|
+
Requires-Dist: docrep>=0.3.2
|
|
69
|
+
Requires-Dist: ott-jax>=0.4.5
|
|
70
|
+
Requires-Dist: cloudpickle>=2.2.0
|
|
71
|
+
Requires-Dist: rich>=13.5
|
|
59
72
|
Provides-Extra: spatial
|
|
73
|
+
Requires-Dist: squidpy>=1.2.3; extra == "spatial"
|
|
60
74
|
Provides-Extra: dev
|
|
75
|
+
Requires-Dist: pre-commit>=3.0.0; extra == "dev"
|
|
76
|
+
Requires-Dist: tox>=4; extra == "dev"
|
|
61
77
|
Provides-Extra: test
|
|
78
|
+
Requires-Dist: pytest>=7; extra == "test"
|
|
79
|
+
Requires-Dist: pytest-xdist>=3; extra == "test"
|
|
80
|
+
Requires-Dist: pytest-mock>=3.5.0; extra == "test"
|
|
81
|
+
Requires-Dist: pytest-cov>=4; extra == "test"
|
|
82
|
+
Requires-Dist: coverage[toml]>=7; extra == "test"
|
|
62
83
|
Provides-Extra: docs
|
|
63
|
-
|
|
84
|
+
Requires-Dist: sphinx>=5.1.1; extra == "docs"
|
|
85
|
+
Requires-Dist: sphinx_copybutton>=0.5.0; extra == "docs"
|
|
86
|
+
Requires-Dist: sphinxcontrib-bibtex>=2.3.0; extra == "docs"
|
|
87
|
+
Requires-Dist: sphinxcontrib-spelling>=7.6.2; extra == "docs"
|
|
88
|
+
Requires-Dist: sphinx-autodoc-typehints; extra == "docs"
|
|
89
|
+
Requires-Dist: furo>=2022.09.29; extra == "docs"
|
|
90
|
+
Requires-Dist: sphinx-tippy>=0.4.1; extra == "docs"
|
|
91
|
+
Requires-Dist: myst-nb>=0.17.1; extra == "docs"
|
|
92
|
+
Requires-Dist: ipython>=7.20.0; extra == "docs"
|
|
93
|
+
Requires-Dist: sphinx_design>=0.3.0; extra == "docs"
|
|
64
94
|
|
|
65
95
|
|PyPI| |Downloads| |CI| |Pre-commit| |Codecov| |Docs|
|
|
66
96
|
|
|
@@ -125,6 +155,6 @@ Our preprint "Mapping cells through time and space with moscot" can be found `he
|
|
|
125
155
|
:target: https://moscot.readthedocs.io/en/stable/
|
|
126
156
|
:alt: Documentation
|
|
127
157
|
|
|
128
|
-
.. |Downloads| image:: https://pepy.tech/badge/moscot
|
|
158
|
+
.. |Downloads| image:: https://static.pepy.tech/badge/moscot
|
|
129
159
|
:target: https://pepy.tech/project/moscot
|
|
130
160
|
:alt: Downloads
|
|
@@ -61,6 +61,6 @@ Our preprint "Mapping cells through time and space with moscot" can be found `he
|
|
|
61
61
|
:target: https://moscot.readthedocs.io/en/stable/
|
|
62
62
|
:alt: Documentation
|
|
63
63
|
|
|
64
|
-
.. |Downloads| image:: https://pepy.tech/badge/moscot
|
|
64
|
+
.. |Downloads| image:: https://static.pepy.tech/badge/moscot
|
|
65
65
|
:target: https://pepy.tech/project/moscot
|
|
66
66
|
:alt: Downloads
|
|
@@ -55,8 +55,9 @@ dependencies = [
|
|
|
55
55
|
"scanpy>=1.9.3",
|
|
56
56
|
"wrapt>=1.13.2",
|
|
57
57
|
"docrep>=0.3.2",
|
|
58
|
-
"ott-jax>=0.4.
|
|
58
|
+
"ott-jax>=0.4.5",
|
|
59
59
|
"cloudpickle>=2.2.0",
|
|
60
|
+
"rich>=13.5",
|
|
60
61
|
]
|
|
61
62
|
|
|
62
63
|
[project.optional-dependencies]
|
|
@@ -121,6 +122,8 @@ ignore = [
|
|
|
121
122
|
"D107",
|
|
122
123
|
# Missing docstring in magic method
|
|
123
124
|
"D105",
|
|
125
|
+
# Use `X | Y` for type annotations
|
|
126
|
+
"UP007",
|
|
124
127
|
]
|
|
125
128
|
line-length = 120
|
|
126
129
|
select = [
|
|
@@ -8,7 +8,7 @@ import numpy as np
|
|
|
8
8
|
try:
|
|
9
9
|
from numpy.typing import DTypeLike, NDArray
|
|
10
10
|
|
|
11
|
-
ArrayLike = NDArray[np.
|
|
11
|
+
ArrayLike = NDArray[np.float64]
|
|
12
12
|
except (ImportError, TypeError):
|
|
13
13
|
ArrayLike = np.ndarray # type: ignore[misc]
|
|
14
14
|
DTypeLike = np.dtype # type: ignore[misc]
|
|
@@ -33,20 +33,20 @@ OttCostFn_t = Literal[
|
|
|
33
33
|
"euclidean",
|
|
34
34
|
"sq_euclidean",
|
|
35
35
|
"cosine",
|
|
36
|
-
"
|
|
37
|
-
"
|
|
38
|
-
"
|
|
39
|
-
"
|
|
40
|
-
"
|
|
41
|
-
"
|
|
42
|
-
"
|
|
43
|
-
"
|
|
36
|
+
"pnorm_p",
|
|
37
|
+
"sq_pnorm",
|
|
38
|
+
"cosine",
|
|
39
|
+
"elastic_l1",
|
|
40
|
+
"elastic_l2",
|
|
41
|
+
"elastic_stvs",
|
|
42
|
+
"elastic_sqk_overlap",
|
|
43
|
+
"geodesic",
|
|
44
44
|
]
|
|
45
45
|
OttCostFnMap_t = Union[OttCostFn_t, Mapping[Literal["xy", "x", "y"], OttCostFn_t]]
|
|
46
46
|
GenericCostFn_t = Literal["barcode_distance", "leaf_distance", "custom"]
|
|
47
47
|
CostFn_t = Union[str, GenericCostFn_t, OttCostFn_t]
|
|
48
48
|
CostFnMap_t = Union[Union[OttCostFn_t, GenericCostFn_t], Mapping[str, Union[OttCostFn_t, GenericCostFn_t]]]
|
|
49
|
-
PathLike = Union[os.PathLike, str]
|
|
49
|
+
PathLike = Union[os.PathLike, str] # type: ignore[type-arg]
|
|
50
50
|
Policy_t = Literal[
|
|
51
51
|
"sequential",
|
|
52
52
|
"star",
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from ott.geometry import costs
|
|
2
2
|
|
|
3
3
|
from moscot.backends.ott._utils import sinkhorn_divergence
|
|
4
|
-
from moscot.backends.ott.output import OTTOutput
|
|
4
|
+
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
|
|
5
5
|
from moscot.backends.ott.solver import GWSolver, SinkhornSolver
|
|
6
6
|
from moscot.costs import register_cost
|
|
7
7
|
|
|
8
|
-
__all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"]
|
|
8
|
+
__all__ = ["OTTOutput", "GraphOTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"]
|
|
9
9
|
|
|
10
10
|
register_cost("euclidean", backend="ott")(costs.Euclidean)
|
|
11
11
|
register_cost("sq_euclidean", backend="ott")(costs.SqEuclidean)
|
|
@@ -13,4 +13,6 @@ register_cost("cosine", backend="ott")(costs.Cosine)
|
|
|
13
13
|
register_cost("pnorm_p", backend="ott")(costs.PNormP)
|
|
14
14
|
register_cost("sq_pnorm", backend="ott")(costs.SqPNorm)
|
|
15
15
|
register_cost("elastic_l1", backend="ott")(costs.ElasticL1)
|
|
16
|
+
register_cost("elastic_l2", backend="ott")(costs.ElasticL2)
|
|
16
17
|
register_cost("elastic_stvs", backend="ott")(costs.ElasticSTVS)
|
|
18
|
+
register_cost("elastic_sqk_overlap", backend="ott")(costs.ElasticSqKOverlap)
|
|
@@ -1,14 +1,17 @@
|
|
|
1
|
-
from typing import Any, Optional, Union
|
|
1
|
+
from typing import Any, Literal, Optional, Tuple, Union
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
4
|
import jax.numpy as jnp
|
|
5
5
|
import scipy.sparse as sp
|
|
6
|
-
from ott.geometry import epsilon_scheduler, geometry, pointcloud
|
|
6
|
+
from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
|
|
7
7
|
from ott.tools import sinkhorn_divergence as sdiv
|
|
8
8
|
|
|
9
9
|
from moscot._logging import logger
|
|
10
10
|
from moscot._types import ArrayLike, ScaleCost_t
|
|
11
11
|
|
|
12
|
+
Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]
|
|
13
|
+
|
|
14
|
+
|
|
12
15
|
__all__ = ["sinkhorn_divergence"]
|
|
13
16
|
|
|
14
17
|
|
|
@@ -86,3 +89,23 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
|
|
|
86
89
|
if arr.ndim != 2:
|
|
87
90
|
raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
|
|
88
91
|
return arr
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _instantiate_geodesic_cost(
|
|
95
|
+
arr: jax.Array,
|
|
96
|
+
problem_shape: Tuple[int, int],
|
|
97
|
+
t: Optional[float],
|
|
98
|
+
is_linear_term: bool,
|
|
99
|
+
epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
|
|
100
|
+
relative_epsilon: Optional[bool] = None,
|
|
101
|
+
scale_cost: Scale_t = 1.0,
|
|
102
|
+
directed: bool = True,
|
|
103
|
+
**kwargs: Any,
|
|
104
|
+
) -> geometry.Geometry:
|
|
105
|
+
n_src, n_tgt = problem_shape
|
|
106
|
+
if is_linear_term and n_src + n_tgt != arr.shape[0]:
|
|
107
|
+
raise ValueError(f"Expected `x` to have `{n_src + n_tgt}` points, found `{arr.shape[0]}`.")
|
|
108
|
+
t = epsilon / 4.0 if t is None else t
|
|
109
|
+
cm_full = geodesic.Geodesic.from_graph(arr, t=t, directed=directed, **kwargs).cost_matrix
|
|
110
|
+
cm = cm_full[:n_src, n_src:] if is_linear_term else cm_full
|
|
111
|
+
return geometry.Geometry(cm, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost)
|
|
@@ -6,7 +6,7 @@ import jax
|
|
|
6
6
|
import jax.numpy as jnp
|
|
7
7
|
import numpy as np
|
|
8
8
|
from ott.solvers.linear import sinkhorn, sinkhorn_lr
|
|
9
|
-
from ott.solvers.quadratic import gromov_wasserstein
|
|
9
|
+
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr
|
|
10
10
|
|
|
11
11
|
import matplotlib as mpl
|
|
12
12
|
import matplotlib.pyplot as plt
|
|
@@ -14,7 +14,7 @@ import matplotlib.pyplot as plt
|
|
|
14
14
|
from moscot._types import ArrayLike, Device_t
|
|
15
15
|
from moscot.base.output import BaseSolverOutput
|
|
16
16
|
|
|
17
|
-
__all__ = ["OTTOutput"]
|
|
17
|
+
__all__ = ["OTTOutput", "GraphOTTOutput"]
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class OTTOutput(BaseSolverOutput):
|
|
@@ -29,7 +29,13 @@ class OTTOutput(BaseSolverOutput):
|
|
|
29
29
|
_NOT_COMPUTED = -1.0 # sentinel value used in `ott`
|
|
30
30
|
|
|
31
31
|
def __init__(
|
|
32
|
-
self,
|
|
32
|
+
self,
|
|
33
|
+
output: Union[
|
|
34
|
+
sinkhorn.SinkhornOutput,
|
|
35
|
+
sinkhorn_lr.LRSinkhornOutput,
|
|
36
|
+
gromov_wasserstein.GWOutput,
|
|
37
|
+
gromov_wasserstein_lr.LRGWOutput,
|
|
38
|
+
],
|
|
33
39
|
):
|
|
34
40
|
super().__init__()
|
|
35
41
|
self._output = output
|
|
@@ -168,7 +174,10 @@ class OTTOutput(BaseSolverOutput):
|
|
|
168
174
|
def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
|
|
169
175
|
if x.ndim == 1:
|
|
170
176
|
return self._output.apply(x, axis=1 - forward)
|
|
171
|
-
return self._output.apply(
|
|
177
|
+
return self._output.apply(
|
|
178
|
+
x.T,
|
|
179
|
+
axis=1 - forward,
|
|
180
|
+
).T # convert to batch first
|
|
172
181
|
|
|
173
182
|
@property
|
|
174
183
|
def shape(self) -> Tuple[int, int]: # noqa: D102
|
|
@@ -218,8 +227,72 @@ class OTTOutput(BaseSolverOutput):
|
|
|
218
227
|
|
|
219
228
|
@property
|
|
220
229
|
def rank(self) -> int: # noqa: D102
|
|
221
|
-
|
|
222
|
-
return
|
|
230
|
+
output = self._output.linear_state if isinstance(self._output, gromov_wasserstein.GWOutput) else self._output
|
|
231
|
+
return (
|
|
232
|
+
len(output.g)
|
|
233
|
+
if isinstance(output, (sinkhorn_lr.LRSinkhornOutput, gromov_wasserstein_lr.LRGWOutput))
|
|
234
|
+
else -1
|
|
235
|
+
)
|
|
223
236
|
|
|
224
237
|
def _ones(self, n: int) -> ArrayLike: # noqa: D102
|
|
225
238
|
return jnp.ones((n,))
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class GraphOTTOutput(OTTOutput):
|
|
242
|
+
"""Output of :term:`OT` problems with a graph geometry in the linear term.
|
|
243
|
+
|
|
244
|
+
Parameters
|
|
245
|
+
----------
|
|
246
|
+
output
|
|
247
|
+
Output of the :mod:`ott` backend.
|
|
248
|
+
shape
|
|
249
|
+
Shape of the problem.
|
|
250
|
+
"""
|
|
251
|
+
|
|
252
|
+
def __init__(
|
|
253
|
+
self,
|
|
254
|
+
output: Union[
|
|
255
|
+
sinkhorn.SinkhornOutput,
|
|
256
|
+
sinkhorn_lr.LRSinkhornOutput,
|
|
257
|
+
gromov_wasserstein.GWOutput,
|
|
258
|
+
gromov_wasserstein_lr.LRGWOutput,
|
|
259
|
+
],
|
|
260
|
+
shape: Tuple[int, int],
|
|
261
|
+
):
|
|
262
|
+
super().__init__(output)
|
|
263
|
+
self._shape = shape
|
|
264
|
+
|
|
265
|
+
@property
|
|
266
|
+
def shape(self) -> Tuple[int, int]: # noqa: D102
|
|
267
|
+
return self._shape
|
|
268
|
+
|
|
269
|
+
def _expand_data(self, x: jnp.ndarray, forward: bool) -> jnp.ndarray:
|
|
270
|
+
if forward:
|
|
271
|
+
shape = (self.shape[1],) if x.ndim == 1 else (self.shape[1], x.shape[1])
|
|
272
|
+
return jnp.concatenate((x, jnp.zeros(shape)))
|
|
273
|
+
shape = (self.shape[0],) if x.ndim == 1 else (self.shape[0], x.shape[1])
|
|
274
|
+
return jnp.concatenate((jnp.zeros(shape), x))
|
|
275
|
+
|
|
276
|
+
def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
|
|
277
|
+
x_expanded = self._expand_data(x, forward=forward)
|
|
278
|
+
# ott-jax only supports lse_mode=False with graph geometry
|
|
279
|
+
res = self._output.apply(x_expanded.T, axis=1 - forward, lse_mode=False).T
|
|
280
|
+
return res[len(x) :] if forward else res[: -len(x)]
|
|
281
|
+
|
|
282
|
+
def to(self, device: Optional[Device_t] = None) -> "GraphOTTOutput": # noqa: D102
|
|
283
|
+
if device is None:
|
|
284
|
+
return GraphOTTOutput(jax.device_put(self._output, device=device), shape=self.shape)
|
|
285
|
+
|
|
286
|
+
if isinstance(device, str) and ":" in device:
|
|
287
|
+
device, ix = device.split(":")
|
|
288
|
+
idx = int(ix)
|
|
289
|
+
else:
|
|
290
|
+
idx = 0
|
|
291
|
+
|
|
292
|
+
if not isinstance(device, xla_ext.Device):
|
|
293
|
+
try:
|
|
294
|
+
device = jax.devices(device)[idx]
|
|
295
|
+
except IndexError:
|
|
296
|
+
raise IndexError(f"Unable to fetch the device with `id={idx}`.") from None
|
|
297
|
+
|
|
298
|
+
return GraphOTTOutput(jax.device_put(self._output, device), shape=self.shape)
|
|
@@ -4,22 +4,34 @@ import types
|
|
|
4
4
|
from typing import Any, Literal, Mapping, Optional, Set, Tuple, Union
|
|
5
5
|
|
|
6
6
|
import jax
|
|
7
|
-
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud
|
|
8
9
|
from ott.problems.linear import linear_problem
|
|
9
10
|
from ott.problems.quadratic import quadratic_problem
|
|
10
11
|
from ott.solvers.linear import sinkhorn, sinkhorn_lr
|
|
11
|
-
from ott.solvers.quadratic import gromov_wasserstein
|
|
12
|
+
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr
|
|
12
13
|
|
|
13
14
|
from moscot._types import ProblemKind_t, QuadInitializer_t, SinkhornInitializer_t
|
|
14
|
-
from moscot.backends.ott._utils import
|
|
15
|
-
|
|
15
|
+
from moscot.backends.ott._utils import (
|
|
16
|
+
_instantiate_geodesic_cost,
|
|
17
|
+
alpha_to_fused_penalty,
|
|
18
|
+
check_shapes,
|
|
19
|
+
ensure_2d,
|
|
20
|
+
)
|
|
21
|
+
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
|
|
22
|
+
from moscot.base.problems._utils import TimeScalesHeatKernel
|
|
16
23
|
from moscot.base.solver import OTSolver
|
|
17
24
|
from moscot.costs import get_cost
|
|
18
25
|
from moscot.utils.tagged_array import TaggedArray
|
|
19
26
|
|
|
20
27
|
__all__ = ["SinkhornSolver", "GWSolver"]
|
|
21
28
|
|
|
22
|
-
OTTSolver_t = Union[
|
|
29
|
+
OTTSolver_t = Union[
|
|
30
|
+
sinkhorn.Sinkhorn,
|
|
31
|
+
sinkhorn_lr.LRSinkhorn,
|
|
32
|
+
gromov_wasserstein.GromovWasserstein,
|
|
33
|
+
gromov_wasserstein_lr.LRGromovWasserstein,
|
|
34
|
+
]
|
|
23
35
|
OTTProblem_t = Union[linear_problem.LinearProblem, quadratic_problem.QuadraticProblem]
|
|
24
36
|
Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]
|
|
25
37
|
|
|
@@ -38,14 +50,21 @@ class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC):
|
|
|
38
50
|
self._solver: Optional[OTTSolver_t] = None
|
|
39
51
|
self._problem: Optional[OTTProblem_t] = None
|
|
40
52
|
self._jit = jit
|
|
53
|
+
self._a: Optional[jnp.ndarray] = None
|
|
54
|
+
self._b: Optional[jnp.ndarray] = None
|
|
41
55
|
|
|
42
56
|
def _create_geometry(
|
|
43
57
|
self,
|
|
44
58
|
x: TaggedArray,
|
|
59
|
+
*,
|
|
60
|
+
is_linear_term: bool,
|
|
45
61
|
epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
|
|
46
62
|
relative_epsilon: Optional[bool] = None,
|
|
47
63
|
scale_cost: Scale_t = 1.0,
|
|
48
64
|
batch_size: Optional[int] = None,
|
|
65
|
+
problem_shape: Optional[Tuple[int, int]] = None,
|
|
66
|
+
t: Optional[float] = None,
|
|
67
|
+
directed: bool = True,
|
|
49
68
|
**kwargs: Any,
|
|
50
69
|
) -> geometry.Geometry:
|
|
51
70
|
if x.is_point_cloud:
|
|
@@ -83,17 +102,80 @@ class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC):
|
|
|
83
102
|
return geometry.Geometry(
|
|
84
103
|
kernel_matrix=arr, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost
|
|
85
104
|
)
|
|
105
|
+
if x.is_graph: # we currently only support this for the linear term.
|
|
106
|
+
return self._create_graph_geometry(
|
|
107
|
+
is_linear_term=is_linear_term,
|
|
108
|
+
x=x,
|
|
109
|
+
arr=arr,
|
|
110
|
+
problem_shape=problem_shape,
|
|
111
|
+
t=t,
|
|
112
|
+
epsilon=epsilon,
|
|
113
|
+
relative_epsilon=relative_epsilon,
|
|
114
|
+
scale_cost=scale_cost,
|
|
115
|
+
directed=directed,
|
|
116
|
+
**kwargs,
|
|
117
|
+
)
|
|
86
118
|
raise NotImplementedError(f"Creating geometry from `tag={x.tag!r}` is not yet implemented.")
|
|
87
119
|
|
|
88
120
|
def _solve( # type: ignore[override]
|
|
89
121
|
self,
|
|
90
122
|
prob: OTTProblem_t,
|
|
91
123
|
**kwargs: Any,
|
|
92
|
-
) -> OTTOutput:
|
|
124
|
+
) -> Union[OTTOutput, GraphOTTOutput]:
|
|
93
125
|
solver = jax.jit(self.solver) if self._jit else self.solver
|
|
94
126
|
out = solver(prob, **kwargs)
|
|
127
|
+
if isinstance(prob, linear_problem.LinearProblem) and isinstance(prob.geom, geodesic.Geodesic):
|
|
128
|
+
return GraphOTTOutput(out, shape=(len(self._a), len(self._b))) # type: ignore[arg-type]
|
|
95
129
|
return OTTOutput(out)
|
|
96
130
|
|
|
131
|
+
def _create_graph_geometry(
|
|
132
|
+
self,
|
|
133
|
+
is_linear_term: bool,
|
|
134
|
+
x: TaggedArray,
|
|
135
|
+
arr: jax.Array,
|
|
136
|
+
problem_shape: Optional[Tuple[int, int]],
|
|
137
|
+
t: Optional[float],
|
|
138
|
+
epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
|
|
139
|
+
relative_epsilon: Optional[bool] = None,
|
|
140
|
+
scale_cost: Scale_t = 1.0,
|
|
141
|
+
directed: bool = True,
|
|
142
|
+
**kwargs: Any,
|
|
143
|
+
) -> geometry.Geometry:
|
|
144
|
+
if x.cost == "geodesic":
|
|
145
|
+
if self.problem_kind == "linear":
|
|
146
|
+
if t is None:
|
|
147
|
+
if epsilon is None:
|
|
148
|
+
raise ValueError("`epsilon` cannot be `None`.")
|
|
149
|
+
return geodesic.Geodesic.from_graph(arr, t=epsilon / 4.0, directed=directed, **kwargs)
|
|
150
|
+
|
|
151
|
+
return _instantiate_geodesic_cost(
|
|
152
|
+
arr=arr,
|
|
153
|
+
problem_shape=problem_shape, # type: ignore[arg-type]
|
|
154
|
+
t=t,
|
|
155
|
+
is_linear_term=True,
|
|
156
|
+
epsilon=epsilon,
|
|
157
|
+
relative_epsilon=relative_epsilon,
|
|
158
|
+
scale_cost=scale_cost,
|
|
159
|
+
directed=directed,
|
|
160
|
+
**kwargs,
|
|
161
|
+
)
|
|
162
|
+
if self.problem_kind == "quadratic":
|
|
163
|
+
problem_shape = x.shape if problem_shape is None else problem_shape
|
|
164
|
+
return _instantiate_geodesic_cost(
|
|
165
|
+
arr=arr,
|
|
166
|
+
problem_shape=problem_shape,
|
|
167
|
+
t=t,
|
|
168
|
+
is_linear_term=is_linear_term,
|
|
169
|
+
epsilon=epsilon,
|
|
170
|
+
relative_epsilon=relative_epsilon,
|
|
171
|
+
scale_cost=scale_cost,
|
|
172
|
+
directed=directed,
|
|
173
|
+
**kwargs,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
raise NotImplementedError(f"Invalid problem kind `{self.problem_kind}`.")
|
|
177
|
+
raise NotImplementedError(f"If the geometry is a graph, `cost` must be `geodesic`, found `{x.cost}`.")
|
|
178
|
+
|
|
97
179
|
@property
|
|
98
180
|
def solver(self) -> OTTSolver_t:
|
|
99
181
|
""":mod:`ott` solver."""
|
|
@@ -160,6 +242,8 @@ class SinkhornSolver(OTTJaxSolver):
|
|
|
160
242
|
|
|
161
243
|
def _prepare(
|
|
162
244
|
self,
|
|
245
|
+
a: jnp.ndarray,
|
|
246
|
+
b: jnp.ndarray,
|
|
163
247
|
xy: Optional[TaggedArray] = None,
|
|
164
248
|
x: Optional[TaggedArray] = None,
|
|
165
249
|
y: Optional[TaggedArray] = None,
|
|
@@ -170,24 +254,35 @@ class SinkhornSolver(OTTJaxSolver):
|
|
|
170
254
|
scale_cost: Scale_t = 1.0,
|
|
171
255
|
cost_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
|
|
172
256
|
cost_matrix_rank: Optional[int] = None,
|
|
257
|
+
time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None,
|
|
173
258
|
# problem
|
|
174
259
|
**kwargs: Any,
|
|
175
260
|
) -> linear_problem.LinearProblem:
|
|
176
261
|
del x, y
|
|
262
|
+
time_scales_heat_kernel = (
|
|
263
|
+
TimeScalesHeatKernel(None, None, None) if time_scales_heat_kernel is None else time_scales_heat_kernel
|
|
264
|
+
)
|
|
177
265
|
if xy is None:
|
|
178
266
|
raise ValueError(f"Unable to create geometry from `xy={xy}`.")
|
|
179
|
-
|
|
267
|
+
self._a = a
|
|
268
|
+
self._b = b
|
|
180
269
|
geom = self._create_geometry(
|
|
181
270
|
xy,
|
|
271
|
+
is_linear_term=True,
|
|
182
272
|
epsilon=epsilon,
|
|
183
273
|
relative_epsilon=relative_epsilon,
|
|
184
274
|
batch_size=batch_size,
|
|
275
|
+
problem_shape=(len(self._a), len(self._b)),
|
|
185
276
|
scale_cost=scale_cost,
|
|
277
|
+
t=time_scales_heat_kernel.xy,
|
|
186
278
|
**cost_kwargs,
|
|
187
279
|
)
|
|
188
280
|
if cost_matrix_rank is not None:
|
|
189
281
|
geom = geom.to_LRCGeometry(rank=cost_matrix_rank)
|
|
190
|
-
|
|
282
|
+
if isinstance(geom, geodesic.Geodesic):
|
|
283
|
+
a = jnp.concatenate((a, jnp.zeros_like(self._b)), axis=0)
|
|
284
|
+
b = jnp.concatenate((jnp.zeros_like(self._a), b), axis=0)
|
|
285
|
+
self._problem = linear_problem.LinearProblem(geom, a=a, b=b, **kwargs)
|
|
191
286
|
return self._problem
|
|
192
287
|
|
|
193
288
|
@property
|
|
@@ -201,7 +296,15 @@ class SinkhornSolver(OTTJaxSolver):
|
|
|
201
296
|
|
|
202
297
|
@classmethod
|
|
203
298
|
def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]:
|
|
204
|
-
geom_kwargs = {
|
|
299
|
+
geom_kwargs = {
|
|
300
|
+
"epsilon",
|
|
301
|
+
"relative_epsilon",
|
|
302
|
+
"batch_size",
|
|
303
|
+
"scale_cost",
|
|
304
|
+
"cost_kwargs",
|
|
305
|
+
"cost_matrix_rank",
|
|
306
|
+
"t",
|
|
307
|
+
}
|
|
205
308
|
problem_kwargs = set(inspect.signature(linear_problem.LinearProblem).parameters.keys())
|
|
206
309
|
problem_kwargs -= {"geom"}
|
|
207
310
|
return geom_kwargs | problem_kwargs, {"epsilon"}
|
|
@@ -243,24 +346,30 @@ class GWSolver(OTTJaxSolver):
|
|
|
243
346
|
):
|
|
244
347
|
super().__init__(jit=jit)
|
|
245
348
|
if rank > -1:
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
linear_solver_kwargs.setdefault("gamma_rescale", True)
|
|
249
|
-
linear_ot_solver = sinkhorn_lr.LRSinkhorn(rank=rank, **linear_solver_kwargs)
|
|
349
|
+
kwargs.setdefault("gamma", 10)
|
|
350
|
+
kwargs.setdefault("gamma_rescale", True)
|
|
250
351
|
initializer = "rank2" if initializer is None else initializer
|
|
352
|
+
self._solver = gromov_wasserstein_lr.LRGromovWasserstein(
|
|
353
|
+
rank=rank,
|
|
354
|
+
initializer=initializer,
|
|
355
|
+
kwargs_init=initializer_kwargs,
|
|
356
|
+
**kwargs,
|
|
357
|
+
)
|
|
251
358
|
else:
|
|
252
359
|
linear_ot_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs)
|
|
253
360
|
initializer = None
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
361
|
+
self._solver = gromov_wasserstein.GromovWasserstein(
|
|
362
|
+
rank=rank,
|
|
363
|
+
linear_ot_solver=linear_ot_solver,
|
|
364
|
+
quad_initializer=initializer,
|
|
365
|
+
kwargs_init=initializer_kwargs,
|
|
366
|
+
**kwargs,
|
|
367
|
+
)
|
|
261
368
|
|
|
262
369
|
def _prepare(
|
|
263
370
|
self,
|
|
371
|
+
a: jnp.ndarray,
|
|
372
|
+
b: jnp.ndarray,
|
|
264
373
|
xy: Optional[TaggedArray] = None,
|
|
265
374
|
x: Optional[TaggedArray] = None,
|
|
266
375
|
y: Optional[TaggedArray] = None,
|
|
@@ -271,32 +380,45 @@ class GWSolver(OTTJaxSolver):
|
|
|
271
380
|
scale_cost: Scale_t = 1.0,
|
|
272
381
|
cost_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
|
|
273
382
|
cost_matrix_rank: Optional[int] = None,
|
|
383
|
+
time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None,
|
|
274
384
|
# problem
|
|
275
385
|
alpha: float = 0.5,
|
|
276
386
|
**kwargs: Any,
|
|
277
387
|
) -> quadratic_problem.QuadraticProblem:
|
|
388
|
+
self._a = a
|
|
389
|
+
self._b = b
|
|
390
|
+
time_scales_heat_kernel = (
|
|
391
|
+
TimeScalesHeatKernel(None, None, None) if time_scales_heat_kernel is None else time_scales_heat_kernel
|
|
392
|
+
)
|
|
278
393
|
if x is None or y is None:
|
|
279
394
|
raise ValueError(f"Unable to create geometry from `x={x}`, `y={y}`.")
|
|
280
|
-
geom_kwargs: Any = {
|
|
395
|
+
geom_kwargs: dict[str, Any] = {
|
|
281
396
|
"epsilon": epsilon,
|
|
282
397
|
"relative_epsilon": relative_epsilon,
|
|
283
398
|
"batch_size": batch_size,
|
|
284
399
|
"scale_cost": scale_cost,
|
|
285
|
-
"cost_matrix_rank": cost_matrix_rank,
|
|
286
400
|
**cost_kwargs,
|
|
287
401
|
}
|
|
288
|
-
|
|
289
|
-
|
|
402
|
+
if cost_matrix_rank is not None:
|
|
403
|
+
geom_kwargs["cost_matrix_rank"] = cost_matrix_rank
|
|
404
|
+
geom_xx = self._create_geometry(x, t=time_scales_heat_kernel.x, is_linear_term=False, **geom_kwargs)
|
|
405
|
+
geom_yy = self._create_geometry(y, t=time_scales_heat_kernel.y, is_linear_term=False, **geom_kwargs)
|
|
290
406
|
if alpha == 1.0 or xy is None: # GW
|
|
291
407
|
# arbitrary fused penalty; must be positive
|
|
292
408
|
geom_xy, fused_penalty = None, 1.0
|
|
293
409
|
else: # FGW
|
|
294
410
|
fused_penalty = alpha_to_fused_penalty(alpha)
|
|
295
|
-
geom_xy = self._create_geometry(
|
|
411
|
+
geom_xy = self._create_geometry(
|
|
412
|
+
xy,
|
|
413
|
+
t=time_scales_heat_kernel.xy,
|
|
414
|
+
problem_shape=(x.shape[0], y.shape[0]),
|
|
415
|
+
is_linear_term=True,
|
|
416
|
+
**geom_kwargs,
|
|
417
|
+
)
|
|
296
418
|
check_shapes(geom_xx, geom_yy, geom_xy)
|
|
297
419
|
|
|
298
420
|
self._problem = quadratic_problem.QuadraticProblem(
|
|
299
|
-
geom_xx, geom_yy, geom_xy, fused_penalty=fused_penalty, **kwargs
|
|
421
|
+
geom_xx, geom_yy, geom_xy, fused_penalty=fused_penalty, a=self._a, b=self._b, **kwargs
|
|
300
422
|
)
|
|
301
423
|
return self._problem
|
|
302
424
|
|