pyepo 2.2.2__tar.gz → 2.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. {pyepo-2.2.2 → pyepo-2.2.3}/PKG-INFO +8 -2
  2. {pyepo-2.2.2 → pyepo-2.2.3}/README.md +7 -1
  3. pyepo-2.2.3/pyepo/data/_validation.py +32 -0
  4. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/data/dataset.py +69 -71
  5. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/data/knapsack.py +4 -7
  6. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/data/portfolio.py +4 -5
  7. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/data/shortestpath.py +4 -7
  8. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/data/tsp.py +4 -7
  9. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/dsl/compiled.py +18 -9
  10. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/dsl/problem.py +4 -2
  11. pyepo-2.2.3/pyepo/func/_common.py +65 -0
  12. pyepo-2.2.3/pyepo/func/abcmodule.py +100 -0
  13. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/blackbox.py +4 -8
  14. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/cave.py +8 -8
  15. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/contrastive.py +7 -26
  16. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/jax/__init__.py +1 -1
  17. pyepo-2.2.3/pyepo/func/jax/abcmodule.py +96 -0
  18. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/jax/blackbox.py +4 -8
  19. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/jax/cave.py +19 -9
  20. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/jax/contrastive.py +19 -10
  21. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/jax/perturbed.py +22 -13
  22. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/jax/rank.py +28 -17
  23. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/jax/regularized.py +22 -9
  24. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/jax/surrogate.py +12 -5
  25. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/jax/utils.py +7 -8
  26. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/perturbed.py +22 -22
  27. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/rank.py +12 -40
  28. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/regularized.py +20 -19
  29. pyepo-2.2.3/pyepo/func/runtime.py +129 -0
  30. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/surrogate.py +12 -15
  31. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/utils.py +22 -22
  32. pyepo-2.2.3/pyepo/metric/_common.py +167 -0
  33. pyepo-2.2.3/pyepo/metric/metrics.py +120 -0
  34. pyepo-2.2.3/pyepo/metric/mse.py +42 -0
  35. pyepo-2.2.3/pyepo/metric/regret.py +163 -0
  36. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/metric/unambregret.py +26 -19
  37. pyepo-2.2.3/pyepo/model/__init__.py +41 -0
  38. pyepo-2.2.3/pyepo/model/_common.py +60 -0
  39. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/bases.py +127 -34
  40. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/copt/compile.py +10 -12
  41. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/copt/coptmodel.py +36 -15
  42. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/copt/knapsack.py +2 -46
  43. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/copt/portfolio.py +3 -35
  44. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/copt/shortestpath.py +3 -33
  45. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/copt/tsp.py +29 -67
  46. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/copt/vrp.py +24 -16
  47. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/grb/compile.py +7 -9
  48. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/grb/grbmodel.py +43 -14
  49. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/grb/knapsack.py +2 -44
  50. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/grb/portfolio.py +3 -35
  51. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/grb/shortestpath.py +3 -33
  52. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/grb/tsp.py +12 -11
  53. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/grb/vrp.py +13 -19
  54. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/mpax/compile.py +7 -8
  55. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/mpax/knapsack.py +1 -21
  56. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/mpax/mpaxmodel.py +8 -12
  57. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/mpax/shortestpath.py +1 -31
  58. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/omo/compile.py +11 -18
  59. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/omo/knapsack.py +6 -46
  60. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/omo/omomodel.py +42 -15
  61. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/omo/portfolio.py +2 -35
  62. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/omo/shortestpath.py +2 -33
  63. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/omo/tsp.py +6 -52
  64. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/omo/vrp.py +5 -13
  65. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/opt.py +52 -5
  66. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/ort/compile.py +6 -5
  67. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/ort/knapsack.py +3 -57
  68. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/ort/ortcpmodel.py +18 -7
  69. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/ort/ortmodel.py +16 -29
  70. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/ort/shortestpath.py +2 -44
  71. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/predefined.py +13 -12
  72. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/twostage/autosklearnpred.py +1 -7
  73. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/utils.py +2 -4
  74. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo.egg-info/PKG-INFO +8 -2
  75. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo.egg-info/SOURCES.txt +6 -0
  76. {pyepo-2.2.2 → pyepo-2.2.3}/pyproject.toml +2 -7
  77. {pyepo-2.2.2 → pyepo-2.2.3}/tests/test_00_constants.py +0 -1
  78. {pyepo-2.2.2 → pyepo-2.2.3}/tests/test_10_utils.py +98 -5
  79. {pyepo-2.2.2 → pyepo-2.2.3}/tests/test_15_dsl.py +263 -108
  80. {pyepo-2.2.2 → pyepo-2.2.3}/tests/test_20_data_gen.py +28 -36
  81. {pyepo-2.2.2 → pyepo-2.2.3}/tests/test_30_model.py +465 -49
  82. {pyepo-2.2.2 → pyepo-2.2.3}/tests/test_40_dataset.py +65 -22
  83. {pyepo-2.2.2 → pyepo-2.2.3}/tests/test_50_func.py +384 -149
  84. {pyepo-2.2.2 → pyepo-2.2.3}/tests/test_55_jax.py +95 -122
  85. {pyepo-2.2.2 → pyepo-2.2.3}/tests/test_60_metric.py +159 -85
  86. pyepo-2.2.3/tests/test_61_metric_validation.py +168 -0
  87. {pyepo-2.2.2 → pyepo-2.2.3}/tests/test_70_twostage.py +28 -3
  88. {pyepo-2.2.2 → pyepo-2.2.3}/tests/test_80_integration.py +50 -30
  89. {pyepo-2.2.2 → pyepo-2.2.3}/tests/test_90_cuda.py +7 -10
  90. pyepo-2.2.2/pyepo/func/abcmodule.py +0 -120
  91. pyepo-2.2.2/pyepo/func/jax/abcmodule.py +0 -118
  92. pyepo-2.2.2/pyepo/metric/metrics.py +0 -191
  93. pyepo-2.2.2/pyepo/metric/mse.py +0 -49
  94. pyepo-2.2.2/pyepo/metric/regret.py +0 -216
  95. pyepo-2.2.2/pyepo/model/__init__.py +0 -52
  96. {pyepo-2.2.2 → pyepo-2.2.3}/LICENSE +0 -0
  97. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/EPO.py +0 -0
  98. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/__init__.py +0 -0
  99. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/data/__init__.py +0 -0
  100. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/dsl/__init__.py +0 -0
  101. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/dsl/expression.py +0 -0
  102. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/dsl/objective.py +0 -0
  103. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/func/__init__.py +0 -0
  104. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/metric/__init__.py +0 -0
  105. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/copt/__init__.py +0 -0
  106. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/grb/__init__.py +0 -0
  107. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/mpax/__init__.py +0 -0
  108. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/omo/__init__.py +0 -0
  109. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/ort/__init__.py +0 -0
  110. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/model/utils.py +0 -0
  111. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/py.typed +0 -0
  112. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/twostage/__init__.py +0 -0
  113. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo/twostage/sklearnpred.py +0 -0
  114. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo.egg-info/dependency_links.txt +0 -0
  115. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo.egg-info/requires.txt +0 -0
  116. {pyepo-2.2.2 → pyepo-2.2.3}/pyepo.egg-info/top_level.txt +0 -0
  117. {pyepo-2.2.2 → pyepo-2.2.3}/setup.cfg +0 -0
  118. {pyepo-2.2.2 → pyepo-2.2.3}/tests/test_85_backend_pipeline.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyepo
3
- Version: 2.2.2
3
+ Version: 2.2.3
4
4
  Summary: PyTorch-based End-to-End Predict-then-Optimize Tool
5
5
  Author-email: Bo Tang <bolucas.tang@mail.utoronto.ca>
6
6
  License: MIT
@@ -76,12 +76,18 @@ Dynamic: license-file
76
76
 
77
77
  ## Features
78
78
 
79
- - Implement **SPO+**, **DBB**, **NID**, **DPO** (additive and multiplicative perturbations), **PFYL** (additive and multiplicative perturbations), L2-regularized **RFWO/RFYL**, **NCE**, **LTR**, **I-MLE**, **AI-MLE**, and **PG**
79
+ - Implement **SPO+**, **PG**, **DPO** (additive and multiplicative perturbations), **PFYL** (additive and multiplicative perturbations), **I-MLE**, **AI-MLE**, L2-regularized **RFWO/RFYL**, **DBB**, **NID**, **CaVE**, **NCE**, and **LTR**
80
80
  - Support [Gurobi](https://www.gurobi.com/), [COPT](https://shanshu.ai/copt), [Pyomo](http://www.pyomo.org/), [Google OR-Tools](https://developers.google.com/optimization), and [MPAX](https://github.com/MIT-Lu-Lab/MPAX) API
81
+ - Symbolic modeling with `pyepo.dsl`: define an LP, MIP, or QP once, then compile it to any backend
82
+ - JAX frontend (`pyepo.func.jax`): train any loss in JAX/Flax with `jax.grad`
81
83
  - Support parallel computing for optimization solvers
82
84
  - Support solution caching to speed up training
83
85
  - Support kNN robust loss to improve decision quality
84
86
 
87
+ ## CaVE for Binary Linear Programs
88
+
89
+ For end-to-end learning on **binary linear programs** (TSP, CVRP, knapsack, ...), ``PyEPO`` ships **CaVE**. CaVE replaces the per-step ILP solve with a cone-alignment projection onto the binding-constraint normals at the true optimum, backed by an interior-point QP solver (Clarabel). Because the cone projection is far cheaper than the per-instance ILP solve, CaVE trains an order of magnitude faster than SPO+ at TSP scale.
90
+
85
91
  ## GPU-Accelerated Solving with MPAX
86
92
 
87
93
  ``PyEPO`` integrates [MPAX](https://github.com/MIT-Lu-Lab/MPAX), a JAX-based mathematical programming solver using the PDHG algorithm for GPU-accelerated optimization. Key advantages: (1) **GPU-native solving** — the first-order PDHG method runs efficiently on GPU; (2) **batch solving** — an entire mini-batch can be solved simultaneously via vectorization; (3) **no GPU-CPU data transfer overhead** — both the neural network and the solver stay on GPU, eliminating the data transfer bottleneck.
@@ -4,12 +4,18 @@
4
4
 
5
5
  ## Features
6
6
 
7
- - Implement **SPO+**, **DBB**, **NID**, **DPO** (additive and multiplicative perturbations), **PFYL** (additive and multiplicative perturbations), L2-regularized **RFWO/RFYL**, **NCE**, **LTR**, **I-MLE**, **AI-MLE**, and **PG**
7
+ - Implement **SPO+**, **PG**, **DPO** (additive and multiplicative perturbations), **PFYL** (additive and multiplicative perturbations), **I-MLE**, **AI-MLE**, L2-regularized **RFWO/RFYL**, **DBB**, **NID**, **CaVE**, **NCE**, and **LTR**
8
8
  - Support [Gurobi](https://www.gurobi.com/), [COPT](https://shanshu.ai/copt), [Pyomo](http://www.pyomo.org/), [Google OR-Tools](https://developers.google.com/optimization), and [MPAX](https://github.com/MIT-Lu-Lab/MPAX) API
9
+ - Symbolic modeling with `pyepo.dsl`: define an LP, MIP, or QP once, then compile it to any backend
10
+ - JAX frontend (`pyepo.func.jax`): train any loss in JAX/Flax with `jax.grad`
9
11
  - Support parallel computing for optimization solvers
10
12
  - Support solution caching to speed up training
11
13
  - Support kNN robust loss to improve decision quality
12
14
 
15
+ ## CaVE for Binary Linear Programs
16
+
17
+ For end-to-end learning on **binary linear programs** (TSP, CVRP, knapsack, ...), ``PyEPO`` ships **CaVE**. CaVE replaces the per-step ILP solve with a cone-alignment projection onto the binding-constraint normals at the true optimum, backed by an interior-point QP solver (Clarabel). Because the cone projection is far cheaper than the per-instance ILP solve, CaVE trains an order of magnitude faster than SPO+ at TSP scale.
18
+
13
19
  ## GPU-Accelerated Solving with MPAX
14
20
 
15
21
  ``PyEPO`` integrates [MPAX](https://github.com/MIT-Lu-Lab/MPAX), a JAX-based mathematical programming solver using the PDHG algorithm for GPU-accelerated optimization. Key advantages: (1) **GPU-native solving** — the first-order PDHG method runs efficiently on GPU; (2) **batch solving** — an entire mini-batch can be solved simultaneously via vectorization; (3) **no GPU-CPU data transfer overhead** — both the neural network and the solver stay on GPU, eliminating the data transfer bottleneck.
@@ -0,0 +1,32 @@
1
+ """Shared validation for synthetic data generators."""
2
+
3
+ import math
4
+ from numbers import Real
5
+
6
+
7
+ def validate_degree(deg: int) -> None:
8
+ """Validate a positive integer polynomial degree."""
9
+ if not isinstance(deg, int) or isinstance(deg, bool) or deg <= 0:
10
+ raise ValueError(f"deg = {deg} should be a positive integer.")
11
+
12
+
13
+ def validate_nonnegative(value: float, name: str) -> None:
14
+ """Validate a finite, non-negative real generator parameter."""
15
+ if not isinstance(value, Real) or isinstance(value, bool):
16
+ raise ValueError(f"{name} = {value} should be a finite non-negative number.")
17
+ number = float(value)
18
+ if not math.isfinite(number) or number < 0:
19
+ raise ValueError(f"{name} = {value} should be a finite non-negative number.")
20
+
21
+
22
+ def validate_probability(value: float, name: str) -> None:
23
+ """Validate a finite real value in the closed interval [0, 1]."""
24
+ validate_nonnegative(value, name)
25
+ if float(value) > 1:
26
+ raise ValueError(f"{name} = {value} should be in [0, 1].")
27
+
28
+
29
+ def validate_positive_int(value: int, name: str) -> None:
30
+ """Validate a strictly positive integer parameter."""
31
+ if not isinstance(value, int) or isinstance(value, bool) or value <= 0:
32
+ raise ValueError(f"{name} = {value} should be a positive integer.")
@@ -15,19 +15,48 @@ from torch.utils.data import Dataset
15
15
  from tqdm import tqdm
16
16
 
17
17
  from pyepo import EPO
18
+ from pyepo.data._validation import validate_positive_int, validate_probability
18
19
  from pyepo.model.opt import optModel
19
20
 
20
21
  if TYPE_CHECKING:
21
22
  from pyepo.model.mpax import optMpaxModel as _optMpaxModelT
22
23
 
23
24
  try:
24
- from pyepo.model.mpax import optMpaxModel
25
+ from pyepo.model.mpax import optMpaxModel as _opt_mpax_model_cls
25
26
  except ImportError:
26
- optMpaxModel = None # type: ignore[assignment]
27
+ _opt_mpax_model_cls = None
27
28
 
28
29
  logger = logging.getLogger(__name__)
29
30
 
30
31
 
32
+ def _validate_inputs(
33
+ model: optModel,
34
+ feats: np.ndarray | torch.Tensor,
35
+ costs: np.ndarray | torch.Tensor,
36
+ ) -> None:
37
+ """Validate the common constructor contract for optimization datasets."""
38
+ if not isinstance(model, optModel):
39
+ raise TypeError("arg model is not an optModel")
40
+ if len(feats) != len(costs):
41
+ raise ValueError(
42
+ f"feats and costs must have the same number of instances: {len(feats)} vs {len(costs)}."
43
+ )
44
+
45
+
46
+ def _as_float_tensor(data) -> torch.Tensor:
47
+ """Convert dataset arrays to the common float32 tensor representation."""
48
+ return torch.as_tensor(data, dtype=torch.float32)
49
+
50
+
51
+ def _solution_to_numpy(
52
+ solution: np.ndarray | torch.Tensor | list,
53
+ ) -> np.ndarray:
54
+ """Normalize a solver solution to a NumPy array."""
55
+ if isinstance(solution, torch.Tensor):
56
+ solution = solution.detach().cpu().numpy()
57
+ return np.asarray(solution)
58
+
59
+
31
60
  class optDataset(Dataset):
32
61
  """
33
62
  PyTorch ``Dataset`` for predict-then-optimize problems.
@@ -62,44 +91,35 @@ class optDataset(Dataset):
62
91
  feats: data features
63
92
  costs: costs of objective function
64
93
  """
65
- if not isinstance(model, optModel):
66
- raise TypeError("arg model is not an optModel")
67
- if len(feats) != len(costs):
68
- raise ValueError(
69
- f"feats and costs must have the same number of instances: "
70
- f"{len(feats)} vs {len(costs)}."
71
- )
94
+ _validate_inputs(model, feats, costs)
72
95
  self.model = model
73
96
  # data
74
97
  self.feats = feats
75
98
  self.costs = costs
76
99
  # find optimal solutions
77
- sols, objs = self._getSols()
78
- self.feats = torch.as_tensor(feats, dtype=torch.float32)
79
- self.costs = torch.as_tensor(costs, dtype=torch.float32)
80
- self.sols = torch.as_tensor(sols, dtype=torch.float32)
81
- self.objs = torch.as_tensor(objs, dtype=torch.float32)
100
+ sols, objs = self._get_sols()
101
+ self.feats = _as_float_tensor(feats)
102
+ self.costs = _as_float_tensor(costs)
103
+ self.sols = _as_float_tensor(sols)
104
+ self.objs = _as_float_tensor(objs)
82
105
 
83
- def _getSols(self) -> tuple[np.ndarray, np.ndarray]:
106
+ def _get_sols(self) -> tuple[np.ndarray, np.ndarray]:
84
107
  """
85
108
  A method to get optimal solutions for all cost vectors
86
109
  """
87
110
  # MPAX fast path: vmap-solve the whole dataset in a single dispatch
88
- if optMpaxModel is not None and isinstance(self.model, optMpaxModel):
89
- return self._getSolsMpaxBatch()
111
+ if _opt_mpax_model_cls is not None and isinstance(self.model, _opt_mpax_model_cls):
112
+ return self._get_sols_mpax_batch()
90
113
  sols = []
91
114
  objs = []
92
115
  logger.info("Optimizing for optDataset...")
93
116
  for c in tqdm(self.costs):
94
117
  sol, obj = self._solve(c)
95
- # to numpy
96
- if isinstance(sol, torch.Tensor):
97
- sol = sol.detach().cpu().numpy()
98
- sols.append(np.asarray(sol))
118
+ sols.append(_solution_to_numpy(sol))
99
119
  objs.append(obj)
100
120
  return np.stack(sols), np.asarray(objs).reshape(-1, 1)
101
121
 
102
- def _getSolsMpaxBatch(self) -> tuple[np.ndarray, np.ndarray]:
122
+ def _get_sols_mpax_batch(self) -> tuple[np.ndarray, np.ndarray]:
103
123
  """
104
124
  A method to batch-solve every cost vector in one MPAX vmap call.
105
125
  """
@@ -209,20 +229,14 @@ class optDatasetKNN(optDataset):
209
229
  k: number of nearest neighbours selected
210
230
  weight: self-weight in the kNN convex combination (1.0 = no smoothing)
211
231
  """
212
- if not isinstance(model, optModel):
213
- raise TypeError("arg model is not an optModel")
214
- if len(feats) != len(costs):
215
- raise ValueError(
216
- f"feats and costs must have the same number of instances: "
217
- f"{len(feats)} vs {len(costs)}."
218
- )
232
+ _validate_inputs(model, feats, costs)
219
233
  self.model = model
220
234
  # at most num_data-1 neighbours exist (self excluded), so k must stay below it
221
235
  num_data = len(feats)
222
- if not 1 <= k < num_data:
236
+ validate_positive_int(k, "k")
237
+ if k >= num_data:
223
238
  raise ValueError(f"Invalid k={k}; must satisfy 1 <= k < num_data ({num_data}).")
224
- if not 0.0 <= weight <= 1.0:
225
- raise ValueError(f"Invalid weight={weight}; must satisfy 0 <= weight <= 1.")
239
+ validate_probability(weight, "weight")
226
240
  # kNN loss parameters
227
241
  self.k = k
228
242
  self.weight = weight
@@ -230,13 +244,13 @@ class optDatasetKNN(optDataset):
230
244
  self.feats = feats
231
245
  self.costs = costs
232
246
  # find optimal solutions
233
- sols, objs = self._getSols()
234
- self.feats = torch.as_tensor(self.feats, dtype=torch.float32)
235
- self.costs = torch.as_tensor(self.costs, dtype=torch.float32)
236
- self.sols = torch.as_tensor(sols, dtype=torch.float32)
237
- self.objs = torch.as_tensor(objs, dtype=torch.float32)
247
+ sols, objs = self._get_sols()
248
+ self.feats = _as_float_tensor(self.feats)
249
+ self.costs = _as_float_tensor(self.costs)
250
+ self.sols = _as_float_tensor(sols)
251
+ self.objs = _as_float_tensor(objs)
238
252
 
239
- def _getSols(self) -> tuple[np.ndarray, np.ndarray]:
253
+ def _get_sols(self) -> tuple[np.ndarray, np.ndarray]:
240
254
  """
241
255
  A method to get optimal solutions for all cost vectors
242
256
  """
@@ -244,16 +258,14 @@ class optDatasetKNN(optDataset):
244
258
  objs = []
245
259
  logger.info("Optimizing for optDataset...")
246
260
  # get kNN costs
247
- costs_knn = self._getKNN()
261
+ costs_knn = self._get_knn()
248
262
  # solve optimization
249
263
  for c_knn in tqdm(costs_knn):
250
264
  sol_knn = np.zeros((self.costs.shape[1], self.k))
251
265
  obj_knn = np.zeros(self.k)
252
266
  for i, c in enumerate(c_knn.T):
253
267
  sol_i, obj_i = self._solve(c)
254
- if isinstance(sol_i, torch.Tensor):
255
- sol_i = sol_i.detach().cpu().numpy()
256
- sol_knn[:, i] = sol_i
268
+ sol_knn[:, i] = _solution_to_numpy(sol_i)
257
269
  obj_knn[i] = obj_i
258
270
  # get average
259
271
  sol = sol_knn.mean(axis=1)
@@ -264,7 +276,7 @@ class optDatasetKNN(optDataset):
264
276
  self.costs = costs_knn.mean(axis=2)
265
277
  return np.stack(sols), np.asarray(objs).reshape(-1, 1)
266
278
 
267
- def _getKNN(self) -> np.ndarray:
279
+ def _get_knn(self) -> np.ndarray:
268
280
  """
269
281
  A method to get kNN costs
270
282
  """
@@ -334,28 +346,22 @@ class optDatasetConstrs(optDataset):
334
346
  costs: costs of objective function
335
347
  skip_infeas: if True, drop infeasible instances instead of raising
336
348
  """
337
- if not isinstance(model, optModel):
338
- raise TypeError("arg model is not an optModel")
339
- if len(feats) != len(costs):
340
- raise ValueError(
341
- f"feats and costs must have the same number of instances: "
342
- f"{len(feats)} vs {len(costs)}."
343
- )
349
+ _validate_inputs(model, feats, costs)
344
350
  self.model = model
345
351
  self.skip_infeas = skip_infeas
346
352
  # data
347
353
  self.feats = feats
348
354
  self.costs = costs
349
355
  # find optimal solutions and binding constraints
350
- sols, objs, ctrs, valid = self._getSols()
356
+ sols, objs, ctrs, valid = self._get_sols()
351
357
  # pre-convert to tensors (on CPU) to avoid repeated numpy→tensor copies
352
- self.feats = torch.as_tensor(self.feats[valid], dtype=torch.float32)
353
- self.costs = torch.as_tensor(self.costs[valid], dtype=torch.float32)
354
- self.sols = torch.as_tensor(sols, dtype=torch.float32)
355
- self.objs = torch.as_tensor(objs, dtype=torch.float32)
356
- self.ctrs = [torch.as_tensor(c, dtype=torch.float32) for c in ctrs]
358
+ self.feats = _as_float_tensor(self.feats[valid])
359
+ self.costs = _as_float_tensor(self.costs[valid])
360
+ self.sols = _as_float_tensor(sols)
361
+ self.objs = _as_float_tensor(objs)
362
+ self.ctrs = [_as_float_tensor(c) for c in ctrs]
357
363
 
358
- def _getSols( # type: ignore[override]
364
+ def _get_sols( # type: ignore[override]
359
365
  self,
360
366
  ) -> tuple[np.ndarray, np.ndarray, list[np.ndarray], list[int]]:
361
367
  """
@@ -371,9 +377,8 @@ class optDatasetConstrs(optDataset):
371
377
  logger.info("Optimizing for optDatasetConstrs...")
372
378
  model = self.model
373
379
  for i, c in enumerate(tqdm(self.costs)):
374
- model._setFullObj(model._fullCost(c))
375
380
  try:
376
- sol, obj = model.solve()
381
+ sol, obj = self._solve(c)
377
382
  except RuntimeError as e:
378
383
  if self.skip_infeas:
379
384
  logger.warning("Instance %d had no solution, skipping: %s", i, e)
@@ -413,12 +418,6 @@ class optDatasetConstrs(optDataset):
413
418
  raise ValueError("No valid instances (all skipped or empty input).")
414
419
  return np.stack(sols), np.asarray(objs), ctrs, valid
415
420
 
416
- def __len__(self) -> int:
417
- """
418
- A method to get data size
419
- """
420
- return len(self.feats)
421
-
422
421
  def __getitem__( # type: ignore[override]
423
422
  self,
424
423
  index: int,
@@ -563,13 +562,12 @@ def _parse_temp_constraint(
563
562
  """
564
563
  Parse a Gurobi TempConstr into (coefs, rhs, sense) over the cost-vector dim
565
564
  """
566
- # TempConstr internals
567
- lhs = getattr(tc, "_lhs", None)
568
- rhs = getattr(tc, "_rhs", None)
569
- sense = getattr(tc, "_sense", None)
570
- # unparseable fallback
571
- if lhs is None or rhs is None or sense is None:
565
+ from pyepo.model.grb.grbmodel import _temp_constr_fields
566
+
567
+ fields = _temp_constr_fields(tc)
568
+ if fields is None:
572
569
  return None
570
+ lhs, rhs, sense = fields
573
571
  # project LinExpr terms onto cost-vector dim
574
572
  coefs = np.zeros(num_cost, dtype=np.float64)
575
573
  for i in range(lhs.size()):
@@ -7,6 +7,8 @@ from __future__ import annotations
7
7
 
8
8
  import numpy as np
9
9
 
10
+ from pyepo.data._validation import validate_degree, validate_nonnegative
11
+
10
12
 
11
13
  def genData(
12
14
  num_data: int,
@@ -39,13 +41,8 @@ def genData(
39
41
  Returns:
40
42
  tuple: weights of items (np.ndarray), data features (np.ndarray), costs (np.ndarray)
41
43
  """
42
- # positive integer parameter
43
- if not isinstance(deg, int):
44
- raise ValueError(f"deg = {deg} should be int.")
45
- if deg <= 0:
46
- raise ValueError(f"deg = {deg} should be positive.")
47
- if noise_width < 0:
48
- raise ValueError(f"noise_width = {noise_width} should be non-negative.")
44
+ validate_degree(deg)
45
+ validate_nonnegative(noise_width, "noise_width")
49
46
  # set seed
50
47
  rnd = np.random.RandomState(seed)
51
48
  # number of data points
@@ -7,6 +7,8 @@ from __future__ import annotations
7
7
 
8
8
  import numpy as np
9
9
 
10
+ from pyepo.data._validation import validate_degree, validate_nonnegative
11
+
10
12
 
11
13
  def genData(
12
14
  num_data: int,
@@ -40,11 +42,8 @@ def genData(
40
42
  Returns:
41
43
  tuple: covariance matrix (np.ndarray), data features (np.ndarray), mean returns (np.ndarray)
42
44
  """
43
- # positive integer parameter
44
- if not isinstance(deg, int):
45
- raise ValueError(f"deg = {deg} should be int.")
46
- if deg <= 0:
47
- raise ValueError(f"deg = {deg} should be positive.")
45
+ validate_degree(deg)
46
+ validate_nonnegative(noise_level, "noise_level")
48
47
  # set seed
49
48
  rnd = np.random.RandomState(seed)
50
49
  # number of data points
@@ -7,6 +7,8 @@ from __future__ import annotations
7
7
 
8
8
  import numpy as np
9
9
 
10
+ from pyepo.data._validation import validate_degree, validate_nonnegative
11
+
10
12
 
11
13
  def genData(
12
14
  num_data: int,
@@ -36,13 +38,8 @@ def genData(
36
38
  Returns:
37
39
  tuple: data features (np.ndarray), costs (np.ndarray)
38
40
  """
39
- # positive integer parameter
40
- if not isinstance(deg, int):
41
- raise ValueError(f"deg = {deg} should be int.")
42
- if deg <= 0:
43
- raise ValueError(f"deg = {deg} should be positive.")
44
- if noise_width < 0:
45
- raise ValueError(f"noise_width = {noise_width} should be non-negative.")
41
+ validate_degree(deg)
42
+ validate_nonnegative(noise_width, "noise_width")
46
43
  # set seed
47
44
  rnd = np.random.RandomState(seed)
48
45
  # number of data points
@@ -8,6 +8,8 @@ from __future__ import annotations
8
8
  import numpy as np
9
9
  from scipy.spatial import distance
10
10
 
11
+ from pyepo.data._validation import validate_degree, validate_nonnegative
12
+
11
13
 
12
14
  def genData(
13
15
  num_data: int,
@@ -38,13 +40,8 @@ def genData(
38
40
  Returns:
39
41
  tuple: data features (np.ndarray), costs (np.ndarray)
40
42
  """
41
- # positive integer parameter
42
- if not isinstance(deg, int):
43
- raise ValueError(f"deg = {deg} should be int.")
44
- if deg <= 0:
45
- raise ValueError(f"deg = {deg} should be positive.")
46
- if noise_width < 0:
47
- raise ValueError(f"noise_width = {noise_width} should be non-negative.")
43
+ validate_degree(deg)
44
+ validate_nonnegative(noise_width, "noise_width")
48
45
  # set seed
49
46
  rnd = np.random.RandomState(seed)
50
47
  # number of data points
@@ -11,11 +11,14 @@ subclass builds the solver model and provides the read / write hooks.
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
+ from copy import deepcopy
15
+
14
16
  import numpy as np
15
17
  import torch
16
18
 
19
+ from pyepo.model._common import validate_constraint, validate_objective_shape
17
20
  from pyepo.model.opt import optModel
18
- from pyepo.utils import costToNumpy, getArgs
21
+ from pyepo.utils import costToNumpy
19
22
 
20
23
 
21
24
  class compiledBase(optModel):
@@ -25,11 +28,18 @@ class compiledBase(optModel):
25
28
 
26
29
  def __init__(self, problem, params=None):
27
30
  # the source DSL Problem and backend solver parameters
28
- self.problem = problem
31
+ self.problem = deepcopy(problem)
29
32
  self.params = dict(params) if params else {}
30
33
  super().__init__()
31
34
  self._apply_params()
32
35
 
36
+ def get_config(self) -> dict:
37
+ return {
38
+ **super().get_config(),
39
+ "problem": deepcopy(self.problem),
40
+ "params": self.params.copy(),
41
+ }
42
+
33
43
  @property
34
44
  def num_cost(self) -> int:
35
45
  # predicted cost dimension
@@ -44,20 +54,18 @@ class compiledBase(optModel):
44
54
  """Set the objective from a predicted cost of length ``num_cost``, scattered onto the known fixed costs."""
45
55
  prob = self.problem
46
56
  coef = costToNumpy(c)
57
+ validate_objective_shape(coef, (prob.num_cost, prob.num_vars))
47
58
  # scatter onto fixed costs; an unambiguous full-length vector passes through
48
59
  if coef.shape[-1] == prob.num_cost:
49
60
  full = prob.fixed_cost.copy()
50
61
  full[prob.c_pred_index] += coef
51
62
  coef = full
52
- elif coef.shape[-1] != prob.num_vars:
53
- raise ValueError("Size of cost vector does not match number of cost variables.")
54
63
  self._write_obj(coef)
55
64
 
56
65
  def _setFullObj(self, c):
57
66
  """Set the objective from full-space coefficients (length ``num_vars``), bypassing the predicted-cost scatter."""
58
67
  coef = costToNumpy(c)
59
- if coef.shape[-1] != self.problem.num_vars:
60
- raise ValueError("Size of cost vector does not match number of variables.")
68
+ validate_objective_shape(coef, self.problem.num_vars, full=True)
61
69
  self._write_obj(coef)
62
70
 
63
71
  def _fullCost(self, pred_cost):
@@ -85,15 +93,16 @@ class compiledBase(optModel):
85
93
 
86
94
  def addConstr(self, coefs, rhs):
87
95
  # add a cut coefs @ x <= rhs over the full variable vector
88
- coefs = np.asarray(coefs, dtype=float).reshape(-1)
96
+ rhs = validate_constraint(coefs, rhs, self.problem.num_vars, full=True)
97
+ coefs = np.array(coefs, dtype=float, copy=True).reshape(-1)
89
98
  new_model = self._add_cut(coefs, rhs)
90
99
  # track for replay on relax
91
- new_model._extra_constrs = [*self._extra_constrs, (coefs, float(rhs))]
100
+ new_model._extra_constrs = [*self._extra_constrs, (coefs, rhs)]
92
101
  return new_model
93
102
 
94
103
  def relax(self):
95
104
  # recompile the relaxed problem, preserving backend kwargs
96
- kwargs = getArgs(self)
105
+ kwargs = self.get_config()
97
106
  kwargs["problem"] = self.problem.relax()
98
107
  model_rel = type(self)(**kwargs)
99
108
  # replay user cuts on the relaxation
@@ -62,6 +62,8 @@ class Problem:
62
62
  # objective cost layout
63
63
  self.cost_param = objective.cost_param
64
64
  self.cost_var = objective.cost_var
65
+ self.modelSense = objective.modelSense
66
+ self.cost_var_name = objective.cost_var.name
65
67
  # assign flat slices in encounter order (objective var first)
66
68
  self._assign_flat()
67
69
  # finalize IR
@@ -69,7 +71,7 @@ class Problem:
69
71
 
70
72
  def __repr__(self) -> str:
71
73
  # one-line summary of the finalized problem
72
- sense = "min" if self.objective.modelSense == EPO.MINIMIZE else "max"
74
+ sense = "min" if self.modelSense == EPO.MINIMIZE else "max"
73
75
  n_quad = sum(1 for Q, *_ in self.constrs if Q is not None)
74
76
  quad = f" [{n_quad} quad]" if n_quad else ""
75
77
  obj_q = " +quad obj" if self.obj_Q is not None else ""
@@ -160,7 +162,7 @@ class Problem:
160
162
  Return a new Problem with all integer / binary variables continuous
161
163
  (bounds preserved: binary ⇒ [0, 1]); objective and constraints unchanged.
162
164
  """
163
- new = _copy.copy(self)
165
+ new = _copy.deepcopy(self)
164
166
  new.var_type = np.full(self.num_vars, EPO.CONTINUOUS, dtype=object)
165
167
  return new
166
168
 
@@ -0,0 +1,65 @@
1
+ """Backend-independent policies shared by the Torch and JAX frontends."""
2
+
3
+ import math
4
+ from numbers import Real
5
+ from typing import Optional, TypeVar
6
+
7
+ from pyepo import EPO
8
+
9
+ T = TypeVar("T")
10
+
11
+
12
+ def is_minimize(model_sense) -> bool:
13
+ """Return the objective direction, rejecting unsupported sense values."""
14
+ if model_sense == EPO.MINIMIZE:
15
+ return True
16
+ if model_sense == EPO.MAXIMIZE:
17
+ return False
18
+ raise ValueError("Invalid modelSense. Must be EPO.MINIMIZE or EPO.MAXIMIZE.")
19
+
20
+
21
+ def solution_pool_tolerance(num_cost: int) -> float:
22
+ """L1 tolerance used to deduplicate approximate solver solutions."""
23
+ return min(1e-4 * num_cost, 0.1)
24
+
25
+
26
+ def validate_positive(value, name: str) -> None:
27
+ """Validate a finite, strictly positive real parameter."""
28
+ if not isinstance(value, Real) or isinstance(value, bool):
29
+ raise ValueError(f"{name} must be a finite positive number.")
30
+ number = float(value)
31
+ if not math.isfinite(number) or number <= 0:
32
+ raise ValueError(f"{name} must be a finite positive number.")
33
+
34
+
35
+ def validate_positive_int(value, name: str) -> None:
36
+ """Validate a strictly positive integer parameter."""
37
+ if not isinstance(value, int) or isinstance(value, bool) or value <= 0:
38
+ raise ValueError(f"{name} must be a positive integer.")
39
+
40
+
41
+ def validate_nonnegative(value, name: str) -> None:
42
+ """Validate a finite, non-negative real parameter."""
43
+ if not isinstance(value, Real) or isinstance(value, bool):
44
+ raise ValueError(f"{name} must be a finite non-negative number.")
45
+ number = float(value)
46
+ if not math.isfinite(number) or number < 0:
47
+ raise ValueError(f"{name} must be a finite non-negative number.")
48
+
49
+
50
+ def validate_probability(value, name: str) -> None:
51
+ """Validate a finite real probability in the closed interval [0, 1]."""
52
+ if not isinstance(value, Real) or isinstance(value, bool):
53
+ raise ValueError(f"{name} must be a finite number in [0, 1].")
54
+ number = float(value)
55
+ if not math.isfinite(number) or not 0.0 <= number <= 1.0:
56
+ raise ValueError(f"{name} must be a finite number in [0, 1].")
57
+
58
+
59
+ def require_solution_pool(solpool: Optional[T]) -> T:
60
+ """Return an initialized solution pool or raise a stable runtime error."""
61
+ if solpool is None:
62
+ raise RuntimeError(
63
+ "Solution pool is unavailable; provide an optDataset when pool-based solving is enabled."
64
+ )
65
+ return solpool