kumoai 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl → 2.15.0.dev202601121731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +51 -24
  7. kumoai/experimental/rfm/authenticate.py +3 -4
  8. kumoai/experimental/rfm/backend/local/graph_store.py +37 -46
  9. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  10. kumoai/experimental/rfm/backend/local/table.py +24 -30
  11. kumoai/experimental/rfm/backend/snow/sampler.py +197 -90
  12. kumoai/experimental/rfm/backend/snow/table.py +159 -52
  13. kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
  14. kumoai/experimental/rfm/backend/sqlite/sampler.py +199 -99
  15. kumoai/experimental/rfm/backend/sqlite/table.py +103 -45
  16. kumoai/experimental/rfm/base/__init__.py +6 -1
  17. kumoai/experimental/rfm/base/column.py +96 -10
  18. kumoai/experimental/rfm/base/expression.py +44 -0
  19. kumoai/experimental/rfm/base/mapper.py +69 -0
  20. kumoai/experimental/rfm/base/sampler.py +28 -18
  21. kumoai/experimental/rfm/base/source.py +1 -1
  22. kumoai/experimental/rfm/base/sql_sampler.py +342 -13
  23. kumoai/experimental/rfm/base/table.py +374 -208
  24. kumoai/experimental/rfm/base/utils.py +27 -0
  25. kumoai/experimental/rfm/graph.py +335 -180
  26. kumoai/experimental/rfm/infer/__init__.py +6 -4
  27. kumoai/experimental/rfm/infer/dtype.py +7 -4
  28. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  29. kumoai/experimental/rfm/infer/pkey.py +4 -2
  30. kumoai/experimental/rfm/infer/stype.py +35 -0
  31. kumoai/experimental/rfm/infer/time_col.py +5 -4
  32. kumoai/experimental/rfm/pquery/executor.py +27 -27
  33. kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
  34. kumoai/experimental/rfm/relbench.py +76 -0
  35. kumoai/experimental/rfm/rfm.py +600 -360
  36. kumoai/experimental/rfm/sagemaker.py +4 -4
  37. kumoai/experimental/rfm/task_table.py +292 -0
  38. kumoai/pquery/training_table.py +16 -2
  39. kumoai/testing/snow.py +3 -3
  40. kumoai/trainer/distilled_trainer.py +175 -0
  41. kumoai/utils/__init__.py +1 -2
  42. kumoai/utils/display.py +87 -0
  43. kumoai/utils/progress_logger.py +190 -12
  44. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/METADATA +3 -2
  45. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/RECORD +48 -40
  46. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/WHEEL +0 -0
  47. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/licenses/LICENSE +0 -0
  48. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,16 @@
1
1
  import re
2
- from typing import List, Optional, Sequence, cast
2
+ from collections import Counter
3
+ from collections.abc import Sequence
4
+ from typing import cast
3
5
 
4
6
  import pandas as pd
7
+ from kumoapi.model_plan import MissingType
5
8
  from kumoapi.typing import Dtype
6
9
 
7
10
  from kumoai.experimental.rfm.backend.snow import Connection
8
11
  from kumoai.experimental.rfm.base import (
12
+ ColumnSpec,
13
+ ColumnSpecType,
9
14
  DataBackend,
10
15
  SourceColumn,
11
16
  SourceForeignKey,
@@ -20,6 +25,8 @@ class SnowTable(Table):
20
25
  Args:
21
26
  connection: The connection to a :class:`snowflake` database.
22
27
  name: The name of this table.
28
+ source_name: The source name of this table. If set to ``None``,
29
+ ``name`` is being used.
23
30
  database: The database.
24
31
  schema: The schema.
25
32
  columns: The selected columns of this table.
@@ -32,17 +39,27 @@ class SnowTable(Table):
32
39
  self,
33
40
  connection: Connection,
34
41
  name: str,
42
+ source_name: str | None = None,
35
43
  database: str | None = None,
36
44
  schema: str | None = None,
37
- columns: Optional[Sequence[str]] = None,
38
- primary_key: Optional[str] = None,
39
- time_column: Optional[str] = None,
40
- end_time_column: Optional[str] = None,
45
+ columns: Sequence[ColumnSpecType] | None = None,
46
+ primary_key: MissingType | str | None = MissingType.VALUE,
47
+ time_column: str | None = None,
48
+ end_time_column: str | None = None,
41
49
  ) -> None:
42
50
 
43
- if database is not None and schema is None:
44
- raise ValueError(f"Missing 'schema' for table '{name}' in "
45
- f"database '{database}'")
51
+ if database is None or schema is None:
52
+ with connection.cursor() as cursor:
53
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
54
+ result = cursor.fetchone()
55
+ database = database or result[0]
56
+ assert database is not None
57
+ schema = schema or result[1]
58
+
59
+ if schema is None:
60
+ raise ValueError(f"Unspecified 'schema' for table "
61
+ f"'{source_name or name}' in database "
62
+ f"'{database}'")
46
63
 
