kumoai 2.12.0.dev202511031731__cp313-cp313-macosx_11_0_arm64.whl → 2.13.0.dev202512061731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/__init__.py +18 -9
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +9 -13
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/rfm.py +35 -7
  6. kumoai/connector/utils.py +23 -2
  7. kumoai/experimental/rfm/__init__.py +164 -46
  8. kumoai/experimental/rfm/backend/__init__.py +0 -0
  9. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  10. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +20 -30
  11. kumoai/experimental/rfm/backend/local/sampler.py +131 -0
  12. kumoai/experimental/rfm/backend/local/table.py +109 -0
  13. kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +117 -0
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
  16. kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
  17. kumoai/experimental/rfm/base/__init__.py +14 -0
  18. kumoai/experimental/rfm/base/column.py +66 -0
  19. kumoai/experimental/rfm/base/sampler.py +287 -0
  20. kumoai/experimental/rfm/base/source.py +18 -0
  21. kumoai/experimental/rfm/{local_table.py → base/table.py} +139 -139
  22. kumoai/experimental/rfm/{local_graph.py → graph.py} +334 -79
  23. kumoai/experimental/rfm/infer/__init__.py +6 -0
  24. kumoai/experimental/rfm/infer/dtype.py +79 -0
  25. kumoai/experimental/rfm/infer/pkey.py +126 -0
  26. kumoai/experimental/rfm/infer/time_col.py +62 -0
  27. kumoai/experimental/rfm/local_graph_sampler.py +43 -4
  28. kumoai/experimental/rfm/local_pquery_driver.py +222 -27
  29. kumoai/experimental/rfm/pquery/__init__.py +0 -4
  30. kumoai/experimental/rfm/pquery/pandas_executor.py +34 -8
  31. kumoai/experimental/rfm/rfm.py +153 -96
  32. kumoai/experimental/rfm/sagemaker.py +138 -0
  33. kumoai/spcs.py +1 -3
  34. kumoai/testing/decorators.py +1 -1
  35. kumoai/utils/progress_logger.py +10 -4
  36. {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/METADATA +12 -2
  37. {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/RECORD +40 -27
  38. kumoai/experimental/rfm/pquery/backend.py +0 -136
  39. kumoai/experimental/rfm/pquery/pandas_backend.py +0 -478
  40. kumoai/experimental/rfm/utils.py +0 -344
  41. {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/WHEEL +0 -0
  42. {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,287 @@
1
+ import copy
2
+ import re
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
10
+ from kumoapi.typing import Stype
11
+
12
+ if TYPE_CHECKING:
13
+ from kumoai.experimental.rfm import Graph
14
+
15
+
16
+ @dataclass
17
+ class EdgeSpec:
18
+ num_neighbors: int | None = None
19
+ time_offsets: tuple[
20
+ pd.DateOffset | None,
21
+ pd.DateOffset,
22
+ ] | None = None
23
+
24
+ def __post_init__(self) -> None:
25
+ if (self.num_neighbors is None) == (self.time_offsets is None):
26
+ raise ValueError("Only one of 'num_neighbors' and 'time_offsets' "
27
+ "must be provided")
28
+
29
+
30
+ @dataclass
31
+ class SamplerOutput:
32
+ df_dict: dict[str, pd.DataFrame]
33
+ inverse_dict: dict[str, np.ndarray]
34
+ batch_dict: dict[str, np.ndarray]
35
+ num_sampled_nodes_dict: dict[str, list[int]]
36
+ row_dict: dict[tuple[str, str, str], np.ndarray] | None = None
37
+ col_dict: dict[tuple[str, str, str], np.ndarray] | None = None
38
+ num_sampled_edges_dict: dict[tuple[str, str, str], list[int]] | None = None
39
+
40
+
41
+ class Sampler(ABC):
42
+ def __init__(self, graph: 'Graph') -> None:
43
+ self._edge_types: list[tuple[str, str, str]] = []
44
+ for edge in graph.edges:
45
+ edge_type = (edge.src_table, edge.fkey, edge.dst_table)
46
+ self._edge_types.append(edge_type)
47
+ self._edge_types.append(Subgraph.rev_edge_type(edge_type))
48
+
49
+ self._primary_key_dict: dict[str, str] = {
50
+ table.name: table._primary_key
51
+ for table in graph.tables.values()
52
+ if table._primary_key is not None
53
+ }
54
+
55
+ self._time_column_dict: dict[str, str] = {
56
+ table.name: table._time_column
57
+ for table in graph.tables.values()
58
+ if table._time_column is not None
59
+ }
60
+
61
+ self._end_time_column_dict: dict[str, str] = {
62
+ table.name: table._end_time_column
63
+ for table in graph.tables.values()
64
+ if table._end_time_column is not None
65
+ }
66
+
67
+ foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
68
+ self._table_stype_dict: dict[str, dict[str, Stype]] = {}
69
+ for table in graph.tables.values():
70
+ self._table_stype_dict[table.name] = {}
71
+ for column in table.columns:
72
+ if column == table.primary_key:
73
+ continue
74
+ if (table.name, column.name) in foreign_keys:
75
+ continue
76
+ self._table_stype_dict[table.name][column.name] = column.stype
77
+
78
+ @property
79
+ def edge_types(self) -> list[tuple[str, str, str]]:
80
+ return self._edge_types
81
+
82
+ @property
83
+ def primary_key_dict(self) -> dict[str, str]:
84
+ return self._primary_key_dict
85
+
86
+ @property
87
+ def time_column_dict(self) -> dict[str, str]:
88
+ return self._time_column_dict
89
+
90
+ @property
91
+ def end_time_column_dict(self) -> dict[str, str]:
92
+ return self._end_time_column_dict
93
+
94
+ @property
95
+ def table_stype_dict(self) -> dict[str, dict[str, Stype]]:
96
+ return self._table_stype_dict
97
+
98
+ def sample_subgraph(
99
+ self,
100
+ entity_table_names: tuple[str, ...],
101
+ entity_pkey: pd.Series,
102
+ anchor_time: pd.Series,
103
+ num_neighbors: list[int],
104
+ exclude_cols_dict: dict[str, list[str]] | None = None,
105
+ ) -> Subgraph:
106
+
107
+ edge_spec_dict: dict[tuple[str, str, str], list[EdgeSpec]] = {
108
+ edge_type: [EdgeSpec(value) for value in num_neighbors]
109
+ for edge_type in self.edge_types
110
+ }
111
+
112
+ # Exclude all columns that leak target information:
113
+ table_stype_dict: dict[str, dict[str, Stype]] = self._table_stype_dict
114
+ if exclude_cols_dict is not None:
115
+ table_stype_dict = copy.deepcopy(table_stype_dict)
116
+ for table_name, exclude_cols in exclude_cols_dict.items():
117
+ for column_name in exclude_cols:
118
+ del table_stype_dict[table_name][column_name]
119
+
120
+ # Collect all columns being used as features:
121
+ column_spec_dict: dict[str, list[str]] = {
122
+ table_name: list(stype_dict.keys())
123
+ for table_name, stype_dict in table_stype_dict.items()
124
+ }
125
+ # Make sure to store primary key information for entity tables:
126
+ for table_name in entity_table_names:
127
+ column_spec_dict[table_name] = (
128
+ [self.primary_key_dict[table_name]] +
129
+ column_spec_dict[table_name])
130
+
131
+ if anchor_time.dtype != 'datetime64[ns]':
132
+ anchor_time = anchor_time.astype('datetime64[ns]')
133
+ out = self.sample(
134
+ entity_table_name=entity_table_names[0],
135
+ entity_pkey=entity_pkey,
136
+ anchor_time=anchor_time,
137
+ column_spec_dict=column_spec_dict,
138
+ edge_spec_dict=edge_spec_dict,
139
+ drop_duplicates=True,
140
+ return_edges=True,
141
+ )
142
+
143
+ subgraph = Subgraph(
144
+ anchor_time=anchor_time.astype(int).to_numpy(),
145
+ table_dict={},
146
+ link_dict={},
147
+ )
148
+
149
+ for table_name, batch in out.batch_dict.items():
150
+ if len(batch) == 0:
151
+ continue
152
+
153
+ primary_key = None
154
+ if table_name in entity_table_names:
155
+ primary_key = self.primary_key_dict.get(table_name, None)
156
+
157
+ df = out.df_dict[table_name].reset_index(drop=True)
158
+ if table_name in self.end_time_column_dict:
159
+ # Set end time to NaT for all values greater than anchor time:
160
+ end_time_column = self.end_time_column_dict[table_name]
161
+ ser = df[end_time_column]
162
+ if ser.dtype != 'datetime64[ns]':
163
+ ser = ser.astype('datetime64[ns]')
164
+ mask = ser > anchor_time.iloc[batch]
165
+ ser.iloc[mask] = pd.NaT
166
+ df[end_time_column] = ser
167
+
168
+ stype_dict = table_stype_dict[table_name]
169
+ for column_name, stype in stype_dict.items():
170
+ if stype == Stype.text:
171
+ df[column_name] = _normalize_text(df[column_name])
172
+
173
+ subgraph.table_dict[table_name] = Table(
174
+ df=df,
175
+ row=out.inverse_dict.get(table_name),
176
+ batch=batch,
177
+ num_sampled_nodes=out.num_sampled_nodes_dict[table_name],
178
+ stype_dict=stype_dict,
179
+ primary_key=primary_key,
180
+ )
181
+
182
+ assert out.row_dict is not None
183
+ assert out.col_dict is not None
184
+ assert out.num_sampled_edges_dict is not None
185
+ for edge_type in out.row_dict.keys():
186
+ row: np.ndarray | None = out.row_dict[edge_type]
187
+ col: np.ndarray | None = out.col_dict[edge_type]
188
+
189
+ if row is None or col is None or len(row) == 0:
190
+ continue
191
+
192
+ # Do not store reverse edge type if it is an exact replica:
193
+ rev_edge_type = Subgraph.rev_edge_type(edge_type)
194
+ if (rev_edge_type in subgraph.link_dict
195
+ and np.array_equal(row, out.col_dict[rev_edge_type])
196
+ and np.array_equal(col, out.row_dict[rev_edge_type])):
197
+ subgraph.link_dict[edge_type] = Link(
198
+ layout=EdgeLayout.REV,
199
+ row=None,
200
+ col=None,
201
+ num_sampled_edges=out.num_sampled_edges_dict[edge_type],
202
+ )
203
+ continue
204
+
205
+ # Do not store non-informative edges:
206
+ layout = EdgeLayout.COO
207
+ if np.array_equal(row, np.arange(len(row))):
208
+ row = None
209
+ if np.array_equal(col, np.arange(len(col))):
210
+ col = None
211
+
212
+ # Store in compressed representation if more efficient:
213
+ num_cols = subgraph.table_dict[edge_type[2]].num_rows
214
+ if col is not None and len(col) > num_cols + 1:
215
+ layout = EdgeLayout.CSC
216
+ colcount = np.bincount(col, minlength=num_cols)
217
+ col = np.empty(num_cols + 1, dtype=col.dtype)
218
+ col[0] = 0
219
+ np.cumsum(colcount, out=col[1:])
220
+
221
+ subgraph.link_dict[edge_type] = Link(
222
+ layout=layout,
223
+ row=row,
224
+ col=col,
225
+ num_sampled_edges=out.num_sampled_edges_dict[edge_type],
226
+ )
227
+
228
+ return subgraph
229
+
230
+ # Abstract Methods ########################################################
231
+
232
+ @abstractmethod
233
+ def sample(
234
+ self,
235
+ entity_table_name: str,
236
+ entity_pkey: pd.Series,
237
+ anchor_time: pd.Series,
238
+ column_spec_dict: dict[str, list[str]],
239
+ edge_spec_dict: dict[tuple[str, str, str], list[EdgeSpec]],
240
+ drop_duplicates: bool = False,
241
+ return_edges: bool = False,
242
+ ) -> SamplerOutput:
243
+ pass
244
+
245
+
246
+ # Helper Functions ############################################################
247
+
248
+ PUNCTUATION = re.compile(r"[\'\"\.,\(\)\!\?\;\:]")
249
+ MULTISPACE = re.compile(r"\s+")
250
+
251
+
252
+ def _normalize_text(
253
+ ser: pd.Series,
254
+ max_words: int | None = 50,
255
+ ) -> pd.Series:
256
+ r"""Normalizes text into a list of lower-case words.
257
+
258
+ Args:
259
+ ser: The :class:`pandas.Series` to normalize.
260
+ max_words: The maximum number of words to return.
261
+ This will auto-shrink any large text column to avoid blowing up
262
+ context size.
263
+ """
264
+ if len(ser) == 0 or pd.api.types.is_list_like(ser.iloc[0]):
265
+ return ser
266
+
267
+ def normalize_fn(line: str) -> list[str]:
268
+ line = PUNCTUATION.sub(" ", line)
269
+ line = re.sub(r"<br\s*/?>", " ", line) # Handle <br /> or <br>
270
+ line = MULTISPACE.sub(" ", line)
271
+ words = line.split()
272
+ if max_words is not None:
273
+ words = words[:max_words]
274
+ return words
275
+
276
+ ser = ser.fillna('').astype(str)
277
+
278
+ if max_words is not None:
279
+ # We estimate the number of words as 5 characters + 1 space in an
280
+ # English text on average. We need this pre-filter here, as word
281
+ # splitting on a giant text can be very expensive:
282
+ ser = ser.str[:6 * max_words]
283
+
284
+ ser = ser.str.lower()
285
+ ser = ser.map(normalize_fn)
286
+
287
+ return ser
@@ -0,0 +1,18 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kumoapi.typing import Dtype
4
+
5
+
6
+ @dataclass
7
+ class SourceColumn:
8
+ name: str
9
+ dtype: Dtype
10
+ is_primary_key: bool
11
+ is_unique_key: bool
12
+
13
+
14
+ @dataclass
15
+ class SourceForeignKey:
16
+ name: str
17
+ dst_table: str
18
+ primary_key: str
@@ -1,115 +1,35 @@
1
- from dataclasses import dataclass
2
- from typing import Any, Dict, List, Optional
1
+ from abc import ABC, abstractmethod
2
+ from collections import defaultdict
3
+ from functools import cached_property
4
+ from typing import Dict, List, Optional, Sequence, Set
3
5
 
4
6
  import pandas as pd
5
7
  from kumoapi.source_table import UnavailableSourceTable
6
8
  from kumoapi.table import Column as ColumnDefinition
7
9
  from kumoapi.table import TableDefinition
8
- from kumoapi.typing import Dtype, Stype
10
+ from kumoapi.typing import Stype
9
11
  from typing_extensions import Self
10
12
 
11
- from kumoai import in_notebook
12
- from kumoai.experimental.rfm import utils
13
+ from kumoai import in_notebook, in_snowflake_notebook
14
+ from kumoai.experimental.rfm.base import Column, SourceColumn, SourceForeignKey
15
+ from kumoai.experimental.rfm.infer import (
16
+ contains_categorical,
17
+ contains_id,
18
+ contains_multicategorical,
19
+ contains_timestamp,
20
+ infer_primary_key,
21
+ infer_time_column,
22
+ )
13
23
 
14
24
 
15
- @dataclass(init=False, repr=False, eq=False)
16
- class Column:
17
- stype: Stype
18
-
19
- def __init__(
20
- self,
21
- name: str,
22
- dtype: Dtype,
23
- stype: Stype,
24
- is_primary_key: bool = False,
25
- is_time_column: bool = False,
26
- is_end_time_column: bool = False,
27
- ) -> None:
28
- self._name = name
29
- self._dtype = Dtype(dtype)
30
- self._is_primary_key = is_primary_key
31
- self._is_time_column = is_time_column
32
- self._is_end_time_column = is_end_time_column
33
- self.stype = Stype(stype)
34
-
35
- @property
36
- def name(self) -> str:
37
- return self._name
38
-
39
- @property
40
- def dtype(self) -> Dtype:
41
- return self._dtype
42
-
43
- def __setattr__(self, key: str, val: Any) -> None:
44
- if key == 'stype':
45
- if isinstance(val, str):
46
- val = Stype(val)
47
- assert isinstance(val, Stype)
48
- if not val.supports_dtype(self.dtype):
49
- raise ValueError(f"Column '{self.name}' received an "
50
- f"incompatible semantic type (got "
51
- f"dtype='{self.dtype}' and stype='{val}')")
52
- if self._is_primary_key and val != Stype.ID:
53
- raise ValueError(f"Primary key '{self.name}' must have 'ID' "
54
- f"semantic type (got '{val}')")
55
- if self._is_time_column and val != Stype.timestamp:
56
- raise ValueError(f"Time column '{self.name}' must have "
57
- f"'timestamp' semantic type (got '{val}')")
58
- if self._is_end_time_column and val != Stype.timestamp:
59
- raise ValueError(f"End time column '{self.name}' must have "
60
- f"'timestamp' semantic type (got '{val}')")
61
-
62
- super().__setattr__(key, val)
63
-
64
- def __hash__(self) -> int:
65
- return hash((self.name, self.stype, self.dtype))
66
-
67
- def __eq__(self, other: Any) -> bool:
68
- if not isinstance(other, Column):
69
- return False
70
- return hash(self) == hash(other)
71
-
72
- def __repr__(self) -> str:
73
- return (f'{self.__class__.__name__}(name={self.name}, '
74
- f'stype={self.stype}, dtype={self.dtype})')
75
-
76
-
77
- class LocalTable:
78
- r"""A table backed by a :class:`pandas.DataFrame`.
79
-
80
- A :class:`LocalTable` fully specifies the relevant metadata, *i.e.*
81
- selected columns, column semantic types, primary keys and time columns.
82
- :class:`LocalTable` is used to create a :class:`LocalGraph`.
83
-
84
- .. code-block:: python
85
-
86
- import pandas as pd
87
- import kumoai.experimental.rfm as rfm
88
-
89
- # Load data from a CSV file:
90
- df = pd.read_csv("data.csv")
91
-
92
- # Create a table from a `pandas.DataFrame` and infer its metadata ...
93
- table = rfm.LocalTable(df, name="my_table").infer_metadata()
94
-
95
- # ... or create a table explicitly:
96
- table = rfm.LocalTable(
97
- df=df,
98
- name="my_table",
99
- primary_key="id",
100
- time_column="time",
101
- end_time_column=None,
102
- )
103
-
104
- # Verify metadata:
105
- table.print_metadata()
106
-
107
- # Change the semantic type of a column:
108
- table[column].stype = "text"
25
+ class Table(ABC):
26
+ r"""A :class:`Table` fully specifies the relevant metadata of a single
27
+ table, *i.e.* its selected columns, data types, semantic types, primary
28
+ keys and time columns.
109
29
 
110
30
  Args:
111
- df: The data frame to create the table from.
112
- name: The name of the table.
31
+ name: The name of this table.
32
+ columns: The selected columns of this table.
113
33
  primary_key: The name of the primary key of this table, if it exists.
114
34
  time_column: The name of the time column of this table, if it exists.
115
35
  end_time_column: The name of the end time column of this table, if it
@@ -117,46 +37,62 @@ class LocalTable:
117
37
  """
118
38
  def __init__(
119
39
  self,
120
- df: pd.DataFrame,
121
40
  name: str,
41
+ columns: Optional[Sequence[str]] = None,
122
42
  primary_key: Optional[str] = None,
123
43
  time_column: Optional[str] = None,
124
44
  end_time_column: Optional[str] = None,
125
45
  ) -> None:
126
46
 
127
- if df.empty:
128
- raise ValueError("Data frame must have at least one row")
129
- if isinstance(df.columns, pd.MultiIndex):
130
- raise ValueError("Data frame must not have a multi-index")
131
- if not df.columns.is_unique:
132
- raise ValueError("Data frame must have unique column names")
133
- if any(col == '' for col in df.columns):
134
- raise ValueError("Data frame must have non-empty column names")
135
-
136
- df = df.copy(deep=False)
137
-
138
- self._data = df
139
47
  self._name = name
140
48
  self._primary_key: Optional[str] = None
141
49
  self._time_column: Optional[str] = None
142
50
  self._end_time_column: Optional[str] = None
143
51
 
52
+ if len(self._source_column_dict) == 0:
53
+ raise ValueError(f"Table '{name}' does not hold any column with "
54
+ f"a supported data type")
55
+
56
+ primary_keys = [
57
+ column.name for column in self._source_column_dict.values()
58
+ if column.is_primary_key
59
+ ]
60
+ if len(primary_keys) == 1: # NOTE No composite keys yet.
61
+ if primary_key is not None and primary_key != primary_keys[0]:
62
+ raise ValueError(f"Found duplicate primary key "
63
+ f"definition '{primary_key}' and "
64
+ f"'{primary_keys[0]}' in table '{name}'")
65
+ primary_key = primary_keys[0]
66
+
67
+ unique_keys = [
68
+ column.name for column in self._source_column_dict.values()
69
+ if column.is_unique_key
70
+ ]
71
+ if primary_key is None and len(unique_keys) == 1:
72
+ primary_key = unique_keys[0]
73
+
144
74
  self._columns: Dict[str, Column] = {}
145
- for column_name in df.columns:
75
+ for column_name in columns or list(self._source_column_dict.keys()):
146
76
  self.add_column(column_name)
147
77
 
148
78
  if primary_key is not None:
79
+ if primary_key not in self:
80
+ self.add_column(primary_key)
149
81
  self.primary_key = primary_key
150
82
 
151
83
  if time_column is not None:
84
+ if time_column not in self:
85
+ self.add_column(time_column)
152
86
  self.time_column = time_column
153
87
 
154
88
  if end_time_column is not None:
89
+ if end_time_column not in self:
90
+ self.add_column(end_time_column)
155
91
  self.end_time_column = end_time_column
156
92
 
157
93
  @property
158
94
  def name(self) -> str:
159
- r"""The name of the table."""
95
+ r"""The name of this table."""
160
96
  return self._name
161
97
 
162
98
  # Data column #############################################################
@@ -200,24 +136,35 @@ class LocalTable:
200
136
  raise KeyError(f"Column '{name}' already exists in table "
201
137
  f"'{self.name}'")
202
138
 
203
- if name not in self._data.columns:
204
- raise KeyError(f"Column '{name}' does not exist in the underyling "
205
- f"data frame")
139
+ if name not in self._source_column_dict:
140
+ raise KeyError(f"Column '{name}' does not exist in the underlying "
141
+ f"source table")
206
142
 
207
143
  try:
208
- dtype = utils.to_dtype(self._data[name])
144
+ dtype = self._source_column_dict[name].dtype
209
145
  except Exception as e:
210
- raise RuntimeError(f"Data type inference for column '{name}' in "
211
- f"table '{self.name}' failed. Consider "
212
- f"changing the data type of the column or "
213
- f"removing it from the table.") from e
146
+ raise RuntimeError(f"Could not obtain data type for column "
147
+ f"'{name}' in table '{self.name}'. Change "
148
+ f"the data type of the column in the source "
149
+ f"table or remove it from the table.") from e
150
+
214
151
  try:
215
- stype = utils.infer_stype(self._data[name], name, dtype)
152
+ ser = self._sample_df[name]
153
+ if contains_id(ser, name, dtype):
154
+ stype = Stype.ID
155
+ elif contains_timestamp(ser, name, dtype):
156
+ stype = Stype.timestamp
157
+ elif contains_multicategorical(ser, name, dtype):
158
+ stype = Stype.multicategorical
159
+ elif contains_categorical(ser, name, dtype):
160
+ stype = Stype.categorical
161
+ else:
162
+ stype = dtype.default_stype
216
163
  except Exception as e:
217
- raise RuntimeError(f"Semantic type inference for column '{name}' "
218
- f"in table '{self.name}' failed. Consider "
219
- f"changing the data type of the column or "
220
- f"removing it from the table.") from e
164
+ raise RuntimeError(f"Could not obtain semantic type for column "
165
+ f"'{name}' in table '{self.name}'. Change "
166
+ f"the data type of the column in the source "
167
+ f"table or remove it from the table.") from e
221
168
 
222
169
  self._columns[name] = Column(
223
170
  name=name,
@@ -432,12 +379,20 @@ class LocalTable:
432
379
  })
433
380
 
434
381
  def print_metadata(self) -> None:
435
- r"""Prints the :meth:`~LocalTable.metadata` of the table."""
436
- if in_notebook():
382
+ r"""Prints the :meth:`~metadata` of this table."""
383
+ num_rows_repr = ''
384
+ if self._num_rows is not None:
385
+ num_rows_repr = ' ({self._num_rows:,} rows)'
386
+
387
+ if in_snowflake_notebook():
388
+ import streamlit as st
389
+ md_repr = f"### 🏷️ Metadata of Table `{self.name}`{num_rows_repr}"
390
+ st.markdown(md_repr)
391
+ st.dataframe(self.metadata, hide_index=True)
392
+ elif in_notebook():
437
393
  from IPython.display import Markdown, display
438
- display(
439
- Markdown(f"### 🏷️ Metadata of Table `{self.name}` "
440
- f"({len(self._data):,} rows)"))
394
+ md_repr = f"### 🏷️ Metadata of Table `{self.name}`{num_rows_repr}"
395
+ display(Markdown(md_repr))
441
396
  df = self.metadata
442
397
  try:
443
398
  if hasattr(df.style, 'hide'):
@@ -447,8 +402,7 @@ class LocalTable:
447
402
  except ImportError:
448
403
  print(df.to_string(index=False)) # missing jinja2
449
404
  else:
450
- print(f"🏷️ Metadata of Table '{self.name}' "
451
- f"({len(self._data):,} rows):")
405
+ print(f"🏷️ Metadata of Table '{self.name}'{num_rows_repr}")
452
406
  print(self.metadata.to_string(index=False))
453
407
 
454
408
  def infer_metadata(self, verbose: bool = True) -> Self:
@@ -478,9 +432,9 @@ class LocalTable:
478
432
  column.name for column in self.columns if is_candidate(column)
479
433
  ]
480
434
 
481
- if primary_key := utils.detect_primary_key(
435
+ if primary_key := infer_primary_key(
482
436
  table_name=self.name,
483
- df=self._data,
437
+ df=self._sample_df,
484
438
  candidates=candidates,
485
439
  ):
486
440
  self.primary_key = primary_key
@@ -493,7 +447,10 @@ class LocalTable:
493
447
  if column.stype == Stype.timestamp
494
448
  and column.name != self._end_time_column
495
449
  ]
496
- if time_column := utils.detect_time_column(self._data, candidates):
450
+ if time_column := infer_time_column(
451
+ df=self._sample_df,
452
+ candidates=candidates,
453
+ ):
497
454
  self.time_column = time_column
498
455
  logs.append(f"time column '{time_column}'")
499
456
 
@@ -543,3 +500,46 @@ class LocalTable:
543
500
  f' time_column={self._time_column},\n'
544
501
  f' end_time_column={self._end_time_column},\n'
545
502
  f')')
503
+
504
+ # Abstract Methods ########################################################
505
+
506
+ @cached_property
507
+ def _source_column_dict(self) -> Dict[str, SourceColumn]:
508
+ return {col.name: col for col in self._get_source_columns()}
509
+
510
+ @abstractmethod
511
+ def _get_source_columns(self) -> List[SourceColumn]:
512
+ pass
513
+
514
+ @cached_property
515
+ def _source_foreign_key_dict(self) -> Dict[str, SourceForeignKey]:
516
+ fkeys = self._get_source_foreign_keys()
517
+ # NOTE Drop all keys that link to different primary keys in the same
518
+ # table since we don't support composite keys yet:
519
+ table_pkeys: Dict[str, Set[str]] = defaultdict(set)
520
+ for fkey in fkeys:
521
+ table_pkeys[fkey.dst_table].add(fkey.primary_key)
522
+ return {
523
+ fkey.name: fkey
524
+ for fkey in fkeys if len(table_pkeys[fkey.dst_table]) == 1
525
+ }
526
+
527
+ @abstractmethod
528
+ def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
529
+ pass
530
+
531
+ @cached_property
532
+ def _sample_df(self) -> pd.DataFrame:
533
+ return self._get_sample_df()
534
+
535
+ @abstractmethod
536
+ def _get_sample_df(self) -> pd.DataFrame:
537
+ pass
538
+
539
+ @cached_property
540
+ def _num_rows(self) -> Optional[int]:
541
+ return self._get_num_rows()
542
+
543
+ @abstractmethod
544
+ def _get_num_rows(self) -> Optional[int]:
545
+ pass