kumoai 2.13.0.dev202512031731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512181731__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/__init__.py +12 -0
  2. kumoai/_version.py +1 -1
  3. kumoai/client/pquery.py +6 -2
  4. kumoai/experimental/rfm/__init__.py +33 -8
  5. kumoai/experimental/rfm/authenticate.py +3 -4
  6. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  7. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +52 -91
  8. kumoai/experimental/rfm/backend/local/sampler.py +315 -0
  9. kumoai/experimental/rfm/backend/local/table.py +31 -14
  10. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  11. kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
  12. kumoai/experimental/rfm/backend/snow/table.py +75 -23
  13. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  14. kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
  15. kumoai/experimental/rfm/backend/sqlite/table.py +71 -28
  16. kumoai/experimental/rfm/base/__init__.py +24 -3
  17. kumoai/experimental/rfm/base/column.py +6 -12
  18. kumoai/experimental/rfm/base/column_expression.py +16 -0
  19. kumoai/experimental/rfm/base/sampler.py +773 -0
  20. kumoai/experimental/rfm/base/source.py +1 -0
  21. kumoai/experimental/rfm/base/sql_sampler.py +84 -0
  22. kumoai/experimental/rfm/base/sql_table.py +113 -0
  23. kumoai/experimental/rfm/base/table.py +136 -105
  24. kumoai/experimental/rfm/graph.py +296 -89
  25. kumoai/experimental/rfm/infer/dtype.py +46 -59
  26. kumoai/experimental/rfm/infer/pkey.py +4 -2
  27. kumoai/experimental/rfm/infer/time_col.py +1 -2
  28. kumoai/experimental/rfm/pquery/executor.py +27 -27
  29. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  30. kumoai/experimental/rfm/rfm.py +299 -230
  31. kumoai/experimental/rfm/sagemaker.py +4 -4
  32. kumoai/pquery/predictive_query.py +10 -6
  33. kumoai/testing/snow.py +50 -0
  34. kumoai/utils/__init__.py +3 -2
  35. kumoai/utils/progress_logger.py +178 -12
  36. kumoai/utils/sql.py +3 -0
  37. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/METADATA +4 -2
  38. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/RECORD +41 -34
  39. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  40. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  41. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/WHEEL +0 -0
  42. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,34 @@
1
1
  import re
2
- from typing import List, Optional, Sequence
2
+ from collections.abc import Sequence
3
+ from typing import cast
3
4
 
4
5
  import pandas as pd
6
+ from kumoapi.model_plan import MissingType
5
7
  from kumoapi.typing import Dtype
6
8
 
7
- from kumoai.experimental.rfm.backend.sqlite import Connection
8
- from kumoai.experimental.rfm.base import SourceColumn, SourceForeignKey, Table
9
+ from kumoai.experimental.rfm.backend.snow import Connection
10
+ from kumoai.experimental.rfm.base import (
11
+ ColumnExpressionType,
12
+ DataBackend,
13
+ SourceColumn,
14
+ SourceForeignKey,
15
+ SQLTable,
16
+ )
17
+ from kumoai.utils import quote_ident
9
18
 
10
19
 
11
- class SnowTable(Table):
20
+ class SnowTable(SQLTable):
12
21
  r"""A table backed by a :class:`sqlite` database.
13
22
 
14
23
  Args:
15
24
  connection: The connection to a :class:`snowflake` database.
16
- name: The name of this table.
17
- columns: The selected columns of this table.
25
+ name: The logical name of this table.
26
+ source_name: The physical name of this table in the database. If set to
27
+ ``None``, ``name`` is being used.
28
+ database: The database.
29
+ schema: The schema.
30
+ columns: The selected physical columns of this table.
31
+ column_expressions: The logical columns of this table.
18
32
  primary_key: The name of the primary key of this table, if it exists.
19
33
  time_column: The name of the time column of this table, if it exists.
20
34
  end_time_column: The name of the end time column of this table, if it
@@ -24,32 +38,67 @@ class SnowTable(Table):
24
38
  self,
25
39
  connection: Connection,
26
40
  name: str,
27
- columns: Optional[Sequence[str]] = None,
28
- primary_key: Optional[str] = None,
29
- time_column: Optional[str] = None,
30
- end_time_column: Optional[str] = None,
41
+ source_name: str | None = None,
42
+ database: str | None = None,
43
+ schema: str | None = None,
44
+ columns: Sequence[str] | None = None,
45
+ column_expressions: Sequence[ColumnExpressionType] | None = None,
46
+ primary_key: MissingType | str | None = MissingType.VALUE,
47
+ time_column: str | None = None,
48
+ end_time_column: str | None = None,
31
49
  ) -> None:
32
50
 
51
+ if database is not None and schema is None:
52
+ raise ValueError(f"Unspecified 'schema' for table "
53
+ f"'{source_name or name}' in database "
54
+ f"'{database}'")
55
+
33
56
  self._connection = connection
57
+ self._database = database
58
+ self._schema = schema
34
59
 
35
60
  super().__init__(
36
61
  name=name,
62
+ source_name=source_name,
37
63
  columns=columns,
64
+ column_expressions=column_expressions,
38
65
  primary_key=primary_key,
39
66
  time_column=time_column,
40
67
  end_time_column=end_time_column,
41
68
  )
42
69
 
43
- def _get_source_columns(self) -> List[SourceColumn]:
44
- source_columns: List[SourceColumn] = []
70
+ @property
71
+ def backend(self) -> DataBackend:
72
+ return cast(DataBackend, DataBackend.SNOWFLAKE)
73
+
74
+ @property
75
+ def fqn(self) -> str:
76
+ r"""The fully-qualified quoted table name."""
77
+ names: list[str] = []
78
+ if self._database is not None:
79
+ names.append(quote_ident(self._database))
80
+ if self._schema is not None:
81
+ names.append(quote_ident(self._schema))
82
+ return '.'.join(names + [quote_ident(self._source_name)])
83
+
84
+ def _get_source_columns(self) -> list[SourceColumn]:
85
+ source_columns: list[SourceColumn] = []
45
86
  with self._connection.cursor() as cursor:
46
87
  try:
47
- cursor.execute(f"DESCRIBE TABLE {self.name}")
88
+ sql = f"DESCRIBE TABLE {self.fqn}"
89
+ cursor.execute(sql)
48
90
  except Exception as e:
49
- raise ValueError(f"Table '{self.name}' does not exist") from e
91
+ names: list[str] = []
92
+ if self._database is not None:
93
+ names.append(self._database)
94
+ if self._schema is not None:
95
+ names.append(self._schema)
96
+ source_name = '.'.join(names + [self._source_name])
97
+ raise ValueError(f"Table '{source_name}' does not exist in "
98
+ f"the remote data backend") from e
50
99
 
51
100
  for row in cursor.fetchall():
52
- column, type, _, _, _, is_pkey, is_unique = row[:7]
101
+ column, type, _, null, _, is_pkey, is_unique, *_ = row
53
102
 
54
103
  type = type.strip().upper()
55
104
  if type.startswith('NUMBER'):
@@ -70,26 +119,29 @@ class SnowTable(Table):
70
119
  dtype=dtype,
71
120
  is_primary_key=is_pkey.strip().upper() == 'Y',
72
121
  is_unique_key=is_unique.strip().upper() == 'Y',
122
+ is_nullable=null.strip().upper() == 'Y',
73
123
  )
74
124
  source_columns.append(source_column)
75
125
 
76
126
  return source_columns
77
127
 
78
- def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
79
- source_fkeys: List[SourceForeignKey] = []
128
+ def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
129
+ source_fkeys: list[SourceForeignKey] = []
80
130
  with self._connection.cursor() as cursor:
81
- cursor.execute(f"SHOW IMPORTED KEYS IN TABLE {self.name}")
131
+ sql = f"SHOW IMPORTED KEYS IN TABLE {self.fqn}"
132
+ cursor.execute(sql)
82
133
  for row in cursor.fetchall():
83
- _, _, _, dst_table, pkey, _, _, _, fkey = row[:9]
134
+ _, _, _, dst_table, pkey, _, _, _, fkey, *_ = row
84
135
  source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
85
136
  return source_fkeys
86
137
 
87
138
  def _get_sample_df(self) -> pd.DataFrame:
88
139
  with self._connection.cursor() as cursor:
89
- columns = ', '.join(self._source_column_dict.keys())
90
- cursor.execute(f"SELECT {columns} FROM {self.name} LIMIT 1000")
140
+ columns = [quote_ident(col) for col in self._source_column_dict]
141
+ sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
142
+ cursor.execute(sql)
91
143
  table = cursor.fetch_arrow_all()
92
- return table.to_pandas()
144
+ return table.to_pandas(types_mapper=pd.ArrowDtype)
93
145
 
94
- def _get_num_rows(self) -> Optional[int]:
146
+ def _get_num_rows(self) -> int | None:
95
147
  return None
@@ -1,5 +1,5 @@
1
1
  from pathlib import Path
2
- from typing import Any, TypeAlias, Union
2
+ from typing import Any, TypeAlias
3
3
 
4
4
  try:
5
5
  import adbc_driver_sqlite.dbapi as adbc
@@ -11,7 +11,7 @@ except ImportError:
11
11
  Connection: TypeAlias = adbc.AdbcSqliteConnection
12
12
 
13
13
 
14
- def connect(uri: Union[str, Path, None] = None, **kwargs: Any) -> Connection:
14
+ def connect(uri: str | Path | None = None, **kwargs: Any) -> Connection:
15
15
  r"""Opens a connection to a :class:`sqlite` database.
