kumoai 2.13.0.dev202512081731__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512211732__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/_version.py +1 -1
  2. kumoai/client/pquery.py +6 -2
  3. kumoai/experimental/rfm/__init__.py +33 -8
  4. kumoai/experimental/rfm/authenticate.py +3 -4
  5. kumoai/experimental/rfm/backend/local/graph_store.py +40 -83
  6. kumoai/experimental/rfm/backend/local/sampler.py +213 -14
  7. kumoai/experimental/rfm/backend/local/table.py +21 -16
  8. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  9. kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
  10. kumoai/experimental/rfm/backend/snow/table.py +101 -49
  11. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
  13. kumoai/experimental/rfm/backend/sqlite/table.py +84 -31
  14. kumoai/experimental/rfm/base/__init__.py +25 -6
  15. kumoai/experimental/rfm/base/column.py +14 -12
  16. kumoai/experimental/rfm/base/column_expression.py +50 -0
  17. kumoai/experimental/rfm/base/sampler.py +438 -38
  18. kumoai/experimental/rfm/base/source.py +1 -0
  19. kumoai/experimental/rfm/base/sql_sampler.py +84 -0
  20. kumoai/experimental/rfm/base/sql_table.py +229 -0
  21. kumoai/experimental/rfm/base/table.py +165 -135
  22. kumoai/experimental/rfm/graph.py +266 -102
  23. kumoai/experimental/rfm/infer/__init__.py +6 -4
  24. kumoai/experimental/rfm/infer/dtype.py +3 -3
  25. kumoai/experimental/rfm/infer/pkey.py +4 -2
  26. kumoai/experimental/rfm/infer/stype.py +35 -0
  27. kumoai/experimental/rfm/infer/time_col.py +1 -2
  28. kumoai/experimental/rfm/pquery/executor.py +27 -27
  29. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  30. kumoai/experimental/rfm/rfm.py +299 -230
  31. kumoai/experimental/rfm/sagemaker.py +4 -4
  32. kumoai/pquery/predictive_query.py +10 -6
  33. kumoai/testing/snow.py +50 -0
  34. kumoai/utils/__init__.py +3 -2
  35. kumoai/utils/progress_logger.py +178 -12
  36. kumoai/utils/sql.py +3 -0
  37. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/METADATA +3 -2
  38. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/RECORD +41 -35
  39. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  40. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  41. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/WHEEL +0 -0
  42. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/top_level.txt +0 -0
