pyoframe 1.0.0a0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyoframe/_arithmetic.py +176 -174
- pyoframe/_constants.py +94 -47
- pyoframe/_core.py +225 -148
- pyoframe/_model.py +46 -26
- pyoframe/_model_element.py +32 -18
- pyoframe/_monkey_patch.py +15 -13
- pyoframe/_objective.py +2 -4
- pyoframe/_utils.py +8 -9
- pyoframe/_version.py +2 -2
- {pyoframe-1.0.0a0.dist-info → pyoframe-1.0.1.dist-info}/METADATA +11 -12
- pyoframe-1.0.1.dist-info/RECORD +15 -0
- pyoframe-1.0.0a0.dist-info/RECORD +0 -15
- {pyoframe-1.0.0a0.dist-info → pyoframe-1.0.1.dist-info}/WHEEL +0 -0
- {pyoframe-1.0.0a0.dist-info → pyoframe-1.0.1.dist-info}/licenses/LICENSE +0 -0
- {pyoframe-1.0.0a0.dist-info → pyoframe-1.0.1.dist-info}/top_level.txt +0 -0
pyoframe/_arithmetic.py
CHANGED
|
@@ -9,18 +9,27 @@ import polars as pl
|
|
|
9
9
|
from pyoframe._constants import (
|
|
10
10
|
COEF_KEY,
|
|
11
11
|
CONST_TERM,
|
|
12
|
-
KEY_TYPE,
|
|
13
12
|
QUAD_VAR_KEY,
|
|
14
13
|
RESERVED_COL_KEYS,
|
|
15
14
|
VAR_KEY,
|
|
16
15
|
Config,
|
|
16
|
+
ExtrasStrategy,
|
|
17
17
|
PyoframeError,
|
|
18
|
-
UnmatchedStrategy,
|
|
19
18
|
)
|
|
20
19
|
|
|
21
20
|
if TYPE_CHECKING: # pragma: no cover
|
|
22
21
|
from pyoframe._core import Expression
|
|
23
22
|
|
|
23
|
+
# Mapping of how a sum of two expressions should propagate the extras strategy
|
|
24
|
+
_extras_propagation_rules = {
|
|
25
|
+
(ExtrasStrategy.DROP, ExtrasStrategy.DROP): ExtrasStrategy.DROP,
|
|
26
|
+
(ExtrasStrategy.UNSET, ExtrasStrategy.UNSET): ExtrasStrategy.UNSET,
|
|
27
|
+
(ExtrasStrategy.KEEP, ExtrasStrategy.KEEP): ExtrasStrategy.KEEP,
|
|
28
|
+
(ExtrasStrategy.DROP, ExtrasStrategy.KEEP): ExtrasStrategy.UNSET,
|
|
29
|
+
(ExtrasStrategy.DROP, ExtrasStrategy.UNSET): ExtrasStrategy.DROP,
|
|
30
|
+
(ExtrasStrategy.KEEP, ExtrasStrategy.UNSET): ExtrasStrategy.KEEP,
|
|
31
|
+
}
|
|
32
|
+
|
|
24
33
|
|
|
25
34
|
def multiply(self: Expression, other: Expression) -> Expression:
|
|
26
35
|
"""Multiplies two expressions together.
|
|
@@ -160,184 +169,82 @@ def _quadratic_multiplication(self: Expression, other: Expression) -> Expression
|
|
|
160
169
|
|
|
161
170
|
def add(*expressions: Expression) -> Expression:
|
|
162
171
|
"""Add multiple expressions together."""
|
|
163
|
-
# Mapping of how a sum of two expressions should propagate the unmatched strategy
|
|
164
|
-
propagation_strategies = {
|
|
165
|
-
(UnmatchedStrategy.DROP, UnmatchedStrategy.DROP): UnmatchedStrategy.DROP,
|
|
166
|
-
(
|
|
167
|
-
UnmatchedStrategy.UNSET,
|
|
168
|
-
UnmatchedStrategy.UNSET,
|
|
169
|
-
): UnmatchedStrategy.UNSET,
|
|
170
|
-
(UnmatchedStrategy.KEEP, UnmatchedStrategy.KEEP): UnmatchedStrategy.KEEP,
|
|
171
|
-
(UnmatchedStrategy.DROP, UnmatchedStrategy.KEEP): UnmatchedStrategy.UNSET,
|
|
172
|
-
(UnmatchedStrategy.DROP, UnmatchedStrategy.UNSET): UnmatchedStrategy.DROP,
|
|
173
|
-
(UnmatchedStrategy.KEEP, UnmatchedStrategy.UNSET): UnmatchedStrategy.KEEP,
|
|
174
|
-
}
|
|
175
|
-
|
|
176
172
|
assert len(expressions) > 1, "Need at least two expressions to add together."
|
|
177
173
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
if dims is None:
|
|
181
|
-
requires_join = False
|
|
182
|
-
dims = []
|
|
183
|
-
elif Config.disable_unmatched_checks:
|
|
184
|
-
requires_join = any(
|
|
185
|
-
expr._unmatched_strategy
|
|
186
|
-
not in (UnmatchedStrategy.KEEP, UnmatchedStrategy.UNSET)
|
|
187
|
-
for expr in expressions
|
|
188
|
-
)
|
|
174
|
+
if Config.disable_extras_checks:
|
|
175
|
+
no_checks_strats = (ExtrasStrategy.KEEP, ExtrasStrategy.UNSET)
|
|
189
176
|
else:
|
|
190
|
-
|
|
191
|
-
expr._unmatched_strategy != UnmatchedStrategy.KEEP for expr in expressions
|
|
192
|
-
)
|
|
177
|
+
no_checks_strats = (ExtrasStrategy.KEEP,)
|
|
193
178
|
|
|
194
|
-
|
|
195
|
-
|
|
179
|
+
no_extras_checks_required = (
|
|
180
|
+
all(expr._extras_strategy in no_checks_strats for expr in expressions)
|
|
181
|
+
# if only one dimensioned, then there is no such thing as extra labels,
|
|
182
|
+
# labels will be set by the only dimensioned expression
|
|
183
|
+
or sum(not expr.dimensionless for expr in expressions) <= 1
|
|
196
184
|
)
|
|
197
185
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
for expr in expressions[1:]:
|
|
202
|
-
result = add(result, expr)
|
|
203
|
-
return result
|
|
204
|
-
|
|
205
|
-
if has_dim_conflict:
|
|
206
|
-
assert len(expressions) == 2
|
|
207
|
-
|
|
208
|
-
left, right = expressions[0], expressions[1]
|
|
209
|
-
left_dims, right_dims = left._dimensions_unsafe, right._dimensions_unsafe
|
|
210
|
-
|
|
211
|
-
missing_left = [dim for dim in right_dims if dim not in left_dims]
|
|
212
|
-
missing_right = [dim for dim in left_dims if dim not in right_dims]
|
|
213
|
-
common_dims = [dim for dim in left_dims if dim in right_dims]
|
|
214
|
-
|
|
215
|
-
if not (
|
|
216
|
-
set(missing_left) <= set(left._allowed_new_dims)
|
|
217
|
-
and set(missing_right) <= set(right._allowed_new_dims)
|
|
218
|
-
):
|
|
219
|
-
_raise_addition_error(
|
|
220
|
-
left,
|
|
221
|
-
right,
|
|
222
|
-
f"their\n\tdimensions are different ({left_dims} != {right_dims})",
|
|
223
|
-
"If this is intentional, use .over(…) to broadcast. Learn more at\n\thttps://bravos-power.github.io/pyoframe/learn/concepts/special-functions/#adding-expressions-with-differing-dimensions-using-over",
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
left_old = left
|
|
227
|
-
if missing_left:
|
|
228
|
-
left = _broadcast(left, right, common_dims, missing_left)
|
|
229
|
-
if missing_right:
|
|
230
|
-
right = _broadcast(
|
|
231
|
-
right, left_old, common_dims, missing_right, swapped=True
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
assert sorted(left._dimensions_unsafe) == sorted(right._dimensions_unsafe)
|
|
235
|
-
expressions = (left, right)
|
|
236
|
-
|
|
237
|
-
dims = expressions[0]._dimensions_unsafe
|
|
238
|
-
# Check no dims conflict
|
|
239
|
-
assert all(
|
|
240
|
-
sorted(dims) == sorted(expr._dimensions_unsafe) for expr in expressions[1:]
|
|
186
|
+
has_dim_conflict = any(
|
|
187
|
+
sorted(expressions[0]._dimensions_unsafe) != sorted(expr._dimensions_unsafe)
|
|
188
|
+
for expr in expressions[1:]
|
|
241
189
|
)
|
|
242
|
-
if requires_join:
|
|
243
|
-
assert len(expressions) == 2
|
|
244
|
-
assert dims != []
|
|
245
|
-
left, right = expressions[0], expressions[1]
|
|
246
|
-
|
|
247
|
-
# Order so that drop always comes before keep, and keep always comes before default
|
|
248
|
-
if swap := (
|
|
249
|
-
(left._unmatched_strategy, right._unmatched_strategy)
|
|
250
|
-
in (
|
|
251
|
-
(UnmatchedStrategy.UNSET, UnmatchedStrategy.DROP),
|
|
252
|
-
(UnmatchedStrategy.UNSET, UnmatchedStrategy.KEEP),
|
|
253
|
-
(UnmatchedStrategy.KEEP, UnmatchedStrategy.DROP),
|
|
254
|
-
)
|
|
255
|
-
):
|
|
256
|
-
left, right = right, left
|
|
257
|
-
|
|
258
|
-
def get_indices(expr):
|
|
259
|
-
return expr.data.select(dims).unique(maintain_order=Config.maintain_order)
|
|
260
|
-
|
|
261
|
-
left_data, right_data = left.data, right.data
|
|
262
190
|
|
|
263
|
-
|
|
191
|
+
# If we cannot use .concat compute the sum in a pairwise manner, so far nobody uses this code
|
|
192
|
+
if len(expressions) > 2: # pragma: no cover
|
|
193
|
+
assert False, "This code has not been tested."
|
|
194
|
+
if has_dim_conflict or not no_extras_checks_required:
|
|
195
|
+
result = expressions[0]
|
|
196
|
+
for expr in expressions[1:]:
|
|
197
|
+
result = add(result, expr)
|
|
198
|
+
return result
|
|
199
|
+
propagate_strat = expressions[0]._extras_strategy
|
|
200
|
+
dims = expressions[0]._dimensions_unsafe
|
|
201
|
+
expr_data = [expr.data for expr in expressions]
|
|
202
|
+
else:
|
|
203
|
+
left, right = expressions[0], expressions[1]
|
|
264
204
|
|
|
265
|
-
|
|
205
|
+
if has_dim_conflict:
|
|
206
|
+
left_dims, right_dims = left._dimensions_unsafe, right._dimensions_unsafe
|
|
266
207
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
on=dims,
|
|
271
|
-
maintain_order="left" if Config.maintain_order else None,
|
|
272
|
-
)
|
|
273
|
-
right_data = right.data.join(
|
|
274
|
-
get_indices(left),
|
|
275
|
-
on=dims,
|
|
276
|
-
maintain_order="left" if Config.maintain_order else None,
|
|
277
|
-
)
|
|
278
|
-
elif strat == (UnmatchedStrategy.UNSET, UnmatchedStrategy.UNSET):
|
|
279
|
-
assert not Config.disable_unmatched_checks, (
|
|
280
|
-
"This code should not be reached when unmatched checks are disabled."
|
|
281
|
-
)
|
|
282
|
-
outer_join = get_indices(left).join(
|
|
283
|
-
get_indices(right),
|
|
284
|
-
how="full",
|
|
285
|
-
on=dims,
|
|
286
|
-
maintain_order="left_right" if Config.maintain_order else None,
|
|
287
|
-
)
|
|
288
|
-
if outer_join.get_column(dims[0]).null_count() > 0:
|
|
289
|
-
unmatched_vals = outer_join.filter(
|
|
290
|
-
outer_join.get_column(dims[0]).is_null()
|
|
291
|
-
)
|
|
292
|
-
_raise_unmatched_values_error(left, right, unmatched_vals, swap)
|
|
293
|
-
if outer_join.get_column(dims[0] + "_right").null_count() > 0:
|
|
294
|
-
unmatched_vals = outer_join.filter(
|
|
295
|
-
outer_join.get_column(dims[0] + "_right").is_null()
|
|
296
|
-
)
|
|
297
|
-
_raise_unmatched_values_error(left, right, unmatched_vals, swap)
|
|
208
|
+
missing_left = [dim for dim in right_dims if dim not in left_dims]
|
|
209
|
+
missing_right = [dim for dim in left_dims if dim not in right_dims]
|
|
210
|
+
common_dims = [dim for dim in left_dims if dim in right_dims]
|
|
298
211
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
maintain_order="left" if Config.maintain_order else None,
|
|
305
|
-
)
|
|
306
|
-
elif strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.UNSET):
|
|
307
|
-
left_data = get_indices(right).join(
|
|
308
|
-
left.data,
|
|
309
|
-
how="left",
|
|
310
|
-
on=dims,
|
|
311
|
-
maintain_order="left" if Config.maintain_order else None,
|
|
312
|
-
)
|
|
313
|
-
if left_data.get_column(COEF_KEY).null_count() > 0:
|
|
314
|
-
_raise_unmatched_values_error(
|
|
212
|
+
if not (
|
|
213
|
+
set(missing_left) <= set(left._allowed_new_dims)
|
|
214
|
+
and set(missing_right) <= set(right._allowed_new_dims)
|
|
215
|
+
):
|
|
216
|
+
_raise_addition_error(
|
|
315
217
|
left,
|
|
316
218
|
right,
|
|
317
|
-
|
|
318
|
-
|
|
219
|
+
f"their\n\tdimensions are different ({left_dims} != {right_dims})",
|
|
220
|
+
"If this is intentional, use .over(…) to broadcast. Learn more at\n\thttps://bravos-power.github.io/pyoframe/latest/learn/concepts/addition/#adding-expressions-with-differing-dimensions-using-over",
|
|
319
221
|
)
|
|
320
222
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
else: # pragma: no cover
|
|
329
|
-
assert False, "This code should've never been reached!"
|
|
223
|
+
left_old = left
|
|
224
|
+
if missing_left:
|
|
225
|
+
left = _broadcast(left, right, common_dims, missing_left)
|
|
226
|
+
if missing_right:
|
|
227
|
+
right = _broadcast(
|
|
228
|
+
right, left_old, common_dims, missing_right, swapped=True
|
|
229
|
+
)
|
|
330
230
|
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
231
|
+
assert sorted(left._dimensions_unsafe) == sorted(right._dimensions_unsafe)
|
|
232
|
+
|
|
233
|
+
dims = left._dimensions_unsafe
|
|
234
|
+
|
|
235
|
+
if not no_extras_checks_required:
|
|
236
|
+
expr_data, propagate_strat = _handle_extra_labels(left, right, dims)
|
|
237
|
+
else:
|
|
238
|
+
propagate_strat = left._extras_strategy
|
|
239
|
+
expr_data = (left.data, right.data)
|
|
335
240
|
|
|
336
241
|
# Add quadratic column if it is needed and doesn't already exist
|
|
337
242
|
if any(QUAD_VAR_KEY in df.columns for df in expr_data):
|
|
338
243
|
expr_data = [
|
|
339
244
|
(
|
|
340
|
-
df.with_columns(
|
|
245
|
+
df.with_columns(
|
|
246
|
+
pl.lit(CONST_TERM).alias(QUAD_VAR_KEY).cast(Config.id_dtype)
|
|
247
|
+
)
|
|
341
248
|
if QUAD_VAR_KEY not in df.columns
|
|
342
249
|
else df
|
|
343
250
|
)
|
|
@@ -358,23 +265,116 @@ def add(*expressions: Expression) -> Expression:
|
|
|
358
265
|
full_name += f" - {name[1:]}" if name[0] == "-" else f" + {name}"
|
|
359
266
|
|
|
360
267
|
new_expr = expressions[0]._new(data, name=f"({full_name})")
|
|
361
|
-
new_expr.
|
|
268
|
+
new_expr._extras_strategy = propagate_strat
|
|
362
269
|
|
|
363
270
|
return new_expr
|
|
364
271
|
|
|
365
272
|
|
|
366
|
-
def
|
|
367
|
-
left: Expression, right: Expression,
|
|
273
|
+
def _handle_extra_labels(
|
|
274
|
+
left: Expression, right: Expression, dims: list[str]
|
|
275
|
+
) -> tuple[tuple[pl.DataFrame, pl.DataFrame], ExtrasStrategy]:
|
|
276
|
+
assert dims != []
|
|
277
|
+
# Order so that drop always comes before keep, and keep always comes before default
|
|
278
|
+
if swapped := (
|
|
279
|
+
(left._extras_strategy, right._extras_strategy)
|
|
280
|
+
in (
|
|
281
|
+
(ExtrasStrategy.UNSET, ExtrasStrategy.DROP),
|
|
282
|
+
(ExtrasStrategy.UNSET, ExtrasStrategy.KEEP),
|
|
283
|
+
(ExtrasStrategy.KEEP, ExtrasStrategy.DROP),
|
|
284
|
+
)
|
|
285
|
+
):
|
|
286
|
+
left, right = right, left
|
|
287
|
+
|
|
288
|
+
def get_labels(expr):
|
|
289
|
+
return expr.data.select(dims).unique(maintain_order=Config.maintain_order)
|
|
290
|
+
|
|
291
|
+
left_data, right_data = left.data, right.data
|
|
292
|
+
|
|
293
|
+
strat = (left._extras_strategy, right._extras_strategy)
|
|
294
|
+
|
|
295
|
+
if strat == (ExtrasStrategy.DROP, ExtrasStrategy.DROP):
|
|
296
|
+
left_data = left.data.join(
|
|
297
|
+
get_labels(right),
|
|
298
|
+
on=dims,
|
|
299
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
300
|
+
)
|
|
301
|
+
right_data = right.data.join(
|
|
302
|
+
get_labels(left),
|
|
303
|
+
on=dims,
|
|
304
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
305
|
+
)
|
|
306
|
+
elif strat == (ExtrasStrategy.UNSET, ExtrasStrategy.UNSET):
|
|
307
|
+
assert not Config.disable_extras_checks, (
|
|
308
|
+
"This code should not be reached when checks for extra values are disabled."
|
|
309
|
+
)
|
|
310
|
+
left_labels, right_labels = get_labels(left), get_labels(right)
|
|
311
|
+
left_extras = left_labels.join(right_labels, how="anti", on=dims)
|
|
312
|
+
right_extras = right_labels.join(left_labels, how="anti", on=dims)
|
|
313
|
+
if len(left_extras) > 0:
|
|
314
|
+
_raise_extras_error(
|
|
315
|
+
left, right, left_extras, swapped, extras_on_right=False
|
|
316
|
+
)
|
|
317
|
+
if len(right_extras) > 0:
|
|
318
|
+
_raise_extras_error(left, right, right_extras, swapped)
|
|
319
|
+
|
|
320
|
+
elif strat == (ExtrasStrategy.DROP, ExtrasStrategy.KEEP):
|
|
321
|
+
left_data = get_labels(right).join(
|
|
322
|
+
left.data,
|
|
323
|
+
on=dims,
|
|
324
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
325
|
+
)
|
|
326
|
+
elif strat == (ExtrasStrategy.DROP, ExtrasStrategy.UNSET):
|
|
327
|
+
right_labels = get_labels(right)
|
|
328
|
+
left_data = right_labels.join(
|
|
329
|
+
left.data,
|
|
330
|
+
how="left",
|
|
331
|
+
on=dims,
|
|
332
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
333
|
+
)
|
|
334
|
+
if left_data.get_column(COEF_KEY).null_count() > 0:
|
|
335
|
+
_raise_extras_error(
|
|
336
|
+
left,
|
|
337
|
+
right,
|
|
338
|
+
right_labels.join(get_labels(left), how="anti", on=dims),
|
|
339
|
+
swapped,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
elif strat == (ExtrasStrategy.KEEP, ExtrasStrategy.UNSET):
|
|
343
|
+
assert not Config.disable_extras_checks, (
|
|
344
|
+
"This code should not be reached when checks for extra values are disabled."
|
|
345
|
+
)
|
|
346
|
+
extras = right.data.join(get_labels(left), how="anti", on=dims)
|
|
347
|
+
if len(extras) > 0:
|
|
348
|
+
_raise_extras_error(left, right, extras.select(dims), swapped)
|
|
349
|
+
else: # pragma: no cover
|
|
350
|
+
assert False, "This code should've never been reached!"
|
|
351
|
+
|
|
352
|
+
if swapped:
|
|
353
|
+
left_data, right_data = right_data, left_data
|
|
354
|
+
|
|
355
|
+
return (left_data, right_data), _extras_propagation_rules[strat]
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _raise_extras_error(
|
|
359
|
+
left: Expression,
|
|
360
|
+
right: Expression,
|
|
361
|
+
extra_labels: pl.DataFrame,
|
|
362
|
+
swapped: bool,
|
|
363
|
+
extras_on_right: bool = True,
|
|
368
364
|
):
|
|
369
365
|
if swapped:
|
|
370
366
|
left, right = right, left
|
|
367
|
+
extras_on_right = not extras_on_right
|
|
371
368
|
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
369
|
+
expression_num = 2 if extras_on_right else 1
|
|
370
|
+
|
|
371
|
+
with Config.print_polars_config:
|
|
372
|
+
_raise_addition_error(
|
|
373
|
+
left,
|
|
374
|
+
right,
|
|
375
|
+
f"expression {expression_num} has extra labels",
|
|
376
|
+
f"Extra labels in expression {expression_num}:\n{extra_labels}\nUse .drop_extras() or .keep_extras() to indicate how the extra labels should be handled. Learn more at\n\thttps://bravos-power.github.io/pyoframe/latest/learn/concepts/addition",
|
|
377
|
+
)
|
|
378
378
|
|
|
379
379
|
|
|
380
380
|
def _raise_addition_error(
|
|
@@ -412,7 +412,7 @@ def _broadcast(
|
|
|
412
412
|
return res
|
|
413
413
|
|
|
414
414
|
# If drop, we just do an inner join to get into the shape of the other
|
|
415
|
-
if self.
|
|
415
|
+
if self._extras_strategy == ExtrasStrategy.DROP:
|
|
416
416
|
res = self._new(
|
|
417
417
|
self.data.join(
|
|
418
418
|
target_data,
|
|
@@ -430,12 +430,14 @@ def _broadcast(
|
|
|
430
430
|
how="left",
|
|
431
431
|
maintain_order="left" if Config.maintain_order else None,
|
|
432
432
|
)
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
433
|
+
if result.get_column(missing_dims[0]).null_count() > 0:
|
|
434
|
+
target_labels = target.data.select(target._dimensions_unsafe).unique(
|
|
435
|
+
maintain_order=Config.maintain_order
|
|
436
|
+
)
|
|
437
|
+
_raise_extras_error(
|
|
436
438
|
self,
|
|
437
439
|
target,
|
|
438
|
-
|
|
440
|
+
target_labels.join(self.data, how="anti", on=common_dims),
|
|
439
441
|
swapped,
|
|
440
442
|
)
|
|
441
443
|
res = self._new(result, self.name)
|
|
@@ -522,7 +524,7 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
|
|
|
522
524
|
if df.is_empty():
|
|
523
525
|
df = pl.DataFrame(
|
|
524
526
|
{VAR_KEY: [CONST_TERM], COEF_KEY: [0]},
|
|
525
|
-
schema={VAR_KEY:
|
|
527
|
+
schema={VAR_KEY: Config.id_dtype, COEF_KEY: pl.Float64},
|
|
526
528
|
)
|
|
527
529
|
|
|
528
530
|
if QUAD_VAR_KEY in df.columns and (df.get_column(QUAD_VAR_KEY) == CONST_TERM).all():
|
pyoframe/_constants.py
CHANGED
|
@@ -17,41 +17,54 @@ CONSTRAINT_KEY = "__constraint_id"
|
|
|
17
17
|
SOLUTION_KEY = "solution"
|
|
18
18
|
DUAL_KEY = "dual"
|
|
19
19
|
|
|
20
|
-
# TODO: move as configuration since this could be too small... also add a test to make sure errors occur on overflow.
|
|
21
|
-
KEY_TYPE = pl.UInt32
|
|
22
|
-
|
|
23
20
|
|
|
24
21
|
@dataclass
|
|
25
22
|
class _Solver:
|
|
26
23
|
name: SUPPORTED_SOLVER_TYPES
|
|
27
24
|
supports_integer_variables: bool = True
|
|
28
|
-
|
|
25
|
+
supports_quadratic_constraints: bool = True
|
|
26
|
+
supports_non_convex: bool = True
|
|
29
27
|
supports_duals: bool = True
|
|
30
28
|
supports_objective_sense: bool = True
|
|
31
29
|
supports_write: bool = True
|
|
32
|
-
|
|
30
|
+
accelerate_with_repeat_names: bool = False
|
|
33
31
|
"""
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
that
|
|
38
|
-
|
|
39
|
-
|
|
32
|
+
If True, Pyoframe sets all the variable and constraint names to 'V'
|
|
33
|
+
and 'C', respectively, which, for some solvers, was found to improve
|
|
34
|
+
performance. This setting should only be enabled for a given solver after
|
|
35
|
+
testing that a) it actually improves performance, and b) the solver can
|
|
36
|
+
handle conflicting variable and constraint names.
|
|
37
|
+
So far, only Gurobi has been tested.
|
|
38
|
+
Note, that when enabled, Model.write() is not supported
|
|
39
|
+
(unless solver_uses_variable_names=True) because the outputted files would
|
|
40
|
+
be meaningless as all variables/constraints would have identical names.
|
|
40
41
|
"""
|
|
41
42
|
|
|
43
|
+
def __post_init__(self):
|
|
44
|
+
if self.supports_non_convex:
|
|
45
|
+
assert self.supports_quadratic_constraints, (
|
|
46
|
+
"Non-convex solvers typically support quadratic constraints. Are you sure this is correct?"
|
|
47
|
+
)
|
|
48
|
+
|
|
42
49
|
def __repr__(self):
|
|
43
50
|
return self.name
|
|
44
51
|
|
|
45
52
|
|
|
46
53
|
SUPPORTED_SOLVERS = [
|
|
47
|
-
_Solver("gurobi",
|
|
48
|
-
_Solver(
|
|
54
|
+
_Solver("gurobi", accelerate_with_repeat_names=True),
|
|
55
|
+
_Solver(
|
|
56
|
+
"highs",
|
|
57
|
+
supports_quadratic_constraints=False,
|
|
58
|
+
supports_non_convex=False,
|
|
59
|
+
supports_duals=False,
|
|
60
|
+
),
|
|
49
61
|
_Solver(
|
|
50
62
|
"ipopt",
|
|
51
63
|
supports_integer_variables=False,
|
|
52
64
|
supports_objective_sense=False,
|
|
53
65
|
supports_write=False,
|
|
54
66
|
),
|
|
67
|
+
_Solver("copt", supports_non_convex=False),
|
|
55
68
|
]
|
|
56
69
|
|
|
57
70
|
|
|
@@ -71,7 +84,7 @@ RESERVED_COL_KEYS = (
|
|
|
71
84
|
@dataclass
|
|
72
85
|
class ConfigDefaults:
|
|
73
86
|
default_solver: SUPPORTED_SOLVER_TYPES | _Solver | Literal["raise", "auto"] = "auto"
|
|
74
|
-
|
|
87
|
+
disable_extras_checks: bool = False
|
|
75
88
|
enable_is_duplicated_expression_safety_check: bool = False
|
|
76
89
|
integer_tolerance: float = 1e-8
|
|
77
90
|
float_to_str_precision: int | None = 5
|
|
@@ -85,6 +98,7 @@ class ConfigDefaults:
|
|
|
85
98
|
)
|
|
86
99
|
print_max_terms: int = 5
|
|
87
100
|
maintain_order: bool = True
|
|
101
|
+
id_dtype = pl.UInt32
|
|
88
102
|
|
|
89
103
|
|
|
90
104
|
class _Config:
|
|
@@ -114,16 +128,16 @@ class _Config:
|
|
|
114
128
|
self._settings.default_solver = value
|
|
115
129
|
|
|
116
130
|
@property
|
|
117
|
-
def
|
|
118
|
-
"""When `True`, improves performance by skipping
|
|
131
|
+
def disable_extras_checks(self) -> bool:
|
|
132
|
+
"""When `True`, improves performance by skipping checks for extra values (not recommended).
|
|
119
133
|
|
|
120
|
-
When `True`,
|
|
121
|
-
are treated as if they contained [`.
|
|
122
|
-
(unless [`.
|
|
134
|
+
When `True`, checks for extra values are disabled which effectively means that all expressions
|
|
135
|
+
are treated as if they contained [`.keep_extras()`][pyoframe.Expression.keep_extras]
|
|
136
|
+
(unless [`.drop_extras()`][pyoframe.Expression.drop_extras] was applied).
|
|
123
137
|
|
|
124
138
|
!!! warning
|
|
125
|
-
This might improve performance, but it will suppress the
|
|
126
|
-
behaviors (
|
|
139
|
+
This might improve performance, but it will suppress the errors that alert you of unexpected
|
|
140
|
+
behaviors ([learn more](../../learn/concepts/addition.md)).
|
|
127
141
|
Only consider enabling after you have thoroughly tested your code.
|
|
128
142
|
|
|
129
143
|
Examples:
|
|
@@ -141,26 +155,24 @@ class _Config:
|
|
|
141
155
|
... }
|
|
142
156
|
... ).to_expr()
|
|
143
157
|
|
|
144
|
-
Normally, an error warns users that the two expressions have conflicting
|
|
158
|
+
Normally, an error warns users that the two expressions have conflicting labels:
|
|
145
159
|
>>> population + population_influx
|
|
146
160
|
Traceback (most recent call last):
|
|
147
161
|
...
|
|
148
|
-
pyoframe._constants.PyoframeError: Cannot add the two expressions below because
|
|
162
|
+
pyoframe._constants.PyoframeError: Cannot add the two expressions below because expression 1 has extra labels.
|
|
149
163
|
Expression 1: pop
|
|
150
164
|
Expression 2: influx
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
│
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
But if `Config.disable_unmatched_checks = True`, the error is suppressed and the sum is considered to be `population.keep_unmatched() + population_influx.keep_unmatched()`:
|
|
163
|
-
>>> pf.Config.disable_unmatched_checks = True
|
|
165
|
+
Extra labels in expression 1:
|
|
166
|
+
┌──────────┐
|
|
167
|
+
│ city │
|
|
168
|
+
╞══════════╡
|
|
169
|
+
│ Montreal │
|
|
170
|
+
└──────────┘
|
|
171
|
+
Use .drop_extras() or .keep_extras() to indicate how the extra labels should be handled. Learn more at
|
|
172
|
+
https://bravos-power.github.io/pyoframe/latest/learn/concepts/addition
|
|
173
|
+
|
|
174
|
+
But if `Config.disable_extras_checks = True`, the error is suppressed and the sum is considered to be `population.keep_extras() + population_influx.keep_extras()`:
|
|
175
|
+
>>> pf.Config.disable_extras_checks = True
|
|
164
176
|
>>> population + population_influx
|
|
165
177
|
<Expression height=3 terms=3 type=constant>
|
|
166
178
|
┌───────────┬────────────┐
|
|
@@ -172,11 +184,11 @@ class _Config:
|
|
|
172
184
|
│ Montreal ┆ 1704694 │
|
|
173
185
|
└───────────┴────────────┘
|
|
174
186
|
"""
|
|
175
|
-
return self._settings.
|
|
187
|
+
return self._settings.disable_extras_checks
|
|
176
188
|
|
|
177
|
-
@
|
|
178
|
-
def
|
|
179
|
-
self._settings.
|
|
189
|
+
@disable_extras_checks.setter
|
|
190
|
+
def disable_extras_checks(self, value: bool):
|
|
191
|
+
self._settings.disable_extras_checks = value
|
|
180
192
|
|
|
181
193
|
@property
|
|
182
194
|
def enable_is_duplicated_expression_safety_check(self) -> bool:
|
|
@@ -308,17 +320,52 @@ class _Config:
|
|
|
308
320
|
def maintain_order(self, value: bool):
|
|
309
321
|
self._settings.maintain_order = value
|
|
310
322
|
|
|
323
|
+
@property
|
|
324
|
+
def id_dtype(self):
|
|
325
|
+
"""The Polars data type to use for variable and constraint IDs.
|
|
326
|
+
|
|
327
|
+
Defaults to `pl.UInt32` which should be ideal for most users.
|
|
328
|
+
|
|
329
|
+
Users with more than 4 billion variables or constraints can change this to `pl.UInt64`.
|
|
330
|
+
|
|
331
|
+
Users concerned with memory usage and with fewer than 65k variables or constraints can change this to `pl.UInt16`.
|
|
332
|
+
|
|
333
|
+
!!! warning
|
|
334
|
+
Changing this setting after creating a model will lead to errors.
|
|
335
|
+
You should only change this setting before creating any models.
|
|
336
|
+
|
|
337
|
+
Examples:
|
|
338
|
+
An error is automatically raised if the number of variables or constraints exceeds the chosen data type:
|
|
339
|
+
>>> pf.Config.id_dtype = pl.UInt8
|
|
340
|
+
>>> m = pf.Model()
|
|
341
|
+
>>> big_set = pf.Set(x=range(2**8 + 1))
|
|
342
|
+
>>> m.X = pf.Variable()
|
|
343
|
+
>>> m.constraint = m.X.over("x") <= big_set
|
|
344
|
+
Traceback (most recent call last):
|
|
345
|
+
...
|
|
346
|
+
TypeError: Number of constraints exceeds the current data type (UInt8). Consider increasing the data type by changing Config.id_dtype.
|
|
347
|
+
>>> m.X_large = pf.Variable(big_set)
|
|
348
|
+
Traceback (most recent call last):
|
|
349
|
+
...
|
|
350
|
+
TypeError: Number of variables exceeds the current data type (UInt8). Consider increasing the data type by changing Config.id_dtype.
|
|
351
|
+
"""
|
|
352
|
+
return self._settings.id_dtype
|
|
353
|
+
|
|
354
|
+
@id_dtype.setter
|
|
355
|
+
def id_dtype(self, value):
|
|
356
|
+
self._settings.id_dtype = value
|
|
357
|
+
|
|
311
358
|
def reset_defaults(self):
|
|
312
359
|
"""Resets all configuration options to their default values.
|
|
313
360
|
|
|
314
361
|
Examples:
|
|
315
|
-
>>> pf.Config.
|
|
362
|
+
>>> pf.Config.disable_extras_checks
|
|
316
363
|
False
|
|
317
|
-
>>> pf.Config.
|
|
318
|
-
>>> pf.Config.
|
|
364
|
+
>>> pf.Config.disable_extras_checks = True
|
|
365
|
+
>>> pf.Config.disable_extras_checks
|
|
319
366
|
True
|
|
320
367
|
>>> pf.Config.reset_defaults()
|
|
321
|
-
>>> pf.Config.
|
|
368
|
+
>>> pf.Config.disable_extras_checks
|
|
322
369
|
False
|
|
323
370
|
"""
|
|
324
371
|
self._settings = ConfigDefaults()
|
|
@@ -389,8 +436,8 @@ class VType(Enum):
|
|
|
389
436
|
raise ValueError(f"Invalid variable type: {self}") # pragma: no cover
|
|
390
437
|
|
|
391
438
|
|
|
392
|
-
class
|
|
393
|
-
"""An enum to specify how to handle
|
|
439
|
+
class ExtrasStrategy(Enum):
|
|
440
|
+
"""An enum to specify how to handle extra values in expressions."""
|
|
394
441
|
|
|
395
442
|
UNSET = "not_set"
|
|
396
443
|
DROP = "drop"
|
|
@@ -404,7 +451,7 @@ VTypeValue = Literal["continuous", "binary", "integer"]
|
|
|
404
451
|
for enum, type in [(ObjSense, ObjSenseValue), (VType, VTypeValue)]:
|
|
405
452
|
assert set(typing.get_args(type)) == {vtype.value for vtype in enum}
|
|
406
453
|
|
|
407
|
-
SUPPORTED_SOLVER_TYPES = Literal["gurobi", "highs", "ipopt"]
|
|
454
|
+
SUPPORTED_SOLVER_TYPES = Literal["gurobi", "highs", "ipopt", "copt"]
|
|
408
455
|
assert set(typing.get_args(SUPPORTED_SOLVER_TYPES)) == {
|
|
409
456
|
s.name for s in SUPPORTED_SOLVERS
|
|
410
457
|
}
|