pyepo 2.2.4__tar.gz → 2.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. {pyepo-2.2.4 → pyepo-2.2.6}/PKG-INFO +3 -2
  2. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/data/_validation.py +1 -2
  3. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/data/dataset.py +37 -7
  4. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/cave.py +5 -1
  5. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/cave.py +3 -0
  6. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/regularized.py +2 -3
  7. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/surrogate.py +6 -1
  8. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/utils.py +9 -0
  9. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/perturbed.py +6 -74
  10. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/regularized.py +2 -3
  11. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/runtime.py +18 -5
  12. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/surrogate.py +6 -1
  13. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/utils.py +27 -11
  14. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/metric/_common.py +0 -6
  15. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/metric/mse.py +7 -6
  16. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/metric/regret.py +4 -4
  17. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/metric/unambregret.py +5 -5
  18. pyepo-2.2.6/pyepo/model/_mvar_compile.py +83 -0
  19. pyepo-2.2.6/pyepo/model/copt/compile.py +66 -0
  20. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/copt/coptmodel.py +7 -1
  21. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/copt/tsp.py +2 -3
  22. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/copt/vrp.py +11 -2
  23. pyepo-2.2.6/pyepo/model/grb/compile.py +62 -0
  24. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/grb/vrp.py +4 -0
  25. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/mpax/compile.py +17 -8
  26. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/mpax/knapsack.py +6 -0
  27. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/omo/vrp.py +6 -0
  28. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/opt.py +66 -21
  29. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/ort/compile.py +0 -2
  30. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/ort/ortcpmodel.py +0 -1
  31. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/ort/ortmodel.py +4 -1
  32. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/utils.py +5 -4
  33. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo.egg-info/PKG-INFO +3 -2
  34. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo.egg-info/SOURCES.txt +1 -0
  35. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo.egg-info/requires.txt +1 -1
  36. {pyepo-2.2.4 → pyepo-2.2.6}/pyproject.toml +3 -5
  37. {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_00_constants.py +2 -13
  38. {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_10_utils.py +122 -50
  39. {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_15_dsl.py +75 -90
  40. pyepo-2.2.6/tests/test_20_data_gen.py +125 -0
  41. {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_30_model.py +82 -103
  42. {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_40_dataset.py +59 -8
  43. {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_50_func.py +190 -201
  44. {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_55_jax.py +316 -383
  45. {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_60_metric.py +43 -77
  46. {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_61_metric_validation.py +27 -7
  47. {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_70_twostage.py +12 -29
  48. {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_80_integration.py +15 -9
  49. {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_85_backend_pipeline.py +8 -22
  50. {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_90_cuda.py +9 -27
  51. pyepo-2.2.4/pyepo/model/copt/compile.py +0 -107
  52. pyepo-2.2.4/pyepo/model/grb/compile.py +0 -103
  53. pyepo-2.2.4/tests/test_20_data_gen.py +0 -193
  54. {pyepo-2.2.4 → pyepo-2.2.6}/LICENSE +0 -0
  55. {pyepo-2.2.4 → pyepo-2.2.6}/README.md +0 -0
  56. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/EPO.py +0 -0
  57. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/__init__.py +0 -0
  58. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/data/__init__.py +0 -0
  59. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/data/knapsack.py +0 -0
  60. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/data/portfolio.py +0 -0
  61. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/data/shortestpath.py +0 -0
  62. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/data/tsp.py +0 -0
  63. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/dsl/__init__.py +0 -0
  64. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/dsl/compiled.py +0 -0
  65. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/dsl/expression.py +0 -0
  66. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/dsl/objective.py +0 -0
  67. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/dsl/problem.py +0 -0
  68. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/__init__.py +0 -0
  69. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/_common.py +0 -0
  70. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/abcmodule.py +0 -0
  71. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/blackbox.py +0 -0
  72. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/contrastive.py +0 -0
  73. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/__init__.py +0 -0
  74. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/abcmodule.py +0 -0
  75. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/blackbox.py +0 -0
  76. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/contrastive.py +0 -0
  77. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/perturbed.py +0 -0
  78. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/rank.py +0 -0
  79. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/rank.py +0 -0
  80. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/metric/__init__.py +0 -0
  81. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/metric/metrics.py +0 -0
  82. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/__init__.py +0 -0
  83. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/_common.py +0 -0
  84. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/bases.py +0 -0
  85. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/copt/__init__.py +0 -0
  86. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/copt/knapsack.py +0 -0
  87. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/copt/portfolio.py +0 -0
  88. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/copt/shortestpath.py +0 -0
  89. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/grb/__init__.py +0 -0
  90. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/grb/grbmodel.py +0 -0
  91. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/grb/knapsack.py +0 -0
  92. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/grb/portfolio.py +0 -0
  93. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/grb/shortestpath.py +0 -0
  94. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/grb/tsp.py +0 -0
  95. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/mpax/__init__.py +0 -0
  96. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/mpax/mpaxmodel.py +0 -0
  97. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/mpax/shortestpath.py +0 -0
  98. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/omo/__init__.py +0 -0
  99. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/omo/compile.py +0 -0
  100. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/omo/knapsack.py +0 -0
  101. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/omo/omomodel.py +0 -0
  102. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/omo/portfolio.py +0 -0
  103. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/omo/shortestpath.py +0 -0
  104. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/omo/tsp.py +0 -0
  105. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/ort/__init__.py +0 -0
  106. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/ort/knapsack.py +0 -0
  107. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/ort/shortestpath.py +0 -0
  108. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/predefined.py +0 -0
  109. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/utils.py +0 -0
  110. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/py.typed +0 -0
  111. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/twostage/__init__.py +0 -0
  112. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/twostage/autosklearnpred.py +0 -0
  113. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/twostage/sklearnpred.py +0 -0
  114. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo.egg-info/dependency_links.txt +0 -0
  115. {pyepo-2.2.4 → pyepo-2.2.6}/pyepo.egg-info/top_level.txt +0 -0
  116. {pyepo-2.2.4 → pyepo-2.2.6}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyepo
3
- Version: 2.2.4
3
+ Version: 2.2.6
4
4
  Summary: PyTorch/JAX-based End-to-End Predict-then-Optimize Tool
5
5
  Author-email: Bo Tang <bolucas.tang@mail.utoronto.ca>
6
6
  License: MIT
@@ -16,6 +16,7 @@ Classifier: Programming Language :: Python :: 3.10
16
16
  Classifier: Programming Language :: Python :: 3.11
17
17
  Classifier: Programming Language :: Python :: 3.12
18
18
  Classifier: Programming Language :: Python :: 3.13
19
+ Classifier: Programming Language :: Python :: 3.14
19
20
  Classifier: Operating System :: OS Independent
20
21
  Classifier: License :: OSI Approved :: MIT License
21
22
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
@@ -29,7 +30,7 @@ Requires-Dist: numpy
29
30
  Requires-Dist: scipy
30
31
  Requires-Dist: pathos
31
32
  Requires-Dist: tqdm
32
- Requires-Dist: scikit_learn
33
+ Requires-Dist: scikit-learn
33
34
  Requires-Dist: torch>=1.13.1
34
35
  Provides-Extra: pyomo
35
36
  Requires-Dist: pyomo>=6.1.2; extra == "pyomo"
@@ -6,8 +6,7 @@ from numbers import Real
6
6
 
7
7
  def validate_degree(deg: int) -> None:
8
8
  """Validate a positive integer polynomial degree."""
9
- if not isinstance(deg, int) or isinstance(deg, bool) or deg <= 0:
10
- raise ValueError(f"deg = {deg} should be a positive integer.")
9
+ validate_positive_int(deg, "deg")
11
10
 
12
11
 
13
12
  def validate_nonnegative(value: float, name: str) -> None:
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, cast
11
11
  import numpy as np
12
12
  import torch
13
13
  from scipy.spatial import distance
14
- from torch.utils.data import Dataset
14
+ from torch.utils.data import DataLoader, Dataset
15
15
  from tqdm import tqdm
16
16
 
17
17
  from pyepo import EPO
@@ -33,14 +33,20 @@ def _validate_inputs(
33
33
  model: optModel,
34
34
  feats: np.ndarray | torch.Tensor,
35
35
  costs: np.ndarray | torch.Tensor,
36
- ) -> None:
37
- """Validate the common constructor contract for optimization datasets."""
36
+ ) -> tuple[np.ndarray | torch.Tensor, np.ndarray | torch.Tensor]:
37
+ """Validate and normalize the common constructor contract for optimization datasets."""
38
38
  if not isinstance(model, optModel):
39
39
  raise TypeError("arg model is not an optModel")
40
+ # array-likes become numpy arrays
41
+ if not isinstance(feats, (np.ndarray, torch.Tensor)):
42
+ feats = np.asarray(feats)
43
+ if not isinstance(costs, (np.ndarray, torch.Tensor)):
44
+ costs = np.asarray(costs)
40
45
  if len(feats) != len(costs):
41
46
  raise ValueError(
42
47
  f"feats and costs must have the same number of instances: {len(feats)} vs {len(costs)}."
43
48
  )
49
+ return feats, costs
44
50
 
45
51
 
46
52
  def _as_float_tensor(data) -> torch.Tensor:
@@ -91,7 +97,7 @@ class optDataset(Dataset):
91
97
  feats: data features
92
98
  costs: costs of objective function
93
99
  """
94
- _validate_inputs(model, feats, costs)
100
+ feats, costs = _validate_inputs(model, feats, costs)
95
101
  self.model = model
96
102
  # data
97
103
  self.feats = feats
@@ -229,7 +235,7 @@ class optDatasetKNN(optDataset):
229
235
  k: number of nearest neighbours selected
230
236
  weight: self-weight in the kNN convex combination (1.0 = no smoothing)
231
237
  """
232
- _validate_inputs(model, feats, costs)
238
+ feats, costs = _validate_inputs(model, feats, costs)
233
239
  self.model = model
234
240
  # at most num_data-1 neighbours exist (self excluded), so k must stay below it
235
241
  num_data = len(feats)
@@ -315,7 +321,8 @@ class optDatasetConstrs(optDataset):
315
321
  currently requires a Gurobi-backed ``optModel``.
316
322
 
317
323
  Per-instance row counts differ (different constraints bind at different
318
- vertices), so batches must be assembled with ``collate_tight_constraints``.
324
+ vertices), so batch with ``optDataLoader`` or pass
325
+ ``collate_tight_constraints`` to a PyTorch ``DataLoader``.
319
326
 
320
327
  Reference: Tang & Khalil (2024)
321
328
  `<https://link.springer.com/chapter/10.1007/978-3-031-60599-4_12>`_
@@ -346,7 +353,7 @@ class optDatasetConstrs(optDataset):
346
353
  costs: costs of objective function
347
354
  skip_infeas: if True, drop infeasible instances instead of raising
348
355
  """
349
- _validate_inputs(model, feats, costs)
356
+ feats, costs = _validate_inputs(model, feats, costs)
350
357
  self.model = model
351
358
  self.skip_infeas = skip_infeas
352
359
  # data
@@ -455,6 +462,29 @@ def collate_tight_constraints(batch):
455
462
  )
456
463
 
457
464
 
465
+ # optDatasetConstrs yields ragged binding-constraint matrices; pad them when batching
466
+ optDatasetConstrs.collate_fn = staticmethod(collate_tight_constraints)
467
+
468
+
469
+ class optDataLoader(DataLoader):
470
+ """
471
+ ``DataLoader`` that applies a dataset's own ``collate_fn`` when present.
472
+
473
+ Datasets with ragged samples (e.g. ``optDatasetConstrs``, whose
474
+ binding-constraint matrices vary in row count) carry a ``collate_fn``; this
475
+ loader uses it so the caller never passes one explicitly. Plain datasets
476
+ fall back to the default PyTorch collation.
477
+ """
478
+
479
+ def __init__(self, dataset, *args, **kwargs):
480
+ # use the dataset's own collate_fn unless the caller supplies one
481
+ if len(args) <= 5 and "collate_fn" not in kwargs:
482
+ collate_fn = getattr(dataset, "collate_fn", None)
483
+ if collate_fn is not None:
484
+ kwargs["collate_fn"] = collate_fn
485
+ super().__init__(dataset, *args, **kwargs)
486
+
487
+
458
488
  def _extract_tight_normals(
459
489
  model: optModel,
460
490
  sol: np.ndarray,
@@ -59,12 +59,16 @@ class coneAlignedCosine(optModule):
59
59
  cutting the per-epoch cost without measurable regret loss.
60
60
 
61
61
  Training data must come from ``pyepo.data.dataset.optDatasetConstrs``
62
- (Gurobi-backed) and be collated with ``collate_tight_constraints``.
62
+ (Gurobi-backed) and be batched with ``pyepo.data.dataset.optDataLoader``
63
+ or a ``DataLoader`` using ``collate_tight_constraints``.
63
64
 
64
65
  Reference: Tang & Khalil (2024)
65
66
  `<https://link.springer.com/chapter/10.1007/978-3-031-60599-4_12>`_
66
67
  """
67
68
 
69
+ # Clarabel-only workers
70
+ _pool_needs_optmodel = False
71
+
68
72
  def __init__(
69
73
  self,
70
74
  optmodel: optModel,
@@ -34,6 +34,9 @@ class coneAlignedCosine(optModule):
34
34
  `<https://link.springer.com/chapter/10.1007/978-3-031-60599-4_12>`_
35
35
  """
36
36
 
37
+ # Clarabel-only workers
38
+ _pool_needs_optmodel = False
39
+
37
40
  def __init__(
38
41
  self,
39
42
  optmodel,
@@ -139,9 +139,8 @@ def _away_step_frank_wolfe(theta, module, use_cache=False):
139
139
  w = w.at[bidx, match_idx].add(gamma_fw * (has_match & use_fw).astype(theta.dtype))
140
140
  vt = vt.at[bidx, free_idx].set(jnp.where(add_new[:, None], v, vt[bidx, free_idx]))
141
141
  vn = vn.at[bidx, free_idx].set(jnp.where(add_new, vnv, vn[bidx, free_idx]))
142
- w = w.at[bidx, free_idx].set(
143
- jnp.where(add_new, gamma_fw + w[bidx, free_idx], w[bidx, free_idx])
144
- )
142
+ # displaced mass is dropped
143
+ w = w.at[bidx, free_idx].set(jnp.where(add_new, gamma_fw, w[bidx, free_idx]))
145
144
  # away subtract, then clear FP residue so dropped atoms leave the active set
146
145
  w = w.at[bidx, away_idx].add(-gamma_away)
147
146
  w = jnp.where(w < 1e-12, 0.0, jnp.maximum(w, 0.0))
@@ -14,7 +14,7 @@ import jax.numpy as jnp
14
14
  from pyepo.func._common import is_minimize, validate_positive
15
15
  from pyepo.func.jax.abcmodule import optModule
16
16
  from pyepo.func.jax.utils import _full_cost, _solve_or_cache
17
- from pyepo.utils import _EPS
17
+ from pyepo.utils import _EPS, objective_offset
18
18
 
19
19
  if TYPE_CHECKING:
20
20
  from pyepo.func.runtime import Reduction
@@ -68,6 +68,11 @@ def _spoplus_value_and_grad(pred_cost, true_cost, true_sol, true_obj, module):
68
68
  # solve the perturbed problem
69
69
  sol, obj = _solve_or_cache(2.0 * pred_cost - true_cost, module)
70
70
  z = jnp.squeeze(true_obj, axis=-1) if true_obj.ndim > 1 else true_obj
71
+ # drop the bare objective constant
72
+ offset = objective_offset(module.optmodel)
73
+ if offset:
74
+ obj = obj - offset
75
+ z = z - offset
71
76
  inner = 2.0 * jnp.einsum("bi,bi->b", pred_cost, true_sol)
72
77
  # loss and subgradient
73
78
  if is_minimize(module.optmodel.modelSense):
@@ -12,6 +12,7 @@ import numpy as np
12
12
  from pyepo.func._common import is_minimize, solution_pool_tolerance
13
13
  from pyepo.func.utils import _solve_batch_np
14
14
  from pyepo.model.mpax import optMpaxModel
15
+ from pyepo.utils import objective_offset
15
16
 
16
17
  try:
17
18
  from jax.extend.core import concrete_or_error as _concrete_or_error
@@ -83,6 +84,10 @@ def _solve_batch_mpax(cost, optmodel):
83
84
  # obj in true sense
84
85
  if not minimize:
85
86
  obj = -obj
87
+ # add the bare objective constant
88
+ offset = objective_offset(optmodel)
89
+ if offset:
90
+ obj = obj + offset
86
91
  return sol, obj
87
92
 
88
93
 
@@ -127,6 +132,10 @@ def _cache_in_pass(cost, optmodel, solpool):
127
132
  select = jnp.argmin if is_minimize(optmodel.modelSense) else jnp.argmax
128
133
  ind = select(solpool_obj, axis=1)
129
134
  obj = jnp.take_along_axis(solpool_obj, ind[:, None], axis=1).squeeze(1)
135
+ # add the bare objective constant
136
+ offset = objective_offset(optmodel)
137
+ if offset:
138
+ obj = obj + offset
130
139
  return solpool[ind], obj
131
140
 
132
141
 
@@ -13,20 +13,16 @@ from torch.autograd import Function
13
13
 
14
14
  from pyepo.func._common import (
15
15
  is_minimize,
16
- require_solution_pool,
17
16
  validate_positive,
18
17
  validate_positive_int,
19
18
  )
20
19
  from pyepo.func.abcmodule import optModule
21
20
  from pyepo.func.utils import (
22
21
  _mask_pred,
22
+ _solve_or_cache,
23
23
  _torch_generator,
24
- _update_solution_pool,
25
24
  sumGammaDistribution,
26
25
  )
27
- from pyepo.func.utils import (
28
- _solve_batch as _solve_batch_2d,
29
- )
30
26
  from pyepo.utils import _EPS
31
27
 
32
28
  if TYPE_CHECKING:
@@ -625,76 +621,12 @@ class adaptiveImplicitMLEFunc(implicitMLEFunc):
625
621
 
626
622
  def _solve_or_cache_3d(ptb_c: torch.Tensor, module: optModule) -> torch.Tensor:
627
623
  """
628
- Solve or use cached solutions for perturbed costs (3D: n_samples × batch × vars).
629
- Delegates to the shared 2D functions in utils after flattening.
630
- """
631
- optmodel = module.optmodel
632
- processes = module.processes
633
- pool = module.pool
634
- solpool = module.solpool
635
- if module._branch_rng.uniform() <= module.solve_ratio:
636
- ptb_sols, solpool = _solve_in_pass_3d(ptb_c, optmodel, processes, pool, solpool)
637
- else:
638
- solpool = require_solution_pool(solpool)
639
- ptb_sols, solpool = _cache_in_pass_3d(ptb_c, optmodel, solpool)
640
- module.solpool = solpool
641
- return ptb_sols
642
-
643
-
644
- def _solve_in_pass_3d(
645
- ptb_c: torch.Tensor,
646
- optmodel: optModel,
647
- processes: int,
648
- pool,
649
- solpool: torch.Tensor | None = None,
650
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
651
- """
652
- Solve optimization for perturbed 3D costs and update solution pool.
653
-
654
- Args:
655
- ptb_c: perturbed costs, shape (batch, n_samples, vars)
656
- optmodel: optimization model
657
- processes: number of processors
658
- pool: process pool
659
- solpool: solution pool
660
-
661
- Returns:
662
- tuple: (solutions shape (batch, n_samples, vars), updated solpool)
663
- """
664
- ins_num, n_samples, num_vars = ptb_c.shape
665
- # flatten (batch, n_samples, vars) → (batch * n_samples, vars)
666
- flat_c = ptb_c.reshape(-1, num_vars)
667
- # solve using shared 2D function
668
- flat_sols, _ = _solve_batch_2d(flat_c, optmodel, processes, pool)
669
- # update pool while flat_sols is still contiguous
670
- if solpool is not None:
671
- solpool = _update_solution_pool(flat_sols, solpool)
672
- if solpool.device != ptb_c.device:
673
- solpool = solpool.to(ptb_c.device)
674
- # reshape (batch * n_samples, vars) → (batch, n_samples, vars)
675
- ptb_sols = flat_sols.reshape(ins_num, n_samples, num_vars)
676
- return ptb_sols, solpool
677
-
678
-
679
- def _cache_in_pass_3d(
680
- ptb_c: torch.Tensor,
681
- optmodel: optModel,
682
- solpool: torch.Tensor,
683
- ) -> tuple[torch.Tensor, torch.Tensor]:
684
- """
685
- Use solution pool for perturbed 3D costs (batch × n_samples × vars).
686
- Unlike the 2D version in utils, this handles the extra sample dimension.
624
+ Solve or use cached solutions for perturbed costs (batch × n_samples × vars).
625
+ Flattens the sample axis and delegates to the shared 2D path in utils.
687
626
  """
688
- # move solpool to the correct device
689
- if solpool.device != ptb_c.device:
690
- solpool = solpool.to(ptb_c.device)
691
- # compute objective values: (batch, n_samples, pool_size)
692
- solpool_obj = torch.einsum("bnd,sd->bns", ptb_c, solpool)
693
- # best solution in pool
694
- select = torch.argmin if is_minimize(optmodel.modelSense) else torch.argmax
695
- best_inds = select(solpool_obj, dim=2)
696
- ptb_sols = solpool[best_inds]
697
- return ptb_sols, solpool
627
+ batch, n_samples, num_vars = ptb_c.shape
628
+ sol, _ = _solve_or_cache(ptb_c.reshape(-1, num_vars), module)
629
+ return sol.reshape(batch, n_samples, num_vars)
698
630
 
699
631
 
700
632
  # acronym aliases
@@ -120,9 +120,8 @@ def _away_step_frank_wolfe(
120
120
  vertex_norms[batch_idx, free_idx] = torch.where(
121
121
  add_new, v_norm_sq, vertex_norms[batch_idx, free_idx]
122
122
  )
123
- weights[batch_idx, free_idx] = torch.where(
124
- add_new, gamma_fw + weights[batch_idx, free_idx], weights[batch_idx, free_idx]
125
- )
123
+ # displaced mass is dropped
124
+ weights[batch_idx, free_idx] = torch.where(add_new, gamma_fw, weights[batch_idx, free_idx])
126
125
  # away subtract, then clear FP residue so dropped atoms leave the active set
127
126
  weights[batch_idx, away_idx] = weights[batch_idx, away_idx] - gamma_away
128
127
  weights = weights.clamp(min=0.0)
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import itertools
5
6
  import multiprocessing as mp
6
7
  import weakref
7
8
  from dataclasses import dataclass
@@ -63,20 +64,27 @@ def normalize_processes(
63
64
  return cpu_count if processes == 0 else processes
64
65
 
65
66
 
67
+ # unique pool ids
68
+ _pool_ids = itertools.count()
69
+
70
+
66
71
  def create_solver_pool(
67
72
  optmodel: optModel,
68
73
  processes: int,
69
74
  *,
70
75
  owner=None,
76
+ with_solver: bool = True,
71
77
  ) -> ProcessingPool | None:
72
78
  """Create a worker pool, optionally tied to an owner's lifetime."""
73
79
  if processes == 1:
74
80
  return None
75
- pool = ProcessingPool(
76
- processes,
77
- initializer=_init_worker_model,
78
- initargs=(optmodel.to_spec(),),
81
+ # optional per-worker optmodel preload
82
+ init_kwargs = (
83
+ {"initializer": _init_worker_model, "initargs": (optmodel.to_spec(),)}
84
+ if with_solver
85
+ else {}
79
86
  )
87
+ pool = ProcessingPool(processes, id=f"pyepo-{next(_pool_ids)}", **init_kwargs)
80
88
  if owner is not None:
81
89
  weakref.finalize(owner, _close_pool, pool)
82
90
  return pool
@@ -117,7 +125,12 @@ def init_runtime(
117
125
  raise ValueError(f"No reduction '{reduction}'.")
118
126
 
119
127
  normalized_processes = normalize_processes(optmodel, processes, logger)
120
- pool = create_solver_pool(optmodel, normalized_processes, owner=owner)
128
+ pool = create_solver_pool(
129
+ optmodel,
130
+ normalized_processes,
131
+ owner=owner,
132
+ with_solver=getattr(owner, "_pool_needs_optmodel", True),
133
+ )
121
134
  logger.info("Num of cores: %d", normalized_processes)
122
135
  return RuntimeState(
123
136
  optmodel=optmodel,
@@ -13,7 +13,7 @@ from torch.autograd import Function
13
13
  from pyepo.func._common import is_minimize, validate_positive
14
14
  from pyepo.func.abcmodule import optModule
15
15
  from pyepo.func.utils import _solve_or_cache
16
- from pyepo.utils import _EPS
16
+ from pyepo.utils import _EPS, objective_offset
17
17
 
18
18
  if TYPE_CHECKING:
19
19
  from pyepo.data.dataset import optDataset
@@ -115,6 +115,11 @@ class SPOPlusFunc(Function):
115
115
  # _check_sol(c, w, z)
116
116
  # solve
117
117
  sol, obj = _solve_or_cache(2 * cp - c, module)
118
+ # drop the bare objective constant
119
+ offset = objective_offset(module.optmodel)
120
+ if offset:
121
+ obj = obj - offset
122
+ z = z - offset
118
123
  # calculate loss
119
124
  if is_minimize(module.optmodel.modelSense):
120
125
  loss = -obj + 2 * torch.einsum("bi,bi->b", cp, w) - z.squeeze(dim=-1)
@@ -20,7 +20,7 @@ from pyepo.func._common import (
20
20
  validate_positive_int,
21
21
  )
22
22
  from pyepo.model.mpax import optMpaxModel
23
- from pyepo.utils import costToNumpy
23
+ from pyepo.utils import costToNumpy, objective_offset
24
24
 
25
25
  if TYPE_CHECKING:
26
26
  from pyepo.func.abcmodule import optModule
@@ -87,8 +87,9 @@ def _solve_batch(
87
87
  """
88
88
  A function to solve optimization in the forward/backward pass
89
89
  """
90
- # get device
90
+ # get device and dtype
91
91
  device = cp.device if isinstance(cp, torch.Tensor) else torch.device("cpu")
92
+ dtype = cp.dtype if isinstance(cp, torch.Tensor) else torch.float32
92
93
  # MPAX batch solving
93
94
  if isinstance(optmodel, optMpaxModel):
94
95
  # get params
@@ -113,11 +114,17 @@ def _solve_batch(
113
114
  # obj sense
114
115
  if not is_minimize(optmodel.modelSense):
115
116
  obj = -obj
117
+ # add the bare objective constant
118
+ offset = objective_offset(optmodel)
119
+ if offset:
120
+ obj = obj + offset
121
+ # match input dtype
122
+ sol, obj = sol.to(dtype), obj.to(dtype)
116
123
  # host solving on numpy costs
117
124
  else:
118
125
  sol_np, obj_np = _solve_batch_np(costToNumpy(cp), optmodel, processes, pool)
119
- sol = torch.as_tensor(sol_np).to(device)
120
- obj = torch.as_tensor(obj_np).to(device)
126
+ sol = torch.as_tensor(sol_np).to(device=device, dtype=dtype)
127
+ obj = torch.as_tensor(obj_np).to(device=device, dtype=dtype)
121
128
  return sol, obj
122
129
 
123
130
 
@@ -130,6 +137,8 @@ def _solve_batch_np(
130
137
  """
131
138
  A function to solve a batch of numpy costs on the host, shared by both frontends
132
139
  """
140
+ # match input precision
141
+ out_dtype = np.result_type(cp.dtype, np.float32)
133
142
  # single-core
134
143
  if processes == 1:
135
144
  sol_list: list = []
@@ -140,13 +149,13 @@ def _solve_batch_np(
140
149
  sol_list.append(solp)
141
150
  obj_list.append(objp)
142
151
  # stack + dtype convert in a single call
143
- sol = np.asarray(sol_list, dtype=np.float32)
144
- obj = np.asarray(obj_list, dtype=np.float32)
152
+ sol = np.asarray(sol_list, dtype=out_dtype)
153
+ obj = np.asarray(obj_list, dtype=out_dtype)
145
154
  # multi-core (workers pre-loaded with optmodel via pool initializer)
146
155
  else:
147
156
  res = pool.amap(_solve_with_obj_in_worker, cp).get()
148
- sol = np.stack([r[0] for r in res]).astype(np.float32)
149
- obj = np.asarray([r[1] for r in res], dtype=np.float32)
157
+ sol = np.stack([r[0] for r in res]).astype(out_dtype)
158
+ obj = np.asarray([r[1] for r in res], dtype=out_dtype)
150
159
  return sol, obj
151
160
 
152
161
 
@@ -163,6 +172,9 @@ def _update_solution_pool(
163
172
  return torch.unique(sol, dim=0).clone()
164
173
  if sol.device != solpool.device:
165
174
  sol = sol.to(solpool.device)
175
+ # match pool dtype
176
+ if solpool.dtype != sol.dtype:
177
+ solpool = solpool.to(sol.dtype)
166
178
  sol_uniq = torch.unique(sol, dim=0)
167
179
  # capped L1-tolerance dedup: first-order solvers re-emit near-identical vertices
168
180
  tol = solution_pool_tolerance(sol.shape[-1])
@@ -181,15 +193,19 @@ def _cache_in_pass(
181
193
  """
182
194
  A function to use solution pool in the forward/backward pass
183
195
  """
184
- # move solpool to the correct device
185
- if solpool.device != cp.device:
186
- solpool = solpool.to(cp.device)
196
+ # move solpool to the correct device and dtype
197
+ if solpool.device != cp.device or solpool.dtype != cp.dtype:
198
+ solpool = solpool.to(device=cp.device, dtype=cp.dtype)
187
199
  # best solution in pool
188
200
  solpool_obj = torch.matmul(cp, solpool.T)
189
201
  select = torch.argmin if is_minimize(optmodel.modelSense) else torch.argmax
190
202
  ind = select(solpool_obj, dim=1)
191
203
  obj = solpool_obj.gather(1, ind.view(-1, 1)).squeeze(1)
192
204
  sol = solpool[ind]
205
+ # add the bare objective constant
206
+ offset = objective_offset(optmodel)
207
+ if offset:
208
+ obj = obj + offset
193
209
  return sol, obj, solpool
194
210
 
195
211
 
@@ -30,12 +30,6 @@ def regret_from_objective(obj, true_obj, model_sense):
30
30
  raise ValueError("Invalid modelSense.")
31
31
 
32
32
 
33
- def objective_offset(optmodel: optModel) -> float:
34
- """Return a compiled DSL problem's bare objective constant, if present."""
35
- problem = getattr(optmodel, "problem", None)
36
- return float(problem.obj_offset) if problem is not None else 0.0
37
-
38
-
39
33
  def require_linear_objective(optmodel: optModel) -> None:
40
34
  """Reject models carrying a quadratic objective term."""
41
35
  problem = getattr(optmodel, "problem", None)
@@ -5,15 +5,13 @@ Mean Squared Error
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
- from typing import TYPE_CHECKING, cast
8
+ from typing import TYPE_CHECKING
9
9
 
10
10
  import torch
11
11
 
12
12
  from pyepo.metric._common import torch_evaluation, validate_prediction_batch
13
13
 
14
14
  if TYPE_CHECKING:
15
- from collections.abc import Sized
16
-
17
15
  from torch import nn
18
16
  from torch.utils.data import DataLoader
19
17
 
@@ -24,19 +22,22 @@ def MSE(predmodel: nn.Module, dataloader: DataLoader) -> float:
24
22
 
25
23
  Args:
26
24
  predmodel: a regression neural network for cost prediction
27
- dataloader: Torch dataloader from optDataSet
25
+ dataloader: Torch dataloader from optDataSet (fields beyond
26
+ ``(x, c, w, z)`` are ignored)
28
27
 
29
28
  Returns:
30
29
  float: MSE loss
31
30
  """
32
31
  loss = 0
32
+ total = 0
33
33
  with torch_evaluation(predmodel) as device, torch.no_grad():
34
34
  # load data
35
35
  for data in dataloader:
36
- x, c, _, _ = data
36
+ x, c, _, _ = data[:4]
37
37
  x, c = x.to(device), c.to(device)
38
38
  # predict
39
39
  cp = predmodel(x)
40
40
  validate_prediction_batch(cp, c)
41
41
  loss += ((cp - c) ** 2).mean(dim=1).sum().item()
42
- return loss / len(cast("Sized", dataloader.dataset))
42
+ total += x.shape[0]
43
+ return loss / total if total else 0.0
@@ -15,14 +15,13 @@ from pyepo.func.runtime import create_solver_pool, normalize_processes
15
15
  from pyepo.func.utils import _close_pool, _solve_batch
16
16
  from pyepo.metric._common import (
17
17
  normalize_regret,
18
- objective_offset,
19
18
  regret_from_objective,
20
19
  require_linear_objective,
21
20
  torch_evaluation,
22
21
  validate_cost_vectors,
23
22
  validate_prediction_batch,
24
23
  )
25
- from pyepo.utils import costToNumpy
24
+ from pyepo.utils import costToNumpy, objective_offset
26
25
 
27
26
  if TYPE_CHECKING:
28
27
  from collections.abc import Callable
@@ -65,7 +64,8 @@ def regret(
65
64
  JAX callable ``f(x_numpy) -> cost_array``
66
65
  optmodel: a PyEPO optimization model
67
66
  dataloader: PyTorch DataLoader over an ``optDataset`` (yielding
68
- ``(x, c, w, z)`` tuples)
67
+ ``(x, c, w, z, ...)`` tuples; fields beyond the first four,
68
+ e.g. CaVE tight constraints, are ignored)
69
69
  processes: number of processors, 1 for single-core, 0 for all of
70
70
  cores; a fresh worker pool is spawned per call, each worker
71
71
  rebuilding the model from its constructor args
@@ -89,7 +89,7 @@ def regret(
89
89
  with torch_evaluation(torch_model) as device:
90
90
  # load data
91
91
  for data in dataloader:
92
- x, c, _, z = data
92
+ x, c, _, z = data[:4]
93
93
  if torch_model is not None:
94
94
  x, c, z = x.to(device), c.to(device), z.to(device)
95
95
  with torch.no_grad():
@@ -14,7 +14,6 @@ import torch
14
14
  from pyepo import EPO
15
15
  from pyepo.metric._common import (
16
16
  normalize_regret,
17
- objective_offset,
18
17
  regret_from_objective,
19
18
  require_linear_objective,
20
19
  torch_evaluation,
@@ -23,7 +22,7 @@ from pyepo.metric._common import (
23
22
  validate_retry_count,
24
23
  validate_tolerance,
25
24
  )
26
- from pyepo.utils import costToNumpy
25
+ from pyepo.utils import costToNumpy, objective_offset
27
26
 
28
27
  if TYPE_CHECKING:
29
28
  from torch import nn
@@ -57,7 +56,8 @@ def unambRegret(
57
56
  Args:
58
57
  predmodel: a regression neural network for cost prediction
59
58
  optmodel: a PyEPO optimization model
60
- dataloader: PyTorch DataLoader over an ``optDataset``
59
+ dataloader: PyTorch DataLoader over an ``optDataset`` (fields beyond
60
+ ``(x, c, w, z)`` are ignored)
61
61
  tolerance: precision used when rounding predicted costs to find ties
62
62
  max_iter: maximum number of solve retries with relaxed tolerance
63
63
 
@@ -72,8 +72,8 @@ def unambRegret(
72
72
  with torch_evaluation(predmodel) as device:
73
73
  # load data
74
74
  for data in dataloader:
75
- x, c, w, z = data
76
- x, c, w, z = x.to(device), c.to(device), w.to(device), z.to(device)
75
+ x, c, _, z = data[:4]
76
+ x, c, z = x.to(device), c.to(device), z.to(device)
77
77
  # pred cost
78
78
  with torch.no_grad():
79
79
  cp = costToNumpy(predmodel(x))