kumoai 2.14.0.dev202512191731__cp311-cp311-macosx_11_0_arm64.whl → 2.15.0.dev202601141731__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +24 -22
- kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +24 -25
- kumoai/experimental/rfm/backend/snow/sampler.py +235 -80
- kumoai/experimental/rfm/backend/snow/table.py +146 -70
- kumoai/experimental/rfm/backend/sqlite/sampler.py +196 -89
- kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
- kumoai/experimental/rfm/base/__init__.py +6 -9
- kumoai/experimental/rfm/base/column.py +95 -11
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +28 -18
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +320 -19
- kumoai/experimental/rfm/base/table.py +256 -109
- kumoai/experimental/rfm/base/utils.py +36 -0
- kumoai/experimental/rfm/graph.py +115 -107
- kumoai/experimental/rfm/infer/dtype.py +7 -2
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/time_col.py +4 -2
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +540 -306
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +3 -3
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +15 -2
- kumoai/utils/sql.py +2 -2
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/METADATA +2 -2
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/RECORD +40 -35
- kumoai/experimental/rfm/base/column_expression.py +0 -50
- kumoai/experimental/rfm/base/sql_table.py +0 -229
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/top_level.txt +0 -0
|
@@ -6,9 +6,9 @@ import numpy as np
|
|
|
6
6
|
import pandas as pd
|
|
7
7
|
import pyarrow as pa
|
|
8
8
|
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
9
|
-
from kumoapi.typing import Stype
|
|
10
9
|
|
|
11
|
-
from kumoai.experimental.rfm.
|
|
10
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteTable
|
|
11
|
+
from kumoai.experimental.rfm.base import SQLSampler, Table
|
|
12
12
|
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
13
13
|
from kumoai.utils import ProgressLogger, quote_ident
|
|
14
14
|
|
|
@@ -25,22 +25,32 @@ class SQLiteSampler(SQLSampler):
|
|
|
25
25
|
) -> None:
|
|
26
26
|
super().__init__(graph=graph, verbose=verbose)
|
|
27
27
|
|
|
28
|
+
for table in graph.tables.values():
|
|
29
|
+
assert isinstance(table, SQLiteTable)
|
|
30
|
+
self._connection = table._connection
|
|
31
|
+
|
|
28
32
|
if optimize:
|
|
29
33
|
with self._connection.cursor() as cursor:
|
|
30
34
|
cursor.execute("PRAGMA temp_store = MEMORY")
|
|
31
35
|
cursor.execute("PRAGMA cache_size = -2000000") # 2 GB
|
|
32
36
|
|
|
33
|
-
# Collect database indices
|
|
37
|
+
# Collect database indices for speeding sampling:
|
|
34
38
|
index_dict: dict[str, set[tuple[str, ...]]] = defaultdict(set)
|
|
35
39
|
for table_name, primary_key in self.primary_key_dict.items():
|
|
36
40
|
source_table = self.source_table_dict[table_name]
|
|
37
|
-
if not source_table
|
|
38
|
-
|
|
41
|
+
if primary_key not in source_table:
|
|
42
|
+
continue # No physical column.
|
|
43
|
+
if source_table[primary_key].is_unique_key:
|
|
44
|
+
continue
|
|
45
|
+
index_dict[table_name].add((primary_key, ))
|
|
39
46
|
for src_table_name, foreign_key, _ in graph.edges:
|
|
40
47
|
source_table = self.source_table_dict[src_table_name]
|
|
48
|
+
if foreign_key not in source_table:
|
|
49
|
+
continue # No physical column.
|
|
41
50
|
if source_table[foreign_key].is_unique_key:
|
|
42
|
-
|
|
43
|
-
|
|
51
|
+
continue
|
|
52
|
+
time_column = self.time_column_dict.get(src_table_name)
|
|
53
|
+
if time_column is not None and time_column in source_table:
|
|
44
54
|
index_dict[src_table_name].add((foreign_key, time_column))
|
|
45
55
|
else:
|
|
46
56
|
index_dict[src_table_name].add((foreign_key, ))
|
|
@@ -49,22 +59,22 @@ class SQLiteSampler(SQLSampler):
|
|
|
49
59
|
with self._connection.cursor() as cursor:
|
|
50
60
|
for table_name in list(index_dict.keys()):
|
|
51
61
|
indices = index_dict[table_name]
|
|
52
|
-
|
|
62
|
+
source_name = self.source_name_dict[table_name]
|
|
63
|
+
sql = f"PRAGMA index_list({source_name})"
|
|
53
64
|
cursor.execute(sql)
|
|
54
65
|
for _, index_name, *_ in cursor.fetchall():
|
|
55
66
|
sql = f"PRAGMA index_info({quote_ident(index_name)})"
|
|
56
67
|
cursor.execute(sql)
|
|
57
|
-
index
|
|
68
|
+
# Fetch index information and sort by `seqno`:
|
|
69
|
+
index_info = tuple(info[2] for info in sorted(
|
|
58
70
|
cursor.fetchall(), key=lambda x: x[0]))
|
|
59
|
-
indices
|
|
71
|
+
# Remove all indices in case primary index already exists:
|
|
72
|
+
for index in list(indices):
|
|
73
|
+
if index_info[0] == index[0]:
|
|
74
|
+
indices.discard(index)
|
|
60
75
|
if len(indices) == 0:
|
|
61
76
|
del index_dict[table_name]
|
|
62
77
|
|
|
63
|
-
num = sum(len(indices) for indices in index_dict.values())
|
|
64
|
-
index_repr = '1 index' if num == 1 else f'{num} indices'
|
|
65
|
-
num = len(index_dict)
|
|
66
|
-
table_repr = '1 table' if num == 1 else f'{num} tables'
|
|
67
|
-
|
|
68
78
|
if optimize and len(index_dict) > 0:
|
|
69
79
|
if not isinstance(verbose, ProgressLogger):
|
|
70
80
|
verbose = ProgressLogger.default(
|
|
@@ -79,16 +89,27 @@ class SQLiteSampler(SQLSampler):
|
|
|
79
89
|
name = quote_ident(name)
|
|
80
90
|
columns = ', '.join(quote_ident(v) for v in index)
|
|
81
91
|
columns += ' DESC' if len(index) > 1 else ''
|
|
92
|
+
source_name = self.source_name_dict[table_name]
|
|
82
93
|
sql = (f"CREATE INDEX IF NOT EXISTS {name}\n"
|
|
83
|
-
f"ON {
|
|
94
|
+
f"ON {source_name}({columns})")
|
|
84
95
|
cursor.execute(sql)
|
|
85
|
-
|
|
86
|
-
|
|
96
|
+
self._connection.commit()
|
|
97
|
+
if len(index) > 1:
|
|
98
|
+
logger.log(f"Created index on {index} in table "
|
|
99
|
+
f"'{table_name}'")
|
|
100
|
+
else:
|
|
101
|
+
logger.log(f"Created index on '{index[0]}' in "
|
|
102
|
+
f"table '{table_name}'")
|
|
87
103
|
|
|
88
104
|
elif len(index_dict) > 0:
|
|
105
|
+
num = sum(len(indices) for indices in index_dict.values())
|
|
106
|
+
index_repr = '1 index' if num == 1 else f'{num} indices'
|
|
107
|
+
num = len(index_dict)
|
|
108
|
+
table_repr = '1 table' if num == 1 else f'{num} tables'
|
|
89
109
|
warnings.warn(f"Missing {index_repr} in {table_repr} for optimal "
|
|
90
110
|
f"database querying. For improving runtime, we "
|
|
91
|
-
f"strongly suggest to create
|
|
111
|
+
f"strongly suggest to create indices for primary "
|
|
112
|
+
f"and foreign keys, e.g., automatically by "
|
|
92
113
|
f"instantiating KumoRFM via "
|
|
93
114
|
f"`KumoRFM(graph, optimize=True)`.")
|
|
94
115
|
|
|
@@ -98,23 +119,26 @@ class SQLiteSampler(SQLSampler):
|
|
|
98
119
|
) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
|
|
99
120
|
selects: list[str] = []
|
|
100
121
|
for table_name in table_names:
|
|
101
|
-
|
|
122
|
+
column = self.time_column_dict[table_name]
|
|
123
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
124
|
+
ident = quote_ident(table_name, char="'")
|
|
102
125
|
select = (f"SELECT\n"
|
|
103
|
-
f"
|
|
104
|
-
f" MIN({
|
|
105
|
-
f" MAX({
|
|
106
|
-
f"FROM {self.
|
|
126
|
+
f" {ident} as table_name,\n"
|
|
127
|
+
f" MIN({column_ref}) as min_date,\n"
|
|
128
|
+
f" MAX({column_ref}) as max_date\n"
|
|
129
|
+
f"FROM {self.source_name_dict[table_name]}")
|
|
107
130
|
selects.append(select)
|
|
108
131
|
sql = "\nUNION ALL\n".join(selects)
|
|
109
132
|
|
|
110
133
|
out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
111
134
|
with self._connection.cursor() as cursor:
|
|
112
|
-
cursor.execute(sql
|
|
135
|
+
cursor.execute(sql)
|
|
113
136
|
for table_name, _min, _max in cursor.fetchall():
|
|
114
137
|
out_dict[table_name] = (
|
|
115
138
|
pd.Timestamp.max if _min is None else pd.Timestamp(_min),
|
|
116
139
|
pd.Timestamp.min if _max is None else pd.Timestamp(_max),
|
|
117
140
|
)
|
|
141
|
+
|
|
118
142
|
return out_dict
|
|
119
143
|
|
|
120
144
|
def _sample_entity_table(
|
|
@@ -126,18 +150,28 @@ class SQLiteSampler(SQLSampler):
|
|
|
126
150
|
) -> pd.DataFrame:
|
|
127
151
|
# NOTE SQLite does not natively support passing a `random_seed`.
|
|
128
152
|
|
|
153
|
+
source_table = self.source_table_dict[table_name]
|
|
129
154
|
filters: list[str] = []
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
155
|
+
|
|
156
|
+
key = self.primary_key_dict[table_name]
|
|
157
|
+
if key not in source_table or source_table[key].is_nullable:
|
|
158
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
159
|
+
filters.append(f" {key_ref} IS NOT NULL")
|
|
160
|
+
|
|
161
|
+
column = self.time_column_dict.get(table_name)
|
|
162
|
+
if column is None:
|
|
163
|
+
pass
|
|
164
|
+
elif column not in source_table or source_table[column].is_nullable:
|
|
165
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
166
|
+
filters.append(f" {column_ref} IS NOT NULL")
|
|
137
167
|
|
|
138
168
|
# TODO Make this query more efficient - it does full table scan.
|
|
139
|
-
|
|
140
|
-
|
|
169
|
+
projections = [
|
|
170
|
+
self.table_column_proj_dict[table_name][column]
|
|
171
|
+
for column in columns
|
|
172
|
+
]
|
|
173
|
+
sql = (f"SELECT {', '.join(projections)}\n"
|
|
174
|
+
f"FROM {self.source_name_dict[table_name]}")
|
|
141
175
|
if len(filters) > 0:
|
|
142
176
|
sql += f"\nWHERE{' AND'.join(filters)}"
|
|
143
177
|
sql += f"\nORDER BY RANDOM() LIMIT {num_rows}"
|
|
@@ -147,7 +181,11 @@ class SQLiteSampler(SQLSampler):
|
|
|
147
181
|
cursor.execute(sql)
|
|
148
182
|
table = cursor.fetch_arrow_table()
|
|
149
183
|
|
|
150
|
-
return
|
|
184
|
+
return Table._sanitize(
|
|
185
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
186
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
187
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
188
|
+
)
|
|
151
189
|
|
|
152
190
|
def _sample_target(
|
|
153
191
|
self,
|
|
@@ -190,84 +228,163 @@ class SQLiteSampler(SQLSampler):
|
|
|
190
228
|
def _by_pkey(
|
|
191
229
|
self,
|
|
192
230
|
table_name: str,
|
|
193
|
-
|
|
231
|
+
index: pd.Series,
|
|
194
232
|
columns: set[str],
|
|
195
233
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
196
|
-
|
|
234
|
+
source_table = self.source_table_dict[table_name]
|
|
235
|
+
key = self.primary_key_dict[table_name]
|
|
236
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
237
|
+
projections = [
|
|
238
|
+
self.table_column_proj_dict[table_name][column]
|
|
239
|
+
for column in columns
|
|
240
|
+
]
|
|
241
|
+
|
|
242
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
243
|
+
tmp_name = f'tmp_{table_name}_{key}_{id(tmp)}'
|
|
244
|
+
|
|
245
|
+
sql = (f"SELECT "
|
|
246
|
+
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
247
|
+
f"{', '.join(projections)}\n"
|
|
248
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
249
|
+
f"JOIN {self.source_name_dict[table_name]} ent\n")
|
|
250
|
+
if key in source_table and source_table[key].is_unique_key:
|
|
251
|
+
sql += (f" ON {key_ref} = tmp.__kumo_id__")
|
|
252
|
+
else:
|
|
253
|
+
sql += (f" ON ent.rowid = (\n"
|
|
254
|
+
f" SELECT rowid\n"
|
|
255
|
+
f" FROM {self.source_name_dict[table_name]}\n"
|
|
256
|
+
f" WHERE {key_ref} == tmp.__kumo_id__\n"
|
|
257
|
+
f" LIMIT 1\n"
|
|
258
|
+
f")")
|
|
259
|
+
|
|
260
|
+
with self._connection.cursor() as cursor:
|
|
261
|
+
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
262
|
+
cursor.execute(sql)
|
|
263
|
+
table = cursor.fetch_arrow_table()
|
|
197
264
|
|
|
198
|
-
|
|
199
|
-
|
|
265
|
+
batch = table['__kumo_batch__'].to_numpy()
|
|
266
|
+
batch_index = table.schema.get_field_index('__kumo_batch__')
|
|
267
|
+
table = table.remove_column(batch_index)
|
|
200
268
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
269
|
+
return Table._sanitize(
|
|
270
|
+
df=table.to_pandas(),
|
|
271
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
272
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
273
|
+
), batch
|
|
274
|
+
|
|
275
|
+
def _by_fkey(
|
|
276
|
+
self,
|
|
277
|
+
table_name: str,
|
|
278
|
+
foreign_key: str,
|
|
279
|
+
index: pd.Series,
|
|
280
|
+
num_neighbors: int,
|
|
281
|
+
anchor_time: pd.Series | None,
|
|
282
|
+
columns: set[str],
|
|
283
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
284
|
+
time_column = self.time_column_dict.get(table_name)
|
|
285
|
+
|
|
286
|
+
# NOTE SQLite does not have a native datetime format. Currently, we
|
|
287
|
+
# assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
|
|
288
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
289
|
+
if time_column is not None and anchor_time is not None:
|
|
290
|
+
anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
291
|
+
tmp = tmp.append_column('__kumo_time__', pa.array(anchor_time))
|
|
292
|
+
tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
|
|
293
|
+
|
|
294
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
295
|
+
projections = [
|
|
296
|
+
self.table_column_proj_dict[table_name][column]
|
|
297
|
+
for column in columns
|
|
298
|
+
]
|
|
299
|
+
sql = (f"SELECT "
|
|
300
|
+
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
301
|
+
f"{', '.join(projections)}\n"
|
|
302
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
303
|
+
f"JOIN {self.source_name_dict[table_name]} fact\n"
|
|
304
|
+
f"ON fact.rowid IN (\n"
|
|
305
|
+
f" SELECT rowid\n"
|
|
306
|
+
f" FROM {self.source_name_dict[table_name]}\n"
|
|
307
|
+
f" WHERE {key_ref} = tmp.__kumo_id__\n")
|
|
308
|
+
if time_column is not None and anchor_time is not None:
|
|
309
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
310
|
+
sql += f" AND {time_ref} <= tmp.__kumo_time__\n"
|
|
311
|
+
if time_column is not None:
|
|
312
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
313
|
+
sql += f" ORDER BY {time_ref} DESC\n"
|
|
314
|
+
sql += (f" LIMIT {num_neighbors}\n"
|
|
315
|
+
f")")
|
|
217
316
|
|
|
218
317
|
with self._connection.cursor() as cursor:
|
|
219
318
|
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
220
319
|
cursor.execute(sql)
|
|
221
320
|
table = cursor.fetch_arrow_table()
|
|
222
321
|
|
|
223
|
-
batch = table['
|
|
224
|
-
|
|
322
|
+
batch = table['__kumo_batch__'].to_numpy()
|
|
323
|
+
batch_index = table.schema.get_field_index('__kumo_batch__')
|
|
324
|
+
table = table.remove_column(batch_index)
|
|
225
325
|
|
|
226
|
-
return
|
|
326
|
+
return Table._sanitize(
|
|
327
|
+
df=table.to_pandas(),
|
|
328
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
329
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
330
|
+
), batch
|
|
227
331
|
|
|
228
332
|
# Helper Methods ##########################################################
|
|
229
333
|
|
|
230
334
|
def _by_time(
|
|
231
335
|
self,
|
|
232
336
|
table_name: str,
|
|
233
|
-
|
|
234
|
-
|
|
337
|
+
foreign_key: str,
|
|
338
|
+
index: pd.Series,
|
|
235
339
|
anchor_time: pd.Series,
|
|
236
340
|
min_offset: pd.DateOffset | None,
|
|
237
341
|
max_offset: pd.DateOffset,
|
|
238
342
|
columns: set[str],
|
|
239
343
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
344
|
+
time_column = self.time_column_dict[table_name]
|
|
345
|
+
|
|
240
346
|
# NOTE SQLite does not have a native datetime format. Currently, we
|
|
241
347
|
# assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
|
|
242
|
-
tmp = pa.table([pa.array(
|
|
348
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
243
349
|
end_time = anchor_time + max_offset
|
|
244
350
|
end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
245
|
-
tmp = tmp.append_column('
|
|
351
|
+
tmp = tmp.append_column('__kumo_end__', pa.array(end_time))
|
|
246
352
|
if min_offset is not None:
|
|
247
353
|
start_time = anchor_time + min_offset
|
|
248
354
|
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
249
|
-
tmp = tmp.append_column('
|
|
250
|
-
tmp_name = f'tmp_{table_name}_{
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
355
|
+
tmp = tmp.append_column('__kumo_start__', pa.array(start_time))
|
|
356
|
+
tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
|
|
357
|
+
|
|
358
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
359
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
360
|
+
projections = [
|
|
361
|
+
self.table_column_proj_dict[table_name][column]
|
|
362
|
+
for column in columns
|
|
363
|
+
]
|
|
364
|
+
sql = (f"SELECT "
|
|
365
|
+
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
366
|
+
f"{', '.join(projections)}\n"
|
|
255
367
|
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
256
|
-
f"JOIN {self.
|
|
257
|
-
f" ON
|
|
258
|
-
f" AND
|
|
368
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
369
|
+
f" ON {key_ref} = tmp.__kumo_id__\n"
|
|
370
|
+
f" AND {time_ref} <= tmp.__kumo_end__")
|
|
259
371
|
if min_offset is not None:
|
|
260
|
-
sql += f"\n AND
|
|
372
|
+
sql += f"\n AND {time_ref} > tmp.__kumo_start__"
|
|
261
373
|
|
|
262
374
|
with self._connection.cursor() as cursor:
|
|
263
375
|
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
264
376
|
cursor.execute(sql)
|
|
265
377
|
table = cursor.fetch_arrow_table()
|
|
266
378
|
|
|
267
|
-
batch = table['
|
|
268
|
-
|
|
379
|
+
batch = table['__kumo_batch__'].to_numpy()
|
|
380
|
+
batch_index = table.schema.get_field_index('__kumo_batch__')
|
|
381
|
+
table = table.remove_column(batch_index)
|
|
269
382
|
|
|
270
|
-
return
|
|
383
|
+
return Table._sanitize(
|
|
384
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
385
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
386
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
387
|
+
), batch
|
|
271
388
|
|
|
272
389
|
def _sample_target_set(
|
|
273
390
|
self,
|
|
@@ -300,11 +417,11 @@ class SQLiteSampler(SQLSampler):
|
|
|
300
417
|
query.entity_table: np.arange(len(df)),
|
|
301
418
|
}
|
|
302
419
|
for edge_type, (_min, _max) in time_offset_dict.items():
|
|
303
|
-
table_name,
|
|
420
|
+
table_name, foreign_key, _ = edge_type
|
|
304
421
|
feat_dict[table_name], batch_dict[table_name] = self._by_time(
|
|
305
422
|
table_name=table_name,
|
|
306
|
-
|
|
307
|
-
|
|
423
|
+
foreign_key=foreign_key,
|
|
424
|
+
index=df[self.primary_key_dict[query.entity_table]],
|
|
308
425
|
anchor_time=time,
|
|
309
426
|
min_offset=_min,
|
|
310
427
|
max_offset=_max,
|
|
@@ -319,7 +436,7 @@ class SQLiteSampler(SQLSampler):
|
|
|
319
436
|
feat_dict=feat_dict,
|
|
320
437
|
time_dict=time_dict,
|
|
321
438
|
batch_dict=batch_dict,
|
|
322
|
-
anchor_time=
|
|
439
|
+
anchor_time=time,
|
|
323
440
|
num_forecasts=query.num_forecasts,
|
|
324
441
|
)
|
|
325
442
|
ys.append(y)
|
|
@@ -337,13 +454,3 @@ class SQLiteSampler(SQLSampler):
|
|
|
337
454
|
y = pd.concat(ys, axis=0, ignore_index=True)
|
|
338
455
|
|
|
339
456
|
return y, mask
|
|
340
|
-
|
|
341
|
-
def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
|
|
342
|
-
df = table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
343
|
-
|
|
344
|
-
stype_dict = self.table_stype_dict[table_name]
|
|
345
|
-
for column_name in df.columns:
|
|
346
|
-
if stype_dict.get(column_name) == Stype.timestamp:
|
|
347
|
-
df[column_name] = pd.to_datetime(df[column_name])
|
|
348
|
-
|
|
349
|
-
return df
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import re
|
|
2
|
-
import
|
|
2
|
+
from collections import Counter
|
|
3
3
|
from collections.abc import Sequence
|
|
4
4
|
from typing import cast
|
|
5
5
|
|
|
@@ -9,27 +9,25 @@ from kumoapi.typing import Dtype
|
|
|
9
9
|
|
|
10
10
|
from kumoai.experimental.rfm.backend.sqlite import Connection
|
|
11
11
|
from kumoai.experimental.rfm.base import (
|
|
12
|
-
|
|
13
|
-
|
|
12
|
+
ColumnSpec,
|
|
13
|
+
ColumnSpecType,
|
|
14
14
|
DataBackend,
|
|
15
15
|
SourceColumn,
|
|
16
16
|
SourceForeignKey,
|
|
17
|
-
|
|
17
|
+
Table,
|
|
18
18
|
)
|
|
19
|
-
from kumoai.experimental.rfm.infer import infer_dtype
|
|
20
19
|
from kumoai.utils import quote_ident
|
|
21
20
|
|
|
22
21
|
|
|
23
|
-
class SQLiteTable(
|
|
22
|
+
class SQLiteTable(Table):
|
|
24
23
|
r"""A table backed by a :class:`sqlite` database.
|
|
25
24
|
|
|
26
25
|
Args:
|
|
27
26
|
connection: The connection to a :class:`sqlite` database.
|
|
28
|
-
name: The
|
|
29
|
-
source_name: The
|
|
30
|
-
``
|
|
31
|
-
columns: The selected
|
|
32
|
-
column_expressions: The logical columns of this table.
|
|
27
|
+
name: The name of this table.
|
|
28
|
+
source_name: The source name of this table. If set to ``None``,
|
|
29
|
+
``name`` is being used.
|
|
30
|
+
columns: The selected columns of this table.
|
|
33
31
|
primary_key: The name of the primary key of this table, if it exists.
|
|
34
32
|
time_column: The name of the time column of this table, if it exists.
|
|
35
33
|
end_time_column: The name of the end time column of this table, if it
|
|
@@ -40,8 +38,7 @@ class SQLiteTable(SQLTable):
|
|
|
40
38
|
connection: Connection,
|
|
41
39
|
name: str,
|
|
42
40
|
source_name: str | None = None,
|
|
43
|
-
columns: Sequence[
|
|
44
|
-
column_expressions: Sequence[ColumnExpressionType] | None = None,
|
|
41
|
+
columns: Sequence[ColumnSpecType] | None = None,
|
|
45
42
|
primary_key: MissingType | str | None = MissingType.VALUE,
|
|
46
43
|
time_column: str | None = None,
|
|
47
44
|
end_time_column: str | None = None,
|
|
@@ -53,7 +50,6 @@ class SQLiteTable(SQLTable):
|
|
|
53
50
|
name=name,
|
|
54
51
|
source_name=source_name,
|
|
55
52
|
columns=columns,
|
|
56
|
-
column_expressions=column_expressions,
|
|
57
53
|
primary_key=primary_key,
|
|
58
54
|
time_column=time_column,
|
|
59
55
|
end_time_column=end_time_column,
|
|
@@ -66,16 +62,16 @@ class SQLiteTable(SQLTable):
|
|
|
66
62
|
def _get_source_columns(self) -> list[SourceColumn]:
|
|
67
63
|
source_columns: list[SourceColumn] = []
|
|
68
64
|
with self._connection.cursor() as cursor:
|
|
69
|
-
sql = f"PRAGMA table_info({self.
|
|
65
|
+
sql = f"PRAGMA table_info({self._quoted_source_name})"
|
|
70
66
|
cursor.execute(sql)
|
|
71
67
|
columns = cursor.fetchall()
|
|
72
68
|
|
|
73
69
|
if len(columns) == 0:
|
|
74
|
-
raise ValueError(f"Table '{self.
|
|
70
|
+
raise ValueError(f"Table '{self.source_name}' does not exist "
|
|
75
71
|
f"in the SQLite database")
|
|
76
72
|
|
|
77
73
|
unique_keys: set[str] = set()
|
|
78
|
-
sql = f"PRAGMA index_list({self.
|
|
74
|
+
sql = f"PRAGMA index_list({self._quoted_source_name})"
|
|
79
75
|
cursor.execute(sql)
|
|
80
76
|
for _, index_name, is_unique, *_ in cursor.fetchall():
|
|
81
77
|
if bool(is_unique):
|
|
@@ -85,32 +81,19 @@ class SQLiteTable(SQLTable):
|
|
|
85
81
|
if len(index) == 1:
|
|
86
82
|
unique_keys.add(index[0][2])
|
|
87
83
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
dtype = Dtype.float
|
|
97
|
-
else: # NUMERIC affinity.
|
|
98
|
-
ser = self._source_sample_df[column]
|
|
99
|
-
try:
|
|
100
|
-
dtype = infer_dtype(ser)
|
|
101
|
-
except Exception:
|
|
102
|
-
warnings.warn(f"Encountered unsupported data type "
|
|
103
|
-
f"'{ser.dtype}' with source data type "
|
|
104
|
-
f"'{type}' for column '{column}' in "
|
|
105
|
-
f"table '{self.name}'. If possible, "
|
|
106
|
-
f"change the data type of the column in "
|
|
107
|
-
f"your SQLite database to use it within "
|
|
108
|
-
f"this table.")
|
|
109
|
-
continue
|
|
84
|
+
# Special SQLite case that creates a rowid alias for
|
|
85
|
+
# `INTEGER PRIMARY KEY` annotated columns:
|
|
86
|
+
rowid_candidates = [
|
|
87
|
+
column for _, column, dtype, _, _, is_pkey in columns
|
|
88
|
+
if bool(is_pkey) and dtype.strip().upper() == 'INTEGER'
|
|
89
|
+
]
|
|
90
|
+
if len(rowid_candidates) == 1:
|
|
91
|
+
unique_keys.add(rowid_candidates[0])
|
|
110
92
|
|
|
93
|
+
for _, column, dtype, notnull, _, is_pkey in columns:
|
|
111
94
|
source_column = SourceColumn(
|
|
112
95
|
name=column,
|
|
113
|
-
dtype=dtype,
|
|
96
|
+
dtype=self._to_dtype(dtype),
|
|
114
97
|
is_primary_key=bool(is_pkey),
|
|
115
98
|
is_unique_key=column in unique_keys,
|
|
116
99
|
is_nullable=not bool(is_pkey) and not bool(notnull),
|
|
@@ -120,35 +103,82 @@ class SQLiteTable(SQLTable):
|
|
|
120
103
|
return source_columns
|
|
121
104
|
|
|
122
105
|
def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
|
|
123
|
-
|
|
106
|
+
source_foreign_keys: list[SourceForeignKey] = []
|
|
124
107
|
with self._connection.cursor() as cursor:
|
|
125
|
-
sql = f"PRAGMA foreign_key_list({self.
|
|
108
|
+
sql = f"PRAGMA foreign_key_list({self._quoted_source_name})"
|
|
126
109
|
cursor.execute(sql)
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
110
|
+
rows = cursor.fetchall()
|
|
111
|
+
counts = Counter(row[0] for row in rows)
|
|
112
|
+
for idx, _, dst_table, foreign_key, primary_key, *_ in rows:
|
|
113
|
+
if counts[idx] == 1:
|
|
114
|
+
source_foreign_key = SourceForeignKey(
|
|
115
|
+
name=foreign_key,
|
|
116
|
+
dst_table=dst_table,
|
|
117
|
+
primary_key=primary_key,
|
|
118
|
+
)
|
|
119
|
+
source_foreign_keys.append(source_foreign_key)
|
|
120
|
+
return source_foreign_keys
|
|
130
121
|
|
|
131
122
|
def _get_source_sample_df(self) -> pd.DataFrame:
|
|
132
123
|
with self._connection.cursor() as cursor:
|
|
133
|
-
|
|
134
|
-
|
|
124
|
+
columns = [quote_ident(col) for col in self._source_column_dict]
|
|
125
|
+
sql = (f"SELECT {', '.join(columns)} "
|
|
126
|
+
f"FROM {self._quoted_source_name} "
|
|
127
|
+
f"ORDER BY rowid "
|
|
128
|
+
f"LIMIT {self._NUM_SAMPLE_ROWS}")
|
|
135
129
|
cursor.execute(sql)
|
|
136
130
|
table = cursor.fetch_arrow_table()
|
|
137
|
-
|
|
131
|
+
|
|
132
|
+
if len(table) == 0:
|
|
133
|
+
raise RuntimeError(f"Table '{self.source_name}' is empty")
|
|
134
|
+
|
|
135
|
+
return self._sanitize(
|
|
136
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
137
|
+
dtype_dict={
|
|
138
|
+
column.name: column.dtype
|
|
139
|
+
for column in self._source_column_dict.values()
|
|
140
|
+
},
|
|
141
|
+
stype_dict=None,
|
|
142
|
+
)
|
|
138
143
|
|
|
139
144
|
def _get_num_rows(self) -> int | None:
|
|
140
145
|
return None
|
|
141
146
|
|
|
142
|
-
def
|
|
147
|
+
def _get_expr_sample_df(
|
|
143
148
|
self,
|
|
144
|
-
|
|
149
|
+
columns: Sequence[ColumnSpec],
|
|
145
150
|
) -> pd.DataFrame:
|
|
146
151
|
with self._connection.cursor() as cursor:
|
|
147
|
-
|
|
148
|
-
f"{
|
|
152
|
+
projections = [
|
|
153
|
+
f"{column.expr} AS {quote_ident(column.name)}"
|
|
154
|
+
for column in columns
|
|
149
155
|
]
|
|
150
|
-
sql = (f"SELECT {', '.join(
|
|
151
|
-
f"
|
|
156
|
+
sql = (f"SELECT {', '.join(projections)} "
|
|
157
|
+
f"FROM {self._quoted_source_name} "
|
|
158
|
+
f"ORDER BY rowid "
|
|
159
|
+
f"LIMIT {self._NUM_SAMPLE_ROWS}")
|
|
152
160
|
cursor.execute(sql)
|
|
153
161
|
table = cursor.fetch_arrow_table()
|
|
154
|
-
|
|
162
|
+
|
|
163
|
+
if len(table) == 0:
|
|
164
|
+
raise RuntimeError(f"Table '{self.source_name}' is empty")
|
|
165
|
+
|
|
166
|
+
return self._sanitize(
|
|
167
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
168
|
+
dtype_dict={column.name: column.dtype
|
|
169
|
+
for column in columns},
|
|
170
|
+
stype_dict=None,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
@staticmethod
|
|
174
|
+
def _to_dtype(dtype: str | None) -> Dtype | None:
|
|
175
|
+
if dtype is None:
|
|
176
|
+
return None
|
|
177
|
+
dtype = dtype.strip().upper()
|
|
178
|
+
if re.search('INT', dtype):
|
|
179
|
+
return Dtype.int
|
|
180
|
+
if re.search('TEXT|CHAR|CLOB', dtype):
|
|
181
|
+
return Dtype.string
|
|
182
|
+
if re.search('REAL|FLOA|DOUB', dtype):
|
|
183
|
+
return Dtype.float
|
|
184
|
+
return None # NUMERIC affinity.
|