kumoai 2.13.0.dev202511131731__cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (98) hide show
  1. kumoai/__init__.py +294 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +221 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +447 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +203 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1775 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +67 -0
  51. kumoai/experimental/rfm/authenticate.py +433 -0
  52. kumoai/experimental/rfm/infer/__init__.py +11 -0
  53. kumoai/experimental/rfm/infer/categorical.py +40 -0
  54. kumoai/experimental/rfm/infer/id.py +46 -0
  55. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  56. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  57. kumoai/experimental/rfm/local_graph.py +810 -0
  58. kumoai/experimental/rfm/local_graph_sampler.py +184 -0
  59. kumoai/experimental/rfm/local_graph_store.py +359 -0
  60. kumoai/experimental/rfm/local_pquery_driver.py +689 -0
  61. kumoai/experimental/rfm/local_table.py +545 -0
  62. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  63. kumoai/experimental/rfm/pquery/executor.py +102 -0
  64. kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
  65. kumoai/experimental/rfm/rfm.py +1130 -0
  66. kumoai/experimental/rfm/utils.py +344 -0
  67. kumoai/formatting.py +30 -0
  68. kumoai/futures.py +99 -0
  69. kumoai/graph/__init__.py +12 -0
  70. kumoai/graph/column.py +106 -0
  71. kumoai/graph/graph.py +948 -0
  72. kumoai/graph/table.py +838 -0
  73. kumoai/jobs.py +80 -0
  74. kumoai/kumolib.cpython-313-x86_64-linux-gnu.so +0 -0
  75. kumoai/mixin.py +28 -0
  76. kumoai/pquery/__init__.py +25 -0
  77. kumoai/pquery/prediction_table.py +287 -0
  78. kumoai/pquery/predictive_query.py +637 -0
  79. kumoai/pquery/training_table.py +424 -0
  80. kumoai/spcs.py +123 -0
  81. kumoai/testing/__init__.py +8 -0
  82. kumoai/testing/decorators.py +57 -0
  83. kumoai/trainer/__init__.py +42 -0
  84. kumoai/trainer/baseline_trainer.py +93 -0
  85. kumoai/trainer/config.py +2 -0
  86. kumoai/trainer/job.py +1192 -0
  87. kumoai/trainer/online_serving.py +258 -0
  88. kumoai/trainer/trainer.py +475 -0
  89. kumoai/trainer/util.py +103 -0
  90. kumoai/utils/__init__.py +10 -0
  91. kumoai/utils/datasets.py +83 -0
  92. kumoai/utils/forecasting.py +209 -0
  93. kumoai/utils/progress_logger.py +177 -0
  94. kumoai-2.13.0.dev202511131731.dist-info/METADATA +60 -0
  95. kumoai-2.13.0.dev202511131731.dist-info/RECORD +98 -0
  96. kumoai-2.13.0.dev202511131731.dist-info/WHEEL +6 -0
  97. kumoai-2.13.0.dev202511131731.dist-info/licenses/LICENSE +9 -0
  98. kumoai-2.13.0.dev202511131731.dist-info/top_level.txt +1 -0
