relationalai 0.11.0__py3-none-any.whl → 0.11.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. relationalai/clients/azure.py +7 -4
  2. relationalai/clients/client.py +15 -14
  3. relationalai/clients/config.py +4 -0
  4. relationalai/clients/snowflake.py +70 -14
  5. relationalai/dsl.py +2 -2
  6. relationalai/early_access/dsl/codegen/weaver.py +1 -2
  7. relationalai/errors.py +37 -0
  8. relationalai/experimental/solvers.py +44 -14
  9. relationalai/semantics/devtools/extract_lqp.py +4 -1
  10. relationalai/semantics/internal/internal.py +212 -26
  11. relationalai/semantics/internal/snowflake.py +7 -5
  12. relationalai/semantics/lqp/executor.py +23 -4
  13. relationalai/semantics/lqp/model2lqp.py +53 -8
  14. relationalai/semantics/lqp/primitives.py +17 -1
  15. relationalai/semantics/lqp/result_helpers.py +1 -1
  16. relationalai/semantics/metamodel/builtins.py +137 -8
  17. relationalai/semantics/metamodel/executor.py +2 -2
  18. relationalai/semantics/metamodel/rewrite/extract_keys.py +16 -5
  19. relationalai/semantics/metamodel/typer/typer.py +23 -16
  20. relationalai/semantics/reasoners/__init__.py +0 -4
  21. relationalai/semantics/reasoners/graph/core.py +564 -36
  22. relationalai/semantics/rel/executor.py +15 -6
  23. relationalai/semantics/rel/rel_utils.py +17 -1
  24. relationalai/semantics/rel/rewrite/cdc.py +6 -0
  25. relationalai/semantics/sql/compiler.py +144 -123
  26. relationalai/semantics/sql/executor/duck_db.py +4 -2
  27. relationalai/semantics/sql/executor/snowflake.py +7 -3
  28. relationalai/semantics/sql/sql.py +14 -4
  29. relationalai/semantics/std/__init__.py +5 -3
  30. relationalai/semantics/std/dates.py +38 -14
  31. relationalai/semantics/std/math.py +35 -2
  32. relationalai/semantics/std/strings.py +12 -1
  33. relationalai/tools/cli_controls.py +9 -3
  34. relationalai/tools/constants.py +1 -0
  35. relationalai/tools/qb_debugger.py +22 -9
  36. relationalai/util/timeout.py +24 -0
  37. {relationalai-0.11.0.dist-info → relationalai-0.11.2.dist-info}/METADATA +1 -1
  38. {relationalai-0.11.0.dist-info → relationalai-0.11.2.dist-info}/RECORD +41 -40
  39. {relationalai-0.11.0.dist-info → relationalai-0.11.2.dist-info}/WHEEL +0 -0
  40. {relationalai-0.11.0.dist-info → relationalai-0.11.2.dist-info}/entry_points.txt +0 -0
  41. {relationalai-0.11.0.dist-info → relationalai-0.11.2.dist-info}/licenses/LICENSE +0 -0
@@ -12,7 +12,7 @@ from pandas import DataFrame
12
12
  from relationalai import debugging
13
13
  from relationalai.clients.util import poll_with_specified_overhead
14
14
 
15
- from ..errors import EngineNotFoundException, RAIException
15
+ from ..errors import EngineNotFoundException, RAIException, AzureUnsupportedQueryTimeoutException
16
16
  from ..rel_utils import assert_no_problems
17
17
  from ..loaders.loader import emit_delete_import, import_file, list_available_resources
18
18
  from .config import Config
@@ -227,10 +227,13 @@ class Resources(ResourcesBase):
227
227
  def _exec(self, code:str, params:List[Any]|Any|None = None, raw=False, help=True):
228
228
  raise Exception("Azure doesn't support _exec")
229
229
 
230
- def exec_lqp(self, database: str, engine: str | None, raw_code: bytes, readonly=True, *, inputs: Dict | None = None, nowait_durable=False, headers: Dict | None = None, bypass_index=False):
230
+ def exec_lqp(self, database: str, engine: str | None, raw_code: bytes, readonly=True, *, inputs: Dict | None = None, nowait_durable=False, headers: Dict | None = None, bypass_index=False, query_timeout_mins: int | None = None):
231
231
  raise Exception("Azure doesn't support exec_lqp")
232
232
 
233
- def exec_raw(self, database:str, engine:str|None, raw_code:str, readonly=True, *, inputs: Dict | None = None, nowait_durable=False, headers: Dict | None = None, raw_results=True):
233
+ def exec_raw(self, database:str, engine:str|None, raw_code:str, readonly=True, *, inputs: Dict | None = None, nowait_durable=False, headers: Dict | None = None, raw_results=True, query_timeout_mins: int | None = None):
234
+ if query_timeout_mins is not None or self.config.get("query_timeout_mins", None) is not None:
235
+ config_file_path = getattr(self.config, 'file_path', None)
236
+ raise AzureUnsupportedQueryTimeoutException(config_file_path=config_file_path)
234
237
  if engine is None:
235
238
  engine = self.get_default_engine_name()
