kumoai 2.14.0.dev202512191731__cp311-cp311-macosx_11_0_arm64.whl → 2.14.0.dev202601051732__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +24 -22
  7. kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
  8. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  9. kumoai/experimental/rfm/backend/local/table.py +24 -25
  10. kumoai/experimental/rfm/backend/snow/sampler.py +106 -61
  11. kumoai/experimental/rfm/backend/snow/table.py +137 -64
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +127 -78
  13. kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
  14. kumoai/experimental/rfm/base/__init__.py +6 -9
  15. kumoai/experimental/rfm/base/column.py +95 -11
  16. kumoai/experimental/rfm/base/expression.py +44 -0
  17. kumoai/experimental/rfm/base/sampler.py +5 -17
  18. kumoai/experimental/rfm/base/source.py +1 -1
  19. kumoai/experimental/rfm/base/sql_sampler.py +69 -9
  20. kumoai/experimental/rfm/base/table.py +258 -97
  21. kumoai/experimental/rfm/graph.py +106 -98
  22. kumoai/experimental/rfm/infer/dtype.py +4 -1
  23. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  24. kumoai/experimental/rfm/relbench.py +76 -0
  25. kumoai/experimental/rfm/rfm.py +394 -241
  26. kumoai/experimental/rfm/task_table.py +290 -0
  27. kumoai/trainer/distilled_trainer.py +175 -0
  28. kumoai/utils/display.py +51 -0
  29. kumoai/utils/progress_logger.py +13 -1
  30. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/METADATA +1 -1
  31. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/RECORD +34 -31
  32. kumoai/experimental/rfm/base/column_expression.py +0 -50
  33. kumoai/experimental/rfm/base/sql_table.py +0 -229
  34. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/WHEEL +0 -0
  35. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/licenses/LICENSE +0 -0
  36. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,20 @@
1
1
  import json
2
2
  from collections.abc import Iterator
3
3
  from contextlib import contextmanager
4
+ from typing import TYPE_CHECKING
4
5
 
5
6
  import numpy as np
6
7
  import pandas as pd
7
8
  import pyarrow as pa
8
9
  from kumoapi.pquery import ValidatedPredictiveQuery
9
10
 
10
- from kumoai.experimental.rfm.backend.snow import Connection
11
- from kumoai.experimental.rfm.base import SQLSampler
11
+ from kumoai.experimental.rfm.backend.snow import Connection, SnowTable
12
+ from kumoai.experimental.rfm.base import SQLSampler, Table
12
13
  from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
13
- from kumoai.utils import quote_ident
14
+ from kumoai.utils import ProgressLogger
15
+
16
+ if TYPE_CHECKING:
17
+ from kumoai.experimental.rfm import Graph
14
18
 
15
19
 
16
20
  @contextmanager
@@ -22,18 +26,30 @@ def paramstyle(connection: Connection, style: str = 'qmark') -> Iterator[None]:
22
26
 
23
27
 
24
28
  class SnowSampler(SQLSampler):
29
+ def __init__(
30
+ self,
31
+ graph: 'Graph',
32
+ verbose: bool | ProgressLogger = True,
33
+ ) -> None:
34
+ super().__init__(graph=graph, verbose=verbose)
35
+
36
+ for table in graph.tables.values():
37
+ assert isinstance(table, SnowTable)
38
+ self._connection = table._connection
39
+
25
40
  def _get_min_max_time_dict(
26
41
  self,
27
42
  table_names: list[str],
28
43
  ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
29
44
  selects: list[str] = []
30
45
  for table_name in table_names:
31
- time_column = self.time_column_dict[table_name]
46
+ column = self.time_column_dict[table_name]
47
+ column_ref = self.table_column_ref_dict[table_name][column]
32
48
  select = (f"SELECT\n"
33
49
  f" ? as table_name,\n"
34
- f" MIN({quote_ident(time_column)}) as min_date,\n"
35
- f" MAX({quote_ident(time_column)}) as max_date\n"
36
- f"FROM {self.fqn_dict[table_name]}")
50
+ f" MIN({column_ref}) as min_date,\n"
51
+ f" MAX({column_ref}) as max_date\n"
52
+ f"FROM {self.source_name_dict[table_name]}")
37
53
  selects.append(select)
38
54
  sql = "\nUNION ALL\n".join(selects)
39
55
 
@@ -59,17 +75,27 @@ class SnowSampler(SQLSampler):
59
75
  # NOTE Snowflake does support `SEED` only as part of `SYSTEM` sampling.
60
76
  num_rows = min(num_rows, 1_000_000) # Snowflake's upper limit.
61
77
 
78
+ source_table = self.source_table_dict[table_name]
62
79
  filters: list[str] = []
63
- primary_key = self.primary_key_dict[table_name]
64
- if self.source_table_dict[table_name][primary_key].is_nullable:
65
- filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
66
- time_column = self.time_column_dict.get(table_name)
67
- if (time_column is not None and
68
- self.source_table_dict[table_name][time_column].is_nullable):
69
- filters.append(f" {quote_ident(time_column)} IS NOT NULL")
70
-
71
- sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
72
- f"FROM {self.fqn_dict[table_name]}\n"
80
+
81
+ key = self.primary_key_dict[table_name]
82
+ if key not in source_table or source_table[key].is_nullable:
83
+ key_ref = self.table_column_ref_dict[table_name][key]
84
+ filters.append(f" {key_ref} IS NOT NULL")
85
+
86
+ column = self.time_column_dict.get(table_name)
87
+ if column is None:
88
+ pass
89
+ elif column not in source_table or source_table[column].is_nullable:
90
+ column_ref = self.table_column_ref_dict[table_name][column]
91
+ filters.append(f" {column_ref} IS NOT NULL")
92
+
93
+ projections = [
94
+ self.table_column_proj_dict[table_name][column]
95
+ for column in columns
96
+ ]
97
+ sql = (f"SELECT {', '.join(projections)}\n"
98
+ f"FROM {self.source_name_dict[table_name]}\n"
73
99
  f"SAMPLE ROW ({num_rows} ROWS)")
74
100
  if len(filters) > 0:
75
101
  sql += f"\nWHERE{' AND'.join(filters)}"
@@ -79,7 +105,11 @@ class SnowSampler(SQLSampler):
79
105
  cursor.execute(sql)
80
106
  table = cursor.fetch_arrow_all()
81
107
 
82
- return self._sanitize(table_name, table)
108
+ return Table._sanitize(
109
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
110
+ dtype_dict=self.table_dtype_dict[table_name],
111
+ stype_dict=self.table_stype_dict[table_name],
112
+ )
83
113
 
84
114
  def _sample_target(
85
115
  self,
@@ -152,42 +182,51 @@ class SnowSampler(SQLSampler):
152
182
  pkey: pd.Series,
153
183
  columns: set[str],
154
184
  ) -> tuple[pd.DataFrame, np.ndarray]:
155
-
156
- pkey_name = self.primary_key_dict[table_name]
157
- source_table = self.source_table_dict[table_name]
185
+ key = self.primary_key_dict[table_name]
186
+ key_ref = self.table_column_ref_dict[table_name][key]
187
+ projections = [
188
+ self.table_column_proj_dict[table_name][column]
189
+ for column in columns
190
+ ]
158
191
 
159
192
  payload = json.dumps(list(pkey))
160
193
 
161
194
  sql = ("WITH TMP as (\n"
162
195
  " SELECT\n"
163
- " f.index as BATCH,\n")
164
- if source_table[pkey_name].dtype.is_int():
165
- sql += " f.value::NUMBER as ID\n"
166
- elif source_table[pkey_name].dtype.is_float():
167
- sql += " f.value::FLOAT as ID\n"
196
+ " f.index as __KUMO_BATCH__,\n")
197
+ if self.table_dtype_dict[table_name][key].is_int():
198
+ sql += " f.value::NUMBER as __KUMO_ID__\n"
199
+ elif self.table_dtype_dict[table_name][key].is_float():
200
+ sql += " f.value::FLOAT as __KUMO_ID__\n"
168
201
  else:
169
- sql += " f.value::VARCHAR as ID\n"
202
+ sql += " f.value::VARCHAR as __KUMO_ID__\n"
170
203
  sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
171
204
  f")\n"
172
- f"SELECT TMP.BATCH as __BATCH__, "
173
- f"{', '.join('ENT.' + quote_ident(col) for col in columns)}\n"
205
+ f"SELECT "
206
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
207
+ f"{', '.join(projections)}\n"
174
208
  f"FROM TMP\n"
175
- f"JOIN {self.fqn_dict[table_name]} ENT\n"
176
- f" ON ENT.{quote_ident(pkey_name)} = TMP.ID")
209
+ f"JOIN {self.source_name_dict[table_name]} ENT\n"
210
+ f" ON {key_ref} = TMP.__KUMO_ID__")
177
211
 
