relationalai 0.12.2__py3-none-any.whl → 0.12.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/clients/snowflake.py +117 -28
- relationalai/clients/use_index_poller.py +3 -0
- relationalai/experimental/solvers.py +18 -19
- relationalai/semantics/internal/snowflake.py +2 -3
- relationalai/semantics/lqp/executor.py +39 -9
- relationalai/semantics/lqp/model2lqp.py +0 -1
- relationalai/semantics/lqp/rewrite/extract_common.py +30 -8
- relationalai/semantics/metamodel/builtins.py +6 -6
- relationalai/semantics/metamodel/dependency.py +44 -21
- relationalai/semantics/metamodel/helpers.py +7 -6
- relationalai/semantics/metamodel/rewrite/extract_nested_logicals.py +1 -4
- relationalai/semantics/metamodel/rewrite/flatten.py +1 -13
- relationalai/semantics/reasoners/graph/core.py +803 -121
- relationalai/semantics/rel/executor.py +13 -6
- relationalai/semantics/sql/executor/snowflake.py +2 -2
- relationalai/semantics/std/math.py +2 -2
- {relationalai-0.12.2.dist-info → relationalai-0.12.4.dist-info}/METADATA +1 -1
- {relationalai-0.12.2.dist-info → relationalai-0.12.4.dist-info}/RECORD +21 -21
- {relationalai-0.12.2.dist-info → relationalai-0.12.4.dist-info}/WHEEL +0 -0
- {relationalai-0.12.2.dist-info → relationalai-0.12.4.dist-info}/entry_points.txt +0 -0
- {relationalai-0.12.2.dist-info → relationalai-0.12.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -83,6 +83,7 @@ FIELD_MAP = {
|
|
|
83
83
|
VALID_IMPORT_STATES = ["PENDING", "PROCESSING", "QUARANTINED", "LOADED"]
|
|
84
84
|
ENGINE_ERRORS = ["engine is suspended", "create/resume", "engine not found", "no engines found", "engine was deleted"]
|
|
85
85
|
ENGINE_NOT_READY_MSGS = ["engine is in pending", "engine is provisioning"]
|
|
86
|
+
DATABASE_ERRORS = ["database not found"]
|
|
86
87
|
PYREL_ROOT_DB = 'pyrel_root_db'
|
|
87
88
|
|
|
88
89
|
TERMINAL_TXN_STATES = ["COMPLETED", "ABORTED"]
|
|
@@ -281,6 +282,8 @@ def _sanitize_user_name(user: str) -> str:
|
|
|
281
282
|
def _is_engine_issue(response_message: str) -> bool:
|
|
282
283
|
return any(kw in response_message.lower() for kw in ENGINE_ERRORS + ENGINE_NOT_READY_MSGS)
|
|
283
284
|
|
|
285
|
+
def _is_database_issue(response_message: str) -> bool:
|
|
286
|
+
return any(kw in response_message.lower() for kw in DATABASE_ERRORS)
|
|
284
287
|
|
|
285
288
|
|
|
286
289
|
#--------------------------------------------------
|
|
@@ -298,6 +301,7 @@ class Resources(ResourcesBase):
|
|
|
298
301
|
dry_run: bool = False,
|
|
299
302
|
reset_session: bool = False,
|
|
300
303
|
generation: Generation | None = None,
|
|
304
|
+
language: str = "rel",
|
|
301
305
|
):
|
|
302
306
|
super().__init__(profile, config=config)
|
|
303
307
|
self._token_handler: TokenHandler | None = None
|
|
@@ -315,6 +319,8 @@ class Resources(ResourcesBase):
|
|
|
315
319
|
# self.sources contains fully qualified Snowflake table/view names
|
|
316
320
|
self.sources: set[str] = set()
|
|
317
321
|
self._sproc_models = None
|
|
322
|
+
self.database = ""
|
|
323
|
+
self.language = language
|
|
318
324
|
atexit.register(self.cancel_pending_transactions)
|
|
319
325
|
|
|
320
326
|
@property
|
|
@@ -452,6 +458,7 @@ class Resources(ResourcesBase):
|
|
|
452
458
|
rai_app = self.config.get("rai_app_name", "")
|
|
453
459
|
current_role = self.config.get("role")
|
|
454
460
|
engine = self.get_default_engine_name()
|
|
461
|
+
engine_size = self.config.get_default_engine_size()
|
|
455
462
|
assert isinstance(rai_app, str), f"rai_app_name must be a string, not {type(rai_app)}"
|
|
456
463
|
assert isinstance(engine, str), f"engine must be a string, not {type(engine)}"
|
|
457
464
|
print("\n")
|
|
@@ -460,9 +467,15 @@ class Resources(ResourcesBase):
|
|
|
460
467
|
if re.search(f"database '{rai_app}' does not exist or not authorized.".lower(), orig_message):
|
|
461
468
|
exception = SnowflakeAppMissingException(rai_app, current_role)
|
|
462
469
|
raise exception from None
|
|
463
|
-
if
|
|
470
|
+
if _is_engine_issue(orig_message) or _is_database_issue(orig_message):
|
|
464
471
|
try:
|
|
465
|
-
self.
|
|
472
|
+
self._poll_use_index(
|
|
473
|
+
app_name=self.get_app_name(),
|
|
474
|
+
sources=self.sources,
|
|
475
|
+
model=self.database,
|
|
476
|
+
engine_name=engine,
|
|
477
|
+
engine_size=engine_size
|
|
478
|
+
)
|
|
466
479
|
return self._exec(code, params, raw=raw, help=help)
|
|
467
480
|
except EngineNameValidationException as e:
|
|
468
481
|
raise EngineNameValidationException(engine) from e
|
|
@@ -767,7 +780,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
767
780
|
keep_database = not force and self.config.get("reuse_model", True)
|
|
768
781
|
with debugging.span("release_index", name=name, keep_database=keep_database, language=language):
|
|
769
782
|
#TODO add headers to release_index
|
|
770
|
-
response = self._exec(f"call {APP_NAME}.api.release_index('{name}', OBJECT_CONSTRUCT('keep_database', {keep_database}, 'language', '{language}'));")
|
|
783
|
+
response = self._exec(f"call {APP_NAME}.api.release_index('{name}', OBJECT_CONSTRUCT('keep_database', {keep_database}, 'language', '{language}', 'user_agent', '{get_pyrel_version(self.generation)}'));")
|
|
771
784
|
if response:
|
|
772
785
|
result = next(iter(response))
|
|
773
786
|
obj = json.loads(result["RELEASE_INDEX"])
|
|
@@ -788,14 +801,13 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
788
801
|
headers = debugging.gen_current_propagation_headers()
|
|
789
802
|
self._exec(f"call {APP_NAME}.api.clone_database('{target_name}', '{source_name}', {nowait_durable}, {headers});")
|
|
790
803
|
|
|
791
|
-
def
|
|
804
|
+
def _poll_use_index(
|
|
792
805
|
self,
|
|
793
806
|
app_name: str,
|
|
794
807
|
sources: Iterable[str],
|
|
795
808
|
model: str,
|
|
796
809
|
engine_name: str,
|
|
797
810
|
engine_size: str | None = None,
|
|
798
|
-
language: str = "rel",
|
|
799
811
|
program_span_id: str | None = None,
|
|
800
812
|
headers: Dict | None = None,
|
|
801
813
|
):
|
|
@@ -806,12 +818,36 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
806
818
|
model,
|
|
807
819
|
engine_name,
|
|
808
820
|
engine_size,
|
|
809
|
-
language,
|
|
821
|
+
self.language,
|
|
810
822
|
program_span_id,
|
|
811
823
|
headers,
|
|
812
824
|
self.generation
|
|
813
825
|
).poll()
|
|
814
826
|
|
|
827
|
+
def maybe_poll_use_index(
|
|
828
|
+
self,
|
|
829
|
+
app_name: str,
|
|
830
|
+
sources: Iterable[str],
|
|
831
|
+
model: str,
|
|
832
|
+
engine_name: str,
|
|
833
|
+
engine_size: str | None = None,
|
|
834
|
+
program_span_id: str | None = None,
|
|
835
|
+
headers: Dict | None = None,
|
|
836
|
+
):
|
|
837
|
+
"""Only call _poll_use_index if there are sources to process."""
|
|
838
|
+
sources_list = list(sources)
|
|
839
|
+
self.database = model
|
|
840
|
+
if sources_list:
|
|
841
|
+
return self._poll_use_index(
|
|
842
|
+
app_name=app_name,
|
|
843
|
+
sources=sources_list,
|
|
844
|
+
model=model,
|
|
845
|
+
engine_name=engine_name,
|
|
846
|
+
engine_size=engine_size,
|
|
847
|
+
program_span_id=program_span_id,
|
|
848
|
+
headers=headers,
|
|
849
|
+
)
|
|
850
|
+
|
|
815
851
|
#--------------------------------------------------
|
|
816
852
|
# Models
|
|
817
853
|
#--------------------------------------------------
|
|
@@ -1868,9 +1904,19 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1868
1904
|
)
|
|
1869
1905
|
except Exception as e:
|
|
1870
1906
|
err_message = str(e).lower()
|
|
1871
|
-
if _is_engine_issue(err_message):
|
|
1872
|
-
self.
|
|
1873
|
-
self.
|
|
1907
|
+
if _is_engine_issue(err_message) or _is_database_issue(err_message):
|
|
1908
|
+
engine_name = engine or self.get_default_engine_name()
|
|
1909
|
+
engine_size = self.config.get_default_engine_size()
|
|
1910
|
+
self._poll_use_index(
|
|
1911
|
+
app_name=self.get_app_name(),
|
|
1912
|
+
sources=self.sources,
|
|
1913
|
+
model=database,
|
|
1914
|
+
engine_name=engine_name,
|
|
1915
|
+
engine_size=engine_size,
|
|
1916
|
+
headers=headers,
|
|
1917
|
+
)
|
|
1918
|
+
|
|
1919
|
+
return self._exec_async_v2(
|
|
1874
1920
|
database, engine, raw_code_b64, inputs, readonly, nowait_durable,
|
|
1875
1921
|
headers=headers, bypass_index=bypass_index, language='lqp',
|
|
1876
1922
|
query_timeout_mins=query_timeout_mins,
|
|
@@ -1908,8 +1954,17 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1908
1954
|
)
|
|
1909
1955
|
except Exception as e:
|
|
1910
1956
|
err_message = str(e).lower()
|
|
1911
|
-
if _is_engine_issue(err_message):
|
|
1912
|
-
self.
|
|
1957
|
+
if _is_engine_issue(err_message) or _is_database_issue(err_message):
|
|
1958
|
+
engine_name = engine or self.get_default_engine_name()
|
|
1959
|
+
engine_size = self.config.get_default_engine_size()
|
|
1960
|
+
self._poll_use_index(
|
|
1961
|
+
app_name=self.get_app_name(),
|
|
1962
|
+
sources=self.sources,
|
|
1963
|
+
model=database,
|
|
1964
|
+
engine_name=engine_name,
|
|
1965
|
+
engine_size=engine_size,
|
|
1966
|
+
headers=headers,
|
|
1967
|
+
)
|
|
1913
1968
|
return self._exec_async_v2(
|
|
1914
1969
|
database,
|
|
1915
1970
|
engine,
|
|
@@ -2972,13 +3027,12 @@ class SnowflakeClient(Client):
|
|
|
2972
3027
|
|
|
2973
3028
|
query_attrs_dict = json.loads(headers.get("X-Query-Attributes", "{}")) if headers else {}
|
|
2974
3029
|
with debugging.span("poll_use_index", sources=self.resources.sources, model=model, engine=engine_name, **query_attrs_dict):
|
|
2975
|
-
self.
|
|
3030
|
+
self.maybe_poll_use_index(
|
|
2976
3031
|
app_name=app_name,
|
|
2977
3032
|
sources=self.resources.sources,
|
|
2978
3033
|
model=model,
|
|
2979
3034
|
engine_name=engine_name,
|
|
2980
3035
|
engine_size=engine_size,
|
|
2981
|
-
language="rel",
|
|
2982
3036
|
program_span_id=program_span_id,
|
|
2983
3037
|
headers=headers
|
|
2984
3038
|
)
|
|
@@ -2989,29 +3043,24 @@ class SnowflakeClient(Client):
|
|
|
2989
3043
|
if isolated and not self.keep_model:
|
|
2990
3044
|
atexit.register(self.delete_database)
|
|
2991
3045
|
|
|
2992
|
-
|
|
2993
|
-
# if data is ready, break the loop
|
|
2994
|
-
# if data is not ready, print the status of the tables or engines
|
|
2995
|
-
# if data is not ready and there are errors, collect the errors and raise exceptions
|
|
2996
|
-
def poll_use_index(
|
|
3046
|
+
def maybe_poll_use_index(
|
|
2997
3047
|
self,
|
|
2998
3048
|
app_name: str,
|
|
2999
3049
|
sources: Iterable[str],
|
|
3000
3050
|
model: str,
|
|
3001
3051
|
engine_name: str,
|
|
3002
3052
|
engine_size: str | None = None,
|
|
3003
|
-
language: str = "rel",
|
|
3004
3053
|
program_span_id: str | None = None,
|
|
3005
3054
|
headers: Dict | None = None,
|
|
3006
3055
|
):
|
|
3056
|
+
"""Only call _poll_use_index if there are sources to process."""
|
|
3007
3057
|
assert isinstance(self.resources, Resources)
|
|
3008
|
-
return self.resources.
|
|
3058
|
+
return self.resources.maybe_poll_use_index(
|
|
3009
3059
|
app_name=app_name,
|
|
3010
3060
|
sources=sources,
|
|
3011
3061
|
model=model,
|
|
3012
3062
|
engine_name=engine_name,
|
|
3013
3063
|
engine_size=engine_size,
|
|
3014
|
-
language=language,
|
|
3015
3064
|
program_span_id=program_span_id,
|
|
3016
3065
|
headers=headers
|
|
3017
3066
|
)
|
|
@@ -3136,6 +3185,7 @@ class DirectAccessResources(Resources):
|
|
|
3136
3185
|
dry_run: bool = False,
|
|
3137
3186
|
reset_session: bool = False,
|
|
3138
3187
|
generation: Optional[Generation] = None,
|
|
3188
|
+
language: str = "rel",
|
|
3139
3189
|
):
|
|
3140
3190
|
super().__init__(
|
|
3141
3191
|
generation=generation,
|
|
@@ -3144,11 +3194,13 @@ class DirectAccessResources(Resources):
|
|
|
3144
3194
|
connection=connection,
|
|
3145
3195
|
reset_session=reset_session,
|
|
3146
3196
|
dry_run=dry_run,
|
|
3197
|
+
language=language,
|
|
3147
3198
|
)
|
|
3148
3199
|
self._endpoint_info = ConfigStore(ENDPOINT_FILE)
|
|
3149
3200
|
self._service_endpoint = ""
|
|
3150
3201
|
self._direct_access_client = None
|
|
3151
3202
|
self.generation = generation
|
|
3203
|
+
self.database = ""
|
|
3152
3204
|
|
|
3153
3205
|
@property
|
|
3154
3206
|
def service_endpoint(self) -> str:
|
|
@@ -3226,9 +3278,18 @@ class DirectAccessResources(Resources):
|
|
|
3226
3278
|
|
|
3227
3279
|
# fix engine on engine error and retry
|
|
3228
3280
|
# Skip auto-retry if skip_auto_create is True to avoid recursion
|
|
3229
|
-
if _is_engine_issue(message) and not skip_auto_create:
|
|
3230
|
-
|
|
3231
|
-
self.
|
|
3281
|
+
if (_is_engine_issue(message) and not skip_auto_create) or _is_database_issue(message):
|
|
3282
|
+
engine_name = payload.get("caller_engine_name", "") if payload else ""
|
|
3283
|
+
engine_name = engine_name or self.get_default_engine_name()
|
|
3284
|
+
engine_size = self.config.get_default_engine_size()
|
|
3285
|
+
self._poll_use_index(
|
|
3286
|
+
app_name=self.get_app_name(),
|
|
3287
|
+
sources=self.sources,
|
|
3288
|
+
model=self.database,
|
|
3289
|
+
engine_name=engine_name,
|
|
3290
|
+
engine_size=engine_size,
|
|
3291
|
+
headers=headers,
|
|
3292
|
+
)
|
|
3232
3293
|
response = _send_request()
|
|
3233
3294
|
except requests.exceptions.ConnectionError as e:
|
|
3234
3295
|
if "NameResolutionError" in str(e):
|
|
@@ -3356,14 +3417,13 @@ class DirectAccessResources(Resources):
|
|
|
3356
3417
|
|
|
3357
3418
|
return response.json()
|
|
3358
3419
|
|
|
3359
|
-
def
|
|
3420
|
+
def _poll_use_index(
|
|
3360
3421
|
self,
|
|
3361
3422
|
app_name: str,
|
|
3362
3423
|
sources: Iterable[str],
|
|
3363
3424
|
model: str,
|
|
3364
3425
|
engine_name: str,
|
|
3365
3426
|
engine_size: str | None = None,
|
|
3366
|
-
language: str = "rel",
|
|
3367
3427
|
program_span_id: str | None = None,
|
|
3368
3428
|
headers: Dict | None = None,
|
|
3369
3429
|
):
|
|
@@ -3374,12 +3434,36 @@ class DirectAccessResources(Resources):
|
|
|
3374
3434
|
model=model,
|
|
3375
3435
|
engine_name=engine_name,
|
|
3376
3436
|
engine_size=engine_size,
|
|
3377
|
-
language=language,
|
|
3437
|
+
language=self.language,
|
|
3378
3438
|
program_span_id=program_span_id,
|
|
3379
3439
|
headers=headers,
|
|
3380
3440
|
generation=self.generation,
|
|
3381
3441
|
).poll()
|
|
3382
3442
|
|
|
3443
|
+
def maybe_poll_use_index(
|
|
3444
|
+
self,
|
|
3445
|
+
app_name: str,
|
|
3446
|
+
sources: Iterable[str],
|
|
3447
|
+
model: str,
|
|
3448
|
+
engine_name: str,
|
|
3449
|
+
engine_size: str | None = None,
|
|
3450
|
+
program_span_id: str | None = None,
|
|
3451
|
+
headers: Dict | None = None,
|
|
3452
|
+
):
|
|
3453
|
+
"""Only call _poll_use_index if there are sources to process."""
|
|
3454
|
+
sources_list = list(sources)
|
|
3455
|
+
self.database = model
|
|
3456
|
+
if sources_list:
|
|
3457
|
+
return self._poll_use_index(
|
|
3458
|
+
app_name=app_name,
|
|
3459
|
+
sources=sources_list,
|
|
3460
|
+
model=model,
|
|
3461
|
+
engine_name=engine_name,
|
|
3462
|
+
engine_size=engine_size,
|
|
3463
|
+
program_span_id=program_span_id,
|
|
3464
|
+
headers=headers,
|
|
3465
|
+
)
|
|
3466
|
+
|
|
3383
3467
|
def _check_exec_async_status(self, txn_id: str, headers: Dict[str, str] | None = None) -> bool:
|
|
3384
3468
|
"""Check whether the given transaction has completed."""
|
|
3385
3469
|
|
|
@@ -3522,7 +3606,12 @@ class DirectAccessResources(Resources):
|
|
|
3522
3606
|
with debugging.span("release_index", name=name, keep_database=keep_database, language=language):
|
|
3523
3607
|
response = self.request(
|
|
3524
3608
|
"release_index",
|
|
3525
|
-
payload={
|
|
3609
|
+
payload={
|
|
3610
|
+
"model_name": name,
|
|
3611
|
+
"keep_database": keep_database,
|
|
3612
|
+
"language": language,
|
|
3613
|
+
"user_agent": get_pyrel_version(self.generation),
|
|
3614
|
+
},
|
|
3526
3615
|
headers=prop_hdrs,
|
|
3527
3616
|
)
|
|
3528
3617
|
if (
|
|
@@ -261,20 +261,7 @@ class SolverModel:
|
|
|
261
261
|
remaining_timeout_minutes = calc_remaining_timeout_minutes(
|
|
262
262
|
start_time, query_timeout_mins, config_file_path=config_file_path
|
|
263
263
|
)
|
|
264
|
-
|
|
265
|
-
job_id = solver._exec_job(payload, log_to_console=log_to_console, query_timeout_mins=remaining_timeout_minutes)
|
|
266
|
-
except Exception as e:
|
|
267
|
-
err_message = str(e).lower()
|
|
268
|
-
if isinstance(e, ResponseStatusException):
|
|
269
|
-
err_message = e.response.json().get("message", "")
|
|
270
|
-
if any(kw in err_message.lower() for kw in ENGINE_ERRORS + WORKER_ERRORS + ENGINE_NOT_READY_MSGS):
|
|
271
|
-
solver._auto_create_solver_async()
|
|
272
|
-
remaining_timeout_minutes = calc_remaining_timeout_minutes(
|
|
273
|
-
start_time, query_timeout_mins, config_file_path=config_file_path
|
|
274
|
-
)
|
|
275
|
-
job_id = solver._exec_job(payload, log_to_console=log_to_console, query_timeout_mins=remaining_timeout_minutes)
|
|
276
|
-
else:
|
|
277
|
-
raise e
|
|
264
|
+
job_id = solver._exec_job(payload, log_to_console=log_to_console, query_timeout_mins=remaining_timeout_minutes)
|
|
278
265
|
|
|
279
266
|
# 3. Extract result.
|
|
280
267
|
remaining_timeout_minutes = calc_remaining_timeout_minutes(
|
|
@@ -660,12 +647,24 @@ class Solver:
|
|
|
660
647
|
if self.engine is None:
|
|
661
648
|
raise Exception("Engine not initialized.")
|
|
662
649
|
|
|
663
|
-
# Make sure the engine is ready.
|
|
664
|
-
if self.engine["state"] != "READY":
|
|
665
|
-
poll_with_specified_overhead(lambda: self._is_solver_ready(), 0.1)
|
|
666
|
-
|
|
667
650
|
with debugging.span("job") as job_span:
|
|
668
|
-
|
|
651
|
+
# Retry logic. If creating a job fails with an engine
|
|
652
|
+
# related error we will create/resume/... the engine and
|
|
653
|
+
# retry.
|
|
654
|
+
try:
|
|
655
|
+
job_id = self.provider.create_job_async(self.engine["name"], payload, query_timeout_mins=query_timeout_mins)
|
|
656
|
+
except Exception as e:
|
|
657
|
+
err_message = str(e).lower()
|
|
658
|
+
if isinstance(e, ResponseStatusException):
|
|
659
|
+
err_message = e.response.json().get("message", "")
|
|
660
|
+
if any(kw in err_message.lower() for kw in ENGINE_ERRORS + WORKER_ERRORS + ENGINE_NOT_READY_MSGS):
|
|
661
|
+
self._auto_create_solver_async()
|
|
662
|
+
# Wait until the engine is ready.
|
|
663
|
+
poll_with_specified_overhead(lambda: self._is_solver_ready(), 0.1)
|
|
664
|
+
job_id = self.provider.create_job_async(self.engine["name"], payload, query_timeout_mins=query_timeout_mins)
|
|
665
|
+
else:
|
|
666
|
+
raise e
|
|
667
|
+
|
|
669
668
|
job_span["job_id"] = job_id
|
|
670
669
|
debugging.event("job_created", job_span, job_id=job_id, engine_name=self.engine["name"], job_type=ENGINE_TYPE_SOLVER)
|
|
671
670
|
if not isinstance(job_id, str):
|
|
@@ -264,12 +264,11 @@ class Table():
|
|
|
264
264
|
else:
|
|
265
265
|
me = self._rel._field_refs[0]
|
|
266
266
|
b.where(self).define(concept(me))
|
|
267
|
-
#
|
|
268
|
-
rel_func = b.Relationship if keys else b.Property
|
|
267
|
+
# All the fields are treated as properties
|
|
269
268
|
for field in self._rel._fields[1:]:
|
|
270
269
|
field_name = sanitize_identifier(field.name.lower())
|
|
271
270
|
if field_name not in key_dict:
|
|
272
|
-
r =
|
|
271
|
+
r = b.Property(
|
|
273
272
|
f"{{{concept}}} has {{{field_name}:{field.type_str}}}",
|
|
274
273
|
parent=concept,
|
|
275
274
|
short_name=field_name,
|
|
@@ -56,7 +56,15 @@ class LQPExecutor(e.Executor):
|
|
|
56
56
|
resource_class = rai.clients.snowflake.Resources
|
|
57
57
|
if self.config.get("use_direct_access", USE_DIRECT_ACCESS):
|
|
58
58
|
resource_class = rai.clients.snowflake.DirectAccessResources
|
|
59
|
-
|
|
59
|
+
# NOTE: language="lqp" is not strictly required for LQP execution, but it
|
|
60
|
+
# will significantly improve performance.
|
|
61
|
+
self._resources = resource_class(
|
|
62
|
+
dry_run=self.dry_run,
|
|
63
|
+
config=self.config,
|
|
64
|
+
generation=rai.Generation.QB,
|
|
65
|
+
connection=self.connection,
|
|
66
|
+
language="lqp",
|
|
67
|
+
)
|
|
60
68
|
if not self.dry_run:
|
|
61
69
|
self.engine = self._resources.get_default_engine_name()
|
|
62
70
|
if not self.keep_model:
|
|
@@ -88,13 +96,12 @@ class LQPExecutor(e.Executor):
|
|
|
88
96
|
assert self.engine is not None
|
|
89
97
|
|
|
90
98
|
with debugging.span("poll_use_index", sources=sources, model=model, engine=engine_name):
|
|
91
|
-
resources.
|
|
99
|
+
resources.maybe_poll_use_index(
|
|
92
100
|
app_name=app_name,
|
|
93
101
|
sources=sources,
|
|
94
102
|
model=model,
|
|
95
103
|
engine_name=self.engine,
|
|
96
104
|
engine_size=engine_size,
|
|
97
|
-
language="lqp",
|
|
98
105
|
program_span_id=program_span_id,
|
|
99
106
|
)
|
|
100
107
|
|
|
@@ -280,6 +287,8 @@ class LQPExecutor(e.Executor):
|
|
|
280
287
|
"""Construct an epoch that defines a number of built-in definitions used by the
|
|
281
288
|
emitter."""
|
|
282
289
|
with debugging.span("compile_intrinsics") as span:
|
|
290
|
+
span["compile_type"] = "intrinsics"
|
|
291
|
+
|
|
283
292
|
debug_info = lqp_ir.DebugInfo(id_to_orig_name={}, meta=None)
|
|
284
293
|
intrinsics_fragment = lqp_ir.Fragment(
|
|
285
294
|
id = lqp_ir.FragmentId(id=b"__pyrel_lqp_intrinsics", meta=None),
|
|
@@ -290,7 +299,7 @@ class LQPExecutor(e.Executor):
|
|
|
290
299
|
meta = None,
|
|
291
300
|
)
|
|
292
301
|
|
|
293
|
-
|
|
302
|
+
|
|
294
303
|
span["lqp"] = lqp_print.to_string(intrinsics_fragment, {"print_names": True, "print_debug": False, "print_csv_filename": False})
|
|
295
304
|
|
|
296
305
|
return lqp_ir.Epoch(
|
|
@@ -300,19 +309,41 @@ class LQPExecutor(e.Executor):
|
|
|
300
309
|
meta=None,
|
|
301
310
|
)
|
|
302
311
|
|
|
312
|
+
def _compile_undefine_query(self, query_epoch: lqp_ir.Epoch) -> lqp_ir.Epoch:
|
|
313
|
+
fragment_ids = []
|
|
314
|
+
|
|
315
|
+
for write in query_epoch.writes:
|
|
316
|
+
if isinstance(write.write_type, lqp_ir.Define):
|
|
317
|
+
fragment_ids.append(write.write_type.fragment.id)
|
|
318
|
+
|
|
319
|
+
# Construct new Epoch with Undefine operations for all collected fragment IDs
|
|
320
|
+
undefine_writes = [
|
|
321
|
+
lqp_ir.Write(
|
|
322
|
+
write_type=lqp_ir.Undefine(fragment_id=frag_id, meta=None),
|
|
323
|
+
meta=None
|
|
324
|
+
)
|
|
325
|
+
for frag_id in fragment_ids
|
|
326
|
+
]
|
|
327
|
+
|
|
328
|
+
return lqp_ir.Epoch(
|
|
329
|
+
writes=undefine_writes,
|
|
330
|
+
meta=None,
|
|
331
|
+
)
|
|
332
|
+
|
|
303
333
|
def compile_lqp(self, model: ir.Model, task: ir.Task):
|
|
304
334
|
configure = self._construct_configure()
|
|
305
335
|
|
|
306
336
|
model_txn = None
|
|
307
337
|
if self._last_model != model:
|
|
308
338
|
with debugging.span("compile", metamodel=model) as install_span:
|
|
339
|
+
install_span["compile_type"] = "model"
|
|
309
340
|
_, model_txn = self.compiler.compile(model, {"fragment_id": b"model"})
|
|
310
341
|
model_txn = txn_with_configure(model_txn, configure)
|
|
311
|
-
install_span["compile_type"] = "model"
|
|
312
342
|
install_span["lqp"] = lqp_print.to_string(model_txn, {"print_names": True, "print_debug": False, "print_csv_filename": False})
|
|
313
343
|
self._last_model = model
|
|
314
344
|
|
|
315
345
|
with debugging.span("compile", metamodel=task) as compile_span:
|
|
346
|
+
compile_span["compile_type"] = "query"
|
|
316
347
|
query = f.compute_model(f.logical([task]))
|
|
317
348
|
options = {
|
|
318
349
|
"wide_outputs": self.wide_outputs,
|
|
@@ -321,7 +352,6 @@ class LQPExecutor(e.Executor):
|
|
|
321
352
|
result, final_model = self.compiler.compile_inner(query, options)
|
|
322
353
|
export_info, query_txn = result
|
|
323
354
|
query_txn = txn_with_configure(query_txn, configure)
|
|
324
|
-
compile_span["compile_type"] = "query"
|
|
325
355
|
compile_span["lqp"] = lqp_print.to_string(query_txn, {"print_names": True, "print_debug": False, "print_csv_filename": False})
|
|
326
356
|
|
|
327
357
|
# Merge the epochs into a single transactions. Long term the query bits should all
|
|
@@ -334,11 +364,11 @@ class LQPExecutor(e.Executor):
|
|
|
334
364
|
if model_txn is not None:
|
|
335
365
|
epochs.append(model_txn.epochs[0])
|
|
336
366
|
|
|
337
|
-
|
|
367
|
+
query_txn_epoch = query_txn.epochs[0]
|
|
368
|
+
epochs.append(query_txn_epoch)
|
|
369
|
+
epochs.append(self._compile_undefine_query(query_txn_epoch))
|
|
338
370
|
|
|
339
371
|
txn = lqp_ir.Transaction(epochs=epochs, configure=configure, meta=None)
|
|
340
|
-
|
|
341
|
-
# Revalidate now that we've joined all the epochs.
|
|
342
372
|
validate_lqp(txn)
|
|
343
373
|
|
|
344
374
|
txn_proto = convert_transaction(txn)
|
|
@@ -2,8 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from typing import Optional
|
|
5
|
-
from relationalai.semantics.metamodel
|
|
6
|
-
from relationalai.semantics.metamodel import ir, factory as f, helpers
|
|
5
|
+
from relationalai.semantics.metamodel import ir, factory as f, helpers, visitor
|
|
7
6
|
from relationalai.semantics.metamodel.compiler import Pass, group_tasks
|
|
8
7
|
from relationalai.semantics.metamodel.util import OrderedSet, ordered_set
|
|
9
8
|
from relationalai.semantics.metamodel import dependency
|
|
@@ -52,7 +51,7 @@ class ExtractCommon(Pass):
|
|
|
52
51
|
#--------------------------------------------------
|
|
53
52
|
def rewrite(self, model: ir.Model, options:dict={}) -> ir.Model:
|
|
54
53
|
# create the context
|
|
55
|
-
ctx = ExtractCommon.Context(model)
|
|
54
|
+
ctx = ExtractCommon.Context(model, options)
|
|
56
55
|
|
|
57
56
|
# rewrite the root
|
|
58
57
|
replacement = self.handle(model.root, ctx)
|
|
@@ -76,9 +75,10 @@ class ExtractCommon(Pass):
|
|
|
76
75
|
#--------------------------------------------------
|
|
77
76
|
|
|
78
77
|
class Context():
|
|
79
|
-
def __init__(self, model: ir.Model):
|
|
78
|
+
def __init__(self, model: ir.Model, options: dict):
|
|
80
79
|
self.rewrite_ctx = helpers.RewriteContext()
|
|
81
80
|
self.info = dependency.analyze(model.root)
|
|
81
|
+
self.options = options
|
|
82
82
|
|
|
83
83
|
def handle(self, task: ir.Task, ctx: Context):
|
|
84
84
|
# currently we only extract if it's a sequence of Logicals, but we could in the
|
|
@@ -107,7 +107,7 @@ class ExtractCommon(Pass):
|
|
|
107
107
|
# extracted logic).
|
|
108
108
|
plan = None
|
|
109
109
|
if len(binders) > 1 and composites_and_effects:
|
|
110
|
-
extractables =
|
|
110
|
+
extractables = self._get_extractables(ctx, composites_and_effects)
|
|
111
111
|
# only makes sense to extract common if at least one nested composite will be
|
|
112
112
|
# extracted during Flatten
|
|
113
113
|
if extractables:
|
|
@@ -265,10 +265,12 @@ class ExtractCommon(Pass):
|
|
|
265
265
|
for child in common_body:
|
|
266
266
|
body_output_vars.update(ctx.info.task_outputs(child))
|
|
267
267
|
|
|
268
|
-
# Compute the union of input vars across all
|
|
268
|
+
# Compute the union of input vars across all non-extracted tasks (basically
|
|
269
|
+
# composites and binders left behind), intersected with output
|
|
269
270
|
# vars of the common body
|
|
270
271
|
exposed_vars = OrderedSet.from_iterable(ctx.info.task_inputs(sample)) & body_output_vars
|
|
271
|
-
|
|
272
|
+
non_extracted_tasks = (binders - common_body) | composites
|
|
273
|
+
for composite in non_extracted_tasks:
|
|
272
274
|
if composite is sample:
|
|
273
275
|
continue
|
|
274
276
|
# compute common input vars
|
|
@@ -323,7 +325,6 @@ class ExtractCommon(Pass):
|
|
|
323
325
|
|
|
324
326
|
return ExtractCommon.ExtractionPlan(common_body, remaining, exposed_vars, local_dependencies, distribute_common_reference)
|
|
325
327
|
|
|
326
|
-
|
|
327
328
|
def _compute_local_dependencies(self, ctx: Context, binders: OrderedSet[ir.Task], composite: ir.Task, exposed_vars: OrderedSet[ir.Var]):
|
|
328
329
|
"""
|
|
329
330
|
The tasks in common_body will be extracted into a logical that will expose the exposed_vars.
|
|
@@ -360,3 +361,24 @@ class ExtractCommon(Pass):
|
|
|
360
361
|
if inputs:
|
|
361
362
|
vars_needed.update(inputs - vars_exposed)
|
|
362
363
|
return local_body
|
|
364
|
+
|
|
365
|
+
def _get_extractables(self, ctx: Context, composites: OrderedSet[ir.Task]):
|
|
366
|
+
"""
|
|
367
|
+
Extractables are tasks that will eventually be extracted by the Flatten pass later.
|
|
368
|
+
Given a set of tasks, return the extractable ones.
|
|
369
|
+
"""
|
|
370
|
+
def _extractable(t: ir.Task):
|
|
371
|
+
# With GNF outputs (i.e., wide_outputs = False), the output tasks will be
|
|
372
|
+
# extracted into separate top-level single-column outputs.
|
|
373
|
+
if isinstance(t, ir.Output) and not ctx.options.get("wide_outputs", False):
|
|
374
|
+
return True
|
|
375
|
+
|
|
376
|
+
extractable_types = (ir.Update, ir.Aggregate, ir.Match, ir.Union, ir.Rank)
|
|
377
|
+
return isinstance(t, ir.Logical) and len(visitor.collect_by_type(extractable_types, t)) > 0
|
|
378
|
+
|
|
379
|
+
extractables = []
|
|
380
|
+
for t in composites:
|
|
381
|
+
if _extractable(t):
|
|
382
|
+
extractables.append(t)
|
|
383
|
+
|
|
384
|
+
return extractables
|
|
@@ -112,12 +112,12 @@ log10 = f.relation(
|
|
|
112
112
|
|
|
113
113
|
log = f.relation(
|
|
114
114
|
"log",
|
|
115
|
-
[f.input_field("
|
|
115
|
+
[f.input_field("base", types.Number), f.input_field("value", types.Number), f.field("result", types.Float)],
|
|
116
116
|
overloads=[
|
|
117
|
-
f.relation("log", [f.input_field("
|
|
118
|
-
f.relation("log", [f.input_field("
|
|
119
|
-
f.relation("log", [f.input_field("
|
|
120
|
-
f.relation("log", [f.input_field("
|
|
117
|
+
f.relation("log", [f.input_field("base", types.Int64), f.input_field("value", types.Int64), f.field("result", types.Float)]),
|
|
118
|
+
f.relation("log", [f.input_field("base", types.Int128), f.input_field("value", types.Int128), f.field("result", types.Float)]),
|
|
119
|
+
f.relation("log", [f.input_field("base", types.Float), f.input_field("value", types.Float), f.field("result", types.Float)]),
|
|
120
|
+
f.relation("log", [f.input_field("base", types.GenericDecimal), f.input_field("value", types.GenericDecimal), f.field("result", types.Float)]),
|
|
121
121
|
|
|
122
122
|
],
|
|
123
123
|
)
|
|
@@ -496,7 +496,7 @@ function = f.relation("function", [f.input_field("code", types.Symbol)])
|
|
|
496
496
|
function_checked_annotation = f.annotation(function, [f.lit("checked")])
|
|
497
497
|
function_annotation = f.annotation(function, [])
|
|
498
498
|
|
|
499
|
-
# Indicates this relation should be tracked in telemetry.
|
|
499
|
+
# Indicates this relation should be tracked in telemetry. Supported for Relationships and Concepts.
|
|
500
500
|
# `RAI_BackIR.with_relation_tracking` produces log messages at the start and end of each
|
|
501
501
|
# SCC evaluation, if any declarations bear the `track` annotation.
|
|
502
502
|
track = f.relation("track", [
|