relationalai 0.12.0__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,7 +18,8 @@ from relationalai.clients.snowflake import APP_NAME
18
18
  from relationalai.semantics.metamodel import ir, executor as e, factory as f
19
19
  from relationalai.semantics.rel import Compiler
20
20
  from relationalai.clients.config import Config
21
- from relationalai.tools.constants import USE_DIRECT_ACCESS, Generation
21
+ from relationalai.tools.constants import USE_DIRECT_ACCESS, Generation, QUERY_ATTRIBUTES_HEADER
22
+ from relationalai.tools.query_utils import prepare_metadata_for_headers
22
23
 
23
24
 
24
25
  class RelExecutor(e.Executor):
@@ -60,7 +61,7 @@ class RelExecutor(e.Executor):
60
61
  atexit.register(self._resources.delete_graph, self.database, True)
61
62
  return self._resources
62
63
 
63
- def check_graph_index(self):
64
+ def check_graph_index(self, headers: dict[str, Any] | None = None):
64
65
  # Has to happen first, so self.dry_run is populated.
65
66
  resources = self.resources
66
67
 
@@ -84,7 +85,7 @@ class RelExecutor(e.Executor):
84
85
  assert self.engine is not None
85
86
 
86
87
  with debugging.span("poll_use_index", sources=sources, model=model, engine=engine_name):
87
- resources.poll_use_index(app_name, sources, model, self.engine, engine_size, program_span_id)
88
+ resources.poll_use_index(app_name, sources, model, self.engine, engine_size, program_span_id, headers)
88
89
 
89
90
  def report_errors(self, problems: list[dict[str, Any]], abort_on_error=True):
90
91
  from relationalai import errors
@@ -143,7 +144,7 @@ class RelExecutor(e.Executor):
143
144
  elif len(all_errors) > 1:
144
145
  raise errors.RAIExceptionSet(all_errors)
145
146
 
146
- def _export(self, raw_code: str, dest_fqn: str, actual_cols: list[str], declared_cols: list[str], update:bool):
147
+ def _export(self, raw_code: str, dest_fqn: str, actual_cols: list[str], declared_cols: list[str], update:bool, headers: dict[str, Any] | None = None):
147
148
  _exec = self.resources._exec
148
149
  output_table = "out" + str(uuid.uuid4()).replace("-", "_")
149
150
  txn_id = None
@@ -153,7 +154,7 @@ class RelExecutor(e.Executor):
153
154
  with debugging.span("transaction"):
154
155
  try:
155
156
  with debugging.span("exec_format") as span:
156
- res = _exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, ?);", [self.database, self.engine, raw_code, output_table, False, True])
157
+ res = _exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, NULL, ?, {headers}, ?, ?);", [self.database, self.engine, raw_code, output_table, False, True, False, None])
157
158
  txn_id = json.loads(res[0]["EXEC_INTO_TABLE"])["rai_transaction_id"]
158
159
  span["txn_id"] = txn_id
159
160
 
@@ -229,14 +230,18 @@ class RelExecutor(e.Executor):
229
230
  else:
230
231
  raise e
231
232
  if txn_id:
232
- artifact_info = self.resources._list_exec_async_artifacts(txn_id)
233
+ artifact_info = self.resources._list_exec_async_artifacts(txn_id, headers=headers)
233
234
  with debugging.span("fetch"):
234
235
  artifacts = self.resources._download_results(artifact_info, txn_id, "ABORTED")
235
236
  return artifacts
236
237
 
237
238
  def execute(self, model: ir.Model, task: ir.Task, format: Literal["pandas", "snowpark"] = "pandas",
238
- result_cols: list[str] | None = None, export_to: Optional[str] = None, update: bool = False) -> Any:
239
- self.check_graph_index()
239
+ result_cols: list[str] | None = None, export_to: Optional[str] = None, update: bool = False, meta: dict[str, Any] | None = None) -> Any:
240
+ # Format meta as headers
241
+ json_meta = prepare_metadata_for_headers(meta)
242
+ headers = {QUERY_ATTRIBUTES_HEADER: json_meta} if json_meta else {}
243
+
244
+ self.check_graph_index(headers)
240
245
  resources= self.resources
