kumoai 2.13.0.dev202511181731__cp311-cp311-macosx_11_0_arm64.whl → 2.13.0.dev202512091732__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. kumoai/__init__.py +12 -0
  2. kumoai/_version.py +1 -1
  3. kumoai/connector/utils.py +23 -2
  4. kumoai/experimental/rfm/__init__.py +20 -45
  5. kumoai/experimental/rfm/backend/__init__.py +0 -0
  6. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  7. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +20 -30
  8. kumoai/experimental/rfm/backend/local/sampler.py +242 -0
  9. kumoai/experimental/rfm/backend/local/table.py +109 -0
  10. kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
  11. kumoai/experimental/rfm/backend/snow/table.py +117 -0
  12. kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
  13. kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
  14. kumoai/experimental/rfm/base/__init__.py +14 -0
  15. kumoai/experimental/rfm/base/column.py +66 -0
  16. kumoai/experimental/rfm/base/sampler.py +374 -0
  17. kumoai/experimental/rfm/base/source.py +18 -0
  18. kumoai/experimental/rfm/{local_table.py → base/table.py} +139 -139
  19. kumoai/experimental/rfm/{local_graph.py → graph.py} +334 -79
  20. kumoai/experimental/rfm/infer/__init__.py +6 -0
  21. kumoai/experimental/rfm/infer/dtype.py +79 -0
  22. kumoai/experimental/rfm/infer/pkey.py +126 -0
  23. kumoai/experimental/rfm/infer/time_col.py +62 -0
  24. kumoai/experimental/rfm/local_graph_sampler.py +43 -4
  25. kumoai/experimental/rfm/local_pquery_driver.py +1 -1
  26. kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
  27. kumoai/experimental/rfm/rfm.py +17 -19
  28. kumoai/experimental/rfm/sagemaker.py +11 -3
  29. kumoai/testing/decorators.py +1 -1
  30. {kumoai-2.13.0.dev202511181731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/METADATA +9 -8
  31. {kumoai-2.13.0.dev202511181731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/RECORD +34 -20
  32. kumoai/experimental/rfm/utils.py +0 -344
  33. {kumoai-2.13.0.dev202511181731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/WHEEL +0 -0
  34. {kumoai-2.13.0.dev202511181731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/licenses/LICENSE +0 -0
  35. {kumoai-2.13.0.dev202511181731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,374 @@
1
+ import copy
2
+ import re
3
+ from abc import ABC, abstractmethod
4
+ from collections import defaultdict
5
+ from dataclasses import dataclass
6
+ from typing import TYPE_CHECKING, Literal
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ from kumoapi.pquery import ValidatedPredictiveQuery
11
+ from kumoapi.pquery.AST import Aggregation, ASTNode
12
+ from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
13
+ from kumoapi.typing import Stype
14
+
15
+ if TYPE_CHECKING:
16
+ from kumoai.experimental.rfm import Graph
17
+
18
+
19
+ @dataclass
20
+ class SamplerOutput:
21
+ df_dict: dict[str, pd.DataFrame]
22
+ inverse_dict: dict[str, np.ndarray]
23
+ batch_dict: dict[str, np.ndarray]
24
+ num_sampled_nodes_dict: dict[str, list[int]]
25
+ row_dict: dict[tuple[str, str, str], np.ndarray]
26
+ col_dict: dict[tuple[str, str, str], np.ndarray]
27
+ num_sampled_edges_dict: dict[tuple[str, str, str], list[int]]
28
+
29
+
30
+ @dataclass
31
+ class TargetOutput:
32
+ entity_pkey: pd.Series
33
+ anchor_time: pd.Series
34
+ target: pd.Series
35
+ num_trials: int
36
+
37
+
38
+ class Sampler(ABC):
39
+ def __init__(self, graph: 'Graph') -> None:
40
+ self._edge_types: list[tuple[str, str, str]] = []
41
+ for edge in graph.edges:
42
+ edge_type = (edge.src_table, edge.fkey, edge.dst_table)
43
+ self._edge_types.append(edge_type)
44
+ self._edge_types.append(Subgraph.rev_edge_type(edge_type))
45
+
46
+ self._primary_key_dict: dict[str, str] = {
47
+ table.name: table._primary_key
48
+ for table in graph.tables.values()
49
+ if table._primary_key is not None
50
+ }
51
+
52
+ self._time_column_dict: dict[str, str] = {
53
+ table.name: table._time_column
54
+ for table in graph.tables.values()
55
+ if table._time_column is not None
56
+ }
57
+
58
+ self._end_time_column_dict: dict[str, str] = {
59
+ table.name: table._end_time_column
60
+ for table in graph.tables.values()
61
+ if table._end_time_column is not None
62
+ }
63
+
64
+ foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
65
+ self._table_stype_dict: dict[str, dict[str, Stype]] = {}
66
+ for table in graph.tables.values():
67
+ self._table_stype_dict[table.name] = {}
68
+ for column in table.columns:
69
+ if column == table.primary_key:
70
+ continue
71
+ if (table.name, column.name) in foreign_keys:
72
+ continue
73
+ self._table_stype_dict[table.name][column.name] = column.stype
74
+
75
+ @property
76
+ def edge_types(self) -> list[tuple[str, str, str]]:
77
+ return self._edge_types
78
+
79
+ @property
80
+ def primary_key_dict(self) -> dict[str, str]:
81
+ return self._primary_key_dict
82
+
83
+ @property
84
+ def time_column_dict(self) -> dict[str, str]:
85
+ return self._time_column_dict
86
+
87
+ @property
88
+ def end_time_column_dict(self) -> dict[str, str]:
89
+ return self._end_time_column_dict
90
+
91
+ @property
92
+ def table_stype_dict(self) -> dict[str, dict[str, Stype]]:
93
+ return self._table_stype_dict
94
+
95
+ def sample_subgraph(
96
+ self,
97
+ entity_table_names: tuple[str, ...],
98
+ entity_pkey: pd.Series,
99
+ anchor_time: pd.Series,
100
+ num_neighbors: list[int],
101
+ exclude_cols_dict: dict[str, list[str]] | None = None,
102
+ ) -> Subgraph:
103
+
104
+ # Exclude all columns that leak target information:
105
+ table_stype_dict: dict[str, dict[str, Stype]] = self._table_stype_dict
106
+ if exclude_cols_dict is not None:
107
+ table_stype_dict = copy.deepcopy(table_stype_dict)
108
+ for table_name, exclude_cols in exclude_cols_dict.items():
109
+ for column_name in exclude_cols:
110
+ del table_stype_dict[table_name][column_name]
111
+
112
+ # Collect all columns being used as features:
113
+ columns_dict: dict[str, set[str]] = {
114
+ table_name: set(stype_dict.keys())
115
+ for table_name, stype_dict in table_stype_dict.items()
116
+ }
117
+ # Make sure to store primary key information for entity tables:
118
+ for table_name in entity_table_names:
119
+ columns_dict[table_name].add(self.primary_key_dict[table_name])
120
+
121
+ if anchor_time.dtype != 'datetime64[ns]':
122
+ anchor_time = anchor_time.astype('datetime64[ns]')
123
+
124
+ out = self._sample_subgraph(
125
+ entity_table_name=entity_table_names[0],
126
+ entity_pkey=entity_pkey,
127
+ anchor_time=anchor_time,
128
+ columns_dict=columns_dict,
129
+ num_neighbors=num_neighbors,
130
+ )
131
+
132
+ subgraph = Subgraph(
133
+ anchor_time=anchor_time.astype(int).to_numpy(),
134
+ table_dict={},
135
+ link_dict={},
136
+ )
137
+
138
+ for table_name, batch in out.batch_dict.items():
139
+ if len(batch) == 0:
140
+ continue
141
+
142
+ primary_key: str | None = None
143
+ if table_name in entity_table_names:
144
+ primary_key = self.primary_key_dict[table_name]
145
+
146
+ df = out.df_dict[table_name].reset_index(drop=True)
147
+ if end_time_column := self.end_time_column_dict.get(table_name):
148
+ # Set end time to NaT for all values greater than anchor time:
149
+ assert table_name not in out.inverse_dict
150
+ ser = df[end_time_column]
151
+ if ser.dtype != 'datetime64[ns]':
152
+ ser = ser.astype('datetime64[ns]')
153
+ mask = ser > anchor_time.iloc[batch]
154
+ ser.iloc[mask] = pd.NaT
155
+ df[end_time_column] = ser
156
+
157
+ stype_dict = table_stype_dict[table_name]
158
+ for column_name, stype in stype_dict.items():
159
+ if stype == Stype.text:
160
+ df[column_name] = _normalize_text(df[column_name])
161
+
162
+ subgraph.table_dict[table_name] = Table(
163
+ df=df,
164
+ row=out.inverse_dict.get(table_name),
165
+ batch=batch,
166
+ num_sampled_nodes=out.num_sampled_nodes_dict[table_name],
167
+ stype_dict=stype_dict,
168
+ primary_key=primary_key,
169
+ )
170
+
171
+ for edge_type in out.row_dict.keys():
172
+ row: np.ndarray | None = out.row_dict[edge_type]
173
+ col: np.ndarray | None = out.col_dict[edge_type]
174
+
175
+ if row is None or col is None or len(row) == 0:
176
+ continue
177
+
178
+ # Do not store reverse edge type if it is an exact replica:
179
+ rev_edge_type = Subgraph.rev_edge_type(edge_type)
180
+ if (rev_edge_type in subgraph.link_dict
181
+ and np.array_equal(row, out.col_dict[rev_edge_type])
182
+ and np.array_equal(col, out.row_dict[rev_edge_type])):
183
+ subgraph.link_dict[edge_type] = Link(
184
+ layout=EdgeLayout.REV,
185
+ row=None,
186
+ col=None,
187
+ num_sampled_edges=out.num_sampled_edges_dict[edge_type],
188
+ )
189
+ continue
190
+
191
+ # Do not store non-informative edges:
192
+ layout = EdgeLayout.COO
193
+ if np.array_equal(row, np.arange(len(row))):
194
+ row = None
195
+ if np.array_equal(col, np.arange(len(col))):
196
+ col = None
197
+
198
+ # Store in compressed representation if more efficient:
199
+ num_cols = subgraph.table_dict[edge_type[2]].num_rows
200
+ if col is not None and len(col) > num_cols + 1:
201
+ layout = EdgeLayout.CSC
202
+ colcount = np.bincount(col, minlength=num_cols)
203
+ col = np.empty(num_cols + 1, dtype=col.dtype)
204
+ col[0] = 0
205
+ np.cumsum(colcount, out=col[1:])
206
+
207
+ subgraph.link_dict[edge_type] = Link(
208
+ layout=layout,
209
+ row=row,
210
+ col=col,
211
+ num_sampled_edges=out.num_sampled_edges_dict[edge_type],
212
+ )
213
+
214
+ return subgraph
215
+
216
+ def sample_target(
217
+ self,
218
+ query: ValidatedPredictiveQuery,
219
+ num_examples: int,
220
+ anchor_time: pd.Timestamp | Literal['entity'],
221
+ random_seed: int | None = None,
222
+ ) -> TargetOutput:
223
+
224
+ columns_dict: dict[str, set[str]] = defaultdict(set)
225
+ for fqn in query.all_query_columns + [query.entity_column]:
226
+ table_name, column_name = fqn.split('.')
227
+ columns_dict[table_name].add(column_name)
228
+
229
+ if column_name := self.time_column_dict.get(query.entity_table):
230
+ columns_dict[table_name].add(column_name)
231
+ if column_name := self.end_time_column_dict.get(query.entity_table):
232
+ columns_dict[table_name].add(column_name)
233
+
234
+ time_offset_dict: dict[
235
+ tuple[str, str, str],
236
+ tuple[pd.DateOffset | None, pd.DateOffset],
237
+ ] = {}
238
+
239
+ def _add_time_offset(node: ASTNode, num_forecasts: int = 1) -> None:
240
+ if isinstance(node, Aggregation):
241
+ table_name = node._get_target_column_name().split('.')[0]
242
+ columns_dict[table_name].add(self.time_column_dict[table_name])
243
+
244
+ edge_types = [
245
+ edge_type for edge_type in self.edge_types
246
+ if edge_type[0] == table_name
247
+ and edge_type[2] == query.entity_table
248
+ ]
249
+ if len(edge_types) != 1:
250
+ raise ValueError(f"Could not find a unique foreign key "
251
+ f"from table '{table_name}' to "
252
+ f"'{query.entity_table}'")
253
+ if edge_types[0] not in time_offset_dict:
254
+ start = node.aggr_time_range.start_date_offset
255
+ end = node.aggr_time_range.end_date_offset * num_forecasts
256
+ else:
257
+ start, end = time_offset_dict[edge_types[0]]
258
+ start = min_date_offset(
259
+ start,
260
+ node.aggr_time_range.start_date_offset,
261
+ )
262
+ end = max_date_offset(
263
+ end,
264
+ node.aggr_time_range.end_date_offset * num_forecasts,
265
+ )
266
+ time_offset_dict[edge_types[0]] = (start, end)
267
+
268
+ for child in node.children:
269
+ _add_time_offset(child, num_forecasts)
270
+
271
+ _add_time_offset(query.target_ast, query.num_forecasts)
272
+ _add_time_offset(query.entity_ast)
273
+ if query.whatif_ast is not None:
274
+ _add_time_offset(query.whatif_ast)
275
+
276
+ return self._sample_target(
277
+ query=query,
278
+ num_examples=num_examples,
279
+ anchor_time=anchor_time,
280
+ columns_dict=columns_dict,
281
+ time_offset_dict=time_offset_dict,
282
+ random_seed=random_seed,
283
+ )
284
+
285
+ # Abstract Methods ########################################################
286
+
287
+ @abstractmethod
288
+ def _sample_subgraph(
289
+ self,
290
+ entity_table_name: str,
291
+ entity_pkey: pd.Series,
292
+ anchor_time: pd.Series,
293
+ columns_dict: dict[str, set[str]],
294
+ num_neighbors: list[int],
295
+ ) -> SamplerOutput:
296
+ pass
297
+
298
+ @abstractmethod
299
+ def _sample_target(
300
+ self,
301
+ query: ValidatedPredictiveQuery,
302
+ num_examples: int,
303
+ anchor_time: pd.Timestamp | Literal['entity'],
304
+ columns_dict: dict[str, set[str]],
305
+ time_offset_dict: dict[
306
+ tuple[str, str, str],
307
+ tuple[pd.DateOffset | None, pd.DateOffset],
308
+ ],
309
+ random_seed: int | None = None,
310
+ ) -> TargetOutput:
311
+ pass
312
+
313
+
314
+ # Helper Functions ############################################################
315
+
316
+ PUNCTUATION = re.compile(r"[\'\"\.,\(\)\!\?\;\:]")
317
+ MULTISPACE = re.compile(r"\s+")
318
+
319
+
320
+ def _normalize_text(
321
+ ser: pd.Series,
322
+ max_words: int | None = 50,
323
+ ) -> pd.Series:
324
+ r"""Normalizes text into a list of lower-case words.
325
+
326
+ Args:
327
+ ser: The :class:`pandas.Series` to normalize.
328
+ max_words: The maximum number of words to return.
329
+ This will auto-shrink any large text column to avoid blowing up
330
+ context size.
331
+ """
332
+ if len(ser) == 0 or pd.api.types.is_list_like(ser.iloc[0]):
333
+ return ser
334
+
335
+ def normalize_fn(line: str) -> list[str]:
336
+ line = PUNCTUATION.sub(" ", line)
337
+ line = re.sub(r"<br\s*/?>", " ", line) # Handle <br /> or <br>
338
+ line = MULTISPACE.sub(" ", line)
339
+ words = line.split()
340
+ if max_words is not None:
341
+ words = words[:max_words]
342
+ return words
343
+
344
+ ser = ser.fillna('').astype(str)
345
+
346
+ if max_words is not None:
347
+ # We estimate the number of words as 5 characters + 1 space in an
348
+ # English text on average. We need this pre-filter here, as word
349
+ # splitting on a giant text can be very expensive:
350
+ ser = ser.str[:6 * max_words]
351
+
352
+ ser = ser.str.lower()
353
+ ser = ser.map(normalize_fn)
354
+
355
+ return ser
356
+
357
+
358
+ def min_date_offset(*args: pd.DateOffset | None) -> pd.DateOffset | None:
359
+ if any(arg is None for arg in args):
360
+ return None
361
+
362
+ anchor = pd.Timestamp('2000-01-01')
363
+ timestamps = [anchor + arg for arg in args]
364
+ assert len(timestamps) > 0
365
+ argmin = min(range(len(timestamps)), key=lambda i: timestamps[i])
366
+ return args[argmin]
367
+
368
+
369
+ def max_date_offset(*args: pd.DateOffset) -> pd.DateOffset:
370
+ anchor = pd.Timestamp('2000-01-01')
371
+ timestamps = [anchor + arg for arg in args]
372
+ assert len(timestamps) > 0
373
+ argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
374
+ return args[argmax]
@@ -0,0 +1,18 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kumoapi.typing import Dtype
4
+
5
+
6
+ @dataclass
7
+ class SourceColumn:
8
+ name: str
9
+ dtype: Dtype
10
+ is_primary_key: bool
11
+ is_unique_key: bool
12
+
13
+
14
+ @dataclass
15
+ class SourceForeignKey:
16
+ name: str
17
+ dst_table: str
18
+ primary_key: str