kumoai 2.14.0.dev202601051732__cp311-cp311-macosx_11_0_arm64.whl → 2.15.0.dev202601141731__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/_version.py +1 -1
- kumoai/client/jobs.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +138 -28
- kumoai/experimental/rfm/backend/snow/table.py +16 -13
- kumoai/experimental/rfm/backend/sqlite/sampler.py +73 -15
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +23 -1
- kumoai/experimental/rfm/base/sql_sampler.py +252 -11
- kumoai/experimental/rfm/base/table.py +15 -29
- kumoai/experimental/rfm/base/utils.py +36 -0
- kumoai/experimental/rfm/graph.py +9 -9
- kumoai/experimental/rfm/infer/dtype.py +3 -1
- kumoai/experimental/rfm/infer/time_col.py +4 -2
- kumoai/experimental/rfm/rfm.py +195 -114
- kumoai/experimental/rfm/task_table.py +2 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +3 -3
- kumoai/utils/display.py +44 -8
- kumoai/utils/progress_logger.py +2 -1
- kumoai/utils/sql.py +2 -2
- {kumoai-2.14.0.dev202601051732.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/METADATA +2 -2
- {kumoai-2.14.0.dev202601051732.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/RECORD +25 -23
- {kumoai-2.14.0.dev202601051732.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202601051732.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202601051732.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
|
+
from collections import defaultdict
|
|
2
3
|
from typing import TYPE_CHECKING, Literal
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
5
6
|
import pandas as pd
|
|
7
|
+
from kumoapi.rfm.context import Subgraph
|
|
6
8
|
from kumoapi.typing import Dtype
|
|
7
9
|
|
|
8
10
|
from kumoai.experimental.rfm.base import (
|
|
@@ -11,11 +13,14 @@ from kumoai.experimental.rfm.base import (
|
|
|
11
13
|
SamplerOutput,
|
|
12
14
|
SourceColumn,
|
|
13
15
|
)
|
|
16
|
+
from kumoai.experimental.rfm.base.mapper import Mapper
|
|
14
17
|
from kumoai.utils import ProgressLogger, quote_ident
|
|
15
18
|
|
|
16
19
|
if TYPE_CHECKING:
|
|
17
20
|
from kumoai.experimental.rfm import Graph
|
|
18
21
|
|
|
22
|
+
EdgeType = tuple[str, str, str]
|
|
23
|
+
|
|
19
24
|
|
|
20
25
|
class SQLSampler(Sampler):
|
|
21
26
|
def __init__(
|
|
@@ -100,11 +105,28 @@ class SQLSampler(Sampler):
|
|
|
100
105
|
num_neighbors: list[int],
|
|
101
106
|
) -> SamplerOutput:
|
|
102
107
|
|
|
103
|
-
#
|
|
108
|
+
# Make sure to always include primary key, foreign key and time columns
|
|
109
|
+
# during data fetching since these are needed for graph traversal:
|
|
110
|
+
sample_columns_dict: dict[str, set[str]] = {}
|
|
111
|
+
for table, columns in columns_dict.items():
|
|
112
|
+
sample_columns = columns | {
|
|
113
|
+
foreign_key
|
|
114
|
+
for foreign_key, _ in self.foreign_key_dict[table]
|
|
115
|
+
}
|
|
116
|
+
if primary_key := self.primary_key_dict.get(table):
|
|
117
|
+
sample_columns |= {primary_key}
|
|
118
|
+
sample_columns_dict[table] = sample_columns
|
|
119
|
+
if not isinstance(anchor_time, pd.Series):
|
|
120
|
+
sample_columns_dict[entity_table_name] |= {
|
|
121
|
+
self.time_column_dict[entity_table_name]
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
# Sample Entity Table #################################################
|
|
125
|
+
|
|
104
126
|
df, batch = self._by_pkey(
|
|
105
127
|
table_name=entity_table_name,
|
|
106
|
-
|
|
107
|
-
columns=
|
|
128
|
+
index=entity_pkey,
|
|
129
|
+
columns=sample_columns_dict[entity_table_name],
|
|
108
130
|
)
|
|
109
131
|
if len(batch) != len(entity_pkey):
|
|
110
132
|
mask = np.ones(len(entity_pkey), dtype=bool)
|
|
@@ -113,23 +135,230 @@ class SQLSampler(Sampler):
|
|
|
113
135
|
f"{entity_pkey.iloc[mask].tolist()} do not exist "
|
|
114
136
|
f"in the '{entity_table_name}' table")
|
|
115
137
|
|
|
138
|
+
# Make sure that entities are returned in expected order:
|
|
116
139
|
perm = batch.argsort()
|
|
117
140
|
batch = batch[perm]
|
|
118
141
|
df = df.iloc[perm].reset_index(drop=True)
|
|
119
142
|
|
|
143
|
+
# Fill 'entity' anchor times with actual values:
|
|
120
144
|
if not isinstance(anchor_time, pd.Series):
|
|
121
145
|
time_column = self.time_column_dict[entity_table_name]
|
|
122
146
|
anchor_time = df[time_column]
|
|
147
|
+
assert isinstance(anchor_time, pd.Series)
|
|
148
|
+
|
|
149
|
+
# Recursive Neighbor Sampling #########################################
|
|
150
|
+
|
|
151
|
+
mapper_dict: dict[str, Mapper] = defaultdict(
|
|
152
|
+
lambda: Mapper(num_examples=len(entity_pkey)))
|
|
153
|
+
mapper_dict[entity_table_name].add(
|
|
154
|
+
pkey=df[self.primary_key_dict[entity_table_name]],
|
|
155
|
+
batch=batch,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
dfs_dict: dict[str, list[pd.DataFrame]] = defaultdict(list)
|
|
159
|
+
dfs_dict[entity_table_name].append(df)
|
|
160
|
+
batches_dict: dict[str, list[np.ndarray]] = defaultdict(list)
|
|
161
|
+
batches_dict[entity_table_name].append(batch)
|
|
162
|
+
num_sampled_nodes_dict: dict[str, list[int]] = defaultdict(
|
|
163
|
+
lambda: [0] * (len(num_neighbors) + 1))
|
|
164
|
+
num_sampled_nodes_dict[entity_table_name][0] = len(entity_pkey)
|
|
165
|
+
|
|
166
|
+
rows_dict: dict[EdgeType, list[np.ndarray]] = defaultdict(list)
|
|
167
|
+
cols_dict: dict[EdgeType, list[np.ndarray]] = defaultdict(list)
|
|
168
|
+
num_sampled_edges_dict: dict[EdgeType, list[int]] = defaultdict(
|
|
169
|
+
lambda: [0] * len(num_neighbors))
|
|
170
|
+
|
|
171
|
+
# The start index of data frame slices of the previous hop:
|
|
172
|
+
offset_dict: dict[str, int] = defaultdict(int)
|
|
173
|
+
|
|
174
|
+
for hop, neighbors in enumerate(num_neighbors):
|
|
175
|
+
if neighbors == 0:
|
|
176
|
+
break # Abort early.
|
|
177
|
+
|
|
178
|
+
for table in list(num_sampled_nodes_dict.keys()):
|
|
179
|
+
# Only sample from tables that have been visited in the
|
|
180
|
+
# previous hop:
|
|
181
|
+
if num_sampled_nodes_dict[table][hop] == 0:
|
|
182
|
+
continue
|
|
183
|
+
|
|
184
|
+
# Collect the slices of data sampled in the previous hop
|
|
185
|
+
# (but maintain only required key information):
|
|
186
|
+
cols = [fkey for fkey, _ in self.foreign_key_dict[table]]
|
|
187
|
+
if table in self.primary_key_dict:
|
|
188
|
+
cols.append(self.primary_key_dict[table])
|
|
189
|
+
dfs = [df[cols] for df in dfs_dict[table][offset_dict[table]:]]
|
|
190
|
+
df = pd.concat(
|
|
191
|
+
dfs,
|
|
192
|
+
axis=0,
|
|
193
|
+
ignore_index=True,
|
|
194
|
+
) if len(dfs) > 1 else dfs[0]
|
|
195
|
+
batches = batches_dict[table][offset_dict[table]:]
|
|
196
|
+
batch = (np.concatenate(batches)
|
|
197
|
+
if len(batches) > 1 else batches[0])
|
|
198
|
+
offset_dict[table] = len(batches_dict[table]) # Increase.
|
|
199
|
+
|
|
200
|
+
pkey: pd.Series | None = None
|
|
201
|
+
index: pd.ndarray | None = None
|
|
202
|
+
if table in self.primary_key_dict:
|
|
203
|
+
pkey = df[self.primary_key_dict[table]]
|
|
204
|
+
index = mapper_dict[table].get(pkey, batch)
|
|
205
|
+
|
|
206
|
+
# Iterate over foreign keys in the current table:
|
|
207
|
+
for fkey, dst_table in self.foreign_key_dict[table]:
|
|
208
|
+
row = mapper_dict[dst_table].get(df[fkey], batch)
|
|
209
|
+
mask = row == -1
|
|
210
|
+
if mask.any():
|
|
211
|
+
key_df = pd.DataFrame({
|
|
212
|
+
'fkey': df[fkey],
|
|
213
|
+
'batch': batch,
|
|
214
|
+
}).iloc[mask]
|
|
215
|
+
# Only maintain unique keys per example:
|
|
216
|
+
unique_key_df = key_df.drop_duplicates()
|
|
217
|
+
# Fully de-duplicate keys across examples:
|
|
218
|
+
code, fkey_index = pd.factorize(unique_key_df['fkey'])
|
|
219
|
+
|
|
220
|
+
_df, _batch = self._by_pkey(
|
|
221
|
+
table_name=dst_table,
|
|
222
|
+
index=fkey_index,
|
|
223
|
+
columns=sample_columns_dict[dst_table],
|
|
224
|
+
) # Ensure result is sorted according to input order:
|
|
225
|
+
_df = _df.iloc[_batch.argsort()]
|
|
226
|
+
|
|
227
|
+
# Compute valid entries (without dangling foreign keys)
|
|
228
|
+
# in `unique_fkey_df`:
|
|
229
|
+
_mask = np.full(len(fkey_index), fill_value=False)
|
|
230
|
+
_mask[_batch] = True
|
|
231
|
+
_mask = _mask[code]
|
|
232
|
+
|
|
233
|
+
# Recontruct unique (key, batch) pairs:
|
|
234
|
+
code, _ = pd.factorize(unique_key_df['fkey'][_mask])
|
|
235
|
+
_df = _df.iloc[code].reset_index(drop=True)
|
|
236
|
+
_batch = unique_key_df['batch'].to_numpy()[_mask]
|
|
237
|
+
|
|
238
|
+
# Register node IDs:
|
|
239
|
+
mapper_dict[dst_table].add(
|
|
240
|
+
pkey=_df[self.primary_key_dict[dst_table]],
|
|
241
|
+
batch=_batch,
|
|
242
|
+
)
|
|
243
|
+
row[mask] = mapper_dict[dst_table].get(
|
|
244
|
+
pkey=key_df['fkey'],
|
|
245
|
+
batch=key_df['batch'].to_numpy(),
|
|
246
|
+
) # NOTE `row` may still hold `-1` for dangling fkeys.
|
|
247
|
+
|
|
248
|
+
dfs_dict[dst_table].append(_df)
|
|
249
|
+
batches_dict[dst_table].append(_batch)
|
|
250
|
+
num_sampled_nodes_dict[dst_table][hop + 1] += ( #
|
|
251
|
+
len(_batch))
|
|
252
|
+
|
|
253
|
+
mask = row != -1
|
|
254
|
+
|
|
255
|
+
col = index
|
|
256
|
+
if col is None:
|
|
257
|
+
start = sum(num_sampled_nodes_dict[table][:hop])
|
|
258
|
+
end = sum(num_sampled_nodes_dict[table][:hop + 1])
|
|
259
|
+
col = np.arange(start, end)
|
|
260
|
+
|
|
261
|
+
row = row[mask]
|
|
262
|
+
col = col[mask]
|
|
263
|
+
|
|
264
|
+
edge_type = (table, fkey, dst_table)
|
|
265
|
+
edge_type = Subgraph.rev_edge_type(edge_type)
|
|
266
|
+
rows_dict[edge_type].append(row)
|
|
267
|
+
cols_dict[edge_type].append(col)
|
|
268
|
+
num_sampled_edges_dict[edge_type][hop] = len(col)
|
|
269
|
+
|
|
270
|
+
# Iterate over foreign keys that reference the current table:
|
|
271
|
+
for src_table, fkey in self.rev_foreign_key_dict[table]:
|
|
272
|
+
assert pkey is not None and index is not None
|
|
273
|
+
_df, _batch = self._by_fkey(
|
|
274
|
+
table_name=src_table,
|
|
275
|
+
foreign_key=fkey,
|
|
276
|
+
index=pkey,
|
|
277
|
+
num_neighbors=neighbors,
|
|
278
|
+
anchor_time=anchor_time.iloc[batch],
|
|
279
|
+
columns=sample_columns_dict[src_table],
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
edge_type = (src_table, fkey, table)
|
|
283
|
+
cols_dict[edge_type].append(index[_batch])
|
|
284
|
+
num_sampled_edges_dict[edge_type][hop] = len(_batch)
|
|
285
|
+
|
|
286
|
+
_batch = batch[_batch]
|
|
287
|
+
num_nodes = sum(num_sampled_nodes_dict[src_table])
|
|
288
|
+
if src_table in self.primary_key_dict:
|
|
289
|
+
_pkey = _df[self.primary_key_dict]
|
|
290
|
+
mapper_dict[src_table].add(_pkey, _batch)
|
|
291
|
+
row = mapper_dict[src_table].get(_pkey, _batch)
|
|
292
|
+
|
|
293
|
+
# Only preserve unknown rows:
|
|
294
|
+
mask = row >= num_nodes # type: ignore
|
|
295
|
+
mask[pd.duplicated(row)] = False
|
|
296
|
+
_df = _df.iloc[mask]
|
|
297
|
+
_batch = _batch[mask]
|
|
298
|
+
else:
|
|
299
|
+
row = np.arange(num_nodes, num_nodes + len(_batch))
|
|
300
|
+
|
|
301
|
+
rows_dict[edge_type].append(row)
|
|
302
|
+
num_sampled_nodes_dict[src_table][hop + 1] += len(_batch)
|
|
303
|
+
|
|
304
|
+
dfs_dict[src_table].append(_df)
|
|
305
|
+
batches_dict[src_table].append(_batch)
|
|
306
|
+
|
|
307
|
+
# Post-Processing #####################################################
|
|
308
|
+
|
|
309
|
+
df_dict = {
|
|
310
|
+
table:
|
|
311
|
+
pd.concat(dfs, axis=0, ignore_index=True)
|
|
312
|
+
if len(dfs) > 1 else dfs[0]
|
|
313
|
+
for table, dfs in dfs_dict.items()
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
# Only store unique rows in `df` above a certain threshold:
|
|
317
|
+
inverse_dict: dict[str, np.ndarray] = {}
|
|
318
|
+
for table, df in df_dict.items():
|
|
319
|
+
if table not in self.primary_key_dict:
|
|
320
|
+
continue
|
|
321
|
+
unique, index, inverse = np.unique(
|
|
322
|
+
df_dict[table][self.primary_key_dict[table]],
|
|
323
|
+
return_index=True,
|
|
324
|
+
return_inverse=True,
|
|
325
|
+
)
|
|
326
|
+
if len(df) > 1.05 * len(unique):
|
|
327
|
+
df_dict[table] = df.iloc[index].reset_index(drop=True)
|
|
328
|
+
inverse_dict[table] = inverse
|
|
329
|
+
|
|
330
|
+
df_dict = { # Post-filter column set:
|
|
331
|
+
table: df[list(columns_dict[table])]
|
|
332
|
+
for table, df in df_dict.items()
|
|
333
|
+
}
|
|
334
|
+
batch_dict = {
|
|
335
|
+
table: np.concatenate(batches) if len(batches) > 1 else batches[0]
|
|
336
|
+
for table, batches in batches_dict.items()
|
|
337
|
+
}
|
|
338
|
+
row_dict = {
|
|
339
|
+
edge_type: np.concatenate(rows)
|
|
340
|
+
for edge_type, rows in rows_dict.items()
|
|
341
|
+
}
|
|
342
|
+
col_dict = {
|
|
343
|
+
edge_type: np.concatenate(cols)
|
|
344
|
+
for edge_type, cols in cols_dict.items()
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
if len(num_sampled_edges_dict) == 0: # Single table:
|
|
348
|
+
num_sampled_nodes_dict = {
|
|
349
|
+
key: value[:1]
|
|
350
|
+
for key, value in num_sampled_nodes_dict.items()
|
|
351
|
+
}
|
|
123
352
|
|
|
124
353
|
return SamplerOutput(
|
|
125
354
|
anchor_time=anchor_time.astype(int).to_numpy(),
|
|
126
|
-
df_dict=
|
|
127
|
-
inverse_dict=
|
|
128
|
-
batch_dict=
|
|
129
|
-
num_sampled_nodes_dict=
|
|
130
|
-
row_dict=
|
|
131
|
-
col_dict=
|
|
132
|
-
num_sampled_edges_dict=
|
|
355
|
+
df_dict=df_dict,
|
|
356
|
+
inverse_dict=inverse_dict,
|
|
357
|
+
batch_dict=batch_dict,
|
|
358
|
+
num_sampled_nodes_dict=num_sampled_nodes_dict,
|
|
359
|
+
row_dict=row_dict,
|
|
360
|
+
col_dict=col_dict,
|
|
361
|
+
num_sampled_edges_dict=num_sampled_edges_dict,
|
|
133
362
|
)
|
|
134
363
|
|
|
135
364
|
# Abstract Methods ########################################################
|
|
@@ -138,7 +367,19 @@ class SQLSampler(Sampler):
|
|
|
138
367
|
def _by_pkey(
|
|
139
368
|
self,
|
|
140
369
|
table_name: str,
|
|
141
|
-
|
|
370
|
+
index: pd.Series,
|
|
371
|
+
columns: set[str],
|
|
372
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
373
|
+
pass
|
|
374
|
+
|
|
375
|
+
@abstractmethod
|
|
376
|
+
def _by_fkey(
|
|
377
|
+
self,
|
|
378
|
+
table_name: str,
|
|
379
|
+
foreign_key: str,
|
|
380
|
+
index: pd.Series,
|
|
381
|
+
num_neighbors: int,
|
|
382
|
+
anchor_time: pd.Series | None,
|
|
142
383
|
columns: set[str],
|
|
143
384
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
144
385
|
pass
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import warnings
|
|
2
1
|
from abc import ABC, abstractmethod
|
|
3
2
|
from collections.abc import Sequence
|
|
4
3
|
from functools import cached_property
|
|
@@ -20,6 +19,7 @@ from kumoai.experimental.rfm.base import (
|
|
|
20
19
|
SourceColumn,
|
|
21
20
|
SourceForeignKey,
|
|
22
21
|
)
|
|
22
|
+
from kumoai.experimental.rfm.base.utils import to_datetime
|
|
23
23
|
from kumoai.experimental.rfm.infer import (
|
|
24
24
|
infer_dtype,
|
|
25
25
|
infer_primary_key,
|
|
@@ -399,39 +399,39 @@ class Table(ABC):
|
|
|
399
399
|
r"""Returns a :class:`pandas.DataFrame` object containing metadata
|
|
400
400
|
information about the columns in this table.
|
|
401
401
|
|
|
402
|
-
The returned dataframe has columns ``
|
|
403
|
-
``
|
|
404
|
-
which provide an
|
|
405
|
-
this table.
|
|
402
|
+
The returned dataframe has columns ``"Name"``, ``"Data Type"``,
|
|
403
|
+
``"Semantic Type"``, ``"Primary Key"``, ``"Time Column"`` and
|
|
404
|
+
``"End Time Column"``, which provide an aggregated view of the
|
|
405
|
+
properties of the columns of this table.
|
|
406
406
|
|
|
407
407
|
Example:
|
|
408
408
|
>>> # doctest: +SKIP
|
|
409
409
|
>>> import kumoai.experimental.rfm as rfm
|
|
410
410
|
>>> table = rfm.LocalTable(df=..., name=...).infer_metadata()
|
|
411
411
|
>>> table.metadata
|
|
412
|
-
|
|
413
|
-
0 CustomerID float64
|
|
412
|
+
Name Data Type Semantic Type Primary Key Time Column End Time Column
|
|
413
|
+
0 CustomerID float64 ID True False False
|
|
414
414
|
""" # noqa: E501
|
|
415
415
|
cols = self.columns
|
|
416
416
|
|
|
417
417
|
return pd.DataFrame({
|
|
418
|
-
'
|
|
418
|
+
'Name':
|
|
419
419
|
pd.Series(dtype=str, data=[c.name for c in cols]),
|
|
420
|
-
'
|
|
420
|
+
'Data Type':
|
|
421
421
|
pd.Series(dtype=str, data=[c.dtype for c in cols]),
|
|
422
|
-
'
|
|
422
|
+
'Semantic Type':
|
|
423
423
|
pd.Series(dtype=str, data=[c.stype for c in cols]),
|
|
424
|
-
'
|
|
424
|
+
'Primary Key':
|
|
425
425
|
pd.Series(
|
|
426
426
|
dtype=bool,
|
|
427
427
|
data=[self._primary_key == c.name for c in cols],
|
|
428
428
|
),
|
|
429
|
-
'
|
|
429
|
+
'Time Column':
|
|
430
430
|
pd.Series(
|
|
431
431
|
dtype=bool,
|
|
432
432
|
data=[self._time_column == c.name for c in cols],
|
|
433
433
|
),
|
|
434
|
-
'
|
|
434
|
+
'End Time Column':
|
|
435
435
|
pd.Series(
|
|
436
436
|
dtype=bool,
|
|
437
437
|
data=[self._end_time_column == c.name for c in cols],
|
|
@@ -623,20 +623,6 @@ class Table(ABC):
|
|
|
623
623
|
r"""Sanitzes a :class:`pandas.DataFrame` in-place such that its data
|
|
624
624
|
types match table data and semantic type specification.
|
|
625
625
|
"""
|
|
626
|
-
def _to_datetime(ser: pd.Series) -> pd.Series:
|
|
627
|
-
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
628
|
-
with warnings.catch_warnings():
|
|
629
|
-
warnings.filterwarnings(
|
|
630
|
-
'ignore',
|
|
631
|
-
message='Could not infer format',
|
|
632
|
-
)
|
|
633
|
-
ser = pd.to_datetime(ser, errors='coerce')
|
|
634
|
-
if isinstance(ser.dtype, pd.DatetimeTZDtype):
|
|
635
|
-
ser = ser.dt.tz_localize(None)
|
|
636
|
-
if ser.dtype != 'datetime64[ns]':
|
|
637
|
-
ser = ser.astype('datetime64[ns]')
|
|
638
|
-
return ser
|
|
639
|
-
|
|
640
626
|
def _to_list(ser: pd.Series, dtype: Dtype | None) -> pd.Series:
|
|
641
627
|
if (pd.api.types.is_string_dtype(ser)
|
|
642
628
|
and dtype in {Dtype.intlist, Dtype.floatlist}):
|
|
@@ -667,9 +653,9 @@ class Table(ABC):
|
|
|
667
653
|
stype = (stype_dict or {}).get(column_name)
|
|
668
654
|
|
|
669
655
|
if dtype == Dtype.time:
|
|
670
|
-
df[column_name] =
|
|
656
|
+
df[column_name] = to_datetime(df[column_name])
|
|
671
657
|
elif stype == Stype.timestamp:
|
|
672
|
-
df[column_name] =
|
|
658
|
+
df[column_name] = to_datetime(df[column_name])
|
|
673
659
|
elif dtype is not None and dtype.is_list():
|
|
674
660
|
df[column_name] = _to_list(df[column_name], dtype)
|
|
675
661
|
elif stype == Stype.sequence:
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import pyarrow as pa
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def is_datetime(ser: pd.Series) -> bool:
|
|
8
|
+
r"""Check whether a :class:`pandas.Series` holds datetime values."""
|
|
9
|
+
if isinstance(ser.dtype, pd.ArrowDtype):
|
|
10
|
+
dtype = ser.dtype.pyarrow_dtype
|
|
11
|
+
return (pa.types.is_timestamp(dtype) or pa.types.is_date(dtype)
|
|
12
|
+
or pa.types.is_time(dtype))
|
|
13
|
+
|
|
14
|
+
return pd.api.types.is_datetime64_any_dtype(ser)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def to_datetime(ser: pd.Series) -> pd.Series:
|
|
18
|
+
"""Converts a :class:`pandas.Series` to ``datetime64[ns]`` format."""
|
|
19
|
+
if isinstance(ser.dtype, pd.ArrowDtype):
|
|
20
|
+
ser = pd.Series(ser.to_numpy(), index=ser.index, name=ser.name)
|
|
21
|
+
|
|
22
|
+
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
23
|
+
with warnings.catch_warnings():
|
|
24
|
+
warnings.filterwarnings(
|
|
25
|
+
'ignore',
|
|
26
|
+
message='Could not infer format',
|
|
27
|
+
)
|
|
28
|
+
ser = pd.to_datetime(ser, errors='coerce')
|
|
29
|
+
|
|
30
|
+
if isinstance(ser.dtype, pd.DatetimeTZDtype):
|
|
31
|
+
ser = ser.dt.tz_localize(None)
|
|
32
|
+
|
|
33
|
+
if ser.dtype != 'datetime64[ns]':
|
|
34
|
+
ser = ser.astype('datetime64[ns]')
|
|
35
|
+
|
|
36
|
+
return ser
|
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -649,28 +649,28 @@ class Graph:
|
|
|
649
649
|
r"""Returns a :class:`pandas.DataFrame` object containing metadata
|
|
650
650
|
information about the tables in this graph.
|
|
651
651
|
|
|
652
|
-
The returned dataframe has columns ``
|
|
653
|
-
``
|
|
654
|
-
view of the properties of the tables of this graph.
|
|
652
|
+
The returned dataframe has columns ``"Name"``, ``"Primary Key"``,
|
|
653
|
+
``"Time Column"``, and ``"End Time Column"``, which provide an
|
|
654
|
+
aggregated view of the properties of the tables of this graph.
|
|
655
655
|
|
|
656
656
|
Example:
|
|
657
657
|
>>> # doctest: +SKIP
|
|
658
658
|
>>> import kumoai.experimental.rfm as rfm
|
|
659
659
|
>>> graph = rfm.Graph(tables=...).infer_metadata()
|
|
660
660
|
>>> graph.metadata # doctest: +SKIP
|
|
661
|
-
|
|
662
|
-
0 users
|
|
661
|
+
Name Primary Key Time Column End Time Column
|
|
662
|
+
0 users user_id - -
|
|
663
663
|
"""
|
|
664
664
|
tables = list(self.tables.values())
|
|
665
665
|
|
|
666
666
|
return pd.DataFrame({
|
|
667
|
-
'
|
|
667
|
+
'Name':
|
|
668
668
|
pd.Series(dtype=str, data=[t.name for t in tables]),
|
|
669
|
-
'
|
|
669
|
+
'Primary Key':
|
|
670
670
|
pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
|
|
671
|
-
'
|
|
671
|
+
'Time Column':
|
|
672
672
|
pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
|
|
673
|
-
'
|
|
673
|
+
'End Time Column':
|
|
674
674
|
pd.Series(
|
|
675
675
|
dtype=str,
|
|
676
676
|
data=[t._end_time_column or '-' for t in tables],
|
|
@@ -3,6 +3,8 @@ import pandas as pd
|
|
|
3
3
|
import pyarrow as pa
|
|
4
4
|
from kumoapi.typing import Dtype
|
|
5
5
|
|
|
6
|
+
from kumoai.experimental.rfm.base.utils import is_datetime
|
|
7
|
+
|
|
6
8
|
PANDAS_TO_DTYPE: dict[str, Dtype] = {
|
|
7
9
|
'bool': Dtype.bool,
|
|
8
10
|
'boolean': Dtype.bool,
|
|
@@ -34,7 +36,7 @@ def infer_dtype(ser: pd.Series) -> Dtype:
|
|
|
34
36
|
Returns:
|
|
35
37
|
The data type.
|
|
36
38
|
"""
|
|
37
|
-
if
|
|
39
|
+
if is_datetime(ser):
|
|
38
40
|
return Dtype.date
|
|
39
41
|
if pd.api.types.is_timedelta64_dtype(ser.dtype):
|
|
40
42
|
return Dtype.timedelta
|
|
@@ -3,6 +3,8 @@ import warnings
|
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
|
|
6
|
+
from kumoai.experimental.rfm.base.utils import to_datetime
|
|
7
|
+
|
|
6
8
|
|
|
7
9
|
def infer_time_column(
|
|
8
10
|
df: pd.DataFrame,
|
|
@@ -43,11 +45,11 @@ def infer_time_column(
|
|
|
43
45
|
with warnings.catch_warnings():
|
|
44
46
|
warnings.filterwarnings('ignore', message='Could not infer format')
|
|
45
47
|
min_timestamp_dict = {
|
|
46
|
-
key:
|
|
48
|
+
key: to_datetime(df[key].iloc[:10_000])
|
|
47
49
|
for key in candidates
|
|
48
50
|
}
|
|
49
51
|
min_timestamp_dict = {
|
|
50
|
-
key: value.min()
|
|
52
|
+
key: value.min()
|
|
51
53
|
for key, value in min_timestamp_dict.items()
|
|
52
54
|
}
|
|
53
55
|
min_timestamp_dict = {
|