kumoai 2.13.0.dev202512081731__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512211732__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/_version.py +1 -1
  2. kumoai/client/pquery.py +6 -2
  3. kumoai/experimental/rfm/__init__.py +33 -8
  4. kumoai/experimental/rfm/authenticate.py +3 -4
  5. kumoai/experimental/rfm/backend/local/graph_store.py +40 -83
  6. kumoai/experimental/rfm/backend/local/sampler.py +213 -14
  7. kumoai/experimental/rfm/backend/local/table.py +21 -16
  8. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  9. kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
  10. kumoai/experimental/rfm/backend/snow/table.py +101 -49
  11. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
  13. kumoai/experimental/rfm/backend/sqlite/table.py +84 -31
  14. kumoai/experimental/rfm/base/__init__.py +25 -6
  15. kumoai/experimental/rfm/base/column.py +14 -12
  16. kumoai/experimental/rfm/base/column_expression.py +50 -0
  17. kumoai/experimental/rfm/base/sampler.py +438 -38
  18. kumoai/experimental/rfm/base/source.py +1 -0
  19. kumoai/experimental/rfm/base/sql_sampler.py +84 -0
  20. kumoai/experimental/rfm/base/sql_table.py +229 -0
  21. kumoai/experimental/rfm/base/table.py +165 -135
  22. kumoai/experimental/rfm/graph.py +266 -102
  23. kumoai/experimental/rfm/infer/__init__.py +6 -4
  24. kumoai/experimental/rfm/infer/dtype.py +3 -3
  25. kumoai/experimental/rfm/infer/pkey.py +4 -2
  26. kumoai/experimental/rfm/infer/stype.py +35 -0
  27. kumoai/experimental/rfm/infer/time_col.py +1 -2
  28. kumoai/experimental/rfm/pquery/executor.py +27 -27
  29. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  30. kumoai/experimental/rfm/rfm.py +299 -230
  31. kumoai/experimental/rfm/sagemaker.py +4 -4
  32. kumoai/pquery/predictive_query.py +10 -6
  33. kumoai/testing/snow.py +50 -0
  34. kumoai/utils/__init__.py +3 -2
  35. kumoai/utils/progress_logger.py +178 -12
  36. kumoai/utils/sql.py +3 -0
  37. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/METADATA +3 -2
  38. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/RECORD +41 -35
  39. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  40. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  41. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/WHEEL +0 -0
  42. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,10 @@
1
1
  import warnings
2
- from typing import List, Optional
2
+ from typing import cast
3
3
 
4
4
  import pandas as pd
5
+ from kumoapi.model_plan import MissingType
5
6
 
6
- from kumoai.experimental.rfm.base import SourceColumn, SourceForeignKey, Table
7
+ from kumoai.experimental.rfm.base import DataBackend, SourceColumn, Table
7
8
  from kumoai.experimental.rfm.infer import infer_dtype
8
9
 
9
10
 
@@ -52,9 +53,9 @@ class LocalTable(Table):
52
53
  self,
53
54
  df: pd.DataFrame,
54
55
  name: str,
55
- primary_key: Optional[str] = None,
56
- time_column: Optional[str] = None,
57
- end_time_column: Optional[str] = None,
56
+ primary_key: MissingType | str | None = MissingType.VALUE,
57
+ time_column: str | None = None,
58
+ end_time_column: str | None = None,
58
59
  ) -> None:
59
60
 
60
61
  if df.empty:
@@ -76,17 +77,23 @@ class LocalTable(Table):
76
77
  end_time_column=end_time_column,
77
78
  )
78
79
 
79
- def _get_source_columns(self) -> List[SourceColumn]:
80
- source_columns: List[SourceColumn] = []
80
+ @property
81
+ def backend(self) -> DataBackend:
82
+ return cast(DataBackend, DataBackend.LOCAL)
83
+
84
+ def _get_source_columns(self) -> list[SourceColumn]:
85
+ source_columns: list[SourceColumn] = []
81
86
  for column in self._data.columns:
