kumoai 2.10.0.dev202509231831__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512161731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (53) hide show
  1. kumoai/__init__.py +22 -11
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +17 -16
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/client/rfm.py +37 -8
  7. kumoai/connector/utils.py +23 -2
  8. kumoai/experimental/rfm/__init__.py +164 -46
  9. kumoai/experimental/rfm/backend/__init__.py +0 -0
  10. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  11. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +49 -86
  12. kumoai/experimental/rfm/backend/local/sampler.py +315 -0
  13. kumoai/experimental/rfm/backend/local/table.py +119 -0
  14. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  15. kumoai/experimental/rfm/backend/snow/sampler.py +274 -0
  16. kumoai/experimental/rfm/backend/snow/table.py +135 -0
  17. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  18. kumoai/experimental/rfm/backend/sqlite/sampler.py +353 -0
  19. kumoai/experimental/rfm/backend/sqlite/table.py +126 -0
  20. kumoai/experimental/rfm/base/__init__.py +25 -0
  21. kumoai/experimental/rfm/base/column.py +66 -0
  22. kumoai/experimental/rfm/base/sampler.py +773 -0
  23. kumoai/experimental/rfm/base/source.py +19 -0
  24. kumoai/experimental/rfm/base/sql_sampler.py +60 -0
  25. kumoai/experimental/rfm/{local_table.py → base/table.py} +245 -156
  26. kumoai/experimental/rfm/{local_graph.py → graph.py} +425 -137
  27. kumoai/experimental/rfm/infer/__init__.py +6 -0
  28. kumoai/experimental/rfm/infer/dtype.py +79 -0
  29. kumoai/experimental/rfm/infer/pkey.py +126 -0
  30. kumoai/experimental/rfm/infer/time_col.py +62 -0
  31. kumoai/experimental/rfm/infer/timestamp.py +7 -4
  32. kumoai/experimental/rfm/pquery/__init__.py +4 -4
  33. kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
  34. kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +278 -224
  35. kumoai/experimental/rfm/rfm.py +669 -246
  36. kumoai/experimental/rfm/sagemaker.py +138 -0
  37. kumoai/jobs.py +1 -0
  38. kumoai/pquery/predictive_query.py +10 -6
  39. kumoai/spcs.py +1 -3
  40. kumoai/testing/decorators.py +1 -1
  41. kumoai/testing/snow.py +50 -0
  42. kumoai/trainer/trainer.py +12 -10
  43. kumoai/utils/__init__.py +3 -2
  44. kumoai/utils/progress_logger.py +239 -4
  45. kumoai/utils/sql.py +3 -0
  46. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/METADATA +15 -5
  47. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/RECORD +50 -32
  48. kumoai/experimental/rfm/local_graph_sampler.py +0 -176
  49. kumoai/experimental/rfm/local_pquery_driver.py +0 -404
  50. kumoai/experimental/rfm/utils.py +0 -344
  51. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/WHEEL +0 -0
  52. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/licenses/LICENSE +0 -0
  53. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,13 @@
1
1
  import warnings
2
- from typing import Dict, List, Optional, Tuple, Union
2
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
6
6
  from kumoapi.rfm.context import Subgraph
7
7
  from kumoapi.typing import Stype
8
8
 
9
- from kumoai.experimental.rfm import LocalGraph
10
- from kumoai.experimental.rfm.utils import normalize_text
11
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
9
+ from kumoai.experimental.rfm.backend.local import LocalTable
10
+ from kumoai.utils import ProgressLogger
12
11
 
13
12
  try:
14
13
  import torch
@@ -16,42 +15,40 @@ try:
16
15
  except ImportError:
17
16
  WITH_TORCH = False
18
17
 
18
+ if TYPE_CHECKING:
19
+ from kumoai.experimental.rfm import Graph
20
+
19
21
 
20
22
  class LocalGraphStore:
21
23
  def __init__(
22
24
  self,
23
- graph: LocalGraph,
24
- preprocess: bool = False,
25
+ graph: 'Graph',
25
26
  verbose: Union[bool, ProgressLogger] = True,
26
27
  ) -> None:
27
28
 
28
29
  if not isinstance(verbose, ProgressLogger):
