kumoai 2.13.0.dev202512040649__cp313-cp313-win_amd64.whl → 2.14.0.dev202601081732__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +35 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +51 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/__init__.py +4 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +35 -31
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +366 -0
- kumoai/experimental/rfm/backend/snow/table.py +177 -50
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +454 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
- kumoai/experimental/rfm/base/__init__.py +23 -3
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +782 -0
- kumoai/experimental/rfm/base/source.py +2 -1
- kumoai/experimental/rfm/base/sql_sampler.py +247 -0
- kumoai/experimental/rfm/base/table.py +404 -203
- kumoai/experimental/rfm/graph.py +374 -172
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +7 -4
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +762 -467
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/kumolib.cp313-win_amd64.pyd +0 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +190 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/METADATA +3 -2
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/RECORD +52 -41
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
|
-
from typing import Any, TypeAlias
|
|
2
|
+
from typing import Any, TypeAlias
|
|
3
3
|
|
|
4
4
|
try:
|
|
5
5
|
import adbc_driver_sqlite.dbapi as adbc
|
|
@@ -11,7 +11,7 @@ except ImportError:
|
|
|
11
11
|
Connection: TypeAlias = adbc.AdbcSqliteConnection
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
def connect(uri:
|
|
14
|
+
def connect(uri: str | Path | None = None, **kwargs: Any) -> Connection:
|
|
15
15
|
r"""Opens a connection to a :class:`sqlite` database.
|
|
16
16
|
|
|
17
17
|
uri: The path to the database file to be opened.
|
|
@@ -22,9 +22,11 @@ def connect(uri: Union[str, Path, None] = None, **kwargs: Any) -> Connection:
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
from .table import SQLiteTable # noqa: E402
|
|
25
|
+
from .sampler import SQLiteSampler # noqa: E402
|
|
25
26
|
|
|
26
27
|
__all__ = [
|
|
27
28
|
'connect',
|
|
28
29
|
'Connection',
|
|
29
30
|
'SQLiteTable',
|
|
31
|
+
'SQLiteSampler',
|
|
30
32
|
]
|
|
@@ -0,0 +1,454 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import pyarrow as pa
|
|
8
|
+
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
9
|
+
|
|
10
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteTable
|
|
11
|
+
from kumoai.experimental.rfm.base import SQLSampler, Table
|
|
12
|
+
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
13
|
+
from kumoai.utils import ProgressLogger, quote_ident
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from kumoai.experimental.rfm import Graph
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SQLiteSampler(SQLSampler):
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
graph: 'Graph',
|
|
23
|
+
verbose: bool | ProgressLogger = True,
|
|
24
|
+
optimize: bool = False,
|
|
25
|
+
) -> None:
|
|
26
|
+
super().__init__(graph=graph, verbose=verbose)
|
|
27
|
+
|
|
28
|
+
for table in graph.tables.values():
|
|
29
|
+
assert isinstance(table, SQLiteTable)
|
|
30
|
+
self._connection = table._connection
|
|
31
|
+
|
|
32
|
+
if optimize:
|
|
33
|
+
with self._connection.cursor() as cursor:
|
|
34
|
+
cursor.execute("PRAGMA temp_store = MEMORY")
|
|
35
|
+
cursor.execute("PRAGMA cache_size = -2000000") # 2 GB
|
|
36
|
+
|
|
37
|
+
# Collect database indices for speeding sampling:
|
|
38
|
+
index_dict: dict[str, set[tuple[str, ...]]] = defaultdict(set)
|
|
39
|
+
for table_name, primary_key in self.primary_key_dict.items():
|
|
40
|
+
source_table = self.source_table_dict[table_name]
|
|
41
|
+
if primary_key not in source_table:
|
|
42
|
+
continue # No physical column.
|
|
43
|
+
if source_table[primary_key].is_unique_key:
|
|
44
|
+
continue
|
|
45
|
+
index_dict[table_name].add((primary_key, ))
|
|
46
|
+
for src_table_name, foreign_key, _ in graph.edges:
|
|
47
|
+
source_table = self.source_table_dict[src_table_name]
|
|
48
|
+
if foreign_key not in source_table:
|
|
49
|
+
continue # No physical column.
|
|
50
|
+
if source_table[foreign_key].is_unique_key:
|
|
51
|
+
continue
|
|
52
|
+
time_column = self.time_column_dict.get(src_table_name)
|
|
53
|
+
if time_column is not None and time_column in source_table:
|
|
54
|
+
index_dict[src_table_name].add((foreign_key, time_column))
|
|
55
|
+
else:
|
|
56
|
+
index_dict[src_table_name].add((foreign_key, ))
|
|
57
|
+
|
|
58
|
+
# Only maintain missing indices:
|
|
59
|
+
with self._connection.cursor() as cursor:
|
|
60
|
+
for table_name in list(index_dict.keys()):
|
|
61
|
+
indices = index_dict[table_name]
|
|
62
|
+
source_name = self.source_name_dict[table_name]
|
|
63
|
+
sql = f"PRAGMA index_list({source_name})"
|
|
64
|
+
cursor.execute(sql)
|
|
65
|
+
for _, index_name, *_ in cursor.fetchall():
|
|
66
|
+
sql = f"PRAGMA index_info({quote_ident(index_name)})"
|
|
67
|
+
cursor.execute(sql)
|
|
68
|
+
# Fetch index information and sort by `seqno`:
|
|
69
|
+
index_info = tuple(info[2] for info in sorted(
|
|
70
|
+
cursor.fetchall(), key=lambda x: x[0]))
|
|
71
|
+
# Remove all indices in case primary index already exists:
|
|
72
|
+
for index in list(indices):
|
|
73
|
+
if index_info[0] == index[0]:
|
|
74
|
+
indices.discard(index)
|
|
75
|
+
if len(indices) == 0:
|
|
76
|
+
del index_dict[table_name]
|
|
77
|
+
|
|
78
|
+
if optimize and len(index_dict) > 0:
|
|
79
|
+
if not isinstance(verbose, ProgressLogger):
|
|
80
|
+
verbose = ProgressLogger.default(
|
|
81
|
+
msg="Optimizing SQLite database",
|
|
82
|
+
verbose=verbose,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
with verbose as logger, self._connection.cursor() as cursor:
|
|
86
|
+
for table_name, indices in index_dict.items():
|
|
87
|
+
for index in indices:
|
|
88
|
+
name = f"kumo_index_{table_name}_{'_'.join(index)}"
|
|
89
|
+
name = quote_ident(name)
|
|
90
|
+
columns = ', '.join(quote_ident(v) for v in index)
|
|
91
|
+
columns += ' DESC' if len(index) > 1 else ''
|
|
92
|
+
source_name = self.source_name_dict[table_name]
|
|
93
|
+
sql = (f"CREATE INDEX IF NOT EXISTS {name}\n"
|
|
94
|
+
f"ON {source_name}({columns})")
|
|
95
|
+
cursor.execute(sql)
|
|
96
|
+
self._connection.commit()
|
|
97
|
+
if len(index) > 1:
|
|
98
|
+
logger.log(f"Created index on {index} in table "
|
|
99
|
+
f"'{table_name}'")
|
|
100
|
+
else:
|
|
101
|
+
logger.log(f"Created index on '{index[0]}' in "
|
|
102
|
+
f"table '{table_name}'")
|
|
103
|
+
|
|
104
|
+
elif len(index_dict) > 0:
|
|
105
|
+
num = sum(len(indices) for indices in index_dict.values())
|
|
106
|
+
index_repr = '1 index' if num == 1 else f'{num} indices'
|
|
107
|
+
num = len(index_dict)
|
|
108
|
+
table_repr = '1 table' if num == 1 else f'{num} tables'
|
|
109
|
+
warnings.warn(f"Missing {index_repr} in {table_repr} for optimal "
|
|
110
|
+
f"database querying. For improving runtime, we "
|
|
111
|
+
f"strongly suggest to create indices for primary "
|
|
112
|
+
f"and foreign keys, e.g., automatically by "
|
|
113
|
+
f"instantiating KumoRFM via "
|
|
114
|
+
f"`KumoRFM(graph, optimize=True)`.")
|
|
115
|
+
|
|
116
|
+
def _get_min_max_time_dict(
|
|
117
|
+
self,
|
|
118
|
+
table_names: list[str],
|
|
119
|
+
) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
|
|
120
|
+
selects: list[str] = []
|
|
121
|
+
for table_name in table_names:
|
|
122
|
+
column = self.time_column_dict[table_name]
|
|
123
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
124
|
+
select = (f"SELECT\n"
|
|
125
|
+
f" ? as table_name,\n"
|
|
126
|
+
f" MIN({column_ref}) as min_date,\n"
|
|
127
|
+
f" MAX({column_ref}) as max_date\n"
|
|
128
|
+
f"FROM {self.source_name_dict[table_name]}")
|
|
129
|
+
selects.append(select)
|
|
130
|
+
sql = "\nUNION ALL\n".join(selects)
|
|
131
|
+
|
|
132
|
+
out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
133
|
+
with self._connection.cursor() as cursor:
|
|
134
|
+
cursor.execute(sql, table_names)
|
|
135
|
+
for table_name, _min, _max in cursor.fetchall():
|
|
136
|
+
out_dict[table_name] = (
|
|
137
|
+
pd.Timestamp.max if _min is None else pd.Timestamp(_min),
|
|
138
|
+
pd.Timestamp.min if _max is None else pd.Timestamp(_max),
|
|
139
|
+
)
|
|
140
|
+
return out_dict
|
|
141
|
+
|
|
142
|
+
def _sample_entity_table(
|
|
143
|
+
self,
|
|
144
|
+
table_name: str,
|
|
145
|
+
columns: set[str],
|
|
146
|
+
num_rows: int,
|
|
147
|
+
random_seed: int | None = None,
|
|
148
|
+
) -> pd.DataFrame:
|
|
149
|
+
# NOTE SQLite does not natively support passing a `random_seed`.
|
|
150
|
+
|
|
151
|
+
source_table = self.source_table_dict[table_name]
|
|
152
|
+
filters: list[str] = []
|
|
153
|
+
|
|
154
|
+
key = self.primary_key_dict[table_name]
|
|
155
|
+
if key not in source_table or source_table[key].is_nullable:
|
|
156
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
157
|
+
filters.append(f" {key_ref} IS NOT NULL")
|
|
158
|
+
|
|
159
|
+
column = self.time_column_dict.get(table_name)
|
|
160
|
+
if column is None:
|
|
161
|
+
pass
|
|
162
|
+
elif column not in source_table or source_table[column].is_nullable:
|
|
163
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
164
|
+
filters.append(f" {column_ref} IS NOT NULL")
|
|
165
|
+
|
|
166
|
+
# TODO Make this query more efficient - it does full table scan.
|
|
167
|
+
projections = [
|
|
168
|
+
self.table_column_proj_dict[table_name][column]
|
|
169
|
+
for column in columns
|
|
170
|
+
]
|
|
171
|
+
sql = (f"SELECT {', '.join(projections)}\n"
|
|
172
|
+
f"FROM {self.source_name_dict[table_name]}")
|
|
173
|
+
if len(filters) > 0:
|
|
174
|
+
sql += f"\nWHERE{' AND'.join(filters)}"
|
|
175
|
+
sql += f"\nORDER BY RANDOM() LIMIT {num_rows}"
|
|
176
|
+
|
|
177
|
+
with self._connection.cursor() as cursor:
|
|
178
|
+
# NOTE This may return duplicate primary keys. This is okay.
|
|
179
|
+
cursor.execute(sql)
|
|
180
|
+
table = cursor.fetch_arrow_table()
|
|
181
|
+
|
|
182
|
+
return Table._sanitize(
|
|
183
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
184
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
185
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def _sample_target(
|
|
189
|
+
self,
|
|
190
|
+
query: ValidatedPredictiveQuery,
|
|
191
|
+
entity_df: pd.DataFrame,
|
|
192
|
+
train_index: np.ndarray,
|
|
193
|
+
train_time: pd.Series,
|
|
194
|
+
num_train_examples: int,
|
|
195
|
+
test_index: np.ndarray,
|
|
196
|
+
test_time: pd.Series,
|
|
197
|
+
num_test_examples: int,
|
|
198
|
+
columns_dict: dict[str, set[str]],
|
|
199
|
+
time_offset_dict: dict[
|
|
200
|
+
tuple[str, str, str],
|
|
201
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
202
|
+
],
|
|
203
|
+
) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
|
|
204
|
+
train_y, train_mask = self._sample_target_set(
|
|
205
|
+
query=query,
|
|
206
|
+
entity_df=entity_df,
|
|
207
|
+
index=train_index,
|
|
208
|
+
anchor_time=train_time,
|
|
209
|
+
num_examples=num_train_examples,
|
|
210
|
+
columns_dict=columns_dict,
|
|
211
|
+
time_offset_dict=time_offset_dict,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
test_y, test_mask = self._sample_target_set(
|
|
215
|
+
query=query,
|
|
216
|
+
entity_df=entity_df,
|
|
217
|
+
index=test_index,
|
|
218
|
+
anchor_time=test_time,
|
|
219
|
+
num_examples=num_test_examples,
|
|
220
|
+
columns_dict=columns_dict,
|
|
221
|
+
time_offset_dict=time_offset_dict,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
return train_y, train_mask, test_y, test_mask
|
|
225
|
+
|
|
226
|
+
def _by_pkey(
|
|
227
|
+
self,
|
|
228
|
+
table_name: str,
|
|
229
|
+
index: pd.Series,
|
|
230
|
+
columns: set[str],
|
|
231
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
232
|
+
source_table = self.source_table_dict[table_name]
|
|
233
|
+
key = self.primary_key_dict[table_name]
|
|
234
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
235
|
+
projections = [
|
|
236
|
+
self.table_column_proj_dict[table_name][column]
|
|
237
|
+
for column in columns
|
|
238
|
+
]
|
|
239
|
+
|
|
240
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
241
|
+
tmp_name = f'tmp_{table_name}_{key}_{id(tmp)}'
|
|
242
|
+
|
|
243
|
+
sql = (f"SELECT "
|
|
244
|
+
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
245
|
+
f"{', '.join(projections)}\n"
|
|
246
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
247
|
+
f"JOIN {self.source_name_dict[table_name]} ent\n")
|
|
248
|
+
if key in source_table and source_table[key].is_unique_key:
|
|
249
|
+
sql += (f" ON {key_ref} = tmp.__kumo_id__")
|
|
250
|
+
else:
|
|
251
|
+
sql += (f" ON ent.rowid = (\n"
|
|
252
|
+
f" SELECT rowid\n"
|
|
253
|
+
f" FROM {self.source_name_dict[table_name]}\n"
|
|
254
|
+
f" WHERE {key_ref} == tmp.__kumo_id__\n"
|
|
255
|
+
f" LIMIT 1\n"
|
|
256
|
+
f")")
|
|
257
|
+
|
|
258
|
+
with self._connection.cursor() as cursor:
|
|
259
|
+
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
260
|
+
cursor.execute(sql)
|
|
261
|
+
table = cursor.fetch_arrow_table()
|
|
262
|
+
|
|
263
|
+
batch = table['__kumo_batch__'].to_numpy()
|
|
264
|
+
batch_index = table.schema.get_field_index('__kumo_batch__')
|
|
265
|
+
table = table.remove_column(batch_index)
|
|
266
|
+
|
|
267
|
+
return Table._sanitize(
|
|
268
|
+
df=table.to_pandas(),
|
|
269
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
270
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
271
|
+
), batch
|
|
272
|
+
|
|
273
|
+
def _by_fkey(
|
|
274
|
+
self,
|
|
275
|
+
table_name: str,
|
|
276
|
+
foreign_key: str,
|
|
277
|
+
index: pd.Series,
|
|
278
|
+
num_neighbors: int,
|
|
279
|
+
anchor_time: pd.Series | None,
|
|
280
|
+
columns: set[str],
|
|
281
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
282
|
+
time_column = self.time_column_dict.get(table_name)
|
|
283
|
+
|
|
284
|
+
# NOTE SQLite does not have a native datetime format. Currently, we
|
|
285
|
+
# assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
|
|
286
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
287
|
+
if time_column is not None and anchor_time is not None:
|
|
288
|
+
anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
289
|
+
tmp = tmp.append_column('__kumo_time__', pa.array(anchor_time))
|
|
290
|
+
tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
|
|
291
|
+
|
|
292
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
293
|
+
projections = [
|
|
294
|
+
self.table_column_proj_dict[table_name][column]
|
|
295
|
+
for column in columns
|
|
296
|
+
]
|
|
297
|
+
sql = (f"SELECT "
|
|
298
|
+
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
299
|
+
f"{', '.join(projections)}\n"
|
|
300
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
301
|
+
f"JOIN {self.source_name_dict[table_name]} fact\n"
|
|
302
|
+
f"ON fact.rowid IN (\n"
|
|
303
|
+
f" SELECT rowid\n"
|
|
304
|
+
f" FROM {self.source_name_dict[table_name]}\n"
|
|
305
|
+
f" WHERE {key_ref} = tmp.__kumo_id__\n")
|
|
306
|
+
if time_column is not None and anchor_time is not None:
|
|
307
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
308
|
+
sql += f" AND {time_ref} <= tmp.__kumo_time__\n"
|
|
309
|
+
if time_column is not None:
|
|
310
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
311
|
+
sql += f" ORDER BY {time_ref} DESC\n"
|
|
312
|
+
sql += (f" LIMIT {num_neighbors}\n"
|
|
313
|
+
f")")
|
|
314
|
+
|
|
315
|
+
with self._connection.cursor() as cursor:
|
|
316
|
+
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
317
|
+
cursor.execute(sql)
|
|
318
|
+
table = cursor.fetch_arrow_table()
|
|
319
|
+
|
|
320
|
+
batch = table['__kumo_batch__'].to_numpy()
|
|
321
|
+
batch_index = table.schema.get_field_index('__kumo_batch__')
|
|
322
|
+
table = table.remove_column(batch_index)
|
|
323
|
+
|
|
324
|
+
return Table._sanitize(
|
|
325
|
+
df=table.to_pandas(),
|
|
326
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
327
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
328
|
+
), batch
|
|
329
|
+
|
|
330
|
+
# Helper Methods ##########################################################
|
|
331
|
+
|
|
332
|
+
def _by_time(
|
|
333
|
+
self,
|
|
334
|
+
table_name: str,
|
|
335
|
+
foreign_key: str,
|
|
336
|
+
index: pd.Series,
|
|
337
|
+
anchor_time: pd.Series,
|
|
338
|
+
min_offset: pd.DateOffset | None,
|
|
339
|
+
max_offset: pd.DateOffset,
|
|
340
|
+
columns: set[str],
|
|
341
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
342
|
+
time_column = self.time_column_dict[table_name]
|
|
343
|
+
|
|
344
|
+
# NOTE SQLite does not have a native datetime format. Currently, we
|
|
345
|
+
# assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
|
|
346
|
+
tmp = pa.table([pa.array(index)], names=['__kumo_id__'])
|
|
347
|
+
end_time = anchor_time + max_offset
|
|
348
|
+
end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
349
|
+
tmp = tmp.append_column('__kumo_end__', pa.array(end_time))
|
|
350
|
+
if min_offset is not None:
|
|
351
|
+
start_time = anchor_time + min_offset
|
|
352
|
+
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
353
|
+
tmp = tmp.append_column('__kumo_start__', pa.array(start_time))
|
|
354
|
+
tmp_name = f'tmp_{table_name}_{foreign_key}_{id(tmp)}'
|
|
355
|
+
|
|
356
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
357
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
358
|
+
projections = [
|
|
359
|
+
self.table_column_proj_dict[table_name][column]
|
|
360
|
+
for column in columns
|
|
361
|
+
]
|
|
362
|
+
sql = (f"SELECT "
|
|
363
|
+
f"tmp.rowid - 1 as __kumo_batch__, "
|
|
364
|
+
f"{', '.join(projections)}\n"
|
|
365
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
366
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
367
|
+
f" ON {key_ref} = tmp.__kumo_id__\n"
|
|
368
|
+
f" AND {time_ref} <= tmp.__kumo_end__")
|
|
369
|
+
if min_offset is not None:
|
|
370
|
+
sql += f"\n AND {time_ref} > tmp.__kumo_start__"
|
|
371
|
+
|
|
372
|
+
with self._connection.cursor() as cursor:
|
|
373
|
+
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
374
|
+
cursor.execute(sql)
|
|
375
|
+
table = cursor.fetch_arrow_table()
|
|
376
|
+
|
|
377
|
+
batch = table['__kumo_batch__'].to_numpy()
|
|
378
|
+
batch_index = table.schema.get_field_index('__kumo_batch__')
|
|
379
|
+
table = table.remove_column(batch_index)
|
|
380
|
+
|
|
381
|
+
return Table._sanitize(
|
|
382
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
383
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
384
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
385
|
+
), batch
|
|
386
|
+
|
|
387
|
+
def _sample_target_set(
|
|
388
|
+
self,
|
|
389
|
+
query: ValidatedPredictiveQuery,
|
|
390
|
+
entity_df: pd.DataFrame,
|
|
391
|
+
index: np.ndarray,
|
|
392
|
+
anchor_time: pd.Series,
|
|
393
|
+
num_examples: int,
|
|
394
|
+
columns_dict: dict[str, set[str]],
|
|
395
|
+
time_offset_dict: dict[
|
|
396
|
+
tuple[str, str, str],
|
|
397
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
398
|
+
],
|
|
399
|
+
batch_size: int = 10_000,
|
|
400
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
401
|
+
|
|
402
|
+
count = 0
|
|
403
|
+
ys: list[pd.Series] = []
|
|
404
|
+
mask = np.full(len(index), False, dtype=bool)
|
|
405
|
+
for start in range(0, len(index), batch_size):
|
|
406
|
+
df = entity_df.iloc[index[start:start + batch_size]]
|
|
407
|
+
time = anchor_time.iloc[start:start + batch_size]
|
|
408
|
+
|
|
409
|
+
feat_dict: dict[str, pd.DataFrame] = {query.entity_table: df}
|
|
410
|
+
time_dict: dict[str, pd.Series] = {}
|
|
411
|
+
time_column = self.time_column_dict.get(query.entity_table)
|
|
412
|
+
if time_column in columns_dict[query.entity_table]:
|
|
413
|
+
time_dict[query.entity_table] = df[time_column]
|
|
414
|
+
batch_dict: dict[str, np.ndarray] = {
|
|
415
|
+
query.entity_table: np.arange(len(df)),
|
|
416
|
+
}
|
|
417
|
+
for edge_type, (_min, _max) in time_offset_dict.items():
|
|
418
|
+
table_name, foreign_key, _ = edge_type
|
|
419
|
+
feat_dict[table_name], batch_dict[table_name] = self._by_time(
|
|
420
|
+
table_name=table_name,
|
|
421
|
+
foreign_key=foreign_key,
|
|
422
|
+
index=df[self.primary_key_dict[query.entity_table]],
|
|
423
|
+
anchor_time=time,
|
|
424
|
+
min_offset=_min,
|
|
425
|
+
max_offset=_max,
|
|
426
|
+
columns=columns_dict[table_name],
|
|
427
|
+
)
|
|
428
|
+
time_column = self.time_column_dict.get(table_name)
|
|
429
|
+
if time_column in columns_dict[table_name]:
|
|
430
|
+
time_dict[table_name] = feat_dict[table_name][time_column]
|
|
431
|
+
|
|
432
|
+
y, _mask = PQueryPandasExecutor().execute(
|
|
433
|
+
query=query,
|
|
434
|
+
feat_dict=feat_dict,
|
|
435
|
+
time_dict=time_dict,
|
|
436
|
+
batch_dict=batch_dict,
|
|
437
|
+
anchor_time=anchor_time,
|
|
438
|
+
num_forecasts=query.num_forecasts,
|
|
439
|
+
)
|
|
440
|
+
ys.append(y)
|
|
441
|
+
mask[start:start + batch_size] = _mask
|
|
442
|
+
|
|
443
|
+
count += len(y)
|
|
444
|
+
if count >= num_examples:
|
|
445
|
+
break
|
|
446
|
+
|
|
447
|
+
if len(ys) == 0:
|
|
448
|
+
y = pd.Series([], dtype=float)
|
|
449
|
+
elif len(ys) == 1:
|
|
450
|
+
y = ys[0]
|
|
451
|
+
else:
|
|
452
|
+
y = pd.concat(ys, axis=0, ignore_index=True)
|
|
453
|
+
|
|
454
|
+
return y, mask
|