oracle-ads 2.11.15__py3-none-any.whl → 2.11.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/common/entities.py +17 -0
- ads/aqua/common/enums.py +5 -1
- ads/aqua/common/utils.py +32 -2
- ads/aqua/config/config.py +1 -1
- ads/aqua/config/deployment_config_defaults.json +29 -1
- ads/aqua/config/resource_limit_names.json +1 -0
- ads/aqua/constants.py +5 -1
- ads/aqua/evaluation/entities.py +0 -1
- ads/aqua/evaluation/evaluation.py +47 -14
- ads/aqua/extension/common_ws_msg_handler.py +57 -0
- ads/aqua/extension/deployment_handler.py +14 -13
- ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
- ads/aqua/extension/errors.py +1 -1
- ads/aqua/extension/evaluation_ws_msg_handler.py +28 -6
- ads/aqua/extension/model_handler.py +31 -6
- ads/aqua/extension/models/ws_models.py +78 -3
- ads/aqua/extension/models_ws_msg_handler.py +49 -0
- ads/aqua/extension/ui_websocket_handler.py +7 -1
- ads/aqua/model/entities.py +11 -1
- ads/aqua/model/model.py +260 -90
- ads/aqua/modeldeployment/deployment.py +52 -7
- ads/aqua/modeldeployment/entities.py +9 -20
- ads/aqua/ui.py +152 -28
- ads/common/object_storage_details.py +2 -5
- ads/common/serializer.py +2 -3
- ads/jobs/builders/infrastructure/dsc_job.py +29 -3
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +74 -27
- ads/jobs/builders/runtimes/container_runtime.py +83 -4
- ads/opctl/operator/lowcode/anomaly/const.py +1 -0
- ads/opctl/operator/lowcode/anomaly/model/base_model.py +23 -7
- ads/opctl/operator/lowcode/anomaly/operator_config.py +1 -0
- ads/opctl/operator/lowcode/anomaly/schema.yaml +4 -0
- ads/opctl/operator/lowcode/common/errors.py +6 -0
- ads/opctl/operator/lowcode/forecast/model/base_model.py +21 -13
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +11 -2
- ads/pipeline/ads_pipeline_run.py +13 -2
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.16.dist-info}/METADATA +1 -1
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.16.dist-info}/RECORD +41 -37
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.16.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.16.dist-info}/WHEEL +0 -0
- {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.16.dist-info}/entry_points.txt +0 -0
@@ -1,12 +1,14 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*-
|
3
2
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
|
6
5
|
from dataclasses import dataclass, field
|
7
6
|
from typing import Union
|
8
7
|
|
9
|
-
from oci.data_science.models import
|
8
|
+
from oci.data_science.models import (
|
9
|
+
ModelDeployment,
|
10
|
+
ModelDeploymentSummary,
|
11
|
+
)
|
10
12
|
|
11
13
|
from ads.aqua.common.enums import Tags
|
12
14
|
from ads.aqua.constants import UNKNOWN, UNKNOWN_DICT
|
@@ -24,18 +26,6 @@ class ModelParams:
|
|
24
26
|
model: str = None
|
25
27
|
|
26
28
|
|
27
|
-
class ContainerSpec:
|
28
|
-
"""
|
29
|
-
Class to hold to hold keys within the container spec.
|
30
|
-
"""
|
31
|
-
|
32
|
-
CONTAINER_SPEC = "containerSpec"
|
33
|
-
CLI_PARM = "cliParam"
|
34
|
-
SERVER_PORT = "serverPort"
|
35
|
-
HEALTH_CHECK_PORT = "healthCheckPort"
|
36
|
-
ENV_VARS = "envVars"
|
37
|
-
|
38
|
-
|
39
29
|
@dataclass
|
40
30
|
class ShapeInfo:
|
41
31
|
instance_shape: str = None
|
@@ -61,6 +51,7 @@ class AquaDeployment(DataClassSerializable):
|
|
61
51
|
lifecycle_details: str = None
|
62
52
|
shape_info: field(default_factory=ShapeInfo) = None
|
63
53
|
tags: dict = None
|
54
|
+
environment_variables: dict = None
|
64
55
|
|
65
56
|
@classmethod
|
66
57
|
def from_oci_model_deployment(
|
@@ -83,15 +74,12 @@ class AquaDeployment(DataClassSerializable):
|
|
83
74
|
AquaDeployment:
|
84
75
|
The instance of the Aqua model deployment.
|
85
76
|
"""
|
86
|
-
instance_configuration =
|
87
|
-
oci_model_deployment.model_deployment_configuration_details.model_configuration_details.instance_configuration
|
88
|
-
)
|
77
|
+
instance_configuration = oci_model_deployment.model_deployment_configuration_details.model_configuration_details.instance_configuration
|
89
78
|
instance_shape_config_details = (
|
90
79
|
instance_configuration.model_deployment_instance_shape_config_details
|
91
80
|
)
|
92
|
-
instance_count =
|
93
|
-
|
94
|
-
)
|
81
|
+
instance_count = oci_model_deployment.model_deployment_configuration_details.model_configuration_details.scaling_policy.instance_count
|
82
|
+
environment_variables = oci_model_deployment.model_deployment_configuration_details.environment_configuration_details.environment_variables
|
95
83
|
shape_info = ShapeInfo(
|
96
84
|
instance_shape=instance_configuration.instance_shape_name,
|
97
85
|
instance_count=instance_count,
|
@@ -131,6 +119,7 @@ class AquaDeployment(DataClassSerializable):
|
|
131
119
|
region=region,
|
132
120
|
),
|
133
121
|
tags=freeform_tags,
|
122
|
+
environment_variables=environment_variables,
|
134
123
|
)
|
135
124
|
|
136
125
|
|
ads/aqua/ui.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*-
|
3
2
|
# Copyright (c) 2024 Oracle and/or its affiliates.
|
4
3
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5
4
|
import concurrent.futures
|
6
|
-
from dataclasses import dataclass, field
|
5
|
+
from dataclasses import dataclass, field, fields
|
7
6
|
from datetime import datetime, timedelta
|
7
|
+
from enum import Enum
|
8
8
|
from threading import Lock
|
9
|
-
from typing import Dict, List
|
9
|
+
from typing import Dict, List, Optional
|
10
10
|
|
11
11
|
from cachetools import TTLCache
|
12
12
|
from oci.exceptions import ServiceError
|
@@ -14,6 +14,7 @@ from oci.identity.models import Compartment
|
|
14
14
|
|
15
15
|
from ads.aqua import logger
|
16
16
|
from ads.aqua.app import AquaApp
|
17
|
+
from ads.aqua.common.entities import ContainerSpec
|
17
18
|
from ads.aqua.common.enums import Tags
|
18
19
|
from ads.aqua.common.errors import AquaResourceAccessError, AquaValueError
|
19
20
|
from ads.aqua.common.utils import get_container_config, load_config, sanitize_response
|
@@ -31,14 +32,84 @@ from ads.config import (
|
|
31
32
|
from ads.telemetry import telemetry
|
32
33
|
|
33
34
|
|
35
|
+
class ModelFormat(Enum):
|
36
|
+
GGUF = "GGUF"
|
37
|
+
SAFETENSORS = "SAFETENSORS"
|
38
|
+
UNKNOWN = "UNKNOWN"
|
39
|
+
|
40
|
+
def to_dict(self):
|
41
|
+
return self.value
|
42
|
+
|
43
|
+
|
44
|
+
# todo: the container config spec information is shared across ui and deployment modules, move them
|
45
|
+
# within ads.aqua.common.entities. In that case, check for circular imports due to usage of get_container_config.
|
46
|
+
|
47
|
+
|
48
|
+
@dataclass(repr=False)
|
49
|
+
class AquaContainerEvaluationConfig(DataClassSerializable):
|
50
|
+
"""
|
51
|
+
Represents the evaluation configuration for the container.
|
52
|
+
"""
|
53
|
+
|
54
|
+
inference_max_threads: Optional[int] = None
|
55
|
+
inference_rps: Optional[int] = None
|
56
|
+
inference_timeout: Optional[int] = None
|
57
|
+
inference_retries: Optional[int] = None
|
58
|
+
inference_backoff_factor: Optional[int] = None
|
59
|
+
inference_delay: Optional[int] = None
|
60
|
+
|
61
|
+
@classmethod
|
62
|
+
def from_config(cls, config: dict) -> "AquaContainerEvaluationConfig":
|
63
|
+
return cls(
|
64
|
+
inference_max_threads=config.get("inference_max_threads"),
|
65
|
+
inference_rps=config.get("inference_rps"),
|
66
|
+
inference_timeout=config.get("inference_timeout"),
|
67
|
+
inference_retries=config.get("inference_retries"),
|
68
|
+
inference_backoff_factor=config.get("inference_backoff_factor"),
|
69
|
+
inference_delay=config.get("inference_delay"),
|
70
|
+
)
|
71
|
+
|
72
|
+
def to_filtered_dict(self):
|
73
|
+
return {
|
74
|
+
field.name: getattr(self, field.name)
|
75
|
+
for field in fields(self)
|
76
|
+
if getattr(self, field.name) is not None
|
77
|
+
}
|
78
|
+
|
79
|
+
|
80
|
+
@dataclass(repr=False)
|
81
|
+
class AquaContainerConfigSpec(DataClassSerializable):
|
82
|
+
cli_param: str = None
|
83
|
+
server_port: str = None
|
84
|
+
health_check_port: str = None
|
85
|
+
env_vars: List[dict] = None
|
86
|
+
restricted_params: List[str] = None
|
87
|
+
evaluation_configuration: AquaContainerEvaluationConfig = field(
|
88
|
+
default_factory=AquaContainerEvaluationConfig
|
89
|
+
)
|
90
|
+
|
91
|
+
|
34
92
|
@dataclass(repr=False)
|
35
93
|
class AquaContainerConfigItem(DataClassSerializable):
|
36
94
|
"""Represents an item of the AQUA container configuration."""
|
37
95
|
|
96
|
+
class Platform(Enum):
|
97
|
+
ARM_CPU = "ARM_CPU"
|
98
|
+
NVIDIA_GPU = "NVIDIA_GPU"
|
99
|
+
|
100
|
+
def to_dict(self):
|
101
|
+
return self.value
|
102
|
+
|
103
|
+
def __repr__(self):
|
104
|
+
return repr(self.value)
|
105
|
+
|
38
106
|
name: str = None
|
39
107
|
version: str = None
|
40
108
|
display_name: str = None
|
41
109
|
family: str = None
|
110
|
+
platforms: List[Platform] = None
|
111
|
+
model_formats: List[ModelFormat] = None
|
112
|
+
spec: AquaContainerConfigSpec = field(default_factory=AquaContainerConfigSpec)
|
42
113
|
|
43
114
|
|
44
115
|
@dataclass(repr=False)
|
@@ -47,12 +118,23 @@ class AquaContainerConfig(DataClassSerializable):
|
|
47
118
|
Represents a configuration with AQUA containers to be returned to the client.
|
48
119
|
"""
|
49
120
|
|
50
|
-
inference:
|
51
|
-
finetune:
|
52
|
-
evaluate:
|
121
|
+
inference: Dict[str, AquaContainerConfigItem] = field(default_factory=dict)
|
122
|
+
finetune: Dict[str, AquaContainerConfigItem] = field(default_factory=dict)
|
123
|
+
evaluate: Dict[str, AquaContainerConfigItem] = field(default_factory=dict)
|
124
|
+
|
125
|
+
def to_dict(self):
|
126
|
+
return {
|
127
|
+
"inference": list(self.inference.values()),
|
128
|
+
"finetune": list(self.finetune.values()),
|
129
|
+
"evaluate": list(self.evaluate.values()),
|
130
|
+
}
|
53
131
|
|
54
132
|
@classmethod
|
55
|
-
def from_container_index_json(
|
133
|
+
def from_container_index_json(
|
134
|
+
cls,
|
135
|
+
config: Optional[Dict] = None,
|
136
|
+
enable_spec: Optional[bool] = False,
|
137
|
+
) -> "AquaContainerConfig":
|
56
138
|
"""
|
57
139
|
Create an AquaContainerConfig instance from a container index JSON.
|
58
140
|
|
@@ -60,21 +142,39 @@ class AquaContainerConfig(DataClassSerializable):
|
|
60
142
|
----------
|
61
143
|
config : Dict
|
62
144
|
The container index JSON.
|
145
|
+
enable_spec: bool
|
146
|
+
flag to check if container specification details should be fetched.
|
63
147
|
|
64
148
|
Returns
|
65
149
|
-------
|
66
150
|
AquaContainerConfig
|
67
151
|
The container configuration instance.
|
68
152
|
"""
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
153
|
+
if not config:
|
154
|
+
config = get_container_config()
|
155
|
+
inference_items = {}
|
156
|
+
finetune_items = {}
|
157
|
+
evaluate_items = {}
|
73
158
|
|
74
159
|
# extract inference containers
|
75
160
|
for container_type, containers in config.items():
|
76
161
|
if isinstance(containers, list):
|
77
162
|
for container in containers:
|
163
|
+
platforms = [
|
164
|
+
AquaContainerConfigItem.Platform[platform]
|
165
|
+
for platform in container.get("platforms", [])
|
166
|
+
]
|
167
|
+
model_formats = [
|
168
|
+
ModelFormat[model_format]
|
169
|
+
for model_format in container.get("modelFormats", [])
|
170
|
+
]
|
171
|
+
container_spec = (
|
172
|
+
config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
|
173
|
+
container_type, {}
|
174
|
+
)
|
175
|
+
if enable_spec
|
176
|
+
else None
|
177
|
+
)
|
78
178
|
container_item = AquaContainerConfigItem(
|
79
179
|
name=container.get("name", ""),
|
80
180
|
version=container.get("version", ""),
|
@@ -82,13 +182,35 @@ class AquaContainerConfig(DataClassSerializable):
|
|
82
182
|
"displayName", container.get("version", "")
|
83
183
|
),
|
84
184
|
family=container_type,
|
185
|
+
platforms=platforms,
|
186
|
+
model_formats=model_formats,
|
187
|
+
spec=AquaContainerConfigSpec(
|
188
|
+
cli_param=container_spec.get(ContainerSpec.CLI_PARM, ""),
|
189
|
+
server_port=container_spec.get(
|
190
|
+
ContainerSpec.SERVER_PORT, ""
|
191
|
+
),
|
192
|
+
health_check_port=container_spec.get(
|
193
|
+
ContainerSpec.HEALTH_CHECK_PORT, ""
|
194
|
+
),
|
195
|
+
env_vars=container_spec.get(ContainerSpec.ENV_VARS, []),
|
196
|
+
restricted_params=container_spec.get(
|
197
|
+
ContainerSpec.RESTRICTED_PARAMS, []
|
198
|
+
),
|
199
|
+
evaluation_configuration=AquaContainerEvaluationConfig.from_config(
|
200
|
+
container_spec.get(
|
201
|
+
ContainerSpec.EVALUATION_CONFIGURATION, {}
|
202
|
+
)
|
203
|
+
),
|
204
|
+
)
|
205
|
+
if container_spec
|
206
|
+
else None,
|
85
207
|
)
|
86
208
|
if container.get("type") == "inference":
|
87
|
-
inference_items
|
209
|
+
inference_items[container_type] = container_item
|
88
210
|
elif container_type == "odsc-llm-fine-tuning":
|
89
|
-
finetune_items
|
211
|
+
finetune_items[container_type] = container_item
|
90
212
|
elif container_type == "odsc-llm-evaluate":
|
91
|
-
evaluate_items
|
213
|
+
evaluate_items[container_type] = container_item
|
92
214
|
|
93
215
|
return AquaContainerConfig(
|
94
216
|
inference=inference_items, finetune=finetune_items, evaluate=evaluate_items
|
@@ -175,11 +297,11 @@ class AquaUIApp(AquaApp):
|
|
175
297
|
try:
|
176
298
|
if not TENANCY_OCID:
|
177
299
|
raise AquaValueError(
|
178
|
-
|
300
|
+
"TENANCY_OCID must be available in environment"
|
179
301
|
" variables to list the sub compartments."
|
180
302
|
)
|
181
303
|
|
182
|
-
if TENANCY_OCID in self._compartments_cache
|
304
|
+
if TENANCY_OCID in self._compartments_cache:
|
183
305
|
logger.info(
|
184
306
|
f"Returning compartments list in {TENANCY_OCID} from cache."
|
185
307
|
)
|
@@ -196,7 +318,7 @@ class AquaUIApp(AquaApp):
|
|
196
318
|
access_level="ANY",
|
197
319
|
)
|
198
320
|
)
|
199
|
-
except ServiceError
|
321
|
+
except ServiceError:
|
200
322
|
logger.error(
|
201
323
|
f"ERROR: Unable to list all sub compartment in tenancy {TENANCY_OCID}."
|
202
324
|
)
|
@@ -207,7 +329,7 @@ class AquaUIApp(AquaApp):
|
|
207
329
|
compartment_id=TENANCY_OCID,
|
208
330
|
)
|
209
331
|
)
|
210
|
-
except ServiceError
|
332
|
+
except ServiceError:
|
211
333
|
logger.error(
|
212
334
|
f"ERROR: Unable to list all child compartment in tenancy {TENANCY_OCID}."
|
213
335
|
)
|
@@ -216,7 +338,7 @@ class AquaUIApp(AquaApp):
|
|
216
338
|
TENANCY_OCID
|
217
339
|
).data
|
218
340
|
compartments.insert(0, root_compartment)
|
219
|
-
except ServiceError
|
341
|
+
except ServiceError:
|
220
342
|
logger.error(
|
221
343
|
f"ERROR: Unable to get details of the root compartment {TENANCY_OCID}."
|
222
344
|
)
|
@@ -248,7 +370,7 @@ class AquaUIApp(AquaApp):
|
|
248
370
|
"""
|
249
371
|
if not COMPARTMENT_OCID:
|
250
372
|
logger.error("No compartment id found from environment variables.")
|
251
|
-
return
|
373
|
+
return {"compartment_id": COMPARTMENT_OCID}
|
252
374
|
|
253
375
|
def clear_compartments_list_cache(self) -> dict:
|
254
376
|
"""Allows caller to clear compartments list cache
|
@@ -257,9 +379,9 @@ class AquaUIApp(AquaApp):
|
|
257
379
|
dict with the key used, and True if cache has the key that needs to be deleted.
|
258
380
|
"""
|
259
381
|
res = {}
|
260
|
-
logger.info(
|
382
|
+
logger.info("Clearing list_compartments cache")
|
261
383
|
with self._cache_lock:
|
262
|
-
if TENANCY_OCID in self._compartments_cache
|
384
|
+
if TENANCY_OCID in self._compartments_cache:
|
263
385
|
self._compartments_cache.pop(key=TENANCY_OCID)
|
264
386
|
res = {
|
265
387
|
"key": {
|
@@ -332,6 +454,7 @@ class AquaUIApp(AquaApp):
|
|
332
454
|
response = os_client.list_buckets(
|
333
455
|
namespace_name=namespace_name,
|
334
456
|
compartment_id=compartment_id,
|
457
|
+
limit=1000,
|
335
458
|
**kwargs,
|
336
459
|
).data
|
337
460
|
|
@@ -474,16 +597,16 @@ class AquaUIApp(AquaApp):
|
|
474
597
|
raise AquaResourceAccessError(
|
475
598
|
f"Could not check limits availability for the shape {instance_shape}. Make sure you have the necessary policy to check limits availability.",
|
476
599
|
service_payload=se.args[0] if se.args else None,
|
477
|
-
)
|
600
|
+
) from None
|
478
601
|
|
479
602
|
available = res.available
|
480
603
|
|
481
604
|
try:
|
482
605
|
cards = int(instance_shape.split(".")[-1])
|
483
|
-
except:
|
606
|
+
except Exception:
|
484
607
|
cards = 1
|
485
608
|
|
486
|
-
response =
|
609
|
+
response = {"available_count": available}
|
487
610
|
|
488
611
|
if available < cards:
|
489
612
|
raise AquaValueError(
|
@@ -516,7 +639,7 @@ class AquaUIApp(AquaApp):
|
|
516
639
|
is_versioned = False
|
517
640
|
message = f"Model artifact bucket {bucket_uri} is not versioned. Check if the path exists and enable versioning on the bucket to proceed with model creation."
|
518
641
|
|
519
|
-
return
|
642
|
+
return {"is_versioned": is_versioned, "message": message}
|
520
643
|
|
521
644
|
@telemetry(entry_point="plugin=ui&action=list_containers", name="aqua")
|
522
645
|
def list_containers(self) -> AquaContainerConfig:
|
@@ -526,8 +649,9 @@ class AquaUIApp(AquaApp):
|
|
526
649
|
Returns
|
527
650
|
-------
|
528
651
|
AquaContainerConfig
|
529
|
-
The AQUA containers
|
652
|
+
The AQUA containers configurations.
|
530
653
|
"""
|
531
654
|
return AquaContainerConfig.from_container_index_json(
|
532
|
-
config=get_container_config()
|
655
|
+
config=get_container_config(),
|
656
|
+
enable_spec=True,
|
533
657
|
)
|
@@ -1,5 +1,4 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
2
|
|
4
3
|
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
@@ -7,16 +6,15 @@
|
|
7
6
|
import json
|
8
7
|
import os
|
9
8
|
import re
|
9
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
10
10
|
from dataclasses import dataclass
|
11
11
|
from typing import Dict, List
|
12
12
|
from urllib.parse import urlparse
|
13
13
|
|
14
|
-
|
15
14
|
import oci
|
16
15
|
from ads.common import auth as authutil
|
17
16
|
from ads.common import oci_client
|
18
17
|
from ads.dataset.progress import TqdmProgressBar
|
19
|
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
20
18
|
|
21
19
|
THREAD_POOL_MAX_WORKERS = 10
|
22
20
|
|
@@ -169,8 +167,7 @@ class ObjectStorageDetails:
|
|
169
167
|
|
170
168
|
def list_objects(self, **kwargs):
|
171
169
|
"""Lists objects in a given oss path
|
172
|
-
|
173
|
-
Parameters
|
170
|
+
Parameters
|
174
171
|
-------
|
175
172
|
**kwargs:
|
176
173
|
namespace, bucket, filepath are set by the class. By default, fields gets all values. For other supported
|
ads/common/serializer.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8; -*-
|
3
2
|
|
4
3
|
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
|
5
4
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
@@ -25,10 +24,8 @@ from ads.common import logger
|
|
25
24
|
from ads.common.auth import default_signer
|
26
25
|
|
27
26
|
try:
|
28
|
-
from yaml import CSafeDumper as dumper
|
29
27
|
from yaml import CSafeLoader as loader
|
30
28
|
except:
|
31
|
-
from yaml import SafeDumper as dumper
|
32
29
|
from yaml import SafeLoader as loader
|
33
30
|
|
34
31
|
|
@@ -99,6 +96,8 @@ class Serializable(ABC):
|
|
99
96
|
"""JSON serializer for objects not serializable by default json code."""
|
100
97
|
if isinstance(obj, datetime):
|
101
98
|
return obj.isoformat()
|
99
|
+
if hasattr(obj, "to_dict"):
|
100
|
+
return obj.to_dict()
|
102
101
|
raise TypeError(f"Type {type(obj)} not serializable.")
|
103
102
|
|
104
103
|
@staticmethod
|
@@ -30,6 +30,7 @@ from ads.common.oci_logging import OCILog
|
|
30
30
|
from ads.common.oci_resource import ResourceNotFoundError
|
31
31
|
from ads.jobs.builders.infrastructure.base import Infrastructure, RunInstance
|
32
32
|
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
|
33
|
+
ContainerRuntimeHandler,
|
33
34
|
DataScienceJobRuntimeManager,
|
34
35
|
)
|
35
36
|
from ads.jobs.builders.infrastructure.utils import get_value
|
@@ -458,7 +459,7 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
|
|
458
459
|
----------
|
459
460
|
**kwargs :
|
460
461
|
Keyword arguments for initializing a Data Science Job Run.
|
461
|
-
The keys can be any keys in supported by OCI JobConfigurationDetails and JobRun, including:
|
462
|
+
The keys can be any keys in supported by OCI JobConfigurationDetails, OcirContainerJobEnvironmentConfigurationDetails and JobRun, including:
|
462
463
|
* hyperparameter_values: dict(str, str)
|
463
464
|
* environment_variables: dict(str, str)
|
464
465
|
* command_line_arguments: str
|
@@ -466,6 +467,11 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
|
|
466
467
|
* display_name: str
|
467
468
|
* freeform_tags: dict(str, str)
|
468
469
|
* defined_tags: dict(str, dict(str, object))
|
470
|
+
* image: str
|
471
|
+
* cmd: list[str]
|
472
|
+
* entrypoint: list[str]
|
473
|
+
* image_digest: str
|
474
|
+
* image_signature_id: str
|
469
475
|
|
470
476
|
If display_name is not specified, it will be generated as "<JOB_NAME>-run-<TIMESTAMP>".
|
471
477
|
|
@@ -478,14 +484,28 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
|
|
478
484
|
if not self.id:
|
479
485
|
self.create()
|
480
486
|
|
481
|
-
|
487
|
+
config_swagger_types = (
|
482
488
|
oci.data_science.models.DefaultJobConfigurationDetails().swagger_types.keys()
|
483
489
|
)
|
490
|
+
env_config_swagger_types = {}
|
491
|
+
if hasattr(oci.data_science.models, "OcirContainerJobEnvironmentConfigurationDetails"):
|
492
|
+
env_config_swagger_types = (
|
493
|
+
oci.data_science.models.OcirContainerJobEnvironmentConfigurationDetails().swagger_types.keys()
|
494
|
+
)
|
484
495
|
config_kwargs = {}
|
496
|
+
env_config_kwargs = {}
|
485
497
|
keys = list(kwargs.keys())
|
486
498
|
for key in keys:
|
487
|
-
if key in
|
499
|
+
if key in config_swagger_types:
|
488
500
|
config_kwargs[key] = kwargs.pop(key)
|
501
|
+
elif key in env_config_swagger_types:
|
502
|
+
value = kwargs.pop(key)
|
503
|
+
if key in [
|
504
|
+
ContainerRuntime.CONST_CMD,
|
505
|
+
ContainerRuntime.CONST_ENTRYPOINT
|
506
|
+
] and isinstance(value, str):
|
507
|
+
value = ContainerRuntimeHandler.split_args(value)
|
508
|
+
env_config_kwargs[key] = value
|
489
509
|
|
490
510
|
# remove timestamp from the job name (added in default names, when display_name not specified by user)
|
491
511
|
if self.display_name:
|
@@ -514,6 +534,12 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
|
|
514
534
|
config_override.update(config_kwargs)
|
515
535
|
kwargs["job_configuration_override_details"] = config_override
|
516
536
|
|
537
|
+
if env_config_kwargs:
|
538
|
+
env_config_kwargs["jobEnvironmentType"] = "OCIR_CONTAINER"
|
539
|
+
env_config_override = kwargs.get("job_environment_configuration_override_details", {})
|
540
|
+
env_config_override.update(env_config_kwargs)
|
541
|
+
kwargs["job_environment_configuration_override_details"] = env_config_override
|
542
|
+
|
517
543
|
wait = kwargs.pop("wait", False)
|
518
544
|
run = DataScienceJobRun(**kwargs, **self.auth).create()
|
519
545
|
if wait:
|
@@ -1,7 +1,7 @@
|
|
1
1
|
#!/usr/bin/env python
|
2
2
|
# -*- coding: utf-8; -*-
|
3
3
|
|
4
|
-
# Copyright (c) 2021,
|
4
|
+
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
|
5
5
|
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
6
|
"""Contains classes for conversion between ADS runtime and OCI Data Science Job implementation.
|
7
7
|
This module is for ADS developers only.
|
@@ -305,10 +305,29 @@ class RuntimeHandler:
|
|
305
305
|
self._extract_envs,
|
306
306
|
self._extract_artifact,
|
307
307
|
self._extract_runtime_minutes,
|
308
|
+
self._extract_properties,
|
308
309
|
]
|
309
310
|
for extraction in extractions:
|
310
311
|
runtime_spec.update(extraction(dsc_job))
|
311
312
|
return self.RUNTIME_CLASS(self._format_env_var(runtime_spec))
|
313
|
+
|
314
|
+
def _extract_properties(self, dsc_job) -> dict:
|
315
|
+
"""Extract the job runtime properties from data science job.
|
316
|
+
|
317
|
+
This is the base method which does not extract the job runtime properties.
|
318
|
+
Sub-class should implement the extraction if needed.
|
319
|
+
|
320
|
+
Parameters
|
321
|
+
----------
|
322
|
+
dsc_job : DSCJob or oci.datascience.models.Job
|
323
|
+
The data science job containing runtime information.
|
324
|
+
|
325
|
+
Returns
|
326
|
+
-------
|
327
|
+
dict
|
328
|
+
A runtime specification dictionary for initializing a runtime.
|
329
|
+
"""
|
330
|
+
return {}
|
312
331
|
|
313
332
|
def _extract_args(self, dsc_job) -> dict:
|
314
333
|
"""Extracts the command line arguments from data science job.
|
@@ -942,9 +961,12 @@ class GitPythonRuntimeHandler(CondaRuntimeHandler):
|
|
942
961
|
class ContainerRuntimeHandler(RuntimeHandler):
|
943
962
|
RUNTIME_CLASS = ContainerRuntime
|
944
963
|
CMD_DELIMITER = ","
|
945
|
-
|
946
|
-
|
947
|
-
|
964
|
+
|
965
|
+
def translate(self, runtime: Runtime) -> dict:
|
966
|
+
payload = super().translate(runtime)
|
967
|
+
job_env_config = self._translate_env_config(runtime)
|
968
|
+
payload["job_environment_configuration_details"] = job_env_config
|
969
|
+
return payload
|
948
970
|
|
949
971
|
def _translate_artifact(self, runtime: Runtime):
|
950
972
|
"""Specifies a dummy script as the job artifact.
|
@@ -964,29 +986,34 @@ class ContainerRuntimeHandler(RuntimeHandler):
|
|
964
986
|
os.path.dirname(__file__), "../../templates", "container.py"
|
965
987
|
)
|
966
988
|
|
967
|
-
def
|
968
|
-
"""
|
989
|
+
def _translate_env_config(self, runtime: Runtime) -> dict:
|
990
|
+
"""Converts runtime properties to ``OcirContainerJobEnvironmentConfigurationDetails`` payload required by OCI Data Science job.
|
969
991
|
|
970
992
|
Parameters
|
971
993
|
----------
|
972
|
-
runtime :
|
973
|
-
|
994
|
+
runtime : Runtime
|
995
|
+
The runtime containing the properties to be converted.
|
974
996
|
|
975
997
|
Returns
|
976
998
|
-------
|
977
999
|
dict
|
978
|
-
A dictionary
|
1000
|
+
A dictionary storing the ``OcirContainerJobEnvironmentConfigurationDetails`` payload for OCI data science job.
|
979
1001
|
"""
|
980
|
-
|
981
|
-
|
982
|
-
envs = super()._translate_env(runtime)
|
983
|
-
spec_mappings = {
|
984
|
-
ContainerRuntime.CONST_IMAGE: self.CONST_CONTAINER_IMAGE,
|
985
|
-
ContainerRuntime.CONST_ENTRYPOINT: self.CONST_CONTAINER_ENTRYPOINT,
|
986
|
-
ContainerRuntime.CONST_CMD: self.CONST_CONTAINER_CMD,
|
1002
|
+
job_environment_configuration_details = {
|
1003
|
+
"job_environment_type": runtime.job_env_type
|
987
1004
|
}
|
988
|
-
|
989
|
-
|
1005
|
+
|
1006
|
+
for key, value in ContainerRuntime.attribute_map.items():
|
1007
|
+
property = runtime.get_spec(key, None)
|
1008
|
+
if key in [
|
1009
|
+
ContainerRuntime.CONST_CMD,
|
1010
|
+
ContainerRuntime.CONST_ENTRYPOINT
|
1011
|
+
] and isinstance(property, str):
|
1012
|
+
property = self.split_args(property)
|
1013
|
+
if property is not None:
|
1014
|
+
job_environment_configuration_details[value] = property
|
1015
|
+
|
1016
|
+
return job_environment_configuration_details
|
990
1017
|
|
991
1018
|
@staticmethod
|
992
1019
|
def split_args(args: str) -> list:
|
@@ -1031,17 +1058,37 @@ class ContainerRuntimeHandler(RuntimeHandler):
|
|
1031
1058
|
"""
|
1032
1059
|
spec = super()._extract_envs(dsc_job)
|
1033
1060
|
envs = spec.pop(ContainerRuntime.CONST_ENV_VAR, {})
|
1034
|
-
|
1035
|
-
raise IncompatibleRuntime()
|
1036
|
-
spec[ContainerRuntime.CONST_IMAGE] = envs.pop(self.CONST_CONTAINER_IMAGE)
|
1037
|
-
cmd = self.split_args(envs.pop(self.CONST_CONTAINER_CMD, ""))
|
1038
|
-
if cmd:
|
1039
|
-
spec[ContainerRuntime.CONST_CMD] = cmd
|
1040
|
-
entrypoint = self.split_args(envs.pop(self.CONST_CONTAINER_ENTRYPOINT, ""))
|
1041
|
-
if entrypoint:
|
1042
|
-
spec[ContainerRuntime.CONST_ENTRYPOINT] = entrypoint
|
1061
|
+
|
1043
1062
|
if envs:
|
1044
1063
|
spec[ContainerRuntime.CONST_ENV_VAR] = envs
|
1064
|
+
|
1065
|
+
return spec
|
1066
|
+
|
1067
|
+
def _extract_properties(self, dsc_job) -> dict:
|
1068
|
+
"""Extract the runtime properties from data science job.
|
1069
|
+
|
1070
|
+
Parameters
|
1071
|
+
----------
|
1072
|
+
dsc_job : DSCJob or oci.datascience.models.Job
|
1073
|
+
The data science job containing runtime information.
|
1074
|
+
|
1075
|
+
Returns
|
1076
|
+
-------
|
1077
|
+
dict
|
1078
|
+
A runtime specification dictionary for initializing a runtime.
|
1079
|
+
"""
|
1080
|
+
spec = super()._extract_envs(dsc_job)
|
1081
|
+
|
1082
|
+
job_env_config = getattr(dsc_job, "job_environment_configuration_details", None)
|
1083
|
+
job_env_type = getattr(job_env_config, "job_environment_type", None)
|
1084
|
+
|
1085
|
+
if not (job_env_config and job_env_type == "OCIR_CONTAINER"):
|
1086
|
+
raise IncompatibleRuntime()
|
1087
|
+
|
1088
|
+
for key, value in ContainerRuntime.attribute_map.items():
|
1089
|
+
property = getattr(job_env_config, value, None)
|
1090
|
+
if property is not None:
|
1091
|
+
spec[key] = property
|
1045
1092
|
return spec
|
1046
1093
|
|
1047
1094
|
|