relationalai 0.12.0__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/clients/snowflake.py +228 -83
- relationalai/clients/types.py +4 -1
- relationalai/clients/use_index_poller.py +72 -48
- relationalai/clients/util.py +9 -0
- relationalai/dsl.py +1 -2
- relationalai/environments/snowbook.py +10 -1
- relationalai/semantics/internal/internal.py +22 -3
- relationalai/semantics/lqp/executor.py +12 -4
- relationalai/semantics/lqp/model2lqp.py +1 -0
- relationalai/semantics/metamodel/executor.py +2 -1
- relationalai/semantics/metamodel/rewrite/flatten.py +8 -7
- relationalai/semantics/reasoners/graph/core.py +1174 -226
- relationalai/semantics/rel/executor.py +20 -11
- relationalai/semantics/sql/executor/snowflake.py +1 -1
- relationalai/tools/cli.py +6 -2
- relationalai/tools/cli_controls.py +334 -352
- relationalai/tools/constants.py +1 -0
- relationalai/tools/query_utils.py +27 -0
- relationalai/util/otel_configuration.py +1 -1
- {relationalai-0.12.0.dist-info → relationalai-0.12.1.dist-info}/METADATA +1 -1
- {relationalai-0.12.0.dist-info → relationalai-0.12.1.dist-info}/RECORD +24 -23
- {relationalai-0.12.0.dist-info → relationalai-0.12.1.dist-info}/WHEEL +0 -0
- {relationalai-0.12.0.dist-info → relationalai-0.12.1.dist-info}/entry_points.txt +0 -0
- {relationalai-0.12.0.dist-info → relationalai-0.12.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -18,7 +18,8 @@ from relationalai.clients.snowflake import APP_NAME
|
|
|
18
18
|
from relationalai.semantics.metamodel import ir, executor as e, factory as f
|
|
19
19
|
from relationalai.semantics.rel import Compiler
|
|
20
20
|
from relationalai.clients.config import Config
|
|
21
|
-
from relationalai.tools.constants import USE_DIRECT_ACCESS, Generation
|
|
21
|
+
from relationalai.tools.constants import USE_DIRECT_ACCESS, Generation, QUERY_ATTRIBUTES_HEADER
|
|
22
|
+
from relationalai.tools.query_utils import prepare_metadata_for_headers
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
class RelExecutor(e.Executor):
|
|
@@ -60,7 +61,7 @@ class RelExecutor(e.Executor):
|
|
|
60
61
|
atexit.register(self._resources.delete_graph, self.database, True)
|
|
61
62
|
return self._resources
|
|
62
63
|
|
|
63
|
-
def check_graph_index(self):
|
|
64
|
+
def check_graph_index(self, headers: dict[str, Any] | None = None):
|
|
64
65
|
# Has to happen first, so self.dry_run is populated.
|
|
65
66
|
resources = self.resources
|
|
66
67
|
|
|
@@ -84,7 +85,7 @@ class RelExecutor(e.Executor):
|
|
|
84
85
|
assert self.engine is not None
|
|
85
86
|
|
|
86
87
|
with debugging.span("poll_use_index", sources=sources, model=model, engine=engine_name):
|
|
87
|
-
resources.poll_use_index(app_name, sources, model, self.engine, engine_size, program_span_id)
|
|
88
|
+
resources.poll_use_index(app_name, sources, model, self.engine, engine_size, program_span_id, headers)
|
|
88
89
|
|
|
89
90
|
def report_errors(self, problems: list[dict[str, Any]], abort_on_error=True):
|
|
90
91
|
from relationalai import errors
|
|
@@ -143,7 +144,7 @@ class RelExecutor(e.Executor):
|
|
|
143
144
|
elif len(all_errors) > 1:
|
|
144
145
|
raise errors.RAIExceptionSet(all_errors)
|
|
145
146
|
|
|
146
|
-
def _export(self, raw_code: str, dest_fqn: str, actual_cols: list[str], declared_cols: list[str], update:bool):
|
|
147
|
+
def _export(self, raw_code: str, dest_fqn: str, actual_cols: list[str], declared_cols: list[str], update:bool, headers: dict[str, Any] | None = None):
|
|
147
148
|
_exec = self.resources._exec
|
|
148
149
|
output_table = "out" + str(uuid.uuid4()).replace("-", "_")
|
|
149
150
|
txn_id = None
|
|
@@ -153,7 +154,7 @@ class RelExecutor(e.Executor):
|
|
|
153
154
|
with debugging.span("transaction"):
|
|
154
155
|
try:
|
|
155
156
|
with debugging.span("exec_format") as span:
|
|
156
|
-
res = _exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, ?);", [self.database, self.engine, raw_code, output_table, False, True])
|
|
157
|
+
res = _exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?, ?);", [self.database, self.engine, raw_code, output_table, False, True, False, None])
|
|
157
158
|
txn_id = json.loads(res[0]["EXEC_INTO_TABLE"])["rai_transaction_id"]
|
|
158
159
|
span["txn_id"] = txn_id
|
|
159
160
|
|
|
@@ -229,14 +230,18 @@ class RelExecutor(e.Executor):
|
|
|
229
230
|
else:
|
|
230
231
|
raise e
|
|
231
232
|
if txn_id:
|
|
232
|
-
artifact_info = self.resources._list_exec_async_artifacts(txn_id)
|
|
233
|
+
artifact_info = self.resources._list_exec_async_artifacts(txn_id, headers=headers)
|
|
233
234
|
with debugging.span("fetch"):
|
|
234
235
|
artifacts = self.resources._download_results(artifact_info, txn_id, "ABORTED")
|
|
235
236
|
return artifacts
|
|
236
237
|
|
|
237
238
|
def execute(self, model: ir.Model, task: ir.Task, format: Literal["pandas", "snowpark"] = "pandas",
|
|
238
|
-
result_cols: list[str] | None = None, export_to: Optional[str] = None, update: bool = False) -> Any:
|
|
239
|
-
|
|
239
|
+
result_cols: list[str] | None = None, export_to: Optional[str] = None, update: bool = False, meta: dict[str, Any] | None = None) -> Any:
|
|
240
|
+
# Format meta as headers
|
|
241
|
+
json_meta = prepare_metadata_for_headers(meta)
|
|
242
|
+
headers = {QUERY_ATTRIBUTES_HEADER: json_meta} if json_meta else {}
|
|
243
|
+
|
|
244
|
+
self.check_graph_index(headers)
|
|
240
245
|
resources= self.resources
|
|
241
246
|
|
|
242
247
|
rules_code = ""
|
|
@@ -277,12 +282,16 @@ class RelExecutor(e.Executor):
|
|
|
277
282
|
|
|
278
283
|
if not export_to:
|
|
279
284
|
if format == "pandas":
|
|
280
|
-
raw_results = resources.exec_raw(self.database, self.engine, full_code, False, nowait_durable=True)
|
|
285
|
+
raw_results = resources.exec_raw(self.database, self.engine, full_code, False, nowait_durable=True, headers=headers)
|
|
281
286
|
df, errs = result_helpers.format_results(raw_results, None, cols, generation=Generation.QB) # Pass None for task parameter
|
|
282
287
|
self.report_errors(errs)
|
|
288
|
+
# Rename columns if wide outputs is enabled
|
|
289
|
+
if self.wide_outputs and len(cols) - len(extra_cols) == len(df.columns):
|
|
290
|
+
df.columns = cols[: len(df.columns)]
|
|
291
|
+
|
|
283
292
|
return self._postprocess_df(self.config, df, extra_cols)
|
|
284
293
|
elif format == "snowpark":
|
|
285
|
-
results, raw = resources.exec_format(self.database, self.engine, full_code, cols, format=format, readonly=False, nowait_durable=True)
|
|
294
|
+
results, raw = resources.exec_format(self.database, self.engine, full_code, cols, format=format, readonly=False, nowait_durable=True, headers=headers)
|
|
286
295
|
if raw:
|
|
287
296
|
df, errs = result_helpers.format_results(raw, None, cols, generation=Generation.QB) # Pass None for task parameter
|
|
288
297
|
self.report_errors(errs)
|
|
@@ -296,7 +305,7 @@ class RelExecutor(e.Executor):
|
|
|
296
305
|
else:
|
|
297
306
|
result_cols = [col for col in cols if col not in extra_cols]
|
|
298
307
|
assert result_cols
|
|
299
|
-
raw = self._export(full_code, export_to, cols, result_cols, update)
|
|
308
|
+
raw = self._export(full_code, export_to, cols, result_cols, update, headers)
|
|
300
309
|
errors = []
|
|
301
310
|
if raw:
|
|
302
311
|
dataframe, errors = result_helpers.format_results(raw, None, result_cols, generation=Generation.QB)
|
|
@@ -62,7 +62,7 @@ class SnowflakeExecutor(e.Executor):
|
|
|
62
62
|
|
|
63
63
|
def execute(self, model: ir.Model, task: ir.Task, format:Literal["pandas", "snowpark"]="pandas",
|
|
64
64
|
result_cols: Optional[list[str]] = None, export_to: Optional[str] = None,
|
|
65
|
-
update: bool = False) -> Union[pd.DataFrame, Any]:
|
|
65
|
+
update: bool = False, meta: dict[str, Any] | None = None) -> Union[pd.DataFrame, Any]:
|
|
66
66
|
""" Execute the SQL query directly. """
|
|
67
67
|
|
|
68
68
|
warehouse = self.resources.config.get("warehouse", None)
|
relationalai/tools/cli.py
CHANGED
|
@@ -719,7 +719,10 @@ def config_check(all_profiles:bool=False):
|
|
|
719
719
|
@cli.command(help="Print version info")
|
|
720
720
|
def version():
|
|
721
721
|
from .. import __version__
|
|
722
|
-
|
|
722
|
+
try:
|
|
723
|
+
from railib import __version__ as railib_version
|
|
724
|
+
except Exception:
|
|
725
|
+
railib_version = None
|
|
723
726
|
|
|
724
727
|
table = Table(show_header=False, border_style="dim", header_style="bold", box=rich_box.SIMPLE)
|
|
725
728
|
def print_version(name, version, latest=None):
|
|
@@ -730,7 +733,8 @@ def version():
|
|
|
730
733
|
|
|
731
734
|
divider()
|
|
732
735
|
print_version("RelationalAI", __version__, latest_version("relationalai"))
|
|
733
|
-
|
|
736
|
+
if railib_version is not None:
|
|
737
|
+
print_version("Rai-sdk", railib_version, latest_version("rai-sdk"))
|
|
734
738
|
print_version("Python", sys.version.split()[0])
|
|
735
739
|
|
|
736
740
|
app_version = None
|