kumoai 2.14.0.dev202512211732__cp313-cp313-win_amd64.whl → 2.14.0.dev202601081732__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +24 -22
  7. kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
  8. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  9. kumoai/experimental/rfm/backend/local/table.py +24 -25
  10. kumoai/experimental/rfm/backend/snow/sampler.py +184 -70
  11. kumoai/experimental/rfm/backend/snow/table.py +137 -64
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +191 -86
  13. kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
  14. kumoai/experimental/rfm/base/__init__.py +6 -9
  15. kumoai/experimental/rfm/base/column.py +95 -11
  16. kumoai/experimental/rfm/base/expression.py +44 -0
  17. kumoai/experimental/rfm/base/sampler.py +26 -17
  18. kumoai/experimental/rfm/base/source.py +1 -1
  19. kumoai/experimental/rfm/base/sql_sampler.py +182 -19
  20. kumoai/experimental/rfm/base/table.py +275 -109
  21. kumoai/experimental/rfm/graph.py +115 -107
  22. kumoai/experimental/rfm/infer/dtype.py +4 -1
  23. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  24. kumoai/experimental/rfm/relbench.py +76 -0
  25. kumoai/experimental/rfm/rfm.py +530 -304
  26. kumoai/experimental/rfm/task_table.py +292 -0
  27. kumoai/kumolib.cp313-win_amd64.pyd +0 -0
  28. kumoai/pquery/training_table.py +16 -2
  29. kumoai/trainer/distilled_trainer.py +175 -0
  30. kumoai/utils/display.py +87 -0
  31. kumoai/utils/progress_logger.py +13 -1
  32. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/METADATA +1 -1
  33. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/RECORD +36 -33
  34. kumoai/experimental/rfm/base/column_expression.py +0 -50
  35. kumoai/experimental/rfm/base/sql_table.py +0 -229
  36. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/WHEEL +0 -0
  37. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/licenses/LICENSE +0 -0
  38. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,89 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import dataclass
2
- from typing import Any
4
+ from typing import Any, Mapping, TypeAlias
3
5
 
4
6
  from kumoapi.typing import Dtype, Stype
7
+ from typing_extensions import Self
8
+
9
+ from kumoai.experimental.rfm.base import Expression
10
+ from kumoai.mixin import CastMixin
11
+
12
+
13
+ @dataclass(init=False)
14
+ class ColumnSpec(CastMixin):
15
+ r"""A column specification for adding a column to a table.
16
+
17
+ A column specification can either refer to a physical column present in
18
+ the data source, or be defined logically via an expression.
19
+
20
+ Args:
21
+ name: The name of the column.
22
+ expr: A column expression to define logical columns.
23
+ dtype: The data type of the column.
24
+ """
25
+ def __init__(
26
+ self,
27
+ name: str,
28
+ expr: Expression | Mapping[str, str] | str | None = None,
29
+ dtype: Dtype | str | None = None,
30
+ stype: Stype | str | None = None,
31
+ ) -> None:
32
+
33
+ self.name = name
34
+ self.expr = Expression.coerce(expr)
35
+ self.dtype = Dtype(dtype) if dtype is not None else None
36
+ self.stype = Stype(dtype) if stype is not None else None
37
+
38
+ @classmethod
39
+ def coerce(cls, spec: ColumnSpec | Mapping[str, Any] | str) -> Self:
40
+ r"""Coerces a column specification into a :class:`ColumnSpec`."""
41
+ if isinstance(spec, cls):
42
+ return spec
43
+ if isinstance(spec, str):
44
+ return cls(name=spec)
45
+ if isinstance(spec, Mapping):
46
+ try:
47
+ return cls(**spec)
48
+ except TypeError:
49
+ pass
50
+ raise TypeError(f"Unable to coerce 'ColumnSpec' from '{spec}'")
51
+
52
+ @property
53
+ def is_source(self) -> bool:
54
+ r"""Whether the column specification refers to a phyiscal column
55
+ present in the data source.
56
+ """
57
+ return self.expr is None
58
+
59
+
60
+ ColumnSpecType: TypeAlias = ColumnSpec | Mapping[str, Any] | str
5
61
 
6
62
 
7
63
  @dataclass(init=False, repr=False, eq=False)