241
246
 
242
247
  rules_code = ""
@@ -277,12 +282,16 @@ class RelExecutor(e.Executor):
277
282
 
278
283
  if not export_to:
279
284
  if format == "pandas":
280
- raw_results = resources.exec_raw(self.database, self.engine, full_code, False, nowait_durable=True)
285
+ raw_results = resources.exec_raw(self.database, self.engine, full_code, False, nowait_durable=True, headers=headers)
281
286
  df, errs = result_helpers.format_results(raw_results, None, cols, generation=Generation.QB) # Pass None for task parameter
282
287
  self.report_errors(errs)
288
+ # Rename columns if wide outputs is enabled
289
+ if self.wide_outputs and len(cols) - len(extra_cols) == len(df.columns):
290
+ df.columns = cols[: len(df.columns)]
291
+
283
292
  return self._postprocess_df(self.config, df, extra_cols)
284
293
  elif format == "snowpark":
285
- results, raw = resources.exec_format(self.database, self.engine, full_code, cols, format=format, readonly=False, nowait_durable=True)
294
+ results, raw = resources.exec_format(self.database, self.engine, full_code, cols, format=format, readonly=False, nowait_durable=True, headers=headers)
286
295
  if raw:
287
296
  df, errs = result_helpers.format_results(raw, None, cols, generation=Generation.QB) # Pass None for task parameter
288
297
  self.report_errors(errs)
@@ -296,7 +305,7 @@ class RelExecutor(e.Executor):
296
305
  else:
297
306
  result_cols = [col for col in cols if col not in extra_cols]
298
307
  assert result_cols
299
- raw = self._export(full_code, export_to, cols, result_cols, update)
308
+ raw = self._export(full_code, export_to, cols, result_cols, update, headers)
300
309
  errors = []
301
310
  if raw:
302
311
  dataframe, errors = result_helpers.format_results(raw, None, result_cols, generation=Generation.QB)
@@ -62,7 +62,7 @@ class SnowflakeExecutor(e.Executor):
62
62
 
63
63
  def execute(self, model: ir.Model, task: ir.Task, format:Literal["pandas", "snowpark"]="pandas",
64
64
  result_cols: Optional[list[str]] = None, export_to: Optional[str] = None,
65
- update: bool = False) -> Union[pd.DataFrame, Any]:
65
+ update: bool = False, meta: dict[str, Any] | None = None) -> Union[pd.DataFrame, Any]:
66
66
  """ Execute the SQL query directly. """
67
67
 
68
68
  warehouse = self.resources.config.get("warehouse", None)
relationalai/tools/cli.py CHANGED
@@ -719,7 +719,10 @@ def config_check(all_profiles:bool=False):
719
719
  @cli.command(help="Print version info")
720
720
  def version():
721
721
  from .. import __version__
722
- from railib import __version__ as railib_version
722
+ try:
723
+ from railib import __version__ as railib_version
724
+ except Exception:
725
+ railib_version = None
723
726
 
724
727
  table = Table(show_header=False, border_style="dim", header_style="bold", box=rich_box.SIMPLE)
725
728
  def print_version(name, version, latest=None):
@@ -730,7 +733,8 @@ def version():
730
733
 
731
734
  divider()
732
735
  print_version("RelationalAI", __version__, latest_version("relationalai"))
733
- print_version("Rai-sdk", railib_version, latest_version("rai-sdk"))
736
+ if railib_version is not None:
737
+ print_version("Rai-sdk", railib_version, latest_version("rai-sdk"))
734
738
  print_version("Python", sys.version.split()[0])
735
739
 
736
740
  app_version = None