relationalai 0.12.6__py3-none-any.whl → 0.12.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. relationalai/clients/snowflake.py +48 -7
  2. relationalai/clients/use_index_poller.py +11 -1
  3. relationalai/early_access/lqp/constructors/__init__.py +2 -2
  4. relationalai/early_access/metamodel/rewrite/__init__.py +2 -2
  5. relationalai/semantics/internal/internal.py +1 -4
  6. relationalai/semantics/internal/snowflake.py +14 -1
  7. relationalai/semantics/lqp/constructors.py +0 -5
  8. relationalai/semantics/lqp/executor.py +34 -10
  9. relationalai/semantics/lqp/intrinsics.py +2 -2
  10. relationalai/semantics/lqp/model2lqp.py +10 -7
  11. relationalai/semantics/lqp/passes.py +29 -9
  12. relationalai/semantics/lqp/primitives.py +15 -15
  13. relationalai/semantics/lqp/rewrite/__init__.py +2 -2
  14. relationalai/semantics/lqp/rewrite/{fd_constraints.py → function_annotations.py} +4 -4
  15. relationalai/semantics/lqp/utils.py +17 -13
  16. relationalai/semantics/metamodel/builtins.py +1 -0
  17. relationalai/semantics/metamodel/rewrite/__init__.py +2 -1
  18. relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +1 -1
  19. relationalai/semantics/metamodel/rewrite/extract_nested_logicals.py +5 -6
  20. relationalai/semantics/metamodel/rewrite/flatten.py +18 -149
  21. relationalai/semantics/metamodel/rewrite/format_outputs.py +165 -0
  22. relationalai/semantics/reasoners/graph/core.py +98 -70
  23. relationalai/semantics/reasoners/optimization/__init__.py +55 -10
  24. relationalai/semantics/reasoners/optimization/common.py +63 -8
  25. relationalai/semantics/reasoners/optimization/solvers_dev.py +39 -33
  26. relationalai/semantics/reasoners/optimization/solvers_pb.py +1033 -385
  27. relationalai/semantics/rel/compiler.py +4 -3
  28. relationalai/semantics/rel/executor.py +30 -8
  29. relationalai/semantics/snowflake/__init__.py +2 -2
  30. relationalai/semantics/sql/executor/snowflake.py +6 -2
  31. relationalai/semantics/tests/test_snapshot_abstract.py +5 -4
  32. relationalai/tools/cli.py +10 -0
  33. relationalai/tools/cli_controls.py +15 -0
  34. {relationalai-0.12.6.dist-info → relationalai-0.12.8.dist-info}/METADATA +2 -2
  35. {relationalai-0.12.6.dist-info → relationalai-0.12.8.dist-info}/RECORD +38 -37
  36. {relationalai-0.12.6.dist-info → relationalai-0.12.8.dist-info}/WHEEL +0 -0
  37. {relationalai-0.12.6.dist-info → relationalai-0.12.8.dist-info}/entry_points.txt +0 -0
  38. {relationalai-0.12.6.dist-info → relationalai-0.12.8.dist-info}/licenses/LICENSE +0 -0
@@ -851,7 +851,17 @@ Otherwise, remove it from your '{profile}' configuration profile.
851
851
  self.generation
852
852
  )
853
853
  # If cache is valid (data freshness has not expired), skip polling
854
- if not poller.cache.is_valid():
854
+ if poller.cache.is_valid():
855
+ cached_sources = len(poller.cache.sources)
856
+ total_sources = len(sources_list)
857
+ cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
858
+
859
+ message = f"Using cached data for {cached_sources}/{total_sources} data streams"
860
+ if cached_timestamp:
861
+ print(f"\n{message} (cached at {cached_timestamp})\n")
862
+ else:
863
+ print(f"\n{message}\n")
864
+ else:
855
865
  return poller.poll()
856
866
 
857
867
  #--------------------------------------------------
@@ -2842,6 +2852,15 @@ class SnowflakeTable(dsl.Type):
2842
2852
  else:
2843
2853
  self(snowflake_id=id).set(**{prop_ident: val})
2844
2854
 
2855
+ # Because we're bypassing a bunch of the normal Type.add machinery here,
2856
+ # we need to manually account for the case where people are using value types.
2857
+ def wrapped(x):
2858
+ if not model._config.get("compiler.use_value_types", False):
2859
+ return x
2860
+ other_id = dsl.create_var()
2861
+ model._action(dsl.build.construct(self._type, [x, other_id]))
2862
+ return other_id
2863
+
2845
2864
  # new UInt128 schema mapping rules
2846
2865
  with model.rule(dynamic=True, globalize=True, source=self._source):
2847
2866
  id = dsl.create_var()
@@ -2851,7 +2870,7 @@ class SnowflakeTable(dsl.Type):
2851
2870
  # for avoiding a non-blocking warning
