kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (122) hide show
  1. kumoai/__init__.py +300 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +223 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +471 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +207 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1796 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +210 -0
  51. kumoai/experimental/rfm/authenticate.py +432 -0
  52. kumoai/experimental/rfm/backend/__init__.py +0 -0
  53. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  54. kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
  55. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  56. kumoai/experimental/rfm/backend/local/table.py +113 -0
  57. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  58. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  59. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  60. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  61. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  62. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  63. kumoai/experimental/rfm/base/__init__.py +30 -0
  64. kumoai/experimental/rfm/base/column.py +152 -0
  65. kumoai/experimental/rfm/base/expression.py +44 -0
  66. kumoai/experimental/rfm/base/sampler.py +761 -0
  67. kumoai/experimental/rfm/base/source.py +19 -0
  68. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  69. kumoai/experimental/rfm/base/table.py +736 -0
  70. kumoai/experimental/rfm/graph.py +1237 -0
  71. kumoai/experimental/rfm/infer/__init__.py +19 -0
  72. kumoai/experimental/rfm/infer/categorical.py +40 -0
  73. kumoai/experimental/rfm/infer/dtype.py +82 -0
  74. kumoai/experimental/rfm/infer/id.py +46 -0
  75. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  76. kumoai/experimental/rfm/infer/pkey.py +128 -0
  77. kumoai/experimental/rfm/infer/stype.py +35 -0
  78. kumoai/experimental/rfm/infer/time_col.py +61 -0
  79. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  80. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  81. kumoai/experimental/rfm/pquery/executor.py +102 -0
  82. kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
  83. kumoai/experimental/rfm/relbench.py +76 -0
  84. kumoai/experimental/rfm/rfm.py +1184 -0
  85. kumoai/experimental/rfm/sagemaker.py +138 -0
  86. kumoai/experimental/rfm/task_table.py +231 -0
  87. kumoai/formatting.py +30 -0
  88. kumoai/futures.py +99 -0
  89. kumoai/graph/__init__.py +12 -0
  90. kumoai/graph/column.py +106 -0
  91. kumoai/graph/graph.py +948 -0
  92. kumoai/graph/table.py +838 -0
  93. kumoai/jobs.py +80 -0
  94. kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
  95. kumoai/mixin.py +28 -0
  96. kumoai/pquery/__init__.py +25 -0
  97. kumoai/pquery/prediction_table.py +287 -0
  98. kumoai/pquery/predictive_query.py +641 -0
  99. kumoai/pquery/training_table.py +424 -0
  100. kumoai/spcs.py +121 -0
  101. kumoai/testing/__init__.py +8 -0
  102. kumoai/testing/decorators.py +57 -0
  103. kumoai/testing/snow.py +50 -0
  104. kumoai/trainer/__init__.py +42 -0
  105. kumoai/trainer/baseline_trainer.py +93 -0
  106. kumoai/trainer/config.py +2 -0
  107. kumoai/trainer/distilled_trainer.py +175 -0
  108. kumoai/trainer/job.py +1192 -0
  109. kumoai/trainer/online_serving.py +258 -0
  110. kumoai/trainer/trainer.py +475 -0
  111. kumoai/trainer/util.py +103 -0
  112. kumoai/utils/__init__.py +11 -0
  113. kumoai/utils/datasets.py +83 -0
  114. kumoai/utils/display.py +51 -0
  115. kumoai/utils/forecasting.py +209 -0
  116. kumoai/utils/progress_logger.py +343 -0
  117. kumoai/utils/sql.py +3 -0
  118. kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
  119. kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
  120. kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
  121. kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
  122. kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
