relationalai 0.13.1__py3-none-any.whl → 0.13.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. relationalai/clients/client.py +52 -18
  2. relationalai/clients/exec_txn_poller.py +62 -31
  3. relationalai/clients/local.py +23 -8
  4. relationalai/clients/resources/azure/azure.py +36 -11
  5. relationalai/clients/resources/snowflake/__init__.py +3 -3
  6. relationalai/clients/resources/snowflake/cli_resources.py +12 -1
  7. relationalai/clients/resources/snowflake/direct_access_resources.py +63 -22
  8. relationalai/clients/resources/snowflake/engine_service.py +381 -0
  9. relationalai/clients/resources/snowflake/engine_state_handlers.py +35 -29
  10. relationalai/clients/resources/snowflake/error_handlers.py +43 -2
  11. relationalai/clients/resources/snowflake/snowflake.py +163 -172
  12. relationalai/clients/types.py +5 -0
  13. relationalai/errors.py +1 -1
  14. relationalai/semantics/lqp/algorithms.py +173 -0
  15. relationalai/semantics/lqp/builtins.py +199 -2
  16. relationalai/semantics/lqp/executor.py +65 -36
  17. relationalai/semantics/lqp/ir.py +28 -2
  18. relationalai/semantics/lqp/model2lqp.py +215 -45
  19. relationalai/semantics/lqp/passes.py +13 -658
  20. relationalai/semantics/lqp/rewrite/__init__.py +12 -0
  21. relationalai/semantics/lqp/rewrite/algorithm.py +385 -0
  22. relationalai/semantics/lqp/rewrite/constants_to_vars.py +70 -0
  23. relationalai/semantics/lqp/rewrite/deduplicate_vars.py +104 -0
  24. relationalai/semantics/lqp/rewrite/eliminate_data.py +108 -0
  25. relationalai/semantics/lqp/rewrite/period_math.py +77 -0
  26. relationalai/semantics/lqp/rewrite/quantify_vars.py +65 -31
  27. relationalai/semantics/lqp/rewrite/unify_definitions.py +317 -0
  28. relationalai/semantics/lqp/utils.py +11 -1
  29. relationalai/semantics/lqp/validators.py +14 -1
  30. relationalai/semantics/metamodel/builtins.py +2 -1
  31. relationalai/semantics/metamodel/compiler.py +2 -1
  32. relationalai/semantics/metamodel/dependency.py +12 -3
  33. relationalai/semantics/metamodel/executor.py +11 -1
  34. relationalai/semantics/metamodel/factory.py +2 -2
  35. relationalai/semantics/metamodel/helpers.py +7 -0
  36. relationalai/semantics/metamodel/ir.py +3 -2
  37. relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +30 -20
  38. relationalai/semantics/metamodel/rewrite/flatten.py +50 -13
  39. relationalai/semantics/metamodel/rewrite/format_outputs.py +9 -3
  40. relationalai/semantics/metamodel/typer/checker.py +6 -4
  41. relationalai/semantics/metamodel/typer/typer.py +4 -3
  42. relationalai/semantics/metamodel/visitor.py +4 -3
  43. relationalai/semantics/reasoners/optimization/solvers_dev.py +1 -1
  44. relationalai/semantics/reasoners/optimization/solvers_pb.py +3 -4
  45. relationalai/semantics/rel/compiler.py +2 -1
  46. relationalai/semantics/rel/executor.py +3 -2
  47. relationalai/semantics/tests/lqp/__init__.py +0 -0
  48. relationalai/semantics/tests/lqp/algorithms.py +345 -0
  49. relationalai/tools/cli.py +339 -186
  50. relationalai/tools/cli_controls.py +216 -67
  51. relationalai/tools/cli_helpers.py +410 -6
  52. relationalai/util/format.py +5 -2
  53. {relationalai-0.13.1.dist-info → relationalai-0.13.3.dist-info}/METADATA +1 -1
  54. {relationalai-0.13.1.dist-info → relationalai-0.13.3.dist-info}/RECORD +58 -48
  55. relationalai_test_util/fixtures.py +2 -2
  56. {relationalai-0.13.1.dist-info → relationalai-0.13.3.dist-info}/WHEEL +0 -0
  57. {relationalai-0.13.1.dist-info → relationalai-0.13.3.dist-info}/entry_points.txt +0 -0
  58. {relationalai-0.13.1.dist-info → relationalai-0.13.3.dist-info}/licenses/LICENSE +0 -0
@@ -164,11 +164,18 @@ class ResourcesBase(ABC):
164
164
  pass
165
165
 
166
166
  @abstractmethod
167
- def list_engines(self, state: str|None = None) -> List[Any]:
167
+ def list_engines(
168
+ self,
169
+ state: str | None = None,
170
+ name: str | None = None,
171
+ type: str | None = None,
172
+ size: str | None = None,
173
+ created_by: str | None = None,
174
+ ) -> List[Any]:
168
175
  pass
169
176
 
170
177
  @abstractmethod
171
- def get_engine(self, name: str) -> EngineState | None:
178
+ def get_engine(self, name: str, type: str) -> EngineState | None:
172
179
  pass
173
180
 
174
181
  @abstractmethod
@@ -180,23 +187,31 @@ class ResourcesBase(ABC):
180
187
  pass
181
188
 
182
189
  @abstractmethod
183
- def create_engine(self, name: str, size: str|None, auto_suspend_mins: int|None) -> dict | None:
190
+ def create_engine(
191
+ self,
192
+ name: str,
193
+ type: str | None = None,
194
+ size: str | None = None,
195
+ auto_suspend_mins: int | None = None,
196
+ headers: Dict | None = None,
197
+ settings: dict | None = None,
198
+ ) -> dict | None:
184
199
  pass
185
200
 
186
201
  @abstractmethod
187
- def delete_engine(self, name:str, force:bool=False) -> dict | None:
202
+ def delete_engine(self, name:str, type: str) -> dict | None:
188
203
  pass
189
204
 
190
205
  @abstractmethod
191
- def suspend_engine(self, name: str):
206
+ def suspend_engine(self, name: str, type: str | None = None):
192
207
  pass
193
208
 
194
209
  @abstractmethod
195
- def resume_engine(self, name: str, headers: Dict | None = None) -> dict:
210
+ def resume_engine(self, name: str, type: str | None = None, headers: Dict | None = None) -> dict:
196
211
  pass
197
212
 
198
213
  @abstractmethod
199
- def resume_engine_async(self, name: str) -> dict:
214
+ def resume_engine_async(self, name: str, type: str | None = None, headers: Dict | None = None) -> dict:
200
215
  pass
201
216
 
202
217
  @abstractmethod
@@ -210,7 +225,7 @@ class ResourcesBase(ABC):
210
225
  return engine
211
226
 
212
227
  @abstractmethod
213
- def auto_create_engine_async(self, name: str | None = None) -> str:
228
+ def auto_create_engine_async(self, name: str | None = None, type: str | None = None) -> str:
214
229
  pass
215
230
 
216
231
  _active_engine: EngineState|None = None
@@ -428,14 +443,34 @@ class ProviderBase(ABC):
428
443
 
429
444
  resources: ResourcesBase
430
445
 
431
- def list_engines(self, state: str | None = None):
432
- return self.resources.list_engines(state)
446
+ def list_engines(
447
+ self,
448
+ state: str | None = None,
449
+ name: str | None = None,
450
+ type: str | None = None,
451
+ size: str | None = None,
452
+ created_by: str | None = None,
453
+ ):
454
+ return self.resources.list_engines(state=state, name=name, type=type, size=size, created_by=created_by)
433
455
 
434
- def create_engine(self, name:str, size:str|None=None, auto_suspend_mins:int|None=None):
435
- return self.resources.create_engine(name, size, auto_suspend_mins)
456
+ def create_engine(
457
+ self,
458
+ name: str,
459
+ type: str | None = None,
460
+ size: str | None = None,
461
+ auto_suspend_mins: int | None = None,
462
+ settings: dict | None = None,
463
+ ):
464
+ return self.resources.create_engine(
465
+ name,
466
+ type=type,
467
+ size=size,
468
+ auto_suspend_mins=auto_suspend_mins,
469
+ settings=settings,
470
+ )
436
471
 
437
- def delete_engine(self, name:str):
438
- return self.resources.delete_engine(name)
472
+ def delete_engine(self, name:str, type: str = "LOGIC"):
473
+ return self.resources.delete_engine(name, type)
439
474
 
440
475
  def get_transaction(self, transaction_id:str):
441
476
  return self.resources.get_transaction(transaction_id)
@@ -579,7 +614,6 @@ class Client():
579
614
  self._timed_query(
580
615
  "update_registry",
581
616
  dependencies.generate_update_registry(),
582
- readonly=False,
583
617
  abort_on_error=False,
584
618
  )
585
619
 
@@ -588,7 +622,6 @@ class Client():
588
622
  self._timed_query(
589
623
  "update_packages",
590
624
  dependencies.generate_update_packages(),
591
- readonly=False,
592
625
  abort_on_error=False,
593
626
  )
594
627
  else:
@@ -611,10 +644,11 @@ class Client():
611
644
  finally:
612
645
  self._database = database_name
613
646
 
614
- def _timed_query(self, span_name:str, code: str, readonly=True, abort_on_error=True):
647
+ def _timed_query(self, span_name:str, code: str, abort_on_error=True):
615
648
  with debugging.span(span_name, model=self._database) as end_span:
616
649
  start = time.perf_counter()
617
- res, raw = self._query(code, None, end_span, readonly=readonly, abort_on_error=abort_on_error)
650
+ # NOTE hardcoding to readonly=False, read-only Rel transactions are deprecated.
651
+ res, raw = self._query(code, None, end_span, readonly=False, abort_on_error=abort_on_error)
618
652
  debugging.time(span_name, time.perf_counter() - start, code=code)
619
653
  return res, raw
620
654
 
@@ -27,16 +27,43 @@ class ExecTxnPoller:
27
27
 
28
28
  def __init__(
29
29
  self,
30
+ print_txn_progress: bool,
30
31
  resource: "Resources",
31
- txn_id: str,
32
+ txn_id: Optional[str] = None,
32
33
  headers: Optional[Dict] = None,
33
34
  txn_start_time: Optional[float] = None,
34
35
  ):
36
+ self.print_txn_progress = print_txn_progress
35
37
  self.res = resource
36
38
  self.txn_id = txn_id
37
39
  self.headers = headers or {}
38
40
  self.txn_start_time = txn_start_time or time.time()
39
41
 
42
+ def __enter__(self) -> ExecTxnPoller:
43
+ if not self.print_txn_progress:
44
+ return self
45
+ self.progress = create_progress(
46
+ description=lambda: self.description_with_timing(),
47
+ success_message="", # We'll handle this ourselves
48
+ leading_newline=False,
49
+ trailing_newline=False,
50
+ show_duration_summary=False,
51
+ )
52
+ self.progress.__enter__()
53
+ return self
54
+
55
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
56
+ if not self.print_txn_progress or self.txn_id is None:
57
+ return
58
+ # Update to success message with duration
59
+ total_duration = time.time() - self.txn_start_time
60
+ txn_id = self.txn_id
61
+ self.progress.update_main_status(
62
+ query_complete_message(txn_id, total_duration)
63
+ )
64
+ self.progress.__exit__(exc_type, exc_value, traceback)
65
+ return
66
+
40
67
  def poll(self) -> bool:
41
68
  """
42
69
  Poll for transaction completion with interactive progress display.
@@ -44,48 +71,52 @@ class ExecTxnPoller:
44
71
  Returns:
45
72
  True if transaction completed successfully, False otherwise
46
73
  """
74
+ if not self.txn_id:
75
+ raise ValueError("Transaction ID must be provided for polling.")
76
+ else:
77
+ txn_id = self.txn_id
78
+
79
+ if self.print_txn_progress:
80
+ # Update the main status to include the new txn_id
81
+ self.progress.update_main_status_fn(
82
+ lambda: self.description_with_timing(txn_id),
83
+ )
47
84
 
48
85
  # Don't show duration summary - we handle our own completion message
