kumoai 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl → 2.15.0.dev202601121731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +51 -24
  7. kumoai/experimental/rfm/authenticate.py +3 -4
  8. kumoai/experimental/rfm/backend/local/graph_store.py +37 -46
  9. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  10. kumoai/experimental/rfm/backend/local/table.py +24 -30
  11. kumoai/experimental/rfm/backend/snow/sampler.py +197 -90
  12. kumoai/experimental/rfm/backend/snow/table.py +159 -52
  13. kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
  14. kumoai/experimental/rfm/backend/sqlite/sampler.py +199 -99
  15. kumoai/experimental/rfm/backend/sqlite/table.py +103 -45
  16. kumoai/experimental/rfm/base/__init__.py +6 -1
  17. kumoai/experimental/rfm/base/column.py +96 -10
  18. kumoai/experimental/rfm/base/expression.py +44 -0
  19. kumoai/experimental/rfm/base/mapper.py +69 -0
  20. kumoai/experimental/rfm/base/sampler.py +28 -18
  21. kumoai/experimental/rfm/base/source.py +1 -1
  22. kumoai/experimental/rfm/base/sql_sampler.py +342 -13
  23. kumoai/experimental/rfm/base/table.py +374 -208
  24. kumoai/experimental/rfm/base/utils.py +27 -0
  25. kumoai/experimental/rfm/graph.py +335 -180
  26. kumoai/experimental/rfm/infer/__init__.py +6 -4
  27. kumoai/experimental/rfm/infer/dtype.py +7 -4
  28. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  29. kumoai/experimental/rfm/infer/pkey.py +4 -2
  30. kumoai/experimental/rfm/infer/stype.py +35 -0
  31. kumoai/experimental/rfm/infer/time_col.py +5 -4
  32. kumoai/experimental/rfm/pquery/executor.py +27 -27
  33. kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
  34. kumoai/experimental/rfm/relbench.py +76 -0
  35. kumoai/experimental/rfm/rfm.py +600 -360
  36. kumoai/experimental/rfm/sagemaker.py +4 -4
  37. kumoai/experimental/rfm/task_table.py +292 -0
  38. kumoai/pquery/training_table.py +16 -2
  39. kumoai/testing/snow.py +3 -3
  40. kumoai/trainer/distilled_trainer.py +175 -0
  41. kumoai/utils/__init__.py +1 -2
  42. kumoai/utils/display.py +87 -0
  43. kumoai/utils/progress_logger.py +190 -12
  44. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/METADATA +3 -2
  45. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/RECORD +48 -40
  46. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/WHEEL +0 -0
  47. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/licenses/LICENSE +0 -0
  48. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/top_level.txt +0 -0
@@ -6,12 +6,11 @@ import numpy as np
6
6
  import pandas as pd
7
7
  import pyarrow as pa
8
8
  from kumoapi.pquery import ValidatedPredictiveQuery
9
- from kumoapi.typing import Stype
10
9
 
11
10
  from kumoai.experimental.rfm.backend.sqlite import SQLiteTable
12
- from kumoai.experimental.rfm.base import SQLSampler
11
+ from kumoai.experimental.rfm.base import SQLSampler, Table
13
12
  from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
14
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger, quote_ident
13
+ from kumoai.utils import ProgressLogger, quote_ident
15
14
 
16
15
  if TYPE_CHECKING:
17
16
  from kumoai.experimental.rfm import Graph
@@ -35,17 +34,23 @@ class SQLiteSampler(SQLSampler):
35
34
  cursor.execute("PRAGMA temp_store = MEMORY")
36
35
  cursor.execute("PRAGMA cache_size = -2000000") # 2 GB
37
36
 
38
- # Collect database indices to speed-up sampling:
37
+ # Collect database indices for speeding sampling:
39
38
  index_dict: dict[str, set[tuple[str, ...]]] = defaultdict(set)
40
39
  for table_name, primary_key in self.primary_key_dict.items():
41
40
  source_table = self.source_table_dict[table_name]
