kumoai 2.14.0.dev202512141732__py3-none-any.whl → 2.15.0.dev202601131732__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +51 -24
  7. kumoai/experimental/rfm/authenticate.py +3 -4
  8. kumoai/experimental/rfm/backend/local/graph_store.py +37 -46
  9. kumoai/experimental/rfm/backend/local/sampler.py +4 -5
  10. kumoai/experimental/rfm/backend/local/table.py +24 -30
  11. kumoai/experimental/rfm/backend/snow/sampler.py +331 -43
  12. kumoai/experimental/rfm/backend/snow/table.py +166 -56
  13. kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
  14. kumoai/experimental/rfm/backend/sqlite/sampler.py +372 -30
  15. kumoai/experimental/rfm/backend/sqlite/table.py +117 -48
  16. kumoai/experimental/rfm/base/__init__.py +8 -1
  17. kumoai/experimental/rfm/base/column.py +96 -10
  18. kumoai/experimental/rfm/base/expression.py +44 -0
  19. kumoai/experimental/rfm/base/mapper.py +69 -0
  20. kumoai/experimental/rfm/base/sampler.py +28 -18
  21. kumoai/experimental/rfm/base/source.py +1 -1
  22. kumoai/experimental/rfm/base/sql_sampler.py +385 -0
  23. kumoai/experimental/rfm/base/table.py +374 -208
  24. kumoai/experimental/rfm/base/utils.py +36 -0
  25. kumoai/experimental/rfm/graph.py +335 -180
  26. kumoai/experimental/rfm/infer/__init__.py +6 -4
  27. kumoai/experimental/rfm/infer/dtype.py +10 -5
  28. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  29. kumoai/experimental/rfm/infer/pkey.py +4 -2
  30. kumoai/experimental/rfm/infer/stype.py +35 -0
  31. kumoai/experimental/rfm/infer/time_col.py +5 -4
  32. kumoai/experimental/rfm/pquery/executor.py +27 -27
  33. kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
  34. kumoai/experimental/rfm/relbench.py +76 -0
  35. kumoai/experimental/rfm/rfm.py +606 -361
  36. kumoai/experimental/rfm/sagemaker.py +4 -4
  37. kumoai/experimental/rfm/task_table.py +292 -0
  38. kumoai/pquery/training_table.py +16 -2
  39. kumoai/testing/snow.py +3 -3
  40. kumoai/trainer/distilled_trainer.py +175 -0
  41. kumoai/utils/__init__.py +1 -2
  42. kumoai/utils/display.py +87 -0
  43. kumoai/utils/progress_logger.py +192 -13
  44. kumoai/utils/sql.py +2 -2
  45. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/METADATA +3 -2
  46. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/RECORD +49 -40
  47. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/WHEEL +0 -0
  48. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/licenses/LICENSE +0 -0
  49. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,117 @@
1
- from typing import TYPE_CHECKING, Literal
1
+ import warnings
2
+ from collections import defaultdict
3
+ from typing import TYPE_CHECKING
2
4
 
3
5
  import numpy as np
4
6
  import pandas as pd
7
+ import pyarrow as pa
5
8
  from kumoapi.pquery import ValidatedPredictiveQuery
6
9
 
7
10
  from kumoai.experimental.rfm.backend.sqlite import SQLiteTable
8
- from kumoai.experimental.rfm.base import Sampler, SamplerOutput
11
+ from kumoai.experimental.rfm.base import SQLSampler, Table
12
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
9
13
  from kumoai.utils import ProgressLogger, quote_ident
10
14
 
11
15
  if TYPE_CHECKING:
12
16
  from kumoai.experimental.rfm import Graph
13
17
 
14
18
 
15
- class SQLiteSampler(Sampler):
19
+ class SQLiteSampler(SQLSampler):
16
20
  def __init__(
17
21
  self,
18
22
  graph: 'Graph',
19
23
  verbose: bool | ProgressLogger = True,
24
+ optimize: bool = False,
20
25
  ) -> None:
21
- super().__init__(graph=graph)
26
+ super().__init__(graph=graph, verbose=verbose)
22
27
 
23
28
  for table in graph.tables.values():
24
29
  assert isinstance(table, SQLiteTable)
25
30
  self._connection = table._connection
26
31
 
27
- # TODO Check for indices being present.
32
+ if optimize:
33
+ with self._connection.cursor() as cursor:
34
+ cursor.execute("PRAGMA temp_store = MEMORY")
35
+ cursor.execute("PRAGMA cache_size = -2000000") # 2 GB
36
+
37
+ # Collect database indices for speeding sampling:
38
+ index_dict: dict[str, set[tuple[str, ...]]] = defaultdict(set)
39
+ for table_name, primary_key in self.primary_key_dict.items():
40
+ source_table = self.source_table_dict[table_name]
41
+ if primary_key not in source_table:
42
+ continue # No physical column.
43
+ if source_table[primary_key].is_unique_key:
44
+ continue
45
+ index_dict[table_name].add((primary_key, ))
46
+ for src_table_name, foreign_key, _ in graph.edges:
47
+ source_table = self.source_table_dict[src_table_name]
48
+ if foreign_key not in source_table:
49
+ continue # No physical column.
50
+ if source_table[foreign_key].is_unique_key:
51
+ continue
52
+ time_column = self.time_column_dict.get(src_table_name)
53
+ if time_column is not None and time_column in source_table:
54
+ index_dict[src_table_name].add((foreign_key, time_column))
55
+ else:
56
+ index_dict[src_table_name].add((foreign_key, ))
57
+
58
+ # Only maintain missing indices:
59
+ with self._connection.cursor() as cursor:
60
+ for table_name in list(index_dict.keys()):
61
+ indices = index_dict[table_name]
62
+ source_name = self.source_name_dict[table_name]
63
+ sql = f"PRAGMA index_list({source_name})"
64
+ cursor.execute(sql)
65
+ for _, index_name, *_ in cursor.fetchall():
66
+ sql = f"PRAGMA index_info({quote_ident(index_name)})"
67
+ cursor.execute(sql)
68
+ # Fetch index information and sort by `seqno`:
69
+ index_info = tuple(info[2] for info in sorted(
70
+ cursor.fetchall(), key=lambda x: x[0]))
71
+ # Remove all indices in case primary index already exists:
72
+ for index in list(indices):
73
+ if index_info[0] == index[0]:
74
+ indices.discard(index)
75
+ if len(indices) == 0:
76
+ del index_dict[table_name]
77
+
78
+ if optimize and len(index_dict) > 0:
79
+ if not isinstance(verbose, ProgressLogger):
80
+ verbose = ProgressLogger.default(
81
+ msg="Optimizing SQLite database",
82
+ verbose=verbose,
83
+ )
84
+
85
+ with verbose as logger, self._connection.cursor() as cursor:
86
+ for table_name, indices in index_dict.items():
87
+ for index in indices:
88
+ name = f"kumo_index_{table_name}_{'_'.join(index)}"
89
+ name = quote_ident(name)
90
+ columns = ', '.join(quote_ident(v) for v in index)
91
+ columns += ' DESC' if len(index) > 1 else ''
92
+ source_name = self.source_name_dict[table_name]
93
+ sql = (f"CREATE INDEX IF NOT EXISTS {name}\n"
94
+ f"ON {source_name}({columns})")
95
+ cursor.execute(sql)
96
+ self._connection.commit()
97
+ if len(index) > 1:
98
+ logger.log(f"Created index on {index} in table "
99
+ f"'{table_name}'")
100
+ else:
101
+ logger.log(f"Created index on '{index[0]}' in "
102
+ f"table '{table_name}'")
103
+
104
+ elif len(index_dict) > 0:
105
+ num = sum(len(indices) for indices in index_dict.values())
106
+ index_repr = '1 index' if num == 1 else f'{num} indices'
107
+ num = len(index_dict)
108
+ table_repr = '1 table' if num == 1 else f'{num} tables'
109
+ warnings.warn(f"Missing {index_repr} in {table_repr} for optimal "
110
+ f"database querying. For improving runtime, we "
111
+ f"strongly suggest to create indices for primary "
112
+ f"and foreign keys, e.g., automatically by "
113
+ f"instantiating KumoRFM via "
114
+ f"`KumoRFM(graph, optimize=True)`.")
28
115
 
29
116
  def _get_min_max_time_dict(
30
117
  self,
@@ -32,12 +119,13 @@ class SQLiteSampler(Sampler):
32
119
  ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
33
120
  selects: list[str] = []
34
121
  for table_name in table_names:
35
- time_column = self.time_column_dict[table_name]
122
+ column = self.time_column_dict[table_name]
123
+ column_ref = self.table_column_ref_dict[table_name][column]
36
124
  select = (f"SELECT\n"
37
125
  f" ? as table_name,\n"
38
- f" MIN({quote_ident(time_column)}) as min_date,\n"
39
- f" MAX({quote_ident(time_column)}) as max_date\n"
40
- f"FROM {quote_ident(table_name)}")
126
+ f" MIN({column_ref}) as min_date,\n"
127
+ f" MAX({column_ref}) as max_date\n"
128
+ f"FROM {self.source_name_dict[table_name]}")
41
129
  selects.append(select)
42
130
  sql = "\nUNION ALL\n".join(selects)
43
131
 
@@ -51,16 +139,6 @@ class SQLiteSampler(Sampler):
51
139
  )
52
140
  return out_dict
53
141
 
54
- def _sample_subgraph(
55
- self,
56
- entity_table_name: str,
57
- entity_pkey: pd.Series,
58
- anchor_time: pd.Series | Literal['entity'],
59
- columns_dict: dict[str, set[str]],
60
- num_neighbors: list[int],
61
- ) -> SamplerOutput:
62
- raise NotImplementedError
63
-
64
142
  def _sample_entity_table(
65
143
  self,
66
144
  table_name: str,
@@ -70,18 +148,28 @@ class SQLiteSampler(Sampler):
70
148
  ) -> pd.DataFrame:
71
149
  # NOTE SQLite does not natively support passing a `random_seed`.
72
150
 
151
+ source_table = self.source_table_dict[table_name]
73
152
  filters: list[str] = []
74
- primary_key = self.primary_key_dict[table_name]
75
- if self.source_table_dict[table_name][primary_key].is_nullable:
76
- filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
77
- time_column = self.time_column_dict.get(table_name)
78
- if (time_column is not None and
79
- self.source_table_dict[table_name][time_column].is_nullable):
80
- filters.append(f" {quote_ident(time_column)} IS NOT NULL")
153
+
154
+ key = self.primary_key_dict[table_name]
155
+ if key not in source_table or source_table[key].is_nullable:
156
+ key_ref = self.table_column_ref_dict[table_name][key]
157
+ filters.append(f" {key_ref} IS NOT NULL")
158
+
159
+ column = self.time_column_dict.get(table_name)
160
+ if column is None:
161
+ pass
162
+ elif column not in source_table or source_table[column].is_nullable:
163
+ column_ref = self.table_column_ref_dict[table_name][column]
164
+ filters.append(f" {column_ref} IS NOT NULL")
81
165
 
82
166
  # TODO Make this query more efficient - it does full table scan.
83
- sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
84
- f"FROM {quote_ident(table_name)}")
167
+ projections = [
168
+ self.table_column_proj_dict[table_name][column]
169
+ for column in columns
170
+ ]
171
+ sql = (f"SELECT {', '.join(projections)}\n"
172
+ f"FROM {self.source_name_dict[table_name]}")
85
173
  if len(filters) > 0:
86
174
  sql += f"\nWHERE{' AND'.join(filters)}"
87
175
  sql += f"\nORDER BY RANDOM() LIMIT {num_rows}"
@@ -91,7 +179,11 @@ class SQLiteSampler(Sampler):
91
179
  cursor.execute(sql)
92
180
  table = cursor.fetch_arrow_table()
93
181
 
94
- return table.to_pandas(types_mapper=pd.ArrowDtype)
182
+ return Table._sanitize(
183
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
184
+ dtype_dict=self.table_dtype_dict[table_name],
185
+ stype_dict=self.table_stype_dict[table_name],
186
+ )
95
187
 
96
188
  def _sample_target(
97
189
  self,
@@ -109,4 +201,254 @@ class SQLiteSampler(Sampler):
109
201
  tuple[pd.DateOffset | None, pd.DateOffset],
110
202
  ],
111
203
  ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
112
- raise NotImplementedError
204
+ train_y, train_mask = self._sample_target_set(
205
+ query=query,
206
+ entity_df=entity_df,
207
+ index=train_index,
208
+ anchor_time=train_time,
209
+ num_examples=num_train_examples,
210
+ columns_dict=columns_dict,
211
+ time_offset_dict=time_offset_dict,
212
+ )
213
+
214
+ test_y, test_mask = self._sample_target_set(
215
+ query=query,
216
+ entity_df=entity_df,
217
+ index=test_index,
218
+ anchor_time=test_time,
219
+ num_examples=num_test_examples,
220
+ columns_dict=columns_dict,
221
+ time_offset_dict=time_offset_dict,
222
+ )
223
+
224
+ return train_y, train_mask, test_y, test_mask
225
+
226
+ def _by_pkey(
227
+ self,
228
+ table_name: str,
229
+ index: pd.Series,
230
+ columns: set[str],
231
+ ) -> tuple[pd.DataFrame, np.ndarray]:
232
+ source_table = self.source_table_dict[table_name]
233
+ key = self.primary_key_dict[table_name]
234
+ key_ref = self.table_column_ref_dict[table_name][key]
235
+ projections = [
236
+ self.table_column_proj_dict[table_name][column]
237
+ for column in columns
238
+ ]
239
+
240
+ tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
241
+ tmp_name = f'tmp_{table_name}_{key}_{id(tmp)}'
242
+
243
+ sql = (f"SELECT "
244
+ f"tmp.rowid - 1 as __kumo_batch__, "
245
+ f"{', '.join(projections)}\n"
246
+ f"FROM {quote_ident(tmp_name)} tmp\n"
247
+ f"JOIN {self.source_name_dict[table_name]} ent\n")
248
+ if key in source_table and source_table[key].is_unique_key:
249
+ sql += (f" ON {key_ref} = tmp.__kumo_id__")
250
+ else:
251
+ sql += (f" ON ent.rowid = (\n"
252
+ f" SELECT rowid\n"
253
+ f" FROM {self.source_name_dict[table_name]}\n"
254
+ f" WHERE {key_ref} == tmp.__kumo_id__\n"
255
+ f" LIMIT 1\n"
256
+ f")")
257
+
258
+ with self._connection.cursor() as cursor:
259
+ cursor.adbc_ingest(tmp_name, tmp, mode='replace')
260
+ cursor.execute(sql)
261
+ table = cursor.fetch_arrow_table()
262
+
263
+ batch = table['__kumo_batch__'].to_numpy()
264
+ batch_index = table.schema.get_field_index('__kumo_batch__')
265
+ table = table.remove_column(batch_index)
266
+
267
+ return Table._sanitize(
268
+ df=table.to_pandas(),
269
+ dtype_dict=self.table_dtype_dict[table_name],
270
+ stype_dict=self.table_stype_dict[table_name],
271
+ ), batch
272
+
273
+ def _by_fkey(
274
+ self,
275
+ table_name: str,
276
+ foreign_key: str,
277
+ index: pd.Series,
278
+ num_neighbors: int,
279
+ anchor_time: pd.Series | None,
280
+ columns: set[str],
281
+ ) -> tuple[pd.DataFrame, np.ndarray]:
282
+ time_column = self.time_column_dict.get(table_name)
283
+
284
+ # NOTE SQLite does not have a native datetime format. Currently, we
285
+ # assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
286
+ tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
287
+ if time_column is not None and anchor_time is not None:
288
+ anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
289
+ tmp = tmp.append_column('__kumo_time__', pa.array(anchor_time))
290
+ tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
291
+
292
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
293
+ projections = [
294
+ self.table_column_proj_dict[table_name][column]
295
+ for column in columns
296
+ ]
297
+ sql = (f"SELECT "
298
+ f"tmp.rowid - 1 as __kumo_batch__, "
299
+ f"{', '.join(projections)}\n"
300
+ f"FROM {quote_ident(tmp_name)} tmp\n"
301
+ f"JOIN {self.source_name_dict[table_name]} fact\n"
302
+ f"ON fact.rowid IN (\n"
303
+ f" SELECT rowid\n"
304
+ f" FROM {self.source_name_dict[table_name]}\n"
305
+ f" WHERE {key_ref} = tmp.__kumo_id__\n")
306
+ if time_column is not None and anchor_time is not None:
307
+ time_ref = self.table_column_ref_dict[table_name][time_column]
308
+ sql += f" AND {time_ref} <= tmp.__kumo_time__\n"
309
+ if time_column is not None:
310
+ time_ref = self.table_column_ref_dict[table_name][time_column]
311
+ sql += f" ORDER BY {time_ref} DESC\n"
312
+ sql += (f" LIMIT {num_neighbors}\n"
313
+ f")")
314
+
315
+ with self._connection.cursor() as cursor:
316
+ cursor.adbc_ingest(tmp_name, tmp, mode='replace')
317
+ cursor.execute(sql)
318
+ table = cursor.fetch_arrow_table()
319
+
320
+ batch = table['__kumo_batch__'].to_numpy()
321
+ batch_index = table.schema.get_field_index('__kumo_batch__')
322
+ table = table.remove_column(batch_index)
323
+
324
+ return Table._sanitize(
325
+ df=table.to_pandas(),
326
+ dtype_dict=self.table_dtype_dict[table_name],
327
+ stype_dict=self.table_stype_dict[table_name],
328
+ ), batch
329
+
330
+ # Helper Methods ##########################################################
331
+
332
+ def _by_time(
333
+ self,
334
+ table_name: str,
335
+ foreign_key: str,
336
+ index: pd.Series,
337
+ anchor_time: pd.Series,
338
+ min_offset: pd.DateOffset | None,
339
+ max_offset: pd.DateOffset,
340
+ columns: set[str],
341
+ ) -> tuple[pd.DataFrame, np.ndarray]:
342
+ time_column = self.time_column_dict[table_name]
343
+
344
+ # NOTE SQLite does not have a native datetime format. Currently, we
345
+ # assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
346
+ tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
347
+ end_time = anchor_time + max_offset
348
+ end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
349
+ tmp = tmp.append_column('__kumo_end__', pa.array(end_time))
350
+ if min_offset is not None:
351
+ start_time = anchor_time + min_offset
352
+ start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
353
+ tmp = tmp.append_column('__kumo_start__', pa.array(start_time))
354
+ tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
355
+
356
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
357
+ time_ref = self.table_column_ref_dict[table_name][time_column]
358
+ projections = [
359
+ self.table_column_proj_dict[table_name][column]
360
+ for column in columns
361
+ ]
362
+ sql = (f"SELECT "
363
+ f"tmp.rowid - 1 as __kumo_batch__, "
364
+ f"{', '.join(projections)}\n"
365
+ f"FROM {quote_ident(tmp_name)} tmp\n"
366
+ f"JOIN {self.source_name_dict[table_name]}\n"
367
+ f" ON {key_ref} = tmp.__kumo_id__\n"
368
+ f" AND {time_ref} <= tmp.__kumo_end__")
369
+ if min_offset is not None:
370
+ sql += f"\n AND {time_ref} > tmp.__kumo_start__"
371
+
372
+ with self._connection.cursor() as cursor:
373
+ cursor.adbc_ingest(tmp_name, tmp, mode='replace')
374
+ cursor.execute(sql)
375
+ table = cursor.fetch_arrow_table()
376
+
377
+ batch = table['__kumo_batch__'].to_numpy()
378
+ batch_index = table.schema.get_field_index('__kumo_batch__')
379
+ table = table.remove_column(batch_index)
380
+
381
+ return Table._sanitize(
382
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
383
+ dtype_dict=self.table_dtype_dict[table_name],
384
+ stype_dict=self.table_stype_dict[table_name],
385
+ ), batch
386
+
387
+ def _sample_target_set(
388
+ self,
389
+ query: ValidatedPredictiveQuery,
390
+ entity_df: pd.DataFrame,
391
+ index: np.ndarray,
392
+ anchor_time: pd.Series,
393
+ num_examples: int,
394
+ columns_dict: dict[str, set[str]],
395
+ time_offset_dict: dict[
396
+ tuple[str, str, str],
397
+ tuple[pd.DateOffset | None, pd.DateOffset],
398
+ ],
399
+ batch_size: int = 10_000,
400
+ ) -> tuple[pd.Series, np.ndarray]:
401
+
402
+ count = 0
403
+ ys: list[pd.Series] = []
404
+ mask = np.full(len(index), False, dtype=bool)
405
+ for start in range(0, len(index), batch_size):
406
+ df = entity_df.iloc[index[start:start + batch_size]]
407
+ time = anchor_time.iloc[start:start + batch_size]
408
+
409
+ feat_dict: dict[str, pd.DataFrame] = {query.entity_table: df}
410
+ time_dict: dict[str, pd.Series] = {}
411
+ time_column = self.time_column_dict.get(query.entity_table)
412
+ if time_column in columns_dict[query.entity_table]:
413
+ time_dict[query.entity_table] = df[time_column]
414
+ batch_dict: dict[str, np.ndarray] = {
415
+ query.entity_table: np.arange(len(df)),
416
+ }
417
+ for edge_type, (_min, _max) in time_offset_dict.items():
418
+ table_name, foreign_key, _ = edge_type
419
+ feat_dict[table_name], batch_dict[table_name] = self._by_time(
420
+ table_name=table_name,
421
+ foreign_key=foreign_key,
422
+ index=df[self.primary_key_dict[query.entity_table]],
423
+ anchor_time=time,
424
+ min_offset=_min,
425
+ max_offset=_max,
426
+ columns=columns_dict[table_name],
427
+ )
428
+ time_column = self.time_column_dict.get(table_name)
429
+ if time_column in columns_dict[table_name]:
430
+ time_dict[table_name] = feat_dict[table_name][time_column]
431
+
432
+ y, _mask = PQueryPandasExecutor().execute(
433
+ query=query,
434
+ feat_dict=feat_dict,
435
+ time_dict=time_dict,
436
+ batch_dict=batch_dict,
437
+ anchor_time=time,
438
+ num_forecasts=query.num_forecasts,
439
+ )
440
+ ys.append(y)
441
+ mask[start:start + batch_size] = _mask
442
+
443
+ count += len(y)
444
+ if count >= num_examples:
445
+ break
446
+
447
+ if len(ys) == 0:
448
+ y = pd.Series([], dtype=float)
449
+ elif len(ys) == 1:
450
+ y = ys[0]
451
+ else:
452
+ y = pd.concat(ys, axis=0, ignore_index=True)
453
+
454
+ return y, mask
@@ -1,18 +1,21 @@
1
1
  import re
2
- import warnings
3
- from typing import List, Optional, Sequence, cast
2
+ from collections import Counter
3
+ from collections.abc import Sequence
4
+ from typing import cast
4
5
 
5
6
  import pandas as pd
7
+ from kumoapi.model_plan import MissingType
6
8
  from kumoapi.typing import Dtype
7
9
 
8
10
  from kumoai.experimental.rfm.backend.sqlite import Connection
9
11
  from kumoai.experimental.rfm.base import (
12
+ ColumnSpec,
13
+ ColumnSpecType,
10
14
  DataBackend,
11
15
  SourceColumn,
12
16
  SourceForeignKey,
13
17
  Table,
14
18
  )
15
- from kumoai.experimental.rfm.infer import infer_dtype
16
19
  from kumoai.utils import quote_ident
17
20
 
18
21
 
@@ -22,6 +25,8 @@ class SQLiteTable(Table):
22
25
  Args:
23
26
  connection: The connection to a :class:`sqlite` database.
24
27
  name: The name of this table.
28
+ source_name: The source name of this table. If set to ``None``,
29
+ ``name`` is being used.
25
30
  columns: The selected columns of this table.
26
31
  primary_key: The name of the primary key of this table, if it exists.
27
32
  time_column: The name of the time column of this table, if it exists.
