kumoai 2.13.0.dev202512040649__cp313-cp313-win_amd64.whl → 2.14.0.dev202601081732__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. kumoai/__init__.py +35 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +21 -7
  7. kumoai/experimental/rfm/__init__.py +51 -24
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  10. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
  11. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  12. kumoai/experimental/rfm/backend/local/table.py +35 -31
  13. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  14. kumoai/experimental/rfm/backend/snow/sampler.py +366 -0
  15. kumoai/experimental/rfm/backend/snow/table.py +177 -50
  16. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  17. kumoai/experimental/rfm/backend/sqlite/sampler.py +454 -0
  18. kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
  19. kumoai/experimental/rfm/base/__init__.py +23 -3
  20. kumoai/experimental/rfm/base/column.py +96 -10
  21. kumoai/experimental/rfm/base/expression.py +44 -0
  22. kumoai/experimental/rfm/base/sampler.py +782 -0
  23. kumoai/experimental/rfm/base/source.py +2 -1
  24. kumoai/experimental/rfm/base/sql_sampler.py +247 -0
  25. kumoai/experimental/rfm/base/table.py +404 -203
  26. kumoai/experimental/rfm/graph.py +374 -172
  27. kumoai/experimental/rfm/infer/__init__.py +6 -4
  28. kumoai/experimental/rfm/infer/dtype.py +7 -4
  29. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  30. kumoai/experimental/rfm/infer/pkey.py +4 -2
  31. kumoai/experimental/rfm/infer/stype.py +35 -0
  32. kumoai/experimental/rfm/infer/time_col.py +1 -2
  33. kumoai/experimental/rfm/pquery/executor.py +27 -27
  34. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  35. kumoai/experimental/rfm/relbench.py +76 -0
  36. kumoai/experimental/rfm/rfm.py +762 -467
  37. kumoai/experimental/rfm/sagemaker.py +4 -4
  38. kumoai/experimental/rfm/task_table.py +292 -0
  39. kumoai/kumolib.cp313-win_amd64.pyd +0 -0
  40. kumoai/pquery/predictive_query.py +10 -6
  41. kumoai/pquery/training_table.py +16 -2
  42. kumoai/testing/snow.py +50 -0
  43. kumoai/trainer/distilled_trainer.py +175 -0
  44. kumoai/utils/__init__.py +3 -2
  45. kumoai/utils/display.py +87 -0
  46. kumoai/utils/progress_logger.py +190 -12
  47. kumoai/utils/sql.py +3 -0
  48. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/METADATA +3 -2
  49. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/RECORD +52 -41
  50. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  51. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  52. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/WHEEL +0 -0
  53. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/licenses/LICENSE +0 -0
  54. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/top_level.txt +0 -0
@@ -6,9 +6,10 @@ from kumoapi.typing import Dtype
6
6
  @dataclass
7
7
  class SourceColumn:
8
8
  name: str
9
- dtype: Dtype
9
+ dtype: Dtype | None
10
10
  is_primary_key: bool
11
11
  is_unique_key: bool
12
+ is_nullable: bool
12
13
 
13
14
 
14
15
  @dataclass
@@ -0,0 +1,247 @@
1
+ from abc import abstractmethod
2
+ from collections import defaultdict
3
+ from typing import TYPE_CHECKING, Literal
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from kumoapi.typing import Dtype
8
+
9
+ from kumoai.experimental.rfm.base import (
10
+ LocalExpression,
11
+ Sampler,
12
+ SamplerOutput,
13
+ SourceColumn,
14
+ )
15
+ from kumoai.utils import ProgressLogger, quote_ident
16
+
17
+ if TYPE_CHECKING:
18
+ from kumoai.experimental.rfm import Graph
19
+
20
+
21
+ class SQLSampler(Sampler):
22
+ def __init__(
23
+ self,
24
+ graph: 'Graph',
25
+ verbose: bool | ProgressLogger = True,
26
+ ) -> None:
27
+ super().__init__(graph=graph, verbose=verbose)
28
+
29
+ self._source_name_dict: dict[str, str] = {
30
+ table.name: table._quoted_source_name
31
+ for table in graph.tables.values()
32
+ }
33
+
34
+ self._source_table_dict: dict[str, dict[str, SourceColumn]] = {}
35
+ for table in graph.tables.values():
36
+ self._source_table_dict[table.name] = {}
37
+ for column in table.columns:
38
+ if not column.is_source:
39
+ continue
40
+ src_column = table._source_column_dict[column.name]
41
+ self._source_table_dict[table.name][column.name] = src_column
42
+
43
+ self._table_dtype_dict: dict[str, dict[str, Dtype]] = {}
44
+ for table in graph.tables.values():
45
+ self._table_dtype_dict[table.name] = {}
46
+ for column in table.columns:
47
+ self._table_dtype_dict[table.name][column.name] = column.dtype
48
+
49
+ self._table_column_ref_dict: dict[str, dict[str, str]] = {}
50
+ self._table_column_proj_dict: dict[str, dict[str, str]] = {}
51
+ for table in graph.tables.values():
52
+ column_ref_dict: dict[str, str] = {}
53
+ column_proj_dict: dict[str, str] = {}
54
+ for column in table.columns:
55
+ if column.expr is not None:
56
+ assert isinstance(column.expr, LocalExpression)
57
+ column_ref_dict[column.name] = column.expr.value
58
+ column_proj_dict[column.name] = (
59
+ f'{column.expr} AS {quote_ident(column.name)}')
60
+ else:
61
+ column_ref_dict[column.name] = quote_ident(column.name)
62
+ column_proj_dict[column.name] = quote_ident(column.name)
63
+ self._table_column_ref_dict[table.name] = column_ref_dict
64
+ self._table_column_proj_dict[table.name] = column_proj_dict
65
+
66
+ @property
67
+ def source_name_dict(self) -> dict[str, str]:
68
+ r"""The source table names for all tables in the graph."""
69
+ return self._source_name_dict
70
+
71
+ @property
72
+ def source_table_dict(self) -> dict[str, dict[str, SourceColumn]]:
73
+ r"""The source column information for all tables in the graph."""
74
+ return self._source_table_dict
75
+
76
+ @property
77
+ def table_dtype_dict(self) -> dict[str, dict[str, Dtype]]:
78
+ r"""The data types for all columns in all tables in the graph."""
79
+ return self._table_dtype_dict
80
+
81
+ @property
82
+ def table_column_ref_dict(self) -> dict[str, dict[str, str]]:
83
+ r"""The SQL reference expression for all columns in all tables in the
84
+ graph.
85
+ """
86
+ return self._table_column_ref_dict
87
+
88
+ @property
89
+ def table_column_proj_dict(self) -> dict[str, dict[str, str]]:
90
+ r"""The SQL projection expressions for all columns in all tables in the
91
+ graph.
92
+ """
93
+ return self._table_column_proj_dict
94
+
95
+ def _sample_subgraph(
96
+ self,
97
+ entity_table_name: str,
98
+ entity_pkey: pd.Series,
99
+ anchor_time: pd.Series | Literal['entity'],
100
+ columns_dict: dict[str, set[str]],
101
+ num_neighbors: list[int],
102
+ ) -> SamplerOutput:
103
+
104
+ # Make sure to include primary key, foreign key and time columns:
105
+ sample_columns_dict: dict[str, set[str]] = {}
106
+ for table, columns in columns_dict.items():
107
+ sample_columns = columns | {
108
+ foreign_key
109
+ for foreign_key, _ in self.foreign_key_dict[table]
110
+ }
111
+ if primary_key := self.primary_key_dict.get(table):
112
+ sample_columns |= {primary_key}
113
+ if time_column := self.time_column_dict.get(table):
114
+ sample_columns |= {time_column}
115
+ sample_columns_dict[table] = sample_columns
116
+
117
+ # Sample Entity Table #################################################
118
+
119
+ df, batch = self._by_pkey(
120
+ table_name=entity_table_name,
121
+ index=entity_pkey,
122
+ columns=sample_columns_dict[entity_table_name],
123
+ )
124
+ if len(batch) != len(entity_pkey):
125
+ mask = np.ones(len(entity_pkey), dtype=bool)
126
+ mask[batch] = False
127
+ raise KeyError(f"The primary keys "
128
+ f"{entity_pkey.iloc[mask].tolist()} do not exist "
129
+ f"in the '{entity_table_name}' table")
130
+
131
+ # Make sure that entities are returned in expected order:
132
+ perm = batch.argsort()
133
+ batch = batch[perm]
134
+ df = df.iloc[perm].reset_index(drop=True)
135
+
136
+ # Fill 'entity' anchor times with actual values:
137
+ if not isinstance(anchor_time, pd.Series):
138
+ time_column = self.time_column_dict[entity_table_name]
139
+ anchor_time = df[time_column]
140
+ assert isinstance(anchor_time, pd.Series)
141
+
142
+ df_hop_dict: dict[tuple[str, int], pd.DataFrame] = {
143
+ (entity_table_name, 0): df,
144
+ }
145
+ batch_hop_dict: dict[tuple[str, int], np.ndarray] = {
146
+ (entity_table_name, 0): batch,
147
+ }
148
+
149
+ # Recursive Neighbor Sampling #########################################
150
+
151
+ for hop, neighbors in enumerate(num_neighbors):
152
+ if neighbors == 0:
153
+ break # Abort early.
154
+
155
+ dfs: dict[str, list[pd.DataFrame]] = defaultdict(list)
156
+ batches: dict[str, list[np.ndarray]] = defaultdict(list)
157
+
158
+ tables = [table for table, i in batch_hop_dict if i == hop]
159
+ for table in tables:
160
+ df = df_hop_dict[(table, hop)]
161
+ batch = batch_hop_dict[(table, hop)]
162
+
163
+ # Iterate over foreign keys in the current table:
164
+ for fkey, dst_table in self.foreign_key_dict[table]:
165
+ raise NotImplementedError
166
+
167
+ # Iterate over foreign keys that reference the current table:
168
+ for src_table, fkey in self.rev_foreign_key_dict[table]:
169
+ _df, _batch = self._by_fkey(
170
+ table_name=src_table,
171
+ foreign_key=fkey,
172
+ index=df[self.primary_key_dict[table]],
173
+ num_neighbors=neighbors,
174
+ anchor_time=anchor_time.iloc[batch],
175
+ columns=sample_columns_dict[src_table],
176
+ )
177
+ _batch = batch[_batch]
178
+
179
+ # TODO Filter out duplicates if `src_table` has a pkey.
180
+ dfs[src_table].append(_df)
181
+ batches[src_table].append(_batch)
182
+
183
+ # TODO Add edges to all sampled nodes.
184
+
185
+ # Post-Processing #####################################################
186
+
187
+ dfs_dict: dict[str, list[pd.DataFrame]] = defaultdict(list)
188
+ batches_dict: dict[str, list[np.ndarray]] = defaultdict(list)
189
+ num_hops = max(hop for _, hop in df_hop_dict.keys()) # TODO
190
+ num_sampled_nodes_dict: dict[str, list[int]] = {
191
+ table: [0] * (num_hops + 1)
192
+ for table in [table for table, _ in df_hop_dict.keys()]
193
+ }
194
+ for (table, hop), df in df_hop_dict.items():
195
+ dfs_dict[table].append(df)
196
+ batches_dict[table].append(batch_hop_dict[(table, hop)])
197
+ num_sampled_nodes_dict[table][hop] = len(df)
198
+
199
+ df_dict = { # Concatenate data frames across hops:
200
+ table:
201
+ pd.concat(dfs, axis=0, ignore_index=True)
202
+ if len(dfs) > 1 else dfs[0]
203
+ for table, dfs in dfs_dict.items()
204
+ }
205
+ df_dict = { # Post-filter column set:
206
+ table: df[list(columns_dict[table])]
207
+ for table_name, df in df_dict.items()
208
+ }
209
+ batch_dict = { # Concatenate batch vector across hops:
210
+ table:
211
+ np.concatenate(batches, axis=0) if len(batches) > 1 else batches[0]
212
+ for table, batches in batches_dict.items()
213
+ }
214
+
215
+ return SamplerOutput(
216
+ anchor_time=anchor_time.astype(int).to_numpy(),
217
+ df_dict=df_dict,
218
+ inverse_dict={}, # TODO
219
+ batch_dict=batch_dict,
220
+ num_sampled_nodes_dict=num_sampled_nodes_dict,
221
+ row_dict={}, # TODO
222
+ col_dict={}, # TODO
223
+ num_sampled_edges_dict={}, # TODO
224
+ )
225
+
226
+ # Abstract Methods ########################################################
227
+
228
+ @abstractmethod
229
+ def _by_pkey(
230
+ self,
231
+ table_name: str,
232
+ index: pd.Series,
233
+ columns: set[str],
234
+ ) -> tuple[pd.DataFrame, np.ndarray]:
235
+ pass
236
+
237
+ @abstractmethod
238
+ def _by_fkey(
239
+ self,
240
+ table_name: str,
241
+ foreign_key: str,
242
+ index: pd.Series,
243
+ num_neighbors: int,
244
+ anchor_time: pd.Series | None,
245
+ columns: set[str],
246
+ ) -> tuple[pd.DataFrame, np.ndarray]:
247
+ pass