relationalai 0.12.2__py3-none-any.whl → 0.12.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -261,20 +261,7 @@ class SolverModel:
261
261
  remaining_timeout_minutes = calc_remaining_timeout_minutes(
262
262
  start_time, query_timeout_mins, config_file_path=config_file_path
263
263
  )
264
- try:
265
- job_id = solver._exec_job(payload, log_to_console=log_to_console, query_timeout_mins=remaining_timeout_minutes)
266
- except Exception as e:
267
- err_message = str(e).lower()
268
- if isinstance(e, ResponseStatusException):
269
- err_message = e.response.json().get("message", "")
270
- if any(kw in err_message.lower() for kw in ENGINE_ERRORS + WORKER_ERRORS + ENGINE_NOT_READY_MSGS):
271
- solver._auto_create_solver_async()
272
- remaining_timeout_minutes = calc_remaining_timeout_minutes(
273
- start_time, query_timeout_mins, config_file_path=config_file_path
274
- )
275
- job_id = solver._exec_job(payload, log_to_console=log_to_console, query_timeout_mins=remaining_timeout_minutes)
276
- else:
277
- raise e
264
+ job_id = solver._exec_job(payload, log_to_console=log_to_console, query_timeout_mins=remaining_timeout_minutes)
278
265
 
279
266
  # 3. Extract result.
280
267
  remaining_timeout_minutes = calc_remaining_timeout_minutes(
@@ -660,12 +647,24 @@ class Solver:
660
647
  if self.engine is None:
661
648
  raise Exception("Engine not initialized.")
662
649
 
663
- # Make sure the engine is ready.
664
- if self.engine["state"] != "READY":
665
- poll_with_specified_overhead(lambda: self._is_solver_ready(), 0.1)
666
-
667
650
  with debugging.span("job") as job_span:
668
- job_id = self.provider.create_job_async(self.engine["name"], payload, query_timeout_mins=query_timeout_mins)
651
+ # Retry logic. If creating a job fails with an engine
652
+ # related error we will create/resume/... the engine and
653
+ # retry.
654
+ try:
655
+ job_id = self.provider.create_job_async(self.engine["name"], payload, query_timeout_mins=query_timeout_mins)
656
+ except Exception as e:
657
+ err_message = str(e).lower()
658
+ if isinstance(e, ResponseStatusException):
659
+ err_message = e.response.json().get("message", "")
660
+ if any(kw in err_message.lower() for kw in ENGINE_ERRORS + WORKER_ERRORS + ENGINE_NOT_READY_MSGS):
661
+ self._auto_create_solver_async()
662
+ # Wait until the engine is ready.
663
+ poll_with_specified_overhead(lambda: self._is_solver_ready(), 0.1)
664
+ job_id = self.provider.create_job_async(self.engine["name"], payload, query_timeout_mins=query_timeout_mins)
665
+ else:
666
+ raise e
667
+
669
668
  job_span["job_id"] = job_id
670
669
  debugging.event("job_created", job_span, job_id=job_id, engine_name=self.engine["name"], job_type=ENGINE_TYPE_SOLVER)
671
670
  if not isinstance(job_id, str):
@@ -300,6 +300,27 @@ class LQPExecutor(e.Executor):
300
300
  meta=None,
301
301
  )
302
302
 
303
+ def _compile_undefine_query(self, query_epoch: lqp_ir.Epoch) -> lqp_ir.Epoch:
304
+ fragment_ids = []
305
+
306
+ for write in query_epoch.writes:
307
+ if isinstance(write.write_type, lqp_ir.Define):
308
+ fragment_ids.append(write.write_type.fragment.id)
309
+
310
+ # Construct new Epoch with Undefine operations for all collected fragment IDs
311
+ undefine_writes = [
312
+ lqp_ir.Write(
313
+ write_type=lqp_ir.Undefine(fragment_id=frag_id, meta=None),
314
+ meta=None
315
+ )
316
+ for frag_id in fragment_ids
317
+ ]
318
+
319
+ return lqp_ir.Epoch(
320
+ writes=undefine_writes,
321
+ meta=None,
322
+ )
323
+
303
324
  def compile_lqp(self, model: ir.Model, task: ir.Task):
304
325
  configure = self._construct_configure()
305
326
 
@@ -334,7 +355,11 @@ class LQPExecutor(e.Executor):
334
355
  if model_txn is not None:
335
356
  epochs.append(model_txn.epochs[0])
336
357
 
337
- epochs.append(query_txn.epochs[0])
358
+ query_txn_epoch = query_txn.epochs[0]
359
+
360
+ epochs.append(query_txn_epoch)
361
+
362
+ epochs.append(self._compile_undefine_query(query_txn_epoch))
338
363
 
