pyoframe 0.2.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyoframe/_arithmetic.py CHANGED
@@ -1,99 +1,87 @@
1
- """
2
- Defines helper functions for doing arithmetic operations on expressions (e.g. addition).
3
- """
1
+ """Defines helper functions for doing arithmetic operations on expressions (e.g. addition)."""
2
+
3
+ from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING, List, Optional
5
+ from typing import TYPE_CHECKING
6
6
 
7
7
  import polars as pl
8
8
 
9
- from pyoframe.constants import (
9
+ from pyoframe._constants import (
10
10
  COEF_KEY,
11
11
  CONST_TERM,
12
- KEY_TYPE,
13
12
  QUAD_VAR_KEY,
14
13
  RESERVED_COL_KEYS,
15
14
  VAR_KEY,
16
15
  Config,
16
+ ExtrasStrategy,
17
17
  PyoframeError,
18
- UnmatchedStrategy,
19
18
  )
20
19
 
21
20
  if TYPE_CHECKING: # pragma: no cover
22
- from pyoframe.core import Expression
21
+ from pyoframe._core import Expression
23
22
 
23
+ # Mapping of how a sum of two expressions should propagate the extras strategy
24
+ _extras_propagation_rules = {
25
+ (ExtrasStrategy.DROP, ExtrasStrategy.DROP): ExtrasStrategy.DROP,
26
+ (ExtrasStrategy.UNSET, ExtrasStrategy.UNSET): ExtrasStrategy.UNSET,
27
+ (ExtrasStrategy.KEEP, ExtrasStrategy.KEEP): ExtrasStrategy.KEEP,
28
+ (ExtrasStrategy.DROP, ExtrasStrategy.KEEP): ExtrasStrategy.UNSET,
29
+ (ExtrasStrategy.DROP, ExtrasStrategy.UNSET): ExtrasStrategy.DROP,
30
+ (ExtrasStrategy.KEEP, ExtrasStrategy.UNSET): ExtrasStrategy.KEEP,
31
+ }
24
32
 
25
- def _multiply_expressions(self: "Expression", other: "Expression") -> "Expression":
26
- """
27
- Multiply two or more expressions together.
33
+
34
+ def multiply(self: Expression, other: Expression) -> Expression:
35
+ """Multiplies two expressions together.
28
36
 
29
37
  Examples:
30
38
  >>> import pyoframe as pf
31
- >>> m = pf.Model("min")
39
+ >>> m = pf.Model()
32
40
  >>> m.x1 = pf.Variable()
33
41
  >>> m.x2 = pf.Variable()
34
42
  >>> m.x3 = pf.Variable()
35
43
  >>> result = 5 * m.x1 * m.x2
36
44
  >>> result
37
- <Expression size=1 dimensions={} terms=1 degree=2>
38
- 5 x2 * x1
45
+ <Expression terms=1 type=quadratic>
46
+ 5 x2 * x1
39
47
  >>> result * m.x3
40
48
  Traceback (most recent call last):
41
49
  ...
42
- pyoframe.constants.PyoframeError: Failed to multiply expressions:
43
- <Expression size=1 dimensions={} terms=1 degree=2> * <Expression size=1 dimensions={} terms=1>
44
- Due to error:
45
- Cannot multiply a quadratic expression by a non-constant.
50
+ pyoframe._constants.PyoframeError: Cannot multiply the two expressions below because the result would be a cubic. Only quadratic or linear expressions are allowed.
51
+ Expression 1 (quadratic): ((5 * x1) * x2)
52
+ Expression 2 (linear): x3
46
53
  """
47
- try:
48
- return _multiply_expressions_core(self, other)
49
- except PyoframeError as error:
50
- raise PyoframeError(
51
- "Failed to multiply expressions:\n"
52
- + " * ".join(
53
- e.to_str(include_header=True, include_data=False) for e in [self, other]
54
- )
55
- + "\nDue to error:\n"
56
- + str(error)
57
- ) from error
58
-
59
-
60
- def _add_expressions(*expressions: "Expression") -> "Expression":
61
- try:
62
- return _add_expressions_core(*expressions)
63
- except PyoframeError as error:
54
+ self_degree, other_degree = self.degree(), other.degree()
55
+ product_degree = self_degree + other_degree
56
+ if product_degree > 2:
57
+ assert product_degree <= 4, (
58
+ "Unexpected because expressions should not exceed degree 2."
59
+ )
60
+ res_type = "cubic" if product_degree == 3 else "quartic"
64
61
  raise PyoframeError(
65
- "Failed to add expressions:\n"
66
- + " + ".join(
67
- e.to_str(include_header=True, include_data=False) for e in expressions
68
- )
69
- + "\nDue to error:\n"
70
- + str(error)
71
- ) from error
62
+ f"""Cannot multiply the two expressions below because the result would be a {res_type}. Only quadratic or linear expressions are allowed.
63
+ Expression 1 ({self.degree(return_str=True)}):\t{self.name}
64
+ Expression 2 ({other.degree(return_str=True)}):\t{other.name}"""
65
+ )
72
66
 
67
+ if self_degree == 1 and other_degree == 1:
68
+ return _quadratic_multiplication(self, other)
73
69
 
74
- def _multiply_expressions_core(self: "Expression", other: "Expression") -> "Expression":
75
- self_degree, other_degree = self.degree(), other.degree()
76
- if self_degree + other_degree > 2:
77
- # We know one of the two must be a quadratic since 1 + 1 is not greater than 2.
78
- raise PyoframeError("Cannot multiply a quadratic expression by a non-constant.")
70
+ # save names to use in debug messages before any swapping occurs
71
+ self_name, other_name = self.name, other.name
79
72
  if self_degree < other_degree:
80
73
  self, other = other, self
81
74
  self_degree, other_degree = other_degree, self_degree
82
- if other_degree == 1:
83
- assert self_degree == 1, (
84
- "This should always be true since the sum of degrees must be <=2."
85
- )
86
- return _quadratic_multiplication(self, other)
87
75
 
88
76
  assert other_degree == 0, (
89
77
  "This should always be true since other cases have already been handled."
90
78
  )
91
- multiplier = other.data.drop(
92
- VAR_KEY
93
- ) # QUAD_VAR_KEY doesn't need to be dropped since we know it doesn't exist
94
79
 
95
- dims = self.dimensions_unsafe
96
- other_dims = other.dimensions_unsafe
80
+ # QUAD_VAR_KEY doesn't need to be dropped since we know it doesn't exist
81
+ multiplier = other.data.drop(VAR_KEY)
82
+
83
+ dims = self._dimensions_unsafe
84
+ other_dims = other._dimensions_unsafe
97
85
  dims_in_common = [dim for dim in dims if dim in other_dims]
98
86
 
99
87
  data = (
@@ -101,17 +89,19 @@ def _multiply_expressions_core(self: "Expression", other: "Expression") -> "Expr
101
89
  multiplier,
102
90
  on=dims_in_common if len(dims_in_common) > 0 else None,
103
91
  how="inner" if dims_in_common else "cross",
92
+ maintain_order=(
93
+ "left" if Config.maintain_order and dims_in_common else None
94
+ ),
104
95
  )
105
96
  .with_columns(pl.col(COEF_KEY) * pl.col(COEF_KEY + "_right"))
106
97
  .drop(COEF_KEY + "_right")
107
98
  )
108
99
 
109
- return self._new(data)
100
+ return self._new(data, name=f"({self_name} * {other_name})")
110
101
 
111
102
 
112
- def _quadratic_multiplication(self: "Expression", other: "Expression") -> "Expression":
113
- """
114
- Multiply two expressions of degree 1.
103
+ def _quadratic_multiplication(self: Expression, other: Expression) -> Expression:
104
+ """Multiplies two expressions of degree 1.
115
105
 
116
106
  Examples:
117
107
  >>> import polars as pl
@@ -122,18 +112,29 @@ def _quadratic_multiplication(self: "Expression", other: "Expression") -> "Expre
122
112
  >>> expr1 = df * m.x1
123
113
  >>> expr2 = df * m.x2 * 2 + 4
124
114
  >>> expr1 * expr2
125
- <Expression size=3 dimensions={'dim': 3} terms=6 degree=2>
126
- [1]: 4 x1 +2 x2 * x1
127
- [2]: 8 x1 +8 x2 * x1
128
- [3]: 12 x1 +18 x2 * x1
115
+ <Expression height=3 terms=6 type=quadratic>
116
+ ┌─────┬───────────────────┐
117
+ dim expression │
118
+ │ (3) ┆ │
119
+ ╞═════╪═══════════════════╡
120
+ │ 1 ┆ 4 x1 +2 x2 * x1 │
121
+ │ 2 ┆ 8 x1 +8 x2 * x1 │
122
+ │ 3 ┆ 12 x1 +18 x2 * x1 │
123
+ └─────┴───────────────────┘
129
124
  >>> (expr1 * expr2) - df * m.x1 * df * m.x2 * 2
130
- <Expression size=3 dimensions={'dim': 3} terms=3>
131
- [1]: 4 x1
132
- [2]: 8 x1
133
- [3]: 12 x1
125
+ <Expression height=3 terms=3 type=linear>
126
+ ┌─────┬────────────┐
127
+ dim ┆ expression │
128
+ │ (3) ┆ │
129
+ ╞═════╪════════════╡
130
+ │ 1 ┆ 4 x1 │
131
+ │ 2 ┆ 8 x1 │
132
+ │ 3 ┆ 12 x1 │
133
+ └─────┴────────────┘
134
+
134
135
  """
135
- dims = self.dimensions_unsafe
136
- other_dims = other.dimensions_unsafe
136
+ dims = self._dimensions_unsafe
137
+ other_dims = other._dimensions_unsafe
137
138
  dims_in_common = [dim for dim in dims if dim in other_dims]
138
139
 
139
140
  data = (
@@ -141,11 +142,14 @@ def _quadratic_multiplication(self: "Expression", other: "Expression") -> "Expre
141
142
  other.data,
142
143
  on=dims_in_common if len(dims_in_common) > 0 else None,
143
144
  how="inner" if dims_in_common else "cross",
145
+ maintain_order=(
146
+ "left" if Config.maintain_order and dims_in_common else None
147
+ ),
144
148
  )
145
149
  .with_columns(pl.col(COEF_KEY) * pl.col(COEF_KEY + "_right"))
146
150
  .drop(COEF_KEY + "_right")
147
151
  .rename({VAR_KEY + "_right": QUAD_VAR_KEY})
148
- # Swap VAR_KEY and QUAD_VAR_KEY so that VAR_KEy is always the larger one
152
+ # Swap VAR_KEY and QUAD_VAR_KEY so that VAR_KEY is always the larger one
149
153
  .with_columns(
150
154
  pl.when(pl.col(VAR_KEY) < pl.col(QUAD_VAR_KEY))
151
155
  .then(pl.col(QUAD_VAR_KEY))
@@ -160,147 +164,87 @@ def _quadratic_multiplication(self: "Expression", other: "Expression") -> "Expre
160
164
 
161
165
  data = _sum_like_terms(data)
162
166
 
163
- return self._new(data)
164
-
167
+ return self._new(data, name=f"({self.name} * {other.name})")
165
168
 
166
- def _add_expressions_core(*expressions: "Expression") -> "Expression":
167
- # Mapping of how a sum of two expressions should propogate the unmatched strategy
168
- propogatation_strategies = {
169
- (UnmatchedStrategy.DROP, UnmatchedStrategy.DROP): UnmatchedStrategy.DROP,
170
- (
171
- UnmatchedStrategy.UNSET,
172
- UnmatchedStrategy.UNSET,
173
- ): UnmatchedStrategy.UNSET,
174
- (UnmatchedStrategy.KEEP, UnmatchedStrategy.KEEP): UnmatchedStrategy.KEEP,
175
- (UnmatchedStrategy.DROP, UnmatchedStrategy.KEEP): UnmatchedStrategy.UNSET,
176
- (UnmatchedStrategy.DROP, UnmatchedStrategy.UNSET): UnmatchedStrategy.DROP,
177
- (UnmatchedStrategy.KEEP, UnmatchedStrategy.UNSET): UnmatchedStrategy.KEEP,
178
- }
179
169
 
170
+ def add(*expressions: Expression) -> Expression:
171
+ """Add multiple expressions together."""
180
172
  assert len(expressions) > 1, "Need at least two expressions to add together."
181
173
 
182
- dims = expressions[0].dimensions
183
-
184
- if dims is None:
185
- requires_join = False
186
- dims = []
187
- elif Config.disable_unmatched_checks:
188
- requires_join = any(
189
- expr.unmatched_strategy
190
- not in (UnmatchedStrategy.KEEP, UnmatchedStrategy.UNSET)
191
- for expr in expressions
192
- )
174
+ if Config.disable_extras_checks:
175
+ no_checks_strats = (ExtrasStrategy.KEEP, ExtrasStrategy.UNSET)
193
176
  else:
194
- requires_join = any(
195
- expr.unmatched_strategy != UnmatchedStrategy.KEEP for expr in expressions
196
- )
177
+ no_checks_strats = (ExtrasStrategy.KEEP,)
197
178
 
198
- has_dim_conflict = any(
199
- sorted(dims) != sorted(expr.dimensions_unsafe) for expr in expressions[1:]
179
+ no_extras_checks_required = (
180
+ all(expr._extras_strategy in no_checks_strats for expr in expressions)
181
+ # if only one dimensioned, then there is no such thing as extra labels,
182
+ # labels will be set by the only dimensioned expression
183
+ or sum(not expr.dimensionless for expr in expressions) <= 1
200
184
  )
201
185
 
202
- # If we cannot use .concat compute the sum in a pairwise manner
203
- if len(expressions) > 2 and (has_dim_conflict or requires_join):
204
- result = expressions[0]
205
- for expr in expressions[1:]:
206
- result = _add_expressions_core(result, expr)
207
- return result
208
-
209
- if has_dim_conflict:
210
- assert len(expressions) == 2
211
- expressions = (
212
- _add_dimension(expressions[0], expressions[1]),
213
- _add_dimension(expressions[1], expressions[0]),
214
- )
215
- assert sorted(expressions[0].dimensions_unsafe) == sorted(
216
- expressions[1].dimensions_unsafe
217
- )
218
-
219
- dims = expressions[0].dimensions_unsafe
220
- # Check no dims conflict
221
- assert all(
222
- sorted(dims) == sorted(expr.dimensions_unsafe) for expr in expressions[1:]
186
+ has_dim_conflict = any(
187
+ sorted(expressions[0]._dimensions_unsafe) != sorted(expr._dimensions_unsafe)
188
+ for expr in expressions[1:]
223
189
  )
224
- if requires_join:
225
- assert len(expressions) == 2
226
- assert dims != []
227
- left, right = expressions[0], expressions[1]
228
190
 
229
- # Order so that drop always comes before keep, and keep always comes before default
230
- if (left.unmatched_strategy, right.unmatched_strategy) in (
231
- (UnmatchedStrategy.UNSET, UnmatchedStrategy.DROP),
232
- (UnmatchedStrategy.UNSET, UnmatchedStrategy.KEEP),
233
- (UnmatchedStrategy.KEEP, UnmatchedStrategy.DROP),
234
- ):
235
- left, right = right, left
191
+ # If we cannot use .concat compute the sum in a pairwise manner, so far nobody uses this code
192
+ if len(expressions) > 2: # pragma: no cover
193
+ assert False, "This code has not been tested."
194
+ if has_dim_conflict or not no_extras_checks_required:
195
+ result = expressions[0]
196
+ for expr in expressions[1:]:
197
+ result = add(result, expr)
198
+ return result
199
+ propagate_strat = expressions[0]._extras_strategy
200
+ dims = expressions[0]._dimensions_unsafe
201
+ expr_data = [expr.data for expr in expressions]
202
+ else:
203
+ left, right = expressions[0], expressions[1]
236
204
 
237
- def get_indices(expr):
238
- return expr.data.select(dims).unique(maintain_order=True)
205
+ if has_dim_conflict:
206
+ left_dims, right_dims = left._dimensions_unsafe, right._dimensions_unsafe
207
+
208
+ missing_left = [dim for dim in right_dims if dim not in left_dims]
209
+ missing_right = [dim for dim in left_dims if dim not in right_dims]
210
+ common_dims = [dim for dim in left_dims if dim in right_dims]
211
+
212
+ if not (
213
+ set(missing_left) <= set(left._allowed_new_dims)
214
+ and set(missing_right) <= set(right._allowed_new_dims)
215
+ ):
216
+ _raise_addition_error(
217
+ left,
218
+ right,
219
+ f"their\n\tdimensions are different ({left_dims} != {right_dims})",
220
+ "If this is intentional, use .over(…) to broadcast. Learn more at\n\thttps://bravos-power.github.io/pyoframe/latest/learn/concepts/addition/#adding-expressions-with-differing-dimensions-using-over",
221
+ )
239
222
 
240
- left_data, right_data = left.data, right.data
223
+ left_old = left
224
+ if missing_left:
225
+ left = _broadcast(left, right, common_dims, missing_left)
226
+ if missing_right:
227
+ right = _broadcast(
228
+ right, left_old, common_dims, missing_right, swapped=True
229
+ )
241
230
 
242
- strat = (left.unmatched_strategy, right.unmatched_strategy)
231
+ assert sorted(left._dimensions_unsafe) == sorted(right._dimensions_unsafe)
243
232
 
244
- propogate_strat = propogatation_strategies[strat] # type: ignore
233
+ dims = left._dimensions_unsafe
245
234
 
246
- if strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.DROP):
247
- left_data = left.data.join(get_indices(right), how="inner", on=dims)
248
- right_data = right.data.join(get_indices(left), how="inner", on=dims)
249
- elif strat == (UnmatchedStrategy.UNSET, UnmatchedStrategy.UNSET):
250
- assert not Config.disable_unmatched_checks, (
251
- "This code should not be reached when unmatched checks are disabled."
252
- )
253
- outer_join = get_indices(left).join(
254
- get_indices(right),
255
- how="full",
256
- on=dims,
257
- )
258
- if outer_join.get_column(dims[0]).null_count() > 0:
259
- raise PyoframeError(
260
- "Dataframe has unmatched values. If this is intentional, use .drop_unmatched() or .keep_unmatched()\n"
261
- + str(outer_join.filter(outer_join.get_column(dims[0]).is_null()))
262
- )
263
- if outer_join.get_column(dims[0] + "_right").null_count() > 0:
264
- raise PyoframeError(
265
- "Dataframe has unmatched values. If this is intentional, use .drop_unmatched() or .keep_unmatched()\n"
266
- + str(
267
- outer_join.filter(
268
- outer_join.get_column(dims[0] + "_right").is_null()
269
- )
270
- )
271
- )
272
- elif strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.KEEP):
273
- left_data = get_indices(right).join(left.data, how="left", on=dims)
274
- elif strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.UNSET):
275
- left_data = get_indices(right).join(left.data, how="left", on=dims)
276
- if left_data.get_column(COEF_KEY).null_count() > 0:
277
- raise PyoframeError(
278
- "Dataframe has unmatched values. If this is intentional, use .drop_unmatched() or .keep_unmatched()\n"
279
- + str(left_data.filter(left_data.get_column(COEF_KEY).is_null()))
280
- )
281
- elif strat == (UnmatchedStrategy.KEEP, UnmatchedStrategy.UNSET):
282
- assert not Config.disable_unmatched_checks, (
283
- "This code should not be reached when unmatched checks are disabled."
284
- )
285
- unmatched = right.data.join(get_indices(left), how="anti", on=dims)
286
- if len(unmatched) > 0:
287
- raise PyoframeError(
288
- "Dataframe has unmatched values. If this is intentional, use .drop_unmatched() or .keep_unmatched()\n"
289
- + str(unmatched)
290
- )
291
- else: # pragma: no cover
292
- assert False, "This code should've never been reached!"
293
-
294
- expr_data = [left_data, right_data]
295
- else:
296
- propogate_strat = expressions[0].unmatched_strategy
297
- expr_data = [expr.data for expr in expressions]
235
+ if not no_extras_checks_required:
236
+ expr_data, propagate_strat = _handle_extra_labels(left, right, dims)
237
+ else:
238
+ propagate_strat = left._extras_strategy
239
+ expr_data = (left.data, right.data)
298
240
 
299
241
  # Add quadratic column if it is needed and doesn't already exist
300
242
  if any(QUAD_VAR_KEY in df.columns for df in expr_data):
301
243
  expr_data = [
302
244
  (
303
- df.with_columns(pl.lit(CONST_TERM).alias(QUAD_VAR_KEY).cast(KEY_TYPE))
245
+ df.with_columns(
246
+ pl.lit(CONST_TERM).alias(QUAD_VAR_KEY).cast(Config.id_dtype)
247
+ )
304
248
  if QUAD_VAR_KEY not in df.columns
305
249
  else df
306
250
  )
@@ -315,71 +259,215 @@ def _add_expressions_core(*expressions: "Expression") -> "Expression":
315
259
  data = pl.concat(expr_data, how="vertical_relaxed")
316
260
  data = _sum_like_terms(data)
317
261
 
318
- new_expr = expressions[0]._new(data)
319
- new_expr.unmatched_strategy = propogate_strat
262
+ full_name = expressions[0].name
263
+ for expr in expressions[1:]:
264
+ name = expr.name
265
+ full_name += f" - {name[1:]}" if name[0] == "-" else f" + {name}"
266
+
267
+ new_expr = expressions[0]._new(data, name=f"({full_name})")
268
+ new_expr._extras_strategy = propagate_strat
320
269
 
321
270
  return new_expr
322
271
 
323
272
 
324
- def _add_dimension(self: "Expression", target: "Expression") -> "Expression":
325
- target_dims = target.dimensions
326
- if target_dims is None:
327
- return self
328
- dims = self.dimensions
329
- if dims is None:
330
- dims_in_common = []
331
- missing_dims = target_dims
332
- else:
333
- dims_in_common = [dim for dim in dims if dim in target_dims]
334
- missing_dims = [dim for dim in target_dims if dim not in dims]
273
+ def _handle_extra_labels(
274
+ left: Expression, right: Expression, dims: list[str]
275
+ ) -> tuple[tuple[pl.DataFrame, pl.DataFrame], ExtrasStrategy]:
276
+ assert dims != []
277
+ # Order so that drop always comes before keep, and keep always comes before default
278
+ if swapped := (
279
+ (left._extras_strategy, right._extras_strategy)
280
+ in (
281
+ (ExtrasStrategy.UNSET, ExtrasStrategy.DROP),
282
+ (ExtrasStrategy.UNSET, ExtrasStrategy.KEEP),
283
+ (ExtrasStrategy.KEEP, ExtrasStrategy.DROP),
284
+ )
285
+ ):
286
+ left, right = right, left
335
287
 
336
- # We're already at the size of our target
337
- if not missing_dims:
338
- return self
288
+ def get_labels(expr):
289
+ return expr.data.select(dims).unique(maintain_order=Config.maintain_order)
339
290
 
340
- if not set(missing_dims) <= set(self.allowed_new_dims):
341
- # TODO actually suggest using e.g. .add_dim("a", "b") instead of just "use .add_dim()"
342
- raise PyoframeError(
343
- f"Dataframe has missing dimensions {missing_dims}. If this is intentional, use .add_dim()\n{self.data}"
291
+ left_data, right_data = left.data, right.data
292
+
293
+ strat = (left._extras_strategy, right._extras_strategy)
294
+
295
+ if strat == (ExtrasStrategy.DROP, ExtrasStrategy.DROP):
296
+ left_data = left.data.join(
297
+ get_labels(right),
298
+ on=dims,
299
+ maintain_order="left" if Config.maintain_order else None,
300
+ )
301
+ right_data = right.data.join(
302
+ get_labels(left),
303
+ on=dims,
304
+ maintain_order="left" if Config.maintain_order else None,
344
305
  )
306
+ elif strat == (ExtrasStrategy.UNSET, ExtrasStrategy.UNSET):
307
+ assert not Config.disable_extras_checks, (
308
+ "This code should not be reached when checks for extra values are disabled."
309
+ )
310
+ left_labels, right_labels = get_labels(left), get_labels(right)
311
+ left_extras = left_labels.join(right_labels, how="anti", on=dims)
312
+ right_extras = right_labels.join(left_labels, how="anti", on=dims)
313
+ if len(left_extras) > 0:
314
+ _raise_extras_error(
315
+ left, right, left_extras, swapped, extras_on_right=False
316
+ )
317
+ if len(right_extras) > 0:
318
+ _raise_extras_error(left, right, right_extras, swapped)
319
+
320
+ elif strat == (ExtrasStrategy.DROP, ExtrasStrategy.KEEP):
321
+ left_data = get_labels(right).join(
322
+ left.data,
323
+ on=dims,
324
+ maintain_order="left" if Config.maintain_order else None,
325
+ )
326
+ elif strat == (ExtrasStrategy.DROP, ExtrasStrategy.UNSET):
327
+ right_labels = get_labels(right)
328
+ left_data = right_labels.join(
329
+ left.data,
330
+ how="left",
331
+ on=dims,
332
+ maintain_order="left" if Config.maintain_order else None,
333
+ )
334
+ if left_data.get_column(COEF_KEY).null_count() > 0:
335
+ _raise_extras_error(
336
+ left,
337
+ right,
338
+ right_labels.join(get_labels(left), how="anti", on=dims),
339
+ swapped,
340
+ )
345
341
 
346
- target_data = target.data.select(target_dims).unique(maintain_order=True)
342
+ elif strat == (ExtrasStrategy.KEEP, ExtrasStrategy.UNSET):
343
+ assert not Config.disable_extras_checks, (
344
+ "This code should not be reached when checks for extra values are disabled."
345
+ )
346
+ extras = right.data.join(get_labels(left), how="anti", on=dims)
347
+ if len(extras) > 0:
348
+ _raise_extras_error(left, right, extras.select(dims), swapped)
349
+ else: # pragma: no cover
350
+ assert False, "This code should've never been reached!"
351
+
352
+ if swapped:
353
+ left_data, right_data = right_data, left_data
354
+
355
+ return (left_data, right_data), _extras_propagation_rules[strat]
356
+
357
+
358
+ def _raise_extras_error(
359
+ left: Expression,
360
+ right: Expression,
361
+ extra_labels: pl.DataFrame,
362
+ swapped: bool,
363
+ extras_on_right: bool = True,
364
+ ):
365
+ if swapped:
366
+ left, right = right, left
367
+ extras_on_right = not extras_on_right
368
+
369
+ expression_num = 2 if extras_on_right else 1
370
+
371
+ with Config.print_polars_config:
372
+ _raise_addition_error(
373
+ left,
374
+ right,
375
+ f"expression {expression_num} has extra labels",
376
+ f"Extra labels in expression {expression_num}:\n{extra_labels}\nUse .drop_extras() or .keep_extras() to indicate how the extra labels should be handled. Learn more at\n\thttps://bravos-power.github.io/pyoframe/latest/learn/concepts/addition",
377
+ )
347
378
 
348
- if not dims_in_common:
349
- return self._new(self.data.join(target_data, how="cross"))
350
379
 
351
- # If drop, we just do an inner join to get into the shape of the other
352
- if self.unmatched_strategy == UnmatchedStrategy.DROP:
353
- return self._new(self.data.join(target_data, on=dims_in_common, how="inner"))
380
+ def _raise_addition_error(
381
+ left: Expression, right: Expression, reason: str, postfix: str
382
+ ):
383
+ op = "add"
384
+ right_name = right.name
385
+ if right_name[0] == "-":
386
+ op = "subtract"
387
+ right_name = right_name[1:]
388
+ raise PyoframeError(
389
+ f"""Cannot {op} the two expressions below because {reason}.
390
+ Expression 1:\t{left.name}
391
+ Expression 2:\t{right_name}
392
+ {postfix}
393
+ """
394
+ )
354
395
 
355
- result = self.data.join(target_data, on=dims_in_common, how="left")
356
- right_has_missing = result.get_column(missing_dims[0]).null_count() > 0
357
- if right_has_missing:
358
- raise PyoframeError(
359
- f"Cannot add dimension {missing_dims} since it contains unmatched values. If this is intentional, consider using .drop_unmatched()"
396
+
397
+ # TODO consider returning a dataframe instead of an expression to simplify code (e.g. avoid copy_flags)
398
+ def _broadcast(
399
+ self: Expression,
400
+ target: Expression,
401
+ common_dims: list[str],
402
+ missing_dims: list[str],
403
+ swapped: bool = False,
404
+ ) -> Expression:
405
+ target_data = target.data.select(target._dimensions_unsafe).unique(
406
+ maintain_order=Config.maintain_order
407
+ )
408
+
409
+ if not common_dims:
410
+ res = self._new(self.data.join(target_data, how="cross"), name=self.name)
411
+ res._copy_flags(self)
412
+ return res
413
+
414
+ # If drop, we just do an inner join to get into the shape of the other
415
+ if self._extras_strategy == ExtrasStrategy.DROP:
416
+ res = self._new(
417
+ self.data.join(
418
+ target_data,
419
+ on=common_dims,
420
+ maintain_order="left" if Config.maintain_order else None,
421
+ ),
422
+ name=self.name,
423
+ )
424
+ res._copy_flags(self)
425
+ return res
426
+
427
+ result = self.data.join(
428
+ target_data,
429
+ on=common_dims,
430
+ how="left",
431
+ maintain_order="left" if Config.maintain_order else None,
432
+ )
433
+ if result.get_column(missing_dims[0]).null_count() > 0:
434
+ target_labels = target.data.select(target._dimensions_unsafe).unique(
435
+ maintain_order=Config.maintain_order
360
436
  )
361
- return self._new(result)
437
+ _raise_extras_error(
438
+ self,
439
+ target,
440
+ target_labels.join(self.data, how="anti", on=common_dims),
441
+ swapped,
442
+ )
443
+ res = self._new(result, self.name)
444
+ res._copy_flags(self)
445
+ return res
362
446
 
363
447
 
364
448
  def _sum_like_terms(df: pl.DataFrame) -> pl.DataFrame:
365
449
  """Combines terms with the same variables."""
366
450
  dims = [c for c in df.columns if c not in RESERVED_COL_KEYS]
367
451
  var_cols = [VAR_KEY] + ([QUAD_VAR_KEY] if QUAD_VAR_KEY in df.columns else [])
368
- df = df.group_by(dims + var_cols, maintain_order=True).sum()
452
+ df = df.group_by(dims + var_cols, maintain_order=Config.maintain_order).sum()
369
453
  return df
370
454
 
371
455
 
372
456
  def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
373
- """
374
- Removes the quadratic column and terms with a zero coefficient, when applicable.
457
+ """Removes the quadratic column and terms with a zero coefficient, when applicable.
375
458
 
376
459
  Specifically, zero coefficient terms are always removed, except if they're the only terms in which case the expression contains a single term.
377
460
  The quadratic column is removed if the expression is not a quadratic.
378
461
 
379
462
  Examples:
380
-
381
463
  >>> import polars as pl
382
- >>> df = pl.DataFrame({ VAR_KEY: [CONST_TERM, 1], QUAD_VAR_KEY: [CONST_TERM, 1], COEF_KEY: [1.0, 0]})
464
+ >>> df = pl.DataFrame(
465
+ ... {
466
+ ... VAR_KEY: [CONST_TERM, 1],
467
+ ... QUAD_VAR_KEY: [CONST_TERM, 1],
468
+ ... COEF_KEY: [1.0, 0],
469
+ ... }
470
+ ... )
383
471
  >>> _simplify_expr_df(df)
384
472
  shape: (1, 2)
385
473
  ┌───────────────┬─────────┐
@@ -389,7 +477,21 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
389
477
  ╞═══════════════╪═════════╡
390
478
  │ 0 ┆ 1.0 │
391
479
  └───────────────┴─────────┘
392
- >>> df = pl.DataFrame({"t": [1, 1, 2, 2, 3, 3], VAR_KEY: [CONST_TERM, 1, CONST_TERM, 1, 1, 2], QUAD_VAR_KEY: [CONST_TERM, CONST_TERM, CONST_TERM, CONST_TERM, CONST_TERM, 1], COEF_KEY: [1, 0, 0, 0, 1, 0]})
480
+ >>> df = pl.DataFrame(
481
+ ... {
482
+ ... "t": [1, 1, 2, 2, 3, 3],
483
+ ... VAR_KEY: [CONST_TERM, 1, CONST_TERM, 1, 1, 2],
484
+ ... QUAD_VAR_KEY: [
485
+ ... CONST_TERM,
486
+ ... CONST_TERM,
487
+ ... CONST_TERM,
488
+ ... CONST_TERM,
489
+ ... CONST_TERM,
490
+ ... 1,
491
+ ... ],
492
+ ... COEF_KEY: [1, 0, 0, 0, 1, 0],
493
+ ... }
494
+ ... )
393
495
  >>> _simplify_expr_df(df)
394
496
  shape: (3, 3)
395
497
  ┌─────┬───────────────┬─────────┐
@@ -406,9 +508,14 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
406
508
  if len(df_filtered) < len(df):
407
509
  dims = [c for c in df.columns if c not in RESERVED_COL_KEYS]
408
510
  if dims:
409
- dim_values = df.select(dims).unique(maintain_order=True)
511
+ dim_values = df.select(dims).unique(maintain_order=Config.maintain_order)
410
512
  df = (
411
- dim_values.join(df_filtered, on=dims, how="left")
513
+ dim_values.join(
514
+ df_filtered,
515
+ on=dims,
516
+ how="left",
517
+ maintain_order="left" if Config.maintain_order else None,
518
+ )
412
519
  .with_columns(pl.col(COEF_KEY).fill_null(0))
413
520
  .fill_null(CONST_TERM)
414
521
  )
@@ -417,7 +524,7 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
417
524
  if df.is_empty():
418
525
  df = pl.DataFrame(
419
526
  {VAR_KEY: [CONST_TERM], COEF_KEY: [0]},
420
- schema={VAR_KEY: KEY_TYPE, COEF_KEY: pl.Float64},
527
+ schema={VAR_KEY: Config.id_dtype, COEF_KEY: pl.Float64},
421
528
  )
422
529
 
423
530
  if QUAD_VAR_KEY in df.columns and (df.get_column(QUAD_VAR_KEY) == CONST_TERM).all():
@@ -426,10 +533,11 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
426
533
  return df
427
534
 
428
535
 
429
- def _get_dimensions(df: pl.DataFrame) -> Optional[List[str]]:
430
- """
431
- Returns the dimensions of the DataFrame. Reserved columns do not count as dimensions.
432
- If there are no dimensions, returns None to force caller to handle this special case.
536
+ def _get_dimensions(df: pl.DataFrame) -> list[str] | None:
537
+ """Returns the dimensions of the DataFrame.
538
+
539
+ Reserved columns do not count as dimensions. If there are no dimensions,
540
+ returns `None` to force caller to handle this special case.
433
541
 
434
542
  Examples:
435
543
  >>> import polars as pl