29
- verbose = InteractiveProgressLogger(
30
- "Materializing graph",
30
+ verbose = ProgressLogger.default(
31
+ msg="Materializing graph",
31
32
  verbose=verbose,
32
33
  )
33
34
 
34
35
  with verbose as logger:
35
- self.df_dict, self.mask_dict = self.sanitize(graph, preprocess)
36
- self.stype_dict = self.get_stype_dict(graph)
36
+ self.df_dict, self.mask_dict = self.sanitize(graph)
37
37
  logger.log("Sanitized input data")
38
38
 
39
- self.pkey_name_dict, self.pkey_map_dict = self.get_pkey_data(graph)
39
+ self.pkey_map_dict = self.get_pkey_map_dict(graph)
40
40
  num_pkeys = sum(t.has_primary_key() for t in graph.tables.values())
41
41
  if num_pkeys > 1:
42
42
  logger.log(f"Collected primary keys from {num_pkeys} tables")
43
43
  else:
44
44
  logger.log(f"Collected primary key from {num_pkeys} table")
45
45
 
46
- (
47
- self.time_column_dict,
48
- self.time_dict,
49
- self.min_time,
50
- self.max_time,
51
- ) = self.get_time_data(graph)
52
- if self.max_time != pd.Timestamp.min:
46
+ self.time_dict, self.min_max_time_dict = self.get_time_data(graph)
47
+ if len(self.min_max_time_dict) > 0:
48
+ min_time = min(t for t, _ in self.min_max_time_dict.values())
49
+ max_time = max(t for _, t in self.min_max_time_dict.values())
53
50
  logger.log(f"Identified temporal graph from "
54
- f"{self.min_time.date()} to {self.max_time.date()}")
51
+ f"{min_time.date()} to {max_time.date()}")
55
52
  else:
56
53
  logger.log("Identified static graph without timestamps")
57
54
 
@@ -61,14 +58,6 @@ class LocalGraphStore:
61
58
  logger.log(f"Created graph with {num_nodes:,} nodes and "
62
59
  f"{num_edges:,} edges")
63
60
 
64
- @property
65
- def node_types(self) -> List[str]:
66
- return list(self.df_dict.keys())
67
-
68
- @property
69
- def edge_types(self) -> List[Tuple[str, str, str]]:
70
- return list(self.row_dict.keys())
71
-
72
61
  def get_node_id(self, table_name: str, pkey: pd.Series) -> np.ndarray:
73
62
  r"""Returns the node ID given primary keys.
74
63
 
@@ -104,8 +93,7 @@ class LocalGraphStore:
104
93
 
105
94
  def sanitize(
106
95
  self,
107
- graph: LocalGraph,
108
- preprocess: bool = False,
96
+ graph: 'Graph',
109
97
  ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, np.ndarray]]:
110
98
  r"""Sanitizes raw data according to table schema definition:
111
99
 
@@ -114,17 +102,12 @@ class LocalGraphStore:
114
102
  * drops timezone information from timestamps
115
103
  * drops duplicate primary keys
116
104
  * removes rows with missing primary keys or time values
