pyepo 2.2.1__tar.gz → 2.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. {pyepo-2.2.1 → pyepo-2.2.2}/PKG-INFO +1 -1
  2. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/metric/regret.py +30 -14
  3. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo.egg-info/PKG-INFO +1 -1
  4. {pyepo-2.2.1 → pyepo-2.2.2}/pyproject.toml +1 -1
  5. {pyepo-2.2.1 → pyepo-2.2.2}/tests/test_60_metric.py +75 -1
  6. {pyepo-2.2.1 → pyepo-2.2.2}/LICENSE +0 -0
  7. {pyepo-2.2.1 → pyepo-2.2.2}/README.md +0 -0
  8. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/EPO.py +0 -0
  9. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/__init__.py +0 -0
  10. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/data/__init__.py +0 -0
  11. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/data/dataset.py +0 -0
  12. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/data/knapsack.py +0 -0
  13. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/data/portfolio.py +0 -0
  14. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/data/shortestpath.py +0 -0
  15. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/data/tsp.py +0 -0
  16. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/dsl/__init__.py +0 -0
  17. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/dsl/compiled.py +0 -0
  18. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/dsl/expression.py +0 -0
  19. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/dsl/objective.py +0 -0
  20. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/dsl/problem.py +0 -0
  21. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/__init__.py +0 -0
  22. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/abcmodule.py +0 -0
  23. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/blackbox.py +0 -0
  24. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/cave.py +0 -0
  25. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/contrastive.py +0 -0
  26. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/jax/__init__.py +0 -0
  27. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/jax/abcmodule.py +0 -0
  28. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/jax/blackbox.py +0 -0
  29. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/jax/cave.py +0 -0
  30. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/jax/contrastive.py +0 -0
  31. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/jax/perturbed.py +0 -0
  32. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/jax/rank.py +0 -0
  33. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/jax/regularized.py +0 -0
  34. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/jax/surrogate.py +0 -0
  35. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/jax/utils.py +0 -0
  36. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/perturbed.py +0 -0
  37. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/rank.py +0 -0
  38. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/regularized.py +0 -0
  39. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/surrogate.py +0 -0
  40. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/func/utils.py +0 -0
  41. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/metric/__init__.py +0 -0
  42. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/metric/metrics.py +0 -0
  43. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/metric/mse.py +0 -0
  44. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/metric/unambregret.py +0 -0
  45. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/__init__.py +0 -0
  46. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/bases.py +0 -0
  47. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/copt/__init__.py +0 -0
  48. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/copt/compile.py +0 -0
  49. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/copt/coptmodel.py +0 -0
  50. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/copt/knapsack.py +0 -0
  51. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/copt/portfolio.py +0 -0
  52. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/copt/shortestpath.py +0 -0
  53. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/copt/tsp.py +0 -0
  54. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/copt/vrp.py +0 -0
  55. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/grb/__init__.py +0 -0
  56. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/grb/compile.py +0 -0
  57. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/grb/grbmodel.py +0 -0
  58. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/grb/knapsack.py +0 -0
  59. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/grb/portfolio.py +0 -0
  60. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/grb/shortestpath.py +0 -0
  61. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/grb/tsp.py +0 -0
  62. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/grb/vrp.py +0 -0
  63. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/mpax/__init__.py +0 -0
  64. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/mpax/compile.py +0 -0
  65. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/mpax/knapsack.py +0 -0
  66. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/mpax/mpaxmodel.py +0 -0
  67. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/mpax/shortestpath.py +0 -0
  68. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/omo/__init__.py +0 -0
  69. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/omo/compile.py +0 -0
  70. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/omo/knapsack.py +0 -0
  71. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/omo/omomodel.py +0 -0
  72. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/omo/portfolio.py +0 -0
  73. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/omo/shortestpath.py +0 -0
  74. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/omo/tsp.py +0 -0
  75. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/omo/vrp.py +0 -0
  76. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/opt.py +0 -0
  77. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/ort/__init__.py +0 -0
  78. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/ort/compile.py +0 -0
  79. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/ort/knapsack.py +0 -0
  80. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/ort/ortcpmodel.py +0 -0
  81. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/ort/ortmodel.py +0 -0
  82. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/ort/shortestpath.py +0 -0
  83. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/predefined.py +0 -0
  84. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/model/utils.py +0 -0
  85. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/py.typed +0 -0
  86. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/twostage/__init__.py +0 -0
  87. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/twostage/autosklearnpred.py +0 -0
  88. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/twostage/sklearnpred.py +0 -0
  89. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo/utils.py +0 -0
  90. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo.egg-info/SOURCES.txt +0 -0
  91. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo.egg-info/dependency_links.txt +0 -0
  92. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo.egg-info/requires.txt +0 -0
  93. {pyepo-2.2.1 → pyepo-2.2.2}/pyepo.egg-info/top_level.txt +0 -0
  94. {pyepo-2.2.1 → pyepo-2.2.2}/setup.cfg +0 -0
  95. {pyepo-2.2.1 → pyepo-2.2.2}/tests/test_00_constants.py +0 -0
  96. {pyepo-2.2.1 → pyepo-2.2.2}/tests/test_10_utils.py +0 -0
  97. {pyepo-2.2.1 → pyepo-2.2.2}/tests/test_15_dsl.py +0 -0
  98. {pyepo-2.2.1 → pyepo-2.2.2}/tests/test_20_data_gen.py +0 -0
  99. {pyepo-2.2.1 → pyepo-2.2.2}/tests/test_30_model.py +0 -0
  100. {pyepo-2.2.1 → pyepo-2.2.2}/tests/test_40_dataset.py +0 -0
  101. {pyepo-2.2.1 → pyepo-2.2.2}/tests/test_50_func.py +0 -0
  102. {pyepo-2.2.1 → pyepo-2.2.2}/tests/test_55_jax.py +0 -0
  103. {pyepo-2.2.1 → pyepo-2.2.2}/tests/test_70_twostage.py +0 -0
  104. {pyepo-2.2.1 → pyepo-2.2.2}/tests/test_80_integration.py +0 -0
  105. {pyepo-2.2.1 → pyepo-2.2.2}/tests/test_85_backend_pipeline.py +0 -0
  106. {pyepo-2.2.1 → pyepo-2.2.2}/tests/test_90_cuda.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyepo
3
- Version: 2.2.1
3
+ Version: 2.2.2
4
4
  Summary: PyTorch-based End-to-End Predict-then-Optimize Tool
5
5
  Author-email: Bo Tang <bolucas.tang@mail.utoronto.ca>
6
6
  License: MIT
@@ -19,6 +19,8 @@ from pyepo.model.mpax import optMpaxModel
19
19
  from pyepo.utils import _EPS, costToNumpy, getArgs
20
20
 
21
21
  if TYPE_CHECKING:
22
+ from collections.abc import Callable
23
+
22
24
  from torch import nn
23
25
  from torch.utils.data import DataLoader
24
26
 
@@ -75,7 +77,7 @@ def _checkLinearObj(optmodel) -> None:
75
77
 
76
78
 
77
79
  def regret(
78
- predmodel: nn.Module,
80
+ predmodel: nn.Module | Callable,
79
81
  optmodel: optModel,
80
82
  dataloader: DataLoader,
81
83
  processes: int = 1,
@@ -91,11 +93,17 @@ def regret(
91
93
  z^*(\\mathbf{c}_i)`. With the default ``reduction="normalized"`` the
92
94
  result is :math:`\\sum_i l_i / \\sum_i |z^*(\\mathbf{c}_i)|`,
93
95
  dimensionless and comparable across problem scales; instances with
94
- near-zero true optima inflate the ratio. The predictor is evaluated
95
- under ``eval()``; its original mode is restored afterwards.
96
+ near-zero true optima inflate the ratio. PyTorch predictors are
97
+ evaluated under ``eval()``; the original mode is restored afterwards.
98
+
99
+ ``predmodel`` may also be a plain callable ``f(x: np.ndarray) ->
100
+ array-like`` for JAX/Flax models; pass a ``functools.partial`` that
101
+ closes over the current parameter pytree, e.g.
102
+ ``functools.partial(model.apply, params)``.
96
103
 
97
104
  Args:
98
- predmodel: a regression neural network for cost prediction
105
+ predmodel: a PyTorch ``nn.Module`` for cost prediction, or a
106
+ JAX callable ``f(x_numpy) -> cost_array``
99
107
  optmodel: a PyEPO optimization model
100
108
  dataloader: PyTorch DataLoader over an ``optDataset`` (yielding
101
109
  ``(x, c, w, z)`` tuples)
@@ -122,9 +130,12 @@ def regret(
122
130
  processes = mp.cpu_count() if processes == 0 else processes
123
131
  losses = []
124
132
  optsum = 0.0
125
- # get device (cpu fallback for parameterless predictors)
126
- param = next(predmodel.parameters(), None)
127
- device = param.device if param is not None else torch.device("cpu")
133
+ _is_torch = isinstance(predmodel, torch.nn.Module)
134
+ # get device (cpu fallback for parameterless predictors; JAX callables always use cpu)
135
+ device = torch.device("cpu")
136
+ if _is_torch:
137
+ param = next(predmodel.parameters(), None)
138
+ device = param.device if param is not None else torch.device("cpu")
128
139
  # multi-core: each worker builds its own optmodel once via the initializer
129
140
  pool = None
130
141
  if processes > 1:
@@ -134,16 +145,20 @@ def regret(
134
145
  initargs=(type(optmodel), getArgs(optmodel)),
135
146
  )
136
147
  # evaluate under eval(); the original mode is restored afterwards
137
- was_training = predmodel.training
138
- predmodel.eval()
148
+ was_training = predmodel.training if _is_torch else False
149
+ if _is_torch:
150
+ predmodel.eval()
139
151
  try:
140
152
  # load data
141
153
  for data in dataloader:
142
154
  x, c, _, z = data
143
- x, c, z = x.to(device), c.to(device), z.to(device)
144
- # predict and batch-solve all instances in one call
145
- with torch.no_grad():
146
- cp = predmodel(x)
155
+ if _is_torch:
156
+ x, c, z = x.to(device), c.to(device), z.to(device)
157
+ with torch.no_grad():
158
+ cp = predmodel(x)
159
+ else:
160
+ # JAX callable: f(x_numpy) -> array-like
161
+ cp = torch.as_tensor(np.asarray(predmodel(x.numpy())), dtype=torch.float32)
147
162
  # full cost so the MPAX backend can batch-set the objective in one call
148
163
  sols, _ = _solve_batch(optmodel._fullCost(cp), optmodel, processes=processes, pool=pool)
149
164
  # vectorized regret accumulation (one host sync per batch)
@@ -158,7 +173,8 @@ def regret(
158
173
  if pool is not None:
159
174
  _close_pool(pool)
160
175
  # restore the original mode even if evaluation raises
161
- predmodel.train(was_training)
176
+ if _is_torch:
177
+ predmodel.train(was_training)
162
178
  loss = np.concatenate(losses) if losses else np.empty(0)
163
179
  # reduce
164
180
  if reduction == "normalized":
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyepo
3
- Version: 2.2.1
3
+ Version: 2.2.2
4
4
  Summary: PyTorch-based End-to-End Predict-then-Optimize Tool
5
5
  Author-email: Bo Tang <bolucas.tang@mail.utoronto.ca>
6
6
  License: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "pyepo"
7
- version = "2.2.1"
7
+ version = "2.2.2"
8
8
  description = "PyTorch-based End-to-End Predict-then-Optimize Tool"
9
9
  readme = { file = "README.md", content-type = "text/markdown" }
10
10
  license = { text = "MIT" }
@@ -20,7 +20,12 @@ from pyepo.metric.mse import MSE
20
20
  from pyepo.metric.regret import _regretFromObj, calRegret
21
21
  from pyepo.metric.unambregret import calUnambRegret
22
22
 
23
- from .conftest import NUM_FEAT, LinearPred, requires_gurobi
23
+ from .conftest import (
24
+ NUM_FEAT, GRID, NUM_DATA, BATCH,
25
+ LinearPred,
26
+ requires_gurobi, requires_jax, requires_mpax,
27
+ _HAS_GUROBI, _HAS_JAX,
28
+ )
24
29
 
25
30
 
26
31
  class _IdentityModel(nn.Module):
@@ -471,6 +476,75 @@ class TestSkScorer:
471
476
  assert not np.isclose(expected, swapped)
472
477
 
473
478
 
479
+ # ============================================================
480
+ # JAX callable path for pyepo.metric.regret
481
+ # ============================================================
482
+
483
+ @requires_mpax
484
+ class TestDataloaderMetricsJax:
485
+ """pyepo.metric.regret accepts a plain callable f(x_numpy) -> array."""
486
+
487
+ @pytest.fixture(scope="class")
488
+ def mpax_data(self):
489
+ import pyepo
490
+ from pyepo.data.dataset import optDataset
491
+ from pyepo.model.mpax.shortestpath import shortestPathModel
492
+
493
+ x, c = pyepo.data.shortestpath.genData(NUM_DATA, NUM_FEAT, GRID, seed=42)
494
+ optmodel = shortestPathModel(grid=GRID)
495
+ dataset = optDataset(optmodel, x, c)
496
+ loader = DataLoader(dataset, batch_size=BATCH, shuffle=False)
497
+ return optmodel, dataset, loader
498
+
499
+ def test_callable_returns_non_negative_float(self, mpax_data):
500
+ import pyepo
501
+ optmodel, _ds, loader = mpax_data
502
+ reg = pyepo.metric.regret(lambda x: x, optmodel, loader)
503
+ assert isinstance(reg, float) and reg >= 0
504
+
505
+ def test_callable_reductions_consistent(self, mpax_data):
506
+ import pyepo
507
+ optmodel, ds, loader = mpax_data
508
+ fn = lambda x: x # noqa: E731
509
+ per = pyepo.metric.regret(fn, optmodel, loader, reduction="none")
510
+ total = pyepo.metric.regret(fn, optmodel, loader, reduction="sum")
511
+ assert isinstance(per, np.ndarray) and len(per) == len(ds)
512
+ assert total == pytest.approx(per.sum(), rel=1e-5)
513
+
514
+
515
+ @pytest.mark.skipif(
516
+ not (_HAS_GUROBI and _HAS_JAX),
517
+ reason="Parity: Gurobi (exact solve) + JAX both required",
518
+ )
519
+ class TestDataloaderMetricsJaxParity:
520
+ """JAX callable and torch nn.Module with identical weights give the same regret."""
521
+
522
+ def test_same_weights_same_regret(self, sp_data):
523
+ import functools
524
+
525
+ import jax.numpy as jnp
526
+ import pyepo
527
+ from flax import linen as nn
528
+
529
+ optmodel, _ds, loader = sp_data
530
+
531
+ # random torch predictor; capture weights as numpy
532
+ torch_pred = LinearPred(NUM_FEAT, optmodel.num_cost)
533
+ torch_pred.eval()
534
+ w = torch_pred.linear.weight.detach().numpy() # (num_cost, num_feat)
535
+ b = torch_pred.linear.bias.detach().numpy() # (num_cost,)
536
+
537
+ # Flax Dense: kernel shape is (num_feat, num_cost) = w.T
538
+ flax_pred = nn.Dense(optmodel.num_cost)
539
+ params = {"params": {"kernel": jnp.asarray(w.T), "bias": jnp.asarray(b)}}
540
+
541
+ torch_reg = pyepo.metric.regret(torch_pred, optmodel, loader)
542
+ jax_reg = pyepo.metric.regret(
543
+ functools.partial(flax_pred.apply, params), optmodel, loader
544
+ )
545
+ assert jax_reg == pytest.approx(torch_reg, abs=1e-4)
546
+
547
+
474
548
  class TestAutoSkScorer:
475
549
 
476
550
  def test_raises_without_autosklearn(self):
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes