kumoai 2.14.0.dev202512191731__cp311-cp311-macosx_11_0_arm64.whl → 2.14.0.dev202601051732__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +24 -22
  7. kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
  8. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  9. kumoai/experimental/rfm/backend/local/table.py +24 -25
  10. kumoai/experimental/rfm/backend/snow/sampler.py +106 -61
  11. kumoai/experimental/rfm/backend/snow/table.py +137 -64
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +127 -78
  13. kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
  14. kumoai/experimental/rfm/base/__init__.py +6 -9
  15. kumoai/experimental/rfm/base/column.py +95 -11
  16. kumoai/experimental/rfm/base/expression.py +44 -0
  17. kumoai/experimental/rfm/base/sampler.py +5 -17
  18. kumoai/experimental/rfm/base/source.py +1 -1
  19. kumoai/experimental/rfm/base/sql_sampler.py +69 -9
  20. kumoai/experimental/rfm/base/table.py +258 -97
  21. kumoai/experimental/rfm/graph.py +106 -98
  22. kumoai/experimental/rfm/infer/dtype.py +4 -1
  23. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  24. kumoai/experimental/rfm/relbench.py +76 -0
  25. kumoai/experimental/rfm/rfm.py +394 -241
  26. kumoai/experimental/rfm/task_table.py +290 -0
  27. kumoai/trainer/distilled_trainer.py +175 -0
  28. kumoai/utils/display.py +51 -0
  29. kumoai/utils/progress_logger.py +13 -1
  30. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/METADATA +1 -1
  31. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/RECORD +34 -31
  32. kumoai/experimental/rfm/base/column_expression.py +0 -50
  33. kumoai/experimental/rfm/base/sql_table.py +0 -229
  34. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/WHEEL +0 -0
  35. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/licenses/LICENSE +0 -0
  36. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/top_level.txt +0 -0
@@ -6,9 +6,9 @@ import numpy as np
6
6
  import pandas as pd
7
7
  import pyarrow as pa
8
8
  from kumoapi.pquery import ValidatedPredictiveQuery
9
- from kumoapi.typing import Stype
10
9
 
11
- from kumoai.experimental.rfm.base import SQLSampler
10
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteTable
11
+ from kumoai.experimental.rfm.base import SQLSampler, Table
12
12
  from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
13
13
  from kumoai.utils import ProgressLogger, quote_ident
14
14
 
@@ -25,22 +25,32 @@ class SQLiteSampler(SQLSampler):
25
25
  ) -> None:
26
26
  super().__init__(graph=graph, verbose=verbose)
27
27
 
28
+ for table in graph.tables.values():
29
+ assert isinstance(table, SQLiteTable)
30
+ self._connection = table._connection
31
+
28
32
  if optimize:
29
33
  with self._connection.cursor() as cursor:
30
34
  cursor.execute("PRAGMA temp_store = MEMORY")
31
35
  cursor.execute("PRAGMA cache_size = -2000000") # 2 GB
32
36
 
33
- # Collect database indices to speed-up sampling:
37
+ # Collect database indices for speeding sampling:
34
38
  index_dict: dict[str, set[tuple[str, ...]]] = defaultdict(set)
35
39
  for table_name, primary_key in self.primary_key_dict.items():
36
40
  source_table = self.source_table_dict[table_name]
37
- if not source_table[primary_key].is_unique_key:
38
- index_dict[table_name].add((primary_key, ))
41
+ if primary_key not in source_table:
42
+ continue # No physical column.
43
+ if source_table[primary_key].is_unique_key:
44
+ continue
45
+ index_dict[table_name].add((primary_key, ))
39
46
  for src_table_name, foreign_key, _ in graph.edges:
40
47
  source_table = self.source_table_dict[src_table_name]
48
+ if foreign_key not in source_table:
49
+ continue # No physical column.
41
50
  if source_table[foreign_key].is_unique_key:
42
- pass
43
- elif time_column := self.time_column_dict.get(src_table_name):
51
+ continue
52
+ time_column = self.time_column_dict.get(src_table_name)
53
+ if time_column is not None and time_column in source_table:
44
54
  index_dict[src_table_name].add((foreign_key, time_column))
45
55
  else:
46
56
  index_dict[src_table_name].add((foreign_key, ))
@@ -49,22 +59,22 @@ class SQLiteSampler(SQLSampler):
49
59
  with self._connection.cursor() as cursor:
50
60
  for table_name in list(index_dict.keys()):
51
61
  indices = index_dict[table_name]
52
- sql = f"PRAGMA index_list({self.fqn_dict[table_name]})"
62
+ source_name = self.source_name_dict[table_name]
63
+ sql = f"PRAGMA index_list({source_name})"
53
64
  cursor.execute(sql)
54
65
  for _, index_name, *_ in cursor.fetchall():
55
66
  sql = f"PRAGMA index_info({quote_ident(index_name)})"
56
67
  cursor.execute(sql)
57
- index = tuple(info[2] for info in sorted(
68
+ # Fetch index information and sort by `seqno`:
69
+ index_info = tuple(info[2] for info in sorted(
58
70
  cursor.fetchall(), key=lambda x: x[0]))
59
- indices.discard(index)
71
+ # Remove all indices in case primary index already exists:
72
+ for index in list(indices):
73
+ if index_info[0] == index[0]:
74
+ indices.discard(index)
60
75
  if len(indices) == 0:
61
76
  del index_dict[table_name]
62
77
 
63
- num = sum(len(indices) for indices in index_dict.values())
64
- index_repr = '1 index' if num == 1 else f'{num} indices'
65
- num = len(index_dict)
66
- table_repr = '1 table' if num == 1 else f'{num} tables'
67
-
68
78
  if optimize and len(index_dict) > 0:
69
79
  if not isinstance(verbose, ProgressLogger):
70
80
  verbose = ProgressLogger.default(
@@ -79,16 +89,27 @@ class SQLiteSampler(SQLSampler):
79
89
  name = quote_ident(name)
80
90
  columns = ', '.join(quote_ident(v) for v in index)
81
91
  columns += ' DESC' if len(index) > 1 else ''
92
+ source_name = self.source_name_dict[table_name]
82
93
  sql = (f"CREATE INDEX IF NOT EXISTS {name}\n"
83
- f"ON {self.fqn_dict[table_name]}({columns})")
94
+ f"ON {source_name}({columns})")
84
95
  cursor.execute(sql)
85
- self._connection.commit()
86
- logger.log(f"Created {index_repr} in {table_repr}")
96
+ self._connection.commit()
97
+ if len(index) > 1:
98
+ logger.log(f"Created index on {index} in table "
99
+ f"'{table_name}'")
100
+ else:
101
+ logger.log(f"Created index on '{index[0]}' in "
102
+ f"table '{table_name}'")
87
103
 
88
104
  elif len(index_dict) > 0:
105
+ num = sum(len(indices) for indices in index_dict.values())
106
+ index_repr = '1 index' if num == 1 else f'{num} indices'
107
+ num = len(index_dict)
108
+ table_repr = '1 table' if num == 1 else f'{num} tables'
89
109
  warnings.warn(f"Missing {index_repr} in {table_repr} for optimal "
90
110
  f"database querying. For improving runtime, we "
91
- f"strongly suggest to create these indices by "
111
+ f"strongly suggest to create indices for primary "
112
+ f"and foreign keys, e.g., automatically by "
92
113
  f"instantiating KumoRFM via "
93
114
  f"`KumoRFM(graph, optimize=True)`.")
94
115
 
@@ -98,12 +119,13 @@ class SQLiteSampler(SQLSampler):
98
119
  ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
99
120
  selects: list[str] = []
100
121
  for table_name in table_names:
101
- time_column = self.time_column_dict[table_name]
122
+ column = self.time_column_dict[table_name]
123
+ column_ref = self.table_column_ref_dict[table_name][column]
102
124
  select = (f"SELECT\n"
103
125
  f" ? as table_name,\n"
104
- f" MIN({quote_ident(time_column)}) as min_date,\n"
105
- f" MAX({quote_ident(time_column)}) as max_date\n"
106
- f"FROM {self.fqn_dict[table_name]}")
126
+ f" MIN({column_ref}) as min_date,\n"
127
+ f" MAX({column_ref}) as max_date\n"
128
+ f"FROM {self.source_name_dict[table_name]}")
107
129
  selects.append(select)
108
130
  sql = "\nUNION ALL\n".join(selects)
109
131
 
@@ -126,18 +148,28 @@ class SQLiteSampler(SQLSampler):
126
148
  ) -> pd.DataFrame:
127
149
  # NOTE SQLite does not natively support passing a `random_seed`.
128
150
 
151
+ source_table = self.source_table_dict[table_name]
129
152
  filters: list[str] = []
130
- primary_key = self.primary_key_dict[table_name]
131
- if self.source_table_dict[table_name][primary_key].is_nullable:
132
- filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
133
- time_column = self.time_column_dict.get(table_name)
134
- if (time_column is not None and
135
- self.source_table_dict[table_name][time_column].is_nullable):
136
- filters.append(f" {quote_ident(time_column)} IS NOT NULL")
153
+
154
+ key = self.primary_key_dict[table_name]
155
+ if key not in source_table or source_table[key].is_nullable:
156
+ key_ref = self.table_column_ref_dict[table_name][key]
157
+ filters.append(f" {key_ref} IS NOT NULL")
158
+
159
+ column = self.time_column_dict.get(table_name)
160
+ if column is None:
161
+ pass
162
+ elif column not in source_table or source_table[column].is_nullable:
163
+ column_ref = self.table_column_ref_dict[table_name][column]
164
+ filters.append(f" {column_ref} IS NOT NULL")
137
165
 
138
166
  # TODO Make this query more efficient - it does full table scan.
139
- sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
140
- f"FROM {self.fqn_dict[table_name]}")
167
+ projections = [
168
+ self.table_column_proj_dict[table_name][column]
169
+ for column in columns
170
+ ]
171
+ sql = (f"SELECT {', '.join(projections)}\n"
172
+ f"FROM {self.source_name_dict[table_name]}")
141
173
  if len(filters) > 0:
142
174
  sql += f"\nWHERE{' AND'.join(filters)}"
143
175
  sql += f"\nORDER BY RANDOM() LIMIT {num_rows}"
@@ -147,7 +179,11 @@ class SQLiteSampler(SQLSampler):
147
179
  cursor.execute(sql)
148
180
  table = cursor.fetch_arrow_table()
149
181
 
150
- return self._sanitize(table_name, table)
182
+ return Table._sanitize(
183
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
184
+ dtype_dict=self.table_dtype_dict[table_name],
185
+ stype_dict=self.table_stype_dict[table_name],
186
+ )
151
187
 
152
188
  def _sample_target(
153
189
  self,
@@ -193,37 +229,47 @@ class SQLiteSampler(SQLSampler):
193
229
  pkey: pd.Series,
194
230
  columns: set[str],
195
231
  ) -> tuple[pd.DataFrame, np.ndarray]:
196
- pkey_name = self.primary_key_dict[table_name]
197
-
198
- tmp = pa.table([pa.array(pkey)], names=['id'])
199
- tmp_name = f'tmp_{table_name}_{pkey_name}_{id(tmp)}'
232
+ source_table = self.source_table_dict[table_name]
233
+ key = self.primary_key_dict[table_name]
234
+ key_ref = self.table_column_ref_dict[table_name][key]
235
+ projections = [
236
+ self.table_column_proj_dict[table_name][column]
237
+ for column in columns
238
+ ]
239
+
240
+ tmp = pa.table([pa.array(pkey)], names=['__kumo_id__'])
241
+ tmp_name = f'tmp_{table_name}_{key}_{id(tmp)}'
242
+
243
+ sql = (f"SELECT "
244
+ f"tmp.rowid - 1 as __kumo_batch__, "
245
+ f"{', '.join(projections)}\n"
246
+ f"FROM {quote_ident(tmp_name)} tmp\n"
247
+ f"JOIN {self.source_name_dict[table_name]} ent\n")
200
248
 
201
- if self.source_table_dict[table_name][pkey_name].is_unique_key:
202
- sql = (f"SELECT tmp.rowid - 1 as __batch__, "
203
- f"{', '.join('ent.' + quote_ident(c) for c in columns)}\n"
204
- f"FROM {quote_ident(tmp_name)} tmp\n"
205
- f"JOIN {self.fqn_dict[table_name]} ent\n"
206
- f" ON ent.{quote_ident(pkey_name)} = tmp.id")
249
+ if key in source_table and source_table[key].is_unique_key:
250
+ sql += (f" ON {key_ref} = tmp.__kumo_id__")
207
251
  else:
208
- sql = (f"SELECT tmp.rowid - 1 as __batch__, "
209
- f"{', '.join('ent.' + quote_ident(c) for c in columns)}\n"
210
- f"FROM {quote_ident(tmp_name)} tmp\n"
211
- f"JOIN {self.fqn_dict[table_name]} ent\n"
212
- f" ON ent.rowid = (\n"
213
- f" SELECT rowid FROM {self.fqn_dict[table_name]}\n"
214
- f" WHERE {quote_ident(pkey_name)} == tmp.id\n"
215
- f" LIMIT 1\n"
216
- f")")
252
+ sql += (f" ON ent.rowid = (\n"
253
+ f" SELECT rowid\n"
254
+ f" FROM {self.source_name_dict[table_name]}\n"
255
+ f" WHERE {key_ref} == tmp.__kumo_id__\n"
256
+ f" LIMIT 1\n"
257
+ f")")
217
258
 
218
259
  with self._connection.cursor() as cursor:
219
260
  cursor.adbc_ingest(tmp_name, tmp, mode='replace')
220
261
  cursor.execute(sql)
221
262
  table = cursor.fetch_arrow_table()
222
263
 
223
- batch = table['__batch__'].to_numpy()
224
- table = table.remove_column(table.schema.get_field_index('__batch__'))
264
+ batch = table['__kumo_batch__'].to_numpy()
265
+ batch_index = table.schema.get_field_index('__kumo_batch__')
266
+ table = table.remove_column(batch_index)
225
267
 
226
- return table.to_pandas(), batch # TODO Use `self._sanitize`.
268
+ return Table._sanitize(
269
+ df=table.to_pandas(),
270
+ dtype_dict=self.table_dtype_dict[table_name],
271
+ stype_dict=self.table_stype_dict[table_name],
272
+ ), batch
227
273
 
228
274
  # Helper Methods ##########################################################
229
275
 
@@ -237,37 +283,50 @@ class SQLiteSampler(SQLSampler):
237
283
  max_offset: pd.DateOffset,
238
284
  columns: set[str],
239
285
  ) -> tuple[pd.DataFrame, np.ndarray]:
286
+ time_column = self.time_column_dict[table_name]
287
+
240
288
  # NOTE SQLite does not have a native datetime format. Currently, we
241
289
  # assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
242
- tmp = pa.table([pa.array(pkey)], names=['id'])
290
+ tmp = pa.table([pa.array(pkey)], names=['__kumo_id__'])
243
291
  end_time = anchor_time + max_offset
244
292
  end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
245
- tmp = tmp.append_column('end', pa.array(end_time))
293
+ tmp = tmp.append_column('__kumo_end__', pa.array(end_time))
246
294
  if min_offset is not None:
247
295
  start_time = anchor_time + min_offset
248
296
  start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
249
- tmp = tmp.append_column('start', pa.array(start_time))
297
+ tmp = tmp.append_column('__kumo_start__', pa.array(start_time))
250
298
  tmp_name = f'tmp_{table_name}_{fkey}_{id(tmp)}'
251
299
 
252
- time_column = self.time_column_dict[table_name]
253
- sql = (f"SELECT tmp.rowid - 1 as __batch__, "
254
- f"{', '.join('fact.' + quote_ident(col) for col in columns)}\n"
300
+ key_ref = self.table_column_ref_dict[table_name][fkey]
301
+ time_ref = self.table_column_ref_dict[table_name][time_column]
302
+ projections = [
303
+ self.table_column_proj_dict[table_name][column]
304
+ for column in columns
305
+ ]
306
+ sql = (f"SELECT "
307
+ f"tmp.rowid - 1 as __kumo_batch__, "
308
+ f"{', '.join(projections)}\n"
255
309
  f"FROM {quote_ident(tmp_name)} tmp\n"
256
- f"JOIN {self.fqn_dict[table_name]} fact\n"
257
- f" ON fact.{quote_ident(fkey)} = tmp.id\n"
258
- f" AND fact.{quote_ident(time_column)} <= tmp.end")
310
+ f"JOIN {self.source_name_dict[table_name]} fact\n"
311
+ f" ON {key_ref} = tmp.__kumo_id__\n"
312
+ f" AND {time_ref} <= tmp.__kumo_end__")
259
313
  if min_offset is not None:
