relationalai 1.0.0a3__py3-none-any.whl → 1.0.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. relationalai/config/config.py +47 -21
  2. relationalai/config/connections/__init__.py +5 -2
  3. relationalai/config/connections/duckdb.py +2 -2
  4. relationalai/config/connections/local.py +31 -0
  5. relationalai/config/connections/snowflake.py +0 -1
  6. relationalai/config/external/raiconfig_converter.py +235 -0
  7. relationalai/config/external/raiconfig_models.py +202 -0
  8. relationalai/config/external/utils.py +31 -0
  9. relationalai/config/shims.py +1 -0
  10. relationalai/semantics/__init__.py +10 -8
  11. relationalai/semantics/backends/sql/sql_compiler.py +1 -4
  12. relationalai/semantics/experimental/__init__.py +0 -0
  13. relationalai/semantics/experimental/builder.py +295 -0
  14. relationalai/semantics/experimental/builtins.py +154 -0
  15. relationalai/semantics/frontend/base.py +67 -42
  16. relationalai/semantics/frontend/core.py +34 -6
  17. relationalai/semantics/frontend/front_compiler.py +209 -37
  18. relationalai/semantics/frontend/pprint.py +6 -2
  19. relationalai/semantics/metamodel/__init__.py +7 -0
  20. relationalai/semantics/metamodel/metamodel.py +2 -0
  21. relationalai/semantics/metamodel/metamodel_analyzer.py +58 -16
  22. relationalai/semantics/metamodel/pprint.py +6 -1
  23. relationalai/semantics/metamodel/rewriter.py +11 -7
  24. relationalai/semantics/metamodel/typer.py +116 -41
  25. relationalai/semantics/reasoners/__init__.py +11 -0
  26. relationalai/semantics/reasoners/graph/__init__.py +35 -0
  27. relationalai/semantics/reasoners/graph/core.py +9028 -0
  28. relationalai/semantics/std/__init__.py +30 -10
  29. relationalai/semantics/std/aggregates.py +641 -12
  30. relationalai/semantics/std/common.py +146 -13
  31. relationalai/semantics/std/constraints.py +71 -1
  32. relationalai/semantics/std/datetime.py +904 -21
  33. relationalai/semantics/std/decimals.py +143 -2
  34. relationalai/semantics/std/floats.py +57 -4
  35. relationalai/semantics/std/integers.py +98 -4
  36. relationalai/semantics/std/math.py +857 -35
  37. relationalai/semantics/std/numbers.py +216 -20
  38. relationalai/semantics/std/re.py +213 -5
  39. relationalai/semantics/std/strings.py +437 -44
  40. relationalai/shims/executor.py +60 -52
  41. relationalai/shims/fixtures.py +85 -0
  42. relationalai/shims/helpers.py +26 -2
  43. relationalai/shims/hoister.py +28 -9
  44. relationalai/shims/mm2v0.py +204 -173
  45. relationalai/tools/cli/cli.py +192 -10
  46. relationalai/tools/cli/components/progress_reader.py +1 -1
  47. relationalai/tools/cli/docs.py +394 -0
  48. relationalai/tools/debugger.py +11 -4
  49. relationalai/tools/qb_debugger.py +435 -0
  50. relationalai/tools/typer_debugger.py +1 -2
  51. relationalai/util/dataclasses.py +3 -5
  52. relationalai/util/docutils.py +1 -2
  53. relationalai/util/error.py +2 -5
  54. relationalai/util/python.py +23 -0
  55. relationalai/util/runtime.py +1 -2
  56. relationalai/util/schema.py +2 -4
  57. relationalai/util/structures.py +4 -2
  58. relationalai/util/tracing.py +8 -2
  59. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/METADATA +8 -5
  60. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/RECORD +118 -95
  61. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/WHEEL +1 -1
  62. v0/relationalai/__init__.py +1 -1
  63. v0/relationalai/clients/client.py +52 -18
  64. v0/relationalai/clients/exec_txn_poller.py +122 -0
  65. v0/relationalai/clients/local.py +23 -8
  66. v0/relationalai/clients/resources/azure/azure.py +36 -11
  67. v0/relationalai/clients/resources/snowflake/__init__.py +4 -4
  68. v0/relationalai/clients/resources/snowflake/cli_resources.py +12 -1
  69. v0/relationalai/clients/resources/snowflake/direct_access_resources.py +124 -100
  70. v0/relationalai/clients/resources/snowflake/engine_service.py +381 -0
  71. v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +35 -29
  72. v0/relationalai/clients/resources/snowflake/error_handlers.py +43 -2
  73. v0/relationalai/clients/resources/snowflake/snowflake.py +277 -179
  74. v0/relationalai/clients/resources/snowflake/use_index_poller.py +8 -0
  75. v0/relationalai/clients/types.py +5 -0
  76. v0/relationalai/errors.py +19 -1
  77. v0/relationalai/semantics/lqp/algorithms.py +173 -0
  78. v0/relationalai/semantics/lqp/builtins.py +199 -2
  79. v0/relationalai/semantics/lqp/executor.py +68 -37
  80. v0/relationalai/semantics/lqp/ir.py +28 -2
  81. v0/relationalai/semantics/lqp/model2lqp.py +215 -45
  82. v0/relationalai/semantics/lqp/passes.py +13 -658
  83. v0/relationalai/semantics/lqp/rewrite/__init__.py +12 -0
  84. v0/relationalai/semantics/lqp/rewrite/algorithm.py +385 -0
  85. v0/relationalai/semantics/lqp/rewrite/constants_to_vars.py +70 -0
  86. v0/relationalai/semantics/lqp/rewrite/deduplicate_vars.py +104 -0
  87. v0/relationalai/semantics/lqp/rewrite/eliminate_data.py +108 -0
  88. v0/relationalai/semantics/lqp/rewrite/extract_keys.py +25 -3
  89. v0/relationalai/semantics/lqp/rewrite/period_math.py +77 -0
  90. v0/relationalai/semantics/lqp/rewrite/quantify_vars.py +65 -31
  91. v0/relationalai/semantics/lqp/rewrite/unify_definitions.py +317 -0
  92. v0/relationalai/semantics/lqp/utils.py +11 -1
  93. v0/relationalai/semantics/lqp/validators.py +14 -1
  94. v0/relationalai/semantics/metamodel/builtins.py +2 -1
  95. v0/relationalai/semantics/metamodel/compiler.py +2 -1
  96. v0/relationalai/semantics/metamodel/dependency.py +12 -3
  97. v0/relationalai/semantics/metamodel/executor.py +11 -1
  98. v0/relationalai/semantics/metamodel/factory.py +2 -2
  99. v0/relationalai/semantics/metamodel/helpers.py +7 -0
  100. v0/relationalai/semantics/metamodel/ir.py +3 -2
  101. v0/relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +30 -20
  102. v0/relationalai/semantics/metamodel/rewrite/flatten.py +50 -13
  103. v0/relationalai/semantics/metamodel/rewrite/format_outputs.py +9 -3
  104. v0/relationalai/semantics/metamodel/typer/checker.py +6 -4
  105. v0/relationalai/semantics/metamodel/typer/typer.py +4 -3
  106. v0/relationalai/semantics/metamodel/visitor.py +4 -3
  107. v0/relationalai/semantics/reasoners/optimization/solvers_dev.py +1 -1
  108. v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +336 -86
  109. v0/relationalai/semantics/rel/compiler.py +2 -1
  110. v0/relationalai/semantics/rel/executor.py +3 -2
  111. v0/relationalai/semantics/tests/lqp/__init__.py +0 -0
  112. v0/relationalai/semantics/tests/lqp/algorithms.py +345 -0
  113. v0/relationalai/tools/cli.py +339 -186
  114. v0/relationalai/tools/cli_controls.py +216 -67
  115. v0/relationalai/tools/cli_helpers.py +410 -6
  116. v0/relationalai/util/format.py +5 -2
  117. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/entry_points.txt +0 -0
  118. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@
