kumoai 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl → 2.15.0.dev202601121731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +51 -24
  7. kumoai/experimental/rfm/authenticate.py +3 -4
  8. kumoai/experimental/rfm/backend/local/graph_store.py +37 -46
  9. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  10. kumoai/experimental/rfm/backend/local/table.py +24 -30
  11. kumoai/experimental/rfm/backend/snow/sampler.py +197 -90
  12. kumoai/experimental/rfm/backend/snow/table.py +159 -52
  13. kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
  14. kumoai/experimental/rfm/backend/sqlite/sampler.py +199 -99
  15. kumoai/experimental/rfm/backend/sqlite/table.py +103 -45
  16. kumoai/experimental/rfm/base/__init__.py +6 -1
  17. kumoai/experimental/rfm/base/column.py +96 -10
  18. kumoai/experimental/rfm/base/expression.py +44 -0
  19. kumoai/experimental/rfm/base/mapper.py +69 -0
  20. kumoai/experimental/rfm/base/sampler.py +28 -18
  21. kumoai/experimental/rfm/base/source.py +1 -1
  22. kumoai/experimental/rfm/base/sql_sampler.py +342 -13
  23. kumoai/experimental/rfm/base/table.py +374 -208
  24. kumoai/experimental/rfm/base/utils.py +27 -0
  25. kumoai/experimental/rfm/graph.py +335 -180
  26. kumoai/experimental/rfm/infer/__init__.py +6 -4
  27. kumoai/experimental/rfm/infer/dtype.py +7 -4
  28. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  29. kumoai/experimental/rfm/infer/pkey.py +4 -2
  30. kumoai/experimental/rfm/infer/stype.py +35 -0
  31. kumoai/experimental/rfm/infer/time_col.py +5 -4
  32. kumoai/experimental/rfm/pquery/executor.py +27 -27
  33. kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
  34. kumoai/experimental/rfm/relbench.py +76 -0
  35. kumoai/experimental/rfm/rfm.py +600 -360
  36. kumoai/experimental/rfm/sagemaker.py +4 -4
  37. kumoai/experimental/rfm/task_table.py +292 -0
  38. kumoai/pquery/training_table.py +16 -2
  39. kumoai/testing/snow.py +3 -3
  40. kumoai/trainer/distilled_trainer.py +175 -0
  41. kumoai/utils/__init__.py +1 -2
  42. kumoai/utils/display.py +87 -0
  43. kumoai/utils/progress_logger.py +190 -12
  44. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/METADATA +3 -2
  45. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/RECORD +48 -40
  46. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/WHEEL +0 -0
  47. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/licenses/LICENSE +0 -0
  48. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,32 @@
1
1
  from abc import ABC, abstractmethod
2
- from collections import defaultdict
2
+ from collections.abc import Sequence
3
3
  from functools import cached_property
4
- from typing import Dict, List, Optional, Sequence, Set
5
4
 
5
+ import numpy as np
6
6
  import pandas as pd
7
+ from kumoapi.model_plan import MissingType
7
8
  from kumoapi.source_table import UnavailableSourceTable
8
9
  from kumoapi.table import Column as ColumnDefinition
9
10
  from kumoapi.table import TableDefinition
10
- from kumoapi.typing import Stype
11
+ from kumoapi.typing import Dtype, Stype
11
12
  from typing_extensions import Self
12
13
 
13
- from kumoai import in_notebook, in_snowflake_notebook
14
14
  from kumoai.experimental.rfm.base import (
15
15
  Column,
16
+ ColumnSpec,
17
+ ColumnSpecType,
16
18
  DataBackend,
17
19
  SourceColumn,
18
20
  SourceForeignKey,
19
21
  )
22
+ from kumoai.experimental.rfm.base.utils import to_datetime
20
23
  from kumoai.experimental.rfm.infer import (
21
- contains_categorical,
22
- contains_id,
23
- contains_multicategorical,
24
- contains_timestamp,
24
+ infer_dtype,
25
25
  infer_primary_key,
26
+ infer_stype,
26
27
  infer_time_column,
27
28
  )
29
+ from kumoai.utils import display, quote_ident
28
30
 
29
31
 
30
32
  class Table(ABC):
@@ -34,53 +36,48 @@ class Table(ABC):
34
36
 
35
37
  Args:
36
38
  name: The name of this table.
39
+ source_name: The source name of this table. If set to ``None``,
40
+ ``name`` is being used.
37
41
  columns: The selected columns of this table.
38
42
  primary_key: The name of the primary key of this table, if it exists.
39
43
  time_column: The name of the time column of this table, if it exists.
40
44
  end_time_column: The name of the end time column of this table, if it
41
45
  exists.
42
46
  """
47
+ _NUM_SAMPLE_ROWS = 1_000
48
+
43
49
  def __init__(
44
50
  self,
45
51
  name: str,
46
- columns: Optional[Sequence[str]] = None,
47
- primary_key: Optional[str] = None,
48
- time_column: Optional[str] = None,
49
- end_time_column: Optional[str] = None,
52
+ source_name: str | None = None,
53
+ columns: Sequence[ColumnSpecType] | None = None,
54
+ primary_key: MissingType | str | None = MissingType.VALUE,
55
+ time_column: str | None = None,
56
+ end_time_column: str | None = None,
50
57
  ) -> None:
51
58
 
52
59
  self._name = name
53
- self._primary_key: Optional[str] = None
54
- self._time_column: Optional[str] = None
55
- self._end_time_column: Optional[str] = None
56
-
57
- if len(self._source_column_dict) == 0:
58
- raise ValueError(f"Table '{name}' does not hold any column with "
59
- f"a supported data type")
60
-
61
- primary_keys = [
62
- column.name for column in self._source_column_dict.values()
63
- if column.is_primary_key
64
- ]
65
- if len(primary_keys) == 1: # NOTE No composite keys yet.
66
- if primary_key is not None and primary_key != primary_keys[0]:
67
- raise ValueError(f"Found duplicate primary key "
68
- f"definition '{primary_key}' and "
69
- f"'{primary_keys[0]}' in table '{name}'")
70
- primary_key = primary_keys[0]
71
-
72
- unique_keys = [
73
- column.name for column in self._source_column_dict.values()
74
- if column.is_unique_key
75
- ]
76
- if primary_key is None and len(unique_keys) == 1:
77
- primary_key = unique_keys[0]
78
-
79
- self._columns: Dict[str, Column] = {}
80
- for column_name in columns or list(self._source_column_dict.keys()):
81
- self.add_column(column_name)
82
-
83
- if primary_key is not None:
60
+ self._source_name = source_name or name
61
+ self._column_dict: dict[str, Column] = {}
62
+ self._primary_key: str | None = None
63
+ self._time_column: str | None = None
64
+ self._end_time_column: str | None = None
65
+ self._expr_sample_df = pd.DataFrame(index=range(self._NUM_SAMPLE_ROWS))
66
+
67
+ if columns is None:
68
+ columns = list(self._source_column_dict.keys())
69
+
70
+ self.add_columns(columns)
71
+
72
+ if isinstance(primary_key, MissingType):
73
+ # Infer primary key from source metadata, but only set it in case
74
+ # it is already part of the column set (don't magically add it):
75
+ if any(column.is_source for column in self.columns):
76
+ primary_key = self._source_primary_key
77
+ if (primary_key is not None and primary_key in self
78
+ and self[primary_key].is_source):
79
+ self.primary_key = primary_key
80
+ elif primary_key is not None:
84
81
  if primary_key not in self:
85
82
  self.add_column(primary_key)
86
83
  self.primary_key = primary_key
@@ -100,13 +97,22 @@ class Table(ABC):
100
97
  r"""The name of this table."""
101
98
  return self._name
102
99
 
103
- # Data column #############################################################
100
+ @property
101
+ def source_name(self) -> str:
102
+ r"""The source name of this table."""
103
+ return self._source_name
104
+
105
+ @property
106
+ def _quoted_source_name(self) -> str:
107
+ return quote_ident(self._source_name)
108
+
109
+ # Column ##################################################################
104
110
 
105
111
  def has_column(self, name: str) -> bool:
106
112
  r"""Returns ``True`` if this table holds a column with name ``name``;
107
113
  ``False`` otherwise.
108
114
  """
109
- return name in self._columns
115
+ return name in self._column_dict
110
116
 
111
117
  def column(self, name: str) -> Column:
112
118
  r"""Returns the data column named with name ``name`` in this table.
@@ -119,65 +125,113 @@ class Table(ABC):
119
125
  """
120
126
  if not self.has_column(name):
121
127
  raise KeyError(f"Column '{name}' not found in table '{self.name}'")
122
- return self._columns[name]
128
+ return self._column_dict[name]
123
129
 
124
130
  @property
125
- def columns(self) -> List[Column]:
131
+ def columns(self) -> list[Column]:
126
132
  r"""Returns a list of :class:`Column` objects that represent the
127
133
  columns in this table.
128
134
  """
129
- return list(self._columns.values())
135
+ return list(self._column_dict.values())
130
136
 
131
- def add_column(self, name: str) -> Column:
132
- r"""Adds a column to this table.
137
+ def add_columns(self, columns: Sequence[ColumnSpecType]) -> None:
138
+ r"""Adds a set of columns to this table.
133
139
 
134
140
  Args:
135
- name: The name of the column.
141
+ columns: The columns to add.
136
142
 
137
143
  Raises:
138
- KeyError: If ``name`` is already present in this table.
144
+ KeyError: If any of the column names already exist in this table.
139
145
  """
140
- if name in self:
141
- raise KeyError(f"Column '{name}' already exists in table "
142
- f"'{self.name}'")
143
-
144
- if name not in self._source_column_dict:
145
- raise KeyError(f"Column '{name}' does not exist in the underlying "
146
- f"source table")
147
-
148
- try:
149
- dtype = self._source_column_dict[name].dtype
150
- except Exception as e:
151
- raise RuntimeError(f"Could not obtain data type for column "
152
- f"'{name}' in table '{self.name}'. Change "
153
- f"the data type of the column in the source "
154
- f"table or remove it from the table.") from e
155
-
156
- try:
157
- ser = self._sample_df[name]
158
- if contains_id(ser, name, dtype):
159
- stype = Stype.ID
160
- elif contains_timestamp(ser, name, dtype):
161
- stype = Stype.timestamp
162
- elif contains_multicategorical(ser, name, dtype):
163
- stype = Stype.multicategorical
164
- elif contains_categorical(ser, name, dtype):
165
- stype = Stype.categorical
166
- else:
167
- stype = dtype.default_stype
168
- except Exception as e:
169
- raise RuntimeError(f"Could not obtain semantic type for column "
170
- f"'{name}' in table '{self.name}'. Change "
171
- f"the data type of the column in the source "
172
- f"table or remove it from the table.") from e
173
-
174
- self._columns[name] = Column(
175
- name=name,
176
- dtype=dtype,
177
- stype=stype,
178
- )
146
+ if len(columns) == 0:
147
+ return
148
+
149
+ column_specs = [ColumnSpec.coerce(column) for column in columns]
150
+
151
+ # Obtain a batch-wise sample for all column expressions:
152
+ expr_specs = [spec for spec in column_specs if not spec.is_source]
153
+ if len(expr_specs) > 0:
154
+ dfs = [
155
+ self._expr_sample_df,
156
+ self._get_expr_sample_df(expr_specs).reset_index(drop=True),
157
+ ]
158
+ size = min(map(len, dfs))
159
+ df = pd.concat([dfs[0].iloc[:size], dfs[1].iloc[:size]], axis=1)
160
+ df = df.loc[:, ~df.columns.duplicated(keep='last')]
161
+ self._expr_sample_df = df
162
+
163
+ for column_spec in column_specs:
164
+ if column_spec.name in self:
165
+ raise KeyError(f"Column '{column_spec.name}' already exists "
166
+ f"in table '{self.name}'")
167
+
168
+ dtype = column_spec.dtype
169
+ stype = column_spec.stype
170
+
171
+ if column_spec.is_source:
172
+ if column_spec.name not in self._source_column_dict:
173
+ raise ValueError(
174
+ f"Column '{column_spec.name}' does not exist in the "
175
+ f"underlying source table")
176
+
177
+ if dtype is None:
178
+ dtype = self._source_column_dict[column_spec.name].dtype
179
+
180
+ if dtype == Dtype.unsupported:
181
+ raise ValueError(
182
+ f"Encountered unsupported data type for column "
183
+ f"'{column_spec.name}' in table '{self.name}'. Please "
184
+ f"either change the column's data type or remove the "
185
+ f"column from this table.")
186
+
187
+ if dtype is None:
188
+ if column_spec.is_source:
189
+ ser = self._source_sample_df[column_spec.name]
190
+ else:
191
+ ser = self._expr_sample_df[column_spec.name]
192
+ try:
193
+ dtype = infer_dtype(ser)
194
+ except Exception as e:
195
+ raise RuntimeError(
196
+ f"Encountered unsupported data type '{ser.dtype}' for "
197
+ f"column '{column_spec.name}' in table '{self.name}'. "
198
+ f"Please either manually override the columns's data "
199
+ f"type or remove the column from this table.") from e
200
+
201
+ if stype is None:
202
+ if column_spec.is_source:
203
+ ser = self._source_sample_df[column_spec.name]
204
+ else:
205
+ ser = self._expr_sample_df[column_spec.name]
206
+ try:
207
+ stype = infer_stype(ser, column_spec.name, dtype)
208
+ except Exception as e:
209
+ raise RuntimeError(
210
+ f"Could not determine semantic type for column "
211
+ f"'{column_spec.name}' with data type '{dtype}' in "
212
+ f"table '{self.name}'. Please either change the "
213
+ f"column's data type or remove the column from this "
214
+ f"table.") from e
215
+
216
+ self._column_dict[column_spec.name] = Column(
217
+ name=column_spec.name,
218
+ expr=column_spec.expr,
219
+ dtype=dtype,
220
+ stype=stype,
221
+ )
222
+
223
+ def add_column(self, column: ColumnSpecType) -> Column:
224
+ r"""Adds a column to this table.
225
+
226
+ Args:
227
+ column: The column to add.
179
228
 
180
- return self._columns[name]
229
+ Raises:
230
+ KeyError: If the column name already exists in this table.
231
+ """
232
+ column_spec = ColumnSpec.coerce(column)
233
+ self.add_columns([column_spec])
234
+ return self[column_spec.name]
181
235
 
182
236
  def remove_column(self, name: str) -> Self:
183
237
  r"""Removes a column from this table.
@@ -197,7 +251,7 @@ class Table(ABC):
197
251
  self.time_column = None
198
252
  if self._end_time_column == name:
199
253
  self.end_time_column = None
200
- del self._columns[name]
254
+ del self._column_dict[name]
201
255
 
202
256
  return self
203
257
 
@@ -210,22 +264,22 @@ class Table(ABC):
210
264
  return self._primary_key is not None
211
265
 
212
266
  @property
213
- def primary_key(self) -> Optional[Column]:
267
+ def primary_key(self) -> Column | None:
214
268
  r"""The primary key column of this table.
215
269
 
216
270
  The getter returns the primary key column of this table, or ``None`` if
217
271
  no such primary key is present.
218
272
 
219
273
  The setter sets a column as a primary key on this table, and raises a
220
- :class:`ValueError` if the primary key has a non-ID semantic type or
221
- if the column name does not match a column in the data frame.
274
+ :class:`ValueError` if the primary key has a non-ID compatible data
275
+ type or if the column name does not match a column in the data frame.
222
276
  """
223
277
  if self._primary_key is None:
224
278
  return None
225
279
  return self[self._primary_key]
226
280
 
227
281
  @primary_key.setter
228
- def primary_key(self, name: Optional[str]) -> None:
282
+ def primary_key(self, name: str | None) -> None:
229
283
  if name is not None and name == self._time_column:
230
284
  raise ValueError(f"Cannot specify column '{name}' as a primary "
231
285
  f"key since it is already defined to be a time "
@@ -255,22 +309,23 @@ class Table(ABC):
255
309
  return self._time_column is not None
256
310
 
257
311
  @property
258
- def time_column(self) -> Optional[Column]:
312
+ def time_column(self) -> Column | None:
259
313
  r"""The time column of this table.
260
314
 
261
315
  The getter returns the time column of this table, or ``None`` if no
262
316
  such time column is present.
263
317
 
264
318
  The setter sets a column as a time column on this table, and raises a
265
- :class:`ValueError` if the time column has a non-timestamp semantic
266
- type or if the column name does not match a column in the data frame.
319
+ :class:`ValueError` if the time column has a non-timestamp compatible
320
+ data type or if the column name does not match a column in the data
321
+ frame.
267
322
  """
268
323
  if self._time_column is None:
269
324
  return None
270
325
  return self[self._time_column]
271
326
 
272
327
  @time_column.setter
273
- def time_column(self, name: Optional[str]) -> None:
328
+ def time_column(self, name: str | None) -> None:
274
329
  if name is not None and name == self._primary_key:
275
330
  raise ValueError(f"Cannot specify column '{name}' as a time "
276
331
  f"column since it is already defined to be a "
@@ -300,7 +355,7 @@ class Table(ABC):
300
355
  return self._end_time_column is not None
301
356
 
302
357
  @property
303
- def end_time_column(self) -> Optional[Column]:
358
+ def end_time_column(self) -> Column | None:
304
359
  r"""The end time column of this table.
305
360
 
306
361
  The getter returns the end time column of this table, or ``None`` if no
@@ -308,15 +363,15 @@ class Table(ABC):
308
363
 
309
364
  The setter sets a column as an end time column on this table, and
310
365
  raises a :class:`ValueError` if the end time column has a non-timestamp
311
- semantic type or if the column name does not match a column in the data
312
- frame.
366
+ compatible data type or if the column name does not match a column in
367
+ the data frame.
313
368
  """
314
369
  if self._end_time_column is None:
315
370
  return None
316
371
  return self[self._end_time_column]
317
372
 
318
373
  @end_time_column.setter
319
- def end_time_column(self, name: Optional[str]) -> None:
374
+ def end_time_column(self, name: str | None) -> None:
320
375
  if name is not None and name == self._primary_key:
321
376
  raise ValueError(f"Cannot specify column '{name}' as an end time "
322
377
  f"column since it is already defined to be a "
@@ -344,39 +399,39 @@ class Table(ABC):
344
399
  r"""Returns a :class:`pandas.DataFrame` object containing metadata
345
400
  information about the columns in this table.
346
401
 
347
- The returned dataframe has columns ``name``, ``dtype``, ``stype``,
348
- ``is_primary_key``, ``is_time_column`` and ``is_end_time_column``,
349
- which provide an aggregate view of the properties of the columns of
350
- this table.
402
+ The returned dataframe has columns ``"Name"``, ``"Data Type"``,
403
+ ``"Semantic Type"``, ``"Primary Key"``, ``"Time Column"`` and
404
+ ``"End Time Column"``, which provide an aggregated view of the
405
+ properties of the columns of this table.
351
406
 
352
407
  Example:
353
408
  >>> # doctest: +SKIP
354
409
  >>> import kumoai.experimental.rfm as rfm
355
410
  >>> table = rfm.LocalTable(df=..., name=...).infer_metadata()
356
411
  >>> table.metadata
357
- name dtype stype is_primary_key is_time_column is_end_time_column
358
- 0 CustomerID float64 ID True False False
412
+ Name Data Type Semantic Type Primary Key Time Column End Time Column
413
+ 0 CustomerID float64 ID True False False
359
414
  """ # noqa: E501
360
415
  cols = self.columns
361
416
 
362
417
  return pd.DataFrame({
363
- 'name':
418
+ 'Name':
364
419
  pd.Series(dtype=str, data=[c.name for c in cols]),
365
- 'dtype':
420
+ 'Data Type':
366
421
  pd.Series(dtype=str, data=[c.dtype for c in cols]),
367
- 'stype':
422
+ 'Semantic Type':
368
423
  pd.Series(dtype=str, data=[c.stype for c in cols]),
369
- 'is_primary_key':
424
+ 'Primary Key':
370
425
  pd.Series(
371
426
  dtype=bool,
372
427
  data=[self._primary_key == c.name for c in cols],
373
428
  ),
374
- 'is_time_column':
429
+ 'Time Column':
375
430
  pd.Series(
376
431
  dtype=bool,
377
432
  data=[self._time_column == c.name for c in cols],
378
433
  ),
379
- 'is_end_time_column':
434
+ 'End Time Column':
380
435
  pd.Series(
381
436
  dtype=bool,
382
437
  data=[self._end_time_column == c.name for c in cols],
@@ -385,33 +440,98 @@ class Table(ABC):
385
440
 
386
441
  def print_metadata(self) -> None:
387
442
  r"""Prints the :meth:`~metadata` of this table."""
388
- num_rows_repr = ''
389
- if self._num_rows is not None:
390
- num_rows_repr = ' ({self._num_rows:,} rows)'
391
-
392
- if in_snowflake_notebook():
393
- import streamlit as st
394
- md_repr = f"### 🏷️ Metadata of Table `{self.name}`{num_rows_repr}"
395
- st.markdown(md_repr)
396
- st.dataframe(self.metadata, hide_index=True)
397
- elif in_notebook():
398
- from IPython.display import Markdown, display
399
- md_repr = f"### 🏷️ Metadata of Table `{self.name}`{num_rows_repr}"
400
- display(Markdown(md_repr))
401
- df = self.metadata
402
- try:
403
- if hasattr(df.style, 'hide'):
404
- display(df.style.hide(axis='index')) # pandas=2
405
- else:
406
- display(df.style.hide_index()) # pandas<1.3
407
- except ImportError:
408
- print(df.to_string(index=False)) # missing jinja2
409
- else:
410
- print(f"🏷️ Metadata of Table '{self.name}'{num_rows_repr}")
411
- print(self.metadata.to_string(index=False))
443
+ msg = f"🏷️ Metadata of Table `{self.name}`"
444
+ if num := self._num_rows:
445
+ msg += " (1 row)" if num == 1 else f" ({num:,} rows)"
446
+
447
+ display.title(msg)
448
+ display.dataframe(self.metadata)
449
+
450
+ def infer_primary_key(self, verbose: bool = True) -> Self:
451
+ r"""Infers the primary key in this table.
452
+
453
+ Args:
454
+ verbose: Whether to print verbose output.
455
+ """
456
+ if self.has_primary_key():
457
+ return self
458
+
459
+ def _set_primary_key(primary_key: str) -> None:
460
+ self.primary_key = primary_key
461
+ if verbose:
462
+ display.message(f"Inferred primary key `{primary_key}` for "
463
+ f"table `{self.name}`")
464
+
465
+ # Inference from source column metadata:
466
+ if any(column.is_source for column in self.columns):
467
+ primary_key = self._source_primary_key
468
+ if (primary_key is not None and primary_key in self
469
+ and self[primary_key].is_source):
470
+ _set_primary_key(primary_key)
471
+ return self
472
+
473
+ unique_keys = [
474
+ column.name for column in self._source_column_dict.values()
475
+ if column.is_unique_key
476
+ ]
477
+ if (len(unique_keys) == 1 # NOTE No composite keys yet.
478
+ and unique_keys[0] in self
479
+ and self[unique_keys[0]].is_source):
480
+ _set_primary_key(unique_keys[0])
481
+ return self
482
+
483
+ # Heuristic-based inference:
484
+ candidates = [
485
+ column.name for column in self.columns if column.stype == Stype.ID
486
+ ]
487
+ if len(candidates) == 0:
488
+ for column in self.columns:
489
+ if self.name.lower() == column.name.lower():
490
+ candidates.append(column.name)
491
+ elif (self.name.lower().endswith('s')
492
+ and self.name.lower()[:-1] == column.name.lower()):
493
+ candidates.append(column.name)
494
+
495
+ if primary_key := infer_primary_key(
496
+ table_name=self.name,
497
+ df=self._get_sample_df(),
498
+ candidates=candidates,
499
+ ):
500
+ _set_primary_key(primary_key)
501
+ return self
502
+
503
+ return self
504
+
505
+ def infer_time_column(self, verbose: bool = True) -> Self:
506
+ r"""Infers the time column in this table.
507
+
508
+ Args:
509
+ verbose: Whether to print verbose output.
510
+ """
511
+ if self.has_time_column():
512
+ return self
513
+
514
+ # Heuristic-based inference:
515
+ candidates = [
516
+ column.name for column in self.columns
517
+ if column.stype == Stype.timestamp
518
+ and column.name != self._end_time_column
519
+ ]
520
+
521
+ if time_column := infer_time_column(
522
+ df=self._get_sample_df(),
523
+ candidates=candidates,
524
+ ):
525
+ self.time_column = time_column
526
+
527
+ if verbose:
528
+ display.message(f"Inferred time column `{time_column}` for "
529
+ f"table `{self.name}`")
530
+
531
+ return self
412
532
 
413
533
  def infer_metadata(self, verbose: bool = True) -> Self:
414
- r"""Infers metadata, *i.e.*, primary keys and time columns, in the
534
+ r"""Infers metadata, *i.e.*, primary keys and time columns, in this
415
535
  table.
416
536
 
417
537
  Args:
@@ -419,48 +539,19 @@ class Table(ABC):
419
539
  """
420
540
  logs = []
421
541
 
422
- # Try to detect primary key if not set:
423
542
  if not self.has_primary_key():
543
+ self.infer_primary_key(verbose=False)
544
+ if self.has_primary_key():
545
+ logs.append(f"primary key `{self._primary_key}`")
424
546
 
425
- def is_candidate(column: Column) -> bool:
426
- if column.stype == Stype.ID:
427
- return True
428
- if all(column.stype != Stype.ID for column in self.columns):
429
- if self.name == column.name:
430
- return True
431
- if (self.name.endswith('s')
432
- and self.name[:-1] == column.name):
433
- return True
434
- return False
435
-
436
- candidates = [
437
- column.name for column in self.columns if is_candidate(column)
438
- ]
439
-
440
- if primary_key := infer_primary_key(
441
- table_name=self.name,
442
- df=self._sample_df,
443
- candidates=candidates,
444
- ):
445
- self.primary_key = primary_key
446
- logs.append(f"primary key '{primary_key}'")
447
-
448
- # Try to detect time column if not set:
449
547
  if not self.has_time_column():
450
- candidates = [
451
- column.name for column in self.columns
452
- if column.stype == Stype.timestamp
453
- and column.name != self._end_time_column
454
- ]
455
- if time_column := infer_time_column(
456
- df=self._sample_df,
457
- candidates=candidates,
458
- ):
459
- self.time_column = time_column
460
- logs.append(f"time column '{time_column}'")
548
+ self.infer_time_column(verbose=False)
549
+ if self.has_time_column():
550
+ logs.append(f"time column `{self._time_column}`")
461
551
 
462
552
  if verbose and len(logs) > 0:
463
- print(f"Detected {' and '.join(logs)} in table '{self.name}'")
553
+ display.message(f"Inferred {' and '.join(logs)} for table "
554
+ f"`{self.name}`")
464
555
 
465
556
  return self
466
557
 
@@ -478,6 +569,100 @@ class Table(ABC):
478
569
  end_time_col=self._end_time_column,
479
570
  )
480
571
 
572
+ @cached_property
573
+ def _source_column_dict(self) -> dict[str, SourceColumn]:
574
+ source_columns = self._get_source_columns()
575
+ if len(source_columns) == 0:
576
+ raise ValueError(f"Table '{self.name}' has no columns")
577
+ return {column.name: column for column in source_columns}
578
+
579
+ @cached_property
580
+ def _source_primary_key(self) -> str | None:
581
+ primary_keys = [
582
+ column.name for column in self._source_column_dict.values()
583
+ if column.is_primary_key
584
+ ]
585
+ # NOTE No composite keys yet.
586
+ return primary_keys[0] if len(primary_keys) == 1 else None
587
+
588
+ @cached_property
589
+ def _source_foreign_key_dict(self) -> dict[str, SourceForeignKey]:
590
+ return {key.name: key for key in self._get_source_foreign_keys()}
591
+
592
+ @cached_property
593
+ def _source_sample_df(self) -> pd.DataFrame:
594
+ return self._get_source_sample_df().reset_index(drop=True)
595
+
596
+ @cached_property
597
+ def _num_rows(self) -> int | None:
598
+ return self._get_num_rows()
599
+
600
+ def _get_sample_df(self) -> pd.DataFrame:
601
+ dfs: list[pd.DataFrame] = []
602
+ if any(column.is_source for column in self.columns):
603
+ dfs.append(self._source_sample_df)
604
+ if any(not column.is_source for column in self.columns):
605
+ dfs.append(self._expr_sample_df)
606
+
607
+ if len(dfs) == 0:
608
+ return pd.DataFrame(index=range(1000))
609
+ if len(dfs) == 1:
610
+ return dfs[0]
611
+
612
+ size = min(map(len, dfs))
613
+ df = pd.concat([dfs[0].iloc[:size], dfs[1].iloc[:size]], axis=1)
614
+ df = df.loc[:, ~df.columns.duplicated(keep='last')]
615
+ return df
616
+
617
+ @staticmethod
618
+ def _sanitize(
619
+ df: pd.DataFrame,
620
+ dtype_dict: dict[str, Dtype | None] | None = None,
621
+ stype_dict: dict[str, Stype | None] | None = None,
622
+ ) -> pd.DataFrame:
623
+ r"""Sanitzes a :class:`pandas.DataFrame` in-place such that its data
624
+ types match table data and semantic type specification.
625
+ """
626
+ def _to_list(ser: pd.Series, dtype: Dtype | None) -> pd.Series:
627
+ if (pd.api.types.is_string_dtype(ser)
628
+ and dtype in {Dtype.intlist, Dtype.floatlist}):
629
+ try:
630
+ ser = ser.map(lambda row: np.fromstring(
631
+ row.strip('[]'),
632
+ sep=',',
633
+ dtype=int if dtype == Dtype.intlist else np.float32,
634
+ ) if row is not None else None)
635
+ except Exception:
636
+ pass
637
+
638
+ if pd.api.types.is_string_dtype(ser):
639
+ try:
640
+ import orjson as json
641
+ except ImportError:
642
+ import json
643
+ try:
644
+ ser = ser.map(lambda row: json.loads(row)
645
+ if row is not None else None)
646
+ except Exception:
647
+ pass
648
+
649
+ return ser
650
+
651
+ for column_name in df.columns:
652
+ dtype = (dtype_dict or {}).get(column_name)
653
+ stype = (stype_dict or {}).get(column_name)
654
+
655
+ if dtype == Dtype.time:
656
+ df[column_name] = to_datetime(df[column_name])
657
+ elif stype == Stype.timestamp:
658
+ df[column_name] = to_datetime(df[column_name])
659
+ elif dtype is not None and dtype.is_list():
660
+ df[column_name] = _to_list(df[column_name], dtype)
661
+ elif stype == Stype.sequence:
662
+ df[column_name] = _to_list(df[column_name], Dtype.floatlist)
663
+
664
+ return df
665
+
481
666
  # Python builtins #########################################################
482
667
 
483
668
  def __hash__(self) -> int:
@@ -512,45 +697,26 @@ class Table(ABC):
512
697
  @abstractmethod
513
698
  def backend(self) -> DataBackend:
514
699
  r"""The data backend of this table."""
515
- pass
516
-
517
- @cached_property
518
- def _source_column_dict(self) -> Dict[str, SourceColumn]:
519
- return {col.name: col for col in self._get_source_columns()}
520
700
 
521
701
  @abstractmethod
522
- def _get_source_columns(self) -> List[SourceColumn]:
702
+ def _get_source_columns(self) -> list[SourceColumn]:
523
703
  pass
524
704
 
525
- @cached_property
526
- def _source_foreign_key_dict(self) -> Dict[str, SourceForeignKey]:
527
- fkeys = self._get_source_foreign_keys()
528
- # NOTE Drop all keys that link to different primary keys in the same
529
- # table since we don't support composite keys yet:
530
- table_pkeys: Dict[str, Set[str]] = defaultdict(set)
531
- for fkey in fkeys:
532
- table_pkeys[fkey.dst_table].add(fkey.primary_key)
533
- return {
534
- fkey.name: fkey
535
- for fkey in fkeys if len(table_pkeys[fkey.dst_table]) == 1
536
- }
537
-
538
705
  @abstractmethod
539
- def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
706
+ def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
540
707
  pass
541
708
 
542
- @cached_property
543
- def _sample_df(self) -> pd.DataFrame:
544
- return self._get_sample_df()
545
-
546
709
  @abstractmethod
547
- def _get_sample_df(self) -> pd.DataFrame:
710
+ def _get_source_sample_df(self) -> pd.DataFrame:
548
711
  pass
549
712
 
550
- @cached_property
551
- def _num_rows(self) -> Optional[int]:
552
- return self._get_num_rows()
713
+ @abstractmethod
714
+ def _get_expr_sample_df(
715
+ self,
716
+ columns: Sequence[ColumnSpec],
717
+ ) -> pd.DataFrame:
718
+ pass
553
719
 
554
720
  @abstractmethod
555
- def _get_num_rows(self) -> Optional[int]:
721
+ def _get_num_rows(self) -> int | None:
556
722
  pass