339
364
  txn = lqp_ir.Transaction(epochs=epochs, configure=configure, meta=None)
340
365
 
@@ -265,10 +265,12 @@ class ExtractCommon(Pass):
265
265
  for child in common_body:
266
266
  body_output_vars.update(ctx.info.task_outputs(child))
267
267
 
268
- # Compute the union of input vars across all composites, intersected with output
268
+ # Compute the union of input vars across all non-extracted tasks (basically
269
+ # composites and binders left behind), intersected with output
269
270
  # vars of the common body
270
271
  exposed_vars = OrderedSet.from_iterable(ctx.info.task_inputs(sample)) & body_output_vars
271
- for composite in composites:
272
+ non_extracted_tasks = (binders - common_body) | composites
273
+ for composite in non_extracted_tasks:
272
274
  if composite is sample:
273
275
  continue
274
276
  # compute common input vars
@@ -112,12 +112,12 @@ log10 = f.relation(
112
112
 
113
113
  log = f.relation(
114
114
  "log",
115
- [f.input_field("a", types.Number), f.input_field("b", types.Number), f.field("c", types.Float)],
115
+ [f.input_field("base", types.Number), f.input_field("value", types.Number), f.field("result", types.Float)],
116
116
  overloads=[
117
- f.relation("log", [f.input_field("a", types.Int64), f.input_field("b", types.Int64), f.field("c", types.Float)]),
118
- f.relation("log", [f.input_field("a", types.Int128), f.input_field("b", types.Int128), f.field("c", types.Float)]),
119
- f.relation("log", [f.input_field("a", types.Float), f.input_field("b", types.Float), f.field("c", types.Float)]),
120
- f.relation("log", [f.input_field("a", types.GenericDecimal), f.input_field("b", types.GenericDecimal), f.field("c", types.Float)]),
117
+ f.relation("log", [f.input_field("base", types.Int64), f.input_field("value", types.Int64), f.field("result", types.Float)]),
118
+ f.relation("log", [f.input_field("base", types.Int128), f.input_field("value", types.Int128), f.field("result", types.Float)]),
119
+ f.relation("log", [f.input_field("base", types.Float), f.input_field("value", types.Float), f.field("result", types.Float)]),
120
+ f.relation("log", [f.input_field("base", types.GenericDecimal), f.input_field("value", types.GenericDecimal), f.field("result", types.Float)]),
121
121
 
122
122
  ],
123
123
  )
@@ -496,7 +496,7 @@ function = f.relation("function", [f.input_field("code", types.Symbol)])
496
496
  function_checked_annotation = f.annotation(function, [f.lit("checked")])
497
497
  function_annotation = f.annotation(function, [])
498
498
 
499
- # Indicates this relation should be tracked in telemetry. Only supported for Relationships.
499
+ # Indicates this relation should be tracked in telemetry. Supported for Relationships and Concepts.
500
500
  # `RAI_BackIR.with_relation_tracking` produces log messages at the start and end of each
501
501
  # SCC evaluation, if any declarations bear the `track` annotation.
502
502
  track = f.relation("track", [
@@ -627,11 +627,13 @@ class BindingAnalysis(visitor.Visitor):
627
627
  x, y = child.args[0], child.args[1]
628
628
  if isinstance(x, ir.Var) and not isinstance(y, ir.Var):
629
629
  grounds.add(x)
630
+ self.output(child, x)
630
631
  elif not isinstance(x, ir.Var) and isinstance(y, ir.Var):
631
632
  grounds.add(y)
633
+ self.output(child, y)
632
634
  elif isinstance(x, ir.Var) and isinstance(y, ir.Var):
633
635
  # mark as potentially grounded, if any is grounded in other atoms then we later ground both
634
- potentially_grounded.add((x, y))
636
+ potentially_grounded.add((child, x, y))
635
637
  else:
636
638
  # grounds only outputs
637
639
  for idx, f in enumerate(child.relation.fields):
@@ -654,23 +656,33 @@ class BindingAnalysis(visitor.Visitor):
654
656
  # grounds the output var
655
657
  grounds.add(child.id_var)
656
658
 
659
+ # add child hoisted vars to grounded so that they can be picked up by the children
660
+ for child in node.body:
661
+ if isinstance(child, helpers.COMPOSITES):
662
+ grounds.update(helpers.hoisted_vars(child.hoisted))
663
+
664
+ # equalities where both sides are already grounded mean that both sides are input
665
+ for child, x, y in potentially_grounded:
666
+ if x in grounds and y in grounds:
667
+ self.input(child, x)
668
+ self.input(child, y)
669
+
657
670
  # deal with potentially grounded vars up to a fixpoint
658
671
  changed = True
659
672
  while changed:
660
673
  changed = False
661
- for x, y in potentially_grounded:
674
+ for child, x, y in potentially_grounded:
662
675
  if x in grounds and y not in grounds:
676
+ self.input(child, x)
677
+ self.output(child, y)
663
678
  grounds.add(y)
664
679
  changed = True
665
680
  elif y in grounds and x not in grounds:
681
+ self.input(child, y)
682
+ self.output(child, x)
666
683
  grounds.add(x)
667
684
  changed = True
668
685
 
669
- # add child hoisted vars to grounded so that they can be picked up by the children
670
- for child in node.body:
671
- if isinstance(child, helpers.COMPOSITES):
672
- grounds.update(helpers.hoisted_vars(child.hoisted))
673
-
674
686
  # now visit the children
675
687
  self._grounded.append(grounds)
676
688
  super().visit_logical(node, parent)
@@ -765,20 +777,28 @@ class BindingAnalysis(visitor.Visitor):
765
777
  self.output(node, arg)
766
778
 
767
779
  if builtins.is_eq(node.relation):
768
- # special case eq because it can be input or output
769
- x, y = node.args[0], node.args[1]
770
- if isinstance(x, ir.Var) and not isinstance(y, ir.Var):
771
- self.output(node, x)
772
- elif not isinstance(x, ir.Var) and isinstance(y, ir.Var):
773
- self.output(node, y)
774
- elif isinstance(x, ir.Var) and isinstance(y, ir.Var):
775
- # in this case it's possible that both are outputs (if both are grounded), else both are inputs
776
- grounds = self._grounded[-1] if self._grounded else ordered_set()
777
- for var in (x, y):
778
- if var in grounds:
779
- self.output(node, var)
780
+ # Most cases are covered already at the parent level if the equality is part of
781
+ # a Logical. The remaining cases are when the equality is a child of a
782
+ # non-Logical.
783
+ if self.info.task_inputs(node) or self.info.task_outputs(node):
784
+ # already covered
785
+ pass
786
+ else:
787
+ x, y = node.args[0], node.args[1]
788
+ if isinstance(x, ir.Var) and not isinstance(y, ir.Var):
789
+ self.output(node, x)
790
+ elif not isinstance(x, ir.Var) and isinstance(y, ir.Var):
791
+ self.output(node, y)
792
+ elif isinstance(x, ir.Var) and isinstance(y, ir.Var):
793
+ grounds = self._grounded[-1] if self._grounded else ordered_set()
794
+ if x in grounds:
795
+ self.input(node, x)
796
+ else:
797
+ self.output(node, x)
798
+ if y in grounds:
799
+ self.input(node, y)
780
800
  else:
781
- self.input(node, var)
801
+ self.output(node, y)
782
802
  else:
783
803
  # register variables depending on the input flag of the relation bound to the lookup
784
804
  for idx, f in enumerate(node.relation.fields):
@@ -348,12 +348,13 @@ def clone_task(task: T) -> T:
348
348
  # if no childrean were stacked, we rewrote all fields of curr, so we can pop it and rewrite it
349
349
  if not stacked_children:
350
350
  stack.pop()
351
- children = []
352
- for f in curr_fields:
353
- children.append(from_cache(getattr(curr, f.name)))
354
- # create a new prev_node with the cloned children
355
- prev_node = curr.__class__(*children)
356
- cache[curr.id] = prev_node
351
+ if curr.id not in cache:
352
+ children = []
353
+ for f in curr_fields:
354
+ children.append(from_cache(getattr(curr, f.name)))
355
+ # create a new prev_node with the cloned children
356
+ prev_node = curr.__class__(*children)
357
+ cache[curr.id] = prev_node
357
358
 
358
359
  # the last node we processed is the rewritten original node
359
360
  assert(isinstance(prev_node, type(task)))
@@ -58,10 +58,7 @@ class LogicalExtractor(Rewriter):
58
58
  # compute the vars to be exposed by the extracted logical; those are keys (what
59
59
  # makes the values unique) + the values (the hoisted variables)
60
60
  exposed_vars = ordered_set()
61
- # start with all the inputs to this logical as keys
62
- # TODO - in the future we can analyze better these inputs to see if we can drop
63
- # some and have narrower intermediate relations
64
- exposed_vars.update(self.info.task_inputs(logical))
61
+
65
62
  # if there are aggregations, make sure we don't expose the projected and input vars,
66
63
  # but expose groupbys
67
64
  for child in node.body: