relationalai 1.0.0a3__py3-none-any.whl → 1.0.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. relationalai/config/config.py +47 -21
  2. relationalai/config/connections/__init__.py +5 -2
  3. relationalai/config/connections/duckdb.py +2 -2
  4. relationalai/config/connections/local.py +31 -0
  5. relationalai/config/connections/snowflake.py +0 -1
  6. relationalai/config/external/raiconfig_converter.py +235 -0
  7. relationalai/config/external/raiconfig_models.py +202 -0
  8. relationalai/config/external/utils.py +31 -0
  9. relationalai/config/shims.py +1 -0
  10. relationalai/semantics/__init__.py +10 -8
  11. relationalai/semantics/backends/sql/sql_compiler.py +1 -4
  12. relationalai/semantics/experimental/__init__.py +0 -0
  13. relationalai/semantics/experimental/builder.py +295 -0
  14. relationalai/semantics/experimental/builtins.py +154 -0
  15. relationalai/semantics/frontend/base.py +67 -42
  16. relationalai/semantics/frontend/core.py +34 -6
  17. relationalai/semantics/frontend/front_compiler.py +209 -37
  18. relationalai/semantics/frontend/pprint.py +6 -2
  19. relationalai/semantics/metamodel/__init__.py +7 -0
  20. relationalai/semantics/metamodel/metamodel.py +2 -0
  21. relationalai/semantics/metamodel/metamodel_analyzer.py +58 -16
  22. relationalai/semantics/metamodel/pprint.py +6 -1
  23. relationalai/semantics/metamodel/rewriter.py +11 -7
  24. relationalai/semantics/metamodel/typer.py +116 -41
  25. relationalai/semantics/reasoners/__init__.py +11 -0
  26. relationalai/semantics/reasoners/graph/__init__.py +35 -0
  27. relationalai/semantics/reasoners/graph/core.py +9028 -0
  28. relationalai/semantics/std/__init__.py +30 -10
  29. relationalai/semantics/std/aggregates.py +641 -12
  30. relationalai/semantics/std/common.py +146 -13
  31. relationalai/semantics/std/constraints.py +71 -1
  32. relationalai/semantics/std/datetime.py +904 -21
  33. relationalai/semantics/std/decimals.py +143 -2
  34. relationalai/semantics/std/floats.py +57 -4
  35. relationalai/semantics/std/integers.py +98 -4
  36. relationalai/semantics/std/math.py +857 -35
  37. relationalai/semantics/std/numbers.py +216 -20
  38. relationalai/semantics/std/re.py +213 -5
  39. relationalai/semantics/std/strings.py +437 -44
  40. relationalai/shims/executor.py +60 -52
  41. relationalai/shims/fixtures.py +85 -0
  42. relationalai/shims/helpers.py +26 -2
  43. relationalai/shims/hoister.py +28 -9
  44. relationalai/shims/mm2v0.py +204 -173
  45. relationalai/tools/cli/cli.py +192 -10
  46. relationalai/tools/cli/components/progress_reader.py +1 -1
  47. relationalai/tools/cli/docs.py +394 -0
  48. relationalai/tools/debugger.py +11 -4
  49. relationalai/tools/qb_debugger.py +435 -0
  50. relationalai/tools/typer_debugger.py +1 -2
  51. relationalai/util/dataclasses.py +3 -5
  52. relationalai/util/docutils.py +1 -2
  53. relationalai/util/error.py +2 -5
  54. relationalai/util/python.py +23 -0
  55. relationalai/util/runtime.py +1 -2
  56. relationalai/util/schema.py +2 -4
  57. relationalai/util/structures.py +4 -2
  58. relationalai/util/tracing.py +8 -2
  59. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/METADATA +8 -5
  60. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/RECORD +118 -95
  61. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/WHEEL +1 -1
  62. v0/relationalai/__init__.py +1 -1
  63. v0/relationalai/clients/client.py +52 -18
  64. v0/relationalai/clients/exec_txn_poller.py +122 -0
  65. v0/relationalai/clients/local.py +23 -8
  66. v0/relationalai/clients/resources/azure/azure.py +36 -11
  67. v0/relationalai/clients/resources/snowflake/__init__.py +4 -4
  68. v0/relationalai/clients/resources/snowflake/cli_resources.py +12 -1
  69. v0/relationalai/clients/resources/snowflake/direct_access_resources.py +124 -100
  70. v0/relationalai/clients/resources/snowflake/engine_service.py +381 -0
  71. v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +35 -29
  72. v0/relationalai/clients/resources/snowflake/error_handlers.py +43 -2
  73. v0/relationalai/clients/resources/snowflake/snowflake.py +277 -179
  74. v0/relationalai/clients/resources/snowflake/use_index_poller.py +8 -0
  75. v0/relationalai/clients/types.py +5 -0
  76. v0/relationalai/errors.py +19 -1
  77. v0/relationalai/semantics/lqp/algorithms.py +173 -0
  78. v0/relationalai/semantics/lqp/builtins.py +199 -2
  79. v0/relationalai/semantics/lqp/executor.py +68 -37
  80. v0/relationalai/semantics/lqp/ir.py +28 -2
  81. v0/relationalai/semantics/lqp/model2lqp.py +215 -45
  82. v0/relationalai/semantics/lqp/passes.py +13 -658
  83. v0/relationalai/semantics/lqp/rewrite/__init__.py +12 -0
  84. v0/relationalai/semantics/lqp/rewrite/algorithm.py +385 -0
  85. v0/relationalai/semantics/lqp/rewrite/constants_to_vars.py +70 -0
  86. v0/relationalai/semantics/lqp/rewrite/deduplicate_vars.py +104 -0
  87. v0/relationalai/semantics/lqp/rewrite/eliminate_data.py +108 -0
  88. v0/relationalai/semantics/lqp/rewrite/extract_keys.py +25 -3
  89. v0/relationalai/semantics/lqp/rewrite/period_math.py +77 -0
  90. v0/relationalai/semantics/lqp/rewrite/quantify_vars.py +65 -31
  91. v0/relationalai/semantics/lqp/rewrite/unify_definitions.py +317 -0
  92. v0/relationalai/semantics/lqp/utils.py +11 -1
  93. v0/relationalai/semantics/lqp/validators.py +14 -1
  94. v0/relationalai/semantics/metamodel/builtins.py +2 -1
  95. v0/relationalai/semantics/metamodel/compiler.py +2 -1
  96. v0/relationalai/semantics/metamodel/dependency.py +12 -3
  97. v0/relationalai/semantics/metamodel/executor.py +11 -1
  98. v0/relationalai/semantics/metamodel/factory.py +2 -2
  99. v0/relationalai/semantics/metamodel/helpers.py +7 -0
  100. v0/relationalai/semantics/metamodel/ir.py +3 -2
  101. v0/relationalai/semantics/metamodel/rewrite/dnf_union_splitter.py +30 -20
  102. v0/relationalai/semantics/metamodel/rewrite/flatten.py +50 -13
  103. v0/relationalai/semantics/metamodel/rewrite/format_outputs.py +9 -3
  104. v0/relationalai/semantics/metamodel/typer/checker.py +6 -4
  105. v0/relationalai/semantics/metamodel/typer/typer.py +4 -3
  106. v0/relationalai/semantics/metamodel/visitor.py +4 -3
  107. v0/relationalai/semantics/reasoners/optimization/solvers_dev.py +1 -1
  108. v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +336 -86
  109. v0/relationalai/semantics/rel/compiler.py +2 -1
  110. v0/relationalai/semantics/rel/executor.py +3 -2
  111. v0/relationalai/semantics/tests/lqp/__init__.py +0 -0
  112. v0/relationalai/semantics/tests/lqp/algorithms.py +345 -0
  113. v0/relationalai/tools/cli.py +339 -186
  114. v0/relationalai/tools/cli_controls.py +216 -67
  115. v0/relationalai/tools/cli_helpers.py +410 -6
  116. v0/relationalai/util/format.py +5 -2
  117. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/entry_points.txt +0 -0
  118. {relationalai-1.0.0a3.dist-info → relationalai-1.0.0a5.dist-info}/top_level.txt +0 -0
