kumoai 2.14.0.dev202512211732__cp313-cp313-win_amd64.whl → 2.15.0.dev202601151732__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +24 -22
  7. kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
  8. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  9. kumoai/experimental/rfm/backend/local/table.py +24 -25
  10. kumoai/experimental/rfm/backend/snow/sampler.py +235 -80
  11. kumoai/experimental/rfm/backend/snow/table.py +146 -70
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +196 -89
  13. kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
  14. kumoai/experimental/rfm/base/__init__.py +6 -9
  15. kumoai/experimental/rfm/base/column.py +95 -11
  16. kumoai/experimental/rfm/base/expression.py +44 -0
  17. kumoai/experimental/rfm/base/mapper.py +69 -0
  18. kumoai/experimental/rfm/base/sampler.py +28 -18
  19. kumoai/experimental/rfm/base/source.py +1 -1
  20. kumoai/experimental/rfm/base/sql_sampler.py +320 -19
  21. kumoai/experimental/rfm/base/table.py +256 -109
  22. kumoai/experimental/rfm/base/utils.py +36 -0
  23. kumoai/experimental/rfm/graph.py +130 -110
  24. kumoai/experimental/rfm/infer/dtype.py +7 -2
  25. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  26. kumoai/experimental/rfm/infer/time_col.py +4 -2
  27. kumoai/experimental/rfm/relbench.py +76 -0
  28. kumoai/experimental/rfm/rfm.py +540 -306
  29. kumoai/experimental/rfm/task_table.py +292 -0
  30. kumoai/kumolib.cp313-win_amd64.pyd +0 -0
  31. kumoai/pquery/training_table.py +16 -2
  32. kumoai/testing/snow.py +3 -3
  33. kumoai/trainer/distilled_trainer.py +175 -0
  34. kumoai/utils/display.py +87 -0
  35. kumoai/utils/progress_logger.py +15 -2
  36. kumoai/utils/sql.py +2 -2
  37. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/METADATA +2 -2
  38. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/RECORD +41 -36
  39. kumoai/experimental/rfm/base/column_expression.py +0 -50
  40. kumoai/experimental/rfm/base/sql_table.py +0 -229
  41. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/WHEEL +0 -0
  42. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/top_level.txt +0 -0
@@ -6,9 +6,9 @@ import numpy as np
6
6
  import pandas as pd
7
7
  import pyarrow as pa
8
8
  from kumoapi.pquery import ValidatedPredictiveQuery
9
- from kumoapi.typing import Stype
10
9
 
11
- from kumoai.experimental.rfm.base import SQLSampler
10
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteTable
11
+ from kumoai.experimental.rfm.base import SQLSampler, Table
12
12
  from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
13
13
  from kumoai.utils import ProgressLogger, quote_ident
14
14
 
@@ -25,22 +25,32 @@ class SQLiteSampler(SQLSampler):
25
25
  ) -> None:
26
26
  super().__init__(graph=graph, verbose=verbose)
27
27
 
28
+ for table in graph.tables.values():
29
+ assert isinstance(table, SQLiteTable)
30
+ self._connection = table._connection
31
+
28
32
  if optimize:
29
33
  with self._connection.cursor() as cursor:
30
34
  cursor.execute("PRAGMA temp_store = MEMORY")
31
35
  cursor.execute("PRAGMA cache_size = -2000000") # 2 GB
32
36
 
33
- # Collect database indices to speed-up sampling:
37
+ # Collect database indices for speeding sampling:
34
38
  index_dict: dict[str, set[tuple[str, ...]]] = defaultdict(set)
35
39
  for table_name, primary_key in self.primary_key_dict.items():
36
40
  source_table = self.source_table_dict[table_name]
37
- if not source_table[primary_key].is_unique_key:
38
- index_dict[table_name].add((primary_key, ))
41
+ if primary_key not in source_table:
42
+ continue # No physical column.
43
+ if source_table[primary_key].is_unique_key:
44
+ continue
45
+ index_dict[table_name].add((primary_key, ))
39
46
  for src_table_name, foreign_key, _ in graph.edges:
40
47
  source_table = self.source_table_dict[src_table_name]
48
+ if foreign_key not in source_table:
49
+ continue # No physical column.
41
50
  if source_table[foreign_key].is_unique_key:
42
- pass
43
- elif time_column := self.time_column_dict.get(src_table_name):
51
+ continue
52
+ time_column = self.time_column_dict.get(src_table_name)
53
+ if time_column is not None and time_column in source_table:
44
54
  index_dict[src_table_name].add((foreign_key, time_column))
45
55
  else:
46
56
  index_dict[src_table_name].add((foreign_key, ))
@@ -49,22 +59,22 @@ class SQLiteSampler(SQLSampler):
49
59
  with self._connection.cursor() as cursor:
50
60
  for table_name in list(index_dict.keys()):
51
61
  indices = index_dict[table_name]
52
- sql = f"PRAGMA index_list({self.fqn_dict[table_name]})"
62
+ source_name = self.source_name_dict[table_name]
63
+ sql = f"PRAGMA index_list({source_name})"
53
64
  cursor.execute(sql)
54
65
  for _, index_name, *_ in cursor.fetchall():
55
66
  sql = f"PRAGMA index_info({quote_ident(index_name)})"
56
67
  cursor.execute(sql)
57
- index = tuple(info[2] for info in sorted(
68
+ # Fetch index information and sort by `seqno`:
69
+ index_info = tuple(info[2] for info in sorted(
58
70
  cursor.fetchall(), key=lambda x: x[0]))
59
- indices.discard(index)
71
+ # Remove all indices in case primary index already exists:
72
+ for index in list(indices):
73
+ if index_info[0] == index[0]:
74
+ indices.discard(index)
60
75
  if len(indices) == 0:
61
76
  del index_dict[table_name]
62
77
 
63
- num = sum(len(indices) for indices in index_dict.values())
64
- index_repr = '1 index' if num == 1 else f'{num} indices'
65
- num = len(index_dict)
66
- table_repr = '1 table' if num == 1 else f'{num} tables'
67
-
68
78
  if optimize and len(index_dict) > 0:
69
79
  if not isinstance(verbose, ProgressLogger):
70
80
  verbose = ProgressLogger.default(
@@ -79,16 +89,27 @@ class SQLiteSampler(SQLSampler):
79
89
  name = quote_ident(name)
80
90
  columns = ', '.join(quote_ident(v) for v in index)
81
91
  columns += ' DESC' if len(index) > 1 else ''
92
+ source_name = self.source_name_dict[table_name]
82
93
  sql = (f"CREATE INDEX IF NOT EXISTS {name}\n"
83
- f"ON {self.fqn_dict[table_name]}({columns})")
94
+ f"ON {source_name}({columns})")
84
95
  cursor.execute(sql)
85
- self._connection.commit()
86
- logger.log(f"Created {index_repr} in {table_repr}")
96
+ self._connection.commit()
97
+ if len(index) > 1:
98
+ logger.log(f"Created index on {index} in table "
99
+ f"'{table_name}'")
100
+ else:
101
+ logger.log(f"Created index on '{index[0]}' in "
102
+ f"table '{table_name}'")
87
103
 
88
104
  elif len(index_dict) > 0:
105
+ num = sum(len(indices) for indices in index_dict.values())
106
+ index_repr = '1 index' if num == 1 else f'{num} indices'
107
+ num = len(index_dict)
108
+ table_repr = '1 table' if num == 1 else f'{num} tables'
89
109
  warnings.warn(f"Missing {index_repr} in {table_repr} for optimal "
90
110
  f"database querying. For improving runtime, we "
91
- f"strongly suggest to create these indices by "
111
+ f"strongly suggest to create indices for primary "
112
+ f"and foreign keys, e.g., automatically by "
92
113
  f"instantiating KumoRFM via "
93
114
  f"`KumoRFM(graph, optimize=True)`.")
94
115
 
@@ -98,23 +119,26 @@ class SQLiteSampler(SQLSampler):
98
119
  ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
99
120
  selects: list[str] = []
100
121
  for table_name in table_names:
101
- time_column = self.time_column_dict[table_name]
122
+ column = self.time_column_dict[table_name]
123
+ column_ref = self.table_column_ref_dict[table_name][column]
124
+ ident = quote_ident(table_name, char="'")
102
125
  select = (f"SELECT\n"
103
- f" ? as table_name,\n"
104
- f" MIN({quote_ident(time_column)}) as min_date,\n"
105
- f" MAX({quote_ident(time_column)}) as max_date\n"
106
- f"FROM {self.fqn_dict[table_name]}")
126
+ f" {ident} as table_name,\n"
127
+ f" MIN({column_ref}) as min_date,\n"
128
+ f" MAX({column_ref}) as max_date\n"
129
+ f"FROM {self.source_name_dict[table_name]}")
107
130
  selects.append(select)
108
131
  sql = "\nUNION ALL\n".join(selects)
109
132
 
110
133
  out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
111
134
  with self._connection.cursor() as cursor:
112
- cursor.execute(sql, table_names)
135
+ cursor.execute(sql)
113
136
  for table_name, _min, _max in cursor.fetchall():
114
137
  out_dict[table_name] = (
115
138
  pd.Timestamp.max if _min is None else pd.Timestamp(_min),
116
139
  pd.Timestamp.min if _max is None else pd.Timestamp(_max),
117
140
  )
141
+
118
142
  return out_dict
119
143
 
120
144
  def _sample_entity_table(
@@ -126,18 +150,28 @@ class SQLiteSampler(SQLSampler):
126
150
  ) -> pd.DataFrame:
127
151
  # NOTE SQLite does not natively support passing a `random_seed`.
128
152
 
153
+ source_table = self.source_table_dict[table_name]
129
154
  filters: list[str] = []
130
- primary_key = self.primary_key_dict[table_name]
131
- if self.source_table_dict[table_name][primary_key].is_nullable:
132
- filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
133
- time_column = self.time_column_dict.get(table_name)
134
- if (time_column is not None and
135
- self.source_table_dict[table_name][time_column].is_nullable):
136
- filters.append(f" {quote_ident(time_column)} IS NOT NULL")
155
+
156
+ key = self.primary_key_dict[table_name]
157
+ if key not in source_table or source_table[key].is_nullable:
158
+ key_ref = self.table_column_ref_dict[table_name][key]
159
+ filters.append(f" {key_ref} IS NOT NULL")
160
+
161
+ column = self.time_column_dict.get(table_name)
162
+ if column is None:
163
+ pass
164
+ elif column not in source_table or source_table[column].is_nullable:
165
+ column_ref = self.table_column_ref_dict[table_name][column]
166
+ filters.append(f" {column_ref} IS NOT NULL")
137
167
 
138
168
  # TODO Make this query more efficient - it does full table scan.
139
- sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
140
- f"FROM {self.fqn_dict[table_name]}")
169
+ projections = [
170
+ self.table_column_proj_dict[table_name][column]
171
+ for column in columns
172
+ ]
173
+ sql = (f"SELECT {', '.join(projections)}\n"
174
+ f"FROM {self.source_name_dict[table_name]}")
141
175
  if len(filters) > 0:
142
176
  sql += f"\nWHERE{' AND'.join(filters)}"
143
177
  sql += f"\nORDER BY RANDOM() LIMIT {num_rows}"
@@ -147,7 +181,11 @@ class SQLiteSampler(SQLSampler):
147
181
  cursor.execute(sql)
148
182
  table = cursor.fetch_arrow_table()
149
183
 
150
- return self._sanitize(table_name, table)
184
+ return Table._sanitize(
185
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
186
+ dtype_dict=self.table_dtype_dict[table_name],
187
+ stype_dict=self.table_stype_dict[table_name],
188
+ )
151
189
 
152
190
  def _sample_target(
153
191
  self,
@@ -190,84 +228,163 @@ class SQLiteSampler(SQLSampler):
190
228
  def _by_pkey(
191
229
  self,
192
230
  table_name: str,
193
- pkey: pd.Series,
231
+ index: pd.Series,
194
232
  columns: set[str],
195
233
  ) -> tuple[pd.DataFrame, np.ndarray]:
196
- pkey_name = self.primary_key_dict[table_name]
234
+ source_table = self.source_table_dict[table_name]
235
+ key = self.primary_key_dict[table_name]
236
+ key_ref = self.table_column_ref_dict[table_name][key]
237
+ projections = [
238
+ self.table_column_proj_dict[table_name][column]
239
+ for column in columns
240
+ ]
241
+
242
+ tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
243
+ tmp_name = f'tmp_{table_name}_{key}_{id(tmp)}'
244
+
245
+ sql = (f"SELECT "
246
+ f"tmp.rowid - 1 as __kumo_batch__, "
247
+ f"{', '.join(projections)}\n"
248
+ f"FROM {quote_ident(tmp_name)} tmp\n"
249
+ f"JOIN {self.source_name_dict[table_name]} ent\n")
250
+ if key in source_table and source_table[key].is_unique_key:
251
+ sql += (f" ON {key_ref} = tmp.__kumo_id__")
252
+ else:
253
+ sql += (f" ON ent.rowid = (\n"
254
+ f" SELECT rowid\n"
255
+ f" FROM {self.source_name_dict[table_name]}\n"
256
+ f" WHERE {key_ref} == tmp.__kumo_id__\n"
257
+ f" LIMIT 1\n"
258
+ f")")
259
+
260
+ with self._connection.cursor() as cursor:
261
+ cursor.adbc_ingest(tmp_name, tmp, mode='replace')
262
+ cursor.execute(sql)
263
+ table = cursor.fetch_arrow_table()
197
264
 
198
- tmp = pa.table([pa.array(pkey)], names=['id'])
199
- tmp_name = f'tmp_{table_name}_{pkey_name}_{id(tmp)}'
265
+ batch = table['__kumo_batch__'].to_numpy()
266
+ batch_index = table.schema.get_field_index('__kumo_batch__')
267
+ table = table.remove_column(batch_index)
200
268
 
201
- if self.source_table_dict[table_name][pkey_name].is_unique_key:
202
- sql = (f"SELECT tmp.rowid - 1 as __batch__, "
203
- f"{', '.join('ent.' + quote_ident(c) for c in columns)}\n"
204
- f"FROM {quote_ident(tmp_name)} tmp\n"
205
- f"JOIN {self.fqn_dict[table_name]} ent\n"
206
- f" ON ent.{quote_ident(pkey_name)} = tmp.id")
207
- else:
208
- sql = (f"SELECT tmp.rowid - 1 as __batch__, "
209
- f"{', '.join('ent.' + quote_ident(c) for c in columns)}\n"
210
- f"FROM {quote_ident(tmp_name)} tmp\n"
211
- f"JOIN {self.fqn_dict[table_name]} ent\n"
212
- f" ON ent.rowid = (\n"
213
- f" SELECT rowid FROM {self.fqn_dict[table_name]}\n"
214
- f" WHERE {quote_ident(pkey_name)} == tmp.id\n"
215
- f" LIMIT 1\n"
216
- f")")
269
+ return Table._sanitize(
270
+ df=table.to_pandas(),
271
+ dtype_dict=self.table_dtype_dict[table_name],
272
+ stype_dict=self.table_stype_dict[table_name],
273
+ ), batch
274
+
275
+ def _by_fkey(
276
+ self,
277
+ table_name: str,
278
+ foreign_key: str,
279
+ index: pd.Series,
280
+ num_neighbors: int,
281
+ anchor_time: pd.Series | None,
282
+ columns: set[str],
283
+ ) -> tuple[pd.DataFrame, np.ndarray]:
284
+ time_column = self.time_column_dict.get(table_name)
285
+
286
+ # NOTE SQLite does not have a native datetime format. Currently, we
287
+ # assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
288
+ tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
289
+ if time_column is not None and anchor_time is not None:
290
+ anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
291
+ tmp = tmp.append_column('__kumo_time__', pa.array(anchor_time))
292
+ tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
293
+
294
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
295
+ projections = [
296
+ self.table_column_proj_dict[table_name][column]
297
+ for column in columns
298
+ ]
299
+ sql = (f"SELECT "
300
+ f"tmp.rowid - 1 as __kumo_batch__, "
301
+ f"{', '.join(projections)}\n"
302
+ f"FROM {quote_ident(tmp_name)} tmp\n"
303
+ f"JOIN {self.source_name_dict[table_name]} fact\n"
304
+ f"ON fact.rowid IN (\n"
305
+ f" SELECT rowid\n"
306
+ f" FROM {self.source_name_dict[table_name]}\n"
307
+ f" WHERE {key_ref} = tmp.__kumo_id__\n")
308
+ if time_column is not None and anchor_time is not None:
309
+ time_ref = self.table_column_ref_dict[table_name][time_column]
310
+ sql += f" AND {time_ref} <= tmp.__kumo_time__\n"
311
+ if time_column is not None:
312
+ time_ref = self.table_column_ref_dict[table_name][time_column]
313
+ sql += f" ORDER BY {time_ref} DESC\n"
314
+ sql += (f" LIMIT {num_neighbors}\n"
315
+ f")")
217
316
 
218
317
  with self._connection.cursor() as cursor:
219
318
  cursor.adbc_ingest(tmp_name, tmp, mode='replace')
220
319
  cursor.execute(sql)
221
320
  table = cursor.fetch_arrow_table()
222
321
 
223
- batch = table['__batch__'].to_numpy()
224
- table = table.remove_column(table.schema.get_field_index('__batch__'))
322
+ batch = table['__kumo_batch__'].to_numpy()
323
+ batch_index = table.schema.get_field_index('__kumo_batch__')
324
+ table = table.remove_column(batch_index)
225
325
 
226
- return table.to_pandas(), batch # TODO Use `self._sanitize`.
326
+ return Table._sanitize(
327
+ df=table.to_pandas(),
328
+ dtype_dict=self.table_dtype_dict[table_name],
329
+ stype_dict=self.table_stype_dict[table_name],
330
+ ), batch
227
331
 
228
332
  # Helper Methods ##########################################################
229
333
 
230
334
  def _by_time(
231
335
  self,
232
336
  table_name: str,
233
- fkey: str,
234
- pkey: pd.Series,
337
+ foreign_key: str,
338
+ index: pd.Series,
235
339
  anchor_time: pd.Series,
236
340
  min_offset: pd.DateOffset | None,
237
341
  max_offset: pd.DateOffset,
238
342
  columns: set[str],
239
343
  ) -> tuple[pd.DataFrame, np.ndarray]:
344
+ time_column = self.time_column_dict[table_name]
345
+
240
346
  # NOTE SQLite does not have a native datetime format. Currently, we
241
347
  # assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
242
- tmp = pa.table([pa.array(pkey)], names=['id'])
348
+ tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
243
349
  end_time = anchor_time + max_offset
244
350
  end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
245
- tmp = tmp.append_column('end', pa.array(end_time))
351
+ tmp = tmp.append_column('__kumo_end__', pa.array(end_time))
246
352
  if min_offset is not None:
247
353
  start_time = anchor_time + min_offset
248
354
  start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
249
- tmp = tmp.append_column('start', pa.array(start_time))
250
- tmp_name = f'tmp_{table_name}_{fkey}_{id(tmp)}'
251
-
252
- time_column = self.time_column_dict[table_name]
253
- sql = (f"SELECT tmp.rowid - 1 as __batch__, "
254
- f"{', '.join('fact.' + quote_ident(col) for col in columns)}\n"
355
+ tmp = tmp.append_column('__kumo_start__', pa.array(start_time))
356
+ tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
357
+
358
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
359
+ time_ref = self.table_column_ref_dict[table_name][time_column]
360
+ projections = [
361
+ self.table_column_proj_dict[table_name][column]
362
+ for column in columns
363
+ ]
364
+ sql = (f"SELECT "
365
+ f"tmp.rowid - 1 as __kumo_batch__, "
366
+ f"{', '.join(projections)}\n"
255
367
  f"FROM {quote_ident(tmp_name)} tmp\n"
256
- f"JOIN {self.fqn_dict[table_name]} fact\n"
257
- f" ON fact.{quote_ident(fkey)} = tmp.id\n"
258
- f" AND fact.{quote_ident(time_column)} <= tmp.end")
368
+ f"JOIN {self.source_name_dict[table_name]}\n"
369
+ f" ON {key_ref} = tmp.__kumo_id__\n"
370
+ f" AND {time_ref} <= tmp.__kumo_end__")
259
371
  if min_offset is not None:
260
- sql += f"\n AND fact.{quote_ident(time_column)} > tmp.start"
372
+ sql += f"\n AND {time_ref} > tmp.__kumo_start__"
261
373
 
262
374
  with self._connection.cursor() as cursor:
263
375
  cursor.adbc_ingest(tmp_name, tmp, mode='replace')
264
376
  cursor.execute(sql)
265
377
  table = cursor.fetch_arrow_table()
266
378
 
267
- batch = table['__batch__'].to_numpy()
268
- table = table.remove_column(table.schema.get_field_index('__batch__'))
379
+ batch = table['__kumo_batch__'].to_numpy()
380
+ batch_index = table.schema.get_field_index('__kumo_batch__')
381
+ table = table.remove_column(batch_index)
269
382
 
270
- return self._sanitize(table_name, table), batch
383
+ return Table._sanitize(
384
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
385
+ dtype_dict=self.table_dtype_dict[table_name],
386
+ stype_dict=self.table_stype_dict[table_name],
387
+ ), batch
271
388
 
272
389
  def _sample_target_set(
273
390
  self,
@@ -300,11 +417,11 @@ class SQLiteSampler(SQLSampler):
300
417
  query.entity_table: np.arange(len(df)),
301
418
  }
302
419
  for edge_type, (_min, _max) in time_offset_dict.items():
303
- table_name, fkey, _ = edge_type
420
+ table_name, foreign_key, _ = edge_type
304
421
  feat_dict[table_name], batch_dict[table_name] = self._by_time(
305
422
  table_name=table_name,
306
- fkey=fkey,
307
- pkey=df[self.primary_key_dict[query.entity_table]],
423
+ foreign_key=foreign_key,
424
+ index=df[self.primary_key_dict[query.entity_table]],
308
425
  anchor_time=time,
309
426
  min_offset=_min,
310
427
  max_offset=_max,
@@ -319,7 +436,7 @@ class SQLiteSampler(SQLSampler):
319
436
  feat_dict=feat_dict,
320
437
  time_dict=time_dict,
321
438
  batch_dict=batch_dict,
322
- anchor_time=anchor_time,
439
+ anchor_time=time,
323
440
  num_forecasts=query.num_forecasts,
324
441
  )
325
442
  ys.append(y)
@@ -337,13 +454,3 @@ class SQLiteSampler(SQLSampler):
337
454
  y = pd.concat(ys, axis=0, ignore_index=True)
338
455
 
339
456
  return y, mask
340
-
341
- def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
342
- df = table.to_pandas(types_mapper=pd.ArrowDtype)
343
-
344
- stype_dict = self.table_stype_dict[table_name]
345
- for column_name in df.columns:
346
- if stype_dict.get(column_name) == Stype.timestamp:
347
- df[column_name] = pd.to_datetime(df[column_name])
348
-
349
- return df
@@ -1,5 +1,5 @@
1
1
  import re
2
- import warnings
2
+ from collections import Counter
3
3
  from collections.abc import Sequence
4
4
  from typing import cast
5
5
 
@@ -9,27 +9,25 @@ from kumoapi.typing import Dtype
9
9
 
10
10
  from kumoai.experimental.rfm.backend.sqlite import Connection
11
11
  from kumoai.experimental.rfm.base import (
12
- ColumnExpressionSpec,
13
- ColumnExpressionType,
12
+ ColumnSpec,
13
+ ColumnSpecType,
14
14
  DataBackend,
15
15
  SourceColumn,
16
16
  SourceForeignKey,
17
- SQLTable,
17
+ Table,
18
18
  )
19
- from kumoai.experimental.rfm.infer import infer_dtype
20
19
  from kumoai.utils import quote_ident
21
20
 
22
21
 
23
- class SQLiteTable(SQLTable):
22
+ class SQLiteTable(Table):
24
23
  r"""A table backed by a :class:`sqlite` database.
25
24
 
26
25
  Args:
27
26
  connection: The connection to a :class:`sqlite` database.
28
- name: The logical name of this table.
29
- source_name: The physical name of this table in the database. If set to
30
- ``None``, ``name`` is being used.
31
- columns: The selected physical columns of this table.
32
- column_expressions: The logical columns of this table.
27
+ name: The name of this table.
28
+ source_name: The source name of this table. If set to ``None``,
29
+ ``name`` is being used.
30
+ columns: The selected columns of this table.
33
31
  primary_key: The name of the primary key of this table, if it exists.
34
32
  time_column: The name of the time column of this table, if it exists.
35
33
  end_time_column: The name of the end time column of this table, if it
@@ -40,8 +38,7 @@ class SQLiteTable(SQLTable):
40
38
  connection: Connection,
41
39
  name: str,
42
40
  source_name: str | None = None,
43
- columns: Sequence[str] | None = None,
44
- column_expressions: Sequence[ColumnExpressionType] | None = None,
41
+ columns: Sequence[ColumnSpecType] | None = None,
45
42
  primary_key: MissingType | str | None = MissingType.VALUE,
46
43
  time_column: str | None = None,
47
44
  end_time_column: str | None = None,
@@ -53,7 +50,6 @@ class SQLiteTable(SQLTable):
53
50
  name=name,
54
51
  source_name=source_name,
55
52
  columns=columns,
56
- column_expressions=column_expressions,
57
53
  primary_key=primary_key,
58
54
  time_column=time_column,
59
55
  end_time_column=end_time_column,
@@ -66,16 +62,16 @@ class SQLiteTable(SQLTable):
66
62
  def _get_source_columns(self) -> list[SourceColumn]:
67
63
  source_columns: list[SourceColumn] = []
68
64
  with self._connection.cursor() as cursor:
69
- sql = f"PRAGMA table_info({self.fqn})"
65
+ sql = f"PRAGMA table_info({self._quoted_source_name})"
70
66
  cursor.execute(sql)
71
67
  columns = cursor.fetchall()
72
68
 
73
69
  if len(columns) == 0:
74
- raise ValueError(f"Table '{self._source_name}' does not exist "
70
+ raise ValueError(f"Table '{self.source_name}' does not exist "
75
71
  f"in the SQLite database")
76
72
 
77
73
  unique_keys: set[str] = set()
78
- sql = f"PRAGMA index_list({self.fqn})"
74
+ sql = f"PRAGMA index_list({self._quoted_source_name})"
79
75
  cursor.execute(sql)
80
76
  for _, index_name, is_unique, *_ in cursor.fetchall():
81
77
  if bool(is_unique):
@@ -85,32 +81,19 @@ class SQLiteTable(SQLTable):
85
81
  if len(index) == 1:
86
82
  unique_keys.add(index[0][2])
87
83
 
88
- for _, column, type, notnull, _, is_pkey in columns:
89
- # Determine column affinity:
90
- type = type.strip().upper()
91
- if re.search('INT', type):
92
- dtype = Dtype.int
93
- elif re.search('TEXT|CHAR|CLOB', type):
94
- dtype = Dtype.string
95
- elif re.search('REAL|FLOA|DOUB', type):
96
- dtype = Dtype.float
97
- else: # NUMERIC affinity.
98
- ser = self._source_sample_df[column]
99
- try:
100
- dtype = infer_dtype(ser)
101
- except Exception:
102
- warnings.warn(f"Encountered unsupported data type "
103
- f"'{ser.dtype}' with source data type "
104
- f"'{type}' for column '{column}' in "
105
- f"table '{self.name}'. If possible, "
106
- f"change the data type of the column in "
107
- f"your SQLite database to use it within "
108
- f"this table.")
109
- continue
84
+ # Special SQLite case that creates a rowid alias for
85
+ # `INTEGER PRIMARY KEY` annotated columns:
86
+ rowid_candidates = [
87
+ column for _, column, dtype, _, _, is_pkey in columns
88
+ if bool(is_pkey) and dtype.strip().upper() == 'INTEGER'
89
+ ]
90
+ if len(rowid_candidates) == 1:
91
+ unique_keys.add(rowid_candidates[0])
110
92
 
93
+ for _, column, dtype, notnull, _, is_pkey in columns:
111
94
  source_column = SourceColumn(
112
95
  name=column,
113
- dtype=dtype,
96
+ dtype=self._to_dtype(dtype),
114
97
  is_primary_key=bool(is_pkey),
115
98
  is_unique_key=column in unique_keys,
116
99
  is_nullable=not bool(is_pkey) and not bool(notnull),
@@ -120,35 +103,82 @@ class SQLiteTable(SQLTable):
120
103
  return source_columns
121
104
 
122
105
  def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
123
- source_fkeys: list[SourceForeignKey] = []
106
+ source_foreign_keys: list[SourceForeignKey] = []
124
107
  with self._connection.cursor() as cursor:
125
- sql = f"PRAGMA foreign_key_list({self.fqn})"
108
+ sql = f"PRAGMA foreign_key_list({self._quoted_source_name})"
126
109
  cursor.execute(sql)
127
- for _, _, dst_table, fkey, pkey, *_ in cursor.fetchall():
128
- source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
129
- return source_fkeys
110
+ rows = cursor.fetchall()
111
+ counts = Counter(row[0] for row in rows)
112
+ for idx, _, dst_table, foreign_key, primary_key, *_ in rows:
113
+ if counts[idx] == 1:
114
+ source_foreign_key = SourceForeignKey(
115
+ name=foreign_key,
116
+ dst_table=dst_table,
117
+ primary_key=primary_key,
118
+ )
119
+ source_foreign_keys.append(source_foreign_key)
120
+ return source_foreign_keys
130
121
 
131
122
  def _get_source_sample_df(self) -> pd.DataFrame:
132
123
  with self._connection.cursor() as cursor:
133
- sql = (f"SELECT * FROM {self.fqn} "
134
- f"ORDER BY rowid LIMIT 1000")
124
+ columns = [quote_ident(col) for col in self._source_column_dict]
125
+ sql = (f"SELECT {', '.join(columns)} "
126
+ f"FROM {self._quoted_source_name} "
127
+ f"ORDER BY rowid "
128
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
135
129
  cursor.execute(sql)
136
130
  table = cursor.fetch_arrow_table()
137
- return table.to_pandas(types_mapper=pd.ArrowDtype)
131
+
132
+ if len(table) == 0:
133
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
134
+
135
+ return self._sanitize(
136
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
137
+ dtype_dict={
138
+ column.name: column.dtype
139
+ for column in self._source_column_dict.values()
140
+ },
141
+ stype_dict=None,
142
+ )
138
143
 
139
144
  def _get_num_rows(self) -> int | None:
140
145
  return None
141
146
 
142
- def _get_expression_sample_df(
147
+ def _get_expr_sample_df(
143
148
  self,
144
- specs: Sequence[ColumnExpressionSpec],
149
+ columns: Sequence[ColumnSpec],
145
150
  ) -> pd.DataFrame:
146
151
  with self._connection.cursor() as cursor:
147
- columns = [
148
- f"{spec.expr} AS {quote_ident(spec.name)}" for spec in specs
152
+ projections = [
153
+ f"{column.expr} AS {quote_ident(column.name)}"
154
+ for column in columns
149
155
  ]
150
- sql = (f"SELECT {', '.join(columns)} FROM {self.fqn} "
151
- f"ORDER BY rowid LIMIT 1000")
156
+ sql = (f"SELECT {', '.join(projections)} "
157
+ f"FROM {self._quoted_source_name} "
158
+ f"ORDER BY rowid "
159
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
152
160
  cursor.execute(sql)
153
161
  table = cursor.fetch_arrow_table()
154
- return table.to_pandas(types_mapper=pd.ArrowDtype)
162
+
163
+ if len(table) == 0:
164
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
165
+
166
+ return self._sanitize(
167
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
168
+ dtype_dict={column.name: column.dtype
169
+ for column in columns},
170
+ stype_dict=None,
171
+ )
172
+
173
+ @staticmethod
174
+ def _to_dtype(dtype: str | None) -> Dtype | None:
175
+ if dtype is None:
176
+ return None
177
+ dtype = dtype.strip().upper()
178
+ if re.search('INT', dtype):
179
+ return Dtype.int
180
+ if re.search('TEXT|CHAR|CLOB', dtype):
181
+ return Dtype.string
182
+ if re.search('REAL|FLOA|DOUB', dtype):
183
+ return Dtype.float
184
+ return None # NUMERIC affinity.