kumoai 2.14.0.dev202512211732__cp313-cp313-win_amd64.whl → 2.14.0.dev202601081732__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +24 -22
- kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +24 -25
- kumoai/experimental/rfm/backend/snow/sampler.py +184 -70
- kumoai/experimental/rfm/backend/snow/table.py +137 -64
- kumoai/experimental/rfm/backend/sqlite/sampler.py +191 -86
- kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
- kumoai/experimental/rfm/base/__init__.py +6 -9
- kumoai/experimental/rfm/base/column.py +95 -11
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +26 -17
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +182 -19
- kumoai/experimental/rfm/base/table.py +275 -109
- kumoai/experimental/rfm/graph.py +115 -107
- kumoai/experimental/rfm/infer/dtype.py +4 -1
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +530 -304
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/kumolib.cp313-win_amd64.pyd +0 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +13 -1
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/METADATA +1 -1
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/RECORD +36 -33
- kumoai/experimental/rfm/base/column_expression.py +0 -50
- kumoai/experimental/rfm/base/sql_table.py +0 -229
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/top_level.txt +0 -0
|
@@ -1,15 +1,89 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Any
|
|
4
|
+
from typing import Any, Mapping, TypeAlias
|
|
3
5
|
|
|
4
6
|
from kumoapi.typing import Dtype, Stype
|
|
7
|
+
from typing_extensions import Self
|
|
8
|
+
|
|
9
|
+
from kumoai.experimental.rfm.base import Expression
|
|
10
|
+
from kumoai.mixin import CastMixin
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(init=False)
|
|
14
|
+
class ColumnSpec(CastMixin):
|
|
15
|
+
r"""A column specification for adding a column to a table.
|
|
16
|
+
|
|
17
|
+
A column specification can either refer to a physical column present in
|
|
18
|
+
the data source, or be defined logically via an expression.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
name: The name of the column.
|
|
22
|
+
expr: A column expression to define logical columns.
|
|
23
|
+
dtype: The data type of the column.
|
|
24
|
+
"""
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
name: str,
|
|
28
|
+
expr: Expression | Mapping[str, str] | str | None = None,
|
|
29
|
+
dtype: Dtype | str | None = None,
|
|
30
|
+
stype: Stype | str | None = None,
|
|
31
|
+
) -> None:
|
|
32
|
+
|
|
33
|
+
self.name = name
|
|
34
|
+
self.expr = Expression.coerce(expr)
|
|
35
|
+
self.dtype = Dtype(dtype) if dtype is not None else None
|
|
36
|
+
self.stype = Stype(dtype) if stype is not None else None
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def coerce(cls, spec: ColumnSpec | Mapping[str, Any] | str) -> Self:
|
|
40
|
+
r"""Coerces a column specification into a :class:`ColumnSpec`."""
|
|
41
|
+
if isinstance(spec, cls):
|
|
42
|
+
return spec
|
|
43
|
+
if isinstance(spec, str):
|
|
44
|
+
return cls(name=spec)
|
|
45
|
+
if isinstance(spec, Mapping):
|
|
46
|
+
try:
|
|
47
|
+
return cls(**spec)
|
|
48
|
+
except TypeError:
|
|
49
|
+
pass
|
|
50
|
+
raise TypeError(f"Unable to coerce 'ColumnSpec' from '{spec}'")
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def is_source(self) -> bool:
|
|
54
|
+
r"""Whether the column specification refers to a phyiscal column
|
|
55
|
+
present in the data source.
|
|
56
|
+
"""
|
|
57
|
+
return self.expr is None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
ColumnSpecType: TypeAlias = ColumnSpec | Mapping[str, Any] | str
|
|
5
61
|
|
|
6
62
|
|
|
7
63
|
@dataclass(init=False, repr=False, eq=False)
|
|
8
64
|
class Column:
|
|
65
|
+
r"""Column-level metadata information.
|
|
66
|
+
|
|
67
|
+
A column can either refer to a physical column present in the data source,
|
|
68
|
+
or be defined logically via an expression.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
name: The name of the column.
|
|
72
|
+
expr: A column expression to define logical columns.
|
|
73
|
+
dtype: The data type of the column.
|
|
74
|
+
stype: The semantic type of the column.
|
|
75
|
+
"""
|
|
9
76
|
stype: Stype
|
|
10
77
|
|
|
11
|
-
def __init__(
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
name: str,
|
|
81
|
+
expr: Expression | None,
|
|
82
|
+
dtype: Dtype,
|
|
83
|
+
stype: Stype,
|
|
84
|
+
) -> None:
|
|
12
85
|
self._name = name
|
|
86
|
+
self._expr = expr
|
|
13
87
|
self._dtype = Dtype(dtype)
|
|
14
88
|
|
|
15
89
|
self._is_primary_key = False
|
|
@@ -20,19 +94,25 @@ class Column:
|
|
|
20
94
|
|
|
21
95
|
@property
|
|
22
96
|
def name(self) -> str:
|
|
97
|
+
r"""The name of the column."""
|
|
23
98
|
return self._name
|
|
24
99
|
|
|
25
100
|
@property
|
|
26
|
-
def
|
|
27
|
-
|
|
101
|
+
def expr(self) -> Expression | None:
|
|
102
|
+
r"""The expression of column (if logically)."""
|
|
103
|
+
return self._expr
|
|
28
104
|
|
|
29
105
|
@property
|
|
30
|
-
def
|
|
31
|
-
|
|
106
|
+
def dtype(self) -> Dtype:
|
|
107
|
+
r"""The data type of the column."""
|
|
108
|
+
return self._dtype
|
|
32
109
|
|
|
33
110
|
@property
|
|
34
|
-
def
|
|
35
|
-
|
|
111
|
+
def is_source(self) -> bool:
|
|
112
|
+
r"""Whether the column refers to a phyiscal column present in the data
|
|
113
|
+
source.
|
|
114
|
+
"""
|
|
115
|
+
return self.expr is None
|
|
36
116
|
|
|
37
117
|
def __setattr__(self, key: str, val: Any) -> None:
|
|
38
118
|
if key == 'stype':
|
|
@@ -56,7 +136,7 @@ class Column:
|
|
|
56
136
|
super().__setattr__(key, val)
|
|
57
137
|
|
|
58
138
|
def __hash__(self) -> int:
|
|
59
|
-
return hash((self.name, self.
|
|
139
|
+
return hash((self.name, self.expr, self.dtype, self.stype))
|
|
60
140
|
|
|
61
141
|
def __eq__(self, other: Any) -> bool:
|
|
62
142
|
if not isinstance(other, Column):
|
|
@@ -64,5 +144,9 @@ class Column:
|
|
|
64
144
|
return hash(self) == hash(other)
|
|
65
145
|
|
|
66
146
|
def __repr__(self) -> str:
|
|
67
|
-
|
|
68
|
-
|
|
147
|
+
parts = [f'name={self.name}']
|
|
148
|
+
if self.expr is not None:
|
|
149
|
+
parts.append(f'expr={self.expr}')
|
|
150
|
+
parts.append(f'dtype={self.dtype}')
|
|
151
|
+
parts.append(f'stype={self.stype}')
|
|
152
|
+
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Mapping
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Expression(ABC):
|
|
9
|
+
"""A base expression to define logical columns."""
|
|
10
|
+
@classmethod
|
|
11
|
+
def coerce(
|
|
12
|
+
cls,
|
|
13
|
+
spec: Expression | Mapping[str, str] | str | None,
|
|
14
|
+
) -> Expression | None:
|
|
15
|
+
r"""Coerces an expression specification into an :class:`Expression`, if
|
|
16
|
+
possible.
|
|
17
|
+
"""
|
|
18
|
+
if spec is None:
|
|
19
|
+
return None
|
|
20
|
+
if isinstance(spec, Expression):
|
|
21
|
+
return spec
|
|
22
|
+
if isinstance(spec, str):
|
|
23
|
+
return LocalExpression(spec)
|
|
24
|
+
if isinstance(spec, Mapping):
|
|
25
|
+
for sub_cls in (LocalExpression, ):
|
|
26
|
+
try:
|
|
27
|
+
return sub_cls(**spec)
|
|
28
|
+
except TypeError:
|
|
29
|
+
pass
|
|
30
|
+
raise TypeError(f"Unable to coerce 'Expression' from '{spec}'")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(frozen=True, repr=False)
|
|
34
|
+
class LocalExpression(Expression):
|
|
35
|
+
r"""A local expression to define a row-level logical attribute based on
|
|
36
|
+
physical columns of the data source in the same row.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
value: The value of the expression.
|
|
40
|
+
"""
|
|
41
|
+
value: str
|
|
42
|
+
|
|
43
|
+
def __repr__(self) -> str:
|
|
44
|
+
return self.value
|
|
@@ -13,7 +13,6 @@ from kumoapi.pquery.AST import Aggregation, ASTNode
|
|
|
13
13
|
from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
|
|
14
14
|
from kumoapi.typing import Stype
|
|
15
15
|
|
|
16
|
-
from kumoai.experimental.rfm.base import SourceColumn
|
|
17
16
|
from kumoai.utils import ProgressLogger
|
|
18
17
|
|
|
19
18
|
if TYPE_CHECKING:
|
|
@@ -53,12 +52,24 @@ class Sampler(ABC):
|
|
|
53
52
|
graph: 'Graph',
|
|
54
53
|
verbose: bool | ProgressLogger = True,
|
|
55
54
|
) -> None:
|
|
55
|
+
|
|
56
56
|
self._edge_types: list[tuple[str, str, str]] = []
|
|
57
57
|
for edge in graph.edges:
|
|
58
58
|
edge_type = (edge.src_table, edge.fkey, edge.dst_table)
|
|
59
59
|
self._edge_types.append(edge_type)
|
|
60
60
|
self._edge_types.append(Subgraph.rev_edge_type(edge_type))
|
|
61
61
|
|
|
62
|
+
# Source Table -> [(Foreign Key, Destination Table)]
|
|
63
|
+
self._foreign_key_dict: dict[str, list[tuple[str, str]]] = {}
|
|
64
|
+
# Destination Table -> [(Source Table, Foreign Key)]
|
|
65
|
+
self._rev_foreign_key_dict: dict[str, list[tuple[str, str]]] = {}
|
|
66
|
+
for table in graph.tables.values():
|
|
67
|
+
self._foreign_key_dict[table.name] = []
|
|
68
|
+
self._rev_foreign_key_dict[table.name] = []
|
|
69
|
+
for src_table, fkey, dst_table in graph.edges:
|
|
70
|
+
self._foreign_key_dict[src_table].append((fkey, dst_table))
|
|
71
|
+
self._rev_foreign_key_dict[dst_table].append((src_table, fkey))
|
|
72
|
+
|
|
62
73
|
self._primary_key_dict: dict[str, str] = {
|
|
63
74
|
table.name: table._primary_key
|
|
64
75
|
for table in graph.tables.values()
|
|
@@ -88,10 +99,6 @@ class Sampler(ABC):
|
|
|
88
99
|
continue
|
|
89
100
|
self._table_stype_dict[table.name][column.name] = column.stype
|
|
90
101
|
|
|
91
|
-
self._source_table_dict: dict[str, dict[str, SourceColumn]] = {}
|
|
92
|
-
for table in graph.tables.values():
|
|
93
|
-
self._source_table_dict[table.name] = table._source_column_dict
|
|
94
|
-
|
|
95
102
|
self._min_time_dict: dict[str, pd.Timestamp] = {}
|
|
96
103
|
self._max_time_dict: dict[str, pd.Timestamp] = {}
|
|
97
104
|
|
|
@@ -102,6 +109,16 @@ class Sampler(ABC):
|
|
|
102
109
|
r"""All available edge types in the graph."""
|
|
103
110
|
return self._edge_types
|
|
104
111
|
|
|
112
|
+
@property
|
|
113
|
+
def foreign_key_dict(self) -> dict[str, list[tuple[str, str]]]:
|
|
114
|
+
r"""The foreign keys for all tables in the graph."""
|
|
115
|
+
return self._foreign_key_dict
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def rev_foreign_key_dict(self) -> dict[str, list[tuple[str, str]]]:
|
|
119
|
+
r"""The foreign key back references for all tables in the graph."""
|
|
120
|
+
return self._rev_foreign_key_dict
|
|
121
|
+
|
|
105
122
|
@property
|
|
106
123
|
def primary_key_dict(self) -> dict[str, str]:
|
|
107
124
|
r"""All available primary keys in the graph."""
|
|
@@ -119,16 +136,11 @@ class Sampler(ABC):
|
|
|
119
136
|
|
|
120
137
|
@property
|
|
121
138
|
def table_stype_dict(self) -> dict[str, dict[str, Stype]]:
|
|
122
|
-
r"""The registered semantic types for all columns in all tables
|
|
123
|
-
the graph.
|
|
139
|
+
r"""The registered semantic types for all feature columns in all tables
|
|
140
|
+
in the graph.
|
|
124
141
|
"""
|
|
125
142
|
return self._table_stype_dict
|
|
126
143
|
|
|
127
|
-
@property
|
|
128
|
-
def source_table_dict(self) -> dict[str, dict[str, SourceColumn]]:
|
|
129
|
-
r"""Source column information for all tables in the graph."""
|
|
130
|
-
return self._source_table_dict
|
|
131
|
-
|
|
132
144
|
def get_min_time(
|
|
133
145
|
self,
|
|
134
146
|
table_names: list[str] | None = None,
|
|
@@ -189,7 +201,7 @@ class Sampler(ABC):
|
|
|
189
201
|
exclude_cols_dict: The columns to exclude from the subgraph.
|
|
190
202
|
"""
|
|
191
203
|
# Exclude all columns that leak target information:
|
|
192
|
-
table_stype_dict: dict[str, dict[str, Stype]] = self.
|
|
204
|
+
table_stype_dict: dict[str, dict[str, Stype]] = self.table_stype_dict
|
|
193
205
|
if exclude_cols_dict is not None:
|
|
194
206
|
table_stype_dict = copy.deepcopy(table_stype_dict)
|
|
195
207
|
for table_name, exclude_cols in exclude_cols_dict.items():
|
|
@@ -237,11 +249,8 @@ class Sampler(ABC):
|
|
|
237
249
|
# Set end time to NaT for all values greater than anchor time:
|
|
238
250
|
assert table_name not in out.inverse_dict
|
|
239
251
|
ser = df[end_time_column]
|
|
240
|
-
if ser.dtype != 'datetime64[ns]':
|
|
241
|
-
ser = ser.astype('datetime64[ns]')
|
|
242
252
|
mask = ser.astype(int).to_numpy() > out.anchor_time[batch]
|
|
243
|
-
|
|
244
|
-
df[end_time_column] = ser
|
|
253
|
+
df.loc[mask, end_time_column] = pd.NaT
|
|
245
254
|
|
|
246
255
|
stype_dict = table_stype_dict[table_name]
|
|
247
256
|
for column_name, stype in stype_dict.items():
|
|
@@ -1,11 +1,18 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
|
+
from collections import defaultdict
|
|
2
3
|
from typing import TYPE_CHECKING, Literal
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
5
6
|
import pandas as pd
|
|
7
|
+
from kumoapi.typing import Dtype
|
|
6
8
|
|
|
7
|
-
from kumoai.experimental.rfm.base import
|
|
8
|
-
|
|
9
|
+
from kumoai.experimental.rfm.base import (
|
|
10
|
+
LocalExpression,
|
|
11
|
+
Sampler,
|
|
12
|
+
SamplerOutput,
|
|
13
|
+
SourceColumn,
|
|
14
|
+
)
|
|
15
|
+
from kumoai.utils import ProgressLogger, quote_ident
|
|
9
16
|
|
|
10
17
|
if TYPE_CHECKING:
|
|
11
18
|
from kumoai.experimental.rfm import Graph
|
|
@@ -19,18 +26,71 @@ class SQLSampler(Sampler):
|
|
|
19
26
|
) -> None:
|
|
20
27
|
super().__init__(graph=graph, verbose=verbose)
|
|
21
28
|
|
|
22
|
-
self.
|
|
29
|
+
self._source_name_dict: dict[str, str] = {
|
|
30
|
+
table.name: table._quoted_source_name
|
|
31
|
+
for table in graph.tables.values()
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
self._source_table_dict: dict[str, dict[str, SourceColumn]] = {}
|
|
35
|
+
for table in graph.tables.values():
|
|
36
|
+
self._source_table_dict[table.name] = {}
|
|
37
|
+
for column in table.columns:
|
|
38
|
+
if not column.is_source:
|
|
39
|
+
continue
|
|
40
|
+
src_column = table._source_column_dict[column.name]
|
|
41
|
+
self._source_table_dict[table.name][column.name] = src_column
|
|
42
|
+
|
|
43
|
+
self._table_dtype_dict: dict[str, dict[str, Dtype]] = {}
|
|
44
|
+
for table in graph.tables.values():
|
|
45
|
+
self._table_dtype_dict[table.name] = {}
|
|
46
|
+
for column in table.columns:
|
|
47
|
+
self._table_dtype_dict[table.name][column.name] = column.dtype
|
|
48
|
+
|
|
49
|
+
self._table_column_ref_dict: dict[str, dict[str, str]] = {}
|
|
50
|
+
self._table_column_proj_dict: dict[str, dict[str, str]] = {}
|
|
23
51
|
for table in graph.tables.values():
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
52
|
+
column_ref_dict: dict[str, str] = {}
|
|
53
|
+
column_proj_dict: dict[str, str] = {}
|
|
54
|
+
for column in table.columns:
|
|
55
|
+
if column.expr is not None:
|
|
56
|
+
assert isinstance(column.expr, LocalExpression)
|
|
57
|
+
column_ref_dict[column.name] = column.expr.value
|
|
58
|
+
column_proj_dict[column.name] = (
|
|
59
|
+
f'{column.expr} AS {quote_ident(column.name)}')
|
|
60
|
+
else:
|
|
61
|
+
column_ref_dict[column.name] = quote_ident(column.name)
|
|
62
|
+
column_proj_dict[column.name] = quote_ident(column.name)
|
|
63
|
+
self._table_column_ref_dict[table.name] = column_ref_dict
|
|
64
|
+
self._table_column_proj_dict[table.name] = column_proj_dict
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def source_name_dict(self) -> dict[str, str]:
|
|
68
|
+
r"""The source table names for all tables in the graph."""
|
|
69
|
+
return self._source_name_dict
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def source_table_dict(self) -> dict[str, dict[str, SourceColumn]]:
|
|
73
|
+
r"""The source column information for all tables in the graph."""
|
|
74
|
+
return self._source_table_dict
|
|
27
75
|
|
|
28
76
|
@property
|
|
29
|
-
def
|
|
30
|
-
r"""The
|
|
77
|
+
def table_dtype_dict(self) -> dict[str, dict[str, Dtype]]:
|
|
78
|
+
r"""The data types for all columns in all tables in the graph."""
|
|
79
|
+
return self._table_dtype_dict
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def table_column_ref_dict(self) -> dict[str, dict[str, str]]:
|
|
83
|
+
r"""The SQL reference expression for all columns in all tables in the
|
|
31
84
|
graph.
|
|
32
85
|
"""
|
|
33
|
-
return self.
|
|
86
|
+
return self._table_column_ref_dict
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def table_column_proj_dict(self) -> dict[str, dict[str, str]]:
|
|
90
|
+
r"""The SQL projection expressions for all columns in all tables in the
|
|
91
|
+
graph.
|
|
92
|
+
"""
|
|
93
|
+
return self._table_column_proj_dict
|
|
34
94
|
|
|
35
95
|
def _sample_subgraph(
|
|
36
96
|
self,
|
|
@@ -41,10 +101,25 @@ class SQLSampler(Sampler):
|
|
|
41
101
|
num_neighbors: list[int],
|
|
42
102
|
) -> SamplerOutput:
|
|
43
103
|
|
|
104
|
+
# Make sure to include primary key, foreign key and time columns:
|
|
105
|
+
sample_columns_dict: dict[str, set[str]] = {}
|
|
106
|
+
for table, columns in columns_dict.items():
|
|
107
|
+
sample_columns = columns | {
|
|
108
|
+
foreign_key
|
|
109
|
+
for foreign_key, _ in self.foreign_key_dict[table]
|
|
110
|
+
}
|
|
111
|
+
if primary_key := self.primary_key_dict.get(table):
|
|
112
|
+
sample_columns |= {primary_key}
|
|
113
|
+
if time_column := self.time_column_dict.get(table):
|
|
114
|
+
sample_columns |= {time_column}
|
|
115
|
+
sample_columns_dict[table] = sample_columns
|
|
116
|
+
|
|
117
|
+
# Sample Entity Table #################################################
|
|
118
|
+
|
|
44
119
|
df, batch = self._by_pkey(
|
|
45
120
|
table_name=entity_table_name,
|
|
46
|
-
|
|
47
|
-
columns=
|
|
121
|
+
index=entity_pkey,
|
|
122
|
+
columns=sample_columns_dict[entity_table_name],
|
|
48
123
|
)
|
|
49
124
|
if len(batch) != len(entity_pkey):
|
|
50
125
|
mask = np.ones(len(entity_pkey), dtype=bool)
|
|
@@ -53,23 +128,99 @@ class SQLSampler(Sampler):
|
|
|
53
128
|
f"{entity_pkey.iloc[mask].tolist()} do not exist "
|
|
54
129
|
f"in the '{entity_table_name}' table")
|
|
55
130
|
|
|
131
|
+
# Make sure that entities are returned in expected order:
|
|
56
132
|
perm = batch.argsort()
|
|
57
133
|
batch = batch[perm]
|
|
58
134
|
df = df.iloc[perm].reset_index(drop=True)
|
|
59
135
|
|
|
136
|
+
# Fill 'entity' anchor times with actual values:
|
|
60
137
|
if not isinstance(anchor_time, pd.Series):
|
|
61
138
|
time_column = self.time_column_dict[entity_table_name]
|
|
62
139
|
anchor_time = df[time_column]
|
|
140
|
+
assert isinstance(anchor_time, pd.Series)
|
|
141
|
+
|
|
142
|
+
df_hop_dict: dict[tuple[str, int], pd.DataFrame] = {
|
|
143
|
+
(entity_table_name, 0): df,
|
|
144
|
+
}
|
|
145
|
+
batch_hop_dict: dict[tuple[str, int], np.ndarray] = {
|
|
146
|
+
(entity_table_name, 0): batch,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
# Recursive Neighbor Sampling #########################################
|
|
150
|
+
|
|
151
|
+
for hop, neighbors in enumerate(num_neighbors):
|
|
152
|
+
if neighbors == 0:
|
|
153
|
+
break # Abort early.
|
|
154
|
+
|
|
155
|
+
dfs: dict[str, list[pd.DataFrame]] = defaultdict(list)
|
|
156
|
+
batches: dict[str, list[np.ndarray]] = defaultdict(list)
|
|
157
|
+
|
|
158
|
+
tables = [table for table, i in batch_hop_dict if i == hop]
|
|
159
|
+
for table in tables:
|
|
160
|
+
df = df_hop_dict[(table, hop)]
|
|
161
|
+
batch = batch_hop_dict[(table, hop)]
|
|
162
|
+
|
|
163
|
+
# Iterate over foreign keys in the current table:
|
|
164
|
+
for fkey, dst_table in self.foreign_key_dict[table]:
|
|
165
|
+
raise NotImplementedError
|
|
166
|
+
|
|
167
|
+
# Iterate over foreign keys that reference the current table:
|
|
168
|
+
for src_table, fkey in self.rev_foreign_key_dict[table]:
|
|
169
|
+
_df, _batch = self._by_fkey(
|
|
170
|
+
table_name=src_table,
|
|
171
|
+
foreign_key=fkey,
|
|
172
|
+
index=df[self.primary_key_dict[table]],
|
|
173
|
+
num_neighbors=neighbors,
|
|
174
|
+
anchor_time=anchor_time.iloc[batch],
|
|
175
|
+
columns=sample_columns_dict[src_table],
|
|
176
|
+
)
|
|
177
|
+
_batch = batch[_batch]
|
|
178
|
+
|
|
179
|
+
# TODO Filter out duplicates if `src_table` has a pkey.
|
|
180
|
+
dfs[src_table].append(_df)
|
|
181
|
+
batches[src_table].append(_batch)
|
|
182
|
+
|
|
183
|
+
# TODO Add edges to all sampled nodes.
|
|
184
|
+
|
|
185
|
+
# Post-Processing #####################################################
|
|
186
|
+
|
|
187
|
+
dfs_dict: dict[str, list[pd.DataFrame]] = defaultdict(list)
|
|
188
|
+
batches_dict: dict[str, list[np.ndarray]] = defaultdict(list)
|
|
189
|
+
num_hops = max(hop for _, hop in df_hop_dict.keys()) # TODO
|
|
190
|
+
num_sampled_nodes_dict: dict[str, list[int]] = {
|
|
191
|
+
table: [0] * (num_hops + 1)
|
|
192
|
+
for table in [table for table, _ in df_hop_dict.keys()]
|
|
193
|
+
}
|
|
194
|
+
for (table, hop), df in df_hop_dict.items():
|
|
195
|
+
dfs_dict[table].append(df)
|
|
196
|
+
batches_dict[table].append(batch_hop_dict[(table, hop)])
|
|
197
|
+
num_sampled_nodes_dict[table][hop] = len(df)
|
|
198
|
+
|
|
199
|
+
df_dict = { # Concatenate data frames across hops:
|
|
200
|
+
table:
|
|
201
|
+
pd.concat(dfs, axis=0, ignore_index=True)
|
|
202
|
+
if len(dfs) > 1 else dfs[0]
|
|
203
|
+
for table, dfs in dfs_dict.items()
|
|
204
|
+
}
|
|
205
|
+
df_dict = { # Post-filter column set:
|
|
206
|
+
table: df[list(columns_dict[table])]
|
|
207
|
+
for table_name, df in df_dict.items()
|
|
208
|
+
}
|
|
209
|
+
batch_dict = { # Concatenate batch vector across hops:
|
|
210
|
+
table:
|
|
211
|
+
np.concatenate(batches, axis=0) if len(batches) > 1 else batches[0]
|
|
212
|
+
for table, batches in batches_dict.items()
|
|
213
|
+
}
|
|
63
214
|
|
|
64
215
|
return SamplerOutput(
|
|
65
216
|
anchor_time=anchor_time.astype(int).to_numpy(),
|
|
66
|
-
df_dict=
|
|
67
|
-
inverse_dict={},
|
|
68
|
-
batch_dict=
|
|
69
|
-
num_sampled_nodes_dict=
|
|
70
|
-
row_dict={},
|
|
71
|
-
col_dict={},
|
|
72
|
-
num_sampled_edges_dict={},
|
|
217
|
+
df_dict=df_dict,
|
|
218
|
+
inverse_dict={}, # TODO
|
|
219
|
+
batch_dict=batch_dict,
|
|
220
|
+
num_sampled_nodes_dict=num_sampled_nodes_dict,
|
|
221
|
+
row_dict={}, # TODO
|
|
222
|
+
col_dict={}, # TODO
|
|
223
|
+
num_sampled_edges_dict={}, # TODO
|
|
73
224
|
)
|
|
74
225
|
|
|
75
226
|
# Abstract Methods ########################################################
|
|
@@ -78,7 +229,19 @@ class SQLSampler(Sampler):
|
|
|
78
229
|
def _by_pkey(
|
|
79
230
|
self,
|
|
80
231
|
table_name: str,
|
|
81
|
-
|
|
232
|
+
index: pd.Series,
|
|
233
|
+
columns: set[str],
|
|
234
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
235
|
+
pass
|
|
236
|
+
|
|
237
|
+
@abstractmethod
|
|
238
|
+
def _by_fkey(
|
|
239
|
+
self,
|
|
240
|
+
table_name: str,
|
|
241
|
+
foreign_key: str,
|
|
242
|
+
index: pd.Series,
|
|
243
|
+
num_neighbors: int,
|
|
244
|
+
anchor_time: pd.Series | None,
|
|
82
245
|
columns: set[str],
|
|
83
246
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
84
247
|
pass
|