oracle-ads 2.11.14__py3-none-any.whl → 2.11.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. ads/aqua/common/entities.py +17 -0
  2. ads/aqua/common/enums.py +5 -1
  3. ads/aqua/common/utils.py +109 -22
  4. ads/aqua/config/config.py +1 -1
  5. ads/aqua/config/deployment_config_defaults.json +29 -1
  6. ads/aqua/config/resource_limit_names.json +1 -0
  7. ads/aqua/constants.py +35 -18
  8. ads/aqua/evaluation/entities.py +0 -1
  9. ads/aqua/evaluation/evaluation.py +165 -121
  10. ads/aqua/extension/common_ws_msg_handler.py +57 -0
  11. ads/aqua/extension/deployment_handler.py +14 -13
  12. ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
  13. ads/aqua/extension/errors.py +1 -1
  14. ads/aqua/extension/evaluation_handler.py +4 -7
  15. ads/aqua/extension/evaluation_ws_msg_handler.py +28 -10
  16. ads/aqua/extension/model_handler.py +31 -6
  17. ads/aqua/extension/models/ws_models.py +78 -3
  18. ads/aqua/extension/models_ws_msg_handler.py +49 -0
  19. ads/aqua/extension/ui_websocket_handler.py +7 -1
  20. ads/aqua/model/entities.py +17 -9
  21. ads/aqua/model/model.py +260 -90
  22. ads/aqua/modeldeployment/constants.py +0 -16
  23. ads/aqua/modeldeployment/deployment.py +97 -74
  24. ads/aqua/modeldeployment/entities.py +9 -20
  25. ads/aqua/ui.py +152 -28
  26. ads/common/object_storage_details.py +2 -5
  27. ads/common/serializer.py +2 -3
  28. ads/jobs/builders/infrastructure/dsc_job.py +29 -3
  29. ads/jobs/builders/infrastructure/dsc_job_runtime.py +74 -27
  30. ads/jobs/builders/runtimes/container_runtime.py +83 -4
  31. ads/opctl/operator/common/operator_config.py +1 -0
  32. ads/opctl/operator/lowcode/anomaly/README.md +3 -3
  33. ads/opctl/operator/lowcode/anomaly/__main__.py +5 -6
  34. ads/opctl/operator/lowcode/anomaly/const.py +9 -0
  35. ads/opctl/operator/lowcode/anomaly/model/anomaly_dataset.py +6 -2
  36. ads/opctl/operator/lowcode/anomaly/model/base_model.py +51 -26
  37. ads/opctl/operator/lowcode/anomaly/model/factory.py +41 -13
  38. ads/opctl/operator/lowcode/anomaly/model/isolationforest.py +79 -0
  39. ads/opctl/operator/lowcode/anomaly/model/oneclasssvm.py +79 -0
  40. ads/opctl/operator/lowcode/anomaly/operator_config.py +1 -0
  41. ads/opctl/operator/lowcode/anomaly/schema.yaml +16 -2
  42. ads/opctl/operator/lowcode/anomaly/utils.py +16 -13
  43. ads/opctl/operator/lowcode/common/data.py +2 -1
  44. ads/opctl/operator/lowcode/common/errors.py +6 -0
  45. ads/opctl/operator/lowcode/common/transformations.py +37 -9
  46. ads/opctl/operator/lowcode/common/utils.py +32 -10
  47. ads/opctl/operator/lowcode/forecast/model/base_model.py +21 -13
  48. ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +14 -18
  49. ads/opctl/operator/lowcode/forecast/model_evaluator.py +15 -4
  50. ads/opctl/operator/lowcode/forecast/schema.yaml +9 -0
  51. ads/opctl/operator/lowcode/recommender/MLoperator +16 -0
  52. ads/opctl/operator/lowcode/recommender/README.md +206 -0
  53. ads/opctl/operator/lowcode/recommender/__init__.py +5 -0
  54. ads/opctl/operator/lowcode/recommender/__main__.py +82 -0
  55. ads/opctl/operator/lowcode/recommender/cmd.py +33 -0
  56. ads/opctl/operator/lowcode/recommender/constant.py +25 -0
  57. ads/opctl/operator/lowcode/recommender/environment.yaml +11 -0
  58. ads/opctl/operator/lowcode/recommender/model/base_model.py +198 -0
  59. ads/opctl/operator/lowcode/recommender/model/factory.py +58 -0
  60. ads/opctl/operator/lowcode/recommender/model/recommender_dataset.py +25 -0
  61. ads/opctl/operator/lowcode/recommender/model/svd.py +88 -0
  62. ads/opctl/operator/lowcode/recommender/operator_config.py +81 -0
  63. ads/opctl/operator/lowcode/recommender/schema.yaml +265 -0
  64. ads/opctl/operator/lowcode/recommender/utils.py +13 -0
  65. ads/pipeline/ads_pipeline_run.py +13 -2
  66. {oracle_ads-2.11.14.dist-info → oracle_ads-2.11.16.dist-info}/METADATA +6 -1
  67. {oracle_ads-2.11.14.dist-info → oracle_ads-2.11.16.dist-info}/RECORD +70 -50
  68. {oracle_ads-2.11.14.dist-info → oracle_ads-2.11.16.dist-info}/LICENSE.txt +0 -0
  69. {oracle_ads-2.11.14.dist-info → oracle_ads-2.11.16.dist-info}/WHEEL +0 -0
  70. {oracle_ads-2.11.14.dist-info → oracle_ads-2.11.16.dist-info}/entry_points.txt +0 -0
