kumoai 2.10.1__cp311-cp311-macosx_11_0_arm64.whl → 2.12.0.dev202511031731__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kumoai/__init__.py CHANGED
@@ -200,9 +200,11 @@ def init(
200
200
 
201
201
  logger = logging.getLogger('kumoai')
202
202
  log_level = logging.getLevelName(logger.getEffectiveLevel())
203
+
203
204
  logger.info(
204
- "Successfully initialized the Kumo SDK against deployment %s, with "
205
- "log level %s.", url, log_level)
205
+ f"Successfully initialized the Kumo SDK (version {__version__}) "
206
+ f"against deployment {url}, with "
207
+ f"log level {log_level}.")
206
208
 
207
209
 
208
210
  def set_log_level(level: str) -> None:
kumoai/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '2.10.1'
1
+ __version__ = '2.12.0.dev202511031731'
@@ -611,8 +611,8 @@ class LocalGraph:
611
611
  raise ValueError(f"{edge} is invalid as foreign key "
612
612
  f"'{fkey}' and primary key '{dst_key.name}' "
613
613
  f"have incompatible data types (got "
614
- f"fkey.dtype '{dst_key.dtype}' and "
615
- f"pkey.dtype '{src_key.dtype}')")
614
+ f"fkey.dtype '{src_key.dtype}' and "
615
+ f"pkey.dtype '{dst_key.dtype}')")
616
616
 
617
617
  return self
618
618
 
@@ -1,7 +1,11 @@
1
1
  from .backend import PQueryBackend
2
2
  from .pandas_backend import PQueryPandasBackend
3
+ from .executor import PQueryExecutor
4
+ from .pandas_executor import PQueryPandasExecutor
3
5
 
4
6
  __all__ = [
5
7
  'PQueryBackend',
6
8
  'PQueryPandasBackend',
9
+ 'PQueryExecutor',
10
+ 'PQueryPandasExecutor',
7
11
  ]
