kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +300 -0
- kumoai/_logging.py +29 -0
- kumoai/_singleton.py +25 -0
- kumoai/_version.py +1 -0
- kumoai/artifact_export/__init__.py +9 -0
- kumoai/artifact_export/config.py +209 -0
- kumoai/artifact_export/job.py +108 -0
- kumoai/client/__init__.py +5 -0
- kumoai/client/client.py +223 -0
- kumoai/client/connector.py +110 -0
- kumoai/client/endpoints.py +150 -0
- kumoai/client/graph.py +120 -0
- kumoai/client/jobs.py +471 -0
- kumoai/client/online.py +78 -0
- kumoai/client/pquery.py +207 -0
- kumoai/client/rfm.py +112 -0
- kumoai/client/source_table.py +53 -0
- kumoai/client/table.py +101 -0
- kumoai/client/utils.py +130 -0
- kumoai/codegen/__init__.py +19 -0
- kumoai/codegen/cli.py +100 -0
- kumoai/codegen/context.py +16 -0
- kumoai/codegen/edits.py +473 -0
- kumoai/codegen/exceptions.py +10 -0
- kumoai/codegen/generate.py +222 -0
- kumoai/codegen/handlers/__init__.py +4 -0
- kumoai/codegen/handlers/connector.py +118 -0
- kumoai/codegen/handlers/graph.py +71 -0
- kumoai/codegen/handlers/pquery.py +62 -0
- kumoai/codegen/handlers/table.py +109 -0
- kumoai/codegen/handlers/utils.py +42 -0
- kumoai/codegen/identity.py +114 -0
- kumoai/codegen/loader.py +93 -0
- kumoai/codegen/naming.py +94 -0
- kumoai/codegen/registry.py +121 -0
- kumoai/connector/__init__.py +31 -0
- kumoai/connector/base.py +153 -0
- kumoai/connector/bigquery_connector.py +200 -0
- kumoai/connector/databricks_connector.py +213 -0
- kumoai/connector/file_upload_connector.py +189 -0
- kumoai/connector/glue_connector.py +150 -0
- kumoai/connector/s3_connector.py +278 -0
- kumoai/connector/snowflake_connector.py +252 -0
- kumoai/connector/source_table.py +471 -0
- kumoai/connector/utils.py +1796 -0
- kumoai/databricks.py +14 -0
- kumoai/encoder/__init__.py +4 -0
- kumoai/exceptions.py +26 -0
- kumoai/experimental/__init__.py +0 -0
- kumoai/experimental/rfm/__init__.py +210 -0
- kumoai/experimental/rfm/authenticate.py +432 -0
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +736 -0
- kumoai/experimental/rfm/graph.py +1237 -0
- kumoai/experimental/rfm/infer/__init__.py +19 -0
- kumoai/experimental/rfm/infer/categorical.py +40 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/id.py +46 -0
- kumoai/experimental/rfm/infer/multicategorical.py +48 -0
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/infer/timestamp.py +41 -0
- kumoai/experimental/rfm/pquery/__init__.py +7 -0
- kumoai/experimental/rfm/pquery/executor.py +102 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +1184 -0
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/experimental/rfm/task_table.py +231 -0
- kumoai/formatting.py +30 -0
- kumoai/futures.py +99 -0
- kumoai/graph/__init__.py +12 -0
- kumoai/graph/column.py +106 -0
- kumoai/graph/graph.py +948 -0
- kumoai/graph/table.py +838 -0
- kumoai/jobs.py +80 -0
- kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
- kumoai/mixin.py +28 -0
- kumoai/pquery/__init__.py +25 -0
- kumoai/pquery/prediction_table.py +287 -0
- kumoai/pquery/predictive_query.py +641 -0
- kumoai/pquery/training_table.py +424 -0
- kumoai/spcs.py +121 -0
- kumoai/testing/__init__.py +8 -0
- kumoai/testing/decorators.py +57 -0
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/__init__.py +42 -0
- kumoai/trainer/baseline_trainer.py +93 -0
- kumoai/trainer/config.py +2 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/trainer/job.py +1192 -0
- kumoai/trainer/online_serving.py +258 -0
- kumoai/trainer/trainer.py +475 -0
- kumoai/trainer/util.py +103 -0
- kumoai/utils/__init__.py +11 -0
- kumoai/utils/datasets.py +83 -0
- kumoai/utils/display.py +51 -0
- kumoai/utils/forecasting.py +209 -0
- kumoai/utils/progress_logger.py +343 -0
- kumoai/utils/sql.py +3 -0
- kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
- kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
- kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
- kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
- kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
kumoai/client/jobs.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import (
|
|
3
|
+
Any,
|
|
4
|
+
Dict,
|
|
5
|
+
Generic,
|
|
6
|
+
List,
|
|
7
|
+
Mapping,
|
|
8
|
+
Optional,
|
|
9
|
+
Tuple,
|
|
10
|
+
Type,
|
|
11
|
+
TypeVar,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from kumoapi.common import ValidationResponse
|
|
15
|
+
from kumoapi.jobs import (
|
|
16
|
+
ArtifactExportRequest,
|
|
17
|
+
ArtifactExportResponse,
|
|
18
|
+
AutoTrainerProgress,
|
|
19
|
+
BaselineJobRequest,
|
|
20
|
+
BaselineJobResource,
|
|
21
|
+
BatchPredictionJobResource,
|
|
22
|
+
BatchPredictionRequest,
|
|
23
|
+
CancelBatchPredictionJobResponse,
|
|
24
|
+
CancelTrainingJobResponse,
|
|
25
|
+
DistillationJobRequest,
|
|
26
|
+
DistillationJobResource,
|
|
27
|
+
ErrorDetails,
|
|
28
|
+
GeneratePredictionTableJobResource,
|
|
29
|
+
GeneratePredictionTableRequest,
|
|
30
|
+
GenerateTrainTableJobResource,
|
|
31
|
+
GenerateTrainTableRequest,
|
|
32
|
+
GetEmbeddingsDfUrlResponse,
|
|
33
|
+
GetPredictionsDfUrlResponse,
|
|
34
|
+
JobRequestBase,
|
|
35
|
+
JobResourceBase,
|
|
36
|
+
JobStatus,
|
|
37
|
+
PredictionProgress,
|
|
38
|
+
TrainingJobRequest,
|
|
39
|
+
TrainingJobResource,
|
|
40
|
+
TrainingTableSpec,
|
|
41
|
+
)
|
|
42
|
+
from kumoapi.json_serde import from_json, to_json_dict
|
|
43
|
+
from kumoapi.source_table import LLMRequest, LLMResponse, SourceTableType
|
|
44
|
+
from kumoapi.train import TrainingStage
|
|
45
|
+
from typing_extensions import override
|
|
46
|
+
|
|
47
|
+
from kumoai.client import KumoClient
|
|
48
|
+
from kumoai.client.utils import (
|
|
49
|
+
Returns,
|
|
50
|
+
parse_patch_response,
|
|
51
|
+
parse_response,
|
|
52
|
+
raise_on_error,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
TrainingJobID = str
|
|
56
|
+
BatchPredictionJobID = str
|
|
57
|
+
GenerateTrainTableJobID = str
|
|
58
|
+
GeneratePredictionTableJobID = str
|
|
59
|
+
LLMJobId = str
|
|
60
|
+
BaselineJobID = str
|
|
61
|
+
|
|
62
|
+
JobRequestType = TypeVar('JobRequestType', bound=JobRequestBase)
|
|
63
|
+
JobResourceType = TypeVar('JobResourceType', bound=JobResourceBase)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class CommonJobAPI(Generic[JobRequestType, JobResourceType]):
|
|
67
|
+
def __init__(self, client: KumoClient, base_endpoint: str,
|
|
68
|
+
res_type: Type[JobResourceType]) -> None:
|
|
69
|
+
self._client = client
|
|
70
|
+
self._base_endpoint = base_endpoint
|
|
71
|
+
self._res_type = res_type
|
|
72
|
+
|
|
73
|
+
def create(self, request: JobRequestType) -> str:
|
|
74
|
+
response = self._client._post(self._base_endpoint,
|
|
75
|
+
json=to_json_dict(request))
|
|
76
|
+
raise_on_error(response)
|
|
77
|
+
return parse_response(Dict[str, str], response)['id']
|
|
78
|
+
|
|
79
|
+
def get(self, id: str) -> JobResourceType:
|
|
80
|
+
response = self._client._get(f'{self._base_endpoint}/{id}')
|
|
81
|
+
raise_on_error(response)
|
|
82
|
+
return parse_response(self._res_type, response)
|
|
83
|
+
|
|
84
|
+
def list(
|
|
85
|
+
self,
|
|
86
|
+
*,
|
|
87
|
+
pquery_name: Optional[str] = None,
|
|
88
|
+
pquery_id: Optional[str] = None,
|
|
89
|
+
job_status: Optional[JobStatus] = None,
|
|
90
|
+
limit: Optional[int] = None,
|
|
91
|
+
additional_tags: Mapping[str, str] = {},
|
|
92
|
+
) -> List[JobResourceType]:
|
|
93
|
+
params: Dict[str, Any] = {
|
|
94
|
+
'pquery_name': pquery_name,
|
|
95
|
+
'pquery_id': pquery_id,
|
|
96
|
+
'job_status': job_status,
|
|
97
|
+
'limit': limit
|
|
98
|
+
}
|
|
99
|
+
params.update(additional_tags)
|
|
100
|
+
response = self._client._get(self._base_endpoint, params=params)
|
|
101
|
+
raise_on_error(response)
|
|
102
|
+
resource_elements = response.json()
|
|
103
|
+
assert isinstance(resource_elements, list)
|
|
104
|
+
return [from_json(e, self._res_type) for e in resource_elements]
|
|
105
|
+
|
|
106
|
+
def delete_tags(self, job_id: str, tags: List[str]) -> bool:
|
|
107
|
+
r"""Removes the tags from the job.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
job_id (str): The ID of the job.
|
|
111
|
+
tags (List[str]): The tags to remove.
|
|
112
|
+
"""
|
|
113
|
+
return self.update_tags(job_id, {t: 'none' for t in tags})
|
|
114
|
+
|
|
115
|
+
def update_tags(self, job_id: str,
|
|
116
|
+
custom_job_tags: Mapping[str, Optional[str]]) -> bool:
|
|
117
|
+
r"""Updates the tags of the job.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
job_id (str): The ID of the job.
|
|
121
|
+
custom_job_tags (Mapping[str, Optional[str]]): The tags to update.
|
|
122
|
+
Note that the value 'none' will remove the tag. If the tag is
|
|
123
|
+
not present, it will be added.
|
|
124
|
+
"""
|
|
125
|
+
response = self._client._patch(
|
|
126
|
+
f'{self._base_endpoint}/{job_id}/tags',
|
|
127
|
+
data=None,
|
|
128
|
+
params={
|
|
129
|
+
k: str(v)
|
|
130
|
+
for k, v in custom_job_tags.items()
|
|
131
|
+
},
|
|
132
|
+
)
|
|
133
|
+
raise_on_error(response)
|
|
134
|
+
return parse_patch_response(response)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class BaselineJobAPI(CommonJobAPI[BaselineJobRequest, BaselineJobResource]):
|
|
138
|
+
r"""Typed API definition for the baseline job resource."""
|
|
139
|
+
def __init__(self, client: KumoClient) -> None:
|
|
140
|
+
super().__init__(client, '/baseline_jobs', BaselineJobResource)
|
|
141
|
+
|
|
142
|
+
def get_config(self, job_id: str) -> BaselineJobRequest:
|
|
143
|
+
"""Load the configuration for a baseline job by ID."""
|
|
144
|
+
resource = self.get(job_id)
|
|
145
|
+
return resource.config
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class TrainingJobAPI(CommonJobAPI[TrainingJobRequest, TrainingJobResource]):
|
|
149
|
+
r"""Typed API definition for the training job resource."""
|
|
150
|
+
def __init__(self, client: KumoClient) -> None:
|
|
151
|
+
super().__init__(client, '/training_jobs', TrainingJobResource)
|
|
152
|
+
|
|
153
|
+
def get_progress(self, id: TrainingJobID) -> AutoTrainerProgress:
|
|
154
|
+
response = self._client._get(f'{self._base_endpoint}/{id}/progress')
|
|
155
|
+
raise_on_error(response)
|
|
156
|
+
return parse_response(AutoTrainerProgress, response)
|
|
157
|
+
|
|
158
|
+
def holdout_data_url(self, id: TrainingJobID,
|
|
159
|
+
presigned: bool = True) -> str:
|
|
160
|
+
response = self._client._get(f'{self._base_endpoint}/{id}/holdout',
|
|
161
|
+
params={'presigned': presigned})
|
|
162
|
+
raise_on_error(response)
|
|
163
|
+
return response.text
|
|
164
|
+
|
|
165
|
+
def cancel(self, id: str) -> CancelTrainingJobResponse:
|
|
166
|
+
response = self._client._post(f'{self._base_endpoint}/{id}/cancel')
|
|
167
|
+
raise_on_error(response)
|
|
168
|
+
return parse_response(CancelTrainingJobResponse, response)
|
|
169
|
+
|
|
170
|
+
def get_config(self, job_id: str) -> TrainingJobRequest:
|
|
171
|
+
"""Load the configuration for a training job by ID."""
|
|
172
|
+
resource = self.get(job_id)
|
|
173
|
+
return resource.config
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class DistillationJobAPI(CommonJobAPI[DistillationJobRequest,
|
|
177
|
+
DistillationJobResource]):
|
|
178
|
+
r"""Typed API definition for the distillation job resource."""
|
|
179
|
+
def __init__(self, client: KumoClient) -> None:
|
|
180
|
+
super().__init__(client, '/training_jobs/distilled_training_job',
|
|
181
|
+
DistillationJobResource)
|
|
182
|
+
|
|
183
|
+
def get_config(self, job_id: str) -> DistillationJobRequest:
|
|
184
|
+
raise NotImplementedError(
|
|
185
|
+
"Getting the configuration for a distillation job is "
|
|
186
|
+
"not implemented yet.")
|
|
187
|
+
|
|
188
|
+
def get_progress(self, id: str) -> AutoTrainerProgress:
|
|
189
|
+
raise NotImplementedError(
|
|
190
|
+
"Getting the progress for a distillation job is not "
|
|
191
|
+
"implemented yet.")
|
|
192
|
+
|
|
193
|
+
def cancel(self, id: str) -> CancelTrainingJobResponse:
|
|
194
|
+
raise NotImplementedError(
|
|
195
|
+
"Cancelling a distillation job is not implemented yet.")
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class BatchPredictionJobAPI(CommonJobAPI[BatchPredictionRequest,
|
|
199
|
+
BatchPredictionJobResource]):
|
|
200
|
+
r"""Typed API definition for the prediction job resource."""
|
|
201
|
+
def __init__(self, client: KumoClient) -> None:
|
|
202
|
+
super().__init__(client, '/prediction_jobs',
|
|
203
|
+
BatchPredictionJobResource)
|
|
204
|
+
|
|
205
|
+
@override
|
|
206
|
+
def create(self, request: BatchPredictionRequest) -> str:
|
|
207
|
+
# TODO(manan): eventually, all `create` methods should
|
|
208
|
+
# return a validation response:
|
|
209
|
+
raise NotImplementedError
|
|
210
|
+
|
|
211
|
+
def maybe_create(
|
|
212
|
+
self, request: BatchPredictionRequest
|
|
213
|
+
) -> Tuple[Optional[str], ValidationResponse]:
|
|
214
|
+
response = self._client._post(self._base_endpoint,
|
|
215
|
+
json=to_json_dict(request))
|
|
216
|
+
raise_on_error(response)
|
|
217
|
+
return parse_response(
|
|
218
|
+
Returns[Tuple[Optional[str], ValidationResponse]], response)
|
|
219
|
+
|
|
220
|
+
def list(
|
|
221
|
+
self,
|
|
222
|
+
*,
|
|
223
|
+
model_id: Optional[TrainingJobID] = None,
|
|
224
|
+
pquery_name: Optional[str] = None,
|
|
225
|
+
pquery_id: Optional[str] = None,
|
|
226
|
+
job_status: Optional[JobStatus] = None,
|
|
227
|
+
limit: Optional[int] = None,
|
|
228
|
+
additional_tags: Mapping[str, str] = {},
|
|
229
|
+
) -> List[BatchPredictionJobResource]:
|
|
230
|
+
if model_id:
|
|
231
|
+
additional_tags = {**additional_tags, 'model_id': model_id}
|
|
232
|
+
return super().list(pquery_name=pquery_name, pquery_id=pquery_id,
|
|
233
|
+
job_status=job_status, limit=limit,
|
|
234
|
+
additional_tags=additional_tags)
|
|
235
|
+
|
|
236
|
+
def get_progress(self, id: str) -> PredictionProgress:
|
|
237
|
+
response = self._client._get(f'{self._base_endpoint}/{id}/progress')
|
|
238
|
+
raise_on_error(response)
|
|
239
|
+
return parse_response(PredictionProgress, response)
|
|
240
|
+
|
|
241
|
+
def cancel(self, id: str) -> CancelBatchPredictionJobResponse:
|
|
242
|
+
response = self._client._post(f'{self._base_endpoint}/{id}/cancel')
|
|
243
|
+
raise_on_error(response)
|
|
244
|
+
return parse_response(CancelBatchPredictionJobResponse, response)
|
|
245
|
+
|
|
246
|
+
def get_batch_predictions_url(self, id: str) -> List[str]:
|
|
247
|
+
"""Returns presigned URLs pointing to the locations where the
|
|
248
|
+
predictions are stored. Depending on the environment where this is run,
|
|
249
|
+
they could be AWS S3 paths, Snowflake stage paths, or Databricks UC
|
|
250
|
+
volume paths.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
id (str): ID of the batch prediction job for which predictions are
|
|
254
|
+
requested
|
|
255
|
+
"""
|
|
256
|
+
response = self._client._get(
|
|
257
|
+
f'{self._base_endpoint}/{id}/get_prediction_df_urls')
|
|
258
|
+
raise_on_error(response)
|
|
259
|
+
return parse_response(
|
|
260
|
+
GetPredictionsDfUrlResponse,
|
|
261
|
+
response,
|
|
262
|
+
).prediction_partitions
|
|
263
|
+
|
|
264
|
+
def get_batch_embeddings_url(self, id: str) -> List[str]:
|
|
265
|
+
"""Returns presigned URLs pointing to the locations where the
|
|
266
|
+
embeddings are stored. Depending on the environment where this is run,
|
|
267
|
+
they could be AWS S3 paths, Snowflake stage paths, or Databricks UC
|
|
268
|
+
volume paths.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
id (str): ID of the batch prediction job for which embeddings are
|
|
272
|
+
requested
|
|
273
|
+
"""
|
|
274
|
+
response = self._client._get(
|
|
275
|
+
f'{self._base_endpoint}/{id}/get_embedding_df_urls')
|
|
276
|
+
raise_on_error(response)
|
|
277
|
+
return parse_response(
|
|
278
|
+
GetEmbeddingsDfUrlResponse,
|
|
279
|
+
response,
|
|
280
|
+
).embedding_partitions
|
|
281
|
+
|
|
282
|
+
def get_config(self, job_id: str) -> BatchPredictionRequest:
|
|
283
|
+
"""Load the configuration for a batch prediction job by ID."""
|
|
284
|
+
resource = self.get(job_id)
|
|
285
|
+
return resource.config
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class GenerateTrainTableJobAPI(CommonJobAPI[GenerateTrainTableRequest,
|
|
289
|
+
GenerateTrainTableJobResource]):
|
|
290
|
+
r"""Typed API definition for training table generation job resource."""
|
|
291
|
+
def __init__(self, client: KumoClient) -> None:
|
|
292
|
+
super().__init__(client, '/gentraintable_jobs',
|
|
293
|
+
GenerateTrainTableJobResource)
|
|
294
|
+
|
|
295
|
+
def get_table_data(self, id: GenerateTrainTableJobID,
|
|
296
|
+
presigned: bool = True) -> List[str]:
|
|
297
|
+
"""Return a list of URLs to access train table parquet data.
|
|
298
|
+
There might be multiple URLs if the table data is partitioned into
|
|
299
|
+
multiple files.
|
|
300
|
+
"""
|
|
301
|
+
return self._get_table_data(id, presigned)
|
|
302
|
+
|
|
303
|
+
def _get_table_data(self, id: GenerateTrainTableJobID,
|
|
304
|
+
presigned: bool = True,
|
|
305
|
+
raw_path: bool = False) -> List[str]:
|
|
306
|
+
"""Helper function to get train table data."""
|
|
307
|
+
# Raw path to get local file path instead of SPCS stage path
|
|
308
|
+
params: Dict[str, Any] = {'presigned': presigned, 'raw_path': raw_path}
|
|
309
|
+
resp = self._client._get(f'{self._base_endpoint}/{id}/table_data',
|
|
310
|
+
params=params)
|
|
311
|
+
raise_on_error(resp)
|
|
312
|
+
return parse_response(List[str], resp)
|
|
313
|
+
|
|
314
|
+
def get_split_masks(
|
|
315
|
+
self, id: GenerateTrainTableJobID) -> Dict[TrainingStage, str]:
|
|
316
|
+
"""Return a dictionary of presigned URLs keyed by training stage.
|
|
317
|
+
Each URL points to a torch-serialized (default pickle protocol) file of
|
|
318
|
+
the mask tensor for that training stage.
|
|
319
|
+
|
|
320
|
+
Example:
|
|
321
|
+
>>> # code to load a mask tensor:
|
|
322
|
+
>>> import io
|
|
323
|
+
>>> import torch
|
|
324
|
+
>>> import requests
|
|
325
|
+
>>> masks = get_split_masks('some-gen-traintable-job-id')
|
|
326
|
+
>>> data_bytes = requests.get(masks[TrainingStage.TEST]).content
|
|
327
|
+
>>> test_mask_tensor = torch.load(io.BytesIO(data))
|
|
328
|
+
"""
|
|
329
|
+
resp = self._client._get(f'{self._base_endpoint}/{id}/split_masks')
|
|
330
|
+
raise_on_error(resp)
|
|
331
|
+
return parse_response(Dict[TrainingStage, str], resp)
|
|
332
|
+
|
|
333
|
+
def get_progress(self, id: str) -> Dict[str, int]:
|
|
334
|
+
response = self._client._get(f'{self._base_endpoint}/{id}/progress')
|
|
335
|
+
raise_on_error(response)
|
|
336
|
+
return parse_response(Dict[str, int], response)
|
|
337
|
+
|
|
338
|
+
def cancel(self, id: str) -> None:
|
|
339
|
+
response = self._client._post(f'{self._base_endpoint}/{id}/cancel')
|
|
340
|
+
raise_on_error(response)
|
|
341
|
+
|
|
342
|
+
def validate_custom_train_table(
|
|
343
|
+
self,
|
|
344
|
+
id: str,
|
|
345
|
+
source_table_type: SourceTableType,
|
|
346
|
+
train_table_mod: TrainingTableSpec,
|
|
347
|
+
) -> ValidationResponse:
|
|
348
|
+
response = self._client._post(
|
|
349
|
+
f'{self._base_endpoint}/{id}/validate_custom_train_table',
|
|
350
|
+
json=to_json_dict({
|
|
351
|
+
'custom_table': source_table_type,
|
|
352
|
+
'train_table_mod': train_table_mod,
|
|
353
|
+
}),
|
|
354
|
+
)
|
|
355
|
+
return parse_response(ValidationResponse, response)
|
|
356
|
+
|
|
357
|
+
def get_job_error(self, id: str) -> ErrorDetails:
|
|
358
|
+
"""Thin API wrapper for fetching errors from the jobs.
|
|
359
|
+
|
|
360
|
+
Arguments:
|
|
361
|
+
id (str): Id of the job whose related errors are expected to be
|
|
362
|
+
queried.
|
|
363
|
+
"""
|
|
364
|
+
response = self._client._get(f'{self._base_endpoint}/{id}/get_errors')
|
|
365
|
+
raise_on_error(response)
|
|
366
|
+
return parse_response(ErrorDetails, response)
|
|
367
|
+
|
|
368
|
+
def get_config(self, job_id: str) -> GenerateTrainTableRequest:
|
|
369
|
+
"""Load the configuration for a training table generation job by ID."""
|
|
370
|
+
resource = self.get(job_id)
|
|
371
|
+
return resource.config
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
class GeneratePredictionTableJobAPI(
|
|
375
|
+
CommonJobAPI[GeneratePredictionTableRequest,
|
|
376
|
+
GeneratePredictionTableJobResource]):
|
|
377
|
+
r"""Typed API definition for prediction table generation job resource."""
|
|
378
|
+
def __init__(self, client: KumoClient) -> None:
|
|
379
|
+
super().__init__(client, '/genpredtable_jobs',
|
|
380
|
+
GeneratePredictionTableJobResource)
|
|
381
|
+
|
|
382
|
+
def get_anchor_time(self, id: BatchPredictionJobID) -> Optional[datetime]:
|
|
383
|
+
response = self._client._get(
|
|
384
|
+
f'{self._base_endpoint}/{id}/get_anchor_time')
|
|
385
|
+
raise_on_error(response)
|
|
386
|
+
return parse_response(Returns[Optional[datetime]], response)
|
|
387
|
+
|
|
388
|
+
def get_table_data(self, id: GeneratePredictionTableJobID,
|
|
389
|
+
presigned: bool = True) -> List[str]:
|
|
390
|
+
"""Return a list of URLs to access prediction table parquet data.
|
|
391
|
+
There might be multiple URLs if the table data is partitioned into
|
|
392
|
+
multiple files.
|
|
393
|
+
"""
|
|
394
|
+
params: Dict[str, Any] = {'presigned': presigned}
|
|
395
|
+
resp = self._client._get(f'{self._base_endpoint}/{id}/table_data',
|
|
396
|
+
params=params)
|
|
397
|
+
raise_on_error(resp)
|
|
398
|
+
return parse_response(List[str], resp)
|
|
399
|
+
|
|
400
|
+
def cancel(self, id: str) -> None:
|
|
401
|
+
response = self._client._post(f'{self._base_endpoint}/{id}/cancel')
|
|
402
|
+
raise_on_error(response)
|
|
403
|
+
|
|
404
|
+
def get_job_error(self, id: str) -> ErrorDetails:
|
|
405
|
+
"""Thin API wrapper for fetching errors from the jobs.
|
|
406
|
+
|
|
407
|
+
Arguments:
|
|
408
|
+
id (str): Id of the job whose related errors are expected to be
|
|
409
|
+
queried.
|
|
410
|
+
"""
|
|
411
|
+
response = self._client._get(f'{self._base_endpoint}/{id}/get_errors')
|
|
412
|
+
raise_on_error(response)
|
|
413
|
+
return parse_response(ErrorDetails, response)
|
|
414
|
+
|
|
415
|
+
def get_config(self, job_id: str) -> GeneratePredictionTableRequest:
|
|
416
|
+
"""Load the configuration for a
|
|
417
|
+
prediction table generation job by ID.
|
|
418
|
+
"""
|
|
419
|
+
resource = self.get(job_id)
|
|
420
|
+
return resource.config
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
class LLMJobAPI:
|
|
424
|
+
r"""Typed API definition for LLM job resource."""
|
|
425
|
+
def __init__(self, client: KumoClient) -> None:
|
|
426
|
+
self._client = client
|
|
427
|
+
self._base_endpoint = '/llm_embedding_job'
|
|
428
|
+
|
|
429
|
+
def create(self, request: LLMRequest) -> LLMJobId:
|
|
430
|
+
response = self._client._post(
|
|
431
|
+
self._base_endpoint,
|
|
432
|
+
json=to_json_dict(request),
|
|
433
|
+
)
|
|
434
|
+
raise_on_error(response)
|
|
435
|
+
return parse_response(LLMResponse, response).job_id
|
|
436
|
+
|
|
437
|
+
def get(self, id: LLMJobId) -> JobStatus:
|
|
438
|
+
response = self._client._get(f'{self._base_endpoint}/status/{id}')
|
|
439
|
+
raise_on_error(response)
|
|
440
|
+
return parse_response(JobStatus, response)
|
|
441
|
+
|
|
442
|
+
def cancel(self, id: LLMJobId) -> JobStatus:
|
|
443
|
+
response = self._client._delete(f'{self._base_endpoint}/cancel/{id}')
|
|
444
|
+
return response
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
class ArtifactExportJobAPI:
|
|
448
|
+
r"""Typed API definition for artifact export job resource."""
|
|
449
|
+
def __init__(self, client: KumoClient) -> None:
|
|
450
|
+
self._client = client
|
|
451
|
+
self._base_endpoint = '/artifact'
|
|
452
|
+
|
|
453
|
+
def create(self, request: ArtifactExportRequest) -> str:
|
|
454
|
+
response = self._client._post(
|
|
455
|
+
self._base_endpoint,
|
|
456
|
+
json=to_json_dict(request),
|
|
457
|
+
)
|
|
458
|
+
raise_on_error(response)
|
|
459
|
+
return parse_response(ArtifactExportResponse, response).job_id
|
|
460
|
+
|
|
461
|
+
# TODO Add an API in artifact export to get
|
|
462
|
+
# JobStatusReport and not just JobStatus
|
|
463
|
+
def get(self, id: str) -> JobStatus:
|
|
464
|
+
response = self._client._get(f'{self._base_endpoint}/{id}')
|
|
465
|
+
raise_on_error(response)
|
|
466
|
+
return parse_response(JobStatus, response)
|
|
467
|
+
|
|
468
|
+
def cancel(self, id: str) -> JobStatus:
|
|
469
|
+
response = self._client._post(f'{self._base_endpoint}/{id}/cancel')
|
|
470
|
+
raise_on_error(response)
|
|
471
|
+
return parse_response(JobStatus, response)
|
kumoai/client/online.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from http import HTTPStatus
|
|
2
|
+
from typing import Any, List, Optional
|
|
3
|
+
|
|
4
|
+
from kumoapi.json_serde import to_json_dict
|
|
5
|
+
from kumoapi.online_serving import (
|
|
6
|
+
OnlineServingEndpointRequest,
|
|
7
|
+
OnlineServingEndpointResource,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
from kumoai.client import KumoClient
|
|
11
|
+
from kumoai.client.endpoints import OnlineServingEndpoints
|
|
12
|
+
from kumoai.client.utils import (
|
|
13
|
+
parse_id_response,
|
|
14
|
+
parse_patch_response,
|
|
15
|
+
parse_response,
|
|
16
|
+
raise_on_error,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
OnlineServingEndpointID = str
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class OnlineServingEndpointAPI:
|
|
23
|
+
r"""Typed API definition for Kumo graph definition."""
|
|
24
|
+
def __init__(self, client: KumoClient) -> None:
|
|
25
|
+
self._client = client
|
|
26
|
+
self._base_endpoint = '/online_serving_endpoints'
|
|
27
|
+
|
|
28
|
+
# TODO(blaz): document final interface
|
|
29
|
+
def create(
|
|
30
|
+
self,
|
|
31
|
+
req: OnlineServingEndpointRequest,
|
|
32
|
+
**query_params: Any,
|
|
33
|
+
) -> OnlineServingEndpointID:
|
|
34
|
+
"""Creates a new online serving endpoint.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
req (OnlineServingEndpointRequest): request body.
|
|
38
|
+
use_ge (Optional[bool], optional): If present, override graph
|
|
39
|
+
backend option to use GRAPHENGINE if true else MEMORY.
|
|
40
|
+
**query_params: Additional query parameters to pass to the API.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
OnlineServingEndpointID: unique endpoint resource id.
|
|
44
|
+
"""
|
|
45
|
+
resp = self._client._post(
|
|
46
|
+
self._base_endpoint,
|
|
47
|
+
params=query_params if query_params else None,
|
|
48
|
+
json=to_json_dict(req),
|
|
49
|
+
)
|
|
50
|
+
raise_on_error(resp)
|
|
51
|
+
return parse_id_response(resp)
|
|
52
|
+
|
|
53
|
+
def get_if_exists(
|
|
54
|
+
self, id: OnlineServingEndpointID
|
|
55
|
+
) -> Optional[OnlineServingEndpointResource]:
|
|
56
|
+
resp = self._client._request(OnlineServingEndpoints.get.with_id(id))
|
|
57
|
+
if resp.status_code == HTTPStatus.NOT_FOUND:
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
raise_on_error(resp)
|
|
61
|
+
return parse_response(OnlineServingEndpointResource, resp)
|
|
62
|
+
|
|
63
|
+
def list(self) -> List[OnlineServingEndpointResource]:
|
|
64
|
+
resp = self._client._request(OnlineServingEndpoints.list)
|
|
65
|
+
raise_on_error(resp)
|
|
66
|
+
return parse_response(List[OnlineServingEndpointResource], resp)
|
|
67
|
+
|
|
68
|
+
def update(self, id: OnlineServingEndpointID,
|
|
69
|
+
req: OnlineServingEndpointRequest) -> bool:
|
|
70
|
+
resp = self._client._request(OnlineServingEndpoints.update.with_id(id),
|
|
71
|
+
data=to_json_dict(req))
|
|
72
|
+
raise_on_error(resp)
|
|
73
|
+
return parse_patch_response(resp)
|
|
74
|
+
|
|
75
|
+
def delete(self, id: OnlineServingEndpointID) -> None:
|
|
76
|
+
"""This is idempotent and can be called multiple times."""
|
|
77
|
+
resp = self._client._request(OnlineServingEndpoints.delete.with_id(id))
|
|
78
|
+
raise_on_error(resp)
|