kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (122) hide show
  1. kumoai/__init__.py +300 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +223 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +471 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +207 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1796 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +210 -0
  51. kumoai/experimental/rfm/authenticate.py +432 -0
  52. kumoai/experimental/rfm/backend/__init__.py +0 -0
  53. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  54. kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
  55. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  56. kumoai/experimental/rfm/backend/local/table.py +113 -0
  57. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  58. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  59. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  60. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  61. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  62. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  63. kumoai/experimental/rfm/base/__init__.py +30 -0
  64. kumoai/experimental/rfm/base/column.py +152 -0
  65. kumoai/experimental/rfm/base/expression.py +44 -0
  66. kumoai/experimental/rfm/base/sampler.py +761 -0
  67. kumoai/experimental/rfm/base/source.py +19 -0
  68. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  69. kumoai/experimental/rfm/base/table.py +736 -0
  70. kumoai/experimental/rfm/graph.py +1237 -0
  71. kumoai/experimental/rfm/infer/__init__.py +19 -0
  72. kumoai/experimental/rfm/infer/categorical.py +40 -0
  73. kumoai/experimental/rfm/infer/dtype.py +82 -0
  74. kumoai/experimental/rfm/infer/id.py +46 -0
  75. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  76. kumoai/experimental/rfm/infer/pkey.py +128 -0
  77. kumoai/experimental/rfm/infer/stype.py +35 -0
  78. kumoai/experimental/rfm/infer/time_col.py +61 -0
  79. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  80. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  81. kumoai/experimental/rfm/pquery/executor.py +102 -0
  82. kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
  83. kumoai/experimental/rfm/relbench.py +76 -0
  84. kumoai/experimental/rfm/rfm.py +1184 -0
  85. kumoai/experimental/rfm/sagemaker.py +138 -0
  86. kumoai/experimental/rfm/task_table.py +231 -0
  87. kumoai/formatting.py +30 -0
  88. kumoai/futures.py +99 -0
  89. kumoai/graph/__init__.py +12 -0
  90. kumoai/graph/column.py +106 -0
  91. kumoai/graph/graph.py +948 -0
  92. kumoai/graph/table.py +838 -0
  93. kumoai/jobs.py +80 -0
  94. kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
  95. kumoai/mixin.py +28 -0
  96. kumoai/pquery/__init__.py +25 -0
  97. kumoai/pquery/prediction_table.py +287 -0
  98. kumoai/pquery/predictive_query.py +641 -0
  99. kumoai/pquery/training_table.py +424 -0
  100. kumoai/spcs.py +121 -0
  101. kumoai/testing/__init__.py +8 -0
  102. kumoai/testing/decorators.py +57 -0
  103. kumoai/testing/snow.py +50 -0
  104. kumoai/trainer/__init__.py +42 -0
  105. kumoai/trainer/baseline_trainer.py +93 -0
  106. kumoai/trainer/config.py +2 -0
  107. kumoai/trainer/distilled_trainer.py +175 -0
  108. kumoai/trainer/job.py +1192 -0
  109. kumoai/trainer/online_serving.py +258 -0
  110. kumoai/trainer/trainer.py +475 -0
  111. kumoai/trainer/util.py +103 -0
  112. kumoai/utils/__init__.py +11 -0
  113. kumoai/utils/datasets.py +83 -0
  114. kumoai/utils/display.py +51 -0
  115. kumoai/utils/forecasting.py +209 -0
  116. kumoai/utils/progress_logger.py +343 -0
  117. kumoai/utils/sql.py +3 -0
  118. kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
  119. kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
  120. kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
  121. kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
  122. kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
File without changes
@@ -0,0 +1,42 @@
1
+ try:
2
+ import kumoai.kumolib # noqa: F401
3
+ except Exception as e:
4
+ import platform
5
+
6
+ _msg = f"""RFM is not supported in your environment.
7
+
8
+ 💻 Your Environment:
9
+ Python version: {platform.python_version()}
10
+ Operating system: {platform.system()}
11
+ CPU architecture: {platform.machine()}
12
+ glibc version: {platform.libc_ver()[1]}
13
+
14
+ ✅ Supported Environments:
15
+ * Python versions: 3.10, 3.11, 3.12, 3.13
16
+ * Operating systems and CPU architectures:
17
+ * Linux (x86_64)
18
+ * macOS (arm64)
19
+ * Windows (x86_64)
20
+ * glibc versions: >=2.28
21
+
22
+ ❌ Unsupported Environments:
23
+ * Python versions: 3.8, 3.9, 3.14
24
+ * Operating systems and CPU architectures:
25
+ * Linux (arm64)
26
+ * macOS (x86_64)
27
+ * Windows (arm64)
28
+ * glibc versions: <2.28
29
+
30
+ Please create a feature request at 'https://github.com/kumo-ai/kumo-rfm'."""
31
+
32
+ raise RuntimeError(_msg) from e
33
+
34
+ from .table import LocalTable
35
+ from .graph_store import LocalGraphStore
36
+ from .sampler import LocalSampler
37
+
38
+ __all__ = [
39
+ 'LocalTable',
40
+ 'LocalGraphStore',
41
+ 'LocalSampler',
42
+ ]
@@ -0,0 +1,297 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from kumoapi.rfm.context import Subgraph
6
+
7
+ from kumoai.experimental.rfm.backend.local import LocalTable
8
+ from kumoai.experimental.rfm.base import Table
9
+ from kumoai.utils import ProgressLogger
10
+
11
+ try:
12
+ import torch
13
+ WITH_TORCH = True
14
+ except ImportError:
15
+ WITH_TORCH = False
16
+
17
+ if TYPE_CHECKING:
18
+ from kumoai.experimental.rfm import Graph
19
+
20
+
21
+ class LocalGraphStore:
22
+ def __init__(
23
+ self,
24
+ graph: 'Graph',
25
+ verbose: bool | ProgressLogger = True,
26
+ ) -> None:
27
+
28
+ if not isinstance(verbose, ProgressLogger):
29
+ verbose = ProgressLogger.default(
30
+ msg="Materializing graph",
31
+ verbose=verbose,
32
+ )
33
+
34
+ with verbose as logger:
35
+ self.df_dict, self.mask_dict = self.sanitize(graph)
36
+ logger.log("Sanitized input data")
37
+
38
+ self.pkey_map_dict = self.get_pkey_map_dict(graph)
39
+ num_pkeys = sum(t.has_primary_key() for t in graph.tables.values())
40
+ if num_pkeys > 1:
41
+ logger.log(f"Collected primary keys from {num_pkeys} tables")
42
+ else:
43
+ logger.log(f"Collected primary key from {num_pkeys} table")
44
+
45
+ self.time_dict, self.min_max_time_dict = self.get_time_data(graph)
46
+ if len(self.min_max_time_dict) > 0:
47
+ min_time = min(t for t, _ in self.min_max_time_dict.values())
48
+ max_time = max(t for _, t in self.min_max_time_dict.values())
49
+ logger.log(f"Identified temporal graph from "
50
+ f"{min_time.date()} to {max_time.date()}")
51
+ else:
52
+ logger.log("Identified static graph without timestamps")
53
+
54
+ self.row_dict, self.colptr_dict = self.get_csc(graph)
55
+ num_nodes = sum(len(df) for df in self.df_dict.values())
56
+ num_edges = sum(len(row) for row in self.row_dict.values())
57
+ logger.log(f"Created graph with {num_nodes:,} nodes and "
58
+ f"{num_edges:,} edges")
59
+
60
+ def get_node_id(self, table_name: str, pkey: pd.Series) -> np.ndarray:
61
+ r"""Returns the node ID given primary keys.
62
+
63
+ Args:
64
+ table_name: The table name.
65
+ pkey: The primary keys to receive node IDs for.
66
+ """
67
+ if table_name not in self.df_dict.keys():
68
+ raise KeyError(f"Table '{table_name}' does not exist")
69
+
70
+ if table_name not in self.pkey_map_dict.keys():
71
+ raise ValueError(f"Table '{table_name}' does not have a primary "
72
+ f"key")
73
+
74
+ if len(pkey) == 0:
75
+ raise KeyError(f"No primary keys passed for table '{table_name}'")
76
+
77
+ pkey_map = self.pkey_map_dict[table_name]
78
+
79
+ try:
80
+ pkey = pkey.astype(type(pkey_map.index[0]))
81
+ except ValueError as e:
82
+ raise ValueError(f"Could not cast primary keys "
83
+ f"{pkey.tolist()} to the expected data "
84
+ f"type '{pkey_map.index.dtype}'") from e
85
+
86
+ try:
87
+ return pkey_map.loc[pkey]['arange'].to_numpy()
88
+ except KeyError as e:
89
+ missing = ~np.isin(pkey, pkey_map.index)
90
+ raise KeyError(f"The primary keys {pkey[missing].tolist()} do "
91
+ f"not exist in the '{table_name}' table") from e
92
+
93
+ def sanitize(
94
+ self,
95
+ graph: 'Graph',
96
+ ) -> tuple[dict[str, pd.DataFrame], dict[str, np.ndarray]]:
97
+ r"""Sanitizes raw data according to table schema definition:
98
+
99
+ In particular, it:
100
+ * converts timestamp data to `pd.Datetime`
101
+ * drops timezone information from timestamps
102
+ * drops duplicate primary keys
103
+ * removes rows with missing primary keys or time values
104
+ """
105
+ df_dict: dict[str, pd.DataFrame] = {}
106
+ for table_name, table in graph.tables.items():
107
+ assert isinstance(table, LocalTable)
108
+ df_dict[table_name] = Table._sanitize(
109
+ df=table._data.copy(deep=False).reset_index(drop=True),
110
+ dtype_dict={
111
+ column.name: column.dtype
112
+ for column in table.columns
113
+ },
114
+ stype_dict={
115
+ column.name: column.stype
116
+ for column in table.columns
117
+ },
118
+ )
119
+
120
+ mask_dict: dict[str, np.ndarray] = {}
121
+ for table in graph.tables.values():
122
+ mask: np.ndarray | None = None
123
+ if table._time_column is not None:
124
+ ser = df_dict[table.name][table._time_column]
125
+ mask = ser.notna().to_numpy()
126
+
127
+ if table._primary_key is not None:
128
+ ser = df_dict[table.name][table._primary_key]
129
+ _mask = (~ser.duplicated().to_numpy()) & ser.notna().to_numpy()
130
+ mask = _mask if mask is None else (_mask & mask)
131
+
132
+ if mask is not None and not mask.all():
133
+ mask_dict[table.name] = mask
134
+
135
+ return df_dict, mask_dict
136
+
137
+ def get_pkey_map_dict(
138
+ self,
139
+ graph: 'Graph',
140
+ ) -> dict[str, pd.DataFrame]:
141
+ pkey_map_dict: dict[str, pd.DataFrame] = {}
142
+
143
+ for table in graph.tables.values():
144
+ if table._primary_key is None:
145
+ continue
146
+
147
+ pkey = self.df_dict[table.name][table._primary_key]
148
+ pkey_map = pd.DataFrame(
149
+ dict(arange=range(len(pkey))),
150
+ index=pkey,
151
+ )
152
+ if table.name in self.mask_dict:
153
+ pkey_map = pkey_map[self.mask_dict[table.name]]
154
+
155
+ if len(pkey_map) == 0:
156
+ error_msg = f"Found no valid rows in table '{table.name}'. "
157
+ if table.has_time_column():
158
+ error_msg += ("Please make sure that there exists valid "
159
+ "non-N/A primary key and time column pairs "
160
+ "in this table.")
161
+ else:
162
+ error_msg += ("Please make sure that there exists valid "
163
+ "non-N/A primary keys in this table.")
164
+ raise ValueError(error_msg)
165
+
166
+ pkey_map_dict[table.name] = pkey_map
167
+
168
+ return pkey_map_dict
169
+
170
+ def get_time_data(
171
+ self,
172
+ graph: 'Graph',
173
+ ) -> tuple[
174
+ dict[str, np.ndarray],
175
+ dict[str, tuple[pd.Timestamp, pd.Timestamp]],
176
+ ]:
177
+ time_dict: dict[str, np.ndarray] = {}
178
+ min_max_time_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
179
+ for table in graph.tables.values():
180
+ if table._time_column is None:
181
+ continue
182
+
183
+ time = self.df_dict[table.name][table._time_column]
184
+ time_dict[table.name] = time.astype(int).to_numpy() // 1000**3
185
+
186
+ if table.name in self.mask_dict.keys():
187
+ time = time[self.mask_dict[table.name]]
188
+ if len(time) > 0:
189
+ min_max_time_dict[table.name] = (time.min(), time.max())
190
+ else:
191
+ min_max_time_dict[table.name] = (
192
+ pd.Timestamp.max,
193
+ pd.Timestamp.min,
194
+ )
195
+
196
+ return time_dict, min_max_time_dict
197
+
198
+ def get_csc(
199
+ self,
200
+ graph: 'Graph',
201
+ ) -> tuple[
202
+ dict[tuple[str, str, str], np.ndarray],
203
+ dict[tuple[str, str, str], np.ndarray],
204
+ ]:
205
+ # A mapping from raw primary keys to node indices (0 to N-1):
206
+ map_dict: dict[str, pd.CategoricalDtype] = {}
207
+ # A dictionary to manage offsets of node indices for invalid rows:
208
+ offset_dict: dict[str, np.ndarray] = {}
209
+ for table_name in {edge.dst_table for edge in graph.edges}:
210
+ ser = self.df_dict[table_name][graph[table_name]._primary_key]
211
+ if table_name in self.mask_dict.keys():
212
+ mask = self.mask_dict[table_name]
213
+ ser = ser[mask]
214
+ offset_dict[table_name] = np.cumsum(~mask)[mask]
215
+ map_dict[table_name] = pd.CategoricalDtype(ser, ordered=True)
216
+
217
+ # Build CSC graph representation:
218
+ row_dict: dict[tuple[str, str, str], np.ndarray] = {}
219
+ colptr_dict: dict[tuple[str, str, str], np.ndarray] = {}
220
+ for src_table, fkey, dst_table in graph.edges:
221
+ src_df = self.df_dict[src_table]
222
+ dst_df = self.df_dict[dst_table]
223
+
224
+ src = np.arange(len(src_df))
225
+ dst = src_df[fkey].astype(map_dict[dst_table]).cat.codes.to_numpy()
226
+ dst = dst.astype(int)
227
+ mask = dst >= 0
228
+ if dst_table in offset_dict.keys():
229
+ dst = dst + offset_dict[dst_table][dst]
230
+ if src_table in self.mask_dict.keys():
231
+ mask &= self.mask_dict[src_table]
232
+ src, dst = src[mask], dst[mask]
233
+
234
+ # Sort by destination/column (and time within neighborhoods):
235
+ # `lexsort` is expensive (especially in numpy) so avoid it if
236
+ # possible by grouping `time` and `node_id` together:
237
+ if src_table in self.time_dict:
238
+ src_time = self.time_dict[src_table][src]
239
+ min_time = int(src_time.min())
240
+ max_time = int(src_time.max())
241
+ offset = (max_time - min_time) + 1
242
+ if offset * len(dst_df) <= np.iinfo(np.int64).max:
243
+ index = dst * offset + (src_time - min_time)
244
+ perm = _argsort(index)
245
+ else: # Safe route to avoid `int64` overflow:
246
+ perm = _lexsort([src_time, dst])
247
+ else:
248
+ perm = _argsort(dst)
249
+
250
+ row, col = src[perm], dst[perm]
251
+
252
+ # Convert into compressed representation:
253
+ colcount = np.bincount(col, minlength=len(dst_df))
254
+ colptr = np.empty(len(colcount) + 1, dtype=colcount.dtype)
255
+ colptr[0] = 0
256
+ np.cumsum(colcount, out=colptr[1:])
257
+ edge_type = (src_table, fkey, dst_table)
258
+ row_dict[edge_type] = row
259
+ colptr_dict[edge_type] = colptr
260
+
261
+ # Reverse connection - no sort and no time handling needed since
262
+ # the reverse mapping is 1-to-many.
263
+ row, col = dst, src
264
+ colcount = np.bincount(col, minlength=len(src_df))
265
+ colptr = np.empty(len(colcount) + 1, dtype=colcount.dtype)
266
+ colptr[0] = 0
267
+ np.cumsum(colcount, out=colptr[1:])
268
+ edge_type = Subgraph.rev_edge_type(edge_type)
269
+ row_dict[edge_type] = row
270
+ colptr_dict[edge_type] = colptr
271
+
272
+ return row_dict, colptr_dict
273
+
274
+
275
+ def _argsort(input: np.ndarray) -> np.ndarray:
276
+ if not WITH_TORCH:
277
+ return np.argsort(input)
278
+ return torch.from_numpy(input).argsort().numpy()
279
+
280
+
281
+ def _lexsort(inputs: list[np.ndarray]) -> np.ndarray:
282
+ assert len(inputs) >= 1
283
+
284
+ if not WITH_TORCH:
285
+ return np.lexsort(inputs)
286
+
287
+ try:
288
+ out = torch.from_numpy(inputs[0]).argsort(stable=True)
289
+ except Exception:
290
+ return np.lexsort(inputs) # PyTorch<1.9 without `stable` support.
291
+
292
+ for input in inputs[1:]:
293
+ index = torch.from_numpy(input)[out]
294
+ index = index.argsort(stable=True)
295
+ out = out[index]
296
+
297
+ return out.numpy()
@@ -0,0 +1,312 @@
1
+ from typing import TYPE_CHECKING, Literal
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from kumoapi.pquery import ValidatedPredictiveQuery
6
+
7
+ from kumoai.experimental.rfm.backend.local import LocalGraphStore
8
+ from kumoai.experimental.rfm.base import Sampler, SamplerOutput
9
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
10
+ from kumoai.utils import ProgressLogger
11
+
12
+ if TYPE_CHECKING:
13
+ from kumoai.experimental.rfm import Graph
14
+
15
+
16
+ class LocalSampler(Sampler):
17
+ def __init__(
18
+ self,
19
+ graph: 'Graph',
20
+ verbose: bool | ProgressLogger = True,
21
+ ) -> None:
22
+ super().__init__(graph=graph, verbose=verbose)
23
+
24
+ import kumoai.kumolib as kumolib
25
+
26
+ self._graph_store = LocalGraphStore(graph, verbose)
27
+ self._graph_sampler = kumolib.NeighborSampler(
28
+ list(self.table_stype_dict.keys()),
29
+ self.edge_types,
30
+ {
31
+ '__'.join(edge_type): colptr
32
+ for edge_type, colptr in self._graph_store.colptr_dict.items()
33
+ },
34
+ {
35
+ '__'.join(edge_type): row
36
+ for edge_type, row in self._graph_store.row_dict.items()
37
+ },
38
+ self._graph_store.time_dict,
39
+ )
40
+
41
+ def _get_min_max_time_dict(
42
+ self,
43
+ table_names: list[str],
44
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
45
+ return {
46
+ key: value
47
+ for key, value in self._graph_store.min_max_time_dict.items()
48
+ if key in table_names
49
+ }
50
+
51
+ def _sample_subgraph(
52
+ self,
53
+ entity_table_name: str,
54
+ entity_pkey: pd.Series,
55
+ anchor_time: pd.Series | Literal['entity'],
56
+ columns_dict: dict[str, set[str]],
57
+ num_neighbors: list[int],
58
+ ) -> SamplerOutput:
59
+
60
+ index = self._graph_store.get_node_id(entity_table_name, entity_pkey)
61
+
62
+ if isinstance(anchor_time, pd.Series):
63
+ time = anchor_time.astype(int).to_numpy() // 1000**3 # to seconds
64
+ else:
65
+ assert anchor_time == 'entity'
66
+ time = self._graph_store.time_dict[entity_table_name][index]
67
+
68
+ (
69
+ row_dict,
70
+ col_dict,
71
+ node_dict,
72
+ batch_dict,
73
+ num_sampled_nodes_dict,
74
+ num_sampled_edges_dict,
75
+ ) = self._graph_sampler.sample(
76
+ {
77
+ '__'.join(edge_type): num_neighbors
78
+ for edge_type in self.edge_types
79
+ },
80
+ {},
81
+ entity_table_name,
82
+ index,
83
+ time,
84
+ )
85
+
86
+ df_dict: dict[str, pd.DataFrame] = {}
87
+ inverse_dict: dict[str, np.ndarray] = {}
88
+ for table_name, node in node_dict.items():
89
+ df = self._graph_store.df_dict[table_name]
90
+ columns = columns_dict[table_name]
91
+ if self.end_time_column_dict.get(table_name, None) in columns:
92
+ df = df.iloc[node]
93
+ elif len(columns) == 0:
94
+ df = df.iloc[node]
95
+ else:
96
+ # Only store unique rows in `df` above a certain threshold:
97
+ unique_node, inverse = np.unique(node, return_inverse=True)
98
+ if len(node) > 1.05 * len(unique_node):
99
+ df = df.iloc[unique_node]
100
+ inverse_dict[table_name] = inverse
101
+ else:
102
+ df = df.iloc[node]
103
+ df = df.reset_index(drop=True)
104
+ df = df[list(columns)]
105
+ df_dict[table_name] = df
106
+
107
+ num_sampled_nodes_dict = {
108
+ table_name: num_sampled_nodes.tolist()
109
+ for table_name, num_sampled_nodes in
110
+ num_sampled_nodes_dict.items()
111
+ }
112
+
113
+ row_dict = {
114
+ edge_type: row_dict['__'.join(edge_type)]
115
+ for edge_type in self.edge_types
116
+ }
117
+ col_dict = {
118
+ edge_type: col_dict['__'.join(edge_type)]
119
+ for edge_type in self.edge_types
120
+ }
121
+ num_sampled_edges_dict = {
122
+ edge_type: num_sampled_edges_dict['__'.join(edge_type)].tolist()
123
+ for edge_type in self.edge_types
124
+ }
125
+
126
+ return SamplerOutput(
127
+ anchor_time=time * 1000**3, # to nanoseconds
128
+ df_dict=df_dict,
129
+ inverse_dict=inverse_dict,
130
+ batch_dict=batch_dict,
131
+ num_sampled_nodes_dict=num_sampled_nodes_dict,
132
+ row_dict=row_dict,
133
+ col_dict=col_dict,
134
+ num_sampled_edges_dict=num_sampled_edges_dict,
135
+ )
136
+
137
+ def _sample_entity_table(
138
+ self,
139
+ table_name: str,
140
+ columns: set[str],
141
+ num_rows: int,
142
+ random_seed: int | None = None,
143
+ ) -> pd.DataFrame:
144
+ pkey_map = self._graph_store.pkey_map_dict[table_name]
145
+ if len(pkey_map) > num_rows:
146
+ pkey_map = pkey_map.sample(
147
+ n=num_rows,
148
+ random_state=random_seed,
149
+ ignore_index=True,
150
+ )
151
+ df = self._graph_store.df_dict[table_name]
152
+ df = df.iloc[pkey_map['arange']][list(columns)]
153
+ return df
154
+
155
+ def _sample_target(
156
+ self,
157
+ query: ValidatedPredictiveQuery,
158
+ entity_df: pd.DataFrame,
159
+ train_index: np.ndarray,
160
+ train_time: pd.Series,
161
+ num_train_examples: int,
162
+ test_index: np.ndarray,
163
+ test_time: pd.Series,
164
+ num_test_examples: int,
165
+ columns_dict: dict[str, set[str]],
166
+ time_offset_dict: dict[
167
+ tuple[str, str, str],
168
+ tuple[pd.DateOffset | None, pd.DateOffset],
169
+ ],
170
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
171
+
172
+ train_y, train_mask = self._sample_target_set(
173
+ query=query,
174
+ pkey=entity_df[self.primary_key_dict[query.entity_table]],
175
+ index=train_index,
176
+ anchor_time=train_time,
177
+ num_examples=num_train_examples,
178
+ columns_dict=columns_dict,
179
+ time_offset_dict=time_offset_dict,
180
+ )
181
+
182
+ test_y, test_mask = self._sample_target_set(
183
+ query=query,
184
+ pkey=entity_df[self.primary_key_dict[query.entity_table]],
185
+ index=test_index,
186
+ anchor_time=test_time,
187
+ num_examples=num_test_examples,
188
+ columns_dict=columns_dict,
189
+ time_offset_dict=time_offset_dict,
190
+ )
191
+
192
+ return train_y, train_mask, test_y, test_mask
193
+
194
+ # Helper Methods ##########################################################
195
+
196
+ def _sample_target_set(
197
+ self,
198
+ query: ValidatedPredictiveQuery,
199
+ pkey: pd.Series,
200
+ index: np.ndarray,
201
+ anchor_time: pd.Series,
202
+ num_examples: int,
203
+ columns_dict: dict[str, set[str]],
204
+ time_offset_dict: dict[
205
+ tuple[str, str, str],
206
+ tuple[pd.DateOffset | None, pd.DateOffset],
207
+ ],
208
+ batch_size: int = 10_000,
209
+ ) -> tuple[pd.Series, np.ndarray]:
210
+
211
+ num_hops = 1 if len(time_offset_dict) > 0 else 0
212
+ num_neighbors_dict: dict[str, list[int]] = {}
213
+ unix_time_offset_dict: dict[str, list[list[int | None]]] = {}
214
+ for edge_type, (start, end) in time_offset_dict.items():
215
+ unix_time_offset_dict['__'.join(edge_type)] = [[
216
+ date_offset_to_seconds(start) if start is not None else None,
217
+ date_offset_to_seconds(end),
218
+ ]]
219
+ for edge_type in set(self.edge_types) - set(time_offset_dict.keys()):
220
+ num_neighbors_dict['__'.join(edge_type)] = [0] * num_hops
221
+
222
+ count = 0
223
+ ys: list[pd.Series] = []
224
+ mask = np.full(len(index), False, dtype=bool)
225
+ for start in range(0, len(index), batch_size):
226
+ subset = pkey.iloc[index[start:start + batch_size]]
227
+ time = anchor_time.iloc[start:start + batch_size]
228
+
229
+ _, _, node_dict, batch_dict, _, _ = self._graph_sampler.sample(
230
+ num_neighbors_dict,
231
+ unix_time_offset_dict,
232
+ query.entity_table,
233
+ self._graph_store.get_node_id(query.entity_table, subset),
234
+ time.astype(int).to_numpy() // 1000**3, # to seconds
235
+ )
236
+
237
+ feat_dict: dict[str, pd.DataFrame] = {}
238
+ time_dict: dict[str, pd.Series] = {}
239
+ for table_name, columns in columns_dict.items():
240
+ df = self._graph_store.df_dict[table_name]
241
+ df = df.iloc[node_dict[table_name]].reset_index(drop=True)
242
+ df = df[list(columns)]
243
+ feat_dict[table_name] = df
244
+
245
+ time_column = self.time_column_dict.get(table_name)
246
+ if time_column in columns:
247
+ time_dict[table_name] = df[time_column]
248
+
249
+ y, _mask = PQueryPandasExecutor().execute(
250
+ query=query,
251
+ feat_dict=feat_dict,
252
+ time_dict=time_dict,
253
+ batch_dict=batch_dict,
254
+ anchor_time=time,
255
+ num_forecasts=query.num_forecasts,
256
+ )
257
+ ys.append(y)
258
+ mask[start:start + batch_size] = _mask
259
+
260
+ count += len(y)
261
+ if count >= num_examples:
262
+ break
263
+
264
+ if len(ys) == 0:
265
+ y = pd.Series([], dtype=float)
266
+ elif len(ys) == 1:
267
+ y = ys[0]
268
+ else:
269
+ y = pd.concat(ys, axis=0, ignore_index=True)
270
+
271
+ return y, mask
272
+
273
+
274
+ # Helper Functions ############################################################
275
+
276
+
277
+ def date_offset_to_seconds(offset: pd.DateOffset) -> int:
278
+ r"""Convert a :class:`pandas.DateOffset` into a number of seconds.
279
+
280
+ .. note::
281
+ We are conservative and take months and years as their maximum value.
282
+ Additional values are then dropped in label computation where we know
283
+ the actual dates.
284
+ """
285
+ MAX_DAYS_IN_MONTH = 31
286
+ MAX_DAYS_IN_YEAR = 366
287
+
288
+ SECONDS_IN_MINUTE = 60
289
+ SECONDS_IN_HOUR = 60 * SECONDS_IN_MINUTE
290
+ SECONDS_IN_DAY = 24 * SECONDS_IN_HOUR
291
+
292
+ total_sec = 0
293
+ multiplier = getattr(offset, 'n', 1) # The multiplier (if present).
294
+
295
+ for attr, value in offset.__dict__.items():
296
+ if value is None or value == 0:
297
+ continue
298
+ scaled_value = value * multiplier
299
+ if attr == 'years':
300
+ total_sec += scaled_value * MAX_DAYS_IN_YEAR * SECONDS_IN_DAY
301
+ elif attr == 'months':
302
+ total_sec += scaled_value * MAX_DAYS_IN_MONTH * SECONDS_IN_DAY
303
+ elif attr == 'days':
304
+ total_sec += scaled_value * SECONDS_IN_DAY
305
+ elif attr == 'hours':
306
+ total_sec += scaled_value * SECONDS_IN_HOUR
307
+ elif attr == 'minutes':
308
+ total_sec += scaled_value * SECONDS_IN_MINUTE
309
+ elif attr == 'seconds':
310
+ total_sec += scaled_value
311
+
312
+ return total_sec