kumoai 2.13.0.dev202512031731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512301731__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. kumoai/__init__.py +35 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/experimental/rfm/__init__.py +49 -24
  7. kumoai/experimental/rfm/authenticate.py +3 -4
  8. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  9. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
  10. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  11. kumoai/experimental/rfm/backend/local/table.py +32 -14
  12. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  13. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +186 -39
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  16. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  17. kumoai/experimental/rfm/backend/sqlite/table.py +131 -41
  18. kumoai/experimental/rfm/base/__init__.py +23 -3
  19. kumoai/experimental/rfm/base/column.py +96 -10
  20. kumoai/experimental/rfm/base/expression.py +44 -0
  21. kumoai/experimental/rfm/base/sampler.py +761 -0
  22. kumoai/experimental/rfm/base/source.py +2 -1
  23. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  24. kumoai/experimental/rfm/base/table.py +380 -185
  25. kumoai/experimental/rfm/graph.py +404 -144
  26. kumoai/experimental/rfm/infer/__init__.py +6 -4
  27. kumoai/experimental/rfm/infer/dtype.py +52 -60
  28. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  29. kumoai/experimental/rfm/infer/pkey.py +4 -2
  30. kumoai/experimental/rfm/infer/stype.py +35 -0
  31. kumoai/experimental/rfm/infer/time_col.py +1 -2
  32. kumoai/experimental/rfm/pquery/executor.py +27 -27
  33. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  34. kumoai/experimental/rfm/relbench.py +76 -0
  35. kumoai/experimental/rfm/rfm.py +283 -230
  36. kumoai/experimental/rfm/sagemaker.py +4 -4
  37. kumoai/pquery/predictive_query.py +10 -6
  38. kumoai/testing/snow.py +50 -0
  39. kumoai/trainer/distilled_trainer.py +175 -0
  40. kumoai/utils/__init__.py +3 -2
  41. kumoai/utils/display.py +51 -0
  42. kumoai/utils/progress_logger.py +178 -12
  43. kumoai/utils/sql.py +3 -0
  44. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/METADATA +4 -2
  45. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/RECORD +48 -38
  46. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  47. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  48. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/WHEEL +0 -0
  49. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/licenses/LICENSE +0 -0
  50. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,22 @@
1
1
  import re
2
- from typing import List, Optional, Sequence
2
+ from collections import Counter
3
+ from collections.abc import Sequence
4
+ from typing import cast
3
5
 
4
6
  import pandas as pd
7
+ from kumoapi.model_plan import MissingType
5
8
  from kumoapi.typing import Dtype
6
9
 
7
- from kumoai.experimental.rfm.backend.sqlite import Connection
8
- from kumoai.experimental.rfm.base import SourceColumn, SourceForeignKey, Table
10
+ from kumoai.experimental.rfm.backend.snow import Connection
11
+ from kumoai.experimental.rfm.base import (
12
+ ColumnSpec,
13
+ ColumnSpecType,
14
+ DataBackend,
15
+ SourceColumn,
16
+ SourceForeignKey,
17
+ Table,
18
+ )
19
+ from kumoai.utils import quote_ident
9
20
 
10
21
 
11
22
  class SnowTable(Table):
@@ -14,6 +25,10 @@ class SnowTable(Table):
14
25
  Args:
15
26
  connection: The connection to a :class:`snowflake` database.
16
27
  name: The name of this table.
28
+ source_name: The source name of this table. If set to ``None``,
29
+ ``name`` is being used.
30
+ database: The database.
31
+ schema: The schema.
17
32
  columns: The selected columns of this table.
18
33
  primary_key: The name of the primary key of this table, if it exists.
19
34
  time_column: The name of the time column of this table, if it exists.
@@ -24,72 +39,204 @@ class SnowTable(Table):
24
39
  self,
25
40
  connection: Connection,
26
41
  name: str,
27
- columns: Optional[Sequence[str]] = None,
28
- primary_key: Optional[str] = None,
29
- time_column: Optional[str] = None,
30
- end_time_column: Optional[str] = None,
42
+ source_name: str | None = None,
43
+ database: str | None = None,
44
+ schema: str | None = None,
45
+ columns: Sequence[ColumnSpecType] | None = None,
46
+ primary_key: MissingType | str | None = MissingType.VALUE,
47
+ time_column: str | None = None,
48
+ end_time_column: str | None = None,
31
49
  ) -> None:
32
50
 
51
+ if database is None or schema is None:
52
+ with connection.cursor() as cursor:
53
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
54
+ result = cursor.fetchone()
55
+ database = database or result[0]
56
+ assert database is not None
57
+ schema = schema or result[1]
58
+
59
+ if schema is None:
60
+ raise ValueError(f"Unspecified 'schema' for table "
61
+ f"'{source_name or name}' in database "
62
+ f"'{database}'")
63
+
33
64
  self._connection = connection
65
+ self._database = database
66
+ self._schema = schema
34
67
 
35
68
  super().__init__(
36
69
  name=name,
70
+ source_name=source_name,
37
71
  columns=columns,
38
72
  primary_key=primary_key,
39
73
  time_column=time_column,
40
74
  end_time_column=end_time_column,
41
75
  )
42
76
 
43
- def _get_source_columns(self) -> List[SourceColumn]:
44
- source_columns: List[SourceColumn] = []
77
+ @property
78
+ def source_name(self) -> str:
79
+ names: list[str] = []
80
+ if self._database is not None:
81
+ names.append(self._database)
82
+ if self._schema is not None:
83
+ names.append(self._schema)
84
+ return '.'.join(names + [self._source_name])
85
+
86
+ @property
87
+ def _quoted_source_name(self) -> str:
88
+ names: list[str] = []
89
+ if self._database is not None:
90
+ names.append(quote_ident(self._database))
91
+ if self._schema is not None:
92
+ names.append(quote_ident(self._schema))
93
+ return '.'.join(names + [quote_ident(self._source_name)])
94
+
95
+ @property
96
+ def backend(self) -> DataBackend:
97
+ return cast(DataBackend, DataBackend.SNOWFLAKE)
98
+
99
+ def _get_source_columns(self) -> list[SourceColumn]:
100
+ source_columns: list[SourceColumn] = []
45
101
  with self._connection.cursor() as cursor:
46
102
  try:
47
- cursor.execute(f"DESCRIBE TABLE {self.name}")
103
+ sql = f"DESCRIBE TABLE {self._quoted_source_name}"
104
+ cursor.execute(sql)
48
105
  except Exception as e:
49
- raise ValueError(f"Table '{self.name}' does not exist") from e
106
+ raise ValueError(f"Table '{self.source_name}' does not exist "
107
+ f"in the remote data backend") from e
50
108
 
51
109
  for row in cursor.fetchall():
52
- column, type, _, _, _, is_pkey, is_unique = row[:7]
53
-
54
- type = type.strip().upper()
55
- if type.startswith('NUMBER'):
56
- dtype = Dtype.int
57
- elif type.startswith('VARCHAR'):
58
- dtype = Dtype.string
59
- elif type == 'FLOAT':
60
- dtype = Dtype.float
61
- elif type == 'BOOLEAN':
62
- dtype = Dtype.bool
63
- elif re.search('DATE|TIMESTAMP', type):
64
- dtype = Dtype.date
65
- else:
66
- continue
110
+ column, dtype, _, null, _, is_pkey, is_unique, *_ = row
67
111
 
68
112
  source_column = SourceColumn(
69
113
  name=column,
70
- dtype=dtype,
114
+ dtype=self._to_dtype(dtype),
71
115
  is_primary_key=is_pkey.strip().upper() == 'Y',
72
116
  is_unique_key=is_unique.strip().upper() == 'Y',
117
+ is_nullable=null.strip().upper() == 'Y',
73
118
  )
74
119
  source_columns.append(source_column)
75
120
 
76
121
  return source_columns
77
122
 
78
- def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
79
- source_fkeys: List[SourceForeignKey] = []
123
+ def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
124
+ source_foreign_keys: list[SourceForeignKey] = []
80
125
  with self._connection.cursor() as cursor:
