kumoai 2.14.0.dev202512141732__py3-none-any.whl → 2.15.0.dev202601131732__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +51 -24
  7. kumoai/experimental/rfm/authenticate.py +3 -4
  8. kumoai/experimental/rfm/backend/local/graph_store.py +37 -46
  9. kumoai/experimental/rfm/backend/local/sampler.py +4 -5
  10. kumoai/experimental/rfm/backend/local/table.py +24 -30
  11. kumoai/experimental/rfm/backend/snow/sampler.py +331 -43
  12. kumoai/experimental/rfm/backend/snow/table.py +166 -56
  13. kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
  14. kumoai/experimental/rfm/backend/sqlite/sampler.py +372 -30
  15. kumoai/experimental/rfm/backend/sqlite/table.py +117 -48
  16. kumoai/experimental/rfm/base/__init__.py +8 -1
  17. kumoai/experimental/rfm/base/column.py +96 -10
  18. kumoai/experimental/rfm/base/expression.py +44 -0
  19. kumoai/experimental/rfm/base/mapper.py +69 -0
  20. kumoai/experimental/rfm/base/sampler.py +28 -18
  21. kumoai/experimental/rfm/base/source.py +1 -1
  22. kumoai/experimental/rfm/base/sql_sampler.py +385 -0
  23. kumoai/experimental/rfm/base/table.py +374 -208
  24. kumoai/experimental/rfm/base/utils.py +36 -0
  25. kumoai/experimental/rfm/graph.py +335 -180
  26. kumoai/experimental/rfm/infer/__init__.py +6 -4
  27. kumoai/experimental/rfm/infer/dtype.py +10 -5
  28. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  29. kumoai/experimental/rfm/infer/pkey.py +4 -2
  30. kumoai/experimental/rfm/infer/stype.py +35 -0
  31. kumoai/experimental/rfm/infer/time_col.py +5 -4
  32. kumoai/experimental/rfm/pquery/executor.py +27 -27
  33. kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
  34. kumoai/experimental/rfm/relbench.py +76 -0
  35. kumoai/experimental/rfm/rfm.py +606 -361
  36. kumoai/experimental/rfm/sagemaker.py +4 -4
  37. kumoai/experimental/rfm/task_table.py +292 -0
  38. kumoai/pquery/training_table.py +16 -2
  39. kumoai/testing/snow.py +3 -3
  40. kumoai/trainer/distilled_trainer.py +175 -0
  41. kumoai/utils/__init__.py +1 -2
  42. kumoai/utils/display.py +87 -0
  43. kumoai/utils/progress_logger.py +192 -13
  44. kumoai/utils/sql.py +2 -2
  45. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/METADATA +3 -2
  46. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/RECORD +49 -40
  47. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/WHEEL +0 -0
  48. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/licenses/LICENSE +0 -0
  49. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,15 @@
1
- import warnings
2
- from typing import List, Optional, cast
1
+ from typing import Sequence, cast
3
2
 
4
3
  import pandas as pd
4
+ from kumoapi.model_plan import MissingType
5
5
 
6
6
  from kumoai.experimental.rfm.base import (
7
+ ColumnSpec,
7
8
  DataBackend,
8
9
  SourceColumn,
9
10
  SourceForeignKey,
10
11
  Table,
11
12
  )
12
- from kumoai.experimental.rfm.infer import infer_dtype
13
13
 
14
14
 
15
15
  class LocalTable(Table):
@@ -57,9 +57,9 @@ class LocalTable(Table):
57
57
  self,
58
58
  df: pd.DataFrame,
59
59
  name: str,
60
- primary_key: Optional[str] = None,
61
- time_column: Optional[str] = None,
62
- end_time_column: Optional[str] = None,
60
+ primary_key: MissingType | str | None = MissingType.VALUE,
61
+ time_column: str | None = None,
62
+ end_time_column: str | None = None,
63
63
  ) -> None:
64
64
 
65
65
  if df.empty:
@@ -75,7 +75,6 @@ class LocalTable(Table):
75
75
 
76
76
  super().__init__(
77
77
  name=name,
78
- columns=list(df.columns),
79
78
  primary_key=primary_key,
80
79
  time_column=time_column,
81
80
  end_time_column=end_time_column,
@@ -85,35 +84,30 @@ class LocalTable(Table):
85
84
  def backend(self) -> DataBackend:
86
85
  return cast(DataBackend, DataBackend.LOCAL)
87
86
 
88
- def _get_source_columns(self) -> List[SourceColumn]:
89
- source_columns: List[SourceColumn] = []
90
- for column in self._data.columns:
91
- ser = self._data[column]
92
- try:
93
- dtype = infer_dtype(ser)
94
- except Exception:
95
- warnings.warn(f"Data type inference for column '{column}' in "
96
- f"table '{self.name}' failed. Consider changing "
97
- f"the data type of the column to use it within "
98
- f"this table.")
99
- continue
100
-
101
- source_column = SourceColumn(
102
- name=column,
103
- dtype=dtype,
87
+ def _get_source_columns(self) -> list[SourceColumn]:
88
+ return [
89
+ SourceColumn(
90
+ name=column_name,
91
+ dtype=None,
104
92
  is_primary_key=False,
105
93
  is_unique_key=False,
106
94
  is_nullable=True,
107
- )
108
- source_columns.append(source_column)
95
+ ) for column_name in self._data.columns
96
+ ]
109
97
 
110
- return source_columns
111
-
112
- def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
98
+ def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
113
99
  return []
114
100
 
115
- def _get_sample_df(self) -> pd.DataFrame:
101
+ def _get_source_sample_df(self) -> pd.DataFrame:
116
102
  return self._data
117
103
 
118
- def _get_num_rows(self) -> Optional[int]:
104
+ def _get_expr_sample_df(
105
+ self,
106
+ columns: Sequence[ColumnSpec],
107
+ ) -> pd.DataFrame:
108
+ raise RuntimeError(f"Column expressions are not supported in "
109
+ f"'{self.__class__.__name__}'. Please apply your "
110
+ f"expressions on the `pd.DataFrame` directly.")
111
+
112
+ def _get_num_rows(self) -> int | None:
119
113
  return len(self._data)
@@ -1,35 +1,51 @@
1
- from typing import TYPE_CHECKING, Literal
1
+ import json
2
+ import math
3
+ from collections.abc import Iterator
4
+ from contextlib import contextmanager
5
+ from typing import TYPE_CHECKING, cast
2
6
 
3
7
  import numpy as np
4
8
  import pandas as pd
9
+ import pyarrow as pa
5
10
  from kumoapi.pquery import ValidatedPredictiveQuery
6
11
 
7
- from kumoai.experimental.rfm.backend.snow import SnowTable
8
- from kumoai.experimental.rfm.base import Sampler, SamplerOutput
9
- from kumoai.utils import ProgressLogger, quote_ident
12
+ from kumoai.experimental.rfm.backend.snow import Connection, SnowTable
13
+ from kumoai.experimental.rfm.base import SQLSampler, Table
14
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
15
+ from kumoai.utils import ProgressLogger
10
16
 
11
17
  if TYPE_CHECKING:
12
18
  from kumoai.experimental.rfm import Graph
13
19
 
14
20
 
15
- class SnowSampler(Sampler):
21
+ @contextmanager
22
+ def paramstyle(connection: Connection, style: str = 'qmark') -> Iterator[None]:
23
+ _style = connection._paramstyle
24
+ connection._paramstyle = style
25
+ yield
26
+ connection._paramstyle = _style
27
+
28
+
29
+ class SnowSampler(SQLSampler):
16
30
  def __init__(
17
31
  self,
18
32
  graph: 'Graph',
19
33
  verbose: bool | ProgressLogger = True,
20
34
  ) -> None:
21
- super().__init__(graph=graph)
35
+ super().__init__(graph=graph, verbose=verbose)
22
36
 
23
- self._fqn_dict: dict[str, str] = {}
24
37
  for table in graph.tables.values():
25
38
  assert isinstance(table, SnowTable)
26
39
  self._connection = table._connection
27
- self._fqn_dict[table.name] = table.fqn
40
+
41
+ self._num_rows_dict: dict[str, int] = {
42
+ table.name: cast(int, table._num_rows)
43
+ for table in graph.tables.values()
44
+ }
28
45
 
29
46
  @property
30
- def fqn_dict(self) -> dict[str, str]:
31
- r"""The fully-qualified quoted names for all tables in the graph."""
32
- return self._fqn_dict
47
+ def num_rows_dict(self) -> dict[str, int]:
48
+ return self._num_rows_dict
33
49
 
34
50
  def _get_min_max_time_dict(
35
51
  self,
@@ -37,37 +53,28 @@ class SnowSampler(Sampler):
37
53
  ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
38
54
  selects: list[str] = []
39
55
  for table_name in table_names:
40
- time_column = self.time_column_dict[table_name]
56
+ column = self.time_column_dict[table_name]
57
+ column_ref = self.table_column_ref_dict[table_name][column]
41
58
  select = (f"SELECT\n"
42
- f" %s as table_name,\n"
43
- f" MIN({quote_ident(time_column)}) as min_date,\n"
44
- f" MAX({quote_ident(time_column)}) as max_date\n"
45
- f"FROM {self.fqn_dict[table_name]}")
59
+ f" ? as table_name,\n"
60
+ f" MIN({column_ref}) as min_date,\n"
61
+ f" MAX({column_ref}) as max_date\n"
62
+ f"FROM {self.source_name_dict[table_name]}")
46
63
  selects.append(select)
47
64
  sql = "\nUNION ALL\n".join(selects)
48
65
 
49
66
  out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
50
- with self._connection.cursor() as cursor:
67
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
51
68
  cursor.execute(sql, table_names)
52
69
  rows = cursor.fetchall()
53
- for table_name, _min, _max in rows:
54
- out_dict[table_name] = (
55
- pd.Timestamp.max if _min is None else pd.Timestamp(_min),
56
- pd.Timestamp.min if _max is None else pd.Timestamp(_max),
57
- )
70
+ for table_name, _min, _max in rows:
71
+ out_dict[table_name] = (
72
+ pd.Timestamp.max if _min is None else pd.Timestamp(_min),
73
+ pd.Timestamp.min if _max is None else pd.Timestamp(_max),
74
+ )
58
75
 
59
76
  return out_dict
60
77
 
61
- def _sample_subgraph(
62
- self,
63
- entity_table_name: str,
64
- entity_pkey: pd.Series,
65
- anchor_time: pd.Series | Literal['entity'],
66
- columns_dict: dict[str, set[str]],
67
- num_neighbors: list[int],
68
- ) -> SamplerOutput:
69
- raise NotImplementedError
70
-
71
78
  def _sample_entity_table(
72
79
  self,
73
80
  table_name: str,
@@ -78,17 +85,27 @@ class SnowSampler(Sampler):
78
85
  # NOTE Snowflake does support `SEED` only as part of `SYSTEM` sampling.
79
86
  num_rows = min(num_rows, 1_000_000) # Snowflake's upper limit.
80
87
 
88
+ source_table = self.source_table_dict[table_name]
81
89
  filters: list[str] = []
82
- primary_key = self.primary_key_dict[table_name]
83
- if self.source_table_dict[table_name][primary_key].is_nullable:
84
- filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
85
- time_column = self.time_column_dict.get(table_name)
86
- if (time_column is not None and
87
- self.source_table_dict[table_name][time_column].is_nullable):
88
- filters.append(f" {quote_ident(time_column)} IS NOT NULL")
89
90
 
90
- sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
91
- f"FROM {self.fqn_dict[table_name]}\n"
91
+ key = self.primary_key_dict[table_name]
92
+ if key not in source_table or source_table[key].is_nullable:
93
+ key_ref = self.table_column_ref_dict[table_name][key]
94
+ filters.append(f" {key_ref} IS NOT NULL")
95
+
96
+ column = self.time_column_dict.get(table_name)
97
+ if column is None:
98
+ pass
99
+ elif column not in source_table or source_table[column].is_nullable:
100
+ column_ref = self.table_column_ref_dict[table_name][column]
101
+ filters.append(f" {column_ref} IS NOT NULL")
102
+
103
+ projections = [
104
+ self.table_column_proj_dict[table_name][column]
105
+ for column in columns
106
+ ]
107
+ sql = (f"SELECT {', '.join(projections)}\n"
108
+ f"FROM {self.source_name_dict[table_name]}\n"
92
109
  f"SAMPLE ROW ({num_rows} ROWS)")
93
110
  if len(filters) > 0:
94
111
  sql += f"\nWHERE{' AND'.join(filters)}"
@@ -98,7 +115,11 @@ class SnowSampler(Sampler):
98
115
  cursor.execute(sql)
99
116
  table = cursor.fetch_arrow_all()
100
117
 
101
- return table.to_pandas(types_mapper=pd.ArrowDtype)
118
+ return Table._sanitize(
119
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
120
+ dtype_dict=self.table_dtype_dict[table_name],
121
+ stype_dict=self.table_stype_dict[table_name],
122
+ )
102
123
 
103
124
  def _sample_target(
104
125
  self,
@@ -116,4 +137,271 @@ class SnowSampler(Sampler):
116
137
  tuple[pd.DateOffset | None, pd.DateOffset],
117
138
  ],
118
139
  ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
119
- raise NotImplementedError
140
+
141
+ # NOTE For Snowflake, we execute everything at once to pay minimal
142
+ # query initialization costs.
143
+ index = np.concatenate([train_index, test_index])
144
+ time = pd.concat([train_time, test_time], axis=0, ignore_index=True)
145
+
146
+ entity_df = entity_df.iloc[index].reset_index(drop=True)
147
+
148
+ feat_dict: dict[str, pd.DataFrame] = {query.entity_table: entity_df}
149
+ time_dict: dict[str, pd.Series] = {}
150
+ time_column = self.time_column_dict.get(query.entity_table)
151
+ if time_column in columns_dict[query.entity_table]:
152
+ time_dict[query.entity_table] = entity_df[time_column]
153
+ batch_dict: dict[str, np.ndarray] = {
154
+ query.entity_table: np.arange(len(entity_df)),
155
+ }
156
+ for edge_type, (min_offset, max_offset) in time_offset_dict.items():
157
+ table_name, foreign_key, _ = edge_type
158
+ feat_dict[table_name], batch_dict[table_name] = self._by_time(
159
+ table_name=table_name,
160
+ foreign_key=foreign_key,
161
+ index=entity_df[self.primary_key_dict[query.entity_table]],
162
+ anchor_time=time,
163
+ min_offset=min_offset,
164
+ max_offset=max_offset,
165
+ columns=columns_dict[table_name],
166
+ )
167
+ time_column = self.time_column_dict.get(table_name)
168
+ if time_column in columns_dict[table_name]:
169
+ time_dict[table_name] = feat_dict[table_name][time_column]
170
+
171
+ y, mask = PQueryPandasExecutor().execute(
172
+ query=query,
173
+ feat_dict=feat_dict,
174
+ time_dict=time_dict,
175
+ batch_dict=batch_dict,
176
+ anchor_time=time,
177
+ num_forecasts=query.num_forecasts,
178
+ )
179
+
180
+ train_mask = mask[:len(train_index)]
181
+ test_mask = mask[len(train_index):]
182
+
183
+ boundary = int(train_mask.sum())
184
+ train_y = y.iloc[:boundary]
185
+ test_y = y.iloc[boundary:].reset_index(drop=True)
186
+
187
+ return train_y, train_mask, test_y, test_mask
188
+
189
+ def _by_pkey(
190
+ self,
191
+ table_name: str,
192
+ index: pd.Series,
193
+ columns: set[str],
194
+ ) -> tuple[pd.DataFrame, np.ndarray]:
195
+ key = self.primary_key_dict[table_name]
196
+ key_ref = self.table_column_ref_dict[table_name][key]
197
+ projections = [
198
+ self.table_column_proj_dict[table_name][column]
199
+ for column in columns
200
+ ]
201
+
202
+ payload = json.dumps(list(index))
203
+
204
+ sql = ("WITH TMP as (\n"
205
+ " SELECT\n"
206
+ " f.index as __KUMO_BATCH__,\n")
207
+ if self.table_dtype_dict[table_name][key].is_int():
208
+ sql += " f.value::NUMBER as __KUMO_ID__\n"
209
+ elif self.table_dtype_dict[table_name][key].is_float():
210
+ sql += " f.value::FLOAT as __KUMO_ID__\n"
211
+ else:
212
+ sql += " f.value::VARCHAR as __KUMO_ID__\n"
213
+ sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
214
+ f")\n"
215
+ f"SELECT "
216
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
217
+ f"{', '.join(projections)}\n"
218
+ f"FROM TMP\n"
219
+ f"JOIN {self.source_name_dict[table_name]}\n"
220
+ f" ON {key_ref} = TMP.__KUMO_ID__")
221
+
222
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
223
+ cursor.execute(sql, (payload, ))
224
+ table = cursor.fetch_arrow_all()
225
+
226
+ # Remove any duplicated primary keys in post-processing:
227
+ tmp = table.append_column('__KUMO_ID__', pa.array(range(len(table))))
228
+ gb = tmp.group_by('__KUMO_BATCH__').aggregate([('__KUMO_ID__', 'min')])
229
+ table = table.take(gb['__KUMO_ID___min'])
230
+
231
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
232
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
233
+ table = table.remove_column(batch_index)
234
+
235
+ return Table._sanitize(
236
+ df=table.to_pandas(),
237
+ dtype_dict=self.table_dtype_dict[table_name],
238
+ stype_dict=self.table_stype_dict[table_name],
239
+ ), batch
240
+
241
+ def _by_fkey(
242
+ self,
243
+ table_name: str,
244
+ foreign_key: str,
245
+ index: pd.Series,
246
+ num_neighbors: int,
247
+ anchor_time: pd.Series | None,
248
+ columns: set[str],
249
+ ) -> tuple[pd.DataFrame, np.ndarray]:
250
+ time_column = self.time_column_dict.get(table_name)
251
+
252
+ end_time: pd.Series | None = None
253
+ start_time: pd.Series | None = None
254
+ if time_column is not None and anchor_time is not None:
255
+ # In order to avoid a full table scan, we limit foreign key
256
+ # sampling to a certain time range, approximated by the number of
257
+ # rows, timestamp ranges and `num_neighbors` value.
258
+ # Downstream, this helps Snowflake to apply partition pruning:
259
+ dst_table_name = [
260
+ dst_table
261
+ for key, dst_table in self.foreign_key_dict[table_name]
262
+ if key == foreign_key
263
+ ][0]
264
+ num_facts = self.num_rows_dict[table_name]
265
+ num_entities = self.num_rows_dict[dst_table_name]
266
+ min_time = self.get_min_time([table_name])
267
+ max_time = self.get_max_time([table_name])
268
+ freq = num_facts / num_entities
269
+ freq = freq / max((max_time - min_time).total_seconds(), 1)
270
+ offset = pd.Timedelta(seconds=math.ceil(5 * num_neighbors / freq))
271
+
272
+ end_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
273
+ start_time = anchor_time - offset
274
+ start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
275
+ payload = json.dumps(list(zip(index, end_time, start_time)))
276
+ else:
277
+ payload = json.dumps(list(zip(index)))
278
+
279
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
280
+ projections = [
281
+ self.table_column_proj_dict[table_name][column]
282
+ for column in columns
283
+ ]
284
+
285
+ sql = ("WITH TMP as (\n"
286
+ " SELECT\n"
287
+ " f.index as __KUMO_BATCH__,\n")
288
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
289
+ sql += " f.value[0]::NUMBER as __KUMO_ID__"
290
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
291
+ sql += " f.value[0]::FLOAT as __KUMO_ID__"
292
+ else:
293
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__"
294
+ if end_time is not None and start_time is not None:
295
+ sql += (",\n"
296
+ " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__,\n"
297
+ " f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__")
298
+ sql += (f"\n"
299
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
300
+ f")\n"
301
+ f"SELECT "
302
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
303
+ f"{', '.join(projections)}\n"
304
+ f"FROM TMP\n"
305
+ f"JOIN {self.source_name_dict[table_name]}\n"
306
+ f" ON {key_ref} = TMP.__KUMO_ID__\n")
307
+ if end_time is not None and start_time is not None:
308
+ assert time_column is not None
309
+ time_ref = self.table_column_ref_dict[table_name][time_column]
310
+ sql += (f" AND {time_ref} <= TMP.__KUMO_END_TIME__\n"
311
+ f" AND {time_ref} > TMP.__KUMO_START_TIME__\n"
312
+ f"WHERE {time_ref} <= '{end_time.max()}'\n"
313
+ f" AND {time_ref} > '{start_time.min()}'\n")
314
+ sql += ("QUALIFY ROW_NUMBER() OVER (\n"
315
+ " PARTITION BY TMP.__KUMO_BATCH__\n")
316
+ if time_column is not None:
317
+ sql += f" ORDER BY {time_ref} DESC\n"
318
+ else:
319
+ sql += f" ORDER BY {key_ref}\n"
320
+ sql += f") <= {num_neighbors}"
321
+
322
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
323
+ cursor.execute(sql, (payload, ))
324
+ table = cursor.fetch_arrow_all()
325
+
326
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
327
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
328
+ table = table.remove_column(batch_index)
329
+
330
+ return Table._sanitize(
331
+ df=table.to_pandas(),
332
+ dtype_dict=self.table_dtype_dict[table_name],
333
+ stype_dict=self.table_stype_dict[table_name],
334
+ ), batch
335
+
336
+ # Helper Methods ##########################################################
337
+
338
+ def _by_time(
339
+ self,
340
+ table_name: str,
341
+ foreign_key: str,
342
+ index: pd.Series,
343
+ anchor_time: pd.Series,
344
+ min_offset: pd.DateOffset | None,
345
+ max_offset: pd.DateOffset,
346
+ columns: set[str],
347
+ ) -> tuple[pd.DataFrame, np.ndarray]:
348
+ time_column = self.time_column_dict[table_name]
349
+
350
+ end_time = anchor_time + max_offset
351
+ end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
352
+ start_time: pd.Series | None = None
353
+ if min_offset is not None:
354
+ start_time = anchor_time + min_offset
355
+ start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
356
+ payload = json.dumps(list(zip(index, end_time, start_time)))
357
+ else:
358
+ payload = json.dumps(list(zip(index, end_time)))
359
+
360
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
361
+ time_ref = self.table_column_ref_dict[table_name][time_column]
362
+ projections = [
363
+ self.table_column_proj_dict[table_name][column]
364
+ for column in columns
365
+ ]
366
+ sql = ("WITH TMP as (\n"
367
+ " SELECT\n"
368
+ " f.index as __KUMO_BATCH__,\n")
369
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
370
+ sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
371
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
372
+ sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
373
+ else:
374
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
375
+ sql += " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__"
376
+ if min_offset is not None:
377
+ sql += ",\n f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__"
378
+ sql += (f"\n"
379
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
380
+ f")\n"
381
+ f"SELECT "
382
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
383
+ f"{', '.join(projections)}\n"
384
+ f"FROM TMP\n"
385
+ f"JOIN {self.source_name_dict[table_name]}\n"
386
+ f" ON {key_ref} = TMP.__KUMO_ID__\n"
387
+ f" AND {time_ref} <= TMP.__KUMO_END_TIME__\n")
388
+ if start_time is not None:
389
+ sql += f"AND {time_ref} > TMP.__KUMO_START_TIME__\n"
390
+ # Add global time bounds to enable partition pruning:
391
+ sql += f"WHERE {time_ref} <= '{end_time.max()}'"
392
+ if start_time is not None:
393
+ sql += f"\nAND {time_ref} > '{start_time.min()}'"
394
+
395
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
396
+ cursor.execute(sql, (payload, ))
397
+ table = cursor.fetch_arrow_all()
398
+
399
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
400
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
401
+ table = table.remove_column(batch_index)
402
+
403
+ return Table._sanitize(
404
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
405
+ dtype_dict=self.table_dtype_dict[table_name],
406
+ stype_dict=self.table_stype_dict[table_name],
407
+ ), batch