kumoai 2.13.0.dev202512031731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512301731__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. kumoai/__init__.py +35 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/experimental/rfm/__init__.py +49 -24
  7. kumoai/experimental/rfm/authenticate.py +3 -4
  8. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  9. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
  10. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  11. kumoai/experimental/rfm/backend/local/table.py +32 -14
  12. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  13. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +186 -39
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  16. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  17. kumoai/experimental/rfm/backend/sqlite/table.py +131 -41
  18. kumoai/experimental/rfm/base/__init__.py +23 -3
  19. kumoai/experimental/rfm/base/column.py +96 -10
  20. kumoai/experimental/rfm/base/expression.py +44 -0
  21. kumoai/experimental/rfm/base/sampler.py +761 -0
  22. kumoai/experimental/rfm/base/source.py +2 -1
  23. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  24. kumoai/experimental/rfm/base/table.py +380 -185
  25. kumoai/experimental/rfm/graph.py +404 -144
  26. kumoai/experimental/rfm/infer/__init__.py +6 -4
  27. kumoai/experimental/rfm/infer/dtype.py +52 -60
  28. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  29. kumoai/experimental/rfm/infer/pkey.py +4 -2
  30. kumoai/experimental/rfm/infer/stype.py +35 -0
  31. kumoai/experimental/rfm/infer/time_col.py +1 -2
  32. kumoai/experimental/rfm/pquery/executor.py +27 -27
  33. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  34. kumoai/experimental/rfm/relbench.py +76 -0
  35. kumoai/experimental/rfm/rfm.py +283 -230
  36. kumoai/experimental/rfm/sagemaker.py +4 -4
  37. kumoai/pquery/predictive_query.py +10 -6
  38. kumoai/testing/snow.py +50 -0
  39. kumoai/trainer/distilled_trainer.py +175 -0
  40. kumoai/utils/__init__.py +3 -2
  41. kumoai/utils/display.py +51 -0
  42. kumoai/utils/progress_logger.py +178 -12
  43. kumoai/utils/sql.py +3 -0
  44. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/METADATA +4 -2
  45. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/RECORD +48 -38
  46. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  47. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  48. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/WHEEL +0 -0
  49. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/licenses/LICENSE +0 -0
  50. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,312 @@
1
+ from typing import TYPE_CHECKING, Literal
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from kumoapi.pquery import ValidatedPredictiveQuery
6
+
7
+ from kumoai.experimental.rfm.backend.local import LocalGraphStore
8
+ from kumoai.experimental.rfm.base import Sampler, SamplerOutput
9
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
10
+ from kumoai.utils import ProgressLogger
11
+
12
+ if TYPE_CHECKING:
13
+ from kumoai.experimental.rfm import Graph
14
+
15
+
16
+ class LocalSampler(Sampler):
17
+ def __init__(
18
+ self,
19
+ graph: 'Graph',
20
+ verbose: bool | ProgressLogger = True,
21
+ ) -> None:
22
+ super().__init__(graph=graph, verbose=verbose)
23
+
24
+ import kumoai.kumolib as kumolib
25
+
26
+ self._graph_store = LocalGraphStore(graph, verbose)
27
+ self._graph_sampler = kumolib.NeighborSampler(
28
+ list(self.table_stype_dict.keys()),
29
+ self.edge_types,
30
+ {
31
+ '__'.join(edge_type): colptr
32
+ for edge_type, colptr in self._graph_store.colptr_dict.items()
33
+ },
34
+ {
35
+ '__'.join(edge_type): row
36
+ for edge_type, row in self._graph_store.row_dict.items()
37
+ },
38
+ self._graph_store.time_dict,
39
+ )
40
+
41
+ def _get_min_max_time_dict(
42
+ self,
43
+ table_names: list[str],
44
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
45
+ return {
46
+ key: value
47
+ for key, value in self._graph_store.min_max_time_dict.items()
48
+ if key in table_names
49
+ }
50
+
51
+ def _sample_subgraph(
52
+ self,
53
+ entity_table_name: str,
54
+ entity_pkey: pd.Series,
55
+ anchor_time: pd.Series | Literal['entity'],
56
+ columns_dict: dict[str, set[str]],
57
+ num_neighbors: list[int],
58
+ ) -> SamplerOutput:
59
+
60
+ index = self._graph_store.get_node_id(entity_table_name, entity_pkey)
61
+
62
+ if isinstance(anchor_time, pd.Series):
63
+ time = anchor_time.astype(int).to_numpy() // 1000**3 # to seconds
64
+ else:
65
+ assert anchor_time == 'entity'
66
+ time = self._graph_store.time_dict[entity_table_name][index]
67
+
68
+ (
69
+ row_dict,
70
+ col_dict,
71
+ node_dict,
72
+ batch_dict,
73
+ num_sampled_nodes_dict,
74
+ num_sampled_edges_dict,
75
+ ) = self._graph_sampler.sample(
76
+ {
77
+ '__'.join(edge_type): num_neighbors
78
+ for edge_type in self.edge_types
79
+ },
80
+ {},
81
+ entity_table_name,
82
+ index,
83
+ time,
84
+ )
85
+
86
+ df_dict: dict[str, pd.DataFrame] = {}
87
+ inverse_dict: dict[str, np.ndarray] = {}
88
+ for table_name, node in node_dict.items():
89
+ df = self._graph_store.df_dict[table_name]
90
+ columns = columns_dict[table_name]
91
+ if self.end_time_column_dict.get(table_name, None) in columns:
92
+ df = df.iloc[node]
93
+ elif len(columns) == 0:
94
+ df = df.iloc[node]
95
+ else:
96
+ # Only store unique rows in `df` above a certain threshold:
97
+ unique_node, inverse = np.unique(node, return_inverse=True)
98
+ if len(node) > 1.05 * len(unique_node):
99
+ df = df.iloc[unique_node]
100
+ inverse_dict[table_name] = inverse
101
+ else:
102
+ df = df.iloc[node]
103
+ df = df.reset_index(drop=True)
104
+ df = df[list(columns)]
105
+ df_dict[table_name] = df
106
+
107
+ num_sampled_nodes_dict = {
108
+ table_name: num_sampled_nodes.tolist()
109
+ for table_name, num_sampled_nodes in
110
+ num_sampled_nodes_dict.items()
111
+ }
112
+
113
+ row_dict = {
114
+ edge_type: row_dict['__'.join(edge_type)]
115
+ for edge_type in self.edge_types
116
+ }
117
+ col_dict = {
118
+ edge_type: col_dict['__'.join(edge_type)]
119
+ for edge_type in self.edge_types
120
+ }
121
+ num_sampled_edges_dict = {
122
+ edge_type: num_sampled_edges_dict['__'.join(edge_type)].tolist()
123
+ for edge_type in self.edge_types
124
+ }
125
+
126
+ return SamplerOutput(
127
+ anchor_time=time * 1000**3, # to nanoseconds
128
+ df_dict=df_dict,
129
+ inverse_dict=inverse_dict,
130
+ batch_dict=batch_dict,
131
+ num_sampled_nodes_dict=num_sampled_nodes_dict,
132
+ row_dict=row_dict,
133
+ col_dict=col_dict,
134
+ num_sampled_edges_dict=num_sampled_edges_dict,
135
+ )
136
+
137
+ def _sample_entity_table(
138
+ self,
139
+ table_name: str,
140
+ columns: set[str],
141
+ num_rows: int,
142
+ random_seed: int | None = None,
143
+ ) -> pd.DataFrame:
144
+ pkey_map = self._graph_store.pkey_map_dict[table_name]
145
+ if len(pkey_map) > num_rows:
146
+ pkey_map = pkey_map.sample(
147
+ n=num_rows,
148
+ random_state=random_seed,
149
+ ignore_index=True,
150
+ )
151
+ df = self._graph_store.df_dict[table_name]
152
+ df = df.iloc[pkey_map['arange']][list(columns)]
153
+ return df
154
+
155
+ def _sample_target(
156
+ self,
157
+ query: ValidatedPredictiveQuery,
158
+ entity_df: pd.DataFrame,
159
+ train_index: np.ndarray,
160
+ train_time: pd.Series,
161
+ num_train_examples: int,
162
+ test_index: np.ndarray,
163
+ test_time: pd.Series,
164
+ num_test_examples: int,
165
+ columns_dict: dict[str, set[str]],
166
+ time_offset_dict: dict[
167
+ tuple[str, str, str],
168
+ tuple[pd.DateOffset | None, pd.DateOffset],
169
+ ],
170
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
171
+
172
+ train_y, train_mask = self._sample_target_set(
173
+ query=query,
174
+ pkey=entity_df[self.primary_key_dict[query.entity_table]],
175
+ index=train_index,
176
+ anchor_time=train_time,
177
+ num_examples=num_train_examples,
178
+ columns_dict=columns_dict,
179
+ time_offset_dict=time_offset_dict,
180
+ )
181
+
182
+ test_y, test_mask = self._sample_target_set(
183
+ query=query,
184
+ pkey=entity_df[self.primary_key_dict[query.entity_table]],
185
+ index=test_index,
186
+ anchor_time=test_time,
187
+ num_examples=num_test_examples,
188
+ columns_dict=columns_dict,
189
+ time_offset_dict=time_offset_dict,
190
+ )
191
+
192
+ return train_y, train_mask, test_y, test_mask
193
+
194
+ # Helper Methods ##########################################################
195
+
196
+ def _sample_target_set(
197
+ self,
198
+ query: ValidatedPredictiveQuery,
199
+ pkey: pd.Series,
200
+ index: np.ndarray,
201
+ anchor_time: pd.Series,
202
+ num_examples: int,
203
+ columns_dict: dict[str, set[str]],
204
+ time_offset_dict: dict[
205
+ tuple[str, str, str],
206
+ tuple[pd.DateOffset | None, pd.DateOffset],
207
+ ],
208
+ batch_size: int = 10_000,
209
+ ) -> tuple[pd.Series, np.ndarray]:
210
+
211
+ num_hops = 1 if len(time_offset_dict) > 0 else 0
212
+ num_neighbors_dict: dict[str, list[int]] = {}
213
+ unix_time_offset_dict: dict[str, list[list[int | None]]] = {}
214
+ for edge_type, (start, end) in time_offset_dict.items():
215
+ unix_time_offset_dict['__'.join(edge_type)] = [[
216
+ date_offset_to_seconds(start) if start is not None else None,
217
+ date_offset_to_seconds(end),
218
+ ]]
219
+ for edge_type in set(self.edge_types) - set(time_offset_dict.keys()):
220
+ num_neighbors_dict['__'.join(edge_type)] = [0] * num_hops
221
+
222
+ count = 0
223
+ ys: list[pd.Series] = []
224
+ mask = np.full(len(index), False, dtype=bool)
225
+ for start in range(0, len(index), batch_size):
226
+ subset = pkey.iloc[index[start:start + batch_size]]
227
+ time = anchor_time.iloc[start:start + batch_size]
228
+
229
+ _, _, node_dict, batch_dict, _, _ = self._graph_sampler.sample(
230
+ num_neighbors_dict,
231
+ unix_time_offset_dict,
232
+ query.entity_table,
233
+ self._graph_store.get_node_id(query.entity_table, subset),
234
+ time.astype(int).to_numpy() // 1000**3, # to seconds
235
+ )
236
+
237
+ feat_dict: dict[str, pd.DataFrame] = {}
238
+ time_dict: dict[str, pd.Series] = {}
239
+ for table_name, columns in columns_dict.items():
240
+ df = self._graph_store.df_dict[table_name]
241
+ df = df.iloc[node_dict[table_name]].reset_index(drop=True)
242
+ df = df[list(columns)]
243
+ feat_dict[table_name] = df
244
+
245
+ time_column = self.time_column_dict.get(table_name)
246
+ if time_column in columns:
247
+ time_dict[table_name] = df[time_column]
248
+
249
+ y, _mask = PQueryPandasExecutor().execute(
250
+ query=query,
251
+ feat_dict=feat_dict,
252
+ time_dict=time_dict,
253
+ batch_dict=batch_dict,
254
+ anchor_time=time,
255
+ num_forecasts=query.num_forecasts,
256
+ )
257
+ ys.append(y)
258
+ mask[start:start + batch_size] = _mask
259
+
260
+ count += len(y)
261
+ if count >= num_examples:
262
+ break
263
+
264
+ if len(ys) == 0:
265
+ y = pd.Series([], dtype=float)
266
+ elif len(ys) == 1:
267
+ y = ys[0]
268
+ else:
269
+ y = pd.concat(ys, axis=0, ignore_index=True)
270
+
271
+ return y, mask
272
+
273
+
274
+ # Helper Functions ############################################################
275
+
276
+
277
+ def date_offset_to_seconds(offset: pd.DateOffset) -> int:
278
+ r"""Convert a :class:`pandas.DateOffset` into a number of seconds.
279
+
280
+ .. note::
281
+ We are conservative and take months and years as their maximum value.
282
+ Additional values are then dropped in label computation where we know
283
+ the actual dates.
284
+ """
285
+ MAX_DAYS_IN_MONTH = 31
286
+ MAX_DAYS_IN_YEAR = 366
287
+
288
+ SECONDS_IN_MINUTE = 60
289
+ SECONDS_IN_HOUR = 60 * SECONDS_IN_MINUTE
290
+ SECONDS_IN_DAY = 24 * SECONDS_IN_HOUR
291
+
292
+ total_sec = 0
293
+ multiplier = getattr(offset, 'n', 1) # The multiplier (if present).
294
+
295
+ for attr, value in offset.__dict__.items():
296
+ if value is None or value == 0:
297
+ continue
298
+ scaled_value = value * multiplier
299
+ if attr == 'years':
300
+ total_sec += scaled_value * MAX_DAYS_IN_YEAR * SECONDS_IN_DAY
301
+ elif attr == 'months':
302
+ total_sec += scaled_value * MAX_DAYS_IN_MONTH * SECONDS_IN_DAY
303
+ elif attr == 'days':
304
+ total_sec += scaled_value * SECONDS_IN_DAY
305
+ elif attr == 'hours':
306
+ total_sec += scaled_value * SECONDS_IN_HOUR
307
+ elif attr == 'minutes':
308
+ total_sec += scaled_value * SECONDS_IN_MINUTE
309
+ elif attr == 'seconds':
310
+ total_sec += scaled_value
311
+
312
+ return total_sec
@@ -1,9 +1,15 @@
1
- from typing import List, Optional
1
+ from typing import Sequence, cast
2
2
 
3
3
  import pandas as pd
4
+ from kumoapi.model_plan import MissingType
4
5
 
5
- from kumoai.experimental.rfm.base import SourceColumn, SourceForeignKey, Table
6
- from kumoai.experimental.rfm.infer import infer_dtype
6
+ from kumoai.experimental.rfm.base import (
7
+ ColumnSpec,
8
+ DataBackend,
9
+ SourceColumn,
10
+ SourceForeignKey,
11
+ Table,
12
+ )
7
13
 
8
14
 
9
15
  class LocalTable(Table):
@@ -51,9 +57,9 @@ class LocalTable(Table):
51
57
  self,
52
58
  df: pd.DataFrame,
53
59
  name: str,
54
- primary_key: Optional[str] = None,
55
- time_column: Optional[str] = None,
56
- end_time_column: Optional[str] = None,
60
+ primary_key: MissingType | str | None = MissingType.VALUE,
61
+ time_column: str | None = None,
62
+ end_time_column: str | None = None,
57
63
  ) -> None:
58
64
 
59
65
  if df.empty:
@@ -69,27 +75,39 @@ class LocalTable(Table):
69
75
 
70
76
  super().__init__(
71
77
  name=name,
72
- columns=list(df.columns),
73
78
  primary_key=primary_key,
74
79
  time_column=time_column,
75
80
  end_time_column=end_time_column,
76
81
  )
77
82
 
78
- def _get_source_columns(self) -> List[SourceColumn]:
83
+ @property
84
+ def backend(self) -> DataBackend:
85
+ return cast(DataBackend, DataBackend.LOCAL)
86
+
87
+ def _get_source_columns(self) -> list[SourceColumn]:
79
88
  return [
80
89
  SourceColumn(
81
- name=column,
82
- dtype=infer_dtype(self._data[column]),
90
+ name=column_name,
91
+ dtype=None,
83
92
  is_primary_key=False,
84
93
  is_unique_key=False,
85
- ) for column in self._data.columns
94
+ is_nullable=True,
95
+ ) for column_name in self._data.columns
86
96
  ]
87
97
 
88
- def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
98
+ def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
89
99
  return []
90
100
 
91
- def _get_sample_df(self) -> pd.DataFrame:
101
+ def _get_source_sample_df(self) -> pd.DataFrame:
92
102
  return self._data
93
103
 
94
- def _get_num_rows(self) -> Optional[int]:
104
+ def _get_expr_sample_df(
105
+ self,
106
+ columns: Sequence[ColumnSpec],
107
+ ) -> pd.DataFrame:
108
+ raise RuntimeError(f"Column expressions are not supported in "
109
+ f"'{self.__class__.__name__}'. Please apply your "
110
+ f"expressions on the `pd.DataFrame` directly.")
111
+
112
+ def _get_num_rows(self) -> int | None:
95
113
  return len(self._data)
@@ -27,9 +27,11 @@ def connect(**kwargs: Any) -> Connection:
27
27
 
28
28
 
29
29
  from .table import SnowTable # noqa: E402
30
+ from .sampler import SnowSampler # noqa: E402
30
31
 
31
32
  __all__ = [
32
33
  'connect',
33
34
  'Connection',
34
35
  'SnowTable',
36
+ 'SnowSampler',
35
37
  ]
@@ -0,0 +1,297 @@
1
+ import json
2
+ from collections.abc import Iterator
3
+ from contextlib import contextmanager
4
+ from typing import TYPE_CHECKING
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import pyarrow as pa
9
+ from kumoapi.pquery import ValidatedPredictiveQuery
10
+
11
+ from kumoai.experimental.rfm.backend.snow import Connection, SnowTable
12
+ from kumoai.experimental.rfm.base import SQLSampler, Table
13
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
14
+ from kumoai.utils import ProgressLogger
15
+
16
+ if TYPE_CHECKING:
17
+ from kumoai.experimental.rfm import Graph
18
+
19
+
20
+ @contextmanager
21
+ def paramstyle(connection: Connection, style: str = 'qmark') -> Iterator[None]:
22
+ _style = connection._paramstyle
23
+ connection._paramstyle = style
24
+ yield
25
+ connection._paramstyle = _style
26
+
27
+
28
+ class SnowSampler(SQLSampler):
29
+ def __init__(
30
+ self,
31
+ graph: 'Graph',
32
+ verbose: bool | ProgressLogger = True,
33
+ ) -> None:
34
+ super().__init__(graph=graph, verbose=verbose)
35
+
36
+ for table in graph.tables.values():
37
+ assert isinstance(table, SnowTable)
38
+ self._connection = table._connection
39
+
40
+ def _get_min_max_time_dict(
41
+ self,
42
+ table_names: list[str],
43
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
44
+ selects: list[str] = []
45
+ for table_name in table_names:
46
+ column = self.time_column_dict[table_name]
47
+ column_ref = self.table_column_ref_dict[table_name][column]
48
+ select = (f"SELECT\n"
49
+ f" ? as table_name,\n"
50
+ f" MIN({column_ref}) as min_date,\n"
51
+ f" MAX({column_ref}) as max_date\n"
52
+ f"FROM {self.source_name_dict[table_name]}")
53
+ selects.append(select)
54
+ sql = "\nUNION ALL\n".join(selects)
55
+
56
+ out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
57
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
58
+ cursor.execute(sql, table_names)
59
+ rows = cursor.fetchall()
60
+ for table_name, _min, _max in rows:
61
+ out_dict[table_name] = (
62
+ pd.Timestamp.max if _min is None else pd.Timestamp(_min),
63
+ pd.Timestamp.min if _max is None else pd.Timestamp(_max),
64
+ )
65
+
66
+ return out_dict
67
+
68
+ def _sample_entity_table(
69
+ self,
70
+ table_name: str,
71
+ columns: set[str],
72
+ num_rows: int,
73
+ random_seed: int | None = None,
74
+ ) -> pd.DataFrame:
75
+ # NOTE Snowflake does support `SEED` only as part of `SYSTEM` sampling.
76
+ num_rows = min(num_rows, 1_000_000) # Snowflake's upper limit.
77
+
78
+ source_table = self.source_table_dict[table_name]
79
+ filters: list[str] = []
80
+
81
+ key = self.primary_key_dict[table_name]
82
+ if key not in source_table or source_table[key].is_nullable:
83
+ key_ref = self.table_column_ref_dict[table_name][key]
84
+ filters.append(f" {key_ref} IS NOT NULL")
85
+
86
+ column = self.time_column_dict.get(table_name)
87
+ if column is None:
88
+ pass
89
+ elif column not in source_table or source_table[column].is_nullable:
90
+ column_ref = self.table_column_ref_dict[table_name][column]
91
+ filters.append(f" {column_ref} IS NOT NULL")
92
+
93
+ projections = [
94
+ self.table_column_proj_dict[table_name][column]
95
+ for column in columns
96
+ ]
97
+ sql = (f"SELECT {', '.join(projections)}\n"
98
+ f"FROM {self.source_name_dict[table_name]}\n"
99
+ f"SAMPLE ROW ({num_rows} ROWS)")
100
+ if len(filters) > 0:
101
+ sql += f"\nWHERE{' AND'.join(filters)}"
102
+
103
+ with self._connection.cursor() as cursor:
104
+ # NOTE This may return duplicate primary keys. This is okay.
105
+ cursor.execute(sql)
106
+ table = cursor.fetch_arrow_all()
107
+
108
+ return Table._sanitize(
109
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
110
+ dtype_dict=self.table_dtype_dict[table_name],
111
+ stype_dict=self.table_stype_dict[table_name],
112
+ )
113
+
114
+ def _sample_target(
115
+ self,
116
+ query: ValidatedPredictiveQuery,
117
+ entity_df: pd.DataFrame,
118
+ train_index: np.ndarray,
119
+ train_time: pd.Series,
120
+ num_train_examples: int,
121
+ test_index: np.ndarray,
122
+ test_time: pd.Series,
123
+ num_test_examples: int,
124
+ columns_dict: dict[str, set[str]],
125
+ time_offset_dict: dict[
126
+ tuple[str, str, str],
127
+ tuple[pd.DateOffset | None, pd.DateOffset],
128
+ ],
129
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
130
+
131
+ # NOTE For Snowflake, we execute everything at once to pay minimal
132
+ # query initialization costs.
133
+ index = np.concatenate([train_index, test_index])
134
+ time = pd.concat([train_time, test_time], axis=0, ignore_index=True)
135
+
136
+ entity_df = entity_df.iloc[index].reset_index(drop=True)
137
+
138
+ feat_dict: dict[str, pd.DataFrame] = {query.entity_table: entity_df}
139
+ time_dict: dict[str, pd.Series] = {}
140
+ time_column = self.time_column_dict.get(query.entity_table)
141
+ if time_column in columns_dict[query.entity_table]:
142
+ time_dict[query.entity_table] = entity_df[time_column]
143
+ batch_dict: dict[str, np.ndarray] = {
144
+ query.entity_table: np.arange(len(entity_df)),
145
+ }
146
+ for edge_type, (min_offset, max_offset) in time_offset_dict.items():
147
+ table_name, fkey, _ = edge_type
148
+ feat_dict[table_name], batch_dict[table_name] = self._by_time(
149
+ table_name=table_name,
150
+ fkey=fkey,
151
+ pkey=entity_df[self.primary_key_dict[query.entity_table]],
152
+ anchor_time=time,
153
+ min_offset=min_offset,
154
+ max_offset=max_offset,
155
+ columns=columns_dict[table_name],
156
+ )
157
+ time_column = self.time_column_dict.get(table_name)
158
+ if time_column in columns_dict[table_name]:
159
+ time_dict[table_name] = feat_dict[table_name][time_column]
160
+
161
+ y, mask = PQueryPandasExecutor().execute(
162
+ query=query,
163
+ feat_dict=feat_dict,
164
+ time_dict=time_dict,
165
+ batch_dict=batch_dict,
166
+ anchor_time=time,
167
+ num_forecasts=query.num_forecasts,
168
+ )
169
+
170
+ train_mask = mask[:len(train_index)]
171
+ test_mask = mask[len(train_index):]
172
+
173
+ boundary = int(train_mask.sum())
174
+ train_y = y.iloc[:boundary]
175
+ test_y = y.iloc[boundary:].reset_index(drop=True)
176
+
177
+ return train_y, train_mask, test_y, test_mask
178
+
179
+ def _by_pkey(
180
+ self,
181
+ table_name: str,
182
+ pkey: pd.Series,
183
+ columns: set[str],
184
+ ) -> tuple[pd.DataFrame, np.ndarray]:
185
+ key = self.primary_key_dict[table_name]
186
+ key_ref = self.table_column_ref_dict[table_name][key]
187
+ projections = [
188
+ self.table_column_proj_dict[table_name][column]
189
+ for column in columns
190
+ ]
191
+
192
+ payload = json.dumps(list(pkey))
193
+
194
+ sql = ("WITH TMP as (\n"
195
+ " SELECT\n"
196
+ " f.index as __KUMO_BATCH__,\n")
197
+ if self.table_dtype_dict[table_name][key].is_int():
198
+ sql += " f.value::NUMBER as __KUMO_ID__\n"
199
+ elif self.table_dtype_dict[table_name][key].is_float():
200
+ sql += " f.value::FLOAT as __KUMO_ID__\n"
201
+ else:
202
+ sql += " f.value::VARCHAR as __KUMO_ID__\n"
203
+ sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
204
+ f")\n"
205
+ f"SELECT "
206
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
207
+ f"{', '.join(projections)}\n"
208
+ f"FROM TMP\n"
209
+ f"JOIN {self.source_name_dict[table_name]} ENT\n"
210
+ f" ON {key_ref} = TMP.__KUMO_ID__")
211
+
212
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
213
+ cursor.execute(sql, (payload, ))
214
+ table = cursor.fetch_arrow_all()
215
+
216
+ # Remove any duplicated primary keys in post-processing:
217
+ tmp = table.append_column('__KUMO_ID__', pa.array(range(len(table))))
218
+ gb = tmp.group_by('__KUMO_BATCH__').aggregate([('__KUMO_ID__', 'min')])
219
+ table = table.take(gb['__KUMO_ID___min'])
220
+
221
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
222
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
223
+ table = table.remove_column(batch_index)
224
+
225
+ return Table._sanitize(
226
+ df=table.to_pandas(),
227
+ dtype_dict=self.table_dtype_dict[table_name],
228
+ stype_dict=self.table_stype_dict[table_name],
229
+ ), batch
230
+
231
+ # Helper Methods ##########################################################
232
+
233
+ def _by_time(
234
+ self,
235
+ table_name: str,
236
+ fkey: str,
237
+ pkey: pd.Series,
238
+ anchor_time: pd.Series,
239
+ min_offset: pd.DateOffset | None,
240
+ max_offset: pd.DateOffset,
241
+ columns: set[str],
242
+ ) -> tuple[pd.DataFrame, np.ndarray]:
243
+ time_column = self.time_column_dict[table_name]
244
+
245
+ end_time = anchor_time + max_offset
246
+ end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
247
+ if min_offset is not None:
248
+ start_time = anchor_time + min_offset
249
+ start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
250
+ payload = json.dumps(list(zip(pkey, end_time, start_time)))
251
+ else:
252
+ payload = json.dumps(list(zip(pkey, end_time)))
253
+
254
+ key_ref = self.table_column_ref_dict[table_name][fkey]
255
+ time_ref = self.table_column_ref_dict[table_name][time_column]
256
+ projections = [
257
+ self.table_column_proj_dict[table_name][column]
258
+ for column in columns
259
+ ]
260
+ sql = ("WITH TMP as (\n"
261
+ " SELECT\n"
262
+ " f.index as __KUMO_BATCH__,\n")
263
+ if self.table_dtype_dict[table_name][fkey].is_int():
264
+ sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
265
+ elif self.table_dtype_dict[table_name][fkey].is_float():
266
+ sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
267
+ else:
268
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
269
+ sql += " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__"
270
+ if min_offset is not None:
271
+ sql += ",\n f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__"
272
+ sql += (f"\n"
273
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
274
+ f")\n"
275
+ f"SELECT "
276
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
277
+ f"{', '.join(projections)}\n"
278
+ f"FROM TMP\n"
279
+ f"JOIN {self.source_name_dict[table_name]} FACT\n"
280
+ f" ON {key_ref} = TMP.__KUMO_ID__\n"
281
+ f" AND {time_ref} <= TMP.__KUMO_END_TIME__")
282
+ if min_offset is not None:
283
+ sql += f"\n AND {time_ref} > TMP.__KUMO_START_TIME__"
284
+
285
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
286
+ cursor.execute(sql, (payload, ))
287
+ table = cursor.fetch_arrow_all()
288
+
289
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
290
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
291
+ table = table.remove_column(batch_index)
292
+
293
+ return Table._sanitize(
294
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
295
+ dtype_dict=self.table_dtype_dict[table_name],
296
+ stype_dict=self.table_stype_dict[table_name],
297
+ ), batch