47
64
  self._connection = connection
48
65
  self._database = database
@@ -50,6 +67,7 @@ class SnowTable(Table):
50
67
 
51
68
  super().__init__(
52
69
  name=name,
70
+ source_name=source_name,
53
71
  columns=columns,
54
72
  primary_key=primary_key,
55
73
  time_column=time_column,
@@ -57,54 +75,43 @@ class SnowTable(Table):
57
75
  )
58
76
 
59
77
  @property
60
- def backend(self) -> DataBackend:
61
- return cast(DataBackend, DataBackend.SNOWFLAKE)
78
+ def source_name(self) -> str:
79
+ names: list[str] = []
80
+ if self._database is not None:
81
+ names.append(self._database)
82
+ if self._schema is not None:
83
+ names.append(self._schema)
84
+ return '.'.join(names + [self._source_name])
62
85
 
63
86
  @property
64
- def fqn(self) -> str:
65
- r"""The fully-qualified quoted table name."""
66
- names: List[str] = []
87
+ def _quoted_source_name(self) -> str:
88
+ names: list[str] = []
67
89
  if self._database is not None:
68
90
  names.append(quote_ident(self._database))
69
91
  if self._schema is not None:
70
92
  names.append(quote_ident(self._schema))
71
- return '.'.join(names + [quote_ident(self._name)])
93
+ return '.'.join(names + [quote_ident(self._source_name)])
94
+
95
+ @property
96
+ def backend(self) -> DataBackend:
97
+ return cast(DataBackend, DataBackend.SNOWFLAKE)
72
98
 
73
- def _get_source_columns(self) -> List[SourceColumn]:
74
- source_columns: List[SourceColumn] = []
99
+ def _get_source_columns(self) -> list[SourceColumn]:
100
+ source_columns: list[SourceColumn] = []
75
101
  with self._connection.cursor() as cursor:
76
102
  try:
77
- sql = f"DESCRIBE TABLE {self.fqn}"
103
+ sql = f"DESCRIBE TABLE {self._quoted_source_name}"
78
104
  cursor.execute(sql)
79
105
  except Exception as e:
80
- names: list[str] = []
81
- if self._database is not None:
82
- names.append(self._database)
83
- if self._schema is not None:
84
- names.append(self._schema)
85
- name = '.'.join(names + [self._name])
86
- raise ValueError(f"Table '{name}' does not exist") from e
106
+ raise ValueError(f"Table '{self.source_name}' does not exist "
107
+ f"in the remote data backend") from e
87
108
 
88
109
  for row in cursor.fetchall():
89
- column, type, _, null, _, is_pkey, is_unique, *_ = row
90
-
91
- type = type.strip().upper()
92
- if type.startswith('NUMBER'):
93
- dtype = Dtype.int
94
- elif type.startswith('VARCHAR'):
95
- dtype = Dtype.string
96
- elif type == 'FLOAT':
97
- dtype = Dtype.float
98
- elif type == 'BOOLEAN':
99
- dtype = Dtype.bool
100
- elif re.search('DATE|TIMESTAMP', type):
101
- dtype = Dtype.date
102
- else:
103
- continue
110
+ column, dtype, _, null, _, is_pkey, is_unique, *_ = row
104
111
 
