kumoai 2.14.0.dev202512211732__cp313-cp313-win_amd64.whl → 2.14.0.dev202601081732__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +24 -22
- kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +24 -25
- kumoai/experimental/rfm/backend/snow/sampler.py +184 -70
- kumoai/experimental/rfm/backend/snow/table.py +137 -64
- kumoai/experimental/rfm/backend/sqlite/sampler.py +191 -86
- kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
- kumoai/experimental/rfm/base/__init__.py +6 -9
- kumoai/experimental/rfm/base/column.py +95 -11
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +26 -17
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +182 -19
- kumoai/experimental/rfm/base/table.py +275 -109
- kumoai/experimental/rfm/graph.py +115 -107
- kumoai/experimental/rfm/infer/dtype.py +4 -1
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +530 -304
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/kumolib.cp313-win_amd64.pyd +0 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +13 -1
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/METADATA +1 -1
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/RECORD +36 -33
- kumoai/experimental/rfm/base/column_expression.py +0 -50
- kumoai/experimental/rfm/base/sql_table.py +0 -229
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/top_level.txt +0 -0
|
@@ -6,9 +6,9 @@ import numpy as np
|
|
|
6
6
|
import pandas as pd
|
|
7
7
|
import pyarrow as pa
|
|
8
8
|
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
9
|
-
from kumoapi.typing import Stype
|
|
10
9
|
|
|
11
|
-
from kumoai.experimental.rfm.
|
|
10
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteTable
|
|
11
|
+
from kumoai.experimental.rfm.base import SQLSampler, Table
|
|
12
12
|
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
13
13
|
from kumoai.utils import ProgressLogger, quote_ident
|
|
14
14
|
|
|
@@ -25,22 +25,32 @@ class SQLiteSampler(SQLSampler):
|
|
|
25
25
|
) -> None:
|
|
26
26
|
super().__init__(graph=graph, verbose=verbose)
|
|
27
27
|
|
|
28
|
+
for table in graph.tables.values():
|
|
29
|
+
assert isinstance(table, SQLiteTable)
|
|
30
|
+
self._connection = table._connection
|
|
31
|
+
|
|
28
32
|
if optimize:
|
|
29
33
|
with self._connection.cursor() as cursor:
|
|
30
34
|
cursor.execute("PRAGMA temp_store = MEMORY")
|
|
31
35
|
cursor.execute("PRAGMA cache_size = -2000000") # 2 GB
|
|
32
36
|
|
|
33
|
-
# Collect database indices
|
|
37
|
+
# Collect database indices for speeding sampling:
|
|
34
38
|
index_dict: dict[str, set[tuple[str, ...]]] = defaultdict(set)
|
|
35
39
|
for table_name, primary_key in self.primary_key_dict.items():
|
|
36
40
|
source_table = self.source_table_dict[table_name]
|
|
37
|
-
if not source_table
|
|
38
|
-
|
|
41
|
+
if primary_key not in source_table:
|
|
42
|
+
continue # No physical column.
|
|
43
|
+
if source_table[primary_key].is_unique_key:
|
|
44
|
+
continue
|
|
45
|
+
index_dict[table_name].add((primary_key, ))
|
|
39
46
|
for src_table_name, foreign_key, _ in graph.edges:
|
|
40
47
|
source_table = self.source_table_dict[src_table_name]
|
|
48
|
+
if foreign_key not in source_table:
|
|
49
|
+
continue # No physical column.
|
|
41
50
|
if source_table[foreign_key].is_unique_key:
|
|
42
|
-
|
|
43
|
-
|
|
51
|
+
continue
|
|
52
|
+
time_column = self.time_column_dict.get(src_table_name)
|
|
53
|
+
if time_column is not None and time_column in source_table:
|
|
44
54
|
index_dict[src_table_name].add((foreign_key, time_column))
|
|
45
55
|
else:
|
|
46
56
|
index_dict[src_table_name].add((foreign_key, ))
|
|
@@ -49,22 +59,22 @@ class SQLiteSampler(SQLSampler):
|
|
|
49
59
|
with self._connection.cursor() as cursor:
|
|
50
60
|
for table_name in list(index_dict.keys()):
|
|
51
61
|
indices = index_dict[table_name]
|
|
52
|
-
|
|
62
|
+
source_name = self.source_name_dict[table_name]
|
|
63
|
+
sql = f"PRAGMA index_list({source_name})"
|
|
53
64
|
cursor.execute(sql)
|
|
54
65
|
for _, index_name, *_ in cursor.fetchall():
|
|
55
66
|
sql = f"PRAGMA index_info({quote_ident(index_name)})"
|
|
56
67
|
cursor.execute(sql)
|
|
57
|
-
index
|
|
68
|
+
# Fetch index information and sort by `seqno`:
|
|
69
|
+
index_info = tuple(info[2] for info in sorted(
|
|
58
70
|
cursor.fetchall(), key=lambda x: x[0]))
|
|
59
|
-
indices
|
|
71
|
+
# Remove all indices in case primary index already exists:
|
|
72
|
+
for index in list(indices):
|
|
73
|
+
if index_info[0] == index[0]:
|
|
74
|
+
indices.discard(index)
|
|
60
75
|
if len(indices) == 0:
|
|
61
76
|
del index_dict[table_name]
|
|
62
77
|
|
|
63
|
-
num = sum(len(indices) for indices in index_dict.values())
|
|
64
|
-
index_repr = '1 index' if num == 1 else f'{num} indices'
|
|
65
|
-
num = len(index_dict)
|
|
66
|
-
table_repr = '1 table' if num == 1 else f'{num} tables'
|
|
67
|
-
|
|
68
78
|
if optimize and len(index_dict) > 0:
|
|
69
79
|
if not isinstance(verbose, ProgressLogger):
|
|
70
80
|
verbose = ProgressLogger.default(
|
|
@@ -79,16 +89,27 @@ class SQLiteSampler(SQLSampler):
|
|
|
79
89
|
name = quote_ident(name)
|
|
80
90
|
columns = ', '.join(quote_ident(v) for v in index)
|
|
81
91
|
columns += ' DESC' if len(index) > 1 else ''
|
|
92
|
+
source_name = self.source_name_dict[table_name]
|
|
82
93
|
sql = (f"CREATE INDEX IF NOT EXISTS {name}\n"
|
|
83
|
-
f"ON {
|
|
94
|
+
f"ON {source_name}({columns})")
|
|
84
95
|
cursor.execute(sql)
|
|
85
|
-
|
|
86
|
-
|
|
96
|
+
self._connection.commit()
|
|
97
|
+
if len(index) > 1:
|
|
98
|
+
logger.log(f"Created index on {index} in table "
|
|
99
|
+
f"'{table_name}'")
|
|
100
|
+
else:
|
|
101
|
+
logger.log(f"Created index on '{index[0]}' in "
|
|
102
|
+
f"table '{table_name}'")
|
|
87
103
|
|
|
88
104
|
elif len(index_dict) > 0:
|
|
105
|
+
num = sum(len(indices) for indices in index_dict.values())
|
|
106
|
+
index_repr = '1 index' if num == 1 else f'{num} indices'
|
|
107
|
+
num = len(index_dict)
|
|
108
|
+
table_repr = '1 table' if num == 1 else f'{num} tables'
|
|
89
109
|
warnings.warn(f"Missing {index_repr} in {table_repr} for optimal "
|
|
90
110
|
f"database querying. For improving runtime, we "
|
|
91
|
-
f"strongly suggest to create
|
|
111
|
+
f"strongly suggest to create indices for primary "
|
|
112
|
+
f"and foreign keys, e.g., automatically by "
|
|
92
113
|
f"instantiating KumoRFM via "
|
|
93
114
|
f"`KumoRFM(graph, optimize=True)`.")
|
|
94
115
|
|
|
@@ -98,12 +119,13 @@ class SQLiteSampler(SQLSampler):
|
|
|
98
119
|
) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
|
|
99
120
|
selects: list[str] = []
|
|
100
121
|
for table_name in table_names:
|
|
101
|
-
|
|
122
|
+
column = self.time_column_dict[table_name]
|
|
123
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
102
124
|
select = (f"SELECT\n"
|
|
103
125
|
f" ? as table_name,\n"
|
|
104
|
-
f" MIN({
|
|
105
|
-
f" MAX({
|
|
106
|
-
f"FROM {self.
|
|
126
|
+
f" MIN({column_ref}) as min_date,\n"
|
|
127
|
+
f" MAX({column_ref}) as max_date\n"
|
|
128
|
+
f"FROM {self.source_name_dict[table_name]}")
|
|
107
129
|
selects.append(select)
|
|
108
130
|
sql = "\nUNION ALL\n".join(selects)
|
|
109
131
|
|
|
@@ -126,18 +148,28 @@ class SQLiteSampler(SQLSampler):
|
|
|
126
148
|
) -> pd.DataFrame:
|
|
127
149
|
# NOTE SQLite does not natively support passing a `random_seed`.
|
|
128
150
|
|
|
151
|
+
source_table = self.source_table_dict[table_name]
|
|
129
152
|
filters: list[str] = []
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
153
|
+
|
|
154
|
+
key = self.primary_key_dict[table_name]
|
|
155
|
+
if key not in source_table or source_table[key].is_nullable:
|
|
156
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
157
|
+
filters.append(f" {key_ref} IS NOT NULL")
|
|
158
|
+
|
|
159
|
+
column = self.time_column_dict.get(table_name)
|
|
160
|
+
if column is None:
|
|
161
|
+
pass
|
|
162
|
+
elif column not in source_table or source_table[column].is_nullable:
|
|
163
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
164
|
+
filters.append(f" {column_ref} IS NOT NULL")
|
|
137
165
|
|
|
138
166
|
# TODO Make this query more efficient - it does full table scan.
|
|
139
|
-
|
|
140
|
-
|
|
167
|
+
projections = [
|
|
168
|
+
self.table_column_proj_dict[table_name][column]
|
|
169
|
+
for column in columns
|
|
170
|
+
]
|
|
171
|
+
sql = (f"SELECT {', '.join(projections)}\n"
|
|
172
|
+
f"FROM {self.source_name_dict[table_name]}")
|
|
141
173
|
if len(filters) > 0:
|
|
142
174
|
sql += f"\nWHERE{' AND'.join(filters)}"
|
|
143
175
|
sql += f"\nORDER BY RANDOM() LIMIT {num_rows}"
|
|
@@ -147,7 +179,11 @@ class SQLiteSampler(SQLSampler):
|
|
|
147
179
|
cursor.execute(sql)
|
|
148
180
|
table = cursor.fetch_arrow_table()
|
|
149
181
|
|
|
150
|
-
return
|
|
182
|
+
return Table._sanitize(
|
|
183
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
184
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
185
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
186
|
+
)
|
|
151
187
|
|
|
152
188
|
def _sample_target(
|
|
153
189
|
self,
|
|
@@ -190,84 +226,163 @@ class SQLiteSampler(SQLSampler):
|
|
|
190
226
|
def _by_pkey(
|
|
191
227
|
self,
|
|
192
228
|
table_name: str,
|
|
193
|
-
|
|
229
|
+
index: pd.Series,
|
|
194
230
|
columns: set[str],
|
|
195
231
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
196
|
-
|
|
232
|
+
source_table = self.source_table_dict[table_name]
|
|
233
|
+
key = self.primary_key_dict[table_name]
|
|
234
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
235
|
+
projections = [
|
|
236
|
+
self.table_column_proj_dict[table_name][column]
|
|
237
|
+
for column in columns
|
|
238
|
+
]
|
|
239
|
+
|
|
240
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
241
|
+
tmp_name = f'tmp_{table_name}_{key}_{id(tmp)}'
|
|
242
|
+
|
|
243
|
+
sql = (f"SELECT "
|
|
244
|
+
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
245
|
+
f"{', '.join(projections)}\n"
|
|
246
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
247
|
+
f"JOIN {self.source_name_dict[table_name]} ent\n")
|
|
248
|
+
if key in source_table and source_table[key].is_unique_key:
|
|
249
|
+
sql += (f" ON {key_ref} = tmp.__kumo_id__")
|
|
250
|
+
else:
|
|
251
|
+
sql += (f" ON ent.rowid = (\n"
|
|
252
|
+
f" SELECT rowid\n"
|
|
253
|
+
f" FROM {self.source_name_dict[table_name]}\n"
|
|
254
|
+
f" WHERE {key_ref} == tmp.__kumo_id__\n"
|
|
255
|
+
f" LIMIT 1\n"
|
|
256
|
+
f")")
|
|
197
257
|
|
|
198
|
-
|
|
199
|
-
|
|
258
|
+
with self._connection.cursor() as cursor:
|
|
259
|
+
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
260
|
+
cursor.execute(sql)
|
|
261
|
+
table = cursor.fetch_arrow_table()
|
|
200
262
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
263
|
+
batch = table['__kumo_batch__'].to_numpy()
|
|
264
|
+
batch_index = table.schema.get_field_index('__kumo_batch__')
|
|
265
|
+
table = table.remove_column(batch_index)
|
|
266
|
+
|
|
267
|
+
return Table._sanitize(
|
|
268
|
+
df=table.to_pandas(),
|
|
269
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
270
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
271
|
+
), batch
|
|
272
|
+
|
|
273
|
+
def _by_fkey(
|
|
274
|
+
self,
|
|
275
|
+
table_name: str,
|
|
276
|
+
foreign_key: str,
|
|
277
|
+
index: pd.Series,
|
|
278
|
+
num_neighbors: int,
|
|
279
|
+
anchor_time: pd.Series | None,
|
|
280
|
+
columns: set[str],
|
|
281
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
282
|
+
time_column = self.time_column_dict.get(table_name)
|
|
283
|
+
|
|
284
|
+
# NOTE SQLite does not have a native datetime format. Currently, we
|
|
285
|
+
# assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
|
|
286
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
287
|
+
if time_column is not None and anchor_time is not None:
|
|
288
|
+
anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
289
|
+
tmp = tmp.append_column('__kumo_time__', pa.array(anchor_time))
|
|
290
|
+
tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
|
|
291
|
+
|
|
292
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
293
|
+
projections = [
|
|
294
|
+
self.table_column_proj_dict[table_name][column]
|
|
295
|
+
for column in columns
|
|
296
|
+
]
|
|
297
|
+
sql = (f"SELECT "
|
|
298
|
+
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
299
|
+
f"{', '.join(projections)}\n"
|
|
300
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
301
|
+
f"JOIN {self.source_name_dict[table_name]} fact\n"
|
|
302
|
+
f"ON fact.rowid IN (\n"
|
|
303
|
+
f" SELECT rowid\n"
|
|
304
|
+
f" FROM {self.source_name_dict[table_name]}\n"
|
|
305
|
+
f" WHERE {key_ref} = tmp.__kumo_id__\n")
|
|
306
|
+
if time_column is not None and anchor_time is not None:
|
|
307
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
308
|
+
sql += f" AND {time_ref} <= tmp.__kumo_time__\n"
|
|
309
|
+
if time_column is not None:
|
|
310
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
311
|
+
sql += f" ORDER BY {time_ref} DESC\n"
|
|
312
|
+
sql += (f" LIMIT {num_neighbors}\n"
|
|
313
|
+
f")")
|
|
217
314
|
|
|
218
315
|
with self._connection.cursor() as cursor:
|
|
219
316
|
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
220
317
|
cursor.execute(sql)
|
|
221
318
|
table = cursor.fetch_arrow_table()
|
|
222
319
|
|
|
223
|
-
batch = table['
|
|
224
|
-
|
|
320
|
+
batch = table['__kumo_batch__'].to_numpy()
|
|
321
|
+
batch_index = table.schema.get_field_index('__kumo_batch__')
|
|
322
|
+
table = table.remove_column(batch_index)
|
|
225
323
|
|
|
226
|
-
return
|
|
324
|
+
return Table._sanitize(
|
|
325
|
+
df=table.to_pandas(),
|
|
326
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
327
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
328
|
+
), batch
|
|
227
329
|
|
|
228
330
|
# Helper Methods ##########################################################
|
|
229
331
|
|
|
230
332
|
def _by_time(
|
|
231
333
|
self,
|
|
232
334
|
table_name: str,
|
|
233
|
-
|
|
234
|
-
|
|
335
|
+
foreign_key: str,
|
|
336
|
+
index: pd.Series,
|
|
235
337
|
anchor_time: pd.Series,
|
|
236
338
|
min_offset: pd.DateOffset | None,
|
|
237
339
|
max_offset: pd.DateOffset,
|
|
238
340
|
columns: set[str],
|
|
239
341
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
342
|
+
time_column = self.time_column_dict[table_name]
|
|
343
|
+
|
|
240
344
|
# NOTE SQLite does not have a native datetime format. Currently, we
|
|
241
345
|
# assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
|
|
242
|
-
tmp = pa.table([pa.array(
|
|
346
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
243
347
|
end_time = anchor_time + max_offset
|
|
244
348
|
end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
245
|
-
tmp = tmp.append_column('
|
|
349
|
+
tmp = tmp.append_column('__kumo_end__', pa.array(end_time))
|
|
246
350
|
if min_offset is not None:
|
|
247
351
|
start_time = anchor_time + min_offset
|
|
248
352
|
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
249
|
-
tmp = tmp.append_column('
|
|
250
|
-
tmp_name = f'tmp_{table_name}_{
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
353
|
+
tmp = tmp.append_column('__kumo_start__', pa.array(start_time))
|
|
354
|
+
tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
|
|
355
|
+
|
|
356
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
357
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
358
|
+
projections = [
|
|
359
|
+
self.table_column_proj_dict[table_name][column]
|
|
360
|
+
for column in columns
|
|
361
|
+
]
|
|
362
|
+
sql = (f"SELECT "
|
|
363
|
+
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
364
|
+
f"{', '.join(projections)}\n"
|
|
255
365
|
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
256
|
-
f"JOIN {self.
|
|
257
|
-
f" ON
|
|
258
|
-
f" AND
|
|
366
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
367
|
+
f" ON {key_ref} = tmp.__kumo_id__\n"
|
|
368
|
+
f" AND {time_ref} <= tmp.__kumo_end__")
|
|
259
369
|
if min_offset is not None:
|
|
260
|
-
sql += f"\n AND
|
|
370
|
+
sql += f"\n AND {time_ref} > tmp.__kumo_start__"
|
|
261
371
|
|
|
262
372
|
with self._connection.cursor() as cursor:
|
|
263
373
|
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
264
374
|
cursor.execute(sql)
|
|
265
375
|
table = cursor.fetch_arrow_table()
|
|
266
376
|
|
|
267
|
-
batch = table['
|
|
268
|
-
|
|
377
|
+
batch = table['__kumo_batch__'].to_numpy()
|
|
378
|
+
batch_index = table.schema.get_field_index('__kumo_batch__')
|
|
379
|
+
table = table.remove_column(batch_index)
|
|
269
380
|
|
|
270
|
-
return
|
|
381
|
+
return Table._sanitize(
|
|
382
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
383
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
384
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
385
|
+
), batch
|
|
271
386
|
|
|
272
387
|
def _sample_target_set(
|
|
273
388
|
self,
|
|
@@ -300,11 +415,11 @@ class SQLiteSampler(SQLSampler):
|
|
|
300
415
|
query.entity_table: np.arange(len(df)),
|
|
301
416
|
}
|
|
302
417
|
for edge_type, (_min, _max) in time_offset_dict.items():
|
|
303
|
-
table_name,
|
|
418
|
+
table_name, foreign_key, _ = edge_type
|
|
304
419
|
feat_dict[table_name], batch_dict[table_name] = self._by_time(
|
|
305
420
|
table_name=table_name,
|
|
306
|
-
|
|
307
|
-
|
|
421
|
+
foreign_key=foreign_key,
|
|
422
|
+
index=df[self.primary_key_dict[query.entity_table]],
|
|
308
423
|
anchor_time=time,
|
|
309
424
|
min_offset=_min,
|
|
310
425
|
max_offset=_max,
|
|
@@ -337,13 +452,3 @@ class SQLiteSampler(SQLSampler):
|
|
|
337
452
|
y = pd.concat(ys, axis=0, ignore_index=True)
|
|
338
453
|
|
|
339
454
|
return y, mask
|
|
340
|
-
|
|
341
|
-
def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
|
|
342
|
-
df = table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
343
|
-
|
|
344
|
-
stype_dict = self.table_stype_dict[table_name]
|
|
345
|
-
for column_name in df.columns:
|
|
346
|
-
if stype_dict.get(column_name) == Stype.timestamp:
|
|
347
|
-
df[column_name] = pd.to_datetime(df[column_name])
|
|
348
|
-
|
|
349
|
-
return df
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import re
|
|
2
|
-
import
|
|
2
|
+
from collections import Counter
|
|
3
3
|
from collections.abc import Sequence
|
|
4
4
|
from typing import cast
|
|
5
5
|
|
|
@@ -9,27 +9,25 @@ from kumoapi.typing import Dtype
|
|
|
9
9
|
|
|
10
10
|
from kumoai.experimental.rfm.backend.sqlite import Connection
|
|
11
11
|
from kumoai.experimental.rfm.base import (
|
|
12
|
-
|
|
13
|
-
|
|
12
|
+
ColumnSpec,
|
|
13
|
+
ColumnSpecType,
|
|
14
14
|
DataBackend,
|
|
15
15
|
SourceColumn,
|
|
16
16
|
SourceForeignKey,
|
|
17
|
-
|
|
17
|
+
Table,
|
|
18
18
|
)
|
|
19
|
-
from kumoai.experimental.rfm.infer import infer_dtype
|
|
20
19
|
from kumoai.utils import quote_ident
|
|
21
20
|
|
|
22
21
|
|
|
23
|
-
class SQLiteTable(
|
|
22
|
+
class SQLiteTable(Table):
|
|
24
23
|
r"""A table backed by a :class:`sqlite` database.
|
|
25
24
|
|
|
26
25
|
Args:
|
|
27
26
|
connection: The connection to a :class:`sqlite` database.
|
|
28
|
-
name: The
|
|
29
|
-
source_name: The
|
|
30
|
-
``
|
|
31
|
-
columns: The selected
|
|
32
|
-
column_expressions: The logical columns of this table.
|
|
27
|
+
name: The name of this table.
|
|
28
|
+
source_name: The source name of this table. If set to ``None``,
|
|
29
|
+
``name`` is being used.
|
|
30
|
+
columns: The selected columns of this table.
|
|
33
31
|
primary_key: The name of the primary key of this table, if it exists.
|
|
34
32
|
time_column: The name of the time column of this table, if it exists.
|
|
35
33
|
end_time_column: The name of the end time column of this table, if it
|
|
@@ -40,8 +38,7 @@ class SQLiteTable(SQLTable):
|
|
|
40
38
|
connection: Connection,
|
|
41
39
|
name: str,
|
|
42
40
|
source_name: str | None = None,
|
|
43
|
-
columns: Sequence[
|
|
44
|
-
column_expressions: Sequence[ColumnExpressionType] | None = None,
|
|
41
|
+
columns: Sequence[ColumnSpecType] | None = None,
|
|
45
42
|
primary_key: MissingType | str | None = MissingType.VALUE,
|
|
46
43
|
time_column: str | None = None,
|
|
47
44
|
end_time_column: str | None = None,
|
|
@@ -53,7 +50,6 @@ class SQLiteTable(SQLTable):
|
|
|
53
50
|
name=name,
|
|
54
51
|
source_name=source_name,
|
|
55
52
|
columns=columns,
|
|
56
|
-
column_expressions=column_expressions,
|
|
57
53
|
primary_key=primary_key,
|
|
58
54
|
time_column=time_column,
|
|
59
55
|
end_time_column=end_time_column,
|
|
@@ -66,16 +62,16 @@ class SQLiteTable(SQLTable):
|
|
|
66
62
|
def _get_source_columns(self) -> list[SourceColumn]:
|
|
67
63
|
source_columns: list[SourceColumn] = []
|
|
68
64
|
with self._connection.cursor() as cursor:
|
|
69
|
-
sql = f"PRAGMA table_info({self.
|
|
65
|
+
sql = f"PRAGMA table_info({self._quoted_source_name})"
|
|
70
66
|
cursor.execute(sql)
|
|
71
67
|
columns = cursor.fetchall()
|
|
72
68
|
|
|
73
69
|
if len(columns) == 0:
|
|
74
|
-
raise ValueError(f"Table '{self.
|
|
70
|
+
raise ValueError(f"Table '{self.source_name}' does not exist "
|
|
75
71
|
f"in the SQLite database")
|
|
76
72
|
|
|
77
73
|
unique_keys: set[str] = set()
|
|
78
|
-
sql = f"PRAGMA index_list({self.
|
|
74
|
+
sql = f"PRAGMA index_list({self._quoted_source_name})"
|
|
79
75
|
cursor.execute(sql)
|
|
80
76
|
for _, index_name, is_unique, *_ in cursor.fetchall():
|
|
81
77
|
if bool(is_unique):
|
|
@@ -85,32 +81,19 @@ class SQLiteTable(SQLTable):
|
|
|
85
81
|
if len(index) == 1:
|
|
86
82
|
unique_keys.add(index[0][2])
|
|
87
83
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
dtype = Dtype.float
|
|
97
|
-
else: # NUMERIC affinity.
|
|
98
|
-
ser = self._source_sample_df[column]
|
|
99
|
-
try:
|
|
100
|
-
dtype = infer_dtype(ser)
|
|
101
|
-
except Exception:
|
|
102
|
-
warnings.warn(f"Encountered unsupported data type "
|
|
103
|
-
f"'{ser.dtype}' with source data type "
|
|
104
|
-
f"'{type}' for column '{column}' in "
|
|
105
|
-
f"table '{self.name}'. If possible, "
|
|
106
|
-
f"change the data type of the column in "
|
|
107
|
-
f"your SQLite database to use it within "
|
|
108
|
-
f"this table.")
|
|
109
|
-
continue
|
|
84
|
+
# Special SQLite case that creates a rowid alias for
|
|
85
|
+
# `INTEGER PRIMARY KEY` annotated columns:
|
|
86
|
+
rowid_candidates = [
|
|
87
|
+
column for _, column, dtype, _, _, is_pkey in columns
|
|
88
|
+
if bool(is_pkey) and dtype.strip().upper() == 'INTEGER'
|
|
89
|
+
]
|
|
90
|
+
if len(rowid_candidates) == 1:
|
|
91
|
+
unique_keys.add(rowid_candidates[0])
|
|
110
92
|
|
|
93
|
+
for _, column, dtype, notnull, _, is_pkey in columns:
|
|
111
94
|
source_column = SourceColumn(
|
|
112
95
|
name=column,
|
|
113
|
-
dtype=dtype,
|
|
96
|
+
dtype=self._to_dtype(dtype),
|
|
114
97
|
is_primary_key=bool(is_pkey),
|
|
115
98
|
is_unique_key=column in unique_keys,
|
|
116
99
|
is_nullable=not bool(is_pkey) and not bool(notnull),
|
|
@@ -120,35 +103,82 @@ class SQLiteTable(SQLTable):
|
|
|
120
103
|
return source_columns
|
|
121
104
|
|
|
122
105
|
def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
|
|
123
|
-
|
|
106
|
+
source_foreign_keys: list[SourceForeignKey] = []
|
|
124
107
|
with self._connection.cursor() as cursor:
|
|
125
|
-
sql = f"PRAGMA foreign_key_list({self.
|
|
108
|
+
sql = f"PRAGMA foreign_key_list({self._quoted_source_name})"
|
|
126
109
|
cursor.execute(sql)
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
110
|
+
rows = cursor.fetchall()
|
|
111
|
+
counts = Counter(row[0] for row in rows)
|
|
112
|
+
for idx, _, dst_table, foreign_key, primary_key, *_ in rows:
|
|
113
|
+
if counts[idx] == 1:
|
|
114
|
+
source_foreign_key = SourceForeignKey(
|
|
115
|
+
name=foreign_key,
|
|
116
|
+
dst_table=dst_table,
|
|
117
|
+
primary_key=primary_key,
|
|
118
|
+
)
|
|
119
|
+
source_foreign_keys.append(source_foreign_key)
|
|
120
|
+
return source_foreign_keys
|
|
130
121
|
|
|
131
122
|
def _get_source_sample_df(self) -> pd.DataFrame:
|
|
132
123
|
with self._connection.cursor() as cursor:
|
|
133
|
-
|
|
134
|
-
|
|
124
|
+
columns = [quote_ident(col) for col in self._source_column_dict]
|
|
125
|
+
sql = (f"SELECT {', '.join(columns)} "
|
|
126
|
+
f"FROM {self._quoted_source_name} "
|
|
127
|
+
f"ORDER BY rowid "
|
|
128
|
+
f"LIMIT {self._NUM_SAMPLE_ROWS}")
|
|
135
129
|
cursor.execute(sql)
|
|
136
130
|
table = cursor.fetch_arrow_table()
|
|
137
|
-
|
|
131
|
+
|
|
132
|
+
if len(table) == 0:
|
|
133
|
+
raise RuntimeError(f"Table '{self.source_name}' is empty")
|
|
134
|
+
|
|
135
|
+
return self._sanitize(
|
|
136
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
137
|
+
dtype_dict={
|
|
138
|
+
column.name: column.dtype
|
|
139
|
+
for column in self._source_column_dict.values()
|
|
140
|
+
},
|
|
141
|
+
stype_dict=None,
|
|
142
|
+
)
|
|
138
143
|
|
|
139
144
|
def _get_num_rows(self) -> int | None:
|
|
140
145
|
return None
|
|
141
146
|
|
|
142
|
-
def
|
|
147
|
+
def _get_expr_sample_df(
|
|
143
148
|
self,
|
|
144
|
-
|
|
149
|
+
columns: Sequence[ColumnSpec],
|
|
145
150
|
) -> pd.DataFrame:
|
|
146
151
|
with self._connection.cursor() as cursor:
|
|
147
|
-
|
|
148
|
-
f"{
|
|
152
|
+
projections = [
|
|
153
|
+
f"{column.expr} AS {quote_ident(column.name)}"
|
|
154
|
+
for column in columns
|
|
149
155
|
]
|
|
150
|
-
sql = (f"SELECT {', '.join(
|
|
151
|
-
f"
|
|
156
|
+
sql = (f"SELECT {', '.join(projections)} "
|
|
157
|
+
f"FROM {self._quoted_source_name} "
|
|
158
|
+
f"ORDER BY rowid "
|
|
159
|
+
f"LIMIT {self._NUM_SAMPLE_ROWS}")
|
|
152
160
|
cursor.execute(sql)
|
|
153
161
|
table = cursor.fetch_arrow_table()
|
|
154
|
-
|
|
162
|
+
|
|
163
|
+
if len(table) == 0:
|
|
164
|
+
raise RuntimeError(f"Table '{self.source_name}' is empty")
|
|
165
|
+
|
|
166
|
+
return self._sanitize(
|
|
167
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
168
|
+
dtype_dict={column.name: column.dtype
|
|
169
|
+
for column in columns},
|
|
170
|
+
stype_dict=None,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
@staticmethod
|
|
174
|
+
def _to_dtype(dtype: str | None) -> Dtype | None:
|
|
175
|
+
if dtype is None:
|
|
176
|
+
return None
|
|
177
|
+
dtype = dtype.strip().upper()
|
|
178
|
+
if re.search('INT', dtype):
|
|
179
|
+
return Dtype.int
|
|
180
|
+
if re.search('TEXT|CHAR|CLOB', dtype):
|
|
181
|
+
return Dtype.string
|
|
182
|
+
if re.search('REAL|FLOA|DOUB', dtype):
|
|
183
|
+
return Dtype.float
|
|
184
|
+
return None # NUMERIC affinity.
|
|
@@ -8,12 +8,9 @@ class DataBackend(StrEnum):
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
from .source import SourceColumn, SourceForeignKey # noqa: E402
|
|
11
|
-
from .
|
|
12
|
-
from .
|
|
13
|
-
from .column_expression import ColumnExpressionType # noqa: E402
|
|
14
|
-
from .column_expression import ColumnExpression # noqa: E402
|
|
11
|
+
from .expression import Expression, LocalExpression # noqa: E402
|
|
12
|
+
from .column import ColumnSpec, ColumnSpecType, Column # noqa: E402
|
|
15
13
|
from .table import Table # noqa: E402
|
|
16
|
-
from .sql_table import SQLTable # noqa: E402
|
|
17
14
|
from .sampler import SamplerOutput, Sampler # noqa: E402
|
|
18
15
|
from .sql_sampler import SQLSampler # noqa: E402
|
|
19
16
|
|
|
@@ -21,12 +18,12 @@ __all__ = [
|
|
|
21
18
|
'DataBackend',
|
|
22
19
|
'SourceColumn',
|
|
23
20
|
'SourceForeignKey',
|
|
21
|
+
'Expression',
|
|
22
|
+
'LocalExpression',
|
|
23
|
+
'ColumnSpec',
|
|
24
|
+
'ColumnSpecType',
|
|
24
25
|
'Column',
|
|
25
|
-
'ColumnExpressionSpec',
|
|
26
|
-
'ColumnExpressionType',
|
|
27
|
-
'ColumnExpression',
|
|
28
26
|
'Table',
|
|
29
|
-
'SQLTable',
|
|
30
27
|
'SamplerOutput',
|
|
31
28
|
'Sampler',
|
|
32
29
|
'SQLSampler',
|