kumoai 2.13.0.dev202512081731__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. kumoai/_version.py +1 -1
  2. kumoai/client/pquery.py +6 -2
  3. kumoai/experimental/rfm/backend/local/graph_store.py +19 -62
  4. kumoai/experimental/rfm/backend/local/sampler.py +213 -14
  5. kumoai/experimental/rfm/backend/local/table.py +12 -2
  6. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  7. kumoai/experimental/rfm/backend/snow/sampler.py +264 -0
  8. kumoai/experimental/rfm/backend/snow/table.py +35 -17
  9. kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -0
  10. kumoai/experimental/rfm/backend/sqlite/sampler.py +354 -0
  11. kumoai/experimental/rfm/backend/sqlite/table.py +36 -11
  12. kumoai/experimental/rfm/base/__init__.py +17 -6
  13. kumoai/experimental/rfm/base/sampler.py +438 -38
  14. kumoai/experimental/rfm/base/source.py +1 -0
  15. kumoai/experimental/rfm/base/sql_sampler.py +56 -0
  16. kumoai/experimental/rfm/base/table.py +12 -1
  17. kumoai/experimental/rfm/graph.py +26 -9
  18. kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
  19. kumoai/experimental/rfm/rfm.py +214 -151
  20. kumoai/pquery/predictive_query.py +10 -6
  21. kumoai/testing/snow.py +50 -0
  22. kumoai/utils/__init__.py +2 -0
  23. kumoai/utils/sql.py +3 -0
  24. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/METADATA +2 -2
  25. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/RECORD +28 -25
  26. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  27. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  28. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/WHEEL +0 -0
  29. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/licenses/LICENSE +0 -0
  30. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,354 @@
1
+ import warnings
2
+ from collections import defaultdict
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import pyarrow as pa
8
+ from kumoapi.pquery import ValidatedPredictiveQuery
9
+ from kumoapi.typing import Stype
10
+
11
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteTable
12
+ from kumoai.experimental.rfm.base import SQLSampler
13
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
14
+ from kumoai.utils import InteractiveProgressLogger, ProgressLogger, quote_ident
15
+
16
+ if TYPE_CHECKING:
17
+ from kumoai.experimental.rfm import Graph
18
+
19
+
20
+ class SQLiteSampler(SQLSampler):
21
+ def __init__(
22
+ self,
23
+ graph: 'Graph',
24
+ verbose: bool | ProgressLogger = True,
25
+ optimize: bool = False,
26
+ ) -> None:
27
+ super().__init__(graph=graph, verbose=verbose)
28
+
29
+ for table in graph.tables.values():
30
+ assert isinstance(table, SQLiteTable)
31
+ self._connection = table._connection
32
+
33
+ if optimize:
34
+ with self._connection.cursor() as cursor:
35
+ cursor.execute("PRAGMA temp_store = MEMORY")
36
+ cursor.execute("PRAGMA cache_size = -2000000") # 2 GB
37
+
38
+ # Collect database indices to speed-up sampling:
39
+ index_dict: dict[str, set[tuple[str, ...]]] = defaultdict(set)
40
+ for table_name, primary_key in self.primary_key_dict.items():
41
+ source_table = self.source_table_dict[table_name]
42
+ if not source_table[primary_key].is_unique_key:
43
+ index_dict[table_name].add((primary_key, ))
44
+ for src_table_name, foreign_key, _ in graph.edges:
45
+ source_table = self.source_table_dict[src_table_name]
46
+ if source_table[foreign_key].is_unique_key:
47
+ pass
48
+ elif time_column := self.time_column_dict.get(src_table_name):
49
+ index_dict[src_table_name].add((foreign_key, time_column))
50
+ else:
51
+ index_dict[src_table_name].add((foreign_key, ))
52
+
53
+ # Only maintain missing indices:
54
+ with self._connection.cursor() as cursor:
55
+ for table_name in list(index_dict.keys()):
56
+ indices = index_dict[table_name]
57
+ sql = f"PRAGMA index_list({quote_ident(table_name)})"
58
+ cursor.execute(sql)
59
+ for _, index_name, *_ in cursor.fetchall():
60
+ sql = f"PRAGMA index_info({quote_ident(index_name)})"
61
+ cursor.execute(sql)
62
+ index = tuple(info[2] for info in sorted(
63
+ cursor.fetchall(), key=lambda x: x[0]))
64
+ indices.discard(index)
65
+ if len(indices) == 0:
66
+ del index_dict[table_name]
67
+
68
+ num = sum(len(indices) for indices in index_dict.values())
69
+ index_repr = '1 index' if num == 1 else f'{num} indices'
70
+ num = len(index_dict)
71
+ table_repr = '1 table' if num == 1 else f'{num} tables'
72
+
73
+ if optimize and len(index_dict) > 0:
74
+ if not isinstance(verbose, ProgressLogger):
75
+ verbose = InteractiveProgressLogger(
76
+ "Optimizing SQLite database",
77
+ verbose=verbose,
78
+ )
79
+
80
+ with verbose as logger:
81
+ with self._connection.cursor() as cursor:
82
+ for table_name, indices in index_dict.items():
83
+ for index in indices:
84
+ name = f"kumo_index_{table_name}_{'_'.join(index)}"
85
+ columns = ', '.join(quote_ident(v) for v in index)
86
+ columns += ' DESC' if len(index) > 1 else ''
87
+ sql = (f"CREATE INDEX IF NOT EXISTS {name}\n"
88
+ f"ON {quote_ident(table_name)}({columns})")
89
+ cursor.execute(sql)
90
+ self._connection.commit()
91
+ logger.log(f"Created {index_repr} in {table_repr}")
92
+
93
+ elif len(index_dict) > 0:
94
+ warnings.warn(f"Missing {index_repr} in {table_repr} for optimal "
95
+ f"database querying. For improving runtime, we "
96
+ f"strongly suggest to create these indices by "
97
+ f"instantiating KumoRFM via "
98
+ f"`KumoRFM(graph, optimize=True)`.")
99
+
100
+ def _get_min_max_time_dict(
101
+ self,
102
+ table_names: list[str],
103
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
104
+ selects: list[str] = []
105
+ for table_name in table_names:
106
+ time_column = self.time_column_dict[table_name]
107
+ select = (f"SELECT\n"
108
+ f" ? as table_name,\n"
109
+ f" MIN({quote_ident(time_column)}) as min_date,\n"
110
+ f" MAX({quote_ident(time_column)}) as max_date\n"
111
+ f"FROM {quote_ident(table_name)}")
112
+ selects.append(select)
113
+ sql = "\nUNION ALL\n".join(selects)
114
+
115
+ out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
116
+ with self._connection.cursor() as cursor:
117
+ cursor.execute(sql, table_names)
118
+ for table_name, _min, _max in cursor.fetchall():
119
+ out_dict[table_name] = (
120
+ pd.Timestamp.max if _min is None else pd.Timestamp(_min),
121
+ pd.Timestamp.min if _max is None else pd.Timestamp(_max),
122
+ )
123
+ return out_dict
124
+
125
+ def _sample_entity_table(
126
+ self,
127
+ table_name: str,
128
+ columns: set[str],
129
+ num_rows: int,
130
+ random_seed: int | None = None,
131
+ ) -> pd.DataFrame:
132
+ # NOTE SQLite does not natively support passing a `random_seed`.
133
+
134
+ filters: list[str] = []
135
+ primary_key = self.primary_key_dict[table_name]
136
+ if self.source_table_dict[table_name][primary_key].is_nullable:
137
+ filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
138
+ time_column = self.time_column_dict.get(table_name)
139
+ if (time_column is not None and
140
+ self.source_table_dict[table_name][time_column].is_nullable):
141
+ filters.append(f" {quote_ident(time_column)} IS NOT NULL")
142
+
143
+ # TODO Make this query more efficient - it does full table scan.
144
+ sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
145
+ f"FROM {quote_ident(table_name)}")
146
+ if len(filters) > 0:
147
+ sql += f"\nWHERE{' AND'.join(filters)}"
148
+ sql += f"\nORDER BY RANDOM() LIMIT {num_rows}"
149
+
150
+ with self._connection.cursor() as cursor:
151
+ # NOTE This may return duplicate primary keys. This is okay.
152
+ cursor.execute(sql)
153
+ table = cursor.fetch_arrow_table()
154
+
155
+ return self._sanitize(table_name, table)
156
+
157
+ def _sample_target(
158
+ self,
159
+ query: ValidatedPredictiveQuery,
160
+ entity_df: pd.DataFrame,
161
+ train_index: np.ndarray,
162
+ train_time: pd.Series,
163
+ num_train_examples: int,
164
+ test_index: np.ndarray,
165
+ test_time: pd.Series,
166
+ num_test_examples: int,
167
+ columns_dict: dict[str, set[str]],
168
+ time_offset_dict: dict[
169
+ tuple[str, str, str],
170
+ tuple[pd.DateOffset | None, pd.DateOffset],
171
+ ],
172
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
173
+ train_y, train_mask = self._sample_target_set(
174
+ query=query,
175
+ entity_df=entity_df,
176
+ index=train_index,
177
+ anchor_time=train_time,
178
+ num_examples=num_train_examples,
179
+ columns_dict=columns_dict,
180
+ time_offset_dict=time_offset_dict,
181
+ )
182
+
183
+ test_y, test_mask = self._sample_target_set(
184
+ query=query,
185
+ entity_df=entity_df,
186
+ index=test_index,
187
+ anchor_time=test_time,
188
+ num_examples=num_test_examples,
189
+ columns_dict=columns_dict,
190
+ time_offset_dict=time_offset_dict,
191
+ )
192
+
193
+ return train_y, train_mask, test_y, test_mask
194
+
195
+ def _by_pkey(
196
+ self,
197
+ table_name: str,
198
+ pkey: pd.Series,
199
+ columns: set[str],
200
+ ) -> tuple[pd.DataFrame, np.ndarray]:
201
+ pkey_name = self.primary_key_dict[table_name]
202
+
203
+ tmp = pa.table([pa.array(pkey)], names=['id'])
204
+ tmp_name = f'tmp_{table_name}_{pkey_name}_{id(tmp)}'
205
+
206
+ if self.source_table_dict[table_name][pkey_name].is_unique_key:
207
+ sql = (f"SELECT tmp.rowid - 1 as __batch__, "
208
+ f"{', '.join('ent.' + quote_ident(c) for c in columns)}\n"
209
+ f"FROM {quote_ident(tmp_name)} tmp\n"
210
+ f"JOIN {quote_ident(table_name)} ent\n"
211
+ f" ON ent.{quote_ident(pkey_name)} = tmp.id")
212
+ else:
213
+ sql = (f"SELECT tmp.rowid - 1 as __batch__, "
214
+ f"{', '.join('ent.' + quote_ident(c) for c in columns)}\n"
215
+ f"FROM {quote_ident(tmp_name)} tmp\n"
216
+ f"JOIN {quote_ident(table_name)} ent\n"
217
+ f" ON ent.rowid = (\n"
218
+ f" SELECT rowid FROM {quote_ident(table_name)}\n"
219
+ f" WHERE {quote_ident(pkey_name)} == tmp.id\n"
220
+ f" LIMIT 1\n"
221
+ f")")
222
+
223
+ with self._connection.cursor() as cursor:
224
+ cursor.adbc_ingest(tmp_name, tmp, mode='replace')
225
+ cursor.execute(sql)
226
+ table = cursor.fetch_arrow_table()
227
+
228
+ batch = table['__batch__'].to_numpy()
229
+ table = table.remove_column(table.schema.get_field_index('__batch__'))
230
+
231
+ return table.to_pandas(), batch # TODO Use `self._sanitize`.
232
+
233
+ # Helper Methods ##########################################################
234
+
235
+ def _by_time(
236
+ self,
237
+ table_name: str,
238
+ fkey: str,
239
+ pkey: pd.Series,
240
+ anchor_time: pd.Series,
241
+ min_offset: pd.DateOffset | None,
242
+ max_offset: pd.DateOffset,
243
+ columns: set[str],
244
+ ) -> tuple[pd.DataFrame, np.ndarray]:
245
+ # NOTE SQLite does not have a native datetime format. Currently, we
246
+ # assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
247
+ tmp = pa.table([pa.array(pkey)], names=['id'])
248
+ end_time = anchor_time + max_offset
249
+ end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
250
+ tmp = tmp.append_column('end', pa.array(end_time))
251
+ if min_offset is not None:
252
+ start_time = anchor_time + min_offset
253
+ start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
254
+ tmp = tmp.append_column('start', pa.array(start_time))
255
+ tmp_name = f'tmp_{table_name}_{fkey}_{id(tmp)}'
256
+
257
+ time_column = self.time_column_dict[table_name]
258
+ sql = (f"SELECT tmp.rowid - 1 as __batch__, "
259
+ f"{', '.join('fact.' + quote_ident(col) for col in columns)}\n"
260
+ f"FROM {quote_ident(tmp_name)} tmp\n"
261
+ f"JOIN {quote_ident(table_name)} fact\n"
262
+ f" ON fact.{quote_ident(fkey)} = tmp.id\n"
263
+ f" AND fact.{quote_ident(time_column)} <= tmp.end")
264
+ if min_offset is not None:
265
+ sql += f"\n AND fact.{quote_ident(time_column)} > tmp.start"
266
+
267
+ with self._connection.cursor() as cursor:
268
+ cursor.adbc_ingest(tmp_name, tmp, mode='replace')
269
+ cursor.execute(sql)
270
+ table = cursor.fetch_arrow_table()
271
+
272
+ batch = table['__batch__'].to_numpy()
273
+ table = table.remove_column(table.schema.get_field_index('__batch__'))
274
+
275
+ return self._sanitize(table_name, table), batch
276
+
277
+ def _sample_target_set(
278
+ self,
279
+ query: ValidatedPredictiveQuery,
280
+ entity_df: pd.DataFrame,
281
+ index: np.ndarray,
282
+ anchor_time: pd.Series,
283
+ num_examples: int,
284
+ columns_dict: dict[str, set[str]],
285
+ time_offset_dict: dict[
286
+ tuple[str, str, str],
287
+ tuple[pd.DateOffset | None, pd.DateOffset],
288
+ ],
289
+ batch_size: int = 10_000,
290
+ ) -> tuple[pd.Series, np.ndarray]:
291
+
292
+ count = 0
293
+ ys: list[pd.Series] = []
294
+ mask = np.full(len(index), False, dtype=bool)
295
+ for start in range(0, len(index), batch_size):
296
+ df = entity_df.iloc[index[start:start + batch_size]]
297
+ time = anchor_time.iloc[start:start + batch_size]
298
+
299
+ feat_dict: dict[str, pd.DataFrame] = {query.entity_table: df}
300
+ time_dict: dict[str, pd.Series] = {}
301
+ time_column = self.time_column_dict.get(query.entity_table)
302
+ if time_column in columns_dict[query.entity_table]:
303
+ time_dict[query.entity_table] = df[time_column]
304
+ batch_dict: dict[str, np.ndarray] = {
305
+ query.entity_table: np.arange(len(df)),
306
+ }
307
+ for edge_type, (_min, _max) in time_offset_dict.items():
308
+ table_name, fkey, _ = edge_type
309
+ feat_dict[table_name], batch_dict[table_name] = self._by_time(
310
+ table_name=table_name,
311
+ fkey=fkey,
312
+ pkey=df[self.primary_key_dict[query.entity_table]],
313
+ anchor_time=time,
314
+ min_offset=_min,
315
+ max_offset=_max,
316
+ columns=columns_dict[table_name],
317
+ )
318
+ time_column = self.time_column_dict.get(table_name)
319
+ if time_column in columns_dict[table_name]:
320
+ time_dict[table_name] = feat_dict[table_name][time_column]
321
+
322
+ y, _mask = PQueryPandasExecutor().execute(
323
+ query=query,
324
+ feat_dict=feat_dict,
325
+ time_dict=time_dict,
326
+ batch_dict=batch_dict,
327
+ anchor_time=anchor_time,
328
+ num_forecasts=query.num_forecasts,
329
+ )
330
+ ys.append(y)
331
+ mask[start:start + batch_size] = _mask
332
+
333
+ count += len(y)
334
+ if count >= num_examples:
335
+ break
336
+
337
+ if len(ys) == 0:
338
+ y = pd.Series([], dtype=float)
339
+ elif len(ys) == 1:
340
+ y = ys[0]
341
+ else:
342
+ y = pd.concat(ys, axis=0, ignore_index=True)
343
+
344
+ return y, mask
345
+
346
+ def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
347
+ df = table.to_pandas(types_mapper=pd.ArrowDtype)
348
+
349
+ stype_dict = self.table_stype_dict[table_name]
350
+ for column_name in df.columns:
351
+ if stype_dict.get(column_name) == Stype.timestamp:
352
+ df[column_name] = pd.to_datetime(df[column_name])
353
+
354
+ return df
@@ -1,13 +1,19 @@
1
1
  import re
2
2
  import warnings
3
- from typing import List, Optional, Sequence
3
+ from typing import List, Optional, Sequence, cast
4
4
 
5
5
  import pandas as pd
6
6
  from kumoapi.typing import Dtype
7
7
 
8
8
  from kumoai.experimental.rfm.backend.sqlite import Connection
9
- from kumoai.experimental.rfm.base import SourceColumn, SourceForeignKey, Table
9
+ from kumoai.experimental.rfm.base import (
10
+ DataBackend,
11
+ SourceColumn,
12
+ SourceForeignKey,
13
+ Table,
14
+ )
10
15
  from kumoai.experimental.rfm.infer import infer_dtype
16
+ from kumoai.utils import quote_ident
11
17
 
12
18
 
13
19
  class SQLiteTable(Table):
@@ -42,16 +48,32 @@ class SQLiteTable(Table):
42
48
  end_time_column=end_time_column,
43
49
  )
44
50
 
51
+ @property
52
+ def backend(self) -> DataBackend:
53
+ return cast(DataBackend, DataBackend.SQLITE)
54
+
45
55
  def _get_source_columns(self) -> List[SourceColumn]:
46
56
  source_columns: List[SourceColumn] = []
47
57
  with self._connection.cursor() as cursor:
48
- cursor.execute(f"PRAGMA table_info({self.name})")
49
- rows = cursor.fetchall()
58
+ sql = f"PRAGMA table_info({quote_ident(self.name)})"
59
+ cursor.execute(sql)
60
+ columns = cursor.fetchall()
50
61
 
51
- if len(rows) == 0:
62
+ if len(columns) == 0:
52
63
  raise ValueError(f"Table '{self.name}' does not exist")
53
64
 
54
- for _, column, type, _, _, is_pkey in rows:
65
+ unique_keys: set[str] = set()
66
+ sql = f"PRAGMA index_list({quote_ident(self.name)})"
67
+ cursor.execute(sql)
68
+ for _, index_name, is_unique, *_ in cursor.fetchall():
69
+ if bool(is_unique):
70
+ sql = f"PRAGMA index_info({quote_ident(index_name)})"
71
+ cursor.execute(sql)
72
+ index = cursor.fetchall()
73
+ if len(index) == 1:
74
+ unique_keys.add(index[0][2])
75
+
76
+ for _, column, type, notnull, _, is_pkey in columns:
55
77
  # Determine column affinity:
56
78
  type = type.strip().upper()
57
79
  if re.search('INT', type):
@@ -76,7 +98,8 @@ class SQLiteTable(Table):
76
98
  name=column,
77
99
  dtype=dtype,
78
100
  is_primary_key=bool(is_pkey),
79
- is_unique_key=False,
101
+ is_unique_key=column in unique_keys,
102
+ is_nullable=not bool(is_pkey) and not bool(notnull),
80
103
  )
81
104
  source_columns.append(source_column)
82
105
 
@@ -85,15 +108,17 @@ class SQLiteTable(Table):
85
108
  def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
86
109
  source_fkeys: List[SourceForeignKey] = []
87
110
  with self._connection.cursor() as cursor:
88
- cursor.execute(f"PRAGMA foreign_key_list({self.name})")
89
- for _, _, dst_table, fkey, pkey, _, _, _ in cursor.fetchall():
111
+ sql = f"PRAGMA foreign_key_list({quote_ident(self.name)})"
112
+ cursor.execute(sql)
113
+ for _, _, dst_table, fkey, pkey, *_ in cursor.fetchall():
90
114
  source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
91
115
  return source_fkeys
92
116
 
93
117
  def _get_sample_df(self) -> pd.DataFrame:
94
118
  with self._connection.cursor() as cursor:
95
- cursor.execute(f"SELECT * FROM {self.name} "
96
- f"ORDER BY rowid LIMIT 1000")
119
+ sql = (f"SELECT * FROM {quote_ident(self.name)} "
120
+ f"ORDER BY rowid LIMIT 1000")
121
+ cursor.execute(sql)
97
122
  table = cursor.fetch_arrow_table()
98
123
  return table.to_pandas(types_mapper=pd.ArrowDtype)
99
124
 
@@ -1,14 +1,25 @@
1
- from .source import SourceColumn, SourceForeignKey
2
- from .column import Column
3
- from .table import Table
4
- from .sampler import BackwardSamplerOutput, ForwardSamplerOutput, Sampler
1
+ from kumoapi.common import StrEnum
2
+
3
+
4
+ class DataBackend(StrEnum):
5
+ LOCAL = 'local'
6
+ SQLITE = 'sqlite'
7
+ SNOWFLAKE = 'snowflake'
8
+
9
+
10
+ from .source import SourceColumn, SourceForeignKey # noqa: E402
11
+ from .column import Column # noqa: E402
12
+ from .table import Table # noqa: E402
13
+ from .sampler import SamplerOutput, Sampler # noqa: E402
14
+ from .sql_sampler import SQLSampler # noqa: E402
5
15
 
6
16
  __all__ = [
17
+ 'DataBackend',
7
18
  'SourceColumn',
8
19
  'SourceForeignKey',
9
20
  'Column',
10
21
  'Table',
11
- 'BackwardSamplerOutput',
12
- 'ForwardSamplerOutput',
22
+ 'SamplerOutput',
13
23
  'Sampler',
24
+ 'SQLSampler',
14
25
  ]