82
87
  ser = self._data[column]
83
88
  try:
84
89
  dtype = infer_dtype(ser)
85
90
  except Exception:
86
- warnings.warn(f"Data type inference for column '{column}' in "
87
- f"table '{self.name}' failed. Consider changing "
88
- f"the data type of the column to use it within "
89
- f"this table.")
91
+ warnings.warn(f"Encountered unsupported data type "
92
+ f"'{ser.dtype}' for column '{column}' in table "
93
+ f"'{self.name}'. Please change the data type of "
94
+ f"the column in the `pandas.DataFrame` to use "
95
+ f"it within this table, or remove it to "
96
+ f"suppress this warning.")
90
97
  continue
91
98
 
92
99
  source_column = SourceColumn(
@@ -94,16 +101,14 @@ class LocalTable(Table):
94
101
  dtype=dtype,
95
102
  is_primary_key=False,
96
103
  is_unique_key=False,
104
+ is_nullable=True,
97
105
  )
98
106
  source_columns.append(source_column)
99
107
 
100
108
  return source_columns
101
109
 
102
- def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
103
- return []
104
-
105
- def _get_sample_df(self) -> pd.DataFrame:
110
+ def _get_source_sample_df(self) -> pd.DataFrame:
106
111
  return self._data
107
112
 
108
- def _get_num_rows(self) -> Optional[int]:
113
+ def _get_num_rows(self) -> int | None:
109
114
  return len(self._data)
@@ -27,9 +27,11 @@ def connect(**kwargs: Any) -> Connection:
27
27
 
28
28
 
29
29
  from .table import SnowTable # noqa: E402
30
+ from .sampler import SnowSampler # noqa: E402
30
31
 
31
32
  __all__ = [
32
33
  'connect',
33
34
  'Connection',
34
35
  'SnowTable',
36
+ 'SnowSampler',
35
37
  ]
@@ -0,0 +1,252 @@
1
+ import json
2
+ from collections.abc import Iterator
3
+ from contextlib import contextmanager
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import pyarrow as pa
8
+ from kumoapi.pquery import ValidatedPredictiveQuery
9
+
10
+ from kumoai.experimental.rfm.backend.snow import Connection
11
+ from kumoai.experimental.rfm.base import SQLSampler
12
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
13
+ from kumoai.utils import quote_ident
14
+
15
+
16
+ @contextmanager
17
+ def paramstyle(connection: Connection, style: str = 'qmark') -> Iterator[None]:
18
+ _style = connection._paramstyle
19
+ connection._paramstyle = style
20
+ yield
21
+ connection._paramstyle = _style
22
+
23
+
24
+ class SnowSampler(SQLSampler):
25
+ def _get_min_max_time_dict(
26
+ self,
27
+ table_names: list[str],
28
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
29
+ selects: list[str] = []
30
+ for table_name in table_names:
31
+ time_column = self.time_column_dict[table_name]
32
+ select = (f"SELECT\n"
33
+ f" ? as table_name,\n"
34
+ f" MIN({quote_ident(time_column)}) as min_date,\n"
35
+ f" MAX({quote_ident(time_column)}) as max_date\n"
36
+ f"FROM {self.fqn_dict[table_name]}")
37
+ selects.append(select)
38
+ sql = "\nUNION ALL\n".join(selects)
39
+
40
+ out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
41
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
42
+ cursor.execute(sql, table_names)
43
+ rows = cursor.fetchall()
44
+ for table_name, _min, _max in rows:
45
+ out_dict[table_name] = (
46
+ pd.Timestamp.max if _min is None else pd.Timestamp(_min),
47
+ pd.Timestamp.min if _max is None else pd.Timestamp(_max),
48
+ )
49
+
50
+ return out_dict
51
+
52
+ def _sample_entity_table(
53
+ self,
54
+ table_name: str,
55
+ columns: set[str],
56
+ num_rows: int,
57
+ random_seed: int | None = None,
58
+ ) -> pd.DataFrame:
59
+ # NOTE Snowflake does support `SEED` only as part of `SYSTEM` sampling.
60
+ num_rows = min(num_rows, 1_000_000) # Snowflake's upper limit.
61
+
62
+ filters: list[str] = []
63
+ primary_key = self.primary_key_dict[table_name]
64
+ if self.source_table_dict[table_name][primary_key].is_nullable:
65
+ filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
66
+ time_column = self.time_column_dict.get(table_name)
67
+ if (time_column is not None and
68
+ self.source_table_dict[table_name][time_column].is_nullable):
69
+ filters.append(f" {quote_ident(time_column)} IS NOT NULL")
70
+
71
+ sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
72
+ f"FROM {self.fqn_dict[table_name]}\n"
73
+ f"SAMPLE ROW ({num_rows} ROWS)")
74
+ if len(filters) > 0:
75
+ sql += f"\nWHERE{' AND'.join(filters)}"
76
+
77
+ with self._connection.cursor() as cursor:
78
+ # NOTE This may return duplicate primary keys. This is okay.
79
+ cursor.execute(sql)
80
+ table = cursor.fetch_arrow_all()
81
+
82
+ return self._sanitize(table_name, table)
83
+
84
+ def _sample_target(
85
+ self,
86
+ query: ValidatedPredictiveQuery,
87
+ entity_df: pd.DataFrame,
88
+ train_index: np.ndarray,
89
+ train_time: pd.Series,
90
+ num_train_examples: int,
91
+ test_index: np.ndarray,
92
+ test_time: pd.Series,
93
+ num_test_examples: int,
94
+ columns_dict: dict[str, set[str]],
95
+ time_offset_dict: dict[
96
+ tuple[str, str, str],
97
+ tuple[pd.DateOffset | None, pd.DateOffset],
98
+ ],
99
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
100
+
101
+ # NOTE For Snowflake, we execute everything at once to pay minimal
102
+ # query initialization costs.
103
+ index = np.concatenate([train_index, test_index])
104
+ time = pd.concat([train_time, test_time], axis=0, ignore_index=True)
105
+
106
+ entity_df = entity_df.iloc[index].reset_index(drop=True)
107
+
108
+ feat_dict: dict[str, pd.DataFrame] = {query.entity_table: entity_df}
109
+ time_dict: dict[str, pd.Series] = {}
110
+ time_column = self.time_column_dict.get(query.entity_table)
111
+ if time_column in columns_dict[query.entity_table]:
112
+ time_dict[query.entity_table] = entity_df[time_column]
113
+ batch_dict: dict[str, np.ndarray] = {
114
+ query.entity_table: np.arange(len(entity_df)),
115
+ }
116
+ for edge_type, (min_offset, max_offset) in time_offset_dict.items():
117
+ table_name, fkey, _ = edge_type
118
+ feat_dict[table_name], batch_dict[table_name] = self._by_time(
119
+ table_name=table_name,
120
+ fkey=fkey,
121
+ pkey=entity_df[self.primary_key_dict[query.entity_table]],
122
+ anchor_time=time,
123
+ min_offset=min_offset,
124
+ max_offset=max_offset,
125
+ columns=columns_dict[table_name],
126
+ )
127
+ time_column = self.time_column_dict.get(table_name)
128
+ if time_column in columns_dict[table_name]:
129
+ time_dict[table_name] = feat_dict[table_name][time_column]
130
+
131
+ y, mask = PQueryPandasExecutor().execute(
132
+ query=query,
133
+ feat_dict=feat_dict,
134
+ time_dict=time_dict,
135
+ batch_dict=batch_dict,
136
+ anchor_time=time,
137
+ num_forecasts=query.num_forecasts,
138
+ )
139
+
140
+ train_mask = mask[:len(train_index)]
141
+ test_mask = mask[len(train_index):]
142
+
143
+ boundary = int(train_mask.sum())
144
+ train_y = y.iloc[:boundary]
145
+ test_y = y.iloc[boundary:].reset_index(drop=True)
146
+
147
+ return train_y, train_mask, test_y, test_mask
148
+
149
+ def _by_pkey(
150
+ self,
151
+ table_name: str,
152
+ pkey: pd.Series,
153
+ columns: set[str],
154
+ ) -> tuple[pd.DataFrame, np.ndarray]:
155
+
156
+ pkey_name = self.primary_key_dict[table_name]
157
+ source_table = self.source_table_dict[table_name]
158
+
159
+ payload = json.dumps(list(pkey))
160
+
161
+ sql = ("WITH TMP as (\n"
162
+ " SELECT\n"
163
+ " f.index as BATCH,\n")
164
+ if source_table[pkey_name].dtype.is_int():
165
+ sql += " f.value::NUMBER as ID\n"
166
+ elif source_table[pkey_name].dtype.is_float():
167
+ sql += " f.value::FLOAT as ID\n"
168
+ else:
169
+ sql += " f.value::VARCHAR as ID\n"
170
+ sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
171
+ f")\n"
172
+ f"SELECT TMP.BATCH as __BATCH__, "
173
+ f"{', '.join('ENT.' + quote_ident(col) for col in columns)}\n"
174
+ f"FROM TMP\n"
175
+ f"JOIN {self.fqn_dict[table_name]} ENT\n"
176
+ f" ON ENT.{quote_ident(pkey_name)} = TMP.ID")
177
+
178
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
179
+ cursor.execute(sql, (payload, ))
180
+ table = cursor.fetch_arrow_all()
181
+
182
+ # Remove any duplicated primary keys in post-processing:
183
+ tmp = table.append_column('__TMP__', pa.array(range(len(table))))
184
+ gb = tmp.group_by('__BATCH__').aggregate([('__TMP__', 'min')])
185
+ table = table.take(gb['__TMP___min'])
186
+
187
+ batch = table['__BATCH__'].cast(pa.int64()).to_numpy()
188
+ table = table.remove_column(table.schema.get_field_index('__BATCH__'))
189
+
190
+ return table.to_pandas(), batch # TODO Use `self._sanitize`.
191
+
192
+ # Helper Methods ##########################################################
193
+
194
+ def _by_time(
195
+ self,
196
+ table_name: str,
197
+ fkey: str,
198
+ pkey: pd.Series,
199
+ anchor_time: pd.Series,
200
+ min_offset: pd.DateOffset | None,
201
+ max_offset: pd.DateOffset,
202
+ columns: set[str],
203
+ ) -> tuple[pd.DataFrame, np.ndarray]:
204
+
205
+ end_time = anchor_time + max_offset
206
+ end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
207
+ if min_offset is not None:
208
+ start_time = anchor_time + min_offset
209
+ start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
210
+ payload = json.dumps(list(zip(pkey, end_time, start_time)))
211
+ else:
212
+ payload = json.dumps(list(zip(pkey, end_time)))
213
+
214
+ # Based on benchmarking, JSON payload is the fastest way to query by
215
+ # custom indices (compared to large `IN` clauses or temporary tables):
216
+ source_table = self.source_table_dict[table_name]
217
+ time_column = self.time_column_dict[table_name]
218
+ sql = ("WITH TMP as (\n"
219
+ " SELECT\n"
220
+ " f.index as BATCH,\n")
221
+ if source_table[fkey].dtype.is_int():
222
+ sql += " f.value[0]::NUMBER as ID,\n"
223
+ elif source_table[fkey].dtype.is_float():
224
+ sql += " f.value[0]::FLOAT as ID,\n"
225
+ else:
226
+ sql += " f.value[0]::VARCHAR as ID,\n"
227
+ sql += " f.value[1]::TIMESTAMP_NTZ as END_TIME"
228
+ if min_offset is not None:
229
+ sql += ",\n f.value[2]::TIMESTAMP_NTZ as START_TIME"
230
+ sql += (f"\n"
231
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
232
+ f")\n"
233
+ f"SELECT TMP.BATCH as __BATCH__, "
234
+ f"{', '.join('FACT.' + quote_ident(col) for col in columns)}\n"
235
+ f"FROM TMP\n"
236
+ f"JOIN {self.fqn_dict[table_name]} FACT\n"
237
+ f" ON FACT.{quote_ident(fkey)} = TMP.ID\n"
238
+ f" AND FACT.{quote_ident(time_column)} <= TMP.END_TIME")
239
+ if min_offset is not None:
240
+ sql += f"\n AND FACT.{quote_ident(time_column)} > TMP.START_TIME"
241
+
242
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
243
+ cursor.execute(sql, (payload, ))
244
+ table = cursor.fetch_arrow_all()
245
+
246
+ batch = table['__BATCH__'].cast(pa.int64()).to_numpy()
247
+ table = table.remove_column(table.schema.get_field_index('__BATCH__'))
248
+
249
+ return self._sanitize(table_name, table), batch
250
+
251
+ def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
252
+ return table.to_pandas(types_mapper=pd.ArrowDtype)
@@ -1,22 +1,35 @@
1
1
  import re