49
- with create_progress(
50
- description="Evaluating Query...",
51
- success_message="", # We'll handle this ourselves
52
- leading_newline=False,
53
- trailing_newline=False,
54
- show_duration_summary=False,
55
- ) as progress:
56
- def check_status() -> bool:
57
- """Check if transaction is complete."""
58
- elapsed = time.time() - self.txn_start_time
59
- # Update the main status with elapsed time
60
- progress.update_main_status(
61
- query_progress_message(self.txn_id, elapsed)
62
- )
63
- return self.res._check_exec_async_status(self.txn_id, headers=self.headers)
64
-
65
- with debugging.span("wait", txn_id=self.txn_id):
66
- poll_with_specified_overhead(check_status, overhead_rate=POLL_OVERHEAD_RATE)
67
-
68
- # Calculate final duration
69
- total_duration = time.time() - self.txn_start_time
70
-
71
- # Update to success message with duration
72
- progress.update_main_status(
73
- query_complete_message(self.txn_id, total_duration)
74
- )
86
+ def check_status() -> bool:
87
+ """Check if transaction is complete."""
88
+ finished = self.res._check_exec_async_status(txn_id, headers=self.headers)
89
+ return finished
90
+
91
+ with debugging.span("wait", txn_id=self.txn_id):
92
+ poll_with_specified_overhead(check_status, overhead_rate=POLL_OVERHEAD_RATE)
93
+
75
94
 
76
95
  return True
77
96
 
97
+ def description_with_timing(self, txn_id: str | None = None) -> str:
98
+ elapsed = time.time() - self.txn_start_time
99
+ if txn_id is None:
100
+ return query_progress_header(elapsed)
101
+ else:
102
+ return query_progress_message(txn_id, elapsed)
103
+
104
+ def query_progress_header(duration: float) -> str:
105
+ # Don't print sub-second decimals, because it updates too fast and is distracting.
106
+ duration_str = format_duration(duration, seconds_decimals=False)
107
+ return f"Evaluating Query... {duration_str:>15}\n"
108
+
78
109
  def query_progress_message(id: str, duration: float) -> str:
79
110
  return (
111
+ query_progress_header(duration) +
80
112
  # Print with whitespace to align with the end of the transaction ID
81
- f"Evaluating Query... {format_duration(duration):>18}\n" +
82
- f"{GRAY_COLOR}Query: {id}{ENDC}"
113
+ f"{GRAY_COLOR}ID: {id}{ENDC}"
83
114
  )
84
115
 
85
116
  def query_complete_message(id: str, duration: float, status_header: bool = False) -> str:
86
117
  return (
87
118
  (f"{GREEN_COLOR}✅ " if status_header else "") +
88
119
  # Print with whitespace to align with the end of the transaction ID
89
- f"Query Complete: {format_duration(duration):>24}\n" +
90
- f"{GRAY_COLOR}Query: {id}{ENDC}"
120
+ f"Query Complete: {format_duration(duration):>21}\n" +
121
+ f"{GRAY_COLOR}ID: {id}{ENDC}"
91
122
  )
@@ -172,10 +172,17 @@ class LocalResources(ResourcesBase):
172
172
  def get_engine_sizes(self, cloud_provider: str | None = None):
173
173
  raise NotImplementedError("get_engine_sizes not supported in local mode")
174
174
 
175
- def list_engines(self, state: str | None = None):
175
+ def list_engines(
176
+ self,
177
+ state: str | None = None,
178
+ name: str | None = None,
179
+ type: str | None = None,
180
+ size: str | None = None,
181
+ created_by: str | None = None,
182
+ ):
176
183
  raise NotImplementedError("list_engines not supported in local mode")
177
184
 
178
- def get_engine(self, name: str):
185
+ def get_engine(self, name: str, type: str):
179
186
  raise NotImplementedError("get_engine not supported in local mode")
180
187
 
181
188
  def get_cloud_provider(self) -> str:
@@ -184,25 +191,33 @@ class LocalResources(ResourcesBase):
184
191
  def is_valid_engine_state(self, name: str) -> bool:
185
192
  raise NotImplementedError("is_valid_engine_state not supported in local mode")
186
193
 
187
- def create_engine(self, name: str, size:str|None=None, auto_suspend_mins:int|None=None):
194
+ def create_engine(
195
+ self,
196
+ name: str,
197
+ type: str | None = None,
198
+ size: str | None = None,
199
+ auto_suspend_mins: int | None = None,
200
+ headers: dict | None = None,
201
+ settings: dict | None = None,
202
+ ):
188
203
  raise NotImplementedError("create_engine not supported in local mode")
