kumoai 2.13.0.dev202512021731__cp310-cp310-win_amd64.whl → 2.13.0.dev202512041731__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/experimental/rfm/backend/local/table.py +32 -167
- kumoai/experimental/rfm/backend/snow/__init__.py +3 -0
- kumoai/experimental/rfm/backend/snow/table.py +117 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +58 -81
- kumoai/experimental/rfm/base/__init__.py +5 -0
- kumoai/experimental/rfm/base/sampler.py +134 -0
- kumoai/experimental/rfm/base/source.py +18 -0
- kumoai/experimental/rfm/base/table.py +95 -27
- kumoai/experimental/rfm/graph.py +220 -52
- kumoai/experimental/rfm/infer/__init__.py +6 -2
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/{utils.py → infer/pkey.py} +2 -101
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/local_graph_sampler.py +42 -1
- kumoai/experimental/rfm/local_graph_store.py +1 -16
- kumoai/experimental/rfm/rfm.py +1 -11
- kumoai/kumolib.cp310-win_amd64.pyd +0 -0
- {kumoai-2.13.0.dev202512021731.dist-info → kumoai-2.13.0.dev202512041731.dist-info}/METADATA +2 -1
- {kumoai-2.13.0.dev202512021731.dist-info → kumoai-2.13.0.dev202512041731.dist-info}/RECORD +24 -20
- kumoai/experimental/rfm/infer/stype.py +0 -35
- {kumoai-2.13.0.dev202512021731.dist-info → kumoai-2.13.0.dev202512041731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512021731.dist-info → kumoai-2.13.0.dev202512041731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512021731.dist-info → kumoai-2.13.0.dev202512041731.dist-info}/top_level.txt +0 -0
kumoai/__init__.py
CHANGED
|
@@ -280,7 +280,19 @@ __all__ = [
|
|
|
280
280
|
]
|
|
281
281
|
|
|
282
282
|
|
|
283
|
+
def in_snowflake_notebook() -> bool:
|
|
284
|
+
try:
|
|
285
|
+
from snowflake.snowpark.context import get_active_session
|
|
286
|
+
import streamlit # noqa: F401
|
|
287
|
+
get_active_session()
|
|
288
|
+
return True
|
|
289
|
+
except Exception:
|
|
290
|
+
return False
|
|
291
|
+
|
|
292
|
+
|
|
283
293
|
def in_notebook() -> bool:
|
|
294
|
+
if in_snowflake_notebook():
|
|
295
|
+
return True
|
|
284
296
|
try:
|
|
285
297
|
from IPython import get_ipython
|
|
286
298
|
shell = get_ipython()
|
kumoai/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.13.0.
|
|
1
|
+
__version__ = '2.13.0.dev202512041731'
|
|
@@ -1,14 +1,10 @@
|
|
|
1
|
-
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import List, Optional
|
|
2
3
|
|
|
3
|
-
import numpy as np
|
|
4
4
|
import pandas as pd
|
|
5
|
-
import pyarrow as pa
|
|
6
|
-
from kumoapi.typing import Dtype, Stype
|
|
7
|
-
from typing_extensions import Self
|
|
8
5
|
|
|
9
|
-
from kumoai.experimental.rfm import
|
|
10
|
-
from kumoai.experimental.rfm.
|
|
11
|
-
from kumoai.experimental.rfm.infer import infer_stype
|
|
6
|
+
from kumoai.experimental.rfm.base import SourceColumn, SourceForeignKey, Table
|
|
7
|
+
from kumoai.experimental.rfm.infer import infer_dtype
|
|
12
8
|
|
|
13
9
|
|
|
14
10
|
class LocalTable(Table):
|
|
@@ -62,7 +58,7 @@ class LocalTable(Table):
|
|
|
62
58
|
) -> None:
|
|
63
59
|
|
|
64
60
|
if df.empty:
|
|
65
|
-
raise ValueError("Data frame
|
|
61
|
+
raise ValueError("Data frame is empty")
|
|
66
62
|
if isinstance(df.columns, pd.MultiIndex):
|
|
67
63
|
raise ValueError("Data frame must not have a multi-index")
|
|
68
64
|
if not df.columns.is_unique:
|
|
@@ -80,165 +76,34 @@ class LocalTable(Table):
|
|
|
80
76
|
end_time_column=end_time_column,
|
|
81
77
|
)
|
|
82
78
|
|
|
83
|
-
def
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
column.name for column in self.columns if is_candidate(column)
|
|
108
|
-
]
|
|
109
|
-
|
|
110
|
-
if primary_key := utils.detect_primary_key(
|
|
111
|
-
table_name=self.name,
|
|
112
|
-
df=self._data,
|
|
113
|
-
candidates=candidates,
|
|
114
|
-
):
|
|
115
|
-
self.primary_key = primary_key
|
|
116
|
-
logs.append(f"primary key '{primary_key}'")
|
|
117
|
-
|
|
118
|
-
# Try to detect time column if not set:
|
|
119
|
-
if not self.has_time_column():
|
|
120
|
-
candidates = [
|
|
121
|
-
column.name for column in self.columns
|
|
122
|
-
if column.stype == Stype.timestamp
|
|
123
|
-
and column.name != self._end_time_column
|
|
124
|
-
]
|
|
125
|
-
if time_column := utils.detect_time_column(self._data, candidates):
|
|
126
|
-
self.time_column = time_column
|
|
127
|
-
logs.append(f"time column '{time_column}'")
|
|
128
|
-
|
|
129
|
-
if verbose and len(logs) > 0:
|
|
130
|
-
print(f"Detected {' and '.join(logs)} in table '{self.name}'")
|
|
131
|
-
|
|
132
|
-
return self
|
|
133
|
-
|
|
134
|
-
def _has_source_column(self, name: str) -> bool:
|
|
135
|
-
return name in self._data.columns
|
|
136
|
-
|
|
137
|
-
def _get_source_dtype(self, name: str) -> Dtype:
|
|
138
|
-
return to_dtype(self._data[name])
|
|
139
|
-
|
|
140
|
-
def _get_source_stype(self, name: str, dtype: Dtype) -> Stype:
|
|
141
|
-
return infer_stype(self._data[name], name, dtype)
|
|
142
|
-
|
|
143
|
-
def _get_source_foreign_keys(self) -> List[Tuple[str, str, str]]:
|
|
79
|
+
def _get_source_columns(self) -> List[SourceColumn]:
|
|
80
|
+
source_columns: List[SourceColumn] = []
|
|
81
|
+
for column in self._data.columns:
|
|
82
|
+
ser = self._data[column]
|
|
83
|
+
try:
|
|
84
|
+
dtype = infer_dtype(ser)
|
|
85
|
+
except Exception:
|
|
86
|
+
warnings.warn(f"Data type inference for column '{column}' in "
|
|
87
|
+
f"table '{self.name}' failed. Consider changing "
|
|
88
|
+
f"the data type of the column to use it within "
|
|
89
|
+
f"this table.")
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
source_column = SourceColumn(
|
|
93
|
+
name=column,
|
|
94
|
+
dtype=dtype,
|
|
95
|
+
is_primary_key=False,
|
|
96
|
+
is_unique_key=False,
|
|
97
|
+
)
|
|
98
|
+
source_columns.append(source_column)
|
|
99
|
+
|
|
100
|
+
return source_columns
|
|
101
|
+
|
|
102
|
+
def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
|
|
144
103
|
return []
|
|
145
104
|
|
|
146
|
-
def
|
|
147
|
-
return
|
|
148
|
-
table_name=self.name,
|
|
149
|
-
df=self._data,
|
|
150
|
-
candidates=candidates,
|
|
151
|
-
)
|
|
152
|
-
|
|
153
|
-
def _infer_time_column(self, candidates: List[str]) -> Optional[str]:
|
|
154
|
-
return utils.detect_time_column(df=self._data, candidates=candidates)
|
|
105
|
+
def _get_sample_df(self) -> pd.DataFrame:
|
|
106
|
+
return self._data
|
|
155
107
|
|
|
156
|
-
def
|
|
108
|
+
def _get_num_rows(self) -> Optional[int]:
|
|
157
109
|
return len(self._data)
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
# Data Type ###################################################################
|
|
161
|
-
|
|
162
|
-
PANDAS_TO_DTYPE: Dict[Any, Dtype] = {
|
|
163
|
-
np.dtype('bool'): Dtype.bool,
|
|
164
|
-
pd.BooleanDtype(): Dtype.bool,
|
|
165
|
-
pa.bool_(): Dtype.bool,
|
|
166
|
-
np.dtype('byte'): Dtype.int,
|
|
167
|
-
pd.UInt8Dtype(): Dtype.int,
|
|
168
|
-
np.dtype('int16'): Dtype.int,
|
|
169
|
-
pd.Int16Dtype(): Dtype.int,
|
|
170
|
-
np.dtype('int32'): Dtype.int,
|
|
171
|
-
pd.Int32Dtype(): Dtype.int,
|
|
172
|
-
np.dtype('int64'): Dtype.int,
|
|
173
|
-
pd.Int64Dtype(): Dtype.int,
|
|
174
|
-
np.dtype('float32'): Dtype.float,
|
|
175
|
-
pd.Float32Dtype(): Dtype.float,
|
|
176
|
-
np.dtype('float64'): Dtype.float,
|
|
177
|
-
pd.Float64Dtype(): Dtype.float,
|
|
178
|
-
np.dtype('object'): Dtype.string,
|
|
179
|
-
pd.StringDtype(storage='python'): Dtype.string,
|
|
180
|
-
pd.StringDtype(storage='pyarrow'): Dtype.string,
|
|
181
|
-
pa.string(): Dtype.string,
|
|
182
|
-
pa.binary(): Dtype.binary,
|
|
183
|
-
np.dtype('datetime64[ns]'): Dtype.date,
|
|
184
|
-
np.dtype('timedelta64[ns]'): Dtype.timedelta,
|
|
185
|
-
pa.list_(pa.float32()): Dtype.floatlist,
|
|
186
|
-
pa.list_(pa.int64()): Dtype.intlist,
|
|
187
|
-
pa.list_(pa.string()): Dtype.stringlist,
|
|
188
|
-
}
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
def to_dtype(ser: pd.Series) -> Dtype:
|
|
192
|
-
"""Extracts the :class:`Dtype` from a :class:`pandas.Series`.
|
|
193
|
-
|
|
194
|
-
Args:
|
|
195
|
-
ser: A :class:`pandas.Series` to analyze.
|
|
196
|
-
|
|
197
|
-
Returns:
|
|
198
|
-
The data type.
|
|
199
|
-
"""
|
|
200
|
-
if pd.api.types.is_datetime64_any_dtype(ser.dtype):
|
|
201
|
-
return Dtype.date
|
|
202
|
-
|
|
203
|
-
if isinstance(ser.dtype, pd.CategoricalDtype):
|
|
204
|
-
return Dtype.string
|
|
205
|
-
|
|
206
|
-
if pd.api.types.is_object_dtype(ser.dtype):
|
|
207
|
-
index = ser.iloc[:1000].first_valid_index()
|
|
208
|
-
if index is not None and pd.api.types.is_list_like(ser[index]):
|
|
209
|
-
pos = ser.index.get_loc(index)
|
|
210
|
-
assert isinstance(pos, int)
|
|
211
|
-
ser = ser.iloc[pos:pos + 1000].dropna()
|
|
212
|
-
|
|
213
|
-
if not ser.map(pd.api.types.is_list_like).all():
|
|
214
|
-
raise ValueError("Data contains a mix of list-like and "
|
|
215
|
-
"non-list-like values")
|
|
216
|
-
|
|
217
|
-
# Remove all empty Python lists without known data type:
|
|
218
|
-
ser = ser[ser.map(lambda x: not isinstance(x, list) or len(x) > 0)]
|
|
219
|
-
|
|
220
|
-
# Infer unique data types in this series:
|
|
221
|
-
dtypes = ser.apply(lambda x: PANDAS_TO_DTYPE.get(
|
|
222
|
-
np.array(x).dtype, Dtype.string)).unique().tolist()
|
|
223
|
-
|
|
224
|
-
invalid_dtypes = set(dtypes) - {
|
|
225
|
-
Dtype.string,
|
|
226
|
-
Dtype.int,
|
|
227
|
-
Dtype.float,
|
|
228
|
-
}
|
|
229
|
-
if len(invalid_dtypes) > 0:
|
|
230
|
-
raise ValueError(f"Data contains unsupported list data types: "
|
|
231
|
-
f"{list(invalid_dtypes)}")
|
|
232
|
-
|
|
233
|
-
if Dtype.string in dtypes:
|
|
234
|
-
return Dtype.stringlist
|
|
235
|
-
|
|
236
|
-
if dtypes == [Dtype.int]:
|
|
237
|
-
return Dtype.intlist
|
|
238
|
-
|
|
239
|
-
return Dtype.floatlist
|
|
240
|
-
|
|
241
|
-
if ser.dtype not in PANDAS_TO_DTYPE:
|
|
242
|
-
raise ValueError(f"Unsupported data type '{ser.dtype}'")
|
|
243
|
-
|
|
244
|
-
return PANDAS_TO_DTYPE[ser.dtype]
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import List, Optional, Sequence
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from kumoapi.typing import Dtype
|
|
6
|
+
|
|
7
|
+
from kumoai.experimental.rfm.backend.snow import Connection
|
|
8
|
+
from kumoai.experimental.rfm.base import SourceColumn, SourceForeignKey, Table
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SnowTable(Table):
|
|
12
|
+
r"""A table backed by a :class:`sqlite` database.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
connection: The connection to a :class:`snowflake` database.
|
|
16
|
+
name: The name of this table.
|
|
17
|
+
database: The database.
|
|
18
|
+
schema: The schema.
|
|
19
|
+
columns: The selected columns of this table.
|
|
20
|
+
primary_key: The name of the primary key of this table, if it exists.
|
|
21
|
+
time_column: The name of the time column of this table, if it exists.
|
|
22
|
+
end_time_column: The name of the end time column of this table, if it
|
|
23
|
+
exists.
|
|
24
|
+
"""
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
connection: Connection,
|
|
28
|
+
name: str,
|
|
29
|
+
database: str | None = None,
|
|
30
|
+
schema: str | None = None,
|
|
31
|
+
columns: Optional[Sequence[str]] = None,
|
|
32
|
+
primary_key: Optional[str] = None,
|
|
33
|
+
time_column: Optional[str] = None,
|
|
34
|
+
end_time_column: Optional[str] = None,
|
|
35
|
+
) -> None:
|
|
36
|
+
|
|
37
|
+
if database is not None and schema is None:
|
|
38
|
+
raise ValueError(f"Missing 'schema' for table '{name}' in "
|
|
39
|
+
f"database '{database}'")
|
|
40
|
+
|
|
41
|
+
self._connection = connection
|
|
42
|
+
self._database = database
|
|
43
|
+
self._schema = schema
|
|
44
|
+
|
|
45
|
+
super().__init__(
|
|
46
|
+
name=name,
|
|
47
|
+
columns=columns,
|
|
48
|
+
primary_key=primary_key,
|
|
49
|
+
time_column=time_column,
|
|
50
|
+
end_time_column=end_time_column,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def fqn_name(self) -> str:
|
|
55
|
+
names: List[str] = []
|
|
56
|
+
if self._database is not None:
|
|
57
|
+
assert self._schema is not None
|
|
58
|
+
names.extend([self._database, self._schema])
|
|
59
|
+
elif self._schema is not None:
|
|
60
|
+
names.append(self._schema)
|
|
61
|
+
names.append(self._name)
|
|
62
|
+
return '.'.join(names)
|
|
63
|
+
|
|
64
|
+
def _get_source_columns(self) -> List[SourceColumn]:
|
|
65
|
+
source_columns: List[SourceColumn] = []
|
|
66
|
+
with self._connection.cursor() as cursor:
|
|
67
|
+
try:
|
|
68
|
+
cursor.execute(f"DESCRIBE TABLE {self.fqn_name}")
|
|
69
|
+
except Exception as e:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"Table '{self.fqn_name}' does not exist") from e
|
|
72
|
+
|
|
73
|
+
for row in cursor.fetchall():
|
|
74
|
+
column, type, _, _, _, is_pkey, is_unique = row[:7]
|
|
75
|
+
|
|
76
|
+
type = type.strip().upper()
|
|
77
|
+
if type.startswith('NUMBER'):
|
|
78
|
+
dtype = Dtype.int
|
|
79
|
+
elif type.startswith('VARCHAR'):
|
|
80
|
+
dtype = Dtype.string
|
|
81
|
+
elif type == 'FLOAT':
|
|
82
|
+
dtype = Dtype.float
|
|
83
|
+
elif type == 'BOOLEAN':
|
|
84
|
+
dtype = Dtype.bool
|
|
85
|
+
elif re.search('DATE|TIMESTAMP', type):
|
|
86
|
+
dtype = Dtype.date
|
|
87
|
+
else:
|
|
88
|
+
continue
|
|
89
|
+
|
|
90
|
+
source_column = SourceColumn(
|
|
91
|
+
name=column,
|
|
92
|
+
dtype=dtype,
|
|
93
|
+
is_primary_key=is_pkey.strip().upper() == 'Y',
|
|
94
|
+
is_unique_key=is_unique.strip().upper() == 'Y',
|
|
95
|
+
)
|
|
96
|
+
source_columns.append(source_column)
|
|
97
|
+
|
|
98
|
+
return source_columns
|
|
99
|
+
|
|
100
|
+
def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
|
|
101
|
+
source_fkeys: List[SourceForeignKey] = []
|
|
102
|
+
with self._connection.cursor() as cursor:
|
|
103
|
+
cursor.execute(f"SHOW IMPORTED KEYS IN TABLE {self.fqn_name}")
|
|
104
|
+
for row in cursor.fetchall():
|
|
105
|
+
_, _, _, dst_table, pkey, _, _, _, fkey = row[:9]
|
|
106
|
+
source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
|
|
107
|
+
return source_fkeys
|
|
108
|
+
|
|
109
|
+
def _get_sample_df(self) -> pd.DataFrame:
|
|
110
|
+
with self._connection.cursor() as cursor:
|
|
111
|
+
columns = ', '.join(self._source_column_dict.keys())
|
|
112
|
+
cursor.execute(f"SELECT {columns} FROM {self.fqn_name} LIMIT 1000")
|
|
113
|
+
table = cursor.fetch_arrow_all()
|
|
114
|
+
return table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
115
|
+
|
|
116
|
+
def _get_num_rows(self) -> Optional[int]:
|
|
117
|
+
return None
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
import re
|
|
2
|
-
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import List, Optional, Sequence
|
|
3
4
|
|
|
4
|
-
import
|
|
5
|
-
from kumoapi.typing import Dtype
|
|
6
|
-
from typing_extensions import Self
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from kumoapi.typing import Dtype
|
|
7
7
|
|
|
8
8
|
from kumoai.experimental.rfm.backend.sqlite import Connection
|
|
9
|
-
from kumoai.experimental.rfm.base import Table
|
|
10
|
-
from kumoai.experimental.rfm.infer import
|
|
9
|
+
from kumoai.experimental.rfm.base import SourceColumn, SourceForeignKey, Table
|
|
10
|
+
from kumoai.experimental.rfm.infer import infer_dtype
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class SQLiteTable(Table):
|
|
@@ -33,92 +33,69 @@ class SQLiteTable(Table):
|
|
|
33
33
|
) -> None:
|
|
34
34
|
|
|
35
35
|
self._connection = connection
|
|
36
|
-
self._dtype_dict: Dict[str, Dtype] = {}
|
|
37
|
-
|
|
38
|
-
with connection.cursor() as cursor:
|
|
39
|
-
cursor.execute(f"PRAGMA table_info({name})")
|
|
40
|
-
for _, column, dtype, _, _, is_pkey in cursor.fetchall():
|
|
41
|
-
if bool(is_pkey):
|
|
42
|
-
if primary_key is not None and primary_key != column:
|
|
43
|
-
raise ValueError(f"Found duplicate primary key "
|
|
44
|
-
f"definition '{primary_key}' and "
|
|
45
|
-
f"'{column}' in table '{name}'")
|
|
46
|
-
primary_key = column
|
|
47
|
-
|
|
48
|
-
# Determine colun affinity:
|
|
49
|
-
dtype = dtype.strip().upper()
|
|
50
|
-
if re.search('INT', dtype):
|
|
51
|
-
self._dtype_dict[column] = Dtype.int
|
|
52
|
-
elif re.search('TEXT|CHAR|CLOB', dtype):
|
|
53
|
-
self._dtype_dict[column] = Dtype.string
|
|
54
|
-
elif re.search('REAL|FLOA|DOUB', dtype):
|
|
55
|
-
self._dtype_dict[column] = Dtype.float
|
|
56
|
-
else: # NUMERIC affinity.
|
|
57
|
-
self._dtype_dict[column] = Dtype.unsupported
|
|
58
|
-
|
|
59
|
-
if len(self._dtype_dict) > 0:
|
|
60
|
-
column_names = ', '.join(self._dtype_dict.keys())
|
|
61
|
-
cursor.execute(f"SELECT {column_names} FROM {name} "
|
|
62
|
-
f"ORDER BY rowid LIMIT 1000")
|
|
63
|
-
self._sample = cursor.fetch_arrow_table()
|
|
64
|
-
|
|
65
|
-
for column_name in list(self._dtype_dict.keys()):
|
|
66
|
-
if self._dtype_dict[column_name] == Dtype.unsupported:
|
|
67
|
-
dtype = self._sample[column_name].type
|
|
68
|
-
if pa.types.is_integer(dtype):
|
|
69
|
-
self._dtype_dict[column_name] = Dtype.int
|
|
70
|
-
elif pa.types.is_floating(dtype):
|
|
71
|
-
self._dtype_dict[column_name] = Dtype.float
|
|
72
|
-
elif pa.types.is_decimal(dtype):
|
|
73
|
-
self._dtype_dict[column_name] = Dtype.float
|
|
74
|
-
elif pa.types.is_string(dtype):
|
|
75
|
-
self._dtype_dict[column_name] = Dtype.string
|
|
76
|
-
else:
|
|
77
|
-
del self._dtype_dict[column_name]
|
|
78
|
-
|
|
79
|
-
if len(self._dtype_dict) == 0:
|
|
80
|
-
raise RuntimeError(f"Table '{name}' does not exist or does not "
|
|
81
|
-
f"hold any column with a supported data type")
|
|
82
36
|
|
|
83
37
|
super().__init__(
|
|
84
38
|
name=name,
|
|
85
|
-
columns=columns
|
|
39
|
+
columns=columns,
|
|
86
40
|
primary_key=primary_key,
|
|
87
41
|
time_column=time_column,
|
|
88
42
|
end_time_column=end_time_column,
|
|
89
43
|
)
|
|
90
44
|
|
|
91
|
-
def
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
45
|
+
def _get_source_columns(self) -> List[SourceColumn]:
|
|
46
|
+
source_columns: List[SourceColumn] = []
|
|
47
|
+
with self._connection.cursor() as cursor:
|
|
48
|
+
cursor.execute(f"PRAGMA table_info({self.name})")
|
|
49
|
+
rows = cursor.fetchall()
|
|
50
|
+
|
|
51
|
+
if len(rows) == 0:
|
|
52
|
+
raise ValueError(f"Table '{self.name}' does not exist")
|
|
53
|
+
|
|
54
|
+
for _, column, type, _, _, is_pkey in rows:
|
|
55
|
+
# Determine column affinity:
|
|
56
|
+
type = type.strip().upper()
|
|
57
|
+
if re.search('INT', type):
|
|
58
|
+
dtype = Dtype.int
|
|
59
|
+
elif re.search('TEXT|CHAR|CLOB', type):
|
|
60
|
+
dtype = Dtype.string
|
|
61
|
+
elif re.search('REAL|FLOA|DOUB', type):
|
|
62
|
+
dtype = Dtype.float
|
|
63
|
+
else: # NUMERIC affinity.
|
|
64
|
+
ser = self._sample_df[column]
|
|
65
|
+
try:
|
|
66
|
+
dtype = infer_dtype(ser)
|
|
67
|
+
except Exception:
|
|
68
|
+
warnings.warn(
|
|
69
|
+
f"Data type inference for column '{column}' in "
|
|
70
|
+
f"table '{self.name}' failed. Consider changing "
|
|
71
|
+
f"the data type of the column to use it within "
|
|
72
|
+
f"this table.")
|
|
73
|
+
continue
|
|
74
|
+
|
|
75
|
+
source_column = SourceColumn(
|
|
76
|
+
name=column,
|
|
77
|
+
dtype=dtype,
|
|
78
|
+
is_primary_key=bool(is_pkey),
|
|
79
|
+
is_unique_key=False,
|
|
80
|
+
)
|
|
81
|
+
source_columns.append(source_column)
|
|
82
|
+
|
|
83
|
+
return source_columns
|
|
84
|
+
|
|
85
|
+
def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
|
|
86
|
+
source_fkeys: List[SourceForeignKey] = []
|
|
111
87
|
with self._connection.cursor() as cursor:
|
|
112
88
|
cursor.execute(f"PRAGMA foreign_key_list({self.name})")
|
|
113
89
|
for _, _, dst_table, fkey, pkey, _, _, _ in cursor.fetchall():
|
|
114
|
-
|
|
115
|
-
return
|
|
116
|
-
|
|
117
|
-
def _infer_primary_key(self, candidates: List[str]) -> Optional[str]:
|
|
118
|
-
return None # TODO
|
|
90
|
+
source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
|
|
91
|
+
return source_fkeys
|
|
119
92
|
|
|
120
|
-
def
|
|
121
|
-
|
|
93
|
+
def _get_sample_df(self) -> pd.DataFrame:
|
|
94
|
+
with self._connection.cursor() as cursor:
|
|
95
|
+
cursor.execute(f"SELECT * FROM {self.name} "
|
|
96
|
+
f"ORDER BY rowid LIMIT 1000")
|
|
97
|
+
table = cursor.fetch_arrow_table()
|
|
98
|
+
return table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
122
99
|
|
|
123
|
-
def
|
|
100
|
+
def _get_num_rows(self) -> Optional[int]:
|
|
124
101
|
return None
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from kumoapi.rfm.context import Subgraph
|
|
9
|
+
from kumoapi.typing import Stype
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from kumoai.experimental.rfm import Graph
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class EdgeSpec:
|
|
17
|
+
num_neighbors: int | None = None
|
|
18
|
+
time_offsets: tuple[
|
|
19
|
+
pd.DateOffset | None,
|
|
20
|
+
pd.DateOffset,
|
|
21
|
+
] | None = None
|
|
22
|
+
|
|
23
|
+
def __post_init__(self) -> None:
|
|
24
|
+
if (self.num_neighbors is None) == (self.time_offsets is None):
|
|
25
|
+
raise ValueError("Only one of 'num_neighbors' and 'time_offsets' "
|
|
26
|
+
"must be provided")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class SamplerOutput:
|
|
31
|
+
df_dict: dict[str, pd.DataFrame]
|
|
32
|
+
batch_dict: dict[str, pd.DataFrame]
|
|
33
|
+
num_sampled_nodes_dict: dict[str, list[int]]
|
|
34
|
+
edge_index_dict: dict[tuple[str, str, str], np.ndarray] | None = None
|
|
35
|
+
num_sampled_edges_dict: dict[tuple[str, str, str], list[int]] | None = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class Sampler(ABC):
|
|
39
|
+
def __init__(self, graph: 'Graph') -> None:
|
|
40
|
+
self._edge_types: list[tuple[str, str, str]] = []
|
|
41
|
+
for edge in graph.edges:
|
|
42
|
+
edge_type = (edge.src_table, edge.fkey, edge.dst_table)
|
|
43
|
+
self._edge_types.append(edge_type)
|
|
44
|
+
self._edge_types.append(Subgraph.rev_edge_type(edge_type))
|
|
45
|
+
|
|
46
|
+
self._primary_key_dict: dict[str, str] = {
|
|
47
|
+
table.name: table._primary_key
|
|
48
|
+
for table in graph.tables.values()
|
|
49
|
+
if table._primary_key is not None
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
self._time_column_dict: dict[str, str] = {
|
|
53
|
+
table.name: table._time_column
|
|
54
|
+
for table in graph.tables.values()
|
|
55
|
+
if table._time_column is not None
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
|
|
59
|
+
self._stype_dict: dict[str, dict[str, Stype]] = {}
|
|
60
|
+
for table in graph.tables.values():
|
|
61
|
+
self._stype_dict[table.name] = {}
|
|
62
|
+
for column in table.columns:
|
|
63
|
+
if column == table.primary_key:
|
|
64
|
+
continue
|
|
65
|
+
if (table.name, column.name) in foreign_keys:
|
|
66
|
+
continue
|
|
67
|
+
self._stype_dict[table.name][column.name] = column.stype
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def edge_types(self) -> list[tuple[str, str, str]]:
|
|
71
|
+
return self._edge_types
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def primary_key_dict(self) -> dict[str, str]:
|
|
75
|
+
return self._primary_key_dict
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def time_column_dict(self) -> dict[str, str]:
|
|
79
|
+
return self._time_column_dict
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def stype_dict(self) -> dict[str, dict[str, Stype]]:
|
|
83
|
+
return self._stype_dict
|
|
84
|
+
|
|
85
|
+
def sample_subgraph(
|
|
86
|
+
self,
|
|
87
|
+
entity_table_names: tuple[str, ...],
|
|
88
|
+
entity_pkey: pd.Series,
|
|
89
|
+
anchor_time: pd.Series,
|
|
90
|
+
num_neighbors: list[int],
|
|
91
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
92
|
+
) -> Subgraph:
|
|
93
|
+
|
|
94
|
+
edge_spec: dict[tuple[str, str, str], list[EdgeSpec]] = {
|
|
95
|
+
edge_type: [EdgeSpec(value) for value in num_neighbors]
|
|
96
|
+
for edge_type in self.edge_types
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
stype_dict: dict[str, dict[str, Stype]] = self._stype_dict
|
|
100
|
+
if exclude_cols_dict is not None:
|
|
101
|
+
stype_dict = copy.deepcopy(stype_dict)
|
|
102
|
+
for table_name, exclude_cols in exclude_cols_dict.items():
|
|
103
|
+
for column_name in exclude_cols:
|
|
104
|
+
del stype_dict[table_name][column_name]
|
|
105
|
+
|
|
106
|
+
column_spec: dict[str, list[str]] = {
|
|
107
|
+
table_name: list(stypes.keys())
|
|
108
|
+
for table_name, stypes in stype_dict.items()
|
|
109
|
+
}
|
|
110
|
+
for table_name in entity_table_names:
|
|
111
|
+
column_spec[table_name].append(self.primary_key_dict[table_name])
|
|
112
|
+
|
|
113
|
+
return self._sample(
|
|
114
|
+
entity_table_name=entity_table_names[0],
|
|
115
|
+
entity_pkey=entity_pkey,
|
|
116
|
+
anchor_time=anchor_time,
|
|
117
|
+
column_spec=column_spec,
|
|
118
|
+
edge_spec=edge_spec,
|
|
119
|
+
return_edges=True,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Abstract Methods ########################################################
|
|
123
|
+
|
|
124
|
+
@abstractmethod
|
|
125
|
+
def _sample(
|
|
126
|
+
self,
|
|
127
|
+
entity_table_name: str,
|
|
128
|
+
entity_pkey: pd.Series,
|
|
129
|
+
anchor_time: pd.Series,
|
|
130
|
+
column_spec: dict[str, list[str]],
|
|
131
|
+
edge_spec: dict[tuple[str, str, str], list[EdgeSpec]],
|
|
132
|
+
return_edges: bool = False,
|
|
133
|
+
) -> SamplerOutput:
|
|
134
|
+
pass
|