@@ -0,0 +1,102 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Generic, Tuple, TypeVar
3
+
4
+ from kumoapi.pquery import ValidatedPredictiveQuery
5
+ from kumoapi.pquery.AST import (
6
+ Aggregation,
7
+ Column,
8
+ Condition,
9
+ Filter,
10
+ Join,
11
+ LogicalOperation,
12
+ )
13
+
14
+ TableData = TypeVar('TableData')
15
+ ColumnData = TypeVar('ColumnData')
16
+ IndexData = TypeVar('IndexData')
17
+
18
+
19
+ class PQueryExecutor(Generic[TableData, ColumnData, IndexData], ABC):
20
+ @abstractmethod
21
+ def execute_column(
22
+ self,
23
+ column: Column,
24
+ feat_dict: Dict[str, TableData],
25
+ filter_na: bool = True,
26
+ ) -> Tuple[ColumnData, IndexData]:
27
+ pass
28
+
29
+ @abstractmethod
30
+ def execute_aggregation(
31
+ self,
32
+ aggr: Aggregation,
33
+ feat_dict: Dict[str, TableData],
34
+ time_dict: Dict[str, ColumnData],
35
+ batch_dict: Dict[str, IndexData],
36
+ anchor_time: ColumnData,
37
+ filter_na: bool = True,
38
+ num_forecasts: int = 1,
39
+ ) -> Tuple[ColumnData, IndexData]:
40
+ pass
41
+
42
+ @abstractmethod
43
+ def execute_condition(
44
+ self,
45
+ condition: Condition,
46
+ feat_dict: Dict[str, TableData],
47
+ time_dict: Dict[str, ColumnData],
48
+ batch_dict: Dict[str, IndexData],
49
+ anchor_time: ColumnData,
50
+ filter_na: bool = True,
51
+ num_forecasts: int = 1,
52
+ ) -> Tuple[ColumnData, IndexData]:
53
+ pass
54
+
55
+ @abstractmethod
56
+ def execute_logical_operation(
57
+ self,
58
+ logical_operation: LogicalOperation,
59
+ feat_dict: Dict[str, TableData],
60
+ time_dict: Dict[str, ColumnData],
61
+ batch_dict: Dict[str, IndexData],
62
+ anchor_time: ColumnData,
63
+ filter_na: bool = True,
64
+ num_forecasts: int = 1,
65
+ ) -> Tuple[ColumnData, IndexData]:
66
+ pass
67
+
68
+ @abstractmethod
69
+ def execute_join(
70
+ self,
71
+ join: Join,
72
+ feat_dict: Dict[str, TableData],
73
+ time_dict: Dict[str, ColumnData],
74
+ batch_dict: Dict[str, IndexData],
75
+ anchor_time: ColumnData,
76
+ filter_na: bool = True,
77
+ num_forecasts: int = 1,
78
+ ) -> Tuple[ColumnData, IndexData]:
79
+ pass
80
+
81
+ @abstractmethod
82
+ def execute_filter(
83
+ self,
84
+ filter: Filter,
85
+ feat_dict: Dict[str, TableData],
86
+ time_dict: Dict[str, ColumnData],
87
+ batch_dict: Dict[str, IndexData],
88
+ anchor_time: ColumnData,
89
+ ) -> Tuple[ColumnData, IndexData]:
90
+ pass
91
+
92
+ @abstractmethod
93
+ def execute(
94
+ self,
95
+ query: ValidatedPredictiveQuery,
96
+ feat_dict: Dict[str, TableData],
97
+ time_dict: Dict[str, ColumnData],
98
+ batch_dict: Dict[str, IndexData],
99
+ anchor_time: ColumnData,
100
+ num_forecasts: int = 1,
101
+ ) -> Tuple[ColumnData, IndexData]:
102
+ pass
@@ -0,0 +1,506 @@
1
+ from typing import Dict, List, Tuple
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from kumoapi.pquery import ValidatedPredictiveQuery
6
+ from kumoapi.pquery.AST import (
7
+ Aggregation,
8
+ Column,
9
+ Condition,
10
+ Constant,
11
+ Filter,
12
+ Join,
13
+ LogicalOperation,
14
+ )
15
+ from kumoapi.typing import AggregationType, BoolOp, MemberOp, RelOp
16
+
17
+ from kumoai.experimental.rfm.pquery import PQueryExecutor
18
+
19
+
20
+ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
21
+ np.ndarray]):
22
+ def execute_column(
23
+ self,
24
+ column: Column,
25
+ feat_dict: Dict[str, pd.DataFrame],
26
+ filter_na: bool = True,
27
+ ) -> Tuple[pd.Series, np.ndarray]:
28
+ table_name, column_name = column.fqn.split(".")
29
+ if column_name == '*':
30
+ out = pd.Series(np.ones(len(feat_dict[table_name]), dtype='int64'))
31
+ else:
32
+ out = feat_dict[table_name][column_name]
33
+ out = out.reset_index(drop=True)
34
+
35
+ if pd.api.types.is_float_dtype(out):
36
+ out = out.astype('float32')
37
+
38
+ out.name = None
39
+ out.index.name = None
40
+
41
+ mask = out.notna().to_numpy()
42
+
43
+ if not filter_na:
44
+ return out, mask
45
+
46
+ out = out[mask].reset_index(drop=True)
47
+
48
+ # Cast to primitive dtype:
49
+ if pd.api.types.is_integer_dtype(out):
50
+ out = out.astype('int64')
51
+ elif pd.api.types.is_bool_dtype(out):
52
+ out = out.astype('bool')
53
+
54
+ return out, mask
55
+
56
+ def execute_aggregation_type(
57
+ self,
58
+ op: AggregationType,
59
+ feat: pd.Series,
60
+ batch: np.ndarray,
61
+ batch_size: int,
62
+ filter_na: bool = True,
63
+ ) -> Tuple[pd.Series, np.ndarray]:
64
+
65
+ mask = feat.notna()
66
+ feat, batch = feat[mask], batch[mask]
67
+
68
+ if op == AggregationType.LIST_DISTINCT:
69
+ df = pd.DataFrame(dict(feat=feat, batch=batch))
70
+ df = df.drop_duplicates()
71
+ out = df.groupby('batch')['feat'].agg(list)
72
+
73
+ else:
74
+ df = pd.DataFrame(dict(feat=feat, batch=batch))
75
+ if op == AggregationType.AVG:
76
+ agg = 'mean'
77
+ elif op == AggregationType.COUNT:
78
+ agg = 'size'
79
+ else:
80
+ agg = op.lower()
81
+ out = df.groupby('batch')['feat'].agg(agg)
82
+
83
+ if not pd.api.types.is_datetime64_any_dtype(out):
84
+ out = out.astype('float32')
85
+
86
+ out.name = None
87
+ out.index.name = None
88
+
89
+ if op in {AggregationType.SUM, AggregationType.COUNT}:
90
+ out = out.reindex(range(batch_size), fill_value=0)
91
+ mask = np.ones(batch_size, dtype=bool)
92
+ return out, mask
93
+
94
+ mask = np.zeros(batch_size, dtype=bool)
95
+ mask[batch] = True
96
+
97
+ if filter_na:
98
+ return out.reset_index(drop=True), mask
99
+
100
+ out = out.reindex(range(batch_size), fill_value=pd.NA)
101
+
102
+ return out, mask
103
+
104
+ def execute_aggregation(
105
+ self,
106
+ aggr: Aggregation,
107
+ feat_dict: Dict[str, pd.DataFrame],
108
+ time_dict: Dict[str, pd.Series],
109
+ batch_dict: Dict[str, np.ndarray],
110
+ anchor_time: pd.Series,
111
+ filter_na: bool = True,
112
+ num_forecasts: int = 1,
113
+ ) -> Tuple[pd.Series, np.ndarray]:
114
+ target_table = aggr._get_target_column_name().split('.')[0]
115
+ target_batch = batch_dict[target_table]
116
+ target_time = time_dict[target_table]
117
+ if isinstance(aggr.target, Column):
118
+ target_feat, target_mask = self.execute_column(
119
+ column=aggr.target,
120
+ feat_dict=feat_dict,
121
+ filter_na=False,
122
+ )
123
+ else:
124
+ assert isinstance(aggr.target, Filter)
125
+ target_feat, target_mask = self.execute_filter(
126
+ filter=aggr.target,
127
+ feat_dict=feat_dict,
128
+ time_dict=time_dict,
129
+ batch_dict=batch_dict,
130
+ anchor_time=anchor_time,
131
+ filter_na=False,
132
+ )
133
+
134
+ outs: List[pd.Series] = []
135
+ masks: List[np.ndarray] = []
136
+ for _ in range(num_forecasts):
137
+ anchor_target_time = anchor_time[target_batch]
138
+ anchor_target_time = anchor_target_time.reset_index(drop=True)
139
+
140
+ curr_target_mask = target_mask & (
141
+ target_time
142
+ <= anchor_target_time + aggr.aggr_time_range.end_date_offset)
143
+ if aggr.aggr_time_range.start is not None:
144
+ start_offset = aggr.aggr_time_range.start_date_offset
145
+ curr_target_mask &= (target_time
146
+ > anchor_target_time + start_offset)
147
+ else:
148
+ assert num_forecasts == 1
149
+
150
+ out, mask = self.execute_aggregation_type(
151
+ aggr.aggr,
152
+ feat=target_feat[curr_target_mask],
153
+ batch=target_batch[curr_target_mask],
154
+ batch_size=len(anchor_time),
155
+ filter_na=False if num_forecasts > 1 else filter_na,
156
+ )
157
+ outs.append(out)
158
+ masks.append(mask)
159
+
160
+ if num_forecasts > 1:
161
+ anchor_time = (anchor_time +
162
+ aggr.aggr_time_range.end_date_offset)
163
+ if len(outs) == 1:
164
+ assert len(masks) == 1
165
+ return outs[0], masks[0]
166
+
167
+ out = pd.Series([list(ser) for ser in zip(*outs)])
168
+ mask = np.stack(masks, axis=-1).any(axis=-1) # type: ignore
169
+
170
+ if filter_na:
171
+ out = out[mask].reset_index(drop=True)
172
+
173
+ return out, mask
174
+
175
+ def execute_rel_op(
176
+ self,
177
+ left: pd.Series,
178
+ op: RelOp,
179
+ right: Constant,
180
+ ) -> pd.Series:
181
+
182
+ if right.typed_value() is None:
183
+ if op == RelOp.EQ:
184
+ return left.isna()
185
+ assert op == RelOp.NEQ
186
+ return left.notna()
187
+
188
+ # Promote left to float if right is a float to avoid lossy coercion.
189
+ right_value = right.typed_value()
190
+ if pd.api.types.is_integer_dtype(left) and isinstance(
191
+ right_value, float):
192
+ left = left.astype('float64')
193
+ value = pd.Series([right_value], dtype=left.dtype).iloc[0]
194
+
195
+ if op == RelOp.EQ:
196
+ return (left == value).fillna(False).astype(bool)
197
+ if op == RelOp.NEQ:
198
+ out = (left != value).fillna(False).astype(bool)
199
+ out[left.isna()] = False # N/A != right should always be `False`.
200
+ return out
201
+ if op == RelOp.LEQ:
202
+ return (left <= value).fillna(False).astype(bool)
203
+ if op == RelOp.GEQ:
204
+ return (left >= value).fillna(False).astype(bool)
205
+ if op == RelOp.LT:
206
+ return (left < value).fillna(False).astype(bool)
207
+ if op == RelOp.GT:
208
+ return (left > value).fillna(False).astype(bool)
209
+
210
+ raise NotImplementedError(f"Operator '{op}' not implemented")
211
+
212
+ def execute_member_op(
213
+ self,
214
+ left: pd.Series,
215
+ op: MemberOp,
216
+ right: Constant,
217
+ ) -> pd.Series:
218
+
219
+ if op == MemberOp.IN:
220
+ ser = pd.Series(right.typed_value(), dtype=left.dtype)
221
+ return left.isin(ser).astype(bool)
222
+
223
+ raise NotImplementedError(f"Operator '{op}' not implemented")
224
+
225
+ def execute_condition(
226
+ self,
227
+ condition: Condition,
228
+ feat_dict: Dict[str, pd.DataFrame],
229
+ time_dict: Dict[str, pd.Series],
230
+ batch_dict: Dict[str, np.ndarray],
231
+ anchor_time: pd.Series,
232
+ filter_na: bool = True,
233
+ num_forecasts: int = 1,
234
+ ) -> Tuple[pd.Series, np.ndarray]:
235
+ if num_forecasts > 1:
236
+ raise NotImplementedError("Forecasting not yet implemented for "
237
+ "non-regression tasks")
238
+
239
+ assert isinstance(condition.value, Constant)
240
+ value_is_na = condition.value.typed_value() is None
241
+ if isinstance(condition.target, Column):
242
+ left, mask = self.execute_column(
243
+ column=condition.target,
244
+ feat_dict=feat_dict,
245
+ filter_na=filter_na if not value_is_na else False,
246
+ )
247
+ elif isinstance(condition.target, Join):
248
+ left, mask = self.execute_join(
249
+ join=condition.target,
250
+ feat_dict=feat_dict,
251
+ time_dict=time_dict,
252
+ batch_dict=batch_dict,
253
+ anchor_time=anchor_time,
254
+ filter_na=filter_na if not value_is_na else False,
255
+ )
256
+ else:
257
+ assert isinstance(condition.target, Aggregation)
258
+ left, mask = self.execute_aggregation(
259
+ aggr=condition.target,
260
+ feat_dict=feat_dict,
261
+ time_dict=time_dict,
262
+ batch_dict=batch_dict,
263
+ anchor_time=anchor_time,
264
+ filter_na=filter_na if not value_is_na else False,
265
+ )
266
+
267
+ if filter_na and value_is_na:
268
+ mask = np.ones(len(left), dtype=bool)
269
+
270
+ if isinstance(condition.op, RelOp):
271
+ out = self.execute_rel_op(
272
+ left=left,
273
+ op=condition.op,
274
+ right=condition.value,
275
+ )
276
+ else:
277
+ assert isinstance(condition.op, MemberOp)
278
+ out = self.execute_member_op(
279
+ left=left,
280
+ op=condition.op,
281
+ right=condition.value,
282
+ )
283
+
284
+ return out, mask
285
+
286
+ def execute_bool_op(
287
+ self,
288
+ left: pd.Series,
289
+ op: BoolOp,
290
+ right: pd.Series | None,
291
+ ) -> pd.Series:
292
+
293
+ # TODO Implement Kleene-Priest three-value logic.
294
+ if op == BoolOp.AND:
295
+ assert right is not None
296
+ return left & right
297
+ if op == BoolOp.OR:
298
+ assert right is not None
299
+ return left | right
300
+ if op == BoolOp.NOT:
301
+ return ~left
302
+
303
+ raise NotImplementedError(f"Operator '{op}' not implemented")
304
+
305
+ def execute_logical_operation(
306
+ self,
307
+ logical_operation: LogicalOperation,
308
+ feat_dict: Dict[str, pd.DataFrame],
309
+ time_dict: Dict[str, pd.Series],
310
+ batch_dict: Dict[str, np.ndarray],
311
+ anchor_time: pd.Series,
312
+ filter_na: bool = True,
313
+ num_forecasts: int = 1,
314
+ ) -> Tuple[pd.Series, np.ndarray]:
315
+ if num_forecasts > 1:
316
+ raise NotImplementedError("Forecasting not yet implemented for "
317
+ "non-regression tasks")
318
+
319
+ if isinstance(logical_operation.left, Condition):
320
+ left, mask = self.execute_condition(
321
+ condition=logical_operation.left,
322
+ feat_dict=feat_dict,
323
+ time_dict=time_dict,
324
+ batch_dict=batch_dict,
325
+ anchor_time=anchor_time,
326
+ filter_na=False,
327
+ )
328
+ else:
329
+ assert isinstance(logical_operation.left, LogicalOperation)
330
+ left, mask = self.execute_logical_operation(
331
+ logical_operation=logical_operation.left,
332
+ feat_dict=feat_dict,
333
+ time_dict=time_dict,
334
+ batch_dict=batch_dict,
335
+ anchor_time=anchor_time,
336
+ filter_na=False,
337
+ )
338
+
339
+ right = right_mask = None
340
+ if isinstance(logical_operation.right, Condition):
341
+ right, right_mask = self.execute_condition(
342
+ condition=logical_operation.right,
343
+ feat_dict=feat_dict,
344
+ time_dict=time_dict,
345
+ batch_dict=batch_dict,
346
+ anchor_time=anchor_time,
347
+ filter_na=False,
348
+ )
349
+ elif isinstance(logical_operation.right, LogicalOperation):
350
+ right, right_mask = self.execute_logical_operation(
351
+ logical_operation=logical_operation.right,
352
+ feat_dict=feat_dict,
353
+ time_dict=time_dict,
354
+ batch_dict=batch_dict,
355
+ anchor_time=anchor_time,
356
+ filter_na=False,
357
+ )
358
+
359
+ out = self.execute_bool_op(left, logical_operation.bool_op, right)
360
+
361
+ if right_mask is not None:
362
+ mask &= right_mask
363
+
364
+ if filter_na:
365
+ out = out[mask].reset_index(drop=True)
366
+
367
+ return out, mask
368
+
369
+ def execute_join(
370
+ self,
371
+ join: Join,
372
+ feat_dict: Dict[str, pd.DataFrame],
373
+ time_dict: Dict[str, pd.Series],
374
+ batch_dict: Dict[str, np.ndarray],
375
+ anchor_time: pd.Series,
376
+ filter_na: bool = True,
377
+ num_forecasts: int = 1,
378
+ ) -> Tuple[pd.Series, np.ndarray]:
379
+ if isinstance(join.rhs_target, Aggregation):
380
+ return self.execute_aggregation(
381
+ aggr=join.rhs_target,
382
+ feat_dict=feat_dict,
383
+ time_dict=time_dict,
384
+ batch_dict=batch_dict,
385
+ anchor_time=anchor_time,
386
+ filter_na=True,
387
+ num_forecasts=num_forecasts,
388
+ )
389
+ raise NotImplementedError(
390
+ f'Unexpected {type(join.rhs_target)} nested in Join')
391
+
392
+ def execute_filter(
393
+ self,
394
+ filter: Filter,
395
+ feat_dict: Dict[str, pd.DataFrame],
396
+ time_dict: Dict[str, pd.Series],
397
+ batch_dict: Dict[str, np.ndarray],
398
+ anchor_time: pd.Series,
399
+ filter_na: bool = True,
400
+ ) -> Tuple[pd.Series, np.ndarray]:
401
+ out, mask = self.execute_column(
402
+ column=filter.target,
403
+ feat_dict=feat_dict,
404
+ filter_na=False,
405
+ )
406
+ if isinstance(filter.condition, Condition):
407
+ _mask = self.execute_condition(
408
+ condition=filter.condition,
409
+ feat_dict=feat_dict,
410
+ time_dict=time_dict,
411
+ batch_dict=batch_dict,
412
+ anchor_time=anchor_time,
413
+ filter_na=False,
414
+ )[0].to_numpy()
415
+ else:
416
+ assert isinstance(filter.condition, LogicalOperation)
417
+ _mask = self.execute_logical_operation(
418
+ logical_operation=filter.condition,
419
+ feat_dict=feat_dict,
420
+ time_dict=time_dict,
421
+ batch_dict=batch_dict,
422
+ anchor_time=anchor_time,
423
+ filter_na=False,
424
+ )[0].to_numpy()
425
+ if filter_na:
426
+ return out[_mask & mask].reset_index(drop=True), _mask & mask
427
+ else:
428
+ return out[_mask].reset_index(drop=True), mask & _mask
429
+
430
+ def execute(
431
+ self,
432
+ query: ValidatedPredictiveQuery,
433
+ feat_dict: Dict[str, pd.DataFrame],
434
+ time_dict: Dict[str, pd.Series],
435
+ batch_dict: Dict[str, np.ndarray],
436
+ anchor_time: pd.Series,
437
+ num_forecasts: int = 1,
438
+ ) -> Tuple[pd.Series, np.ndarray]:
439
+ if isinstance(query.entity_ast, Column):
440
+ out, mask = self.execute_column(
441
+ column=query.entity_ast,
442
+ feat_dict=feat_dict,
443
+ filter_na=True,
444
+ )
445
+ else:
446
+ assert isinstance(query.entity_ast, Filter)
447
+ out, mask = self.execute_filter(
448
+ filter=query.entity_ast,
449
+ feat_dict=feat_dict,
450
+ time_dict=time_dict,
451
+ batch_dict=batch_dict,
452
+ anchor_time=anchor_time,
453
+ )
454
+ if isinstance(query.target_ast, Column):
455
+ out, _mask = self.execute_column(
456
+ column=query.target_ast,
457
+ feat_dict=feat_dict,
458
+ filter_na=True,
459
+ )
460
+ elif isinstance(query.target_ast, Condition):
461
+ out, _mask = self.execute_condition(
462
+ condition=query.target_ast,
463
+ feat_dict=feat_dict,
464
+ time_dict=time_dict,
465
+ batch_dict=batch_dict,
466
+ anchor_time=anchor_time,
467
+ filter_na=True,
468
+ num_forecasts=num_forecasts,
469
+ )
470
+ elif isinstance(query.target_ast, Aggregation):
471
+ out, _mask = self.execute_aggregation(
472
+ aggr=query.target_ast,
473
+ feat_dict=feat_dict,
474
+ time_dict=time_dict,
475
+ batch_dict=batch_dict,
476
+ anchor_time=anchor_time,
477
+ filter_na=True,
478
+ num_forecasts=num_forecasts,
479
+ )
480
+ elif isinstance(query.target_ast, Join):
481
+ out, _mask = self.execute_join(
482
+ join=query.target_ast,
483
+ feat_dict=feat_dict,
484
+ time_dict=time_dict,
485
+ batch_dict=batch_dict,
486
+ anchor_time=anchor_time,
487
+ filter_na=True,
488
+ num_forecasts=num_forecasts,
489
+ )
490
+ elif isinstance(query.target_ast, LogicalOperation):
491
+ out, _mask = self.execute_logical_operation(
492
+ logical_operation=query.target_ast,
493
+ feat_dict=feat_dict,
494
+ time_dict=time_dict,
495
+ batch_dict=batch_dict,
496
+ anchor_time=anchor_time,
497
+ filter_na=True,
498
+ num_forecasts=num_forecasts,
499
+ )
500
+ else:
501
+ raise NotImplementedError(
502
+ f'{type(query.target)} compilation missing.')
503
+ out = out[mask[_mask]]
504
+ mask &= _mask
505
+ out = out.reset_index(drop=True)
506
+ return out, mask
@@ -199,6 +199,7 @@ class KumoRFM:
199
199
  max_pq_iterations: int = 20,
200
200
  random_seed: Optional[int] = _RANDOM_SEED,
201
201
  verbose: Union[bool, ProgressLogger] = True,
202
+ use_prediction_time: bool = False,
202
203
  ) -> pd.DataFrame:
203
204
  pass
204
205
 
@@ -217,6 +218,7 @@ class KumoRFM:
217
218
  max_pq_iterations: int = 20,
218
219
  random_seed: Optional[int] = _RANDOM_SEED,
219
220
  verbose: Union[bool, ProgressLogger] = True,
221
+ use_prediction_time: bool = False,
220
222
  ) -> Explanation:
221
223
  pass
222
224
 
@@ -234,6 +236,7 @@ class KumoRFM:
234
236
  max_pq_iterations: int = 20,
235
237
  random_seed: Optional[int] = _RANDOM_SEED,
236
238
  verbose: Union[bool, ProgressLogger] = True,
239
+ use_prediction_time: bool = False,
237
240
  ) -> Union[pd.DataFrame, Explanation]:
238
241
  """Returns predictions for a predictive query.
239
242
 
@@ -264,6 +267,9 @@ class KumoRFM:
264
267
  entities to find valid labels.
265
268
  random_seed: A manual seed for generating pseudo-random numbers.
266
269
  verbose: Whether to print verbose output.
270
+ use_prediction_time: Whether to use the anchor timestamp as an
271
+ additional feature during prediction. This is typically
272
+ beneficial for time series forecasting tasks.
267
273
 
268
274
  Returns:
269
275
  The predictions as a :class:`pandas.DataFrame`.
@@ -353,6 +359,7 @@ class KumoRFM:
353
359
  request = RFMPredictRequest(
354
360
  context=context,
355
361
  run_mode=RunMode(run_mode),
362
+ use_prediction_time=use_prediction_time,
356
363
  )
357
364
  with warnings.catch_warnings():
358
365
  warnings.filterwarnings('ignore', message='gencode')
@@ -503,6 +510,7 @@ class KumoRFM:
503
510
  max_pq_iterations: int = 20,
504
511
  random_seed: Optional[int] = _RANDOM_SEED,
505
512
  verbose: Union[bool, ProgressLogger] = True,
513
+ use_prediction_time: bool = False,
506
514
  ) -> pd.DataFrame:
507
515
  """Evaluates a predictive query.