105
112
  source_column = SourceColumn(
106
113
  name=column,
107
- dtype=dtype,
114
+ dtype=self._to_dtype(dtype),
108
115
  is_primary_key=is_pkey.strip().upper() == 'Y',
109
116
  is_unique_key=is_unique.strip().upper() == 'Y',
110
117
  is_nullable=null.strip().upper() == 'Y',
@@ -113,23 +120,123 @@ class SnowTable(Table):
113
120
 
114
121
  return source_columns
115
122
 
116
- def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
117
- source_fkeys: List[SourceForeignKey] = []
123
+ def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
124
+ source_foreign_keys: list[SourceForeignKey] = []
118
125
  with self._connection.cursor() as cursor:
119
- sql = f"SHOW IMPORTED KEYS IN TABLE {self.fqn}"
126
+ sql = f"SHOW IMPORTED KEYS IN TABLE {self._quoted_source_name}"
120
127
  cursor.execute(sql)
121
- for row in cursor.fetchall():
122
- _, _, _, dst_table, pkey, _, _, _, fkey, *_ = row
123
- source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
124
- return source_fkeys
125
-
126
- def _get_sample_df(self) -> pd.DataFrame:
128
+ rows = cursor.fetchall()
129
+ counts = Counter(row[13] for row in rows)
130
+ for row in rows:
131
+ if counts[row[13]] == 1:
132
+ source_foreign_key = SourceForeignKey(
133
+ name=row[8],
134
+ dst_table=f'{row[1]}.{row[2]}.{row[3]}',
135
+ primary_key=row[4],
136
+ )
137
+ source_foreign_keys.append(source_foreign_key)
138
+ return source_foreign_keys
139
+
140
+ def _get_source_sample_df(self) -> pd.DataFrame:
127
141
  with self._connection.cursor() as cursor:
128
142
  columns = [quote_ident(col) for col in self._source_column_dict]
129
- sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
143
+ sql = (f"SELECT {', '.join(columns)} "
144
+ f"FROM {self._quoted_source_name} "
145
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
146
+ cursor.execute(sql)
147
+ table = cursor.fetch_arrow_all()
148
+
149
+ if table is None:
150
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
151
+
152
+ return self._sanitize(
153
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
154
+ dtype_dict={
155
+ column.name: column.dtype
156
+ for column in self._source_column_dict.values()
157
+ },
158
+ stype_dict=None,
159
+ )
160
+
161
+ def _get_num_rows(self) -> int | None:
162
+ return None
163
+
164
+ def _get_expr_sample_df(
165
+ self,
166
+ columns: Sequence[ColumnSpec],
167
+ ) -> pd.DataFrame:
168
+ with self._connection.cursor() as cursor:
169
+ projections = [
170
+ f"{column.expr} AS {quote_ident(column.name)}"
171
+ for column in columns
172
+ ]
173
+ sql = (f"SELECT {', '.join(projections)} "
174
+ f"FROM {self._quoted_source_name} "
175
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
130
176
  cursor.execute(sql)
131
177
  table = cursor.fetch_arrow_all()
132
- return table.to_pandas(types_mapper=pd.ArrowDtype)
133
178
 
134
- def _get_num_rows(self) -> Optional[int]:
179
+ if table is None:
180
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
181
+
182
+ return self._sanitize(
183
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
184
+ dtype_dict={column.name: column.dtype
185
+ for column in columns},
186
+ stype_dict=None,
187
+ )
188
+
189
+ @staticmethod
190
+ def _to_dtype(dtype: str | None) -> Dtype | None:
191
+ if dtype is None:
192
+ return None
193
+ dtype = dtype.strip().upper()
194
+ if dtype.startswith('NUMBER'):
195
+ try: # Parse `scale` from 'NUMBER(precision, scale)':
196
+ scale = int(dtype.split(',')[-1].split(')')[0])
197
+ return Dtype.int if scale == 0 else Dtype.float
198
+ except Exception:
199
+ return Dtype.float
200
+ if dtype == 'FLOAT':
201
+ return Dtype.float
202
+ if dtype.startswith('VARCHAR'):
203
+ return Dtype.string
204
+ if dtype.startswith('BINARY'):
205
+ return Dtype.binary
206
+ if dtype == 'BOOLEAN':
207
+ return Dtype.bool
208
+ if dtype.startswith('DATE') or dtype.startswith('TIMESTAMP'):
209
+ return Dtype.date
210
+ if dtype.startswith('TIME'):
211
+ return Dtype.time
212
+ if dtype.startswith('VECTOR'):
213
+ try: # Parse element data type from 'VECTOR(dtype, dimension)':
214
+ dtype = dtype.split(',')[0].split('(')[1].strip()
215
+ if dtype == 'INT':
216
+ return Dtype.intlist
217
+ elif dtype == 'FLOAT':
218
+ return Dtype.floatlist
219
+ except Exception:
220
+ pass
221
+ return Dtype.unsupported
222
+ if dtype.startswith('ARRAY'):
223
+ try: # Parse element data type from 'ARRAY(dtype)':
224
+ dtype = dtype.split('(', maxsplit=1)[1]
225
+ dtype = dtype.rsplit(')', maxsplit=1)[0]
226
+ _dtype = SnowTable._to_dtype(dtype)
227
+ if _dtype is not None and _dtype.is_int():
228
+ return Dtype.intlist
229
+ elif _dtype is not None and _dtype.is_float():
230
+ return Dtype.floatlist
231
+ elif _dtype is not None and _dtype.is_string():
232
+ return Dtype.stringlist
233
+ except Exception:
234
+ pass
235
+ return Dtype.unsupported
236
+ # Unsupported data types:
237
+ if re.search(
238
+ 'DECFLOAT|VARIANT|OBJECT|MAP|FILE|GEOGRAPHY|GEOMETRY',
239
+ dtype,
240
+ ):
241
+ return Dtype.unsupported
135
242
  return None
@@ -1,5 +1,5 @@
1
1
  from pathlib import Path
2
- from typing import Any, TypeAlias, Union
2
+ from typing import Any, TypeAlias
3
3
 
4
4
  try:
5
5
  import adbc_driver_sqlite.dbapi as adbc
@@ -11,7 +11,7 @@ except ImportError:
11
11
  Connection: TypeAlias = adbc.AdbcSqliteConnection
12
12
 
13
13
 
14
- def connect(uri: Union[str, Path, None] = None, **kwargs: Any) -> Connection:
14
+ def connect(uri: str | Path | None = None, **kwargs: Any) -> Connection:
15
15
  r"""Opens a connection to a :class:`sqlite` database.
16
16
 
17
17
  uri: The path to the database file to be opened.