kumoai 2.14.0.dev202512211732__cp313-cp313-win_amd64.whl → 2.15.0.dev202601151732__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +24 -22
- kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +24 -25
- kumoai/experimental/rfm/backend/snow/sampler.py +235 -80
- kumoai/experimental/rfm/backend/snow/table.py +146 -70
- kumoai/experimental/rfm/backend/sqlite/sampler.py +196 -89
- kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
- kumoai/experimental/rfm/base/__init__.py +6 -9
- kumoai/experimental/rfm/base/column.py +95 -11
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +28 -18
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +320 -19
- kumoai/experimental/rfm/base/table.py +256 -109
- kumoai/experimental/rfm/base/utils.py +36 -0
- kumoai/experimental/rfm/graph.py +130 -110
- kumoai/experimental/rfm/infer/dtype.py +7 -2
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/time_col.py +4 -2
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +540 -306
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/kumolib.cp313-win_amd64.pyd +0 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +3 -3
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +15 -2
- kumoai/utils/sql.py +2 -2
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/METADATA +2 -2
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/RECORD +41 -36
- kumoai/experimental/rfm/base/column_expression.py +0 -50
- kumoai/experimental/rfm/base/sql_table.py +0 -229
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/top_level.txt +0 -0
|
@@ -1,16 +1,21 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import math
|
|
2
3
|
from collections.abc import Iterator
|
|
3
4
|
from contextlib import contextmanager
|
|
5
|
+
from typing import TYPE_CHECKING, cast
|
|
4
6
|
|
|
5
7
|
import numpy as np
|
|
6
8
|
import pandas as pd
|
|
7
9
|
import pyarrow as pa
|
|
8
10
|
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
9
11
|
|
|
10
|
-
from kumoai.experimental.rfm.backend.snow import Connection
|
|
11
|
-
from kumoai.experimental.rfm.base import SQLSampler
|
|
12
|
+
from kumoai.experimental.rfm.backend.snow import Connection, SnowTable
|
|
13
|
+
from kumoai.experimental.rfm.base import SQLSampler, Table
|
|
12
14
|
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
13
|
-
from kumoai.utils import quote_ident
|
|
15
|
+
from kumoai.utils import ProgressLogger, quote_ident
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from kumoai.experimental.rfm import Graph
|
|
14
19
|
|
|
15
20
|
|
|
16
21
|
@contextmanager
|
|
@@ -22,30 +27,51 @@ def paramstyle(connection: Connection, style: str = 'qmark') -> Iterator[None]:
|
|
|
22
27
|
|
|
23
28
|
|
|
24
29
|
class SnowSampler(SQLSampler):
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
graph: 'Graph',
|
|
33
|
+
verbose: bool | ProgressLogger = True,
|
|
34
|
+
) -> None:
|
|
35
|
+
super().__init__(graph=graph, verbose=verbose)
|
|
36
|
+
|
|
37
|
+
for table in graph.tables.values():
|
|
38
|
+
assert isinstance(table, SnowTable)
|
|
39
|
+
self._connection = table._connection
|
|
40
|
+
|
|
41
|
+
self._num_rows_dict: dict[str, int] = {
|
|
42
|
+
table.name: cast(int, table._num_rows)
|
|
43
|
+
for table in graph.tables.values()
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def num_rows_dict(self) -> dict[str, int]:
|
|
48
|
+
return self._num_rows_dict
|
|
49
|
+
|
|
25
50
|
def _get_min_max_time_dict(
|
|
26
51
|
self,
|
|
27
52
|
table_names: list[str],
|
|
28
53
|
) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
|
|
29
54
|
selects: list[str] = []
|
|
30
55
|
for table_name in table_names:
|
|
31
|
-
|
|
56
|
+
column = self.time_column_dict[table_name]
|
|
57
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
58
|
+
ident = quote_ident(table_name, char="'")
|
|
32
59
|
select = (f"SELECT\n"
|
|
33
|
-
f"
|
|
34
|
-
f" MIN({
|
|
35
|
-
f" MAX({
|
|
36
|
-
f"FROM {self.
|
|
60
|
+
f" {ident} as table_name,\n"
|
|
61
|
+
f" MIN({column_ref}) as min_date,\n"
|
|
62
|
+
f" MAX({column_ref}) as max_date\n"
|
|
63
|
+
f"FROM {self.source_name_dict[table_name]}")
|
|
37
64
|
selects.append(select)
|
|
38
65
|
sql = "\nUNION ALL\n".join(selects)
|
|
39
66
|
|
|
40
67
|
out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
41
|
-
with
|
|
42
|
-
cursor.execute(sql
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
)
|
|
68
|
+
with self._connection.cursor() as cursor:
|
|
69
|
+
cursor.execute(sql)
|
|
70
|
+
for table_name, _min, _max in cursor.fetchall():
|
|
71
|
+
out_dict[table_name] = (
|
|
72
|
+
pd.Timestamp.max if _min is None else pd.Timestamp(_min),
|
|
73
|
+
pd.Timestamp.min if _max is None else pd.Timestamp(_max),
|
|
74
|
+
)
|
|
49
75
|
|
|
50
76
|
return out_dict
|
|
51
77
|
|
|
@@ -59,17 +85,27 @@ class SnowSampler(SQLSampler):
|
|
|
59
85
|
# NOTE Snowflake does support `SEED` only as part of `SYSTEM` sampling.
|
|
60
86
|
num_rows = min(num_rows, 1_000_000) # Snowflake's upper limit.
|
|
61
87
|
|
|
88
|
+
source_table = self.source_table_dict[table_name]
|
|
62
89
|
filters: list[str] = []
|
|
63
|
-
primary_key = self.primary_key_dict[table_name]
|
|
64
|
-
if self.source_table_dict[table_name][primary_key].is_nullable:
|
|
65
|
-
filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
|
|
66
|
-
time_column = self.time_column_dict.get(table_name)
|
|
67
|
-
if (time_column is not None and
|
|
68
|
-
self.source_table_dict[table_name][time_column].is_nullable):
|
|
69
|
-
filters.append(f" {quote_ident(time_column)} IS NOT NULL")
|
|
70
90
|
|
|
71
|
-
|
|
72
|
-
|
|
91
|
+
key = self.primary_key_dict[table_name]
|
|
92
|
+
if key not in source_table or source_table[key].is_nullable:
|
|
93
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
94
|
+
filters.append(f" {key_ref} IS NOT NULL")
|
|
95
|
+
|
|
96
|
+
column = self.time_column_dict.get(table_name)
|
|
97
|
+
if column is None:
|
|
98
|
+
pass
|
|
99
|
+
elif column not in source_table or source_table[column].is_nullable:
|
|
100
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
101
|
+
filters.append(f" {column_ref} IS NOT NULL")
|
|
102
|
+
|
|
103
|
+
projections = [
|
|
104
|
+
self.table_column_proj_dict[table_name][column]
|
|
105
|
+
for column in columns
|
|
106
|
+
]
|
|
107
|
+
sql = (f"SELECT {', '.join(projections)}\n"
|
|
108
|
+
f"FROM {self.source_name_dict[table_name]}\n"
|
|
73
109
|
f"SAMPLE ROW ({num_rows} ROWS)")
|
|
74
110
|
if len(filters) > 0:
|
|
75
111
|
sql += f"\nWHERE{' AND'.join(filters)}"
|
|
@@ -79,7 +115,11 @@ class SnowSampler(SQLSampler):
|
|
|
79
115
|
cursor.execute(sql)
|
|
80
116
|
table = cursor.fetch_arrow_all()
|
|
81
117
|
|
|
82
|
-
return
|
|
118
|
+
return Table._sanitize(
|
|
119
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
120
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
121
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
122
|
+
)
|
|
83
123
|
|
|
84
124
|
def _sample_target(
|
|
85
125
|
self,
|
|
@@ -114,11 +154,11 @@ class SnowSampler(SQLSampler):
|
|
|
114
154
|
query.entity_table: np.arange(len(entity_df)),
|
|
115
155
|
}
|
|
116
156
|
for edge_type, (min_offset, max_offset) in time_offset_dict.items():
|
|
117
|
-
table_name,
|
|
157
|
+
table_name, foreign_key, _ = edge_type
|
|
118
158
|
feat_dict[table_name], batch_dict[table_name] = self._by_time(
|
|
119
159
|
table_name=table_name,
|
|
120
|
-
|
|
121
|
-
|
|
160
|
+
foreign_key=foreign_key,
|
|
161
|
+
index=entity_df[self.primary_key_dict[query.entity_table]],
|
|
122
162
|
anchor_time=time,
|
|
123
163
|
min_offset=min_offset,
|
|
124
164
|
max_offset=max_offset,
|
|
@@ -149,104 +189,219 @@ class SnowSampler(SQLSampler):
|
|
|
149
189
|
def _by_pkey(
|
|
150
190
|
self,
|
|
151
191
|
table_name: str,
|
|
152
|
-
|
|
192
|
+
index: pd.Series,
|
|
153
193
|
columns: set[str],
|
|
154
194
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
195
|
+
key = self.primary_key_dict[table_name]
|
|
196
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
197
|
+
projections = [
|
|
198
|
+
self.table_column_proj_dict[table_name][column]
|
|
199
|
+
for column in columns
|
|
200
|
+
]
|
|
155
201
|
|
|
156
|
-
|
|
157
|
-
source_table = self.source_table_dict[table_name]
|
|
158
|
-
|
|
159
|
-
payload = json.dumps(list(pkey))
|
|
202
|
+
payload = json.dumps(list(index))
|
|
160
203
|
|
|
161
204
|
sql = ("WITH TMP as (\n"
|
|
162
205
|
" SELECT\n"
|
|
163
|
-
" f.index as
|
|
164
|
-
if
|
|
165
|
-
sql += " f.value::NUMBER as
|
|
166
|
-
elif
|
|
167
|
-
sql += " f.value::FLOAT as
|
|
206
|
+
" f.index as __KUMO_BATCH__,\n")
|
|
207
|
+
if self.table_dtype_dict[table_name][key].is_int():
|
|
208
|
+
sql += " f.value::NUMBER as __KUMO_ID__\n"
|
|
209
|
+
elif self.table_dtype_dict[table_name][key].is_float():
|
|
210
|
+
sql += " f.value::FLOAT as __KUMO_ID__\n"
|
|
168
211
|
else:
|
|
169
|
-
sql += " f.value::VARCHAR as
|
|
212
|
+
sql += " f.value::VARCHAR as __KUMO_ID__\n"
|
|
170
213
|
sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
|
|
171
214
|
f")\n"
|
|
172
|
-
f"SELECT
|
|
173
|
-
f"
|
|
215
|
+
f"SELECT "
|
|
216
|
+
f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
|
|
217
|
+
f"{', '.join(projections)}\n"
|
|
174
218
|
f"FROM TMP\n"
|
|
175
|
-
f"JOIN {self.
|
|
176
|
-
f" ON
|
|
219
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
220
|
+
f" ON {key_ref} = TMP.__KUMO_ID__")
|
|
177
221
|
|
|
178
222
|
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
179
223
|
cursor.execute(sql, (payload, ))
|
|
180
224
|
table = cursor.fetch_arrow_all()
|
|
181
225
|
|
|
182
226
|
# Remove any duplicated primary keys in post-processing:
|
|
183
|
-
tmp = table.append_column('
|
|
184
|
-
gb = tmp.group_by('
|
|
185
|
-
table = table.take(gb['
|
|
227
|
+
tmp = table.append_column('__KUMO_ID__', pa.array(range(len(table))))
|
|
228
|
+
gb = tmp.group_by('__KUMO_BATCH__').aggregate([('__KUMO_ID__', 'min')])
|
|
229
|
+
table = table.take(gb['__KUMO_ID___min'])
|
|
230
|
+
|
|
231
|
+
batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
|
|
232
|
+
batch_index = table.schema.get_field_index('__KUMO_BATCH__')
|
|
233
|
+
table = table.remove_column(batch_index)
|
|
186
234
|
|
|
187
|
-
|
|
188
|
-
|
|
235
|
+
return Table._sanitize(
|
|
236
|
+
df=table.to_pandas(),
|
|
237
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
238
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
239
|
+
), batch
|
|
189
240
|
|
|
190
|
-
|
|
241
|
+
def _by_fkey(
|
|
242
|
+
self,
|
|
243
|
+
table_name: str,
|
|
244
|
+
foreign_key: str,
|
|
245
|
+
index: pd.Series,
|
|
246
|
+
num_neighbors: int,
|
|
247
|
+
anchor_time: pd.Series | None,
|
|
248
|
+
columns: set[str],
|
|
249
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
250
|
+
time_column = self.time_column_dict.get(table_name)
|
|
251
|
+
|
|
252
|
+
end_time: pd.Series | None = None
|
|
253
|
+
start_time: pd.Series | None = None
|
|
254
|
+
if time_column is not None and anchor_time is not None:
|
|
255
|
+
# In order to avoid a full table scan, we limit foreign key
|
|
256
|
+
# sampling to a certain time range, approximated by the number of
|
|
257
|
+
# rows, timestamp ranges and `num_neighbors` value.
|
|
258
|
+
# Downstream, this helps Snowflake to apply partition pruning:
|
|
259
|
+
dst_table_name = [
|
|
260
|
+
dst_table
|
|
261
|
+
for key, dst_table in self.foreign_key_dict[table_name]
|
|
262
|
+
if key == foreign_key
|
|
263
|
+
][0]
|
|
264
|
+
num_facts = self.num_rows_dict[table_name]
|
|
265
|
+
num_entities = self.num_rows_dict[dst_table_name]
|
|
266
|
+
min_time = self.get_min_time([table_name])
|
|
267
|
+
max_time = self.get_max_time([table_name])
|
|
268
|
+
freq = num_facts / num_entities
|
|
269
|
+
freq = freq / max((max_time - min_time).total_seconds(), 1)
|
|
270
|
+
offset = pd.Timedelta(seconds=math.ceil(5 * num_neighbors / freq))
|
|
271
|
+
|
|
272
|
+
end_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
273
|
+
start_time = anchor_time - offset
|
|
274
|
+
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
275
|
+
payload = json.dumps(list(zip(index, end_time, start_time)))
|
|
276
|
+
else:
|
|
277
|
+
payload = json.dumps(list(zip(index)))
|
|
278
|
+
|
|
279
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
280
|
+
projections = [
|
|
281
|
+
self.table_column_proj_dict[table_name][column]
|
|
282
|
+
for column in columns
|
|
283
|
+
]
|
|
284
|
+
|
|
285
|
+
sql = ("WITH TMP as (\n"
|
|
286
|
+
" SELECT\n"
|
|
287
|
+
" f.index as __KUMO_BATCH__,\n")
|
|
288
|
+
if self.table_dtype_dict[table_name][foreign_key].is_int():
|
|
289
|
+
sql += " f.value[0]::NUMBER as __KUMO_ID__"
|
|
290
|
+
elif self.table_dtype_dict[table_name][foreign_key].is_float():
|
|
291
|
+
sql += " f.value[0]::FLOAT as __KUMO_ID__"
|
|
292
|
+
else:
|
|
293
|
+
sql += " f.value[0]::VARCHAR as __KUMO_ID__"
|
|
294
|
+
if end_time is not None and start_time is not None:
|
|
295
|
+
sql += (",\n"
|
|
296
|
+
" f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__,\n"
|
|
297
|
+
" f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__")
|
|
298
|
+
sql += (f"\n"
|
|
299
|
+
f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
|
|
300
|
+
f")\n"
|
|
301
|
+
f"SELECT "
|
|
302
|
+
f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
|
|
303
|
+
f"{', '.join(projections)}\n"
|
|
304
|
+
f"FROM TMP\n"
|
|
305
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
306
|
+
f" ON {key_ref} = TMP.__KUMO_ID__\n")
|
|
307
|
+
if end_time is not None and start_time is not None:
|
|
308
|
+
assert time_column is not None
|
|
309
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
310
|
+
sql += (f" AND {time_ref} <= TMP.__KUMO_END_TIME__\n"
|
|
311
|
+
f" AND {time_ref} > TMP.__KUMO_START_TIME__\n"
|
|
312
|
+
f"WHERE {time_ref} <= '{end_time.max()}'\n"
|
|
313
|
+
f" AND {time_ref} > '{start_time.min()}'\n")
|
|
314
|
+
sql += ("QUALIFY ROW_NUMBER() OVER (\n"
|
|
315
|
+
" PARTITION BY TMP.__KUMO_BATCH__\n")
|
|
316
|
+
if time_column is not None:
|
|
317
|
+
sql += f" ORDER BY {time_ref} DESC\n"
|
|
318
|
+
else:
|
|
319
|
+
sql += f" ORDER BY {key_ref}\n"
|
|
320
|
+
sql += f") <= {num_neighbors}"
|
|
321
|
+
|
|
322
|
+
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
323
|
+
cursor.execute(sql, (payload, ))
|
|
324
|
+
table = cursor.fetch_arrow_all()
|
|
325
|
+
|
|
326
|
+
batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
|
|
327
|
+
batch_index = table.schema.get_field_index('__KUMO_BATCH__')
|
|
328
|
+
table = table.remove_column(batch_index)
|
|
329
|
+
|
|
330
|
+
return Table._sanitize(
|
|
331
|
+
df=table.to_pandas(),
|
|
332
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
333
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
334
|
+
), batch
|
|
191
335
|
|
|
192
336
|
# Helper Methods ##########################################################
|
|
193
337
|
|
|
194
338
|
def _by_time(
|
|
195
339
|
self,
|
|
196
340
|
table_name: str,
|
|
197
|
-
|
|
198
|
-
|
|
341
|
+
foreign_key: str,
|
|
342
|
+
index: pd.Series,
|
|
199
343
|
anchor_time: pd.Series,
|
|
200
344
|
min_offset: pd.DateOffset | None,
|
|
201
345
|
max_offset: pd.DateOffset,
|
|
202
346
|
columns: set[str],
|
|
203
347
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
348
|
+
time_column = self.time_column_dict[table_name]
|
|
204
349
|
|
|
205
350
|
end_time = anchor_time + max_offset
|
|
206
351
|
end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
352
|
+
start_time: pd.Series | None = None
|
|
207
353
|
if min_offset is not None:
|
|
208
354
|
start_time = anchor_time + min_offset
|
|
209
355
|
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
210
|
-
payload = json.dumps(list(zip(
|
|
356
|
+
payload = json.dumps(list(zip(index, end_time, start_time)))
|
|
211
357
|
else:
|
|
212
|
-
payload = json.dumps(list(zip(
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
358
|
+
payload = json.dumps(list(zip(index, end_time)))
|
|
359
|
+
|
|
360
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
361
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
362
|
+
projections = [
|
|
363
|
+
self.table_column_proj_dict[table_name][column]
|
|
364
|
+
for column in columns
|
|
365
|
+
]
|
|
218
366
|
sql = ("WITH TMP as (\n"
|
|
219
367
|
" SELECT\n"
|
|
220
|
-
" f.index as
|
|
221
|
-
if
|
|
222
|
-
sql += " f.value[0]::NUMBER as
|
|
223
|
-
elif
|
|
224
|
-
sql += " f.value[0]::FLOAT as
|
|
368
|
+
" f.index as __KUMO_BATCH__,\n")
|
|
369
|
+
if self.table_dtype_dict[table_name][foreign_key].is_int():
|
|
370
|
+
sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
|
|
371
|
+
elif self.table_dtype_dict[table_name][foreign_key].is_float():
|
|
372
|
+
sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
|
|
225
373
|
else:
|
|
226
|
-
sql += " f.value[0]::VARCHAR as
|
|
227
|
-
sql += " f.value[1]::TIMESTAMP_NTZ as
|
|
374
|
+
sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
|
|
375
|
+
sql += " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__"
|
|
228
376
|
if min_offset is not None:
|
|
229
|
-
sql += ",\n f.value[2]::TIMESTAMP_NTZ as
|
|
377
|
+
sql += ",\n f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__"
|
|
230
378
|
sql += (f"\n"
|
|
231
379
|
f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
|
|
232
380
|
f")\n"
|
|
233
|
-
f"SELECT
|
|
234
|
-
f"
|
|
381
|
+
f"SELECT "
|
|
382
|
+
f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
|
|
383
|
+
f"{', '.join(projections)}\n"
|
|
235
384
|
f"FROM TMP\n"
|
|
236
|
-
f"JOIN {self.
|
|
237
|
-
f" ON
|
|
238
|
-
f" AND
|
|
239
|
-
if
|
|
240
|
-
sql += f"
|
|
385
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
386
|
+
f" ON {key_ref} = TMP.__KUMO_ID__\n"
|
|
387
|
+
f" AND {time_ref} <= TMP.__KUMO_END_TIME__\n")
|
|
388
|
+
if start_time is not None:
|
|
389
|
+
sql += f"AND {time_ref} > TMP.__KUMO_START_TIME__\n"
|
|
390
|
+
# Add global time bounds to enable partition pruning:
|
|
391
|
+
sql += f"WHERE {time_ref} <= '{end_time.max()}'"
|
|
392
|
+
if start_time is not None:
|
|
393
|
+
sql += f"\nAND {time_ref} > '{start_time.min()}'"
|
|
241
394
|
|
|
242
395
|
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
243
396
|
cursor.execute(sql, (payload, ))
|
|
244
397
|
table = cursor.fetch_arrow_all()
|
|
245
398
|
|
|
246
|
-
batch = table['
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
return self._sanitize(table_name, table), batch
|
|
399
|
+
batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
|
|
400
|
+
batch_index = table.schema.get_field_index('__KUMO_BATCH__')
|
|
401
|
+
table = table.remove_column(batch_index)
|
|
250
402
|
|
|
251
|
-
|
|
252
|
-
|
|
403
|
+
return Table._sanitize(
|
|
404
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
405
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
406
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
407
|
+
), batch
|