260
- sql += f"\n AND fact.{quote_ident(time_column)} > tmp.start"
314
+ sql += f"\n AND {time_ref} > tmp.__kumo_start__"
261
315
 
262
316
  with self._connection.cursor() as cursor:
263
317
  cursor.adbc_ingest(tmp_name, tmp, mode='replace')
264
318
  cursor.execute(sql)
265
319
  table = cursor.fetch_arrow_table()
266
320
 
267
- batch = table['__batch__'].to_numpy()
268
- table = table.remove_column(table.schema.get_field_index('__batch__'))
321
+ batch = table['__kumo_batch__'].to_numpy()
322
+ batch_index = table.schema.get_field_index('__kumo_batch__')
323
+ table = table.remove_column(batch_index)
269
324
 
270
- return self._sanitize(table_name, table), batch
325
+ return Table._sanitize(
326
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
327
+ dtype_dict=self.table_dtype_dict[table_name],
328
+ stype_dict=self.table_stype_dict[table_name],
329
+ ), batch
271
330
 
272
331
  def _sample_target_set(
273
332
  self,
@@ -337,13 +396,3 @@ class SQLiteSampler(SQLSampler):
337
396
  y = pd.concat(ys, axis=0, ignore_index=True)
338
397
 
339
398
  return y, mask
340
-
341
- def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
342
- df = table.to_pandas(types_mapper=pd.ArrowDtype)
343
-
344
- stype_dict = self.table_stype_dict[table_name]
345
- for column_name in df.columns:
346
- if stype_dict.get(column_name) == Stype.timestamp:
347
- df[column_name] = pd.to_datetime(df[column_name])
348
-
349
- return df
@@ -1,5 +1,5 @@
1
1
  import re
2
- import warnings
2
+ from collections import Counter
3
3
  from collections.abc import Sequence
4
4
  from typing import cast
5
5
 
@@ -9,27 +9,25 @@ from kumoapi.typing import Dtype
9
9
 
10
10
  from kumoai.experimental.rfm.backend.sqlite import Connection
11
11
  from kumoai.experimental.rfm.base import (
12
- ColumnExpressionSpec,
13
- ColumnExpressionType,
12
+ ColumnSpec,
13
+ ColumnSpecType,
14
14
  DataBackend,
15
15
  SourceColumn,
16
16
  SourceForeignKey,
17
- SQLTable,
17
+ Table,
18
18
  )
19
- from kumoai.experimental.rfm.infer import infer_dtype
20
19
  from kumoai.utils import quote_ident
21
20
 
22
21
 
23
- class SQLiteTable(SQLTable):
22
+ class SQLiteTable(Table):
24
23
  r"""A table backed by a :class:`sqlite` database.
25
24
 
26
25
  Args:
27
26
  connection: The connection to a :class:`sqlite` database.
28
- name: The logical name of this table.
29
- source_name: The physical name of this table in the database. If set to
30
- ``None``, ``name`` is being used.
31
- columns: The selected physical columns of this table.
32
- column_expressions: The logical columns of this table.
27
+ name: The name of this table.
28
+ source_name: The source name of this table. If set to ``None``,
29
+ ``name`` is being used.
30
+ columns: The selected columns of this table.
33
31
  primary_key: The name of the primary key of this table, if it exists.
34
32
  time_column: The name of the time column of this table, if it exists.
35
33
  end_time_column: The name of the end time column of this table, if it
@@ -40,8 +38,7 @@ class SQLiteTable(SQLTable):
40
38
  connection: Connection,
41
39
  name: str,
42
40
  source_name: str | None = None,
43
- columns: Sequence[str] | None = None,
44
- column_expressions: Sequence[ColumnExpressionType] | None = None,
41
+ columns: Sequence[ColumnSpecType] | None = None,
45
42
  primary_key: MissingType | str | None = MissingType.VALUE,
46
43
  time_column: str | None = None,
47
44
  end_time_column: str | None = None,
@@ -53,7 +50,6 @@ class SQLiteTable(SQLTable):
53
50
  name=name,
54
51
  source_name=source_name,
55
52
  columns=columns,
56
- column_expressions=column_expressions,
57
53
  primary_key=primary_key,
58
54
  time_column=time_column,
59
55
  end_time_column=end_time_column,
@@ -66,16 +62,16 @@ class SQLiteTable(SQLTable):
66
62
  def _get_source_columns(self) -> list[SourceColumn]:
67
63
  source_columns: list[SourceColumn] = []
68
64
  with self._connection.cursor() as cursor:
69
- sql = f"PRAGMA table_info({self.fqn})"
65
+ sql = f"PRAGMA table_info({self._quoted_source_name})"
70
66
  cursor.execute(sql)
71
67
  columns = cursor.fetchall()
72
68
 
73
69
  if len(columns) == 0:
74
- raise ValueError(f"Table '{self._source_name}' does not exist "
70
+ raise ValueError(f"Table '{self.source_name}' does not exist "
75
71
  f"in the SQLite database")
76
72
 
77
73
  unique_keys: set[str] = set()
78
- sql = f"PRAGMA index_list({self.fqn})"
74
+ sql = f"PRAGMA index_list({self._quoted_source_name})"
79
75
  cursor.execute(sql)
80
76
  for _, index_name, is_unique, *_ in cursor.fetchall():
81
77
  if bool(is_unique):
@@ -85,32 +81,19 @@ class SQLiteTable(SQLTable):
85
81
  if len(index) == 1:
86
82
  unique_keys.add(index[0][2])
87
83
 
88
- for _, column, type, notnull, _, is_pkey in columns:
89
- # Determine column affinity:
90
- type = type.strip().upper()
91
- if re.search('INT', type):
92
- dtype = Dtype.int
93
- elif re.search('TEXT|CHAR|CLOB', type):
94
- dtype = Dtype.string
95
- elif re.search('REAL|FLOA|DOUB', type):
96
- dtype = Dtype.float
97
- else: # NUMERIC affinity.
98
- ser = self._source_sample_df[column]
99
- try:
100
- dtype = infer_dtype(ser)
101
- except Exception:
102
- warnings.warn(f"Encountered unsupported data type "
103
- f"'{ser.dtype}' with source data type "
104
- f"'{type}' for column '{column}' in "
105
- f"table '{self.name}'. If possible, "
106
- f"change the data type of the column in "
107
- f"your SQLite database to use it within "
108
- f"this table.")
109
- continue
84
+ # Special SQLite case that creates a rowid alias for
85
+ # `INTEGER PRIMARY KEY` annotated columns:
86
+ rowid_candidates = [
87
+ column for _, column, dtype, _, _, is_pkey in columns
88
+ if bool(is_pkey) and dtype.strip().upper() == 'INTEGER'
89
+ ]
90
+ if len(rowid_candidates) == 1:
91
+ unique_keys.add(rowid_candidates[0])
110
92
 
93
+ for _, column, dtype, notnull, _, is_pkey in columns:
111
94
  source_column = SourceColumn(
112
95
  name=column,
113
- dtype=dtype,
96
+ dtype=self._to_dtype(dtype),
114
97
  is_primary_key=bool(is_pkey),
115
98
  is_unique_key=column in unique_keys,
116
99
  is_nullable=not bool(is_pkey) and not bool(notnull),
@@ -120,35 +103,82 @@ class SQLiteTable(SQLTable):
120
103
  return source_columns
121
104
 
122
105
  def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
123
- source_fkeys: list[SourceForeignKey] = []
106
+ source_foreign_keys: list[SourceForeignKey] = []
124
107
  with self._connection.cursor() as cursor:
125
- sql = f"PRAGMA foreign_key_list({self.fqn})"
108
+ sql = f"PRAGMA foreign_key_list({self._quoted_source_name})"
126
109
  cursor.execute(sql)
127
- for _, _, dst_table, fkey, pkey, *_ in cursor.fetchall():
128
- source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
129
- return source_fkeys
110
+ rows = cursor.fetchall()
111
+ counts = Counter(row[0] for row in rows)
112
+ for idx, _, dst_table, foreign_key, primary_key, *_ in rows:
113
+ if counts[idx] == 1:
114
+ source_foreign_key = SourceForeignKey(
115
+ name=foreign_key,
116
+ dst_table=dst_table,
117
+ primary_key=primary_key,
118
+ )
119
+ source_foreign_keys.append(source_foreign_key)
120
+ return source_foreign_keys
130
121
 
131
122
  def _get_source_sample_df(self) -> pd.DataFrame:
132
123
  with self._connection.cursor() as cursor:
133
- sql = (f"SELECT * FROM {self.fqn} "
134
- f"ORDER BY rowid LIMIT 1000")
124
+ columns = [quote_ident(col) for col in self._source_column_dict]
125
+ sql = (f"SELECT {', '.join(columns)} "
126
+ f"FROM {self._quoted_source_name} "
127
+ f"ORDER BY rowid "
128
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
135
129
  cursor.execute(sql)
136
130
  table = cursor.fetch_arrow_table()
137
- return table.to_pandas(types_mapper=pd.ArrowDtype)
131
+
132
+ if len(table) == 0:
133
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
134
+
135
+ return self._sanitize(
136
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
137
+ dtype_dict={
138
+ column.name: column.dtype
139
+ for column in self._source_column_dict.values()
140
+ },
141
+ stype_dict=None,
142
+ )
138
143
 
139
144
  def _get_num_rows(self) -> int | None:
140
145
  return None
141
146
 
142
- def _get_expression_sample_df(
147
+ def _get_expr_sample_df(
143
148
  self,
144
- specs: Sequence[ColumnExpressionSpec],
149
+ columns: Sequence[ColumnSpec],
145
150
  ) -> pd.DataFrame:
146
151
  with self._connection.cursor() as cursor:
147
- columns = [
148
- f"{spec.expr} AS {quote_ident(spec.name)}" for spec in specs
152
+ projections = [
153
+ f"{column.expr} AS {quote_ident(column.name)}"
154
+ for column in columns
149
155
  ]
150
- sql = (f"SELECT {', '.join(columns)} FROM {self.fqn} "
151
- f"ORDER BY rowid LIMIT 1000")
156
+ sql = (f"SELECT {', '.join(projections)} "
157
+ f"FROM {self._quoted_source_name} "
158
+ f"ORDER BY rowid "
159
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
152
160
  cursor.execute(sql)
153
161
  table = cursor.fetch_arrow_table()
154
- return table.to_pandas(types_mapper=pd.ArrowDtype)
162
+
163
+ if len(table) == 0:
164
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
165
+
166
+ return self._sanitize(
167
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
168
+ dtype_dict={column.name: column.dtype
169
+ for column in columns},
170
+ stype_dict=None,
171
+ )
172
+
173
+ @staticmethod
174
+ def _to_dtype(dtype: str | None) -> Dtype | None:
175
+ if dtype is None:
176
+ return None
177
+ dtype = dtype.strip().upper()
178
+ if re.search('INT', dtype):
179
+ return Dtype.int
180
+ if re.search('TEXT|CHAR|CLOB', dtype):
181
+ return Dtype.string
182
+ if re.search('REAL|FLOA|DOUB', dtype):
183
+ return Dtype.float
184
+ return None # NUMERIC affinity.
@@ -8,12 +8,9 @@ class DataBackend(StrEnum):
8
8
 
9
9
 
10
10
  from .source import SourceColumn, SourceForeignKey # noqa: E402
11
- from .column import Column # noqa: E402
12
- from .column_expression import ColumnExpressionSpec # noqa: E402
13
- from .column_expression import ColumnExpressionType # noqa: E402
14
- from .column_expression import ColumnExpression # noqa: E402
11
+ from .expression import Expression, LocalExpression # noqa: E402
12
+ from .column import ColumnSpec, ColumnSpecType, Column # noqa: E402
15
13
  from .table import Table # noqa: E402
16
- from .sql_table import SQLTable # noqa: E402
17
14
  from .sampler import SamplerOutput, Sampler # noqa: E402
18
15
  from .sql_sampler import SQLSampler # noqa: E402
19
16
 
@@ -21,12 +18,12 @@ __all__ = [
21
18
  'DataBackend',
22
19
  'SourceColumn',
23
20
  'SourceForeignKey',
21
+ 'Expression',
22
+ 'LocalExpression',
23
+ 'ColumnSpec',
24
+ 'ColumnSpecType',
24
25
  'Column',
25
- 'ColumnExpressionSpec',
26
- 'ColumnExpressionType',
27
- 'ColumnExpression',
28
26
  'Table',
29
- 'SQLTable',
30
27
  'SamplerOutput',
31
28
  'Sampler',
32
29
  'SQLSampler',