kumoai 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl → 2.15.0.dev202601121731__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +51 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/graph_store.py +37 -46
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +24 -30
- kumoai/experimental/rfm/backend/snow/sampler.py +197 -90
- kumoai/experimental/rfm/backend/snow/table.py +159 -52
- kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +199 -99
- kumoai/experimental/rfm/backend/sqlite/table.py +103 -45
- kumoai/experimental/rfm/base/__init__.py +6 -1
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +28 -18
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +342 -13
- kumoai/experimental/rfm/base/table.py +374 -208
- kumoai/experimental/rfm/base/utils.py +27 -0
- kumoai/experimental/rfm/graph.py +335 -180
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +7 -4
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +5 -4
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +600 -360
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +3 -3
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +1 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +190 -12
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/METADATA +3 -2
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/RECORD +48 -40
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,101 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
|
-
from
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import TYPE_CHECKING, Literal
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
5
6
|
import pandas as pd
|
|
7
|
+
from kumoapi.rfm.context import Subgraph
|
|
8
|
+
from kumoapi.typing import Dtype
|
|
6
9
|
|
|
7
|
-
from kumoai.experimental.rfm.base import
|
|
10
|
+
from kumoai.experimental.rfm.base import (
|
|
11
|
+
LocalExpression,
|
|
12
|
+
Sampler,
|
|
13
|
+
SamplerOutput,
|
|
14
|
+
SourceColumn,
|
|
15
|
+
)
|
|
16
|
+
from kumoai.experimental.rfm.base.mapper import Mapper
|
|
17
|
+
from kumoai.utils import ProgressLogger, quote_ident
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from kumoai.experimental.rfm import Graph
|
|
21
|
+
|
|
22
|
+
EdgeType = tuple[str, str, str]
|
|
8
23
|
|
|
9
24
|
|
|
10
25
|
class SQLSampler(Sampler):
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
graph: 'Graph',
|
|
29
|
+
verbose: bool | ProgressLogger = True,
|
|
30
|
+
) -> None:
|
|
31
|
+
super().__init__(graph=graph, verbose=verbose)
|
|
32
|
+
|
|
33
|
+
self._source_name_dict: dict[str, str] = {
|
|
34
|
+
table.name: table._quoted_source_name
|
|
35
|
+
for table in graph.tables.values()
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
self._source_table_dict: dict[str, dict[str, SourceColumn]] = {}
|
|
39
|
+
for table in graph.tables.values():
|
|
40
|
+
self._source_table_dict[table.name] = {}
|
|
41
|
+
for column in table.columns:
|
|
42
|
+
if not column.is_source:
|
|
43
|
+
continue
|
|
44
|
+
src_column = table._source_column_dict[column.name]
|
|
45
|
+
self._source_table_dict[table.name][column.name] = src_column
|
|
46
|
+
|
|
47
|
+
self._table_dtype_dict: dict[str, dict[str, Dtype]] = {}
|
|
48
|
+
for table in graph.tables.values():
|
|
49
|
+
self._table_dtype_dict[table.name] = {}
|
|
50
|
+
for column in table.columns:
|
|
51
|
+
self._table_dtype_dict[table.name][column.name] = column.dtype
|
|
52
|
+
|
|
53
|
+
self._table_column_ref_dict: dict[str, dict[str, str]] = {}
|
|
54
|
+
self._table_column_proj_dict: dict[str, dict[str, str]] = {}
|
|
55
|
+
for table in graph.tables.values():
|
|
56
|
+
column_ref_dict: dict[str, str] = {}
|
|
57
|
+
column_proj_dict: dict[str, str] = {}
|
|
58
|
+
for column in table.columns:
|
|
59
|
+
if column.expr is not None:
|
|
60
|
+
assert isinstance(column.expr, LocalExpression)
|
|
61
|
+
column_ref_dict[column.name] = column.expr.value
|
|
62
|
+
column_proj_dict[column.name] = (
|
|
63
|
+
f'{column.expr} AS {quote_ident(column.name)}')
|
|
64
|
+
else:
|
|
65
|
+
column_ref_dict[column.name] = quote_ident(column.name)
|
|
66
|
+
column_proj_dict[column.name] = quote_ident(column.name)
|
|
67
|
+
self._table_column_ref_dict[table.name] = column_ref_dict
|
|
68
|
+
self._table_column_proj_dict[table.name] = column_proj_dict
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def source_name_dict(self) -> dict[str, str]:
|
|
72
|
+
r"""The source table names for all tables in the graph."""
|
|
73
|
+
return self._source_name_dict
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def source_table_dict(self) -> dict[str, dict[str, SourceColumn]]:
|
|
77
|
+
r"""The source column information for all tables in the graph."""
|
|
78
|
+
return self._source_table_dict
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def table_dtype_dict(self) -> dict[str, dict[str, Dtype]]:
|
|
82
|
+
r"""The data types for all columns in all tables in the graph."""
|
|
83
|
+
return self._table_dtype_dict
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def table_column_ref_dict(self) -> dict[str, dict[str, str]]:
|
|
87
|
+
r"""The SQL reference expression for all columns in all tables in the
|
|
88
|
+
graph.
|
|
89
|
+
"""
|
|
90
|
+
return self._table_column_ref_dict
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def table_column_proj_dict(self) -> dict[str, dict[str, str]]:
|
|
94
|
+
r"""The SQL projection expressions for all columns in all tables in the
|
|
95
|
+
graph.
|
|
96
|
+
"""
|
|
97
|
+
return self._table_column_proj_dict
|
|
98
|
+
|
|
11
99
|
def _sample_subgraph(
|
|
12
100
|
self,
|
|
13
101
|
entity_table_name: str,
|
|
@@ -17,31 +105,260 @@ class SQLSampler(Sampler):
|
|
|
17
105
|
num_neighbors: list[int],
|
|
18
106
|
) -> SamplerOutput:
|
|
19
107
|
|
|
108
|
+
# Make sure to always include primary key, foreign key and time columns
|
|
109
|
+
# during data fetching since these are needed for graph traversal:
|
|
110
|
+
sample_columns_dict: dict[str, set[str]] = {}
|
|
111
|
+
for table, columns in columns_dict.items():
|
|
112
|
+
sample_columns = columns | {
|
|
113
|
+
foreign_key
|
|
114
|
+
for foreign_key, _ in self.foreign_key_dict[table]
|
|
115
|
+
}
|
|
116
|
+
if primary_key := self.primary_key_dict.get(table):
|
|
117
|
+
sample_columns |= {primary_key}
|
|
118
|
+
sample_columns_dict[table] = sample_columns
|
|
119
|
+
if not isinstance(anchor_time, pd.Series):
|
|
120
|
+
sample_columns_dict[entity_table_name] |= {
|
|
121
|
+
self.time_column_dict[entity_table_name]
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
# Sample Entity Table #################################################
|
|
125
|
+
|
|
20
126
|
df, batch = self._by_pkey(
|
|
21
127
|
table_name=entity_table_name,
|
|
22
|
-
|
|
23
|
-
columns=
|
|
128
|
+
index=entity_pkey,
|
|
129
|
+
columns=sample_columns_dict[entity_table_name],
|
|
24
130
|
)
|
|
25
131
|
if len(batch) != len(entity_pkey):
|
|
26
|
-
|
|
132
|
+
mask = np.ones(len(entity_pkey), dtype=bool)
|
|
133
|
+
mask[batch] = False
|
|
134
|
+
raise KeyError(f"The primary keys "
|
|
135
|
+
f"{entity_pkey.iloc[mask].tolist()} do not exist "
|
|
136
|
+
f"in the '{entity_table_name}' table")
|
|
27
137
|
|
|
138
|
+
# Make sure that entities are returned in expected order:
|
|
28
139
|
perm = batch.argsort()
|
|
29
140
|
batch = batch[perm]
|
|
30
141
|
df = df.iloc[perm].reset_index(drop=True)
|
|
31
142
|
|
|
143
|
+
# Fill 'entity' anchor times with actual values:
|
|
32
144
|
if not isinstance(anchor_time, pd.Series):
|
|
33
145
|
time_column = self.time_column_dict[entity_table_name]
|
|
34
146
|
anchor_time = df[time_column]
|
|
147
|
+
assert isinstance(anchor_time, pd.Series)
|
|
148
|
+
|
|
149
|
+
# Recursive Neighbor Sampling #########################################
|
|
150
|
+
|
|
151
|
+
mapper_dict: dict[str, Mapper] = defaultdict(
|
|
152
|
+
lambda: Mapper(num_examples=len(entity_pkey)))
|
|
153
|
+
mapper_dict[entity_table_name].add(
|
|
154
|
+
pkey=df[self.primary_key_dict[entity_table_name]],
|
|
155
|
+
batch=batch,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
dfs_dict: dict[str, list[pd.DataFrame]] = defaultdict(list)
|
|
159
|
+
dfs_dict[entity_table_name].append(df)
|
|
160
|
+
batches_dict: dict[str, list[np.ndarray]] = defaultdict(list)
|
|
161
|
+
batches_dict[entity_table_name].append(batch)
|
|
162
|
+
num_sampled_nodes_dict: dict[str, list[int]] = defaultdict(
|
|
163
|
+
lambda: [0] * (len(num_neighbors) + 1))
|
|
164
|
+
num_sampled_nodes_dict[entity_table_name][0] = len(entity_pkey)
|
|
165
|
+
|
|
166
|
+
rows_dict: dict[EdgeType, list[np.ndarray]] = defaultdict(list)
|
|
167
|
+
cols_dict: dict[EdgeType, list[np.ndarray]] = defaultdict(list)
|
|
168
|
+
num_sampled_edges_dict: dict[EdgeType, list[int]] = defaultdict(
|
|
169
|
+
lambda: [0] * len(num_neighbors))
|
|
170
|
+
|
|
171
|
+
# The start index of data frame slices of the previous hop:
|
|
172
|
+
offset_dict: dict[str, int] = defaultdict(int)
|
|
173
|
+
|
|
174
|
+
for hop, neighbors in enumerate(num_neighbors):
|
|
175
|
+
if neighbors == 0:
|
|
176
|
+
break # Abort early.
|
|
177
|
+
|
|
178
|
+
for table in list(num_sampled_nodes_dict.keys()):
|
|
179
|
+
# Only sample from tables that have been visited in the
|
|
180
|
+
# previous hop:
|
|
181
|
+
if num_sampled_nodes_dict[table][hop] == 0:
|
|
182
|
+
continue
|
|
183
|
+
|
|
184
|
+
# Collect the slices of data sampled in the previous hop
|
|
185
|
+
# (but maintain only required key information):
|
|
186
|
+
cols = [fkey for fkey, _ in self.foreign_key_dict[table]]
|
|
187
|
+
if table in self.primary_key_dict:
|
|
188
|
+
cols.append(self.primary_key_dict[table])
|
|
189
|
+
dfs = [df[cols] for df in dfs_dict[table][offset_dict[table]:]]
|
|
190
|
+
df = pd.concat(
|
|
191
|
+
dfs,
|
|
192
|
+
axis=0,
|
|
193
|
+
ignore_index=True,
|
|
194
|
+
) if len(dfs) > 1 else dfs[0]
|
|
195
|
+
batches = batches_dict[table][offset_dict[table]:]
|
|
196
|
+
batch = (np.concatenate(batches)
|
|
197
|
+
if len(batches) > 1 else batches[0])
|
|
198
|
+
offset_dict[table] = len(batches_dict[table]) # Increase.
|
|
199
|
+
|
|
200
|
+
pkey: pd.Series | None = None
|
|
201
|
+
index: pd.ndarray | None = None
|
|
202
|
+
if table in self.primary_key_dict:
|
|
203
|
+
pkey = df[self.primary_key_dict[table]]
|
|
204
|
+
index = mapper_dict[table].get(pkey, batch)
|
|
205
|
+
|
|
206
|
+
# Iterate over foreign keys in the current table:
|
|
207
|
+
for fkey, dst_table in self.foreign_key_dict[table]:
|
|
208
|
+
row = mapper_dict[dst_table].get(df[fkey], batch)
|
|
209
|
+
mask = row == -1
|
|
210
|
+
if mask.any():
|
|
211
|
+
key_df = pd.DataFrame({
|
|
212
|
+
'fkey': df[fkey],
|
|
213
|
+
'batch': batch,
|
|
214
|
+
}).iloc[mask]
|
|
215
|
+
# Only maintain unique keys per example:
|
|
216
|
+
unique_key_df = key_df.drop_duplicates()
|
|
217
|
+
# Fully de-duplicate keys across examples:
|
|
218
|
+
code, fkey_index = pd.factorize(unique_key_df['fkey'])
|
|
219
|
+
|
|
220
|
+
_df, _batch = self._by_pkey(
|
|
221
|
+
table_name=dst_table,
|
|
222
|
+
index=fkey_index,
|
|
223
|
+
columns=sample_columns_dict[dst_table],
|
|
224
|
+
) # Ensure result is sorted according to input order:
|
|
225
|
+
_df = _df.iloc[_batch.argsort()]
|
|
226
|
+
|
|
227
|
+
# Compute valid entries (without dangling foreign keys)
|
|
228
|
+
# in `unique_fkey_df`:
|
|
229
|
+
_mask = np.full(len(fkey_index), fill_value=False)
|
|
230
|
+
_mask[_batch] = True
|
|
231
|
+
_mask = _mask[code]
|
|
232
|
+
|
|
233
|
+
# Recontruct unique (key, batch) pairs:
|
|
234
|
+
code, _ = pd.factorize(unique_key_df['fkey'][_mask])
|
|
235
|
+
_df = _df.iloc[code].reset_index(drop=True)
|
|
236
|
+
_batch = unique_key_df['batch'].to_numpy()[_mask]
|
|
237
|
+
|
|
238
|
+
# Register node IDs:
|
|
239
|
+
mapper_dict[dst_table].add(
|
|
240
|
+
pkey=_df[self.primary_key_dict[dst_table]],
|
|
241
|
+
batch=_batch,
|
|
242
|
+
)
|
|
243
|
+
row[mask] = mapper_dict[dst_table].get(
|
|
244
|
+
pkey=key_df['fkey'],
|
|
245
|
+
batch=key_df['batch'].to_numpy(),
|
|
246
|
+
) # NOTE `row` may still hold `-1` for dangling fkeys.
|
|
247
|
+
|
|
248
|
+
dfs_dict[dst_table].append(_df)
|
|
249
|
+
batches_dict[dst_table].append(_batch)
|
|
250
|
+
num_sampled_nodes_dict[dst_table][hop + 1] += ( #
|
|
251
|
+
len(_batch))
|
|
252
|
+
|
|
253
|
+
mask = row != -1
|
|
254
|
+
|
|
255
|
+
col = index
|
|
256
|
+
if col is None:
|
|
257
|
+
start = sum(num_sampled_nodes_dict[table][:hop])
|
|
258
|
+
end = sum(num_sampled_nodes_dict[table][:hop + 1])
|
|
259
|
+
col = np.arange(start, end)
|
|
260
|
+
|
|
261
|
+
row = row[mask]
|
|
262
|
+
col = col[mask]
|
|
263
|
+
|
|
264
|
+
edge_type = (table, fkey, dst_table)
|
|
265
|
+
edge_type = Subgraph.rev_edge_type(edge_type)
|
|
266
|
+
rows_dict[edge_type].append(row)
|
|
267
|
+
cols_dict[edge_type].append(col)
|
|
268
|
+
num_sampled_edges_dict[edge_type][hop] = len(col)
|
|
269
|
+
|
|
270
|
+
# Iterate over foreign keys that reference the current table:
|
|
271
|
+
for src_table, fkey in self.rev_foreign_key_dict[table]:
|
|
272
|
+
assert pkey is not None and index is not None
|
|
273
|
+
_df, _batch = self._by_fkey(
|
|
274
|
+
table_name=src_table,
|
|
275
|
+
foreign_key=fkey,
|
|
276
|
+
index=pkey,
|
|
277
|
+
num_neighbors=neighbors,
|
|
278
|
+
anchor_time=anchor_time.iloc[batch],
|
|
279
|
+
columns=sample_columns_dict[src_table],
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
edge_type = (src_table, fkey, table)
|
|
283
|
+
cols_dict[edge_type].append(index[_batch])
|
|
284
|
+
num_sampled_edges_dict[edge_type][hop] = len(_batch)
|
|
285
|
+
|
|
286
|
+
_batch = batch[_batch]
|
|
287
|
+
num_nodes = sum(num_sampled_nodes_dict[src_table])
|
|
288
|
+
if src_table in self.primary_key_dict:
|
|
289
|
+
_pkey = _df[self.primary_key_dict]
|
|
290
|
+
mapper_dict[src_table].add(_pkey, _batch)
|
|
291
|
+
row = mapper_dict[src_table].get(_pkey, _batch)
|
|
292
|
+
|
|
293
|
+
# Only preserve unknown rows:
|
|
294
|
+
mask = row >= num_nodes # type: ignore
|
|
295
|
+
mask[pd.duplicated(row)] = False
|
|
296
|
+
_df = _df.iloc[mask]
|
|
297
|
+
_batch = _batch[mask]
|
|
298
|
+
else:
|
|
299
|
+
row = np.arange(num_nodes, num_nodes + len(_batch))
|
|
300
|
+
|
|
301
|
+
rows_dict[edge_type].append(row)
|
|
302
|
+
num_sampled_nodes_dict[src_table][hop + 1] += len(_batch)
|
|
303
|
+
|
|
304
|
+
dfs_dict[src_table].append(_df)
|
|
305
|
+
batches_dict[src_table].append(_batch)
|
|
306
|
+
|
|
307
|
+
# Post-Processing #####################################################
|
|
308
|
+
|
|
309
|
+
df_dict = {
|
|
310
|
+
table:
|
|
311
|
+
pd.concat(dfs, axis=0, ignore_index=True)
|
|
312
|
+
if len(dfs) > 1 else dfs[0]
|
|
313
|
+
for table, dfs in dfs_dict.items()
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
# Only store unique rows in `df` above a certain threshold:
|
|
317
|
+
inverse_dict: dict[str, np.ndarray] = {}
|
|
318
|
+
for table, df in df_dict.items():
|
|
319
|
+
if table not in self.primary_key_dict:
|
|
320
|
+
continue
|
|
321
|
+
unique, index, inverse = np.unique(
|
|
322
|
+
df_dict[table][self.primary_key_dict[table]],
|
|
323
|
+
return_index=True,
|
|
324
|
+
return_inverse=True,
|
|
325
|
+
)
|
|
326
|
+
if len(df) > 1.05 * len(unique):
|
|
327
|
+
df_dict[table] = df.iloc[index].reset_index(drop=True)
|
|
328
|
+
inverse_dict[table] = inverse
|
|
329
|
+
|
|
330
|
+
df_dict = { # Post-filter column set:
|
|
331
|
+
table: df[list(columns_dict[table])]
|
|
332
|
+
for table, df in df_dict.items()
|
|
333
|
+
}
|
|
334
|
+
batch_dict = {
|
|
335
|
+
table: np.concatenate(batches) if len(batches) > 1 else batches[0]
|
|
336
|
+
for table, batches in batches_dict.items()
|
|
337
|
+
}
|
|
338
|
+
row_dict = {
|
|
339
|
+
edge_type: np.concatenate(rows)
|
|
340
|
+
for edge_type, rows in rows_dict.items()
|
|
341
|
+
}
|
|
342
|
+
col_dict = {
|
|
343
|
+
edge_type: np.concatenate(cols)
|
|
344
|
+
for edge_type, cols in cols_dict.items()
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
if len(num_sampled_edges_dict) == 0: # Single table:
|
|
348
|
+
num_sampled_nodes_dict = {
|
|
349
|
+
key: value[:1]
|
|
350
|
+
for key, value in num_sampled_nodes_dict.items()
|
|
351
|
+
}
|
|
35
352
|
|
|
36
353
|
return SamplerOutput(
|
|
37
354
|
anchor_time=anchor_time.astype(int).to_numpy(),
|
|
38
|
-
df_dict=
|
|
39
|
-
inverse_dict=
|
|
40
|
-
batch_dict=
|
|
41
|
-
num_sampled_nodes_dict=
|
|
42
|
-
row_dict=
|
|
43
|
-
col_dict=
|
|
44
|
-
num_sampled_edges_dict=
|
|
355
|
+
df_dict=df_dict,
|
|
356
|
+
inverse_dict=inverse_dict,
|
|
357
|
+
batch_dict=batch_dict,
|
|
358
|
+
num_sampled_nodes_dict=num_sampled_nodes_dict,
|
|
359
|
+
row_dict=row_dict,
|
|
360
|
+
col_dict=col_dict,
|
|
361
|
+
num_sampled_edges_dict=num_sampled_edges_dict,
|
|
45
362
|
)
|
|
46
363
|
|
|
47
364
|
# Abstract Methods ########################################################
|
|
@@ -50,7 +367,19 @@ class SQLSampler(Sampler):
|
|
|
50
367
|
def _by_pkey(
|
|
51
368
|
self,
|
|
52
369
|
table_name: str,
|
|
53
|
-
|
|
370
|
+
index: pd.Series,
|
|
371
|
+
columns: set[str],
|
|
372
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
373
|
+
pass
|
|
374
|
+
|
|
375
|
+
@abstractmethod
|
|
376
|
+
def _by_fkey(
|
|
377
|
+
self,
|
|
378
|
+
table_name: str,
|
|
379
|
+
foreign_key: str,
|
|
380
|
+
index: pd.Series,
|
|
381
|
+
num_neighbors: int,
|
|
382
|
+
anchor_time: pd.Series | None,
|
|
54
383
|
columns: set[str],
|
|
55
384
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
56
385
|
pass
|