relationalai 0.11.3__py3-none-any.whl → 0.11.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/clients/snowflake.py +6 -1
- relationalai/clients/use_index_poller.py +349 -188
- relationalai/early_access/dsl/bindings/csv.py +2 -2
- relationalai/semantics/internal/internal.py +22 -4
- relationalai/semantics/lqp/executor.py +61 -12
- relationalai/semantics/lqp/intrinsics.py +23 -0
- relationalai/semantics/lqp/model2lqp.py +13 -4
- relationalai/semantics/lqp/passes.py +2 -3
- relationalai/semantics/lqp/primitives.py +12 -1
- relationalai/semantics/metamodel/builtins.py +8 -1
- relationalai/semantics/metamodel/factory.py +3 -2
- relationalai/semantics/reasoners/graph/core.py +54 -2
- relationalai/semantics/reasoners/optimization/solvers_dev.py +20 -1
- relationalai/semantics/reasoners/optimization/solvers_pb.py +24 -3
- relationalai/semantics/rel/compiler.py +5 -17
- relationalai/semantics/rel/executor.py +2 -2
- relationalai/semantics/rel/rel.py +6 -0
- relationalai/semantics/rel/rel_utils.py +8 -1
- relationalai/semantics/rel/rewrite/extract_common.py +153 -242
- relationalai/semantics/sql/compiler.py +120 -39
- relationalai/semantics/sql/executor/duck_db.py +21 -0
- relationalai/semantics/sql/rewrite/denormalize.py +4 -6
- relationalai/semantics/sql/rewrite/recursive_union.py +23 -3
- relationalai/semantics/sql/sql.py +27 -0
- relationalai/semantics/std/__init__.py +2 -1
- relationalai/semantics/std/datetime.py +4 -0
- relationalai/semantics/std/re.py +83 -0
- relationalai/semantics/std/strings.py +1 -1
- relationalai/tools/cli_controls.py +445 -60
- relationalai/util/format.py +78 -1
- {relationalai-0.11.3.dist-info → relationalai-0.11.4.dist-info}/METADATA +3 -2
- {relationalai-0.11.3.dist-info → relationalai-0.11.4.dist-info}/RECORD +35 -33
- {relationalai-0.11.3.dist-info → relationalai-0.11.4.dist-info}/WHEEL +0 -0
- {relationalai-0.11.3.dist-info → relationalai-0.11.4.dist-info}/entry_points.txt +0 -0
- {relationalai-0.11.3.dist-info → relationalai-0.11.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -66,6 +66,22 @@ class RelationInfo:
|
|
|
66
66
|
table_selects: list[sql.Select] = field(default_factory=list)
|
|
67
67
|
dynamic_table_selects: list[sql.Select] = field(default_factory=list)
|
|
68
68
|
|
|
69
|
+
@dataclass
|
|
70
|
+
class ImportSpec:
|
|
71
|
+
value: str
|
|
72
|
+
module: Optional[str] = None # e.g., "scipy.special"
|
|
73
|
+
|
|
74
|
+
def render(self) -> str:
|
|
75
|
+
return f"from {self.module} import {self.value}" if self.module else f"import {self.value}"
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@dataclass
|
|
79
|
+
class UDFConfig:
|
|
80
|
+
handler: str
|
|
81
|
+
code: str
|
|
82
|
+
imports: list[ImportSpec] = field(default_factory=list)
|
|
83
|
+
packages: list[str] = field(default_factory=list)
|
|
84
|
+
|
|
69
85
|
@dataclass
|
|
70
86
|
class ModelToSQL:
|
|
71
87
|
""" Generates SQL from an IR Model, assuming the compiler rewrites were done. """
|
|
@@ -89,9 +105,9 @@ class ModelToSQL:
|
|
|
89
105
|
return sql.Program(self._sort_dependencies(self._union_output_selects(self._generate_statements(model))))
|
|
90
106
|
|
|
91
107
|
def _generate_statements(self, model: ir.Model) -> list[sql.Node]:
|
|
92
|
-
|
|
108
|
+
table_relations, used_builtins = self._get_relations(model)
|
|
93
109
|
|
|
94
|
-
self._register_relation_args(
|
|
110
|
+
self._register_relation_args(table_relations)
|
|
95
111
|
self._register_external_relations(model)
|
|
96
112
|
|
|
97
113
|
statements: list[sql.Node] = []
|
|
@@ -139,11 +155,15 @@ class ModelToSQL:
|
|
|
139
155
|
)
|
|
140
156
|
|
|
141
157
|
# 4. Create physical tables for explicitly declared table relations
|
|
142
|
-
for relation in
|
|
158
|
+
for relation in table_relations:
|
|
143
159
|
info = self.relation_infos.get(relation)
|
|
144
160
|
if info is None or info.table_selects:
|
|
145
161
|
statements.append(self._create_table(relation))
|
|
146
162
|
|
|
163
|
+
#5. Create Snowflake user-defined functions
|
|
164
|
+
if not self._is_duck_db:
|
|
165
|
+
statements.extend(self._create_user_defined_functions(used_builtins))
|
|
166
|
+
|
|
147
167
|
return statements
|
|
148
168
|
|
|
149
169
|
#--------------------------------------------------
|
|
@@ -204,6 +224,62 @@ class ModelToSQL:
|
|
|
204
224
|
)
|
|
205
225
|
)
|
|
206
226
|
|
|
227
|
+
def _create_user_defined_functions(self, relations: list[ir.Relation]) -> list[sql.CreateFunction]:
|
|
228
|
+
# Central UDF metadata configuration
|
|
229
|
+
udf_relations: dict[str, UDFConfig] = {
|
|
230
|
+
builtins.acot.name: UDFConfig(
|
|
231
|
+
handler="compute",
|
|
232
|
+
imports=[ImportSpec("math")],
|
|
233
|
+
code="""def compute(x): return math.atan(1 / x) if x != 0 else math.copysign(math.pi / 2, x)"""
|
|
234
|
+
),
|
|
235
|
+
builtins.erf.name: UDFConfig(
|
|
236
|
+
handler="compute",
|
|
237
|
+
imports=[ImportSpec("math")],
|
|
238
|
+
code="""def compute(x): return math.erf(x)"""
|
|
239
|
+
),
|
|
240
|
+
builtins.erfinv.name: UDFConfig(
|
|
241
|
+
handler="compute",
|
|
242
|
+
imports=[ImportSpec("erfinv", module="scipy.special")],
|
|
243
|
+
packages=["'scipy'"],
|
|
244
|
+
code="""def compute(x): return erfinv(x)"""
|
|
245
|
+
)
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
statements: list[sql.CreateFunction] = []
|
|
249
|
+
|
|
250
|
+
for r in relations:
|
|
251
|
+
meta = udf_relations.get(r.name)
|
|
252
|
+
if not meta:
|
|
253
|
+
continue
|
|
254
|
+
|
|
255
|
+
# Split relation fields into inputs and return type
|
|
256
|
+
# We expect a single return argument per builtin relation
|
|
257
|
+
return_type = None
|
|
258
|
+
input_columns: list[sql.Column] = []
|
|
259
|
+
for f in r.fields:
|
|
260
|
+
if f.input:
|
|
261
|
+
input_columns.append(sql.Column(self._var_name(r.id, f), self._convert_type(f.type)))
|
|
262
|
+
else:
|
|
263
|
+
return_type = self._convert_type(f.type)
|
|
264
|
+
|
|
265
|
+
# Build a full code block (imports + code)
|
|
266
|
+
imports_code = "\n".join(imp.render() for imp in meta.imports)
|
|
267
|
+
python_block = "\n".join(part for part in (imports_code, meta.code) if part)
|
|
268
|
+
|
|
269
|
+
assert return_type, f"No return type found for relation '{r.name}'"
|
|
270
|
+
statements.append(
|
|
271
|
+
sql.CreateFunction(
|
|
272
|
+
name=r.name,
|
|
273
|
+
inputs=input_columns,
|
|
274
|
+
return_type=return_type,
|
|
275
|
+
handler=meta.handler,
|
|
276
|
+
body=python_block,
|
|
277
|
+
packages=meta.packages
|
|
278
|
+
)
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
return statements
|
|
282
|
+
|
|
207
283
|
def _create_statement(self, task: ir.Logical):
|
|
208
284
|
|
|
209
285
|
# TODO - improve the typing info to avoid these casts
|
|
@@ -1186,7 +1262,7 @@ class ModelToSQL:
|
|
|
1186
1262
|
# SF returns values in `""` and to avoid this, we need to cast it to `TEXT` type
|
|
1187
1263
|
part_expr = f"cast({table_sql_var}.value as TEXT)"
|
|
1188
1264
|
index_expr = f"({table_sql_var}.index + 1)" # SF is 0-based internally, adjust to it back
|
|
1189
|
-
assert isinstance(index, ir.Var) and isinstance(part, ir.Var), "Third and
|
|
1265
|
+
assert isinstance(index, ir.Var) and isinstance(part, ir.Var), "Third and fourth arguments (index, part) must be variables"
|
|
1190
1266
|
builtin_vars[part] = part_expr
|
|
1191
1267
|
builtin_vars[index] = index_expr
|
|
1192
1268
|
elif relation == builtins.range:
|
|
@@ -1205,7 +1281,7 @@ class ModelToSQL:
|
|
|
1205
1281
|
table_expr = f"LATERAL FLATTEN(input => ARRAY_GENERATE_RANGE({start}, ({stop} + 1), {step}))"
|
|
1206
1282
|
expr = f"{table_sql_var}.value"
|
|
1207
1283
|
table_expressions[table_sql_var] = table_expr
|
|
1208
|
-
assert isinstance(result, ir.Var), "
|
|
1284
|
+
assert isinstance(result, ir.Var), "Fourth argument (result) must be a variable"
|
|
1209
1285
|
builtin_vars[result] = f"{expr}"
|
|
1210
1286
|
elif relation == builtins.cast:
|
|
1211
1287
|
assert len(args) == 3, f"Expected 3 args for `cast`, got {len(args)}: {args}"
|
|
@@ -1214,7 +1290,7 @@ class ModelToSQL:
|
|
|
1214
1290
|
assert isinstance(result, ir.Var), "Third argument (result) must be a variable"
|
|
1215
1291
|
|
|
1216
1292
|
builtin_vars[result] = original_raw
|
|
1217
|
-
elif relation in
|
|
1293
|
+
elif relation in {builtins.isnan, builtins.isinf}:
|
|
1218
1294
|
arg_expr = self._var_to_expr(args[0], reference, resolve_builtin_var, var_to_construct)
|
|
1219
1295
|
expr = "cast('NaN' AS DOUBLE)" if relation == builtins.isnan else "cast('Infinity' AS DOUBLE)"
|
|
1220
1296
|
wheres.append(sql.Terminal(f"{arg_expr} = {expr}"))
|
|
@@ -1351,8 +1427,8 @@ class ModelToSQL:
|
|
|
1351
1427
|
builtin_vars[rhs] = lhs
|
|
1352
1428
|
date_period_var_type[rhs] = relation.name
|
|
1353
1429
|
elif relation in builtins.date_builtins:
|
|
1354
|
-
if relation in
|
|
1355
|
-
builtins.datetime_subtract
|
|
1430
|
+
if relation in {builtins.date_add, builtins.date_subtract, builtins.datetime_add,
|
|
1431
|
+
builtins.datetime_subtract}:
|
|
1356
1432
|
assert len(args) == 3, f"Expected 3 args for {relation}, got {len(args)}: {args}"
|
|
1357
1433
|
assert isinstance(rhs, ir.Var), f"Period variable must be `ir.Var`, got: {rhs}"
|
|
1358
1434
|
period = date_period_var_type[rhs]
|
|
@@ -1361,10 +1437,10 @@ class ModelToSQL:
|
|
|
1361
1437
|
left = self._var_to_expr(lhs, reference, resolve_builtin_var, var_to_construct)
|
|
1362
1438
|
|
|
1363
1439
|
if self._is_duck_db:
|
|
1364
|
-
op = "+" if relation in
|
|
1440
|
+
op = "+" if relation in {builtins.date_add, builtins.datetime_add} else "-"
|
|
1365
1441
|
expr = f"({left} {op} {period_val} * interval 1 {period})"
|
|
1366
1442
|
else:
|
|
1367
|
-
sign = 1 if relation in
|
|
1443
|
+
sign = 1 if relation in {builtins.date_add, builtins.datetime_add} else -1
|
|
1368
1444
|
expr = f"dateadd({period}, ({sign} * {period_val}), {left})"
|
|
1369
1445
|
|
|
1370
1446
|
result_var = args[2]
|
|
@@ -1439,23 +1515,19 @@ class ModelToSQL:
|
|
|
1439
1515
|
rel_name = relation.name
|
|
1440
1516
|
left = self._var_to_expr(lhs, reference, resolve_builtin_var, var_to_construct)
|
|
1441
1517
|
if relation in builtins.math_unary_builtins:
|
|
1442
|
-
if rel_name
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
sub_expr = f"10, {left}"
|
|
1454
|
-
method = "log"
|
|
1455
|
-
expr = f"{method}({sub_expr})"
|
|
1456
|
-
elif rel_name in (builtins.minimum.name, builtins.maximum.name, builtins.trunc_div.name,
|
|
1518
|
+
method = "ln" if rel_name == builtins.natural_log.name else rel_name
|
|
1519
|
+
sub_expr = left
|
|
1520
|
+
if rel_name == builtins.factorial.name and self._is_duck_db:
|
|
1521
|
+
# Factorial requires an integer operand in DuckDB
|
|
1522
|
+
sub_expr = f"{left}::INTEGER"
|
|
1523
|
+
elif rel_name == builtins.log10.name:
|
|
1524
|
+
# log10 is not supported, so we use log with base 10
|
|
1525
|
+
sub_expr = f"10, {left}"
|
|
1526
|
+
method = "log"
|
|
1527
|
+
expr = f"{method}({sub_expr})"
|
|
1528
|
+
elif rel_name in {builtins.minimum.name, builtins.maximum.name, builtins.trunc_div.name,
|
|
1457
1529
|
builtins.power.name, builtins.mod.name, builtins.pow.name,
|
|
1458
|
-
builtins.log.name
|
|
1530
|
+
builtins.log.name}:
|
|
1459
1531
|
assert len(args) == 3, f"Expected 3 args for {relation}, got {len(args)}: {args}"
|
|
1460
1532
|
|
|
1461
1533
|
result_var = args[2]
|
|
@@ -1480,7 +1552,7 @@ class ModelToSQL:
|
|
|
1480
1552
|
f"but got `{type(result_var).__name__}`: {result_var}"
|
|
1481
1553
|
)
|
|
1482
1554
|
builtin_vars[result_var] = expr
|
|
1483
|
-
elif relation in
|
|
1555
|
+
elif relation in {builtins.parse_int64, builtins.parse_int128} and isinstance(rhs, ir.Var):
|
|
1484
1556
|
builtin_vars[rhs] = self._var_to_expr(lhs, reference, resolve_builtin_var, var_to_construct, False)
|
|
1485
1557
|
elif helpers.is_from_cast(lookup) and isinstance(rhs, ir.Var):
|
|
1486
1558
|
# For the `from cast` relations we keep the raw var, and we will ground it later.
|
|
@@ -2071,21 +2143,27 @@ class ModelToSQL:
|
|
|
2071
2143
|
return f"DECIMAL({base_type.precision},{base_type.scale})"
|
|
2072
2144
|
raise Exception(f"Unknown built-in type: {t}")
|
|
2073
2145
|
|
|
2074
|
-
def _get_relations(self, model: ir.Model) -> list[ir.Relation]:
|
|
2146
|
+
def _get_relations(self, model: ir.Model) -> Tuple[list[ir.Relation], list[ir.Relation]]:
|
|
2147
|
+
rw = ReadWriteVisitor()
|
|
2148
|
+
model.accept(rw)
|
|
2149
|
+
|
|
2150
|
+
root = cast(ir.Logical, model.root)
|
|
2151
|
+
|
|
2152
|
+
# For query compilation exclude read-only tables because we do not need to declare `CREATE TABLE` statements
|
|
2153
|
+
used_relations = rw.writes(root) if self._query_compilation else rw.writes(root) | rw.reads(root)
|
|
2154
|
+
|
|
2075
2155
|
# Filter only relations that require table creation
|
|
2076
|
-
|
|
2077
|
-
r for r in
|
|
2156
|
+
table_relations = [
|
|
2157
|
+
r for r in used_relations
|
|
2078
2158
|
if self._is_table_creation_required(r)
|
|
2079
2159
|
]
|
|
2080
2160
|
|
|
2081
|
-
|
|
2082
|
-
|
|
2083
|
-
|
|
2084
|
-
|
|
2085
|
-
writable = rw.writes(cast(ir.Logical, model.root))
|
|
2086
|
-
relations = [r for r in relations if r in writable]
|
|
2161
|
+
used_builtins = [
|
|
2162
|
+
r for r in rw.reads(root)
|
|
2163
|
+
if builtins.is_builtin(r)
|
|
2164
|
+
]
|
|
2087
2165
|
|
|
2088
|
-
return
|
|
2166
|
+
return table_relations, used_builtins
|
|
2089
2167
|
|
|
2090
2168
|
def _is_table_creation_required(self, r: ir.Relation) -> bool:
|
|
2091
2169
|
"""
|
|
@@ -2161,7 +2239,7 @@ class ModelToSQL:
|
|
|
2161
2239
|
|
|
2162
2240
|
def _var_name(self, relation_id: int, arg: Union[ir.Var, ir.Field]):
|
|
2163
2241
|
name = helpers.sanitize(self.relation_arg_name_cache.get_name((relation_id, arg.id), arg.name))
|
|
2164
|
-
return f'"{name}"' if name.lower() in
|
|
2242
|
+
return f'"{name}"' if name.lower() in {"any", "order"} else name
|
|
2165
2243
|
|
|
2166
2244
|
def _register_relation_args(self, relations: list[ir.Relation]):
|
|
2167
2245
|
"""
|
|
@@ -2308,6 +2386,7 @@ class ModelToSQL:
|
|
|
2308
2386
|
3. Other statements except SELECT queries
|
|
2309
2387
|
4. SELECT queries
|
|
2310
2388
|
"""
|
|
2389
|
+
udfs = []
|
|
2311
2390
|
create_tables = []
|
|
2312
2391
|
need_sort: dict[str, list[Union[sql.Insert, sql.CreateView, sql.CreateDynamicTable]]] = defaultdict(list)
|
|
2313
2392
|
updates = []
|
|
@@ -2327,12 +2406,14 @@ class ModelToSQL:
|
|
|
2327
2406
|
updates.append(statement)
|
|
2328
2407
|
elif isinstance(statement, sql.Select):
|
|
2329
2408
|
selects.append(statement)
|
|
2409
|
+
elif isinstance(statement, sql.CreateFunction):
|
|
2410
|
+
udfs.append(statement)
|
|
2330
2411
|
else:
|
|
2331
2412
|
miscellaneous_statements.append(statement)
|
|
2332
2413
|
|
|
2333
2414
|
sorted_statements = self._sort_statements_dependency_graph(need_sort)
|
|
2334
2415
|
|
|
2335
|
-
return create_tables + sorted_statements + updates + miscellaneous_statements + selects
|
|
2416
|
+
return udfs + create_tables + sorted_statements + updates + miscellaneous_statements + selects
|
|
2336
2417
|
|
|
2337
2418
|
@staticmethod
|
|
2338
2419
|
def _sort_statements_dependency_graph(statements: dict[str, list[Union[sql.Insert, sql.CreateView, sql.CreateDynamicTable]]]) -> list[sql.Insert]:
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import duckdb
|
|
4
|
+
import math
|
|
4
5
|
from pandas import DataFrame
|
|
5
6
|
from typing import Any, Union, Literal
|
|
7
|
+
from scipy.special import erfinv as special_erfinv
|
|
6
8
|
|
|
7
9
|
from relationalai.semantics.sql import Compiler
|
|
8
10
|
from relationalai.semantics.sql.executor.result_helpers import format_duckdb_columns
|
|
@@ -18,10 +20,29 @@ class DuckDBExecutor(e.Executor):
|
|
|
18
20
|
""" Execute the SQL query directly. """
|
|
19
21
|
if format != "pandas":
|
|
20
22
|
raise ValueError(f"Unsupported format: {format}")
|
|
23
|
+
|
|
21
24
|
connection = duckdb.connect()
|
|
25
|
+
|
|
26
|
+
# Register scalar functions
|
|
27
|
+
connection.create_function("erf", self.erf)
|
|
28
|
+
connection.create_function("acot", self.acot)
|
|
29
|
+
connection.create_function("erfinv", self.erfinv)
|
|
30
|
+
|
|
22
31
|
try:
|
|
23
32
|
sql, _ = self.compiler.compile(model, {"is_duck_db": True})
|
|
24
33
|
arrow_table = connection.query(sql).fetch_arrow_table()
|
|
25
34
|
return format_duckdb_columns(arrow_table.to_pandas(), arrow_table.schema)
|
|
26
35
|
finally:
|
|
27
36
|
connection.close()
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def erf(x: float) -> float:
|
|
40
|
+
return math.erf(x)
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def erfinv(x: float) -> float:
|
|
44
|
+
return special_erfinv(x)
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def acot(x: float) -> float:
|
|
48
|
+
return math.atan(1 / x) if x != 0 else math.copysign(math.pi / 2, x)
|
|
@@ -25,7 +25,6 @@ class Denormalize(c.Pass):
|
|
|
25
25
|
|
|
26
26
|
@dataclass
|
|
27
27
|
class OldDenormalize(visitor.Rewriter):
|
|
28
|
-
# TODO: use the new Rewriter when available.
|
|
29
28
|
|
|
30
29
|
denormalized: dict[ir.Relation, ir.Relation] = field(default_factory=dict, init=False, hash=False, compare=False)
|
|
31
30
|
|
|
@@ -168,13 +167,12 @@ class OldDenormalize(visitor.Rewriter):
|
|
|
168
167
|
""" Denormalize the relations that can be denormalized.
|
|
169
168
|
|
|
170
169
|
Group together relations that are keyed by the same "entity". This method defines
|
|
171
|
-
entities as being types that
|
|
170
|
+
entities as being types that have a unary relation containing only that type. All
|
|
172
171
|
relations whose first argument is this type are grouped together.
|
|
173
172
|
|
|
174
173
|
Returns a tuple with 2 elements:
|
|
175
|
-
1.
|
|
176
|
-
2.
|
|
177
|
-
its place.
|
|
174
|
+
1. The new set of relations after denormalization
|
|
175
|
+
2. A dict from relations that were denormalized away to the new relation that took its place.
|
|
178
176
|
"""
|
|
179
177
|
new_relations = ordered_set()
|
|
180
178
|
denormalized: dict[ir.Relation, ir.Relation] = dict()
|
|
@@ -184,7 +182,7 @@ class OldDenormalize(visitor.Rewriter):
|
|
|
184
182
|
entity_relations: dict[ir.Type, ir.Relation] = dict()
|
|
185
183
|
|
|
186
184
|
for r in relations:
|
|
187
|
-
if len(r.fields) == 1 and not types.is_builtin(r.fields[0].type):
|
|
185
|
+
if len(r.fields) == 1 and not types.is_builtin(r.fields[0].type) and not r.name == 'Error':
|
|
188
186
|
e = r.fields[0].type
|
|
189
187
|
entity_types.add(e)
|
|
190
188
|
entity_relations[e] = r
|
|
@@ -49,9 +49,29 @@ class RecursiveUnion(c.Pass):
|
|
|
49
49
|
new_body = [logical for logical in root_logical.body if logical not in recursive_logicals]
|
|
50
50
|
|
|
51
51
|
# Step 4: Add unions for each recursive group
|
|
52
|
-
for
|
|
53
|
-
|
|
54
|
-
|
|
52
|
+
for rel_id, logical_group in recursive_groups.items():
|
|
53
|
+
split_group = ordered_set()
|
|
54
|
+
|
|
55
|
+
for logical in logical_group:
|
|
56
|
+
# Count total ir.Update tasks in this logical
|
|
57
|
+
update_count = sum(isinstance(t, ir.Update) for t in logical.body)
|
|
58
|
+
|
|
59
|
+
# If there's only one, keep the original logical as-is
|
|
60
|
+
if update_count == 1:
|
|
61
|
+
split_group.add(logical)
|
|
62
|
+
continue
|
|
63
|
+
|
|
64
|
+
# Otherwise, keep only updates relevant to this relation (and non-update tasks)
|
|
65
|
+
filtered_body = [
|
|
66
|
+
t for t in logical.body
|
|
67
|
+
if not isinstance(t, ir.Update) or t.relation.id == rel_id
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
if filtered_body:
|
|
71
|
+
split_group.add(f.logical(filtered_body))
|
|
72
|
+
|
|
73
|
+
if split_group:
|
|
74
|
+
new_body.append(f.union(list(split_group)))
|
|
55
75
|
|
|
56
76
|
return model.reconstruct(model.engines, model.relations, model.types, f.logical(new_body), model.annotations)
|
|
57
77
|
|
|
@@ -55,6 +55,17 @@ class CreateView(Node):
|
|
|
55
55
|
name: str
|
|
56
56
|
query: Union[list[Select], CTE]
|
|
57
57
|
|
|
58
|
+
@dataclass(frozen=True)
|
|
59
|
+
class CreateFunction(Node):
|
|
60
|
+
name: str
|
|
61
|
+
inputs: list[Column]
|
|
62
|
+
return_type: str
|
|
63
|
+
body: str
|
|
64
|
+
language: str = "PYTHON"
|
|
65
|
+
runtime_version: str = "3.11"
|
|
66
|
+
handler: str = "compute"
|
|
67
|
+
packages: Optional[list[str]] = None
|
|
68
|
+
|
|
58
69
|
@dataclass(frozen=True)
|
|
59
70
|
class Insert(Node):
|
|
60
71
|
table: str
|
|
@@ -286,6 +297,22 @@ class Printer(BasePrinter):
|
|
|
286
297
|
elif isinstance(node, CreateView):
|
|
287
298
|
self._print(f"CREATE VIEW {self._get_table_name(node.name)} AS ")
|
|
288
299
|
self._print_query(indent, node, "UNION")
|
|
300
|
+
elif isinstance(node, CreateFunction):
|
|
301
|
+
self._print(f"CREATE OR REPLACE FUNCTION {self._get_table_name(node.name)} (")
|
|
302
|
+
self._join(node.inputs)
|
|
303
|
+
self._print_nl(")")
|
|
304
|
+
self._print_nl(f"RETURNS {node.return_type}")
|
|
305
|
+
self._print_nl(f"LANGUAGE {node.language}")
|
|
306
|
+
self._print_nl(f"RUNTIME_VERSION = '{node.runtime_version}'")
|
|
307
|
+
self._print_nl(f"HANDLER = '{node.handler}'")
|
|
308
|
+
if node.packages:
|
|
309
|
+
self._print("PACKAGES = (")
|
|
310
|
+
self._join(node.packages)
|
|
311
|
+
self._print_nl(")")
|
|
312
|
+
self._print_nl("AS ")
|
|
313
|
+
self._print_nl("$$")
|
|
314
|
+
self._print_nl(node.body)
|
|
315
|
+
self._print("$$;")
|
|
289
316
|
elif isinstance(node, Insert):
|
|
290
317
|
self._print(f"INSERT INTO {self._get_table_name(node.table)} ")
|
|
291
318
|
if len(node.columns) > 0:
|
|
@@ -3,7 +3,7 @@ from typing import Any
|
|
|
3
3
|
|
|
4
4
|
from relationalai.semantics.internal import internal as i
|
|
5
5
|
from .std import _Date, _DateTime, _Number, _String, _Integer, _make_expr
|
|
6
|
-
from . import datetime, math, strings, decimals, integers, floats, pragmas, constraints
|
|
6
|
+
from . import datetime, math, strings, decimals, integers, floats, pragmas, constraints, re
|
|
7
7
|
|
|
8
8
|
def range(*args: _Integer) -> i.Expression:
|
|
9
9
|
# supports range(stop), range(start, stop), range(start, stop, step)
|
|
@@ -43,6 +43,7 @@ __all__ = [
|
|
|
43
43
|
"datetime",
|
|
44
44
|
"math",
|
|
45
45
|
"strings",
|
|
46
|
+
"re",
|
|
46
47
|
"decimals",
|
|
47
48
|
"integers",
|
|
48
49
|
"floats",
|
|
@@ -153,6 +153,10 @@ class datetime:
|
|
|
153
153
|
std.cast_to_int64(day), std.cast_to_int64(hour), std.cast_to_int64(minute),
|
|
154
154
|
std.cast_to_int64(second), std.cast_to_int64(millisecond), tz, b.DateTime.ref("res"))
|
|
155
155
|
|
|
156
|
+
@classmethod
|
|
157
|
+
def now(cls) -> b.Expression:
|
|
158
|
+
return _make_expr("datetime_now", b.DateTime.ref("res"))
|
|
159
|
+
|
|
156
160
|
@classmethod
|
|
157
161
|
def year(cls, datetime: _DateTime, tz: dt.tzinfo|_String|None = None) -> b.Expression:
|
|
158
162
|
tz = _extract_tz(datetime, tz)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from relationalai.semantics.internal import internal as i
|
|
4
|
+
from relationalai.semantics.metamodel.util import OrderedSet
|
|
5
|
+
from .std import _Integer, _String, _make_expr
|
|
6
|
+
from typing import Literal, Any
|
|
7
|
+
from .. import std
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def escape(regex: _String) -> i.Expression:
|
|
11
|
+
return _make_expr("escape_regex_metachars", regex, i.String.ref())
|
|
12
|
+
|
|
13
|
+
class Match(i.Producer):
|
|
14
|
+
|
|
15
|
+
def __init__(self, regex: _String, string: _String, pos: _Integer = 0, _type: Literal["search", "fullmatch", "match"] = "match"):
|
|
16
|
+
super().__init__(i.find_model([regex, string, pos]))
|
|
17
|
+
self.regex = regex
|
|
18
|
+
self.string = string
|
|
19
|
+
self.pos = pos
|
|
20
|
+
|
|
21
|
+
if _type == "match":
|
|
22
|
+
self._expr = _regex_match_all(self.regex, self.string, std.cast_to_int64(self.pos + 1))
|
|
23
|
+
self._offset, self._full_match = self._expr._arg_ref(2), self._expr._arg_ref(3)
|
|
24
|
+
elif _type == "search":
|
|
25
|
+
raise NotImplementedError("`search` is not implemented")
|
|
26
|
+
elif _type == "fullmatch":
|
|
27
|
+
_exp = _regex_match_all(self.regex, self.string, std.cast_to_int64(self.pos + 1))
|
|
28
|
+
self._offset, self._full_match = _exp._arg_ref(2), _exp._arg_ref(3)
|
|
29
|
+
self._expr = self._full_match == std.strings.substring(self.string, std.cast_to_int64(self.pos), std.strings.len(self.string))
|
|
30
|
+
|
|
31
|
+
def group(self, index: _Integer = 0) -> i.Producer:
|
|
32
|
+
if index == 0:
|
|
33
|
+
return self._full_match
|
|
34
|
+
else:
|
|
35
|
+
return _make_expr("capture_group_by_index", self.regex, self.string, std.cast_to_int64(self.pos + 1), std.cast_to_int64(index), i.String.ref("res"))
|
|
36
|
+
|
|
37
|
+
def group_by_name(self, name: _String) -> i.Producer:
|
|
38
|
+
return _make_expr("capture_group_by_name", self.regex, self.string, std.cast_to_int64(self.pos + 1), name, i.String.ref("res"))
|
|
39
|
+
|
|
40
|
+
def start(self) -> i.Expression:
|
|
41
|
+
return self._offset - 1
|
|
42
|
+
|
|
43
|
+
def end(self) -> i.Expression:
|
|
44
|
+
return std.strings.len(self.group(0)) + self.start() - 1
|
|
45
|
+
|
|
46
|
+
def span(self) -> tuple[i.Producer, i.Producer]:
|
|
47
|
+
return self.start(), self.end()
|
|
48
|
+
|
|
49
|
+
def _to_keys(self) -> OrderedSet[Any]:
|
|
50
|
+
return i.find_keys(self._expr)
|
|
51
|
+
|
|
52
|
+
def _compile_lookup(self, compiler:i.Compiler, ctx:i.CompilerContext):
|
|
53
|
+
compiler.lookup(self.regex, ctx)
|
|
54
|
+
compiler.lookup(self.string, ctx)
|
|
55
|
+
compiler.lookup(self.pos, ctx)
|
|
56
|
+
return compiler.lookup(self._expr, ctx)
|
|
57
|
+
|
|
58
|
+
def __getattr__(self, name: str) -> Any:
|
|
59
|
+
return object.__getattribute__(self, name)
|
|
60
|
+
|
|
61
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
62
|
+
object.__setattr__(self, name, value)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def match(regex: _String, string: _String) -> Match:
|
|
66
|
+
return Match(regex, string)
|
|
67
|
+
|
|
68
|
+
def search(regex: _String, string: _String, pos: _Integer = 0) -> Match:
|
|
69
|
+
return Match(regex, string, pos, _type="search")
|
|
70
|
+
|
|
71
|
+
def fullmatch(regex: _String, string: _String, pos: _Integer = 0) -> Match:
|
|
72
|
+
return Match(regex, string, pos, _type="fullmatch")
|
|
73
|
+
|
|
74
|
+
def findall(regex: _String, string: _String) -> tuple[i.Producer, i.Producer]:
|
|
75
|
+
exp = _regex_match_all(regex, string)
|
|
76
|
+
ix, match = exp._arg_ref(2), exp._arg_ref(3)
|
|
77
|
+
rank = i.rank(i.asc(ix, match))
|
|
78
|
+
return rank, match
|
|
79
|
+
|
|
80
|
+
def _regex_match_all(regex: _String, string: _String, pos: _Integer|None = None) -> i.Expression:
|
|
81
|
+
if pos is None:
|
|
82
|
+
pos = i.Int64.ref()
|
|
83
|
+
return _make_expr("regex_match_all", regex, string, pos, i.String.ref())
|
|
@@ -15,7 +15,7 @@ def concat(s0: _String, s1: _String, *args: _String) -> b.Expression:
|
|
|
15
15
|
return res
|
|
16
16
|
|
|
17
17
|
def len(s: _String) -> b.Expression:
|
|
18
|
-
return _make_expr("num_chars", s, b.
|
|
18
|
+
return _make_expr("num_chars", s, b.Int64.ref("res"))
|
|
19
19
|
|
|
20
20
|
def startswith(s0: _String, s1: _String) -> b.Expression:
|
|
21
21
|
return _make_expr("starts_with", s0, s1)
|