ads/aqua/ui.py CHANGED
@@ -1,12 +1,12 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
  import concurrent.futures
6
- from dataclasses import dataclass, field
5
+ from dataclasses import dataclass, field, fields
7
6
  from datetime import datetime, timedelta
7
+ from enum import Enum
8
8
  from threading import Lock
9
- from typing import Dict, List
9
+ from typing import Dict, List, Optional
10
10
 
11
11
  from cachetools import TTLCache
12
12
  from oci.exceptions import ServiceError
@@ -14,6 +14,7 @@ from oci.identity.models import Compartment
14
14
 
15
15
  from ads.aqua import logger
16
16
  from ads.aqua.app import AquaApp
17
+ from ads.aqua.common.entities import ContainerSpec
17
18
  from ads.aqua.common.enums import Tags
18
19
  from ads.aqua.common.errors import AquaResourceAccessError, AquaValueError
19
20
  from ads.aqua.common.utils import get_container_config, load_config, sanitize_response
@@ -31,14 +32,84 @@ from ads.config import (
31
32
  from ads.telemetry import telemetry
32
33
 
33
34
 
35
+ class ModelFormat(Enum):
36
+ GGUF = "GGUF"
37
+ SAFETENSORS = "SAFETENSORS"
38
+ UNKNOWN = "UNKNOWN"
39
+
40
+ def to_dict(self):
41
+ return self.value
42
+
43
+
44
+ # todo: the container config spec information is shared across ui and deployment modules, move them
45
+ # within ads.aqua.common.entities. In that case, check for circular imports due to usage of get_container_config.
46
+
47
+
48
+ @dataclass(repr=False)
49
+ class AquaContainerEvaluationConfig(DataClassSerializable):
50
+ """
51
+ Represents the evaluation configuration for the container.
52
+ """
53
+
54
+ inference_max_threads: Optional[int] = None
55
+ inference_rps: Optional[int] = None
56
+ inference_timeout: Optional[int] = None
57
+ inference_retries: Optional[int] = None
58
+ inference_backoff_factor: Optional[int] = None
59
+ inference_delay: Optional[int] = None
60
+
61
+ @classmethod
62
+ def from_config(cls, config: dict) -> "AquaContainerEvaluationConfig":
63
+ return cls(
64
+ inference_max_threads=config.get("inference_max_threads"),
65
+ inference_rps=config.get("inference_rps"),
66
+ inference_timeout=config.get("inference_timeout"),
67
+ inference_retries=config.get("inference_retries"),
68
+ inference_backoff_factor=config.get("inference_backoff_factor"),
69
+ inference_delay=config.get("inference_delay"),
70
+ )
71
+
72
+ def to_filtered_dict(self):
73
+ return {
74
+ field.name: getattr(self, field.name)
75
+ for field in fields(self)
76
+ if getattr(self, field.name) is not None
77
+ }
78
+
79
+
80
+ @dataclass(repr=False)
81
+ class AquaContainerConfigSpec(DataClassSerializable):
82
+ cli_param: str = None
83
+ server_port: str = None
84
+ health_check_port: str = None
85
+ env_vars: List[dict] = None
86
+ restricted_params: List[str] = None
87
+ evaluation_configuration: AquaContainerEvaluationConfig = field(
88
+ default_factory=AquaContainerEvaluationConfig
89
+ )
90
+
91
+
34
92
  @dataclass(repr=False)
35
93
  class AquaContainerConfigItem(DataClassSerializable):
36
94
  """Represents an item of the AQUA container configuration."""
37
95
 
96
+ class Platform(Enum):
97
+ ARM_CPU = "ARM_CPU"
98
+ NVIDIA_GPU = "NVIDIA_GPU"
99
+
100
+ def to_dict(self):
101
+ return self.value
102
+
103
+ def __repr__(self):
104
+ return repr(self.value)
105
+
38
106
  name: str = None
39
107
  version: str = None
40
108
  display_name: str = None
41
109
  family: str = None
110
+ platforms: List[Platform] = None
111
+ model_formats: List[ModelFormat] = None
112
+ spec: AquaContainerConfigSpec = field(default_factory=AquaContainerConfigSpec)
42
113
 
43
114
 
44
115
  @dataclass(repr=False)
@@ -47,12 +118,23 @@ class AquaContainerConfig(DataClassSerializable):
47
118
  Represents a configuration with AQUA containers to be returned to the client.
48
119
  """
49
120
 
50
- inference: List[AquaContainerConfigItem] = field(default_factory=list)
51
- finetune: List[AquaContainerConfigItem] = field(default_factory=list)
52
- evaluate: List[AquaContainerConfigItem] = field(default_factory=list)
121
+ inference: Dict[str, AquaContainerConfigItem] = field(default_factory=dict)
122
+ finetune: Dict[str, AquaContainerConfigItem] = field(default_factory=dict)
123
+ evaluate: Dict[str, AquaContainerConfigItem] = field(default_factory=dict)
124
+
125
+ def to_dict(self):
126
+ return {
127
+ "inference": list(self.inference.values()),
128
+ "finetune": list(self.finetune.values()),
129
+ "evaluate": list(self.evaluate.values()),
130
+ }
53
131
 
54
132
  @classmethod
55
- def from_container_index_json(cls, config: Dict) -> "AquaContainerConfig":
133
+ def from_container_index_json(
134
+ cls,
135
+ config: Optional[Dict] = None,
136
+ enable_spec: Optional[bool] = False,
137
+ ) -> "AquaContainerConfig":
56
138
  """
57
139
  Create an AquaContainerConfig instance from a container index JSON.
58
140
 
@@ -60,21 +142,39 @@ class AquaContainerConfig(DataClassSerializable):
60
142
  ----------
61
143
  config : Dict
62
144
  The container index JSON.
145
+ enable_spec: bool
146
+ flag to check if container specification details should be fetched.
63
147
 
64
148
  Returns
65
149
  -------
66
150
  AquaContainerConfig
67
151
  The container configuration instance.
68
152
  """
69
- config = config or {}
70
- inference_items = []
71
- finetune_items = []
72
- evaluate_items = []
153
+ if not config:
154
+ config = get_container_config()
155
+ inference_items = {}
156
+ finetune_items = {}
157
+ evaluate_items = {}
73
158
 
74
159
  # extract inference containers
75
160
  for container_type, containers in config.items():
76
161
  if isinstance(containers, list):
77
162
  for container in containers:
163
+ platforms = [
164
+ AquaContainerConfigItem.Platform[platform]
165
+ for platform in container.get("platforms", [])
166
+ ]
167
+ model_formats = [
168
+ ModelFormat[model_format]
169
+ for model_format in container.get("modelFormats", [])
170
+ ]
171
+ container_spec = (
172
+ config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
173
+ container_type, {}
174
+ )
175
+ if enable_spec
176
+ else None
177
+ )
78
178
  container_item = AquaContainerConfigItem(
79
179
  name=container.get("name", ""),
80
180
  version=container.get("version", ""),
@@ -82,13 +182,35 @@ class AquaContainerConfig(DataClassSerializable):
82
182
  "displayName", container.get("version", "")
83
183
  ),
84
184
  family=container_type,
185
+ platforms=platforms,
186
+ model_formats=model_formats,
187
+ spec=AquaContainerConfigSpec(
188
+ cli_param=container_spec.get(ContainerSpec.CLI_PARM, ""),
189
+ server_port=container_spec.get(
190
+ ContainerSpec.SERVER_PORT, ""
191
+ ),
192
+ health_check_port=container_spec.get(
193
+ ContainerSpec.HEALTH_CHECK_PORT, ""
194
+ ),
195
+ env_vars=container_spec.get(ContainerSpec.ENV_VARS, []),
196
+ restricted_params=container_spec.get(
197
+ ContainerSpec.RESTRICTED_PARAMS, []
198
+ ),
199
+ evaluation_configuration=AquaContainerEvaluationConfig.from_config(
200
+ container_spec.get(
201
+ ContainerSpec.EVALUATION_CONFIGURATION, {}
202
+ )
203
+ ),
204
+ )
205
+ if container_spec
206
+ else None,
85
207
  )
86
208
  if container.get("type") == "inference":
87
- inference_items.append(container_item)
209
+ inference_items[container_type] = container_item
88
210
  elif container_type == "odsc-llm-fine-tuning":
89
- finetune_items.append(container_item)
211
+ finetune_items[container_type] = container_item
90
212
  elif container_type == "odsc-llm-evaluate":
91
- evaluate_items.append(container_item)
213
+ evaluate_items[container_type] = container_item
92
214
 
93
215
  return AquaContainerConfig(
94
216
  inference=inference_items, finetune=finetune_items, evaluate=evaluate_items
@@ -175,11 +297,11 @@ class AquaUIApp(AquaApp):
175
297
  try:
176
298
  if not TENANCY_OCID:
177
299
  raise AquaValueError(
178
- f"TENANCY_OCID must be available in environment"
300
+ "TENANCY_OCID must be available in environment"
179
301
  " variables to list the sub compartments."
180
302
  )
181
303
 
182
- if TENANCY_OCID in self._compartments_cache.keys():
304
+ if TENANCY_OCID in self._compartments_cache:
183
305
  logger.info(
184
306
  f"Returning compartments list in {TENANCY_OCID} from cache."
185
307
  )
@@ -196,7 +318,7 @@ class AquaUIApp(AquaApp):
196
318
  access_level="ANY",
197
319
  )
198
320
  )
199
- except ServiceError as se:
321
+ except ServiceError:
200
322
  logger.error(
201
323
  f"ERROR: Unable to list all sub compartment in tenancy {TENANCY_OCID}."
202
324
  )
@@ -207,7 +329,7 @@ class AquaUIApp(AquaApp):
207
329
  compartment_id=TENANCY_OCID,
208
330
  )
209
331
  )
210
- except ServiceError as se:
332
+ except ServiceError:
211
333
  logger.error(
212
334
  f"ERROR: Unable to list all child compartment in tenancy {TENANCY_OCID}."
213
335
  )
@@ -216,7 +338,7 @@ class AquaUIApp(AquaApp):
216
338
  TENANCY_OCID
217
339
  ).data
218
340
  compartments.insert(0, root_compartment)
219
- except ServiceError as se:
341
+ except ServiceError:
220
342
  logger.error(
221
343
  f"ERROR: Unable to get details of the root compartment {TENANCY_OCID}."
222
344
  )
@@ -248,7 +370,7 @@ class AquaUIApp(AquaApp):
248
370
  """
249
371
  if not COMPARTMENT_OCID:
250
372
  logger.error("No compartment id found from environment variables.")
251
- return dict(compartment_id=COMPARTMENT_OCID)
373
+ return {"compartment_id": COMPARTMENT_OCID}
252
374
 
253
375
  def clear_compartments_list_cache(self) -> dict:
254
376
  """Allows caller to clear compartments list cache
@@ -257,9 +379,9 @@ class AquaUIApp(AquaApp):
257
379
  dict with the key used, and True if cache has the key that needs to be deleted.
258
380
  """
259
381
  res = {}
260
- logger.info(f"Clearing list_compartments cache")
382
+ logger.info("Clearing list_compartments cache")
261
383
  with self._cache_lock:
262
- if TENANCY_OCID in self._compartments_cache.keys():
384
+ if TENANCY_OCID in self._compartments_cache:
263
385
  self._compartments_cache.pop(key=TENANCY_OCID)
264
386
  res = {
265
387
  "key": {
@@ -332,6 +454,7 @@ class AquaUIApp(AquaApp):
332
454
  response = os_client.list_buckets(
333
455
  namespace_name=namespace_name,
334
456
  compartment_id=compartment_id,
457
+ limit=1000,
335
458
  **kwargs,
336
459
  ).data
337
460
 
@@ -474,16 +597,16 @@ class AquaUIApp(AquaApp):
474
597
  raise AquaResourceAccessError(
475
598
  f"Could not check limits availability for the shape {instance_shape}. Make sure you have the necessary policy to check limits availability.",
476
599
  service_payload=se.args[0] if se.args else None,
477
- )
600
+ ) from None
478
601
 
479
602
  available = res.available
480
603
 
481
604
  try:
482
605
  cards = int(instance_shape.split(".")[-1])
483
- except:
606
+ except Exception:
484
607
  cards = 1
485
608
 
486
- response = dict(available_count=available)
609
+ response = {"available_count": available}
487
610
 
488
611
  if available < cards:
489
612
  raise AquaValueError(
@@ -516,7 +639,7 @@ class AquaUIApp(AquaApp):
516
639
  is_versioned = False
517
640
  message = f"Model artifact bucket {bucket_uri} is not versioned. Check if the path exists and enable versioning on the bucket to proceed with model creation."
518
641
 
519
- return dict(is_versioned=is_versioned, message=message)
642
+ return {"is_versioned": is_versioned, "message": message}
520
643
 
521
644
  @telemetry(entry_point="plugin=ui&action=list_containers", name="aqua")
522
645
  def list_containers(self) -> AquaContainerConfig:
@@ -526,8 +649,9 @@ class AquaUIApp(AquaApp):
526
649
  Returns
527
650
  -------
528
651
  AquaContainerConfig
529
- The AQUA containers configuration.
652
+ The AQUA containers configurations.
530
653
  """
531
654
  return AquaContainerConfig.from_container_index_json(
532
- config=get_container_config()
655
+ config=get_container_config(),
656
+ enable_spec=True,
533
657
  )
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
3
  # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -7,16 +6,15 @@
7
6
  import json
8
7
  import os
9
8
  import re
9
+ from concurrent.futures import ThreadPoolExecutor, as_completed
10
10
  from dataclasses import dataclass
11
11
  from typing import Dict, List
12
12
  from urllib.parse import urlparse
13
13
 
14
-
15
14
  import oci
16
15
  from ads.common import auth as authutil
17
16
  from ads.common import oci_client
18
17
  from ads.dataset.progress import TqdmProgressBar
19
- from concurrent.futures import ThreadPoolExecutor, as_completed
20
18
 
21
19
  THREAD_POOL_MAX_WORKERS = 10
22
20
 
@@ -169,8 +167,7 @@ class ObjectStorageDetails:
169
167
 
170
168
  def list_objects(self, **kwargs):
171
169
  """Lists objects in a given oss path
172
-
173
- Parameters
170
+ Parameters
174
171
  -------
175
172
  **kwargs:
176
173
  namespace, bucket, filepath are set by the class. By default, fields gets all values. For other supported
ads/common/serializer.py CHANGED
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
3
  # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -25,10 +24,8 @@ from ads.common import logger
25
24
  from ads.common.auth import default_signer
26
25
 
27
26
  try:
28
- from yaml import CSafeDumper as dumper
29
27
  from yaml import CSafeLoader as loader
30
28
  except:
31
- from yaml import SafeDumper as dumper
32
29
  from yaml import SafeLoader as loader
33
30
 
34
31
 
@@ -99,6 +96,8 @@ class Serializable(ABC):
99
96
  """JSON serializer for objects not serializable by default json code."""
100
97
  if isinstance(obj, datetime):
101
98
  return obj.isoformat()
99
+ if hasattr(obj, "to_dict"):
100
+ return obj.to_dict()
102
101
  raise TypeError(f"Type {type(obj)} not serializable.")
103
102
 
104
103
  @staticmethod
@@ -30,6 +30,7 @@ from ads.common.oci_logging import OCILog
30
30
  from ads.common.oci_resource import ResourceNotFoundError
31
31
  from ads.jobs.builders.infrastructure.base import Infrastructure, RunInstance
32
32
  from ads.jobs.builders.infrastructure.dsc_job_runtime import (
33
+ ContainerRuntimeHandler,
33
34
  DataScienceJobRuntimeManager,
34
35
  )
35
36
  from ads.jobs.builders.infrastructure.utils import get_value
@@ -458,7 +459,7 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
458
459
  ----------
459
460
  **kwargs :
460
461
  Keyword arguments for initializing a Data Science Job Run.
461
- The keys can be any keys in supported by OCI JobConfigurationDetails and JobRun, including:
462
+ The keys can be any keys in supported by OCI JobConfigurationDetails, OcirContainerJobEnvironmentConfigurationDetails and JobRun, including:
462
463
  * hyperparameter_values: dict(str, str)
463
464
  * environment_variables: dict(str, str)
464
465
  * command_line_arguments: str
@@ -466,6 +467,11 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
466
467
  * display_name: str
467
468
  * freeform_tags: dict(str, str)
468
469
  * defined_tags: dict(str, dict(str, object))
470
+ * image: str
471
+ * cmd: list[str]
472
+ * entrypoint: list[str]
473
+ * image_digest: str
474
+ * image_signature_id: str
469
475
 
470
476
  If display_name is not specified, it will be generated as "<JOB_NAME>-run-<TIMESTAMP>".
471
477
 
@@ -478,14 +484,28 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
478
484
  if not self.id:
479
485
  self.create()
480
486
 
481
- swagger_types = (
487
+ config_swagger_types = (
482
488
  oci.data_science.models.DefaultJobConfigurationDetails().swagger_types.keys()
483
489
  )
490
+ env_config_swagger_types = {}
491
+ if hasattr(oci.data_science.models, "OcirContainerJobEnvironmentConfigurationDetails"):
492
+ env_config_swagger_types = (
493
+ oci.data_science.models.OcirContainerJobEnvironmentConfigurationDetails().swagger_types.keys()
494
+ )
484
495
  config_kwargs = {}
496
+ env_config_kwargs = {}
485
497
  keys = list(kwargs.keys())
486
498
  for key in keys:
487
- if key in swagger_types:
499
+ if key in config_swagger_types:
488
500
  config_kwargs[key] = kwargs.pop(key)
501
+ elif key in env_config_swagger_types:
502
+ value = kwargs.pop(key)
503
+ if key in [
504
+ ContainerRuntime.CONST_CMD,
505
+ ContainerRuntime.CONST_ENTRYPOINT
506
+ ] and isinstance(value, str):
507
+ value = ContainerRuntimeHandler.split_args(value)
508
+ env_config_kwargs[key] = value
489
509
 
490
510
  # remove timestamp from the job name (added in default names, when display_name not specified by user)
491
511
  if self.display_name:
@@ -514,6 +534,12 @@ class DSCJob(OCIDataScienceMixin, oci.data_science.models.Job):
514
534
  config_override.update(config_kwargs)
515
535
  kwargs["job_configuration_override_details"] = config_override
516
536
 
537
+ if env_config_kwargs:
538
+ env_config_kwargs["jobEnvironmentType"] = "OCIR_CONTAINER"
539
+ env_config_override = kwargs.get("job_environment_configuration_override_details", {})
540
+ env_config_override.update(env_config_kwargs)
541
+ kwargs["job_environment_configuration_override_details"] = env_config_override
542
+
517
543
  wait = kwargs.pop("wait", False)
518
544
  run = DataScienceJobRun(**kwargs, **self.auth).create()
519
545
  if wait:
@@ -1,7 +1,7 @@
1
1
  #!/usr/bin/env python
2
2
  # -*- coding: utf-8; -*-
3
3
 
4
- # Copyright (c) 2021, 2023 Oracle and/or its affiliates.
4
+ # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
5
5
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
  """Contains classes for conversion between ADS runtime and OCI Data Science Job implementation.
7
7
  This module is for ADS developers only.
@@ -305,10 +305,29 @@ class RuntimeHandler:
305
305
  self._extract_envs,
306
306
  self._extract_artifact,
307
307
  self._extract_runtime_minutes,
308
+ self._extract_properties,
308
309
  ]
309
310
  for extraction in extractions:
310
311
  runtime_spec.update(extraction(dsc_job))
311
312
  return self.RUNTIME_CLASS(self._format_env_var(runtime_spec))
313
+
314
+ def _extract_properties(self, dsc_job) -> dict:
315
+ """Extract the job runtime properties from data science job.
316
+
317
+ This is the base method which does not extract the job runtime properties.
318
+ Sub-class should implement the extraction if needed.
319
+
320
+ Parameters
321
+ ----------
322
+ dsc_job : DSCJob or oci.datascience.models.Job
323
+ The data science job containing runtime information.
324
+
325
+ Returns
326
+ -------
327
+ dict
328
+ A runtime specification dictionary for initializing a runtime.
329
+ """
330
+ return {}
312
331
 
313
332
  def _extract_args(self, dsc_job) -> dict:
314
333
  """Extracts the command line arguments from data science job.
@@ -942,9 +961,12 @@ class GitPythonRuntimeHandler(CondaRuntimeHandler):
942
961
  class ContainerRuntimeHandler(RuntimeHandler):
943
962
  RUNTIME_CLASS = ContainerRuntime
944
963
  CMD_DELIMITER = ","
945
- CONST_CONTAINER_IMAGE = "CONTAINER_CUSTOM_IMAGE"
946
- CONST_CONTAINER_ENTRYPOINT = "CONTAINER_ENTRYPOINT"
947
- CONST_CONTAINER_CMD = "CONTAINER_CMD"
964
+
965
+ def translate(self, runtime: Runtime) -> dict:
966
+ payload = super().translate(runtime)
967
+ job_env_config = self._translate_env_config(runtime)
968
+ payload["job_environment_configuration_details"] = job_env_config
969
+ return payload
948
970
 
949
971
  def _translate_artifact(self, runtime: Runtime):
950
972
  """Specifies a dummy script as the job artifact.
@@ -964,29 +986,34 @@ class ContainerRuntimeHandler(RuntimeHandler):
964
986
  os.path.dirname(__file__), "../../templates", "container.py"
965
987
  )
966
988
 
967
- def _translate_env(self, runtime: ContainerRuntime) -> dict:
968
- """Translate the environment variable.
989
+ def _translate_env_config(self, runtime: Runtime) -> dict:
990
+ """Converts runtime properties to ``OcirContainerJobEnvironmentConfigurationDetails`` payload required by OCI Data Science job.
969
991
 
970
992
  Parameters
971
993
  ----------
972
- runtime : GitPythonRuntime
973
- An instance of GitPythonRuntime
994
+ runtime : Runtime
995
+ The runtime containing the properties to be converted.
974
996
 
975
997
  Returns
976
998
  -------
977
999
  dict
978
- A dictionary containing environment variables for OCI data science job.
1000
+ A dictionary storing the ``OcirContainerJobEnvironmentConfigurationDetails`` payload for OCI data science job.
979
1001
  """
980
- if not runtime.image:
981
- raise ValueError("Specify container image for ContainerRuntime.")
982
- envs = super()._translate_env(runtime)
983
- spec_mappings = {
984
- ContainerRuntime.CONST_IMAGE: self.CONST_CONTAINER_IMAGE,
985
- ContainerRuntime.CONST_ENTRYPOINT: self.CONST_CONTAINER_ENTRYPOINT,
986
- ContainerRuntime.CONST_CMD: self.CONST_CONTAINER_CMD,
1002
+ job_environment_configuration_details = {
1003
+ "job_environment_type": runtime.job_env_type
987
1004
  }
988
- envs.update(self._translate_specs(runtime, spec_mappings, self.CMD_DELIMITER))
989
- return envs
1005
+
1006
+ for key, value in ContainerRuntime.attribute_map.items():
1007
+ property = runtime.get_spec(key, None)
1008
+ if key in [
1009
+ ContainerRuntime.CONST_CMD,
1010
+ ContainerRuntime.CONST_ENTRYPOINT
1011
+ ] and isinstance(property, str):
1012
+ property = self.split_args(property)
1013
+ if property is not None:
1014
+ job_environment_configuration_details[value] = property
1015
+
1016
+ return job_environment_configuration_details
990
1017
 
991
1018
  @staticmethod
992
1019
  def split_args(args: str) -> list:
@@ -1031,17 +1058,37 @@ class ContainerRuntimeHandler(RuntimeHandler):
1031
1058
  """
1032
1059
  spec = super()._extract_envs(dsc_job)
1033
1060
  envs = spec.pop(ContainerRuntime.CONST_ENV_VAR, {})
1034
- if self.CONST_CONTAINER_IMAGE not in envs:
1035
- raise IncompatibleRuntime()
1036
- spec[ContainerRuntime.CONST_IMAGE] = envs.pop(self.CONST_CONTAINER_IMAGE)
1037
- cmd = self.split_args(envs.pop(self.CONST_CONTAINER_CMD, ""))
1038
- if cmd:
1039
- spec[ContainerRuntime.CONST_CMD] = cmd
1040
- entrypoint = self.split_args(envs.pop(self.CONST_CONTAINER_ENTRYPOINT, ""))
1041
- if entrypoint:
1042
- spec[ContainerRuntime.CONST_ENTRYPOINT] = entrypoint
1061
+
1043
1062
  if envs:
1044
1063
  spec[ContainerRuntime.CONST_ENV_VAR] = envs
1064
+
1065
+ return spec
1066
+
1067
+ def _extract_properties(self, dsc_job) -> dict:
1068
+ """Extract the runtime properties from data science job.
1069
+
1070
+ Parameters
1071
+ ----------
1072
+ dsc_job : DSCJob or oci.datascience.models.Job
1073
+ The data science job containing runtime information.
1074
+
1075
+ Returns
1076
+ -------
1077
+ dict
1078
+ A runtime specification dictionary for initializing a runtime.
1079
+ """
1080
+ spec = super()._extract_envs(dsc_job)
1081
+
1082
+ job_env_config = getattr(dsc_job, "job_environment_configuration_details", None)
1083
+ job_env_type = getattr(job_env_config, "job_environment_type", None)
1084
+
1085
+ if not (job_env_config and job_env_type == "OCIR_CONTAINER"):
1086
+ raise IncompatibleRuntime()
1087
+
1088
+ for key, value in ContainerRuntime.attribute_map.items():
1089
+ property = getattr(job_env_config, value, None)
1090
+ if property is not None:
1091
+ spec[key] = property
1045
1092
  return spec
1046
1093
 
1047
1094