kumoai 2.13.0.dev202512041731__cp310-cp310-win_amd64.whl → 2.15.0.dev202601141731__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +21 -7
  7. kumoai/experimental/rfm/__init__.py +51 -24
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  10. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
  11. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  12. kumoai/experimental/rfm/backend/local/table.py +35 -31
  13. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  14. kumoai/experimental/rfm/backend/snow/sampler.py +407 -0
  15. kumoai/experimental/rfm/backend/snow/table.py +178 -50
  16. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  17. kumoai/experimental/rfm/backend/sqlite/sampler.py +456 -0
  18. kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
  19. kumoai/experimental/rfm/base/__init__.py +22 -4
  20. kumoai/experimental/rfm/base/column.py +96 -10
  21. kumoai/experimental/rfm/base/expression.py +44 -0
  22. kumoai/experimental/rfm/base/mapper.py +69 -0
  23. kumoai/experimental/rfm/base/sampler.py +696 -47
  24. kumoai/experimental/rfm/base/source.py +2 -1
  25. kumoai/experimental/rfm/base/sql_sampler.py +385 -0
  26. kumoai/experimental/rfm/base/table.py +384 -207
  27. kumoai/experimental/rfm/base/utils.py +36 -0
  28. kumoai/experimental/rfm/graph.py +359 -187
  29. kumoai/experimental/rfm/infer/__init__.py +6 -4
  30. kumoai/experimental/rfm/infer/dtype.py +10 -5
  31. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  32. kumoai/experimental/rfm/infer/pkey.py +4 -2
  33. kumoai/experimental/rfm/infer/stype.py +35 -0
  34. kumoai/experimental/rfm/infer/time_col.py +5 -4
  35. kumoai/experimental/rfm/pquery/executor.py +27 -27
  36. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  37. kumoai/experimental/rfm/relbench.py +76 -0
  38. kumoai/experimental/rfm/rfm.py +770 -467
  39. kumoai/experimental/rfm/sagemaker.py +4 -4
  40. kumoai/experimental/rfm/task_table.py +292 -0
  41. kumoai/kumolib.cp310-win_amd64.pyd +0 -0
  42. kumoai/pquery/predictive_query.py +10 -6
  43. kumoai/pquery/training_table.py +16 -2
  44. kumoai/testing/snow.py +50 -0
  45. kumoai/trainer/distilled_trainer.py +175 -0
  46. kumoai/utils/__init__.py +3 -2
  47. kumoai/utils/display.py +87 -0
  48. kumoai/utils/progress_logger.py +192 -13
  49. kumoai/utils/sql.py +3 -0
  50. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/METADATA +3 -2
  51. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/RECORD +54 -42
  52. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  53. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  54. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/WHEEL +0 -0
  55. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/licenses/LICENSE +0 -0
  56. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,22 @@
1
1
  import re
2
- from typing import List, Optional, Sequence
2
+ from collections import Counter
3
+ from collections.abc import Sequence
4
+ from typing import cast
3
5
 
4
6
  import pandas as pd
7
+ from kumoapi.model_plan import MissingType
5
8
  from kumoapi.typing import Dtype
6
9
 
7
10
  from kumoai.experimental.rfm.backend.snow import Connection
8
- from kumoai.experimental.rfm.base import SourceColumn, SourceForeignKey, Table
11
+ from kumoai.experimental.rfm.base import (
12
+ ColumnSpec,
13
+ ColumnSpecType,
14
+ DataBackend,
15
+ SourceColumn,
16
+ SourceForeignKey,
17
+ Table,
18
+ )
19
+ from kumoai.utils import quote_ident
9
20
 
10
21
 
11
22
  class SnowTable(Table):
@@ -14,6 +25,8 @@ class SnowTable(Table):
14
25
  Args:
15
26
  connection: The connection to a :class:`snowflake` database.
16
27
  name: The name of this table.
28
+ source_name: The source name of this table. If set to ``None``,
29
+ ``name`` is being used.
17
30
  database: The database.
