kumoai 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl → 2.15.0.dev202601121731__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +51 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/graph_store.py +37 -46
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +24 -30
- kumoai/experimental/rfm/backend/snow/sampler.py +197 -90
- kumoai/experimental/rfm/backend/snow/table.py +159 -52
- kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +199 -99
- kumoai/experimental/rfm/backend/sqlite/table.py +103 -45
- kumoai/experimental/rfm/base/__init__.py +6 -1
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +28 -18
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +342 -13
- kumoai/experimental/rfm/base/table.py +374 -208
- kumoai/experimental/rfm/base/utils.py +27 -0
- kumoai/experimental/rfm/graph.py +335 -180
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +7 -4
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +5 -4
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +600 -360
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +3 -3
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +1 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +190 -12
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/METADATA +3 -2
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/RECORD +48 -40
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/top_level.txt +0 -0
|
@@ -1,15 +1,15 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import List, Optional, cast
|
|
1
|
+
from typing import Sequence, cast
|
|
3
2
|
|
|
4
3
|
import pandas as pd
|
|
4
|
+
from kumoapi.model_plan import MissingType
|
|
5
5
|
|
|
6
6
|
from kumoai.experimental.rfm.base import (
|
|
7
|
+
ColumnSpec,
|
|
7
8
|
DataBackend,
|
|
8
9
|
SourceColumn,
|
|
9
10
|
SourceForeignKey,
|
|
10
11
|
Table,
|
|
11
12
|
)
|
|
12
|
-
from kumoai.experimental.rfm.infer import infer_dtype
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class LocalTable(Table):
|
|
@@ -57,9 +57,9 @@ class LocalTable(Table):
|
|
|
57
57
|
self,
|
|
58
58
|
df: pd.DataFrame,
|
|
59
59
|
name: str,
|
|
60
|
-
primary_key:
|
|
61
|
-
time_column:
|
|
62
|
-
end_time_column:
|
|
60
|
+
primary_key: MissingType | str | None = MissingType.VALUE,
|
|
61
|
+
time_column: str | None = None,
|
|
62
|
+
end_time_column: str | None = None,
|
|
63
63
|
) -> None:
|
|
64
64
|
|
|
65
65
|
if df.empty:
|
|
@@ -75,7 +75,6 @@ class LocalTable(Table):
|
|
|
75
75
|
|
|
76
76
|
super().__init__(
|
|
77
77
|
name=name,
|
|
78
|
-
columns=list(df.columns),
|
|
79
78
|
primary_key=primary_key,
|
|
80
79
|
time_column=time_column,
|
|
81
80
|
end_time_column=end_time_column,
|
|
@@ -85,35 +84,30 @@ class LocalTable(Table):
|
|
|
85
84
|
def backend(self) -> DataBackend:
|
|
86
85
|
return cast(DataBackend, DataBackend.LOCAL)
|
|
87
86
|
|
|
88
|
-
def _get_source_columns(self) ->
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
dtype = infer_dtype(ser)
|
|
94
|
-
except Exception:
|
|
95
|
-
warnings.warn(f"Data type inference for column '{column}' in "
|
|
96
|
-
f"table '{self.name}' failed. Consider changing "
|
|
97
|
-
f"the data type of the column to use it within "
|
|
98
|
-
f"this table.")
|
|
99
|
-
continue
|
|
100
|
-
|
|
101
|
-
source_column = SourceColumn(
|
|
102
|
-
name=column,
|
|
103
|
-
dtype=dtype,
|
|
87
|
+
def _get_source_columns(self) -> list[SourceColumn]:
|
|
88
|
+
return [
|
|
89
|
+
SourceColumn(
|
|
90
|
+
name=column_name,
|
|
91
|
+
dtype=None,
|
|
104
92
|
is_primary_key=False,
|
|
105
93
|
is_unique_key=False,
|
|
106
94
|
is_nullable=True,
|
|
107
|
-
)
|
|
108
|
-
|
|
95
|
+
) for column_name in self._data.columns
|
|
96
|
+
]
|
|
109
97
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
|
|
98
|
+
def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
|
|
113
99
|
return []
|
|
114
100
|
|
|
115
|
-
def
|
|
101
|
+
def _get_source_sample_df(self) -> pd.DataFrame:
|
|
116
102
|
return self._data
|
|
117
103
|
|
|
118
|
-
def
|
|
104
|
+
def _get_expr_sample_df(
|
|
105
|
+
self,
|
|
106
|
+
columns: Sequence[ColumnSpec],
|
|
107
|
+
) -> pd.DataFrame:
|
|
108
|
+
raise RuntimeError(f"Column expressions are not supported in "
|
|
109
|
+
f"'{self.__class__.__name__}'. Please apply your "
|
|
110
|
+
f"expressions on the `pd.DataFrame` directly.")
|
|
111
|
+
|
|
112
|
+
def _get_num_rows(self) -> int | None:
|
|
119
113
|
return len(self._data)
|
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
|
+
from collections.abc import Iterator
|
|
3
|
+
from contextlib import contextmanager
|
|
2
4
|
from typing import TYPE_CHECKING
|
|
3
5
|
|
|
4
6
|
import numpy as np
|
|
@@ -6,15 +8,23 @@ import pandas as pd
|
|
|
6
8
|
import pyarrow as pa
|
|
7
9
|
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
8
10
|
|
|
9
|
-
from kumoai.experimental.rfm.backend.snow import SnowTable
|
|
10
|
-
from kumoai.experimental.rfm.base import SQLSampler
|
|
11
|
+
from kumoai.experimental.rfm.backend.snow import Connection, SnowTable
|
|
12
|
+
from kumoai.experimental.rfm.base import SQLSampler, Table
|
|
11
13
|
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
12
|
-
from kumoai.utils import ProgressLogger
|
|
14
|
+
from kumoai.utils import ProgressLogger
|
|
13
15
|
|
|
14
16
|
if TYPE_CHECKING:
|
|
15
17
|
from kumoai.experimental.rfm import Graph
|
|
16
18
|
|
|
17
19
|
|
|
20
|
+
@contextmanager
|
|
21
|
+
def paramstyle(connection: Connection, style: str = 'qmark') -> Iterator[None]:
|
|
22
|
+
_style = connection._paramstyle
|
|
23
|
+
connection._paramstyle = style
|
|
24
|
+
yield
|
|
25
|
+
connection._paramstyle = _style
|
|
26
|
+
|
|
27
|
+
|
|
18
28
|
class SnowSampler(SQLSampler):
|
|
19
29
|
def __init__(
|
|
20
30
|
self,
|
|
@@ -23,16 +33,9 @@ class SnowSampler(SQLSampler):
|
|
|
23
33
|
) -> None:
|
|
24
34
|
super().__init__(graph=graph, verbose=verbose)
|
|
25
35
|
|
|
26
|
-
self._fqn_dict: dict[str, str] = {}
|
|
27
36
|
for table in graph.tables.values():
|
|
28
37
|
assert isinstance(table, SnowTable)
|
|
29
38
|
self._connection = table._connection
|
|
30
|
-
self._fqn_dict[table.name] = table.fqn
|
|
31
|
-
|
|
32
|
-
@property
|
|
33
|
-
def fqn_dict(self) -> dict[str, str]:
|
|
34
|
-
r"""The fully-qualified quoted names for all tables in the graph."""
|
|
35
|
-
return self._fqn_dict
|
|
36
39
|
|
|
37
40
|
def _get_min_max_time_dict(
|
|
38
41
|
self,
|
|
@@ -40,24 +43,25 @@ class SnowSampler(SQLSampler):
|
|
|
40
43
|
) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
|
|
41
44
|
selects: list[str] = []
|
|
42
45
|
for table_name in table_names:
|
|
43
|
-
|
|
46
|
+
column = self.time_column_dict[table_name]
|
|
47
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
44
48
|
select = (f"SELECT\n"
|
|
45
|
-
f"
|
|
46
|
-
f" MIN({
|
|
47
|
-
f" MAX({
|
|
48
|
-
f"FROM {self.
|
|
49
|
+
f" ? as table_name,\n"
|
|
50
|
+
f" MIN({column_ref}) as min_date,\n"
|
|
51
|
+
f" MAX({column_ref}) as max_date\n"
|
|
52
|
+
f"FROM {self.source_name_dict[table_name]}")
|
|
49
53
|
selects.append(select)
|
|
50
54
|
sql = "\nUNION ALL\n".join(selects)
|
|
51
55
|
|
|
52
56
|
out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
53
|
-
with self._connection.cursor() as cursor:
|
|
57
|
+
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
54
58
|
cursor.execute(sql, table_names)
|
|
55
59
|
rows = cursor.fetchall()
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
60
|
+
for table_name, _min, _max in rows:
|
|
61
|
+
out_dict[table_name] = (
|
|
62
|
+
pd.Timestamp.max if _min is None else pd.Timestamp(_min),
|
|
63
|
+
pd.Timestamp.min if _max is None else pd.Timestamp(_max),
|
|
64
|
+
)
|
|
61
65
|
|
|
62
66
|
return out_dict
|
|
63
67
|
|
|
@@ -71,17 +75,27 @@ class SnowSampler(SQLSampler):
|
|
|
71
75
|
# NOTE Snowflake does support `SEED` only as part of `SYSTEM` sampling.
|
|
72
76
|
num_rows = min(num_rows, 1_000_000) # Snowflake's upper limit.
|
|
73
77
|
|
|
78
|
+
source_table = self.source_table_dict[table_name]
|
|
74
79
|
filters: list[str] = []
|
|
75
|
-
primary_key = self.primary_key_dict[table_name]
|
|
76
|
-
if self.source_table_dict[table_name][primary_key].is_nullable:
|
|
77
|
-
filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
|
|
78
|
-
time_column = self.time_column_dict.get(table_name)
|
|
79
|
-
if (time_column is not None and
|
|
80
|
-
self.source_table_dict[table_name][time_column].is_nullable):
|
|
81
|
-
filters.append(f" {quote_ident(time_column)} IS NOT NULL")
|
|
82
80
|
|
|
83
|
-
|
|
84
|
-
|
|
81
|
+
key = self.primary_key_dict[table_name]
|
|
82
|
+
if key not in source_table or source_table[key].is_nullable:
|
|
83
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
84
|
+
filters.append(f" {key_ref} IS NOT NULL")
|
|
85
|
+
|
|
86
|
+
column = self.time_column_dict.get(table_name)
|
|
87
|
+
if column is None:
|
|
88
|
+
pass
|
|
89
|
+
elif column not in source_table or source_table[column].is_nullable:
|
|
90
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
91
|
+
filters.append(f" {column_ref} IS NOT NULL")
|
|
92
|
+
|
|
93
|
+
projections = [
|
|
94
|
+
self.table_column_proj_dict[table_name][column]
|
|
95
|
+
for column in columns
|
|
96
|
+
]
|
|
97
|
+
sql = (f"SELECT {', '.join(projections)}\n"
|
|
98
|
+
f"FROM {self.source_name_dict[table_name]}\n"
|
|
85
99
|
f"SAMPLE ROW ({num_rows} ROWS)")
|
|
86
100
|
if len(filters) > 0:
|
|
87
101
|
sql += f"\nWHERE{' AND'.join(filters)}"
|
|
@@ -91,7 +105,11 @@ class SnowSampler(SQLSampler):
|
|
|
91
105
|
cursor.execute(sql)
|
|
92
106
|
table = cursor.fetch_arrow_all()
|
|
93
107
|
|
|
94
|
-
return
|
|
108
|
+
return Table._sanitize(
|
|
109
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
110
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
111
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
112
|
+
)
|
|
95
113
|
|
|
96
114
|
def _sample_target(
|
|
97
115
|
self,
|
|
@@ -126,11 +144,11 @@ class SnowSampler(SQLSampler):
|
|
|
126
144
|
query.entity_table: np.arange(len(entity_df)),
|
|
127
145
|
}
|
|
128
146
|
for edge_type, (min_offset, max_offset) in time_offset_dict.items():
|
|
129
|
-
table_name,
|
|
147
|
+
table_name, foreign_key, _ = edge_type
|
|
130
148
|
feat_dict[table_name], batch_dict[table_name] = self._by_time(
|
|
131
149
|
table_name=table_name,
|
|
132
|
-
|
|
133
|
-
|
|
150
|
+
foreign_key=foreign_key,
|
|
151
|
+
index=entity_df[self.primary_key_dict[query.entity_table]],
|
|
134
152
|
anchor_time=time,
|
|
135
153
|
min_offset=min_offset,
|
|
136
154
|
max_offset=max_offset,
|
|
@@ -161,104 +179,193 @@ class SnowSampler(SQLSampler):
|
|
|
161
179
|
def _by_pkey(
|
|
162
180
|
self,
|
|
163
181
|
table_name: str,
|
|
164
|
-
|
|
182
|
+
index: pd.Series,
|
|
165
183
|
columns: set[str],
|
|
166
184
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
185
|
+
key = self.primary_key_dict[table_name]
|
|
186
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
187
|
+
projections = [
|
|
188
|
+
self.table_column_proj_dict[table_name][column]
|
|
189
|
+
for column in columns
|
|
190
|
+
]
|
|
167
191
|
|
|
168
|
-
|
|
169
|
-
source_table = self.source_table_dict[table_name]
|
|
170
|
-
|
|
171
|
-
payload = json.dumps(list(pkey))
|
|
192
|
+
payload = json.dumps(list(index))
|
|
172
193
|
|
|
173
194
|
sql = ("WITH TMP as (\n"
|
|
174
195
|
" SELECT\n"
|
|
175
|
-
" f.index as
|
|
176
|
-
if
|
|
177
|
-
sql += " f.value::NUMBER as
|
|
178
|
-
elif
|
|
179
|
-
sql += " f.value::FLOAT as
|
|
196
|
+
" f.index as __KUMO_BATCH__,\n")
|
|
197
|
+
if self.table_dtype_dict[table_name][key].is_int():
|
|
198
|
+
sql += " f.value::NUMBER as __KUMO_ID__\n"
|
|
199
|
+
elif self.table_dtype_dict[table_name][key].is_float():
|
|
200
|
+
sql += " f.value::FLOAT as __KUMO_ID__\n"
|
|
180
201
|
else:
|
|
181
|
-
sql += " f.value::VARCHAR as
|
|
182
|
-
sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(
|
|
202
|
+
sql += " f.value::VARCHAR as __KUMO_ID__\n"
|
|
203
|
+
sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
|
|
183
204
|
f")\n"
|
|
184
|
-
f"SELECT
|
|
185
|
-
f"
|
|
205
|
+
f"SELECT "
|
|
206
|
+
f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
|
|
207
|
+
f"{', '.join(projections)}\n"
|
|
186
208
|
f"FROM TMP\n"
|
|
187
|
-
f"JOIN {self.
|
|
188
|
-
f" ON
|
|
209
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
210
|
+
f" ON {key_ref} = TMP.__KUMO_ID__")
|
|
189
211
|
|
|
190
|
-
with self._connection.cursor() as cursor:
|
|
212
|
+
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
191
213
|
cursor.execute(sql, (payload, ))
|
|
192
214
|
table = cursor.fetch_arrow_all()
|
|
193
215
|
|
|
194
216
|
# Remove any duplicated primary keys in post-processing:
|
|
195
|
-
tmp = table.append_column('
|
|
196
|
-
gb = tmp.group_by('
|
|
197
|
-
table = table.take(gb['
|
|
217
|
+
tmp = table.append_column('__KUMO_ID__', pa.array(range(len(table))))
|
|
218
|
+
gb = tmp.group_by('__KUMO_BATCH__').aggregate([('__KUMO_ID__', 'min')])
|
|
219
|
+
table = table.take(gb['__KUMO_ID___min'])
|
|
198
220
|
|
|
199
|
-
batch = table['
|
|
200
|
-
|
|
221
|
+
batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
|
|
222
|
+
batch_index = table.schema.get_field_index('__KUMO_BATCH__')
|
|
223
|
+
table = table.remove_column(batch_index)
|
|
201
224
|
|
|
202
|
-
return
|
|
225
|
+
return Table._sanitize(
|
|
226
|
+
df=table.to_pandas(),
|
|
227
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
228
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
229
|
+
), batch
|
|
230
|
+
|
|
231
|
+
def _by_fkey(
|
|
232
|
+
self,
|
|
233
|
+
table_name: str,
|
|
234
|
+
foreign_key: str,
|
|
235
|
+
index: pd.Series,
|
|
236
|
+
num_neighbors: int,
|
|
237
|
+
anchor_time: pd.Series | None,
|
|
238
|
+
columns: set[str],
|
|
239
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
240
|
+
time_column = self.time_column_dict.get(table_name)
|
|
241
|
+
|
|
242
|
+
if time_column is not None and anchor_time is not None:
|
|
243
|
+
anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
244
|
+
payload = json.dumps(list(zip(index, anchor_time)))
|
|
245
|
+
else:
|
|
246
|
+
payload = json.dumps(list(zip(index)))
|
|
247
|
+
|
|
248
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
249
|
+
projections = [
|
|
250
|
+
self.table_column_proj_dict[table_name][column]
|
|
251
|
+
for column in columns
|
|
252
|
+
]
|
|
253
|
+
|
|
254
|
+
sql = ("WITH TMP as (\n"
|
|
255
|
+
" SELECT\n"
|
|
256
|
+
" f.index as __KUMO_BATCH__,\n")
|
|
257
|
+
if self.table_dtype_dict[table_name][foreign_key].is_int():
|
|
258
|
+
sql += " f.value[0]::NUMBER as __KUMO_ID__"
|
|
259
|
+
elif self.table_dtype_dict[table_name][foreign_key].is_float():
|
|
260
|
+
sql += " f.value[0]::FLOAT as __KUMO_ID__"
|
|
261
|
+
else:
|
|
262
|
+
sql += " f.value[0]::VARCHAR as __KUMO_ID__"
|
|
263
|
+
if time_column is not None and anchor_time is not None:
|
|
264
|
+
sql += (",\n"
|
|
265
|
+
" f.value[1]::TIMESTAMP_NTZ as __KUMO_TIME__")
|
|
266
|
+
sql += (f"\n"
|
|
267
|
+
f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
|
|
268
|
+
f")\n"
|
|
269
|
+
f"SELECT "
|
|
270
|
+
f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
|
|
271
|
+
f"{', '.join(projections)}\n"
|
|
272
|
+
f"FROM TMP\n"
|
|
273
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
274
|
+
f" ON {key_ref} = TMP.__KUMO_ID__\n")
|
|
275
|
+
if time_column is not None and anchor_time is not None:
|
|
276
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
277
|
+
sql += f" AND {time_ref} <= TMP.__KUMO_TIME__\n"
|
|
278
|
+
sql += ("QUALIFY ROW_NUMBER() OVER (\n"
|
|
279
|
+
" PARTITION BY TMP.__KUMO_BATCH__\n")
|
|
280
|
+
if time_column is not None:
|
|
281
|
+
sql += f" ORDER BY {time_ref} DESC\n"
|
|
282
|
+
else:
|
|
283
|
+
sql += f" ORDER BY {key_ref}\n"
|
|
284
|
+
sql += f") <= {num_neighbors}"
|
|
285
|
+
|
|
286
|
+
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
287
|
+
cursor.execute(sql, (payload, ))
|
|
288
|
+
table = cursor.fetch_arrow_all()
|
|
289
|
+
|
|
290
|
+
batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
|
|
291
|
+
batch_index = table.schema.get_field_index('__KUMO_BATCH__')
|
|
292
|
+
table = table.remove_column(batch_index)
|
|
293
|
+
|
|
294
|
+
return Table._sanitize(
|
|
295
|
+
df=table.to_pandas(),
|
|
296
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
297
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
298
|
+
), batch
|
|
203
299
|
|
|
204
300
|
# Helper Methods ##########################################################
|
|
205
301
|
|
|
206
302
|
def _by_time(
|
|
207
303
|
self,
|
|
208
304
|
table_name: str,
|
|
209
|
-
|
|
210
|
-
|
|
305
|
+
foreign_key: str,
|
|
306
|
+
index: pd.Series,
|
|
211
307
|
anchor_time: pd.Series,
|
|
212
308
|
min_offset: pd.DateOffset | None,
|
|
213
309
|
max_offset: pd.DateOffset,
|
|
214
310
|
columns: set[str],
|
|
215
311
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
312
|
+
time_column = self.time_column_dict[table_name]
|
|
216
313
|
|
|
217
314
|
end_time = anchor_time + max_offset
|
|
218
315
|
end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
316
|
+
start_time: pd.Series | None = None
|
|
219
317
|
if min_offset is not None:
|
|
220
318
|
start_time = anchor_time + min_offset
|
|
221
319
|
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
222
|
-
payload = json.dumps(list(zip(
|
|
320
|
+
payload = json.dumps(list(zip(index, end_time, start_time)))
|
|
223
321
|
else:
|
|
224
|
-
payload = json.dumps(list(zip(
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
322
|
+
payload = json.dumps(list(zip(index, end_time)))
|
|
323
|
+
|
|
324
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
325
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
326
|
+
projections = [
|
|
327
|
+
self.table_column_proj_dict[table_name][column]
|
|
328
|
+
for column in columns
|
|
329
|
+
]
|
|
230
330
|
sql = ("WITH TMP as (\n"
|
|
231
331
|
" SELECT\n"
|
|
232
|
-
" f.index as
|
|
233
|
-
if
|
|
234
|
-
sql += " f.value[0]::NUMBER as
|
|
235
|
-
elif
|
|
236
|
-
sql += " f.value[0]::FLOAT as
|
|
332
|
+
" f.index as __KUMO_BATCH__,\n")
|
|
333
|
+
if self.table_dtype_dict[table_name][foreign_key].is_int():
|
|
334
|
+
sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
|
|
335
|
+
elif self.table_dtype_dict[table_name][foreign_key].is_float():
|
|
336
|
+
sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
|
|
237
337
|
else:
|
|
238
|
-
sql += " f.value[0]::VARCHAR as
|
|
239
|
-
sql += " f.value[1]::TIMESTAMP_NTZ as
|
|
338
|
+
sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
|
|
339
|
+
sql += " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__"
|
|
240
340
|
if min_offset is not None:
|
|
241
|
-
sql += ",\n f.value[2]::TIMESTAMP_NTZ as
|
|
341
|
+
sql += ",\n f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__"
|
|
242
342
|
sql += (f"\n"
|
|
243
|
-
f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(
|
|
343
|
+
f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
|
|
244
344
|
f")\n"
|
|
245
|
-
f"SELECT
|
|
246
|
-
f"
|
|
345
|
+
f"SELECT "
|
|
346
|
+
f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
|
|
347
|
+
f"{', '.join(projections)}\n"
|
|
247
348
|
f"FROM TMP\n"
|
|
248
|
-
f"JOIN {self.
|
|
249
|
-
f" ON
|
|
250
|
-
f" AND
|
|
251
|
-
if
|
|
252
|
-
sql += f"
|
|
253
|
-
|
|
254
|
-
|
|
349
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
350
|
+
f" ON {key_ref} = TMP.__KUMO_ID__\n"
|
|
351
|
+
f" AND {time_ref} <= TMP.__KUMO_END_TIME__\n")
|
|
352
|
+
if start_time is not None:
|
|
353
|
+
sql += f"AND {time_ref} > TMP.__KUMO_START_TIME__\n"
|
|
354
|
+
# Add global time bounds to enable partition pruning:
|
|
355
|
+
sql += f"WHERE {time_ref} <= '{end_time.max()}'"
|
|
356
|
+
if start_time is not None:
|
|
357
|
+
sql += f"\nAND {time_ref} > '{start_time.min()}'"
|
|
358
|
+
|
|
359
|
+
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
255
360
|
cursor.execute(sql, (payload, ))
|
|
256
361
|
table = cursor.fetch_arrow_all()
|
|
257
362
|
|
|
258
|
-
batch = table['
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
return self._sanitize(table_name, table), batch
|
|
363
|
+
batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
|
|
364
|
+
batch_index = table.schema.get_field_index('__KUMO_BATCH__')
|
|
365
|
+
table = table.remove_column(batch_index)
|
|
262
366
|
|
|
263
|
-
|
|
264
|
-
|
|
367
|
+
return Table._sanitize(
|
|
368
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
369
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
370
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
371
|
+
), batch
|