117
-
118
- If ``preprocess`` is set to ``True``, it will additionally pre-process
119
- data for faster model processing. In particular, it:
120
- * tokenizes any text column that is not a foreign key
121
105
  """
122
- df_dict: Dict[str, pd.DataFrame] = {
123
- table_name: table._data.copy(deep=False).reset_index(drop=True)
124
- for table_name, table in graph.tables.items()
125
- }
126
-
127
- foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
106
+ df_dict: Dict[str, pd.DataFrame] = {}
107
+ for table_name, table in graph.tables.items():
108
+ assert isinstance(table, LocalTable)
109
+ df = table._data
110
+ df_dict[table_name] = df.copy(deep=False).reset_index(drop=True)
128
111
 
129
112
  mask_dict: Dict[str, np.ndarray] = {}
130
113
  for table in graph.tables.values():
@@ -143,12 +126,6 @@ class LocalGraphStore:
143
126
  ser = ser.dt.tz_localize(None)
144
127
  df_dict[table.name][col.name] = ser
145
128
 
146
- # Normalize text in advance (but exclude foreign keys):
147
- if (preprocess and col.stype == Stype.text
148
- and (table.name, col.name) not in foreign_keys):
149
- ser = df_dict[table.name][col.name]
150
- df_dict[table.name][col.name] = normalize_text(ser)
151
-
152
129
  mask: Optional[np.ndarray] = None
153
130
  if table._time_column is not None:
154
131
  ser = df_dict[table.name][table._time_column]
@@ -164,34 +141,16 @@ class LocalGraphStore:
164
141
 
165
142
  return df_dict, mask_dict
166
143
 
167
- def get_stype_dict(self, graph: LocalGraph) -> Dict[str, Dict[str, Stype]]:
168
- stype_dict: Dict[str, Dict[str, Stype]] = {}
169
- foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
170
- for table in graph.tables.values():
171
- stype_dict[table.name] = {}
172
- for column in table.columns:
173
- if column == table.primary_key:
174
- continue
175
- if (table.name, column.name) in foreign_keys:
176
- continue
177
- stype_dict[table.name][column.name] = column.stype
178
- return stype_dict
179
-
180
- def get_pkey_data(
144
+ def get_pkey_map_dict(
181
145
  self,
182
- graph: LocalGraph,
183
- ) -> Tuple[
184
- Dict[str, str],
185
- Dict[str, pd.DataFrame],
186
- ]:
187
- pkey_name_dict: Dict[str, str] = {}
146
+ graph: 'Graph',
147
+ ) -> Dict[str, pd.DataFrame]:
188
148
  pkey_map_dict: Dict[str, pd.DataFrame] = {}
189
149
 
190
150
  for table in graph.tables.values():
191
151
  if table._primary_key is None:
192
152
  continue
193
153
 
194
- pkey_name_dict[table.name] = table._primary_key
195
154
  pkey = self.df_dict[table.name][table._primary_key]
196
155
  pkey_map = pd.DataFrame(
197
156
  dict(arange=range(len(pkey))),
@@ -201,49 +160,53 @@ class LocalGraphStore:
201
160
  pkey_map = pkey_map[self.mask_dict[table.name]]
202
161
 
203
162
  if len(pkey_map) == 0:
204
- raise ValueError(
205
- f"Found no valid rows in table '{table.name}' since there "
206
- f"exists not a single row with a non-N/A primary key."
207
- f"Consider fixing your underlying data or removing this "
208
- f"table from the graph.")
163
+ error_msg = f"Found no valid rows in table '{table.name}'. "
164
+ if table.has_time_column():
165
+ error_msg += ("Please make sure that there exists valid "
166
+ "non-N/A primary key and time column pairs "
167
+ "in this table.")
168
+ else:
169
+ error_msg += ("Please make sure that there exists valid "
170
+ "non-N/A primary keys in this table.")
171
+ raise ValueError(error_msg)
209
172
 
210
173
  pkey_map_dict[table.name] = pkey_map
211
174
 
212
- return pkey_name_dict, pkey_map_dict
175
+ return pkey_map_dict
213
176
 
214
177
  def get_time_data(
215
178
  self,
216
- graph: LocalGraph,
179
+ graph: 'Graph',
217
180
  ) -> Tuple[
218
- Dict[str, str],
219
181
  Dict[str, np.ndarray],
220
- pd.Timestamp,
221
- pd.Timestamp,
182
+ Dict[str, Tuple[pd.Timestamp, pd.Timestamp]],
222
183
  ]:
223
- time_column_dict: Dict[str, str] = {}
224
184
  time_dict: Dict[str, np.ndarray] = {}
225
- min_time = pd.Timestamp.max
226
- max_time = pd.Timestamp.min
185
+ min_max_time_dict: Dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
227
186
  for table in graph.tables.values():
228
187
  if table._time_column is None:
229
188
  continue
230
189
 
231
190
  time = self.df_dict[table.name][table._time_column]
232
- time_dict[table.name] = time.astype('datetime64[ns]').astype(
233
- int).to_numpy() // 1000**3
234
- time_column_dict[table.name] = table._time_column
191
+ if time.dtype != 'datetime64[ns]':
192
+ time = time.astype('datetime64[ns]')
193
+ time_dict[table.name] = time.astype(int).to_numpy() // 1000**3
235
194
 
236
195
  if table.name in self.mask_dict.keys():
237
196
  time = time[self.mask_dict[table.name]]
238
197
  if len(time) > 0:
239
- min_time = min(min_time, time.min())
240
- max_time = max(max_time, time.max())
198
+ min_max_time_dict[table.name] = (time.min(), time.max())
199
+ else:
200
+ min_max_time_dict[table.name] = (
201
+ pd.Timestamp.max,
202
+ pd.Timestamp.min,
203
+ )
241
204
 
242
- return time_column_dict, time_dict, min_time, max_time
205
+ return time_dict, min_max_time_dict
243
206
 
244
207
  def get_csc(
245
208
  self,
246
- graph: LocalGraph,
209
+ graph: 'Graph',
247
210
  ) -> Tuple[
248
211
  Dict[Tuple[str, str, str], np.ndarray],
249
212
  Dict[Tuple[str, str, str], np.ndarray],
@@ -0,0 +1,315 @@
1
+ from typing import TYPE_CHECKING, Literal
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from kumoapi.pquery import ValidatedPredictiveQuery
6
+
7
+ from kumoai.experimental.rfm.backend.local import LocalGraphStore
8
+ from kumoai.experimental.rfm.base import Sampler, SamplerOutput
9
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
10
+ from kumoai.utils import ProgressLogger
11
+
12
+ if TYPE_CHECKING:
13
+ from kumoai.experimental.rfm import Graph
14
+
15
+
16
+ class LocalSampler(Sampler):
17
+ def __init__(
18
+ self,
19
+ graph: 'Graph',
20
+ verbose: bool | ProgressLogger = True,
21
+ ) -> None:
22
+ super().__init__(graph=graph, verbose=verbose)
23
+
24
+ import kumoai.kumolib as kumolib
25
+
26
+ self._graph_store = LocalGraphStore(graph, verbose)
27
+ self._graph_sampler = kumolib.NeighborSampler(
28
+ list(self.table_stype_dict.keys()),
29
+ self.edge_types,
30
+ {
31
+ '__'.join(edge_type): colptr
32
+ for edge_type, colptr in self._graph_store.colptr_dict.items()
33
+ },
34
+ {
35
+ '__'.join(edge_type): row
36
+ for edge_type, row in self._graph_store.row_dict.items()
37
+ },
38
+ self._graph_store.time_dict,
39
+ )
40
+
41
+ def _get_min_max_time_dict(
42
+ self,
43
+ table_names: list[str],
44
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
45
+ return {
46
+ key: value
47
+ for key, value in self._graph_store.min_max_time_dict.items()
48
+ if key in table_names
49
+ }
50
+
51
+ def _sample_subgraph(
52
+ self,
53
+ entity_table_name: str,
54
+ entity_pkey: pd.Series,
55
+ anchor_time: pd.Series | Literal['entity'],
56
+ columns_dict: dict[str, set[str]],
57
+ num_neighbors: list[int],
58
+ ) -> SamplerOutput:
59
+
60
+ index = self._graph_store.get_node_id(entity_table_name, entity_pkey)
61
+
62
+ if isinstance(anchor_time, pd.Series):
63
+ time = anchor_time.astype(int).to_numpy() // 1000**3 # to seconds
64
+ else:
65
+ assert anchor_time == 'entity'
66
+ time = self._graph_store.time_dict[entity_table_name][index]
67
+
68
+ (
69
+ row_dict,
70
+ col_dict,
71
+ node_dict,
72
+ batch_dict,
73
+ num_sampled_nodes_dict,
74
+ num_sampled_edges_dict,
75
+ ) = self._graph_sampler.sample(
76
+ {
77
+ '__'.join(edge_type): num_neighbors
78
+ for edge_type in self.edge_types
79
+ },
80
+ {},
81
+ entity_table_name,
82
+ index,
83
+ time,
84
+ )
85
+
86
+ df_dict: dict[str, pd.DataFrame] = {}
87
+ inverse_dict: dict[str, np.ndarray] = {}
88
+ for table_name, node in node_dict.items():
89
+ df = self._graph_store.df_dict[table_name]
90
+ columns = columns_dict[table_name]
91
+ if self.end_time_column_dict.get(table_name, None) in columns:
92
+ df = df.iloc[node]
93
+ elif len(columns) == 0:
94
+ df = df.iloc[node]
95
+ else:
96
+ # Only store unique rows in `df` above a certain threshold:
97
+ unique_node, inverse = np.unique(node, return_inverse=True)
98
+ if len(node) > 1.05 * len(unique_node):
99
+ df = df.iloc[unique_node]
100
+ inverse_dict[table_name] = inverse
101
+ else:
102
+ df = df.iloc[node]
103
+ df = df.reset_index(drop=True)
104
+ df = df[list(columns)]
105
+ df_dict[table_name] = df
106
+
107
+ num_sampled_nodes_dict = {
108
+ table_name: num_sampled_nodes.tolist()
109
+ for table_name, num_sampled_nodes in
110
+ num_sampled_nodes_dict.items()
111
+ }
112
+
113
+ row_dict = {
114
+ edge_type: row_dict['__'.join(edge_type)]
115
+ for edge_type in self.edge_types
116
+ }
117
+ col_dict = {
118
+ edge_type: col_dict['__'.join(edge_type)]
119
+ for edge_type in self.edge_types
120
+ }
121
+ num_sampled_edges_dict = {
122
+ edge_type: num_sampled_edges_dict['__'.join(edge_type)].tolist()
123
+ for edge_type in self.edge_types
124
+ }
125
+
126
+ return SamplerOutput(
127
+ anchor_time=time * 1000**3, # to nanoseconds
128
+ df_dict=df_dict,
129
+ inverse_dict=inverse_dict,
130
+ batch_dict=batch_dict,
131
+ num_sampled_nodes_dict=num_sampled_nodes_dict,
132
+ row_dict=row_dict,
133
+ col_dict=col_dict,
134
+ num_sampled_edges_dict=num_sampled_edges_dict,
135
+ )
136
+
137
+ def _sample_entity_table(
138
+ self,
139
+ table_name: str,
140
+ columns: set[str],
141
+ num_rows: int,
142
+ random_seed: int | None = None,
143
+ ) -> pd.DataFrame:
144
+ pkey_map = self._graph_store.pkey_map_dict[table_name]
145
+ if len(pkey_map) > num_rows:
146
+ pkey_map = pkey_map.sample(
147
+ n=num_rows,
148
+ random_state=random_seed,
149
+ ignore_index=True,
150
+ )
151
+ df = self._graph_store.df_dict[table_name]
152
+ df = df.iloc[pkey_map['arange']][list(columns)]
153
+ return df
154
+
155
+ def _sample_target(
156
+ self,
157
+ query: ValidatedPredictiveQuery,
158
+ entity_df: pd.DataFrame,
159
+ train_index: np.ndarray,
160
+ train_time: pd.Series,
161
+ num_train_examples: int,
162
+ test_index: np.ndarray,
163
+ test_time: pd.Series,
164
+ num_test_examples: int,
165
+ columns_dict: dict[str, set[str]],
166
+ time_offset_dict: dict[
167
+ tuple[str, str, str],
168
+ tuple[pd.DateOffset | None, pd.DateOffset],
169
+ ],
170
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
171
+
172
+ train_y, train_mask = self._sample_target_set(
173
+ query=query,
174
+ pkey=entity_df[self.primary_key_dict[query.entity_table]],
175
+ index=train_index,
176
+ anchor_time=train_time,
177
+ num_examples=num_train_examples,
178
+ columns_dict=columns_dict,
179
+ time_offset_dict=time_offset_dict,
180
+ )
181
+
182
+ test_y, test_mask = self._sample_target_set(
183
+ query=query,
184
+ pkey=entity_df[self.primary_key_dict[query.entity_table]],
185
+ index=test_index,
186
+ anchor_time=test_time,
187
+ num_examples=num_test_examples,
188
+ columns_dict=columns_dict,
189
+ time_offset_dict=time_offset_dict,
190
+ )
191
+
192
+ return train_y, train_mask, test_y, test_mask
193
+
194
+ # Helper Methods ##########################################################
195
+
196
+ def _sample_target_set(
197
+ self,
198
+ query: ValidatedPredictiveQuery,
199
+ pkey: pd.Series,
200
+ index: np.ndarray,
201
+ anchor_time: pd.Series,
202
+ num_examples: int,
203
+ columns_dict: dict[str, set[str]],
204
+ time_offset_dict: dict[
205
+ tuple[str, str, str],
206
+ tuple[pd.DateOffset | None, pd.DateOffset],
207
+ ],
208
+ batch_size: int = 10_000,
209
+ ) -> tuple[pd.Series, np.ndarray]:
210
+
211
+ num_hops = 1 if len(time_offset_dict) > 0 else 0
212
+ num_neighbors_dict: dict[str, list[int]] = {}
213
+ unix_time_offset_dict: dict[str, list[list[int | None]]] = {}
214
+ for edge_type, (start, end) in time_offset_dict.items():
215
+ unix_time_offset_dict['__'.join(edge_type)] = [[
216
+ date_offset_to_seconds(start) if start is not None else None,
217
+ date_offset_to_seconds(end),
218
+ ]]
219
+ for edge_type in set(self.edge_types) - set(time_offset_dict.keys()):
220
+ num_neighbors_dict['__'.join(edge_type)] = [0] * num_hops
221
+
222
+ if anchor_time.dtype != 'datetime64[ns]':
223
+ anchor_time = anchor_time.astype('datetime64')
224
+
225
+ count = 0
226
+ ys: list[pd.Series] = []
227
+ mask = np.full(len(index), False, dtype=bool)
228
+ for start in range(0, len(index), batch_size):
229
+ subset = pkey.iloc[index[start:start + batch_size]]
230
+ time = anchor_time.iloc[start:start + batch_size]
231
+
232
+ _, _, node_dict, batch_dict, _, _ = self._graph_sampler.sample(
233
+ num_neighbors_dict,
234
+ unix_time_offset_dict,
235
+ query.entity_table,
236
+ self._graph_store.get_node_id(query.entity_table, subset),
237
+ time.astype(int).to_numpy() // 1000**3, # to seconds
238
+ )
239
+
240
+ feat_dict: dict[str, pd.DataFrame] = {}
241
+ time_dict: dict[str, pd.Series] = {}
242
+ for table_name, columns in columns_dict.items():
243
+ df = self._graph_store.df_dict[table_name]
244
+ df = df.iloc[node_dict[table_name]].reset_index(drop=True)
245
+ df = df[list(columns)]
246
+ feat_dict[table_name] = df
247
+
248
+ time_column = self.time_column_dict.get(table_name)
249
+ if time_column in columns:
250
+ time_dict[table_name] = df[time_column]
251
+
252
+ y, _mask = PQueryPandasExecutor().execute(
253
+ query=query,
254
+ feat_dict=feat_dict,
255
+ time_dict=time_dict,
256
+ batch_dict=batch_dict,
257
+ anchor_time=time,
258
+ num_forecasts=query.num_forecasts,
259
+ )
260
+ ys.append(y)
261
+ mask[start:start + batch_size] = _mask
262
+
263
+ count += len(y)
264
+ if count >= num_examples:
265
+ break
266
+
267
+ if len(ys) == 0:
268
+ y = pd.Series([], dtype=float)
269
+ elif len(ys) == 1:
270
+ y = ys[0]
271
+ else:
272
+ y = pd.concat(ys, axis=0, ignore_index=True)
273
+
274
+ return y, mask
275
+
276
+
277
+ # Helper Functions ############################################################
278
+
279
+
280
+ def date_offset_to_seconds(offset: pd.DateOffset) -> int:
281
+ r"""Convert a :class:`pandas.DateOffset` into a number of seconds.
282
+
283
+ .. note::
284
+ We are conservative and take months and years as their maximum value.
285
+ Additional values are then dropped in label computation where we know
286
+ the actual dates.
287
+ """
288
+ MAX_DAYS_IN_MONTH = 31
289
+ MAX_DAYS_IN_YEAR = 366
290
+
291
+ SECONDS_IN_MINUTE = 60
292
+ SECONDS_IN_HOUR = 60 * SECONDS_IN_MINUTE
293
+ SECONDS_IN_DAY = 24 * SECONDS_IN_HOUR
294
+
295
+ total_sec = 0
296
+ multiplier = getattr(offset, 'n', 1) # The multiplier (if present).
297
+
298
+ for attr, value in offset.__dict__.items():
299
+ if value is None or value == 0:
300
+ continue
301
+ scaled_value = value * multiplier
302
+ if attr == 'years':
303
+ total_sec += scaled_value * MAX_DAYS_IN_YEAR * SECONDS_IN_DAY
304
+ elif attr == 'months':
305
+ total_sec += scaled_value * MAX_DAYS_IN_MONTH * SECONDS_IN_DAY
306
+ elif attr == 'days':
307
+ total_sec += scaled_value * SECONDS_IN_DAY
308
+ elif attr == 'hours':
309
+ total_sec += scaled_value * SECONDS_IN_HOUR
310
+ elif attr == 'minutes':
311
+ total_sec += scaled_value * SECONDS_IN_MINUTE
312
+ elif attr == 'seconds':
313
+ total_sec += scaled_value
314
+
315
+ return total_sec
@@ -0,0 +1,119 @@
1
+ import warnings
2
+ from typing import List, Optional, cast
3
+
4
+ import pandas as pd
5
+
6
+ from kumoai.experimental.rfm.base import (
7
+ DataBackend,
8
+ SourceColumn,
9
+ SourceForeignKey,
10
+ Table,
11
+ )
12
+ from kumoai.experimental.rfm.infer import infer_dtype
13
+
14
+
15
+ class LocalTable(Table):
16
+ r"""A table backed by a :class:`pandas.DataFrame`.
17
+
18
+ A :class:`LocalTable` fully specifies the relevant metadata, *i.e.*
19
+ selected columns, column semantic types, primary keys and time columns.
20
+ :class:`LocalTable` is used to create a :class:`Graph`.
21
+
22
+ .. code-block:: python
23
+
24
+ import pandas as pd
25
+ import kumoai.experimental.rfm as rfm
26
+
27
+ # Load data from a CSV file:
28
+ df = pd.read_csv("data.csv")
29
+
30
+ # Create a table from a `pandas.DataFrame` and infer its metadata ...
31
+ table = rfm.LocalTable(df, name="my_table").infer_metadata()
32
+
33
+ # ... or create a table explicitly:
34
+ table = rfm.LocalTable(
35
+ df=df,
36
+ name="my_table",
37
+ primary_key="id",
38
+ time_column="time",
39
+ end_time_column=None,
40
+ )
41
+
42
+ # Verify metadata:
43
+ table.print_metadata()
44
+
45
+ # Change the semantic type of a column:
46
+ table[column].stype = "text"
47
+
48
+ Args:
49
+ df: The data frame to create this table from.
50
+ name: The name of this table.
51
+ primary_key: The name of the primary key of this table, if it exists.
52
+ time_column: The name of the time column of this table, if it exists.
53
+ end_time_column: The name of the end time column of this table, if it
54
+ exists.
55
+ """
56
+ def __init__(
57
+ self,
58
+ df: pd.DataFrame,
59
+ name: str,
60
+ primary_key: Optional[str] = None,
61
+ time_column: Optional[str] = None,
62
+ end_time_column: Optional[str] = None,
63
+ ) -> None:
64
+
65
+ if df.empty:
66
+ raise ValueError("Data frame is empty")
67
+ if isinstance(df.columns, pd.MultiIndex):
68
+ raise ValueError("Data frame must not have a multi-index")
69
+ if not df.columns.is_unique:
70
+ raise ValueError("Data frame must have unique column names")
71
+ if any(col == '' for col in df.columns):
72
+ raise ValueError("Data frame must have non-empty column names")
73
+
74
+ self._data = df.copy(deep=False)
75
+
76
+ super().__init__(
77
+ name=name,
78
+ columns=list(df.columns),
79
+ primary_key=primary_key,
80
+ time_column=time_column,
81
+ end_time_column=end_time_column,
82
+ )
83
+
84
+ @property
85
+ def backend(self) -> DataBackend:
86
+ return cast(DataBackend, DataBackend.LOCAL)
87
+
88
+ def _get_source_columns(self) -> List[SourceColumn]:
89
+ source_columns: List[SourceColumn] = []
90
+ for column in self._data.columns:
91
+ ser = self._data[column]
92
+ try:
93
+ dtype = infer_dtype(ser)
94
+ except Exception:
95
+ warnings.warn(f"Data type inference for column '{column}' in "
96
+ f"table '{self.name}' failed. Consider changing "
97
+ f"the data type of the column to use it within "
98
+ f"this table.")
99
+ continue
100
+
101
+ source_column = SourceColumn(
102
+ name=column,
103
+ dtype=dtype,
104
+ is_primary_key=False,
105
+ is_unique_key=False,
106
+ is_nullable=True,
107
+ )
108
+ source_columns.append(source_column)
109
+
110
+ return source_columns
111
+
112
+ def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
113
+ return []
114
+
115
+ def _get_sample_df(self) -> pd.DataFrame:
116
+ return self._data
117
+
118
+ def _get_num_rows(self) -> Optional[int]:
119
+ return len(self._data)