kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +300 -0
- kumoai/_logging.py +29 -0
- kumoai/_singleton.py +25 -0
- kumoai/_version.py +1 -0
- kumoai/artifact_export/__init__.py +9 -0
- kumoai/artifact_export/config.py +209 -0
- kumoai/artifact_export/job.py +108 -0
- kumoai/client/__init__.py +5 -0
- kumoai/client/client.py +223 -0
- kumoai/client/connector.py +110 -0
- kumoai/client/endpoints.py +150 -0
- kumoai/client/graph.py +120 -0
- kumoai/client/jobs.py +471 -0
- kumoai/client/online.py +78 -0
- kumoai/client/pquery.py +207 -0
- kumoai/client/rfm.py +112 -0
- kumoai/client/source_table.py +53 -0
- kumoai/client/table.py +101 -0
- kumoai/client/utils.py +130 -0
- kumoai/codegen/__init__.py +19 -0
- kumoai/codegen/cli.py +100 -0
- kumoai/codegen/context.py +16 -0
- kumoai/codegen/edits.py +473 -0
- kumoai/codegen/exceptions.py +10 -0
- kumoai/codegen/generate.py +222 -0
- kumoai/codegen/handlers/__init__.py +4 -0
- kumoai/codegen/handlers/connector.py +118 -0
- kumoai/codegen/handlers/graph.py +71 -0
- kumoai/codegen/handlers/pquery.py +62 -0
- kumoai/codegen/handlers/table.py +109 -0
- kumoai/codegen/handlers/utils.py +42 -0
- kumoai/codegen/identity.py +114 -0
- kumoai/codegen/loader.py +93 -0
- kumoai/codegen/naming.py +94 -0
- kumoai/codegen/registry.py +121 -0
- kumoai/connector/__init__.py +31 -0
- kumoai/connector/base.py +153 -0
- kumoai/connector/bigquery_connector.py +200 -0
- kumoai/connector/databricks_connector.py +213 -0
- kumoai/connector/file_upload_connector.py +189 -0
- kumoai/connector/glue_connector.py +150 -0
- kumoai/connector/s3_connector.py +278 -0
- kumoai/connector/snowflake_connector.py +252 -0
- kumoai/connector/source_table.py +471 -0
- kumoai/connector/utils.py +1796 -0
- kumoai/databricks.py +14 -0
- kumoai/encoder/__init__.py +4 -0
- kumoai/exceptions.py +26 -0
- kumoai/experimental/__init__.py +0 -0
- kumoai/experimental/rfm/__init__.py +210 -0
- kumoai/experimental/rfm/authenticate.py +432 -0
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +736 -0
- kumoai/experimental/rfm/graph.py +1237 -0
- kumoai/experimental/rfm/infer/__init__.py +19 -0
- kumoai/experimental/rfm/infer/categorical.py +40 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/id.py +46 -0
- kumoai/experimental/rfm/infer/multicategorical.py +48 -0
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/infer/timestamp.py +41 -0
- kumoai/experimental/rfm/pquery/__init__.py +7 -0
- kumoai/experimental/rfm/pquery/executor.py +102 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +1184 -0
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/experimental/rfm/task_table.py +231 -0
- kumoai/formatting.py +30 -0
- kumoai/futures.py +99 -0
- kumoai/graph/__init__.py +12 -0
- kumoai/graph/column.py +106 -0
- kumoai/graph/graph.py +948 -0
- kumoai/graph/table.py +838 -0
- kumoai/jobs.py +80 -0
- kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
- kumoai/mixin.py +28 -0
- kumoai/pquery/__init__.py +25 -0
- kumoai/pquery/prediction_table.py +287 -0
- kumoai/pquery/predictive_query.py +641 -0
- kumoai/pquery/training_table.py +424 -0
- kumoai/spcs.py +121 -0
- kumoai/testing/__init__.py +8 -0
- kumoai/testing/decorators.py +57 -0
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/__init__.py +42 -0
- kumoai/trainer/baseline_trainer.py +93 -0
- kumoai/trainer/config.py +2 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/trainer/job.py +1192 -0
- kumoai/trainer/online_serving.py +258 -0
- kumoai/trainer/trainer.py +475 -0
- kumoai/trainer/util.py +103 -0
- kumoai/utils/__init__.py +11 -0
- kumoai/utils/datasets.py +83 -0
- kumoai/utils/display.py +51 -0
- kumoai/utils/forecasting.py +209 -0
- kumoai/utils/progress_logger.py +343 -0
- kumoai/utils/sql.py +3 -0
- kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
- kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
- kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
- kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
- kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
from kumoapi.data_source import (
|
|
5
|
+
CreateConnectorArgs,
|
|
6
|
+
DataSourceType,
|
|
7
|
+
GlueConnectorResourceConfig,
|
|
8
|
+
)
|
|
9
|
+
from kumoapi.source_table import GlueSourceTableRequest
|
|
10
|
+
from typing_extensions import Self, override
|
|
11
|
+
|
|
12
|
+
from kumoai import global_state
|
|
13
|
+
from kumoai.connector import Connector
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
_DEFAULT_NAME = 'glue_connector'
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class GlueConnector(Connector):
|
|
21
|
+
r"""Defines a connector to a table stored in AWS Glue catalog. Currently,
|
|
22
|
+
only supports tables in partitioned parquet format. Authenticated via IAM
|
|
23
|
+
permissions on Glue catalog and data location in S3.
|
|
24
|
+
|
|
25
|
+
.. code-block:: python
|
|
26
|
+
|
|
27
|
+
import kumoai
|
|
28
|
+
connector = kumoai.GlueConnector(database="...", region="...", account="...")
|
|
29
|
+
|
|
30
|
+
# List all tables:
|
|
31
|
+
print(connector.table_names()) # Returns: ['articles', 'customers', 'users']
|
|
32
|
+
|
|
33
|
+
# Check whether a table is present:
|
|
34
|
+
assert "articles" in connector
|
|
35
|
+
|
|
36
|
+
# Fetch a source table (both approaches are equivalent):
|
|
37
|
+
source_table = connector["articles"]
|
|
38
|
+
source_table = connector.table("articles")
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
account: The account of the Glue catalog.
|
|
42
|
+
region: The region of the Glue catalog.
|
|
43
|
+
database: The name of the database in the Glue Catalog
|
|
44
|
+
|
|
45
|
+
""" # noqa
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
name: str,
|
|
50
|
+
account: str,
|
|
51
|
+
region: str,
|
|
52
|
+
database: str,
|
|
53
|
+
_bypass_creation: bool = False, # INTERNAL ONLY.
|
|
54
|
+
) -> None:
|
|
55
|
+
self._name = name
|
|
56
|
+
self.account = account
|
|
57
|
+
self.region = region
|
|
58
|
+
self.database = database
|
|
59
|
+
if global_state.is_spcs and database is not None:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"Glue connectors are not supported when running Kumo in "
|
|
62
|
+
"Snowpark container services. Please use a Snowflake "
|
|
63
|
+
"connector instead.")
|
|
64
|
+
|
|
65
|
+
if _bypass_creation:
|
|
66
|
+
return
|
|
67
|
+
|
|
68
|
+
self._create_connector()
|
|
69
|
+
|
|
70
|
+
@override
|
|
71
|
+
@property
|
|
72
|
+
def name(self) -> str:
|
|
73
|
+
r"""Returns the name of this connector."""
|
|
74
|
+
return self._name
|
|
75
|
+
|
|
76
|
+
@override
|
|
77
|
+
@property
|
|
78
|
+
def source_type(self) -> DataSourceType:
|
|
79
|
+
return DataSourceType.GLUE
|
|
80
|
+
|
|
81
|
+
@override
|
|
82
|
+
def _source_table_request(
|
|
83
|
+
self,
|
|
84
|
+
table_names: List[str],
|
|
85
|
+
) -> GlueSourceTableRequest:
|
|
86
|
+
return GlueSourceTableRequest(
|
|
87
|
+
connector_id=self.name,
|
|
88
|
+
table_names=table_names,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def _create_connector(self) -> None:
|
|
92
|
+
r"""Creates and persists a Glue connector in the REST DB.
|
|
93
|
+
Currently only intended for internal use.
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
Raises:
|
|
97
|
+
RuntimeError: if connector creation failed
|
|
98
|
+
"""
|
|
99
|
+
args = CreateConnectorArgs(
|
|
100
|
+
config=GlueConnectorResourceConfig(
|
|
101
|
+
name=self.name,
|
|
102
|
+
account=self.account,
|
|
103
|
+
region=self.region,
|
|
104
|
+
database=self.database,
|
|
105
|
+
), )
|
|
106
|
+
global_state.client.connector_api.create_if_not_exist(args)
|
|
107
|
+
|
|
108
|
+
def _delete_connector(self) -> None:
|
|
109
|
+
r"""Deletes a connector in the REST DB. Only intended for internal
|
|
110
|
+
use.
|
|
111
|
+
"""
|
|
112
|
+
global_state.client.connector_api.delete_if_exists(self.name)
|
|
113
|
+
|
|
114
|
+
# Class properties ########################################################
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
def get_by_name(cls, name: str) -> Self:
|
|
118
|
+
r"""Returns an instance of a named Glue Connector, including those created
|
|
119
|
+
in the Kumo UI.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
name: The name of the existing connector.
|
|
123
|
+
|
|
124
|
+
Example:
|
|
125
|
+
>>> import kumoai
|
|
126
|
+
>>> connector = kumoai.GlueConnector.get_by_name("name") # doctest: +SKIP # noqa: E501
|
|
127
|
+
"""
|
|
128
|
+
api = global_state.client.connector_api
|
|
129
|
+
resp = api.get(name)
|
|
130
|
+
if resp is None:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
f"There does not exist an existing stored connector with name "
|
|
133
|
+
f"{name}.")
|
|
134
|
+
config = resp.config
|
|
135
|
+
assert isinstance(config, GlueConnectorResourceConfig)
|
|
136
|
+
return cls(
|
|
137
|
+
name=config.name,
|
|
138
|
+
account=config.account,
|
|
139
|
+
region=config.region,
|
|
140
|
+
database=config.database,
|
|
141
|
+
_bypass_creation=True,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
@override
|
|
145
|
+
def __repr__(self) -> str:
|
|
146
|
+
account_name = f"\"{self.account}\"" if self.account else "None"
|
|
147
|
+
region_name = f"\"{self.region}\"" if self.region else "None"
|
|
148
|
+
database_name = f"\"{self.database}\"" if self.database else "None"
|
|
149
|
+
return (f'{self.__class__.__name__}(account={account_name}, '
|
|
150
|
+
f'region={region_name}, database={database_name})')
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
from kumoapi.data_source import DataSourceType, FileConnectorResourceConfig
|
|
5
|
+
from kumoapi.source_table import (
|
|
6
|
+
S3SourceTableRequest,
|
|
7
|
+
SourceTableConfigRequest,
|
|
8
|
+
SourceTableConfigResponse,
|
|
9
|
+
SourceTableListRequest,
|
|
10
|
+
SourceTableValidateRequest,
|
|
11
|
+
SourceTableValidateResponse,
|
|
12
|
+
)
|
|
13
|
+
from typing_extensions import Self, override
|
|
14
|
+
|
|
15
|
+
from kumoai import global_state
|
|
16
|
+
from kumoai.connector import Connector
|
|
17
|
+
from kumoai.connector.source_table import SourceTable
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
_DEFAULT_NAME = 's3_connector'
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class S3Connector(Connector):
|
|
25
|
+
r"""Defines a connector to a table stored as a file (or partitioned
|
|
26
|
+
set of files) on the Amazon `S3 <https://aws.amazon.com/s3/>`__ object
|
|
27
|
+
store. Any table behind an S3 bucket accessible by the shared external IAM
|
|
28
|
+
role can be accessed through this connector.
|
|
29
|
+
|
|
30
|
+
.. code-block:: python
|
|
31
|
+
|
|
32
|
+
import kumoai
|
|
33
|
+
connector = kumoai.S3Connector(root_dir="s3://...") # an S3 path.
|
|
34
|
+
|
|
35
|
+
# List all tables:
|
|
36
|
+
print(connector.table_names()) # Returns: ['articles', 'customers', 'users']
|
|
37
|
+
|
|
38
|
+
# Check whether a table is present:
|
|
39
|
+
assert "articles" in connector
|
|
40
|
+
|
|
41
|
+
# Fetch a source table (both approaches are equivalent):
|
|
42
|
+
source_table = connector["articles"]
|
|
43
|
+
source_table = connector.table("articles")
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
root_dir: The root directory of this connector. If provided, the root
|
|
47
|
+
directory is used as a prefix for tables in this connector. If not
|
|
48
|
+
provided, all tables must be specified by their full S3 paths.
|
|
49
|
+
""" # noqa
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
root_dir: Optional[str] = None,
|
|
54
|
+
_connector_id: Optional[str] = None,
|
|
55
|
+
) -> None:
|
|
56
|
+
if _connector_id is not None:
|
|
57
|
+
# UI S3Connector, named:
|
|
58
|
+
self._connector_id = _connector_id
|
|
59
|
+
self.root_dir = None
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
self._connector_id = _DEFAULT_NAME
|
|
63
|
+
if root_dir is not None:
|
|
64
|
+
# Remove trailing / to be consistent with boto s3
|
|
65
|
+
root_dir = root_dir.rstrip('/')
|
|
66
|
+
self.root_dir = root_dir
|
|
67
|
+
if global_state.is_spcs and root_dir is not None \
|
|
68
|
+
and root_dir.startswith('s3://'):
|
|
69
|
+
raise ValueError(
|
|
70
|
+
"S3 connectors are not supported when running Kumo in "
|
|
71
|
+
"Snowpark container services. Please use a Snowflake "
|
|
72
|
+
"connector instead.")
|
|
73
|
+
|
|
74
|
+
@override
|
|
75
|
+
@property
|
|
76
|
+
def name(self) -> str:
|
|
77
|
+
r"""Not supported by :class:`S3Connector`; returns an internal
|
|
78
|
+
specifier.
|
|
79
|
+
"""
|
|
80
|
+
return self._connector_id
|
|
81
|
+
|
|
82
|
+
@override
|
|
83
|
+
@property
|
|
84
|
+
def source_type(self) -> DataSourceType:
|
|
85
|
+
return DataSourceType.S3
|
|
86
|
+
|
|
87
|
+
@override
|
|
88
|
+
def _source_table_request(
|
|
89
|
+
self,
|
|
90
|
+
table_names: List[str],
|
|
91
|
+
) -> S3SourceTableRequest:
|
|
92
|
+
root_dir = self.root_dir
|
|
93
|
+
if not root_dir and self.name == _DEFAULT_NAME:
|
|
94
|
+
# Handle None root directories (table name is a path):
|
|
95
|
+
table_path = S3URI(table_names[0]).validate()
|
|
96
|
+
root_dir = table_path.root_dir
|
|
97
|
+
for i, v in enumerate(table_names):
|
|
98
|
+
uri = S3URI(v)
|
|
99
|
+
if uri.root_dir != root_dir:
|
|
100
|
+
# TODO(manan): fix
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"Please ensure that all of your tables are behind "
|
|
103
|
+
f"the same root directory ({root_dir}).")
|
|
104
|
+
table_names[i] = uri.object_name
|
|
105
|
+
|
|
106
|
+
connector_id = self.name if self.name != _DEFAULT_NAME else None
|
|
107
|
+
root_dir = root_dir if self.name == _DEFAULT_NAME else ""
|
|
108
|
+
|
|
109
|
+
# TODO(manan): file type?
|
|
110
|
+
return S3SourceTableRequest(
|
|
111
|
+
s3_root_dir=root_dir,
|
|
112
|
+
connector_id=connector_id,
|
|
113
|
+
table_names=table_names,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
@override
|
|
117
|
+
def table(self, name: str) -> SourceTable:
|
|
118
|
+
r"""Returns a :class:`~kumoai.connector.SourceTable` object
|
|
119
|
+
corresponding to a source table on Amazon S3.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
name: The name of the table on S3. If :obj:`root_dir` is provided,
|
|
123
|
+
the path will be specified as :obj:`root_dir/name`. If
|
|
124
|
+
:obj:`root_dir` is not provided, the name should be the full
|
|
125
|
+
path (e.g. starting with ``s3://``).
|
|
126
|
+
|
|
127
|
+
Raises:
|
|
128
|
+
:class:`ValueError`: if ``name`` does not exist in the backing
|
|
129
|
+
connector.
|
|
130
|
+
"""
|
|
131
|
+
# NOTE only overridden for documentation purposes.
|
|
132
|
+
return super().table(name)
|
|
133
|
+
|
|
134
|
+
@override
|
|
135
|
+
def _list_tables(self) -> List[str]:
|
|
136
|
+
connector_id = self.name if self.name != _DEFAULT_NAME else None
|
|
137
|
+
root_dir = self.root_dir if self.name == _DEFAULT_NAME else None
|
|
138
|
+
if root_dir is None and connector_id is None:
|
|
139
|
+
raise ValueError(
|
|
140
|
+
"Listing tables without a specified root directory is not "
|
|
141
|
+
"supported. Please specify a root directory to continue; "
|
|
142
|
+
"alternatively, please access individual tables with their "
|
|
143
|
+
"full S3 paths.")
|
|
144
|
+
|
|
145
|
+
req = SourceTableListRequest(connector_id=connector_id,
|
|
146
|
+
root_dir=root_dir,
|
|
147
|
+
source_type=DataSourceType.S3)
|
|
148
|
+
return global_state.client.source_table_api.list_tables(req)
|
|
149
|
+
|
|
150
|
+
@override
|
|
151
|
+
def _get_table_config(self, table_name: str) -> SourceTableConfigResponse:
|
|
152
|
+
root_dir = self.root_dir
|
|
153
|
+
if not root_dir and self.name == _DEFAULT_NAME:
|
|
154
|
+
# Handle None root directories (table name is a path):
|
|
155
|
+
table_path = S3URI(table_name).validate()
|
|
156
|
+
root_dir = table_path.root_dir
|
|
157
|
+
table_name = table_path.object_name
|
|
158
|
+
|
|
159
|
+
connector_id = self.name if self.name != _DEFAULT_NAME else None
|
|
160
|
+
root_dir = root_dir if self.name == _DEFAULT_NAME else None
|
|
161
|
+
|
|
162
|
+
req = SourceTableConfigRequest(
|
|
163
|
+
connector_id=connector_id,
|
|
164
|
+
root_dir=root_dir,
|
|
165
|
+
table_name=table_name,
|
|
166
|
+
source_type=self.source_type,
|
|
167
|
+
)
|
|
168
|
+
return global_state.client.source_table_api.get_table_config(req)
|
|
169
|
+
|
|
170
|
+
@override
|
|
171
|
+
def _validate_table(self, table_name: str) -> SourceTableValidateResponse:
|
|
172
|
+
if table_name in self._validated_tables:
|
|
173
|
+
return SourceTableValidateResponse(is_valid=True, msg='')
|
|
174
|
+
|
|
175
|
+
if self.name == _DEFAULT_NAME:
|
|
176
|
+
# For S3 connector without name, pass root_dir to validate table -
|
|
177
|
+
# need to infer from table name if root_dir is not provided
|
|
178
|
+
if self.root_dir:
|
|
179
|
+
req = SourceTableValidateRequest(
|
|
180
|
+
connector_id=None,
|
|
181
|
+
table_name=table_name,
|
|
182
|
+
source_type=self.source_type,
|
|
183
|
+
root_dir=self.root_dir,
|
|
184
|
+
)
|
|
185
|
+
else:
|
|
186
|
+
req_root_dir = '/'.join(table_name.split('/')[:-1])
|
|
187
|
+
req_table_name = table_name.split('/')[-1]
|
|
188
|
+
|
|
189
|
+
req = SourceTableValidateRequest(
|
|
190
|
+
connector_id=None,
|
|
191
|
+
table_name=req_table_name,
|
|
192
|
+
source_type=self.source_type,
|
|
193
|
+
root_dir=req_root_dir,
|
|
194
|
+
)
|
|
195
|
+
else:
|
|
196
|
+
req = SourceTableValidateRequest(connector_id=self.name,
|
|
197
|
+
table_name=table_name,
|
|
198
|
+
source_type=self.source_type)
|
|
199
|
+
|
|
200
|
+
ret = global_state.client.source_table_api.validate_table(req)
|
|
201
|
+
# Cache the result for the whole session.
|
|
202
|
+
if ret.is_valid:
|
|
203
|
+
self._validated_tables.add(table_name)
|
|
204
|
+
return ret
|
|
205
|
+
|
|
206
|
+
# Class properties ########################################################
|
|
207
|
+
|
|
208
|
+
@classmethod
|
|
209
|
+
def get_by_name(cls, name: str) -> Self:
|
|
210
|
+
r"""Returns an instance of a named S3 Connector, created in the Kumo UI.
|
|
211
|
+
|
|
212
|
+
.. note::
|
|
213
|
+
Named S3 connectors are read-only: if you would like to modify the
|
|
214
|
+
root directory, please do so from the UI.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
name: The name of the existing connector.
|
|
218
|
+
|
|
219
|
+
Example:
|
|
220
|
+
>>> import kumoai
|
|
221
|
+
>>> connector = kumoai.S3Connector.get_by_name("name") # doctest: +SKIP # noqa: E501
|
|
222
|
+
"""
|
|
223
|
+
api = global_state.client.connector_api
|
|
224
|
+
resp = api.get(name)
|
|
225
|
+
if resp is None:
|
|
226
|
+
raise ValueError(
|
|
227
|
+
f"There does not exist an existing stored connector with name "
|
|
228
|
+
f"{name}.")
|
|
229
|
+
config = resp.config
|
|
230
|
+
assert isinstance(config, FileConnectorResourceConfig)
|
|
231
|
+
return cls(
|
|
232
|
+
root_dir=None,
|
|
233
|
+
_connector_id=config.name,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
def __repr__(self) -> str:
|
|
237
|
+
if self.name != _DEFAULT_NAME:
|
|
238
|
+
return f'{self.__class__.__name__}(name={self.name})'
|
|
239
|
+
root_dir_name = f"\"{self.root_dir}\"" if self.root_dir else "None"
|
|
240
|
+
return f'{self.__class__.__name__}(root_dir={root_dir_name})'
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class S3URI:
|
|
244
|
+
r"""A utility class to parse and navigate S3 URIs."""
|
|
245
|
+
def __init__(self, uri: str):
|
|
246
|
+
self.uri: str = uri
|
|
247
|
+
if uri.endswith('/'): # remove trailing slash
|
|
248
|
+
self.uri = uri[:-1]
|
|
249
|
+
|
|
250
|
+
@property
|
|
251
|
+
def is_valid(self) -> bool:
|
|
252
|
+
# TODO(zeyuan): For SPCS, the path can be a local filesystem path
|
|
253
|
+
# For train/pred table.
|
|
254
|
+
if global_state.is_spcs:
|
|
255
|
+
return True
|
|
256
|
+
# TODO(manan): implement more checks...
|
|
257
|
+
return self.uri.startswith("s3://")
|
|
258
|
+
|
|
259
|
+
def validate(self) -> Self:
|
|
260
|
+
if not self.is_valid:
|
|
261
|
+
raise ValueError(f"Path {self.uri} is not a valid S3 URI.")
|
|
262
|
+
return self
|
|
263
|
+
|
|
264
|
+
@property
|
|
265
|
+
def root_dir(self) -> str:
|
|
266
|
+
self.validate()
|
|
267
|
+
return self.uri.rsplit('/', 1)[0]
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def object_name(self) -> str:
|
|
271
|
+
self.validate()
|
|
272
|
+
return self.uri.rsplit('/', 1)[1]
|
|
273
|
+
|
|
274
|
+
# Class properties ########################################################
|
|
275
|
+
|
|
276
|
+
def __repr__(self) -> str:
|
|
277
|
+
return (f'{self.__class__.__name__}('
|
|
278
|
+
f'uri={self.uri}, valid={self.is_valid})')
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from kumoapi.data_source import (
|
|
5
|
+
CreateConnectorArgs,
|
|
6
|
+
DataSourceType,
|
|
7
|
+
KeyPair,
|
|
8
|
+
SnowflakeConnectorResourceConfig,
|
|
9
|
+
UsernamePassword,
|
|
10
|
+
)
|
|
11
|
+
from kumoapi.source_table import SnowflakeSourceTableRequest
|
|
12
|
+
from typing_extensions import Self, override
|
|
13
|
+
|
|
14
|
+
from kumoai import global_state
|
|
15
|
+
from kumoai.connector import Connector
|
|
16
|
+
|
|
17
|
+
_ENV_SNOWFLAKE_USER = 'SNOWFLAKE_USER'
|
|
18
|
+
_ENV_SNOWFLAKE_PASSWORD = 'SNOWFLAKE_PASSWORD'
|
|
19
|
+
_ENV_SNOWFLAKE_PRIVATE_KEY = 'SNOWFLAKE_PRIVATE_KEY'
|
|
20
|
+
_ENV_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE = 'SNOWFLAKE_PRIVATE_KEY_PASSPHRASE'
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SnowflakeConnector(Connector):
|
|
24
|
+
r"""Establishes a connection to a `Snowflake <https://www.snowflake.com/>`_
|
|
25
|
+
database.
|
|
26
|
+
|
|
27
|
+
Multiple methods of authentication are available. Username/password
|
|
28
|
+
authentication is supported either via environment variables
|
|
29
|
+
(``SNOWFLAKE_USER`` and ``SNOWFLAKE_PASSWORD``) or via keys in the
|
|
30
|
+
credentials dictionary (:obj:`user` and :obj:`password`).
|
|
31
|
+
|
|
32
|
+
.. note::
|
|
33
|
+
Key-pair authentication is coming soon; please contact your Kumo POC if
|
|
34
|
+
you need access.
|
|
35
|
+
|
|
36
|
+
.. code-block:: python
|
|
37
|
+
|
|
38
|
+
import kumoai
|
|
39
|
+
|
|
40
|
+
# Either pass `credentials=dict(user=..., password=...)` or set the
|
|
41
|
+
# 'SNOWFLAKE_USER' and 'SNOWFLAKE_PASSWORD' environment variables:
|
|
42
|
+
connector = kumoai.SnowflakeConnector(
|
|
43
|
+
name="<connector_name>",
|
|
44
|
+
account="<snowflake_account_name>",
|
|
45
|
+
database="<snowflake_database_name>"
|
|
46
|
+
schema_name="<snowflake_schema_name>",
|
|
47
|
+
credentials=credentials,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# List all tables:
|
|
51
|
+
print(connector.table_names())
|
|
52
|
+
|
|
53
|
+
# Check whether a table is present:
|
|
54
|
+
assert "articles" in connector
|
|
55
|
+
|
|
56
|
+
# Fetch a source table (both approaches are equivalent):
|
|
57
|
+
source_table = connector["articles"]
|
|
58
|
+
source_table = connector.table("articles")
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
name: The name of the connector.
|
|
62
|
+
account: The account name.
|
|
63
|
+
warehouse: The name of the warehouse.
|
|
64
|
+
database: The name of the database.
|
|
65
|
+
schema_name: The name of the schema.
|
|
66
|
+
credentials: The username and password corresponding to this Snowflake
|
|
67
|
+
account, if not provided as environment variables.
|
|
68
|
+
"""
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
name: str,
|
|
72
|
+
account: str,
|
|
73
|
+
warehouse: str,
|
|
74
|
+
database: str,
|
|
75
|
+
schema_name: str,
|
|
76
|
+
credentials: Optional[Dict[str, str]] = None,
|
|
77
|
+
_bypass_creation: bool = False, # INTERNAL ONLY.
|
|
78
|
+
):
|
|
79
|
+
super().__init__()
|
|
80
|
+
|
|
81
|
+
self._name = name
|
|
82
|
+
self.account = account
|
|
83
|
+
self.warehouse = warehouse
|
|
84
|
+
|
|
85
|
+
# Snowflake DBs and schemas are all in upper-case:
|
|
86
|
+
self.database = database.upper()
|
|
87
|
+
self.schema_name = schema_name.upper()
|
|
88
|
+
|
|
89
|
+
if _bypass_creation:
|
|
90
|
+
# TODO(manan, siyang): validate that this connector actually exists
|
|
91
|
+
# in the REST DB:
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
# Fully specify credentials, create Kumo connector:
|
|
95
|
+
credentials = credentials or global_state._snowflake_credentials or {}
|
|
96
|
+
|
|
97
|
+
credentials_args = {
|
|
98
|
+
"user": credentials.get("user", os.getenv(_ENV_SNOWFLAKE_USER)),
|
|
99
|
+
}
|
|
100
|
+
password = credentials.get("password",
|
|
101
|
+
os.getenv(_ENV_SNOWFLAKE_PASSWORD))
|
|
102
|
+
private_key = credentials.get("private_key",
|
|
103
|
+
os.getenv(_ENV_SNOWFLAKE_PRIVATE_KEY))
|
|
104
|
+
private_key_passphrase = credentials.get(
|
|
105
|
+
"private_key_passphrase",
|
|
106
|
+
os.getenv(_ENV_SNOWFLAKE_PRIVATE_KEY_PASSPHRASE))
|
|
107
|
+
|
|
108
|
+
if not password and not private_key:
|
|
109
|
+
self._create_native_connector()
|
|
110
|
+
return
|
|
111
|
+
|
|
112
|
+
# Don't pass unused credential fields so that _create_connector can
|
|
113
|
+
# decide which credential class (KeyPair or UsernamePassword) to use
|
|
114
|
+
if private_key:
|
|
115
|
+
credentials_args["private_key"] = private_key
|
|
116
|
+
if private_key_passphrase:
|
|
117
|
+
credentials_args[
|
|
118
|
+
"private_key_passphrase"] = private_key_passphrase
|
|
119
|
+
else:
|
|
120
|
+
credentials_args["password"] = password
|
|
121
|
+
error_name = None
|
|
122
|
+
error_var = None
|
|
123
|
+
if credentials_args["user"] is None:
|
|
124
|
+
error_name = "username"
|
|
125
|
+
error_var = _ENV_SNOWFLAKE_USER
|
|
126
|
+
elif password is None and private_key is None:
|
|
127
|
+
error_name = "password or private key"
|
|
128
|
+
error_var = f"{_ENV_SNOWFLAKE_PASSWORD} or " + \
|
|
129
|
+
f"{_ENV_SNOWFLAKE_PRIVATE_KEY}"
|
|
130
|
+
if error_name is not None:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
f"Please pass a valid {error_name} to create a Snowflake "
|
|
133
|
+
f"connector. You can do so either via the 'credentials' "
|
|
134
|
+
f"argument or the {error_var} environment variable.")
|
|
135
|
+
|
|
136
|
+
self._create_connector(credentials_args) # type: ignore
|
|
137
|
+
|
|
138
|
+
@classmethod
|
|
139
|
+
def get_by_name(cls, name: str) -> Self:
|
|
140
|
+
r"""Returns an instance of a named Snowflake Connector, including
|
|
141
|
+
those created in the Kumo UI.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
name: The name of the existing connector.
|
|
145
|
+
|
|
146
|
+
Example:
|
|
147
|
+
>>> import kumoai
|
|
148
|
+
>>> connector = kumoai.SnowflakeConnector.get_by_name("name") # doctest: +SKIP # noqa: E501
|
|
149
|
+
"""
|
|
150
|
+
api = global_state.client.connector_api
|
|
151
|
+
resp = api.get(name)
|
|
152
|
+
if resp is None:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"There does not exist an existing stored connector with name "
|
|
155
|
+
f"{name}.")
|
|
156
|
+
config = resp.config
|
|
157
|
+
assert isinstance(config, SnowflakeConnectorResourceConfig)
|
|
158
|
+
return cls(
|
|
159
|
+
name=config.name,
|
|
160
|
+
account=config.account,
|
|
161
|
+
warehouse=config.warehouse,
|
|
162
|
+
database=config.database,
|
|
163
|
+
schema_name=config.schema_name,
|
|
164
|
+
credentials=None,
|
|
165
|
+
_bypass_creation=True,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
@override
|
|
169
|
+
@property
|
|
170
|
+
def name(self) -> str:
|
|
171
|
+
r"""Returns the name of this connector."""
|
|
172
|
+
return self._name
|
|
173
|
+
|
|
174
|
+
@override
|
|
175
|
+
@property
|
|
176
|
+
def source_type(self) -> DataSourceType:
|
|
177
|
+
return DataSourceType.SNOWFLAKE
|
|
178
|
+
|
|
179
|
+
@override
|
|
180
|
+
def _source_table_request(
|
|
181
|
+
self,
|
|
182
|
+
table_names: List[str],
|
|
183
|
+
) -> SnowflakeSourceTableRequest:
|
|
184
|
+
return SnowflakeSourceTableRequest(
|
|
185
|
+
connector_id=self.name,
|
|
186
|
+
table_names=table_names,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
def _create_connector(self, credentials: Dict[str, str]) -> None:
|
|
190
|
+
r"""Creates and persists a Snowflake connector in the REST DB.
|
|
191
|
+
Currently only intended for internal use.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
credentials: Fully-specified credentials containing the username
|
|
195
|
+
and password for the Snowflake connector.
|
|
196
|
+
|
|
197
|
+
Raises:
|
|
198
|
+
RuntimeError: if connector creation failed
|
|
199
|
+
"""
|
|
200
|
+
# TODO(manan, siyang): consider avoiding connector persistence in the
|
|
201
|
+
# REST DB, instead moving towards global connectors. For now, to get
|
|
202
|
+
# a Snowflake experience working smoothly, using the old interface:
|
|
203
|
+
if credentials.get("password") is not None:
|
|
204
|
+
credentials = UsernamePassword(
|
|
205
|
+
username=credentials["user"],
|
|
206
|
+
password=credentials["password"],
|
|
207
|
+
)
|
|
208
|
+
else:
|
|
209
|
+
credentials = KeyPair(
|
|
210
|
+
user=credentials["user"],
|
|
211
|
+
private_key=credentials["private_key"],
|
|
212
|
+
private_key_passphrase=credentials.get(
|
|
213
|
+
"private_key_passphrase"),
|
|
214
|
+
)
|
|
215
|
+
args = CreateConnectorArgs(
|
|
216
|
+
config=SnowflakeConnectorResourceConfig(
|
|
217
|
+
name=self.name,
|
|
218
|
+
account=self.account,
|
|
219
|
+
warehouse=self.warehouse,
|
|
220
|
+
database=self.database,
|
|
221
|
+
schema_name=self.schema_name,
|
|
222
|
+
),
|
|
223
|
+
credentials=credentials,
|
|
224
|
+
)
|
|
225
|
+
global_state.client.connector_api.create_if_not_exist(args)
|
|
226
|
+
|
|
227
|
+
def _create_native_connector(self) -> None:
|
|
228
|
+
args = CreateConnectorArgs(
|
|
229
|
+
config=SnowflakeConnectorResourceConfig(
|
|
230
|
+
name=self.name,
|
|
231
|
+
account=self.account,
|
|
232
|
+
warehouse=self.warehouse,
|
|
233
|
+
database=self.database,
|
|
234
|
+
schema_name=self.schema_name,
|
|
235
|
+
),
|
|
236
|
+
credentials=None,
|
|
237
|
+
)
|
|
238
|
+
global_state.client.connector_api.create_if_not_exist(args)
|
|
239
|
+
|
|
240
|
+
def _delete_connector(self) -> None:
|
|
241
|
+
r"""Deletes a connector in the REST DB. Only intended for internal
|
|
242
|
+
use.
|
|
243
|
+
"""
|
|
244
|
+
global_state.client.connector_api.delete_if_exists(self.name)
|
|
245
|
+
|
|
246
|
+
# Class properties ########################################################
|
|
247
|
+
|
|
248
|
+
@override
|
|
249
|
+
def __repr__(self) -> str:
|
|
250
|
+
return (f'{self.__class__.__name__}'
|
|
251
|
+
f'(account=\"{self.account}\", database=\"{self.database}\", '
|
|
252
|
+
f'schema=\"{self.schema_name}\")')
|