kumoai 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512211732__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. kumoai/_version.py +1 -1
  2. kumoai/experimental/rfm/__init__.py +33 -8
  3. kumoai/experimental/rfm/authenticate.py +3 -4
  4. kumoai/experimental/rfm/backend/local/graph_store.py +25 -25
  5. kumoai/experimental/rfm/backend/local/table.py +16 -21
  6. kumoai/experimental/rfm/backend/snow/sampler.py +22 -34
  7. kumoai/experimental/rfm/backend/snow/table.py +67 -33
  8. kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
  9. kumoai/experimental/rfm/backend/sqlite/sampler.py +21 -26
  10. kumoai/experimental/rfm/backend/sqlite/table.py +54 -26
  11. kumoai/experimental/rfm/base/__init__.py +8 -0
  12. kumoai/experimental/rfm/base/column.py +14 -12
  13. kumoai/experimental/rfm/base/column_expression.py +50 -0
  14. kumoai/experimental/rfm/base/sql_sampler.py +31 -3
  15. kumoai/experimental/rfm/base/sql_table.py +229 -0
  16. kumoai/experimental/rfm/base/table.py +162 -143
  17. kumoai/experimental/rfm/graph.py +242 -95
  18. kumoai/experimental/rfm/infer/__init__.py +6 -4
  19. kumoai/experimental/rfm/infer/dtype.py +3 -3
  20. kumoai/experimental/rfm/infer/pkey.py +4 -2
  21. kumoai/experimental/rfm/infer/stype.py +35 -0
  22. kumoai/experimental/rfm/infer/time_col.py +1 -2
  23. kumoai/experimental/rfm/pquery/executor.py +27 -27
  24. kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
  25. kumoai/experimental/rfm/rfm.py +86 -80
  26. kumoai/experimental/rfm/sagemaker.py +4 -4
  27. kumoai/utils/__init__.py +1 -2
  28. kumoai/utils/progress_logger.py +178 -12
  29. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/METADATA +2 -1
  30. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/RECORD +33 -30
  31. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/WHEEL +0 -0
  32. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/licenses/LICENSE +0 -0
  33. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/top_level.txt +0 -0
kumoai/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '2.14.0.dev202512151351'
1
+ __version__ = '2.14.0.dev202512211732'
@@ -6,11 +6,11 @@ import socket
6
6
  import threading
7
7
  from dataclasses import dataclass
8
8
  from enum import Enum
9
- from typing import Dict, Optional, Tuple
10
9
  from urllib.parse import urlparse
11
10
 
12
11
  import kumoai
13
12
  from kumoai.client.client import KumoClient
13
+ from kumoai.spcs import _get_active_session
14
14
 
15
15
  from .authenticate import authenticate
16
16
  from .sagemaker import (
@@ -49,7 +49,8 @@ class InferenceBackend(str, Enum):
49
49
 
50
50
 
51
51
  def _detect_backend(
52
- url: str) -> Tuple[InferenceBackend, Optional[str], Optional[str]]:
52
+ url: str, #
53
+ ) -> tuple[InferenceBackend, str | None, str | None]:
53
54
  parsed = urlparse(url)
54
55
 
55
56
  # Remote SageMaker