2852
2871
  edb(dsl.Symbol("METADATA$KEY"), id)
2853
2872
  std.rel.UInt128(id)
2854
- self.add(id, snowflake_id=id)
2873
+ self.add(wrapped(id), snowflake_id=id)
2855
2874
 
2856
2875
  for prop, prop_type in self._schema["columns"].items():
2857
2876
  _prop = prop
@@ -2873,7 +2892,7 @@ class SnowflakeTable(dsl.Type):
2873
2892
  model._check_property(_prop._prop)
2874
2893
  raw_relation = getattr(std.rel, prop_ident)
2875
2894
  dsl.tag(raw_relation, dsl.Builtins.FunctionAnnotation)
2876
- raw_relation.add(id, val)
2895
+ raw_relation.add(wrapped(id), val)
2877
2896
 
2878
2897
  def namespace(self):
2879
2898
  return f"{self._parent._parent._name}.{self._parent._name}"
@@ -3275,12 +3294,24 @@ class DirectAccessResources(Resources):
3275
3294
  try:
3276
3295
  response = _send_request()
3277
3296
  if response.status_code != 200:
3297
+ # For 404 responses with skip_auto_create=True, return immediately to let caller handle it
3298
+ # (e.g., get_engine needs to check 404 and return None for auto_create_engine)
3299
+ # For skip_auto_create=False, continue to auto-creation logic below
3300
+ if response.status_code == 404 and skip_auto_create:
3301
+ return response
3302
+
3278
3303
  try:
3279
3304
  message = response.json().get("message", "")
3280
3305
  except requests.exceptions.JSONDecodeError:
3281
- raise ResponseStatusException(
3282
- f"Failed to parse error response from endpoint {endpoint}.", response
3283
- )
3306
+ # Can't parse JSON response. For skip_auto_create=True (e.g., get_engine),
3307
+ # this should have been caught by the 404 check above, so this is an error.
3308
+ # For skip_auto_create=False, we explicitly check status_code below,
3309
+ # so we don't need to parse the message.
3310
+ if skip_auto_create:
3311
+ raise ResponseStatusException(
3312
+ f"Failed to parse error response from endpoint {endpoint}.", response
3313
+ )
3314
+ message = "" # Not used when we check status_code directly
3284
3315
 
3285
3316
  # fix engine on engine error and retry
3286
3317
  # Skip auto-retry if skip_auto_create is True to avoid recursion
@@ -3473,7 +3504,17 @@ class DirectAccessResources(Resources):
3473
3504
  generation=self.generation,
3474
3505
  )
3475
3506
  # If cache is valid (data freshness has not expired), skip polling
3476
- if not poller.cache.is_valid():
3507
+ if poller.cache.is_valid():
3508
+ cached_sources = len(poller.cache.sources)
3509
+ total_sources = len(sources_list)
3510
+ cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
3511
+
3512
+ message = f"Using cached data for {cached_sources}/{total_sources} data streams"
3513
+ if cached_timestamp:
3514
+ print(f"\n{message} (cached at {cached_timestamp})\n")
3515
+ else:
3516
+ print(f"\n{message}\n")
3517
+ else:
3477
3518
  return poller.poll()
3478
3519
 
3479
3520
  def _check_exec_async_status(self, txn_id: str, headers: Dict[str, str] | None = None) -> bool:
@@ -250,7 +250,17 @@ class UseIndexPoller:
250
250
  # Cache was used - show how many sources were cached
251
251
  total_sources = len(self.cache.sources)
252
252
  cached_sources = total_sources - len(self.sources)
253
- progress.add_sub_task(f"Using cached data for {cached_sources}/{total_sources} data streams", task_id="cache_usage", category=TASK_CATEGORY_CACHE)
253
+
254
+ # Get the timestamp when sources were cached
255
+ entry = self.cache._metadata.get("cachedIndices", {}).get(self.cache.key, {})
256
+ cached_timestamp = entry.get("last_use_index_update_on", "")
257
+
258
+ message = f"Using cached data for {cached_sources}/{total_sources} data streams"
259
+ # Format the message with timestamp
260
+ if cached_timestamp:
261
+ message += f" (cached at {cached_timestamp})"
262
+
263
+ progress.add_sub_task(message, task_id="cache_usage", category=TASK_CATEGORY_CACHE)
254
264
  # Complete the subtask immediately since it's just informational
255
265
  progress.complete_sub_task("cache_usage")
256
266
 
@@ -2,12 +2,12 @@ import warnings
2
2
 
3
3
  from relationalai.semantics.lqp.constructors import (
4
4
  mk_abstraction, mk_and, mk_exists, mk_or, mk_pragma, mk_primitive,
5
- mk_specialized_value, mk_type, mk_value, mk_var
5
+ mk_specialized_value, mk_type, mk_value
6
6
  )