@@ -32,16 +37,18 @@ class SQLiteTable(Table):
32
37
  self,
33
38
  connection: Connection,
34
39
  name: str,
35
- columns: Optional[Sequence[str]] = None,
36
- primary_key: Optional[str] = None,
37
- time_column: Optional[str] = None,
38
- end_time_column: Optional[str] = None,
40
+ source_name: str | None = None,
41
+ columns: Sequence[ColumnSpecType] | None = None,
42
+ primary_key: MissingType | str | None = MissingType.VALUE,
43
+ time_column: str | None = None,
44
+ end_time_column: str | None = None,
39
45
  ) -> None:
40
46
 
41
47
  self._connection = connection
42
48
 
43
49
  super().__init__(
44
50
  name=name,
51
+ source_name=source_name,
45
52
  columns=columns,
46
53
  primary_key=primary_key,
47
54
  time_column=time_column,
@@ -52,64 +59,126 @@ class SQLiteTable(Table):
52
59
  def backend(self) -> DataBackend:
53
60
  return cast(DataBackend, DataBackend.SQLITE)
54
61
 
55
- def _get_source_columns(self) -> List[SourceColumn]:
56
- source_columns: List[SourceColumn] = []
62
+ def _get_source_columns(self) -> list[SourceColumn]:
63
+ source_columns: list[SourceColumn] = []
57
64
  with self._connection.cursor() as cursor:
58
- sql = f"PRAGMA table_info({quote_ident(self.name)})"
65
+ sql = f"PRAGMA table_info({self._quoted_source_name})"
59
66
  cursor.execute(sql)
60
- rows = cursor.fetchall()
67
+ columns = cursor.fetchall()
61
68
 
62
- if len(rows) == 0:
63
- raise ValueError(f"Table '{self.name}' does not exist")
64
-
65
- for _, column, type, notnull, _, is_pkey in rows:
66
- # Determine column affinity:
67
- type = type.strip().upper()
68
- if re.search('INT', type):
69
- dtype = Dtype.int
70
- elif re.search('TEXT|CHAR|CLOB', type):
71
- dtype = Dtype.string
72
- elif re.search('REAL|FLOA|DOUB', type):
73
- dtype = Dtype.float
74
- else: # NUMERIC affinity.
75
- ser = self._sample_df[column]
76
- try:
77
- dtype = infer_dtype(ser)
78
- except Exception:
79
- warnings.warn(
80
- f"Data type inference for column '{column}' in "
81
- f"table '{self.name}' failed. Consider changing "
82
- f"the data type of the column to use it within "
83
- f"this table.")
84
- continue
69
+ if len(columns) == 0:
70
+ raise ValueError(f"Table '{self.source_name}' does not exist "
71
+ f"in the SQLite database")
85
72
 
73
+ unique_keys: set[str] = set()
74
+ sql = f"PRAGMA index_list({self._quoted_source_name})"
75
+ cursor.execute(sql)
76
+ for _, index_name, is_unique, *_ in cursor.fetchall():
77
+ if bool(is_unique):
78
+ sql = f"PRAGMA index_info({quote_ident(index_name)})"
79
+ cursor.execute(sql)
80
+ index = cursor.fetchall()
81
+ if len(index) == 1:
82
+ unique_keys.add(index[0][2])
83
+
84
+ # Special SQLite case that creates a rowid alias for
85
+ # `INTEGER PRIMARY KEY` annotated columns:
86
+ rowid_candidates = [
87
+ column for _, column, dtype, _, _, is_pkey in columns
88
+ if bool(is_pkey) and dtype.strip().upper() == 'INTEGER'
89
+ ]
90
+ if len(rowid_candidates) == 1:
91
+ unique_keys.add(rowid_candidates[0])
92
+
93
+ for _, column, dtype, notnull, _, is_pkey in columns:
86
94
  source_column = SourceColumn(
87
95
  name=column,
88
- dtype=dtype,
96
+ dtype=self._to_dtype(dtype),
89
97
  is_primary_key=bool(is_pkey),
90
- is_unique_key=False,
98
+ is_unique_key=column in unique_keys,
91
99
  is_nullable=not bool(is_pkey) and not bool(notnull),
92
100
  )
93
101
  source_columns.append(source_column)
94
102
 
95
103
  return source_columns
96
104
 
97
- def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
98
- source_fkeys: List[SourceForeignKey] = []
105
+ def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
106
+ source_foreign_keys: list[SourceForeignKey] = []
99
107
  with self._connection.cursor() as cursor:
100
- sql = f"PRAGMA foreign_key_list({quote_ident(self.name)})"
108
+ sql = f"PRAGMA foreign_key_list({self._quoted_source_name})"
101
109
  cursor.execute(sql)
102
- for _, _, dst_table, fkey, pkey, _, _, _ in cursor.fetchall():
103
- source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
104
- return source_fkeys
105
-
106
- def _get_sample_df(self) -> pd.DataFrame:
110
+ rows = cursor.fetchall()
111
+ counts = Counter(row[0] for row in rows)
112
+ for idx, _, dst_table, foreign_key, primary_key, *_ in rows:
113
+ if counts[idx] == 1:
114
+ source_foreign_key = SourceForeignKey(
115
+ name=foreign_key,
116
+ dst_table=dst_table,
117
+ primary_key=primary_key,
118
+ )
119
+ source_foreign_keys.append(source_foreign_key)
120
+ return source_foreign_keys
121
+
122
+ def _get_source_sample_df(self) -> pd.DataFrame:
107
123
  with self._connection.cursor() as cursor:
108
- sql = (f"SELECT * FROM {quote_ident(self.name)} "
109
- f"ORDER BY rowid LIMIT 1000")
124
+ columns = [quote_ident(col) for col in self._source_column_dict]
125
+ sql = (f"SELECT {', '.join(columns)} "
126
+ f"FROM {self._quoted_source_name} "
127
+ f"ORDER BY rowid "
128
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
110
129
  cursor.execute(sql)
111
130
  table = cursor.fetch_arrow_table()
112
- return table.to_pandas(types_mapper=pd.ArrowDtype)
113
131
 
114
- def _get_num_rows(self) -> Optional[int]:
132
+ if len(table) == 0:
133
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
134
+
135
+ return self._sanitize(
136
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
137
+ dtype_dict={
138
+ column.name: column.dtype
139
+ for column in self._source_column_dict.values()
140
+ },
141
+ stype_dict=None,
142
+ )
143
+
144
+ def _get_num_rows(self) -> int | None:
115
145
  return None
146
+
147
+ def _get_expr_sample_df(
148
+ self,
149
+ columns: Sequence[ColumnSpec],
150
+ ) -> pd.DataFrame:
151
+ with self._connection.cursor() as cursor:
152
+ projections = [
153
+ f"{column.expr} AS {quote_ident(column.name)}"
154
+ for column in columns
155
+ ]
156
+ sql = (f"SELECT {', '.join(projections)} "
157
+ f"FROM {self._quoted_source_name} "
158
+ f"ORDER BY rowid "
159
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
160
+ cursor.execute(sql)
161
+ table = cursor.fetch_arrow_table()
162
+
163
+ if len(table) == 0:
164
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
165
+
166
+ return self._sanitize(
167
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
168
+ dtype_dict={column.name: column.dtype
169
+ for column in columns},
170
+ stype_dict=None,
171
+ )
172
+
173
+ @staticmethod
174
+ def _to_dtype(dtype: str | None) -> Dtype | None:
175
+ if dtype is None:
176
+ return None
177
+ dtype = dtype.strip().upper()
178
+ if re.search('INT', dtype):
179
+ return Dtype.int
180
+ if re.search('TEXT|CHAR|CLOB', dtype):
181
+ return Dtype.string
182
+ if re.search('REAL|FLOA|DOUB', dtype):
183
+ return Dtype.float
184
+ return None # NUMERIC affinity.