@@ -172,10 +172,17 @@ class LocalResources(ResourcesBase):
172
172
  def get_engine_sizes(self, cloud_provider: str | None = None):
173
173
  raise NotImplementedError("get_engine_sizes not supported in local mode")
174
174
 
175
- def list_engines(self, state: str | None = None):
175
+ def list_engines(
176
+ self,
177
+ state: str | None = None,
178
+ name: str | None = None,
179
+ type: str | None = None,
180
+ size: str | None = None,
181
+ created_by: str | None = None,
182
+ ):
176
183
  raise NotImplementedError("list_engines not supported in local mode")
177
184
 
178
- def get_engine(self, name: str):
185
+ def get_engine(self, name: str, type: str):
179
186
  raise NotImplementedError("get_engine not supported in local mode")
180
187
 
181
188
  def get_cloud_provider(self) -> str:
@@ -184,25 +191,33 @@ class LocalResources(ResourcesBase):
184
191
  def is_valid_engine_state(self, name: str) -> bool:
185
192
  raise NotImplementedError("is_valid_engine_state not supported in local mode")
186
193
 
187
- def create_engine(self, name: str, size:str|None=None, auto_suspend_mins:int|None=None):
194
+ def create_engine(
195
+ self,
196
+ name: str,
197
+ type: str | None = None,
198
+ size: str | None = None,
199
+ auto_suspend_mins: int | None = None,
200
+ headers: dict | None = None,
201
+ settings: dict | None = None,
202
+ ):
188
203
  raise NotImplementedError("create_engine not supported in local mode")
189
204
 
190
- def delete_engine(self, name: str, force: bool = False):
205
+ def delete_engine(self, name: str, type: str):
191
206
  raise NotImplementedError("delete_engine not supported in local mode")
192
207
 
193
- def suspend_engine(self, name: str):
208
+ def suspend_engine(self, name: str, type: str | None = None):
194
209
  raise NotImplementedError("suspend_engine not supported in local mode")
195
210
 
196
- def resume_engine(self, name: str, headers: Dict | None = None):
211
+ def resume_engine(self, name: str, type: str | None = None, headers: Dict | None = None):
197
212
  raise NotImplementedError("resume_engine not supported in local mode")
198
213
 
199
- def resume_engine_async(self, name: str):
214
+ def resume_engine_async(self, name: str, type: str | None = None, headers: Dict | None = None):
200
215
  raise NotImplementedError("resume_engine_async not supported in local mode")
201
216
 
202
217
  def alter_engine_pool(self, size: str | None = None, mins: int | None = None, maxs: int | None = None):
203
218
  raise NotImplementedError("alter_engine_pool not supported in local mode")
204
219
 
205
- def auto_create_engine_async(self, name: str | None = None) -> str:
220
+ def auto_create_engine_async(self, name: str | None = None, type: str | None = None) -> str:
206
221
  raise NotImplementedError("auto_create_engine_async not supported in local mode")
207
222
 
208
223
  #--------------------------------------------------
@@ -78,34 +78,59 @@ class Resources(ResourcesBase):
78
78
  def get_cloud_provider(self) -> str:
79
79
  return "azure"
80
80
 
81
- def list_engines(self, state:str|None = None):
82
- return api.list_engines(self._api_ctx(), state)
83
-
84
- def get_engine(self, name:str) -> EngineState:
85
- return cast(EngineState, api.get_engine(self._api_ctx(), name))
81
+ def list_engines(
82
+ self,
83
+ state: str | None = None,
84
+ name: str | None = None,
85
+ type: str | None = None,
86
+ size: str | None = None,
87
+ created_by: str | None = None,
88
+ ):
89
+ # Azure only supports state filtering at the API level; other filters are ignored.
90
+ engines = api.list_engines(self._api_ctx(), state)
91
+ # Ensure EngineState shape includes 'type' for callers/tests.
92
+ for eng in engines:
93
+ eng.setdefault("type", "LOGIC")
94
+ return engines
95
+
96
+ def get_engine(self, name: str, type: str) -> EngineState:
97
+ # type is ignored for Azure as it doesn't support multiple engine types
98
+ engine = cast(EngineState, api.get_engine(self._api_ctx(), name))
99
+ engine.setdefault("type", "LOGIC")
100
+ return engine
86
101
 
87
102
  def is_valid_engine_state(self, name:str):
88
103
  return name in VALID_ENGINE_STATES
89
104
 
90
- def create_engine(self, name:str, size:str|None=None, auto_suspend_mins: int|None=None):
105
+ def create_engine(
106
+ self,
107
+ name: str,
108
+ type: str | None = None,
109
+ size: str | None = None,
110
+ auto_suspend_mins: int | None = None,
111
+ headers: dict | None = None,
112
+ settings: dict | None = None,
113
+ ):
114
+ # Azure only supports one engine type, so type parameter is ignored
91
115
  if size is None:
92
116
  size = "M"
93
117
  with debugging.span("create_engine", name=name, size=size):
94
118
  return api.create_engine_wait(self._api_ctx(), name, size)
95
119
 
96
- def delete_engine(self, name:str, force:bool=False):
120
+ def delete_engine(self, name:str, type: str):
97
121
  return api.delete_engine(self._api_ctx(), name)
98
122
 
99
- def suspend_engine(self, name:str):
123
+ def suspend_engine(self, name:str, type: str | None = None): # type is ignored for Azure
100
124
  return api.suspend_engine(self._api_ctx(), name)
101
125
 
102
- def resume_engine(self, name:str, headers={}):
126
+ def resume_engine(self, name:str, type: str | None = None, headers={}):
127
+ # type is ignored for Azure as it doesn't support multiple engine types
103
128
  return api.resume_engine_wait(self._api_ctx(), name)
104
129
 
105
- def resume_engine_async(self, name:str):
130
+ def resume_engine_async(self, name:str, type: str | None = None, headers: Dict | None = None):
106
131
  return api.resume_engine(self._api_ctx(), name)
107
132
 
108
- def auto_create_engine_async(self, name: str | None = None) -> str:
133
+ def auto_create_engine_async(self, name: str | None = None, type: str | None = None) -> str:
109
134
  raise Exception("Azure doesn't support auto_create_engine_async")
110
135
 
111
136
  def alter_engine_pool(self, size: str | None = None, mins: int | None = None, maxs: int | None = None):
@@ -2,8 +2,8 @@
2
2
  Snowflake resources module.
3
3
  """
4
4
  # Import order matters - Resources must be imported first since other classes depend on it
5
- from .snowflake import Resources, Provider, Graph, SnowflakeClient, APP_NAME, PYREL_ROOT_DB, ExecContext, INTERNAL_ENGINE_SIZES, ENGINE_SIZES_AWS, ENGINE_SIZES_AZURE, PrimaryKey
6
-
5
+ from .snowflake import Resources, Provider, Graph, SnowflakeClient, APP_NAME, PYREL_ROOT_DB, ExecContext, PrimaryKey, PRINT_TXN_PROGRESS_FLAG
6
+ from .engine_service import EngineType, INTERNAL_ENGINE_SIZES, ENGINE_SIZES_AWS, ENGINE_SIZES_AZURE
7
7
  # These imports depend on Resources, so they come after
8
8
  from .cli_resources import CLIResources
9
9
  from .use_index_resources import UseIndexResources
@@ -12,9 +12,9 @@ from .resources_factory import create_resources_instance
12
12
 
13
13
  __all__ = [
14
14
  'Resources', 'DirectAccessResources', 'Provider', 'Graph', 'SnowflakeClient',
15
- 'APP_NAME', 'PYREL_ROOT_DB', 'CLIResources', 'UseIndexResources', 'ExecContext',
15
+ 'APP_NAME', 'PYREL_ROOT_DB', 'CLIResources', 'UseIndexResources', 'ExecContext', 'EngineType',
16
16
  'INTERNAL_ENGINE_SIZES', 'ENGINE_SIZES_AWS', 'ENGINE_SIZES_AZURE', 'PrimaryKey',
17
- 'create_resources_instance',
17
+ 'PRINT_TXN_PROGRESS_FLAG', 'create_resources_instance',
18
18
  ]
19
19
 
20
20
 
@@ -11,6 +11,7 @@ from ....tools.constants import RAI_APP_NAME
11
11
  # Import Resources from snowflake - this creates a dependency but no circular import
12
12
  # since snowflake.py doesn't import from this file
13
13
  from .snowflake import Resources, ExecContext
14
+ from .error_handlers import AppMissingErrorHandler, AppFunctionMissingErrorHandler, ServiceNotStartedErrorHandler
14
15
 
15
16
 
16
17
  class CLIResources(Resources):
@@ -20,7 +21,17 @@ class CLIResources(Resources):
20
21
  """