508
516
 
@@ -526,6 +534,9 @@ class KumoRFM:
526
534
  entities to find valid labels.
527
535
  random_seed: A manual seed for generating pseudo-random numbers.
528
536
  verbose: Whether to print verbose output.
537
+ use_prediction_time: Whether to use the anchor timestamp as an
538
+ additional feature during prediction. This is typically
539
+ beneficial for time series forecasting tasks.
529
540
 
530
541
  Returns:
531
542
  The metrics as a :class:`pandas.DataFrame`
@@ -569,6 +580,7 @@ class KumoRFM:
569
580
  context=context,
570
581
  run_mode=RunMode(run_mode),
571
582
  metrics=metrics,
583
+ use_prediction_time=use_prediction_time,
572
584
  )
573
585
  with warnings.catch_warnings():
574
586
  warnings.filterwarnings('ignore', message='Protobuf gencode')
@@ -1006,7 +1018,7 @@ class KumoRFM:
1006
1018
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1007
1019
  supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
1008
1020
  elif task_type == TaskType.REGRESSION:
1009
- supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
1021
+ supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
1010
1022
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1011
1023
  supported_metrics = [
1012
1024
  'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',
kumoai/trainer/trainer.py CHANGED
@@ -20,7 +20,6 @@ from kumoapi.jobs import (
20
20
  TrainingJobResource,
21
21
  )
22
22
  from kumoapi.model_plan import ModelPlan
23
- from kumoapi.task import TaskType
24
23
 
25
24
  from kumoai import global_state
26
25
  from kumoai.artifact_export.config import OutputConfig
@@ -405,15 +404,15 @@ class Trainer:
405
404
  pred_table_data_path = prediction_table.table_data_uri
406
405
 
407
406
  api = global_state.client.batch_prediction_job_api
408
-
409
- from kumoai.pquery.predictive_query import PredictiveQuery
410
- pquery = PredictiveQuery.load_from_training_job(training_job_id)
411
- if pquery.get_task_type() == TaskType.BINARY_CLASSIFICATION:
412
- if binary_classification_threshold is None:
413
- logger.warning("No binary classification threshold provided. "
414
- "Using default threshold of 0.5.")
415
- binary_classification_threshold = 0.5
416
-
407
+ # Remove to resolve https://github.com/kumo-ai/kumo/issues/24250
408
+ # from kumoai.pquery.predictive_query import PredictiveQuery
409
+ # pquery = PredictiveQuery.load_from_training_job(training_job_id)
410
+ # if pquery.get_task_type() == TaskType.BINARY_CLASSIFICATION:
411
+ # if binary_classification_threshold is None:
412
+ # logger.warning(
413
+ # "No binary classification threshold provided. "
414
+ # "Using default threshold of 0.5.")
415
+ # binary_classification_threshold = 0.5
417
416
  job_id, response = api.maybe_create(
418
417
  BatchPredictionRequest(
419
418
  dict(custom_tags),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kumoai
3
- Version: 2.10.1
3
+ Version: 2.12.0.dev202511031731
4
4
  Summary: AI on the Modern Data Stack
5
5
  Author-email: "Kumo.AI" <hello@kumo.ai>
6
6
  License-Expression: MIT
@@ -23,7 +23,7 @@ Requires-Dist: requests>=2.28.2
23
23
  Requires-Dist: urllib3
24
24
  Requires-Dist: plotly
25
25
  Requires-Dist: typing_extensions>=4.5.0
26
- Requires-Dist: kumo-api==0.38.0
26
+ Requires-Dist: kumo-api==0.40.0
27
27
  Requires-Dist: tqdm>=4.66.0
28
28
  Requires-Dist: aiohttp>=3.10.0
29
29
  Requires-Dist: pydantic>=1.10.21
@@ -1,7 +1,7 @@
1
1
  kumoai/_logging.py,sha256=U2_5ROdyk92P4xO4H2WJV8EC7dr6YxmmnM-b7QX9M7I,886
2
2
  kumoai/mixin.py,sha256=MP413xzuCqWhxAPUHmloLA3j4ZyF1tEtfi516b_hOXQ,812
3
- kumoai/_version.py,sha256=fy1qvJHPd7FWOAuuVNKh9cFE7RrqWon8x59x3EjYTCc,23
4
- kumoai/__init__.py,sha256=x3DjDsWBgWSNwo7mDwb3XAoRm2NuSO09yvhQTL9tBT8,10673
3
+ kumoai/_version.py,sha256=X5C9cHVsjznMq0N29k8V18IjmrXq8NyKWG7IEMkjaBc,39
4
+ kumoai/__init__.py,sha256=LU1zmKYc0KV5hy2VGKUuXgSvbJwj2rSRQ_R_bpHyl1o,10708
5
5
  kumoai/formatting.py,sha256=jA_rLDCGKZI8WWCha-vtuLenVKTZvli99Tqpurz1H84,953
6
6
  kumoai/futures.py,sha256=oJFIfdCM_3nWIqQteBKYMY4fPhoYlYWE_JA2o6tx-ng,3737
7
7
  kumoai/kumolib.cpython-311-darwin.so,sha256=AmB_Fysmud1y7Gm5CuBQ5lWDuSzpxVDV_iTA2cjH1s8,232544
@@ -12,17 +12,19 @@ kumoai/spcs.py,sha256=N4ddeoHAc4I3bKrDitsb91lUx5VKvCyPyMT3zWiuCcY,4275
12
12
  kumoai/_singleton.py,sha256=UTwrbDkoZSGB8ZelorvprPDDv9uZkUi1q_SrmsyngpQ,836
13
13
  kumoai/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  kumoai/experimental/rfm/local_graph_sampler.py,sha256=o60_sdMa_fr60DrdmCIaE6lKQAD2msp1t-GGubFNt-o,6738
15
- kumoai/experimental/rfm/local_graph.py,sha256=2LTllKKnjkThM7lr6jg_miypWN_oLC3YmIcZHLkAa4U,30076
15
+ kumoai/experimental/rfm/local_graph.py,sha256=2iJDlsGVzqCe1bD_puXWlhwGkn7YnQyJ4p4C-fwCZNE,30076
16
16
  kumoai/experimental/rfm/local_pquery_driver.py,sha256=xqAQ9fJfkqM1axknFpg0NLQbIYmExh-s7vGdUyDEkwA,18600
17
17
  kumoai/experimental/rfm/__init__.py,sha256=F1aUOCLDN2yrIRDAiOlogDfXKUkUQgp8Mt0pVX9rLX8,1641
18
18
  kumoai/experimental/rfm/utils.py,sha256=3IiBvT_aLBkkcJh3H11_50yt_XlEzHR0cm9Kprrtl8k,11123
19
19
  kumoai/experimental/rfm/local_table.py,sha256=r8xZ33Mjs6JD8ud6h23tZ99Dag2DvZ4h6tWjmGrKQg4,19605
20
- kumoai/experimental/rfm/rfm.py,sha256=cKHwIinGmnRg1_QiDpzhC_ZnNTWhvX3tA4ZhXgjlNEU,45513
20
+ kumoai/experimental/rfm/rfm.py,sha256=BcC0EqXfz2OhMT-g8gBGv7M6yUTboj-PyGWIQZPUf70,46227
21
21
  kumoai/experimental/rfm/local_graph_store.py,sha256=8BqonuaMftAAsjgZpB369i5AeNd1PkisMbbEqc0cKBo,13847
22
22
  kumoai/experimental/rfm/authenticate.py,sha256=FiuHMvP7V3zBZUlHMDMbNLhc-UgDZgz4hjVSTuQ7DRw,18888
23
23
  kumoai/experimental/rfm/pquery/backend.py,sha256=6wtB0yFpxQUraBSA2TbKMVSIMD0dcLwYV5P4SQx2g_k,3287
24
- kumoai/experimental/rfm/pquery/__init__.py,sha256=bsNcdn7DnPw9kpSQ_bQVmQX1RmXzPQhzfA1y6G-n7I8,146
24
+ kumoai/experimental/rfm/pquery/__init__.py,sha256=9uLXixjp78y0IzO2F__lFqKNm37OGhN3iDh56akWLNU,283
25
25
  kumoai/experimental/rfm/pquery/pandas_backend.py,sha256=pgHCErSo6U-KJMhgIYijYt96uubtFB2WtsrTdLU7NYc,15396
26
+ kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=BgF3saosisgLHx1RyLj-HSEbMp4xLatNuARdKWwiiLY,17326
27
+ kumoai/experimental/rfm/pquery/executor.py,sha256=f7-pJhL0BgFU9E4o4gQpQyArOvyrZtwxFmks34-QOAE,2741
26
28
  kumoai/experimental/rfm/infer/multicategorical.py,sha256=0-cLpDnGryhr76QhZNO-klKokJ6MUSfxXcGdQ61oykY,1102
27
29
  kumoai/experimental/rfm/infer/categorical.py,sha256=VwNaKwKbRYkTxEJ1R6gziffC8dGsEThcDEfbi-KqW5c,853
28
30
  kumoai/experimental/rfm/infer/id.py,sha256=ZIO0DWIoiEoS_8MVc5lkqBfkTWWQ0yGCgjkwLdaYa_Q,908
@@ -90,9 +92,9 @@ kumoai/trainer/job.py,sha256=Wk69nzFhbvuA3nEvtCstI04z5CxkgvQ6tHnGchE0Lkg,44938
90
92
  kumoai/trainer/baseline_trainer.py,sha256=LlfViNOmswNv4c6zJJLsyv0pC2mM2WKMGYx06ogtEVc,4024
91
93
  kumoai/trainer/__init__.py,sha256=zUdFl-f-sBWmm2x8R-rdVzPBeU2FaMzUY5mkcgoTa1k,939
92
94
  kumoai/trainer/online_serving.py,sha256=9cddb5paeZaCgbUeceQdAOxysCtV5XP-KcsgFz_XR5w,9566
93
- kumoai/trainer/trainer.py,sha256=nPeZMMp17TtRFd4lKbF-TlMPnhYR4_VyPDPI0T9W9PU,20094
94
- kumoai-2.10.1.dist-info/RECORD,,
95
- kumoai-2.10.1.dist-info/WHEEL,sha256=sunMa2yiYbrNLGeMVDqEA0ayyJbHlex7SCn1TZrEq60,136
96
- kumoai-2.10.1.dist-info/top_level.txt,sha256=YjU6UcmomoDx30vEXLsOU784ED7VztQOsFApk1SFwvs,7
97
- kumoai-2.10.1.dist-info/METADATA,sha256=6XcvFFVYccShSlMdWchiyehRoiG93v2gU18n7D6pwD4,2036
98
- kumoai-2.10.1.dist-info/licenses/LICENSE,sha256=TbWlyqRmhq9PEzCaTI0H0nWLQCCOywQM8wYH8MbjfLo,1102
95
+ kumoai/trainer/trainer.py,sha256=hBXO7gwpo3t59zKFTeIkK65B8QRmWCwO33sbDuEAPlY,20133
96
+ kumoai-2.12.0.dev202511031731.dist-info/RECORD,,
97
+ kumoai-2.12.0.dev202511031731.dist-info/WHEEL,sha256=sunMa2yiYbrNLGeMVDqEA0ayyJbHlex7SCn1TZrEq60,136
98
+ kumoai-2.12.0.dev202511031731.dist-info/top_level.txt,sha256=YjU6UcmomoDx30vEXLsOU784ED7VztQOsFApk1SFwvs,7
99
+ kumoai-2.12.0.dev202511031731.dist-info/METADATA,sha256=yf8LuBryiRverUZLTN389Y_94sZrrzNnITX6sDAlfy0,2052
100
+ kumoai-2.12.0.dev202511031731.dist-info/licenses/LICENSE,sha256=TbWlyqRmhq9PEzCaTI0H0nWLQCCOywQM8wYH8MbjfLo,1102