2
- from typing import List, Optional, Sequence
2
+ from collections.abc import Sequence
3
+ from typing import cast
3
4
 
4
5
  import pandas as pd
6
+ from kumoapi.model_plan import MissingType
5
7
  from kumoapi.typing import Dtype
6
8
 
7
9
  from kumoai.experimental.rfm.backend.snow import Connection
8
- from kumoai.experimental.rfm.base import SourceColumn, SourceForeignKey, Table
9
-
10
-
11
- class SnowTable(Table):
10
+ from kumoai.experimental.rfm.base import (
11
+ ColumnExpressionSpec,
12
+ ColumnExpressionType,
13
+ DataBackend,
14
+ SourceColumn,
15
+ SourceForeignKey,
16
+ SQLTable,
17
+ )
18
+ from kumoai.utils import quote_ident
19
+
20
+
21
+ class SnowTable(SQLTable):
12
22
  r"""A table backed by a :class:`sqlite` database.
13
23
 
14
24
  Args:
15
25
  connection: The connection to a :class:`snowflake` database.
16
- name: The name of this table.
26
+ name: The logical name of this table.
27
+ source_name: The physical name of this table in the database. If set to
28
+ ``None``, ``name`` is being used.
17
29
  database: The database.
18
30
  schema: The schema.
19
- columns: The selected columns of this table.
31
+ columns: The selected physical columns of this table.
32
+ column_expressions: The logical columns of this table.
20
33
  primary_key: The name of the primary key of this table, if it exists.
21
34
  time_column: The name of the time column of this table, if it exists.
22
35
  end_time_column: The name of the end time column of this table, if it
@@ -26,17 +39,20 @@ class SnowTable(Table):
26
39
  self,
27
40
  connection: Connection,
28
41
  name: str,
42
+ source_name: str | None = None,
29
43
  database: str | None = None,
30
44
  schema: str | None = None,
31
- columns: Optional[Sequence[str]] = None,
32
- primary_key: Optional[str] = None,
33
- time_column: Optional[str] = None,
34
- end_time_column: Optional[str] = None,
45
+ columns: Sequence[str] | None = None,
46
+ column_expressions: Sequence[ColumnExpressionType] | None = None,
47
+ primary_key: MissingType | str | None = MissingType.VALUE,
48
+ time_column: str | None = None,
49
+ end_time_column: str | None = None,
35
50
  ) -> None:
36
51
 
37
52
  if database is not None and schema is None:
38
- raise ValueError(f"Missing 'schema' for table '{name}' in "
39
- f"database '{database}'")
53
+ raise ValueError(f"Unspecified 'schema' for table "
54
+ f"'{source_name or name}' in database "
55
+ f"'{database}'")
40
56
 
41
57
  self._connection = connection
42
58
  self._database = database
@@ -44,47 +60,67 @@ class SnowTable(Table):
44
60
 
45
61
  super().__init__(
46
62
  name=name,
63
+ source_name=source_name,
47
64
  columns=columns,
65
+ column_expressions=column_expressions,
48
66
  primary_key=primary_key,
49
67
  time_column=time_column,
50
68
  end_time_column=end_time_column,
51
69
  )
52
70
 
71
+ @staticmethod
72
+ def to_dtype(snowflake_dtype: str | None) -> Dtype | None:
73
+ if snowflake_dtype is None:
74
+ return None
75
+ snowflake_dtype = snowflake_dtype.strip().upper()
76
+ # TODO 'NUMBER(...)' is not always an integer!
77
+ if snowflake_dtype.startswith('NUMBER'):
78
+ return Dtype.int
79
+ elif snowflake_dtype.startswith('VARCHAR'):
80
+ return Dtype.string
81
+ elif snowflake_dtype == 'FLOAT':
82
+ return Dtype.float
83
+ elif snowflake_dtype == 'BOOLEAN':
84
+ return Dtype.bool
85
+ elif re.search('DATE|TIMESTAMP', snowflake_dtype):
86
+ return Dtype.date
87
+ return None
88
+
53
89
  @property
54
- def fqn_name(self) -> str:
55
- names: List[str] = []
90
+ def backend(self) -> DataBackend:
91
+ return cast(DataBackend, DataBackend.SNOWFLAKE)
92
+
93
+ @property
94
+ def fqn(self) -> str:
95
+ r"""The fully-qualified quoted table name."""
96
+ names: list[str] = []
56
97
  if self._database is not None:
57
- assert self._schema is not None
58
- names.extend([self._database, self._schema])
59
- elif self._schema is not None:
60
- names.append(self._schema)
61
- names.append(self._name)
62
- return '.'.join(names)
63
-
64
- def _get_source_columns(self) -> List[SourceColumn]:
65
- source_columns: List[SourceColumn] = []
98
+ names.append(quote_ident(self._database))
99
+ if self._schema is not None:
100
+ names.append(quote_ident(self._schema))
101
+ return '.'.join(names + [quote_ident(self._source_name)])
102
+
103
+ def _get_source_columns(self) -> list[SourceColumn]:
104
+ source_columns: list[SourceColumn] = []
66
105
  with self._connection.cursor() as cursor:
67
106
  try:
68
- cursor.execute(f"DESCRIBE TABLE {self.fqn_name}")
107
+ sql = f"DESCRIBE TABLE {self.fqn}"
108
+ cursor.execute(sql)
69
109
  except Exception as e:
70
- raise ValueError(
71
- f"Table '{self.fqn_name}' does not exist") from e
110
+ names: list[str] = []
111
+ if self._database is not None:
112
+ names.append(self._database)
113
+ if self._schema is not None:
114
+ names.append(self._schema)
115
+ source_name = '.'.join(names + [self._source_name])
116
+ raise ValueError(f"Table '{source_name}' does not exist in "
117
+ f"the remote data backend") from e
72
118
 
73
119
  for row in cursor.fetchall():
74
- column, type, _, _, _, is_pkey, is_unique = row[:7]
75
-
76
- type = type.strip().upper()
77
- if type.startswith('NUMBER'):
78
- dtype = Dtype.int
79
- elif type.startswith('VARCHAR'):
80
- dtype = Dtype.string
81
- elif type == 'FLOAT':
82
- dtype = Dtype.float
83
- elif type == 'BOOLEAN':
84
- dtype = Dtype.bool
85
- elif re.search('DATE|TIMESTAMP', type):
86
- dtype = Dtype.date
87
- else:
120
+ column, type, _, null, _, is_pkey, is_unique, *_ = row
121
+
122
+ dtype = self.to_dtype(type)
123
+ if dtype is None:
88
124
  continue
89
125
 
90
126
  source_column = SourceColumn(
@@ -92,26 +128,42 @@ class SnowTable(Table):
92
128
  dtype=dtype,
93
129
  is_primary_key=is_pkey.strip().upper() == 'Y',
94
130
  is_unique_key=is_unique.strip().upper() == 'Y',
131
+ is_nullable=null.strip().upper() == 'Y',
95
132
  )
96
133
  source_columns.append(source_column)
97
134
 
98
135
  return source_columns
99
136
 
100
- def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
101
- source_fkeys: List[SourceForeignKey] = []
137
+ def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
138
+ source_fkeys: list[SourceForeignKey] = []
102
139
  with self._connection.cursor() as cursor:
103
- cursor.execute(f"SHOW IMPORTED KEYS IN TABLE {self.fqn_name}")
140
+ sql = f"SHOW IMPORTED KEYS IN TABLE {self.fqn}"
141
+ cursor.execute(sql)
104
142
  for row in cursor.fetchall():
105
- _, _, _, dst_table, pkey, _, _, _, fkey = row[:9]
143
+ _, _, _, dst_table, pkey, _, _, _, fkey, *_ = row
106
144
  source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
107
145
  return source_fkeys
108
146
 
109
- def _get_sample_df(self) -> pd.DataFrame:
147
+ def _get_source_sample_df(self) -> pd.DataFrame:
110
148
  with self._connection.cursor() as cursor:
111
- columns = ', '.join(self._source_column_dict.keys())
112
- cursor.execute(f"SELECT {columns} FROM {self.fqn_name} LIMIT 1000")
149
+ columns = [quote_ident(col) for col in self._source_column_dict]
150
+ sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
151
+ cursor.execute(sql)
113
152
  table = cursor.fetch_arrow_all()
114
153
  return table.to_pandas(types_mapper=pd.ArrowDtype)
115
154
 
116
- def _get_num_rows(self) -> Optional[int]:
155
+ def _get_num_rows(self) -> int | None:
117
156
  return None
157
+
158
+ def _get_expression_sample_df(
159
+ self,
160
+ specs: Sequence[ColumnExpressionSpec],
161
+ ) -> pd.DataFrame:
162
+ with self._connection.cursor() as cursor:
163
+ columns = [
164
+ f"{spec.expr} AS {quote_ident(spec.name)}" for spec in specs
165
+ ]
166
+ sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
167
+ cursor.execute(sql)
168
+ table = cursor.fetch_arrow_all()
169
+ return table.to_pandas(types_mapper=pd.ArrowDtype)
@@ -1,5 +1,5 @@
1
1
  from pathlib import Path
2
- from typing import Any, TypeAlias, Union
2
+ from typing import Any, TypeAlias
3
3
 
4
4
  try:
5
5
  import adbc_driver_sqlite.dbapi as adbc
@@ -11,7 +11,7 @@ except ImportError:
11
11
  Connection: TypeAlias = adbc.AdbcSqliteConnection
12
12
 
13
13
 
14
- def connect(uri: Union[str, Path, None] = None, **kwargs: Any) -> Connection:
14
+ def connect(uri: str | Path | None = None, **kwargs: Any) -> Connection:
15
15
  r"""Opens a connection to a :class:`sqlite` database.
16
16
 
17
17
  uri: The path to the database file to be opened.
@@ -22,9 +22,11 @@ def connect(uri: Union[str, Path, None] = None, **kwargs: Any) -> Connection:
22
22
 
23
23
 
24
24
  from .table import SQLiteTable # noqa: E402
25
+ from .sampler import SQLiteSampler # noqa: E402
25
26
 
26
27
  __all__ = [
27
28
  'connect',
28
29
  'Connection',
29
30
  'SQLiteTable',
31
+ 'SQLiteSampler',
30
32
  ]