kumoai 2.13.0.dev202511191731__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0rc2__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. kumoai/__init__.py +35 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +44 -9
  7. kumoai/experimental/rfm/__init__.py +70 -68
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/__init__.py +0 -0
  10. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  11. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
  12. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  13. kumoai/experimental/rfm/backend/local/table.py +113 -0
  14. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  15. kumoai/experimental/rfm/backend/snow/sampler.py +366 -0
  16. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  17. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  18. kumoai/experimental/rfm/backend/sqlite/sampler.py +454 -0
  19. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  20. kumoai/experimental/rfm/base/__init__.py +30 -0
  21. kumoai/experimental/rfm/base/column.py +152 -0
  22. kumoai/experimental/rfm/base/expression.py +44 -0
  23. kumoai/experimental/rfm/base/mapper.py +67 -0
  24. kumoai/experimental/rfm/base/sampler.py +782 -0
  25. kumoai/experimental/rfm/base/source.py +19 -0
  26. kumoai/experimental/rfm/base/sql_sampler.py +366 -0
  27. kumoai/experimental/rfm/base/table.py +741 -0
  28. kumoai/experimental/rfm/{local_graph.py → graph.py} +581 -154
  29. kumoai/experimental/rfm/infer/__init__.py +8 -0
  30. kumoai/experimental/rfm/infer/dtype.py +82 -0
  31. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  32. kumoai/experimental/rfm/infer/pkey.py +128 -0
  33. kumoai/experimental/rfm/infer/stype.py +35 -0
  34. kumoai/experimental/rfm/infer/time_col.py +61 -0
  35. kumoai/experimental/rfm/pquery/executor.py +27 -27
  36. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  37. kumoai/experimental/rfm/relbench.py +76 -0
  38. kumoai/experimental/rfm/rfm.py +775 -481
  39. kumoai/experimental/rfm/sagemaker.py +15 -7
  40. kumoai/experimental/rfm/task_table.py +292 -0
  41. kumoai/pquery/predictive_query.py +10 -6
  42. kumoai/pquery/training_table.py +16 -2
  43. kumoai/testing/decorators.py +1 -1
  44. kumoai/testing/snow.py +50 -0
  45. kumoai/trainer/distilled_trainer.py +175 -0
  46. kumoai/utils/__init__.py +3 -2
  47. kumoai/utils/display.py +87 -0
  48. kumoai/utils/progress_logger.py +190 -12
  49. kumoai/utils/sql.py +3 -0
  50. {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/METADATA +10 -8
  51. {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/RECORD +54 -30
  52. kumoai/experimental/rfm/local_graph_sampler.py +0 -182
  53. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  54. kumoai/experimental/rfm/local_table.py +0 -545
  55. kumoai/experimental/rfm/utils.py +0 -344
  56. {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/WHEEL +0 -0
  57. {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/licenses/LICENSE +0 -0
  58. {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,113 @@
1
+ from typing import Sequence, cast
2
+
3
+ import pandas as pd
4
+ from kumoapi.model_plan import MissingType
5
+
6
+ from kumoai.experimental.rfm.base import (
7
+ ColumnSpec,
8
+ DataBackend,
9
+ SourceColumn,
10
+ SourceForeignKey,
11
+ Table,
12
+ )
13
+
14
+
15
+ class LocalTable(Table):
16
+ r"""A table backed by a :class:`pandas.DataFrame`.
17
+
18
+ A :class:`LocalTable` fully specifies the relevant metadata, *i.e.*
19
+ selected columns, column semantic types, primary keys and time columns.
20
+ :class:`LocalTable` is used to create a :class:`Graph`.
21
+
22
+ .. code-block:: python
23
+
24
+ import pandas as pd
25
+ import kumoai.experimental.rfm as rfm
26
+
27
+ # Load data from a CSV file:
28
+ df = pd.read_csv("data.csv")
29
+
30
+ # Create a table from a `pandas.DataFrame` and infer its metadata ...
31
+ table = rfm.LocalTable(df, name="my_table").infer_metadata()
32
+
33
+ # ... or create a table explicitly:
34
+ table = rfm.LocalTable(
35
+ df=df,
36
+ name="my_table",
37
+ primary_key="id",
38
+ time_column="time",
39
+ end_time_column=None,
40
+ )
41
+
42
+ # Verify metadata:
43
+ table.print_metadata()
44
+
45
+ # Change the semantic type of a column:
46
+ table[column].stype = "text"
47
+
48
+ Args:
49
+ df: The data frame to create this table from.
50
+ name: The name of this table.
51
+ primary_key: The name of the primary key of this table, if it exists.
52
+ time_column: The name of the time column of this table, if it exists.
53
+ end_time_column: The name of the end time column of this table, if it
54
+ exists.
55
+ """
56
+ def __init__(
57
+ self,
58
+ df: pd.DataFrame,
59
+ name: str,
60
+ primary_key: MissingType | str | None = MissingType.VALUE,
61
+ time_column: str | None = None,
62
+ end_time_column: str | None = None,
63
+ ) -> None:
64
+
65
+ if df.empty:
66
+ raise ValueError("Data frame is empty")
67
+ if isinstance(df.columns, pd.MultiIndex):
68
+ raise ValueError("Data frame must not have a multi-index")
69
+ if not df.columns.is_unique:
70
+ raise ValueError("Data frame must have unique column names")
71
+ if any(col == '' for col in df.columns):
72
+ raise ValueError("Data frame must have non-empty column names")
73
+
74
+ self._data = df.copy(deep=False)
75
+
76
+ super().__init__(
77
+ name=name,
78
+ primary_key=primary_key,
79
+ time_column=time_column,
80
+ end_time_column=end_time_column,
81
+ )
82
+
83
+ @property
84
+ def backend(self) -> DataBackend:
85
+ return cast(DataBackend, DataBackend.LOCAL)
86
+
87
+ def _get_source_columns(self) -> list[SourceColumn]:
88
+ return [
89
+ SourceColumn(
90
+ name=column_name,
91
+ dtype=None,
92
+ is_primary_key=False,
93
+ is_unique_key=False,
94
+ is_nullable=True,
95
+ ) for column_name in self._data.columns
96
+ ]
97
+
98
+ def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
99
+ return []
100
+
101
+ def _get_source_sample_df(self) -> pd.DataFrame:
102
+ return self._data
103
+
104
+ def _get_expr_sample_df(
105
+ self,
106
+ columns: Sequence[ColumnSpec],
107
+ ) -> pd.DataFrame:
108
+ raise RuntimeError(f"Column expressions are not supported in "
109
+ f"'{self.__class__.__name__}'. Please apply your "
110
+ f"expressions on the `pd.DataFrame` directly.")
111
+
112
+ def _get_num_rows(self) -> int | None:
113
+ return len(self._data)
@@ -0,0 +1,37 @@
1
+ from typing import Any, TypeAlias
2
+
3
+ try:
4
+ import snowflake.connector
5
+ except ImportError:
6
+ raise ImportError("No module named 'snowflake'. Please install Kumo SDK "
7
+ "with the 'snowflake' extension via "
8
+ "`pip install kumoai[snowflake]`.")
9
+
10
+ Connection: TypeAlias = snowflake.connector.SnowflakeConnection
11
+
12
+
13
+ def connect(**kwargs: Any) -> Connection:
14
+ r"""Opens a connection to a :class:`snowflake` database.
15
+
16
+ If available, will return a connection to the active session.
17
+
18
+ kwargs: Connection arguments, following the :class:`snowflake` protocol.
19
+ """
20
+ try:
21
+ from snowflake.snowpark.context import get_active_session
22
+ return get_active_session().connection
23
+ except Exception:
24
+ pass
25
+
26
+ return snowflake.connector.connect(**kwargs)
27
+
28
+
29
+ from .table import SnowTable # noqa: E402
30
+ from .sampler import SnowSampler # noqa: E402
31
+
32
+ __all__ = [
33
+ 'connect',
34
+ 'Connection',
35
+ 'SnowTable',
36
+ 'SnowSampler',
37
+ ]
@@ -0,0 +1,366 @@
1
+ import json
2
+ from collections.abc import Iterator
3
+ from contextlib import contextmanager
4
+ from typing import TYPE_CHECKING
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import pyarrow as pa
9
+ from kumoapi.pquery import ValidatedPredictiveQuery
10
+
11
+ from kumoai.experimental.rfm.backend.snow import Connection, SnowTable
12
+ from kumoai.experimental.rfm.base import SQLSampler, Table
13
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
14
+ from kumoai.utils import ProgressLogger
15
+
16
+ if TYPE_CHECKING:
17
+ from kumoai.experimental.rfm import Graph
18
+
19
+
20
+ @contextmanager
21
+ def paramstyle(connection: Connection, style: str = 'qmark') -> Iterator[None]:
22
+ _style = connection._paramstyle
23
+ connection._paramstyle = style
24
+ yield
25
+ connection._paramstyle = _style
26
+
27
+
28
+ class SnowSampler(SQLSampler):
29
+ def __init__(
30
+ self,
31
+ graph: 'Graph',
32
+ verbose: bool | ProgressLogger = True,
33
+ ) -> None:
34
+ super().__init__(graph=graph, verbose=verbose)
35
+
36
+ for table in graph.tables.values():
37
+ assert isinstance(table, SnowTable)
38
+ self._connection = table._connection
39
+
40
+ def _get_min_max_time_dict(
41
+ self,
42
+ table_names: list[str],
43
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
44
+ selects: list[str] = []
45
+ for table_name in table_names:
46
+ column = self.time_column_dict[table_name]
47
+ column_ref = self.table_column_ref_dict[table_name][column]
48
+ select = (f"SELECT\n"
49
+ f" ? as table_name,\n"
50
+ f" MIN({column_ref}) as min_date,\n"
51
+ f" MAX({column_ref}) as max_date\n"
52
+ f"FROM {self.source_name_dict[table_name]}")
53
+ selects.append(select)
54
+ sql = "\nUNION ALL\n".join(selects)
55
+
56
+ out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
57
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
58
+ cursor.execute(sql, table_names)
59
+ rows = cursor.fetchall()
60
+ for table_name, _min, _max in rows:
61
+ out_dict[table_name] = (
62
+ pd.Timestamp.max if _min is None else pd.Timestamp(_min),
63
+ pd.Timestamp.min if _max is None else pd.Timestamp(_max),
64
+ )
65
+
66
+ return out_dict
67
+
68
+ def _sample_entity_table(
69
+ self,
70
+ table_name: str,
71
+ columns: set[str],
72
+ num_rows: int,
73
+ random_seed: int | None = None,
74
+ ) -> pd.DataFrame:
75
+ # NOTE Snowflake does support `SEED` only as part of `SYSTEM` sampling.
76
+ num_rows = min(num_rows, 1_000_000) # Snowflake's upper limit.
77
+
78
+ source_table = self.source_table_dict[table_name]
79
+ filters: list[str] = []
80
+
81
+ key = self.primary_key_dict[table_name]
82
+ if key not in source_table or source_table[key].is_nullable:
83
+ key_ref = self.table_column_ref_dict[table_name][key]
84
+ filters.append(f" {key_ref} IS NOT NULL")
85
+
86
+ column = self.time_column_dict.get(table_name)
87
+ if column is None:
88
+ pass
89
+ elif column not in source_table or source_table[column].is_nullable:
90
+ column_ref = self.table_column_ref_dict[table_name][column]
91
+ filters.append(f" {column_ref} IS NOT NULL")
92
+
93
+ projections = [
94
+ self.table_column_proj_dict[table_name][column]
95
+ for column in columns
96
+ ]
97
+ sql = (f"SELECT {', '.join(projections)}\n"
98
+ f"FROM {self.source_name_dict[table_name]}\n"
99
+ f"SAMPLE ROW ({num_rows} ROWS)")
100
+ if len(filters) > 0:
101
+ sql += f"\nWHERE{' AND'.join(filters)}"
102
+
103
+ with self._connection.cursor() as cursor:
104
+ # NOTE This may return duplicate primary keys. This is okay.
105
+ cursor.execute(sql)
106
+ table = cursor.fetch_arrow_all()
107
+
108
+ return Table._sanitize(
109
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
110
+ dtype_dict=self.table_dtype_dict[table_name],
111
+ stype_dict=self.table_stype_dict[table_name],
112
+ )
113
+
114
+ def _sample_target(
115
+ self,
116
+ query: ValidatedPredictiveQuery,
117
+ entity_df: pd.DataFrame,
118
+ train_index: np.ndarray,
119
+ train_time: pd.Series,
120
+ num_train_examples: int,
121
+ test_index: np.ndarray,
122
+ test_time: pd.Series,
123
+ num_test_examples: int,
124
+ columns_dict: dict[str, set[str]],
125
+ time_offset_dict: dict[
126
+ tuple[str, str, str],
127
+ tuple[pd.DateOffset | None, pd.DateOffset],
128
+ ],
129
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
130
+
131
+ # NOTE For Snowflake, we execute everything at once to pay minimal
132
+ # query initialization costs.
133
+ index = np.concatenate([train_index, test_index])
134
+ time = pd.concat([train_time, test_time], axis=0, ignore_index=True)
135
+
136
+ entity_df = entity_df.iloc[index].reset_index(drop=True)
137
+
138
+ feat_dict: dict[str, pd.DataFrame] = {query.entity_table: entity_df}
139
+ time_dict: dict[str, pd.Series] = {}
140
+ time_column = self.time_column_dict.get(query.entity_table)
141
+ if time_column in columns_dict[query.entity_table]:
142
+ time_dict[query.entity_table] = entity_df[time_column]
143
+ batch_dict: dict[str, np.ndarray] = {
144
+ query.entity_table: np.arange(len(entity_df)),
145
+ }
146
+ for edge_type, (min_offset, max_offset) in time_offset_dict.items():
147
+ table_name, foreign_key, _ = edge_type
148
+ feat_dict[table_name], batch_dict[table_name] = self._by_time(
149
+ table_name=table_name,
150
+ foreign_key=foreign_key,
151
+ index=entity_df[self.primary_key_dict[query.entity_table]],
152
+ anchor_time=time,
153
+ min_offset=min_offset,
154
+ max_offset=max_offset,
155
+ columns=columns_dict[table_name],
156
+ )
157
+ time_column = self.time_column_dict.get(table_name)
158
+ if time_column in columns_dict[table_name]:
159
+ time_dict[table_name] = feat_dict[table_name][time_column]
160
+
161
+ y, mask = PQueryPandasExecutor().execute(
162
+ query=query,
163
+ feat_dict=feat_dict,
164
+ time_dict=time_dict,
165
+ batch_dict=batch_dict,
166
+ anchor_time=time,
167
+ num_forecasts=query.num_forecasts,
168
+ )
169
+
170
+ train_mask = mask[:len(train_index)]
171
+ test_mask = mask[len(train_index):]
172
+
173
+ boundary = int(train_mask.sum())
174
+ train_y = y.iloc[:boundary]
175
+ test_y = y.iloc[boundary:].reset_index(drop=True)
176
+
177
+ return train_y, train_mask, test_y, test_mask
178
+
179
+ def _by_pkey(
180
+ self,
181
+ table_name: str,
182
+ index: pd.Series,
183
+ columns: set[str],
184
+ ) -> tuple[pd.DataFrame, np.ndarray]:
185
+ key = self.primary_key_dict[table_name]
186
+ key_ref = self.table_column_ref_dict[table_name][key]
187
+ projections = [
188
+ self.table_column_proj_dict[table_name][column]
189
+ for column in columns
190
+ ]
191
+
192
+ payload = json.dumps(list(index))
193
+
194
+ sql = ("WITH TMP as (\n"
195
+ " SELECT\n"
196
+ " f.index as __KUMO_BATCH__,\n")
197
+ if self.table_dtype_dict[table_name][key].is_int():
198
+ sql += " f.value::NUMBER as __KUMO_ID__\n"
199
+ elif self.table_dtype_dict[table_name][key].is_float():
200
+ sql += " f.value::FLOAT as __KUMO_ID__\n"
201
+ else:
202
+ sql += " f.value::VARCHAR as __KUMO_ID__\n"
203
+ sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
204
+ f")\n"
205
+ f"SELECT "
206
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
207
+ f"{', '.join(projections)}\n"
208
+ f"FROM TMP\n"
209
+ f"JOIN {self.source_name_dict[table_name]}\n"
210
+ f" ON {key_ref} = TMP.__KUMO_ID__")
211
+
212
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
213
+ cursor.execute(sql, (payload, ))
214
+ table = cursor.fetch_arrow_all()
215
+
216
+ # Remove any duplicated primary keys in post-processing:
217
+ tmp = table.append_column('__KUMO_ID__', pa.array(range(len(table))))
218
+ gb = tmp.group_by('__KUMO_BATCH__').aggregate([('__KUMO_ID__', 'min')])
219
+ table = table.take(gb['__KUMO_ID___min'])
220
+
221
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
222
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
223
+ table = table.remove_column(batch_index)
224
+
225
+ return Table._sanitize(
226
+ df=table.to_pandas(),
227
+ dtype_dict=self.table_dtype_dict[table_name],
228
+ stype_dict=self.table_stype_dict[table_name],
229
+ ), batch
230
+
231
+ def _by_fkey(
232
+ self,
233
+ table_name: str,
234
+ foreign_key: str,
235
+ index: pd.Series,
236
+ num_neighbors: int,
237
+ anchor_time: pd.Series | None,
238
+ columns: set[str],
239
+ ) -> tuple[pd.DataFrame, np.ndarray]:
240
+ time_column = self.time_column_dict.get(table_name)
241
+
242
+ if time_column is not None and anchor_time is not None:
243
+ anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
244
+ payload = json.dumps(list(zip(index, anchor_time)))
245
+ else:
246
+ payload = json.dumps(list(zip(index)))
247
+
248
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
249
+ projections = [
250
+ self.table_column_proj_dict[table_name][column]
251
+ for column in columns
252
+ ]
253
+
254
+ sql = ("WITH TMP as (\n"
255
+ " SELECT\n"
256
+ " f.index as __KUMO_BATCH__,\n")
257
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
258
+ sql += " f.value[0]::NUMBER as __KUMO_ID__"
259
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
260
+ sql += " f.value[0]::FLOAT as __KUMO_ID__"
261
+ else:
262
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__"
263
+ if time_column is not None and anchor_time is not None:
264
+ sql += (",\n"
265
+ " f.value[1]::TIMESTAMP_NTZ as __KUMO_TIME__")
266
+ sql += (f"\n"
267
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
268
+ f")\n"
269
+ f"SELECT "
270
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
271
+ f"{', '.join(projections)}\n"
272
+ f"FROM TMP\n"
273
+ f"JOIN {self.source_name_dict[table_name]}\n"
274
+ f" ON {key_ref} = TMP.__KUMO_ID__\n")
275
+ if time_column is not None and anchor_time is not None:
276
+ time_ref = self.table_column_ref_dict[table_name][time_column]
277
+ sql += f" AND {time_ref} <= TMP.__KUMO_TIME__\n"
278
+ sql += ("QUALIFY ROW_NUMBER() OVER (\n"
279
+ " PARTITION BY TMP.__KUMO_BATCH__\n")
280
+ if time_column is not None:
281
+ sql += f" ORDER BY {time_ref} DESC\n"
282
+ else:
283
+ sql += f" ORDER BY {key_ref}\n"
284
+ sql += f") <= {num_neighbors}"
285
+
286
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
287
+ cursor.execute(sql, (payload, ))
288
+ table = cursor.fetch_arrow_all()
289
+
290
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
291
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
292
+ table = table.remove_column(batch_index)
293
+
294
+ return Table._sanitize(
295
+ df=table.to_pandas(),
296
+ dtype_dict=self.table_dtype_dict[table_name],
297
+ stype_dict=self.table_stype_dict[table_name],
298
+ ), batch
299
+
300
+ # Helper Methods ##########################################################
301
+
302
+ def _by_time(
303
+ self,
304
+ table_name: str,
305
+ foreign_key: str,
306
+ index: pd.Series,
307
+ anchor_time: pd.Series,
308
+ min_offset: pd.DateOffset | None,
309
+ max_offset: pd.DateOffset,
310
+ columns: set[str],
311
+ ) -> tuple[pd.DataFrame, np.ndarray]:
312
+ time_column = self.time_column_dict[table_name]
313
+
314
+ end_time = anchor_time + max_offset
315
+ end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
316
+ if min_offset is not None:
317
+ start_time = anchor_time + min_offset
318
+ start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
319
+ payload = json.dumps(list(zip(index, end_time, start_time)))
320
+ else:
321
+ payload = json.dumps(list(zip(index, end_time)))
322
+
323
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
324
+ time_ref = self.table_column_ref_dict[table_name][time_column]
325
+ projections = [
326
+ self.table_column_proj_dict[table_name][column]
327
+ for column in columns
328
+ ]
329
+ sql = ("WITH TMP as (\n"
330
+ " SELECT\n"
331
+ " f.index as __KUMO_BATCH__,\n")
332
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
333
+ sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
334
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
335
+ sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
336
+ else:
337
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
338
+ sql += " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__"
339
+ if min_offset is not None:
340
+ sql += ",\n f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__"
341
+ sql += (f"\n"
342
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
343
+ f")\n"
344
+ f"SELECT "
345
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
346
+ f"{', '.join(projections)}\n"
347
+ f"FROM TMP\n"
348
+ f"JOIN {self.source_name_dict[table_name]}\n"
349
+ f" ON {key_ref} = TMP.__KUMO_ID__\n"
350
+ f" AND {time_ref} <= TMP.__KUMO_END_TIME__")
351
+ if min_offset is not None:
352
+ sql += f"\n AND {time_ref} > TMP.__KUMO_START_TIME__"
353
+
354
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
355
+ cursor.execute(sql, (payload, ))
356
+ table = cursor.fetch_arrow_all()
357
+
358
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
359
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
360
+ table = table.remove_column(batch_index)
361
+
362
+ return Table._sanitize(
363
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
364
+ dtype_dict=self.table_dtype_dict[table_name],
365
+ stype_dict=self.table_stype_dict[table_name],
366
+ ), batch