kumoai 2.14.0.dev202512191731__cp311-cp311-macosx_11_0_arm64.whl → 2.15.0.dev202601141731__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +24 -22
  7. kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
  8. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  9. kumoai/experimental/rfm/backend/local/table.py +24 -25
  10. kumoai/experimental/rfm/backend/snow/sampler.py +235 -80
  11. kumoai/experimental/rfm/backend/snow/table.py +146 -70
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +196 -89
  13. kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
  14. kumoai/experimental/rfm/base/__init__.py +6 -9
  15. kumoai/experimental/rfm/base/column.py +95 -11
  16. kumoai/experimental/rfm/base/expression.py +44 -0
  17. kumoai/experimental/rfm/base/mapper.py +69 -0
  18. kumoai/experimental/rfm/base/sampler.py +28 -18
  19. kumoai/experimental/rfm/base/source.py +1 -1
  20. kumoai/experimental/rfm/base/sql_sampler.py +320 -19
  21. kumoai/experimental/rfm/base/table.py +256 -109
  22. kumoai/experimental/rfm/base/utils.py +36 -0
  23. kumoai/experimental/rfm/graph.py +115 -107
  24. kumoai/experimental/rfm/infer/dtype.py +7 -2
  25. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  26. kumoai/experimental/rfm/infer/time_col.py +4 -2
  27. kumoai/experimental/rfm/relbench.py +76 -0
  28. kumoai/experimental/rfm/rfm.py +540 -306
  29. kumoai/experimental/rfm/task_table.py +292 -0
  30. kumoai/pquery/training_table.py +16 -2
  31. kumoai/testing/snow.py +3 -3
  32. kumoai/trainer/distilled_trainer.py +175 -0
  33. kumoai/utils/display.py +87 -0
  34. kumoai/utils/progress_logger.py +15 -2
  35. kumoai/utils/sql.py +2 -2
  36. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/METADATA +2 -2
  37. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/RECORD +40 -35
  38. kumoai/experimental/rfm/base/column_expression.py +0 -50
  39. kumoai/experimental/rfm/base/sql_table.py +0 -229
  40. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/WHEEL +0 -0
  41. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/licenses/LICENSE +0 -0
  42. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,26 @@
1
1
  from abc import abstractmethod
2
+ from collections import defaultdict
2
3
  from typing import TYPE_CHECKING, Literal
3
4
 
4
5
  import numpy as np
5
6
  import pandas as pd
7
+ from kumoapi.rfm.context import Subgraph
8
+ from kumoapi.typing import Dtype
6
9
 
7
- from kumoai.experimental.rfm.base import Sampler, SamplerOutput, SQLTable
8
- from kumoai.utils import ProgressLogger
10
+ from kumoai.experimental.rfm.base import (
11
+ LocalExpression,
12
+ Sampler,
13
+ SamplerOutput,
14
+ SourceColumn,
15
+ )
16
+ from kumoai.experimental.rfm.base.mapper import Mapper
17
+ from kumoai.utils import ProgressLogger, quote_ident
9
18
 
10
19
  if TYPE_CHECKING:
11
20
  from kumoai.experimental.rfm import Graph
12
21
 
22
+ EdgeType = tuple[str, str, str]
23
+
13
24
 
14
25
  class SQLSampler(Sampler):
15
26
  def __init__(
@@ -19,18 +30,71 @@ class SQLSampler(Sampler):
19
30
  ) -> None:
20
31
  super().__init__(graph=graph, verbose=verbose)
21
32
 
22
- self._fqn_dict: dict[str, str] = {}
33
+ self._source_name_dict: dict[str, str] = {
34
+ table.name: table._quoted_source_name
35
+ for table in graph.tables.values()
36
+ }
37
+
38
+ self._source_table_dict: dict[str, dict[str, SourceColumn]] = {}
39
+ for table in graph.tables.values():
40
+ self._source_table_dict[table.name] = {}
41
+ for column in table.columns:
42
+ if not column.is_source:
43
+ continue
44
+ src_column = table._source_column_dict[column.name]
45
+ self._source_table_dict[table.name][column.name] = src_column
46
+
47
+ self._table_dtype_dict: dict[str, dict[str, Dtype]] = {}
23
48
  for table in graph.tables.values():
24
- assert isinstance(table, SQLTable)
25
- self._connection = table._connection
26
- self._fqn_dict[table.name] = table.fqn
49
+ self._table_dtype_dict[table.name] = {}
50
+ for column in table.columns:
51
+ self._table_dtype_dict[table.name][column.name] = column.dtype
52
+
53
+ self._table_column_ref_dict: dict[str, dict[str, str]] = {}
54
+ self._table_column_proj_dict: dict[str, dict[str, str]] = {}
55
+ for table in graph.tables.values():
56
+ column_ref_dict: dict[str, str] = {}
57
+ column_proj_dict: dict[str, str] = {}
58
+ for column in table.columns:
59
+ if column.expr is not None:
60
+ assert isinstance(column.expr, LocalExpression)
61
+ column_ref_dict[column.name] = column.expr.value
62
+ column_proj_dict[column.name] = (
63
+ f'{column.expr} AS {quote_ident(column.name)}')
64
+ else:
65
+ column_ref_dict[column.name] = quote_ident(column.name)
66
+ column_proj_dict[column.name] = quote_ident(column.name)
67
+ self._table_column_ref_dict[table.name] = column_ref_dict
68
+ self._table_column_proj_dict[table.name] = column_proj_dict
69
+
70
+ @property
71
+ def source_name_dict(self) -> dict[str, str]:
72
+ r"""The source table names for all tables in the graph."""
73
+ return self._source_name_dict
27
74
 
28
75
  @property
29
- def fqn_dict(self) -> dict[str, str]:
30
- r"""The fully-qualified quoted source name for all table names in the
76
+ def source_table_dict(self) -> dict[str, dict[str, SourceColumn]]:
77
+ r"""The source column information for all tables in the graph."""
78
+ return self._source_table_dict
79
+
80
+ @property
81
+ def table_dtype_dict(self) -> dict[str, dict[str, Dtype]]:
82
+ r"""The data types for all columns in all tables in the graph."""
83
+ return self._table_dtype_dict
84
+
85
+ @property
86
+ def table_column_ref_dict(self) -> dict[str, dict[str, str]]:
87
+ r"""The SQL reference expression for all columns in all tables in the
31
88
  graph.
32
89
  """
33
- return self._fqn_dict
90
+ return self._table_column_ref_dict
91
+
92
+ @property
93
+ def table_column_proj_dict(self) -> dict[str, dict[str, str]]:
94
+ r"""The SQL projection expressions for all columns in all tables in the
95
+ graph.
96
+ """
97
+ return self._table_column_proj_dict
34
98
 
35
99
  def _sample_subgraph(
36
100
  self,
@@ -41,10 +105,28 @@ class SQLSampler(Sampler):
41
105
  num_neighbors: list[int],
42
106
  ) -> SamplerOutput:
43
107
 
108
+ # Make sure to always include primary key, foreign key and time columns
109
+ # during data fetching since these are needed for graph traversal:
110
+ sample_columns_dict: dict[str, set[str]] = {}
111
+ for table, columns in columns_dict.items():
112
+ sample_columns = columns | {
113
+ foreign_key
114
+ for foreign_key, _ in self.foreign_key_dict[table]
115
+ }
116
+ if primary_key := self.primary_key_dict.get(table):
117
+ sample_columns |= {primary_key}
118
+ sample_columns_dict[table] = sample_columns
119
+ if not isinstance(anchor_time, pd.Series):
120
+ sample_columns_dict[entity_table_name] |= {
121
+ self.time_column_dict[entity_table_name]
122
+ }
123
+
124
+ # Sample Entity Table #################################################
125
+
44
126
  df, batch = self._by_pkey(
45
127
  table_name=entity_table_name,
46
- pkey=entity_pkey,
47
- columns=columns_dict[entity_table_name],
128
+ index=entity_pkey,
129
+ columns=sample_columns_dict[entity_table_name],
48
130
  )
49
131
  if len(batch) != len(entity_pkey):
50
132
  mask = np.ones(len(entity_pkey), dtype=bool)
@@ -53,23 +135,230 @@ class SQLSampler(Sampler):
53
135
  f"{entity_pkey.iloc[mask].tolist()} do not exist "
54
136
  f"in the '{entity_table_name}' table")
55
137
 
138
+ # Make sure that entities are returned in expected order:
56
139
  perm = batch.argsort()
57
140
  batch = batch[perm]
58
141
  df = df.iloc[perm].reset_index(drop=True)
59
142
 
143
+ # Fill 'entity' anchor times with actual values:
60
144
  if not isinstance(anchor_time, pd.Series):
61
145
  time_column = self.time_column_dict[entity_table_name]
62
146
  anchor_time = df[time_column]
147
+ assert isinstance(anchor_time, pd.Series)
148
+
149
+ # Recursive Neighbor Sampling #########################################
150
+
151
+ mapper_dict: dict[str, Mapper] = defaultdict(
152
+ lambda: Mapper(num_examples=len(entity_pkey)))
153
+ mapper_dict[entity_table_name].add(
154
+ pkey=df[self.primary_key_dict[entity_table_name]],
155
+ batch=batch,
156
+ )
157
+
158
+ dfs_dict: dict[str, list[pd.DataFrame]] = defaultdict(list)
159
+ dfs_dict[entity_table_name].append(df)
160
+ batches_dict: dict[str, list[np.ndarray]] = defaultdict(list)
161
+ batches_dict[entity_table_name].append(batch)
162
+ num_sampled_nodes_dict: dict[str, list[int]] = defaultdict(
163
+ lambda: [0] * (len(num_neighbors) + 1))
164
+ num_sampled_nodes_dict[entity_table_name][0] = len(entity_pkey)
165
+
166
+ rows_dict: dict[EdgeType, list[np.ndarray]] = defaultdict(list)
167
+ cols_dict: dict[EdgeType, list[np.ndarray]] = defaultdict(list)
168
+ num_sampled_edges_dict: dict[EdgeType, list[int]] = defaultdict(
169
+ lambda: [0] * len(num_neighbors))
170
+
171
+ # The start index of data frame slices of the previous hop:
172
+ offset_dict: dict[str, int] = defaultdict(int)
173
+
174
+ for hop, neighbors in enumerate(num_neighbors):
175
+ if neighbors == 0:
176
+ break # Abort early.
177
+
178
+ for table in list(num_sampled_nodes_dict.keys()):
179
+ # Only sample from tables that have been visited in the
180
+ # previous hop:
181
+ if num_sampled_nodes_dict[table][hop] == 0:
182
+ continue
183
+
184
+ # Collect the slices of data sampled in the previous hop
185
+ # (but maintain only required key information):
186
+ cols = [fkey for fkey, _ in self.foreign_key_dict[table]]
187
+ if table in self.primary_key_dict:
188
+ cols.append(self.primary_key_dict[table])
189
+ dfs = [df[cols] for df in dfs_dict[table][offset_dict[table]:]]
190
+ df = pd.concat(
191
+ dfs,
192
+ axis=0,
193
+ ignore_index=True,
194
+ ) if len(dfs) > 1 else dfs[0]
195
+ batches = batches_dict[table][offset_dict[table]:]
196
+ batch = (np.concatenate(batches)
197
+ if len(batches) > 1 else batches[0])
198
+ offset_dict[table] = len(batches_dict[table]) # Increase.
199
+
200
+ pkey: pd.Series | None = None
201
+ index: pd.ndarray | None = None
202
+ if table in self.primary_key_dict:
203
+ pkey = df[self.primary_key_dict[table]]
204
+ index = mapper_dict[table].get(pkey, batch)
205
+
206
+ # Iterate over foreign keys in the current table:
207
+ for fkey, dst_table in self.foreign_key_dict[table]:
208
+ row = mapper_dict[dst_table].get(df[fkey], batch)
209
+ mask = row == -1
210
+ if mask.any():
211
+ key_df = pd.DataFrame({
212
+ 'fkey': df[fkey],
213
+ 'batch': batch,
214
+ }).iloc[mask]
215
+ # Only maintain unique keys per example:
216
+ unique_key_df = key_df.drop_duplicates()
217
+ # Fully de-duplicate keys across examples:
218
+ code, fkey_index = pd.factorize(unique_key_df['fkey'])
219
+
220
+ _df, _batch = self._by_pkey(
221
+ table_name=dst_table,
222
+ index=fkey_index,
223
+ columns=sample_columns_dict[dst_table],
224
+ ) # Ensure result is sorted according to input order:
225
+ _df = _df.iloc[_batch.argsort()]
226
+
227
+ # Compute valid entries (without dangling foreign keys)
228
+ # in `unique_fkey_df`:
229
+ _mask = np.full(len(fkey_index), fill_value=False)
230
+ _mask[_batch] = True
231
+ _mask = _mask[code]
232
+
233
+ # Recontruct unique (key, batch) pairs:
234
+ code, _ = pd.factorize(unique_key_df['fkey'][_mask])
235
+ _df = _df.iloc[code].reset_index(drop=True)
236
+ _batch = unique_key_df['batch'].to_numpy()[_mask]
237
+
238
+ # Register node IDs:
239
+ mapper_dict[dst_table].add(
240
+ pkey=_df[self.primary_key_dict[dst_table]],
241
+ batch=_batch,
242
+ )
243
+ row[mask] = mapper_dict[dst_table].get(
244
+ pkey=key_df['fkey'],
245
+ batch=key_df['batch'].to_numpy(),
246
+ ) # NOTE `row` may still hold `-1` for dangling fkeys.
247
+
248
+ dfs_dict[dst_table].append(_df)
249
+ batches_dict[dst_table].append(_batch)
250
+ num_sampled_nodes_dict[dst_table][hop + 1] += ( #
251
+ len(_batch))
252
+
253
+ mask = row != -1
254
+
255
+ col = index
256
+ if col is None:
257
+ start = sum(num_sampled_nodes_dict[table][:hop])
258
+ end = sum(num_sampled_nodes_dict[table][:hop + 1])
259
+ col = np.arange(start, end)
260
+
261
+ row = row[mask]
262
+ col = col[mask]
263
+
264
+ edge_type = (table, fkey, dst_table)
265
+ edge_type = Subgraph.rev_edge_type(edge_type)
266
+ rows_dict[edge_type].append(row)
267
+ cols_dict[edge_type].append(col)
268
+ num_sampled_edges_dict[edge_type][hop] = len(col)
269
+
270
+ # Iterate over foreign keys that reference the current table:
271
+ for src_table, fkey in self.rev_foreign_key_dict[table]:
272
+ assert pkey is not None and index is not None
273
+ _df, _batch = self._by_fkey(
274
+ table_name=src_table,
275
+ foreign_key=fkey,
276
+ index=pkey,
277
+ num_neighbors=neighbors,
278
+ anchor_time=anchor_time.iloc[batch],
279
+ columns=sample_columns_dict[src_table],
280
+ )
281
+
282
+ edge_type = (src_table, fkey, table)
283
+ cols_dict[edge_type].append(index[_batch])
284
+ num_sampled_edges_dict[edge_type][hop] = len(_batch)
285
+
286
+ _batch = batch[_batch]
287
+ num_nodes = sum(num_sampled_nodes_dict[src_table])
288
+ if src_table in self.primary_key_dict:
289
+ _pkey = _df[self.primary_key_dict]
290
+ mapper_dict[src_table].add(_pkey, _batch)
291
+ row = mapper_dict[src_table].get(_pkey, _batch)
292
+
293
+ # Only preserve unknown rows:
294
+ mask = row >= num_nodes # type: ignore
295
+ mask[pd.duplicated(row)] = False
296
+ _df = _df.iloc[mask]
297
+ _batch = _batch[mask]
298
+ else:
299
+ row = np.arange(num_nodes, num_nodes + len(_batch))
300
+
301
+ rows_dict[edge_type].append(row)
302
+ num_sampled_nodes_dict[src_table][hop + 1] += len(_batch)
303
+
304
+ dfs_dict[src_table].append(_df)
305
+ batches_dict[src_table].append(_batch)
306
+
307
+ # Post-Processing #####################################################
308
+
309
+ df_dict = {
310
+ table:
311
+ pd.concat(dfs, axis=0, ignore_index=True)
312
+ if len(dfs) > 1 else dfs[0]
313
+ for table, dfs in dfs_dict.items()
314
+ }
315
+
316
+ # Only store unique rows in `df` above a certain threshold:
317
+ inverse_dict: dict[str, np.ndarray] = {}
318
+ for table, df in df_dict.items():
319
+ if table not in self.primary_key_dict:
320
+ continue
321
+ unique, index, inverse = np.unique(
322
+ df_dict[table][self.primary_key_dict[table]],
323
+ return_index=True,
324
+ return_inverse=True,
325
+ )
326
+ if len(df) > 1.05 * len(unique):
327
+ df_dict[table] = df.iloc[index].reset_index(drop=True)
328
+ inverse_dict[table] = inverse
329
+
330
+ df_dict = { # Post-filter column set:
331
+ table: df[list(columns_dict[table])]
332
+ for table, df in df_dict.items()
333
+ }
334
+ batch_dict = {
335
+ table: np.concatenate(batches) if len(batches) > 1 else batches[0]
336
+ for table, batches in batches_dict.items()
337
+ }
338
+ row_dict = {
339
+ edge_type: np.concatenate(rows)
340
+ for edge_type, rows in rows_dict.items()
341
+ }
342
+ col_dict = {
343
+ edge_type: np.concatenate(cols)
344
+ for edge_type, cols in cols_dict.items()
345
+ }
346
+
347
+ if len(num_sampled_edges_dict) == 0: # Single table:
348
+ num_sampled_nodes_dict = {
349
+ key: value[:1]
350
+ for key, value in num_sampled_nodes_dict.items()
351
+ }
63
352
 
64
353
  return SamplerOutput(
65
354
  anchor_time=anchor_time.astype(int).to_numpy(),
66
- df_dict={entity_table_name: df},
67
- inverse_dict={},
68
- batch_dict={entity_table_name: batch},
69
- num_sampled_nodes_dict={entity_table_name: [len(batch)]},
70
- row_dict={},
71
- col_dict={},
72
- num_sampled_edges_dict={},
355
+ df_dict=df_dict,
356
+ inverse_dict=inverse_dict,
357
+ batch_dict=batch_dict,
358
+ num_sampled_nodes_dict=num_sampled_nodes_dict,
359
+ row_dict=row_dict,
360
+ col_dict=col_dict,
361
+ num_sampled_edges_dict=num_sampled_edges_dict,
73
362
  )
74
363
 
75
364
  # Abstract Methods ########################################################
@@ -78,7 +367,19 @@ class SQLSampler(Sampler):
78
367
  def _by_pkey(
79
368
  self,
80
369
  table_name: str,
81
- pkey: pd.Series,
370
+ index: pd.Series,
371
+ columns: set[str],
372
+ ) -> tuple[pd.DataFrame, np.ndarray]:
373
+ pass
374
+
375
+ @abstractmethod
376
+ def _by_fkey(
377
+ self,
378
+ table_name: str,
379
+ foreign_key: str,
380
+ index: pd.Series,
381
+ num_neighbors: int,
382
+ anchor_time: pd.Series | None,
82
383
  columns: set[str],
83
384
  ) -> tuple[pd.DataFrame, np.ndarray]:
84
385
  pass