relationalai 0.11.3__py3-none-any.whl → 0.11.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. relationalai/clients/snowflake.py +6 -1
  2. relationalai/clients/use_index_poller.py +349 -188
  3. relationalai/early_access/dsl/bindings/csv.py +2 -2
  4. relationalai/semantics/internal/internal.py +22 -4
  5. relationalai/semantics/lqp/executor.py +61 -12
  6. relationalai/semantics/lqp/intrinsics.py +23 -0
  7. relationalai/semantics/lqp/model2lqp.py +13 -4
  8. relationalai/semantics/lqp/passes.py +2 -3
  9. relationalai/semantics/lqp/primitives.py +12 -1
  10. relationalai/semantics/metamodel/builtins.py +8 -1
  11. relationalai/semantics/metamodel/factory.py +3 -2
  12. relationalai/semantics/reasoners/graph/core.py +54 -2
  13. relationalai/semantics/reasoners/optimization/solvers_dev.py +20 -1
  14. relationalai/semantics/reasoners/optimization/solvers_pb.py +24 -3
  15. relationalai/semantics/rel/compiler.py +5 -17
  16. relationalai/semantics/rel/executor.py +2 -2
  17. relationalai/semantics/rel/rel.py +6 -0
  18. relationalai/semantics/rel/rel_utils.py +8 -1
  19. relationalai/semantics/rel/rewrite/extract_common.py +153 -242
  20. relationalai/semantics/sql/compiler.py +120 -39
  21. relationalai/semantics/sql/executor/duck_db.py +21 -0
  22. relationalai/semantics/sql/rewrite/denormalize.py +4 -6
  23. relationalai/semantics/sql/rewrite/recursive_union.py +23 -3
  24. relationalai/semantics/sql/sql.py +27 -0
  25. relationalai/semantics/std/__init__.py +2 -1
  26. relationalai/semantics/std/datetime.py +4 -0
  27. relationalai/semantics/std/re.py +83 -0
  28. relationalai/semantics/std/strings.py +1 -1
  29. relationalai/tools/cli_controls.py +445 -60
  30. relationalai/util/format.py +78 -1
  31. {relationalai-0.11.3.dist-info → relationalai-0.11.4.dist-info}/METADATA +3 -2
  32. {relationalai-0.11.3.dist-info → relationalai-0.11.4.dist-info}/RECORD +35 -33
  33. {relationalai-0.11.3.dist-info → relationalai-0.11.4.dist-info}/WHEEL +0 -0
  34. {relationalai-0.11.3.dist-info → relationalai-0.11.4.dist-info}/entry_points.txt +0 -0
  35. {relationalai-0.11.3.dist-info → relationalai-0.11.4.dist-info}/licenses/LICENSE +0 -0
@@ -66,6 +66,22 @@ class RelationInfo:
66
66
  table_selects: list[sql.Select] = field(default_factory=list)
67
67
  dynamic_table_selects: list[sql.Select] = field(default_factory=list)
68
68
 
69
+ @dataclass
70
+ class ImportSpec:
71
+ value: str
72
+ module: Optional[str] = None # e.g., "scipy.special"
73
+
74
+ def render(self) -> str:
75
+ return f"from {self.module} import {self.value}" if self.module else f"import {self.value}"
76
+
77
+
78
+ @dataclass
79
+ class UDFConfig:
80
+ handler: str
81
+ code: str
82
+ imports: list[ImportSpec] = field(default_factory=list)
83
+ packages: list[str] = field(default_factory=list)
84
+
69
85
  @dataclass
70
86
  class ModelToSQL:
71
87
  """ Generates SQL from an IR Model, assuming the compiler rewrites were done. """
@@ -89,9 +105,9 @@ class ModelToSQL:
89
105
  return sql.Program(self._sort_dependencies(self._union_output_selects(self._generate_statements(model))))
90
106
 
91
107
  def _generate_statements(self, model: ir.Model) -> list[sql.Node]:
92
- relations = self._get_relations(model)
108
+ table_relations, used_builtins = self._get_relations(model)
93
109
 
94
- self._register_relation_args(relations)
110
+ self._register_relation_args(table_relations)
95
111
  self._register_external_relations(model)
96
112
 
97
113
  statements: list[sql.Node] = []
@@ -139,11 +155,15 @@ class ModelToSQL:
139
155
  )
140
156
 
141
157
  # 4. Create physical tables for explicitly declared table relations
142
- for relation in relations:
158
+ for relation in table_relations:
143
159
  info = self.relation_infos.get(relation)
144
160
  if info is None or info.table_selects:
145
161
  statements.append(self._create_table(relation))
146
162
 
163
+ #5. Create Snowflake user-defined functions
164
+ if not self._is_duck_db:
165
+ statements.extend(self._create_user_defined_functions(used_builtins))
166
+
147
167
  return statements
148
168
 
149
169
  #--------------------------------------------------
@@ -204,6 +224,62 @@ class ModelToSQL:
204
224
  )
205
225
  )
206
226
 
227
+ def _create_user_defined_functions(self, relations: list[ir.Relation]) -> list[sql.CreateFunction]:
228
+ # Central UDF metadata configuration
229
+ udf_relations: dict[str, UDFConfig] = {
230
+ builtins.acot.name: UDFConfig(
231
+ handler="compute",
232
+ imports=[ImportSpec("math")],
233
+ code="""def compute(x): return math.atan(1 / x) if x != 0 else math.copysign(math.pi / 2, x)"""
234
+ ),
235
+ builtins.erf.name: UDFConfig(
236
+ handler="compute",
237
+ imports=[ImportSpec("math")],
238
+ code="""def compute(x): return math.erf(x)"""
239
+ ),
240
+ builtins.erfinv.name: UDFConfig(
241
+ handler="compute",
242
+ imports=[ImportSpec("erfinv", module="scipy.special")],
243
+ packages=["'scipy'"],
244
+ code="""def compute(x): return erfinv(x)"""
245
+ )
246
+ }
247
+
248
+ statements: list[sql.CreateFunction] = []
249
+
250
+ for r in relations:
251
+ meta = udf_relations.get(r.name)
252
+ if not meta:
253
+ continue
254
+
255
+ # Split relation fields into inputs and return type
256
+ # We expect a single return argument per builtin relation
257
+ return_type = None
258
+ input_columns: list[sql.Column] = []
259
+ for f in r.fields:
260
+ if f.input:
261
+ input_columns.append(sql.Column(self._var_name(r.id, f), self._convert_type(f.type)))
262
+ else:
263
+ return_type = self._convert_type(f.type)
264
+
265
+ # Build a full code block (imports + code)
266
+ imports_code = "\n".join(imp.render() for imp in meta.imports)
267
+ python_block = "\n".join(part for part in (imports_code, meta.code) if part)
268
+
269
+ assert return_type, f"No return type found for relation '{r.name}'"
270
+ statements.append(
271
+ sql.CreateFunction(
272
+ name=r.name,
273
+ inputs=input_columns,
274
+ return_type=return_type,
275
+ handler=meta.handler,
276
+ body=python_block,
277
+ packages=meta.packages
278
+ )
279
+ )
280
+
281
+ return statements
282
+
207
283
  def _create_statement(self, task: ir.Logical):
208
284
 
209
285
  # TODO - improve the typing info to avoid these casts
@@ -1186,7 +1262,7 @@ class ModelToSQL:
1186
1262
  # SF returns values in `""` and to avoid this, we need to cast it to `TEXT` type
1187
1263
  part_expr = f"cast({table_sql_var}.value as TEXT)"
1188
1264
  index_expr = f"({table_sql_var}.index + 1)" # SF is 0-based internally, adjust to it back
1189
- assert isinstance(index, ir.Var) and isinstance(part, ir.Var), "Third and forth arguments (index, part) must be variables"
1265
+ assert isinstance(index, ir.Var) and isinstance(part, ir.Var), "Third and fourth arguments (index, part) must be variables"
1190
1266
  builtin_vars[part] = part_expr
1191
1267
  builtin_vars[index] = index_expr
1192
1268
  elif relation == builtins.range:
@@ -1205,7 +1281,7 @@ class ModelToSQL:
1205
1281
  table_expr = f"LATERAL FLATTEN(input => ARRAY_GENERATE_RANGE({start}, ({stop} + 1), {step}))"
1206
1282
  expr = f"{table_sql_var}.value"
1207
1283
  table_expressions[table_sql_var] = table_expr
1208
- assert isinstance(result, ir.Var), "Forth argument (result) must be a variable"
1284
+ assert isinstance(result, ir.Var), "Fourth argument (result) must be a variable"
1209
1285
  builtin_vars[result] = f"{expr}"
1210
1286
  elif relation == builtins.cast:
1211
1287
  assert len(args) == 3, f"Expected 3 args for `cast`, got {len(args)}: {args}"
@@ -1214,7 +1290,7 @@ class ModelToSQL:
1214
1290
  assert isinstance(result, ir.Var), "Third argument (result) must be a variable"
1215
1291
 
1216
1292
  builtin_vars[result] = original_raw
1217
- elif relation in (builtins.isnan, builtins.isinf):
1293
+ elif relation in {builtins.isnan, builtins.isinf}:
1218
1294
  arg_expr = self._var_to_expr(args[0], reference, resolve_builtin_var, var_to_construct)
1219
1295
  expr = "cast('NaN' AS DOUBLE)" if relation == builtins.isnan else "cast('Infinity' AS DOUBLE)"
1220
1296
  wheres.append(sql.Terminal(f"{arg_expr} = {expr}"))
@@ -1351,8 +1427,8 @@ class ModelToSQL:
1351
1427
  builtin_vars[rhs] = lhs
1352
1428
  date_period_var_type[rhs] = relation.name
1353
1429
  elif relation in builtins.date_builtins:
1354
- if relation in (builtins.date_add, builtins.date_subtract, builtins.datetime_add,
1355
- builtins.datetime_subtract):
1430
+ if relation in {builtins.date_add, builtins.date_subtract, builtins.datetime_add,
1431
+ builtins.datetime_subtract}:
1356
1432
  assert len(args) == 3, f"Expected 3 args for {relation}, got {len(args)}: {args}"
1357
1433
  assert isinstance(rhs, ir.Var), f"Period variable must be `ir.Var`, got: {rhs}"
1358
1434
  period = date_period_var_type[rhs]
@@ -1361,10 +1437,10 @@ class ModelToSQL:
1361
1437
  left = self._var_to_expr(lhs, reference, resolve_builtin_var, var_to_construct)
1362
1438
 
1363
1439
  if self._is_duck_db:
1364
- op = "+" if relation in (builtins.date_add, builtins.datetime_add) else "-"
1440
+ op = "+" if relation in {builtins.date_add, builtins.datetime_add} else "-"
1365
1441
  expr = f"({left} {op} {period_val} * interval 1 {period})"
1366
1442
  else:
1367
- sign = 1 if relation in (builtins.date_add, builtins.datetime_add) else -1
1443
+ sign = 1 if relation in {builtins.date_add, builtins.datetime_add} else -1
1368
1444
  expr = f"dateadd({period}, ({sign} * {period_val}), {left})"
1369
1445
 
1370
1446
  result_var = args[2]
@@ -1439,23 +1515,19 @@ class ModelToSQL:
1439
1515
  rel_name = relation.name
1440
1516
  left = self._var_to_expr(lhs, reference, resolve_builtin_var, var_to_construct)
1441
1517
  if relation in builtins.math_unary_builtins:
1442
- if rel_name in {builtins.acot.name, builtins.erf.name, builtins.erfinv.name}:
1443
- # TODO: implement acot using atan or atan2, it needs to be handled as a union of two cases
1444
- raise Exception(f"The function {rel_name} is not supported.")
1445
- else:
1446
- method = "ln" if rel_name == builtins.natural_log.name else rel_name
1447
- sub_expr = left
1448
- if rel_name == builtins.factorial.name and self._is_duck_db:
1449
- # Factorial requires an integer operand in DuckDB
1450
- sub_expr = f"{left}::INTEGER"
1451
- elif rel_name == builtins.log10.name:
1452
- # log10 is not supported, so we use log with base 10
1453
- sub_expr = f"10, {left}"
1454
- method = "log"
1455
- expr = f"{method}({sub_expr})"
1456
- elif rel_name in (builtins.minimum.name, builtins.maximum.name, builtins.trunc_div.name,
1518
+ method = "ln" if rel_name == builtins.natural_log.name else rel_name
1519
+ sub_expr = left
1520
+ if rel_name == builtins.factorial.name and self._is_duck_db:
1521
+ # Factorial requires an integer operand in DuckDB
1522
+ sub_expr = f"{left}::INTEGER"
1523
+ elif rel_name == builtins.log10.name:
1524
+ # log10 is not supported, so we use log with base 10
1525
+ sub_expr = f"10, {left}"
1526
+ method = "log"
1527
+ expr = f"{method}({sub_expr})"
1528
+ elif rel_name in {builtins.minimum.name, builtins.maximum.name, builtins.trunc_div.name,
1457
1529
  builtins.power.name, builtins.mod.name, builtins.pow.name,
1458
- builtins.log.name):
1530
+ builtins.log.name}:
1459
1531
  assert len(args) == 3, f"Expected 3 args for {relation}, got {len(args)}: {args}"
1460
1532
 
1461
1533
  result_var = args[2]
@@ -1480,7 +1552,7 @@ class ModelToSQL:
1480
1552
  f"but got `{type(result_var).__name__}`: {result_var}"
1481
1553
  )
1482
1554
  builtin_vars[result_var] = expr
1483
- elif relation in (builtins.parse_int64, builtins.parse_int128) and isinstance(rhs, ir.Var):
1555
+ elif relation in {builtins.parse_int64, builtins.parse_int128} and isinstance(rhs, ir.Var):
1484
1556
  builtin_vars[rhs] = self._var_to_expr(lhs, reference, resolve_builtin_var, var_to_construct, False)
1485
1557
  elif helpers.is_from_cast(lookup) and isinstance(rhs, ir.Var):
1486
1558
  # For the `from cast` relations we keep the raw var, and we will ground it later.
@@ -2071,21 +2143,27 @@ class ModelToSQL:
2071
2143
  return f"DECIMAL({base_type.precision},{base_type.scale})"
2072
2144
  raise Exception(f"Unknown built-in type: {t}")
2073
2145
 
2074
- def _get_relations(self, model: ir.Model) -> list[ir.Relation]:
2146
+ def _get_relations(self, model: ir.Model) -> Tuple[list[ir.Relation], list[ir.Relation]]:
2147
+ rw = ReadWriteVisitor()
2148
+ model.accept(rw)
2149
+
2150
+ root = cast(ir.Logical, model.root)
2151
+
2152
+ # For query compilation exclude read-only tables because we do not need to declare `CREATE TABLE` statements
2153
+ used_relations = rw.writes(root) if self._query_compilation else rw.writes(root) | rw.reads(root)
2154
+
2075
2155
  # Filter only relations that require table creation
2076
- relations = [
2077
- r for r in model.relations
2156
+ table_relations = [
2157
+ r for r in used_relations
2078
2158
  if self._is_table_creation_required(r)
2079
2159
  ]
2080
2160
 
2081
- # Optionally exclude read-only tables
2082
- if self._query_compilation:
2083
- rw = ReadWriteVisitor()
2084
- model.accept(rw)
2085
- writable = rw.writes(cast(ir.Logical, model.root))
2086
- relations = [r for r in relations if r in writable]
2161
+ used_builtins = [
2162
+ r for r in rw.reads(root)
2163
+ if builtins.is_builtin(r)
2164
+ ]
2087
2165
 
2088
- return relations
2166
+ return table_relations, used_builtins
2089
2167
 
2090
2168
  def _is_table_creation_required(self, r: ir.Relation) -> bool:
2091
2169
  """
@@ -2161,7 +2239,7 @@ class ModelToSQL:
2161
2239
 
2162
2240
  def _var_name(self, relation_id: int, arg: Union[ir.Var, ir.Field]):
2163
2241
  name = helpers.sanitize(self.relation_arg_name_cache.get_name((relation_id, arg.id), arg.name))
2164
- return f'"{name}"' if name.lower() in ("any", "order") else name
2242
+ return f'"{name}"' if name.lower() in {"any", "order"} else name
2165
2243
 
2166
2244
  def _register_relation_args(self, relations: list[ir.Relation]):
2167
2245
  """
@@ -2308,6 +2386,7 @@ class ModelToSQL:
2308
2386
  3. Other statements except SELECT queries
2309
2387
  4. SELECT queries
2310
2388
  """
2389
+ udfs = []
2311
2390
  create_tables = []
2312
2391
  need_sort: dict[str, list[Union[sql.Insert, sql.CreateView, sql.CreateDynamicTable]]] = defaultdict(list)
2313
2392
  updates = []
@@ -2327,12 +2406,14 @@ class ModelToSQL:
2327
2406
  updates.append(statement)
2328
2407
  elif isinstance(statement, sql.Select):
2329
2408
  selects.append(statement)
2409
+ elif isinstance(statement, sql.CreateFunction):
2410
+ udfs.append(statement)
2330
2411
  else:
2331
2412
  miscellaneous_statements.append(statement)
2332
2413
 
2333
2414
  sorted_statements = self._sort_statements_dependency_graph(need_sort)
2334
2415
 
2335
- return create_tables + sorted_statements + updates + miscellaneous_statements + selects
2416
+ return udfs + create_tables + sorted_statements + updates + miscellaneous_statements + selects
2336
2417
 
2337
2418
  @staticmethod
2338
2419
  def _sort_statements_dependency_graph(statements: dict[str, list[Union[sql.Insert, sql.CreateView, sql.CreateDynamicTable]]]) -> list[sql.Insert]:
@@ -1,8 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import duckdb
4
+ import math
4
5
  from pandas import DataFrame
5
6
  from typing import Any, Union, Literal
7
+ from scipy.special import erfinv as special_erfinv
6
8
 
7
9
  from relationalai.semantics.sql import Compiler
8
10
  from relationalai.semantics.sql.executor.result_helpers import format_duckdb_columns
@@ -18,10 +20,29 @@ class DuckDBExecutor(e.Executor):
18
20
  """ Execute the SQL query directly. """
19
21
  if format != "pandas":
20
22
  raise ValueError(f"Unsupported format: {format}")
23
+
21
24
  connection = duckdb.connect()
25
+
26
+ # Register scalar functions
27
+ connection.create_function("erf", self.erf)
28
+ connection.create_function("acot", self.acot)
29
+ connection.create_function("erfinv", self.erfinv)
30
+
22
31
  try:
23
32
  sql, _ = self.compiler.compile(model, {"is_duck_db": True})
24
33
  arrow_table = connection.query(sql).fetch_arrow_table()
25
34
  return format_duckdb_columns(arrow_table.to_pandas(), arrow_table.schema)
26
35
  finally:
27
36
  connection.close()
37
+
38
+ @staticmethod
39
+ def erf(x: float) -> float:
40
+ return math.erf(x)
41
+
42
+ @staticmethod
43
+ def erfinv(x: float) -> float:
44
+ return special_erfinv(x)
45
+
46
+ @staticmethod
47
+ def acot(x: float) -> float:
48
+ return math.atan(1 / x) if x != 0 else math.copysign(math.pi / 2, x)
@@ -25,7 +25,6 @@ class Denormalize(c.Pass):
25
25
 
26
26
  @dataclass
27
27
  class OldDenormalize(visitor.Rewriter):
28
- # TODO: use the new Rewriter when available.
29
28
 
30
29
  denormalized: dict[ir.Relation, ir.Relation] = field(default_factory=dict, init=False, hash=False, compare=False)
31
30
 
@@ -168,13 +167,12 @@ class OldDenormalize(visitor.Rewriter):
168
167
  """ Denormalize the relations that can be denormalized.
169
168
 
170
169
  Group together relations that are keyed by the same "entity". This method defines
171
- entities as being types that haver a unary relation containing only that type. All
170
+ entities as being types that have a unary relation containing only that type. All
172
171
  relations whose first argument is this type are grouped together.
173
172
 
174
173
  Returns a tuple with 2 elements:
175
- 1. the new set of relations after denormalization
176
- 2. a dict from relations that were denormalized away to the new relation that took
177
- its place.
174
+ 1. The new set of relations after denormalization
175
+ 2. A dict from relations that were denormalized away to the new relation that took its place.
178
176
  """
179
177
  new_relations = ordered_set()
180
178
  denormalized: dict[ir.Relation, ir.Relation] = dict()
@@ -184,7 +182,7 @@ class OldDenormalize(visitor.Rewriter):
184
182
  entity_relations: dict[ir.Type, ir.Relation] = dict()
185
183
 
186
184
  for r in relations:
187
- if len(r.fields) == 1 and not types.is_builtin(r.fields[0].type):
185
+ if len(r.fields) == 1 and not types.is_builtin(r.fields[0].type) and not r.name == 'Error':
188
186
  e = r.fields[0].type
189
187
  entity_types.add(e)
190
188
  entity_relations[e] = r
@@ -49,9 +49,29 @@ class RecursiveUnion(c.Pass):
49
49
  new_body = [logical for logical in root_logical.body if logical not in recursive_logicals]
50
50
 
51
51
  # Step 4: Add unions for each recursive group
52
- for group in recursive_groups.values():
53
- if group:
54
- new_body.append(f.union(list(group)))
52
+ for rel_id, logical_group in recursive_groups.items():
53
+ split_group = ordered_set()
54
+
55
+ for logical in logical_group:
56
+ # Count total ir.Update tasks in this logical
57
+ update_count = sum(isinstance(t, ir.Update) for t in logical.body)
58
+
59
+ # If there's only one, keep the original logical as-is
60
+ if update_count == 1:
61
+ split_group.add(logical)
62
+ continue
63
+
64
+ # Otherwise, keep only updates relevant to this relation (and non-update tasks)
65
+ filtered_body = [
66
+ t for t in logical.body
67
+ if not isinstance(t, ir.Update) or t.relation.id == rel_id
68
+ ]
69
+
70
+ if filtered_body:
71
+ split_group.add(f.logical(filtered_body))
72
+
73
+ if split_group:
74
+ new_body.append(f.union(list(split_group)))
55
75
 
56
76
  return model.reconstruct(model.engines, model.relations, model.types, f.logical(new_body), model.annotations)
57
77
 
@@ -55,6 +55,17 @@ class CreateView(Node):
55
55
  name: str
56
56
  query: Union[list[Select], CTE]
57
57
 
58
+ @dataclass(frozen=True)
59
+ class CreateFunction(Node):
60
+ name: str
61
+ inputs: list[Column]
62
+ return_type: str
63
+ body: str
64
+ language: str = "PYTHON"
65
+ runtime_version: str = "3.11"
66
+ handler: str = "compute"
67
+ packages: Optional[list[str]] = None
68
+
58
69
  @dataclass(frozen=True)
59
70
  class Insert(Node):
60
71
  table: str
@@ -286,6 +297,22 @@ class Printer(BasePrinter):
286
297
  elif isinstance(node, CreateView):
287
298
  self._print(f"CREATE VIEW {self._get_table_name(node.name)} AS ")
288
299
  self._print_query(indent, node, "UNION")
300
+ elif isinstance(node, CreateFunction):
301
+ self._print(f"CREATE OR REPLACE FUNCTION {self._get_table_name(node.name)} (")
302
+ self._join(node.inputs)
303
+ self._print_nl(")")
304
+ self._print_nl(f"RETURNS {node.return_type}")
305
+ self._print_nl(f"LANGUAGE {node.language}")
306
+ self._print_nl(f"RUNTIME_VERSION = '{node.runtime_version}'")
307
+ self._print_nl(f"HANDLER = '{node.handler}'")
308
+ if node.packages:
309
+ self._print("PACKAGES = (")
310
+ self._join(node.packages)
311
+ self._print_nl(")")
312
+ self._print_nl("AS ")
313
+ self._print_nl("$$")
314
+ self._print_nl(node.body)
315
+ self._print("$$;")
289
316
  elif isinstance(node, Insert):
290
317
  self._print(f"INSERT INTO {self._get_table_name(node.table)} ")
291
318
  if len(node.columns) > 0:
@@ -3,7 +3,7 @@ from typing import Any
3
3
 
4
4
  from relationalai.semantics.internal import internal as i
5
5
  from .std import _Date, _DateTime, _Number, _String, _Integer, _make_expr
6
- from . import datetime, math, strings, decimals, integers, floats, pragmas, constraints
6
+ from . import datetime, math, strings, decimals, integers, floats, pragmas, constraints, re
7
7
 
8
8
  def range(*args: _Integer) -> i.Expression:
9
9
  # supports range(stop), range(start, stop), range(start, stop, step)
@@ -43,6 +43,7 @@ __all__ = [
43
43
  "datetime",
44
44
  "math",
45
45
  "strings",
46
+ "re",
46
47
  "decimals",
47
48
  "integers",
48
49
  "floats",
@@ -153,6 +153,10 @@ class datetime:
153
153
  std.cast_to_int64(day), std.cast_to_int64(hour), std.cast_to_int64(minute),
154
154
  std.cast_to_int64(second), std.cast_to_int64(millisecond), tz, b.DateTime.ref("res"))
155
155
 
156
+ @classmethod
157
+ def now(cls) -> b.Expression:
158
+ return _make_expr("datetime_now", b.DateTime.ref("res"))
159
+
156
160
  @classmethod
157
161
  def year(cls, datetime: _DateTime, tz: dt.tzinfo|_String|None = None) -> b.Expression:
158
162
  tz = _extract_tz(datetime, tz)
@@ -0,0 +1,83 @@
1
+ from __future__ import annotations
2
+
3
+ from relationalai.semantics.internal import internal as i
4
+ from relationalai.semantics.metamodel.util import OrderedSet
5
+ from .std import _Integer, _String, _make_expr
6
+ from typing import Literal, Any
7
+ from .. import std
8
+
9
+
10
+ def escape(regex: _String) -> i.Expression:
11
+ return _make_expr("escape_regex_metachars", regex, i.String.ref())
12
+
13
+ class Match(i.Producer):
14
+
15
+ def __init__(self, regex: _String, string: _String, pos: _Integer = 0, _type: Literal["search", "fullmatch", "match"] = "match"):
16
+ super().__init__(i.find_model([regex, string, pos]))
17
+ self.regex = regex
18
+ self.string = string
19
+ self.pos = pos
20
+
21
+ if _type == "match":
22
+ self._expr = _regex_match_all(self.regex, self.string, std.cast_to_int64(self.pos + 1))
23
+ self._offset, self._full_match = self._expr._arg_ref(2), self._expr._arg_ref(3)
24
+ elif _type == "search":
25
+ raise NotImplementedError("`search` is not implemented")
26
+ elif _type == "fullmatch":
27
+ _exp = _regex_match_all(self.regex, self.string, std.cast_to_int64(self.pos + 1))
28
+ self._offset, self._full_match = _exp._arg_ref(2), _exp._arg_ref(3)
29
+ self._expr = self._full_match == std.strings.substring(self.string, std.cast_to_int64(self.pos), std.strings.len(self.string))
30
+
31
+ def group(self, index: _Integer = 0) -> i.Producer:
32
+ if index == 0:
33
+ return self._full_match
34
+ else:
35
+ return _make_expr("capture_group_by_index", self.regex, self.string, std.cast_to_int64(self.pos + 1), std.cast_to_int64(index), i.String.ref("res"))
36
+
37
+ def group_by_name(self, name: _String) -> i.Producer:
38
+ return _make_expr("capture_group_by_name", self.regex, self.string, std.cast_to_int64(self.pos + 1), name, i.String.ref("res"))
39
+
40
+ def start(self) -> i.Expression:
41
+ return self._offset - 1
42
+
43
+ def end(self) -> i.Expression:
44
+ return std.strings.len(self.group(0)) + self.start() - 1
45
+
46
+ def span(self) -> tuple[i.Producer, i.Producer]:
47
+ return self.start(), self.end()
48
+
49
+ def _to_keys(self) -> OrderedSet[Any]:
50
+ return i.find_keys(self._expr)
51
+
52
+ def _compile_lookup(self, compiler:i.Compiler, ctx:i.CompilerContext):
53
+ compiler.lookup(self.regex, ctx)
54
+ compiler.lookup(self.string, ctx)
55
+ compiler.lookup(self.pos, ctx)
56
+ return compiler.lookup(self._expr, ctx)
57
+
58
+ def __getattr__(self, name: str) -> Any:
59
+ return object.__getattribute__(self, name)
60
+
61
+ def __setattr__(self, name: str, value: Any) -> None:
62
+ object.__setattr__(self, name, value)
63
+
64
+
65
+ def match(regex: _String, string: _String) -> Match:
66
+ return Match(regex, string)
67
+
68
+ def search(regex: _String, string: _String, pos: _Integer = 0) -> Match:
69
+ return Match(regex, string, pos, _type="search")
70
+
71
+ def fullmatch(regex: _String, string: _String, pos: _Integer = 0) -> Match:
72
+ return Match(regex, string, pos, _type="fullmatch")
73
+
74
+ def findall(regex: _String, string: _String) -> tuple[i.Producer, i.Producer]:
75
+ exp = _regex_match_all(regex, string)
76
+ ix, match = exp._arg_ref(2), exp._arg_ref(3)
77
+ rank = i.rank(i.asc(ix, match))
78
+ return rank, match
79
+
80
+ def _regex_match_all(regex: _String, string: _String, pos: _Integer|None = None) -> i.Expression:
81
+ if pos is None:
82
+ pos = i.Int64.ref()
83
+ return _make_expr("regex_match_all", regex, string, pos, i.String.ref())
@@ -15,7 +15,7 @@ def concat(s0: _String, s1: _String, *args: _String) -> b.Expression:
15
15
  return res
16
16
 
17
17
  def len(s: _String) -> b.Expression:
18
- return _make_expr("num_chars", s, b.Int128.ref("res"))
18
+ return _make_expr("num_chars", s, b.Int64.ref("res"))
19
19
 
20
20
  def startswith(s0: _String, s1: _String) -> b.Expression:
21
21
  return _make_expr("starts_with", s0, s1)