3
3
  # VERY UNSTABLE API TO EXECUTE FRONTEND PROGRAMS USING V0 EXECUTORS
4
4
  #------------------------------------------------------
5
5
  from functools import lru_cache
6
+ from pandas import DataFrame
6
7
  import json
7
8
  from v0.relationalai import debugging
8
9
  from v0.relationalai.semantics.lqp.executor import LQPExecutor
@@ -10,6 +11,7 @@ from v0.relationalai.semantics.rel.executor import RelExecutor
10
11
  from v0.relationalai.semantics.metamodel import ir as v0, factory as v0_factory
11
12
  from v0.relationalai.semantics.metamodel.visitor import collect_by_type
12
13
  from v0.relationalai.semantics.snowflake import Table as v0Table
14
+ from v0.relationalai.semantics.metamodel import executor as v0executor
13
15
  try:
14
16
  from v0.relationalai.clients.snowflake import Provider as v0Provider #type: ignore
15
17
  except ImportError:
@@ -17,18 +19,17 @@ except ImportError:
17
19
  from v0.relationalai.clients.config import Config
18
20
 
19
21
  # from ..config import Config
20
- from ..config.shims import DRY_RUN
22
+ from ..config.shims import DRY_RUN, ENFORCE_TYPE_CORRECT
21
23
  from ..semantics import Model, Fragment