42
- if not source_table[primary_key].is_unique_key:
43
- index_dict[table_name].add((primary_key, ))
41
+ if primary_key not in source_table:
42
+ continue # No physical column.
43
+ if source_table[primary_key].is_unique_key:
44
+ continue
45
+ index_dict[table_name].add((primary_key, ))
44
46
  for src_table_name, foreign_key, _ in graph.edges:
45
47
  source_table = self.source_table_dict[src_table_name]
48
+ if foreign_key not in source_table:
49
+ continue # No physical column.
46
50
  if source_table[foreign_key].is_unique_key:
47
- pass
48
- elif time_column := self.time_column_dict.get(src_table_name):
51
+ continue
52
+ time_column = self.time_column_dict.get(src_table_name)
53
+ if time_column is not None and time_column in source_table:
49
54
  index_dict[src_table_name].add((foreign_key, time_column))
50
55
  else:
51
56
  index_dict[src_table_name].add((foreign_key, ))
@@ -54,46 +59,57 @@ class SQLiteSampler(SQLSampler):
54
59
  with self._connection.cursor() as cursor:
55
60
  for table_name in list(index_dict.keys()):
56
61
  indices = index_dict[table_name]
57
- sql = f"PRAGMA index_list({quote_ident(table_name)})"
62
+ source_name = self.source_name_dict[table_name]
63
+ sql = f"PRAGMA index_list({source_name})"
58
64
  cursor.execute(sql)
59
65
  for _, index_name, *_ in cursor.fetchall():
60
66
  sql = f"PRAGMA index_info({quote_ident(index_name)})"
61
67
  cursor.execute(sql)
