kumoai 2.13.0.dev202512011731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512181731__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/client/pquery.py +6 -2
- kumoai/experimental/rfm/__init__.py +33 -8
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/__init__.py +4 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +53 -107
- kumoai/experimental/rfm/backend/local/sampler.py +315 -0
- kumoai/experimental/rfm/backend/local/table.py +41 -80
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
- kumoai/experimental/rfm/backend/snow/table.py +147 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +11 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +108 -88
- kumoai/experimental/rfm/base/__init__.py +26 -2
- kumoai/experimental/rfm/base/column.py +6 -12
- kumoai/experimental/rfm/base/column_expression.py +16 -0
- kumoai/experimental/rfm/base/sampler.py +773 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +84 -0
- kumoai/experimental/rfm/base/sql_table.py +113 -0
- kumoai/experimental/rfm/base/table.py +174 -76
- kumoai/experimental/rfm/graph.py +444 -84
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +77 -0
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/rfm.py +299 -240
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/progress_logger.py +178 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/METADATA +6 -2
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/RECORD +42 -30
- kumoai/experimental/rfm/local_graph_sampler.py +0 -182
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import pyarrow as pa
|
|
8
|
+
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
9
|
+
from kumoapi.typing import Stype
|
|
10
|
+
|
|
11
|
+
from kumoai.experimental.rfm.base import SQLSampler
|
|
12
|
+
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
13
|
+
from kumoai.utils import ProgressLogger, quote_ident
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from kumoai.experimental.rfm import Graph
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SQLiteSampler(SQLSampler):
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
graph: 'Graph',
|
|
23
|
+
verbose: bool | ProgressLogger = True,
|
|
24
|
+
optimize: bool = False,
|
|
25
|
+
) -> None:
|
|
26
|
+
super().__init__(graph=graph, verbose=verbose)
|
|
27
|
+
|
|
28
|
+
if optimize:
|
|
29
|
+
with self._connection.cursor() as cursor:
|
|
30
|
+
cursor.execute("PRAGMA temp_store = MEMORY")
|
|
31
|
+
cursor.execute("PRAGMA cache_size = -2000000") # 2 GB
|
|
32
|
+
|
|
33
|
+
# Collect database indices to speed-up sampling:
|
|
34
|
+
index_dict: dict[str, set[tuple[str, ...]]] = defaultdict(set)
|
|
35
|
+
for table_name, primary_key in self.primary_key_dict.items():
|
|
36
|
+
source_table = self.source_table_dict[table_name]
|
|
37
|
+
if not source_table[primary_key].is_unique_key:
|
|
38
|
+
index_dict[table_name].add((primary_key, ))
|
|
39
|
+
for src_table_name, foreign_key, _ in graph.edges:
|
|
40
|
+
source_table = self.source_table_dict[src_table_name]
|
|
41
|
+
if source_table[foreign_key].is_unique_key:
|
|
42
|
+
pass
|
|
43
|
+
elif time_column := self.time_column_dict.get(src_table_name):
|
|
44
|
+
index_dict[src_table_name].add((foreign_key, time_column))
|
|
45
|
+
else:
|
|
46
|
+
index_dict[src_table_name].add((foreign_key, ))
|
|
47
|
+
|
|
48
|
+
# Only maintain missing indices:
|
|
49
|
+
with self._connection.cursor() as cursor:
|
|
50
|
+
for table_name in list(index_dict.keys()):
|
|
51
|
+
indices = index_dict[table_name]
|
|
52
|
+
sql = f"PRAGMA index_list({self.fqn_dict[table_name]})"
|
|
53
|
+
cursor.execute(sql)
|
|
54
|
+
for _, index_name, *_ in cursor.fetchall():
|
|
55
|
+
sql = f"PRAGMA index_info({quote_ident(index_name)})"
|
|
56
|
+
cursor.execute(sql)
|
|
57
|
+
index = tuple(info[2] for info in sorted(
|
|
58
|
+
cursor.fetchall(), key=lambda x: x[0]))
|
|
59
|
+
indices.discard(index)
|
|
60
|
+
if len(indices) == 0:
|
|
61
|
+
del index_dict[table_name]
|
|
62
|
+
|
|
63
|
+
num = sum(len(indices) for indices in index_dict.values())
|
|
64
|
+
index_repr = '1 index' if num == 1 else f'{num} indices'
|
|
65
|
+
num = len(index_dict)
|
|
66
|
+
table_repr = '1 table' if num == 1 else f'{num} tables'
|
|
67
|
+
|
|
68
|
+
if optimize and len(index_dict) > 0:
|
|
69
|
+
if not isinstance(verbose, ProgressLogger):
|
|
70
|
+
verbose = ProgressLogger.default(
|
|
71
|
+
msg="Optimizing SQLite database",
|
|
72
|
+
verbose=verbose,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
with verbose as logger, self._connection.cursor() as cursor:
|
|
76
|
+
for table_name, indices in index_dict.items():
|
|
77
|
+
for index in indices:
|
|
78
|
+
name = f"kumo_index_{table_name}_{'_'.join(index)}"
|
|
79
|
+
name = quote_ident(name)
|
|
80
|
+
columns = ', '.join(quote_ident(v) for v in index)
|
|
81
|
+
columns += ' DESC' if len(index) > 1 else ''
|
|
82
|
+
sql = (f"CREATE INDEX IF NOT EXISTS {name}\n"
|
|
83
|
+
f"ON {self.fqn_dict[table_name]}({columns})")
|
|
84
|
+
cursor.execute(sql)
|
|
85
|
+
self._connection.commit()
|
|
86
|
+
logger.log(f"Created {index_repr} in {table_repr}")
|
|
87
|
+
|
|
88
|
+
elif len(index_dict) > 0:
|
|
89
|
+
warnings.warn(f"Missing {index_repr} in {table_repr} for optimal "
|
|
90
|
+
f"database querying. For improving runtime, we "
|
|
91
|
+
f"strongly suggest to create these indices by "
|
|
92
|
+
f"instantiating KumoRFM via "
|
|
93
|
+
f"`KumoRFM(graph, optimize=True)`.")
|
|
94
|
+
|
|
95
|
+
def _get_min_max_time_dict(
|
|
96
|
+
self,
|
|
97
|
+
table_names: list[str],
|
|
98
|
+
) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
|
|
99
|
+
selects: list[str] = []
|
|
100
|
+
for table_name in table_names:
|
|
101
|
+
time_column = self.time_column_dict[table_name]
|
|
102
|
+
select = (f"SELECT\n"
|
|
103
|
+
f" ? as table_name,\n"
|
|
104
|
+
f" MIN({quote_ident(time_column)}) as min_date,\n"
|
|
105
|
+
f" MAX({quote_ident(time_column)}) as max_date\n"
|
|
106
|
+
f"FROM {self.fqn_dict[table_name]}")
|
|
107
|
+
selects.append(select)
|
|
108
|
+
sql = "\nUNION ALL\n".join(selects)
|
|
109
|
+
|
|
110
|
+
out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
111
|
+
with self._connection.cursor() as cursor:
|
|
112
|
+
cursor.execute(sql, table_names)
|
|
113
|
+
for table_name, _min, _max in cursor.fetchall():
|
|
114
|
+
out_dict[table_name] = (
|
|
115
|
+
pd.Timestamp.max if _min is None else pd.Timestamp(_min),
|
|
116
|
+
pd.Timestamp.min if _max is None else pd.Timestamp(_max),
|
|
117
|
+
)
|
|
118
|
+
return out_dict
|
|
119
|
+
|
|
120
|
+
def _sample_entity_table(
|
|
121
|
+
self,
|
|
122
|
+
table_name: str,
|
|
123
|
+
columns: set[str],
|
|
124
|
+
num_rows: int,
|
|
125
|
+
random_seed: int | None = None,
|
|
126
|
+
) -> pd.DataFrame:
|
|
127
|
+
# NOTE SQLite does not natively support passing a `random_seed`.
|
|
128
|
+
|
|
129
|
+
filters: list[str] = []
|
|
130
|
+
primary_key = self.primary_key_dict[table_name]
|
|
131
|
+
if self.source_table_dict[table_name][primary_key].is_nullable:
|
|
132
|
+
filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
|
|
133
|
+
time_column = self.time_column_dict.get(table_name)
|
|
134
|
+
if (time_column is not None and
|
|
135
|
+
self.source_table_dict[table_name][time_column].is_nullable):
|
|
136
|
+
filters.append(f" {quote_ident(time_column)} IS NOT NULL")
|
|
137
|
+
|
|
138
|
+
# TODO Make this query more efficient - it does full table scan.
|
|
139
|
+
sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
|
|
140
|
+
f"FROM {self.fqn_dict[table_name]}")
|
|
141
|
+
if len(filters) > 0:
|
|
142
|
+
sql += f"\nWHERE{' AND'.join(filters)}"
|
|
143
|
+
sql += f"\nORDER BY RANDOM() LIMIT {num_rows}"
|
|
144
|
+
|
|
145
|
+
with self._connection.cursor() as cursor:
|
|
146
|
+
# NOTE This may return duplicate primary keys. This is okay.
|
|
147
|
+
cursor.execute(sql)
|
|
148
|
+
table = cursor.fetch_arrow_table()
|
|
149
|
+
|
|
150
|
+
return self._sanitize(table_name, table)
|
|
151
|
+
|
|
152
|
+
def _sample_target(
|
|
153
|
+
self,
|
|
154
|
+
query: ValidatedPredictiveQuery,
|
|
155
|
+
entity_df: pd.DataFrame,
|
|
156
|
+
train_index: np.ndarray,
|
|
157
|
+
train_time: pd.Series,
|
|
158
|
+
num_train_examples: int,
|
|
159
|
+
test_index: np.ndarray,
|
|
160
|
+
test_time: pd.Series,
|
|
161
|
+
num_test_examples: int,
|
|
162
|
+
columns_dict: dict[str, set[str]],
|
|
163
|
+
time_offset_dict: dict[
|
|
164
|
+
tuple[str, str, str],
|
|
165
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
166
|
+
],
|
|
167
|
+
) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
|
|
168
|
+
train_y, train_mask = self._sample_target_set(
|
|
169
|
+
query=query,
|
|
170
|
+
entity_df=entity_df,
|
|
171
|
+
index=train_index,
|
|
172
|
+
anchor_time=train_time,
|
|
173
|
+
num_examples=num_train_examples,
|
|
174
|
+
columns_dict=columns_dict,
|
|
175
|
+
time_offset_dict=time_offset_dict,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
test_y, test_mask = self._sample_target_set(
|
|
179
|
+
query=query,
|
|
180
|
+
entity_df=entity_df,
|
|
181
|
+
index=test_index,
|
|
182
|
+
anchor_time=test_time,
|
|
183
|
+
num_examples=num_test_examples,
|
|
184
|
+
columns_dict=columns_dict,
|
|
185
|
+
time_offset_dict=time_offset_dict,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
return train_y, train_mask, test_y, test_mask
|
|
189
|
+
|
|
190
|
+
def _by_pkey(
|
|
191
|
+
self,
|
|
192
|
+
table_name: str,
|
|
193
|
+
pkey: pd.Series,
|
|
194
|
+
columns: set[str],
|
|
195
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
196
|
+
pkey_name = self.primary_key_dict[table_name]
|
|
197
|
+
|
|
198
|
+
tmp = pa.table([pa.array(pkey)], names=['id'])
|
|
199
|
+
tmp_name = f'tmp_{table_name}_{pkey_name}_{id(tmp)}'
|
|
200
|
+
|
|
201
|
+
if self.source_table_dict[table_name][pkey_name].is_unique_key:
|
|
202
|
+
sql = (f"SELECT tmp.rowid - 1 as __batch__, "
|
|
203
|
+
f"{', '.join('ent.' + quote_ident(c) for c in columns)}\n"
|
|
204
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
205
|
+
f"JOIN {self.fqn_dict[table_name]} ent\n"
|
|
206
|
+
f" ON ent.{quote_ident(pkey_name)} = tmp.id")
|
|
207
|
+
else:
|
|
208
|
+
sql = (f"SELECT tmp.rowid - 1 as __batch__, "
|
|
209
|
+
f"{', '.join('ent.' + quote_ident(c) for c in columns)}\n"
|
|
210
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
211
|
+
f"JOIN {self.fqn_dict[table_name]} ent\n"
|
|
212
|
+
f" ON ent.rowid = (\n"
|
|
213
|
+
f" SELECT rowid FROM {self.fqn_dict[table_name]}\n"
|
|
214
|
+
f" WHERE {quote_ident(pkey_name)} == tmp.id\n"
|
|
215
|
+
f" LIMIT 1\n"
|
|
216
|
+
f")")
|
|
217
|
+
|
|
218
|
+
with self._connection.cursor() as cursor:
|
|
219
|
+
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
220
|
+
cursor.execute(sql)
|
|
221
|
+
table = cursor.fetch_arrow_table()
|
|
222
|
+
|
|
223
|
+
batch = table['__batch__'].to_numpy()
|
|
224
|
+
table = table.remove_column(table.schema.get_field_index('__batch__'))
|
|
225
|
+
|
|
226
|
+
return table.to_pandas(), batch # TODO Use `self._sanitize`.
|
|
227
|
+
|
|
228
|
+
# Helper Methods ##########################################################
|
|
229
|
+
|
|
230
|
+
def _by_time(
|
|
231
|
+
self,
|
|
232
|
+
table_name: str,
|
|
233
|
+
fkey: str,
|
|
234
|
+
pkey: pd.Series,
|
|
235
|
+
anchor_time: pd.Series,
|
|
236
|
+
min_offset: pd.DateOffset | None,
|
|
237
|
+
max_offset: pd.DateOffset,
|
|
238
|
+
columns: set[str],
|
|
239
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
240
|
+
# NOTE SQLite does not have a native datetime format. Currently, we
|
|
241
|
+
# assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
|
|
242
|
+
tmp = pa.table([pa.array(pkey)], names=['id'])
|
|
243
|
+
end_time = anchor_time + max_offset
|
|
244
|
+
end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
245
|
+
tmp = tmp.append_column('end', pa.array(end_time))
|
|
246
|
+
if min_offset is not None:
|
|
247
|
+
start_time = anchor_time + min_offset
|
|
248
|
+
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
249
|
+
tmp = tmp.append_column('start', pa.array(start_time))
|
|
250
|
+
tmp_name = f'tmp_{table_name}_{fkey}_{id(tmp)}'
|
|
251
|
+
|
|
252
|
+
time_column = self.time_column_dict[table_name]
|
|
253
|
+
sql = (f"SELECT tmp.rowid - 1 as __batch__, "
|
|
254
|
+
f"{', '.join('fact.' + quote_ident(col) for col in columns)}\n"
|
|
255
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
256
|
+
f"JOIN {self.fqn_dict[table_name]} fact\n"
|
|
257
|
+
f" ON fact.{quote_ident(fkey)} = tmp.id\n"
|
|
258
|
+
f" AND fact.{quote_ident(time_column)} <= tmp.end")
|
|
259
|
+
if min_offset is not None:
|
|
260
|
+
sql += f"\n AND fact.{quote_ident(time_column)} > tmp.start"
|
|
261
|
+
|
|
262
|
+
with self._connection.cursor() as cursor:
|
|
263
|
+
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
264
|
+
cursor.execute(sql)
|
|
265
|
+
table = cursor.fetch_arrow_table()
|
|
266
|
+
|
|
267
|
+
batch = table['__batch__'].to_numpy()
|
|
268
|
+
table = table.remove_column(table.schema.get_field_index('__batch__'))
|
|
269
|
+
|
|
270
|
+
return self._sanitize(table_name, table), batch
|
|
271
|
+
|
|
272
|
+
def _sample_target_set(
|
|
273
|
+
self,
|
|
274
|
+
query: ValidatedPredictiveQuery,
|
|
275
|
+
entity_df: pd.DataFrame,
|
|
276
|
+
index: np.ndarray,
|
|
277
|
+
anchor_time: pd.Series,
|
|
278
|
+
num_examples: int,
|
|
279
|
+
columns_dict: dict[str, set[str]],
|
|
280
|
+
time_offset_dict: dict[
|
|
281
|
+
tuple[str, str, str],
|
|
282
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
283
|
+
],
|
|
284
|
+
batch_size: int = 10_000,
|
|
285
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
286
|
+
|
|
287
|
+
count = 0
|
|
288
|
+
ys: list[pd.Series] = []
|
|
289
|
+
mask = np.full(len(index), False, dtype=bool)
|
|
290
|
+
for start in range(0, len(index), batch_size):
|
|
291
|
+
df = entity_df.iloc[index[start:start + batch_size]]
|
|
292
|
+
time = anchor_time.iloc[start:start + batch_size]
|
|
293
|
+
|
|
294
|
+
feat_dict: dict[str, pd.DataFrame] = {query.entity_table: df}
|
|
295
|
+
time_dict: dict[str, pd.Series] = {}
|
|
296
|
+
time_column = self.time_column_dict.get(query.entity_table)
|
|
297
|
+
if time_column in columns_dict[query.entity_table]:
|
|
298
|
+
time_dict[query.entity_table] = df[time_column]
|
|
299
|
+
batch_dict: dict[str, np.ndarray] = {
|
|
300
|
+
query.entity_table: np.arange(len(df)),
|
|
301
|
+
}
|
|
302
|
+
for edge_type, (_min, _max) in time_offset_dict.items():
|
|
303
|
+
table_name, fkey, _ = edge_type
|
|
304
|
+
feat_dict[table_name], batch_dict[table_name] = self._by_time(
|
|
305
|
+
table_name=table_name,
|
|
306
|
+
fkey=fkey,
|
|
307
|
+
pkey=df[self.primary_key_dict[query.entity_table]],
|
|
308
|
+
anchor_time=time,
|
|
309
|
+
min_offset=_min,
|
|
310
|
+
max_offset=_max,
|
|
311
|
+
columns=columns_dict[table_name],
|
|
312
|
+
)
|
|
313
|
+
time_column = self.time_column_dict.get(table_name)
|
|
314
|
+
if time_column in columns_dict[table_name]:
|
|
315
|
+
time_dict[table_name] = feat_dict[table_name][time_column]
|
|
316
|
+
|
|
317
|
+
y, _mask = PQueryPandasExecutor().execute(
|
|
318
|
+
query=query,
|
|
319
|
+
feat_dict=feat_dict,
|
|
320
|
+
time_dict=time_dict,
|
|
321
|
+
batch_dict=batch_dict,
|
|
322
|
+
anchor_time=anchor_time,
|
|
323
|
+
num_forecasts=query.num_forecasts,
|
|
324
|
+
)
|
|
325
|
+
ys.append(y)
|
|
326
|
+
mask[start:start + batch_size] = _mask
|
|
327
|
+
|
|
328
|
+
count += len(y)
|
|
329
|
+
if count >= num_examples:
|
|
330
|
+
break
|
|
331
|
+
|
|
332
|
+
if len(ys) == 0:
|
|
333
|
+
y = pd.Series([], dtype=float)
|
|
334
|
+
elif len(ys) == 1:
|
|
335
|
+
y = ys[0]
|
|
336
|
+
else:
|
|
337
|
+
y = pd.concat(ys, axis=0, ignore_index=True)
|
|
338
|
+
|
|
339
|
+
return y, mask
|
|
340
|
+
|
|
341
|
+
def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
|
|
342
|
+
df = table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
343
|
+
|
|
344
|
+
stype_dict = self.table_stype_dict[table_name]
|
|
345
|
+
for column_name in df.columns:
|
|
346
|
+
if stype_dict.get(column_name) == Stype.timestamp:
|
|
347
|
+
df[column_name] = pd.to_datetime(df[column_name])
|
|
348
|
+
|
|
349
|
+
return df
|
|
@@ -1,22 +1,34 @@
|
|
|
1
1
|
import re
|
|
2
|
-
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import cast
|
|
3
5
|
|
|
4
|
-
import
|
|
5
|
-
from kumoapi.
|
|
6
|
-
from
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from kumoapi.model_plan import MissingType
|
|
8
|
+
from kumoapi.typing import Dtype
|
|
7
9
|
|
|
8
|
-
from kumoai.experimental.rfm import utils
|
|
9
10
|
from kumoai.experimental.rfm.backend.sqlite import Connection
|
|
10
|
-
from kumoai.experimental.rfm.base import
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
11
|
+
from kumoai.experimental.rfm.base import (
|
|
12
|
+
ColumnExpressionType,
|
|
13
|
+
DataBackend,
|
|
14
|
+
SourceColumn,
|
|
15
|
+
SourceForeignKey,
|
|
16
|
+
SQLTable,
|
|
17
|
+
)
|
|
18
|
+
from kumoai.experimental.rfm.infer import infer_dtype
|
|
19
|
+
from kumoai.utils import quote_ident
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SQLiteTable(SQLTable):
|
|
14
23
|
r"""A table backed by a :class:`sqlite` database.
|
|
15
24
|
|
|
16
25
|
Args:
|
|
17
26
|
connection: The connection to a :class:`sqlite` database.
|
|
18
|
-
name: The name of this table.
|
|
19
|
-
|
|
27
|
+
name: The logical name of this table.
|
|
28
|
+
source_name: The physical name of this table in the database. If set to
|
|
29
|
+
``None``, ``name`` is being used.
|
|
30
|
+
columns: The selected physical columns of this table.
|
|
31
|
+
column_expressions: The logical columns of this table.
|
|
20
32
|
primary_key: The name of the primary key of this table, if it exists.
|
|
21
33
|
time_column: The name of the time column of this table, if it exists.
|
|
22
34
|
end_time_column: The name of the end time column of this table, if it
|
|
@@ -26,92 +38,100 @@ class SQLiteTable(Table):
|
|
|
26
38
|
self,
|
|
27
39
|
connection: Connection,
|
|
28
40
|
name: str,
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
41
|
+
source_name: str | None = None,
|
|
42
|
+
columns: Sequence[str] | None = None,
|
|
43
|
+
column_expressions: Sequence[ColumnExpressionType] | None = None,
|
|
44
|
+
primary_key: MissingType | str | None = MissingType.VALUE,
|
|
45
|
+
time_column: str | None = None,
|
|
46
|
+
end_time_column: str | None = None,
|
|
33
47
|
) -> None:
|
|
34
48
|
|
|
35
49
|
self._connection = connection
|
|
36
|
-
self._dtype_dict: Dict[str, Dtype] = {}
|
|
37
|
-
|
|
38
|
-
with connection.cursor() as cursor:
|
|
39
|
-
cursor.execute(f"PRAGMA table_info({name})")
|
|
40
|
-
for _, column, dtype, _, _, is_pkey in cursor.fetchall():
|
|
41
|
-
if bool(is_pkey):
|
|
42
|
-
if primary_key is not None and primary_key != column:
|
|
43
|
-
raise ValueError(f"Found duplicate primary key "
|
|
44
|
-
f"definition '{primary_key}' and "
|
|
45
|
-
f"'{column}' in table '{name}'")
|
|
46
|
-
primary_key = column
|
|
47
|
-
|
|
48
|
-
# Determine colun affinity:
|
|
49
|
-
dtype = dtype.strip().upper()
|
|
50
|
-
if re.search('INT', dtype):
|
|
51
|
-
self._dtype_dict[column] = Dtype.int
|
|
52
|
-
elif re.search('TEXT|CHAR|CLOB', dtype):
|
|
53
|
-
self._dtype_dict[column] = Dtype.string
|
|
54
|
-
elif re.search('REAL|FLOA|DOUB', dtype):
|
|
55
|
-
self._dtype_dict[column] = Dtype.float
|
|
56
|
-
else: # NUMERIC affinity.
|
|
57
|
-
self._dtype_dict[column] = Dtype.unsupported
|
|
58
|
-
|
|
59
|
-
if len(self._dtype_dict) > 0:
|
|
60
|
-
column_names = ', '.join(self._dtype_dict.keys())
|
|
61
|
-
cursor.execute(f"SELECT {column_names} FROM {name} "
|
|
62
|
-
f"ORDER BY rowid LIMIT 1000")
|
|
63
|
-
self._sample = cursor.fetch_arrow_table()
|
|
64
|
-
|
|
65
|
-
for column_name in list(self._dtype_dict.keys()):
|
|
66
|
-
if self._dtype_dict[column_name] == Dtype.unsupported:
|
|
67
|
-
dtype = self._sample[column_name].type
|
|
68
|
-
if pa.types.is_integer(dtype):
|
|
69
|
-
self._dtype_dict[column_name] = Dtype.int
|
|
70
|
-
elif pa.types.is_floating(dtype):
|
|
71
|
-
self._dtype_dict[column_name] = Dtype.float
|
|
72
|
-
elif pa.types.is_decimal(dtype):
|
|
73
|
-
self._dtype_dict[column_name] = Dtype.float
|
|
74
|
-
elif pa.types.is_string(dtype):
|
|
75
|
-
self._dtype_dict[column_name] = Dtype.string
|
|
76
|
-
else:
|
|
77
|
-
del self._dtype_dict[column_name]
|
|
78
|
-
|
|
79
|
-
if len(self._dtype_dict) == 0:
|
|
80
|
-
raise RuntimeError(f"Table '{name}' does not exist or does not "
|
|
81
|
-
f"hold any column with a supported data type")
|
|
82
50
|
|
|
83
51
|
super().__init__(
|
|
84
52
|
name=name,
|
|
85
|
-
|
|
53
|
+
source_name=source_name,
|
|
54
|
+
columns=columns,
|
|
55
|
+
column_expressions=column_expressions,
|
|
86
56
|
primary_key=primary_key,
|
|
87
57
|
time_column=time_column,
|
|
88
58
|
end_time_column=end_time_column,
|
|
89
59
|
)
|
|
90
60
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
61
|
+
@property
|
|
62
|
+
def backend(self) -> DataBackend:
|
|
63
|
+
return cast(DataBackend, DataBackend.SQLITE)
|
|
64
|
+
|
|
65
|
+
def _get_source_columns(self) -> list[SourceColumn]:
|
|
66
|
+
source_columns: list[SourceColumn] = []
|
|
67
|
+
with self._connection.cursor() as cursor:
|
|
68
|
+
sql = f"PRAGMA table_info({self.fqn})"
|
|
69
|
+
cursor.execute(sql)
|
|
70
|
+
columns = cursor.fetchall()
|
|
71
|
+
|
|
72
|
+
if len(columns) == 0:
|
|
73
|
+
raise ValueError(f"Table '{self._source_name}' does not exist "
|
|
74
|
+
f"in the SQLite database")
|
|
75
|
+
|
|
76
|
+
unique_keys: set[str] = set()
|
|
77
|
+
sql = f"PRAGMA index_list({self.fqn})"
|
|
78
|
+
cursor.execute(sql)
|
|
79
|
+
for _, index_name, is_unique, *_ in cursor.fetchall():
|
|
80
|
+
if bool(is_unique):
|
|
81
|
+
sql = f"PRAGMA index_info({quote_ident(index_name)})"
|
|
82
|
+
cursor.execute(sql)
|
|
83
|
+
index = cursor.fetchall()
|
|
84
|
+
if len(index) == 1:
|
|
85
|
+
unique_keys.add(index[0][2])
|
|
86
|
+
|
|
87
|
+
for _, column, type, notnull, _, is_pkey in columns:
|
|
88
|
+
# Determine column affinity:
|
|
89
|
+
type = type.strip().upper()
|
|
90
|
+
if re.search('INT', type):
|
|
91
|
+
dtype = Dtype.int
|
|
92
|
+
elif re.search('TEXT|CHAR|CLOB', type):
|
|
93
|
+
dtype = Dtype.string
|
|
94
|
+
elif re.search('REAL|FLOA|DOUB', type):
|
|
95
|
+
dtype = Dtype.float
|
|
96
|
+
else: # NUMERIC affinity.
|
|
97
|
+
ser = self._sample_df[column]
|
|
98
|
+
try:
|
|
99
|
+
dtype = infer_dtype(ser)
|
|
100
|
+
except Exception:
|
|
101
|
+
warnings.warn(
|
|
102
|
+
f"Data type inference for column '{column}' in "
|
|
103
|
+
f"table '{self.name}' failed. Consider changing "
|
|
104
|
+
f"the data type of the column in the database or "
|
|
105
|
+
f"remove this column from this table.")
|
|
106
|
+
continue
|
|
107
|
+
|
|
108
|
+
source_column = SourceColumn(
|
|
109
|
+
name=column,
|
|
110
|
+
dtype=dtype,
|
|
111
|
+
is_primary_key=bool(is_pkey),
|
|
112
|
+
is_unique_key=column in unique_keys,
|
|
113
|
+
is_nullable=not bool(is_pkey) and not bool(notnull),
|
|
114
|
+
)
|
|
115
|
+
source_columns.append(source_column)
|
|
116
|
+
|
|
117
|
+
return source_columns
|
|
118
|
+
|
|
119
|
+
def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
|
|
120
|
+
source_fkeys: list[SourceForeignKey] = []
|
|
121
|
+
with self._connection.cursor() as cursor:
|
|
122
|
+
sql = f"PRAGMA foreign_key_list({self.fqn})"
|
|
123
|
+
cursor.execute(sql)
|
|
124
|
+
for _, _, dst_table, fkey, pkey, *_ in cursor.fetchall():
|
|
125
|
+
source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
|
|
126
|
+
return source_fkeys
|
|
127
|
+
|
|
128
|
+
def _get_sample_df(self) -> pd.DataFrame:
|
|
129
|
+
with self._connection.cursor() as cursor:
|
|
130
|
+
sql = (f"SELECT * FROM {self.fqn} "
|
|
131
|
+
f"ORDER BY rowid LIMIT 1000")
|
|
132
|
+
cursor.execute(sql)
|
|
133
|
+
table = cursor.fetch_arrow_table()
|
|
134
|
+
return table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
135
|
+
|
|
136
|
+
def _get_num_rows(self) -> int | None:
|
|
117
137
|
return None
|
|
@@ -1,7 +1,31 @@
|
|
|
1
|
-
from .
|
|
2
|
-
|
|
1
|
+
from kumoapi.common import StrEnum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DataBackend(StrEnum):
|
|
5
|
+
LOCAL = 'local'
|
|
6
|
+
SQLITE = 'sqlite'
|
|
7
|
+
SNOWFLAKE = 'snowflake'
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
from .source import SourceColumn, SourceForeignKey # noqa: E402
|
|
11
|
+
from .column import Column # noqa: E402
|
|
12
|
+
from .column_expression import ColumnExpressionSpec # noqa: E402
|
|
13
|
+
from .column_expression import ColumnExpressionType # noqa: E402
|
|
14
|
+
from .table import Table # noqa: E402
|
|
15
|
+
from .sql_table import SQLTable # noqa: E402
|
|
16
|
+
from .sampler import SamplerOutput, Sampler # noqa: E402
|
|
17
|
+
from .sql_sampler import SQLSampler # noqa: E402
|
|
3
18
|
|
|
4
19
|
__all__ = [
|
|
20
|
+
'DataBackend',
|
|
21
|
+
'SourceColumn',
|
|
22
|
+
'SourceForeignKey',
|
|
5
23
|
'Column',
|
|
24
|
+
'ColumnExpressionSpec',
|
|
25
|
+
'ColumnExpressionType',
|
|
6
26
|
'Table',
|
|
27
|
+
'SQLTable',
|
|
28
|
+
'SamplerOutput',
|
|
29
|
+
'Sampler',
|
|
30
|
+
'SQLSampler',
|
|
7
31
|
]
|
|
@@ -8,20 +8,14 @@ from kumoapi.typing import Dtype, Stype
|
|
|
8
8
|
class Column:
|
|
9
9
|
stype: Stype
|
|
10
10
|
|
|
11
|
-
def __init__(
|
|
12
|
-
self,
|
|
13
|
-
name: str,
|
|
14
|
-
dtype: Dtype,
|
|
15
|
-
stype: Stype,
|
|
16
|
-
is_primary_key: bool = False,
|
|
17
|
-
is_time_column: bool = False,
|
|
18
|
-
is_end_time_column: bool = False,
|
|
19
|
-
) -> None:
|
|
11
|
+
def __init__(self, name: str, stype: Stype, dtype: Dtype) -> None:
|
|
20
12
|
self._name = name
|
|
21
13
|
self._dtype = Dtype(dtype)
|
|
22
|
-
|
|
23
|
-
self.
|
|
24
|
-
self.
|
|
14
|
+
|
|
15
|
+
self._is_primary_key = False
|
|
16
|
+
self._is_time_column = False
|
|
17
|
+
self._is_end_time_column = False
|
|
18
|
+
|
|
25
19
|
self.stype = Stype(stype)
|
|
26
20
|
|
|
27
21
|
@property
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any, TypeAlias
|
|
3
|
+
|
|
4
|
+
from kumoapi.typing import Dtype
|
|
5
|
+
|
|
6
|
+
from kumoai.mixin import CastMixin
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(frozen=True)
|
|
10
|
+
class ColumnExpressionSpec(CastMixin):
|
|
11
|
+
name: str
|
|
12
|
+
expr: str
|
|
13
|
+
dtype: Dtype | None = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
ColumnExpressionType: TypeAlias = ColumnExpressionSpec | dict[str, Any]
|