kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +300 -0
- kumoai/_logging.py +29 -0
- kumoai/_singleton.py +25 -0
- kumoai/_version.py +1 -0
- kumoai/artifact_export/__init__.py +9 -0
- kumoai/artifact_export/config.py +209 -0
- kumoai/artifact_export/job.py +108 -0
- kumoai/client/__init__.py +5 -0
- kumoai/client/client.py +223 -0
- kumoai/client/connector.py +110 -0
- kumoai/client/endpoints.py +150 -0
- kumoai/client/graph.py +120 -0
- kumoai/client/jobs.py +471 -0
- kumoai/client/online.py +78 -0
- kumoai/client/pquery.py +207 -0
- kumoai/client/rfm.py +112 -0
- kumoai/client/source_table.py +53 -0
- kumoai/client/table.py +101 -0
- kumoai/client/utils.py +130 -0
- kumoai/codegen/__init__.py +19 -0
- kumoai/codegen/cli.py +100 -0
- kumoai/codegen/context.py +16 -0
- kumoai/codegen/edits.py +473 -0
- kumoai/codegen/exceptions.py +10 -0
- kumoai/codegen/generate.py +222 -0
- kumoai/codegen/handlers/__init__.py +4 -0
- kumoai/codegen/handlers/connector.py +118 -0
- kumoai/codegen/handlers/graph.py +71 -0
- kumoai/codegen/handlers/pquery.py +62 -0
- kumoai/codegen/handlers/table.py +109 -0
- kumoai/codegen/handlers/utils.py +42 -0
- kumoai/codegen/identity.py +114 -0
- kumoai/codegen/loader.py +93 -0
- kumoai/codegen/naming.py +94 -0
- kumoai/codegen/registry.py +121 -0
- kumoai/connector/__init__.py +31 -0
- kumoai/connector/base.py +153 -0
- kumoai/connector/bigquery_connector.py +200 -0
- kumoai/connector/databricks_connector.py +213 -0
- kumoai/connector/file_upload_connector.py +189 -0
- kumoai/connector/glue_connector.py +150 -0
- kumoai/connector/s3_connector.py +278 -0
- kumoai/connector/snowflake_connector.py +252 -0
- kumoai/connector/source_table.py +471 -0
- kumoai/connector/utils.py +1796 -0
- kumoai/databricks.py +14 -0
- kumoai/encoder/__init__.py +4 -0
- kumoai/exceptions.py +26 -0
- kumoai/experimental/__init__.py +0 -0
- kumoai/experimental/rfm/__init__.py +210 -0
- kumoai/experimental/rfm/authenticate.py +432 -0
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +736 -0
- kumoai/experimental/rfm/graph.py +1237 -0
- kumoai/experimental/rfm/infer/__init__.py +19 -0
- kumoai/experimental/rfm/infer/categorical.py +40 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/id.py +46 -0
- kumoai/experimental/rfm/infer/multicategorical.py +48 -0
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/infer/timestamp.py +41 -0
- kumoai/experimental/rfm/pquery/__init__.py +7 -0
- kumoai/experimental/rfm/pquery/executor.py +102 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +1184 -0
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/experimental/rfm/task_table.py +231 -0
- kumoai/formatting.py +30 -0
- kumoai/futures.py +99 -0
- kumoai/graph/__init__.py +12 -0
- kumoai/graph/column.py +106 -0
- kumoai/graph/graph.py +948 -0
- kumoai/graph/table.py +838 -0
- kumoai/jobs.py +80 -0
- kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
- kumoai/mixin.py +28 -0
- kumoai/pquery/__init__.py +25 -0
- kumoai/pquery/prediction_table.py +287 -0
- kumoai/pquery/predictive_query.py +641 -0
- kumoai/pquery/training_table.py +424 -0
- kumoai/spcs.py +121 -0
- kumoai/testing/__init__.py +8 -0
- kumoai/testing/decorators.py +57 -0
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/__init__.py +42 -0
- kumoai/trainer/baseline_trainer.py +93 -0
- kumoai/trainer/config.py +2 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/trainer/job.py +1192 -0
- kumoai/trainer/online_serving.py +258 -0
- kumoai/trainer/trainer.py +475 -0
- kumoai/trainer/util.py +103 -0
- kumoai/utils/__init__.py +11 -0
- kumoai/utils/datasets.py +83 -0
- kumoai/utils/display.py +51 -0
- kumoai/utils/forecasting.py +209 -0
- kumoai/utils/progress_logger.py +343 -0
- kumoai/utils/sql.py +3 -0
- kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
- kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
- kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
- kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
- kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,471 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import concurrent
|
|
3
|
+
import logging
|
|
4
|
+
from io import StringIO
|
|
5
|
+
from typing import (
|
|
6
|
+
TYPE_CHECKING,
|
|
7
|
+
Dict,
|
|
8
|
+
List,
|
|
9
|
+
Literal,
|
|
10
|
+
Optional,
|
|
11
|
+
Union,
|
|
12
|
+
overload,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
import pandas as pd
|
|
16
|
+
from kumoapi.jobs import JobStatus
|
|
17
|
+
from kumoapi.source_table import (
|
|
18
|
+
DataSourceType,
|
|
19
|
+
LLMRequest,
|
|
20
|
+
SourceColumn,
|
|
21
|
+
SourceTableDataResponse,
|
|
22
|
+
SourceTableType,
|
|
23
|
+
)
|
|
24
|
+
from kumoapi.table import TableDefinition
|
|
25
|
+
from typing_extensions import override
|
|
26
|
+
|
|
27
|
+
from kumoai import global_state
|
|
28
|
+
from kumoai.client.jobs import LLMJobId
|
|
29
|
+
from kumoai.exceptions import HTTPException
|
|
30
|
+
from kumoai.futures import KumoFuture, create_future
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from kumoai.connector import Connector
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
_DEFAULT_INTERVAL_S = 20.0
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class SourceTable:
|
|
41
|
+
r"""A source table is a reference to a table stored behind a backing
|
|
42
|
+
:class:`~kumoai.connector.Connector`. It can be used to examine basic
|
|
43
|
+
information about raw data connected to Kumo, including a sample of the
|
|
44
|
+
table's rows, basic statistics, and column data type information.
|
|
45
|
+
|
|
46
|
+
Once you are ready to use a table as part of a
|
|
47
|
+
:class:`~kumoai.graph.Graph`, you may create a :class:`~kumoai.graph.Table`
|
|
48
|
+
object from this source table, which includes additional specifying
|
|
49
|
+
information (including column semantic types and column constraint
|
|
50
|
+
information).
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
name: The name of this table in the backing connector
|
|
54
|
+
connector: The connector containing this table.
|
|
55
|
+
|
|
56
|
+
.. note::
|
|
57
|
+
Source tables can also be augmented with large language models to
|
|
58
|
+
introduce contextual embeddings for language features. To do so, please
|
|
59
|
+
consult :meth:`~kumoai.connector.SourceTable.add_llm`.
|
|
60
|
+
|
|
61
|
+
Example:
|
|
62
|
+
>>> import kumoai
|
|
63
|
+
>>> connector = kumoai.S3Connector(root_dir='s3://...') # doctest: +SKIP # noqa: E501
|
|
64
|
+
>>> articles_src = connector['articles'] # doctest: +SKIP
|
|
65
|
+
>>> articles_src = kumoai.SourceTable('articles', connector) # doctest: +SKIP # noqa: E501
|
|
66
|
+
"""
|
|
67
|
+
def __init__(self, name: str, connector: 'Connector') -> None:
|
|
68
|
+
# TODO(manan): existence check, if not too expensive?
|
|
69
|
+
self.name = name
|
|
70
|
+
self.connector = connector
|
|
71
|
+
|
|
72
|
+
# Metadata ################################################################
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def column_dict(self) -> Dict[str, SourceColumn]:
|
|
76
|
+
r"""Returns the names of the columns in this table along with their
|
|
77
|
+
:class:`SourceColumn` information.
|
|
78
|
+
"""
|
|
79
|
+
return {col.name: col for col in self.columns}
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def columns(self) -> List[SourceColumn]:
|
|
83
|
+
r"""Returns a list of the :class:`SourceColumn` metadata of the columns
|
|
84
|
+
in this table.
|
|
85
|
+
"""
|
|
86
|
+
resp: SourceTableDataResponse = self.connector._get_table_data(
|
|
87
|
+
table_names=[self.name], sample_rows=0)[0]
|
|
88
|
+
return resp.cols
|
|
89
|
+
|
|
90
|
+
# Data Access #############################################################
|
|
91
|
+
|
|
92
|
+
def head(self, num_rows: int = 5) -> pd.DataFrame:
|
|
93
|
+
r"""Returns the first :obj:`num_rows` rows of this source table by
|
|
94
|
+
reading data from the backing connector.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
num_rows: The number of rows to select. If :obj:`num_rows` is
|
|
98
|
+
larger than the number of available rows, all rows will be
|
|
99
|
+
returned.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
The first :obj:`num_rows` rows of the source table as a
|
|
103
|
+
:class:`~pandas.DataFrame`.
|
|
104
|
+
"""
|
|
105
|
+
num_rows = int(num_rows)
|
|
106
|
+
if num_rows <= 0:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"'num_rows' must be an integer greater than 0; got {num_rows}"
|
|
109
|
+
)
|
|
110
|
+
try:
|
|
111
|
+
resp = self.connector._get_table_data([self.name], num_rows)[0]
|
|
112
|
+
|
|
113
|
+
# TODO(manan, siyang): consider returning `bytes` instead of `json`
|
|
114
|
+
assert resp.sample_rows is not None
|
|
115
|
+
return pd.read_json(StringIO(resp.sample_rows), orient='table')
|
|
116
|
+
except TypeError as e:
|
|
117
|
+
raise RuntimeError(f"Cannot read head of table {self.name}. "
|
|
118
|
+
"Please restart the kernel and try.") from e
|
|
119
|
+
|
|
120
|
+
# Language Models #########################################################
|
|
121
|
+
|
|
122
|
+
@overload
|
|
123
|
+
def add_llm(
|
|
124
|
+
self,
|
|
125
|
+
model: str,
|
|
126
|
+
api_key: str,
|
|
127
|
+
template: str,
|
|
128
|
+
output_dir: str,
|
|
129
|
+
output_column_name: str,
|
|
130
|
+
output_table_name: str,
|
|
131
|
+
dimensions: Optional[int],
|
|
132
|
+
) -> 'SourceTable':
|
|
133
|
+
pass
|
|
134
|
+
|
|
135
|
+
@overload
|
|
136
|
+
def add_llm(
|
|
137
|
+
self,
|
|
138
|
+
model: str,
|
|
139
|
+
api_key: str,
|
|
140
|
+
template: str,
|
|
141
|
+
output_dir: str,
|
|
142
|
+
output_column_name: str,
|
|
143
|
+
output_table_name: str,
|
|
144
|
+
dimensions: Optional[int],
|
|
145
|
+
*,
|
|
146
|
+
non_blocking: Literal[False],
|
|
147
|
+
) -> 'SourceTable':
|
|
148
|
+
pass
|
|
149
|
+
|
|
150
|
+
@overload
|
|
151
|
+
def add_llm(
|
|
152
|
+
self,
|
|
153
|
+
model: str,
|
|
154
|
+
api_key: str,
|
|
155
|
+
template: str,
|
|
156
|
+
output_dir: str,
|
|
157
|
+
output_column_name: str,
|
|
158
|
+
output_table_name: str,
|
|
159
|
+
dimensions: Optional[int],
|
|
160
|
+
*,
|
|
161
|
+
non_blocking: Literal[True],
|
|
162
|
+
) -> 'LLMSourceTableFuture':
|
|
163
|
+
pass
|
|
164
|
+
|
|
165
|
+
@overload
|
|
166
|
+
def add_llm(
|
|
167
|
+
self,
|
|
168
|
+
model: str,
|
|
169
|
+
api_key: str,
|
|
170
|
+
template: str,
|
|
171
|
+
output_dir: str,
|
|
172
|
+
output_column_name: str,
|
|
173
|
+
output_table_name: str,
|
|
174
|
+
dimensions: Optional[int] = None,
|
|
175
|
+
*,
|
|
176
|
+
non_blocking: bool,
|
|
177
|
+
) -> Union['SourceTable', 'LLMSourceTableFuture']:
|
|
178
|
+
pass
|
|
179
|
+
|
|
180
|
+
def add_llm(
|
|
181
|
+
self,
|
|
182
|
+
model: str,
|
|
183
|
+
api_key: str,
|
|
184
|
+
template: str,
|
|
185
|
+
output_dir: str,
|
|
186
|
+
output_column_name: str,
|
|
187
|
+
output_table_name: str,
|
|
188
|
+
dimensions: Optional[int] = None,
|
|
189
|
+
*,
|
|
190
|
+
non_blocking: bool = False,
|
|
191
|
+
) -> Union['SourceTable', 'LLMSourceTableFuture']:
|
|
192
|
+
r"""Experimental method which returns a new source table that
|
|
193
|
+
includes a column computed via an LLM such as OpenAI embedding models.
|
|
194
|
+
Please refer to the example script for more details.
|
|
195
|
+
|
|
196
|
+
.. note::
|
|
197
|
+
|
|
198
|
+
Current LLM embedding only works for :obj:`SourceTable` in s3.
|
|
199
|
+
|
|
200
|
+
.. note::
|
|
201
|
+
|
|
202
|
+
Your :obj:`api_key` will be encrypted once we received it and
|
|
203
|
+
it's only decrypted just before we call the OpenAI text embeddings.
|
|
204
|
+
|
|
205
|
+
.. note::
|
|
206
|
+
Please keep track of the token usage in the `OpenAI Dashboard
|
|
207
|
+
<https://platform.openai.com/usage/activity>`_. If number of
|
|
208
|
+
tokens in the data exceeds the limit, the backend will raise
|
|
209
|
+
an error and no result will be produced.
|
|
210
|
+
|
|
211
|
+
.. warning::
|
|
212
|
+
|
|
213
|
+
This method only supports text embedding with data that has less
|
|
214
|
+
than ~6 million tokens. Number of tokens is estimated by following
|
|
215
|
+
`this guide <https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them>`_.
|
|
216
|
+
|
|
217
|
+
.. warning::
|
|
218
|
+
|
|
219
|
+
This method is still experimental. Please consult with your Kumo
|
|
220
|
+
POC before using it.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
model: The LLM model name, *e.g.*, OpenAI's
|
|
224
|
+
:obj:`"text-embedding-3-small"`.
|
|
225
|
+
api_key: The API key to call the LLM service.
|
|
226
|
+
template: A template string to be put into the LLM. For example,
|
|
227
|
+
:obj:`"{A1} and {A2}"` will fuse columns :obj:`A1` and
|
|
228
|
+
:obj:`A2` into a single string.
|
|
229
|
+
output_dir: The S3 directory to store the output.
|
|
230
|
+
output_column_name: The output column name for the LLM.
|
|
231
|
+
output_table_name: The output table name.
|
|
232
|
+
dimensions: The desired LLM embedding dimension.
|
|
233
|
+
non_blocking: Whether making this function non-blocking.
|
|
234
|
+
|
|
235
|
+
Example:
|
|
236
|
+
>>> import kumoai
|
|
237
|
+
>>> connector = kumoai.S3Connector(root_dir='s3://...') # doctest: +SKIP # noqa: E501
|
|
238
|
+
>>> articles_src = connector['articles'] # doctest: +SKIP
|
|
239
|
+
>>> articles_src_future = \
|
|
240
|
+
connector["articles"].add_llm(
|
|
241
|
+
model="text-embedding-3-small",
|
|
242
|
+
api_key=YOUR_OPENAI_API_KEY,
|
|
243
|
+
template=("The product {prod_name} in the {section_name} section"
|
|
244
|
+
"is categorized as {product_type_name} "
|
|
245
|
+
"and has following description: {detail_desc}"),
|
|
246
|
+
output_dir=YOUR_OUTPUT_DIR,
|
|
247
|
+
output_column_name="embedding_column",
|
|
248
|
+
output_table_name="articles_emb",
|
|
249
|
+
dimensions=256,
|
|
250
|
+
non_blocking=True,
|
|
251
|
+
)
|
|
252
|
+
>>> articles_src_future.status() # doctest: +SKIP
|
|
253
|
+
>>> articles_src_future.cancel() # doctest: +SKIP
|
|
254
|
+
>>> articles_src = articles_src_future.result() # doctest: +SKIP
|
|
255
|
+
""" # noqa
|
|
256
|
+
if global_state.is_spcs:
|
|
257
|
+
raise NotImplementedError("add_llm is not available on Snowflake")
|
|
258
|
+
source_table_type = self._to_api_source_table()
|
|
259
|
+
req = LLMRequest(
|
|
260
|
+
source_table_type=source_table_type,
|
|
261
|
+
template=template,
|
|
262
|
+
model=model,
|
|
263
|
+
model_api_key=api_key,
|
|
264
|
+
output_dir=output_dir,
|
|
265
|
+
output_column_name=output_column_name,
|
|
266
|
+
output_table_name=output_table_name,
|
|
267
|
+
dimensions=dimensions,
|
|
268
|
+
)
|
|
269
|
+
api = global_state.client.llm_job_api
|
|
270
|
+
resp: LLMJobId = api.create(req)
|
|
271
|
+
logger.info(f"LLMJobId: {resp}")
|
|
272
|
+
source_table_future = LLMSourceTableFuture(resp, output_table_name,
|
|
273
|
+
output_dir)
|
|
274
|
+
if non_blocking:
|
|
275
|
+
return source_table_future
|
|
276
|
+
# TODO (zecheng): Add attach for text embedding
|
|
277
|
+
return source_table_future.result()
|
|
278
|
+
|
|
279
|
+
# Persistence #############################################################
|
|
280
|
+
|
|
281
|
+
def _to_api_source_table(self) -> SourceTableType:
|
|
282
|
+
r"""Cast this source table as an object of type :obj:`SourceTableType`
|
|
283
|
+
for use with the public REST API.
|
|
284
|
+
"""
|
|
285
|
+
# TODO(manan): this is stupid, and is necessary because the s3_validate
|
|
286
|
+
# method in Kumo core does not properly return directories. So, we have
|
|
287
|
+
# to explicitly handle this ourselves here...
|
|
288
|
+
try:
|
|
289
|
+
return self.connector._get_table_config(self.name).source_table
|
|
290
|
+
except HTTPException:
|
|
291
|
+
name = self.name.rsplit('.', maxsplit=1)[0]
|
|
292
|
+
out = self.connector._get_table_config(name).source_table
|
|
293
|
+
self.name = name
|
|
294
|
+
return out
|
|
295
|
+
|
|
296
|
+
@staticmethod
|
|
297
|
+
def _from_api_table_definition(
|
|
298
|
+
table_definition: TableDefinition) -> 'SourceTable':
|
|
299
|
+
r"""Constructs a :class:`SourceTable` from a
|
|
300
|
+
:class:`kumoapi.table.TableDefinition`.
|
|
301
|
+
"""
|
|
302
|
+
from kumoai.connector import (
|
|
303
|
+
BigQueryConnector,
|
|
304
|
+
DatabricksConnector,
|
|
305
|
+
FileUploadConnector,
|
|
306
|
+
GlueConnector,
|
|
307
|
+
S3Connector,
|
|
308
|
+
SnowflakeConnector,
|
|
309
|
+
)
|
|
310
|
+
from kumoai.connector.s3_connector import S3URI
|
|
311
|
+
source_type = table_definition.source_table.data_source_type
|
|
312
|
+
connector: Connector
|
|
313
|
+
if source_type == DataSourceType.S3:
|
|
314
|
+
connector_id = table_definition.source_table.connector_id
|
|
315
|
+
if connector_id in {
|
|
316
|
+
'parquet_upload_connector', 'csv_upload_connector'
|
|
317
|
+
}:
|
|
318
|
+
# File upload:
|
|
319
|
+
connector = FileUploadConnector(
|
|
320
|
+
file_type=('parquet' if connector_id ==
|
|
321
|
+
'parquet_upload_connector' else 'csv'))
|
|
322
|
+
table_name = table_definition.source_table.source_table_name
|
|
323
|
+
else:
|
|
324
|
+
if connector_id is not None:
|
|
325
|
+
connector = S3Connector(root_dir=None,
|
|
326
|
+
_connector_id=connector_id)
|
|
327
|
+
table_name = (
|
|
328
|
+
table_definition.source_table.source_table_name)
|
|
329
|
+
else:
|
|
330
|
+
table_path = S3URI(table_definition.source_table.s3_path)
|
|
331
|
+
connector = S3Connector(root_dir=table_path.root_dir)
|
|
332
|
+
# Strip suffix, since Kumo always takes care of that:
|
|
333
|
+
table_name = table_path.object_name.rsplit(
|
|
334
|
+
'.', maxsplit=1)[0]
|
|
335
|
+
elif source_type == DataSourceType.SNOWFLAKE:
|
|
336
|
+
connector_api = global_state.client.connector_api
|
|
337
|
+
connector_resp = connector_api.get(
|
|
338
|
+
table_definition.source_table.snowflake_connector_id)
|
|
339
|
+
assert connector_resp is not None
|
|
340
|
+
connector_cfg = connector_resp.config
|
|
341
|
+
connector = SnowflakeConnector(
|
|
342
|
+
name=connector_cfg.name,
|
|
343
|
+
account=connector_cfg.account,
|
|
344
|
+
warehouse=connector_cfg.warehouse,
|
|
345
|
+
database=connector_cfg.database,
|
|
346
|
+
schema_name=connector_cfg.schema_name,
|
|
347
|
+
credentials=None, # should be in env; do not load from DB.
|
|
348
|
+
_bypass_creation=True,
|
|
349
|
+
)
|
|
350
|
+
table_name = table_definition.source_table.table
|
|
351
|
+
elif source_type == DataSourceType.DATABRICKS:
|
|
352
|
+
connector_api = global_state.client.connector_api
|
|
353
|
+
connector_resp = connector_api.get(
|
|
354
|
+
table_definition.source_table.databricks_connector_id)
|
|
355
|
+
assert connector_resp is not None
|
|
356
|
+
connector_cfg = connector_resp.config
|
|
357
|
+
connector = DatabricksConnector(
|
|
358
|
+
name=connector_cfg.name,
|
|
359
|
+
host=connector_cfg.host,
|
|
360
|
+
cluster_id=connector_cfg.cluster_id,
|
|
361
|
+
warehouse_id=connector_cfg.warehouse_id,
|
|
362
|
+
catalog=connector_cfg.catalog,
|
|
363
|
+
credentials=None, # should be in env; do not load from DB.
|
|
364
|
+
_bypass_creation=True,
|
|
365
|
+
)
|
|
366
|
+
table_name = table_definition.source_table.table
|
|
367
|
+
elif source_type == DataSourceType.BIGQUERY:
|
|
368
|
+
connector_api = global_state.client.connector_api
|
|
369
|
+
connector_resp = connector_api.get(
|
|
370
|
+
table_definition.source_table.bigquery_connector_id)
|
|
371
|
+
assert connector_resp is not None
|
|
372
|
+
connector_cfg = connector_resp.config
|
|
373
|
+
connector = BigQueryConnector(
|
|
374
|
+
name=connector_cfg.name,
|
|
375
|
+
project_id=connector_cfg.project_id,
|
|
376
|
+
dataset_id=connector_cfg.dataset_id,
|
|
377
|
+
credentials=None, # should be in env; do not load from DB.
|
|
378
|
+
_bypass_creation=True,
|
|
379
|
+
)
|
|
380
|
+
table_name = table_definition.source_table.table_name
|
|
381
|
+
elif source_type == DataSourceType.GLUE:
|
|
382
|
+
connector_api = global_state.client.connector_api
|
|
383
|
+
connector_resp = connector_api.get(
|
|
384
|
+
table_definition.source_table.glue_connector_id)
|
|
385
|
+
assert connector_resp is not None
|
|
386
|
+
connector_cfg = connector_resp.config
|
|
387
|
+
connector = GlueConnector(
|
|
388
|
+
name=connector_cfg.name,
|
|
389
|
+
account=connector_cfg.account,
|
|
390
|
+
region=connector_cfg.region,
|
|
391
|
+
database=connector_cfg.database,
|
|
392
|
+
_bypass_creation=True,
|
|
393
|
+
)
|
|
394
|
+
table_name = table_definition.source_table.table
|
|
395
|
+
else:
|
|
396
|
+
raise NotImplementedError()
|
|
397
|
+
|
|
398
|
+
return SourceTable(table_name, connector)
|
|
399
|
+
|
|
400
|
+
# Class properties ########################################################
|
|
401
|
+
|
|
402
|
+
def __repr__(self) -> str:
|
|
403
|
+
return f'{self.__class__.__name__}(name={self.name})'
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
class SourceTableFuture(KumoFuture[SourceTable]):
|
|
407
|
+
r"""A representation of an on-going :class:`SourceTable` generation
|
|
408
|
+
process.
|
|
409
|
+
"""
|
|
410
|
+
def __init__(
|
|
411
|
+
self,
|
|
412
|
+
job_id: LLMJobId,
|
|
413
|
+
table_name: str,
|
|
414
|
+
output_dir: str,
|
|
415
|
+
) -> None:
|
|
416
|
+
self.job_id = job_id
|
|
417
|
+
self._fut: concurrent.futures.Future = create_future(
|
|
418
|
+
_poll(job_id, table_name, output_dir))
|
|
419
|
+
|
|
420
|
+
@override
|
|
421
|
+
def result(self) -> SourceTable:
|
|
422
|
+
return self._fut.result()
|
|
423
|
+
|
|
424
|
+
@override
|
|
425
|
+
def future(self) -> 'concurrent.futures.Future[SourceTable]':
|
|
426
|
+
return self._fut
|
|
427
|
+
|
|
428
|
+
def status(self) -> JobStatus:
|
|
429
|
+
return _get_status(self.job_id)
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
class LLMSourceTableFuture(SourceTableFuture):
|
|
433
|
+
r"""A representation of an on-going :class:`SourceTable`
|
|
434
|
+
generation process for LLM. This class inherits from the
|
|
435
|
+
:class:`SourceTableFuture` with some functions specific
|
|
436
|
+
to LLM job.
|
|
437
|
+
"""
|
|
438
|
+
def __init__(
|
|
439
|
+
self,
|
|
440
|
+
job_id: LLMJobId,
|
|
441
|
+
table_name: str,
|
|
442
|
+
output_dir: str,
|
|
443
|
+
) -> None:
|
|
444
|
+
super().__init__(job_id, table_name, output_dir)
|
|
445
|
+
|
|
446
|
+
def cancel(self) -> JobStatus:
|
|
447
|
+
r"""Cancel the LLM job."""
|
|
448
|
+
api = global_state.client.llm_job_api
|
|
449
|
+
return api.cancel(self.job_id)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def _get_status(job_id: str) -> JobStatus:
|
|
453
|
+
api = global_state.client.llm_job_api
|
|
454
|
+
resource: JobStatus = api.get(job_id)
|
|
455
|
+
return resource
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
async def _poll(job_id: str, table_name: str, output_dir: str) -> SourceTable:
|
|
459
|
+
status = _get_status(job_id)
|
|
460
|
+
while not status.is_terminal:
|
|
461
|
+
await asyncio.sleep(_DEFAULT_INTERVAL_S)
|
|
462
|
+
status = _get_status(job_id)
|
|
463
|
+
|
|
464
|
+
if status != JobStatus.DONE:
|
|
465
|
+
raise RuntimeError(f"LLM job {job_id} failed with "
|
|
466
|
+
f"job status {status}.")
|
|
467
|
+
|
|
468
|
+
from kumoai.connector import S3Connector
|
|
469
|
+
connector = S3Connector(root_dir=output_dir)
|
|
470
|
+
|
|
471
|
+
return connector[table_name]
|