@@ -73,12 +74,27 @@ def _detect_backend(
73
74
  return InferenceBackend.REST, None, None
74
75
 
75
76
 
77
+ def _get_snowflake_url(snowflake_application: str) -> str:
78
+ snowpark_session = _get_active_session()
79
+ if not snowpark_session:
80
+ raise ValueError(
81
+ "Client creation failed: snowflake_application is specified "
82
+ "without an active snowpark session. If running outside "
83
+ "a snowflake notebook, specify a URL and credentials.")
84
+ with snowpark_session.connection.cursor() as cur:
85
+ cur.execute(
86
+ f"DESCRIBE SERVICE {snowflake_application}.user_schema.rfm_service"
87
+ f" ->> SELECT \"dns_name\" from $1")
88
+ dns_name: str = cur.fetchone()[0]
89
+ return f"http://{dns_name}:8000/api"
90
+
91
+
76
92
  @dataclass
77
93
  class RfmGlobalState:
78
94
  _url: str = '__url_not_provided__'
79
95
  _backend: InferenceBackend = InferenceBackend.UNKNOWN
80
- _region: Optional[str] = None
81
- _endpoint_name: Optional[str] = None
96
+ _region: str | None = None
97
+ _endpoint_name: str | None = None
82
98
  _thread_local = threading.local()
83
99
 
84
100
  # Thread-safe init-once.
@@ -121,10 +137,10 @@ global_state = RfmGlobalState()
121
137
 
122
138
 
123
139
  def init(
124
- url: Optional[str] = None,
125
- api_key: Optional[str] = None,
126
- snowflake_credentials: Optional[Dict[str, str]] = None,
127
- snowflake_application: Optional[str] = None,
140
+ url: str | None = None,
141
+ api_key: str | None = None,
142
+ snowflake_credentials: dict[str, str] | None = None,
143
+ snowflake_application: str | None = None,
128
144
  log_level: str = "INFO",
129
145
  ) -> None:
130
146
  with global_state._lock:
@@ -136,6 +152,15 @@ def init(
136
152
  "supported.")
137
153
  return
138
154
 
155
+ if snowflake_application:
156
+ if url is not None:
157
+ raise ValueError(
158
+ "Client creation failed: both snowflake_application and "
159
+ "url are specified. If running from a snowflake notebook, "
160
+ "specify only snowflake_application.")
161
+ url = _get_snowflake_url(snowflake_application)
162
+ api_key = "test:DISABLED"
163
+
139
164
  if url is None:
140
165
  url = os.getenv("RFM_API_URL", "https://kumorfm.ai/api")
141
166
 
@@ -2,12 +2,11 @@ import logging
2
2
  import os
3
3
  import platform
4
4
  from datetime import datetime
5
- from typing import Optional
6
5
 
7
6
  from kumoai import in_notebook
8
7
 
9
8
 
10
- def authenticate(api_url: Optional[str] = None) -> None:
9
+ def authenticate(api_url: str | None = None) -> None:
11
10
  """Authenticates the user and sets the Kumo API key for the SDK.
12
11
 
13
12
  This function detects the current environment and launches the appropriate
@@ -65,11 +64,11 @@ def _authenticate_local(api_url: str, redirect_port: int = 8765) -> None:
65
64
  import webbrowser
66
65
  from getpass import getpass
67
66
  from socketserver import TCPServer
68
- from typing import Any, Dict
67
+ from typing import Any
69
68
 
70
69
  logger = logging.getLogger('kumoai')
71
70
 
72
- token_status: Dict[str, Any] = {
71
+ token_status: dict[str, Any] = {
73
72
  'token': None,
74
73
  'token_name': None,
75
74
  'failed': False
@@ -1,5 +1,5 @@
1
1
  import warnings
2
- from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
2
+ from typing import TYPE_CHECKING
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
@@ -7,7 +7,7 @@ from kumoapi.rfm.context import Subgraph
7
7
  from kumoapi.typing import Stype
8
8
 
9
9
  from kumoai.experimental.rfm.backend.local import LocalTable
10
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
10
+ from kumoai.utils import ProgressLogger
11
11
 
12
12
  try:
13
13
  import torch
@@ -23,12 +23,12 @@ class LocalGraphStore:
23
23
  def __init__(
24
24
  self,
25
25
  graph: 'Graph',
26
- verbose: Union[bool, ProgressLogger] = True,
26
+ verbose: bool | ProgressLogger = True,
27
27
  ) -> None:
28
28
 
29
29
  if not isinstance(verbose, ProgressLogger):
30
- verbose = InteractiveProgressLogger(
31
- "Materializing graph",
30
+ verbose = ProgressLogger.default(
31
+ msg="Materializing graph",
32
32
  verbose=verbose,
33
33
  )
34
34
 
@@ -94,7 +94,7 @@ class LocalGraphStore:
94
94
  def sanitize(
95
95
  self,
96
96
  graph: 'Graph',
97
- ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, np.ndarray]]:
97
+ ) -> tuple[dict[str, pd.DataFrame], dict[str, np.ndarray]]:
98
98
  r"""Sanitizes raw data according to table schema definition:
99
99
 
100
100
  In particular, it:
@@ -103,13 +103,13 @@ class LocalGraphStore:
103
103
  * drops duplicate primary keys
104
104
  * removes rows with missing primary keys or time values
105
105
  """
106
- df_dict: Dict[str, pd.DataFrame] = {}
106
+ df_dict: dict[str, pd.DataFrame] = {}
107
107
  for table_name, table in graph.tables.items():
108
108
  assert isinstance(table, LocalTable)
109
109
  df = table._data
110
110
  df_dict[table_name] = df.copy(deep=False).reset_index(drop=True)
111
111
 
112
- mask_dict: Dict[str, np.ndarray] = {}
112
+ mask_dict: dict[str, np.ndarray] = {}
113
113
  for table in graph.tables.values():
114
114
  for col in table.columns:
115
115
  if col.stype == Stype.timestamp:
@@ -126,7 +126,7 @@ class LocalGraphStore:
126
126
  ser = ser.dt.tz_localize(None)
127
127
  df_dict[table.name][col.name] = ser
128
128
 
129
- mask: Optional[np.ndarray] = None
129
+ mask: np.ndarray | None = None
130
130
  if table._time_column is not None:
131
131
  ser = df_dict[table.name][table._time_column]
132
132
  mask = ser.notna().to_numpy()
@@ -144,8 +144,8 @@ class LocalGraphStore:
144
144
  def get_pkey_map_dict(
145
145
  self,
146
146
  graph: 'Graph',
147
- ) -> Dict[str, pd.DataFrame]:
148
- pkey_map_dict: Dict[str, pd.DataFrame] = {}
147
+ ) -> dict[str, pd.DataFrame]:
148
+ pkey_map_dict: dict[str, pd.DataFrame] = {}
149
149
 
150
150
  for table in graph.tables.values():
151
151
  if table._primary_key is None:
@@ -177,12 +177,12 @@ class LocalGraphStore:
177
177
  def get_time_data(
178
178
  self,
179
179
  graph: 'Graph',
180
- ) -> Tuple[
181
- Dict[str, np.ndarray],
182
- Dict[str, Tuple[pd.Timestamp, pd.Timestamp]],
180
+ ) -> tuple[
181
+ dict[str, np.ndarray],
182
+ dict[str, tuple[pd.Timestamp, pd.Timestamp]],
183
183
  ]:
184
- time_dict: Dict[str, np.ndarray] = {}
185
- min_max_time_dict: Dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
184
+ time_dict: dict[str, np.ndarray] = {}
185
+ min_max_time_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
186
186
  for table in graph.tables.values():
187
187
  if table._time_column is None:
188
188
  continue
@@ -207,15 +207,15 @@ class LocalGraphStore:
207
207
  def get_csc(
208
208
  self,
209
209
  graph: 'Graph',
210
- ) -> Tuple[
211
- Dict[Tuple[str, str, str], np.ndarray],
212
- Dict[Tuple[str, str, str], np.ndarray],
210
+ ) -> tuple[
211
+ dict[tuple[str, str, str], np.ndarray],
212
+ dict[tuple[str, str, str], np.ndarray],
213
213
  ]:
214
214
  # A mapping from raw primary keys to node indices (0 to N-1):
215
- map_dict: Dict[str, pd.CategoricalDtype] = {}
215
+ map_dict: dict[str, pd.CategoricalDtype] = {}
216
216
  # A dictionary to manage offsets of node indices for invalid rows:
217
- offset_dict: Dict[str, np.ndarray] = {}
218
- for table_name in set(edge.dst_table for edge in graph.edges):
217
+ offset_dict: dict[str, np.ndarray] = {}
218
+ for table_name in {edge.dst_table for edge in graph.edges}:
219
219
  ser = self.df_dict[table_name][graph[table_name]._primary_key]
220
220
  if table_name in self.mask_dict.keys():
221
221
  mask = self.mask_dict[table_name]
@@ -224,8 +224,8 @@ class LocalGraphStore:
224
224
  map_dict[table_name] = pd.CategoricalDtype(ser, ordered=True)
225
225
 
226
226
  # Build CSC graph representation:
227
- row_dict: Dict[Tuple[str, str, str], np.ndarray] = {}
228
- colptr_dict: Dict[Tuple[str, str, str], np.ndarray] = {}
227
+ row_dict: dict[tuple[str, str, str], np.ndarray] = {}
228
+ colptr_dict: dict[tuple[str, str, str], np.ndarray] = {}
229
229
  for src_table, fkey, dst_table in graph.edges:
230
230
  src_df = self.df_dict[src_table]
231
231
  dst_df = self.df_dict[dst_table]
@@ -287,7 +287,7 @@ def _argsort(input: np.ndarray) -> np.ndarray:
287
287
  return torch.from_numpy(input).argsort().numpy()
288
288
 
289
289
 
290
- def _lexsort(inputs: List[np.ndarray]) -> np.ndarray:
290
+ def _lexsort(inputs: list[np.ndarray]) -> np.ndarray:
291
291
  assert len(inputs) >= 1
292
292
 
293
293
  if not WITH_TORCH:
@@ -1,14 +1,10 @@
1
1
  import warnings
2
- from typing import List, Optional, cast
2
+ from typing import cast
3
3
 
4
4
  import pandas as pd
5
+ from kumoapi.model_plan import MissingType
5
6
 
6
- from kumoai.experimental.rfm.base import (
7
- DataBackend,
8
- SourceColumn,
9
- SourceForeignKey,
10
- Table,
11
- )
7
+ from kumoai.experimental.rfm.base import DataBackend, SourceColumn, Table
12
8
  from kumoai.experimental.rfm.infer import infer_dtype
13
9
 
14
10
 
@@ -57,9 +53,9 @@ class LocalTable(Table):
57
53
  self,
58
54
  df: pd.DataFrame,
59
55
  name: str,
60
- primary_key: Optional[str] = None,
61
- time_column: Optional[str] = None,
62
- end_time_column: Optional[str] = None,
56
+ primary_key: MissingType | str | None = MissingType.VALUE,
57
+ time_column: str | None = None,
58
+ end_time_column: str | None = None,
63
59
  ) -> None:
64
60
 
65
61
  if df.empty:
@@ -85,17 +81,19 @@ class LocalTable(Table):
85
81
  def backend(self) -> DataBackend:
86
82
  return cast(DataBackend, DataBackend.LOCAL)
87
83
 
88
- def _get_source_columns(self) -> List[SourceColumn]:
89
- source_columns: List[SourceColumn] = []
84
+ def _get_source_columns(self) -> list[SourceColumn]:
85
+ source_columns: list[SourceColumn] = []
90
86
  for column in self._data.columns:
91
87
  ser = self._data[column]
92
88
  try:
93
89
  dtype = infer_dtype(ser)
94
90
  except Exception:
95
- warnings.warn(f"Data type inference for column '{column}' in "
96
- f"table '{self.name}' failed. Consider changing "
97
- f"the data type of the column to use it within "
98
- f"this table.")
91
+ warnings.warn(f"Encountered unsupported data type "
92
+ f"'{ser.dtype}' for column '{column}' in table "
93
+ f"'{self.name}'. Please change the data type of "
94
+ f"the column in the `pandas.DataFrame` to use "
95
+ f"it within this table, or remove it to "
96
+ f"suppress this warning.")
99
97
  continue
100
98
 
101
99
  source_column = SourceColumn(
@@ -109,11 +107,8 @@ class LocalTable(Table):
109
107
 
110
108
  return source_columns
111
109
 
112
- def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
113
- return []
114
-
115
- def _get_sample_df(self) -> pd.DataFrame:
110
+ def _get_source_sample_df(self) -> pd.DataFrame:
116
111
  return self._data
117
112
 
118
- def _get_num_rows(self) -> Optional[int]:
113
+ def _get_num_rows(self) -> int | None:
119
114
  return len(self._data)
@@ -1,39 +1,27 @@
1
1
  import json
2
- from typing import TYPE_CHECKING
2
+ from collections.abc import Iterator
3
+ from contextlib import contextmanager
3
4
 
4
5
  import numpy as np
5
6
  import pandas as pd
6
7
  import pyarrow as pa
7
8
  from kumoapi.pquery import ValidatedPredictiveQuery
8
9
 
9
- from kumoai.experimental.rfm.backend.snow import SnowTable
10
+ from kumoai.experimental.rfm.backend.snow import Connection
10
11
  from kumoai.experimental.rfm.base import SQLSampler
11
12
  from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
12
- from kumoai.utils import ProgressLogger, quote_ident
13
+ from kumoai.utils import quote_ident
13
14
 
14
- if TYPE_CHECKING:
15
- from kumoai.experimental.rfm import Graph
16
15
 
16
+ @contextmanager
17
+ def paramstyle(connection: Connection, style: str = 'qmark') -> Iterator[None]:
18
+ _style = connection._paramstyle
19
+ connection._paramstyle = style
20
+ yield
21
+ connection._paramstyle = _style
17
22
 
18
- class SnowSampler(SQLSampler):
19
- def __init__(
20
- self,
21
- graph: 'Graph',
22
- verbose: bool | ProgressLogger = True,
23
- ) -> None:
24
- super().__init__(graph=graph, verbose=verbose)
25
-
26
- self._fqn_dict: dict[str, str] = {}
27
- for table in graph.tables.values():
28
- assert isinstance(table, SnowTable)
29
- self._connection = table._connection
30
- self._fqn_dict[table.name] = table.fqn
31
-
32
- @property
33
- def fqn_dict(self) -> dict[str, str]:
34
- r"""The fully-qualified quoted names for all tables in the graph."""
35
- return self._fqn_dict
36
23
 
24
+ class SnowSampler(SQLSampler):
37
25
  def _get_min_max_time_dict(
38
26
  self,
39
27
  table_names: list[str],
@@ -42,7 +30,7 @@ class SnowSampler(SQLSampler):
42
30
  for table_name in table_names:
43
31
  time_column = self.time_column_dict[table_name]
44
32
  select = (f"SELECT\n"
45
- f" %s as table_name,\n"
33
+ f" ? as table_name,\n"
46
34
  f" MIN({quote_ident(time_column)}) as min_date,\n"
47
35
  f" MAX({quote_ident(time_column)}) as max_date\n"
48
36
  f"FROM {self.fqn_dict[table_name]}")
@@ -50,14 +38,14 @@ class SnowSampler(SQLSampler):
50
38
  sql = "\nUNION ALL\n".join(selects)
51
39
 
52
40
  out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
53
- with self._connection.cursor() as cursor:
41
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
54
42
  cursor.execute(sql, table_names)
55
43
  rows = cursor.fetchall()
56
- for table_name, _min, _max in rows:
57
- out_dict[table_name] = (
58
- pd.Timestamp.max if _min is None else pd.Timestamp(_min),
59
- pd.Timestamp.min if _max is None else pd.Timestamp(_max),
60
- )
44
+ for table_name, _min, _max in rows:
45
+ out_dict[table_name] = (
46
+ pd.Timestamp.max if _min is None else pd.Timestamp(_min),
47
+ pd.Timestamp.min if _max is None else pd.Timestamp(_max),
48
+ )
61
49
 
62
50
  return out_dict
63
51
 
@@ -179,7 +167,7 @@ class SnowSampler(SQLSampler):
179
167
  sql += " f.value::FLOAT as ID\n"
180
168
  else:
181
169
  sql += " f.value::VARCHAR as ID\n"
182
- sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(%s))) f\n"
170
+ sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
183
171
  f")\n"
184
172
  f"SELECT TMP.BATCH as __BATCH__, "
185
173
  f"{', '.join('ENT.' + quote_ident(col) for col in columns)}\n"
@@ -187,7 +175,7 @@ class SnowSampler(SQLSampler):
187
175
  f"JOIN {self.fqn_dict[table_name]} ENT\n"
188
176
  f" ON ENT.{quote_ident(pkey_name)} = TMP.ID")
189
177
 
190
- with self._connection.cursor() as cursor:
178
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
191
179
  cursor.execute(sql, (payload, ))
192
180
  table = cursor.fetch_arrow_all()
193
181
 
@@ -240,7 +228,7 @@ class SnowSampler(SQLSampler):
240
228
  if min_offset is not None:
241
229
  sql += ",\n f.value[2]::TIMESTAMP_NTZ as START_TIME"
242
230
  sql += (f"\n"
243
- f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(%s))) f\n"
231
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
244
232
  f")\n"
245
233
  f"SELECT TMP.BATCH as __BATCH__, "
246
234
  f"{', '.join('FACT.' + quote_ident(col) for col in columns)}\n"
@@ -251,7 +239,7 @@ class SnowSampler(SQLSampler):
251
239
  if min_offset is not None:
252
240
  sql += f"\n AND FACT.{quote_ident(time_column)} > TMP.START_TIME"
253
241
 
254
- with self._connection.cursor() as cursor:
242
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
255
243
  cursor.execute(sql, (payload, ))
256
244
  table = cursor.fetch_arrow_all()
257
245
 
@@ -1,28 +1,35 @@
1
1
  import re
2
- from typing import List, Optional, Sequence, cast
2
+ from collections.abc import Sequence
3
+ from typing import cast
3
4
 
4
5
  import pandas as pd
6
+ from kumoapi.model_plan import MissingType
5
7
  from kumoapi.typing import Dtype
6
8
 
7
9
  from kumoai.experimental.rfm.backend.snow import Connection
8
10
  from kumoai.experimental.rfm.base import (
11
+ ColumnExpressionSpec,
12
+ ColumnExpressionType,
9
13
  DataBackend,
10
14
  SourceColumn,
11
15
  SourceForeignKey,
12
- Table,
16
+ SQLTable,
13
17
  )
14
18
  from kumoai.utils import quote_ident
15
19
 
16
20
 
17
- class SnowTable(Table):
21
+ class SnowTable(SQLTable):
18
22
  r"""A table backed by a :class:`sqlite` database.
19
23
 
20
24
  Args:
21
25
  connection: The connection to a :class:`snowflake` database.
22
- name: The name of this table.
26
+ name: The logical name of this table.
27
+ source_name: The physical name of this table in the database. If set to
28
+ ``None``, ``name`` is being used.
23
29
  database: The database.
24
30
  schema: The schema.
25
- columns: The selected columns of this table.
31
+ columns: The selected physical columns of this table.
32
+ column_expressions: The logical columns of this table.
26
33
  primary_key: The name of the primary key of this table, if it exists.
27
34
  time_column: The name of the time column of this table, if it exists.
28
35
  end_time_column: The name of the end time column of this table, if it
@@ -32,17 +39,20 @@ class SnowTable(Table):
32
39
  self,
33
40
  connection: Connection,
34
41
  name: str,
42
+ source_name: str | None = None,
35
43
  database: str | None = None,
36
44
  schema: str | None = None,
37
- columns: Optional[Sequence[str]] = None,
38
- primary_key: Optional[str] = None,
39
- time_column: Optional[str] = None,
40
- end_time_column: Optional[str] = None,
45
+ columns: Sequence[str] | None = None,
46
+ column_expressions: Sequence[ColumnExpressionType] | None = None,
47
+ primary_key: MissingType | str | None = MissingType.VALUE,
48
+ time_column: str | None = None,
49
+ end_time_column: str | None = None,
41
50
  ) -> None:
