relationalai 0.12.0__py3-none-any.whl → 0.12.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,11 +9,13 @@ from ..metamodel import Builtins
9
9
  from ..tools.cli_controls import Spinner
10
10
  from ..tools.constants import DEFAULT_QUERY_TIMEOUT_MINS
11
11
  from .. import debugging
12
+ from .. errors import ResponseStatusException
12
13
  import uuid
13
14
  import relationalai
14
15
  import json
15
16
  from ..clients.util import poll_with_specified_overhead
16
17
  from ..clients.snowflake import Resources as SnowflakeResources
18
+ from ..clients.snowflake import DirectAccessClient, DirectAccessResources
17
19
  from ..util.timeout import calc_remaining_timeout_minutes
18
20
 
19
21
  rel_sv = rel._tagged(Builtins.SingleValued)
@@ -23,7 +25,8 @@ APP_NAME = relationalai.clients.snowflake.APP_NAME
23
25
  ENGINE_TYPE_SOLVER = "SOLVER"
24
26
  # TODO (dba) The ERP still uses `worker` instead of `engine`. Change
25
27
  # this once we fix this in the ERP.
26
- ENGINE_ERRORS = ["worker is suspended", "create/resume", "worker not found", "no workers found", "worker was deleted"]
28
+ WORKER_ERRORS = ["worker is suspended", "create/resume", "worker not found", "no workers found", "worker was deleted"]
29
+ ENGINE_ERRORS = ["engine is suspended", "create/resume", "engine not found", "no engines found", "engine was deleted"]
27
30
  ENGINE_NOT_READY_MSGS = ["worker is in pending", "worker is provisioning", "worker is not ready to accept jobs"]
28
31
 
29
32
  # --------------------------------------------------
@@ -213,13 +216,6 @@ class SolverModel:
213
216
  config_file_path = getattr(rai_config, 'file_path', None)
214
217
  start_time = time.monotonic()
215
218
  remaining_timeout_minutes = query_timeout_mins