178
212
  with paramstyle(self._connection), self._connection.cursor() as cursor:
179
213
  cursor.execute(sql, (payload, ))
180
214
  table = cursor.fetch_arrow_all()
181
215
 
182
216
  # Remove any duplicated primary keys in post-processing:
183
- tmp = table.append_column('__TMP__', pa.array(range(len(table))))
184
- gb = tmp.group_by('__BATCH__').aggregate([('__TMP__', 'min')])
185
- table = table.take(gb['__TMP___min'])
217
+ tmp = table.append_column('__KUMO_ID__', pa.array(range(len(table))))
218
+ gb = tmp.group_by('__KUMO_BATCH__').aggregate([('__KUMO_ID__', 'min')])
219
+ table = table.take(gb['__KUMO_ID___min'])
186
220
 
187
- batch = table['__BATCH__'].cast(pa.int64()).to_numpy()
188
- table = table.remove_column(table.schema.get_field_index('__BATCH__'))
221
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
222
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
223
+ table = table.remove_column(batch_index)
189
224
 
190
- return table.to_pandas(), batch # TODO Use `self._sanitize`.
225
+ return Table._sanitize(
226
+ df=table.to_pandas(),
227
+ dtype_dict=self.table_dtype_dict[table_name],
228
+ stype_dict=self.table_stype_dict[table_name],
229
+ ), batch
191
230
 
192
231
  # Helper Methods ##########################################################
193
232
 
@@ -201,6 +240,7 @@ class SnowSampler(SQLSampler):
201
240
  max_offset: pd.DateOffset,
202
241
  columns: set[str],
203
242
  ) -> tuple[pd.DataFrame, np.ndarray]:
243
+ time_column = self.time_column_dict[table_name]
204
244
 
205
245
  end_time = anchor_time + max_offset
206
246
  end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
@@ -211,42 +251,47 @@ class SnowSampler(SQLSampler):
211
251
  else:
212
252
  payload = json.dumps(list(zip(pkey, end_time)))
213
253
 
214
- # Based on benchmarking, JSON payload is the fastest way to query by
215
- # custom indices (compared to large `IN` clauses or temporary tables):
216
- source_table = self.source_table_dict[table_name]
217
- time_column = self.time_column_dict[table_name]
254
+ key_ref = self.table_column_ref_dict[table_name][fkey]
255
+ time_ref = self.table_column_ref_dict[table_name][time_column]
256
+ projections = [
257
+ self.table_column_proj_dict[table_name][column]
258
+ for column in columns
259
+ ]
218
260
  sql = ("WITH TMP as (\n"
219
261
  " SELECT\n"
220
- " f.index as BATCH,\n")
221
- if source_table[fkey].dtype.is_int():
222
- sql += " f.value[0]::NUMBER as ID,\n"
223
- elif source_table[fkey].dtype.is_float():
224
- sql += " f.value[0]::FLOAT as ID,\n"
262
+ " f.index as __KUMO_BATCH__,\n")
263
+ if self.table_dtype_dict[table_name][fkey].is_int():
264
+ sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
265
+ elif self.table_dtype_dict[table_name][fkey].is_float():
266
+ sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
225
267
  else:
226
- sql += " f.value[0]::VARCHAR as ID,\n"
227
- sql += " f.value[1]::TIMESTAMP_NTZ as END_TIME"
268
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
269
+ sql += " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__"
228
270
  if min_offset is not None:
229
- sql += ",\n f.value[2]::TIMESTAMP_NTZ as START_TIME"
271
+ sql += ",\n f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__"
230
272
  sql += (f"\n"
231
273
  f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
232
274
  f")\n"
233
- f"SELECT TMP.BATCH as __BATCH__, "
234
- f"{', '.join('FACT.' + quote_ident(col) for col in columns)}\n"
275
+ f"SELECT "
276
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
277
+ f"{', '.join(projections)}\n"
235
278
  f"FROM TMP\n"
236
- f"JOIN {self.fqn_dict[table_name]} FACT\n"
237
- f" ON FACT.{quote_ident(fkey)} = TMP.ID\n"
238
- f" AND FACT.{quote_ident(time_column)} <= TMP.END_TIME")
279
+ f"JOIN {self.source_name_dict[table_name]} FACT\n"
280
+ f" ON {key_ref} = TMP.__KUMO_ID__\n"
281
+ f" AND {time_ref} <= TMP.__KUMO_END_TIME__")
239
282
  if min_offset is not None:
240
- sql += f"\n AND FACT.{quote_ident(time_column)} > TMP.START_TIME"
283
+ sql += f"\n AND {time_ref} > TMP.__KUMO_START_TIME__"
241
284
 
242
285
  with paramstyle(self._connection), self._connection.cursor() as cursor:
243
286
  cursor.execute(sql, (payload, ))
244
287
  table = cursor.fetch_arrow_all()
245
288
 
246
- batch = table['__BATCH__'].cast(pa.int64()).to_numpy()
247
- table = table.remove_column(table.schema.get_field_index('__BATCH__'))
248
-
249
- return self._sanitize(table_name, table), batch
289
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
290
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
291
+ table = table.remove_column(batch_index)
250
292
 
251
- def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
252
- return table.to_pandas(types_mapper=pd.ArrowDtype)
293
+ return Table._sanitize(
294
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
295
+ dtype_dict=self.table_dtype_dict[table_name],
296
+ stype_dict=self.table_stype_dict[table_name],
297
+ ), batch
@@ -1,4 +1,5 @@
1
1
  import re
2
+ from collections import Counter
2
3
  from collections.abc import Sequence
3
4
  from typing import cast
4
5
 
@@ -8,28 +9,27 @@ from kumoapi.typing import Dtype
8
9
 
9
10
  from kumoai.experimental.rfm.backend.snow import Connection
10
11
  from kumoai.experimental.rfm.base import (
11
- ColumnExpressionSpec,
12
- ColumnExpressionType,
12
+ ColumnSpec,
13
+ ColumnSpecType,
13
14
  DataBackend,
14
15
  SourceColumn,
15
16
  SourceForeignKey,
16
- SQLTable,
17
+ Table,
17
18
  )
18
19
  from kumoai.utils import quote_ident
19
20
 
20
21
 
