kumoai 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl → 2.15.0.dev202601121731__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +51 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/graph_store.py +37 -46
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +24 -30
- kumoai/experimental/rfm/backend/snow/sampler.py +197 -90
- kumoai/experimental/rfm/backend/snow/table.py +159 -52
- kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +199 -99
- kumoai/experimental/rfm/backend/sqlite/table.py +103 -45
- kumoai/experimental/rfm/base/__init__.py +6 -1
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +28 -18
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +342 -13
- kumoai/experimental/rfm/base/table.py +374 -208
- kumoai/experimental/rfm/base/utils.py +27 -0
- kumoai/experimental/rfm/graph.py +335 -180
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +7 -4
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +5 -4
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +600 -360
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +3 -3
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +1 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +190 -12
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/METADATA +3 -2
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/RECORD +48 -40
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/top_level.txt +0 -0
|
@@ -6,12 +6,11 @@ import numpy as np
|
|
|
6
6
|
import pandas as pd
|
|
7
7
|
import pyarrow as pa
|
|
8
8
|
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
9
|
-
from kumoapi.typing import Stype
|
|
10
9
|
|
|
11
10
|
from kumoai.experimental.rfm.backend.sqlite import SQLiteTable
|
|
12
|
-
from kumoai.experimental.rfm.base import SQLSampler
|
|
11
|
+
from kumoai.experimental.rfm.base import SQLSampler, Table
|
|
13
12
|
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
14
|
-
from kumoai.utils import
|
|
13
|
+
from kumoai.utils import ProgressLogger, quote_ident
|
|
15
14
|
|
|
16
15
|
if TYPE_CHECKING:
|
|
17
16
|
from kumoai.experimental.rfm import Graph
|
|
@@ -35,17 +34,23 @@ class SQLiteSampler(SQLSampler):
|
|
|
35
34
|
cursor.execute("PRAGMA temp_store = MEMORY")
|
|
36
35
|
cursor.execute("PRAGMA cache_size = -2000000") # 2 GB
|
|
37
36
|
|
|
38
|
-
# Collect database indices
|
|
37
|
+
# Collect database indices for speeding sampling:
|
|
39
38
|
index_dict: dict[str, set[tuple[str, ...]]] = defaultdict(set)
|
|
40
39
|
for table_name, primary_key in self.primary_key_dict.items():
|
|
41
40
|
source_table = self.source_table_dict[table_name]
|
|
42
|
-
if not source_table
|
|
43
|
-
|
|
41
|
+
if primary_key not in source_table:
|
|
42
|
+
continue # No physical column.
|
|
43
|
+
if source_table[primary_key].is_unique_key:
|
|
44
|
+
continue
|
|
45
|
+
index_dict[table_name].add((primary_key, ))
|
|
44
46
|
for src_table_name, foreign_key, _ in graph.edges:
|
|
45
47
|
source_table = self.source_table_dict[src_table_name]
|
|
48
|
+
if foreign_key not in source_table:
|
|
49
|
+
continue # No physical column.
|
|
46
50
|
if source_table[foreign_key].is_unique_key:
|
|
47
|
-
|
|
48
|
-
|
|
51
|
+
continue
|
|
52
|
+
time_column = self.time_column_dict.get(src_table_name)
|
|
53
|
+
if time_column is not None and time_column in source_table:
|
|
49
54
|
index_dict[src_table_name].add((foreign_key, time_column))
|
|
50
55
|
else:
|
|
51
56
|
index_dict[src_table_name].add((foreign_key, ))
|
|
@@ -54,46 +59,57 @@ class SQLiteSampler(SQLSampler):
|
|
|
54
59
|
with self._connection.cursor() as cursor:
|
|
55
60
|
for table_name in list(index_dict.keys()):
|
|
56
61
|
indices = index_dict[table_name]
|
|
57
|
-
|
|
62
|
+
source_name = self.source_name_dict[table_name]
|
|
63
|
+
sql = f"PRAGMA index_list({source_name})"
|
|
58
64
|
cursor.execute(sql)
|
|
59
65
|
for _, index_name, *_ in cursor.fetchall():
|
|
60
66
|
sql = f"PRAGMA index_info({quote_ident(index_name)})"
|
|
61
67
|
cursor.execute(sql)
|
|
62
|
-
index
|
|
68
|
+
# Fetch index information and sort by `seqno`:
|
|
69
|
+
index_info = tuple(info[2] for info in sorted(
|
|
63
70
|
cursor.fetchall(), key=lambda x: x[0]))
|
|
64
|
-
indices
|
|
71
|
+
# Remove all indices in case primary index already exists:
|
|
72
|
+
for index in list(indices):
|
|
73
|
+
if index_info[0] == index[0]:
|
|
74
|
+
indices.discard(index)
|
|
65
75
|
if len(indices) == 0:
|
|
66
76
|
del index_dict[table_name]
|
|
67
77
|
|
|
68
|
-
num = sum(len(indices) for indices in index_dict.values())
|
|
69
|
-
index_repr = '1 index' if num == 1 else f'{num} indices'
|
|
70
|
-
num = len(index_dict)
|
|
71
|
-
table_repr = '1 table' if num == 1 else f'{num} tables'
|
|
72
|
-
|
|
73
78
|
if optimize and len(index_dict) > 0:
|
|
74
79
|
if not isinstance(verbose, ProgressLogger):
|
|
75
|
-
verbose =
|
|
76
|
-
"Optimizing SQLite database",
|
|
80
|
+
verbose = ProgressLogger.default(
|
|
81
|
+
msg="Optimizing SQLite database",
|
|
77
82
|
verbose=verbose,
|
|
78
83
|
)
|
|
79
84
|
|
|
80
|
-
with verbose as logger:
|
|
81
|
-
|
|
82
|
-
for
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
85
|
+
with verbose as logger, self._connection.cursor() as cursor:
|
|
86
|
+
for table_name, indices in index_dict.items():
|
|
87
|
+
for index in indices:
|
|
88
|
+
name = f"kumo_index_{table_name}_{'_'.join(index)}"
|
|
89
|
+
name = quote_ident(name)
|
|
90
|
+
columns = ', '.join(quote_ident(v) for v in index)
|
|
91
|
+
columns += ' DESC' if len(index) > 1 else ''
|
|
92
|
+
source_name = self.source_name_dict[table_name]
|
|
93
|
+
sql = (f"CREATE INDEX IF NOT EXISTS {name}\n"
|
|
94
|
+
f"ON {source_name}({columns})")
|
|
95
|
+
cursor.execute(sql)
|
|
96
|
+
self._connection.commit()
|
|
97
|
+
if len(index) > 1:
|
|
98
|
+
logger.log(f"Created index on {index} in table "
|
|
99
|
+
f"'{table_name}'")
|
|
100
|
+
else:
|
|
101
|
+
logger.log(f"Created index on '{index[0]}' in "
|
|
102
|
+
f"table '{table_name}'")
|
|
92
103
|
|
|
93
104
|
elif len(index_dict) > 0:
|
|
105
|
+
num = sum(len(indices) for indices in index_dict.values())
|
|
106
|
+
index_repr = '1 index' if num == 1 else f'{num} indices'
|
|
107
|
+
num = len(index_dict)
|
|
108
|
+
table_repr = '1 table' if num == 1 else f'{num} tables'
|
|
94
109
|
warnings.warn(f"Missing {index_repr} in {table_repr} for optimal "
|
|
95
110
|
f"database querying. For improving runtime, we "
|
|
96
|
-
f"strongly suggest to create
|
|
111
|
+
f"strongly suggest to create indices for primary "
|
|
112
|
+
f"and foreign keys, e.g., automatically by "
|
|
97
113
|
f"instantiating KumoRFM via "
|
|
98
114
|
f"`KumoRFM(graph, optimize=True)`.")
|
|
99
115
|
|
|
@@ -103,12 +119,13 @@ class SQLiteSampler(SQLSampler):
|
|
|
103
119
|
) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
|
|
104
120
|
selects: list[str] = []
|
|
105
121
|
for table_name in table_names:
|
|
106
|
-
|
|
122
|
+
column = self.time_column_dict[table_name]
|
|
123
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
107
124
|
select = (f"SELECT\n"
|
|
108
125
|
f" ? as table_name,\n"
|
|
109
|
-
f" MIN({
|
|
110
|
-
f" MAX({
|
|
111
|
-
f"FROM {
|
|
126
|
+
f" MIN({column_ref}) as min_date,\n"
|
|
127
|
+
f" MAX({column_ref}) as max_date\n"
|
|
128
|
+
f"FROM {self.source_name_dict[table_name]}")
|
|
112
129
|
selects.append(select)
|
|
113
130
|
sql = "\nUNION ALL\n".join(selects)
|
|
114
131
|
|
|
@@ -131,18 +148,28 @@ class SQLiteSampler(SQLSampler):
|
|
|
131
148
|
) -> pd.DataFrame:
|
|
132
149
|
# NOTE SQLite does not natively support passing a `random_seed`.
|
|
133
150
|
|
|
151
|
+
source_table = self.source_table_dict[table_name]
|
|
134
152
|
filters: list[str] = []
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
153
|
+
|
|
154
|
+
key = self.primary_key_dict[table_name]
|
|
155
|
+
if key not in source_table or source_table[key].is_nullable:
|
|
156
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
157
|
+
filters.append(f" {key_ref} IS NOT NULL")
|
|
158
|
+
|
|
159
|
+
column = self.time_column_dict.get(table_name)
|
|
160
|
+
if column is None:
|
|
161
|
+
pass
|
|
162
|
+
elif column not in source_table or source_table[column].is_nullable:
|
|
163
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
164
|
+
filters.append(f" {column_ref} IS NOT NULL")
|
|
142
165
|
|
|
143
166
|
# TODO Make this query more efficient - it does full table scan.
|
|
144
|
-
|
|
145
|
-
|
|
167
|
+
projections = [
|
|
168
|
+
self.table_column_proj_dict[table_name][column]
|
|
169
|
+
for column in columns
|
|
170
|
+
]
|
|
171
|
+
sql = (f"SELECT {', '.join(projections)}\n"
|
|
172
|
+
f"FROM {self.source_name_dict[table_name]}")
|
|
146
173
|
if len(filters) > 0:
|
|
147
174
|
sql += f"\nWHERE{' AND'.join(filters)}"
|
|
148
175
|
sql += f"\nORDER BY RANDOM() LIMIT {num_rows}"
|
|
@@ -152,7 +179,11 @@ class SQLiteSampler(SQLSampler):
|
|
|
152
179
|
cursor.execute(sql)
|
|
153
180
|
table = cursor.fetch_arrow_table()
|
|
154
181
|
|
|
155
|
-
return
|
|
182
|
+
return Table._sanitize(
|
|
183
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
184
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
185
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
186
|
+
)
|
|
156
187
|
|
|
157
188
|
def _sample_target(
|
|
158
189
|
self,
|
|
@@ -195,84 +226,163 @@ class SQLiteSampler(SQLSampler):
|
|
|
195
226
|
def _by_pkey(
|
|
196
227
|
self,
|
|
197
228
|
table_name: str,
|
|
198
|
-
|
|
229
|
+
index: pd.Series,
|
|
199
230
|
columns: set[str],
|
|
200
231
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
201
|
-
|
|
232
|
+
source_table = self.source_table_dict[table_name]
|
|
233
|
+
key = self.primary_key_dict[table_name]
|
|
234
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
235
|
+
projections = [
|
|
236
|
+
self.table_column_proj_dict[table_name][column]
|
|
237
|
+
for column in columns
|
|
238
|
+
]
|
|
239
|
+
|
|
240
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
241
|
+
tmp_name = f'tmp_{table_name}_{key}_{id(tmp)}'
|
|
242
|
+
|
|
243
|
+
sql = (f"SELECT "
|
|
244
|
+
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
245
|
+
f"{', '.join(projections)}\n"
|
|
246
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
247
|
+
f"JOIN {self.source_name_dict[table_name]} ent\n")
|
|
248
|
+
if key in source_table and source_table[key].is_unique_key:
|
|
249
|
+
sql += (f" ON {key_ref} = tmp.__kumo_id__")
|
|
250
|
+
else:
|
|
251
|
+
sql += (f" ON ent.rowid = (\n"
|
|
252
|
+
f" SELECT rowid\n"
|
|
253
|
+
f" FROM {self.source_name_dict[table_name]}\n"
|
|
254
|
+
f" WHERE {key_ref} == tmp.__kumo_id__\n"
|
|
255
|
+
f" LIMIT 1\n"
|
|
256
|
+
f")")
|
|
202
257
|
|
|
203
|
-
|
|
204
|
-
|
|
258
|
+
with self._connection.cursor() as cursor:
|
|
259
|
+
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
260
|
+
cursor.execute(sql)
|
|
261
|
+
table = cursor.fetch_arrow_table()
|
|
205
262
|
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
263
|
+
batch = table['__kumo_batch__'].to_numpy()
|
|
264
|
+
batch_index = table.schema.get_field_index('__kumo_batch__')
|
|
265
|
+
table = table.remove_column(batch_index)
|
|
266
|
+
|
|
267
|
+
return Table._sanitize(
|
|
268
|
+
df=table.to_pandas(),
|
|
269
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
270
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
271
|
+
), batch
|
|
272
|
+
|
|
273
|
+
def _by_fkey(
|
|
274
|
+
self,
|
|
275
|
+
table_name: str,
|
|
276
|
+
foreign_key: str,
|
|
277
|
+
index: pd.Series,
|
|
278
|
+
num_neighbors: int,
|
|
279
|
+
anchor_time: pd.Series | None,
|
|
280
|
+
columns: set[str],
|
|
281
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
282
|
+
time_column = self.time_column_dict.get(table_name)
|
|
283
|
+
|
|
284
|
+
# NOTE SQLite does not have a native datetime format. Currently, we
|
|
285
|
+
# assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
|
|
286
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
287
|
+
if time_column is not None and anchor_time is not None:
|
|
288
|
+
anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
289
|
+
tmp = tmp.append_column('__kumo_time__', pa.array(anchor_time))
|
|
290
|
+
tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
|
|
291
|
+
|
|
292
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
293
|
+
projections = [
|
|
294
|
+
self.table_column_proj_dict[table_name][column]
|
|
295
|
+
for column in columns
|
|
296
|
+
]
|
|
297
|
+
sql = (f"SELECT "
|
|
298
|
+
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
299
|
+
f"{', '.join(projections)}\n"
|
|
300
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
301
|
+
f"JOIN {self.source_name_dict[table_name]} fact\n"
|
|
302
|
+
f"ON fact.rowid IN (\n"
|
|
303
|
+
f" SELECT rowid\n"
|
|
304
|
+
f" FROM {self.source_name_dict[table_name]}\n"
|
|
305
|
+
f" WHERE {key_ref} = tmp.__kumo_id__\n")
|
|
306
|
+
if time_column is not None and anchor_time is not None:
|
|
307
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
308
|
+
sql += f" AND {time_ref} <= tmp.__kumo_time__\n"
|
|
309
|
+
if time_column is not None:
|
|
310
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
311
|
+
sql += f" ORDER BY {time_ref} DESC\n"
|
|
312
|
+
sql += (f" LIMIT {num_neighbors}\n"
|
|
313
|
+
f")")
|
|
222
314
|
|
|
223
315
|
with self._connection.cursor() as cursor:
|
|
224
316
|
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
225
317
|
cursor.execute(sql)
|
|
226
318
|
table = cursor.fetch_arrow_table()
|
|
227
319
|
|
|
228
|
-
batch = table['
|
|
229
|
-
|
|
320
|
+
batch = table['__kumo_batch__'].to_numpy()
|
|
321
|
+
batch_index = table.schema.get_field_index('__kumo_batch__')
|
|
322
|
+
table = table.remove_column(batch_index)
|
|
230
323
|
|
|
231
|
-
return
|
|
324
|
+
return Table._sanitize(
|
|
325
|
+
df=table.to_pandas(),
|
|
326
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
327
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
328
|
+
), batch
|
|
232
329
|
|
|
233
330
|
# Helper Methods ##########################################################
|
|
234
331
|
|
|
235
332
|
def _by_time(
|
|
236
333
|
self,
|
|
237
334
|
table_name: str,
|
|
238
|
-
|
|
239
|
-
|
|
335
|
+
foreign_key: str,
|
|
336
|
+
index: pd.Series,
|
|
240
337
|
anchor_time: pd.Series,
|
|
241
338
|
min_offset: pd.DateOffset | None,
|
|
242
339
|
max_offset: pd.DateOffset,
|
|
243
340
|
columns: set[str],
|
|
244
341
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
342
|
+
time_column = self.time_column_dict[table_name]
|
|
343
|
+
|
|
245
344
|
# NOTE SQLite does not have a native datetime format. Currently, we
|
|
246
345
|
# assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
|
|
247
|
-
tmp = pa.table([pa.array(
|
|
346
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
248
347
|
end_time = anchor_time + max_offset
|
|
249
348
|
end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
250
|
-
tmp = tmp.append_column('
|
|
349
|
+
tmp = tmp.append_column('__kumo_end__', pa.array(end_time))
|
|
251
350
|
if min_offset is not None:
|
|
252
351
|
start_time = anchor_time + min_offset
|
|
253
352
|
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
254
|
-
tmp = tmp.append_column('
|
|
255
|
-
tmp_name = f'tmp_{table_name}_{
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
353
|
+
tmp = tmp.append_column('__kumo_start__', pa.array(start_time))
|
|
354
|
+
tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
|
|
355
|
+
|
|
356
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
357
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
358
|
+
projections = [
|
|
359
|
+
self.table_column_proj_dict[table_name][column]
|
|
360
|
+
for column in columns
|
|
361
|
+
]
|
|
362
|
+
sql = (f"SELECT "
|
|
363
|
+
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
364
|
+
f"{', '.join(projections)}\n"
|
|
260
365
|
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
261
|
-
f"JOIN {
|
|
262
|
-
f" ON
|
|
263
|
-
f" AND
|
|
366
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
367
|
+
f" ON {key_ref} = tmp.__kumo_id__\n"
|
|
368
|
+
f" AND {time_ref} <= tmp.__kumo_end__")
|
|
264
369
|
if min_offset is not None:
|
|
265
|
-
sql += f"\n AND
|
|
370
|
+
sql += f"\n AND {time_ref} > tmp.__kumo_start__"
|
|
266
371
|
|
|
267
372
|
with self._connection.cursor() as cursor:
|
|
268
373
|
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
269
374
|
cursor.execute(sql)
|
|
270
375
|
table = cursor.fetch_arrow_table()
|
|
271
376
|
|
|
272
|
-
batch = table['
|
|
273
|
-
|
|
377
|
+
batch = table['__kumo_batch__'].to_numpy()
|
|
378
|
+
batch_index = table.schema.get_field_index('__kumo_batch__')
|
|
379
|
+
table = table.remove_column(batch_index)
|
|
274
380
|
|
|
275
|
-
return
|
|
381
|
+
return Table._sanitize(
|
|
382
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
383
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
384
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
385
|
+
), batch
|
|
276
386
|
|
|
277
387
|
def _sample_target_set(
|
|
278
388
|
self,
|
|
@@ -305,11 +415,11 @@ class SQLiteSampler(SQLSampler):
|
|
|
305
415
|
query.entity_table: np.arange(len(df)),
|
|
306
416
|
}
|
|
307
417
|
for edge_type, (_min, _max) in time_offset_dict.items():
|
|
308
|
-
table_name,
|
|
418
|
+
table_name, foreign_key, _ = edge_type
|
|
309
419
|
feat_dict[table_name], batch_dict[table_name] = self._by_time(
|
|
310
420
|
table_name=table_name,
|
|
311
|
-
|
|
312
|
-
|
|
421
|
+
foreign_key=foreign_key,
|
|
422
|
+
index=df[self.primary_key_dict[query.entity_table]],
|
|
313
423
|
anchor_time=time,
|
|
314
424
|
min_offset=_min,
|
|
315
425
|
max_offset=_max,
|
|
@@ -324,7 +434,7 @@ class SQLiteSampler(SQLSampler):
|
|
|
324
434
|
feat_dict=feat_dict,
|
|
325
435
|
time_dict=time_dict,
|
|
326
436
|
batch_dict=batch_dict,
|
|
327
|
-
anchor_time=
|
|
437
|
+
anchor_time=time,
|
|
328
438
|
num_forecasts=query.num_forecasts,
|
|
329
439
|
)
|
|
330
440
|
ys.append(y)
|
|
@@ -342,13 +452,3 @@ class SQLiteSampler(SQLSampler):
|
|
|
342
452
|
y = pd.concat(ys, axis=0, ignore_index=True)
|
|
343
453
|
|
|
344
454
|
return y, mask
|
|
345
|
-
|
|
346
|
-
def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
|
|
347
|
-
df = table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
348
|
-
|
|
349
|
-
stype_dict = self.table_stype_dict[table_name]
|
|
350
|
-
for column_name in df.columns:
|
|
351
|
-
if stype_dict.get(column_name) == Stype.timestamp:
|
|
352
|
-
df[column_name] = pd.to_datetime(df[column_name])
|
|
353
|
-
|
|
354
|
-
return df
|
|
@@ -1,18 +1,21 @@
|
|
|
1
1
|
import re
|
|
2
|
-
import
|
|
3
|
-
from
|
|
2
|
+
from collections import Counter
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import cast
|
|
4
5
|
|
|
5
6
|
import pandas as pd
|
|
7
|
+
from kumoapi.model_plan import MissingType
|
|
6
8
|
from kumoapi.typing import Dtype
|
|
7
9
|
|
|
8
10
|
from kumoai.experimental.rfm.backend.sqlite import Connection
|
|
9
11
|
from kumoai.experimental.rfm.base import (
|
|
12
|
+
ColumnSpec,
|
|
13
|
+
ColumnSpecType,
|
|
10
14
|
DataBackend,
|
|
11
15
|
SourceColumn,
|
|
12
16
|
SourceForeignKey,
|
|
13
17
|
Table,
|
|
14
18
|
)
|
|
15
|
-
from kumoai.experimental.rfm.infer import infer_dtype
|
|
16
19
|
from kumoai.utils import quote_ident
|
|
17
20
|
|
|
18
21
|
|
|
@@ -22,6 +25,8 @@ class SQLiteTable(Table):
|
|
|
22
25
|
Args:
|
|
23
26
|
connection: The connection to a :class:`sqlite` database.
|
|
24
27
|
name: The name of this table.
|
|
28
|
+
source_name: The source name of this table. If set to ``None``,
|
|
29
|
+
``name`` is being used.
|
|
25
30
|
columns: The selected columns of this table.
|
|
26
31
|
primary_key: The name of the primary key of this table, if it exists.
|
|
27
32
|
time_column: The name of the time column of this table, if it exists.
|
|
@@ -32,16 +37,18 @@ class SQLiteTable(Table):
|
|
|
32
37
|
self,
|
|
33
38
|
connection: Connection,
|
|
34
39
|
name: str,
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
40
|
+
source_name: str | None = None,
|
|
41
|
+
columns: Sequence[ColumnSpecType] | None = None,
|
|
42
|
+
primary_key: MissingType | str | None = MissingType.VALUE,
|
|
43
|
+
time_column: str | None = None,
|
|
44
|
+
end_time_column: str | None = None,
|
|
39
45
|
) -> None:
|
|
40
46
|
|
|
41
47
|
self._connection = connection
|
|
42
48
|
|
|
43
49
|
super().__init__(
|
|
44
50
|
name=name,
|
|
51
|
+
source_name=source_name,
|
|
45
52
|
columns=columns,
|
|
46
53
|
primary_key=primary_key,
|
|
47
54
|
time_column=time_column,
|
|
@@ -52,18 +59,19 @@ class SQLiteTable(Table):
|
|
|
52
59
|
def backend(self) -> DataBackend:
|
|
53
60
|
return cast(DataBackend, DataBackend.SQLITE)
|
|
54
61
|
|
|
55
|
-
def _get_source_columns(self) ->
|
|
56
|
-
source_columns:
|
|
62
|
+
def _get_source_columns(self) -> list[SourceColumn]:
|
|
63
|
+
source_columns: list[SourceColumn] = []
|
|
57
64
|
with self._connection.cursor() as cursor:
|
|
58
|
-
sql = f"PRAGMA table_info({
|
|
65
|
+
sql = f"PRAGMA table_info({self._quoted_source_name})"
|
|
59
66
|
cursor.execute(sql)
|
|
60
67
|
columns = cursor.fetchall()
|
|
61
68
|
|
|
62
69
|
if len(columns) == 0:
|
|
63
|
-
raise ValueError(f"Table '{self.
|
|
70
|
+
raise ValueError(f"Table '{self.source_name}' does not exist "
|
|
71
|
+
f"in the SQLite database")
|
|
64
72
|
|
|
65
73
|
unique_keys: set[str] = set()
|
|
66
|
-
sql = f"PRAGMA index_list({
|
|
74
|
+
sql = f"PRAGMA index_list({self._quoted_source_name})"
|
|
67
75
|
cursor.execute(sql)
|
|
68
76
|
for _, index_name, is_unique, *_ in cursor.fetchall():
|
|
69
77
|
if bool(is_unique):
|
|
@@ -73,30 +81,19 @@ class SQLiteTable(Table):
|
|
|
73
81
|
if len(index) == 1:
|
|
74
82
|
unique_keys.add(index[0][2])
|
|
75
83
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
dtype = Dtype.float
|
|
85
|
-
else: # NUMERIC affinity.
|
|
86
|
-
ser = self._sample_df[column]
|
|
87
|
-
try:
|
|
88
|
-
dtype = infer_dtype(ser)
|
|
89
|
-
except Exception:
|
|
90
|
-
warnings.warn(
|
|
91
|
-
f"Data type inference for column '{column}' in "
|
|
92
|
-
f"table '{self.name}' failed. Consider changing "
|
|
93
|
-
f"the data type of the column to use it within "
|
|
94
|
-
f"this table.")
|
|
95
|
-
continue
|
|
84
|
+
# Special SQLite case that creates a rowid alias for
|
|
85
|
+
# `INTEGER PRIMARY KEY` annotated columns:
|
|
86
|
+
rowid_candidates = [
|
|
87
|
+
column for _, column, dtype, _, _, is_pkey in columns
|
|
88
|
+
if bool(is_pkey) and dtype.strip().upper() == 'INTEGER'
|
|
89
|
+
]
|
|
90
|
+
if len(rowid_candidates) == 1:
|
|
91
|
+
unique_keys.add(rowid_candidates[0])
|
|
96
92
|
|
|
93
|
+
for _, column, dtype, notnull, _, is_pkey in columns:
|
|
97
94
|
source_column = SourceColumn(
|
|
98
95
|
name=column,
|
|
99
|
-
dtype=dtype,
|
|
96
|
+
dtype=self._to_dtype(dtype),
|
|
100
97
|
is_primary_key=bool(is_pkey),
|
|
101
98
|
is_unique_key=column in unique_keys,
|
|
102
99
|
is_nullable=not bool(is_pkey) and not bool(notnull),
|
|
@@ -105,22 +102,83 @@ class SQLiteTable(Table):
|
|
|
105
102
|
|
|
106
103
|
return source_columns
|
|
107
104
|
|
|
108
|
-
def _get_source_foreign_keys(self) ->
|
|
109
|
-
|
|
105
|
+
def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
|
|
106
|
+
source_foreign_keys: list[SourceForeignKey] = []
|
|
110
107
|
with self._connection.cursor() as cursor:
|
|
111
|
-
sql = f"PRAGMA foreign_key_list({
|
|
108
|
+
sql = f"PRAGMA foreign_key_list({self._quoted_source_name})"
|
|
112
109
|
cursor.execute(sql)
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
110
|
+
rows = cursor.fetchall()
|
|
111
|
+
counts = Counter(row[0] for row in rows)
|
|
112
|
+
for idx, _, dst_table, foreign_key, primary_key, *_ in rows:
|
|
113
|
+
if counts[idx] == 1:
|
|
114
|
+
source_foreign_key = SourceForeignKey(
|
|
115
|
+
name=foreign_key,
|
|
116
|
+
dst_table=dst_table,
|
|
117
|
+
primary_key=primary_key,
|
|
118
|
+
)
|
|
119
|
+
source_foreign_keys.append(source_foreign_key)
|
|
120
|
+
return source_foreign_keys
|
|
121
|
+
|
|
122
|
+
def _get_source_sample_df(self) -> pd.DataFrame:
|
|
118
123
|
with self._connection.cursor() as cursor:
|
|
119
|
-
|
|
120
|
-
|
|
124
|
+
columns = [quote_ident(col) for col in self._source_column_dict]
|
|
125
|
+
sql = (f"SELECT {', '.join(columns)} "
|
|
126
|
+
f"FROM {self._quoted_source_name} "
|
|
127
|
+
f"ORDER BY rowid "
|
|
128
|
+
f"LIMIT {self._NUM_SAMPLE_ROWS}")
|
|
121
129
|
cursor.execute(sql)
|
|
122
130
|
table = cursor.fetch_arrow_table()
|
|
123
|
-
return table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
124
131
|
|
|
125
|
-
|
|
132
|
+
if len(table) == 0:
|
|
133
|
+
raise RuntimeError(f"Table '{self.source_name}' is empty")
|
|
134
|
+
|
|
135
|
+
return self._sanitize(
|
|
136
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
137
|
+
dtype_dict={
|
|
138
|
+
column.name: column.dtype
|
|
139
|
+
for column in self._source_column_dict.values()
|
|
140
|
+
},
|
|
141
|
+
stype_dict=None,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def _get_num_rows(self) -> int | None:
|
|
126
145
|
return None
|
|
146
|
+
|
|
147
|
+
def _get_expr_sample_df(
|
|
148
|
+
self,
|
|
149
|
+
columns: Sequence[ColumnSpec],
|
|
150
|
+
) -> pd.DataFrame:
|
|
151
|
+
with self._connection.cursor() as cursor:
|
|
152
|
+
projections = [
|
|
153
|
+
f"{column.expr} AS {quote_ident(column.name)}"
|
|
154
|
+
for column in columns
|
|
155
|
+
]
|
|
156
|
+
sql = (f"SELECT {', '.join(projections)} "
|
|
157
|
+
f"FROM {self._quoted_source_name} "
|
|
158
|
+
f"ORDER BY rowid "
|
|
159
|
+
f"LIMIT {self._NUM_SAMPLE_ROWS}")
|
|
160
|
+
cursor.execute(sql)
|
|
161
|
+
table = cursor.fetch_arrow_table()
|
|
162
|
+
|
|
163
|
+
if len(table) == 0:
|
|
164
|
+
raise RuntimeError(f"Table '{self.source_name}' is empty")
|
|
165
|
+
|
|
166
|
+
return self._sanitize(
|
|
167
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
168
|
+
dtype_dict={column.name: column.dtype
|
|
169
|
+
for column in columns},
|
|
170
|
+
stype_dict=None,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
@staticmethod
|
|
174
|
+
def _to_dtype(dtype: str | None) -> Dtype | None:
|
|
175
|
+
if dtype is None:
|
|
176
|
+
return None
|
|
177
|
+
dtype = dtype.strip().upper()
|
|
178
|
+
if re.search('INT', dtype):
|
|
179
|
+
return Dtype.int
|
|
180
|
+
if re.search('TEXT|CHAR|CLOB', dtype):
|
|
181
|
+
return Dtype.string
|
|
182
|
+
if re.search('REAL|FLOA|DOUB', dtype):
|
|
183
|
+
return Dtype.float
|
|
184
|
+
return None # NUMERIC affinity.
|