pyepo 2.2.4__tar.gz → 2.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. {pyepo-2.2.4 → pyepo-2.2.5}/PKG-INFO +1 -1
  2. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/data/dataset.py +26 -2
  3. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/cave.py +2 -1
  4. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/opt.py +46 -18
  5. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/ort/compile.py +0 -1
  6. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/ort/ortcpmodel.py +0 -1
  7. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/ort/ortmodel.py +0 -1
  8. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/utils.py +1 -1
  9. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo.egg-info/PKG-INFO +1 -1
  10. {pyepo-2.2.4 → pyepo-2.2.5}/pyproject.toml +1 -1
  11. {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_10_utils.py +92 -2
  12. {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_40_dataset.py +49 -0
  13. {pyepo-2.2.4 → pyepo-2.2.5}/LICENSE +0 -0
  14. {pyepo-2.2.4 → pyepo-2.2.5}/README.md +0 -0
  15. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/EPO.py +0 -0
  16. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/__init__.py +0 -0
  17. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/data/__init__.py +0 -0
  18. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/data/_validation.py +0 -0
  19. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/data/knapsack.py +0 -0
  20. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/data/portfolio.py +0 -0
  21. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/data/shortestpath.py +0 -0
  22. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/data/tsp.py +0 -0
  23. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/dsl/__init__.py +0 -0
  24. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/dsl/compiled.py +0 -0
  25. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/dsl/expression.py +0 -0
  26. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/dsl/objective.py +0 -0
  27. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/dsl/problem.py +0 -0
  28. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/__init__.py +0 -0
  29. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/_common.py +0 -0
  30. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/abcmodule.py +0 -0
  31. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/blackbox.py +0 -0
  32. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/contrastive.py +0 -0
  33. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/__init__.py +0 -0
  34. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/abcmodule.py +0 -0
  35. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/blackbox.py +0 -0
  36. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/cave.py +0 -0
  37. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/contrastive.py +0 -0
  38. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/perturbed.py +0 -0
  39. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/rank.py +0 -0
  40. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/regularized.py +0 -0
  41. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/surrogate.py +0 -0
  42. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/jax/utils.py +0 -0
  43. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/perturbed.py +0 -0
  44. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/rank.py +0 -0
  45. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/regularized.py +0 -0
  46. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/runtime.py +0 -0
  47. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/surrogate.py +0 -0
  48. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/func/utils.py +0 -0
  49. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/metric/__init__.py +0 -0
  50. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/metric/_common.py +0 -0
  51. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/metric/metrics.py +0 -0
  52. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/metric/mse.py +0 -0
  53. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/metric/regret.py +0 -0
  54. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/metric/unambregret.py +0 -0
  55. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/__init__.py +0 -0
  56. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/_common.py +0 -0
  57. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/bases.py +0 -0
  58. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/copt/__init__.py +0 -0
  59. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/copt/compile.py +0 -0
  60. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/copt/coptmodel.py +0 -0
  61. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/copt/knapsack.py +0 -0
  62. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/copt/portfolio.py +0 -0
  63. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/copt/shortestpath.py +0 -0
  64. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/copt/tsp.py +0 -0
  65. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/copt/vrp.py +0 -0
  66. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/grb/__init__.py +0 -0
  67. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/grb/compile.py +0 -0
  68. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/grb/grbmodel.py +0 -0
  69. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/grb/knapsack.py +0 -0
  70. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/grb/portfolio.py +0 -0
  71. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/grb/shortestpath.py +0 -0
  72. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/grb/tsp.py +0 -0
  73. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/grb/vrp.py +0 -0
  74. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/mpax/__init__.py +0 -0
  75. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/mpax/compile.py +0 -0
  76. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/mpax/knapsack.py +0 -0
  77. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/mpax/mpaxmodel.py +0 -0
  78. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/mpax/shortestpath.py +0 -0
  79. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/omo/__init__.py +0 -0
  80. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/omo/compile.py +0 -0
  81. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/omo/knapsack.py +0 -0
  82. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/omo/omomodel.py +0 -0
  83. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/omo/portfolio.py +0 -0
  84. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/omo/shortestpath.py +0 -0
  85. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/omo/tsp.py +0 -0
  86. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/omo/vrp.py +0 -0
  87. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/ort/__init__.py +0 -0
  88. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/ort/knapsack.py +0 -0
  89. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/ort/shortestpath.py +0 -0
  90. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/predefined.py +0 -0
  91. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/model/utils.py +0 -0
  92. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/py.typed +0 -0
  93. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/twostage/__init__.py +0 -0
  94. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/twostage/autosklearnpred.py +0 -0
  95. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo/twostage/sklearnpred.py +0 -0
  96. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo.egg-info/SOURCES.txt +0 -0
  97. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo.egg-info/dependency_links.txt +0 -0
  98. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo.egg-info/requires.txt +0 -0
  99. {pyepo-2.2.4 → pyepo-2.2.5}/pyepo.egg-info/top_level.txt +0 -0
  100. {pyepo-2.2.4 → pyepo-2.2.5}/setup.cfg +0 -0
  101. {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_00_constants.py +0 -0
  102. {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_15_dsl.py +0 -0
  103. {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_20_data_gen.py +0 -0
  104. {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_30_model.py +0 -0
  105. {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_50_func.py +0 -0
  106. {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_55_jax.py +0 -0
  107. {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_60_metric.py +0 -0
  108. {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_61_metric_validation.py +0 -0
  109. {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_70_twostage.py +0 -0
  110. {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_80_integration.py +0 -0
  111. {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_85_backend_pipeline.py +0 -0
  112. {pyepo-2.2.4 → pyepo-2.2.5}/tests/test_90_cuda.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyepo
3
- Version: 2.2.4
3
+ Version: 2.2.5
4
4
  Summary: PyTorch/JAX-based End-to-End Predict-then-Optimize Tool
5
5
  Author-email: Bo Tang <bolucas.tang@mail.utoronto.ca>
6
6
  License: MIT
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, cast
11
11
  import numpy as np
12
12
  import torch
13
13
  from scipy.spatial import distance
14
- from torch.utils.data import Dataset
14
+ from torch.utils.data import DataLoader, Dataset
15
15
  from tqdm import tqdm
16
16
 
17
17
  from pyepo import EPO
@@ -315,7 +315,8 @@ class optDatasetConstrs(optDataset):
315
315
  currently requires a Gurobi-backed ``optModel``.
316
316
 
317
317
  Per-instance row counts differ (different constraints bind at different
318
- vertices), so batches must be assembled with ``collate_tight_constraints``.
318
+ vertices), so batch with ``optDataLoader`` or pass
319
+ ``collate_tight_constraints`` to a PyTorch ``DataLoader``.
319
320
 
320
321
  Reference: Tang & Khalil (2024)
321
322
  `<https://link.springer.com/chapter/10.1007/978-3-031-60599-4_12>`_
@@ -455,6 +456,29 @@ def collate_tight_constraints(batch):
455
456
  )
456
457
 
457
458
 
459
+ # optDatasetConstrs yields ragged binding-constraint matrices; pad them when batching
460
+ optDatasetConstrs.collate_fn = staticmethod(collate_tight_constraints)
461
+
462
+
463
+ class optDataLoader(DataLoader):
464
+ """
465
+ ``DataLoader`` that applies a dataset's own ``collate_fn`` when present.
466
+
467
+ Datasets with ragged samples (e.g. ``optDatasetConstrs``, whose
468
+ binding-constraint matrices vary in row count) carry a ``collate_fn``; this
469
+ loader uses it so the caller never passes one explicitly. Plain datasets
470
+ fall back to the default PyTorch collation.
471
+ """
472
+
473
+ def __init__(self, dataset, *args, **kwargs):
474
+ # use the dataset's own collate_fn unless the caller supplies one
475
+ if len(args) <= 5 and "collate_fn" not in kwargs:
476
+ collate_fn = getattr(dataset, "collate_fn", None)
477
+ if collate_fn is not None:
478
+ kwargs["collate_fn"] = collate_fn
479
+ super().__init__(dataset, *args, **kwargs)
480
+
481
+
458
482
  def _extract_tight_normals(
459
483
  model: optModel,
460
484
  sol: np.ndarray,
@@ -59,7 +59,8 @@ class coneAlignedCosine(optModule):
59
59
  cutting the per-epoch cost without measurable regret loss.
60
60
 
61
61
  Training data must come from ``pyepo.data.dataset.optDatasetConstrs``
62
- (Gurobi-backed) and be collated with ``collate_tight_constraints``.
62
+ (Gurobi-backed) and be batched with ``pyepo.data.dataset.optDataLoader``
63
+ or a ``DataLoader`` using ``collate_tight_constraints``.
63
64
 
64
65
  Reference: Tang & Khalil (2024)
65
66
  `<https://link.springer.com/chapter/10.1007/978-3-031-60599-4_12>`_
@@ -27,41 +27,63 @@ class ModelSpec:
27
27
  """Serializable recipe for building a fresh optimization model."""
28
28
 
29
29
  model_type: type[optModel]
30
+ _args: tuple
30
31
  _config: dict
31
32
 
32
- def __init__(self, model_type: type[optModel], config: dict) -> None:
33
+ def __init__(
34
+ self,
35
+ model_type: type[optModel],
36
+ config: dict,
37
+ args: tuple = (),
38
+ ) -> None:
33
39
  object.__setattr__(self, "model_type", model_type)
40
+ object.__setattr__(self, "_args", deepcopy(args))
34
41
  object.__setattr__(self, "_config", deepcopy(config))
35
42
 
43
+ @property
44
+ def args(self) -> tuple:
45
+ """Return an independent copy of positional constructor arguments."""
46
+ return deepcopy(self._args)
47
+
36
48
  @property
37
49
  def config(self) -> dict:
38
- """Return an independent copy of the constructor configuration."""
50
+ """Return an independent copy of keyword constructor arguments."""
39
51
  return deepcopy(self._config)
40
52
 
41
53
  def build(self) -> optModel:
42
54
  """Build a fresh model without sharing mutable configuration values."""
43
- return self.model_type.from_config(self._config)
55
+ return self.model_type.from_config(self._config, self._args)
56
+
44
57
 
58
+ def _snapshot(value):
59
+ """Deep-copy a constructor argument, keeping the reference if it cannot be copied."""
60
+ try:
61
+ return deepcopy(value)
62
+ except Exception: # noqa: BLE001 -- any copy failure falls back to a reference
63
+ return value
45
64
 
46
- def _capture_init_config(init, args, kwargs) -> dict:
47
- """Flatten a constructor call into keyword arguments that rebuild the model."""
65
+
66
+ def _capture_init_config(init, args, kwargs) -> tuple[tuple, dict]:
67
+ """Flatten a constructor call into arguments that rebuild the model."""
48
68
  sig = inspect.signature(init)
49
69
  bound = sig.bind(None, *args, **kwargs)
70
+ init_args = []
50
71
  config = {}
51
72
  for i, (name, value) in enumerate(bound.arguments.items()):
52
73
  # the first bound argument is self
53
74
  if i == 0:
54
75
  continue
55
76
  kind = sig.parameters[name].kind
56
- # **kwargs: merge captured keywords in directly
77
+ # snapshot each argument; values that cannot be deep-copied keep a reference
57
78
  if kind is inspect.Parameter.VAR_KEYWORD:
58
- config.update(deepcopy(value))
59
- # nameless positionals cannot replay by keyword; override get_config to keep them
60
- elif kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.POSITIONAL_ONLY):
61
- continue
79
+ config.update({k: _snapshot(v) for k, v in value.items()})
80
+ elif kind is inspect.Parameter.VAR_POSITIONAL:
81
+ init_args.extend(_snapshot(v) for v in value)
82
+ elif kind is inspect.Parameter.POSITIONAL_ONLY:
83
+ init_args.append(_snapshot(value))
62
84
  else:
63
- config[name] = deepcopy(value)
64
- return config
85
+ config[name] = _snapshot(value)
86
+ return tuple(init_args), config
65
87
 
66
88
 
67
89
  class optModel(ABC):
@@ -93,8 +115,10 @@ class optModel(ABC):
93
115
 
94
116
  def __init_subclass__(cls, **kwargs) -> None:
95
117
  super().__init_subclass__(**kwargs)
96
- # only wrap a subclass that defines its own __init__
97
- if "__init__" not in cls.__dict__:
118
+ # Only wrap subclasses that define their own __init__ and use the
119
+ # default reconstruction config. Custom get_config implementations may
120
+ # intentionally accept objects that are not deepcopyable/rebuildable.
121
+ if "__init__" not in cls.__dict__ or "get_config" in cls.__dict__:
98
122
  return
99
123
  user_init = cls.__init__
100
124
 
@@ -102,7 +126,7 @@ class optModel(ABC):
102
126
  def _init_capturing(self, *args, **kwargs):
103
127
  # record only the outermost call; nested super().__init__ leaves it intact
104
128
  if "_init_config" not in self.__dict__:
105
- self._init_config = _capture_init_config(user_init, args, kwargs)
129
+ self._init_args, self._init_config = _capture_init_config(user_init, args, kwargs)
106
130
  user_init(self, *args, **kwargs)
107
131
 
108
132
  cls.__init__ = _init_capturing
@@ -124,13 +148,17 @@ class optModel(ABC):
124
148
  return deepcopy(self.__dict__.get("_init_config", {}))
125
149
 
126
150
  @classmethod
127
- def from_config(cls, config: dict) -> Self:
151
+ def from_config(cls, config: dict, args: tuple = ()) -> Self:
128
152
  """Build a model from a configuration produced by ``get_config``."""
129
- return cls(**deepcopy(config))
153
+ return cls(*deepcopy(args), **deepcopy(config))
130
154
 
131
155
  def to_spec(self) -> ModelSpec:
132
156
  """Return a serializable, immutable-snapshot rebuild recipe."""
133
- return ModelSpec(type(self), self.get_config())
157
+ return ModelSpec(
158
+ type(self),
159
+ self.get_config(),
160
+ self.__dict__.get("_init_args", ()),
161
+ )
134
162
 
135
163
  def rebuild(self) -> Self:
136
164
  """Build a structurally equivalent model with clean runtime state."""
@@ -40,7 +40,6 @@ class compiledOrtProblem(compiledBase, optOrtModel):
40
40
  self.problem = deepcopy(problem)
41
41
  self.params = dict(params) if params else {}
42
42
  self.solver = solver
43
- self._extra_constrs = [] # (coef, rhs) cuts replayed on copy
44
43
  optModel.__init__(self) # builds the model via _getModel
45
44
  self._model.SuppressOutput()
46
45
  self._set_obj_sense()
@@ -53,7 +53,6 @@ class optOrtCpModel(optModel):
53
53
  raise ImportError(
54
54
  "OR-Tools is not installed. Please install ortools to use this feature."
55
55
  )
56
- self._extra_constrs = []
57
56
  self._objective_coefs: list[int] | None = None
58
57
  super().__init__()
59
58
  # cache ordered Var list for batched weighted_sum / per-Var Value() loop
@@ -54,7 +54,6 @@ class optOrtModel(optModel):
54
54
  "OR-Tools is not installed. Please install ortools to use this feature."
55
55
  )
56
56
  self.solver = solver
57
- self._extra_constrs = []
58
57
  super().__init__()
59
58
  # suppress output
60
59
  self._model.SuppressOutput()
@@ -21,7 +21,7 @@ _EPS: float = 1e-8
21
21
 
22
22
  def getArgs(model: optModel) -> dict:
23
23
  """
24
- Compatibility wrapper for the explicit model configuration protocol.
24
+ Compatibility wrapper for model reconstruction configuration.
25
25
 
26
26
  Args:
27
27
  model: optimization model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyepo
3
- Version: 2.2.4
3
+ Version: 2.2.5
4
4
  Summary: PyTorch/JAX-based End-to-End Predict-then-Optimize Tool
5
5
  Author-email: Bo Tang <bolucas.tang@mail.utoronto.ca>
6
6
  License: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "pyepo"
7
- version = "2.2.4"
7
+ version = "2.2.5"
8
8
  description = "PyTorch/JAX-based End-to-End Predict-then-Optimize Tool"
9
9
  readme = { file = "README.md", content-type = "text/markdown" }
10
10
  license = { text = "MIT" }
@@ -19,7 +19,7 @@ from .conftest import requires_gurobi
19
19
 
20
20
 
21
21
  class ConfigModel(optModel):
22
- """Solver-free model for the explicit reconstruction protocol."""
22
+ """Solver-free model with custom reconstruction config."""
23
23
 
24
24
  def __init__(self, values, label="default"):
25
25
  self.values = np.asarray(values)
@@ -61,6 +61,70 @@ class AutoConfigModel(optModel):
61
61
  return np.zeros(len(self.values)), 0.0
62
62
 
63
63
 
64
+ class VarArgsConfigModel(optModel):
65
+ """Solver-free model whose constructor needs positional replay."""
66
+
67
+ def __init__(self, *values, label="default"):
68
+ self.values = values
69
+ self.label = label
70
+ super().__init__()
71
+
72
+ def _getModel(self):
73
+ return None, list(range(len(self.values)))
74
+
75
+ def setObj(self, c):
76
+ self.cost = np.asarray(c)
77
+
78
+ def solve(self):
79
+ return np.zeros(len(self.values)), 0.0
80
+
81
+
82
+ class PosOnlyConfigModel(optModel):
83
+ """Solver-free model with a positional-only constructor argument."""
84
+
85
+ def __init__(self, values, /, label="default"):
86
+ self.values = values
87
+ self.label = label
88
+ super().__init__()
89
+
90
+ def _getModel(self):
91
+ return None, list(range(len(self.values)))
92
+
93
+ def setObj(self, c):
94
+ self.cost = np.asarray(c)
95
+
96
+ def solve(self):
97
+ return np.zeros(len(self.values)), 0.0
98
+
99
+
100
+ class _NoDeepcopy:
101
+ def __init__(self, values):
102
+ self.values = values
103
+
104
+ def __deepcopy__(self, memo):
105
+ raise TypeError("not deepcopyable")
106
+
107
+
108
+ class CustomConfigModel(optModel):
109
+ """Custom config should control reconstruction for unusual constructor inputs."""
110
+
111
+ def __init__(self, resource=None, values=None):
112
+ self.values = list(values if values is not None else resource.values)
113
+ super().__init__()
114
+
115
+ def get_config(self):
116
+ return {"values": self.values.copy()}
117
+
118
+ def _getModel(self):
119
+ return None, list(range(len(self.values)))
120
+
121
+ def setObj(self, c):
122
+ self.cost = np.asarray(c)
123
+
124
+ def solve(self):
125
+ return np.zeros(len(self.values)), 0.0
126
+
127
+
64
128
  # ============================================================
65
129
  # unionFind (pure)
66
130
  # ============================================================
@@ -179,7 +243,7 @@ class TestCostToNumpy:
179
243
 
180
244
 
181
245
  # ============================================================
182
- # explicit model reconstruction protocol (pure)
246
+ # model reconstruction config (pure)
183
247
  # ============================================================
184
248
 
185
249
 
@@ -262,6 +326,32 @@ class TestModelSpec:
262
326
  assert model.get_config()["nested"] == {"tag": ["x"]}
263
327
  assert model.rebuild().kwargs == {"nested": {"tag": ["x"]}}
264
328
 
329
+ def test_auto_config_replays_varargs(self):
330
+ model = VarArgsConfigModel(1, 2, 3, label="x")
331
+ spec = model.to_spec()
332
+
333
+ assert spec.args == (1, 2, 3)
334
+ assert spec.config == {"label": "x"}
335
+ rebuilt = model.rebuild()
336
+ assert rebuilt.values == (1, 2, 3)
337
+ assert rebuilt.label == "x"
338
+
339
+ def test_auto_config_replays_positional_only_args(self):
340
+ model = PosOnlyConfigModel([1, 2, 3], label="x")
341
+ spec = model.to_spec()
342
+
343
+ assert spec.args == ([1, 2, 3],)
344
+ assert spec.config == {"label": "x"}
345
+ rebuilt = model.rebuild()
346
+ assert rebuilt.values == [1, 2, 3]
347
+ assert rebuilt.label == "x"
348
+
349
+ def test_custom_get_config_accepts_uncopyable_constructor_input(self):
350
+ model = CustomConfigModel(_NoDeepcopy([1, 2, 3]))
351
+ rebuilt = model.rebuild()
352
+
353
+ assert rebuilt.values == [1, 2, 3]
354
+
265
355
 
266
356
  # ============================================================
267
357
  # getArgs (needs a real optModel)
@@ -14,6 +14,7 @@ import torch
14
14
  from pyepo.data import shortestpath
15
15
  from pyepo.data.dataset import (
16
16
  collate_tight_constraints,
17
+ optDataLoader,
17
18
  optDataset,
18
19
  optDatasetConstrs,
19
20
  optDatasetKNN,
@@ -299,3 +300,51 @@ class TestCollateTightConstraints:
299
300
  assert padded.shape == (2, 5, 4)
300
301
  # the shorter matrix is zero-padded at the tail
301
302
  assert torch.allclose(padded[0, 2:], torch.zeros(3, 4))
303
+
304
+
305
+ class TestOptDataLoader:
306
+ """Pure: optDataLoader applies a dataset's collate_fn automatically."""
307
+
308
+ def test_constrs_collate_fn_wired(self):
309
+ assert optDatasetConstrs.collate_fn is collate_tight_constraints
310
+
311
+ def test_uses_dataset_collate_fn(self):
312
+ class _Ragged(list):
313
+ collate_fn = staticmethod(lambda batch: ("auto", len(batch)))
314
+
315
+ loader = optDataLoader(_Ragged([0, 1, 2, 3]), batch_size=2)
316
+ assert next(iter(loader)) == ("auto", 2)
317
+
318
+ def test_accepts_positional_dataloader_args(self):
319
+ loader = optDataLoader([0, 1, 2, 3], 2)
320
+ assert torch.equal(next(iter(loader)), torch.tensor([0, 1]))
321
+
322
+ def test_explicit_collate_fn_wins(self):
323
+ class _Ragged(list):
324
+ collate_fn = staticmethod(lambda batch: ("auto", len(batch)))
325
+
326
+ loader = optDataLoader(
327
+ _Ragged([0, 1, 2, 3]), batch_size=2, collate_fn=lambda b: ("explicit", len(b))
328
+ )
329
+ assert next(iter(loader)) == ("explicit", 2)
330
+
331
+ def test_positional_collate_fn_wins(self):
332
+ class _Ragged(list):
333
+ collate_fn = staticmethod(lambda batch: ("auto", len(batch)))
334
+
335
+ loader = optDataLoader(
336
+ _Ragged([0, 1, 2, 3]),
337
+ 2,
338
+ None,
339
+ None,
340
+ None,
341
+ 0,
342
+ lambda b: ("positional", len(b)),
343
+ )
344
+ assert next(iter(loader)) == ("positional", 2)
345
+
346
+ def test_plain_dataset_uses_default_collate(self):
347
+ data = [(torch.tensor([float(i)]),) for i in range(4)]
348
+ loader = optDataLoader(data, batch_size=2)
349
+ (batch,) = next(iter(loader))
350
+ assert batch.shape == (2, 1)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes