relationalai 1.0.0a2__py3-none-any.whl → 1.0.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. relationalai/config/shims.py +1 -0
  2. relationalai/semantics/__init__.py +7 -1
  3. relationalai/semantics/frontend/base.py +19 -13
  4. relationalai/semantics/frontend/core.py +30 -2
  5. relationalai/semantics/frontend/front_compiler.py +38 -11
  6. relationalai/semantics/frontend/pprint.py +1 -1
  7. relationalai/semantics/metamodel/rewriter.py +6 -2
  8. relationalai/semantics/metamodel/typer.py +70 -26
  9. relationalai/semantics/reasoners/__init__.py +11 -0
  10. relationalai/semantics/reasoners/graph/__init__.py +38 -0
  11. relationalai/semantics/reasoners/graph/core.py +9015 -0
  12. relationalai/shims/executor.py +4 -1
  13. relationalai/shims/hoister.py +9 -0
  14. relationalai/shims/mm2v0.py +47 -34
  15. relationalai/tools/cli/cli.py +138 -0
  16. relationalai/tools/cli/docs.py +394 -0
  17. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/METADATA +5 -3
  18. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/RECORD +57 -43
  19. v0/relationalai/__init__.py +69 -22
  20. v0/relationalai/clients/__init__.py +15 -2
  21. v0/relationalai/clients/client.py +4 -4
  22. v0/relationalai/clients/exec_txn_poller.py +91 -0
  23. v0/relationalai/clients/local.py +5 -5
  24. v0/relationalai/clients/resources/__init__.py +8 -0
  25. v0/relationalai/clients/{azure.py → resources/azure/azure.py} +12 -12
  26. v0/relationalai/clients/resources/snowflake/__init__.py +20 -0
  27. v0/relationalai/clients/resources/snowflake/cli_resources.py +87 -0
  28. v0/relationalai/clients/resources/snowflake/direct_access_resources.py +717 -0
  29. v0/relationalai/clients/resources/snowflake/engine_state_handlers.py +309 -0
  30. v0/relationalai/clients/resources/snowflake/error_handlers.py +199 -0
  31. v0/relationalai/clients/resources/snowflake/resources_factory.py +99 -0
  32. v0/relationalai/clients/{snowflake.py → resources/snowflake/snowflake.py} +642 -1399
  33. v0/relationalai/clients/{use_index_poller.py → resources/snowflake/use_index_poller.py} +51 -12
  34. v0/relationalai/clients/resources/snowflake/use_index_resources.py +188 -0
  35. v0/relationalai/clients/resources/snowflake/util.py +387 -0
  36. v0/relationalai/early_access/dsl/ir/executor.py +4 -4
  37. v0/relationalai/early_access/dsl/snow/api.py +2 -1
  38. v0/relationalai/errors.py +18 -0
  39. v0/relationalai/experimental/solvers.py +7 -7
  40. v0/relationalai/semantics/devtools/benchmark_lqp.py +4 -5
  41. v0/relationalai/semantics/devtools/extract_lqp.py +1 -1
  42. v0/relationalai/semantics/internal/snowflake.py +1 -1
  43. v0/relationalai/semantics/lqp/executor.py +7 -12
  44. v0/relationalai/semantics/lqp/rewrite/extract_keys.py +25 -3
  45. v0/relationalai/semantics/metamodel/util.py +6 -5
  46. v0/relationalai/semantics/reasoners/optimization/solvers_pb.py +335 -84
  47. v0/relationalai/semantics/rel/executor.py +14 -11
  48. v0/relationalai/semantics/sql/executor/snowflake.py +9 -5
  49. v0/relationalai/semantics/tests/test_snapshot_abstract.py +1 -1
  50. v0/relationalai/tools/cli.py +26 -30
  51. v0/relationalai/tools/cli_helpers.py +10 -2
  52. v0/relationalai/util/otel_configuration.py +2 -1
  53. v0/relationalai/util/otel_handler.py +1 -1
  54. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/WHEEL +0 -0
  55. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/entry_points.txt +0 -0
  56. {relationalai-1.0.0a2.dist-info → relationalai-1.0.0a4.dist-info}/top_level.txt +0 -0
  57. /v0/relationalai/clients/{cache_store.py → resources/snowflake/cache_store.py} +0 -0
