relationalai 0.13.4__py3-none-any.whl → 0.13.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/clients/exec_txn_poller.py +51 -20
- relationalai/clients/local.py +15 -7
- relationalai/clients/resources/snowflake/__init__.py +2 -2
- relationalai/clients/resources/snowflake/direct_access_resources.py +8 -4
- relationalai/clients/resources/snowflake/snowflake.py +16 -11
- relationalai/experimental/solvers.py +8 -0
- relationalai/semantics/lqp/executor.py +3 -3
- relationalai/semantics/lqp/model2lqp.py +34 -28
- relationalai/semantics/lqp/passes.py +6 -3
- relationalai/semantics/lqp/result_helpers.py +76 -12
- relationalai/semantics/lqp/rewrite/__init__.py +2 -0
- relationalai/semantics/lqp/rewrite/extract_common.py +3 -1
- relationalai/semantics/lqp/rewrite/extract_keys.py +85 -20
- relationalai/semantics/lqp/rewrite/flatten_script.py +301 -0
- relationalai/semantics/lqp/rewrite/functional_dependencies.py +12 -7
- relationalai/semantics/lqp/rewrite/quantify_vars.py +12 -3
- relationalai/semantics/lqp/rewrite/unify_definitions.py +9 -3
- relationalai/semantics/metamodel/dependency.py +9 -0
- relationalai/semantics/metamodel/executor.py +17 -10
- relationalai/semantics/metamodel/rewrite/__init__.py +2 -1
- relationalai/semantics/metamodel/rewrite/flatten.py +1 -2
- relationalai/semantics/metamodel/rewrite/format_outputs.py +131 -46
- relationalai/semantics/metamodel/rewrite/handle_aggregations_and_ranks.py +237 -0
- relationalai/semantics/metamodel/typer/typer.py +1 -1
- relationalai/semantics/reasoners/optimization/solvers_pb.py +101 -107
- relationalai/semantics/rel/compiler.py +7 -3
- relationalai/semantics/rel/executor.py +1 -1
- relationalai/tools/txn_progress.py +188 -0
- {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/METADATA +1 -1
- {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/RECORD +33 -30
- {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/WHEEL +0 -0
- {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/entry_points.txt +0 -0
- {relationalai-0.13.4.dist-info → relationalai-0.13.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -5,11 +5,14 @@ from typing import Dict, Optional, TYPE_CHECKING
|
|
|
5
5
|
|
|
6
6
|
from relationalai import debugging
|
|
7
7
|
from relationalai.clients.util import poll_with_specified_overhead
|
|
8
|
+
from relationalai.clients.config import Config
|
|
8
9
|
from relationalai.tools.cli_controls import create_progress
|
|
9
10
|
from relationalai.util.format import format_duration
|
|
11
|
+
from relationalai.tools.txn_progress import format_execution_tree
|
|
10
12
|
|
|
11
13
|
if TYPE_CHECKING:
|
|
12
14
|
from relationalai.clients.resources.snowflake import Resources
|
|
15
|
+
from relationalai.clients.resources.snowflake.snowflake import TxnStatusResponse
|
|
13
16
|
|
|
14
17
|
# Polling behavior constants
|
|
15
18
|
POLL_OVERHEAD_RATE = 0.1 # Overhead rate for exponential backoff
|
|
@@ -19,6 +22,14 @@ GREEN_COLOR = '\033[92m'
|
|
|
19
22
|
GRAY_COLOR = '\033[90m'
|
|
20
23
|
ENDC = '\033[0m'
|
|
21
24
|
|
|
25
|
+
PRINT_TXN_PROGRESS_FLAG = "print_txn_progress"
|
|
26
|
+
PRINT_INTERNAL_TXN_PROGRESS_FLAG = "print_txn_progress_internal"
|
|
27
|
+
|
|
28
|
+
def should_print_txn_progress(config: Config) -> bool:
|
|
29
|
+
return bool(config.get(PRINT_TXN_PROGRESS_FLAG, False))
|
|
30
|
+
|
|
31
|
+
def should_print_internal_txn_progress(config) -> bool:
|
|
32
|
+
return bool(config.get(PRINT_INTERNAL_TXN_PROGRESS_FLAG, False))
|
|
22
33
|
|
|
23
34
|
class ExecTxnPoller:
|
|
24
35
|
"""
|
|
@@ -27,17 +38,19 @@ class ExecTxnPoller:
|
|
|
27
38
|
|
|
28
39
|
def __init__(
|
|
29
40
|
self,
|
|
30
|
-
|
|
31
|
-
resource: "Resources",
|
|
41
|
+
config: Config,
|
|
42
|
+
resource: Optional["Resources"] = None,
|
|
32
43
|
txn_id: Optional[str] = None,
|
|
33
44
|
headers: Optional[Dict] = None,
|
|
34
|
-
txn_start_time: Optional[float] = None
|
|
45
|
+
txn_start_time: Optional[float] = None
|
|
35
46
|
):
|
|
36
|
-
self.print_txn_progress =
|
|
47
|
+
self.print_txn_progress = should_print_txn_progress(config)
|
|
37
48
|
self.res = resource
|
|
38
49
|
self.txn_id = txn_id
|
|
39
50
|
self.headers = headers or {}
|
|
40
51
|
self.txn_start_time = txn_start_time or time.time()
|
|
52
|
+
self.print_internal_txn_progress = should_print_internal_txn_progress(config)
|
|
53
|
+
self.last_status: Optional[TxnStatusResponse] = None
|
|
41
54
|
|
|
42
55
|
def __enter__(self) -> ExecTxnPoller:
|
|
43
56
|
if not self.print_txn_progress:
|
|
@@ -53,17 +66,23 @@ class ExecTxnPoller:
|
|
|
53
66
|
return self
|
|
54
67
|
|
|
55
68
|
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
|
56
|
-
if not self.print_txn_progress
|
|
69
|
+
if not self.print_txn_progress:
|
|
57
70
|
return
|
|
58
71
|
# Update to success message with duration
|
|
59
72
|
total_duration = time.time() - self.txn_start_time
|
|
60
73
|
txn_id = self.txn_id
|
|
61
74
|
self.progress.update_main_status(
|
|
62
|
-
query_complete_message(txn_id, total_duration)
|
|
75
|
+
query_complete_message(txn_id, total_duration, internal_txn_progress=self._get_internal_progress())
|
|
63
76
|
)
|
|
64
77
|
self.progress.__exit__(exc_type, exc_value, traceback)
|
|
65
78
|
return
|
|
66
79
|
|
|
80
|
+
def _get_internal_progress(self) -> Optional[Dict]:
|
|
81
|
+
"""Get internal transaction progress if enabled and available."""
|
|
82
|
+
if self.print_internal_txn_progress and self.last_status:
|
|
83
|
+
return self.last_status.progress
|
|
84
|
+
return None
|
|
85
|
+
|
|
67
86
|
def poll(self) -> bool:
|
|
68
87
|
"""
|
|
69
88
|
Poll for transaction completion with interactive progress display.
|
|
@@ -79,44 +98,56 @@ class ExecTxnPoller:
|
|
|
79
98
|
if self.print_txn_progress:
|
|
80
99
|
# Update the main status to include the new txn_id
|
|
81
100
|
self.progress.update_main_status_fn(
|
|
82
|
-
lambda: self.description_with_timing(txn_id),
|
|
101
|
+
lambda: self.description_with_timing(txn_id, self._get_internal_progress()),
|
|
83
102
|
)
|
|
84
103
|
|
|
85
104
|
# Don't show duration summary - we handle our own completion message
|
|
86
105
|
def check_status() -> bool:
|
|
87
106
|
"""Check if transaction is complete."""
|
|
88
|
-
|
|
89
|
-
|
|
107
|
+
if self.res is None:
|
|
108
|
+
raise ValueError("Resource must be provided for polling.")
|
|
109
|
+
self.last_status = self.res._check_exec_async_status(txn_id, headers=self.headers)
|
|
110
|
+
return self.last_status.finished
|
|
90
111
|
|
|
91
|
-
with debugging.span("wait", txn_id=
|
|
112
|
+
with debugging.span("wait", txn_id=txn_id):
|
|
92
113
|
poll_with_specified_overhead(check_status, overhead_rate=POLL_OVERHEAD_RATE)
|
|
93
114
|
|
|
94
|
-
|
|
95
115
|
return True
|
|
96
116
|
|
|
97
|
-
def description_with_timing(self, txn_id: str | None = None) -> str:
|
|
117
|
+
def description_with_timing(self, txn_id: str | None = None, internal_txn_progress: Dict | None = None) -> str:
|
|
98
118
|
elapsed = time.time() - self.txn_start_time
|
|
99
119
|
if txn_id is None:
|
|
100
120
|
return query_progress_header(elapsed)
|
|
101
121
|
else:
|
|
102
|
-
return query_progress_message(txn_id, elapsed)
|
|
122
|
+
return query_progress_message(txn_id, elapsed, internal_txn_progress)
|
|
103
123
|
|
|
104
124
|
def query_progress_header(duration: float) -> str:
|
|
105
125
|
# Don't print sub-second decimals, because it updates too fast and is distracting.
|
|
106
126
|
duration_str = format_duration(duration, seconds_decimals=False)
|
|
107
127
|
return f"Evaluating Query... {duration_str:>15}\n"
|
|
108
128
|
|
|
109
|
-
def query_progress_message(id: str, duration: float) -> str:
|
|
110
|
-
|
|
129
|
+
def query_progress_message(id: str, duration: float, internal_txn_progress: Dict | None = None) -> str:
|
|
130
|
+
result = (
|
|
111
131
|
query_progress_header(duration) +
|
|
112
132
|
# Print with whitespace to align with the end of the transaction ID
|
|
113
133
|
f"{GRAY_COLOR}ID: {id}{ENDC}"
|
|
114
134
|
)
|
|
135
|
+
if internal_txn_progress is not None:
|
|
136
|
+
result += format_execution_tree(internal_txn_progress)
|
|
137
|
+
return result
|
|
115
138
|
|
|
116
|
-
def query_complete_message(id: str, duration: float, status_header: bool = False) -> str:
|
|
117
|
-
|
|
139
|
+
def query_complete_message(id: str | None, duration: float, status_header: bool = False, internal_txn_progress: Dict | None = None) -> str:
|
|
140
|
+
out = (
|
|
118
141
|
(f"{GREEN_COLOR}✅ " if status_header else "") +
|
|
119
142
|
# Print with whitespace to align with the end of the transaction ID
|
|
120
|
-
f"Query Complete: {format_duration(duration):>21}
|
|
121
|
-
|
|
122
|
-
|
|
143
|
+
f"Query Complete: {format_duration(duration):>21}"
|
|
144
|
+
)
|
|
145
|
+
if id is None:
|
|
146
|
+
out += ENDC
|
|
147
|
+
else:
|
|
148
|
+
out += f"\n{GRAY_COLOR}ID: {id}{ENDC}"
|
|
149
|
+
|
|
150
|
+
if internal_txn_progress is not None:
|
|
151
|
+
out += format_execution_tree(internal_txn_progress)
|
|
152
|
+
|
|
153
|
+
return out
|
relationalai/clients/local.py
CHANGED
|
@@ -4,6 +4,7 @@ import base64
|
|
|
4
4
|
import json
|
|
5
5
|
from urllib.parse import quote, urlencode
|
|
6
6
|
import pyarrow as pa
|
|
7
|
+
import time
|
|
7
8
|
import requests
|
|
8
9
|
from email import message_from_bytes, policy
|
|
9
10
|
from email.message import EmailMessage
|
|
@@ -18,6 +19,7 @@ from .config import Config
|
|
|
18
19
|
from .types import TransactionAsyncResponse
|
|
19
20
|
from .util import get_pyrel_version
|
|
20
21
|
from ..errors import ResponseStatusException
|
|
22
|
+
from ..clients.exec_txn_poller import ExecTxnPoller
|
|
21
23
|
from .. import debugging
|
|
22
24
|
|
|
23
25
|
@dataclass
|
|
@@ -112,7 +114,7 @@ class LocalResources(ResourcesBase):
|
|
|
112
114
|
|
|
113
115
|
def reset(self):
|
|
114
116
|
raise NotImplementedError("reset not supported in local mode")
|
|
115
|
-
|
|
117
|
+
|
|
116
118
|
#--------------------------------------------------
|
|
117
119
|
# Check direct access is enabled (0 implemented)
|
|
118
120
|
#--------------------------------------------------
|
|
@@ -332,7 +334,7 @@ class LocalResources(ResourcesBase):
|
|
|
332
334
|
#--------------------------------------------------
|
|
333
335
|
# Exec Async
|
|
334
336
|
#--------------------------------------------------
|
|
335
|
-
|
|
337
|
+
|
|
336
338
|
def _parse_multipart_response(self, response: requests.Response) -> Dict[str, Any]:
|
|
337
339
|
response_map = {}
|
|
338
340
|
response_map['results'] = {}
|
|
@@ -464,11 +466,17 @@ class LocalResources(ResourcesBase):
|
|
|
464
466
|
"readonly": readonly,
|
|
465
467
|
}
|
|
466
468
|
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
469
|
+
txn_start_time = time.time()
|
|
470
|
+
with ExecTxnPoller(
|
|
471
|
+
self.config,
|
|
472
|
+
txn_id=None,
|
|
473
|
+
txn_start_time=txn_start_time
|
|
474
|
+
) as _poller: # unused, except for __enter__ and __exit__ display
|
|
475
|
+
parsed_response = self._create_transaction(
|
|
476
|
+
target_endpoint="create_txn",
|
|
477
|
+
payload=payload,
|
|
478
|
+
headers=headers
|
|
479
|
+
)
|
|
472
480
|
|
|
473
481
|
state = parsed_response["state"]
|
|
474
482
|
if state not in ["COMPLETED", "ABORTED"]:
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Snowflake resources module.
|
|
3
3
|
"""
|
|
4
4
|
# Import order matters - Resources must be imported first since other classes depend on it
|
|
5
|
-
from .snowflake import Resources, Provider, Graph, SnowflakeClient, APP_NAME, PYREL_ROOT_DB, ExecContext, PrimaryKey
|
|
5
|
+
from .snowflake import Resources, Provider, Graph, SnowflakeClient, APP_NAME, PYREL_ROOT_DB, ExecContext, PrimaryKey
|
|
6
6
|
from .engine_service import EngineType, INTERNAL_ENGINE_SIZES, ENGINE_SIZES_AWS, ENGINE_SIZES_AZURE
|
|
7
7
|
# These imports depend on Resources, so they come after
|
|
8
8
|
from .cli_resources import CLIResources
|
|
@@ -14,7 +14,7 @@ __all__ = [
|
|
|
14
14
|
'Resources', 'DirectAccessResources', 'Provider', 'Graph', 'SnowflakeClient',
|
|
15
15
|
'APP_NAME', 'PYREL_ROOT_DB', 'CLIResources', 'UseIndexResources', 'ExecContext', 'EngineType',
|
|
16
16
|
'INTERNAL_ENGINE_SIZES', 'ENGINE_SIZES_AWS', 'ENGINE_SIZES_AZURE', 'PrimaryKey',
|
|
17
|
-
'
|
|
17
|
+
'create_resources_instance',
|
|
18
18
|
]
|
|
19
19
|
|
|
20
20
|
|
|
@@ -18,7 +18,7 @@ from snowflake.snowpark import Session
|
|
|
18
18
|
|
|
19
19
|
# Import UseIndexResources to enable use_index functionality with direct access
|
|
20
20
|
from .use_index_resources import UseIndexResources
|
|
21
|
-
from .snowflake import TxnCreationResult
|
|
21
|
+
from .snowflake import TxnCreationResult, TxnStatusResponse
|
|
22
22
|
|
|
23
23
|
# Import helper functions from util
|
|
24
24
|
from .util import is_engine_issue as _is_engine_issue, is_database_issue as _is_database_issue, collect_error_messages
|
|
@@ -314,7 +314,7 @@ class DirectAccessResources(UseIndexResources):
|
|
|
314
314
|
|
|
315
315
|
return response.json()
|
|
316
316
|
|
|
317
|
-
def _check_exec_async_status(self, txn_id: str, headers: Dict[str, str] | None = None) ->
|
|
317
|
+
def _check_exec_async_status(self, txn_id: str, headers: Dict[str, str] | None = None) -> TxnStatusResponse:
|
|
318
318
|
"""Check whether the given transaction has completed."""
|
|
319
319
|
|
|
320
320
|
with debugging.span("check_status"):
|
|
@@ -349,8 +349,12 @@ class DirectAccessResources(UseIndexResources):
|
|
|
349
349
|
elif reason == TXN_ABORT_REASON_GUARD_RAILS:
|
|
350
350
|
raise GuardRailsException(response_content.get("progress", {}))
|
|
351
351
|
|
|
352
|
-
|
|
353
|
-
|
|
352
|
+
return TxnStatusResponse(
|
|
353
|
+
txn_id=txn_id,
|
|
354
|
+
finished=status in ["COMPLETED", "ABORTED"],
|
|
355
|
+
abort_reason=response_content.get("abort_reason", None),
|
|
356
|
+
progress=response_content.get("progress", None),
|
|
357
|
+
)
|
|
354
358
|
|
|
355
359
|
def _list_exec_async_artifacts(self, txn_id: str, headers: Dict[str, str] | None = None) -> Dict[str, Dict]:
|
|
356
360
|
"""Grab the list of artifacts produced in the transaction and the URLs to retrieve their contents."""
|
|
@@ -15,7 +15,7 @@ import hashlib
|
|
|
15
15
|
from dataclasses import dataclass
|
|
16
16
|
|
|
17
17
|
from ....auth.token_handler import TokenHandler
|
|
18
|
-
from
|
|
18
|
+
from ....clients.exec_txn_poller import ExecTxnPoller
|
|
19
19
|
import snowflake.snowpark
|
|
20
20
|
|
|
21
21
|
from ....rel_utils import sanitize_identifier, to_fqn_relation_name
|
|
@@ -104,7 +104,6 @@ TERMINAL_TXN_STATES = ["COMPLETED", "ABORTED"]
|
|
|
104
104
|
TXN_ABORT_REASON_TIMEOUT = "transaction timeout"
|
|
105
105
|
GUARDRAILS_ABORT_REASON = "guard rail violation"
|
|
106
106
|
|
|
107
|
-
PRINT_TXN_PROGRESS_FLAG = "print_txn_progress"
|
|
108
107
|
ENABLE_GUARD_RAILS_FLAG = "enable_guard_rails"
|
|
109
108
|
|
|
110
109
|
ENABLE_GUARD_RAILS_HEADER = "X-RAI-Enable-Guard-Rails"
|
|
@@ -113,9 +112,6 @@ ENABLE_GUARD_RAILS_HEADER = "X-RAI-Enable-Guard-Rails"
|
|
|
113
112
|
# Helpers
|
|
114
113
|
#--------------------------------------------------
|
|
115
114
|
|
|
116
|
-
def should_print_txn_progress(config) -> bool:
|
|
117
|
-
return bool(config.get(PRINT_TXN_PROGRESS_FLAG, False))
|
|
118
|
-
|
|
119
115
|
def should_enable_guard_rails(config) -> bool:
|
|
120
116
|
return bool(config.get(ENABLE_GUARD_RAILS_FLAG, False))
|
|
121
117
|
|
|
@@ -157,6 +153,14 @@ class TxnCreationResult:
|
|
|
157
153
|
artifact_info: Dict[str, Dict] # Populated if fast-path (state is COMPLETED/ABORTED)
|
|
158
154
|
|
|
159
155
|
|
|
156
|
+
@dataclass
|
|
157
|
+
class TxnStatusResponse:
|
|
158
|
+
"""Transaction progress response for transaction status checks."""
|
|
159
|
+
txn_id: str
|
|
160
|
+
finished: bool
|
|
161
|
+
abort_reason: str | None = None
|
|
162
|
+
progress: Dict | None = None
|
|
163
|
+
|
|
160
164
|
class Resources(ResourcesBase):
|
|
161
165
|
def __init__(
|
|
162
166
|
self,
|
|
@@ -1409,7 +1413,7 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1409
1413
|
# Exec Async
|
|
1410
1414
|
#--------------------------------------------------
|
|
1411
1415
|
|
|
1412
|
-
def _check_exec_async_status(self, txn_id: str, headers: Dict | None = None):
|
|
1416
|
+
def _check_exec_async_status(self, txn_id: str, headers: Dict | None = None) -> TxnStatusResponse:
|
|
1413
1417
|
"""Check whether the given transaction has completed."""
|
|
1414
1418
|
if headers is None:
|
|
1415
1419
|
headers = {}
|
|
@@ -1439,8 +1443,11 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1439
1443
|
elif response_row.get("ABORT_REASON", "") == GUARDRAILS_ABORT_REASON:
|
|
1440
1444
|
raise GuardRailsException()
|
|
1441
1445
|
|
|
1442
|
-
|
|
1443
|
-
|
|
1446
|
+
return TxnStatusResponse(
|
|
1447
|
+
txn_id=txn_id,
|
|
1448
|
+
finished=status in ["COMPLETED", "ABORTED"],
|
|
1449
|
+
abort_reason=response_row.get("ABORT_REASON", None),
|
|
1450
|
+
)
|
|
1444
1451
|
|
|
1445
1452
|
|
|
1446
1453
|
def _list_exec_async_artifacts(self, txn_id: str, headers: Dict | None = None) -> Dict[str, Dict]:
|
|
@@ -1794,10 +1801,8 @@ Otherwise, remove it from your '{profile}' configuration profile.
|
|
|
1794
1801
|
|
|
1795
1802
|
with debugging.span("transaction", **query_attrs_dict) as txn_span:
|
|
1796
1803
|
txn_start_time = time.time()
|
|
1797
|
-
print_txn_progress = should_print_txn_progress(self.config)
|
|
1798
|
-
|
|
1799
1804
|
with ExecTxnPoller(
|
|
1800
|
-
|
|
1805
|
+
config=self.config,
|
|
1801
1806
|
resource=self, txn_id=None, headers=request_headers,
|
|
1802
1807
|
txn_start_time=txn_start_time
|
|
1803
1808
|
) as poller:
|
|
@@ -533,6 +533,14 @@ class Solver:
|
|
|
533
533
|
self.engine_size = engine_size or settings.pop("engine_size", None)
|
|
534
534
|
self.engine_auto_suspend_mins = auto_suspend_mins or settings.pop("auto_suspend_mins", None)
|
|
535
535
|
|
|
536
|
+
# Set default CSV store setting if not already configured
|
|
537
|
+
if "store" not in settings:
|
|
538
|
+
settings["store"] = {}
|
|
539
|
+
if "csv" not in settings["store"]:
|
|
540
|
+
settings["store"]["csv"] = {}
|
|
541
|
+
if "enabled" not in settings["store"]["csv"]:
|
|
542
|
+
settings["store"]["csv"]["enabled"] = True
|
|
543
|
+
|
|
536
544
|
# The settings are used when creating a solver engine, they
|
|
537
545
|
# may configure each individual solver.
|
|
538
546
|
self.engine_settings = settings
|
|
@@ -460,8 +460,8 @@ class LQPExecutor(e.Executor):
|
|
|
460
460
|
txid = raw_results.transaction['id']
|
|
461
461
|
|
|
462
462
|
try:
|
|
463
|
-
cols, extra_cols = self._compute_cols(task, final_model)
|
|
464
|
-
df, errs = result_helpers.format_results(raw_results, cols)
|
|
463
|
+
cols, extra_cols, key_locs = self._compute_cols(task, final_model)
|
|
464
|
+
df, errs = result_helpers.format_results(raw_results, cols, key_locs)
|
|
465
465
|
self.report_errors(errs)
|
|
466
466
|
|
|
467
467
|
# Rename columns if wide outputs is enabled
|
|
@@ -488,7 +488,7 @@ class LQPExecutor(e.Executor):
|
|
|
488
488
|
return DataFrame([full_path], columns=["path"])
|
|
489
489
|
else:
|
|
490
490
|
raise ValueError("The CSV export was not successful!")
|
|
491
|
-
|
|
491
|
+
|
|
492
492
|
return self._postprocess_df(self.config, df, extra_cols)
|
|
493
493
|
|
|
494
494
|
except Exception as e:
|
|
@@ -126,37 +126,43 @@ def _translate_to_decls(ctx: TranslationCtx, rule: ir.Logical) -> list[lqp.Decla
|
|
|
126
126
|
def _translate_to_constraint_decls(ctx: TranslationCtx, rule: ir.Logical) -> list[lqp.Declaration]:
|
|
127
127
|
constraint_decls: list[lqp.Declaration] = []
|
|
128
128
|
for task in rule.body:
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
assert fd is not None
|
|
132
|
-
|
|
133
|
-
# check for unresolved types
|
|
134
|
-
if any(types.is_any(var.type) for var in fd.keys + fd.values):
|
|
135
|
-
warn(f"Ignoring FD with unresolved type: {fd}")
|
|
129
|
+
if isinstance(task, ir.Logical):
|
|
130
|
+
constraint_decls.extend(_translate_to_constraint_decls(ctx, task))
|
|
136
131
|
continue
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
lqp_guard = mk_abstraction(lqp_typed_vars, mk_and(lqp_guard_atoms))
|
|
143
|
-
lqp_keys:list[lqp.Var] = [var for (var, _) in lqp_typed_keys] # type: ignore
|
|
144
|
-
lqp_values:list[lqp.Var] = [var for (var, _) in lqp_typed_values] # type: ignore
|
|
145
|
-
lqp_id = utils.lqp_hash(fd.canonical_str)
|
|
146
|
-
lqp_name:lqp.RelationId = lqp.RelationId(id=lqp_id, meta=None)
|
|
147
|
-
|
|
148
|
-
fd_decl = lqp.FunctionalDependency(
|
|
149
|
-
name=lqp_name,
|
|
150
|
-
guard=lqp_guard,
|
|
151
|
-
keys=lqp_keys,
|
|
152
|
-
values=lqp_values,
|
|
153
|
-
meta=None
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
constraint_decls.append(fd_decl)
|
|
157
|
-
|
|
132
|
+
else:
|
|
133
|
+
assert isinstance(task, ir.Require)
|
|
134
|
+
decl = _translate_to_constraint_decl(ctx, task)
|
|
135
|
+
if decl is not None:
|
|
136
|
+
constraint_decls.append(decl)
|
|
158
137
|
return constraint_decls
|
|
159
138
|
|
|
139
|
+
def _translate_to_constraint_decl(ctx: TranslationCtx, rule: ir.Require) -> Optional[lqp.Declaration]:
|
|
140
|
+
fd = normalized_fd(rule)
|
|
141
|
+
assert fd is not None
|
|
142
|
+
|
|
143
|
+
# check for unresolved types
|
|
144
|
+
if any(types.is_any(var.type) for var in fd.keys + fd.values):
|
|
145
|
+
warn(f"Ignoring FD with unresolved type: {fd}")
|
|
146
|
+
return None
|
|
147
|
+
|
|
148
|
+
lqp_typed_keys = [_translate_term(ctx, key) for key in fd.keys]
|
|
149
|
+
lqp_typed_values = [_translate_term(ctx, value) for value in fd.values]
|
|
150
|
+
lqp_typed_vars:list[Tuple[lqp.Var, lqp.Type]] = lqp_typed_keys + lqp_typed_values # type: ignore
|
|
151
|
+
lqp_guard_atoms = [_translate_to_atom(ctx, atom) for atom in fd.guard]
|
|
152
|
+
lqp_guard = mk_abstraction(lqp_typed_vars, mk_and(lqp_guard_atoms))
|
|
153
|
+
lqp_keys:list[lqp.Var] = [var for (var, _) in lqp_typed_keys] # type: ignore
|
|
154
|
+
lqp_values:list[lqp.Var] = [var for (var, _) in lqp_typed_values] # type: ignore
|
|
155
|
+
lqp_id = utils.lqp_hash(fd.canonical_str)
|
|
156
|
+
lqp_name:lqp.RelationId = lqp.RelationId(id=lqp_id, meta=None)
|
|
157
|
+
|
|
158
|
+
return lqp.FunctionalDependency(
|
|
159
|
+
name=lqp_name,
|
|
160
|
+
guard=lqp_guard,
|
|
161
|
+
keys=lqp_keys,
|
|
162
|
+
values=lqp_values,
|
|
163
|
+
meta=None
|
|
164
|
+
)
|
|
165
|
+
|
|
160
166
|
def _translate_algorithms(ctx: TranslationCtx, task: ir.Logical) -> list[lqp.Declaration]:
|
|
161
167
|
assert is_algorithm_logical(task)
|
|
162
168
|
decls: list[lqp.Declaration] = []
|
|
@@ -2,11 +2,12 @@ from relationalai.semantics.metamodel.compiler import Pass
|
|
|
2
2
|
from relationalai.semantics.metamodel.typer import Checker, InferTypes
|
|
3
3
|
|
|
4
4
|
from ..metamodel.rewrite import (
|
|
5
|
-
DNFUnionSplitter,
|
|
5
|
+
DNFUnionSplitter, Flatten, FormatOutputs, ExtractNestedLogicals,
|
|
6
|
+
# HandleAggregationsAndRanks
|
|
6
7
|
)
|
|
7
8
|
from .rewrite import (
|
|
8
9
|
AlgorithmPass, AnnotateConstraints, CDC, ConstantsToVars, DeduplicateVars,
|
|
9
|
-
ExtractCommon, EliminateData, ExtractKeys, FunctionAnnotations, PeriodMath,
|
|
10
|
+
ExtractCommon, EliminateData, ExtractKeys, FlattenScript, FunctionAnnotations, PeriodMath,
|
|
10
11
|
QuantifyVars, Splinter, SplitMultiCheckRequires, UnifyDefinitions,
|
|
11
12
|
)
|
|
12
13
|
|
|
@@ -17,13 +18,15 @@ def lqp_passes() -> list[Pass]:
|
|
|
17
18
|
AnnotateConstraints(),
|
|
18
19
|
Checker(),
|
|
19
20
|
CDC(), # specialize to physical relations before extracting nested and typing
|
|
20
|
-
ExtractNestedLogicals(),
|
|
21
|
+
ExtractNestedLogicals(),
|
|
21
22
|
InferTypes(),
|
|
22
23
|
DNFUnionSplitter(), # Handle unions that require DNF decomposition
|
|
23
24
|
ExtractKeys(), # Create a logical for each valid combinations of keys
|
|
24
25
|
FormatOutputs(),
|
|
25
26
|
ExtractCommon(), # Extracts tasks that will become common after Flatten into their own definition
|
|
26
27
|
Flatten(), # Move nested tasks to the top level, and various related things touched along the way
|
|
28
|
+
FlattenScript(), # Additional flattening specific to scripts
|
|
29
|
+
# HandleAggregationsAndRanks(), # Handle aggregation and rank dependencies
|
|
27
30
|
Splinter(), # Splits multi-headed rules into multiple rules
|
|
28
31
|
QuantifyVars(), # Adds missing existentials
|
|
29
32
|
EliminateData(), # Turns Data nodes into ordinary relations.
|
|
@@ -13,8 +13,12 @@ from relationalai.clients.result_helpers import format_columns, format_value, me
|
|
|
13
13
|
sort_data_frame_result
|
|
14
14
|
from relationalai.tools.constants import Generation
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
|
|
16
|
+
# Convert LQP results into the expected single wide table dataframe for the end user
|
|
17
|
+
# - Requires identifying and unrolling all GNF relations and populating any resulting nulls
|
|
18
|
+
# - Relies on expected ordering of results, as we do not have IDs or names associated with columns here.
|
|
19
|
+
# - At the end, col names are rewritten with the expected output names blindly; we trust the stitching
|
|
20
|
+
# has been done correctly for the names to line up with the correct columns
|
|
21
|
+
def format_results(results, result_cols:List[str]|None = None, key_locations:List[int]|None = None) -> Tuple[DataFrame, List[Any]]:
|
|
18
22
|
with debugging.span("format_results"):
|
|
19
23
|
data_frame = DataFrame()
|
|
20
24
|
problems = defaultdict(
|
|
@@ -37,8 +41,15 @@ def format_results(results, result_cols:List[str]|None = None) -> Tuple[DataFra
|
|
|
37
41
|
|
|
38
42
|
# Check if there are any results to process
|
|
39
43
|
if len(results.results):
|
|
40
|
-
ret_cols = result_cols or []
|
|
41
|
-
|
|
44
|
+
ret_cols = result_cols or [] # output column names
|
|
45
|
+
key_locations = key_locations or [] # where output keys are located in outputs
|
|
46
|
+
out_keys_n = len(key_locations) # number of keys in output
|
|
47
|
+
assert out_keys_n <= len(ret_cols)
|
|
48
|
+
out_vals_n = len(ret_cols) - out_keys_n # number of values in output
|
|
49
|
+
|
|
50
|
+
# only create cols for values, we handle keys separately as they are not GNF
|
|
51
|
+
has_cols:List[DataFrame] = [DataFrame() for _ in range(0, out_vals_n)]
|
|
52
|
+
keys_data_frame = DataFrame()
|
|
42
53
|
key_len = 0
|
|
43
54
|
|
|
44
55
|
for result in results.results:
|
|
@@ -46,7 +57,7 @@ def format_results(results, result_cols:List[str]|None = None) -> Tuple[DataFra
|
|
|
46
57
|
result_frame = result["table"].to_pandas()
|
|
47
58
|
types = [
|
|
48
59
|
t
|
|
49
|
-
for t in
|
|
60
|
+
for t in relation_id.split("/")
|
|
50
61
|
if t != "" and not t.startswith(":")
|
|
51
62
|
]
|
|
52
63
|
|
|
@@ -168,13 +179,15 @@ def format_results(results, result_cols:List[str]|None = None) -> Tuple[DataFra
|
|
|
168
179
|
else:
|
|
169
180
|
result_frame = format_columns(result_frame, types, Generation.QB)
|
|
170
181
|
result["table"] = result_frame
|
|
171
|
-
if "/:output" in
|
|
172
|
-
and "_cols_col" in
|
|
182
|
+
if "/:output" in relation_id \
|
|
183
|
+
and "_cols_col" in relation_id:
|
|
173
184
|
# Match rows with an id like "/:output.*_cols_col[0-9]+"
|
|
174
|
-
|
|
175
|
-
|
|
185
|
+
# These should be all of the GNF value outputs
|
|
186
|
+
matched = re.search(r"_cols_col([0-9]+)", relation_id)
|
|
187
|
+
assert matched, f"Column id not found for: {relation_id}"
|
|
176
188
|
col_ix = int(matched.group(1))
|
|
177
189
|
|
|
190
|
+
# Generate col names and write them into the df (idn for keys, vn for cols)
|
|
178
191
|
key_cols = [f"id{i}" for i in range(0, len(result_frame.columns) - 1)]
|
|
179
192
|
key_len = len(key_cols)
|
|
180
193
|
result_frame.columns = [*key_cols, f"v{col_ix}"]
|
|
@@ -183,20 +196,71 @@ def format_results(results, result_cols:List[str]|None = None) -> Tuple[DataFra
|
|
|
183
196
|
has_cols[col_ix] = result_frame
|
|
184
197
|
else:
|
|
185
198
|
has_cols[col_ix] = pd.concat([has_cols[col_ix], result_frame], ignore_index=True)
|
|
186
|
-
elif ":output" in
|
|
199
|
+
elif ":output" in relation_id \
|
|
200
|
+
and "_keys" in relation_id:
|
|
201
|
+
# data for all keys (wide), to merge in later
|
|
202
|
+
keys_data_frame = result_frame
|
|
203
|
+
|
|
204
|
+
# Rename wide key col names to match key cols in df_wide_reset
|
|
205
|
+
keys_data_frame.columns = pd.RangeIndex(len(keys_data_frame.columns))
|
|
206
|
+
keys_data_frame = keys_data_frame.rename(columns=lambda c: f"id{c}")
|
|
207
|
+
|
|
208
|
+
elif ":output" in relation_id: # wide outputs case
|
|
187
209
|
data_frame = pd.concat(
|
|
188
210
|
[data_frame, result_frame], ignore_index=True
|
|
189
211
|
)
|
|
190
212
|
|
|
213
|
+
# GNF values case: stitch together output vals and keys into one wide dataframe
|
|
191
214
|
if any(not col.empty for col in has_cols):
|
|
215
|
+
# Merge value cols together by their key cols
|
|
192
216
|
key_cols = [f"id{i}" for i in range(0, key_len)]
|
|
193
217
|
df_wide_reset = reduce(lambda left, right: merge_columns(left, right, key_cols), has_cols)
|
|
194
|
-
|
|
218
|
+
|
|
219
|
+
# Join wide keys with wide vals (keys all at the front; still needs reordering)
|
|
220
|
+
data_frame = pd.merge(keys_data_frame, df_wide_reset, on=key_cols, how='outer')
|
|
221
|
+
|
|
222
|
+
# Reorder outputs
|
|
223
|
+
if key_locations:
|
|
224
|
+
data_frame = _shift_keys(data_frame, keys_data_frame, out_keys_n, key_locations)
|
|
225
|
+
|
|
226
|
+
else: # if no keys in output, just drop all of the key cols
|
|
227
|
+
data_frame = data_frame.drop(columns=key_cols)
|
|
228
|
+
|
|
229
|
+
# Empty values case: reorder/drop keys as needed
|
|
230
|
+
elif not keys_data_frame.empty:
|
|
231
|
+
if key_locations: # Reorder outputs
|
|
232
|
+
data_frame = _shift_keys(keys_data_frame, keys_data_frame, out_keys_n, key_locations)
|
|
233
|
+
else: # if there are no keys to output, we may still need to populate nulls for output values
|
|
234
|
+
# Take into account the cols that could contain values (even though they're empty)
|
|
235
|
+
key_cols = [f"id{i}" for i in range(0, len(keys_data_frame.columns))]
|
|
236
|
+
has_cols.append(keys_data_frame) # include the keys so we know how many nulls to generate
|
|
237
|
+
df_wide_reset = reduce(lambda left, right: merge_columns(left, right, key_cols), has_cols)
|
|
238
|
+
data_frame = df_wide_reset.drop(columns=key_cols)
|
|
195
239
|
|
|
196
240
|
data_frame = sort_data_frame_result(data_frame)
|
|
197
241
|
|
|
198
|
-
|
|
242
|
+
# Overwrite column names with user-defined names
|
|
243
|
+
# The assumption is that the extra keys have been chopped off the front, and the
|
|
244
|
+
# remaining columns are in the correct order and require renaming
|
|
245
|
+
if len(ret_cols) and len(data_frame.columns) <= len(ret_cols):
|
|
199
246
|
if result_cols is not None:
|
|
200
247
|
data_frame.columns = result_cols[: len(data_frame.columns)]
|
|
201
248
|
|
|
202
249
|
return (data_frame, list(problems.values()))
|
|
250
|
+
|
|
251
|
+
# Reorder `res` df to match user-specified output order and drop non-output key cols
|
|
252
|
+
# E.g., Current df looks like:
|
|
253
|
+
# [out_key_pos1, out_key_pos4, hidden_key_1, value_pos2, value_pos3, value_pos5]
|
|
254
|
+
# Target df looks like:
|
|
255
|
+
# [out_key_pos1, value_pos2, value_pos3, out_key_pos4, value_pos5]
|
|
256
|
+
def _shift_keys(res:DataFrame, keys_data_frame:DataFrame, out_keys_n:int, key_locations:List[int]):
|
|
257
|
+
offset = len(keys_data_frame.columns) # index of first value in data frame
|
|
258
|
+
assert out_keys_n <= offset
|
|
259
|
+
extra_keys_n = offset - out_keys_n # number of keys to drop (those not in output)
|
|
260
|
+
|
|
261
|
+
# Shift output keys into the correct spots and drop the rest
|
|
262
|
+
for i in key_locations:
|
|
263
|
+
res = res[res.columns[1:].insert(i + offset - 1, res.columns[0])]
|
|
264
|
+
offset -= 1
|
|
265
|
+
res = res.drop(columns=res.columns[:extra_keys_n])
|
|
266
|
+
return res
|
|
@@ -7,6 +7,7 @@ from .eliminate_data import EliminateData
|
|
|
7
7
|
from .extract_common import ExtractCommon
|
|
8
8
|
from .extract_keys import ExtractKeys
|
|
9
9
|
from .function_annotations import FunctionAnnotations, SplitMultiCheckRequires
|
|
10
|
+
from .flatten_script import FlattenScript
|
|
10
11
|
from .period_math import PeriodMath
|
|
11
12
|
from .quantify_vars import QuantifyVars
|
|
12
13
|
from .splinter import Splinter
|
|
@@ -22,6 +23,7 @@ __all__ = [
|
|
|
22
23
|
"ExtractCommon",
|
|
23
24
|
"ExtractKeys",
|
|
24
25
|
"FunctionAnnotations",
|
|
26
|
+
"FlattenScript",
|
|
25
27
|
"PeriodMath",
|
|
26
28
|
"QuantifyVars",
|
|
27
29
|
"Splinter",
|
|
@@ -315,8 +315,10 @@ def _compute_local_dependencies(ctx: ExtractCommon.Context, binders: OrderedSet[
|
|
|
315
315
|
return local_body
|
|
316
316
|
|
|
317
317
|
def _is_binder(task: ir.Task):
|
|
318
|
+
binder_types = (ir.Lookup, ir.Construct, ir.Exists, ir.Data, ir.Not)
|
|
319
|
+
|
|
318
320
|
# If the task itself is a binder
|
|
319
|
-
if
|
|
321
|
+
if isinstance(task, binder_types):
|
|
320
322
|
return True
|
|
321
323
|
|
|
322
324
|
# If the task is a Logical containing only binders
|