kumoai 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl → 2.15.0.dev202601121731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +51 -24
  7. kumoai/experimental/rfm/authenticate.py +3 -4
  8. kumoai/experimental/rfm/backend/local/graph_store.py +37 -46
  9. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  10. kumoai/experimental/rfm/backend/local/table.py +24 -30
  11. kumoai/experimental/rfm/backend/snow/sampler.py +197 -90
  12. kumoai/experimental/rfm/backend/snow/table.py +159 -52
  13. kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
  14. kumoai/experimental/rfm/backend/sqlite/sampler.py +199 -99
  15. kumoai/experimental/rfm/backend/sqlite/table.py +103 -45
  16. kumoai/experimental/rfm/base/__init__.py +6 -1
  17. kumoai/experimental/rfm/base/column.py +96 -10
  18. kumoai/experimental/rfm/base/expression.py +44 -0
  19. kumoai/experimental/rfm/base/mapper.py +69 -0
  20. kumoai/experimental/rfm/base/sampler.py +28 -18
  21. kumoai/experimental/rfm/base/source.py +1 -1
  22. kumoai/experimental/rfm/base/sql_sampler.py +342 -13
  23. kumoai/experimental/rfm/base/table.py +374 -208
  24. kumoai/experimental/rfm/base/utils.py +27 -0
  25. kumoai/experimental/rfm/graph.py +335 -180
  26. kumoai/experimental/rfm/infer/__init__.py +6 -4
  27. kumoai/experimental/rfm/infer/dtype.py +7 -4
  28. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  29. kumoai/experimental/rfm/infer/pkey.py +4 -2
  30. kumoai/experimental/rfm/infer/stype.py +35 -0
  31. kumoai/experimental/rfm/infer/time_col.py +5 -4
  32. kumoai/experimental/rfm/pquery/executor.py +27 -27
  33. kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
  34. kumoai/experimental/rfm/relbench.py +76 -0
  35. kumoai/experimental/rfm/rfm.py +600 -360
  36. kumoai/experimental/rfm/sagemaker.py +4 -4
  37. kumoai/experimental/rfm/task_table.py +292 -0
  38. kumoai/pquery/training_table.py +16 -2
  39. kumoai/testing/snow.py +3 -3
  40. kumoai/trainer/distilled_trainer.py +175 -0
  41. kumoai/utils/__init__.py +1 -2
  42. kumoai/utils/display.py +87 -0
  43. kumoai/utils/progress_logger.py +190 -12
  44. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/METADATA +3 -2
  45. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/RECORD +48 -40
  46. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/WHEEL +0 -0
  47. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/licenses/LICENSE +0 -0
  48. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,15 @@
1
- import warnings
2
- from typing import List, Optional, cast
1
+ from typing import Sequence, cast
3
2
 
4
3
  import pandas as pd
4
+ from kumoapi.model_plan import MissingType
5
5
 
6
6
  from kumoai.experimental.rfm.base import (
7
+ ColumnSpec,
7
8
  DataBackend,
8
9
  SourceColumn,
9
10
  SourceForeignKey,
10
11
  Table,
11
12
  )
12
- from kumoai.experimental.rfm.infer import infer_dtype
13
13
 
14
14
 
15
15
  class LocalTable(Table):
@@ -57,9 +57,9 @@ class LocalTable(Table):
57
57
  self,
58
58
  df: pd.DataFrame,
59
59
  name: str,
60
- primary_key: Optional[str] = None,
61
- time_column: Optional[str] = None,
62
- end_time_column: Optional[str] = None,
60
+ primary_key: MissingType | str | None = MissingType.VALUE,
61
+ time_column: str | None = None,
62
+ end_time_column: str | None = None,
63
63
  ) -> None:
64
64
 
65
65
  if df.empty:
@@ -75,7 +75,6 @@ class LocalTable(Table):
75
75
 
76
76
  super().__init__(
77
77
  name=name,
78
- columns=list(df.columns),
79
78
  primary_key=primary_key,
80
79
  time_column=time_column,
81
80
  end_time_column=end_time_column,
@@ -85,35 +84,30 @@ class LocalTable(Table):
85
84
  def backend(self) -> DataBackend:
86
85
  return cast(DataBackend, DataBackend.LOCAL)
87
86
 
88
- def _get_source_columns(self) -> List[SourceColumn]:
89
- source_columns: List[SourceColumn] = []
90
- for column in self._data.columns:
91
- ser = self._data[column]
92
- try:
93
- dtype = infer_dtype(ser)
94
- except Exception:
95
- warnings.warn(f"Data type inference for column '{column}' in "
96
- f"table '{self.name}' failed. Consider changing "
97
- f"the data type of the column to use it within "
98
- f"this table.")
99
- continue
100
-
101
- source_column = SourceColumn(
102
- name=column,
103
- dtype=dtype,
87
+ def _get_source_columns(self) -> list[SourceColumn]:
88
+ return [
89
+ SourceColumn(
90
+ name=column_name,
91
+ dtype=None,
104
92
  is_primary_key=False,
105
93
  is_unique_key=False,
106
94
  is_nullable=True,
107
- )
108
- source_columns.append(source_column)
95
+ ) for column_name in self._data.columns
96
+ ]
109
97
 
110
- return source_columns
111
-
112
- def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
98
+ def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
113
99
  return []
114
100
 
115
- def _get_sample_df(self) -> pd.DataFrame:
101
+ def _get_source_sample_df(self) -> pd.DataFrame:
116
102
  return self._data
117
103
 
118
- def _get_num_rows(self) -> Optional[int]:
104
+ def _get_expr_sample_df(
105
+ self,
106
+ columns: Sequence[ColumnSpec],
107
+ ) -> pd.DataFrame:
108
+ raise RuntimeError(f"Column expressions are not supported in "
109
+ f"'{self.__class__.__name__}'. Please apply your "
110
+ f"expressions on the `pd.DataFrame` directly.")
111
+
112
+ def _get_num_rows(self) -> int | None:
119
113
  return len(self._data)
@@ -1,4 +1,6 @@
1
1
  import json
2
+ from collections.abc import Iterator
3
+ from contextlib import contextmanager
2
4
  from typing import TYPE_CHECKING
3
5
 
4
6
  import numpy as np
@@ -6,15 +8,23 @@ import pandas as pd
6
8
  import pyarrow as pa
7
9
  from kumoapi.pquery import ValidatedPredictiveQuery
8
10
 
9
- from kumoai.experimental.rfm.backend.snow import SnowTable
10
- from kumoai.experimental.rfm.base import SQLSampler
11
+ from kumoai.experimental.rfm.backend.snow import Connection, SnowTable
12
+ from kumoai.experimental.rfm.base import SQLSampler, Table
11
13
  from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
12
- from kumoai.utils import ProgressLogger, quote_ident
14
+ from kumoai.utils import ProgressLogger
13
15
 
14
16
  if TYPE_CHECKING:
15
17
  from kumoai.experimental.rfm import Graph
16
18
 
17
19
 
20
+ @contextmanager
21
+ def paramstyle(connection: Connection, style: str = 'qmark') -> Iterator[None]:
22
+ _style = connection._paramstyle
23
+ connection._paramstyle = style
24
+ yield
25
+ connection._paramstyle = _style
26
+
27
+
18
28
  class SnowSampler(SQLSampler):
19
29
  def __init__(
20
30
  self,
@@ -23,16 +33,9 @@ class SnowSampler(SQLSampler):
23
33
  ) -> None:
24
34
  super().__init__(graph=graph, verbose=verbose)
25
35
 
26
- self._fqn_dict: dict[str, str] = {}
27
36
  for table in graph.tables.values():
28
37
  assert isinstance(table, SnowTable)
29
38
  self._connection = table._connection
30
- self._fqn_dict[table.name] = table.fqn
31
-
32
- @property
33
- def fqn_dict(self) -> dict[str, str]:
34
- r"""The fully-qualified quoted names for all tables in the graph."""
35
- return self._fqn_dict
36
39
 
37
40
  def _get_min_max_time_dict(
38
41
  self,
@@ -40,24 +43,25 @@ class SnowSampler(SQLSampler):
40
43
  ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
41
44
  selects: list[str] = []
42
45
  for table_name in table_names:
43
- time_column = self.time_column_dict[table_name]
46
+ column = self.time_column_dict[table_name]
47
+ column_ref = self.table_column_ref_dict[table_name][column]
44
48
  select = (f"SELECT\n"
45
- f" %s as table_name,\n"
46
- f" MIN({quote_ident(time_column)}) as min_date,\n"
47
- f" MAX({quote_ident(time_column)}) as max_date\n"
48
- f"FROM {self.fqn_dict[table_name]}")
49
+ f" ? as table_name,\n"
50
+ f" MIN({column_ref}) as min_date,\n"
51
+ f" MAX({column_ref}) as max_date\n"
52
+ f"FROM {self.source_name_dict[table_name]}")
49
53
  selects.append(select)
50
54
  sql = "\nUNION ALL\n".join(selects)
51
55
 
52
56
  out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
53
- with self._connection.cursor() as cursor:
57
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
54
58
  cursor.execute(sql, table_names)
55
59
  rows = cursor.fetchall()
56
- for table_name, _min, _max in rows:
57
- out_dict[table_name] = (
58
- pd.Timestamp.max if _min is None else pd.Timestamp(_min),
59
- pd.Timestamp.min if _max is None else pd.Timestamp(_max),
60
- )
60
+ for table_name, _min, _max in rows:
61
+ out_dict[table_name] = (
62
+ pd.Timestamp.max if _min is None else pd.Timestamp(_min),
63
+ pd.Timestamp.min if _max is None else pd.Timestamp(_max),
64
+ )
61
65
 
62
66
  return out_dict
63
67
 
@@ -71,17 +75,27 @@ class SnowSampler(SQLSampler):
71
75
  # NOTE Snowflake does support `SEED` only as part of `SYSTEM` sampling.
72
76
  num_rows = min(num_rows, 1_000_000) # Snowflake's upper limit.
73
77
 
78
+ source_table = self.source_table_dict[table_name]
74
79
  filters: list[str] = []
75
- primary_key = self.primary_key_dict[table_name]
76
- if self.source_table_dict[table_name][primary_key].is_nullable:
77
- filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
78
- time_column = self.time_column_dict.get(table_name)
79
- if (time_column is not None and
80
- self.source_table_dict[table_name][time_column].is_nullable):
81
- filters.append(f" {quote_ident(time_column)} IS NOT NULL")
82
80
 
83
- sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
84
- f"FROM {self.fqn_dict[table_name]}\n"
81
+ key = self.primary_key_dict[table_name]
82
+ if key not in source_table or source_table[key].is_nullable:
83
+ key_ref = self.table_column_ref_dict[table_name][key]
84
+ filters.append(f" {key_ref} IS NOT NULL")
85
+
86
+ column = self.time_column_dict.get(table_name)
87
+ if column is None:
88
+ pass
89
+ elif column not in source_table or source_table[column].is_nullable:
90
+ column_ref = self.table_column_ref_dict[table_name][column]
91
+ filters.append(f" {column_ref} IS NOT NULL")
92
+
93
+ projections = [
94
+ self.table_column_proj_dict[table_name][column]
95
+ for column in columns
96
+ ]
97
+ sql = (f"SELECT {', '.join(projections)}\n"
98
+ f"FROM {self.source_name_dict[table_name]}\n"
85
99
  f"SAMPLE ROW ({num_rows} ROWS)")
86
100
  if len(filters) > 0:
87
101
  sql += f"\nWHERE{' AND'.join(filters)}"
@@ -91,7 +105,11 @@ class SnowSampler(SQLSampler):
91
105
  cursor.execute(sql)
92
106
  table = cursor.fetch_arrow_all()
93
107
 
94
- return self._sanitize(table_name, table)
108
+ return Table._sanitize(
109
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
110
+ dtype_dict=self.table_dtype_dict[table_name],
111
+ stype_dict=self.table_stype_dict[table_name],
112
+ )
95
113
 
96
114
  def _sample_target(
97
115
  self,
@@ -126,11 +144,11 @@ class SnowSampler(SQLSampler):
126
144
  query.entity_table: np.arange(len(entity_df)),
127
145
  }
128
146
  for edge_type, (min_offset, max_offset) in time_offset_dict.items():
129
- table_name, fkey, _ = edge_type
147
+ table_name, foreign_key, _ = edge_type
130
148
  feat_dict[table_name], batch_dict[table_name] = self._by_time(
131
149
  table_name=table_name,
132
- fkey=fkey,
133
- pkey=entity_df[self.primary_key_dict[query.entity_table]],
150
+ foreign_key=foreign_key,
151
+ index=entity_df[self.primary_key_dict[query.entity_table]],
134
152
  anchor_time=time,
135
153
  min_offset=min_offset,
136
154
  max_offset=max_offset,
@@ -161,104 +179,193 @@ class SnowSampler(SQLSampler):
161
179
  def _by_pkey(
162
180
  self,
163
181
  table_name: str,
164
- pkey: pd.Series,
182
+ index: pd.Series,
165
183
  columns: set[str],
166
184
  ) -> tuple[pd.DataFrame, np.ndarray]:
185
+ key = self.primary_key_dict[table_name]
186
+ key_ref = self.table_column_ref_dict[table_name][key]
187
+ projections = [
188
+ self.table_column_proj_dict[table_name][column]
189
+ for column in columns
190
+ ]
167
191
 
168
- pkey_name = self.primary_key_dict[table_name]
169
- source_table = self.source_table_dict[table_name]
170
-
171
- payload = json.dumps(list(pkey))
192
+ payload = json.dumps(list(index))
172
193
 
173
194
  sql = ("WITH TMP as (\n"
174
195
  " SELECT\n"
175
- " f.index as BATCH,\n")
176
- if source_table[pkey_name].dtype.is_int():
177
- sql += " f.value::NUMBER as ID\n"
178
- elif source_table[pkey_name].dtype.is_float():
179
- sql += " f.value::FLOAT as ID\n"
196
+ " f.index as __KUMO_BATCH__,\n")
197
+ if self.table_dtype_dict[table_name][key].is_int():
198
+ sql += " f.value::NUMBER as __KUMO_ID__\n"
199
+ elif self.table_dtype_dict[table_name][key].is_float():
200
+ sql += " f.value::FLOAT as __KUMO_ID__\n"
180
201
  else:
181
- sql += " f.value::VARCHAR as ID\n"
182
- sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(%s))) f\n"
202
+ sql += " f.value::VARCHAR as __KUMO_ID__\n"
203
+ sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
183
204
  f")\n"
184
- f"SELECT TMP.BATCH as __BATCH__, "
185
- f"{', '.join('ENT.' + quote_ident(col) for col in columns)}\n"
205
+ f"SELECT "
206
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
207
+ f"{', '.join(projections)}\n"
186
208
  f"FROM TMP\n"
187
- f"JOIN {self.fqn_dict[table_name]} ENT\n"
188
- f" ON ENT.{quote_ident(pkey_name)} = TMP.ID")
209
+ f"JOIN {self.source_name_dict[table_name]}\n"
210
+ f" ON {key_ref} = TMP.__KUMO_ID__")
189
211
 
190
- with self._connection.cursor() as cursor:
212
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
191
213
  cursor.execute(sql, (payload, ))
192
214
  table = cursor.fetch_arrow_all()
193
215
 
194
216
  # Remove any duplicated primary keys in post-processing:
195
- tmp = table.append_column('__TMP__', pa.array(range(len(table))))
196
- gb = tmp.group_by('__BATCH__').aggregate([('__TMP__', 'min')])
197
- table = table.take(gb['__TMP___min'])
217
+ tmp = table.append_column('__KUMO_ID__', pa.array(range(len(table))))
218
+ gb = tmp.group_by('__KUMO_BATCH__').aggregate([('__KUMO_ID__', 'min')])
219
+ table = table.take(gb['__KUMO_ID___min'])
198
220
 
199
- batch = table['__BATCH__'].cast(pa.int64()).to_numpy()
200
- table = table.remove_column(table.schema.get_field_index('__BATCH__'))
221
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
222
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
223
+ table = table.remove_column(batch_index)
201
224
 
202
- return table.to_pandas(), batch # TODO Use `self._sanitize`.
225
+ return Table._sanitize(
226
+ df=table.to_pandas(),
227
+ dtype_dict=self.table_dtype_dict[table_name],
228
+ stype_dict=self.table_stype_dict[table_name],
229
+ ), batch
230
+
231
+ def _by_fkey(
232
+ self,
233
+ table_name: str,
234
+ foreign_key: str,
235
+ index: pd.Series,
236
+ num_neighbors: int,
237
+ anchor_time: pd.Series | None,
238
+ columns: set[str],
239
+ ) -> tuple[pd.DataFrame, np.ndarray]:
240
+ time_column = self.time_column_dict.get(table_name)
241
+
242
+ if time_column is not None and anchor_time is not None:
243
+ anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
244
+ payload = json.dumps(list(zip(index, anchor_time)))
245
+ else:
246
+ payload = json.dumps(list(zip(index)))
247
+
248
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
249
+ projections = [
250
+ self.table_column_proj_dict[table_name][column]
251
+ for column in columns
252
+ ]
253
+
254
+ sql = ("WITH TMP as (\n"
255
+ " SELECT\n"
256
+ " f.index as __KUMO_BATCH__,\n")
257
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
258
+ sql += " f.value[0]::NUMBER as __KUMO_ID__"
259
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
260
+ sql += " f.value[0]::FLOAT as __KUMO_ID__"
261
+ else:
262
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__"
263
+ if time_column is not None and anchor_time is not None:
264
+ sql += (",\n"
265
+ " f.value[1]::TIMESTAMP_NTZ as __KUMO_TIME__")
266
+ sql += (f"\n"
267
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
268
+ f")\n"
269
+ f"SELECT "
270
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
271
+ f"{', '.join(projections)}\n"
272
+ f"FROM TMP\n"
273
+ f"JOIN {self.source_name_dict[table_name]}\n"
274
+ f" ON {key_ref} = TMP.__KUMO_ID__\n")
275
+ if time_column is not None and anchor_time is not None:
276
+ time_ref = self.table_column_ref_dict[table_name][time_column]
277
+ sql += f" AND {time_ref} <= TMP.__KUMO_TIME__\n"
278
+ sql += ("QUALIFY ROW_NUMBER() OVER (\n"
279
+ " PARTITION BY TMP.__KUMO_BATCH__\n")
280
+ if time_column is not None:
281
+ sql += f" ORDER BY {time_ref} DESC\n"
282
+ else:
283
+ sql += f" ORDER BY {key_ref}\n"
284
+ sql += f") <= {num_neighbors}"
285
+
286
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
287
+ cursor.execute(sql, (payload, ))
288
+ table = cursor.fetch_arrow_all()
289
+
290
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
291
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
292
+ table = table.remove_column(batch_index)
293
+
294
+ return Table._sanitize(
295
+ df=table.to_pandas(),
296
+ dtype_dict=self.table_dtype_dict[table_name],
297
+ stype_dict=self.table_stype_dict[table_name],
298
+ ), batch
203
299
 
204
300
  # Helper Methods ##########################################################
205
301
 
206
302
  def _by_time(
207
303
  self,
208
304
  table_name: str,
209
- fkey: str,
210
- pkey: pd.Series,
305
+ foreign_key: str,
306
+ index: pd.Series,
211
307
  anchor_time: pd.Series,
212
308
  min_offset: pd.DateOffset | None,
213
309
  max_offset: pd.DateOffset,
214
310
  columns: set[str],
215
311
  ) -> tuple[pd.DataFrame, np.ndarray]:
312
+ time_column = self.time_column_dict[table_name]
216
313
 
217
314
  end_time = anchor_time + max_offset
218
315
  end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
316
+ start_time: pd.Series | None = None
219
317
  if min_offset is not None:
220
318
  start_time = anchor_time + min_offset
221
319
  start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
222
- payload = json.dumps(list(zip(pkey, end_time, start_time)))
320
+ payload = json.dumps(list(zip(index, end_time, start_time)))
223
321
  else:
224
- payload = json.dumps(list(zip(pkey, end_time)))
225
-
226
- # Based on benchmarking, JSON payload is the fastest way to query by
227
- # custom indices (compared to large `IN` clauses or temporary tables):
228
- source_table = self.source_table_dict[table_name]
229
- time_column = self.time_column_dict[table_name]
322
+ payload = json.dumps(list(zip(index, end_time)))
323
+
324
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
325
+ time_ref = self.table_column_ref_dict[table_name][time_column]
326
+ projections = [
327
+ self.table_column_proj_dict[table_name][column]
328
+ for column in columns
329
+ ]
230
330
  sql = ("WITH TMP as (\n"
231
331
  " SELECT\n"
232
- " f.index as BATCH,\n")
233
- if source_table[fkey].dtype.is_int():
234
- sql += " f.value[0]::NUMBER as ID,\n"
235
- elif source_table[fkey].dtype.is_float():
236
- sql += " f.value[0]::FLOAT as ID,\n"
332
+ " f.index as __KUMO_BATCH__,\n")
333
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
334
+ sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
335
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
336
+ sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
237
337
  else:
238
- sql += " f.value[0]::VARCHAR as ID,\n"
239
- sql += " f.value[1]::TIMESTAMP_NTZ as END_TIME"
338
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
339
+ sql += " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__"
240
340
  if min_offset is not None:
241
- sql += ",\n f.value[2]::TIMESTAMP_NTZ as START_TIME"
341
+ sql += ",\n f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__"
242
342
  sql += (f"\n"
243
- f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(%s))) f\n"
343
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
244
344
  f")\n"
245
- f"SELECT TMP.BATCH as __BATCH__, "
246
- f"{', '.join('FACT.' + quote_ident(col) for col in columns)}\n"
345
+ f"SELECT "
346
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
347
+ f"{', '.join(projections)}\n"
247
348
  f"FROM TMP\n"
248
- f"JOIN {self.fqn_dict[table_name]} FACT\n"
249
- f" ON FACT.{quote_ident(fkey)} = TMP.ID\n"
250
- f" AND FACT.{quote_ident(time_column)} <= TMP.END_TIME")
251
- if min_offset is not None:
252
- sql += f"\n AND FACT.{quote_ident(time_column)} > TMP.START_TIME"
253
-
254
- with self._connection.cursor() as cursor:
349
+ f"JOIN {self.source_name_dict[table_name]}\n"
350
+ f" ON {key_ref} = TMP.__KUMO_ID__\n"
351
+ f" AND {time_ref} <= TMP.__KUMO_END_TIME__\n")
352
+ if start_time is not None:
353
+ sql += f"AND {time_ref} > TMP.__KUMO_START_TIME__\n"
354
+ # Add global time bounds to enable partition pruning:
355
+ sql += f"WHERE {time_ref} <= '{end_time.max()}'"
356
+ if start_time is not None:
357
+ sql += f"\nAND {time_ref} > '{start_time.min()}'"
358
+
359
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
255
360
  cursor.execute(sql, (payload, ))
256
361
  table = cursor.fetch_arrow_all()
257
362
 
258
- batch = table['__BATCH__'].cast(pa.int64()).to_numpy()
259
- table = table.remove_column(table.schema.get_field_index('__BATCH__'))
260
-
261
- return self._sanitize(table_name, table), batch
363
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
364
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
365
+ table = table.remove_column(batch_index)
262
366
 
263
- def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
264
- return table.to_pandas(types_mapper=pd.ArrowDtype)
367
+ return Table._sanitize(
368
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
369
+ dtype_dict=self.table_dtype_dict[table_name],
370
+ stype_dict=self.table_stype_dict[table_name],
371
+ ), batch