@@ -0,0 +1,717 @@
1
+ """
2
+ Direct Access Resources - Resources class for Direct Service Access.
3
+ This class overrides methods to use direct HTTP requests instead of Snowflake service functions.
4
+ """
5
+ from __future__ import annotations
6
+ from typing import Any, Dict, List, Optional, Union
7
+ import requests
8
+
9
+ from .... import debugging
10
+ from ....tools.constants import USE_GRAPH_INDEX, DEFAULT_QUERY_TIMEOUT_MINS, Generation
11
+ from ....environments import runtime_env, SnowbookEnvironment
12
+ from ...config import Config, ConfigStore, ENDPOINT_FILE
13
+ from ...direct_access_client import DirectAccessClient
14
+ from ...types import EngineState
15
+ from ...util import get_pyrel_version, poll_with_specified_overhead, safe_json_loads, ms_to_timestamp
16
+ from ....errors import GuardRailsException, ResponseStatusException, QueryTimeoutExceededException
17
+ from snowflake.snowpark import Session
18
+
19
+ # Import UseIndexResources to enable use_index functionality with direct access
20
+ from .use_index_resources import UseIndexResources
21
+
22
+ # Import helper functions from util
23
+ from .util import is_engine_issue as _is_engine_issue, is_database_issue as _is_database_issue, collect_error_messages
24
+
25
+ from .use_index_poller import DirectUseIndexPoller
26
+ from typing import Iterable
27
+
28
+ # Constants
29
+ TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
30
+ TXN_ABORT_REASON_GUARD_RAILS = "guard rail violation"
31
+
32
+
33
+ class DirectAccessResources(UseIndexResources):
34
+ """
35
+ Resources class for Direct Service Access avoiding Snowflake service functions.
36
+ Uses HTTP requests instead of Snowflake SQL for execution.
37
+ """
38
+ def __init__(
39
+ self,
40
+ profile: Union[str, None] = None,
41
+ config: Union[Config, None] = None,
42
+ connection: Union[Session, None] = None,
43
+ dry_run: bool = False,
44
+ reset_session: bool = False,
45
+ generation: Optional[Generation] = None,
46
+ language: str = "rel",
47
+ ):
48
+ super().__init__(
49
+ generation=generation,
50
+ profile=profile,
51
+ config=config,
52
+ connection=connection,
53
+ reset_session=reset_session,
54
+ dry_run=dry_run,
55
+ language=language,
56
+ )
57
+ self._endpoint_info = ConfigStore(ENDPOINT_FILE)
58
+ self._service_endpoint = ""
59
+ self._direct_access_client = None
60
+ # database and language are already set by UseIndexResources.__init__
61
+
62
+ @property
63
+ def service_endpoint(self) -> str:
64
+ return self._retrieve_service_endpoint()
65
+
66
+ def _retrieve_service_endpoint(self, enforce_update=False) -> str:
67
+ account = self.config.get("account")
68
+ app_name = self.config.get("rai_app_name")
69
+ service_endpoint_key = f"{account}.{app_name}.service_endpoint"
70
+ if self._service_endpoint and not enforce_update:
71
+ return self._service_endpoint
72
+ if self._endpoint_info.get(service_endpoint_key, "") and not enforce_update:
73
+ self._service_endpoint = str(self._endpoint_info.get(service_endpoint_key, ""))
74
+ return self._service_endpoint
75
+
76
+ is_snowflake_notebook = isinstance(runtime_env, SnowbookEnvironment)
77
+ query = f"CALL {self.get_app_name()}.app.service_endpoint({not is_snowflake_notebook});"
78
+ result = self._exec(query)
79
+ assert result, f"Could not retrieve service endpoint for {self.get_app_name()}"
80
+ if is_snowflake_notebook:
81
+ self._service_endpoint = f"http://{result[0]['SERVICE_ENDPOINT']}"
82
+ else:
83
+ self._service_endpoint = f"https://{result[0]['SERVICE_ENDPOINT']}"
84
+
85
+ self._endpoint_info.set(service_endpoint_key, self._service_endpoint)
86
+ # save the endpoint to `ENDPOINT_FILE` to avoid calling the endpoint with every
87
+ # pyrel execution
88
+ try:
89
+ self._endpoint_info.save()
90
+ except Exception:
91
+ print("Failed to persist endpoints to file. This might slow down future executions.")
92
+
93
+ return self._service_endpoint
94
+
95
+ @property
96
+ def direct_access_client(self) -> DirectAccessClient:
97
+ if self._direct_access_client:
98
+ return self._direct_access_client
99
+ try:
100
+ service_endpoint = self.service_endpoint
101
+ self._direct_access_client = DirectAccessClient(
102
+ self.config, self.token_handler, service_endpoint, self.generation,
103
+ )
104
+ except Exception as e:
105
+ raise e
106
+ return self._direct_access_client
107
+
108
+ def request(
109
+ self,
110
+ endpoint: str,
111
+ payload: Dict[str, Any] | None = None,
112
+ headers: Dict[str, str] | None = None,
113
+ path_params: Dict[str, str] | None = None,
114
+ query_params: Dict[str, str] | None = None,
115
+ skip_auto_create: bool = False,
116
+ skip_engine_db_error_retry: bool = False,
117
+ ) -> requests.Response:
118
+ with debugging.span("direct_access_request"):
119
+ def _send_request():
120
+ return self.direct_access_client.request(
121
+ endpoint=endpoint,
122
+ payload=payload,
123
+ headers=headers,
124
+ path_params=path_params,
125
+ query_params=query_params,
126
+ )
127
+ try:
128
+ response = _send_request()
129
+ if response.status_code != 200:
130
+ # For 404 responses with skip_auto_create=True, return immediately to let caller handle it
131
+ # (e.g., get_engine needs to check 404 and return None for auto_create_engine)
132
+ # For skip_auto_create=False, continue to auto-creation logic below
133
+ if response.status_code == 404 and skip_auto_create:
134
+ return response
135
+
136
+ try:
137
+ message = response.json().get("message", "")
138
+ except requests.exceptions.JSONDecodeError:
139
+ # Can't parse JSON response. For skip_auto_create=True (e.g., get_engine),
140
+ # this should have been caught by the 404 check above, so this is an error.
141
+ # For skip_auto_create=False, we explicitly check status_code below,
142
+ # so we don't need to parse the message.
143
+ if skip_auto_create:
144
+ raise ResponseStatusException(
145
+ f"Failed to parse error response from endpoint {endpoint}.", response
146
+ )
147
+ message = "" # Not used when we check status_code directly
148
+
149
+ # fix engine on engine error and retry
150
+ # Skip setting up GI if skip_auto_create is True to avoid recursion or skip_engine_db_error_retry is true to let _exec_async_v2 perform the retry with the correct headers.
151
+ if ((_is_engine_issue(message) and not skip_auto_create) or _is_database_issue(message)) and not skip_engine_db_error_retry:
152
+ engine_name = payload.get("caller_engine_name", "") if payload else ""
153
+ engine_name = engine_name or self.get_default_engine_name()
154
+ engine_size = self.config.get_default_engine_size()
155
+ # Use the mixin's _poll_use_index method
156
+ self._poll_use_index(
157
+ app_name=self.get_app_name(),
158
+ sources=self.sources,
159
+ model=self.database,
160
+ engine_name=engine_name,
161
+ engine_size=engine_size,
162
+ headers=headers,
163
+ )
164
+ response = _send_request()
165
+ except requests.exceptions.ConnectionError as e:
166
+ messages = collect_error_messages(e)
167
+ if any("nameresolutionerror" in msg for msg in messages):
168
+ # when we can not resolve the service endpoint, we assume it is outdated
169
+ # hence, we try to retrieve it again and query again.
170
+ self.direct_access_client.service_endpoint = self._retrieve_service_endpoint(
171
+ enforce_update=True,
172
+ )
173
+ return _send_request()
174
+ # raise in all other cases
175
+ raise e
176
+ return response
177
+
178
+ def _txn_request_with_gi_retry(
179
+ self,
180
+ payload: Dict,
181
+ headers: Dict[str, str],
182
+ query_params: Dict,
183
+ engine: Union[str, None],
184
+ ):
185
+ """Make request with graph index retry logic.
186
+
187
+ Attempts request with gi_setup_skipped=True first. If an engine or database
188
+ issue occurs, polls use_index and retries with gi_setup_skipped=False.
189
+ """
190
+ response = self.request(
191
+ "create_txn", payload=payload, headers=headers, query_params=query_params, skip_auto_create=True, skip_engine_db_error_retry=True
192
+ )
193
+
194
+ if response.status_code != 200:
195
+ try:
196
+ message = response.json().get("message", "")
197
+ except requests.exceptions.JSONDecodeError:
198
+ message = ""
199
+
200
+ if _is_engine_issue(message) or _is_database_issue(message):
201
+ engine_name = engine or self.get_default_engine_name()
202
+ engine_size = self.config.get_default_engine_size()
203
+ # Use the mixin's _poll_use_index method
204
+ self._poll_use_index(
205
+ app_name=self.get_app_name(),
206
+ sources=self.sources,
207
+ model=self.database,
208
+ engine_name=engine_name,
209
+ engine_size=engine_size,
210
+ headers=headers,
211
+ )
212
+ headers['gi_setup_skipped'] = 'False'
213
+ response = self.request(
214
+ "create_txn", payload=payload, headers=headers, query_params=query_params, skip_auto_create=True, skip_engine_db_error_retry=True
215
+ )
216
+ else:
217
+ raise ResponseStatusException("Failed to create transaction.", response)
218
+
219
+ return response
220
+
221
+ def _exec_async_v2(
222
+ self,
223
+ database: str,
224
+ engine: Union[str, None],
225
+ raw_code: str,
226
+ inputs: Dict | None = None,
227
+ readonly=True,
228
+ nowait_durable=False,
229
+ headers: Dict[str, str] | None = None,
230
+ bypass_index=False,
231
+ language: str = "rel",
232
+ query_timeout_mins: int | None = None,
233
+ gi_setup_skipped: bool = False,
234
+ ):
235
+
236
+ with debugging.span("transaction") as txn_span:
237
+ with debugging.span("create_v2") as create_span:
238
+
239
+ use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
240
+
241
+ payload = {
242
+ "dbname": database,
243
+ "engine_name": engine,
244
+ "query": raw_code,
245
+ "v1_inputs": inputs,
246
+ "nowait_durable": nowait_durable,
247
+ "readonly": readonly,
248
+ "language": language,
249
+ }
250
+ if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
251
+ query_timeout_mins = int(timeout_value)
252
+ if query_timeout_mins is not None:
253
+ payload["timeout_mins"] = query_timeout_mins
254
+ query_params={"use_graph_index": str(use_graph_index and not bypass_index)}
255
+
256
+ # Add gi_setup_skipped to headers
257
+ if headers is None:
258
+ headers = {}
259
+ headers["gi_setup_skipped"] = str(gi_setup_skipped)
260
+ headers['pyrel_program_id'] = debugging.get_program_span_id() or ""
261
+
262
+ response = self._txn_request_with_gi_retry(
263
+ payload, headers, query_params, engine
264
+ )
265
+
266
+ artifact_info = {}
267
+ response_content = response.json()
268
+
269
+ txn_id = response_content["transaction"]['id']
270
+ state = response_content["transaction"]['state']
271
+
272
+ txn_span["txn_id"] = txn_id
273
+ create_span["txn_id"] = txn_id
274
+ debugging.event("transaction_created", txn_span, txn_id=txn_id)
275
+
276
+ # fast path: transaction already finished
277
+ if state in ["COMPLETED", "ABORTED"]:
278
+ if txn_id in self._pending_transactions:
279
+ self._pending_transactions.remove(txn_id)
280
+
281
+ # Process rows to get the rest of the artifacts
282
+ for result in response_content.get("results", []):
283
+ filename = result['filename']
284
+ # making keys uppercase to match the old behavior
285
+ artifact_info[filename] = {k.upper(): v for k, v in result.items()}
286
+
287
+ # Slow path: transaction not done yet; start polling
288
+ else:
289
+ self._pending_transactions.append(txn_id)
290
+ with debugging.span("wait", txn_id=txn_id):
291
+ poll_with_specified_overhead(
292
+ lambda: self._check_exec_async_status(txn_id, headers=headers), 0.1
293
+ )
294
+ artifact_info = self._list_exec_async_artifacts(txn_id, headers=headers)
295
+
296
+ with debugging.span("fetch"):
297
+ return self._download_results(artifact_info, txn_id, state)
298
+
299
+ def _prepare_index(
300
+ self,
301
+ model: str,
302
+ engine_name: str,
303
+ engine_size: str = "",
304
+ language: str = "rel",
305
+ rai_relations: List[str] | None = None,
306
+ pyrel_program_id: str | None = None,
307
+ skip_pull_relations: bool = False,
308
+ headers: Dict | None = None,
309
+ ):
310
+ """
311
+ Prepare the index for the given engine and model.
312
+ """
313
+ with debugging.span("prepare_index"):
314
+ if headers is None:
315
+ headers = {}
316
+
317
+ payload = {
318
+ "model_name": model,
319
+ "caller_engine_name": engine_name,
320
+ "language": language,
321
+ "pyrel_program_id": pyrel_program_id,
322
+ "skip_pull_relations": skip_pull_relations,
323
+ "rai_relations": rai_relations or [],
324
+ "user_agent": get_pyrel_version(self.generation),
325
+ }
326
+ # Only include engine_size if it has a non-empty string value
327
+ if engine_size and engine_size.strip():
328
+ payload["caller_engine_size"] = engine_size
329
+
330
+ response = self.request(
331
+ "prepare_index", payload=payload, headers=headers
332
+ )
333
+
334
+ if response.status_code != 200:
335
+ raise ResponseStatusException("Failed to prepare index.", response)
336
+
337
+ return response.json()
338
+
339
+ def _check_exec_async_status(self, txn_id: str, headers: Dict[str, str] | None = None) -> bool:
340
+ """Check whether the given transaction has completed."""
341
+
342
+ with debugging.span("check_status"):
343
+ response = self.request(
344
+ "get_txn",
345
+ headers=headers,
346
+ path_params={"txn_id": txn_id},
347
+ )
348
+ assert response, f"No results from get_transaction('{txn_id}')"
349
+
350
+ response_content = response.json()
351
+ transaction = response_content["transaction"]
352
+ status: str = transaction['state']
353
+
354
+ # remove the transaction from the pending list if it's completed or aborted
355
+ if status in ["COMPLETED", "ABORTED"]:
356
+ if txn_id in self._pending_transactions:
357
+ self._pending_transactions.remove(txn_id)
358
+
359
+ if status == "ABORTED":
360
+ reason = transaction.get("abort_reason", "")
361
+
362
+ if reason == TXN_ABORT_REASON_TIMEOUT:
363
+ config_file_path = getattr(self.config, 'file_path', None)
364
+ timeout_ms = int(transaction.get("timeout_ms", 0))
365
+ timeout_mins = timeout_ms // 60000 if timeout_ms > 0 else int(self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS) or DEFAULT_QUERY_TIMEOUT_MINS)
366
+ raise QueryTimeoutExceededException(
367
+ timeout_mins=timeout_mins,
368
+ query_id=txn_id,
369
+ config_file_path=config_file_path,
370
+ )
371
+ elif reason == TXN_ABORT_REASON_GUARD_RAILS:
372
+ raise GuardRailsException(response_content.get("progress", {}))
373
+
374
+ # @TODO: Find some way to tunnel the ABORT_REASON out. Azure doesn't have this, but it's handy
375
+ return status == "COMPLETED" or status == "ABORTED"
376
+
377
+ def _list_exec_async_artifacts(self, txn_id: str, headers: Dict[str, str] | None = None) -> Dict[str, Dict]:
378
+ """Grab the list of artifacts produced in the transaction and the URLs to retrieve their contents."""
379
+ with debugging.span("list_results"):
380
+ response = self.request(
381
+ "get_txn_artifacts",
382
+ headers=headers,
383
+ path_params={"txn_id": txn_id},
384
+ )
385
+ assert response, f"No results from get_transaction_artifacts('{txn_id}')"
386
+ artifact_info = {}
387
+ for result in response.json()["results"]:
388
+ filename = result['filename']
389
+ # making keys uppercase to match the old behavior
390
+ artifact_info[filename] = {k.upper(): v for k, v in result.items()}
391
+ return artifact_info
392
+
393
+ def get_transaction_problems(self, txn_id: str) -> List[Dict[str, Any]]:
394
+ with debugging.span("get_transaction_problems"):
395
+ response = self.request(
396
+ "get_txn_problems",
397
+ path_params={"txn_id": txn_id},
398
+ )
399
+ response_content = response.json()
400
+ if not response_content:
401
+ return []
402
+ return response_content.get("problems", [])
403
+
404
+ def get_transaction_events(self, transaction_id: str, continuation_token: str = ''):
405
+ response = self.request(
406
+ "get_txn_events",
407
+ path_params={"txn_id": transaction_id, "stream_name": "profiler"},
408
+ query_params={"continuation_token": continuation_token},
409
+ )
410
+ response_content = response.json()
411
+ if not response_content:
412
+ return {
413
+ "events": [],
414
+ "continuation_token": None
415
+ }
416
+ return response_content
417
+
418
+ #--------------------------------------------------
419
+ # Databases
420
+ #--------------------------------------------------
421
+
422
+ def get_installed_packages(self, database: str) -> Union[Dict, None]:
423
+ use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
424
+ if use_graph_index:
425
+ response = self.request(
426
+ "get_model_package_versions",
427
+ payload={"model_name": database},
428
+ )
429
+ else:
430
+ response = self.request(
431
+ "get_package_versions",
432
+ path_params={"db_name": database},
433
+ )
434
+ if response.status_code == 404 and response.json().get("message", "") == "database not found":
435
+ return None
436
+ if response.status_code != 200:
437
+ raise ResponseStatusException(
438
+ f"Failed to retrieve package versions for {database}.", response
439
+ )
440
+
441
+ content = response.json()
442
+ if not content:
443
+ return None
444
+
445
+ return safe_json_loads(content["package_versions"])
446
+
447
+ def get_database(self, database: str):
448
+ with debugging.span("get_database", dbname=database):
449
+ if not database:
450
+ raise ValueError("Database name must be provided to get database.")
451
+ response = self.request(
452
+ "get_db",
453
+ path_params={},
454
+ query_params={"name": database},
455
+ )
456
+ if response.status_code != 200:
457
+ raise ResponseStatusException(f"Failed to get db. db:{database}", response)
458
+
459
+ response_content = response.json()
460
+
461
+ if (response_content.get("databases") and len(response_content["databases"]) == 1):
462
+ db = response_content["databases"][0]
463
+ return {
464
+ "id": db["id"],
465
+ "name": db["name"],
466
+ "created_by": db.get("created_by"),
467
+ "created_on": ms_to_timestamp(db.get("created_on")),
468
+ "deleted_by": db.get("deleted_by"),
469
+ "deleted_on": ms_to_timestamp(db.get("deleted_on")),
470
+ "state": db["state"],
471
+ }
472
+ else:
473
+ return None
474
+
475
+ def create_graph(self, name: str):
476
+ with debugging.span("create_model", dbname=name):
477
+ return self._create_database(name,"")
478
+
479
+ def delete_graph(self, name:str, force=False, language: str = "rel"):
480
+ prop_hdrs = debugging.gen_current_propagation_headers()
481
+ if self.config.get("use_graph_index", USE_GRAPH_INDEX):
482
+ keep_database = not force and self.config.get("reuse_model", True)
483
+ with debugging.span("release_index", name=name, keep_database=keep_database, language=language):
484
+ response = self.request(
485
+ "release_index",
486
+ payload={
487
+ "model_name": name,
488
+ "keep_database": keep_database,
489
+ "language": language,
490
+ "user_agent": get_pyrel_version(self.generation),
491
+ },
492
+ headers=prop_hdrs,
493
+ )
494
+ if (
495
+ response.status_code != 200
496
+ and not (
497
+ response.status_code == 404
498
+ and "database not found" in response.json().get("message", "")
499
+ )
500
+ ):
501
+ raise ResponseStatusException(f"Failed to release index. Model: {name} ", response)
502
+ else:
503
+ with debugging.span("delete_model", name=name):
504
+ self._delete_database(name, headers=prop_hdrs)
505
+
506
+ def clone_graph(self, target_name:str, source_name:str, nowait_durable=True, force=False):
507
+ if force and self.get_graph(target_name):
508
+ self.delete_graph(target_name)
509
+ with debugging.span("clone_model", target_name=target_name, source_name=source_name):
510
+ return self._create_database(target_name,source_name)
511
+
512
+ def _delete_database(self, name:str, headers:Dict={}):
513
+ with debugging.span("_delete_database", dbname=name):
514
+ response = self.request(
515
+ "delete_db",
516
+ path_params={"db_name": name},
517
+ query_params={},
518
+ headers=headers,
519
+ )
520
+ if response.status_code != 200:
521
+ raise ResponseStatusException(f"Failed to delete db. db:{name} ", response)
522
+
523
+ def _create_database(self, name:str, source_name:str):
524
+ with debugging.span("_create_database", dbname=name):
525
+ payload = {
526
+ "name": name,
527
+ "source_name": source_name,
528
+ }
529
+ response = self.request(
530
+ "create_db", payload=payload, headers={}, query_params={},
531
+ )
532
+ if response.status_code != 200:
533
+ raise ResponseStatusException(f"Failed to create db. db:{name}", response)
534
+
535
+ #--------------------------------------------------
536
+ # Engines
537
+ #--------------------------------------------------
538
+
539
+ def list_engines(self, state: str | None = None):
540
+ response = self.request("list_engines")
541
+ if response.status_code != 200:
542
+ raise ResponseStatusException(
543
+ "Failed to retrieve engines.", response
544
+ )
545
+ response_content = response.json()
546
+ if not response_content:
547
+ return []
548
+ engines = [
549
+ {
550
+ "name": engine["name"],
551
+ "id": engine["id"],
552
+ "size": engine["size"],
553
+ "state": engine["status"], # callers are expecting 'state'
554
+ "created_by": engine["created_by"],
555
+ "created_on": engine["created_on"],
556
+ "updated_on": engine["updated_on"],
557
+ }
558
+ for engine in response_content.get("engines", [])
559
+ if state is None or engine.get("status") == state
560
+ ]
561
+ return sorted(engines, key=lambda x: x["name"])
562
+
563
+ def get_engine(self, name: str):
564
+ response = self.request("get_engine", path_params={"engine_name": name, "engine_type": "logic"}, skip_auto_create=True)
565
+ if response.status_code == 404: # engine not found return 404
566
+ return None
567
+ elif response.status_code != 200:
568
+ raise ResponseStatusException(
569
+ f"Failed to retrieve engine {name}.", response
570
+ )
571
+ engine = response.json()
572
+ if not engine:
573
+ return None
574
+ engine_state: EngineState = {
575
+ "name": engine["name"],
576
+ "id": engine["id"],
577
+ "size": engine["size"],
578
+ "state": engine["status"], # callers are expecting 'state'
579
+ "created_by": engine["created_by"],
580
+ "created_on": engine["created_on"],
581
+ "updated_on": engine["updated_on"],
582
+ "version": engine["version"],
583
+ "auto_suspend": engine["auto_suspend_mins"],
584
+ "suspends_at": engine["suspends_at"],
585
+ }
586
+ return engine_state
587
+
588
+ def _create_engine(
589
+ self,
590
+ name: str,
591
+ size: str | None = None,
592
+ auto_suspend_mins: int | None = None,
593
+ is_async: bool = False,
594
+ headers: Dict[str, str] | None = None
595
+ ):
596
+ # only async engine creation supported via direct access
597
+ if not is_async:
598
+ return super()._create_engine(name, size, auto_suspend_mins, is_async, headers=headers)
599
+ payload:Dict[str, Any] = {
600
+ "name": name,
601
+ }
602
+ if auto_suspend_mins is not None:
603
+ payload["auto_suspend_mins"] = auto_suspend_mins
604
+ if size is not None:
605
+ payload["size"] = size
606
+ response = self.request(
607
+ "create_engine",
608
+ payload=payload,
609
+ path_params={"engine_type": "logic"},
610
+ headers=headers,
611
+ skip_auto_create=True,
612
+ )
613
+ if response.status_code != 200:
614
+ raise ResponseStatusException(
615
+ f"Failed to create engine {name} with size {size}.", response
616
+ )
617
+
618
+ def delete_engine(self, name:str, force:bool = False, headers={}):
619
+ response = self.request(
620
+ "delete_engine",
621
+ path_params={"engine_name": name, "engine_type": "logic"},
622
+ headers=headers,
623
+ skip_auto_create=True,
624
+ )
625
+ if response.status_code != 200:
626
+ raise ResponseStatusException(
627
+ f"Failed to delete engine {name}.", response
628
+ )
629
+
630
+ def suspend_engine(self, name:str):
631
+ response = self.request(
632
+ "suspend_engine",
633
+ path_params={"engine_name": name, "engine_type": "logic"},
634
+ skip_auto_create=True,
635
+ )
636
+ if response.status_code != 200:
637
+ raise ResponseStatusException(
638
+ f"Failed to suspend engine {name}.", response
639
+ )
640
+
641
+ def resume_engine_async(self, name:str, headers={}):
642
+ response = self.request(
643
+ "resume_engine",
644
+ path_params={"engine_name": name, "engine_type": "logic"},
645
+ headers=headers,
646
+ skip_auto_create=True,
647
+ )
648
+ if response.status_code != 200:
649
+ raise ResponseStatusException(
650
+ f"Failed to resume engine {name}.", response
651
+ )
652
+ return {}
653
+
654
+ def _poll_use_index(
655
+ self,
656
+ app_name: str,
657
+ sources: Iterable[str],
658
+ model: str,
659
+ engine_name: str,
660
+ engine_size: str | None = None,
661
+ program_span_id: str | None = None,
662
+ headers: Dict | None = None,
663
+ ):
664
+ """Poll use_index to prepare indices for the given sources using DirectUseIndexPoller."""
665
+ return DirectUseIndexPoller(
666
+ self,
667
+ app_name=app_name,
668
+ sources=sources,
669
+ model=model,
670
+ engine_name=engine_name,
671
+ engine_size=engine_size,
672
+ language=self.language,
673
+ program_span_id=program_span_id,
674
+ headers=headers,
675
+ generation=self.generation,
676
+ ).poll()
677
+
678
+ def maybe_poll_use_index(
679
+ self,
680
+ app_name: str,
681
+ sources: Iterable[str],
682
+ model: str,
683
+ engine_name: str,
684
+ engine_size: str | None = None,
685
+ program_span_id: str | None = None,
686
+ headers: Dict | None = None,
687
+ ):
688
+ """Only call poll() if there are sources to process and cache is not valid."""
689
+ sources_list = list(sources)
690
+ self.database = model
691
+ if sources_list:
692
+ poller = DirectUseIndexPoller(
693
+ self,
694
+ app_name=app_name,
695
+ sources=sources_list,
696
+ model=model,
697
+ engine_name=engine_name,
698
+ engine_size=engine_size,
699
+ language=self.language,
700
+ program_span_id=program_span_id,
701
+ headers=headers,
702
+ generation=self.generation,
703
+ )
704
+ # If cache is valid (data freshness has not expired), skip polling
705
+ if poller.cache.is_valid():
706
+ cached_sources = len(poller.cache.sources)
707
+ total_sources = len(sources_list)
708
+ cached_timestamp = poller.cache._metadata.get("cachedIndices", {}).get(poller.cache.key, {}).get("last_use_index_update_on", "")
709
+
710
+ message = f"Using cached data for {cached_sources}/{total_sources} data streams"
711
+ if cached_timestamp:
712
+ print(f"\n{message} (cached at {cached_timestamp})\n")
713
+ else:
714
+ print(f"\n{message}\n")
715
+ else:
716
+ return poller.poll()
717
+