relationalai 0.12.9__py3-none-any.whl → 0.12.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. relationalai/__init__.py +9 -0
  2. relationalai/clients/__init__.py +2 -2
  3. relationalai/clients/local.py +571 -0
  4. relationalai/clients/snowflake.py +106 -83
  5. relationalai/debugging.py +5 -2
  6. relationalai/semantics/__init__.py +2 -2
  7. relationalai/semantics/internal/__init__.py +2 -2
  8. relationalai/semantics/internal/internal.py +24 -7
  9. relationalai/semantics/lqp/README.md +34 -0
  10. relationalai/semantics/lqp/constructors.py +2 -1
  11. relationalai/semantics/lqp/executor.py +13 -2
  12. relationalai/semantics/lqp/ir.py +4 -0
  13. relationalai/semantics/lqp/model2lqp.py +41 -2
  14. relationalai/semantics/lqp/passes.py +6 -4
  15. relationalai/semantics/lqp/rewrite/__init__.py +2 -0
  16. relationalai/semantics/lqp/rewrite/annotate_constraints.py +55 -0
  17. relationalai/semantics/lqp/rewrite/extract_keys.py +22 -3
  18. relationalai/semantics/lqp/rewrite/functional_dependencies.py +42 -10
  19. relationalai/semantics/lqp/rewrite/quantify_vars.py +14 -0
  20. relationalai/semantics/lqp/validators.py +3 -0
  21. relationalai/semantics/metamodel/builtins.py +5 -0
  22. relationalai/semantics/metamodel/rewrite/flatten.py +10 -4
  23. relationalai/semantics/metamodel/typer/typer.py +13 -0
  24. relationalai/semantics/metamodel/types.py +2 -1
  25. relationalai/semantics/reasoners/graph/core.py +44 -53
  26. relationalai/tools/debugger.py +4 -2
  27. relationalai/tools/qb_debugger.py +5 -3
  28. {relationalai-0.12.9.dist-info → relationalai-0.12.10.dist-info}/METADATA +2 -2
  29. {relationalai-0.12.9.dist-info → relationalai-0.12.10.dist-info}/RECORD +32 -29
  30. {relationalai-0.12.9.dist-info → relationalai-0.12.10.dist-info}/WHEEL +0 -0
  31. {relationalai-0.12.9.dist-info → relationalai-0.12.10.dist-info}/entry_points.txt +0 -0
  32. {relationalai-0.12.9.dist-info → relationalai-0.12.10.dist-info}/licenses/LICENSE +0 -0
@@ -441,7 +441,8 @@ class Resources(ResourcesBase):
441
441
  code: str,
442
442
  params: List[Any] | Any | None = None,
443
443
  raw: bool = False,
444
- help: bool = True
444
+ help: bool = True,
445
+ skip_auto_create: bool = False
445
446
  ) -> Any:
446
447
  # print(f"\n--- sql---\n{code}\n--- end sql---\n")
447
448
  if not self._session:
@@ -458,7 +459,6 @@ class Resources(ResourcesBase):
458
459
  rai_app = self.config.get("rai_app_name", "")
459
460
  current_role = self.config.get("role")
460
461
  engine = self.get_default_engine_name()
461
- engine_size = self.config.get_default_engine_size()
462
462
  assert isinstance(rai_app, str), f"rai_app_name must be a string, not {type(rai_app)}"
463
463
  assert isinstance(engine, str), f"engine must be a string, not {type(engine)}"
464
464
  print("\n")
@@ -467,15 +467,10 @@ class Resources(ResourcesBase):
467
467
  if re.search(f"database '{rai_app}' does not exist or not authorized.".lower(), orig_message):
468
468
  exception = SnowflakeAppMissingException(rai_app, current_role)
469
469
  raise exception from None
470
- if _is_engine_issue(orig_message) or _is_database_issue(orig_message):
470
+ # skip creating the engine if the query is a user transaction. exec_async_v2 will handle that case.
471
+ if _is_engine_issue(orig_message) and not skip_auto_create:
471
472
  try:
472
- self._poll_use_index(
473
- app_name=self.get_app_name(),
474
- sources=self.sources,
475
- model=self.database,
476
- engine_name=engine,
477
- engine_size=engine_size
478
- )
473
+ self.auto_create_engine(engine)
479
474
  return self._exec(code, params, raw=raw, help=help)
