relationalai 0.12.7__py3-none-any.whl → 0.12.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,219 +1,534 @@
1
+ """Solver model implementation using protobuf format.
2
+
3
+ This module provides the SolverModelPB class for defining optimization and
4
+ constraint programming problems that are serialized to protobuf format and
5
+ solved by external solver engines.
6
+
7
+ Note: This protobuf-based implementation will be deprecated in favor of the
8
+ development version (solvers_dev.py) in future releases.
9
+ """
10
+
1
11
  from __future__ import annotations
2
- from typing import Any, Union
12
+
3
13
  import textwrap
4
- import uuid
5
14
  import time
15
+ import uuid
16
+ from typing import Any, Optional
6
17
 
7
- from relationalai.semantics.metamodel.util import ordered_set
8
- from relationalai.semantics.internal import internal as b # TODO(coey) change b name or remove b.?
9
- from relationalai.semantics.rel.executor import RelExecutor
10
- from .common import make_name
11
18
  from relationalai.experimental.solvers import Solver
19
+ from relationalai.semantics.internal import internal as b
20
+ from relationalai.semantics.rel.executor import RelExecutor
12
21
  from relationalai.tools.constants import DEFAULT_QUERY_TIMEOUT_MINS
13
22
  from relationalai.util.timeout import calc_remaining_timeout_minutes
14
23
 
15
- _Any = Union[b.Producer, str, float, int]
16
- _Number = Union[b.Producer, float, int]
24
+ from .common import make_name
25
+
26
+
27
+ # =============================================================================
28
+ # Solver ProtoBuf Format Constants and Helpers
29
+ # =============================================================================
30
+
31
+ # Variable type codes for the solver protobuf format
32
+ # cont: continuous (real-valued), int: integer, bin: binary (0/1)
33
+ _VARIABLE_TYPE_CODES: dict[str, int] = {
34
+ "cont": 40,
35
+ "int": 41,
36
+ "bin": 42,
37
+ }
38
+
39
+ # First-order operators: arithmetic and mathematical functions
40
+ _FIRST_ORDER_OPERATOR_CODES: dict[str, int] = {
41
+ "+": 10,
42
+ "-": 11,
43
+ "*": 12,
44
+ "/": 13,
45
+ "^": 14,
46
+ "abs": 20,
47
+ "exp": 21,
48
+ "log": 22,
49
+ "range": 50,
50
+ }
51
+
52
+ # First-order comparison operators: relational constraints
53
+ _FIRST_ORDER_COMPARISON_CODES: dict[str, int] = {
54
+ "=": 30,
55
+ "!=": 31,
56
+ "<=": 32,
57
+ ">=": 33,
58
+ "<": 34,
59
+ ">": 35,
60
+ "implies": 62,
61
+ }
62
+
63
+ # Higher-order operators: aggregation and global constraints
64
+ _HIGHER_ORDER_OPERATOR_CODES: dict[str, int] = {
65
+ "sum": 80,
66
+ "min": 82,
67
+ "max": 83,
68
+ "count": 84,
69
+ "all_different": 90,
70
+ }
71
+
72
+ # Mapping from bound type keywords to comparison operators
73
+ _BOUND_TO_COMPARISON_OPERATOR: dict[str, str] = {
74
+ "lower": ">=",
75
+ "upper": "<=",
76
+ "fixed": "=",
77
+ }
78
+
79
+
80
+ def _make_first_order_application_with_result(
81
+ operator_code: int, *args: Any
82
+ ) -> b.Expression:
83
+ """Create a first-order application with a result variable."""
84
+ return _make_first_order_application(operator_code, *args, b.String.ref("res"))
85
+
86
+
87
+ def _make_first_order_application(operator_code: int, *args: Any) -> b.Expression:
88
+ """Create a first-order application expression."""
89
+ if not (2 <= len(args) <= 4):
90
+ raise ValueError(
91
+ f"First-order application requires 2-4 arguments, but got {len(args)}."
92
+ )
93
+ result_ref = args[-1]
94
+ if not isinstance(result_ref, b.Ref):
95
+ raise TypeError(
96
+ f"Last argument must be a Ref, got {type(result_ref).__name__}."
97
+ )
98
+ if result_ref._thing != b.String:
99
+ result_ref = b.String.ref("res")
100
+ application_builtin = b.Relationship.builtins["rel_primitive_solverlib_fo_appl"]
101
+ # Wrap operands in TupleArg for the vararg solverlib_fo_appl primitive:
102
+ # fo_appl(op, (operands...), result)
103
+ return b.Expression(
104
+ application_builtin, operator_code, b.TupleArg(args[:-1]), result_ref
105
+ )
106
+
107
+
108
+ # =============================================================================
109
+ # Main Solver Model Class
110
+ # =============================================================================
111
+
17
112
 
18
113
  class SolverModelPB:
19
- def __init__(self, model: b.Model, num_type: str):
20
- assert num_type in ["cont", "int"], "Invalid numerical type, must be 'cont' or 'int'"
21
- self._model = model # TODO can we remove? only used for _model._to_executor
114
+ """Solver model interface using protobuf format for optimization problems."""
115
+
116
+ def __init__(self, model: b.Model, num_type: str) -> None:
117
+ """Initialize solver model.
118
+
119
+ Args:
120
+ model: The RelationalAI model.
121
+ num_type: Variable type - 'cont' or 'int'.
122
+ """
123
+ if num_type not in ["cont", "int"]:
124
+ raise ValueError(
125
+ f"Invalid numerical type '{num_type}'. Must be 'cont' or 'int'."
126
+ )
127
+ self._model = model
22
128
  self._num_type = num_type
23
129
  self._id = next(b._global_id)
24
- self.variable_relationships = ordered_set()
25
- prefix_u = f"SolverModel_{self._id}_"
26
- prefix_l = f"solvermodel_{self._id}_"
27
-
28
- Variable = model.Concept(prefix_u + "Variable")
29
- self.Variable = Variable
30
- self.MinObjective = model.Concept(prefix_u + "MinObjective")
31
- self.MaxObjective = model.Concept(prefix_u + "MaxObjective")
32
- self.Constraint = model.Concept(prefix_u + "Constraint")
130
+ # Maps relationships to their corresponding variable concepts
131
+ self._variable_relationships: dict[b.Relationship, b.Concept] = {}
132
+ prefix_uppercase = f"SolverModel_{self._id}_"
133
+ prefix_lowercase = prefix_uppercase.lower()
33
134
 
135
+ # Create core concepts for model components
136
+ self.Variable = Variable = model.Concept(prefix_uppercase + "Variable")
137
+ self.MinObjective = model.Concept(prefix_uppercase + "MinObjective")
138
+ self.MaxObjective = model.Concept(prefix_uppercase + "MaxObjective")
139
+ self.Constraint = model.Concept(prefix_uppercase + "Constraint")
34
140
  self._model_info = {
35
141
  "num_variables": Variable,
36
- "num_constraints": self.Constraint,
37
142
  "num_min_objectives": self.MinObjective,
38
143
  "num_max_objectives": self.MaxObjective,
144
+ "num_constraints": self.Constraint,
39
145
  }
146
+ # Add printed_expr property to objectives and constraints for human-readable output
147
+ for concept in [self.MinObjective, self.MaxObjective, self.Constraint]:
148
+ concept.printed_expr = model.Property(
149
+ f"{{{concept._name}}} has {{printed_expr:str}}"
150
+ )
40
151
 
41
- res_type = "int" if num_type == "int" else "float"
42
- self.result_info = model.Relationship("{key:str} has {val:str}", short_name=(prefix_l + "result_info"))
43
- self.point = model.Relationship(f"{{Variable}} has {{val:{res_type}}}", short_name=(prefix_l + "point"))
44
- self.points = model.Relationship(f"point {{i:int}} for {{Variable}} has {{val:{res_type}}}", short_name=(prefix_l + "points"))
152
+ # Create relationships for result extraction
153
+ result_type = "int" if num_type == "int" else "float"
154
+ self.result_info = model.Relationship(
155
+ "{key:str} has {value:str}", short_name=(prefix_lowercase + "result_info")
156
+ )
157
+ # TODO(coey) PyRel is not able to handle "Variable._name" instead of "var" below due
158
+ # to some internal naming bug; this leads to a "Unresolved Type" warning that we
159
+ # will have to live with for now
160
+ self.point = model.Property(
161
+ f"{{var}} has {{value:{result_type}}}",
162
+ short_name=(prefix_lowercase + "point"),
163
+ )
164
+ self.points = model.Property(
165
+ f"point {{i:int}} for {{var}} has {{value:{result_type}}}",
166
+ short_name=(prefix_lowercase + "points"),
167
+ )
45
168
 
