pyoframe 0.2.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyoframe/__init__.py +21 -14
- pyoframe/_arithmetic.py +346 -238
- pyoframe/_constants.py +463 -0
- pyoframe/_core.py +2652 -0
- pyoframe/_model.py +598 -0
- pyoframe/_model_element.py +189 -0
- pyoframe/_monkey_patch.py +82 -0
- pyoframe/{objective.py → _objective.py} +50 -17
- pyoframe/{util.py → _utils.py} +108 -129
- pyoframe/_version.py +16 -3
- {pyoframe-0.2.0.dist-info → pyoframe-1.0.0.dist-info}/METADATA +37 -31
- pyoframe-1.0.0.dist-info/RECORD +15 -0
- pyoframe/constants.py +0 -140
- pyoframe/core.py +0 -1794
- pyoframe/model.py +0 -408
- pyoframe/model_element.py +0 -184
- pyoframe/monkey_patch.py +0 -54
- pyoframe-0.2.0.dist-info/RECORD +0 -15
- {pyoframe-0.2.0.dist-info → pyoframe-1.0.0.dist-info}/WHEEL +0 -0
- {pyoframe-0.2.0.dist-info → pyoframe-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {pyoframe-0.2.0.dist-info → pyoframe-1.0.0.dist-info}/top_level.txt +0 -0
pyoframe/_arithmetic.py
CHANGED
|
@@ -1,99 +1,87 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
"""Defines helper functions for doing arithmetic operations on expressions (e.g. addition)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import TYPE_CHECKING
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
6
|
|
|
7
7
|
import polars as pl
|
|
8
8
|
|
|
9
|
-
from pyoframe.
|
|
9
|
+
from pyoframe._constants import (
|
|
10
10
|
COEF_KEY,
|
|
11
11
|
CONST_TERM,
|
|
12
|
-
KEY_TYPE,
|
|
13
12
|
QUAD_VAR_KEY,
|
|
14
13
|
RESERVED_COL_KEYS,
|
|
15
14
|
VAR_KEY,
|
|
16
15
|
Config,
|
|
16
|
+
ExtrasStrategy,
|
|
17
17
|
PyoframeError,
|
|
18
|
-
UnmatchedStrategy,
|
|
19
18
|
)
|
|
20
19
|
|
|
21
20
|
if TYPE_CHECKING: # pragma: no cover
|
|
22
|
-
from pyoframe.
|
|
21
|
+
from pyoframe._core import Expression
|
|
23
22
|
|
|
23
|
+
# Mapping of how a sum of two expressions should propagate the extras strategy
|
|
24
|
+
_extras_propagation_rules = {
|
|
25
|
+
(ExtrasStrategy.DROP, ExtrasStrategy.DROP): ExtrasStrategy.DROP,
|
|
26
|
+
(ExtrasStrategy.UNSET, ExtrasStrategy.UNSET): ExtrasStrategy.UNSET,
|
|
27
|
+
(ExtrasStrategy.KEEP, ExtrasStrategy.KEEP): ExtrasStrategy.KEEP,
|
|
28
|
+
(ExtrasStrategy.DROP, ExtrasStrategy.KEEP): ExtrasStrategy.UNSET,
|
|
29
|
+
(ExtrasStrategy.DROP, ExtrasStrategy.UNSET): ExtrasStrategy.DROP,
|
|
30
|
+
(ExtrasStrategy.KEEP, ExtrasStrategy.UNSET): ExtrasStrategy.KEEP,
|
|
31
|
+
}
|
|
24
32
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
33
|
+
|
|
34
|
+
def multiply(self: Expression, other: Expression) -> Expression:
|
|
35
|
+
"""Multiplies two expressions together.
|
|
28
36
|
|
|
29
37
|
Examples:
|
|
30
38
|
>>> import pyoframe as pf
|
|
31
|
-
>>> m = pf.Model(
|
|
39
|
+
>>> m = pf.Model()
|
|
32
40
|
>>> m.x1 = pf.Variable()
|
|
33
41
|
>>> m.x2 = pf.Variable()
|
|
34
42
|
>>> m.x3 = pf.Variable()
|
|
35
43
|
>>> result = 5 * m.x1 * m.x2
|
|
36
44
|
>>> result
|
|
37
|
-
<Expression
|
|
38
|
-
5
|
|
45
|
+
<Expression terms=1 type=quadratic>
|
|
46
|
+
5 x2 * x1
|
|
39
47
|
>>> result * m.x3
|
|
40
48
|
Traceback (most recent call last):
|
|
41
49
|
...
|
|
42
|
-
pyoframe.
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
Cannot multiply a quadratic expression by a non-constant.
|
|
50
|
+
pyoframe._constants.PyoframeError: Cannot multiply the two expressions below because the result would be a cubic. Only quadratic or linear expressions are allowed.
|
|
51
|
+
Expression 1 (quadratic): ((5 * x1) * x2)
|
|
52
|
+
Expression 2 (linear): x3
|
|
46
53
|
"""
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
"
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
)
|
|
55
|
-
+ "\nDue to error:\n"
|
|
56
|
-
+ str(error)
|
|
57
|
-
) from error
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def _add_expressions(*expressions: "Expression") -> "Expression":
|
|
61
|
-
try:
|
|
62
|
-
return _add_expressions_core(*expressions)
|
|
63
|
-
except PyoframeError as error:
|
|
54
|
+
self_degree, other_degree = self.degree(), other.degree()
|
|
55
|
+
product_degree = self_degree + other_degree
|
|
56
|
+
if product_degree > 2:
|
|
57
|
+
assert product_degree <= 4, (
|
|
58
|
+
"Unexpected because expressions should not exceed degree 2."
|
|
59
|
+
)
|
|
60
|
+
res_type = "cubic" if product_degree == 3 else "quartic"
|
|
64
61
|
raise PyoframeError(
|
|
65
|
-
"
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
+ "\nDue to error:\n"
|
|
70
|
-
+ str(error)
|
|
71
|
-
) from error
|
|
62
|
+
f"""Cannot multiply the two expressions below because the result would be a {res_type}. Only quadratic or linear expressions are allowed.
|
|
63
|
+
Expression 1 ({self.degree(return_str=True)}):\t{self.name}
|
|
64
|
+
Expression 2 ({other.degree(return_str=True)}):\t{other.name}"""
|
|
65
|
+
)
|
|
72
66
|
|
|
67
|
+
if self_degree == 1 and other_degree == 1:
|
|
68
|
+
return _quadratic_multiplication(self, other)
|
|
73
69
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
if self_degree + other_degree > 2:
|
|
77
|
-
# We know one of the two must be a quadratic since 1 + 1 is not greater than 2.
|
|
78
|
-
raise PyoframeError("Cannot multiply a quadratic expression by a non-constant.")
|
|
70
|
+
# save names to use in debug messages before any swapping occurs
|
|
71
|
+
self_name, other_name = self.name, other.name
|
|
79
72
|
if self_degree < other_degree:
|
|
80
73
|
self, other = other, self
|
|
81
74
|
self_degree, other_degree = other_degree, self_degree
|
|
82
|
-
if other_degree == 1:
|
|
83
|
-
assert self_degree == 1, (
|
|
84
|
-
"This should always be true since the sum of degrees must be <=2."
|
|
85
|
-
)
|
|
86
|
-
return _quadratic_multiplication(self, other)
|
|
87
75
|
|
|
88
76
|
assert other_degree == 0, (
|
|
89
77
|
"This should always be true since other cases have already been handled."
|
|
90
78
|
)
|
|
91
|
-
multiplier = other.data.drop(
|
|
92
|
-
VAR_KEY
|
|
93
|
-
) # QUAD_VAR_KEY doesn't need to be dropped since we know it doesn't exist
|
|
94
79
|
|
|
95
|
-
|
|
96
|
-
|
|
80
|
+
# QUAD_VAR_KEY doesn't need to be dropped since we know it doesn't exist
|
|
81
|
+
multiplier = other.data.drop(VAR_KEY)
|
|
82
|
+
|
|
83
|
+
dims = self._dimensions_unsafe
|
|
84
|
+
other_dims = other._dimensions_unsafe
|
|
97
85
|
dims_in_common = [dim for dim in dims if dim in other_dims]
|
|
98
86
|
|
|
99
87
|
data = (
|
|
@@ -101,17 +89,19 @@ def _multiply_expressions_core(self: "Expression", other: "Expression") -> "Expr
|
|
|
101
89
|
multiplier,
|
|
102
90
|
on=dims_in_common if len(dims_in_common) > 0 else None,
|
|
103
91
|
how="inner" if dims_in_common else "cross",
|
|
92
|
+
maintain_order=(
|
|
93
|
+
"left" if Config.maintain_order and dims_in_common else None
|
|
94
|
+
),
|
|
104
95
|
)
|
|
105
96
|
.with_columns(pl.col(COEF_KEY) * pl.col(COEF_KEY + "_right"))
|
|
106
97
|
.drop(COEF_KEY + "_right")
|
|
107
98
|
)
|
|
108
99
|
|
|
109
|
-
return self._new(data)
|
|
100
|
+
return self._new(data, name=f"({self_name} * {other_name})")
|
|
110
101
|
|
|
111
102
|
|
|
112
|
-
def _quadratic_multiplication(self:
|
|
113
|
-
"""
|
|
114
|
-
Multiply two expressions of degree 1.
|
|
103
|
+
def _quadratic_multiplication(self: Expression, other: Expression) -> Expression:
|
|
104
|
+
"""Multiplies two expressions of degree 1.
|
|
115
105
|
|
|
116
106
|
Examples:
|
|
117
107
|
>>> import polars as pl
|
|
@@ -122,18 +112,29 @@ def _quadratic_multiplication(self: "Expression", other: "Expression") -> "Expre
|
|
|
122
112
|
>>> expr1 = df * m.x1
|
|
123
113
|
>>> expr2 = df * m.x2 * 2 + 4
|
|
124
114
|
>>> expr1 * expr2
|
|
125
|
-
<Expression
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
115
|
+
<Expression height=3 terms=6 type=quadratic>
|
|
116
|
+
┌─────┬───────────────────┐
|
|
117
|
+
│ dim ┆ expression │
|
|
118
|
+
│ (3) ┆ │
|
|
119
|
+
╞═════╪═══════════════════╡
|
|
120
|
+
│ 1 ┆ 4 x1 +2 x2 * x1 │
|
|
121
|
+
│ 2 ┆ 8 x1 +8 x2 * x1 │
|
|
122
|
+
│ 3 ┆ 12 x1 +18 x2 * x1 │
|
|
123
|
+
└─────┴───────────────────┘
|
|
129
124
|
>>> (expr1 * expr2) - df * m.x1 * df * m.x2 * 2
|
|
130
|
-
<Expression
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
125
|
+
<Expression height=3 terms=3 type=linear>
|
|
126
|
+
┌─────┬────────────┐
|
|
127
|
+
│ dim ┆ expression │
|
|
128
|
+
│ (3) ┆ │
|
|
129
|
+
╞═════╪════════════╡
|
|
130
|
+
│ 1 ┆ 4 x1 │
|
|
131
|
+
│ 2 ┆ 8 x1 │
|
|
132
|
+
│ 3 ┆ 12 x1 │
|
|
133
|
+
└─────┴────────────┘
|
|
134
|
+
|
|
134
135
|
"""
|
|
135
|
-
dims = self.
|
|
136
|
-
other_dims = other.
|
|
136
|
+
dims = self._dimensions_unsafe
|
|
137
|
+
other_dims = other._dimensions_unsafe
|
|
137
138
|
dims_in_common = [dim for dim in dims if dim in other_dims]
|
|
138
139
|
|
|
139
140
|
data = (
|
|
@@ -141,11 +142,14 @@ def _quadratic_multiplication(self: "Expression", other: "Expression") -> "Expre
|
|
|
141
142
|
other.data,
|
|
142
143
|
on=dims_in_common if len(dims_in_common) > 0 else None,
|
|
143
144
|
how="inner" if dims_in_common else "cross",
|
|
145
|
+
maintain_order=(
|
|
146
|
+
"left" if Config.maintain_order and dims_in_common else None
|
|
147
|
+
),
|
|
144
148
|
)
|
|
145
149
|
.with_columns(pl.col(COEF_KEY) * pl.col(COEF_KEY + "_right"))
|
|
146
150
|
.drop(COEF_KEY + "_right")
|
|
147
151
|
.rename({VAR_KEY + "_right": QUAD_VAR_KEY})
|
|
148
|
-
# Swap VAR_KEY and QUAD_VAR_KEY so that
|
|
152
|
+
# Swap VAR_KEY and QUAD_VAR_KEY so that VAR_KEY is always the larger one
|
|
149
153
|
.with_columns(
|
|
150
154
|
pl.when(pl.col(VAR_KEY) < pl.col(QUAD_VAR_KEY))
|
|
151
155
|
.then(pl.col(QUAD_VAR_KEY))
|
|
@@ -160,147 +164,87 @@ def _quadratic_multiplication(self: "Expression", other: "Expression") -> "Expre
|
|
|
160
164
|
|
|
161
165
|
data = _sum_like_terms(data)
|
|
162
166
|
|
|
163
|
-
return self._new(data)
|
|
164
|
-
|
|
167
|
+
return self._new(data, name=f"({self.name} * {other.name})")
|
|
165
168
|
|
|
166
|
-
def _add_expressions_core(*expressions: "Expression") -> "Expression":
|
|
167
|
-
# Mapping of how a sum of two expressions should propogate the unmatched strategy
|
|
168
|
-
propogatation_strategies = {
|
|
169
|
-
(UnmatchedStrategy.DROP, UnmatchedStrategy.DROP): UnmatchedStrategy.DROP,
|
|
170
|
-
(
|
|
171
|
-
UnmatchedStrategy.UNSET,
|
|
172
|
-
UnmatchedStrategy.UNSET,
|
|
173
|
-
): UnmatchedStrategy.UNSET,
|
|
174
|
-
(UnmatchedStrategy.KEEP, UnmatchedStrategy.KEEP): UnmatchedStrategy.KEEP,
|
|
175
|
-
(UnmatchedStrategy.DROP, UnmatchedStrategy.KEEP): UnmatchedStrategy.UNSET,
|
|
176
|
-
(UnmatchedStrategy.DROP, UnmatchedStrategy.UNSET): UnmatchedStrategy.DROP,
|
|
177
|
-
(UnmatchedStrategy.KEEP, UnmatchedStrategy.UNSET): UnmatchedStrategy.KEEP,
|
|
178
|
-
}
|
|
179
169
|
|
|
170
|
+
def add(*expressions: Expression) -> Expression:
|
|
171
|
+
"""Add multiple expressions together."""
|
|
180
172
|
assert len(expressions) > 1, "Need at least two expressions to add together."
|
|
181
173
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
if dims is None:
|
|
185
|
-
requires_join = False
|
|
186
|
-
dims = []
|
|
187
|
-
elif Config.disable_unmatched_checks:
|
|
188
|
-
requires_join = any(
|
|
189
|
-
expr.unmatched_strategy
|
|
190
|
-
not in (UnmatchedStrategy.KEEP, UnmatchedStrategy.UNSET)
|
|
191
|
-
for expr in expressions
|
|
192
|
-
)
|
|
174
|
+
if Config.disable_extras_checks:
|
|
175
|
+
no_checks_strats = (ExtrasStrategy.KEEP, ExtrasStrategy.UNSET)
|
|
193
176
|
else:
|
|
194
|
-
|
|
195
|
-
expr.unmatched_strategy != UnmatchedStrategy.KEEP for expr in expressions
|
|
196
|
-
)
|
|
177
|
+
no_checks_strats = (ExtrasStrategy.KEEP,)
|
|
197
178
|
|
|
198
|
-
|
|
199
|
-
|
|
179
|
+
no_extras_checks_required = (
|
|
180
|
+
all(expr._extras_strategy in no_checks_strats for expr in expressions)
|
|
181
|
+
# if only one dimensioned, then there is no such thing as extra labels,
|
|
182
|
+
# labels will be set by the only dimensioned expression
|
|
183
|
+
or sum(not expr.dimensionless for expr in expressions) <= 1
|
|
200
184
|
)
|
|
201
185
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
for expr in expressions[1:]:
|
|
206
|
-
result = _add_expressions_core(result, expr)
|
|
207
|
-
return result
|
|
208
|
-
|
|
209
|
-
if has_dim_conflict:
|
|
210
|
-
assert len(expressions) == 2
|
|
211
|
-
expressions = (
|
|
212
|
-
_add_dimension(expressions[0], expressions[1]),
|
|
213
|
-
_add_dimension(expressions[1], expressions[0]),
|
|
214
|
-
)
|
|
215
|
-
assert sorted(expressions[0].dimensions_unsafe) == sorted(
|
|
216
|
-
expressions[1].dimensions_unsafe
|
|
217
|
-
)
|
|
218
|
-
|
|
219
|
-
dims = expressions[0].dimensions_unsafe
|
|
220
|
-
# Check no dims conflict
|
|
221
|
-
assert all(
|
|
222
|
-
sorted(dims) == sorted(expr.dimensions_unsafe) for expr in expressions[1:]
|
|
186
|
+
has_dim_conflict = any(
|
|
187
|
+
sorted(expressions[0]._dimensions_unsafe) != sorted(expr._dimensions_unsafe)
|
|
188
|
+
for expr in expressions[1:]
|
|
223
189
|
)
|
|
224
|
-
if requires_join:
|
|
225
|
-
assert len(expressions) == 2
|
|
226
|
-
assert dims != []
|
|
227
|
-
left, right = expressions[0], expressions[1]
|
|
228
190
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
191
|
+
# If we cannot use .concat compute the sum in a pairwise manner, so far nobody uses this code
|
|
192
|
+
if len(expressions) > 2: # pragma: no cover
|
|
193
|
+
assert False, "This code has not been tested."
|
|
194
|
+
if has_dim_conflict or not no_extras_checks_required:
|
|
195
|
+
result = expressions[0]
|
|
196
|
+
for expr in expressions[1:]:
|
|
197
|
+
result = add(result, expr)
|
|
198
|
+
return result
|
|
199
|
+
propagate_strat = expressions[0]._extras_strategy
|
|
200
|
+
dims = expressions[0]._dimensions_unsafe
|
|
201
|
+
expr_data = [expr.data for expr in expressions]
|
|
202
|
+
else:
|
|
203
|
+
left, right = expressions[0], expressions[1]
|
|
236
204
|
|
|
237
|
-
|
|
238
|
-
|
|
205
|
+
if has_dim_conflict:
|
|
206
|
+
left_dims, right_dims = left._dimensions_unsafe, right._dimensions_unsafe
|
|
207
|
+
|
|
208
|
+
missing_left = [dim for dim in right_dims if dim not in left_dims]
|
|
209
|
+
missing_right = [dim for dim in left_dims if dim not in right_dims]
|
|
210
|
+
common_dims = [dim for dim in left_dims if dim in right_dims]
|
|
211
|
+
|
|
212
|
+
if not (
|
|
213
|
+
set(missing_left) <= set(left._allowed_new_dims)
|
|
214
|
+
and set(missing_right) <= set(right._allowed_new_dims)
|
|
215
|
+
):
|
|
216
|
+
_raise_addition_error(
|
|
217
|
+
left,
|
|
218
|
+
right,
|
|
219
|
+
f"their\n\tdimensions are different ({left_dims} != {right_dims})",
|
|
220
|
+
"If this is intentional, use .over(…) to broadcast. Learn more at\n\thttps://bravos-power.github.io/pyoframe/latest/learn/concepts/addition/#adding-expressions-with-differing-dimensions-using-over",
|
|
221
|
+
)
|
|
239
222
|
|
|
240
|
-
|
|
223
|
+
left_old = left
|
|
224
|
+
if missing_left:
|
|
225
|
+
left = _broadcast(left, right, common_dims, missing_left)
|
|
226
|
+
if missing_right:
|
|
227
|
+
right = _broadcast(
|
|
228
|
+
right, left_old, common_dims, missing_right, swapped=True
|
|
229
|
+
)
|
|
241
230
|
|
|
242
|
-
|
|
231
|
+
assert sorted(left._dimensions_unsafe) == sorted(right._dimensions_unsafe)
|
|
243
232
|
|
|
244
|
-
|
|
233
|
+
dims = left._dimensions_unsafe
|
|
245
234
|
|
|
246
|
-
if
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
"This code should not be reached when unmatched checks are disabled."
|
|
252
|
-
)
|
|
253
|
-
outer_join = get_indices(left).join(
|
|
254
|
-
get_indices(right),
|
|
255
|
-
how="full",
|
|
256
|
-
on=dims,
|
|
257
|
-
)
|
|
258
|
-
if outer_join.get_column(dims[0]).null_count() > 0:
|
|
259
|
-
raise PyoframeError(
|
|
260
|
-
"Dataframe has unmatched values. If this is intentional, use .drop_unmatched() or .keep_unmatched()\n"
|
|
261
|
-
+ str(outer_join.filter(outer_join.get_column(dims[0]).is_null()))
|
|
262
|
-
)
|
|
263
|
-
if outer_join.get_column(dims[0] + "_right").null_count() > 0:
|
|
264
|
-
raise PyoframeError(
|
|
265
|
-
"Dataframe has unmatched values. If this is intentional, use .drop_unmatched() or .keep_unmatched()\n"
|
|
266
|
-
+ str(
|
|
267
|
-
outer_join.filter(
|
|
268
|
-
outer_join.get_column(dims[0] + "_right").is_null()
|
|
269
|
-
)
|
|
270
|
-
)
|
|
271
|
-
)
|
|
272
|
-
elif strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.KEEP):
|
|
273
|
-
left_data = get_indices(right).join(left.data, how="left", on=dims)
|
|
274
|
-
elif strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.UNSET):
|
|
275
|
-
left_data = get_indices(right).join(left.data, how="left", on=dims)
|
|
276
|
-
if left_data.get_column(COEF_KEY).null_count() > 0:
|
|
277
|
-
raise PyoframeError(
|
|
278
|
-
"Dataframe has unmatched values. If this is intentional, use .drop_unmatched() or .keep_unmatched()\n"
|
|
279
|
-
+ str(left_data.filter(left_data.get_column(COEF_KEY).is_null()))
|
|
280
|
-
)
|
|
281
|
-
elif strat == (UnmatchedStrategy.KEEP, UnmatchedStrategy.UNSET):
|
|
282
|
-
assert not Config.disable_unmatched_checks, (
|
|
283
|
-
"This code should not be reached when unmatched checks are disabled."
|
|
284
|
-
)
|
|
285
|
-
unmatched = right.data.join(get_indices(left), how="anti", on=dims)
|
|
286
|
-
if len(unmatched) > 0:
|
|
287
|
-
raise PyoframeError(
|
|
288
|
-
"Dataframe has unmatched values. If this is intentional, use .drop_unmatched() or .keep_unmatched()\n"
|
|
289
|
-
+ str(unmatched)
|
|
290
|
-
)
|
|
291
|
-
else: # pragma: no cover
|
|
292
|
-
assert False, "This code should've never been reached!"
|
|
293
|
-
|
|
294
|
-
expr_data = [left_data, right_data]
|
|
295
|
-
else:
|
|
296
|
-
propogate_strat = expressions[0].unmatched_strategy
|
|
297
|
-
expr_data = [expr.data for expr in expressions]
|
|
235
|
+
if not no_extras_checks_required:
|
|
236
|
+
expr_data, propagate_strat = _handle_extra_labels(left, right, dims)
|
|
237
|
+
else:
|
|
238
|
+
propagate_strat = left._extras_strategy
|
|
239
|
+
expr_data = (left.data, right.data)
|
|
298
240
|
|
|
299
241
|
# Add quadratic column if it is needed and doesn't already exist
|
|
300
242
|
if any(QUAD_VAR_KEY in df.columns for df in expr_data):
|
|
301
243
|
expr_data = [
|
|
302
244
|
(
|
|
303
|
-
df.with_columns(
|
|
245
|
+
df.with_columns(
|
|
246
|
+
pl.lit(CONST_TERM).alias(QUAD_VAR_KEY).cast(Config.id_dtype)
|
|
247
|
+
)
|
|
304
248
|
if QUAD_VAR_KEY not in df.columns
|
|
305
249
|
else df
|
|
306
250
|
)
|
|
@@ -315,71 +259,215 @@ def _add_expressions_core(*expressions: "Expression") -> "Expression":
|
|
|
315
259
|
data = pl.concat(expr_data, how="vertical_relaxed")
|
|
316
260
|
data = _sum_like_terms(data)
|
|
317
261
|
|
|
318
|
-
|
|
319
|
-
|
|
262
|
+
full_name = expressions[0].name
|
|
263
|
+
for expr in expressions[1:]:
|
|
264
|
+
name = expr.name
|
|
265
|
+
full_name += f" - {name[1:]}" if name[0] == "-" else f" + {name}"
|
|
266
|
+
|
|
267
|
+
new_expr = expressions[0]._new(data, name=f"({full_name})")
|
|
268
|
+
new_expr._extras_strategy = propagate_strat
|
|
320
269
|
|
|
321
270
|
return new_expr
|
|
322
271
|
|
|
323
272
|
|
|
324
|
-
def
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
if
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
273
|
+
def _handle_extra_labels(
|
|
274
|
+
left: Expression, right: Expression, dims: list[str]
|
|
275
|
+
) -> tuple[tuple[pl.DataFrame, pl.DataFrame], ExtrasStrategy]:
|
|
276
|
+
assert dims != []
|
|
277
|
+
# Order so that drop always comes before keep, and keep always comes before default
|
|
278
|
+
if swapped := (
|
|
279
|
+
(left._extras_strategy, right._extras_strategy)
|
|
280
|
+
in (
|
|
281
|
+
(ExtrasStrategy.UNSET, ExtrasStrategy.DROP),
|
|
282
|
+
(ExtrasStrategy.UNSET, ExtrasStrategy.KEEP),
|
|
283
|
+
(ExtrasStrategy.KEEP, ExtrasStrategy.DROP),
|
|
284
|
+
)
|
|
285
|
+
):
|
|
286
|
+
left, right = right, left
|
|
335
287
|
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
return self
|
|
288
|
+
def get_labels(expr):
|
|
289
|
+
return expr.data.select(dims).unique(maintain_order=Config.maintain_order)
|
|
339
290
|
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
291
|
+
left_data, right_data = left.data, right.data
|
|
292
|
+
|
|
293
|
+
strat = (left._extras_strategy, right._extras_strategy)
|
|
294
|
+
|
|
295
|
+
if strat == (ExtrasStrategy.DROP, ExtrasStrategy.DROP):
|
|
296
|
+
left_data = left.data.join(
|
|
297
|
+
get_labels(right),
|
|
298
|
+
on=dims,
|
|
299
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
300
|
+
)
|
|
301
|
+
right_data = right.data.join(
|
|
302
|
+
get_labels(left),
|
|
303
|
+
on=dims,
|
|
304
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
344
305
|
)
|
|
306
|
+
elif strat == (ExtrasStrategy.UNSET, ExtrasStrategy.UNSET):
|
|
307
|
+
assert not Config.disable_extras_checks, (
|
|
308
|
+
"This code should not be reached when checks for extra values are disabled."
|
|
309
|
+
)
|
|
310
|
+
left_labels, right_labels = get_labels(left), get_labels(right)
|
|
311
|
+
left_extras = left_labels.join(right_labels, how="anti", on=dims)
|
|
312
|
+
right_extras = right_labels.join(left_labels, how="anti", on=dims)
|
|
313
|
+
if len(left_extras) > 0:
|
|
314
|
+
_raise_extras_error(
|
|
315
|
+
left, right, left_extras, swapped, extras_on_right=False
|
|
316
|
+
)
|
|
317
|
+
if len(right_extras) > 0:
|
|
318
|
+
_raise_extras_error(left, right, right_extras, swapped)
|
|
319
|
+
|
|
320
|
+
elif strat == (ExtrasStrategy.DROP, ExtrasStrategy.KEEP):
|
|
321
|
+
left_data = get_labels(right).join(
|
|
322
|
+
left.data,
|
|
323
|
+
on=dims,
|
|
324
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
325
|
+
)
|
|
326
|
+
elif strat == (ExtrasStrategy.DROP, ExtrasStrategy.UNSET):
|
|
327
|
+
right_labels = get_labels(right)
|
|
328
|
+
left_data = right_labels.join(
|
|
329
|
+
left.data,
|
|
330
|
+
how="left",
|
|
331
|
+
on=dims,
|
|
332
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
333
|
+
)
|
|
334
|
+
if left_data.get_column(COEF_KEY).null_count() > 0:
|
|
335
|
+
_raise_extras_error(
|
|
336
|
+
left,
|
|
337
|
+
right,
|
|
338
|
+
right_labels.join(get_labels(left), how="anti", on=dims),
|
|
339
|
+
swapped,
|
|
340
|
+
)
|
|
345
341
|
|
|
346
|
-
|
|
342
|
+
elif strat == (ExtrasStrategy.KEEP, ExtrasStrategy.UNSET):
|
|
343
|
+
assert not Config.disable_extras_checks, (
|
|
344
|
+
"This code should not be reached when checks for extra values are disabled."
|
|
345
|
+
)
|
|
346
|
+
extras = right.data.join(get_labels(left), how="anti", on=dims)
|
|
347
|
+
if len(extras) > 0:
|
|
348
|
+
_raise_extras_error(left, right, extras.select(dims), swapped)
|
|
349
|
+
else: # pragma: no cover
|
|
350
|
+
assert False, "This code should've never been reached!"
|
|
351
|
+
|
|
352
|
+
if swapped:
|
|
353
|
+
left_data, right_data = right_data, left_data
|
|
354
|
+
|
|
355
|
+
return (left_data, right_data), _extras_propagation_rules[strat]
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _raise_extras_error(
|
|
359
|
+
left: Expression,
|
|
360
|
+
right: Expression,
|
|
361
|
+
extra_labels: pl.DataFrame,
|
|
362
|
+
swapped: bool,
|
|
363
|
+
extras_on_right: bool = True,
|
|
364
|
+
):
|
|
365
|
+
if swapped:
|
|
366
|
+
left, right = right, left
|
|
367
|
+
extras_on_right = not extras_on_right
|
|
368
|
+
|
|
369
|
+
expression_num = 2 if extras_on_right else 1
|
|
370
|
+
|
|
371
|
+
with Config.print_polars_config:
|
|
372
|
+
_raise_addition_error(
|
|
373
|
+
left,
|
|
374
|
+
right,
|
|
375
|
+
f"expression {expression_num} has extra labels",
|
|
376
|
+
f"Extra labels in expression {expression_num}:\n{extra_labels}\nUse .drop_extras() or .keep_extras() to indicate how the extra labels should be handled. Learn more at\n\thttps://bravos-power.github.io/pyoframe/latest/learn/concepts/addition",
|
|
377
|
+
)
|
|
347
378
|
|
|
348
|
-
if not dims_in_common:
|
|
349
|
-
return self._new(self.data.join(target_data, how="cross"))
|
|
350
379
|
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
380
|
+
def _raise_addition_error(
|
|
381
|
+
left: Expression, right: Expression, reason: str, postfix: str
|
|
382
|
+
):
|
|
383
|
+
op = "add"
|
|
384
|
+
right_name = right.name
|
|
385
|
+
if right_name[0] == "-":
|
|
386
|
+
op = "subtract"
|
|
387
|
+
right_name = right_name[1:]
|
|
388
|
+
raise PyoframeError(
|
|
389
|
+
f"""Cannot {op} the two expressions below because {reason}.
|
|
390
|
+
Expression 1:\t{left.name}
|
|
391
|
+
Expression 2:\t{right_name}
|
|
392
|
+
{postfix}
|
|
393
|
+
"""
|
|
394
|
+
)
|
|
354
395
|
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
396
|
+
|
|
397
|
+
# TODO consider returning a dataframe instead of an expression to simplify code (e.g. avoid copy_flags)
|
|
398
|
+
def _broadcast(
|
|
399
|
+
self: Expression,
|
|
400
|
+
target: Expression,
|
|
401
|
+
common_dims: list[str],
|
|
402
|
+
missing_dims: list[str],
|
|
403
|
+
swapped: bool = False,
|
|
404
|
+
) -> Expression:
|
|
405
|
+
target_data = target.data.select(target._dimensions_unsafe).unique(
|
|
406
|
+
maintain_order=Config.maintain_order
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
if not common_dims:
|
|
410
|
+
res = self._new(self.data.join(target_data, how="cross"), name=self.name)
|
|
411
|
+
res._copy_flags(self)
|
|
412
|
+
return res
|
|
413
|
+
|
|
414
|
+
# If drop, we just do an inner join to get into the shape of the other
|
|
415
|
+
if self._extras_strategy == ExtrasStrategy.DROP:
|
|
416
|
+
res = self._new(
|
|
417
|
+
self.data.join(
|
|
418
|
+
target_data,
|
|
419
|
+
on=common_dims,
|
|
420
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
421
|
+
),
|
|
422
|
+
name=self.name,
|
|
423
|
+
)
|
|
424
|
+
res._copy_flags(self)
|
|
425
|
+
return res
|
|
426
|
+
|
|
427
|
+
result = self.data.join(
|
|
428
|
+
target_data,
|
|
429
|
+
on=common_dims,
|
|
430
|
+
how="left",
|
|
431
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
432
|
+
)
|
|
433
|
+
if result.get_column(missing_dims[0]).null_count() > 0:
|
|
434
|
+
target_labels = target.data.select(target._dimensions_unsafe).unique(
|
|
435
|
+
maintain_order=Config.maintain_order
|
|
360
436
|
)
|
|
361
|
-
|
|
437
|
+
_raise_extras_error(
|
|
438
|
+
self,
|
|
439
|
+
target,
|
|
440
|
+
target_labels.join(self.data, how="anti", on=common_dims),
|
|
441
|
+
swapped,
|
|
442
|
+
)
|
|
443
|
+
res = self._new(result, self.name)
|
|
444
|
+
res._copy_flags(self)
|
|
445
|
+
return res
|
|
362
446
|
|
|
363
447
|
|
|
364
448
|
def _sum_like_terms(df: pl.DataFrame) -> pl.DataFrame:
|
|
365
449
|
"""Combines terms with the same variables."""
|
|
366
450
|
dims = [c for c in df.columns if c not in RESERVED_COL_KEYS]
|
|
367
451
|
var_cols = [VAR_KEY] + ([QUAD_VAR_KEY] if QUAD_VAR_KEY in df.columns else [])
|
|
368
|
-
df = df.group_by(dims + var_cols, maintain_order=
|
|
452
|
+
df = df.group_by(dims + var_cols, maintain_order=Config.maintain_order).sum()
|
|
369
453
|
return df
|
|
370
454
|
|
|
371
455
|
|
|
372
456
|
def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
|
|
373
|
-
"""
|
|
374
|
-
Removes the quadratic column and terms with a zero coefficient, when applicable.
|
|
457
|
+
"""Removes the quadratic column and terms with a zero coefficient, when applicable.
|
|
375
458
|
|
|
376
459
|
Specifically, zero coefficient terms are always removed, except if they're the only terms in which case the expression contains a single term.
|
|
377
460
|
The quadratic column is removed if the expression is not a quadratic.
|
|
378
461
|
|
|
379
462
|
Examples:
|
|
380
|
-
|
|
381
463
|
>>> import polars as pl
|
|
382
|
-
>>> df = pl.DataFrame(
|
|
464
|
+
>>> df = pl.DataFrame(
|
|
465
|
+
... {
|
|
466
|
+
... VAR_KEY: [CONST_TERM, 1],
|
|
467
|
+
... QUAD_VAR_KEY: [CONST_TERM, 1],
|
|
468
|
+
... COEF_KEY: [1.0, 0],
|
|
469
|
+
... }
|
|
470
|
+
... )
|
|
383
471
|
>>> _simplify_expr_df(df)
|
|
384
472
|
shape: (1, 2)
|
|
385
473
|
┌───────────────┬─────────┐
|
|
@@ -389,7 +477,21 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
|
|
|
389
477
|
╞═══════════════╪═════════╡
|
|
390
478
|
│ 0 ┆ 1.0 │
|
|
391
479
|
└───────────────┴─────────┘
|
|
392
|
-
>>> df = pl.DataFrame(
|
|
480
|
+
>>> df = pl.DataFrame(
|
|
481
|
+
... {
|
|
482
|
+
... "t": [1, 1, 2, 2, 3, 3],
|
|
483
|
+
... VAR_KEY: [CONST_TERM, 1, CONST_TERM, 1, 1, 2],
|
|
484
|
+
... QUAD_VAR_KEY: [
|
|
485
|
+
... CONST_TERM,
|
|
486
|
+
... CONST_TERM,
|
|
487
|
+
... CONST_TERM,
|
|
488
|
+
... CONST_TERM,
|
|
489
|
+
... CONST_TERM,
|
|
490
|
+
... 1,
|
|
491
|
+
... ],
|
|
492
|
+
... COEF_KEY: [1, 0, 0, 0, 1, 0],
|
|
493
|
+
... }
|
|
494
|
+
... )
|
|
393
495
|
>>> _simplify_expr_df(df)
|
|
394
496
|
shape: (3, 3)
|
|
395
497
|
┌─────┬───────────────┬─────────┐
|
|
@@ -406,9 +508,14 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
|
|
|
406
508
|
if len(df_filtered) < len(df):
|
|
407
509
|
dims = [c for c in df.columns if c not in RESERVED_COL_KEYS]
|
|
408
510
|
if dims:
|
|
409
|
-
dim_values = df.select(dims).unique(maintain_order=
|
|
511
|
+
dim_values = df.select(dims).unique(maintain_order=Config.maintain_order)
|
|
410
512
|
df = (
|
|
411
|
-
dim_values.join(
|
|
513
|
+
dim_values.join(
|
|
514
|
+
df_filtered,
|
|
515
|
+
on=dims,
|
|
516
|
+
how="left",
|
|
517
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
518
|
+
)
|
|
412
519
|
.with_columns(pl.col(COEF_KEY).fill_null(0))
|
|
413
520
|
.fill_null(CONST_TERM)
|
|
414
521
|
)
|
|
@@ -417,7 +524,7 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
|
|
|
417
524
|
if df.is_empty():
|
|
418
525
|
df = pl.DataFrame(
|
|
419
526
|
{VAR_KEY: [CONST_TERM], COEF_KEY: [0]},
|
|
420
|
-
schema={VAR_KEY:
|
|
527
|
+
schema={VAR_KEY: Config.id_dtype, COEF_KEY: pl.Float64},
|
|
421
528
|
)
|
|
422
529
|
|
|
423
530
|
if QUAD_VAR_KEY in df.columns and (df.get_column(QUAD_VAR_KEY) == CONST_TERM).all():
|
|
@@ -426,10 +533,11 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
|
|
|
426
533
|
return df
|
|
427
534
|
|
|
428
535
|
|
|
429
|
-
def _get_dimensions(df: pl.DataFrame) ->
|
|
430
|
-
"""
|
|
431
|
-
|
|
432
|
-
|
|
536
|
+
def _get_dimensions(df: pl.DataFrame) -> list[str] | None:
|
|
537
|
+
"""Returns the dimensions of the DataFrame.
|
|
538
|
+
|
|
539
|
+
Reserved columns do not count as dimensions. If there are no dimensions,
|
|
540
|
+
returns `None` to force caller to handle this special case.
|
|
433
541
|
|
|
434
542
|
Examples:
|
|
435
543
|
>>> import polars as pl
|