24
+ from ..semantics.frontend.pprint import format
22
25
  from ..semantics.metamodel import metamodel as mm
23
26
  from ..semantics.metamodel.typer import Typer
24
27
  from ..semantics.metamodel.metamodel_analyzer import Normalize
25
28
  from .mm2v0 import Translator
26
29
 
27
- DEBUG=False
28
30
  PRINT_RESULT=False
29
31
  TYPER_DEBUGGER=False
30
32
 
31
- # DEBUG=True
32
33
  # PRINT_RESULT=True
33
34
  # TYPER_DEBUGGER=True
34
35
 
@@ -50,19 +51,21 @@ def with_source(item: mm.Node):
50
51
  else:
51
52
  return {"file":item.source.file, "line":item.source.line}
52
53
 
53
- def execute(query: Fragment, model: Model|None = None, executor=None, export_to="", update=False):
54
+ def execute(query: Fragment, model: Model|None = None, executor=None, export_to="", update=False, meta=None):
54
55
  if not executor:
55
- # use_lqp = Config().reasoner.rule.use_lqp
56
- use_lqp = bool(get_config().get("reasoner.rule.use_lqp", True))
56
+ use_lqp = get_config().get("reasoner.rule.use_lqp", True)
57
57
  executor = "lqp" if use_lqp else "rel"
58
58
  mm_model = model.to_metamodel() if model else None
59
59
  mm_query = query.to_metamodel()
60
60
  assert isinstance(mm_query, mm.Node)
61
- return execute_mm(mm_query, mm_model, executor, export_to=export_to, update=update, model=model)
61
+ with debugging.span("query", tag=None, export_to=export_to, dsl=format(query), **with_source(mm_query)) as query_span:
62
+ results = execute_mm(mm_query, mm_model, executor, export_to=export_to, update=update, model=model, meta=meta)
63
+ query_span["results"] = results
64
+ return results
62
65
 
63
- def execute_mm(mm_query: mm.Task, mm_model: mm.Model|None = None, executor="lqp", export_to="", update=False, model: Model|None = None):
66
+ def execute_mm(mm_query: mm.Task, mm_model: mm.Model|None = None, executor="lqp", export_to="", update=False, model: Model|None = None, meta=None):
64
67
  # perform type inference
65
- typer = Typer(enforce=False)
68
+ typer = Typer(enforce=ENFORCE_TYPE_CORRECT)
66
69
  # normalize the metamodel
67
70
  normalizer = Normalize()
68
71
  # translate the metamodel into a v0 query