46
- b.define(b.RawSource("rel", textwrap.dedent(f"""
169
+ # Install raw rel to work around lack of support for rel_primitive_solverlib_print_expr
170
+ install_rel = f"""
47
171
  declare {self.MinObjective._name}
48
172
  declare {self.MaxObjective._name}
49
173
  declare {self.Constraint._name}
50
174
 
51
- declare {prefix_l}variable_name
52
- declare {prefix_l}minobjective_name
53
- declare {prefix_l}maxobjective_name
54
- declare {prefix_l}constraint_name
55
- declare {prefix_l}minobjective_serialized
56
- declare {prefix_l}maxobjective_serialized
57
- declare {prefix_l}constraint_serialized
175
+ declare {prefix_lowercase}variable_name
176
+ declare {prefix_lowercase}minobjective_name
177
+ declare {prefix_lowercase}maxobjective_name
178
+ declare {prefix_lowercase}constraint_name
179
+ declare {prefix_lowercase}minobjective_serialized
180
+ declare {prefix_lowercase}maxobjective_serialized
181
+ declare {prefix_lowercase}constraint_serialized
182
+
183
+ def {prefix_lowercase}minobjective_printed_expr(h, s):
184
+ rel_primitive_solverlib_print_expr({prefix_lowercase}minobjective_serialized[h], {prefix_lowercase}variable_name, s)
185
+
186
+ def {prefix_lowercase}maxobjective_printed_expr(h, s):
187
+ rel_primitive_solverlib_print_expr({prefix_lowercase}maxobjective_serialized[h], {prefix_lowercase}variable_name, s)
188
+
189
+ def {prefix_lowercase}constraint_printed_expr(h, s):
190
+ rel_primitive_solverlib_print_expr({prefix_lowercase}constraint_serialized[h], {prefix_lowercase}variable_name, s)
191
+ """
192
+ b.define(b.RawSource("rel", textwrap.dedent(install_rel)))
58
193
 
59
- def {prefix_l}minobjective_printed_expr(h, s):
60
- rel_primitive_solverlib_print_expr({prefix_l}minobjective_serialized[h], {prefix_l}variable_name, s)
194
+ # -------------------------------------------------------------------------
195
+ # Variable Handling
196
+ # -------------------------------------------------------------------------
61
197
 
62
- def {prefix_l}maxobjective_printed_expr(h, s):
63
- rel_primitive_solverlib_print_expr({prefix_l}maxobjective_serialized[h], {prefix_l}variable_name, s)
198
+ def solve_for(
199
+ self,
200
+ expr,
201
+ where: Optional[list[Any]] = None,
202
+ populate: bool = True,
203
+ **kwargs: Any,
204
+ ) -> b.Concept:
205
+ """Define decision variables.
64
206
 
65
- def {prefix_l}constraint_printed_expr(h, s):
66
- rel_primitive_solverlib_print_expr({prefix_l}constraint_serialized[h], {prefix_l}variable_name, s)
67
- """)))
207
+ Args:
208
+ expr: Relationship or expression defining variables.
209
+ where: Optional grounding conditions.
210
+ populate: Whether to populate relationship with solver results.
211
+ **kwargs: Optional properties (name, type, lower, upper, fixed).
68
212
 
69
- # TODO(coey) assert that it is a property? not just a relationship.
70
- def solve_for(self, expr: b.Relationship | b.Fragment, populate: bool = True, **kwargs):
71
- where = []
213
+ Returns:
214
+ Variable concept.
215
+ """
216
+ if where is None:
217
+ where = []
72
218
  if isinstance(expr, b.Fragment):
73
- assert expr._select and len(expr._select) == 1 and expr._where, "Fragment input for `solve_for` must have exactly one select and a where clause"
74
- rel = expr._select[0]
75
- where = expr._where
219
+ # TODO(coey): Remove in future
220
+ raise ValueError(
221
+ "The select fragment argument to `solve_for` is deprecated. "
222
+ "Instead, use the `where = [conditions...]` kwarg to specify optional grounding conditions."
223
+ )
224
+ elif isinstance(expr, b.Expression):
225
+ relationship = expr._op
226
+ if not isinstance(relationship, b.Relationship):
227
+ raise TypeError(
228
+ f"Expression operator must be a Relationship, got {type(relationship).__name__}."
229
+ )
230
+ params = expr._params
76
231
  elif isinstance(expr, b.Relationship):
77
- rel = expr
232
+ relationship = expr
233
+ params = [
234
+ b.field_to_type(self._model, field) for field in relationship._fields
235
+ ]
78
236
  else:
79
- raise ValueError(f"Invalid expression type {type(expr)} for `solve_for`; must be a Relationship or Fragment")
237
+ raise TypeError(
238
+ f"Invalid expression type for solve_for: {type(expr).__name__}. "
239
+ f"Expected Relationship or Expression."
240
+ )
241
+
242
+ if len(params) != len(relationship._fields):
243
+ raise ValueError(
244
+ f"Parameter count mismatch: Got {len(params)} params "
245
+ f"but relationship has {len(relationship._fields)} fields."
246
+ )
247
+ if relationship in self._variable_relationships:
248
+ raise ValueError(
249
+ f"Variables are already defined for relationship {relationship}."
250
+ )
80
251
 
81
- assert rel._parent, "Relationship for `solve_for` must have a parent"
82
- assert rel._short_name, "Relationship for `solve_for` must have a short name"
83
- self.variable_relationships.add(rel)
252
+ # Create a specialized Variable concept for this relationship
253
+ # Each decision variable gets its own concept subtype
254
+ Var = self._model.Concept(
255
+ f"{self.Variable._name}_{str(relationship).replace('.', '_')}",
256
+ extends=[self.Variable],
257
+ )
258
+ self._variable_relationships[relationship] = Var
84
259
 
85
- ent = b.select(rel._parent).where(*where)
86
- var = self.Variable.new(entity=ent, relationship=rel._short_name)
87
- defs = [var]
260
+ # Build field dict from relationship parameters (excluding the value field)
261
+ fields = {}
262
+ for i in range(len(params) - 1):
263
+ if i == 0 and relationship._parent is not None:
264
+ concept = relationship._parent
265
+ if not isinstance(concept, b.Concept):
266
+ raise TypeError(
267
+ f"Relationship parent must be a Concept, got {type(concept).__name__}."
268
+ )
269
+ else:
270
+ concept = params[i]
271
+ field_name = relationship._field_names[i]
272
+ # Prevent "Implicit Subtype Relationship" warnings by explicitly registering
273
+ # the relationship on the parent Variable concept before using it on subtypes
274
+ self.Variable._relationships[field_name] = self.Variable._get_relationship(
275
+ field_name
276
+ )
277
+ fields[field_name] = concept
278
+ var = Var.new(**fields)
279
+ b.define(var).where(*where)
88
280
 
89
- # handle optional variable properties
90
- for (key, val) in kwargs.items():
281
+ # Handle optional variable properties
282
+ for key, value in kwargs.items():
91
283
  if key == "name":
92
- assert isinstance(val, (_Any, list)), f"Expected {key} to be a value or list, got {type(val)}"
93
- defs.append(var.name(make_name(val)))
284
+ definition = self.Variable.name(var, make_name(value))
94
285
  elif key == "type":
95
- assert isinstance(val, str), f"Expected {key} to be a string, got {type(val)}"
96
- assert val in _var_types, f"Invalid variable type {val} for `solve_for`"
97
- ser = _make_fo_appl_with_res(_var_types[val], var)
98
- defs.append(self.Constraint.new(serialized=ser))
286
+ if not isinstance(value, str):
287
+ raise TypeError(
288
+ f"Variable 'type' must be a string, but got {type(value).__name__}."
289
+ )
290
+ if value not in _VARIABLE_TYPE_CODES:
291
+ valid_types = ", ".join(_VARIABLE_TYPE_CODES.keys())
292
+ raise ValueError(
293
+ f"Invalid variable type '{value}'. Valid types are: {valid_types}."
294
+ )
295
+ serialized_expr = _make_first_order_application_with_result(
296
+ _VARIABLE_TYPE_CODES[value], var
297
+ )
298
+ definition = self.Constraint.new(serialized=serialized_expr)
99
299
  elif key in ("lower", "upper", "fixed"):
100
- assert isinstance(val, _Number), f"Expected {key} to be a number, got {type(val)}"
101
- op = ">=" if key == "lower" else ("<=" if key == "upper" else "=")
102
- ser = _make_fo_appl_with_res(_fo_comparisons[op], var, val)
103
- defs.append(self.Constraint.new(serialized=ser))
300
+ if not isinstance(value, (b.Producer, float, int)):
301
+ raise TypeError(
302
+ f"Variable '{key}' must be a number, but got {type(value).__name__}."
303
+ )
304
+ # Map bound types to comparison operators
305
+ operator = _BOUND_TO_COMPARISON_OPERATOR[key]
306
+ serialized_expr = _make_first_order_application_with_result(
307
+ _FIRST_ORDER_COMPARISON_CODES[operator], var, value
308
+ )
309
+ definition = self.Constraint.new(serialized=serialized_expr)
104
310
  else:
105
- raise ValueError(f"Invalid keyword argument {key} for `solve_for`")
106
-
107
- b.define(*defs)
311
+ raise ValueError(f"Invalid keyword argument '{key}' for solve_for.")
312
+ b.define(definition).where(*where)
108
313
 
109
314
  if populate:
110
- # TODO do something different in future. maybe delete/insert into variable relationships after solve.
111
- # get variable values from the result point (populated by the solver)
112
- val = (b.Integer if self._num_type == "int" else b.Float).ref()
113
- b.define(rel(val)).where(self.point(var, val))
114
- return var
115
-
116
- # min/max must take a number or a Producer
117
- def minimize(self, expr: _Number, name = None):
118
- return self._obj(self.MinObjective, expr, name)
119
-
120
- def maximize(self, expr: _Number, name = None):
121
- return self._obj(self.MaxObjective, expr, name)
122
-
123
- def _obj(self, concept: b.Concept, expr: _Number, name):
124
- sym_expr = _rewrite(expr, self)
125
- if not sym_expr:
126
- # expr is not symbolic (a constant)
127
- # TODO should we warn if objective is constant or do further checks?
128
- sym_expr = _make_fo_appl_with_res(0, expr)
129
- elif isinstance(sym_expr, b.ConceptMember):
130
- # expr probably refers to a single variable, so we need to wrap it for valid protobuf
131
- sym_expr = _make_fo_appl_with_res(0, sym_expr)
132
- obj = concept.new(serialized=sym_expr)
133
- defs = [obj]
315
+ # Automatically populate the variable relationship with solver results
316
+ # This defines the original relationship to pull values from self.point after solving
317
+ value_ref = (b.Integer if self._num_type == "int" else b.Float).ref()
318
+ b.define(
319
+ relationship(
320
+ *[getattr(Var, field_name) for field_name in fields], value_ref
321
+ )
322
+ ).where(self.point(Var, value_ref))
323
+
324
+ return Var
325
+
326
+ # -------------------------------------------------------------------------
327
+ # Objective Functions
328
+ # -------------------------------------------------------------------------
329
+
330
+ def minimize(
331
+ self,
332
+ expr: b.Producer | float | int | b.Fragment,
333
+ name: Optional[str | list[str]] = None,
334
+ ) -> None:
335
+ """Add minimization objective.
336
+
337
+ Args:
338
+ expr: Expression to minimize.
339
+ name: Optional objective name.
340
+ """
341
+ return self._add_objective(self.MinObjective, expr, name)
342
+
343
+ def maximize(
344
+ self,
345
+ expr: b.Producer | float | int | b.Fragment,
346
+ name: Optional[str | list[str]] = None,
347
+ ) -> None:
348
+ """Add maximization objective.
349
+
350
+ Args:
351
+ expr: Expression to maximize.
352
+ name: Optional objective name.
353
+ """
354
+ return self._add_objective(self.MaxObjective, expr, name)
355
+
356
+ def _add_objective(
357
+ self,
358
+ objective_concept: b.Concept,
359
+ expr: b.Producer | float | int | b.Fragment,
360
+ name: Optional[str | list[str]],
361
+ ) -> None:
362
+ context = SymbolifyContext(self)
363
+ symbolic_expr = context.rewrite(expr)
364
+ if not isinstance(symbolic_expr, Symbolic):
365
+ # Expr is not symbolic (a constant) - wrap it as a trivial expression
366
+ symbolic_expr = _make_first_order_application_with_result(0, expr)
367
+ else:
368
+ # Unwrap the symbolic expression
369
+ unwrapped_expr = symbolic_expr.expr
370
+ # Check if it's a bare variable (needs wrapping for protobuf)
371
+ is_bare_variable = isinstance(unwrapped_expr, b.ConceptMember)
372
+ is_fragment_with_variable = (
373
+ isinstance(unwrapped_expr, b.Fragment)
374
+ and unwrapped_expr._select
375
+ and isinstance(unwrapped_expr._select[0], b.ConceptMember)
376
+ )
377
+ if is_bare_variable or is_fragment_with_variable:
378
+ # The protobuf format requires all objectives to be expressions, not bare variables
379
+ symbolic_expr = _make_first_order_application_with_result(
380
+ 0, unwrapped_expr
381
+ )
382
+ else:
383
+ symbolic_expr = unwrapped_expr
384
+
385
+ if isinstance(symbolic_expr, Symbolic):
386
+ raise ValueError(
387
+ "Internal error. Expression is still Symbolic after unwrapping."
388
+ )
389
+
390
+ objective = objective_concept.new(serialized=symbolic_expr)
391
+ definitions = [objective]
134
392
  if name is not None:
135
- defs.append(obj.name(make_name(name)))
136
- b.define(*defs)
137
- return obj
138
-
139
- # satisfy must take a require Fragment
140
- def satisfy(self, expr: b.Fragment, check: bool = False, name = None):
141
- assert expr._require, "Fragment input for `satisfy` must have a require clause"
142
- assert not expr._select and not expr._define, "Fragment input for `satisfy` must not have a select or define clause"
393
+ definitions.append(objective.name(make_name(name)))
394
+ b.define(*definitions)
395
+
396
+ def satisfy(
397
+ self,
398
+ expr: b.Fragment,
399
+ check: bool = False,
400
+ name: Optional[str | list[str]] = None,
401
+ ) -> None:
402
+ """Add constraints.
403
+
404
+ Args:
405
+ expr: Fragment with require clause.
406
+ check: Whether to keep require in model validation.
407
+ name: Optional constraint name.
408
+ """
409
+ if not isinstance(expr, b.Fragment):
410
+ raise TypeError(
411
+ f"The satisfy method expects a Fragment, but got {type(expr).__name__}."
412
+ )
413
+ if not expr._require:
414
+ raise ValueError("Fragment for satisfy must have a require clause.")
415
+ if expr._select or expr._define:
416
+ raise ValueError(
417
+ "Fragment for satisfy must not have select or define clauses."
418
+ )
143
419
  if not check:
144
- # remove the `require` from the model roots so it is not checked
420
+ # Remove the `require` from the model roots so it is not checked as an integrity constraint
145
421
  b._remove_roots([expr])
146
- # TODO maybe ensure no variables in `where`s, now `.new`s?
147
- sym_reqs = []
148
- for req in expr._require:
149
- sym_req = _rewrite(req, self)
150
- assert sym_req, f"Cannot symbolify requirement {req} in `satisfy`"
151
- sym_reqs.append(sym_req)
152
- ser = b.union(*sym_reqs) if len(sym_reqs) > 1 else sym_reqs[0]
153
- # TODO(coey) nested select not working properly on supply_chain, so have to put the where later, for now
154
- # cons = self.Constraint.new(serialized=b.select(ser).where(*expr._where))
155
- cons = self.Constraint.new(serialized=ser)
156
- defs = [cons]
157
- if name is not None:
158
- defs.append(cons.name(make_name(name)))
159
- b.define(*defs).where(*expr._where)
160
- return cons
161
-
162
- # print counts of the number of model components
163
- def summarize(self):
164
- to_count = [
165
- self.Variable,
166
- self.MinObjective.serialized,
167
- self.MaxObjective.serialized,
168
- self.Constraint.serialized,
169
- ]
170
- counts = b.select(*[(b.count(c) | 0) for c in to_count]).to_df() # TODO(coey) do we need the |0?
171
- assert counts.shape == (1, len(to_count)), f"Unexpected counts shape {counts.shape}"
172
- (vars, min_objs, max_objs, cons) = counts.iloc[0]
173
- print(f"Solver model has {vars} variables, {min_objs} minimization objectives, {max_objs} maximization objectives, and {cons} constraints")
174
- return None
175
-
176
- # print the variables and components of the model in human-readable format
177
- def print(self, with_names: bool = False):
178
- print("Printing solver model.")
179
- # print variables
422
+ context = SymbolifyContext(self)
423
+
424
+ symbolic_where_clauses = context.rewrite_where(*expr._where)
425
+ definitions = []
426
+ for requirement in expr._require:
427
+ symbolic_requirement = context.rewrite(requirement)
428
+ if not isinstance(symbolic_requirement, Symbolic):
429
+ raise ValueError(
430
+ f"Cannot symbolify requirement {requirement} in satisfy. "
431
+ f"The requirement must contain solver variables or expressions."
432
+ )
433
+ constraint = self.Constraint.new(serialized=symbolic_requirement.expr)
434
+ definitions.append(constraint)
435
+ if name is not None:
436
+ definitions.append(constraint.name(make_name(name)))
437
+ b.define(*definitions).where(*symbolic_where_clauses)
438
+
439
+ # -------------------------------------------------------------------------
440
+ # Model Inspection and Display
441
+ # -------------------------------------------------------------------------
442
+
443
+ @staticmethod
444
+ def _print_dataframe(df: Any) -> None:
445
+ """Print dataframe with consistent formatting for long strings."""
446
+ for row in df.itertuples(index=False):
447
+ print(" ".join(str(val) for val in row))
448
+ print()
449
+
450
+ def summarize(self) -> None:
451
+ """Print counts of variables, objectives, and constraints in the model."""
452
+ counts_df = b.select(
453
+ *[(b.count(item) | 0) for (_, item) in self._model_info.items()]
454
+ ).to_df()
455
+ if counts_df.shape != (1, 4):
456
+ raise ValueError("Unexpected counts dataframe shape.")
457
+ num_vars, num_min_objs, num_max_objs, num_constraints = counts_df.iloc[0]
458
+ print(
459
+ f"Solver model has {num_vars} variables, {num_min_objs} minimization objectives, {num_max_objs} maximization objectives, and {num_constraints} constraints."
460
+ )
461
+
462
+ def print(self, with_names: bool = False) -> None:
463
+ """Print model components.
464
+
465
+ Args:
466
+ with_names: Whether to print expression string names (if available).
467
+ """
468
+ # Print variables
180
469
  var_df = b.select(self.Variable.name | "_").where(self.Variable).to_df()
181
470
  if var_df.empty:
182
- print("No variables defined in the solver model.")
183
- return None
184
- print(f"{var_df.shape[0]} variables:")
185
- print(var_df.to_string(index=False, header=False))
186
-
187
- # print components
188
- comps = [
189
- (self.MinObjective, "minimization objectives"),
190
- (self.MaxObjective, "maximization objectives"),
191
- (self.Constraint, "constraints"),
471
+ print("No variables defined.")
472
+ return
473
+ print("Solver model:")
474
+ print()
475
+ print(f"Variables ({var_df.shape[0]}):")
476
+ self._print_dataframe(var_df)
477
+
478
+ # Print components
479
+ components = [
480
+ (self.MinObjective, "Min objectives"),
481
+ (self.MaxObjective, "Max objectives"),
482
+ (self.Constraint, "Constraints"),
192
483
  ]
193
- p = b.String.ref()
194
- for (e, s) in comps:
195
- sel = [e.name | "", p] if with_names else [p]
196
- comp_df = b.select(*sel).where(e.printed_expr(p)).to_df()
197
- if not comp_df.empty:
198
- print(f"{comp_df.shape[0]} {s}:")
199
- print(comp_df.to_string(index=False, header=False))
200
- return None
201
-
202
- # solve the model given a solver and solver options
203
- def solve(self, solver: Solver, log_to_console: bool = False, **kwargs):
204
- options = kwargs
205
- options["version"] = 1
206
-
207
- # Validate options.
208
- for k, v in options.items():
209
- if not isinstance(k, str):
210
- raise ValueError(f"Invalid parameter key. Expected string, got {type(k)} for {k}.")
211
- if not isinstance(v, (int, float, str, bool)):
212
- raise ValueError(
213
- f"Invalid parameter value. Expected string, integer, float, or boolean, got {type(v)} for {k}."
484
+ printed_expr_ref = b.String.ref()
485
+ for component_concept, component_label in components:
486
+ selection = (
487
+ [component_concept.name | "", printed_expr_ref]
488
+ if with_names
489
+ else [printed_expr_ref]
490
+ )
491
+ component_df = (
492
+ b.select(*selection)
493
+ .where(component_concept.printed_expr(printed_expr_ref))
494
+ .to_df()
495
+ )
496
+ if not component_df.empty:
497
+ print(f"{component_label} ({component_df.shape[0]}):")
498
+ self._print_dataframe(component_df)
499
+
500
+ # -------------------------------------------------------------------------
501
+ # Solving and Result Handling
502
+ # -------------------------------------------------------------------------
503
+
504
+ def solve(
505
+ self, solver: Solver, log_to_console: bool = False, **kwargs: Any
506
+ ) -> None:
507
+ """Solve the model.
508
+
509
+ Args:
510
+ solver: Solver instance.
511
+ log_to_console: Whether to show solver output.
512
+ **kwargs: Solver options and parameters.
513
+ """
514
+ options = {**kwargs, "version": 1}
515
+
516
+ # Validate solver options
517
+ for option_key, option_value in options.items():
518
+ if not isinstance(option_key, str):
519
+ raise TypeError(
520
+ f"Solver option keys must be strings, but got {type(option_key).__name__} for key {option_key!r}."
521
+ )
522
+ if not isinstance(option_value, (int, float, str, bool)):
523
+ raise TypeError(
524
+ f"Solver option values must be int, float, str, or bool, "
525
+ f"but got {type(option_value).__name__} for option {option_key!r}."
214
526
  )
215
527
 
216
- # Run the solve query and insert the extracted result.
528
+ # Three-phase solve process:
529
+ # 1. Export model to Snowflake as protobuf
530
+ # 2. Execute solver job (external solver reads from Snowflake)
531
+ # 3. Extract and load results back into the model
217
532
  input_id = uuid.uuid4()
218
533
  model_uri = f"snowflake://APP_STATE.RAI_INTERNAL_STAGE/job-inputs/solver/{input_id}/model.binpb"
219
534
  sf_input_uri = f"snowflake://job-inputs/solver/{input_id}/model.binpb"
@@ -222,28 +537,40 @@ class SolverModelPB:
222
537
  payload["model_uri"] = sf_input_uri
223
538
 
224
539
  executor = self._model._to_executor()
225
- assert isinstance(executor, RelExecutor)
226
- prefix_l = f"solvermodel_{self._id}_"
540
+ if not isinstance(executor, RelExecutor):
541
+ raise ValueError(f"Expected RelExecutor, got {type(executor).__name__}.")
542
+ prefix_lowercase = f"solvermodel_{self._id}_"
227
543
 
228
544
  query_timeout_mins = kwargs.get("query_timeout_mins", None)
229
545
  config = self._model._config
230
- if query_timeout_mins is None and (timeout_value := config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
546
+ if (
547
+ query_timeout_mins is None
548
+ and (
549
+ timeout_value := config.get(
550
+ "query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS
551
+ )
552
+ )
553
+ is not None
554
+ ):
231
555
  query_timeout_mins = int(timeout_value)
232
- config_file_path = getattr(config, 'file_path', None)
556
+ config_file_path = getattr(config, "file_path", None)
233
557
  start_time = time.monotonic()
234
558
  remaining_timeout_minutes = query_timeout_mins
235
559
 
236
- # 1. Materialize the model and store it.
560
+ # Step 1: Materialize the model and store it in Snowflake
237
561
  print("export model")
238
- b.select(b.count(self.Variable)).to_df() # TODO(coey) weird hack to avoid uninitialized properties error
239
- executor.execute_raw(textwrap.dedent(f"""
562
+ # TODO(coey): Weird hack to avoid uninitialized properties error
563
+ # This forces evaluation of the Variable concept before export
564
+ b.select(b.count(self.Variable)).to_df()
565
+ export_model_relation = f"""
240
566
  // TODO maybe only want to pass names if printing - like in old setup
567
+ // Collect all model components into a relation for serialization
241
568
  def model_relation {{
242
569
  (:variable, {self.Variable._name});
243
- (:variable_name, {prefix_l}variable_name);
244
- (:min_objective, {prefix_l}minobjective_serialized);
245
- (:max_objective, {prefix_l}maxobjective_serialized);
246
- (:constraint, {prefix_l}constraint_serialized);
570
+ (:variable_name, {prefix_lowercase}variable_name);
571
+ (:min_objective, {prefix_lowercase}minobjective_serialized);
572
+ (:max_objective, {prefix_lowercase}maxobjective_serialized);
573
+ (:constraint, {prefix_lowercase}constraint_serialized);
247
574
  }}
248
575
 
249
576
  @no_diagnostics(:EXPERIMENTAL)
@@ -255,23 +582,33 @@ class SolverModelPB:
255
582
  def config[:envelope, :payload, :data]: model_string
256
583
  def config[:envelope, :payload, :path]: "{model_uri}"
257
584
  def export {{ config }}
258
- """), query_timeout_mins=remaining_timeout_minutes)
585
+ """
586
+ executor.execute_raw(
587
+ textwrap.dedent(export_model_relation),
588
+ query_timeout_mins=remaining_timeout_minutes,
589
+ )
259
590
 
260
- # 2. Execute job and wait for completion.
591
+ # Step 2: Execute solver job and wait for completion
261
592
  print("execute solver job")
262
593
  remaining_timeout_minutes = calc_remaining_timeout_minutes(
263
- start_time, query_timeout_mins, config_file_path=config_file_path,
594
+ start_time,
595
+ query_timeout_mins,
596
+ config_file_path=config_file_path,
264
597
  )
265
598
  job_id = solver._exec_job(
266
- payload, log_to_console=log_to_console, query_timeout_mins=remaining_timeout_minutes,
599
+ payload,
600
+ log_to_console=log_to_console,
601
+ query_timeout_mins=remaining_timeout_minutes,
267
602
  )
268
603
 
269
- # 3. Extract result.
604
+ # Step 3: Extract and insert solver results into the model
270
605
  print("extract result")
271
606
  remaining_timeout_minutes = calc_remaining_timeout_minutes(
272
- start_time, query_timeout_mins, config_file_path=config_file_path,
607
+ start_time,
608
+ query_timeout_mins,
609
+ config_file_path=config_file_path,
273
610
  )
274
- extract_str = textwrap.dedent(f"""
611
+ extract_results_relation = f"""
275
612
  def raw_result {{
276
613
  load_binary["snowflake://APP_STATE.RAI_INTERNAL_STAGE/job-results/{job_id}/result.binpb"]
277
614
  }}
@@ -285,231 +622,542 @@ class SolverModelPB:
285
622
  def delete[:{self.point._name}]: {self.point._name}
286
623
  def delete[:{self.points._name}]: {self.points._name}
287
624
 
288
- def insert(:{self.result_info._name}, key, val):
289
- exists((k) | string(extracted[k], val) and ::std::mirror::lower(k, key))
290
- """)
625
+ def insert(:{self.result_info._name}, key, value):
626
+ exists((original_key) | string(extracted[original_key], value) and ::std::mirror::lower(original_key, key))
627
+ """
291
628
  if self._num_type == "int":
292
- extract_str += textwrap.dedent(f"""
293
- def insert(:{self.point._name}, var, val):
294
- exists((x) | extracted(:point, var, x) and
295
- ::std::mirror::convert(std::mirror::typeof[Int128], x, val)
629
+ insert_points_relation = f"""
630
+ def insert(:{self.point._name}, variable, value):
631
+ exists((float_value) | extracted(:point, variable, float_value) and
632
+ ::std::mirror::convert(std::mirror::typeof[Int128], float_value, value)
296
633
  )
297
- def insert(:{self.points._name}, i, var, val):
298
- exists((j, x) | extracted(:points, var, j, x) and
299
- ::std::mirror::convert(std::mirror::typeof[Int128], j, i) and
300
- ::std::mirror::convert(std::mirror::typeof[Int128], x, val)
634
+ def insert(:{self.points._name}, point_index, variable, value):
635
+ exists((float_index, float_value) | extracted(:points, variable, float_index, float_value) and
636
+ ::std::mirror::convert(std::mirror::typeof[Int128], float_index, point_index) and
637
+ ::std::mirror::convert(std::mirror::typeof[Int128], float_value, value)
301
638
  )
302
- """)
639
+ """
303
640
  else:
304
- extract_str += textwrap.dedent(f"""
305
- def insert(:{self.point._name}, var, val): extracted(:point, var, val)
306
- def insert(:{self.points._name}, i, var, val):
307
- exists((j) | extracted(:points, var, j, val) and
308
- ::std::mirror::convert(std::mirror::typeof[Int128], j, i)
641
+ insert_points_relation = f"""
642
+ def insert(:{self.point._name}, variable, value): extracted(:point, variable, value)
643
+ def insert(:{self.points._name}, point_index, variable, value):
644
+ exists((float_index) | extracted(:points, variable, float_index, value) and
645
+ ::std::mirror::convert(std::mirror::typeof[Int128], float_index, point_index)
309
646
  )
310
- """)
647
+ """
311
648
  executor.execute_raw(
312
- extract_str, readonly=False, query_timeout_mins=remaining_timeout_minutes,
649
+ textwrap.dedent(extract_results_relation)
650
+ + textwrap.dedent(insert_points_relation),
651
+ readonly=False,
652
+ query_timeout_mins=remaining_timeout_minutes,
313
653
  )
314
654
 
315
655
  print("finished solve")
316
- return None
656
+ print()
317
657
 
318
- # load a particular point index from `points` into `point`
319
- # so it is accessible from the variable relationship
320
- def load_point(self, i: int):
321
- if not isinstance(i, int) and i >= 0:
322
- raise ValueError(f"Expected nonnegative integer index for point, got {i}")
658
+ def load_point(self, point_index: int) -> None:
659
+ """Load a solution point.
660
+
661
+ Args:
662
+ point_index: Solution point index (0-based).
663
+ """
664
+ if not isinstance(point_index, int):
665
+ raise TypeError(
666
+ f"Point index must be an integer, but got {type(point_index).__name__}."
667
+ )
668
+ if point_index < 0:
669
+ raise ValueError(
670
+ f"Point index must be non-negative, but got {point_index}."
671
+ )
323
672
  executor = self._model._to_executor()
324
- assert isinstance(executor, RelExecutor)
325
- executor.execute_raw(textwrap.dedent(f"""
673
+ if not isinstance(executor, RelExecutor):
674
+ raise ValueError(
675
+ f"Expected RelExecutor, but got {type(executor).__name__}."
676
+ )
677
+ load_point_relation = f"""
326
678
  def delete[:{self.point._name}]: {self.point._name}
327
- def insert(:{self.point._name}, var, val): {self.points._name}(int128[{i}], var, val)
328
- """), readonly=False)
329
- return None
330
-
331
- # print summary of the solver result
332
- def summarize_result(self):
333
- to_get = ["error", "termination_status", "solve_time_sec", "objective_value", "solver_version", "result_count"]
334
- k, v = b.String.ref(), b.String.ref()
335
- df = b.select(k, v).where(self.result_info(k, v), k.in_(to_get)).to_df()
336
- assert not df.empty, "No result information"
337
- print(df.to_string(index=False, header=False))
338
- return df
339
-
340
- # select variable names and values in the primal result point(s)
341
- def variable_values(self, multiple: bool = False):
342
- var = self.Variable.ref()
343
- val = (b.Integer if self._num_type == "int" else b.Float).ref()
679
+ def insert(:{self.point._name}, variable, value): {self.points._name}(int128[{point_index}], variable, value)
680
+ """
681
+ executor.execute_raw(textwrap.dedent(load_point_relation), readonly=False)
682
+
683
+ def summarize_result(self) -> Any:
684
+ """Print solver result summary.
685
+
686
+ Returns:
687
+ DataFrame with result information.
688
+ """
689
+ info_keys_to_retrieve = [
690
+ "error",
691
+ "termination_status",
692
+ "solve_time_sec",
693
+ "objective_value",
694
+ "solver_version",
695
+ "result_count",
696
+ ]
697
+ key, value_ref = b.String.ref(), b.String.ref()
698
+ result_df = (
699
+ b.select(key, value_ref)
700
+ .where(self.result_info(key, value_ref), key.in_(info_keys_to_retrieve))
701
+ .to_df()
702
+ )
703
+ if result_df.empty:
704
+ raise ValueError(
705
+ "No result information is available. Has the model been solved?"
706
+ )
707
+ print("Solver result:")
708
+ print(result_df.to_string(index=False, header=False))
709
+ print()
710
+ return result_df
711
+
712
+ def variable_values(self, multiple: bool = False) -> b.Fragment:
713
+ """Retrieve variable values.
714
+
715
+ Args:
716
+ multiple: Whether to return all solution points.
717
+
718
+ Returns:
719
+ Fragment for selecting values.
720
+ """
721
+ variable_ref = self.Variable.ref()
722
+ value_ref = (b.Integer if self._num_type == "int" else b.Float).ref()
344
723
  if multiple:
345
- i = b.Integer.ref()
346
- return b.select(i, var.name, val).where(self.points(i, var, val))
347
- else:
348
- return b.select(var.name, val).where(self.point(var, val))
724
+ point_index = b.Integer.ref()
725
+ return b.select(point_index, variable_ref.name, value_ref).where(
726
+ self.points(point_index, variable_ref, value_ref)
727
+ )
728
+ return b.select(variable_ref.name, value_ref).where(
729
+ self.point(variable_ref, value_ref)
730
+ )
731
+
732
+ # Valid result info keys that can be accessed as attributes
733
+ _RESULT_INFO_KEYS = frozenset(
734
+ [
735
+ "error",
736
+ "termination_status",
737
+ "solver_version",
738
+ "printed_model",
739
+ "solve_time_sec",
740
+ "objective_value",
741
+ "result_count",
742
+ ]
743
+ )
744
+
745
+ def __getattr__(self, name: str) -> Any:
746
+ """Get result attribute (e.g., num_variables, termination_status, objective_value).
747
+
748
+ Args:
749
+ name: Attribute name.
349
750
 
350
- # get scalar result information after solving
351
- def __getattr__(self, name: str):
352
- df = None
751
+ Returns:
752
+ Attribute value or None.
753
+ """
754
+ # Try to get dataframe from model info or result info
353
755
  if name in self._model_info:
354
- df = b.select(b.count(self._model_info[name]) | 0).to_df()
355
- elif name in {"error", "termination_status", "solver_version", "printed_model", "solve_time_sec", "objective_value", "result_count"}:
356
- val = b.String.ref()
357
- df = b.select(val).where(self.result_info(name, val)).to_df()
358
- # extract scalar from df
359
- if df is not None:
360
- if not df.shape == (1, 1):
361
- raise ValueError(f"Expected exactly one value for {name}, but df has shape {df.shape}")
362
- v = df.iloc[0, 0]
363
- if isinstance(v, str):
364
- if name == "solve_time_sec":
365
- return float(v)
366
- elif name == "objective_value":
367
- return int(v) if self._num_type == "int" else float(v)
368
- elif name == "result_count":
369
- return int(v)
370
- return v
371
- return None
372
-
373
- # TODO maybe structure rewriting code to be more like the compiler passes rather than in one big if-else
374
- def _rewrite(expr: Any, sm: SolverModelPB) -> Any:
375
- if isinstance(expr, (int, float, str)):
376
- return None
377
-
378
- elif isinstance(expr, (b.TypeRef, b.Concept)):
379
- return None
380
-
381
- elif isinstance(expr, b.Ref):
382
- thing = _rewrite(expr._thing, sm)
383
- if thing:
384
- return thing.ref()
385
- return None
386
-
387
- elif isinstance(expr, (b.Relationship, b.RelationshipRef, b.RelationshipFieldRef)):
388
- rel = expr if isinstance(expr, b.Relationship) else expr._relationship
389
- if rel in sm.variable_relationships:
390
- return sm.Variable(sm.Variable.ref(), entity=expr._parent, relationship=rel._short_name)
391
- return None
392
-
393
- elif isinstance(expr, b.Expression):
394
- op = _rewrite(expr._op, sm) # TODO what cases is this useful for?
395
- op_rewritten = op is not None
396
- params_rewritten = False
397
- params = []
398
- for p in expr._params:
399
- rp = _rewrite(p, sm)
400
- if rp:
401
- params_rewritten = True
402
- params.append(rp)
403
- else:
404
- params.append(p)
405
- if op_rewritten:
406
- assert not params_rewritten, f"Solver rewrites cannot handle expression {expr} with symbolic operator and symbolic parameters"
407
- return b.Expression(op, *params)
408
- if not params_rewritten:
409
- return None
410
- # some arguments involve solver variables, so rewrite the expression
411
- assert isinstance(expr._op, b.Relationship), f"Solver rewrites cannot handle expression {expr}"
412
- op = expr._op._name
413
- assert isinstance(op, str)
414
- if op in _fo_operators:
415
- return _make_fo_appl(_fo_operators[op], *params)
416
- elif op in _fo_comparisons:
417
- return _make_fo_appl_with_res(_fo_comparisons[op], *params)
756
+ result_df = b.select(b.count(self._model_info[name]) | 0).to_df()
757
+ elif name in self._RESULT_INFO_KEYS:
758
+ value_ref = b.String.ref()
759
+ result_df = (
760
+ b.select(value_ref).where(self.result_info(name, value_ref)).to_df()
761
+ )
418
762
  else:
419
- raise NotImplementedError(f"Solver rewrites cannot handle operator {op}")
420
-
421
- elif isinstance(expr, b.Aggregate):
422
- # only the last argument can be symbolic
423
- start_args = expr._args[:-1]
424
- for arg in start_args:
425
- assert not _rewrite(arg, sm), f"Solver rewrites cannot handle expression {expr}; only the last argument can be symbolic"
426
- sym_arg = _rewrite(expr._args[-1], sm)
427
- if not sym_arg:
428
763
  return None
429
- op = expr._op._name
430
- assert isinstance(op, str)
431
- if op in _ho_operators:
432
- appl = b.Relationship.builtins["rel_primitive_solverlib_ho_appl"]
433
- agg = b.Aggregate(appl, *start_args, sym_arg, _ho_operators[op])
434
- agg._group = expr._group
435
- agg._where = expr._where
436
- return agg
437
- else:
438
- raise NotImplementedError(f"Solver rewrites cannot handle aggregate operator {op}")
439
-
440
- elif isinstance(expr, b.Union):
441
- # return union of the symbolified expressions, if any are symbolic
442
- args_rewritten = False
443
- args = []
444
- for arg in expr._args:
445
- ra = _rewrite(arg, sm)
446
- if ra:
447
- args_rewritten = True
448
- args.append(ra)
764
+
765
+ # Extract and convert scalar value
766
+ if result_df.shape != (1, 1):
767
+ raise ValueError(
768
+ f"Expected exactly one value for attribute '{name}', "
769
+ f"but got dataframe with shape {result_df.shape}."
770
+ )
771
+
772
+ result_value = result_df.iloc[0, 0]
773
+ if not isinstance(result_value, str):
774
+ return result_value
775
+
776
+ # Convert string results to appropriate types
777
+ if name == "solve_time_sec":
778
+ return float(result_value)
779
+ if name == "objective_value":
780
+ return int(result_value) if self._num_type == "int" else float(result_value)
781
+ if name == "result_count":
782
+ return int(result_value)
783
+ return result_value
784
+
785
+
786
+ # =============================================================================
787
+ # Symbolic Expression Classes
788
+ # =============================================================================
789
+
790
+
791
+ class Symbolic:
792
+ """Wrapper for symbolified solver expressions."""
793
+
794
+ def __init__(self, expr: Any) -> None:
795
+ if isinstance(expr, Symbolic):
796
+ raise TypeError("Cannot wrap a Symbolic expression in another Symbolic.")
797
+ self.expr = expr
798
+
799
+
800
+ class SymbolifyContext:
801
+ """Context for rewriting expressions into solver-compatible symbolic form."""
802
+
803
+ def __init__(self, solver_model: SolverModelPB) -> None:
804
+ self.model = solver_model._model
805
+ self.solver_model = solver_model
806
+ # Maps original variables (or refs) to symbolic variables bound in where clauses
807
+ self.variable_map: dict[Any, Any] = {}
808
+
809
+ # -------------------------------------------------------------------------
810
+ # Public Rewriting Methods
811
+ # -------------------------------------------------------------------------
812
+
813
+ def rewrite_where(self, *exprs: Any) -> list[Any]:
814
+ """Rewrite where clause expressions.
815
+
816
+ Args:
817
+ *exprs: Where clause expressions.
818
+
819
+ Returns:
820
+ Rewritten expressions.
821
+ """
822
+ rewritten_expressions: list[Any] = []
823
+ # Two-pass strategy: first handle variable relationships to populate variable_map,
824
+ # then rewrite other expressions that may reference those variables
825
+ # First pass: identify and handle variable relationship expressions
826
+ for expression in exprs:
827
+ if (
828
+ isinstance(expression, b.Expression)
829
+ and isinstance(expression._op, b.Relationship)
830
+ and expression._op in self.solver_model._variable_relationships
831
+ ):
832
+ rewritten_expressions.append(
833
+ self._handle_variable_relationship(expression)
834
+ )
449
835
  else:
450
- args.append(arg)
451
- if args_rewritten:
452
- return b.union(*args)
453
- return None
836
+ rewritten_expressions.append(None)
837
+ # Second pass: rewrite remaining non-variable expressions
838
+ for i, expr in enumerate(exprs):
839
+ if rewritten_expressions[i] is None:
840
+ rewritten_expressions[i] = (
841
+ expr
842
+ if expr in self.variable_map
843
+ else self._rewrite_nonsymbolic(expr)
844
+ )
845
+ return rewritten_expressions
454
846
 
455
- elif isinstance(expr, b.Fragment):
456
- # only support selects with one item
457
- assert not expr._define and not expr._require and len(expr._select) == 1, "Solver rewrites only support fragments with a single select and no define or require clauses"
458
- sym_select = _rewrite(expr._select[0], sm)
459
- if sym_select:
460
- return b.select(sym_select).where(*expr._where)
461
- return None
847
+ def rewrite(self, expr: Any) -> Optional[Symbolic | Any]:
848
+ """Rewrite expressions to symbolify solver variables."""
849
+ if expr is None:
850
+ return None
462
851
 
463
- raise NotImplementedError(f"Solver rewrites cannot handle {expr} of type {type(expr)}")
852
+ elif isinstance(expr, (int, float, str)):
853
+ return None
464
854
 
855
+ elif isinstance(expr, b.ConceptFilter):
856
+ concept = expr._op
857
+ assert isinstance(concept, b.Concept)
858
+ (ident, kwargs) = expr._params
859
+ assert ident is None
860
+ assert isinstance(kwargs, dict)
861
+ new_kwargs = {}
862
+ values_were_rewritten = False
863
+ for key, value in kwargs.items():
864
+ rewritten_value = self.rewrite(value)
865
+ if isinstance(rewritten_value, Symbolic):
866
+ raise ValueError(
867
+ f"Cannot symbolify ConceptFilter argument {key} with symbolic value."
868
+ )
869
+ if rewritten_value is not None:
870
+ values_were_rewritten = True
871
+ new_kwargs[key] = rewritten_value
872
+ else:
873
+ new_kwargs[key] = value
874
+ if values_were_rewritten:
875
+ return b.ConceptFilter(concept, ident, new_kwargs)
876
+ return None
465
877
 
466
- def _make_fo_appl_with_res(op: int, *args: Any):
467
- return _make_fo_appl(op, *args, b.String.ref("res"))
878
+ elif isinstance(expr, (b.DataColumn, b.TypeRef, b.Concept)):
879
+ return None
468
880
 
469
- def _make_fo_appl(op: int, *args: Any):
470
- assert 2 <= len(args) <= 4
471
- res = args[-1]
472
- assert isinstance(res, b.Ref)
473
- if res._thing != b.String:
474
- res = b.String.ref("res")
475
- appl = b.Relationship.builtins["rel_primitive_solverlib_fo_appl"]
476
- return b.Expression(appl, op, b.TupleArg(args[:-1]), res)
881
+ elif isinstance(expr, b.Alias):
882
+ return self.rewrite(expr._thing)
477
883
 
884
+ elif isinstance(expr, b.Ref):
885
+ if expr in self.variable_map:
886
+ return Symbolic(self.variable_map[expr])
887
+ thing = self.rewrite(expr._thing)
888
+ if thing is not None:
889
+ raise ValueError(
890
+ f"Internal error. Ref._thing rewrite unexpectedly returned {thing}."
891
+ )
892
+ return None
478
893
 
479
- _var_types = {
480
- "cont": 40,
481
- "int": 41,
482
- "bin": 42,
483
- }
894
+ elif isinstance(expr, b.Relationship):
895
+ if expr in self.variable_map:
896
+ return Symbolic(self.variable_map[expr])
897
+ variable_result = self._get_variable_ref(expr, expr._parent)
898
+ if variable_result is not None:
899
+ self.variable_map[expr] = variable_result
900
+ return Symbolic(variable_result)
901
+ return None
484
902
 
485
- _fo_operators = {
486
- "+": 10,
487
- "-": 11,
488
- "*": 12,
489
- "/": 13,
490
- "^": 14,
491
- "abs": 20,
492
- "exp": 21,
493
- "log": 22,
494
- "range": 50,
495
- }
903
+ elif isinstance(expr, b.RelationshipRef):
904
+ if expr in self.variable_map:
905
+ return Symbolic(self.variable_map[expr])
906
+ relationship = expr._relationship
907
+ if isinstance(relationship, b.Relationship):
908
+ variable_result = self._get_variable_ref(relationship, expr._parent)
909
+ if variable_result is not None:
910
+ self.variable_map[expr] = variable_result
911
+ return Symbolic(variable_result)
912
+ rewritten_parent = self.rewrite(expr._parent)
913
+ if isinstance(rewritten_parent, Symbolic):
914
+ raise ValueError(
915
+ "Internal error. RelationshipRef parent rewrite returned Symbolic."
916
+ )
917
+ if rewritten_parent is not None:
918
+ return b.RelationshipRef(rewritten_parent, relationship)
919
+ return None
496
920
 
497
- _fo_comparisons = {
498
- "=": 30,
499
- "!=": 31,
500
- "<=": 32,
501
- ">=": 33,
502
- "<": 34,
503
- ">": 35,
504
- "implies": 62,
505
- }
921
+ elif isinstance(expr, b.RelationshipFieldRef):
922
+ relationship = expr._relationship
923
+ if not isinstance(relationship, b.Relationship):
924
+ # TODO(coey): Handle relationship:RelationshipReading
925
+ return None
506
926
 
507
- _ho_operators = {
508
- "sum": 80,
509
- # "product":81,
510
- "min": 82,
511
- "max": 83,
512
- "count": 84,
513
- "all_different": 90,
514
- }
927
+ # Rewrite the relationship reference
928
+ relationship_expression = (
929
+ relationship
930
+ if expr._parent is None
931
+ else b.RelationshipRef(expr._parent, relationship)
932
+ )
933
+ variable_result = self.rewrite(relationship_expression)
934
+ if variable_result is None:
935
+ return None
936
+
937
+ # Handle symbolic result - return as-is if it's the last field
938
+ if isinstance(variable_result, Symbolic):
939
+ if expr._field_ix == len(relationship._fields) - 1:
940
+ return variable_result
941
+ variable_result = variable_result.expr
942
+
943
+ return getattr(variable_result, relationship._field_names[expr._field_ix])
944
+
945
+ elif isinstance(expr, b.Expression):
946
+ operator = self.rewrite(expr._op)
947
+ if isinstance(operator, Symbolic):
948
+ raise ValueError(
949
+ "Internal error: Expression operator rewrite returned Symbolic."
950
+ )
951
+ params_were_rewritten = False
952
+ has_symbolic_params = False
953
+ params = []
954
+ for param in expr._params:
955
+ rewritten_param = self.rewrite(param)
956
+ if isinstance(rewritten_param, Symbolic):
957
+ has_symbolic_params = True
958
+ params_were_rewritten = True
959
+ params.append(rewritten_param.expr)
960
+ elif rewritten_param is not None:
961
+ params_were_rewritten = True
962
+ params.append(rewritten_param)
963
+ else:
964
+ params.append(param)
965
+ if operator is not None:
966
+ if has_symbolic_params:
967
+ raise NotImplementedError(
968
+ f"Solver rewrites cannot handle expression {expr} "
969
+ f"with both a symbolic operator and symbolic parameters."
970
+ )
971
+ return b.Expression(operator, *params)
972
+ if not has_symbolic_params:
973
+ return b.Expression(expr._op, *params)
974
+ if not params_were_rewritten:
975
+ return None
976
+
977
+ # Some arguments involve solver variables, so rewrite into solver protobuf format
978
+ # This converts operations like x + y into fo_appl(ADD_OP, (x, y), res)
979
+ if not has_symbolic_params:
980
+ raise ValueError(
981
+ "Internal error. Expected symbolic parameters but none were found."
982
+ )
983
+ if not isinstance(expr._op, b.Relationship):
984
+ raise NotImplementedError(
985
+ f"Solver rewrites cannot handle expression {expr} "
986
+ f"with operator type {type(expr._op).__name__}."
987
+ )
988
+ operator_name = expr._op._name
989
+ if not isinstance(operator_name, str):
990
+ raise ValueError(
991
+ f"Internal error. Operator name is {type(operator_name).__name__}, expected str."
992
+ )
993
+ if operator_name in _FIRST_ORDER_OPERATOR_CODES:
994
+ return Symbolic(
995
+ _make_first_order_application(
996
+ _FIRST_ORDER_OPERATOR_CODES[operator_name], *params
997
+ )
998
+ )
999
+ elif operator_name in _FIRST_ORDER_COMPARISON_CODES:
1000
+ return Symbolic(
1001
+ _make_first_order_application_with_result(
1002
+ _FIRST_ORDER_COMPARISON_CODES[operator_name], *params
1003
+ )
1004
+ )
1005
+ else:
1006
+ raise NotImplementedError(
1007
+ f"Solver rewrites cannot handle operator '{operator_name}'."
1008
+ )
1009
+
1010
+ elif isinstance(expr, b.Aggregate):
1011
+ # Only the last argument can be symbolic
1012
+ preceding_args = [self._rewrite_nonsymbolic(arg) for arg in expr._args[:-1]]
1013
+ group = [self._rewrite_nonsymbolic(arg) for arg in expr._group]
1014
+ # TODO(coey): Should this be done with a subcontext (for variable_map)?
1015
+ where = self.rewrite_where(*expr._where._where)
1016
+ rewritten = (
1017
+ preceding_args != expr._args[:-1]
1018
+ or group != expr._group
1019
+ or where != expr._where
1020
+ )
1021
+ symbolic_arg = self.rewrite(expr._args[-1])
1022
+ if symbolic_arg is None and not rewritten:
1023
+ return None
1024
+ if not isinstance(symbolic_arg, Symbolic):
1025
+ if symbolic_arg is None:
1026
+ symbolic_arg = expr._args[-1]
1027
+ aggregate_expr = b.Aggregate(expr._op, *preceding_args, symbolic_arg)
1028
+ return aggregate_expr.per(*group).where(*where)
1029
+
1030
+ # The last argument is symbolic - convert to higher-order application
1031
+ # Example: sum(x for x in variables) becomes ho_appl(..., x, SUM_OP)
1032
+ operator_name = expr._op._name
1033
+ if not isinstance(operator_name, str):
1034
+ raise ValueError(
1035
+ f"Internal error. Aggregate operator name is {type(operator_name).__name__}, expected str."
1036
+ )
1037
+ if operator_name not in _HIGHER_ORDER_OPERATOR_CODES:
1038
+ raise NotImplementedError(
1039
+ f"Solver rewrites cannot handle aggregate operator '{operator_name}'. "
1040
+ f"Supported operators: {', '.join(_HIGHER_ORDER_OPERATOR_CODES.keys())}"
1041
+ )
1042
+ higher_order_application_builtin = b.Relationship.builtins[
1043
+ "rel_primitive_solverlib_ho_appl"
1044
+ ]
1045
+ aggregate_expr = b.Aggregate(
1046
+ higher_order_application_builtin,
1047
+ *preceding_args,
1048
+ symbolic_arg.expr,
1049
+ _HIGHER_ORDER_OPERATOR_CODES[operator_name],
1050
+ )
1051
+ return Symbolic(aggregate_expr.per(*group).where(*where))
1052
+
1053
+ elif isinstance(expr, b.Union):
1054
+ # Return union of the symbolified expressions, if any are symbolic
1055
+ args_were_rewritten = False
1056
+ has_symbolic_args = False
1057
+ args = []
1058
+ for union_arg in expr._args:
1059
+ rewritten_arg = self.rewrite(union_arg)
1060
+ if isinstance(rewritten_arg, Symbolic):
1061
+ has_symbolic_args = True
1062
+ args.append(rewritten_arg.expr)
1063
+ elif rewritten_arg is not None:
1064
+ args_were_rewritten = True
1065
+ args.append(rewritten_arg)
1066
+ else:
1067
+ args.append(union_arg)
1068
+ if has_symbolic_args:
1069
+ return Symbolic(b.union(*args))
1070
+ elif args_were_rewritten:
1071
+ return b.union(*args)
1072
+ return None
1073
+
1074
+ elif isinstance(expr, b.Fragment):
1075
+ # Only support selects with one item
1076
+ if expr._define or expr._require:
1077
+ raise ValueError(
1078
+ "Solver rewrites do not support fragments with define or require clauses."
1079
+ )
1080
+ if len(expr._select) != 1:
1081
+ raise ValueError(
1082
+ f"Solver rewrites require fragments with exactly one select item, "
1083
+ f"but got {len(expr._select)}."
1084
+ )
1085
+ # TODO(coey): Should this be done with a subcontext (for variable_map)?
1086
+ where = self.rewrite_where(*expr._where)
1087
+ symbolic_select = self.rewrite(expr._select[0])
1088
+ if isinstance(symbolic_select, Symbolic):
1089
+ return Symbolic(b.select(symbolic_select.expr).where(*where))
1090
+ elif symbolic_select is not None:
1091
+ return b.select(symbolic_select).where(*where)
1092
+ return None
1093
+
1094
+ raise NotImplementedError(
1095
+ f"Solver rewrites cannot handle {expr} of type {type(expr).__name__}."
1096
+ )
1097
+
1098
+ # -------------------------------------------------------------------------
1099
+ # Private Helper Methods
1100
+ # -------------------------------------------------------------------------
1101
+
1102
+ def _handle_variable_relationship(self, expr: b.Expression) -> Any:
1103
+ """Create symbolic reference for variable relationship expression."""
1104
+ relationship = expr._op
1105
+ if not isinstance(relationship, b.Relationship):
1106
+ raise TypeError(
1107
+ f"Expected Relationship in variable expression, but got {type(relationship).__name__}."
1108
+ )
1109
+ params = expr._params
1110
+ if len(params) != len(relationship._fields):
1111
+ raise ValueError(
1112
+ f"Parameter count mismatch: Got {len(params)} params "
1113
+ f"but relationship has {len(relationship._fields)} fields."
1114
+ )
1115
+ last_param = params[-1]
1116
+ if isinstance(last_param, b.Alias):
1117
+ last_param = last_param._thing
1118
+ if not isinstance(last_param, (b.Concept, b.Ref)):
1119
+ raise TypeError(
1120
+ f"Last parameter must be a Concept or Ref, but got {type(last_param).__name__}."
1121
+ )
1122
+ # Extract and rewrite field parameters to build the symbolic variable reference
1123
+ # This maps the relationship fields to their grounding values
1124
+ fields = {}
1125
+ for i in range(len(params) - 1):
1126
+ rewritten_param = self.rewrite(params[i])
1127
+ assert not isinstance(rewritten_param, Symbolic)
1128
+ fields[relationship._field_names[i]] = (
1129
+ rewritten_param if rewritten_param is not None else params[i]
1130
+ )
1131
+ # Create new ref corresponding to the decision variable
1132
+ variable_ref = self.solver_model._variable_relationships[relationship].ref()
1133
+ self.variable_map[last_param] = variable_ref
1134
+ # Return new condition to ground the variable
1135
+ return b.where(
1136
+ *[
1137
+ getattr(variable_ref, field_name) == field_value
1138
+ for field_name, field_value in fields.items()
1139
+ ]
1140
+ )
1141
+
1142
+ def _rewrite_nonsymbolic(self, expr: Any) -> Any:
1143
+ """Rewrite expression ensuring non-symbolic result."""
1144
+ new_expr = self.rewrite(expr)
1145
+ if isinstance(new_expr, Symbolic):
1146
+ raise ValueError(
1147
+ f"Internal error. Non-symbolic rewrite unexpectedly returned Symbolic for {expr}."
1148
+ )
1149
+ return expr if new_expr is None else new_expr
1150
+
1151
+ def _get_variable_ref(
1152
+ self, relationship: b.Relationship, parent_producer: b.Producer | None
1153
+ ) -> Optional[Any]:
1154
+ """Get variable reference for relationship, or None if not a solver variable."""
1155
+ # Check if this relationship corresponds to a decision variable
1156
+ VariableConcept = self.solver_model._variable_relationships.get(relationship)
1157
+ if VariableConcept is None:
1158
+ return None
515
1159
 
1160
+ properties = {}
1161
+ if parent_producer is not None:
1162
+ properties[relationship._field_names[0]] = parent_producer
1163
+ return VariableConcept(VariableConcept.ref(), **properties)