kumoai 2.10.0.dev202510061830__cp313-cp313-macosx_11_0_arm64.whl → 2.13.0.dev202511261731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,6 @@ from typing import Dict, List, Optional, Tuple
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
5
- from kumoapi.model_plan import RunMode
6
5
  from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
7
6
  from kumoapi.typing import Stype
8
7
 
@@ -33,7 +32,6 @@ class LocalGraphSampler:
33
32
  entity_table_names: Tuple[str, ...],
34
33
  node: np.ndarray,
35
34
  time: np.ndarray,
36
- run_mode: RunMode,
37
35
  num_neighbors: List[int],
38
36
  exclude_cols_dict: Dict[str, List[str]],
39
37
  ) -> Subgraph:
@@ -92,15 +90,23 @@ class LocalGraphSampler:
92
90
  )
93
91
  continue
94
92
 
95
- # Only store unique rows in `df` above a certain threshold:
96
- unique_node, inverse_node = np.unique(node, return_inverse=True)
97
- if len(node) > 1.05 * len(unique_node):
98
- df = df.iloc[unique_node]
99
- row = inverse_node
93
+ row: Optional[np.ndarray] = None
94
+ if table_name in self._graph_store.end_time_column_dict:
95
+ # Set end time to NaT for all values greater than anchor time:
96
+ df = df.iloc[node].reset_index(drop=True)
97
+ col_name = self._graph_store.end_time_column_dict[table_name]
98
+ ser = df[col_name]
99
+ value = ser.astype('datetime64[ns]').astype(int).to_numpy()
100
+ mask = value > time[batch]
101
+ df.loc[mask, col_name] = pd.NaT
100
102
  else:
101
- df = df.iloc[node]
102
- row = None
103
- df = df.reset_index(drop=True)
103
+ # Only store unique rows in `df` above a certain threshold:
104
+ unique_node, inverse = np.unique(node, return_inverse=True)
105
+ if len(node) > 1.05 * len(unique_node):
106
+ df = df.iloc[unique_node].reset_index(drop=True)
107
+ row = inverse
108
+ else:
109
+ df = df.iloc[node].reset_index(drop=True)
104
110
 
105
111
  # Filter data frame to minimal set of columns:
106
112
  df = df[columns]
@@ -45,6 +45,7 @@ class LocalGraphStore:
45
45
 
46
46
  (
47
47
  self.time_column_dict,
48
+ self.end_time_column_dict,
48
49
  self.time_dict,
49
50
  self.min_time,
50
51
  self.max_time,
@@ -219,16 +220,21 @@ class LocalGraphStore:
219
220
  self,
220
221
  graph: LocalGraph,
221
222
  ) -> Tuple[
223
+ Dict[str, str],
222
224
  Dict[str, str],
223
225
  Dict[str, np.ndarray],
224
226
  pd.Timestamp,
225
227
  pd.Timestamp,
226
228
  ]:
227
229
  time_column_dict: Dict[str, str] = {}
230
+ end_time_column_dict: Dict[str, str] = {}
228
231
  time_dict: Dict[str, np.ndarray] = {}
229
232
  min_time = pd.Timestamp.max
230
233
  max_time = pd.Timestamp.min
231
234
  for table in graph.tables.values():
235
+ if table._end_time_column is not None:
236
+ end_time_column_dict[table.name] = table._end_time_column
237
+
232
238
  if table._time_column is None:
233
239
  continue
234
240
 
@@ -243,7 +249,13 @@ class LocalGraphStore:
243
249
  min_time = min(min_time, time.min())
244
250
  max_time = max(max_time, time.max())
245
251
 
246
- return time_column_dict, time_dict, min_time, max_time
252
+ return (
253
+ time_column_dict,
254
+ end_time_column_dict,
255
+ time_dict,
256
+ min_time,
257
+ max_time,
258
+ )
247
259
 