8
64
  class Column:
65
+ r"""Column-level metadata information.
66
+
67
+ A column can either refer to a physical column present in the data source,
68
+ or be defined logically via an expression.
69
+
70
+ Args:
71
+ name: The name of the column.
72
+ expr: A column expression to define logical columns.
73
+ dtype: The data type of the column.
74
+ stype: The semantic type of the column.
75
+ """
9
76
  stype: Stype
10
77
 
11
- def __init__(self, name: str, stype: Stype, dtype: Dtype) -> None:
78
+ def __init__(
79
+ self,
80
+ name: str,
81
+ expr: Expression | None,
82
+ dtype: Dtype,
83
+ stype: Stype,
84
+ ) -> None:
12
85
  self._name = name
86
+ self._expr = expr
13
87
  self._dtype = Dtype(dtype)
14
88
 
15
89
  self._is_primary_key = False
@@ -20,19 +94,25 @@ class Column:
20
94
 
21
95
  @property
22
96
  def name(self) -> str:
97
+ r"""The name of the column."""
23
98
  return self._name
24
99
 
25
100
  @property
26
- def dtype(self) -> Dtype:
27
- return self._dtype
101
+ def expr(self) -> Expression | None:
102
+ r"""The expression of column (if logically)."""
103
+ return self._expr
28
104
 
29
105
  @property
30
- def is_physical(self) -> bool:
31
- return True
106
+ def dtype(self) -> Dtype:
107
+ r"""The data type of the column."""
108
+ return self._dtype
32
109
 
33
110
  @property
34
- def is_logical(self) -> bool:
35
- return not self.is_physical
111
+ def is_source(self) -> bool:
112
+ r"""Whether the column refers to a phyiscal column present in the data
113
+ source.
114
+ """
115
+ return self.expr is None
36
116
 
37
117
  def __setattr__(self, key: str, val: Any) -> None:
38
118
  if key == 'stype':
@@ -56,7 +136,7 @@ class Column:
56
136
  super().__setattr__(key, val)
57
137
 
58
138
  def __hash__(self) -> int:
59
- return hash((self.name, self.stype, self.dtype))
139
+ return hash((self.name, self.expr, self.dtype, self.stype))
60
140
 
61
141
  def __eq__(self, other: Any) -> bool:
62
142
  if not isinstance(other, Column):
@@ -64,5 +144,9 @@ class Column:
64
144
  return hash(self) == hash(other)
65
145
 
66
146
  def __repr__(self) -> str:
67
- return (f'{self.__class__.__name__}(name={self.name}, '
68
- f'stype={self.stype}, dtype={self.dtype})')
147
+ parts = [f'name={self.name}']
148
+ if self.expr is not None:
149
+ parts.append(f'expr={self.expr}')
150
+ parts.append(f'dtype={self.dtype}')
151
+ parts.append(f'stype={self.stype}')
152
+ return f"{self.__class__.__name__}({', '.join(parts)})"
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC
4
+ from dataclasses import dataclass
5
+ from typing import Mapping
6
+
7
+
8
+ class Expression(ABC):
9
+ """A base expression to define logical columns."""
10
+ @classmethod
11
+ def coerce(
12
+ cls,
13
+ spec: Expression | Mapping[str, str] | str | None,
14
+ ) -> Expression | None:
15
+ r"""Coerces an expression specification into an :class:`Expression`, if
16
+ possible.
17
+ """
18
+ if spec is None:
19
+ return None
20
+ if isinstance(spec, Expression):
21
+ return spec
22
+ if isinstance(spec, str):
23
+ return LocalExpression(spec)
24
+ if isinstance(spec, Mapping):
25
+ for sub_cls in (LocalExpression, ):
26
+ try:
27
+ return sub_cls(**spec)
28
+ except TypeError:
29
+ pass
30
+ raise TypeError(f"Unable to coerce 'Expression' from '{spec}'")
31
+
32
+
33
+ @dataclass(frozen=True, repr=False)
34
+ class LocalExpression(Expression):
35
+ r"""A local expression to define a row-level logical attribute based on
36
+ physical columns of the data source in the same row.
37
+
38
+ Args:
39
+ value: The value of the expression.
40
+ """
41
+ value: str
42
+
43
+ def __repr__(self) -> str:
44
+ return self.value
@@ -13,7 +13,6 @@ from kumoapi.pquery.AST import Aggregation, ASTNode
13
13
  from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
14
14
  from kumoapi.typing import Stype
15
15
 
16
- from kumoai.experimental.rfm.base import SourceColumn
17
16
  from kumoai.utils import ProgressLogger
18
17
 
19
18
  if TYPE_CHECKING:
@@ -53,12 +52,24 @@ class Sampler(ABC):
53
52
  graph: 'Graph',
54
53
  verbose: bool | ProgressLogger = True,
55
54
  ) -> None:
55
+
56
56
  self._edge_types: list[tuple[str, str, str]] = []
57
57
  for edge in graph.edges:
58
58
  edge_type = (edge.src_table, edge.fkey, edge.dst_table)
59
59
  self._edge_types.append(edge_type)
60
60
  self._edge_types.append(Subgraph.rev_edge_type(edge_type))
61
61
 
62
+ # Source Table -> [(Foreign Key, Destination Table)]
63
+ self._foreign_key_dict: dict[str, list[tuple[str, str]]] = {}
64
+ # Destination Table -> [(Source Table, Foreign Key)]
65
+ self._rev_foreign_key_dict: dict[str, list[tuple[str, str]]] = {}
66
+ for table in graph.tables.values():
67
+ self._foreign_key_dict[table.name] = []
68
+ self._rev_foreign_key_dict[table.name] = []
69
+ for src_table, fkey, dst_table in graph.edges:
70
+ self._foreign_key_dict[src_table].append((fkey, dst_table))
71
+ self._rev_foreign_key_dict[dst_table].append((src_table, fkey))
72
+
62
73
  self._primary_key_dict: dict[str, str] = {
63
74
  table.name: table._primary_key
64
75
  for table in graph.tables.values()
@@ -88,10 +99,6 @@ class Sampler(ABC):
88
99
  continue
89
100
  self._table_stype_dict[table.name][column.name] = column.stype
90
101
 
91
- self._source_table_dict: dict[str, dict[str, SourceColumn]] = {}
92
- for table in graph.tables.values():
93
- self._source_table_dict[table.name] = table._source_column_dict
94
-
95
102
  self._min_time_dict: dict[str, pd.Timestamp] = {}
96
103
  self._max_time_dict: dict[str, pd.Timestamp] = {}
97
104
 
@@ -102,6 +109,16 @@ class Sampler(ABC):
102
109
  r"""All available edge types in the graph."""
103
110
  return self._edge_types
104
111
 
112
+ @property
113
+ def foreign_key_dict(self) -> dict[str, list[tuple[str, str]]]:
114
+ r"""The foreign keys for all tables in the graph."""
115
+ return self._foreign_key_dict
116
+
117
+ @property
118
+ def rev_foreign_key_dict(self) -> dict[str, list[tuple[str, str]]]:
119
+ r"""The foreign key back references for all tables in the graph."""
120
+ return self._rev_foreign_key_dict
121
+
105
122
  @property
106
123
  def primary_key_dict(self) -> dict[str, str]:
107
124
  r"""All available primary keys in the graph."""
@@ -119,16 +136,11 @@ class Sampler(ABC):
119
136
 
120
137
  @property
121
138
  def table_stype_dict(self) -> dict[str, dict[str, Stype]]:
122
- r"""The registered semantic types for all columns in all tables in
123
- the graph.
139
+ r"""The registered semantic types for all feature columns in all tables
140
+ in the graph.
124
141
  """
125
142
  return self._table_stype_dict
126
143
 
127
- @property
128
- def source_table_dict(self) -> dict[str, dict[str, SourceColumn]]:
129
- r"""Source column information for all tables in the graph."""
130
- return self._source_table_dict
131
-
132
144
  def get_min_time(
133
145
  self,
134
146
  table_names: list[str] | None = None,
@@ -189,7 +201,7 @@ class Sampler(ABC):
189
201
  exclude_cols_dict: The columns to exclude from the subgraph.
190
202
  """
191
203
  # Exclude all columns that leak target information:
192
- table_stype_dict: dict[str, dict[str, Stype]] = self._table_stype_dict
204
+ table_stype_dict: dict[str, dict[str, Stype]] = self.table_stype_dict
193
205
  if exclude_cols_dict is not None:
194
206
  table_stype_dict = copy.deepcopy(table_stype_dict)
195
207
  for table_name, exclude_cols in exclude_cols_dict.items():
@@ -237,11 +249,8 @@ class Sampler(ABC):
237
249
  # Set end time to NaT for all values greater than anchor time:
238
250
  assert table_name not in out.inverse_dict
239
251
  ser = df[end_time_column]
240
- if ser.dtype != 'datetime64[ns]':
241
- ser = ser.astype('datetime64[ns]')
242
252
  mask = ser.astype(int).to_numpy() > out.anchor_time[batch]
243
- ser.iloc[mask] = pd.NaT
244
- df[end_time_column] = ser
253
+ df.loc[mask, end_time_column] = pd.NaT
245
254
 
246
255
  stype_dict = table_stype_dict[table_name]
247
256
  for column_name, stype in stype_dict.items():
@@ -6,7 +6,7 @@ from kumoapi.typing import Dtype
6
6
  @dataclass
7
7
  class SourceColumn:
8
8
  name: str
9
- dtype: Dtype
9
+ dtype: Dtype | None
10
10
  is_primary_key: bool
11
11
  is_unique_key: bool
12
12
  is_nullable: bool
@@ -1,11 +1,18 @@
1
1
  from abc import abstractmethod
2
+ from collections import defaultdict
2
3
  from typing import TYPE_CHECKING, Literal
3
4
 
4
5
  import numpy as np
5
6
  import pandas as pd
7
+ from kumoapi.typing import Dtype
6
8
 
7
- from kumoai.experimental.rfm.base import Sampler, SamplerOutput, SQLTable
8
- from kumoai.utils import ProgressLogger
9
+ from kumoai.experimental.rfm.base import (
10
+ LocalExpression,
11
+ Sampler,
12
+ SamplerOutput,
13
+ SourceColumn,
14
+ )
15
+ from kumoai.utils import ProgressLogger, quote_ident
9
16
 
10
17
  if TYPE_CHECKING:
11
18
  from kumoai.experimental.rfm import Graph
@@ -19,18 +26,71 @@ class SQLSampler(Sampler):
19
26
  ) -> None:
20
27
  super().__init__(graph=graph, verbose=verbose)
21
28
 
22
- self._fqn_dict: dict[str, str] = {}
29
+ self._source_name_dict: dict[str, str] = {
30
+ table.name: table._quoted_source_name
31
+ for table in graph.tables.values()
32
+ }
33
+
34
+ self._source_table_dict: dict[str, dict[str, SourceColumn]] = {}
35
+ for table in graph.tables.values():
36
+ self._source_table_dict[table.name] = {}
37
+ for column in table.columns:
38
+ if not column.is_source:
39
+ continue
40
+ src_column = table._source_column_dict[column.name]
41
+ self._source_table_dict[table.name][column.name] = src_column
42
+
43
+ self._table_dtype_dict: dict[str, dict[str, Dtype]] = {}
44
+ for table in graph.tables.values():
45
+ self._table_dtype_dict[table.name] = {}
46
+ for column in table.columns:
47
+ self._table_dtype_dict[table.name][column.name] = column.dtype
48
+
49
+ self._table_column_ref_dict: dict[str, dict[str, str]] = {}
50
+ self._table_column_proj_dict: dict[str, dict[str, str]] = {}
23
51
  for table in graph.tables.values():
24
- assert isinstance(table, SQLTable)
25
- self._connection = table._connection
26
- self._fqn_dict[table.name] = table.fqn
52
+ column_ref_dict: dict[str, str] = {}
53
+ column_proj_dict: dict[str, str] = {}
54
+ for column in table.columns:
55
+ if column.expr is not None:
56
+ assert isinstance(column.expr, LocalExpression)
57
+ column_ref_dict[column.name] = column.expr.value
58
+ column_proj_dict[column.name] = (
59
+ f'{column.expr} AS {quote_ident(column.name)}')
60
+ else:
61
+ column_ref_dict[column.name] = quote_ident(column.name)
62
+ column_proj_dict[column.name] = quote_ident(column.name)
63
+ self._table_column_ref_dict[table.name] = column_ref_dict
64
+ self._table_column_proj_dict[table.name] = column_proj_dict
65
+
66
+ @property
67
+ def source_name_dict(self) -> dict[str, str]:
68
+ r"""The source table names for all tables in the graph."""
69
+ return self._source_name_dict
70
+
71
+ @property
72
+ def source_table_dict(self) -> dict[str, dict[str, SourceColumn]]:
73
+ r"""The source column information for all tables in the graph."""
74
+ return self._source_table_dict
27
75
 