21
22
 
22
23
  def _handle_standard_exec_errors(self, e: Exception, ctx: ExecContext) -> Any | None:
23
- """For CLI resources, re-raise exceptions without transformation."""
24
+ """
25
+ For CLI resources, keep exceptions raw except for few specific cases.
26
+ """
27
+ message = str(e).lower()
28
+ for handler in (
29
+ ServiceNotStartedErrorHandler(),
30
+ AppMissingErrorHandler(),
31
+ AppFunctionMissingErrorHandler(),
32
+ ):
33
+ if handler.matches(e, message, ctx, self):
34
+ handler.handle(e, ctx, self)
24
35
  raise e
25
36
 
26
37
  def list_warehouses(self):
@@ -12,12 +12,13 @@ from ....environments import runtime_env, SnowbookEnvironment
12
12
  from ...config import Config, ConfigStore, ENDPOINT_FILE
13
13
  from ...direct_access_client import DirectAccessClient
14
14
  from ...types import EngineState
15
- from ...util import get_pyrel_version, poll_with_specified_overhead, safe_json_loads, ms_to_timestamp
16
- from ....errors import ResponseStatusException, QueryTimeoutExceededException
15
+ from ...util import get_pyrel_version, safe_json_loads, ms_to_timestamp
16
+ from ....errors import GuardRailsException, ResponseStatusException, QueryTimeoutExceededException, RAIException
17
17
  from snowflake.snowpark import Session
18
18
 
19
19
  # Import UseIndexResources to enable use_index functionality with direct access
20
20
  from .use_index_resources import UseIndexResources
21
+ from .snowflake import TxnCreationResult
21
22
 
22
23
  # Import helper functions from util
23
24
  from .util import is_engine_issue as _is_engine_issue, is_database_issue as _is_database_issue, collect_error_messages
@@ -27,6 +28,7 @@ from typing import Iterable
27
28
 
28
29
  # Constants
29
30
  TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
31
+ TXN_ABORT_REASON_GUARD_RAILS = "guard rail violation"
30
32
 
31
33
 
32
34
  class DirectAccessResources(UseIndexResources):
@@ -161,11 +163,12 @@ class DirectAccessResources(UseIndexResources):
161
163
  headers=headers,
162
164
  )
163
165
  response = _send_request()
164
- except requests.exceptions.ConnectionError as e:
166
+ except (requests.exceptions.ConnectionError, RAIException) as e:
165
167
  messages = collect_error_messages(e)
166
- if any("nameresolutionerror" in msg for msg in messages):
167
- # when we can not resolve the service endpoint, we assume it is outdated
168
- # hence, we try to retrieve it again and query again.
168
+ if any("nameresolutionerror" in msg for msg in messages) or \
169
+ any("could not find the service associated with endpoint" in msg for msg in messages):
170
+ # when we can not resolve the service endpoint or the service is not found,
171
+ # we assume the endpoint is outdated, so we retrieve it again and retry.
169
172
  self.direct_access_client.service_endpoint = self._retrieve_service_endpoint(
170
173
  enforce_update=True,
171
174
  )
@@ -217,83 +220,59 @@ class DirectAccessResources(UseIndexResources):
217
220
 
218
221
  return response
219
222
 