189
204
 
190
- def delete_engine(self, name: str, force: bool = False):
205
+ def delete_engine(self, name: str, type: str):
191
206
  raise NotImplementedError("delete_engine not supported in local mode")
192
207
 
193
- def suspend_engine(self, name: str):
208
+ def suspend_engine(self, name: str, type: str | None = None):
194
209
  raise NotImplementedError("suspend_engine not supported in local mode")
195
210
 
196
- def resume_engine(self, name: str, headers: Dict | None = None):
211
+ def resume_engine(self, name: str, type: str | None = None, headers: Dict | None = None):
197
212
  raise NotImplementedError("resume_engine not supported in local mode")
198
213
 
199
- def resume_engine_async(self, name: str):
214
+ def resume_engine_async(self, name: str, type: str | None = None, headers: Dict | None = None):
200
215
  raise NotImplementedError("resume_engine_async not supported in local mode")
201
216
 
202
217
  def alter_engine_pool(self, size: str | None = None, mins: int | None = None, maxs: int | None = None):
203
218
  raise NotImplementedError("alter_engine_pool not supported in local mode")
204
219
 
205
- def auto_create_engine_async(self, name: str | None = None) -> str:
220
+ def auto_create_engine_async(self, name: str | None = None, type: str | None = None) -> str:
206
221
  raise NotImplementedError("auto_create_engine_async not supported in local mode")
207
222
 
208
223
  #--------------------------------------------------
@@ -78,34 +78,59 @@ class Resources(ResourcesBase):
78
78
  def get_cloud_provider(self) -> str:
79
79
  return "azure"
80
80
 
81
- def list_engines(self, state:str|None = None):
82
- return api.list_engines(self._api_ctx(), state)
83
-
84
- def get_engine(self, name:str) -> EngineState:
85
- return cast(EngineState, api.get_engine(self._api_ctx(), name))
81
+ def list_engines(
82
+ self,
83
+ state: str | None = None,
84
+ name: str | None = None,
85
+ type: str | None = None,
86
+ size: str | None = None,
87
+ created_by: str | None = None,
88
+ ):
89
+ # Azure only supports state filtering at the API level; other filters are ignored.
90
+ engines = api.list_engines(self._api_ctx(), state)
91
+ # Ensure EngineState shape includes 'type' for callers/tests.
92
+ for eng in engines:
93
+ eng.setdefault("type", "LOGIC")
94
+ return engines
95
+
96
+ def get_engine(self, name: str, type: str) -> EngineState:
97
+ # type is ignored for Azure as it doesn't support multiple engine types
98
+ engine = cast(EngineState, api.get_engine(self._api_ctx(), name))
99
+ engine.setdefault("type", "LOGIC")
100
+ return engine
86
101
 
87
102
  def is_valid_engine_state(self, name:str):
88
103
  return name in VALID_ENGINE_STATES
89
104
 
90
- def create_engine(self, name:str, size:str|None=None, auto_suspend_mins: int|None=None):
105
+ def create_engine(
106
+ self,
107
+ name: str,
108
+ type: str | None = None,
109
+ size: str | None = None,
110
+ auto_suspend_mins: int | None = None,
111
+ headers: dict | None = None,
112
+ settings: dict | None = None,
113
+ ):
114
+ # Azure only supports one engine type, so type parameter is ignored
91
115
  if size is None:
92
116
  size = "M"
93
117
  with debugging.span("create_engine", name=name, size=size):
94
118
  return api.create_engine_wait(self._api_ctx(), name, size)
95
119
 
96
- def delete_engine(self, name:str, force:bool=False):
120
+ def delete_engine(self, name:str, type: str):
97
121
  return api.delete_engine(self._api_ctx(), name)
98
122
 
99
- def suspend_engine(self, name:str):
123
+ def suspend_engine(self, name:str, type: str | None = None): # type is ignored for Azure
100
124
  return api.suspend_engine(self._api_ctx(), name)
101
125
 