kumoai/client/jobs.py ADDED
@@ -0,0 +1,447 @@
1
+ from datetime import datetime
2
+ from typing import (
3
+ Any,
4
+ Dict,
5
+ Generic,
6
+ List,
7
+ Mapping,
8
+ Optional,
9
+ Tuple,
10
+ Type,
11
+ TypeVar,
12
+ )
13
+
14
+ from kumoapi.common import ValidationResponse
15
+ from kumoapi.jobs import (
16
+ ArtifactExportRequest,
17
+ ArtifactExportResponse,
18
+ AutoTrainerProgress,
19
+ BaselineJobRequest,
20
+ BaselineJobResource,
21
+ BatchPredictionJobResource,
22
+ BatchPredictionRequest,
23
+ CancelBatchPredictionJobResponse,
24
+ CancelTrainingJobResponse,
25
+ ErrorDetails,
26
+ GeneratePredictionTableJobResource,
27
+ GeneratePredictionTableRequest,
28
+ GenerateTrainTableJobResource,
29
+ GenerateTrainTableRequest,
30
+ GetEmbeddingsDfUrlResponse,
31
+ GetPredictionsDfUrlResponse,
32
+ JobRequestBase,
33
+ JobResourceBase,
34
+ JobStatus,
35
+ PredictionProgress,
36
+ TrainingJobRequest,
37
+ TrainingJobResource,
38
+ TrainingTableSpec,
39
+ )
40
+ from kumoapi.json_serde import from_json, to_json_dict
41
+ from kumoapi.source_table import LLMRequest, LLMResponse, SourceTableType
42
+ from kumoapi.train import TrainingStage
43
+ from typing_extensions import override
44
+
45
+ from kumoai.client import KumoClient
46
+ from kumoai.client.utils import (
47
+ Returns,
48
+ parse_patch_response,
49
+ parse_response,
50
+ raise_on_error,
51
+ )
52
+
53
+ TrainingJobID = str
54
+ BatchPredictionJobID = str
55
+ GenerateTrainTableJobID = str
56
+ GeneratePredictionTableJobID = str
57
+ LLMJobId = str
58
+ BaselineJobID = str
59
+
60
+ JobRequestType = TypeVar('JobRequestType', bound=JobRequestBase)
61
+ JobResourceType = TypeVar('JobResourceType', bound=JobResourceBase)
62
+
63
+
64
+ class CommonJobAPI(Generic[JobRequestType, JobResourceType]):
65
+ def __init__(self, client: KumoClient, base_endpoint: str,
66
+ res_type: Type[JobResourceType]) -> None:
67
+ self._client = client
68
+ self._base_endpoint = base_endpoint
69
+ self._res_type = res_type
70
+
71
+ def create(self, request: JobRequestType) -> str:
72
+ response = self._client._post(self._base_endpoint,
73
+ json=to_json_dict(request))
74
+ raise_on_error(response)
75
+ return parse_response(Dict[str, str], response)['id']
76
+
77
+ def get(self, id: str) -> JobResourceType:
78
+ response = self._client._get(f'{self._base_endpoint}/{id}')
79
+ raise_on_error(response)
80
+ return parse_response(self._res_type, response)
81
+
82
+ def list(
83
+ self,
84
+ *,
85
+ pquery_name: Optional[str] = None,
86
+ pquery_id: Optional[str] = None,
87
+ job_status: Optional[JobStatus] = None,
88
+ limit: Optional[int] = None,
89
+ additional_tags: Mapping[str, str] = {},
90
+ ) -> List[JobResourceType]:
91
+ params: Dict[str, Any] = {
92
+ 'pquery_name': pquery_name,
93
+ 'pquery_id': pquery_id,
94
+ 'job_status': job_status,
95
+ 'limit': limit
96
+ }
97
+ params.update(additional_tags)
98
+ response = self._client._get(self._base_endpoint, params=params)
99
+ raise_on_error(response)
100
+ resource_elements = response.json()
101
+ assert isinstance(resource_elements, list)
102
+ return [from_json(e, self._res_type) for e in resource_elements]
103
+
104
+ def delete_tags(self, job_id: str, tags: List[str]) -> bool:
105
+ r"""Removes the tags from the job.
106
+
107
+ Args:
108
+ job_id (str): The ID of the job.
109
+ tags (List[str]): The tags to remove.
110
+ """
111
+ return self.update_tags(job_id, {t: 'none' for t in tags})
112
+
113
+ def update_tags(self, job_id: str,
114
+ custom_job_tags: Mapping[str, Optional[str]]) -> bool:
115
+ r"""Updates the tags of the job.
116
+
117
+ Args:
118
+ job_id (str): The ID of the job.
119
+ custom_job_tags (Mapping[str, Optional[str]]): The tags to update.
120
+ Note that the value 'none' will remove the tag. If the tag is
121
+ not present, it will be added.
122
+ """
123
+ response = self._client._patch(
124
+ f'{self._base_endpoint}/{job_id}/tags',
125
+ data=None,
126
+ params={
127
+ k: str(v)
128
+ for k, v in custom_job_tags.items()
129
+ },
130
+ )
131
+ raise_on_error(response)
132
+ return parse_patch_response(response)
133
+
134
+
135
+ class BaselineJobAPI(CommonJobAPI[BaselineJobRequest, BaselineJobResource]):
136
+ r"""Typed API definition for the baseline job resource."""
137
+ def __init__(self, client: KumoClient) -> None:
138
+ super().__init__(client, '/baseline_jobs', BaselineJobResource)
139
+
140
+ def get_config(self, job_id: str) -> BaselineJobRequest:
141
+ """Load the configuration for a baseline job by ID."""
142
+ resource = self.get(job_id)
143
+ return resource.config
144
+
145
+
146
+ class TrainingJobAPI(CommonJobAPI[TrainingJobRequest, TrainingJobResource]):
147
+ r"""Typed API definition for the training job resource."""
148
+ def __init__(self, client: KumoClient) -> None:
149
+ super().__init__(client, '/training_jobs', TrainingJobResource)
150
+
151
+ def get_progress(self, id: TrainingJobID) -> AutoTrainerProgress:
152
+ response = self._client._get(f'{self._base_endpoint}/{id}/progress')
153
+ raise_on_error(response)
154
+ return parse_response(AutoTrainerProgress, response)
155
+
156
+ def holdout_data_url(self, id: TrainingJobID,
157
+ presigned: bool = True) -> str:
158
+ response = self._client._get(f'{self._base_endpoint}/{id}/holdout',
159
+ params={'presigned': presigned})
160
+ raise_on_error(response)
161
+ return response.text
162
+
163
+ def cancel(self, id: str) -> CancelTrainingJobResponse:
164
+ response = self._client._post(f'{self._base_endpoint}/{id}/cancel')
165
+ raise_on_error(response)
166
+ return parse_response(CancelTrainingJobResponse, response)
167
+
168
+ def get_config(self, job_id: str) -> TrainingJobRequest:
169
+ """Load the configuration for a training job by ID."""
170
+ resource = self.get(job_id)
171
+ return resource.config
172
+
173
+
174
+ class BatchPredictionJobAPI(CommonJobAPI[BatchPredictionRequest,
175
+ BatchPredictionJobResource]):
176
+ r"""Typed API definition for the prediction job resource."""
177
+ def __init__(self, client: KumoClient) -> None:
178
+ super().__init__(client, '/prediction_jobs',
179
+ BatchPredictionJobResource)
180
+
181
+ @override
182
+ def create(self, request: BatchPredictionRequest) -> str:
183
+ # TODO(manan): eventually, all `create` methods should
184
+ # return a validation response:
185
+ raise NotImplementedError
186
+
187
+ def maybe_create(
188
+ self, request: BatchPredictionRequest
189
+ ) -> Tuple[Optional[str], ValidationResponse]:
190
+ response = self._client._post(self._base_endpoint,
191
+ json=to_json_dict(request))
192
+ raise_on_error(response)
193
+ return parse_response(
194
+ Returns[Tuple[Optional[str], ValidationResponse]], response)
195
+
196
+ def list(
197
+ self,
198
+ *,
199
+ model_id: Optional[TrainingJobID] = None,
200
+ pquery_name: Optional[str] = None,
201
+ pquery_id: Optional[str] = None,
202
+ job_status: Optional[JobStatus] = None,
203
+ limit: Optional[int] = None,
204
+ additional_tags: Mapping[str, str] = {},
205
+ ) -> List[BatchPredictionJobResource]:
206
+ if model_id:
207
+ additional_tags = {**additional_tags, 'model_id': model_id}
208
+ return super().list(pquery_name=pquery_name, pquery_id=pquery_id,
209
+ job_status=job_status, limit=limit,
210
+ additional_tags=additional_tags)
211
+
212
+ def get_progress(self, id: str) -> PredictionProgress:
213
+ response = self._client._get(f'{self._base_endpoint}/{id}/progress')
214
+ raise_on_error(response)
215
+ return parse_response(PredictionProgress, response)
216
+
217
+ def cancel(self, id: str) -> CancelBatchPredictionJobResponse:
218
+ response = self._client._post(f'{self._base_endpoint}/{id}/cancel')
219
+ raise_on_error(response)
220
+ return parse_response(CancelBatchPredictionJobResponse, response)
221
+
222
+ def get_batch_predictions_url(self, id: str) -> List[str]:
223
+ """Returns presigned URLs pointing to the locations where the
224
+ predictions are stored. Depending on the environment where this is run,
225
+ they could be AWS S3 paths, Snowflake stage paths, or Databricks UC
226
+ volume paths.
227
+
228
+ Args:
229
+ id (str): ID of the batch prediction job for which predictions are
230
+ requested
231
+ """
232
+ response = self._client._get(
233
+ f'{self._base_endpoint}/{id}/get_prediction_df_urls')
234
+ raise_on_error(response)
235
+ return parse_response(
236
+ GetPredictionsDfUrlResponse,
237
+ response,
238
+ ).prediction_partitions
239
+
240
+ def get_batch_embeddings_url(self, id: str) -> List[str]:
241
+ """Returns presigned URLs pointing to the locations where the
242
+ embeddings are stored. Depending on the environment where this is run,
243
+ they could be AWS S3 paths, Snowflake stage paths, or Databricks UC
244
+ volume paths.
245
+
246
+ Args:
247
+ id (str): ID of the batch prediction job for which embeddings are
248
+ requested
249
+ """
250
+ response = self._client._get(
251
+ f'{self._base_endpoint}/{id}/get_embedding_df_urls')
252
+ raise_on_error(response)
253
+ return parse_response(
254
+ GetEmbeddingsDfUrlResponse,
255
+ response,
256
+ ).embedding_partitions
257
+
258
+ def get_config(self, job_id: str) -> BatchPredictionRequest:
259
+ """Load the configuration for a batch prediction job by ID."""
260
+ resource = self.get(job_id)
261
+ return resource.config
262
+
263
+
264
+ class GenerateTrainTableJobAPI(CommonJobAPI[GenerateTrainTableRequest,
265
+ GenerateTrainTableJobResource]):
266
+ r"""Typed API definition for training table generation job resource."""
267
+ def __init__(self, client: KumoClient) -> None:
268
+ super().__init__(client, '/gentraintable_jobs',
269
+ GenerateTrainTableJobResource)
270
+
271
+ def get_table_data(self, id: GenerateTrainTableJobID,
272
+ presigned: bool = True) -> List[str]:
273
+ """Return a list of URLs to access train table parquet data.
274
+ There might be multiple URLs if the table data is partitioned into
275
+ multiple files.
276
+ """
277
+ return self._get_table_data(id, presigned)
278
+
279
+ def _get_table_data(self, id: GenerateTrainTableJobID,
280
+ presigned: bool = True,
281
+ raw_path: bool = False) -> List[str]:
282
+ """Helper function to get train table data."""
283
+ # Raw path to get local file path instead of SPCS stage path
284
+ params: Dict[str, Any] = {'presigned': presigned, 'raw_path': raw_path}
285
+ resp = self._client._get(f'{self._base_endpoint}/{id}/table_data',
286
+ params=params)
287
+ raise_on_error(resp)
288
+ return parse_response(List[str], resp)
289
+
290
+ def get_split_masks(
291
+ self, id: GenerateTrainTableJobID) -> Dict[TrainingStage, str]:
292
+ """Return a dictionary of presigned URLs keyed by training stage.
293
+ Each URL points to a torch-serialized (default pickle protocol) file of
294
+ the mask tensor for that training stage.
295
+
296
+ Example:
297
+ >>> # code to load a mask tensor:
298
+ >>> import io
299
+ >>> import torch
300
+ >>> import requests
301
+ >>> masks = get_split_masks('some-gen-traintable-job-id')
302
+ >>> data_bytes = requests.get(masks[TrainingStage.TEST]).content
303
+ >>> test_mask_tensor = torch.load(io.BytesIO(data))
304
+ """
305
+ resp = self._client._get(f'{self._base_endpoint}/{id}/split_masks')
306
+ raise_on_error(resp)
307
+ return parse_response(Dict[TrainingStage, str], resp)
308
+
309
+ def get_progress(self, id: str) -> Dict[str, int]:
310
+ response = self._client._get(f'{self._base_endpoint}/{id}/progress')
311
+ raise_on_error(response)
312
+ return parse_response(Dict[str, int], response)
313
+
314
+ def cancel(self, id: str) -> None:
315
+ response = self._client._post(f'{self._base_endpoint}/{id}/cancel')
316
+ raise_on_error(response)
317
+
318
+ def validate_custom_train_table(
319
+ self,
320
+ id: str,
321
+ source_table_type: SourceTableType,
322
+ train_table_mod: TrainingTableSpec,
323
+ ) -> ValidationResponse:
324
+ response = self._client._post(
325
+ f'{self._base_endpoint}/{id}/validate_custom_train_table',
326
+ json=to_json_dict({
327
+ 'custom_table': source_table_type,
328
+ 'train_table_mod': train_table_mod,
329
+ }),
330
+ )
331
+ return parse_response(ValidationResponse, response)
332
+
333
+ def get_job_error(self, id: str) -> ErrorDetails:
334
+ """Thin API wrapper for fetching errors from the jobs.
335
+
336
+ Arguments:
337
+ id (str): Id of the job whose related errors are expected to be
338
+ queried.
339
+ """
340
+ response = self._client._get(f'{self._base_endpoint}/{id}/get_errors')
341
+ raise_on_error(response)
342
+ return parse_response(ErrorDetails, response)
343
+
344
+ def get_config(self, job_id: str) -> GenerateTrainTableRequest:
345
+ """Load the configuration for a training table generation job by ID."""
346
+ resource = self.get(job_id)
347
+ return resource.config
348
+
349
+
350
+ class GeneratePredictionTableJobAPI(
351
+ CommonJobAPI[GeneratePredictionTableRequest,
352
+ GeneratePredictionTableJobResource]):
353
+ r"""Typed API definition for prediction table generation job resource."""
354
+ def __init__(self, client: KumoClient) -> None:
355
+ super().__init__(client, '/genpredtable_jobs',
356
+ GeneratePredictionTableJobResource)
357
+
358
+ def get_anchor_time(self, id: BatchPredictionJobID) -> Optional[datetime]:
359
+ response = self._client._get(
360
+ f'{self._base_endpoint}/{id}/get_anchor_time')
361
+ raise_on_error(response)
362
+ return parse_response(Returns[Optional[datetime]], response)
363
+
364
+ def get_table_data(self, id: GeneratePredictionTableJobID,
365
+ presigned: bool = True) -> List[str]:
366
+ """Return a list of URLs to access prediction table parquet data.
367
+ There might be multiple URLs if the table data is partitioned into
368
+ multiple files.
369
+ """
370
+ params: Dict[str, Any] = {'presigned': presigned}
371
+ resp = self._client._get(f'{self._base_endpoint}/{id}/table_data',
372
+ params=params)
373
+ raise_on_error(resp)
374
+ return parse_response(List[str], resp)
375
+
376
+ def cancel(self, id: str) -> None:
377
+ response = self._client._post(f'{self._base_endpoint}/{id}/cancel')
378
+ raise_on_error(response)
379
+
380
+ def get_job_error(self, id: str) -> ErrorDetails:
381
+ """Thin API wrapper for fetching errors from the jobs.
382
+
383
+ Arguments:
384
+ id (str): Id of the job whose related errors are expected to be
385
+ queried.
386
+ """
387
+ response = self._client._get(f'{self._base_endpoint}/{id}/get_errors')
388
+ raise_on_error(response)
389
+ return parse_response(ErrorDetails, response)
390
+
391
+ def get_config(self, job_id: str) -> GeneratePredictionTableRequest:
392
+ """Load the configuration for a
393
+ prediction table generation job by ID.
394
+ """
395
+ resource = self.get(job_id)
396
+ return resource.config
397
+
398
+
399
+ class LLMJobAPI:
400
+ r"""Typed API definition for LLM job resource."""
401
+ def __init__(self, client: KumoClient) -> None:
402
+ self._client = client
403
+ self._base_endpoint = '/llm_embedding_job'
404
+
405
+ def create(self, request: LLMRequest) -> LLMJobId:
406
+ response = self._client._post(
407
+ self._base_endpoint,
408
+ json=to_json_dict(request),
409
+ )
410
+ raise_on_error(response)
411
+ return parse_response(LLMResponse, response).job_id
412
+
413
+ def get(self, id: LLMJobId) -> JobStatus:
414
+ response = self._client._get(f'{self._base_endpoint}/status/{id}')
415
+ raise_on_error(response)
416
+ return parse_response(JobStatus, response)
417
+
418
+ def cancel(self, id: LLMJobId) -> JobStatus:
419
+ response = self._client._delete(f'{self._base_endpoint}/cancel/{id}')
420
+ return response
421
+
422
+
423
+ class ArtifactExportJobAPI:
424
+ r"""Typed API definition for artifact export job resource."""
425
+ def __init__(self, client: KumoClient) -> None:
426
+ self._client = client
427
+ self._base_endpoint = '/artifact'
428
+
429
+ def create(self, request: ArtifactExportRequest) -> str:
430
+ response = self._client._post(
431
+ self._base_endpoint,
432
+ json=to_json_dict(request),
433
+ )
434
+ raise_on_error(response)
435
+ return parse_response(ArtifactExportResponse, response).job_id
436
+
437
+ # TODO Add an API in artifact export to get
438
+ # JobStatusReport and not just JobStatus
439
+ def get(self, id: str) -> JobStatus:
440
+ response = self._client._get(f'{self._base_endpoint}/{id}')
441
+ raise_on_error(response)
442
+ return parse_response(JobStatus, response)
443
+
444
+ def cancel(self, id: str) -> JobStatus:
445
+ response = self._client._post(f'{self._base_endpoint}/{id}/cancel')
446
+ raise_on_error(response)
447
+ return parse_response(JobStatus, response)
@@ -0,0 +1,78 @@
1
+ from http import HTTPStatus
2
+ from typing import Any, List, Optional
3
+
4
+ from kumoapi.json_serde import to_json_dict
5
+ from kumoapi.online_serving import (
6
+ OnlineServingEndpointRequest,
7
+ OnlineServingEndpointResource,
8
+ )
9
+
10
+ from kumoai.client import KumoClient
11
+ from kumoai.client.endpoints import OnlineServingEndpoints
12
+ from kumoai.client.utils import (
13
+ parse_id_response,
14
+ parse_patch_response,
15
+ parse_response,
16
+ raise_on_error,
17
+ )
18
+
19
+ OnlineServingEndpointID = str
20
+
21
+
22
+ class OnlineServingEndpointAPI:
23
+ r"""Typed API definition for Kumo graph definition."""
24
+ def __init__(self, client: KumoClient) -> None:
25
+ self._client = client
26
+ self._base_endpoint = '/online_serving_endpoints'
27
+
28
+ # TODO(blaz): document final interface
29
+ def create(
30
+ self,
31
+ req: OnlineServingEndpointRequest,
32
+ **query_params: Any,
33
+ ) -> OnlineServingEndpointID:
34
+ """Creates a new online serving endpoint.
35
+
36
+ Args:
37
+ req (OnlineServingEndpointRequest): request body.
38
+ use_ge (Optional[bool], optional): If present, override graph
39
+ backend option to use GRAPHENGINE if true else MEMORY.
40
+ **query_params: Additional query parameters to pass to the API.
41
+
42
+ Returns:
43
+ OnlineServingEndpointID: unique endpoint resource id.
44
+ """
45
+ resp = self._client._post(
46
+ self._base_endpoint,
47
+ params=query_params if query_params else None,
48
+ json=to_json_dict(req),
49
+ )
50
+ raise_on_error(resp)
51
+ return parse_id_response(resp)
52
+
53
+ def get_if_exists(
54
+ self, id: OnlineServingEndpointID
55
+ ) -> Optional[OnlineServingEndpointResource]:
56
+ resp = self._client._request(OnlineServingEndpoints.get.with_id(id))
57
+ if resp.status_code == HTTPStatus.NOT_FOUND:
58
+ return None
59
+
60
+ raise_on_error(resp)
61
+ return parse_response(OnlineServingEndpointResource, resp)
62
+
63
+ def list(self) -> List[OnlineServingEndpointResource]:
64
+ resp = self._client._request(OnlineServingEndpoints.list)
65
+ raise_on_error(resp)
66
+ return parse_response(List[OnlineServingEndpointResource], resp)
67
+
68
+ def update(self, id: OnlineServingEndpointID,
69
+ req: OnlineServingEndpointRequest) -> bool:
70
+ resp = self._client._request(OnlineServingEndpoints.update.with_id(id),
71
+ data=to_json_dict(req))
72
+ raise_on_error(resp)
73
+ return parse_patch_response(resp)
74
+
75
+ def delete(self, id: OnlineServingEndpointID) -> None:
76
+ """This is idempotent and can be called multiple times."""
77
+ resp = self._client._request(OnlineServingEndpoints.delete.with_id(id))
78
+ raise_on_error(resp)