kumoai 2.13.0.dev202512091732__cp311-cp311-macosx_11_0_arm64.whl → 2.14.0.dev202601051732__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +24 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +51 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/graph_store.py +52 -104
- kumoai/experimental/rfm/backend/local/sampler.py +125 -55
- kumoai/experimental/rfm/backend/local/table.py +35 -31
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +174 -49
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
- kumoai/experimental/rfm/base/__init__.py +21 -5
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +422 -35
- kumoai/experimental/rfm/base/source.py +2 -1
- kumoai/experimental/rfm/base/sql_sampler.py +144 -0
- kumoai/experimental/rfm/base/table.py +386 -195
- kumoai/experimental/rfm/graph.py +350 -178
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +7 -4
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +630 -408
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/experimental/rfm/task_table.py +290 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +51 -0
- kumoai/utils/progress_logger.py +190 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/METADATA +3 -2
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/RECORD +49 -40
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,12 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
import pandas as pd
|
|
6
5
|
from kumoapi.rfm.context import Subgraph
|
|
7
|
-
from kumoapi.typing import Stype
|
|
8
6
|
|
|
9
7
|
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
10
|
-
from kumoai.
|
|
8
|
+
from kumoai.experimental.rfm.base import Table
|
|
9
|
+
from kumoai.utils import ProgressLogger
|
|
11
10
|
|
|
12
11
|
try:
|
|
13
12
|
import torch
|
|
@@ -23,37 +22,32 @@ class LocalGraphStore:
|
|
|
23
22
|
def __init__(
|
|
24
23
|
self,
|
|
25
24
|
graph: 'Graph',
|
|
26
|
-
verbose:
|
|
25
|
+
verbose: bool | ProgressLogger = True,
|
|
27
26
|
) -> None:
|
|
28
27
|
|
|
29
28
|
if not isinstance(verbose, ProgressLogger):
|
|
30
|
-
verbose =
|
|
31
|
-
"Materializing graph",
|
|
29
|
+
verbose = ProgressLogger.default(
|
|
30
|
+
msg="Materializing graph",
|
|
32
31
|
verbose=verbose,
|
|
33
32
|
)
|
|
34
33
|
|
|
35
34
|
with verbose as logger:
|
|
36
35
|
self.df_dict, self.mask_dict = self.sanitize(graph)
|
|
37
|
-
self.stype_dict = self.get_stype_dict(graph)
|
|
38
36
|
logger.log("Sanitized input data")
|
|
39
37
|
|
|
40
|
-
self.
|
|
38
|
+
self.pkey_map_dict = self.get_pkey_map_dict(graph)
|
|
41
39
|
num_pkeys = sum(t.has_primary_key() for t in graph.tables.values())
|
|
42
40
|
if num_pkeys > 1:
|
|
43
41
|
logger.log(f"Collected primary keys from {num_pkeys} tables")
|
|
44
42
|
else:
|
|
45
43
|
logger.log(f"Collected primary key from {num_pkeys} table")
|
|
46
44
|
|
|
47
|
-
(
|
|
48
|
-
|
|
49
|
-
self.
|
|
50
|
-
self.
|
|
51
|
-
self.min_time,
|
|
52
|
-
self.max_time,
|
|
53
|
-
) = self.get_time_data(graph)
|
|
54
|
-
if self.max_time != pd.Timestamp.min:
|
|
45
|
+
self.time_dict, self.min_max_time_dict = self.get_time_data(graph)
|
|
46
|
+
if len(self.min_max_time_dict) > 0:
|
|
47
|
+
min_time = min(t for t, _ in self.min_max_time_dict.values())
|
|
48
|
+
max_time = max(t for _, t in self.min_max_time_dict.values())
|
|
55
49
|
logger.log(f"Identified temporal graph from "
|
|
56
|
-
f"{
|
|
50
|
+
f"{min_time.date()} to {max_time.date()}")
|
|
57
51
|
else:
|
|
58
52
|
logger.log("Identified static graph without timestamps")
|
|
59
53
|
|
|
@@ -63,14 +57,6 @@ class LocalGraphStore:
|
|
|
63
57
|
logger.log(f"Created graph with {num_nodes:,} nodes and "
|
|
64
58
|
f"{num_edges:,} edges")
|
|
65
59
|
|
|
66
|
-
@property
|
|
67
|
-
def node_types(self) -> List[str]:
|
|
68
|
-
return list(self.df_dict.keys())
|
|
69
|
-
|
|
70
|
-
@property
|
|
71
|
-
def edge_types(self) -> List[Tuple[str, str, str]]:
|
|
72
|
-
return list(self.row_dict.keys())
|
|
73
|
-
|
|
74
60
|
def get_node_id(self, table_name: str, pkey: pd.Series) -> np.ndarray:
|
|
75
61
|
r"""Returns the node ID given primary keys.
|
|
76
62
|
|
|
@@ -107,7 +93,7 @@ class LocalGraphStore:
|
|
|
107
93
|
def sanitize(
|
|
108
94
|
self,
|
|
109
95
|
graph: 'Graph',
|
|
110
|
-
) ->
|
|
96
|
+
) -> tuple[dict[str, pd.DataFrame], dict[str, np.ndarray]]:
|
|
111
97
|
r"""Sanitizes raw data according to table schema definition:
|
|
112
98
|
|
|
113
99
|
In particular, it:
|
|
@@ -116,30 +102,24 @@ class LocalGraphStore:
|
|
|
116
102
|
* drops duplicate primary keys
|
|
117
103
|
* removes rows with missing primary keys or time values
|
|
118
104
|
"""
|
|
119
|
-
df_dict:
|
|
105
|
+
df_dict: dict[str, pd.DataFrame] = {}
|
|
120
106
|
for table_name, table in graph.tables.items():
|
|
121
107
|
assert isinstance(table, LocalTable)
|
|
122
|
-
|
|
123
|
-
|
|
108
|
+
df_dict[table_name] = Table._sanitize(
|
|
109
|
+
df=table._data.copy(deep=False).reset_index(drop=True),
|
|
110
|
+
dtype_dict={
|
|
111
|
+
column.name: column.dtype
|
|
112
|
+
for column in table.columns
|
|
113
|
+
},
|
|
114
|
+
stype_dict={
|
|
115
|
+
column.name: column.stype
|
|
116
|
+
for column in table.columns
|
|
117
|
+
},
|
|
118
|
+
)
|
|
124
119
|
|
|
125
|
-
mask_dict:
|
|
120
|
+
mask_dict: dict[str, np.ndarray] = {}
|
|
126
121
|
for table in graph.tables.values():
|
|
127
|
-
|
|
128
|
-
if col.stype == Stype.timestamp:
|
|
129
|
-
ser = df_dict[table.name][col.name]
|
|
130
|
-
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
131
|
-
with warnings.catch_warnings():
|
|
132
|
-
warnings.filterwarnings(
|
|
133
|
-
'ignore',
|
|
134
|
-
message='Could not infer format',
|
|
135
|
-
)
|
|
136
|
-
ser = pd.to_datetime(ser, errors='coerce')
|
|
137
|
-
df_dict[table.name][col.name] = ser
|
|
138
|
-
if isinstance(ser.dtype, pd.DatetimeTZDtype):
|
|
139
|
-
ser = ser.dt.tz_localize(None)
|
|
140
|
-
df_dict[table.name][col.name] = ser
|
|
141
|
-
|
|
142
|
-
mask: Optional[np.ndarray] = None
|
|
122
|
+
mask: np.ndarray | None = None
|
|
143
123
|
if table._time_column is not None:
|
|
144
124
|
ser = df_dict[table.name][table._time_column]
|
|
145
125
|
mask = ser.notna().to_numpy()
|
|
@@ -154,34 +134,16 @@ class LocalGraphStore:
|
|
|
154
134
|
|
|
155
135
|
return df_dict, mask_dict
|
|
156
136
|
|
|
157
|
-
def
|
|
158
|
-
stype_dict: Dict[str, Dict[str, Stype]] = {}
|
|
159
|
-
foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
|
|
160
|
-
for table in graph.tables.values():
|
|
161
|
-
stype_dict[table.name] = {}
|
|
162
|
-
for column in table.columns:
|
|
163
|
-
if column == table.primary_key:
|
|
164
|
-
continue
|
|
165
|
-
if (table.name, column.name) in foreign_keys:
|
|
166
|
-
continue
|
|
167
|
-
stype_dict[table.name][column.name] = column.stype
|
|
168
|
-
return stype_dict
|
|
169
|
-
|
|
170
|
-
def get_pkey_data(
|
|
137
|
+
def get_pkey_map_dict(
|
|
171
138
|
self,
|
|
172
139
|
graph: 'Graph',
|
|
173
|
-
) ->
|
|
174
|
-
|
|
175
|
-
Dict[str, pd.DataFrame],
|
|
176
|
-
]:
|
|
177
|
-
pkey_name_dict: Dict[str, str] = {}
|
|
178
|
-
pkey_map_dict: Dict[str, pd.DataFrame] = {}
|
|
140
|
+
) -> dict[str, pd.DataFrame]:
|
|
141
|
+
pkey_map_dict: dict[str, pd.DataFrame] = {}
|
|
179
142
|
|
|
180
143
|
for table in graph.tables.values():
|
|
181
144
|
if table._primary_key is None:
|
|
182
145
|
continue
|
|
183
146
|
|
|
184
|
-
pkey_name_dict[table.name] = table._primary_key
|
|
185
147
|
pkey = self.df_dict[table.name][table._primary_key]
|
|
186
148
|
pkey_map = pd.DataFrame(
|
|
187
149
|
dict(arange=range(len(pkey))),
|
|
@@ -203,62 +165,48 @@ class LocalGraphStore:
|
|
|
203
165
|
|
|
204
166
|
pkey_map_dict[table.name] = pkey_map
|
|
205
167
|
|
|
206
|
-
return
|
|
168
|
+
return pkey_map_dict
|
|
207
169
|
|
|
208
170
|
def get_time_data(
|
|
209
171
|
self,
|
|
210
172
|
graph: 'Graph',
|
|
211
|
-
) ->
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
Dict[str, np.ndarray],
|
|
215
|
-
pd.Timestamp,
|
|
216
|
-
pd.Timestamp,
|
|
173
|
+
) -> tuple[
|
|
174
|
+
dict[str, np.ndarray],
|
|
175
|
+
dict[str, tuple[pd.Timestamp, pd.Timestamp]],
|
|
217
176
|
]:
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
time_dict: Dict[str, np.ndarray] = {}
|
|
221
|
-
min_time = pd.Timestamp.max
|
|
222
|
-
max_time = pd.Timestamp.min
|
|
177
|
+
time_dict: dict[str, np.ndarray] = {}
|
|
178
|
+
min_max_time_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
223
179
|
for table in graph.tables.values():
|
|
224
|
-
if table._end_time_column is not None:
|
|
225
|
-
end_time_column_dict[table.name] = table._end_time_column
|
|
226
|
-
|
|
227
180
|
if table._time_column is None:
|
|
228
181
|
continue
|
|
229
182
|
|
|
230
183
|
time = self.df_dict[table.name][table._time_column]
|
|
231
|
-
if time.dtype != 'datetime64[ns]':
|
|
232
|
-
time = time.astype('datetime64[ns]')
|
|
233
184
|
time_dict[table.name] = time.astype(int).to_numpy() // 1000**3
|
|
234
|
-
time_column_dict[table.name] = table._time_column
|
|
235
185
|
|
|
236
186
|
if table.name in self.mask_dict.keys():
|
|
237
187
|
time = time[self.mask_dict[table.name]]
|
|
238
188
|
if len(time) > 0:
|
|
239
|
-
|
|
240
|
-
|
|
189
|
+
min_max_time_dict[table.name] = (time.min(), time.max())
|
|
190
|
+
else:
|
|
191
|
+
min_max_time_dict[table.name] = (
|
|
192
|
+
pd.Timestamp.max,
|
|
193
|
+
pd.Timestamp.min,
|
|
194
|
+
)
|
|
241
195
|
|
|
242
|
-
return
|
|
243
|
-
time_column_dict,
|
|
244
|
-
end_time_column_dict,
|
|
245
|
-
time_dict,
|
|
246
|
-
min_time,
|
|
247
|
-
max_time,
|
|
248
|
-
)
|
|
196
|
+
return time_dict, min_max_time_dict
|
|
249
197
|
|
|
250
198
|
def get_csc(
|
|
251
199
|
self,
|
|
252
200
|
graph: 'Graph',
|
|
253
|
-
) ->
|
|
254
|
-
|
|
255
|
-
|
|
201
|
+
) -> tuple[
|
|
202
|
+
dict[tuple[str, str, str], np.ndarray],
|
|
203
|
+
dict[tuple[str, str, str], np.ndarray],
|
|
256
204
|
]:
|
|
257
205
|
# A mapping from raw primary keys to node indices (0 to N-1):
|
|
258
|
-
map_dict:
|
|
206
|
+
map_dict: dict[str, pd.CategoricalDtype] = {}
|
|
259
207
|
# A dictionary to manage offsets of node indices for invalid rows:
|
|
260
|
-
offset_dict:
|
|
261
|
-
for table_name in
|
|
208
|
+
offset_dict: dict[str, np.ndarray] = {}
|
|
209
|
+
for table_name in {edge.dst_table for edge in graph.edges}:
|
|
262
210
|
ser = self.df_dict[table_name][graph[table_name]._primary_key]
|
|
263
211
|
if table_name in self.mask_dict.keys():
|
|
264
212
|
mask = self.mask_dict[table_name]
|
|
@@ -267,8 +215,8 @@ class LocalGraphStore:
|
|
|
267
215
|
map_dict[table_name] = pd.CategoricalDtype(ser, ordered=True)
|
|
268
216
|
|
|
269
217
|
# Build CSC graph representation:
|
|
270
|
-
row_dict:
|
|
271
|
-
colptr_dict:
|
|
218
|
+
row_dict: dict[tuple[str, str, str], np.ndarray] = {}
|
|
219
|
+
colptr_dict: dict[tuple[str, str, str], np.ndarray] = {}
|
|
272
220
|
for src_table, fkey, dst_table in graph.edges:
|
|
273
221
|
src_df = self.df_dict[src_table]
|
|
274
222
|
dst_df = self.df_dict[dst_table]
|
|
@@ -330,7 +278,7 @@ def _argsort(input: np.ndarray) -> np.ndarray:
|
|
|
330
278
|
return torch.from_numpy(input).argsort().numpy()
|
|
331
279
|
|
|
332
280
|
|
|
333
|
-
def _lexsort(inputs:
|
|
281
|
+
def _lexsort(inputs: list[np.ndarray]) -> np.ndarray:
|
|
334
282
|
assert len(inputs) >= 1
|
|
335
283
|
|
|
336
284
|
if not WITH_TORCH:
|
|
@@ -5,7 +5,7 @@ import pandas as pd
|
|
|
5
5
|
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
6
6
|
|
|
7
7
|
from kumoai.experimental.rfm.backend.local import LocalGraphStore
|
|
8
|
-
from kumoai.experimental.rfm.base import Sampler, SamplerOutput
|
|
8
|
+
from kumoai.experimental.rfm.base import Sampler, SamplerOutput
|
|
9
9
|
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
10
10
|
from kumoai.utils import ProgressLogger
|
|
11
11
|
|
|
@@ -19,7 +19,7 @@ class LocalSampler(Sampler):
|
|
|
19
19
|
graph: 'Graph',
|
|
20
20
|
verbose: bool | ProgressLogger = True,
|
|
21
21
|
) -> None:
|
|
22
|
-
super().__init__(graph=graph)
|
|
22
|
+
super().__init__(graph=graph, verbose=verbose)
|
|
23
23
|
|
|
24
24
|
import kumoai.kumolib as kumolib
|
|
25
25
|
|
|
@@ -38,19 +38,32 @@ class LocalSampler(Sampler):
|
|
|
38
38
|
self._graph_store.time_dict,
|
|
39
39
|
)
|
|
40
40
|
|
|
41
|
+
def _get_min_max_time_dict(
|
|
42
|
+
self,
|
|
43
|
+
table_names: list[str],
|
|
44
|
+
) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
|
|
45
|
+
return {
|
|
46
|
+
key: value
|
|
47
|
+
for key, value in self._graph_store.min_max_time_dict.items()
|
|
48
|
+
if key in table_names
|
|
49
|
+
}
|
|
50
|
+
|
|
41
51
|
def _sample_subgraph(
|
|
42
52
|
self,
|
|
43
53
|
entity_table_name: str,
|
|
44
54
|
entity_pkey: pd.Series,
|
|
45
|
-
anchor_time: pd.Series,
|
|
55
|
+
anchor_time: pd.Series | Literal['entity'],
|
|
46
56
|
columns_dict: dict[str, set[str]],
|
|
47
57
|
num_neighbors: list[int],
|
|
48
58
|
) -> SamplerOutput:
|
|
49
59
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
60
|
+
index = self._graph_store.get_node_id(entity_table_name, entity_pkey)
|
|
61
|
+
|
|
62
|
+
if isinstance(anchor_time, pd.Series):
|
|
63
|
+
time = anchor_time.astype(int).to_numpy() // 1000**3 # to seconds
|
|
64
|
+
else:
|
|
65
|
+
assert anchor_time == 'entity'
|
|
66
|
+
time = self._graph_store.time_dict[entity_table_name][index]
|
|
54
67
|
|
|
55
68
|
(
|
|
56
69
|
row_dict,
|
|
@@ -60,11 +73,14 @@ class LocalSampler(Sampler):
|
|
|
60
73
|
num_sampled_nodes_dict,
|
|
61
74
|
num_sampled_edges_dict,
|
|
62
75
|
) = self._graph_sampler.sample(
|
|
63
|
-
|
|
76
|
+
{
|
|
77
|
+
'__'.join(edge_type): num_neighbors
|
|
78
|
+
for edge_type in self.edge_types
|
|
79
|
+
},
|
|
64
80
|
{},
|
|
65
81
|
entity_table_name,
|
|
66
|
-
|
|
67
|
-
|
|
82
|
+
index,
|
|
83
|
+
time,
|
|
68
84
|
)
|
|
69
85
|
|
|
70
86
|
df_dict: dict[str, pd.DataFrame] = {}
|
|
@@ -108,6 +124,7 @@ class LocalSampler(Sampler):
|
|
|
108
124
|
}
|
|
109
125
|
|
|
110
126
|
return SamplerOutput(
|
|
127
|
+
anchor_time=time * 1000**3, # to nanoseconds
|
|
111
128
|
df_dict=df_dict,
|
|
112
129
|
inverse_dict=inverse_dict,
|
|
113
130
|
batch_dict=batch_dict,
|
|
@@ -117,51 +134,80 @@ class LocalSampler(Sampler):
|
|
|
117
134
|
num_sampled_edges_dict=num_sampled_edges_dict,
|
|
118
135
|
)
|
|
119
136
|
|
|
137
|
+
def _sample_entity_table(
|
|
138
|
+
self,
|
|
139
|
+
table_name: str,
|
|
140
|
+
columns: set[str],
|
|
141
|
+
num_rows: int,
|
|
142
|
+
random_seed: int | None = None,
|
|
143
|
+
) -> pd.DataFrame:
|
|
144
|
+
pkey_map = self._graph_store.pkey_map_dict[table_name]
|
|
145
|
+
if len(pkey_map) > num_rows:
|
|
146
|
+
pkey_map = pkey_map.sample(
|
|
147
|
+
n=num_rows,
|
|
148
|
+
random_state=random_seed,
|
|
149
|
+
ignore_index=True,
|
|
150
|
+
)
|
|
151
|
+
df = self._graph_store.df_dict[table_name]
|
|
152
|
+
df = df.iloc[pkey_map['arange']][list(columns)]
|
|
153
|
+
return df
|
|
154
|
+
|
|
120
155
|
def _sample_target(
|
|
121
156
|
self,
|
|
122
157
|
query: ValidatedPredictiveQuery,
|
|
123
|
-
|
|
124
|
-
|
|
158
|
+
entity_df: pd.DataFrame,
|
|
159
|
+
train_index: np.ndarray,
|
|
160
|
+
train_time: pd.Series,
|
|
161
|
+
num_train_examples: int,
|
|
162
|
+
test_index: np.ndarray,
|
|
163
|
+
test_time: pd.Series,
|
|
164
|
+
num_test_examples: int,
|
|
125
165
|
columns_dict: dict[str, set[str]],
|
|
126
166
|
time_offset_dict: dict[
|
|
127
167
|
tuple[str, str, str],
|
|
128
168
|
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
129
169
|
],
|
|
130
|
-
|
|
131
|
-
) -> TargetOutput:
|
|
132
|
-
|
|
133
|
-
candidate = pd.Series([0, 1]) # TODO
|
|
134
|
-
anchor_time = pd.Series(anchor_time).repeat(len(candidate))
|
|
135
|
-
anchor_time = anchor_time.reset_index(drop=True)
|
|
136
|
-
if anchor_time.dtype != 'datetime64[ns]':
|
|
137
|
-
anchor_time = anchor_time.astype('datetime64[ns]')
|
|
170
|
+
) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
|
|
138
171
|
|
|
139
|
-
|
|
172
|
+
train_y, train_mask = self._sample_target_set(
|
|
140
173
|
query=query,
|
|
141
|
-
|
|
142
|
-
|
|
174
|
+
pkey=entity_df[self.primary_key_dict[query.entity_table]],
|
|
175
|
+
index=train_index,
|
|
176
|
+
anchor_time=train_time,
|
|
177
|
+
num_examples=num_train_examples,
|
|
143
178
|
columns_dict=columns_dict,
|
|
144
179
|
time_offset_dict=time_offset_dict,
|
|
145
180
|
)
|
|
146
181
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
182
|
+
test_y, test_mask = self._sample_target_set(
|
|
183
|
+
query=query,
|
|
184
|
+
pkey=entity_df[self.primary_key_dict[query.entity_table]],
|
|
185
|
+
index=test_index,
|
|
186
|
+
anchor_time=test_time,
|
|
187
|
+
num_examples=num_test_examples,
|
|
188
|
+
columns_dict=columns_dict,
|
|
189
|
+
time_offset_dict=time_offset_dict,
|
|
152
190
|
)
|
|
153
191
|
|
|
154
|
-
|
|
192
|
+
return train_y, train_mask, test_y, test_mask
|
|
193
|
+
|
|
194
|
+
# Helper Methods ##########################################################
|
|
195
|
+
|
|
196
|
+
def _sample_target_set(
|
|
155
197
|
self,
|
|
156
198
|
query: ValidatedPredictiveQuery,
|
|
157
|
-
|
|
199
|
+
pkey: pd.Series,
|
|
200
|
+
index: np.ndarray,
|
|
158
201
|
anchor_time: pd.Series,
|
|
202
|
+
num_examples: int,
|
|
159
203
|
columns_dict: dict[str, set[str]],
|
|
160
204
|
time_offset_dict: dict[
|
|
161
205
|
tuple[str, str, str],
|
|
162
206
|
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
163
207
|
],
|
|
208
|
+
batch_size: int = 10_000,
|
|
164
209
|
) -> tuple[pd.Series, np.ndarray]:
|
|
210
|
+
|
|
165
211
|
num_hops = 1 if len(time_offset_dict) > 0 else 0
|
|
166
212
|
num_neighbors_dict: dict[str, list[int]] = {}
|
|
167
213
|
unix_time_offset_dict: dict[str, list[list[int | None]]] = {}
|
|
@@ -173,35 +219,59 @@ class LocalSampler(Sampler):
|
|
|
173
219
|
for edge_type in set(self.edge_types) - set(time_offset_dict.keys()):
|
|
174
220
|
num_neighbors_dict['__'.join(edge_type)] = [0] * num_hops
|
|
175
221
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
anchor_time.
|
|
182
|
-
)
|
|
222
|
+
count = 0
|
|
223
|
+
ys: list[pd.Series] = []
|
|
224
|
+
mask = np.full(len(index), False, dtype=bool)
|
|
225
|
+
for start in range(0, len(index), batch_size):
|
|
226
|
+
subset = pkey.iloc[index[start:start + batch_size]]
|
|
227
|
+
time = anchor_time.iloc[start:start + batch_size]
|
|
183
228
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
if time_column := self.time_column_dict.get(table_name):
|
|
192
|
-
time_dict[table_name] = df[time_column]
|
|
229
|
+
_, _, node_dict, batch_dict, _, _ = self._graph_sampler.sample(
|
|
230
|
+
num_neighbors_dict,
|
|
231
|
+
unix_time_offset_dict,
|
|
232
|
+
query.entity_table,
|
|
233
|
+
self._graph_store.get_node_id(query.entity_table, subset),
|
|
234
|
+
time.astype(int).to_numpy() // 1000**3, # to seconds
|
|
235
|
+
)
|
|
193
236
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
237
|
+
feat_dict: dict[str, pd.DataFrame] = {}
|
|
238
|
+
time_dict: dict[str, pd.Series] = {}
|
|
239
|
+
for table_name, columns in columns_dict.items():
|
|
240
|
+
df = self._graph_store.df_dict[table_name]
|
|
241
|
+
df = df.iloc[node_dict[table_name]].reset_index(drop=True)
|
|
242
|
+
df = df[list(columns)]
|
|
243
|
+
feat_dict[table_name] = df
|
|
244
|
+
|
|
245
|
+
time_column = self.time_column_dict.get(table_name)
|
|
246
|
+
if time_column in columns:
|
|
247
|
+
time_dict[table_name] = df[time_column]
|
|
248
|
+
|
|
249
|
+
y, _mask = PQueryPandasExecutor().execute(
|
|
250
|
+
query=query,
|
|
251
|
+
feat_dict=feat_dict,
|
|
252
|
+
time_dict=time_dict,
|
|
253
|
+
batch_dict=batch_dict,
|
|
254
|
+
anchor_time=time,
|
|
255
|
+
num_forecasts=query.num_forecasts,
|
|
256
|
+
)
|
|
257
|
+
ys.append(y)
|
|
258
|
+
mask[start:start + batch_size] = _mask
|
|
259
|
+
|
|
260
|
+
count += len(y)
|
|
261
|
+
if count >= num_examples:
|
|
262
|
+
break
|
|
263
|
+
|
|
264
|
+
if len(ys) == 0:
|
|
265
|
+
y = pd.Series([], dtype=float)
|
|
266
|
+
elif len(ys) == 1:
|
|
267
|
+
y = ys[0]
|
|
268
|
+
else:
|
|
269
|
+
y = pd.concat(ys, axis=0, ignore_index=True)
|
|
270
|
+
|
|
271
|
+
return y, mask
|
|
202
272
|
|
|
203
273
|
|
|
204
|
-
# Helper
|
|
274
|
+
# Helper Functions ############################################################
|
|
205
275
|
|
|
206
276
|
|
|
207
277
|
def date_offset_to_seconds(offset: pd.DateOffset) -> int:
|
|
@@ -1,10 +1,15 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import List, Optional
|
|
1
|
+
from typing import Sequence, cast
|
|
3
2
|
|
|
4
3
|
import pandas as pd
|
|
4
|
+
from kumoapi.model_plan import MissingType
|
|
5
5
|
|
|
6
|
-
from kumoai.experimental.rfm.base import
|
|
7
|
-
|
|
6
|
+
from kumoai.experimental.rfm.base import (
|
|
7
|
+
ColumnSpec,
|
|
8
|
+
DataBackend,
|
|
9
|
+
SourceColumn,
|
|
10
|
+
SourceForeignKey,
|
|
11
|
+
Table,
|
|
12
|
+
)
|
|
8
13
|
|
|
9
14
|
|
|
10
15
|
class LocalTable(Table):
|
|
@@ -52,9 +57,9 @@ class LocalTable(Table):
|
|
|
52
57
|
self,
|
|
53
58
|
df: pd.DataFrame,
|
|
54
59
|
name: str,
|
|
55
|
-
primary_key:
|
|
56
|
-
time_column:
|
|
57
|
-
end_time_column:
|
|
60
|
+
primary_key: MissingType | str | None = MissingType.VALUE,
|
|
61
|
+
time_column: str | None = None,
|
|
62
|
+
end_time_column: str | None = None,
|
|
58
63
|
) -> None:
|
|
59
64
|
|
|
60
65
|
if df.empty:
|
|
@@ -70,40 +75,39 @@ class LocalTable(Table):
|
|
|
70
75
|
|
|
71
76
|
super().__init__(
|
|
72
77
|
name=name,
|
|
73
|
-
columns=list(df.columns),
|
|
74
78
|
primary_key=primary_key,
|
|
75
79
|
time_column=time_column,
|
|
76
80
|
end_time_column=end_time_column,
|
|
77
81
|
)
|
|
78
82
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
f"the data type of the column to use it within "
|
|
89
|
-
f"this table.")
|
|
90
|
-
continue
|
|
91
|
-
|
|
92
|
-
source_column = SourceColumn(
|
|
93
|
-
name=column,
|
|
94
|
-
dtype=dtype,
|
|
83
|
+
@property
|
|
84
|
+
def backend(self) -> DataBackend:
|
|
85
|
+
return cast(DataBackend, DataBackend.LOCAL)
|
|
86
|
+
|
|
87
|
+
def _get_source_columns(self) -> list[SourceColumn]:
|
|
88
|
+
return [
|
|
89
|
+
SourceColumn(
|
|
90
|
+
name=column_name,
|
|
91
|
+
dtype=None,
|
|
95
92
|
is_primary_key=False,
|
|
96
93
|
is_unique_key=False,
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
return source_columns
|
|
94
|
+
is_nullable=True,
|
|
95
|
+
) for column_name in self._data.columns
|
|
96
|
+
]
|
|
101
97
|
|
|
102
|
-
def _get_source_foreign_keys(self) ->
|
|
98
|
+
def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
|
|
103
99
|
return []
|
|
104
100
|
|
|
105
|
-
def
|
|
101
|
+
def _get_source_sample_df(self) -> pd.DataFrame:
|
|
106
102
|
return self._data
|
|
107
103
|
|
|
108
|
-
def
|
|
104
|
+
def _get_expr_sample_df(
|
|
105
|
+
self,
|
|
106
|
+
columns: Sequence[ColumnSpec],
|
|
107
|
+
) -> pd.DataFrame:
|
|
108
|
+
raise RuntimeError(f"Column expressions are not supported in "
|
|
109
|
+
f"'{self.__class__.__name__}'. Please apply your "
|
|
110
|
+
f"expressions on the `pd.DataFrame` directly.")
|
|
111
|
+
|
|
112
|
+
def _get_num_rows(self) -> int | None:
|
|
109
113
|
return len(self._data)
|