kumoai 2.13.0.dev202512011731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512181731__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. kumoai/__init__.py +12 -0
  2. kumoai/_version.py +1 -1
  3. kumoai/client/pquery.py +6 -2
  4. kumoai/experimental/rfm/__init__.py +33 -8
  5. kumoai/experimental/rfm/authenticate.py +3 -4
  6. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  7. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +53 -107
  8. kumoai/experimental/rfm/backend/local/sampler.py +315 -0
  9. kumoai/experimental/rfm/backend/local/table.py +41 -80
  10. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  11. kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
  12. kumoai/experimental/rfm/backend/snow/table.py +147 -0
  13. kumoai/experimental/rfm/backend/sqlite/__init__.py +11 -2
  14. kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
  15. kumoai/experimental/rfm/backend/sqlite/table.py +108 -88
  16. kumoai/experimental/rfm/base/__init__.py +26 -2
  17. kumoai/experimental/rfm/base/column.py +6 -12
  18. kumoai/experimental/rfm/base/column_expression.py +16 -0
  19. kumoai/experimental/rfm/base/sampler.py +773 -0
  20. kumoai/experimental/rfm/base/source.py +19 -0
  21. kumoai/experimental/rfm/base/sql_sampler.py +84 -0
  22. kumoai/experimental/rfm/base/sql_table.py +113 -0
  23. kumoai/experimental/rfm/base/table.py +174 -76
  24. kumoai/experimental/rfm/graph.py +444 -84
  25. kumoai/experimental/rfm/infer/__init__.py +6 -0
  26. kumoai/experimental/rfm/infer/dtype.py +77 -0
  27. kumoai/experimental/rfm/infer/pkey.py +128 -0
  28. kumoai/experimental/rfm/infer/time_col.py +61 -0
  29. kumoai/experimental/rfm/pquery/executor.py +27 -27
  30. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  31. kumoai/experimental/rfm/rfm.py +299 -240
  32. kumoai/experimental/rfm/sagemaker.py +4 -4
  33. kumoai/pquery/predictive_query.py +10 -6
  34. kumoai/testing/snow.py +50 -0
  35. kumoai/utils/__init__.py +3 -2
  36. kumoai/utils/progress_logger.py +178 -12
  37. kumoai/utils/sql.py +3 -0
  38. {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/METADATA +6 -2
  39. {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/RECORD +42 -30
  40. kumoai/experimental/rfm/local_graph_sampler.py +0 -182
  41. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  42. kumoai/experimental/rfm/utils.py +0 -344
  43. {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/WHEEL +0 -0
  44. {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/licenses/LICENSE +0 -0
  45. {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/top_level.txt +0 -0
@@ -1,689 +0,0 @@
1
- import warnings
2
- from typing import Dict, List, Literal, NamedTuple, Optional, Set, Tuple, Union
3
-
4
- import numpy as np
5
- import pandas as pd
6
- from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
7
- from kumoapi.pquery.AST import (
8
- Aggregation,
9
- ASTNode,
10
- Column,
11
- Condition,
12
- Filter,
13
- Join,
14
- LogicalOperation,
15
- )
16
- from kumoapi.task import TaskType
17
- from kumoapi.typing import AggregationType, DateOffset, Stype
18
-
19
- import kumoai.kumolib as kumolib
20
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
21
- from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
22
-
23
- _coverage_warned = False
24
-
25
-
26
- class SamplingSpec(NamedTuple):
27
- edge_type: Tuple[str, str, str]
28
- hop: int
29
- start_offset: Optional[DateOffset]
30
- end_offset: Optional[DateOffset]
31
-
32
-
33
- class LocalPQueryDriver:
34
- def __init__(
35
- self,
36
- graph_store: LocalGraphStore,
37
- query: ValidatedPredictiveQuery,
38
- random_seed: Optional[int] = None,
39
- ) -> None:
40
- self._graph_store = graph_store
41
- self._query = query
42
- self._random_seed = random_seed
43
- self._rng = np.random.default_rng(random_seed)
44
-
45
- def _get_candidates(
46
- self,
47
- exclude_node: Optional[np.ndarray] = None,
48
- ) -> np.ndarray:
49
-
50
- if self._query.query_type == QueryType.TEMPORAL:
51
- assert exclude_node is None
52
-
53
- table_name = self._query.entity_table
54
- num_nodes = len(self._graph_store.df_dict[table_name])
55
- mask_dict = self._graph_store.mask_dict
56
-
57
- candidate: np.ndarray
58
-
59
- # Case 1: All nodes are valid and nothing to exclude:
60
- if exclude_node is None and table_name not in mask_dict:
61
- candidate = np.arange(num_nodes)
62
-
63
- # Case 2: Not all nodes are valid - lookup valid nodes:
64
- if exclude_node is None:
65
- pkey_map = self._graph_store.pkey_map_dict[table_name]
66
- candidate = pkey_map['arange'].to_numpy().copy()
67
-
68
- # Case 3: Exclude nodes - use a mask to exclude them:
69
- else:
70
- mask = np.full((num_nodes, ), fill_value=True, dtype=bool)
71
- mask[exclude_node] = False
72
- if table_name in mask_dict:
73
- mask &= mask_dict[table_name]
74
- candidate = mask.nonzero()[0]
75
-
76
- self._rng.shuffle(candidate)
77
-
78
- return candidate
79
-
80
- def _filter_candidates_by_time(
81
- self,
82
- candidate: np.ndarray,
83
- anchor_time: pd.Timestamp,
84
- ) -> np.ndarray:
85
-
86
- entity = self._query.entity_table
87
-
88
- # Filter out entities that do not exist yet in time:
89
- time_sec = self._graph_store.time_dict.get(entity)
90
- if time_sec is not None:
91
- mask = time_sec[candidate] <= (anchor_time.value // (1000**3))
92
- candidate = candidate[mask]
93
-
94
- # Filter out entities that no longer exist in time:
95
- end_time_col = self._graph_store.end_time_column_dict.get(entity)
96
- if end_time_col is not None:
97
- ser = self._graph_store.df_dict[entity][end_time_col]
98
- ser = ser.iloc[candidate]
99
- mask = (anchor_time < ser) | ser.isna().to_numpy()
100
- candidate = candidate[mask]
101
-
102
- return candidate
103
-
104
- def collect_test(
105
- self,
106
- size: int,
107
- anchor_time: Union[pd.Timestamp, Literal['entity']],
108
- batch_size: Optional[int] = None,
109
- max_iterations: int = 20,
110
- guarantee_train_examples: bool = True,
111
- ) -> Tuple[np.ndarray, pd.Series, pd.Series]:
112
- r"""Collects test nodes and their labels used for evaluation.
113
-
114
- Args:
115
- size: The number of test nodes to collect.
116
- anchor_time: The anchor time.
117
- batch_size: How many nodes to process in a single batch.
118
- max_iterations: The number of steps to run before aborting.
119
- guarantee_train_examples: Ensures that test examples do not occupy
120
- the entire set of entity candidates.
121
-
122
- Returns:
123
- A triplet holding the nodes, timestamps and labels.
124
- """
125
- batch_size = size if batch_size is None else batch_size
126
-
127
- candidate = self._get_candidates()
128
-
129
- nodes: List[np.ndarray] = []
130
- times: List[pd.Series] = []
131
- ys: List[pd.Series] = []
132
-
133
- reached_end = False
134
- num_labels = candidate_offset = 0
135
- for _ in range(max_iterations):
136
- node = candidate[candidate_offset:candidate_offset + batch_size]
137
-
138
- if isinstance(anchor_time, pd.Timestamp):
139
- node = self._filter_candidates_by_time(node, anchor_time)
140
- time = pd.Series(anchor_time).repeat(len(node))
141
- time = time.astype('datetime64[ns]').reset_index(drop=True)
142
- else:
143
- assert anchor_time == 'entity'
144
- time = self._graph_store.time_dict[self._query.entity_table]
145
- time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
146
-
147
- y, mask = self(node, time)
148
-
149
- nodes.append(node[mask])
150
- times.append(time[mask].reset_index(drop=True))
151
- ys.append(y)
152
-
153
- num_labels += len(y)
154
-
155
- if num_labels > size:
156
- reached_end = True
157
- break # Sufficient number of labels collected. Abort.
158
-
159
- candidate_offset += batch_size
160
- if candidate_offset >= len(candidate):
161
- reached_end = True
162
- break
163
-
164
- if len(nodes) > 1:
165
- node = np.concatenate(nodes, axis=0)[:size]
166
- time = pd.concat(times, axis=0).reset_index(drop=True).iloc[:size]
167
- y = pd.concat(ys, axis=0).reset_index(drop=True).iloc[:size]
168
- else:
169
- node = nodes[0][:size]
170
- time = times[0].iloc[:size]
171
- y = ys[0].iloc[:size]
172
-
173
- if len(node) == 0:
174
- raise RuntimeError("Failed to collect any test examples for "
175
- "evaluation. Is your predictive query too "
176
- "restrictive?")
177
-
178
- global _coverage_warned
179
- if not _coverage_warned and not reached_end and len(node) < size // 2:
180
- _coverage_warned = True
181
- warnings.warn(f"Failed to collect {size:,} test examples within "
182
- f"{max_iterations} iterations. To improve coverage, "
183
- f"consider increasing the number of PQ iterations "
184
- f"using the 'max_pq_iterations' option. This "
185
- f"warning will not be shown again in this run.")
186
-
187
- if (guarantee_train_examples
188
- and self._query.query_type == QueryType.STATIC
189
- and candidate_offset >= len(candidate)):
190
- # In case all valid entities are used as test examples, we can no
191
- # longer find any training example. Fallback to a 50/50 split:
192
- size = len(node) // 2
193
- node = node[:size]
194
- time = time.iloc[:size]
195
- y = y.iloc[:size]
196
-
197
- return node, time, y
198
-
199
- def collect_train(
200
- self,
201
- size: int,
202
- anchor_time: Union[pd.Timestamp, Literal['entity']],
203
- exclude_node: Optional[np.ndarray] = None,
204
- batch_size: Optional[int] = None,
205
- max_iterations: int = 20,
206
- ) -> Tuple[np.ndarray, pd.Series, pd.Series]:
207
- r"""Collects training nodes and their labels.
208
-
209
- Args:
210
- size: The number of test nodes to collect.
211
- anchor_time: The anchor time.
212
- exclude_node: The nodes to exclude for use as in-context examples.
213
- batch_size: How many nodes to process in a single batch.
214
- max_iterations: The number of steps to run before aborting.
215
-
216
- Returns:
217
- A triplet holding the nodes, timestamps and labels.
218
- """
219
- batch_size = size if batch_size is None else batch_size
220
-
221
- candidate = self._get_candidates(exclude_node)
222
-
223
- if len(candidate) == 0:
224
- raise RuntimeError("Failed to generate any context examples "
225
- "since not enough entities exist")
226
-
227
- nodes: List[np.ndarray] = []
228
- times: List[pd.Series] = []
229
- ys: List[pd.Series] = []
230
-
231
- reached_end = False
232
- num_labels = candidate_offset = 0
233
- for _ in range(max_iterations):
234
- node = candidate[candidate_offset:candidate_offset + batch_size]
235
-
236
- if isinstance(anchor_time, pd.Timestamp):
237
- node = self._filter_candidates_by_time(node, anchor_time)
238
- time = pd.Series(anchor_time).repeat(len(node))
239
- time = time.astype('datetime64[ns]').reset_index(drop=True)
240
- else:
241
- assert anchor_time == 'entity'
242
- time = self._graph_store.time_dict[self._query.entity_table]
243
- time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
244
-
245
- y, mask = self(node, time)
246
-
247
- nodes.append(node[mask])
248
- times.append(time[mask].reset_index(drop=True))
249
- ys.append(y)
250
-
251
- num_labels += len(y)
252
-
253
- if num_labels > size:
254
- reached_end = True
255
- break # Sufficient number of labels collected. Abort.
256
-
257
- candidate_offset += batch_size
258
- if candidate_offset >= len(candidate):
259
- # Restart with an earlier anchor time (if applicable).
260
- if self._query.query_type == QueryType.STATIC:
261
- reached_end = True
262
- break # Cannot jump back in time for static PQs. Abort.
263
- if anchor_time == 'entity':
264
- reached_end = True
265
- break
266
- candidate_offset = 0
267
- time_frame = self._query.target_timeframe.timeframe
268
- anchor_time = anchor_time - (time_frame *
269
- self._query.num_forecasts)
270
- if anchor_time < self._graph_store.min_time:
271
- reached_end = True
272
- break # No earlier anchor time left. Abort.
273
-
274
- if len(nodes) > 1:
275
- node = np.concatenate(nodes, axis=0)[:size]
276
- time = pd.concat(times, axis=0).reset_index(drop=True).iloc[:size]
277
- y = pd.concat(ys, axis=0).reset_index(drop=True).iloc[:size]
278
- else:
279
- node = nodes[0][:size]
280
- time = times[0].iloc[:size]
281
- y = ys[0].iloc[:size]
282
-
283
- if len(node) == 0:
284
- raise ValueError("Failed to collect any context examples. Is your "
285
- "predictive query too restrictive?")
286
-
287
- global _coverage_warned
288
- if not _coverage_warned and not reached_end and len(node) < size // 2:
289
- _coverage_warned = True
290
- warnings.warn(f"Failed to collect {size:,} context examples "
291
- f"within {max_iterations} iterations. To improve "
292
- f"coverage, consider increasing the number of PQ "
293
- f"iterations using the 'max_pq_iterations' option. "
294
- f"This warning will not be shown again in this run.")
295
-
296
- return node, time, y
297
-
298
- def is_valid(
299
- self,
300
- node: np.ndarray,
301
- anchor_time: Union[pd.Timestamp, Literal['entity']],
302
- batch_size: int = 10_000,
303
- ) -> np.ndarray:
304
- r"""Denotes which nodes are valid for a given anchor time, *e.g.*,
305
- which nodes fulfill entity filter constraints.
306
-
307
- Args:
308
- node: The nodes to check for.
309
- anchor_time: The anchor time.
310
- batch_size: How many nodes to process in a single batch.
311
-
312
- Returns:
313
- The mask.
314
- """
315
- mask: Optional[np.ndarray] = None
316
-
317
- if isinstance(anchor_time, pd.Timestamp):
318
- node = self._filter_candidates_by_time(node, anchor_time)
319
- time = pd.Series(anchor_time).repeat(len(node))
320
- time = time.astype('datetime64[ns]').reset_index(drop=True)
321
- else:
322
- assert anchor_time == 'entity'
323
- time = self._graph_store.time_dict[self._query.entity_table]
324
- time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
325
-
326
- if isinstance(self._query.entity_ast, Filter):
327
- # Mask out via (temporal) entity filter:
328
- executor = PQueryPandasExecutor()
329
- masks: List[np.ndarray] = []
330
- for start in range(0, len(node), batch_size):
331
- feat_dict, time_dict, batch_dict = self._sample(
332
- node[start:start + batch_size],
333
- time.iloc[start:start + batch_size],
334
- )
335
- _mask = executor.execute_filter(
336
- filter=self._query.entity_ast,
337
- feat_dict=feat_dict,
338
- time_dict=time_dict,
339
- batch_dict=batch_dict,
340
- anchor_time=time.iloc[start:start + batch_size],
341
- )[1]
342
- masks.append(_mask)
343
-
344
- _mask = np.concatenate(masks)
345
- mask = (mask & _mask) if mask is not None else _mask
346
-
347
- if mask is None:
348
- mask = np.ones(len(node), dtype=bool)
349
-
350
- return mask
351
-
352
- def _get_sampling_specs(
353
- self,
354
- node: ASTNode,
355
- hop: int,
356
- seed_table_name: str,
357
- edge_types: List[Tuple[str, str, str]],
358
- num_forecasts: int = 1,
359
- ) -> List[SamplingSpec]:
360
- if isinstance(node, (Aggregation, Column)):
361
- if isinstance(node, Column):
362
- table_name = node.fqn.split('.')[0]
363
- if seed_table_name == table_name:
364
- return []
365
- else:
366
- table_name = node._get_target_column_name().split('.')[0]
367
-
368
- target_edge_types = [
369
- edge_type for edge_type in edge_types if
370
- edge_type[2] == seed_table_name and edge_type[0] == table_name
371
- ]
372
- if len(target_edge_types) != 1:
373
- raise ValueError(
374
- f"Could not find a unique foreign key from table "
375
- f"'{seed_table_name}' to '{table_name}'")
376
-
377
- if isinstance(node, Column):
378
- return [
379
- SamplingSpec(
380
- edge_type=target_edge_types[0],
381
- hop=hop + 1,
382
- start_offset=None,
383
- end_offset=None,
384
- )
385
- ]
386
- spec = SamplingSpec(
387
- edge_type=target_edge_types[0],
388
- hop=hop + 1,
389
- start_offset=node.aggr_time_range.start_date_offset,
390
- end_offset=node.aggr_time_range.end_date_offset *
391
- num_forecasts,
392
- )
393
- return [spec] + self._get_sampling_specs(
394
- node.target, hop=hop + 1, seed_table_name=table_name,
395
- edge_types=edge_types, num_forecasts=num_forecasts)
396
- specs = []
397
- for child in node.children:
398
- specs += self._get_sampling_specs(child, hop, seed_table_name,
399
- edge_types, num_forecasts)
400
- return specs
401
-
402
- def get_sampling_specs(self) -> List[SamplingSpec]:
403
- edge_types = self._graph_store.edge_types
404
- specs = self._get_sampling_specs(
405
- self._query.target_ast, hop=0,
406
- seed_table_name=self._query.entity_table, edge_types=edge_types,
407
- num_forecasts=self._query.num_forecasts)
408
- specs += self._get_sampling_specs(
409
- self._query.entity_ast, hop=0,
410
- seed_table_name=self._query.entity_table, edge_types=edge_types)
411
- if self._query.whatif_ast is not None:
412
- specs += self._get_sampling_specs(
413
- self._query.whatif_ast, hop=0,
414
- seed_table_name=self._query.entity_table,
415
- edge_types=edge_types)
416
- # Group specs according to edge type and hop:
417
- spec_dict: Dict[
418
- Tuple[Tuple[str, str, str], int],
419
- Tuple[Optional[DateOffset], Optional[DateOffset]],
420
- ] = {}
421
- for spec in specs:
422
- if (spec.edge_type, spec.hop) not in spec_dict:
423
- spec_dict[(spec.edge_type, spec.hop)] = (
424
- spec.start_offset,
425
- spec.end_offset,
426
- )
427
- else:
428
- start_offset, end_offset = spec_dict[(
429
- spec.edge_type,
430
- spec.hop,
431
- )]
432
- spec_dict[(spec.edge_type, spec.hop)] = (
433
- min_date_offset(start_offset, spec.start_offset),
434
- max_date_offset(end_offset, spec.end_offset),
435
- )
436
-
437
- return [
438
- SamplingSpec(edge, hop, start_offset, end_offset)
439
- for (edge, hop), (start_offset, end_offset) in spec_dict.items()
440
- ]
441
-
442
- def _sample(
443
- self,
444
- node: np.ndarray,
445
- anchor_time: pd.Series,
446
- ) -> Tuple[
447
- Dict[str, pd.DataFrame],
448
- Dict[str, pd.Series],
449
- Dict[str, np.ndarray],
450
- ]:
451
- r"""Samples a subgraph that contains all relevant information to
452
- evaluate the predictive query.
453
-
454
- Args:
455
- node: The nodes to check for.
456
- anchor_time: The anchor time.
457
-
458
- Returns:
459
- The feature dictionary, the time column dictionary and the batch
460
- dictionary.
461
- """
462
- specs = self.get_sampling_specs()
463
- num_hops = max([spec.hop for spec in specs] + [0])
464
- num_neighbors: Dict[Tuple[str, str, str], list[int]] = {}
465
- time_offsets: Dict[
466
- Tuple[str, str, str],
467
- List[List[Optional[int]]],
468
- ] = {}
469
- for spec in specs:
470
- if spec.end_offset is not None:
471
- if spec.edge_type not in time_offsets:
472
- time_offsets[spec.edge_type] = [[0, 0]
473
- for _ in range(num_hops)]
474
- offset: Optional[int] = date_offset_to_seconds(spec.end_offset)
475
- time_offsets[spec.edge_type][spec.hop - 1][1] = offset
476
- if spec.start_offset is not None:
477
- offset = date_offset_to_seconds(spec.start_offset)
478
- else:
479
- offset = None
480
- time_offsets[spec.edge_type][spec.hop - 1][0] = offset
481
- else:
482
- if spec.edge_type not in num_neighbors:
483
- num_neighbors[spec.edge_type] = [0] * num_hops
484
- num_neighbors[spec.edge_type][spec.hop - 1] = -1
485
-
486
- edge_types = list(num_neighbors.keys()) + list(time_offsets.keys())
487
- node_types = list(
488
- set([self._query.entity_table])
489
- | set(src for src, _, _ in edge_types)
490
- | set(dst for _, _, dst in edge_types))
491
-
492
- sampler = kumolib.NeighborSampler(
493
- node_types,
494
- edge_types,
495
- {
496
- '__'.join(edge_type): self._graph_store.colptr_dict[edge_type]
497
- for edge_type in edge_types
498
- },
499
- {
500
- '__'.join(edge_type): self._graph_store.row_dict[edge_type]
501
- for edge_type in edge_types
502
- },
503
- {
504
- node_type: time
505
- for node_type, time in self._graph_store.time_dict.items()
506
- if node_type in node_types
507
- },
508
- )
509
-
510
- anchor_time = anchor_time.astype('datetime64[ns]')
511
- _, _, node_dict, batch_dict, _, _ = sampler.sample(
512
- {
513
- '__'.join(edge_type): np.array(values)
514
- for edge_type, values in num_neighbors.items()
515
- },
516
- {
517
- '__'.join(edge_type): np.array(values)
518
- for edge_type, values in time_offsets.items()
519
- },
520
- self._query.entity_table,
521
- node,
522
- anchor_time.astype(int).to_numpy() // 1000**3,
523
- )
524
-
525
- feat_dict: Dict[str, pd.DataFrame] = {}
526
- time_dict: Dict[str, pd.Series] = {}
527
- column_dict: Dict[str, Set[str]] = {}
528
- for col in self._query.all_query_columns:
529
- table_name, col_name = col.split('.')
530
- if table_name not in column_dict:
531
- column_dict[table_name] = set()
532
- if col_name != '*':
533
- column_dict[table_name].add(col_name)
534
- time_tables = self.find_time_tables()
535
- for table_name in set(list(column_dict.keys()) + time_tables):
536
- df = self._graph_store.df_dict[table_name]
537
- row_id = node_dict[table_name]
538
- df = df.iloc[row_id].reset_index(drop=True)
539
- if table_name in column_dict:
540
- if len(column_dict[table_name]) == 0:
541
- # We are dealing with COUNT(table.*), insert a dummy col
542
- # to ensure we don't lose the information on node count
543
- feat_dict[table_name] = pd.DataFrame(
544
- {'ones': [1] * len(df)})
545
- else:
546
- feat_dict[table_name] = df[list(column_dict[table_name])]
547
- if table_name in time_tables:
548
- time_col = self._graph_store.time_column_dict[table_name]
549
- time_dict[table_name] = df[time_col]
550
-
551
- return feat_dict, time_dict, batch_dict
552
-
553
- def __call__(
554
- self,
555
- node: np.ndarray,
556
- anchor_time: pd.Series,
557
- ) -> Tuple[pd.Series, np.ndarray]:
558
-
559
- feat_dict, time_dict, batch_dict = self._sample(node, anchor_time)
560
-
561
- y, mask = PQueryPandasExecutor().execute(
562
- query=self._query,
563
- feat_dict=feat_dict,
564
- time_dict=time_dict,
565
- batch_dict=batch_dict,
566
- anchor_time=anchor_time,
567
- num_forecasts=self._query.num_forecasts,
568
- )
569
-
570
- return y, mask
571
-
572
- def find_time_tables(self) -> List[str]:
573
- def _find_time_tables(node: ASTNode) -> List[str]:
574
- time_tables = []
575
- if isinstance(node, Aggregation):
576
- time_tables.append(
577
- node._get_target_column_name().split('.')[0])
578
- for child in node.children:
579
- time_tables += _find_time_tables(child)
580
- return time_tables
581
-
582
- time_tables = _find_time_tables(
583
- self._query.target_ast) + _find_time_tables(self._query.entity_ast)
584
- if self._query.whatif_ast is not None:
585
- time_tables += _find_time_tables(self._query.whatif_ast)
586
- return list(set(time_tables))
587
-
588
- @staticmethod
589
- def get_task_type(
590
- query: ValidatedPredictiveQuery,
591
- edge_types: List[Tuple[str, str, str]],
592
- ) -> TaskType:
593
- if isinstance(query.target_ast, (Condition, LogicalOperation)):
594
- return TaskType.BINARY_CLASSIFICATION
595
-
596
- target = query.target_ast
597
- if isinstance(target, Join):
598
- target = target.rhs_target
599
- if isinstance(target, Aggregation):
600
- if target.aggr == AggregationType.LIST_DISTINCT:
601
- table_name, col_name = target._get_target_column_name().split(
602
- '.')
603
- target_edge_types = [
604
- edge_type for edge_type in edge_types
605
- if edge_type[0] == table_name and edge_type[1] == col_name
606
- ]
607
- if len(target_edge_types) != 1:
608
- raise NotImplementedError(
609
- f"Multilabel-classification queries based on "
610
- f"'LIST_DISTINCT' are not supported yet. If you "
611
- f"planned to write a link prediction query instead, "
612
- f"make sure to register '{col_name}' as a "
613
- f"foreign key.")
614
- return TaskType.TEMPORAL_LINK_PREDICTION
615
-
616
- return TaskType.REGRESSION
617
-
618
- assert isinstance(target, Column)
619
-
620
- if target.stype in {Stype.ID, Stype.categorical}:
621
- return TaskType.MULTICLASS_CLASSIFICATION
622
-
623
- if target.stype in {Stype.numerical}:
624
- return TaskType.REGRESSION
625
-
626
- raise NotImplementedError("Task type not yet supported")
627
-
628
-
629
- def date_offset_to_seconds(offset: pd.DateOffset) -> int:
630
- r"""Convert a :class:`pandas.DateOffset` into a maximum number of
631
- nanoseconds.
632
-
633
- .. note::
634
- We are conservative and take months and years as their maximum value.
635
- Additional values are then dropped in label computation where we know
636
- the actual dates.
637
- """
638
- # Max durations for months and years in nanoseconds:
639
- MAX_DAYS_IN_MONTH = 31
640
- MAX_DAYS_IN_YEAR = 366
641
-
642
- # Conversion factors:
643
- SECONDS_IN_MINUTE = 60
644
- SECONDS_IN_HOUR = 60 * SECONDS_IN_MINUTE
645
- SECONDS_IN_DAY = 24 * SECONDS_IN_HOUR
646
-
647
- total_ns = 0
648
- multiplier = getattr(offset, 'n', 1) # The multiplier (if present).
649
-
650
- for attr, value in offset.__dict__.items():
651
- if value is None or value == 0:
652
- continue
653
- scaled_value = value * multiplier
654
- if attr == 'years':
655
- total_ns += scaled_value * MAX_DAYS_IN_YEAR * SECONDS_IN_DAY
656
- elif attr == 'months':
657
- total_ns += scaled_value * MAX_DAYS_IN_MONTH * SECONDS_IN_DAY
658
- elif attr == 'days':
659
- total_ns += scaled_value * SECONDS_IN_DAY
660
- elif attr == 'hours':
661
- total_ns += scaled_value * SECONDS_IN_HOUR
662
- elif attr == 'minutes':
663
- total_ns += scaled_value * SECONDS_IN_MINUTE
664
- elif attr == 'seconds':
665
- total_ns += scaled_value
666
-
667
- return total_ns
668
-
669
-
670
- def min_date_offset(*args: Optional[DateOffset]) -> Optional[DateOffset]:
671
- if any(arg is None for arg in args):
672
- return None
673
-
674
- anchor = pd.Timestamp('2000-01-01')
675
- timestamps = [anchor + arg for arg in args]
676
- assert len(timestamps) > 0
677
- argmin = min(range(len(timestamps)), key=lambda i: timestamps[i])
678
- return args[argmin]
679
-
680
-
681
- def max_date_offset(*args: DateOffset) -> DateOffset:
682
- if any(arg is None for arg in args):
683
- return None
684
-
685
- anchor = pd.Timestamp('2000-01-01')
686
- timestamps = [anchor + arg for arg in args]
687
- assert len(timestamps) > 0
688
- argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
689
- return args[argmax]