pyepo 2.2.3__tar.gz → 2.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. {pyepo-2.2.3 → pyepo-2.2.5}/PKG-INFO +3 -3
  2. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/data/dataset.py +26 -2
  3. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/cave.py +2 -1
  4. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/bases.py +0 -24
  5. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/grb/grbmodel.py +2 -1
  6. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/grb/tsp.py +0 -3
  7. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/grb/vrp.py +0 -3
  8. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/opt.py +75 -13
  9. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/ort/compile.py +0 -1
  10. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/ort/ortcpmodel.py +0 -1
  11. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/ort/ortmodel.py +0 -1
  12. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/utils.py +1 -1
  13. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo.egg-info/PKG-INFO +3 -3
  14. {pyepo-2.2.3 → pyepo-2.2.5}/pyproject.toml +3 -2
  15. {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_10_utils.py +132 -2
  16. {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_40_dataset.py +49 -0
  17. {pyepo-2.2.3 → pyepo-2.2.5}/LICENSE +0 -0
  18. {pyepo-2.2.3 → pyepo-2.2.5}/README.md +0 -0
  19. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/EPO.py +0 -0
  20. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/__init__.py +0 -0
  21. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/data/__init__.py +0 -0
  22. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/data/_validation.py +0 -0
  23. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/data/knapsack.py +0 -0
  24. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/data/portfolio.py +0 -0
  25. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/data/shortestpath.py +0 -0
  26. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/data/tsp.py +0 -0
  27. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/dsl/__init__.py +0 -0
  28. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/dsl/compiled.py +0 -0
  29. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/dsl/expression.py +0 -0
  30. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/dsl/objective.py +0 -0
  31. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/dsl/problem.py +0 -0
  32. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/__init__.py +0 -0
  33. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/_common.py +0 -0
  34. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/abcmodule.py +0 -0
  35. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/blackbox.py +0 -0
  36. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/contrastive.py +0 -0
  37. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/__init__.py +0 -0
  38. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/abcmodule.py +0 -0
  39. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/blackbox.py +0 -0
  40. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/cave.py +0 -0
  41. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/contrastive.py +0 -0
  42. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/perturbed.py +0 -0
  43. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/rank.py +0 -0
  44. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/regularized.py +0 -0
  45. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/surrogate.py +0 -0
  46. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/jax/utils.py +0 -0
  47. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/perturbed.py +0 -0
  48. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/rank.py +0 -0
  49. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/regularized.py +0 -0
  50. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/runtime.py +0 -0
  51. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/surrogate.py +0 -0
  52. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/func/utils.py +0 -0
  53. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/metric/__init__.py +0 -0
  54. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/metric/_common.py +0 -0
  55. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/metric/metrics.py +0 -0
  56. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/metric/mse.py +0 -0
  57. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/metric/regret.py +0 -0
  58. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/metric/unambregret.py +0 -0
  59. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/__init__.py +0 -0
  60. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/_common.py +0 -0
  61. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/copt/__init__.py +0 -0
  62. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/copt/compile.py +0 -0
  63. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/copt/coptmodel.py +0 -0
  64. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/copt/knapsack.py +0 -0
  65. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/copt/portfolio.py +0 -0
  66. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/copt/shortestpath.py +0 -0
  67. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/copt/tsp.py +0 -0
  68. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/copt/vrp.py +0 -0
  69. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/grb/__init__.py +0 -0
  70. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/grb/compile.py +0 -0
  71. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/grb/knapsack.py +0 -0
  72. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/grb/portfolio.py +0 -0
  73. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/grb/shortestpath.py +0 -0
  74. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/mpax/__init__.py +0 -0
  75. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/mpax/compile.py +0 -0
  76. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/mpax/knapsack.py +0 -0
  77. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/mpax/mpaxmodel.py +0 -0
  78. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/mpax/shortestpath.py +0 -0
  79. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/omo/__init__.py +0 -0
  80. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/omo/compile.py +0 -0
  81. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/omo/knapsack.py +0 -0
  82. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/omo/omomodel.py +0 -0
  83. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/omo/portfolio.py +0 -0
  84. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/omo/shortestpath.py +0 -0
  85. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/omo/tsp.py +0 -0
  86. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/omo/vrp.py +0 -0
  87. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/ort/__init__.py +0 -0
  88. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/ort/knapsack.py +0 -0
  89. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/ort/shortestpath.py +0 -0
  90. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/predefined.py +0 -0
  91. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/model/utils.py +0 -0
  92. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/py.typed +0 -0
  93. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/twostage/__init__.py +0 -0
  94. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/twostage/autosklearnpred.py +0 -0
  95. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo/twostage/sklearnpred.py +0 -0
  96. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo.egg-info/SOURCES.txt +0 -0
  97. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo.egg-info/dependency_links.txt +0 -0
  98. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo.egg-info/requires.txt +0 -0
  99. {pyepo-2.2.3 → pyepo-2.2.5}/pyepo.egg-info/top_level.txt +0 -0
  100. {pyepo-2.2.3 → pyepo-2.2.5}/setup.cfg +0 -0
  101. {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_00_constants.py +0 -0
  102. {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_15_dsl.py +0 -0
  103. {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_20_data_gen.py +0 -0
  104. {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_30_model.py +0 -0
  105. {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_50_func.py +0 -0
  106. {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_55_jax.py +0 -0
  107. {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_60_metric.py +0 -0
  108. {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_61_metric_validation.py +0 -0
  109. {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_70_twostage.py +0 -0
  110. {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_80_integration.py +0 -0
  111. {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_85_backend_pipeline.py +0 -0
  112. {pyepo-2.2.3 → pyepo-2.2.5}/tests/test_90_cuda.py +0 -0
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyepo
3
- Version: 2.2.3
4
- Summary: PyTorch-based End-to-End Predict-then-Optimize Tool
3
+ Version: 2.2.5
4
+ Summary: PyTorch/JAX-based End-to-End Predict-then-Optimize Tool
5
5
  Author-email: Bo Tang <bolucas.tang@mail.utoronto.ca>
6
6
  License: MIT
7
7
  Project-URL: Homepage, https://github.com/khalil-research/PyEPO
@@ -9,7 +9,7 @@ Project-URL: Documentation, https://khalil-research.github.io/PyEPO
9
9
  Project-URL: Repository, https://github.com/khalil-research/PyEPO
10
10
  Project-URL: Issues, https://github.com/khalil-research/PyEPO/issues
11
11
  Project-URL: Paper, https://link.springer.com/article/10.1007/s12532-024-00255-x
12
- Keywords: predict-then-optimize,end-to-end,decision-focused learning,optimization,deep learning,pytorch,linear programming,integer programming
12
+ Keywords: predict-then-optimize,end-to-end,decision-focused learning,optimization,deep learning,pytorch,jax,linear programming,integer programming
13
13
  Classifier: Programming Language :: Python :: 3
14
14
  Classifier: Programming Language :: Python :: 3.9
15
15
  Classifier: Programming Language :: Python :: 3.10
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, cast
11
11
  import numpy as np
12
12
  import torch
13
13
  from scipy.spatial import distance
14
- from torch.utils.data import Dataset
14
+ from torch.utils.data import DataLoader, Dataset
15
15
  from tqdm import tqdm
16
16
 
17
17
  from pyepo import EPO
@@ -315,7 +315,8 @@ class optDatasetConstrs(optDataset):
315
315
  currently requires a Gurobi-backed ``optModel``.
316
316
 
317
317
  Per-instance row counts differ (different constraints bind at different
318
- vertices), so batches must be assembled with ``collate_tight_constraints``.
318
+ vertices), so batch with ``optDataLoader`` or pass
319
+ ``collate_tight_constraints`` to a PyTorch ``DataLoader``.
319
320
 
320
321
  Reference: Tang & Khalil (2024)
321
322
  `<https://link.springer.com/chapter/10.1007/978-3-031-60599-4_12>`_
@@ -455,6 +456,29 @@ def collate_tight_constraints(batch):
455
456
  )
456
457
 
457
458
 
459
+ # optDatasetConstrs yields ragged binding-constraint matrices; pad them when batching
460
+ optDatasetConstrs.collate_fn = staticmethod(collate_tight_constraints)
461
+
462
+
463
+ class optDataLoader(DataLoader):
464
+ """
465
+ ``DataLoader`` that applies a dataset's own ``collate_fn`` when present.
466
+
467
+ Datasets with ragged samples (e.g. ``optDatasetConstrs``, whose
468
+ binding-constraint matrices vary in row count) carry a ``collate_fn``; this
469
+ loader uses it so the caller never passes one explicitly. Plain datasets
470
+ fall back to the default PyTorch collation.
471
+ """
472
+
473
+ def __init__(self, dataset, *args, **kwargs):
474
+ # use the dataset's own collate_fn unless the caller supplies one
475
+ if len(args) <= 5 and "collate_fn" not in kwargs:
476
+ collate_fn = getattr(dataset, "collate_fn", None)
477
+ if collate_fn is not None:
478
+ kwargs["collate_fn"] = collate_fn
479
+ super().__init__(dataset, *args, **kwargs)
480
+
481
+
458
482
  def _extract_tight_normals(
459
483
  model: optModel,
460
484
  sol: np.ndarray,
@@ -59,7 +59,8 @@ class coneAlignedCosine(optModule):
59
59
  cutting the per-epoch cost without measurable regret loss.
60
60
 
61
61
  Training data must come from ``pyepo.data.dataset.optDatasetConstrs``
62
- (Gurobi-backed) and be collated with ``collate_tight_constraints``.
62
+ (Gurobi-backed) and be batched with ``pyepo.data.dataset.optDataLoader``
63
+ or a ``DataLoader`` using ``collate_tight_constraints``.
63
64
 
64
65
  Reference: Tang & Khalil (2024)
65
66
  `<https://link.springer.com/chapter/10.1007/978-3-031-60599-4_12>`_
@@ -16,7 +16,6 @@ from __future__ import annotations
16
16
 
17
17
  import math
18
18
  from collections import defaultdict
19
- from copy import deepcopy
20
19
  from itertools import combinations
21
20
  from numbers import Integral, Real
22
21
  from typing import TYPE_CHECKING
@@ -105,9 +104,6 @@ class shortestPathBase(optModel):
105
104
  self.arcs = _get_grid_arcs(self.grid)
106
105
  super().__init__(*args, **kwargs)
107
106
 
108
- def get_config(self) -> dict:
109
- return {**super().get_config(), "grid": self.grid}
110
-
111
107
  @property
112
108
  def num_cost(self) -> int:
113
109
  return len(self.arcs)
@@ -208,14 +204,6 @@ class portfolioBase(optModel):
208
204
  raise ValueError("gamma must be greater than or equal to zero.")
209
205
  super().__init__(*args, **kwargs)
210
206
 
211
- def get_config(self) -> dict:
212
- return {
213
- **super().get_config(),
214
- "num_assets": self.num_assets,
215
- "covariance": self.covariance.copy(),
216
- "gamma": self.gamma,
217
- }
218
-
219
207
  @property
220
208
  def num_cost(self) -> int:
221
209
  return self.num_assets
@@ -263,9 +251,6 @@ class tspABBase(optModel):
263
251
  self._extra_constrs: list = []
264
252
  super().__init__(*args, **kwargs)
265
253
 
266
- def get_config(self) -> dict:
267
- return {**super().get_config(), "num_nodes": self.num_nodes}
268
-
269
254
  @property
270
255
  def num_cost(self) -> int:
271
256
  # use edges; backend's self.x has 2*num_edges directed Vars
@@ -381,15 +366,6 @@ class vrpABBase(optModel):
381
366
  self._extra_constrs: list = []
382
367
  super().__init__(*args, **kwargs)
383
368
 
384
- def get_config(self) -> dict:
385
- return {
386
- **super().get_config(),
387
- "num_nodes": self.num_nodes,
388
- "demands": deepcopy(self.demands),
389
- "capacity": self.capacity,
390
- "num_vehicle": self.num_vehicle,
391
- }
392
-
393
369
  @property
394
370
  def num_cost(self) -> int:
395
371
  # one predicted cost per undirected edge
@@ -171,7 +171,8 @@ class optGrbModel(optModel):
171
171
  else:
172
172
  # LinExpr(coeffs, vars) builds the affine expression in one C call
173
173
  vars_list = new_model._vars_list
174
- assert vars_list is not None
174
+ if vars_list is None:
175
+ raise RuntimeError("Gurobi variable list is unavailable.")
175
176
  expr = gp.LinExpr(coefs_np.tolist(), vars_list) <= rhs
176
177
  new_model._model.addConstr(expr)
177
178
  # track for replay on relax
@@ -178,9 +178,6 @@ class tspDFJModel(tspABModel):
178
178
  self._recycled_keys: set = set()
179
179
  super().__init__(num_nodes, *args, **kwargs)
180
180
 
181
- def get_config(self) -> dict:
182
- return {**super().get_config(), "recycle_cuts": self.recycle_cuts}
183
-
184
181
  def _getModel(self) -> tuple:
185
182
  """
186
183
  A method to build Gurobi model
@@ -70,9 +70,6 @@ class vrpRCIModel(vrpABModel):
70
70
  self._recycled_keys: set = set()
71
71
  super().__init__(num_nodes, demands, capacity, num_vehicle)
72
72
 
73
- def get_config(self) -> dict:
74
- return {**super().get_config(), "recycle_cuts": self.recycle_cuts}
75
-
76
73
  def _getModel(self) -> tuple:
77
74
  """
78
75
  A method to build Gurobi model
@@ -5,6 +5,8 @@ Abstract optimization model
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
+ import functools
9
+ import inspect
8
10
  from abc import ABC, abstractmethod
9
11
  from copy import deepcopy
10
12
  from dataclasses import dataclass
@@ -25,20 +27,63 @@ class ModelSpec:
25
27
  """Serializable recipe for building a fresh optimization model."""
26
28
 
27
29
  model_type: type[optModel]
30
+ _args: tuple
28
31
  _config: dict
29
32
 
30
- def __init__(self, model_type: type[optModel], config: dict) -> None:
33
+ def __init__(
34
+ self,
35
+ model_type: type[optModel],
36
+ config: dict,
37
+ args: tuple = (),
38
+ ) -> None:
31
39
  object.__setattr__(self, "model_type", model_type)
40
+ object.__setattr__(self, "_args", deepcopy(args))
32
41
  object.__setattr__(self, "_config", deepcopy(config))
33
42
 
43
+ @property
44
+ def args(self) -> tuple:
45
+ """Return an independent copy of positional constructor arguments."""
46
+ return deepcopy(self._args)
47
+
34
48
  @property
35
49
  def config(self) -> dict:
36
- """Return an independent copy of the constructor configuration."""
50
+ """Return an independent copy of keyword constructor arguments."""
37
51
  return deepcopy(self._config)
38
52
 
39
53
  def build(self) -> optModel:
40
54
  """Build a fresh model without sharing mutable configuration values."""
41
- return self.model_type.from_config(self._config)
55
+ return self.model_type.from_config(self._config, self._args)
56
+
57
+
58
+ def _snapshot(value):
59
+ """Deep-copy a constructor argument, keeping the reference if it cannot be copied."""
60
+ try:
61
+ return deepcopy(value)
62
+ except Exception: # noqa: BLE001 -- any copy failure falls back to a reference
63
+ return value
64
+
65
+
66
+ def _capture_init_config(init, args, kwargs) -> tuple[tuple, dict]:
67
+ """Flatten a constructor call into arguments that rebuild the model."""
68
+ sig = inspect.signature(init)
69
+ bound = sig.bind(None, *args, **kwargs)
70
+ init_args = []
71
+ config = {}
72
+ for i, (name, value) in enumerate(bound.arguments.items()):
73
+ # the first bound argument is self
74
+ if i == 0:
75
+ continue
76
+ kind = sig.parameters[name].kind
77
+ # snapshot each argument; values that cannot be deep-copied keep a reference
78
+ if kind is inspect.Parameter.VAR_KEYWORD:
79
+ config.update({k: _snapshot(v) for k, v in value.items()})
80
+ elif kind is inspect.Parameter.VAR_POSITIONAL:
81
+ init_args.extend(_snapshot(v) for v in value)
82
+ elif kind is inspect.Parameter.POSITIONAL_ONLY:
83
+ init_args.append(_snapshot(value))
84
+ else:
85
+ config[name] = _snapshot(value)
86
+ return tuple(init_args), config
42
87
 
43
88
 
44
89
  class optModel(ABC):
@@ -53,11 +98,6 @@ class optModel(ABC):
53
98
  and MPAX (``optMpaxModel``); subclass ``optModel`` directly to integrate
54
99
  any other solver or algorithm.
55
100
 
56
- Models that take constructor arguments should override ``get_config`` and
57
- cooperatively merge ``super().get_config()``. The resulting configuration
58
- powers ``rebuild()``, multiprocessing workers, and sklearn scorers without
59
- inspecting constructor signatures or runtime solver state.
60
-
61
101
  The default objective sense is minimization; set
62
102
  ``self.modelSense = EPO.MAXIMIZE`` in ``_getModel`` or ``__init__`` for
63
103
  maximization problems (some backends, e.g. Gurobi and COPT, detect this
@@ -73,6 +113,24 @@ class optModel(ABC):
73
113
  arcs: list
74
114
  _cost_vars: list
75
115
 
116
+ def __init_subclass__(cls, **kwargs) -> None:
117
+ super().__init_subclass__(**kwargs)
118
+ # Only wrap subclasses that define their own __init__ and use the
119
+ # default reconstruction config. Custom get_config implementations may
120
+ # intentionally accept objects that are not deepcopyable/rebuildable.
121
+ if "__init__" not in cls.__dict__ or "get_config" in cls.__dict__:
122
+ return
123
+ user_init = cls.__init__
124
+
125
+ @functools.wraps(user_init)
126
+ def _init_capturing(self, *args, **kwargs):
127
+ # record only the outermost call; nested super().__init__ leaves it intact
128
+ if "_init_config" not in self.__dict__:
129
+ self._init_args, self._init_config = _capture_init_config(user_init, args, kwargs)
130
+ user_init(self, *args, **kwargs)
131
+
132
+ cls.__init__ = _init_capturing
133
+
76
134
  def __init__(self) -> None:
77
135
  # Cache for models whose solver variables do not map one-to-one to
78
136
  # predicted costs (for example directed TSP/VRP formulations).
@@ -86,17 +144,21 @@ class optModel(ABC):
86
144
  return "optModel " + self.__class__.__name__
87
145
 
88
146
  def get_config(self) -> dict:
89
- """Return the explicit constructor configuration for this model."""
90
- return {}
147
+ """Return the constructor configuration for this model."""
148
+ return deepcopy(self.__dict__.get("_init_config", {}))
91
149
 
92
150
  @classmethod
93
- def from_config(cls, config: dict) -> Self:
151
+ def from_config(cls, config: dict, args: tuple = ()) -> Self:
94
152
  """Build a model from a configuration produced by ``get_config``."""
95
- return cls(**deepcopy(config))
153
+ return cls(*deepcopy(args), **deepcopy(config))
96
154
 
97
155
  def to_spec(self) -> ModelSpec:
98
156
  """Return a serializable, immutable-snapshot rebuild recipe."""
99
- return ModelSpec(type(self), self.get_config())
157
+ return ModelSpec(
158
+ type(self),
159
+ self.get_config(),
160
+ self.__dict__.get("_init_args", ()),
161
+ )
100
162
 
101
163
  def rebuild(self) -> Self:
102
164
  """Build a structurally equivalent model with clean runtime state."""
@@ -40,7 +40,6 @@ class compiledOrtProblem(compiledBase, optOrtModel):
40
40
  self.problem = deepcopy(problem)
41
41
  self.params = dict(params) if params else {}
42
42
  self.solver = solver
43
- self._extra_constrs = [] # (coef, rhs) cuts replayed on copy
44
43
  optModel.__init__(self) # builds the model via _getModel
45
44
  self._model.SuppressOutput()
46
45
  self._set_obj_sense()
@@ -53,7 +53,6 @@ class optOrtCpModel(optModel):
53
53
  raise ImportError(
54
54
  "OR-Tools is not installed. Please install ortools to use this feature."
55
55
  )
56
- self._extra_constrs = []
57
56
  self._objective_coefs: list[int] | None = None
58
57
  super().__init__()
59
58
  # cache ordered Var list for batched weighted_sum / per-Var Value() loop
@@ -54,7 +54,6 @@ class optOrtModel(optModel):
54
54
  "OR-Tools is not installed. Please install ortools to use this feature."
55
55
  )
56
56
  self.solver = solver
57
- self._extra_constrs = []
58
57
  super().__init__()
59
58
  # suppress output
60
59
  self._model.SuppressOutput()
@@ -21,7 +21,7 @@ _EPS: float = 1e-8
21
21
 
22
22
  def getArgs(model: optModel) -> dict:
23
23
  """
24
- Compatibility wrapper for the explicit model configuration protocol.
24
+ Compatibility wrapper for model reconstruction configuration.
25
25
 
26
26
  Args:
27
27
  model: optimization model
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyepo
3
- Version: 2.2.3
4
- Summary: PyTorch-based End-to-End Predict-then-Optimize Tool
3
+ Version: 2.2.5
4
+ Summary: PyTorch/JAX-based End-to-End Predict-then-Optimize Tool
5
5
  Author-email: Bo Tang <bolucas.tang@mail.utoronto.ca>
6
6
  License: MIT
7
7
  Project-URL: Homepage, https://github.com/khalil-research/PyEPO
@@ -9,7 +9,7 @@ Project-URL: Documentation, https://khalil-research.github.io/PyEPO
9
9
  Project-URL: Repository, https://github.com/khalil-research/PyEPO
10
10
  Project-URL: Issues, https://github.com/khalil-research/PyEPO/issues
11
11
  Project-URL: Paper, https://link.springer.com/article/10.1007/s12532-024-00255-x
12
- Keywords: predict-then-optimize,end-to-end,decision-focused learning,optimization,deep learning,pytorch,linear programming,integer programming
12
+ Keywords: predict-then-optimize,end-to-end,decision-focused learning,optimization,deep learning,pytorch,jax,linear programming,integer programming
13
13
  Classifier: Programming Language :: Python :: 3
14
14
  Classifier: Programming Language :: Python :: 3.9
15
15
  Classifier: Programming Language :: Python :: 3.10
@@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "pyepo"
7
- version = "2.2.3"
8
- description = "PyTorch-based End-to-End Predict-then-Optimize Tool"
7
+ version = "2.2.5"
8
+ description = "PyTorch/JAX-based End-to-End Predict-then-Optimize Tool"
9
9
  readme = { file = "README.md", content-type = "text/markdown" }
10
10
  license = { text = "MIT" }
11
11
  authors = [{ name = "Bo Tang", email = "bolucas.tang@mail.utoronto.ca" }]
@@ -17,6 +17,7 @@ keywords = [
17
17
  "optimization",
18
18
  "deep learning",
19
19
  "pytorch",
20
+ "jax",
20
21
  "linear programming",
21
22
  "integer programming",
22
23
  ]
@@ -19,7 +19,7 @@ from .conftest import requires_gurobi
19
19
 
20
20
 
21
21
  class ConfigModel(optModel):
22
- """Solver-free model for the explicit reconstruction protocol."""
22
+ """Solver-free model with custom reconstruction config."""
23
23
 
24
24
  def __init__(self, values, label="default"):
25
25
  self.values = np.asarray(values)
@@ -43,6 +43,88 @@ class ConfigModel(optModel):
43
43
  return np.zeros(len(self.values)), 0.0
44
44
 
45
45
 
46
+ class AutoConfigModel(optModel):
47
+ """Solver-free model that relies on optModel's automatic config capture."""
48
+
49
+ def __init__(self, values, **kwargs):
50
+ self.values = values
51
+ self.kwargs = kwargs
52
+ super().__init__()
53
+
54
+ def _getModel(self):
55
+ return None, list(range(len(self.values)))
56
+
57
+ def setObj(self, c):
58
+ self.cost = np.asarray(c)
59
+
60
+ def solve(self):
61
+ return np.zeros(len(self.values)), 0.0
62
+
63
+
64
+ class VarArgsConfigModel(optModel):
65
+ """Solver-free model whose constructor needs positional replay."""
66
+
67
+ def __init__(self, *values, label="default"):
68
+ self.values = values
69
+ self.label = label
70
+ super().__init__()
71
+
72
+ def _getModel(self):
73
+ return None, list(range(len(self.values)))
74
+
75
+ def setObj(self, c):
76
+ self.cost = np.asarray(c)
77
+
78
+ def solve(self):
79
+ return np.zeros(len(self.values)), 0.0
80
+
81
+
82
+ class PosOnlyConfigModel(optModel):
83
+ """Solver-free model with a positional-only constructor argument."""
84
+
85
+ def __init__(self, values, /, label="default"):
86
+ self.values = values
87
+ self.label = label
88
+ super().__init__()
89
+
90
+ def _getModel(self):
91
+ return None, list(range(len(self.values)))
92
+
93
+ def setObj(self, c):
94
+ self.cost = np.asarray(c)
95
+
96
+ def solve(self):
97
+ return np.zeros(len(self.values)), 0.0
98
+
99
+
100
+ class _NoDeepcopy:
101
+ def __init__(self, values):
102
+ self.values = values
103
+
104
+ def __deepcopy__(self, memo):
105
+ raise TypeError("not deepcopyable")
106
+
107
+
108
+ class CustomConfigModel(optModel):
109
+ """Custom config should control reconstruction for unusual constructor inputs."""
110
+
111
+ def __init__(self, resource=None, values=None):
112
+ self.values = list(values if values is not None else resource.values)
113
+ super().__init__()
114
+
115
+ def get_config(self):
116
+ return {"values": self.values.copy()}
117
+
118
+ def _getModel(self):
119
+ return None, list(range(len(self.values)))
120
+
121
+ def setObj(self, c):
122
+ self.cost = np.asarray(c)
123
+
124
+ def solve(self):
125
+ return np.zeros(len(self.values)), 0.0
126
+
127
+
46
128
  # ============================================================
47
129
  # unionFind (pure)
48
130
  # ============================================================
@@ -161,7 +243,7 @@ class TestCostToNumpy:
161
243
 
162
244
 
163
245
  # ============================================================
164
- # explicit model reconstruction protocol (pure)
246
+ # model reconstruction config (pure)
165
247
  # ============================================================
166
248
 
167
249
 
@@ -222,6 +304,54 @@ class TestModelSpec:
222
304
  np.testing.assert_array_equal(sol, [0.0, 0.0, 0.0])
223
305
  assert obj == 0.0
224
306
 
307
+ def test_auto_config_snapshots_constructor_inputs(self):
308
+ values = [1, 2, 3]
309
+ nested = {"tag": ["x"]}
310
+ model = AutoConfigModel(values, nested=nested)
311
+ values[0] = 99
312
+ nested["tag"][0] = "changed"
313
+
314
+ rebuilt = model.rebuild()
315
+
316
+ assert rebuilt.values == [1, 2, 3]
317
+ assert rebuilt.kwargs == {"nested": {"tag": ["x"]}}
318
+
319
+ def test_auto_config_export_is_independent(self):
320
+ model = AutoConfigModel([1, 2, 3], nested={"tag": ["x"]})
321
+ config = model.get_config()
322
+ config["values"][0] = 99
323
+ config["nested"]["tag"][0] = "changed"
324
+
325
+ assert model.get_config()["values"] == [1, 2, 3]
326
+ assert model.get_config()["nested"] == {"tag": ["x"]}
327
+ assert model.rebuild().kwargs == {"nested": {"tag": ["x"]}}
328
+
329
+ def test_auto_config_replays_varargs(self):
330
+ model = VarArgsConfigModel(1, 2, 3, label="x")
331
+ spec = model.to_spec()
332
+
333
+ assert spec.args == (1, 2, 3)
334
+ assert spec.config == {"label": "x"}
335
+ rebuilt = model.rebuild()
336
+ assert rebuilt.values == (1, 2, 3)
337
+ assert rebuilt.label == "x"
338
+
339
+ def test_auto_config_replays_positional_only_args(self):
340
+ model = PosOnlyConfigModel([1, 2, 3], label="x")
341
+ spec = model.to_spec()
342
+
343
+ assert spec.args == ([1, 2, 3],)
344
+ assert spec.config == {"label": "x"}
345
+ rebuilt = model.rebuild()
346
+ assert rebuilt.values == [1, 2, 3]
347
+ assert rebuilt.label == "x"
348
+
349
+ def test_custom_get_config_accepts_uncopyable_constructor_input(self):
350
+ model = CustomConfigModel(_NoDeepcopy([1, 2, 3]))
351
+ rebuilt = model.rebuild()
352
+
353
+ assert rebuilt.values == [1, 2, 3]
354
+
225
355
 
226
356
  # ============================================================
227
357
  # getArgs (needs a real optModel)
@@ -14,6 +14,7 @@ import torch
14
14
  from pyepo.data import shortestpath
15
15
  from pyepo.data.dataset import (
16
16
  collate_tight_constraints,
17
+ optDataLoader,
17
18
  optDataset,
18
19
  optDatasetConstrs,
19
20
  optDatasetKNN,
@@ -299,3 +300,51 @@ class TestCollateTightConstraints:
299
300
  assert padded.shape == (2, 5, 4)
300
301
  # the shorter matrix is zero-padded at the tail
301
302
  assert torch.allclose(padded[0, 2:], torch.zeros(3, 4))
303
+
304
+
305
+ class TestOptDataLoader:
306
+ """Pure: optDataLoader applies a dataset's collate_fn automatically."""
307
+
308
+ def test_constrs_collate_fn_wired(self):
309
+ assert optDatasetConstrs.collate_fn is collate_tight_constraints
310
+
311
+ def test_uses_dataset_collate_fn(self):
312
+ class _Ragged(list):
313
+ collate_fn = staticmethod(lambda batch: ("auto", len(batch)))
314
+
315
+ loader = optDataLoader(_Ragged([0, 1, 2, 3]), batch_size=2)
316
+ assert next(iter(loader)) == ("auto", 2)
317
+
318
+ def test_accepts_positional_dataloader_args(self):
319
+ loader = optDataLoader([0, 1, 2, 3], 2)
320
+ assert torch.equal(next(iter(loader)), torch.tensor([0, 1]))
321
+
322
+ def test_explicit_collate_fn_wins(self):
323
+ class _Ragged(list):
324
+ collate_fn = staticmethod(lambda batch: ("auto", len(batch)))
325
+
326
+ loader = optDataLoader(
327
+ _Ragged([0, 1, 2, 3]), batch_size=2, collate_fn=lambda b: ("explicit", len(b))
328
+ )
329
+ assert next(iter(loader)) == ("explicit", 2)
330
+
331
+ def test_positional_collate_fn_wins(self):
332
+ class _Ragged(list):
333
+ collate_fn = staticmethod(lambda batch: ("auto", len(batch)))
334
+
335
+ loader = optDataLoader(
336
+ _Ragged([0, 1, 2, 3]),
337
+ 2,
338
+ None,
339
+ None,
340
+ None,
341
+ 0,
342
+ lambda b: ("positional", len(b)),
343
+ )
344
+ assert next(iter(loader)) == ("positional", 2)
345
+
346
+ def test_plain_dataset_uses_default_collate(self):
347
+ data = [(torch.tensor([float(i)]),) for i in range(4)]
348
+ loader = optDataLoader(data, batch_size=2)
349
+ (batch,) = next(iter(loader))
350
+ assert batch.shape == (2, 1)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes