kumoai 2.13.0.dev202511261731__cp313-cp313-macosx_11_0_arm64.whl → 2.13.0.dev202512081731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. kumoai/__init__.py +12 -0
  2. kumoai/_version.py +1 -1
  3. kumoai/connector/utils.py +23 -2
  4. kumoai/experimental/rfm/__init__.py +20 -45
  5. kumoai/experimental/rfm/backend/__init__.py +0 -0
  6. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  7. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +20 -30
  8. kumoai/experimental/rfm/backend/local/sampler.py +116 -0
  9. kumoai/experimental/rfm/backend/local/table.py +109 -0
  10. kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
  11. kumoai/experimental/rfm/backend/snow/table.py +117 -0
  12. kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
  13. kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
  14. kumoai/experimental/rfm/base/__init__.py +14 -0
  15. kumoai/experimental/rfm/base/column.py +66 -0
  16. kumoai/experimental/rfm/base/sampler.py +373 -0
  17. kumoai/experimental/rfm/base/source.py +18 -0
  18. kumoai/experimental/rfm/{local_table.py → base/table.py} +139 -139
  19. kumoai/experimental/rfm/{local_graph.py → graph.py} +334 -79
  20. kumoai/experimental/rfm/infer/__init__.py +6 -0
  21. kumoai/experimental/rfm/infer/dtype.py +79 -0
  22. kumoai/experimental/rfm/infer/pkey.py +126 -0
  23. kumoai/experimental/rfm/infer/time_col.py +62 -0
  24. kumoai/experimental/rfm/local_graph_sampler.py +43 -2
  25. kumoai/experimental/rfm/local_pquery_driver.py +1 -1
  26. kumoai/experimental/rfm/rfm.py +7 -17
  27. kumoai/experimental/rfm/sagemaker.py +11 -3
  28. kumoai/testing/decorators.py +1 -1
  29. {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512081731.dist-info}/METADATA +9 -8
  30. {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512081731.dist-info}/RECORD +33 -19
  31. kumoai/experimental/rfm/utils.py +0 -344
  32. {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512081731.dist-info}/WHEEL +0 -0
  33. {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512081731.dist-info}/licenses/LICENSE +0 -0
  34. {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512081731.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,373 @@
1
+ import copy
2
+ import re
3
+ from abc import ABC, abstractmethod
4
+ from collections import defaultdict
5
+ from dataclasses import dataclass
6
+ from typing import TYPE_CHECKING, Literal
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ from kumoapi.pquery import ValidatedPredictiveQuery
11
+ from kumoapi.pquery.AST import Aggregation, ASTNode
12
+ from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
13
+ from kumoapi.typing import Stype
14
+
15
+ if TYPE_CHECKING:
16
+ from kumoai.experimental.rfm import Graph
17
+
18
+
19
+ @dataclass
20
+ class BackwardSamplerOutput:
21
+ df_dict: dict[str, pd.DataFrame]
22
+ inverse_dict: dict[str, np.ndarray]
23
+ batch_dict: dict[str, np.ndarray]
24
+ num_sampled_nodes_dict: dict[str, list[int]]
25
+ row_dict: dict[tuple[str, str, str], np.ndarray]
26
+ col_dict: dict[tuple[str, str, str], np.ndarray]
27
+ num_sampled_edges_dict: dict[tuple[str, str, str], list[int]]
28
+
29
+
30
+ @dataclass
31
+ class ForwardSamplerOutput:
32
+ entity_pkey: pd.Series
33
+ anchor_time: pd.Series
34
+ target: pd.Series
35
+
36
+
37
+ class Sampler(ABC):
38
+ def __init__(self, graph: 'Graph') -> None:
39
+ self._edge_types: list[tuple[str, str, str]] = []
40
+ for edge in graph.edges:
41
+ edge_type = (edge.src_table, edge.fkey, edge.dst_table)
42
+ self._edge_types.append(edge_type)
43
+ self._edge_types.append(Subgraph.rev_edge_type(edge_type))
44
+
45
+ self._primary_key_dict: dict[str, str] = {
46
+ table.name: table._primary_key
47
+ for table in graph.tables.values()
48
+ if table._primary_key is not None
49
+ }
50
+
51
+ self._time_column_dict: dict[str, str] = {
52
+ table.name: table._time_column
53
+ for table in graph.tables.values()
54
+ if table._time_column is not None
55
+ }
56
+
57
+ self._end_time_column_dict: dict[str, str] = {
58
+ table.name: table._end_time_column
59
+ for table in graph.tables.values()
60
+ if table._end_time_column is not None
61
+ }
62
+
63
+ foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
64
+ self._table_stype_dict: dict[str, dict[str, Stype]] = {}
65
+ for table in graph.tables.values():
66
+ self._table_stype_dict[table.name] = {}
67
+ for column in table.columns:
68
+ if column == table.primary_key:
69
+ continue
70
+ if (table.name, column.name) in foreign_keys:
71
+ continue
72
+ self._table_stype_dict[table.name][column.name] = column.stype
73
+
74
+ @property
75
+ def edge_types(self) -> list[tuple[str, str, str]]:
76
+ return self._edge_types
77
+
78
+ @property
79
+ def primary_key_dict(self) -> dict[str, str]:
80
+ return self._primary_key_dict
81
+
82
+ @property
83
+ def time_column_dict(self) -> dict[str, str]:
84
+ return self._time_column_dict
85
+
86
+ @property
87
+ def end_time_column_dict(self) -> dict[str, str]:
88
+ return self._end_time_column_dict
89
+
90
+ @property
91
+ def table_stype_dict(self) -> dict[str, dict[str, Stype]]:
92
+ return self._table_stype_dict
93
+
94
+ def sample_subgraph(
95
+ self,
96
+ entity_table_names: tuple[str, ...],
97
+ entity_pkey: pd.Series,
98
+ anchor_time: pd.Series,
99
+ num_neighbors: list[int],
100
+ exclude_cols_dict: dict[str, list[str]] | None = None,
101
+ ) -> Subgraph:
102
+
103
+ # Exclude all columns that leak target information:
104
+ table_stype_dict: dict[str, dict[str, Stype]] = self._table_stype_dict
105
+ if exclude_cols_dict is not None:
106
+ table_stype_dict = copy.deepcopy(table_stype_dict)
107
+ for table_name, exclude_cols in exclude_cols_dict.items():
108
+ for column_name in exclude_cols:
109
+ del table_stype_dict[table_name][column_name]
110
+
111
+ # Collect all columns being used as features:
112
+ columns_dict: dict[str, set[str]] = {
113
+ table_name: set(stype_dict.keys())
114
+ for table_name, stype_dict in table_stype_dict.items()
115
+ }
116
+ # Make sure to store primary key information for entity tables:
117
+ for table_name in entity_table_names:
118
+ columns_dict[table_name].add(self.primary_key_dict[table_name])
119
+
120
+ if anchor_time.dtype != 'datetime64[ns]':
121
+ anchor_time = anchor_time.astype('datetime64[ns]')
122
+
123
+ out = self._sample_backward(
124
+ entity_table_name=entity_table_names[0],
125
+ entity_pkey=entity_pkey,
126
+ anchor_time=anchor_time,
127
+ columns_dict=columns_dict,
128
+ num_neighbors=num_neighbors,
129
+ )
130
+
131
+ subgraph = Subgraph(
132
+ anchor_time=anchor_time.astype(int).to_numpy(),
133
+ table_dict={},
134
+ link_dict={},
135
+ )
136
+
137
+ for table_name, batch in out.batch_dict.items():
138
+ if len(batch) == 0:
139
+ continue
140
+
141
+ primary_key: str | None = None
142
+ if table_name in entity_table_names:
143
+ primary_key = self.primary_key_dict[table_name]
144
+
145
+ df = out.df_dict[table_name].reset_index(drop=True)
146
+ if end_time_column := self.end_time_column_dict.get(table_name):
147
+ # Set end time to NaT for all values greater than anchor time:
148
+ assert table_name not in out.inverse_dict
149
+ ser = df[end_time_column]
150
+ if ser.dtype != 'datetime64[ns]':
151
+ ser = ser.astype('datetime64[ns]')
152
+ mask = ser > anchor_time.iloc[batch]
153
+ ser.iloc[mask] = pd.NaT
154
+ df[end_time_column] = ser
155
+
156
+ stype_dict = table_stype_dict[table_name]
157
+ for column_name, stype in stype_dict.items():
158
+ if stype == Stype.text:
159
+ df[column_name] = _normalize_text(df[column_name])
160
+
161
+ subgraph.table_dict[table_name] = Table(
162
+ df=df,
163
+ row=out.inverse_dict.get(table_name),
164
+ batch=batch,
165
+ num_sampled_nodes=out.num_sampled_nodes_dict[table_name],
166
+ stype_dict=stype_dict,
167
+ primary_key=primary_key,
168
+ )
169
+
170
+ for edge_type in out.row_dict.keys():
171
+ row: np.ndarray | None = out.row_dict[edge_type]
172
+ col: np.ndarray | None = out.col_dict[edge_type]
173
+
174
+ if row is None or col is None or len(row) == 0:
175
+ continue
176
+
177
+ # Do not store reverse edge type if it is an exact replica:
178
+ rev_edge_type = Subgraph.rev_edge_type(edge_type)
179
+ if (rev_edge_type in subgraph.link_dict
180
+ and np.array_equal(row, out.col_dict[rev_edge_type])
181
+ and np.array_equal(col, out.row_dict[rev_edge_type])):
182
+ subgraph.link_dict[edge_type] = Link(
183
+ layout=EdgeLayout.REV,
184
+ row=None,
185
+ col=None,
186
+ num_sampled_edges=out.num_sampled_edges_dict[edge_type],
187
+ )
188
+ continue
189
+
190
+ # Do not store non-informative edges:
191
+ layout = EdgeLayout.COO
192
+ if np.array_equal(row, np.arange(len(row))):
193
+ row = None
194
+ if np.array_equal(col, np.arange(len(col))):
195
+ col = None
196
+
197
+ # Store in compressed representation if more efficient:
198
+ num_cols = subgraph.table_dict[edge_type[2]].num_rows
199
+ if col is not None and len(col) > num_cols + 1:
200
+ layout = EdgeLayout.CSC
201
+ colcount = np.bincount(col, minlength=num_cols)
202
+ col = np.empty(num_cols + 1, dtype=col.dtype)
203
+ col[0] = 0
204
+ np.cumsum(colcount, out=col[1:])
205
+
206
+ subgraph.link_dict[edge_type] = Link(
207
+ layout=layout,
208
+ row=row,
209
+ col=col,
210
+ num_sampled_edges=out.num_sampled_edges_dict[edge_type],
211
+ )
212
+
213
+ return subgraph
214
+
215
+ def sample_forward(
216
+ self,
217
+ query: ValidatedPredictiveQuery,
218
+ num_examples: int,
219
+ anchor_time: pd.Timestamp | Literal['entity'],
220
+ random_seed: int | None = None,
221
+ ) -> ForwardSamplerOutput:
222
+
223
+ columns_dict: dict[str, set[str]] = defaultdict(set)
224
+ for fqn in query.all_query_columns + [query.entity_column]:
225
+ table_name, column_name = fqn.split('.')
226
+ columns_dict[table_name].add(column_name)
227
+
228
+ if time_column := self.time_column_dict[query.entity_table]:
229
+ columns_dict[table_name].add(time_column)
230
+ if end_time_column := self.end_time_column_dict[query.entity_table]:
231
+ columns_dict[table_name].add(end_time_column)
232
+
233
+ time_offset_dict: dict[
234
+ tuple[str, str, str],
235
+ tuple[pd.DateOffset | None, pd.DateOffset],
236
+ ] = {}
237
+
238
+ def _add_time_offset(node: ASTNode, num_forecasts: int = 1) -> None:
239
+ if isinstance(node, Aggregation):
240
+ table_name = node._get_target_column_name().split('.')[0]
241
+ columns_dict[table_name].add(self.time_column_dict[table_name])
242
+
243
+ edge_types = [
244
+ edge_type for edge_type in self.edge_types
245
+ if edge_type[0] == table_name
246
+ and edge_type[2] == query.entity_table
247
+ ]
248
+ if len(edge_types) != 1:
249
+ raise ValueError(f"Could not find a unique foreign key "
250
+ f"from table '{table_name}' to "
251
+ f"'{query.entity_table}'")
252
+ if edge_types[0] not in time_offset_dict:
253
+ start = node.aggr_time_range.start_date_offset
254
+ end = node.aggr_time_range.end_date_offset * num_forecasts
255
+ else:
256
+ start, end = time_offset_dict[edge_types[0]]
257
+ start = min_date_offset(
258
+ start,
259
+ node.aggr_time_range.start_date_offset,
260
+ )
261
+ end = max_date_offset(
262
+ end,
263
+ node.aggr_time_range.end_date_offset * num_forecasts,
264
+ )
265
+ time_offset_dict[edge_types[0]] = (start, end)
266
+
267
+ for child in node.children:
268
+ _add_time_offset(child, num_forecasts)
269
+
270
+ _add_time_offset(query.target_ast, query.num_forecasts)
271
+ _add_time_offset(query.entity_ast)
272
+ if query.whatif_ast is not None:
273
+ _add_time_offset(query.whatif_ast)
274
+
275
+ return self._sample_forward(
276
+ query=query,
277
+ num_examples=num_examples,
278
+ anchor_time=anchor_time,
279
+ columns_dict=columns_dict,
280
+ time_offset_dict=time_offset_dict,
281
+ random_seed=random_seed,
282
+ )
283
+
284
+ # Abstract Methods ########################################################
285
+
286
+ @abstractmethod
287
+ def _sample_backward(
288
+ self,
289
+ entity_table_name: str,
290
+ entity_pkey: pd.Series,
291
+ anchor_time: pd.Series,
292
+ columns_dict: dict[str, set[str]],
293
+ num_neighbors: list[int],
294
+ ) -> BackwardSamplerOutput:
295
+ pass
296
+
297
+ @abstractmethod
298
+ def _sample_forward(
299
+ self,
300
+ query: ValidatedPredictiveQuery,
301
+ num_examples: int,
302
+ anchor_time: pd.Timestamp | Literal['entity'],
303
+ columns_dict: dict[str, set[str]],
304
+ time_offset_dict: dict[
305
+ tuple[str, str, str],
306
+ tuple[pd.DateOffset | None, pd.DateOffset],
307
+ ],
308
+ random_seed: int | None = None,
309
+ ) -> ForwardSamplerOutput:
310
+ pass
311
+
312
+
313
+ # Helper Functions ############################################################
314
+
315
+ PUNCTUATION = re.compile(r"[\'\"\.,\(\)\!\?\;\:]")
316
+ MULTISPACE = re.compile(r"\s+")
317
+
318
+
319
+ def _normalize_text(
320
+ ser: pd.Series,
321
+ max_words: int | None = 50,
322
+ ) -> pd.Series:
323
+ r"""Normalizes text into a list of lower-case words.
324
+
325
+ Args:
326
+ ser: The :class:`pandas.Series` to normalize.
327
+ max_words: The maximum number of words to return.
328
+ This will auto-shrink any large text column to avoid blowing up
329
+ context size.
330
+ """
331
+ if len(ser) == 0 or pd.api.types.is_list_like(ser.iloc[0]):
332
+ return ser
333
+
334
+ def normalize_fn(line: str) -> list[str]:
335
+ line = PUNCTUATION.sub(" ", line)
336
+ line = re.sub(r"<br\s*/?>", " ", line) # Handle <br /> or <br>
337
+ line = MULTISPACE.sub(" ", line)
338
+ words = line.split()
339
+ if max_words is not None:
340
+ words = words[:max_words]
341
+ return words
342
+
343
+ ser = ser.fillna('').astype(str)
344
+
345
+ if max_words is not None:
346
+ # We estimate the number of words as 5 characters + 1 space in an
347
+ # English text on average. We need this pre-filter here, as word
348
+ # splitting on a giant text can be very expensive:
349
+ ser = ser.str[:6 * max_words]
350
+
351
+ ser = ser.str.lower()
352
+ ser = ser.map(normalize_fn)
353
+
354
+ return ser
355
+
356
+
357
+ def min_date_offset(*args: pd.DateOffset | None) -> pd.DateOffset | None:
358
+ if any(arg is None for arg in args):
359
+ return None
360
+
361
+ anchor = pd.Timestamp('2000-01-01')
362
+ timestamps = [anchor + arg for arg in args]
363
+ assert len(timestamps) > 0
364
+ argmin = min(range(len(timestamps)), key=lambda i: timestamps[i])
365
+ return args[argmin]
366
+
367
+
368
+ def max_date_offset(*args: pd.DateOffset) -> pd.DateOffset:
369
+ anchor = pd.Timestamp('2000-01-01')
370
+ timestamps = [anchor + arg for arg in args]
371
+ assert len(timestamps) > 0
372
+ argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
373
+ return args[argmax]
@@ -0,0 +1,18 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kumoapi.typing import Dtype
4
+
5
+
6
+ @dataclass
7
+ class SourceColumn:
8
+ name: str
9
+ dtype: Dtype
10
+ is_primary_key: bool
11
+ is_unique_key: bool
12
+
13
+
14
+ @dataclass
15
+ class SourceForeignKey:
16
+ name: str
17
+ dst_table: str
18
+ primary_key: str