220
- def _exec_async_v2(
223
+ def _create_v2_txn(
221
224
  self,
222
225
  database: str,
223
- engine: Union[str, None],
226
+ engine: str | None,
224
227
  raw_code: str,
225
- inputs: Dict | None = None,
226
- readonly=True,
227
- nowait_durable=False,
228
- headers: Dict[str, str] | None = None,
229
- bypass_index=False,
230
- language: str = "rel",
231
- query_timeout_mins: int | None = None,
232
- gi_setup_skipped: bool = False,
233
- ):
234
-
235
- with debugging.span("transaction") as txn_span:
236
- with debugging.span("create_v2") as create_span:
237
-
238
- use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
239
-
240
- payload = {
241
- "dbname": database,
242
- "engine_name": engine,
243
- "query": raw_code,
244
- "v1_inputs": inputs,
245
- "nowait_durable": nowait_durable,
246
- "readonly": readonly,
247
- "language": language,
248
- }
249
- if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
250
- query_timeout_mins = int(timeout_value)
251
- if query_timeout_mins is not None:
252
- payload["timeout_mins"] = query_timeout_mins
253
- query_params={"use_graph_index": str(use_graph_index and not bypass_index)}
254
-
255
- # Add gi_setup_skipped to headers
256
- if headers is None:
257
- headers = {}
258
- headers["gi_setup_skipped"] = str(gi_setup_skipped)
259
- headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
260
-
261
- response = self._txn_request_with_gi_retry(
262
- payload, headers, query_params, engine
263
- )
264
-
265
- artifact_info = {}
266
- response_content = response.json()
228
+ inputs: Dict,
229
+ headers: Dict[str, str],
230
+ readonly: bool,
231
+ nowait_durable: bool,
232
+ bypass_index: bool,
233
+ language: str,
234
+ query_timeout_mins: int | None,
235
+ ) -> TxnCreationResult:
236
+ """
237
+ Create a transaction via direct HTTP access and return the result.
267
238
 
268
- txn_id = response_content["transaction"]['id']
269
- state = response_content["transaction"]['state']
239
+ This override uses HTTP requests instead of SQL stored procedures.
240
+ """
241
+ use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
270
242
 
271
- txn_span["txn_id"] = txn_id
272
- create_span["txn_id"] = txn_id
273
- debugging.event("transaction_created", txn_span, txn_id=txn_id)
243
+ payload = {
244
+ "dbname": database,
245
+ "engine_name": engine,
246
+ "query": raw_code,
247
+ "v1_inputs": inputs,
248
+ "nowait_durable": nowait_durable,
249
+ "readonly": readonly,
250
+ "language": language,
251
+ }
252
+ if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
253
+ query_timeout_mins = int(timeout_value)
254
+ if query_timeout_mins is not None:
255
+ payload["timeout_mins"] = query_timeout_mins
256
+ query_params = {"use_graph_index": str(use_graph_index and not bypass_index)}
257
+
258
+ response = self._txn_request_with_gi_retry(
259
+ payload, headers, query_params, engine
260
+ )
274
261
 
275
- # fast path: transaction already finished
276
- if state in ["COMPLETED", "ABORTED"]:
277
- if txn_id in self._pending_transactions:
278
- self._pending_transactions.remove(txn_id)
262
+ response_content = response.json()
279
263
 
280
- # Process rows to get the rest of the artifacts
281
- for result in response_content.get("results", []):
282
- filename = result['filename']
283
- # making keys uppercase to match the old behavior
284
- artifact_info[filename] = {k.upper(): v for k, v in result.items()}
264
+ txn_id = response_content["transaction"]['id']
265
+ state = response_content["transaction"]['state']
285
266
 
286
- # Slow path: transaction not done yet; start polling
287
- else:
288
- self._pending_transactions.append(txn_id)
289
- with debugging.span("wait", txn_id=txn_id):
290
- poll_with_specified_overhead(
291
- lambda: self._check_exec_async_status(txn_id, headers=headers), 0.1
292
- )
293
- artifact_info = self._list_exec_async_artifacts(txn_id, headers=headers)
267
+ # Build artifact_info if transaction completed immediately (fast path)
268
+ artifact_info: Dict[str, Dict] = {}
269
+ if state in ["COMPLETED", "ABORTED"]:
270
+ for result in response_content.get("results", []):
271
+ filename = result['filename']
272
+ # making keys uppercase to match the old behavior
273
+ artifact_info[filename] = {k.upper(): v for k, v in result.items()}
294
274
 
295
- with debugging.span("fetch"):
296
- return self._download_results(artifact_info, txn_id, state)
275
+ return TxnCreationResult(txn_id=txn_id, state=state, artifact_info=artifact_info)
297
276
 
