kumoai 2.13.0.dev202512040649__cp313-cp313-win_amd64.whl → 2.14.0.dev202601081732__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +35 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +51 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/__init__.py +4 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +35 -31
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +366 -0
- kumoai/experimental/rfm/backend/snow/table.py +177 -50
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +454 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
- kumoai/experimental/rfm/base/__init__.py +23 -3
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +782 -0
- kumoai/experimental/rfm/base/source.py +2 -1
- kumoai/experimental/rfm/base/sql_sampler.py +247 -0
- kumoai/experimental/rfm/base/table.py +404 -203
- kumoai/experimental/rfm/graph.py +374 -172
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +7 -4
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +762 -467
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/kumolib.cp313-win_amd64.pyd +0 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +190 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/METADATA +3 -2
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/RECORD +52 -41
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import TYPE_CHECKING, Literal
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from kumoapi.typing import Dtype
|
|
8
|
+
|
|
9
|
+
from kumoai.experimental.rfm.base import (
|
|
10
|
+
LocalExpression,
|
|
11
|
+
Sampler,
|
|
12
|
+
SamplerOutput,
|
|
13
|
+
SourceColumn,
|
|
14
|
+
)
|
|
15
|
+
from kumoai.utils import ProgressLogger, quote_ident
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from kumoai.experimental.rfm import Graph
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SQLSampler(Sampler):
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
graph: 'Graph',
|
|
25
|
+
verbose: bool | ProgressLogger = True,
|
|
26
|
+
) -> None:
|
|
27
|
+
super().__init__(graph=graph, verbose=verbose)
|
|
28
|
+
|
|
29
|
+
self._source_name_dict: dict[str, str] = {
|
|
30
|
+
table.name: table._quoted_source_name
|
|
31
|
+
for table in graph.tables.values()
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
self._source_table_dict: dict[str, dict[str, SourceColumn]] = {}
|
|
35
|
+
for table in graph.tables.values():
|
|
36
|
+
self._source_table_dict[table.name] = {}
|
|
37
|
+
for column in table.columns:
|
|
38
|
+
if not column.is_source:
|
|
39
|
+
continue
|
|
40
|
+
src_column = table._source_column_dict[column.name]
|
|
41
|
+
self._source_table_dict[table.name][column.name] = src_column
|
|
42
|
+
|
|
43
|
+
self._table_dtype_dict: dict[str, dict[str, Dtype]] = {}
|
|
44
|
+
for table in graph.tables.values():
|
|
45
|
+
self._table_dtype_dict[table.name] = {}
|
|
46
|
+
for column in table.columns:
|
|
47
|
+
self._table_dtype_dict[table.name][column.name] = column.dtype
|
|
48
|
+
|
|
49
|
+
self._table_column_ref_dict: dict[str, dict[str, str]] = {}
|
|
50
|
+
self._table_column_proj_dict: dict[str, dict[str, str]] = {}
|
|
51
|
+
for table in graph.tables.values():
|
|
52
|
+
column_ref_dict: dict[str, str] = {}
|
|
53
|
+
column_proj_dict: dict[str, str] = {}
|
|
54
|
+
for column in table.columns:
|
|
55
|
+
if column.expr is not None:
|
|
56
|
+
assert isinstance(column.expr, LocalExpression)
|
|
57
|
+
column_ref_dict[column.name] = column.expr.value
|
|
58
|
+
column_proj_dict[column.name] = (
|
|
59
|
+
f'{column.expr} AS {quote_ident(column.name)}')
|
|
60
|
+
else:
|
|
61
|
+
column_ref_dict[column.name] = quote_ident(column.name)
|
|
62
|
+
column_proj_dict[column.name] = quote_ident(column.name)
|
|
63
|
+
self._table_column_ref_dict[table.name] = column_ref_dict
|
|
64
|
+
self._table_column_proj_dict[table.name] = column_proj_dict
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def source_name_dict(self) -> dict[str, str]:
|
|
68
|
+
r"""The source table names for all tables in the graph."""
|
|
69
|
+
return self._source_name_dict
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def source_table_dict(self) -> dict[str, dict[str, SourceColumn]]:
|
|
73
|
+
r"""The source column information for all tables in the graph."""
|
|
74
|
+
return self._source_table_dict
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def table_dtype_dict(self) -> dict[str, dict[str, Dtype]]:
|
|
78
|
+
r"""The data types for all columns in all tables in the graph."""
|
|
79
|
+
return self._table_dtype_dict
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def table_column_ref_dict(self) -> dict[str, dict[str, str]]:
|
|
83
|
+
r"""The SQL reference expression for all columns in all tables in the
|
|
84
|
+
graph.
|
|
85
|
+
"""
|
|
86
|
+
return self._table_column_ref_dict
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def table_column_proj_dict(self) -> dict[str, dict[str, str]]:
|
|
90
|
+
r"""The SQL projection expressions for all columns in all tables in the
|
|
91
|
+
graph.
|
|
92
|
+
"""
|
|
93
|
+
return self._table_column_proj_dict
|
|
94
|
+
|
|
95
|
+
def _sample_subgraph(
|
|
96
|
+
self,
|
|
97
|
+
entity_table_name: str,
|
|
98
|
+
entity_pkey: pd.Series,
|
|
99
|
+
anchor_time: pd.Series | Literal['entity'],
|
|
100
|
+
columns_dict: dict[str, set[str]],
|
|
101
|
+
num_neighbors: list[int],
|
|
102
|
+
) -> SamplerOutput:
|
|
103
|
+
|
|
104
|
+
# Make sure to include primary key, foreign key and time columns:
|
|
105
|
+
sample_columns_dict: dict[str, set[str]] = {}
|
|
106
|
+
for table, columns in columns_dict.items():
|
|
107
|
+
sample_columns = columns | {
|
|
108
|
+
foreign_key
|
|
109
|
+
for foreign_key, _ in self.foreign_key_dict[table]
|
|
110
|
+
}
|
|
111
|
+
if primary_key := self.primary_key_dict.get(table):
|
|
112
|
+
sample_columns |= {primary_key}
|
|
113
|
+
if time_column := self.time_column_dict.get(table):
|
|
114
|
+
sample_columns |= {time_column}
|
|
115
|
+
sample_columns_dict[table] = sample_columns
|
|
116
|
+
|
|
117
|
+
# Sample Entity Table #################################################
|
|
118
|
+
|
|
119
|
+
df, batch = self._by_pkey(
|
|
120
|
+
table_name=entity_table_name,
|
|
121
|
+
index=entity_pkey,
|
|
122
|
+
columns=sample_columns_dict[entity_table_name],
|
|
123
|
+
)
|
|
124
|
+
if len(batch) != len(entity_pkey):
|
|
125
|
+
mask = np.ones(len(entity_pkey), dtype=bool)
|
|
126
|
+
mask[batch] = False
|
|
127
|
+
raise KeyError(f"The primary keys "
|
|
128
|
+
f"{entity_pkey.iloc[mask].tolist()} do not exist "
|
|
129
|
+
f"in the '{entity_table_name}' table")
|
|
130
|
+
|
|
131
|
+
# Make sure that entities are returned in expected order:
|
|
132
|
+
perm = batch.argsort()
|
|
133
|
+
batch = batch[perm]
|
|
134
|
+
df = df.iloc[perm].reset_index(drop=True)
|
|
135
|
+
|
|
136
|
+
# Fill 'entity' anchor times with actual values:
|
|
137
|
+
if not isinstance(anchor_time, pd.Series):
|
|
138
|
+
time_column = self.time_column_dict[entity_table_name]
|
|
139
|
+
anchor_time = df[time_column]
|
|
140
|
+
assert isinstance(anchor_time, pd.Series)
|
|
141
|
+
|
|
142
|
+
df_hop_dict: dict[tuple[str, int], pd.DataFrame] = {
|
|
143
|
+
(entity_table_name, 0): df,
|
|
144
|
+
}
|
|
145
|
+
batch_hop_dict: dict[tuple[str, int], np.ndarray] = {
|
|
146
|
+
(entity_table_name, 0): batch,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
# Recursive Neighbor Sampling #########################################
|
|
150
|
+
|
|
151
|
+
for hop, neighbors in enumerate(num_neighbors):
|
|
152
|
+
if neighbors == 0:
|
|
153
|
+
break # Abort early.
|
|
154
|
+
|
|
155
|
+
dfs: dict[str, list[pd.DataFrame]] = defaultdict(list)
|
|
156
|
+
batches: dict[str, list[np.ndarray]] = defaultdict(list)
|
|
157
|
+
|
|
158
|
+
tables = [table for table, i in batch_hop_dict if i == hop]
|
|
159
|
+
for table in tables:
|
|
160
|
+
df = df_hop_dict[(table, hop)]
|
|
161
|
+
batch = batch_hop_dict[(table, hop)]
|
|
162
|
+
|
|
163
|
+
# Iterate over foreign keys in the current table:
|
|
164
|
+
for fkey, dst_table in self.foreign_key_dict[table]:
|
|
165
|
+
raise NotImplementedError
|
|
166
|
+
|
|
167
|
+
# Iterate over foreign keys that reference the current table:
|
|
168
|
+
for src_table, fkey in self.rev_foreign_key_dict[table]:
|
|
169
|
+
_df, _batch = self._by_fkey(
|
|
170
|
+
table_name=src_table,
|
|
171
|
+
foreign_key=fkey,
|
|
172
|
+
index=df[self.primary_key_dict[table]],
|
|
173
|
+
num_neighbors=neighbors,
|
|
174
|
+
anchor_time=anchor_time.iloc[batch],
|
|
175
|
+
columns=sample_columns_dict[src_table],
|
|
176
|
+
)
|
|
177
|
+
_batch = batch[_batch]
|
|
178
|
+
|
|
179
|
+
# TODO Filter out duplicates if `src_table` has a pkey.
|
|
180
|
+
dfs[src_table].append(_df)
|
|
181
|
+
batches[src_table].append(_batch)
|
|
182
|
+
|
|
183
|
+
# TODO Add edges to all sampled nodes.
|
|
184
|
+
|
|
185
|
+
# Post-Processing #####################################################
|
|
186
|
+
|
|
187
|
+
dfs_dict: dict[str, list[pd.DataFrame]] = defaultdict(list)
|
|
188
|
+
batches_dict: dict[str, list[np.ndarray]] = defaultdict(list)
|
|
189
|
+
num_hops = max(hop for _, hop in df_hop_dict.keys()) # TODO
|
|
190
|
+
num_sampled_nodes_dict: dict[str, list[int]] = {
|
|
191
|
+
table: [0] * (num_hops + 1)
|
|
192
|
+
for table in [table for table, _ in df_hop_dict.keys()]
|
|
193
|
+
}
|
|
194
|
+
for (table, hop), df in df_hop_dict.items():
|
|
195
|
+
dfs_dict[table].append(df)
|
|
196
|
+
batches_dict[table].append(batch_hop_dict[(table, hop)])
|
|
197
|
+
num_sampled_nodes_dict[table][hop] = len(df)
|
|
198
|
+
|
|
199
|
+
df_dict = { # Concatenate data frames across hops:
|
|
200
|
+
table:
|
|
201
|
+
pd.concat(dfs, axis=0, ignore_index=True)
|
|
202
|
+
if len(dfs) > 1 else dfs[0]
|
|
203
|
+
for table, dfs in dfs_dict.items()
|
|
204
|
+
}
|
|
205
|
+
df_dict = { # Post-filter column set:
|
|
206
|
+
table: df[list(columns_dict[table])]
|
|
207
|
+
for table_name, df in df_dict.items()
|
|
208
|
+
}
|
|
209
|
+
batch_dict = { # Concatenate batch vector across hops:
|
|
210
|
+
table:
|
|
211
|
+
np.concatenate(batches, axis=0) if len(batches) > 1 else batches[0]
|
|
212
|
+
for table, batches in batches_dict.items()
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
return SamplerOutput(
|
|
216
|
+
anchor_time=anchor_time.astype(int).to_numpy(),
|
|
217
|
+
df_dict=df_dict,
|
|
218
|
+
inverse_dict={}, # TODO
|
|
219
|
+
batch_dict=batch_dict,
|
|
220
|
+
num_sampled_nodes_dict=num_sampled_nodes_dict,
|
|
221
|
+
row_dict={}, # TODO
|
|
222
|
+
col_dict={}, # TODO
|
|
223
|
+
num_sampled_edges_dict={}, # TODO
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Abstract Methods ########################################################
|
|
227
|
+
|
|
228
|
+
@abstractmethod
|
|
229
|
+
def _by_pkey(
|
|
230
|
+
self,
|
|
231
|
+
table_name: str,
|
|
232
|
+
index: pd.Series,
|
|
233
|
+
columns: set[str],
|
|
234
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
235
|
+
pass
|
|
236
|
+
|
|
237
|
+
@abstractmethod
|
|
238
|
+
def _by_fkey(
|
|
239
|
+
self,
|
|
240
|
+
table_name: str,
|
|
241
|
+
foreign_key: str,
|
|
242
|
+
index: pd.Series,
|
|
243
|
+
num_neighbors: int,
|
|
244
|
+
anchor_time: pd.Series | None,
|
|
245
|
+
columns: set[str],
|
|
246
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
247
|
+
pass
|