relationalai 0.12.6__py3-none-any.whl → 0.12.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/clients/snowflake.py +48 -7
- relationalai/clients/use_index_poller.py +11 -1
- relationalai/early_access/lqp/constructors/__init__.py +2 -2
- relationalai/early_access/metamodel/rewrite/__init__.py +2 -2
- relationalai/semantics/internal/internal.py +1 -4
- relationalai/semantics/internal/snowflake.py +14 -1
- relationalai/semantics/lqp/constructors.py +0 -5
- relationalai/semantics/lqp/executor.py +34 -10
- relationalai/semantics/lqp/intrinsics.py +2 -2
- relationalai/semantics/lqp/model2lqp.py +10 -7
- relationalai/semantics/lqp/passes.py +29 -9
- relationalai/semantics/lqp/primitives.py +15 -15
- relationalai/semantics/lqp/rewrite/__init__.py +2 -2
- relationalai/semantics/lqp/rewrite/{fd_constraints.py → function_annotations.py} +4 -4
- relationalai/semantics/lqp/utils.py +17 -13
- relationalai/semantics/metamodel/builtins.py +1 -0
- relationalai/semantics/metamodel/rewrite/__init__.py +2 -1
- relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +1 -1
- relationalai/semantics/metamodel/rewrite/extract_nested_logicals.py +5 -6
- relationalai/semantics/metamodel/rewrite/flatten.py +18 -149
- relationalai/semantics/metamodel/rewrite/format_outputs.py +165 -0
- relationalai/semantics/reasoners/graph/core.py +98 -70
- relationalai/semantics/reasoners/optimization/__init__.py +55 -10
- relationalai/semantics/reasoners/optimization/common.py +63 -8
- relationalai/semantics/reasoners/optimization/solvers_dev.py +39 -33
- relationalai/semantics/reasoners/optimization/solvers_pb.py +1033 -385
- relationalai/semantics/rel/compiler.py +4 -3
- relationalai/semantics/rel/executor.py +30 -8
- relationalai/semantics/snowflake/__init__.py +2 -2
- relationalai/semantics/sql/executor/snowflake.py +6 -2
- relationalai/semantics/tests/test_snapshot_abstract.py +5 -4
- relationalai/tools/cli.py +10 -0
- relationalai/tools/cli_controls.py +15 -0
- {relationalai-0.12.6.dist-info → relationalai-0.12.8.dist-info}/METADATA +2 -2
- {relationalai-0.12.6.dist-info → relationalai-0.12.8.dist-info}/RECORD +38 -37
- {relationalai-0.12.6.dist-info → relationalai-0.12.8.dist-info}/WHEEL +0 -0
- {relationalai-0.12.6.dist-info → relationalai-0.12.8.dist-info}/entry_points.txt +0 -0
- {relationalai-0.12.6.dist-info → relationalai-0.12.8.dist-info}/licenses/LICENSE +0 -0
|
@@ -851,7 +851,17 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
851
851
|
self.generation
|
|
852
852
|
)
|
|
853
853
|
# If cache is valid (data freshness has not expired), skip polling
|
|
854
|
-
if
|
|
854
|
+
if poller.cache.is_valid():
|
|
855
|
+
cached_sources = len(poller.cache.sources)
|
|
856
|
+
total_sources = len(sources_list)
|
|
857
|
+
cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
|
|
858
|
+
|
|
859
|
+
message = f"Using cached data for {cached_sources}/{total_sources} data streams"
|
|
860
|
+
if cached_timestamp:
|
|
861
|
+
print(f"\n{message} (cached at {cached_timestamp})\n")
|
|
862
|
+
else:
|
|
863
|
+
print(f"\n{message}\n")
|
|
864
|
+
else:
|
|
855
865
|
return poller.poll()
|
|
856
866
|
|
|
857
867
|
#--------------------------------------------------
|
|
@@ -2842,6 +2852,15 @@ class SnowflakeTable(dsl.Type):
|
|
|
2842
2852
|
else:
|
|
2843
2853
|
self(snowflake_id=id).set(**{prop_ident: val})
|
|
2844
2854
|
|
|
2855
|
+
# Because we're bypassing a bunch of the normal Type.add machinery here,
|
|
2856
|
+
# we need to manually account for the case where people are using value types.
|
|
2857
|
+
def wrapped(x):
|
|
2858
|
+
if not model._config.get("compiler.use_value_types", False):
|
|
2859
|
+
return x
|
|
2860
|
+
other_id = dsl.create_var()
|
|
2861
|
+
model._action(dsl.build.construct(self._type, [x, other_id]))
|
|
2862
|
+
return other_id
|
|
2863
|
+
|
|
2845
2864
|
# new UInt128 schema mapping rules
|
|
2846
2865
|
with model.rule(dynamic=True, globalize=True, source=self._source):
|
|
2847
2866
|
id = dsl.create_var()
|
|
@@ -2851,7 +2870,7 @@ class SnowflakeTable(dsl.Type):
|
|
|
2851
2870
|
# for avoiding a non-blocking warning
|
|
2852
2871
|
edb(dsl.Symbol("METADATA$KEY"), id)
|
|
2853
2872
|
std.rel.UInt128(id)
|
|
2854
|
-
self.add(id, snowflake_id=id)
|
|
2873
|
+
self.add(wrapped(id), snowflake_id=id)
|
|
2855
2874
|
|
|
2856
2875
|
for prop, prop_type in self._schema["columns"].items():
|
|
2857
2876
|
_prop = prop
|
|
@@ -2873,7 +2892,7 @@ class SnowflakeTable(dsl.Type):
|
|
|
2873
2892
|
model._check_property(_prop._prop)
|
|
2874
2893
|
raw_relation = getattr(std.rel, prop_ident)
|
|
2875
2894
|
dsl.tag(raw_relation, dsl.Builtins.FunctionAnnotation)
|
|
2876
|
-
raw_relation.add(id, val)
|
|
2895
|
+
raw_relation.add(wrapped(id), val)
|
|
2877
2896
|
|
|
2878
2897
|
def namespace(self):
|
|
2879
2898
|
return f"{self._parent._parent._name}.{self._parent._name}"
|
|
@@ -3275,12 +3294,24 @@ class DirectAccessResources(Resources):
|
|
|
3275
3294
|
try:
|
|
3276
3295
|
response = _send_request()
|
|
3277
3296
|
if response.status_code != 200:
|
|
3297
|
+
# For 404 responses with skip_auto_create=True, return immediately to let caller handle it
|
|
3298
|
+
# (e.g., get_engine needs to check 404 and return None for auto_create_engine)
|
|
3299
|
+
# For skip_auto_create=False, continue to auto-creation logic below
|
|
3300
|
+
if response.status_code == 404 and skip_auto_create:
|
|
3301
|
+
return response
|
|
3302
|
+
|
|
3278
3303
|
try:
|
|
3279
3304
|
message = response.json().get("message", "")
|
|
3280
3305
|
except requests.exceptions.JSONDecodeError:
|
|
3281
|
-
|
|
3282
|
-
|
|
3283
|
-
|
|
3306
|
+
# Can't parse JSON response. For skip_auto_create=True (e.g., get_engine),
|
|
3307
|
+
# this should have been caught by the 404 check above, so this is an error.
|
|
3308
|
+
# For skip_auto_create=False, we explicitly check status_code below,
|
|
3309
|
+
# so we don't need to parse the message.
|
|
3310
|
+
if skip_auto_create:
|
|
3311
|
+
raise ResponseStatusException(
|
|
3312
|
+
f"Failed to parse error response from endpoint {endpoint}.", response
|
|
3313
|
+
)
|
|
3314
|
+
message = "" # Not used when we check status_code directly
|
|
3284
3315
|
|
|
3285
3316
|
# fix engine on engine error and retry
|
|
3286
3317
|
# Skip auto-retry if skip_auto_create is True to avoid recursion
|
|
@@ -3473,7 +3504,17 @@ class DirectAccessResources(Resources):
|
|
|
3473
3504
|
generation=self.generation,
|
|
3474
3505
|
)
|
|
3475
3506
|
# If cache is valid (data freshness has not expired), skip polling
|
|
3476
|
-
if
|
|
3507
|
+
if poller.cache.is_valid():
|
|
3508
|
+
cached_sources = len(poller.cache.sources)
|
|
3509
|
+
total_sources = len(sources_list)
|
|
3510
|
+
cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
|
|
3511
|
+
|
|
3512
|
+
message = f"Using cached data for {cached_sources}/{total_sources} data streams"
|
|
3513
|
+
if cached_timestamp:
|
|
3514
|
+
print(f"\n{message} (cached at {cached_timestamp})\n")
|
|
3515
|
+
else:
|
|
3516
|
+
print(f"\n{message}\n")
|
|
3517
|
+
else:
|
|
3477
3518
|
return poller.poll()
|
|
3478
3519
|
|
|
3479
3520
|
def _check_exec_async_status(self, txn_id: str, headers: Dict[str, str] | None = None) -> bool:
|
|
@@ -250,7 +250,17 @@ class UseIndexPoller:
|
|
|
250
250
|
# Cache was used - show how many sources were cached
|
|
251
251
|
total_sources = len(self.cache.sources)
|
|
252
252
|
cached_sources = total_sources - len(self.sources)
|
|
253
|
-
|
|
253
|
+
|
|
254
|
+
# Get the timestamp when sources were cached
|
|
255
|
+
entry = self.cache._metadata.get("cachedIndices", {}).get(self.cache.key, {})
|
|
256
|
+
cached_timestamp = entry.get("last_use_index_update_on", "")
|
|
257
|
+
|
|
258
|
+
message = f"Using cached data for {cached_sources}/{total_sources} data streams"
|
|
259
|
+
# Format the message with timestamp
|
|
260
|
+
if cached_timestamp:
|
|
261
|
+
message += f" (cached at {cached_timestamp})"
|
|
262
|
+
|
|
263
|
+
progress.add_sub_task(message, task_id="cache_usage", category=TASK_CATEGORY_CACHE)
|
|
254
264
|
# Complete the subtask immediately since it's just informational
|
|
255
265
|
progress.complete_sub_task("cache_usage")
|
|
256
266
|
|
|
@@ -2,12 +2,12 @@ import warnings
|
|
|
2
2
|
|
|
3
3
|
from relationalai.semantics.lqp.constructors import (
|
|
4
4
|
mk_abstraction, mk_and, mk_exists, mk_or, mk_pragma, mk_primitive,
|
|
5
|
-
mk_specialized_value, mk_type, mk_value
|
|
5
|
+
mk_specialized_value, mk_type, mk_value
|
|
6
6
|
)
|
|
7
7
|
|
|
8
8
|
__all__ = [
|
|
9
9
|
"mk_abstraction", "mk_and", "mk_exists", "mk_or", "mk_pragma", "mk_primitive", "mk_specialized_value", "mk_type",
|
|
10
|
-
"mk_value"
|
|
10
|
+
"mk_value"
|
|
11
11
|
]
|
|
12
12
|
|
|
13
13
|
warnings.warn(
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from relationalai.semantics.metamodel.rewrite import Flatten, \
|
|
2
2
|
DNFUnionSplitter, ExtractNestedLogicals, flatten
|
|
3
3
|
from relationalai.semantics.lqp.rewrite import Splinter, \
|
|
4
|
-
ExtractKeys,
|
|
4
|
+
ExtractKeys, FunctionAnnotations
|
|
5
5
|
|
|
6
6
|
__all__ = ["Splinter", "Flatten", "DNFUnionSplitter", "ExtractKeys",
|
|
7
|
-
"ExtractNestedLogicals", "
|
|
7
|
+
"ExtractNestedLogicals", "FunctionAnnotations", "flatten"]
|
|
@@ -2550,8 +2550,6 @@ class Fragment():
|
|
|
2550
2550
|
from .snowflake import Table
|
|
2551
2551
|
assert isinstance(table, Table), "Only Snowflake tables are supported for now"
|
|
2552
2552
|
|
|
2553
|
-
result_cols = table._col_names
|
|
2554
|
-
|
|
2555
2553
|
clone = Fragment(parent=self)
|
|
2556
2554
|
clone._is_export = True
|
|
2557
2555
|
qb_model = clone._model or Model("anon")
|
|
@@ -2559,8 +2557,7 @@ class Fragment():
|
|
|
2559
2557
|
clone._source = runtime_env.get_source_pos()
|
|
2560
2558
|
with debugging.span("query", dsl=str(clone), **with_source(clone), meta=clone._meta):
|
|
2561
2559
|
query_task = qb_model._compiler.fragment(clone)
|
|
2562
|
-
qb_model._to_executor().execute(ir_model, query_task,
|
|
2563
|
-
|
|
2560
|
+
qb_model._to_executor().execute(ir_model, query_task, export_to=table, update=update, meta=clone._meta)
|
|
2564
2561
|
|
|
2565
2562
|
#--------------------------------------------------
|
|
2566
2563
|
# Select / Where
|
|
@@ -12,7 +12,18 @@ from . import internal as b, annotations as anns
|
|
|
12
12
|
from relationalai import debugging
|
|
13
13
|
from relationalai.errors import UnsupportedColumnTypesWarning
|
|
14
14
|
from snowflake.snowpark.context import get_active_session
|
|
15
|
+
from typing import ClassVar, Optional
|
|
15
16
|
|
|
17
|
+
#--------------------------------------------------
|
|
18
|
+
# Iceberg Configuration
|
|
19
|
+
#--------------------------------------------------
|
|
20
|
+
@dataclass
|
|
21
|
+
class IcebergConfig:
|
|
22
|
+
"""Configuration for exporting to Iceberg tables."""
|
|
23
|
+
external_volume: str | None = None
|
|
24
|
+
default: ClassVar[Optional["IcebergConfig"]]
|
|
25
|
+
|
|
26
|
+
IcebergConfig.default = IcebergConfig()
|
|
16
27
|
#--------------------------------------------------
|
|
17
28
|
# Helpers
|
|
18
29
|
#--------------------------------------------------
|
|
@@ -191,7 +202,7 @@ class Table():
|
|
|
191
202
|
_schemas:dict[tuple[str, str], SchemaInfo] = {}
|
|
192
203
|
_used_sources:OrderedSet[Table] = ordered_set()
|
|
193
204
|
|
|
194
|
-
def __init__(self, fqn:str, cols:list[str]|None=None, schema:dict[str, str|b.Concept]|None=None) -> None:
|
|
205
|
+
def __init__(self, fqn:str, cols:list[str]|None=None, schema:dict[str, str|b.Concept]|None=None, config: IcebergConfig|None=None) -> None:
|
|
195
206
|
self._fqn = fqn
|
|
196
207
|
parser = IdentityParser(fqn, require_all_parts=True)
|
|
197
208
|
self._database, self._schema, self._table, self._fqn = parser.to_list()
|
|
@@ -201,6 +212,8 @@ class Table():
|
|
|
201
212
|
self._ref = self._concept.ref("row_id")
|
|
202
213
|
self._cols = {}
|
|
203
214
|
self._col_names = cols
|
|
215
|
+
self._iceberg_config = config
|
|
216
|
+
self._is_iceberg = config is not None
|
|
204
217
|
info = self._schemas.get((self._database, self._schema))
|
|
205
218
|
if not info:
|
|
206
219
|
info = self._schemas[(self._database, self._schema)] = SchemaInfo(self._database, self._schema)
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from typing import Tuple
|
|
2
2
|
from relationalai.semantics.lqp import ir as lqp
|
|
3
|
-
from relationalai.semantics.metamodel.ir import sanitize
|
|
4
3
|
|
|
5
4
|
def mk_and(args: list[lqp.Formula]) -> lqp.Formula:
|
|
6
5
|
# Flatten nested conjunctions
|
|
@@ -49,10 +48,6 @@ def mk_specialized_value(value) -> lqp.SpecializedValue:
|
|
|
49
48
|
def mk_value(value) -> lqp.Value:
|
|
50
49
|
return lqp.Value(value=value, meta=None)
|
|
51
50
|
|
|
52
|
-
def mk_var(name: str) -> lqp.Var:
|
|
53
|
-
_name = '_' if name == '_' else sanitize(name)
|
|
54
|
-
return lqp.Var(name=_name, meta=None)
|
|
55
|
-
|
|
56
51
|
def mk_type(typename: lqp.TypeName, parameters: list[lqp.Value]=[]) -> lqp.Type:
|
|
57
52
|
return lqp.Type(type_name=typename, parameters=parameters, meta=None)
|
|
58
53
|
|
|
@@ -4,7 +4,7 @@ import atexit
|
|
|
4
4
|
import re
|
|
5
5
|
|
|
6
6
|
from pandas import DataFrame
|
|
7
|
-
from typing import Any, Optional, Literal
|
|
7
|
+
from typing import Any, Optional, Literal, TYPE_CHECKING
|
|
8
8
|
from snowflake.snowpark import Session
|
|
9
9
|
import relationalai as rai
|
|
10
10
|
|
|
@@ -20,10 +20,14 @@ from relationalai.semantics.lqp.ir import convert_transaction, validate_lqp
|
|
|
20
20
|
from relationalai.clients.config import Config
|
|
21
21
|
from relationalai.clients.snowflake import APP_NAME
|
|
22
22
|
from relationalai.clients.types import TransactionAsyncResponse
|
|
23
|
-
from relationalai.clients.util import IdentityParser
|
|
23
|
+
from relationalai.clients.util import IdentityParser, escape_for_f_string
|
|
24
24
|
from relationalai.tools.constants import USE_DIRECT_ACCESS, QUERY_ATTRIBUTES_HEADER
|
|
25
25
|
from relationalai.tools.query_utils import prepare_metadata_for_headers
|
|
26
26
|
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from relationalai.semantics.snowflake import Table
|
|
29
|
+
|
|
30
|
+
|
|
27
31
|
class LQPExecutor(e.Executor):
|
|
28
32
|
"""Executes LQP using the RAI client."""
|
|
29
33
|
|
|
@@ -172,12 +176,12 @@ class LQPExecutor(e.Executor):
|
|
|
172
176
|
elif len(all_errors) > 1:
|
|
173
177
|
raise errors.RAIExceptionSet(all_errors)
|
|
174
178
|
|
|
175
|
-
def _export(self, txn_id: str, export_info: tuple,
|
|
179
|
+
def _export(self, txn_id: str, export_info: tuple, dest: Table, actual_cols: list[str], declared_cols: list[str], update: bool):
|
|
176
180
|
# At this point of the export, we assume that a CSV file has already been written
|
|
177
181
|
# to the Snowflake Native App stage area. Thus, the purpose of this method is to
|
|
178
182
|
# copy the data from the CSV file to the destination table.
|
|
179
183
|
_exec = self.resources._exec
|
|
180
|
-
dest_database, dest_schema, dest_table, _ = IdentityParser(
|
|
184
|
+
dest_database, dest_schema, dest_table, _ = IdentityParser(dest._fqn, require_all_parts=True).to_list()
|
|
181
185
|
filename = export_info[0]
|
|
182
186
|
result_table_name = filename + "_table"
|
|
183
187
|
|
|
@@ -203,8 +207,28 @@ class LQPExecutor(e.Executor):
|
|
|
203
207
|
# destination table. This step also cleans up the result table.
|
|
204
208
|
out_sample = _exec(f"select * from {APP_NAME}.results.{result_table_name} limit 1;")
|
|
205
209
|
names = self._build_projection(declared_cols, actual_cols, column_fields, out_sample)
|
|
210
|
+
dest_fqn = dest._fqn
|
|
206
211
|
try:
|
|
207
212
|
if not update:
|
|
213
|
+
createTableLogic = f"""
|
|
214
|
+
CREATE TABLE {dest_fqn} AS
|
|
215
|
+
SELECT {names}
|
|
216
|
+
FROM {APP_NAME}.results.{result_table_name};
|
|
217
|
+
"""
|
|
218
|
+
if dest._is_iceberg:
|
|
219
|
+
assert dest._iceberg_config is not None
|
|
220
|
+
external_volume_clause = ""
|
|
221
|
+
if dest._iceberg_config.external_volume:
|
|
222
|
+
external_volume_clause = f"EXTERNAL_VOLUME = '{dest._iceberg_config.external_volume}'"
|
|
223
|
+
createTableLogic = f"""
|
|
224
|
+
CREATE ICEBERG TABLE {dest_fqn}
|
|
225
|
+
CATALOG = "SNOWFLAKE"
|
|
226
|
+
{external_volume_clause}
|
|
227
|
+
AS
|
|
228
|
+
SELECT {names}
|
|
229
|
+
FROM {APP_NAME}.results.{result_table_name};
|
|
230
|
+
"""
|
|
231
|
+
|
|
208
232
|
_exec(f"""
|
|
209
233
|
BEGIN
|
|
210
234
|
-- Check if table exists
|
|
@@ -227,9 +251,7 @@ class LQPExecutor(e.Executor):
|
|
|
227
251
|
ELSE
|
|
228
252
|
-- Create table based on the SELECT
|
|
229
253
|
EXECUTE IMMEDIATE '
|
|
230
|
-
|
|
231
|
-
SELECT {names}
|
|
232
|
-
FROM {APP_NAME}.results.{result_table_name};
|
|
254
|
+
{escape_for_f_string(createTableLogic)}
|
|
233
255
|
';
|
|
234
256
|
END IF;
|
|
235
257
|
END;
|
|
@@ -376,7 +398,7 @@ class LQPExecutor(e.Executor):
|
|
|
376
398
|
return final_model, export_info, txn_proto
|
|
377
399
|
|
|
378
400
|
# TODO (azreika): This should probably be split up into exporting and other processing. There are quite a lot of arguments here...
|
|
379
|
-
def _process_results(self, task: ir.Task, final_model: ir.Model, raw_results: TransactionAsyncResponse,
|
|
401
|
+
def _process_results(self, task: ir.Task, final_model: ir.Model, raw_results: TransactionAsyncResponse, export_info: Optional[tuple], export_to: Optional[Table], update: bool) -> DataFrame:
|
|
380
402
|
cols, extra_cols = self._compute_cols(task, final_model)
|
|
381
403
|
|
|
382
404
|
df, errs = result_helpers.format_results(raw_results, cols)
|
|
@@ -391,6 +413,8 @@ class LQPExecutor(e.Executor):
|
|
|
391
413
|
assert cols, "No columns found in the output"
|
|
392
414
|
assert isinstance(raw_results, TransactionAsyncResponse) and raw_results.transaction, "Invalid transaction result"
|
|
393
415
|
|
|
416
|
+
result_cols = export_to._col_names
|
|
417
|
+
|
|
394
418
|
if result_cols is not None:
|
|
395
419
|
assert all(col in result_cols or col in extra_cols for col in cols)
|
|
396
420
|
else:
|
|
@@ -403,7 +427,7 @@ class LQPExecutor(e.Executor):
|
|
|
403
427
|
return self._postprocess_df(self.config, df, extra_cols)
|
|
404
428
|
|
|
405
429
|
def execute(self, model: ir.Model, task: ir.Task, format: Literal["pandas", "snowpark"] = "pandas",
|
|
406
|
-
|
|
430
|
+
export_to: Optional[Table] = None,
|
|
407
431
|
update: bool = False, meta: dict[str, Any] | None = None) -> DataFrame:
|
|
408
432
|
self.prepare_data()
|
|
409
433
|
previous_model = self._last_model
|
|
@@ -433,7 +457,7 @@ class LQPExecutor(e.Executor):
|
|
|
433
457
|
assert isinstance(raw_results, TransactionAsyncResponse)
|
|
434
458
|
|
|
435
459
|
try:
|
|
436
|
-
return self._process_results(task, final_model, raw_results,
|
|
460
|
+
return self._process_results(task, final_model, raw_results, export_info, export_to, update)
|
|
437
461
|
except Exception as e:
|
|
438
462
|
# If processing the results failed, revert to the previous model.
|
|
439
463
|
self._last_model = previous_model
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
from datetime import datetime, timezone
|
|
2
2
|
|
|
3
3
|
from relationalai.semantics.lqp import ir as lqp
|
|
4
|
-
from relationalai.semantics.lqp.constructors import mk_abstraction, mk_value,
|
|
4
|
+
from relationalai.semantics.lqp.constructors import mk_abstraction, mk_value, mk_type, mk_primitive
|
|
5
5
|
from relationalai.semantics.lqp.utils import lqp_hash
|
|
6
6
|
|
|
7
7
|
def mk_intrinsic_datetime_now() -> lqp.Def:
|
|
8
8
|
"""Constructs a definition of the current datetime."""
|
|
9
9
|
id = lqp_hash("__pyrel_lqp_intrinsic_datetime_now")
|
|
10
|
-
out =
|
|
10
|
+
out = lqp.Var(name="out", meta=None)
|
|
11
11
|
out_type = mk_type(lqp.TypeName.DATETIME)
|
|
12
12
|
now = mk_value(lqp.DateTimeValue(value=datetime.now(timezone.utc), meta=None))
|
|
13
13
|
datetime_now = mk_abstraction(
|
|
@@ -7,7 +7,7 @@ from relationalai.semantics.lqp.pragmas import pragma_to_lqp_name
|
|
|
7
7
|
from relationalai.semantics.lqp.types import meta_type_to_lqp
|
|
8
8
|
from relationalai.semantics.lqp.constructors import (
|
|
9
9
|
mk_abstraction, mk_and, mk_exists, mk_or, mk_pragma, mk_primitive,
|
|
10
|
-
mk_specialized_value, mk_type, mk_value,
|
|
10
|
+
mk_specialized_value, mk_type, mk_value, mk_attribute
|
|
11
11
|
)
|
|
12
12
|
from relationalai.semantics.lqp.utils import TranslationCtx, gen_unique_var
|
|
13
13
|
from relationalai.semantics.lqp.validators import assert_valid_input
|
|
@@ -253,7 +253,7 @@ def _translate_rank(ctx: TranslationCtx, rank: ir.Rank, body: lqp.Formula) -> lq
|
|
|
253
253
|
# to convert it to Int128.
|
|
254
254
|
result_var, _ = _translate_term(ctx, rank.result)
|
|
255
255
|
# The primitive will return an Int64 result, so we need a var to hold the intermediary.
|
|
256
|
-
result_64_var = gen_unique_var(ctx, "
|
|
256
|
+
result_64_var = gen_unique_var(ctx, "v_rank")
|
|
257
257
|
result_64_type = mk_type(lqp.TypeName.INT)
|
|
258
258
|
|
|
259
259
|
cast = lqp.Cast(input=result_64_var, result=result_var, meta=None)
|
|
@@ -306,7 +306,7 @@ def _translate_descending_rank(ctx: TranslationCtx, limit: int, result: lqp.Var,
|
|
|
306
306
|
aggr_abstr_args = new_abstr_args + [(count_var, count_type)]
|
|
307
307
|
count_aggr = lqp.Reduce(
|
|
308
308
|
op=lqp_operator(
|
|
309
|
-
ctx
|
|
309
|
+
ctx,
|
|
310
310
|
"count",
|
|
311
311
|
"count",
|
|
312
312
|
mk_type(lqp.TypeName.INT)
|
|
@@ -431,7 +431,7 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
|
|
|
431
431
|
(sum_var, sum_type) = abstr_args[-2]
|
|
432
432
|
|
|
433
433
|
result = lqp.Reduce(
|
|
434
|
-
op=lqp_avg_op(ctx
|
|
434
|
+
op=lqp_avg_op(ctx, aggr.aggregation.name, sum_var.name, sum_type),
|
|
435
435
|
body=mk_abstraction(abstr_args, body),
|
|
436
436
|
terms=[sum_result, count_result],
|
|
437
437
|
meta=None,
|
|
@@ -464,7 +464,7 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
|
|
|
464
464
|
|
|
465
465
|
# Group-bys do not need to be handled at all, since they are introduced outside already
|
|
466
466
|
reduce = lqp.Reduce(
|
|
467
|
-
op=lqp_operator(ctx
|
|
467
|
+
op=lqp_operator(ctx, aggr.aggregation.name, aggr_arg.name, aggr_arg_type),
|
|
468
468
|
body=mk_abstraction(abstr_args, body),
|
|
469
469
|
terms=output_vars,
|
|
470
470
|
meta=None
|
|
@@ -523,9 +523,8 @@ def _translate_term(ctx: TranslationCtx, term: ir.Value) -> Tuple[lqp.Term, lqp.
|
|
|
523
523
|
# TODO: ScalarType is not like other terms, should be handled separately.
|
|
524
524
|
return to_lqp_value(term.name, types.String), meta_type_to_lqp(types.String)
|
|
525
525
|
elif isinstance(term, ir.Var):
|
|
526
|
-
name = ctx.var_names.get_name_by_id(term.id, term.name)
|
|
527
526
|
t = meta_type_to_lqp(term.type)
|
|
528
|
-
return
|
|
527
|
+
return _translate_var(ctx, term), t
|
|
529
528
|
else:
|
|
530
529
|
assert isinstance(term, ir.Literal), f"Cannot translate value {term!r} of type {type(term)} to LQP Term; neither Var nor Literal."
|
|
531
530
|
v = to_lqp_value(term.value, term.type)
|
|
@@ -801,3 +800,7 @@ def _translate_join(ctx: TranslationCtx, task: ir.Lookup) -> lqp.Formula:
|
|
|
801
800
|
output_term = _translate_term(ctx, target)[0]
|
|
802
801
|
|
|
803
802
|
return lqp.Reduce(meta=None, op=op, body=body, terms=[output_term])
|
|
803
|
+
|
|
804
|
+
def _translate_var(ctx: TranslationCtx, term: ir.Var) -> lqp.Var:
|
|
805
|
+
name = ctx.var_names.get_name_by_id(term.id, term.name)
|
|
806
|
+
return lqp.Var(name=name, meta=None)
|
|
@@ -6,17 +6,19 @@ from relationalai.semantics.metamodel.util import FrozenOrderedSet
|
|
|
6
6
|
|
|
7
7
|
from relationalai.semantics.metamodel.rewrite import Flatten
|
|
8
8
|
|
|
9
|
-
from ..metamodel.rewrite import DischargeConstraints, DNFUnionSplitter, ExtractNestedLogicals
|
|
10
|
-
from .rewrite import CDC, ExtractCommon, ExtractKeys,
|
|
9
|
+
from ..metamodel.rewrite import DischargeConstraints, DNFUnionSplitter, ExtractNestedLogicals, FormatOutputs
|
|
10
|
+
from .rewrite import CDC, ExtractCommon, ExtractKeys, FunctionAnnotations, QuantifyVars, Splinter
|
|
11
11
|
|
|
12
12
|
from relationalai.semantics.lqp.utils import output_names
|
|
13
13
|
|
|
14
14
|
from typing import cast, List, Sequence, Tuple, Union, Optional, Iterable
|
|
15
15
|
from collections import defaultdict
|
|
16
|
+
import pandas as pd
|
|
17
|
+
import hashlib
|
|
16
18
|
|
|
17
19
|
def lqp_passes() -> list[Pass]:
|
|
18
20
|
return [
|
|
19
|
-
|
|
21
|
+
FunctionAnnotations(),
|
|
20
22
|
DischargeConstraints(),
|
|
21
23
|
Checker(),
|
|
22
24
|
CDC(), # specialize to physical relations before extracting nested and typing
|
|
@@ -25,6 +27,7 @@ def lqp_passes() -> list[Pass]:
|
|
|
25
27
|
DNFUnionSplitter(),
|
|
26
28
|
ExtractKeys(),
|
|
27
29
|
ExtractCommon(),
|
|
30
|
+
FormatOutputs(),
|
|
28
31
|
Flatten(),
|
|
29
32
|
Splinter(), # Splits multi-headed rules into multiple rules
|
|
30
33
|
QuantifyVars(), # Adds missing existentials
|
|
@@ -337,7 +340,7 @@ class UnifyDefinitions(Pass):
|
|
|
337
340
|
)
|
|
338
341
|
|
|
339
342
|
# Creates intermediary relations for all Data nodes and replaces said Data nodes
|
|
340
|
-
# with a Lookup into these created relations.
|
|
343
|
+
# with a Lookup into these created relations. Reuse duplicate created relations.
|
|
341
344
|
class EliminateData(Pass):
|
|
342
345
|
def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
|
|
343
346
|
r = self.DataRewriter()
|
|
@@ -350,17 +353,25 @@ class EliminateData(Pass):
|
|
|
350
353
|
# Counter for naming new relations.
|
|
351
354
|
# It must be that new_count == len new_updates == len new_relations.
|
|
352
355
|
new_count: int
|
|
356
|
+
# Cache for Data nodes to avoid creating duplicate intermediary relations
|
|
357
|
+
data_cache: dict[str, ir.Relation]
|
|
353
358
|
|
|
354
359
|
def __init__(self):
|
|
355
360
|
self.new_relations = []
|
|
356
361
|
self.new_updates = []
|
|
357
362
|
self.new_count = 0
|
|
363
|
+
self.data_cache = {}
|
|
358
364
|
super().__init__()
|
|
359
365
|
|
|
360
|
-
# Create a
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
366
|
+
# Create a cache key for a Data node based on its structure and content
|
|
367
|
+
def _data_cache_key(self, node: ir.Data) -> str:
|
|
368
|
+
values = pd.util.hash_pandas_object(node.data).values
|
|
369
|
+
return hashlib.sha256(bytes(values)).hexdigest()
|
|
370
|
+
|
|
371
|
+
def _intermediary_relation(self, node: ir.Data) -> ir.Relation:
|
|
372
|
+
cache_key = self._data_cache_key(node)
|
|
373
|
+
if cache_key in self.data_cache:
|
|
374
|
+
return self.data_cache[cache_key]
|
|
364
375
|
self.new_count += 1
|
|
365
376
|
intermediary_name = f"formerly_Data_{self.new_count}"
|
|
366
377
|
|
|
@@ -379,7 +390,6 @@ class EliminateData(Pass):
|
|
|
379
390
|
f.lookup(rel_builtins.eq, [f.literal(val), var])
|
|
380
391
|
for (val, var) in zip(row, node.vars)
|
|
381
392
|
],
|
|
382
|
-
hoisted = node.vars,
|
|
383
393
|
)
|
|
384
394
|
for row in node
|
|
385
395
|
],
|
|
@@ -390,6 +400,16 @@ class EliminateData(Pass):
|
|
|
390
400
|
])
|
|
391
401
|
self.new_updates.append(intermediary_update)
|
|
392
402
|
|
|
403
|
+
# Cache the result for reuse
|
|
404
|
+
self.data_cache[cache_key] = intermediary_relation
|
|
405
|
+
|
|
406
|
+
return intermediary_relation
|
|
407
|
+
|
|
408
|
+
# Create a new intermediary relation representing the Data (and pop it in
|
|
409
|
+
# new_updates/new_relations) and replace this Data with a Lookup of said
|
|
410
|
+
# intermediary.
|
|
411
|
+
def handle_data(self, node: ir.Data, parent: ir.Node) -> ir.Lookup:
|
|
412
|
+
intermediary_relation = self._intermediary_relation(node)
|
|
393
413
|
replacement_lookup = f.lookup(intermediary_relation, node.vars)
|
|
394
414
|
|
|
395
415
|
return replacement_lookup
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from relationalai.semantics.metamodel.types import digits_to_bits
|
|
2
2
|
from relationalai.semantics.lqp import ir as lqp
|
|
3
3
|
from relationalai.semantics.lqp.types import is_numeric
|
|
4
|
-
from relationalai.semantics.lqp.utils import
|
|
5
|
-
from relationalai.semantics.lqp.constructors import mk_primitive, mk_specialized_value, mk_type, mk_value
|
|
4
|
+
from relationalai.semantics.lqp.utils import TranslationCtx, gen_unique_var, lqp_hash
|
|
5
|
+
from relationalai.semantics.lqp.constructors import mk_primitive, mk_specialized_value, mk_type, mk_value
|
|
6
6
|
|
|
7
7
|
rel_to_lqp = {
|
|
8
8
|
"=": "rel_primitive_eq",
|
|
@@ -205,15 +205,15 @@ def is_monotype(name: str) -> bool:
|
|
|
205
205
|
|
|
206
206
|
# We take the name and type of the variable that we're summing over, so that we can generate
|
|
207
207
|
# recognizable names for the variables in the reduce operation and preserve the type.
|
|
208
|
-
def lqp_avg_op(
|
|
208
|
+
def lqp_avg_op(ctx: TranslationCtx, op_name: str, sum_name: str, sum_type: lqp.Type) -> lqp.Abstraction:
|
|
209
209
|
count_type = mk_type(lqp.TypeName.INT)
|
|
210
210
|
vars = [
|
|
211
|
-
(
|
|
212
|
-
(
|
|
213
|
-
(
|
|
214
|
-
(
|
|
215
|
-
(
|
|
216
|
-
(
|
|
211
|
+
(gen_unique_var(ctx, sum_name), sum_type),
|
|
212
|
+
(gen_unique_var(ctx, "counter"), count_type),
|
|
213
|
+
(gen_unique_var(ctx, sum_name), sum_type),
|
|
214
|
+
(gen_unique_var(ctx, "one"), count_type),
|
|
215
|
+
(gen_unique_var(ctx, "sum"), sum_type),
|
|
216
|
+
(gen_unique_var(ctx, "count"), count_type),
|
|
217
217
|
]
|
|
218
218
|
|
|
219
219
|
x1 = vars[0][0]
|
|
@@ -233,10 +233,10 @@ def lqp_avg_op(names: UniqueNames, op_name: str, sum_name: str, sum_type: lqp.Ty
|
|
|
233
233
|
return lqp.Abstraction(vars=vars, value=body, meta=None)
|
|
234
234
|
|
|
235
235
|
# Default handler for aggregation operations in LQP.
|
|
236
|
-
def lqp_agg_op(
|
|
237
|
-
x =
|
|
238
|
-
y =
|
|
239
|
-
z =
|
|
236
|
+
def lqp_agg_op(ctx: TranslationCtx, op_name: str, aggr_arg_name: str, aggr_arg_type: lqp.Type) -> lqp.Abstraction:
|
|
237
|
+
x = gen_unique_var(ctx, f"x_{aggr_arg_name}")
|
|
238
|
+
y = gen_unique_var(ctx, f"y_{aggr_arg_name}")
|
|
239
|
+
z = gen_unique_var(ctx, f"z_{aggr_arg_name}")
|
|
240
240
|
ts = [(x, aggr_arg_type), (y, aggr_arg_type), (z, aggr_arg_type)]
|
|
241
241
|
|
|
242
242
|
name = agg_to_lqp.get(op_name, op_name)
|
|
@@ -244,9 +244,9 @@ def lqp_agg_op(names: UniqueNames, op_name: str, aggr_arg_name: str, aggr_arg_ty
|
|
|
244
244
|
|
|
245
245
|
return lqp.Abstraction(vars=ts, value=body, meta=None)
|
|
246
246
|
|
|
247
|
-
def lqp_operator(
|
|
247
|
+
def lqp_operator(ctx: TranslationCtx, op_name: str, aggr_arg_name: str, aggr_arg_type: lqp.Type) -> lqp.Abstraction:
|
|
248
248
|
# TODO: Can we just pass through unknown operations?
|
|
249
249
|
if op_name not in agg_to_lqp:
|
|
250
250
|
raise NotImplementedError(f"Unsupported aggregation: {op_name}")
|
|
251
251
|
|
|
252
|
-
return lqp_agg_op(
|
|
252
|
+
return lqp_agg_op(ctx, op_name, aggr_arg_name, aggr_arg_type)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from .cdc import CDC
|
|
2
2
|
from .extract_common import ExtractCommon
|
|
3
3
|
from .extract_keys import ExtractKeys
|
|
4
|
-
from .
|
|
4
|
+
from .function_annotations import FunctionAnnotations
|
|
5
5
|
from .quantify_vars import QuantifyVars
|
|
6
6
|
from .splinter import Splinter
|
|
7
7
|
|
|
@@ -9,7 +9,7 @@ __all__ = [
|
|
|
9
9
|
"CDC",
|
|
10
10
|
"ExtractCommon",
|
|
11
11
|
"ExtractKeys",
|
|
12
|
-
"
|
|
12
|
+
"FunctionAnnotations",
|
|
13
13
|
"QuantifyVars",
|
|
14
14
|
"Splinter",
|
|
15
15
|
]
|
|
@@ -5,7 +5,7 @@ from relationalai.semantics.metamodel import ir, compiler as c, visitor as v, bu
|
|
|
5
5
|
from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
class
|
|
8
|
+
class FunctionAnnotations(c.Pass):
|
|
9
9
|
"""
|
|
10
10
|
Pass marks all appropriate relations with `function` annotation.
|
|
11
11
|
Criteria:
|
|
@@ -17,7 +17,7 @@ class FDConstraints(c.Pass):
|
|
|
17
17
|
collect_fd = CollectFunctionalRelationsVisitor()
|
|
18
18
|
new_model = collect_fd.walk(model)
|
|
19
19
|
# mark relations collected by previous visitor with `@function` annotation
|
|
20
|
-
return
|
|
20
|
+
return FunctionalAnnotationsVisitor(collect_fd.functional_relations).walk(new_model)
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
@dataclass
|
|
@@ -57,9 +57,9 @@ class CollectFunctionalRelationsVisitor(v.Rewriter):
|
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
@dataclass
|
|
60
|
-
class
|
|
60
|
+
class FunctionalAnnotationsVisitor(v.Rewriter):
|
|
61
61
|
"""
|
|
62
|
-
This visitor marks functional_relations with `
|
|
62
|
+
This visitor marks functional_relations with `function` annotation.
|
|
63
63
|
"""
|
|
64
64
|
|
|
65
65
|
def __init__(self, functional_relations: OrderedSet):
|