relationalai 0.12.1__py3-none-any.whl → 0.12.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,11 +9,13 @@ from ..metamodel import Builtins
9
9
  from ..tools.cli_controls import Spinner
10
10
  from ..tools.constants import DEFAULT_QUERY_TIMEOUT_MINS
11
11
  from .. import debugging
12
+ from .. errors import ResponseStatusException
12
13
  import uuid
13
14
  import relationalai
14
15
  import json
15
16
  from ..clients.util import poll_with_specified_overhead
16
17
  from ..clients.snowflake import Resources as SnowflakeResources
18
+ from ..clients.snowflake import DirectAccessClient, DirectAccessResources
17
19
  from ..util.timeout import calc_remaining_timeout_minutes
18
20
 
19
21
  rel_sv = rel._tagged(Builtins.SingleValued)
@@ -23,7 +25,8 @@ APP_NAME = relationalai.clients.snowflake.APP_NAME
23
25
  ENGINE_TYPE_SOLVER = "SOLVER"
24
26
  # TODO (dba) The ERP still uses `worker` instead of `engine`. Change
25
27
  # this once we fix this in the ERP.
26
- ENGINE_ERRORS = ["worker is suspended", "create/resume", "worker not found", "no workers found", "worker was deleted"]
28
+ WORKER_ERRORS = ["worker is suspended", "create/resume", "worker not found", "no workers found", "worker was deleted"]
29
+ ENGINE_ERRORS = ["engine is suspended", "create/resume", "engine not found", "no engines found", "engine was deleted"]
27
30
  ENGINE_NOT_READY_MSGS = ["worker is in pending", "worker is provisioning", "worker is not ready to accept jobs"]
28
31
 
29
32
  # --------------------------------------------------
@@ -213,13 +216,6 @@ class SolverModel:
213
216
  config_file_path = getattr(rai_config, 'file_path', None)
214
217
  start_time = time.monotonic()
215
218
  remaining_timeout_minutes = query_timeout_mins
