pyoframe 0.2.0__py3-none-any.whl → 1.0.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyoframe/__init__.py +21 -14
- pyoframe/_arithmetic.py +265 -159
- pyoframe/_constants.py +416 -0
- pyoframe/_core.py +2575 -0
- pyoframe/_model.py +578 -0
- pyoframe/_model_element.py +175 -0
- pyoframe/_monkey_patch.py +80 -0
- pyoframe/{objective.py → _objective.py} +49 -14
- pyoframe/{util.py → _utils.py} +106 -126
- pyoframe/_version.py +16 -3
- {pyoframe-0.2.0.dist-info → pyoframe-1.0.0a0.dist-info}/METADATA +32 -25
- pyoframe-1.0.0a0.dist-info/RECORD +15 -0
- pyoframe/constants.py +0 -140
- pyoframe/core.py +0 -1794
- pyoframe/model.py +0 -408
- pyoframe/model_element.py +0 -184
- pyoframe/monkey_patch.py +0 -54
- pyoframe-0.2.0.dist-info/RECORD +0 -15
- {pyoframe-0.2.0.dist-info → pyoframe-1.0.0a0.dist-info}/WHEEL +0 -0
- {pyoframe-0.2.0.dist-info → pyoframe-1.0.0a0.dist-info}/licenses/LICENSE +0 -0
- {pyoframe-0.2.0.dist-info → pyoframe-1.0.0a0.dist-info}/top_level.txt +0 -0
pyoframe/_arithmetic.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Defines helper functions for doing arithmetic operations on expressions (e.g. addition).
|
|
3
|
-
"""
|
|
1
|
+
"""Defines helper functions for doing arithmetic operations on expressions (e.g. addition)."""
|
|
4
2
|
|
|
5
|
-
from
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
6
|
|
|
7
7
|
import polars as pl
|
|
8
8
|
|
|
9
|
-
from pyoframe.
|
|
9
|
+
from pyoframe._constants import (
|
|
10
10
|
COEF_KEY,
|
|
11
11
|
CONST_TERM,
|
|
12
12
|
KEY_TYPE,
|
|
@@ -19,81 +19,60 @@ from pyoframe.constants import (
|
|
|
19
19
|
)
|
|
20
20
|
|
|
21
21
|
if TYPE_CHECKING: # pragma: no cover
|
|
22
|
-
from pyoframe.
|
|
22
|
+
from pyoframe._core import Expression
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
def
|
|
26
|
-
"""
|
|
27
|
-
Multiply two or more expressions together.
|
|
25
|
+
def multiply(self: Expression, other: Expression) -> Expression:
|
|
26
|
+
"""Multiplies two expressions together.
|
|
28
27
|
|
|
29
28
|
Examples:
|
|
30
29
|
>>> import pyoframe as pf
|
|
31
|
-
>>> m = pf.Model(
|
|
30
|
+
>>> m = pf.Model()
|
|
32
31
|
>>> m.x1 = pf.Variable()
|
|
33
32
|
>>> m.x2 = pf.Variable()
|
|
34
33
|
>>> m.x3 = pf.Variable()
|
|
35
34
|
>>> result = 5 * m.x1 * m.x2
|
|
36
35
|
>>> result
|
|
37
|
-
<Expression
|
|
38
|
-
5
|
|
36
|
+
<Expression terms=1 type=quadratic>
|
|
37
|
+
5 x2 * x1
|
|
39
38
|
>>> result * m.x3
|
|
40
39
|
Traceback (most recent call last):
|
|
41
40
|
...
|
|
42
|
-
pyoframe.
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
Cannot multiply a quadratic expression by a non-constant.
|
|
41
|
+
pyoframe._constants.PyoframeError: Cannot multiply the two expressions below because the result would be a cubic. Only quadratic or linear expressions are allowed.
|
|
42
|
+
Expression 1 (quadratic): ((5 * x1) * x2)
|
|
43
|
+
Expression 2 (linear): x3
|
|
46
44
|
"""
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
"
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
)
|
|
55
|
-
+ "\nDue to error:\n"
|
|
56
|
-
+ str(error)
|
|
57
|
-
) from error
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def _add_expressions(*expressions: "Expression") -> "Expression":
|
|
61
|
-
try:
|
|
62
|
-
return _add_expressions_core(*expressions)
|
|
63
|
-
except PyoframeError as error:
|
|
45
|
+
self_degree, other_degree = self.degree(), other.degree()
|
|
46
|
+
product_degree = self_degree + other_degree
|
|
47
|
+
if product_degree > 2:
|
|
48
|
+
assert product_degree <= 4, (
|
|
49
|
+
"Unexpected because expressions should not exceed degree 2."
|
|
50
|
+
)
|
|
51
|
+
res_type = "cubic" if product_degree == 3 else "quartic"
|
|
64
52
|
raise PyoframeError(
|
|
65
|
-
"
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
+ "\nDue to error:\n"
|
|
70
|
-
+ str(error)
|
|
71
|
-
) from error
|
|
53
|
+
f"""Cannot multiply the two expressions below because the result would be a {res_type}. Only quadratic or linear expressions are allowed.
|
|
54
|
+
Expression 1 ({self.degree(return_str=True)}):\t{self.name}
|
|
55
|
+
Expression 2 ({other.degree(return_str=True)}):\t{other.name}"""
|
|
56
|
+
)
|
|
72
57
|
|
|
58
|
+
if self_degree == 1 and other_degree == 1:
|
|
59
|
+
return _quadratic_multiplication(self, other)
|
|
73
60
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
if self_degree + other_degree > 2:
|
|
77
|
-
# We know one of the two must be a quadratic since 1 + 1 is not greater than 2.
|
|
78
|
-
raise PyoframeError("Cannot multiply a quadratic expression by a non-constant.")
|
|
61
|
+
# save names to use in debug messages before any swapping occurs
|
|
62
|
+
self_name, other_name = self.name, other.name
|
|
79
63
|
if self_degree < other_degree:
|
|
80
64
|
self, other = other, self
|
|
81
65
|
self_degree, other_degree = other_degree, self_degree
|
|
82
|
-
if other_degree == 1:
|
|
83
|
-
assert self_degree == 1, (
|
|
84
|
-
"This should always be true since the sum of degrees must be <=2."
|
|
85
|
-
)
|
|
86
|
-
return _quadratic_multiplication(self, other)
|
|
87
66
|
|
|
88
67
|
assert other_degree == 0, (
|
|
89
68
|
"This should always be true since other cases have already been handled."
|
|
90
69
|
)
|
|
91
|
-
multiplier = other.data.drop(
|
|
92
|
-
VAR_KEY
|
|
93
|
-
) # QUAD_VAR_KEY doesn't need to be dropped since we know it doesn't exist
|
|
94
70
|
|
|
95
|
-
|
|
96
|
-
|
|
71
|
+
# QUAD_VAR_KEY doesn't need to be dropped since we know it doesn't exist
|
|
72
|
+
multiplier = other.data.drop(VAR_KEY)
|
|
73
|
+
|
|
74
|
+
dims = self._dimensions_unsafe
|
|
75
|
+
other_dims = other._dimensions_unsafe
|
|
97
76
|
dims_in_common = [dim for dim in dims if dim in other_dims]
|
|
98
77
|
|
|
99
78
|
data = (
|
|
@@ -101,17 +80,19 @@ def _multiply_expressions_core(self: "Expression", other: "Expression") -> "Expr
|
|
|
101
80
|
multiplier,
|
|
102
81
|
on=dims_in_common if len(dims_in_common) > 0 else None,
|
|
103
82
|
how="inner" if dims_in_common else "cross",
|
|
83
|
+
maintain_order=(
|
|
84
|
+
"left" if Config.maintain_order and dims_in_common else None
|
|
85
|
+
),
|
|
104
86
|
)
|
|
105
87
|
.with_columns(pl.col(COEF_KEY) * pl.col(COEF_KEY + "_right"))
|
|
106
88
|
.drop(COEF_KEY + "_right")
|
|
107
89
|
)
|
|
108
90
|
|
|
109
|
-
return self._new(data)
|
|
91
|
+
return self._new(data, name=f"({self_name} * {other_name})")
|
|
110
92
|
|
|
111
93
|
|
|
112
|
-
def _quadratic_multiplication(self:
|
|
113
|
-
"""
|
|
114
|
-
Multiply two expressions of degree 1.
|
|
94
|
+
def _quadratic_multiplication(self: Expression, other: Expression) -> Expression:
|
|
95
|
+
"""Multiplies two expressions of degree 1.
|
|
115
96
|
|
|
116
97
|
Examples:
|
|
117
98
|
>>> import polars as pl
|
|
@@ -122,18 +103,29 @@ def _quadratic_multiplication(self: "Expression", other: "Expression") -> "Expre
|
|
|
122
103
|
>>> expr1 = df * m.x1
|
|
123
104
|
>>> expr2 = df * m.x2 * 2 + 4
|
|
124
105
|
>>> expr1 * expr2
|
|
125
|
-
<Expression
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
106
|
+
<Expression height=3 terms=6 type=quadratic>
|
|
107
|
+
┌─────┬───────────────────┐
|
|
108
|
+
│ dim ┆ expression │
|
|
109
|
+
│ (3) ┆ │
|
|
110
|
+
╞═════╪═══════════════════╡
|
|
111
|
+
│ 1 ┆ 4 x1 +2 x2 * x1 │
|
|
112
|
+
│ 2 ┆ 8 x1 +8 x2 * x1 │
|
|
113
|
+
│ 3 ┆ 12 x1 +18 x2 * x1 │
|
|
114
|
+
└─────┴───────────────────┘
|
|
129
115
|
>>> (expr1 * expr2) - df * m.x1 * df * m.x2 * 2
|
|
130
|
-
<Expression
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
116
|
+
<Expression height=3 terms=3 type=linear>
|
|
117
|
+
┌─────┬────────────┐
|
|
118
|
+
│ dim ┆ expression │
|
|
119
|
+
│ (3) ┆ │
|
|
120
|
+
╞═════╪════════════╡
|
|
121
|
+
│ 1 ┆ 4 x1 │
|
|
122
|
+
│ 2 ┆ 8 x1 │
|
|
123
|
+
│ 3 ┆ 12 x1 │
|
|
124
|
+
└─────┴────────────┘
|
|
125
|
+
|
|
134
126
|
"""
|
|
135
|
-
dims = self.
|
|
136
|
-
other_dims = other.
|
|
127
|
+
dims = self._dimensions_unsafe
|
|
128
|
+
other_dims = other._dimensions_unsafe
|
|
137
129
|
dims_in_common = [dim for dim in dims if dim in other_dims]
|
|
138
130
|
|
|
139
131
|
data = (
|
|
@@ -141,11 +133,14 @@ def _quadratic_multiplication(self: "Expression", other: "Expression") -> "Expre
|
|
|
141
133
|
other.data,
|
|
142
134
|
on=dims_in_common if len(dims_in_common) > 0 else None,
|
|
143
135
|
how="inner" if dims_in_common else "cross",
|
|
136
|
+
maintain_order=(
|
|
137
|
+
"left" if Config.maintain_order and dims_in_common else None
|
|
138
|
+
),
|
|
144
139
|
)
|
|
145
140
|
.with_columns(pl.col(COEF_KEY) * pl.col(COEF_KEY + "_right"))
|
|
146
141
|
.drop(COEF_KEY + "_right")
|
|
147
142
|
.rename({VAR_KEY + "_right": QUAD_VAR_KEY})
|
|
148
|
-
# Swap VAR_KEY and QUAD_VAR_KEY so that
|
|
143
|
+
# Swap VAR_KEY and QUAD_VAR_KEY so that VAR_KEY is always the larger one
|
|
149
144
|
.with_columns(
|
|
150
145
|
pl.when(pl.col(VAR_KEY) < pl.col(QUAD_VAR_KEY))
|
|
151
146
|
.then(pl.col(QUAD_VAR_KEY))
|
|
@@ -160,12 +155,13 @@ def _quadratic_multiplication(self: "Expression", other: "Expression") -> "Expre
|
|
|
160
155
|
|
|
161
156
|
data = _sum_like_terms(data)
|
|
162
157
|
|
|
163
|
-
return self._new(data)
|
|
158
|
+
return self._new(data, name=f"({self.name} * {other.name})")
|
|
164
159
|
|
|
165
160
|
|
|
166
|
-
def
|
|
167
|
-
|
|
168
|
-
|
|
161
|
+
def add(*expressions: Expression) -> Expression:
|
|
162
|
+
"""Add multiple expressions together."""
|
|
163
|
+
# Mapping of how a sum of two expressions should propagate the unmatched strategy
|
|
164
|
+
propagation_strategies = {
|
|
169
165
|
(UnmatchedStrategy.DROP, UnmatchedStrategy.DROP): UnmatchedStrategy.DROP,
|
|
170
166
|
(
|
|
171
167
|
UnmatchedStrategy.UNSET,
|
|
@@ -186,40 +182,62 @@ def _add_expressions_core(*expressions: "Expression") -> "Expression":
|
|
|
186
182
|
dims = []
|
|
187
183
|
elif Config.disable_unmatched_checks:
|
|
188
184
|
requires_join = any(
|
|
189
|
-
expr.
|
|
185
|
+
expr._unmatched_strategy
|
|
190
186
|
not in (UnmatchedStrategy.KEEP, UnmatchedStrategy.UNSET)
|
|
191
187
|
for expr in expressions
|
|
192
188
|
)
|
|
193
189
|
else:
|
|
194
190
|
requires_join = any(
|
|
195
|
-
expr.
|
|
191
|
+
expr._unmatched_strategy != UnmatchedStrategy.KEEP for expr in expressions
|
|
196
192
|
)
|
|
197
193
|
|
|
198
194
|
has_dim_conflict = any(
|
|
199
|
-
sorted(dims) != sorted(expr.
|
|
195
|
+
sorted(dims) != sorted(expr._dimensions_unsafe) for expr in expressions[1:]
|
|
200
196
|
)
|
|
201
197
|
|
|
202
198
|
# If we cannot use .concat compute the sum in a pairwise manner
|
|
203
199
|
if len(expressions) > 2 and (has_dim_conflict or requires_join):
|
|
204
200
|
result = expressions[0]
|
|
205
201
|
for expr in expressions[1:]:
|
|
206
|
-
result =
|
|
202
|
+
result = add(result, expr)
|
|
207
203
|
return result
|
|
208
204
|
|
|
209
205
|
if has_dim_conflict:
|
|
210
206
|
assert len(expressions) == 2
|
|
211
|
-
expressions = (
|
|
212
|
-
_add_dimension(expressions[0], expressions[1]),
|
|
213
|
-
_add_dimension(expressions[1], expressions[0]),
|
|
214
|
-
)
|
|
215
|
-
assert sorted(expressions[0].dimensions_unsafe) == sorted(
|
|
216
|
-
expressions[1].dimensions_unsafe
|
|
217
|
-
)
|
|
218
207
|
|
|
219
|
-
|
|
208
|
+
left, right = expressions[0], expressions[1]
|
|
209
|
+
left_dims, right_dims = left._dimensions_unsafe, right._dimensions_unsafe
|
|
210
|
+
|
|
211
|
+
missing_left = [dim for dim in right_dims if dim not in left_dims]
|
|
212
|
+
missing_right = [dim for dim in left_dims if dim not in right_dims]
|
|
213
|
+
common_dims = [dim for dim in left_dims if dim in right_dims]
|
|
214
|
+
|
|
215
|
+
if not (
|
|
216
|
+
set(missing_left) <= set(left._allowed_new_dims)
|
|
217
|
+
and set(missing_right) <= set(right._allowed_new_dims)
|
|
218
|
+
):
|
|
219
|
+
_raise_addition_error(
|
|
220
|
+
left,
|
|
221
|
+
right,
|
|
222
|
+
f"their\n\tdimensions are different ({left_dims} != {right_dims})",
|
|
223
|
+
"If this is intentional, use .over(…) to broadcast. Learn more at\n\thttps://bravos-power.github.io/pyoframe/learn/concepts/special-functions/#adding-expressions-with-differing-dimensions-using-over",
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
left_old = left
|
|
227
|
+
if missing_left:
|
|
228
|
+
left = _broadcast(left, right, common_dims, missing_left)
|
|
229
|
+
if missing_right:
|
|
230
|
+
right = _broadcast(
|
|
231
|
+
right, left_old, common_dims, missing_right, swapped=True
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
assert sorted(left._dimensions_unsafe) == sorted(right._dimensions_unsafe)
|
|
235
|
+
expressions = (left, right)
|
|
236
|
+
|
|
237
|
+
dims = expressions[0]._dimensions_unsafe
|
|
220
238
|
# Check no dims conflict
|
|
221
239
|
assert all(
|
|
222
|
-
sorted(dims) == sorted(expr.
|
|
240
|
+
sorted(dims) == sorted(expr._dimensions_unsafe) for expr in expressions[1:]
|
|
223
241
|
)
|
|
224
242
|
if requires_join:
|
|
225
243
|
assert len(expressions) == 2
|
|
@@ -227,25 +245,36 @@ def _add_expressions_core(*expressions: "Expression") -> "Expression":
|
|
|
227
245
|
left, right = expressions[0], expressions[1]
|
|
228
246
|
|
|
229
247
|
# Order so that drop always comes before keep, and keep always comes before default
|
|
230
|
-
if
|
|
231
|
-
(
|
|
232
|
-
(
|
|
233
|
-
|
|
248
|
+
if swap := (
|
|
249
|
+
(left._unmatched_strategy, right._unmatched_strategy)
|
|
250
|
+
in (
|
|
251
|
+
(UnmatchedStrategy.UNSET, UnmatchedStrategy.DROP),
|
|
252
|
+
(UnmatchedStrategy.UNSET, UnmatchedStrategy.KEEP),
|
|
253
|
+
(UnmatchedStrategy.KEEP, UnmatchedStrategy.DROP),
|
|
254
|
+
)
|
|
234
255
|
):
|
|
235
256
|
left, right = right, left
|
|
236
257
|
|
|
237
258
|
def get_indices(expr):
|
|
238
|
-
return expr.data.select(dims).unique(maintain_order=
|
|
259
|
+
return expr.data.select(dims).unique(maintain_order=Config.maintain_order)
|
|
239
260
|
|
|
240
261
|
left_data, right_data = left.data, right.data
|
|
241
262
|
|
|
242
|
-
strat = (left.
|
|
263
|
+
strat = (left._unmatched_strategy, right._unmatched_strategy)
|
|
243
264
|
|
|
244
|
-
|
|
265
|
+
propagate_strat = propagation_strategies[strat] # type: ignore
|
|
245
266
|
|
|
246
267
|
if strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.DROP):
|
|
247
|
-
left_data = left.data.join(
|
|
248
|
-
|
|
268
|
+
left_data = left.data.join(
|
|
269
|
+
get_indices(right),
|
|
270
|
+
on=dims,
|
|
271
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
272
|
+
)
|
|
273
|
+
right_data = right.data.join(
|
|
274
|
+
get_indices(left),
|
|
275
|
+
on=dims,
|
|
276
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
277
|
+
)
|
|
249
278
|
elif strat == (UnmatchedStrategy.UNSET, UnmatchedStrategy.UNSET):
|
|
250
279
|
assert not Config.disable_unmatched_checks, (
|
|
251
280
|
"This code should not be reached when unmatched checks are disabled."
|
|
@@ -254,46 +283,54 @@ def _add_expressions_core(*expressions: "Expression") -> "Expression":
|
|
|
254
283
|
get_indices(right),
|
|
255
284
|
how="full",
|
|
256
285
|
on=dims,
|
|
286
|
+
maintain_order="left_right" if Config.maintain_order else None,
|
|
257
287
|
)
|
|
258
288
|
if outer_join.get_column(dims[0]).null_count() > 0:
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
+ str(outer_join.filter(outer_join.get_column(dims[0]).is_null()))
|
|
289
|
+
unmatched_vals = outer_join.filter(
|
|
290
|
+
outer_join.get_column(dims[0]).is_null()
|
|
262
291
|
)
|
|
292
|
+
_raise_unmatched_values_error(left, right, unmatched_vals, swap)
|
|
263
293
|
if outer_join.get_column(dims[0] + "_right").null_count() > 0:
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
+ str(
|
|
267
|
-
outer_join.filter(
|
|
268
|
-
outer_join.get_column(dims[0] + "_right").is_null()
|
|
269
|
-
)
|
|
270
|
-
)
|
|
294
|
+
unmatched_vals = outer_join.filter(
|
|
295
|
+
outer_join.get_column(dims[0] + "_right").is_null()
|
|
271
296
|
)
|
|
297
|
+
_raise_unmatched_values_error(left, right, unmatched_vals, swap)
|
|
298
|
+
|
|
272
299
|
elif strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.KEEP):
|
|
273
|
-
left_data = get_indices(right).join(
|
|
300
|
+
left_data = get_indices(right).join(
|
|
301
|
+
left.data,
|
|
302
|
+
how="left",
|
|
303
|
+
on=dims,
|
|
304
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
305
|
+
)
|
|
274
306
|
elif strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.UNSET):
|
|
275
|
-
left_data = get_indices(right).join(
|
|
307
|
+
left_data = get_indices(right).join(
|
|
308
|
+
left.data,
|
|
309
|
+
how="left",
|
|
310
|
+
on=dims,
|
|
311
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
312
|
+
)
|
|
276
313
|
if left_data.get_column(COEF_KEY).null_count() > 0:
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
314
|
+
_raise_unmatched_values_error(
|
|
315
|
+
left,
|
|
316
|
+
right,
|
|
317
|
+
left_data.filter(left_data.get_column(COEF_KEY).is_null()),
|
|
318
|
+
swap,
|
|
280
319
|
)
|
|
320
|
+
|
|
281
321
|
elif strat == (UnmatchedStrategy.KEEP, UnmatchedStrategy.UNSET):
|
|
282
322
|
assert not Config.disable_unmatched_checks, (
|
|
283
323
|
"This code should not be reached when unmatched checks are disabled."
|
|
284
324
|
)
|
|
285
325
|
unmatched = right.data.join(get_indices(left), how="anti", on=dims)
|
|
286
326
|
if len(unmatched) > 0:
|
|
287
|
-
|
|
288
|
-
"Dataframe has unmatched values. If this is intentional, use .drop_unmatched() or .keep_unmatched()\n"
|
|
289
|
-
+ str(unmatched)
|
|
290
|
-
)
|
|
327
|
+
_raise_unmatched_values_error(left, right, unmatched, swap)
|
|
291
328
|
else: # pragma: no cover
|
|
292
329
|
assert False, "This code should've never been reached!"
|
|
293
330
|
|
|
294
331
|
expr_data = [left_data, right_data]
|
|
295
332
|
else:
|
|
296
|
-
|
|
333
|
+
propagate_strat = expressions[0]._unmatched_strategy
|
|
297
334
|
expr_data = [expr.data for expr in expressions]
|
|
298
335
|
|
|
299
336
|
# Add quadratic column if it is needed and doesn't already exist
|
|
@@ -315,71 +352,120 @@ def _add_expressions_core(*expressions: "Expression") -> "Expression":
|
|
|
315
352
|
data = pl.concat(expr_data, how="vertical_relaxed")
|
|
316
353
|
data = _sum_like_terms(data)
|
|
317
354
|
|
|
318
|
-
|
|
319
|
-
|
|
355
|
+
full_name = expressions[0].name
|
|
356
|
+
for expr in expressions[1:]:
|
|
357
|
+
name = expr.name
|
|
358
|
+
full_name += f" - {name[1:]}" if name[0] == "-" else f" + {name}"
|
|
359
|
+
|
|
360
|
+
new_expr = expressions[0]._new(data, name=f"({full_name})")
|
|
361
|
+
new_expr._unmatched_strategy = propagate_strat
|
|
320
362
|
|
|
321
363
|
return new_expr
|
|
322
364
|
|
|
323
365
|
|
|
324
|
-
def
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
366
|
+
def _raise_unmatched_values_error(
|
|
367
|
+
left: Expression, right: Expression, unmatched_values: pl.DataFrame, swapped: bool
|
|
368
|
+
):
|
|
369
|
+
if swapped:
|
|
370
|
+
left, right = right, left
|
|
371
|
+
|
|
372
|
+
_raise_addition_error(
|
|
373
|
+
left,
|
|
374
|
+
right,
|
|
375
|
+
"of unmatched values",
|
|
376
|
+
f"Unmatched values:\n{unmatched_values}\nIf this is intentional, use .drop_unmatched() or .keep_unmatched().",
|
|
377
|
+
)
|
|
335
378
|
|
|
336
|
-
# We're already at the size of our target
|
|
337
|
-
if not missing_dims:
|
|
338
|
-
return self
|
|
339
379
|
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
380
|
+
def _raise_addition_error(
|
|
381
|
+
left: Expression, right: Expression, reason: str, postfix: str
|
|
382
|
+
):
|
|
383
|
+
op = "add"
|
|
384
|
+
right_name = right.name
|
|
385
|
+
if right_name[0] == "-":
|
|
386
|
+
op = "subtract"
|
|
387
|
+
right_name = right_name[1:]
|
|
388
|
+
raise PyoframeError(
|
|
389
|
+
f"""Cannot {op} the two expressions below because {reason}.
|
|
390
|
+
Expression 1:\t{left.name}
|
|
391
|
+
Expression 2:\t{right_name}
|
|
392
|
+
{postfix}
|
|
393
|
+
"""
|
|
394
|
+
)
|
|
395
|
+
|
|
345
396
|
|
|
346
|
-
|
|
397
|
+
# TODO consider returning a dataframe instead of an expression to simplify code (e.g. avoid copy_flags)
|
|
398
|
+
def _broadcast(
|
|
399
|
+
self: Expression,
|
|
400
|
+
target: Expression,
|
|
401
|
+
common_dims: list[str],
|
|
402
|
+
missing_dims: list[str],
|
|
403
|
+
swapped: bool = False,
|
|
404
|
+
) -> Expression:
|
|
405
|
+
target_data = target.data.select(target._dimensions_unsafe).unique(
|
|
406
|
+
maintain_order=Config.maintain_order
|
|
407
|
+
)
|
|
347
408
|
|
|
348
|
-
if not
|
|
349
|
-
|
|
409
|
+
if not common_dims:
|
|
410
|
+
res = self._new(self.data.join(target_data, how="cross"), name=self.name)
|
|
411
|
+
res._copy_flags(self)
|
|
412
|
+
return res
|
|
350
413
|
|
|
351
414
|
# If drop, we just do an inner join to get into the shape of the other
|
|
352
|
-
if self.
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
415
|
+
if self._unmatched_strategy == UnmatchedStrategy.DROP:
|
|
416
|
+
res = self._new(
|
|
417
|
+
self.data.join(
|
|
418
|
+
target_data,
|
|
419
|
+
on=common_dims,
|
|
420
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
421
|
+
),
|
|
422
|
+
name=self.name,
|
|
423
|
+
)
|
|
424
|
+
res._copy_flags(self)
|
|
425
|
+
return res
|
|
426
|
+
|
|
427
|
+
result = self.data.join(
|
|
428
|
+
target_data,
|
|
429
|
+
on=common_dims,
|
|
430
|
+
how="left",
|
|
431
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
432
|
+
)
|
|
356
433
|
right_has_missing = result.get_column(missing_dims[0]).null_count() > 0
|
|
357
434
|
if right_has_missing:
|
|
358
|
-
|
|
359
|
-
|
|
435
|
+
_raise_unmatched_values_error(
|
|
436
|
+
self,
|
|
437
|
+
target,
|
|
438
|
+
result.filter(result.get_column(missing_dims[0]).is_null()),
|
|
439
|
+
swapped,
|
|
360
440
|
)
|
|
361
|
-
|
|
441
|
+
res = self._new(result, self.name)
|
|
442
|
+
res._copy_flags(self)
|
|
443
|
+
return res
|
|
362
444
|
|
|
363
445
|
|
|
364
446
|
def _sum_like_terms(df: pl.DataFrame) -> pl.DataFrame:
|
|
365
447
|
"""Combines terms with the same variables."""
|
|
366
448
|
dims = [c for c in df.columns if c not in RESERVED_COL_KEYS]
|
|
367
449
|
var_cols = [VAR_KEY] + ([QUAD_VAR_KEY] if QUAD_VAR_KEY in df.columns else [])
|
|
368
|
-
df = df.group_by(dims + var_cols, maintain_order=
|
|
450
|
+
df = df.group_by(dims + var_cols, maintain_order=Config.maintain_order).sum()
|
|
369
451
|
return df
|
|
370
452
|
|
|
371
453
|
|
|
372
454
|
def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
|
|
373
|
-
"""
|
|
374
|
-
Removes the quadratic column and terms with a zero coefficient, when applicable.
|
|
455
|
+
"""Removes the quadratic column and terms with a zero coefficient, when applicable.
|
|
375
456
|
|
|
376
457
|
Specifically, zero coefficient terms are always removed, except if they're the only terms in which case the expression contains a single term.
|
|
377
458
|
The quadratic column is removed if the expression is not a quadratic.
|
|
378
459
|
|
|
379
460
|
Examples:
|
|
380
|
-
|
|
381
461
|
>>> import polars as pl
|
|
382
|
-
>>> df = pl.DataFrame(
|
|
462
|
+
>>> df = pl.DataFrame(
|
|
463
|
+
... {
|
|
464
|
+
... VAR_KEY: [CONST_TERM, 1],
|
|
465
|
+
... QUAD_VAR_KEY: [CONST_TERM, 1],
|
|
466
|
+
... COEF_KEY: [1.0, 0],
|
|
467
|
+
... }
|
|
468
|
+
... )
|
|
383
469
|
>>> _simplify_expr_df(df)
|
|
384
470
|
shape: (1, 2)
|
|
385
471
|
┌───────────────┬─────────┐
|
|
@@ -389,7 +475,21 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
|
|
|
389
475
|
╞═══════════════╪═════════╡
|
|
390
476
|
│ 0 ┆ 1.0 │
|
|
391
477
|
└───────────────┴─────────┘
|
|
392
|
-
>>> df = pl.DataFrame(
|
|
478
|
+
>>> df = pl.DataFrame(
|
|
479
|
+
... {
|
|
480
|
+
... "t": [1, 1, 2, 2, 3, 3],
|
|
481
|
+
... VAR_KEY: [CONST_TERM, 1, CONST_TERM, 1, 1, 2],
|
|
482
|
+
... QUAD_VAR_KEY: [
|
|
483
|
+
... CONST_TERM,
|
|
484
|
+
... CONST_TERM,
|
|
485
|
+
... CONST_TERM,
|
|
486
|
+
... CONST_TERM,
|
|
487
|
+
... CONST_TERM,
|
|
488
|
+
... 1,
|
|
489
|
+
... ],
|
|
490
|
+
... COEF_KEY: [1, 0, 0, 0, 1, 0],
|
|
491
|
+
... }
|
|
492
|
+
... )
|
|
393
493
|
>>> _simplify_expr_df(df)
|
|
394
494
|
shape: (3, 3)
|
|
395
495
|
┌─────┬───────────────┬─────────┐
|
|
@@ -406,9 +506,14 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
|
|
|
406
506
|
if len(df_filtered) < len(df):
|
|
407
507
|
dims = [c for c in df.columns if c not in RESERVED_COL_KEYS]
|
|
408
508
|
if dims:
|
|
409
|
-
dim_values = df.select(dims).unique(maintain_order=
|
|
509
|
+
dim_values = df.select(dims).unique(maintain_order=Config.maintain_order)
|
|
410
510
|
df = (
|
|
411
|
-
dim_values.join(
|
|
511
|
+
dim_values.join(
|
|
512
|
+
df_filtered,
|
|
513
|
+
on=dims,
|
|
514
|
+
how="left",
|
|
515
|
+
maintain_order="left" if Config.maintain_order else None,
|
|
516
|
+
)
|
|
412
517
|
.with_columns(pl.col(COEF_KEY).fill_null(0))
|
|
413
518
|
.fill_null(CONST_TERM)
|
|
414
519
|
)
|
|
@@ -426,10 +531,11 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
|
|
|
426
531
|
return df
|
|
427
532
|
|
|
428
533
|
|
|
429
|
-
def _get_dimensions(df: pl.DataFrame) ->
|
|
430
|
-
"""
|
|
431
|
-
|
|
432
|
-
|
|
534
|
+
def _get_dimensions(df: pl.DataFrame) -> list[str] | None:
|
|
535
|
+
"""Returns the dimensions of the DataFrame.
|
|
536
|
+
|
|
537
|
+
Reserved columns do not count as dimensions. If there are no dimensions,
|
|
538
|
+
returns `None` to force caller to handle this special case.
|
|
433
539
|
|
|
434
540
|
Examples:
|
|
435
541
|
>>> import polars as pl
|