298
277
  def _prepare_index(
299
278
  self,
@@ -355,15 +334,20 @@ class DirectAccessResources(UseIndexResources):
355
334
  if txn_id in self._pending_transactions:
356
335
  self._pending_transactions.remove(txn_id)
357
336
 
358
- if status == "ABORTED" and transaction.get("abort_reason", "") == TXN_ABORT_REASON_TIMEOUT:
359
- config_file_path = getattr(self.config, 'file_path', None)
360
- timeout_ms = int(transaction.get("timeout_ms", 0))
361
- timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
362
- raise QueryTimeoutExceededException(
363
- timeout_mins=timeout_mins,
364
- query_id=txn_id,
365
- config_file_path=config_file_path,
366
- )
337
+ if status == "ABORTED":
338
+ reason = transaction.get("abort_reason", "")
339
+
340
+ if reason == TXN_ABORT_REASON_TIMEOUT:
341
+ config_file_path = getattr(self.config, 'file_path', None)
342
+ timeout_ms = int(transaction.get("timeout_ms", 0))
343
+ timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
344
+ raise QueryTimeoutExceededException(
345
+ timeout_mins=timeout_mins,
346
+ query_id=txn_id,
347
+ config_file_path=config_file_path,
348
+ )
349
+ elif reason == TXN_ABORT_REASON_GUARD_RAILS:
350
+ raise GuardRailsException(response_content.get("progress", {}))
367
351
 
368
352
  # @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
369
353
  return status == "COMPLETED" or status == "ABORTED"
@@ -530,7 +514,14 @@ class DirectAccessResources(UseIndexResources):
530
514
  # Engines
531
515
  #--------------------------------------------------
532
516
 
533
- def list_engines(self, state: str | None = None):
517
+ def list_engines(
518
+ self,
519
+ state: str | None = None,
520
+ name: str | None = None,
521
+ type: str | None = None,
522
+ size: str | None = None,
523
+ created_by: str | None = None,
524
+ ):
534
525
  response = self.request("list_engines")
535
526
  if response.status_code != 200:
536
527
  raise ResponseStatusException(
@@ -543,19 +534,31 @@ class DirectAccessResources(UseIndexResources):
543
534
  {
544
535
  "name": engine["name"],
545
536
  "id": engine["id"],
537
+ "type": engine.get("type", "LOGIC"),
546
538
  "size": engine["size"],
547
539
  "state": engine["status"], # callers are expecting 'state'
548
540
  "created_by": engine["created_by"],
549
541
  "created_on": engine["created_on"],
550
542
  "updated_on": engine["updated_on"],
543
+ # Optional fields (present in newer APIs / service-functions path)
544
+ "auto_suspend_mins": engine.get("auto_suspend_mins"),
545
+ "suspends_at": engine.get("suspends_at"),
546
+ "settings": engine.get("settings"),
551
547
  }
552
548
  for engine in response_content.get("engines", [])
553
- if state is None or engine.get("status") == state
549
+ if (state is None or engine.get("status") == state)
550
+ and (name is None or name.upper() in engine.get("name", "").upper())
551
+ and (type is None or engine.get("type", "LOGIC").upper() == type.upper())
552
+ and (size is None or engine.get("size") == size)
553
+ and (created_by is None or created_by.upper() in engine.get("created_by", "").upper())
554
554
  ]
555
555
  return sorted(engines, key=lambda x: x["name"])
556
556
 
557
- def get_engine(self, name: str):
558
- response = self.request("get_engine", path_params={"engine_name": name, "engine_type": "logic"}, skip_auto_create=True)
557
+ def get_engine(self, name: str, type: str):
558
+ if type is None:
559
+ raise Exception("Engine type is required. Valid types are: LOGIC, SOLVER, ML")
560
+ engine_type_lower = type.lower()
561
+ response = self.request("get_engine", path_params={"engine_name": name, "engine_type": engine_type_lower}, skip_auto_create=True)
559
562
  if response.status_code == 404: # engine not found return 404
560
563
  return None
561
564
  elif response.status_code != 200:
@@ -569,6 +572,7 @@ class DirectAccessResources(UseIndexResources):
569
572
  "name": engine["name"],
570
573
  "id": engine["id"],
571
574
  "size": engine["size"],
575
+ "type": engine.get("type", type),
572
576
  "state": engine["status"], # callers are expecting 'state'
573
577
  "created_by": engine["created_by"],
574
578
  "created_on": engine["created_on"],
@@ -576,31 +580,47 @@ class DirectAccessResources(UseIndexResources):
576
580
  "version": engine["version"],
577
581
  "auto_suspend": engine["auto_suspend_mins"],
578
582
  "suspends_at": engine["suspends_at"],
583
+ "settings": engine.get("settings"),
579
584
  }