18
31
  schema: The schema.
19
32
  columns: The selected columns of this table.
@@ -26,17 +39,27 @@ class SnowTable(Table):
26
39
  self,
27
40
  connection: Connection,
28
41
  name: str,
42
+ source_name: str | None = None,
29
43
  database: str | None = None,
30
44
  schema: str | None = None,
31
- columns: Optional[Sequence[str]] = None,
32
- primary_key: Optional[str] = None,
33
- time_column: Optional[str] = None,
34
- end_time_column: Optional[str] = None,
45
+ columns: Sequence[ColumnSpecType] | None = None,
46
+ primary_key: MissingType | str | None = MissingType.VALUE,
47
+ time_column: str | None = None,
48
+ end_time_column: str | None = None,
35
49
  ) -> None:
36
50
 
37
- if database is not None and schema is None:
38
- raise ValueError(f"Missing 'schema' for table '{name}' in "
39
- f"database '{database}'")
51
+ if database is None or schema is None:
52
+ with connection.cursor() as cursor:
53
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
54
+ result = cursor.fetchone()
55
+ database = database or result[0]
56
+ assert database is not None
57
+ schema = schema or result[1]
58
+
59
+ if schema is None:
60
+ raise ValueError(f"Unspecified 'schema' for table "
61
+ f"'{source_name or name}' in database "
62
+ f"'{database}'")
40
63
 
41
64
  self._connection = connection
42
65
  self._database = database
@@ -44,6 +67,7 @@ class SnowTable(Table):
44
67
 
45
68
  super().__init__(
46
69
  name=name,
70
+ source_name=source_name,
47
71
  columns=columns,
48
72
  primary_key=primary_key,
49
73
  time_column=time_column,
@@ -51,67 +75,171 @@ class SnowTable(Table):
51
75
  )
52
76
 
53
77
  @property
54
- def fqn_name(self) -> str:
55
- names: List[str] = []
56
- if self._database is not None:
57
- assert self._schema is not None
58
- names.extend([self._database, self._schema])
59
- elif self._schema is not None:
60
- names.append(self._schema)
61
- names.append(self._name)
78
+ def source_name(self) -> str:
79
+ names = [self._database, self._schema, self._source_name]
62
80
  return '.'.join(names)
63
81
 
64
- def _get_source_columns(self) -> List[SourceColumn]:
65
- source_columns: List[SourceColumn] = []
82
+ @property
83
+ def _quoted_source_name(self) -> str:
84
+ names = [self._database, self._schema, self._source_name]
85
+ return '.'.join([quote_ident(name) for name in names])
86
+
87
+ @property
88
+ def backend(self) -> DataBackend:
89
+ return cast(DataBackend, DataBackend.SNOWFLAKE)
90
+
91
+ def _get_source_columns(self) -> list[SourceColumn]:
92
+ source_columns: list[SourceColumn] = []
66
93
  with self._connection.cursor() as cursor:
67
94
  try:
68
- cursor.execute(f"DESCRIBE TABLE {self.fqn_name}")
95
+ sql = f"DESCRIBE TABLE {self._quoted_source_name}"
96
+ cursor.execute(sql)
69
97
  except Exception as e:
70
- raise ValueError(
71
- f"Table '{self.fqn_name}' does not exist") from e
98
+ raise ValueError(f"Table '{self.source_name}' does not exist "
99
+ f"in the remote data backend") from e
72
100
 
73
101
  for row in cursor.fetchall():
74
- column, type, _, _, _, is_pkey, is_unique = row[:7]
75
-
76
- type = type.strip().upper()
77
- if type.startswith('NUMBER'):
78
- dtype = Dtype.int
79
- elif type.startswith('VARCHAR'):
80
- dtype = Dtype.string
81
- elif type == 'FLOAT':
82
- dtype = Dtype.float
83
- elif type == 'BOOLEAN':
84
- dtype = Dtype.bool
85
- elif re.search('DATE|TIMESTAMP', type):
86
- dtype = Dtype.date
87
- else:
88
- continue
102
+ column, dtype, _, null, _, is_pkey, is_unique, *_ = row
89
103
 