216
- # 1. Materialize the model and store it.
217
- # TODO(coey) Currently we must run a dummy query to install the pyrel rules in a separate txn
218
- # to the solve_output updates. Ideally pyrel would offer an option to flush the rules separately.
219
- self.graph.exec_raw("", query_timeout_mins=remaining_timeout_minutes)
220
- remaining_timeout_minutes = calc_remaining_timeout_minutes(
221
- start_time, query_timeout_mins, config_file_path=config_file_path,
222
- )
223
219
  response = self.graph.exec_raw(
224
220
  textwrap.dedent(f"""
225
221
  @inline
@@ -269,7 +265,9 @@ class SolverModel:
269
265
  job_id = solver._exec_job(payload, log_to_console=log_to_console, query_timeout_mins=remaining_timeout_minutes)
270
266
  except Exception as e:
271
267
  err_message = str(e).lower()
272
- if any(kw in err_message.lower() for kw in ENGINE_ERRORS + ENGINE_NOT_READY_MSGS):
268
+ if isinstance(e, ResponseStatusException):
269
+ err_message = e.response.json().get("message", "")
270
+ if any(kw in err_message.lower() for kw in ENGINE_ERRORS + WORKER_ERRORS + ENGINE_NOT_READY_MSGS):
273
271
  solver._auto_create_solver_async()
274
272
  remaining_timeout_minutes = calc_remaining_timeout_minutes(
275
273
  start_time, query_timeout_mins, config_file_path=config_file_path
@@ -553,7 +551,11 @@ class Solver:
553
551
  # may configure each individual solver.
554
552
  self.engine_settings = settings
555
553
 
556
- return self._auto_create_solver_async()
554
+ # Optimistically set the engine object to a `READY` engine to
555
+ # avoid checking the engine status on each execution.
556
+ self.engine:Optional[dict[str,Any]] = {"name": engine_name, "state": "READY"}
557
+
558
+ return None
557
559
 
558
560
  # --------------------------------------------------
559
561
  # Helper
@@ -572,6 +574,7 @@ class Solver:
572
574
  assert len(engines) == 1 or len(engines) == 0
573
575
  if len(engines) != 0:
574
576
  engine = engines[0]
577
+
575
578
  if engine:
576
579
  # TODO (dba) Logic engines support altering the
577
580
  # auto_suspend_mins setting. Currently, we don't have
@@ -653,31 +656,20 @@ class Solver:
653
656
 
654
657
  self.engine = engine
655
658
 
656
- def _exec_job_async(self, payload, query_timeout_mins: Optional[int]=None):
657
- payload_json = json.dumps(payload)
658
- engine_name = self.engine["name"]
659
- if query_timeout_mins is None and (timeout_value := self.rai_config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
660
- query_timeout_mins = int(timeout_value)
661
- if query_timeout_mins is not None:
662
- sql_string = textwrap.dedent(f"""
663
- CALL {APP_NAME}.experimental.exec_job_async('{ENGINE_TYPE_SOLVER}', '{engine_name}', '{payload_json}', null, {query_timeout_mins})
664
- """)
665
- else:
666
- sql_string = textwrap.dedent(f"""
667
- CALL {APP_NAME}.experimental.exec_job_async('{ENGINE_TYPE_SOLVER}', '{engine_name}', '{payload_json}')
668
- """)
669
- res = self.provider.resources._exec(sql_string)
670
- return res[0]["ID"]
671
-
672
659
  def _exec_job(self, payload, log_to_console=True, query_timeout_mins: Optional[int]=None):
660
+ if self.engine is None:
661
+ raise Exception("Engine not initialized.")
662
+
673
663
  # Make sure the engine is ready.
674
664
  if self.engine["state"] != "READY":
675
665
  poll_with_specified_overhead(lambda: self._is_solver_ready(), 0.1)
676
666
 
677
667
  with debugging.span("job") as job_span:
678
- job_id = self._exec_job_async(payload, query_timeout_mins=query_timeout_mins)
668
+ job_id = self.provider.create_job_async(self.engine["name"], payload, query_timeout_mins=query_timeout_mins)
679
669
  job_span["job_id"] = job_id
680
670
  debugging.event("job_created", job_span, job_id=job_id, engine_name=self.engine["name"], job_type=ENGINE_TYPE_SOLVER)
671
+ if not isinstance(job_id, str):
672
+ job_id = ""
681
673
  polling_state = PollingState(job_id, "", False, log_to_console)
682
674
 
683
675
  try:
@@ -693,7 +685,14 @@ class Solver:
693
685
  return job_id
694
686
 
695
687
  def _is_solver_ready(self):
688
+ if self.engine is None:
689
+ raise Exception("Engine not initialized.")
690
+
696
691
  result = self.provider.get_solver(self.engine["name"])
692
+
693
+ if result is None:
694
+ raise Exception("No engine available.")
695
+
697
696
  self.engine = result
698
697
  state = result["state"]
699
698
  if state != "READY" and state != "PENDING":
@@ -711,20 +710,11 @@ class Solver:
711
710
 
712
711
  return status == "COMPLETED" or status == "FAILED" or status == "CANCELED"
713
712
 
714
- def _get_job_events(self, job_id: str, continuation_token: str = ""):
715
- results = self.provider.resources._exec(
716
- f"SELECT {APP_NAME}.experimental.get_job_events('{ENGINE_TYPE_SOLVER}', '{job_id}', '{continuation_token}');"
717
- )
718
- if not results:
719
- return {"events": [], "continuation_token": None}
720
- row = results[0][0]
721
- return json.loads(row)
722
-
723
713
  def _print_solver_logs(self, state: PollingState):
724
714
  if state.is_done:
725
715
  return
726
716
 
727
- resp = self._get_job_events(state.job_id, state.continuation_token)
717
+ resp = self.provider.get_job_events(state.job_id, state.continuation_token)
728
718
 
729
719
  # Print solver logs to stdout.
730
720
  for event in resp["events"]:
@@ -754,7 +744,12 @@ class Provider:
754
744
  resources = relationalai.Resources()
755
745
  if not isinstance(resources, relationalai.clients.snowflake.Resources):
756
746
  raise Exception("Solvers are only supported on SPCS.")
747
+
757
748
  self.resources = resources
749
+ self.direct_access_client: Optional[DirectAccessClient] = None
750
+
751
+ if isinstance(self.resources, DirectAccessResources):
752
+ self.direct_access_client = self.resources.direct_access_client
758
753
 
759
754
  def create_solver(
760
755
  self,
@@ -770,75 +765,285 @@ class Provider:
770
765
  engine_config: dict[str, Any] = {"settings": settings}
771
766
  if auto_suspend_mins is not None:
772
767
  engine_config["auto_suspend_mins"] = auto_suspend_mins
773
- self.resources._exec(
774
- f"CALL {APP_NAME}.experimental.create_engine('{ENGINE_TYPE_SOLVER}', '{name}', '{size}', {engine_config});"
768
+ self.resources._exec_sql(
769
+ f"CALL {APP_NAME}.experimental.create_engine('{ENGINE_TYPE_SOLVER}', '{name}', '{size}', {engine_config});", None
775
770
  )
776
771
 
777
772
  def create_solver_async(
778
773
  self,
779
774
  name: str,
780
775
  size: str | None = None,
781
- settings: dict = {},
776
+ settings: dict | None = None,
782
777
  auto_suspend_mins: int | None = None,
783
778
  ):
784
779
  if size is None:
785
780
  size = "HIGHMEM_X64_S"
786
- if settings is None:
787
- settings = ""
788
- engine_config: dict[str, Any] = {"settings": settings}
789
- if auto_suspend_mins is not None:
790
- engine_config["auto_suspend_mins"] = auto_suspend_mins
791
- self.resources._exec(
792
- f"CALL {APP_NAME}.experimental.create_engine_async('{ENGINE_TYPE_SOLVER}', '{name}', '{size}', {engine_config});"
793
- )
781
+
782
+ if self.direct_access_client is not None:
783
+ payload:dict[str, Any] = {
784
+ "name": name,
785
+ "settings": settings,
786
+ }
787
+ if auto_suspend_mins is not None:
788
+ payload["auto_suspend_mins"] = auto_suspend_mins
789
+ if size is not None:
790
+ payload["size"] = size
791
+ response = self.direct_access_client.request(
792
+ "create_engine",
793
+ payload=payload,
794
+ path_params={"engine_type": "solver"},
795
+ )
796
+ if response.status_code != 200:
797
+ raise ResponseStatusException(
798
+ f"Failed to create engine {name} with size {size}.", response
799
+ )
800
+ else:
801
+ engine_config: dict[str, Any] = {}
802
+ if settings is not None:
803
+ engine_config["settings"] = settings
804
+ if auto_suspend_mins is not None:
805
+ engine_config["auto_suspend_mins"] = auto_suspend_mins
806
+ self.resources._exec_sql(
807
+ f"CALL {APP_NAME}.experimental.create_engine_async('{ENGINE_TYPE_SOLVER}', '{name}', '{size}', {engine_config});",
808
+ None
809
+ )
794
810
 
795
811
  def delete_solver(self, name: str):
796
- self.resources._exec(
797
- f"CALL {APP_NAME}.experimental.delete_engine('{ENGINE_TYPE_SOLVER}', '{name}');"
798
- )
812
+ if self.direct_access_client is not None:
813
+ response = self.direct_access_client.request(
814
+ "delete_engine", path_params = {"engine_type": ENGINE_TYPE_SOLVER, "engine_name": name}
815
+ )
816
+ if response.status_code != 200:
817
+ raise ResponseStatusException("Failed to delete engine.", response)
818
+ return None
819
+ else:
820
+ self.resources._exec_sql(
821
+ f"CALL {APP_NAME}.experimental.delete_engine('{ENGINE_TYPE_SOLVER}', '{name}');",
822
+ None
823
+ )
799
824
 
800
825
  def resume_solver_async(self, name: str):
801
- self.resources._exec(
802
- f"CALL {APP_NAME}.experimental.resume_engine_async('{ENGINE_TYPE_SOLVER}', '{name}');"
803
- )
826
+ if self.direct_access_client is not None:
827
+ response = self.direct_access_client.request(
828
+ "resume_engine", path_params = {"engine_type": ENGINE_TYPE_SOLVER, "engine_name": name}
829
+ )
830
+ if response.status_code != 200:
831
+ raise ResponseStatusException("Failed to resume engine.", response)
832
+ return None
833
+ else:
834
+ self.resources._exec_sql(
835
+ f"CALL {APP_NAME}.experimental.resume_engine_async('{ENGINE_TYPE_SOLVER}', '{name}');",
836
+ None
837
+ )
838
+ return None
804
839
 
805
840
  def get_solver(self, name: str):
806
- results = self.resources._exec(
807
- f"CALL {APP_NAME}.experimental.get_engine('{ENGINE_TYPE_SOLVER}', '{name}');"
808
- )
809
- return solver_list_to_dicts(results)[0]
841
+ if self.direct_access_client is not None:
842
+ response = self.direct_access_client.request(
843
+ "get_engine", path_params = {"engine_type": ENGINE_TYPE_SOLVER, "engine_name": name}
844
+ )
845
+ if response.status_code != 200:
846
+ raise ResponseStatusException("Failed to get engine.", response)
847
+ solver = response.json()
848
+ if not solver :
849
+ return None
850
+ solver_state = {
851
+ "name": solver["name"],
852
+ "id": solver["id"],
853
+ "size": solver["size"],
854
+ "state": solver["status"], # callers are expecting 'state'
855
+ "created_by": solver["created_by"],
856
+ "created_on": solver["created_on"],
857
+ "updated_on": solver["updated_on"],
858
+ "version": solver["version"],
859
+ "auto_suspend": solver["auto_suspend_mins"],
860
+ "suspends_at": solver["suspends_at"],
861
+ "solvers": []
862
+ if solver["settings"] == ""
863
+ else [
864
+ k
865
+ for (k,v) in json.loads(solver["settings"]).items()
866
+ if isinstance(v, dict) and v.get("enabled", False)
867
+ ],
868
+ }
869
+ return solver_state
870
+ else:
871
+ results = self.resources._exec_sql(
872
+ f"CALL {APP_NAME}.experimental.get_engine('{ENGINE_TYPE_SOLVER}', '{name}');",
873
+ None
874
+ )
875
+ return solver_list_to_dicts(results)[0]
810
876
 
811
877
  def list_solvers(self, state: str | None = None):
812
- where_clause = f"WHERE TYPE='{ENGINE_TYPE_SOLVER}'"
813
- where_clause = (
814
- f"{where_clause} AND STATUS = '{state.upper()}'" if state else where_clause
815
- )
816
- statement = f"SELECT NAME,ID,SIZE,STATUS,CREATED_BY,CREATED_ON,UPDATED_ON,AUTO_SUSPEND_MINS,SETTINGS FROM {APP_NAME}.experimental.engines {where_clause};"
817
- results = self.resources._exec(statement)
818
- return solver_list_to_dicts(results)
878
+ if self.direct_access_client is not None:
879
+ response = self.direct_access_client.request(
880
+ "list_engines"
881
+ )
882
+ if response.status_code != 200:
883
+ raise ResponseStatusException("Failed to list engines.", response)
884
+ response_content = response.json()
885
+ if not response_content:
886
+ return []
887
+ engines = [
888
+ {
889
+ "name": engine["name"],
890
+ "id": engine["id"],
891
+ "size": engine["size"],
892
+ "state": engine["status"], # callers are expecting 'state'
893
+ "created_by": engine["created_by"],
894
+ "created_on": engine["created_on"],
895
+ "updated_on": engine["updated_on"],
896
+ "auto_suspend_mins": engine["auto_suspend_mins"],
897
+ "solvers": []
898
+ if engine["settings"] == ""
899
+ else [
900
+ k
901
+ for (k, v) in json.loads(engine["settings"]).items()
902
+ if isinstance(v, dict) and v.get("enabled", False)
903
+ ],
904
+ }
905
+ for engine in response_content.get("engines", [])
906
+ if (state is None or engine.get("status") == state) and (engine.get("type") == ENGINE_TYPE_SOLVER)
907
+ ]
908
+ return sorted(engines, key=lambda x: x["name"])
909
+ else:
910
+ where_clause = f"WHERE TYPE='{ENGINE_TYPE_SOLVER}'"
911
+ where_clause = (
912
+ f"{where_clause} AND STATUS = '{state.upper()}'" if state else where_clause
913
+ )
914
+ statement = f"SELECT NAME,ID,SIZE,STATUS,CREATED_BY,CREATED_ON,UPDATED_ON,AUTO_SUSPEND_MINS,SETTINGS FROM {APP_NAME}.experimental.engines {where_clause};"
915
+ results = self.resources._exec_sql(statement, None)
916
+ return solver_list_to_dicts(results)
819
917
 
820
918
  # --------------------------------------------------
821
919
  # Job API
822
920
  # --------------------------------------------------
823
921
 
922
+ def create_job_async(self, engine_name, payload, query_timeout_mins: Optional[int]=None):
923
+ payload_json = json.dumps(payload)
924
+
925
+ if query_timeout_mins is None and (timeout_value := self.resources.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
926
+ query_timeout_mins = int(timeout_value)
927
+
928
+ if self.direct_access_client is not None:
929
+ job = {
930
+ "job_type":ENGINE_TYPE_SOLVER,
931
+ "worker_name": engine_name,
932
+ "timeout_mins": query_timeout_mins,
933
+ "payload": payload_json,
934
+ }
935
+ response = self.direct_access_client.request(
936
+ "create_job",
937
+ payload=job,
938
+ )
939
+ if response.status_code != 200:
940
+ raise ResponseStatusException("Failed to create job.", response)
941
+ response_content = response.json()
942
+ return response_content["id"]
943
+ else:
944
+ if query_timeout_mins is not None:
945
+ sql_string = textwrap.dedent(f"""
946
+ CALL {APP_NAME}.experimental.exec_job_async('{ENGINE_TYPE_SOLVER}', '{engine_name}', '{payload_json}', null, {query_timeout_mins})
947
+ """)
948
+ else:
949
+ sql_string = textwrap.dedent(f"""
950
+ CALL {APP_NAME}.experimental.exec_job_async('{ENGINE_TYPE_SOLVER}', '{engine_name}', '{payload_json}')
951
+ """)
952
+ res = self.resources._exec_sql(sql_string, None)
953
+ return res[0]["ID"]
954
+
824
955
  def list_jobs(self, state=None, limit=None):
825
- state_clause = f"AND STATE = '{state.upper()}'" if state else ""
826
- limit_clause = f"LIMIT {limit}" if limit else ""
827
- results = self.resources._exec(
828
- f"SELECT ID,STATE,CREATED_BY,CREATED_ON,FINISHED_AT,DURATION,PAYLOAD,ENGINE_NAME FROM {APP_NAME}.experimental.jobs where type='{ENGINE_TYPE_SOLVER}' {state_clause} ORDER BY created_on DESC {limit_clause};"
829
- )
830
- return job_list_to_dicts(results)
956
+ if self.direct_access_client is not None:
957
+ response = self.direct_access_client.request(
958
+ "list_jobs"
959
+ )
960
+ if response.status_code != 200:
961
+ raise ResponseStatusException("Failed to list jobs.", response)
962
+ response_content = response.json()
963
+ if not response_content:
964
+ return []
965
+ jobs = [
966
+ {
967
+ "id": job["id"],
968
+ "state": job["state"],
969
+ "created_by": job["created_by"],
970
+ "created_on": job["created_on"],
971
+ "finished_at": job.get("finished_at", None),
972
+ "duration": job["duration"] if "duration" in job else 0,
973
+ "solver": json.loads(job["payload"]).get("solver", ""),
974
+ "engine": job.get("engine_name", job["worker_name"]),
975
+ }
976
+ for job in response_content.get("jobs", [])
977
+ if state is None or job.get("state") == state
978
+ ]
979
+ return sorted(jobs, key=lambda x: x["created_on"], reverse=True)
980
+ else:
981
+ state_clause = f"AND STATE = '{state.upper()}'" if state else ""
982
+ limit_clause = f"LIMIT {limit}" if limit else ""
983
+ results = self.resources._exec_sql(
984
+ f"SELECT ID,STATE,CREATED_BY,CREATED_ON,FINISHED_AT,DURATION,PAYLOAD,ENGINE_NAME FROM {APP_NAME}.experimental.jobs where type='{ENGINE_TYPE_SOLVER}' {state_clause} ORDER BY created_on DESC {limit_clause};",
985
+ None
986
+ )
987
+ return job_list_to_dicts(results)
831
988
 
832
989
  def get_job(self, id: str):
833
- results = self.resources._exec(
834
- f"CALL {APP_NAME}.experimental.get_job('{ENGINE_TYPE_SOLVER}', '{id}');"
835
- )
836
- return job_list_to_dicts(results)[0]
990
+ if self.direct_access_client is not None:
991
+ response = self.direct_access_client.request(
992
+ "get_job", path_params = {"job_type": ENGINE_TYPE_SOLVER, "job_id": id}
993
+ )
994
+ if response.status_code != 200:
995
+ raise ResponseStatusException("Failed to get job.", response)
996
+ response_content = response.json()
997
+ return response_content["job"]
998
+ else:
999
+ results = self.resources._exec_sql(
1000
+ f"CALL {APP_NAME}.experimental.get_job('{ENGINE_TYPE_SOLVER}', '{id}');",
1001
+ None
1002
+ )
1003
+ return job_list_to_dicts(results)[0]
1004
+
1005
+ def get_job_events(self, job_id: str, continuation_token: str = ""):
1006
+ if self.direct_access_client is not None:
1007
+ response = self.direct_access_client.request(
1008
+ "get_job_events",
1009
+ path_params = {"job_type": ENGINE_TYPE_SOLVER, "job_id": job_id, "stream_name": "progress"},
1010
+ query_params={"continuation_token": continuation_token},
1011
+ )
1012
+ if response.status_code != 200:
1013
+ raise ResponseStatusException("Failed to get job events.", response)
1014
+ response_content = response.json()
1015
+ if not response_content:
1016
+ return {
1017
+ "events": [],
1018
+ "continuation_token": None
1019
+ }
1020
+ return response_content
1021
+ else:
1022
+ results = self.resources._exec_sql(
1023
+ f"SELECT {APP_NAME}.experimental.get_job_events('{ENGINE_TYPE_SOLVER}', '{job_id}', '{continuation_token}');",
1024
+ None
1025
+ )
1026
+ if not results:
1027
+ return {"events": [], "continuation_token": None}
1028
+ row = results[0][0]
1029
+ if not isinstance(row, str):
1030
+ row = ""
1031
+ return json.loads(row)
837
1032
 
838
1033
  def cancel_job(self, id: str):
839
- self.resources._exec(
840
- f"CALL {APP_NAME}.experimental.cancel_job('{ENGINE_TYPE_SOLVER}', '{id}');"
841
- )
1034
+ if self.direct_access_client is not None:
1035
+ response = self.direct_access_client.request(
1036
+ "cancel_job", path_params = {"job_type": ENGINE_TYPE_SOLVER, "job_id": id}
1037
+ )
1038
+ if response.status_code != 200:
1039
+ raise ResponseStatusException("Failed to cancel job.", response)
1040
+ return None
1041
+ else:
1042
+ self.resources._exec_sql(
1043
+ f"CALL {APP_NAME}.experimental.cancel_job('{ENGINE_TYPE_SOLVER}', '{id}');",
1044
+ None
1045
+ )
1046
+ return None
842
1047
 
843
1048
 
844
1049
  def solver_list_to_dicts(results):
@@ -865,7 +1070,6 @@ def solver_list_to_dicts(results):
865
1070
  for row in results
866
1071
  ]
867
1072
 
868
-
869
1073
  def job_list_to_dicts(results):
870
1074
  if not results:
871
1075
  return []
@@ -346,7 +346,7 @@ def find_select_keys(item: Any, keys:OrderedSet[Key]|None = None, enable_primiti
346
346
 
347
347
  if isinstance(item, (list, tuple)):
348
348
  for it in item:
349
- find_select_keys(it, keys)
349
+ find_select_keys(it, keys, enable_primitive_key=enable_primitive_key)
350
350
 
351
351
  elif isinstance(item, (Relationship, RelationshipReading)) and item._parent:
352
352
  find_select_keys(item._parent, keys)
@@ -390,7 +390,7 @@ def find_select_keys(item: Any, keys:OrderedSet[Key]|None = None, enable_primiti
390
390
  find_select_keys(item._arg, keys)
391
391
 
392
392
  elif isinstance(item, Alias):
393
- find_select_keys(item._thing, keys)
393
+ find_select_keys(item._thing, keys, enable_primitive_key=enable_primitive_key)
394
394
 
395
395
  elif isinstance(item, Aggregate):
396
396
  keys.update( Key(k, True) for k in item._group )
@@ -2338,6 +2338,7 @@ class Fragment():
2338
2338
  self._define.extend(parent._define)
2339
2339
  self._order_by.extend(parent._order_by)
2340
2340
  self._limit = parent._limit
2341
+ self._meta.update(parent._meta)
2341
2342
 
2342
2343
  def _add_items(self, items:PySequence[Any], to_attr:list[Any]):
2343
2344
  # TODO: ensure that you are _either_ a select, require, or then
@@ -2416,9 +2417,26 @@ class Fragment():
2416
2417
  return f
2417
2418
 
2418
2419
  def meta(self, **kwargs: Any) -> Fragment:
2420
+ """Add metadata to the query.
2421
+
2422
+ Metadata can be used for debugging and observability purposes.
2423
+
2424
+ Args:
2425
+ **kwargs: Metadata key-value pairs
2426
+
2427
+ Returns:
2428
+ Fragment: Returns self for method chaining
2429
+
2430
+ Example:
2431
+ select(Person.name).meta(workload_name="test", priority=1, enabled=True)
2432
+ """
2433
+ if not kwargs:
2434
+ raise ValueError("meta() requires at least one argument")
2435
+
2419
2436
  self._meta.update(kwargs)
2420
2437
  return self
2421
2438
 
2439
+
2422
2440
  def annotate(self, *annos:Expression|Relationship|ir.Annotation) -> Fragment:
2423
2441
  self._annotations.extend(annos)
2424
2442
  return self
@@ -2497,7 +2515,7 @@ class Fragment():
2497
2515
  # @TODO for now we set tag to None but we need to work out how to properly propagate user-provided tag here
2498
2516
  with debugging.span("query", tag=None, dsl=str(self), **with_source(self), meta=self._meta) as query_span:
2499
2517
  query_task = qb_model._compiler.fragment(self)
2500
- results = qb_model._to_executor().execute(ir_model, query_task)
2518
+ results = qb_model._to_executor().execute(ir_model, query_task, meta=self._meta)
2501
2519
  query_span["results"] = results
2502
2520
  # For local debugging mostly
2503
2521
  dry_run = qb_model._dry_run or bool(qb_model._config.get("compiler.dry_run", False))
@@ -2524,7 +2542,7 @@ class Fragment():
2524
2542
  # @TODO for now we set tag to None but we need to work out how to properly propagate user-provided tag here
2525
2543
  with debugging.span("query", tag=None, dsl=str(clone), **with_source(clone), meta=clone._meta) as query_span:
2526
2544
  query_task = qb_model._compiler.fragment(clone)
2527
- results = qb_model._to_executor().execute(ir_model, query_task, format="snowpark")
2545
+ results = qb_model._to_executor().execute(ir_model, query_task, format="snowpark", meta=clone._meta)
2528
2546
  query_span["alt_format_results"] = results
2529
2547
  return results
2530
2548
 
@@ -2541,7 +2559,8 @@ class Fragment():
2541
2559
  clone._source = runtime_env.get_source_pos()
2542
2560
  with debugging.span("query", dsl=str(clone), **with_source(clone), meta=clone._meta):
2543
2561
  query_task = qb_model._compiler.fragment(clone)
2544
- qb_model._to_executor().execute(ir_model, query_task, result_cols=result_cols, export_to=table._fqn, update=update)
2562
+ qb_model._to_executor().execute(ir_model, query_task, result_cols=result_cols, export_to=table._fqn, update=update, meta=clone._meta)
2563
+
2545
2564
 
2546
2565
  #--------------------------------------------------
2547
2566
  # Select / Where
@@ -21,8 +21,8 @@ from relationalai.clients.config import Config
21
21
  from relationalai.clients.snowflake import APP_NAME
22
22
  from relationalai.clients.types import TransactionAsyncResponse
23
23
  from relationalai.clients.util import IdentityParser
24
- from relationalai.tools.constants import USE_DIRECT_ACCESS
25
-
24
+ from relationalai.tools.constants import USE_DIRECT_ACCESS, QUERY_ATTRIBUTES_HEADER
25
+ from relationalai.tools.query_utils import prepare_metadata_for_headers
26
26
 
27
27
  class LQPExecutor(e.Executor):
28
28
  """Executes LQP using the RAI client."""
@@ -60,7 +60,7 @@ class LQPExecutor(e.Executor):
60
60
  if not self.dry_run:
61
61
  self.engine = self._resources.get_default_engine_name()
62
62
  if not self.keep_model:
63
- atexit.register(self._resources.delete_graph, self.database, True)
63
+ atexit.register(self._resources.delete_graph, self.database, True, "lqp")
64
64
  return self._resources
65
65
 
66
66
  # Checks the graph index and updates it if necessary
@@ -88,7 +88,15 @@ class LQPExecutor(e.Executor):
88
88
  assert self.engine is not None
89
89
 
90
90
  with debugging.span("poll_use_index", sources=sources, model=model, engine=engine_name):
91
- resources.poll_use_index(app_name, sources, model, self.engine, engine_size, program_span_id)
91
+ resources.poll_use_index(
92
+ app_name=app_name,
93
+ sources=sources,
94
+ model=model,
95
+ engine_name=self.engine,
96
+ engine_size=engine_size,
97
+ language="lqp",
98
+ program_span_id=program_span_id,
99
+ )
92
100
 
93
101
  def report_errors(self, problems: list[dict[str, Any]], abort_on_error=True):
94
102
  from relationalai import errors
@@ -267,7 +275,7 @@ class LQPExecutor(e.Executor):
267
275
  if ivm_flag:
268
276
  config_dict['ivm.maintenance_level'] = lqp_ir.Value(value=ivm_flag, meta=None)
269
277
  return construct_configure(config_dict, None)
270
-
278
+
271
279
  def _compile_intrinsics(self) -> lqp_ir.Epoch:
272
280
  """Construct an epoch that defines a number of built-in definitions used by the
273
281
  emitter."""
@@ -344,6 +352,10 @@ class LQPExecutor(e.Executor):
344
352
  df, errs = result_helpers.format_results(raw_results, cols)
345
353
  self.report_errors(errs)
346
354
 
355
+ # Rename columns if wide outputs is enabled
356
+ if self.wide_outputs and len(cols) - len(extra_cols) == len(df.columns):
357
+ df.columns = cols[: len(df.columns)]
358
+
347
359
  # Process exports
348
360
  if export_to and not self.dry_run:
349
361
  assert cols, "No columns found in the output"
@@ -362,7 +374,7 @@ class LQPExecutor(e.Executor):
362
374
 
363
375
  def execute(self, model: ir.Model, task: ir.Task, format: Literal["pandas", "snowpark"] = "pandas",
364
376
  result_cols: Optional[list[str]] = None, export_to: Optional[str] = None,
365
- update: bool = False) -> DataFrame:
377
+ update: bool = False, meta: dict[str, Any] | None = None) -> DataFrame:
366
378
  self.prepare_data()
367
379
  previous_model = self._last_model
368
380
 
@@ -374,6 +386,9 @@ class LQPExecutor(e.Executor):
374
386
  if format != "pandas":
375
387
  raise ValueError(f"Unsupported format: {format}")
376
388
 
389
+ # Format meta as headers
390
+ json_meta = prepare_metadata_for_headers(meta)
391
+ headers = {QUERY_ATTRIBUTES_HEADER: json_meta} if json_meta else {}
377
392
  raw_results = self.resources.exec_lqp(
378
393
  self.database,
379
394
  self.engine,
@@ -383,6 +398,7 @@ class LQPExecutor(e.Executor):
383
398
  # transactions are serialized.
384
399
  readonly=False,
385
400
  nowait_durable=True,
401
+ headers=headers,
386
402
  )
387
403
  assert isinstance(raw_results, TransactionAsyncResponse)
388
404