relationalai 0.11.3__py3-none-any.whl → 0.11.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. relationalai/clients/snowflake.py +6 -1
  2. relationalai/clients/use_index_poller.py +349 -188
  3. relationalai/early_access/dsl/bindings/csv.py +2 -2
  4. relationalai/semantics/internal/internal.py +22 -4
  5. relationalai/semantics/lqp/executor.py +61 -12
  6. relationalai/semantics/lqp/intrinsics.py +23 -0
  7. relationalai/semantics/lqp/model2lqp.py +13 -4
  8. relationalai/semantics/lqp/passes.py +2 -3
  9. relationalai/semantics/lqp/primitives.py +12 -1
  10. relationalai/semantics/metamodel/builtins.py +8 -1
  11. relationalai/semantics/metamodel/factory.py +3 -2
  12. relationalai/semantics/reasoners/graph/core.py +54 -2
  13. relationalai/semantics/reasoners/optimization/solvers_dev.py +20 -1
  14. relationalai/semantics/reasoners/optimization/solvers_pb.py +24 -3
  15. relationalai/semantics/rel/compiler.py +5 -17
  16. relationalai/semantics/rel/executor.py +2 -2
  17. relationalai/semantics/rel/rel.py +6 -0
  18. relationalai/semantics/rel/rel_utils.py +8 -1
  19. relationalai/semantics/rel/rewrite/extract_common.py +153 -242
  20. relationalai/semantics/sql/compiler.py +120 -39
  21. relationalai/semantics/sql/executor/duck_db.py +21 -0
  22. relationalai/semantics/sql/rewrite/denormalize.py +4 -6
  23. relationalai/semantics/sql/rewrite/recursive_union.py +23 -3
  24. relationalai/semantics/sql/sql.py +27 -0
  25. relationalai/semantics/std/__init__.py +2 -1
  26. relationalai/semantics/std/datetime.py +4 -0
  27. relationalai/semantics/std/re.py +83 -0
  28. relationalai/semantics/std/strings.py +1 -1
  29. relationalai/tools/cli_controls.py +445 -60
  30. relationalai/util/format.py +78 -1
  31. {relationalai-0.11.3.dist-info → relationalai-0.11.4.dist-info}/METADATA +3 -2
  32. {relationalai-0.11.3.dist-info → relationalai-0.11.4.dist-info}/RECORD +35 -33
  33. {relationalai-0.11.3.dist-info → relationalai-0.11.4.dist-info}/WHEEL +0 -0
  34. {relationalai-0.11.3.dist-info → relationalai-0.11.4.dist-info}/entry_points.txt +0 -0
  35. {relationalai-0.11.3.dist-info → relationalai-0.11.4.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,5 @@
1
1
  from io import StringIO
2
- from typing import Optional
2
+ from typing import Optional, Hashable
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
@@ -76,7 +76,7 @@ class BindableCsvColumn(BindableColumn, b.Relationship):
76
76
 
77
77
 
78
78
  class CsvTable(AbstractBindableTable[BindableCsvColumn]):
79
- _basic_type_schema: dict[str, str]
79
+ _basic_type_schema: dict[Hashable, str]
80
80
  _csv_data: list[pd.DataFrame]
81
81
  _num_rows: int
82
82
 
@@ -514,11 +514,13 @@ class Producer:
514
514
  #--------------------------------------------------
515
515
 
516
516
  def in_(self, values:list[Any]|Fragment) -> Expression:
517
+ columns = None
517
518
  if isinstance(values, Fragment):
518
519
  return self == values
519
520
  if not isinstance(values[0], tuple):
520
521
  values = [tuple([v]) for v in values]
521
- d = data(values)
522
+ columns = [f"v{i}" for i in range(len(values[0]))]
523
+ d = data(values, columns)
522
524
  return self == d[0]
523
525
 
524
526
  #--------------------------------------------------
@@ -907,9 +909,9 @@ class Concept(Producer):
907
909
  if python_types_to_concepts.get(v):
908
910
  v = python_types_to_concepts[v]
909
911
  if isinstance(v, Concept):
910
- setattr(self, k, Property(f"{{{self._name}}} has {{{k}:{v._name}}}", short_name=k, model=self._model))
912
+ setattr(self, k, Property(f"{{{self._name}}} has {{{k}:{v._name}}}", parent=self, short_name=k, model=self._model))
911
913
  elif isinstance(v, type) and issubclass(v, self._model.Enum): #type: ignore
912
- setattr(self, k, Property(f"{{{self._name}}} has {{{k}:{v._concept._name}}}", short_name=k, model=self._model))
914
+ setattr(self, k, Property(f"{{{self._name}}} has {{{k}:{v._concept._name}}}", parent=self, short_name=k, model=self._model))
913
915
  elif isinstance(v, Relationship):
914
916
  self._validate_identifier_relationship(v)
915
917
  setattr(self, k, v)
@@ -1189,6 +1191,7 @@ def is_decimal(concept: Concept) -> bool:
1189
1191
  Concept.builtins["Int"] = Concept.builtins["Int128"]
1190
1192
  Concept.builtins["Integer"] = Concept.builtins["Int128"]
1191
1193
 
1194
+ _np_datetime = np.dtype('datetime64[ns]')
1192
1195
  python_types_to_concepts : dict[Any, Concept] = {
1193
1196
  int: Concept.builtins["Int128"],
1194
1197
  float: Concept.builtins["Float"],
@@ -1213,6 +1216,7 @@ python_types_to_concepts : dict[Any, Concept] = {
1213
1216
  np.dtype('float32'): Concept.builtins["Float"],
1214
1217
  np.dtype('bool'): Concept.builtins["Bool"],
1215
1218
  np.dtype('object'): Concept.builtins["String"], # Often strings are stored as object dtype
1219
+ _np_datetime: Concept.builtins["DateTime"],
1216
1220
 
1217
1221
  # Pandas extension dtypes
1218
1222
  pd.Int64Dtype(): Concept.builtins["Int128"],
@@ -1655,7 +1659,9 @@ class Expression(Producer):
1655
1659
  raise ValueError(f"Argument index should be positive, got {idx}")
1656
1660
  if len(self._params) <= idx:
1657
1661
  raise ValueError(f"Expression '{self.__str__()}' has only {len(self._params)} arguments")
1658
- return ArgumentRef(self, self._params[idx])
1662
+ param = self._params[idx]
1663
+ # if param is an Expression then refer the last param of this expression
1664
+ return ArgumentRef(self, param._params[-1] if isinstance(param, Expression) else param)
1659
1665
 
1660
1666
  def __getattr__(self, name: str):
1661
1667
  last = self._params[-1]
@@ -2090,8 +2096,20 @@ class DataColumn(Producer):
2090
2096
  self._data = data
2091
2097
  self._type = _type
2092
2098
  self._name = name if isinstance(name, str) else f"v{name}"
2099
+ if pd.api.types.is_datetime64_any_dtype(_type):
2100
+ _type = _np_datetime
2101
+ # dates are objects in pandas
2102
+ elif pd.api.types.is_object_dtype(_type) and self._is_date_column():
2103
+ _type = date
2093
2104
  self._ref = python_types_to_concepts[_type].ref(self._name)
2094
2105
 
2106
+ def _is_date_column(self) -> bool:
2107
+ sample = self._data._data[self._name].dropna()
2108
+ if sample.empty:
2109
+ return False
2110
+ sample_value = sample.iloc[0]
2111
+ return isinstance(sample_value, date) and not isinstance(sample_value, datetime)
2112
+
2095
2113
  def __str__(self):
2096
2114
  return f"DataColumn({self._name}, {self._type})"
2097
2115
 
@@ -12,6 +12,7 @@ from relationalai import debugging
12
12
  from relationalai.semantics.lqp import result_helpers
13
13
  from relationalai.semantics.metamodel import ir, factory as f, executor as e
14
14
  from relationalai.semantics.lqp.compiler import Compiler
15
+ from relationalai.semantics.lqp.intrinsics import mk_intrinsic_datetime_now
15
16
  from relationalai.semantics.lqp.types import lqp_type_to_sql
16
17
  from lqp import print as lqp_print, ir as lqp_ir
17
18
  from lqp.parser import construct_configure
@@ -258,11 +259,47 @@ class LQPExecutor(e.Executor):
258
259
 
259
260
  return ", ".join(fields)
260
261
 
262
+ def _construct_configure(self):
263
+ config_dict = {}
264
+ # Only set the IVM flag if there is a value in `config`. Otherwise, let
265
+ # `construct_configure` set the default value.
266
+ ivm_flag = self.config.get('reasoner.rule.incremental_maintenance', None)
267
+ if ivm_flag:
268
+ config_dict['ivm.maintenance_level'] = lqp_ir.Value(value=ivm_flag, meta=None)
269
+ return construct_configure(config_dict, None)
270
+
271
+ def _compile_intrinsics(self) -> lqp_ir.Epoch:
272
+ """Construct an epoch that defines a number of built-in definitions used by the
273
+ emitter."""
274
+ with debugging.span("compile_intrinsics") as span:
275
+ debug_info = lqp_ir.DebugInfo(id_to_orig_name={}, meta=None)
276
+ intrinsics_fragment = lqp_ir.Fragment(
277
+ id = lqp_ir.FragmentId(id=b"__pyrel_lqp_intrinsics", meta=None),
278
+ declarations = [
279
+ mk_intrinsic_datetime_now(),
280
+ ],
281
+ debug_info = debug_info,
282
+ meta = None,
283
+ )
284
+
285
+ span["compile_type"] = "intrinsics"
286
+ span["lqp"] = lqp_print.to_string(intrinsics_fragment, {"print_names": True, "print_debug": False, "print_csv_filename": False})
287
+
288
+ return lqp_ir.Epoch(
289
+ writes=[
290
+ lqp_ir.Write(write_type=lqp_ir.Define(fragment=intrinsics_fragment, meta=None), meta=None)
291
+ ],
292
+ meta=None,
293
+ )
294
+
261
295
  def compile_lqp(self, model: ir.Model, task: ir.Task):
296
+ configure = self._construct_configure()
297
+
262
298
  model_txn = None
263
299
  if self._last_model != model:
264
300
  with debugging.span("compile", metamodel=model) as install_span:
265
301
  _, model_txn = self.compiler.compile(model, {"fragment_id": b"model"})
302
+ model_txn = txn_with_configure(model_txn, configure)
266
303
  install_span["compile_type"] = "model"
267
304
  install_span["lqp"] = lqp_print.to_string(model_txn, {"print_names": True, "print_debug": False, "print_csv_filename": False})
268
305
  self._last_model = model
@@ -275,23 +312,26 @@ class LQPExecutor(e.Executor):
275
312
  }
276
313
  result, final_model = self.compiler.compile_inner(query, options)
277
314
  export_info, query_txn = result
315
+ query_txn = txn_with_configure(query_txn, configure)
278
316
  compile_span["compile_type"] = "query"
279
317
  compile_span["lqp"] = lqp_print.to_string(query_txn, {"print_names": True, "print_debug": False, "print_csv_filename": False})
280
318
 
281
- txn = query_txn
319
+ # Merge the epochs into a single transactions. Long term the query bits should all
320
+ # go into a WhatIf action and the intrinsics could be fused with either of them. But
321
+ # for now we just use separate epochs.
322
+ epochs = []
323
+
324
+ epochs.append(self._compile_intrinsics())
325
+
282
326
  if model_txn is not None:
283
- # Merge the two LQP transactions into one. Long term the query bits should all
284
- # go into a WhatIf action. But for now we just use two separate epochs.
285
- model_epoch = model_txn.epochs[0]
286
- query_epoch = query_txn.epochs[0]
287
- txn = lqp_ir.Transaction(
288
- epochs=[model_epoch, query_epoch],
289
- configure=construct_configure({}, None),
290
- meta=None,
291
- )
327
+ epochs.append(model_txn.epochs[0])
292
328
 
293
- # Revalidate now that we've joined two epochs
294
- validate_lqp(txn)
329
+ epochs.append(query_txn.epochs[0])
330
+
331
+ txn = lqp_ir.Transaction(epochs=epochs, configure=configure, meta=None)
332
+
333
+ # Revalidate now that we've joined all the epochs.
334
+ validate_lqp(txn)
295
335
 
296
336
  txn_proto = convert_transaction(txn)
297
337
  # TODO (azreika): Should export_info be encoded as part of the txn_proto? [RAI-40312]
@@ -352,3 +392,12 @@ class LQPExecutor(e.Executor):
352
392
  # If processing the results failed, revert to the previous model.
353
393
  self._last_model = previous_model
354
394
  raise e
395
+
396
+ def txn_with_configure(txn: lqp_ir.Transaction, configure: lqp_ir.Configure) -> lqp_ir.Transaction:
397
+ """ Return a new transaction with the given configure. If the transaction already has
398
+ a configure, it is replaced. """
399
+ return lqp_ir.Transaction(
400
+ epochs=txn.epochs,
401
+ configure=configure,
402
+ meta=txn.meta,
403
+ )
@@ -0,0 +1,23 @@
1
+ from datetime import datetime, timezone
2
+
3
+ from relationalai.semantics.lqp import ir as lqp
4
+ from relationalai.semantics.lqp.constructors import mk_abstraction, mk_value, mk_var, mk_type, mk_primitive
5
+ from relationalai.semantics.lqp.utils import lqp_hash
6
+
7
+ def mk_intrinsic_datetime_now() -> lqp.Def:
8
+ """Constructs a definition of the current datetime."""
9
+ id = lqp_hash("__pyrel_lqp_intrinsic_datetime_now")
10
+ out = mk_var("out")
11
+ out_type = mk_type(lqp.TypeName.DATETIME)
12
+ now = mk_value(lqp.DateTimeValue(value=datetime.now(timezone.utc), meta=None))
13
+ datetime_now = mk_abstraction(
14
+ [(out, out_type)],
15
+ mk_primitive("rel_primitive_eq", [out, now]),
16
+ )
17
+
18
+ return lqp.Def(
19
+ name = lqp.RelationId(id=id, meta=None),
20
+ body = datetime_now,
21
+ attrs = [],
22
+ meta = None,
23
+ )
@@ -192,12 +192,21 @@ def _translate_effect(ctx: TranslationCtx, effect: Union[ir.Output, ir.Update],
192
192
  elif isinstance(effect, ir.Output):
193
193
  ctx.output_ids.append((rel_id, def_name))
194
194
 
195
+ # First we collect annotations on the effect itself, e.g. from something like
196
+ # `select(...).annotate(...)`.
197
+ annotations = effect.annotations
198
+ if isinstance(effect, ir.Update):
199
+ # Then we translate annotations on the relation itself, e.g.
200
+ # ```
201
+ # Bar.foo = model.Relationship(...)
202
+ # Bar.foo.annotate(...)
203
+ # ```
204
+ annotations = annotations | effect.relation.annotations
205
+
195
206
  return lqp.Def(
196
207
  name = rel_id,
197
208
  body = mk_abstraction(projection, new_body),
198
- # TODO this only covers the annotations on the effect itself. Annotations on the
199
- # relation are not included yet.
200
- attrs = _translate_annotations(effect.annotations),
209
+ attrs = _translate_annotations(annotations),
201
210
  meta = None,
202
211
  )
203
212
 
@@ -697,4 +706,4 @@ def _translate_join(ctx: TranslationCtx, task: ir.Lookup) -> lqp.Formula:
697
706
 
698
707
  output_term = _translate_term(ctx, target)[0]
699
708
 
700
- return lqp.Reduce(meta=None, op=op, body=body, terms=[output_term])
709
+ return lqp.Reduce(meta=None, op=op, body=body, terms=[output_term])
@@ -8,7 +8,7 @@ from relationalai.semantics.metamodel.util import FrozenOrderedSet
8
8
 
9
9
  from relationalai.semantics.metamodel.rewrite import Flatten
10
10
  # TODO: Move this into metamodel.rewrite
11
- from relationalai.semantics.rel.rewrite import QuantifyVars, CDC
11
+ from relationalai.semantics.rel.rewrite import QuantifyVars, CDC, ExtractCommon
12
12
 
13
13
  from relationalai.semantics.lqp.utils import output_names
14
14
 
@@ -25,8 +25,7 @@ def lqp_passes() -> list[Pass]:
25
25
  InferTypes(),
26
26
  DNFUnionSplitter(),
27
27
  ExtractKeys(),
28
- # Broken
29
- # ExtractCommon(),
28
+ ExtractCommon(),
30
29
  Flatten(),
31
30
  Splinter(), # Splits multi-headed rules into multiple rules
32
31
  QuantifyVars(), # Adds missing existentials
@@ -1,7 +1,7 @@
1
1
  from relationalai.semantics.metamodel.types import digits_to_bits
2
2
  from relationalai.semantics.lqp import ir as lqp
3
3
  from relationalai.semantics.lqp.types import is_numeric
4
- from relationalai.semantics.lqp.utils import UniqueNames
4
+ from relationalai.semantics.lqp.utils import UniqueNames, lqp_hash
5
5
  from relationalai.semantics.lqp.constructors import mk_primitive, mk_specialized_value, mk_type, mk_value, mk_var
6
6
 
7
7
  rel_to_lqp = {
@@ -62,6 +62,7 @@ rel_to_lqp = {
62
62
  "date_add": "rel_primitive_typed_add_date_period",
63
63
  "date_subtract": "rel_primitive_typed_subtract_date_period",
64
64
  "dates_period_days": "rel_primitive_date_days_between",
65
+ "datetime_now": "__pyrel_lqp_intrinsic_datetime_now",
65
66
  "datetime_add": "rel_primitive_typed_add_datetime_period",
66
67
  "datetime_subtract": "rel_primitive_typed_subtract_datetime_period",
67
68
  "datetime_year": "rel_primitive_datetime_year",
@@ -175,6 +176,16 @@ def build_primitive(
175
176
  terms, term_types = _reorder_primitive_terms(lqp_name, terms, term_types)
176
177
  _assert_primitive_terms(lqp_name, terms, term_types)
177
178
 
179
+ # Handle intrinsics. To callers of `build_primitive` the distinction between intrinsic
180
+ # and primitive doesn't matter, so we don't want to burden them with that detail.
181
+ # Intrinsics are built-in definitions added by the LQP emitter, that user logic can just
182
+ # refer to.
183
+ if lqp_name == "__pyrel_lqp_intrinsic_datetime_now":
184
+ id = lqp.RelationId(id=lqp_hash(lqp_name), meta=None)
185
+ assert len(terms) == 1
186
+ assert isinstance(terms[0], lqp.Term)
187
+ return lqp.Atom(name=id, terms=[terms[0]], meta=None)
188
+
178
189
  return mk_primitive(lqp_name, terms)
179
190
 
180
191
  def relname_to_lqp_name(name: str) -> str:
@@ -391,7 +391,7 @@ erfinv = f.relation(
391
391
 
392
392
  # Strings
393
393
  concat = f.relation("concat", [f.input_field("a", types.String), f.input_field("b", types.String), f.field("c", types.String)])
394
- num_chars = f.relation("num_chars", [f.input_field("a", types.String), f.field("b", types.Int128)])
394
+ num_chars = f.relation("num_chars", [f.input_field("a", types.String), f.field("b", types.Int64)])
395
395
  starts_with = f.relation("starts_with", [f.input_field("a", types.String), f.input_field("b", types.String)])
396
396
  ends_with = f.relation("ends_with", [f.input_field("a", types.String), f.input_field("b", types.String)])
397
397
  contains = f.relation("contains", [f.input_field("a", types.String), f.input_field("b", types.String)])
@@ -406,7 +406,13 @@ replace = f.relation("replace", [f.input_field("a", types.String), f.input_field
406
406
  split = f.relation("split", [f.input_field("a", types.String), f.input_field("b", types.String), f.field("c", types.Int64), f.field("d", types.String)])
407
407
  # should be a separate builtin. SQL emitter compiles it differently
408
408
  split_part = f.relation("split_part", [f.input_field("a", types.String), f.input_field("b", types.String), f.field("c", types.Int64), f.field("d", types.String)])
409
+
410
+ # regex
409
411
  regex_match = f.relation("regex_match", [f.input_field("a", types.String), f.input_field("b", types.String)])
412
+ regex_match_all = f.relation("regex_match_all", [f.input_field("a", types.String), f.input_field("b", types.String), f.input_field("c", types.Int64), f.field("d", types.String)])
413
+ capture_group_by_index = f.relation("capture_group_by_index", [f.input_field("a", types.String), f.input_field("b", types.String), f.input_field("c", types.Int64), f.input_field("d", types.Int64), f.field("e", types.String)])
414
+ capture_group_by_name = f.relation("capture_group_by_name", [f.input_field("a", types.String), f.input_field("b", types.String), f.input_field("c", types.Int64), f.input_field("d", types.String), f.field("e", types.String)])
415
+ escape_regex_metachars = f.relation("escape_regex_metachars", [f.input_field("a", types.String), f.field("b", types.String)])
410
416
 
411
417
  # Dates
412
418
  date_format = f.relation("date_format", [f.input_field("a", types.Date), f.input_field("b", types.String), f.field("c", types.String)])
@@ -422,6 +428,7 @@ date_add = f.relation("date_add", [f.input_field("a", types.Date), f.input_field
422
428
  dates_period_days = f.relation("dates_period_days", [f.input_field("a", types.Date), f.input_field("b", types.Date), f.field("c", types.Int64)])
423
429
  datetimes_period_milliseconds = f.relation("datetimes_period_milliseconds", [f.input_field("a", types.DateTime), f.input_field("b", types.DateTime), f.field("c", types.Int64)])
424
430
  date_subtract = f.relation("date_subtract", [f.input_field("a", types.Date), f.input_field("b", types.Int64), f.field("c", types.Date)])
431
+ datetime_now = f.relation("datetime_now", [f.field("a", types.DateTime)])
425
432
  datetime_add = f.relation("datetime_add", [f.input_field("a", types.DateTime), f.input_field("b", types.Int64), f.field("c", types.DateTime)])
426
433
  datetime_subtract = f.relation("datetime_subtract", [f.input_field("a", types.DateTime), f.input_field("b", types.Int64), f.field("c", types.DateTime)])
427
434
  datetime_year = f.relation("datetime_year", [f.input_field("a", types.DateTime), f.input_field("b", types.String), f.field("c", types.Int64)])
@@ -185,10 +185,11 @@ def lit(value: Any) -> ir.Value:
185
185
  return ir.Literal(types.Bool, value)
186
186
  elif isinstance(value, decimal.Decimal):
187
187
  return ir.Literal(types.Decimal, value)
188
- elif isinstance(value, datetime.date):
189
- return ir.Literal(types.Date, value)
188
+ # datetime.datetime is a subclass of datetime.date, so check it first
190
189
  elif isinstance(value, datetime.datetime):
191
190
  return ir.Literal(types.DateTime, value)
191
+ elif isinstance(value, datetime.date):
192
+ return ir.Literal(types.Date, value)
192
193
  elif isinstance(value, list):
193
194
  return tuple([lit(v) for v in value])
194
195
  else:
@@ -3865,12 +3865,19 @@ class Graph():
3865
3865
 
3866
3866
 
3867
3867
  @include_in_docs
3868
- def triangle_count(self):
3868
+ def triangle_count(self, *, of: Optional[Relationship] = None):
3869
3869
  """Returns a binary relationship containing the number of unique triangles each node belongs to.
3870
3870
 
3871
3871
  A triangle is a set of three nodes where each node has a directed
3872
3872
  or undirected edge to the other two nodes, forming a 3-cycle.
3873
3873
 
3874
+ Parameters
3875
+ ----------
3876
+ of : Relationship, optional
3877
+ A unary relationship containing a subset of the graph's nodes. When
3878
+ provided, constrains the domain of the triangle count computation: only
3879
+ triangle counts of nodes in this relationship are computed and returned.
3880
+
3874
3881
  Returns
3875
3882
  -------
3876
3883
  Relationship
@@ -3926,6 +3933,31 @@ class Graph():
3926
3933
  3 4 0
3927
3934
  4 5 0
3928
3935
 
3936
+ >>> # 4. Use 'of' parameter to constrain the set of nodes to compute triangle counts of
3937
+ >>> # Define a subset containing only nodes 1 and 3
3938
+ >>> subset = model.Relationship(f"{{node:{Node}}} is in subset")
3939
+ >>> node = Node.ref()
3940
+ >>> where(union(node.id == 1, node.id == 3)).define(subset(node))
3941
+ >>>
3942
+ >>> # Get triangle counts only of nodes in the subset
3943
+ >>> constrained_triangle_count = graph.triangle_count(of=subset)
3944
+ >>> select(node.id, count).where(constrained_triangle_count(node, count)).inspect()
3945
+ ▰▰▰▰ Setup complete
3946
+ id count
3947
+ 0 1 1
3948
+ 1 3 1
3949
+
3950
+ Notes
3951
+ -----
3952
+ The ``triangle_count()`` method, called with no parameters, computes and caches
3953
+ the full triangle count relationship, providing efficient reuse across multiple
3954
+ calls to ``triangle_count()``. In contrast, ``triangle_count(of=subset)`` computes a
3955
+ constrained relationship specific to the passed-in ``subset`` and that
3956
+ call site. When a significant fraction of the triangle count relation is needed
3957
+ across a program, ``triangle_count()`` is typically more efficient; this is the
3958
+ typical case. Use ``triangle_count(of=subset)`` only when small subsets of the
3959
+ triangle count relationship are needed collectively across the program.
3960
+
3929
3961
  See Also
3930
3962
  --------
3931
3963
  triangle
@@ -3933,15 +3965,35 @@ class Graph():
3933
3965
  num_triangles
3934
3966
 
3935
3967
  """
3968
+ if of is not None:
3969
+ self._validate_node_subset_parameter(of)
3970
+ return self._triangle_count_of(of)
3936
3971
  return self._triangle_count
3937
3972
 
3938
3973
  @cached_property
3939
3974
  def _triangle_count(self):
3940
3975
  """Lazily define and cache the self._triangle_count relationship."""
3976
+ return self._create_triangle_count_relationship(nodes_subset=None)
3977
+
3978
+ def _triangle_count_of(self, nodes_subset: Relationship):
3979
+ """
3980
+ Create a triangle count relationship constrained to the subset of nodes
3981
+ in `nodes_subset`. Note this relationship is not cached; it is
3982
+ specific to the callsite.
3983
+ """
3984
+ return self._create_triangle_count_relationship(nodes_subset=nodes_subset)
3985
+
3986
+ def _create_triangle_count_relationship(self, *, nodes_subset: Optional[Relationship]):
3987
+ """Create a triangle count relationship, optionally constrained to a subset of nodes."""
3941
3988
  _triangle_count_rel = self._model.Relationship(f"{{node:{self._NodeConceptStr}}} belongs to {{count:Integer}} triangles")
3942
3989
 
3990
+ if nodes_subset is None:
3991
+ node_constraint = self.Node # No constraint on nodes.
3992
+ else:
3993
+ node_constraint = nodes_subset(self.Node) # Nodes constrained to given subset.
3994
+
3943
3995
  where(
3944
- self.Node,
3996
+ node_constraint,
3945
3997
  _count := self._nonzero_triangle_count_fragment(self.Node) | 0
3946
3998
  ).define(_triangle_count_rel(self.Node, _count))
3947
3999
 
@@ -2,12 +2,15 @@ from __future__ import annotations
2
2
  from typing import Union
3
3
  import textwrap
4
4
  import uuid
5
+ import time
5
6
 
6
7
  from relationalai.semantics.snowflake import Table
7
8
  from relationalai.semantics import std
8
9
  from relationalai.semantics.internal import internal as b # TODO(coey) change b name or remove b.?
9
10
  from relationalai.semantics.rel.executor import RelExecutor
10
11
  from relationalai.semantics.lqp.executor import LQPExecutor
12
+ from relationalai.tools.constants import DEFAULT_QUERY_TIMEOUT_MINS
13
+ from relationalai.util.timeout import calc_remaining_timeout_minutes
11
14
 
12
15
  from .common import make_name
13
16
  from relationalai.experimental.solvers import Solver
@@ -243,6 +246,17 @@ class SolverModelDev:
243
246
  app_name = resources.get_app_name()
244
247
  print(app_name)
245
248
 
249
+ # Note: currently the query timeout is not propagated to the steps 'export model
250
+ # relations', and 'import result relations'. For those steps the default query
251
+ # timeout value defined in the config will apply.
252
+ # TODO: propagate the query timeout to those steps as well.
253
+ query_timeout_mins = kwargs.get("query_timeout_mins", None)
254
+ config = self._model._config
255
+ if query_timeout_mins is None and (timeout_value := config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
256
+ query_timeout_mins = int(timeout_value)
257
+ config_file_path = getattr(config, 'file_path', None)
258
+ start_time = time.monotonic()
259
+
246
260
  # 1. export model relations
247
261
  print("export model relations")
248
262
  # TODO(coey) perf: only export the relations that are actually used in the model
@@ -266,6 +280,9 @@ class SolverModelDev:
266
280
  b.select(*rel._field_refs).where(rel(*rel._field_refs)).into(table)
267
281
 
268
282
  # 2. execute solver job and wait for completion
283
+ remaining_timeout_minutes = calc_remaining_timeout_minutes(
284
+ start_time, query_timeout_mins, config_file_path=config_file_path,
285
+ )
269
286
  print("execute solver job")
270
287
  payload = {
271
288
  "solver": solver.solver_name.lower(),
@@ -273,7 +290,9 @@ class SolverModelDev:
273
290
  "input_id": input_id,
274
291
  "data_type": self._data_type
275
292
  }
276
- job_id = solver._exec_job(payload, log_to_console=log_to_console)
293
+ job_id = solver._exec_job(
294
+ payload, log_to_console=log_to_console, query_timeout_mins=remaining_timeout_minutes,
295
+ )
277
296
  print(f"job id: {job_id}") # TODO(coey) maybe job_id is not useful
278
297
 
279
298
  # 3. import result relations
@@ -2,12 +2,15 @@ from __future__ import annotations
2
2
  from typing import Any, Union
3
3
  import textwrap
4
4
  import uuid
5
+ import time
5
6
 
6
7
  from relationalai.semantics.metamodel.util import ordered_set
7
8
  from relationalai.semantics.internal import internal as b # TODO(coey) change b name or remove b.?
8
9
  from relationalai.semantics.rel.executor import RelExecutor
9
10
  from .common import make_name
10
11
  from relationalai.experimental.solvers import Solver
12
+ from relationalai.tools.constants import DEFAULT_QUERY_TIMEOUT_MINS
13
+ from relationalai.util.timeout import calc_remaining_timeout_minutes
11
14
 
12
15
  _Any = Union[b.Producer, str, float, int]
13
16
  _Number = Union[b.Producer, float, int]
@@ -222,6 +225,14 @@ class SolverModelPB:
222
225
  assert isinstance(executor, RelExecutor)
223
226
  prefix_l = f"solvermodel_{self._id}_"
224
227
 
228
+ query_timeout_mins = kwargs.get("query_timeout_mins", None)
229
+ config = self._model._config
230
+ if query_timeout_mins is None and (timeout_value := config.get("query_timeout_mins", DEFAULT_QUERY_TIMEOUT_MINS)) is not None:
231
+ query_timeout_mins = int(timeout_value)
232
+ config_file_path = getattr(config, 'file_path', None)
233
+ start_time = time.monotonic()
234
+ remaining_timeout_minutes = query_timeout_mins
235
+
225
236
  # 1. Materialize the model and store it.
226
237
  print("export model")
227
238
  b.select(b.count(self.Variable)).to_df() # TODO(coey) weird hack to avoid uninitialized properties error
@@ -244,14 +255,22 @@ class SolverModelPB:
244
255
  def config[:envelope, :payload, :data]: model_string
245
256
  def config[:envelope, :payload, :path]: "{model_uri}"
246
257
  def export {{ config }}
247
- """))
258
+ """), query_timeout_mins=remaining_timeout_minutes)
248
259
 
249
260
  # 2. Execute job and wait for completion.
250
261
  print("execute solver job")
251
- job_id = solver._exec_job(payload, log_to_console=log_to_console)
262
+ remaining_timeout_minutes = calc_remaining_timeout_minutes(
263
+ start_time, query_timeout_mins, config_file_path=config_file_path,
264
+ )
265
+ job_id = solver._exec_job(
266
+ payload, log_to_console=log_to_console, query_timeout_mins=remaining_timeout_minutes,
267
+ )
252
268
 
253
269
  # 3. Extract result.
254
270
  print("extract result")
271
+ remaining_timeout_minutes = calc_remaining_timeout_minutes(
272
+ start_time, query_timeout_mins, config_file_path=config_file_path,
273
+ )
255
274
  extract_str = textwrap.dedent(f"""
256
275
  def raw_result {{
257
276
  load_binary["snowflake://APP_STATE.RAI_INTERNAL_STAGE/job-results/{job_id}/result.binpb"]
@@ -289,7 +308,9 @@ class SolverModelPB:
289
308
  ::std::mirror::convert(std::mirror::typeof[Int128], j, i)
290
309
  )
291
310
  """)
292
- executor.execute_raw(extract_str, readonly=False)
311
+ executor.execute_raw(
312
+ extract_str, readonly=False, query_timeout_mins=remaining_timeout_minutes,
313
+ )
293
314
 
294
315
  print("finished solve")
295
316
  return None
@@ -12,7 +12,7 @@ from relationalai.semantics.metamodel.visitor import ReadWriteVisitor
12
12
  from relationalai.semantics.metamodel.util import OrderedSet, group_by, NameCache, ordered_set
13
13
 
14
14
  from relationalai.semantics.rel import rel, rel_utils as u, builtins as rel_bt
15
- from relationalai.semantics.rel.rewrite import CDC, QuantifyVars
15
+ from relationalai.semantics.rel.rewrite import CDC, QuantifyVars, ExtractCommon
16
16
 
17
17
  import math
18
18
 
@@ -32,7 +32,7 @@ class Compiler(c.Compiler):
32
32
  InferTypes(),
33
33
  DNFUnionSplitter(),
34
34
  ExtractKeys(),
35
- # rewrite.ExtractCommon(),
35
+ ExtractCommon(),
36
36
  Flatten(),
37
37
  Splinter(),
38
38
  QuantifyVars(),
@@ -125,21 +125,6 @@ class ModelToRel:
125
125
  tuple([rel.Annotation("inline", ())]),
126
126
  ))
127
127
 
128
- if "pyrel_num_chars" in reads:
129
- defs.append(
130
- rel.Def("pyrel_num_chars",
131
- tuple([rel.Var("x"), rel.Var("y")]),
132
- rel.Exists(
133
- tuple([rel.Var("z")]),
134
- rel.And(ordered_set(
135
- rel.atom("::std::common::num_chars", tuple([rel.Var("x"), rel.Var("z")])),
136
- rel.Atom(self._convert_abs(types.Int64, types.Int128), tuple([rel.Var("z"), rel.Var("y")])),
137
- )),
138
- ),
139
- tuple([rel.Annotation("inline", ())]),
140
- ),
141
- )
142
-
143
128
  if "pyrel_count" in reads:
144
129
  defs.append(
145
130
  rel.Def("pyrel_count",
@@ -249,6 +234,9 @@ class ModelToRel:
249
234
  ),
250
235
  )
251
236
 
237
+ if "pyrel_regex_search" in reads:
238
+ raise NotImplementedError("pyrel_regex_search is not implemented")
239
+
252
240
  return defs
253
241
 
254
242
  @staticmethod
@@ -305,8 +305,8 @@ class RelExecutor(e.Executor):
305
305
 
306
306
  # NOTE(coey): this is added temporarily to support executing Rel for the solvers library in EA.
307
307
  # It can be removed once this is no longer needed by the solvers library.
308
- def execute_raw(self, raw_rel:str, readonly:bool=True) -> DataFrame:
309
- raw_results = self.resources.exec_raw(self.database, self.engine, raw_rel, readonly, nowait_durable=True)
308
+ def execute_raw(self, raw_rel:str, readonly:bool=True, query_timeout_mins:int|None=None) -> DataFrame:
309
+ raw_results = self.resources.exec_raw(self.database, self.engine, raw_rel, readonly, nowait_durable=True, query_timeout_mins=query_timeout_mins)
310
310
  df, errs = result_helpers.format_results(raw_results, None, generation=Generation.QB) # Pass None for task parameter
311
311
  self.report_errors(errs)
312
312
  return df
@@ -223,6 +223,12 @@ class Printer(BasePrinter):
223
223
  self._print("::std::common::int[128,")
224
224
  self._print(str(value))
225
225
  self._print("]")
226
+ elif isinstance(value, datetime):
227
+ if value.tzinfo is None:
228
+ value = value.replace(tzinfo=timezone.utc)
229
+ self._print(value.astimezone(timezone.utc).isoformat())
230
+ elif isinstance(value, date):
231
+ self._print(value.isoformat())
226
232
  else:
227
233
  self._print(str(value))
228
234