kumoai 2.14.0.dev202512211732__cp313-cp313-win_amd64.whl → 2.15.0.dev202601151732__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +24 -22
  7. kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
  8. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  9. kumoai/experimental/rfm/backend/local/table.py +24 -25
  10. kumoai/experimental/rfm/backend/snow/sampler.py +235 -80
  11. kumoai/experimental/rfm/backend/snow/table.py +146 -70
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +196 -89
  13. kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
  14. kumoai/experimental/rfm/base/__init__.py +6 -9
  15. kumoai/experimental/rfm/base/column.py +95 -11
  16. kumoai/experimental/rfm/base/expression.py +44 -0
  17. kumoai/experimental/rfm/base/mapper.py +69 -0
  18. kumoai/experimental/rfm/base/sampler.py +28 -18
  19. kumoai/experimental/rfm/base/source.py +1 -1
  20. kumoai/experimental/rfm/base/sql_sampler.py +320 -19
  21. kumoai/experimental/rfm/base/table.py +256 -109
  22. kumoai/experimental/rfm/base/utils.py +36 -0
  23. kumoai/experimental/rfm/graph.py +130 -110
  24. kumoai/experimental/rfm/infer/dtype.py +7 -2
  25. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  26. kumoai/experimental/rfm/infer/time_col.py +4 -2
  27. kumoai/experimental/rfm/relbench.py +76 -0
  28. kumoai/experimental/rfm/rfm.py +540 -306
  29. kumoai/experimental/rfm/task_table.py +292 -0
  30. kumoai/kumolib.cp313-win_amd64.pyd +0 -0
  31. kumoai/pquery/training_table.py +16 -2
  32. kumoai/testing/snow.py +3 -3
  33. kumoai/trainer/distilled_trainer.py +175 -0
  34. kumoai/utils/display.py +87 -0
  35. kumoai/utils/progress_logger.py +15 -2
  36. kumoai/utils/sql.py +2 -2
  37. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/METADATA +2 -2
  38. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/RECORD +41 -36
  39. kumoai/experimental/rfm/base/column_expression.py +0 -50
  40. kumoai/experimental/rfm/base/sql_table.py +0 -229
  41. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/WHEEL +0 -0
  42. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import re
2
+ from collections import Counter
2
3
  from collections.abc import Sequence
3
4
  from typing import cast
4
5
 
@@ -8,28 +9,27 @@ from kumoapi.typing import Dtype
8
9
 
9
10
  from kumoai.experimental.rfm.backend.snow import Connection
10
11
  from kumoai.experimental.rfm.base import (
11
- ColumnExpressionSpec,
12
- ColumnExpressionType,
12
+ ColumnSpec,
13
+ ColumnSpecType,
13
14
  DataBackend,
14
15
  SourceColumn,
15
16
  SourceForeignKey,
16
- SQLTable,
17
+ Table,
17
18
  )
18
19
  from kumoai.utils import quote_ident
19
20
 
20
21
 
