relationalai 0.12.0__py3-none-any.whl → 0.12.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,7 +18,8 @@ from relationalai.clients.snowflake import APP_NAME
18
18
  from relationalai.semantics.metamodel import ir, executor as e, factory as f
19
19
  from relationalai.semantics.rel import Compiler
20
20
  from relationalai.clients.config import Config
21
- from relationalai.tools.constants import USE_DIRECT_ACCESS, Generation
21
+ from relationalai.tools.constants import USE_DIRECT_ACCESS, Generation, QUERY_ATTRIBUTES_HEADER
22
+ from relationalai.tools.query_utils import prepare_metadata_for_headers
22
23
 
23
24
 
24
25
  class RelExecutor(e.Executor):
@@ -57,10 +58,10 @@ class RelExecutor(e.Executor):
57
58
  if not self.dry_run:
58
59
  self.engine = self._resources.get_default_engine_name()
59
60
  if not self.keep_model:
60
- atexit.register(self._resources.delete_graph, self.database, True)
61
+ atexit.register(self._resources.delete_graph, self.database, True, "rel")
61
62
  return self._resources
62
63
 
63
- def check_graph_index(self):
64
+ def check_graph_index(self, headers: dict[str, Any] | None = None):
64
65
  # Has to happen first, so self.dry_run is populated.
65
66
  resources = self.resources
66
67
 
@@ -84,7 +85,16 @@ class RelExecutor(e.Executor):
84
85
  assert self.engine is not None
85
86
 
86
87
  with debugging.span("poll_use_index", sources=sources, model=model, engine=engine_name):
87
- resources.poll_use_index(app_name, sources, model, self.engine, engine_size, program_span_id)
88
+ resources.poll_use_index(
89
+ app_name=app_name,
90
+ sources=sources,
91
+ model=model,
92
+ engine_name=self.engine,
93
+ engine_size=engine_size,
94
+ language="rel",
95
+ program_span_id=program_span_id,
96
+ headers=headers,
97
+ )
88
98
 
89
99
  def report_errors(self, problems: list[dict[str, Any]], abort_on_error=True):
90
100
  from relationalai import errors
@@ -143,7 +153,7 @@ class RelExecutor(e.Executor):
143
153
  elif len(all_errors) > 1:
144
154
  raise errors.RAIExceptionSet(all_errors)
145
155
 
146
- def _export(self, raw_code: str, dest_fqn: str, actual_cols: list[str], declared_cols: list[str], update:bool):
156
+ def _export(self, raw_code: str, dest_fqn: str, actual_cols: list[str], declared_cols: list[str], update:bool, headers: dict[str, Any] | None = None):
147
157
  _exec = self.resources._exec
148
158
  output_table = "out" + str(uuid.uuid4()).replace("-", "_")
149
159
  txn_id = None
@@ -153,7 +163,7 @@ class RelExecutor(e.Executor):
153
163
  with debugging.span("transaction"):
154
164
  try:
155
165
  with debugging.span("exec_format") as span:
156
- res = _exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, ?);", [self.database, self.engine, raw_code, output_table, False, True])
166
+ res = _exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?, ?);", [self.database, self.engine, raw_code, output_table, False, True, False, None])
157
167
  txn_id = json.loads(res[0]["EXEC_INTO_TABLE"])["rai_transaction_id"]
158
168
  span["txn_id"] = txn_id
159
169
 
@@ -229,14 +239,18 @@ class RelExecutor(e.Executor):
229
239
  else:
230
240
  raise e
231
241
  if txn_id:
232
- artifact_info = self.resources._list_exec_async_artifacts(txn_id)
242
+ artifact_info = self.resources._list_exec_async_artifacts(txn_id, headers=headers)
233
243
  with debugging.span("fetch"):
234
244
  artifacts = self.resources._download_results(artifact_info, txn_id, "ABORTED")
235
245
  return artifacts
236
246
 
237
247
  def execute(self, model: ir.Model, task: ir.Task, format: Literal["pandas", "snowpark"] = "pandas",
238
- result_cols: list[str] | None = None, export_to: Optional[str] = None, update: bool = False) -> Any:
239
- self.check_graph_index()
248
+ result_cols: list[str] | None = None, export_to: Optional[str] = None, update: bool = False, meta: dict[str, Any] | None = None) -> Any:
249
+ # Format meta as headers
250
+ json_meta = prepare_metadata_for_headers(meta)
251
+ headers = {QUERY_ATTRIBUTES_HEADER: json_meta} if json_meta else {}
252
+
253
+ self.check_graph_index(headers)
240
254
  resources= self.resources
241
255
 
242
256
  rules_code = ""
@@ -277,12 +291,16 @@ class RelExecutor(e.Executor):
277
291
 
278
292
  if not export_to:
279
293
  if format == "pandas":
280
- raw_results = resources.exec_raw(self.database, self.engine, full_code, False, nowait_durable=True)
294
+ raw_results = resources.exec_raw(self.database, self.engine, full_code, False, nowait_durable=True, headers=headers)
281
295
  df, errs = result_helpers.format_results(raw_results, None, cols, generation=Generation.QB) # Pass None for task parameter
282
296
  self.report_errors(errs)
297
+ # Rename columns if wide outputs is enabled
298
+ if self.wide_outputs and len(cols) - len(extra_cols) == len(df.columns):
299
+ df.columns = cols[: len(df.columns)]
300
+
283
301
  return self._postprocess_df(self.config, df, extra_cols)
284
302
  elif format == "snowpark":
285
- results, raw = resources.exec_format(self.database, self.engine, full_code, cols, format=format, readonly=False, nowait_durable=True)
303
+ results, raw = resources.exec_format(self.database, self.engine, full_code, cols, format=format, readonly=False, nowait_durable=True, headers=headers)
286
304
  if raw:
287
305
  df, errs = result_helpers.format_results(raw, None, cols, generation=Generation.QB) # Pass None for task parameter
288
306
  self.report_errors(errs)
@@ -296,7 +314,7 @@ class RelExecutor(e.Executor):
296
314
  else:
297
315
  result_cols = [col for col in cols if col not in extra_cols]
298
316
  assert result_cols
299
- raw = self._export(full_code, export_to, cols, result_cols, update)
317
+ raw = self._export(full_code, export_to, cols, result_cols, update, headers)
300
318
  errors = []
301
319
  if raw:
302
320
  dataframe, errors = result_helpers.format_results(raw, None, result_cols, generation=Generation.QB)
@@ -62,7 +62,7 @@ class SnowflakeExecutor(e.Executor):
62
62
 
63
63
  def execute(self, model: ir.Model, task: ir.Task, format:Literal["pandas", "snowpark"]="pandas",
64
64
  result_cols: Optional[list[str]] = None, export_to: Optional[str] = None,
65
- update: bool = False) -> Union[pd.DataFrame, Any]:
65
+ update: bool = False, meta: dict[str, Any] | None = None) -> Union[pd.DataFrame, Any]:
66
66
  """ Execute the SQL query directly. """
67
67
 
68
68
  warehouse = self.resources.config.get("warehouse", None)
relationalai/tools/cli.py CHANGED
@@ -719,7 +719,10 @@ def config_check(all_profiles:bool=False):
719
719
  @cli.command(help="Print version info")
720
720
  def version():
721
721
  from .. import __version__
722
- from railib import __version__ as railib_version
722
+ try:
723
+ from railib import __version__ as railib_version
724
+ except Exception:
725
+ railib_version = None
723
726
 
724
727
  table = Table(show_header=False, border_style="dim", header_style="bold", box=rich_box.SIMPLE)
725
728
  def print_version(name, version, latest=None):
@@ -730,7 +733,8 @@ def version():
730
733
 
731
734
  divider()
732
735
  print_version("RelationalAI", __version__, latest_version("relationalai"))
733
- print_version("Rai-sdk", railib_version, latest_version("rai-sdk"))
736
+ if railib_version is not None:
737
+ print_version("Rai-sdk", railib_version, latest_version("rai-sdk"))
734
738
  print_version("Python", sys.version.split()[0])
735
739
 
736
740
  app_version = None