248
260
  def get_csc(
249
261
  self,
@@ -1,23 +1,40 @@
1
1
  import warnings
2
- from typing import Dict, List, Literal, Optional, Tuple, Union
2
+ from typing import Dict, List, Literal, NamedTuple, Optional, Set, Tuple, Union
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
6
- from kumoapi.pquery import QueryType
7
- from kumoapi.rfm import PQueryDefinition
6
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
7
+ from kumoapi.pquery.AST import (
8
+ Aggregation,
9
+ ASTNode,
10
+ Column,
11
+ Condition,
12
+ Filter,
13
+ Join,
14
+ LogicalOperation,
15
+ )
16
+ from kumoapi.task import TaskType
17
+ from kumoapi.typing import AggregationType, DateOffset, Stype
8
18
 
9
19
  import kumoai.kumolib as kumolib
10
20
  from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
11
- from kumoai.experimental.rfm.pquery import PQueryPandasBackend
21
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
12
22
 
13
23
  _coverage_warned = False
14
24
 
15
25
 
26
+ class SamplingSpec(NamedTuple):
27
+ edge_type: Tuple[str, str, str]
28
+ hop: int
29
+ start_offset: Optional[DateOffset]
30
+ end_offset: Optional[DateOffset]
31
+
32
+
16
33
  class LocalPQueryDriver:
17
34
  def __init__(
18
35
  self,
19
36
  graph_store: LocalGraphStore,
20
- query: PQueryDefinition,
37
+ query: ValidatedPredictiveQuery,
21
38
  random_seed: Optional[int] = None,
22
39
  ) -> None:
23
40
  self._graph_store = graph_store
@@ -27,14 +44,13 @@ class LocalPQueryDriver:
27
44
 
28
45
  def _get_candidates(
29
46
  self,
30
- anchor_time: Union[pd.Timestamp, Literal['entity']],
31
47
  exclude_node: Optional[np.ndarray] = None,
32
48
  ) -> np.ndarray:
33
49
 
34
50
  if self._query.query_type == QueryType.TEMPORAL:
35
51
  assert exclude_node is None
36
52
 
37
- table_name = self._query.entity.pkey.table_name
53
+ table_name = self._query.entity_table
38
54
  num_nodes = len(self._graph_store.df_dict[table_name])
39
55
  mask_dict = self._graph_store.mask_dict
40
56
 
@@ -61,6 +77,30 @@ class LocalPQueryDriver:
61
77
 
62
78
  return candidate
63
79
 
80
+ def _filter_candidates_by_time(
81
+ self,
82
+ candidate: np.ndarray,
83
+ anchor_time: pd.Timestamp,
84
+ ) -> np.ndarray:
85
+
86
+ entity = self._query.entity_table
87
+
88
+ # Filter out entities that do not exist yet in time:
89
+ time_sec = self._graph_store.time_dict.get(entity)
90
+ if time_sec is not None:
91
+ mask = time_sec[candidate] <= (anchor_time.value // (1000**3))
92
+ candidate = candidate[mask]
93
+
94
+ # Filter out entities that no longer exist in time:
95
+ end_time_col = self._graph_store.end_time_column_dict.get(entity)
96
+ if end_time_col is not None:
97
+ ser = self._graph_store.df_dict[entity][end_time_col]
98
+ ser = ser.iloc[candidate]
99
+ mask = (anchor_time < ser) | ser.isna().to_numpy()
100
+ candidate = candidate[mask]
101
+
102
+ return candidate
103
+
64
104
  def collect_test(
65
105
  self,
66
106
  size: int,
@@ -84,7 +124,7 @@ class LocalPQueryDriver:
84
124
  """
85
125
  batch_size = size if batch_size is None else batch_size
86
126
 
87
- candidate = self._get_candidates(anchor_time)
127
+ candidate = self._get_candidates()
88
128
 
89
129
  nodes: List[np.ndarray] = []
90
130
  times: List[pd.Series] = []
@@ -96,19 +136,12 @@ class LocalPQueryDriver:
96
136
  node = candidate[candidate_offset:candidate_offset + batch_size]
97
137
 
98
138
  if isinstance(anchor_time, pd.Timestamp):
99
- # Filter out non-existent entities:
100
- time = self._graph_store.time_dict.get(
101
- self._query.entity.pkey.table_name)
102
- if time is not None:
103
- node = node[time[node] <= (anchor_time.value // (1000**3))]
104
-
105
- if isinstance(anchor_time, pd.Timestamp):
139
+ node = self._filter_candidates_by_time(node, anchor_time)
106
140
  time = pd.Series(anchor_time).repeat(len(node))
107
141
  time = time.astype('datetime64[ns]').reset_index(drop=True)
108
142
  else:
109
143
  assert anchor_time == 'entity'
110
- time = self._graph_store.time_dict[
111
- self._query.entity.pkey.table_name]
144
+ time = self._graph_store.time_dict[self._query.entity_table]
112
145
  time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
113
146
 
114
147
  y, mask = self(node, time)
@@ -185,7 +218,7 @@ class LocalPQueryDriver:
185
218
  """
186
219
  batch_size = size if batch_size is None else batch_size
187
220
 
188
- candidate = self._get_candidates(anchor_time, exclude_node)
221
+ candidate = self._get_candidates(exclude_node)
189
222
 
190
223
  if len(candidate) == 0:
191
224
  raise RuntimeError("Failed to generate any context examples "
@@ -201,19 +234,12 @@ class LocalPQueryDriver:
201
234
  node = candidate[candidate_offset:candidate_offset + batch_size]
202
235
 
203
236
  if isinstance(anchor_time, pd.Timestamp):
204
- # Filter out non-existent entities:
205
- time = self._graph_store.time_dict.get(
206
- self._query.entity.pkey.table_name)
207
- if time is not None:
208
- node = node[time[node] <= (anchor_time.value // (1000**3))]
209
-
210
- if isinstance(anchor_time, pd.Timestamp):
237
+ node = self._filter_candidates_by_time(node, anchor_time)
211
238
  time = pd.Series(anchor_time).repeat(len(node))
212
239
  time = time.astype('datetime64[ns]').reset_index(drop=True)
213
240
  else:
214
241
  assert anchor_time == 'entity'
215
- time = self._graph_store.time_dict[
216
- self._query.entity.pkey.table_name]
242
+ time = self._graph_store.time_dict[self._query.entity_table]
217
243
  time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
218
244
 
219
245
  y, mask = self(node, time)
@@ -238,7 +264,8 @@ class LocalPQueryDriver:
238
264
  reached_end = True
239
265
  break
240
266
  candidate_offset = 0
241
- anchor_time = anchor_time - (self._query.target.end_offset *
267
+ time_frame = self._query.target_timeframe.timeframe
268
+ anchor_time = anchor_time - (time_frame *
242
269
  self._query.num_forecasts)
243
270
  if anchor_time < self._graph_store.min_time:
244
271
  reached_end = True
@@ -288,37 +315,30 @@ class LocalPQueryDriver:
288
315
  mask: Optional[np.ndarray] = None
289
316
 
290
317
  if isinstance(anchor_time, pd.Timestamp):
291
- # Mask out non-existent entities:
292
- time = self._graph_store.time_dict.get(
293
- self._query.entity.pkey.table_name)
294
- if time is not None:
295
- mask = time[node] <= (anchor_time.value // (1000**3))
296
-
297
- if isinstance(anchor_time, pd.Timestamp):
318
+ node = self._filter_candidates_by_time(node, anchor_time)
298
319
  time = pd.Series(anchor_time).repeat(len(node))
299
320
  time = time.astype('datetime64[ns]').reset_index(drop=True)
300
321
  else:
301
322
  assert anchor_time == 'entity'
302
- time = self._graph_store.time_dict[
303
- self._query.entity.pkey.table_name]
323
+ time = self._graph_store.time_dict[self._query.entity_table]
304
324
  time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
305
325
 
306
- if self._query.entity.filter is not None:
326
+ if isinstance(self._query.entity_ast, Filter):
307
327
  # Mask out via (temporal) entity filter:
308
- backend = PQueryPandasBackend()
328
+ executor = PQueryPandasExecutor()
309
329
  masks: List[np.ndarray] = []
310
330
  for start in range(0, len(node), batch_size):
311
331
  feat_dict, time_dict, batch_dict = self._sample(
312
332
  node[start:start + batch_size],
313
333
  time.iloc[start:start + batch_size],
314
334
  )
315
- _mask = backend.eval_filter(
316
- filter=self._query.entity.filter,
335
+ _mask = executor.execute_filter(
336
+ filter=self._query.entity_ast,
317
337
  feat_dict=feat_dict,
318
338
  time_dict=time_dict,
319
339
  batch_dict=batch_dict,
320
340
  anchor_time=time.iloc[start:start + batch_size],
321
- )
341
+ )[1]
322
342
  masks.append(_mask)
323
343
 
324
344
  _mask = np.concatenate(masks)
@@ -329,6 +349,96 @@ class LocalPQueryDriver:
329
349
 
330
350
  return mask
331
351
 
352
+ def _get_sampling_specs(
353
+ self,
354
+ node: ASTNode,
355
+ hop: int,
356
+ seed_table_name: str,
357
+ edge_types: List[Tuple[str, str, str]],
358
+ num_forecasts: int = 1,
359
+ ) -> List[SamplingSpec]:
360
+ if isinstance(node, (Aggregation, Column)):
361
+ if isinstance(node, Column):
362
+ table_name = node.fqn.split('.')[0]
363
+ if seed_table_name == table_name:
364
+ return []
365
+ else:
366
+ table_name = node._get_target_column_name().split('.')[0]
367
+
368
+ target_edge_types = [
369
+ edge_type for edge_type in edge_types if
370
+ edge_type[2] == seed_table_name and edge_type[0] == table_name
371
+ ]
372
+ if len(target_edge_types) != 1:
373
+ raise ValueError(
374
+ f"Could not find a unique foreign key from table "
375
+ f"'{seed_table_name}' to '{table_name}'")
376
+
377
+ if isinstance(node, Column):
378
+ return [
379
+ SamplingSpec(
380
+ edge_type=target_edge_types[0],
381
+ hop=hop + 1,
382
+ start_offset=None,
383
+ end_offset=None,
384
+ )
385
+ ]
386
+ spec = SamplingSpec(
387
+ edge_type=target_edge_types[0],
388
+ hop=hop + 1,
389
+ start_offset=node.aggr_time_range.start_date_offset,
390
+ end_offset=node.aggr_time_range.end_date_offset *
391
+ num_forecasts,
392
+ )
393
+ return [spec] + self._get_sampling_specs(
394
+ node.target, hop=hop + 1, seed_table_name=table_name,
395
+ edge_types=edge_types, num_forecasts=num_forecasts)
396
+ specs = []
397
+ for child in node.children:
398
+ specs += self._get_sampling_specs(child, hop, seed_table_name,
399
+ edge_types, num_forecasts)
400
+ return specs
401
+
402
+ def get_sampling_specs(self) -> List[SamplingSpec]:
403
+ edge_types = self._graph_store.edge_types
404
+ specs = self._get_sampling_specs(
405
+ self._query.target_ast, hop=0,
406
+ seed_table_name=self._query.entity_table, edge_types=edge_types,
407
+ num_forecasts=self._query.num_forecasts)
408
+ specs += self._get_sampling_specs(
409
+ self._query.entity_ast, hop=0,
410
+ seed_table_name=self._query.entity_table, edge_types=edge_types)
411
+ if self._query.whatif_ast is not None:
412
+ specs += self._get_sampling_specs(
413
+ self._query.whatif_ast, hop=0,
414
+ seed_table_name=self._query.entity_table,
415
+ edge_types=edge_types)
416
+ # Group specs according to edge type and hop:
417
+ spec_dict: Dict[
418
+ Tuple[Tuple[str, str, str], int],
419
+ Tuple[Optional[DateOffset], Optional[DateOffset]],
420
+ ] = {}
421
+ for spec in specs:
422
+ if (spec.edge_type, spec.hop) not in spec_dict:
423
+ spec_dict[(spec.edge_type, spec.hop)] = (
424
+ spec.start_offset,
425
+ spec.end_offset,
426
+ )
427
+ else:
428
+ start_offset, end_offset = spec_dict[(
429
+ spec.edge_type,
430
+ spec.hop,
431
+ )]
432
+ spec_dict[(spec.edge_type, spec.hop)] = (
433
+ min_date_offset(start_offset, spec.start_offset),
434
+ max_date_offset(end_offset, spec.end_offset),
435
+ )
436
+
437
+ return [
438
+ SamplingSpec(edge, hop, start_offset, end_offset)
439
+ for (edge, hop), (start_offset, end_offset) in spec_dict.items()
440
+ ]
441
+
332
442
  def _sample(
333
443
  self,
334
444
  node: np.ndarray,
@@ -349,7 +459,7 @@ class LocalPQueryDriver:
349
459
  The feature dictionary, the time column dictionary and the batch
350
460
  dictionary.
351
461
  """
352
- specs = self._query.get_sampling_specs(self._graph_store.edge_types)
462
+ specs = self.get_sampling_specs()
353
463
  num_hops = max([spec.hop for spec in specs] + [0])
354
464
  num_neighbors: Dict[Tuple[str, str, str], list[int]] = {}
355
465
  time_offsets: Dict[
@@ -375,7 +485,7 @@ class LocalPQueryDriver:
375
485
 
376
486
  edge_types = list(num_neighbors.keys()) + list(time_offsets.keys())
377
487
  node_types = list(
378
- set([self._query.entity.pkey.table_name])
488
+ set([self._query.entity_table])
379
489
  | set(src for src, _, _ in edge_types)
380
490
  | set(dst for _, _, dst in edge_types))
381
491
 
@@ -407,21 +517,33 @@ class LocalPQueryDriver:
407
517
  '__'.join(edge_type): np.array(values)
408
518
  for edge_type, values in time_offsets.items()
409
519
  },
410
- self._query.entity.pkey.table_name,
520
+ self._query.entity_table,
411
521
  node,
412
522
  anchor_time.astype(int).to_numpy() // 1000**3,
413
523
  )
414
524
 
415
525
  feat_dict: Dict[str, pd.DataFrame] = {}
416
526
  time_dict: Dict[str, pd.Series] = {}
417
- column_dict = self._query.column_dict
418
- time_tables = self._query.time_tables
527
+ column_dict: Dict[str, Set[str]] = {}
528
+ for col in self._query.all_query_columns:
529
+ table_name, col_name = col.split('.')
530
+ if table_name not in column_dict:
531
+ column_dict[table_name] = set()
532
+ if col_name != '*':
533
+ column_dict[table_name].add(col_name)
534
+ time_tables = self.find_time_tables()
419
535
  for table_name in set(list(column_dict.keys()) + time_tables):
420
536
  df = self._graph_store.df_dict[table_name]
421
537
  row_id = node_dict[table_name]
422
538
  df = df.iloc[row_id].reset_index(drop=True)
423
539
  if table_name in column_dict:
424
- feat_dict[table_name] = df[list(column_dict[table_name])]
540
+ if len(column_dict[table_name]) == 0:
541
+ # We are dealing with COUNT(table.*), insert a dummy col
542
+ # to ensure we don't lose the information on node count
543
+ feat_dict[table_name] = pd.DataFrame(
544
+ {'ones': [1] * len(df)})
545
+ else:
546
+ feat_dict[table_name] = df[list(column_dict[table_name])]
425
547
  if table_name in time_tables:
426
548
  time_col = self._graph_store.time_column_dict[table_name]
427
549
  time_dict[table_name] = df[time_col]
@@ -436,7 +558,7 @@ class LocalPQueryDriver:
436
558
 
437
559
  feat_dict, time_dict, batch_dict = self._sample(node, anchor_time)
438
560
 
439
- y, mask = PQueryPandasBackend().eval_pquery(
561
+ y, mask = PQueryPandasExecutor().execute(
440
562
  query=self._query,
441
563
  feat_dict=feat_dict,
442
564
  time_dict=time_dict,
@@ -447,6 +569,62 @@ class LocalPQueryDriver:
447
569
 
448
570
  return y, mask
449
571
 
572
+ def find_time_tables(self) -> List[str]:
573
+ def _find_time_tables(node: ASTNode) -> List[str]:
574
+ time_tables = []
575
+ if isinstance(node, Aggregation):
576
+ time_tables.append(
577
+ node._get_target_column_name().split('.')[0])
578
+ for child in node.children:
579
+ time_tables += _find_time_tables(child)
580
+ return time_tables
581
+
582
+ time_tables = _find_time_tables(
583
+ self._query.target_ast) + _find_time_tables(self._query.entity_ast)
584
+ if self._query.whatif_ast is not None:
585
+ time_tables += _find_time_tables(self._query.whatif_ast)
586
+ return list(set(time_tables))
587
+
588
+ @staticmethod
589
+ def get_task_type(
590
+ query: ValidatedPredictiveQuery,
591
+ edge_types: List[Tuple[str, str, str]],
592
+ ) -> TaskType:
593
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
594
+ return TaskType.BINARY_CLASSIFICATION
595
+
596
+ target = query.target_ast
597
+ if isinstance(target, Join):
598
+ target = target.rhs_target
599
+ if isinstance(target, Aggregation):
600
+ if target.aggr == AggregationType.LIST_DISTINCT:
601
+ table_name, col_name = target._get_target_column_name().split(
602
+ '.')
603
+ target_edge_types = [
604
+ edge_type for edge_type in edge_types
605
+ if edge_type[0] == table_name and edge_type[1] == col_name
606
+ ]
607
+ if len(target_edge_types) != 1:
608
+ raise NotImplementedError(
609
+ f"Multilabel-classification queries based on "
610
+ f"'LIST_DISTINCT' are not supported yet. If you "
611
+ f"planned to write a link prediction query instead, "
612
+ f"make sure to register '{col_name}' as a "
613
+ f"foreign key.")
614
+ return TaskType.TEMPORAL_LINK_PREDICTION
615
+
616
+ return TaskType.REGRESSION
617
+
618
+ assert isinstance(target, Column)
619
+
620
+ if target.stype in {Stype.ID, Stype.categorical}:
621
+ return TaskType.MULTICLASS_CLASSIFICATION
622
+
623
+ if target.stype in {Stype.numerical}:
624
+ return TaskType.REGRESSION
625
+
626
+ raise NotImplementedError("Task type not yet supported")
627
+
450
628
 
451
629
  def date_offset_to_seconds(offset: pd.DateOffset) -> int:
452
630
  r"""Convert a :class:`pandas.DateOffset` into a maximum number of
@@ -487,3 +665,25 @@ def date_offset_to_seconds(offset: pd.DateOffset) -> int:
487
665
  total_ns += scaled_value
488
666
 
489
667
  return total_ns
668
+
669
+
670
+ def min_date_offset(*args: Optional[DateOffset]) -> Optional[DateOffset]:
671
+ if any(arg is None for arg in args):
672
+ return None
673
+
674
+ anchor = pd.Timestamp('2000-01-01')
675
+ timestamps = [anchor + arg for arg in args]
676
+ assert len(timestamps) > 0
677
+ argmin = min(range(len(timestamps)), key=lambda i: timestamps[i])
678
+ return args[argmin]
679
+
680
+
681
+ def max_date_offset(*args: DateOffset) -> DateOffset:
682
+ if any(arg is None for arg in args):
683
+ return None
684
+
685
+ anchor = pd.Timestamp('2000-01-01')
686
+ timestamps = [anchor + arg for arg in args]
687
+ assert len(timestamps) > 0
688
+ argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
689
+ return args[argmax]