kumoai 2.8.0.dev202508221830__cp312-cp312-win_amd64.whl → 2.13.0.dev202512041141__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (52) hide show
  1. kumoai/__init__.py +22 -11
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +17 -16
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/rfm.py +37 -8
  6. kumoai/connector/file_upload_connector.py +94 -85
  7. kumoai/connector/utils.py +1399 -210
  8. kumoai/experimental/rfm/__init__.py +164 -46
  9. kumoai/experimental/rfm/authenticate.py +8 -5
  10. kumoai/experimental/rfm/backend/__init__.py +0 -0
  11. kumoai/experimental/rfm/backend/local/__init__.py +38 -0
  12. kumoai/experimental/rfm/backend/local/table.py +109 -0
  13. kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +117 -0
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
  16. kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
  17. kumoai/experimental/rfm/base/__init__.py +10 -0
  18. kumoai/experimental/rfm/base/column.py +66 -0
  19. kumoai/experimental/rfm/base/source.py +18 -0
  20. kumoai/experimental/rfm/base/table.py +545 -0
  21. kumoai/experimental/rfm/{local_graph.py → graph.py} +413 -144
  22. kumoai/experimental/rfm/infer/__init__.py +6 -0
  23. kumoai/experimental/rfm/infer/dtype.py +79 -0
  24. kumoai/experimental/rfm/infer/pkey.py +126 -0
  25. kumoai/experimental/rfm/infer/time_col.py +62 -0
  26. kumoai/experimental/rfm/infer/timestamp.py +7 -4
  27. kumoai/experimental/rfm/local_graph_sampler.py +58 -11
  28. kumoai/experimental/rfm/local_graph_store.py +45 -37
  29. kumoai/experimental/rfm/local_pquery_driver.py +342 -46
  30. kumoai/experimental/rfm/pquery/__init__.py +4 -4
  31. kumoai/experimental/rfm/pquery/{backend.py → executor.py} +28 -58
  32. kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
  33. kumoai/experimental/rfm/rfm.py +559 -148
  34. kumoai/experimental/rfm/sagemaker.py +138 -0
  35. kumoai/jobs.py +27 -1
  36. kumoai/kumolib.cp312-win_amd64.pyd +0 -0
  37. kumoai/pquery/prediction_table.py +5 -3
  38. kumoai/pquery/training_table.py +5 -3
  39. kumoai/spcs.py +1 -3
  40. kumoai/testing/decorators.py +1 -1
  41. kumoai/trainer/job.py +9 -30
  42. kumoai/trainer/trainer.py +19 -10
  43. kumoai/utils/__init__.py +2 -1
  44. kumoai/utils/progress_logger.py +96 -16
  45. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/METADATA +14 -5
  46. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/RECORD +49 -36
  47. kumoai/experimental/rfm/local_table.py +0 -448
  48. kumoai/experimental/rfm/pquery/pandas_backend.py +0 -437
  49. kumoai/experimental/rfm/utils.py +0 -347
  50. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/WHEEL +0 -0
  51. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/licenses/LICENSE +0 -0
  52. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/top_level.txt +0 -0
@@ -1,24 +1,41 @@
1
1
  import warnings
2
- from typing import Dict, List, Literal, Optional, Tuple, Union
2
+ from typing import Dict, List, Literal, NamedTuple, Optional, Set, Tuple, Union
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
6
- from kumoapi.pquery import QueryType
7
- from kumoapi.rfm import PQueryDefinition
6
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
7
+ from kumoapi.pquery.AST import (
8
+ Aggregation,
9
+ ASTNode,
10
+ Column,
11
+ Condition,
12
+ Filter,
13
+ Join,
14
+ LogicalOperation,
15
+ )
16
+ from kumoapi.task import TaskType
17
+ from kumoapi.typing import AggregationType, DateOffset, Stype
8
18
 
9
19
  import kumoai.kumolib as kumolib
10
20
  from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
11
- from kumoai.experimental.rfm.pquery import PQueryPandasBackend
21
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
12
22
 
13
23
  _coverage_warned = False
14
24
 
15
25
 
26
+ class SamplingSpec(NamedTuple):
27
+ edge_type: Tuple[str, str, str]
28
+ hop: int
29
+ start_offset: Optional[DateOffset]
30
+ end_offset: Optional[DateOffset]
31
+
32
+
16
33
  class LocalPQueryDriver:
17
34
  def __init__(
18
35
  self,
19
36
  graph_store: LocalGraphStore,
20
- query: PQueryDefinition,
21
- random_seed: Optional[int],
37
+ query: ValidatedPredictiveQuery,
38
+ random_seed: Optional[int] = None,
22
39
  ) -> None:
23
40
  self._graph_store = graph_store
24
41
  self._query = query
@@ -27,14 +44,13 @@ class LocalPQueryDriver:
27
44
 
28
45
  def _get_candidates(
29
46
  self,
30
- anchor_time: Union[pd.Timestamp, Literal['entity']],
31
47
  exclude_node: Optional[np.ndarray] = None,
32
48
  ) -> np.ndarray:
33
49
 
34
50
  if self._query.query_type == QueryType.TEMPORAL:
35
51
  assert exclude_node is None
36
52
 
37
- table_name = self._query.entity.pkey.table_name
53
+ table_name = self._query.entity_table
38
54
  num_nodes = len(self._graph_store.df_dict[table_name])
39
55
  mask_dict = self._graph_store.mask_dict
40
56
 
@@ -61,12 +77,37 @@ class LocalPQueryDriver:
61
77
 
62
78
  return candidate
63
79
 