21
- class SnowTable(SQLTable):
22
+ class SnowTable(Table):
22
23
  r"""A table backed by a :class:`sqlite` database.
23
24
 
24
25
  Args:
25
26
  connection: The connection to a :class:`snowflake` database.
26
- name: The logical name of this table.
27
- source_name: The physical name of this table in the database. If set to
28
- ``None``, ``name`` is being used.
27
+ name: The name of this table.
28
+ source_name: The source name of this table. If set to ``None``,
29
+ ``name`` is being used.
29
30
  database: The database.
30
31
  schema: The schema.
31
- columns: The selected physical columns of this table.
32
- column_expressions: The logical columns of this table.
32
+ columns: The selected columns of this table.
33
33
  primary_key: The name of the primary key of this table, if it exists.
34
34
  time_column: The name of the time column of this table, if it exists.
35
35
  end_time_column: The name of the end time column of this table, if it
@@ -42,14 +42,21 @@ class SnowTable(SQLTable):
42
42
  source_name: str | None = None,
43
43
  database: str | None = None,
44
44
  schema: str | None = None,
45
- columns: Sequence[str] | None = None,
46
- column_expressions: Sequence[ColumnExpressionType] | None = None,
45
+ columns: Sequence[ColumnSpecType] | None = None,
47
46
  primary_key: MissingType | str | None = MissingType.VALUE,
48
47
  time_column: str | None = None,
49
48
  end_time_column: str | None = None,
50
49
  ) -> None:
51
50
 
52
- if database is not None and schema is None:
51
+ if database is None or schema is None:
52
+ with connection.cursor() as cursor:
53
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
54
+ result = cursor.fetchone()
55
+ database = database or result[0]
56
+ assert database is not None
57
+ schema = schema or result[1]
58
+
59
+ if schema is None:
53
60
  raise ValueError(f"Unspecified 'schema' for table "
54
61
  f"'{source_name or name}' in database "
55
62
  f"'{database}'")
@@ -62,70 +69,41 @@ class SnowTable(SQLTable):
62
69
  name=name,
63
70
  source_name=source_name,
64
71
  columns=columns,
65
- column_expressions=column_expressions,
66
72
  primary_key=primary_key,
67
73
  time_column=time_column,
68
74
  end_time_column=end_time_column,
69
75
  )
70
76
 
71
- @staticmethod
72
- def to_dtype(snowflake_dtype: str | None) -> Dtype | None:
73
- if snowflake_dtype is None:
74
- return None
75
- snowflake_dtype = snowflake_dtype.strip().upper()
76
- # TODO 'NUMBER(...)' is not always an integer!
77
- if snowflake_dtype.startswith('NUMBER'):
78
- return Dtype.int
79
- elif snowflake_dtype.startswith('VARCHAR'):
80
- return Dtype.string
81
- elif snowflake_dtype == 'FLOAT':
82
- return Dtype.float
83
- elif snowflake_dtype == 'BOOLEAN':
84
- return Dtype.bool
85
- elif re.search('DATE|TIMESTAMP', snowflake_dtype):
86
- return Dtype.date
87
- return None
77
+ @property
78
+ def source_name(self) -> str:
79
+ names = [self._database, self._schema, self._source_name]
80
+ return '.'.join(names)
88
81
 
89
82
  @property
90
- def backend(self) -> DataBackend:
91
- return cast(DataBackend, DataBackend.SNOWFLAKE)
83
+ def _quoted_source_name(self) -> str:
84
+ names = [self._database, self._schema, self._source_name]
85
+ return '.'.join([quote_ident(name) for name in names])
92
86
 
93
87
  @property
94
- def fqn(self) -> str:
95
- r"""The fully-qualified quoted table name."""
96
- names: list[str] = []
97
- if self._database is not None:
98
- names.append(quote_ident(self._database))
99
- if self._schema is not None:
100
- names.append(quote_ident(self._schema))
101
- return '.'.join(names + [quote_ident(self._source_name)])
88
+ def backend(self) -> DataBackend:
89
+ return cast(DataBackend, DataBackend.SNOWFLAKE)
102
90
 
103
91
  def _get_source_columns(self) -> list[SourceColumn]:
104
92
  source_columns: list[SourceColumn] = []
105
93
  with self._connection.cursor() as cursor:
106
94
  try:
107
- sql = f"DESCRIBE TABLE {self.fqn}"
95
+ sql = f"DESCRIBE TABLE {self._quoted_source_name}"
108
96
  cursor.execute(sql)
109
97
  except Exception as e:
110
- names: list[str] = []
111
- if self._database is not None:
112
- names.append(self._database)
113
- if self._schema is not None:
114
- names.append(self._schema)
115
- source_name = '.'.join(names + [self._source_name])
116
- raise ValueError(f"Table '{source_name}' does not exist in "
117
- f"the remote data backend") from e
98
+ raise ValueError(f"Table '{self.source_name}' does not exist "
99
+ f"in the remote data backend") from e
118
100
 
119
101
  for row in cursor.fetchall():
120
- column, type, _, null, _, is_pkey, is_unique, *_ = row
121
-
122
- dtype = self.to_dtype(type)
123
- if dtype is None:
124
- continue
102
+ column, dtype, _, null, _, is_pkey, is_unique, *_ = row
125
103
 
126
104
  source_column = SourceColumn(
127
105
  name=column,
128
- dtype=dtype,
106
+ dtype=self._to_dtype(dtype),
129
107
  is_primary_key=is_pkey.strip().upper() == 'Y',
130
108
  is_unique_key=is_unique.strip().upper() == 'Y',
131
109
  is_nullable=null.strip().upper() == 'Y',
@@ -135,35 +113,133 @@ class SnowTable(SQLTable):
135
113
  return source_columns
136
114
 
137
115
  def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
138
- source_fkeys: list[SourceForeignKey] = []
116
+ source_foreign_keys: list[SourceForeignKey] = []
139
117
  with self._connection.cursor() as cursor:
140
- sql = f"SHOW IMPORTED KEYS IN TABLE {self.fqn}"
118
+ sql = f"SHOW IMPORTED KEYS IN TABLE {self._quoted_source_name}"
141
119
  cursor.execute(sql)
142
- for row in cursor.fetchall():
143
- _, _, _, dst_table, pkey, _, _, _, fkey, *_ = row
144
- source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
145
- return source_fkeys
120
+ rows = cursor.fetchall()
121
+ counts = Counter(row[13] for row in rows)
122
+ for row in rows:
123
+ if counts[row[13]] == 1:
124
+ source_foreign_key = SourceForeignKey(
125
+ name=row[8],
126
+ dst_table=f'{row[1]}.{row[2]}.{row[3]}',
127
+ primary_key=row[4],
128
+ )
129
+ source_foreign_keys.append(source_foreign_key)
130
+ return source_foreign_keys
146
131
 
147
132
  def _get_source_sample_df(self) -> pd.DataFrame:
148
133
  with self._connection.cursor() as cursor:
149
134
  columns = [quote_ident(col) for col in self._source_column_dict]
150
- sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
135
+ sql = (f"SELECT {', '.join(columns)} "
136
+ f"FROM {self._quoted_source_name} "
137
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
151
138
  cursor.execute(sql)
152
139
  table = cursor.fetch_arrow_all()
153
- return table.to_pandas(types_mapper=pd.ArrowDtype)
140
+
141
+ if table is None:
142
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
143
+
144
+ return self._sanitize(
145
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
146
+ dtype_dict={
147
+ column.name: column.dtype
148
+ for column in self._source_column_dict.values()
149
+ },
150
+ stype_dict=None,
151
+ )
154
152
 
155
153
  def _get_num_rows(self) -> int | None:
156
- return None
154
+ with self._connection.cursor() as cursor:
155
+ quoted_source_name = quote_ident(self._source_name, char="'")
156
+ sql = (f"SHOW TABLES LIKE {quoted_source_name} "
157
+ f"IN SCHEMA {quote_ident(self._database)}."
158
+ f"{quote_ident(self._schema)}")
159
+ cursor.execute(sql)
160
+ num_rows = cursor.fetchone()[7]
161
+
162
+ if num_rows == 0:
163
+ raise RuntimeError("Table '{self.source_name}' is empty")
157
164
 
158
- def _get_expression_sample_df(
165
+ return num_rows
166
+
167
+ def _get_expr_sample_df(
159
168
  self,
160
- specs: Sequence[ColumnExpressionSpec],
169
+ columns: Sequence[ColumnSpec],
161
170
  ) -> pd.DataFrame:
162
171
  with self._connection.cursor() as cursor:
163
- columns = [
164
- f"{spec.expr} AS {quote_ident(spec.name)}" for spec in specs
172
+ projections = [
173
+ f"{column.expr} AS {quote_ident(column.name)}"
174
+ for column in columns
165
175
  ]
166
- sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
176
+ sql = (f"SELECT {', '.join(projections)} "
177
+ f"FROM {self._quoted_source_name} "
178
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
167
179
  cursor.execute(sql)
168
180
  table = cursor.fetch_arrow_all()
169
- return table.to_pandas(types_mapper=pd.ArrowDtype)
181
+
182
+ if table is None:
183
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
184
+
185
+ return self._sanitize(
186
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
187
+ dtype_dict={column.name: column.dtype
188
+ for column in columns},
189
+ stype_dict=None,
190
+ )
191
+
192
+ @staticmethod
193
+ def _to_dtype(dtype: str | None) -> Dtype | None:
194
+ if dtype is None:
195
+ return None
196
+ dtype = dtype.strip().upper()
197
+ if dtype.startswith('NUMBER'):
198
+ try: # Parse `scale` from 'NUMBER(precision, scale)':
199
+ scale = int(dtype.split(',')[-1].split(')')[0])
200
+ return Dtype.int if scale == 0 else Dtype.float
201
+ except Exception:
202
+ return Dtype.float
203
+ if dtype == 'FLOAT':
204
+ return Dtype.float
205
+ if dtype.startswith('VARCHAR'):
206
+ return Dtype.string
207
+ if dtype.startswith('BINARY'):
208
+ return Dtype.binary
209
+ if dtype == 'BOOLEAN':
210
+ return Dtype.bool
211
+ if dtype.startswith('DATE') or dtype.startswith('TIMESTAMP'):
212
+ return Dtype.date
213
+ if dtype.startswith('TIME'):
214
+ return Dtype.time
215
+ if dtype.startswith('VECTOR'):
216
+ try: # Parse element data type from 'VECTOR(dtype, dimension)':
217
+ dtype = dtype.split(',')[0].split('(')[1].strip()
218
+ if dtype == 'INT':
219
+ return Dtype.intlist
220
+ elif dtype == 'FLOAT':
221
+ return Dtype.floatlist
222
+ except Exception:
223
+ pass
224
+ return Dtype.unsupported
225
+ if dtype.startswith('ARRAY'):
226
+ try: # Parse element data type from 'ARRAY(dtype)':
227
+ dtype = dtype.split('(', maxsplit=1)[1]
228
+ dtype = dtype.rsplit(')', maxsplit=1)[0]
229
+ _dtype = SnowTable._to_dtype(dtype)
230
+ if _dtype is not None and _dtype.is_int():
231
+ return Dtype.intlist
232
+ elif _dtype is not None and _dtype.is_float():
233
+ return Dtype.floatlist
234
+ elif _dtype is not None and _dtype.is_string():
235
+ return Dtype.stringlist
236
+ except Exception:
237
+ pass
238
+ return Dtype.unsupported
239
+ # Unsupported data types:
240
+ if re.search(
241
+ 'DECFLOAT|VARIANT|OBJECT|MAP|FILE|GEOGRAPHY|GEOMETRY',
242
+ dtype,
243
+ ):
244
+ return Dtype.unsupported
245
+ return None