kumoai 2.13.0.dev202511211730__py3-none-any.whl → 2.14.0.dev202512141732__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +20 -45
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +37 -90
- kumoai/experimental/rfm/backend/local/sampler.py +313 -0
- kumoai/experimental/rfm/backend/local/table.py +119 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +119 -0
- kumoai/experimental/rfm/backend/snow/table.py +135 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +112 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +115 -0
- kumoai/experimental/rfm/base/__init__.py +23 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/sampler.py +773 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +152 -141
- kumoai/experimental/rfm/{local_graph.py → graph.py} +352 -80
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
- kumoai/experimental/rfm/rfm.py +224 -167
- kumoai/experimental/rfm/sagemaker.py +11 -3
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/utils/__init__.py +2 -0
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/METADATA +9 -8
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/RECORD +39 -23
- kumoai/experimental/rfm/local_graph_sampler.py +0 -182
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import List, Optional, Sequence, cast
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from kumoapi.typing import Dtype
|
|
6
|
+
|
|
7
|
+
from kumoai.experimental.rfm.backend.snow import Connection
|
|
8
|
+
from kumoai.experimental.rfm.base import (
|
|
9
|
+
DataBackend,
|
|
10
|
+
SourceColumn,
|
|
11
|
+
SourceForeignKey,
|
|
12
|
+
Table,
|
|
13
|
+
)
|
|
14
|
+
from kumoai.utils import quote_ident
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SnowTable(Table):
|
|
18
|
+
r"""A table backed by a :class:`sqlite` database.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
connection: The connection to a :class:`snowflake` database.
|
|
22
|
+
name: The name of this table.
|
|
23
|
+
database: The database.
|
|
24
|
+
schema: The schema.
|
|
25
|
+
columns: The selected columns of this table.
|
|
26
|
+
primary_key: The name of the primary key of this table, if it exists.
|
|
27
|
+
time_column: The name of the time column of this table, if it exists.
|
|
28
|
+
end_time_column: The name of the end time column of this table, if it
|
|
29
|
+
exists.
|
|
30
|
+
"""
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
connection: Connection,
|
|
34
|
+
name: str,
|
|
35
|
+
database: str | None = None,
|
|
36
|
+
schema: str | None = None,
|
|
37
|
+
columns: Optional[Sequence[str]] = None,
|
|
38
|
+
primary_key: Optional[str] = None,
|
|
39
|
+
time_column: Optional[str] = None,
|
|
40
|
+
end_time_column: Optional[str] = None,
|
|
41
|
+
) -> None:
|
|
42
|
+
|
|
43
|
+
if database is not None and schema is None:
|
|
44
|
+
raise ValueError(f"Missing 'schema' for table '{name}' in "
|
|
45
|
+
f"database '{database}'")
|
|
46
|
+
|
|
47
|
+
self._connection = connection
|
|
48
|
+
self._database = database
|
|
49
|
+
self._schema = schema
|
|
50
|
+
|
|
51
|
+
super().__init__(
|
|
52
|
+
name=name,
|
|
53
|
+
columns=columns,
|
|
54
|
+
primary_key=primary_key,
|
|
55
|
+
time_column=time_column,
|
|
56
|
+
end_time_column=end_time_column,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def backend(self) -> DataBackend:
|
|
61
|
+
return cast(DataBackend, DataBackend.SNOWFLAKE)
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def fqn(self) -> str:
|
|
65
|
+
r"""The fully-qualified quoted table name."""
|
|
66
|
+
names: List[str] = []
|
|
67
|
+
if self._database is not None:
|
|
68
|
+
names.append(quote_ident(self._database))
|
|
69
|
+
if self._schema is not None:
|
|
70
|
+
names.append(quote_ident(self._schema))
|
|
71
|
+
return '.'.join(names + [quote_ident(self._name)])
|
|
72
|
+
|
|
73
|
+
def _get_source_columns(self) -> List[SourceColumn]:
|
|
74
|
+
source_columns: List[SourceColumn] = []
|
|
75
|
+
with self._connection.cursor() as cursor:
|
|
76
|
+
try:
|
|
77
|
+
sql = f"DESCRIBE TABLE {self.fqn}"
|
|
78
|
+
cursor.execute(sql)
|
|
79
|
+
except Exception as e:
|
|
80
|
+
names: list[str] = []
|
|
81
|
+
if self._database is not None:
|
|
82
|
+
names.append(self._database)
|
|
83
|
+
if self._schema is not None:
|
|
84
|
+
names.append(self._schema)
|
|
85
|
+
name = '.'.join(names + [self._name])
|
|
86
|
+
raise ValueError(f"Table '{name}' does not exist") from e
|
|
87
|
+
|
|
88
|
+
for row in cursor.fetchall():
|
|
89
|
+
column, type, _, null, _, is_pkey, is_unique = row[:7]
|
|
90
|
+
|
|
91
|
+
type = type.strip().upper()
|
|
92
|
+
if type.startswith('NUMBER'):
|
|
93
|
+
dtype = Dtype.int
|
|
94
|
+
elif type.startswith('VARCHAR'):
|
|
95
|
+
dtype = Dtype.string
|
|
96
|
+
elif type == 'FLOAT':
|
|
97
|
+
dtype = Dtype.float
|
|
98
|
+
elif type == 'BOOLEAN':
|
|
99
|
+
dtype = Dtype.bool
|
|
100
|
+
elif re.search('DATE|TIMESTAMP', type):
|
|
101
|
+
dtype = Dtype.date
|
|
102
|
+
else:
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
source_column = SourceColumn(
|
|
106
|
+
name=column,
|
|
107
|
+
dtype=dtype,
|
|
108
|
+
is_primary_key=is_pkey.strip().upper() == 'Y',
|
|
109
|
+
is_unique_key=is_unique.strip().upper() == 'Y',
|
|
110
|
+
is_nullable=null.strip().upper() == 'Y',
|
|
111
|
+
)
|
|
112
|
+
source_columns.append(source_column)
|
|
113
|
+
|
|
114
|
+
return source_columns
|
|
115
|
+
|
|
116
|
+
def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
|
|
117
|
+
source_fkeys: List[SourceForeignKey] = []
|
|
118
|
+
with self._connection.cursor() as cursor:
|
|
119
|
+
sql = f"SHOW IMPORTED KEYS IN TABLE {self.fqn}"
|
|
120
|
+
cursor.execute(sql)
|
|
121
|
+
for row in cursor.fetchall():
|
|
122
|
+
_, _, _, dst_table, pkey, _, _, _, fkey = row[:9]
|
|
123
|
+
source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
|
|
124
|
+
return source_fkeys
|
|
125
|
+
|
|
126
|
+
def _get_sample_df(self) -> pd.DataFrame:
|
|
127
|
+
with self._connection.cursor() as cursor:
|
|
128
|
+
columns = [quote_ident(col) for col in self._source_column_dict]
|
|
129
|
+
sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
|
|
130
|
+
cursor.execute(sql)
|
|
131
|
+
table = cursor.fetch_arrow_all()
|
|
132
|
+
return table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
133
|
+
|
|
134
|
+
def _get_num_rows(self) -> Optional[int]:
|
|
135
|
+
return None
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Any, TypeAlias, Union
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
import adbc_driver_sqlite.dbapi as adbc
|
|
6
|
+
except ImportError:
|
|
7
|
+
raise ImportError("No module named 'adbc_driver_sqlite'. Please install "
|
|
8
|
+
"Kumo SDK with the 'sqlite' extension via "
|
|
9
|
+
"`pip install kumoai[sqlite]`.")
|
|
10
|
+
|
|
11
|
+
Connection: TypeAlias = adbc.AdbcSqliteConnection
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def connect(uri: Union[str, Path, None] = None, **kwargs: Any) -> Connection:
|
|
15
|
+
r"""Opens a connection to a :class:`sqlite` database.
|
|
16
|
+
|
|
17
|
+
uri: The path to the database file to be opened.
|
|
18
|
+
kwargs: Additional connection arguments, following the
|
|
19
|
+
:class:`adbc_driver_sqlite` protocol.
|
|
20
|
+
"""
|
|
21
|
+
return adbc.connect(uri, **kwargs)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
from .table import SQLiteTable # noqa: E402
|
|
25
|
+
from .sampler import SQLiteSampler # noqa: E402
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
'connect',
|
|
29
|
+
'Connection',
|
|
30
|
+
'SQLiteTable',
|
|
31
|
+
'SQLiteSampler',
|
|
32
|
+
]
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Literal
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
6
|
+
|
|
7
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteTable
|
|
8
|
+
from kumoai.experimental.rfm.base import Sampler, SamplerOutput
|
|
9
|
+
from kumoai.utils import ProgressLogger, quote_ident
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from kumoai.experimental.rfm import Graph
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SQLiteSampler(Sampler):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
graph: 'Graph',
|
|
19
|
+
verbose: bool | ProgressLogger = True,
|
|
20
|
+
) -> None:
|
|
21
|
+
super().__init__(graph=graph)
|
|
22
|
+
|
|
23
|
+
for table in graph.tables.values():
|
|
24
|
+
assert isinstance(table, SQLiteTable)
|
|
25
|
+
self._connection = table._connection
|
|
26
|
+
|
|
27
|
+
# TODO Check for indices being present.
|
|
28
|
+
|
|
29
|
+
def _get_min_max_time_dict(
|
|
30
|
+
self,
|
|
31
|
+
table_names: list[str],
|
|
32
|
+
) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
|
|
33
|
+
selects: list[str] = []
|
|
34
|
+
for table_name in table_names:
|
|
35
|
+
time_column = self.time_column_dict[table_name]
|
|
36
|
+
select = (f"SELECT\n"
|
|
37
|
+
f" ? as table_name,\n"
|
|
38
|
+
f" MIN({quote_ident(time_column)}) as min_date,\n"
|
|
39
|
+
f" MAX({quote_ident(time_column)}) as max_date\n"
|
|
40
|
+
f"FROM {quote_ident(table_name)}")
|
|
41
|
+
selects.append(select)
|
|
42
|
+
sql = "\nUNION ALL\n".join(selects)
|
|
43
|
+
|
|
44
|
+
out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
45
|
+
with self._connection.cursor() as cursor:
|
|
46
|
+
cursor.execute(sql, table_names)
|
|
47
|
+
for table_name, _min, _max in cursor.fetchall():
|
|
48
|
+
out_dict[table_name] = (
|
|
49
|
+
pd.Timestamp.max if _min is None else pd.Timestamp(_min),
|
|
50
|
+
pd.Timestamp.min if _max is None else pd.Timestamp(_max),
|
|
51
|
+
)
|
|
52
|
+
return out_dict
|
|
53
|
+
|
|
54
|
+
def _sample_subgraph(
|
|
55
|
+
self,
|
|
56
|
+
entity_table_name: str,
|
|
57
|
+
entity_pkey: pd.Series,
|
|
58
|
+
anchor_time: pd.Series | Literal['entity'],
|
|
59
|
+
columns_dict: dict[str, set[str]],
|
|
60
|
+
num_neighbors: list[int],
|
|
61
|
+
) -> SamplerOutput:
|
|
62
|
+
raise NotImplementedError
|
|
63
|
+
|
|
64
|
+
def _sample_entity_table(
|
|
65
|
+
self,
|
|
66
|
+
table_name: str,
|
|
67
|
+
columns: set[str],
|
|
68
|
+
num_rows: int,
|
|
69
|
+
random_seed: int | None = None,
|
|
70
|
+
) -> pd.DataFrame:
|
|
71
|
+
# NOTE SQLite does not natively support passing a `random_seed`.
|
|
72
|
+
|
|
73
|
+
filters: list[str] = []
|
|
74
|
+
primary_key = self.primary_key_dict[table_name]
|
|
75
|
+
if self.source_table_dict[table_name][primary_key].is_nullable:
|
|
76
|
+
filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
|
|
77
|
+
time_column = self.time_column_dict.get(table_name)
|
|
78
|
+
if (time_column is not None and
|
|
79
|
+
self.source_table_dict[table_name][time_column].is_nullable):
|
|
80
|
+
filters.append(f" {quote_ident(time_column)} IS NOT NULL")
|
|
81
|
+
|
|
82
|
+
# TODO Make this query more efficient - it does full table scan.
|
|
83
|
+
sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
|
|
84
|
+
f"FROM {quote_ident(table_name)}")
|
|
85
|
+
if len(filters) > 0:
|
|
86
|
+
sql += f"\nWHERE{' AND'.join(filters)}"
|
|
87
|
+
sql += f"\nORDER BY RANDOM() LIMIT {num_rows}"
|
|
88
|
+
|
|
89
|
+
with self._connection.cursor() as cursor:
|
|
90
|
+
# NOTE This may return duplicate primary keys. This is okay.
|
|
91
|
+
cursor.execute(sql)
|
|
92
|
+
table = cursor.fetch_arrow_table()
|
|
93
|
+
|
|
94
|
+
return table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
95
|
+
|
|
96
|
+
def _sample_target(
|
|
97
|
+
self,
|
|
98
|
+
query: ValidatedPredictiveQuery,
|
|
99
|
+
entity_df: pd.DataFrame,
|
|
100
|
+
train_index: np.ndarray,
|
|
101
|
+
train_time: pd.Series,
|
|
102
|
+
num_train_examples: int,
|
|
103
|
+
test_index: np.ndarray,
|
|
104
|
+
test_time: pd.Series,
|
|
105
|
+
num_test_examples: int,
|
|
106
|
+
columns_dict: dict[str, set[str]],
|
|
107
|
+
time_offset_dict: dict[
|
|
108
|
+
tuple[str, str, str],
|
|
109
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
110
|
+
],
|
|
111
|
+
) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
|
|
112
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import List, Optional, Sequence, cast
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from kumoapi.typing import Dtype
|
|
7
|
+
|
|
8
|
+
from kumoai.experimental.rfm.backend.sqlite import Connection
|
|
9
|
+
from kumoai.experimental.rfm.base import (
|
|
10
|
+
DataBackend,
|
|
11
|
+
SourceColumn,
|
|
12
|
+
SourceForeignKey,
|
|
13
|
+
Table,
|
|
14
|
+
)
|
|
15
|
+
from kumoai.experimental.rfm.infer import infer_dtype
|
|
16
|
+
from kumoai.utils import quote_ident
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SQLiteTable(Table):
|
|
20
|
+
r"""A table backed by a :class:`sqlite` database.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
connection: The connection to a :class:`sqlite` database.
|
|
24
|
+
name: The name of this table.
|
|
25
|
+
columns: The selected columns of this table.
|
|
26
|
+
primary_key: The name of the primary key of this table, if it exists.
|
|
27
|
+
time_column: The name of the time column of this table, if it exists.
|
|
28
|
+
end_time_column: The name of the end time column of this table, if it
|
|
29
|
+
exists.
|
|
30
|
+
"""
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
connection: Connection,
|
|
34
|
+
name: str,
|
|
35
|
+
columns: Optional[Sequence[str]] = None,
|
|
36
|
+
primary_key: Optional[str] = None,
|
|
37
|
+
time_column: Optional[str] = None,
|
|
38
|
+
end_time_column: Optional[str] = None,
|
|
39
|
+
) -> None:
|
|
40
|
+
|
|
41
|
+
self._connection = connection
|
|
42
|
+
|
|
43
|
+
super().__init__(
|
|
44
|
+
name=name,
|
|
45
|
+
columns=columns,
|
|
46
|
+
primary_key=primary_key,
|
|
47
|
+
time_column=time_column,
|
|
48
|
+
end_time_column=end_time_column,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def backend(self) -> DataBackend:
|
|
53
|
+
return cast(DataBackend, DataBackend.SQLITE)
|
|
54
|
+
|
|
55
|
+
def _get_source_columns(self) -> List[SourceColumn]:
|
|
56
|
+
source_columns: List[SourceColumn] = []
|
|
57
|
+
with self._connection.cursor() as cursor:
|
|
58
|
+
sql = f"PRAGMA table_info({quote_ident(self.name)})"
|
|
59
|
+
cursor.execute(sql)
|
|
60
|
+
rows = cursor.fetchall()
|
|
61
|
+
|
|
62
|
+
if len(rows) == 0:
|
|
63
|
+
raise ValueError(f"Table '{self.name}' does not exist")
|
|
64
|
+
|
|
65
|
+
for _, column, type, notnull, _, is_pkey in rows:
|
|
66
|
+
# Determine column affinity:
|
|
67
|
+
type = type.strip().upper()
|
|
68
|
+
if re.search('INT', type):
|
|
69
|
+
dtype = Dtype.int
|
|
70
|
+
elif re.search('TEXT|CHAR|CLOB', type):
|
|
71
|
+
dtype = Dtype.string
|
|
72
|
+
elif re.search('REAL|FLOA|DOUB', type):
|
|
73
|
+
dtype = Dtype.float
|
|
74
|
+
else: # NUMERIC affinity.
|
|
75
|
+
ser = self._sample_df[column]
|
|
76
|
+
try:
|
|
77
|
+
dtype = infer_dtype(ser)
|
|
78
|
+
except Exception:
|
|
79
|
+
warnings.warn(
|
|
80
|
+
f"Data type inference for column '{column}' in "
|
|
81
|
+
f"table '{self.name}' failed. Consider changing "
|
|
82
|
+
f"the data type of the column to use it within "
|
|
83
|
+
f"this table.")
|
|
84
|
+
continue
|
|
85
|
+
|
|
86
|
+
source_column = SourceColumn(
|
|
87
|
+
name=column,
|
|
88
|
+
dtype=dtype,
|
|
89
|
+
is_primary_key=bool(is_pkey),
|
|
90
|
+
is_unique_key=False,
|
|
91
|
+
is_nullable=not bool(is_pkey) and not bool(notnull),
|
|
92
|
+
)
|
|
93
|
+
source_columns.append(source_column)
|
|
94
|
+
|
|
95
|
+
return source_columns
|
|
96
|
+
|
|
97
|
+
def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
|
|
98
|
+
source_fkeys: List[SourceForeignKey] = []
|
|
99
|
+
with self._connection.cursor() as cursor:
|
|
100
|
+
sql = f"PRAGMA foreign_key_list({quote_ident(self.name)})"
|
|
101
|
+
cursor.execute(sql)
|
|
102
|
+
for _, _, dst_table, fkey, pkey, _, _, _ in cursor.fetchall():
|
|
103
|
+
source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
|
|
104
|
+
return source_fkeys
|
|
105
|
+
|
|
106
|
+
def _get_sample_df(self) -> pd.DataFrame:
|
|
107
|
+
with self._connection.cursor() as cursor:
|
|
108
|
+
sql = (f"SELECT * FROM {quote_ident(self.name)} "
|
|
109
|
+
f"ORDER BY rowid LIMIT 1000")
|
|
110
|
+
cursor.execute(sql)
|
|
111
|
+
table = cursor.fetch_arrow_table()
|
|
112
|
+
return table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
113
|
+
|
|
114
|
+
def _get_num_rows(self) -> Optional[int]:
|
|
115
|
+
return None
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from kumoapi.common import StrEnum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DataBackend(StrEnum):
|
|
5
|
+
LOCAL = 'local'
|
|
6
|
+
SQLITE = 'sqlite'
|
|
7
|
+
SNOWFLAKE = 'snowflake'
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
from .source import SourceColumn, SourceForeignKey # noqa: E402
|
|
11
|
+
from .column import Column # noqa: E402
|
|
12
|
+
from .table import Table # noqa: E402
|
|
13
|
+
from .sampler import SamplerOutput, Sampler # noqa: E402
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
'DataBackend',
|
|
17
|
+
'SourceColumn',
|
|
18
|
+
'SourceForeignKey',
|
|
19
|
+
'Column',
|
|
20
|
+
'Table',
|
|
21
|
+
'SamplerOutput',
|
|
22
|
+
'Sampler',
|
|
23
|
+
]
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from kumoapi.typing import Dtype, Stype
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(init=False, repr=False, eq=False)
|
|
8
|
+
class Column:
|
|
9
|
+
stype: Stype
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
name: str,
|
|
14
|
+
dtype: Dtype,
|
|
15
|
+
stype: Stype,
|
|
16
|
+
is_primary_key: bool = False,
|
|
17
|
+
is_time_column: bool = False,
|
|
18
|
+
is_end_time_column: bool = False,
|
|
19
|
+
) -> None:
|
|
20
|
+
self._name = name
|
|
21
|
+
self._dtype = Dtype(dtype)
|
|
22
|
+
self._is_primary_key = is_primary_key
|
|
23
|
+
self._is_time_column = is_time_column
|
|
24
|
+
self._is_end_time_column = is_end_time_column
|
|
25
|
+
self.stype = Stype(stype)
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def name(self) -> str:
|
|
29
|
+
return self._name
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def dtype(self) -> Dtype:
|
|
33
|
+
return self._dtype
|
|
34
|
+
|
|
35
|
+
def __setattr__(self, key: str, val: Any) -> None:
|
|
36
|
+
if key == 'stype':
|
|
37
|
+
if isinstance(val, str):
|
|
38
|
+
val = Stype(val)
|
|
39
|
+
assert isinstance(val, Stype)
|
|
40
|
+
if not val.supports_dtype(self.dtype):
|
|
41
|
+
raise ValueError(f"Column '{self.name}' received an "
|
|
42
|
+
f"incompatible semantic type (got "
|
|
43
|
+
f"dtype='{self.dtype}' and stype='{val}')")
|
|
44
|
+
if self._is_primary_key and val != Stype.ID:
|
|
45
|
+
raise ValueError(f"Primary key '{self.name}' must have 'ID' "
|
|
46
|
+
f"semantic type (got '{val}')")
|
|
47
|
+
if self._is_time_column and val != Stype.timestamp:
|
|
48
|
+
raise ValueError(f"Time column '{self.name}' must have "
|
|
49
|
+
f"'timestamp' semantic type (got '{val}')")
|
|
50
|
+
if self._is_end_time_column and val != Stype.timestamp:
|
|
51
|
+
raise ValueError(f"End time column '{self.name}' must have "
|
|
52
|
+
f"'timestamp' semantic type (got '{val}')")
|
|
53
|
+
|
|
54
|
+
super().__setattr__(key, val)
|
|
55
|
+
|
|
56
|
+
def __hash__(self) -> int:
|
|
57
|
+
return hash((self.name, self.stype, self.dtype))
|
|
58
|
+
|
|
59
|
+
def __eq__(self, other: Any) -> bool:
|
|
60
|
+
if not isinstance(other, Column):
|
|
61
|
+
return False
|
|
62
|
+
return hash(self) == hash(other)
|
|
63
|
+
|
|
64
|
+
def __repr__(self) -> str:
|
|
65
|
+
return (f'{self.__class__.__name__}(name={self.name}, '
|
|
66
|
+
f'stype={self.stype}, dtype={self.dtype})')
|