80
+ def _filter_candidates_by_time(
81
+ self,
82
+ candidate: np.ndarray,
83
+ anchor_time: pd.Timestamp,
84
+ ) -> np.ndarray:
85
+
86
+ entity = self._query.entity_table
87
+
88
+ # Filter out entities that do not exist yet in time:
89
+ time_sec = self._graph_store.time_dict.get(entity)
90
+ if time_sec is not None:
91
+ mask = time_sec[candidate] <= (anchor_time.value // (1000**3))
92
+ candidate = candidate[mask]
93
+
94
+ # Filter out entities that no longer exist in time:
95
+ end_time_col = self._graph_store.end_time_column_dict.get(entity)
96
+ if end_time_col is not None:
97
+ ser = self._graph_store.df_dict[entity][end_time_col]
98
+ ser = ser.iloc[candidate]
99
+ mask = (anchor_time < ser) | ser.isna().to_numpy()
100
+ candidate = candidate[mask]
101
+
102
+ return candidate
103
+
64
104
  def collect_test(
65
105
  self,
66
106
  size: int,
67
107
  anchor_time: Union[pd.Timestamp, Literal['entity']],
68
108
  batch_size: Optional[int] = None,
69
109
  max_iterations: int = 20,
110
+ guarantee_train_examples: bool = True,
70
111
  ) -> Tuple[np.ndarray, pd.Series, pd.Series]:
71
112
  r"""Collects test nodes and their labels used for evaluation.
72
113
 
@@ -75,13 +116,15 @@ class LocalPQueryDriver:
75
116
  anchor_time: The anchor time.
76
117
  batch_size: How many nodes to process in a single batch.
77
118
  max_iterations: The number of steps to run before aborting.
119
+ guarantee_train_examples: Ensures that test examples do not occupy
120
+ the entire set of entity candidates.
78
121
 
79
122
  Returns:
80
123
  A triplet holding the nodes, timestamps and labels.
81
124
  """
82
125
  batch_size = size if batch_size is None else batch_size
83
126
 
84
- candidate = self._get_candidates(anchor_time)
127
+ candidate = self._get_candidates()
85
128
 
86
129
  nodes: List[np.ndarray] = []
87
130
  times: List[pd.Series] = []
@@ -93,19 +136,12 @@ class LocalPQueryDriver:
93
136
  node = candidate[candidate_offset:candidate_offset + batch_size]
94
137
 
95
138
  if isinstance(anchor_time, pd.Timestamp):
96
- # Filter out non-existent entities:
97
- time = self._graph_store.time_dict.get(
98
- self._query.entity.pkey.table_name)
99
- if time is not None:
100
- node = node[time[node] <= (anchor_time.value // (1000**3))]
101
-
102
- if isinstance(anchor_time, pd.Timestamp):
139
+ node = self._filter_candidates_by_time(node, anchor_time)
103
140
  time = pd.Series(anchor_time).repeat(len(node))
104
141
  time = time.astype('datetime64[ns]').reset_index(drop=True)
105
142
  else:
106
143
  assert anchor_time == 'entity'
107
- time = self._graph_store.time_dict[
108
- self._query.entity.pkey.table_name]
144
+ time = self._graph_store.time_dict[self._query.entity_table]
109
145
  time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
110
146
 
111
147
  y, mask = self(node, time)
@@ -148,6 +184,16 @@ class LocalPQueryDriver:
148
184
  f"using the 'max_pq_iterations' option. This "
149
185
  f"warning will not be shown again in this run.")
150
186
 
187
+ if (guarantee_train_examples
188
+ and self._query.query_type == QueryType.STATIC
189
+ and candidate_offset >= len(candidate)):
190
+ # In case all valid entities are used as test examples, we can no
191
+ # longer find any training example. Fallback to a 50/50 split:
192
+ size = len(node) // 2
193
+ node = node[:size]
194
+ time = time.iloc[:size]
195
+ y = y.iloc[:size]
196
+
151
197
  return node, time, y
152
198
 
153
199
  def collect_train(
@@ -172,7 +218,7 @@ class LocalPQueryDriver:
172
218
  """
173
219
  batch_size = size if batch_size is None else batch_size
174
220
 
175
- candidate = self._get_candidates(anchor_time, exclude_node)
221
+ candidate = self._get_candidates(exclude_node)
176
222
 
177
223
  if len(candidate) == 0:
178
224
  raise RuntimeError("Failed to generate any context examples "
@@ -182,28 +228,18 @@ class LocalPQueryDriver:
182
228
  times: List[pd.Series] = []
183
229
  ys: List[pd.Series] = []
184
230
 
185
- if isinstance(anchor_time, pd.Timestamp):
186
- anchor_time = anchor_time - self._query.target.end_offset
187
-
188
231
  reached_end = False
189
232
  num_labels = candidate_offset = 0
190
233
  for _ in range(max_iterations):
191
234
  node = candidate[candidate_offset:candidate_offset + batch_size]
192
235
 
193
236
  if isinstance(anchor_time, pd.Timestamp):
194
- # Filter out non-existent entities:
195
- time = self._graph_store.time_dict.get(
196
- self._query.entity.pkey.table_name)
197
- if time is not None:
198
- node = node[time[node] <= (anchor_time.value // (1000**3))]
199
-
200
- if isinstance(anchor_time, pd.Timestamp):
237
+ node = self._filter_candidates_by_time(node, anchor_time)
201
238
  time = pd.Series(anchor_time).repeat(len(node))
202
239
  time = time.astype('datetime64[ns]').reset_index(drop=True)
203
240
  else:
204
241
  assert anchor_time == 'entity'
205
- time = self._graph_store.time_dict[
206
- self._query.entity.pkey.table_name]
242
+ time = self._graph_store.time_dict[self._query.entity_table]
207
243
  time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
208
244
 
209
245
  y, mask = self(node, time)
@@ -228,7 +264,9 @@ class LocalPQueryDriver:
228
264
  reached_end = True
229
265
  break
230
266
  candidate_offset = 0
231
- anchor_time = anchor_time - self._query.target.end_offset
267
+ time_frame = self._query.target_timeframe.timeframe
268
+ anchor_time = anchor_time - (time_frame *
269
+ self._query.num_forecasts)
232
270
  if anchor_time < self._graph_store.min_time:
233
271
  reached_end = True
234
272
  break # No earlier anchor time left. Abort.
@@ -257,13 +295,171 @@ class LocalPQueryDriver:
257
295
 
258
296
  return node, time, y
259
297
 
260
- def __call__(
298
+ def is_valid(
299
+ self,
300
+ node: np.ndarray,
301
+ anchor_time: Union[pd.Timestamp, Literal['entity']],
302
+ batch_size: int = 10_000,
303
+ ) -> np.ndarray:
304
+ r"""Denotes which nodes are valid for a given anchor time, *e.g.*,
305
+ which nodes fulfill entity filter constraints.
306
+
307
+ Args:
308
+ node: The nodes to check for.
309
+ anchor_time: The anchor time.
310
+ batch_size: How many nodes to process in a single batch.
311
+
312
+ Returns:
313
+ The mask.
314
+ """
315
+ mask: Optional[np.ndarray] = None
316
+
317
+ if isinstance(anchor_time, pd.Timestamp):
318
+ node = self._filter_candidates_by_time(node, anchor_time)
319
+ time = pd.Series(anchor_time).repeat(len(node))
320
+ time = time.astype('datetime64[ns]').reset_index(drop=True)
321
+ else:
322
+ assert anchor_time == 'entity'
323
+ time = self._graph_store.time_dict[self._query.entity_table]
324
+ time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
325
+
326
+ if isinstance(self._query.entity_ast, Filter):
327
+ # Mask out via (temporal) entity filter:
328
+ executor = PQueryPandasExecutor()
329
+ masks: List[np.ndarray] = []
330
+ for start in range(0, len(node), batch_size):
331
+ feat_dict, time_dict, batch_dict = self._sample(
332
+ node[start:start + batch_size],
333
+ time.iloc[start:start + batch_size],
334
+ )
335
+ _mask = executor.execute_filter(
336
+ filter=self._query.entity_ast,
337
+ feat_dict=feat_dict,
338
+ time_dict=time_dict,
339
+ batch_dict=batch_dict,
340
+ anchor_time=time.iloc[start:start + batch_size],
341
+ )[1]
342
+ masks.append(_mask)
343
+
344
+ _mask = np.concatenate(masks)
345
+ mask = (mask & _mask) if mask is not None else _mask
346
+
347
+ if mask is None:
348
+ mask = np.ones(len(node), dtype=bool)
349
+
350
+ return mask
351
+
352
+ def _get_sampling_specs(
353
+ self,
354
+ node: ASTNode,
355
+ hop: int,
356
+ seed_table_name: str,
357
+ edge_types: List[Tuple[str, str, str]],
358
+ num_forecasts: int = 1,
359
+ ) -> List[SamplingSpec]:
360
+ if isinstance(node, (Aggregation, Column)):
361
+ if isinstance(node, Column):
362
+ table_name = node.fqn.split('.')[0]
363
+ if seed_table_name == table_name:
364
+ return []
365
+ else:
366
+ table_name = node._get_target_column_name().split('.')[0]
367
+
368
+ target_edge_types = [
369
+ edge_type for edge_type in edge_types if
370
+ edge_type[2] == seed_table_name and edge_type[0] == table_name
371
+ ]
372
+ if len(target_edge_types) != 1:
373
+ raise ValueError(
374
+ f"Could not find a unique foreign key from table "
375
+ f"'{seed_table_name}' to '{table_name}'")
376
+
377
+ if isinstance(node, Column):
378
+ return [
379
+ SamplingSpec(
380
+ edge_type=target_edge_types[0],
381
+ hop=hop + 1,
382
+ start_offset=None,
383
+ end_offset=None,
384
+ )
385
+ ]
386
+ spec = SamplingSpec(
387
+ edge_type=target_edge_types[0],
388
+ hop=hop + 1,
389
+ start_offset=node.aggr_time_range.start_date_offset,
390
+ end_offset=node.aggr_time_range.end_date_offset *
391
+ num_forecasts,
392
+ )
393
+ return [spec] + self._get_sampling_specs(
394
+ node.target, hop=hop + 1, seed_table_name=table_name,
395
+ edge_types=edge_types, num_forecasts=num_forecasts)
396
+ specs = []
397
+ for child in node.children:
398
+ specs += self._get_sampling_specs(child, hop, seed_table_name,
399
+ edge_types, num_forecasts)
400
+ return specs
401
+
402
+ def get_sampling_specs(self) -> List[SamplingSpec]:
403
+ edge_types = self._graph_store.edge_types
404
+ specs = self._get_sampling_specs(
405
+ self._query.target_ast, hop=0,
406
+ seed_table_name=self._query.entity_table, edge_types=edge_types,
407
+ num_forecasts=self._query.num_forecasts)
408
+ specs += self._get_sampling_specs(
409
+ self._query.entity_ast, hop=0,
410
+ seed_table_name=self._query.entity_table, edge_types=edge_types)
411
+ if self._query.whatif_ast is not None:
412
+ specs += self._get_sampling_specs(
413
+ self._query.whatif_ast, hop=0,
414
+ seed_table_name=self._query.entity_table,
415
+ edge_types=edge_types)
416
+ # Group specs according to edge type and hop:
417
+ spec_dict: Dict[
418
+ Tuple[Tuple[str, str, str], int],
419
+ Tuple[Optional[DateOffset], Optional[DateOffset]],
420
+ ] = {}
421
+ for spec in specs:
422
+ if (spec.edge_type, spec.hop) not in spec_dict:
423
+ spec_dict[(spec.edge_type, spec.hop)] = (
424
+ spec.start_offset,
425
+ spec.end_offset,
426
+ )
427
+ else:
428
+ start_offset, end_offset = spec_dict[(
429
+ spec.edge_type,
430
+ spec.hop,
431
+ )]
432
+ spec_dict[(spec.edge_type, spec.hop)] = (
433
+ min_date_offset(start_offset, spec.start_offset),
434
+ max_date_offset(end_offset, spec.end_offset),
435
+ )
436
+
437
+ return [
438
+ SamplingSpec(edge, hop, start_offset, end_offset)
439
+ for (edge, hop), (start_offset, end_offset) in spec_dict.items()
440
+ ]
441
+
442
+ def _sample(
261
443
  self,
262
444
  node: np.ndarray,
263
445
  anchor_time: pd.Series,
264
- ) -> Tuple[pd.Series, np.ndarray]:
446
+ ) -> Tuple[
447
+ Dict[str, pd.DataFrame],
448
+ Dict[str, pd.Series],
449
+ Dict[str, np.ndarray],
450
+ ]:
451
+ r"""Samples a subgraph that contains all relevant information to
452
+ evaluate the predictive query.
453
+
454
+ Args:
455
+ node: The nodes to check for.
456
+ anchor_time: The anchor time.
265
457
 
266
- specs = self._query.get_sampling_specs(self._graph_store.edge_types)
458
+ Returns:
459
+ The feature dictionary, the time column dictionary and the batch
460
+ dictionary.
461
+ """
462
+ specs = self.get_sampling_specs()
267
463
  num_hops = max([spec.hop for spec in specs] + [0])
268
464
  num_neighbors: Dict[Tuple[str, str, str], list[int]] = {}
269
465
  time_offsets: Dict[
@@ -275,11 +471,10 @@ class LocalPQueryDriver:
275
471
  if spec.edge_type not in time_offsets:
276
472
  time_offsets[spec.edge_type] = [[0, 0]
277
473
  for _ in range(num_hops)]
278
- offset: Optional[int] = _date_offset_to_seconds(
279
- spec.end_offset)
474
+ offset: Optional[int] = date_offset_to_seconds(spec.end_offset)
280
475
  time_offsets[spec.edge_type][spec.hop - 1][1] = offset
281
476
  if spec.start_offset is not None:
282
- offset = _date_offset_to_seconds(spec.start_offset)
477
+ offset = date_offset_to_seconds(spec.start_offset)
283
478
  else:
284
479
  offset = None
285
480
  time_offsets[spec.edge_type][spec.hop - 1][0] = offset
@@ -290,7 +485,7 @@ class LocalPQueryDriver:
290
485
 
291
486
  edge_types = list(num_neighbors.keys()) + list(time_offsets.keys())
292
487
  node_types = list(
293
- set([self._query.entity.pkey.table_name])
488
+ set([self._query.entity_table])
294
489
  | set(src for src, _, _ in edge_types)
295
490
  | set(dst for _, _, dst in edge_types))
296
491
 
@@ -322,37 +517,116 @@ class LocalPQueryDriver:
322
517
  '__'.join(edge_type): np.array(values)
323
518
  for edge_type, values in time_offsets.items()
324
519
  },
325
- self._query.entity.pkey.table_name,
520
+ self._query.entity_table,
326
521
  node,
327
522
  anchor_time.astype(int).to_numpy() // 1000**3,
328
523
  )
329
524
 
330
525
  feat_dict: Dict[str, pd.DataFrame] = {}
331
526
  time_dict: Dict[str, pd.Series] = {}
332
- column_dict = self._query.column_dict
333
- time_tables = self._query.time_tables
527
+ column_dict: Dict[str, Set[str]] = {}
528
+ for col in self._query.all_query_columns:
529
+ table_name, col_name = col.split('.')
530
+ if table_name not in column_dict:
531
+ column_dict[table_name] = set()
532
+ if col_name != '*':
533
+ column_dict[table_name].add(col_name)
534
+ time_tables = self.find_time_tables()
334
535
  for table_name in set(list(column_dict.keys()) + time_tables):
335
536
  df = self._graph_store.df_dict[table_name]
336
537
  row_id = node_dict[table_name]
337
538
  df = df.iloc[row_id].reset_index(drop=True)
338
539
  if table_name in column_dict:
339
- feat_dict[table_name] = df[list(column_dict[table_name])]
540
+ if len(column_dict[table_name]) == 0:
541
+ # We are dealing with COUNT(table.*), insert a dummy col
542
+ # to ensure we don't lose the information on node count
543
+ feat_dict[table_name] = pd.DataFrame(
544
+ {'ones': [1] * len(df)})
545
+ else:
546
+ feat_dict[table_name] = df[list(column_dict[table_name])]
340
547
  if table_name in time_tables:
341
548
  time_col = self._graph_store.time_column_dict[table_name]
342
549
  time_dict[table_name] = df[time_col]
343
550
 
344
- y, mask = PQueryPandasBackend().eval_pquery(
551
+ return feat_dict, time_dict, batch_dict
552
+
553
+ def __call__(
554
+ self,
555
+ node: np.ndarray,
556
+ anchor_time: pd.Series,
557
+ ) -> Tuple[pd.Series, np.ndarray]:
558
+
559
+ feat_dict, time_dict, batch_dict = self._sample(node, anchor_time)
560
+
561
+ y, mask = PQueryPandasExecutor().execute(
345
562
  query=self._query,
346
563
  feat_dict=feat_dict,
347
564
  time_dict=time_dict,
348
565
  batch_dict=batch_dict,
349
566
  anchor_time=anchor_time,
567
+ num_forecasts=self._query.num_forecasts,
350
568
  )
351
569
 
352
570
  return y, mask
353
571
 
354
-
355
- def _date_offset_to_seconds(offset: pd.DateOffset) -> int:
572
+ def find_time_tables(self) -> List[str]:
573
+ def _find_time_tables(node: ASTNode) -> List[str]:
574
+ time_tables = []
575
+ if isinstance(node, Aggregation):
576
+ time_tables.append(
577
+ node._get_target_column_name().split('.')[0])
578
+ for child in node.children:
579
+ time_tables += _find_time_tables(child)
580
+ return time_tables
581
+
582
+ time_tables = _find_time_tables(
583
+ self._query.target_ast) + _find_time_tables(self._query.entity_ast)
584
+ if self._query.whatif_ast is not None:
585
+ time_tables += _find_time_tables(self._query.whatif_ast)
586
+ return list(set(time_tables))
587
+
588
+ @staticmethod
589
+ def get_task_type(
590
+ query: ValidatedPredictiveQuery,
591
+ edge_types: List[Tuple[str, str, str]],
592
+ ) -> TaskType:
593
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
594
+ return TaskType.BINARY_CLASSIFICATION
595
+
596
+ target = query.target_ast
597
+ if isinstance(target, Join):
598
+ target = target.rhs_target
599
+ if isinstance(target, Aggregation):
600
+ if target.aggr == AggregationType.LIST_DISTINCT:
601
+ table_name, col_name = target._get_target_column_name().split(
602
+ '.')
603
+ target_edge_types = [
604
+ edge_type for edge_type in edge_types
605
+ if edge_type[0] == table_name and edge_type[1] == col_name
606
+ ]
607
+ if len(target_edge_types) != 1:
608
+ raise NotImplementedError(
609
+ f"Multilabel-classification queries based on "
610
+ f"'LIST_DISTINCT' are not supported yet. If you "
611
+ f"planned to write a link prediction query instead, "
612
+ f"make sure to register '{col_name}' as a "
613
+ f"foreign key.")
614
+ return TaskType.TEMPORAL_LINK_PREDICTION
615
+
616
+ return TaskType.REGRESSION
617
+
618
+ assert isinstance(target, Column)
619
+
620
+ if target.stype in {Stype.ID, Stype.categorical}:
621
+ return TaskType.MULTICLASS_CLASSIFICATION
622
+
623
+ if target.stype in {Stype.numerical}:
624
+ return TaskType.REGRESSION
625
+
626
+ raise NotImplementedError("Task type not yet supported")
627
+
628
+
629
+ def date_offset_to_seconds(offset: pd.DateOffset) -> int:
356
630
  r"""Convert a :class:`pandas.DateOffset` into a maximum number of
357
631
  nanoseconds.
358
632
 
@@ -391,3 +665,25 @@ def _date_offset_to_seconds(offset: pd.DateOffset) -> int:
391
665
  total_ns += scaled_value
392
666
 
393
667
  return total_ns
668
+
669
+
670
+ def min_date_offset(*args: Optional[DateOffset]) -> Optional[DateOffset]:
671
+ if any(arg is None for arg in args):
672
+ return None
673
+
674
+ anchor = pd.Timestamp('2000-01-01')
675
+ timestamps = [anchor + arg for arg in args]
676
+ assert len(timestamps) > 0
677
+ argmin = min(range(len(timestamps)), key=lambda i: timestamps[i])
678
+ return args[argmin]
679
+
680
+
681
+ def max_date_offset(*args: DateOffset) -> DateOffset:
682
+ if any(arg is None for arg in args):
683
+ return None
684
+
685
+ anchor = pd.Timestamp('2000-01-01')
686
+ timestamps = [anchor + arg for arg in args]
687
+ assert len(timestamps) > 0
688
+ argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
689
+ return args[argmax]
@@ -1,7 +1,7 @@
1
- from .backend import PQueryBackend
2
- from .pandas_backend import PQueryPandasBackend
1
+ from .executor import PQueryExecutor
2
+ from .pandas_executor import PQueryPandasExecutor
3
3
 
4
4
  __all__ = [
5
- 'PQueryBackend',
6
- 'PQueryPandasBackend',
5
+ 'PQueryExecutor',
6
+ 'PQueryPandasExecutor',
7
7
  ]