kumoai 2.13.0.dev202512091732__cp311-cp311-macosx_11_0_arm64.whl → 2.14.0.dev202512191731__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/_version.py +1 -1
  2. kumoai/client/pquery.py +6 -2
  3. kumoai/experimental/rfm/__init__.py +33 -8
  4. kumoai/experimental/rfm/authenticate.py +3 -4
  5. kumoai/experimental/rfm/backend/local/graph_store.py +40 -83
  6. kumoai/experimental/rfm/backend/local/sampler.py +128 -55
  7. kumoai/experimental/rfm/backend/local/table.py +21 -16
  8. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  9. kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
  10. kumoai/experimental/rfm/backend/snow/table.py +101 -49
  11. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
  13. kumoai/experimental/rfm/backend/sqlite/table.py +84 -31
  14. kumoai/experimental/rfm/base/__init__.py +24 -5
  15. kumoai/experimental/rfm/base/column.py +14 -12
  16. kumoai/experimental/rfm/base/column_expression.py +50 -0
  17. kumoai/experimental/rfm/base/sampler.py +429 -30
  18. kumoai/experimental/rfm/base/source.py +1 -0
  19. kumoai/experimental/rfm/base/sql_sampler.py +84 -0
  20. kumoai/experimental/rfm/base/sql_table.py +229 -0
  21. kumoai/experimental/rfm/base/table.py +165 -135
  22. kumoai/experimental/rfm/graph.py +266 -102
  23. kumoai/experimental/rfm/infer/__init__.py +6 -4
  24. kumoai/experimental/rfm/infer/dtype.py +3 -3
  25. kumoai/experimental/rfm/infer/pkey.py +4 -2
  26. kumoai/experimental/rfm/infer/stype.py +35 -0
  27. kumoai/experimental/rfm/infer/time_col.py +1 -2
  28. kumoai/experimental/rfm/pquery/executor.py +27 -27
  29. kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
  30. kumoai/experimental/rfm/rfm.py +299 -230
  31. kumoai/experimental/rfm/sagemaker.py +4 -4
  32. kumoai/pquery/predictive_query.py +10 -6
  33. kumoai/testing/snow.py +50 -0
  34. kumoai/utils/__init__.py +3 -2
  35. kumoai/utils/progress_logger.py +178 -12
  36. kumoai/utils/sql.py +3 -0
  37. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/METADATA +3 -2
  38. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/RECORD +41 -35
  39. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  40. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  41. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/WHEEL +0 -0
  42. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,349 @@
1
+ import warnings
2
+ from collections import defaultdict
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import pyarrow as pa
8
+ from kumoapi.pquery import ValidatedPredictiveQuery
9
+ from kumoapi.typing import Stype
10
+
11
+ from kumoai.experimental.rfm.base import SQLSampler
12
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
13
+ from kumoai.utils import ProgressLogger, quote_ident
14
+
15
+ if TYPE_CHECKING:
16
+ from kumoai.experimental.rfm import Graph
17
+
18
+
19
+ class SQLiteSampler(SQLSampler):
20
+ def __init__(
21
+ self,
22
+ graph: 'Graph',
23
+ verbose: bool | ProgressLogger = True,
24
+ optimize: bool = False,
25
+ ) -> None:
26
+ super().__init__(graph=graph, verbose=verbose)
27
+
28
+ if optimize:
29
+ with self._connection.cursor() as cursor:
30
+ cursor.execute("PRAGMA temp_store = MEMORY")
31
+ cursor.execute("PRAGMA cache_size = -2000000") # 2 GB
32
+
33
+ # Collect database indices to speed-up sampling:
34
+ index_dict: dict[str, set[tuple[str, ...]]] = defaultdict(set)
35
+ for table_name, primary_key in self.primary_key_dict.items():
36
+ source_table = self.source_table_dict[table_name]
37
+ if not source_table[primary_key].is_unique_key:
38
+ index_dict[table_name].add((primary_key, ))
39
+ for src_table_name, foreign_key, _ in graph.edges:
40
+ source_table = self.source_table_dict[src_table_name]
41
+ if source_table[foreign_key].is_unique_key:
42
+ pass
43
+ elif time_column := self.time_column_dict.get(src_table_name):
44
+ index_dict[src_table_name].add((foreign_key, time_column))
45
+ else:
46
+ index_dict[src_table_name].add((foreign_key, ))
47
+
48
+ # Only maintain missing indices:
49
+ with self._connection.cursor() as cursor:
50
+ for table_name in list(index_dict.keys()):
51
+ indices = index_dict[table_name]
52
+ sql = f"PRAGMA index_list({self.fqn_dict[table_name]})"
53
+ cursor.execute(sql)
54
+ for _, index_name, *_ in cursor.fetchall():
55
+ sql = f"PRAGMA index_info({quote_ident(index_name)})"
56
+ cursor.execute(sql)
57
+ index = tuple(info[2] for info in sorted(
58
+ cursor.fetchall(), key=lambda x: x[0]))
59
+ indices.discard(index)
60
+ if len(indices) == 0:
61
+ del index_dict[table_name]
62
+
63
+ num = sum(len(indices) for indices in index_dict.values())
64
+ index_repr = '1 index' if num == 1 else f'{num} indices'
65
+ num = len(index_dict)
66
+ table_repr = '1 table' if num == 1 else f'{num} tables'
67
+
68
+ if optimize and len(index_dict) > 0:
69
+ if not isinstance(verbose, ProgressLogger):
70
+ verbose = ProgressLogger.default(
71
+ msg="Optimizing SQLite database",
72
+ verbose=verbose,
73
+ )
74
+
75
+ with verbose as logger, self._connection.cursor() as cursor:
76
+ for table_name, indices in index_dict.items():
77
+ for index in indices:
78
+ name = f"kumo_index_{table_name}_{'_'.join(index)}"
79
+ name = quote_ident(name)
80
+ columns = ', '.join(quote_ident(v) for v in index)
81
+ columns += ' DESC' if len(index) > 1 else ''
82
+ sql = (f"CREATE INDEX IF NOT EXISTS {name}\n"
83
+ f"ON {self.fqn_dict[table_name]}({columns})")
84
+ cursor.execute(sql)
85
+ self._connection.commit()
86
+ logger.log(f"Created {index_repr} in {table_repr}")
87
+
88
+ elif len(index_dict) > 0:
89
+ warnings.warn(f"Missing {index_repr} in {table_repr} for optimal "
90
+ f"database querying. For improving runtime, we "
91
+ f"strongly suggest to create these indices by "
92
+ f"instantiating KumoRFM via "
93
+ f"`KumoRFM(graph, optimize=True)`.")
94
+
95
+ def _get_min_max_time_dict(
96
+ self,
97
+ table_names: list[str],
98
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
99
+ selects: list[str] = []
100
+ for table_name in table_names:
101
+ time_column = self.time_column_dict[table_name]
102
+ select = (f"SELECT\n"
103
+ f" ? as table_name,\n"
104
+ f" MIN({quote_ident(time_column)}) as min_date,\n"
105
+ f" MAX({quote_ident(time_column)}) as max_date\n"
106
+ f"FROM {self.fqn_dict[table_name]}")
107
+ selects.append(select)
108
+ sql = "\nUNION ALL\n".join(selects)
109
+
110
+ out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
111
+ with self._connection.cursor() as cursor:
112
+ cursor.execute(sql, table_names)
113
+ for table_name, _min, _max in cursor.fetchall():
114
+ out_dict[table_name] = (
115
+ pd.Timestamp.max if _min is None else pd.Timestamp(_min),
116
+ pd.Timestamp.min if _max is None else pd.Timestamp(_max),
117
+ )
118
+ return out_dict
119
+
120
+ def _sample_entity_table(
121
+ self,
122
+ table_name: str,
123
+ columns: set[str],
124
+ num_rows: int,
125
+ random_seed: int | None = None,
126
+ ) -> pd.DataFrame:
127
+ # NOTE SQLite does not natively support passing a `random_seed`.
128
+
129
+ filters: list[str] = []
130
+ primary_key = self.primary_key_dict[table_name]
131
+ if self.source_table_dict[table_name][primary_key].is_nullable:
132
+ filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
133
+ time_column = self.time_column_dict.get(table_name)
134
+ if (time_column is not None and
135
+ self.source_table_dict[table_name][time_column].is_nullable):
136
+ filters.append(f" {quote_ident(time_column)} IS NOT NULL")
137
+
138
+ # TODO Make this query more efficient - it does full table scan.
139
+ sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
140
+ f"FROM {self.fqn_dict[table_name]}")
141
+ if len(filters) > 0:
142
+ sql += f"\nWHERE{' AND'.join(filters)}"
143
+ sql += f"\nORDER BY RANDOM() LIMIT {num_rows}"
144
+
145
+ with self._connection.cursor() as cursor:
146
+ # NOTE This may return duplicate primary keys. This is okay.
147
+ cursor.execute(sql)
148
+ table = cursor.fetch_arrow_table()
149
+
150
+ return self._sanitize(table_name, table)
151
+
152
+ def _sample_target(
153
+ self,
154
+ query: ValidatedPredictiveQuery,
155
+ entity_df: pd.DataFrame,
156
+ train_index: np.ndarray,
157
+ train_time: pd.Series,
158
+ num_train_examples: int,
159
+ test_index: np.ndarray,
160
+ test_time: pd.Series,
161
+ num_test_examples: int,
162
+ columns_dict: dict[str, set[str]],
163
+ time_offset_dict: dict[
164
+ tuple[str, str, str],
165
+ tuple[pd.DateOffset | None, pd.DateOffset],
166
+ ],
167
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
168
+ train_y, train_mask = self._sample_target_set(
169
+ query=query,
170
+ entity_df=entity_df,
171
+ index=train_index,
172
+ anchor_time=train_time,
173
+ num_examples=num_train_examples,
174
+ columns_dict=columns_dict,
175
+ time_offset_dict=time_offset_dict,
176
+ )
177
+
178
+ test_y, test_mask = self._sample_target_set(
179
+ query=query,
180
+ entity_df=entity_df,
181
+ index=test_index,
182
+ anchor_time=test_time,
183
+ num_examples=num_test_examples,
184
+ columns_dict=columns_dict,
185
+ time_offset_dict=time_offset_dict,
186
+ )
187
+
188
+ return train_y, train_mask, test_y, test_mask
189
+
190
+ def _by_pkey(
191
+ self,
192
+ table_name: str,
193
+ pkey: pd.Series,
194
+ columns: set[str],
195
+ ) -> tuple[pd.DataFrame, np.ndarray]:
196
+ pkey_name = self.primary_key_dict[table_name]
197
+
198
+ tmp = pa.table([pa.array(pkey)], names=['id'])
199
+ tmp_name = f'tmp_{table_name}_{pkey_name}_{id(tmp)}'
200
+
201
+ if self.source_table_dict[table_name][pkey_name].is_unique_key:
202
+ sql = (f"SELECT tmp.rowid - 1 as __batch__, "
203
+ f"{', '.join('ent.' + quote_ident(c) for c in columns)}\n"
204
+ f"FROM {quote_ident(tmp_name)} tmp\n"
205
+ f"JOIN {self.fqn_dict[table_name]} ent\n"
206
+ f" ON ent.{quote_ident(pkey_name)} = tmp.id")
207
+ else:
208
+ sql = (f"SELECT tmp.rowid - 1 as __batch__, "
209
+ f"{', '.join('ent.' + quote_ident(c) for c in columns)}\n"
210
+ f"FROM {quote_ident(tmp_name)} tmp\n"
211
+ f"JOIN {self.fqn_dict[table_name]} ent\n"
212
+ f" ON ent.rowid = (\n"
213
+ f" SELECT rowid FROM {self.fqn_dict[table_name]}\n"
214
+ f" WHERE {quote_ident(pkey_name)} == tmp.id\n"
215
+ f" LIMIT 1\n"
216
+ f")")
217
+
218
+ with self._connection.cursor() as cursor:
219
+ cursor.adbc_ingest(tmp_name, tmp, mode='replace')
220
+ cursor.execute(sql)
221
+ table = cursor.fetch_arrow_table()
222
+
223
+ batch = table['__batch__'].to_numpy()
224
+ table = table.remove_column(table.schema.get_field_index('__batch__'))
225
+
226
+ return table.to_pandas(), batch # TODO Use `self._sanitize`.
227
+
228
+ # Helper Methods ##########################################################
229
+
230
+ def _by_time(
231
+ self,
232
+ table_name: str,
233
+ fkey: str,
234
+ pkey: pd.Series,
235
+ anchor_time: pd.Series,
236
+ min_offset: pd.DateOffset | None,
237
+ max_offset: pd.DateOffset,
238
+ columns: set[str],
239
+ ) -> tuple[pd.DataFrame, np.ndarray]:
240
+ # NOTE SQLite does not have a native datetime format. Currently, we
241
+ # assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
242
+ tmp = pa.table([pa.array(pkey)], names=['id'])
243
+ end_time = anchor_time + max_offset
244
+ end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
245
+ tmp = tmp.append_column('end', pa.array(end_time))
246
+ if min_offset is not None:
247
+ start_time = anchor_time + min_offset
248
+ start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
249
+ tmp = tmp.append_column('start', pa.array(start_time))
250
+ tmp_name = f'tmp_{table_name}_{fkey}_{id(tmp)}'
251
+
252
+ time_column = self.time_column_dict[table_name]
253
+ sql = (f"SELECT tmp.rowid - 1 as __batch__, "
254
+ f"{', '.join('fact.' + quote_ident(col) for col in columns)}\n"
255
+ f"FROM {quote_ident(tmp_name)} tmp\n"
256
+ f"JOIN {self.fqn_dict[table_name]} fact\n"
257
+ f" ON fact.{quote_ident(fkey)} = tmp.id\n"
258
+ f" AND fact.{quote_ident(time_column)} <= tmp.end")
259
+ if min_offset is not None:
260
+ sql += f"\n AND fact.{quote_ident(time_column)} > tmp.start"
261
+
262
+ with self._connection.cursor() as cursor:
263
+ cursor.adbc_ingest(tmp_name, tmp, mode='replace')
264
+ cursor.execute(sql)
265
+ table = cursor.fetch_arrow_table()
266
+
267
+ batch = table['__batch__'].to_numpy()
268
+ table = table.remove_column(table.schema.get_field_index('__batch__'))
269
+
270
+ return self._sanitize(table_name, table), batch
271
+
272
+ def _sample_target_set(
273
+ self,
274
+ query: ValidatedPredictiveQuery,
275
+ entity_df: pd.DataFrame,
276
+ index: np.ndarray,
277
+ anchor_time: pd.Series,
278
+ num_examples: int,
279
+ columns_dict: dict[str, set[str]],
280
+ time_offset_dict: dict[
281
+ tuple[str, str, str],
282
+ tuple[pd.DateOffset | None, pd.DateOffset],
283
+ ],
284
+ batch_size: int = 10_000,
285
+ ) -> tuple[pd.Series, np.ndarray]:
286
+
287
+ count = 0
288
+ ys: list[pd.Series] = []
289
+ mask = np.full(len(index), False, dtype=bool)
290
+ for start in range(0, len(index), batch_size):
291
+ df = entity_df.iloc[index[start:start + batch_size]]
292
+ time = anchor_time.iloc[start:start + batch_size]
293
+
294
+ feat_dict: dict[str, pd.DataFrame] = {query.entity_table: df}
295
+ time_dict: dict[str, pd.Series] = {}
296
+ time_column = self.time_column_dict.get(query.entity_table)
297
+ if time_column in columns_dict[query.entity_table]:
298
+ time_dict[query.entity_table] = df[time_column]
299
+ batch_dict: dict[str, np.ndarray] = {
300
+ query.entity_table: np.arange(len(df)),
301
+ }
302
+ for edge_type, (_min, _max) in time_offset_dict.items():
303
+ table_name, fkey, _ = edge_type
304
+ feat_dict[table_name], batch_dict[table_name] = self._by_time(
305
+ table_name=table_name,
306
+ fkey=fkey,
307
+ pkey=df[self.primary_key_dict[query.entity_table]],
308
+ anchor_time=time,
309
+ min_offset=_min,
310
+ max_offset=_max,
311
+ columns=columns_dict[table_name],
312
+ )
313
+ time_column = self.time_column_dict.get(table_name)
314
+ if time_column in columns_dict[table_name]:
315
+ time_dict[table_name] = feat_dict[table_name][time_column]
316
+
317
+ y, _mask = PQueryPandasExecutor().execute(
318
+ query=query,
319
+ feat_dict=feat_dict,
320
+ time_dict=time_dict,
321
+ batch_dict=batch_dict,
322
+ anchor_time=anchor_time,
323
+ num_forecasts=query.num_forecasts,
324
+ )
325
+ ys.append(y)
326
+ mask[start:start + batch_size] = _mask
327
+
328
+ count += len(y)
329
+ if count >= num_examples:
330
+ break
331
+
332
+ if len(ys) == 0:
333
+ y = pd.Series([], dtype=float)
334
+ elif len(ys) == 1:
335
+ y = ys[0]
336
+ else:
337
+ y = pd.concat(ys, axis=0, ignore_index=True)
338
+
339
+ return y, mask
340
+
341
+ def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
342
+ df = table.to_pandas(types_mapper=pd.ArrowDtype)
343
+
344
+ stype_dict = self.table_stype_dict[table_name]
345
+ for column_name in df.columns:
346
+ if stype_dict.get(column_name) == Stype.timestamp:
347
+ df[column_name] = pd.to_datetime(df[column_name])
348
+
349
+ return df
@@ -1,22 +1,35 @@
1
1
  import re
2
2
  import warnings
3
- from typing import List, Optional, Sequence
3
+ from collections.abc import Sequence
4
+ from typing import cast
4
5
 
5
6
  import pandas as pd
7
+ from kumoapi.model_plan import MissingType
6
8
  from kumoapi.typing import Dtype
7
9
 
8
10
  from kumoai.experimental.rfm.backend.sqlite import Connection
9
- from kumoai.experimental.rfm.base import SourceColumn, SourceForeignKey, Table
11
+ from kumoai.experimental.rfm.base import (
12
+ ColumnExpressionSpec,
13
+ ColumnExpressionType,
14
+ DataBackend,
15
+ SourceColumn,
16
+ SourceForeignKey,
17
+ SQLTable,
18
+ )
10
19
  from kumoai.experimental.rfm.infer import infer_dtype
20
+ from kumoai.utils import quote_ident
11
21
 
12
22
 
13
- class SQLiteTable(Table):
23
+ class SQLiteTable(SQLTable):
14
24
  r"""A table backed by a :class:`sqlite` database.
15
25
 
16
26
  Args:
17
27
  connection: The connection to a :class:`sqlite` database.
18
- name: The name of this table.
19
- columns: The selected columns of this table.
28
+ name: The logical name of this table.
29
+ source_name: The physical name of this table in the database. If set to
30
+ ``None``, ``name`` is being used.
31
+ columns: The selected physical columns of this table.
32
+ column_expressions: The logical columns of this table.
20
33
  primary_key: The name of the primary key of this table, if it exists.
21
34
  time_column: The name of the time column of this table, if it exists.
22
35
  end_time_column: The name of the end time column of this table, if it
@@ -26,32 +39,53 @@ class SQLiteTable(Table):
26
39
  self,
27
40
  connection: Connection,
28
41
  name: str,
29
- columns: Optional[Sequence[str]] = None,
30
- primary_key: Optional[str] = None,
31
- time_column: Optional[str] = None,
32
- end_time_column: Optional[str] = None,
42
+ source_name: str | None = None,
43
+ columns: Sequence[str] | None = None,
44
+ column_expressions: Sequence[ColumnExpressionType] | None = None,
45
+ primary_key: MissingType | str | None = MissingType.VALUE,
46
+ time_column: str | None = None,
47
+ end_time_column: str | None = None,
33
48
  ) -> None:
34
49
 
35
50
  self._connection = connection
36
51
 
37
52
  super().__init__(
38
53
  name=name,
54
+ source_name=source_name,
39
55
  columns=columns,
56
+ column_expressions=column_expressions,
40
57
  primary_key=primary_key,
41
58
  time_column=time_column,
42
59
  end_time_column=end_time_column,
43
60
  )
44
61
 
45
- def _get_source_columns(self) -> List[SourceColumn]:
46
- source_columns: List[SourceColumn] = []
62
+ @property
63
+ def backend(self) -> DataBackend:
64
+ return cast(DataBackend, DataBackend.SQLITE)
65
+
66
+ def _get_source_columns(self) -> list[SourceColumn]:
67
+ source_columns: list[SourceColumn] = []
47
68
  with self._connection.cursor() as cursor:
48
- cursor.execute(f"PRAGMA table_info({self.name})")
49
- rows = cursor.fetchall()
69
+ sql = f"PRAGMA table_info({self.fqn})"
70
+ cursor.execute(sql)
71
+ columns = cursor.fetchall()
72
+
73
+ if len(columns) == 0:
74
+ raise ValueError(f"Table '{self._source_name}' does not exist "
75
+ f"in the SQLite database")
50
76
 
51
- if len(rows) == 0:
52
- raise ValueError(f"Table '{self.name}' does not exist")
77
+ unique_keys: set[str] = set()
78
+ sql = f"PRAGMA index_list({self.fqn})"
79
+ cursor.execute(sql)
80
+ for _, index_name, is_unique, *_ in cursor.fetchall():
81
+ if bool(is_unique):
82
+ sql = f"PRAGMA index_info({quote_ident(index_name)})"
83
+ cursor.execute(sql)
84
+ index = cursor.fetchall()
85
+ if len(index) == 1:
86
+ unique_keys.add(index[0][2])
53
87
 
54
- for _, column, type, _, _, is_pkey in rows:
88
+ for _, column, type, notnull, _, is_pkey in columns:
55
89
  # Determine column affinity:
56
90
  type = type.strip().upper()
57
91
  if re.search('INT', type):
@@ -61,41 +95,60 @@ class SQLiteTable(Table):
61
95
  elif re.search('REAL|FLOA|DOUB', type):
62
96
  dtype = Dtype.float
63
97
  else: # NUMERIC affinity.
64
- ser = self._sample_df[column]
98
+ ser = self._source_sample_df[column]
65
99
  try:
66
100
  dtype = infer_dtype(ser)
67
101
  except Exception:
68
- warnings.warn(
69
- f"Data type inference for column '{column}' in "
70
- f"table '{self.name}' failed. Consider changing "
71
- f"the data type of the column to use it within "
72
- f"this table.")
102
+ warnings.warn(f"Encountered unsupported data type "
103
+ f"'{ser.dtype}' with source data type "
104
+ f"'{type}' for column '{column}' in "
105
+ f"table '{self.name}'. If possible, "
106
+ f"change the data type of the column in "
107
+ f"your SQLite database to use it within "
108
+ f"this table.")
73
109
  continue
74
110
 
75
111
  source_column = SourceColumn(
76
112
  name=column,
77
113
  dtype=dtype,
78
114
  is_primary_key=bool(is_pkey),
79
- is_unique_key=False,
115
+ is_unique_key=column in unique_keys,
116
+ is_nullable=not bool(is_pkey) and not bool(notnull),
80
117
  )
81
118
  source_columns.append(source_column)
82
119
 
83
120
  return source_columns
84
121
 
85
- def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
86
- source_fkeys: List[SourceForeignKey] = []
122
+ def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
123
+ source_fkeys: list[SourceForeignKey] = []
87
124
  with self._connection.cursor() as cursor:
88
- cursor.execute(f"PRAGMA foreign_key_list({self.name})")
89
- for _, _, dst_table, fkey, pkey, _, _, _ in cursor.fetchall():
125
+ sql = f"PRAGMA foreign_key_list({self.fqn})"
126
+ cursor.execute(sql)
127
+ for _, _, dst_table, fkey, pkey, *_ in cursor.fetchall():
90
128
  source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
91
129
  return source_fkeys
92
130
 
93
- def _get_sample_df(self) -> pd.DataFrame:
131
+ def _get_source_sample_df(self) -> pd.DataFrame:
94
132
  with self._connection.cursor() as cursor:
95
- cursor.execute(f"SELECT * FROM {self.name} "
96
- f"ORDER BY rowid LIMIT 1000")
133
+ sql = (f"SELECT * FROM {self.fqn} "
134
+ f"ORDER BY rowid LIMIT 1000")
135
+ cursor.execute(sql)
97
136
  table = cursor.fetch_arrow_table()
98
137
  return table.to_pandas(types_mapper=pd.ArrowDtype)
99
138
 
100
- def _get_num_rows(self) -> Optional[int]:
139
+ def _get_num_rows(self) -> int | None:
101
140
  return None
141
+
142
+ def _get_expression_sample_df(
143
+ self,
144
+ specs: Sequence[ColumnExpressionSpec],
145
+ ) -> pd.DataFrame:
146
+ with self._connection.cursor() as cursor:
147
+ columns = [
148
+ f"{spec.expr} AS {quote_ident(spec.name)}" for spec in specs
149
+ ]
150
+ sql = (f"SELECT {', '.join(columns)} FROM {self.fqn} "
151
+ f"ORDER BY rowid LIMIT 1000")
152
+ cursor.execute(sql)
153
+ table = cursor.fetch_arrow_table()
154
+ return table.to_pandas(types_mapper=pd.ArrowDtype)
@@ -1,14 +1,33 @@
1
- from .source import SourceColumn, SourceForeignKey
2
- from .column import Column
3
- from .table import Table
4
- from .sampler import SamplerOutput, TargetOutput, Sampler
1
+ from kumoapi.common import StrEnum
2
+
3
+
4
+ class DataBackend(StrEnum):
5
+ LOCAL = 'local'
6
+ SQLITE = 'sqlite'
7
+ SNOWFLAKE = 'snowflake'
8
+
9
+
10
+ from .source import SourceColumn, SourceForeignKey # noqa: E402
11
+ from .column import Column # noqa: E402
12
+ from .column_expression import ColumnExpressionSpec # noqa: E402
13
+ from .column_expression import ColumnExpressionType # noqa: E402
14
+ from .column_expression import ColumnExpression # noqa: E402
15
+ from .table import Table # noqa: E402
16
+ from .sql_table import SQLTable # noqa: E402
17
+ from .sampler import SamplerOutput, Sampler # noqa: E402
18
+ from .sql_sampler import SQLSampler # noqa: E402
5
19
 
6
20
  __all__ = [
21
+ 'DataBackend',
7
22
  'SourceColumn',
8
23
  'SourceForeignKey',
9
24
  'Column',
25
+ 'ColumnExpressionSpec',
26
+ 'ColumnExpressionType',
27
+ 'ColumnExpression',
10
28
  'Table',
29
+ 'SQLTable',
11
30
  'SamplerOutput',
12
- 'TargetOutput',
13
31
  'Sampler',
32
+ 'SQLSampler',
14
33
  ]
@@ -8,20 +8,14 @@ from kumoapi.typing import Dtype, Stype
8
8
  class Column:
9
9
  stype: Stype
10
10
 
11
- def __init__(
12
- self,
13
- name: str,
14
- dtype: Dtype,
15
- stype: Stype,
16
- is_primary_key: bool = False,
17
- is_time_column: bool = False,
18
- is_end_time_column: bool = False,
19
- ) -> None:
11
+ def __init__(self, name: str, stype: Stype, dtype: Dtype) -> None:
20
12
  self._name = name
21
13
  self._dtype = Dtype(dtype)
22
- self._is_primary_key = is_primary_key
23
- self._is_time_column = is_time_column
24
- self._is_end_time_column = is_end_time_column
14
+
15
+ self._is_primary_key = False
16
+ self._is_time_column = False
17
+ self._is_end_time_column = False
18
+
25
19
  self.stype = Stype(stype)
26
20
 
27
21
  @property
@@ -32,6 +26,14 @@ class Column:
32
26
  def dtype(self) -> Dtype:
33
27
  return self._dtype
34
28
 
29
+ @property
30
+ def is_physical(self) -> bool:
31
+ return True
32
+
33
+ @property
34
+ def is_logical(self) -> bool:
35
+ return not self.is_physical
36
+
35
37
  def __setattr__(self, key: str, val: Any) -> None:
36
38
  if key == 'stype':
37
39
  if isinstance(val, str):
@@ -0,0 +1,50 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, TypeAlias
3
+
4
+ from kumoapi.typing import Dtype, Stype
5
+
6
+ from kumoai.experimental.rfm.base import Column
7
+ from kumoai.mixin import CastMixin
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class ColumnExpressionSpec(CastMixin):
12
+ name: str
13
+ expr: str
14
+ dtype: Dtype | None = None
15
+
16
+
17
+ ColumnExpressionType: TypeAlias = ColumnExpressionSpec | dict[str, Any]
18
+
19
+
20
+ @dataclass(init=False, repr=False, eq=False)
21
+ class ColumnExpression(Column):
22
+ def __init__(
23
+ self,
24
+ name: str,
25
+ expr: str,
26
+ stype: Stype,
27
+ dtype: Dtype,
28
+ ) -> None:
29
+ super().__init__(name=name, stype=stype, dtype=dtype)
30
+ self._expr = expr
31
+
32
+ @property
33
+ def expr(self) -> str:
34
+ return self._expr
35
+
36
+ @property
37
+ def is_physical(self) -> bool:
38
+ return False
39
+
40
+ def __hash__(self) -> int:
41
+ return hash((self.name, self.expr, self.stype, self.dtype))
42
+
43
+ def __eq__(self, other: Any) -> bool:
44
+ if not isinstance(other, ColumnExpression):
45
+ return False
46
+ return hash(self) == hash(other)
47
+
48
+ def __repr__(self) -> str:
49
+ return (f'{self.__class__.__name__}(name={self.name}, '
50
+ f'expr={self.expr}, stype={self.stype}, dtype={self.dtype})')