kumoai 2.12.1__py3-none-any.whl → 2.14.0.dev202512141732__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/__init__.py +18 -9
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +9 -13
  4. kumoai/client/pquery.py +6 -2
  5. kumoai/connector/utils.py +23 -2
  6. kumoai/experimental/rfm/__init__.py +162 -46
  7. kumoai/experimental/rfm/backend/__init__.py +0 -0
  8. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  9. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +37 -90
  10. kumoai/experimental/rfm/backend/local/sampler.py +313 -0
  11. kumoai/experimental/rfm/backend/local/table.py +119 -0
  12. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  13. kumoai/experimental/rfm/backend/snow/sampler.py +119 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +135 -0
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  16. kumoai/experimental/rfm/backend/sqlite/sampler.py +112 -0
  17. kumoai/experimental/rfm/backend/sqlite/table.py +115 -0
  18. kumoai/experimental/rfm/base/__init__.py +23 -0
  19. kumoai/experimental/rfm/base/column.py +66 -0
  20. kumoai/experimental/rfm/base/sampler.py +773 -0
  21. kumoai/experimental/rfm/base/source.py +19 -0
  22. kumoai/experimental/rfm/{local_table.py → base/table.py} +152 -141
  23. kumoai/experimental/rfm/{local_graph.py → graph.py} +352 -80
  24. kumoai/experimental/rfm/infer/__init__.py +6 -0
  25. kumoai/experimental/rfm/infer/dtype.py +79 -0
  26. kumoai/experimental/rfm/infer/pkey.py +126 -0
  27. kumoai/experimental/rfm/infer/time_col.py +62 -0
  28. kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
  29. kumoai/experimental/rfm/rfm.py +233 -174
  30. kumoai/experimental/rfm/sagemaker.py +138 -0
  31. kumoai/spcs.py +1 -3
  32. kumoai/testing/decorators.py +1 -1
  33. kumoai/testing/snow.py +50 -0
  34. kumoai/utils/__init__.py +2 -0
  35. kumoai/utils/sql.py +3 -0
  36. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/METADATA +12 -2
  37. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/RECORD +40 -23
  38. kumoai/experimental/rfm/local_graph_sampler.py +0 -184
  39. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  40. kumoai/experimental/rfm/utils.py +0 -344
  41. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/WHEEL +0 -0
  42. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,135 @@
1
+ import re
2
+ from typing import List, Optional, Sequence, cast
3
+
4
+ import pandas as pd
5
+ from kumoapi.typing import Dtype
6
+
7
+ from kumoai.experimental.rfm.backend.snow import Connection
8
+ from kumoai.experimental.rfm.base import (
9
+ DataBackend,
10
+ SourceColumn,
11
+ SourceForeignKey,
12
+ Table,
13
+ )
14
+ from kumoai.utils import quote_ident
15
+
16
+
17
+ class SnowTable(Table):
18
+ r"""A table backed by a :class:`sqlite` database.
19
+
20
+ Args:
21
+ connection: The connection to a :class:`snowflake` database.
22
+ name: The name of this table.
23
+ database: The database.
24
+ schema: The schema.
25
+ columns: The selected columns of this table.
26
+ primary_key: The name of the primary key of this table, if it exists.
27
+ time_column: The name of the time column of this table, if it exists.
28
+ end_time_column: The name of the end time column of this table, if it
29
+ exists.
30
+ """
31
+ def __init__(
32
+ self,
33
+ connection: Connection,
34
+ name: str,
35
+ database: str | None = None,
36
+ schema: str | None = None,
37
+ columns: Optional[Sequence[str]] = None,
38
+ primary_key: Optional[str] = None,
39
+ time_column: Optional[str] = None,
40
+ end_time_column: Optional[str] = None,
41
+ ) -> None:
42
+
43
+ if database is not None and schema is None:
44
+ raise ValueError(f"Missing 'schema' for table '{name}' in "
45
+ f"database '{database}'")
46
+
47
+ self._connection = connection
48
+ self._database = database
49
+ self._schema = schema
50
+
51
+ super().__init__(
52
+ name=name,
53
+ columns=columns,
54
+ primary_key=primary_key,
55
+ time_column=time_column,
56
+ end_time_column=end_time_column,
57
+ )
58
+
59
+ @property
60
+ def backend(self) -> DataBackend:
61
+ return cast(DataBackend, DataBackend.SNOWFLAKE)
62
+
63
+ @property
64
+ def fqn(self) -> str:
65
+ r"""The fully-qualified quoted table name."""
66
+ names: List[str] = []
67
+ if self._database is not None:
68
+ names.append(quote_ident(self._database))
69
+ if self._schema is not None:
70
+ names.append(quote_ident(self._schema))
71
+ return '.'.join(names + [quote_ident(self._name)])
72
+
73
+ def _get_source_columns(self) -> List[SourceColumn]:
74
+ source_columns: List[SourceColumn] = []
75
+ with self._connection.cursor() as cursor:
76
+ try:
77
+ sql = f"DESCRIBE TABLE {self.fqn}"
78
+ cursor.execute(sql)
79
+ except Exception as e:
80
+ names: list[str] = []
81
+ if self._database is not None:
82
+ names.append(self._database)
83
+ if self._schema is not None:
84
+ names.append(self._schema)
85
+ name = '.'.join(names + [self._name])
86
+ raise ValueError(f"Table '{name}' does not exist") from e
87
+
88
+ for row in cursor.fetchall():
89
+ column, type, _, null, _, is_pkey, is_unique = row[:7]
90
+
91
+ type = type.strip().upper()
92
+ if type.startswith('NUMBER'):
93
+ dtype = Dtype.int
94
+ elif type.startswith('VARCHAR'):
95
+ dtype = Dtype.string
96
+ elif type == 'FLOAT':
97
+ dtype = Dtype.float
98
+ elif type == 'BOOLEAN':
99
+ dtype = Dtype.bool
100
+ elif re.search('DATE|TIMESTAMP', type):
101
+ dtype = Dtype.date
102
+ else:
103
+ continue
104
+
105
+ source_column = SourceColumn(
106
+ name=column,
107
+ dtype=dtype,
108
+ is_primary_key=is_pkey.strip().upper() == 'Y',
109
+ is_unique_key=is_unique.strip().upper() == 'Y',
110
+ is_nullable=null.strip().upper() == 'Y',
111
+ )
112
+ source_columns.append(source_column)
113
+
114
+ return source_columns
115
+
116
+ def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
117
+ source_fkeys: List[SourceForeignKey] = []
118
+ with self._connection.cursor() as cursor:
119
+ sql = f"SHOW IMPORTED KEYS IN TABLE {self.fqn}"
120
+ cursor.execute(sql)
121
+ for row in cursor.fetchall():
122
+ _, _, _, dst_table, pkey, _, _, _, fkey = row[:9]
123
+ source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
124
+ return source_fkeys
125
+
126
+ def _get_sample_df(self) -> pd.DataFrame:
127
+ with self._connection.cursor() as cursor:
128
+ columns = [quote_ident(col) for col in self._source_column_dict]
129
+ sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
130
+ cursor.execute(sql)
131
+ table = cursor.fetch_arrow_all()
132
+ return table.to_pandas(types_mapper=pd.ArrowDtype)
133
+
134
+ def _get_num_rows(self) -> Optional[int]:
135
+ return None
@@ -0,0 +1,32 @@
1
+ from pathlib import Path
2
+ from typing import Any, TypeAlias, Union
3
+
4
+ try:
5
+ import adbc_driver_sqlite.dbapi as adbc
6
+ except ImportError:
7
+ raise ImportError("No module named 'adbc_driver_sqlite'. Please install "
8
+ "Kumo SDK with the 'sqlite' extension via "
9
+ "`pip install kumoai[sqlite]`.")
10
+
11
+ Connection: TypeAlias = adbc.AdbcSqliteConnection
12
+
13
+
14
+ def connect(uri: Union[str, Path, None] = None, **kwargs: Any) -> Connection:
15
+ r"""Opens a connection to a :class:`sqlite` database.
16
+
17
+ uri: The path to the database file to be opened.
18
+ kwargs: Additional connection arguments, following the
19
+ :class:`adbc_driver_sqlite` protocol.
20
+ """
21
+ return adbc.connect(uri, **kwargs)
22
+
23
+
24
+ from .table import SQLiteTable # noqa: E402
25
+ from .sampler import SQLiteSampler # noqa: E402
26
+
27
+ __all__ = [
28
+ 'connect',
29
+ 'Connection',
30
+ 'SQLiteTable',
31
+ 'SQLiteSampler',
32
+ ]
@@ -0,0 +1,112 @@
1
+ from typing import TYPE_CHECKING, Literal
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from kumoapi.pquery import ValidatedPredictiveQuery
6
+
7
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteTable
8
+ from kumoai.experimental.rfm.base import Sampler, SamplerOutput
9
+ from kumoai.utils import ProgressLogger, quote_ident
10
+
11
+ if TYPE_CHECKING:
12
+ from kumoai.experimental.rfm import Graph
13
+
14
+
15
+ class SQLiteSampler(Sampler):
16
+ def __init__(
17
+ self,
18
+ graph: 'Graph',
19
+ verbose: bool | ProgressLogger = True,
20
+ ) -> None:
21
+ super().__init__(graph=graph)
22
+
23
+ for table in graph.tables.values():
24
+ assert isinstance(table, SQLiteTable)
25
+ self._connection = table._connection
26
+
27
+ # TODO Check for indices being present.
28
+
29
+ def _get_min_max_time_dict(
30
+ self,
31
+ table_names: list[str],
32
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
33
+ selects: list[str] = []
34
+ for table_name in table_names:
35
+ time_column = self.time_column_dict[table_name]
36
+ select = (f"SELECT\n"
37
+ f" ? as table_name,\n"
38
+ f" MIN({quote_ident(time_column)}) as min_date,\n"
39
+ f" MAX({quote_ident(time_column)}) as max_date\n"
40
+ f"FROM {quote_ident(table_name)}")
41
+ selects.append(select)
42
+ sql = "\nUNION ALL\n".join(selects)
43
+
44
+ out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
45
+ with self._connection.cursor() as cursor:
46
+ cursor.execute(sql, table_names)
47
+ for table_name, _min, _max in cursor.fetchall():
48
+ out_dict[table_name] = (
49
+ pd.Timestamp.max if _min is None else pd.Timestamp(_min),
50
+ pd.Timestamp.min if _max is None else pd.Timestamp(_max),
51
+ )
52
+ return out_dict
53
+
54
+ def _sample_subgraph(
55
+ self,
56
+ entity_table_name: str,
57
+ entity_pkey: pd.Series,
58
+ anchor_time: pd.Series | Literal['entity'],
59
+ columns_dict: dict[str, set[str]],
60
+ num_neighbors: list[int],
61
+ ) -> SamplerOutput:
62
+ raise NotImplementedError
63
+
64
+ def _sample_entity_table(
65
+ self,
66
+ table_name: str,
67
+ columns: set[str],
68
+ num_rows: int,
69
+ random_seed: int | None = None,
70
+ ) -> pd.DataFrame:
71
+ # NOTE SQLite does not natively support passing a `random_seed`.
72
+
73
+ filters: list[str] = []
74
+ primary_key = self.primary_key_dict[table_name]
75
+ if self.source_table_dict[table_name][primary_key].is_nullable:
76
+ filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
77
+ time_column = self.time_column_dict.get(table_name)
78
+ if (time_column is not None and
79
+ self.source_table_dict[table_name][time_column].is_nullable):
80
+ filters.append(f" {quote_ident(time_column)} IS NOT NULL")
81
+
82
+ # TODO Make this query more efficient - it does full table scan.
83
+ sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
84
+ f"FROM {quote_ident(table_name)}")
85
+ if len(filters) > 0:
86
+ sql += f"\nWHERE{' AND'.join(filters)}"
87
+ sql += f"\nORDER BY RANDOM() LIMIT {num_rows}"
88
+
89
+ with self._connection.cursor() as cursor:
90
+ # NOTE This may return duplicate primary keys. This is okay.
91
+ cursor.execute(sql)
92
+ table = cursor.fetch_arrow_table()
93
+
94
+ return table.to_pandas(types_mapper=pd.ArrowDtype)
95
+
96
+ def _sample_target(
97
+ self,
98
+ query: ValidatedPredictiveQuery,
99
+ entity_df: pd.DataFrame,
100
+ train_index: np.ndarray,
101
+ train_time: pd.Series,
102
+ num_train_examples: int,
103
+ test_index: np.ndarray,
104
+ test_time: pd.Series,
105
+ num_test_examples: int,
106
+ columns_dict: dict[str, set[str]],
107
+ time_offset_dict: dict[
108
+ tuple[str, str, str],
109
+ tuple[pd.DateOffset | None, pd.DateOffset],
110
+ ],
111
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
112
+ raise NotImplementedError
@@ -0,0 +1,115 @@
1
+ import re
2
+ import warnings
3
+ from typing import List, Optional, Sequence, cast
4
+
5
+ import pandas as pd
6
+ from kumoapi.typing import Dtype
7
+
8
+ from kumoai.experimental.rfm.backend.sqlite import Connection
9
+ from kumoai.experimental.rfm.base import (
10
+ DataBackend,
11
+ SourceColumn,
12
+ SourceForeignKey,
13
+ Table,
14
+ )
15
+ from kumoai.experimental.rfm.infer import infer_dtype
16
+ from kumoai.utils import quote_ident
17
+
18
+
19
+ class SQLiteTable(Table):
20
+ r"""A table backed by a :class:`sqlite` database.
21
+
22
+ Args:
23
+ connection: The connection to a :class:`sqlite` database.
24
+ name: The name of this table.
25
+ columns: The selected columns of this table.
26
+ primary_key: The name of the primary key of this table, if it exists.
27
+ time_column: The name of the time column of this table, if it exists.
28
+ end_time_column: The name of the end time column of this table, if it
29
+ exists.
30
+ """
31
+ def __init__(
32
+ self,
33
+ connection: Connection,
34
+ name: str,
35
+ columns: Optional[Sequence[str]] = None,
36
+ primary_key: Optional[str] = None,
37
+ time_column: Optional[str] = None,
38
+ end_time_column: Optional[str] = None,
39
+ ) -> None:
40
+
41
+ self._connection = connection
42
+
43
+ super().__init__(
44
+ name=name,
45
+ columns=columns,
46
+ primary_key=primary_key,
47
+ time_column=time_column,
48
+ end_time_column=end_time_column,
49
+ )
50
+
51
+ @property
52
+ def backend(self) -> DataBackend:
53
+ return cast(DataBackend, DataBackend.SQLITE)
54
+
55
+ def _get_source_columns(self) -> List[SourceColumn]:
56
+ source_columns: List[SourceColumn] = []
57
+ with self._connection.cursor() as cursor:
58
+ sql = f"PRAGMA table_info({quote_ident(self.name)})"
59
+ cursor.execute(sql)
60
+ rows = cursor.fetchall()
61
+
62
+ if len(rows) == 0:
63
+ raise ValueError(f"Table '{self.name}' does not exist")
64
+
65
+ for _, column, type, notnull, _, is_pkey in rows:
66
+ # Determine column affinity:
67
+ type = type.strip().upper()
68
+ if re.search('INT', type):
69
+ dtype = Dtype.int
70
+ elif re.search('TEXT|CHAR|CLOB', type):
71
+ dtype = Dtype.string
72
+ elif re.search('REAL|FLOA|DOUB', type):
73
+ dtype = Dtype.float
74
+ else: # NUMERIC affinity.
75
+ ser = self._sample_df[column]
76
+ try:
77
+ dtype = infer_dtype(ser)
78
+ except Exception:
79
+ warnings.warn(
80
+ f"Data type inference for column '{column}' in "
81
+ f"table '{self.name}' failed. Consider changing "
82
+ f"the data type of the column to use it within "
83
+ f"this table.")
84
+ continue
85
+
86
+ source_column = SourceColumn(
87
+ name=column,
88
+ dtype=dtype,
89
+ is_primary_key=bool(is_pkey),
90
+ is_unique_key=False,
91
+ is_nullable=not bool(is_pkey) and not bool(notnull),
92
+ )
93
+ source_columns.append(source_column)
94
+
95
+ return source_columns
96
+
97
+ def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
98
+ source_fkeys: List[SourceForeignKey] = []
99
+ with self._connection.cursor() as cursor:
100
+ sql = f"PRAGMA foreign_key_list({quote_ident(self.name)})"
101
+ cursor.execute(sql)
102
+ for _, _, dst_table, fkey, pkey, _, _, _ in cursor.fetchall():
103
+ source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
104
+ return source_fkeys
105
+
106
+ def _get_sample_df(self) -> pd.DataFrame:
107
+ with self._connection.cursor() as cursor:
108
+ sql = (f"SELECT * FROM {quote_ident(self.name)} "
109
+ f"ORDER BY rowid LIMIT 1000")
110
+ cursor.execute(sql)
111
+ table = cursor.fetch_arrow_table()
112
+ return table.to_pandas(types_mapper=pd.ArrowDtype)
113
+
114
+ def _get_num_rows(self) -> Optional[int]:
115
+ return None
@@ -0,0 +1,23 @@
1
+ from kumoapi.common import StrEnum
2
+
3
+
4
+ class DataBackend(StrEnum):
5
+ LOCAL = 'local'
6
+ SQLITE = 'sqlite'
7
+ SNOWFLAKE = 'snowflake'
8
+
9
+
10
+ from .source import SourceColumn, SourceForeignKey # noqa: E402
11
+ from .column import Column # noqa: E402
12
+ from .table import Table # noqa: E402
13
+ from .sampler import SamplerOutput, Sampler # noqa: E402
14
+
15
+ __all__ = [
16
+ 'DataBackend',
17
+ 'SourceColumn',
18
+ 'SourceForeignKey',
19
+ 'Column',
20
+ 'Table',
21
+ 'SamplerOutput',
22
+ 'Sampler',
23
+ ]
@@ -0,0 +1,66 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+
4
+ from kumoapi.typing import Dtype, Stype
5
+
6
+
7
+ @dataclass(init=False, repr=False, eq=False)
8
+ class Column:
9
+ stype: Stype
10
+
11
+ def __init__(
12
+ self,
13
+ name: str,
14
+ dtype: Dtype,
15
+ stype: Stype,
16
+ is_primary_key: bool = False,
17
+ is_time_column: bool = False,
18
+ is_end_time_column: bool = False,
19
+ ) -> None:
20
+ self._name = name
21
+ self._dtype = Dtype(dtype)
22
+ self._is_primary_key = is_primary_key
23
+ self._is_time_column = is_time_column
24
+ self._is_end_time_column = is_end_time_column
25
+ self.stype = Stype(stype)
26
+
27
+ @property
28
+ def name(self) -> str:
29
+ return self._name
30
+
31
+ @property
32
+ def dtype(self) -> Dtype:
33
+ return self._dtype
34
+
35
+ def __setattr__(self, key: str, val: Any) -> None:
36
+ if key == 'stype':
37
+ if isinstance(val, str):
38
+ val = Stype(val)
39
+ assert isinstance(val, Stype)
40
+ if not val.supports_dtype(self.dtype):
41
+ raise ValueError(f"Column '{self.name}' received an "
42
+ f"incompatible semantic type (got "
43
+ f"dtype='{self.dtype}' and stype='{val}')")
44
+ if self._is_primary_key and val != Stype.ID:
45
+ raise ValueError(f"Primary key '{self.name}' must have 'ID' "
46
+ f"semantic type (got '{val}')")
47
+ if self._is_time_column and val != Stype.timestamp:
48
+ raise ValueError(f"Time column '{self.name}' must have "
49
+ f"'timestamp' semantic type (got '{val}')")
50
+ if self._is_end_time_column and val != Stype.timestamp:
51
+ raise ValueError(f"End time column '{self.name}' must have "
52
+ f"'timestamp' semantic type (got '{val}')")
53
+
54
+ super().__setattr__(key, val)
55
+
56
+ def __hash__(self) -> int:
57
+ return hash((self.name, self.stype, self.dtype))
58
+
59
+ def __eq__(self, other: Any) -> bool:
60
+ if not isinstance(other, Column):
61
+ return False
62
+ return hash(self) == hash(other)
63
+
64
+ def __repr__(self) -> str:
65
+ return (f'{self.__class__.__name__}(name={self.name}, '
66
+ f'stype={self.stype}, dtype={self.dtype})')