kumoai 2.13.0.dev202511211730__py3-none-any.whl → 2.15.0.dev202601131732__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +35 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +44 -9
- kumoai/experimental/rfm/__init__.py +70 -68
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +407 -0
- kumoai/experimental/rfm/backend/snow/table.py +245 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +454 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +783 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +385 -0
- kumoai/experimental/rfm/base/table.py +722 -0
- kumoai/experimental/rfm/base/utils.py +36 -0
- kumoai/experimental/rfm/{local_graph.py → graph.py} +581 -154
- kumoai/experimental/rfm/infer/__init__.py +8 -0
- kumoai/experimental/rfm/infer/dtype.py +84 -0
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +63 -0
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +783 -481
- kumoai/experimental/rfm/sagemaker.py +15 -7
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +192 -13
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/METADATA +10 -8
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/RECORD +55 -30
- kumoai/experimental/rfm/local_graph_sampler.py +0 -182
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/local_table.py +0 -545
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/top_level.txt +0 -0
|
@@ -1,14 +1,12 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import Dict, List, Optional, Tuple, Union
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
import pandas as pd
|
|
6
5
|
from kumoapi.rfm.context import Subgraph
|
|
7
|
-
from kumoapi.typing import Stype
|
|
8
6
|
|
|
9
|
-
from kumoai.experimental.rfm import
|
|
10
|
-
from kumoai.experimental.rfm.
|
|
11
|
-
from kumoai.utils import
|
|
7
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
8
|
+
from kumoai.experimental.rfm.base import Table
|
|
9
|
+
from kumoai.utils import ProgressLogger
|
|
12
10
|
|
|
13
11
|
try:
|
|
14
12
|
import torch
|
|
@@ -16,43 +14,40 @@ try:
|
|
|
16
14
|
except ImportError:
|
|
17
15
|
WITH_TORCH = False
|
|
18
16
|
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from kumoai.experimental.rfm import Graph
|
|
19
|
+
|
|
19
20
|
|
|
20
21
|
class LocalGraphStore:
|
|
21
22
|
def __init__(
|
|
22
23
|
self,
|
|
23
|
-
graph:
|
|
24
|
-
|
|
25
|
-
verbose: Union[bool, ProgressLogger] = True,
|
|
24
|
+
graph: 'Graph',
|
|
25
|
+
verbose: bool | ProgressLogger = True,
|
|
26
26
|
) -> None:
|
|
27
27
|
|
|
28
28
|
if not isinstance(verbose, ProgressLogger):
|
|
29
|
-
verbose =
|
|
30
|
-
"Materializing graph",
|
|
29
|
+
verbose = ProgressLogger.default(
|
|
30
|
+
msg="Materializing graph",
|
|
31
31
|
verbose=verbose,
|
|
32
32
|
)
|
|
33
33
|
|
|
34
34
|
with verbose as logger:
|
|
35
|
-
self.df_dict, self.mask_dict = self.sanitize(graph
|
|
36
|
-
self.stype_dict = self.get_stype_dict(graph)
|
|
35
|
+
self.df_dict, self.mask_dict = self.sanitize(graph)
|
|
37
36
|
logger.log("Sanitized input data")
|
|
38
37
|
|
|
39
|
-
self.
|
|
38
|
+
self.pkey_map_dict = self.get_pkey_map_dict(graph)
|
|
40
39
|
num_pkeys = sum(t.has_primary_key() for t in graph.tables.values())
|
|
41
40
|
if num_pkeys > 1:
|
|
42
41
|
logger.log(f"Collected primary keys from {num_pkeys} tables")
|
|
43
42
|
else:
|
|
44
43
|
logger.log(f"Collected primary key from {num_pkeys} table")
|
|
45
44
|
|
|
46
|
-
(
|
|
47
|
-
|
|
48
|
-
self.
|
|
49
|
-
self.
|
|
50
|
-
self.min_time,
|
|
51
|
-
self.max_time,
|
|
52
|
-
) = self.get_time_data(graph)
|
|
53
|
-
if self.max_time != pd.Timestamp.min:
|
|
45
|
+
self.time_dict, self.min_max_time_dict = self.get_time_data(graph)
|
|
46
|
+
if len(self.min_max_time_dict) > 0:
|
|
47
|
+
min_time = min(t for t, _ in self.min_max_time_dict.values())
|
|
48
|
+
max_time = max(t for _, t in self.min_max_time_dict.values())
|
|
54
49
|
logger.log(f"Identified temporal graph from "
|
|
55
|
-
f"{
|
|
50
|
+
f"{min_time.date()} to {max_time.date()}")
|
|
56
51
|
else:
|
|
57
52
|
logger.log("Identified static graph without timestamps")
|
|
58
53
|
|
|
@@ -62,14 +57,6 @@ class LocalGraphStore:
|
|
|
62
57
|
logger.log(f"Created graph with {num_nodes:,} nodes and "
|
|
63
58
|
f"{num_edges:,} edges")
|
|
64
59
|
|
|
65
|
-
@property
|
|
66
|
-
def node_types(self) -> List[str]:
|
|
67
|
-
return list(self.df_dict.keys())
|
|
68
|
-
|
|
69
|
-
@property
|
|
70
|
-
def edge_types(self) -> List[Tuple[str, str, str]]:
|
|
71
|
-
return list(self.row_dict.keys())
|
|
72
|
-
|
|
73
60
|
def get_node_id(self, table_name: str, pkey: pd.Series) -> np.ndarray:
|
|
74
61
|
r"""Returns the node ID given primary keys.
|
|
75
62
|
|
|
@@ -105,9 +92,8 @@ class LocalGraphStore:
|
|
|
105
92
|
|
|
106
93
|
def sanitize(
|
|
107
94
|
self,
|
|
108
|
-
graph:
|
|
109
|
-
|
|
110
|
-
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, np.ndarray]]:
|
|
95
|
+
graph: 'Graph',
|
|
96
|
+
) -> tuple[dict[str, pd.DataFrame], dict[str, np.ndarray]]:
|
|
111
97
|
r"""Sanitizes raw data according to table schema definition:
|
|
112
98
|
|
|
113
99
|
In particular, it:
|
|
@@ -115,42 +101,25 @@ class LocalGraphStore:
|
|
|
115
101
|
* drops timezone information from timestamps
|
|
116
102
|
* drops duplicate primary keys
|
|
117
103
|
* removes rows with missing primary keys or time values
|
|
118
|
-
|
|
119
|
-
If ``preprocess`` is set to ``True``, it will additionally pre-process
|
|
120
|
-
data for faster model processing. In particular, it:
|
|
121
|
-
* tokenizes any text column that is not a foreign key
|
|
122
104
|
"""
|
|
123
|
-
df_dict:
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
105
|
+
df_dict: dict[str, pd.DataFrame] = {}
|
|
106
|
+
for table_name, table in graph.tables.items():
|
|
107
|
+
assert isinstance(table, LocalTable)
|
|
108
|
+
df_dict[table_name] = Table._sanitize(
|
|
109
|
+
df=table._data.copy(deep=False).reset_index(drop=True),
|
|
110
|
+
dtype_dict={
|
|
111
|
+
column.name: column.dtype
|
|
112
|
+
for column in table.columns
|
|
113
|
+
},
|
|
114
|
+
stype_dict={
|
|
115
|
+
column.name: column.stype
|
|
116
|
+
for column in table.columns
|
|
117
|
+
},
|
|
118
|
+
)
|
|
129
119
|
|
|
130
|
-
mask_dict:
|
|
120
|
+
mask_dict: dict[str, np.ndarray] = {}
|
|
131
121
|
for table in graph.tables.values():
|
|
132
|
-
|
|
133
|
-
if col.stype == Stype.timestamp:
|
|
134
|
-
ser = df_dict[table.name][col.name]
|
|
135
|
-
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
136
|
-
with warnings.catch_warnings():
|
|
137
|
-
warnings.filterwarnings(
|
|
138
|
-
'ignore',
|
|
139
|
-
message='Could not infer format',
|
|
140
|
-
)
|
|
141
|
-
ser = pd.to_datetime(ser, errors='coerce')
|
|
142
|
-
df_dict[table.name][col.name] = ser
|
|
143
|
-
if isinstance(ser.dtype, pd.DatetimeTZDtype):
|
|
144
|
-
ser = ser.dt.tz_localize(None)
|
|
145
|
-
df_dict[table.name][col.name] = ser
|
|
146
|
-
|
|
147
|
-
# Normalize text in advance (but exclude foreign keys):
|
|
148
|
-
if (preprocess and col.stype == Stype.text
|
|
149
|
-
and (table.name, col.name) not in foreign_keys):
|
|
150
|
-
ser = df_dict[table.name][col.name]
|
|
151
|
-
df_dict[table.name][col.name] = normalize_text(ser)
|
|
152
|
-
|
|
153
|
-
mask: Optional[np.ndarray] = None
|
|
122
|
+
mask: np.ndarray | None = None
|
|
154
123
|
if table._time_column is not None:
|
|
155
124
|
ser = df_dict[table.name][table._time_column]
|
|
156
125
|
mask = ser.notna().to_numpy()
|
|
@@ -165,34 +134,16 @@ class LocalGraphStore:
|
|
|
165
134
|
|
|
166
135
|
return df_dict, mask_dict
|
|
167
136
|
|
|
168
|
-
def
|
|
169
|
-
stype_dict: Dict[str, Dict[str, Stype]] = {}
|
|
170
|
-
foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
|
|
171
|
-
for table in graph.tables.values():
|
|
172
|
-
stype_dict[table.name] = {}
|
|
173
|
-
for column in table.columns:
|
|
174
|
-
if column == table.primary_key:
|
|
175
|
-
continue
|
|
176
|
-
if (table.name, column.name) in foreign_keys:
|
|
177
|
-
continue
|
|
178
|
-
stype_dict[table.name][column.name] = column.stype
|
|
179
|
-
return stype_dict
|
|
180
|
-
|
|
181
|
-
def get_pkey_data(
|
|
137
|
+
def get_pkey_map_dict(
|
|
182
138
|
self,
|
|
183
|
-
graph:
|
|
184
|
-
) ->
|
|
185
|
-
|
|
186
|
-
Dict[str, pd.DataFrame],
|
|
187
|
-
]:
|
|
188
|
-
pkey_name_dict: Dict[str, str] = {}
|
|
189
|
-
pkey_map_dict: Dict[str, pd.DataFrame] = {}
|
|
139
|
+
graph: 'Graph',
|
|
140
|
+
) -> dict[str, pd.DataFrame]:
|
|
141
|
+
pkey_map_dict: dict[str, pd.DataFrame] = {}
|
|
190
142
|
|
|
191
143
|
for table in graph.tables.values():
|
|
192
144
|
if table._primary_key is None:
|
|
193
145
|
continue
|
|
194
146
|
|
|
195
|
-
pkey_name_dict[table.name] = table._primary_key
|
|
196
147
|
pkey = self.df_dict[table.name][table._primary_key]
|
|
197
148
|
pkey_map = pd.DataFrame(
|
|
198
149
|
dict(arange=range(len(pkey))),
|
|
@@ -214,61 +165,48 @@ class LocalGraphStore:
|
|
|
214
165
|
|
|
215
166
|
pkey_map_dict[table.name] = pkey_map
|
|
216
167
|
|
|
217
|
-
return
|
|
168
|
+
return pkey_map_dict
|
|
218
169
|
|
|
219
170
|
def get_time_data(
|
|
220
171
|
self,
|
|
221
|
-
graph:
|
|
222
|
-
) ->
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
Dict[str, np.ndarray],
|
|
226
|
-
pd.Timestamp,
|
|
227
|
-
pd.Timestamp,
|
|
172
|
+
graph: 'Graph',
|
|
173
|
+
) -> tuple[
|
|
174
|
+
dict[str, np.ndarray],
|
|
175
|
+
dict[str, tuple[pd.Timestamp, pd.Timestamp]],
|
|
228
176
|
]:
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
time_dict: Dict[str, np.ndarray] = {}
|
|
232
|
-
min_time = pd.Timestamp.max
|
|
233
|
-
max_time = pd.Timestamp.min
|
|
177
|
+
time_dict: dict[str, np.ndarray] = {}
|
|
178
|
+
min_max_time_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
234
179
|
for table in graph.tables.values():
|
|
235
|
-
if table._end_time_column is not None:
|
|
236
|
-
end_time_column_dict[table.name] = table._end_time_column
|
|
237
|
-
|
|
238
180
|
if table._time_column is None:
|
|
239
181
|
continue
|
|
240
182
|
|
|
241
183
|
time = self.df_dict[table.name][table._time_column]
|
|
242
|
-
time_dict[table.name] = time.astype(
|
|
243
|
-
int).to_numpy() // 1000**3
|
|
244
|
-
time_column_dict[table.name] = table._time_column
|
|
184
|
+
time_dict[table.name] = time.astype(int).to_numpy() // 1000**3
|
|
245
185
|
|
|
246
186
|
if table.name in self.mask_dict.keys():
|
|
247
187
|
time = time[self.mask_dict[table.name]]
|
|
248
188
|
if len(time) > 0:
|
|
249
|
-
|
|
250
|
-
|
|
189
|
+
min_max_time_dict[table.name] = (time.min(), time.max())
|
|
190
|
+
else:
|
|
191
|
+
min_max_time_dict[table.name] = (
|
|
192
|
+
pd.Timestamp.max,
|
|
193
|
+
pd.Timestamp.min,
|
|
194
|
+
)
|
|
251
195
|
|
|
252
|
-
return
|
|
253
|
-
time_column_dict,
|
|
254
|
-
end_time_column_dict,
|
|
255
|
-
time_dict,
|
|
256
|
-
min_time,
|
|
257
|
-
max_time,
|
|
258
|
-
)
|
|
196
|
+
return time_dict, min_max_time_dict
|
|
259
197
|
|
|
260
198
|
def get_csc(
|
|
261
199
|
self,
|
|
262
|
-
graph:
|
|
263
|
-
) ->
|
|
264
|
-
|
|
265
|
-
|
|
200
|
+
graph: 'Graph',
|
|
201
|
+
) -> tuple[
|
|
202
|
+
dict[tuple[str, str, str], np.ndarray],
|
|
203
|
+
dict[tuple[str, str, str], np.ndarray],
|
|
266
204
|
]:
|
|
267
205
|
# A mapping from raw primary keys to node indices (0 to N-1):
|
|
268
|
-
map_dict:
|
|
206
|
+
map_dict: dict[str, pd.CategoricalDtype] = {}
|
|
269
207
|
# A dictionary to manage offsets of node indices for invalid rows:
|
|
270
|
-
offset_dict:
|
|
271
|
-
for table_name in
|
|
208
|
+
offset_dict: dict[str, np.ndarray] = {}
|
|
209
|
+
for table_name in {edge.dst_table for edge in graph.edges}:
|
|
272
210
|
ser = self.df_dict[table_name][graph[table_name]._primary_key]
|
|
273
211
|
if table_name in self.mask_dict.keys():
|
|
274
212
|
mask = self.mask_dict[table_name]
|
|
@@ -277,8 +215,8 @@ class LocalGraphStore:
|
|
|
277
215
|
map_dict[table_name] = pd.CategoricalDtype(ser, ordered=True)
|
|
278
216
|
|
|
279
217
|
# Build CSC graph representation:
|
|
280
|
-
row_dict:
|
|
281
|
-
colptr_dict:
|
|
218
|
+
row_dict: dict[tuple[str, str, str], np.ndarray] = {}
|
|
219
|
+
colptr_dict: dict[tuple[str, str, str], np.ndarray] = {}
|
|
282
220
|
for src_table, fkey, dst_table in graph.edges:
|
|
283
221
|
src_df = self.df_dict[src_table]
|
|
284
222
|
dst_df = self.df_dict[dst_table]
|
|
@@ -340,7 +278,7 @@ def _argsort(input: np.ndarray) -> np.ndarray:
|
|
|
340
278
|
return torch.from_numpy(input).argsort().numpy()
|
|
341
279
|
|
|
342
280
|
|
|
343
|
-
def _lexsort(inputs:
|
|
281
|
+
def _lexsort(inputs: list[np.ndarray]) -> np.ndarray:
|
|
344
282
|
assert len(inputs) >= 1
|
|
345
283
|
|
|
346
284
|
if not WITH_TORCH:
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Literal
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
6
|
+
|
|
7
|
+
from kumoai.experimental.rfm.backend.local import LocalGraphStore
|
|
8
|
+
from kumoai.experimental.rfm.base import Sampler, SamplerOutput
|
|
9
|
+
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
10
|
+
from kumoai.utils import ProgressLogger
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from kumoai.experimental.rfm import Graph
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LocalSampler(Sampler):
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
graph: 'Graph',
|
|
20
|
+
verbose: bool | ProgressLogger = True,
|
|
21
|
+
) -> None:
|
|
22
|
+
super().__init__(graph=graph, verbose=verbose)
|
|
23
|
+
|
|
24
|
+
import kumoai.kumolib as kumolib
|
|
25
|
+
|
|
26
|
+
self._graph_store = LocalGraphStore(graph, verbose)
|
|
27
|
+
self._graph_sampler = kumolib.NeighborSampler(
|
|
28
|
+
list(self.table_stype_dict.keys()),
|
|
29
|
+
self.edge_types,
|
|
30
|
+
{
|
|
31
|
+
'__'.join(edge_type): colptr
|
|
32
|
+
for edge_type, colptr in self._graph_store.colptr_dict.items()
|
|
33
|
+
},
|
|
34
|
+
{
|
|
35
|
+
'__'.join(edge_type): row
|
|
36
|
+
for edge_type, row in self._graph_store.row_dict.items()
|
|
37
|
+
},
|
|
38
|
+
self._graph_store.time_dict,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def _get_min_max_time_dict(
|
|
42
|
+
self,
|
|
43
|
+
table_names: list[str],
|
|
44
|
+
) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
|
|
45
|
+
return {
|
|
46
|
+
key: value
|
|
47
|
+
for key, value in self._graph_store.min_max_time_dict.items()
|
|
48
|
+
if key in table_names
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
def _sample_subgraph(
|
|
52
|
+
self,
|
|
53
|
+
entity_table_name: str,
|
|
54
|
+
entity_pkey: pd.Series,
|
|
55
|
+
anchor_time: pd.Series | Literal['entity'],
|
|
56
|
+
columns_dict: dict[str, set[str]],
|
|
57
|
+
num_neighbors: list[int],
|
|
58
|
+
) -> SamplerOutput:
|
|
59
|
+
|
|
60
|
+
index = self._graph_store.get_node_id(entity_table_name, entity_pkey)
|
|
61
|
+
|
|
62
|
+
if isinstance(anchor_time, pd.Series):
|
|
63
|
+
time = anchor_time.astype(int).to_numpy() // 1000**3 # to seconds
|
|
64
|
+
else:
|
|
65
|
+
assert anchor_time == 'entity'
|
|
66
|
+
time = self._graph_store.time_dict[entity_table_name][index]
|
|
67
|
+
|
|
68
|
+
(
|
|
69
|
+
row_dict,
|
|
70
|
+
col_dict,
|
|
71
|
+
node_dict,
|
|
72
|
+
batch_dict,
|
|
73
|
+
num_sampled_nodes_dict,
|
|
74
|
+
num_sampled_edges_dict,
|
|
75
|
+
) = self._graph_sampler.sample(
|
|
76
|
+
{
|
|
77
|
+
'__'.join(edge_type): num_neighbors
|
|
78
|
+
for edge_type in self.edge_types
|
|
79
|
+
},
|
|
80
|
+
{},
|
|
81
|
+
entity_table_name,
|
|
82
|
+
index,
|
|
83
|
+
time,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
df_dict: dict[str, pd.DataFrame] = {}
|
|
87
|
+
inverse_dict: dict[str, np.ndarray] = {}
|
|
88
|
+
for table_name, node in node_dict.items():
|
|
89
|
+
df = self._graph_store.df_dict[table_name]
|
|
90
|
+
columns = columns_dict[table_name]
|
|
91
|
+
if self.end_time_column_dict.get(table_name, None) in columns:
|
|
92
|
+
df = df.iloc[node]
|
|
93
|
+
elif len(columns) == 0:
|
|
94
|
+
df = df.iloc[node]
|
|
95
|
+
else:
|
|
96
|
+
# Only store unique rows in `df` above a certain threshold:
|
|
97
|
+
unique_node, inverse = np.unique(node, return_inverse=True)
|
|
98
|
+
if len(node) > 1.05 * len(unique_node):
|
|
99
|
+
df = df.iloc[unique_node]
|
|
100
|
+
inverse_dict[table_name] = inverse
|
|
101
|
+
else:
|
|
102
|
+
df = df.iloc[node]
|
|
103
|
+
df = df.reset_index(drop=True)
|
|
104
|
+
df = df[list(columns)]
|
|
105
|
+
df_dict[table_name] = df
|
|
106
|
+
|
|
107
|
+
num_sampled_nodes_dict = {
|
|
108
|
+
table_name: num_sampled_nodes.tolist()
|
|
109
|
+
for table_name, num_sampled_nodes in
|
|
110
|
+
num_sampled_nodes_dict.items()
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
row_dict = {
|
|
114
|
+
edge_type: row_dict['__'.join(edge_type)]
|
|
115
|
+
for edge_type in self.edge_types
|
|
116
|
+
}
|
|
117
|
+
col_dict = {
|
|
118
|
+
edge_type: col_dict['__'.join(edge_type)]
|
|
119
|
+
for edge_type in self.edge_types
|
|
120
|
+
}
|
|
121
|
+
num_sampled_edges_dict = {
|
|
122
|
+
edge_type: num_sampled_edges_dict['__'.join(edge_type)].tolist()
|
|
123
|
+
for edge_type in self.edge_types
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
return SamplerOutput(
|
|
127
|
+
anchor_time=time * 1000**3, # to nanoseconds
|
|
128
|
+
df_dict=df_dict,
|
|
129
|
+
inverse_dict=inverse_dict,
|
|
130
|
+
batch_dict=batch_dict,
|
|
131
|
+
num_sampled_nodes_dict=num_sampled_nodes_dict,
|
|
132
|
+
row_dict=row_dict,
|
|
133
|
+
col_dict=col_dict,
|
|
134
|
+
num_sampled_edges_dict=num_sampled_edges_dict,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def _sample_entity_table(
|
|
138
|
+
self,
|
|
139
|
+
table_name: str,
|
|
140
|
+
columns: set[str],
|
|
141
|
+
num_rows: int,
|
|
142
|
+
random_seed: int | None = None,
|
|
143
|
+
) -> pd.DataFrame:
|
|
144
|
+
pkey_map = self._graph_store.pkey_map_dict[table_name]
|
|
145
|
+
if len(pkey_map) > num_rows:
|
|
146
|
+
pkey_map = pkey_map.sample(
|
|
147
|
+
n=num_rows,
|
|
148
|
+
random_state=random_seed,
|
|
149
|
+
ignore_index=True,
|
|
150
|
+
)
|
|
151
|
+
df = self._graph_store.df_dict[table_name]
|
|
152
|
+
df = df.iloc[pkey_map['arange']][list(columns)]
|
|
153
|
+
return df
|
|
154
|
+
|
|
155
|
+
def _sample_target(
|
|
156
|
+
self,
|
|
157
|
+
query: ValidatedPredictiveQuery,
|
|
158
|
+
entity_df: pd.DataFrame,
|
|
159
|
+
train_index: np.ndarray,
|
|
160
|
+
train_time: pd.Series,
|
|
161
|
+
num_train_examples: int,
|
|
162
|
+
test_index: np.ndarray,
|
|
163
|
+
test_time: pd.Series,
|
|
164
|
+
num_test_examples: int,
|
|
165
|
+
columns_dict: dict[str, set[str]],
|
|
166
|
+
time_offset_dict: dict[
|
|
167
|
+
tuple[str, str, str],
|
|
168
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
169
|
+
],
|
|
170
|
+
) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
|
|
171
|
+
|
|
172
|
+
train_y, train_mask = self._sample_target_set(
|
|
173
|
+
query=query,
|
|
174
|
+
pkey=entity_df[self.primary_key_dict[query.entity_table]],
|
|
175
|
+
index=train_index,
|
|
176
|
+
anchor_time=train_time,
|
|
177
|
+
num_examples=num_train_examples,
|
|
178
|
+
columns_dict=columns_dict,
|
|
179
|
+
time_offset_dict=time_offset_dict,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
test_y, test_mask = self._sample_target_set(
|
|
183
|
+
query=query,
|
|
184
|
+
pkey=entity_df[self.primary_key_dict[query.entity_table]],
|
|
185
|
+
index=test_index,
|
|
186
|
+
anchor_time=test_time,
|
|
187
|
+
num_examples=num_test_examples,
|
|
188
|
+
columns_dict=columns_dict,
|
|
189
|
+
time_offset_dict=time_offset_dict,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
return train_y, train_mask, test_y, test_mask
|
|
193
|
+
|
|
194
|
+
# Helper Methods ##########################################################
|
|
195
|
+
|
|
196
|
+
def _sample_target_set(
|
|
197
|
+
self,
|
|
198
|
+
query: ValidatedPredictiveQuery,
|
|
199
|
+
pkey: pd.Series,
|
|
200
|
+
index: np.ndarray,
|
|
201
|
+
anchor_time: pd.Series,
|
|
202
|
+
num_examples: int,
|
|
203
|
+
columns_dict: dict[str, set[str]],
|
|
204
|
+
time_offset_dict: dict[
|
|
205
|
+
tuple[str, str, str],
|
|
206
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
207
|
+
],
|
|
208
|
+
batch_size: int = 10_000,
|
|
209
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
210
|
+
|
|
211
|
+
num_hops = 1 if len(time_offset_dict) > 0 else 0
|
|
212
|
+
num_neighbors_dict: dict[str, list[int]] = {}
|
|
213
|
+
unix_time_offset_dict: dict[str, list[list[int | None]]] = {}
|
|
214
|
+
for edge_type, (start, end) in time_offset_dict.items():
|
|
215
|
+
unix_time_offset_dict['__'.join(edge_type)] = [[
|
|
216
|
+
date_offset_to_seconds(start) if start is not None else None,
|
|
217
|
+
date_offset_to_seconds(end),
|
|
218
|
+
]]
|
|
219
|
+
for edge_type in set(self.edge_types) - set(time_offset_dict.keys()):
|
|
220
|
+
num_neighbors_dict['__'.join(edge_type)] = [0] * num_hops
|
|
221
|
+
|
|
222
|
+
count = 0
|
|
223
|
+
ys: list[pd.Series] = []
|
|
224
|
+
mask = np.full(len(index), False, dtype=bool)
|
|
225
|
+
for start in range(0, len(index), batch_size):
|
|
226
|
+
subset = pkey.iloc[index[start:start + batch_size]]
|
|
227
|
+
time = anchor_time.iloc[start:start + batch_size]
|
|
228
|
+
|
|
229
|
+
_, _, node_dict, batch_dict, _, _ = self._graph_sampler.sample(
|
|
230
|
+
num_neighbors_dict,
|
|
231
|
+
unix_time_offset_dict,
|
|
232
|
+
query.entity_table,
|
|
233
|
+
self._graph_store.get_node_id(query.entity_table, subset),
|
|
234
|
+
time.astype(int).to_numpy() // 1000**3, # to seconds
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
feat_dict: dict[str, pd.DataFrame] = {}
|
|
238
|
+
time_dict: dict[str, pd.Series] = {}
|
|
239
|
+
for table_name, columns in columns_dict.items():
|
|
240
|
+
df = self._graph_store.df_dict[table_name]
|
|
241
|
+
df = df.iloc[node_dict[table_name]].reset_index(drop=True)
|
|
242
|
+
df = df[list(columns)]
|
|
243
|
+
feat_dict[table_name] = df
|
|
244
|
+
|
|
245
|
+
time_column = self.time_column_dict.get(table_name)
|
|
246
|
+
if time_column in columns:
|
|
247
|
+
time_dict[table_name] = df[time_column]
|
|
248
|
+
|
|
249
|
+
y, _mask = PQueryPandasExecutor().execute(
|
|
250
|
+
query=query,
|
|
251
|
+
feat_dict=feat_dict,
|
|
252
|
+
time_dict=time_dict,
|
|
253
|
+
batch_dict=batch_dict,
|
|
254
|
+
anchor_time=time,
|
|
255
|
+
num_forecasts=query.num_forecasts,
|
|
256
|
+
)
|
|
257
|
+
ys.append(y)
|
|
258
|
+
mask[start:start + batch_size] = _mask
|
|
259
|
+
|
|
260
|
+
count += len(y)
|
|
261
|
+
if count >= num_examples:
|
|
262
|
+
break
|
|
263
|
+
|
|
264
|
+
if len(ys) == 0:
|
|
265
|
+
y = pd.Series([], dtype=float)
|
|
266
|
+
elif len(ys) == 1:
|
|
267
|
+
y = ys[0]
|
|
268
|
+
else:
|
|
269
|
+
y = pd.concat(ys, axis=0, ignore_index=True)
|
|
270
|
+
|
|
271
|
+
return y, mask
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
# Helper Functions ############################################################
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def date_offset_to_seconds(offset: pd.DateOffset) -> int:
|
|
278
|
+
r"""Convert a :class:`pandas.DateOffset` into a number of seconds.
|
|
279
|
+
|
|
280
|
+
.. note::
|
|
281
|
+
We are conservative and take months and years as their maximum value.
|
|
282
|
+
Additional values are then dropped in label computation where we know
|
|
283
|
+
the actual dates.
|
|
284
|
+
"""
|
|
285
|
+
MAX_DAYS_IN_MONTH = 31
|
|
286
|
+
MAX_DAYS_IN_YEAR = 366
|
|
287
|
+
|
|
288
|
+
SECONDS_IN_MINUTE = 60
|
|
289
|
+
SECONDS_IN_HOUR = 60 * SECONDS_IN_MINUTE
|
|
290
|
+
SECONDS_IN_DAY = 24 * SECONDS_IN_HOUR
|
|
291
|
+
|
|
292
|
+
total_sec = 0
|
|
293
|
+
multiplier = getattr(offset, 'n', 1) # The multiplier (if present).
|
|
294
|
+
|
|
295
|
+
for attr, value in offset.__dict__.items():
|
|
296
|
+
if value is None or value == 0:
|
|
297
|
+
continue
|
|
298
|
+
scaled_value = value * multiplier
|
|
299
|
+
if attr == 'years':
|
|
300
|
+
total_sec += scaled_value * MAX_DAYS_IN_YEAR * SECONDS_IN_DAY
|
|
301
|
+
elif attr == 'months':
|
|
302
|
+
total_sec += scaled_value * MAX_DAYS_IN_MONTH * SECONDS_IN_DAY
|
|
303
|
+
elif attr == 'days':
|
|
304
|
+
total_sec += scaled_value * SECONDS_IN_DAY
|
|
305
|
+
elif attr == 'hours':
|
|
306
|
+
total_sec += scaled_value * SECONDS_IN_HOUR
|
|
307
|
+
elif attr == 'minutes':
|
|
308
|
+
total_sec += scaled_value * SECONDS_IN_MINUTE
|
|
309
|
+
elif attr == 'seconds':
|
|
310
|
+
total_sec += scaled_value
|
|
311
|
+
|
|
312
|
+
return total_sec
|