kumoai 2.12.0.dev202510231830__cp311-cp311-win_amd64.whl → 2.14.0.dev202512311733__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. kumoai/__init__.py +41 -35
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +15 -13
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/jobs.py +24 -0
  6. kumoai/client/pquery.py +6 -2
  7. kumoai/client/rfm.py +35 -7
  8. kumoai/connector/utils.py +23 -2
  9. kumoai/experimental/rfm/__init__.py +191 -48
  10. kumoai/experimental/rfm/authenticate.py +3 -4
  11. kumoai/experimental/rfm/backend/__init__.py +0 -0
  12. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  13. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
  14. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  15. kumoai/experimental/rfm/backend/local/table.py +113 -0
  16. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  17. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  18. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  19. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  20. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  21. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  22. kumoai/experimental/rfm/base/__init__.py +30 -0
  23. kumoai/experimental/rfm/base/column.py +152 -0
  24. kumoai/experimental/rfm/base/expression.py +44 -0
  25. kumoai/experimental/rfm/base/sampler.py +761 -0
  26. kumoai/experimental/rfm/base/source.py +19 -0
  27. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  28. kumoai/experimental/rfm/base/table.py +735 -0
  29. kumoai/experimental/rfm/graph.py +1237 -0
  30. kumoai/experimental/rfm/infer/__init__.py +8 -0
  31. kumoai/experimental/rfm/infer/dtype.py +82 -0
  32. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  33. kumoai/experimental/rfm/infer/pkey.py +128 -0
  34. kumoai/experimental/rfm/infer/stype.py +35 -0
  35. kumoai/experimental/rfm/infer/time_col.py +61 -0
  36. kumoai/experimental/rfm/pquery/__init__.py +0 -4
  37. kumoai/experimental/rfm/pquery/executor.py +27 -27
  38. kumoai/experimental/rfm/pquery/pandas_executor.py +64 -40
  39. kumoai/experimental/rfm/relbench.py +76 -0
  40. kumoai/experimental/rfm/rfm.py +386 -276
  41. kumoai/experimental/rfm/sagemaker.py +138 -0
  42. kumoai/kumolib.cp311-win_amd64.pyd +0 -0
  43. kumoai/pquery/predictive_query.py +10 -6
  44. kumoai/spcs.py +1 -3
  45. kumoai/testing/decorators.py +1 -1
  46. kumoai/testing/snow.py +50 -0
  47. kumoai/trainer/distilled_trainer.py +175 -0
  48. kumoai/trainer/trainer.py +9 -10
  49. kumoai/utils/__init__.py +3 -2
  50. kumoai/utils/display.py +51 -0
  51. kumoai/utils/progress_logger.py +188 -16
  52. kumoai/utils/sql.py +3 -0
  53. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/METADATA +13 -2
  54. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/RECORD +57 -36
  55. kumoai/experimental/rfm/local_graph.py +0 -810
  56. kumoai/experimental/rfm/local_graph_sampler.py +0 -184
  57. kumoai/experimental/rfm/local_pquery_driver.py +0 -494
  58. kumoai/experimental/rfm/local_table.py +0 -545
  59. kumoai/experimental/rfm/pquery/backend.py +0 -136
  60. kumoai/experimental/rfm/pquery/pandas_backend.py +0 -478
  61. kumoai/experimental/rfm/utils.py +0 -344
  62. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/WHEEL +0 -0
  63. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/licenses/LICENSE +0 -0
  64. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/top_level.txt +0 -0
@@ -1,184 +0,0 @@
1
- from typing import Dict, List, Optional, Tuple
2
-
3
- import numpy as np
4
- import pandas as pd
5
- from kumoapi.model_plan import RunMode
6
- from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
7
- from kumoapi.typing import Stype
8
-
9
- import kumoai.kumolib as kumolib
10
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
11
- from kumoai.experimental.rfm.utils import normalize_text
12
-
13
-
14
- class LocalGraphSampler:
15
- def __init__(self, graph_store: LocalGraphStore) -> None:
16
- self._graph_store = graph_store
17
- self._sampler = kumolib.NeighborSampler(
18
- self._graph_store.node_types,
19
- self._graph_store.edge_types,
20
- {
21
- '__'.join(edge_type): colptr
22
- for edge_type, colptr in self._graph_store.colptr_dict.items()
23
- },
24
- {
25
- '__'.join(edge_type): row
26
- for edge_type, row in self._graph_store.row_dict.items()
27
- },
28
- self._graph_store.time_dict,
29
- )
30
-
31
- def __call__(
32
- self,
33
- entity_table_names: Tuple[str, ...],
34
- node: np.ndarray,
35
- time: np.ndarray,
36
- run_mode: RunMode,
37
- num_neighbors: List[int],
38
- exclude_cols_dict: Dict[str, List[str]],
39
- ) -> Subgraph:
40
-
41
- (
42
- row_dict,
43
- col_dict,
44
- node_dict,
45
- batch_dict,
46
- num_sampled_nodes_dict,
47
- num_sampled_edges_dict,
48
- ) = self._sampler.sample(
49
- {
50
- '__'.join(edge_type): num_neighbors
51
- for edge_type in self._graph_store.edge_types
52
- },
53
- {}, # time interval based sampling
54
- entity_table_names[0],
55
- node,
56
- time // 1000**3, # nanoseconds to seconds
57
- )
58
-
59
- table_dict: Dict[str, Table] = {}
60
- for table_name, node in node_dict.items():
61
- batch = batch_dict[table_name]
62
-
63
- if len(node) == 0:
64
- continue
65
-
66
- df = self._graph_store.df_dict[table_name]
67
-
68
- num_sampled_nodes = num_sampled_nodes_dict[table_name].tolist()
69
- stype_dict = { # Exclude target columns:
70
- column_name: stype
71
- for column_name, stype in
72
- self._graph_store.stype_dict[table_name].items()
73
- if column_name not in exclude_cols_dict.get(table_name, [])
74
- }
75
- primary_key: Optional[str] = None
76
- if table_name in entity_table_names:
77
- primary_key = self._graph_store.pkey_name_dict.get(table_name)
78
-
79
- columns: List[str] = []
80
- if table_name in entity_table_names:
81
- columns += [self._graph_store.pkey_name_dict[table_name]]
82
- columns += list(stype_dict.keys())
83
-
84
- if len(columns) == 0:
85
- table_dict[table_name] = Table(
86
- df=pd.DataFrame(index=range(len(node))),
87
- row=None,
88
- batch=batch,
89
- num_sampled_nodes=num_sampled_nodes,
90
- stype_dict=stype_dict,
91
- primary_key=primary_key,
92
- )
93
- continue
94
-
95
- row: Optional[np.ndarray] = None
96
- if table_name in self._graph_store.end_time_column_dict:
97
- # Set end time to NaT for all values greater than anchor time:
98
- df = df.iloc[node].reset_index(drop=True)
99
- col_name = self._graph_store.end_time_column_dict[table_name]
100
- ser = df[col_name]
101
- value = ser.astype('datetime64[ns]').astype(int).to_numpy()
102
- mask = value > time[batch]
103
- df.loc[mask, col_name] = pd.NaT
104
- else:
105
- # Only store unique rows in `df` above a certain threshold:
106
- unique_node, inverse = np.unique(node, return_inverse=True)
107
- if len(node) > 1.05 * len(unique_node):
108
- df = df.iloc[unique_node].reset_index(drop=True)
109
- row = inverse
110
- else:
111
- df = df.iloc[node].reset_index(drop=True)
112
-
113
- # Filter data frame to minimal set of columns:
114
- df = df[columns]
115
-
116
- # Normalize text (if not already pre-processed):
117
- for column_name, stype in stype_dict.items():
118
- if stype == Stype.text:
119
- df[column_name] = normalize_text(df[column_name])
120
-
121
- table_dict[table_name] = Table(
122
- df=df,
123
- row=row,
124
- batch=batch,
125
- num_sampled_nodes=num_sampled_nodes,
126
- stype_dict=stype_dict,
127
- primary_key=primary_key,
128
- )
129
-
130
- link_dict: Dict[Tuple[str, str, str], Link] = {}
131
- for edge_type in self._graph_store.edge_types:
132
- edge_type_str = '__'.join(edge_type)
133
-
134
- row = row_dict[edge_type_str]
135
- col = col_dict[edge_type_str]
136
-
137
- if len(row) == 0:
138
- continue
139
-
140
- # Do not store reverse edge type if it is a replica:
141
- rev_edge_type = Subgraph.rev_edge_type(edge_type)
142
- rev_edge_type_str = '__'.join(rev_edge_type)
143
- if (rev_edge_type in link_dict
144
- and np.array_equal(row, col_dict[rev_edge_type_str])
145
- and np.array_equal(col, row_dict[rev_edge_type_str])):
146
- link = Link(
147
- layout=EdgeLayout.REV,
148
- row=None,
149
- col=None,
150
- num_sampled_edges=(
151
- num_sampled_edges_dict[edge_type_str].tolist()),
152
- )
153
- link_dict[edge_type] = link
154
- continue
155
-
156
- layout = EdgeLayout.COO
157
- if np.array_equal(row, np.arange(len(row))):
158
- row = None
159
- if np.array_equal(col, np.arange(len(col))):
160
- col = None
161
-
162
- # Store in compressed representation if more efficient:
163
- num_cols = table_dict[edge_type[2]].num_rows
164
- if col is not None and len(col) > num_cols + 1:
165
- layout = EdgeLayout.CSC
166
- colcount = np.bincount(col, minlength=num_cols)
167
- col = np.empty(num_cols + 1, dtype=col.dtype)
168
- col[0] = 0
169
- np.cumsum(colcount, out=col[1:])
170
-
171
- link = Link(
172
- layout=layout,
173
- row=row,
174
- col=col,
175
- num_sampled_edges=(
176
- num_sampled_edges_dict[edge_type_str].tolist()),
177
- )
178
- link_dict[edge_type] = link
179
-
180
- return Subgraph(
181
- anchor_time=time,
182
- table_dict=table_dict,
183
- link_dict=link_dict,
184
- )
@@ -1,494 +0,0 @@
1
- import warnings
2
- from typing import Dict, List, Literal, Optional, Tuple, Union
3
-
4
- import numpy as np
5
- import pandas as pd
6
- from kumoapi.pquery import QueryType
7
- from kumoapi.rfm import PQueryDefinition
8
-
9
- import kumoai.kumolib as kumolib
10
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
11
- from kumoai.experimental.rfm.pquery import PQueryPandasBackend
12
-
13
- _coverage_warned = False
14
-
15
-
16
- class LocalPQueryDriver:
17
- def __init__(
18
- self,
19
- graph_store: LocalGraphStore,
20
- query: PQueryDefinition,
21
- random_seed: Optional[int] = None,
22
- ) -> None:
23
- self._graph_store = graph_store
24
- self._query = query
25
- self._random_seed = random_seed
26
- self._rng = np.random.default_rng(random_seed)
27
-
28
- def _get_candidates(
29
- self,
30
- exclude_node: Optional[np.ndarray] = None,
31
- ) -> np.ndarray:
32
-
33
- if self._query.query_type == QueryType.TEMPORAL:
34
- assert exclude_node is None
35
-
36
- table_name = self._query.entity.pkey.table_name
37
- num_nodes = len(self._graph_store.df_dict[table_name])
38
- mask_dict = self._graph_store.mask_dict
39
-
40
- candidate: np.ndarray
41
-
42
- # Case 1: All nodes are valid and nothing to exclude:
43
- if exclude_node is None and table_name not in mask_dict:
44
- candidate = np.arange(num_nodes)
45
-
46
- # Case 2: Not all nodes are valid - lookup valid nodes:
47
- if exclude_node is None:
48
- pkey_map = self._graph_store.pkey_map_dict[table_name]
49
- candidate = pkey_map['arange'].to_numpy().copy()
50
-
51
- # Case 3: Exclude nodes - use a mask to exclude them:
52
- else:
53
- mask = np.full((num_nodes, ), fill_value=True, dtype=bool)
54
- mask[exclude_node] = False
55
- if table_name in mask_dict:
56
- mask &= mask_dict[table_name]
57
- candidate = mask.nonzero()[0]
58
-
59
- self._rng.shuffle(candidate)
60
-
61
- return candidate
62
-
63
- def _filter_candidates_by_time(
64
- self,
65
- candidate: np.ndarray,
66
- anchor_time: pd.Timestamp,
67
- ) -> np.ndarray:
68
-
69
- entity = self._query.entity.pkey.table_name
70
-
71
- # Filter out entities that do not exist yet in time:
72
- time_sec = self._graph_store.time_dict.get(entity)
73
- if time_sec is not None:
74
- mask = time_sec[candidate] <= (anchor_time.value // (1000**3))
75
- candidate = candidate[mask]
76
-
77
- # Filter out entities that no longer exist in time:
78
- end_time_col = self._graph_store.end_time_column_dict.get(entity)
79
- if end_time_col is not None:
80
- ser = self._graph_store.df_dict[entity][end_time_col]
81
- ser = ser.iloc[candidate]
82
- mask = (anchor_time < ser) | ser.isna().to_numpy()
83
- candidate = candidate[mask]
84
-
85
- return candidate
86
-
87
- def collect_test(
88
- self,
89
- size: int,
90
- anchor_time: Union[pd.Timestamp, Literal['entity']],
91
- batch_size: Optional[int] = None,
92
- max_iterations: int = 20,
93
- guarantee_train_examples: bool = True,
94
- ) -> Tuple[np.ndarray, pd.Series, pd.Series]:
95
- r"""Collects test nodes and their labels used for evaluation.
96
-
97
- Args:
98
- size: The number of test nodes to collect.
99
- anchor_time: The anchor time.
100
- batch_size: How many nodes to process in a single batch.
101
- max_iterations: The number of steps to run before aborting.
102
- guarantee_train_examples: Ensures that test examples do not occupy
103
- the entire set of entity candidates.
104
-
105
- Returns:
106
- A triplet holding the nodes, timestamps and labels.
107
- """
108
- batch_size = size if batch_size is None else batch_size
109
-
110
- candidate = self._get_candidates()
111
-
112
- nodes: List[np.ndarray] = []
113
- times: List[pd.Series] = []
114
- ys: List[pd.Series] = []
115
-
116
- reached_end = False
117
- num_labels = candidate_offset = 0
118
- for _ in range(max_iterations):
119
- node = candidate[candidate_offset:candidate_offset + batch_size]
120
-
121
- if isinstance(anchor_time, pd.Timestamp):
122
- node = self._filter_candidates_by_time(node, anchor_time)
123
- time = pd.Series(anchor_time).repeat(len(node))
124
- time = time.astype('datetime64[ns]').reset_index(drop=True)
125
- else:
126
- assert anchor_time == 'entity'
127
- time = self._graph_store.time_dict[
128
- self._query.entity.pkey.table_name]
129
- time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
130
-
131
- y, mask = self(node, time)
132
-
133
- nodes.append(node[mask])
134
- times.append(time[mask].reset_index(drop=True))
135
- ys.append(y)
136
-
137
- num_labels += len(y)
138
-
139
- if num_labels > size:
140
- reached_end = True
141
- break # Sufficient number of labels collected. Abort.
142
-
143
- candidate_offset += batch_size
144
- if candidate_offset >= len(candidate):
145
- reached_end = True
146
- break
147
-
148
- if len(nodes) > 1:
149
- node = np.concatenate(nodes, axis=0)[:size]
150
- time = pd.concat(times, axis=0).reset_index(drop=True).iloc[:size]
151
- y = pd.concat(ys, axis=0).reset_index(drop=True).iloc[:size]
152
- else:
153
- node = nodes[0][:size]
154
- time = times[0].iloc[:size]
155
- y = ys[0].iloc[:size]
156
-
157
- if len(node) == 0:
158
- raise RuntimeError("Failed to collect any test examples for "
159
- "evaluation. Is your predictive query too "
160
- "restrictive?")
161
-
162
- global _coverage_warned
163
- if not _coverage_warned and not reached_end and len(node) < size // 2:
164
- _coverage_warned = True
165
- warnings.warn(f"Failed to collect {size:,} test examples within "
166
- f"{max_iterations} iterations. To improve coverage, "
167
- f"consider increasing the number of PQ iterations "
168
- f"using the 'max_pq_iterations' option. This "
169
- f"warning will not be shown again in this run.")
170
-
171
- if (guarantee_train_examples
172
- and self._query.query_type == QueryType.STATIC
173
- and candidate_offset >= len(candidate)):
174
- # In case all valid entities are used as test examples, we can no
175
- # longer find any training example. Fallback to a 50/50 split:
176
- size = len(node) // 2
177
- node = node[:size]
178
- time = time.iloc[:size]
179
- y = y.iloc[:size]
180
-
181
- return node, time, y
182
-
183
- def collect_train(
184
- self,
185
- size: int,
186
- anchor_time: Union[pd.Timestamp, Literal['entity']],
187
- exclude_node: Optional[np.ndarray] = None,
188
- batch_size: Optional[int] = None,
189
- max_iterations: int = 20,
190
- ) -> Tuple[np.ndarray, pd.Series, pd.Series]:
191
- r"""Collects training nodes and their labels.
192
-
193
- Args:
194
- size: The number of test nodes to collect.
195
- anchor_time: The anchor time.
196
- exclude_node: The nodes to exclude for use as in-context examples.
197
- batch_size: How many nodes to process in a single batch.
198
- max_iterations: The number of steps to run before aborting.
199
-
200
- Returns:
201
- A triplet holding the nodes, timestamps and labels.
202
- """
203
- batch_size = size if batch_size is None else batch_size
204
-
205
- candidate = self._get_candidates(exclude_node)
206
-
207
- if len(candidate) == 0:
208
- raise RuntimeError("Failed to generate any context examples "
209
- "since not enough entities exist")
210
-
211
- nodes: List[np.ndarray] = []
212
- times: List[pd.Series] = []
213
- ys: List[pd.Series] = []
214
-
215
- reached_end = False
216
- num_labels = candidate_offset = 0
217
- for _ in range(max_iterations):
218
- node = candidate[candidate_offset:candidate_offset + batch_size]
219
-
220
- if isinstance(anchor_time, pd.Timestamp):
221
- node = self._filter_candidates_by_time(node, anchor_time)
222
- time = pd.Series(anchor_time).repeat(len(node))
223
- time = time.astype('datetime64[ns]').reset_index(drop=True)
224
- else:
225
- assert anchor_time == 'entity'
226
- time = self._graph_store.time_dict[
227
- self._query.entity.pkey.table_name]
228
- time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
229
-
230
- y, mask = self(node, time)
231
-
232
- nodes.append(node[mask])
233
- times.append(time[mask].reset_index(drop=True))
234
- ys.append(y)
235
-
236
- num_labels += len(y)
237
-
238
- if num_labels > size:
239
- reached_end = True
240
- break # Sufficient number of labels collected. Abort.
241
-
242
- candidate_offset += batch_size
243
- if candidate_offset >= len(candidate):
244
- # Restart with an earlier anchor time (if applicable).
245
- if self._query.query_type == QueryType.STATIC:
246
- reached_end = True
247
- break # Cannot jump back in time for static PQs. Abort.
248
- if anchor_time == 'entity':
249
- reached_end = True
250
- break
251
- candidate_offset = 0
252
- anchor_time = anchor_time - (self._query.target.end_offset *
253
- self._query.num_forecasts)
254
- if anchor_time < self._graph_store.min_time:
255
- reached_end = True
256
- break # No earlier anchor time left. Abort.
257
-
258
- if len(nodes) > 1:
259
- node = np.concatenate(nodes, axis=0)[:size]
260
- time = pd.concat(times, axis=0).reset_index(drop=True).iloc[:size]
261
- y = pd.concat(ys, axis=0).reset_index(drop=True).iloc[:size]
262
- else:
263
- node = nodes[0][:size]
264
- time = times[0].iloc[:size]
265
- y = ys[0].iloc[:size]
266
-
267
- if len(node) == 0:
268
- raise ValueError("Failed to collect any context examples. Is your "
269
- "predictive query too restrictive?")
270
-
271
- global _coverage_warned
272
- if not _coverage_warned and not reached_end and len(node) < size // 2:
273
- _coverage_warned = True
274
- warnings.warn(f"Failed to collect {size:,} context examples "
275
- f"within {max_iterations} iterations. To improve "
276
- f"coverage, consider increasing the number of PQ "
277
- f"iterations using the 'max_pq_iterations' option. "
278
- f"This warning will not be shown again in this run.")
279
-
280
- return node, time, y
281
-
282
- def is_valid(
283
- self,
284
- node: np.ndarray,
285
- anchor_time: Union[pd.Timestamp, Literal['entity']],
286
- batch_size: int = 10_000,
287
- ) -> np.ndarray:
288
- r"""Denotes which nodes are valid for a given anchor time, *e.g.*,
289
- which nodes fulfill entity filter constraints.
290
-
291
- Args:
292
- node: The nodes to check for.
293
- anchor_time: The anchor time.
294
- batch_size: How many nodes to process in a single batch.
295
-
296
- Returns:
297
- The mask.
298
- """
299
- mask: Optional[np.ndarray] = None
300
-
301
- if isinstance(anchor_time, pd.Timestamp):
302
- node = self._filter_candidates_by_time(node, anchor_time)
303
- time = pd.Series(anchor_time).repeat(len(node))
304
- time = time.astype('datetime64[ns]').reset_index(drop=True)
305
- else:
306
- assert anchor_time == 'entity'
307
- time = self._graph_store.time_dict[
308
- self._query.entity.pkey.table_name]
309
- time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
310
-
311
- if self._query.entity.filter is not None:
312
- # Mask out via (temporal) entity filter:
313
- backend = PQueryPandasBackend()
314
- masks: List[np.ndarray] = []
315
- for start in range(0, len(node), batch_size):
316
- feat_dict, time_dict, batch_dict = self._sample(
317
- node[start:start + batch_size],
318
- time.iloc[start:start + batch_size],
319
- )
320
- _mask = backend.eval_filter(
321
- filter=self._query.entity.filter,
322
- feat_dict=feat_dict,
323
- time_dict=time_dict,
324
- batch_dict=batch_dict,
325
- anchor_time=time.iloc[start:start + batch_size],
326
- )
327
- masks.append(_mask)
328
-
329
- _mask = np.concatenate(masks)
330
- mask = (mask & _mask) if mask is not None else _mask
331
-
332
- if mask is None:
333
- mask = np.ones(len(node), dtype=bool)
334
-
335
- return mask
336
-
337
- def _sample(
338
- self,
339
- node: np.ndarray,
340
- anchor_time: pd.Series,
341
- ) -> Tuple[
342
- Dict[str, pd.DataFrame],
343
- Dict[str, pd.Series],
344
- Dict[str, np.ndarray],
345
- ]:
346
- r"""Samples a subgraph that contains all relevant information to
347
- evaluate the predictive query.
348
-
349
- Args:
350
- node: The nodes to check for.
351
- anchor_time: The anchor time.
352
-
353
- Returns:
354
- The feature dictionary, the time column dictionary and the batch
355
- dictionary.
356
- """
357
- specs = self._query.get_sampling_specs(self._graph_store.edge_types)
358
- num_hops = max([spec.hop for spec in specs] + [0])
359
- num_neighbors: Dict[Tuple[str, str, str], list[int]] = {}
360
- time_offsets: Dict[
361
- Tuple[str, str, str],
362
- List[List[Optional[int]]],
363
- ] = {}
364
- for spec in specs:
365
- if spec.end_offset is not None:
366
- if spec.edge_type not in time_offsets:
367
- time_offsets[spec.edge_type] = [[0, 0]
368
- for _ in range(num_hops)]
369
- offset: Optional[int] = date_offset_to_seconds(spec.end_offset)
370
- time_offsets[spec.edge_type][spec.hop - 1][1] = offset
371
- if spec.start_offset is not None:
372
- offset = date_offset_to_seconds(spec.start_offset)
373
- else:
374
- offset = None
375
- time_offsets[spec.edge_type][spec.hop - 1][0] = offset
376
- else:
377
- if spec.edge_type not in num_neighbors:
378
- num_neighbors[spec.edge_type] = [0] * num_hops
379
- num_neighbors[spec.edge_type][spec.hop - 1] = -1
380
-
381
- edge_types = list(num_neighbors.keys()) + list(time_offsets.keys())
382
- node_types = list(
383
- set([self._query.entity.pkey.table_name])
384
- | set(src for src, _, _ in edge_types)
385
- | set(dst for _, _, dst in edge_types))
386
-
387
- sampler = kumolib.NeighborSampler(
388
- node_types,
389
- edge_types,
390
- {
391
- '__'.join(edge_type): self._graph_store.colptr_dict[edge_type]
392
- for edge_type in edge_types
393
- },
394
- {
395
- '__'.join(edge_type): self._graph_store.row_dict[edge_type]
396
- for edge_type in edge_types
397
- },
398
- {
399
- node_type: time
400
- for node_type, time in self._graph_store.time_dict.items()
401
- if node_type in node_types
402
- },
403
- )
404
-
405
- anchor_time = anchor_time.astype('datetime64[ns]')
406
- _, _, node_dict, batch_dict, _, _ = sampler.sample(
407
- {
408
- '__'.join(edge_type): np.array(values)
409
- for edge_type, values in num_neighbors.items()
410
- },
411
- {
412
- '__'.join(edge_type): np.array(values)
413
- for edge_type, values in time_offsets.items()
414
- },
415
- self._query.entity.pkey.table_name,
416
- node,
417
- anchor_time.astype(int).to_numpy() // 1000**3,
418
- )
419
-
420
- feat_dict: Dict[str, pd.DataFrame] = {}
421
- time_dict: Dict[str, pd.Series] = {}
422
- column_dict = self._query.column_dict
423
- time_tables = self._query.time_tables
424
- for table_name in set(list(column_dict.keys()) + time_tables):
425
- df = self._graph_store.df_dict[table_name]
426
- row_id = node_dict[table_name]
427
- df = df.iloc[row_id].reset_index(drop=True)
428
- if table_name in column_dict:
429
- feat_dict[table_name] = df[list(column_dict[table_name])]
430
- if table_name in time_tables:
431
- time_col = self._graph_store.time_column_dict[table_name]
432
- time_dict[table_name] = df[time_col]
433
-
434
- return feat_dict, time_dict, batch_dict
435
-
436
- def __call__(
437
- self,
438
- node: np.ndarray,
439
- anchor_time: pd.Series,
440
- ) -> Tuple[pd.Series, np.ndarray]:
441
-
442
- feat_dict, time_dict, batch_dict = self._sample(node, anchor_time)
443
-
444
- y, mask = PQueryPandasBackend().eval_pquery(
445
- query=self._query,
446
- feat_dict=feat_dict,
447
- time_dict=time_dict,
448
- batch_dict=batch_dict,
449
- anchor_time=anchor_time,
450
- num_forecasts=self._query.num_forecasts,
451
- )
452
-
453
- return y, mask
454
-
455
-
456
- def date_offset_to_seconds(offset: pd.DateOffset) -> int:
457
- r"""Convert a :class:`pandas.DateOffset` into a maximum number of
458
- nanoseconds.
459
-
460
- .. note::
461
- We are conservative and take months and years as their maximum value.
462
- Additional values are then dropped in label computation where we know
463
- the actual dates.
464
- """
465
- # Max durations for months and years in nanoseconds:
466
- MAX_DAYS_IN_MONTH = 31
467
- MAX_DAYS_IN_YEAR = 366
468
-
469
- # Conversion factors:
470
- SECONDS_IN_MINUTE = 60
471
- SECONDS_IN_HOUR = 60 * SECONDS_IN_MINUTE
472
- SECONDS_IN_DAY = 24 * SECONDS_IN_HOUR
473
-
474
- total_ns = 0
475
- multiplier = getattr(offset, 'n', 1) # The multiplier (if present).
476
-
477
- for attr, value in offset.__dict__.items():
478
- if value is None or value == 0:
479
- continue
480
- scaled_value = value * multiplier
481
- if attr == 'years':
482
- total_ns += scaled_value * MAX_DAYS_IN_YEAR * SECONDS_IN_DAY
483
- elif attr == 'months':
484
- total_ns += scaled_value * MAX_DAYS_IN_MONTH * SECONDS_IN_DAY
485
- elif attr == 'days':
486
- total_ns += scaled_value * SECONDS_IN_DAY
487
- elif attr == 'hours':
488
- total_ns += scaled_value * SECONDS_IN_HOUR
489
- elif attr == 'minutes':
490
- total_ns += scaled_value * SECONDS_IN_MINUTE
491
- elif attr == 'seconds':
492
- total_ns += scaled_value
493
-
494
- return total_ns