kumoai 2.14.0.dev202512181731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512301731__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +24 -0
- kumoai/experimental/rfm/__init__.py +22 -22
- kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +25 -24
- kumoai/experimental/rfm/backend/snow/sampler.py +106 -61
- kumoai/experimental/rfm/backend/snow/table.py +146 -51
- kumoai/experimental/rfm/backend/sqlite/sampler.py +127 -78
- kumoai/experimental/rfm/backend/sqlite/table.py +94 -47
- kumoai/experimental/rfm/base/__init__.py +6 -7
- kumoai/experimental/rfm/base/column.py +97 -5
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +5 -17
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +68 -9
- kumoai/experimental/rfm/base/table.py +284 -120
- kumoai/experimental/rfm/graph.py +139 -86
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +6 -1
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +4 -20
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/display.py +51 -0
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/METADATA +1 -1
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/RECORD +33 -30
- kumoai/experimental/rfm/base/column_expression.py +0 -16
- kumoai/experimental/rfm/base/sql_table.py +0 -113
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/top_level.txt +0 -0
|
@@ -6,9 +6,9 @@ import numpy as np
|
|
|
6
6
|
import pandas as pd
|
|
7
7
|
import pyarrow as pa
|
|
8
8
|
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
9
|
-
from kumoapi.typing import Stype
|
|
10
9
|
|
|
11
|
-
from kumoai.experimental.rfm.
|
|
10
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteTable
|
|
11
|
+
from kumoai.experimental.rfm.base import SQLSampler, Table
|
|
12
12
|
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
13
13
|
from kumoai.utils import ProgressLogger, quote_ident
|
|
14
14
|
|
|
@@ -25,22 +25,32 @@ class SQLiteSampler(SQLSampler):
|
|
|
25
25
|
) -> None:
|
|
26
26
|
super().__init__(graph=graph, verbose=verbose)
|
|
27
27
|
|
|
28
|
+
for table in graph.tables.values():
|
|
29
|
+
assert isinstance(table, SQLiteTable)
|
|
30
|
+
self._connection = table._connection
|
|
31
|
+
|
|
28
32
|
if optimize:
|
|
29
33
|
with self._connection.cursor() as cursor:
|
|
30
34
|
cursor.execute("PRAGMA temp_store = MEMORY")
|
|
31
35
|
cursor.execute("PRAGMA cache_size = -2000000") # 2 GB
|
|
32
36
|
|
|
33
|
-
# Collect database indices
|
|
37
|
+
# Collect database indices for speeding sampling:
|
|
34
38
|
index_dict: dict[str, set[tuple[str, ...]]] = defaultdict(set)
|
|
35
39
|
for table_name, primary_key in self.primary_key_dict.items():
|
|
36
40
|
source_table = self.source_table_dict[table_name]
|
|
37
|
-
if not source_table
|
|
38
|
-
|
|
41
|
+
if primary_key not in source_table:
|
|
42
|
+
continue # No physical column.
|
|
43
|
+
if source_table[primary_key].is_unique_key:
|
|
44
|
+
continue
|
|
45
|
+
index_dict[table_name].add((primary_key, ))
|
|
39
46
|
for src_table_name, foreign_key, _ in graph.edges:
|
|
40
47
|
source_table = self.source_table_dict[src_table_name]
|
|
48
|
+
if foreign_key not in source_table:
|
|
49
|
+
continue # No physical column.
|
|
41
50
|
if source_table[foreign_key].is_unique_key:
|
|
42
|
-
|
|
43
|
-
|
|
51
|
+
continue
|
|
52
|
+
time_column = self.time_column_dict.get(src_table_name)
|
|
53
|
+
if time_column is not None and time_column in source_table:
|
|
44
54
|
index_dict[src_table_name].add((foreign_key, time_column))
|
|
45
55
|
else:
|
|
46
56
|
index_dict[src_table_name].add((foreign_key, ))
|
|
@@ -49,22 +59,22 @@ class SQLiteSampler(SQLSampler):
|
|
|
49
59
|
with self._connection.cursor() as cursor:
|
|
50
60
|
for table_name in list(index_dict.keys()):
|
|
51
61
|
indices = index_dict[table_name]
|
|
52
|
-
|
|
62
|
+
source_name = self.source_name_dict[table_name]
|
|
63
|
+
sql = f"PRAGMA index_list({source_name})"
|
|
53
64
|
cursor.execute(sql)
|
|
54
65
|
for _, index_name, *_ in cursor.fetchall():
|
|
55
66
|
sql = f"PRAGMA index_info({quote_ident(index_name)})"
|
|
56
67
|
cursor.execute(sql)
|
|
57
|
-
index
|
|
68
|
+
# Fetch index information and sort by `seqno`:
|
|
69
|
+
index_info = tuple(info[2] for info in sorted(
|
|
58
70
|
cursor.fetchall(), key=lambda x: x[0]))
|
|
59
|
-
indices
|
|
71
|
+
# Remove all indices in case primary index already exists:
|
|
72
|
+
for index in list(indices):
|
|
73
|
+
if index_info[0] == index[0]:
|
|
74
|
+
indices.discard(index)
|
|
60
75
|
if len(indices) == 0:
|
|
61
76
|
del index_dict[table_name]
|
|
62
77
|
|
|
63
|
-
num = sum(len(indices) for indices in index_dict.values())
|
|
64
|
-
index_repr = '1 index' if num == 1 else f'{num} indices'
|
|
65
|
-
num = len(index_dict)
|
|
66
|
-
table_repr = '1 table' if num == 1 else f'{num} tables'
|
|
67
|
-
|
|
68
78
|
if optimize and len(index_dict) > 0:
|
|
69
79
|
if not isinstance(verbose, ProgressLogger):
|
|
70
80
|
verbose = ProgressLogger.default(
|
|
@@ -79,16 +89,27 @@ class SQLiteSampler(SQLSampler):
|
|
|
79
89
|
name = quote_ident(name)
|
|
80
90
|
columns = ', '.join(quote_ident(v) for v in index)
|
|
81
91
|
columns += ' DESC' if len(index) > 1 else ''
|
|
92
|
+
source_name = self.source_name_dict[table_name]
|
|
82
93
|
sql = (f"CREATE INDEX IF NOT EXISTS {name}\n"
|
|
83
|
-
f"ON {
|
|
94
|
+
f"ON {source_name}({columns})")
|
|
84
95
|
cursor.execute(sql)
|
|
85
|
-
|
|
86
|
-
|
|
96
|
+
self._connection.commit()
|
|
97
|
+
if len(index) > 1:
|
|
98
|
+
logger.log(f"Created index on {index} in table "
|
|
99
|
+
f"'{table_name}'")
|
|
100
|
+
else:
|
|
101
|
+
logger.log(f"Created index on '{index[0]}' in "
|
|
102
|
+
f"table '{table_name}'")
|
|
87
103
|
|
|
88
104
|
elif len(index_dict) > 0:
|
|
105
|
+
num = sum(len(indices) for indices in index_dict.values())
|
|
106
|
+
index_repr = '1 index' if num == 1 else f'{num} indices'
|
|
107
|
+
num = len(index_dict)
|
|
108
|
+
table_repr = '1 table' if num == 1 else f'{num} tables'
|
|
89
109
|
warnings.warn(f"Missing {index_repr} in {table_repr} for optimal "
|
|
90
110
|
f"database querying. For improving runtime, we "
|
|
91
|
-
f"strongly suggest to create
|
|
111
|
+
f"strongly suggest to create indices for primary "
|
|
112
|
+
f"and foreign keys, e.g., automatically by "
|
|
92
113
|
f"instantiating KumoRFM via "
|
|
93
114
|
f"`KumoRFM(graph, optimize=True)`.")
|
|
94
115
|
|
|
@@ -98,12 +119,13 @@ class SQLiteSampler(SQLSampler):
|
|
|
98
119
|
) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
|
|
99
120
|
selects: list[str] = []
|
|
100
121
|
for table_name in table_names:
|
|
101
|
-
|
|
122
|
+
column = self.time_column_dict[table_name]
|
|
123
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
102
124
|
select = (f"SELECT\n"
|
|
103
125
|
f" ? as table_name,\n"
|
|
104
|
-
f" MIN({
|
|
105
|
-
f" MAX({
|
|
106
|
-
f"FROM {self.
|
|
126
|
+
f" MIN({column_ref}) as min_date,\n"
|
|
127
|
+
f" MAX({column_ref}) as max_date\n"
|
|
128
|
+
f"FROM {self.source_name_dict[table_name]}")
|
|
107
129
|
selects.append(select)
|
|
108
130
|
sql = "\nUNION ALL\n".join(selects)
|
|
109
131
|
|
|
@@ -126,18 +148,28 @@ class SQLiteSampler(SQLSampler):
|
|
|
126
148
|
) -> pd.DataFrame:
|
|
127
149
|
# NOTE SQLite does not natively support passing a `random_seed`.
|
|
128
150
|
|
|
151
|
+
source_table = self.source_table_dict[table_name]
|
|
129
152
|
filters: list[str] = []
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
153
|
+
|
|
154
|
+
key = self.primary_key_dict[table_name]
|
|
155
|
+
if key not in source_table or source_table[key].is_nullable:
|
|
156
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
157
|
+
filters.append(f" {key_ref} IS NOT NULL")
|
|
158
|
+
|
|
159
|
+
column = self.time_column_dict.get(table_name)
|
|
160
|
+
if column is None:
|
|
161
|
+
pass
|
|
162
|
+
elif column not in source_table or source_table[column].is_nullable:
|
|
163
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
164
|
+
filters.append(f" {column_ref} IS NOT NULL")
|
|
137
165
|
|
|
138
166
|
# TODO Make this query more efficient - it does full table scan.
|
|
139
|
-
|
|
140
|
-
|
|
167
|
+
projections = [
|
|
168
|
+
self.table_column_proj_dict[table_name][column]
|
|
169
|
+
for column in columns
|
|
170
|
+
]
|
|
171
|
+
sql = (f"SELECT {', '.join(projections)}\n"
|
|
172
|
+
f"FROM {self.source_name_dict[table_name]}")
|
|
141
173
|
if len(filters) > 0:
|
|
142
174
|
sql += f"\nWHERE{' AND'.join(filters)}"
|
|
143
175
|
sql += f"\nORDER BY RANDOM() LIMIT {num_rows}"
|
|
@@ -147,7 +179,11 @@ class SQLiteSampler(SQLSampler):
|
|
|
147
179
|
cursor.execute(sql)
|
|
148
180
|
table = cursor.fetch_arrow_table()
|
|
149
181
|
|
|
150
|
-
return
|
|
182
|
+
return Table._sanitize(
|
|
183
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
184
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
185
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
186
|
+
)
|
|
151
187
|
|
|
152
188
|
def _sample_target(
|
|
153
189
|
self,
|
|
@@ -193,37 +229,47 @@ class SQLiteSampler(SQLSampler):
|
|
|
193
229
|
pkey: pd.Series,
|
|
194
230
|
columns: set[str],
|
|
195
231
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
232
|
+
source_table = self.source_table_dict[table_name]
|
|
233
|
+
key = self.primary_key_dict[table_name]
|
|
234
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
235
|
+
projections = [
|
|
236
|
+
self.table_column_proj_dict[table_name][column]
|
|
237
|
+
for column in columns
|
|
238
|
+
]
|
|
239
|
+
|
|
240
|
+
tmp = pa.table([pa.array(pkey)], names=['__kumo_id__'])
|
|
241
|
+
tmp_name = f'tmp_{table_name}_{key}_{id(tmp)}'
|
|
242
|
+
|
|
243
|
+
sql = (f"SELECT "
|
|
244
|
+
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
245
|
+
f"{', '.join(projections)}\n"
|
|
246
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
247
|
+
f"JOIN {self.source_name_dict[table_name]} ent\n")
|
|
200
248
|
|
|
201
|
-
if
|
|
202
|
-
sql
|
|
203
|
-
f"{', '.join('ent.' + quote_ident(c) for c in columns)}\n"
|
|
204
|
-
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
205
|
-
f"JOIN {self.fqn_dict[table_name]} ent\n"
|
|
206
|
-
f" ON ent.{quote_ident(pkey_name)} = tmp.id")
|
|
249
|
+
if key in source_table and source_table[key].is_unique_key:
|
|
250
|
+
sql += (f" ON {key_ref} = tmp.__kumo_id__")
|
|
207
251
|
else:
|
|
208
|
-
sql
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
f" WHERE {quote_ident(pkey_name)} == tmp.id\n"
|
|
215
|
-
f" LIMIT 1\n"
|
|
216
|
-
f")")
|
|
252
|
+
sql += (f" ON ent.rowid = (\n"
|
|
253
|
+
f" SELECT rowid\n"
|
|
254
|
+
f" FROM {self.source_name_dict[table_name]}\n"
|
|
255
|
+
f" WHERE {key_ref} == tmp.__kumo_id__\n"
|
|
256
|
+
f" LIMIT 1\n"
|
|
257
|
+
f")")
|
|
217
258
|
|
|
218
259
|
with self._connection.cursor() as cursor:
|
|
219
260
|
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
220
261
|
cursor.execute(sql)
|
|
221
262
|
table = cursor.fetch_arrow_table()
|
|
222
263
|
|
|
223
|
-
batch = table['
|
|
224
|
-
|
|
264
|
+
batch = table['__kumo_batch__'].to_numpy()
|
|
265
|
+
batch_index = table.schema.get_field_index('__kumo_batch__')
|
|
266
|
+
table = table.remove_column(batch_index)
|
|
225
267
|
|
|
226
|
-
return
|
|
268
|
+
return Table._sanitize(
|
|
269
|
+
df=table.to_pandas(),
|
|
270
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
271
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
272
|
+
), batch
|
|
227
273
|
|
|
228
274
|
# Helper Methods ##########################################################
|
|
229
275
|
|
|
@@ -237,37 +283,50 @@ class SQLiteSampler(SQLSampler):
|
|
|
237
283
|
max_offset: pd.DateOffset,
|
|
238
284
|
columns: set[str],
|
|
239
285
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
286
|
+
time_column = self.time_column_dict[table_name]
|
|
287
|
+
|
|
240
288
|
# NOTE SQLite does not have a native datetime format. Currently, we
|
|
241
289
|
# assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
|
|
242
|
-
tmp = pa.table([pa.array(pkey)], names=['
|
|
290
|
+
tmp = pa.table([pa.array(pkey)], names=['__kumo_id__'])
|
|
243
291
|
end_time = anchor_time + max_offset
|
|
244
292
|
end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
245
|
-
tmp = tmp.append_column('
|
|
293
|
+
tmp = tmp.append_column('__kumo_end__', pa.array(end_time))
|
|
246
294
|
if min_offset is not None:
|
|
247
295
|
start_time = anchor_time + min_offset
|
|
248
296
|
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
249
|
-
tmp = tmp.append_column('
|
|
297
|
+
tmp = tmp.append_column('__kumo_start__', pa.array(start_time))
|
|
250
298
|
tmp_name = f'tmp_{table_name}_{fkey}_{id(tmp)}'
|
|
251
299
|
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
300
|
+
key_ref = self.table_column_ref_dict[table_name][fkey]
|
|
301
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
302
|
+
projections = [
|
|
303
|
+
self.table_column_proj_dict[table_name][column]
|
|
304
|
+
for column in columns
|
|
305
|
+
]
|
|
306
|
+
sql = (f"SELECT "
|
|
307
|
+
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
308
|
+
f"{', '.join(projections)}\n"
|
|
255
309
|
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
256
|
-
f"JOIN {self.
|
|
257
|
-
f" ON
|
|
258
|
-
f" AND
|
|
310
|
+
f"JOIN {self.source_name_dict[table_name]} fact\n"
|
|
311
|
+
f" ON {key_ref} = tmp.__kumo_id__\n"
|
|
312
|
+
f" AND {time_ref} <= tmp.__kumo_end__")
|
|
259
313
|
if min_offset is not None:
|
|
260
|
-
sql += f"\n AND
|
|
314
|
+
sql += f"\n AND {time_ref} > tmp.__kumo_start__"
|
|
261
315
|
|
|
262
316
|
with self._connection.cursor() as cursor:
|
|
263
317
|
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
264
318
|
cursor.execute(sql)
|
|
265
319
|
table = cursor.fetch_arrow_table()
|
|
266
320
|
|
|
267
|
-
batch = table['
|
|
268
|
-
|
|
321
|
+
batch = table['__kumo_batch__'].to_numpy()
|
|
322
|
+
batch_index = table.schema.get_field_index('__kumo_batch__')
|
|
323
|
+
table = table.remove_column(batch_index)
|
|
269
324
|
|
|
270
|
-
return
|
|
325
|
+
return Table._sanitize(
|
|
326
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
327
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
328
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
329
|
+
), batch
|
|
271
330
|
|
|
272
331
|
def _sample_target_set(
|
|
273
332
|
self,
|
|
@@ -337,13 +396,3 @@ class SQLiteSampler(SQLSampler):
|
|
|
337
396
|
y = pd.concat(ys, axis=0, ignore_index=True)
|
|
338
397
|
|
|
339
398
|
return y, mask
|
|
340
|
-
|
|
341
|
-
def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
|
|
342
|
-
df = table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
343
|
-
|
|
344
|
-
stype_dict = self.table_stype_dict[table_name]
|
|
345
|
-
for column_name in df.columns:
|
|
346
|
-
if stype_dict.get(column_name) == Stype.timestamp:
|
|
347
|
-
df[column_name] = pd.to_datetime(df[column_name])
|
|
348
|
-
|
|
349
|
-
return df
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import re
|
|
2
|
-
import
|
|
2
|
+
from collections import Counter
|
|
3
3
|
from collections.abc import Sequence
|
|
4
4
|
from typing import cast
|
|
5
5
|
|
|
@@ -9,26 +9,25 @@ from kumoapi.typing import Dtype
|
|
|
9
9
|
|
|
10
10
|
from kumoai.experimental.rfm.backend.sqlite import Connection
|
|
11
11
|
from kumoai.experimental.rfm.base import (
|
|
12
|
-
|
|
12
|
+
ColumnSpec,
|
|
13
|
+
ColumnSpecType,
|
|
13
14
|
DataBackend,
|
|
14
15
|
SourceColumn,
|
|
15
16
|
SourceForeignKey,
|
|
16
|
-
|
|
17
|
+
Table,
|
|
17
18
|
)
|
|
18
|
-
from kumoai.experimental.rfm.infer import infer_dtype
|
|
19
19
|
from kumoai.utils import quote_ident
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
class SQLiteTable(
|
|
22
|
+
class SQLiteTable(Table):
|
|
23
23
|
r"""A table backed by a :class:`sqlite` database.
|
|
24
24
|
|
|
25
25
|
Args:
|
|
26
26
|
connection: The connection to a :class:`sqlite` database.
|
|
27
|
-
name: The
|
|
28
|
-
source_name: The
|
|
29
|
-
``
|
|
30
|
-
columns: The selected
|
|
31
|
-
column_expressions: The logical columns of this table.
|
|
27
|
+
name: The name of this table.
|
|
28
|
+
source_name: The source name of this table. If set to ``None``,
|
|
29
|
+
``name`` is being used.
|
|
30
|
+
columns: The selected columns of this table.
|
|
32
31
|
primary_key: The name of the primary key of this table, if it exists.
|
|
33
32
|
time_column: The name of the time column of this table, if it exists.
|
|
34
33
|
end_time_column: The name of the end time column of this table, if it
|
|
@@ -39,8 +38,7 @@ class SQLiteTable(SQLTable):
|
|
|
39
38
|
connection: Connection,
|
|
40
39
|
name: str,
|
|
41
40
|
source_name: str | None = None,
|
|
42
|
-
columns: Sequence[
|
|
43
|
-
column_expressions: Sequence[ColumnExpressionType] | None = None,
|
|
41
|
+
columns: Sequence[ColumnSpecType] | None = None,
|
|
44
42
|
primary_key: MissingType | str | None = MissingType.VALUE,
|
|
45
43
|
time_column: str | None = None,
|
|
46
44
|
end_time_column: str | None = None,
|
|
@@ -52,7 +50,6 @@ class SQLiteTable(SQLTable):
|
|
|
52
50
|
name=name,
|
|
53
51
|
source_name=source_name,
|
|
54
52
|
columns=columns,
|
|
55
|
-
column_expressions=column_expressions,
|
|
56
53
|
primary_key=primary_key,
|
|
57
54
|
time_column=time_column,
|
|
58
55
|
end_time_column=end_time_column,
|
|
@@ -65,16 +62,16 @@ class SQLiteTable(SQLTable):
|
|
|
65
62
|
def _get_source_columns(self) -> list[SourceColumn]:
|
|
66
63
|
source_columns: list[SourceColumn] = []
|
|
67
64
|
with self._connection.cursor() as cursor:
|
|
68
|
-
sql = f"PRAGMA table_info({self.
|
|
65
|
+
sql = f"PRAGMA table_info({self._quoted_source_name})"
|
|
69
66
|
cursor.execute(sql)
|
|
70
67
|
columns = cursor.fetchall()
|
|
71
68
|
|
|
72
69
|
if len(columns) == 0:
|
|
73
|
-
raise ValueError(f"Table '{self.
|
|
70
|
+
raise ValueError(f"Table '{self.source_name}' does not exist "
|
|
74
71
|
f"in the SQLite database")
|
|
75
72
|
|
|
76
73
|
unique_keys: set[str] = set()
|
|
77
|
-
sql = f"PRAGMA index_list({self.
|
|
74
|
+
sql = f"PRAGMA index_list({self._quoted_source_name})"
|
|
78
75
|
cursor.execute(sql)
|
|
79
76
|
for _, index_name, is_unique, *_ in cursor.fetchall():
|
|
80
77
|
if bool(is_unique):
|
|
@@ -84,30 +81,19 @@ class SQLiteTable(SQLTable):
|
|
|
84
81
|
if len(index) == 1:
|
|
85
82
|
unique_keys.add(index[0][2])
|
|
86
83
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
dtype = Dtype.float
|
|
96
|
-
else: # NUMERIC affinity.
|
|
97
|
-
ser = self._sample_df[column]
|
|
98
|
-
try:
|
|
99
|
-
dtype = infer_dtype(ser)
|
|
100
|
-
except Exception:
|
|
101
|
-
warnings.warn(
|
|
102
|
-
f"Data type inference for column '{column}' in "
|
|
103
|
-
f"table '{self.name}' failed. Consider changing "
|
|
104
|
-
f"the data type of the column in the database or "
|
|
105
|
-
f"remove this column from this table.")
|
|
106
|
-
continue
|
|
84
|
+
# Special SQLite case that creates a rowid alias for
|
|
85
|
+
# `INTEGER PRIMARY KEY` annotated columns:
|
|
86
|
+
rowid_candidates = [
|
|
87
|
+
column for _, column, dtype, _, _, is_pkey in columns
|
|
88
|
+
if bool(is_pkey) and dtype.strip().upper() == 'INTEGER'
|
|
89
|
+
]
|
|
90
|
+
if len(rowid_candidates) == 1:
|
|
91
|
+
unique_keys.add(rowid_candidates[0])
|
|
107
92
|
|
|
93
|
+
for _, column, dtype, notnull, _, is_pkey in columns:
|
|
108
94
|
source_column = SourceColumn(
|
|
109
95
|
name=column,
|
|
110
|
-
dtype=dtype,
|
|
96
|
+
dtype=self._to_dtype(dtype),
|
|
111
97
|
is_primary_key=bool(is_pkey),
|
|
112
98
|
is_unique_key=column in unique_keys,
|
|
113
99
|
is_nullable=not bool(is_pkey) and not bool(notnull),
|
|
@@ -117,21 +103,82 @@ class SQLiteTable(SQLTable):
|
|
|
117
103
|
return source_columns
|
|
118
104
|
|
|
119
105
|
def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
|
|
120
|
-
|
|
106
|
+
source_foreign_keys: list[SourceForeignKey] = []
|
|
121
107
|
with self._connection.cursor() as cursor:
|
|
122
|
-
sql = f"PRAGMA foreign_key_list({self.
|
|
108
|
+
sql = f"PRAGMA foreign_key_list({self._quoted_source_name})"
|
|
123
109
|
cursor.execute(sql)
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
110
|
+
rows = cursor.fetchall()
|
|
111
|
+
counts = Counter(row[0] for row in rows)
|
|
112
|
+
for idx, _, dst_table, foreign_key, primary_key, *_ in rows:
|
|
113
|
+
if counts[idx] == 1:
|
|
114
|
+
source_foreign_key = SourceForeignKey(
|
|
115
|
+
name=foreign_key,
|
|
116
|
+
dst_table=dst_table,
|
|
117
|
+
primary_key=primary_key,
|
|
118
|
+
)
|
|
119
|
+
source_foreign_keys.append(source_foreign_key)
|
|
120
|
+
return source_foreign_keys
|
|
121
|
+
|
|
122
|
+
def _get_source_sample_df(self) -> pd.DataFrame:
|
|
129
123
|
with self._connection.cursor() as cursor:
|
|
130
|
-
|
|
131
|
-
|
|
124
|
+
columns = [quote_ident(col) for col in self._source_column_dict]
|
|
125
|
+
sql = (f"SELECT {', '.join(columns)} "
|
|
126
|
+
f"FROM {self._quoted_source_name} "
|
|
127
|
+
f"ORDER BY rowid "
|
|
128
|
+
f"LIMIT {self._NUM_SAMPLE_ROWS}")
|
|
132
129
|
cursor.execute(sql)
|
|
133
130
|
table = cursor.fetch_arrow_table()
|
|
134
|
-
|
|
131
|
+
|
|
132
|
+
if len(table) == 0:
|
|
133
|
+
raise RuntimeError(f"Table '{self.source_name}' is empty")
|
|
134
|
+
|
|
135
|
+
return self._sanitize(
|
|
136
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
137
|
+
dtype_dict={
|
|
138
|
+
column.name: column.dtype
|
|
139
|
+
for column in self._source_column_dict.values()
|
|
140
|
+
},
|
|
141
|
+
stype_dict=None,
|
|
142
|
+
)
|
|
135
143
|
|
|
136
144
|
def _get_num_rows(self) -> int | None:
|
|
137
145
|
return None
|
|
146
|
+
|
|
147
|
+
def _get_expr_sample_df(
|
|
148
|
+
self,
|
|
149
|
+
columns: Sequence[ColumnSpec],
|
|
150
|
+
) -> pd.DataFrame:
|
|
151
|
+
with self._connection.cursor() as cursor:
|
|
152
|
+
projections = [
|
|
153
|
+
f"{column.expr} AS {quote_ident(column.name)}"
|
|
154
|
+
for column in columns
|
|
155
|
+
]
|
|
156
|
+
sql = (f"SELECT {', '.join(projections)} "
|
|
157
|
+
f"FROM {self._quoted_source_name} "
|
|
158
|
+
f"ORDER BY rowid "
|
|
159
|
+
f"LIMIT {self._NUM_SAMPLE_ROWS}")
|
|
160
|
+
cursor.execute(sql)
|
|
161
|
+
table = cursor.fetch_arrow_table()
|
|
162
|
+
|
|
163
|
+
if len(table) == 0:
|
|
164
|
+
raise RuntimeError(f"Table '{self.source_name}' is empty")
|
|
165
|
+
|
|
166
|
+
return self._sanitize(
|
|
167
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
168
|
+
dtype_dict={column.name: column.dtype
|
|
169
|
+
for column in columns},
|
|
170
|
+
stype_dict=None,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
@staticmethod
|
|
174
|
+
def _to_dtype(dtype: str | None) -> Dtype | None:
|
|
175
|
+
if dtype is None:
|
|
176
|
+
return None
|
|
177
|
+
dtype = dtype.strip().upper()
|
|
178
|
+
if re.search('INT', dtype):
|
|
179
|
+
return Dtype.int
|
|
180
|
+
if re.search('TEXT|CHAR|CLOB', dtype):
|
|
181
|
+
return Dtype.string
|
|
182
|
+
if re.search('REAL|FLOA|DOUB', dtype):
|
|
183
|
+
return Dtype.float
|
|
184
|
+
return None # NUMERIC affinity.
|
|
@@ -8,11 +8,9 @@ class DataBackend(StrEnum):
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
from .source import SourceColumn, SourceForeignKey # noqa: E402
|
|
11
|
-
from .
|
|
12
|
-
from .
|
|
13
|
-
from .column_expression import ColumnExpressionType # noqa: E402
|
|
11
|
+
from .expression import Expression, LocalExpression # noqa: E402
|
|
12
|
+
from .column import ColumnSpec, ColumnSpecType, Column # noqa: E402
|
|
14
13
|
from .table import Table # noqa: E402
|
|
15
|
-
from .sql_table import SQLTable # noqa: E402
|
|
16
14
|
from .sampler import SamplerOutput, Sampler # noqa: E402
|
|
17
15
|
from .sql_sampler import SQLSampler # noqa: E402
|
|
18
16
|
|
|
@@ -20,11 +18,12 @@ __all__ = [
|
|
|
20
18
|
'DataBackend',
|
|
21
19
|
'SourceColumn',
|
|
22
20
|
'SourceForeignKey',
|
|
21
|
+
'Expression',
|
|
22
|
+
'LocalExpression',
|
|
23
|
+
'ColumnSpec',
|
|
24
|
+
'ColumnSpecType',
|
|
23
25
|
'Column',
|
|
24
|
-
'ColumnExpressionSpec',
|
|
25
|
-
'ColumnExpressionType',
|
|
26
26
|
'Table',
|
|
27
|
-
'SQLTable',
|
|
28
27
|
'SamplerOutput',
|
|
29
28
|
'Sampler',
|
|
30
29
|
'SQLSampler',
|