kumoai/client/jobs.py ADDED
@@ -0,0 +1,471 @@
1
+ from datetime import datetime
2
+ from typing import (
3
+ Any,
4
+ Dict,
5
+ Generic,
6
+ List,
7
+ Mapping,
8
+ Optional,
9
+ Tuple,
10
+ Type,
11
+ TypeVar,
12
+ )
13
+
14
+ from kumoapi.common import ValidationResponse
15
+ from kumoapi.jobs import (
16
+ ArtifactExportRequest,
17
+ ArtifactExportResponse,
18
+ AutoTrainerProgress,
19
+ BaselineJobRequest,
20
+ BaselineJobResource,
21
+ BatchPredictionJobResource,
22
+ BatchPredictionRequest,
23
+ CancelBatchPredictionJobResponse,
24
+ CancelTrainingJobResponse,
25
+ DistillationJobRequest,
26
+ DistillationJobResource,
27
+ ErrorDetails,
28
+ GeneratePredictionTableJobResource,
29
+ GeneratePredictionTableRequest,
30
+ GenerateTrainTableJobResource,
31
+ GenerateTrainTableRequest,
32
+ GetEmbeddingsDfUrlResponse,
33
+ GetPredictionsDfUrlResponse,
34
+ JobRequestBase,
35
+ JobResourceBase,
36
+ JobStatus,
37
+ PredictionProgress,
38
+ TrainingJobRequest,
39
+ TrainingJobResource,
40
+ TrainingTableSpec,
41
+ )
42
+ from kumoapi.json_serde import from_json, to_json_dict
43
+ from kumoapi.source_table import LLMRequest, LLMResponse, SourceTableType
44
+ from kumoapi.train import TrainingStage
45
+ from typing_extensions import override
46
+
47
+ from kumoai.client import KumoClient
48
+ from kumoai.client.utils import (
49
+ Returns,
50
+ parse_patch_response,
51
+ parse_response,
52
+ raise_on_error,
53
+ )
54
+
55
+ TrainingJobID = str
56
+ BatchPredictionJobID = str
57
+ GenerateTrainTableJobID = str
58
+ GeneratePredictionTableJobID = str
59
+ LLMJobId = str
60
+ BaselineJobID = str
61
+
62
+ JobRequestType = TypeVar('JobRequestType', bound=JobRequestBase)
63
+ JobResourceType = TypeVar('JobResourceType', bound=JobResourceBase)
64
+
65
+
66
+ class CommonJobAPI(Generic[JobRequestType, JobResourceType]):
67
+ def __init__(self, client: KumoClient, base_endpoint: str,
68
+ res_type: Type[JobResourceType]) -> None:
69
+ self._client = client
70
+ self._base_endpoint = base_endpoint
71
+ self._res_type = res_type
72
+
73
+ def create(self, request: JobRequestType) -> str:
74
+ response = self._client._post(self._base_endpoint,
75
+ json=to_json_dict(request))
76
+ raise_on_error(response)
77
+ return parse_response(Dict[str, str], response)['id']
78
+
79
+ def get(self, id: str) -> JobResourceType:
80
+ response = self._client._get(f'{self._base_endpoint}/{id}')
81
+ raise_on_error(response)
82
+ return parse_response(self._res_type, response)
83
+
84
+ def list(
85
+ self,
86
+ *,
87
+ pquery_name: Optional[str] = None,
88
+ pquery_id: Optional[str] = None,
89
+ job_status: Optional[JobStatus] = None,
90
+ limit: Optional[int] = None,
91
+ additional_tags: Mapping[str, str] = {},
92
+ ) -> List[JobResourceType]:
93
+ params: Dict[str, Any] = {
94
+ 'pquery_name': pquery_name,
95
+ 'pquery_id': pquery_id,
96
+ 'job_status': job_status,
97
+ 'limit': limit
98
+ }
99
+ params.update(additional_tags)
100
+ response = self._client._get(self._base_endpoint, params=params)
101
+ raise_on_error(response)
102
+ resource_elements = response.json()
103
+ assert isinstance(resource_elements, list)
104
+ return [from_json(e, self._res_type) for e in resource_elements]
105
+
106
+ def delete_tags(self, job_id: str, tags: List[str]) -> bool:
107
+ r"""Removes the tags from the job.
108
+
109
+ Args:
110
+ job_id (str): The ID of the job.
111
+ tags (List[str]): The tags to remove.
112
+ """
113
+ return self.update_tags(job_id, {t: 'none' for t in tags})
114
+
115
+ def update_tags(self, job_id: str,
116
+ custom_job_tags: Mapping[str, Optional[str]]) -> bool:
117
+ r"""Updates the tags of the job.
118
+
119
+ Args:
120
+ job_id (str): The ID of the job.
121
+ custom_job_tags (Mapping[str, Optional[str]]): The tags to update.
122
+ Note that the value 'none' will remove the tag. If the tag is
123
+ not present, it will be added.
124
+ """
125
+ response = self._client._patch(
126
+ f'{self._base_endpoint}/{job_id}/tags',
127
+ data=None,
128
+ params={
129
+ k: str(v)
130
+ for k, v in custom_job_tags.items()
131
+ },
132
+ )
133
+ raise_on_error(response)
134
+ return parse_patch_response(response)
135
+
136
+
137
+ class BaselineJobAPI(CommonJobAPI[BaselineJobRequest, BaselineJobResource]):
138
+ r"""Typed API definition for the baseline job resource."""
139
+ def __init__(self, client: KumoClient) -> None:
140
+ super().__init__(client, '/baseline_jobs', BaselineJobResource)
141
+
142
+ def get_config(self, job_id: str) -> BaselineJobRequest:
143
+ """Load the configuration for a baseline job by ID."""
144
+ resource = self.get(job_id)
145
+ return resource.config
146
+
147
+
148
+ class TrainingJobAPI(CommonJobAPI[TrainingJobRequest, TrainingJobResource]):
149
+ r"""Typed API definition for the training job resource."""
150
+ def __init__(self, client: KumoClient) -> None:
151
+ super().__init__(client, '/training_jobs', TrainingJobResource)
152
+
153
+ def get_progress(self, id: TrainingJobID) -> AutoTrainerProgress:
154
+ response = self._client._get(f'{self._base_endpoint}/{id}/progress')
155
+ raise_on_error(response)
156
+ return parse_response(AutoTrainerProgress, response)
157
+
158
+ def holdout_data_url(self, id: TrainingJobID,
159
+ presigned: bool = True) -> str:
160
+ response = self._client._get(f'{self._base_endpoint}/{id}/holdout',
161
+ params={'presigned': presigned})
162
+ raise_on_error(response)
163
+ return response.text
164
+
165
+ def cancel(self, id: str) -> CancelTrainingJobResponse:
166
+ response = self._client._post(f'{self._base_endpoint}/{id}/cancel')
167
+ raise_on_error(response)
168
+ return parse_response(CancelTrainingJobResponse, response)
169
+
170
+ def get_config(self, job_id: str) -> TrainingJobRequest:
171
+ """Load the configuration for a training job by ID."""
172
+ resource = self.get(job_id)
173
+ return resource.config
174
+
175
+
176
+ class DistillationJobAPI(CommonJobAPI[DistillationJobRequest,
177
+ DistillationJobResource]):
178
+ r"""Typed API definition for the distillation job resource."""
179
+ def __init__(self, client: KumoClient) -> None:
180
+ super().__init__(client, '/training_jobs/distilled_training_job',
181
+ DistillationJobResource)
182
+
183
+ def get_config(self, job_id: str) -> DistillationJobRequest:
184
+ raise NotImplementedError(
185
+ "Getting the configuration for a distillation job is "
186
+ "not implemented yet.")
187
+
188
+ def get_progress(self, id: str) -> AutoTrainerProgress:
189
+ raise NotImplementedError(
190
+ "Getting the progress for a distillation job is not "
191
+ "implemented yet.")
192
+
193
+ def cancel(self, id: str) -> CancelTrainingJobResponse:
194
+ raise NotImplementedError(
195
+ "Cancelling a distillation job is not implemented yet.")
196
+
197
+
198
+ class BatchPredictionJobAPI(CommonJobAPI[BatchPredictionRequest,
199
+ BatchPredictionJobResource]):
200
+ r"""Typed API definition for the prediction job resource."""
201
+ def __init__(self, client: KumoClient) -> None:
202
+ super().__init__(client, '/prediction_jobs',
203
+ BatchPredictionJobResource)
204
+
205
+ @override
206
+ def create(self, request: BatchPredictionRequest) -> str:
207
+ # TODO(manan): eventually, all `create` methods should
208
+ # return a validation response:
209
+ raise NotImplementedError
210
+
211
+ def maybe_create(
212
+ self, request: BatchPredictionRequest
213
+ ) -> Tuple[Optional[str], ValidationResponse]:
214
+ response = self._client._post(self._base_endpoint,
215
+ json=to_json_dict(request))
216
+ raise_on_error(response)
217
+ return parse_response(
218
+ Returns[Tuple[Optional[str], ValidationResponse]], response)
219
+
220
+ def list(
221
+ self,
222
+ *,
223
+ model_id: Optional[TrainingJobID] = None,
224
+ pquery_name: Optional[str] = None,
225
+ pquery_id: Optional[str] = None,
226
+ job_status: Optional[JobStatus] = None,
227
+ limit: Optional[int] = None,
228
+ additional_tags: Mapping[str, str] = {},
229
+ ) -> List[BatchPredictionJobResource]:
230
+ if model_id:
231
+ additional_tags = {**additional_tags, 'model_id': model_id}
232
+ return super().list(pquery_name=pquery_name, pquery_id=pquery_id,
233
+ job_status=job_status, limit=limit,
234
+ additional_tags=additional_tags)
235
+
236
+ def get_progress(self, id: str) -> PredictionProgress:
237
+ response = self._client._get(f'{self._base_endpoint}/{id}/progress')
238
+ raise_on_error(response)
239
+ return parse_response(PredictionProgress, response)
240
+
241
+ def cancel(self, id: str) -> CancelBatchPredictionJobResponse:
242
+ response = self._client._post(f'{self._base_endpoint}/{id}/cancel')
243
+ raise_on_error(response)
244
+ return parse_response(CancelBatchPredictionJobResponse, response)
245
+
246
+ def get_batch_predictions_url(self, id: str) -> List[str]:
247
+ """Returns presigned URLs pointing to the locations where the
248
+ predictions are stored. Depending on the environment where this is run,
249
+ they could be AWS S3 paths, Snowflake stage paths, or Databricks UC
250
+ volume paths.
251
+
252
+ Args:
253
+ id (str): ID of the batch prediction job for which predictions are
254
+ requested
255
+ """
256
+ response = self._client._get(
257
+ f'{self._base_endpoint}/{id}/get_prediction_df_urls')
258
+ raise_on_error(response)
259
+ return parse_response(
260
+ GetPredictionsDfUrlResponse,
261
+ response,
262
+ ).prediction_partitions
263
+
264
+ def get_batch_embeddings_url(self, id: str) -> List[str]:
265
+ """Returns presigned URLs pointing to the locations where the
266
+ embeddings are stored. Depending on the environment where this is run,
267
+ they could be AWS S3 paths, Snowflake stage paths, or Databricks UC
268
+ volume paths.
269
+
270
+ Args:
271
+ id (str): ID of the batch prediction job for which embeddings are
272
+ requested
273
+ """
274
+ response = self._client._get(
275
+ f'{self._base_endpoint}/{id}/get_embedding_df_urls')
276
+ raise_on_error(response)
277
+ return parse_response(
278
+ GetEmbeddingsDfUrlResponse,
279
+ response,
280
+ ).embedding_partitions
281
+
282
+ def get_config(self, job_id: str) -> BatchPredictionRequest:
283
+ """Load the configuration for a batch prediction job by ID."""
284
+ resource = self.get(job_id)
285
+ return resource.config
286
+
287
+
288
+ class GenerateTrainTableJobAPI(CommonJobAPI[GenerateTrainTableRequest,
289
+ GenerateTrainTableJobResource]):
290
+ r"""Typed API definition for training table generation job resource."""
291
+ def __init__(self, client: KumoClient) -> None:
292
+ super().__init__(client, '/gentraintable_jobs',
293
+ GenerateTrainTableJobResource)
294
+
295
+ def get_table_data(self, id: GenerateTrainTableJobID,
296
+ presigned: bool = True) -> List[str]:
297
+ """Return a list of URLs to access train table parquet data.
298
+ There might be multiple URLs if the table data is partitioned into
299
+ multiple files.
300
+ """
301
+ return self._get_table_data(id, presigned)
302
+
303
+ def _get_table_data(self, id: GenerateTrainTableJobID,
304
+ presigned: bool = True,
305
+ raw_path: bool = False) -> List[str]:
306
+ """Helper function to get train table data."""
307
+ # Raw path to get local file path instead of SPCS stage path
308
+ params: Dict[str, Any] = {'presigned': presigned, 'raw_path': raw_path}
309
+ resp = self._client._get(f'{self._base_endpoint}/{id}/table_data',
310
+ params=params)
311
+ raise_on_error(resp)
312
+ return parse_response(List[str], resp)
313
+
314
+ def get_split_masks(
315
+ self, id: GenerateTrainTableJobID) -> Dict[TrainingStage, str]:
316
+ """Return a dictionary of presigned URLs keyed by training stage.
317
+ Each URL points to a torch-serialized (default pickle protocol) file of
318
+ the mask tensor for that training stage.
319
+
320
+ Example:
321
+ >>> # code to load a mask tensor:
322
+ >>> import io
323
+ >>> import torch
324
+ >>> import requests
325
+ >>> masks = get_split_masks('some-gen-traintable-job-id')
326
+ >>> data_bytes = requests.get(masks[TrainingStage.TEST]).content
327
+ >>> test_mask_tensor = torch.load(io.BytesIO(data))
328
+ """
329
+ resp = self._client._get(f'{self._base_endpoint}/{id}/split_masks')
330
+ raise_on_error(resp)
331
+ return parse_response(Dict[TrainingStage, str], resp)
332
+
333
+ def get_progress(self, id: str) -> Dict[str, int]:
334
+ response = self._client._get(f'{self._base_endpoint}/{id}/progress')
335
+ raise_on_error(response)
336
+ return parse_response(Dict[str, int], response)
337
+
338
+ def cancel(self, id: str) -> None:
339
+ response = self._client._post(f'{self._base_endpoint}/{id}/cancel')
340
+ raise_on_error(response)
341
+
342
+ def validate_custom_train_table(
343
+ self,
344
+ id: str,
345
+ source_table_type: SourceTableType,
346
+ train_table_mod: TrainingTableSpec,
347
+ ) -> ValidationResponse:
348
+ response = self._client._post(
349
+ f'{self._base_endpoint}/{id}/validate_custom_train_table',
350
+ json=to_json_dict({
351
+ 'custom_table': source_table_type,
352
+ 'train_table_mod': train_table_mod,
353
+ }),
354
+ )
355
+ return parse_response(ValidationResponse, response)
356
+
357
+ def get_job_error(self, id: str) -> ErrorDetails:
358
+ """Thin API wrapper for fetching errors from the jobs.
359
+
360
+ Arguments:
361
+ id (str): Id of the job whose related errors are expected to be
362
+ queried.
363
+ """
364
+ response = self._client._get(f'{self._base_endpoint}/{id}/get_errors')
365
+ raise_on_error(response)
366
+ return parse_response(ErrorDetails, response)
367
+
368
+ def get_config(self, job_id: str) -> GenerateTrainTableRequest:
369
+ """Load the configuration for a training table generation job by ID."""
370
+ resource = self.get(job_id)
371
+ return resource.config
372
+
373
+
374
+ class GeneratePredictionTableJobAPI(
375
+ CommonJobAPI[GeneratePredictionTableRequest,
376
+ GeneratePredictionTableJobResource]):
377
+ r"""Typed API definition for prediction table generation job resource."""
378
+ def __init__(self, client: KumoClient) -> None:
379
+ super().__init__(client, '/genpredtable_jobs',
380
+ GeneratePredictionTableJobResource)
381
+
382
+ def get_anchor_time(self, id: BatchPredictionJobID) -> Optional[datetime]:
383
+ response = self._client._get(
384
+ f'{self._base_endpoint}/{id}/get_anchor_time')
385
+ raise_on_error(response)
386
+ return parse_response(Returns[Optional[datetime]], response)
387
+
388
+ def get_table_data(self, id: GeneratePredictionTableJobID,
389
+ presigned: bool = True) -> List[str]:
390
+ """Return a list of URLs to access prediction table parquet data.
391
+ There might be multiple URLs if the table data is partitioned into
392
+ multiple files.
393
+ """
394
+ params: Dict[str, Any] = {'presigned': presigned}
395
+ resp = self._client._get(f'{self._base_endpoint}/{id}/table_data',
396
+ params=params)
397
+ raise_on_error(resp)
398
+ return parse_response(List[str], resp)
399
+
400
+ def cancel(self, id: str) -> None:
401
+ response = self._client._post(f'{self._base_endpoint}/{id}/cancel')
402
+ raise_on_error(response)
403
+
404
+ def get_job_error(self, id: str) -> ErrorDetails:
405
+ """Thin API wrapper for fetching errors from the jobs.
406
+
407
+ Arguments:
408
+ id (str): Id of the job whose related errors are expected to be
409
+ queried.
410
+ """
411
+ response = self._client._get(f'{self._base_endpoint}/{id}/get_errors')
412
+ raise_on_error(response)
413
+ return parse_response(ErrorDetails, response)
414
+
415
+ def get_config(self, job_id: str) -> GeneratePredictionTableRequest:
416
+ """Load the configuration for a
417
+ prediction table generation job by ID.
418
+ """
419
+ resource = self.get(job_id)
420
+ return resource.config
421
+
422
+
423
+ class LLMJobAPI:
424
+ r"""Typed API definition for LLM job resource."""
425
+ def __init__(self, client: KumoClient) -> None:
426
+ self._client = client
427
+ self._base_endpoint = '/llm_embedding_job'
428
+
429
+ def create(self, request: LLMRequest) -> LLMJobId:
430
+ response = self._client._post(
431
+ self._base_endpoint,
432
+ json=to_json_dict(request),
433
+ )
434
+ raise_on_error(response)
435
+ return parse_response(LLMResponse, response).job_id
436
+
437
+ def get(self, id: LLMJobId) -> JobStatus:
438
+ response = self._client._get(f'{self._base_endpoint}/status/{id}')
439
+ raise_on_error(response)
440
+ return parse_response(JobStatus, response)
441
+
442
+ def cancel(self, id: LLMJobId) -> JobStatus:
443
+ response = self._client._delete(f'{self._base_endpoint}/cancel/{id}')
444
+ return response
445
+
446
+
447
+ class ArtifactExportJobAPI:
448
+ r"""Typed API definition for artifact export job resource."""
449
+ def __init__(self, client: KumoClient) -> None:
450
+ self._client = client
451
+ self._base_endpoint = '/artifact'
452
+
453
+ def create(self, request: ArtifactExportRequest) -> str:
454
+ response = self._client._post(
455
+ self._base_endpoint,
456
+ json=to_json_dict(request),
457
+ )
458
+ raise_on_error(response)
459
+ return parse_response(ArtifactExportResponse, response).job_id
460
+
461
+ # TODO Add an API in artifact export to get
462
+ # JobStatusReport and not just JobStatus
463
+ def get(self, id: str) -> JobStatus:
464
+ response = self._client._get(f'{self._base_endpoint}/{id}')
465
+ raise_on_error(response)
466
+ return parse_response(JobStatus, response)
467
+
468
+ def cancel(self, id: str) -> JobStatus:
469
+ response = self._client._post(f'{self._base_endpoint}/{id}/cancel')
470
+ raise_on_error(response)
471
+ return parse_response(JobStatus, response)
@@ -0,0 +1,78 @@
1
+ from http import HTTPStatus
2
+ from typing import Any, List, Optional
3
+
4
+ from kumoapi.json_serde import to_json_dict
5
+ from kumoapi.online_serving import (
6
+ OnlineServingEndpointRequest,
7
+ OnlineServingEndpointResource,
8
+ )
9
+
10
+ from kumoai.client import KumoClient
11
+ from kumoai.client.endpoints import OnlineServingEndpoints
12
+ from kumoai.client.utils import (
13
+ parse_id_response,
14
+ parse_patch_response,
15
+ parse_response,
16
+ raise_on_error,
17
+ )
18
+
19
+ OnlineServingEndpointID = str
20
+
21
+
22
+ class OnlineServingEndpointAPI:
23
+ r"""Typed API definition for Kumo graph definition."""
24
+ def __init__(self, client: KumoClient) -> None:
25
+ self._client = client
26
+ self._base_endpoint = '/online_serving_endpoints'
27
+
28
+ # TODO(blaz): document final interface
29
+ def create(
30
+ self,
31
+ req: OnlineServingEndpointRequest,
32
+ **query_params: Any,
33
+ ) -> OnlineServingEndpointID:
34
+ """Creates a new online serving endpoint.
35
+
36
+ Args:
37
+ req (OnlineServingEndpointRequest): request body.
38
+ use_ge (Optional[bool], optional): If present, override graph
39
+ backend option to use GRAPHENGINE if true else MEMORY.
40
+ **query_params: Additional query parameters to pass to the API.
41
+
42
+ Returns:
43
+ OnlineServingEndpointID: unique endpoint resource id.
44
+ """
45
+ resp = self._client._post(
46
+ self._base_endpoint,
47
+ params=query_params if query_params else None,
48
+ json=to_json_dict(req),
49
+ )
50
+ raise_on_error(resp)
51
+ return parse_id_response(resp)
52
+
53
+ def get_if_exists(
54
+ self, id: OnlineServingEndpointID
55
+ ) -> Optional[OnlineServingEndpointResource]:
56
+ resp = self._client._request(OnlineServingEndpoints.get.with_id(id))
57
+ if resp.status_code == HTTPStatus.NOT_FOUND:
58
+ return None
59
+
60
+ raise_on_error(resp)
61
+ return parse_response(OnlineServingEndpointResource, resp)
62
+
63
+ def list(self) -> List[OnlineServingEndpointResource]:
64
+ resp = self._client._request(OnlineServingEndpoints.list)
65
+ raise_on_error(resp)
66
+ return parse_response(List[OnlineServingEndpointResource], resp)
67
+
68
+ def update(self, id: OnlineServingEndpointID,
69
+ req: OnlineServingEndpointRequest) -> bool:
70
+ resp = self._client._request(OnlineServingEndpoints.update.with_id(id),
71
+ data=to_json_dict(req))
72
+ raise_on_error(resp)
73
+ return parse_patch_response(resp)
74
+
75
+ def delete(self, id: OnlineServingEndpointID) -> None:
76
+ """This is idempotent and can be called multiple times."""
77
+ resp = self._client._request(OnlineServingEndpoints.delete.with_id(id))
78
+ raise_on_error(resp)