kumoai/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '2.13.0.dev202512081731'
1
+ __version__ = '2.14.0.dev202512211732'
kumoai/client/pquery.py CHANGED
@@ -176,8 +176,12 @@ def filter_model_plan(
176
176
  # Undefined
177
177
  pass
178
178
 
179
- new_opt_fields.append((field.name, _type, default))
180
- new_opts.append(getattr(section, field.name))
179
+ # Forward compatibility - Remove any newly introduced arguments not
180
+ # returned yet by the backend:
181
+ value = getattr(section, field.name)
182
+ if value != MissingType.VALUE:
183
+ new_opt_fields.append((field.name, _type, default))
184
+ new_opts.append(value)
181
185
 
182
186
  Section = dataclass(
183
187
  config=dict(validate_assignment=True),
@@ -6,11 +6,11 @@ import socket
6
6
  import threading
7
7
  from dataclasses import dataclass
8
8
  from enum import Enum
9
- from typing import Dict, Optional, Tuple
10
9
  from urllib.parse import urlparse
11
10
 
12
11
  import kumoai
13
12
  from kumoai.client.client import KumoClient
13
+ from kumoai.spcs import _get_active_session
14
14
 
15
15
  from .authenticate import authenticate
16
16
  from .sagemaker import (
@@ -49,7 +49,8 @@ class InferenceBackend(str, Enum):
49
49
 
50
50
 
51
51
  def _detect_backend(
52
- url: str) -> Tuple[InferenceBackend, Optional[str], Optional[str]]:
52
+ url: str, #
53
+ ) -> tuple[InferenceBackend, str | None, str | None]:
53
54
  parsed = urlparse(url)
54
55
 
55
56
  # Remote SageMaker
@@ -73,12 +74,27 @@ def _detect_backend(
73
74
  return InferenceBackend.REST, None, None
74
75
 
75
76
 
77
+ def _get_snowflake_url(snowflake_application: str) -> str:
78
+ snowpark_session = _get_active_session()
79
+ if not snowpark_session:
80
+ raise ValueError(
81
+ "Client creation failed: snowflake_application is specified "
82
+ "without an active snowpark session. If running outside "
83
+ "a snowflake notebook, specify a URL and credentials.")
84
+ with snowpark_session.connection.cursor() as cur:
85
+ cur.execute(
86
+ f"DESCRIBE SERVICE {snowflake_application}.user_schema.rfm_service"
87
+ f" ->> SELECT \"dns_name\" from $1")
88
+ dns_name: str = cur.fetchone()[0]
89
+ return f"http://{dns_name}:8000/api"
90
+
91
+
76
92
  @dataclass
77
93
  class RfmGlobalState:
78
94
  _url: str = '__url_not_provided__'
79
95
  _backend: InferenceBackend = InferenceBackend.UNKNOWN
80
- _region: Optional[str] = None
81
- _endpoint_name: Optional[str] = None
96
+ _region: str | None = None
97
+ _endpoint_name: str | None = None
82
98
  _thread_local = threading.local()
83
99
 
84
100
  # Thread-safe init-once.
@@ -121,10 +137,10 @@ global_state = RfmGlobalState()
121
137
 
122
138
 
123
139
  def init(
124
- url: Optional[str] = None,
125
- api_key: Optional[str] = None,
126
- snowflake_credentials: Optional[Dict[str, str]] = None,
127
- snowflake_application: Optional[str] = None,
140
+ url: str | None = None,
141
+ api_key: str | None = None,
142
+ snowflake_credentials: dict[str, str] | None = None,
143
+ snowflake_application: str | None = None,
128
144
  log_level: str = "INFO",
129
145
  ) -> None:
130
146
  with global_state._lock:
@@ -136,6 +152,15 @@ def init(
136
152
  "supported.")
137
153
  return
138
154
 
155
+ if snowflake_application:
156
+ if url is not None:
157
+ raise ValueError(
158
+ "Client creation failed: both snowflake_application and "
159
+ "url are specified. If running from a snowflake notebook, "
160
+ "specify only snowflake_application.")
161
+ url = _get_snowflake_url(snowflake_application)
162
+ api_key = "test:DISABLED"
163
+
139
164
  if url is None:
140
165
  url = os.getenv("RFM_API_URL", "https://kumorfm.ai/api")
141
166
 
@@ -2,12 +2,11 @@ import logging
2
2
  import os
3
3
  import platform
4
4
  from datetime import datetime
5
- from typing import Optional
6
5
 
7
6
  from kumoai import in_notebook
8
7
 
9
8
 
10
- def authenticate(api_url: Optional[str] = None) -> None:
9
+ def authenticate(api_url: str | None = None) -> None:
11
10
  """Authenticates the user and sets the Kumo API key for the SDK.
12
11
 
13
12
  This function detects the current environment and launches the appropriate
@@ -65,11 +64,11 @@ def _authenticate_local(api_url: str, redirect_port: int = 8765) -> None:
65
64
  import webbrowser
66
65
  from getpass import getpass
67
66
  from socketserver import TCPServer
68
- from typing import Any, Dict
67
+ from typing import Any
69
68
 
70
69
  logger = logging.getLogger('kumoai')
71
70
 
72
- token_status: Dict[str, Any] = {
71
+ token_status: dict[str, Any] = {
73
72
  'token': None,
74
73
  'token_name': None,
75
74
  'failed': False
@@ -1,5 +1,5 @@
1
1
  import warnings
2
- from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
2
+ from typing import TYPE_CHECKING
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
@@ -7,7 +7,7 @@ from kumoapi.rfm.context import Subgraph
7
7
  from kumoapi.typing import Stype
8
8
 
9
9
  from kumoai.experimental.rfm.backend.local import LocalTable
10
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
10
+ from kumoai.utils import ProgressLogger
11
11
 
12
12
  try:
13
13
  import torch
@@ -23,37 +23,32 @@ class LocalGraphStore:
23
23
  def __init__(
24
24
  self,
25
25
  graph: 'Graph',
26
- verbose: Union[bool, ProgressLogger] = True,
26
+ verbose: bool | ProgressLogger = True,
27
27
  ) -> None:
28
28
 
29
29
  if not isinstance(verbose, ProgressLogger):
30
- verbose = InteractiveProgressLogger(
31
- "Materializing graph",
30
+ verbose = ProgressLogger.default(
31
+ msg="Materializing graph",
32
32
  verbose=verbose,
33
33
  )
34
34
 
35
35
  with verbose as logger:
36
36
  self.df_dict, self.mask_dict = self.sanitize(graph)
37
- self.stype_dict = self.get_stype_dict(graph)
38
37
  logger.log("Sanitized input data")
39
38
 
40
- self.pkey_name_dict, self.pkey_map_dict = self.get_pkey_data(graph)
39
+ self.pkey_map_dict = self.get_pkey_map_dict(graph)
41
40
  num_pkeys = sum(t.has_primary_key() for t in graph.tables.values())
42
41
  if num_pkeys > 1:
43
42
  logger.log(f"Collected primary keys from {num_pkeys} tables")
44
43
  else:
45
44
  logger.log(f"Collected primary key from {num_pkeys} table")
46
45
 
47
- (
48
- self.time_column_dict,
49
- self.end_time_column_dict,
50
- self.time_dict,
51
- self.min_time,
52
- self.max_time,
53
- ) = self.get_time_data(graph)
54
- if self.max_time != pd.Timestamp.min:
46
+ self.time_dict, self.min_max_time_dict = self.get_time_data(graph)
47
+ if len(self.min_max_time_dict) > 0:
48
+ min_time = min(t for t, _ in self.min_max_time_dict.values())
49
+ max_time = max(t for _, t in self.min_max_time_dict.values())
55
50
  logger.log(f"Identified temporal graph from "
56
- f"{self.min_time.date()} to {self.max_time.date()}")
51
+ f"{min_time.date()} to {max_time.date()}")
57
52
  else:
58
53
  logger.log("Identified static graph without timestamps")
59
54
 
@@ -63,14 +58,6 @@ class LocalGraphStore:
63
58
  logger.log(f"Created graph with {num_nodes:,} nodes and "
64
59
  f"{num_edges:,} edges")
65
60
 
66
- @property
67
- def node_types(self) -> List[str]:
68
- return list(self.df_dict.keys())
69
-
70
- @property
71
- def edge_types(self) -> List[Tuple[str, str, str]]:
72
- return list(self.row_dict.keys())
73
-
74
61
  def get_node_id(self, table_name: str, pkey: pd.Series) -> np.ndarray:
75
62
  r"""Returns the node ID given primary keys.
76
63
 
@@ -107,7 +94,7 @@ class LocalGraphStore:
107
94
  def sanitize(
108
95
  self,
109
96
  graph: 'Graph',
110
- ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, np.ndarray]]:
97
+ ) -> tuple[dict[str, pd.DataFrame], dict[str, np.ndarray]]:
111
98
  r"""Sanitizes raw data according to table schema definition:
112
99
 
113
100
  In particular, it:
@@ -116,13 +103,13 @@ class LocalGraphStore:
116
103
  * drops duplicate primary keys
117
104
  * removes rows with missing primary keys or time values
118
105
  """
119
- df_dict: Dict[str, pd.DataFrame] = {}
106
+ df_dict: dict[str, pd.DataFrame] = {}
120
107
  for table_name, table in graph.tables.items():
121
108
  assert isinstance(table, LocalTable)
122
109
  df = table._data
123
110
  df_dict[table_name] = df.copy(deep=False).reset_index(drop=True)
124
111
 
125
- mask_dict: Dict[str, np.ndarray] = {}
112
+ mask_dict: dict[str, np.ndarray] = {}
126
113
  for table in graph.tables.values():
127
114
  for col in table.columns:
128
115
  if col.stype == Stype.timestamp:
@@ -139,7 +126,7 @@ class LocalGraphStore:
139
126
  ser = ser.dt.tz_localize(None)
140
127
  df_dict[table.name][col.name] = ser
141
128
 
142
- mask: Optional[np.ndarray] = None
129
+ mask: np.ndarray | None = None
143
130
  if table._time_column is not None:
144
131
  ser = df_dict[table.name][table._time_column]
145
132
  mask = ser.notna().to_numpy()
@@ -154,34 +141,16 @@ class LocalGraphStore:
154
141
 
155
142
  return df_dict, mask_dict
156
143
 
157
- def get_stype_dict(self, graph: 'Graph') -> Dict[str, Dict[str, Stype]]:
158
- stype_dict: Dict[str, Dict[str, Stype]] = {}
159
- foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
160
- for table in graph.tables.values():
161
- stype_dict[table.name] = {}
162
- for column in table.columns:
163
- if column == table.primary_key:
164
- continue
165
- if (table.name, column.name) in foreign_keys:
166
- continue
167
- stype_dict[table.name][column.name] = column.stype
168
- return stype_dict
169
-
170
- def get_pkey_data(
144
+ def get_pkey_map_dict(
171
145
  self,
172
146
  graph: 'Graph',
173
- ) -> Tuple[
174
- Dict[str, str],
175
- Dict[str, pd.DataFrame],
176
- ]:
177
- pkey_name_dict: Dict[str, str] = {}
178
- pkey_map_dict: Dict[str, pd.DataFrame] = {}
147
+ ) -> dict[str, pd.DataFrame]:
148
+ pkey_map_dict: dict[str, pd.DataFrame] = {}
179
149
 
180
150
  for table in graph.tables.values():
181
151
  if table._primary_key is None:
182
152
  continue
183
153
 
184
- pkey_name_dict[table.name] = table._primary_key
185
154
  pkey = self.df_dict[table.name][table._primary_key]
186
155
  pkey_map = pd.DataFrame(
187
156
  dict(arange=range(len(pkey))),
@@ -203,27 +172,18 @@ class LocalGraphStore:
203
172
 
204
173
  pkey_map_dict[table.name] = pkey_map
205
174
 
206
- return pkey_name_dict, pkey_map_dict
175
+ return pkey_map_dict
207
176
 
208
177
  def get_time_data(
209
178
  self,
210
179
  graph: 'Graph',
211
- ) -> Tuple[
212
- Dict[str, str],
213
- Dict[str, str],
214
- Dict[str, np.ndarray],
215
- pd.Timestamp,
216
- pd.Timestamp,
180
+ ) -> tuple[
181
+ dict[str, np.ndarray],
182
+ dict[str, tuple[pd.Timestamp, pd.Timestamp]],
217
183
  ]:
218
- time_column_dict: Dict[str, str] = {}
219
- end_time_column_dict: Dict[str, str] = {}
220
- time_dict: Dict[str, np.ndarray] = {}
221
- min_time = pd.Timestamp.max
222
- max_time = pd.Timestamp.min
184
+ time_dict: dict[str, np.ndarray] = {}
185
+ min_max_time_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
223
186
  for table in graph.tables.values():
224
- if table._end_time_column is not None:
225
- end_time_column_dict[table.name] = table._end_time_column
226
-
227
187
  if table._time_column is None:
228
188
  continue
229
189
 
@@ -231,34 +191,31 @@ class LocalGraphStore:
231
191
  if time.dtype != 'datetime64[ns]':
232
192
  time = time.astype('datetime64[ns]')
233
193
  time_dict[table.name] = time.astype(int).to_numpy() // 1000**3
234
- time_column_dict[table.name] = table._time_column
235
194
 
236
195
  if table.name in self.mask_dict.keys():
237
196
  time = time[self.mask_dict[table.name]]
238
197
  if len(time) > 0:
239
- min_time = min(min_time, time.min())
240
- max_time = max(max_time, time.max())
198
+ min_max_time_dict[table.name] = (time.min(), time.max())
199
+ else:
200
+ min_max_time_dict[table.name] = (
201
+ pd.Timestamp.max,
202
+ pd.Timestamp.min,
203
+ )
241
204
 
242
- return (
243
- time_column_dict,
244
- end_time_column_dict,
245
- time_dict,
246
- min_time,
247
- max_time,
248
- )
205
+ return time_dict, min_max_time_dict
249
206
 
250
207
  def get_csc(
251
208
  self,
252
209
  graph: 'Graph',
253
- ) -> Tuple[
254
- Dict[Tuple[str, str, str], np.ndarray],
255
- Dict[Tuple[str, str, str], np.ndarray],
210
+ ) -> tuple[
211
+ dict[tuple[str, str, str], np.ndarray],
212
+ dict[tuple[str, str, str], np.ndarray],
256
213
  ]:
257
214
  # A mapping from raw primary keys to node indices (0 to N-1):
258
- map_dict: Dict[str, pd.CategoricalDtype] = {}
215
+ map_dict: dict[str, pd.CategoricalDtype] = {}
259
216
  # A dictionary to manage offsets of node indices for invalid rows:
260
- offset_dict: Dict[str, np.ndarray] = {}
261
- for table_name in set(edge.dst_table for edge in graph.edges):
217
+ offset_dict: dict[str, np.ndarray] = {}
218
+ for table_name in {edge.dst_table for edge in graph.edges}:
262
219
  ser = self.df_dict[table_name][graph[table_name]._primary_key]
263
220
  if table_name in self.mask_dict.keys():
264
221
  mask = self.mask_dict[table_name]
@@ -267,8 +224,8 @@ class LocalGraphStore:
267
224
  map_dict[table_name] = pd.CategoricalDtype(ser, ordered=True)
268
225
 
269
226
  # Build CSC graph representation:
270
- row_dict: Dict[Tuple[str, str, str], np.ndarray] = {}
271
- colptr_dict: Dict[Tuple[str, str, str], np.ndarray] = {}
227
+ row_dict: dict[tuple[str, str, str], np.ndarray] = {}
228
+ colptr_dict: dict[tuple[str, str, str], np.ndarray] = {}
272
229
  for src_table, fkey, dst_table in graph.edges:
273
230
  src_df = self.df_dict[src_table]
274
231
  dst_df = self.df_dict[dst_table]
@@ -330,7 +287,7 @@ def _argsort(input: np.ndarray) -> np.ndarray:
330
287
  return torch.from_numpy(input).argsort().numpy()
331
288
 
332
289
 
333
- def _lexsort(inputs: List[np.ndarray]) -> np.ndarray:
290
+ def _lexsort(inputs: list[np.ndarray]) -> np.ndarray:
334
291
  assert len(inputs) >= 1
335
292
 
336
293
  if not WITH_TORCH:
@@ -1,10 +1,12 @@
1
- from typing import TYPE_CHECKING
1
+ from typing import TYPE_CHECKING, Literal
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
5
+ from kumoapi.pquery import ValidatedPredictiveQuery
5
6
 
6
7
  from kumoai.experimental.rfm.backend.local import LocalGraphStore
7
- from kumoai.experimental.rfm.base import BackwardSamplerOutput, Sampler
8
+ from kumoai.experimental.rfm.base import Sampler, SamplerOutput
9
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
8
10
  from kumoai.utils import ProgressLogger
9
11
 
10
12
  if TYPE_CHECKING:
@@ -17,7 +19,7 @@ class LocalSampler(Sampler):
17
19
  graph: 'Graph',
18
20
  verbose: bool | ProgressLogger = True,
19
21
  ) -> None:
20
- super().__init__(graph=graph)
22
+ super().__init__(graph=graph, verbose=verbose)
21
23
 
22
24
  import kumoai.kumolib as kumolib
23
25
 
@@ -36,19 +38,32 @@ class LocalSampler(Sampler):
36
38
  self._graph_store.time_dict,
37
39
  )
38
40
 
39
- def _sample_backward(
41
+ def _get_min_max_time_dict(
42
+ self,
43
+ table_names: list[str],
44
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
45
+ return {
46
+ key: value
47
+ for key, value in self._graph_store.min_max_time_dict.items()
48
+ if key in table_names
49
+ }
50
+
51
+ def _sample_subgraph(
40
52
  self,
41
53
  entity_table_name: str,
42
54
  entity_pkey: pd.Series,
43
- anchor_time: pd.Series,
55
+ anchor_time: pd.Series | Literal['entity'],
44
56
  columns_dict: dict[str, set[str]],
45
57
  num_neighbors: list[int],
46
- ) -> BackwardSamplerOutput:
58
+ ) -> SamplerOutput:
47
59
 
48
- num_neighbors_dict: dict[str, list[int]] = {
49
- '__'.join(edge_type): num_neighbors
50
- for edge_type in self.edge_types
51
- }
60
+ index = self._graph_store.get_node_id(entity_table_name, entity_pkey)
61
+
62
+ if isinstance(anchor_time, pd.Series):
63
+ time = anchor_time.astype(int).to_numpy() // 1000**3 # to seconds
64
+ else:
65
+ assert anchor_time == 'entity'
66
+ time = self._graph_store.time_dict[entity_table_name][index]
52
67
 
53
68
  (
54
69
  row_dict,
@@ -58,11 +73,14 @@ class LocalSampler(Sampler):
58
73
  num_sampled_nodes_dict,
59
74
  num_sampled_edges_dict,
60
75
  ) = self._graph_sampler.sample(
61
- num_neighbors_dict,
76
+ {
77
+ '__'.join(edge_type): num_neighbors
78
+ for edge_type in self.edge_types
79
+ },
62
80
  {},
63
81
  entity_table_name,
64
- self._graph_store.get_node_id(entity_table_name, entity_pkey),
65
- anchor_time.astype(int).to_numpy() // 1000**3, # to seconds
82
+ index,
83
+ time,
66
84
  )
67
85
 
68
86
  df_dict: dict[str, pd.DataFrame] = {}
@@ -105,7 +123,8 @@ class LocalSampler(Sampler):
105
123
  for edge_type in self.edge_types
106
124
  }
107
125
 
108
- return BackwardSamplerOutput(
126
+ return SamplerOutput(
127
+ anchor_time=time * 1000**3, # to nanoseconds
109
128
  df_dict=df_dict,
110
129
  inverse_dict=inverse_dict,
111
130
  batch_dict=batch_dict,
@@ -114,3 +133,183 @@ class LocalSampler(Sampler):
114
133
  col_dict=col_dict,
115
134
  num_sampled_edges_dict=num_sampled_edges_dict,
116
135
  )
136
+
137
+ def _sample_entity_table(
138
+ self,
139
+ table_name: str,
140
+ columns: set[str],
141
+ num_rows: int,
142
+ random_seed: int | None = None,
143
+ ) -> pd.DataFrame:
144
+ pkey_map = self._graph_store.pkey_map_dict[table_name]
145
+ if len(pkey_map) > num_rows:
146
+ pkey_map = pkey_map.sample(
147
+ n=num_rows,
148
+ random_state=random_seed,
149
+ ignore_index=True,
150
+ )
151
+ df = self._graph_store.df_dict[table_name]
152
+ df = df.iloc[pkey_map['arange']][list(columns)]
153
+ return df
154
+
155
+ def _sample_target(
156
+ self,
157
+ query: ValidatedPredictiveQuery,
158
+ entity_df: pd.DataFrame,
159
+ train_index: np.ndarray,
160
+ train_time: pd.Series,
161
+ num_train_examples: int,
162
+ test_index: np.ndarray,
163
+ test_time: pd.Series,
164
+ num_test_examples: int,
165
+ columns_dict: dict[str, set[str]],
166
+ time_offset_dict: dict[
167
+ tuple[str, str, str],
168
+ tuple[pd.DateOffset | None, pd.DateOffset],
169
+ ],
170
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
171
+
172
+ train_y, train_mask = self._sample_target_set(
173
+ query=query,
174
+ pkey=entity_df[self.primary_key_dict[query.entity_table]],
175
+ index=train_index,
176
+ anchor_time=train_time,
177
+ num_examples=num_train_examples,
178
+ columns_dict=columns_dict,
179
+ time_offset_dict=time_offset_dict,
180
+ )
181
+
182
+ test_y, test_mask = self._sample_target_set(
183
+ query=query,
184
+ pkey=entity_df[self.primary_key_dict[query.entity_table]],
185
+ index=test_index,
186
+ anchor_time=test_time,
187
+ num_examples=num_test_examples,
188
+ columns_dict=columns_dict,
189
+ time_offset_dict=time_offset_dict,
190
+ )
191
+
192
+ return train_y, train_mask, test_y, test_mask
193
+
194
+ # Helper Methods ##########################################################
195
+
196
+ def _sample_target_set(
197
+ self,
198
+ query: ValidatedPredictiveQuery,
199
+ pkey: pd.Series,
200
+ index: np.ndarray,
201
+ anchor_time: pd.Series,
202
+ num_examples: int,
203
+ columns_dict: dict[str, set[str]],
204
+ time_offset_dict: dict[
205
+ tuple[str, str, str],
206
+ tuple[pd.DateOffset | None, pd.DateOffset],
207
+ ],
208
+ batch_size: int = 10_000,
209
+ ) -> tuple[pd.Series, np.ndarray]:
210
+
211
+ num_hops = 1 if len(time_offset_dict) > 0 else 0
212
+ num_neighbors_dict: dict[str, list[int]] = {}
213
+ unix_time_offset_dict: dict[str, list[list[int | None]]] = {}
214
+ for edge_type, (start, end) in time_offset_dict.items():
215
+ unix_time_offset_dict['__'.join(edge_type)] = [[
216
+ date_offset_to_seconds(start) if start is not None else None,
217
+ date_offset_to_seconds(end),
218
+ ]]
219
+ for edge_type in set(self.edge_types) - set(time_offset_dict.keys()):
220
+ num_neighbors_dict['__'.join(edge_type)] = [0] * num_hops
221
+
222
+ if anchor_time.dtype != 'datetime64[ns]':
223
+ anchor_time = anchor_time.astype('datetime64')
224
+
225
+ count = 0
226
+ ys: list[pd.Series] = []
227
+ mask = np.full(len(index), False, dtype=bool)
228
+ for start in range(0, len(index), batch_size):
229
+ subset = pkey.iloc[index[start:start + batch_size]]
230
+ time = anchor_time.iloc[start:start + batch_size]
231
+
232
+ _, _, node_dict, batch_dict, _, _ = self._graph_sampler.sample(
233
+ num_neighbors_dict,
234
+ unix_time_offset_dict,
235
+ query.entity_table,
236
+ self._graph_store.get_node_id(query.entity_table, subset),
237
+ time.astype(int).to_numpy() // 1000**3, # to seconds
238
+ )
239
+
240
+ feat_dict: dict[str, pd.DataFrame] = {}
241
+ time_dict: dict[str, pd.Series] = {}
242
+ for table_name, columns in columns_dict.items():
243
+ df = self._graph_store.df_dict[table_name]
244
+ df = df.iloc[node_dict[table_name]].reset_index(drop=True)
245
+ df = df[list(columns)]
246
+ feat_dict[table_name] = df
247
+
248
+ time_column = self.time_column_dict.get(table_name)
249
+ if time_column in columns:
250
+ time_dict[table_name] = df[time_column]
251
+
252
+ y, _mask = PQueryPandasExecutor().execute(
253
+ query=query,
254
+ feat_dict=feat_dict,
255
+ time_dict=time_dict,
256
+ batch_dict=batch_dict,
257
+ anchor_time=time,
258
+ num_forecasts=query.num_forecasts,
259
+ )
260
+ ys.append(y)
261
+ mask[start:start + batch_size] = _mask
262
+
263
+ count += len(y)
264
+ if count >= num_examples:
265
+ break
266
+
267
+ if len(ys) == 0:
268
+ y = pd.Series([], dtype=float)
269
+ elif len(ys) == 1:
270
+ y = ys[0]
271
+ else:
272
+ y = pd.concat(ys, axis=0, ignore_index=True)
273
+
274
+ return y, mask
275
+
276
+
277
+ # Helper Functions ############################################################
278
+
279
+
280
+ def date_offset_to_seconds(offset: pd.DateOffset) -> int:
281
+ r"""Convert a :class:`pandas.DateOffset` into a number of seconds.
282
+
283
+ .. note::
284
+ We are conservative and take months and years as their maximum value.
285
+ Additional values are then dropped in label computation where we know
286
+ the actual dates.
287
+ """
288
+ MAX_DAYS_IN_MONTH = 31
289
+ MAX_DAYS_IN_YEAR = 366
290
+
291
+ SECONDS_IN_MINUTE = 60
292
+ SECONDS_IN_HOUR = 60 * SECONDS_IN_MINUTE
293
+ SECONDS_IN_DAY = 24 * SECONDS_IN_HOUR
294
+
295
+ total_sec = 0
296
+ multiplier = getattr(offset, 'n', 1) # The multiplier (if present).
297
+
298
+ for attr, value in offset.__dict__.items():
299
+ if value is None or value == 0:
300
+ continue
301
+ scaled_value = value * multiplier
302
+ if attr == 'years':
303
+ total_sec += scaled_value * MAX_DAYS_IN_YEAR * SECONDS_IN_DAY
304
+ elif attr == 'months':
305
+ total_sec += scaled_value * MAX_DAYS_IN_MONTH * SECONDS_IN_DAY
306
+ elif attr == 'days':
307
+ total_sec += scaled_value * SECONDS_IN_DAY
308
+ elif attr == 'hours':
309
+ total_sec += scaled_value * SECONDS_IN_HOUR
310
+ elif attr == 'minutes':
311
+ total_sec += scaled_value * SECONDS_IN_MINUTE
312
+ elif attr == 'seconds':
313
+ total_sec += scaled_value
314
+
315
+ return total_sec