kumoai 2.13.0.dev202512061731__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. kumoai/_version.py +1 -1
  2. kumoai/client/pquery.py +6 -2
  3. kumoai/experimental/rfm/backend/local/graph_store.py +19 -62
  4. kumoai/experimental/rfm/backend/local/sampler.py +229 -45
  5. kumoai/experimental/rfm/backend/local/table.py +12 -2
  6. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  7. kumoai/experimental/rfm/backend/snow/sampler.py +264 -0
  8. kumoai/experimental/rfm/backend/snow/table.py +35 -17
  9. kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -0
  10. kumoai/experimental/rfm/backend/sqlite/sampler.py +354 -0
  11. kumoai/experimental/rfm/backend/sqlite/table.py +36 -11
  12. kumoai/experimental/rfm/base/__init__.py +16 -5
  13. kumoai/experimental/rfm/base/sampler.py +538 -52
  14. kumoai/experimental/rfm/base/source.py +1 -0
  15. kumoai/experimental/rfm/base/sql_sampler.py +56 -0
  16. kumoai/experimental/rfm/base/table.py +12 -1
  17. kumoai/experimental/rfm/graph.py +26 -9
  18. kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
  19. kumoai/experimental/rfm/rfm.py +214 -151
  20. kumoai/pquery/predictive_query.py +10 -6
  21. kumoai/testing/snow.py +50 -0
  22. kumoai/utils/__init__.py +2 -0
  23. kumoai/utils/sql.py +3 -0
  24. {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/METADATA +2 -2
  25. {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/RECORD +28 -25
  26. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  27. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  28. {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/WHEEL +0 -0
  29. {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/licenses/LICENSE +0 -0
  30. {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,264 @@
1
+ import json
2
+ from typing import TYPE_CHECKING
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import pyarrow as pa
7
+ from kumoapi.pquery import ValidatedPredictiveQuery
8
+
9
+ from kumoai.experimental.rfm.backend.snow import SnowTable
10
+ from kumoai.experimental.rfm.base import SQLSampler
11
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
12
+ from kumoai.utils import ProgressLogger, quote_ident
13
+
14
+ if TYPE_CHECKING:
15
+ from kumoai.experimental.rfm import Graph
16
+
17
+
18
+ class SnowSampler(SQLSampler):
19
+ def __init__(
20
+ self,
21
+ graph: 'Graph',
22
+ verbose: bool | ProgressLogger = True,
23
+ ) -> None:
24
+ super().__init__(graph=graph, verbose=verbose)
25
+
26
+ self._fqn_dict: dict[str, str] = {}
27
+ for table in graph.tables.values():
28
+ assert isinstance(table, SnowTable)
29
+ self._connection = table._connection
30
+ self._fqn_dict[table.name] = table.fqn
31
+
32
+ @property
33
+ def fqn_dict(self) -> dict[str, str]:
34
+ r"""The fully-qualified quoted names for all tables in the graph."""
35
+ return self._fqn_dict
36
+
37
+ def _get_min_max_time_dict(
38
+ self,
39
+ table_names: list[str],
40
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
41
+ selects: list[str] = []
42
+ for table_name in table_names:
43
+ time_column = self.time_column_dict[table_name]
44
+ select = (f"SELECT\n"
45
+ f" %s as table_name,\n"
46
+ f" MIN({quote_ident(time_column)}) as min_date,\n"
47
+ f" MAX({quote_ident(time_column)}) as max_date\n"
48
+ f"FROM {self.fqn_dict[table_name]}")
49
+ selects.append(select)
50
+ sql = "\nUNION ALL\n".join(selects)
51
+
52
+ out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
53
+ with self._connection.cursor() as cursor:
54
+ cursor.execute(sql, table_names)
55
+ rows = cursor.fetchall()
56
+ for table_name, _min, _max in rows:
57
+ out_dict[table_name] = (
58
+ pd.Timestamp.max if _min is None else pd.Timestamp(_min),
59
+ pd.Timestamp.min if _max is None else pd.Timestamp(_max),
60
+ )
61
+
62
+ return out_dict
63
+
64
+ def _sample_entity_table(
65
+ self,
66
+ table_name: str,
67
+ columns: set[str],
68
+ num_rows: int,
69
+ random_seed: int | None = None,
70
+ ) -> pd.DataFrame:
71
+ # NOTE Snowflake does support `SEED` only as part of `SYSTEM` sampling.
72
+ num_rows = min(num_rows, 1_000_000) # Snowflake's upper limit.
73
+
74
+ filters: list[str] = []
75
+ primary_key = self.primary_key_dict[table_name]
76
+ if self.source_table_dict[table_name][primary_key].is_nullable:
77
+ filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
78
+ time_column = self.time_column_dict.get(table_name)
79
+ if (time_column is not None and
80
+ self.source_table_dict[table_name][time_column].is_nullable):
81
+ filters.append(f" {quote_ident(time_column)} IS NOT NULL")
82
+
83
+ sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
84
+ f"FROM {self.fqn_dict[table_name]}\n"
85
+ f"SAMPLE ROW ({num_rows} ROWS)")
86
+ if len(filters) > 0:
87
+ sql += f"\nWHERE{' AND'.join(filters)}"
88
+
89
+ with self._connection.cursor() as cursor:
90
+ # NOTE This may return duplicate primary keys. This is okay.
91
+ cursor.execute(sql)
92
+ table = cursor.fetch_arrow_all()
93
+
94
+ return self._sanitize(table_name, table)
95
+
96
+ def _sample_target(
97
+ self,
98
+ query: ValidatedPredictiveQuery,
99
+ entity_df: pd.DataFrame,
100
+ train_index: np.ndarray,
101
+ train_time: pd.Series,
102
+ num_train_examples: int,
103
+ test_index: np.ndarray,
104
+ test_time: pd.Series,
105
+ num_test_examples: int,
106
+ columns_dict: dict[str, set[str]],
107
+ time_offset_dict: dict[
108
+ tuple[str, str, str],
109
+ tuple[pd.DateOffset | None, pd.DateOffset],
110
+ ],
111
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
112
+
113
+ # NOTE For Snowflake, we execute everything at once to pay minimal
114
+ # query initialization costs.
115
+ index = np.concatenate([train_index, test_index])
116
+ time = pd.concat([train_time, test_time], axis=0, ignore_index=True)
117
+
118
+ entity_df = entity_df.iloc[index].reset_index(drop=True)
119
+
120
+ feat_dict: dict[str, pd.DataFrame] = {query.entity_table: entity_df}
121
+ time_dict: dict[str, pd.Series] = {}
122
+ time_column = self.time_column_dict.get(query.entity_table)
123
+ if time_column in columns_dict[query.entity_table]:
124
+ time_dict[query.entity_table] = entity_df[time_column]
125
+ batch_dict: dict[str, np.ndarray] = {
126
+ query.entity_table: np.arange(len(entity_df)),
127
+ }
128
+ for edge_type, (min_offset, max_offset) in time_offset_dict.items():
129
+ table_name, fkey, _ = edge_type
130
+ feat_dict[table_name], batch_dict[table_name] = self._by_time(
131
+ table_name=table_name,
132
+ fkey=fkey,
133
+ pkey=entity_df[self.primary_key_dict[query.entity_table]],
134
+ anchor_time=time,
135
+ min_offset=min_offset,
136
+ max_offset=max_offset,
137
+ columns=columns_dict[table_name],
138
+ )
139
+ time_column = self.time_column_dict.get(table_name)
140
+ if time_column in columns_dict[table_name]:
141
+ time_dict[table_name] = feat_dict[table_name][time_column]
142
+
143
+ y, mask = PQueryPandasExecutor().execute(
144
+ query=query,
145
+ feat_dict=feat_dict,
146
+ time_dict=time_dict,
147
+ batch_dict=batch_dict,
148
+ anchor_time=time,
149
+ num_forecasts=query.num_forecasts,
150
+ )
151
+
152
+ train_mask = mask[:len(train_index)]
153
+ test_mask = mask[len(train_index):]
154
+
155
+ boundary = int(train_mask.sum())
156
+ train_y = y.iloc[:boundary]
157
+ test_y = y.iloc[boundary:].reset_index(drop=True)
158
+
159
+ return train_y, train_mask, test_y, test_mask
160
+
161
+ def _by_pkey(
162
+ self,
163
+ table_name: str,
164
+ pkey: pd.Series,
165
+ columns: set[str],
166
+ ) -> tuple[pd.DataFrame, np.ndarray]:
167
+
168
+ pkey_name = self.primary_key_dict[table_name]
169
+ source_table = self.source_table_dict[table_name]
170
+
171
+ payload = json.dumps(list(pkey))
172
+
173
+ sql = ("WITH TMP as (\n"
174
+ " SELECT\n"
175
+ " f.index as BATCH,\n")
176
+ if source_table[pkey_name].dtype.is_int():
177
+ sql += " f.value::NUMBER as ID\n"
178
+ elif source_table[pkey_name].dtype.is_float():
179
+ sql += " f.value::FLOAT as ID\n"
180
+ else:
181
+ sql += " f.value::VARCHAR as ID\n"
182
+ sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(%s))) f\n"
183
+ f")\n"
184
+ f"SELECT TMP.BATCH as __BATCH__, "
185
+ f"{', '.join('ENT.' + quote_ident(col) for col in columns)}\n"
186
+ f"FROM TMP\n"
187
+ f"JOIN {self.fqn_dict[table_name]} ENT\n"
188
+ f" ON ENT.{quote_ident(pkey_name)} = TMP.ID")
189
+
190
+ with self._connection.cursor() as cursor:
191
+ cursor.execute(sql, (payload, ))
192
+ table = cursor.fetch_arrow_all()
193
+
194
+ # Remove any duplicated primary keys in post-processing:
195
+ tmp = table.append_column('__TMP__', pa.array(range(len(table))))
196
+ gb = tmp.group_by('__BATCH__').aggregate([('__TMP__', 'min')])
197
+ table = table.take(gb['__TMP___min'])
198
+
199
+ batch = table['__BATCH__'].cast(pa.int64()).to_numpy()
200
+ table = table.remove_column(table.schema.get_field_index('__BATCH__'))
201
+
202
+ return table.to_pandas(), batch # TODO Use `self._sanitize`.
203
+
204
+ # Helper Methods ##########################################################
205
+
206
+ def _by_time(
207
+ self,
208
+ table_name: str,
209
+ fkey: str,
210
+ pkey: pd.Series,
211
+ anchor_time: pd.Series,
212
+ min_offset: pd.DateOffset | None,
213
+ max_offset: pd.DateOffset,
214
+ columns: set[str],
215
+ ) -> tuple[pd.DataFrame, np.ndarray]:
216
+
217
+ end_time = anchor_time + max_offset
218
+ end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
219
+ if min_offset is not None:
220
+ start_time = anchor_time + min_offset
221
+ start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
222
+ payload = json.dumps(list(zip(pkey, end_time, start_time)))
223
+ else:
224
+ payload = json.dumps(list(zip(pkey, end_time)))
225
+
226
+ # Based on benchmarking, JSON payload is the fastest way to query by
227
+ # custom indices (compared to large `IN` clauses or temporary tables):
228
+ source_table = self.source_table_dict[table_name]
229
+ time_column = self.time_column_dict[table_name]
230
+ sql = ("WITH TMP as (\n"
231
+ " SELECT\n"
232
+ " f.index as BATCH,\n")
233
+ if source_table[fkey].dtype.is_int():
234
+ sql += " f.value[0]::NUMBER as ID,\n"
235
+ elif source_table[fkey].dtype.is_float():
236
+ sql += " f.value[0]::FLOAT as ID,\n"
237
+ else:
238
+ sql += " f.value[0]::VARCHAR as ID,\n"
239
+ sql += " f.value[1]::TIMESTAMP_NTZ as END_TIME"
240
+ if min_offset is not None:
241
+ sql += ",\n f.value[2]::TIMESTAMP_NTZ as START_TIME"
242
+ sql += (f"\n"
243
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(%s))) f\n"
244
+ f")\n"
245
+ f"SELECT TMP.BATCH as __BATCH__, "
246
+ f"{', '.join('FACT.' + quote_ident(col) for col in columns)}\n"
247
+ f"FROM TMP\n"
248
+ f"JOIN {self.fqn_dict[table_name]} FACT\n"
249
+ f" ON FACT.{quote_ident(fkey)} = TMP.ID\n"
250
+ f" AND FACT.{quote_ident(time_column)} <= TMP.END_TIME")
251
+ if min_offset is not None:
252
+ sql += f"\n AND FACT.{quote_ident(time_column)} > TMP.START_TIME"
253
+
254
+ with self._connection.cursor() as cursor:
255
+ cursor.execute(sql, (payload, ))
256
+ table = cursor.fetch_arrow_all()
257
+
258
+ batch = table['__BATCH__'].cast(pa.int64()).to_numpy()
259
+ table = table.remove_column(table.schema.get_field_index('__BATCH__'))
260
+
261
+ return self._sanitize(table_name, table), batch
262
+
263
+ def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
264
+ return table.to_pandas(types_mapper=pd.ArrowDtype)
@@ -1,11 +1,17 @@
1
1
  import re
2
- from typing import List, Optional, Sequence
2
+ from typing import List, Optional, Sequence, cast
3
3
 
4
4
  import pandas as pd
5
5
  from kumoapi.typing import Dtype
6
6
 
7
7
  from kumoai.experimental.rfm.backend.snow import Connection
8
- from kumoai.experimental.rfm.base import SourceColumn, SourceForeignKey, Table
8
+ from kumoai.experimental.rfm.base import (
9
+ DataBackend,
10
+ SourceColumn,
11
+ SourceForeignKey,
12
+ Table,
13
+ )
14
+ from kumoai.utils import quote_ident
9
15
 
10
16
 
11
17
  class SnowTable(Table):
@@ -51,27 +57,36 @@ class SnowTable(Table):
51
57
  )
52
58
 
53
59
  @property
54
- def fqn_name(self) -> str:
60
+ def backend(self) -> DataBackend:
61
+ return cast(DataBackend, DataBackend.SNOWFLAKE)
62
+
63
+ @property
64
+ def fqn(self) -> str:
65
+ r"""The fully-qualified quoted table name."""
55
66
  names: List[str] = []
56
67
  if self._database is not None:
57
- assert self._schema is not None
58
- names.extend([self._database, self._schema])
59
- elif self._schema is not None:
60
- names.append(self._schema)
61
- names.append(self._name)
62
- return '.'.join(names)
68
+ names.append(quote_ident(self._database))
69
+ if self._schema is not None:
70
+ names.append(quote_ident(self._schema))
71
+ return '.'.join(names + [quote_ident(self._name)])
63
72
 
64
73
  def _get_source_columns(self) -> List[SourceColumn]:
65
74
  source_columns: List[SourceColumn] = []
66
75
  with self._connection.cursor() as cursor:
67
76
  try:
68
- cursor.execute(f"DESCRIBE TABLE {self.fqn_name}")
77
+ sql = f"DESCRIBE TABLE {self.fqn}"
78
+ cursor.execute(sql)
69
79
  except Exception as e:
70
- raise ValueError(
71
- f"Table '{self.fqn_name}' does not exist") from e
80
+ names: list[str] = []
81
+ if self._database is not None:
82
+ names.append(self._database)
83
+ if self._schema is not None:
84
+ names.append(self._schema)
85
+ name = '.'.join(names + [self._name])
86
+ raise ValueError(f"Table '{name}' does not exist") from e
72
87
 
73
88
  for row in cursor.fetchall():
74
- column, type, _, _, _, is_pkey, is_unique = row[:7]
89
+ column, type, _, null, _, is_pkey, is_unique, *_ = row
75
90
 
76
91
  type = type.strip().upper()
77
92
  if type.startswith('NUMBER'):
@@ -92,6 +107,7 @@ class SnowTable(Table):
92
107
  dtype=dtype,
93
108
  is_primary_key=is_pkey.strip().upper() == 'Y',
94
109
  is_unique_key=is_unique.strip().upper() == 'Y',
110
+ is_nullable=null.strip().upper() == 'Y',
95
111
  )
96
112
  source_columns.append(source_column)
97
113
 
@@ -100,16 +116,18 @@ class SnowTable(Table):
100
116
  def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
101
117
  source_fkeys: List[SourceForeignKey] = []
102
118
  with self._connection.cursor() as cursor:
103
- cursor.execute(f"SHOW IMPORTED KEYS IN TABLE {self.fqn_name}")
119
+ sql = f"SHOW IMPORTED KEYS IN TABLE {self.fqn}"
120
+ cursor.execute(sql)
104
121
  for row in cursor.fetchall():
105
- _, _, _, dst_table, pkey, _, _, _, fkey = row[:9]
122
+ _, _, _, dst_table, pkey, _, _, _, fkey, *_ = row
106
123
  source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
107
124
  return source_fkeys
108
125
 
109
126
  def _get_sample_df(self) -> pd.DataFrame:
110
127
  with self._connection.cursor() as cursor:
111
- columns = ', '.join(self._source_column_dict.keys())
112
- cursor.execute(f"SELECT {columns} FROM {self.fqn_name} LIMIT 1000")
128
+ columns = [quote_ident(col) for col in self._source_column_dict]
129
+ sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
130
+ cursor.execute(sql)
113
131
  table = cursor.fetch_arrow_all()
114
132
  return table.to_pandas(types_mapper=pd.ArrowDtype)
115
133
 
@@ -22,9 +22,11 @@ def connect(uri: Union[str, Path, None] = None, **kwargs: Any) -> Connection:
22
22
 
23
23
 
24
24
  from .table import SQLiteTable # noqa: E402
25
+ from .sampler import SQLiteSampler # noqa: E402
25
26
 
26
27
  __all__ = [
27
28
  'connect',
28
29
  'Connection',
29
30
  'SQLiteTable',
31
+ 'SQLiteSampler',
30
32
  ]