102
- def resume_engine(self, name:str, headers={}):
126
+ def resume_engine(self, name:str, type: str | None = None, headers={}):
127
+ # type is ignored for Azure as it doesn't support multiple engine types
103
128
  return api.resume_engine_wait(self._api_ctx(), name)
104
129
 
105
- def resume_engine_async(self, name:str):
130
+ def resume_engine_async(self, name:str, type: str | None = None, headers: Dict | None = None):
106
131
  return api.resume_engine(self._api_ctx(), name)
107
132
 
108
- def auto_create_engine_async(self, name: str | None = None) -> str:
133
+ def auto_create_engine_async(self, name: str | None = None, type: str | None = None) -> str:
109
134
  raise Exception("Azure doesn't support auto_create_engine_async")
110
135
 
111
136
  def alter_engine_pool(self, size: str | None = None, mins: int | None = None, maxs: int | None = None):
@@ -2,8 +2,8 @@
2
2
  Snowflake resources module.
3
3
  """
4
4
  # Import order matters - Resources must be imported first since other classes depend on it
5
- from .snowflake import Resources, Provider, Graph, SnowflakeClient, APP_NAME, PYREL_ROOT_DB, ExecContext, INTERNAL_ENGINE_SIZES, ENGINE_SIZES_AWS, ENGINE_SIZES_AZURE, PrimaryKey, PRINT_TXN_PROGRESS_FLAG
6
-
5
+ from .snowflake import Resources, Provider, Graph, SnowflakeClient, APP_NAME, PYREL_ROOT_DB, ExecContext, PrimaryKey, PRINT_TXN_PROGRESS_FLAG
6
+ from .engine_service import EngineType, INTERNAL_ENGINE_SIZES, ENGINE_SIZES_AWS, ENGINE_SIZES_AZURE
7
7
  # These imports depend on Resources, so they come after
8
8
  from .cli_resources import CLIResources
9
9
  from .use_index_resources import UseIndexResources
@@ -12,7 +12,7 @@ from .resources_factory import create_resources_instance
12
12
 
13
13
  __all__ = [
14
14
  'Resources', 'DirectAccessResources', 'Provider', 'Graph', 'SnowflakeClient',
15
- 'APP_NAME', 'PYREL_ROOT_DB', 'CLIResources', 'UseIndexResources', 'ExecContext',
15
+ 'APP_NAME', 'PYREL_ROOT_DB', 'CLIResources', 'UseIndexResources', 'ExecContext', 'EngineType',
16
16
  'INTERNAL_ENGINE_SIZES', 'ENGINE_SIZES_AWS', 'ENGINE_SIZES_AZURE', 'PrimaryKey',
17
17
  'PRINT_TXN_PROGRESS_FLAG', 'create_resources_instance',
18
18
  ]
@@ -11,6 +11,7 @@ from ....tools.constants import RAI_APP_NAME
11
11
  # Import Resources from snowflake - this creates a dependency but no circular import
12
12
  # since snowflake.py doesn't import from this file
13
13
  from .snowflake import Resources, ExecContext
14
+ from .error_handlers import AppMissingErrorHandler, AppFunctionMissingErrorHandler, ServiceNotStartedErrorHandler
14
15
 
15
16
 
16
17
  class CLIResources(Resources):
@@ -20,7 +21,17 @@ class CLIResources(Resources):
20
21
  """
21
22
 
22
23
  def _handle_standard_exec_errors(self, e: Exception, ctx: ExecContext) -> Any | None:
23
- """For CLI resources, re-raise exceptions without transformation."""
24
+ """
25
+ For CLI resources, keep exceptions raw except for few specific cases.
26
+ """
27
+ message = str(e).lower()
28
+ for handler in (
29
+ ServiceNotStartedErrorHandler(),
30
+ AppMissingErrorHandler(),
31
+ AppFunctionMissingErrorHandler(),
32
+ ):
33
+ if handler.matches(e, message, ctx, self):
34
+ handler.handle(e, ctx, self)
24
35
  raise e
25
36
 
26
37
  def list_warehouses(self):
@@ -13,7 +13,7 @@ from ...config import Config, ConfigStore, ENDPOINT_FILE
13
13
  from ...direct_access_client import DirectAccessClient
14
14
  from ...types import EngineState
15
15
  from ...util import get_pyrel_version, safe_json_loads, ms_to_timestamp
