pyepo 2.2.4__tar.gz → 2.2.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyepo-2.2.4 → pyepo-2.2.6}/PKG-INFO +3 -2
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/data/_validation.py +1 -2
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/data/dataset.py +37 -7
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/cave.py +5 -1
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/cave.py +3 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/regularized.py +2 -3
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/surrogate.py +6 -1
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/utils.py +9 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/perturbed.py +6 -74
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/regularized.py +2 -3
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/runtime.py +18 -5
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/surrogate.py +6 -1
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/utils.py +27 -11
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/metric/_common.py +0 -6
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/metric/mse.py +7 -6
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/metric/regret.py +4 -4
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/metric/unambregret.py +5 -5
- pyepo-2.2.6/pyepo/model/_mvar_compile.py +83 -0
- pyepo-2.2.6/pyepo/model/copt/compile.py +66 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/copt/coptmodel.py +7 -1
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/copt/tsp.py +2 -3
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/copt/vrp.py +11 -2
- pyepo-2.2.6/pyepo/model/grb/compile.py +62 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/grb/vrp.py +4 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/mpax/compile.py +17 -8
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/mpax/knapsack.py +6 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/omo/vrp.py +6 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/opt.py +66 -21
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/ort/compile.py +0 -2
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/ort/ortcpmodel.py +0 -1
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/ort/ortmodel.py +4 -1
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/utils.py +5 -4
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo.egg-info/PKG-INFO +3 -2
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo.egg-info/SOURCES.txt +1 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo.egg-info/requires.txt +1 -1
- {pyepo-2.2.4 → pyepo-2.2.6}/pyproject.toml +3 -5
- {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_00_constants.py +2 -13
- {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_10_utils.py +122 -50
- {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_15_dsl.py +75 -90
- pyepo-2.2.6/tests/test_20_data_gen.py +125 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_30_model.py +82 -103
- {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_40_dataset.py +59 -8
- {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_50_func.py +190 -201
- {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_55_jax.py +316 -383
- {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_60_metric.py +43 -77
- {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_61_metric_validation.py +27 -7
- {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_70_twostage.py +12 -29
- {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_80_integration.py +15 -9
- {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_85_backend_pipeline.py +8 -22
- {pyepo-2.2.4 → pyepo-2.2.6}/tests/test_90_cuda.py +9 -27
- pyepo-2.2.4/pyepo/model/copt/compile.py +0 -107
- pyepo-2.2.4/pyepo/model/grb/compile.py +0 -103
- pyepo-2.2.4/tests/test_20_data_gen.py +0 -193
- {pyepo-2.2.4 → pyepo-2.2.6}/LICENSE +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/README.md +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/EPO.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/data/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/data/knapsack.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/data/portfolio.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/data/shortestpath.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/data/tsp.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/dsl/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/dsl/compiled.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/dsl/expression.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/dsl/objective.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/dsl/problem.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/_common.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/abcmodule.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/blackbox.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/contrastive.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/abcmodule.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/blackbox.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/contrastive.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/perturbed.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/jax/rank.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/func/rank.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/metric/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/metric/metrics.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/_common.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/bases.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/copt/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/copt/knapsack.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/copt/portfolio.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/copt/shortestpath.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/grb/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/grb/grbmodel.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/grb/knapsack.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/grb/portfolio.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/grb/shortestpath.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/grb/tsp.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/mpax/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/mpax/mpaxmodel.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/mpax/shortestpath.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/omo/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/omo/compile.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/omo/knapsack.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/omo/omomodel.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/omo/portfolio.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/omo/shortestpath.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/omo/tsp.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/ort/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/ort/knapsack.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/ort/shortestpath.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/predefined.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/model/utils.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/py.typed +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/twostage/__init__.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/twostage/autosklearnpred.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo/twostage/sklearnpred.py +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo.egg-info/dependency_links.txt +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/pyepo.egg-info/top_level.txt +0 -0
- {pyepo-2.2.4 → pyepo-2.2.6}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyepo
|
|
3
|
-
Version: 2.2.
|
|
3
|
+
Version: 2.2.6
|
|
4
4
|
Summary: PyTorch/JAX-based End-to-End Predict-then-Optimize Tool
|
|
5
5
|
Author-email: Bo Tang <bolucas.tang@mail.utoronto.ca>
|
|
6
6
|
License: MIT
|
|
@@ -16,6 +16,7 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
16
16
|
Classifier: Programming Language :: Python :: 3.11
|
|
17
17
|
Classifier: Programming Language :: Python :: 3.12
|
|
18
18
|
Classifier: Programming Language :: Python :: 3.13
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
19
20
|
Classifier: Operating System :: OS Independent
|
|
20
21
|
Classifier: License :: OSI Approved :: MIT License
|
|
21
22
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
@@ -29,7 +30,7 @@ Requires-Dist: numpy
|
|
|
29
30
|
Requires-Dist: scipy
|
|
30
31
|
Requires-Dist: pathos
|
|
31
32
|
Requires-Dist: tqdm
|
|
32
|
-
Requires-Dist:
|
|
33
|
+
Requires-Dist: scikit-learn
|
|
33
34
|
Requires-Dist: torch>=1.13.1
|
|
34
35
|
Provides-Extra: pyomo
|
|
35
36
|
Requires-Dist: pyomo>=6.1.2; extra == "pyomo"
|
|
@@ -6,8 +6,7 @@ from numbers import Real
|
|
|
6
6
|
|
|
7
7
|
def validate_degree(deg: int) -> None:
|
|
8
8
|
"""Validate a positive integer polynomial degree."""
|
|
9
|
-
|
|
10
|
-
raise ValueError(f"deg = {deg} should be a positive integer.")
|
|
9
|
+
validate_positive_int(deg, "deg")
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
def validate_nonnegative(value: float, name: str) -> None:
|
|
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, cast
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import torch
|
|
13
13
|
from scipy.spatial import distance
|
|
14
|
-
from torch.utils.data import Dataset
|
|
14
|
+
from torch.utils.data import DataLoader, Dataset
|
|
15
15
|
from tqdm import tqdm
|
|
16
16
|
|
|
17
17
|
from pyepo import EPO
|
|
@@ -33,14 +33,20 @@ def _validate_inputs(
|
|
|
33
33
|
model: optModel,
|
|
34
34
|
feats: np.ndarray | torch.Tensor,
|
|
35
35
|
costs: np.ndarray | torch.Tensor,
|
|
36
|
-
) ->
|
|
37
|
-
"""Validate the common constructor contract for optimization datasets."""
|
|
36
|
+
) -> tuple[np.ndarray | torch.Tensor, np.ndarray | torch.Tensor]:
|
|
37
|
+
"""Validate and normalize the common constructor contract for optimization datasets."""
|
|
38
38
|
if not isinstance(model, optModel):
|
|
39
39
|
raise TypeError("arg model is not an optModel")
|
|
40
|
+
# array-likes become numpy arrays
|
|
41
|
+
if not isinstance(feats, (np.ndarray, torch.Tensor)):
|
|
42
|
+
feats = np.asarray(feats)
|
|
43
|
+
if not isinstance(costs, (np.ndarray, torch.Tensor)):
|
|
44
|
+
costs = np.asarray(costs)
|
|
40
45
|
if len(feats) != len(costs):
|
|
41
46
|
raise ValueError(
|
|
42
47
|
f"feats and costs must have the same number of instances: {len(feats)} vs {len(costs)}."
|
|
43
48
|
)
|
|
49
|
+
return feats, costs
|
|
44
50
|
|
|
45
51
|
|
|
46
52
|
def _as_float_tensor(data) -> torch.Tensor:
|
|
@@ -91,7 +97,7 @@ class optDataset(Dataset):
|
|
|
91
97
|
feats: data features
|
|
92
98
|
costs: costs of objective function
|
|
93
99
|
"""
|
|
94
|
-
_validate_inputs(model, feats, costs)
|
|
100
|
+
feats, costs = _validate_inputs(model, feats, costs)
|
|
95
101
|
self.model = model
|
|
96
102
|
# data
|
|
97
103
|
self.feats = feats
|
|
@@ -229,7 +235,7 @@ class optDatasetKNN(optDataset):
|
|
|
229
235
|
k: number of nearest neighbours selected
|
|
230
236
|
weight: self-weight in the kNN convex combination (1.0 = no smoothing)
|
|
231
237
|
"""
|
|
232
|
-
_validate_inputs(model, feats, costs)
|
|
238
|
+
feats, costs = _validate_inputs(model, feats, costs)
|
|
233
239
|
self.model = model
|
|
234
240
|
# at most num_data-1 neighbours exist (self excluded), so k must stay below it
|
|
235
241
|
num_data = len(feats)
|
|
@@ -315,7 +321,8 @@ class optDatasetConstrs(optDataset):
|
|
|
315
321
|
currently requires a Gurobi-backed ``optModel``.
|
|
316
322
|
|
|
317
323
|
Per-instance row counts differ (different constraints bind at different
|
|
318
|
-
vertices), so
|
|
324
|
+
vertices), so batch with ``optDataLoader`` or pass
|
|
325
|
+
``collate_tight_constraints`` to a PyTorch ``DataLoader``.
|
|
319
326
|
|
|
320
327
|
Reference: Tang & Khalil (2024)
|
|
321
328
|
`<https://link.springer.com/chapter/10.1007/978-3-031-60599-4_12>`_
|
|
@@ -346,7 +353,7 @@ class optDatasetConstrs(optDataset):
|
|
|
346
353
|
costs: costs of objective function
|
|
347
354
|
skip_infeas: if True, drop infeasible instances instead of raising
|
|
348
355
|
"""
|
|
349
|
-
_validate_inputs(model, feats, costs)
|
|
356
|
+
feats, costs = _validate_inputs(model, feats, costs)
|
|
350
357
|
self.model = model
|
|
351
358
|
self.skip_infeas = skip_infeas
|
|
352
359
|
# data
|
|
@@ -455,6 +462,29 @@ def collate_tight_constraints(batch):
|
|
|
455
462
|
)
|
|
456
463
|
|
|
457
464
|
|
|
465
|
+
# optDatasetConstrs yields ragged binding-constraint matrices; pad them when batching
|
|
466
|
+
optDatasetConstrs.collate_fn = staticmethod(collate_tight_constraints)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
class optDataLoader(DataLoader):
|
|
470
|
+
"""
|
|
471
|
+
``DataLoader`` that applies a dataset's own ``collate_fn`` when present.
|
|
472
|
+
|
|
473
|
+
Datasets with ragged samples (e.g. ``optDatasetConstrs``, whose
|
|
474
|
+
binding-constraint matrices vary in row count) carry a ``collate_fn``; this
|
|
475
|
+
loader uses it so the caller never passes one explicitly. Plain datasets
|
|
476
|
+
fall back to the default PyTorch collation.
|
|
477
|
+
"""
|
|
478
|
+
|
|
479
|
+
def __init__(self, dataset, *args, **kwargs):
|
|
480
|
+
# use the dataset's own collate_fn unless the caller supplies one
|
|
481
|
+
if len(args) <= 5 and "collate_fn" not in kwargs:
|
|
482
|
+
collate_fn = getattr(dataset, "collate_fn", None)
|
|
483
|
+
if collate_fn is not None:
|
|
484
|
+
kwargs["collate_fn"] = collate_fn
|
|
485
|
+
super().__init__(dataset, *args, **kwargs)
|
|
486
|
+
|
|
487
|
+
|
|
458
488
|
def _extract_tight_normals(
|
|
459
489
|
model: optModel,
|
|
460
490
|
sol: np.ndarray,
|
|
@@ -59,12 +59,16 @@ class coneAlignedCosine(optModule):
|
|
|
59
59
|
cutting the per-epoch cost without measurable regret loss.
|
|
60
60
|
|
|
61
61
|
Training data must come from ``pyepo.data.dataset.optDatasetConstrs``
|
|
62
|
-
(Gurobi-backed) and be
|
|
62
|
+
(Gurobi-backed) and be batched with ``pyepo.data.dataset.optDataLoader``
|
|
63
|
+
or a ``DataLoader`` using ``collate_tight_constraints``.
|
|
63
64
|
|
|
64
65
|
Reference: Tang & Khalil (2024)
|
|
65
66
|
`<https://link.springer.com/chapter/10.1007/978-3-031-60599-4_12>`_
|
|
66
67
|
"""
|
|
67
68
|
|
|
69
|
+
# Clarabel-only workers
|
|
70
|
+
_pool_needs_optmodel = False
|
|
71
|
+
|
|
68
72
|
def __init__(
|
|
69
73
|
self,
|
|
70
74
|
optmodel: optModel,
|
|
@@ -139,9 +139,8 @@ def _away_step_frank_wolfe(theta, module, use_cache=False):
|
|
|
139
139
|
w = w.at[bidx, match_idx].add(gamma_fw * (has_match & use_fw).astype(theta.dtype))
|
|
140
140
|
vt = vt.at[bidx, free_idx].set(jnp.where(add_new[:, None], v, vt[bidx, free_idx]))
|
|
141
141
|
vn = vn.at[bidx, free_idx].set(jnp.where(add_new, vnv, vn[bidx, free_idx]))
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
)
|
|
142
|
+
# displaced mass is dropped
|
|
143
|
+
w = w.at[bidx, free_idx].set(jnp.where(add_new, gamma_fw, w[bidx, free_idx]))
|
|
145
144
|
# away subtract, then clear FP residue so dropped atoms leave the active set
|
|
146
145
|
w = w.at[bidx, away_idx].add(-gamma_away)
|
|
147
146
|
w = jnp.where(w < 1e-12, 0.0, jnp.maximum(w, 0.0))
|
|
@@ -14,7 +14,7 @@ import jax.numpy as jnp
|
|
|
14
14
|
from pyepo.func._common import is_minimize, validate_positive
|
|
15
15
|
from pyepo.func.jax.abcmodule import optModule
|
|
16
16
|
from pyepo.func.jax.utils import _full_cost, _solve_or_cache
|
|
17
|
-
from pyepo.utils import _EPS
|
|
17
|
+
from pyepo.utils import _EPS, objective_offset
|
|
18
18
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
20
|
from pyepo.func.runtime import Reduction
|
|
@@ -68,6 +68,11 @@ def _spoplus_value_and_grad(pred_cost, true_cost, true_sol, true_obj, module):
|
|
|
68
68
|
# solve the perturbed problem
|
|
69
69
|
sol, obj = _solve_or_cache(2.0 * pred_cost - true_cost, module)
|
|
70
70
|
z = jnp.squeeze(true_obj, axis=-1) if true_obj.ndim > 1 else true_obj
|
|
71
|
+
# drop the bare objective constant
|
|
72
|
+
offset = objective_offset(module.optmodel)
|
|
73
|
+
if offset:
|
|
74
|
+
obj = obj - offset
|
|
75
|
+
z = z - offset
|
|
71
76
|
inner = 2.0 * jnp.einsum("bi,bi->b", pred_cost, true_sol)
|
|
72
77
|
# loss and subgradient
|
|
73
78
|
if is_minimize(module.optmodel.modelSense):
|
|
@@ -12,6 +12,7 @@ import numpy as np
|
|
|
12
12
|
from pyepo.func._common import is_minimize, solution_pool_tolerance
|
|
13
13
|
from pyepo.func.utils import _solve_batch_np
|
|
14
14
|
from pyepo.model.mpax import optMpaxModel
|
|
15
|
+
from pyepo.utils import objective_offset
|
|
15
16
|
|
|
16
17
|
try:
|
|
17
18
|
from jax.extend.core import concrete_or_error as _concrete_or_error
|
|
@@ -83,6 +84,10 @@ def _solve_batch_mpax(cost, optmodel):
|
|
|
83
84
|
# obj in true sense
|
|
84
85
|
if not minimize:
|
|
85
86
|
obj = -obj
|
|
87
|
+
# add the bare objective constant
|
|
88
|
+
offset = objective_offset(optmodel)
|
|
89
|
+
if offset:
|
|
90
|
+
obj = obj + offset
|
|
86
91
|
return sol, obj
|
|
87
92
|
|
|
88
93
|
|
|
@@ -127,6 +132,10 @@ def _cache_in_pass(cost, optmodel, solpool):
|
|
|
127
132
|
select = jnp.argmin if is_minimize(optmodel.modelSense) else jnp.argmax
|
|
128
133
|
ind = select(solpool_obj, axis=1)
|
|
129
134
|
obj = jnp.take_along_axis(solpool_obj, ind[:, None], axis=1).squeeze(1)
|
|
135
|
+
# add the bare objective constant
|
|
136
|
+
offset = objective_offset(optmodel)
|
|
137
|
+
if offset:
|
|
138
|
+
obj = obj + offset
|
|
130
139
|
return solpool[ind], obj
|
|
131
140
|
|
|
132
141
|
|
|
@@ -13,20 +13,16 @@ from torch.autograd import Function
|
|
|
13
13
|
|
|
14
14
|
from pyepo.func._common import (
|
|
15
15
|
is_minimize,
|
|
16
|
-
require_solution_pool,
|
|
17
16
|
validate_positive,
|
|
18
17
|
validate_positive_int,
|
|
19
18
|
)
|
|
20
19
|
from pyepo.func.abcmodule import optModule
|
|
21
20
|
from pyepo.func.utils import (
|
|
22
21
|
_mask_pred,
|
|
22
|
+
_solve_or_cache,
|
|
23
23
|
_torch_generator,
|
|
24
|
-
_update_solution_pool,
|
|
25
24
|
sumGammaDistribution,
|
|
26
25
|
)
|
|
27
|
-
from pyepo.func.utils import (
|
|
28
|
-
_solve_batch as _solve_batch_2d,
|
|
29
|
-
)
|
|
30
26
|
from pyepo.utils import _EPS
|
|
31
27
|
|
|
32
28
|
if TYPE_CHECKING:
|
|
@@ -625,76 +621,12 @@ class adaptiveImplicitMLEFunc(implicitMLEFunc):
|
|
|
625
621
|
|
|
626
622
|
def _solve_or_cache_3d(ptb_c: torch.Tensor, module: optModule) -> torch.Tensor:
|
|
627
623
|
"""
|
|
628
|
-
Solve or use cached solutions for perturbed costs (
|
|
629
|
-
|
|
630
|
-
"""
|
|
631
|
-
optmodel = module.optmodel
|
|
632
|
-
processes = module.processes
|
|
633
|
-
pool = module.pool
|
|
634
|
-
solpool = module.solpool
|
|
635
|
-
if module._branch_rng.uniform() <= module.solve_ratio:
|
|
636
|
-
ptb_sols, solpool = _solve_in_pass_3d(ptb_c, optmodel, processes, pool, solpool)
|
|
637
|
-
else:
|
|
638
|
-
solpool = require_solution_pool(solpool)
|
|
639
|
-
ptb_sols, solpool = _cache_in_pass_3d(ptb_c, optmodel, solpool)
|
|
640
|
-
module.solpool = solpool
|
|
641
|
-
return ptb_sols
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
def _solve_in_pass_3d(
|
|
645
|
-
ptb_c: torch.Tensor,
|
|
646
|
-
optmodel: optModel,
|
|
647
|
-
processes: int,
|
|
648
|
-
pool,
|
|
649
|
-
solpool: torch.Tensor | None = None,
|
|
650
|
-
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
651
|
-
"""
|
|
652
|
-
Solve optimization for perturbed 3D costs and update solution pool.
|
|
653
|
-
|
|
654
|
-
Args:
|
|
655
|
-
ptb_c: perturbed costs, shape (batch, n_samples, vars)
|
|
656
|
-
optmodel: optimization model
|
|
657
|
-
processes: number of processors
|
|
658
|
-
pool: process pool
|
|
659
|
-
solpool: solution pool
|
|
660
|
-
|
|
661
|
-
Returns:
|
|
662
|
-
tuple: (solutions shape (batch, n_samples, vars), updated solpool)
|
|
663
|
-
"""
|
|
664
|
-
ins_num, n_samples, num_vars = ptb_c.shape
|
|
665
|
-
# flatten (batch, n_samples, vars) → (batch * n_samples, vars)
|
|
666
|
-
flat_c = ptb_c.reshape(-1, num_vars)
|
|
667
|
-
# solve using shared 2D function
|
|
668
|
-
flat_sols, _ = _solve_batch_2d(flat_c, optmodel, processes, pool)
|
|
669
|
-
# update pool while flat_sols is still contiguous
|
|
670
|
-
if solpool is not None:
|
|
671
|
-
solpool = _update_solution_pool(flat_sols, solpool)
|
|
672
|
-
if solpool.device != ptb_c.device:
|
|
673
|
-
solpool = solpool.to(ptb_c.device)
|
|
674
|
-
# reshape (batch * n_samples, vars) → (batch, n_samples, vars)
|
|
675
|
-
ptb_sols = flat_sols.reshape(ins_num, n_samples, num_vars)
|
|
676
|
-
return ptb_sols, solpool
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
def _cache_in_pass_3d(
|
|
680
|
-
ptb_c: torch.Tensor,
|
|
681
|
-
optmodel: optModel,
|
|
682
|
-
solpool: torch.Tensor,
|
|
683
|
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
684
|
-
"""
|
|
685
|
-
Use solution pool for perturbed 3D costs (batch × n_samples × vars).
|
|
686
|
-
Unlike the 2D version in utils, this handles the extra sample dimension.
|
|
624
|
+
Solve or use cached solutions for perturbed costs (batch × n_samples × vars).
|
|
625
|
+
Flattens the sample axis and delegates to the shared 2D path in utils.
|
|
687
626
|
"""
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
# compute objective values: (batch, n_samples, pool_size)
|
|
692
|
-
solpool_obj = torch.einsum("bnd,sd->bns", ptb_c, solpool)
|
|
693
|
-
# best solution in pool
|
|
694
|
-
select = torch.argmin if is_minimize(optmodel.modelSense) else torch.argmax
|
|
695
|
-
best_inds = select(solpool_obj, dim=2)
|
|
696
|
-
ptb_sols = solpool[best_inds]
|
|
697
|
-
return ptb_sols, solpool
|
|
627
|
+
batch, n_samples, num_vars = ptb_c.shape
|
|
628
|
+
sol, _ = _solve_or_cache(ptb_c.reshape(-1, num_vars), module)
|
|
629
|
+
return sol.reshape(batch, n_samples, num_vars)
|
|
698
630
|
|
|
699
631
|
|
|
700
632
|
# acronym aliases
|
|
@@ -120,9 +120,8 @@ def _away_step_frank_wolfe(
|
|
|
120
120
|
vertex_norms[batch_idx, free_idx] = torch.where(
|
|
121
121
|
add_new, v_norm_sq, vertex_norms[batch_idx, free_idx]
|
|
122
122
|
)
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
)
|
|
123
|
+
# displaced mass is dropped
|
|
124
|
+
weights[batch_idx, free_idx] = torch.where(add_new, gamma_fw, weights[batch_idx, free_idx])
|
|
126
125
|
# away subtract, then clear FP residue so dropped atoms leave the active set
|
|
127
126
|
weights[batch_idx, away_idx] = weights[batch_idx, away_idx] - gamma_away
|
|
128
127
|
weights = weights.clamp(min=0.0)
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import itertools
|
|
5
6
|
import multiprocessing as mp
|
|
6
7
|
import weakref
|
|
7
8
|
from dataclasses import dataclass
|
|
@@ -63,20 +64,27 @@ def normalize_processes(
|
|
|
63
64
|
return cpu_count if processes == 0 else processes
|
|
64
65
|
|
|
65
66
|
|
|
67
|
+
# unique pool ids
|
|
68
|
+
_pool_ids = itertools.count()
|
|
69
|
+
|
|
70
|
+
|
|
66
71
|
def create_solver_pool(
|
|
67
72
|
optmodel: optModel,
|
|
68
73
|
processes: int,
|
|
69
74
|
*,
|
|
70
75
|
owner=None,
|
|
76
|
+
with_solver: bool = True,
|
|
71
77
|
) -> ProcessingPool | None:
|
|
72
78
|
"""Create a worker pool, optionally tied to an owner's lifetime."""
|
|
73
79
|
if processes == 1:
|
|
74
80
|
return None
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
initializer
|
|
78
|
-
|
|
81
|
+
# optional per-worker optmodel preload
|
|
82
|
+
init_kwargs = (
|
|
83
|
+
{"initializer": _init_worker_model, "initargs": (optmodel.to_spec(),)}
|
|
84
|
+
if with_solver
|
|
85
|
+
else {}
|
|
79
86
|
)
|
|
87
|
+
pool = ProcessingPool(processes, id=f"pyepo-{next(_pool_ids)}", **init_kwargs)
|
|
80
88
|
if owner is not None:
|
|
81
89
|
weakref.finalize(owner, _close_pool, pool)
|
|
82
90
|
return pool
|
|
@@ -117,7 +125,12 @@ def init_runtime(
|
|
|
117
125
|
raise ValueError(f"No reduction '{reduction}'.")
|
|
118
126
|
|
|
119
127
|
normalized_processes = normalize_processes(optmodel, processes, logger)
|
|
120
|
-
pool = create_solver_pool(
|
|
128
|
+
pool = create_solver_pool(
|
|
129
|
+
optmodel,
|
|
130
|
+
normalized_processes,
|
|
131
|
+
owner=owner,
|
|
132
|
+
with_solver=getattr(owner, "_pool_needs_optmodel", True),
|
|
133
|
+
)
|
|
121
134
|
logger.info("Num of cores: %d", normalized_processes)
|
|
122
135
|
return RuntimeState(
|
|
123
136
|
optmodel=optmodel,
|
|
@@ -13,7 +13,7 @@ from torch.autograd import Function
|
|
|
13
13
|
from pyepo.func._common import is_minimize, validate_positive
|
|
14
14
|
from pyepo.func.abcmodule import optModule
|
|
15
15
|
from pyepo.func.utils import _solve_or_cache
|
|
16
|
-
from pyepo.utils import _EPS
|
|
16
|
+
from pyepo.utils import _EPS, objective_offset
|
|
17
17
|
|
|
18
18
|
if TYPE_CHECKING:
|
|
19
19
|
from pyepo.data.dataset import optDataset
|
|
@@ -115,6 +115,11 @@ class SPOPlusFunc(Function):
|
|
|
115
115
|
# _check_sol(c, w, z)
|
|
116
116
|
# solve
|
|
117
117
|
sol, obj = _solve_or_cache(2 * cp - c, module)
|
|
118
|
+
# drop the bare objective constant
|
|
119
|
+
offset = objective_offset(module.optmodel)
|
|
120
|
+
if offset:
|
|
121
|
+
obj = obj - offset
|
|
122
|
+
z = z - offset
|
|
118
123
|
# calculate loss
|
|
119
124
|
if is_minimize(module.optmodel.modelSense):
|
|
120
125
|
loss = -obj + 2 * torch.einsum("bi,bi->b", cp, w) - z.squeeze(dim=-1)
|
|
@@ -20,7 +20,7 @@ from pyepo.func._common import (
|
|
|
20
20
|
validate_positive_int,
|
|
21
21
|
)
|
|
22
22
|
from pyepo.model.mpax import optMpaxModel
|
|
23
|
-
from pyepo.utils import costToNumpy
|
|
23
|
+
from pyepo.utils import costToNumpy, objective_offset
|
|
24
24
|
|
|
25
25
|
if TYPE_CHECKING:
|
|
26
26
|
from pyepo.func.abcmodule import optModule
|
|
@@ -87,8 +87,9 @@ def _solve_batch(
|
|
|
87
87
|
"""
|
|
88
88
|
A function to solve optimization in the forward/backward pass
|
|
89
89
|
"""
|
|
90
|
-
# get device
|
|
90
|
+
# get device and dtype
|
|
91
91
|
device = cp.device if isinstance(cp, torch.Tensor) else torch.device("cpu")
|
|
92
|
+
dtype = cp.dtype if isinstance(cp, torch.Tensor) else torch.float32
|
|
92
93
|
# MPAX batch solving
|
|
93
94
|
if isinstance(optmodel, optMpaxModel):
|
|
94
95
|
# get params
|
|
@@ -113,11 +114,17 @@ def _solve_batch(
|
|
|
113
114
|
# obj sense
|
|
114
115
|
if not is_minimize(optmodel.modelSense):
|
|
115
116
|
obj = -obj
|
|
117
|
+
# add the bare objective constant
|
|
118
|
+
offset = objective_offset(optmodel)
|
|
119
|
+
if offset:
|
|
120
|
+
obj = obj + offset
|
|
121
|
+
# match input dtype
|
|
122
|
+
sol, obj = sol.to(dtype), obj.to(dtype)
|
|
116
123
|
# host solving on numpy costs
|
|
117
124
|
else:
|
|
118
125
|
sol_np, obj_np = _solve_batch_np(costToNumpy(cp), optmodel, processes, pool)
|
|
119
|
-
sol = torch.as_tensor(sol_np).to(device)
|
|
120
|
-
obj = torch.as_tensor(obj_np).to(device)
|
|
126
|
+
sol = torch.as_tensor(sol_np).to(device=device, dtype=dtype)
|
|
127
|
+
obj = torch.as_tensor(obj_np).to(device=device, dtype=dtype)
|
|
121
128
|
return sol, obj
|
|
122
129
|
|
|
123
130
|
|
|
@@ -130,6 +137,8 @@ def _solve_batch_np(
|
|
|
130
137
|
"""
|
|
131
138
|
A function to solve a batch of numpy costs on the host, shared by both frontends
|
|
132
139
|
"""
|
|
140
|
+
# match input precision
|
|
141
|
+
out_dtype = np.result_type(cp.dtype, np.float32)
|
|
133
142
|
# single-core
|
|
134
143
|
if processes == 1:
|
|
135
144
|
sol_list: list = []
|
|
@@ -140,13 +149,13 @@ def _solve_batch_np(
|
|
|
140
149
|
sol_list.append(solp)
|
|
141
150
|
obj_list.append(objp)
|
|
142
151
|
# stack + dtype convert in a single call
|
|
143
|
-
sol = np.asarray(sol_list, dtype=
|
|
144
|
-
obj = np.asarray(obj_list, dtype=
|
|
152
|
+
sol = np.asarray(sol_list, dtype=out_dtype)
|
|
153
|
+
obj = np.asarray(obj_list, dtype=out_dtype)
|
|
145
154
|
# multi-core (workers pre-loaded with optmodel via pool initializer)
|
|
146
155
|
else:
|
|
147
156
|
res = pool.amap(_solve_with_obj_in_worker, cp).get()
|
|
148
|
-
sol = np.stack([r[0] for r in res]).astype(
|
|
149
|
-
obj = np.asarray([r[1] for r in res], dtype=
|
|
157
|
+
sol = np.stack([r[0] for r in res]).astype(out_dtype)
|
|
158
|
+
obj = np.asarray([r[1] for r in res], dtype=out_dtype)
|
|
150
159
|
return sol, obj
|
|
151
160
|
|
|
152
161
|
|
|
@@ -163,6 +172,9 @@ def _update_solution_pool(
|
|
|
163
172
|
return torch.unique(sol, dim=0).clone()
|
|
164
173
|
if sol.device != solpool.device:
|
|
165
174
|
sol = sol.to(solpool.device)
|
|
175
|
+
# match pool dtype
|
|
176
|
+
if solpool.dtype != sol.dtype:
|
|
177
|
+
solpool = solpool.to(sol.dtype)
|
|
166
178
|
sol_uniq = torch.unique(sol, dim=0)
|
|
167
179
|
# capped L1-tolerance dedup: first-order solvers re-emit near-identical vertices
|
|
168
180
|
tol = solution_pool_tolerance(sol.shape[-1])
|
|
@@ -181,15 +193,19 @@ def _cache_in_pass(
|
|
|
181
193
|
"""
|
|
182
194
|
A function to use solution pool in the forward/backward pass
|
|
183
195
|
"""
|
|
184
|
-
# move solpool to the correct device
|
|
185
|
-
if solpool.device != cp.device:
|
|
186
|
-
solpool = solpool.to(cp.device)
|
|
196
|
+
# move solpool to the correct device and dtype
|
|
197
|
+
if solpool.device != cp.device or solpool.dtype != cp.dtype:
|
|
198
|
+
solpool = solpool.to(device=cp.device, dtype=cp.dtype)
|
|
187
199
|
# best solution in pool
|
|
188
200
|
solpool_obj = torch.matmul(cp, solpool.T)
|
|
189
201
|
select = torch.argmin if is_minimize(optmodel.modelSense) else torch.argmax
|
|
190
202
|
ind = select(solpool_obj, dim=1)
|
|
191
203
|
obj = solpool_obj.gather(1, ind.view(-1, 1)).squeeze(1)
|
|
192
204
|
sol = solpool[ind]
|
|
205
|
+
# add the bare objective constant
|
|
206
|
+
offset = objective_offset(optmodel)
|
|
207
|
+
if offset:
|
|
208
|
+
obj = obj + offset
|
|
193
209
|
return sol, obj, solpool
|
|
194
210
|
|
|
195
211
|
|
|
@@ -30,12 +30,6 @@ def regret_from_objective(obj, true_obj, model_sense):
|
|
|
30
30
|
raise ValueError("Invalid modelSense.")
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
def objective_offset(optmodel: optModel) -> float:
|
|
34
|
-
"""Return a compiled DSL problem's bare objective constant, if present."""
|
|
35
|
-
problem = getattr(optmodel, "problem", None)
|
|
36
|
-
return float(problem.obj_offset) if problem is not None else 0.0
|
|
37
|
-
|
|
38
|
-
|
|
39
33
|
def require_linear_objective(optmodel: optModel) -> None:
|
|
40
34
|
"""Reject models carrying a quadratic objective term."""
|
|
41
35
|
problem = getattr(optmodel, "problem", None)
|
|
@@ -5,15 +5,13 @@ Mean Squared Error
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
|
-
from typing import TYPE_CHECKING
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
9
|
|
|
10
10
|
import torch
|
|
11
11
|
|
|
12
12
|
from pyepo.metric._common import torch_evaluation, validate_prediction_batch
|
|
13
13
|
|
|
14
14
|
if TYPE_CHECKING:
|
|
15
|
-
from collections.abc import Sized
|
|
16
|
-
|
|
17
15
|
from torch import nn
|
|
18
16
|
from torch.utils.data import DataLoader
|
|
19
17
|
|
|
@@ -24,19 +22,22 @@ def MSE(predmodel: nn.Module, dataloader: DataLoader) -> float:
|
|
|
24
22
|
|
|
25
23
|
Args:
|
|
26
24
|
predmodel: a regression neural network for cost prediction
|
|
27
|
-
dataloader: Torch dataloader from optDataSet
|
|
25
|
+
dataloader: Torch dataloader from optDataSet (fields beyond
|
|
26
|
+
``(x, c, w, z)`` are ignored)
|
|
28
27
|
|
|
29
28
|
Returns:
|
|
30
29
|
float: MSE loss
|
|
31
30
|
"""
|
|
32
31
|
loss = 0
|
|
32
|
+
total = 0
|
|
33
33
|
with torch_evaluation(predmodel) as device, torch.no_grad():
|
|
34
34
|
# load data
|
|
35
35
|
for data in dataloader:
|
|
36
|
-
x, c, _, _ = data
|
|
36
|
+
x, c, _, _ = data[:4]
|
|
37
37
|
x, c = x.to(device), c.to(device)
|
|
38
38
|
# predict
|
|
39
39
|
cp = predmodel(x)
|
|
40
40
|
validate_prediction_batch(cp, c)
|
|
41
41
|
loss += ((cp - c) ** 2).mean(dim=1).sum().item()
|
|
42
|
-
|
|
42
|
+
total += x.shape[0]
|
|
43
|
+
return loss / total if total else 0.0
|
|
@@ -15,14 +15,13 @@ from pyepo.func.runtime import create_solver_pool, normalize_processes
|
|
|
15
15
|
from pyepo.func.utils import _close_pool, _solve_batch
|
|
16
16
|
from pyepo.metric._common import (
|
|
17
17
|
normalize_regret,
|
|
18
|
-
objective_offset,
|
|
19
18
|
regret_from_objective,
|
|
20
19
|
require_linear_objective,
|
|
21
20
|
torch_evaluation,
|
|
22
21
|
validate_cost_vectors,
|
|
23
22
|
validate_prediction_batch,
|
|
24
23
|
)
|
|
25
|
-
from pyepo.utils import costToNumpy
|
|
24
|
+
from pyepo.utils import costToNumpy, objective_offset
|
|
26
25
|
|
|
27
26
|
if TYPE_CHECKING:
|
|
28
27
|
from collections.abc import Callable
|
|
@@ -65,7 +64,8 @@ def regret(
|
|
|
65
64
|
JAX callable ``f(x_numpy) -> cost_array``
|
|
66
65
|
optmodel: a PyEPO optimization model
|
|
67
66
|
dataloader: PyTorch DataLoader over an ``optDataset`` (yielding
|
|
68
|
-
``(x, c, w, z)`` tuples
|
|
67
|
+
``(x, c, w, z, ...)`` tuples; fields beyond the first four,
|
|
68
|
+
e.g. CaVE tight constraints, are ignored)
|
|
69
69
|
processes: number of processors, 1 for single-core, 0 for all of
|
|
70
70
|
cores; a fresh worker pool is spawned per call, each worker
|
|
71
71
|
rebuilding the model from its constructor args
|
|
@@ -89,7 +89,7 @@ def regret(
|
|
|
89
89
|
with torch_evaluation(torch_model) as device:
|
|
90
90
|
# load data
|
|
91
91
|
for data in dataloader:
|
|
92
|
-
x, c, _, z = data
|
|
92
|
+
x, c, _, z = data[:4]
|
|
93
93
|
if torch_model is not None:
|
|
94
94
|
x, c, z = x.to(device), c.to(device), z.to(device)
|
|
95
95
|
with torch.no_grad():
|
|
@@ -14,7 +14,6 @@ import torch
|
|
|
14
14
|
from pyepo import EPO
|
|
15
15
|
from pyepo.metric._common import (
|
|
16
16
|
normalize_regret,
|
|
17
|
-
objective_offset,
|
|
18
17
|
regret_from_objective,
|
|
19
18
|
require_linear_objective,
|
|
20
19
|
torch_evaluation,
|
|
@@ -23,7 +22,7 @@ from pyepo.metric._common import (
|
|
|
23
22
|
validate_retry_count,
|
|
24
23
|
validate_tolerance,
|
|
25
24
|
)
|
|
26
|
-
from pyepo.utils import costToNumpy
|
|
25
|
+
from pyepo.utils import costToNumpy, objective_offset
|
|
27
26
|
|
|
28
27
|
if TYPE_CHECKING:
|
|
29
28
|
from torch import nn
|
|
@@ -57,7 +56,8 @@ def unambRegret(
|
|
|
57
56
|
Args:
|
|
58
57
|
predmodel: a regression neural network for cost prediction
|
|
59
58
|
optmodel: a PyEPO optimization model
|
|
60
|
-
dataloader: PyTorch DataLoader over an ``optDataset``
|
|
59
|
+
dataloader: PyTorch DataLoader over an ``optDataset`` (fields beyond
|
|
60
|
+
``(x, c, w, z)`` are ignored)
|
|
61
61
|
tolerance: precision used when rounding predicted costs to find ties
|
|
62
62
|
max_iter: maximum number of solve retries with relaxed tolerance
|
|
63
63
|
|
|
@@ -72,8 +72,8 @@ def unambRegret(
|
|
|
72
72
|
with torch_evaluation(predmodel) as device:
|
|
73
73
|
# load data
|
|
74
74
|
for data in dataloader:
|
|
75
|
-
x, c,
|
|
76
|
-
x, c,
|
|
75
|
+
x, c, _, z = data[:4]
|
|
76
|
+
x, c, z = x.to(device), c.to(device), z.to(device)
|
|
77
77
|
# pred cost
|
|
78
78
|
with torch.no_grad():
|
|
79
79
|
cp = costToNumpy(predmodel(x))
|