relationalai 0.13.2__py3-none-any.whl → 0.13.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/clients/client.py +3 -4
- relationalai/clients/exec_txn_poller.py +62 -31
- relationalai/clients/resources/snowflake/direct_access_resources.py +6 -5
- relationalai/clients/resources/snowflake/snowflake.py +54 -51
- relationalai/clients/resources/snowflake/use_index_poller.py +1 -1
- relationalai/semantics/internal/snowflake.py +5 -1
- relationalai/semantics/lqp/algorithms.py +173 -0
- relationalai/semantics/lqp/builtins.py +199 -2
- relationalai/semantics/lqp/executor.py +90 -41
- relationalai/semantics/lqp/export_rewriter.py +40 -0
- relationalai/semantics/lqp/ir.py +28 -2
- relationalai/semantics/lqp/model2lqp.py +218 -45
- relationalai/semantics/lqp/passes.py +13 -658
- relationalai/semantics/lqp/rewrite/__init__.py +12 -0
- relationalai/semantics/lqp/rewrite/algorithm.py +385 -0
- relationalai/semantics/lqp/rewrite/annotate_constraints.py +22 -10
- relationalai/semantics/lqp/rewrite/constants_to_vars.py +70 -0
- relationalai/semantics/lqp/rewrite/deduplicate_vars.py +104 -0
- relationalai/semantics/lqp/rewrite/eliminate_data.py +108 -0
- relationalai/semantics/lqp/rewrite/functional_dependencies.py +31 -2
- relationalai/semantics/lqp/rewrite/period_math.py +77 -0
- relationalai/semantics/lqp/rewrite/quantify_vars.py +65 -31
- relationalai/semantics/lqp/rewrite/unify_definitions.py +317 -0
- relationalai/semantics/lqp/utils.py +11 -1
- relationalai/semantics/lqp/validators.py +14 -1
- relationalai/semantics/metamodel/builtins.py +2 -1
- relationalai/semantics/metamodel/compiler.py +2 -1
- relationalai/semantics/metamodel/dependency.py +12 -3
- relationalai/semantics/metamodel/executor.py +11 -1
- relationalai/semantics/metamodel/factory.py +2 -2
- relationalai/semantics/metamodel/helpers.py +7 -0
- relationalai/semantics/metamodel/ir.py +3 -2
- relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +30 -20
- relationalai/semantics/metamodel/rewrite/flatten.py +50 -13
- relationalai/semantics/metamodel/rewrite/format_outputs.py +9 -3
- relationalai/semantics/metamodel/typer/checker.py +6 -4
- relationalai/semantics/metamodel/typer/typer.py +2 -5
- relationalai/semantics/metamodel/visitor.py +4 -3
- relationalai/semantics/reasoners/optimization/solvers_dev.py +1 -1
- relationalai/semantics/reasoners/optimization/solvers_pb.py +3 -4
- relationalai/semantics/rel/compiler.py +2 -1
- relationalai/semantics/rel/executor.py +3 -2
- relationalai/semantics/tests/lqp/__init__.py +0 -0
- relationalai/semantics/tests/lqp/algorithms.py +345 -0
- relationalai/semantics/tests/test_snapshot_abstract.py +2 -1
- relationalai/tools/cli_controls.py +216 -67
- relationalai/util/format.py +5 -2
- {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/METADATA +2 -2
- {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/RECORD +52 -42
- {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/WHEEL +0 -0
- {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/entry_points.txt +0 -0
- {relationalai-0.13.2.dist-info → relationalai-0.13.4.dist-info}/licenses/LICENSE +0 -0
relationalai/clients/client.py
CHANGED
|
@@ -614,7 +614,6 @@ class Client():
|
|
|
614
614
|
self._timed_query(
|
|
615
615
|
"update_registry",
|
|
616
616
|
dependencies.generate_update_registry(),
|
|
617
|
-
readonly=False,
|
|
618
617
|
abort_on_error=False,
|
|
619
618
|
)
|
|
620
619
|
|
|
@@ -623,7 +622,6 @@ class Client():
|
|
|
623
622
|
self._timed_query(
|
|
624
623
|
"update_packages",
|
|
625
624
|
dependencies.generate_update_packages(),
|
|
626
|
-
readonly=False,
|
|
627
625
|
abort_on_error=False,
|
|
628
626
|
)
|
|
629
627
|
else:
|
|
@@ -646,10 +644,11 @@ class Client():
|
|
|
646
644
|
finally:
|
|
647
645
|
self._database = database_name
|
|
648
646
|
|
|
649
|
-
def _timed_query(self, span_name:str, code: str,
|
|
647
|
+
def _timed_query(self, span_name:str, code: str, abort_on_error=True):
|
|
650
648
|
with debugging.span(span_name, model=self._database) as end_span:
|
|
651
649
|
start = time.perf_counter()
|
|
652
|
-
|
|
650
|
+
# NOTE hardcoding to readonly=False, read-only Rel transactions are deprecated.
|
|
651
|
+
res, raw = self._query(code, None, end_span, readonly=False, abort_on_error=abort_on_error)
|
|
653
652
|
debugging.time(span_name, time.perf_counter() - start, code=code)
|
|
654
653
|
return res, raw
|
|
655
654
|
|
|
@@ -27,16 +27,43 @@ class ExecTxnPoller:
|
|
|
27
27
|
|
|
28
28
|
def __init__(
|
|
29
29
|
self,
|
|
30
|
+
print_txn_progress: bool,
|
|
30
31
|
resource: "Resources",
|
|
31
|
-
txn_id: str,
|
|
32
|
+
txn_id: Optional[str] = None,
|
|
32
33
|
headers: Optional[Dict] = None,
|
|
33
34
|
txn_start_time: Optional[float] = None,
|
|
34
35
|
):
|
|
36
|
+
self.print_txn_progress = print_txn_progress
|
|
35
37
|
self.res = resource
|
|
36
38
|
self.txn_id = txn_id
|
|
37
39
|
self.headers = headers or {}
|
|
38
40
|
self.txn_start_time = txn_start_time or time.time()
|
|
39
41
|
|
|
42
|
+
def __enter__(self) -> ExecTxnPoller:
|
|
43
|
+
if not self.print_txn_progress:
|
|
44
|
+
return self
|
|
45
|
+
self.progress = create_progress(
|
|
46
|
+
description=lambda: self.description_with_timing(),
|
|
47
|
+
success_message="", # We'll handle this ourselves
|
|
48
|
+
leading_newline=False,
|
|
49
|
+
trailing_newline=False,
|
|
50
|
+
show_duration_summary=False,
|
|
51
|
+
)
|
|
52
|
+
self.progress.__enter__()
|
|
53
|
+
return self
|
|
54
|
+
|
|
55
|
+
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
|
56
|
+
if not self.print_txn_progress or self.txn_id is None:
|
|
57
|
+
return
|
|
58
|
+
# Update to success message with duration
|
|
59
|
+
total_duration = time.time() - self.txn_start_time
|
|
60
|
+
txn_id = self.txn_id
|
|
61
|
+
self.progress.update_main_status(
|
|
62
|
+
query_complete_message(txn_id, total_duration)
|
|
63
|
+
)
|
|
64
|
+
self.progress.__exit__(exc_type, exc_value, traceback)
|
|
65
|
+
return
|
|
66
|
+
|
|
40
67
|
def poll(self) -> bool:
|
|
41
68
|
"""
|
|
42
69
|
Poll for transaction completion with interactive progress display.
|
|
@@ -44,48 +71,52 @@ class ExecTxnPoller:
|
|
|
44
71
|
Returns:
|
|
45
72
|
True if transaction completed successfully, False otherwise
|
|
46
73
|
"""
|
|
74
|
+
if not self.txn_id:
|
|
75
|
+
raise ValueError("Transaction ID must be provided for polling.")
|
|
76
|
+
else:
|
|
77
|
+
txn_id = self.txn_id
|
|
78
|
+
|
|
79
|
+
if self.print_txn_progress:
|
|
80
|
+
# Update the main status to include the new txn_id
|
|
81
|
+
self.progress.update_main_status_fn(
|
|
82
|
+
lambda: self.description_with_timing(txn_id),
|
|
83
|
+
)
|
|
47
84
|
|
|
48
85
|
# Don't show duration summary - we handle our own completion message
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
"""Check if transaction is complete."""
|
|
58
|
-
elapsed = time.time() - self.txn_start_time
|
|
59
|
-
# Update the main status with elapsed time
|
|
60
|
-
progress.update_main_status(
|
|
61
|
-
query_progress_message(self.txn_id, elapsed)
|
|
62
|
-
)
|
|
63
|
-
return self.res._check_exec_async_status(self.txn_id, headers=self.headers)
|
|
64
|
-
|
|
65
|
-
with debugging.span("wait", txn_id=self.txn_id):
|
|
66
|
-
poll_with_specified_overhead(check_status, overhead_rate=POLL_OVERHEAD_RATE)
|
|
67
|
-
|
|
68
|
-
# Calculate final duration
|
|
69
|
-
total_duration = time.time() - self.txn_start_time
|
|
70
|
-
|
|
71
|
-
# Update to success message with duration
|
|
72
|
-
progress.update_main_status(
|
|
73
|
-
query_complete_message(self.txn_id, total_duration)
|
|
74
|
-
)
|
|
86
|
+
def check_status() -> bool:
|
|
87
|
+
"""Check if transaction is complete."""
|
|
88
|
+
finished = self.res._check_exec_async_status(txn_id, headers=self.headers)
|
|
89
|
+
return finished
|
|
90
|
+
|
|
91
|
+
with debugging.span("wait", txn_id=self.txn_id):
|
|
92
|
+
poll_with_specified_overhead(check_status, overhead_rate=POLL_OVERHEAD_RATE)
|
|
93
|
+
|
|
75
94
|
|
|
76
95
|
return True
|
|
77
96
|
|
|
97
|
+
def description_with_timing(self, txn_id: str | None = None) -> str:
|
|
98
|
+
elapsed = time.time() - self.txn_start_time
|
|
99
|
+
if txn_id is None:
|
|
100
|
+
return query_progress_header(elapsed)
|
|
101
|
+
else:
|
|
102
|
+
return query_progress_message(txn_id, elapsed)
|
|
103
|
+
|
|
104
|
+
def query_progress_header(duration: float) -> str:
|
|
105
|
+
# Don't print sub-second decimals, because it updates too fast and is distracting.
|
|
106
|
+
duration_str = format_duration(duration, seconds_decimals=False)
|
|
107
|
+
return f"Evaluating Query... {duration_str:>15}\n"
|
|
108
|
+
|
|
78
109
|
def query_progress_message(id: str, duration: float) -> str:
|
|
79
110
|
return (
|
|
111
|
+
query_progress_header(duration) +
|
|
80
112
|
# Print with whitespace to align with the end of the transaction ID
|
|
81
|
-
f"
|
|
82
|
-
f"{GRAY_COLOR}Query: {id}{ENDC}"
|
|
113
|
+
f"{GRAY_COLOR}ID: {id}{ENDC}"
|
|
83
114
|
)
|
|
84
115
|
|
|
85
116
|
def query_complete_message(id: str, duration: float, status_header: bool = False) -> str:
|
|
86
117
|
return (
|
|
87
118
|
(f"{GREEN_COLOR}✅ " if status_header else "") +
|
|
88
119
|
# Print with whitespace to align with the end of the transaction ID
|
|
89
|
-
f"Query Complete: {format_duration(duration):>
|
|
90
|
-
f"{GRAY_COLOR}
|
|
120
|
+
f"Query Complete: {format_duration(duration):>21}\n" +
|
|
121
|
+
f"{GRAY_COLOR}ID: {id}{ENDC}"
|
|
91
122
|
)
|
|
@@ -13,7 +13,7 @@ from ...config import Config, ConfigStore, ENDPOINT_FILE
|
|
|
13
13
|
from ...direct_access_client import DirectAccessClient
|
|
14
14
|
from ...types import EngineState
|
|
15
15
|
from ...util import get_pyrel_version, safe_json_loads, ms_to_timestamp
|
|
16
|
-
from ....errors import GuardRailsException, ResponseStatusException, QueryTimeoutExceededException
|
|
16
|
+
from ....errors import GuardRailsException, ResponseStatusException, QueryTimeoutExceededException, RAIException
|
|
17
17
|
from snowflake.snowpark import Session
|
|
18
18
|
|
|
19
19
|
# Import UseIndexResources to enable use_index functionality with direct access
|
|
@@ -163,11 +163,12 @@ class DirectAccessResources(UseIndexResources):
|
|
|
163
163
|
headers=headers,
|
|
164
164
|
)
|
|
165
165
|
response = _send_request()
|
|
166
|
-
except requests.exceptions.ConnectionError as e:
|
|
166
|
+
except (requests.exceptions.ConnectionError, RAIException) as e:
|
|
167
167
|
messages = collect_error_messages(e)
|
|
168
|
-
if any("nameresolutionerror" in msg for msg in messages)
|
|
169
|
-
|
|
170
|
-
#
|
|
168
|
+
if any("nameresolutionerror" in msg for msg in messages) or \
|
|
169
|
+
any("could not find the service associated with endpoint" in msg for msg in messages):
|
|
170
|
+
# when we can not resolve the service endpoint or the service is not found,
|
|
171
|
+
# we assume the endpoint is outdated, so we retrieve it again and retry.
|
|
171
172
|
self.direct_access_client.service_endpoint = self._retrieve_service_endpoint(
|
|
172
173
|
enforce_update=True,
|
|
173
174
|
)
|
|
@@ -15,7 +15,7 @@ import hashlib
|
|
|
15
15
|
from dataclasses import dataclass
|
|
16
16
|
|
|
17
17
|
from ....auth.token_handler import TokenHandler
|
|
18
|
-
from relationalai.clients.exec_txn_poller import ExecTxnPoller
|
|
18
|
+
from relationalai.clients.exec_txn_poller import ExecTxnPoller
|
|
19
19
|
import snowflake.snowpark
|
|
20
20
|
|
|
21
21
|
from ....rel_utils import sanitize_identifier, to_fqn_relation_name
|
|
@@ -105,6 +105,9 @@ TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
|
|
|
105
105
|
GUARDRAILS_ABORT_REASON = "guard rail violation"
|
|
106
106
|
|
|
107
107
|
PRINT_TXN_PROGRESS_FLAG = "print_txn_progress"
|
|
108
|
+
ENABLE_GUARD_RAILS_FLAG = "enable_guard_rails"
|
|
109
|
+
|
|
110
|
+
ENABLE_GUARD_RAILS_HEADER = "X-RAI-Enable-Guard-Rails"
|
|
108
111
|
|
|
109
112
|
#--------------------------------------------------
|
|
110
113
|
# Helpers
|
|
@@ -113,6 +116,9 @@ PRINT_TXN_PROGRESS_FLAG = "print_txn_progress"
|
|
|
113
116
|
def should_print_txn_progress(config) -> bool:
|
|
114
117
|
return bool(config.get(PRINT_TXN_PROGRESS_FLAG, False))
|
|
115
118
|
|
|
119
|
+
def should_enable_guard_rails(config) -> bool:
|
|
120
|
+
return bool(config.get(ENABLE_GUARD_RAILS_FLAG, False))
|
|
121
|
+
|
|
116
122
|
#--------------------------------------------------
|
|
117
123
|
# Resources
|
|
118
124
|
#--------------------------------------------------
|
|
@@ -1788,65 +1794,62 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1788
1794
|
|
|
1789
1795
|
with debugging.span("transaction", **query_attrs_dict) as txn_span:
|
|
1790
1796
|
txn_start_time = time.time()
|
|
1791
|
-
|
|
1792
|
-
# Prepare headers for transaction creation
|
|
1793
|
-
request_headers['user-agent'] = get_pyrel_version(self.generation)
|
|
1794
|
-
request_headers['gi_setup_skipped'] = str(gi_setup_skipped)
|
|
1795
|
-
request_headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
|
|
1796
|
-
|
|
1797
|
-
# Create the transaction
|
|
1798
|
-
result = self._create_v2_txn(
|
|
1799
|
-
database=database,
|
|
1800
|
-
engine=engine,
|
|
1801
|
-
raw_code=raw_code,
|
|
1802
|
-
inputs=inputs,
|
|
1803
|
-
headers=request_headers,
|
|
1804
|
-
readonly=readonly,
|
|
1805
|
-
nowait_durable=nowait_durable,
|
|
1806
|
-
bypass_index=bypass_index,
|
|
1807
|
-
language=language,
|
|
1808
|
-
query_timeout_mins=query_timeout_mins,
|
|
1809
|
-
)
|
|
1797
|
+
print_txn_progress = should_print_txn_progress(self.config)
|
|
1810
1798
|
|
|
1811
|
-
|
|
1812
|
-
|
|
1799
|
+
with ExecTxnPoller(
|
|
1800
|
+
print_txn_progress=print_txn_progress,
|
|
1801
|
+
resource=self, txn_id=None, headers=request_headers,
|
|
1802
|
+
txn_start_time=txn_start_time
|
|
1803
|
+
) as poller:
|
|
1804
|
+
with debugging.span("create_v2", **query_attrs_dict) as create_span:
|
|
1805
|
+
# Prepare headers for transaction creation
|
|
1806
|
+
request_headers['user-agent'] = get_pyrel_version(self.generation)
|
|
1807
|
+
request_headers['gi_setup_skipped'] = str(gi_setup_skipped)
|
|
1808
|
+
request_headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
|
|
1809
|
+
request_headers[ENABLE_GUARD_RAILS_HEADER] = str(should_enable_guard_rails(self.config))
|
|
1810
|
+
|
|
1811
|
+
# Create the transaction
|
|
1812
|
+
result = self._create_v2_txn(
|
|
1813
|
+
database=database,
|
|
1814
|
+
engine=engine,
|
|
1815
|
+
raw_code=raw_code,
|
|
1816
|
+
inputs=inputs,
|
|
1817
|
+
headers=request_headers,
|
|
1818
|
+
readonly=readonly,
|
|
1819
|
+
nowait_durable=nowait_durable,
|
|
1820
|
+
bypass_index=bypass_index,
|
|
1821
|
+
language=language,
|
|
1822
|
+
query_timeout_mins=query_timeout_mins,
|
|
1823
|
+
)
|
|
1813
1824
|
|
|
1814
|
-
|
|
1815
|
-
|
|
1816
|
-
debugging.event("transaction_created", txn_span, txn_id=txn_id)
|
|
1825
|
+
txn_id = result.txn_id
|
|
1826
|
+
state = result.state
|
|
1817
1827
|
|
|
1818
|
-
|
|
1828
|
+
txn_span["txn_id"] = txn_id
|
|
1829
|
+
create_span["txn_id"] = txn_id
|
|
1830
|
+
debugging.event("transaction_created", txn_span, txn_id=txn_id)
|
|
1819
1831
|
|
|
1820
|
-
|
|
1821
|
-
|
|
1822
|
-
txn_end_time = time.time()
|
|
1823
|
-
if txn_id in self._pending_transactions:
|
|
1824
|
-
self._pending_transactions.remove(txn_id)
|
|
1832
|
+
# Set the transaction ID now that we have it, to update the progress text
|
|
1833
|
+
poller.txn_id = txn_id
|
|
1825
1834
|
|
|
1826
|
-
|
|
1835
|
+
# fast path: transaction already finished
|
|
1836
|
+
if state in ["COMPLETED", "ABORTED"]:
|
|
1837
|
+
if txn_id in self._pending_transactions:
|
|
1838
|
+
self._pending_transactions.remove(txn_id)
|
|
1827
1839
|
|
|
1828
|
-
|
|
1829
|
-
if print_txn_progress:
|
|
1830
|
-
print(
|
|
1831
|
-
query_complete_message(txn_id, txn_duration, status_header=True)
|
|
1832
|
-
)
|
|
1840
|
+
artifact_info = result.artifact_info
|
|
1833
1841
|
|
|
1834
|
-
|
|
1835
|
-
|
|
1836
|
-
|
|
1837
|
-
|
|
1838
|
-
|
|
1839
|
-
if print_txn_progress:
|
|
1840
|
-
poller = ExecTxnPoller(resource=self, txn_id=txn_id, headers=request_headers, txn_start_time=txn_start_time)
|
|
1842
|
+
# Slow path: transaction not done yet; start polling
|
|
1843
|
+
else:
|
|
1844
|
+
self._pending_transactions.append(txn_id)
|
|
1845
|
+
# Use the interactive poller for transaction status
|
|
1846
|
+
with debugging.span("wait", txn_id=txn_id):
|
|
1841
1847
|
poller.poll()
|
|
1842
|
-
else:
|
|
1843
|
-
poll_with_specified_overhead(
|
|
1844
|
-
lambda: self._check_exec_async_status(txn_id, headers=request_headers), 0.1
|
|
1845
|
-
)
|
|
1846
|
-
artifact_info = self._list_exec_async_artifacts(txn_id, headers=request_headers)
|
|
1847
1848
|
|
|
1848
|
-
|
|
1849
|
-
|
|
1849
|
+
artifact_info = self._list_exec_async_artifacts(txn_id, headers=request_headers)
|
|
1850
|
+
|
|
1851
|
+
with debugging.span("fetch"):
|
|
1852
|
+
return self._download_results(artifact_info, txn_id, state)
|
|
1850
1853
|
|
|
1851
1854
|
def get_user_based_engine_name(self):
|
|
1852
1855
|
if not self._session:
|
|
@@ -66,7 +66,7 @@ ERP_CHECK_FREQUENCY = 15
|
|
|
66
66
|
|
|
67
67
|
# Polling behavior constants
|
|
68
68
|
POLL_OVERHEAD_RATE = 0.1 # Overhead rate for exponential backoff
|
|
69
|
-
POLL_MAX_DELAY =
|
|
69
|
+
POLL_MAX_DELAY = 0 # Maximum delay between polls in seconds
|
|
70
70
|
|
|
71
71
|
# SQL query template for getting stream column hashes
|
|
72
72
|
# This query calculates a hash of column metadata (name, type, precision, scale, nullable)
|
|
@@ -214,6 +214,7 @@ class Table():
|
|
|
214
214
|
self._col_names = cols
|
|
215
215
|
self._iceberg_config = config
|
|
216
216
|
self._is_iceberg = config is not None
|
|
217
|
+
self._skip_cdc = False
|
|
217
218
|
info = self._schemas.get((self._database, self._schema))
|
|
218
219
|
if not info:
|
|
219
220
|
info = self._schemas[(self._database, self._schema)] = SchemaInfo(self._database, self._schema)
|
|
@@ -303,7 +304,10 @@ class Table():
|
|
|
303
304
|
|
|
304
305
|
def _compile_lookup(self, compiler:b.Compiler, ctx:b.CompilerContext):
|
|
305
306
|
self._lazy_init()
|
|
306
|
-
|
|
307
|
+
if not self._skip_cdc:
|
|
308
|
+
# Don't do CDC if the underlying data has been loaded
|
|
309
|
+
# directly via `api.load_data`.
|
|
310
|
+
Table._used_sources.add(self)
|
|
307
311
|
compiler.lookup(self._rel, ctx)
|
|
308
312
|
return compiler.lookup(b.RelationshipFieldRef(None, self._rel, 0), ctx)
|
|
309
313
|
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
from typing import TypeGuard
|
|
2
|
+
from relationalai.semantics.metamodel import ir, factory, types
|
|
3
|
+
from relationalai.semantics.metamodel.visitor import Rewriter, collect_by_type
|
|
4
|
+
from relationalai.semantics.lqp import ir as lqp
|
|
5
|
+
from relationalai.semantics.lqp.types import meta_type_to_lqp
|
|
6
|
+
from relationalai.semantics.lqp.builtins import (
|
|
7
|
+
has_empty_annotation, has_assign_annotation, has_upsert_annotation,
|
|
8
|
+
has_monoid_annotation, has_monus_annotation, has_script_annotation,
|
|
9
|
+
has_algorithm_annotation, has_while_annotation, global_annotation,
|
|
10
|
+
empty_annotation, assign_annotation, upsert_annotation, monoid_annotation,
|
|
11
|
+
monus_annotation
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
# Complex tests for Loopy constructs in the metamodel
|
|
15
|
+
def is_script(task: ir.Task) -> TypeGuard[ir.Sequence]:
|
|
16
|
+
""" Check if it is a script i.e., a Sequence with @script annotation. """
|
|
17
|
+
if not isinstance(task, ir.Sequence):
|
|
18
|
+
return False
|
|
19
|
+
return has_script_annotation(task)
|
|
20
|
+
|
|
21
|
+
def is_algorithm_logical(task: ir.Task) -> TypeGuard[ir.Logical]:
|
|
22
|
+
""" Check if it is an algorithm logical i.e., a Logical task with all subtasks being
|
|
23
|
+
algorithm scripts. """
|
|
24
|
+
if not isinstance(task, ir.Logical):
|
|
25
|
+
return False
|
|
26
|
+
return all(is_algorithm_script(subtask) for subtask in task.body)
|
|
27
|
+
|
|
28
|
+
def is_algorithm_script(task: ir.Task) -> TypeGuard[ir.Sequence]:
|
|
29
|
+
""" Check if it is an algorithm script i.e., a Sequence with @script and @algorithm annotations. """
|
|
30
|
+
if not isinstance(task, ir.Sequence):
|
|
31
|
+
return False
|
|
32
|
+
return is_script(task) and has_algorithm_annotation(task)
|
|
33
|
+
|
|
34
|
+
def is_while_loop(task: ir.Task) -> TypeGuard[ir.Loop]:
|
|
35
|
+
""" Check if input is is a while loop i.e., a Loop with @while annotation. """
|
|
36
|
+
if not isinstance(task, ir.Loop):
|
|
37
|
+
return False
|
|
38
|
+
return has_while_annotation(task)
|
|
39
|
+
|
|
40
|
+
def is_while_script(task: ir.Task) -> TypeGuard[ir.Sequence]:
|
|
41
|
+
""" Check if input is a while script i.e., a Sequence with @script and @while annotations. """
|
|
42
|
+
if not isinstance(task, ir.Sequence):
|
|
43
|
+
return False
|
|
44
|
+
return is_script(task) and has_while_annotation(task)
|
|
45
|
+
|
|
46
|
+
# Tools for annotating Loopy constructs
|
|
47
|
+
class LoopyAnnoAdder(Rewriter):
|
|
48
|
+
""" Rewrites a node by adding the given annotation to all Update nodes. """
|
|
49
|
+
def __init__(self, anno: ir.Annotation):
|
|
50
|
+
self.anno = anno
|
|
51
|
+
super().__init__()
|
|
52
|
+
|
|
53
|
+
def handle_update(self, node: ir.Update, parent: ir.Node) -> ir.Update:
|
|
54
|
+
new_annos = list(node.annotations) + [self.anno]
|
|
55
|
+
return factory.update(node.relation, node.args, node.effect, new_annos, node.engine)
|
|
56
|
+
|
|
57
|
+
def mk_global(i: ir.Node):
|
|
58
|
+
return LoopyAnnoAdder(global_annotation()).walk(i)
|
|
59
|
+
|
|
60
|
+
def mk_empty(i: ir.Node):
|
|
61
|
+
return LoopyAnnoAdder(empty_annotation()).walk(i)
|
|
62
|
+
|
|
63
|
+
def mk_assign(i: ir.Node):
|
|
64
|
+
return LoopyAnnoAdder(assign_annotation()).walk(i)
|
|
65
|
+
|
|
66
|
+
def mk_upsert(i: ir.Node, arity: int):
|
|
67
|
+
return LoopyAnnoAdder(upsert_annotation(arity)).walk(i)
|
|
68
|
+
|
|
69
|
+
def mk_monoid(i: ir.Node, monoid_type: ir.ScalarType, monoid_op: str, arity: int):
|
|
70
|
+
return LoopyAnnoAdder(monoid_annotation(monoid_type, monoid_op, arity)).walk(i)
|
|
71
|
+
|
|
72
|
+
def mk_monus(i: ir.Node, monoid_type: ir.ScalarType, monoid_op: str, arity: int):
|
|
73
|
+
return LoopyAnnoAdder(monus_annotation(monoid_type, monoid_op, arity)).walk(i)
|
|
74
|
+
|
|
75
|
+
def construct_monoid(i: ir.Annotation):
|
|
76
|
+
base_type = None
|
|
77
|
+
op = None
|
|
78
|
+
for arg in i.args:
|
|
79
|
+
if isinstance(arg, ir.ScalarType):
|
|
80
|
+
base_type = meta_type_to_lqp(arg)
|
|
81
|
+
elif isinstance(arg, ir.Literal) and arg.type == types.String:
|
|
82
|
+
op = arg.value
|
|
83
|
+
assert isinstance(base_type, lqp.Type) and isinstance(op, str), "Failed to get monoid"
|
|
84
|
+
if op.lower() == "or":
|
|
85
|
+
return lqp.OrMonoid(meta=None)
|
|
86
|
+
elif op.lower() == "sum":
|
|
87
|
+
return lqp.SumMonoid(type=base_type, meta=None)
|
|
88
|
+
elif op.lower() == "min":
|
|
89
|
+
return lqp.MinMonoid(type=base_type, meta=None)
|
|
90
|
+
elif op.lower() == "max":
|
|
91
|
+
return lqp.MaxMonoid(type=base_type, meta=None)
|
|
92
|
+
else:
|
|
93
|
+
assert False, "Failed to get monoid"
|
|
94
|
+
|
|
95
|
+
# Tools for analyzing Loopy constructs
|
|
96
|
+
def is_logical_instruction(node: ir.Node) -> TypeGuard[ir.Logical]:
|
|
97
|
+
if not isinstance(node, ir.Logical):
|
|
98
|
+
return False
|
|
99
|
+
return any(collect_by_type(ir.Update, node)) and not any(collect_by_type(ir.Sequence, node))
|
|
100
|
+
|
|
101
|
+
def get_instruction_body_rels(node: ir.Logical) -> set[ir.Relation]:
|
|
102
|
+
assert is_logical_instruction(node)
|
|
103
|
+
body: set[ir.Relation] = set()
|
|
104
|
+
for update in collect_by_type(ir.Lookup, node):
|
|
105
|
+
body.add(update.relation)
|
|
106
|
+
return body
|
|
107
|
+
|
|
108
|
+
def get_instruction_head_rels(node: ir.Logical) -> set[ir.Relation]:
|
|
109
|
+
assert is_logical_instruction(node)
|
|
110
|
+
heads: set[ir.Relation] = set()
|
|
111
|
+
for update in collect_by_type(ir.Update, node):
|
|
112
|
+
heads.add(update.relation)
|
|
113
|
+
return heads
|
|
114
|
+
|
|
115
|
+
# base Loopy instruction: @empty, @assign, @upsert, @monoid, @monus
|
|
116
|
+
def is_instruction(update: ir.Task) -> TypeGuard[ir.Logical]:
|
|
117
|
+
if not is_logical_instruction(update):
|
|
118
|
+
return False
|
|
119
|
+
for u in collect_by_type(ir.Update, update):
|
|
120
|
+
if (has_empty_annotation(u) or
|
|
121
|
+
has_assign_annotation(u) or
|
|
122
|
+
has_upsert_annotation(u) or
|
|
123
|
+
has_monoid_annotation(u) or
|
|
124
|
+
has_monus_annotation(u)):
|
|
125
|
+
return True
|
|
126
|
+
return False
|
|
127
|
+
|
|
128
|
+
# update Loopy instruction @upsert, @monoid, @monus
|
|
129
|
+
def is_update_instruction(task: ir.Task) -> TypeGuard[ir.Logical]:
|
|
130
|
+
if not is_logical_instruction(task):
|
|
131
|
+
return False
|
|
132
|
+
for u in collect_by_type(ir.Update, task):
|
|
133
|
+
if (has_upsert_annotation(u) or
|
|
134
|
+
has_monoid_annotation(u) or
|
|
135
|
+
has_monus_annotation(u)):
|
|
136
|
+
return True
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
def is_empty_instruction(node: ir.Node) -> TypeGuard[ir.Logical]:
|
|
140
|
+
""" Check if input is an empty Loopy instruction `empty rel = ∅`"""
|
|
141
|
+
if not is_logical_instruction(node):
|
|
142
|
+
return False
|
|
143
|
+
updates = collect_by_type(ir.Update, node)
|
|
144
|
+
if not any(has_empty_annotation(update) for update in updates):
|
|
145
|
+
return False
|
|
146
|
+
|
|
147
|
+
# At this point, we have the prerequisites for an empty instruction. We check it is
|
|
148
|
+
# well-formed:
|
|
149
|
+
# 1. It has only a single @empty Update operation
|
|
150
|
+
# 2. Has no other operations
|
|
151
|
+
assert len(updates) == 1, "[Loopy] Empty instruction must have single Update operation"
|
|
152
|
+
assert len(node.body) == 1, "[Loopy] Empty instruction must have only a single Update operation"
|
|
153
|
+
|
|
154
|
+
return True
|
|
155
|
+
|
|
156
|
+
# Splits a Loopy instruction into its head updates, body lookups, and other body tasks
|
|
157
|
+
def split_instruction(update_logical: ir.Logical) -> tuple[ir.Update,list[ir.Lookup],list[ir.Task]]:
|
|
158
|
+
assert is_instruction(update_logical)
|
|
159
|
+
lookups = []
|
|
160
|
+
update = None
|
|
161
|
+
others = []
|
|
162
|
+
for task in update_logical.body:
|
|
163
|
+
if isinstance(task, ir.Lookup):
|
|
164
|
+
lookups.append(task)
|
|
165
|
+
elif isinstance(task, ir.Update):
|
|
166
|
+
if update is not None:
|
|
167
|
+
raise AssertionError("[Loopy] Update instruction must have exactly one Update operation")
|
|
168
|
+
update = task
|
|
169
|
+
else:
|
|
170
|
+
others.append(task)
|
|
171
|
+
assert update is not None, "[Loopy] Update instruction must have exactly one Update operation"
|
|
172
|
+
|
|
173
|
+
return update, lookups, others
|