kumoai 2.10.0.dev202509231831__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512161731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (53) hide show
  1. kumoai/__init__.py +22 -11
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +17 -16
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/client/rfm.py +37 -8
  7. kumoai/connector/utils.py +23 -2
  8. kumoai/experimental/rfm/__init__.py +164 -46
  9. kumoai/experimental/rfm/backend/__init__.py +0 -0
  10. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  11. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +49 -86
  12. kumoai/experimental/rfm/backend/local/sampler.py +315 -0
  13. kumoai/experimental/rfm/backend/local/table.py +119 -0
  14. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  15. kumoai/experimental/rfm/backend/snow/sampler.py +274 -0
  16. kumoai/experimental/rfm/backend/snow/table.py +135 -0
  17. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  18. kumoai/experimental/rfm/backend/sqlite/sampler.py +353 -0
  19. kumoai/experimental/rfm/backend/sqlite/table.py +126 -0
  20. kumoai/experimental/rfm/base/__init__.py +25 -0
  21. kumoai/experimental/rfm/base/column.py +66 -0
  22. kumoai/experimental/rfm/base/sampler.py +773 -0
  23. kumoai/experimental/rfm/base/source.py +19 -0
  24. kumoai/experimental/rfm/base/sql_sampler.py +60 -0
  25. kumoai/experimental/rfm/{local_table.py → base/table.py} +245 -156
  26. kumoai/experimental/rfm/{local_graph.py → graph.py} +425 -137
  27. kumoai/experimental/rfm/infer/__init__.py +6 -0
  28. kumoai/experimental/rfm/infer/dtype.py +79 -0
  29. kumoai/experimental/rfm/infer/pkey.py +126 -0
  30. kumoai/experimental/rfm/infer/time_col.py +62 -0
  31. kumoai/experimental/rfm/infer/timestamp.py +7 -4
  32. kumoai/experimental/rfm/pquery/__init__.py +4 -4
  33. kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
  34. kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +278 -224
  35. kumoai/experimental/rfm/rfm.py +669 -246
  36. kumoai/experimental/rfm/sagemaker.py +138 -0
  37. kumoai/jobs.py +1 -0
  38. kumoai/pquery/predictive_query.py +10 -6
  39. kumoai/spcs.py +1 -3
  40. kumoai/testing/decorators.py +1 -1
  41. kumoai/testing/snow.py +50 -0
  42. kumoai/trainer/trainer.py +12 -10
  43. kumoai/utils/__init__.py +3 -2
  44. kumoai/utils/progress_logger.py +239 -4
  45. kumoai/utils/sql.py +3 -0
  46. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/METADATA +15 -5
  47. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/RECORD +50 -32
  48. kumoai/experimental/rfm/local_graph_sampler.py +0 -176
  49. kumoai/experimental/rfm/local_pquery_driver.py +0 -404
  50. kumoai/experimental/rfm/utils.py +0 -344
  51. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/WHEEL +0 -0
  52. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/licenses/LICENSE +0 -0
  53. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/top_level.txt +0 -0
@@ -1,404 +0,0 @@
1
- import warnings
2
- from typing import Dict, List, Literal, Optional, Tuple, Union
3
-
4
- import numpy as np
5
- import pandas as pd
6
- from kumoapi.pquery import QueryType
7
- from kumoapi.rfm import PQueryDefinition
8
-
9
- import kumoai.kumolib as kumolib
10
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
11
- from kumoai.experimental.rfm.pquery import PQueryPandasBackend
12
-
13
- _coverage_warned = False
14
-
15
-
16
- class LocalPQueryDriver:
17
- def __init__(
18
- self,
19
- graph_store: LocalGraphStore,
20
- query: PQueryDefinition,
21
- random_seed: Optional[int],
22
- ) -> None:
23
- self._graph_store = graph_store
24
- self._query = query
25
- self._random_seed = random_seed
26
- self._rng = np.random.default_rng(random_seed)
27
-
28
- def _get_candidates(
29
- self,
30
- anchor_time: Union[pd.Timestamp, Literal['entity']],
31
- exclude_node: Optional[np.ndarray] = None,
32
- ) -> np.ndarray:
33
-
34
- if self._query.query_type == QueryType.TEMPORAL:
35
- assert exclude_node is None
36
-
37
- table_name = self._query.entity.pkey.table_name
38
- num_nodes = len(self._graph_store.df_dict[table_name])
39
- mask_dict = self._graph_store.mask_dict
40
-
41
- candidate: np.ndarray
42
-
43
- # Case 1: All nodes are valid and nothing to exclude:
44
- if exclude_node is None and table_name not in mask_dict:
45
- candidate = np.arange(num_nodes)
46
-
47
- # Case 2: Not all nodes are valid - lookup valid nodes:
48
- if exclude_node is None:
49
- pkey_map = self._graph_store.pkey_map_dict[table_name]
50
- candidate = pkey_map['arange'].to_numpy().copy()
51
-
52
- # Case 3: Exclude nodes - use a mask to exclude them:
53
- else:
54
- mask = np.full((num_nodes, ), fill_value=True, dtype=bool)
55
- mask[exclude_node] = False
56
- if table_name in mask_dict:
57
- mask &= mask_dict[table_name]
58
- candidate = mask.nonzero()[0]
59
-
60
- self._rng.shuffle(candidate)
61
-
62
- return candidate
63
-
64
- def collect_test(
65
- self,
66
- size: int,
67
- anchor_time: Union[pd.Timestamp, Literal['entity']],
68
- batch_size: Optional[int] = None,
69
- max_iterations: int = 20,
70
- guarantee_train_examples: bool = True,
71
- ) -> Tuple[np.ndarray, pd.Series, pd.Series]:
72
- r"""Collects test nodes and their labels used for evaluation.
73
-
74
- Args:
75
- size: The number of test nodes to collect.
76
- anchor_time: The anchor time.
77
- batch_size: How many nodes to process in a single batch.
78
- max_iterations: The number of steps to run before aborting.
79
- guarantee_train_examples: Ensures that test examples do not occupy
80
- the entire set of entity candidates.
81
-
82
- Returns:
83
- A triplet holding the nodes, timestamps and labels.
84
- """
85
- batch_size = size if batch_size is None else batch_size
86
-
87
- candidate = self._get_candidates(anchor_time)
88
-
89
- nodes: List[np.ndarray] = []
90
- times: List[pd.Series] = []
91
- ys: List[pd.Series] = []
92
-
93
- reached_end = False
94
- num_labels = candidate_offset = 0
95
- for _ in range(max_iterations):
96
- node = candidate[candidate_offset:candidate_offset + batch_size]
97
-
98
- if isinstance(anchor_time, pd.Timestamp):
99
- # Filter out non-existent entities:
100
- time = self._graph_store.time_dict.get(
101
- self._query.entity.pkey.table_name)
102
- if time is not None:
103
- node = node[time[node] <= (anchor_time.value // (1000**3))]
104
-
105
- if isinstance(anchor_time, pd.Timestamp):
106
- time = pd.Series(anchor_time).repeat(len(node))
107
- time = time.astype('datetime64[ns]').reset_index(drop=True)
108
- else:
109
- assert anchor_time == 'entity'
110
- time = self._graph_store.time_dict[
111
- self._query.entity.pkey.table_name]
112
- time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
113
-
114
- y, mask = self(node, time)
115
-
116
- nodes.append(node[mask])
117
- times.append(time[mask].reset_index(drop=True))
118
- ys.append(y)
119
-
120
- num_labels += len(y)
121
-
122
- if num_labels > size:
123
- reached_end = True
124
- break # Sufficient number of labels collected. Abort.
125
-
126
- candidate_offset += batch_size
127
- if candidate_offset >= len(candidate):
128
- reached_end = True
129
- break
130
-
131
- if len(nodes) > 1:
132
- node = np.concatenate(nodes, axis=0)[:size]
133
- time = pd.concat(times, axis=0).reset_index(drop=True).iloc[:size]
134
- y = pd.concat(ys, axis=0).reset_index(drop=True).iloc[:size]
135
- else:
136
- node = nodes[0][:size]
137
- time = times[0].iloc[:size]
138
- y = ys[0].iloc[:size]
139
-
140
- if len(node) == 0:
141
- raise RuntimeError("Failed to collect any test examples for "
142
- "evaluation. Is your predictive query too "
143
- "restrictive?")
144
-
145
- global _coverage_warned
146
- if not _coverage_warned and not reached_end and len(node) < size // 2:
147
- _coverage_warned = True
148
- warnings.warn(f"Failed to collect {size:,} test examples within "
149
- f"{max_iterations} iterations. To improve coverage, "
150
- f"consider increasing the number of PQ iterations "
151
- f"using the 'max_pq_iterations' option. This "
152
- f"warning will not be shown again in this run.")
153
-
154
- if (guarantee_train_examples
155
- and self._query.query_type == QueryType.STATIC
156
- and candidate_offset >= len(candidate)):
157
- # In case all valid entities are used as test examples, we can no
158
- # longer find any training example. Fallback to a 50/50 split:
159
- size = len(node) // 2
160
- node = node[:size]
161
- time = time.iloc[:size]
162
- y = y.iloc[:size]
163
-
164
- return node, time, y
165
-
166
- def collect_train(
167
- self,
168
- size: int,
169
- anchor_time: Union[pd.Timestamp, Literal['entity']],
170
- exclude_node: Optional[np.ndarray] = None,
171
- batch_size: Optional[int] = None,
172
- max_iterations: int = 20,
173
- ) -> Tuple[np.ndarray, pd.Series, pd.Series]:
174
- r"""Collects training nodes and their labels.
175
-
176
- Args:
177
- size: The number of test nodes to collect.
178
- anchor_time: The anchor time.
179
- exclude_node: The nodes to exclude for use as in-context examples.
180
- batch_size: How many nodes to process in a single batch.
181
- max_iterations: The number of steps to run before aborting.
182
-
183
- Returns:
184
- A triplet holding the nodes, timestamps and labels.
185
- """
186
- batch_size = size if batch_size is None else batch_size
187
-
188
- candidate = self._get_candidates(anchor_time, exclude_node)
189
-
190
- if len(candidate) == 0:
191
- raise RuntimeError("Failed to generate any context examples "
192
- "since not enough entities exist")
193
-
194
- nodes: List[np.ndarray] = []
195
- times: List[pd.Series] = []
196
- ys: List[pd.Series] = []
197
-
198
- reached_end = False
199
- num_labels = candidate_offset = 0
200
- for _ in range(max_iterations):
201
- node = candidate[candidate_offset:candidate_offset + batch_size]
202
-
203
- if isinstance(anchor_time, pd.Timestamp):
204
- # Filter out non-existent entities:
205
- time = self._graph_store.time_dict.get(
206
- self._query.entity.pkey.table_name)
207
- if time is not None:
208
- node = node[time[node] <= (anchor_time.value // (1000**3))]
209
-
210
- if isinstance(anchor_time, pd.Timestamp):
211
- time = pd.Series(anchor_time).repeat(len(node))
212
- time = time.astype('datetime64[ns]').reset_index(drop=True)
213
- else:
214
- assert anchor_time == 'entity'
215
- time = self._graph_store.time_dict[
216
- self._query.entity.pkey.table_name]
217
- time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
218
-
219
- y, mask = self(node, time)
220
-
221
- nodes.append(node[mask])
222
- times.append(time[mask].reset_index(drop=True))
223
- ys.append(y)
224
-
225
- num_labels += len(y)
226
-
227
- if num_labels > size:
228
- reached_end = True
229
- break # Sufficient number of labels collected. Abort.
230
-
231
- candidate_offset += batch_size
232
- if candidate_offset >= len(candidate):
233
- # Restart with an earlier anchor time (if applicable).
234
- if self._query.query_type == QueryType.STATIC:
235
- reached_end = True
236
- break # Cannot jump back in time for static PQs. Abort.
237
- if anchor_time == 'entity':
238
- reached_end = True
239
- break
240
- candidate_offset = 0
241
- anchor_time = anchor_time - (self._query.target.end_offset *
242
- self._query.num_forecasts)
243
- if anchor_time < self._graph_store.min_time:
244
- reached_end = True
245
- break # No earlier anchor time left. Abort.
246
-
247
- if len(nodes) > 1:
248
- node = np.concatenate(nodes, axis=0)[:size]
249
- time = pd.concat(times, axis=0).reset_index(drop=True).iloc[:size]
250
- y = pd.concat(ys, axis=0).reset_index(drop=True).iloc[:size]
251
- else:
252
- node = nodes[0][:size]
253
- time = times[0].iloc[:size]
254
- y = ys[0].iloc[:size]
255
-
256
- if len(node) == 0:
257
- raise ValueError("Failed to collect any context examples. Is your "
258
- "predictive query too restrictive?")
259
-
260
- global _coverage_warned
261
- if not _coverage_warned and not reached_end and len(node) < size // 2:
262
- _coverage_warned = True
263
- warnings.warn(f"Failed to collect {size:,} context examples "
264
- f"within {max_iterations} iterations. To improve "
265
- f"coverage, consider increasing the number of PQ "
266
- f"iterations using the 'max_pq_iterations' option. "
267
- f"This warning will not be shown again in this run.")
268
-
269
- return node, time, y
270
-
271
- def __call__(
272
- self,
273
- node: np.ndarray,
274
- anchor_time: pd.Series,
275
- ) -> Tuple[pd.Series, np.ndarray]:
276
-
277
- specs = self._query.get_sampling_specs(self._graph_store.edge_types)
278
- num_hops = max([spec.hop for spec in specs] + [0])
279
- num_neighbors: Dict[Tuple[str, str, str], list[int]] = {}
280
- time_offsets: Dict[
281
- Tuple[str, str, str],
282
- List[List[Optional[int]]],
283
- ] = {}
284
- for spec in specs:
285
- if spec.end_offset is not None:
286
- if spec.edge_type not in time_offsets:
287
- time_offsets[spec.edge_type] = [[0, 0]
288
- for _ in range(num_hops)]
289
- offset: Optional[int] = date_offset_to_seconds(spec.end_offset)
290
- time_offsets[spec.edge_type][spec.hop - 1][1] = offset
291
- if spec.start_offset is not None:
292
- offset = date_offset_to_seconds(spec.start_offset)
293
- else:
294
- offset = None
295
- time_offsets[spec.edge_type][spec.hop - 1][0] = offset
296
- else:
297
- if spec.edge_type not in num_neighbors:
298
- num_neighbors[spec.edge_type] = [0] * num_hops
299
- num_neighbors[spec.edge_type][spec.hop - 1] = -1
300
-
301
- edge_types = list(num_neighbors.keys()) + list(time_offsets.keys())
302
- node_types = list(
303
- set([self._query.entity.pkey.table_name])
304
- | set(src for src, _, _ in edge_types)
305
- | set(dst for _, _, dst in edge_types))
306
-
307
- sampler = kumolib.NeighborSampler(
308
- node_types,
309
- edge_types,
310
- {
311
- '__'.join(edge_type): self._graph_store.colptr_dict[edge_type]
312
- for edge_type in edge_types
313
- },
314
- {
315
- '__'.join(edge_type): self._graph_store.row_dict[edge_type]
316
- for edge_type in edge_types
317
- },
318
- {
319
- node_type: time
320
- for node_type, time in self._graph_store.time_dict.items()
321
- if node_type in node_types
322
- },
323
- )
324
-
325
- anchor_time = anchor_time.astype('datetime64[ns]')
326
- _, _, node_dict, batch_dict, _, _ = sampler.sample(
327
- {
328
- '__'.join(edge_type): np.array(values)
329
- for edge_type, values in num_neighbors.items()
330
- },
331
- {
332
- '__'.join(edge_type): np.array(values)
333
- for edge_type, values in time_offsets.items()
334
- },
335
- self._query.entity.pkey.table_name,
336
- node,
337
- anchor_time.astype(int).to_numpy() // 1000**3,
338
- )
339
-
340
- feat_dict: Dict[str, pd.DataFrame] = {}
341
- time_dict: Dict[str, pd.Series] = {}
342
- column_dict = self._query.column_dict
343
- time_tables = self._query.time_tables
344
- for table_name in set(list(column_dict.keys()) + time_tables):
345
- df = self._graph_store.df_dict[table_name]
346
- row_id = node_dict[table_name]
347
- df = df.iloc[row_id].reset_index(drop=True)
348
- if table_name in column_dict:
349
- feat_dict[table_name] = df[list(column_dict[table_name])]
350
- if table_name in time_tables:
351
- time_col = self._graph_store.time_column_dict[table_name]
352
- time_dict[table_name] = df[time_col]
353
-
354
- y, mask = PQueryPandasBackend().eval_pquery(
355
- query=self._query,
356
- feat_dict=feat_dict,
357
- time_dict=time_dict,
358
- batch_dict=batch_dict,
359
- anchor_time=anchor_time,
360
- num_forecasts=self._query.num_forecasts,
361
- )
362
-
363
- return y, mask
364
-
365
-
366
- def date_offset_to_seconds(offset: pd.DateOffset) -> int:
367
- r"""Convert a :class:`pandas.DateOffset` into a maximum number of
368
- nanoseconds.
369
-
370
- .. note::
371
- We are conservative and take months and years as their maximum value.
372
- Additional values are then dropped in label computation where we know
373
- the actual dates.
374
- """
375
- # Max durations for months and years in nanoseconds:
376
- MAX_DAYS_IN_MONTH = 31
377
- MAX_DAYS_IN_YEAR = 366
378
-
379
- # Conversion factors:
380
- SECONDS_IN_MINUTE = 60
381
- SECONDS_IN_HOUR = 60 * SECONDS_IN_MINUTE
382
- SECONDS_IN_DAY = 24 * SECONDS_IN_HOUR
383
-
384
- total_ns = 0
385
- multiplier = getattr(offset, 'n', 1) # The multiplier (if present).
386
-
387
- for attr, value in offset.__dict__.items():
388
- if value is None or value == 0:
389
- continue
390
- scaled_value = value * multiplier
391
- if attr == 'years':
392
- total_ns += scaled_value * MAX_DAYS_IN_YEAR * SECONDS_IN_DAY
393
- elif attr == 'months':
394
- total_ns += scaled_value * MAX_DAYS_IN_MONTH * SECONDS_IN_DAY
395
- elif attr == 'days':
396
- total_ns += scaled_value * SECONDS_IN_DAY
397
- elif attr == 'hours':
398
- total_ns += scaled_value * SECONDS_IN_HOUR
399
- elif attr == 'minutes':
400
- total_ns += scaled_value * SECONDS_IN_MINUTE
401
- elif attr == 'seconds':
402
- total_ns += scaled_value
403
-
404
- return total_ns