7
7
 
8
8
  __all__ = [
9
9
  "mk_abstraction", "mk_and", "mk_exists", "mk_or", "mk_pragma", "mk_primitive", "mk_specialized_value", "mk_type",
10
- "mk_value", "mk_var"
10
+ "mk_value"
11
11
  ]
12
12
 
13
13
  warnings.warn(
@@ -1,7 +1,7 @@
1
1
  from relationalai.semantics.metamodel.rewrite import Flatten, \
2
2
  DNFUnionSplitter, ExtractNestedLogicals, flatten
3
3
  from relationalai.semantics.lqp.rewrite import Splinter, \
4
- ExtractKeys, FDConstraints
4
+ ExtractKeys, FunctionAnnotations
5
5
 
6
6
  __all__ = ["Splinter", "Flatten", "DNFUnionSplitter", "ExtractKeys",
7
- "ExtractNestedLogicals", "FDConstraints", "flatten"]
7
+ "ExtractNestedLogicals", "FunctionAnnotations", "flatten"]
@@ -2550,8 +2550,6 @@ class Fragment():
2550
2550
  from .snowflake import Table
2551
2551
  assert isinstance(table, Table), "Only Snowflake tables are supported for now"
2552
2552
 
2553
- result_cols = table._col_names
2554
-
2555
2553
  clone = Fragment(parent=self)
2556
2554
  clone._is_export = True
2557
2555
  qb_model = clone._model or Model("anon")
@@ -2559,8 +2557,7 @@ class Fragment():
2559
2557
  clone._source = runtime_env.get_source_pos()
2560
2558
  with debugging.span("query", dsl=str(clone), **with_source(clone), meta=clone._meta):
2561
2559
  query_task = qb_model._compiler.fragment(clone)
2562
- qb_model._to_executor().execute(ir_model, query_task, result_cols=result_cols, export_to=table._fqn, update=update, meta=clone._meta)
2563
-
2560
+ qb_model._to_executor().execute(ir_model, query_task, export_to=table, update=update, meta=clone._meta)
2564
2561
 
2565
2562
  #--------------------------------------------------
2566
2563
  # Select / Where
@@ -12,7 +12,18 @@ from . import internal as b, annotations as anns
12
12
  from relationalai import debugging
13
13
  from relationalai.errors import UnsupportedColumnTypesWarning
14
14
  from snowflake.snowpark.context import get_active_session
15
+ from typing import ClassVar, Optional
15
16
 
17
+ #--------------------------------------------------
18
+ # Iceberg Configuration
19
+ #--------------------------------------------------
20
+ @dataclass
21
+ class IcebergConfig:
22
+ """Configuration for exporting to Iceberg tables."""
23
+ external_volume: str | None = None
24
+ default: ClassVar[Optional["IcebergConfig"]]
25
+
26
+ IcebergConfig.default = IcebergConfig()
16
27
  #--------------------------------------------------
17
28
  # Helpers
18
29
  #--------------------------------------------------
@@ -191,7 +202,7 @@ class Table():
191
202
  _schemas:dict[tuple[str, str], SchemaInfo] = {}
192
203
  _used_sources:OrderedSet[Table] = ordered_set()
193
204
 
194
- def __init__(self, fqn:str, cols:list[str]|None=None, schema:dict[str, str|b.Concept]|None=None) -> None:
205
+ def __init__(self, fqn:str, cols:list[str]|None=None, schema:dict[str, str|b.Concept]|None=None, config: IcebergConfig|None=None) -> None:
195
206
  self._fqn = fqn
196
207
  parser = IdentityParser(fqn, require_all_parts=True)
197
208
  self._database, self._schema, self._table, self._fqn = parser.to_list()
@@ -201,6 +212,8 @@ class Table():
201
212
  self._ref = self._concept.ref("row_id")
202
213
  self._cols = {}
203
214
  self._col_names = cols
215
+ self._iceberg_config = config
216
+ self._is_iceberg = config is not None
204
217
  info = self._schemas.get((self._database, self._schema))
205
218
  if not info:
206
219
  info = self._schemas[(self._database, self._schema)] = SchemaInfo(self._database, self._schema)
@@ -1,6 +1,5 @@
1
1
  from typing import Tuple
2
2
  from relationalai.semantics.lqp import ir as lqp
3
- from relationalai.semantics.metamodel.ir import sanitize
4
3
 
5
4
  def mk_and(args: list[lqp.Formula]) -> lqp.Formula:
6
5
  # Flatten nested conjunctions
@@ -49,10 +48,6 @@ def mk_specialized_value(value) -> lqp.SpecializedValue:
49
48
  def mk_value(value) -> lqp.Value:
50
49
  return lqp.Value(value=value, meta=None)
51
50
 
