kumoai 2.9.0.dev202509061830__cp311-cp311-macosx_11_0_arm64.whl → 2.12.0.dev202511031731__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. kumoai/__init__.py +4 -2
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +10 -5
  4. kumoai/client/rfm.py +3 -2
  5. kumoai/connector/file_upload_connector.py +71 -102
  6. kumoai/connector/utils.py +1367 -236
  7. kumoai/experimental/rfm/__init__.py +2 -2
  8. kumoai/experimental/rfm/authenticate.py +8 -5
  9. kumoai/experimental/rfm/infer/timestamp.py +7 -4
  10. kumoai/experimental/rfm/local_graph.py +90 -80
  11. kumoai/experimental/rfm/local_graph_sampler.py +16 -8
  12. kumoai/experimental/rfm/local_graph_store.py +22 -6
  13. kumoai/experimental/rfm/local_pquery_driver.py +129 -28
  14. kumoai/experimental/rfm/local_table.py +100 -22
  15. kumoai/experimental/rfm/pquery/__init__.py +4 -0
  16. kumoai/experimental/rfm/pquery/backend.py +4 -0
  17. kumoai/experimental/rfm/pquery/executor.py +102 -0
  18. kumoai/experimental/rfm/pquery/pandas_backend.py +71 -30
  19. kumoai/experimental/rfm/pquery/pandas_executor.py +506 -0
  20. kumoai/experimental/rfm/rfm.py +442 -94
  21. kumoai/jobs.py +1 -0
  22. kumoai/trainer/trainer.py +19 -10
  23. kumoai/utils/progress_logger.py +62 -0
  24. {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/METADATA +4 -5
  25. {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/RECORD +28 -26
  26. {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/WHEEL +0 -0
  27. {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/licenses/LICENSE +0 -0
  28. {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ class LocalPQueryDriver:
18
18
  self,
19
19
  graph_store: LocalGraphStore,
20
20
  query: PQueryDefinition,
21
- random_seed: Optional[int],
21
+ random_seed: Optional[int] = None,
22
22
  ) -> None:
23
23
  self._graph_store = graph_store
24
24
  self._query = query
@@ -27,7 +27,6 @@ class LocalPQueryDriver:
27
27
 
28
28
  def _get_candidates(
29
29
  self,
30
- anchor_time: Union[pd.Timestamp, Literal['entity']],
31
30
  exclude_node: Optional[np.ndarray] = None,
32
31
  ) -> np.ndarray:
33
32
 
@@ -61,12 +60,37 @@ class LocalPQueryDriver:
61
60
 
62
61
  return candidate
63
62
 
63
+ def _filter_candidates_by_time(
64
+ self,
65
+ candidate: np.ndarray,
66
+ anchor_time: pd.Timestamp,
67
+ ) -> np.ndarray:
68
+
69
+ entity = self._query.entity.pkey.table_name
70
+
71
+ # Filter out entities that do not exist yet in time:
72
+ time_sec = self._graph_store.time_dict.get(entity)
73
+ if time_sec is not None:
74
+ mask = time_sec[candidate] <= (anchor_time.value // (1000**3))
75
+ candidate = candidate[mask]
76
+
77
+ # Filter out entities that no longer exist in time:
78
+ end_time_col = self._graph_store.end_time_column_dict.get(entity)
79
+ if end_time_col is not None:
80
+ ser = self._graph_store.df_dict[entity][end_time_col]
81
+ ser = ser.iloc[candidate]
82
+ mask = (anchor_time < ser) | ser.isna().to_numpy()
83
+ candidate = candidate[mask]
84
+
85
+ return candidate
86
+
64
87
  def collect_test(
65
88
  self,
66
89
  size: int,
67
90
  anchor_time: Union[pd.Timestamp, Literal['entity']],
68
91
  batch_size: Optional[int] = None,
69
92
  max_iterations: int = 20,
93
+ guarantee_train_examples: bool = True,
70
94
  ) -> Tuple[np.ndarray, pd.Series, pd.Series]:
71
95
  r"""Collects test nodes and their labels used for evaluation.
72
96
 
@@ -75,13 +99,15 @@ class LocalPQueryDriver:
75
99
  anchor_time: The anchor time.
76
100
  batch_size: How many nodes to process in a single batch.
77
101
  max_iterations: The number of steps to run before aborting.
102
+ guarantee_train_examples: Ensures that test examples do not occupy
103
+ the entire set of entity candidates.
78
104
 
79
105
  Returns:
80
106
  A triplet holding the nodes, timestamps and labels.
81
107
  """
82
108
  batch_size = size if batch_size is None else batch_size
83
109
 
84
- candidate = self._get_candidates(anchor_time)
110
+ candidate = self._get_candidates()
85
111
 
86
112
  nodes: List[np.ndarray] = []
87
113
  times: List[pd.Series] = []
@@ -93,13 +119,7 @@ class LocalPQueryDriver:
93
119
  node = candidate[candidate_offset:candidate_offset + batch_size]
94
120
 
95
121
  if isinstance(anchor_time, pd.Timestamp):
96
- # Filter out non-existent entities:
97
- time = self._graph_store.time_dict.get(
98
- self._query.entity.pkey.table_name)
99
- if time is not None:
100
- node = node[time[node] <= (anchor_time.value // (1000**3))]
101
-
102
- if isinstance(anchor_time, pd.Timestamp):
122
+ node = self._filter_candidates_by_time(node, anchor_time)
103
123
  time = pd.Series(anchor_time).repeat(len(node))
104
124
  time = time.astype('datetime64[ns]').reset_index(drop=True)
105
125
  else:
@@ -148,6 +168,16 @@ class LocalPQueryDriver:
148
168
  f"using the 'max_pq_iterations' option. This "
149
169
  f"warning will not be shown again in this run.")
150
170
 
171
+ if (guarantee_train_examples
172
+ and self._query.query_type == QueryType.STATIC
173
+ and candidate_offset >= len(candidate)):
174
+ # In case all valid entities are used as test examples, we can no
175
+ # longer find any training example. Fallback to a 50/50 split:
176
+ size = len(node) // 2
177
+ node = node[:size]
178
+ time = time.iloc[:size]
179
+ y = y.iloc[:size]
180
+
151
181
  return node, time, y
152
182
 
153
183
  def collect_train(
@@ -172,7 +202,7 @@ class LocalPQueryDriver:
172
202
  """
173
203
  batch_size = size if batch_size is None else batch_size
174
204
 
175
- candidate = self._get_candidates(anchor_time, exclude_node)
205
+ candidate = self._get_candidates(exclude_node)
176
206
 
177
207
  if len(candidate) == 0:
178
208
  raise RuntimeError("Failed to generate any context examples "
@@ -182,22 +212,13 @@ class LocalPQueryDriver:
182
212
  times: List[pd.Series] = []
183
213
  ys: List[pd.Series] = []
184
214
 
185
- if isinstance(anchor_time, pd.Timestamp):
186
- anchor_time = anchor_time - self._query.target.end_offset
187
-
188
215
  reached_end = False
189
216
  num_labels = candidate_offset = 0
190
217
  for _ in range(max_iterations):
191
218
  node = candidate[candidate_offset:candidate_offset + batch_size]
192
219
 
193
220
  if isinstance(anchor_time, pd.Timestamp):
194
- # Filter out non-existent entities:
195
- time = self._graph_store.time_dict.get(
196
- self._query.entity.pkey.table_name)
197
- if time is not None:
198
- node = node[time[node] <= (anchor_time.value // (1000**3))]
199
-
200
- if isinstance(anchor_time, pd.Timestamp):
221
+ node = self._filter_candidates_by_time(node, anchor_time)
201
222
  time = pd.Series(anchor_time).repeat(len(node))
202
223
  time = time.astype('datetime64[ns]').reset_index(drop=True)
203
224
  else:
@@ -228,7 +249,8 @@ class LocalPQueryDriver:
228
249
  reached_end = True
229
250
  break
230
251
  candidate_offset = 0
231
- anchor_time = anchor_time - self._query.target.end_offset
252
+ anchor_time = anchor_time - (self._query.target.end_offset *
253
+ self._query.num_forecasts)
232
254
  if anchor_time < self._graph_store.min_time:
233
255
  reached_end = True
234
256
  break # No earlier anchor time left. Abort.
@@ -257,12 +279,81 @@ class LocalPQueryDriver:
257
279
 
258
280
  return node, time, y
259
281
 
260
- def __call__(
282
+ def is_valid(
283
+ self,
284
+ node: np.ndarray,
285
+ anchor_time: Union[pd.Timestamp, Literal['entity']],
286
+ batch_size: int = 10_000,
287
+ ) -> np.ndarray:
288
+ r"""Denotes which nodes are valid for a given anchor time, *e.g.*,
289
+ which nodes fulfill entity filter constraints.
290
+
291
+ Args:
292
+ node: The nodes to check for.
293
+ anchor_time: The anchor time.
294
+ batch_size: How many nodes to process in a single batch.
295
+
296
+ Returns:
297
+ The mask.
298
+ """
299
+ mask: Optional[np.ndarray] = None
300
+
301
+ if isinstance(anchor_time, pd.Timestamp):
302
+ node = self._filter_candidates_by_time(node, anchor_time)
303
+ time = pd.Series(anchor_time).repeat(len(node))
304
+ time = time.astype('datetime64[ns]').reset_index(drop=True)
305
+ else:
306
+ assert anchor_time == 'entity'
307
+ time = self._graph_store.time_dict[
308
+ self._query.entity.pkey.table_name]
309
+ time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
310
+
311
+ if self._query.entity.filter is not None:
312
+ # Mask out via (temporal) entity filter:
313
+ backend = PQueryPandasBackend()
314
+ masks: List[np.ndarray] = []
315
+ for start in range(0, len(node), batch_size):
316
+ feat_dict, time_dict, batch_dict = self._sample(
317
+ node[start:start + batch_size],
318
+ time.iloc[start:start + batch_size],
319
+ )
320
+ _mask = backend.eval_filter(
321
+ filter=self._query.entity.filter,
322
+ feat_dict=feat_dict,
323
+ time_dict=time_dict,
324
+ batch_dict=batch_dict,
325
+ anchor_time=time.iloc[start:start + batch_size],
326
+ )
327
+ masks.append(_mask)
328
+
329
+ _mask = np.concatenate(masks)
330
+ mask = (mask & _mask) if mask is not None else _mask
331
+
332
+ if mask is None:
333
+ mask = np.ones(len(node), dtype=bool)
334
+
335
+ return mask
336
+
337
+ def _sample(
261
338
  self,
262
339
  node: np.ndarray,
263
340
  anchor_time: pd.Series,
264
- ) -> Tuple[pd.Series, np.ndarray]:
341
+ ) -> Tuple[
342
+ Dict[str, pd.DataFrame],
343
+ Dict[str, pd.Series],
344
+ Dict[str, np.ndarray],
345
+ ]:
346
+ r"""Samples a subgraph that contains all relevant information to
347
+ evaluate the predictive query.
348
+
349
+ Args:
350
+ node: The nodes to check for.
351
+ anchor_time: The anchor time.
265
352
 
353
+ Returns:
354
+ The feature dictionary, the time column dictionary and the batch
355
+ dictionary.
356
+ """
266
357
  specs = self._query.get_sampling_specs(self._graph_store.edge_types)
267
358
  num_hops = max([spec.hop for spec in specs] + [0])
268
359
  num_neighbors: Dict[Tuple[str, str, str], list[int]] = {}
@@ -275,11 +366,10 @@ class LocalPQueryDriver:
275
366
  if spec.edge_type not in time_offsets:
276
367
  time_offsets[spec.edge_type] = [[0, 0]
277
368
  for _ in range(num_hops)]
278
- offset: Optional[int] = _date_offset_to_seconds(
279
- spec.end_offset)
369
+ offset: Optional[int] = date_offset_to_seconds(spec.end_offset)
280
370
  time_offsets[spec.edge_type][spec.hop - 1][1] = offset
281
371
  if spec.start_offset is not None:
282
- offset = _date_offset_to_seconds(spec.start_offset)
372
+ offset = date_offset_to_seconds(spec.start_offset)
283
373
  else:
284
374
  offset = None
285
375
  time_offsets[spec.edge_type][spec.hop - 1][0] = offset
@@ -341,18 +431,29 @@ class LocalPQueryDriver:
341
431
  time_col = self._graph_store.time_column_dict[table_name]
342
432
  time_dict[table_name] = df[time_col]
343
433
 
434
+ return feat_dict, time_dict, batch_dict
435
+
436
+ def __call__(
437
+ self,
438
+ node: np.ndarray,
439
+ anchor_time: pd.Series,
440
+ ) -> Tuple[pd.Series, np.ndarray]:
441
+
442
+ feat_dict, time_dict, batch_dict = self._sample(node, anchor_time)
443
+
344
444
  y, mask = PQueryPandasBackend().eval_pquery(
345
445
  query=self._query,
346
446
  feat_dict=feat_dict,
347
447
  time_dict=time_dict,
348
448
  batch_dict=batch_dict,
349
449
  anchor_time=anchor_time,
450
+ num_forecasts=self._query.num_forecasts,
350
451
  )
351
452
 
352
453
  return y, mask
353
454
 
354
455
 
355
- def _date_offset_to_seconds(offset: pd.DateOffset) -> int:
456
+ def date_offset_to_seconds(offset: pd.DateOffset) -> int:
356
457
  r"""Convert a :class:`pandas.DateOffset` into a maximum number of
357
458
  nanoseconds.
358
459
 
@@ -23,11 +23,13 @@ class Column:
23
23
  stype: Stype,
24
24
  is_primary_key: bool = False,
25
25
  is_time_column: bool = False,
26
+ is_end_time_column: bool = False,
26
27
  ) -> None:
27
28
  self._name = name
28
29
  self._dtype = Dtype(dtype)
29
30
  self._is_primary_key = is_primary_key
30
31
  self._is_time_column = is_time_column
32
+ self._is_end_time_column = is_end_time_column
31
33
  self.stype = Stype(stype)
32
34
 
33
35
  @property
@@ -50,9 +52,12 @@ class Column:
50
52
  if self._is_primary_key and val != Stype.ID:
51
53
  raise ValueError(f"Primary key '{self.name}' must have 'ID' "
52
54
  f"semantic type (got '{val}')")
53
- if self.name == self._is_time_column and val != Stype.timestamp:
55
+ if self._is_time_column and val != Stype.timestamp:
54
56
  raise ValueError(f"Time column '{self.name}' must have "
55
57
  f"'timestamp' semantic type (got '{val}')")
58
+ if self._is_end_time_column and val != Stype.timestamp:
59
+ raise ValueError(f"End time column '{self.name}' must have "
60
+ f"'timestamp' semantic type (got '{val}')")
56
61
 
57
62
  super().__setattr__(key, val)
58
63
 
@@ -93,6 +98,7 @@ class LocalTable:
93
98
  name="my_table",
94
99
  primary_key="id",
95
100
  time_column="time",
101
+ end_time_column=None,
96
102
  )
97
103
 
98
104
  # Verify metadata:
@@ -106,6 +112,8 @@ class LocalTable:
106
112
  name: The name of the table.
107
113
  primary_key: The name of the primary key of this table, if it exists.
108
114
  time_column: The name of the time column of this table, if it exists.
115
+ end_time_column: The name of the end time column of this table, if it
116
+ exists.
109
117
  """
110
118
  def __init__(
111
119
  self,
@@ -113,6 +121,7 @@ class LocalTable:
113
121
  name: str,
114
122
  primary_key: Optional[str] = None,
115
123
  time_column: Optional[str] = None,
124
+ end_time_column: Optional[str] = None,
116
125
  ) -> None:
117
126
 
118
127
  if df.empty:
@@ -130,6 +139,7 @@ class LocalTable:
130
139
  self._name = name
131
140
  self._primary_key: Optional[str] = None
132
141
  self._time_column: Optional[str] = None
142
+ self._end_time_column: Optional[str] = None
133
143
 
134
144
  self._columns: Dict[str, Column] = {}
135
145
  for column_name in df.columns:
@@ -141,6 +151,9 @@ class LocalTable:
141
151
  if time_column is not None:
142
152
  self.time_column = time_column
143
153
 
154
+ if end_time_column is not None:
155
+ self.end_time_column = end_time_column
156
+
144
157
  @property
145
158
  def name(self) -> str:
146
159
  r"""The name of the table."""
@@ -230,6 +243,8 @@ class LocalTable:
230
243
  self.primary_key = None
231
244
  if self._time_column == name:
232
245
  self.time_column = None
246
+ if self._end_time_column == name:
247
+ self.end_time_column = None
233
248
  del self._columns[name]
234
249
 
235
250
  return self
@@ -253,9 +268,8 @@ class LocalTable:
253
268
  :class:`ValueError` if the primary key has a non-ID semantic type or
254
269
  if the column name does not match a column in the data frame.
255
270
  """
256
- if not self.has_primary_key():
271
+ if self._primary_key is None:
257
272
  return None
258
- assert self._primary_key is not None
259
273
  return self[self._primary_key]
260
274
 
261
275
  @primary_key.setter
@@ -264,6 +278,10 @@ class LocalTable:
264
278
  raise ValueError(f"Cannot specify column '{name}' as a primary "
265
279
  f"key since it is already defined to be a time "
266
280
  f"column")
281
+ if name is not None and name == self._end_time_column:
282
+ raise ValueError(f"Cannot specify column '{name}' as a primary "
283
+ f"key since it is already defined to be an end "
284
+ f"time column")
267
285
 
268
286
  if self.primary_key is not None:
269
287
  self.primary_key._is_primary_key = False
@@ -295,9 +313,8 @@ class LocalTable:
295
313
  :class:`ValueError` if the time column has a non-timestamp semantic
296
314
  type or if the column name does not match a column in the data frame.
297
315
  """
298
- if not self.has_time_column():
316
+ if self._time_column is None:
299
317
  return None
300
- assert self._time_column is not None
301
318
  return self[self._time_column]
302
319
 
303
320
  @time_column.setter
@@ -306,6 +323,10 @@ class LocalTable:
306
323
  raise ValueError(f"Cannot specify column '{name}' as a time "
307
324
  f"column since it is already defined to be a "
308
325
  f"primary key")
326
+ if name is not None and name == self._end_time_column:
327
+ raise ValueError(f"Cannot specify column '{name}' as a time "
328
+ f"column since it is already defined to be an "
329
+ f"end time column")
309
330
 
310
331
  if self.time_column is not None:
311
332
  self.time_column._is_time_column = False
@@ -318,6 +339,52 @@ class LocalTable:
318
339
  self[name]._is_time_column = True
319
340
  self._time_column = name
320
341
 
342
+ # End Time column #########################################################
343
+
344
+ def has_end_time_column(self) -> bool:
345
+ r"""Returns ``True`` if this table has an end time column; ``False``
346
+ otherwise.
347
+ """
348
+ return self._end_time_column is not None
349
+
350
+ @property
351
+ def end_time_column(self) -> Optional[Column]:
352
+ r"""The end time column of this table.
353
+
354
+ The getter returns the end time column of this table, or ``None`` if no
355
+ such end time column is present.
356
+
357
+ The setter sets a column as an end time column on this table, and
358
+ raises a :class:`ValueError` if the end time column has a non-timestamp
359
+ semantic type or if the column name does not match a column in the data
360
+ frame.
361
+ """
362
+ if self._end_time_column is None:
363
+ return None
364
+ return self[self._end_time_column]
365
+
366
+ @end_time_column.setter
367
+ def end_time_column(self, name: Optional[str]) -> None:
368
+ if name is not None and name == self._primary_key:
369
+ raise ValueError(f"Cannot specify column '{name}' as an end time "
370
+ f"column since it is already defined to be a "
371
+ f"primary key")
372
+ if name is not None and name == self._time_column:
373
+ raise ValueError(f"Cannot specify column '{name}' as an end time "
374
+ f"column since it is already defined to be a "
375
+ f"time column")
376
+
377
+ if self.end_time_column is not None:
378
+ self.end_time_column._is_end_time_column = False
379
+
380
+ if name is None:
381
+ self._end_time_column = None
382
+ return
383
+
384
+ self[name].stype = Stype.timestamp
385
+ self[name]._is_end_time_column = True
386
+ self._end_time_column = name
387
+
321
388
  # Metadata ################################################################
322
389
 
323
390
  @property
@@ -326,16 +393,18 @@ class LocalTable:
326
393
  information about the columns in this table.
327
394
 
328
395
  The returned dataframe has columns ``name``, ``dtype``, ``stype``,
329
- ``is_primary_key``, and ``is_time_column``, which provide an aggregate
330
- view of the properties of the columns of this table.
396
+ ``is_primary_key``, ``is_time_column`` and ``is_end_time_column``,
397
+ which provide an aggregate view of the properties of the columns of
398
+ this table.
331
399
 
332
400
  Example:
401
+ >>> # doctest: +SKIP
333
402
  >>> import kumoai.experimental.rfm as rfm
334
403
  >>> table = rfm.LocalTable(df=..., name=...).infer_metadata()
335
404
  >>> table.metadata
336
- name dtype stype is_primary_key is_time_column
337
- 0 CustomerID float64 ID True False
338
- """
405
+ name dtype stype is_primary_key is_time_column is_end_time_column
406
+ 0 CustomerID float64 ID True False False
407
+ """ # noqa: E501
339
408
  cols = self.columns
340
409
 
341
410
  return pd.DataFrame({
@@ -355,6 +424,11 @@ class LocalTable:
355
424
  dtype=bool,
356
425
  data=[self._time_column == c.name for c in cols],
357
426
  ),
427
+ 'is_end_time_column':
428
+ pd.Series(
429
+ dtype=bool,
430
+ data=[self._end_time_column == c.name for c in cols],
431
+ ),
358
432
  })
359
433
 
360
434
  def print_metadata(self) -> None:
@@ -417,6 +491,7 @@ class LocalTable:
417
491
  candidates = [
418
492
  column.name for column in self.columns
419
493
  if column.stype == Stype.timestamp
494
+ and column.name != self._end_time_column
420
495
  ]
421
496
  if time_column := utils.detect_time_column(self._data, candidates):
422
497
  self.time_column = time_column
@@ -430,24 +505,26 @@ class LocalTable:
430
505
  # Helpers #################################################################
431
506
 
432
507
  def _to_api_table_definition(self) -> TableDefinition:
433
- cols: List[ColumnDefinition] = []
434
- for col in self.columns:
435
- cols.append(ColumnDefinition(col.name, col.stype, col.dtype))
436
- pkey = self._primary_key
437
- time_col = self._time_column
438
- source_table = UnavailableSourceTable(table=self.name)
439
-
440
508
  return TableDefinition(
441
- cols=cols,
442
- source_table=source_table,
443
- pkey=pkey,
444
- time_col=time_col,
509
+ cols=[
510
+ ColumnDefinition(col.name, col.stype, col.dtype)
511
+ for col in self.columns
512
+ ],
513
+ source_table=UnavailableSourceTable(table=self.name),
514
+ pkey=self._primary_key,
515
+ time_col=self._time_column,
516
+ end_time_col=self._end_time_column,
445
517
  )
446
518
 
447
519
  # Python builtins #########################################################
448
520
 
449
521
  def __hash__(self) -> int:
450
- return hash(tuple(self.columns + [self.primary_key, self.time_column]))
522
+ special_columns = [
523
+ self.primary_key,
524
+ self.time_column,
525
+ self.end_time_column,
526
+ ]
527
+ return hash(tuple(self.columns + special_columns))
451
528
 
452
529
  def __contains__(self, name: str) -> bool:
453
530
  return self.has_column(name)
@@ -464,4 +541,5 @@ class LocalTable:
464
541
  f' num_columns={len(self.columns)},\n'
465
542
  f' primary_key={self._primary_key},\n'
466
543
  f' time_column={self._time_column},\n'
544
+ f' end_time_column={self._end_time_column},\n'
467
545
  f')')
@@ -1,7 +1,11 @@
1
1
  from .backend import PQueryBackend
2
2
  from .pandas_backend import PQueryPandasBackend
3
+ from .executor import PQueryExecutor
4
+ from .pandas_executor import PQueryPandasExecutor
3
5
 
4
6
  __all__ = [
5
7
  'PQueryBackend',
6
8
  'PQueryPandasBackend',
9
+ 'PQueryExecutor',
10
+ 'PQueryPandasExecutor',
7
11
  ]
@@ -82,6 +82,7 @@ class PQueryBackend(Generic[TableData, ColumnData, IndexData], ABC):
82
82
  batch_dict: Dict[str, IndexData],
83
83
  anchor_time: ColumnData,
84
84
  filter_na: bool = True,
85
+ num_forecasts: int = 1,
85
86
  ) -> Tuple[ColumnData, IndexData]:
86
87
  pass
87
88
 
@@ -94,6 +95,7 @@ class PQueryBackend(Generic[TableData, ColumnData, IndexData], ABC):
94
95
  batch_dict: Dict[str, IndexData],
95
96
  anchor_time: ColumnData,
96
97
  filter_na: bool = True,
98
+ num_forecasts: int = 1,
97
99
  ) -> Tuple[ColumnData, IndexData]:
98
100
  pass
99
101
 
@@ -106,6 +108,7 @@ class PQueryBackend(Generic[TableData, ColumnData, IndexData], ABC):
106
108
  batch_dict: Dict[str, IndexData],
107
109
  anchor_time: ColumnData,
108
110
  filter_na: bool = True,
111
+ num_forecasts: int = 1,
109
112
  ) -> Tuple[ColumnData, IndexData]:
110
113
  pass
111
114
 
@@ -128,5 +131,6 @@ class PQueryBackend(Generic[TableData, ColumnData, IndexData], ABC):
128
131
  time_dict: Dict[str, ColumnData],
129
132
  batch_dict: Dict[str, IndexData],
130
133
  anchor_time: ColumnData,
134
+ num_forecasts: int = 1,
131
135
  ) -> Tuple[ColumnData, IndexData]:
132
136
  pass
@@ -0,0 +1,102 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Generic, Tuple, TypeVar
3
+
4
+ from kumoapi.pquery import ValidatedPredictiveQuery
5
+ from kumoapi.pquery.AST import (
6
+ Aggregation,
7
+ Column,
8
+ Condition,
9
+ Filter,
10
+ Join,
11
+ LogicalOperation,
12
+ )
13
+
14
+ TableData = TypeVar('TableData')
15
+ ColumnData = TypeVar('ColumnData')
16
+ IndexData = TypeVar('IndexData')
17
+
18
+
19
+ class PQueryExecutor(Generic[TableData, ColumnData, IndexData], ABC):
20
+ @abstractmethod
21
+ def execute_column(
22
+ self,
23
+ column: Column,
24
+ feat_dict: Dict[str, TableData],
25
+ filter_na: bool = True,
26
+ ) -> Tuple[ColumnData, IndexData]:
27
+ pass
28
+
29
+ @abstractmethod
30
+ def execute_aggregation(
31
+ self,
32
+ aggr: Aggregation,
33
+ feat_dict: Dict[str, TableData],
34
+ time_dict: Dict[str, ColumnData],
35
+ batch_dict: Dict[str, IndexData],
36
+ anchor_time: ColumnData,
37
+ filter_na: bool = True,
38
+ num_forecasts: int = 1,
39
+ ) -> Tuple[ColumnData, IndexData]:
40
+ pass
41
+
42
+ @abstractmethod
43
+ def execute_condition(
44
+ self,
45
+ condition: Condition,
46
+ feat_dict: Dict[str, TableData],
47
+ time_dict: Dict[str, ColumnData],
48
+ batch_dict: Dict[str, IndexData],
49
+ anchor_time: ColumnData,
50
+ filter_na: bool = True,
51
+ num_forecasts: int = 1,
52
+ ) -> Tuple[ColumnData, IndexData]:
53
+ pass
54
+
55
+ @abstractmethod
56
+ def execute_logical_operation(
57
+ self,
58
+ logical_operation: LogicalOperation,
59
+ feat_dict: Dict[str, TableData],
60
+ time_dict: Dict[str, ColumnData],
61
+ batch_dict: Dict[str, IndexData],
62
+ anchor_time: ColumnData,
63
+ filter_na: bool = True,
64
+ num_forecasts: int = 1,
65
+ ) -> Tuple[ColumnData, IndexData]:
66
+ pass
67
+
68
+ @abstractmethod
69
+ def execute_join(
70
+ self,
71
+ join: Join,
72
+ feat_dict: Dict[str, TableData],
73
+ time_dict: Dict[str, ColumnData],
74
+ batch_dict: Dict[str, IndexData],
75
+ anchor_time: ColumnData,
76
+ filter_na: bool = True,
77
+ num_forecasts: int = 1,
78
+ ) -> Tuple[ColumnData, IndexData]:
79
+ pass
80
+
81
+ @abstractmethod
82
+ def execute_filter(
83
+ self,
84
+ filter: Filter,
85
+ feat_dict: Dict[str, TableData],
86
+ time_dict: Dict[str, ColumnData],
87
+ batch_dict: Dict[str, IndexData],
88
+ anchor_time: ColumnData,
89
+ ) -> Tuple[ColumnData, IndexData]:
90
+ pass
91
+
92
+ @abstractmethod
93
+ def execute(
94
+ self,
95
+ query: ValidatedPredictiveQuery,
96
+ feat_dict: Dict[str, TableData],
97
+ time_dict: Dict[str, ColumnData],
98
+ batch_dict: Dict[str, IndexData],
99
+ anchor_time: ColumnData,
100
+ num_forecasts: int = 1,
101
+ ) -> Tuple[ColumnData, IndexData]:
102
+ pass