62
- index = tuple(info[2] for info in sorted(
68
+ # Fetch index information and sort by `seqno`:
69
+ index_info = tuple(info[2] for info in sorted(
63
70
  cursor.fetchall(), key=lambda x: x[0]))
64
- indices.discard(index)
71
+ # Remove all indices in case primary index already exists:
72
+ for index in list(indices):
73
+ if index_info[0] == index[0]:
74
+ indices.discard(index)
65
75
  if len(indices) == 0:
66
76
  del index_dict[table_name]
67
77
 
68
- num = sum(len(indices) for indices in index_dict.values())
69
- index_repr = '1 index' if num == 1 else f'{num} indices'
70
- num = len(index_dict)
71
- table_repr = '1 table' if num == 1 else f'{num} tables'
72
-
73
78
  if optimize and len(index_dict) > 0:
74
79
  if not isinstance(verbose, ProgressLogger):
75
- verbose = InteractiveProgressLogger(
76
- "Optimizing SQLite database",
80
+ verbose = ProgressLogger.default(
81
+ msg="Optimizing SQLite database",
77
82
  verbose=verbose,
78
83
  )
79
84
 
80
- with verbose as logger:
81
- with self._connection.cursor() as cursor:
82
- for table_name, indices in index_dict.items():
83
- for index in indices:
84
- name = f"kumo_index_{table_name}_{'_'.join(index)}"
85
- columns = ', '.join(quote_ident(v) for v in index)
86
- columns += ' DESC' if len(index) > 1 else ''
87
- sql = (f"CREATE INDEX IF NOT EXISTS {name}\n"
88
- f"ON {quote_ident(table_name)}({columns})")
89
- cursor.execute(sql)
90
- self._connection.commit()
91
- logger.log(f"Created {index_repr} in {table_repr}")
85
+ with verbose as logger, self._connection.cursor() as cursor:
86
+ for table_name, indices in index_dict.items():
87
+ for index in indices:
88
+ name = f"kumo_index_{table_name}_{'_'.join(index)}"
89
+ name = quote_ident(name)
90
+ columns = ', '.join(quote_ident(v) for v in index)
91
+ columns += ' DESC' if len(index) > 1 else ''
92
+ source_name = self.source_name_dict[table_name]
93
+ sql = (f"CREATE INDEX IF NOT EXISTS {name}\n"
94
+ f"ON {source_name}({columns})")
95
+ cursor.execute(sql)
96
+ self._connection.commit()
97
+ if len(index) > 1:
98
+ logger.log(f"Created index on {index} in table "
99
+ f"'{table_name}'")
100
+ else:
101
+ logger.log(f"Created index on '{index[0]}' in "
102
+ f"table '{table_name}'")
92
103
 
93
104
  elif len(index_dict) > 0:
105
+ num = sum(len(indices) for indices in index_dict.values())
106
+ index_repr = '1 index' if num == 1 else f'{num} indices'
107
+ num = len(index_dict)
108
+ table_repr = '1 table' if num == 1 else f'{num} tables'
94
109
  warnings.warn(f"Missing {index_repr} in {table_repr} for optimal "
95
110
  f"database querying. For improving runtime, we "
96
- f"strongly suggest to create these indices by "
111
+ f"strongly suggest to create indices for primary "
112
+ f"and foreign keys, e.g., automatically by "
97
113
  f"instantiating KumoRFM via "
98
114
  f"`KumoRFM(graph, optimize=True)`.")
99
115
 
@@ -103,12 +119,13 @@ class SQLiteSampler(SQLSampler):
103
119
  ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
104
120
  selects: list[str] = []
105
121
  for table_name in table_names:
106
- time_column = self.time_column_dict[table_name]
122
+ column = self.time_column_dict[table_name]
123
+ column_ref = self.table_column_ref_dict[table_name][column]
107
124
  select = (f"SELECT\n"
108
125
  f" ? as table_name,\n"
109
- f" MIN({quote_ident(time_column)}) as min_date,\n"
110
- f" MAX({quote_ident(time_column)}) as max_date\n"
111
- f"FROM {quote_ident(table_name)}")
126
+ f" MIN({column_ref}) as min_date,\n"
127
+ f" MAX({column_ref}) as max_date\n"
128
+ f"FROM {self.source_name_dict[table_name]}")
112
129
  selects.append(select)
113
130
  sql = "\nUNION ALL\n".join(selects)
114
131
 
@@ -131,18 +148,28 @@ class SQLiteSampler(SQLSampler):
131
148
  ) -> pd.DataFrame:
132
149
  # NOTE SQLite does not natively support passing a `random_seed`.
133
150
 
151
+ source_table = self.source_table_dict[table_name]
134
152
  filters: list[str] = []
135
- primary_key = self.primary_key_dict[table_name]
136
- if self.source_table_dict[table_name][primary_key].is_nullable:
137
- filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
138
- time_column = self.time_column_dict.get(table_name)
139
- if (time_column is not None and
140
- self.source_table_dict[table_name][time_column].is_nullable):
141
- filters.append(f" {quote_ident(time_column)} IS NOT NULL")
153
+
154
+ key = self.primary_key_dict[table_name]
155
+ if key not in source_table or source_table[key].is_nullable:
156
+ key_ref = self.table_column_ref_dict[table_name][key]
157
+ filters.append(f" {key_ref} IS NOT NULL")
158
+
159
+ column = self.time_column_dict.get(table_name)
160
+ if column is None:
161
+ pass
162
+ elif column not in source_table or source_table[column].is_nullable:
163
+ column_ref = self.table_column_ref_dict[table_name][column]
164
+ filters.append(f" {column_ref} IS NOT NULL")
142
165
 
143
166
  # TODO Make this query more efficient - it does full table scan.
144
- sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
145
- f"FROM {quote_ident(table_name)}")
167
+ projections = [
168
+ self.table_column_proj_dict[table_name][column]
169
+ for column in columns
170
+ ]
171
+ sql = (f"SELECT {', '.join(projections)}\n"
172
+ f"FROM {self.source_name_dict[table_name]}")
146
173
  if len(filters) > 0:
147
174
  sql += f"\nWHERE{' AND'.join(filters)}"
148
175
  sql += f"\nORDER BY RANDOM() LIMIT {num_rows}"
@@ -152,7 +179,11 @@ class SQLiteSampler(SQLSampler):
152
179
  cursor.execute(sql)
153
180
  table = cursor.fetch_arrow_table()
154
181
 
155
- return self._sanitize(table_name, table)
182
+ return Table._sanitize(
183
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
184
+ dtype_dict=self.table_dtype_dict[table_name],
185
+ stype_dict=self.table_stype_dict[table_name],
186
+ )
156
187
 
157
188
  def _sample_target(
158
189
  self,
@@ -195,84 +226,163 @@ class SQLiteSampler(SQLSampler):
195
226
  def _by_pkey(
196
227
  self,
197
228
  table_name: str,
198
- pkey: pd.Series,
229
+ index: pd.Series,
199
230
  columns: set[str],
200
231
  ) -> tuple[pd.DataFrame, np.ndarray]:
201
- pkey_name = self.primary_key_dict[table_name]
232
+ source_table = self.source_table_dict[table_name]
233
+ key = self.primary_key_dict[table_name]
234
+ key_ref = self.table_column_ref_dict[table_name][key]
235
+ projections = [
236
+ self.table_column_proj_dict[table_name][column]
237
+ for column in columns
238
+ ]
239
+
240
+ tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
241
+ tmp_name = f'tmp_{table_name}_{key}_{id(tmp)}'
242
+
243
+ sql = (f"SELECT "
244
+ f"tmp.rowid - 1 as __kumo_batch__, "
245
+ f"{', '.join(projections)}\n"
246
+ f"FROM {quote_ident(tmp_name)} tmp\n"
247
+ f"JOIN {self.source_name_dict[table_name]} ent\n")
248
+ if key in source_table and source_table[key].is_unique_key:
249
+ sql += (f" ON {key_ref} = tmp.__kumo_id__")
250
+ else:
251
+ sql += (f" ON ent.rowid = (\n"
252
+ f" SELECT rowid\n"
253
+ f" FROM {self.source_name_dict[table_name]}\n"
254
+ f" WHERE {key_ref} == tmp.__kumo_id__\n"
255
+ f" LIMIT 1\n"
256
+ f")")
202
257
 
203
- tmp = pa.table([pa.array(pkey)], names=['id'])
204
- tmp_name = f'tmp_{table_name}_{pkey_name}_{id(tmp)}'
258
+ with self._connection.cursor() as cursor:
259
+ cursor.adbc_ingest(tmp_name, tmp, mode='replace')
260
+ cursor.execute(sql)
261
+ table = cursor.fetch_arrow_table()
205
262
 
206
- if self.source_table_dict[table_name][pkey_name].is_unique_key:
207
- sql = (f"SELECT tmp.rowid - 1 as __batch__, "
208
- f"{', '.join('ent.' + quote_ident(c) for c in columns)}\n"
209
- f"FROM {quote_ident(tmp_name)} tmp\n"
210
- f"JOIN {quote_ident(table_name)} ent\n"
211
- f" ON ent.{quote_ident(pkey_name)} = tmp.id")
212
- else:
213
- sql = (f"SELECT tmp.rowid - 1 as __batch__, "
214
- f"{', '.join('ent.' + quote_ident(c) for c in columns)}\n"
215
- f"FROM {quote_ident(tmp_name)} tmp\n"
216
- f"JOIN {quote_ident(table_name)} ent\n"
217
- f" ON ent.rowid = (\n"
218
- f" SELECT rowid FROM {quote_ident(table_name)}\n"
219
- f" WHERE {quote_ident(pkey_name)} == tmp.id\n"
220
- f" LIMIT 1\n"
221
- f")")
263
+ batch = table['__kumo_batch__'].to_numpy()
264
+ batch_index = table.schema.get_field_index('__kumo_batch__')
265
+ table = table.remove_column(batch_index)
266
+
267
+ return Table._sanitize(
268
+ df=table.to_pandas(),
269
+ dtype_dict=self.table_dtype_dict[table_name],
270
+ stype_dict=self.table_stype_dict[table_name],
271
+ ), batch
272
+
273
+ def _by_fkey(
274
+ self,
275
+ table_name: str,
276
+ foreign_key: str,
277
+ index: pd.Series,
278
+ num_neighbors: int,
279
+ anchor_time: pd.Series | None,
280
+ columns: set[str],
281
+ ) -> tuple[pd.DataFrame, np.ndarray]:
282
+ time_column = self.time_column_dict.get(table_name)
283
+
284
+ # NOTE SQLite does not have a native datetime format. Currently, we
285
+ # assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
286
+ tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
287
+ if time_column is not None and anchor_time is not None:
288
+ anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
289
+ tmp = tmp.append_column('__kumo_time__', pa.array(anchor_time))
290
+ tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
291
+
292
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
293
+ projections = [
294
+ self.table_column_proj_dict[table_name][column]
295
+ for column in columns
296
+ ]
297
+ sql = (f"SELECT "
298
+ f"tmp.rowid - 1 as __kumo_batch__, "
299
+ f"{', '.join(projections)}\n"
300
+ f"FROM {quote_ident(tmp_name)} tmp\n"
301
+ f"JOIN {self.source_name_dict[table_name]} fact\n"
302
+ f"ON fact.rowid IN (\n"
303
+ f" SELECT rowid\n"
304
+ f" FROM {self.source_name_dict[table_name]}\n"
305
+ f" WHERE {key_ref} = tmp.__kumo_id__\n")
306
+ if time_column is not None and anchor_time is not None:
307
+ time_ref = self.table_column_ref_dict[table_name][time_column]
308
+ sql += f" AND {time_ref} <= tmp.__kumo_time__\n"
309
+ if time_column is not None:
310
+ time_ref = self.table_column_ref_dict[table_name][time_column]
311
+ sql += f" ORDER BY {time_ref} DESC\n"
312
+ sql += (f" LIMIT {num_neighbors}\n"
313
+ f")")
222
314
 
223
315
  with self._connection.cursor() as cursor:
224
316
  cursor.adbc_ingest(tmp_name, tmp, mode='replace')
225
317
  cursor.execute(sql)
226
318
  table = cursor.fetch_arrow_table()
227
319
 
228
- batch = table['__batch__'].to_numpy()
229
- table = table.remove_column(table.schema.get_field_index('__batch__'))
320
+ batch = table['__kumo_batch__'].to_numpy()
321
+ batch_index = table.schema.get_field_index('__kumo_batch__')
322
+ table = table.remove_column(batch_index)
230
323
 
231
- return table.to_pandas(), batch # TODO Use `self._sanitize`.
324
+ return Table._sanitize(
325
+ df=table.to_pandas(),
326
+ dtype_dict=self.table_dtype_dict[table_name],
327
+ stype_dict=self.table_stype_dict[table_name],
328
+ ), batch
232
329
 
233
330
  # Helper Methods ##########################################################
234
331
 
235
332
  def _by_time(
236
333
  self,
237
334
  table_name: str,
238
- fkey: str,
239
- pkey: pd.Series,
335
+ foreign_key: str,
336
+ index: pd.Series,
240
337
  anchor_time: pd.Series,
241
338
  min_offset: pd.DateOffset | None,
242
339
  max_offset: pd.DateOffset,
243
340
  columns: set[str],
244
341
  ) -> tuple[pd.DataFrame, np.ndarray]:
342
+ time_column = self.time_column_dict[table_name]
343
+
245
344
  # NOTE SQLite does not have a native datetime format. Currently, we
246
345
  # assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
247
- tmp = pa.table([pa.array(pkey)], names=['id'])
346
+ tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
248
347
  end_time = anchor_time + max_offset
249
348
  end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
250
- tmp = tmp.append_column('end', pa.array(end_time))
349
+ tmp = tmp.append_column('__kumo_end__', pa.array(end_time))
251
350
  if min_offset is not None:
252
351
  start_time = anchor_time + min_offset
253
352
  start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
254
- tmp = tmp.append_column('start', pa.array(start_time))
255
- tmp_name = f'tmp_{table_name}_{fkey}_{id(tmp)}'
256
-
257
- time_column = self.time_column_dict[table_name]
258
- sql = (f"SELECT tmp.rowid - 1 as __batch__, "
259
- f"{', '.join('fact.' + quote_ident(col) for col in columns)}\n"
353
+ tmp = tmp.append_column('__kumo_start__', pa.array(start_time))
354
+ tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
355
+
356
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
357
+ time_ref = self.table_column_ref_dict[table_name][time_column]
358
+ projections = [
359
+ self.table_column_proj_dict[table_name][column]
360
+ for column in columns
361
+ ]
362
+ sql = (f"SELECT "
363
+ f"tmp.rowid - 1 as __kumo_batch__, "
364
+ f"{', '.join(projections)}\n"
260
365
  f"FROM {quote_ident(tmp_name)} tmp\n"
261
- f"JOIN {quote_ident(table_name)} fact\n"
262
- f" ON fact.{quote_ident(fkey)} = tmp.id\n"
263
- f" AND fact.{quote_ident(time_column)} <= tmp.end")
366
+ f"JOIN {self.source_name_dict[table_name]}\n"
367
+ f" ON {key_ref} = tmp.__kumo_id__\n"
368
+ f" AND {time_ref} <= tmp.__kumo_end__")
264
369
  if min_offset is not None:
265
- sql += f"\n AND fact.{quote_ident(time_column)} > tmp.start"
370
+ sql += f"\n AND {time_ref} > tmp.__kumo_start__"
266
371
 
267
372
  with self._connection.cursor() as cursor:
268
373
  cursor.adbc_ingest(tmp_name, tmp, mode='replace')
269
374
  cursor.execute(sql)
270
375
  table = cursor.fetch_arrow_table()
271
376
 
272
- batch = table['__batch__'].to_numpy()
273
- table = table.remove_column(table.schema.get_field_index('__batch__'))
377
+ batch = table['__kumo_batch__'].to_numpy()
378
+ batch_index = table.schema.get_field_index('__kumo_batch__')
379
+ table = table.remove_column(batch_index)
274
380
 
275
- return self._sanitize(table_name, table), batch
381
+ return Table._sanitize(
382
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
383
+ dtype_dict=self.table_dtype_dict[table_name],
384
+ stype_dict=self.table_stype_dict[table_name],
385
+ ), batch
276
386
 
277
387
  def _sample_target_set(
278
388
  self,
@@ -305,11 +415,11 @@ class SQLiteSampler(SQLSampler):
305
415
  query.entity_table: np.arange(len(df)),
306
416
  }
307
417
  for edge_type, (_min, _max) in time_offset_dict.items():
308
- table_name, fkey, _ = edge_type
418
+ table_name, foreign_key, _ = edge_type
309
419
  feat_dict[table_name], batch_dict[table_name] = self._by_time(
310
420
  table_name=table_name,
311
- fkey=fkey,
312
- pkey=df[self.primary_key_dict[query.entity_table]],
421
+ foreign_key=foreign_key,
422
+ index=df[self.primary_key_dict[query.entity_table]],
313
423
  anchor_time=time,
314
424
  min_offset=_min,
315
425
  max_offset=_max,
@@ -324,7 +434,7 @@ class SQLiteSampler(SQLSampler):
324
434
  feat_dict=feat_dict,
325
435
  time_dict=time_dict,
326
436
  batch_dict=batch_dict,
327
- anchor_time=anchor_time,
437
+ anchor_time=time,
328
438
  num_forecasts=query.num_forecasts,
329
439
  )
330
440
  ys.append(y)
@@ -342,13 +452,3 @@ class SQLiteSampler(SQLSampler):
342
452
  y = pd.concat(ys, axis=0, ignore_index=True)
343
453
 
344
454
  return y, mask
345
-
346
- def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
347
- df = table.to_pandas(types_mapper=pd.ArrowDtype)
348
-
349
- stype_dict = self.table_stype_dict[table_name]
350
- for column_name in df.columns:
351
- if stype_dict.get(column_name) == Stype.timestamp:
352
- df[column_name] = pd.to_datetime(df[column_name])
353
-
354
- return df
@@ -1,18 +1,21 @@
1
1
  import re
2
- import warnings
3
- from typing import List, Optional, Sequence, cast
2
+ from collections import Counter
3
+ from collections.abc import Sequence
4
+ from typing import cast
4
5
 
5
6
  import pandas as pd
7
+ from kumoapi.model_plan import MissingType
6
8
  from kumoapi.typing import Dtype
7
9
 
8
10
  from kumoai.experimental.rfm.backend.sqlite import Connection
9
11
  from kumoai.experimental.rfm.base import (
12
+ ColumnSpec,
13
+ ColumnSpecType,
10
14
  DataBackend,
11
15
  SourceColumn,
12
16
  SourceForeignKey,
13
17
  Table,
14
18
  )
15
- from kumoai.experimental.rfm.infer import infer_dtype
16
19
  from kumoai.utils import quote_ident
17
20
 
18
21
 
@@ -22,6 +25,8 @@ class SQLiteTable(Table):
22
25
  Args:
23
26
  connection: The connection to a :class:`sqlite` database.
24
27
  name: The name of this table.
28
+ source_name: The source name of this table. If set to ``None``,
29
+ ``name`` is being used.
25
30
  columns: The selected columns of this table.
26
31
  primary_key: The name of the primary key of this table, if it exists.
27
32
  time_column: The name of the time column of this table, if it exists.
@@ -32,16 +37,18 @@ class SQLiteTable(Table):
32
37
  self,
33
38
  connection: Connection,
34
39
  name: str,
35
- columns: Optional[Sequence[str]] = None,
36
- primary_key: Optional[str] = None,
37
- time_column: Optional[str] = None,
38
- end_time_column: Optional[str] = None,
40
+ source_name: str | None = None,
41
+ columns: Sequence[ColumnSpecType] | None = None,
42
+ primary_key: MissingType | str | None = MissingType.VALUE,
43
+ time_column: str | None = None,
44
+ end_time_column: str | None = None,
39
45
  ) -> None:
40
46
 
41
47
  self._connection = connection
42
48
 
43
49
  super().__init__(
44
50
  name=name,
51
+ source_name=source_name,
45
52
  columns=columns,
46
53
  primary_key=primary_key,
47
54
  time_column=time_column,
@@ -52,18 +59,19 @@ class SQLiteTable(Table):
52
59
  def backend(self) -> DataBackend:
53
60
  return cast(DataBackend, DataBackend.SQLITE)
54
61
 
55
- def _get_source_columns(self) -> List[SourceColumn]:
56
- source_columns: List[SourceColumn] = []
62
+ def _get_source_columns(self) -> list[SourceColumn]:
63
+ source_columns: list[SourceColumn] = []
57
64
  with self._connection.cursor() as cursor:
58
- sql = f"PRAGMA table_info({quote_ident(self.name)})"
65
+ sql = f"PRAGMA table_info({self._quoted_source_name})"
59
66
  cursor.execute(sql)
60
67
  columns = cursor.fetchall()
61
68
 
62
69
  if len(columns) == 0:
63
- raise ValueError(f"Table '{self.name}' does not exist")
70
+ raise ValueError(f"Table '{self.source_name}' does not exist "
71
+ f"in the SQLite database")
64
72
 
65
73
  unique_keys: set[str] = set()
66
- sql = f"PRAGMA index_list({quote_ident(self.name)})"
74
+ sql = f"PRAGMA index_list({self._quoted_source_name})"
67
75
  cursor.execute(sql)
68
76
  for _, index_name, is_unique, *_ in cursor.fetchall():
69
77
  if bool(is_unique):
@@ -73,30 +81,19 @@ class SQLiteTable(Table):
73
81
  if len(index) == 1:
74
82
  unique_keys.add(index[0][2])
75
83
 
76
- for _, column, type, notnull, _, is_pkey in columns:
77
- # Determine column affinity:
78
- type = type.strip().upper()
79
- if re.search('INT', type):
80
- dtype = Dtype.int
81
- elif re.search('TEXT|CHAR|CLOB', type):
82
- dtype = Dtype.string
83
- elif re.search('REAL|FLOA|DOUB', type):
84
- dtype = Dtype.float
85
- else: # NUMERIC affinity.
86
- ser = self._sample_df[column]
87
- try:
88
- dtype = infer_dtype(ser)
89
- except Exception:
90
- warnings.warn(
91
- f"Data type inference for column '{column}' in "
92
- f"table '{self.name}' failed. Consider changing "
93
- f"the data type of the column to use it within "
94
- f"this table.")
95
- continue
84
+ # Special SQLite case that creates a rowid alias for
85
+ # `INTEGER PRIMARY KEY` annotated columns:
86
+ rowid_candidates = [
87
+ column for _, column, dtype, _, _, is_pkey in columns
88
+ if bool(is_pkey) and dtype.strip().upper() == 'INTEGER'
89
+ ]
90
+ if len(rowid_candidates) == 1:
91
+ unique_keys.add(rowid_candidates[0])
96
92
 
93
+ for _, column, dtype, notnull, _, is_pkey in columns:
97
94
  source_column = SourceColumn(
98
95
  name=column,
99
- dtype=dtype,
96
+ dtype=self._to_dtype(dtype),
100
97
  is_primary_key=bool(is_pkey),
101
98
  is_unique_key=column in unique_keys,
102
99
  is_nullable=not bool(is_pkey) and not bool(notnull),
@@ -105,22 +102,83 @@ class SQLiteTable(Table):
105
102
 
106
103
  return source_columns
107
104
 
108
- def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
109
- source_fkeys: List[SourceForeignKey] = []
105
+ def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
106
+ source_foreign_keys: list[SourceForeignKey] = []
110
107
  with self._connection.cursor() as cursor:
111
- sql = f"PRAGMA foreign_key_list({quote_ident(self.name)})"
108
+ sql = f"PRAGMA foreign_key_list({self._quoted_source_name})"
112
109
  cursor.execute(sql)
113
- for _, _, dst_table, fkey, pkey, *_ in cursor.fetchall():
114
- source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
115
- return source_fkeys
116
-
117
- def _get_sample_df(self) -> pd.DataFrame:
110
+ rows = cursor.fetchall()
111
+ counts = Counter(row[0] for row in rows)
112
+ for idx, _, dst_table, foreign_key, primary_key, *_ in rows:
113
+ if counts[idx] == 1:
114
+ source_foreign_key = SourceForeignKey(
115
+ name=foreign_key,
116
+ dst_table=dst_table,
117
+ primary_key=primary_key,
118
+ )
119
+ source_foreign_keys.append(source_foreign_key)
120
+ return source_foreign_keys
121
+
122
+ def _get_source_sample_df(self) -> pd.DataFrame:
118
123
  with self._connection.cursor() as cursor:
119
- sql = (f"SELECT * FROM {quote_ident(self.name)} "
120
- f"ORDER BY rowid LIMIT 1000")
124
+ columns = [quote_ident(col) for col in self._source_column_dict]
125
+ sql = (f"SELECT {', '.join(columns)} "
126
+ f"FROM {self._quoted_source_name} "
127
+ f"ORDER BY rowid "
128
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
121
129
  cursor.execute(sql)
122
130
  table = cursor.fetch_arrow_table()
123
- return table.to_pandas(types_mapper=pd.ArrowDtype)
124
131
 
125
- def _get_num_rows(self) -> Optional[int]:
132
+ if len(table) == 0:
133
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
134
+
135
+ return self._sanitize(
136
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
137
+ dtype_dict={
138
+ column.name: column.dtype
139
+ for column in self._source_column_dict.values()
140
+ },
141
+ stype_dict=None,
142
+ )
143
+
144
+ def _get_num_rows(self) -> int | None:
126
145
  return None
146
+
147
+ def _get_expr_sample_df(
148
+ self,
149
+ columns: Sequence[ColumnSpec],
150
+ ) -> pd.DataFrame:
151
+ with self._connection.cursor() as cursor:
152
+ projections = [
153
+ f"{column.expr} AS {quote_ident(column.name)}"
154
+ for column in columns
155
+ ]
156
+ sql = (f"SELECT {', '.join(projections)} "
157
+ f"FROM {self._quoted_source_name} "
158
+ f"ORDER BY rowid "
159
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
160
+ cursor.execute(sql)
161
+ table = cursor.fetch_arrow_table()
162
+
163
+ if len(table) == 0:
164
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
165
+
166
+ return self._sanitize(
167
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
168
+ dtype_dict={column.name: column.dtype
169
+ for column in columns},
170
+ stype_dict=None,
171
+ )
172
+
173
+ @staticmethod
174
+ def _to_dtype(dtype: str | None) -> Dtype | None:
175
+ if dtype is None:
176
+ return None
177
+ dtype = dtype.strip().upper()
178
+ if re.search('INT', dtype):
179
+ return Dtype.int
180
+ if re.search('TEXT|CHAR|CLOB', dtype):
181
+ return Dtype.string
182
+ if re.search('REAL|FLOA|DOUB', dtype):
183
+ return Dtype.float
184
+ return None # NUMERIC affinity.