kumoai 2.9.0.dev202509061830__cp311-cp311-macosx_11_0_arm64.whl → 2.12.0.dev202511031731__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +4 -2
- kumoai/_version.py +1 -1
- kumoai/client/client.py +10 -5
- kumoai/client/rfm.py +3 -2
- kumoai/connector/file_upload_connector.py +71 -102
- kumoai/connector/utils.py +1367 -236
- kumoai/experimental/rfm/__init__.py +2 -2
- kumoai/experimental/rfm/authenticate.py +8 -5
- kumoai/experimental/rfm/infer/timestamp.py +7 -4
- kumoai/experimental/rfm/local_graph.py +90 -80
- kumoai/experimental/rfm/local_graph_sampler.py +16 -8
- kumoai/experimental/rfm/local_graph_store.py +22 -6
- kumoai/experimental/rfm/local_pquery_driver.py +129 -28
- kumoai/experimental/rfm/local_table.py +100 -22
- kumoai/experimental/rfm/pquery/__init__.py +4 -0
- kumoai/experimental/rfm/pquery/backend.py +4 -0
- kumoai/experimental/rfm/pquery/executor.py +102 -0
- kumoai/experimental/rfm/pquery/pandas_backend.py +71 -30
- kumoai/experimental/rfm/pquery/pandas_executor.py +506 -0
- kumoai/experimental/rfm/rfm.py +442 -94
- kumoai/jobs.py +1 -0
- kumoai/trainer/trainer.py +19 -10
- kumoai/utils/progress_logger.py +62 -0
- {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/METADATA +4 -5
- {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/RECORD +28 -26
- {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/WHEEL +0 -0
- {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/top_level.txt +0 -0
|
@@ -18,7 +18,7 @@ class LocalPQueryDriver:
|
|
|
18
18
|
self,
|
|
19
19
|
graph_store: LocalGraphStore,
|
|
20
20
|
query: PQueryDefinition,
|
|
21
|
-
random_seed: Optional[int],
|
|
21
|
+
random_seed: Optional[int] = None,
|
|
22
22
|
) -> None:
|
|
23
23
|
self._graph_store = graph_store
|
|
24
24
|
self._query = query
|
|
@@ -27,7 +27,6 @@ class LocalPQueryDriver:
|
|
|
27
27
|
|
|
28
28
|
def _get_candidates(
|
|
29
29
|
self,
|
|
30
|
-
anchor_time: Union[pd.Timestamp, Literal['entity']],
|
|
31
30
|
exclude_node: Optional[np.ndarray] = None,
|
|
32
31
|
) -> np.ndarray:
|
|
33
32
|
|
|
@@ -61,12 +60,37 @@ class LocalPQueryDriver:
|
|
|
61
60
|
|
|
62
61
|
return candidate
|
|
63
62
|
|
|
63
|
+
def _filter_candidates_by_time(
|
|
64
|
+
self,
|
|
65
|
+
candidate: np.ndarray,
|
|
66
|
+
anchor_time: pd.Timestamp,
|
|
67
|
+
) -> np.ndarray:
|
|
68
|
+
|
|
69
|
+
entity = self._query.entity.pkey.table_name
|
|
70
|
+
|
|
71
|
+
# Filter out entities that do not exist yet in time:
|
|
72
|
+
time_sec = self._graph_store.time_dict.get(entity)
|
|
73
|
+
if time_sec is not None:
|
|
74
|
+
mask = time_sec[candidate] <= (anchor_time.value // (1000**3))
|
|
75
|
+
candidate = candidate[mask]
|
|
76
|
+
|
|
77
|
+
# Filter out entities that no longer exist in time:
|
|
78
|
+
end_time_col = self._graph_store.end_time_column_dict.get(entity)
|
|
79
|
+
if end_time_col is not None:
|
|
80
|
+
ser = self._graph_store.df_dict[entity][end_time_col]
|
|
81
|
+
ser = ser.iloc[candidate]
|
|
82
|
+
mask = (anchor_time < ser) | ser.isna().to_numpy()
|
|
83
|
+
candidate = candidate[mask]
|
|
84
|
+
|
|
85
|
+
return candidate
|
|
86
|
+
|
|
64
87
|
def collect_test(
|
|
65
88
|
self,
|
|
66
89
|
size: int,
|
|
67
90
|
anchor_time: Union[pd.Timestamp, Literal['entity']],
|
|
68
91
|
batch_size: Optional[int] = None,
|
|
69
92
|
max_iterations: int = 20,
|
|
93
|
+
guarantee_train_examples: bool = True,
|
|
70
94
|
) -> Tuple[np.ndarray, pd.Series, pd.Series]:
|
|
71
95
|
r"""Collects test nodes and their labels used for evaluation.
|
|
72
96
|
|
|
@@ -75,13 +99,15 @@ class LocalPQueryDriver:
|
|
|
75
99
|
anchor_time: The anchor time.
|
|
76
100
|
batch_size: How many nodes to process in a single batch.
|
|
77
101
|
max_iterations: The number of steps to run before aborting.
|
|
102
|
+
guarantee_train_examples: Ensures that test examples do not occupy
|
|
103
|
+
the entire set of entity candidates.
|
|
78
104
|
|
|
79
105
|
Returns:
|
|
80
106
|
A triplet holding the nodes, timestamps and labels.
|
|
81
107
|
"""
|
|
82
108
|
batch_size = size if batch_size is None else batch_size
|
|
83
109
|
|
|
84
|
-
candidate = self._get_candidates(
|
|
110
|
+
candidate = self._get_candidates()
|
|
85
111
|
|
|
86
112
|
nodes: List[np.ndarray] = []
|
|
87
113
|
times: List[pd.Series] = []
|
|
@@ -93,13 +119,7 @@ class LocalPQueryDriver:
|
|
|
93
119
|
node = candidate[candidate_offset:candidate_offset + batch_size]
|
|
94
120
|
|
|
95
121
|
if isinstance(anchor_time, pd.Timestamp):
|
|
96
|
-
|
|
97
|
-
time = self._graph_store.time_dict.get(
|
|
98
|
-
self._query.entity.pkey.table_name)
|
|
99
|
-
if time is not None:
|
|
100
|
-
node = node[time[node] <= (anchor_time.value // (1000**3))]
|
|
101
|
-
|
|
102
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
122
|
+
node = self._filter_candidates_by_time(node, anchor_time)
|
|
103
123
|
time = pd.Series(anchor_time).repeat(len(node))
|
|
104
124
|
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
105
125
|
else:
|
|
@@ -148,6 +168,16 @@ class LocalPQueryDriver:
|
|
|
148
168
|
f"using the 'max_pq_iterations' option. This "
|
|
149
169
|
f"warning will not be shown again in this run.")
|
|
150
170
|
|
|
171
|
+
if (guarantee_train_examples
|
|
172
|
+
and self._query.query_type == QueryType.STATIC
|
|
173
|
+
and candidate_offset >= len(candidate)):
|
|
174
|
+
# In case all valid entities are used as test examples, we can no
|
|
175
|
+
# longer find any training example. Fallback to a 50/50 split:
|
|
176
|
+
size = len(node) // 2
|
|
177
|
+
node = node[:size]
|
|
178
|
+
time = time.iloc[:size]
|
|
179
|
+
y = y.iloc[:size]
|
|
180
|
+
|
|
151
181
|
return node, time, y
|
|
152
182
|
|
|
153
183
|
def collect_train(
|
|
@@ -172,7 +202,7 @@ class LocalPQueryDriver:
|
|
|
172
202
|
"""
|
|
173
203
|
batch_size = size if batch_size is None else batch_size
|
|
174
204
|
|
|
175
|
-
candidate = self._get_candidates(
|
|
205
|
+
candidate = self._get_candidates(exclude_node)
|
|
176
206
|
|
|
177
207
|
if len(candidate) == 0:
|
|
178
208
|
raise RuntimeError("Failed to generate any context examples "
|
|
@@ -182,22 +212,13 @@ class LocalPQueryDriver:
|
|
|
182
212
|
times: List[pd.Series] = []
|
|
183
213
|
ys: List[pd.Series] = []
|
|
184
214
|
|
|
185
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
186
|
-
anchor_time = anchor_time - self._query.target.end_offset
|
|
187
|
-
|
|
188
215
|
reached_end = False
|
|
189
216
|
num_labels = candidate_offset = 0
|
|
190
217
|
for _ in range(max_iterations):
|
|
191
218
|
node = candidate[candidate_offset:candidate_offset + batch_size]
|
|
192
219
|
|
|
193
220
|
if isinstance(anchor_time, pd.Timestamp):
|
|
194
|
-
|
|
195
|
-
time = self._graph_store.time_dict.get(
|
|
196
|
-
self._query.entity.pkey.table_name)
|
|
197
|
-
if time is not None:
|
|
198
|
-
node = node[time[node] <= (anchor_time.value // (1000**3))]
|
|
199
|
-
|
|
200
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
221
|
+
node = self._filter_candidates_by_time(node, anchor_time)
|
|
201
222
|
time = pd.Series(anchor_time).repeat(len(node))
|
|
202
223
|
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
203
224
|
else:
|
|
@@ -228,7 +249,8 @@ class LocalPQueryDriver:
|
|
|
228
249
|
reached_end = True
|
|
229
250
|
break
|
|
230
251
|
candidate_offset = 0
|
|
231
|
-
anchor_time = anchor_time - self._query.target.end_offset
|
|
252
|
+
anchor_time = anchor_time - (self._query.target.end_offset *
|
|
253
|
+
self._query.num_forecasts)
|
|
232
254
|
if anchor_time < self._graph_store.min_time:
|
|
233
255
|
reached_end = True
|
|
234
256
|
break # No earlier anchor time left. Abort.
|
|
@@ -257,12 +279,81 @@ class LocalPQueryDriver:
|
|
|
257
279
|
|
|
258
280
|
return node, time, y
|
|
259
281
|
|
|
260
|
-
def
|
|
282
|
+
def is_valid(
|
|
283
|
+
self,
|
|
284
|
+
node: np.ndarray,
|
|
285
|
+
anchor_time: Union[pd.Timestamp, Literal['entity']],
|
|
286
|
+
batch_size: int = 10_000,
|
|
287
|
+
) -> np.ndarray:
|
|
288
|
+
r"""Denotes which nodes are valid for a given anchor time, *e.g.*,
|
|
289
|
+
which nodes fulfill entity filter constraints.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
node: The nodes to check for.
|
|
293
|
+
anchor_time: The anchor time.
|
|
294
|
+
batch_size: How many nodes to process in a single batch.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
The mask.
|
|
298
|
+
"""
|
|
299
|
+
mask: Optional[np.ndarray] = None
|
|
300
|
+
|
|
301
|
+
if isinstance(anchor_time, pd.Timestamp):
|
|
302
|
+
node = self._filter_candidates_by_time(node, anchor_time)
|
|
303
|
+
time = pd.Series(anchor_time).repeat(len(node))
|
|
304
|
+
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
305
|
+
else:
|
|
306
|
+
assert anchor_time == 'entity'
|
|
307
|
+
time = self._graph_store.time_dict[
|
|
308
|
+
self._query.entity.pkey.table_name]
|
|
309
|
+
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
310
|
+
|
|
311
|
+
if self._query.entity.filter is not None:
|
|
312
|
+
# Mask out via (temporal) entity filter:
|
|
313
|
+
backend = PQueryPandasBackend()
|
|
314
|
+
masks: List[np.ndarray] = []
|
|
315
|
+
for start in range(0, len(node), batch_size):
|
|
316
|
+
feat_dict, time_dict, batch_dict = self._sample(
|
|
317
|
+
node[start:start + batch_size],
|
|
318
|
+
time.iloc[start:start + batch_size],
|
|
319
|
+
)
|
|
320
|
+
_mask = backend.eval_filter(
|
|
321
|
+
filter=self._query.entity.filter,
|
|
322
|
+
feat_dict=feat_dict,
|
|
323
|
+
time_dict=time_dict,
|
|
324
|
+
batch_dict=batch_dict,
|
|
325
|
+
anchor_time=time.iloc[start:start + batch_size],
|
|
326
|
+
)
|
|
327
|
+
masks.append(_mask)
|
|
328
|
+
|
|
329
|
+
_mask = np.concatenate(masks)
|
|
330
|
+
mask = (mask & _mask) if mask is not None else _mask
|
|
331
|
+
|
|
332
|
+
if mask is None:
|
|
333
|
+
mask = np.ones(len(node), dtype=bool)
|
|
334
|
+
|
|
335
|
+
return mask
|
|
336
|
+
|
|
337
|
+
def _sample(
|
|
261
338
|
self,
|
|
262
339
|
node: np.ndarray,
|
|
263
340
|
anchor_time: pd.Series,
|
|
264
|
-
) -> Tuple[
|
|
341
|
+
) -> Tuple[
|
|
342
|
+
Dict[str, pd.DataFrame],
|
|
343
|
+
Dict[str, pd.Series],
|
|
344
|
+
Dict[str, np.ndarray],
|
|
345
|
+
]:
|
|
346
|
+
r"""Samples a subgraph that contains all relevant information to
|
|
347
|
+
evaluate the predictive query.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
node: The nodes to check for.
|
|
351
|
+
anchor_time: The anchor time.
|
|
265
352
|
|
|
353
|
+
Returns:
|
|
354
|
+
The feature dictionary, the time column dictionary and the batch
|
|
355
|
+
dictionary.
|
|
356
|
+
"""
|
|
266
357
|
specs = self._query.get_sampling_specs(self._graph_store.edge_types)
|
|
267
358
|
num_hops = max([spec.hop for spec in specs] + [0])
|
|
268
359
|
num_neighbors: Dict[Tuple[str, str, str], list[int]] = {}
|
|
@@ -275,11 +366,10 @@ class LocalPQueryDriver:
|
|
|
275
366
|
if spec.edge_type not in time_offsets:
|
|
276
367
|
time_offsets[spec.edge_type] = [[0, 0]
|
|
277
368
|
for _ in range(num_hops)]
|
|
278
|
-
offset: Optional[int] =
|
|
279
|
-
spec.end_offset)
|
|
369
|
+
offset: Optional[int] = date_offset_to_seconds(spec.end_offset)
|
|
280
370
|
time_offsets[spec.edge_type][spec.hop - 1][1] = offset
|
|
281
371
|
if spec.start_offset is not None:
|
|
282
|
-
offset =
|
|
372
|
+
offset = date_offset_to_seconds(spec.start_offset)
|
|
283
373
|
else:
|
|
284
374
|
offset = None
|
|
285
375
|
time_offsets[spec.edge_type][spec.hop - 1][0] = offset
|
|
@@ -341,18 +431,29 @@ class LocalPQueryDriver:
|
|
|
341
431
|
time_col = self._graph_store.time_column_dict[table_name]
|
|
342
432
|
time_dict[table_name] = df[time_col]
|
|
343
433
|
|
|
434
|
+
return feat_dict, time_dict, batch_dict
|
|
435
|
+
|
|
436
|
+
def __call__(
|
|
437
|
+
self,
|
|
438
|
+
node: np.ndarray,
|
|
439
|
+
anchor_time: pd.Series,
|
|
440
|
+
) -> Tuple[pd.Series, np.ndarray]:
|
|
441
|
+
|
|
442
|
+
feat_dict, time_dict, batch_dict = self._sample(node, anchor_time)
|
|
443
|
+
|
|
344
444
|
y, mask = PQueryPandasBackend().eval_pquery(
|
|
345
445
|
query=self._query,
|
|
346
446
|
feat_dict=feat_dict,
|
|
347
447
|
time_dict=time_dict,
|
|
348
448
|
batch_dict=batch_dict,
|
|
349
449
|
anchor_time=anchor_time,
|
|
450
|
+
num_forecasts=self._query.num_forecasts,
|
|
350
451
|
)
|
|
351
452
|
|
|
352
453
|
return y, mask
|
|
353
454
|
|
|
354
455
|
|
|
355
|
-
def
|
|
456
|
+
def date_offset_to_seconds(offset: pd.DateOffset) -> int:
|
|
356
457
|
r"""Convert a :class:`pandas.DateOffset` into a maximum number of
|
|
357
458
|
nanoseconds.
|
|
358
459
|
|
|
@@ -23,11 +23,13 @@ class Column:
|
|
|
23
23
|
stype: Stype,
|
|
24
24
|
is_primary_key: bool = False,
|
|
25
25
|
is_time_column: bool = False,
|
|
26
|
+
is_end_time_column: bool = False,
|
|
26
27
|
) -> None:
|
|
27
28
|
self._name = name
|
|
28
29
|
self._dtype = Dtype(dtype)
|
|
29
30
|
self._is_primary_key = is_primary_key
|
|
30
31
|
self._is_time_column = is_time_column
|
|
32
|
+
self._is_end_time_column = is_end_time_column
|
|
31
33
|
self.stype = Stype(stype)
|
|
32
34
|
|
|
33
35
|
@property
|
|
@@ -50,9 +52,12 @@ class Column:
|
|
|
50
52
|
if self._is_primary_key and val != Stype.ID:
|
|
51
53
|
raise ValueError(f"Primary key '{self.name}' must have 'ID' "
|
|
52
54
|
f"semantic type (got '{val}')")
|
|
53
|
-
if self.
|
|
55
|
+
if self._is_time_column and val != Stype.timestamp:
|
|
54
56
|
raise ValueError(f"Time column '{self.name}' must have "
|
|
55
57
|
f"'timestamp' semantic type (got '{val}')")
|
|
58
|
+
if self._is_end_time_column and val != Stype.timestamp:
|
|
59
|
+
raise ValueError(f"End time column '{self.name}' must have "
|
|
60
|
+
f"'timestamp' semantic type (got '{val}')")
|
|
56
61
|
|
|
57
62
|
super().__setattr__(key, val)
|
|
58
63
|
|
|
@@ -93,6 +98,7 @@ class LocalTable:
|
|
|
93
98
|
name="my_table",
|
|
94
99
|
primary_key="id",
|
|
95
100
|
time_column="time",
|
|
101
|
+
end_time_column=None,
|
|
96
102
|
)
|
|
97
103
|
|
|
98
104
|
# Verify metadata:
|
|
@@ -106,6 +112,8 @@ class LocalTable:
|
|
|
106
112
|
name: The name of the table.
|
|
107
113
|
primary_key: The name of the primary key of this table, if it exists.
|
|
108
114
|
time_column: The name of the time column of this table, if it exists.
|
|
115
|
+
end_time_column: The name of the end time column of this table, if it
|
|
116
|
+
exists.
|
|
109
117
|
"""
|
|
110
118
|
def __init__(
|
|
111
119
|
self,
|
|
@@ -113,6 +121,7 @@ class LocalTable:
|
|
|
113
121
|
name: str,
|
|
114
122
|
primary_key: Optional[str] = None,
|
|
115
123
|
time_column: Optional[str] = None,
|
|
124
|
+
end_time_column: Optional[str] = None,
|
|
116
125
|
) -> None:
|
|
117
126
|
|
|
118
127
|
if df.empty:
|
|
@@ -130,6 +139,7 @@ class LocalTable:
|
|
|
130
139
|
self._name = name
|
|
131
140
|
self._primary_key: Optional[str] = None
|
|
132
141
|
self._time_column: Optional[str] = None
|
|
142
|
+
self._end_time_column: Optional[str] = None
|
|
133
143
|
|
|
134
144
|
self._columns: Dict[str, Column] = {}
|
|
135
145
|
for column_name in df.columns:
|
|
@@ -141,6 +151,9 @@ class LocalTable:
|
|
|
141
151
|
if time_column is not None:
|
|
142
152
|
self.time_column = time_column
|
|
143
153
|
|
|
154
|
+
if end_time_column is not None:
|
|
155
|
+
self.end_time_column = end_time_column
|
|
156
|
+
|
|
144
157
|
@property
|
|
145
158
|
def name(self) -> str:
|
|
146
159
|
r"""The name of the table."""
|
|
@@ -230,6 +243,8 @@ class LocalTable:
|
|
|
230
243
|
self.primary_key = None
|
|
231
244
|
if self._time_column == name:
|
|
232
245
|
self.time_column = None
|
|
246
|
+
if self._end_time_column == name:
|
|
247
|
+
self.end_time_column = None
|
|
233
248
|
del self._columns[name]
|
|
234
249
|
|
|
235
250
|
return self
|
|
@@ -253,9 +268,8 @@ class LocalTable:
|
|
|
253
268
|
:class:`ValueError` if the primary key has a non-ID semantic type or
|
|
254
269
|
if the column name does not match a column in the data frame.
|
|
255
270
|
"""
|
|
256
|
-
if
|
|
271
|
+
if self._primary_key is None:
|
|
257
272
|
return None
|
|
258
|
-
assert self._primary_key is not None
|
|
259
273
|
return self[self._primary_key]
|
|
260
274
|
|
|
261
275
|
@primary_key.setter
|
|
@@ -264,6 +278,10 @@ class LocalTable:
|
|
|
264
278
|
raise ValueError(f"Cannot specify column '{name}' as a primary "
|
|
265
279
|
f"key since it is already defined to be a time "
|
|
266
280
|
f"column")
|
|
281
|
+
if name is not None and name == self._end_time_column:
|
|
282
|
+
raise ValueError(f"Cannot specify column '{name}' as a primary "
|
|
283
|
+
f"key since it is already defined to be an end "
|
|
284
|
+
f"time column")
|
|
267
285
|
|
|
268
286
|
if self.primary_key is not None:
|
|
269
287
|
self.primary_key._is_primary_key = False
|
|
@@ -295,9 +313,8 @@ class LocalTable:
|
|
|
295
313
|
:class:`ValueError` if the time column has a non-timestamp semantic
|
|
296
314
|
type or if the column name does not match a column in the data frame.
|
|
297
315
|
"""
|
|
298
|
-
if
|
|
316
|
+
if self._time_column is None:
|
|
299
317
|
return None
|
|
300
|
-
assert self._time_column is not None
|
|
301
318
|
return self[self._time_column]
|
|
302
319
|
|
|
303
320
|
@time_column.setter
|
|
@@ -306,6 +323,10 @@ class LocalTable:
|
|
|
306
323
|
raise ValueError(f"Cannot specify column '{name}' as a time "
|
|
307
324
|
f"column since it is already defined to be a "
|
|
308
325
|
f"primary key")
|
|
326
|
+
if name is not None and name == self._end_time_column:
|
|
327
|
+
raise ValueError(f"Cannot specify column '{name}' as a time "
|
|
328
|
+
f"column since it is already defined to be an "
|
|
329
|
+
f"end time column")
|
|
309
330
|
|
|
310
331
|
if self.time_column is not None:
|
|
311
332
|
self.time_column._is_time_column = False
|
|
@@ -318,6 +339,52 @@ class LocalTable:
|
|
|
318
339
|
self[name]._is_time_column = True
|
|
319
340
|
self._time_column = name
|
|
320
341
|
|
|
342
|
+
# End Time column #########################################################
|
|
343
|
+
|
|
344
|
+
def has_end_time_column(self) -> bool:
|
|
345
|
+
r"""Returns ``True`` if this table has an end time column; ``False``
|
|
346
|
+
otherwise.
|
|
347
|
+
"""
|
|
348
|
+
return self._end_time_column is not None
|
|
349
|
+
|
|
350
|
+
@property
|
|
351
|
+
def end_time_column(self) -> Optional[Column]:
|
|
352
|
+
r"""The end time column of this table.
|
|
353
|
+
|
|
354
|
+
The getter returns the end time column of this table, or ``None`` if no
|
|
355
|
+
such end time column is present.
|
|
356
|
+
|
|
357
|
+
The setter sets a column as an end time column on this table, and
|
|
358
|
+
raises a :class:`ValueError` if the end time column has a non-timestamp
|
|
359
|
+
semantic type or if the column name does not match a column in the data
|
|
360
|
+
frame.
|
|
361
|
+
"""
|
|
362
|
+
if self._end_time_column is None:
|
|
363
|
+
return None
|
|
364
|
+
return self[self._end_time_column]
|
|
365
|
+
|
|
366
|
+
@end_time_column.setter
|
|
367
|
+
def end_time_column(self, name: Optional[str]) -> None:
|
|
368
|
+
if name is not None and name == self._primary_key:
|
|
369
|
+
raise ValueError(f"Cannot specify column '{name}' as an end time "
|
|
370
|
+
f"column since it is already defined to be a "
|
|
371
|
+
f"primary key")
|
|
372
|
+
if name is not None and name == self._time_column:
|
|
373
|
+
raise ValueError(f"Cannot specify column '{name}' as an end time "
|
|
374
|
+
f"column since it is already defined to be a "
|
|
375
|
+
f"time column")
|
|
376
|
+
|
|
377
|
+
if self.end_time_column is not None:
|
|
378
|
+
self.end_time_column._is_end_time_column = False
|
|
379
|
+
|
|
380
|
+
if name is None:
|
|
381
|
+
self._end_time_column = None
|
|
382
|
+
return
|
|
383
|
+
|
|
384
|
+
self[name].stype = Stype.timestamp
|
|
385
|
+
self[name]._is_end_time_column = True
|
|
386
|
+
self._end_time_column = name
|
|
387
|
+
|
|
321
388
|
# Metadata ################################################################
|
|
322
389
|
|
|
323
390
|
@property
|
|
@@ -326,16 +393,18 @@ class LocalTable:
|
|
|
326
393
|
information about the columns in this table.
|
|
327
394
|
|
|
328
395
|
The returned dataframe has columns ``name``, ``dtype``, ``stype``,
|
|
329
|
-
``is_primary_key``, and ``
|
|
330
|
-
view of the properties of the columns of
|
|
396
|
+
``is_primary_key``, ``is_time_column`` and ``is_end_time_column``,
|
|
397
|
+
which provide an aggregate view of the properties of the columns of
|
|
398
|
+
this table.
|
|
331
399
|
|
|
332
400
|
Example:
|
|
401
|
+
>>> # doctest: +SKIP
|
|
333
402
|
>>> import kumoai.experimental.rfm as rfm
|
|
334
403
|
>>> table = rfm.LocalTable(df=..., name=...).infer_metadata()
|
|
335
404
|
>>> table.metadata
|
|
336
|
-
name dtype
|
|
337
|
-
0 CustomerID float64
|
|
338
|
-
"""
|
|
405
|
+
name dtype stype is_primary_key is_time_column is_end_time_column
|
|
406
|
+
0 CustomerID float64 ID True False False
|
|
407
|
+
""" # noqa: E501
|
|
339
408
|
cols = self.columns
|
|
340
409
|
|
|
341
410
|
return pd.DataFrame({
|
|
@@ -355,6 +424,11 @@ class LocalTable:
|
|
|
355
424
|
dtype=bool,
|
|
356
425
|
data=[self._time_column == c.name for c in cols],
|
|
357
426
|
),
|
|
427
|
+
'is_end_time_column':
|
|
428
|
+
pd.Series(
|
|
429
|
+
dtype=bool,
|
|
430
|
+
data=[self._end_time_column == c.name for c in cols],
|
|
431
|
+
),
|
|
358
432
|
})
|
|
359
433
|
|
|
360
434
|
def print_metadata(self) -> None:
|
|
@@ -417,6 +491,7 @@ class LocalTable:
|
|
|
417
491
|
candidates = [
|
|
418
492
|
column.name for column in self.columns
|
|
419
493
|
if column.stype == Stype.timestamp
|
|
494
|
+
and column.name != self._end_time_column
|
|
420
495
|
]
|
|
421
496
|
if time_column := utils.detect_time_column(self._data, candidates):
|
|
422
497
|
self.time_column = time_column
|
|
@@ -430,24 +505,26 @@ class LocalTable:
|
|
|
430
505
|
# Helpers #################################################################
|
|
431
506
|
|
|
432
507
|
def _to_api_table_definition(self) -> TableDefinition:
|
|
433
|
-
cols: List[ColumnDefinition] = []
|
|
434
|
-
for col in self.columns:
|
|
435
|
-
cols.append(ColumnDefinition(col.name, col.stype, col.dtype))
|
|
436
|
-
pkey = self._primary_key
|
|
437
|
-
time_col = self._time_column
|
|
438
|
-
source_table = UnavailableSourceTable(table=self.name)
|
|
439
|
-
|
|
440
508
|
return TableDefinition(
|
|
441
|
-
cols=
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
509
|
+
cols=[
|
|
510
|
+
ColumnDefinition(col.name, col.stype, col.dtype)
|
|
511
|
+
for col in self.columns
|
|
512
|
+
],
|
|
513
|
+
source_table=UnavailableSourceTable(table=self.name),
|
|
514
|
+
pkey=self._primary_key,
|
|
515
|
+
time_col=self._time_column,
|
|
516
|
+
end_time_col=self._end_time_column,
|
|
445
517
|
)
|
|
446
518
|
|
|
447
519
|
# Python builtins #########################################################
|
|
448
520
|
|
|
449
521
|
def __hash__(self) -> int:
|
|
450
|
-
|
|
522
|
+
special_columns = [
|
|
523
|
+
self.primary_key,
|
|
524
|
+
self.time_column,
|
|
525
|
+
self.end_time_column,
|
|
526
|
+
]
|
|
527
|
+
return hash(tuple(self.columns + special_columns))
|
|
451
528
|
|
|
452
529
|
def __contains__(self, name: str) -> bool:
|
|
453
530
|
return self.has_column(name)
|
|
@@ -464,4 +541,5 @@ class LocalTable:
|
|
|
464
541
|
f' num_columns={len(self.columns)},\n'
|
|
465
542
|
f' primary_key={self._primary_key},\n'
|
|
466
543
|
f' time_column={self._time_column},\n'
|
|
544
|
+
f' end_time_column={self._end_time_column},\n'
|
|
467
545
|
f')')
|
|
@@ -1,7 +1,11 @@
|
|
|
1
1
|
from .backend import PQueryBackend
|
|
2
2
|
from .pandas_backend import PQueryPandasBackend
|
|
3
|
+
from .executor import PQueryExecutor
|
|
4
|
+
from .pandas_executor import PQueryPandasExecutor
|
|
3
5
|
|
|
4
6
|
__all__ = [
|
|
5
7
|
'PQueryBackend',
|
|
6
8
|
'PQueryPandasBackend',
|
|
9
|
+
'PQueryExecutor',
|
|
10
|
+
'PQueryPandasExecutor',
|
|
7
11
|
]
|
|
@@ -82,6 +82,7 @@ class PQueryBackend(Generic[TableData, ColumnData, IndexData], ABC):
|
|
|
82
82
|
batch_dict: Dict[str, IndexData],
|
|
83
83
|
anchor_time: ColumnData,
|
|
84
84
|
filter_na: bool = True,
|
|
85
|
+
num_forecasts: int = 1,
|
|
85
86
|
) -> Tuple[ColumnData, IndexData]:
|
|
86
87
|
pass
|
|
87
88
|
|
|
@@ -94,6 +95,7 @@ class PQueryBackend(Generic[TableData, ColumnData, IndexData], ABC):
|
|
|
94
95
|
batch_dict: Dict[str, IndexData],
|
|
95
96
|
anchor_time: ColumnData,
|
|
96
97
|
filter_na: bool = True,
|
|
98
|
+
num_forecasts: int = 1,
|
|
97
99
|
) -> Tuple[ColumnData, IndexData]:
|
|
98
100
|
pass
|
|
99
101
|
|
|
@@ -106,6 +108,7 @@ class PQueryBackend(Generic[TableData, ColumnData, IndexData], ABC):
|
|
|
106
108
|
batch_dict: Dict[str, IndexData],
|
|
107
109
|
anchor_time: ColumnData,
|
|
108
110
|
filter_na: bool = True,
|
|
111
|
+
num_forecasts: int = 1,
|
|
109
112
|
) -> Tuple[ColumnData, IndexData]:
|
|
110
113
|
pass
|
|
111
114
|
|
|
@@ -128,5 +131,6 @@ class PQueryBackend(Generic[TableData, ColumnData, IndexData], ABC):
|
|
|
128
131
|
time_dict: Dict[str, ColumnData],
|
|
129
132
|
batch_dict: Dict[str, IndexData],
|
|
130
133
|
anchor_time: ColumnData,
|
|
134
|
+
num_forecasts: int = 1,
|
|
131
135
|
) -> Tuple[ColumnData, IndexData]:
|
|
132
136
|
pass
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Dict, Generic, Tuple, TypeVar
|
|
3
|
+
|
|
4
|
+
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
5
|
+
from kumoapi.pquery.AST import (
|
|
6
|
+
Aggregation,
|
|
7
|
+
Column,
|
|
8
|
+
Condition,
|
|
9
|
+
Filter,
|
|
10
|
+
Join,
|
|
11
|
+
LogicalOperation,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
TableData = TypeVar('TableData')
|
|
15
|
+
ColumnData = TypeVar('ColumnData')
|
|
16
|
+
IndexData = TypeVar('IndexData')
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PQueryExecutor(Generic[TableData, ColumnData, IndexData], ABC):
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def execute_column(
|
|
22
|
+
self,
|
|
23
|
+
column: Column,
|
|
24
|
+
feat_dict: Dict[str, TableData],
|
|
25
|
+
filter_na: bool = True,
|
|
26
|
+
) -> Tuple[ColumnData, IndexData]:
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def execute_aggregation(
|
|
31
|
+
self,
|
|
32
|
+
aggr: Aggregation,
|
|
33
|
+
feat_dict: Dict[str, TableData],
|
|
34
|
+
time_dict: Dict[str, ColumnData],
|
|
35
|
+
batch_dict: Dict[str, IndexData],
|
|
36
|
+
anchor_time: ColumnData,
|
|
37
|
+
filter_na: bool = True,
|
|
38
|
+
num_forecasts: int = 1,
|
|
39
|
+
) -> Tuple[ColumnData, IndexData]:
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def execute_condition(
|
|
44
|
+
self,
|
|
45
|
+
condition: Condition,
|
|
46
|
+
feat_dict: Dict[str, TableData],
|
|
47
|
+
time_dict: Dict[str, ColumnData],
|
|
48
|
+
batch_dict: Dict[str, IndexData],
|
|
49
|
+
anchor_time: ColumnData,
|
|
50
|
+
filter_na: bool = True,
|
|
51
|
+
num_forecasts: int = 1,
|
|
52
|
+
) -> Tuple[ColumnData, IndexData]:
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def execute_logical_operation(
|
|
57
|
+
self,
|
|
58
|
+
logical_operation: LogicalOperation,
|
|
59
|
+
feat_dict: Dict[str, TableData],
|
|
60
|
+
time_dict: Dict[str, ColumnData],
|
|
61
|
+
batch_dict: Dict[str, IndexData],
|
|
62
|
+
anchor_time: ColumnData,
|
|
63
|
+
filter_na: bool = True,
|
|
64
|
+
num_forecasts: int = 1,
|
|
65
|
+
) -> Tuple[ColumnData, IndexData]:
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
@abstractmethod
|
|
69
|
+
def execute_join(
|
|
70
|
+
self,
|
|
71
|
+
join: Join,
|
|
72
|
+
feat_dict: Dict[str, TableData],
|
|
73
|
+
time_dict: Dict[str, ColumnData],
|
|
74
|
+
batch_dict: Dict[str, IndexData],
|
|
75
|
+
anchor_time: ColumnData,
|
|
76
|
+
filter_na: bool = True,
|
|
77
|
+
num_forecasts: int = 1,
|
|
78
|
+
) -> Tuple[ColumnData, IndexData]:
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
@abstractmethod
|
|
82
|
+
def execute_filter(
|
|
83
|
+
self,
|
|
84
|
+
filter: Filter,
|
|
85
|
+
feat_dict: Dict[str, TableData],
|
|
86
|
+
time_dict: Dict[str, ColumnData],
|
|
87
|
+
batch_dict: Dict[str, IndexData],
|
|
88
|
+
anchor_time: ColumnData,
|
|
89
|
+
) -> Tuple[ColumnData, IndexData]:
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
def execute(
|
|
94
|
+
self,
|
|
95
|
+
query: ValidatedPredictiveQuery,
|
|
96
|
+
feat_dict: Dict[str, TableData],
|
|
97
|
+
time_dict: Dict[str, ColumnData],
|
|
98
|
+
batch_dict: Dict[str, IndexData],
|
|
99
|
+
anchor_time: ColumnData,
|
|
100
|
+
num_forecasts: int = 1,
|
|
101
|
+
) -> Tuple[ColumnData, IndexData]:
|
|
102
|
+
pass
|