pyoframe 0.2.0__py3-none-any.whl → 1.0.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyoframe/_arithmetic.py CHANGED
@@ -1,12 +1,12 @@
1
- """
2
- Defines helper functions for doing arithmetic operations on expressions (e.g. addition).
3
- """
1
+ """Defines helper functions for doing arithmetic operations on expressions (e.g. addition)."""
4
2
 
5
- from typing import TYPE_CHECKING, List, Optional
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
6
 
7
7
  import polars as pl
8
8
 
9
- from pyoframe.constants import (
9
+ from pyoframe._constants import (
10
10
  COEF_KEY,
11
11
  CONST_TERM,
12
12
  KEY_TYPE,
@@ -19,81 +19,60 @@ from pyoframe.constants import (
19
19
  )
20
20
 
21
21
  if TYPE_CHECKING: # pragma: no cover
22
- from pyoframe.core import Expression
22
+ from pyoframe._core import Expression
23
23
 
24
24
 
25
- def _multiply_expressions(self: "Expression", other: "Expression") -> "Expression":
26
- """
27
- Multiply two or more expressions together.
25
+ def multiply(self: Expression, other: Expression) -> Expression:
26
+ """Multiplies two expressions together.
28
27
 
29
28
  Examples:
30
29
  >>> import pyoframe as pf
31
- >>> m = pf.Model("min")
30
+ >>> m = pf.Model()
32
31
  >>> m.x1 = pf.Variable()
33
32
  >>> m.x2 = pf.Variable()
34
33
  >>> m.x3 = pf.Variable()
35
34
  >>> result = 5 * m.x1 * m.x2
36
35
  >>> result
37
- <Expression size=1 dimensions={} terms=1 degree=2>
38
- 5 x2 * x1
36
+ <Expression terms=1 type=quadratic>
37
+ 5 x2 * x1
39
38
  >>> result * m.x3
40
39
  Traceback (most recent call last):
41
40
  ...
42
- pyoframe.constants.PyoframeError: Failed to multiply expressions:
43
- <Expression size=1 dimensions={} terms=1 degree=2> * <Expression size=1 dimensions={} terms=1>
44
- Due to error:
45
- Cannot multiply a quadratic expression by a non-constant.
41
+ pyoframe._constants.PyoframeError: Cannot multiply the two expressions below because the result would be a cubic. Only quadratic or linear expressions are allowed.
42
+ Expression 1 (quadratic): ((5 * x1) * x2)
43
+ Expression 2 (linear): x3
46
44
  """
47
- try:
48
- return _multiply_expressions_core(self, other)
49
- except PyoframeError as error:
50
- raise PyoframeError(
51
- "Failed to multiply expressions:\n"
52
- + " * ".join(
53
- e.to_str(include_header=True, include_data=False) for e in [self, other]
54
- )
55
- + "\nDue to error:\n"
56
- + str(error)
57
- ) from error
58
-
59
-
60
- def _add_expressions(*expressions: "Expression") -> "Expression":
61
- try:
62
- return _add_expressions_core(*expressions)
63
- except PyoframeError as error:
45
+ self_degree, other_degree = self.degree(), other.degree()
46
+ product_degree = self_degree + other_degree
47
+ if product_degree > 2:
48
+ assert product_degree <= 4, (
49
+ "Unexpected because expressions should not exceed degree 2."
50
+ )
51
+ res_type = "cubic" if product_degree == 3 else "quartic"
64
52
  raise PyoframeError(
65
- "Failed to add expressions:\n"
66
- + " + ".join(
67
- e.to_str(include_header=True, include_data=False) for e in expressions
68
- )
69
- + "\nDue to error:\n"
70
- + str(error)
71
- ) from error
53
+ f"""Cannot multiply the two expressions below because the result would be a {res_type}. Only quadratic or linear expressions are allowed.
54
+ Expression 1 ({self.degree(return_str=True)}):\t{self.name}
55
+ Expression 2 ({other.degree(return_str=True)}):\t{other.name}"""
56
+ )
72
57
 
58
+ if self_degree == 1 and other_degree == 1:
59
+ return _quadratic_multiplication(self, other)
73
60
 
74
- def _multiply_expressions_core(self: "Expression", other: "Expression") -> "Expression":
75
- self_degree, other_degree = self.degree(), other.degree()
76
- if self_degree + other_degree > 2:
77
- # We know one of the two must be a quadratic since 1 + 1 is not greater than 2.
78
- raise PyoframeError("Cannot multiply a quadratic expression by a non-constant.")
61
+ # save names to use in debug messages before any swapping occurs
62
+ self_name, other_name = self.name, other.name
79
63
  if self_degree < other_degree:
80
64
  self, other = other, self
81
65
  self_degree, other_degree = other_degree, self_degree
82
- if other_degree == 1:
83
- assert self_degree == 1, (
84
- "This should always be true since the sum of degrees must be <=2."
85
- )
86
- return _quadratic_multiplication(self, other)
87
66
 
88
67
  assert other_degree == 0, (
89
68
  "This should always be true since other cases have already been handled."
90
69
  )
91
- multiplier = other.data.drop(
92
- VAR_KEY
93
- ) # QUAD_VAR_KEY doesn't need to be dropped since we know it doesn't exist
94
70
 
95
- dims = self.dimensions_unsafe
96
- other_dims = other.dimensions_unsafe
71
+ # QUAD_VAR_KEY doesn't need to be dropped since we know it doesn't exist
72
+ multiplier = other.data.drop(VAR_KEY)
73
+
74
+ dims = self._dimensions_unsafe
75
+ other_dims = other._dimensions_unsafe
97
76
  dims_in_common = [dim for dim in dims if dim in other_dims]
98
77
 
99
78
  data = (
@@ -101,17 +80,19 @@ def _multiply_expressions_core(self: "Expression", other: "Expression") -> "Expr
101
80
  multiplier,
102
81
  on=dims_in_common if len(dims_in_common) > 0 else None,
103
82
  how="inner" if dims_in_common else "cross",
83
+ maintain_order=(
84
+ "left" if Config.maintain_order and dims_in_common else None
85
+ ),
104
86
  )
105
87
  .with_columns(pl.col(COEF_KEY) * pl.col(COEF_KEY + "_right"))
106
88
  .drop(COEF_KEY + "_right")
107
89
  )
108
90
 
109
- return self._new(data)
91
+ return self._new(data, name=f"({self_name} * {other_name})")
110
92
 
111
93
 
112
- def _quadratic_multiplication(self: "Expression", other: "Expression") -> "Expression":
113
- """
114
- Multiply two expressions of degree 1.
94
+ def _quadratic_multiplication(self: Expression, other: Expression) -> Expression:
95
+ """Multiplies two expressions of degree 1.
115
96
 
116
97
  Examples:
117
98
  >>> import polars as pl
@@ -122,18 +103,29 @@ def _quadratic_multiplication(self: "Expression", other: "Expression") -> "Expre
122
103
  >>> expr1 = df * m.x1
123
104
  >>> expr2 = df * m.x2 * 2 + 4
124
105
  >>> expr1 * expr2
125
- <Expression size=3 dimensions={'dim': 3} terms=6 degree=2>
126
- [1]: 4 x1 +2 x2 * x1
127
- [2]: 8 x1 +8 x2 * x1
128
- [3]: 12 x1 +18 x2 * x1
106
+ <Expression height=3 terms=6 type=quadratic>
107
+ ┌─────┬───────────────────┐
108
+ dim expression │
109
+ │ (3) ┆ │
110
+ ╞═════╪═══════════════════╡
111
+ │ 1 ┆ 4 x1 +2 x2 * x1 │
112
+ │ 2 ┆ 8 x1 +8 x2 * x1 │
113
+ │ 3 ┆ 12 x1 +18 x2 * x1 │
114
+ └─────┴───────────────────┘
129
115
  >>> (expr1 * expr2) - df * m.x1 * df * m.x2 * 2
130
- <Expression size=3 dimensions={'dim': 3} terms=3>
131
- [1]: 4 x1
132
- [2]: 8 x1
133
- [3]: 12 x1
116
+ <Expression height=3 terms=3 type=linear>
117
+ ┌─────┬────────────┐
118
+ dim ┆ expression │
119
+ │ (3) ┆ │
120
+ ╞═════╪════════════╡
121
+ │ 1 ┆ 4 x1 │
122
+ │ 2 ┆ 8 x1 │
123
+ │ 3 ┆ 12 x1 │
124
+ └─────┴────────────┘
125
+
134
126
  """
135
- dims = self.dimensions_unsafe
136
- other_dims = other.dimensions_unsafe
127
+ dims = self._dimensions_unsafe
128
+ other_dims = other._dimensions_unsafe
137
129
  dims_in_common = [dim for dim in dims if dim in other_dims]
138
130
 
139
131
  data = (
@@ -141,11 +133,14 @@ def _quadratic_multiplication(self: "Expression", other: "Expression") -> "Expre
141
133
  other.data,
142
134
  on=dims_in_common if len(dims_in_common) > 0 else None,
143
135
  how="inner" if dims_in_common else "cross",
136
+ maintain_order=(
137
+ "left" if Config.maintain_order and dims_in_common else None
138
+ ),
144
139
  )
145
140
  .with_columns(pl.col(COEF_KEY) * pl.col(COEF_KEY + "_right"))
146
141
  .drop(COEF_KEY + "_right")
147
142
  .rename({VAR_KEY + "_right": QUAD_VAR_KEY})
148
- # Swap VAR_KEY and QUAD_VAR_KEY so that VAR_KEy is always the larger one
143
+ # Swap VAR_KEY and QUAD_VAR_KEY so that VAR_KEY is always the larger one
149
144
  .with_columns(
150
145
  pl.when(pl.col(VAR_KEY) < pl.col(QUAD_VAR_KEY))
151
146
  .then(pl.col(QUAD_VAR_KEY))
@@ -160,12 +155,13 @@ def _quadratic_multiplication(self: "Expression", other: "Expression") -> "Expre
160
155
 
161
156
  data = _sum_like_terms(data)
162
157
 
163
- return self._new(data)
158
+ return self._new(data, name=f"({self.name} * {other.name})")
164
159
 
165
160
 
166
- def _add_expressions_core(*expressions: "Expression") -> "Expression":
167
- # Mapping of how a sum of two expressions should propogate the unmatched strategy
168
- propogatation_strategies = {
161
+ def add(*expressions: Expression) -> Expression:
162
+ """Add multiple expressions together."""
163
+ # Mapping of how a sum of two expressions should propagate the unmatched strategy
164
+ propagation_strategies = {
169
165
  (UnmatchedStrategy.DROP, UnmatchedStrategy.DROP): UnmatchedStrategy.DROP,
170
166
  (
171
167
  UnmatchedStrategy.UNSET,
@@ -186,40 +182,62 @@ def _add_expressions_core(*expressions: "Expression") -> "Expression":
186
182
  dims = []
187
183
  elif Config.disable_unmatched_checks:
188
184
  requires_join = any(
189
- expr.unmatched_strategy
185
+ expr._unmatched_strategy
190
186
  not in (UnmatchedStrategy.KEEP, UnmatchedStrategy.UNSET)
191
187
  for expr in expressions
192
188
  )
193
189
  else:
194
190
  requires_join = any(
195
- expr.unmatched_strategy != UnmatchedStrategy.KEEP for expr in expressions
191
+ expr._unmatched_strategy != UnmatchedStrategy.KEEP for expr in expressions
196
192
  )
197
193
 
198
194
  has_dim_conflict = any(
199
- sorted(dims) != sorted(expr.dimensions_unsafe) for expr in expressions[1:]
195
+ sorted(dims) != sorted(expr._dimensions_unsafe) for expr in expressions[1:]
200
196
  )
201
197
 
202
198
  # If we cannot use .concat compute the sum in a pairwise manner
203
199
  if len(expressions) > 2 and (has_dim_conflict or requires_join):
204
200
  result = expressions[0]
205
201
  for expr in expressions[1:]:
206
- result = _add_expressions_core(result, expr)
202
+ result = add(result, expr)
207
203
  return result
208
204
 
209
205
  if has_dim_conflict:
210
206
  assert len(expressions) == 2
211
- expressions = (
212
- _add_dimension(expressions[0], expressions[1]),
213
- _add_dimension(expressions[1], expressions[0]),
214
- )
215
- assert sorted(expressions[0].dimensions_unsafe) == sorted(
216
- expressions[1].dimensions_unsafe
217
- )
218
207
 
219
- dims = expressions[0].dimensions_unsafe
208
+ left, right = expressions[0], expressions[1]
209
+ left_dims, right_dims = left._dimensions_unsafe, right._dimensions_unsafe
210
+
211
+ missing_left = [dim for dim in right_dims if dim not in left_dims]
212
+ missing_right = [dim for dim in left_dims if dim not in right_dims]
213
+ common_dims = [dim for dim in left_dims if dim in right_dims]
214
+
215
+ if not (
216
+ set(missing_left) <= set(left._allowed_new_dims)
217
+ and set(missing_right) <= set(right._allowed_new_dims)
218
+ ):
219
+ _raise_addition_error(
220
+ left,
221
+ right,
222
+ f"their\n\tdimensions are different ({left_dims} != {right_dims})",
223
+ "If this is intentional, use .over(…) to broadcast. Learn more at\n\thttps://bravos-power.github.io/pyoframe/learn/concepts/special-functions/#adding-expressions-with-differing-dimensions-using-over",
224
+ )
225
+
226
+ left_old = left
227
+ if missing_left:
228
+ left = _broadcast(left, right, common_dims, missing_left)
229
+ if missing_right:
230
+ right = _broadcast(
231
+ right, left_old, common_dims, missing_right, swapped=True
232
+ )
233
+
234
+ assert sorted(left._dimensions_unsafe) == sorted(right._dimensions_unsafe)
235
+ expressions = (left, right)
236
+
237
+ dims = expressions[0]._dimensions_unsafe
220
238
  # Check no dims conflict
221
239
  assert all(
222
- sorted(dims) == sorted(expr.dimensions_unsafe) for expr in expressions[1:]
240
+ sorted(dims) == sorted(expr._dimensions_unsafe) for expr in expressions[1:]
223
241
  )
224
242
  if requires_join:
225
243
  assert len(expressions) == 2
@@ -227,25 +245,36 @@ def _add_expressions_core(*expressions: "Expression") -> "Expression":
227
245
  left, right = expressions[0], expressions[1]
228
246
 
229
247
  # Order so that drop always comes before keep, and keep always comes before default
230
- if (left.unmatched_strategy, right.unmatched_strategy) in (
231
- (UnmatchedStrategy.UNSET, UnmatchedStrategy.DROP),
232
- (UnmatchedStrategy.UNSET, UnmatchedStrategy.KEEP),
233
- (UnmatchedStrategy.KEEP, UnmatchedStrategy.DROP),
248
+ if swap := (
249
+ (left._unmatched_strategy, right._unmatched_strategy)
250
+ in (
251
+ (UnmatchedStrategy.UNSET, UnmatchedStrategy.DROP),
252
+ (UnmatchedStrategy.UNSET, UnmatchedStrategy.KEEP),
253
+ (UnmatchedStrategy.KEEP, UnmatchedStrategy.DROP),
254
+ )
234
255
  ):
235
256
  left, right = right, left
236
257
 
237
258
  def get_indices(expr):
238
- return expr.data.select(dims).unique(maintain_order=True)
259
+ return expr.data.select(dims).unique(maintain_order=Config.maintain_order)
239
260
 
240
261
  left_data, right_data = left.data, right.data
241
262
 
242
- strat = (left.unmatched_strategy, right.unmatched_strategy)
263
+ strat = (left._unmatched_strategy, right._unmatched_strategy)
243
264
 
244
- propogate_strat = propogatation_strategies[strat] # type: ignore
265
+ propagate_strat = propagation_strategies[strat] # type: ignore
245
266
 
246
267
  if strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.DROP):
247
- left_data = left.data.join(get_indices(right), how="inner", on=dims)
248
- right_data = right.data.join(get_indices(left), how="inner", on=dims)
268
+ left_data = left.data.join(
269
+ get_indices(right),
270
+ on=dims,
271
+ maintain_order="left" if Config.maintain_order else None,
272
+ )
273
+ right_data = right.data.join(
274
+ get_indices(left),
275
+ on=dims,
276
+ maintain_order="left" if Config.maintain_order else None,
277
+ )
249
278
  elif strat == (UnmatchedStrategy.UNSET, UnmatchedStrategy.UNSET):
250
279
  assert not Config.disable_unmatched_checks, (
251
280
  "This code should not be reached when unmatched checks are disabled."
@@ -254,46 +283,54 @@ def _add_expressions_core(*expressions: "Expression") -> "Expression":
254
283
  get_indices(right),
255
284
  how="full",
256
285
  on=dims,
286
+ maintain_order="left_right" if Config.maintain_order else None,
257
287
  )
258
288
  if outer_join.get_column(dims[0]).null_count() > 0:
259
- raise PyoframeError(
260
- "Dataframe has unmatched values. If this is intentional, use .drop_unmatched() or .keep_unmatched()\n"
261
- + str(outer_join.filter(outer_join.get_column(dims[0]).is_null()))
289
+ unmatched_vals = outer_join.filter(
290
+ outer_join.get_column(dims[0]).is_null()
262
291
  )
292
+ _raise_unmatched_values_error(left, right, unmatched_vals, swap)
263
293
  if outer_join.get_column(dims[0] + "_right").null_count() > 0:
264
- raise PyoframeError(
265
- "Dataframe has unmatched values. If this is intentional, use .drop_unmatched() or .keep_unmatched()\n"
266
- + str(
267
- outer_join.filter(
268
- outer_join.get_column(dims[0] + "_right").is_null()
269
- )
270
- )
294
+ unmatched_vals = outer_join.filter(
295
+ outer_join.get_column(dims[0] + "_right").is_null()
271
296
  )
297
+ _raise_unmatched_values_error(left, right, unmatched_vals, swap)
298
+
272
299
  elif strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.KEEP):
273
- left_data = get_indices(right).join(left.data, how="left", on=dims)
300
+ left_data = get_indices(right).join(
301
+ left.data,
302
+ how="left",
303
+ on=dims,
304
+ maintain_order="left" if Config.maintain_order else None,
305
+ )
274
306
  elif strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.UNSET):
275
- left_data = get_indices(right).join(left.data, how="left", on=dims)
307
+ left_data = get_indices(right).join(
308
+ left.data,
309
+ how="left",
310
+ on=dims,
311
+ maintain_order="left" if Config.maintain_order else None,
312
+ )
276
313
  if left_data.get_column(COEF_KEY).null_count() > 0:
277
- raise PyoframeError(
278
- "Dataframe has unmatched values. If this is intentional, use .drop_unmatched() or .keep_unmatched()\n"
279
- + str(left_data.filter(left_data.get_column(COEF_KEY).is_null()))
314
+ _raise_unmatched_values_error(
315
+ left,
316
+ right,
317
+ left_data.filter(left_data.get_column(COEF_KEY).is_null()),
318
+ swap,
280
319
  )
320
+
281
321
  elif strat == (UnmatchedStrategy.KEEP, UnmatchedStrategy.UNSET):
282
322
  assert not Config.disable_unmatched_checks, (
283
323
  "This code should not be reached when unmatched checks are disabled."
284
324
  )
285
325
  unmatched = right.data.join(get_indices(left), how="anti", on=dims)
286
326
  if len(unmatched) > 0:
287
- raise PyoframeError(
288
- "Dataframe has unmatched values. If this is intentional, use .drop_unmatched() or .keep_unmatched()\n"
289
- + str(unmatched)
290
- )
327
+ _raise_unmatched_values_error(left, right, unmatched, swap)
291
328
  else: # pragma: no cover
292
329
  assert False, "This code should've never been reached!"
293
330
 
294
331
  expr_data = [left_data, right_data]
295
332
  else:
296
- propogate_strat = expressions[0].unmatched_strategy
333
+ propagate_strat = expressions[0]._unmatched_strategy
297
334
  expr_data = [expr.data for expr in expressions]
298
335
 
299
336
  # Add quadratic column if it is needed and doesn't already exist
@@ -315,71 +352,120 @@ def _add_expressions_core(*expressions: "Expression") -> "Expression":
315
352
  data = pl.concat(expr_data, how="vertical_relaxed")
316
353
  data = _sum_like_terms(data)
317
354
 
318
- new_expr = expressions[0]._new(data)
319
- new_expr.unmatched_strategy = propogate_strat
355
+ full_name = expressions[0].name
356
+ for expr in expressions[1:]:
357
+ name = expr.name
358
+ full_name += f" - {name[1:]}" if name[0] == "-" else f" + {name}"
359
+
360
+ new_expr = expressions[0]._new(data, name=f"({full_name})")
361
+ new_expr._unmatched_strategy = propagate_strat
320
362
 
321
363
  return new_expr
322
364
 
323
365
 
324
- def _add_dimension(self: "Expression", target: "Expression") -> "Expression":
325
- target_dims = target.dimensions
326
- if target_dims is None:
327
- return self
328
- dims = self.dimensions
329
- if dims is None:
330
- dims_in_common = []
331
- missing_dims = target_dims
332
- else:
333
- dims_in_common = [dim for dim in dims if dim in target_dims]
334
- missing_dims = [dim for dim in target_dims if dim not in dims]
366
+ def _raise_unmatched_values_error(
367
+ left: Expression, right: Expression, unmatched_values: pl.DataFrame, swapped: bool
368
+ ):
369
+ if swapped:
370
+ left, right = right, left
371
+
372
+ _raise_addition_error(
373
+ left,
374
+ right,
375
+ "of unmatched values",
376
+ f"Unmatched values:\n{unmatched_values}\nIf this is intentional, use .drop_unmatched() or .keep_unmatched().",
377
+ )
335
378
 
336
- # We're already at the size of our target
337
- if not missing_dims:
338
- return self
339
379
 
340
- if not set(missing_dims) <= set(self.allowed_new_dims):
341
- # TODO actually suggest using e.g. .add_dim("a", "b") instead of just "use .add_dim()"
342
- raise PyoframeError(
343
- f"Dataframe has missing dimensions {missing_dims}. If this is intentional, use .add_dim()\n{self.data}"
344
- )
380
+ def _raise_addition_error(
381
+ left: Expression, right: Expression, reason: str, postfix: str
382
+ ):
383
+ op = "add"
384
+ right_name = right.name
385
+ if right_name[0] == "-":
386
+ op = "subtract"
387
+ right_name = right_name[1:]
388
+ raise PyoframeError(
389
+ f"""Cannot {op} the two expressions below because {reason}.
390
+ Expression 1:\t{left.name}
391
+ Expression 2:\t{right_name}
392
+ {postfix}
393
+ """
394
+ )
395
+
345
396
 
346
- target_data = target.data.select(target_dims).unique(maintain_order=True)
397
+ # TODO consider returning a dataframe instead of an expression to simplify code (e.g. avoid copy_flags)
398
+ def _broadcast(
399
+ self: Expression,
400
+ target: Expression,
401
+ common_dims: list[str],
402
+ missing_dims: list[str],
403
+ swapped: bool = False,
404
+ ) -> Expression:
405
+ target_data = target.data.select(target._dimensions_unsafe).unique(
406
+ maintain_order=Config.maintain_order
407
+ )
347
408
 
348
- if not dims_in_common:
349
- return self._new(self.data.join(target_data, how="cross"))
409
+ if not common_dims:
410
+ res = self._new(self.data.join(target_data, how="cross"), name=self.name)
411
+ res._copy_flags(self)
412
+ return res
350
413
 
351
414
  # If drop, we just do an inner join to get into the shape of the other
352
- if self.unmatched_strategy == UnmatchedStrategy.DROP:
353
- return self._new(self.data.join(target_data, on=dims_in_common, how="inner"))
354
-
355
- result = self.data.join(target_data, on=dims_in_common, how="left")
415
+ if self._unmatched_strategy == UnmatchedStrategy.DROP:
416
+ res = self._new(
417
+ self.data.join(
418
+ target_data,
419
+ on=common_dims,
420
+ maintain_order="left" if Config.maintain_order else None,
421
+ ),
422
+ name=self.name,
423
+ )
424
+ res._copy_flags(self)
425
+ return res
426
+
427
+ result = self.data.join(
428
+ target_data,
429
+ on=common_dims,
430
+ how="left",
431
+ maintain_order="left" if Config.maintain_order else None,
432
+ )
356
433
  right_has_missing = result.get_column(missing_dims[0]).null_count() > 0
357
434
  if right_has_missing:
358
- raise PyoframeError(
359
- f"Cannot add dimension {missing_dims} since it contains unmatched values. If this is intentional, consider using .drop_unmatched()"
435
+ _raise_unmatched_values_error(
436
+ self,
437
+ target,
438
+ result.filter(result.get_column(missing_dims[0]).is_null()),
439
+ swapped,
360
440
  )
361
- return self._new(result)
441
+ res = self._new(result, self.name)
442
+ res._copy_flags(self)
443
+ return res
362
444
 
363
445
 
364
446
  def _sum_like_terms(df: pl.DataFrame) -> pl.DataFrame:
365
447
  """Combines terms with the same variables."""
366
448
  dims = [c for c in df.columns if c not in RESERVED_COL_KEYS]
367
449
  var_cols = [VAR_KEY] + ([QUAD_VAR_KEY] if QUAD_VAR_KEY in df.columns else [])
368
- df = df.group_by(dims + var_cols, maintain_order=True).sum()
450
+ df = df.group_by(dims + var_cols, maintain_order=Config.maintain_order).sum()
369
451
  return df
370
452
 
371
453
 
372
454
  def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
373
- """
374
- Removes the quadratic column and terms with a zero coefficient, when applicable.
455
+ """Removes the quadratic column and terms with a zero coefficient, when applicable.
375
456
 
376
457
  Specifically, zero coefficient terms are always removed, except if they're the only terms in which case the expression contains a single term.
377
458
  The quadratic column is removed if the expression is not a quadratic.
378
459
 
379
460
  Examples:
380
-
381
461
  >>> import polars as pl
382
- >>> df = pl.DataFrame({ VAR_KEY: [CONST_TERM, 1], QUAD_VAR_KEY: [CONST_TERM, 1], COEF_KEY: [1.0, 0]})
462
+ >>> df = pl.DataFrame(
463
+ ... {
464
+ ... VAR_KEY: [CONST_TERM, 1],
465
+ ... QUAD_VAR_KEY: [CONST_TERM, 1],
466
+ ... COEF_KEY: [1.0, 0],
467
+ ... }
468
+ ... )
383
469
  >>> _simplify_expr_df(df)
384
470
  shape: (1, 2)
385
471
  ┌───────────────┬─────────┐
@@ -389,7 +475,21 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
389
475
  ╞═══════════════╪═════════╡
390
476
  │ 0 ┆ 1.0 │
391
477
  └───────────────┴─────────┘
392
- >>> df = pl.DataFrame({"t": [1, 1, 2, 2, 3, 3], VAR_KEY: [CONST_TERM, 1, CONST_TERM, 1, 1, 2], QUAD_VAR_KEY: [CONST_TERM, CONST_TERM, CONST_TERM, CONST_TERM, CONST_TERM, 1], COEF_KEY: [1, 0, 0, 0, 1, 0]})
478
+ >>> df = pl.DataFrame(
479
+ ... {
480
+ ... "t": [1, 1, 2, 2, 3, 3],
481
+ ... VAR_KEY: [CONST_TERM, 1, CONST_TERM, 1, 1, 2],
482
+ ... QUAD_VAR_KEY: [
483
+ ... CONST_TERM,
484
+ ... CONST_TERM,
485
+ ... CONST_TERM,
486
+ ... CONST_TERM,
487
+ ... CONST_TERM,
488
+ ... 1,
489
+ ... ],
490
+ ... COEF_KEY: [1, 0, 0, 0, 1, 0],
491
+ ... }
492
+ ... )
393
493
  >>> _simplify_expr_df(df)
394
494
  shape: (3, 3)
395
495
  ┌─────┬───────────────┬─────────┐
@@ -406,9 +506,14 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
406
506
  if len(df_filtered) < len(df):
407
507
  dims = [c for c in df.columns if c not in RESERVED_COL_KEYS]
408
508
  if dims:
409
- dim_values = df.select(dims).unique(maintain_order=True)
509
+ dim_values = df.select(dims).unique(maintain_order=Config.maintain_order)
410
510
  df = (
411
- dim_values.join(df_filtered, on=dims, how="left")
511
+ dim_values.join(
512
+ df_filtered,
513
+ on=dims,
514
+ how="left",
515
+ maintain_order="left" if Config.maintain_order else None,
516
+ )
412
517
  .with_columns(pl.col(COEF_KEY).fill_null(0))
413
518
  .fill_null(CONST_TERM)
414
519
  )
@@ -426,10 +531,11 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
426
531
  return df
427
532
 
428
533
 
429
- def _get_dimensions(df: pl.DataFrame) -> Optional[List[str]]:
430
- """
431
- Returns the dimensions of the DataFrame. Reserved columns do not count as dimensions.
432
- If there are no dimensions, returns None to force caller to handle this special case.
534
+ def _get_dimensions(df: pl.DataFrame) -> list[str] | None:
535
+ """Returns the dimensions of the DataFrame.
536
+
537
+ Reserved columns do not count as dimensions. If there are no dimensions,
538
+ returns `None` to force caller to handle this special case.
433
539
 
434
540
  Examples:
435
541
  >>> import polars as pl