21
- class SnowTable(SQLTable):
22
+ class SnowTable(Table):
22
23
  r"""A table backed by a :class:`sqlite` database.
23
24
 
24
25
  Args:
25
26
  connection: The connection to a :class:`snowflake` database.
26
- name: The logical name of this table.
27
- source_name: The physical name of this table in the database. If set to
28
- ``None``, ``name`` is being used.
27
+ name: The name of this table.
28
+ source_name: The source name of this table. If set to ``None``,
29
+ ``name`` is being used.
29
30
  database: The database.
30
31
  schema: The schema.
31
- columns: The selected physical columns of this table.
32
- column_expressions: The logical columns of this table.
32
+ columns: The selected columns of this table.
33
33
  primary_key: The name of the primary key of this table, if it exists.
34
34
  time_column: The name of the time column of this table, if it exists.
35
35
  end_time_column: The name of the end time column of this table, if it
@@ -42,14 +42,21 @@ class SnowTable(SQLTable):
42
42
  source_name: str | None = None,
43
43
  database: str | None = None,
44
44
  schema: str | None = None,
45
- columns: Sequence[str] | None = None,
46
- column_expressions: Sequence[ColumnExpressionType] | None = None,
45
+ columns: Sequence[ColumnSpecType] | None = None,
47
46
  primary_key: MissingType | str | None = MissingType.VALUE,
48
47
  time_column: str | None = None,
49
48
  end_time_column: str | None = None,
50
49
  ) -> None:
51
50
 
52
- if database is not None and schema is None:
51
+ if database is None or schema is None:
52
+ with connection.cursor() as cursor:
53
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
54
+ result = cursor.fetchone()
55
+ database = database or result[0]
56
+ assert database is not None
57
+ schema = schema or result[1]
58
+
59
+ if schema is None:
53
60
  raise ValueError(f"Unspecified 'schema' for table "
54
61
  f"'{source_name or name}' in database "
55
62
  f"'{database}'")
@@ -62,37 +69,22 @@ class SnowTable(SQLTable):
62
69
  name=name,
63
70
  source_name=source_name,
64
71
  columns=columns,
65
- column_expressions=column_expressions,
66
72
  primary_key=primary_key,
67
73
  time_column=time_column,
68
74
  end_time_column=end_time_column,
69
75
  )
70
76
 
71
- @staticmethod
72
- def to_dtype(snowflake_dtype: str | None) -> Dtype | None:
73
- if snowflake_dtype is None:
74
- return None
75
- snowflake_dtype = snowflake_dtype.strip().upper()
76
- # TODO 'NUMBER(...)' is not always an integer!
77
- if snowflake_dtype.startswith('NUMBER'):
78
- return Dtype.int
79
- elif snowflake_dtype.startswith('VARCHAR'):
80
- return Dtype.string
81
- elif snowflake_dtype == 'FLOAT':
82
- return Dtype.float
83
- elif snowflake_dtype == 'BOOLEAN':
84
- return Dtype.bool
85
- elif re.search('DATE|TIMESTAMP', snowflake_dtype):
86
- return Dtype.date
87
- return None
88
-
89
77
  @property
90
- def backend(self) -> DataBackend:
91
- return cast(DataBackend, DataBackend.SNOWFLAKE)
78
+ def source_name(self) -> str:
79
+ names: list[str] = []
80
+ if self._database is not None:
81
+ names.append(self._database)
82
+ if self._schema is not None:
83
+ names.append(self._schema)
84
+ return '.'.join(names + [self._source_name])
92
85
 
93
86
  @property
94
- def fqn(self) -> str:
95
- r"""The fully-qualified quoted table name."""
87
+ def _quoted_source_name(self) -> str:
96
88
  names: list[str] = []
97
89
  if self._database is not None:
98
90
  names.append(quote_ident(self._database))
@@ -100,32 +92,26 @@ class SnowTable(SQLTable):
100
92
  names.append(quote_ident(self._schema))
101
93
  return '.'.join(names + [quote_ident(self._source_name)])
102
94
 
95
+ @property
96
+ def backend(self) -> DataBackend:
97
+ return cast(DataBackend, DataBackend.SNOWFLAKE)
98
+
103
99
  def _get_source_columns(self) -> list[SourceColumn]:
104
100
  source_columns: list[SourceColumn] = []
105
101
  with self._connection.cursor() as cursor:
106
102
  try:
107
- sql = f"DESCRIBE TABLE {self.fqn}"
103
+ sql = f"DESCRIBE TABLE {self._quoted_source_name}"
108
104
  cursor.execute(sql)
109
105
  except Exception as e:
110
- names: list[str] = []
111
- if self._database is not None:
112
- names.append(self._database)
113
- if self._schema is not None:
114
- names.append(self._schema)
115
- source_name = '.'.join(names + [self._source_name])
116
- raise ValueError(f"Table '{source_name}' does not exist in "
117
- f"the remote data backend") from e
106
+ raise ValueError(f"Table '{self.source_name}' does not exist "
107
+ f"in the remote data backend") from e
118
108
 
119
109
  for row in cursor.fetchall():
120
- column, type, _, null, _, is_pkey, is_unique, *_ = row
121
-
122
- dtype = self.to_dtype(type)
123
- if dtype is None:
124
- continue
110
+ column, dtype, _, null, _, is_pkey, is_unique, *_ = row
125
111
 
126
112
  source_column = SourceColumn(
127
113
  name=column,
128
- dtype=dtype,
114
+ dtype=self._to_dtype(dtype),
129
115
  is_primary_key=is_pkey.strip().upper() == 'Y',
130
116
  is_unique_key=is_unique.strip().upper() == 'Y',
131
117
  is_nullable=null.strip().upper() == 'Y',
@@ -135,35 +121,122 @@ class SnowTable(SQLTable):
135
121
  return source_columns
136
122
 
137
123
  def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
138
- source_fkeys: list[SourceForeignKey] = []
124
+ source_foreign_keys: list[SourceForeignKey] = []
139
125
  with self._connection.cursor() as cursor:
140
- sql = f"SHOW IMPORTED KEYS IN TABLE {self.fqn}"
126
+ sql = f"SHOW IMPORTED KEYS IN TABLE {self._quoted_source_name}"
141
127
  cursor.execute(sql)
142
- for row in cursor.fetchall():
143
- _, _, _, dst_table, pkey, _, _, _, fkey, *_ = row
144
- source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
145
- return source_fkeys
128
+ rows = cursor.fetchall()
129
+ counts = Counter(row[13] for row in rows)
130
+ for row in rows:
131
+ if counts[row[13]] == 1:
132
+ source_foreign_key = SourceForeignKey(
133
+ name=row[8],
134
+ dst_table=f'{row[1]}.{row[2]}.{row[3]}',
135
+ primary_key=row[4],
136
+ )
137
+ source_foreign_keys.append(source_foreign_key)
138
+ return source_foreign_keys
146
139
 
147
140
  def _get_source_sample_df(self) -> pd.DataFrame:
148
141
  with self._connection.cursor() as cursor:
149
142
  columns = [quote_ident(col) for col in self._source_column_dict]
150
- sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
143
+ sql = (f"SELECT {', '.join(columns)} "
144
+ f"FROM {self._quoted_source_name} "
145
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
151
146
  cursor.execute(sql)
152
147
  table = cursor.fetch_arrow_all()
153
- return table.to_pandas(types_mapper=pd.ArrowDtype)
148
+
149
+ if table is None:
150
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
151
+
152
+ return self._sanitize(
153
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
154
+ dtype_dict={
155
+ column.name: column.dtype
156
+ for column in self._source_column_dict.values()
157
+ },
158
+ stype_dict=None,
159
+ )
154
160
 
155
161
  def _get_num_rows(self) -> int | None:
156
162
  return None
157
163
 
158
- def _get_expression_sample_df(
164
+ def _get_expr_sample_df(
159
165
  self,
160
- specs: Sequence[ColumnExpressionSpec],
166
+ columns: Sequence[ColumnSpec],
161
167
  ) -> pd.DataFrame:
162
168
  with self._connection.cursor() as cursor:
163
- columns = [
164
- f"{spec.expr} AS {quote_ident(spec.name)}" for spec in specs
169
+ projections = [
170
+ f"{column.expr} AS {quote_ident(column.name)}"
171
+ for column in columns
165
172
  ]
166
- sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
173
+ sql = (f"SELECT {', '.join(projections)} "
174
+ f"FROM {self._quoted_source_name} "
175
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
167
176
  cursor.execute(sql)
168
177
  table = cursor.fetch_arrow_all()
169
- return table.to_pandas(types_mapper=pd.ArrowDtype)
178
+
179
+ if table is None:
180
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
181
+
182
+ return self._sanitize(
183
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
184
+ dtype_dict={column.name: column.dtype
185
+ for column in columns},
186
+ stype_dict=None,
187
+ )
188
+
189
+ @staticmethod
190
+ def _to_dtype(dtype: str | None) -> Dtype | None:
191
+ if dtype is None:
192
+ return None
193
+ dtype = dtype.strip().upper()
194
+ if dtype.startswith('NUMBER'):
195
+ try: # Parse `scale` from 'NUMBER(precision, scale)':
196
+ scale = int(dtype.split(',')[-1].split(')')[0])
197
+ return Dtype.int if scale == 0 else Dtype.float
198
+ except Exception:
199
+ return Dtype.float
200
+ if dtype == 'FLOAT':
201
+ return Dtype.float
202
+ if dtype.startswith('VARCHAR'):
203
+ return Dtype.string
204
+ if dtype.startswith('BINARY'):
205
+ return Dtype.binary
206
+ if dtype == 'BOOLEAN':
207
+ return Dtype.bool
208
+ if dtype.startswith('DATE') or dtype.startswith('TIMESTAMP'):
209
+ return Dtype.date
210
+ if dtype.startswith('TIME'):
211
+ return Dtype.time
212
+ if dtype.startswith('VECTOR'):
213
+ try: # Parse element data type from 'VECTOR(dtype, dimension)':
214
+ dtype = dtype.split(',')[0].split('(')[1].strip()
215
+ if dtype == 'INT':
216
+ return Dtype.intlist
217
+ elif dtype == 'FLOAT':
218
+ return Dtype.floatlist
219
+ except Exception:
220
+ pass
221
+ return Dtype.unsupported
222
+ if dtype.startswith('ARRAY'):
223
+ try: # Parse element data type from 'ARRAY(dtype)':
224
+ dtype = dtype.split('(', maxsplit=1)[1]
225
+ dtype = dtype.rsplit(')', maxsplit=1)[0]
226
+ _dtype = SnowTable._to_dtype(dtype)
227
+ if _dtype is not None and _dtype.is_int():
228
+ return Dtype.intlist
229
+ elif _dtype is not None and _dtype.is_float():
230
+ return Dtype.floatlist
231
+ elif _dtype is not None and _dtype.is_string():
232
+ return Dtype.stringlist
233
+ except Exception:
234
+ pass
235
+ return Dtype.unsupported
236
+ # Unsupported data types:
237
+ if re.search(
238
+ 'DECFLOAT|VARIANT|OBJECT|MAP|FILE|GEOGRAPHY|GEOMETRY',
239
+ dtype,
240
+ ):
241
+ return Dtype.unsupported
242
+ return None