81
- cursor.execute(f"SHOW IMPORTED KEYS IN TABLE {self.name}")
82
- for row in cursor.fetchall():
83
- _, _, _, dst_table, pkey, _, _, _, fkey = row[:9]
84
- source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
85
- return source_fkeys
126
+ sql = f"SHOW IMPORTED KEYS IN TABLE {self._quoted_source_name}"
127
+ cursor.execute(sql)
128
+ rows = cursor.fetchall()
129
+ counts = Counter(row[13] for row in rows)
130
+ for row in rows:
131
+ if counts[row[13]] == 1:
132
+ source_foreign_key = SourceForeignKey(
133
+ name=row[8],
134
+ dst_table=f'{row[1]}.{row[2]}.{row[3]}',
135
+ primary_key=row[4],
136
+ )
137
+ source_foreign_keys.append(source_foreign_key)
138
+ return source_foreign_keys
139
+
140
+ def _get_source_sample_df(self) -> pd.DataFrame:
141
+ with self._connection.cursor() as cursor:
142
+ columns = [quote_ident(col) for col in self._source_column_dict]
143
+ sql = (f"SELECT {', '.join(columns)} "
144
+ f"FROM {self._quoted_source_name} "
145
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
146
+ cursor.execute(sql)
147
+ table = cursor.fetch_arrow_all()
86
148
 
87
- def _get_sample_df(self) -> pd.DataFrame:
149
+ if table is None:
150
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
151
+
152
+ return self._sanitize(
153
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
154
+ dtype_dict={
155
+ column.name: column.dtype
156
+ for column in self._source_column_dict.values()
157
+ },
158
+ stype_dict=None,
159
+ )
160
+
161
+ def _get_num_rows(self) -> int | None:
162
+ return None
163
+
164
+ def _get_expr_sample_df(
165
+ self,
166
+ columns: Sequence[ColumnSpec],
167
+ ) -> pd.DataFrame:
88
168
  with self._connection.cursor() as cursor:
89
- columns = ', '.join(self._source_column_dict.keys())
90
- cursor.execute(f"SELECT {columns} FROM {self.name} LIMIT 1000")
169
+ projections = [
170
+ f"{column.expr} AS {quote_ident(column.name)}"
171
+ for column in columns
172
+ ]
173
+ sql = (f"SELECT {', '.join(projections)} "
174
+ f"FROM {self._quoted_source_name} "
175
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
176
+ cursor.execute(sql)
91
177
  table = cursor.fetch_arrow_all()
92
- return table.to_pandas()
93
178
 
94
- def _get_num_rows(self) -> Optional[int]:
179
+ if table is None:
180
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
181
+
182
+ return self._sanitize(
183
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
184
+ dtype_dict={column.name: column.dtype
185
+ for column in columns},
186
+ stype_dict=None,
187
+ )
188
+
189
+ @staticmethod
190
+ def _to_dtype(dtype: str | None) -> Dtype | None:
191
+ if dtype is None:
192
+ return None
193
+ dtype = dtype.strip().upper()
194
+ if dtype.startswith('NUMBER'):
195
+ try: # Parse `scale` from 'NUMBER(precision, scale)':
196
+ scale = int(dtype.split(',')[-1].split(')')[0])
197
+ return Dtype.int if scale == 0 else Dtype.float
198
+ except Exception:
199
+ return Dtype.float
200
+ if dtype == 'FLOAT':
201
+ return Dtype.float
202
+ if dtype.startswith('VARCHAR'):
203
+ return Dtype.string
204
+ if dtype.startswith('BINARY'):
205
+ return Dtype.binary
206
+ if dtype == 'BOOLEAN':
207
+ return Dtype.bool
208
+ if dtype.startswith('DATE') or dtype.startswith('TIMESTAMP'):
209
+ return Dtype.date
210
+ if dtype.startswith('TIME'):
211
+ return Dtype.time
212
+ if dtype.startswith('VECTOR'):
213
+ try: # Parse element data type from 'VECTOR(dtype, dimension)':
214
+ dtype = dtype.split(',')[0].split('(')[1].strip()
215
+ if dtype == 'INT':
216
+ return Dtype.intlist
217
+ elif dtype == 'FLOAT':
218
+ return Dtype.floatlist
219
+ except Exception:
220
+ pass
221
+ return Dtype.unsupported
222
+ if dtype.startswith('ARRAY'):
223
+ try: # Parse element data type from 'ARRAY(dtype)':
224
+ dtype = dtype.split('(', maxsplit=1)[1]
225
+ dtype = dtype.rsplit(')', maxsplit=1)[0]
226
+ _dtype = SnowTable._to_dtype(dtype)
227
+ if _dtype is not None and _dtype.is_int():
228
+ return Dtype.intlist
229
+ elif _dtype is not None and _dtype.is_float():
230
+ return Dtype.floatlist
231
+ elif _dtype is not None and _dtype.is_string():
232
+ return Dtype.stringlist
233
+ except Exception:
234
+ pass
235
+ return Dtype.unsupported
236
+ # Unsupported data types:
237
+ if re.search(
238
+ 'DECFLOAT|VARIANT|OBJECT|MAP|FILE|GEOGRAPHY|GEOMETRY',
239
+ dtype,
240
+ ):
241
+ return Dtype.unsupported
95
242
  return None
@@ -1,5 +1,5 @@
1
1
  from pathlib import Path
2
- from typing import Any, TypeAlias, Union
2
+ from typing import Any, TypeAlias
3
3
 
4
4
  try:
5
5
  import adbc_driver_sqlite.dbapi as adbc
@@ -11,7 +11,7 @@ except ImportError:
11
11
  Connection: TypeAlias = adbc.AdbcSqliteConnection
12
12
 
13
13
 
14
- def connect(uri: Union[str, Path, None] = None, **kwargs: Any) -> Connection:
14
+ def connect(uri: str | Path | None = None, **kwargs: Any) -> Connection:
15
15
  r"""Opens a connection to a :class:`sqlite` database.
16
16
 
17
17
  uri: The path to the database file to be opened.
@@ -22,9 +22,11 @@ def connect(uri: Union[str, Path, None] = None, **kwargs: Any) -> Connection:
22
22
 
23
23
 
24
24
  from .table import SQLiteTable # noqa: E402
25
+ from .sampler import SQLiteSampler # noqa: E402
25
26
 
26
27
  __all__ = [
27
28
  'connect',
28
29
  'Connection',
29
30
  'SQLiteTable',
31
+ 'SQLiteSampler',
30
32
  ]