580
585
  return engine_state
581
586
 
582
587
  def _create_engine(
583
588
  self,
584
589
  name: str,
590
+ type: str = "LOGIC",
585
591
  size: str | None = None,
586
592
  auto_suspend_mins: int | None = None,
587
593
  is_async: bool = False,
588
- headers: Dict[str, str] | None = None
594
+ headers: Dict[str, str] | None = None,
595
+ settings: Dict[str, Any] | None = None,
589
596
  ):
590
597
  # only async engine creation supported via direct access
591
598
  if not is_async:
592
- return super()._create_engine(name, size, auto_suspend_mins, is_async, headers=headers)
599
+ return super()._create_engine(
600
+ name,
601
+ type=type,
602
+ size=size,
603
+ auto_suspend_mins=auto_suspend_mins,
604
+ is_async=is_async,
605
+ headers=headers,
606
+ settings=settings,
607
+ )
608
+ engine_type_lower = type.lower()
593
609
  payload:Dict[str, Any] = {
594
610
  "name": name,
595
611
  }
612
+ # Allow passing arbitrary engine settings (API-dependent).
613
+ if settings:
614
+ payload["settings"] = settings
596
615
  if auto_suspend_mins is not None:
597
616
  payload["auto_suspend_mins"] = auto_suspend_mins
598
- if size is not None:
599
- payload["size"] = size
617
+ if size is None:
618
+ size = self.config.get_default_engine_size()
619
+ payload["size"] = size
600
620
  response = self.request(
601
621
  "create_engine",
602
622
  payload=payload,
603
- path_params={"engine_type": "logic"},
623
+ path_params={"engine_type": engine_type_lower},
604
624
  headers=headers,
605
625
  skip_auto_create=True,
606
626
  )
@@ -609,11 +629,11 @@ class DirectAccessResources(UseIndexResources):
609
629
  f"Failed to create engine {name} with size {size}.", response
610
630
  )
611
631
 
612
- def delete_engine(self, name:str, force:bool = False, headers={}):
632
+ def delete_engine(self, name: str, type: str):
613
633
  response = self.request(
614
634
  "delete_engine",
615
- path_params={"engine_name": name, "engine_type": "logic"},
616
- headers=headers,
635
+ path_params={"engine_name": name, "engine_type": type.lower()},
636
+ headers={},
617
637
  skip_auto_create=True,
618
638
  )
619
639
  if response.status_code != 200:
@@ -621,10 +641,12 @@ class DirectAccessResources(UseIndexResources):
621
641
  f"Failed to delete engine {name}.", response
622
642
  )
623
643
 
624
- def suspend_engine(self, name:str):
644
+ def suspend_engine(self, name: str, type: str | None = None):
645
+ if type is None:
646
+ type = "LOGIC"
625
647
  response = self.request(
626
648
  "suspend_engine",
627
- path_params={"engine_name": name, "engine_type": "logic"},
649
+ path_params={"engine_name": name, "engine_type": type.lower()},
628
650
  skip_auto_create=True,
629
651
  )
630
652
  if response.status_code != 200:
@@ -632,11 +654,13 @@ class DirectAccessResources(UseIndexResources):
632
654
  f"Failed to suspend engine {name}.", response
633
655
  )
634
656
 
635
- def resume_engine_async(self, name:str, headers={}):
657
+ def resume_engine_async(self, name: str, type: str | None = None, headers: Dict | None = None):
658
+ if type is None:
659
+ type = "LOGIC"
636
660
  response = self.request(
637
661
  "resume_engine",
638
- path_params={"engine_name": name, "engine_type": "logic"},
639
- headers=headers,
662
+ path_params={"engine_name": name, "engine_type": type.lower()},
663
+ headers=headers or {},
640
664
  skip_auto_create=True,
641
665
  )
642
666
  if response.status_code != 200: