relationalai 0.11.2__py3-none-any.whl → 0.11.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- relationalai/clients/snowflake.py +44 -15
- relationalai/clients/types.py +1 -0
- relationalai/clients/use_index_poller.py +446 -178
- relationalai/early_access/builder/std/__init__.py +1 -1
- relationalai/early_access/dsl/bindings/csv.py +4 -4
- relationalai/semantics/internal/internal.py +22 -4
- relationalai/semantics/lqp/executor.py +69 -18
- relationalai/semantics/lqp/intrinsics.py +23 -0
- relationalai/semantics/lqp/model2lqp.py +16 -6
- relationalai/semantics/lqp/passes.py +3 -4
- relationalai/semantics/lqp/primitives.py +38 -14
- relationalai/semantics/metamodel/builtins.py +152 -11
- relationalai/semantics/metamodel/factory.py +3 -2
- relationalai/semantics/metamodel/helpers.py +78 -2
- relationalai/semantics/reasoners/graph/core.py +343 -40
- relationalai/semantics/reasoners/optimization/solvers_dev.py +20 -1
- relationalai/semantics/reasoners/optimization/solvers_pb.py +24 -3
- relationalai/semantics/rel/compiler.py +5 -17
- relationalai/semantics/rel/executor.py +2 -2
- relationalai/semantics/rel/rel.py +6 -0
- relationalai/semantics/rel/rel_utils.py +37 -1
- relationalai/semantics/rel/rewrite/extract_common.py +153 -242
- relationalai/semantics/sql/compiler.py +540 -202
- relationalai/semantics/sql/executor/duck_db.py +21 -0
- relationalai/semantics/sql/executor/result_helpers.py +7 -0
- relationalai/semantics/sql/executor/snowflake.py +9 -2
- relationalai/semantics/sql/rewrite/denormalize.py +4 -6
- relationalai/semantics/sql/rewrite/recursive_union.py +23 -3
- relationalai/semantics/sql/sql.py +120 -46
- relationalai/semantics/std/__init__.py +9 -4
- relationalai/semantics/std/datetime.py +363 -0
- relationalai/semantics/std/math.py +77 -0
- relationalai/semantics/std/re.py +83 -0
- relationalai/semantics/std/strings.py +1 -1
- relationalai/tools/cli_controls.py +445 -60
- relationalai/util/format.py +78 -1
- {relationalai-0.11.2.dist-info → relationalai-0.11.4.dist-info}/METADATA +3 -2
- {relationalai-0.11.2.dist-info → relationalai-0.11.4.dist-info}/RECORD +41 -39
- relationalai/semantics/std/dates.py +0 -213
- {relationalai-0.11.2.dist-info → relationalai-0.11.4.dist-info}/WHEEL +0 -0
- {relationalai-0.11.2.dist-info → relationalai-0.11.4.dist-info}/entry_points.txt +0 -0
- {relationalai-0.11.2.dist-info → relationalai-0.11.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -88,12 +88,36 @@ abs = f.relation(
|
|
|
88
88
|
|
|
89
89
|
natural_log = f.relation(
|
|
90
90
|
"natural_log",
|
|
91
|
-
[f.input_field("a", types.Number), f.field("b", types.
|
|
91
|
+
[f.input_field("a", types.Number), f.field("b", types.Float)],
|
|
92
92
|
overloads=[
|
|
93
93
|
f.relation("natural_log", [f.input_field("a", types.Int64), f.field("b", types.Float)]),
|
|
94
94
|
f.relation("natural_log", [f.input_field("a", types.Int128), f.field("b", types.Float)]),
|
|
95
95
|
f.relation("natural_log", [f.input_field("a", types.Float), f.field("b", types.Float)]),
|
|
96
|
-
f.relation("natural_log", [f.input_field("a", types.GenericDecimal), f.field("b", types.
|
|
96
|
+
f.relation("natural_log", [f.input_field("a", types.GenericDecimal), f.field("b", types.Float)]),
|
|
97
|
+
|
|
98
|
+
],
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
log10 = f.relation(
|
|
102
|
+
"log10",
|
|
103
|
+
[f.input_field("a", types.Number), f.field("b", types.Float)],
|
|
104
|
+
overloads=[
|
|
105
|
+
f.relation("log10", [f.input_field("a", types.Int64), f.field("b", types.Float)]),
|
|
106
|
+
f.relation("log10", [f.input_field("a", types.Int128), f.field("b", types.Float)]),
|
|
107
|
+
f.relation("log10", [f.input_field("a", types.Float), f.field("b", types.Float)]),
|
|
108
|
+
f.relation("log10", [f.input_field("a", types.GenericDecimal), f.field("b", types.Float)]),
|
|
109
|
+
|
|
110
|
+
],
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
log = f.relation(
|
|
114
|
+
"log",
|
|
115
|
+
[f.input_field("a", types.Number), f.input_field("b", types.Number), f.field("c", types.Float)],
|
|
116
|
+
overloads=[
|
|
117
|
+
f.relation("log", [f.input_field("a", types.Int64), f.input_field("b", types.Int64), f.field("c", types.Float)]),
|
|
118
|
+
f.relation("log", [f.input_field("a", types.Int128), f.input_field("b", types.Int128), f.field("c", types.Float)]),
|
|
119
|
+
f.relation("log", [f.input_field("a", types.Float), f.input_field("b", types.Float), f.field("c", types.Float)]),
|
|
120
|
+
f.relation("log", [f.input_field("a", types.GenericDecimal), f.input_field("b", types.GenericDecimal), f.field("c", types.Float)]),
|
|
97
121
|
|
|
98
122
|
],
|
|
99
123
|
)
|
|
@@ -274,9 +298,100 @@ asinh = f.relation(
|
|
|
274
298
|
],
|
|
275
299
|
)
|
|
276
300
|
|
|
301
|
+
tan = f.relation(
|
|
302
|
+
"tan",
|
|
303
|
+
[f.input_field("a", types.Number), f.field("b", types.Float)],
|
|
304
|
+
overloads=[
|
|
305
|
+
f.relation("tan", [f.input_field("a", types.Int64), f.field("b", types.Float)]),
|
|
306
|
+
f.relation("tan", [f.input_field("a", types.Int128), f.field("b", types.Float)]),
|
|
307
|
+
f.relation("tan", [f.input_field("a", types.Float), f.field("b", types.Float)]),
|
|
308
|
+
f.relation("tan", [f.input_field("a", types.GenericDecimal), f.field("b", types.Float)])
|
|
309
|
+
],
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
tanh = f.relation(
|
|
313
|
+
"tanh",
|
|
314
|
+
[f.input_field("a", types.Number), f.field("b", types.Float)],
|
|
315
|
+
overloads=[
|
|
316
|
+
f.relation("tanh", [f.input_field("a", types.Int64), f.field("b", types.Float)]),
|
|
317
|
+
f.relation("tanh", [f.input_field("a", types.Int128), f.field("b", types.Float)]),
|
|
318
|
+
f.relation("tanh", [f.input_field("a", types.Float), f.field("b", types.Float)]),
|
|
319
|
+
f.relation("tanh", [f.input_field("a", types.GenericDecimal), f.field("b", types.Float)])
|
|
320
|
+
],
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
atan = f.relation(
|
|
324
|
+
"atan",
|
|
325
|
+
[f.input_field("a", types.Number), f.field("b", types.Float)],
|
|
326
|
+
overloads=[
|
|
327
|
+
f.relation("atan", [f.input_field("a", types.Int64), f.field("b", types.Float)]),
|
|
328
|
+
f.relation("atan", [f.input_field("a", types.Int128), f.field("b", types.Float)]),
|
|
329
|
+
f.relation("atan", [f.input_field("a", types.Float), f.field("b", types.Float)]),
|
|
330
|
+
f.relation("atan", [f.input_field("a", types.GenericDecimal), f.field("b", types.Float)])
|
|
331
|
+
],
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
atanh = f.relation(
|
|
335
|
+
"atanh",
|
|
336
|
+
[f.input_field("a", types.Number), f.field("b", types.Float)],
|
|
337
|
+
overloads=[
|
|
338
|
+
f.relation("atanh", [f.input_field("a", types.Int64), f.field("b", types.Float)]),
|
|
339
|
+
f.relation("atanh", [f.input_field("a", types.Int128), f.field("b", types.Float)]),
|
|
340
|
+
f.relation("atanh", [f.input_field("a", types.Float), f.field("b", types.Float)]),
|
|
341
|
+
f.relation("atanh", [f.input_field("a", types.GenericDecimal), f.field("b", types.Float)])
|
|
342
|
+
],
|
|
343
|
+
)
|
|
344
|
+
cot = f.relation(
|
|
345
|
+
"cot",
|
|
346
|
+
[f.input_field("a", types.Number), f.field("b", types.Float)],
|
|
347
|
+
# Everything will be converted to float to avoid NaN results with other types
|
|
348
|
+
overloads=[
|
|
349
|
+
f.relation("cot", [f.input_field("a", types.Float), f.field("b", types.Float)])
|
|
350
|
+
],
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
acot = f.relation(
|
|
354
|
+
"acot",
|
|
355
|
+
[f.input_field("a", types.Number), f.field("b", types.Float)],
|
|
356
|
+
# Everything will be converted to float to avoid NaN results with other types
|
|
357
|
+
overloads=[
|
|
358
|
+
f.relation("acot", [f.input_field("a", types.Float), f.field("b", types.Float)])
|
|
359
|
+
],
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
exp = f.relation(
|
|
363
|
+
"exp",
|
|
364
|
+
[f.input_field("a", types.Number), f.field("b", types.Float)],
|
|
365
|
+
overloads=[
|
|
366
|
+
f.relation("exp", [f.input_field("a", types.Int64), f.field("b", types.Float)]),
|
|
367
|
+
f.relation("exp", [f.input_field("a", types.Int128), f.field("b", types.Float)]),
|
|
368
|
+
f.relation("exp", [f.input_field("a", types.Float), f.field("b", types.Float)]),
|
|
369
|
+
f.relation("exp", [f.input_field("a", types.GenericDecimal), f.field("b", types.Float)])
|
|
370
|
+
],
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
erf = f.relation(
|
|
374
|
+
"erf",
|
|
375
|
+
[f.input_field("a", types.Number), f.field("b", types.Float)],
|
|
376
|
+
overloads=[
|
|
377
|
+
# Everything will be converted to float to avoid NaN results with other types
|
|
378
|
+
f.relation("erf", [f.input_field("a", types.Float), f.field("b", types.Float)]),
|
|
379
|
+
],
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
erfinv = f.relation(
|
|
383
|
+
"erfinv",
|
|
384
|
+
[f.input_field("a", types.Number), f.field("b", types.Float)],
|
|
385
|
+
overloads=[
|
|
386
|
+
# Everything will be converted to float to avoid NaN results with other types
|
|
387
|
+
f.relation("erfinv", [f.input_field("a", types.Float), f.field("b", types.Float)]),
|
|
388
|
+
],
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
|
|
277
392
|
# Strings
|
|
278
393
|
concat = f.relation("concat", [f.input_field("a", types.String), f.input_field("b", types.String), f.field("c", types.String)])
|
|
279
|
-
num_chars = f.relation("num_chars", [f.input_field("a", types.String), f.field("b", types.
|
|
394
|
+
num_chars = f.relation("num_chars", [f.input_field("a", types.String), f.field("b", types.Int64)])
|
|
280
395
|
starts_with = f.relation("starts_with", [f.input_field("a", types.String), f.input_field("b", types.String)])
|
|
281
396
|
ends_with = f.relation("ends_with", [f.input_field("a", types.String), f.input_field("b", types.String)])
|
|
282
397
|
contains = f.relation("contains", [f.input_field("a", types.String), f.input_field("b", types.String)])
|
|
@@ -291,22 +406,41 @@ replace = f.relation("replace", [f.input_field("a", types.String), f.input_field
|
|
|
291
406
|
split = f.relation("split", [f.input_field("a", types.String), f.input_field("b", types.String), f.field("c", types.Int64), f.field("d", types.String)])
|
|
292
407
|
# should be a separate builtin. SQL emitter compiles it differently
|
|
293
408
|
split_part = f.relation("split_part", [f.input_field("a", types.String), f.input_field("b", types.String), f.field("c", types.Int64), f.field("d", types.String)])
|
|
409
|
+
|
|
410
|
+
# regex
|
|
294
411
|
regex_match = f.relation("regex_match", [f.input_field("a", types.String), f.input_field("b", types.String)])
|
|
412
|
+
regex_match_all = f.relation("regex_match_all", [f.input_field("a", types.String), f.input_field("b", types.String), f.input_field("c", types.Int64), f.field("d", types.String)])
|
|
413
|
+
capture_group_by_index = f.relation("capture_group_by_index", [f.input_field("a", types.String), f.input_field("b", types.String), f.input_field("c", types.Int64), f.input_field("d", types.Int64), f.field("e", types.String)])
|
|
414
|
+
capture_group_by_name = f.relation("capture_group_by_name", [f.input_field("a", types.String), f.input_field("b", types.String), f.input_field("c", types.Int64), f.input_field("d", types.String), f.field("e", types.String)])
|
|
415
|
+
escape_regex_metachars = f.relation("escape_regex_metachars", [f.input_field("a", types.String), f.field("b", types.String)])
|
|
295
416
|
|
|
296
417
|
# Dates
|
|
297
418
|
date_format = f.relation("date_format", [f.input_field("a", types.Date), f.input_field("b", types.String), f.field("c", types.String)])
|
|
298
419
|
datetime_format = f.relation("datetime_format", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.input_field("c", types.String), f.field("d", types.String)])
|
|
299
420
|
date_year = f.relation("date_year", [f.input_field("a", types.Date), f.field("b", types.Int64)])
|
|
421
|
+
date_quarter = f.relation("date_quarter", [f.input_field("a", types.Date), f.field("b", types.Int64)])
|
|
300
422
|
date_month = f.relation("date_month", [f.input_field("a", types.Date), f.field("b", types.Int64)])
|
|
301
423
|
date_week = f.relation("date_week", [f.input_field("a", types.Date), f.field("b", types.Int64)])
|
|
302
424
|
date_day = f.relation("date_day", [f.input_field("a", types.Date), f.field("b", types.Int64)])
|
|
425
|
+
date_dayofyear = f.relation("date_dayofyear", [f.input_field("a", types.Date), f.field("b", types.Int64)])
|
|
426
|
+
date_weekday = f.relation("date_weekday", [f.input_field("a", types.Date), f.field("b", types.Int64)])
|
|
303
427
|
date_add = f.relation("date_add", [f.input_field("a", types.Date), f.input_field("b", types.Int64), f.field("c", types.Date)])
|
|
304
428
|
dates_period_days = f.relation("dates_period_days", [f.input_field("a", types.Date), f.input_field("b", types.Date), f.field("c", types.Int64)])
|
|
305
429
|
datetimes_period_milliseconds = f.relation("datetimes_period_milliseconds", [f.input_field("a", types.DateTime), f.input_field("b", types.DateTime), f.field("c", types.Int64)])
|
|
306
430
|
date_subtract = f.relation("date_subtract", [f.input_field("a", types.Date), f.input_field("b", types.Int64), f.field("c", types.Date)])
|
|
431
|
+
datetime_now = f.relation("datetime_now", [f.field("a", types.DateTime)])
|
|
307
432
|
datetime_add = f.relation("datetime_add", [f.input_field("a", types.DateTime), f.input_field("b", types.Int64), f.field("c", types.DateTime)])
|
|
308
433
|
datetime_subtract = f.relation("datetime_subtract", [f.input_field("a", types.DateTime), f.input_field("b", types.Int64), f.field("c", types.DateTime)])
|
|
434
|
+
datetime_year = f.relation("datetime_year", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
|
|
435
|
+
datetime_quarter = f.relation("datetime_quarter", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
|
|
436
|
+
datetime_month = f.relation("datetime_month", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
|
|
309
437
|
datetime_week = f.relation("datetime_week", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
|
|
438
|
+
datetime_day = f.relation("datetime_day", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
|
|
439
|
+
datetime_dayofyear = f.relation("datetime_dayofyear", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
|
|
440
|
+
datetime_hour = f.relation("datetime_hour", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
|
|
441
|
+
datetime_minute = f.relation("datetime_minute", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
|
|
442
|
+
datetime_second = f.relation("datetime_second", [f.input_field("a", types.DateTime), f.field("c", types.Int64)])
|
|
443
|
+
datetime_weekday = f.relation("datetime_weekday", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
|
|
310
444
|
|
|
311
445
|
# Other
|
|
312
446
|
range = f.relation("range", [
|
|
@@ -319,6 +453,7 @@ range = f.relation("range", [
|
|
|
319
453
|
hash = f.relation("hash", [f.input_field("args", types.AnyList), f.field("hash", types.Hash)])
|
|
320
454
|
|
|
321
455
|
uuid_to_string = f.relation("uuid_to_string", [f.input_field("a", types.Hash), f.field("b", types.String)])
|
|
456
|
+
parse_uuid = f.relation("parse_uuid", [f.input_field("a", types.String), f.field("b", types.Hash)])
|
|
322
457
|
|
|
323
458
|
# Raw source code to be attached to the transaction, when the backend understands this language
|
|
324
459
|
raw_source = f.relation("raw_source", [f.input_field("lang", types.String), f.input_field("source", types.String)])
|
|
@@ -462,6 +597,8 @@ parse_int64 = f.relation("parse_int64", [f.input_field("a", types.String), f.fie
|
|
|
462
597
|
parse_int128 = f.relation("parse_int128", [f.input_field("a", types.String), f.field("b", types.Int128)])
|
|
463
598
|
parse_float = f.relation("parse_float", [f.input_field("a", types.String), f.field("b", types.Float)])
|
|
464
599
|
|
|
600
|
+
nanosecond = f.relation("nanosecond", [f.input_field("a", types.Int64), f.field("b", types.Int64)])
|
|
601
|
+
microsecond = f.relation("microsecond", [f.input_field("a", types.Int64), f.field("b", types.Int64)])
|
|
465
602
|
millisecond = f.relation("millisecond", [f.input_field("a", types.Int64), f.field("b", types.Int64)])
|
|
466
603
|
second = f.relation("second", [f.input_field("a", types.Int64), f.field("b", types.Int64)])
|
|
467
604
|
minute = f.relation("minute", [f.input_field("a", types.Int64), f.field("b", types.Int64)])
|
|
@@ -484,7 +621,6 @@ cast = f.relation(
|
|
|
484
621
|
# Date construction with less overhead
|
|
485
622
|
construct_date = f.relation("construct_date", [f.input_field("year", types.Int64), f.input_field("month", types.Int64), f.input_field("day", types.Int64), f.field("date", types.Date)])
|
|
486
623
|
construct_date_from_datetime = f.relation("construct_date_from_datetime", [f.input_field("datetime", types.DateTime), f.input_field("timezone", types.String), f.field("date", types.Date)])
|
|
487
|
-
construct_datetime = f.relation("construct_datetime", [f.input_field("year", types.Int64), f.input_field("month", types.Int64), f.input_field("day", types.Int64), f.input_field("hour", types.Int64), f.input_field("minute", types.Int64), f.input_field("second", types.Int64), f.field("datetime", types.DateTime)])
|
|
488
624
|
construct_datetime_ms_tz = f.relation("construct_datetime_ms_tz", [f.input_field("year", types.Int64), f.input_field("month", types.Int64), f.input_field("day", types.Int64), f.input_field("hour", types.Int64), f.input_field("minute", types.Int64), f.input_field("second", types.Int64), f.input_field("milliseconds", types.Int64), f.input_field("timezone", types.String), f.field("datetime", types.DateTime)])
|
|
489
625
|
|
|
490
626
|
# Solver helpers
|
|
@@ -545,18 +681,23 @@ builtin_relations_by_name = dict((r.name, r) for r in builtin_relations)
|
|
|
545
681
|
|
|
546
682
|
string_binary_builtins = [num_chars, starts_with, ends_with, contains, like_match, lower, upper, strip, regex_match]
|
|
547
683
|
|
|
548
|
-
date_builtins = [date_year, date_month, date_week, date_day, date_add, date_subtract,
|
|
549
|
-
datetime_add, datetime_subtract, datetimes_period_milliseconds,
|
|
684
|
+
date_builtins = [date_year, date_quarter, date_month, date_week, date_day, date_dayofyear, date_add, date_subtract,
|
|
685
|
+
dates_period_days, datetime_add, datetime_subtract, datetimes_period_milliseconds, datetime_year,
|
|
686
|
+
datetime_quarter, datetime_month, datetime_week, datetime_day, datetime_dayofyear, datetime_hour,
|
|
687
|
+
datetime_minute, datetime_second, date_weekday, datetime_weekday]
|
|
550
688
|
|
|
551
|
-
date_periods = [year, month, week, day, hour, minute, second, millisecond]
|
|
689
|
+
date_periods = [year, month, week, day, hour, minute, second, millisecond, microsecond, nanosecond]
|
|
552
690
|
|
|
553
|
-
math_unary_builtins = [abs, *abs.overloads,
|
|
691
|
+
math_unary_builtins = [abs, *abs.overloads, sqrt, *sqrt.overloads,
|
|
692
|
+
natural_log, *natural_log.overloads, log10, *log10.overloads,
|
|
554
693
|
cbrt, *cbrt.overloads, factorial, *factorial.overloads, cos, *cos.overloads,
|
|
555
694
|
cosh, *cosh.overloads, acos, *acos.overloads, acosh, *acosh.overloads, sin, *sin.overloads,
|
|
556
|
-
sinh, *sinh.overloads, asin, *asin.overloads, asinh, *asinh.overloads,
|
|
557
|
-
|
|
695
|
+
sinh, *sinh.overloads, asin, *asin.overloads, asinh, *asinh.overloads, tan, *tan.overloads,
|
|
696
|
+
tanh, *tanh.overloads, atan, *atan.overloads, atanh, *atanh.overloads, *ceil.overloads,
|
|
697
|
+
cot, *cot.overloads, acot, *acot.overloads, floor, *floor.overloads, exp, *exp.overloads,
|
|
698
|
+
erf, *erf.overloads, erfinv, *erfinv.overloads]
|
|
558
699
|
|
|
559
700
|
math_builtins = [*math_unary_builtins, maximum, *maximum.overloads, minimum, *minimum.overloads, mod, *mod.overloads,
|
|
560
|
-
pow, *pow.overloads, power, *power.overloads, trunc_div, *trunc_div.overloads]
|
|
701
|
+
pow, *pow.overloads, power, *power.overloads, log, *log.overloads, trunc_div, *trunc_div.overloads]
|
|
561
702
|
|
|
562
703
|
pragma_builtins = [rule_reasoner_sem_vo, rule_reasoner_phys_vo]
|
|
@@ -185,10 +185,11 @@ def lit(value: Any) -> ir.Value:
|
|
|
185
185
|
return ir.Literal(types.Bool, value)
|
|
186
186
|
elif isinstance(value, decimal.Decimal):
|
|
187
187
|
return ir.Literal(types.Decimal, value)
|
|
188
|
-
|
|
189
|
-
return ir.Literal(types.Date, value)
|
|
188
|
+
# datetime.datetime is a subclass of datetime.date, so check it first
|
|
190
189
|
elif isinstance(value, datetime.datetime):
|
|
191
190
|
return ir.Literal(types.DateTime, value)
|
|
191
|
+
elif isinstance(value, datetime.date):
|
|
192
|
+
return ir.Literal(types.Date, value)
|
|
192
193
|
elif isinstance(value, list):
|
|
193
194
|
return tuple([lit(v) for v in value])
|
|
194
195
|
else:
|
|
@@ -4,7 +4,8 @@ Helpers to analyze the metamodel IR.
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
6
|
import re
|
|
7
|
-
from
|
|
7
|
+
from dataclasses import fields
|
|
8
|
+
from typing import cast, Tuple, Iterable, Optional, TypeVar
|
|
8
9
|
from relationalai.semantics.metamodel import ir, visitor, builtins, types, factory as f
|
|
9
10
|
from relationalai.semantics.metamodel.util import NameCache, OrderedSet, FrozenOrderedSet, flatten_tuple, ordered_set
|
|
10
11
|
|
|
@@ -263,7 +264,7 @@ def extract(task: ir.Task, body: OrderedSet[ir.Task], exposed_vars: list[ir.Var]
|
|
|
263
264
|
body.add(f.derive(connection, exposed_vars))
|
|
264
265
|
|
|
265
266
|
# extract the body
|
|
266
|
-
ctx.top_level.append(ir.Logical(task.engine, tuple(), tuple(body)))
|
|
267
|
+
ctx.top_level.append(clone_task(ir.Logical(task.engine, tuple(), tuple(body))))
|
|
267
268
|
|
|
268
269
|
return connection
|
|
269
270
|
|
|
@@ -282,3 +283,78 @@ def create_task_name(name_cache: NameCache, task: ir.Task, prefix: Optional[str]
|
|
|
282
283
|
""" Helper to generate consistent names for tasks extracted from a logical. """
|
|
283
284
|
prefix = prefix if prefix else f"_{task.kind}"
|
|
284
285
|
return name_cache.get_name(task.id, prefix)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
CLONABLE = (ir.Var, ir.Default, ir.Task)
|
|
289
|
+
T = TypeVar('T', bound=ir.Task)
|
|
290
|
+
def clone_task(task: T) -> T:
|
|
291
|
+
"""
|
|
292
|
+
Create a new task that is a clone of this task. This operation clones only sub-tasks
|
|
293
|
+
and variables, and preserves variable references.
|
|
294
|
+
|
|
295
|
+
This is useful when we are rewriting the metamodel and want to copy parts of a task to
|
|
296
|
+
some other place. It is important to clone to avoid having the same object present
|
|
297
|
+
multiple times in the metamodel.
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
# map from original object id to the rewritten object
|
|
301
|
+
cache = {}
|
|
302
|
+
def from_cache(original):
|
|
303
|
+
""" Lookup this object from the cache above, dealing with collections and with
|
|
304
|
+
objects that were not rewritten. """
|
|
305
|
+
if isinstance(original, tuple):
|
|
306
|
+
return tuple([from_cache(c) for c in original])
|
|
307
|
+
elif isinstance(original, FrozenOrderedSet):
|
|
308
|
+
return ordered_set(*[from_cache(c) for c in original])
|
|
309
|
+
elif isinstance(original, CLONABLE):
|
|
310
|
+
return cache.get(original.id, original)
|
|
311
|
+
else:
|
|
312
|
+
return original
|
|
313
|
+
|
|
314
|
+
# the last node that was processed and rewritten
|
|
315
|
+
prev_node = None
|
|
316
|
+
stack: list[ir.Node] = [task]
|
|
317
|
+
def to_stack(original):
|
|
318
|
+
""" Add this original node to the stack if it is clonable and was never processed;
|
|
319
|
+
return True iff the node was added to the stack. """
|
|
320
|
+
if isinstance(original, CLONABLE) and original.id not in cache:
|
|
321
|
+
stack.append(original)
|
|
322
|
+
return True
|
|
323
|
+
return False
|
|
324
|
+
|
|
325
|
+
while stack:
|
|
326
|
+
# peek the current node and get the initializable fields (i.e. ignore Node id)
|
|
327
|
+
curr = stack[-1]
|
|
328
|
+
curr_fields = list(filter(lambda f: f.init, fields(curr)))
|
|
329
|
+
stacked_children = False
|
|
330
|
+
|
|
331
|
+
# go over the fields adding to the stack the ones that we need to rewrite
|
|
332
|
+
for field in curr_fields:
|
|
333
|
+
field_value = getattr(curr, field.name)
|
|
334
|
+
if isinstance(field_value, (tuple, FrozenOrderedSet)):
|
|
335
|
+
# node field is a collection (tuple or set)
|
|
336
|
+
for s in field_value:
|
|
337
|
+
if isinstance(s, tuple):
|
|
338
|
+
# the value within the collection is a tuple (can happen for lookup args)
|
|
339
|
+
for c in s:
|
|
340
|
+
stacked_children = to_stack(c) or stacked_children
|
|
341
|
+
else:
|
|
342
|
+
# the value within the collection is a scalar
|
|
343
|
+
stacked_children = to_stack(s) or stacked_children
|
|
344
|
+
else:
|
|
345
|
+
# node field is a scalar
|
|
346
|
+
stacked_children = to_stack(field_value) or stacked_children
|
|
347
|
+
|
|
348
|
+
# if no childrean were stacked, we rewrote all fields of curr, so we can pop it and rewrite it
|
|
349
|
+
if not stacked_children:
|
|
350
|
+
stack.pop()
|
|
351
|
+
children = []
|
|
352
|
+
for f in curr_fields:
|
|
353
|
+
children.append(from_cache(getattr(curr, f.name)))
|
|
354
|
+
# create a new prev_node with the cloned children
|
|
355
|
+
prev_node = curr.__class__(*children)
|
|
356
|
+
cache[curr.id] = prev_node
|
|
357
|
+
|
|
358
|
+
# the last node we processed is the rewritten original node
|
|
359
|
+
assert(isinstance(prev_node, type(task)))
|
|
360
|
+
return prev_node
|