pyoframe 1.0.0a0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyoframe/__init__.py +2 -0
- pyoframe/_arithmetic.py +179 -177
- pyoframe/_constants.py +103 -57
- pyoframe/_core.py +308 -204
- pyoframe/_model.py +49 -29
- pyoframe/_model_element.py +34 -18
- pyoframe/_monkey_patch.py +8 -50
- pyoframe/_objective.py +4 -6
- pyoframe/_param.py +99 -0
- pyoframe/_utils.py +10 -11
- pyoframe/_version.py +2 -2
- {pyoframe-1.0.0a0.dist-info → pyoframe-1.1.0.dist-info}/METADATA +13 -14
- pyoframe-1.1.0.dist-info/RECORD +16 -0
- pyoframe-1.0.0a0.dist-info/RECORD +0 -15
- {pyoframe-1.0.0a0.dist-info → pyoframe-1.1.0.dist-info}/WHEEL +0 -0
- {pyoframe-1.0.0a0.dist-info → pyoframe-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {pyoframe-1.0.0a0.dist-info → pyoframe-1.1.0.dist-info}/top_level.txt +0 -0
pyoframe/__init__.py
CHANGED
|
@@ -11,6 +11,7 @@ from pyoframe._core import Constraint, Expression, Set, Variable, sum, sum_by
|
|
|
11
11
|
from pyoframe._model import Model
|
|
12
12
|
from pyoframe._monkey_patch import patch_dataframe_libraries
|
|
13
13
|
from pyoframe._objective import Objective
|
|
14
|
+
from pyoframe._param import Param
|
|
14
15
|
|
|
15
16
|
try:
|
|
16
17
|
from pyoframe._version import __version__, __version_tuple__ # noqa: F401
|
|
@@ -25,6 +26,7 @@ __all__ = [
|
|
|
25
26
|
"Expression",
|
|
26
27
|
"Constraint",
|
|
27
28
|
"Objective",
|
|
29
|
+
"Param",
|
|
28
30
|
"Set",
|
|
29
31
|
"Config",
|
|
30
32
|
"sum",
|
pyoframe/_arithmetic.py
CHANGED
|
@@ -9,18 +9,27 @@ import polars as pl
|
|
|
9
9
|
from pyoframe._constants import (
|
|
10
10
|
COEF_KEY,
|
|
11
11
|
CONST_TERM,
|
|
12
|
-
KEY_TYPE,
|
|
13
12
|
QUAD_VAR_KEY,
|
|
14
13
|
RESERVED_COL_KEYS,
|
|
15
14
|
VAR_KEY,
|
|
16
15
|
Config,
|
|
16
|
+
ExtrasStrategy,
|
|
17
17
|
PyoframeError,
|
|
18
|
-
UnmatchedStrategy,
|
|
19
18
|
)
|
|
20
19
|
|
|
21
20
|
if TYPE_CHECKING: # pragma: no cover
|
|
22
21
|
from pyoframe._core import Expression
|
|
23
22
|
|
|
23
|
+
# Mapping of how a sum of two expressions should propagate the extras strategy
|
|
24
|
+
_extras_propagation_rules = {
|
|
25
|
+
(ExtrasStrategy.DROP, ExtrasStrategy.DROP): ExtrasStrategy.DROP,
|
|
26
|
+
(ExtrasStrategy.UNSET, ExtrasStrategy.UNSET): ExtrasStrategy.UNSET,
|
|
27
|
+
(ExtrasStrategy.KEEP, ExtrasStrategy.KEEP): ExtrasStrategy.KEEP,
|
|
28
|
+
(ExtrasStrategy.DROP, ExtrasStrategy.KEEP): ExtrasStrategy.UNSET,
|
|
29
|
+
(ExtrasStrategy.DROP, ExtrasStrategy.UNSET): ExtrasStrategy.DROP,
|
|
30
|
+
(ExtrasStrategy.KEEP, ExtrasStrategy.UNSET): ExtrasStrategy.KEEP,
|
|
31
|
+
}
|
|
32
|
+
|
|
24
33
|
|
|
25
34
|
def multiply(self: Expression, other: Expression) -> Expression:
|
|
26
35
|
"""Multiplies two expressions together.
|
|
@@ -33,7 +42,7 @@ def multiply(self: Expression, other: Expression) -> Expression:
|
|
|
33
42
|
>>> m.x3 = pf.Variable()
|
|
34
43
|
>>> result = 5 * m.x1 * m.x2
|
|
35
44
|
>>> result
|
|
36
|
-
<Expression terms=1
|
|
45
|
+
<Expression (quadratic) terms=1>
|
|
37
46
|
5 x2 * x1
|
|
38
47
|
>>> result * m.x3
|
|
39
48
|
Traceback (most recent call last):
|
|
@@ -103,7 +112,7 @@ def _quadratic_multiplication(self: Expression, other: Expression) -> Expression
|
|
|
103
112
|
>>> expr1 = df * m.x1
|
|
104
113
|
>>> expr2 = df * m.x2 * 2 + 4
|
|
105
114
|
>>> expr1 * expr2
|
|
106
|
-
<Expression height=3 terms=6
|
|
115
|
+
<Expression (quadratic) height=3 terms=6>
|
|
107
116
|
┌─────┬───────────────────┐
|
|
108
117
|
│ dim ┆ expression │
|
|
109
118
|
│ (3) ┆ │
|
|
@@ -113,7 +122,7 @@ def _quadratic_multiplication(self: Expression, other: Expression) -> Expression
|
|
|
113
122
|
│ 3 ┆ 12 x1 +18 x2 * x1 │
|
|
114
123
|
└─────┴───────────────────┘
|
|
115
124
|
>>> (expr1 * expr2) - df * m.x1 * df * m.x2 * 2
|
|
116
|
-
<Expression height=3 terms=3
|
|
125
|
+
<Expression (linear) height=3 terms=3>
|
|
117
126
|
┌─────┬────────────┐
|
|
118
127
|
│ dim ┆ expression │
|
|
119
128
|
│ (3) ┆ │
|
|
@@ -160,184 +169,82 @@ def _quadratic_multiplication(self: Expression, other: Expression) -> Expression
|
|
|
160
169
|
|
|
161
170
|
def add(*expressions: Expression) -> Expression:
|
|
162
171
|
"""Add multiple expressions together."""
|
|
163
|
-
# Mapping of how a sum of two expressions should propagate the unmatched strategy
|
|
164
|
-
propagation_strategies = {
|
|
165
|
-
(UnmatchedStrategy.DROP, UnmatchedStrategy.DROP): UnmatchedStrategy.DROP,
|
|
166
|
-
(
|
|
167
|
-
UnmatchedStrategy.UNSET,
|
|
168
|
-
UnmatchedStrategy.UNSET,
|
|
169
|
-
): UnmatchedStrategy.UNSET,
|
|
170
|
-
(UnmatchedStrategy.KEEP, UnmatchedStrategy.KEEP): UnmatchedStrategy.KEEP,
|
|
171
|
-
(UnmatchedStrategy.DROP, UnmatchedStrategy.KEEP): UnmatchedStrategy.UNSET,
|
|
172
|
-
(UnmatchedStrategy.DROP, UnmatchedStrategy.UNSET): UnmatchedStrategy.DROP,
|
|
173
|
-
(UnmatchedStrategy.KEEP, UnmatchedStrategy.UNSET): UnmatchedStrategy.KEEP,
|
|
174
|
-
}
|
|
175
|
-
|
|
176
172
|
assert len(expressions) > 1, "Need at least two expressions to add together."
|
|
177
173
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
if dims is None:
|
|
181
|
-
requires_join = False
|
|
182
|
-
dims = []
|
|
183
|
-
elif Config.disable_unmatched_checks:
|
|
184
|
-
requires_join = any(
|
|
185
|
-
expr._unmatched_strategy
|
|
186
|
-
not in (UnmatchedStrategy.KEEP, UnmatchedStrategy.UNSET)
|
|
187
|
-
for expr in expressions
|
|
188
|
-
)
|
|
174
|
+
if Config.disable_extras_checks:
|
|
175
|
+
no_checks_strats = (ExtrasStrategy.KEEP, ExtrasStrategy.UNSET)
|
|
189
176
|
else:
|
|
190
|
-
|
|
191
|
-
expr._unmatched_strategy != UnmatchedStrategy.KEEP for expr in expressions
|
|
192
|
-
)
|
|
177
|
+
no_checks_strats = (ExtrasStrategy.KEEP,)
|
|
193
178
|
|
|
194
|
-
|
|
195
|
-
|
|
179
|
+
no_extras_checks_required = (
|
|
180
|
+
all(expr._extras_strategy in no_checks_strats for expr in expressions)
|
|
181
|
+
# if only one dimensioned, then there is no such thing as extra labels,
|
|
182
|
+
# labels will be set by the only dimensioned expression
|
|
183
|
+
or sum(not expr.dimensionless for expr in expressions) <= 1
|
|
196
184
|
)
|
|
197
185
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
for expr in expressions[1:]:
|
|
202
|
-
result = add(result, expr)
|
|
203
|
-
return result
|
|
204
|
-
|
|
205
|
-
if has_dim_conflict:
|
|
206
|
-
assert len(expressions) == 2
|
|
207
|
-
|
|
208
|
-
left, right = expressions[0], expressions[1]
|
|
209
|
-
left_dims, right_dims = left._dimensions_unsafe, right._dimensions_unsafe
|
|
210
|
-
|
|
211
|
-
missing_left = [dim for dim in right_dims if dim not in left_dims]
|
|
212
|
-
missing_right = [dim for dim in left_dims if dim not in right_dims]
|
|
213
|
-
common_dims = [dim for dim in left_dims if dim in right_dims]
|
|
214
|
-
|
|
215
|
-
if not (
|
|
216
|
-
set(missing_left) <= set(left._allowed_new_dims)
|
|
217
|
-
and set(missing_right) <= set(right._allowed_new_dims)
|
|
218
|
-
):
|
|
219
|
-
_raise_addition_error(
|
|
220
|
-
left,
|
|
221
|
-
right,
|
|
222
|
-
f"their\n\tdimensions are different ({left_dims} != {right_dims})",
|
|
223
|
-
"If this is intentional, use .over(…) to broadcast. Learn more at\n\thttps://bravos-power.github.io/pyoframe/learn/concepts/special-functions/#adding-expressions-with-differing-dimensions-using-over",
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
left_old = left
|
|
227
|
-
if missing_left:
|
|
228
|
-
left = _broadcast(left, right, common_dims, missing_left)
|
|
229
|
-
if missing_right:
|
|
230
|
-
right = _broadcast(
|
|
231
|
-
right, left_old, common_dims, missing_right, swapped=True
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
assert sorted(left._dimensions_unsafe) == sorted(right._dimensions_unsafe)
|
|
235
|
-
expressions = (left, right)
|
|
236
|
-
|
|
237
|
-
dims = expressions[0]._dimensions_unsafe
|
|
238
|
-
# Check no dims conflict
|
|
239
|
-
assert all(
|
|
240
|
-
sorted(dims) == sorted(expr._dimensions_unsafe) for expr in expressions[1:]
|
|
186
|
+
has_dim_conflict = any(
|
|
187
|
+
sorted(expressions[0]._dimensions_unsafe) != sorted(expr._dimensions_unsafe)
|
|
188
|
+
for expr in expressions[1:]
|
|
241
189
|
)
|
|
242
|
-
if requires_join:
|
|
243
|
-
assert len(expressions) == 2
|
|
244
|
-
assert dims != []
|
|
245
|
-
left, right = expressions[0], expressions[1]
|
|
246
|
-
|
|
247
|
-
# Order so that drop always comes before keep, and keep always comes before default
|
|
248
|
-
if swap := (
|
|
249
|
-
(left._unmatched_strategy, right._unmatched_strategy)
|
|
250
|
-
in (
|
|
251
|
-
(UnmatchedStrategy.UNSET, UnmatchedStrategy.DROP),
|
|
252
|
-
(UnmatchedStrategy.UNSET, UnmatchedStrategy.KEEP),
|
|
253
|
-
(UnmatchedStrategy.KEEP, UnmatchedStrategy.DROP),
|
|
254
|
-
)
|
|
255
|
-
):
|
|
256
|
-
left, right = right, left
|
|
257
|
-
|
|
258
|
-
def get_indices(expr):
|
|
259
|
-
return expr.data.select(dims).unique(maintain_order=Config.maintain_order)
|
|
260
|
-
|
|
261
|
-
left_data, right_data = left.data, right.data
|
|
262
190
|
|
|
263
|
-
|
|
191
|
+
# If we cannot use .concat compute the sum in a pairwise manner, so far nobody uses this code
|
|
192
|
+
if len(expressions) > 2: # pragma: no cover
|
|
193
|
+
assert False, "This code has not been tested."
|
|
194
|
+
if has_dim_conflict or not no_extras_checks_required:
|
|
195
|
+
result = expressions[0]
|
|
196
|
+
for expr in expressions[1:]:
|
|
197
|
+
result = add(result, expr)
|
|
198
|
+
return result
|
|
199
|
+
propagate_strat = expressions[0]._extras_strategy
|
|
200
|
+
dims = expressions[0]._dimensions_unsafe
|
|
201
|
+
expr_data = [expr.data for expr in expressions]
|
|
202
|
+
else:
|
|
203
|
+
left, right = expressions[0], expressions[1]
|
|
264
204
|
|
|
265
|
-
|
|
205
|
+
if has_dim_conflict:
|
|
206
|
+
left_dims, right_dims = left._dimensions_unsafe, right._dimensions_unsafe
|
|
266
207
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
on=dims,
|
|
271
|
-
maintain_order="left" if Config.maintain_order else None,
|
|
272
|
-
)
|
|
273
|
-
right_data = right.data.join(
|
|
274
|
-
get_indices(left),
|
|
275
|
-
on=dims,
|
|
276
|
-
maintain_order="left" if Config.maintain_order else None,
|
|
277
|
-
)
|
|
278
|
-
elif strat == (UnmatchedStrategy.UNSET, UnmatchedStrategy.UNSET):
|
|
279
|
-
assert not Config.disable_unmatched_checks, (
|
|
280
|
-
"This code should not be reached when unmatched checks are disabled."
|
|
281
|
-
)
|
|
282
|
-
outer_join = get_indices(left).join(
|
|
283
|
-
get_indices(right),
|
|
284
|
-
how="full",
|
|
285
|
-
on=dims,
|
|
286
|
-
maintain_order="left_right" if Config.maintain_order else None,
|
|
287
|
-
)
|
|
288
|
-
if outer_join.get_column(dims[0]).null_count() > 0:
|
|
289
|
-
unmatched_vals = outer_join.filter(
|
|
290
|
-
outer_join.get_column(dims[0]).is_null()
|
|
291
|
-
)
|
|
292
|
-
_raise_unmatched_values_error(left, right, unmatched_vals, swap)
|
|
293
|
-
if outer_join.get_column(dims[0] + "_right").null_count() > 0:
|
|
294
|
-
unmatched_vals = outer_join.filter(
|
|
295
|
-
outer_join.get_column(dims[0] + "_right").is_null()
|
|
296
|
-
)
|
|
297
|
-
_raise_unmatched_values_error(left, right, unmatched_vals, swap)
|
|
208
|
+
missing_left = [dim for dim in right_dims if dim not in left_dims]
|
|
209
|
+
missing_right = [dim for dim in left_dims if dim not in right_dims]
|
|
210
|
+
common_dims = [dim for dim in left_dims if dim in right_dims]
|
|
298
211
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
maintain_order="left" if Config.maintain_order else None,
|
|
305
|
-
)
|
|
306
|
-
elif strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.UNSET):
|
|
307
|
-
left_data = get_indices(right).join(
|
|
308
|
-
left.data,
|
|
309
|
-
how="left",
|
|
310
|
-
on=dims,
|
|
311
|
-
maintain_order="left" if Config.maintain_order else None,
|
|
312
|
-
)
|
|
313
|
-
if left_data.get_column(COEF_KEY).null_count() > 0:
|
|
314
|
-
_raise_unmatched_values_error(
|
|
212
|
+
if not (
|
|
213
|
+
set(missing_left) <= set(left._allowed_new_dims)
|
|
214
|
+
and set(missing_right) <= set(right._allowed_new_dims)
|
|
215
|
+
):
|
|
216
|
+
_raise_addition_error(
|
|
315
217
|
left,
|
|
316
218
|
right,
|
|
317
|
-
|
|
318
|
-
|
|
219
|
+
f"their\n\tdimensions are different ({left_dims} != {right_dims})",
|
|
220
|
+
"If this is intentional, use .over(…) to broadcast. Learn more at\n\thttps://bravos-power.github.io/pyoframe/latest/learn/concepts/addition/#adding-expressions-with-differing-dimensions-using-over",
|
|
319
221
|
)
|
|
320
222
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
else: # pragma: no cover
|
|
329
|
-
assert False, "This code should've never been reached!"
|
|
223
|
+
left_old = left
|
|
224
|
+
if missing_left:
|
|
225
|
+
left = _broadcast(left, right, common_dims, missing_left)
|
|
226
|
+
if missing_right:
|
|
227
|
+
right = _broadcast(
|
|
228
|
+
right, left_old, common_dims, missing_right, swapped=True
|
|
229
|
+
)
|
|
330
230
|
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
231
|
+
assert sorted(left._dimensions_unsafe) == sorted(right._dimensions_unsafe)
|
|
232
|
+
|
|
233
|
+
dims = left._dimensions_unsafe
|
|
234
|
+
|
|
235
|
+
if not no_extras_checks_required:
|
|
236
|
+
expr_data, propagate_strat = _handle_extra_labels(left, right, dims)
|
|
237
|
+
else:
|
|
238
|
+
propagate_strat = left._extras_strategy
|
|
239
|
+
expr_data = (left.data, right.data)
|
|
335
240
|
|
|
336
241
|
# Add quadratic column if it is needed and doesn't already exist
|
|
337
242
|
if any(QUAD_VAR_KEY in df.columns for df in expr_data):
|
|
338
243
|
expr_data = [
|
|
339
244
|
(
|
|
340
|
-
df.with_columns(
|
|
245
|
+
df.with_columns(
|
|
246
|
+
pl.lit(CONST_TERM).alias(QUAD_VAR_KEY).cast(Config.id_dtype)
|
|
247
|
+
)
|
|
341
248
|
if QUAD_VAR_KEY not in df.columns
|
|
342
249
|
else df
|
|
343
250
|
)
|
|
@@ -358,23 +265,116 @@ def add(*expressions: Expression) -> Expression:
|
|
|
358
265
|
full_name += f" - {name[1:]}" if name[0] == "-" else f" + {name}"
|
|
359
266
|
|
|
360
267
|
new_expr = expressions[0]._new(data, name=f"({full_name})")
|
|
361
|
-
new_expr.
|
|
268
|
+
new_expr._extras_strategy = propagate_strat
|
|
362
269
|
|
|
363
270
|
return new_expr
|
|
364
271
|
|
|
365
272
|
|
|
366
|
-
def
|
|
367
|
-
left: Expression, right: Expression,
|
|
273
|
+
def _handle_extra_labels(
|
|
274
|
+
left: Expression, right: Expression, dims: list[str]
|
|
275
|
+
) -> tuple[tuple[pl.DataFrame, pl.DataFrame], ExtrasStrategy]:
|
|
276
|
+
assert dims != []
|
|
277
|
+
# Order so that drop always comes before keep, and keep always comes before default
|
|
278
|
+
if swapped := (
|
|
279
|
+
(left._extras_strategy, right._extras_strategy)
|
|
280
|
+
in (
|
|
281
|
+
(ExtrasStrategy.UNSET, ExtrasStrategy.DROP),
|
|
282
|
+
(ExtrasStrategy.UNSET, ExtrasStrategy.KEEP),
|
|
283
|
+
(ExtrasStrategy.KEEP, ExtrasStrategy.DROP),
|
|
284
|
+
)
|
|
285
|
+
):
|
|
286
|
+
left, right = right, left
|
|
287
|
+
|
|
288
|
+
def get_labels(expr):
|
|
289
|
+
return expr.data.select(dims).unique(maintain_order=Config.maintain_order)
|
|
290
|
+
|
|
291
|
+
left_data, right_data = left.data, right.data
|
|
292
|
+
|
|
293
|
+
strat = (left._extras_strategy, right._extras_strategy)
|
|
294
|
+
|
|
295
|
+
if strat == (ExtrasStrategy.DROP, ExtrasStrategy.DROP):
|
|
296
|
+
left_data = left.data.join(
|
|
297
|
+
get_labels(right),
|
|
298
|
+
on=dims,
|
|
299
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
300
|
+
)
|
|
301
|
+
right_data = right.data.join(
|
|
302
|
+
get_labels(left),
|
|
303
|
+
on=dims,
|
|
304
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
305
|
+
)
|
|
306
|
+
elif strat == (ExtrasStrategy.UNSET, ExtrasStrategy.UNSET):
|
|
307
|
+
assert not Config.disable_extras_checks, (
|
|
308
|
+
"This code should not be reached when checks for extra values are disabled."
|
|
309
|
+
)
|
|
310
|
+
left_labels, right_labels = get_labels(left), get_labels(right)
|
|
311
|
+
left_extras = left_labels.join(right_labels, how="anti", on=dims)
|
|
312
|
+
right_extras = right_labels.join(left_labels, how="anti", on=dims)
|
|
313
|
+
if len(left_extras) > 0:
|
|
314
|
+
_raise_extras_error(
|
|
315
|
+
left, right, left_extras, swapped, extras_on_right=False
|
|
316
|
+
)
|
|
317
|
+
if len(right_extras) > 0:
|
|
318
|
+
_raise_extras_error(left, right, right_extras, swapped)
|
|
319
|
+
|
|
320
|
+
elif strat == (ExtrasStrategy.DROP, ExtrasStrategy.KEEP):
|
|
321
|
+
left_data = get_labels(right).join(
|
|
322
|
+
left.data,
|
|
323
|
+
on=dims,
|
|
324
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
325
|
+
)
|
|
326
|
+
elif strat == (ExtrasStrategy.DROP, ExtrasStrategy.UNSET):
|
|
327
|
+
right_labels = get_labels(right)
|
|
328
|
+
left_data = right_labels.join(
|
|
329
|
+
left.data,
|
|
330
|
+
how="left",
|
|
331
|
+
on=dims,
|
|
332
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
333
|
+
)
|
|
334
|
+
if left_data.get_column(COEF_KEY).null_count() > 0:
|
|
335
|
+
_raise_extras_error(
|
|
336
|
+
left,
|
|
337
|
+
right,
|
|
338
|
+
right_labels.join(get_labels(left), how="anti", on=dims),
|
|
339
|
+
swapped,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
elif strat == (ExtrasStrategy.KEEP, ExtrasStrategy.UNSET):
|
|
343
|
+
assert not Config.disable_extras_checks, (
|
|
344
|
+
"This code should not be reached when checks for extra values are disabled."
|
|
345
|
+
)
|
|
346
|
+
extras = right.data.join(get_labels(left), how="anti", on=dims)
|
|
347
|
+
if len(extras) > 0:
|
|
348
|
+
_raise_extras_error(left, right, extras.select(dims), swapped)
|
|
349
|
+
else: # pragma: no cover
|
|
350
|
+
assert False, "This code should've never been reached!"
|
|
351
|
+
|
|
352
|
+
if swapped:
|
|
353
|
+
left_data, right_data = right_data, left_data
|
|
354
|
+
|
|
355
|
+
return (left_data, right_data), _extras_propagation_rules[strat]
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _raise_extras_error(
|
|
359
|
+
left: Expression,
|
|
360
|
+
right: Expression,
|
|
361
|
+
extra_labels: pl.DataFrame,
|
|
362
|
+
swapped: bool,
|
|
363
|
+
extras_on_right: bool = True,
|
|
368
364
|
):
|
|
369
365
|
if swapped:
|
|
370
366
|
left, right = right, left
|
|
367
|
+
extras_on_right = not extras_on_right
|
|
371
368
|
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
369
|
+
expression_num = 2 if extras_on_right else 1
|
|
370
|
+
|
|
371
|
+
with Config.print_polars_config:
|
|
372
|
+
_raise_addition_error(
|
|
373
|
+
left,
|
|
374
|
+
right,
|
|
375
|
+
f"expression {expression_num} has extra labels",
|
|
376
|
+
f"Extra labels in expression {expression_num}:\n{extra_labels}\nUse .drop_extras() or .keep_extras() to indicate how the extra labels should be handled. Learn more at\n\thttps://bravos-power.github.io/pyoframe/latest/learn/concepts/addition",
|
|
377
|
+
)
|
|
378
378
|
|
|
379
379
|
|
|
380
380
|
def _raise_addition_error(
|
|
@@ -412,7 +412,7 @@ def _broadcast(
|
|
|
412
412
|
return res
|
|
413
413
|
|
|
414
414
|
# If drop, we just do an inner join to get into the shape of the other
|
|
415
|
-
if self.
|
|
415
|
+
if self._extras_strategy == ExtrasStrategy.DROP:
|
|
416
416
|
res = self._new(
|
|
417
417
|
self.data.join(
|
|
418
418
|
target_data,
|
|
@@ -430,12 +430,14 @@ def _broadcast(
|
|
|
430
430
|
how="left",
|
|
431
431
|
maintain_order="left" if Config.maintain_order else None,
|
|
432
432
|
)
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
433
|
+
if result.get_column(missing_dims[0]).null_count() > 0:
|
|
434
|
+
target_labels = target.data.select(target._dimensions_unsafe).unique(
|
|
435
|
+
maintain_order=Config.maintain_order
|
|
436
|
+
)
|
|
437
|
+
_raise_extras_error(
|
|
436
438
|
self,
|
|
437
439
|
target,
|
|
438
|
-
|
|
440
|
+
target_labels.join(self.data, how="anti", on=common_dims),
|
|
439
441
|
swapped,
|
|
440
442
|
)
|
|
441
443
|
res = self._new(result, self.name)
|
|
@@ -522,7 +524,7 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
|
|
|
522
524
|
if df.is_empty():
|
|
523
525
|
df = pl.DataFrame(
|
|
524
526
|
{VAR_KEY: [CONST_TERM], COEF_KEY: [0]},
|
|
525
|
-
schema={VAR_KEY:
|
|
527
|
+
schema={VAR_KEY: Config.id_dtype, COEF_KEY: pl.Float64},
|
|
526
528
|
)
|
|
527
529
|
|
|
528
530
|
if QUAD_VAR_KEY in df.columns and (df.get_column(QUAD_VAR_KEY) == CONST_TERM).all():
|