@@ -81,22 +84,27 @@ def execute_mm(mm_query: mm.Task, mm_model: mm.Model|None = None, executor="lqp"
81
84
  debugger_msgs.append(json.dumps({ 'id': 'model', 'content': str(mm_model)}))
82
85
  with debugging.span("compile", metamodel=mm_model) as span:
83
86
  span["compile_type"] = "model_v1"
84
- mm_model = typer.infer_model(mm_model)
85
- span["typed_mm"] = str(mm_model)
86
- debugger_msgs.append(json.dumps({ 'id': 'typed_model', 'content': str(mm_model)}))
87
- assert(typer.model_net is not None)
88
- debugger_msgs.append(json.dumps({ 'id': 'model_net', 'content': str(typer.model_net.to_mermaid())}))
89
- # normalization
90
- mm_model = mm_model.mut(root=normalizer.normalize(mm_model.root)) # type: ignore
91
- assert isinstance(mm_model, mm.Model)
92
- if DEBUG:
93
- print("V1 Model:")
94
- print(mm_model)
95
- # translation
96
- v0_model = translator.translate_model(mm_model)
97
- if DEBUG:
98
- print("Translated v0 Model:")
99
- print(v0_model)
87
+ with debugging.span("v1") as span:
88
+ with debugging.span("Original") as span:
89
+ if debugging.DEBUG:
90
+ span["metamodel"] = str(mm_model.root)
91
+ with debugging.span("Typer") as span:
92
+ mm_model = typer.infer_model(mm_model)
93
+ if debugging.DEBUG:
94
+ span["metamodel"] = str(mm_model.root)
95
+ span["typed_mm"] = str(mm_model)
96
+ debugger_msgs.append(json.dumps({ 'id': 'typed_model', 'content': str(mm_model)}))
97
+ assert(typer.model_net is not None)
98
+ debugger_msgs.append(json.dumps({ 'id': 'model_net', 'content': str(typer.model_net.to_mermaid())}))
99
+ with debugging.span("Normalizer") as span:
100
+ mm_model = mm_model.mut(root=normalizer.normalize(mm_model.root)) # type: ignore
101
+ assert isinstance(mm_model, mm.Model)
102
+ if debugging.DEBUG:
103
+ span["metamodel"] = str(mm_model.root)
104
+ with debugging.span("v0 Translation") as span:
105
+ v0_model = translator.translate_model(mm_model)
106
+ if debugging.DEBUG:
107
+ span["metamodel"] = str(v0_model.root)
100
108
 
101
109
  #------------------------------------------------------
102
110
  # Query processing
@@ -105,23 +113,23 @@ def execute_mm(mm_query: mm.Task, mm_model: mm.Model|None = None, executor="lqp"
105
113
  debugger_msgs.append(json.dumps({ 'id': 'query', 'content': str(mm_query)}))
106
114
  with debugging.span("compile", metamodel=mm_query) as span:
107
115
  span["compile_type"] = "query_v1"
108
- mm_query = typer.infer_query(mm_query)
109
- span["typed_mm"] = str(mm_query)
110
- debugger_msgs.append(json.dumps({ 'id': 'typed_query', 'content': str(mm_query)}))
111
- assert(typer.last_net is not None)
112
- debugger_msgs.append(json.dumps({ 'id': 'query_net', 'content': str(typer.last_net.to_mermaid())}))
113
- # normalization
114
- mm_query = normalizer.normalize(mm_query) # type: ignore
115
- assert isinstance(mm_query, mm.Task)
116
- if DEBUG:
117
- print("V1 Query:")
118
- print(mm_query)
119
- # translation
120
- v0_query = translator.translate_query(mm_query)
121
- if DEBUG:
122
- print("Translated v0 Query:")
123
- print(v0_query)
124
- assert isinstance(v0_query, v0.Task)
116
+ with debugging.span("v1") as span:
117
+ with debugging.span("Original") as span:
118
+ span["metamodel"] = str(mm_query)
119
+ with debugging.span("Typer") as span:
120
+ mm_query = typer.infer_query(mm_query)
121
+ span["metamodel"] = str(mm_query)
122
+ span["typed_mm"] = str(mm_query)
123
+ debugger_msgs.append(json.dumps({ 'id': 'typed_query', 'content': str(mm_query)}))
124
+ assert(typer.last_net is not None)
125
+ debugger_msgs.append(json.dumps({ 'id': 'query_net', 'content': str(typer.last_net.to_mermaid())}))
126
+ with debugging.span("Normalizer") as span:
127
+ mm_query = normalizer.normalize(mm_query) # type: ignore
128
+ assert isinstance(mm_query, mm.Task)
129
+ with debugging.span("v0 Translation") as span:
130
+ v0_query = translator.translate_query(mm_query)
131
+ span["metamodel"] = str(v0_query)
132
+ assert isinstance(v0_query, v0.Task)
125
133
 
126
134
  if v0_model is None:
127
135
  # there was no model, so create one from the elements refered to by the query
@@ -139,13 +147,14 @@ def execute_mm(mm_query: mm.Task, mm_model: mm.Model|None = None, executor="lqp"
139
147
  f.write('\n')
140
148
 
141
149
  if DRY_RUN or get_config().get("compiler.dry_run", False):
142
- results = []
150
+ results = DataFrame()
143
151
  else:
144
152
  # create snowflake tables for all the tables that have been used
145
153
  ts = [v0Table(t.name) for t in translator.used_tables if not t.uri.startswith("dataframe://")]
146
154
  for t in ts:
147
- t._lazy_init()
148
- v0Table._used_sources.add(t)
155
+ if not any(v0t._fqn == t._fqn for v0t in v0Table._used_sources):
156
+ t._lazy_init()
157
+ v0Table._used_sources.add(t)
149
158
 
150
159
  export_table = None
151
160
  if export_to:
@@ -153,18 +162,17 @@ def execute_mm(mm_query: mm.Task, mm_model: mm.Model|None = None, executor="lqp"
153
162
 
154
163
  # get an executor and execute
155
164
  executor = _get_executor(executor, model.name if model else "")
156
- with debugging.span("query", tag=None, export_to=export_to, dsl="", **with_source(mm_query)) as query_span:
157
- if isinstance(executor, (LQPExecutor, RelExecutor)):
158
- results = executor.execute(v0_model, v0_query, export_to=export_table, update=update)
159
- else:
160
- results = executor.execute(v0_model, v0_query)
161
- query_span["results"] = results
162
- if DEBUG or PRINT_RESULT:
165
+ if isinstance(executor, (LQPExecutor, RelExecutor)):
166
+ results = executor.execute(v0_model, v0_query, export_to=export_table, update=update, meta=meta)
167
+ else:
168
+ results = executor.execute(v0_model, v0_query)
169
+ if PRINT_RESULT:
163
170
  print(results)
164
171
  return results
165
172
 
166
173
  @lru_cache()
167
174
  def _get_executor(name: str, database: str = "ttb_test"):
175
+ v0executor.SUPPRESS_TYPE_ERRORS = True
168
176
  if name == "duckdb":
169
177
  from v0.relationalai.semantics.sql.executor.duck_db import DuckDBExecutor
170
178
  return DuckDBExecutor()
@@ -0,0 +1,85 @@
1
+ from __future__ import annotations
2
+ import os
3
+
4
+ import v0.relationalai as rai
5
+ from v0.relationalai.clients import config as cfg
6
+
7
+ def create_engine(engine_name: str, size: str, use_direct_access=False):
8
+ print('create_engine: about to call make_config')
9
+ config = make_config(engine_name, use_direct_access=use_direct_access)
10
+
11
+ provider = rai.Resources(config=config)
12
+ print(f"Creating engine {engine_name}")
13
+ provider.create_engine(name=engine_name, type="LOGIC", size=size)
14
+ print(f"Engine {engine_name} created")
15
+
16
+ def delete_engine(engine_name: str, use_direct_access=False):
17
+ print(f"Deleting engine {engine_name}")
18
+ config = make_config(engine_name, use_direct_access=use_direct_access)
19
+ provider = rai.Resources(config=config)
20
+ provider.delete_engine(engine_name, "LOGIC")
21
+ print(f"Engine {engine_name} deleted")
22
+
23
+ def make_config(engine_name: str | None = None, fetch_profile: bool = True, use_direct_access = False, show_full_traces = True, show_debug_logs = True, reuse_model = False) -> cfg.Config:
24
+ # First try to load from raiconfig.toml
25
+ try:
26
+ config = cfg.Config()
27
+ if config.file_path is not None:
28
+ # Set test defaults
29
+ config.set("reuse_model", reuse_model)
30
+ rai.Resources(config=config)
31
+ return config
32
+ except Exception as e:
33
+ print(f"Could not load config from file: {e}, trying environment variables")
34
+ # If that fails, construct from environment variables
35
+ cloud_provider = os.getenv("RAI_CLOUD_PROVIDER")
36
+
37
+ if cloud_provider is None:
38
+ cloud_provider = os.getenv("RAI_CLOUD_PROVIDER")
39
+
40
+ print('cloud provider:', cloud_provider)
41
+
42
+ if cloud_provider is None:
43
+ raise ValueError("RAI_CLOUD_PROVIDER must be set")
44
+ elif cloud_provider == "snowflake":
45
+ sf_username = os.getenv("SF_TEST_ACCOUNT_USERNAME")
46
+ sf_password = os.getenv("SF_TEST_ACCOUNT_PASSWORD")
47
+ sf_account = os.getenv("SF_TEST_ACCOUNT_NAME")
48
+ sf_role = os.getenv("SF_TEST_ROLE_NAME", "RAI_USER")
49
+ sf_warehouse = os.getenv("SF_TEST_WAREHOUSE_NAME")
50
+ sf_app_name = os.getenv("SF_TEST_APP_NAME")
51
+ if sf_username is None or sf_password is None:
52
+ raise ValueError(
53
+ "SF_TEST_ACCOUNT_USERNAME, SF_TEST_ACCOUNT_PASSWORD, SF_TEST_ACCOUNT_NAME must be set if RAI_CLOUD_PROVIDER is set to 'snowflake'"
54
+ )
55
+
56
+ current_config = {
57
+ "platform": "snowflake",
58
+ "user": sf_username,
59
+ "password": sf_password,
60
+ "account": sf_account,
61
+ "role": sf_role,
62
+ "warehouse": sf_warehouse,
63
+ "rai_app_name": sf_app_name,
64
+ "use_direct_access": use_direct_access,
65
+ "show_full_traces": show_full_traces,
66
+ "show_debug_logs": show_debug_logs,
67
+ "reuse_model": reuse_model,
68
+ }
69
+ if engine_name:
70
+ current_config["engine"] = engine_name
71
+
72
+ # For direct access, we use key-pair as the default for testing and need to configure additional parameters.
73
+ if use_direct_access:
74
+ authenticator=os.getenv("AUTHENTICATOR","")
75
+ private_key_file=os.getenv("PRIVATE_KEY_FILE","")
76
+ private_key_file_pwd=os.getenv("SF_TEST_ACCOUNT_KEY_PASSPHRASE","")
77
+
78
+ current_config["authenticator"] = authenticator
79
+ current_config["private_key_file"] = private_key_file
80
+ current_config["private_key_file_pwd"]=private_key_file_pwd
81
+
82
+ return cfg.Config(current_config, fetch=fetch_profile)
83
+
84
+ else:
85
+ raise ValueError(f"Unsupported cloud provider: {cloud_provider}")
@@ -1,10 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataclasses import dataclass
4
3
  from typing import Iterable, Sequence as seq
5
4
 
6
5
  from relationalai.semantics.metamodel import metamodel as mm
7
- from relationalai.semantics.metamodel.builtins import builtins as b
8
6
  from relationalai.semantics.metamodel.rewriter import NO_WALK, Walker
9
7
  from relationalai.util.structures import OrderedSet
10
8
 
@@ -105,6 +103,32 @@ class VarFinder(ContainerWalker):
105
103
  def var(self, node:mm.Var):
106
104
  self.vars.add(node)
107
105
 
106
+ class NodeFinder(ContainerWalker):
107
+
108
+ def is_descendant(self, node:mm.Node, ancestor:mm.Task) -> bool:
109
+ """ Ture if node is a descendant of the ancestor task, i.e. one of its children,
110
+ or a child of one of its children, etc.
111
+ """
112
+ self.node = node
113
+ self.found = False
114
+ try:
115
+ self(ancestor)
116
+ return self.found
117
+ finally:
118
+ self.node = None
119
+ self.found = False
120
+
121
+ def enter_container(self, container, children):
122
+ if self.found:
123
+ return NO_WALK
124
+ return None
125
+
126
+ def _visit_node(self, node: mm.Node):
127
+ if node == self.node:
128
+ self.found = True
129
+ return node
130
+ return super()._visit_node(node)
131
+
108
132
  #------------------------------------------------------
109
133
  # Helper functions shared across shim modules
110
134
  #------------------------------------------------------
@@ -1,15 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from collections import defaultdict
4
- import itertools
5
4
  from dataclasses import dataclass
6
- from typing import Iterable, Sequence as seq
5
+ from typing import Sequence as seq
7
6
 
8
7
  from relationalai.semantics.metamodel import metamodel as mm
9
8
  from relationalai.semantics.metamodel.builtins import builtins as b
10
- from relationalai.semantics.metamodel.rewriter import NO_WALK, Walker
11
9
  from relationalai.util.structures import OrderedSet
12
- from .helpers import is_output_update, is_main_output, ContainerWalker
10
+ from .helpers import is_output_update, is_main_output, ContainerWalker, NodeFinder
13
11
 
14
12
 
15
13
  @dataclass(frozen=True)
@@ -60,6 +58,8 @@ class Scope:
60
58
  class Hoister(ContainerWalker):
61
59
  def __init__(self):
62
60
  super().__init__()
61
+ # helper for finding children nodes
62
+ self.finder = NodeFinder()
63
63
  # the result of the whole analysis
64
64
  self.hoists: dict[mm.Container, OrderedSet[mm.Var]] = defaultdict(OrderedSet[mm.Var])
65
65
  self.container_vars: dict[mm.Container, set[mm.Var]] = defaultdict(set)
@@ -130,6 +130,15 @@ class Hoister(ContainerWalker):
130
130
  self.hoists[child] = parent_hoists
131
131
  self.allowed_vars[child] = allowed
132
132
  self.allowed_vars[container] = allowed
133
+ # this might be a match/union that just acts as a filter for some sibling logical,
134
+ # so we need to create an artificial demand for the allowed vars in the parent container
135
+ # to hoist them up if there's no other output providing them.
136
+ # ex: define(foo(Person)).where(union(Person.age > 10, Person.age < 5))
137
+ # parent = self.scope.parent.get(container, None)
138
+ # if parent:
139
+ # for allowed_var in allowed:
140
+ # self._register_use(allowed_var, input=True, container=parent)
141
+ # self.scope.var_refs[allowed_var]
133
142
 
134
143
  if container.scope:
135
144
  # when leaving the scope, compute hoisted vars
@@ -141,6 +150,13 @@ class Hoister(ContainerWalker):
141
150
  for ref in var_refs:
142
151
  if ref.is_input() and not any(self.scope.is_ancestor(other.node, ref.node) for other in var_refs if other != ref and other.is_output()):
143
152
  requires_hoist.add(ref)
153
+ # if this is a require, the ref is in the check or error, and there's another
154
+ # ref in the domain, we need to hoist in the domain to make it available to
155
+ # check/error
156
+ elif isinstance(container, mm.Require):
157
+ if self.finder.is_descendant(ref.node, container.check) or ( container.error and self.finder.is_descendant(ref.node, container.error)):
158
+ if any(self.finder.is_descendant(other.node, container.domain) for other in var_refs if other != ref):
159
+ requires_hoist.add(ref)
144
160
 
145
161
  # for references that require a hoist, find some other output and hoist
146
162
  # that up until the least common ancestor of the two references
@@ -150,7 +166,9 @@ class Hoister(ContainerWalker):
150
166
  # find least common ancestor of the two references
151
167
  lca = self.scope.least_common_ancestor(ref.node, other.node)
152
168
  assert(lca is not None)
153
- if var not in self.container_vars[lca]:
169
+ # we should never hoist vars where the common ancestor is a match/union because that
170
+ # means we're hopping variables across branches
171
+ if var not in self.container_vars[lca] and not isinstance(lca, (mm.Match, mm.Union)):
154
172
  self._hoist_until(var, other.node, lca)
155
173
 
156
174
  # pop container from stack
@@ -188,9 +206,9 @@ class Hoister(ContainerWalker):
188
206
  self.scope.var_refs[var].add(VarRef(parent, input))
189
207
  self.scope.available_vars[parent].add(var)
190
208
 
191
- def lookup(self, l: mm.Lookup):
192
- for arg, field in zip(l.args, l.relation.fields):
193
- if not field.input or l.relation == b.core["="]:
209
+ def lookup(self, lookup_node: mm.Lookup):
210
+ for arg, field in zip(lookup_node.args, lookup_node.relation.fields):
211
+ if not field.input or lookup_node.relation == b.core["="]:
194
212
  self._register_use(arg, input=False)
195
213
  else:
196
214
  self._register_use(arg, input=True)
@@ -209,7 +227,8 @@ class Hoister(ContainerWalker):
209
227
  self._register_use(arg, input=True, container=main_output_container)
210
228
 
211
229
  def aggregate(self, a: mm.Aggregate):
212
- self._register_use(a.projection, input=True)
230
+ for p in a.projection:
231
+ self._register_use(p, input=True)
213
232
  for g in a.group:
214
233
  self._register_use(g, input=True)
215
234
  for arg, field in zip(a.args, a.aggregation.fields):