480
475
  except EngineNameValidationException as e:
481
476
  raise EngineNameValidationException(engine) from e
@@ -1612,6 +1607,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1612
1607
  response = self._exec(
1613
1608
  sql_string,
1614
1609
  raw_code,
1610
+ skip_auto_create=True,
1615
1611
  )
1616
1612
  if not response:
1617
1613
  raise Exception("Failed to create transaction")
@@ -1629,6 +1625,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1629
1625
  bypass_index=False,
1630
1626
  language: str = "rel",
1631
1627
  query_timeout_mins: int | None = None,
1628
+ gi_setup_skipped: bool = False,
1632
1629
  ):
1633
1630
  if inputs is None:
1634
1631
  inputs = {}
@@ -1638,6 +1635,8 @@ Otherwise, remove it from your '{profile}' configuration profile.
1638
1635
  with debugging.span("transaction", **query_attrs_dict) as txn_span:
1639
1636
  with debugging.span("create_v2", **query_attrs_dict) as create_span:
1640
1637
  request_headers['user-agent'] = get_pyrel_version(self.generation)
1638
+ request_headers['gi_setup_skipped'] = str(gi_setup_skipped)
1639
+ request_headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
1641
1640
  response = self._exec_rai_app(
1642
1641
  database=database,
1643
1642
  engine=engine,
@@ -1897,26 +1896,29 @@ Otherwise, remove it from your '{profile}' configuration profile.
1897
1896
  # Exec
1898
1897
  #--------------------------------------------------
1899
1898
 
1900
- def exec_lqp(
1899
+ def _exec_with_gi_retry(
1901
1900
  self,
1902
1901
  database: str,
1903
1902
  engine: str | None,
1904
- raw_code: bytes,
1905
- readonly=True,
1906
- *,
1907
- inputs: Dict | None = None,
1908
- nowait_durable=False,
1909
- headers: Dict | None = None,
1910
- bypass_index=False,
1911
- query_timeout_mins: int | None = None,
1903
+ raw_code: str,
1904
+ inputs: Dict | None,
1905
+ readonly: bool,
1906
+ nowait_durable: bool,
1907
+ headers: Dict | None,
1908
+ bypass_index: bool,
1909
+ language: str,
1910
+ query_timeout_mins: int | None,
1912
1911
  ):
1913
- raw_code_b64 = base64.b64encode(raw_code).decode("utf-8")
1912
+ """Execute with graph index retry logic.
1914
1913
 
1914
+ Attempts execution with gi_setup_skipped=True first. If an engine or database
1915
+ issue occurs, polls use_index and retries with gi_setup_skipped=False.
1916
+ """
1915
1917
  try:
1916
1918
  return self._exec_async_v2(
1917
- database, engine, raw_code_b64, inputs, readonly, nowait_durable,
1918
- headers=headers, bypass_index=bypass_index, language='lqp',
1919
- query_timeout_mins=query_timeout_mins,
1919
+ database, engine, raw_code, inputs, readonly, nowait_durable,
1920
+ headers=headers, bypass_index=bypass_index, language=language,
1921
+ query_timeout_mins=query_timeout_mins, gi_setup_skipped=True,
1920
1922
  )
1921
1923
  except Exception as e:
1922
1924
  err_message = str(e).lower()
@@ -1933,13 +1935,32 @@ Otherwise, remove it from your '{profile}' configuration profile.
1933
1935
  )
1934
1936
 
1935
1937
  return self._exec_async_v2(
1936
- database, engine, raw_code_b64, inputs, readonly, nowait_durable,
1937
- headers=headers, bypass_index=bypass_index, language='lqp',
1938
- query_timeout_mins=query_timeout_mins,
1938
+ database, engine, raw_code, inputs, readonly, nowait_durable,
1939
+ headers=headers, bypass_index=bypass_index, language=language,
1940
+ query_timeout_mins=query_timeout_mins, gi_setup_skipped=False,
1939
1941
  )
1940
1942
  else:
1941
1943
  raise e
1942
1944
 
1945
+ def exec_lqp(
1946
+ self,
1947
+ database: str,
1948
+ engine: str | None,
1949
+ raw_code: bytes,
1950
+ readonly=True,
1951
+ *,
1952
+ inputs: Dict | None = None,
1953
+ nowait_durable=False,
1954
+ headers: Dict | None = None,
1955
+ bypass_index=False,
1956
+ query_timeout_mins: int | None = None,
1957
+ ):
1958
+ raw_code_b64 = base64.b64encode(raw_code).decode("utf-8")
1959
+ return self._exec_with_gi_retry(
1960
+ database, engine, raw_code_b64, inputs, readonly, nowait_durable,
1961
+ headers, bypass_index, 'lqp', query_timeout_mins
1962
+ )
1963
+
1943
1964
 
1944
1965
  def exec_raw(
1945
1966
  self,
@@ -1955,45 +1976,10 @@ Otherwise, remove it from your '{profile}' configuration profile.
1955
1976
  query_timeout_mins: int | None = None,
1956
1977
  ):
1957
1978
  raw_code = raw_code.replace("'", "\\'")
1958
-
1959
- try:
1960
- return self._exec_async_v2(
1961
- database,
1962
- engine,
1963
- raw_code,
1964
- inputs,
1965
- readonly,
1966
- nowait_durable,
1967
- headers=headers,
1968
- bypass_index=bypass_index,
1969
- query_timeout_mins=query_timeout_mins,
1970
- )
1971
- except Exception as e:
1972
- err_message = str(e).lower()
1973
- if _is_engine_issue(err_message) or _is_database_issue(err_message):
1974
- engine_name = engine or self.get_default_engine_name()
1975
- engine_size = self.config.get_default_engine_size()
1976
- self._poll_use_index(
1977
- app_name=self.get_app_name(),
1978
- sources=self.sources,
1979
- model=database,
1980
- engine_name=engine_name,
1981
- engine_size=engine_size,
1982
- headers=headers,
1983
- )
1984
- return self._exec_async_v2(
1985
- database,
1986
- engine,
1987
- raw_code,
1988
- inputs,
1989
- readonly,
1990
- nowait_durable,
1991
- headers=headers,
1992
- bypass_index=bypass_index,
1993
- query_timeout_mins=query_timeout_mins,
1994
- )
1995
- else:
1996
- raise e
1979
+ return self._exec_with_gi_retry(
1980
+ database, engine, raw_code, inputs, readonly, nowait_durable,
1981
+ headers, bypass_index, 'rel', query_timeout_mins
1982
+ )
1997
1983
 
1998
1984
 
1999
1985
  def format_results(self, results, task:m.Task|None=None) -> Tuple[DataFrame, List[Any]]:
@@ -3314,19 +3300,10 @@ class DirectAccessResources(Resources):
3314
3300
  message = "" # Not used when we check status_code directly
3315
3301
 
3316
3302
  # fix engine on engine error and retry
3317
- # Skip auto-retry if skip_auto_create is True to avoid recursion
3318
- if (_is_engine_issue(message) and not skip_auto_create) or _is_database_issue(message):
3319
- engine_name = payload.get("caller_engine_name", "") if payload else ""
3320
- engine_name = engine_name or self.get_default_engine_name()
3321
- engine_size = self.config.get_default_engine_size()
3322
- self._poll_use_index(
3323
- app_name=self.get_app_name(),
3324
- sources=self.sources,
3325
- model=self.database,
3326
- engine_name=engine_name,
3327
- engine_size=engine_size,
3328
- headers=headers,
3329
- )
3303
+ # Skip auto-retry if skip_auto_create is True to avoid recursion or to let _exec_async_v2 poll the index.
3304
+ if _is_engine_issue(message) and not skip_auto_create:
3305
+ engine = payload.get("engine_name", "") if payload else ""
3306
+ self.auto_create_engine(engine)
3330
3307
  response = _send_request()
3331
3308
  except requests.exceptions.ConnectionError as e:
3332
3309
  if "NameResolutionError" in str(e):
@@ -3340,6 +3317,48 @@ class DirectAccessResources(Resources):
3340
3317
  raise e
3341
3318
  return response
3342
3319
 
3320
+ def _txn_request_with_gi_retry(
3321
+ self,
3322
+ payload: Dict,
3323
+ headers: Dict[str, str],
3324
+ query_params: Dict,
3325
+ engine: Union[str, None],
3326
+ ):
3327
+ """Make request with graph index retry logic.
3328
+
3329
+ Attempts request with gi_setup_skipped=True first. If an engine or database
3330
+ issue occurs, polls use_index and retries with gi_setup_skipped=False.
3331
+ """
3332
+ response = self.request(
3333
+ "create_txn", payload=payload, headers=headers, query_params=query_params, skip_auto_create=True
3334
+ )
3335
+
3336
+ if response.status_code != 200:
3337
+ try:
3338
+ message = response.json().get("message", "")
3339
+ except requests.exceptions.JSONDecodeError:
3340
+ message = ""
3341
+
3342
+ if _is_engine_issue(message) or _is_database_issue(message):
3343
+ engine_name = engine or self.get_default_engine_name()
3344
+ engine_size = self.config.get_default_engine_size()
3345
+ self._poll_use_index(
3346
+ app_name=self.get_app_name(),
3347
+ sources=self.sources,
3348
+ model=self.database,
3349
+ engine_name=engine_name,
3350
+ engine_size=engine_size,
3351
+ headers=headers,
3352
+ )
3353
+ headers['gi_setup_skipped'] = 'False'
3354
+ response = self.request(
3355
+ "create_txn", payload=payload, headers=headers, query_params=query_params, skip_auto_create=True
3356
+ )
3357
+ else:
3358
+ raise ResponseStatusException("Failed to create transaction.", response)
3359
+
3360
+ return response
3361
+
3343
3362
  def _exec_async_v2(
3344
3363
  self,
3345
3364
  database: str,
@@ -3352,6 +3371,7 @@ class DirectAccessResources(Resources):
3352
3371
  bypass_index=False,
3353
3372
  language: str = "rel",
3354
3373
  query_timeout_mins: int | None = None,
3374
+ gi_setup_skipped: bool = False,
3355
3375
  ):
3356
3376
 
3357
3377
  with debugging.span("transaction") as txn_span:
@@ -3374,12 +3394,15 @@ class DirectAccessResources(Resources):
3374
3394
  payload["timeout_mins"] = query_timeout_mins
3375
3395
  query_params={"use_graph_index": str(use_graph_index and not bypass_index)}
3376
3396
 
3377
- response = self.request(
3378
- "create_txn", payload=payload, headers=headers, query_params=query_params,
3379
- )
3397
+ # Add gi_setup_skipped to headers
3398
+ if headers is None:
3399
+ headers = {}
3400
+ headers["gi_setup_skipped"] = str(gi_setup_skipped)
3401
+ headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
3380
3402
 
3381
- if response.status_code != 200:
3382
- raise ResponseStatusException("Failed to create transaction.", response)
3403
+ response = self._txn_request_with_gi_retry(
3404
+ payload, headers, query_params, engine
3405
+ )
3383
3406
 
3384
3407
  artifact_info = {}
3385
3408
  response_content = response.json()
relationalai/debugging.py CHANGED
@@ -26,6 +26,9 @@ find_block_in = find_block_in # re-export
26
26
  DEBUG = True
27
27
  handled_error = None
28
28
 
29
+ # Configurable debug log file location
30
+ DEBUG_LOG_FILE = os.environ.get('RAI_DEBUG_LOG', 'debug.jsonl')
31
+
29
32
  #--------------------------------------------------
30
33
  # Log Formatters
31
34
  #--------------------------------------------------
@@ -70,7 +73,7 @@ class FlushingFileHandler(logging.FileHandler):
70
73
  def emit(self, record):
71
74
  if not self._initialized:
72
75
  self._initialized = True
73
- with open('debug.jsonl', 'w'):
76
+ with open(DEBUG_LOG_FILE, 'w'):
74
77
  pass
75
78
  super().emit(record)
76
79
  self.flush()
@@ -78,7 +81,7 @@ class FlushingFileHandler(logging.FileHandler):
78
81
  try:
79
82
  # keep the old file-based debugger around and working until it's fully replaced.
80
83
  if DEBUG:
81
- file_handler = FlushingFileHandler('debug.jsonl', mode='a')
84
+ file_handler = FlushingFileHandler(DEBUG_LOG_FILE, mode='a')
82
85
  file_handler.setFormatter(JsonFormatter())
83
86
  logger.addHandler(file_handler)
84
87
  except Exception:
@@ -8,7 +8,7 @@ __include_in_docs__ = True
8
8
 
9
9
  from .internal import (
10
10
  Model, Concept, Relationship, RelationshipReading, Expression, Fragment, Error, Field,
11
- String, Integer, Int64, Int128, Float, Decimal, Bool,
11
+ AnyEntity, String, Integer, Int64, Int128, Float, Decimal, Bool,
12
12
  Date, DateTime,
13
13
  RawSource, Hash,
14
14
  select, where, require, define, distinct, union, data,
@@ -19,7 +19,7 @@ from .internal import (
19
19
 
20
20
  __all__ = [
21
21
  "Model", "Concept", "Relationship", "RelationshipReading", "Expression", "Fragment", "Error", "Field",
22
- "String", "Integer", "Int64", "Int128", "Float", "Decimal", "Bool",
22
+ "AnyEntity", "String", "Integer", "Int64", "Int128", "Float", "Decimal", "Bool",
23
23
  "Date", "DateTime",
24
24
  "RawSource", "Hash",
25
25
  "select", "where", "require", "define", "distinct", "union", "data",
@@ -4,7 +4,7 @@ API for RelationalAI.
4
4
 
5
5
  from .internal import (
6
6
  Model, Concept, Relationship, RelationshipReading, Expression, Fragment, Error, Field,
7
- String, Integer, Int64, Int128, Float, Decimal, Bool,
7
+ AnyEntity, String, Integer, Int64, Int128, Float, Decimal, Bool,
8
8
  Date, DateTime,
9
9
  RawSource, Hash,
10
10
  select, where, require, define, distinct, union, data,
@@ -15,7 +15,7 @@ from .internal import (
15
15
 
16
16
  __all__ = [
17
17
  "Model", "Concept", "Relationship", "RelationshipReading", "Expression", "Fragment", "Error", "Field",
18
- "String", "Integer", "Int64", "Int128", "Float", "Decimal", "Bool",
18
+ "AnyEntity", "String", "Integer", "Int64", "Int128", "Float", "Decimal", "Bool",
19
19
  "Date", "DateTime",
20
20
  "RawSource", "Hash",
21
21
  "select", "where", "require", "define", "distinct", "union", "data",
@@ -621,7 +621,7 @@ class Producer:
621
621
 
622
622
  if self._model and self._model._strict:
623
623
  raise AttributeError(f"{self._name} has no relationship `{name}`")
624
- if topmost_parent is not concept:
624
+ if topmost_parent is not concept and topmost_parent not in Concept.builtin_concepts:
625
625
  topmost_parent._relationships[name] = topmost_parent._get_relationship(name)
626
626
  rich.print(f"[red bold][Implicit Subtype Relationship][/red bold] [yellow]{concept}.{name}[/yellow] appended to topmost parent [yellow]{topmost_parent}[/yellow] instead")
627
627
 
@@ -1165,7 +1165,10 @@ Primitive = Concept.builtins["Primitive"] = Concept("Primitive")
1165
1165
  Error = Concept.builtins["Error"] = ErrorConcept("Error")
1166
1166
 
1167
1167
  def _register_builtin(name):
1168
- c = Concept(name, extends=[Primitive])
1168
+ if name == "AnyEntity":
1169
+ c = Concept(name)
1170
+ else:
1171
+ c = Concept(name, extends=[Primitive])
1169
1172
  Concept.builtin_concepts.add(c)
1170
1173
  Concept.builtins[name] = c
1171
1174
 
@@ -1174,6 +1177,7 @@ for builtin in types.builtin_types:
1174
1177
  if isinstance(builtin, ir.ScalarType):
1175
1178
  _register_builtin(builtin.name)
1176
1179
 
1180
+ AnyEntity = Concept.builtins["AnyEntity"]
1177
1181
  Float = Concept.builtins["Float"]
1178
1182
  Number = Concept.builtins["Number"]
1179
1183
  Int64 = Concept.builtins["Int64"]
@@ -2896,10 +2900,9 @@ class Compiler():
2896
2900
  if concept not in self.types:
2897
2901
  self.to_type(concept)
2898
2902
  self.to_relation(concept)
2899
- if concept._extends:
2900
- rule = self.concept_inheritance_rule(concept)
2901
- if rule:
2902
- rules.append(rule)
2903
+ rule = self.concept_inheritance_rule(concept)
2904
+ if rule:
2905
+ rules.append(rule)
2903
2906
  unresolved = []
2904
2907
  for relationship in model.relationships:
2905
2908
  if relationship not in self.relations:
@@ -3204,8 +3207,11 @@ class Compiler():
3204
3207
  # filter extends to get only non-primitive parents
3205
3208
  parents = []
3206
3209
  for parent in concept._extends:
3207
- if not parent._is_primitive():
3210
+ if not parent._is_primitive() and parent is not AnyEntity:
3208
3211
  parents.append(parent)
3212
+ # always extend AnyEntity for non-primitive types that are not built-in
3213
+ if not concept._is_primitive() and concept not in Concept.builtin_concepts:
3214
+ parents.append(AnyEntity)
3209
3215
  # only extends primitive types, no need for inheritance rules
3210
3216
  if not parents:
3211
3217
  return None
@@ -3218,6 +3224,17 @@ class Compiler():
3218
3224
  *[f.derive(self.to_relation(parent), [var]) for parent in parents]
3219
3225
  ])
3220
3226
 
3227
+ def concept_any_entity_rule(self, entities:list[Concept]):
3228
+ """
3229
+ Generate an inheritance rule for all these entities to AnyEntity.
3230
+ """
3231
+ any_entity_relation = self.to_relation(AnyEntity)
3232
+ var = f.var("v", types.Any)
3233
+ return f.logical([
3234
+ f.union([f.lookup(self.to_relation(e), [var]) for e in entities]),
3235
+ f.derive(any_entity_relation, [var])
3236
+ ])
3237
+
3221
3238
  def relation_dict(self, items:dict[Relationship|Concept, Producer], ctx:CompilerContext) -> dict[ir.Relation, list[ir.Var]]:
3222
3239
  return {self.to_relation(k): unwrap_list(self.lookup(v, ctx)) for k, v in items.items()}
3223
3240
 
@@ -0,0 +1,34 @@
1
+ # Logic Engine LQP Backend
2
+
3
+ The logic engine runs the *Logical Query Protocol* (short *LQP*). This module includes a
4
+ compiler from the semantic metamodel to LQP along with an executor.
5
+
6
+ ## Running against a local logic engine
7
+
8
+ For development and testing, it is possible to run PyRel models against a local logic engine
9
+ server process.
10
+
11
+ To start your local server, please refer to the [logic engine
12
+ docs](https://github.com/RelationalAI/raicode/tree/master/src/Server#starting-the-server).
13
+
14
+ With the local server running, add this to your `raiconfig.toml`:
15
+
16
+ ```toml
17
+ [profile.local]
18
+ platform = "local"
19
+ engine = "local"
20
+ host = "localhost"
21
+ port = 8010
22
+ ```
23
+
24
+ Then set `active_profile = "local"` at the top of the file.
25
+
26
+ **Known limitations:**
27
+
28
+ Local execution does not support running against Snowflake source tables.
29
+
30
+ At the moment, locally created databases cannot be cleaned up by the client. Eventually you
31
+ will need to clear your local pager directory.
32
+
33
+ At the moment, local execution is only supported for fast-path transactions, i.e. those
34
+ which complete in less than 5 seconds. Polling support will be added soon.
@@ -63,5 +63,6 @@ def mk_attribute(name: str, args: list[lqp.Value]) -> lqp.Attribute:
63
63
  def mk_transaction(
64
64
  epochs: list[lqp.Epoch],
65
65
  configure: lqp.Configure = lqp.construct_configure({}, None),
66
+ sync = None
66
67
  ) -> lqp.Transaction:
67
- return lqp.Transaction(epochs=epochs, configure=configure, meta=None)
68
+ return lqp.Transaction(epochs=epochs, configure=configure, sync=sync, meta=None)
@@ -66,6 +66,8 @@ class LQPExecutor(e.Executor):
66
66
  resource_class = rai.clients.snowflake.Resources
67
67
  if self.config.get("use_direct_access", USE_DIRECT_ACCESS):
68
68
  resource_class = rai.clients.snowflake.DirectAccessResources
69
+ if self.config.get("platform", "") == "local":
70
+ resource_class = rai.clients.local.LocalResources
69
71
  # NOTE: language="lqp" is not strictly required for LQP execution, but it
70
72
  # will significantly improve performance.
71
73
  self._resources = resource_class(
@@ -311,6 +313,12 @@ class LQPExecutor(e.Executor):
311
313
  config_dict['ivm.maintenance_level'] = lqp_ir.Value(value=ivm_flag, meta=None)
312
314
  return construct_configure(config_dict, None)
313
315
 
316
+ def _should_sync(self, model) :
317
+ if self._last_model != model:
318
+ return lqp_ir.Sync(fragments=[], meta=None)
319
+ else :
320
+ return None
321
+
314
322
  def _compile_intrinsics(self) -> lqp_ir.Epoch:
315
323
  """Construct an epoch that defines a number of built-in definitions used by the
316
324
  emitter."""
@@ -334,6 +342,7 @@ class LQPExecutor(e.Executor):
334
342
  meta=None,
335
343
  )
336
344
 
345
+ # [RAI-40997] We eagerly undefine query fragments so they are not committed to storage
337
346
  def _compile_undefine_query(self, query_epoch: lqp_ir.Epoch) -> lqp_ir.Epoch:
338
347
  fragment_ids = []
339
348
 
@@ -363,7 +372,9 @@ class LQPExecutor(e.Executor):
363
372
  epochs = []
364
373
  epochs.append(self._compile_intrinsics())
365
374
 
366
- if self._last_model != model:
375
+ sync = self._should_sync(model)
376
+
377
+ if sync is not None:
367
378
  with debugging.span("compile", metamodel=model) as install_span:
368
379
  install_span["compile_type"] = "model"
369
380
  _, model_epoch = self.compiler.compile(model, {"fragment_id": b"model"})
@@ -383,7 +394,7 @@ class LQPExecutor(e.Executor):
383
394
  epochs.append(self._compile_undefine_query(query_epoch))
384
395
 
385
396
  txn_span["compile_type"] = "query"
386
- txn = mk_transaction(epochs=epochs, configure=configure)
397
+ txn = mk_transaction(epochs=epochs, configure=configure, sync=sync)
387
398
  txn_span["lqp"] = lqp_print.to_string(txn, {"print_names": True, "print_debug": False, "print_csv_filename": False})
388
399
 
389
400
  validate_lqp(txn)
@@ -4,6 +4,7 @@ __all__ = [
4
4
  "SourceInfo",
5
5
  "LqpNode",
6
6
  "Declaration",
7
+ "FunctionalDependency",
7
8
  "Def",
8
9
  "Loop",
9
10
  "Abstraction",
@@ -45,6 +46,7 @@ __all__ = [
45
46
  "Read",
46
47
  "Epoch",
47
48
  "Transaction",
49
+ "Sync",
48
50
  "DebugInfo",
49
51
  "Configure",
50
52
  "IVMConfig",
@@ -59,6 +61,7 @@ from lqp.ir import (
59
61
  SourceInfo,
60
62
  LqpNode,
61
63
  Declaration,
64
+ FunctionalDependency,
62
65
  Def,
63
66
  Loop,
64
67
  Abstraction,
@@ -100,6 +103,7 @@ from lqp.ir import (
100
103
  Read,
101
104
  Epoch,
102
105
  Transaction,
106
+ Sync,
103
107
  DebugInfo,
104
108
  Configure,
105
109
  IVMConfig,
@@ -11,7 +11,9 @@ from relationalai.semantics.lqp.constructors import (
11
11
  )
12
12
  from relationalai.semantics.lqp.utils import TranslationCtx, gen_unique_var
13
13
  from relationalai.semantics.lqp.validators import assert_valid_input
14
-
14
+ from relationalai.semantics.lqp.rewrite.functional_dependencies import (
15
+ normalized_fd, contains_only_declarable_constraints
16
+ )
15
17
  from decimal import Decimal as PyDecimal
16
18
  from datetime import datetime, date, timezone
17
19
  from typing import Tuple, cast, Union, Optional
@@ -102,6 +104,43 @@ def _get_export_reads(export_ids: list[tuple[lqp.RelationId, int, lqp.Type]]) ->
102
104
  return (export_filename, col_info, reads)
103
105
 
104
106
  def _translate_to_decls(ctx: TranslationCtx, rule: ir.Logical) -> list[lqp.Declaration]:
107
+ if contains_only_declarable_constraints(rule):
108
+ return _translate_to_constraint_decls(ctx, rule)
109
+ else:
110
+ return _translate_to_standard_decl(ctx, rule)
111
+
112
+ def _translate_to_constraint_decls(ctx: TranslationCtx, rule: ir.Logical) -> list[lqp.Declaration]:
113
+ constraint_decls: list[lqp.Declaration] = []
114
+ for task in rule.body:
115
+ assert isinstance(task, ir.Require)
116
+ fd = normalized_fd(task)
117
+ assert fd is not None
118
+
119
+ # check for unresolved types
120
+ if any(types.is_any(var.type) for var in fd.keys + fd.values):
121
+ warn(f"Ignoring FD with unresolved type: {fd}")
122
+ continue
123
+
124
+ lqp_typed_keys = [_translate_term(ctx, key) for key in fd.keys]
125
+ lqp_typed_values = [_translate_term(ctx, value) for value in fd.values]
126
+ lqp_typed_vars:list[Tuple[lqp.Var, lqp.Type]] = lqp_typed_keys + lqp_typed_values # type: ignore
127
+ lqp_guard_atoms = [_translate_to_atom(ctx, atom) for atom in fd.guard]
128
+ lqp_guard = mk_abstraction(lqp_typed_vars, mk_and(lqp_guard_atoms))
129
+ lqp_keys:list[lqp.Var] = [var for (var, _) in lqp_typed_keys] # type: ignore
130
+ lqp_values:list[lqp.Var] = [var for (var, _) in lqp_typed_values] # type: ignore
131
+
132
+ fd_decl = lqp.FunctionalDependency(
133
+ guard=lqp_guard,
134
+ keys=lqp_keys,
135
+ values=lqp_values,
136
+ meta=None
137
+ )
138
+
139
+ constraint_decls.append(fd_decl)
140
+
141
+ return constraint_decls
142
+
143
+ def _translate_to_standard_decl(ctx: TranslationCtx, rule: ir.Logical) -> list[lqp.Declaration]:
105
144
  effects = collect_by_type((ir.Output, ir.Update), rule)
106
145
  aggregates = collect_by_type(ir.Aggregate, rule)
107
146
  ranks = collect_by_type(ir.Rank, rule)
@@ -452,7 +491,7 @@ def _translate_aggregate(ctx: TranslationCtx, aggr: ir.Aggregate, body: lqp.Form
452
491
 
453
492
  return mk_exists(result_terms, conjunction)
454
493
 
455
- # `input_args`` hold the types of the input arguments, but they may have been modified
494
+ # `input_args` hold the types of the input arguments, but they may have been modified
456
495
  # if we're dealing with a count, so we use `abstr_args` to find the type.
457
496
  (aggr_arg, aggr_arg_type) = abstr_args[-1]
458
497
 
@@ -6,9 +6,11 @@ from relationalai.semantics.metamodel.util import FrozenOrderedSet
6
6
 
7
7
  from relationalai.semantics.metamodel.rewrite import Flatten
8
8
 
9
- from ..metamodel.rewrite import DischargeConstraints, DNFUnionSplitter, ExtractNestedLogicals, FormatOutputs
10
- from .rewrite import CDC, ExtractCommon, ExtractKeys, FunctionAnnotations, QuantifyVars, Splinter, SplitMultiCheckRequires
11
-
9
+ from ..metamodel.rewrite import DNFUnionSplitter, ExtractNestedLogicals, FormatOutputs
10
+ from .rewrite import (
11
+ AnnotateConstraints, CDC, ExtractCommon, ExtractKeys, FunctionAnnotations, QuantifyVars,
12
+ Splinter, SplitMultiCheckRequires
13
+ )
12
14
  from relationalai.semantics.lqp.utils import output_names
13
15
 
14
16
  from typing import cast, List, Sequence, Tuple, Union, Optional, Iterable
@@ -20,7 +22,7 @@ def lqp_passes() -> list[Pass]:
20
22
  return [
21
23
  SplitMultiCheckRequires(),
22
24
  FunctionAnnotations(),
23
- DischargeConstraints(),
25
+ AnnotateConstraints(),
24
26
  Checker(),
25
27
  CDC(), # specialize to physical relations before extracting nested and typing
26
28
  ExtractNestedLogicals(), # before InferTypes to avoid extracting casts
@@ -1,3 +1,4 @@
1
+ from .annotate_constraints import AnnotateConstraints
1
2
  from .cdc import CDC
2
3
  from .extract_common import ExtractCommon
3
4
  from .extract_keys import ExtractKeys
@@ -6,6 +7,7 @@ from .quantify_vars import QuantifyVars
6
7
  from .splinter import Splinter
7
8
 
8
9
  __all__ = [
10
+ "AnnotateConstraints",
9
11
  "CDC",
10
12
  "ExtractCommon",
11
13
  "ExtractKeys",