52
- def mk_var(name: str) -> lqp.Var:
53
- _name = '_' if name == '_' else sanitize(name)
54
- return lqp.Var(name=_name, meta=None)
55
-
56
51
  def mk_type(typename: lqp.TypeName, parameters: list[lqp.Value]=[]) -> lqp.Type:
57
52
  return lqp.Type(type_name=typename, parameters=parameters, meta=None)
58
53
 
@@ -4,7 +4,7 @@ import atexit
4
4
  import re
5
5
 
6
6
  from pandas import DataFrame
7
- from typing import Any, Optional, Literal
7
+ from typing import Any, Optional, Literal, TYPE_CHECKING
8
8
  from snowflake.snowpark import Session
9
9
  import relationalai as rai
10
10
 
@@ -20,10 +20,14 @@ from relationalai.semantics.lqp.ir import convert_transaction, validate_lqp
20
20
  from relationalai.clients.config import Config
21
21
  from relationalai.clients.snowflake import APP_NAME
22
22
  from relationalai.clients.types import TransactionAsyncResponse
23
- from relationalai.clients.util import IdentityParser
23
+ from relationalai.clients.util import IdentityParser, escape_for_f_string
24
24
  from relationalai.tools.constants import USE_DIRECT_ACCESS, QUERY_ATTRIBUTES_HEADER
25
25
  from relationalai.tools.query_utils import prepare_metadata_for_headers
26
26
 
27
+ if TYPE_CHECKING:
28
+ from relationalai.semantics.snowflake import Table
29
+
30
+
27
31
  class LQPExecutor(e.Executor):
28
32
  """Executes LQP using the RAI client."""
29
33
 
@@ -172,12 +176,12 @@ class LQPExecutor(e.Executor):
172
176
  elif len(all_errors) > 1:
173
177
  raise errors.RAIExceptionSet(all_errors)
174
178
 
175
- def _export(self, txn_id: str, export_info: tuple, dest_fqn: str, actual_cols: list[str], declared_cols: list[str], update:bool):
179
+ def _export(self, txn_id: str, export_info: tuple, dest: Table, actual_cols: list[str], declared_cols: list[str], update: bool):
176
180
  # At this point of the export, we assume that a CSV file has already been written
177
181
  # to the Snowflake Native App stage area. Thus, the purpose of this method is to
178
182
  # copy the data from the CSV file to the destination table.
179
183
  _exec = self.resources._exec
180
- dest_database, dest_schema, dest_table, _ = IdentityParser(dest_fqn, require_all_parts=True).to_list()
184
+ dest_database, dest_schema, dest_table, _ = IdentityParser(dest._fqn, require_all_parts=True).to_list()
181
185
  filename = export_info[0]
182
186
  result_table_name = filename + "_table"
183
187
 
@@ -203,8 +207,28 @@ class LQPExecutor(e.Executor):
203
207
  # destination table. This step also cleans up the result table.
204
208
  out_sample = _exec(f"select * from {APP_NAME}.results.{result_table_name} limit 1;")
205
209
  names = self._build_projection(declared_cols, actual_cols, column_fields, out_sample)
210
+ dest_fqn = dest._fqn
206
211
  try:
207
212
  if not update:
213
+ createTableLogic = f"""
214
+ CREATE TABLE {dest_fqn} AS
215
+ SELECT {names}
216
+ FROM {APP_NAME}.results.{result_table_name};
217
+ """
218
+ if dest._is_iceberg:
219
+ assert dest._iceberg_config is not None
220
+ external_volume_clause = ""
221
+ if dest._iceberg_config.external_volume:
222
+ external_volume_clause = f"EXTERNAL_VOLUME = '{dest._iceberg_config.external_volume}'"
223
+ createTableLogic = f"""
224
+ CREATE ICEBERG TABLE {dest_fqn}
225
+ CATALOG = "SNOWFLAKE"
226
+ {external_volume_clause}
227
+ AS
228
+ SELECT {names}
229
+ FROM {APP_NAME}.results.{result_table_name};
230
+ """
231
+
208
232
  _exec(f"""
209
233
  BEGIN
210
234
  -- Check if table exists
@@ -227,9 +251,7 @@ class LQPExecutor(e.Executor):
227
251
  ELSE
228
252
  -- Create table based on the SELECT
229
253
  EXECUTE IMMEDIATE '
230
- CREATE TABLE {dest_fqn} AS
231
- SELECT {names}
232
- FROM {APP_NAME}.results.{result_table_name};
254
+ {escape_for_f_string(createTableLogic)}
233
255
  ';
234
256
  END IF;
235
257
  END;
@@ -376,7 +398,7 @@ class LQPExecutor(e.Executor):
376
398
  return final_model, export_info, txn_proto
377
399
 
378
400
  # TODO (azreika): This should probably be split up into exporting and other processing. There are quite a lot of arguments here...
379
- def _process_results(self, task: ir.Task, final_model: ir.Model, raw_results: TransactionAsyncResponse, result_cols: Optional[list[str]], export_info: Optional[tuple], export_to: Optional[str], update: bool) -> DataFrame:
401
+ def _process_results(self, task: ir.Task, final_model: ir.Model, raw_results: TransactionAsyncResponse, export_info: Optional[tuple], export_to: Optional[Table], update: bool) -> DataFrame:
380
402
  cols, extra_cols = self._compute_cols(task, final_model)
381
403
 
382
404
  df, errs = result_helpers.format_results(raw_results, cols)
@@ -391,6 +413,8 @@ class LQPExecutor(e.Executor):
391
413
  assert cols, "No columns found in the output"
392
414
  assert isinstance(raw_results, TransactionAsyncResponse) and raw_results.transaction, "Invalid transaction result"
393
415
 
416
+ result_cols = export_to._col_names
417
+
394
418
  if result_cols is not None:
395
419
  assert all(col in result_cols or col in extra_cols for col in cols)
396
420
  else:
@@ -403,7 +427,7 @@ class LQPExecutor(e.Executor):
403
427
  return self._postprocess_df(self.config, df, extra_cols)
404
428
 
405
429
  def execute(self, model: ir.Model, task: ir.Task, format: Literal["pandas", "snowpark"] = "pandas",
406
- result_cols: Optional[list[str]] = None, export_to: Optional[str] = None,
430
+ export_to: Optional[Table] = None,
407
431
  update: bool = False, meta: dict[str, Any] | None = None) -> DataFrame:
408
432
  self.prepare_data()
409
433
  previous_model = self._last_model
@@ -433,7 +457,7 @@ class LQPExecutor(e.Executor):
433
457
  assert isinstance(raw_results, TransactionAsyncResponse)
434
458
 
435
459
  try:
436
- return self._process_results(task, final_model, raw_results, result_cols, export_info, export_to, update)
460
+ return self._process_results(task, final_model, raw_results, export_info, export_to, update)
437
461
  except Exception as e:
438
462
  # If processing the results failed, revert to the previous model.
439
463
  self._last_model = previous_model
@@ -1,13 +1,13 @@
1
1
  from datetime import datetime, timezone
2
2
 
3
3
  from relationalai.semantics.lqp import ir as lqp
4
- from relationalai.semantics.lqp.constructors import mk_abstraction, mk_value, mk_var, mk_type, mk_primitive
4
+ from relationalai.semantics.lqp.constructors import mk_abstraction, mk_value, mk_type, mk_primitive
5
5
  from relationalai.semantics.lqp.utils import lqp_hash
6
6
 
7
7
  def mk_intrinsic_datetime_now() -> lqp.Def:
8
8
  """Constructs a definition of the current datetime."""
9
9
  id = lqp_hash("__pyrel_lqp_intrinsic_datetime_now")
10
- out = mk_var("out")
10
+ out = lqp.Var(name="out", meta=None)
11
11
  out_type = mk_type(lqp.TypeName.DATETIME)
12
12
  now = mk_value(lqp.DateTimeValue(value=datetime.now(timezone.utc), meta=None))
13
13
  datetime_now = mk_abstraction(
@@ -7,7 +7,7 @@ from relationalai.semantics.lqp.pragmas import pragma_to_lqp_name
7
7
  from relationalai.semantics.lqp.types import meta_type_to_lqp
8
8
  from relationalai.semantics.lqp.constructors import (
9
9
  mk_abstraction, mk_and, mk_exists, mk_or, mk_pragma, mk_primitive,
10
- mk_specialized_value, mk_type, mk_value, mk_var, mk_attribute
10
+ mk_specialized_value, mk_type, mk_value, mk_attribute
11
11
  )
12
12
  from relationalai.semantics.lqp.utils import TranslationCtx, gen_unique_var
13
13
  from relationalai.semantics.lqp.validators import assert_valid_input
@@ -253,7 +253,7 @@ def _translate_rank(ctx: TranslationCtx, rank: ir.Rank, body: lqp.Formula) -> lq
253
253
  # to convert it to Int128.
254
254
  result_var, _ = _translate_term(ctx, rank.result)
255
255
  # The primitive will return an Int64 result, so we need a var to hold the intermediary.
256
- result_64_var = gen_unique_var(ctx, "rank_result_64")
256
+ result_64_var = gen_unique_var(ctx, "v_rank")
257
257
  result_64_type = mk_type(lqp.TypeName.INT)
258
258
 
259
259
  cast = lqp.Cast(input=result_64_var, result=result_var, meta=None)
@@ -306,7 +306,7 @@ def _translate_descending_rank(ctx: TranslationCtx, limit: int, result: lqp.Var,
306
306
  aggr_abstr_args = new_abstr_args + [(count_var, count_type)]
307
307
  count_aggr = lqp.Reduce(
308
308
  op=lqp_operator(
309
- ctx.var_names,
309
+ ctx,
310
310
  "count",
311
311
  "count",
312
312
  mk_type(lqp.TypeName.INT)
@@ -431,7 +431,7 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
431
431
  (sum_var, sum_type) = abstr_args[-2]
432
432
 
433
433
  result = lqp.Reduce(
434
- op=lqp_avg_op(ctx.var_names, aggr.aggregation.name, sum_var.name, sum_type),
434
+ op=lqp_avg_op(ctx, aggr.aggregation.name, sum_var.name, sum_type),
435
435
  body=mk_abstraction(abstr_args, body),
436
436
  terms=[sum_result, count_result],
437
437
  meta=None,
@@ -464,7 +464,7 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
464
464
 
465
465
  # Group-bys do not need to be handled at all, since they are introduced outside already
466
466
  reduce = lqp.Reduce(
467
- op=lqp_operator(ctx.var_names, aggr.aggregation.name, aggr_arg.name, aggr_arg_type),
467
+ op=lqp_operator(ctx, aggr.aggregation.name, aggr_arg.name, aggr_arg_type),
468
468
  body=mk_abstraction(abstr_args, body),
469
469
  terms=output_vars,
470
470
  meta=None
@@ -523,9 +523,8 @@ def _translate_term(ctx: TranslationCtx, term: ir.Value) -> Tuple[lqp.Term, lqp.
523
523
  # TODO: ScalarType is not like other terms, should be handled separately.
524
524
  return to_lqp_value(term.name, types.String), meta_type_to_lqp(types.String)
525
525
  elif isinstance(term, ir.Var):
526
- name = ctx.var_names.get_name_by_id(term.id, term.name)
527
526
  t = meta_type_to_lqp(term.type)
528
- return mk_var(name), t
527
+ return _translate_var(ctx, term), t
529
528
  else:
530
529
  assert isinstance(term, ir.Literal), f"Cannot translate value {term!r} of type {type(term)} to LQP Term; neither Var nor Literal."
531
530
  v = to_lqp_value(term.value, term.type)
@@ -801,3 +800,7 @@ def _translate_join(ctx: TranslationCtx, task: ir.Lookup) -> lqp.Formula:
801
800
  output_term = _translate_term(ctx, target)[0]
802
801
 
803
802
  return lqp.Reduce(meta=None, op=op, body=body, terms=[output_term])
803
+
804
+ def _translate_var(ctx: TranslationCtx, term: ir.Var) -> lqp.Var:
805
+ name = ctx.var_names.get_name_by_id(term.id, term.name)
806
+ return lqp.Var(name=name, meta=None)
@@ -6,17 +6,19 @@ from relationalai.semantics.metamodel.util import FrozenOrderedSet
6
6
 
7
7
  from relationalai.semantics.metamodel.rewrite import Flatten
8
8
 
9
- from ..metamodel.rewrite import DischargeConstraints, DNFUnionSplitter, ExtractNestedLogicals
10
- from .rewrite import CDC, ExtractCommon, ExtractKeys, FDConstraints, QuantifyVars, Splinter
9
+ from ..metamodel.rewrite import DischargeConstraints, DNFUnionSplitter, ExtractNestedLogicals, FormatOutputs
10
+ from .rewrite import CDC, ExtractCommon, ExtractKeys, FunctionAnnotations, QuantifyVars, Splinter
11
11
 
12
12
  from relationalai.semantics.lqp.utils import output_names
13
13
 
14
14
  from typing import cast, List, Sequence, Tuple, Union, Optional, Iterable
15
15
  from collections import defaultdict
16
+ import pandas as pd
17
+ import hashlib
16
18
 
17
19
  def lqp_passes() -> list[Pass]:
18
20
  return [
19
- FDConstraints(),
21
+ FunctionAnnotations(),
20
22
  DischargeConstraints(),
21
23
  Checker(),
22
24
  CDC(), # specialize to physical relations before extracting nested and typing
@@ -25,6 +27,7 @@ def lqp_passes() -> list[Pass]:
25
27
  DNFUnionSplitter(),
26
28
  ExtractKeys(),
27
29
  ExtractCommon(),
30
+ FormatOutputs(),
28
31
  Flatten(),
29
32
  Splinter(), # Splits multi-headed rules into multiple rules
30
33
  QuantifyVars(), # Adds missing existentials
@@ -337,7 +340,7 @@ class UnifyDefinitions(Pass):
337
340
  )
338
341
 
339
342
  # Creates intermediary relations for all Data nodes and replaces said Data nodes
340
- # with a Lookup into these created relations.
343
+ # with a Lookup into these created relations. Reuse duplicate created relations.
341
344
  class EliminateData(Pass):
342
345
  def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
343
346
  r = self.DataRewriter()
@@ -350,17 +353,25 @@ class EliminateData(Pass):
350
353
  # Counter for naming new relations.
351
354
  # It must be that new_count == len new_updates == len new_relations.
352
355
  new_count: int
356
+ # Cache for Data nodes to avoid creating duplicate intermediary relations
357
+ data_cache: dict[str, ir.Relation]
353
358
 
354
359
  def __init__(self):
355
360
  self.new_relations = []
356
361
  self.new_updates = []
357
362
  self.new_count = 0
363
+ self.data_cache = {}
358
364
  super().__init__()
359
365
 
360
- # Create a new intermediary relation representing the Data (and pop it in
361
- # new_updates/new_relations) and replace this Data with a Lookup of said
362
- # intermediary.
363
- def handle_data(self, node: ir.Data, parent: ir.Node) -> ir.Lookup:
366
+ # Create a cache key for a Data node based on its structure and content
367
+ def _data_cache_key(self, node: ir.Data) -> str:
368
+ values = pd.util.hash_pandas_object(node.data).values
369
+ return hashlib.sha256(bytes(values)).hexdigest()
370
+
371
+ def _intermediary_relation(self, node: ir.Data) -> ir.Relation:
372
+ cache_key = self._data_cache_key(node)
373
+ if cache_key in self.data_cache:
374
+ return self.data_cache[cache_key]
364
375
  self.new_count += 1
365
376
  intermediary_name = f"formerly_Data_{self.new_count}"
366
377
 
@@ -379,7 +390,6 @@ class EliminateData(Pass):
379
390
  f.lookup(rel_builtins.eq, [f.literal(val), var])
380
391
  for (val, var) in zip(row, node.vars)
381
392
  ],
382
- hoisted = node.vars,
383
393
  )
384
394
  for row in node
385
395
  ],
@@ -390,6 +400,16 @@ class EliminateData(Pass):
390
400
  ])
391
401
  self.new_updates.append(intermediary_update)
392
402
 
403
+ # Cache the result for reuse
404
+ self.data_cache[cache_key] = intermediary_relation
405
+
406
+ return intermediary_relation
407
+
408
+ # Create a new intermediary relation representing the Data (and pop it in
409
+ # new_updates/new_relations) and replace this Data with a Lookup of said
410
+ # intermediary.
411
+ def handle_data(self, node: ir.Data, parent: ir.Node) -> ir.Lookup:
412
+ intermediary_relation = self._intermediary_relation(node)
393
413
  replacement_lookup = f.lookup(intermediary_relation, node.vars)
394
414
 
395
415
  return replacement_lookup
@@ -1,8 +1,8 @@
1
1
  from relationalai.semantics.metamodel.types import digits_to_bits
2
2
  from relationalai.semantics.lqp import ir as lqp
3
3
  from relationalai.semantics.lqp.types import is_numeric
4
- from relationalai.semantics.lqp.utils import UniqueNames, lqp_hash
5
- from relationalai.semantics.lqp.constructors import mk_primitive, mk_specialized_value, mk_type, mk_value, mk_var
4
+ from relationalai.semantics.lqp.utils import TranslationCtx, gen_unique_var, lqp_hash
5
+ from relationalai.semantics.lqp.constructors import mk_primitive, mk_specialized_value, mk_type, mk_value
6
6
 
7
7
  rel_to_lqp = {
8
8
  "=": "rel_primitive_eq",
@@ -205,15 +205,15 @@ def is_monotype(name: str) -> bool:
205
205
 
206
206
  # We take the name and type of the variable that we're summing over, so that we can generate
207
207
  # recognizable names for the variables in the reduce operation and preserve the type.
208
- def lqp_avg_op(names: UniqueNames, op_name: str, sum_name: str, sum_type: lqp.Type) -> lqp.Abstraction:
208
+ def lqp_avg_op(ctx: TranslationCtx, op_name: str, sum_name: str, sum_type: lqp.Type) -> lqp.Abstraction:
209
209
  count_type = mk_type(lqp.TypeName.INT)
210
210
  vars = [
211
- (mk_var(names.get_name(sum_name)), sum_type),
212
- (mk_var(names.get_name("counter")), count_type),
213
- (mk_var(names.get_name(sum_name)), sum_type),
214
- (mk_var(names.get_name("one")), count_type),
215
- (mk_var(names.get_name("sum")), sum_type),
216
- (mk_var(names.get_name("count")), count_type),
211
+ (gen_unique_var(ctx, sum_name), sum_type),
212
+ (gen_unique_var(ctx, "counter"), count_type),
213
+ (gen_unique_var(ctx, sum_name), sum_type),
214
+ (gen_unique_var(ctx, "one"), count_type),
215
+ (gen_unique_var(ctx, "sum"), sum_type),
216
+ (gen_unique_var(ctx, "count"), count_type),
217
217
  ]
218
218
 
219
219
  x1 = vars[0][0]
@@ -233,10 +233,10 @@ def lqp_avg_op(names: UniqueNames, op_name: str, sum_name: str, sum_type: lqp.Ty
233
233
  return lqp.Abstraction(vars=vars, value=body, meta=None)
234
234
 
235
235
  # Default handler for aggregation operations in LQP.
236
- def lqp_agg_op(names: UniqueNames, op_name: str, aggr_arg_name: str, aggr_arg_type: lqp.Type) -> lqp.Abstraction:
237
- x = mk_var(names.get_name(f"x_{aggr_arg_name}"))
238
- y = mk_var(names.get_name(f"y_{aggr_arg_name}"))
239
- z = mk_var(names.get_name(f"z_{aggr_arg_name}"))
236
+ def lqp_agg_op(ctx: TranslationCtx, op_name: str, aggr_arg_name: str, aggr_arg_type: lqp.Type) -> lqp.Abstraction:
237
+ x = gen_unique_var(ctx, f"x_{aggr_arg_name}")
238
+ y = gen_unique_var(ctx, f"y_{aggr_arg_name}")
239
+ z = gen_unique_var(ctx, f"z_{aggr_arg_name}")
240
240
  ts = [(x, aggr_arg_type), (y, aggr_arg_type), (z, aggr_arg_type)]
241
241
 
242
242
  name = agg_to_lqp.get(op_name, op_name)
@@ -244,9 +244,9 @@ def lqp_agg_op(names: UniqueNames, op_name: str, aggr_arg_name: str, aggr_arg_ty
244
244
 
245
245
  return lqp.Abstraction(vars=ts, value=body, meta=None)
246
246
 
247
- def lqp_operator(names: UniqueNames, op_name: str, aggr_arg_name: str, aggr_arg_type: lqp.Type) -> lqp.Abstraction:
247
+ def lqp_operator(ctx: TranslationCtx, op_name: str, aggr_arg_name: str, aggr_arg_type: lqp.Type) -> lqp.Abstraction:
248
248
  # TODO: Can we just pass through unknown operations?
249
249
  if op_name not in agg_to_lqp:
250
250
  raise NotImplementedError(f"Unsupported aggregation: {op_name}")
251
251
 
252
- return lqp_agg_op(names, op_name, aggr_arg_name, aggr_arg_type)
252
+ return lqp_agg_op(ctx, op_name, aggr_arg_name, aggr_arg_type)
@@ -1,7 +1,7 @@
1
1
  from .cdc import CDC
2
2
  from .extract_common import ExtractCommon
3
3
  from .extract_keys import ExtractKeys
4
- from .fd_constraints import FDConstraints
4
+ from .function_annotations import FunctionAnnotations
5
5
  from .quantify_vars import QuantifyVars
6
6
  from .splinter import Splinter
7
7
 
@@ -9,7 +9,7 @@ __all__ = [
9
9
  "CDC",
10
10
  "ExtractCommon",
11
11
  "ExtractKeys",
12
- "FDConstraints",
12
+ "FunctionAnnotations",
13
13
  "QuantifyVars",
14
14
  "Splinter",
15
15
  ]
@@ -5,7 +5,7 @@ from relationalai.semantics.metamodel import ir, compiler as c, visitor as v, bu
5
5
  from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
6
6
 
7
7
 
8
- class FDConstraints(c.Pass):
8
+ class FunctionAnnotations(c.Pass):
9
9
  """
10
10
  Pass marks all appropriate relations with `function` annotation.
11
11
  Criteria:
@@ -17,7 +17,7 @@ class FDConstraints(c.Pass):
17
17
  collect_fd = CollectFunctionalRelationsVisitor()
18
18
  new_model = collect_fd.walk(model)
19
19
  # mark relations collected by previous visitor with `@function` annotation
20
- return FDConstraintsVisitor(collect_fd.functional_relations).walk(new_model)
20
+ return FunctionalAnnotationsVisitor(collect_fd.functional_relations).walk(new_model)
21
21
 
22
22
 
23
23
  @dataclass
@@ -57,9 +57,9 @@ class CollectFunctionalRelationsVisitor(v.Rewriter):
57
57
 
58
58
 
59
59
  @dataclass
60
- class FDConstraintsVisitor(v.Rewriter):
60
+ class FunctionalAnnotationsVisitor(v.Rewriter):
61
61
  """
62
- This visitor marks functional_relations with `functional` annotation.
62
+ This visitor marks functional_relations with `function` annotation.
63
63
  """
64
64
 
65
65
  def __init__(self, functional_relations: OrderedSet):