relationalai 0.11.2__py3-none-any.whl → 0.11.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. relationalai/clients/snowflake.py +44 -15
  2. relationalai/clients/types.py +1 -0
  3. relationalai/clients/use_index_poller.py +446 -178
  4. relationalai/early_access/builder/std/__init__.py +1 -1
  5. relationalai/early_access/dsl/bindings/csv.py +4 -4
  6. relationalai/semantics/internal/internal.py +22 -4
  7. relationalai/semantics/lqp/executor.py +69 -18
  8. relationalai/semantics/lqp/intrinsics.py +23 -0
  9. relationalai/semantics/lqp/model2lqp.py +16 -6
  10. relationalai/semantics/lqp/passes.py +3 -4
  11. relationalai/semantics/lqp/primitives.py +38 -14
  12. relationalai/semantics/metamodel/builtins.py +152 -11
  13. relationalai/semantics/metamodel/factory.py +3 -2
  14. relationalai/semantics/metamodel/helpers.py +78 -2
  15. relationalai/semantics/reasoners/graph/core.py +343 -40
  16. relationalai/semantics/reasoners/optimization/solvers_dev.py +20 -1
  17. relationalai/semantics/reasoners/optimization/solvers_pb.py +24 -3
  18. relationalai/semantics/rel/compiler.py +5 -17
  19. relationalai/semantics/rel/executor.py +2 -2
  20. relationalai/semantics/rel/rel.py +6 -0
  21. relationalai/semantics/rel/rel_utils.py +37 -1
  22. relationalai/semantics/rel/rewrite/extract_common.py +153 -242
  23. relationalai/semantics/sql/compiler.py +540 -202
  24. relationalai/semantics/sql/executor/duck_db.py +21 -0
  25. relationalai/semantics/sql/executor/result_helpers.py +7 -0
  26. relationalai/semantics/sql/executor/snowflake.py +9 -2
  27. relationalai/semantics/sql/rewrite/denormalize.py +4 -6
  28. relationalai/semantics/sql/rewrite/recursive_union.py +23 -3
  29. relationalai/semantics/sql/sql.py +120 -46
  30. relationalai/semantics/std/__init__.py +9 -4
  31. relationalai/semantics/std/datetime.py +363 -0
  32. relationalai/semantics/std/math.py +77 -0
  33. relationalai/semantics/std/re.py +83 -0
  34. relationalai/semantics/std/strings.py +1 -1
  35. relationalai/tools/cli_controls.py +445 -60
  36. relationalai/util/format.py +78 -1
  37. {relationalai-0.11.2.dist-info → relationalai-0.11.4.dist-info}/METADATA +3 -2
  38. {relationalai-0.11.2.dist-info → relationalai-0.11.4.dist-info}/RECORD +41 -39
  39. relationalai/semantics/std/dates.py +0 -213
  40. {relationalai-0.11.2.dist-info → relationalai-0.11.4.dist-info}/WHEEL +0 -0
  41. {relationalai-0.11.2.dist-info → relationalai-0.11.4.dist-info}/entry_points.txt +0 -0
  42. {relationalai-0.11.2.dist-info → relationalai-0.11.4.dist-info}/licenses/LICENSE +0 -0
@@ -88,12 +88,36 @@ abs = f.relation(
88
88
 
89
89
  natural_log = f.relation(
90
90
  "natural_log",
91
- [f.input_field("a", types.Number), f.field("b", types.Number)],
91
+ [f.input_field("a", types.Number), f.field("b", types.Float)],
92
92
  overloads=[
93
93
  f.relation("natural_log", [f.input_field("a", types.Int64), f.field("b", types.Float)]),
94
94
  f.relation("natural_log", [f.input_field("a", types.Int128), f.field("b", types.Float)]),
95
95
  f.relation("natural_log", [f.input_field("a", types.Float), f.field("b", types.Float)]),
96
- f.relation("natural_log", [f.input_field("a", types.GenericDecimal), f.field("b", types.GenericDecimal)]),
96
+ f.relation("natural_log", [f.input_field("a", types.GenericDecimal), f.field("b", types.Float)]),
97
+
98
+ ],
99
+ )
100
+
101
+ log10 = f.relation(
102
+ "log10",
103
+ [f.input_field("a", types.Number), f.field("b", types.Float)],
104
+ overloads=[
105
+ f.relation("log10", [f.input_field("a", types.Int64), f.field("b", types.Float)]),
106
+ f.relation("log10", [f.input_field("a", types.Int128), f.field("b", types.Float)]),
107
+ f.relation("log10", [f.input_field("a", types.Float), f.field("b", types.Float)]),
108
+ f.relation("log10", [f.input_field("a", types.GenericDecimal), f.field("b", types.Float)]),
109
+
110
+ ],
111
+ )
112
+
113
+ log = f.relation(
114
+ "log",
115
+ [f.input_field("a", types.Number), f.input_field("b", types.Number), f.field("c", types.Float)],
116
+ overloads=[
117
+ f.relation("log", [f.input_field("a", types.Int64), f.input_field("b", types.Int64), f.field("c", types.Float)]),
118
+ f.relation("log", [f.input_field("a", types.Int128), f.input_field("b", types.Int128), f.field("c", types.Float)]),
119
+ f.relation("log", [f.input_field("a", types.Float), f.input_field("b", types.Float), f.field("c", types.Float)]),
120
+ f.relation("log", [f.input_field("a", types.GenericDecimal), f.input_field("b", types.GenericDecimal), f.field("c", types.Float)]),
97
121
 
98
122
  ],
99
123
  )
@@ -274,9 +298,100 @@ asinh = f.relation(
274
298
  ],
275
299
  )
276
300
 
301
+ tan = f.relation(
302
+ "tan",
303
+ [f.input_field("a", types.Number), f.field("b", types.Float)],
304
+ overloads=[
305
+ f.relation("tan", [f.input_field("a", types.Int64), f.field("b", types.Float)]),
306
+ f.relation("tan", [f.input_field("a", types.Int128), f.field("b", types.Float)]),
307
+ f.relation("tan", [f.input_field("a", types.Float), f.field("b", types.Float)]),
308
+ f.relation("tan", [f.input_field("a", types.GenericDecimal), f.field("b", types.Float)])
309
+ ],
310
+ )
311
+
312
+ tanh = f.relation(
313
+ "tanh",
314
+ [f.input_field("a", types.Number), f.field("b", types.Float)],
315
+ overloads=[
316
+ f.relation("tanh", [f.input_field("a", types.Int64), f.field("b", types.Float)]),
317
+ f.relation("tanh", [f.input_field("a", types.Int128), f.field("b", types.Float)]),
318
+ f.relation("tanh", [f.input_field("a", types.Float), f.field("b", types.Float)]),
319
+ f.relation("tanh", [f.input_field("a", types.GenericDecimal), f.field("b", types.Float)])
320
+ ],
321
+ )
322
+
323
+ atan = f.relation(
324
+ "atan",
325
+ [f.input_field("a", types.Number), f.field("b", types.Float)],
326
+ overloads=[
327
+ f.relation("atan", [f.input_field("a", types.Int64), f.field("b", types.Float)]),
328
+ f.relation("atan", [f.input_field("a", types.Int128), f.field("b", types.Float)]),
329
+ f.relation("atan", [f.input_field("a", types.Float), f.field("b", types.Float)]),
330
+ f.relation("atan", [f.input_field("a", types.GenericDecimal), f.field("b", types.Float)])
331
+ ],
332
+ )
333
+
334
+ atanh = f.relation(
335
+ "atanh",
336
+ [f.input_field("a", types.Number), f.field("b", types.Float)],
337
+ overloads=[
338
+ f.relation("atanh", [f.input_field("a", types.Int64), f.field("b", types.Float)]),
339
+ f.relation("atanh", [f.input_field("a", types.Int128), f.field("b", types.Float)]),
340
+ f.relation("atanh", [f.input_field("a", types.Float), f.field("b", types.Float)]),
341
+ f.relation("atanh", [f.input_field("a", types.GenericDecimal), f.field("b", types.Float)])
342
+ ],
343
+ )
344
+ cot = f.relation(
345
+ "cot",
346
+ [f.input_field("a", types.Number), f.field("b", types.Float)],
347
+ # Everything will be converted to float to avoid NaN results with other types
348
+ overloads=[
349
+ f.relation("cot", [f.input_field("a", types.Float), f.field("b", types.Float)])
350
+ ],
351
+ )
352
+
353
+ acot = f.relation(
354
+ "acot",
355
+ [f.input_field("a", types.Number), f.field("b", types.Float)],
356
+ # Everything will be converted to float to avoid NaN results with other types
357
+ overloads=[
358
+ f.relation("acot", [f.input_field("a", types.Float), f.field("b", types.Float)])
359
+ ],
360
+ )
361
+
362
+ exp = f.relation(
363
+ "exp",
364
+ [f.input_field("a", types.Number), f.field("b", types.Float)],
365
+ overloads=[
366
+ f.relation("exp", [f.input_field("a", types.Int64), f.field("b", types.Float)]),
367
+ f.relation("exp", [f.input_field("a", types.Int128), f.field("b", types.Float)]),
368
+ f.relation("exp", [f.input_field("a", types.Float), f.field("b", types.Float)]),
369
+ f.relation("exp", [f.input_field("a", types.GenericDecimal), f.field("b", types.Float)])
370
+ ],
371
+ )
372
+
373
+ erf = f.relation(
374
+ "erf",
375
+ [f.input_field("a", types.Number), f.field("b", types.Float)],
376
+ overloads=[
377
+ # Everything will be converted to float to avoid NaN results with other types
378
+ f.relation("erf", [f.input_field("a", types.Float), f.field("b", types.Float)]),
379
+ ],
380
+ )
381
+
382
+ erfinv = f.relation(
383
+ "erfinv",
384
+ [f.input_field("a", types.Number), f.field("b", types.Float)],
385
+ overloads=[
386
+ # Everything will be converted to float to avoid NaN results with other types
387
+ f.relation("erfinv", [f.input_field("a", types.Float), f.field("b", types.Float)]),
388
+ ],
389
+ )
390
+
391
+
277
392
  # Strings
278
393
  concat = f.relation("concat", [f.input_field("a", types.String), f.input_field("b", types.String), f.field("c", types.String)])
279
- num_chars = f.relation("num_chars", [f.input_field("a", types.String), f.field("b", types.Int128)])
394
+ num_chars = f.relation("num_chars", [f.input_field("a", types.String), f.field("b", types.Int64)])
280
395
  starts_with = f.relation("starts_with", [f.input_field("a", types.String), f.input_field("b", types.String)])
281
396
  ends_with = f.relation("ends_with", [f.input_field("a", types.String), f.input_field("b", types.String)])
282
397
  contains = f.relation("contains", [f.input_field("a", types.String), f.input_field("b", types.String)])
@@ -291,22 +406,41 @@ replace = f.relation("replace", [f.input_field("a", types.String), f.input_field
291
406
  split = f.relation("split", [f.input_field("a", types.String), f.input_field("b", types.String), f.field("c", types.Int64), f.field("d", types.String)])
292
407
  # should be a separate builtin. SQL emitter compiles it differently
293
408
  split_part = f.relation("split_part", [f.input_field("a", types.String), f.input_field("b", types.String), f.field("c", types.Int64), f.field("d", types.String)])
409
+
410
+ # regex
294
411
  regex_match = f.relation("regex_match", [f.input_field("a", types.String), f.input_field("b", types.String)])
412
+ regex_match_all = f.relation("regex_match_all", [f.input_field("a", types.String), f.input_field("b", types.String), f.input_field("c", types.Int64), f.field("d", types.String)])
413
+ capture_group_by_index = f.relation("capture_group_by_index", [f.input_field("a", types.String), f.input_field("b", types.String), f.input_field("c", types.Int64), f.input_field("d", types.Int64), f.field("e", types.String)])
414
+ capture_group_by_name = f.relation("capture_group_by_name", [f.input_field("a", types.String), f.input_field("b", types.String), f.input_field("c", types.Int64), f.input_field("d", types.String), f.field("e", types.String)])
415
+ escape_regex_metachars = f.relation("escape_regex_metachars", [f.input_field("a", types.String), f.field("b", types.String)])
295
416
 
296
417
  # Dates
297
418
  date_format = f.relation("date_format", [f.input_field("a", types.Date), f.input_field("b", types.String), f.field("c", types.String)])
298
419
  datetime_format = f.relation("datetime_format", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.input_field("c", types.String), f.field("d", types.String)])
299
420
  date_year = f.relation("date_year", [f.input_field("a", types.Date), f.field("b", types.Int64)])
421
+ date_quarter = f.relation("date_quarter", [f.input_field("a", types.Date), f.field("b", types.Int64)])
300
422
  date_month = f.relation("date_month", [f.input_field("a", types.Date), f.field("b", types.Int64)])
301
423
  date_week = f.relation("date_week", [f.input_field("a", types.Date), f.field("b", types.Int64)])
302
424
  date_day = f.relation("date_day", [f.input_field("a", types.Date), f.field("b", types.Int64)])
425
+ date_dayofyear = f.relation("date_dayofyear", [f.input_field("a", types.Date), f.field("b", types.Int64)])
426
+ date_weekday = f.relation("date_weekday", [f.input_field("a", types.Date), f.field("b", types.Int64)])
303
427
  date_add = f.relation("date_add", [f.input_field("a", types.Date), f.input_field("b", types.Int64), f.field("c", types.Date)])
304
428
  dates_period_days = f.relation("dates_period_days", [f.input_field("a", types.Date), f.input_field("b", types.Date), f.field("c", types.Int64)])
305
429
  datetimes_period_milliseconds = f.relation("datetimes_period_milliseconds", [f.input_field("a", types.DateTime), f.input_field("b", types.DateTime), f.field("c", types.Int64)])
306
430
  date_subtract = f.relation("date_subtract", [f.input_field("a", types.Date), f.input_field("b", types.Int64), f.field("c", types.Date)])
431
+ datetime_now = f.relation("datetime_now", [f.field("a", types.DateTime)])
307
432
  datetime_add = f.relation("datetime_add", [f.input_field("a", types.DateTime), f.input_field("b", types.Int64), f.field("c", types.DateTime)])
308
433
  datetime_subtract = f.relation("datetime_subtract", [f.input_field("a", types.DateTime), f.input_field("b", types.Int64), f.field("c", types.DateTime)])
434
+ datetime_year = f.relation("datetime_year", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
435
+ datetime_quarter = f.relation("datetime_quarter", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
436
+ datetime_month = f.relation("datetime_month", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
309
437
  datetime_week = f.relation("datetime_week", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
438
+ datetime_day = f.relation("datetime_day", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
439
+ datetime_dayofyear = f.relation("datetime_dayofyear", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
440
+ datetime_hour = f.relation("datetime_hour", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
441
+ datetime_minute = f.relation("datetime_minute", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
442
+ datetime_second = f.relation("datetime_second", [f.input_field("a", types.DateTime), f.field("c", types.Int64)])
443
+ datetime_weekday = f.relation("datetime_weekday", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
310
444
 
311
445
  # Other
312
446
  range = f.relation("range", [
@@ -319,6 +453,7 @@ range = f.relation("range", [
319
453
  hash = f.relation("hash", [f.input_field("args", types.AnyList), f.field("hash", types.Hash)])
320
454
 
321
455
  uuid_to_string = f.relation("uuid_to_string", [f.input_field("a", types.Hash), f.field("b", types.String)])
456
+ parse_uuid = f.relation("parse_uuid", [f.input_field("a", types.String), f.field("b", types.Hash)])
322
457
 
323
458
  # Raw source code to be attached to the transaction, when the backend understands this language
324
459
  raw_source = f.relation("raw_source", [f.input_field("lang", types.String), f.input_field("source", types.String)])
@@ -462,6 +597,8 @@ parse_int64 = f.relation("parse_int64", [f.input_field("a", types.String), f.fie
462
597
  parse_int128 = f.relation("parse_int128", [f.input_field("a", types.String), f.field("b", types.Int128)])
463
598
  parse_float = f.relation("parse_float", [f.input_field("a", types.String), f.field("b", types.Float)])
464
599
 
600
+ nanosecond = f.relation("nanosecond", [f.input_field("a", types.Int64), f.field("b", types.Int64)])
601
+ microsecond = f.relation("microsecond", [f.input_field("a", types.Int64), f.field("b", types.Int64)])
465
602
  millisecond = f.relation("millisecond", [f.input_field("a", types.Int64), f.field("b", types.Int64)])
466
603
  second = f.relation("second", [f.input_field("a", types.Int64), f.field("b", types.Int64)])
467
604
  minute = f.relation("minute", [f.input_field("a", types.Int64), f.field("b", types.Int64)])
@@ -484,7 +621,6 @@ cast = f.relation(
484
621
  # Date construction with less overhead
485
622
  construct_date = f.relation("construct_date", [f.input_field("year", types.Int64), f.input_field("month", types.Int64), f.input_field("day", types.Int64), f.field("date", types.Date)])
486
623
  construct_date_from_datetime = f.relation("construct_date_from_datetime", [f.input_field("datetime", types.DateTime), f.input_field("timezone", types.String), f.field("date", types.Date)])
487
- construct_datetime = f.relation("construct_datetime", [f.input_field("year", types.Int64), f.input_field("month", types.Int64), f.input_field("day", types.Int64), f.input_field("hour", types.Int64), f.input_field("minute", types.Int64), f.input_field("second", types.Int64), f.field("datetime", types.DateTime)])
488
624
  construct_datetime_ms_tz = f.relation("construct_datetime_ms_tz", [f.input_field("year", types.Int64), f.input_field("month", types.Int64), f.input_field("day", types.Int64), f.input_field("hour", types.Int64), f.input_field("minute", types.Int64), f.input_field("second", types.Int64), f.input_field("milliseconds", types.Int64), f.input_field("timezone", types.String), f.field("datetime", types.DateTime)])
489
625
 
490
626
  # Solver helpers
@@ -545,18 +681,23 @@ builtin_relations_by_name = dict((r.name, r) for r in builtin_relations)
545
681
 
546
682
  string_binary_builtins = [num_chars, starts_with, ends_with, contains, like_match, lower, upper, strip, regex_match]
547
683
 
548
- date_builtins = [date_year, date_month, date_week, date_day, date_add, date_subtract, dates_period_days,
549
- datetime_add, datetime_subtract, datetimes_period_milliseconds, datetime_week]
684
+ date_builtins = [date_year, date_quarter, date_month, date_week, date_day, date_dayofyear, date_add, date_subtract,
685
+ dates_period_days, datetime_add, datetime_subtract, datetimes_period_milliseconds, datetime_year,
686
+ datetime_quarter, datetime_month, datetime_week, datetime_day, datetime_dayofyear, datetime_hour,
687
+ datetime_minute, datetime_second, date_weekday, datetime_weekday]
550
688
 
551
- date_periods = [year, month, week, day, hour, minute, second, millisecond]
689
+ date_periods = [year, month, week, day, hour, minute, second, millisecond, microsecond, nanosecond]
552
690
 
553
- math_unary_builtins = [abs, *abs.overloads, natural_log, *natural_log.overloads, sqrt, *sqrt.overloads,
691
+ math_unary_builtins = [abs, *abs.overloads, sqrt, *sqrt.overloads,
692
+ natural_log, *natural_log.overloads, log10, *log10.overloads,
554
693
  cbrt, *cbrt.overloads, factorial, *factorial.overloads, cos, *cos.overloads,
555
694
  cosh, *cosh.overloads, acos, *acos.overloads, acosh, *acosh.overloads, sin, *sin.overloads,
556
- sinh, *sinh.overloads, asin, *asin.overloads, asinh, *asinh.overloads, ceil, *ceil.overloads,
557
- floor, *floor.overloads]
695
+ sinh, *sinh.overloads, asin, *asin.overloads, asinh, *asinh.overloads, tan, *tan.overloads,
696
+ tanh, *tanh.overloads, atan, *atan.overloads, atanh, *atanh.overloads, *ceil.overloads,
697
+ cot, *cot.overloads, acot, *acot.overloads, floor, *floor.overloads, exp, *exp.overloads,
698
+ erf, *erf.overloads, erfinv, *erfinv.overloads]
558
699
 
559
700
  math_builtins = [*math_unary_builtins, maximum, *maximum.overloads, minimum, *minimum.overloads, mod, *mod.overloads,
560
- pow, *pow.overloads, power, *power.overloads, trunc_div, *trunc_div.overloads]
701
+ pow, *pow.overloads, power, *power.overloads, log, *log.overloads, trunc_div, *trunc_div.overloads]
561
702
 
562
703
  pragma_builtins = [rule_reasoner_sem_vo, rule_reasoner_phys_vo]
@@ -185,10 +185,11 @@ def lit(value: Any) -> ir.Value:
185
185
  return ir.Literal(types.Bool, value)
186
186
  elif isinstance(value, decimal.Decimal):
187
187
  return ir.Literal(types.Decimal, value)
188
- elif isinstance(value, datetime.date):
189
- return ir.Literal(types.Date, value)
188
+ # datetime.datetime is a subclass of datetime.date, so check it first
190
189
  elif isinstance(value, datetime.datetime):
191
190
  return ir.Literal(types.DateTime, value)
191
+ elif isinstance(value, datetime.date):
192
+ return ir.Literal(types.Date, value)
192
193
  elif isinstance(value, list):
193
194
  return tuple([lit(v) for v in value])
194
195
  else:
@@ -4,7 +4,8 @@ Helpers to analyze the metamodel IR.
4
4
  from __future__ import annotations
5
5
 
6
6
  import re
7
- from typing import cast, Tuple, Iterable, Optional
7
+ from dataclasses import fields
8
+ from typing import cast, Tuple, Iterable, Optional, TypeVar
8
9
  from relationalai.semantics.metamodel import ir, visitor, builtins, types, factory as f
9
10
  from relationalai.semantics.metamodel.util import NameCache, OrderedSet, FrozenOrderedSet, flatten_tuple, ordered_set
10
11
 
@@ -263,7 +264,7 @@ def extract(task: ir.Task, body: OrderedSet[ir.Task], exposed_vars: list[ir.Var]
263
264
  body.add(f.derive(connection, exposed_vars))
264
265
 
265
266
  # extract the body
266
- ctx.top_level.append(ir.Logical(task.engine, tuple(), tuple(body)))
267
+ ctx.top_level.append(clone_task(ir.Logical(task.engine, tuple(), tuple(body))))
267
268
 
268
269
  return connection
269
270
 
@@ -282,3 +283,78 @@ def create_task_name(name_cache: NameCache, task: ir.Task, prefix: Optional[str]
282
283
  """ Helper to generate consistent names for tasks extracted from a logical. """
283
284
  prefix = prefix if prefix else f"_{task.kind}"
284
285
  return name_cache.get_name(task.id, prefix)
286
+
287
+
288
+ CLONABLE = (ir.Var, ir.Default, ir.Task)
289
+ T = TypeVar('T', bound=ir.Task)
290
+ def clone_task(task: T) -> T:
291
+ """
292
+ Create a new task that is a clone of this task. This operation clones only sub-tasks
293
+ and variables, and preserves variable references.
294
+
295
+ This is useful when we are rewriting the metamodel and want to copy parts of a task to
296
+ some other place. It is important to clone to avoid having the same object present
297
+ multiple times in the metamodel.
298
+ """
299
+
300
+ # map from original object id to the rewritten object
301
+ cache = {}
302
+ def from_cache(original):
303
+ """ Lookup this object from the cache above, dealing with collections and with
304
+ objects that were not rewritten. """
305
+ if isinstance(original, tuple):
306
+ return tuple([from_cache(c) for c in original])
307
+ elif isinstance(original, FrozenOrderedSet):
308
+ return ordered_set(*[from_cache(c) for c in original])
309
+ elif isinstance(original, CLONABLE):
310
+ return cache.get(original.id, original)
311
+ else:
312
+ return original
313
+
314
+ # the last node that was processed and rewritten
315
+ prev_node = None
316
+ stack: list[ir.Node] = [task]
317
+ def to_stack(original):
318
+ """ Add this original node to the stack if it is clonable and was never processed;
319
+ return True iff the node was added to the stack. """
320
+ if isinstance(original, CLONABLE) and original.id not in cache:
321
+ stack.append(original)
322
+ return True
323
+ return False
324
+
325
+ while stack:
326
+ # peek the current node and get the initializable fields (i.e. ignore Node id)
327
+ curr = stack[-1]
328
+ curr_fields = list(filter(lambda f: f.init, fields(curr)))
329
+ stacked_children = False
330
+
331
+ # go over the fields adding to the stack the ones that we need to rewrite
332
+ for field in curr_fields:
333
+ field_value = getattr(curr, field.name)
334
+ if isinstance(field_value, (tuple, FrozenOrderedSet)):
335
+ # node field is a collection (tuple or set)
336
+ for s in field_value:
337
+ if isinstance(s, tuple):
338
+ # the value within the collection is a tuple (can happen for lookup args)
339
+ for c in s:
340
+ stacked_children = to_stack(c) or stacked_children
341
+ else:
342
+ # the value within the collection is a scalar
343
+ stacked_children = to_stack(s) or stacked_children
344
+ else:
345
+ # node field is a scalar
346
+ stacked_children = to_stack(field_value) or stacked_children
347
+
348
+ # if no childrean were stacked, we rewrote all fields of curr, so we can pop it and rewrite it
349
+ if not stacked_children:
350
+ stack.pop()
351
+ children = []
352
+ for f in curr_fields:
353
+ children.append(from_cache(getattr(curr, f.name)))
354
+ # create a new prev_node with the cloned children
355
+ prev_node = curr.__class__(*children)
356
+ cache[curr.id] = prev_node
357
+
358
+ # the last node we processed is the rewritten original node
359
+ assert(isinstance(prev_node, type(task)))
360
+ return prev_node