42
51
 
43
52
  if database is not None and schema is None:
44
- raise ValueError(f"Missing 'schema' for table '{name}' in "
45
- f"database '{database}'")
53
+ raise ValueError(f"Unspecified 'schema' for table "
54
+ f"'{source_name or name}' in database "
55
+ f"'{database}'")
46
56
 
47
57
  self._connection = connection
48
58
  self._database = database
@@ -50,12 +60,32 @@ class SnowTable(Table):
50
60
 
51
61
  super().__init__(
52
62
  name=name,
63
+ source_name=source_name,
53
64
  columns=columns,
65
+ column_expressions=column_expressions,
54
66
  primary_key=primary_key,
55
67
  time_column=time_column,
56
68
  end_time_column=end_time_column,
57
69
  )
58
70
 
71
+ @staticmethod
72
+ def to_dtype(snowflake_dtype: str | None) -> Dtype | None:
73
+ if snowflake_dtype is None:
74
+ return None
75
+ snowflake_dtype = snowflake_dtype.strip().upper()
76
+ # TODO 'NUMBER(...)' is not always an integer!
77
+ if snowflake_dtype.startswith('NUMBER'):
78
+ return Dtype.int
79
+ elif snowflake_dtype.startswith('VARCHAR'):
80
+ return Dtype.string
81
+ elif snowflake_dtype == 'FLOAT':
82
+ return Dtype.float
83
+ elif snowflake_dtype == 'BOOLEAN':
84
+ return Dtype.bool
85
+ elif re.search('DATE|TIMESTAMP', snowflake_dtype):
86
+ return Dtype.date
87
+ return None
88
+
59
89
  @property
60
90
  def backend(self) -> DataBackend:
61
91
  return cast(DataBackend, DataBackend.SNOWFLAKE)
@@ -63,15 +93,15 @@ class SnowTable(Table):
63
93
  @property
64
94
  def fqn(self) -> str:
65
95
  r"""The fully-qualified quoted table name."""
66
- names: List[str] = []
96
+ names: list[str] = []
67
97
  if self._database is not None:
68
98
  names.append(quote_ident(self._database))
69
99
  if self._schema is not None:
70
100
  names.append(quote_ident(self._schema))
71
- return '.'.join(names + [quote_ident(self._name)])
101
+ return '.'.join(names + [quote_ident(self._source_name)])
72
102
 
73
- def _get_source_columns(self) -> List[SourceColumn]:
74
- source_columns: List[SourceColumn] = []
103
+ def _get_source_columns(self) -> list[SourceColumn]:
104
+ source_columns: list[SourceColumn] = []
75
105
  with self._connection.cursor() as cursor:
76
106
  try:
77
107
  sql = f"DESCRIBE TABLE {self.fqn}"
@@ -82,24 +112,15 @@ class SnowTable(Table):
82
112
  names.append(self._database)
83
113
  if self._schema is not None:
84
114
  names.append(self._schema)
85
- name = '.'.join(names + [self._name])
86
- raise ValueError(f"Table '{name}' does not exist") from e
115
+ source_name = '.'.join(names + [self._source_name])
116
+ raise ValueError(f"Table '{source_name}' does not exist in "
117
+ f"the remote data backend") from e
87
118
 
88
119
  for row in cursor.fetchall():
89
120
  column, type, _, null, _, is_pkey, is_unique, *_ = row
90
121
 
91
- type = type.strip().upper()
92
- if type.startswith('NUMBER'):
93
- dtype = Dtype.int
94
- elif type.startswith('VARCHAR'):
95
- dtype = Dtype.string
96
- elif type == 'FLOAT':
97
- dtype = Dtype.float
98
- elif type == 'BOOLEAN':
99
- dtype = Dtype.bool
100
- elif re.search('DATE|TIMESTAMP', type):
101
- dtype = Dtype.date
102
- else:
122
+ dtype = self.to_dtype(type)
123
+ if dtype is None:
103
124
  continue
104
125
 
105
126
  source_column = SourceColumn(
@@ -113,8 +134,8 @@ class SnowTable(Table):
113
134
 
114
135
  return source_columns
115
136
 
116
- def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
117
- source_fkeys: List[SourceForeignKey] = []
137
+ def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
138
+ source_fkeys: list[SourceForeignKey] = []
118
139
  with self._connection.cursor() as cursor:
119
140
  sql = f"SHOW IMPORTED KEYS IN TABLE {self.fqn}"
120
141
  cursor.execute(sql)
@@ -123,7 +144,7 @@ class SnowTable(Table):
123
144
  source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
124
145
  return source_fkeys
125
146
 
126
- def _get_sample_df(self) -> pd.DataFrame:
147
+ def _get_source_sample_df(self) -> pd.DataFrame:
127
148
  with self._connection.cursor() as cursor:
128
149
  columns = [quote_ident(col) for col in self._source_column_dict]
129
150
  sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
@@ -131,5 +152,18 @@ class SnowTable(Table):
131
152
  table = cursor.fetch_arrow_all()
132
153
  return table.to_pandas(types_mapper=pd.ArrowDtype)
133
154
 
134
- def _get_num_rows(self) -> Optional[int]:
155
+ def _get_num_rows(self) -> int | None:
135
156
  return None
157
+
158
+ def _get_expression_sample_df(
159
+ self,
160
+ specs: Sequence[ColumnExpressionSpec],
161
+ ) -> pd.DataFrame:
162
+ with self._connection.cursor() as cursor:
163
+ columns = [
164
+ f"{spec.expr} AS {quote_ident(spec.name)}" for spec in specs
165
+ ]
166
+ sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
167
+ cursor.execute(sql)
168
+ table = cursor.fetch_arrow_all()
169
+ return table.to_pandas(types_mapper=pd.ArrowDtype)
@@ -1,5 +1,5 @@
1
1
  from pathlib import Path
2
- from typing import Any, TypeAlias, Union
2
+ from typing import Any, TypeAlias
3
3
 
4
4
  try:
5
5
  import adbc_driver_sqlite.dbapi as adbc
@@ -11,7 +11,7 @@ except ImportError:
11
11
  Connection: TypeAlias = adbc.AdbcSqliteConnection
12
12
 
13
13
 
14
- def connect(uri: Union[str, Path, None] = None, **kwargs: Any) -> Connection:
14
+ def connect(uri: str | Path | None = None, **kwargs: Any) -> Connection:
15
15
  r"""Opens a connection to a :class:`sqlite` database.
16
16
 
17
17
  uri: The path to the database file to be opened.