kumoai 2.14.0.dev202512181731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512301731__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/experimental/rfm/__init__.py +22 -22
  6. kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
  7. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  8. kumoai/experimental/rfm/backend/local/table.py +25 -24
  9. kumoai/experimental/rfm/backend/snow/sampler.py +106 -61
  10. kumoai/experimental/rfm/backend/snow/table.py +146 -51
  11. kumoai/experimental/rfm/backend/sqlite/sampler.py +127 -78
  12. kumoai/experimental/rfm/backend/sqlite/table.py +94 -47
  13. kumoai/experimental/rfm/base/__init__.py +6 -7
  14. kumoai/experimental/rfm/base/column.py +97 -5
  15. kumoai/experimental/rfm/base/expression.py +44 -0
  16. kumoai/experimental/rfm/base/sampler.py +5 -17
  17. kumoai/experimental/rfm/base/source.py +1 -1
  18. kumoai/experimental/rfm/base/sql_sampler.py +68 -9
  19. kumoai/experimental/rfm/base/table.py +284 -120
  20. kumoai/experimental/rfm/graph.py +139 -86
  21. kumoai/experimental/rfm/infer/__init__.py +6 -4
  22. kumoai/experimental/rfm/infer/dtype.py +6 -1
  23. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  24. kumoai/experimental/rfm/infer/stype.py +35 -0
  25. kumoai/experimental/rfm/relbench.py +76 -0
  26. kumoai/experimental/rfm/rfm.py +4 -20
  27. kumoai/trainer/distilled_trainer.py +175 -0
  28. kumoai/utils/display.py +51 -0
  29. {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/METADATA +1 -1
  30. {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/RECORD +33 -30
  31. kumoai/experimental/rfm/base/column_expression.py +0 -16
  32. kumoai/experimental/rfm/base/sql_table.py +0 -113
  33. {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/WHEEL +0 -0
  34. {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/licenses/LICENSE +0 -0
  35. {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,20 @@
1
1
  import json
2
2
  from collections.abc import Iterator
3
3
  from contextlib import contextmanager
4
+ from typing import TYPE_CHECKING
4
5
 
5
6
  import numpy as np
6
7
  import pandas as pd
7
8
  import pyarrow as pa
8
9
  from kumoapi.pquery import ValidatedPredictiveQuery
9
10
 
10
- from kumoai.experimental.rfm.backend.snow import Connection
11
- from kumoai.experimental.rfm.base import SQLSampler
11
+ from kumoai.experimental.rfm.backend.snow import Connection, SnowTable
12
+ from kumoai.experimental.rfm.base import SQLSampler, Table
12
13
  from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
13
- from kumoai.utils import quote_ident
14
+ from kumoai.utils import ProgressLogger
15
+
16
+ if TYPE_CHECKING:
17
+ from kumoai.experimental.rfm import Graph
14
18
 
15
19
 
16
20
  @contextmanager
@@ -22,18 +26,30 @@ def paramstyle(connection: Connection, style: str = 'qmark') -> Iterator[None]:
22
26
 
23
27
 
24
28
  class SnowSampler(SQLSampler):
29
+ def __init__(
30
+ self,
31
+ graph: 'Graph',
32
+ verbose: bool | ProgressLogger = True,
33
+ ) -> None:
34
+ super().__init__(graph=graph, verbose=verbose)
35
+
36
+ for table in graph.tables.values():
37
+ assert isinstance(table, SnowTable)
38
+ self._connection = table._connection
39
+
25
40
  def _get_min_max_time_dict(
26
41
  self,
27
42
  table_names: list[str],
28
43
  ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
29
44
  selects: list[str] = []
30
45
  for table_name in table_names:
31
- time_column = self.time_column_dict[table_name]
46
+ column = self.time_column_dict[table_name]
47
+ column_ref = self.table_column_ref_dict[table_name][column]
32
48
  select = (f"SELECT\n"
33
49
  f" ? as table_name,\n"
34
- f" MIN({quote_ident(time_column)}) as min_date,\n"
35
- f" MAX({quote_ident(time_column)}) as max_date\n"
36
- f"FROM {self.fqn_dict[table_name]}")
50
+ f" MIN({column_ref}) as min_date,\n"
51
+ f" MAX({column_ref}) as max_date\n"
52
+ f"FROM {self.source_name_dict[table_name]}")
37
53
  selects.append(select)
38
54
  sql = "\nUNION ALL\n".join(selects)
39
55
 
@@ -59,17 +75,27 @@ class SnowSampler(SQLSampler):
59
75
  # NOTE Snowflake does support `SEED` only as part of `SYSTEM` sampling.
60
76
  num_rows = min(num_rows, 1_000_000) # Snowflake's upper limit.
61
77
 
78
+ source_table = self.source_table_dict[table_name]
62
79
  filters: list[str] = []
63
- primary_key = self.primary_key_dict[table_name]
64
- if self.source_table_dict[table_name][primary_key].is_nullable:
65
- filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
66
- time_column = self.time_column_dict.get(table_name)
67
- if (time_column is not None and
68
- self.source_table_dict[table_name][time_column].is_nullable):
69
- filters.append(f" {quote_ident(time_column)} IS NOT NULL")
70
-
71
- sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
72
- f"FROM {self.fqn_dict[table_name]}\n"
80
+
81
+ key = self.primary_key_dict[table_name]
82
+ if key not in source_table or source_table[key].is_nullable:
83
+ key_ref = self.table_column_ref_dict[table_name][key]
84
+ filters.append(f" {key_ref} IS NOT NULL")
85
+
86
+ column = self.time_column_dict.get(table_name)
87
+ if column is None:
88
+ pass
89
+ elif column not in source_table or source_table[column].is_nullable:
90
+ column_ref = self.table_column_ref_dict[table_name][column]
91
+ filters.append(f" {column_ref} IS NOT NULL")
92
+
93
+ projections = [
94
+ self.table_column_proj_dict[table_name][column]
95
+ for column in columns
96
+ ]
97
+ sql = (f"SELECT {', '.join(projections)}\n"
98
+ f"FROM {self.source_name_dict[table_name]}\n"
73
99
  f"SAMPLE ROW ({num_rows} ROWS)")
74
100
  if len(filters) > 0:
75
101
  sql += f"\nWHERE{' AND'.join(filters)}"
@@ -79,7 +105,11 @@ class SnowSampler(SQLSampler):
79
105
  cursor.execute(sql)
80
106
  table = cursor.fetch_arrow_all()
81
107
 
82
- return self._sanitize(table_name, table)
108
+ return Table._sanitize(
109
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
110
+ dtype_dict=self.table_dtype_dict[table_name],
111
+ stype_dict=self.table_stype_dict[table_name],
112
+ )
83
113
 
84
114
  def _sample_target(
85
115
  self,
@@ -152,42 +182,51 @@ class SnowSampler(SQLSampler):
152
182
  pkey: pd.Series,
153
183
  columns: set[str],
154
184
  ) -> tuple[pd.DataFrame, np.ndarray]:
155
-
156
- pkey_name = self.primary_key_dict[table_name]
157
- source_table = self.source_table_dict[table_name]
185
+ key = self.primary_key_dict[table_name]
186
+ key_ref = self.table_column_ref_dict[table_name][key]
187
+ projections = [
188
+ self.table_column_proj_dict[table_name][column]
189
+ for column in columns
190
+ ]
158
191
 
159
192
  payload = json.dumps(list(pkey))
160
193
 
161
194
  sql = ("WITH TMP as (\n"
162
195
  " SELECT\n"
163
- " f.index as BATCH,\n")
164
- if source_table[pkey_name].dtype.is_int():
165
- sql += " f.value::NUMBER as ID\n"
166
- elif source_table[pkey_name].dtype.is_float():
167
- sql += " f.value::FLOAT as ID\n"
196
+ " f.index as __KUMO_BATCH__,\n")
197
+ if self.table_dtype_dict[table_name][key].is_int():
198
+ sql += " f.value::NUMBER as __KUMO_ID__\n"
199
+ elif self.table_dtype_dict[table_name][key].is_float():
200
+ sql += " f.value::FLOAT as __KUMO_ID__\n"
168
201
  else:
169
- sql += " f.value::VARCHAR as ID\n"
202
+ sql += " f.value::VARCHAR as __KUMO_ID__\n"
170
203
  sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
171
204
  f")\n"
172
- f"SELECT TMP.BATCH as __BATCH__, "
173
- f"{', '.join('ENT.' + quote_ident(col) for col in columns)}\n"
205
+ f"SELECT "
206
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
207
+ f"{', '.join(projections)}\n"
174
208
  f"FROM TMP\n"
175
- f"JOIN {self.fqn_dict[table_name]} ENT\n"
176
- f" ON ENT.{quote_ident(pkey_name)} = TMP.ID")
209
+ f"JOIN {self.source_name_dict[table_name]} ENT\n"
210
+ f" ON {key_ref} = TMP.__KUMO_ID__")
177
211
 
178
212
  with paramstyle(self._connection), self._connection.cursor() as cursor:
179
213
  cursor.execute(sql, (payload, ))
180
214
  table = cursor.fetch_arrow_all()
181
215
 
182
216
  # Remove any duplicated primary keys in post-processing:
183
- tmp = table.append_column('__TMP__', pa.array(range(len(table))))
184
- gb = tmp.group_by('__BATCH__').aggregate([('__TMP__', 'min')])
185
- table = table.take(gb['__TMP___min'])
217
+ tmp = table.append_column('__KUMO_ID__', pa.array(range(len(table))))
218
+ gb = tmp.group_by('__KUMO_BATCH__').aggregate([('__KUMO_ID__', 'min')])
219
+ table = table.take(gb['__KUMO_ID___min'])
186
220
 
187
- batch = table['__BATCH__'].cast(pa.int64()).to_numpy()
188
- table = table.remove_column(table.schema.get_field_index('__BATCH__'))
221
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
222
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
223
+ table = table.remove_column(batch_index)
189
224
 
190
- return table.to_pandas(), batch # TODO Use `self._sanitize`.
225
+ return Table._sanitize(
226
+ df=table.to_pandas(),
227
+ dtype_dict=self.table_dtype_dict[table_name],
228
+ stype_dict=self.table_stype_dict[table_name],
229
+ ), batch
191
230
 
192
231
  # Helper Methods ##########################################################
193
232
 
@@ -201,6 +240,7 @@ class SnowSampler(SQLSampler):
201
240
  max_offset: pd.DateOffset,
202
241
  columns: set[str],
203
242
  ) -> tuple[pd.DataFrame, np.ndarray]:
243
+ time_column = self.time_column_dict[table_name]
204
244
 
205
245
  end_time = anchor_time + max_offset
206
246
  end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
@@ -211,42 +251,47 @@ class SnowSampler(SQLSampler):
211
251
  else:
212
252
  payload = json.dumps(list(zip(pkey, end_time)))
213
253
 
214
- # Based on benchmarking, JSON payload is the fastest way to query by
215
- # custom indices (compared to large `IN` clauses or temporary tables):
216
- source_table = self.source_table_dict[table_name]
217
- time_column = self.time_column_dict[table_name]
254
+ key_ref = self.table_column_ref_dict[table_name][fkey]
255
+ time_ref = self.table_column_ref_dict[table_name][time_column]
256
+ projections = [
257
+ self.table_column_proj_dict[table_name][column]
258
+ for column in columns
259
+ ]
218
260
  sql = ("WITH TMP as (\n"
219
261
  " SELECT\n"
220
- " f.index as BATCH,\n")
221
- if source_table[fkey].dtype.is_int():
222
- sql += " f.value[0]::NUMBER as ID,\n"
223
- elif source_table[fkey].dtype.is_float():
224
- sql += " f.value[0]::FLOAT as ID,\n"
262
+ " f.index as __KUMO_BATCH__,\n")
263
+ if self.table_dtype_dict[table_name][fkey].is_int():
264
+ sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
265
+ elif self.table_dtype_dict[table_name][fkey].is_float():
266
+ sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
225
267
  else:
226
- sql += " f.value[0]::VARCHAR as ID,\n"
227
- sql += " f.value[1]::TIMESTAMP_NTZ as END_TIME"
268
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
269
+ sql += " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__"
228
270
  if min_offset is not None:
229
- sql += ",\n f.value[2]::TIMESTAMP_NTZ as START_TIME"
271
+ sql += ",\n f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__"
230
272
  sql += (f"\n"
231
273
  f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
232
274
  f")\n"
233
- f"SELECT TMP.BATCH as __BATCH__, "
234
- f"{', '.join('FACT.' + quote_ident(col) for col in columns)}\n"
275
+ f"SELECT "
276
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
277
+ f"{', '.join(projections)}\n"
235
278
  f"FROM TMP\n"
236
- f"JOIN {self.fqn_dict[table_name]} FACT\n"
237
- f" ON FACT.{quote_ident(fkey)} = TMP.ID\n"
238
- f" AND FACT.{quote_ident(time_column)} <= TMP.END_TIME")
279
+ f"JOIN {self.source_name_dict[table_name]} FACT\n"
280
+ f" ON {key_ref} = TMP.__KUMO_ID__\n"
281
+ f" AND {time_ref} <= TMP.__KUMO_END_TIME__")
239
282
  if min_offset is not None:
240
- sql += f"\n AND FACT.{quote_ident(time_column)} > TMP.START_TIME"
283
+ sql += f"\n AND {time_ref} > TMP.__KUMO_START_TIME__"
241
284
 
242
285
  with paramstyle(self._connection), self._connection.cursor() as cursor:
243
286
  cursor.execute(sql, (payload, ))
244
287
  table = cursor.fetch_arrow_all()
245
288
 
246
- batch = table['__BATCH__'].cast(pa.int64()).to_numpy()
247
- table = table.remove_column(table.schema.get_field_index('__BATCH__'))
248
-
249
- return self._sanitize(table_name, table), batch
289
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
290
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
291
+ table = table.remove_column(batch_index)
250
292
 
251
- def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
252
- return table.to_pandas(types_mapper=pd.ArrowDtype)
293
+ return Table._sanitize(
294
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
295
+ dtype_dict=self.table_dtype_dict[table_name],
296
+ stype_dict=self.table_stype_dict[table_name],
297
+ ), batch
@@ -1,4 +1,5 @@
1
1
  import re
2
+ from collections import Counter
2
3
  from collections.abc import Sequence
3
4
  from typing import cast
4
5
 
@@ -8,27 +9,27 @@ from kumoapi.typing import Dtype
8
9
 
9
10
  from kumoai.experimental.rfm.backend.snow import Connection