90
104
  source_column = SourceColumn(
91
105
  name=column,
92
- dtype=dtype,
106
+ dtype=self._to_dtype(dtype),
93
107
  is_primary_key=is_pkey.strip().upper() == 'Y',
94
108
  is_unique_key=is_unique.strip().upper() == 'Y',
109
+ is_nullable=null.strip().upper() == 'Y',
95
110
  )
96
111
  source_columns.append(source_column)
97
112
 
98
113
  return source_columns
99
114
 
100
- def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
101
- source_fkeys: List[SourceForeignKey] = []
115
+ def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
116
+ source_foreign_keys: list[SourceForeignKey] = []
102
117
  with self._connection.cursor() as cursor:
103
- cursor.execute(f"SHOW IMPORTED KEYS IN TABLE {self.fqn_name}")
104
- for row in cursor.fetchall():
105
- _, _, _, dst_table, pkey, _, _, _, fkey = row[:9]
106
- source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
107
- return source_fkeys
118
+ sql = f"SHOW IMPORTED KEYS IN TABLE {self._quoted_source_name}"
119
+ cursor.execute(sql)
120
+ rows = cursor.fetchall()
121
+ counts = Counter(row[13] for row in rows)
122
+ for row in rows:
123
+ if counts[row[13]] == 1:
124
+ source_foreign_key = SourceForeignKey(
125
+ name=row[8],
126
+ dst_table=f'{row[1]}.{row[2]}.{row[3]}',
127
+ primary_key=row[4],
128
+ )
129
+ source_foreign_keys.append(source_foreign_key)
130
+ return source_foreign_keys
131
+
132
+ def _get_source_sample_df(self) -> pd.DataFrame:
133
+ with self._connection.cursor() as cursor:
134
+ columns = [quote_ident(col) for col in self._source_column_dict]
135
+ sql = (f"SELECT {', '.join(columns)} "
136
+ f"FROM {self._quoted_source_name} "
137
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
138
+ cursor.execute(sql)
139
+ table = cursor.fetch_arrow_all()
108
140
 
109
- def _get_sample_df(self) -> pd.DataFrame:
141
+ if table is None:
142
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
143
+
144
+ return self._sanitize(
145
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
146
+ dtype_dict={
147
+ column.name: column.dtype
148
+ for column in self._source_column_dict.values()
149
+ },
150
+ stype_dict=None,
151
+ )
152
+
153
+ def _get_num_rows(self) -> int | None:
110
154
  with self._connection.cursor() as cursor:
111
- columns = ', '.join(self._source_column_dict.keys())
112
- cursor.execute(f"SELECT {columns} FROM {self.fqn_name} LIMIT 1000")
155
+ quoted_source_name = quote_ident(self._source_name, char="'")
156
+ sql = (f"SHOW TABLES LIKE {quoted_source_name} "
157
+ f"IN SCHEMA {quote_ident(self._database)}."
158
+ f"{quote_ident(self._schema)}")
159
+ cursor.execute(sql)
160
+ num_rows = cursor.fetchone()[7]
161
+
162
+ if num_rows == 0:
163
+ raise RuntimeError("Table '{self.source_name}' is empty")
164
+
165
+ return num_rows
166
+
167
+ def _get_expr_sample_df(
168
+ self,
169
+ columns: Sequence[ColumnSpec],
170
+ ) -> pd.DataFrame:
171
+ with self._connection.cursor() as cursor:
172
+ projections = [
173
+ f"{column.expr} AS {quote_ident(column.name)}"
174
+ for column in columns
175
+ ]
176
+ sql = (f"SELECT {', '.join(projections)} "
177
+ f"FROM {self._quoted_source_name} "
178
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
179
+ cursor.execute(sql)
113
180
  table = cursor.fetch_arrow_all()
114
- return table.to_pandas(types_mapper=pd.ArrowDtype)
115
181
 
116
- def _get_num_rows(self) -> Optional[int]:
182
+ if table is None:
183
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
184
+
185
+ return self._sanitize(
186
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
187
+ dtype_dict={column.name: column.dtype
188
+ for column in columns},
189
+ stype_dict=None,
190
+ )
191
+
192
+ @staticmethod
193
+ def _to_dtype(dtype: str | None) -> Dtype | None:
194
+ if dtype is None:
195
+ return None
196
+ dtype = dtype.strip().upper()
197
+ if dtype.startswith('NUMBER'):
198
+ try: # Parse `scale` from 'NUMBER(precision, scale)':
199
+ scale = int(dtype.split(',')[-1].split(')')[0])
200
+ return Dtype.int if scale == 0 else Dtype.float
201
+ except Exception:
202
+ return Dtype.float
203
+ if dtype == 'FLOAT':
204
+ return Dtype.float
205
+ if dtype.startswith('VARCHAR'):
206
+ return Dtype.string
207
+ if dtype.startswith('BINARY'):
208
+ return Dtype.binary
209
+ if dtype == 'BOOLEAN':
210
+ return Dtype.bool
211
+ if dtype.startswith('DATE') or dtype.startswith('TIMESTAMP'):
212
+ return Dtype.date
213
+ if dtype.startswith('TIME'):
214
+ return Dtype.time
215
+ if dtype.startswith('VECTOR'):
216
+ try: # Parse element data type from 'VECTOR(dtype, dimension)':
217
+ dtype = dtype.split(',')[0].split('(')[1].strip()
218
+ if dtype == 'INT':
219
+ return Dtype.intlist
220
+ elif dtype == 'FLOAT':
221
+ return Dtype.floatlist
222
+ except Exception:
223
+ pass
224
+ return Dtype.unsupported
225
+ if dtype.startswith('ARRAY'):
226
+ try: # Parse element data type from 'ARRAY(dtype)':
227
+ dtype = dtype.split('(', maxsplit=1)[1]
228
+ dtype = dtype.rsplit(')', maxsplit=1)[0]
229
+ _dtype = SnowTable._to_dtype(dtype)
230
+ if _dtype is not None and _dtype.is_int():
231
+ return Dtype.intlist
232
+ elif _dtype is not None and _dtype.is_float():
233
+ return Dtype.floatlist
234
+ elif _dtype is not None and _dtype.is_string():
235
+ return Dtype.stringlist
236
+ except Exception:
237
+ pass
238
+ return Dtype.unsupported
239
+ # Unsupported data types:
240
+ if re.search(
241
+ 'DECFLOAT|VARIANT|OBJECT|MAP|FILE|GEOGRAPHY|GEOMETRY',
242
+ dtype,
243
+ ):
244
+ return Dtype.unsupported
117
245
  return None
@@ -1,5 +1,5 @@
1
1
  from pathlib import Path
2
- from typing import Any, TypeAlias, Union
2
+ from typing import Any, TypeAlias
3
3
 
4
4
  try:
5
5
  import adbc_driver_sqlite.dbapi as adbc
@@ -11,7 +11,7 @@ except ImportError:
11
11
  Connection: TypeAlias = adbc.AdbcSqliteConnection
12
12
 
13
13
 
14
- def connect(uri: Union[str, Path, None] = None, **kwargs: Any) -> Connection:
14
+ def connect(uri: str | Path | None = None, **kwargs: Any) -> Connection:
15
15
  r"""Opens a connection to a :class:`sqlite` database.
16
16
 
17
17
  uri: The path to the database file to be opened.
@@ -22,9 +22,11 @@ def connect(uri: Union[str, Path, None] = None, **kwargs: Any) -> Connection:
22
22
 
23
23
 
24
24
  from .table import SQLiteTable # noqa: E402
25
+ from .sampler import SQLiteSampler # noqa: E402
25
26
 
26
27
  __all__ = [
27
28
  'connect',
28
29
  'Connection',
29
30
  'SQLiteTable',
31
+ 'SQLiteSampler',
30
32
  ]