216
- # 1. Materialize the model and store it.
217
- # TODO(coey) Currently we must run a dummy query to install the pyrel rules in a separate txn
218
- # to the solve_output updates. Ideally pyrel would offer an option to flush the rules separately.
219
- self.graph.exec_raw("", query_timeout_mins=remaining_timeout_minutes)
220
- remaining_timeout_minutes = calc_remaining_timeout_minutes(
221
- start_time, query_timeout_mins, config_file_path=config_file_path,
222
- )
223
219
  response = self.graph.exec_raw(
224
220
  textwrap.dedent(f"""
225
221
  @inline
@@ -265,18 +261,7 @@ class SolverModel:
265
261
  remaining_timeout_minutes = calc_remaining_timeout_minutes(
266
262
  start_time, query_timeout_mins, config_file_path=config_file_path
267
263
  )
268
- try:
269
- job_id = solver._exec_job(payload, log_to_console=log_to_console, query_timeout_mins=remaining_timeout_minutes)
270
- except Exception as e:
271
- err_message = str(e).lower()
272
- if any(kw in err_message.lower() for kw in ENGINE_ERRORS + ENGINE_NOT_READY_MSGS):
273
- solver._auto_create_solver_async()
274
- remaining_timeout_minutes = calc_remaining_timeout_minutes(
275
- start_time, query_timeout_mins, config_file_path=config_file_path
276
- )
277
- job_id = solver._exec_job(payload, log_to_console=log_to_console, query_timeout_mins=remaining_timeout_minutes)
278
- else:
279
- raise e
264
+ job_id = solver._exec_job(payload, log_to_console=log_to_console, query_timeout_mins=remaining_timeout_minutes)
280
265
 
281
266
  # 3. Extract result.
282
267
  remaining_timeout_minutes = calc_remaining_timeout_minutes(
@@ -553,7 +538,11 @@ class Solver:
553
538
  # may configure each individual solver.
554
539
  self.engine_settings = settings
555
540
 
556
- return self._auto_create_solver_async()
541
+ # Optimistically set the engine object to a `READY` engine to
542
+ # avoid checking the engine status on each execution.
543
+ self.engine:Optional[dict[str,Any]] = {"name": engine_name, "state": "READY"}
544
+
545
+ return None
557
546
 
558
547
  # --------------------------------------------------
559
548
  # Helper
@@ -572,6 +561,7 @@ class Solver:
572
561
  assert len(engines) == 1 or len(engines) == 0
573
562
  if len(engines) != 0:
574
563
  engine = engines[0]
564
+
575
565
  if engine:
576
566
  # TODO (dba) Logic engines support altering the
577
567
  # auto_suspend_mins setting. Currently, we don't have
@@ -653,31 +643,32 @@ class Solver:
653
643
 
654
644
  self.engine = engine
655
645
 
656
- def _exec_job_async(self, payload, query_timeout_mins: Optional[int]=None):
657
- payload_json = json.dumps(payload)
658
- engine_name = self.engine["name"]
659
- if query_timeout_mins is None and (timeout_value := self.rai_config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
660
- query_timeout_mins = int(timeout_value)
661
- if query_timeout_mins is not None:
662
- sql_string = textwrap.dedent(f"""
663
- CALL {APP_NAME}.experimental.exec_job_async('{ENGINE_TYPE_SOLVER}', '{engine_name}', '{payload_json}', null, {query_timeout_mins})
664
- """)
665
- else:
666
- sql_string = textwrap.dedent(f"""
667
- CALL {APP_NAME}.experimental.exec_job_async('{ENGINE_TYPE_SOLVER}', '{engine_name}', '{payload_json}')
668
- """)
669
- res = self.provider.resources._exec(sql_string)
670
- return res[0]["ID"]
671
-
672
646
  def _exec_job(self, payload, log_to_console=True, query_timeout_mins: Optional[int]=None):
673
- # Make sure the engine is ready.
674
- if self.engine["state"] != "READY":
675
- poll_with_specified_overhead(lambda: self._is_solver_ready(), 0.1)
647
+ if self.engine is None:
648
+ raise Exception("Engine not initialized.")
676
649
 
677
650
  with debugging.span("job") as job_span:
678
- job_id = self._exec_job_async(payload, query_timeout_mins=query_timeout_mins)
651
+ # Retry logic. If creating a job fails with an engine
652
+ # related error we will create/resume/... the engine and
653
+ # retry.
654
+ try:
655
+ job_id = self.provider.create_job_async(self.engine["name"], payload, query_timeout_mins=query_timeout_mins)
656
+ except Exception as e:
657
+ err_message = str(e).lower()
658
+ if isinstance(e, ResponseStatusException):
659
+ err_message = e.response.json().get("message", "")
660
+ if any(kw in err_message.lower() for kw in ENGINE_ERRORS + WORKER_ERRORS + ENGINE_NOT_READY_MSGS):
661
+ self._auto_create_solver_async()
662
+ # Wait until the engine is ready.
663
+ poll_with_specified_overhead(lambda: self._is_solver_ready(), 0.1)
664
+ job_id = self.provider.create_job_async(self.engine["name"], payload, query_timeout_mins=query_timeout_mins)
665
+ else:
666
+ raise e
667
+
679
668
  job_span["job_id"] = job_id
680
669
  debugging.event("job_created", job_span, job_id=job_id, engine_name=self.engine["name"], job_type=ENGINE_TYPE_SOLVER)
670
+ if not isinstance(job_id, str):
671
+ job_id = ""
681
672
  polling_state = PollingState(job_id, "", False, log_to_console)
682
673
 
683
674
  try:
@@ -693,7 +684,14 @@ class Solver:
693
684
  return job_id
694
685
 
695
686
  def _is_solver_ready(self):
687
+ if self.engine is None:
688
+ raise Exception("Engine not initialized.")
689
+
696
690
  result = self.provider.get_solver(self.engine["name"])
691
+
692
+ if result is None:
693
+ raise Exception("No engine available.")
694
+
697
695
  self.engine = result
698
696
  state = result["state"]
699
697
  if state != "READY" and state != "PENDING":
@@ -711,20 +709,11 @@ class Solver:
711
709
 
712
710
  return status == "COMPLETED" or status == "FAILED" or status == "CANCELED"
713
711
 
714
- def _get_job_events(self, job_id: str, continuation_token: str = ""):
715
- results = self.provider.resources._exec(
716
- f"SELECT {APP_NAME}.experimental.get_job_events('{ENGINE_TYPE_SOLVER}', '{job_id}', '{continuation_token}');"
717
- )
718
- if not results:
719
- return {"events": [], "continuation_token": None}
720
- row = results[0][0]
721
- return json.loads(row)
722
-
723
712
  def _print_solver_logs(self, state: PollingState):
724
713
  if state.is_done:
725
714
  return
726
715
 
727
- resp = self._get_job_events(state.job_id, state.continuation_token)
716
+ resp = self.provider.get_job_events(state.job_id, state.continuation_token)
728
717
 
729
718
  # Print solver logs to stdout.
730
719
  for event in resp["events"]:
@@ -754,7 +743,12 @@ class Provider:
754
743
  resources = relationalai.Resources()
755
744
  if not isinstance(resources, relationalai.clients.snowflake.Resources):
756
745
  raise Exception("Solvers are only supported on SPCS.")
746
+
757
747
  self.resources = resources
748
+ self.direct_access_client: Optional[DirectAccessClient] = None
749
+
750
+ if isinstance(self.resources, DirectAccessResources):
751
+ self.direct_access_client = self.resources.direct_access_client
758
752
 
759
753
  def create_solver(
760
754
  self,
@@ -770,75 +764,285 @@ class Provider:
770
764
  engine_config: dict[str, Any] = {"settings": settings}
771
765
  if auto_suspend_mins is not None:
772
766
  engine_config["auto_suspend_mins"] = auto_suspend_mins
773
- self.resources._exec(
774
- f"CALL {APP_NAME}.experimental.create_engine('{ENGINE_TYPE_SOLVER}', '{name}', '{size}', {engine_config});"
767
+ self.resources._exec_sql(
768
+ f"CALL {APP_NAME}.experimental.create_engine('{ENGINE_TYPE_SOLVER}', '{name}', '{size}', {engine_config});", None
775
769
  )
776
770
 
777
771
  def create_solver_async(
778
772
  self,
779
773
  name: str,
780
774
  size: str | None = None,
781
- settings: dict = {},
775
+ settings: dict | None = None,
782
776
  auto_suspend_mins: int | None = None,
783
777
  ):
784
778
  if size is None:
785
779
  size = "HIGHMEM_X64_S"
786
- if settings is None:
787
- settings = ""
788
- engine_config: dict[str, Any] = {"settings": settings}
789
- if auto_suspend_mins is not None:
790
- engine_config["auto_suspend_mins"] = auto_suspend_mins
791
- self.resources._exec(
792
- f"CALL {APP_NAME}.experimental.create_engine_async('{ENGINE_TYPE_SOLVER}', '{name}', '{size}', {engine_config});"
793
- )
780
+
781
+ if self.direct_access_client is not None:
782
+ payload:dict[str, Any] = {
783
+ "name": name,
784
+ "settings": settings,
785
+ }
786
+ if auto_suspend_mins is not None:
787
+ payload["auto_suspend_mins"] = auto_suspend_mins
788
+ if size is not None:
789
+ payload["size"] = size
790
+ response = self.direct_access_client.request(
791
+ "create_engine",
792
+ payload=payload,
793
+ path_params={"engine_type": "solver"},
794
+ )
795
+ if response.status_code != 200:
796
+ raise ResponseStatusException(
797
+ f"Failed to create engine {name} with size {size}.", response
798
+ )
799
+ else:
800
+ engine_config: dict[str, Any] = {}
801
+ if settings is not None:
802
+ engine_config["settings"] = settings
803
+ if auto_suspend_mins is not None:
804
+ engine_config["auto_suspend_mins"] = auto_suspend_mins
805
+ self.resources._exec_sql(
806
+ f"CALL {APP_NAME}.experimental.create_engine_async('{ENGINE_TYPE_SOLVER}', '{name}', '{size}', {engine_config});",
807
+ None
808
+ )
794
809
 
795
810
  def delete_solver(self, name: str):
796
- self.resources._exec(
797
- f"CALL {APP_NAME}.experimental.delete_engine('{ENGINE_TYPE_SOLVER}', '{name}');"
798
- )
811
+ if self.direct_access_client is not None:
812
+ response = self.direct_access_client.request(
813
+ "delete_engine", path_params = {"engine_type": ENGINE_TYPE_SOLVER, "engine_name": name}
814
+ )
815
+ if response.status_code != 200:
816
+ raise ResponseStatusException("Failed to delete engine.", response)
817
+ return None
818
+ else:
819
+ self.resources._exec_sql(
820
+ f"CALL {APP_NAME}.experimental.delete_engine('{ENGINE_TYPE_SOLVER}', '{name}');",
821
+ None
822
+ )
799
823
 
800
824
  def resume_solver_async(self, name: str):
801
- self.resources._exec(
802
- f"CALL {APP_NAME}.experimental.resume_engine_async('{ENGINE_TYPE_SOLVER}', '{name}');"
803
- )
825
+ if self.direct_access_client is not None:
826
+ response = self.direct_access_client.request(
827
+ "resume_engine", path_params = {"engine_type": ENGINE_TYPE_SOLVER, "engine_name": name}
828
+ )
829
+ if response.status_code != 200:
830
+ raise ResponseStatusException("Failed to resume engine.", response)
831
+ return None
832
+ else:
833
+ self.resources._exec_sql(
834
+ f"CALL {APP_NAME}.experimental.resume_engine_async('{ENGINE_TYPE_SOLVER}', '{name}');",
835
+ None
836
+ )
837
+ return None
804
838
 
805
839
  def get_solver(self, name: str):
806
- results = self.resources._exec(
807
- f"CALL {APP_NAME}.experimental.get_engine('{ENGINE_TYPE_SOLVER}', '{name}');"
808
- )
809
- return solver_list_to_dicts(results)[0]
840
+ if self.direct_access_client is not None:
841
+ response = self.direct_access_client.request(
842
+ "get_engine", path_params = {"engine_type": ENGINE_TYPE_SOLVER, "engine_name": name}
843
+ )
844
+ if response.status_code != 200:
845
+ raise ResponseStatusException("Failed to get engine.", response)
846
+ solver = response.json()
847
+ if not solver :
848
+ return None
849
+ solver_state = {
850
+ "name": solver["name"],
851
+ "id": solver["id"],
852
+ "size": solver["size"],
853
+ "state": solver["status"], # callers are expecting 'state'
854
+ "created_by": solver["created_by"],
855
+ "created_on": solver["created_on"],
856
+ "updated_on": solver["updated_on"],
857
+ "version": solver["version"],
858
+ "auto_suspend": solver["auto_suspend_mins"],
859
+ "suspends_at": solver["suspends_at"],
860
+ "solvers": []
861
+ if solver["settings"] == ""
862
+ else [
863
+ k
864
+ for (k,v) in json.loads(solver["settings"]).items()
865
+ if isinstance(v, dict) and v.get("enabled", False)
866
+ ],
867
+ }
868
+ return solver_state
869
+ else:
870
+ results = self.resources._exec_sql(
871
+ f"CALL {APP_NAME}.experimental.get_engine('{ENGINE_TYPE_SOLVER}', '{name}');",
872
+ None
873
+ )
874
+ return solver_list_to_dicts(results)[0]
810
875
 
811
876
  def list_solvers(self, state: str | None = None):
812
- where_clause = f"WHERE TYPE='{ENGINE_TYPE_SOLVER}'"
813
- where_clause = (
814
- f"{where_clause} AND STATUS = '{state.upper()}'" if state else where_clause
815
- )
816
- statement = f"SELECT NAME,ID,SIZE,STATUS,CREATED_BY,CREATED_ON,UPDATED_ON,AUTO_SUSPEND_MINS,SETTINGS FROM {APP_NAME}.experimental.engines {where_clause};"
817
- results = self.resources._exec(statement)
818
- return solver_list_to_dicts(results)
877
+ if self.direct_access_client is not None:
878
+ response = self.direct_access_client.request(
879
+ "list_engines"
880
+ )
881
+ if response.status_code != 200:
882
+ raise ResponseStatusException("Failed to list engines.", response)
883
+ response_content = response.json()
884
+ if not response_content:
885
+ return []
886
+ engines = [
887
+ {
888
+ "name": engine["name"],
889
+ "id": engine["id"],
890
+ "size": engine["size"],
891
+ "state": engine["status"], # callers are expecting 'state'
892
+ "created_by": engine["created_by"],
893
+ "created_on": engine["created_on"],
894
+ "updated_on": engine["updated_on"],
895
+ "auto_suspend_mins": engine["auto_suspend_mins"],
896
+ "solvers": []
897
+ if engine["settings"] == ""
898
+ else [
899
+ k
900
+ for (k, v) in json.loads(engine["settings"]).items()
901
+ if isinstance(v, dict) and v.get("enabled", False)
902
+ ],
903
+ }
904
+ for engine in response_content.get("engines", [])
905
+ if (state is None or engine.get("status") == state) and (engine.get("type") == ENGINE_TYPE_SOLVER)
906
+ ]
907
+ return sorted(engines, key=lambda x: x["name"])
908
+ else:
909
+ where_clause = f"WHERE TYPE='{ENGINE_TYPE_SOLVER}'"
910
+ where_clause = (
911
+ f"{where_clause} AND STATUS = '{state.upper()}'" if state else where_clause
912
+ )
913
+ statement = f"SELECT NAME,ID,SIZE,STATUS,CREATED_BY,CREATED_ON,UPDATED_ON,AUTO_SUSPEND_MINS,SETTINGS FROM {APP_NAME}.experimental.engines {where_clause};"
914
+ results = self.resources._exec_sql(statement, None)
915
+ return solver_list_to_dicts(results)
819
916
 
820
917
  # --------------------------------------------------
821
918
  # Job API
822
919
  # --------------------------------------------------
823
920
 
921
+ def create_job_async(self, engine_name, payload, query_timeout_mins: Optional[int]=None):
922
+ payload_json = json.dumps(payload)
923
+
924
+ if query_timeout_mins is None and (timeout_value := self.resources.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
925
+ query_timeout_mins = int(timeout_value)
926
+
927
+ if self.direct_access_client is not None:
928
+ job = {
929
+ "job_type":ENGINE_TYPE_SOLVER,
930
+ "worker_name": engine_name,
931
+ "timeout_mins": query_timeout_mins,
932
+ "payload": payload_json,
933
+ }
934
+ response = self.direct_access_client.request(
935
+ "create_job",
936
+ payload=job,
937
+ )
938
+ if response.status_code != 200:
939
+ raise ResponseStatusException("Failed to create job.", response)
940
+ response_content = response.json()
941
+ return response_content["id"]
942
+ else:
943
+ if query_timeout_mins is not None:
944
+ sql_string = textwrap.dedent(f"""
945
+ CALL {APP_NAME}.experimental.exec_job_async('{ENGINE_TYPE_SOLVER}', '{engine_name}', '{payload_json}', null, {query_timeout_mins})
946
+ """)
947
+ else:
948
+ sql_string = textwrap.dedent(f"""
949
+ CALL {APP_NAME}.experimental.exec_job_async('{ENGINE_TYPE_SOLVER}', '{engine_name}', '{payload_json}')
950
+ """)
951
+ res = self.resources._exec_sql(sql_string, None)
952
+ return res[0]["ID"]
953
+
824
954
  def list_jobs(self, state=None, limit=None):
825
- state_clause = f"AND STATE = '{state.upper()}'" if state else ""
826
- limit_clause = f"LIMIT {limit}" if limit else ""
827
- results = self.resources._exec(
828
- f"SELECT ID,STATE,CREATED_BY,CREATED_ON,FINISHED_AT,DURATION,PAYLOAD,ENGINE_NAME FROM {APP_NAME}.experimental.jobs where type='{ENGINE_TYPE_SOLVER}' {state_clause} ORDER BY created_on DESC {limit_clause};"
829
- )
830
- return job_list_to_dicts(results)
955
+ if self.direct_access_client is not None:
956
+ response = self.direct_access_client.request(
957
+ "list_jobs"
958
+ )
959
+ if response.status_code != 200:
960
+ raise ResponseStatusException("Failed to list jobs.", response)
961
+ response_content = response.json()
962
+ if not response_content:
963
+ return []
964
+ jobs = [
965
+ {
966
+ "id": job["id"],
967
+ "state": job["state"],
968
+ "created_by": job["created_by"],
969
+ "created_on": job["created_on"],
970
+ "finished_at": job.get("finished_at", None),
971
+ "duration": job["duration"] if "duration" in job else 0,
972
+ "solver": json.loads(job["payload"]).get("solver", ""),
973
+ "engine": job.get("engine_name", job["worker_name"]),
974
+ }
975
+ for job in response_content.get("jobs", [])
976
+ if state is None or job.get("state") == state
977
+ ]
978
+ return sorted(jobs, key=lambda x: x["created_on"], reverse=True)
979
+ else:
980
+ state_clause = f"AND STATE = '{state.upper()}'" if state else ""
981
+ limit_clause = f"LIMIT {limit}" if limit else ""
982
+ results = self.resources._exec_sql(
983
+ f"SELECT ID,STATE,CREATED_BY,CREATED_ON,FINISHED_AT,DURATION,PAYLOAD,ENGINE_NAME FROM {APP_NAME}.experimental.jobs where type='{ENGINE_TYPE_SOLVER}' {state_clause} ORDER BY created_on DESC {limit_clause};",
984
+ None
985
+ )
986
+ return job_list_to_dicts(results)
831
987
 
832
988
  def get_job(self, id: str):
833
- results = self.resources._exec(
834
- f"CALL {APP_NAME}.experimental.get_job('{ENGINE_TYPE_SOLVER}', '{id}');"
835
- )
836
- return job_list_to_dicts(results)[0]
989
+ if self.direct_access_client is not None:
990
+ response = self.direct_access_client.request(
991
+ "get_job", path_params = {"job_type": ENGINE_TYPE_SOLVER, "job_id": id}
992
+ )
993
+ if response.status_code != 200:
994
+ raise ResponseStatusException("Failed to get job.", response)
995
+ response_content = response.json()
996
+ return response_content["job"]
997
+ else:
998
+ results = self.resources._exec_sql(
999
+ f"CALL {APP_NAME}.experimental.get_job('{ENGINE_TYPE_SOLVER}', '{id}');",
1000
+ None
1001
+ )
1002
+ return job_list_to_dicts(results)[0]
1003
+
1004
+ def get_job_events(self, job_id: str, continuation_token: str = ""):
1005
+ if self.direct_access_client is not None:
1006
+ response = self.direct_access_client.request(
1007
+ "get_job_events",
1008
+ path_params = {"job_type": ENGINE_TYPE_SOLVER, "job_id": job_id, "stream_name": "progress"},
1009
+ query_params={"continuation_token": continuation_token},
1010
+ )
1011
+ if response.status_code != 200:
1012
+ raise ResponseStatusException("Failed to get job events.", response)
1013
+ response_content = response.json()
1014
+ if not response_content:
1015
+ return {
1016
+ "events": [],
1017
+ "continuation_token": None
1018
+ }
1019
+ return response_content
1020
+ else:
1021
+ results = self.resources._exec_sql(
1022
+ f"SELECT {APP_NAME}.experimental.get_job_events('{ENGINE_TYPE_SOLVER}', '{job_id}', '{continuation_token}');",
1023
+ None
1024
+ )
1025
+ if not results:
1026
+ return {"events": [], "continuation_token": None}
1027
+ row = results[0][0]
1028
+ if not isinstance(row, str):
1029
+ row = ""
1030
+ return json.loads(row)
837
1031
 
838
1032
  def cancel_job(self, id: str):
839
- self.resources._exec(
840
- f"CALL {APP_NAME}.experimental.cancel_job('{ENGINE_TYPE_SOLVER}', '{id}');"
841
- )
1033
+ if self.direct_access_client is not None:
1034
+ response = self.direct_access_client.request(
1035
+ "cancel_job", path_params = {"job_type": ENGINE_TYPE_SOLVER, "job_id": id}
1036
+ )
1037
+ if response.status_code != 200:
1038
+ raise ResponseStatusException("Failed to cancel job.", response)
1039
+ return None
1040
+ else:
1041
+ self.resources._exec_sql(
1042
+ f"CALL {APP_NAME}.experimental.cancel_job('{ENGINE_TYPE_SOLVER}', '{id}');",
1043
+ None
1044
+ )
1045
+ return None
842
1046
 
843
1047
 
844
1048
  def solver_list_to_dicts(results):
@@ -865,7 +1069,6 @@ def solver_list_to_dicts(results):
865
1069
  for row in results
866
1070
  ]
867
1071
 
868
-
869
1072
  def job_list_to_dicts(results):
870
1073
  if not results:
871
1074
  return []
@@ -346,7 +346,7 @@ def find_select_keys(item: Any, keys:OrderedSet[Key]|None = None, enable_primiti
346
346
 
347
347
  if isinstance(item, (list, tuple)):
348
348
  for it in item:
349
- find_select_keys(it, keys)
349
+ find_select_keys(it, keys, enable_primitive_key=enable_primitive_key)
350
350
 
351
351
  elif isinstance(item, (Relationship, RelationshipReading)) and item._parent:
352
352
  find_select_keys(item._parent, keys)
@@ -390,7 +390,7 @@ def find_select_keys(item: Any, keys:OrderedSet[Key]|None = None, enable_primiti
390
390
  find_select_keys(item._arg, keys)
391
391
 
392
392
  elif isinstance(item, Alias):
393
- find_select_keys(item._thing, keys)
393
+ find_select_keys(item._thing, keys, enable_primitive_key=enable_primitive_key)
394
394
 
395
395
  elif isinstance(item, Aggregate):
396
396
  keys.update( Key(k, True) for k in item._group )
@@ -2418,21 +2418,21 @@ class Fragment():
2418
2418
 
2419
2419
  def meta(self, **kwargs: Any) -> Fragment:
2420
2420
  """Add metadata to the query.
2421
-
2421
+
2422
2422
  Metadata can be used for debugging and observability purposes.
2423
-
2423
+
2424
2424
  Args:
2425
2425
  **kwargs: Metadata key-value pairs
2426
-
2426
+
2427
2427
  Returns:
2428
2428
  Fragment: Returns self for method chaining
2429
-
2429
+
2430
2430
  Example:
2431
2431
  select(Person.name).meta(workload_name="test", priority=1, enabled=True)
2432
2432
  """
2433
2433
  if not kwargs:
2434
2434
  raise ValueError("meta() requires at least one argument")
2435
-
2435
+
2436
2436
  self._meta.update(kwargs)
2437
2437
  return self
2438
2438
 
@@ -2560,7 +2560,7 @@ class Fragment():
2560
2560
  with debugging.span("query", dsl=str(clone), **with_source(clone), meta=clone._meta):
2561
2561
  query_task = qb_model._compiler.fragment(clone)
2562
2562
  qb_model._to_executor().execute(ir_model, query_task, result_cols=result_cols, export_to=table._fqn, update=update, meta=clone._meta)
2563
-
2563
+
2564
2564
 
2565
2565
  #--------------------------------------------------
2566
2566
  # Select / Where
@@ -60,7 +60,7 @@ class LQPExecutor(e.Executor):
60
60
  if not self.dry_run:
61
61
  self.engine = self._resources.get_default_engine_name()
62
62
  if not self.keep_model:
63
- atexit.register(self._resources.delete_graph, self.database, True)
63
+ atexit.register(self._resources.delete_graph, self.database, True, "lqp")
64
64
  return self._resources
65
65
 
66
66
  # Checks the graph index and updates it if necessary
@@ -88,7 +88,15 @@ class LQPExecutor(e.Executor):
88
88
  assert self.engine is not None
89
89
 
90
90
  with debugging.span("poll_use_index", sources=sources, model=model, engine=engine_name):
91
- resources.poll_use_index(app_name, sources, model, self.engine, engine_size, program_span_id)
91
+ resources.poll_use_index(
92
+ app_name=app_name,
93
+ sources=sources,
94
+ model=model,
95
+ engine_name=self.engine,
96
+ engine_size=engine_size,
97
+ language="lqp",
98
+ program_span_id=program_span_id,
99
+ )
92
100
 
93
101
  def report_errors(self, problems: list[dict[str, Any]], abort_on_error=True):
94
102
  from relationalai import errors
@@ -292,6 +300,27 @@ class LQPExecutor(e.Executor):
292
300
  meta=None,
293
301
  )
294
302
 
303
+ def _compile_undefine_query(self, query_epoch: lqp_ir.Epoch) -> lqp_ir.Epoch:
304
+ fragment_ids = []
305
+
306
+ for write in query_epoch.writes:
307
+ if isinstance(write.write_type, lqp_ir.Define):
308
+ fragment_ids.append(write.write_type.fragment.id)
309
+
310
+ # Construct new Epoch with Undefine operations for all collected fragment IDs
311
+ undefine_writes = [
312
+ lqp_ir.Write(
313
+ write_type=lqp_ir.Undefine(fragment_id=frag_id, meta=None),
314
+ meta=None
315
+ )
316
+ for frag_id in fragment_ids
317
+ ]
318
+
319
+ return lqp_ir.Epoch(
320
+ writes=undefine_writes,
321
+ meta=None,
322
+ )
323
+
295
324
  def compile_lqp(self, model: ir.Model, task: ir.Task):
296
325
  configure = self._construct_configure()
297
326
 
@@ -326,7 +355,11 @@ class LQPExecutor(e.Executor):
326
355
  if model_txn is not None:
327
356
  epochs.append(model_txn.epochs[0])
328
357
 
329
- epochs.append(query_txn.epochs[0])
358
+ query_txn_epoch = query_txn.epochs[0]
359
+
360
+ epochs.append(query_txn_epoch)
361
+
362
+ epochs.append(self._compile_undefine_query(query_txn_epoch))
330
363
 
331
364
  txn = lqp_ir.Transaction(epochs=epochs, configure=configure, meta=None)
332
365
 
@@ -580,8 +580,9 @@ def get_relation_id(ctx: TranslationCtx, relation: ir.Relation, projection: list
580
580
  if relation.id in ctx.def_names.id_to_name:
581
581
  unique_name = ctx.def_names.id_to_name[relation.id]
582
582
  else:
583
- prefix = helpers.relation_name_prefix(relation)
584
- unique_name = ctx.def_names.get_name_by_id(relation.id, prefix + relation.name)
583
+ name = helpers.relation_name_prefix(relation) + relation.name
584
+ name = helpers.sanitize(name)
585
+ unique_name = ctx.def_names.get_name_by_id(relation.id, name)
585
586
 
586
587
  return utils.gen_rel_id(ctx, unique_name, types)
587
588
 
@@ -265,10 +265,12 @@ class ExtractCommon(Pass):
265
265
  for child in common_body:
266
266
  body_output_vars.update(ctx.info.task_outputs(child))
267
267
 
268
- # Compute the union of input vars across all composites, intersected with output
268
+ # Compute the union of input vars across all non-extracted tasks (basically
269
+ # composites and binders left behind), intersected with output
269
270
  # vars of the common body
270
271
  exposed_vars = OrderedSet.from_iterable(ctx.info.task_inputs(sample)) & body_output_vars
271
- for composite in composites:
272
+ non_extracted_tasks = (binders - common_body) | composites
273
+ for composite in non_extracted_tasks:
272
274
  if composite is sample:
273
275
  continue
274
276
  # compute common input vars