pyepo 2.2.3__tar.gz → 2.2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyepo-2.2.3 → pyepo-2.2.5}/PKG-INFO +3 -3
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/data/dataset.py +26 -2
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/cave.py +2 -1
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/bases.py +0 -24
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/grb/grbmodel.py +2 -1
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/grb/tsp.py +0 -3
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/grb/vrp.py +0 -3
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/opt.py +75 -13
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/ort/compile.py +0 -1
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/ort/ortcpmodel.py +0 -1
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/ort/ortmodel.py +0 -1
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/utils.py +1 -1
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo.egg-info/PKG-INFO +3 -3
- {pyepo-2.2.3 → pyepo-2.2.5}/pyproject.toml +3 -2
- {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_10_utils.py +132 -2
- {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_40_dataset.py +49 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/LICENSE +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/README.md +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/EPO.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/__init__.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/data/__init__.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/data/_validation.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/data/knapsack.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/data/portfolio.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/data/shortestpath.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/data/tsp.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/dsl/__init__.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/dsl/compiled.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/dsl/expression.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/dsl/objective.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/dsl/problem.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/__init__.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/_common.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/abcmodule.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/blackbox.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/contrastive.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/__init__.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/abcmodule.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/blackbox.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/cave.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/contrastive.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/perturbed.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/rank.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/regularized.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/surrogate.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/utils.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/perturbed.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/rank.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/regularized.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/runtime.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/surrogate.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/utils.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/metric/__init__.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/metric/_common.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/metric/metrics.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/metric/mse.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/metric/regret.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/metric/unambregret.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/__init__.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/_common.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/copt/__init__.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/copt/compile.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/copt/coptmodel.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/copt/knapsack.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/copt/portfolio.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/copt/shortestpath.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/copt/tsp.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/copt/vrp.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/grb/__init__.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/grb/compile.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/grb/knapsack.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/grb/portfolio.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/grb/shortestpath.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/mpax/__init__.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/mpax/compile.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/mpax/knapsack.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/mpax/mpaxmodel.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/mpax/shortestpath.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/omo/__init__.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/omo/compile.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/omo/knapsack.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/omo/omomodel.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/omo/portfolio.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/omo/shortestpath.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/omo/tsp.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/omo/vrp.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/ort/__init__.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/ort/knapsack.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/ort/shortestpath.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/predefined.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/utils.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/py.typed +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/twostage/__init__.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/twostage/autosklearnpred.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/twostage/sklearnpred.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo.egg-info/SOURCES.txt +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo.egg-info/dependency_links.txt +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo.egg-info/requires.txt +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/pyepo.egg-info/top_level.txt +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/setup.cfg +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_00_constants.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_15_dsl.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_20_data_gen.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_30_model.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_50_func.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_55_jax.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_60_metric.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_61_metric_validation.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_70_twostage.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_80_integration.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_85_backend_pipeline.py +0 -0
- {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_90_cuda.py +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyepo
|
|
3
|
-
Version: 2.2.
|
|
4
|
-
Summary: PyTorch-based End-to-End Predict-then-Optimize Tool
|
|
3
|
+
Version: 2.2.5
|
|
4
|
+
Summary: PyTorch/JAX-based End-to-End Predict-then-Optimize Tool
|
|
5
5
|
Author-email: Bo Tang <bolucas.tang@mail.utoronto.ca>
|
|
6
6
|
License: MIT
|
|
7
7
|
Project-URL: Homepage, https://github.com/khalil-research/PyEPO
|
|
@@ -9,7 +9,7 @@ Project-URL: Documentation, https://khalil-research.github.io/PyEPO
|
|
|
9
9
|
Project-URL: Repository, https://github.com/khalil-research/PyEPO
|
|
10
10
|
Project-URL: Issues, https://github.com/khalil-research/PyEPO/issues
|
|
11
11
|
Project-URL: Paper, https://link.springer.com/article/10.1007/s12532-024-00255-x
|
|
12
|
-
Keywords: predict-then-optimize,end-to-end,decision-focused learning,optimization,deep learning,pytorch,linear programming,integer programming
|
|
12
|
+
Keywords: predict-then-optimize,end-to-end,decision-focused learning,optimization,deep learning,pytorch,jax,linear programming,integer programming
|
|
13
13
|
Classifier: Programming Language :: Python :: 3
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.9
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.10
|
|
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, cast
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import torch
|
|
13
13
|
from scipy.spatial import distance
|
|
14
|
-
from torch.utils.data import Dataset
|
|
14
|
+
from torch.utils.data import DataLoader, Dataset
|
|
15
15
|
from tqdm import tqdm
|
|
16
16
|
|
|
17
17
|
from pyepo import EPO
|
|
@@ -315,7 +315,8 @@ class optDatasetConstrs(optDataset):
|
|
|
315
315
|
currently requires a Gurobi-backed ``optModel``.
|
|
316
316
|
|
|
317
317
|
Per-instance row counts differ (different constraints bind at different
|
|
318
|
-
vertices), so
|
|
318
|
+
vertices), so batch with ``optDataLoader`` or pass
|
|
319
|
+
``collate_tight_constraints`` to a PyTorch ``DataLoader``.
|
|
319
320
|
|
|
320
321
|
Reference: Tang & Khalil (2024)
|
|
321
322
|
`<https://link.springer.com/chapter/10.1007/978-3-031-60599-4_12>`_
|
|
@@ -455,6 +456,29 @@ def collate_tight_constraints(batch):
|
|
|
455
456
|
)
|
|
456
457
|
|
|
457
458
|
|
|
459
|
+
# optDatasetConstrs yields ragged binding-constraint matrices; pad them when batching
|
|
460
|
+
optDatasetConstrs.collate_fn = staticmethod(collate_tight_constraints)
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
class optDataLoader(DataLoader):
|
|
464
|
+
"""
|
|
465
|
+
``DataLoader`` that applies a dataset's own ``collate_fn`` when present.
|
|
466
|
+
|
|
467
|
+
Datasets with ragged samples (e.g. ``optDatasetConstrs``, whose
|
|
468
|
+
binding-constraint matrices vary in row count) carry a ``collate_fn``; this
|
|
469
|
+
loader uses it so the caller never passes one explicitly. Plain datasets
|
|
470
|
+
fall back to the default PyTorch collation.
|
|
471
|
+
"""
|
|
472
|
+
|
|
473
|
+
def __init__(self, dataset, *args, **kwargs):
|
|
474
|
+
# use the dataset's own collate_fn unless the caller supplies one
|
|
475
|
+
if len(args) <= 5 and "collate_fn" not in kwargs:
|
|
476
|
+
collate_fn = getattr(dataset, "collate_fn", None)
|
|
477
|
+
if collate_fn is not None:
|
|
478
|
+
kwargs["collate_fn"] = collate_fn
|
|
479
|
+
super().__init__(dataset, *args, **kwargs)
|
|
480
|
+
|
|
481
|
+
|
|
458
482
|
def _extract_tight_normals(
|
|
459
483
|
model: optModel,
|
|
460
484
|
sol: np.ndarray,
|
|
@@ -59,7 +59,8 @@ class coneAlignedCosine(optModule):
|
|
|
59
59
|
cutting the per-epoch cost without measurable regret loss.
|
|
60
60
|
|
|
61
61
|
Training data must come from ``pyepo.data.dataset.optDatasetConstrs``
|
|
62
|
-
(Gurobi-backed) and be
|
|
62
|
+
(Gurobi-backed) and be batched with ``pyepo.data.dataset.optDataLoader``
|
|
63
|
+
or a ``DataLoader`` using ``collate_tight_constraints``.
|
|
63
64
|
|
|
64
65
|
Reference: Tang & Khalil (2024)
|
|
65
66
|
`<https://link.springer.com/chapter/10.1007/978-3-031-60599-4_12>`_
|
|
@@ -16,7 +16,6 @@ from __future__ import annotations
|
|
|
16
16
|
|
|
17
17
|
import math
|
|
18
18
|
from collections import defaultdict
|
|
19
|
-
from copy import deepcopy
|
|
20
19
|
from itertools import combinations
|
|
21
20
|
from numbers import Integral, Real
|
|
22
21
|
from typing import TYPE_CHECKING
|
|
@@ -105,9 +104,6 @@ class shortestPathBase(optModel):
|
|
|
105
104
|
self.arcs = _get_grid_arcs(self.grid)
|
|
106
105
|
super().__init__(*args, **kwargs)
|
|
107
106
|
|
|
108
|
-
def get_config(self) -> dict:
|
|
109
|
-
return {**super().get_config(), "grid": self.grid}
|
|
110
|
-
|
|
111
107
|
@property
|
|
112
108
|
def num_cost(self) -> int:
|
|
113
109
|
return len(self.arcs)
|
|
@@ -208,14 +204,6 @@ class portfolioBase(optModel):
|
|
|
208
204
|
raise ValueError("gamma must be greater than or equal to zero.")
|
|
209
205
|
super().__init__(*args, **kwargs)
|
|
210
206
|
|
|
211
|
-
def get_config(self) -> dict:
|
|
212
|
-
return {
|
|
213
|
-
**super().get_config(),
|
|
214
|
-
"num_assets": self.num_assets,
|
|
215
|
-
"covariance": self.covariance.copy(),
|
|
216
|
-
"gamma": self.gamma,
|
|
217
|
-
}
|
|
218
|
-
|
|
219
207
|
@property
|
|
220
208
|
def num_cost(self) -> int:
|
|
221
209
|
return self.num_assets
|
|
@@ -263,9 +251,6 @@ class tspABBase(optModel):
|
|
|
263
251
|
self._extra_constrs: list = []
|
|
264
252
|
super().__init__(*args, **kwargs)
|
|
265
253
|
|
|
266
|
-
def get_config(self) -> dict:
|
|
267
|
-
return {**super().get_config(), "num_nodes": self.num_nodes}
|
|
268
|
-
|
|
269
254
|
@property
|
|
270
255
|
def num_cost(self) -> int:
|
|
271
256
|
# use edges; backend's self.x has 2*num_edges directed Vars
|
|
@@ -381,15 +366,6 @@ class vrpABBase(optModel):
|
|
|
381
366
|
self._extra_constrs: list = []
|
|
382
367
|
super().__init__(*args, **kwargs)
|
|
383
368
|
|
|
384
|
-
def get_config(self) -> dict:
|
|
385
|
-
return {
|
|
386
|
-
**super().get_config(),
|
|
387
|
-
"num_nodes": self.num_nodes,
|
|
388
|
-
"demands": deepcopy(self.demands),
|
|
389
|
-
"capacity": self.capacity,
|
|
390
|
-
"num_vehicle": self.num_vehicle,
|
|
391
|
-
}
|
|
392
|
-
|
|
393
369
|
@property
|
|
394
370
|
def num_cost(self) -> int:
|
|
395
371
|
# one predicted cost per undirected edge
|
|
@@ -171,7 +171,8 @@ class optGrbModel(optModel):
|
|
|
171
171
|
else:
|
|
172
172
|
# LinExpr(coeffs, vars) builds the affine expression in one C call
|
|
173
173
|
vars_list = new_model._vars_list
|
|
174
|
-
|
|
174
|
+
if vars_list is None:
|
|
175
|
+
raise RuntimeError("Gurobi variable list is unavailable.")
|
|
175
176
|
expr = gp.LinExpr(coefs_np.tolist(), vars_list) <= rhs
|
|
176
177
|
new_model._model.addConstr(expr)
|
|
177
178
|
# track for replay on relax
|
|
@@ -178,9 +178,6 @@ class tspDFJModel(tspABModel):
|
|
|
178
178
|
self._recycled_keys: set = set()
|
|
179
179
|
super().__init__(num_nodes, *args, **kwargs)
|
|
180
180
|
|
|
181
|
-
def get_config(self) -> dict:
|
|
182
|
-
return {**super().get_config(), "recycle_cuts": self.recycle_cuts}
|
|
183
|
-
|
|
184
181
|
def _getModel(self) -> tuple:
|
|
185
182
|
"""
|
|
186
183
|
A method to build Gurobi model
|
|
@@ -70,9 +70,6 @@ class vrpRCIModel(vrpABModel):
|
|
|
70
70
|
self._recycled_keys: set = set()
|
|
71
71
|
super().__init__(num_nodes, demands, capacity, num_vehicle)
|
|
72
72
|
|
|
73
|
-
def get_config(self) -> dict:
|
|
74
|
-
return {**super().get_config(), "recycle_cuts": self.recycle_cuts}
|
|
75
|
-
|
|
76
73
|
def _getModel(self) -> tuple:
|
|
77
74
|
"""
|
|
78
75
|
A method to build Gurobi model
|
|
@@ -5,6 +5,8 @@ Abstract optimization model
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
|
+
import functools
|
|
9
|
+
import inspect
|
|
8
10
|
from abc import ABC, abstractmethod
|
|
9
11
|
from copy import deepcopy
|
|
10
12
|
from dataclasses import dataclass
|
|
@@ -25,20 +27,63 @@ class ModelSpec:
|
|
|
25
27
|
"""Serializable recipe for building a fresh optimization model."""
|
|
26
28
|
|
|
27
29
|
model_type: type[optModel]
|
|
30
|
+
_args: tuple
|
|
28
31
|
_config: dict
|
|
29
32
|
|
|
30
|
-
def __init__(
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
model_type: type[optModel],
|
|
36
|
+
config: dict,
|
|
37
|
+
args: tuple = (),
|
|
38
|
+
) -> None:
|
|
31
39
|
object.__setattr__(self, "model_type", model_type)
|
|
40
|
+
object.__setattr__(self, "_args", deepcopy(args))
|
|
32
41
|
object.__setattr__(self, "_config", deepcopy(config))
|
|
33
42
|
|
|
43
|
+
@property
|
|
44
|
+
def args(self) -> tuple:
|
|
45
|
+
"""Return an independent copy of positional constructor arguments."""
|
|
46
|
+
return deepcopy(self._args)
|
|
47
|
+
|
|
34
48
|
@property
|
|
35
49
|
def config(self) -> dict:
|
|
36
|
-
"""Return an independent copy of
|
|
50
|
+
"""Return an independent copy of keyword constructor arguments."""
|
|
37
51
|
return deepcopy(self._config)
|
|
38
52
|
|
|
39
53
|
def build(self) -> optModel:
|
|
40
54
|
"""Build a fresh model without sharing mutable configuration values."""
|
|
41
|
-
return self.model_type.from_config(self._config)
|
|
55
|
+
return self.model_type.from_config(self._config, self._args)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _snapshot(value):
|
|
59
|
+
"""Deep-copy a constructor argument, keeping the reference if it cannot be copied."""
|
|
60
|
+
try:
|
|
61
|
+
return deepcopy(value)
|
|
62
|
+
except Exception: # noqa: BLE001 -- any copy failure falls back to a reference
|
|
63
|
+
return value
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _capture_init_config(init, args, kwargs) -> tuple[tuple, dict]:
|
|
67
|
+
"""Flatten a constructor call into arguments that rebuild the model."""
|
|
68
|
+
sig = inspect.signature(init)
|
|
69
|
+
bound = sig.bind(None, *args, **kwargs)
|
|
70
|
+
init_args = []
|
|
71
|
+
config = {}
|
|
72
|
+
for i, (name, value) in enumerate(bound.arguments.items()):
|
|
73
|
+
# the first bound argument is self
|
|
74
|
+
if i == 0:
|
|
75
|
+
continue
|
|
76
|
+
kind = sig.parameters[name].kind
|
|
77
|
+
# snapshot each argument; values that cannot be deep-copied keep a reference
|
|
78
|
+
if kind is inspect.Parameter.VAR_KEYWORD:
|
|
79
|
+
config.update({k: _snapshot(v) for k, v in value.items()})
|
|
80
|
+
elif kind is inspect.Parameter.VAR_POSITIONAL:
|
|
81
|
+
init_args.extend(_snapshot(v) for v in value)
|
|
82
|
+
elif kind is inspect.Parameter.POSITIONAL_ONLY:
|
|
83
|
+
init_args.append(_snapshot(value))
|
|
84
|
+
else:
|
|
85
|
+
config[name] = _snapshot(value)
|
|
86
|
+
return tuple(init_args), config
|
|
42
87
|
|
|
43
88
|
|
|
44
89
|
class optModel(ABC):
|
|
@@ -53,11 +98,6 @@ class optModel(ABC):
|
|
|
53
98
|
and MPAX (``optMpaxModel``); subclass ``optModel`` directly to integrate
|
|
54
99
|
any other solver or algorithm.
|
|
55
100
|
|
|
56
|
-
Models that take constructor arguments should override ``get_config`` and
|
|
57
|
-
cooperatively merge ``super().get_config()``. The resulting configuration
|
|
58
|
-
powers ``rebuild()``, multiprocessing workers, and sklearn scorers without
|
|
59
|
-
inspecting constructor signatures or runtime solver state.
|
|
60
|
-
|
|
61
101
|
The default objective sense is minimization; set
|
|
62
102
|
``self.modelSense = EPO.MAXIMIZE`` in ``_getModel`` or ``__init__`` for
|
|
63
103
|
maximization problems (some backends, e.g. Gurobi and COPT, detect this
|
|
@@ -73,6 +113,24 @@ class optModel(ABC):
|
|
|
73
113
|
arcs: list
|
|
74
114
|
_cost_vars: list
|
|
75
115
|
|
|
116
|
+
def __init_subclass__(cls, **kwargs) -> None:
|
|
117
|
+
super().__init_subclass__(**kwargs)
|
|
118
|
+
# Only wrap subclasses that define their own __init__ and use the
|
|
119
|
+
# default reconstruction config. Custom get_config implementations may
|
|
120
|
+
# intentionally accept objects that are not deepcopyable/rebuildable.
|
|
121
|
+
if "__init__" not in cls.__dict__ or "get_config" in cls.__dict__:
|
|
122
|
+
return
|
|
123
|
+
user_init = cls.__init__
|
|
124
|
+
|
|
125
|
+
@functools.wraps(user_init)
|
|
126
|
+
def _init_capturing(self, *args, **kwargs):
|
|
127
|
+
# record only the outermost call; nested super().__init__ leaves it intact
|
|
128
|
+
if "_init_config" not in self.__dict__:
|
|
129
|
+
self._init_args, self._init_config = _capture_init_config(user_init, args, kwargs)
|
|
130
|
+
user_init(self, *args, **kwargs)
|
|
131
|
+
|
|
132
|
+
cls.__init__ = _init_capturing
|
|
133
|
+
|
|
76
134
|
def __init__(self) -> None:
|
|
77
135
|
# Cache for models whose solver variables do not map one-to-one to
|
|
78
136
|
# predicted costs (for example directed TSP/VRP formulations).
|
|
@@ -86,17 +144,21 @@ class optModel(ABC):
|
|
|
86
144
|
return "optModel " + self.__class__.__name__
|
|
87
145
|
|
|
88
146
|
def get_config(self) -> dict:
|
|
89
|
-
"""Return the
|
|
90
|
-
return {}
|
|
147
|
+
"""Return the constructor configuration for this model."""
|
|
148
|
+
return deepcopy(self.__dict__.get("_init_config", {}))
|
|
91
149
|
|
|
92
150
|
@classmethod
|
|
93
|
-
def from_config(cls, config: dict) -> Self:
|
|
151
|
+
def from_config(cls, config: dict, args: tuple = ()) -> Self:
|
|
94
152
|
"""Build a model from a configuration produced by ``get_config``."""
|
|
95
|
-
return cls(**deepcopy(config))
|
|
153
|
+
return cls(*deepcopy(args), **deepcopy(config))
|
|
96
154
|
|
|
97
155
|
def to_spec(self) -> ModelSpec:
|
|
98
156
|
"""Return a serializable, immutable-snapshot rebuild recipe."""
|
|
99
|
-
return ModelSpec(
|
|
157
|
+
return ModelSpec(
|
|
158
|
+
type(self),
|
|
159
|
+
self.get_config(),
|
|
160
|
+
self.__dict__.get("_init_args", ()),
|
|
161
|
+
)
|
|
100
162
|
|
|
101
163
|
def rebuild(self) -> Self:
|
|
102
164
|
"""Build a structurally equivalent model with clean runtime state."""
|
|
@@ -40,7 +40,6 @@ class compiledOrtProblem(compiledBase, optOrtModel):
|
|
|
40
40
|
self.problem = deepcopy(problem)
|
|
41
41
|
self.params = dict(params) if params else {}
|
|
42
42
|
self.solver = solver
|
|
43
|
-
self._extra_constrs = [] # (coef, rhs) cuts replayed on copy
|
|
44
43
|
optModel.__init__(self) # builds the model via _getModel
|
|
45
44
|
self._model.SuppressOutput()
|
|
46
45
|
self._set_obj_sense()
|
|
@@ -53,7 +53,6 @@ class optOrtCpModel(optModel):
|
|
|
53
53
|
raise ImportError(
|
|
54
54
|
"OR-Tools is not installed. Please install ortools to use this feature."
|
|
55
55
|
)
|
|
56
|
-
self._extra_constrs = []
|
|
57
56
|
self._objective_coefs: list[int] | None = None
|
|
58
57
|
super().__init__()
|
|
59
58
|
# cache ordered Var list for batched weighted_sum / per-Var Value() loop
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyepo
|
|
3
|
-
Version: 2.2.
|
|
4
|
-
Summary: PyTorch-based End-to-End Predict-then-Optimize Tool
|
|
3
|
+
Version: 2.2.5
|
|
4
|
+
Summary: PyTorch/JAX-based End-to-End Predict-then-Optimize Tool
|
|
5
5
|
Author-email: Bo Tang <bolucas.tang@mail.utoronto.ca>
|
|
6
6
|
License: MIT
|
|
7
7
|
Project-URL: Homepage, https://github.com/khalil-research/PyEPO
|
|
@@ -9,7 +9,7 @@ Project-URL: Documentation, https://khalil-research.github.io/PyEPO
|
|
|
9
9
|
Project-URL: Repository, https://github.com/khalil-research/PyEPO
|
|
10
10
|
Project-URL: Issues, https://github.com/khalil-research/PyEPO/issues
|
|
11
11
|
Project-URL: Paper, https://link.springer.com/article/10.1007/s12532-024-00255-x
|
|
12
|
-
Keywords: predict-then-optimize,end-to-end,decision-focused learning,optimization,deep learning,pytorch,linear programming,integer programming
|
|
12
|
+
Keywords: predict-then-optimize,end-to-end,decision-focused learning,optimization,deep learning,pytorch,jax,linear programming,integer programming
|
|
13
13
|
Classifier: Programming Language :: Python :: 3
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.9
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.10
|
|
@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "pyepo"
|
|
7
|
-
version = "2.2.
|
|
8
|
-
description = "PyTorch-based End-to-End Predict-then-Optimize Tool"
|
|
7
|
+
version = "2.2.5"
|
|
8
|
+
description = "PyTorch/JAX-based End-to-End Predict-then-Optimize Tool"
|
|
9
9
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
10
10
|
license = { text = "MIT" }
|
|
11
11
|
authors = [{ name = "Bo Tang", email = "bolucas.tang@mail.utoronto.ca" }]
|
|
@@ -17,6 +17,7 @@ keywords = [
|
|
|
17
17
|
"optimization",
|
|
18
18
|
"deep learning",
|
|
19
19
|
"pytorch",
|
|
20
|
+
"jax",
|
|
20
21
|
"linear programming",
|
|
21
22
|
"integer programming",
|
|
22
23
|
]
|
|
@@ -19,7 +19,7 @@ from .conftest import requires_gurobi
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class ConfigModel(optModel):
|
|
22
|
-
"""Solver-free model
|
|
22
|
+
"""Solver-free model with custom reconstruction config."""
|
|
23
23
|
|
|
24
24
|
def __init__(self, values, label="default"):
|
|
25
25
|
self.values = np.asarray(values)
|
|
@@ -43,6 +43,88 @@ class ConfigModel(optModel):
|
|
|
43
43
|
return np.zeros(len(self.values)), 0.0
|
|
44
44
|
|
|
45
45
|
|
|
46
|
+
class AutoConfigModel(optModel):
|
|
47
|
+
"""Solver-free model that relies on optModel's automatic config capture."""
|
|
48
|
+
|
|
49
|
+
def __init__(self, values, **kwargs):
|
|
50
|
+
self.values = values
|
|
51
|
+
self.kwargs = kwargs
|
|
52
|
+
super().__init__()
|
|
53
|
+
|
|
54
|
+
def _getModel(self):
|
|
55
|
+
return None, list(range(len(self.values)))
|
|
56
|
+
|
|
57
|
+
def setObj(self, c):
|
|
58
|
+
self.cost = np.asarray(c)
|
|
59
|
+
|
|
60
|
+
def solve(self):
|
|
61
|
+
return np.zeros(len(self.values)), 0.0
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class VarArgsConfigModel(optModel):
|
|
65
|
+
"""Solver-free model whose constructor needs positional replay."""
|
|
66
|
+
|
|
67
|
+
def __init__(self, *values, label="default"):
|
|
68
|
+
self.values = values
|
|
69
|
+
self.label = label
|
|
70
|
+
super().__init__()
|
|
71
|
+
|
|
72
|
+
def _getModel(self):
|
|
73
|
+
return None, list(range(len(self.values)))
|
|
74
|
+
|
|
75
|
+
def setObj(self, c):
|
|
76
|
+
self.cost = np.asarray(c)
|
|
77
|
+
|
|
78
|
+
def solve(self):
|
|
79
|
+
return np.zeros(len(self.values)), 0.0
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class PosOnlyConfigModel(optModel):
|
|
83
|
+
"""Solver-free model with a positional-only constructor argument."""
|
|
84
|
+
|
|
85
|
+
def __init__(self, values, /, label="default"):
|
|
86
|
+
self.values = values
|
|
87
|
+
self.label = label
|
|
88
|
+
super().__init__()
|
|
89
|
+
|
|
90
|
+
def _getModel(self):
|
|
91
|
+
return None, list(range(len(self.values)))
|
|
92
|
+
|
|
93
|
+
def setObj(self, c):
|
|
94
|
+
self.cost = np.asarray(c)
|
|
95
|
+
|
|
96
|
+
def solve(self):
|
|
97
|
+
return np.zeros(len(self.values)), 0.0
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class _NoDeepcopy:
|
|
101
|
+
def __init__(self, values):
|
|
102
|
+
self.values = values
|
|
103
|
+
|
|
104
|
+
def __deepcopy__(self, memo):
|
|
105
|
+
raise TypeError("not deepcopyable")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class CustomConfigModel(optModel):
|
|
109
|
+
"""Custom config should control reconstruction for unusual constructor inputs."""
|
|
110
|
+
|
|
111
|
+
def __init__(self, resource=None, values=None):
|
|
112
|
+
self.values = list(values if values is not None else resource.values)
|
|
113
|
+
super().__init__()
|
|
114
|
+
|
|
115
|
+
def get_config(self):
|
|
116
|
+
return {"values": self.values.copy()}
|
|
117
|
+
|
|
118
|
+
def _getModel(self):
|
|
119
|
+
return None, list(range(len(self.values)))
|
|
120
|
+
|
|
121
|
+
def setObj(self, c):
|
|
122
|
+
self.cost = np.asarray(c)
|
|
123
|
+
|
|
124
|
+
def solve(self):
|
|
125
|
+
return np.zeros(len(self.values)), 0.0
|
|
126
|
+
|
|
127
|
+
|
|
46
128
|
# ============================================================
|
|
47
129
|
# unionFind (pure)
|
|
48
130
|
# ============================================================
|
|
@@ -161,7 +243,7 @@ class TestCostToNumpy:
|
|
|
161
243
|
|
|
162
244
|
|
|
163
245
|
# ============================================================
|
|
164
|
-
#
|
|
246
|
+
# model reconstruction config (pure)
|
|
165
247
|
# ============================================================
|
|
166
248
|
|
|
167
249
|
|
|
@@ -222,6 +304,54 @@ class TestModelSpec:
|
|
|
222
304
|
np.testing.assert_array_equal(sol, [0.0, 0.0, 0.0])
|
|
223
305
|
assert obj == 0.0
|
|
224
306
|
|
|
307
|
+
def test_auto_config_snapshots_constructor_inputs(self):
|
|
308
|
+
values = [1, 2, 3]
|
|
309
|
+
nested = {"tag": ["x"]}
|
|
310
|
+
model = AutoConfigModel(values, nested=nested)
|
|
311
|
+
values[0] = 99
|
|
312
|
+
nested["tag"][0] = "changed"
|
|
313
|
+
|
|
314
|
+
rebuilt = model.rebuild()
|
|
315
|
+
|
|
316
|
+
assert rebuilt.values == [1, 2, 3]
|
|
317
|
+
assert rebuilt.kwargs == {"nested": {"tag": ["x"]}}
|
|
318
|
+
|
|
319
|
+
def test_auto_config_export_is_independent(self):
|
|
320
|
+
model = AutoConfigModel([1, 2, 3], nested={"tag": ["x"]})
|
|
321
|
+
config = model.get_config()
|
|
322
|
+
config["values"][0] = 99
|
|
323
|
+
config["nested"]["tag"][0] = "changed"
|
|
324
|
+
|
|
325
|
+
assert model.get_config()["values"] == [1, 2, 3]
|
|
326
|
+
assert model.get_config()["nested"] == {"tag": ["x"]}
|
|
327
|
+
assert model.rebuild().kwargs == {"nested": {"tag": ["x"]}}
|
|
328
|
+
|
|
329
|
+
def test_auto_config_replays_varargs(self):
|
|
330
|
+
model = VarArgsConfigModel(1, 2, 3, label="x")
|
|
331
|
+
spec = model.to_spec()
|
|
332
|
+
|
|
333
|
+
assert spec.args == (1, 2, 3)
|
|
334
|
+
assert spec.config == {"label": "x"}
|
|
335
|
+
rebuilt = model.rebuild()
|
|
336
|
+
assert rebuilt.values == (1, 2, 3)
|
|
337
|
+
assert rebuilt.label == "x"
|
|
338
|
+
|
|
339
|
+
def test_auto_config_replays_positional_only_args(self):
|
|
340
|
+
model = PosOnlyConfigModel([1, 2, 3], label="x")
|
|
341
|
+
spec = model.to_spec()
|
|
342
|
+
|
|
343
|
+
assert spec.args == ([1, 2, 3],)
|
|
344
|
+
assert spec.config == {"label": "x"}
|
|
345
|
+
rebuilt = model.rebuild()
|
|
346
|
+
assert rebuilt.values == [1, 2, 3]
|
|
347
|
+
assert rebuilt.label == "x"
|
|
348
|
+
|
|
349
|
+
def test_custom_get_config_accepts_uncopyable_constructor_input(self):
|
|
350
|
+
model = CustomConfigModel(_NoDeepcopy([1, 2, 3]))
|
|
351
|
+
rebuilt = model.rebuild()
|
|
352
|
+
|
|
353
|
+
assert rebuilt.values == [1, 2, 3]
|
|
354
|
+
|
|
225
355
|
|
|
226
356
|
# ============================================================
|
|
227
357
|
# getArgs (needs a real optModel)
|
|
@@ -14,6 +14,7 @@ import torch
|
|
|
14
14
|
from pyepo.data import shortestpath
|
|
15
15
|
from pyepo.data.dataset import (
|
|
16
16
|
collate_tight_constraints,
|
|
17
|
+
optDataLoader,
|
|
17
18
|
optDataset,
|
|
18
19
|
optDatasetConstrs,
|
|
19
20
|
optDatasetKNN,
|
|
@@ -299,3 +300,51 @@ class TestCollateTightConstraints:
|
|
|
299
300
|
assert padded.shape == (2, 5, 4)
|
|
300
301
|
# the shorter matrix is zero-padded at the tail
|
|
301
302
|
assert torch.allclose(padded[0, 2:], torch.zeros(3, 4))
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
class TestOptDataLoader:
|
|
306
|
+
"""Pure: optDataLoader applies a dataset's collate_fn automatically."""
|
|
307
|
+
|
|
308
|
+
def test_constrs_collate_fn_wired(self):
|
|
309
|
+
assert optDatasetConstrs.collate_fn is collate_tight_constraints
|
|
310
|
+
|
|
311
|
+
def test_uses_dataset_collate_fn(self):
|
|
312
|
+
class _Ragged(list):
|
|
313
|
+
collate_fn = staticmethod(lambda batch: ("auto", len(batch)))
|
|
314
|
+
|
|
315
|
+
loader = optDataLoader(_Ragged([0, 1, 2, 3]), batch_size=2)
|
|
316
|
+
assert next(iter(loader)) == ("auto", 2)
|
|
317
|
+
|
|
318
|
+
def test_accepts_positional_dataloader_args(self):
|
|
319
|
+
loader = optDataLoader([0, 1, 2, 3], 2)
|
|
320
|
+
assert torch.equal(next(iter(loader)), torch.tensor([0, 1]))
|
|
321
|
+
|
|
322
|
+
def test_explicit_collate_fn_wins(self):
|
|
323
|
+
class _Ragged(list):
|
|
324
|
+
collate_fn = staticmethod(lambda batch: ("auto", len(batch)))
|
|
325
|
+
|
|
326
|
+
loader = optDataLoader(
|
|
327
|
+
_Ragged([0, 1, 2, 3]), batch_size=2, collate_fn=lambda b: ("explicit", len(b))
|
|
328
|
+
)
|
|
329
|
+
assert next(iter(loader)) == ("explicit", 2)
|
|
330
|
+
|
|
331
|
+
def test_positional_collate_fn_wins(self):
|
|
332
|
+
class _Ragged(list):
|
|
333
|
+
collate_fn = staticmethod(lambda batch: ("auto", len(batch)))
|
|
334
|
+
|
|
335
|
+
loader = optDataLoader(
|
|
336
|
+
_Ragged([0, 1, 2, 3]),
|
|
337
|
+
2,
|
|
338
|
+
None,
|
|
339
|
+
None,
|
|
340
|
+
None,
|
|
341
|
+
0,
|
|
342
|
+
lambda b: ("positional", len(b)),
|
|
343
|
+
)
|
|
344
|
+
assert next(iter(loader)) == ("positional", 2)
|
|
345
|
+
|
|
346
|
+
def test_plain_dataset_uses_default_collate(self):
|
|
347
|
+
data = [(torch.tensor([float(i)]),) for i in range(4)]
|
|
348
|
+
loader = optDataLoader(data, batch_size=2)
|
|
349
|
+
(batch,) = next(iter(loader))
|
|
350
|
+
assert batch.shape == (2, 1)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|