pyepo 2.2.1__tar.gz → 2.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyepo-2.2.1 → pyepo-2.2.3}/PKG-INFO +8 -2
- {pyepo-2.2.1 → pyepo-2.2.3}/README.md +7 -1
- pyepo-2.2.3/pyepo/data/_validation.py +32 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/data/dataset.py +69 -71
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/data/knapsack.py +4 -7
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/data/portfolio.py +4 -5
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/data/shortestpath.py +4 -7
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/data/tsp.py +4 -7
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/dsl/compiled.py +18 -9
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/dsl/problem.py +4 -2
- pyepo-2.2.3/pyepo/func/_common.py +65 -0
- pyepo-2.2.3/pyepo/func/abcmodule.py +100 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/blackbox.py +4 -8
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/cave.py +8 -8
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/contrastive.py +7 -26
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/jax/__init__.py +1 -1
- pyepo-2.2.3/pyepo/func/jax/abcmodule.py +96 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/jax/blackbox.py +4 -8
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/jax/cave.py +19 -9
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/jax/contrastive.py +19 -10
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/jax/perturbed.py +22 -13
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/jax/rank.py +28 -17
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/jax/regularized.py +22 -9
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/jax/surrogate.py +12 -5
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/jax/utils.py +7 -8
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/perturbed.py +22 -22
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/rank.py +12 -40
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/regularized.py +20 -19
- pyepo-2.2.3/pyepo/func/runtime.py +129 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/surrogate.py +12 -15
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/utils.py +22 -22
- pyepo-2.2.3/pyepo/metric/_common.py +167 -0
- pyepo-2.2.3/pyepo/metric/metrics.py +120 -0
- pyepo-2.2.3/pyepo/metric/mse.py +42 -0
- pyepo-2.2.3/pyepo/metric/regret.py +163 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/metric/unambregret.py +26 -19
- pyepo-2.2.3/pyepo/model/__init__.py +41 -0
- pyepo-2.2.3/pyepo/model/_common.py +60 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/bases.py +127 -34
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/copt/compile.py +10 -12
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/copt/coptmodel.py +36 -15
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/copt/knapsack.py +2 -46
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/copt/portfolio.py +3 -35
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/copt/shortestpath.py +3 -33
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/copt/tsp.py +29 -67
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/copt/vrp.py +24 -16
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/grb/compile.py +7 -9
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/grb/grbmodel.py +43 -14
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/grb/knapsack.py +2 -44
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/grb/portfolio.py +3 -35
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/grb/shortestpath.py +3 -33
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/grb/tsp.py +12 -11
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/grb/vrp.py +13 -19
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/mpax/compile.py +7 -8
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/mpax/knapsack.py +1 -21
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/mpax/mpaxmodel.py +8 -12
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/mpax/shortestpath.py +1 -31
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/omo/compile.py +11 -18
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/omo/knapsack.py +6 -46
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/omo/omomodel.py +42 -15
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/omo/portfolio.py +2 -35
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/omo/shortestpath.py +2 -33
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/omo/tsp.py +6 -52
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/omo/vrp.py +5 -13
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/opt.py +52 -5
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/ort/compile.py +6 -5
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/ort/knapsack.py +3 -57
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/ort/ortcpmodel.py +18 -7
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/ort/ortmodel.py +16 -29
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/ort/shortestpath.py +2 -44
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/predefined.py +13 -12
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/twostage/autosklearnpred.py +1 -7
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/utils.py +2 -4
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo.egg-info/PKG-INFO +8 -2
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo.egg-info/SOURCES.txt +6 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyproject.toml +2 -7
- {pyepo-2.2.1 → pyepo-2.2.3}/tests/test_00_constants.py +0 -1
- {pyepo-2.2.1 → pyepo-2.2.3}/tests/test_10_utils.py +98 -5
- {pyepo-2.2.1 → pyepo-2.2.3}/tests/test_15_dsl.py +263 -108
- {pyepo-2.2.1 → pyepo-2.2.3}/tests/test_20_data_gen.py +28 -36
- {pyepo-2.2.1 → pyepo-2.2.3}/tests/test_30_model.py +465 -49
- {pyepo-2.2.1 → pyepo-2.2.3}/tests/test_40_dataset.py +65 -22
- {pyepo-2.2.1 → pyepo-2.2.3}/tests/test_50_func.py +384 -149
- {pyepo-2.2.1 → pyepo-2.2.3}/tests/test_55_jax.py +95 -122
- {pyepo-2.2.1 → pyepo-2.2.3}/tests/test_60_metric.py +222 -74
- pyepo-2.2.3/tests/test_61_metric_validation.py +168 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/tests/test_70_twostage.py +28 -3
- {pyepo-2.2.1 → pyepo-2.2.3}/tests/test_80_integration.py +50 -30
- {pyepo-2.2.1 → pyepo-2.2.3}/tests/test_90_cuda.py +7 -10
- pyepo-2.2.1/pyepo/func/abcmodule.py +0 -120
- pyepo-2.2.1/pyepo/func/jax/abcmodule.py +0 -118
- pyepo-2.2.1/pyepo/metric/metrics.py +0 -191
- pyepo-2.2.1/pyepo/metric/mse.py +0 -49
- pyepo-2.2.1/pyepo/metric/regret.py +0 -200
- pyepo-2.2.1/pyepo/model/__init__.py +0 -52
- {pyepo-2.2.1 → pyepo-2.2.3}/LICENSE +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/EPO.py +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/__init__.py +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/data/__init__.py +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/dsl/__init__.py +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/dsl/expression.py +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/dsl/objective.py +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/func/__init__.py +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/metric/__init__.py +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/copt/__init__.py +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/grb/__init__.py +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/mpax/__init__.py +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/omo/__init__.py +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/ort/__init__.py +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/model/utils.py +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/py.typed +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/twostage/__init__.py +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo/twostage/sklearnpred.py +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo.egg-info/dependency_links.txt +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo.egg-info/requires.txt +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/pyepo.egg-info/top_level.txt +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/setup.cfg +0 -0
- {pyepo-2.2.1 → pyepo-2.2.3}/tests/test_85_backend_pipeline.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyepo
|
|
3
|
-
Version: 2.2.
|
|
3
|
+
Version: 2.2.3
|
|
4
4
|
Summary: PyTorch-based End-to-End Predict-then-Optimize Tool
|
|
5
5
|
Author-email: Bo Tang <bolucas.tang@mail.utoronto.ca>
|
|
6
6
|
License: MIT
|
|
@@ -76,12 +76,18 @@ Dynamic: license-file
|
|
|
76
76
|
|
|
77
77
|
## Features
|
|
78
78
|
|
|
79
|
-
- Implement **SPO+**, **
|
|
79
|
+
- Implement **SPO+**, **PG**, **DPO** (additive and multiplicative perturbations), **PFYL** (additive and multiplicative perturbations), **I-MLE**, **AI-MLE**, L2-regularized **RFWO/RFYL**, **DBB**, **NID**, **CaVE**, **NCE**, and **LTR**
|
|
80
80
|
- Support [Gurobi](https://www.gurobi.com/), [COPT](https://shanshu.ai/copt), [Pyomo](http://www.pyomo.org/), [Google OR-Tools](https://developers.google.com/optimization), and [MPAX](https://github.com/MIT-Lu-Lab/MPAX) API
|
|
81
|
+
- Symbolic modeling with `pyepo.dsl`: define an LP, MIP, or QP once, then compile it to any backend
|
|
82
|
+
- JAX frontend (`pyepo.func.jax`): train any loss in JAX/Flax with `jax.grad`
|
|
81
83
|
- Support parallel computing for optimization solvers
|
|
82
84
|
- Support solution caching to speed up training
|
|
83
85
|
- Support kNN robust loss to improve decision quality
|
|
84
86
|
|
|
87
|
+
## CaVE for Binary Linear Programs
|
|
88
|
+
|
|
89
|
+
For end-to-end learning on **binary linear programs** (TSP, CVRP, knapsack, ...), ``PyEPO`` ships **CaVE**. CaVE replaces the per-step ILP solve with a cone-alignment projection onto the binding-constraint normals at the true optimum, backed by an interior-point QP solver (Clarabel). Because the cone projection is far cheaper than the per-instance ILP solve, CaVE trains an order of magnitude faster than SPO+ at TSP scale.
|
|
90
|
+
|
|
85
91
|
## GPU-Accelerated Solving with MPAX
|
|
86
92
|
|
|
87
93
|
``PyEPO`` integrates [MPAX](https://github.com/MIT-Lu-Lab/MPAX), a JAX-based mathematical programming solver using the PDHG algorithm for GPU-accelerated optimization. Key advantages: (1) **GPU-native solving** — the first-order PDHG method runs efficiently on GPU; (2) **batch solving** — an entire mini-batch can be solved simultaneously via vectorization; (3) **no GPU-CPU data transfer overhead** — both the neural network and the solver stay on GPU, eliminating the data transfer bottleneck.
|
|
@@ -4,12 +4,18 @@
|
|
|
4
4
|
|
|
5
5
|
## Features
|
|
6
6
|
|
|
7
|
-
- Implement **SPO+**, **
|
|
7
|
+
- Implement **SPO+**, **PG**, **DPO** (additive and multiplicative perturbations), **PFYL** (additive and multiplicative perturbations), **I-MLE**, **AI-MLE**, L2-regularized **RFWO/RFYL**, **DBB**, **NID**, **CaVE**, **NCE**, and **LTR**
|
|
8
8
|
- Support [Gurobi](https://www.gurobi.com/), [COPT](https://shanshu.ai/copt), [Pyomo](http://www.pyomo.org/), [Google OR-Tools](https://developers.google.com/optimization), and [MPAX](https://github.com/MIT-Lu-Lab/MPAX) API
|
|
9
|
+
- Symbolic modeling with `pyepo.dsl`: define an LP, MIP, or QP once, then compile it to any backend
|
|
10
|
+
- JAX frontend (`pyepo.func.jax`): train any loss in JAX/Flax with `jax.grad`
|
|
9
11
|
- Support parallel computing for optimization solvers
|
|
10
12
|
- Support solution caching to speed up training
|
|
11
13
|
- Support kNN robust loss to improve decision quality
|
|
12
14
|
|
|
15
|
+
## CaVE for Binary Linear Programs
|
|
16
|
+
|
|
17
|
+
For end-to-end learning on **binary linear programs** (TSP, CVRP, knapsack, ...), ``PyEPO`` ships **CaVE**. CaVE replaces the per-step ILP solve with a cone-alignment projection onto the binding-constraint normals at the true optimum, backed by an interior-point QP solver (Clarabel). Because the cone projection is far cheaper than the per-instance ILP solve, CaVE trains an order of magnitude faster than SPO+ at TSP scale.
|
|
18
|
+
|
|
13
19
|
## GPU-Accelerated Solving with MPAX
|
|
14
20
|
|
|
15
21
|
``PyEPO`` integrates [MPAX](https://github.com/MIT-Lu-Lab/MPAX), a JAX-based mathematical programming solver using the PDHG algorithm for GPU-accelerated optimization. Key advantages: (1) **GPU-native solving** — the first-order PDHG method runs efficiently on GPU; (2) **batch solving** — an entire mini-batch can be solved simultaneously via vectorization; (3) **no GPU-CPU data transfer overhead** — both the neural network and the solver stay on GPU, eliminating the data transfer bottleneck.
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""Shared validation for synthetic data generators."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from numbers import Real
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def validate_degree(deg: int) -> None:
|
|
8
|
+
"""Validate a positive integer polynomial degree."""
|
|
9
|
+
if not isinstance(deg, int) or isinstance(deg, bool) or deg <= 0:
|
|
10
|
+
raise ValueError(f"deg = {deg} should be a positive integer.")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def validate_nonnegative(value: float, name: str) -> None:
|
|
14
|
+
"""Validate a finite, non-negative real generator parameter."""
|
|
15
|
+
if not isinstance(value, Real) or isinstance(value, bool):
|
|
16
|
+
raise ValueError(f"{name} = {value} should be a finite non-negative number.")
|
|
17
|
+
number = float(value)
|
|
18
|
+
if not math.isfinite(number) or number < 0:
|
|
19
|
+
raise ValueError(f"{name} = {value} should be a finite non-negative number.")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def validate_probability(value: float, name: str) -> None:
|
|
23
|
+
"""Validate a finite real value in the closed interval [0, 1]."""
|
|
24
|
+
validate_nonnegative(value, name)
|
|
25
|
+
if float(value) > 1:
|
|
26
|
+
raise ValueError(f"{name} = {value} should be in [0, 1].")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def validate_positive_int(value: int, name: str) -> None:
|
|
30
|
+
"""Validate a strictly positive integer parameter."""
|
|
31
|
+
if not isinstance(value, int) or isinstance(value, bool) or value <= 0:
|
|
32
|
+
raise ValueError(f"{name} = {value} should be a positive integer.")
|
|
@@ -15,19 +15,48 @@ from torch.utils.data import Dataset
|
|
|
15
15
|
from tqdm import tqdm
|
|
16
16
|
|
|
17
17
|
from pyepo import EPO
|
|
18
|
+
from pyepo.data._validation import validate_positive_int, validate_probability
|
|
18
19
|
from pyepo.model.opt import optModel
|
|
19
20
|
|
|
20
21
|
if TYPE_CHECKING:
|
|
21
22
|
from pyepo.model.mpax import optMpaxModel as _optMpaxModelT
|
|
22
23
|
|
|
23
24
|
try:
|
|
24
|
-
from pyepo.model.mpax import optMpaxModel
|
|
25
|
+
from pyepo.model.mpax import optMpaxModel as _opt_mpax_model_cls
|
|
25
26
|
except ImportError:
|
|
26
|
-
|
|
27
|
+
_opt_mpax_model_cls = None
|
|
27
28
|
|
|
28
29
|
logger = logging.getLogger(__name__)
|
|
29
30
|
|
|
30
31
|
|
|
32
|
+
def _validate_inputs(
|
|
33
|
+
model: optModel,
|
|
34
|
+
feats: np.ndarray | torch.Tensor,
|
|
35
|
+
costs: np.ndarray | torch.Tensor,
|
|
36
|
+
) -> None:
|
|
37
|
+
"""Validate the common constructor contract for optimization datasets."""
|
|
38
|
+
if not isinstance(model, optModel):
|
|
39
|
+
raise TypeError("arg model is not an optModel")
|
|
40
|
+
if len(feats) != len(costs):
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"feats and costs must have the same number of instances: {len(feats)} vs {len(costs)}."
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _as_float_tensor(data) -> torch.Tensor:
|
|
47
|
+
"""Convert dataset arrays to the common float32 tensor representation."""
|
|
48
|
+
return torch.as_tensor(data, dtype=torch.float32)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _solution_to_numpy(
|
|
52
|
+
solution: np.ndarray | torch.Tensor | list,
|
|
53
|
+
) -> np.ndarray:
|
|
54
|
+
"""Normalize a solver solution to a NumPy array."""
|
|
55
|
+
if isinstance(solution, torch.Tensor):
|
|
56
|
+
solution = solution.detach().cpu().numpy()
|
|
57
|
+
return np.asarray(solution)
|
|
58
|
+
|
|
59
|
+
|
|
31
60
|
class optDataset(Dataset):
|
|
32
61
|
"""
|
|
33
62
|
PyTorch ``Dataset`` for predict-then-optimize problems.
|
|
@@ -62,44 +91,35 @@ class optDataset(Dataset):
|
|
|
62
91
|
feats: data features
|
|
63
92
|
costs: costs of objective function
|
|
64
93
|
"""
|
|
65
|
-
|
|
66
|
-
raise TypeError("arg model is not an optModel")
|
|
67
|
-
if len(feats) != len(costs):
|
|
68
|
-
raise ValueError(
|
|
69
|
-
f"feats and costs must have the same number of instances: "
|
|
70
|
-
f"{len(feats)} vs {len(costs)}."
|
|
71
|
-
)
|
|
94
|
+
_validate_inputs(model, feats, costs)
|
|
72
95
|
self.model = model
|
|
73
96
|
# data
|
|
74
97
|
self.feats = feats
|
|
75
98
|
self.costs = costs
|
|
76
99
|
# find optimal solutions
|
|
77
|
-
sols, objs = self.
|
|
78
|
-
self.feats =
|
|
79
|
-
self.costs =
|
|
80
|
-
self.sols =
|
|
81
|
-
self.objs =
|
|
100
|
+
sols, objs = self._get_sols()
|
|
101
|
+
self.feats = _as_float_tensor(feats)
|
|
102
|
+
self.costs = _as_float_tensor(costs)
|
|
103
|
+
self.sols = _as_float_tensor(sols)
|
|
104
|
+
self.objs = _as_float_tensor(objs)
|
|
82
105
|
|
|
83
|
-
def
|
|
106
|
+
def _get_sols(self) -> tuple[np.ndarray, np.ndarray]:
|
|
84
107
|
"""
|
|
85
108
|
A method to get optimal solutions for all cost vectors
|
|
86
109
|
"""
|
|
87
110
|
# MPAX fast path: vmap-solve the whole dataset in a single dispatch
|
|
88
|
-
if
|
|
89
|
-
return self.
|
|
111
|
+
if _opt_mpax_model_cls is not None and isinstance(self.model, _opt_mpax_model_cls):
|
|
112
|
+
return self._get_sols_mpax_batch()
|
|
90
113
|
sols = []
|
|
91
114
|
objs = []
|
|
92
115
|
logger.info("Optimizing for optDataset...")
|
|
93
116
|
for c in tqdm(self.costs):
|
|
94
117
|
sol, obj = self._solve(c)
|
|
95
|
-
|
|
96
|
-
if isinstance(sol, torch.Tensor):
|
|
97
|
-
sol = sol.detach().cpu().numpy()
|
|
98
|
-
sols.append(np.asarray(sol))
|
|
118
|
+
sols.append(_solution_to_numpy(sol))
|
|
99
119
|
objs.append(obj)
|
|
100
120
|
return np.stack(sols), np.asarray(objs).reshape(-1, 1)
|
|
101
121
|
|
|
102
|
-
def
|
|
122
|
+
def _get_sols_mpax_batch(self) -> tuple[np.ndarray, np.ndarray]:
|
|
103
123
|
"""
|
|
104
124
|
A method to batch-solve every cost vector in one MPAX vmap call.
|
|
105
125
|
"""
|
|
@@ -209,20 +229,14 @@ class optDatasetKNN(optDataset):
|
|
|
209
229
|
k: number of nearest neighbours selected
|
|
210
230
|
weight: self-weight in the kNN convex combination (1.0 = no smoothing)
|
|
211
231
|
"""
|
|
212
|
-
|
|
213
|
-
raise TypeError("arg model is not an optModel")
|
|
214
|
-
if len(feats) != len(costs):
|
|
215
|
-
raise ValueError(
|
|
216
|
-
f"feats and costs must have the same number of instances: "
|
|
217
|
-
f"{len(feats)} vs {len(costs)}."
|
|
218
|
-
)
|
|
232
|
+
_validate_inputs(model, feats, costs)
|
|
219
233
|
self.model = model
|
|
220
234
|
# at most num_data-1 neighbours exist (self excluded), so k must stay below it
|
|
221
235
|
num_data = len(feats)
|
|
222
|
-
|
|
236
|
+
validate_positive_int(k, "k")
|
|
237
|
+
if k >= num_data:
|
|
223
238
|
raise ValueError(f"Invalid k={k}; must satisfy 1 <= k < num_data ({num_data}).")
|
|
224
|
-
|
|
225
|
-
raise ValueError(f"Invalid weight={weight}; must satisfy 0 <= weight <= 1.")
|
|
239
|
+
validate_probability(weight, "weight")
|
|
226
240
|
# kNN loss parameters
|
|
227
241
|
self.k = k
|
|
228
242
|
self.weight = weight
|
|
@@ -230,13 +244,13 @@ class optDatasetKNN(optDataset):
|
|
|
230
244
|
self.feats = feats
|
|
231
245
|
self.costs = costs
|
|
232
246
|
# find optimal solutions
|
|
233
|
-
sols, objs = self.
|
|
234
|
-
self.feats =
|
|
235
|
-
self.costs =
|
|
236
|
-
self.sols =
|
|
237
|
-
self.objs =
|
|
247
|
+
sols, objs = self._get_sols()
|
|
248
|
+
self.feats = _as_float_tensor(self.feats)
|
|
249
|
+
self.costs = _as_float_tensor(self.costs)
|
|
250
|
+
self.sols = _as_float_tensor(sols)
|
|
251
|
+
self.objs = _as_float_tensor(objs)
|
|
238
252
|
|
|
239
|
-
def
|
|
253
|
+
def _get_sols(self) -> tuple[np.ndarray, np.ndarray]:
|
|
240
254
|
"""
|
|
241
255
|
A method to get optimal solutions for all cost vectors
|
|
242
256
|
"""
|
|
@@ -244,16 +258,14 @@ class optDatasetKNN(optDataset):
|
|
|
244
258
|
objs = []
|
|
245
259
|
logger.info("Optimizing for optDataset...")
|
|
246
260
|
# get kNN costs
|
|
247
|
-
costs_knn = self.
|
|
261
|
+
costs_knn = self._get_knn()
|
|
248
262
|
# solve optimization
|
|
249
263
|
for c_knn in tqdm(costs_knn):
|
|
250
264
|
sol_knn = np.zeros((self.costs.shape[1], self.k))
|
|
251
265
|
obj_knn = np.zeros(self.k)
|
|
252
266
|
for i, c in enumerate(c_knn.T):
|
|
253
267
|
sol_i, obj_i = self._solve(c)
|
|
254
|
-
|
|
255
|
-
sol_i = sol_i.detach().cpu().numpy()
|
|
256
|
-
sol_knn[:, i] = sol_i
|
|
268
|
+
sol_knn[:, i] = _solution_to_numpy(sol_i)
|
|
257
269
|
obj_knn[i] = obj_i
|
|
258
270
|
# get average
|
|
259
271
|
sol = sol_knn.mean(axis=1)
|
|
@@ -264,7 +276,7 @@ class optDatasetKNN(optDataset):
|
|
|
264
276
|
self.costs = costs_knn.mean(axis=2)
|
|
265
277
|
return np.stack(sols), np.asarray(objs).reshape(-1, 1)
|
|
266
278
|
|
|
267
|
-
def
|
|
279
|
+
def _get_knn(self) -> np.ndarray:
|
|
268
280
|
"""
|
|
269
281
|
A method to get kNN costs
|
|
270
282
|
"""
|
|
@@ -334,28 +346,22 @@ class optDatasetConstrs(optDataset):
|
|
|
334
346
|
costs: costs of objective function
|
|
335
347
|
skip_infeas: if True, drop infeasible instances instead of raising
|
|
336
348
|
"""
|
|
337
|
-
|
|
338
|
-
raise TypeError("arg model is not an optModel")
|
|
339
|
-
if len(feats) != len(costs):
|
|
340
|
-
raise ValueError(
|
|
341
|
-
f"feats and costs must have the same number of instances: "
|
|
342
|
-
f"{len(feats)} vs {len(costs)}."
|
|
343
|
-
)
|
|
349
|
+
_validate_inputs(model, feats, costs)
|
|
344
350
|
self.model = model
|
|
345
351
|
self.skip_infeas = skip_infeas
|
|
346
352
|
# data
|
|
347
353
|
self.feats = feats
|
|
348
354
|
self.costs = costs
|
|
349
355
|
# find optimal solutions and binding constraints
|
|
350
|
-
sols, objs, ctrs, valid = self.
|
|
356
|
+
sols, objs, ctrs, valid = self._get_sols()
|
|
351
357
|
# pre-convert to tensors (on CPU) to avoid repeated numpy→tensor copies
|
|
352
|
-
self.feats =
|
|
353
|
-
self.costs =
|
|
354
|
-
self.sols =
|
|
355
|
-
self.objs =
|
|
356
|
-
self.ctrs = [
|
|
358
|
+
self.feats = _as_float_tensor(self.feats[valid])
|
|
359
|
+
self.costs = _as_float_tensor(self.costs[valid])
|
|
360
|
+
self.sols = _as_float_tensor(sols)
|
|
361
|
+
self.objs = _as_float_tensor(objs)
|
|
362
|
+
self.ctrs = [_as_float_tensor(c) for c in ctrs]
|
|
357
363
|
|
|
358
|
-
def
|
|
364
|
+
def _get_sols( # type: ignore[override]
|
|
359
365
|
self,
|
|
360
366
|
) -> tuple[np.ndarray, np.ndarray, list[np.ndarray], list[int]]:
|
|
361
367
|
"""
|
|
@@ -371,9 +377,8 @@ class optDatasetConstrs(optDataset):
|
|
|
371
377
|
logger.info("Optimizing for optDatasetConstrs...")
|
|
372
378
|
model = self.model
|
|
373
379
|
for i, c in enumerate(tqdm(self.costs)):
|
|
374
|
-
model._setFullObj(model._fullCost(c))
|
|
375
380
|
try:
|
|
376
|
-
sol, obj =
|
|
381
|
+
sol, obj = self._solve(c)
|
|
377
382
|
except RuntimeError as e:
|
|
378
383
|
if self.skip_infeas:
|
|
379
384
|
logger.warning("Instance %d had no solution, skipping: %s", i, e)
|
|
@@ -413,12 +418,6 @@ class optDatasetConstrs(optDataset):
|
|
|
413
418
|
raise ValueError("No valid instances (all skipped or empty input).")
|
|
414
419
|
return np.stack(sols), np.asarray(objs), ctrs, valid
|
|
415
420
|
|
|
416
|
-
def __len__(self) -> int:
|
|
417
|
-
"""
|
|
418
|
-
A method to get data size
|
|
419
|
-
"""
|
|
420
|
-
return len(self.feats)
|
|
421
|
-
|
|
422
421
|
def __getitem__( # type: ignore[override]
|
|
423
422
|
self,
|
|
424
423
|
index: int,
|
|
@@ -563,13 +562,12 @@ def _parse_temp_constraint(
|
|
|
563
562
|
"""
|
|
564
563
|
Parse a Gurobi TempConstr into (coefs, rhs, sense) over the cost-vector dim
|
|
565
564
|
"""
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
# unparseable fallback
|
|
571
|
-
if lhs is None or rhs is None or sense is None:
|
|
565
|
+
from pyepo.model.grb.grbmodel import _temp_constr_fields
|
|
566
|
+
|
|
567
|
+
fields = _temp_constr_fields(tc)
|
|
568
|
+
if fields is None:
|
|
572
569
|
return None
|
|
570
|
+
lhs, rhs, sense = fields
|
|
573
571
|
# project LinExpr terms onto cost-vector dim
|
|
574
572
|
coefs = np.zeros(num_cost, dtype=np.float64)
|
|
575
573
|
for i in range(lhs.size()):
|
|
@@ -7,6 +7,8 @@ from __future__ import annotations
|
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
|
+
from pyepo.data._validation import validate_degree, validate_nonnegative
|
|
11
|
+
|
|
10
12
|
|
|
11
13
|
def genData(
|
|
12
14
|
num_data: int,
|
|
@@ -39,13 +41,8 @@ def genData(
|
|
|
39
41
|
Returns:
|
|
40
42
|
tuple: weights of items (np.ndarray), data features (np.ndarray), costs (np.ndarray)
|
|
41
43
|
"""
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
raise ValueError(f"deg = {deg} should be int.")
|
|
45
|
-
if deg <= 0:
|
|
46
|
-
raise ValueError(f"deg = {deg} should be positive.")
|
|
47
|
-
if noise_width < 0:
|
|
48
|
-
raise ValueError(f"noise_width = {noise_width} should be non-negative.")
|
|
44
|
+
validate_degree(deg)
|
|
45
|
+
validate_nonnegative(noise_width, "noise_width")
|
|
49
46
|
# set seed
|
|
50
47
|
rnd = np.random.RandomState(seed)
|
|
51
48
|
# number of data points
|
|
@@ -7,6 +7,8 @@ from __future__ import annotations
|
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
|
+
from pyepo.data._validation import validate_degree, validate_nonnegative
|
|
11
|
+
|
|
10
12
|
|
|
11
13
|
def genData(
|
|
12
14
|
num_data: int,
|
|
@@ -40,11 +42,8 @@ def genData(
|
|
|
40
42
|
Returns:
|
|
41
43
|
tuple: covariance matrix (np.ndarray), data features (np.ndarray), mean returns (np.ndarray)
|
|
42
44
|
"""
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
raise ValueError(f"deg = {deg} should be int.")
|
|
46
|
-
if deg <= 0:
|
|
47
|
-
raise ValueError(f"deg = {deg} should be positive.")
|
|
45
|
+
validate_degree(deg)
|
|
46
|
+
validate_nonnegative(noise_level, "noise_level")
|
|
48
47
|
# set seed
|
|
49
48
|
rnd = np.random.RandomState(seed)
|
|
50
49
|
# number of data points
|
|
@@ -7,6 +7,8 @@ from __future__ import annotations
|
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
|
+
from pyepo.data._validation import validate_degree, validate_nonnegative
|
|
11
|
+
|
|
10
12
|
|
|
11
13
|
def genData(
|
|
12
14
|
num_data: int,
|
|
@@ -36,13 +38,8 @@ def genData(
|
|
|
36
38
|
Returns:
|
|
37
39
|
tuple: data features (np.ndarray), costs (np.ndarray)
|
|
38
40
|
"""
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
raise ValueError(f"deg = {deg} should be int.")
|
|
42
|
-
if deg <= 0:
|
|
43
|
-
raise ValueError(f"deg = {deg} should be positive.")
|
|
44
|
-
if noise_width < 0:
|
|
45
|
-
raise ValueError(f"noise_width = {noise_width} should be non-negative.")
|
|
41
|
+
validate_degree(deg)
|
|
42
|
+
validate_nonnegative(noise_width, "noise_width")
|
|
46
43
|
# set seed
|
|
47
44
|
rnd = np.random.RandomState(seed)
|
|
48
45
|
# number of data points
|
|
@@ -8,6 +8,8 @@ from __future__ import annotations
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from scipy.spatial import distance
|
|
10
10
|
|
|
11
|
+
from pyepo.data._validation import validate_degree, validate_nonnegative
|
|
12
|
+
|
|
11
13
|
|
|
12
14
|
def genData(
|
|
13
15
|
num_data: int,
|
|
@@ -38,13 +40,8 @@ def genData(
|
|
|
38
40
|
Returns:
|
|
39
41
|
tuple: data features (np.ndarray), costs (np.ndarray)
|
|
40
42
|
"""
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
raise ValueError(f"deg = {deg} should be int.")
|
|
44
|
-
if deg <= 0:
|
|
45
|
-
raise ValueError(f"deg = {deg} should be positive.")
|
|
46
|
-
if noise_width < 0:
|
|
47
|
-
raise ValueError(f"noise_width = {noise_width} should be non-negative.")
|
|
43
|
+
validate_degree(deg)
|
|
44
|
+
validate_nonnegative(noise_width, "noise_width")
|
|
48
45
|
# set seed
|
|
49
46
|
rnd = np.random.RandomState(seed)
|
|
50
47
|
# number of data points
|
|
@@ -11,11 +11,14 @@ subclass builds the solver model and provides the read / write hooks.
|
|
|
11
11
|
|
|
12
12
|
from __future__ import annotations
|
|
13
13
|
|
|
14
|
+
from copy import deepcopy
|
|
15
|
+
|
|
14
16
|
import numpy as np
|
|
15
17
|
import torch
|
|
16
18
|
|
|
19
|
+
from pyepo.model._common import validate_constraint, validate_objective_shape
|
|
17
20
|
from pyepo.model.opt import optModel
|
|
18
|
-
from pyepo.utils import costToNumpy
|
|
21
|
+
from pyepo.utils import costToNumpy
|
|
19
22
|
|
|
20
23
|
|
|
21
24
|
class compiledBase(optModel):
|
|
@@ -25,11 +28,18 @@ class compiledBase(optModel):
|
|
|
25
28
|
|
|
26
29
|
def __init__(self, problem, params=None):
|
|
27
30
|
# the source DSL Problem and backend solver parameters
|
|
28
|
-
self.problem = problem
|
|
31
|
+
self.problem = deepcopy(problem)
|
|
29
32
|
self.params = dict(params) if params else {}
|
|
30
33
|
super().__init__()
|
|
31
34
|
self._apply_params()
|
|
32
35
|
|
|
36
|
+
def get_config(self) -> dict:
|
|
37
|
+
return {
|
|
38
|
+
**super().get_config(),
|
|
39
|
+
"problem": deepcopy(self.problem),
|
|
40
|
+
"params": self.params.copy(),
|
|
41
|
+
}
|
|
42
|
+
|
|
33
43
|
@property
|
|
34
44
|
def num_cost(self) -> int:
|
|
35
45
|
# predicted cost dimension
|
|
@@ -44,20 +54,18 @@ class compiledBase(optModel):
|
|
|
44
54
|
"""Set the objective from a predicted cost of length ``num_cost``, scattered onto the known fixed costs."""
|
|
45
55
|
prob = self.problem
|
|
46
56
|
coef = costToNumpy(c)
|
|
57
|
+
validate_objective_shape(coef, (prob.num_cost, prob.num_vars))
|
|
47
58
|
# scatter onto fixed costs; an unambiguous full-length vector passes through
|
|
48
59
|
if coef.shape[-1] == prob.num_cost:
|
|
49
60
|
full = prob.fixed_cost.copy()
|
|
50
61
|
full[prob.c_pred_index] += coef
|
|
51
62
|
coef = full
|
|
52
|
-
elif coef.shape[-1] != prob.num_vars:
|
|
53
|
-
raise ValueError("Size of cost vector does not match number of cost variables.")
|
|
54
63
|
self._write_obj(coef)
|
|
55
64
|
|
|
56
65
|
def _setFullObj(self, c):
|
|
57
66
|
"""Set the objective from full-space coefficients (length ``num_vars``), bypassing the predicted-cost scatter."""
|
|
58
67
|
coef = costToNumpy(c)
|
|
59
|
-
|
|
60
|
-
raise ValueError("Size of cost vector does not match number of variables.")
|
|
68
|
+
validate_objective_shape(coef, self.problem.num_vars, full=True)
|
|
61
69
|
self._write_obj(coef)
|
|
62
70
|
|
|
63
71
|
def _fullCost(self, pred_cost):
|
|
@@ -85,15 +93,16 @@ class compiledBase(optModel):
|
|
|
85
93
|
|
|
86
94
|
def addConstr(self, coefs, rhs):
|
|
87
95
|
# add a cut coefs @ x <= rhs over the full variable vector
|
|
88
|
-
|
|
96
|
+
rhs = validate_constraint(coefs, rhs, self.problem.num_vars, full=True)
|
|
97
|
+
coefs = np.array(coefs, dtype=float, copy=True).reshape(-1)
|
|
89
98
|
new_model = self._add_cut(coefs, rhs)
|
|
90
99
|
# track for replay on relax
|
|
91
|
-
new_model._extra_constrs = [*self._extra_constrs, (coefs,
|
|
100
|
+
new_model._extra_constrs = [*self._extra_constrs, (coefs, rhs)]
|
|
92
101
|
return new_model
|
|
93
102
|
|
|
94
103
|
def relax(self):
|
|
95
104
|
# recompile the relaxed problem, preserving backend kwargs
|
|
96
|
-
kwargs =
|
|
105
|
+
kwargs = self.get_config()
|
|
97
106
|
kwargs["problem"] = self.problem.relax()
|
|
98
107
|
model_rel = type(self)(**kwargs)
|
|
99
108
|
# replay user cuts on the relaxation
|
|
@@ -62,6 +62,8 @@ class Problem:
|
|
|
62
62
|
# objective cost layout
|
|
63
63
|
self.cost_param = objective.cost_param
|
|
64
64
|
self.cost_var = objective.cost_var
|
|
65
|
+
self.modelSense = objective.modelSense
|
|
66
|
+
self.cost_var_name = objective.cost_var.name
|
|
65
67
|
# assign flat slices in encounter order (objective var first)
|
|
66
68
|
self._assign_flat()
|
|
67
69
|
# finalize IR
|
|
@@ -69,7 +71,7 @@ class Problem:
|
|
|
69
71
|
|
|
70
72
|
def __repr__(self) -> str:
|
|
71
73
|
# one-line summary of the finalized problem
|
|
72
|
-
sense = "min" if self.
|
|
74
|
+
sense = "min" if self.modelSense == EPO.MINIMIZE else "max"
|
|
73
75
|
n_quad = sum(1 for Q, *_ in self.constrs if Q is not None)
|
|
74
76
|
quad = f" [{n_quad} quad]" if n_quad else ""
|
|
75
77
|
obj_q = " +quad obj" if self.obj_Q is not None else ""
|
|
@@ -160,7 +162,7 @@ class Problem:
|
|
|
160
162
|
Return a new Problem with all integer / binary variables continuous
|
|
161
163
|
(bounds preserved: binary ⇒ [0, 1]); objective and constraints unchanged.
|
|
162
164
|
"""
|
|
163
|
-
new = _copy.
|
|
165
|
+
new = _copy.deepcopy(self)
|
|
164
166
|
new.var_type = np.full(self.num_vars, EPO.CONTINUOUS, dtype=object)
|
|
165
167
|
return new
|
|
166
168
|
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""Backend-independent policies shared by the Torch and JAX frontends."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from numbers import Real
|
|
5
|
+
from typing import Optional, TypeVar
|
|
6
|
+
|
|
7
|
+
from pyepo import EPO
|
|
8
|
+
|
|
9
|
+
T = TypeVar("T")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def is_minimize(model_sense) -> bool:
|
|
13
|
+
"""Return the objective direction, rejecting unsupported sense values."""
|
|
14
|
+
if model_sense == EPO.MINIMIZE:
|
|
15
|
+
return True
|
|
16
|
+
if model_sense == EPO.MAXIMIZE:
|
|
17
|
+
return False
|
|
18
|
+
raise ValueError("Invalid modelSense. Must be EPO.MINIMIZE or EPO.MAXIMIZE.")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def solution_pool_tolerance(num_cost: int) -> float:
|
|
22
|
+
"""L1 tolerance used to deduplicate approximate solver solutions."""
|
|
23
|
+
return min(1e-4 * num_cost, 0.1)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def validate_positive(value, name: str) -> None:
|
|
27
|
+
"""Validate a finite, strictly positive real parameter."""
|
|
28
|
+
if not isinstance(value, Real) or isinstance(value, bool):
|
|
29
|
+
raise ValueError(f"{name} must be a finite positive number.")
|
|
30
|
+
number = float(value)
|
|
31
|
+
if not math.isfinite(number) or number <= 0:
|
|
32
|
+
raise ValueError(f"{name} must be a finite positive number.")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def validate_positive_int(value, name: str) -> None:
|
|
36
|
+
"""Validate a strictly positive integer parameter."""
|
|
37
|
+
if not isinstance(value, int) or isinstance(value, bool) or value <= 0:
|
|
38
|
+
raise ValueError(f"{name} must be a positive integer.")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def validate_nonnegative(value, name: str) -> None:
|
|
42
|
+
"""Validate a finite, non-negative real parameter."""
|
|
43
|
+
if not isinstance(value, Real) or isinstance(value, bool):
|
|
44
|
+
raise ValueError(f"{name} must be a finite non-negative number.")
|
|
45
|
+
number = float(value)
|
|
46
|
+
if not math.isfinite(number) or number < 0:
|
|
47
|
+
raise ValueError(f"{name} must be a finite non-negative number.")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def validate_probability(value, name: str) -> None:
|
|
51
|
+
"""Validate a finite real probability in the closed interval [0, 1]."""
|
|
52
|
+
if not isinstance(value, Real) or isinstance(value, bool):
|
|
53
|
+
raise ValueError(f"{name} must be a finite number in [0, 1].")
|
|
54
|
+
number = float(value)
|
|
55
|
+
if not math.isfinite(number) or not 0.0 <= number <= 1.0:
|
|
56
|
+
raise ValueError(f"{name} must be a finite number in [0, 1].")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def require_solution_pool(solpool: Optional[T]) -> T:
|
|
60
|
+
"""Return an initialized solution pool or raise a stable runtime error."""
|
|
61
|
+
if solpool is None:
|
|
62
|
+
raise RuntimeError(
|
|
63
|
+
"Solution pool is unavailable; provide an optDataset when pool-based solving is enabled."
|
|
64
|
+
)
|
|
65
|
+
return solpool
|