relationalai 0.13.2__py3-none-any.whl → 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. relationalai/clients/client.py +3 -4
  2. relationalai/clients/exec_txn_poller.py +62 -31
  3. relationalai/clients/resources/snowflake/direct_access_resources.py +6 -5
  4. relationalai/clients/resources/snowflake/snowflake.py +54 -51
  5. relationalai/clients/resources/snowflake/use_index_poller.py +1 -1
  6. relationalai/semantics/internal/snowflake.py +5 -1
  7. relationalai/semantics/lqp/algorithms.py +173 -0
  8. relationalai/semantics/lqp/builtins.py +199 -2
  9. relationalai/semantics/lqp/executor.py +90 -41
  10. relationalai/semantics/lqp/export_rewriter.py +40 -0
  11. relationalai/semantics/lqp/ir.py +28 -2
  12. relationalai/semantics/lqp/model2lqp.py +218 -45
  13. relationalai/semantics/lqp/passes.py +13 -658
  14. relationalai/semantics/lqp/rewrite/__init__.py +12 -0
  15. relationalai/semantics/lqp/rewrite/algorithm.py +385 -0
  16. relationalai/semantics/lqp/rewrite/annotate_constraints.py +22 -10
  17. relationalai/semantics/lqp/rewrite/constants_to_vars.py +70 -0
  18. relationalai/semantics/lqp/rewrite/deduplicate_vars.py +104 -0
  19. relationalai/semantics/lqp/rewrite/eliminate_data.py +108 -0
  20. relationalai/semantics/lqp/rewrite/functional_dependencies.py +31 -2
  21. relationalai/semantics/lqp/rewrite/period_math.py +77 -0
  22. relationalai/semantics/lqp/rewrite/quantify_vars.py +65 -31
  23. relationalai/semantics/lqp/rewrite/unify_definitions.py +317 -0
  24. relationalai/semantics/lqp/utils.py +11 -1
  25. relationalai/semantics/lqp/validators.py +14 -1
  26. relationalai/semantics/metamodel/builtins.py +2 -1
  27. relationalai/semantics/metamodel/compiler.py +2 -1
  28. relationalai/semantics/metamodel/dependency.py +12 -3
  29. relationalai/semantics/metamodel/executor.py +11 -1
  30. relationalai/semantics/metamodel/factory.py +2 -2
  31. relationalai/semantics/metamodel/helpers.py +7 -0
  32. relationalai/semantics/metamodel/ir.py +3 -2
  33. relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +30 -20
  34. relationalai/semantics/metamodel/rewrite/flatten.py +50 -13
  35. relationalai/semantics/metamodel/rewrite/format_outputs.py +9 -3
  36. relationalai/semantics/metamodel/typer/checker.py +6 -4
  37. relationalai/semantics/metamodel/typer/typer.py +2 -5
  38. relationalai/semantics/metamodel/visitor.py +4 -3
  39. relationalai/semantics/reasoners/optimization/solvers_dev.py +1 -1
  40. relationalai/semantics/reasoners/optimization/solvers_pb.py +3 -4
  41. relationalai/semantics/rel/compiler.py +2 -1
  42. relationalai/semantics/rel/executor.py +3 -2
  43. relationalai/semantics/tests/lqp/__init__.py +0 -0
  44. relationalai/semantics/tests/lqp/algorithms.py +345 -0
  45. relationalai/semantics/tests/test_snapshot_abstract.py +2 -1
  46. relationalai/tools/cli_controls.py +216 -67
  47. relationalai/util/format.py +5 -2
  48. {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/METADATA +2 -2
  49. {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/RECORD +52 -42
  50. {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/WHEEL +0 -0
  51. {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/entry_points.txt +0 -0
  52. {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/licenses/LICENSE +0 -0
@@ -614,7 +614,6 @@ class Client():
614
614
  self._timed_query(
615
615
  "update_registry",
616
616
  dependencies.generate_update_registry(),
617
- readonly=False,
618
617
  abort_on_error=False,
619
618
  )
620
619
 
@@ -623,7 +622,6 @@ class Client():
623
622
  self._timed_query(
624
623
  "update_packages",
625
624
  dependencies.generate_update_packages(),
626
- readonly=False,
627
625
  abort_on_error=False,
628
626
  )
629
627
  else:
@@ -646,10 +644,11 @@ class Client():
646
644
  finally:
647
645
  self._database = database_name
648
646
 
649
- def _timed_query(self, span_name:str, code: str, readonly=True, abort_on_error=True):
647
+ def _timed_query(self, span_name:str, code: str, abort_on_error=True):
650
648
  with debugging.span(span_name, model=self._database) as end_span:
651
649
  start = time.perf_counter()
652
- res, raw = self._query(code, None, end_span, readonly=readonly, abort_on_error=abort_on_error)
650
+ # NOTE hardcoding to readonly=False, read-only Rel transactions are deprecated.
651
+ res, raw = self._query(code, None, end_span, readonly=False, abort_on_error=abort_on_error)
653
652
  debugging.time(span_name, time.perf_counter() - start, code=code)
654
653
  return res, raw
655
654
 
@@ -27,16 +27,43 @@ class ExecTxnPoller:
27
27
 
28
28
  def __init__(
29
29
  self,
30
+ print_txn_progress: bool,
30
31
  resource: "Resources",
31
- txn_id: str,
32
+ txn_id: Optional[str] = None,
32
33
  headers: Optional[Dict] = None,
33
34
  txn_start_time: Optional[float] = None,
34
35
  ):
36
+ self.print_txn_progress = print_txn_progress
35
37
  self.res = resource
36
38
  self.txn_id = txn_id
37
39
  self.headers = headers or {}
38
40
  self.txn_start_time = txn_start_time or time.time()
39
41
 
42
+ def __enter__(self) -> ExecTxnPoller:
43
+ if not self.print_txn_progress:
44
+ return self
45
+ self.progress = create_progress(
46
+ description=lambda: self.description_with_timing(),
47
+ success_message="", # We'll handle this ourselves
48
+ leading_newline=False,
49
+ trailing_newline=False,
50
+ show_duration_summary=False,
51
+ )
52
+ self.progress.__enter__()
53
+ return self
54
+
55
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
56
+ if not self.print_txn_progress or self.txn_id is None:
57
+ return
58
+ # Update to success message with duration
59
+ total_duration = time.time() - self.txn_start_time
60
+ txn_id = self.txn_id
61
+ self.progress.update_main_status(
62
+ query_complete_message(txn_id, total_duration)
63
+ )
64
+ self.progress.__exit__(exc_type, exc_value, traceback)
65
+ return
66
+
40
67
  def poll(self) -> bool:
41
68
  """
42
69
  Poll for transaction completion with interactive progress display.
@@ -44,48 +71,52 @@ class ExecTxnPoller:
44
71
  Returns:
45
72
  True if transaction completed successfully, False otherwise
46
73
  """
74
+ if not self.txn_id:
75
+ raise ValueError("Transaction ID must be provided for polling.")
76
+ else:
77
+ txn_id = self.txn_id
78
+
79
+ if self.print_txn_progress:
80
+ # Update the main status to include the new txn_id
81
+ self.progress.update_main_status_fn(
82
+ lambda: self.description_with_timing(txn_id),
83
+ )
47
84
 
48
85
  # Don't show duration summary - we handle our own completion message
49
- with create_progress(
50
- description="Evaluating Query...",
51
- success_message="", # We'll handle this ourselves
52
- leading_newline=False,
53
- trailing_newline=False,
54
- show_duration_summary=False,
55
- ) as progress:
56
- def check_status() -> bool:
57
- """Check if transaction is complete."""
58
- elapsed = time.time() - self.txn_start_time
59
- # Update the main status with elapsed time
60
- progress.update_main_status(
61
- query_progress_message(self.txn_id, elapsed)
62
- )
63
- return self.res._check_exec_async_status(self.txn_id, headers=self.headers)
64
-
65
- with debugging.span("wait", txn_id=self.txn_id):
66
- poll_with_specified_overhead(check_status, overhead_rate=POLL_OVERHEAD_RATE)
67
-
68
- # Calculate final duration
69
- total_duration = time.time() - self.txn_start_time
70
-
71
- # Update to success message with duration
72
- progress.update_main_status(
73
- query_complete_message(self.txn_id, total_duration)
74
- )
86
+ def check_status() -> bool:
87
+ """Check if transaction is complete."""
88
+ finished = self.res._check_exec_async_status(txn_id, headers=self.headers)
89
+ return finished
90
+
91
+ with debugging.span("wait", txn_id=self.txn_id):
92
+ poll_with_specified_overhead(check_status, overhead_rate=POLL_OVERHEAD_RATE)
93
+
75
94
 
76
95
  return True
77
96
 
97
+ def description_with_timing(self, txn_id: str | None = None) -> str:
98
+ elapsed = time.time() - self.txn_start_time
99
+ if txn_id is None:
100
+ return query_progress_header(elapsed)
101
+ else:
102
+ return query_progress_message(txn_id, elapsed)
103
+
104
+ def query_progress_header(duration: float) -> str:
105
+ # Don't print sub-second decimals, because it updates too fast and is distracting.
106
+ duration_str = format_duration(duration, seconds_decimals=False)
107
+ return f"Evaluating Query... {duration_str:>15}\n"
108
+
78
109
  def query_progress_message(id: str, duration: float) -> str:
79
110
  return (
111
+ query_progress_header(duration) +
80
112
  # Print with whitespace to align with the end of the transaction ID
81
- f"Evaluating Query... {format_duration(duration):>18}\n" +
82
- f"{GRAY_COLOR}Query: {id}{ENDC}"
113
+ f"{GRAY_COLOR}ID: {id}{ENDC}"
83
114
  )
84
115
 
85
116
  def query_complete_message(id: str, duration: float, status_header: bool = False) -> str:
86
117
  return (
87
118
  (f"{GREEN_COLOR}✅ " if status_header else "") +
88
119
  # Print with whitespace to align with the end of the transaction ID
89
- f"Query Complete: {format_duration(duration):>24}\n" +
90
- f"{GRAY_COLOR}Query: {id}{ENDC}"
120
+ f"Query Complete: {format_duration(duration):>21}\n" +
121
+ f"{GRAY_COLOR}ID: {id}{ENDC}"
91
122
  )
@@ -13,7 +13,7 @@ from ...config import Config, ConfigStore, ENDPOINT_FILE
13
13
  from ...direct_access_client import DirectAccessClient
14
14
  from ...types import EngineState
15
15
  from ...util import get_pyrel_version, safe_json_loads, ms_to_timestamp
16
- from ....errors import GuardRailsException, ResponseStatusException, QueryTimeoutExceededException
16
+ from ....errors import GuardRailsException, ResponseStatusException, QueryTimeoutExceededException, RAIException
17
17
  from snowflake.snowpark import Session
18
18
 
19
19
  # Import UseIndexResources to enable use_index functionality with direct access
@@ -163,11 +163,12 @@ class DirectAccessResources(UseIndexResources):
163
163
  headers=headers,
164
164
  )
165
165
  response = _send_request()
166
- except requests.exceptions.ConnectionError as e:
166
+ except (requests.exceptions.ConnectionError, RAIException) as e:
167
167
  messages = collect_error_messages(e)
168
- if any("nameresolutionerror" in msg for msg in messages):
169
- # when we can not resolve the service endpoint, we assume it is outdated
170
- # hence, we try to retrieve it again and query again.
168
+ if any("nameresolutionerror" in msg for msg in messages) or \
169
+ any("could not find the service associated with endpoint" in msg for msg in messages):
170
+ # when we can not resolve the service endpoint or the service is not found,
171
+ # we assume the endpoint is outdated, so we retrieve it again and retry.
171
172
  self.direct_access_client.service_endpoint = self._retrieve_service_endpoint(
172
173
  enforce_update=True,
173
174
  )
@@ -15,7 +15,7 @@ import hashlib
15
15
  from dataclasses import dataclass
16
16
 
17
17
  from ....auth.token_handler import TokenHandler
18
- from relationalai.clients.exec_txn_poller import ExecTxnPoller, query_complete_message
18
+ from relationalai.clients.exec_txn_poller import ExecTxnPoller
19
19
  import snowflake.snowpark
20
20
 
21
21
  from ....rel_utils import sanitize_identifier, to_fqn_relation_name
@@ -105,6 +105,9 @@ TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
105
105
  GUARDRAILS_ABORT_REASON = "guard rail violation"
106
106
 
107
107
  PRINT_TXN_PROGRESS_FLAG = "print_txn_progress"
108
+ ENABLE_GUARD_RAILS_FLAG = "enable_guard_rails"
109
+
110
+ ENABLE_GUARD_RAILS_HEADER = "X-RAI-Enable-Guard-Rails"
108
111
 
109
112
  #--------------------------------------------------
110
113
  # Helpers
@@ -113,6 +116,9 @@ PRINT_TXN_PROGRESS_FLAG = "print_txn_progress"
113
116
  def should_print_txn_progress(config) -> bool:
114
117
  return bool(config.get(PRINT_TXN_PROGRESS_FLAG, False))
115
118
 
119
+ def should_enable_guard_rails(config) -> bool:
120
+ return bool(config.get(ENABLE_GUARD_RAILS_FLAG, False))
121
+
116
122
  #--------------------------------------------------
117
123
  # Resources
118
124
  #--------------------------------------------------
@@ -1788,65 +1794,62 @@ Otherwise, remove it from your '{profile}' configuration profile.
1788
1794
 
1789
1795
  with debugging.span("transaction", **query_attrs_dict) as txn_span:
1790
1796
  txn_start_time = time.time()
1791
- with debugging.span("create_v2", **query_attrs_dict) as create_span:
1792
- # Prepare headers for transaction creation
1793
- request_headers['user-agent'] = get_pyrel_version(self.generation)
1794
- request_headers['gi_setup_skipped'] = str(gi_setup_skipped)
1795
- request_headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
1796
-
1797
- # Create the transaction
1798
- result = self._create_v2_txn(
1799
- database=database,
1800
- engine=engine,
1801
- raw_code=raw_code,
1802
- inputs=inputs,
1803
- headers=request_headers,
1804
- readonly=readonly,
1805
- nowait_durable=nowait_durable,
1806
- bypass_index=bypass_index,
1807
- language=language,
1808
- query_timeout_mins=query_timeout_mins,
1809
- )
1797
+ print_txn_progress = should_print_txn_progress(self.config)
1810
1798
 
1811
- txn_id = result.txn_id
1812
- state = result.state
1799
+ with ExecTxnPoller(
1800
+ print_txn_progress=print_txn_progress,
1801
+ resource=self, txn_id=None, headers=request_headers,
1802
+ txn_start_time=txn_start_time
1803
+ ) as poller:
1804
+ with debugging.span("create_v2", **query_attrs_dict) as create_span:
1805
+ # Prepare headers for transaction creation
1806
+ request_headers['user-agent'] = get_pyrel_version(self.generation)
1807
+ request_headers['gi_setup_skipped'] = str(gi_setup_skipped)
1808
+ request_headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
1809
+ request_headers[ENABLE_GUARD_RAILS_HEADER] = str(should_enable_guard_rails(self.config))
1810
+
1811
+ # Create the transaction
1812
+ result = self._create_v2_txn(
1813
+ database=database,
1814
+ engine=engine,
1815
+ raw_code=raw_code,
1816
+ inputs=inputs,
1817
+ headers=request_headers,
1818
+ readonly=readonly,
1819
+ nowait_durable=nowait_durable,
1820
+ bypass_index=bypass_index,
1821
+ language=language,
1822
+ query_timeout_mins=query_timeout_mins,
1823
+ )
1813
1824
 
1814
- txn_span["txn_id"] = txn_id
1815
- create_span["txn_id"] = txn_id
1816
- debugging.event("transaction_created", txn_span, txn_id=txn_id)
1825
+ txn_id = result.txn_id
1826
+ state = result.state
1817
1827
 
1818
- print_txn_progress = should_print_txn_progress(self.config)
1828
+ txn_span["txn_id"] = txn_id
1829
+ create_span["txn_id"] = txn_id
1830
+ debugging.event("transaction_created", txn_span, txn_id=txn_id)
1819
1831
 
1820
- # fast path: transaction already finished
1821
- if state in ["COMPLETED", "ABORTED"]:
1822
- txn_end_time = time.time()
1823
- if txn_id in self._pending_transactions:
1824
- self._pending_transactions.remove(txn_id)
1832
+ # Set the transaction ID now that we have it, to update the progress text
1833
+ poller.txn_id = txn_id
1825
1834
 
1826
- artifact_info = result.artifact_info
1835
+ # fast path: transaction already finished
1836
+ if state in ["COMPLETED", "ABORTED"]:
1837
+ if txn_id in self._pending_transactions:
1838
+ self._pending_transactions.remove(txn_id)
1827
1839
 
1828
- txn_duration = txn_end_time - txn_start_time
1829
- if print_txn_progress:
1830
- print(
1831
- query_complete_message(txn_id, txn_duration, status_header=True)
1832
- )
1840
+ artifact_info = result.artifact_info
1833
1841
 
1834
- # Slow path: transaction not done yet; start polling
1835
- else:
1836
- self._pending_transactions.append(txn_id)
1837
- # Use the interactive poller for transaction status
1838
- with debugging.span("wait", txn_id=txn_id):
1839
- if print_txn_progress:
1840
- poller = ExecTxnPoller(resource=self, txn_id=txn_id, headers=request_headers, txn_start_time=txn_start_time)
1842
+ # Slow path: transaction not done yet; start polling
1843
+ else:
1844
+ self._pending_transactions.append(txn_id)
1845
+ # Use the interactive poller for transaction status
1846
+ with debugging.span("wait", txn_id=txn_id):
1841
1847
  poller.poll()
1842
- else:
1843
- poll_with_specified_overhead(
1844
- lambda: self._check_exec_async_status(txn_id, headers=request_headers), 0.1
1845
- )
1846
- artifact_info = self._list_exec_async_artifacts(txn_id, headers=request_headers)
1847
1848
 
1848
- with debugging.span("fetch"):
1849
- return self._download_results(artifact_info, txn_id, state)
1849
+ artifact_info = self._list_exec_async_artifacts(txn_id, headers=request_headers)
1850
+
1851
+ with debugging.span("fetch"):
1852
+ return self._download_results(artifact_info, txn_id, state)
1850
1853
 
1851
1854
  def get_user_based_engine_name(self):
1852
1855
  if not self._session:
@@ -66,7 +66,7 @@ ERP_CHECK_FREQUENCY = 15
66
66
 
67
67
  # Polling behavior constants
68
68
  POLL_OVERHEAD_RATE = 0.1 # Overhead rate for exponential backoff
69
- POLL_MAX_DELAY = 2.5 # Maximum delay between polls in seconds
69
+ POLL_MAX_DELAY = 0 # Maximum delay between polls in seconds
70
70
 
71
71
  # SQL query template for getting stream column hashes
72
72
  # This query calculates a hash of column metadata (name, type, precision, scale, nullable)
@@ -214,6 +214,7 @@ class Table():
214
214
  self._col_names = cols
215
215
  self._iceberg_config = config
216
216
  self._is_iceberg = config is not None
217
+ self._skip_cdc = False
217
218
  info = self._schemas.get((self._database, self._schema))
218
219
  if not info:
219
220
  info = self._schemas[(self._database, self._schema)] = SchemaInfo(self._database, self._schema)
@@ -303,7 +304,10 @@ class Table():
303
304
 
304
305
  def _compile_lookup(self, compiler:b.Compiler, ctx:b.CompilerContext):
305
306
  self._lazy_init()
306
- Table._used_sources.add(self)
307
+ if not self._skip_cdc:
308
+ # Don't do CDC if the underlying data has been loaded
309
+ # directly via `api.load_data`.
310
+ Table._used_sources.add(self)
307
311
  compiler.lookup(self._rel, ctx)
308
312
  return compiler.lookup(b.RelationshipFieldRef(None, self._rel, 0), ctx)
309
313
 
@@ -0,0 +1,173 @@
1
+ from typing import TypeGuard
2
+ from relationalai.semantics.metamodel import ir, factory, types
3
+ from relationalai.semantics.metamodel.visitor import Rewriter, collect_by_type
4
+ from relationalai.semantics.lqp import ir as lqp
5
+ from relationalai.semantics.lqp.types import meta_type_to_lqp
6
+ from relationalai.semantics.lqp.builtins import (
7
+ has_empty_annotation, has_assign_annotation, has_upsert_annotation,
8
+ has_monoid_annotation, has_monus_annotation, has_script_annotation,
9
+ has_algorithm_annotation, has_while_annotation, global_annotation,
10
+ empty_annotation, assign_annotation, upsert_annotation, monoid_annotation,
11
+ monus_annotation
12
+ )
13
+
14
+ # Complex tests for Loopy constructs in the metamodel
15
+ def is_script(task: ir.Task) -> TypeGuard[ir.Sequence]:
16
+ """ Check if it is a script i.e., a Sequence with @script annotation. """
17
+ if not isinstance(task, ir.Sequence):
18
+ return False
19
+ return has_script_annotation(task)
20
+
21
+ def is_algorithm_logical(task: ir.Task) -> TypeGuard[ir.Logical]:
22
+ """ Check if it is an algorithm logical i.e., a Logical task with all subtasks being
23
+ algorithm scripts. """
24
+ if not isinstance(task, ir.Logical):
25
+ return False
26
+ return all(is_algorithm_script(subtask) for subtask in task.body)
27
+
28
+ def is_algorithm_script(task: ir.Task) -> TypeGuard[ir.Sequence]:
29
+ """ Check if it is an algorithm script i.e., a Sequence with @script and @algorithm annotations. """
30
+ if not isinstance(task, ir.Sequence):
31
+ return False
32
+ return is_script(task) and has_algorithm_annotation(task)
33
+
34
+ def is_while_loop(task: ir.Task) -> TypeGuard[ir.Loop]:
35
+ """ Check if input is is a while loop i.e., a Loop with @while annotation. """
36
+ if not isinstance(task, ir.Loop):
37
+ return False
38
+ return has_while_annotation(task)
39
+
40
+ def is_while_script(task: ir.Task) -> TypeGuard[ir.Sequence]:
41
+ """ Check if input is a while script i.e., a Sequence with @script and @while annotations. """
42
+ if not isinstance(task, ir.Sequence):
43
+ return False
44
+ return is_script(task) and has_while_annotation(task)
45
+
46
+ # Tools for annotating Loopy constructs
47
+ class LoopyAnnoAdder(Rewriter):
48
+ """ Rewrites a node by adding the given annotation to all Update nodes. """
49
+ def __init__(self, anno: ir.Annotation):
50
+ self.anno = anno
51
+ super().__init__()
52
+
53
+ def handle_update(self, node: ir.Update, parent: ir.Node) -> ir.Update:
54
+ new_annos = list(node.annotations) + [self.anno]
55
+ return factory.update(node.relation, node.args, node.effect, new_annos, node.engine)
56
+
57
+ def mk_global(i: ir.Node):
58
+ return LoopyAnnoAdder(global_annotation()).walk(i)
59
+
60
+ def mk_empty(i: ir.Node):
61
+ return LoopyAnnoAdder(empty_annotation()).walk(i)
62
+
63
+ def mk_assign(i: ir.Node):
64
+ return LoopyAnnoAdder(assign_annotation()).walk(i)
65
+
66
+ def mk_upsert(i: ir.Node, arity: int):
67
+ return LoopyAnnoAdder(upsert_annotation(arity)).walk(i)
68
+
69
+ def mk_monoid(i: ir.Node, monoid_type: ir.ScalarType, monoid_op: str, arity: int):
70
+ return LoopyAnnoAdder(monoid_annotation(monoid_type, monoid_op, arity)).walk(i)
71
+
72
+ def mk_monus(i: ir.Node, monoid_type: ir.ScalarType, monoid_op: str, arity: int):
73
+ return LoopyAnnoAdder(monus_annotation(monoid_type, monoid_op, arity)).walk(i)
74
+
75
+ def construct_monoid(i: ir.Annotation):
76
+ base_type = None
77
+ op = None
78
+ for arg in i.args:
79
+ if isinstance(arg, ir.ScalarType):
80
+ base_type = meta_type_to_lqp(arg)
81
+ elif isinstance(arg, ir.Literal) and arg.type == types.String:
82
+ op = arg.value
83
+ assert isinstance(base_type, lqp.Type) and isinstance(op, str), "Failed to get monoid"
84
+ if op.lower() == "or":
85
+ return lqp.OrMonoid(meta=None)
86
+ elif op.lower() == "sum":
87
+ return lqp.SumMonoid(type=base_type, meta=None)
88
+ elif op.lower() == "min":
89
+ return lqp.MinMonoid(type=base_type, meta=None)
90
+ elif op.lower() == "max":
91
+ return lqp.MaxMonoid(type=base_type, meta=None)
92
+ else:
93
+ assert False, "Failed to get monoid"
94
+
95
+ # Tools for analyzing Loopy constructs
96
+ def is_logical_instruction(node: ir.Node) -> TypeGuard[ir.Logical]:
97
+ if not isinstance(node, ir.Logical):
98
+ return False
99
+ return any(collect_by_type(ir.Update, node)) and not any(collect_by_type(ir.Sequence, node))
100
+
101
+ def get_instruction_body_rels(node: ir.Logical) -> set[ir.Relation]:
102
+ assert is_logical_instruction(node)
103
+ body: set[ir.Relation] = set()
104
+ for update in collect_by_type(ir.Lookup, node):
105
+ body.add(update.relation)
106
+ return body
107
+
108
+ def get_instruction_head_rels(node: ir.Logical) -> set[ir.Relation]:
109
+ assert is_logical_instruction(node)
110
+ heads: set[ir.Relation] = set()
111
+ for update in collect_by_type(ir.Update, node):
112
+ heads.add(update.relation)
113
+ return heads
114
+
115
+ # base Loopy instruction: @empty, @assign, @upsert, @monoid, @monus
116
+ def is_instruction(update: ir.Task) -> TypeGuard[ir.Logical]:
117
+ if not is_logical_instruction(update):
118
+ return False
119
+ for u in collect_by_type(ir.Update, update):
120
+ if (has_empty_annotation(u) or
121
+ has_assign_annotation(u) or
122
+ has_upsert_annotation(u) or
123
+ has_monoid_annotation(u) or
124
+ has_monus_annotation(u)):
125
+ return True
126
+ return False
127
+
128
+ # update Loopy instruction @upsert, @monoid, @monus
129
+ def is_update_instruction(task: ir.Task) -> TypeGuard[ir.Logical]:
130
+ if not is_logical_instruction(task):
131
+ return False
132
+ for u in collect_by_type(ir.Update, task):
133
+ if (has_upsert_annotation(u) or
134
+ has_monoid_annotation(u) or
135
+ has_monus_annotation(u)):
136
+ return True
137
+ return False
138
+
139
+ def is_empty_instruction(node: ir.Node) -> TypeGuard[ir.Logical]:
140
+ """ Check if input is an empty Loopy instruction `empty rel = ∅`"""
141
+ if not is_logical_instruction(node):
142
+ return False
143
+ updates = collect_by_type(ir.Update, node)
144
+ if not any(has_empty_annotation(update) for update in updates):
145
+ return False
146
+
147
+ # At this point, we have the prerequisites for an empty instruction. We check it is
148
+ # well-formed:
149
+ # 1. It has only a single @empty Update operation
150
+ # 2. Has no other operations
151
+ assert len(updates) == 1, "[Loopy] Empty instruction must have single Update operation"
152
+ assert len(node.body) == 1, "[Loopy] Empty instruction must have only a single Update operation"
153
+
154
+ return True
155
+
156
+ # Splits a Loopy instruction into its head updates, body lookups, and other body tasks
157
+ def split_instruction(update_logical: ir.Logical) -> tuple[ir.Update,list[ir.Lookup],list[ir.Task]]:
158
+ assert is_instruction(update_logical)
159
+ lookups = []
160
+ update = None
161
+ others = []
162
+ for task in update_logical.body:
163
+ if isinstance(task, ir.Lookup):
164
+ lookups.append(task)
165
+ elif isinstance(task, ir.Update):
166
+ if update is not None:
167
+ raise AssertionError("[Loopy] Update instruction must have exactly one Update operation")
168
+ update = task
169
+ else:
170
+ others.append(task)
171
+ assert update is not None, "[Loopy] Update instruction must have exactly one Update operation"
172
+
173
+ return update, lookups, others