kumoai 2.14.0.dev202512271732__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0rc2__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/jobs.py +2 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +24 -22
- kumoai/experimental/rfm/backend/snow/sampler.py +83 -14
- kumoai/experimental/rfm/backend/sqlite/sampler.py +68 -12
- kumoai/experimental/rfm/base/mapper.py +67 -0
- kumoai/experimental/rfm/base/sampler.py +21 -0
- kumoai/experimental/rfm/base/sql_sampler.py +233 -10
- kumoai/experimental/rfm/base/table.py +41 -53
- kumoai/experimental/rfm/graph.py +57 -60
- kumoai/experimental/rfm/infer/dtype.py +2 -1
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +529 -303
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +13 -1
- {kumoai-2.14.0.dev202512271732.dist-info → kumoai-2.14.0rc2.dist-info}/METADATA +2 -2
- {kumoai-2.14.0.dev202512271732.dist-info → kumoai-2.14.0rc2.dist-info}/RECORD +24 -20
- {kumoai-2.14.0.dev202512271732.dist-info → kumoai-2.14.0rc2.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512271732.dist-info → kumoai-2.14.0rc2.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512271732.dist-info → kumoai-2.14.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -59,6 +59,17 @@ class Sampler(ABC):
|
|
|
59
59
|
self._edge_types.append(edge_type)
|
|
60
60
|
self._edge_types.append(Subgraph.rev_edge_type(edge_type))
|
|
61
61
|
|
|
62
|
+
# Source Table -> [(Foreign Key, Destination Table)]
|
|
63
|
+
self._foreign_key_dict: dict[str, list[tuple[str, str]]] = {}
|
|
64
|
+
# Destination Table -> [(Source Table, Foreign Key)]
|
|
65
|
+
self._rev_foreign_key_dict: dict[str, list[tuple[str, str]]] = {}
|
|
66
|
+
for table in graph.tables.values():
|
|
67
|
+
self._foreign_key_dict[table.name] = []
|
|
68
|
+
self._rev_foreign_key_dict[table.name] = []
|
|
69
|
+
for src_table, fkey, dst_table in graph.edges:
|
|
70
|
+
self._foreign_key_dict[src_table].append((fkey, dst_table))
|
|
71
|
+
self._rev_foreign_key_dict[dst_table].append((src_table, fkey))
|
|
72
|
+
|
|
62
73
|
self._primary_key_dict: dict[str, str] = {
|
|
63
74
|
table.name: table._primary_key
|
|
64
75
|
for table in graph.tables.values()
|
|
@@ -98,6 +109,16 @@ class Sampler(ABC):
|
|
|
98
109
|
r"""All available edge types in the graph."""
|
|
99
110
|
return self._edge_types
|
|
100
111
|
|
|
112
|
+
@property
|
|
113
|
+
def foreign_key_dict(self) -> dict[str, list[tuple[str, str]]]:
|
|
114
|
+
r"""The foreign keys for all tables in the graph."""
|
|
115
|
+
return self._foreign_key_dict
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def rev_foreign_key_dict(self) -> dict[str, list[tuple[str, str]]]:
|
|
119
|
+
r"""The foreign key back references for all tables in the graph."""
|
|
120
|
+
return self._rev_foreign_key_dict
|
|
121
|
+
|
|
101
122
|
@property
|
|
102
123
|
def primary_key_dict(self) -> dict[str, str]:
|
|
103
124
|
r"""All available primary keys in the graph."""
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
|
+
from collections import defaultdict
|
|
2
3
|
from typing import TYPE_CHECKING, Literal
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
5
6
|
import pandas as pd
|
|
7
|
+
from kumoapi.rfm.context import Subgraph
|
|
6
8
|
from kumoapi.typing import Dtype
|
|
7
9
|
|
|
8
10
|
from kumoai.experimental.rfm.base import (
|
|
@@ -11,11 +13,14 @@ from kumoai.experimental.rfm.base import (
|
|
|
11
13
|
SamplerOutput,
|
|
12
14
|
SourceColumn,
|
|
13
15
|
)
|
|
16
|
+
from kumoai.experimental.rfm.base.mapper import Mapper
|
|
14
17
|
from kumoai.utils import ProgressLogger, quote_ident
|
|
15
18
|
|
|
16
19
|
if TYPE_CHECKING:
|
|
17
20
|
from kumoai.experimental.rfm import Graph
|
|
18
21
|
|
|
22
|
+
EdgeType = tuple[str, str, str]
|
|
23
|
+
|
|
19
24
|
|
|
20
25
|
class SQLSampler(Sampler):
|
|
21
26
|
def __init__(
|
|
@@ -100,10 +105,28 @@ class SQLSampler(Sampler):
|
|
|
100
105
|
num_neighbors: list[int],
|
|
101
106
|
) -> SamplerOutput:
|
|
102
107
|
|
|
108
|
+
# Make sure to always include primary key, foreign key and time columns
|
|
109
|
+
# during data fetching since these are needed for graph traversal:
|
|
110
|
+
sample_columns_dict: dict[str, set[str]] = {}
|
|
111
|
+
for table, columns in columns_dict.items():
|
|
112
|
+
sample_columns = columns | {
|
|
113
|
+
foreign_key
|
|
114
|
+
for foreign_key, _ in self.foreign_key_dict[table]
|
|
115
|
+
}
|
|
116
|
+
if primary_key := self.primary_key_dict.get(table):
|
|
117
|
+
sample_columns |= {primary_key}
|
|
118
|
+
sample_columns_dict[table] = sample_columns
|
|
119
|
+
if not isinstance(anchor_time, pd.Series):
|
|
120
|
+
sample_columns_dict[entity_table_name] |= {
|
|
121
|
+
self.time_column_dict[entity_table_name]
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
# Sample Entity Table #################################################
|
|
125
|
+
|
|
103
126
|
df, batch = self._by_pkey(
|
|
104
127
|
table_name=entity_table_name,
|
|
105
|
-
|
|
106
|
-
columns=
|
|
128
|
+
index=entity_pkey,
|
|
129
|
+
columns=sample_columns_dict[entity_table_name],
|
|
107
130
|
)
|
|
108
131
|
if len(batch) != len(entity_pkey):
|
|
109
132
|
mask = np.ones(len(entity_pkey), dtype=bool)
|
|
@@ -112,23 +135,211 @@ class SQLSampler(Sampler):
|
|
|
112
135
|
f"{entity_pkey.iloc[mask].tolist()} do not exist "
|
|
113
136
|
f"in the '{entity_table_name}' table")
|
|
114
137
|
|
|
138
|
+
# Make sure that entities are returned in expected order:
|
|
115
139
|
perm = batch.argsort()
|
|
116
140
|
batch = batch[perm]
|
|
117
141
|
df = df.iloc[perm].reset_index(drop=True)
|
|
118
142
|
|
|
143
|
+
# Fill 'entity' anchor times with actual values:
|
|
119
144
|
if not isinstance(anchor_time, pd.Series):
|
|
120
145
|
time_column = self.time_column_dict[entity_table_name]
|
|
121
146
|
anchor_time = df[time_column]
|
|
147
|
+
assert isinstance(anchor_time, pd.Series)
|
|
148
|
+
|
|
149
|
+
# Recursive Neighbor Sampling #########################################
|
|
150
|
+
|
|
151
|
+
mapper_dict: dict[str, Mapper] = defaultdict(
|
|
152
|
+
lambda: Mapper(num_examples=len(entity_pkey)))
|
|
153
|
+
mapper_dict[entity_table_name].add(
|
|
154
|
+
pkey=df[self.primary_key_dict[entity_table_name]],
|
|
155
|
+
batch=batch,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
dfs_dict: dict[str, list[pd.DataFrame]] = defaultdict(list)
|
|
159
|
+
dfs_dict[entity_table_name].append(df)
|
|
160
|
+
batches_dict: dict[str, list[np.ndarray]] = defaultdict(list)
|
|
161
|
+
batches_dict[entity_table_name].append(batch)
|
|
162
|
+
num_sampled_nodes_dict: dict[str, list[int]] = defaultdict(
|
|
163
|
+
lambda: [0] * (len(num_neighbors) + 1))
|
|
164
|
+
num_sampled_nodes_dict[entity_table_name][0] = len(entity_pkey)
|
|
165
|
+
|
|
166
|
+
rows_dict: dict[EdgeType, list[np.ndarray]] = defaultdict(list)
|
|
167
|
+
cols_dict: dict[EdgeType, list[np.ndarray]] = defaultdict(list)
|
|
168
|
+
num_sampled_edges_dict: dict[EdgeType, list[int]] = defaultdict(
|
|
169
|
+
lambda: [0] * len(num_neighbors))
|
|
170
|
+
|
|
171
|
+
# The start index of data frame slices of the previous hop:
|
|
172
|
+
offset_dict: dict[str, int] = defaultdict(int)
|
|
173
|
+
|
|
174
|
+
for hop, neighbors in enumerate(num_neighbors):
|
|
175
|
+
if neighbors == 0:
|
|
176
|
+
break # Abort early.
|
|
177
|
+
|
|
178
|
+
for table in list(num_sampled_nodes_dict.keys()):
|
|
179
|
+
# Only sample from tables that have been visited in the
|
|
180
|
+
# previous hop:
|
|
181
|
+
if num_sampled_nodes_dict[table][hop] == 0:
|
|
182
|
+
continue
|
|
183
|
+
|
|
184
|
+
# Collect the slices of data sampled in the previous hop
|
|
185
|
+
# (but maintain only required key information):
|
|
186
|
+
cols = [fkey for fkey, _ in self.foreign_key_dict[table]]
|
|
187
|
+
if table in self.primary_key_dict:
|
|
188
|
+
cols.append(self.primary_key_dict[table])
|
|
189
|
+
dfs = [df[cols] for df in dfs_dict[table][offset_dict[table]:]]
|
|
190
|
+
df = pd.concat(
|
|
191
|
+
dfs,
|
|
192
|
+
axis=0,
|
|
193
|
+
ignore_index=True,
|
|
194
|
+
) if len(dfs) > 1 else dfs[0]
|
|
195
|
+
batches = batches_dict[table][offset_dict[table]:]
|
|
196
|
+
batch = (np.concatenate(batches)
|
|
197
|
+
if len(batches) > 1 else batches[0])
|
|
198
|
+
offset_dict[table] = len(batches_dict[table]) # Increase.
|
|
199
|
+
|
|
200
|
+
pkey: pd.Series | None = None
|
|
201
|
+
index: pd.ndarray | None = None
|
|
202
|
+
if table in self.primary_key_dict:
|
|
203
|
+
pkey = df[self.primary_key_dict[table]]
|
|
204
|
+
index = mapper_dict[table].get(pkey, batch)
|
|
205
|
+
|
|
206
|
+
# Iterate over foreign keys in the current table:
|
|
207
|
+
for fkey, dst_table in self.foreign_key_dict[table]:
|
|
208
|
+
row = mapper_dict[dst_table].get(df[fkey], batch)
|
|
209
|
+
mask = row == -1
|
|
210
|
+
if mask.any():
|
|
211
|
+
key_df = pd.DataFrame({
|
|
212
|
+
'fkey': df[fkey],
|
|
213
|
+
'batch': batch,
|
|
214
|
+
}).iloc[mask]
|
|
215
|
+
# Only maintain unique keys per example:
|
|
216
|
+
unique_key_df = key_df.drop_duplicates()
|
|
217
|
+
# Fully de-duplicate keys across examples:
|
|
218
|
+
code, fkey_index = pd.factorize(unique_key_df['fkey'])
|
|
219
|
+
|
|
220
|
+
_df, _batch = self._by_pkey(
|
|
221
|
+
table_name=dst_table,
|
|
222
|
+
index=fkey_index,
|
|
223
|
+
columns=sample_columns_dict[dst_table],
|
|
224
|
+
) # Ensure result is sorted according to input order:
|
|
225
|
+
_df = _df.iloc[_batch.argsort()]
|
|
226
|
+
|
|
227
|
+
# Compute valid entries (without dangling foreign keys)
|
|
228
|
+
# in `unique_fkey_df`:
|
|
229
|
+
_mask = np.full(len(fkey_index), fill_value=False)
|
|
230
|
+
_mask[_batch] = True
|
|
231
|
+
_mask = _mask[code]
|
|
232
|
+
|
|
233
|
+
# Recontruct unique (key, batch) pairs:
|
|
234
|
+
code, _ = pd.factorize(unique_key_df['fkey'][_mask])
|
|
235
|
+
_df = _df.iloc[code].reset_index(drop=True)
|
|
236
|
+
_batch = unique_key_df['batch'].to_numpy()[_mask]
|
|
237
|
+
|
|
238
|
+
# Register node IDs:
|
|
239
|
+
mapper_dict[dst_table].add(_df[fkey], _batch)
|
|
240
|
+
row[mask] = mapper_dict[dst_table].get(
|
|
241
|
+
pkey=key_df['fkey'],
|
|
242
|
+
batch=key_df['batch'].to_numpy(),
|
|
243
|
+
) # NOTE `row` may still hold `-1` for dangling fkeys.
|
|
244
|
+
|
|
245
|
+
dfs_dict[dst_table].append(_df)
|
|
246
|
+
batches_dict[dst_table].append(_batch)
|
|
247
|
+
num_sampled_nodes_dict[dst_table][hop + 1] += ( #
|
|
248
|
+
len(_batch))
|
|
249
|
+
|
|
250
|
+
mask = row != -1
|
|
251
|
+
|
|
252
|
+
col = index
|
|
253
|
+
if col is None:
|
|
254
|
+
num_nodes = num_sampled_nodes_dict[table][hop]
|
|
255
|
+
col = np.arange(num_nodes, num_nodes + len(row))
|
|
256
|
+
|
|
257
|
+
row = row[mask]
|
|
258
|
+
col = col[mask]
|
|
259
|
+
|
|
260
|
+
edge_type = (table, fkey, dst_table)
|
|
261
|
+
edge_type = Subgraph.rev_edge_type(edge_type)
|
|
262
|
+
rows_dict[edge_type].append(row)
|
|
263
|
+
cols_dict[edge_type].append(col)
|
|
264
|
+
num_sampled_edges_dict[edge_type][hop] = len(col)
|
|
265
|
+
|
|
266
|
+
# Iterate over foreign keys that reference the current table:
|
|
267
|
+
for src_table, fkey in self.rev_foreign_key_dict[table]:
|
|
268
|
+
assert pkey is not None and index is not None
|
|
269
|
+
_df, _batch = self._by_fkey(
|
|
270
|
+
table_name=src_table,
|
|
271
|
+
foreign_key=fkey,
|
|
272
|
+
index=pkey,
|
|
273
|
+
num_neighbors=neighbors,
|
|
274
|
+
anchor_time=anchor_time.iloc[batch],
|
|
275
|
+
columns=sample_columns_dict[src_table],
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
edge_type = (src_table, fkey, table)
|
|
279
|
+
cols_dict[edge_type].append(index[_batch])
|
|
280
|
+
num_sampled_edges_dict[edge_type][hop] = len(_batch)
|
|
281
|
+
|
|
282
|
+
_batch = batch[_batch]
|
|
283
|
+
num_nodes = sum(num_sampled_nodes_dict[src_table])
|
|
284
|
+
if src_table in self.primary_key_dict:
|
|
285
|
+
_pkey = _df[self.primary_key_dict]
|
|
286
|
+
mapper_dict[src_table].add(_pkey, _batch)
|
|
287
|
+
row = mapper_dict[src_table].get(_pkey, _batch)
|
|
288
|
+
|
|
289
|
+
# Only preserve unknown rows:
|
|
290
|
+
mask = row >= num_nodes # type: ignore
|
|
291
|
+
mask[pd.duplicated(row)] = False
|
|
292
|
+
_df = _df.iloc[mask]
|
|
293
|
+
_batch = _batch[mask]
|
|
294
|
+
else:
|
|
295
|
+
row = np.arange(num_nodes, num_nodes + len(_batch))
|
|
296
|
+
|
|
297
|
+
rows_dict[edge_type].append(row)
|
|
298
|
+
num_sampled_nodes_dict[src_table][hop + 1] += len(_batch)
|
|
299
|
+
|
|
300
|
+
dfs_dict[src_table].append(_df)
|
|
301
|
+
batches_dict[src_table].append(_batch)
|
|
302
|
+
|
|
303
|
+
# Post-Processing #####################################################
|
|
304
|
+
|
|
305
|
+
df_dict = {
|
|
306
|
+
table:
|
|
307
|
+
pd.concat(dfs, axis=0, ignore_index=True)
|
|
308
|
+
if len(dfs) > 1 else dfs[0]
|
|
309
|
+
for table, dfs in dfs_dict.items()
|
|
310
|
+
}
|
|
311
|
+
df_dict = { # Post-filter column set:
|
|
312
|
+
table: df[list(columns_dict[table])]
|
|
313
|
+
for table, df in df_dict.items()
|
|
314
|
+
}
|
|
315
|
+
batch_dict = {
|
|
316
|
+
table: np.concatenate(batches) if len(batches) > 1 else batches[0]
|
|
317
|
+
for table, batches in batches_dict.items()
|
|
318
|
+
}
|
|
319
|
+
row_dict = {
|
|
320
|
+
edge_type: np.concatenate(rows)
|
|
321
|
+
for edge_type, rows in rows_dict.items()
|
|
322
|
+
}
|
|
323
|
+
col_dict = {
|
|
324
|
+
edge_type: np.concatenate(cols)
|
|
325
|
+
for edge_type, cols in cols_dict.items()
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
if len(num_sampled_edges_dict) == 0: # Single table:
|
|
329
|
+
num_sampled_nodes_dict = {
|
|
330
|
+
key: value[:1]
|
|
331
|
+
for key, value in num_sampled_nodes_dict.items()
|
|
332
|
+
}
|
|
122
333
|
|
|
123
334
|
return SamplerOutput(
|
|
124
335
|
anchor_time=anchor_time.astype(int).to_numpy(),
|
|
125
|
-
df_dict=
|
|
126
|
-
inverse_dict={},
|
|
127
|
-
batch_dict=
|
|
128
|
-
num_sampled_nodes_dict=
|
|
129
|
-
row_dict=
|
|
130
|
-
col_dict=
|
|
131
|
-
num_sampled_edges_dict=
|
|
336
|
+
df_dict=df_dict,
|
|
337
|
+
inverse_dict={}, # TODO
|
|
338
|
+
batch_dict=batch_dict,
|
|
339
|
+
num_sampled_nodes_dict=num_sampled_nodes_dict,
|
|
340
|
+
row_dict=row_dict,
|
|
341
|
+
col_dict=col_dict,
|
|
342
|
+
num_sampled_edges_dict=num_sampled_edges_dict,
|
|
132
343
|
)
|
|
133
344
|
|
|
134
345
|
# Abstract Methods ########################################################
|
|
@@ -137,7 +348,19 @@ class SQLSampler(Sampler):
|
|
|
137
348
|
def _by_pkey(
|
|
138
349
|
self,
|
|
139
350
|
table_name: str,
|
|
140
|
-
|
|
351
|
+
index: pd.Series,
|
|
352
|
+
columns: set[str],
|
|
353
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
354
|
+
pass
|
|
355
|
+
|
|
356
|
+
@abstractmethod
|
|
357
|
+
def _by_fkey(
|
|
358
|
+
self,
|
|
359
|
+
table_name: str,
|
|
360
|
+
foreign_key: str,
|
|
361
|
+
index: pd.Series,
|
|
362
|
+
num_neighbors: int,
|
|
363
|
+
anchor_time: pd.Series | None,
|
|
141
364
|
columns: set[str],
|
|
142
365
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
143
366
|
pass
|
|
@@ -5,6 +5,7 @@ from functools import cached_property
|
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import pandas as pd
|
|
8
|
+
import pyarrow as pa
|
|
8
9
|
from kumoapi.model_plan import MissingType
|
|
9
10
|
from kumoapi.source_table import UnavailableSourceTable
|
|
10
11
|
from kumoapi.table import Column as ColumnDefinition
|
|
@@ -12,7 +13,6 @@ from kumoapi.table import TableDefinition
|
|
|
12
13
|
from kumoapi.typing import Dtype, Stype
|
|
13
14
|
from typing_extensions import Self
|
|
14
15
|
|
|
15
|
-
from kumoai import in_notebook, in_snowflake_notebook
|
|
16
16
|
from kumoai.experimental.rfm.base import (
|
|
17
17
|
Column,
|
|
18
18
|
ColumnSpec,
|
|
@@ -27,7 +27,7 @@ from kumoai.experimental.rfm.infer import (
|
|
|
27
27
|
infer_stype,
|
|
28
28
|
infer_time_column,
|
|
29
29
|
)
|
|
30
|
-
from kumoai.utils import quote_ident
|
|
30
|
+
from kumoai.utils import display, quote_ident
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
class Table(ABC):
|
|
@@ -196,7 +196,7 @@ class Table(ABC):
|
|
|
196
196
|
raise RuntimeError(
|
|
197
197
|
f"Encountered unsupported data type '{ser.dtype}' for "
|
|
198
198
|
f"column '{column_spec.name}' in table '{self.name}'. "
|
|
199
|
-
f"Please either manually
|
|
199
|
+
f"Please either manually override the columns's data "
|
|
200
200
|
f"type or remove the column from this table.") from e
|
|
201
201
|
|
|
202
202
|
if stype is None:
|
|
@@ -272,8 +272,8 @@ class Table(ABC):
|
|
|
272
272
|
no such primary key is present.
|
|
273
273
|
|
|
274
274
|
The setter sets a column as a primary key on this table, and raises a
|
|
275
|
-
:class:`ValueError` if the primary key has a non-ID
|
|
276
|
-
if the column name does not match a column in the data frame.
|
|
275
|
+
:class:`ValueError` if the primary key has a non-ID compatible data
|
|
276
|
+
type or if the column name does not match a column in the data frame.
|
|
277
277
|
"""
|
|
278
278
|
if self._primary_key is None:
|
|
279
279
|
return None
|
|
@@ -317,8 +317,9 @@ class Table(ABC):
|
|
|
317
317
|
such time column is present.
|
|
318
318
|
|
|
319
319
|
The setter sets a column as a time column on this table, and raises a
|
|
320
|
-
:class:`ValueError` if the time column has a non-timestamp
|
|
321
|
-
type or if the column name does not match a column in the data
|
|
320
|
+
:class:`ValueError` if the time column has a non-timestamp compatible
|
|
321
|
+
data type or if the column name does not match a column in the data
|
|
322
|
+
frame.
|
|
322
323
|
"""
|
|
323
324
|
if self._time_column is None:
|
|
324
325
|
return None
|
|
@@ -363,8 +364,8 @@ class Table(ABC):
|
|
|
363
364
|
|
|
364
365
|
The setter sets a column as an end time column on this table, and
|
|
365
366
|
raises a :class:`ValueError` if the end time column has a non-timestamp
|
|
366
|
-
|
|
367
|
-
frame.
|
|
367
|
+
compatible data type or if the column name does not match a column in
|
|
368
|
+
the data frame.
|
|
368
369
|
"""
|
|
369
370
|
if self._end_time_column is None:
|
|
370
371
|
return None
|
|
@@ -399,39 +400,39 @@ class Table(ABC):
|
|
|
399
400
|
r"""Returns a :class:`pandas.DataFrame` object containing metadata
|
|
400
401
|
information about the columns in this table.
|
|
401
402
|
|
|
402
|
-
The returned dataframe has columns ``
|
|
403
|
-
``
|
|
404
|
-
which provide an
|
|
405
|
-
this table.
|
|
403
|
+
The returned dataframe has columns ``"Name"``, ``"Data Type"``,
|
|
404
|
+
``"Semantic Type"``, ``"Primary Key"``, ``"Time Column"`` and
|
|
405
|
+
``"End Time Column"``, which provide an aggregated view of the
|
|
406
|
+
properties of the columns of this table.
|
|
406
407
|
|
|
407
408
|
Example:
|
|
408
409
|
>>> # doctest: +SKIP
|
|
409
410
|
>>> import kumoai.experimental.rfm as rfm
|
|
410
411
|
>>> table = rfm.LocalTable(df=..., name=...).infer_metadata()
|
|
411
412
|
>>> table.metadata
|
|
412
|
-
|
|
413
|
-
0 CustomerID float64
|
|
413
|
+
Name Data Type Semantic Type Primary Key Time Column End Time Column
|
|
414
|
+
0 CustomerID float64 ID True False False
|
|
414
415
|
""" # noqa: E501
|
|
415
416
|
cols = self.columns
|
|
416
417
|
|
|
417
418
|
return pd.DataFrame({
|
|
418
|
-
'
|
|
419
|
+
'Name':
|
|
419
420
|
pd.Series(dtype=str, data=[c.name for c in cols]),
|
|
420
|
-
'
|
|
421
|
+
'Data Type':
|
|
421
422
|
pd.Series(dtype=str, data=[c.dtype for c in cols]),
|
|
422
|
-
'
|
|
423
|
+
'Semantic Type':
|
|
423
424
|
pd.Series(dtype=str, data=[c.stype for c in cols]),
|
|
424
|
-
'
|
|
425
|
+
'Primary Key':
|
|
425
426
|
pd.Series(
|
|
426
427
|
dtype=bool,
|
|
427
428
|
data=[self._primary_key == c.name for c in cols],
|
|
428
429
|
),
|
|
429
|
-
'
|
|
430
|
+
'Time Column':
|
|
430
431
|
pd.Series(
|
|
431
432
|
dtype=bool,
|
|
432
433
|
data=[self._time_column == c.name for c in cols],
|
|
433
434
|
),
|
|
434
|
-
'
|
|
435
|
+
'End Time Column':
|
|
435
436
|
pd.Series(
|
|
436
437
|
dtype=bool,
|
|
437
438
|
data=[self._end_time_column == c.name for c in cols],
|
|
@@ -440,30 +441,12 @@ class Table(ABC):
|
|
|
440
441
|
|
|
441
442
|
def print_metadata(self) -> None:
|
|
442
443
|
r"""Prints the :meth:`~metadata` of this table."""
|
|
443
|
-
|
|
444
|
+
msg = f"🏷️ Metadata of Table `{self.name}`"
|
|
444
445
|
if num := self._num_rows:
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
md_repr = f"### 🏷️ Metadata of Table `{self.name}`{num_rows_repr}"
|
|
450
|
-
st.markdown(md_repr)
|
|
451
|
-
st.dataframe(self.metadata, hide_index=True)
|
|
452
|
-
elif in_notebook():
|
|
453
|
-
from IPython.display import Markdown, display
|
|
454
|
-
md_repr = f"### 🏷️ Metadata of Table `{self.name}`{num_rows_repr}"
|
|
455
|
-
display(Markdown(md_repr))
|
|
456
|
-
df = self.metadata
|
|
457
|
-
try:
|
|
458
|
-
if hasattr(df.style, 'hide'):
|
|
459
|
-
display(df.style.hide(axis='index')) # pandas=2
|
|
460
|
-
else:
|
|
461
|
-
display(df.style.hide_index()) # pandas<1.3
|
|
462
|
-
except ImportError:
|
|
463
|
-
print(df.to_string(index=False)) # missing jinja2
|
|
464
|
-
else:
|
|
465
|
-
print(f"🏷️ Metadata of Table '{self.name}'{num_rows_repr}")
|
|
466
|
-
print(self.metadata.to_string(index=False))
|
|
446
|
+
msg += " (1 row)" if num == 1 else f" ({num:,} rows)"
|
|
447
|
+
|
|
448
|
+
display.title(msg)
|
|
449
|
+
display.dataframe(self.metadata)
|
|
467
450
|
|
|
468
451
|
def infer_primary_key(self, verbose: bool = True) -> Self:
|
|
469
452
|
r"""Infers the primary key in this table.
|
|
@@ -477,8 +460,8 @@ class Table(ABC):
|
|
|
477
460
|
def _set_primary_key(primary_key: str) -> None:
|
|
478
461
|
self.primary_key = primary_key
|
|
479
462
|
if verbose:
|
|
480
|
-
|
|
481
|
-
|
|
463
|
+
display.message(f"Inferred primary key `{primary_key}` for "
|
|
464
|
+
f"table `{self.name}`")
|
|
482
465
|
|
|
483
466
|
# Inference from source column metadata:
|
|
484
467
|
if any(column.is_source for column in self.columns):
|
|
@@ -543,8 +526,8 @@ class Table(ABC):
|
|
|
543
526
|
self.time_column = time_column
|
|
544
527
|
|
|
545
528
|
if verbose:
|
|
546
|
-
|
|
547
|
-
|
|
529
|
+
display.message(f"Inferred time column `{time_column}` for "
|
|
530
|
+
f"table `{self.name}`")
|
|
548
531
|
|
|
549
532
|
return self
|
|
550
533
|
|
|
@@ -560,15 +543,16 @@ class Table(ABC):
|
|
|
560
543
|
if not self.has_primary_key():
|
|
561
544
|
self.infer_primary_key(verbose=False)
|
|
562
545
|
if self.has_primary_key():
|
|
563
|
-
logs.append(f"primary key
|
|
546
|
+
logs.append(f"primary key `{self._primary_key}`")
|
|
564
547
|
|
|
565
548
|
if not self.has_time_column():
|
|
566
549
|
self.infer_time_column(verbose=False)
|
|
567
550
|
if self.has_time_column():
|
|
568
|
-
logs.append(f"time column
|
|
551
|
+
logs.append(f"time column `{self._time_column}`")
|
|
569
552
|
|
|
570
553
|
if verbose and len(logs) > 0:
|
|
571
|
-
|
|
554
|
+
display.message(f"Inferred {' and '.join(logs)} for table "
|
|
555
|
+
f"`{self.name}`")
|
|
572
556
|
|
|
573
557
|
return self
|
|
574
558
|
|
|
@@ -641,14 +625,18 @@ class Table(ABC):
|
|
|
641
625
|
types match table data and semantic type specification.
|
|
642
626
|
"""
|
|
643
627
|
def _to_datetime(ser: pd.Series) -> pd.Series:
|
|
644
|
-
if not pd.api.types.is_datetime64_any_dtype(ser)
|
|
628
|
+
if (not pd.api.types.is_datetime64_any_dtype(ser)
|
|
629
|
+
and not (isinstance(ser.dtype, pd.ArrowDtype) and
|
|
630
|
+
pa.types.is_timestamp(ser.dtype.pyarrow_dtype))):
|
|
645
631
|
with warnings.catch_warnings():
|
|
646
632
|
warnings.filterwarnings(
|
|
647
633
|
'ignore',
|
|
648
634
|
message='Could not infer format',
|
|
649
635
|
)
|
|
650
636
|
ser = pd.to_datetime(ser, errors='coerce')
|
|
651
|
-
if isinstance(ser.dtype, pd.DatetimeTZDtype)
|
|
637
|
+
if (isinstance(ser.dtype, pd.DatetimeTZDtype)
|
|
638
|
+
or (isinstance(ser.dtype, pd.ArrowDtype)
|
|
639
|
+
and ser.dtype.pyarrow_dtype.tz is not None)):
|
|
652
640
|
ser = ser.dt.tz_localize(None)
|
|
653
641
|
if ser.dtype != 'datetime64[ns]':
|
|
654
642
|
ser = ser.astype('datetime64[ns]')
|