kumoai 2.12.0.dev202510231830__cp311-cp311-win_amd64.whl → 2.14.0.dev202512311733__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +41 -35
- kumoai/_version.py +1 -1
- kumoai/client/client.py +15 -13
- kumoai/client/endpoints.py +1 -0
- kumoai/client/jobs.py +24 -0
- kumoai/client/pquery.py +6 -2
- kumoai/client/rfm.py +35 -7
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +191 -48
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +735 -0
- kumoai/experimental/rfm/graph.py +1237 -0
- kumoai/experimental/rfm/infer/__init__.py +8 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/pquery/__init__.py +0 -4
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +64 -40
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +386 -276
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/kumolib.cp311-win_amd64.pyd +0 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/trainer/trainer.py +9 -10
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +51 -0
- kumoai/utils/progress_logger.py +188 -16
- kumoai/utils/sql.py +3 -0
- {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/METADATA +13 -2
- {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/RECORD +57 -36
- kumoai/experimental/rfm/local_graph.py +0 -810
- kumoai/experimental/rfm/local_graph_sampler.py +0 -184
- kumoai/experimental/rfm/local_pquery_driver.py +0 -494
- kumoai/experimental/rfm/local_table.py +0 -545
- kumoai/experimental/rfm/pquery/backend.py +0 -136
- kumoai/experimental/rfm/pquery/pandas_backend.py +0 -478
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/WHEEL +0 -0
- {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/top_level.txt +0 -0
|
@@ -1,184 +0,0 @@
|
|
|
1
|
-
from typing import Dict, List, Optional, Tuple
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import pandas as pd
|
|
5
|
-
from kumoapi.model_plan import RunMode
|
|
6
|
-
from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
|
|
7
|
-
from kumoapi.typing import Stype
|
|
8
|
-
|
|
9
|
-
import kumoai.kumolib as kumolib
|
|
10
|
-
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
11
|
-
from kumoai.experimental.rfm.utils import normalize_text
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class LocalGraphSampler:
|
|
15
|
-
def __init__(self, graph_store: LocalGraphStore) -> None:
|
|
16
|
-
self._graph_store = graph_store
|
|
17
|
-
self._sampler = kumolib.NeighborSampler(
|
|
18
|
-
self._graph_store.node_types,
|
|
19
|
-
self._graph_store.edge_types,
|
|
20
|
-
{
|
|
21
|
-
'__'.join(edge_type): colptr
|
|
22
|
-
for edge_type, colptr in self._graph_store.colptr_dict.items()
|
|
23
|
-
},
|
|
24
|
-
{
|
|
25
|
-
'__'.join(edge_type): row
|
|
26
|
-
for edge_type, row in self._graph_store.row_dict.items()
|
|
27
|
-
},
|
|
28
|
-
self._graph_store.time_dict,
|
|
29
|
-
)
|
|
30
|
-
|
|
31
|
-
def __call__(
|
|
32
|
-
self,
|
|
33
|
-
entity_table_names: Tuple[str, ...],
|
|
34
|
-
node: np.ndarray,
|
|
35
|
-
time: np.ndarray,
|
|
36
|
-
run_mode: RunMode,
|
|
37
|
-
num_neighbors: List[int],
|
|
38
|
-
exclude_cols_dict: Dict[str, List[str]],
|
|
39
|
-
) -> Subgraph:
|
|
40
|
-
|
|
41
|
-
(
|
|
42
|
-
row_dict,
|
|
43
|
-
col_dict,
|
|
44
|
-
node_dict,
|
|
45
|
-
batch_dict,
|
|
46
|
-
num_sampled_nodes_dict,
|
|
47
|
-
num_sampled_edges_dict,
|
|
48
|
-
) = self._sampler.sample(
|
|
49
|
-
{
|
|
50
|
-
'__'.join(edge_type): num_neighbors
|
|
51
|
-
for edge_type in self._graph_store.edge_types
|
|
52
|
-
},
|
|
53
|
-
{}, # time interval based sampling
|
|
54
|
-
entity_table_names[0],
|
|
55
|
-
node,
|
|
56
|
-
time // 1000**3, # nanoseconds to seconds
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
table_dict: Dict[str, Table] = {}
|
|
60
|
-
for table_name, node in node_dict.items():
|
|
61
|
-
batch = batch_dict[table_name]
|
|
62
|
-
|
|
63
|
-
if len(node) == 0:
|
|
64
|
-
continue
|
|
65
|
-
|
|
66
|
-
df = self._graph_store.df_dict[table_name]
|
|
67
|
-
|
|
68
|
-
num_sampled_nodes = num_sampled_nodes_dict[table_name].tolist()
|
|
69
|
-
stype_dict = { # Exclude target columns:
|
|
70
|
-
column_name: stype
|
|
71
|
-
for column_name, stype in
|
|
72
|
-
self._graph_store.stype_dict[table_name].items()
|
|
73
|
-
if column_name not in exclude_cols_dict.get(table_name, [])
|
|
74
|
-
}
|
|
75
|
-
primary_key: Optional[str] = None
|
|
76
|
-
if table_name in entity_table_names:
|
|
77
|
-
primary_key = self._graph_store.pkey_name_dict.get(table_name)
|
|
78
|
-
|
|
79
|
-
columns: List[str] = []
|
|
80
|
-
if table_name in entity_table_names:
|
|
81
|
-
columns += [self._graph_store.pkey_name_dict[table_name]]
|
|
82
|
-
columns += list(stype_dict.keys())
|
|
83
|
-
|
|
84
|
-
if len(columns) == 0:
|
|
85
|
-
table_dict[table_name] = Table(
|
|
86
|
-
df=pd.DataFrame(index=range(len(node))),
|
|
87
|
-
row=None,
|
|
88
|
-
batch=batch,
|
|
89
|
-
num_sampled_nodes=num_sampled_nodes,
|
|
90
|
-
stype_dict=stype_dict,
|
|
91
|
-
primary_key=primary_key,
|
|
92
|
-
)
|
|
93
|
-
continue
|
|
94
|
-
|
|
95
|
-
row: Optional[np.ndarray] = None
|
|
96
|
-
if table_name in self._graph_store.end_time_column_dict:
|
|
97
|
-
# Set end time to NaT for all values greater than anchor time:
|
|
98
|
-
df = df.iloc[node].reset_index(drop=True)
|
|
99
|
-
col_name = self._graph_store.end_time_column_dict[table_name]
|
|
100
|
-
ser = df[col_name]
|
|
101
|
-
value = ser.astype('datetime64[ns]').astype(int).to_numpy()
|
|
102
|
-
mask = value > time[batch]
|
|
103
|
-
df.loc[mask, col_name] = pd.NaT
|
|
104
|
-
else:
|
|
105
|
-
# Only store unique rows in `df` above a certain threshold:
|
|
106
|
-
unique_node, inverse = np.unique(node, return_inverse=True)
|
|
107
|
-
if len(node) > 1.05 * len(unique_node):
|
|
108
|
-
df = df.iloc[unique_node].reset_index(drop=True)
|
|
109
|
-
row = inverse
|
|
110
|
-
else:
|
|
111
|
-
df = df.iloc[node].reset_index(drop=True)
|
|
112
|
-
|
|
113
|
-
# Filter data frame to minimal set of columns:
|
|
114
|
-
df = df[columns]
|
|
115
|
-
|
|
116
|
-
# Normalize text (if not already pre-processed):
|
|
117
|
-
for column_name, stype in stype_dict.items():
|
|
118
|
-
if stype == Stype.text:
|
|
119
|
-
df[column_name] = normalize_text(df[column_name])
|
|
120
|
-
|
|
121
|
-
table_dict[table_name] = Table(
|
|
122
|
-
df=df,
|
|
123
|
-
row=row,
|
|
124
|
-
batch=batch,
|
|
125
|
-
num_sampled_nodes=num_sampled_nodes,
|
|
126
|
-
stype_dict=stype_dict,
|
|
127
|
-
primary_key=primary_key,
|
|
128
|
-
)
|
|
129
|
-
|
|
130
|
-
link_dict: Dict[Tuple[str, str, str], Link] = {}
|
|
131
|
-
for edge_type in self._graph_store.edge_types:
|
|
132
|
-
edge_type_str = '__'.join(edge_type)
|
|
133
|
-
|
|
134
|
-
row = row_dict[edge_type_str]
|
|
135
|
-
col = col_dict[edge_type_str]
|
|
136
|
-
|
|
137
|
-
if len(row) == 0:
|
|
138
|
-
continue
|
|
139
|
-
|
|
140
|
-
# Do not store reverse edge type if it is a replica:
|
|
141
|
-
rev_edge_type = Subgraph.rev_edge_type(edge_type)
|
|
142
|
-
rev_edge_type_str = '__'.join(rev_edge_type)
|
|
143
|
-
if (rev_edge_type in link_dict
|
|
144
|
-
and np.array_equal(row, col_dict[rev_edge_type_str])
|
|
145
|
-
and np.array_equal(col, row_dict[rev_edge_type_str])):
|
|
146
|
-
link = Link(
|
|
147
|
-
layout=EdgeLayout.REV,
|
|
148
|
-
row=None,
|
|
149
|
-
col=None,
|
|
150
|
-
num_sampled_edges=(
|
|
151
|
-
num_sampled_edges_dict[edge_type_str].tolist()),
|
|
152
|
-
)
|
|
153
|
-
link_dict[edge_type] = link
|
|
154
|
-
continue
|
|
155
|
-
|
|
156
|
-
layout = EdgeLayout.COO
|
|
157
|
-
if np.array_equal(row, np.arange(len(row))):
|
|
158
|
-
row = None
|
|
159
|
-
if np.array_equal(col, np.arange(len(col))):
|
|
160
|
-
col = None
|
|
161
|
-
|
|
162
|
-
# Store in compressed representation if more efficient:
|
|
163
|
-
num_cols = table_dict[edge_type[2]].num_rows
|
|
164
|
-
if col is not None and len(col) > num_cols + 1:
|
|
165
|
-
layout = EdgeLayout.CSC
|
|
166
|
-
colcount = np.bincount(col, minlength=num_cols)
|
|
167
|
-
col = np.empty(num_cols + 1, dtype=col.dtype)
|
|
168
|
-
col[0] = 0
|
|
169
|
-
np.cumsum(colcount, out=col[1:])
|
|
170
|
-
|
|
171
|
-
link = Link(
|
|
172
|
-
layout=layout,
|
|
173
|
-
row=row,
|
|
174
|
-
col=col,
|
|
175
|
-
num_sampled_edges=(
|
|
176
|
-
num_sampled_edges_dict[edge_type_str].tolist()),
|
|
177
|
-
)
|
|
178
|
-
link_dict[edge_type] = link
|
|
179
|
-
|
|
180
|
-
return Subgraph(
|
|
181
|
-
anchor_time=time,
|
|
182
|
-
table_dict=table_dict,
|
|
183
|
-
link_dict=link_dict,
|
|
184
|
-
)
|
|
@@ -1,494 +0,0 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
import pandas as pd
|
|
6
|
-
from kumoapi.pquery import QueryType
|
|
7
|
-
from kumoapi.rfm import PQueryDefinition
|
|
8
|
-
|
|
9
|
-
import kumoai.kumolib as kumolib
|
|
10
|
-
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
11
|
-
from kumoai.experimental.rfm.pquery import PQueryPandasBackend
|
|
12
|
-
|
|
13
|
-
_coverage_warned = False
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class LocalPQueryDriver:
|
|
17
|
-
def __init__(
|
|
18
|
-
self,
|
|
19
|
-
graph_store: LocalGraphStore,
|
|
20
|
-
query: PQueryDefinition,
|
|
21
|
-
random_seed: Optional[int] = None,
|
|
22
|
-
) -> None:
|
|
23
|
-
self._graph_store = graph_store
|
|
24
|
-
self._query = query
|
|
25
|
-
self._random_seed = random_seed
|
|
26
|
-
self._rng = np.random.default_rng(random_seed)
|
|
27
|
-
|
|
28
|
-
def _get_candidates(
|
|
29
|
-
self,
|
|
30
|
-
exclude_node: Optional[np.ndarray] = None,
|
|
31
|
-
) -> np.ndarray:
|
|
32
|
-
|
|
33
|
-
if self._query.query_type == QueryType.TEMPORAL:
|
|
34
|
-
assert exclude_node is None
|
|
35
|
-
|
|
36
|
-
table_name = self._query.entity.pkey.table_name
|
|
37
|
-
num_nodes = len(self._graph_store.df_dict[table_name])
|
|
38
|
-
mask_dict = self._graph_store.mask_dict
|
|
39
|
-
|
|
40
|
-
candidate: np.ndarray
|
|
41
|
-
|
|
42
|
-
# Case 1: All nodes are valid and nothing to exclude:
|
|
43
|
-
if exclude_node is None and table_name not in mask_dict:
|
|
44
|
-
candidate = np.arange(num_nodes)
|
|
45
|
-
|
|
46
|
-
# Case 2: Not all nodes are valid - lookup valid nodes:
|
|
47
|
-
if exclude_node is None:
|
|
48
|
-
pkey_map = self._graph_store.pkey_map_dict[table_name]
|
|
49
|
-
candidate = pkey_map['arange'].to_numpy().copy()
|
|
50
|
-
|
|
51
|
-
# Case 3: Exclude nodes - use a mask to exclude them:
|
|
52
|
-
else:
|
|
53
|
-
mask = np.full((num_nodes, ), fill_value=True, dtype=bool)
|
|
54
|
-
mask[exclude_node] = False
|
|
55
|
-
if table_name in mask_dict:
|
|
56
|
-
mask &= mask_dict[table_name]
|
|
57
|
-
candidate = mask.nonzero()[0]
|
|
58
|
-
|
|
59
|
-
self._rng.shuffle(candidate)
|
|
60
|
-
|
|
61
|
-
return candidate
|
|
62
|
-
|
|
63
|
-
def _filter_candidates_by_time(
|
|
64
|
-
self,
|
|
65
|
-
candidate: np.ndarray,
|
|
66
|
-
anchor_time: pd.Timestamp,
|
|
67
|
-
) -> np.ndarray:
|
|
68
|
-
|
|
69
|
-
entity = self._query.entity.pkey.table_name
|
|
70
|
-
|
|
71
|
-
# Filter out entities that do not exist yet in time:
|
|
72
|
-
time_sec = self._graph_store.time_dict.get(entity)
|
|
73
|
-
if time_sec is not None:
|
|
74
|
-
mask = time_sec[candidate] <= (anchor_time.value // (1000**3))
|
|
75
|
-
candidate = candidate[mask]
|
|
76
|
-
|
|
77
|
-
# Filter out entities that no longer exist in time:
|
|
78
|
-
end_time_col = self._graph_store.end_time_column_dict.get(entity)
|
|
79
|
-
if end_time_col is not None:
|
|
80
|
-
ser = self._graph_store.df_dict[entity][end_time_col]
|
|
81
|
-
ser = ser.iloc[candidate]
|
|
82
|
-
mask = (anchor_time < ser) | ser.isna().to_numpy()
|
|
83
|
-
candidate = candidate[mask]
|
|
84
|
-
|
|
85
|
-
return candidate
|
|
86
|
-
|
|
87
|
-
def collect_test(
|
|
88
|
-
self,
|
|
89
|
-
size: int,
|
|
90
|
-
anchor_time: Union[pd.Timestamp, Literal['entity']],
|
|
91
|
-
batch_size: Optional[int] = None,
|
|
92
|
-
max_iterations: int = 20,
|
|
93
|
-
guarantee_train_examples: bool = True,
|
|
94
|
-
) -> Tuple[np.ndarray, pd.Series, pd.Series]:
|
|
95
|
-
r"""Collects test nodes and their labels used for evaluation.
|
|
96
|
-
|
|
97
|
-
Args:
|
|
98
|
-
size: The number of test nodes to collect.
|
|
99
|
-
anchor_time: The anchor time.
|
|
100
|
-
batch_size: How many nodes to process in a single batch.
|
|
101
|
-
max_iterations: The number of steps to run before aborting.
|
|
102
|
-
guarantee_train_examples: Ensures that test examples do not occupy
|
|
103
|
-
the entire set of entity candidates.
|
|
104
|
-
|
|
105
|
-
Returns:
|
|
106
|
-
A triplet holding the nodes, timestamps and labels.
|
|
107
|
-
"""
|
|
108
|
-
batch_size = size if batch_size is None else batch_size
|
|
109
|
-
|
|
110
|
-
candidate = self._get_candidates()
|
|
111
|
-
|
|
112
|
-
nodes: List[np.ndarray] = []
|
|
113
|
-
times: List[pd.Series] = []
|
|
114
|
-
ys: List[pd.Series] = []
|
|
115
|
-
|
|
116
|
-
reached_end = False
|
|
117
|
-
num_labels = candidate_offset = 0
|
|
118
|
-
for _ in range(max_iterations):
|
|
119
|
-
node = candidate[candidate_offset:candidate_offset + batch_size]
|
|
120
|
-
|
|
121
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
122
|
-
node = self._filter_candidates_by_time(node, anchor_time)
|
|
123
|
-
time = pd.Series(anchor_time).repeat(len(node))
|
|
124
|
-
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
125
|
-
else:
|
|
126
|
-
assert anchor_time == 'entity'
|
|
127
|
-
time = self._graph_store.time_dict[
|
|
128
|
-
self._query.entity.pkey.table_name]
|
|
129
|
-
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
130
|
-
|
|
131
|
-
y, mask = self(node, time)
|
|
132
|
-
|
|
133
|
-
nodes.append(node[mask])
|
|
134
|
-
times.append(time[mask].reset_index(drop=True))
|
|
135
|
-
ys.append(y)
|
|
136
|
-
|
|
137
|
-
num_labels += len(y)
|
|
138
|
-
|
|
139
|
-
if num_labels > size:
|
|
140
|
-
reached_end = True
|
|
141
|
-
break # Sufficient number of labels collected. Abort.
|
|
142
|
-
|
|
143
|
-
candidate_offset += batch_size
|
|
144
|
-
if candidate_offset >= len(candidate):
|
|
145
|
-
reached_end = True
|
|
146
|
-
break
|
|
147
|
-
|
|
148
|
-
if len(nodes) > 1:
|
|
149
|
-
node = np.concatenate(nodes, axis=0)[:size]
|
|
150
|
-
time = pd.concat(times, axis=0).reset_index(drop=True).iloc[:size]
|
|
151
|
-
y = pd.concat(ys, axis=0).reset_index(drop=True).iloc[:size]
|
|
152
|
-
else:
|
|
153
|
-
node = nodes[0][:size]
|
|
154
|
-
time = times[0].iloc[:size]
|
|
155
|
-
y = ys[0].iloc[:size]
|
|
156
|
-
|
|
157
|
-
if len(node) == 0:
|
|
158
|
-
raise RuntimeError("Failed to collect any test examples for "
|
|
159
|
-
"evaluation. Is your predictive query too "
|
|
160
|
-
"restrictive?")
|
|
161
|
-
|
|
162
|
-
global _coverage_warned
|
|
163
|
-
if not _coverage_warned and not reached_end and len(node) < size // 2:
|
|
164
|
-
_coverage_warned = True
|
|
165
|
-
warnings.warn(f"Failed to collect {size:,} test examples within "
|
|
166
|
-
f"{max_iterations} iterations. To improve coverage, "
|
|
167
|
-
f"consider increasing the number of PQ iterations "
|
|
168
|
-
f"using the 'max_pq_iterations' option. This "
|
|
169
|
-
f"warning will not be shown again in this run.")
|
|
170
|
-
|
|
171
|
-
if (guarantee_train_examples
|
|
172
|
-
and self._query.query_type == QueryType.STATIC
|
|
173
|
-
and candidate_offset >= len(candidate)):
|
|
174
|
-
# In case all valid entities are used as test examples, we can no
|
|
175
|
-
# longer find any training example. Fallback to a 50/50 split:
|
|
176
|
-
size = len(node) // 2
|
|
177
|
-
node = node[:size]
|
|
178
|
-
time = time.iloc[:size]
|
|
179
|
-
y = y.iloc[:size]
|
|
180
|
-
|
|
181
|
-
return node, time, y
|
|
182
|
-
|
|
183
|
-
def collect_train(
|
|
184
|
-
self,
|
|
185
|
-
size: int,
|
|
186
|
-
anchor_time: Union[pd.Timestamp, Literal['entity']],
|
|
187
|
-
exclude_node: Optional[np.ndarray] = None,
|
|
188
|
-
batch_size: Optional[int] = None,
|
|
189
|
-
max_iterations: int = 20,
|
|
190
|
-
) -> Tuple[np.ndarray, pd.Series, pd.Series]:
|
|
191
|
-
r"""Collects training nodes and their labels.
|
|
192
|
-
|
|
193
|
-
Args:
|
|
194
|
-
size: The number of test nodes to collect.
|
|
195
|
-
anchor_time: The anchor time.
|
|
196
|
-
exclude_node: The nodes to exclude for use as in-context examples.
|
|
197
|
-
batch_size: How many nodes to process in a single batch.
|
|
198
|
-
max_iterations: The number of steps to run before aborting.
|
|
199
|
-
|
|
200
|
-
Returns:
|
|
201
|
-
A triplet holding the nodes, timestamps and labels.
|
|
202
|
-
"""
|
|
203
|
-
batch_size = size if batch_size is None else batch_size
|
|
204
|
-
|
|
205
|
-
candidate = self._get_candidates(exclude_node)
|
|
206
|
-
|
|
207
|
-
if len(candidate) == 0:
|
|
208
|
-
raise RuntimeError("Failed to generate any context examples "
|
|
209
|
-
"since not enough entities exist")
|
|
210
|
-
|
|
211
|
-
nodes: List[np.ndarray] = []
|
|
212
|
-
times: List[pd.Series] = []
|
|
213
|
-
ys: List[pd.Series] = []
|
|
214
|
-
|
|
215
|
-
reached_end = False
|
|
216
|
-
num_labels = candidate_offset = 0
|
|
217
|
-
for _ in range(max_iterations):
|
|
218
|
-
node = candidate[candidate_offset:candidate_offset + batch_size]
|
|
219
|
-
|
|
220
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
221
|
-
node = self._filter_candidates_by_time(node, anchor_time)
|
|
222
|
-
time = pd.Series(anchor_time).repeat(len(node))
|
|
223
|
-
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
224
|
-
else:
|
|
225
|
-
assert anchor_time == 'entity'
|
|
226
|
-
time = self._graph_store.time_dict[
|
|
227
|
-
self._query.entity.pkey.table_name]
|
|
228
|
-
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
229
|
-
|
|
230
|
-
y, mask = self(node, time)
|
|
231
|
-
|
|
232
|
-
nodes.append(node[mask])
|
|
233
|
-
times.append(time[mask].reset_index(drop=True))
|
|
234
|
-
ys.append(y)
|
|
235
|
-
|
|
236
|
-
num_labels += len(y)
|
|
237
|
-
|
|
238
|
-
if num_labels > size:
|
|
239
|
-
reached_end = True
|
|
240
|
-
break # Sufficient number of labels collected. Abort.
|
|
241
|
-
|
|
242
|
-
candidate_offset += batch_size
|
|
243
|
-
if candidate_offset >= len(candidate):
|
|
244
|
-
# Restart with an earlier anchor time (if applicable).
|
|
245
|
-
if self._query.query_type == QueryType.STATIC:
|
|
246
|
-
reached_end = True
|
|
247
|
-
break # Cannot jump back in time for static PQs. Abort.
|
|
248
|
-
if anchor_time == 'entity':
|
|
249
|
-
reached_end = True
|
|
250
|
-
break
|
|
251
|
-
candidate_offset = 0
|
|
252
|
-
anchor_time = anchor_time - (self._query.target.end_offset *
|
|
253
|
-
self._query.num_forecasts)
|
|
254
|
-
if anchor_time < self._graph_store.min_time:
|
|
255
|
-
reached_end = True
|
|
256
|
-
break # No earlier anchor time left. Abort.
|
|
257
|
-
|
|
258
|
-
if len(nodes) > 1:
|
|
259
|
-
node = np.concatenate(nodes, axis=0)[:size]
|
|
260
|
-
time = pd.concat(times, axis=0).reset_index(drop=True).iloc[:size]
|
|
261
|
-
y = pd.concat(ys, axis=0).reset_index(drop=True).iloc[:size]
|
|
262
|
-
else:
|
|
263
|
-
node = nodes[0][:size]
|
|
264
|
-
time = times[0].iloc[:size]
|
|
265
|
-
y = ys[0].iloc[:size]
|
|
266
|
-
|
|
267
|
-
if len(node) == 0:
|
|
268
|
-
raise ValueError("Failed to collect any context examples. Is your "
|
|
269
|
-
"predictive query too restrictive?")
|
|
270
|
-
|
|
271
|
-
global _coverage_warned
|
|
272
|
-
if not _coverage_warned and not reached_end and len(node) < size // 2:
|
|
273
|
-
_coverage_warned = True
|
|
274
|
-
warnings.warn(f"Failed to collect {size:,} context examples "
|
|
275
|
-
f"within {max_iterations} iterations. To improve "
|
|
276
|
-
f"coverage, consider increasing the number of PQ "
|
|
277
|
-
f"iterations using the 'max_pq_iterations' option. "
|
|
278
|
-
f"This warning will not be shown again in this run.")
|
|
279
|
-
|
|
280
|
-
return node, time, y
|
|
281
|
-
|
|
282
|
-
def is_valid(
|
|
283
|
-
self,
|
|
284
|
-
node: np.ndarray,
|
|
285
|
-
anchor_time: Union[pd.Timestamp, Literal['entity']],
|
|
286
|
-
batch_size: int = 10_000,
|
|
287
|
-
) -> np.ndarray:
|
|
288
|
-
r"""Denotes which nodes are valid for a given anchor time, *e.g.*,
|
|
289
|
-
which nodes fulfill entity filter constraints.
|
|
290
|
-
|
|
291
|
-
Args:
|
|
292
|
-
node: The nodes to check for.
|
|
293
|
-
anchor_time: The anchor time.
|
|
294
|
-
batch_size: How many nodes to process in a single batch.
|
|
295
|
-
|
|
296
|
-
Returns:
|
|
297
|
-
The mask.
|
|
298
|
-
"""
|
|
299
|
-
mask: Optional[np.ndarray] = None
|
|
300
|
-
|
|
301
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
302
|
-
node = self._filter_candidates_by_time(node, anchor_time)
|
|
303
|
-
time = pd.Series(anchor_time).repeat(len(node))
|
|
304
|
-
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
305
|
-
else:
|
|
306
|
-
assert anchor_time == 'entity'
|
|
307
|
-
time = self._graph_store.time_dict[
|
|
308
|
-
self._query.entity.pkey.table_name]
|
|
309
|
-
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
310
|
-
|
|
311
|
-
if self._query.entity.filter is not None:
|
|
312
|
-
# Mask out via (temporal) entity filter:
|
|
313
|
-
backend = PQueryPandasBackend()
|
|
314
|
-
masks: List[np.ndarray] = []
|
|
315
|
-
for start in range(0, len(node), batch_size):
|
|
316
|
-
feat_dict, time_dict, batch_dict = self._sample(
|
|
317
|
-
node[start:start + batch_size],
|
|
318
|
-
time.iloc[start:start + batch_size],
|
|
319
|
-
)
|
|
320
|
-
_mask = backend.eval_filter(
|
|
321
|
-
filter=self._query.entity.filter,
|
|
322
|
-
feat_dict=feat_dict,
|
|
323
|
-
time_dict=time_dict,
|
|
324
|
-
batch_dict=batch_dict,
|
|
325
|
-
anchor_time=time.iloc[start:start + batch_size],
|
|
326
|
-
)
|
|
327
|
-
masks.append(_mask)
|
|
328
|
-
|
|
329
|
-
_mask = np.concatenate(masks)
|
|
330
|
-
mask = (mask & _mask) if mask is not None else _mask
|
|
331
|
-
|
|
332
|
-
if mask is None:
|
|
333
|
-
mask = np.ones(len(node), dtype=bool)
|
|
334
|
-
|
|
335
|
-
return mask
|
|
336
|
-
|
|
337
|
-
def _sample(
|
|
338
|
-
self,
|
|
339
|
-
node: np.ndarray,
|
|
340
|
-
anchor_time: pd.Series,
|
|
341
|
-
) -> Tuple[
|
|
342
|
-
Dict[str, pd.DataFrame],
|
|
343
|
-
Dict[str, pd.Series],
|
|
344
|
-
Dict[str, np.ndarray],
|
|
345
|
-
]:
|
|
346
|
-
r"""Samples a subgraph that contains all relevant information to
|
|
347
|
-
evaluate the predictive query.
|
|
348
|
-
|
|
349
|
-
Args:
|
|
350
|
-
node: The nodes to check for.
|
|
351
|
-
anchor_time: The anchor time.
|
|
352
|
-
|
|
353
|
-
Returns:
|
|
354
|
-
The feature dictionary, the time column dictionary and the batch
|
|
355
|
-
dictionary.
|
|
356
|
-
"""
|
|
357
|
-
specs = self._query.get_sampling_specs(self._graph_store.edge_types)
|
|
358
|
-
num_hops = max([spec.hop for spec in specs] + [0])
|
|
359
|
-
num_neighbors: Dict[Tuple[str, str, str], list[int]] = {}
|
|
360
|
-
time_offsets: Dict[
|
|
361
|
-
Tuple[str, str, str],
|
|
362
|
-
List[List[Optional[int]]],
|
|
363
|
-
] = {}
|
|
364
|
-
for spec in specs:
|
|
365
|
-
if spec.end_offset is not None:
|
|
366
|
-
if spec.edge_type not in time_offsets:
|
|
367
|
-
time_offsets[spec.edge_type] = [[0, 0]
|
|
368
|
-
for _ in range(num_hops)]
|
|
369
|
-
offset: Optional[int] = date_offset_to_seconds(spec.end_offset)
|
|
370
|
-
time_offsets[spec.edge_type][spec.hop - 1][1] = offset
|
|
371
|
-
if spec.start_offset is not None:
|
|
372
|
-
offset = date_offset_to_seconds(spec.start_offset)
|
|
373
|
-
else:
|
|
374
|
-
offset = None
|
|
375
|
-
time_offsets[spec.edge_type][spec.hop - 1][0] = offset
|
|
376
|
-
else:
|
|
377
|
-
if spec.edge_type not in num_neighbors:
|
|
378
|
-
num_neighbors[spec.edge_type] = [0] * num_hops
|
|
379
|
-
num_neighbors[spec.edge_type][spec.hop - 1] = -1
|
|
380
|
-
|
|
381
|
-
edge_types = list(num_neighbors.keys()) + list(time_offsets.keys())
|
|
382
|
-
node_types = list(
|
|
383
|
-
set([self._query.entity.pkey.table_name])
|
|
384
|
-
| set(src for src, _, _ in edge_types)
|
|
385
|
-
| set(dst for _, _, dst in edge_types))
|
|
386
|
-
|
|
387
|
-
sampler = kumolib.NeighborSampler(
|
|
388
|
-
node_types,
|
|
389
|
-
edge_types,
|
|
390
|
-
{
|
|
391
|
-
'__'.join(edge_type): self._graph_store.colptr_dict[edge_type]
|
|
392
|
-
for edge_type in edge_types
|
|
393
|
-
},
|
|
394
|
-
{
|
|
395
|
-
'__'.join(edge_type): self._graph_store.row_dict[edge_type]
|
|
396
|
-
for edge_type in edge_types
|
|
397
|
-
},
|
|
398
|
-
{
|
|
399
|
-
node_type: time
|
|
400
|
-
for node_type, time in self._graph_store.time_dict.items()
|
|
401
|
-
if node_type in node_types
|
|
402
|
-
},
|
|
403
|
-
)
|
|
404
|
-
|
|
405
|
-
anchor_time = anchor_time.astype('datetime64[ns]')
|
|
406
|
-
_, _, node_dict, batch_dict, _, _ = sampler.sample(
|
|
407
|
-
{
|
|
408
|
-
'__'.join(edge_type): np.array(values)
|
|
409
|
-
for edge_type, values in num_neighbors.items()
|
|
410
|
-
},
|
|
411
|
-
{
|
|
412
|
-
'__'.join(edge_type): np.array(values)
|
|
413
|
-
for edge_type, values in time_offsets.items()
|
|
414
|
-
},
|
|
415
|
-
self._query.entity.pkey.table_name,
|
|
416
|
-
node,
|
|
417
|
-
anchor_time.astype(int).to_numpy() // 1000**3,
|
|
418
|
-
)
|
|
419
|
-
|
|
420
|
-
feat_dict: Dict[str, pd.DataFrame] = {}
|
|
421
|
-
time_dict: Dict[str, pd.Series] = {}
|
|
422
|
-
column_dict = self._query.column_dict
|
|
423
|
-
time_tables = self._query.time_tables
|
|
424
|
-
for table_name in set(list(column_dict.keys()) + time_tables):
|
|
425
|
-
df = self._graph_store.df_dict[table_name]
|
|
426
|
-
row_id = node_dict[table_name]
|
|
427
|
-
df = df.iloc[row_id].reset_index(drop=True)
|
|
428
|
-
if table_name in column_dict:
|
|
429
|
-
feat_dict[table_name] = df[list(column_dict[table_name])]
|
|
430
|
-
if table_name in time_tables:
|
|
431
|
-
time_col = self._graph_store.time_column_dict[table_name]
|
|
432
|
-
time_dict[table_name] = df[time_col]
|
|
433
|
-
|
|
434
|
-
return feat_dict, time_dict, batch_dict
|
|
435
|
-
|
|
436
|
-
def __call__(
|
|
437
|
-
self,
|
|
438
|
-
node: np.ndarray,
|
|
439
|
-
anchor_time: pd.Series,
|
|
440
|
-
) -> Tuple[pd.Series, np.ndarray]:
|
|
441
|
-
|
|
442
|
-
feat_dict, time_dict, batch_dict = self._sample(node, anchor_time)
|
|
443
|
-
|
|
444
|
-
y, mask = PQueryPandasBackend().eval_pquery(
|
|
445
|
-
query=self._query,
|
|
446
|
-
feat_dict=feat_dict,
|
|
447
|
-
time_dict=time_dict,
|
|
448
|
-
batch_dict=batch_dict,
|
|
449
|
-
anchor_time=anchor_time,
|
|
450
|
-
num_forecasts=self._query.num_forecasts,
|
|
451
|
-
)
|
|
452
|
-
|
|
453
|
-
return y, mask
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
def date_offset_to_seconds(offset: pd.DateOffset) -> int:
|
|
457
|
-
r"""Convert a :class:`pandas.DateOffset` into a maximum number of
|
|
458
|
-
nanoseconds.
|
|
459
|
-
|
|
460
|
-
.. note::
|
|
461
|
-
We are conservative and take months and years as their maximum value.
|
|
462
|
-
Additional values are then dropped in label computation where we know
|
|
463
|
-
the actual dates.
|
|
464
|
-
"""
|
|
465
|
-
# Max durations for months and years in nanoseconds:
|
|
466
|
-
MAX_DAYS_IN_MONTH = 31
|
|
467
|
-
MAX_DAYS_IN_YEAR = 366
|
|
468
|
-
|
|
469
|
-
# Conversion factors:
|
|
470
|
-
SECONDS_IN_MINUTE = 60
|
|
471
|
-
SECONDS_IN_HOUR = 60 * SECONDS_IN_MINUTE
|
|
472
|
-
SECONDS_IN_DAY = 24 * SECONDS_IN_HOUR
|
|
473
|
-
|
|
474
|
-
total_ns = 0
|
|
475
|
-
multiplier = getattr(offset, 'n', 1) # The multiplier (if present).
|
|
476
|
-
|
|
477
|
-
for attr, value in offset.__dict__.items():
|
|
478
|
-
if value is None or value == 0:
|
|
479
|
-
continue
|
|
480
|
-
scaled_value = value * multiplier
|
|
481
|
-
if attr == 'years':
|
|
482
|
-
total_ns += scaled_value * MAX_DAYS_IN_YEAR * SECONDS_IN_DAY
|
|
483
|
-
elif attr == 'months':
|
|
484
|
-
total_ns += scaled_value * MAX_DAYS_IN_MONTH * SECONDS_IN_DAY
|
|
485
|
-
elif attr == 'days':
|
|
486
|
-
total_ns += scaled_value * SECONDS_IN_DAY
|
|
487
|
-
elif attr == 'hours':
|
|
488
|
-
total_ns += scaled_value * SECONDS_IN_HOUR
|
|
489
|
-
elif attr == 'minutes':
|
|
490
|
-
total_ns += scaled_value * SECONDS_IN_MINUTE
|
|
491
|
-
elif attr == 'seconds':
|
|
492
|
-
total_ns += scaled_value
|
|
493
|
-
|
|
494
|
-
return total_ns
|