16
- from ....errors import GuardRailsException, ResponseStatusException, QueryTimeoutExceededException
16
+ from ....errors import GuardRailsException, ResponseStatusException, QueryTimeoutExceededException, RAIException
17
17
  from snowflake.snowpark import Session
18
18
 
19
19
  # Import UseIndexResources to enable use_index functionality with direct access
@@ -163,11 +163,12 @@ class DirectAccessResources(UseIndexResources):
163
163
  headers=headers,
164
164
  )
165
165
  response = _send_request()
166
- except requests.exceptions.ConnectionError as e:
166
+ except (requests.exceptions.ConnectionError, RAIException) as e:
167
167
  messages = collect_error_messages(e)
168
- if any("nameresolutionerror" in msg for msg in messages):
169
- # when we can not resolve the service endpoint, we assume it is outdated
170
- # hence, we try to retrieve it again and query again.
168
+ if any("nameresolutionerror" in msg for msg in messages) or \
169
+ any("could not find the service associated with endpoint" in msg for msg in messages):
170
+ # when we can not resolve the service endpoint or the service is not found,
171
+ # we assume the endpoint is outdated, so we retrieve it again and retry.
171
172
  self.direct_access_client.service_endpoint = self._retrieve_service_endpoint(
172
173
  enforce_update=True,
173
174
  )
@@ -513,7 +514,14 @@ class DirectAccessResources(UseIndexResources):
513
514
  # Engines
514
515
  #--------------------------------------------------
515
516
 
516
- def list_engines(self, state: str | None = None):
517
+ def list_engines(
518
+ self,
519
+ state: str | None = None,
520
+ name: str | None = None,
521
+ type: str | None = None,
522
+ size: str | None = None,
523
+ created_by: str | None = None,
524
+ ):
517
525
  response = self.request("list_engines")
518
526
  if response.status_code != 200:
519
527
  raise ResponseStatusException(
@@ -526,19 +534,31 @@ class DirectAccessResources(UseIndexResources):
526
534
  {
527
535
  "name": engine["name"],
528
536
  "id": engine["id"],
537
+ "type": engine.get("type", "LOGIC"),
529
538
  "size": engine["size"],
530
539
  "state": engine["status"], # callers are expecting 'state'
531
540
  "created_by": engine["created_by"],
532
541
  "created_on": engine["created_on"],
533
542
  "updated_on": engine["updated_on"],
543
+ # Optional fields (present in newer APIs / service-functions path)
544
+ "auto_suspend_mins": engine.get("auto_suspend_mins"),
545
+ "suspends_at": engine.get("suspends_at"),
546
+ "settings": engine.get("settings"),
534
547
  }
535
548
  for engine in response_content.get("engines", [])
536
- if state is None or engine.get("status") == state
549
+ if (state is None or engine.get("status") == state)
550
+ and (name is None or name.upper() in engine.get("name", "").upper())
551
+ and (type is None or engine.get("type", "LOGIC").upper() == type.upper())
552
+ and (size is None or engine.get("size") == size)
553
+ and (created_by is None or created_by.upper() in engine.get("created_by", "").upper())
537
554
  ]
538
555
  return sorted(engines, key=lambda x: x["name"])
539
556
 
540
- def get_engine(self, name: str):
541
- response = self.request("get_engine", path_params={"engine_name": name, "engine_type": "logic"}, skip_auto_create=True)
557
+ def get_engine(self, name: str, type: str):
558
+ if type is None:
559
+ raise Exception("Engine type is required. Valid types are: LOGIC, SOLVER, ML")
560
+ engine_type_lower = type.lower()
561
+ response = self.request("get_engine", path_params={"engine_name": name, "engine_type": engine_type_lower}, skip_auto_create=True)
542
562
  if response.status_code == 404: # engine not found return 404
543
563
  return None
544
564
  elif response.status_code != 200:
@@ -552,6 +572,7 @@ class DirectAccessResources(UseIndexResources):
552
572
  "name": engine["name"],
553
573
  "id": engine["id"],
554
574
  "size": engine["size"],
575
+ "type": engine.get("type", type),
555
576
  "state": engine["status"], # callers are expecting 'state'
556
577
  "created_by": engine["created_by"],
