relationalai 0.11.4__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. relationalai/clients/config.py +7 -0
  2. relationalai/clients/direct_access_client.py +113 -0
  3. relationalai/clients/snowflake.py +263 -189
  4. relationalai/clients/types.py +4 -1
  5. relationalai/clients/use_index_poller.py +72 -48
  6. relationalai/clients/util.py +9 -0
  7. relationalai/dsl.py +1 -2
  8. relationalai/early_access/metamodel/rewrite/__init__.py +5 -3
  9. relationalai/early_access/rel/rewrite/__init__.py +1 -1
  10. relationalai/environments/snowbook.py +10 -1
  11. relationalai/errors.py +24 -3
  12. relationalai/semantics/internal/annotations.py +1 -0
  13. relationalai/semantics/internal/internal.py +22 -3
  14. relationalai/semantics/lqp/builtins.py +1 -0
  15. relationalai/semantics/lqp/executor.py +12 -4
  16. relationalai/semantics/lqp/model2lqp.py +1 -0
  17. relationalai/semantics/lqp/passes.py +3 -4
  18. relationalai/semantics/{rel → lqp}/rewrite/__init__.py +6 -0
  19. relationalai/semantics/metamodel/builtins.py +12 -1
  20. relationalai/semantics/metamodel/executor.py +2 -1
  21. relationalai/semantics/metamodel/rewrite/__init__.py +3 -9
  22. relationalai/semantics/metamodel/rewrite/flatten.py +8 -7
  23. relationalai/semantics/reasoners/graph/core.py +1356 -258
  24. relationalai/semantics/rel/builtins.py +5 -1
  25. relationalai/semantics/rel/compiler.py +3 -3
  26. relationalai/semantics/rel/executor.py +20 -11
  27. relationalai/semantics/sql/compiler.py +2 -3
  28. relationalai/semantics/sql/executor/duck_db.py +8 -4
  29. relationalai/semantics/sql/executor/snowflake.py +1 -1
  30. relationalai/tools/cli.py +17 -6
  31. relationalai/tools/cli_controls.py +334 -352
  32. relationalai/tools/constants.py +1 -0
  33. relationalai/tools/query_utils.py +27 -0
  34. relationalai/util/otel_configuration.py +1 -1
  35. {relationalai-0.11.4.dist-info → relationalai-0.12.1.dist-info}/METADATA +5 -4
  36. {relationalai-0.11.4.dist-info → relationalai-0.12.1.dist-info}/RECORD +45 -45
  37. relationalai/semantics/metamodel/rewrite/gc_nodes.py +0 -58
  38. relationalai/semantics/metamodel/rewrite/list_types.py +0 -109
  39. /relationalai/semantics/{rel → lqp}/rewrite/cdc.py +0 -0
  40. /relationalai/semantics/{rel → lqp}/rewrite/extract_common.py +0 -0
  41. /relationalai/semantics/{metamodel → lqp}/rewrite/extract_keys.py +0 -0
  42. /relationalai/semantics/{metamodel → lqp}/rewrite/fd_constraints.py +0 -0
  43. /relationalai/semantics/{rel → lqp}/rewrite/quantify_vars.py +0 -0
  44. /relationalai/semantics/{metamodel → lqp}/rewrite/splinter.py +0 -0
  45. {relationalai-0.11.4.dist-info → relationalai-0.12.1.dist-info}/WHEEL +0 -0
  46. {relationalai-0.11.4.dist-info → relationalai-0.12.1.dist-info}/entry_points.txt +0 -0
  47. {relationalai-0.11.4.dist-info → relationalai-0.12.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Iterable, Dict, Optional, List, cast, TYPE_CHECKING
2
4
  import json
3
5
  import logging
@@ -5,7 +7,11 @@ import uuid
5
7
 
6
8
  from relationalai import debugging
7
9
  from relationalai.clients.cache_store import GraphIndexCache
8
- from relationalai.clients.util import get_pyrel_version, poll_with_specified_overhead
10
+ from relationalai.clients.util import (
11
+ get_pyrel_version,
12
+ normalize_datetime,
13
+ poll_with_specified_overhead,
14
+ )
9
15
  from relationalai.errors import (
10
16
  ERPNotRunningError,
11
17
  EngineProvisioningFailed,
@@ -29,6 +35,7 @@ from relationalai.tools.constants import WAIT_FOR_STREAM_SYNC, Generation
29
35
  # Set up logger for this module
30
36
  logger = logging.getLogger(__name__)
31
37
 
38
+
32
39
  try:
33
40
  from rich.console import Console
34
41
  from rich.table import Table
@@ -63,49 +70,49 @@ POLL_MAX_DELAY = 2.5 # Maximum delay between polls in seconds
63
70
  # This query calculates a hash of column metadata (name, type, precision, scale, nullable)
64
71
  # to detect if source table schema has changed since stream was created
65
72
  STREAM_COLUMN_HASH_QUERY = """
66
- SELECT
67
- FQ_OBJECT_NAME,
68
- SHA2(
69
- LISTAGG(
70
- value:name::VARCHAR ||
73
+ WITH stream_columns AS (
74
+ SELECT
75
+ fq_object_name,
76
+ HASH(
77
+ value:name::VARCHAR,
71
78
  CASE
72
- WHEN value:precision IS NOT NULL AND value:scale IS NOT NULL
73
- THEN CASE value:type::VARCHAR
74
- WHEN 'FIXED' THEN 'NUMBER'
75
- WHEN 'REAL' THEN 'FLOAT'
76
- WHEN 'TEXT' THEN 'TEXT'
77
- ELSE value:type::VARCHAR
78
- END || '(' || value:precision || ',' || value:scale || ')'
79
- WHEN value:precision IS NOT NULL AND value:scale IS NULL
80
- THEN CASE value:type::VARCHAR
81
- WHEN 'FIXED' THEN 'NUMBER'
82
- WHEN 'REAL' THEN 'FLOAT'
83
- WHEN 'TEXT' THEN 'TEXT'
84
- ELSE value:type::VARCHAR
85
- END || '(0,' || value:precision || ')'
86
- WHEN value:length IS NOT NULL
87
- THEN CASE value:type::VARCHAR
88
- WHEN 'FIXED' THEN 'NUMBER'
89
- WHEN 'REAL' THEN 'FLOAT'
90
- WHEN 'TEXT' THEN 'TEXT'
91
- ELSE value:type::VARCHAR
92
- END || '(' || value:length || ')'
79
+ WHEN value:precision IS NOT NULL AND value:scale IS NOT NULL THEN CASE value:type::VARCHAR
80
+ WHEN 'FIXED' THEN 'NUMBER'
81
+ WHEN 'REAL' THEN 'FLOAT'
82
+ WHEN 'TEXT' THEN 'TEXT'
83
+ ELSE value:type::VARCHAR
84
+ END || '(' || value:precision || ',' || value:scale || ')'
85
+ WHEN value:precision IS NOT NULL AND value:scale IS NULL THEN CASE value:type::VARCHAR
86
+ WHEN 'FIXED' THEN 'NUMBER'
87
+ WHEN 'REAL' THEN 'FLOAT'
88
+ WHEN 'TEXT' THEN 'TEXT'
89
+ ELSE value:type::VARCHAR
90
+ END || '(0,' || value:precision || ')'
91
+ WHEN value:length IS NOT NULL THEN CASE value:type::VARCHAR
92
+ WHEN 'FIXED' THEN 'NUMBER'
93
+ WHEN 'REAL' THEN 'FLOAT'
94
+ WHEN 'TEXT' THEN 'TEXT'
95
+ ELSE value:type::VARCHAR
96
+ END || '(' || value:length || ')'
93
97
  ELSE CASE value:type::VARCHAR
94
- WHEN 'FIXED' THEN 'NUMBER'
95
- WHEN 'REAL' THEN 'FLOAT'
96
- WHEN 'TEXT' THEN 'TEXT'
97
- ELSE value:type::VARCHAR
98
- END
99
- END ||
100
- CASE WHEN value:nullable::BOOLEAN THEN 'YES' ELSE 'NO' END,
101
- ','
102
- ) WITHIN GROUP (ORDER BY value:name::VARCHAR),
103
- 256
104
- ) AS STREAM_HASH
105
- FROM {app_name}.api.data_streams,
106
- LATERAL FLATTEN(input => COLUMNS) f
107
- WHERE RAI_DATABASE = '{rai_database}' AND FQ_OBJECT_NAME IN ({fqn_list})
108
- GROUP BY FQ_OBJECT_NAME;
98
+ WHEN 'FIXED' THEN 'NUMBER'
99
+ WHEN 'REAL' THEN 'FLOAT'
100
+ WHEN 'TEXT' THEN 'TEXT'
101
+ ELSE value:type::VARCHAR
102
+ END
103
+ END,
104
+ IFF(value:nullable::BOOLEAN, 'YES', 'NO')
105
+ ) AS column_signature
106
+ FROM {app_name}.api.data_streams,
107
+ LATERAL FLATTEN(input => columns)
108
+ WHERE rai_database = '{rai_database}'
109
+ AND fq_object_name IN ({fqn_list})
110
+ )
111
+ SELECT
112
+ fq_object_name AS FQ_OBJECT_NAME,
113
+ HEX_ENCODE(HASH_AGG(column_signature)) AS STREAM_HASH
114
+ FROM stream_columns
115
+ GROUP BY fq_object_name;
109
116
  """
110
117
 
111
118
 
@@ -296,9 +303,10 @@ class UseIndexPoller:
296
303
  Returns:
297
304
  List of truly stale sources that need to be deleted/recreated
298
305
 
299
- A source is truly stale if:
300
- - The stream doesn't exist (needs to be created), OR
301
- - The column hashes don't match (needs to be recreated)
306
+ A source is truly stale if any of the following apply:
307
+ - The stream doesn't exist (needs to be created)
308
+ - The source table was recreated after the stream (table creation timestamp is newer)
309
+ - The column hashes don't match (schema drift needs cleanup)
302
310
  """
303
311
  stream_hashes = self._get_stream_column_hashes(stale_sources, progress)
304
312
 
@@ -306,14 +314,30 @@ class UseIndexPoller:
306
314
  for source in stale_sources:
307
315
  source_hash = self.source_info[source].get("columns_hash")
308
316
  stream_hash = stream_hashes.get(source)
317
+ table_created_at_raw = self.source_info[source].get("table_created_at")
318
+ stream_created_at_raw = self.source_info[source].get("stream_created_at")
319
+
320
+ table_created_at = normalize_datetime(table_created_at_raw)
321
+ stream_created_at = normalize_datetime(stream_created_at_raw)
322
+
323
+ recreated_table = False
324
+ if table_created_at is not None and stream_created_at is not None:
325
+ # If the source table was recreated (new creation timestamp) but kept
326
+ # the same column definitions, we still need to recycle the stream so
327
+ # that Snowflake picks up the new table instance.
328
+ recreated_table = table_created_at > stream_created_at
309
329
 
310
330
  # Log hash comparison for debugging
311
331
  logger.debug(f"Source: {source}")
312
332
  logger.debug(f" Source table hash: {source_hash}")
313
333
  logger.debug(f" Stream hash: {stream_hash}")
314
334
  logger.debug(f" Match: {source_hash == stream_hash}")
335
+ if recreated_table:
336
+ logger.debug(" Table appears to have been recreated (table_created_at > stream_created_at)")
337
+ logger.debug(f" table_created_at: {table_created_at}")
338
+ logger.debug(f" stream_created_at: {stream_created_at}")
315
339
 
316
- if stream_hash is None or source_hash != stream_hash:
340
+ if stream_hash is None or source_hash != stream_hash or recreated_table:
317
341
  logger.debug(" Action: DELETE (stale)")
318
342
  truly_stale.append(source)
319
343
  else:
@@ -376,7 +400,7 @@ class UseIndexPoller:
376
400
  stale_sources = [
377
401
  source
378
402
  for source, info in self.source_info.items()
379
- if info["state"] == "STALE"
403
+ if info.get("state") == "STALE"
380
404
  ]
381
405
 
382
406
  if not stale_sources:
@@ -763,7 +787,7 @@ class UseIndexPoller:
763
787
  # Log the error for debugging
764
788
  logger.warning(f"Failed to enable change tracking on {fqn}: {e}")
765
789
  failed_tables.append((fqn, str(e)))
766
-
790
+
767
791
  # Handle errors based on subtask type
768
792
  if len(tables_to_process) <= MAX_INDIVIDUAL_SUBTASKS:
769
793
  # Mark the individual subtask as failed and complete it
@@ -80,6 +80,15 @@ def escape_for_f_string(code: str) -> str:
80
80
  def escape_for_sproc(code: str) -> str:
81
81
  return code.replace("$$", "\\$\\$")
82
82
 
83
+
84
+ def normalize_datetime(value: object) -> datetime | None:
85
+ """Return a timezone-aware UTC datetime or None."""
86
+ if not isinstance(value, datetime):
87
+ return None
88
+ if value.tzinfo is None:
89
+ return value.replace(tzinfo=timezone.utc)
90
+ return value.astimezone(timezone.utc)
91
+
83
92
  # @NOTE: `overhead_rate` should fall between 0.05 and 0.5 depending on how time sensitive / expensive the operation in question is.
84
93
  def poll_with_specified_overhead(
85
94
  f,
relationalai/dsl.py CHANGED
@@ -22,6 +22,7 @@ import sys
22
22
  from pandas import DataFrame
23
23
 
24
24
  from relationalai.environments import runtime_env, SnowbookEnvironment
25
+ from relationalai.tools.constants import QUERY_ATTRIBUTES_HEADER
25
26
 
26
27
  from .clients.client import Client
27
28
 
@@ -34,9 +35,7 @@ from .errors import FilterAsValue, Errors, InvalidPropertySetException, Multiple
34
35
  #--------------------------------------------------
35
36
 
36
37
  RESERVED_PROPS = ["add", "set", "persist", "unpersist"]
37
-
38
38
  MAX_QUERY_ATTRIBUTE_LENGTH = 255
39
- QUERY_ATTRIBUTES_HEADER = "X-Query-Attributes"
40
39
 
41
40
  Value = Union[
42
41
  "Expression",
@@ -1,5 +1,7 @@
1
- from relationalai.semantics.metamodel.rewrite import Splinter, RewriteListTypes, GarbageCollectNodes, Flatten, \
2
- DNFUnionSplitter, ExtractKeys, ExtractNestedLogicals, FDConstraints, flatten
1
+ from relationalai.semantics.metamodel.rewrite import Flatten, \
2
+ DNFUnionSplitter, ExtractNestedLogicals, flatten
3
+ from relationalai.semantics.lqp.rewrite import Splinter, \
4
+ ExtractKeys, FDConstraints
3
5
 
4
- __all__ = ["Splinter", "RewriteListTypes", "GarbageCollectNodes", "Flatten", "DNFUnionSplitter", "ExtractKeys",
6
+ __all__ = ["Splinter", "Flatten", "DNFUnionSplitter", "ExtractKeys",
5
7
  "ExtractNestedLogicals", "FDConstraints", "flatten"]
@@ -1,4 +1,4 @@
1
- from relationalai.semantics.rel.rewrite import CDC, ExtractCommon, QuantifyVars
1
+ from relationalai.semantics.lqp.rewrite import CDC, ExtractCommon, QuantifyVars
2
2
 
3
3
  __all__ = [
4
4
  "CDC",
@@ -20,7 +20,16 @@ class SnowbookEnvironment(NotebookRuntimeEnvironment, SessionEnvironment):
20
20
 
21
21
  def __init__(self):
22
22
  super().__init__()
23
- self.runner = "container" if "snowflake.connector.auth" in sys.modules else "warehouse"
23
+ # Detect runner type based on module presence:
24
+ # - Warehouse runtime has '_snowflake' module
25
+ # - Container runtime has 'snowflake._legacy' module
26
+ if "_snowflake" in sys.modules:
27
+ self.runner = "warehouse"
28
+ elif "snowflake._legacy" in sys.modules:
29
+ self.runner = "container"
30
+ else:
31
+ # Fallback to original check
32
+ self.runner = "container" if "snowflake.connector.auth" in sys.modules else "warehouse"
24
33
 
25
34
  @classmethod
26
35
  def detect(cls):
relationalai/errors.py CHANGED
@@ -2397,17 +2397,18 @@ class UnsupportedColumnTypesWarning(RAIWarning):
2397
2397
  """)
2398
2398
 
2399
2399
  class QueryTimeoutExceededException(RAIException):
2400
- def __init__(self, timeout_mins: int, config_file_path: str | None = None):
2400
+ def __init__(self, timeout_mins: int, query_id: str | None = None, config_file_path: str | None = None):
2401
2401
  self.timeout_mins = timeout_mins
2402
- self.message = f"Query execution time exceeded the specified timeout of {timeout_mins} minutes."
2403
2402
  self.name = "Query Timeout Exceeded"
2403
+ self.message = f"Query execution time exceeded the specified timeout of {self.timeout_mins} minutes."
2404
+ self.query_id = query_id or ""
2404
2405
  self.config_file_path = config_file_path or ""
2405
2406
  self.content = self.format_message()
2406
2407
  super().__init__(self.message, self.name, self.content)
2407
2408
 
2408
2409
  def format_message(self):
2409
2410
  return textwrap.dedent(f"""
2410
- {self.message}
2411
+ Query execution time exceeded the specified timeout of {self.timeout_mins} minutes{f' for query with ID: {self.query_id}' if self.query_id else ''}.
2411
2412
 
2412
2413
  Consider increasing the 'query_timeout_mins' parameter in your configuration file{f' (stored in {self.config_file_path})' if self.config_file_path else ''} to allow more time for query execution.
2413
2414
  """)
@@ -2432,3 +2433,23 @@ class AzureUnsupportedQueryTimeoutException(RAIException):
2432
2433
  Please remove the 'query_timeout_mins' from your configuration file{f' (stored in {self.config_file_path})' if self.config_file_path else ''} when running on platform Azure.
2433
2434
  """)
2434
2435
 
2436
+ class AzureLegacyDependencyMissingException(RAIException):
2437
+ def __init__(self):
2438
+ self.message = "The Azure platform requires the 'legacy' extras to be installed."
2439
+ self.name = "Azure Legacy Dependency Missing"
2440
+ self.content = self.format_message()
2441
+ super().__init__(self.message, self.name, self.content)
2442
+
2443
+ def format_message(self):
2444
+ return textwrap.dedent("""
2445
+ The Azure platform requires the 'rai-sdk' package, which is not installed.
2446
+
2447
+ To use the Azure platform, please install the legacy extras:
2448
+
2449
+ pip install relationalai[legacy]
2450
+
2451
+ Or if upgrading an existing installation:
2452
+
2453
+ pip install --upgrade relationalai[legacy]
2454
+ """)
2455
+
@@ -5,3 +5,4 @@ external = Relationship.builtins["external"]
5
5
  concept_population = Relationship.builtins["concept_population"]
6
6
  function = Relationship.builtins["function"]
7
7
  from_cdc = Relationship.builtins["from_cdc"]
8
+ track = Relationship.builtins["track"]
@@ -2338,6 +2338,7 @@ class Fragment():
2338
2338
  self._define.extend(parent._define)
2339
2339
  self._order_by.extend(parent._order_by)
2340
2340
  self._limit = parent._limit
2341
+ self._meta.update(parent._meta)
2341
2342
 
2342
2343
  def _add_items(self, items:PySequence[Any], to_attr:list[Any]):
2343
2344
  # TODO: ensure that you are _either_ a select, require, or then
@@ -2416,9 +2417,26 @@ class Fragment():
2416
2417
  return f
2417
2418
 
2418
2419
  def meta(self, **kwargs: Any) -> Fragment:
2420
+ """Add metadata to the query.
2421
+
2422
+ Metadata can be used for debugging and observability purposes.
2423
+
2424
+ Args:
2425
+ **kwargs: Metadata key-value pairs
2426
+
2427
+ Returns:
2428
+ Fragment: Returns self for method chaining
2429
+
2430
+ Example:
2431
+ select(Person.name).meta(workload_name="test", priority=1, enabled=True)
2432
+ """
2433
+ if not kwargs:
2434
+ raise ValueError("meta() requires at least one argument")
2435
+
2419
2436
  self._meta.update(kwargs)
2420
2437
  return self
2421
2438
 
2439
+
2422
2440
  def annotate(self, *annos:Expression|Relationship|ir.Annotation) -> Fragment:
2423
2441
  self._annotations.extend(annos)
2424
2442
  return self
@@ -2497,7 +2515,7 @@ class Fragment():
2497
2515
  # @TODO for now we set tag to None but we need to work out how to properly propagate user-provided tag here
2498
2516
  with debugging.span("query", tag=None, dsl=str(self), **with_source(self), meta=self._meta) as query_span:
2499
2517
  query_task = qb_model._compiler.fragment(self)
2500
- results = qb_model._to_executor().execute(ir_model, query_task)
2518
+ results = qb_model._to_executor().execute(ir_model, query_task, meta=self._meta)
2501
2519
  query_span["results"] = results
2502
2520
  # For local debugging mostly
2503
2521
  dry_run = qb_model._dry_run or bool(qb_model._config.get("compiler.dry_run", False))
@@ -2524,7 +2542,7 @@ class Fragment():
2524
2542
  # @TODO for now we set tag to None but we need to work out how to properly propagate user-provided tag here
2525
2543
  with debugging.span("query", tag=None, dsl=str(clone), **with_source(clone), meta=clone._meta) as query_span:
2526
2544
  query_task = qb_model._compiler.fragment(clone)
2527
- results = qb_model._to_executor().execute(ir_model, query_task, format="snowpark")
2545
+ results = qb_model._to_executor().execute(ir_model, query_task, format="snowpark", meta=clone._meta)
2528
2546
  query_span["alt_format_results"] = results
2529
2547
  return results
2530
2548
 
@@ -2541,7 +2559,8 @@ class Fragment():
2541
2559
  clone._source = runtime_env.get_source_pos()
2542
2560
  with debugging.span("query", dsl=str(clone), **with_source(clone), meta=clone._meta):
2543
2561
  query_task = qb_model._compiler.fragment(clone)
2544
- qb_model._to_executor().execute(ir_model, query_task, result_cols=result_cols, export_to=table._fqn, update=update)
2562
+ qb_model._to_executor().execute(ir_model, query_task, result_cols=result_cols, export_to=table._fqn, update=update, meta=clone._meta)
2563
+
2545
2564
 
2546
2565
  #--------------------------------------------------
2547
2566
  # Select / Where
@@ -11,4 +11,5 @@ adhoc_annotation = f.annotation(adhoc, [])
11
11
  annotations_to_emit = FrozenOrderedSet([
12
12
  adhoc.name,
13
13
  builtins.function.name,
14
+ builtins.track.name,
14
15
  ])
@@ -21,8 +21,8 @@ from relationalai.clients.config import Config
21
21
  from relationalai.clients.snowflake import APP_NAME
22
22
  from relationalai.clients.types import TransactionAsyncResponse
23
23
  from relationalai.clients.util import IdentityParser
24
- from relationalai.tools.constants import USE_DIRECT_ACCESS
25
-
24
+ from relationalai.tools.constants import USE_DIRECT_ACCESS, QUERY_ATTRIBUTES_HEADER
25
+ from relationalai.tools.query_utils import prepare_metadata_for_headers
26
26
 
27
27
  class LQPExecutor(e.Executor):
28
28
  """Executes LQP using the RAI client."""
@@ -267,7 +267,7 @@ class LQPExecutor(e.Executor):
267
267
  if ivm_flag:
268
268
  config_dict['ivm.maintenance_level'] = lqp_ir.Value(value=ivm_flag, meta=None)
269
269
  return construct_configure(config_dict, None)
270
-
270
+
271
271
  def _compile_intrinsics(self) -> lqp_ir.Epoch:
272
272
  """Construct an epoch that defines a number of built-in definitions used by the
273
273
  emitter."""
@@ -344,6 +344,10 @@ class LQPExecutor(e.Executor):
344
344
  df, errs = result_helpers.format_results(raw_results, cols)
345
345
  self.report_errors(errs)
346
346
 
347
+ # Rename columns if wide outputs is enabled
348
+ if self.wide_outputs and len(cols) - len(extra_cols) == len(df.columns):
349
+ df.columns = cols[: len(df.columns)]
350
+
347
351
  # Process exports
348
352
  if export_to and not self.dry_run:
349
353
  assert cols, "No columns found in the output"
@@ -362,7 +366,7 @@ class LQPExecutor(e.Executor):
362
366
 
363
367
  def execute(self, model: ir.Model, task: ir.Task, format: Literal["pandas", "snowpark"] = "pandas",
364
368
  result_cols: Optional[list[str]] = None, export_to: Optional[str] = None,
365
- update: bool = False) -> DataFrame:
369
+ update: bool = False, meta: dict[str, Any] | None = None) -> DataFrame:
366
370
  self.prepare_data()
367
371
  previous_model = self._last_model
368
372
 
@@ -374,6 +378,9 @@ class LQPExecutor(e.Executor):
374
378
  if format != "pandas":
375
379
  raise ValueError(f"Unsupported format: {format}")
376
380
 
381
+ # Format meta as headers
382
+ json_meta = prepare_metadata_for_headers(meta)
383
+ headers = {QUERY_ATTRIBUTES_HEADER: json_meta} if json_meta else {}
377
384
  raw_results = self.resources.exec_lqp(
378
385
  self.database,
379
386
  self.engine,
@@ -383,6 +390,7 @@ class LQPExecutor(e.Executor):
383
390
  # transactions are serialized.
384
391
  readonly=False,
385
392
  nowait_durable=True,
393
+ headers=headers,
386
394
  )
387
395
  assert isinstance(raw_results, TransactionAsyncResponse)
388
396
 
@@ -102,6 +102,7 @@ def _get_export_reads(export_ids: list[tuple[lqp.RelationId, int, lqp.Type]]) ->
102
102
  data_columns=csv_columns,
103
103
  compression="gzip",
104
104
  partition_size=200,
105
+ syntax_escapechar='"', # To follow Snowflake's expected format
105
106
  meta=None,
106
107
  )
107
108
  reads.append(lqp.Read(read_type=lqp.Export(config=export_csv_config, meta=None), meta=None))
@@ -2,13 +2,12 @@ from relationalai.semantics.metamodel.compiler import Pass
2
2
  from relationalai.semantics.metamodel import ir, builtins as rel_builtins, factory as f, visitor
3
3
  from relationalai.semantics.metamodel.typer import Checker, InferTypes, typer
4
4
  from relationalai.semantics.metamodel import helpers, types
5
- from relationalai.semantics.metamodel.rewrite import (Splinter, ExtractNestedLogicals, ExtractKeys, FDConstraints,
6
- DNFUnionSplitter, DischargeConstraints)
7
5
  from relationalai.semantics.metamodel.util import FrozenOrderedSet
8
6
 
9
7
  from relationalai.semantics.metamodel.rewrite import Flatten
10
- # TODO: Move this into metamodel.rewrite
11
- from relationalai.semantics.rel.rewrite import QuantifyVars, CDC, ExtractCommon
8
+
9
+ from ..metamodel.rewrite import DischargeConstraints, DNFUnionSplitter, ExtractNestedLogicals
10
+ from .rewrite import CDC, ExtractCommon, ExtractKeys, FDConstraints, QuantifyVars, Splinter
12
11
 
13
12
  from relationalai.semantics.lqp.utils import output_names
14
13
 
@@ -1,9 +1,15 @@
1
1
  from .cdc import CDC
2
2
  from .extract_common import ExtractCommon
3
+ from .extract_keys import ExtractKeys
4
+ from .fd_constraints import FDConstraints
3
5
  from .quantify_vars import QuantifyVars
6
+ from .splinter import Splinter
4
7
 
5
8
  __all__ = [
6
9
  "CDC",
7
10
  "ExtractCommon",
11
+ "ExtractKeys",
12
+ "FDConstraints",
8
13
  "QuantifyVars",
14
+ "Splinter",
9
15
  ]
@@ -496,6 +496,17 @@ function = f.relation("function", [f.input_field("code", types.Symbol)])
496
496
  function_checked_annotation = f.annotation(function, [f.lit("checked")])
497
497
  function_annotation = f.annotation(function, [])
498
498
 
499
+ # Indicates this relation should be tracked in telemetry. Only supported for Relationships.
500
+ # `RAI_BackIR.with_relation_tracking` produces log messages at the start and end of each
501
+ # SCC evaluation, if any declarations bear the `track` annotation.
502
+ track = f.relation("track", [
503
+ # BackIR evaluation expects 2 parameters on the track annotation: the tracking
504
+ # library name and tracking relation name, which appear as log metadata fields.
505
+ f.input_field("library", types.Symbol),
506
+ f.input_field("relation", types.Symbol)
507
+ ])
508
+ track_annotation = f.annotation(track, [])
509
+
499
510
  # All ir nodes marked by this annotation will be removed from the final metamodel before compilation.
500
511
  # Specifically it happens in `Flatten` pass when rewrites for `require` happen
501
512
  discharged = f.relation("discharged", [])
@@ -672,7 +683,7 @@ def _compute_builtin_overloads() -> list[ir.Relation]:
672
683
  return overloads
673
684
 
674
685
  # manually maintain the list of relations that are actually annotations
675
- builtin_annotations = [external, export, concept_population, from_cdc, from_cast]
686
+ builtin_annotations = [external, export, concept_population, from_cdc, from_cast, track]
676
687
  builtin_annotations_by_name = dict((r.name, r) for r in builtin_annotations)
677
688
 
678
689
  builtin_relations = _compute_builtin_relations()
@@ -56,5 +56,6 @@ class Executor():
56
56
  rich.print(f"[black]{df}[/black]")
57
57
  if not df.empty:
58
58
  for col in extra_cols:
59
- df = df.drop(col, axis=1)
59
+ if col in df.columns:
60
+ df = df.drop(col, axis=1)
60
61
  return df
@@ -1,12 +1,6 @@
1
- from .splinter import Splinter
2
- from .list_types import RewriteListTypes
3
- from .gc_nodes import GarbageCollectNodes
4
- from .flatten import Flatten
1
+ from .discharge_constraints import DischargeConstraints
5
2
  from .dnf_union_splitter import DNFUnionSplitter
6
- from .extract_keys import ExtractKeys
7
3
  from .extract_nested_logicals import ExtractNestedLogicals
8
- from .fd_constraints import FDConstraints
9
- from .discharge_constraints import DischargeConstraints
4
+ from .flatten import Flatten
10
5
 
11
- __all__ = ["Splinter", "RewriteListTypes", "GarbageCollectNodes", "Flatten", "DNFUnionSplitter", "ExtractKeys",
12
- "ExtractNestedLogicals", "FDConstraints", "DischargeConstraints"]
6
+ __all__ = ["DischargeConstraints", "DNFUnionSplitter", "ExtractNestedLogicals", "Flatten"]
@@ -558,22 +558,23 @@ class Flatten(Pass):
558
558
  def rewrite_wide_output(self, output: ir.Output):
559
559
  assert(output.keys)
560
560
 
561
- # only prefix keys that are not already in the output
562
- prefix_keys = []
561
+ # only append keys that are not already in the output
562
+ suffix_keys = []
563
563
  for key in output.keys:
564
564
  if all([val is not key for _, val in output.aliases]):
565
- prefix_keys.append(key)
565
+ suffix_keys.append(key)
566
566
 
567
567
  aliases: OrderedSet[Tuple[str, ir.Value]] = ordered_set()
568
- # add the keys to the output
569
- for key in prefix_keys:
570
- aliases.add((key.name, key))
571
568
 
572
569
  # add the remaining args, unless it is already a key
573
570
  for name, val in output.aliases:
574
- if not isinstance(val, ir.Var) or val not in prefix_keys:
571
+ if not isinstance(val, ir.Var) or val not in suffix_keys:
575
572
  aliases.add((name, val))
576
573
 
574
+ # add the keys to the output
575
+ for key in suffix_keys:
576
+ aliases.add((key.name, key))
577
+
577
578
  # TODO - we are assuming that the Rel compiler will translate nullable lookups
578
579
  # properly, returning a `Missing` if necessary, like this:
579
580
  # (nested_192(_adult, _adult_name) or (not nested_192(_adult, _) and _adult_name = Missing)) and