pyepo 2.2.4__tar.gz → 2.2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyepo-2.2.4 → pyepo-2.2.5}/PKG-INFO +1 -1
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/data/dataset.py +26 -2
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/cave.py +2 -1
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/opt.py +46 -18
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/ort/compile.py +0 -1
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/ort/ortcpmodel.py +0 -1
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/ort/ortmodel.py +0 -1
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/utils.py +1 -1
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo.egg-info/PKG-INFO +1 -1
- {pyepo-2.2.4 → pyepo-2.2.5}/pyproject.toml +1 -1
- {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_10_utils.py +92 -2
- {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_40_dataset.py +49 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/LICENSE +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/README.md +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/EPO.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/data/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/data/_validation.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/data/knapsack.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/data/portfolio.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/data/shortestpath.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/data/tsp.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/dsl/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/dsl/compiled.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/dsl/expression.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/dsl/objective.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/dsl/problem.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/_common.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/abcmodule.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/blackbox.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/contrastive.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/abcmodule.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/blackbox.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/cave.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/contrastive.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/perturbed.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/rank.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/regularized.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/surrogate.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/utils.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/perturbed.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/rank.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/regularized.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/runtime.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/surrogate.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/utils.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/metric/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/metric/_common.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/metric/metrics.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/metric/mse.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/metric/regret.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/metric/unambregret.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/_common.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/bases.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/copt/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/copt/compile.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/copt/coptmodel.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/copt/knapsack.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/copt/portfolio.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/copt/shortestpath.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/copt/tsp.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/copt/vrp.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/grb/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/grb/compile.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/grb/grbmodel.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/grb/knapsack.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/grb/portfolio.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/grb/shortestpath.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/grb/tsp.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/grb/vrp.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/mpax/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/mpax/compile.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/mpax/knapsack.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/mpax/mpaxmodel.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/mpax/shortestpath.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/omo/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/omo/compile.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/omo/knapsack.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/omo/omomodel.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/omo/portfolio.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/omo/shortestpath.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/omo/tsp.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/omo/vrp.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/ort/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/ort/knapsack.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/ort/shortestpath.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/predefined.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/utils.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/py.typed +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/twostage/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/twostage/autosklearnpred.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/twostage/sklearnpred.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo.egg-info/SOURCES.txt +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo.egg-info/dependency_links.txt +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo.egg-info/requires.txt +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/pyepo.egg-info/top_level.txt +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/setup.cfg +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_00_constants.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_15_dsl.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_20_data_gen.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_30_model.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_50_func.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_55_jax.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_60_metric.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_61_metric_validation.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_70_twostage.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_80_integration.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_85_backend_pipeline.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_90_cuda.py +0 -0
|
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, cast
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import torch
|
|
13
13
|
from scipy.spatial import distance
|
|
14
|
-
from torch.utils.data import Dataset
|
|
14
|
+
from torch.utils.data import DataLoader, Dataset
|
|
15
15
|
from tqdm import tqdm
|
|
16
16
|
|
|
17
17
|
from pyepo import EPO
|
|
@@ -315,7 +315,8 @@ class optDatasetConstrs(optDataset):
|
|
|
315
315
|
currently requires a Gurobi-backed ``optModel``.
|
|
316
316
|
|
|
317
317
|
Per-instance row counts differ (different constraints bind at different
|
|
318
|
-
vertices), so
|
|
318
|
+
vertices), so batch with ``optDataLoader`` or pass
|
|
319
|
+
``collate_tight_constraints`` to a PyTorch ``DataLoader``.
|
|
319
320
|
|
|
320
321
|
Reference: Tang & Khalil (2024)
|
|
321
322
|
`<https://link.springer.com/chapter/10.1007/978-3-031-60599-4_12>`_
|
|
@@ -455,6 +456,29 @@ def collate_tight_constraints(batch):
|
|
|
455
456
|
)
|
|
456
457
|
|
|
457
458
|
|
|
459
|
+
# optDatasetConstrs yields ragged binding-constraint matrices; pad them when batching
|
|
460
|
+
optDatasetConstrs.collate_fn = staticmethod(collate_tight_constraints)
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
class optDataLoader(DataLoader):
|
|
464
|
+
"""
|
|
465
|
+
``DataLoader`` that applies a dataset's own ``collate_fn`` when present.
|
|
466
|
+
|
|
467
|
+
Datasets with ragged samples (e.g. ``optDatasetConstrs``, whose
|
|
468
|
+
binding-constraint matrices vary in row count) carry a ``collate_fn``; this
|
|
469
|
+
loader uses it so the caller never passes one explicitly. Plain datasets
|
|
470
|
+
fall back to the default PyTorch collation.
|
|
471
|
+
"""
|
|
472
|
+
|
|
473
|
+
def __init__(self, dataset, *args, **kwargs):
|
|
474
|
+
# use the dataset's own collate_fn unless the caller supplies one
|
|
475
|
+
if len(args) <= 5 and "collate_fn" not in kwargs:
|
|
476
|
+
collate_fn = getattr(dataset, "collate_fn", None)
|
|
477
|
+
if collate_fn is not None:
|
|
478
|
+
kwargs["collate_fn"] = collate_fn
|
|
479
|
+
super().__init__(dataset, *args, **kwargs)
|
|
480
|
+
|
|
481
|
+
|
|
458
482
|
def _extract_tight_normals(
|
|
459
483
|
model: optModel,
|
|
460
484
|
sol: np.ndarray,
|
|
@@ -59,7 +59,8 @@ class coneAlignedCosine(optModule):
|
|
|
59
59
|
cutting the per-epoch cost without measurable regret loss.
|
|
60
60
|
|
|
61
61
|
Training data must come from ``pyepo.data.dataset.optDatasetConstrs``
|
|
62
|
-
(Gurobi-backed) and be
|
|
62
|
+
(Gurobi-backed) and be batched with ``pyepo.data.dataset.optDataLoader``
|
|
63
|
+
or a ``DataLoader`` using ``collate_tight_constraints``.
|
|
63
64
|
|
|
64
65
|
Reference: Tang & Khalil (2024)
|
|
65
66
|
`<https://link.springer.com/chapter/10.1007/978-3-031-60599-4_12>`_
|
|
@@ -27,41 +27,63 @@ class ModelSpec:
|
|
|
27
27
|
"""Serializable recipe for building a fresh optimization model."""
|
|
28
28
|
|
|
29
29
|
model_type: type[optModel]
|
|
30
|
+
_args: tuple
|
|
30
31
|
_config: dict
|
|
31
32
|
|
|
32
|
-
def __init__(
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
model_type: type[optModel],
|
|
36
|
+
config: dict,
|
|
37
|
+
args: tuple = (),
|
|
38
|
+
) -> None:
|
|
33
39
|
object.__setattr__(self, "model_type", model_type)
|
|
40
|
+
object.__setattr__(self, "_args", deepcopy(args))
|
|
34
41
|
object.__setattr__(self, "_config", deepcopy(config))
|
|
35
42
|
|
|
43
|
+
@property
|
|
44
|
+
def args(self) -> tuple:
|
|
45
|
+
"""Return an independent copy of positional constructor arguments."""
|
|
46
|
+
return deepcopy(self._args)
|
|
47
|
+
|
|
36
48
|
@property
|
|
37
49
|
def config(self) -> dict:
|
|
38
|
-
"""Return an independent copy of
|
|
50
|
+
"""Return an independent copy of keyword constructor arguments."""
|
|
39
51
|
return deepcopy(self._config)
|
|
40
52
|
|
|
41
53
|
def build(self) -> optModel:
|
|
42
54
|
"""Build a fresh model without sharing mutable configuration values."""
|
|
43
|
-
return self.model_type.from_config(self._config)
|
|
55
|
+
return self.model_type.from_config(self._config, self._args)
|
|
56
|
+
|
|
44
57
|
|
|
58
|
+
def _snapshot(value):
|
|
59
|
+
"""Deep-copy a constructor argument, keeping the reference if it cannot be copied."""
|
|
60
|
+
try:
|
|
61
|
+
return deepcopy(value)
|
|
62
|
+
except Exception: # noqa: BLE001 -- any copy failure falls back to a reference
|
|
63
|
+
return value
|
|
45
64
|
|
|
46
|
-
|
|
47
|
-
|
|
65
|
+
|
|
66
|
+
def _capture_init_config(init, args, kwargs) -> tuple[tuple, dict]:
|
|
67
|
+
"""Flatten a constructor call into arguments that rebuild the model."""
|
|
48
68
|
sig = inspect.signature(init)
|
|
49
69
|
bound = sig.bind(None, *args, **kwargs)
|
|
70
|
+
init_args = []
|
|
50
71
|
config = {}
|
|
51
72
|
for i, (name, value) in enumerate(bound.arguments.items()):
|
|
52
73
|
# the first bound argument is self
|
|
53
74
|
if i == 0:
|
|
54
75
|
continue
|
|
55
76
|
kind = sig.parameters[name].kind
|
|
56
|
-
#
|
|
77
|
+
# snapshot each argument; values that cannot be deep-copied keep a reference
|
|
57
78
|
if kind is inspect.Parameter.VAR_KEYWORD:
|
|
58
|
-
config.update(
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
79
|
+
config.update({k: _snapshot(v) for k, v in value.items()})
|
|
80
|
+
elif kind is inspect.Parameter.VAR_POSITIONAL:
|
|
81
|
+
init_args.extend(_snapshot(v) for v in value)
|
|
82
|
+
elif kind is inspect.Parameter.POSITIONAL_ONLY:
|
|
83
|
+
init_args.append(_snapshot(value))
|
|
62
84
|
else:
|
|
63
|
-
config[name] =
|
|
64
|
-
return config
|
|
85
|
+
config[name] = _snapshot(value)
|
|
86
|
+
return tuple(init_args), config
|
|
65
87
|
|
|
66
88
|
|
|
67
89
|
class optModel(ABC):
|
|
@@ -93,8 +115,10 @@ class optModel(ABC):
|
|
|
93
115
|
|
|
94
116
|
def __init_subclass__(cls, **kwargs) -> None:
|
|
95
117
|
super().__init_subclass__(**kwargs)
|
|
96
|
-
#
|
|
97
|
-
|
|
118
|
+
# Only wrap subclasses that define their own __init__ and use the
|
|
119
|
+
# default reconstruction config. Custom get_config implementations may
|
|
120
|
+
# intentionally accept objects that are not deepcopyable/rebuildable.
|
|
121
|
+
if "__init__" not in cls.__dict__ or "get_config" in cls.__dict__:
|
|
98
122
|
return
|
|
99
123
|
user_init = cls.__init__
|
|
100
124
|
|
|
@@ -102,7 +126,7 @@ class optModel(ABC):
|
|
|
102
126
|
def _init_capturing(self, *args, **kwargs):
|
|
103
127
|
# record only the outermost call; nested super().__init__ leaves it intact
|
|
104
128
|
if "_init_config" not in self.__dict__:
|
|
105
|
-
self._init_config = _capture_init_config(user_init, args, kwargs)
|
|
129
|
+
self._init_args, self._init_config = _capture_init_config(user_init, args, kwargs)
|
|
106
130
|
user_init(self, *args, **kwargs)
|
|
107
131
|
|
|
108
132
|
cls.__init__ = _init_capturing
|
|
@@ -124,13 +148,17 @@ class optModel(ABC):
|
|
|
124
148
|
return deepcopy(self.__dict__.get("_init_config", {}))
|
|
125
149
|
|
|
126
150
|
@classmethod
|
|
127
|
-
def from_config(cls, config: dict) -> Self:
|
|
151
|
+
def from_config(cls, config: dict, args: tuple = ()) -> Self:
|
|
128
152
|
"""Build a model from a configuration produced by ``get_config``."""
|
|
129
|
-
return cls(**deepcopy(config))
|
|
153
|
+
return cls(*deepcopy(args), **deepcopy(config))
|
|
130
154
|
|
|
131
155
|
def to_spec(self) -> ModelSpec:
|
|
132
156
|
"""Return a serializable, immutable-snapshot rebuild recipe."""
|
|
133
|
-
return ModelSpec(
|
|
157
|
+
return ModelSpec(
|
|
158
|
+
type(self),
|
|
159
|
+
self.get_config(),
|
|
160
|
+
self.__dict__.get("_init_args", ()),
|
|
161
|
+
)
|
|
134
162
|
|
|
135
163
|
def rebuild(self) -> Self:
|
|
136
164
|
"""Build a structurally equivalent model with clean runtime state."""
|
|
@@ -40,7 +40,6 @@ class compiledOrtProblem(compiledBase, optOrtModel):
|
|
|
40
40
|
self.problem = deepcopy(problem)
|
|
41
41
|
self.params = dict(params) if params else {}
|
|
42
42
|
self.solver = solver
|
|
43
|
-
self._extra_constrs = [] # (coef, rhs) cuts replayed on copy
|
|
44
43
|
optModel.__init__(self) # builds the model via _getModel
|
|
45
44
|
self._model.SuppressOutput()
|
|
46
45
|
self._set_obj_sense()
|
|
@@ -53,7 +53,6 @@ class optOrtCpModel(optModel):
|
|
|
53
53
|
raise ImportError(
|
|
54
54
|
"OR-Tools is not installed. Please install ortools to use this feature."
|
|
55
55
|
)
|
|
56
|
-
self._extra_constrs = []
|
|
57
56
|
self._objective_coefs: list[int] | None = None
|
|
58
57
|
super().__init__()
|
|
59
58
|
# cache ordered Var list for batched weighted_sum / per-Var Value() loop
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "pyepo"
|
|
7
|
-
version = "2.2.
|
|
7
|
+
version = "2.2.5"
|
|
8
8
|
description = "PyTorch/JAX-based End-to-End Predict-then-Optimize Tool"
|
|
9
9
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
10
10
|
license = { text = "MIT" }
|
|
@@ -19,7 +19,7 @@ from .conftest import requires_gurobi
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class ConfigModel(optModel):
|
|
22
|
-
"""Solver-free model
|
|
22
|
+
"""Solver-free model with custom reconstruction config."""
|
|
23
23
|
|
|
24
24
|
def __init__(self, values, label="default"):
|
|
25
25
|
self.values = np.asarray(values)
|
|
@@ -61,6 +61,70 @@ class AutoConfigModel(optModel):
|
|
|
61
61
|
return np.zeros(len(self.values)), 0.0
|
|
62
62
|
|
|
63
63
|
|
|
64
|
+
class VarArgsConfigModel(optModel):
|
|
65
|
+
"""Solver-free model whose constructor needs positional replay."""
|
|
66
|
+
|
|
67
|
+
def __init__(self, *values, label="default"):
|
|
68
|
+
self.values = values
|
|
69
|
+
self.label = label
|
|
70
|
+
super().__init__()
|
|
71
|
+
|
|
72
|
+
def _getModel(self):
|
|
73
|
+
return None, list(range(len(self.values)))
|
|
74
|
+
|
|
75
|
+
def setObj(self, c):
|
|
76
|
+
self.cost = np.asarray(c)
|
|
77
|
+
|
|
78
|
+
def solve(self):
|
|
79
|
+
return np.zeros(len(self.values)), 0.0
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class PosOnlyConfigModel(optModel):
|
|
83
|
+
"""Solver-free model with a positional-only constructor argument."""
|
|
84
|
+
|
|
85
|
+
def __init__(self, values, /, label="default"):
|
|
86
|
+
self.values = values
|
|
87
|
+
self.label = label
|
|
88
|
+
super().__init__()
|
|
89
|
+
|
|
90
|
+
def _getModel(self):
|
|
91
|
+
return None, list(range(len(self.values)))
|
|
92
|
+
|
|
93
|
+
def setObj(self, c):
|
|
94
|
+
self.cost = np.asarray(c)
|
|
95
|
+
|
|
96
|
+
def solve(self):
|
|
97
|
+
return np.zeros(len(self.values)), 0.0
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class _NoDeepcopy:
|
|
101
|
+
def __init__(self, values):
|
|
102
|
+
self.values = values
|
|
103
|
+
|
|
104
|
+
def __deepcopy__(self, memo):
|
|
105
|
+
raise TypeError("not deepcopyable")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class CustomConfigModel(optModel):
|
|
109
|
+
"""Custom config should control reconstruction for unusual constructor inputs."""
|
|
110
|
+
|
|
111
|
+
def __init__(self, resource=None, values=None):
|
|
112
|
+
self.values = list(values if values is not None else resource.values)
|
|
113
|
+
super().__init__()
|
|
114
|
+
|
|
115
|
+
def get_config(self):
|
|
116
|
+
return {"values": self.values.copy()}
|
|
117
|
+
|
|
118
|
+
def _getModel(self):
|
|
119
|
+
return None, list(range(len(self.values)))
|
|
120
|
+
|
|
121
|
+
def setObj(self, c):
|
|
122
|
+
self.cost = np.asarray(c)
|
|
123
|
+
|
|
124
|
+
def solve(self):
|
|
125
|
+
return np.zeros(len(self.values)), 0.0
|
|
126
|
+
|
|
127
|
+
|
|
64
128
|
# ============================================================
|
|
65
129
|
# unionFind (pure)
|
|
66
130
|
# ============================================================
|
|
@@ -179,7 +243,7 @@ class TestCostToNumpy:
|
|
|
179
243
|
|
|
180
244
|
|
|
181
245
|
# ============================================================
|
|
182
|
-
#
|
|
246
|
+
# model reconstruction config (pure)
|
|
183
247
|
# ============================================================
|
|
184
248
|
|
|
185
249
|
|
|
@@ -262,6 +326,32 @@ class TestModelSpec:
|
|
|
262
326
|
assert model.get_config()["nested"] == {"tag": ["x"]}
|
|
263
327
|
assert model.rebuild().kwargs == {"nested": {"tag": ["x"]}}
|
|
264
328
|
|
|
329
|
+
def test_auto_config_replays_varargs(self):
|
|
330
|
+
model = VarArgsConfigModel(1, 2, 3, label="x")
|
|
331
|
+
spec = model.to_spec()
|
|
332
|
+
|
|
333
|
+
assert spec.args == (1, 2, 3)
|
|
334
|
+
assert spec.config == {"label": "x"}
|
|
335
|
+
rebuilt = model.rebuild()
|
|
336
|
+
assert rebuilt.values == (1, 2, 3)
|
|
337
|
+
assert rebuilt.label == "x"
|
|
338
|
+
|
|
339
|
+
def test_auto_config_replays_positional_only_args(self):
|
|
340
|
+
model = PosOnlyConfigModel([1, 2, 3], label="x")
|
|
341
|
+
spec = model.to_spec()
|
|
342
|
+
|
|
343
|
+
assert spec.args == ([1, 2, 3],)
|
|
344
|
+
assert spec.config == {"label": "x"}
|
|
345
|
+
rebuilt = model.rebuild()
|
|
346
|
+
assert rebuilt.values == [1, 2, 3]
|
|
347
|
+
assert rebuilt.label == "x"
|
|
348
|
+
|
|
349
|
+
def test_custom_get_config_accepts_uncopyable_constructor_input(self):
|
|
350
|
+
model = CustomConfigModel(_NoDeepcopy([1, 2, 3]))
|
|
351
|
+
rebuilt = model.rebuild()
|
|
352
|
+
|
|
353
|
+
assert rebuilt.values == [1, 2, 3]
|
|
354
|
+
|
|
265
355
|
|
|
266
356
|
# ============================================================
|
|
267
357
|
# getArgs (needs a real optModel)
|
|
@@ -14,6 +14,7 @@ import torch
|
|
|
14
14
|
from pyepo.data import shortestpath
|
|
15
15
|
from pyepo.data.dataset import (
|
|
16
16
|
collate_tight_constraints,
|
|
17
|
+
optDataLoader,
|
|
17
18
|
optDataset,
|
|
18
19
|
optDatasetConstrs,
|
|
19
20
|
optDatasetKNN,
|
|
@@ -299,3 +300,51 @@ class TestCollateTightConstraints:
|
|
|
299
300
|
assert padded.shape == (2, 5, 4)
|
|
300
301
|
# the shorter matrix is zero-padded at the tail
|
|
301
302
|
assert torch.allclose(padded[0, 2:], torch.zeros(3, 4))
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
class TestOptDataLoader:
|
|
306
|
+
"""Pure: optDataLoader applies a dataset's collate_fn automatically."""
|
|
307
|
+
|
|
308
|
+
def test_constrs_collate_fn_wired(self):
|
|
309
|
+
assert optDatasetConstrs.collate_fn is collate_tight_constraints
|
|
310
|
+
|
|
311
|
+
def test_uses_dataset_collate_fn(self):
|
|
312
|
+
class _Ragged(list):
|
|
313
|
+
collate_fn = staticmethod(lambda batch: ("auto", len(batch)))
|
|
314
|
+
|
|
315
|
+
loader = optDataLoader(_Ragged([0, 1, 2, 3]), batch_size=2)
|
|
316
|
+
assert next(iter(loader)) == ("auto", 2)
|
|
317
|
+
|
|
318
|
+
def test_accepts_positional_dataloader_args(self):
|
|
319
|
+
loader = optDataLoader([0, 1, 2, 3], 2)
|
|
320
|
+
assert torch.equal(next(iter(loader)), torch.tensor([0, 1]))
|
|
321
|
+
|
|
322
|
+
def test_explicit_collate_fn_wins(self):
|
|
323
|
+
class _Ragged(list):
|
|
324
|
+
collate_fn = staticmethod(lambda batch: ("auto", len(batch)))
|
|
325
|
+
|
|
326
|
+
loader = optDataLoader(
|
|
327
|
+
_Ragged([0, 1, 2, 3]), batch_size=2, collate_fn=lambda b: ("explicit", len(b))
|
|
328
|
+
)
|
|
329
|
+
assert next(iter(loader)) == ("explicit", 2)
|
|
330
|
+
|
|
331
|
+
def test_positional_collate_fn_wins(self):
|
|
332
|
+
class _Ragged(list):
|
|
333
|
+
collate_fn = staticmethod(lambda batch: ("auto", len(batch)))
|
|
334
|
+
|
|
335
|
+
loader = optDataLoader(
|
|
336
|
+
_Ragged([0, 1, 2, 3]),
|
|
337
|
+
2,
|
|
338
|
+
None,
|
|
339
|
+
None,
|
|
340
|
+
None,
|
|
341
|
+
0,
|
|
342
|
+
lambda b: ("positional", len(b)),
|
|
343
|
+
)
|
|
344
|
+
assert next(iter(loader)) == ("positional", 2)
|
|
345
|
+
|
|
346
|
+
def test_plain_dataset_uses_default_collate(self):
|
|
347
|
+
data = [(torch.tensor([float(i)]),) for i in range(4)]
|
|
348
|
+
loader = optDataLoader(data, batch_size=2)
|
|
349
|
+
(batch,) = next(iter(loader))
|
|
350
|
+
assert batch.shape == (2, 1)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|