236
239
  try:
@@ -289,7 +292,7 @@ class Resources(ResourcesBase):
289
292
  # Exec format
290
293
  #--------------------------------------------------
291
294
 
292
- def exec_format(self, database: str, engine: str, raw_code: str, task:m.Task, format:str, inputs: Dict | None = None, readonly: bool = True, nowait_durable=False, skip_invalid_data=False, headers: Dict | None = None) -> Any: # @FIXME: Better type annotation
295
+ def exec_format(self, database: str, engine: str, raw_code: str, cols:List[str], format:str, inputs: Dict | None = None, readonly: bool = True, nowait_durable=False, skip_invalid_data=False, headers: Dict | None = None, query_timeout_mins: int | None = None) -> Any: # @FIXME: Better type annotation
293
296
  raise Exception("Azure doesn't support alternative formats yet")
294
297
 
295
298
  def to_model_type(self, model:dsl.Graph, name: str, source:str):
@@ -3,7 +3,7 @@ import atexit
3
3
  from datetime import datetime, timedelta, timezone
4
4
  import re
5
5
  from collections import defaultdict
6
- from typing import Dict, List, Any, Tuple, cast, Callable
6
+ from typing import Dict, List, Any, Optional, Tuple, cast, Callable
7
7
 
8
8
  from abc import ABC, abstractmethod
9
9
  from dataclasses import dataclass
@@ -401,15 +401,15 @@ class ResourcesBase(ABC):
401
401
  pass
402
402
 
403
403
  @abstractmethod
404
- def exec_lqp(self, database: str, engine: str | None, raw_code: bytes, readonly: bool = True, *, inputs: Dict | None = None, nowait_durable=False, headers: Dict | None = None) -> Any: # @FIXME: Better type annotation
404
+ def exec_lqp(self, database: str, engine: str | None, raw_code: bytes, readonly: bool = True, *, inputs: Dict | None = None, nowait_durable=False, headers: Dict | None = None, query_timeout_mins: int | None = None) -> Any: # @FIXME: Better type annotation
405
405
  pass
406
406
 
407
407
  @abstractmethod
408
- def exec_raw(self, database: str, engine: str | None, raw_code: str, readonly: bool = True, *, inputs: Dict | None = None, nowait_durable=False, headers: Dict | None = None) -> Any: # @FIXME: Better type annotation
408
+ def exec_raw(self, database: str, engine: str | None, raw_code: str, readonly: bool = True, *, inputs: Dict | None = None, nowait_durable=False, headers: Dict | None = None, query_timeout_mins: Optional[int]=None) -> Any: # @FIXME: Better type annotation
409
409
  pass
410
410
 
411
411
  @abstractmethod
412
- def exec_format(self, database: str, engine: str, raw_code: str, task:m.Task, format:str, inputs: Dict | None = None, readonly: bool = True, nowait_durable=False, skip_invalid_data=False, headers: Dict | None = None) -> Any: # @FIXME: Better type annotation
412
+ def exec_format(self, database: str, engine: str, raw_code: str, cols:List[str], format:str, inputs: Dict | None = None, readonly: bool = True, nowait_durable=False, skip_invalid_data=False, headers: Dict | None = None, query_timeout_mins: Optional[int]=None) -> Any: # @FIXME: Better type annotation
413
413
  pass
414
414
 
415
415
  @abstractmethod
@@ -695,10 +695,10 @@ class Client():
695
695
  code = self.compiler.compile(dsl.build.raw_task(content))
696
696
  self._install_batch.set(path, code)
697
697
 
698
- def exec_raw(self, code:str, readonly=True, raw_results=True, inputs: Dict | None = None, internal=False, nowait_durable=None, abort_on_error=True, headers: Dict | None = None) -> DataFrame|Any:
698
+ def exec_raw(self, code:str, readonly=True, raw_results=True, inputs: Dict | None = None, internal=False, nowait_durable=None, abort_on_error=True, headers: Dict | None = None, query_timeout_mins: Optional[int]=None) -> DataFrame|Any:
699
699
  task = dsl.build.raw_task(code)
700
700
  debugging.set_source(task)
701
- return self.query(task, read_only=readonly, raw_results=raw_results, inputs=inputs, internal=internal, nowait_durable=nowait_durable, headers=headers, abort_on_error=abort_on_error)
701
+ return self.query(task, read_only=readonly, raw_results=raw_results, inputs=inputs, internal=internal, nowait_durable=nowait_durable, headers=headers, abort_on_error=abort_on_error, query_timeout_mins=query_timeout_mins)
702
702
 
703
703
  def exec_control(self, code:str, cb:Callable[[DataFrame]]|None=None):
704
704
  self._install_batch.control_items.append((code, cb))
@@ -719,12 +719,12 @@ class Client():
719
719
  # Query
720
720
  #--------------------------------------------------
721
721
 
722
- def _query(self, code:str, task:m.Task|None, end_span, readonly=False, inputs: Dict | None = None, nowait_durable=None, headers: Dict | None = None, abort_on_error=True):
722
+ def _query(self, code:str, task:m.Task|None, end_span, readonly=False, inputs: Dict | None = None, nowait_durable=None, headers: Dict | None = None, abort_on_error=True, query_timeout_mins: Optional[int]=None):
723
723
  if nowait_durable is None:
724
724
  nowait_durable = self.isolated
725
725
 
726
726
  try:
727
- results = self.resources.exec_raw(self._database, self.get_engine_name(), code, readonly=readonly, inputs=inputs, nowait_durable=nowait_durable, headers=headers)
727
+ results = self.resources.exec_raw(self._database, self.get_engine_name(), code, readonly=readonly, inputs=inputs, nowait_durable=nowait_durable, headers=headers, query_timeout_mins=query_timeout_mins)
728
728
  dataframe, errors = self.resources.format_results(results, task)
729
729
  end_span["results"] = dataframe
730
730
  end_span["errors"] = errors
@@ -736,13 +736,14 @@ class Client():
736
736
  engine_name = self.get_engine_name()
737
737
  self.resources.resume_engine(engine_name, headers=headers)
738
738
  # invoke _query again to retry the query
739
- return self._query(code, task, end_span, readonly=readonly, inputs=inputs, nowait_durable=nowait_durable, headers=headers, abort_on_error=abort_on_error)
739
+ return self._query(code, task, end_span, readonly=readonly, inputs=inputs, nowait_durable=nowait_durable, headers=headers, abort_on_error=abort_on_error, query_timeout_mins=query_timeout_mins)
740
740
  else:
741
741
  raise e
742
742
 
743
743
 
744
- def _query_format(self, code:str, task:m.Task, end_span, format, readonly=False, skip_invalid_data=False, inputs: Dict | None = None):
745
- results, raw = self.resources.exec_format(self._database, self.get_engine_name(), code, task, readonly=readonly, inputs=inputs, format=format, skip_invalid_data=skip_invalid_data)
744
+ def _query_format(self, code:str, task:m.Task, end_span, format, readonly=False, skip_invalid_data=False, inputs: Dict | None = None, query_timeout_mins: Optional[int]=None):
745
+ cols = task.return_cols(allow_dups=False)
746
+ results, raw = self.resources.exec_format(self._database, self.get_engine_name(), code, cols, readonly=readonly, inputs=inputs, format=format, skip_invalid_data=skip_invalid_data, query_timeout_mins=query_timeout_mins)
746
747
  errors = []
747
748
  if raw:
748
749
  dataframe, errors = self.resources.format_results(raw, task)
@@ -752,7 +753,7 @@ class Client():
752
753
  # return results if raw_results else dataframe
753
754
  return results, raw
754
755
 
755
- def query(self, task:m.Task, rentrant=False, read_only=False, raw_results=False, inputs: Dict | None = None, format="pandas", tag=None, nowait_durable=None, headers: Dict | None = None, internal=False, abort_on_error=True, skip_invalid_data = False) -> DataFrame|Any:
756
+ def query(self, task:m.Task, rentrant=False, read_only=False, raw_results=False, inputs: Dict | None = None, format="pandas", tag=None, nowait_durable=None, headers: Dict | None = None, internal=False, abort_on_error=True, skip_invalid_data = False, query_timeout_mins: Optional[int]=None) -> DataFrame|Any:
756
757
  if not self.dry_run and self.use_graph_index:
757
758
  self.create_database(isolated=self.isolated, headers=headers)
758
759
 
@@ -793,10 +794,10 @@ class Client():
793
794
 
794
795
  start = time.perf_counter()
795
796
  if format == "pandas":
796
- results, raw = self._query(code, task, end_span, readonly=read_only, inputs=inputs, nowait_durable=nowait_durable, headers=headers, abort_on_error=abort_on_error)
797
+ results, raw = self._query(code, task, end_span, readonly=read_only, inputs=inputs, nowait_durable=nowait_durable, headers=headers, abort_on_error=abort_on_error, query_timeout_mins=query_timeout_mins)
797
798
  debugging.time("query", time.perf_counter() - start, DataFrame() if raw_results else results, internal=internal, source_map=source_map)
798
799
  else:
799
- results, raw = self._query_format(code, task, end_span, readonly=read_only, inputs=inputs, format=format, skip_invalid_data=skip_invalid_data)
800
+ results, raw = self._query_format(code, task, end_span, readonly=read_only, inputs=inputs, format=format, skip_invalid_data=skip_invalid_data, query_timeout_mins=query_timeout_mins)
800
801
  debugging.time("query", time.perf_counter() - start, DataFrame(), source_map=source_map, alt_format_results=results)
801
802
 
802
803
  self._install_batch.clear_dirty()
@@ -28,6 +28,10 @@ PUBLIC_CONFIG_KEYS = [
28
28
  "ensure_change_tracking",
29
29
  "download_url_type",
30
30
  "use_direct_access",
31
+ # query_timeout_mins allows to specify a timeout in minutes applied to all queries. When
32
+ # a query execution time exceeds this timeout, the query will be aborted. This is useful
33
+ # to avoid long-running queries that can incur high costs.
34
+ "query_timeout_mins",
31
35
  ]