28
76
  @property
29
- def fqn_dict(self) -> dict[str, str]:
30
- r"""The fully-qualified quoted source name for all table names in the
77
+ def table_dtype_dict(self) -> dict[str, dict[str, Dtype]]:
78
+ r"""The data types for all columns in all tables in the graph."""
79
+ return self._table_dtype_dict
80
+
81
+ @property
82
+ def table_column_ref_dict(self) -> dict[str, dict[str, str]]:
83
+ r"""The SQL reference expression for all columns in all tables in the
31
84
  graph.
32
85
  """
33
- return self._fqn_dict
86
+ return self._table_column_ref_dict
87
+
88
+ @property
89
+ def table_column_proj_dict(self) -> dict[str, dict[str, str]]:
90
+ r"""The SQL projection expressions for all columns in all tables in the
91
+ graph.
92
+ """
93
+ return self._table_column_proj_dict
34
94
 
35
95
  def _sample_subgraph(
36
96
  self,
@@ -41,10 +101,25 @@ class SQLSampler(Sampler):
41
101
  num_neighbors: list[int],
42
102
  ) -> SamplerOutput:
43
103
 
104
+ # Make sure to include primary key, foreign key and time columns:
105
+ sample_columns_dict: dict[str, set[str]] = {}
106
+ for table, columns in columns_dict.items():
107
+ sample_columns = columns | {
108
+ foreign_key
109
+ for foreign_key, _ in self.foreign_key_dict[table]
110
+ }
111
+ if primary_key := self.primary_key_dict.get(table):
112
+ sample_columns |= {primary_key}
113
+ if time_column := self.time_column_dict.get(table):
114
+ sample_columns |= {time_column}
115
+ sample_columns_dict[table] = sample_columns
116
+
117
+ # Sample Entity Table #################################################
118
+
44
119
  df, batch = self._by_pkey(
45
120
  table_name=entity_table_name,
46
- pkey=entity_pkey,
47
- columns=columns_dict[entity_table_name],
121
+ index=entity_pkey,
122
+ columns=sample_columns_dict[entity_table_name],
48
123
  )
49
124
  if len(batch) != len(entity_pkey):
50
125
  mask = np.ones(len(entity_pkey), dtype=bool)
@@ -53,23 +128,99 @@ class SQLSampler(Sampler):
53
128
  f"{entity_pkey.iloc[mask].tolist()} do not exist "
54
129
  f"in the '{entity_table_name}' table")
55
130
 
131
+ # Make sure that entities are returned in expected order:
56
132
  perm = batch.argsort()
57
133
  batch = batch[perm]
58
134
  df = df.iloc[perm].reset_index(drop=True)
59
135
 
136
+ # Fill 'entity' anchor times with actual values:
60
137
  if not isinstance(anchor_time, pd.Series):
61
138
  time_column = self.time_column_dict[entity_table_name]
62
139
  anchor_time = df[time_column]
140
+ assert isinstance(anchor_time, pd.Series)
141
+
142
+ df_hop_dict: dict[tuple[str, int], pd.DataFrame] = {
143
+ (entity_table_name, 0): df,
144
+ }
145
+ batch_hop_dict: dict[tuple[str, int], np.ndarray] = {
146
+ (entity_table_name, 0): batch,
147
+ }
148
+
149
+ # Recursive Neighbor Sampling #########################################
150
+
151
+ for hop, neighbors in enumerate(num_neighbors):
152
+ if neighbors == 0:
153
+ break # Abort early.
154
+
155
+ dfs: dict[str, list[pd.DataFrame]] = defaultdict(list)
156
+ batches: dict[str, list[np.ndarray]] = defaultdict(list)
157
+
158
+ tables = [table for table, i in batch_hop_dict if i == hop]
159
+ for table in tables:
160
+ df = df_hop_dict[(table, hop)]
161
+ batch = batch_hop_dict[(table, hop)]
162
+
163
+ # Iterate over foreign keys in the current table:
164
+ for fkey, dst_table in self.foreign_key_dict[table]:
165
+ raise NotImplementedError
166
+
167
+ # Iterate over foreign keys that reference the current table:
168
+ for src_table, fkey in self.rev_foreign_key_dict[table]:
169
+ _df, _batch = self._by_fkey(
170
+ table_name=src_table,
171
+ foreign_key=fkey,
172
+ index=df[self.primary_key_dict[table]],
173
+ num_neighbors=neighbors,
174
+ anchor_time=anchor_time.iloc[batch],
175
+ columns=sample_columns_dict[src_table],
176
+ )
177
+ _batch = batch[_batch]
178
+
179
+ # TODO Filter out duplicates if `src_table` has a pkey.
180
+ dfs[src_table].append(_df)
181
+ batches[src_table].append(_batch)
182
+
183
+ # TODO Add edges to all sampled nodes.
184
+
185
+ # Post-Processing #####################################################
186
+
187
+ dfs_dict: dict[str, list[pd.DataFrame]] = defaultdict(list)
188
+ batches_dict: dict[str, list[np.ndarray]] = defaultdict(list)
189
+ num_hops = max(hop for _, hop in df_hop_dict.keys()) # TODO
190
+ num_sampled_nodes_dict: dict[str, list[int]] = {
191
+ table: [0] * (num_hops + 1)
192
+ for table in [table for table, _ in df_hop_dict.keys()]
193
+ }
194
+ for (table, hop), df in df_hop_dict.items():
195
+ dfs_dict[table].append(df)
196
+ batches_dict[table].append(batch_hop_dict[(table, hop)])
197
+ num_sampled_nodes_dict[table][hop] = len(df)
198
+
199
+ df_dict = { # Concatenate data frames across hops:
200
+ table:
201
+ pd.concat(dfs, axis=0, ignore_index=True)
202
+ if len(dfs) > 1 else dfs[0]
203
+ for table, dfs in dfs_dict.items()
204
+ }
205
+ df_dict = { # Post-filter column set:
206
+ table: df[list(columns_dict[table])]
207
+ for table_name, df in df_dict.items()
208
+ }
209
+ batch_dict = { # Concatenate batch vector across hops:
210
+ table:
211
+ np.concatenate(batches, axis=0) if len(batches) > 1 else batches[0]
212
+ for table, batches in batches_dict.items()
213
+ }
63
214
 
64
215
  return SamplerOutput(
65
216
  anchor_time=anchor_time.astype(int).to_numpy(),
66
- df_dict={entity_table_name: df},
67
- inverse_dict={},
68
- batch_dict={entity_table_name: batch},
69
- num_sampled_nodes_dict={entity_table_name: [len(batch)]},
70
- row_dict={},
71
- col_dict={},
72
- num_sampled_edges_dict={},
217
+ df_dict=df_dict,
218
+ inverse_dict={}, # TODO
219
+ batch_dict=batch_dict,
220
+ num_sampled_nodes_dict=num_sampled_nodes_dict,
221
+ row_dict={}, # TODO
222
+ col_dict={}, # TODO
223
+ num_sampled_edges_dict={}, # TODO
73
224
  )
74
225
 
75
226
  # Abstract Methods ########################################################
@@ -78,7 +229,19 @@ class SQLSampler(Sampler):
78
229
  def _by_pkey(
79
230
  self,
80
231
  table_name: str,
81
- pkey: pd.Series,
232
+ index: pd.Series,
233
+ columns: set[str],
234
+ ) -> tuple[pd.DataFrame, np.ndarray]:
235
+ pass
236
+
237
+ @abstractmethod
238
+ def _by_fkey(
239
+ self,
240
+ table_name: str,
241
+ foreign_key: str,
242
+ index: pd.Series,
243
+ num_neighbors: int,
244
+ anchor_time: pd.Series | None,
82
245
  columns: set[str],
83
246
  ) -> tuple[pd.DataFrame, np.ndarray]:
84
247
  pass