kumoai 2.14.0.dev202512211732__cp313-cp313-win_amd64.whl → 2.15.0.dev202601151732__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +24 -22
  7. kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
  8. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  9. kumoai/experimental/rfm/backend/local/table.py +24 -25
  10. kumoai/experimental/rfm/backend/snow/sampler.py +235 -80
  11. kumoai/experimental/rfm/backend/snow/table.py +146 -70
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +196 -89
  13. kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
  14. kumoai/experimental/rfm/base/__init__.py +6 -9
  15. kumoai/experimental/rfm/base/column.py +95 -11
  16. kumoai/experimental/rfm/base/expression.py +44 -0
  17. kumoai/experimental/rfm/base/mapper.py +69 -0
  18. kumoai/experimental/rfm/base/sampler.py +28 -18
  19. kumoai/experimental/rfm/base/source.py +1 -1
  20. kumoai/experimental/rfm/base/sql_sampler.py +320 -19
  21. kumoai/experimental/rfm/base/table.py +256 -109
  22. kumoai/experimental/rfm/base/utils.py +36 -0
  23. kumoai/experimental/rfm/graph.py +130 -110
  24. kumoai/experimental/rfm/infer/dtype.py +7 -2
  25. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  26. kumoai/experimental/rfm/infer/time_col.py +4 -2
  27. kumoai/experimental/rfm/relbench.py +76 -0
  28. kumoai/experimental/rfm/rfm.py +540 -306
  29. kumoai/experimental/rfm/task_table.py +292 -0
  30. kumoai/kumolib.cp313-win_amd64.pyd +0 -0
  31. kumoai/pquery/training_table.py +16 -2
  32. kumoai/testing/snow.py +3 -3
  33. kumoai/trainer/distilled_trainer.py +175 -0
  34. kumoai/utils/display.py +87 -0
  35. kumoai/utils/progress_logger.py +15 -2
  36. kumoai/utils/sql.py +2 -2
  37. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/METADATA +2 -2
  38. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/RECORD +41 -36
  39. kumoai/experimental/rfm/base/column_expression.py +0 -50
  40. kumoai/experimental/rfm/base/sql_table.py +0 -229
  41. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/WHEEL +0 -0
  42. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/top_level.txt +0 -0
@@ -2,21 +2,31 @@ from abc import ABC, abstractmethod
2
2
  from collections.abc import Sequence
3
3
  from functools import cached_property
4
4
 
5
+ import numpy as np
5
6
  import pandas as pd
6
7
  from kumoapi.model_plan import MissingType
7
8
  from kumoapi.source_table import UnavailableSourceTable
8
9
  from kumoapi.table import Column as ColumnDefinition
9
10
  from kumoapi.table import TableDefinition
10
- from kumoapi.typing import Stype
11
+ from kumoapi.typing import Dtype, Stype
11
12
  from typing_extensions import Self
12
13
 
13
- from kumoai import in_notebook, in_snowflake_notebook
14
- from kumoai.experimental.rfm.base import Column, DataBackend, SourceColumn
14
+ from kumoai.experimental.rfm.base import (
15
+ Column,
16
+ ColumnSpec,
17
+ ColumnSpecType,
18
+ DataBackend,
19
+ SourceColumn,
20
+ SourceForeignKey,
21
+ )
22
+ from kumoai.experimental.rfm.base.utils import to_datetime
15
23
  from kumoai.experimental.rfm.infer import (
24
+ infer_dtype,
16
25
  infer_primary_key,
17
26
  infer_stype,
18
27
  infer_time_column,
19
28
  )
29
+ from kumoai.utils import display, quote_ident
20
30
 
21
31
 
22
32
  class Table(ABC):
@@ -26,39 +36,46 @@ class Table(ABC):
26
36
 
27
37
  Args:
28
38
  name: The name of this table.
39
+ source_name: The source name of this table. If set to ``None``,
40
+ ``name`` is being used.
29
41
  columns: The selected columns of this table.
30
42
  primary_key: The name of the primary key of this table, if it exists.
31
43
  time_column: The name of the time column of this table, if it exists.
32
44
  end_time_column: The name of the end time column of this table, if it
33
45
  exists.
34
46
  """
47
+ _NUM_SAMPLE_ROWS = 1_000
48
+
35
49
  def __init__(
36
50
  self,
37
51
  name: str,
38
- columns: Sequence[str] | None = None,
52
+ source_name: str | None = None,
53
+ columns: Sequence[ColumnSpecType] | None = None,
39
54
  primary_key: MissingType | str | None = MissingType.VALUE,
40
55
  time_column: str | None = None,
41
56
  end_time_column: str | None = None,
42
57
  ) -> None:
43
58
 
44
59
  self._name = name
60
+ self._source_name = source_name or name
61
+ self._column_dict: dict[str, Column] = {}
45
62
  self._primary_key: str | None = None
46
63
  self._time_column: str | None = None
47
64
  self._end_time_column: str | None = None
65
+ self._expr_sample_df = pd.DataFrame(index=range(self._NUM_SAMPLE_ROWS))
48
66
 
49
67
  if columns is None:
50
68
  columns = list(self._source_column_dict.keys())
51
69
 
52
- self._columns: dict[str, Column] = {}
53
- for column_name in columns:
54
- self.add_column(column_name)
70
+ self.add_columns(columns)
55
71
 
56
72
  if isinstance(primary_key, MissingType):
57
- # Inference from source column metadata:
58
- if '_source_column_dict' in self.__dict__:
73
+ # Infer primary key from source metadata, but only set it in case
74
+ # it is already part of the column set (don't magically add it):
75
+ if any(column.is_source for column in self.columns):
59
76
  primary_key = self._source_primary_key
60
77
  if (primary_key is not None and primary_key in self
61
- and self[primary_key].is_physical):
78
+ and self[primary_key].is_source):
62
79
  self.primary_key = primary_key
63
80
  elif primary_key is not None:
64
81
  if primary_key not in self:
@@ -80,13 +97,22 @@ class Table(ABC):
80
97
  r"""The name of this table."""
81
98
  return self._name
82
99
 
100
+ @property
101
+ def source_name(self) -> str:
102
+ r"""The source name of this table."""
103
+ return self._source_name
104
+
105
+ @property
106
+ def _quoted_source_name(self) -> str:
107
+ return quote_ident(self._source_name)
108
+
83
109
  # Column ##################################################################
84
110
 
85
111
  def has_column(self, name: str) -> bool:
86
112
  r"""Returns ``True`` if this table holds a column with name ``name``;
87
113
  ``False`` otherwise.
88
114
  """
89
- return name in self._columns
115
+ return name in self._column_dict
90
116
 
91
117
  def column(self, name: str) -> Column:
92
118
  r"""Returns the data column named with name ``name`` in this table.
@@ -99,51 +125,113 @@ class Table(ABC):
99
125
  """
100
126
  if not self.has_column(name):
101
127
  raise KeyError(f"Column '{name}' not found in table '{self.name}'")
102
- return self._columns[name]
128
+ return self._column_dict[name]
103
129
 
104
130
  @property
105
131
  def columns(self) -> list[Column]:
106
132
  r"""Returns a list of :class:`Column` objects that represent the
107
133
  columns in this table.
108
134
  """
109
- return list(self._columns.values())
135
+ return list(self._column_dict.values())
110
136
 
111
- def add_column(self, name: str) -> Column:
112
- r"""Adds a column to this table.
137
+ def add_columns(self, columns: Sequence[ColumnSpecType]) -> None:
138
+ r"""Adds a set of columns to this table.
113
139
 
114
140
  Args:
115
- name: The name of the column.
141
+ columns: The columns to add.
116
142
 
117
143
  Raises:
118
- KeyError: If ``name`` is already present in this table.
144
+ KeyError: If any of the column names already exist in this table.
119
145
  """
120
- if name in self:
121
- raise KeyError(f"Column '{name}' already exists in table "
122
- f"'{self.name}'")
123
-
124
- if name not in self._source_column_dict:
125
- raise KeyError(f"Column '{name}' does not exist in the underlying "
126
- f"source table")
127
-
128
- dtype = self._source_column_dict[name].dtype
129
-
130
- ser = self._source_sample_df[name]
131
- try:
132
- stype = infer_stype(ser, name, dtype)
133
- except Exception as e:
134
- raise RuntimeError(f"Could not obtain semantic type for column "
135
- f"'{name}' with data type '{dtype}' in table "
136
- f"'{self.name}'. Change the data type of the "
137
- f"column in the source table or remove it from "
138
- f"this table.") from e
139
-
140
- self._columns[name] = Column(
141
- name=name,
142
- stype=stype,
143
- dtype=dtype,
144
- )
146
+ if len(columns) == 0:
147
+ return
145
148
 
146
- return self._columns[name]
149
+ column_specs = [ColumnSpec.coerce(column) for column in columns]
150
+
151
+ # Obtain a batch-wise sample for all column expressions:
152
+ expr_specs = [spec for spec in column_specs if not spec.is_source]
153
+ if len(expr_specs) > 0:
154
+ dfs = [
155
+ self._expr_sample_df,
156
+ self._get_expr_sample_df(expr_specs).reset_index(drop=True),
157
+ ]
158
+ size = min(map(len, dfs))
159
+ df = pd.concat([dfs[0].iloc[:size], dfs[1].iloc[:size]], axis=1)
160
+ df = df.loc[:, ~df.columns.duplicated(keep='last')]
161
+ self._expr_sample_df = df
162
+
163
+ for column_spec in column_specs:
164
+ if column_spec.name in self:
165
+ raise KeyError(f"Column '{column_spec.name}' already exists "
166
+ f"in table '{self.name}'")
167
+
168
+ dtype = column_spec.dtype
169
+ stype = column_spec.stype
170
+
171
+ if column_spec.is_source:
172
+ if column_spec.name not in self._source_column_dict:
173
+ raise ValueError(
174
+ f"Column '{column_spec.name}' does not exist in the "
175
+ f"underlying source table")
176
+
177
+ if dtype is None:
178
+ dtype = self._source_column_dict[column_spec.name].dtype
179
+
180
+ if dtype == Dtype.unsupported:
181
+ raise ValueError(
182
+ f"Encountered unsupported data type for column "
183
+ f"'{column_spec.name}' in table '{self.name}'. Please "
184
+ f"either change the column's data type or remove the "
185
+ f"column from this table.")
186
+
187
+ if dtype is None:
188
+ if column_spec.is_source:
189
+ ser = self._source_sample_df[column_spec.name]
190
+ else:
191
+ ser = self._expr_sample_df[column_spec.name]
192
+ try:
193
+ dtype = infer_dtype(ser)
194
+ except Exception as e:
195
+ raise RuntimeError(
196
+ f"Encountered unsupported data type '{ser.dtype}' for "
197
+ f"column '{column_spec.name}' in table '{self.name}'. "
198
+ f"Please either manually override the columns's data "
199
+ f"type or remove the column from this table.") from e
200
+
201
+ if stype is None:
202
+ if column_spec.is_source:
203
+ ser = self._source_sample_df[column_spec.name]
204
+ else:
205
+ ser = self._expr_sample_df[column_spec.name]
206
+ try:
207
+ stype = infer_stype(ser, column_spec.name, dtype)
208
+ except Exception as e:
209
+ raise RuntimeError(
210
+ f"Could not determine semantic type for column "
211
+ f"'{column_spec.name}' with data type '{dtype}' in "
212
+ f"table '{self.name}'. Please either change the "
213
+ f"column's data type or remove the column from this "
214
+ f"table.") from e
215
+
216
+ self._column_dict[column_spec.name] = Column(
217
+ name=column_spec.name,
218
+ expr=column_spec.expr,
219
+ dtype=dtype,
220
+ stype=stype,
221
+ )
222
+
223
+ def add_column(self, column: ColumnSpecType) -> Column:
224
+ r"""Adds a column to this table.
225
+
226
+ Args:
227
+ column: The column to add.
228
+
229
+ Raises:
230
+ KeyError: If the column name already exists in this table.
231
+ """
232
+ column_spec = ColumnSpec.coerce(column)
233
+ self.add_columns([column_spec])
234
+ return self[column_spec.name]
147
235
 
148
236
  def remove_column(self, name: str) -> Self:
149
237
  r"""Removes a column from this table.
@@ -163,7 +251,7 @@ class Table(ABC):
163
251
  self.time_column = None
164
252
  if self._end_time_column == name:
165
253
  self.end_time_column = None
166
- del self._columns[name]
254
+ del self._column_dict[name]
167
255
 
168
256
  return self
169
257
 
@@ -183,8 +271,8 @@ class Table(ABC):
183
271
  no such primary key is present.
184
272
 
185
273
  The setter sets a column as a primary key on this table, and raises a
186
- :class:`ValueError` if the primary key has a non-ID semantic type or
187
- if the column name does not match a column in the data frame.
274
+ :class:`ValueError` if the primary key has a non-ID compatible data
275
+ type or if the column name does not match a column in the data frame.
188
276
  """
189
277
  if self._primary_key is None:
190
278
  return None
@@ -228,8 +316,9 @@ class Table(ABC):
228
316
  such time column is present.
229
317
 
230
318
  The setter sets a column as a time column on this table, and raises a
231
- :class:`ValueError` if the time column has a non-timestamp semantic
232
- type or if the column name does not match a column in the data frame.
319
+ :class:`ValueError` if the time column has a non-timestamp compatible
320
+ data type or if the column name does not match a column in the data
321
+ frame.
233
322
  """
234
323
  if self._time_column is None:
235
324
  return None
@@ -274,8 +363,8 @@ class Table(ABC):
274
363
 
275
364
  The setter sets a column as an end time column on this table, and
276
365
  raises a :class:`ValueError` if the end time column has a non-timestamp
277
- semantic type or if the column name does not match a column in the data
278
- frame.
366
+ compatible data type or if the column name does not match a column in
367
+ the data frame.
279
368
  """
280
369
  if self._end_time_column is None:
281
370
  return None
@@ -310,39 +399,39 @@ class Table(ABC):
310
399
  r"""Returns a :class:`pandas.DataFrame` object containing metadata
311
400
  information about the columns in this table.
312
401
 
313
- The returned dataframe has columns ``name``, ``dtype``, ``stype``,
314
- ``is_primary_key``, ``is_time_column`` and ``is_end_time_column``,
315
- which provide an aggregate view of the properties of the columns of
316
- this table.
402
+ The returned dataframe has columns ``"Name"``, ``"Data Type"``,
403
+ ``"Semantic Type"``, ``"Primary Key"``, ``"Time Column"`` and
404
+ ``"End Time Column"``, which provide an aggregated view of the
405
+ properties of the columns of this table.
317
406
 
318
407
  Example:
319
408
  >>> # doctest: +SKIP
320
409
  >>> import kumoai.experimental.rfm as rfm
321
410
  >>> table = rfm.LocalTable(df=..., name=...).infer_metadata()
322
411
  >>> table.metadata
323
- name dtype stype is_primary_key is_time_column is_end_time_column
324
- 0 CustomerID float64 ID True False False
412
+ Name Data Type Semantic Type Primary Key Time Column End Time Column
413
+ 0 CustomerID float64 ID True False False
325
414
  """ # noqa: E501
326
415
  cols = self.columns
327
416
 
328
417
  return pd.DataFrame({
329
- 'name':
418
+ 'Name':
330
419
  pd.Series(dtype=str, data=[c.name for c in cols]),
331
- 'dtype':
420
+ 'Data Type':
332
421
  pd.Series(dtype=str, data=[c.dtype for c in cols]),
333
- 'stype':
422
+ 'Semantic Type':
334
423
  pd.Series(dtype=str, data=[c.stype for c in cols]),
335
- 'is_primary_key':
424
+ 'Primary Key':
336
425
  pd.Series(
337
426
  dtype=bool,
338
427
  data=[self._primary_key == c.name for c in cols],
339
428
  ),
340
- 'is_time_column':
429
+ 'Time Column':
341
430
  pd.Series(
342
431
  dtype=bool,
343
432
  data=[self._time_column == c.name for c in cols],
344
433
  ),
345
- 'is_end_time_column':
434
+ 'End Time Column':
346
435
  pd.Series(
347
436
  dtype=bool,
348
437
  data=[self._end_time_column == c.name for c in cols],
@@ -351,30 +440,12 @@ class Table(ABC):
351
440
 
352
441
  def print_metadata(self) -> None:
353
442
  r"""Prints the :meth:`~metadata` of this table."""
354
- num_rows_repr = ''
355
- if self._num_rows is not None:
356
- num_rows_repr = ' ({self._num_rows:,} rows)'
357
-
358
- if in_snowflake_notebook():
359
- import streamlit as st
360
- md_repr = f"### 🏷️ Metadata of Table `{self.name}`{num_rows_repr}"
361
- st.markdown(md_repr)
362
- st.dataframe(self.metadata, hide_index=True)
363
- elif in_notebook():
364
- from IPython.display import Markdown, display
365
- md_repr = f"### 🏷️ Metadata of Table `{self.name}`{num_rows_repr}"
366
- display(Markdown(md_repr))
367
- df = self.metadata
368
- try:
369
- if hasattr(df.style, 'hide'):
370
- display(df.style.hide(axis='index')) # pandas=2
371
- else:
372
- display(df.style.hide_index()) # pandas<1.3
373
- except ImportError:
374
- print(df.to_string(index=False)) # missing jinja2
375
- else:
376
- print(f"🏷️ Metadata of Table '{self.name}'{num_rows_repr}")
377
- print(self.metadata.to_string(index=False))
443
+ msg = f"🏷️ Metadata of Table `{self.name}`"
444
+ if num := self._num_rows:
445
+ msg += " (1 row)" if num == 1 else f" ({num:,} rows)"
446
+
447
+ display.title(msg)
448
+ display.dataframe(self.metadata)
378
449
 
379
450
  def infer_primary_key(self, verbose: bool = True) -> Self:
380
451
  r"""Infers the primary key in this table.
@@ -388,14 +459,14 @@ class Table(ABC):
388
459
  def _set_primary_key(primary_key: str) -> None:
389
460
  self.primary_key = primary_key
390
461
  if verbose:
391
- print(f"Detected primary key '{primary_key}' in table "
392
- f"'{self.name}'")
462
+ display.message(f"Inferred primary key `{primary_key}` for "
463
+ f"table `{self.name}`")
393
464
 
394
465
  # Inference from source column metadata:
395
- if '_source_column_dict' in self.__dict__:
466
+ if any(column.is_source for column in self.columns):
396
467
  primary_key = self._source_primary_key
397
468
  if (primary_key is not None and primary_key in self
398
- and self[primary_key].is_physical):
469
+ and self[primary_key].is_source):
399
470
  _set_primary_key(primary_key)
400
471
  return self
401
472
 
@@ -405,7 +476,7 @@ class Table(ABC):
405
476
  ]
406
477
  if (len(unique_keys) == 1 # NOTE No composite keys yet.
407
478
  and unique_keys[0] in self
408
- and self[unique_keys[0]].is_physical):
479
+ and self[unique_keys[0]].is_source):
409
480
  _set_primary_key(unique_keys[0])
410
481
  return self
411
482
 
@@ -423,7 +494,7 @@ class Table(ABC):
423
494
 
424
495
  if primary_key := infer_primary_key(
425
496
  table_name=self.name,
426
- df=self._sample_current_df(columns=candidates),
497
+ df=self._get_sample_df(),
427
498
  candidates=candidates,
428
499
  ):
429
500
  _set_primary_key(primary_key)
@@ -448,14 +519,14 @@ class Table(ABC):
448
519
  ]
449
520
 
450
521
  if time_column := infer_time_column(
451
- df=self._sample_current_df(columns=candidates),
522
+ df=self._get_sample_df(),
452
523
  candidates=candidates,
453
524
  ):
454
525
  self.time_column = time_column
455
526
 
456
527
  if verbose:
457
- print(f"Detected time column '{time_column}' in table "
458
- f"'{self.name}'")
528
+ display.message(f"Inferred time column `{time_column}` for "
529
+ f"table `{self.name}`")
459
530
 
460
531
  return self
461
532
 
@@ -471,15 +542,16 @@ class Table(ABC):
471
542
  if not self.has_primary_key():
472
543
  self.infer_primary_key(verbose=False)
473
544
  if self.has_primary_key():
474
- logs.append(f"primary key '{self._primary_key}'")
545
+ logs.append(f"primary key `{self._primary_key}`")
475
546
 
476
547
  if not self.has_time_column():
477
548
  self.infer_time_column(verbose=False)
478
549
  if self.has_time_column():
479
- logs.append(f"time column '{self._time_column}'")
550
+ logs.append(f"time column `{self._time_column}`")
480
551
 
481
552
  if verbose and len(logs) > 0:
482
- print(f"Detected {' and '.join(logs)} in table '{self.name}'")
553
+ display.message(f"Inferred {' and '.join(logs)} for table "
554
+ f"`{self.name}`")
483
555
 
484
556
  return self
485
557
 
@@ -501,31 +573,95 @@ class Table(ABC):
501
573
  def _source_column_dict(self) -> dict[str, SourceColumn]:
502
574
  source_columns = self._get_source_columns()
503
575
  if len(source_columns) == 0:
504
- raise ValueError(f"Table '{self.name}' does not hold any column "
505
- f"with a supported data type")
576
+ raise ValueError(f"Table '{self.name}' has no columns")
506
577
  return {column.name: column for column in source_columns}
507
578
 
508
579
  @cached_property
509
- def _source_sample_df(self) -> pd.DataFrame:
510
- return self._get_source_sample_df()
511
-
512
- @property
513
580
  def _source_primary_key(self) -> str | None:
514
581
  primary_keys = [
515
582
  column.name for column in self._source_column_dict.values()
516
583
  if column.is_primary_key
517
584
  ]
518
- if len(primary_keys) == 1: # NOTE No composite keys yet.
519
- return primary_keys[0]
585
+ # NOTE No composite keys yet.
586
+ return primary_keys[0] if len(primary_keys) == 1 else None
587
+
588
+ @cached_property
589
+ def _source_foreign_key_dict(self) -> dict[str, SourceForeignKey]:
590
+ return {key.name: key for key in self._get_source_foreign_keys()}
520
591
 
521
- return None
592
+ @cached_property
593
+ def _source_sample_df(self) -> pd.DataFrame:
594
+ return self._get_source_sample_df().reset_index(drop=True)
522
595
 
523
596
  @cached_property
524
597
  def _num_rows(self) -> int | None:
525
598
  return self._get_num_rows()
526
599
 
527
- def _sample_current_df(self, columns: Sequence[str]) -> pd.DataFrame:
528
- return self._source_sample_df[columns]
600
+ def _get_sample_df(self) -> pd.DataFrame:
601
+ dfs: list[pd.DataFrame] = []
602
+ if any(column.is_source for column in self.columns):
603
+ dfs.append(self._source_sample_df)
604
+ if any(not column.is_source for column in self.columns):
605
+ dfs.append(self._expr_sample_df)
606
+
607
+ if len(dfs) == 0:
608
+ return pd.DataFrame(index=range(1000))
609
+ if len(dfs) == 1:
610
+ return dfs[0]
611
+
612
+ size = min(map(len, dfs))
613
+ df = pd.concat([dfs[0].iloc[:size], dfs[1].iloc[:size]], axis=1)
614
+ df = df.loc[:, ~df.columns.duplicated(keep='last')]
615
+ return df
616
+
617
+ @staticmethod
618
+ def _sanitize(
619
+ df: pd.DataFrame,
620
+ dtype_dict: dict[str, Dtype | None] | None = None,
621
+ stype_dict: dict[str, Stype | None] | None = None,
622
+ ) -> pd.DataFrame:
623
+ r"""Sanitzes a :class:`pandas.DataFrame` in-place such that its data
624
+ types match table data and semantic type specification.
625
+ """
626
+ def _to_list(ser: pd.Series, dtype: Dtype | None) -> pd.Series:
627
+ if (pd.api.types.is_string_dtype(ser)
628
+ and dtype in {Dtype.intlist, Dtype.floatlist}):
629
+ try:
630
+ ser = ser.map(lambda row: np.fromstring(
631
+ row.strip('[]'),
632
+ sep=',',
633
+ dtype=int if dtype == Dtype.intlist else np.float32,
634
+ ) if row is not None else None)
635
+ except Exception:
636
+ pass
637
+
638
+ if pd.api.types.is_string_dtype(ser):
639
+ try:
640
+ import orjson as json
641
+ except ImportError:
642
+ import json
643
+ try:
644
+ ser = ser.map(lambda row: json.loads(row)
645
+ if row is not None else None)
646
+ except Exception:
647
+ pass
648
+
649
+ return ser
650
+
651
+ for column_name in df.columns:
652
+ dtype = (dtype_dict or {}).get(column_name)
653
+ stype = (stype_dict or {}).get(column_name)
654
+
655
+ if dtype == Dtype.time:
656
+ df[column_name] = to_datetime(df[column_name])
657
+ elif stype == Stype.timestamp:
658
+ df[column_name] = to_datetime(df[column_name])
659
+ elif dtype is not None and dtype.is_list():
660
+ df[column_name] = _to_list(df[column_name], dtype)
661
+ elif stype == Stype.sequence:
662
+ df[column_name] = _to_list(df[column_name], Dtype.floatlist)
663
+
664
+ return df
529
665
 
530
666
  # Python builtins #########################################################
531
667
 
@@ -566,10 +702,21 @@ class Table(ABC):
566
702
  def _get_source_columns(self) -> list[SourceColumn]:
567
703
  pass
568
704
 
705
+ @abstractmethod
706
+ def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
707
+ pass
708
+
569
709
  @abstractmethod
570
710
  def _get_source_sample_df(self) -> pd.DataFrame:
571
711
  pass
572
712
 
713
+ @abstractmethod
714
+ def _get_expr_sample_df(
715
+ self,
716
+ columns: Sequence[ColumnSpec],
717
+ ) -> pd.DataFrame:
718
+ pass
719
+
573
720
  @abstractmethod
574
721
  def _get_num_rows(self) -> int | None:
575
722
  pass
@@ -0,0 +1,36 @@
1
+ import warnings
2
+
3
+ import pandas as pd
4
+ import pyarrow as pa
5
+
6
+
7
+ def is_datetime(ser: pd.Series) -> bool:
8
+ r"""Check whether a :class:`pandas.Series` holds datetime values."""
9
+ if isinstance(ser.dtype, pd.ArrowDtype):
10
+ dtype = ser.dtype.pyarrow_dtype
11
+ return (pa.types.is_timestamp(dtype) or pa.types.is_date(dtype)
12
+ or pa.types.is_time(dtype))
13
+
14
+ return pd.api.types.is_datetime64_any_dtype(ser)
15
+
16
+
17
+ def to_datetime(ser: pd.Series) -> pd.Series:
18
+ """Converts a :class:`pandas.Series` to ``datetime64[ns]`` format."""
19
+ if isinstance(ser.dtype, pd.ArrowDtype):
20
+ ser = pd.Series(ser.to_numpy(), index=ser.index, name=ser.name)
21
+
22
+ if not pd.api.types.is_datetime64_any_dtype(ser):
23
+ with warnings.catch_warnings():
24
+ warnings.filterwarnings(
25
+ 'ignore',
26
+ message='Could not infer format',
27
+ )
28
+ ser = pd.to_datetime(ser, errors='coerce')
29
+
30
+ if isinstance(ser.dtype, pd.DatetimeTZDtype):
31
+ ser = ser.dt.tz_localize(None)
32
+
33
+ if ser.dtype != 'datetime64[ns]':
34
+ ser = ser.astype('datetime64[ns]')
35
+
36
+ return ser