32
36
 
33
37
  #--------------------------------------------------
@@ -21,7 +21,7 @@ from relationalai.clients.use_index_poller import DirectUseIndexPoller, UseIndex
21
21
  import snowflake.snowpark
22
22
 
23
23
  from relationalai.rel_utils import sanitize_identifier, to_fqn_relation_name
24
- from relationalai.tools.constants import FIELD_PLACEHOLDER, RAI_APP_NAME, SNOWFLAKE_AUTHS, USE_GRAPH_INDEX, USE_DIRECT_ACCESS, WAIT_FOR_STREAM_SYNC, Generation
24
+ from relationalai.tools.constants import FIELD_PLACEHOLDER, RAI_APP_NAME, SNOWFLAKE_AUTHS, USE_GRAPH_INDEX, USE_DIRECT_ACCESS, DEFAULT_QUERY_TIMEOUT_MINS, WAIT_FOR_STREAM_SYNC, Generation
25
25
  from .. import std
26
26
  from collections import defaultdict
27
27
  import requests
@@ -1516,8 +1516,11 @@ Otherwise, remove it from your '{profile}' configuration profile.
1516
1516
  request_headers: Dict | None = None,
1517
1517
  bypass_index=False,
1518
1518
  language: str = "rel",
1519
+ query_timeout_mins: int | None = None,
1519
1520
  ):
1520
1521
  assert language == "rel" or language == "lqp", "Only 'rel' and 'lqp' languages are supported"
1522
+ if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
1523
+ query_timeout_mins = int(timeout_value)
1521
1524
  # Depending on the shape of the input, the behavior of exec_async_v2 changes.
1522
1525
  # When using the new format (with an object), the function retrieves the
1523
1526
  # 'rai' database by hashing the model and username. In contrast, the
@@ -1526,9 +1529,23 @@ Otherwise, remove it from your '{profile}' configuration profile.
1526
1529
  # graph index to ensure the correct database is utilized.
1527
1530
  use_graph_index = self.config.get("use_graph_index", USE_GRAPH_INDEX)
1528
1531
  if use_graph_index and not bypass_index:
1529
- sql_string = f"CALL {APP_NAME}.api.exec_async_v2(?, {{'database': '{database}', 'engine': '{engine}', 'inputs': {inputs}, 'readonly': {readonly}, 'nowait_durable': {nowait_durable}, 'language': '{language}', 'headers': {request_headers}}});"
1532
+ payload = {
1533
+ 'database': database,
1534
+ 'engine': engine,
1535
+ 'inputs': inputs,
1536
+ 'readonly': readonly,
1537
+ 'nowait_durable': nowait_durable,
1538
+ 'language': language,
1539
+ 'headers': request_headers
1540
+ }
1541
+ if query_timeout_mins is not None:
1542
+ payload["timeout_mins"] = query_timeout_mins
1543
+ sql_string = f"CALL {APP_NAME}.api.exec_async_v2(?, {payload});"
1530
1544
  else:
1531
- sql_string = f"CALL {APP_NAME}.api.exec_async_v2('{database}','{engine}', ?, {inputs}, {readonly}, {nowait_durable}, '{language}', {request_headers});"
1545
+ if query_timeout_mins is not None:
1546
+ sql_string = f"CALL {APP_NAME}.api.exec_async_v2('{database}','{engine}', ?, {inputs}, {readonly}, {nowait_durable}, '{language}', {query_timeout_mins}, {request_headers});"
1547
+ else:
1548
+ sql_string = f"CALL {APP_NAME}.api.exec_async_v2('{database}','{engine}', ?, {inputs}, {readonly}, {nowait_durable}, '{language}', {request_headers});"
1532
1549
  response = self._exec(
1533
1550
  sql_string,
1534
1551
  raw_code,
@@ -1548,6 +1565,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1548
1565
  headers: Dict | None = None,
1549
1566
  bypass_index=False,
1550
1567
  language: str = "rel",
1568
+ query_timeout_mins: int | None = None,
1551
1569
  ):
1552
1570
  if inputs is None:
1553
1571
  inputs = {}
@@ -1567,6 +1585,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1567
1585
  request_headers=request_headers,
1568
1586
  bypass_index=bypass_index,
1569
1587
  language=language,
1588
+ query_timeout_mins=query_timeout_mins,
1570
1589
  )
1571
1590
 
1572
1591
  artifact_info = {}
@@ -1825,14 +1844,16 @@ Otherwise, remove it from your '{profile}' configuration profile.
1825
1844
  inputs: Dict | None = None,
1826
1845
  nowait_durable=False,
1827
1846
  headers: Dict | None = None,
1828
- bypass_index=False
1847
+ bypass_index=False,
1848
+ query_timeout_mins: int | None = None,
1829
1849
  ):
1830
1850
  raw_code_b64 = base64.b64encode(raw_code).decode("utf-8")
1831
1851
 
1832
1852
  try:
1833
1853
  return self._exec_async_v2(
1834
1854
  database, engine, raw_code_b64, inputs, readonly, nowait_durable,
1835
- headers=headers, bypass_index=bypass_index, language='lqp'
1855
+ headers=headers, bypass_index=bypass_index, language='lqp',
1856
+ query_timeout_mins=query_timeout_mins,
1836
1857
  )
1837
1858
  except Exception as e:
1838
1859
  err_message = str(e).lower()
@@ -1840,7 +1861,8 @@ Otherwise, remove it from your '{profile}' configuration profile.
1840
1861
  self.auto_create_engine(engine)
1841
1862
  self._exec_async_v2(
1842
1863
  database, engine, raw_code_b64, inputs, readonly, nowait_durable,
1843
- headers=headers, bypass_index=bypass_index, language='lqp'
1864
+ headers=headers, bypass_index=bypass_index, language='lqp',
1865
+ query_timeout_mins=query_timeout_mins,
1844
1866
  )
1845
1867
  else:
1846
1868
  raise e
@@ -1856,19 +1878,38 @@ Otherwise, remove it from your '{profile}' configuration profile.
1856
1878
  inputs: Dict | None = None,
1857
1879
  nowait_durable=False,
1858
1880
  headers: Dict | None = None,
1859
- bypass_index=False
1881
+ bypass_index=False,
1882
+ query_timeout_mins: int | None = None,
1860
1883
  ):
1861
1884
  raw_code = raw_code.replace("'", "\\'")
1862
1885
 
1863
1886
  try:
1864
1887
  return self._exec_async_v2(
1865
- database, engine, raw_code, inputs, readonly, nowait_durable, headers=headers, bypass_index=bypass_index
1888
+ database,
1889
+ engine,
1890
+ raw_code,
1891
+ inputs,
1892
+ readonly,
1893
+ nowait_durable,
1894
+ headers=headers,
1895
+ bypass_index=bypass_index,
1896
+ query_timeout_mins=query_timeout_mins,
1866
1897
  )
1867
1898
  except Exception as e:
1868
1899
  err_message = str(e).lower()
1869
1900
  if _is_engine_issue(err_message):
1870
1901
  self.auto_create_engine(engine)
1871
- return self._exec_async_v2(database, engine, raw_code, inputs, readonly, nowait_durable, headers=headers, bypass_index=bypass_index)
1902
+ return self._exec_async_v2(
1903
+ database,
1904
+ engine,
1905
+ raw_code,
1906
+ inputs,
1907
+ readonly,
1908
+ nowait_durable,
1909
+ headers=headers,
1910
+ bypass_index=bypass_index,
1911
+ query_timeout_mins=query_timeout_mins,
1912
+ )
1872
1913
  else:
1873
1914
  raise e
1874
1915
 
@@ -1885,13 +1926,14 @@ Otherwise, remove it from your '{profile}' configuration profile.
1885
1926
  database: str,
1886
1927
  engine: str,
1887
1928
  raw_code: str,
1888
- task: m.Task,
1929
+ cols: List[str],
1889
1930
  format: str,
1890
1931
  inputs: Dict | None = None,
1891
1932
  readonly=True,
1892
1933
  nowait_durable=False,
1893
1934
  skip_invalid_data=False,
1894
1935
  headers: Dict | None = None,
1936
+ query_timeout_mins: int | None = None,
1895
1937
  ):
1896
1938
  if inputs is None:
1897
1939
  inputs = {}
@@ -1899,6 +1941,8 @@ Otherwise, remove it from your '{profile}' configuration profile.
1899
1941
  headers = {}
1900
1942
  if 'user-agent' not in headers:
1901
1943
  headers['user-agent'] = get_pyrel_version(self.generation)
1944
+ if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
1945
+ query_timeout_mins = int(timeout_value)
1902
1946
  # TODO: add headers
1903
1947
  start = time.perf_counter()
1904
1948
  output_table = "out" + str(uuid.uuid4()).replace("-", "_")
@@ -1909,18 +1953,25 @@ Otherwise, remove it from your '{profile}' configuration profile.
1909
1953
  col_names_map = None
1910
1954
  artifacts = None
1911
1955
  assert self._session
1912
- temp = self._session.createDataFrame([], StructType([StructField(name, StringType()) for name in task.return_cols(allow_dups=False)]))
1956
+ temp = self._session.createDataFrame([], StructType([StructField(name, StringType()) for name in cols]))
1913
1957
  with debugging.span("transaction") as txn_span:
1914
1958
  try:
1915
1959
  # In the graph index case we need to use the new exec_into_table proc as it obfuscates the db name
1916
1960
  with debugging.span("exec_format"):
1917
1961
  if use_graph_index:
1918
- res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, ?, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data])
1962
+ # we do not provide a default value for query_timeout_mins so that we can control the default on app level
1963
+ if query_timeout_mins is not None:
1964
+ res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, ?, ?, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data, query_timeout_mins])
1965
+ else:
1966
+ res = self._exec(f"call {APP_NAME}.api.exec_into_table(?, ?, ?, ?, ?, ?, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data])
1919
1967
  txn_id = json.loads(res[0]["EXEC_INTO_TABLE"])["rai_transaction_id"]
1920
1968
  rejected_rows = json.loads(res[0]["EXEC_INTO_TABLE"]).get("rejected_rows", [])
1921
1969
  rejected_rows_count = json.loads(res[0]["EXEC_INTO_TABLE"]).get("rejected_rows_count", 0)
1922
1970
  else:
1923
- res = self._exec(f"call {APP_NAME}.api.exec_into(?, ?, ?, ?, ?, {inputs}, ?, {headers}, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data])
1971
+ if query_timeout_mins is not None:
1972
+ res = self._exec(f"call {APP_NAME}.api.exec_into(?, ?, ?, ?, ?, {inputs}, ?, {headers}, ?, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data, query_timeout_mins])
1973
+ else:
1974
+ res = self._exec(f"call {APP_NAME}.api.exec_into(?, ?, ?, ?, ?, {inputs}, ?, {headers}, ?);", [database, engine, raw_code, output_table, readonly, nowait_durable, skip_invalid_data])
1924
1975
  txn_id = json.loads(res[0]["EXEC_INTO"])["rai_transaction_id"]
1925
1976
  rejected_rows = json.loads(res[0]["EXEC_INTO"]).get("rejected_rows", [])
1926
1977
  rejected_rows_count = json.loads(res[0]["EXEC_INTO"]).get("rejected_rows_count", 0)
@@ -1932,7 +1983,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
1932
1983
  if out_sample:
1933
1984
  keys = set([k.lower() for k in out_sample[0].as_dict().keys()])
1934
1985
  col_names_map = {}
1935
- for ix, name in enumerate(task.return_cols(allow_dups=False)):
1986
+ for ix, name in enumerate(cols):
1936
1987
  col_key = f"col{ix:03}"
1937
1988
  if col_key in keys:
1938
1989
  col_names_map[col_key] = IdentityParser(name).identity
@@ -3095,6 +3146,7 @@ class DirectAccessResources(Resources):
3095
3146
  headers: Dict[str, str] | None = None,
3096
3147
  bypass_index=False,
3097
3148
  language: str = "rel",
3149
+ query_timeout_mins: int | None = None,
3098
3150
  ):
3099
3151
 
3100
3152
  with debugging.span("transaction") as txn_span:
@@ -3111,6 +3163,10 @@ class DirectAccessResources(Resources):
3111
3163
  "readonly": readonly,
3112
3164
  "language": language,
3113
3165
  }
3166
+ if query_timeout_mins is None and (timeout_value := self.config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
3167
+ query_timeout_mins = int(timeout_value)
3168
+ if query_timeout_mins is not None:
3169
+ payload["timeout_mins"] = query_timeout_mins
3114
3170
  query_params={"use_graph_index": str(use_graph_index and not bypass_index)}
3115
3171
 
3116
3172
  response = self.request(
relationalai/dsl.py CHANGED
@@ -1736,9 +1736,9 @@ class Graph:
1736
1736
  file_path = os.path.join(root, file)
1737
1737
  self._client.load_raw_file(file_path)
1738
1738
 
1739
- def exec_raw(self, code:str, readonly=False, raw_results=True, abort_on_error=True, inputs={}):
1739
+ def exec_raw(self, code:str, readonly=False, raw_results=True, abort_on_error=True, inputs={}, query_timeout_mins: Optional[int]=None):
1740
1740
  try:
1741
- return self._client.exec_raw(code, readonly=readonly, raw_results=raw_results, inputs=inputs, abort_on_error=abort_on_error)
1741
+ return self._client.exec_raw(code, readonly=readonly, raw_results=raw_results, inputs=inputs, abort_on_error=abort_on_error, query_timeout_mins=query_timeout_mins)
1742
1742
  except KeyboardInterrupt as e:
1743
1743
  print("Canceling transactions...")
1744
1744
  self.resources.cancel_pending_transactions()
@@ -103,8 +103,7 @@ class Weaver:
103
103
  # binding for each supertype with a reference scheme.
104
104
  #=
105
105
  constructor_binding = (
106
- self._binder.lookup_constructor_binding(concept, binding.column)
107
- if constructed_concept is concept
106
+ binding if constructed_concept is concept
108
107
  else self._binder.lookup_constructor_binding_by_source(constructed_concept, binding.column.table)
109
108
  )
110
109
 
relationalai/errors.py CHANGED
@@ -2395,3 +2395,40 @@ class UnsupportedColumnTypesWarning(RAIWarning):
2395
2395
 
2396
2396
  {note}
2397
2397
  """)
2398
+
2399
+ class QueryTimeoutExceededException(RAIException):
2400
+ def __init__(self, timeout_mins: int, config_file_path: str | None = None):
2401
+ self.timeout_mins = timeout_mins
2402
+ self.message = f"Query execution time exceeded the specified timeout of {timeout_mins} minutes."
2403
+ self.name = "Query Timeout Exceeded"
2404
+ self.config_file_path = config_file_path or ""
2405
+ self.content = self.format_message()
2406
+ super().__init__(self.message, self.name, self.content)
2407
+
2408
+ def format_message(self):
2409
+ return textwrap.dedent(f"""
2410
+ {self.message}
2411
+
2412
+ Consider increasing the 'query_timeout_mins' parameter in your configuration file{f' (stored in {self.config_file_path})' if self.config_file_path else ''} to allow more time for query execution.
2413
+ """)
2414
+
2415
+
2416
+ #--------------------------------------------------
2417
+ # Azure Exceptions
2418
+ #--------------------------------------------------
2419
+
2420
+ class AzureUnsupportedQueryTimeoutException(RAIException):
2421
+ def __init__(self, config_file_path: str | None = None):
2422
+ self.message = "Query timeouts aren't supported on platform Azure."
2423
+ self.name = "Azure Unsupported Query Timeout Error"
2424
+ self.config_file_path = config_file_path or ""
2425
+ self.content = self.format_message()
2426
+ super().__init__(self.message, self.name, self.content)
2427
+
2428
+ def format_message(self):
2429
+ return textwrap.dedent(f"""
2430
+ {self.message}
2431
+
2432
+ Please remove the 'query_timeout_mins' from your configuration file{f' (stored in {self.config_file_path})' if self.config_file_path else ''} when running on platform Azure.
2433
+ """)
2434
+
@@ -1,17 +1,20 @@
1
1
  from __future__ import annotations
2
- from typing import Any, List
2
+ import time
3
+ from typing import Any, List, Optional
3
4
  from dataclasses import dataclass
4
5
  import textwrap
5
6
  from .. import dsl, std
6
7
  from ..std import rel
7
8
  from ..metamodel import Builtins
8
9
  from ..tools.cli_controls import Spinner
10
+ from ..tools.constants import DEFAULT_QUERY_TIMEOUT_MINS
9
11
  from .. import debugging
10
12
  import uuid
11
13
  import relationalai
12
14
  import json
13
15
  from ..clients.util import poll_with_specified_overhead
14
16
  from ..clients.snowflake import Resources as SnowflakeResources
17
+ from ..util.timeout import calc_remaining_timeout_minutes
15
18
 
16
19
  rel_sv = rel._tagged(Builtins.SingleValued)
17
20
 
@@ -203,10 +206,20 @@ class SolverModel:
203
206
  payload["options"] = options
204
207
  payload["model_uri"] = sf_input_uri
205
208
 
209
+ rai_config = self.graph._config
210
+ query_timeout_mins = kwargs.get("query_timeout_mins", None)
211
+ if query_timeout_mins is None and (timeout_value := rai_config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
212
+ query_timeout_mins = int(timeout_value)
213
+ config_file_path = getattr(rai_config, 'file_path', None)
214
+ start_time = time.monotonic()
215
+ remaining_timeout_minutes = query_timeout_mins
206
216
  # 1. Materialize the model and store it.
207
217
  # TODO(coey) Currently we must run a dummy query to install the pyrel rules in a separate txn
208
218
  # to the solve_output updates. Ideally pyrel would offer an option to flush the rules separately.
209
- self.graph.exec_raw("")
219
+ self.graph.exec_raw("", query_timeout_mins=remaining_timeout_minutes)
220
+ remaining_timeout_minutes = calc_remaining_timeout_minutes(
221
+ start_time, query_timeout_mins, config_file_path=config_file_path,
222
+ )
210
223
  response = self.graph.exec_raw(
211
224
  textwrap.dedent(f"""
212
225
  @inline
@@ -235,7 +248,8 @@ class SolverModel:
235
248
  def config[:envelope, :payload, :data]: {scope}model_string
236
249
  def config[:envelope, :payload, :path]: "{model_uri}"
237
250
  def export {{ config }}
238
- """)
251
+ """),
252
+ query_timeout_mins=remaining_timeout_minutes,
239
253
  )
240
254
  txn = response.transaction or {}
241
255
  # The above `exec_raw` will throw an error if the transaction
@@ -248,17 +262,26 @@ class SolverModel:
248
262
  raise Exception(f"Transaction that materializes the solver inputs did not complete! ID: `{txn['id']}` State `{txn['state']}`")
249
263
 
250
264
  # 2. Execute job and wait for completion.
265
+ remaining_timeout_minutes = calc_remaining_timeout_minutes(
266
+ start_time, query_timeout_mins, config_file_path=config_file_path
267
+ )
251
268
  try:
252
- job_id = solver._exec_job(payload, log_to_console=log_to_console)
269
+ job_id = solver._exec_job(payload, log_to_console=log_to_console, query_timeout_mins=remaining_timeout_minutes)
253
270
  except Exception as e:
254
271
  err_message = str(e).lower()
255
272
  if any(kw in err_message.lower() for kw in ENGINE_ERRORS + ENGINE_NOT_READY_MSGS):
256
273
  solver._auto_create_solver_async()
257
- job_id = solver._exec_job(payload, log_to_console=log_to_console)
274
+ remaining_timeout_minutes = calc_remaining_timeout_minutes(
275
+ start_time, query_timeout_mins, config_file_path=config_file_path
276
+ )
277
+ job_id = solver._exec_job(payload, log_to_console=log_to_console, query_timeout_mins=remaining_timeout_minutes)
258
278
  else:
259
279
  raise e
260
280
 
261
281
  # 3. Extract result.
282
+ remaining_timeout_minutes = calc_remaining_timeout_minutes(
283
+ start_time, query_timeout_mins, config_file_path=config_file_path
284
+ )
262
285
  res = self.graph.exec_raw(
263
286
  textwrap.dedent(f"""
264
287
  ic result_not_empty("Solver result is empty.") requires not empty(result)
@@ -273,6 +296,7 @@ class SolverModel:
273
296
  def output[:solver_error]: {scope}solve_output[:"i_{self.solve_index}", :error]
274
297
  """),
275
298
  readonly=False,
299
+ query_timeout_mins=remaining_timeout_minutes,
276
300
  )
277
301
  errors = []
278
302
  for result in res.results:
@@ -506,10 +530,10 @@ class Solver:
506
530
  self.provider = Provider(resources=resources)
507
531
  self.solver_name = solver_name.lower()
508
532
 
509
- rai_config = self.provider.resources.config
533
+ self.rai_config = self.provider.resources.config
510
534
  settings: dict[str, Any] = {}
511
- if "experimental" in rai_config:
512
- exp_config = rai_config.get("experimental", {})
535
+ if "experimental" in self.rai_config:
536
+ exp_config = self.rai_config.get("experimental", {})
513
537
  if isinstance(exp_config, dict):
514
538
  if "solvers" in exp_config:
515
539
  settings = exp_config["solvers"].copy()
@@ -629,23 +653,29 @@ class Solver:
629
653
 
630
654
  self.engine = engine
631
655
 
632
- def _exec_job_async(self, payload):
656
+ def _exec_job_async(self, payload, query_timeout_mins: Optional[int]=None):
633
657
  payload_json = json.dumps(payload)
634
658
  engine_name = self.engine["name"]
635
- res = self.provider.resources._exec(
636
- textwrap.dedent(f"""
659
+ if query_timeout_mins is None and (timeout_value := self.rai_config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
660
+ query_timeout_mins = int(timeout_value)
661
+ if query_timeout_mins is not None:
662
+ sql_string = textwrap.dedent(f"""
663
+ CALL {APP_NAME}.experimental.exec_job_async('{ENGINE_TYPE_SOLVER}', '{engine_name}', '{payload_json}', null, {query_timeout_mins})
664
+ """)
665
+ else:
666
+ sql_string = textwrap.dedent(f"""
637
667
  CALL {APP_NAME}.experimental.exec_job_async('{ENGINE_TYPE_SOLVER}', '{engine_name}', '{payload_json}')
638
668
  """)
639
- )
669
+ res = self.provider.resources._exec(sql_string)
640
670
  return res[0]["ID"]
641
671
 
642
- def _exec_job(self, payload, log_to_console=True):
672
+ def _exec_job(self, payload, log_to_console=True, query_timeout_mins: Optional[int]=None):
643
673
  # Make sure the engine is ready.
644
674
  if self.engine["state"] != "READY":
645
675
  poll_with_specified_overhead(lambda: self._is_solver_ready(), 0.1)
646
676
 
647
677
  with debugging.span("job") as job_span:
648
- job_id = self._exec_job_async(payload)
678
+ job_id = self._exec_job_async(payload, query_timeout_mins=query_timeout_mins)
649
679
  job_span["job_id"] = job_id
650
680
  debugging.event("job_created", job_span, job_id=job_id, engine_name=self.engine["name"], job_type=ENGINE_TYPE_SOLVER)
651
681
  polling_state = PollingState(job_id, "", False, log_to_console)
@@ -9,7 +9,7 @@ from contextlib import contextmanager
9
9
 
10
10
  from relationalai.clients.snowflake import Resources as snowflake_api
11
11
  from relationalai.semantics.internal import internal
12
- from typing import Dict
12
+ from typing import Dict, Optional
13
13
 
14
14
  def main():
15
15
  parser = argparse.ArgumentParser(description="Extract LQP requests to run locally")
@@ -44,6 +44,7 @@ def instrumented_exec_rai_app(captured_calls, call_counter):
44
44
  request_headers=None,
45
45
  bypass_index=False,
46
46
  language: str = "rel",
47
+ query_timeout_mins: Optional[int] = None,
47
48
  ):
48
49
  result = original_exec_rai(
49
50
  self,
@@ -56,6 +57,7 @@ def instrumented_exec_rai_app(captured_calls, call_counter):
56
57
  request_headers=request_headers,
57
58
  bypass_index=bypass_index,
58
59
  language=language,
60
+ query_timeout_mins=query_timeout_mins,
59
61
  )
60
62
 
61
63
  call_counter[0] += 1
@@ -67,6 +69,7 @@ def instrumented_exec_rai_app(captured_calls, call_counter):
67
69
  "readonly": readonly,
68
70
  "nowait_durable": nowait_durable,
69
71
  "language": language,
72
+ "timeout_mins": query_timeout_mins,
70
73
  "raw_code": raw_code,
71
74
  }
72
75
  captured_calls.append(exec_call_json)