10
11
  from kumoai.experimental.rfm.base import (
11
- ColumnExpressionType,
12
+ ColumnSpec,
13
+ ColumnSpecType,
12
14
  DataBackend,
13
15
  SourceColumn,
14
16
  SourceForeignKey,
15
- SQLTable,
17
+ Table,
16
18
  )
17
19
  from kumoai.utils import quote_ident
18
20
 
19
21
 
20
- class SnowTable(SQLTable):
22
+ class SnowTable(Table):
21
23
  r"""A table backed by a :class:`sqlite` database.
22
24
 
23
25
  Args:
24
26
  connection: The connection to a :class:`snowflake` database.
25
- name: The logical name of this table.
26
- source_name: The physical name of this table in the database. If set to
27
- ``None``, ``name`` is being used.
27
+ name: The name of this table.
28
+ source_name: The source name of this table. If set to ``None``,
29
+ ``name`` is being used.
28
30
  database: The database.
29
31
  schema: The schema.
30
- columns: The selected physical columns of this table.
31
- column_expressions: The logical columns of this table.
32
+ columns: The selected columns of this table.
32
33
  primary_key: The name of the primary key of this table, if it exists.
33
34
  time_column: The name of the time column of this table, if it exists.
34
35
  end_time_column: The name of the end time column of this table, if it
@@ -41,14 +42,21 @@ class SnowTable(SQLTable):
41
42
  source_name: str | None = None,
42
43
  database: str | None = None,
43
44
  schema: str | None = None,
44
- columns: Sequence[str] | None = None,
45
- column_expressions: Sequence[ColumnExpressionType] | None = None,
45
+ columns: Sequence[ColumnSpecType] | None = None,
46
46
  primary_key: MissingType | str | None = MissingType.VALUE,
47
47
  time_column: str | None = None,
48
48
  end_time_column: str | None = None,
49
49
  ) -> None:
50
50
 
51
- if database is not None and schema is None:
51
+ if database is None or schema is None:
52
+ with connection.cursor() as cursor:
53
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
54
+ result = cursor.fetchone()
55
+ database = database or result[0]
56
+ assert database is not None
57
+ schema = schema or result[1]
58
+
59
+ if schema is None:
52
60
  raise ValueError(f"Unspecified 'schema' for table "
53
61
  f"'{source_name or name}' in database "
54
62
  f"'{database}'")
@@ -61,19 +69,22 @@ class SnowTable(SQLTable):
61
69
  name=name,
62
70
  source_name=source_name,
63
71
  columns=columns,
64
- column_expressions=column_expressions,
65
72
  primary_key=primary_key,
66
73
  time_column=time_column,
67
74
  end_time_column=end_time_column,
68
75
  )
69
76
 
70
77
  @property
71
- def backend(self) -> DataBackend:
72
- return cast(DataBackend, DataBackend.SNOWFLAKE)
78
+ def source_name(self) -> str:
79
+ names: list[str] = []
80
+ if self._database is not None:
81
+ names.append(self._database)
82
+ if self._schema is not None:
83
+ names.append(self._schema)
84
+ return '.'.join(names + [self._source_name])
73
85
 
74
86
  @property
75
- def fqn(self) -> str:
76
- r"""The fully-qualified quoted table name."""
87
+ def _quoted_source_name(self) -> str:
77
88
  names: list[str] = []
78
89
  if self._database is not None:
79
90
  names.append(quote_ident(self._database))
@@ -81,42 +92,26 @@ class SnowTable(SQLTable):
81
92
  names.append(quote_ident(self._schema))
82
93
  return '.'.join(names + [quote_ident(self._source_name)])
83
94
 
95
+ @property
96
+ def backend(self) -> DataBackend:
97
+ return cast(DataBackend, DataBackend.SNOWFLAKE)
98
+
84
99
  def _get_source_columns(self) -> list[SourceColumn]:
85
100
  source_columns: list[SourceColumn] = []
86
101
  with self._connection.cursor() as cursor:
87
102
  try:
88
- sql = f"DESCRIBE TABLE {self.fqn}"
103
+ sql = f"DESCRIBE TABLE {self._quoted_source_name}"
89
104
  cursor.execute(sql)
90
105
  except Exception as e:
91
- names: list[str] = []
92
- if self._database is not None:
93
- names.append(self._database)
94
- if self._schema is not None:
95
- names.append(self._schema)
96
- source_name = '.'.join(names + [self._source_name])
97
- raise ValueError(f"Table '{source_name}' does not exist in "
98
- f"the remote data backend") from e
106
+ raise ValueError(f"Table '{self.source_name}' does not exist "
107
+ f"in the remote data backend") from e
99
108
 
100
109
  for row in cursor.fetchall():
101
- column, type, _, null, _, is_pkey, is_unique, *_ = row
102
-
103
- type = type.strip().upper()
104
- if type.startswith('NUMBER'):
105
- dtype = Dtype.int
106
- elif type.startswith('VARCHAR'):
107
- dtype = Dtype.string
108
- elif type == 'FLOAT':
109
- dtype = Dtype.float
110
- elif type == 'BOOLEAN':
111
- dtype = Dtype.bool
112
- elif re.search('DATE|TIMESTAMP', type):
113
- dtype = Dtype.date
114
- else:
115
- continue
110
+ column, dtype, _, null, _, is_pkey, is_unique, *_ = row
116
111
 
117
112
  source_column = SourceColumn(
118
113
  name=column,
119
- dtype=dtype,
114
+ dtype=self._to_dtype(dtype),
120
115
  is_primary_key=is_pkey.strip().upper() == 'Y',
121
116
  is_unique_key=is_unique.strip().upper() == 'Y',
122
117
  is_nullable=null.strip().upper() == 'Y',
@@ -126,22 +121,122 @@ class SnowTable(SQLTable):
126
121
  return source_columns
127
122
 
128
123
  def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
129
- source_fkeys: list[SourceForeignKey] = []
124
+ source_foreign_keys: list[SourceForeignKey] = []
130
125
  with self._connection.cursor() as cursor:
131
- sql = f"SHOW IMPORTED KEYS IN TABLE {self.fqn}"
126
+ sql = f"SHOW IMPORTED KEYS IN TABLE {self._quoted_source_name}"
132
127
  cursor.execute(sql)
133
- for row in cursor.fetchall():
134
- _, _, _, dst_table, pkey, _, _, _, fkey, *_ = row
135
- source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
136
- return source_fkeys
137
-
138
- def _get_sample_df(self) -> pd.DataFrame:
128
+ rows = cursor.fetchall()
129
+ counts = Counter(row[13] for row in rows)
130
+ for row in rows:
131
+ if counts[row[13]] == 1:
132
+ source_foreign_key = SourceForeignKey(
133
+ name=row[8],
134
+ dst_table=f'{row[1]}.{row[2]}.{row[3]}',
135
+ primary_key=row[4],
136
+ )
137
+ source_foreign_keys.append(source_foreign_key)
138
+ return source_foreign_keys
139
+
140
+ def _get_source_sample_df(self) -> pd.DataFrame:
139
141
  with self._connection.cursor() as cursor:
140
142
  columns = [quote_ident(col) for col in self._source_column_dict]
141
- sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
143
+ sql = (f"SELECT {', '.join(columns)} "
144
+ f"FROM {self._quoted_source_name} "
145
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
142
146
  cursor.execute(sql)
143
147
  table = cursor.fetch_arrow_all()
144
- return table.to_pandas(types_mapper=pd.ArrowDtype)
148
+
149
+ if table is None:
150
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
151
+
152
+ return self._sanitize(
153
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
154
+ dtype_dict={
155
+ column.name: column.dtype
156
+ for column in self._source_column_dict.values()
157
+ },
158
+ stype_dict=None,
159
+ )
145
160
 
146
161
  def _get_num_rows(self) -> int | None:
147
162
  return None
163
+
164
+ def _get_expr_sample_df(
165
+ self,
166
+ columns: Sequence[ColumnSpec],
167
+ ) -> pd.DataFrame:
168
+ with self._connection.cursor() as cursor:
169
+ projections = [
170
+ f"{column.expr} AS {quote_ident(column.name)}"
171
+ for column in columns
172
+ ]
173
+ sql = (f"SELECT {', '.join(projections)} "
174
+ f"FROM {self._quoted_source_name} "
175
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
176
+ cursor.execute(sql)
177
+ table = cursor.fetch_arrow_all()
178
+
179
+ if table is None:
180
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
181
+
182
+ return self._sanitize(
183
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
184
+ dtype_dict={column.name: column.dtype
185
+ for column in columns},
186
+ stype_dict=None,
187
+ )
188
+
189
+ @staticmethod
190
+ def _to_dtype(dtype: str | None) -> Dtype | None:
191
+ if dtype is None:
192
+ return None
193
+ dtype = dtype.strip().upper()
194
+ if dtype.startswith('NUMBER'):
195
+ try: # Parse `scale` from 'NUMBER(precision, scale)':
196
+ scale = int(dtype.split(',')[-1].split(')')[0])
197
+ return Dtype.int if scale == 0 else Dtype.float
198
+ except Exception:
199
+ return Dtype.float
200
+ if dtype == 'FLOAT':
201
+ return Dtype.float
202
+ if dtype.startswith('VARCHAR'):
203
+ return Dtype.string
204
+ if dtype.startswith('BINARY'):
205
+ return Dtype.binary
206
+ if dtype == 'BOOLEAN':
207
+ return Dtype.bool
208
+ if dtype.startswith('DATE') or dtype.startswith('TIMESTAMP'):
209
+ return Dtype.date
210
+ if dtype.startswith('TIME'):
211
+ return Dtype.time
212
+ if dtype.startswith('VECTOR'):
213
+ try: # Parse element data type from 'VECTOR(dtype, dimension)':
214
+ dtype = dtype.split(',')[0].split('(')[1].strip()
215
+ if dtype == 'INT':
216
+ return Dtype.intlist
217
+ elif dtype == 'FLOAT':
218
+ return Dtype.floatlist
219
+ except Exception:
220
+ pass
221
+ return Dtype.unsupported
222
+ if dtype.startswith('ARRAY'):
223
+ try: # Parse element data type from 'ARRAY(dtype)':
224
+ dtype = dtype.split('(', maxsplit=1)[1]
225
+ dtype = dtype.rsplit(')', maxsplit=1)[0]
226
+ _dtype = SnowTable._to_dtype(dtype)
227
+ if _dtype is not None and _dtype.is_int():
228
+ return Dtype.intlist
229
+ elif _dtype is not None and _dtype.is_float():
230
+ return Dtype.floatlist
231
+ elif _dtype is not None and _dtype.is_string():
232
+ return Dtype.stringlist
233
+ except Exception:
234
+ pass
235
+ return Dtype.unsupported
236
+ # Unsupported data types:
237
+ if re.search(
238
+ 'DECFLOAT|VARIANT|OBJECT|MAP|FILE|GEOGRAPHY|GEOMETRY',
239
+ dtype,
240
+ ):
241
+ return Dtype.unsupported
242
+ return None