pyoframe 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyoframe/__init__.py +12 -3
- pyoframe/_arithmetic.py +3 -6
- pyoframe/constants.py +20 -14
- pyoframe/{constraints.py → core.py} +504 -74
- pyoframe/io.py +66 -30
- pyoframe/io_mappers.py +66 -34
- pyoframe/model.py +65 -41
- pyoframe/model_element.py +128 -18
- pyoframe/monkey_patch.py +2 -2
- pyoframe/objective.py +16 -13
- pyoframe/solvers.py +300 -109
- pyoframe/user_defined.py +60 -0
- pyoframe/util.py +56 -55
- {pyoframe-0.0.4.dist-info → pyoframe-0.0.6.dist-info}/METADATA +9 -2
- pyoframe-0.0.6.dist-info/RECORD +18 -0
- pyoframe/variables.py +0 -193
- pyoframe-0.0.4.dist-info/RECORD +0 -18
- {pyoframe-0.0.4.dist-info → pyoframe-0.0.6.dist-info}/LICENSE +0 -0
- {pyoframe-0.0.4.dist-info → pyoframe-0.0.6.dist-info}/WHEEL +0 -0
- {pyoframe-0.0.4.dist-info → pyoframe-0.0.6.dist-info}/top_level.txt +0 -0
pyoframe/solvers.py
CHANGED
|
@@ -3,35 +3,91 @@ Code to interface with various solvers
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from abc import abstractmethod, ABC
|
|
6
|
+
from functools import lru_cache
|
|
6
7
|
from pathlib import Path
|
|
7
|
-
from typing import Optional, Union, TYPE_CHECKING
|
|
8
|
+
from typing import Any, Dict, Optional, Type, Union, TYPE_CHECKING
|
|
8
9
|
|
|
9
10
|
import polars as pl
|
|
10
11
|
|
|
11
12
|
from pyoframe.constants import (
|
|
12
13
|
DUAL_KEY,
|
|
13
|
-
NAME_COL,
|
|
14
14
|
SOLUTION_KEY,
|
|
15
|
+
SLACK_COL,
|
|
16
|
+
RC_COL,
|
|
17
|
+
VAR_KEY,
|
|
18
|
+
CONSTRAINT_KEY,
|
|
15
19
|
Result,
|
|
16
20
|
Solution,
|
|
17
21
|
Status,
|
|
18
22
|
)
|
|
19
23
|
import contextlib
|
|
24
|
+
import pyoframe as pf
|
|
20
25
|
|
|
21
26
|
from pathlib import Path
|
|
22
27
|
|
|
23
28
|
if TYPE_CHECKING: # pragma: no cover
|
|
24
29
|
from pyoframe.model import Model
|
|
25
30
|
|
|
31
|
+
available_solvers = []
|
|
32
|
+
solver_registry: Dict[str, Type["Solver"]] = {}
|
|
26
33
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
34
|
+
with contextlib.suppress(ImportError):
|
|
35
|
+
import gurobipy
|
|
36
|
+
|
|
37
|
+
available_solvers.append("gurobi")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _register_solver(solver_name):
|
|
41
|
+
def decorator(cls):
|
|
42
|
+
solver_registry[solver_name] = cls
|
|
43
|
+
return cls
|
|
44
|
+
|
|
45
|
+
return decorator
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def solve(
|
|
49
|
+
m: "Model",
|
|
50
|
+
solver=None,
|
|
51
|
+
directory: Optional[Union[Path, str]] = None,
|
|
52
|
+
use_var_names=False,
|
|
53
|
+
log_fn=None,
|
|
54
|
+
warmstart_fn=None,
|
|
55
|
+
basis_fn=None,
|
|
56
|
+
solution_file=None,
|
|
57
|
+
log_to_console=True,
|
|
58
|
+
):
|
|
59
|
+
if solver is None:
|
|
60
|
+
if len(available_solvers) == 0:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
"No solvers available. Please install a solving library like gurobipy."
|
|
63
|
+
)
|
|
64
|
+
solver = available_solvers[0]
|
|
65
|
+
|
|
66
|
+
if solver not in solver_registry:
|
|
31
67
|
raise ValueError(f"Solver {solver} not recognized or supported.")
|
|
32
68
|
|
|
69
|
+
solver_cls = solver_registry[solver]
|
|
70
|
+
m.solver = solver_cls(
|
|
71
|
+
m,
|
|
72
|
+
log_to_console,
|
|
73
|
+
params={param: value for param, value in m.params},
|
|
74
|
+
directory=directory,
|
|
75
|
+
)
|
|
76
|
+
m.solver_model = m.solver.create_solver_model(use_var_names)
|
|
77
|
+
m.solver.solver_model = m.solver_model
|
|
78
|
+
|
|
79
|
+
for attr_container in [m.variables, m.constraints, [m]]:
|
|
80
|
+
for container in attr_container:
|
|
81
|
+
for param_name, param_value in container.attr:
|
|
82
|
+
m.solver.set_attr(container, param_name, param_value)
|
|
83
|
+
|
|
84
|
+
result = m.solver.solve(log_fn, warmstart_fn, basis_fn, solution_file)
|
|
85
|
+
result = m.solver.process_result(result)
|
|
86
|
+
m.result = result
|
|
87
|
+
|
|
33
88
|
if result.solution is not None:
|
|
34
|
-
m.objective
|
|
89
|
+
if m.objective is not None:
|
|
90
|
+
m.objective.value = result.solution.objective
|
|
35
91
|
|
|
36
92
|
for variable in m.variables:
|
|
37
93
|
variable.solution = result.solution.primal
|
|
@@ -44,142 +100,277 @@ def solve(m: "Model", solver, **kwargs):
|
|
|
44
100
|
|
|
45
101
|
|
|
46
102
|
class Solver(ABC):
|
|
103
|
+
def __init__(self, model: "Model", log_to_console, params, directory):
|
|
104
|
+
self._model = model
|
|
105
|
+
self.solver_model: Optional[Any] = None
|
|
106
|
+
self.log_to_console: bool = log_to_console
|
|
107
|
+
self.params = params
|
|
108
|
+
self.directory = directory
|
|
109
|
+
|
|
110
|
+
@abstractmethod
|
|
111
|
+
def create_solver_model(self, use_var_names) -> Any: ...
|
|
112
|
+
|
|
47
113
|
@abstractmethod
|
|
48
|
-
def
|
|
114
|
+
def set_attr(self, element, param_name, param_value): ...
|
|
115
|
+
|
|
116
|
+
@abstractmethod
|
|
117
|
+
def solve(self, log_fn, warmstart_fn, basis_fn, solution_file) -> Result: ...
|
|
118
|
+
|
|
119
|
+
@abstractmethod
|
|
120
|
+
def process_result(self, results: Result) -> Result: ...
|
|
121
|
+
|
|
122
|
+
def load_rc(self):
|
|
123
|
+
rc = self._get_all_rc()
|
|
124
|
+
for variable in self._model.variables:
|
|
125
|
+
variable.RC = rc
|
|
126
|
+
|
|
127
|
+
def load_slack(self):
|
|
128
|
+
slack = self._get_all_slack()
|
|
129
|
+
for constraint in self._model.constraints:
|
|
130
|
+
constraint.slack = slack
|
|
131
|
+
|
|
132
|
+
@abstractmethod
|
|
133
|
+
def _get_all_rc(self): ...
|
|
134
|
+
|
|
135
|
+
@abstractmethod
|
|
136
|
+
def _get_all_slack(self): ...
|
|
137
|
+
|
|
138
|
+
def dispose(self):
|
|
139
|
+
"""
|
|
140
|
+
Clean up any resources that wouldn't be cleaned up by the garbage collector.
|
|
141
|
+
|
|
142
|
+
For now, this is only used by the Gurobi solver to call .dispose() on the solver model and Gurobi environment
|
|
143
|
+
which helps close a connection to the Gurobi Computer Server. Note that this effectively disables commands that
|
|
144
|
+
need access to the solver model (like .slack and .RC)
|
|
145
|
+
"""
|
|
49
146
|
|
|
50
147
|
|
|
51
148
|
class FileBasedSolver(Solver):
|
|
52
|
-
def
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
directory
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
) -> Result:
|
|
149
|
+
def __init__(self, *args, **kwargs):
|
|
150
|
+
super().__init__(*args, **kwargs)
|
|
151
|
+
self.problem_file: Optional[Path] = None
|
|
152
|
+
self.keep_files = self.directory is not None
|
|
153
|
+
|
|
154
|
+
def create_solver_model(self, use_var_names) -> Any:
|
|
59
155
|
problem_file = None
|
|
156
|
+
directory = self.directory
|
|
60
157
|
if directory is not None:
|
|
61
158
|
if isinstance(directory, str):
|
|
62
159
|
directory = Path(directory)
|
|
63
160
|
if not directory.exists():
|
|
64
161
|
directory.mkdir(parents=True)
|
|
65
|
-
filename =
|
|
162
|
+
filename = (
|
|
163
|
+
self._model.name if self._model.name is not None else "pyoframe-problem"
|
|
164
|
+
)
|
|
66
165
|
problem_file = directory / f"{filename}.lp"
|
|
67
|
-
problem_file =
|
|
68
|
-
|
|
166
|
+
self.problem_file = self._model.to_file(
|
|
167
|
+
problem_file, use_var_names=use_var_names
|
|
168
|
+
)
|
|
169
|
+
assert self._model.io_mappers is not None
|
|
170
|
+
return self.create_solver_model_from_lp()
|
|
171
|
+
|
|
172
|
+
@abstractmethod
|
|
173
|
+
def create_solver_model_from_lp(self) -> Any: ...
|
|
174
|
+
|
|
175
|
+
def set_attr(self, element, param_name, param_value):
|
|
176
|
+
if isinstance(param_value, pl.DataFrame):
|
|
177
|
+
if isinstance(element, pf.Variable):
|
|
178
|
+
param_value = self._model.io_mappers.var_map.apply(param_value)
|
|
179
|
+
elif isinstance(element, pf.Constraint):
|
|
180
|
+
param_value = self._model.io_mappers.const_map.apply(param_value)
|
|
181
|
+
return self.set_attr_unmapped(element, param_name, param_value)
|
|
69
182
|
|
|
70
|
-
|
|
183
|
+
@abstractmethod
|
|
184
|
+
def set_attr_unmapped(self, element, param_name, param_value): ...
|
|
71
185
|
|
|
186
|
+
def process_result(self, results: Result) -> Result:
|
|
72
187
|
if results.solution is not None:
|
|
73
|
-
results.solution.primal =
|
|
188
|
+
results.solution.primal = self._model.io_mappers.var_map.undo(
|
|
74
189
|
results.solution.primal
|
|
75
190
|
)
|
|
76
191
|
if results.solution.dual is not None:
|
|
77
|
-
results.solution.dual =
|
|
192
|
+
results.solution.dual = self._model.io_mappers.const_map.undo(
|
|
78
193
|
results.solution.dual
|
|
79
194
|
)
|
|
80
195
|
|
|
81
196
|
return results
|
|
82
197
|
|
|
198
|
+
def _get_all_rc(self):
|
|
199
|
+
return self._model.io_mappers.var_map.undo(self._get_all_rc_unmapped())
|
|
200
|
+
|
|
201
|
+
def _get_all_slack(self):
|
|
202
|
+
return self._model.io_mappers.const_map.undo(self._get_all_slack_unmapped())
|
|
203
|
+
|
|
83
204
|
@abstractmethod
|
|
84
|
-
def
|
|
205
|
+
def _get_all_rc_unmapped(self): ...
|
|
85
206
|
|
|
207
|
+
@abstractmethod
|
|
208
|
+
def _get_all_slack_unmapped(self): ...
|
|
86
209
|
|
|
210
|
+
|
|
211
|
+
@_register_solver("gurobi")
|
|
87
212
|
class GurobiSolver(FileBasedSolver):
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
213
|
+
# see https://www.gurobi.com/documentation/10.0/refman/optimization_status_codes.html
|
|
214
|
+
CONDITION_MAP = {
|
|
215
|
+
1: "unknown",
|
|
216
|
+
2: "optimal",
|
|
217
|
+
3: "infeasible",
|
|
218
|
+
4: "infeasible_or_unbounded",
|
|
219
|
+
5: "unbounded",
|
|
220
|
+
6: "other",
|
|
221
|
+
7: "iteration_limit",
|
|
222
|
+
8: "terminated_by_limit",
|
|
223
|
+
9: "time_limit",
|
|
224
|
+
10: "optimal",
|
|
225
|
+
11: "user_interrupt",
|
|
226
|
+
12: "other",
|
|
227
|
+
13: "suboptimal",
|
|
228
|
+
14: "unknown",
|
|
229
|
+
15: "terminated_by_limit",
|
|
230
|
+
16: "internal_solver_error",
|
|
231
|
+
17: "internal_solver_error",
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
def __init__(self, *args, **kwargs):
|
|
235
|
+
super().__init__(*args, **kwargs)
|
|
236
|
+
if not self.log_to_console:
|
|
237
|
+
self.params["LogToConsole"] = 0
|
|
238
|
+
self.env = None
|
|
239
|
+
|
|
240
|
+
def create_solver_model_from_lp(self) -> Any:
|
|
98
241
|
"""
|
|
99
242
|
Solve a linear problem using the gurobi solver.
|
|
100
243
|
|
|
101
244
|
This function communicates with gurobi using the gurubipy package.
|
|
102
245
|
"""
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
246
|
+
assert self.problem_file is not None
|
|
247
|
+
self.env = gurobipy.Env(params=self.params)
|
|
248
|
+
|
|
249
|
+
m = gurobipy.read(_path_to_str(self.problem_file), env=self.env)
|
|
250
|
+
if not self.keep_files:
|
|
251
|
+
self.problem_file.unlink()
|
|
252
|
+
|
|
253
|
+
return m
|
|
254
|
+
|
|
255
|
+
@lru_cache
|
|
256
|
+
def _get_var_mapping(self):
|
|
257
|
+
assert self.solver_model is not None
|
|
258
|
+
vars = self.solver_model.getVars()
|
|
259
|
+
return vars, pl.DataFrame(
|
|
260
|
+
{VAR_KEY: self.solver_model.getAttr("VarName", vars)}
|
|
261
|
+
).with_columns(i=pl.int_range(pl.len()))
|
|
262
|
+
|
|
263
|
+
@lru_cache
|
|
264
|
+
def _get_constraint_mapping(self):
|
|
265
|
+
assert self.solver_model is not None
|
|
266
|
+
constraints = self.solver_model.getConstrs()
|
|
267
|
+
return constraints, pl.DataFrame(
|
|
268
|
+
{CONSTRAINT_KEY: self.solver_model.getAttr("ConstrName", constraints)}
|
|
269
|
+
).with_columns(i=pl.int_range(pl.len()))
|
|
270
|
+
|
|
271
|
+
def set_attr_unmapped(self, element, param_name, param_value):
|
|
272
|
+
assert self.solver_model is not None
|
|
273
|
+
if isinstance(element, pf.Model):
|
|
274
|
+
self.solver_model.setAttr(param_name, param_value)
|
|
275
|
+
elif isinstance(element, pf.Variable):
|
|
276
|
+
v, v_map = self._get_var_mapping()
|
|
277
|
+
param_value = param_value.join(v_map, on=VAR_KEY, how="left").drop(VAR_KEY)
|
|
278
|
+
self.solver_model.setAttr(
|
|
279
|
+
param_name,
|
|
280
|
+
[v[i] for i in param_value["i"]],
|
|
281
|
+
param_value[param_name],
|
|
282
|
+
)
|
|
283
|
+
elif isinstance(element, pf.Constraint):
|
|
284
|
+
c, c_map = self._get_constraint_mapping()
|
|
285
|
+
param_value = param_value.join(c_map, on=CONSTRAINT_KEY, how="left").drop(
|
|
286
|
+
CONSTRAINT_KEY
|
|
287
|
+
)
|
|
288
|
+
self.solver_model.setAttr(
|
|
289
|
+
param_name,
|
|
290
|
+
[c[i] for i in param_value["i"]],
|
|
291
|
+
param_value[param_name],
|
|
292
|
+
)
|
|
293
|
+
else:
|
|
294
|
+
raise ValueError(f"Element type {type(element)} not recognized.")
|
|
295
|
+
|
|
296
|
+
def solve(self, log_fn, warmstart_fn, basis_fn, solution_file) -> Result:
|
|
297
|
+
assert self.solver_model is not None
|
|
298
|
+
m = self.solver_model
|
|
299
|
+
if log_fn is not None:
|
|
300
|
+
m.setParam("logfile", _path_to_str(log_fn))
|
|
301
|
+
if warmstart_fn:
|
|
302
|
+
m.read(_path_to_str(warmstart_fn))
|
|
303
|
+
|
|
304
|
+
m.optimize()
|
|
305
|
+
|
|
306
|
+
if basis_fn:
|
|
307
|
+
try:
|
|
308
|
+
m.write(_path_to_str(basis_fn))
|
|
309
|
+
except gurobipy.GurobiError as err:
|
|
310
|
+
print("No model basis stored. Raised error: %s", err)
|
|
311
|
+
|
|
312
|
+
condition = m.status
|
|
313
|
+
termination_condition = GurobiSolver.CONDITION_MAP.get(condition, condition)
|
|
314
|
+
status = Status.from_termination_condition(termination_condition)
|
|
315
|
+
|
|
316
|
+
if status.is_ok:
|
|
317
|
+
if solution_file:
|
|
318
|
+
m.write(_path_to_str(solution_file))
|
|
319
|
+
|
|
320
|
+
objective = m.ObjVal
|
|
321
|
+
vars = m.getVars()
|
|
322
|
+
sol = pl.DataFrame(
|
|
323
|
+
{
|
|
324
|
+
VAR_KEY: m.getAttr("VarName", vars),
|
|
325
|
+
SOLUTION_KEY: m.getAttr("X", vars),
|
|
326
|
+
}
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
constraints = m.getConstrs()
|
|
330
|
+
try:
|
|
331
|
+
dual = pl.DataFrame(
|
|
158
332
|
{
|
|
159
|
-
|
|
160
|
-
|
|
333
|
+
DUAL_KEY: m.getAttr("Pi", constraints),
|
|
334
|
+
CONSTRAINT_KEY: m.getAttr("ConstrName", constraints),
|
|
161
335
|
}
|
|
162
336
|
)
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
def
|
|
337
|
+
except gurobipy.GurobiError:
|
|
338
|
+
dual = None
|
|
339
|
+
|
|
340
|
+
solution = Solution(sol, dual, objective)
|
|
341
|
+
else:
|
|
342
|
+
solution = None
|
|
343
|
+
|
|
344
|
+
return Result(status, solution)
|
|
345
|
+
|
|
346
|
+
def _get_all_rc_unmapped(self):
|
|
347
|
+
m = self._model.solver_model
|
|
348
|
+
vars = m.getVars()
|
|
349
|
+
return pl.DataFrame(
|
|
350
|
+
{
|
|
351
|
+
RC_COL: m.getAttr("RC", vars),
|
|
352
|
+
VAR_KEY: m.getAttr("VarName", vars),
|
|
353
|
+
}
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
def _get_all_slack_unmapped(self):
|
|
357
|
+
m = self._model.solver_model
|
|
358
|
+
constraints = m.getConstrs()
|
|
359
|
+
return pl.DataFrame(
|
|
360
|
+
{
|
|
361
|
+
SLACK_COL: m.getAttr("Slack", constraints),
|
|
362
|
+
CONSTRAINT_KEY: m.getAttr("ConstrName", constraints),
|
|
363
|
+
}
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
def dispose(self):
|
|
367
|
+
if self.solver_model is not None:
|
|
368
|
+
self.solver_model.dispose()
|
|
369
|
+
if self.env is not None:
|
|
370
|
+
self.env.dispose()
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def _path_to_str(path: Union[Path, str]) -> str:
|
|
183
374
|
"""
|
|
184
375
|
Convert a pathlib.Path to a string.
|
|
185
376
|
"""
|
pyoframe/user_defined.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Contains the base classes to support .params and .attr containers for user-defined parameters and attributes.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Container:
|
|
9
|
+
"""
|
|
10
|
+
A container for user-defined attributes or parameters.
|
|
11
|
+
|
|
12
|
+
Parameters:
|
|
13
|
+
preprocess : Callable[str, Any], optional
|
|
14
|
+
A function to preprocess user-defined values before adding them to the container.
|
|
15
|
+
|
|
16
|
+
Examples:
|
|
17
|
+
>>> params = Container()
|
|
18
|
+
>>> params.a = 1
|
|
19
|
+
>>> params.b = 2
|
|
20
|
+
>>> params.a
|
|
21
|
+
1
|
|
22
|
+
>>> params.b
|
|
23
|
+
2
|
|
24
|
+
>>> for k, v in params:
|
|
25
|
+
... print(k, v)
|
|
26
|
+
a 1
|
|
27
|
+
b 2
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, preprocess=None):
|
|
31
|
+
self._preprocess = preprocess
|
|
32
|
+
self._attributes = {}
|
|
33
|
+
|
|
34
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
35
|
+
if name.startswith("_"):
|
|
36
|
+
return super().__setattr__(name, value)
|
|
37
|
+
if self._preprocess is not None:
|
|
38
|
+
value = self._preprocess(name, value)
|
|
39
|
+
self._attributes[name] = value
|
|
40
|
+
|
|
41
|
+
def __getattr__(self, name: str) -> Any:
|
|
42
|
+
if name.startswith("_"):
|
|
43
|
+
return super().__getattribute__(name)
|
|
44
|
+
return self._attributes[name]
|
|
45
|
+
|
|
46
|
+
def __iter__(self):
|
|
47
|
+
return iter(self._attributes.items())
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class AttrContainerMixin:
|
|
51
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
52
|
+
super().__init__(*args, **kwargs)
|
|
53
|
+
self.attr = Container(preprocess=self._preprocess_attr)
|
|
54
|
+
|
|
55
|
+
def _preprocess_attr(self, name: str, value: Any) -> Any:
|
|
56
|
+
"""
|
|
57
|
+
Preprocesses user-defined values before adding them to the Params container.
|
|
58
|
+
By default this function does nothing but subclasses can override it.
|
|
59
|
+
"""
|
|
60
|
+
return value
|
pyoframe/util.py
CHANGED
|
@@ -2,68 +2,17 @@
|
|
|
2
2
|
File containing utility functions and classes.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
from
|
|
5
|
+
from typing import Any, Iterable, Optional, Union, List, Dict
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
8
|
|
|
9
9
|
import polars as pl
|
|
10
10
|
import pandas as pd
|
|
11
|
+
from functools import wraps
|
|
11
12
|
|
|
12
13
|
from pyoframe.constants import COEF_KEY, CONST_TERM, RESERVED_COL_KEYS, VAR_KEY
|
|
13
14
|
|
|
14
15
|
|
|
15
|
-
class IdCounterMixin(ABC):
|
|
16
|
-
"""
|
|
17
|
-
Provides a method that assigns a unique ID to each row in a DataFrame.
|
|
18
|
-
IDs start at 1 and go up consecutively. No zero ID is assigned since it is reserved for the constant variable term.
|
|
19
|
-
IDs are only unique for the subclass since different subclasses have different counters.
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
# Keys are the subclass names and values are the next unasigned ID.
|
|
23
|
-
_id_counters: Dict[str, int] = defaultdict(lambda: 1)
|
|
24
|
-
|
|
25
|
-
@classmethod
|
|
26
|
-
def _reset_counters(cls):
|
|
27
|
-
"""
|
|
28
|
-
Resets all the ID counters.
|
|
29
|
-
This function is called before every unit test to reset the code state.
|
|
30
|
-
"""
|
|
31
|
-
cls._id_counters = defaultdict(lambda: 1)
|
|
32
|
-
|
|
33
|
-
def _assign_ids(self, df: pl.DataFrame) -> pl.DataFrame:
|
|
34
|
-
"""
|
|
35
|
-
Adds the column `to_column` to the DataFrame `df` with the next batch
|
|
36
|
-
of unique consecutive IDs.
|
|
37
|
-
"""
|
|
38
|
-
cls_name = self.__class__.__name__
|
|
39
|
-
cur_count = self._id_counters[cls_name]
|
|
40
|
-
id_col_name = self.get_id_column_name()
|
|
41
|
-
|
|
42
|
-
if df.height == 0:
|
|
43
|
-
df = df.with_columns(pl.lit(cur_count).alias(id_col_name))
|
|
44
|
-
else:
|
|
45
|
-
df = df.with_columns(
|
|
46
|
-
pl.int_range(cur_count, cur_count + pl.len()).alias(id_col_name)
|
|
47
|
-
)
|
|
48
|
-
df = df.with_columns(pl.col(id_col_name).cast(pl.UInt32))
|
|
49
|
-
self._id_counters[cls_name] += df.height
|
|
50
|
-
return df
|
|
51
|
-
|
|
52
|
-
@classmethod
|
|
53
|
-
@abstractmethod
|
|
54
|
-
def get_id_column_name(cls) -> str:
|
|
55
|
-
"""
|
|
56
|
-
Returns the name of the column containing the IDs.
|
|
57
|
-
"""
|
|
58
|
-
|
|
59
|
-
@property
|
|
60
|
-
@abstractmethod
|
|
61
|
-
def ids(self) -> pl.DataFrame:
|
|
62
|
-
"""
|
|
63
|
-
Returns a dataframe with the IDs and any other relevant columns (i.e. the dimension columns).
|
|
64
|
-
"""
|
|
65
|
-
|
|
66
|
-
|
|
67
16
|
def get_obj_repr(obj: object, _props: Iterable[str] = (), **kwargs):
|
|
68
17
|
"""
|
|
69
18
|
Helper function to generate __repr__ strings for classes. See usage for examples.
|
|
@@ -269,3 +218,55 @@ def cast_coef_to_string(
|
|
|
269
218
|
return df.with_columns(pl.concat_str("_sign", column_name).alias(column_name)).drop(
|
|
270
219
|
"_sign"
|
|
271
220
|
)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def unwrap_single_values(func):
|
|
224
|
+
"""Decorator for functions that return DataFrames. Returned dataframes with a single value will instead return the value."""
|
|
225
|
+
|
|
226
|
+
@wraps(func)
|
|
227
|
+
def wrapper(*args, **kwargs):
|
|
228
|
+
result = func(*args, **kwargs)
|
|
229
|
+
if isinstance(result, pl.DataFrame) and result.shape == (1, 1):
|
|
230
|
+
return result.item()
|
|
231
|
+
return result
|
|
232
|
+
|
|
233
|
+
return wrapper
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def dataframe_to_tupled_list(
|
|
237
|
+
df: pl.DataFrame, num_max_elements: Optional[int] = None
|
|
238
|
+
) -> str:
|
|
239
|
+
"""
|
|
240
|
+
Converts a dataframe into a list of tuples. Used to print a Set to the console. See examples for behaviour.
|
|
241
|
+
|
|
242
|
+
Examples:
|
|
243
|
+
>>> df = pl.DataFrame({"x": [1, 2, 3, 4, 5]})
|
|
244
|
+
>>> dataframe_to_tupled_list(df)
|
|
245
|
+
'[1, 2, 3, 4, 5]'
|
|
246
|
+
>>> dataframe_to_tupled_list(df, 3)
|
|
247
|
+
'[1, 2, 3, ...]'
|
|
248
|
+
|
|
249
|
+
>>> df = pl.DataFrame({"x": [1, 2, 3, 4, 5], "y": [2, 3, 4, 5, 6]})
|
|
250
|
+
>>> dataframe_to_tupled_list(df, 3)
|
|
251
|
+
'[(1, 2), (2, 3), (3, 4), ...]'
|
|
252
|
+
"""
|
|
253
|
+
elipse = False
|
|
254
|
+
if num_max_elements is not None:
|
|
255
|
+
if len(df) > num_max_elements:
|
|
256
|
+
elipse = True
|
|
257
|
+
df = df.head(num_max_elements)
|
|
258
|
+
|
|
259
|
+
res = (row for row in df.iter_rows())
|
|
260
|
+
if len(df.columns) == 1:
|
|
261
|
+
res = (row[0] for row in res)
|
|
262
|
+
|
|
263
|
+
res = str(list(res))
|
|
264
|
+
if elipse:
|
|
265
|
+
res = res[:-1] + ", ...]"
|
|
266
|
+
return res
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
@dataclass
|
|
270
|
+
class FuncArgs:
|
|
271
|
+
args: List
|
|
272
|
+
kwargs: Dict = field(default_factory=dict)
|