kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (122) hide show
  1. kumoai/__init__.py +300 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +223 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +471 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +207 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1796 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +210 -0
  51. kumoai/experimental/rfm/authenticate.py +432 -0
  52. kumoai/experimental/rfm/backend/__init__.py +0 -0
  53. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  54. kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
  55. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  56. kumoai/experimental/rfm/backend/local/table.py +113 -0
  57. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  58. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  59. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  60. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  61. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  62. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  63. kumoai/experimental/rfm/base/__init__.py +30 -0
  64. kumoai/experimental/rfm/base/column.py +152 -0
  65. kumoai/experimental/rfm/base/expression.py +44 -0
  66. kumoai/experimental/rfm/base/sampler.py +761 -0
  67. kumoai/experimental/rfm/base/source.py +19 -0
  68. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  69. kumoai/experimental/rfm/base/table.py +736 -0
  70. kumoai/experimental/rfm/graph.py +1237 -0
  71. kumoai/experimental/rfm/infer/__init__.py +19 -0
  72. kumoai/experimental/rfm/infer/categorical.py +40 -0
  73. kumoai/experimental/rfm/infer/dtype.py +82 -0
  74. kumoai/experimental/rfm/infer/id.py +46 -0
  75. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  76. kumoai/experimental/rfm/infer/pkey.py +128 -0
  77. kumoai/experimental/rfm/infer/stype.py +35 -0
  78. kumoai/experimental/rfm/infer/time_col.py +61 -0
  79. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  80. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  81. kumoai/experimental/rfm/pquery/executor.py +102 -0
  82. kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
  83. kumoai/experimental/rfm/relbench.py +76 -0
  84. kumoai/experimental/rfm/rfm.py +1184 -0
  85. kumoai/experimental/rfm/sagemaker.py +138 -0
  86. kumoai/experimental/rfm/task_table.py +231 -0
  87. kumoai/formatting.py +30 -0
  88. kumoai/futures.py +99 -0
  89. kumoai/graph/__init__.py +12 -0
  90. kumoai/graph/column.py +106 -0
  91. kumoai/graph/graph.py +948 -0
  92. kumoai/graph/table.py +838 -0
  93. kumoai/jobs.py +80 -0
  94. kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
  95. kumoai/mixin.py +28 -0
  96. kumoai/pquery/__init__.py +25 -0
  97. kumoai/pquery/prediction_table.py +287 -0
  98. kumoai/pquery/predictive_query.py +641 -0
  99. kumoai/pquery/training_table.py +424 -0
  100. kumoai/spcs.py +121 -0
  101. kumoai/testing/__init__.py +8 -0
  102. kumoai/testing/decorators.py +57 -0
  103. kumoai/testing/snow.py +50 -0
  104. kumoai/trainer/__init__.py +42 -0
  105. kumoai/trainer/baseline_trainer.py +93 -0
  106. kumoai/trainer/config.py +2 -0
  107. kumoai/trainer/distilled_trainer.py +175 -0
  108. kumoai/trainer/job.py +1192 -0
  109. kumoai/trainer/online_serving.py +258 -0
  110. kumoai/trainer/trainer.py +475 -0
  111. kumoai/trainer/util.py +103 -0
  112. kumoai/utils/__init__.py +11 -0
  113. kumoai/utils/datasets.py +83 -0
  114. kumoai/utils/display.py +51 -0
  115. kumoai/utils/forecasting.py +209 -0
  116. kumoai/utils/progress_logger.py +343 -0
  117. kumoai/utils/sql.py +3 -0
  118. kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
  119. kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
  120. kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
  121. kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
  122. kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
@@ -0,0 +1,530 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ from kumoapi.pquery import ValidatedPredictiveQuery
4
+ from kumoapi.pquery.AST import (
5
+ Aggregation,
6
+ Column,
7
+ Condition,
8
+ Constant,
9
+ Filter,
10
+ Join,
11
+ LogicalOperation,
12
+ )
13
+ from kumoapi.typing import AggregationType, BoolOp, MemberOp, RelOp
14
+
15
+ from kumoai.experimental.rfm.pquery import PQueryExecutor
16
+
17
+
18
+ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
19
+ np.ndarray]):
20
+ def execute_column(
21
+ self,
22
+ column: Column,
23
+ feat_dict: dict[str, pd.DataFrame],
24
+ filter_na: bool = True,
25
+ ) -> tuple[pd.Series, np.ndarray]:
26
+ table_name, column_name = column.fqn.split(".")
27
+ if column_name == '*':
28
+ out = pd.Series(np.ones(len(feat_dict[table_name]), dtype='int64'))
29
+ else:
30
+ out = feat_dict[table_name][column_name]
31
+ out = out.reset_index(drop=True)
32
+
33
+ if pd.api.types.is_float_dtype(out):
34
+ out = out.astype('float32')
35
+
36
+ out.name = None
37
+ out.index.name = None
38
+
39
+ mask = out.notna().to_numpy()
40
+
41
+ if not filter_na:
42
+ return out, mask
43
+
44
+ out = out[mask].reset_index(drop=True)
45
+
46
+ # Cast to primitive dtype:
47
+ if pd.api.types.is_integer_dtype(out):
48
+ out = out.astype('int64')
49
+ elif pd.api.types.is_bool_dtype(out):
50
+ out = out.astype('bool')
51
+
52
+ return out, mask
53
+
54
+ def execute_aggregation_type(
55
+ self,
56
+ op: AggregationType,
57
+ feat: pd.Series,
58
+ batch: np.ndarray,
59
+ batch_size: int,
60
+ filter_na: bool = True,
61
+ ) -> tuple[pd.Series, np.ndarray]:
62
+
63
+ mask = feat.notna()
64
+ feat, batch = feat[mask], batch[mask]
65
+
66
+ if op == AggregationType.LIST_DISTINCT:
67
+ df = pd.DataFrame(dict(feat=feat, batch=batch))
68
+ df = df.drop_duplicates()
69
+ out = df.groupby('batch')['feat'].agg(list)
70
+
71
+ else:
72
+ df = pd.DataFrame(dict(feat=feat, batch=batch))
73
+ if op == AggregationType.AVG:
74
+ agg = 'mean'
75
+ elif op == AggregationType.COUNT:
76
+ agg = 'size'
77
+ else:
78
+ agg = op.lower()
79
+ out = df.groupby('batch')['feat'].agg(agg)
80
+
81
+ if not pd.api.types.is_datetime64_any_dtype(out):
82
+ out = out.astype('float32')
83
+
84
+ out.name = None
85
+ out.index.name = None
86
+
87
+ if op in {AggregationType.SUM, AggregationType.COUNT}:
88
+ out = out.reindex(range(batch_size), fill_value=0)
89
+ mask = np.ones(batch_size, dtype=bool)
90
+ return out, mask
91
+
92
+ mask = np.zeros(batch_size, dtype=bool)
93
+ mask[batch] = True
94
+
95
+ if filter_na:
96
+ return out.reset_index(drop=True), mask
97
+
98
+ out = out.reindex(range(batch_size), fill_value=pd.NA)
99
+
100
+ return out, mask
101
+
102
+ def execute_aggregation(
103
+ self,
104
+ aggr: Aggregation,
105
+ feat_dict: dict[str, pd.DataFrame],
106
+ time_dict: dict[str, pd.Series],
107
+ batch_dict: dict[str, np.ndarray],
108
+ anchor_time: pd.Series,
109
+ filter_na: bool = True,
110
+ num_forecasts: int = 1,
111
+ ) -> tuple[pd.Series, np.ndarray]:
112
+ target_table = aggr._get_target_column_name().split('.')[0]
113
+ target_batch = batch_dict[target_table]
114
+ target_time = time_dict[target_table]
115
+ if isinstance(aggr.target, Column):
116
+ target_feat, target_mask = self.execute_column(
117
+ column=aggr.target,
118
+ feat_dict=feat_dict,
119
+ filter_na=True,
120
+ )
121
+ else:
122
+ assert isinstance(aggr.target, Filter)
123
+ target_feat, target_mask = self.execute_filter(
124
+ filter=aggr.target,
125
+ feat_dict=feat_dict,
126
+ time_dict=time_dict,
127
+ batch_dict=batch_dict,
128
+ anchor_time=anchor_time,
129
+ filter_na=True,
130
+ )
131
+
132
+ outs: list[pd.Series] = []
133
+ masks: list[np.ndarray] = []
134
+ for _ in range(num_forecasts):
135
+ anchor_target_time = anchor_time.iloc[target_batch]
136
+ anchor_target_time = anchor_target_time.reset_index(drop=True)
137
+
138
+ time_filter_mask = (target_time <= anchor_target_time +
139
+ aggr.aggr_time_range.end_date_offset)
140
+ if aggr.aggr_time_range.start is not None:
141
+ start_offset = aggr.aggr_time_range.start_date_offset
142
+ time_filter_mask &= (target_time
143
+ > anchor_target_time + start_offset)
144
+ else:
145
+ assert num_forecasts == 1
146
+ curr_target_mask = target_mask & time_filter_mask
147
+
148
+ out, mask = self.execute_aggregation_type(
149
+ aggr.aggr,
150
+ feat=target_feat[time_filter_mask[target_mask].reset_index(
151
+ drop=True)],
152
+ batch=target_batch[curr_target_mask],
153
+ batch_size=len(anchor_time),
154
+ filter_na=False if num_forecasts > 1 else filter_na,
155
+ )
156
+ outs.append(out)
157
+ masks.append(mask)
158
+
159
+ if num_forecasts > 1:
160
+ anchor_time = (anchor_time +
161
+ aggr.aggr_time_range.end_date_offset)
162
+ if len(outs) == 1:
163
+ assert len(masks) == 1
164
+ return outs[0], masks[0]
165
+
166
+ out = pd.Series([list(ser) for ser in zip(*outs)])
167
+ mask = np.stack(masks, axis=-1).any(axis=-1) # type: ignore
168
+
169
+ if filter_na:
170
+ out = out[mask].reset_index(drop=True)
171
+
172
+ return out, mask
173
+
174
+ def execute_rel_op(
175
+ self,
176
+ left: pd.Series,
177
+ op: RelOp,
178
+ right: Constant,
179
+ ) -> pd.Series:
180
+
181
+ if right.typed_value() is None:
182
+ if op == RelOp.EQ:
183
+ return left.isna()
184
+ assert op == RelOp.NEQ
185
+ return left.notna()
186
+
187
+ # Promote left to float if right is a float to avoid lossy coercion.
188
+ right_value = right.typed_value()
189
+ if pd.api.types.is_integer_dtype(left) and isinstance(
190
+ right_value, float):
191
+ left = left.astype('float64')
192
+ value = pd.Series([right_value], dtype=left.dtype).iloc[0]
193
+
194
+ if op == RelOp.EQ:
195
+ return (left == value).fillna(False).astype(bool)
196
+ if op == RelOp.NEQ:
197
+ out = (left != value).fillna(False).astype(bool)
198
+ out[left.isna()] = False # N/A != right should always be `False`.
199
+ return out
200
+ if op == RelOp.LEQ:
201
+ return (left <= value).fillna(False).astype(bool)
202
+ if op == RelOp.GEQ:
203
+ return (left >= value).fillna(False).astype(bool)
204
+ if op == RelOp.LT:
205
+ return (left < value).fillna(False).astype(bool)
206
+ if op == RelOp.GT:
207
+ return (left > value).fillna(False).astype(bool)
208
+
209
+ raise NotImplementedError(f"Operator '{op}' not implemented")
210
+
211
+ def execute_member_op(
212
+ self,
213
+ left: pd.Series,
214
+ op: MemberOp,
215
+ right: Constant,
216
+ ) -> pd.Series:
217
+
218
+ if op == MemberOp.IN:
219
+ ser = pd.Series(right.typed_value(), dtype=left.dtype)
220
+ return left.isin(ser).astype(bool)
221
+
222
+ raise NotImplementedError(f"Operator '{op}' not implemented")
223
+
224
+ def execute_condition(
225
+ self,
226
+ condition: Condition,
227
+ feat_dict: dict[str, pd.DataFrame],
228
+ time_dict: dict[str, pd.Series],
229
+ batch_dict: dict[str, np.ndarray],
230
+ anchor_time: pd.Series,
231
+ filter_na: bool = True,
232
+ num_forecasts: int = 1,
233
+ ) -> tuple[pd.Series, np.ndarray]:
234
+ if num_forecasts > 1:
235
+ raise NotImplementedError("Forecasting not yet implemented for "
236
+ "non-regression tasks")
237
+
238
+ assert isinstance(condition.value, Constant)
239
+ value_is_na = condition.value.typed_value() is None
240
+ if isinstance(condition.target, Column):
241
+ left, mask = self.execute_column(
242
+ column=condition.target,
243
+ feat_dict=feat_dict,
244
+ filter_na=filter_na if not value_is_na else False,
245
+ )
246
+ elif isinstance(condition.target, Join):
247
+ left, mask = self.execute_join(
248
+ join=condition.target,
249
+ feat_dict=feat_dict,
250
+ time_dict=time_dict,
251
+ batch_dict=batch_dict,
252
+ anchor_time=anchor_time,
253
+ filter_na=filter_na if not value_is_na else False,
254
+ )
255
+ else:
256
+ assert isinstance(condition.target, Aggregation)
257
+ left, mask = self.execute_aggregation(
258
+ aggr=condition.target,
259
+ feat_dict=feat_dict,
260
+ time_dict=time_dict,
261
+ batch_dict=batch_dict,
262
+ anchor_time=anchor_time,
263
+ filter_na=filter_na if not value_is_na else False,
264
+ )
265
+
266
+ if filter_na and value_is_na:
267
+ mask = np.ones(len(left), dtype=bool)
268
+
269
+ if isinstance(condition.op, RelOp):
270
+ out = self.execute_rel_op(
271
+ left=left,
272
+ op=condition.op,
273
+ right=condition.value,
274
+ )
275
+ else:
276
+ assert isinstance(condition.op, MemberOp)
277
+ out = self.execute_member_op(
278
+ left=left,
279
+ op=condition.op,
280
+ right=condition.value,
281
+ )
282
+
283
+ return out, mask
284
+
285
+ def execute_bool_op(
286
+ self,
287
+ left: pd.Series,
288
+ op: BoolOp,
289
+ right: pd.Series | None,
290
+ ) -> pd.Series:
291
+
292
+ # TODO Implement Kleene-Priest three-value logic.
293
+ if op == BoolOp.AND:
294
+ assert right is not None
295
+ return left & right
296
+ if op == BoolOp.OR:
297
+ assert right is not None
298
+ return left | right
299
+ if op == BoolOp.NOT:
300
+ return ~left
301
+
302
+ raise NotImplementedError(f"Operator '{op}' not implemented")
303
+
304
+ def execute_logical_operation(
305
+ self,
306
+ logical_operation: LogicalOperation,
307
+ feat_dict: dict[str, pd.DataFrame],
308
+ time_dict: dict[str, pd.Series],
309
+ batch_dict: dict[str, np.ndarray],
310
+ anchor_time: pd.Series,
311
+ filter_na: bool = True,
312
+ num_forecasts: int = 1,
313
+ ) -> tuple[pd.Series, np.ndarray]:
314
+ if num_forecasts > 1:
315
+ raise NotImplementedError("Forecasting not yet implemented for "
316
+ "non-regression tasks")
317
+
318
+ if isinstance(logical_operation.left, Condition):
319
+ left, mask = self.execute_condition(
320
+ condition=logical_operation.left,
321
+ feat_dict=feat_dict,
322
+ time_dict=time_dict,
323
+ batch_dict=batch_dict,
324
+ anchor_time=anchor_time,
325
+ filter_na=False,
326
+ )
327
+ else:
328
+ assert isinstance(logical_operation.left, LogicalOperation)
329
+ left, mask = self.execute_logical_operation(
330
+ logical_operation=logical_operation.left,
331
+ feat_dict=feat_dict,
332
+ time_dict=time_dict,
333
+ batch_dict=batch_dict,
334
+ anchor_time=anchor_time,
335
+ filter_na=False,
336
+ )
337
+
338
+ right = right_mask = None
339
+ if isinstance(logical_operation.right, Condition):
340
+ right, right_mask = self.execute_condition(
341
+ condition=logical_operation.right,
342
+ feat_dict=feat_dict,
343
+ time_dict=time_dict,
344
+ batch_dict=batch_dict,
345
+ anchor_time=anchor_time,
346
+ filter_na=False,
347
+ )
348
+ elif isinstance(logical_operation.right, LogicalOperation):
349
+ right, right_mask = self.execute_logical_operation(
350
+ logical_operation=logical_operation.right,
351
+ feat_dict=feat_dict,
352
+ time_dict=time_dict,
353
+ batch_dict=batch_dict,
354
+ anchor_time=anchor_time,
355
+ filter_na=False,
356
+ )
357
+
358
+ out = self.execute_bool_op(left, logical_operation.bool_op, right)
359
+
360
+ if right_mask is not None:
361
+ mask &= right_mask
362
+
363
+ if filter_na:
364
+ out = out[mask].reset_index(drop=True)
365
+
366
+ return out, mask
367
+
368
+ def execute_join(
369
+ self,
370
+ join: Join,
371
+ feat_dict: dict[str, pd.DataFrame],
372
+ time_dict: dict[str, pd.Series],
373
+ batch_dict: dict[str, np.ndarray],
374
+ anchor_time: pd.Series,
375
+ filter_na: bool = True,
376
+ num_forecasts: int = 1,
377
+ ) -> tuple[pd.Series, np.ndarray]:
378
+ if isinstance(join.rhs_target, Aggregation):
379
+ return self.execute_aggregation(
380
+ aggr=join.rhs_target,
381
+ feat_dict=feat_dict,
382
+ time_dict=time_dict,
383
+ batch_dict=batch_dict,
384
+ anchor_time=anchor_time,
385
+ filter_na=True,
386
+ num_forecasts=num_forecasts,
387
+ )
388
+ raise NotImplementedError(
389
+ f'Unexpected {type(join.rhs_target)} nested in Join')
390
+
391
+ def execute_filter(
392
+ self,
393
+ filter: Filter,
394
+ feat_dict: dict[str, pd.DataFrame],
395
+ time_dict: dict[str, pd.Series],
396
+ batch_dict: dict[str, np.ndarray],
397
+ anchor_time: pd.Series,
398
+ filter_na: bool = True,
399
+ ) -> tuple[pd.Series, np.ndarray]:
400
+ out, mask = self.execute_column(
401
+ column=filter.target,
402
+ feat_dict=feat_dict,
403
+ filter_na=False,
404
+ )
405
+ if isinstance(filter.condition, Condition):
406
+ _mask = self.execute_condition(
407
+ condition=filter.condition,
408
+ feat_dict=feat_dict,
409
+ time_dict=time_dict,
410
+ batch_dict=batch_dict,
411
+ anchor_time=anchor_time,
412
+ filter_na=False,
413
+ )[0].to_numpy()
414
+ else:
415
+ assert isinstance(filter.condition, LogicalOperation)
416
+ _mask = self.execute_logical_operation(
417
+ logical_operation=filter.condition,
418
+ feat_dict=feat_dict,
419
+ time_dict=time_dict,
420
+ batch_dict=batch_dict,
421
+ anchor_time=anchor_time,
422
+ filter_na=False,
423
+ )[0].to_numpy()
424
+ if filter_na:
425
+ return out[_mask & mask].reset_index(drop=True), _mask & mask
426
+ else:
427
+ return out[_mask].reset_index(drop=True), mask & _mask
428
+
429
+ def execute(
430
+ self,
431
+ query: ValidatedPredictiveQuery,
432
+ feat_dict: dict[str, pd.DataFrame],
433
+ time_dict: dict[str, pd.Series],
434
+ batch_dict: dict[str, np.ndarray],
435
+ anchor_time: pd.Series,
436
+ num_forecasts: int = 1,
437
+ ) -> tuple[pd.Series, np.ndarray]:
438
+ if isinstance(query.entity_ast, Column):
439
+ out, mask = self.execute_column(
440
+ column=query.entity_ast,
441
+ feat_dict=feat_dict,
442
+ filter_na=True,
443
+ )
444
+ else:
445
+ assert isinstance(query.entity_ast, Filter)
446
+ out, mask = self.execute_filter(
447
+ filter=query.entity_ast,
448
+ feat_dict=feat_dict,
449
+ time_dict=time_dict,
450
+ batch_dict=batch_dict,
451
+ anchor_time=anchor_time,
452
+ )
453
+ if isinstance(query.target_ast, Column):
454
+ out, _mask = self.execute_column(
455
+ column=query.target_ast,
456
+ feat_dict=feat_dict,
457
+ filter_na=True,
458
+ )
459
+ elif isinstance(query.target_ast, Condition):
460
+ out, _mask = self.execute_condition(
461
+ condition=query.target_ast,
462
+ feat_dict=feat_dict,
463
+ time_dict=time_dict,
464
+ batch_dict=batch_dict,
465
+ anchor_time=anchor_time,
466
+ filter_na=True,
467
+ num_forecasts=num_forecasts,
468
+ )
469
+ elif isinstance(query.target_ast, Aggregation):
470
+ out, _mask = self.execute_aggregation(
471
+ aggr=query.target_ast,
472
+ feat_dict=feat_dict,
473
+ time_dict=time_dict,
474
+ batch_dict=batch_dict,
475
+ anchor_time=anchor_time,
476
+ filter_na=True,
477
+ num_forecasts=num_forecasts,
478
+ )
479
+ elif isinstance(query.target_ast, Join):
480
+ out, _mask = self.execute_join(
481
+ join=query.target_ast,
482
+ feat_dict=feat_dict,
483
+ time_dict=time_dict,
484
+ batch_dict=batch_dict,
485
+ anchor_time=anchor_time,
486
+ filter_na=True,
487
+ num_forecasts=num_forecasts,
488
+ )
489
+ elif isinstance(query.target_ast, LogicalOperation):
490
+ out, _mask = self.execute_logical_operation(
491
+ logical_operation=query.target_ast,
492
+ feat_dict=feat_dict,
493
+ time_dict=time_dict,
494
+ batch_dict=batch_dict,
495
+ anchor_time=anchor_time,
496
+ filter_na=True,
497
+ num_forecasts=num_forecasts,
498
+ )
499
+ else:
500
+ raise NotImplementedError(
501
+ f'{type(query.target_ast)} compilation missing.')
502
+ if query.whatif_ast is not None:
503
+ if isinstance(query.whatif_ast, Condition):
504
+ mask &= self.execute_condition(
505
+ condition=query.whatif_ast,
506
+ feat_dict=feat_dict,
507
+ time_dict=time_dict,
508
+ batch_dict=batch_dict,
509
+ anchor_time=anchor_time,
510
+ filter_na=True,
511
+ num_forecasts=num_forecasts,
512
+ )[0]
513
+ elif isinstance(query.whatif_ast, LogicalOperation):
514
+ mask &= self.execute_logical_operation(
515
+ logical_operation=query.whatif_ast,
516
+ feat_dict=feat_dict,
517
+ time_dict=time_dict,
518
+ batch_dict=batch_dict,
519
+ anchor_time=anchor_time,
520
+ filter_na=True,
521
+ num_forecasts=num_forecasts,
522
+ )[0]
523
+ else:
524
+ raise ValueError(
525
+ f'Unsupported ASSUMING condition {type(query.whatif_ast)}')
526
+
527
+ out = out[mask[_mask]]
528
+ mask &= _mask
529
+ out = out.reset_index(drop=True)
530
+ return out, mask
@@ -0,0 +1,76 @@
1
+ import difflib
2
+ import json
3
+ from functools import lru_cache
4
+ from urllib.request import urlopen
5
+
6
+ import pooch
7
+ import pyarrow as pa
8
+
9
+ from kumoai.experimental.rfm import Graph
10
+ from kumoai.experimental.rfm.backend.local import LocalTable
11
+
12
+ PREFIX = 'rel-'
13
+ CACHE_DIR = pooch.os_cache('relbench')
14
+ HASH_URL = ('https://raw.githubusercontent.com/snap-stanford/relbench/main/'
15
+ 'relbench/datasets/hashes.json')
16
+
17
+
18
+ @lru_cache
19
+ def get_registry() -> pooch.Pooch:
20
+ with urlopen(HASH_URL) as r:
21
+ hashes = json.load(r)
22
+
23
+ return pooch.create(
24
+ path=CACHE_DIR,
25
+ base_url='https://relbench.stanford.edu/download/',
26
+ registry=hashes,
27
+ )
28
+
29
+
30
+ def from_relbench(dataset: str, verbose: bool = True) -> Graph:
31
+ dataset = dataset.lower()
32
+ if dataset.startswith(PREFIX):
33
+ dataset = dataset[len(PREFIX):]
34
+
35
+ registry = get_registry()
36
+
37
+ datasets = [key.split('/')[0][len(PREFIX):] for key in registry.registry]
38
+ if dataset not in datasets:
39
+ matches = difflib.get_close_matches(dataset, datasets, n=1)
40
+ hint = f" Did you mean '{matches[0]}'?" if len(matches) > 0 else ''
41
+ raise ValueError(f"Unknown RelBench dataset '{dataset}'.{hint} Valid "
42
+ f"datasets are {str(datasets)[1:-1]}.")
43
+
44
+ registry.fetch(
45
+ f'{PREFIX}{dataset}/db.zip',
46
+ processor=pooch.Unzip(extract_dir='.'),
47
+ progressbar=verbose,
48
+ )
49
+
50
+ graph = Graph(tables=[])
51
+ edges: list[tuple[str, str, str]] = []
52
+ for path in (CACHE_DIR / f'{PREFIX}{dataset}' / 'db').glob('*.parquet'):
53
+ data = pa.parquet.read_table(path)
54
+ metadata = {
55
+ key.decode('utf-8'): json.loads(value.decode('utf-8'))
56
+ for key, value in data.schema.metadata.items()
57
+ if key in [b"fkey_col_to_pkey_table", b"pkey_col", b"time_col"]
58
+ }
59
+
60
+ table = LocalTable(
61
+ df=data.to_pandas(),
62
+ name=path.stem,
63
+ primary_key=metadata['pkey_col'],
64
+ time_column=metadata['time_col'],
65
+ )
66
+ graph.add_table(table)
67
+
68
+ edges.extend([
69
+ (path.stem, fkey, dst_table)
70
+ for fkey, dst_table in metadata['fkey_col_to_pkey_table'].items()
71
+ ])
72
+
73
+ for edge in edges:
74
+ graph.link(*edge)
75
+
76
+ return graph