relationalai 0.12.0__py3-none-any.whl → 0.12.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/clients/direct_access_client.py +5 -0
- relationalai/clients/snowflake.py +259 -91
- relationalai/clients/types.py +4 -1
- relationalai/clients/use_index_poller.py +96 -55
- relationalai/clients/util.py +9 -0
- relationalai/dsl.py +1 -2
- relationalai/environments/snowbook.py +10 -1
- relationalai/experimental/solvers.py +283 -79
- relationalai/semantics/internal/internal.py +24 -5
- relationalai/semantics/lqp/executor.py +22 -6
- relationalai/semantics/lqp/model2lqp.py +4 -2
- relationalai/semantics/metamodel/executor.py +2 -1
- relationalai/semantics/metamodel/rewrite/flatten.py +8 -7
- relationalai/semantics/reasoners/graph/core.py +1174 -226
- relationalai/semantics/rel/executor.py +30 -12
- relationalai/semantics/sql/executor/snowflake.py +1 -1
- relationalai/tools/cli.py +6 -2
- relationalai/tools/cli_controls.py +334 -352
- relationalai/tools/constants.py +1 -0
- relationalai/tools/query_utils.py +27 -0
- relationalai/util/otel_configuration.py +1 -1
- {relationalai-0.12.0.dist-info → relationalai-0.12.2.dist-info}/METADATA +1 -1
- {relationalai-0.12.0.dist-info → relationalai-0.12.2.dist-info}/RECORD +26 -25
- {relationalai-0.12.0.dist-info → relationalai-0.12.2.dist-info}/WHEEL +0 -0
- {relationalai-0.12.0.dist-info → relationalai-0.12.2.dist-info}/entry_points.txt +0 -0
- {relationalai-0.12.0.dist-info → relationalai-0.12.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -18,7 +18,8 @@ from relationalai.clients.snowflake import APP_NAME
|
|
|
18
18
|
from relationalai.semantics.metamodel import ir, executor as e, factory as f
|
|
19
19
|
from relationalai.semantics.rel import Compiler
|
|
20
20
|
from relationalai.clients.config import Config
|
|
21
|
-
from relationalai.tools.constants import USE_DIRECT_ACCESS, Generation
|
|
21
|
+
from relationalai.tools.constants import USE_DIRECT_ACCESS, Generation, QUERY_ATTRIBUTES_HEADER
|
|
22
|
+
from relationalai.tools.query_utils import prepare_metadata_for_headers
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
class RelExecutor(e.Executor):
|
|
@@ -57,10 +58,10 @@ class RelExecutor(e.Executor):
|
|
|
57
58
|
if not self.dry_run:
|
|
58
59
|
self.engine = self._resources.get_default_engine_name()
|
|
59
60
|
if not self.keep_model:
|
|
60
|
-
atexit.register(self._resources.delete_graph, self.database, True)
|
|
61
|
+
atexit.register(self._resources.delete_graph, self.database, True, "rel")
|
|
61
62
|
return self._resources
|
|
62
63
|
|
|
63
|
-
def check_graph_index(self):
|
|
64
|
+
def check_graph_index(self, headers: dict[str, Any] | None = None):
|
|
64
65
|
# Has to happen first, so self.dry_run is populated.
|
|
65
66
|
resources = self.resources
|
|
66
67
|
|
|
@@ -84,7 +85,16 @@ class RelExecutor(e.Executor):
|
|
|
84
85
|
assert self.engine is not None
|
|
85
86
|
|
|
86
87
|
with debugging.span("poll_use_index", sources=sources, model=model, engine=engine_name):
|
|
87
|
-
resources.poll_use_index(
|
|
88
|
+
resources.poll_use_index(
|
|
89
|
+
app_name=app_name,
|
|
90
|
+
sources=sources,
|
|
91
|
+
model=model,
|
|
92
|
+
engine_name=self.engine,
|
|
93
|
+
engine_size=engine_size,
|
|
94
|
+
language="rel",
|
|
95
|
+
program_span_id=program_span_id,
|
|
96
|
+
headers=headers,
|
|
97
|
+
)
|
|
88
98
|
|
|
89
99
|
def report_errors(self, problems: list[dict[str, Any]], abort_on_error=True):
|
|
90
100
|
from relationalai import errors
|
|
@@ -143,7 +153,7 @@ class RelExecutor(e.Executor):
|
|
|
143
153
|
elif len(all_errors) > 1:
|
|
144
154
|
raise errors.RAIExceptionSet(all_errors)
|
|
145
155
|
|
|
146
|
-
def _export(self, raw_code: str, dest_fqn: str, actual_cols: list[str], declared_cols: list[str], update:bool):
|
|
156
|
+
def _export(self, raw_code: str, dest_fqn: str, actual_cols: list[str], declared_cols: list[str], update:bool, headers: dict[str, Any] | None = None):
|
|
147
157
|
_exec = self.resources._exec
|
|
148
158
|
output_table = "out" + str(uuid.uuid4()).replace("-", "_")
|
|
149
159
|
txn_id = None
|
|
@@ -153,7 +163,7 @@ class RelExecutor(e.Executor):
|
|
|
153
163
|
with debugging.span("transaction"):
|
|
154
164
|
try:
|
|
155
165
|
with debugging.span("exec_format") as span:
|
|
156
|
-
res = _exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, ?);", [self.database, self.engine, raw_code, output_table, False, True])
|
|
166
|
+
res = _exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?, ?);", [self.database, self.engine, raw_code, output_table, False, True, False, None])
|
|
157
167
|
txn_id = json.loads(res[0]["EXEC_INTO_TABLE"])["rai_transaction_id"]
|
|
158
168
|
span["txn_id"] = txn_id
|
|
159
169
|
|
|
@@ -229,14 +239,18 @@ class RelExecutor(e.Executor):
|
|
|
229
239
|
else:
|
|
230
240
|
raise e
|
|
231
241
|
if txn_id:
|
|
232
|
-
artifact_info = self.resources._list_exec_async_artifacts(txn_id)
|
|
242
|
+
artifact_info = self.resources._list_exec_async_artifacts(txn_id, headers=headers)
|
|
233
243
|
with debugging.span("fetch"):
|
|
234
244
|
artifacts = self.resources._download_results(artifact_info, txn_id, "ABORTED")
|
|
235
245
|
return artifacts
|
|
236
246
|
|
|
237
247
|
def execute(self, model: ir.Model, task: ir.Task, format: Literal["pandas", "snowpark"] = "pandas",
|
|
238
|
-
result_cols: list[str] | None = None, export_to: Optional[str] = None, update: bool = False) -> Any:
|
|
239
|
-
|
|
248
|
+
result_cols: list[str] | None = None, export_to: Optional[str] = None, update: bool = False, meta: dict[str, Any] | None = None) -> Any:
|
|
249
|
+
# Format meta as headers
|
|
250
|
+
json_meta = prepare_metadata_for_headers(meta)
|
|
251
|
+
headers = {QUERY_ATTRIBUTES_HEADER: json_meta} if json_meta else {}
|
|
252
|
+
|
|
253
|
+
self.check_graph_index(headers)
|
|
240
254
|
resources= self.resources
|
|
241
255
|
|
|
242
256
|
rules_code = ""
|
|
@@ -277,12 +291,16 @@ class RelExecutor(e.Executor):
|
|
|
277
291
|
|
|
278
292
|
if not export_to:
|
|
279
293
|
if format == "pandas":
|
|
280
|
-
raw_results = resources.exec_raw(self.database, self.engine, full_code, False, nowait_durable=True)
|
|
294
|
+
raw_results = resources.exec_raw(self.database, self.engine, full_code, False, nowait_durable=True, headers=headers)
|
|
281
295
|
df, errs = result_helpers.format_results(raw_results, None, cols, generation=Generation.QB) # Pass None for task parameter
|
|
282
296
|
self.report_errors(errs)
|
|
297
|
+
# Rename columns if wide outputs is enabled
|
|
298
|
+
if self.wide_outputs and len(cols) - len(extra_cols) == len(df.columns):
|
|
299
|
+
df.columns = cols[: len(df.columns)]
|
|
300
|
+
|
|
283
301
|
return self._postprocess_df(self.config, df, extra_cols)
|
|
284
302
|
elif format == "snowpark":
|
|
285
|
-
results, raw = resources.exec_format(self.database, self.engine, full_code, cols, format=format, readonly=False, nowait_durable=True)
|
|
303
|
+
results, raw = resources.exec_format(self.database, self.engine, full_code, cols, format=format, readonly=False, nowait_durable=True, headers=headers)
|
|
286
304
|
if raw:
|
|
287
305
|
df, errs = result_helpers.format_results(raw, None, cols, generation=Generation.QB) # Pass None for task parameter
|
|
288
306
|
self.report_errors(errs)
|
|
@@ -296,7 +314,7 @@ class RelExecutor(e.Executor):
|
|
|
296
314
|
else:
|
|
297
315
|
result_cols = [col for col in cols if col not in extra_cols]
|
|
298
316
|
assert result_cols
|
|
299
|
-
raw = self._export(full_code, export_to, cols, result_cols, update)
|
|
317
|
+
raw = self._export(full_code, export_to, cols, result_cols, update, headers)
|
|
300
318
|
errors = []
|
|
301
319
|
if raw:
|
|
302
320
|
dataframe, errors = result_helpers.format_results(raw, None, result_cols, generation=Generation.QB)
|
|
@@ -62,7 +62,7 @@ class SnowflakeExecutor(e.Executor):
|
|
|
62
62
|
|
|
63
63
|
def execute(self, model: ir.Model, task: ir.Task, format:Literal["pandas", "snowpark"]="pandas",
|
|
64
64
|
result_cols: Optional[list[str]] = None, export_to: Optional[str] = None,
|
|
65
|
-
update: bool = False) -> Union[pd.DataFrame, Any]:
|
|
65
|
+
update: bool = False, meta: dict[str, Any] | None = None) -> Union[pd.DataFrame, Any]:
|
|
66
66
|
""" Execute the SQL query directly. """
|
|
67
67
|
|
|
68
68
|
warehouse = self.resources.config.get("warehouse", None)
|
relationalai/tools/cli.py
CHANGED
|
@@ -719,7 +719,10 @@ def config_check(all_profiles:bool=False):
|
|
|
719
719
|
@cli.command(help="Print version info")
|
|
720
720
|
def version():
|
|
721
721
|
from .. import __version__
|
|
722
|
-
|
|
722
|
+
try:
|
|
723
|
+
from railib import __version__ as railib_version
|
|
724
|
+
except Exception:
|
|
725
|
+
railib_version = None
|
|
723
726
|
|
|
724
727
|
table = Table(show_header=False, border_style="dim", header_style="bold", box=rich_box.SIMPLE)
|
|
725
728
|
def print_version(name, version, latest=None):
|
|
@@ -730,7 +733,8 @@ def version():
|
|
|
730
733
|
|
|
731
734
|
divider()
|
|
732
735
|
print_version("RelationalAI", __version__, latest_version("relationalai"))
|
|
733
|
-
|
|
736
|
+
if railib_version is not None:
|
|
737
|
+
print_version("Rai-sdk", railib_version, latest_version("rai-sdk"))
|
|
734
738
|
print_version("Python", sys.version.split()[0])
|
|
735
739
|
|
|
736
740
|
app_version = None
|