pyoframe 1.0.0a0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyoframe/__init__.py CHANGED
@@ -11,6 +11,7 @@ from pyoframe._core import Constraint, Expression, Set, Variable, sum, sum_by
11
11
  from pyoframe._model import Model
12
12
  from pyoframe._monkey_patch import patch_dataframe_libraries
13
13
  from pyoframe._objective import Objective
14
+ from pyoframe._param import Param
14
15
 
15
16
  try:
16
17
  from pyoframe._version import __version__, __version_tuple__ # noqa: F401
@@ -25,6 +26,7 @@ __all__ = [
25
26
  "Expression",
26
27
  "Constraint",
27
28
  "Objective",
29
+ "Param",
28
30
  "Set",
29
31
  "Config",
30
32
  "sum",
pyoframe/_arithmetic.py CHANGED
@@ -9,18 +9,27 @@ import polars as pl
9
9
  from pyoframe._constants import (
10
10
  COEF_KEY,
11
11
  CONST_TERM,
12
- KEY_TYPE,
13
12
  QUAD_VAR_KEY,
14
13
  RESERVED_COL_KEYS,
15
14
  VAR_KEY,
16
15
  Config,
16
+ ExtrasStrategy,
17
17
  PyoframeError,
18
- UnmatchedStrategy,
19
18
  )
20
19
 
21
20
  if TYPE_CHECKING: # pragma: no cover
22
21
  from pyoframe._core import Expression
23
22
 
23
+ # Mapping of how a sum of two expressions should propagate the extras strategy
24
+ _extras_propagation_rules = {
25
+ (ExtrasStrategy.DROP, ExtrasStrategy.DROP): ExtrasStrategy.DROP,
26
+ (ExtrasStrategy.UNSET, ExtrasStrategy.UNSET): ExtrasStrategy.UNSET,
27
+ (ExtrasStrategy.KEEP, ExtrasStrategy.KEEP): ExtrasStrategy.KEEP,
28
+ (ExtrasStrategy.DROP, ExtrasStrategy.KEEP): ExtrasStrategy.UNSET,
29
+ (ExtrasStrategy.DROP, ExtrasStrategy.UNSET): ExtrasStrategy.DROP,
30
+ (ExtrasStrategy.KEEP, ExtrasStrategy.UNSET): ExtrasStrategy.KEEP,
31
+ }
32
+
24
33
 
25
34
  def multiply(self: Expression, other: Expression) -> Expression:
26
35
  """Multiplies two expressions together.
@@ -33,7 +42,7 @@ def multiply(self: Expression, other: Expression) -> Expression:
33
42
  >>> m.x3 = pf.Variable()
34
43
  >>> result = 5 * m.x1 * m.x2
35
44
  >>> result
36
- <Expression terms=1 type=quadratic>
45
+ <Expression (quadratic) terms=1>
37
46
  5 x2 * x1
38
47
  >>> result * m.x3
39
48
  Traceback (most recent call last):
@@ -103,7 +112,7 @@ def _quadratic_multiplication(self: Expression, other: Expression) -> Expression
103
112
  >>> expr1 = df * m.x1
104
113
  >>> expr2 = df * m.x2 * 2 + 4
105
114
  >>> expr1 * expr2
106
- <Expression height=3 terms=6 type=quadratic>
115
+ <Expression (quadratic) height=3 terms=6>
107
116
  ┌─────┬───────────────────┐
108
117
  │ dim ┆ expression │
109
118
  │ (3) ┆ │
@@ -113,7 +122,7 @@ def _quadratic_multiplication(self: Expression, other: Expression) -> Expression
113
122
  │ 3 ┆ 12 x1 +18 x2 * x1 │
114
123
  └─────┴───────────────────┘
115
124
  >>> (expr1 * expr2) - df * m.x1 * df * m.x2 * 2
116
- <Expression height=3 terms=3 type=linear>
125
+ <Expression (linear) height=3 terms=3>
117
126
  ┌─────┬────────────┐
118
127
  │ dim ┆ expression │
119
128
  │ (3) ┆ │
@@ -160,184 +169,82 @@ def _quadratic_multiplication(self: Expression, other: Expression) -> Expression
160
169
 
161
170
  def add(*expressions: Expression) -> Expression:
162
171
  """Add multiple expressions together."""
163
- # Mapping of how a sum of two expressions should propagate the unmatched strategy
164
- propagation_strategies = {
165
- (UnmatchedStrategy.DROP, UnmatchedStrategy.DROP): UnmatchedStrategy.DROP,
166
- (
167
- UnmatchedStrategy.UNSET,
168
- UnmatchedStrategy.UNSET,
169
- ): UnmatchedStrategy.UNSET,
170
- (UnmatchedStrategy.KEEP, UnmatchedStrategy.KEEP): UnmatchedStrategy.KEEP,
171
- (UnmatchedStrategy.DROP, UnmatchedStrategy.KEEP): UnmatchedStrategy.UNSET,
172
- (UnmatchedStrategy.DROP, UnmatchedStrategy.UNSET): UnmatchedStrategy.DROP,
173
- (UnmatchedStrategy.KEEP, UnmatchedStrategy.UNSET): UnmatchedStrategy.KEEP,
174
- }
175
-
176
172
  assert len(expressions) > 1, "Need at least two expressions to add together."
177
173
 
178
- dims = expressions[0].dimensions
179
-
180
- if dims is None:
181
- requires_join = False
182
- dims = []
183
- elif Config.disable_unmatched_checks:
184
- requires_join = any(
185
- expr._unmatched_strategy
186
- not in (UnmatchedStrategy.KEEP, UnmatchedStrategy.UNSET)
187
- for expr in expressions
188
- )
174
+ if Config.disable_extras_checks:
175
+ no_checks_strats = (ExtrasStrategy.KEEP, ExtrasStrategy.UNSET)
189
176
  else:
190
- requires_join = any(
191
- expr._unmatched_strategy != UnmatchedStrategy.KEEP for expr in expressions
192
- )
177
+ no_checks_strats = (ExtrasStrategy.KEEP,)
193
178
 
194
- has_dim_conflict = any(
195
- sorted(dims) != sorted(expr._dimensions_unsafe) for expr in expressions[1:]
179
+ no_extras_checks_required = (
180
+ all(expr._extras_strategy in no_checks_strats for expr in expressions)
181
+ # if only one dimensioned, then there is no such thing as extra labels,
182
+ # labels will be set by the only dimensioned expression
183
+ or sum(not expr.dimensionless for expr in expressions) <= 1
196
184
  )
197
185
 
198
- # If we cannot use .concat compute the sum in a pairwise manner
199
- if len(expressions) > 2 and (has_dim_conflict or requires_join):
200
- result = expressions[0]
201
- for expr in expressions[1:]:
202
- result = add(result, expr)
203
- return result
204
-
205
- if has_dim_conflict:
206
- assert len(expressions) == 2
207
-
208
- left, right = expressions[0], expressions[1]
209
- left_dims, right_dims = left._dimensions_unsafe, right._dimensions_unsafe
210
-
211
- missing_left = [dim for dim in right_dims if dim not in left_dims]
212
- missing_right = [dim for dim in left_dims if dim not in right_dims]
213
- common_dims = [dim for dim in left_dims if dim in right_dims]
214
-
215
- if not (
216
- set(missing_left) <= set(left._allowed_new_dims)
217
- and set(missing_right) <= set(right._allowed_new_dims)
218
- ):
219
- _raise_addition_error(
220
- left,
221
- right,
222
- f"their\n\tdimensions are different ({left_dims} != {right_dims})",
223
- "If this is intentional, use .over(…) to broadcast. Learn more at\n\thttps://bravos-power.github.io/pyoframe/learn/concepts/special-functions/#adding-expressions-with-differing-dimensions-using-over",
224
- )
225
-
226
- left_old = left
227
- if missing_left:
228
- left = _broadcast(left, right, common_dims, missing_left)
229
- if missing_right:
230
- right = _broadcast(
231
- right, left_old, common_dims, missing_right, swapped=True
232
- )
233
-
234
- assert sorted(left._dimensions_unsafe) == sorted(right._dimensions_unsafe)
235
- expressions = (left, right)
236
-
237
- dims = expressions[0]._dimensions_unsafe
238
- # Check no dims conflict
239
- assert all(
240
- sorted(dims) == sorted(expr._dimensions_unsafe) for expr in expressions[1:]
186
+ has_dim_conflict = any(
187
+ sorted(expressions[0]._dimensions_unsafe) != sorted(expr._dimensions_unsafe)
188
+ for expr in expressions[1:]
241
189
  )
242
- if requires_join:
243
- assert len(expressions) == 2
244
- assert dims != []
245
- left, right = expressions[0], expressions[1]
246
-
247
- # Order so that drop always comes before keep, and keep always comes before default
248
- if swap := (
249
- (left._unmatched_strategy, right._unmatched_strategy)
250
- in (
251
- (UnmatchedStrategy.UNSET, UnmatchedStrategy.DROP),
252
- (UnmatchedStrategy.UNSET, UnmatchedStrategy.KEEP),
253
- (UnmatchedStrategy.KEEP, UnmatchedStrategy.DROP),
254
- )
255
- ):
256
- left, right = right, left
257
-
258
- def get_indices(expr):
259
- return expr.data.select(dims).unique(maintain_order=Config.maintain_order)
260
-
261
- left_data, right_data = left.data, right.data
262
190
 
263
- strat = (left._unmatched_strategy, right._unmatched_strategy)
191
+ # If we cannot use .concat compute the sum in a pairwise manner, so far nobody uses this code
192
+ if len(expressions) > 2: # pragma: no cover
193
+ assert False, "This code has not been tested."
194
+ if has_dim_conflict or not no_extras_checks_required:
195
+ result = expressions[0]
196
+ for expr in expressions[1:]:
197
+ result = add(result, expr)
198
+ return result
199
+ propagate_strat = expressions[0]._extras_strategy
200
+ dims = expressions[0]._dimensions_unsafe
201
+ expr_data = [expr.data for expr in expressions]
202
+ else:
203
+ left, right = expressions[0], expressions[1]
264
204
 
265
- propagate_strat = propagation_strategies[strat] # type: ignore
205
+ if has_dim_conflict:
206
+ left_dims, right_dims = left._dimensions_unsafe, right._dimensions_unsafe
266
207
 
267
- if strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.DROP):
268
- left_data = left.data.join(
269
- get_indices(right),
270
- on=dims,
271
- maintain_order="left" if Config.maintain_order else None,
272
- )
273
- right_data = right.data.join(
274
- get_indices(left),
275
- on=dims,
276
- maintain_order="left" if Config.maintain_order else None,
277
- )
278
- elif strat == (UnmatchedStrategy.UNSET, UnmatchedStrategy.UNSET):
279
- assert not Config.disable_unmatched_checks, (
280
- "This code should not be reached when unmatched checks are disabled."
281
- )
282
- outer_join = get_indices(left).join(
283
- get_indices(right),
284
- how="full",
285
- on=dims,
286
- maintain_order="left_right" if Config.maintain_order else None,
287
- )
288
- if outer_join.get_column(dims[0]).null_count() > 0:
289
- unmatched_vals = outer_join.filter(
290
- outer_join.get_column(dims[0]).is_null()
291
- )
292
- _raise_unmatched_values_error(left, right, unmatched_vals, swap)
293
- if outer_join.get_column(dims[0] + "_right").null_count() > 0:
294
- unmatched_vals = outer_join.filter(
295
- outer_join.get_column(dims[0] + "_right").is_null()
296
- )
297
- _raise_unmatched_values_error(left, right, unmatched_vals, swap)
208
+ missing_left = [dim for dim in right_dims if dim not in left_dims]
209
+ missing_right = [dim for dim in left_dims if dim not in right_dims]
210
+ common_dims = [dim for dim in left_dims if dim in right_dims]
298
211
 
299
- elif strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.KEEP):
300
- left_data = get_indices(right).join(
301
- left.data,
302
- how="left",
303
- on=dims,
304
- maintain_order="left" if Config.maintain_order else None,
305
- )
306
- elif strat == (UnmatchedStrategy.DROP, UnmatchedStrategy.UNSET):
307
- left_data = get_indices(right).join(
308
- left.data,
309
- how="left",
310
- on=dims,
311
- maintain_order="left" if Config.maintain_order else None,
312
- )
313
- if left_data.get_column(COEF_KEY).null_count() > 0:
314
- _raise_unmatched_values_error(
212
+ if not (
213
+ set(missing_left) <= set(left._allowed_new_dims)
214
+ and set(missing_right) <= set(right._allowed_new_dims)
215
+ ):
216
+ _raise_addition_error(
315
217
  left,
316
218
  right,
317
- left_data.filter(left_data.get_column(COEF_KEY).is_null()),
318
- swap,
219
+ f"their\n\tdimensions are different ({left_dims} != {right_dims})",
220
+ "If this is intentional, use .over(…) to broadcast. Learn more at\n\thttps://bravos-power.github.io/pyoframe/latest/learn/concepts/addition/#adding-expressions-with-differing-dimensions-using-over",
319
221
  )
320
222
 
321
- elif strat == (UnmatchedStrategy.KEEP, UnmatchedStrategy.UNSET):
322
- assert not Config.disable_unmatched_checks, (
323
- "This code should not be reached when unmatched checks are disabled."
324
- )
325
- unmatched = right.data.join(get_indices(left), how="anti", on=dims)
326
- if len(unmatched) > 0:
327
- _raise_unmatched_values_error(left, right, unmatched, swap)
328
- else: # pragma: no cover
329
- assert False, "This code should've never been reached!"
223
+ left_old = left
224
+ if missing_left:
225
+ left = _broadcast(left, right, common_dims, missing_left)
226
+ if missing_right:
227
+ right = _broadcast(
228
+ right, left_old, common_dims, missing_right, swapped=True
229
+ )
330
230
 
331
- expr_data = [left_data, right_data]
332
- else:
333
- propagate_strat = expressions[0]._unmatched_strategy
334
- expr_data = [expr.data for expr in expressions]
231
+ assert sorted(left._dimensions_unsafe) == sorted(right._dimensions_unsafe)
232
+
233
+ dims = left._dimensions_unsafe
234
+
235
+ if not no_extras_checks_required:
236
+ expr_data, propagate_strat = _handle_extra_labels(left, right, dims)
237
+ else:
238
+ propagate_strat = left._extras_strategy
239
+ expr_data = (left.data, right.data)
335
240
 
336
241
  # Add quadratic column if it is needed and doesn't already exist
337
242
  if any(QUAD_VAR_KEY in df.columns for df in expr_data):
338
243
  expr_data = [
339
244
  (
340
- df.with_columns(pl.lit(CONST_TERM).alias(QUAD_VAR_KEY).cast(KEY_TYPE))
245
+ df.with_columns(
246
+ pl.lit(CONST_TERM).alias(QUAD_VAR_KEY).cast(Config.id_dtype)
247
+ )
341
248
  if QUAD_VAR_KEY not in df.columns
342
249
  else df
343
250
  )
@@ -358,23 +265,116 @@ def add(*expressions: Expression) -> Expression:
358
265
  full_name += f" - {name[1:]}" if name[0] == "-" else f" + {name}"
359
266
 
360
267
  new_expr = expressions[0]._new(data, name=f"({full_name})")
361
- new_expr._unmatched_strategy = propagate_strat
268
+ new_expr._extras_strategy = propagate_strat
362
269
 
363
270
  return new_expr
364
271
 
365
272
 
366
- def _raise_unmatched_values_error(
367
- left: Expression, right: Expression, unmatched_values: pl.DataFrame, swapped: bool
273
+ def _handle_extra_labels(
274
+ left: Expression, right: Expression, dims: list[str]
275
+ ) -> tuple[tuple[pl.DataFrame, pl.DataFrame], ExtrasStrategy]:
276
+ assert dims != []
277
+ # Order so that drop always comes before keep, and keep always comes before default
278
+ if swapped := (
279
+ (left._extras_strategy, right._extras_strategy)
280
+ in (
281
+ (ExtrasStrategy.UNSET, ExtrasStrategy.DROP),
282
+ (ExtrasStrategy.UNSET, ExtrasStrategy.KEEP),
283
+ (ExtrasStrategy.KEEP, ExtrasStrategy.DROP),
284
+ )
285
+ ):
286
+ left, right = right, left
287
+
288
+ def get_labels(expr):
289
+ return expr.data.select(dims).unique(maintain_order=Config.maintain_order)
290
+
291
+ left_data, right_data = left.data, right.data
292
+
293
+ strat = (left._extras_strategy, right._extras_strategy)
294
+
295
+ if strat == (ExtrasStrategy.DROP, ExtrasStrategy.DROP):
296
+ left_data = left.data.join(
297
+ get_labels(right),
298
+ on=dims,
299
+ maintain_order="left" if Config.maintain_order else None,
300
+ )
301
+ right_data = right.data.join(
302
+ get_labels(left),
303
+ on=dims,
304
+ maintain_order="left" if Config.maintain_order else None,
305
+ )
306
+ elif strat == (ExtrasStrategy.UNSET, ExtrasStrategy.UNSET):
307
+ assert not Config.disable_extras_checks, (
308
+ "This code should not be reached when checks for extra values are disabled."
309
+ )
310
+ left_labels, right_labels = get_labels(left), get_labels(right)
311
+ left_extras = left_labels.join(right_labels, how="anti", on=dims)
312
+ right_extras = right_labels.join(left_labels, how="anti", on=dims)
313
+ if len(left_extras) > 0:
314
+ _raise_extras_error(
315
+ left, right, left_extras, swapped, extras_on_right=False
316
+ )
317
+ if len(right_extras) > 0:
318
+ _raise_extras_error(left, right, right_extras, swapped)
319
+
320
+ elif strat == (ExtrasStrategy.DROP, ExtrasStrategy.KEEP):
321
+ left_data = get_labels(right).join(
322
+ left.data,
323
+ on=dims,
324
+ maintain_order="left" if Config.maintain_order else None,
325
+ )
326
+ elif strat == (ExtrasStrategy.DROP, ExtrasStrategy.UNSET):
327
+ right_labels = get_labels(right)
328
+ left_data = right_labels.join(
329
+ left.data,
330
+ how="left",
331
+ on=dims,
332
+ maintain_order="left" if Config.maintain_order else None,
333
+ )
334
+ if left_data.get_column(COEF_KEY).null_count() > 0:
335
+ _raise_extras_error(
336
+ left,
337
+ right,
338
+ right_labels.join(get_labels(left), how="anti", on=dims),
339
+ swapped,
340
+ )
341
+
342
+ elif strat == (ExtrasStrategy.KEEP, ExtrasStrategy.UNSET):
343
+ assert not Config.disable_extras_checks, (
344
+ "This code should not be reached when checks for extra values are disabled."
345
+ )
346
+ extras = right.data.join(get_labels(left), how="anti", on=dims)
347
+ if len(extras) > 0:
348
+ _raise_extras_error(left, right, extras.select(dims), swapped)
349
+ else: # pragma: no cover
350
+ assert False, "This code should've never been reached!"
351
+
352
+ if swapped:
353
+ left_data, right_data = right_data, left_data
354
+
355
+ return (left_data, right_data), _extras_propagation_rules[strat]
356
+
357
+
358
+ def _raise_extras_error(
359
+ left: Expression,
360
+ right: Expression,
361
+ extra_labels: pl.DataFrame,
362
+ swapped: bool,
363
+ extras_on_right: bool = True,
368
364
  ):
369
365
  if swapped:
370
366
  left, right = right, left
367
+ extras_on_right = not extras_on_right
371
368
 
372
- _raise_addition_error(
373
- left,
374
- right,
375
- "of unmatched values",
376
- f"Unmatched values:\n{unmatched_values}\nIf this is intentional, use .drop_unmatched() or .keep_unmatched().",
377
- )
369
+ expression_num = 2 if extras_on_right else 1
370
+
371
+ with Config.print_polars_config:
372
+ _raise_addition_error(
373
+ left,
374
+ right,
375
+ f"expression {expression_num} has extra labels",
376
+ f"Extra labels in expression {expression_num}:\n{extra_labels}\nUse .drop_extras() or .keep_extras() to indicate how the extra labels should be handled. Learn more at\n\thttps://bravos-power.github.io/pyoframe/latest/learn/concepts/addition",
377
+ )
378
378
 
379
379
 
380
380
  def _raise_addition_error(
@@ -412,7 +412,7 @@ def _broadcast(
412
412
  return res
413
413
 
414
414
  # If drop, we just do an inner join to get into the shape of the other
415
- if self._unmatched_strategy == UnmatchedStrategy.DROP:
415
+ if self._extras_strategy == ExtrasStrategy.DROP:
416
416
  res = self._new(
417
417
  self.data.join(
418
418
  target_data,
@@ -430,12 +430,14 @@ def _broadcast(
430
430
  how="left",
431
431
  maintain_order="left" if Config.maintain_order else None,
432
432
  )
433
- right_has_missing = result.get_column(missing_dims[0]).null_count() > 0
434
- if right_has_missing:
435
- _raise_unmatched_values_error(
433
+ if result.get_column(missing_dims[0]).null_count() > 0:
434
+ target_labels = target.data.select(target._dimensions_unsafe).unique(
435
+ maintain_order=Config.maintain_order
436
+ )
437
+ _raise_extras_error(
436
438
  self,
437
439
  target,
438
- result.filter(result.get_column(missing_dims[0]).is_null()),
440
+ target_labels.join(self.data, how="anti", on=common_dims),
439
441
  swapped,
440
442
  )
441
443
  res = self._new(result, self.name)
@@ -522,7 +524,7 @@ def _simplify_expr_df(df: pl.DataFrame) -> pl.DataFrame:
522
524
  if df.is_empty():
523
525
  df = pl.DataFrame(
524
526
  {VAR_KEY: [CONST_TERM], COEF_KEY: [0]},
525
- schema={VAR_KEY: KEY_TYPE, COEF_KEY: pl.Float64},
527
+ schema={VAR_KEY: Config.id_dtype, COEF_KEY: pl.Float64},
526
528
  )
527
529
 
528
530
  if QUAD_VAR_KEY in df.columns and (df.get_column(QUAD_VAR_KEY) == CONST_TERM).all():