kumoai 2.13.0.dev202512061731__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/_version.py +1 -1
- kumoai/client/pquery.py +6 -2
- kumoai/experimental/rfm/backend/local/graph_store.py +19 -62
- kumoai/experimental/rfm/backend/local/sampler.py +229 -45
- kumoai/experimental/rfm/backend/local/table.py +12 -2
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +264 -0
- kumoai/experimental/rfm/backend/snow/table.py +35 -17
- kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +354 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +36 -11
- kumoai/experimental/rfm/base/__init__.py +16 -5
- kumoai/experimental/rfm/base/sampler.py +538 -52
- kumoai/experimental/rfm/base/source.py +1 -0
- kumoai/experimental/rfm/base/sql_sampler.py +56 -0
- kumoai/experimental/rfm/base/table.py +12 -1
- kumoai/experimental/rfm/graph.py +26 -9
- kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
- kumoai/experimental/rfm/rfm.py +214 -151
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/utils/__init__.py +2 -0
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/METADATA +2 -2
- {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/RECORD +28 -25
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import pyarrow as pa
|
|
7
|
+
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
8
|
+
|
|
9
|
+
from kumoai.experimental.rfm.backend.snow import SnowTable
|
|
10
|
+
from kumoai.experimental.rfm.base import SQLSampler
|
|
11
|
+
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
12
|
+
from kumoai.utils import ProgressLogger, quote_ident
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from kumoai.experimental.rfm import Graph
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SnowSampler(SQLSampler):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
graph: 'Graph',
|
|
22
|
+
verbose: bool | ProgressLogger = True,
|
|
23
|
+
) -> None:
|
|
24
|
+
super().__init__(graph=graph, verbose=verbose)
|
|
25
|
+
|
|
26
|
+
self._fqn_dict: dict[str, str] = {}
|
|
27
|
+
for table in graph.tables.values():
|
|
28
|
+
assert isinstance(table, SnowTable)
|
|
29
|
+
self._connection = table._connection
|
|
30
|
+
self._fqn_dict[table.name] = table.fqn
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def fqn_dict(self) -> dict[str, str]:
|
|
34
|
+
r"""The fully-qualified quoted names for all tables in the graph."""
|
|
35
|
+
return self._fqn_dict
|
|
36
|
+
|
|
37
|
+
def _get_min_max_time_dict(
|
|
38
|
+
self,
|
|
39
|
+
table_names: list[str],
|
|
40
|
+
) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
|
|
41
|
+
selects: list[str] = []
|
|
42
|
+
for table_name in table_names:
|
|
43
|
+
time_column = self.time_column_dict[table_name]
|
|
44
|
+
select = (f"SELECT\n"
|
|
45
|
+
f" %s as table_name,\n"
|
|
46
|
+
f" MIN({quote_ident(time_column)}) as min_date,\n"
|
|
47
|
+
f" MAX({quote_ident(time_column)}) as max_date\n"
|
|
48
|
+
f"FROM {self.fqn_dict[table_name]}")
|
|
49
|
+
selects.append(select)
|
|
50
|
+
sql = "\nUNION ALL\n".join(selects)
|
|
51
|
+
|
|
52
|
+
out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
53
|
+
with self._connection.cursor() as cursor:
|
|
54
|
+
cursor.execute(sql, table_names)
|
|
55
|
+
rows = cursor.fetchall()
|
|
56
|
+
for table_name, _min, _max in rows:
|
|
57
|
+
out_dict[table_name] = (
|
|
58
|
+
pd.Timestamp.max if _min is None else pd.Timestamp(_min),
|
|
59
|
+
pd.Timestamp.min if _max is None else pd.Timestamp(_max),
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
return out_dict
|
|
63
|
+
|
|
64
|
+
def _sample_entity_table(
|
|
65
|
+
self,
|
|
66
|
+
table_name: str,
|
|
67
|
+
columns: set[str],
|
|
68
|
+
num_rows: int,
|
|
69
|
+
random_seed: int | None = None,
|
|
70
|
+
) -> pd.DataFrame:
|
|
71
|
+
# NOTE Snowflake does support `SEED` only as part of `SYSTEM` sampling.
|
|
72
|
+
num_rows = min(num_rows, 1_000_000) # Snowflake's upper limit.
|
|
73
|
+
|
|
74
|
+
filters: list[str] = []
|
|
75
|
+
primary_key = self.primary_key_dict[table_name]
|
|
76
|
+
if self.source_table_dict[table_name][primary_key].is_nullable:
|
|
77
|
+
filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
|
|
78
|
+
time_column = self.time_column_dict.get(table_name)
|
|
79
|
+
if (time_column is not None and
|
|
80
|
+
self.source_table_dict[table_name][time_column].is_nullable):
|
|
81
|
+
filters.append(f" {quote_ident(time_column)} IS NOT NULL")
|
|
82
|
+
|
|
83
|
+
sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
|
|
84
|
+
f"FROM {self.fqn_dict[table_name]}\n"
|
|
85
|
+
f"SAMPLE ROW ({num_rows} ROWS)")
|
|
86
|
+
if len(filters) > 0:
|
|
87
|
+
sql += f"\nWHERE{' AND'.join(filters)}"
|
|
88
|
+
|
|
89
|
+
with self._connection.cursor() as cursor:
|
|
90
|
+
# NOTE This may return duplicate primary keys. This is okay.
|
|
91
|
+
cursor.execute(sql)
|
|
92
|
+
table = cursor.fetch_arrow_all()
|
|
93
|
+
|
|
94
|
+
return self._sanitize(table_name, table)
|
|
95
|
+
|
|
96
|
+
def _sample_target(
|
|
97
|
+
self,
|
|
98
|
+
query: ValidatedPredictiveQuery,
|
|
99
|
+
entity_df: pd.DataFrame,
|
|
100
|
+
train_index: np.ndarray,
|
|
101
|
+
train_time: pd.Series,
|
|
102
|
+
num_train_examples: int,
|
|
103
|
+
test_index: np.ndarray,
|
|
104
|
+
test_time: pd.Series,
|
|
105
|
+
num_test_examples: int,
|
|
106
|
+
columns_dict: dict[str, set[str]],
|
|
107
|
+
time_offset_dict: dict[
|
|
108
|
+
tuple[str, str, str],
|
|
109
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
110
|
+
],
|
|
111
|
+
) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
|
|
112
|
+
|
|
113
|
+
# NOTE For Snowflake, we execute everything at once to pay minimal
|
|
114
|
+
# query initialization costs.
|
|
115
|
+
index = np.concatenate([train_index, test_index])
|
|
116
|
+
time = pd.concat([train_time, test_time], axis=0, ignore_index=True)
|
|
117
|
+
|
|
118
|
+
entity_df = entity_df.iloc[index].reset_index(drop=True)
|
|
119
|
+
|
|
120
|
+
feat_dict: dict[str, pd.DataFrame] = {query.entity_table: entity_df}
|
|
121
|
+
time_dict: dict[str, pd.Series] = {}
|
|
122
|
+
time_column = self.time_column_dict.get(query.entity_table)
|
|
123
|
+
if time_column in columns_dict[query.entity_table]:
|
|
124
|
+
time_dict[query.entity_table] = entity_df[time_column]
|
|
125
|
+
batch_dict: dict[str, np.ndarray] = {
|
|
126
|
+
query.entity_table: np.arange(len(entity_df)),
|
|
127
|
+
}
|
|
128
|
+
for edge_type, (min_offset, max_offset) in time_offset_dict.items():
|
|
129
|
+
table_name, fkey, _ = edge_type
|
|
130
|
+
feat_dict[table_name], batch_dict[table_name] = self._by_time(
|
|
131
|
+
table_name=table_name,
|
|
132
|
+
fkey=fkey,
|
|
133
|
+
pkey=entity_df[self.primary_key_dict[query.entity_table]],
|
|
134
|
+
anchor_time=time,
|
|
135
|
+
min_offset=min_offset,
|
|
136
|
+
max_offset=max_offset,
|
|
137
|
+
columns=columns_dict[table_name],
|
|
138
|
+
)
|
|
139
|
+
time_column = self.time_column_dict.get(table_name)
|
|
140
|
+
if time_column in columns_dict[table_name]:
|
|
141
|
+
time_dict[table_name] = feat_dict[table_name][time_column]
|
|
142
|
+
|
|
143
|
+
y, mask = PQueryPandasExecutor().execute(
|
|
144
|
+
query=query,
|
|
145
|
+
feat_dict=feat_dict,
|
|
146
|
+
time_dict=time_dict,
|
|
147
|
+
batch_dict=batch_dict,
|
|
148
|
+
anchor_time=time,
|
|
149
|
+
num_forecasts=query.num_forecasts,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
train_mask = mask[:len(train_index)]
|
|
153
|
+
test_mask = mask[len(train_index):]
|
|
154
|
+
|
|
155
|
+
boundary = int(train_mask.sum())
|
|
156
|
+
train_y = y.iloc[:boundary]
|
|
157
|
+
test_y = y.iloc[boundary:].reset_index(drop=True)
|
|
158
|
+
|
|
159
|
+
return train_y, train_mask, test_y, test_mask
|
|
160
|
+
|
|
161
|
+
def _by_pkey(
|
|
162
|
+
self,
|
|
163
|
+
table_name: str,
|
|
164
|
+
pkey: pd.Series,
|
|
165
|
+
columns: set[str],
|
|
166
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
167
|
+
|
|
168
|
+
pkey_name = self.primary_key_dict[table_name]
|
|
169
|
+
source_table = self.source_table_dict[table_name]
|
|
170
|
+
|
|
171
|
+
payload = json.dumps(list(pkey))
|
|
172
|
+
|
|
173
|
+
sql = ("WITH TMP as (\n"
|
|
174
|
+
" SELECT\n"
|
|
175
|
+
" f.index as BATCH,\n")
|
|
176
|
+
if source_table[pkey_name].dtype.is_int():
|
|
177
|
+
sql += " f.value::NUMBER as ID\n"
|
|
178
|
+
elif source_table[pkey_name].dtype.is_float():
|
|
179
|
+
sql += " f.value::FLOAT as ID\n"
|
|
180
|
+
else:
|
|
181
|
+
sql += " f.value::VARCHAR as ID\n"
|
|
182
|
+
sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(%s))) f\n"
|
|
183
|
+
f")\n"
|
|
184
|
+
f"SELECT TMP.BATCH as __BATCH__, "
|
|
185
|
+
f"{', '.join('ENT.' + quote_ident(col) for col in columns)}\n"
|
|
186
|
+
f"FROM TMP\n"
|
|
187
|
+
f"JOIN {self.fqn_dict[table_name]} ENT\n"
|
|
188
|
+
f" ON ENT.{quote_ident(pkey_name)} = TMP.ID")
|
|
189
|
+
|
|
190
|
+
with self._connection.cursor() as cursor:
|
|
191
|
+
cursor.execute(sql, (payload, ))
|
|
192
|
+
table = cursor.fetch_arrow_all()
|
|
193
|
+
|
|
194
|
+
# Remove any duplicated primary keys in post-processing:
|
|
195
|
+
tmp = table.append_column('__TMP__', pa.array(range(len(table))))
|
|
196
|
+
gb = tmp.group_by('__BATCH__').aggregate([('__TMP__', 'min')])
|
|
197
|
+
table = table.take(gb['__TMP___min'])
|
|
198
|
+
|
|
199
|
+
batch = table['__BATCH__'].cast(pa.int64()).to_numpy()
|
|
200
|
+
table = table.remove_column(table.schema.get_field_index('__BATCH__'))
|
|
201
|
+
|
|
202
|
+
return table.to_pandas(), batch # TODO Use `self._sanitize`.
|
|
203
|
+
|
|
204
|
+
# Helper Methods ##########################################################
|
|
205
|
+
|
|
206
|
+
def _by_time(
|
|
207
|
+
self,
|
|
208
|
+
table_name: str,
|
|
209
|
+
fkey: str,
|
|
210
|
+
pkey: pd.Series,
|
|
211
|
+
anchor_time: pd.Series,
|
|
212
|
+
min_offset: pd.DateOffset | None,
|
|
213
|
+
max_offset: pd.DateOffset,
|
|
214
|
+
columns: set[str],
|
|
215
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
216
|
+
|
|
217
|
+
end_time = anchor_time + max_offset
|
|
218
|
+
end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
219
|
+
if min_offset is not None:
|
|
220
|
+
start_time = anchor_time + min_offset
|
|
221
|
+
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
222
|
+
payload = json.dumps(list(zip(pkey, end_time, start_time)))
|
|
223
|
+
else:
|
|
224
|
+
payload = json.dumps(list(zip(pkey, end_time)))
|
|
225
|
+
|
|
226
|
+
# Based on benchmarking, JSON payload is the fastest way to query by
|
|
227
|
+
# custom indices (compared to large `IN` clauses or temporary tables):
|
|
228
|
+
source_table = self.source_table_dict[table_name]
|
|
229
|
+
time_column = self.time_column_dict[table_name]
|
|
230
|
+
sql = ("WITH TMP as (\n"
|
|
231
|
+
" SELECT\n"
|
|
232
|
+
" f.index as BATCH,\n")
|
|
233
|
+
if source_table[fkey].dtype.is_int():
|
|
234
|
+
sql += " f.value[0]::NUMBER as ID,\n"
|
|
235
|
+
elif source_table[fkey].dtype.is_float():
|
|
236
|
+
sql += " f.value[0]::FLOAT as ID,\n"
|
|
237
|
+
else:
|
|
238
|
+
sql += " f.value[0]::VARCHAR as ID,\n"
|
|
239
|
+
sql += " f.value[1]::TIMESTAMP_NTZ as END_TIME"
|
|
240
|
+
if min_offset is not None:
|
|
241
|
+
sql += ",\n f.value[2]::TIMESTAMP_NTZ as START_TIME"
|
|
242
|
+
sql += (f"\n"
|
|
243
|
+
f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(%s))) f\n"
|
|
244
|
+
f")\n"
|
|
245
|
+
f"SELECT TMP.BATCH as __BATCH__, "
|
|
246
|
+
f"{', '.join('FACT.' + quote_ident(col) for col in columns)}\n"
|
|
247
|
+
f"FROM TMP\n"
|
|
248
|
+
f"JOIN {self.fqn_dict[table_name]} FACT\n"
|
|
249
|
+
f" ON FACT.{quote_ident(fkey)} = TMP.ID\n"
|
|
250
|
+
f" AND FACT.{quote_ident(time_column)} <= TMP.END_TIME")
|
|
251
|
+
if min_offset is not None:
|
|
252
|
+
sql += f"\n AND FACT.{quote_ident(time_column)} > TMP.START_TIME"
|
|
253
|
+
|
|
254
|
+
with self._connection.cursor() as cursor:
|
|
255
|
+
cursor.execute(sql, (payload, ))
|
|
256
|
+
table = cursor.fetch_arrow_all()
|
|
257
|
+
|
|
258
|
+
batch = table['__BATCH__'].cast(pa.int64()).to_numpy()
|
|
259
|
+
table = table.remove_column(table.schema.get_field_index('__BATCH__'))
|
|
260
|
+
|
|
261
|
+
return self._sanitize(table_name, table), batch
|
|
262
|
+
|
|
263
|
+
def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
|
|
264
|
+
return table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
@@ -1,11 +1,17 @@
|
|
|
1
1
|
import re
|
|
2
|
-
from typing import List, Optional, Sequence
|
|
2
|
+
from typing import List, Optional, Sequence, cast
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
from kumoapi.typing import Dtype
|
|
6
6
|
|
|
7
7
|
from kumoai.experimental.rfm.backend.snow import Connection
|
|
8
|
-
from kumoai.experimental.rfm.base import
|
|
8
|
+
from kumoai.experimental.rfm.base import (
|
|
9
|
+
DataBackend,
|
|
10
|
+
SourceColumn,
|
|
11
|
+
SourceForeignKey,
|
|
12
|
+
Table,
|
|
13
|
+
)
|
|
14
|
+
from kumoai.utils import quote_ident
|
|
9
15
|
|
|
10
16
|
|
|
11
17
|
class SnowTable(Table):
|
|
@@ -51,27 +57,36 @@ class SnowTable(Table):
|
|
|
51
57
|
)
|
|
52
58
|
|
|
53
59
|
@property
|
|
54
|
-
def
|
|
60
|
+
def backend(self) -> DataBackend:
|
|
61
|
+
return cast(DataBackend, DataBackend.SNOWFLAKE)
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def fqn(self) -> str:
|
|
65
|
+
r"""The fully-qualified quoted table name."""
|
|
55
66
|
names: List[str] = []
|
|
56
67
|
if self._database is not None:
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
names.append(self._name)
|
|
62
|
-
return '.'.join(names)
|
|
68
|
+
names.append(quote_ident(self._database))
|
|
69
|
+
if self._schema is not None:
|
|
70
|
+
names.append(quote_ident(self._schema))
|
|
71
|
+
return '.'.join(names + [quote_ident(self._name)])
|
|
63
72
|
|
|
64
73
|
def _get_source_columns(self) -> List[SourceColumn]:
|
|
65
74
|
source_columns: List[SourceColumn] = []
|
|
66
75
|
with self._connection.cursor() as cursor:
|
|
67
76
|
try:
|
|
68
|
-
|
|
77
|
+
sql = f"DESCRIBE TABLE {self.fqn}"
|
|
78
|
+
cursor.execute(sql)
|
|
69
79
|
except Exception as e:
|
|
70
|
-
|
|
71
|
-
|
|
80
|
+
names: list[str] = []
|
|
81
|
+
if self._database is not None:
|
|
82
|
+
names.append(self._database)
|
|
83
|
+
if self._schema is not None:
|
|
84
|
+
names.append(self._schema)
|
|
85
|
+
name = '.'.join(names + [self._name])
|
|
86
|
+
raise ValueError(f"Table '{name}' does not exist") from e
|
|
72
87
|
|
|
73
88
|
for row in cursor.fetchall():
|
|
74
|
-
column, type, _,
|
|
89
|
+
column, type, _, null, _, is_pkey, is_unique, *_ = row
|
|
75
90
|
|
|
76
91
|
type = type.strip().upper()
|
|
77
92
|
if type.startswith('NUMBER'):
|
|
@@ -92,6 +107,7 @@ class SnowTable(Table):
|
|
|
92
107
|
dtype=dtype,
|
|
93
108
|
is_primary_key=is_pkey.strip().upper() == 'Y',
|
|
94
109
|
is_unique_key=is_unique.strip().upper() == 'Y',
|
|
110
|
+
is_nullable=null.strip().upper() == 'Y',
|
|
95
111
|
)
|
|
96
112
|
source_columns.append(source_column)
|
|
97
113
|
|
|
@@ -100,16 +116,18 @@ class SnowTable(Table):
|
|
|
100
116
|
def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
|
|
101
117
|
source_fkeys: List[SourceForeignKey] = []
|
|
102
118
|
with self._connection.cursor() as cursor:
|
|
103
|
-
|
|
119
|
+
sql = f"SHOW IMPORTED KEYS IN TABLE {self.fqn}"
|
|
120
|
+
cursor.execute(sql)
|
|
104
121
|
for row in cursor.fetchall():
|
|
105
|
-
_, _, _, dst_table, pkey, _, _, _, fkey = row
|
|
122
|
+
_, _, _, dst_table, pkey, _, _, _, fkey, *_ = row
|
|
106
123
|
source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
|
|
107
124
|
return source_fkeys
|
|
108
125
|
|
|
109
126
|
def _get_sample_df(self) -> pd.DataFrame:
|
|
110
127
|
with self._connection.cursor() as cursor:
|
|
111
|
-
columns =
|
|
112
|
-
|
|
128
|
+
columns = [quote_ident(col) for col in self._source_column_dict]
|
|
129
|
+
sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
|
|
130
|
+
cursor.execute(sql)
|
|
113
131
|
table = cursor.fetch_arrow_all()
|
|
114
132
|
return table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
115
133
|
|
|
@@ -22,9 +22,11 @@ def connect(uri: Union[str, Path, None] = None, **kwargs: Any) -> Connection:
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
from .table import SQLiteTable # noqa: E402
|
|
25
|
+
from .sampler import SQLiteSampler # noqa: E402
|
|
25
26
|
|
|
26
27
|
__all__ = [
|
|
27
28
|
'connect',
|
|
28
29
|
'Connection',
|
|
29
30
|
'SQLiteTable',
|
|
31
|
+
'SQLiteSampler',
|
|
30
32
|
]
|