16
16
 
17
17
  uri: The path to the database file to be opened.
@@ -22,9 +22,11 @@ def connect(uri: Union[str, Path, None] = None, **kwargs: Any) -> Connection:
22
22
 
23
23
 
24
24
  from .table import SQLiteTable # noqa: E402
25
+ from .sampler import SQLiteSampler # noqa: E402
25
26
 
26
27
  __all__ = [
27
28
  'connect',
28
29
  'Connection',
29
30
  'SQLiteTable',
31
+ 'SQLiteSampler',
30
32
  ]
@@ -0,0 +1,349 @@
1
+ import warnings
2
+ from collections import defaultdict
3
+ from typing import TYPE_CHECKING
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import pyarrow as pa
8
+ from kumoapi.pquery import ValidatedPredictiveQuery
9
+ from kumoapi.typing import Stype
10
+
11
+ from kumoai.experimental.rfm.base import SQLSampler
12
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
13
+ from kumoai.utils import ProgressLogger, quote_ident
14
+
15
+ if TYPE_CHECKING:
16
+ from kumoai.experimental.rfm import Graph
17
+
18
+
19
+ class SQLiteSampler(SQLSampler):
20
+ def __init__(
21
+ self,
22
+ graph: 'Graph',
23
+ verbose: bool | ProgressLogger = True,
24
+ optimize: bool = False,
25
+ ) -> None:
26
+ super().__init__(graph=graph, verbose=verbose)
27
+
28
+ if optimize:
29
+ with self._connection.cursor() as cursor:
30
+ cursor.execute("PRAGMA temp_store = MEMORY")
31
+ cursor.execute("PRAGMA cache_size = -2000000") # 2 GB
32
+
33
+ # Collect database indices to speed-up sampling:
34
+ index_dict: dict[str, set[tuple[str, ...]]] = defaultdict(set)
35
+ for table_name, primary_key in self.primary_key_dict.items():
36
+ source_table = self.source_table_dict[table_name]
37
+ if not source_table[primary_key].is_unique_key:
38
+ index_dict[table_name].add((primary_key, ))
39
+ for src_table_name, foreign_key, _ in graph.edges:
40
+ source_table = self.source_table_dict[src_table_name]
41
+ if source_table[foreign_key].is_unique_key:
42
+ pass
43
+ elif time_column := self.time_column_dict.get(src_table_name):
44
+ index_dict[src_table_name].add((foreign_key, time_column))
45
+ else:
46
+ index_dict[src_table_name].add((foreign_key, ))
47
+
48
+ # Only maintain missing indices:
49
+ with self._connection.cursor() as cursor:
50
+ for table_name in list(index_dict.keys()):
51
+ indices = index_dict[table_name]
52
+ sql = f"PRAGMA index_list({self.fqn_dict[table_name]})"
53
+ cursor.execute(sql)
54
+ for _, index_name, *_ in cursor.fetchall():
55
+ sql = f"PRAGMA index_info({quote_ident(index_name)})"
56
+ cursor.execute(sql)
57
+ index = tuple(info[2] for info in sorted(
58
+ cursor.fetchall(), key=lambda x: x[0]))
59
+ indices.discard(index)
60
+ if len(indices) == 0:
61
+ del index_dict[table_name]
62
+
63
+ num = sum(len(indices) for indices in index_dict.values())
64
+ index_repr = '1 index' if num == 1 else f'{num} indices'
65
+ num = len(index_dict)
66
+ table_repr = '1 table' if num == 1 else f'{num} tables'
67
+
68
+ if optimize and len(index_dict) > 0:
69
+ if not isinstance(verbose, ProgressLogger):
70
+ verbose = ProgressLogger.default(
71
+ msg="Optimizing SQLite database",
72
+ verbose=verbose,
73
+ )
74
+
75
+ with verbose as logger, self._connection.cursor() as cursor:
76
+ for table_name, indices in index_dict.items():
77
+ for index in indices:
78
+ name = f"kumo_index_{table_name}_{'_'.join(index)}"
79
+ name = quote_ident(name)
80
+ columns = ', '.join(quote_ident(v) for v in index)
81
+ columns += ' DESC' if len(index) > 1 else ''
82
+ sql = (f"CREATE INDEX IF NOT EXISTS {name}\n"
83
+ f"ON {self.fqn_dict[table_name]}({columns})")
84
+ cursor.execute(sql)
85
+ self._connection.commit()
86
+ logger.log(f"Created {index_repr} in {table_repr}")
87
+
88
+ elif len(index_dict) > 0:
89
+ warnings.warn(f"Missing {index_repr} in {table_repr} for optimal "
90
+ f"database querying. For improving runtime, we "
91
+ f"strongly suggest to create these indices by "
92
+ f"instantiating KumoRFM via "
93
+ f"`KumoRFM(graph, optimize=True)`.")
94
+
95
+ def _get_min_max_time_dict(
96
+ self,
97
+ table_names: list[str],
98
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
99
+ selects: list[str] = []
100
+ for table_name in table_names:
101
+ time_column = self.time_column_dict[table_name]
102
+ select = (f"SELECT\n"
103
+ f" ? as table_name,\n"
104
+ f" MIN({quote_ident(time_column)}) as min_date,\n"
105
+ f" MAX({quote_ident(time_column)}) as max_date\n"
106
+ f"FROM {self.fqn_dict[table_name]}")
107
+ selects.append(select)
108
+ sql = "\nUNION ALL\n".join(selects)
109
+
110
+ out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
111
+ with self._connection.cursor() as cursor:
112
+ cursor.execute(sql, table_names)
113
+ for table_name, _min, _max in cursor.fetchall():
114
+ out_dict[table_name] = (
115
+ pd.Timestamp.max if _min is None else pd.Timestamp(_min),
116
+ pd.Timestamp.min if _max is None else pd.Timestamp(_max),
117
+ )
118
+ return out_dict
119
+
120
+ def _sample_entity_table(
121
+ self,
122
+ table_name: str,
123
+ columns: set[str],
124
+ num_rows: int,
125
+ random_seed: int | None = None,
126
+ ) -> pd.DataFrame:
127
+ # NOTE SQLite does not natively support passing a `random_seed`.
128
+
129
+ filters: list[str] = []
130
+ primary_key = self.primary_key_dict[table_name]
131
+ if self.source_table_dict[table_name][primary_key].is_nullable:
132
+ filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
133
+ time_column = self.time_column_dict.get(table_name)
134
+ if (time_column is not None and
135
+ self.source_table_dict[table_name][time_column].is_nullable):
136
+ filters.append(f" {quote_ident(time_column)} IS NOT NULL")
137
+
138
+ # TODO Make this query more efficient - it does full table scan.
139
+ sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
140
+ f"FROM {self.fqn_dict[table_name]}")
141
+ if len(filters) > 0:
142
+ sql += f"\nWHERE{' AND'.join(filters)}"
143
+ sql += f"\nORDER BY RANDOM() LIMIT {num_rows}"
144
+
145
+ with self._connection.cursor() as cursor:
146
+ # NOTE This may return duplicate primary keys. This is okay.
147
+ cursor.execute(sql)
148
+ table = cursor.fetch_arrow_table()
149
+
150
+ return self._sanitize(table_name, table)
151
+
152
+ def _sample_target(
153
+ self,
154
+ query: ValidatedPredictiveQuery,
155
+ entity_df: pd.DataFrame,
156
+ train_index: np.ndarray,
157
+ train_time: pd.Series,
158
+ num_train_examples: int,
159
+ test_index: np.ndarray,
160
+ test_time: pd.Series,
161
+ num_test_examples: int,
162
+ columns_dict: dict[str, set[str]],
163
+ time_offset_dict: dict[
164
+ tuple[str, str, str],
165
+ tuple[pd.DateOffset | None, pd.DateOffset],
166
+ ],
167
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
168
+ train_y, train_mask = self._sample_target_set(
169
+ query=query,
170
+ entity_df=entity_df,
171
+ index=train_index,
172
+ anchor_time=train_time,
173
+ num_examples=num_train_examples,
174
+ columns_dict=columns_dict,
175
+ time_offset_dict=time_offset_dict,
176
+ )
177
+
178
+ test_y, test_mask = self._sample_target_set(
179
+ query=query,
180
+ entity_df=entity_df,
181
+ index=test_index,
182
+ anchor_time=test_time,
183
+ num_examples=num_test_examples,
184
+ columns_dict=columns_dict,
185
+ time_offset_dict=time_offset_dict,
186
+ )
187
+
188
+ return train_y, train_mask, test_y, test_mask
189
+
190
+ def _by_pkey(
191
+ self,
192
+ table_name: str,
193
+ pkey: pd.Series,
194
+ columns: set[str],
195
+ ) -> tuple[pd.DataFrame, np.ndarray]:
196
+ pkey_name = self.primary_key_dict[table_name]
197
+
198
+ tmp = pa.table([pa.array(pkey)], names=['id'])
199
+ tmp_name = f'tmp_{table_name}_{pkey_name}_{id(tmp)}'
200
+
201
+ if self.source_table_dict[table_name][pkey_name].is_unique_key:
202
+ sql = (f"SELECT tmp.rowid - 1 as __batch__, "
203
+ f"{', '.join('ent.' + quote_ident(c) for c in columns)}\n"
204
+ f"FROM {quote_ident(tmp_name)} tmp\n"
205
+ f"JOIN {self.fqn_dict[table_name]} ent\n"
206
+ f" ON ent.{quote_ident(pkey_name)} = tmp.id")
207
+ else:
208
+ sql = (f"SELECT tmp.rowid - 1 as __batch__, "
209
+ f"{', '.join('ent.' + quote_ident(c) for c in columns)}\n"
210
+ f"FROM {quote_ident(tmp_name)} tmp\n"
211
+ f"JOIN {self.fqn_dict[table_name]} ent\n"
212
+ f" ON ent.rowid = (\n"
213
+ f" SELECT rowid FROM {self.fqn_dict[table_name]}\n"
214
+ f" WHERE {quote_ident(pkey_name)} == tmp.id\n"
215
+ f" LIMIT 1\n"
216
+ f")")
217
+
218
+ with self._connection.cursor() as cursor:
219
+ cursor.adbc_ingest(tmp_name, tmp, mode='replace')
220
+ cursor.execute(sql)
221
+ table = cursor.fetch_arrow_table()
222
+
223
+ batch = table['__batch__'].to_numpy()
224
+ table = table.remove_column(table.schema.get_field_index('__batch__'))
225
+
226
+ return table.to_pandas(), batch # TODO Use `self._sanitize`.
227
+
228
+ # Helper Methods ##########################################################
229
+
230
+ def _by_time(
231
+ self,
232
+ table_name: str,
233
+ fkey: str,
234
+ pkey: pd.Series,
235
+ anchor_time: pd.Series,
236
+ min_offset: pd.DateOffset | None,
237
+ max_offset: pd.DateOffset,
238
+ columns: set[str],
239
+ ) -> tuple[pd.DataFrame, np.ndarray]:
240
+ # NOTE SQLite does not have a native datetime format. Currently, we
241
+ # assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
242
+ tmp = pa.table([pa.array(pkey)], names=['id'])
243
+ end_time = anchor_time + max_offset
244
+ end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
245
+ tmp = tmp.append_column('end', pa.array(end_time))
246
+ if min_offset is not None:
247
+ start_time = anchor_time + min_offset
248
+ start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
249
+ tmp = tmp.append_column('start', pa.array(start_time))
250
+ tmp_name = f'tmp_{table_name}_{fkey}_{id(tmp)}'
251
+
252
+ time_column = self.time_column_dict[table_name]
253
+ sql = (f"SELECT tmp.rowid - 1 as __batch__, "
254
+ f"{', '.join('fact.' + quote_ident(col) for col in columns)}\n"
255
+ f"FROM {quote_ident(tmp_name)} tmp\n"
256
+ f"JOIN {self.fqn_dict[table_name]} fact\n"
257
+ f" ON fact.{quote_ident(fkey)} = tmp.id\n"
258
+ f" AND fact.{quote_ident(time_column)} <= tmp.end")
259
+ if min_offset is not None:
260
+ sql += f"\n AND fact.{quote_ident(time_column)} > tmp.start"
261
+
262
+ with self._connection.cursor() as cursor:
263
+ cursor.adbc_ingest(tmp_name, tmp, mode='replace')
264
+ cursor.execute(sql)
265
+ table = cursor.fetch_arrow_table()
266
+
267
+ batch = table['__batch__'].to_numpy()
268
+ table = table.remove_column(table.schema.get_field_index('__batch__'))
269
+
270
+ return self._sanitize(table_name, table), batch
271
+
272
+ def _sample_target_set(
273
+ self,
274
+ query: ValidatedPredictiveQuery,
275
+ entity_df: pd.DataFrame,
276
+ index: np.ndarray,
277
+ anchor_time: pd.Series,
278
+ num_examples: int,
279
+ columns_dict: dict[str, set[str]],
280
+ time_offset_dict: dict[
281
+ tuple[str, str, str],
282
+ tuple[pd.DateOffset | None, pd.DateOffset],
283
+ ],
284
+ batch_size: int = 10_000,
285
+ ) -> tuple[pd.Series, np.ndarray]:
286
+
287
+ count = 0
288
+ ys: list[pd.Series] = []
289
+ mask = np.full(len(index), False, dtype=bool)
290
+ for start in range(0, len(index), batch_size):
291
+ df = entity_df.iloc[index[start:start + batch_size]]
292
+ time = anchor_time.iloc[start:start + batch_size]
293
+
294
+ feat_dict: dict[str, pd.DataFrame] = {query.entity_table: df}
295
+ time_dict: dict[str, pd.Series] = {}
296
+ time_column = self.time_column_dict.get(query.entity_table)
297
+ if time_column in columns_dict[query.entity_table]:
298
+ time_dict[query.entity_table] = df[time_column]
299
+ batch_dict: dict[str, np.ndarray] = {
300
+ query.entity_table: np.arange(len(df)),
301
+ }
302
+ for edge_type, (_min, _max) in time_offset_dict.items():
303
+ table_name, fkey, _ = edge_type
304
+ feat_dict[table_name], batch_dict[table_name] = self._by_time(
305
+ table_name=table_name,
306
+ fkey=fkey,
307
+ pkey=df[self.primary_key_dict[query.entity_table]],
308
+ anchor_time=time,
309
+ min_offset=_min,
310
+ max_offset=_max,
311
+ columns=columns_dict[table_name],
312
+ )
313
+ time_column = self.time_column_dict.get(table_name)
314
+ if time_column in columns_dict[table_name]:
315
+ time_dict[table_name] = feat_dict[table_name][time_column]
316
+
317
+ y, _mask = PQueryPandasExecutor().execute(
318
+ query=query,
319
+ feat_dict=feat_dict,
320
+ time_dict=time_dict,
321
+ batch_dict=batch_dict,
322
+ anchor_time=anchor_time,
323
+ num_forecasts=query.num_forecasts,
324
+ )
325
+ ys.append(y)
326
+ mask[start:start + batch_size] = _mask
327
+
328
+ count += len(y)
329
+ if count >= num_examples:
330
+ break
331
+
332
+ if len(ys) == 0:
333
+ y = pd.Series([], dtype=float)
334
+ elif len(ys) == 1:
335
+ y = ys[0]
336
+ else:
337
+ y = pd.concat(ys, axis=0, ignore_index=True)
338
+
339
+ return y, mask
340
+
341
+ def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
342
+ df = table.to_pandas(types_mapper=pd.ArrowDtype)
343
+
344
+ stype_dict = self.table_stype_dict[table_name]
345
+ for column_name in df.columns:
346
+ if stype_dict.get(column_name) == Stype.timestamp:
347
+ df[column_name] = pd.to_datetime(df[column_name])
348
+
349
+ return df