557
578
  "created_on": engine["created_on"],
@@ -559,31 +580,47 @@ class DirectAccessResources(UseIndexResources):
559
580
  "version": engine["version"],
560
581
  "auto_suspend": engine["auto_suspend_mins"],
561
582
  "suspends_at": engine["suspends_at"],
583
+ "settings": engine.get("settings"),
562
584
  }
563
585
  return engine_state
564
586
 
565
587
  def _create_engine(
566
588
  self,
567
589
  name: str,
590
+ type: str = "LOGIC",
568
591
  size: str | None = None,
569
592
  auto_suspend_mins: int | None = None,
570
593
  is_async: bool = False,
571
- headers: Dict[str, str] | None = None
594
+ headers: Dict[str, str] | None = None,
595
+ settings: Dict[str, Any] | None = None,
572
596
  ):
573
597
  # only async engine creation supported via direct access
574
598
  if not is_async:
575
- return super()._create_engine(name, size, auto_suspend_mins, is_async, headers=headers)
599
+ return super()._create_engine(
600
+ name,
601
+ type=type,
602
+ size=size,
603
+ auto_suspend_mins=auto_suspend_mins,
604
+ is_async=is_async,
605
+ headers=headers,
606
+ settings=settings,
607
+ )
608
+ engine_type_lower = type.lower()
576
609
  payload:Dict[str, Any] = {
577
610
  "name": name,
578
611
  }
612
+ # Allow passing arbitrary engine settings (API-dependent).
613
+ if settings:
614
+ payload["settings"] = settings
579
615
  if auto_suspend_mins is not None:
580
616
  payload["auto_suspend_mins"] = auto_suspend_mins
581
- if size is not None:
582
- payload["size"] = size
617
+ if size is None:
618
+ size = self.config.get_default_engine_size()
619
+ payload["size"] = size
583
620
  response = self.request(
584
621
  "create_engine",
585
622
  payload=payload,
586
- path_params={"engine_type": "logic"},
623
+ path_params={"engine_type": engine_type_lower},
587
624
  headers=headers,
588
625
  skip_auto_create=True,
589
626
  )
@@ -592,11 +629,11 @@ class DirectAccessResources(UseIndexResources):
592
629
  f"Failed to create engine {name} with size {size}.", response
593
630
  )
594
631
 
595
- def delete_engine(self, name:str, force:bool = False, headers={}):
632
+ def delete_engine(self, name: str, type: str):
596
633
  response = self.request(
597
634
  "delete_engine",
598
- path_params={"engine_name": name, "engine_type": "logic"},
599
- headers=headers,
635
+ path_params={"engine_name": name, "engine_type": type.lower()},
636
+ headers={},
600
637
  skip_auto_create=True,
601
638
  )
602
639
  if response.status_code != 200:
@@ -604,10 +641,12 @@ class DirectAccessResources(UseIndexResources):
604
641
  f"Failed to delete engine {name}.", response
605
642
  )
606
643
 
607
- def suspend_engine(self, name:str):
644
+ def suspend_engine(self, name: str, type: str | None = None):
645
+ if type is None:
646
+ type = "LOGIC"
608
647
  response = self.request(
609
648
  "suspend_engine",
610
- path_params={"engine_name": name, "engine_type": "logic"},
649
+ path_params={"engine_name": name, "engine_type": type.lower()},
611
650
  skip_auto_create=True,
612
651
  )
613
652
  if response.status_code != 200:
@@ -615,11 +654,13 @@ class DirectAccessResources(UseIndexResources):
615
654
  f"Failed to suspend engine {name}.", response
616
655
  )
617
656
 
618
- def resume_engine_async(self, name:str, headers={}):
657
+ def resume_engine_async(self, name: str, type: str | None = None, headers: Dict | None = None):
658
+ if type is None:
659
+ type = "LOGIC"
619
660
  response = self.request(
620
661
  "resume_engine",
621
- path_params={"engine_name": name, "engine_type": "logic"},
622
- headers=headers,
662
+ path_params={"engine_name": name, "engine_type": type.lower()},
663
+ headers=headers or {},
623
664
  skip_auto_create=True,
624
665
  )
625
666
  if response.status_code != 200: