oracle-ads 2.12.3__py3-none-any.whl → 2.12.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,19 +9,18 @@ aqua.evaluation.entities
9
9
  This module contains dataclasses for aqua evaluation.
10
10
  """
11
11
 
12
- from dataclasses import dataclass, field
13
- from typing import List, Optional, Union
12
+ from pydantic import Field
13
+ from typing import Any, Dict, List, Optional, Union
14
14
 
15
15
  from ads.aqua.data import AquaResourceIdentifier
16
- from ads.common.serializer import DataClassSerializable
16
+ from ads.aqua.config.utils.serializer import Serializable
17
17
 
18
18
 
19
- @dataclass(repr=False)
20
- class CreateAquaEvaluationDetails(DataClassSerializable):
21
- """Dataclass to create aqua model evaluation.
19
+ class CreateAquaEvaluationDetails(Serializable):
20
+ """Class for creating aqua model evaluation.
22
21
 
23
- Fields
24
- ------
22
+ Properties
23
+ ----------
25
24
  evaluation_source_id: str
26
25
  The evaluation source id. Must be either model or model deployment ocid.
27
26
  evaluation_name: str
@@ -83,69 +82,64 @@ class CreateAquaEvaluationDetails(DataClassSerializable):
83
82
  ocpus: Optional[float] = None
84
83
  log_group_id: Optional[str] = None
85
84
  log_id: Optional[str] = None
86
- metrics: Optional[List] = None
85
+ metrics: Optional[List[str]] = None
87
86
  force_overwrite: Optional[bool] = False
88
87
 
88
+ class Config:
89
+ extra = "ignore"
89
90
 
90
- @dataclass(repr=False)
91
- class AquaEvalReport(DataClassSerializable):
91
+ class AquaEvalReport(Serializable):
92
92
  evaluation_id: str = ""
93
93
  content: str = ""
94
94
 
95
+ class Config:
96
+ extra = "ignore"
95
97
 
96
- @dataclass(repr=False)
97
- class ModelParams(DataClassSerializable):
98
- max_tokens: str = ""
99
- top_p: str = ""
100
- top_k: str = ""
101
- temperature: str = ""
102
- presence_penalty: Optional[float] = 0.0
103
- frequency_penalty: Optional[float] = 0.0
104
- stop: Optional[Union[str, List[str]]] = field(default_factory=list)
105
- model: Optional[str] = "odsc-llm"
106
-
107
-
108
- @dataclass(repr=False)
109
- class AquaEvalParams(ModelParams, DataClassSerializable):
98
+ class AquaEvalParams(Serializable):
110
99
  shape: str = ""
111
100
  dataset_path: str = ""
112
101
  report_path: str = ""
113
102
 
103
+ class Config:
104
+ extra = "allow"
114
105
 
115
- @dataclass(repr=False)
116
- class AquaEvalMetric(DataClassSerializable):
106
+ class AquaEvalMetric(Serializable):
117
107
  key: str
118
108
  name: str
119
109
  description: str = ""
120
110
 
111
+ class Config:
112
+ extra = "ignore"
121
113
 
122
- @dataclass(repr=False)
123
- class AquaEvalMetricSummary(DataClassSerializable):
114
+ class AquaEvalMetricSummary(Serializable):
124
115
  metric: str = ""
125
116
  score: str = ""
126
117
  grade: str = ""
127
118
 
119
+ class Config:
120
+ extra = "ignore"
128
121
 
129
- @dataclass(repr=False)
130
- class AquaEvalMetrics(DataClassSerializable):
122
+ class AquaEvalMetrics(Serializable):
131
123
  id: str
132
124
  report: str
133
- metric_results: List[AquaEvalMetric] = field(default_factory=list)
134
- metric_summary_result: List[AquaEvalMetricSummary] = field(default_factory=list)
125
+ metric_results: List[AquaEvalMetric] = Field(default_factory=list)
126
+ metric_summary_result: List[AquaEvalMetricSummary] = Field(default_factory=list)
135
127
 
128
+ class Config:
129
+ extra = "ignore"
136
130
 
137
- @dataclass(repr=False)
138
- class AquaEvaluationCommands(DataClassSerializable):
131
+ class AquaEvaluationCommands(Serializable):
139
132
  evaluation_id: str
140
133
  evaluation_target_id: str
141
- input_data: dict
142
- metrics: list
134
+ input_data: Dict[str, Any]
135
+ metrics: List[str]
143
136
  output_dir: str
144
- params: dict
137
+ params: Dict[str, Any]
145
138
 
139
+ class Config:
140
+ extra = "ignore"
146
141
 
147
- @dataclass(repr=False)
148
- class AquaEvaluationSummary(DataClassSerializable):
142
+ class AquaEvaluationSummary(Serializable):
149
143
  """Represents a summary of Aqua evalution."""
150
144
 
151
145
  id: str
@@ -154,17 +148,18 @@ class AquaEvaluationSummary(DataClassSerializable):
154
148
  lifecycle_state: str
155
149
  lifecycle_details: str
156
150
  time_created: str
157
- tags: dict
158
- experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
159
- source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
160
- job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
161
- parameters: AquaEvalParams = field(default_factory=AquaEvalParams)
151
+ tags: Dict[str, Any]
152
+ experiment: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
153
+ source: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
154
+ job: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
155
+ parameters: AquaEvalParams = Field(default_factory=AquaEvalParams)
162
156
 
157
+ class Config:
158
+ extra = "ignore"
163
159
 
164
- @dataclass(repr=False)
165
- class AquaEvaluationDetail(AquaEvaluationSummary, DataClassSerializable):
160
+ class AquaEvaluationDetail(AquaEvaluationSummary):
166
161
  """Represents a details of Aqua evalution."""
167
162
 
168
- log_group: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
169
- log: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
170
- introspection: dict = field(default_factory=dict)
163
+ log_group: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
164
+ log: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
165
+ introspection: dict = Field(default_factory=dict)
@@ -7,7 +7,6 @@ import os
7
7
  import re
8
8
  import tempfile
9
9
  from concurrent.futures import ThreadPoolExecutor, as_completed
10
- from dataclasses import asdict, fields
11
10
  from datetime import datetime, timedelta
12
11
  from pathlib import Path
13
12
  from threading import Lock
@@ -46,7 +45,6 @@ from ads.aqua.common.utils import (
46
45
  upload_local_to_os,
47
46
  )
48
47
  from ads.aqua.config.config import get_evaluation_service_config
49
- from ads.aqua.config.evaluation.evaluation_service_config import EvaluationServiceConfig
50
48
  from ads.aqua.constants import (
51
49
  CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
52
50
  EVALUATION_REPORT,
@@ -75,7 +73,6 @@ from ads.aqua.evaluation.entities import (
75
73
  AquaEvaluationSummary,
76
74
  AquaResourceIdentifier,
77
75
  CreateAquaEvaluationDetails,
78
- ModelParams,
79
76
  )
80
77
  from ads.aqua.evaluation.errors import EVALUATION_JOB_EXIT_CODE_MESSAGE
81
78
  from ads.aqua.ui import AquaContainerConfig
@@ -161,10 +158,11 @@ class AquaEvaluationApp(AquaApp):
161
158
  try:
162
159
  create_aqua_evaluation_details = CreateAquaEvaluationDetails(**kwargs)
163
160
  except Exception as ex:
161
+ custom_errors = {
162
+ ".".join(map(str, e["loc"])): e["msg"] for e in json.loads(ex.json())
163
+ }
164
164
  raise AquaValueError(
165
- "Invalid create evaluation parameters. "
166
- "Allowable parameters are: "
167
- f"{', '.join([field.name for field in fields(CreateAquaEvaluationDetails)])}."
165
+ f"Invalid create evaluation parameters. Error details: {custom_errors}."
168
166
  ) from ex
169
167
 
170
168
  if not is_valid_ocid(create_aqua_evaluation_details.evaluation_source_id):
@@ -175,15 +173,7 @@ class AquaEvaluationApp(AquaApp):
175
173
 
176
174
  # The model to evaluate
177
175
  evaluation_source = None
178
- # The evaluation service config
179
- evaluation_config: EvaluationServiceConfig = get_evaluation_service_config()
180
- # The evaluation inference configuration. The inference configuration will be extracted
181
- # based on the inferencing container family.
182
176
  eval_inference_configuration: Dict = {}
183
- # The evaluation inference model sampling params. The system parameters that will not be
184
- # visible for user, but will be applied implicitly for evaluation. The service model params
185
- # will be extracted based on the container family and version.
186
- eval_inference_service_model_params: Dict = {}
187
177
 
188
178
  if (
189
179
  DataScienceResource.MODEL_DEPLOYMENT
@@ -200,29 +190,14 @@ class AquaEvaluationApp(AquaApp):
200
190
  runtime = ModelDeploymentContainerRuntime.from_dict(
201
191
  evaluation_source.runtime.to_dict()
202
192
  )
203
- container_config = AquaContainerConfig.from_container_index_json(
193
+ inference_config = AquaContainerConfig.from_container_index_json(
204
194
  enable_spec=True
205
- )
206
- for (
207
- inference_container_family,
208
- inference_container_info,
209
- ) in container_config.inference.items():
210
- if (
211
- inference_container_info.name
212
- == runtime.image[: runtime.image.rfind(":")]
213
- ):
195
+ ).inference
196
+ for container in inference_config.values():
197
+ if container.name == runtime.image[: runtime.image.rfind(":")]:
214
198
  eval_inference_configuration = (
215
- evaluation_config.get_merged_inference_params(
216
- inference_container_family
217
- ).to_dict()
218
- )
219
- eval_inference_service_model_params = (
220
- evaluation_config.get_merged_inference_model_params(
221
- inference_container_family,
222
- inference_container_info.version,
223
- )
199
+ container.spec.evaluation_configuration
224
200
  )
225
-
226
201
  except Exception:
227
202
  logger.debug(
228
203
  f"Could not load inference config details for the evaluation source id: "
@@ -277,19 +252,12 @@ class AquaEvaluationApp(AquaApp):
277
252
  )
278
253
  evaluation_dataset_path = dst_uri
279
254
 
280
- evaluation_model_parameters = None
281
- try:
282
- evaluation_model_parameters = AquaEvalParams(
283
- shape=create_aqua_evaluation_details.shape_name,
284
- dataset_path=evaluation_dataset_path,
285
- report_path=create_aqua_evaluation_details.report_path,
286
- **create_aqua_evaluation_details.model_parameters,
287
- )
288
- except Exception as ex:
289
- raise AquaValueError(
290
- "Invalid model parameters. Model parameters should "
291
- f"be a dictionary with keys: {', '.join(list(ModelParams.__annotations__.keys()))}."
292
- ) from ex
255
+ evaluation_model_parameters = AquaEvalParams(
256
+ shape=create_aqua_evaluation_details.shape_name,
257
+ dataset_path=evaluation_dataset_path,
258
+ report_path=create_aqua_evaluation_details.report_path,
259
+ **create_aqua_evaluation_details.model_parameters,
260
+ )
293
261
 
294
262
  target_compartment = (
295
263
  create_aqua_evaluation_details.compartment_id or COMPARTMENT_OCID
@@ -370,7 +338,7 @@ class AquaEvaluationApp(AquaApp):
370
338
  evaluation_model_taxonomy_metadata = ModelTaxonomyMetadata()
371
339
  evaluation_model_taxonomy_metadata[
372
340
  MetadataTaxonomyKeys.HYPERPARAMETERS
373
- ].value = {"model_params": dict(asdict(evaluation_model_parameters))}
341
+ ].value = {"model_params": evaluation_model_parameters.to_dict()}
374
342
 
375
343
  evaluation_model = (
376
344
  DataScienceModel()
@@ -443,7 +411,6 @@ class AquaEvaluationApp(AquaApp):
443
411
  dataset_path=evaluation_dataset_path,
444
412
  report_path=create_aqua_evaluation_details.report_path,
445
413
  model_parameters={
446
- **eval_inference_service_model_params,
447
414
  **create_aqua_evaluation_details.model_parameters,
448
415
  },
449
416
  metrics=create_aqua_evaluation_details.metrics,
@@ -580,16 +547,14 @@ class AquaEvaluationApp(AquaApp):
580
547
  **{
581
548
  "AIP_SMC_EVALUATION_ARGUMENTS": json.dumps(
582
549
  {
583
- **asdict(
584
- self._build_launch_cmd(
585
- evaluation_id=evaluation_id,
586
- evaluation_source_id=evaluation_source_id,
587
- dataset_path=dataset_path,
588
- report_path=report_path,
589
- model_parameters=model_parameters,
590
- metrics=metrics,
591
- ),
592
- ),
550
+ **self._build_launch_cmd(
551
+ evaluation_id=evaluation_id,
552
+ evaluation_source_id=evaluation_source_id,
553
+ dataset_path=dataset_path,
554
+ report_path=report_path,
555
+ model_parameters=model_parameters,
556
+ metrics=metrics,
557
+ ).to_dict(),
593
558
  **(inference_configuration or {}),
594
559
  },
595
560
  ),
@@ -662,9 +627,9 @@ class AquaEvaluationApp(AquaApp):
662
627
  "format": Path(dataset_path).suffix,
663
628
  "url": dataset_path,
664
629
  },
665
- metrics=metrics,
630
+ metrics=metrics or [],
666
631
  output_dir=report_path,
667
- params=model_parameters,
632
+ params=model_parameters or {},
668
633
  )
669
634
 
670
635
  @telemetry(entry_point="plugin=evaluation&action=get", name="aqua")
@@ -54,6 +54,33 @@ class AquaDeploymentHandler(AquaAPIhandler):
54
54
  else:
55
55
  raise HTTPError(400, f"The request {self.request.path} is invalid.")
56
56
 
57
+ @handle_exceptions
58
+ def delete(self, model_deployment_id):
59
+ return self.finish(AquaDeploymentApp().delete(model_deployment_id))
60
+
61
+ @handle_exceptions
62
+ def put(self, *args, **kwargs):
63
+ """
64
+ Handles put request for the activating and deactivating OCI datascience model deployments
65
+ Raises
66
+ ------
67
+ HTTPError
68
+ Raises HTTPError if inputs are missing or are invalid
69
+ """
70
+ url_parse = urlparse(self.request.path)
71
+ paths = url_parse.path.strip("/").split("/")
72
+ if len(paths) != 4 or paths[0] != "aqua" or paths[1] != "deployments":
73
+ raise HTTPError(400, f"The request {self.request.path} is invalid.")
74
+
75
+ model_deployment_id = paths[2]
76
+ action = paths[3]
77
+ if action == "activate":
78
+ return self.finish(AquaDeploymentApp().activate(model_deployment_id))
79
+ elif action == "deactivate":
80
+ return self.finish(AquaDeploymentApp().deactivate(model_deployment_id))
81
+ else:
82
+ raise HTTPError(400, f"The request {self.request.path} is invalid.")
83
+
57
84
  @handle_exceptions
58
85
  def post(self, *args, **kwargs):
59
86
  """
@@ -103,6 +130,8 @@ class AquaDeploymentHandler(AquaAPIhandler):
103
130
  memory_in_gbs = input_data.get("memory_in_gbs")
104
131
  model_file = input_data.get("model_file")
105
132
  private_endpoint_id = input_data.get("private_endpoint_id")
133
+ container_image_uri = input_data.get("container_image_uri")
134
+ cmd_var = input_data.get("cmd_var")
106
135
 
107
136
  self.finish(
108
137
  AquaDeploymentApp().create(
@@ -126,6 +155,8 @@ class AquaDeploymentHandler(AquaAPIhandler):
126
155
  memory_in_gbs=memory_in_gbs,
127
156
  model_file=model_file,
128
157
  private_endpoint_id=private_endpoint_id,
158
+ container_image_uri=container_image_uri,
159
+ cmd_var=cmd_var,
129
160
  )
130
161
  )
131
162
 
@@ -266,5 +297,7 @@ __handlers__ = [
266
297
  ("deployments/?([^/]*)/params", AquaDeploymentParamsHandler),
267
298
  ("deployments/config/?([^/]*)", AquaDeploymentHandler),
268
299
  ("deployments/?([^/]*)", AquaDeploymentHandler),
300
+ ("deployments/?([^/]*)/activate", AquaDeploymentHandler),
301
+ ("deployments/?([^/]*)/deactivate", AquaDeploymentHandler),
269
302
  ("inference", AquaDeploymentInferenceHandler),
270
303
  ]
@@ -8,3 +8,4 @@ class Errors(str):
8
8
  NO_INPUT_DATA = "No input data provided."
9
9
  MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'"
10
10
  MISSING_ONEOF_REQUIRED_PARAMETER = "Either '{}' or '{}' is required."
11
+ INVALID_VALUE_OF_PARAMETER = "Invalid value of parameter: '{}'"
@@ -12,7 +12,6 @@ from ads.aqua.evaluation import AquaEvaluationApp
12
12
  from ads.aqua.evaluation.entities import CreateAquaEvaluationDetails
13
13
  from ads.aqua.extension.base_handler import AquaAPIhandler
14
14
  from ads.aqua.extension.errors import Errors
15
- from ads.aqua.extension.utils import validate_function_parameters
16
15
  from ads.config import COMPARTMENT_OCID
17
16
 
18
17
 
@@ -47,10 +46,6 @@ class AquaEvaluationHandler(AquaAPIhandler):
47
46
  if not input_data:
48
47
  raise HTTPError(400, Errors.NO_INPUT_DATA)
49
48
 
50
- validate_function_parameters(
51
- data_class=CreateAquaEvaluationDetails, input_data=input_data
52
- )
53
-
54
49
  self.finish(
55
50
  # TODO: decide what other kwargs will be needed for create aqua evaluation.
56
51
  AquaEvaluationApp().create(
@@ -9,7 +9,10 @@ from tornado.web import HTTPError
9
9
 
10
10
  from ads.aqua.common.decorator import handle_exceptions
11
11
  from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
12
- from ads.aqua.common.utils import get_hf_model_info, list_hf_models
12
+ from ads.aqua.common.utils import (
13
+ get_hf_model_info,
14
+ list_hf_models,
15
+ )
13
16
  from ads.aqua.extension.base_handler import AquaAPIhandler
14
17
  from ads.aqua.extension.errors import Errors
15
18
  from ads.aqua.model import AquaModelApp
@@ -73,6 +76,8 @@ class AquaModelHandler(AquaAPIhandler):
73
76
  paths = url_parse.path.strip("/")
74
77
  if paths.startswith("aqua/model/cache"):
75
78
  return self.finish(AquaModelApp().clear_model_list_cache())
79
+ elif id:
80
+ return self.finish(AquaModelApp().delete_model(id))
76
81
  else:
77
82
  raise HTTPError(400, f"The request {self.request.path} is invalid.")
78
83
 
@@ -123,6 +128,7 @@ class AquaModelHandler(AquaAPIhandler):
123
128
  download_from_hf = (
124
129
  str(input_data.get("download_from_hf", "false")).lower() == "true"
125
130
  )
131
+ inference_container_uri = input_data.get("inference_container_uri")
126
132
 
127
133
  return self.finish(
128
134
  AquaModelApp().register(
@@ -134,9 +140,40 @@ class AquaModelHandler(AquaAPIhandler):
134
140
  compartment_id=compartment_id,
135
141
  project_id=project_id,
136
142
  model_file=model_file,
143
+ inference_container_uri=inference_container_uri,
137
144
  )
138
145
  )
139
146
 
147
+ @handle_exceptions
148
+ def put(self, id):
149
+ try:
150
+ input_data = self.get_json_body()
151
+ except Exception as ex:
152
+ raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
153
+
154
+ if not input_data:
155
+ raise HTTPError(400, Errors.NO_INPUT_DATA)
156
+
157
+ inference_container = input_data.get("inference_container")
158
+ inference_containers = AquaModelApp.list_valid_inference_containers()
159
+ if (
160
+ inference_container is not None
161
+ and inference_container not in inference_containers
162
+ ):
163
+ raise HTTPError(
164
+ 400, Errors.INVALID_VALUE_OF_PARAMETER.format("inference_container")
165
+ )
166
+
167
+ enable_finetuning = input_data.get("enable_finetuning")
168
+ task = input_data.get("task")
169
+ app=AquaModelApp()
170
+ self.finish(
171
+ app.edit_registered_model(
172
+ id, inference_container, enable_finetuning, task
173
+ )
174
+ )
175
+ app.clear_model_details_cache(model_id=id)
176
+
140
177
 
141
178
  class AquaModelLicenseHandler(AquaAPIhandler):
142
179
  """Handler for Aqua Model license REST APIs."""
ads/aqua/model/model.py CHANGED
@@ -10,11 +10,15 @@ from typing import Dict, List, Optional, Set, Union
10
10
  import oci
11
11
  from cachetools import TTLCache
12
12
  from huggingface_hub import snapshot_download
13
- from oci.data_science.models import JobRun, Model
13
+ from oci.data_science.models import JobRun, Metadata, Model, UpdateModelDetails
14
14
 
15
15
  from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
16
16
  from ads.aqua.app import AquaApp
17
- from ads.aqua.common.enums import InferenceContainerTypeFamily, Tags
17
+ from ads.aqua.common.enums import (
18
+ FineTuningContainerTypeFamily,
19
+ InferenceContainerTypeFamily,
20
+ Tags,
21
+ )
18
22
  from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
19
23
  from ads.aqua.common.utils import (
20
24
  LifecycleStatus,
@@ -23,6 +27,7 @@ from ads.aqua.common.utils import (
23
27
  create_word_icon,
24
28
  generate_tei_cmd_var,
25
29
  get_artifact_path,
30
+ get_container_config,
26
31
  get_hf_model_info,
27
32
  list_os_files_with_extension,
28
33
  load_config,
@@ -78,7 +83,11 @@ from ads.config import (
78
83
  TENANCY_OCID,
79
84
  )
80
85
  from ads.model import DataScienceModel
81
- from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem
86
+ from ads.model.model_metadata import (
87
+ MetadataCustomCategory,
88
+ ModelCustomMetadata,
89
+ ModelCustomMetadataItem,
90
+ )
82
91
  from ads.telemetry import telemetry
83
92
 
84
93
 
@@ -333,6 +342,96 @@ class AquaModelApp(AquaApp):
333
342
 
334
343
  return model_details
335
344
 
345
+ @telemetry(entry_point="plugin=model&action=delete", name="aqua")
346
+ def delete_model(self, model_id):
347
+ ds_model = DataScienceModel.from_id(model_id)
348
+ is_registered_model = ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None)
349
+ is_fine_tuned_model = ds_model.freeform_tags.get(
350
+ Tags.AQUA_FINE_TUNED_MODEL_TAG, None
351
+ )
352
+ if is_registered_model or is_fine_tuned_model:
353
+ return ds_model.delete()
354
+ else:
355
+ raise AquaRuntimeError(
356
+ f"Failed to delete model:{model_id}. Only registered models or finetuned model can be deleted."
357
+ )
358
+
359
+ @telemetry(entry_point="plugin=model&action=delete", name="aqua")
360
+ def edit_registered_model(self, id, inference_container, enable_finetuning, task):
361
+ """Edits the default config of unverified registered model.
362
+
363
+ Parameters
364
+ ----------
365
+ id: str
366
+ The model OCID.
367
+ inference_container: str.
368
+ The inference container family name
369
+ enable_finetuning: str
370
+ Flag to enable or disable finetuning over the model. Defaults to None
371
+ task:
372
+ The usecase type of the model. e.g , text-generation , text_embedding etc.
373
+
374
+ Returns
375
+ -------
376
+ Model:
377
+ The instance of oci.data_science.models.Model.
378
+
379
+ """
380
+ ds_model = DataScienceModel.from_id(id)
381
+ if ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None):
382
+ if ds_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None):
383
+ raise AquaRuntimeError(
384
+ f"Failed to edit model:{id}. Only registered unverified models can be edited."
385
+ )
386
+ else:
387
+ custom_metadata_list = ds_model.custom_metadata_list
388
+ freeform_tags = ds_model.freeform_tags
389
+ if inference_container:
390
+ custom_metadata_list.add(
391
+ key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
392
+ value=inference_container,
393
+ category=MetadataCustomCategory.OTHER,
394
+ description="Deployment container mapping for SMC",
395
+ replace=True,
396
+ )
397
+ if enable_finetuning is not None:
398
+ if enable_finetuning.lower() == "true":
399
+ custom_metadata_list.add(
400
+ key=ModelCustomMetadataFields.FINETUNE_CONTAINER,
401
+ value=FineTuningContainerTypeFamily.AQUA_FINETUNING_CONTAINER_FAMILY,
402
+ category=MetadataCustomCategory.OTHER,
403
+ description="Fine-tuning container mapping for SMC",
404
+ replace=True,
405
+ )
406
+ freeform_tags.update({Tags.READY_TO_FINE_TUNE: "true"})
407
+ elif enable_finetuning.lower() == "false":
408
+ try:
409
+ custom_metadata_list.remove(
410
+ ModelCustomMetadataFields.FINETUNE_CONTAINER
411
+ )
412
+ freeform_tags.pop(Tags.READY_TO_FINE_TUNE)
413
+ except Exception as ex:
414
+ raise AquaRuntimeError(
415
+ f"The given model already doesn't support finetuning: {ex}"
416
+ )
417
+
418
+ custom_metadata_list.remove("modelDescription")
419
+ if task:
420
+ freeform_tags.update({Tags.TASK: task})
421
+ updated_custom_metadata_list = [
422
+ Metadata(**metadata)
423
+ for metadata in custom_metadata_list.to_dict()["data"]
424
+ ]
425
+ update_model_details = UpdateModelDetails(
426
+ custom_metadata_list=updated_custom_metadata_list,
427
+ freeform_tags=freeform_tags,
428
+ )
429
+ AquaApp().update_model(id, update_model_details)
430
+ else:
431
+ raise AquaRuntimeError(
432
+ f"Failed to edit model:{id}. Only registered unverified models can be edited."
433
+ )
434
+
336
435
  def _fetch_metric_from_metadata(
337
436
  self,
338
437
  custom_metadata_list: ModelCustomMetadata,
@@ -629,6 +728,32 @@ class AquaModelApp(AquaApp):
629
728
  }
630
729
  return res
631
730
 
731
+ def clear_model_details_cache(self, model_id):
732
+ """
733
+ Allows user to clear model details cache item
734
+ Returns
735
+ -------
736
+ dict with the key used, and True if cache has the key that needs to be deleted.
737
+ """
738
+ res = {}
739
+ logger.info(f"Clearing _service_model_details_cache for {model_id}")
740
+ with self._cache_lock:
741
+ if model_id in self._service_model_details_cache:
742
+ self._service_model_details_cache.pop(key=model_id)
743
+ res = {"key": {"model_id": model_id}, "cache_deleted": True}
744
+
745
+ return res
746
+
747
+ @staticmethod
748
+ def list_valid_inference_containers():
749
+ containers = list(
750
+ AquaContainerConfig.from_container_index_json(
751
+ config=get_container_config(), enable_spec=True
752
+ ).inference.values()
753
+ )
754
+ family_values = [item.family for item in containers]
755
+ return family_values
756
+
632
757
  def _create_model_catalog_entry(
633
758
  self,
634
759
  os_path: str,
@@ -532,6 +532,18 @@ class AquaDeploymentApp(AquaApp):
532
532
 
533
533
  return results
534
534
 
535
+ @telemetry(entry_point="plugin=deployment&action=delete", name="aqua")
536
+ def delete(self,model_deployment_id:str):
537
+ return self.ds_client.delete_model_deployment(model_deployment_id=model_deployment_id).data
538
+
539
+ @telemetry(entry_point="plugin=deployment&action=deactivate",name="aqua")
540
+ def deactivate(self,model_deployment_id:str):
541
+ return self.ds_client.deactivate_model_deployment(model_deployment_id=model_deployment_id).data
542
+
543
+ @telemetry(entry_point="plugin=deployment&action=activate",name="aqua")
544
+ def activate(self,model_deployment_id:str):
545
+ return self.ds_client.activate_model_deployment(model_deployment_id=model_deployment_id).data
546
+
535
547
  @telemetry(entry_point="plugin=deployment&action=get", name="aqua")
536
548
  def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
537
549
  """Gets the information of Aqua model deployment.
ads/dataset/dataset.py CHANGED
@@ -202,7 +202,7 @@ class ADSDataset(PandasDataset):
202
202
  self.sampled_df.head(5)
203
203
  .style.set_table_styles(utils.get_dataframe_styles())
204
204
  .set_table_attributes("class=table")
205
- .hide_index()
205
+ .hide()
206
206
  .to_html()
207
207
  )
208
208
  )
@@ -261,7 +261,7 @@ class ADSDataset(PandasDataset):
261
261
  utils.horizontal_scrollable_div(
262
262
  self.style.set_table_styles(utils.get_dataframe_styles())
263
263
  .set_table_attributes("class=table")
264
- .hide_index()
264
+ .hide()
265
265
  .to_html()
266
266
  )
267
267
  )
ads/dataset/factory.py CHANGED
@@ -366,7 +366,7 @@ class DatasetFactory:
366
366
  display(
367
367
  HTML(
368
368
  list_df.style.set_table_attributes("class=table")
369
- .hide_index()
369
+ .hide()
370
370
  .to_html()
371
371
  )
372
372
  )
@@ -1,18 +1,12 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  from typing import Dict
8
7
 
9
- import click
10
-
11
- from ads.opctl import logger
12
- from ads.opctl.operator.common.utils import _load_yaml_from_uri
13
8
  from ads.opctl.operator.common.operator_yaml_generator import YamlGenerator
14
-
15
- from .const import SupportedModels
9
+ from ads.opctl.operator.common.utils import _load_yaml_from_uri
16
10
 
17
11
 
18
12
  def init(**kwargs: Dict) -> str:
@@ -39,7 +33,7 @@ def init(**kwargs: Dict) -> str:
39
33
  # type=click.Choice(SupportedModels.values()),
40
34
  # default=SupportedModels.Auto,
41
35
  # )
42
- model_type = "auto"
36
+ model_type = "prophet"
43
37
 
44
38
  return YamlGenerator(
45
39
  schema=_load_yaml_from_uri(__file__.replace("cmd.py", "schema.yaml"))
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  from ads.common.extended_enum import ExtendedEnumMeta
@@ -17,7 +16,7 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
17
16
  LGBForecast = "lgbforecast"
18
17
  AutoMLX = "automlx"
19
18
  AutoTS = "autots"
20
- Auto = "auto"
19
+ # Auto = "auto"
21
20
 
22
21
 
23
22
  class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
@@ -28,7 +27,7 @@ class SpeedAccuracyMode(str, metaclass=ExtendedEnumMeta):
28
27
  HIGH_ACCURACY = "HIGH_ACCURACY"
29
28
  BALANCED = "BALANCED"
30
29
  FAST_APPROXIMATE = "FAST_APPROXIMATE"
31
- ratio = dict()
30
+ ratio = {}
32
31
  ratio[HIGH_ACCURACY] = 1 # 100 % data used for generating explanations
33
32
  ratio[BALANCED] = 0.5 # 50 % data used for generating explanations
34
33
  ratio[FAST_APPROXIMATE] = 0 # constant
@@ -1,20 +1,19 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
- from ..const import SupportedModels, AUTO_SELECT
6
+ from ..const import AUTO_SELECT, SupportedModels
7
+ from ..model_evaluator import ModelEvaluator
8
8
  from ..operator_config import ForecastOperatorConfig
9
9
  from .arima import ArimaOperatorModel
10
10
  from .automlx import AutoMLXOperatorModel
11
11
  from .autots import AutoTSOperatorModel
12
12
  from .base_model import ForecastOperatorBaseModel
13
+ from .forecast_datasets import ForecastDatasets
13
14
  from .neuralprophet import NeuralProphetOperatorModel
14
15
  from .prophet import ProphetOperatorModel
15
- from .forecast_datasets import ForecastDatasets
16
- from .ml_forecast import MLForecastOperatorModel
17
- from ..model_evaluator import ModelEvaluator
16
+
18
17
 
19
18
  class UnSupportedModelError(Exception):
20
19
  def __init__(self, model_type: str):
@@ -33,9 +32,9 @@ class ForecastOperatorModelFactory:
33
32
  SupportedModels.Prophet: ProphetOperatorModel,
34
33
  SupportedModels.Arima: ArimaOperatorModel,
35
34
  SupportedModels.NeuralProphet: NeuralProphetOperatorModel,
36
- SupportedModels.LGBForecast: MLForecastOperatorModel,
35
+ # SupportedModels.LGBForecast: MLForecastOperatorModel,
37
36
  SupportedModels.AutoMLX: AutoMLXOperatorModel,
38
- SupportedModels.AutoTS: AutoTSOperatorModel
37
+ SupportedModels.AutoTS: AutoTSOperatorModel,
39
38
  }
40
39
 
41
40
  @classmethod
@@ -65,14 +64,14 @@ class ForecastOperatorModelFactory:
65
64
  model_type = operator_config.spec.model
66
65
  if model_type == AUTO_SELECT:
67
66
  model_type = cls.auto_select_model(datasets, operator_config)
68
- operator_config.spec.model_kwargs = dict()
67
+ operator_config.spec.model_kwargs = {}
69
68
  if model_type not in cls._MAP:
70
69
  raise UnSupportedModelError(model_type)
71
70
  return cls._MAP[model_type](config=operator_config, datasets=datasets)
72
71
 
73
72
  @classmethod
74
73
  def auto_select_model(
75
- cls, datasets: ForecastDatasets, operator_config: ForecastOperatorConfig
74
+ cls, datasets: ForecastDatasets, operator_config: ForecastOperatorConfig
76
75
  ) -> str:
77
76
  """
78
77
  Selects AutoMLX or Arima model based on column count.
@@ -90,8 +89,10 @@ class ForecastOperatorModelFactory:
90
89
  str
91
90
  The type of the model.
92
91
  """
93
- all_models = operator_config.spec.model_kwargs.get("model_list", cls._MAP.keys())
92
+ all_models = operator_config.spec.model_kwargs.get(
93
+ "model_list", cls._MAP.keys()
94
+ )
94
95
  num_backtests = operator_config.spec.model_kwargs.get("num_backtests", 5)
95
96
  sample_ratio = operator_config.spec.model_kwargs.get("sample_ratio", 0.20)
96
97
  model_evaluator = ModelEvaluator(all_models, num_backtests, sample_ratio)
97
- return model_evaluator.find_best_model(datasets, operator_config)
98
+ return model_evaluator.find_best_model(datasets, operator_config)
@@ -2,7 +2,8 @@
2
2
 
3
3
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
- import numpy as np
5
+ import traceback
6
+
6
7
  import pandas as pd
7
8
 
8
9
  from ads.common.decorator import runtime_dependency
@@ -164,7 +165,7 @@ class MLForecastOperatorModel(ForecastOperatorBaseModel):
164
165
  self.errors_dict[self.spec.model] = {
165
166
  "model_name": self.spec.model,
166
167
  "error": str(e),
167
- "error_trace": traceback.format_exc()
168
+ "error_trace": traceback.format_exc(),
168
169
  }
169
170
  logger.warn(f"Encountered Error: {e}. Skipping.")
170
171
  logger.warn(traceback.format_exc())
@@ -173,7 +174,7 @@ class MLForecastOperatorModel(ForecastOperatorBaseModel):
173
174
  def _build_model(self) -> pd.DataFrame:
174
175
  data_train = self.datasets.get_all_data_long(include_horizon=False)
175
176
  data_test = self.datasets.get_all_data_long_forecast_horizon()
176
- self.models = dict()
177
+ self.models = {}
177
178
  model_kwargs = self.set_kwargs()
178
179
  self.forecast_output = ForecastOutput(
179
180
  confidence_interval_width=self.spec.confidence_interval_width,
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  import os
@@ -9,13 +8,17 @@ from dataclasses import dataclass, field
9
8
  from typing import Dict, List
10
9
 
11
10
  from ads.common.serializer import DataClassSerializable
11
+ from ads.opctl.operator.common.operator_config import (
12
+ InputData,
13
+ OperatorConfig,
14
+ OutputDirectory,
15
+ )
12
16
  from ads.opctl.operator.common.utils import _load_yaml_from_uri
13
- from ads.opctl.operator.common.operator_config import OperatorConfig, OutputDirectory, InputData
14
-
15
- from .const import SupportedMetrics, SpeedAccuracyMode
16
- from .const import SupportedModels
17
17
  from ads.opctl.operator.lowcode.common.utils import find_output_dirname
18
18
 
19
+ from .const import SpeedAccuracyMode, SupportedMetrics, SupportedModels
20
+
21
+
19
22
  @dataclass(repr=True)
20
23
  class TestData(InputData):
21
24
  """Class representing operator specification test data details."""
@@ -90,13 +93,17 @@ class ForecastOperatorSpec(DataClassSerializable):
90
93
 
91
94
  def __post_init__(self):
92
95
  """Adjusts the specification details."""
93
- self.output_directory = self.output_directory or OutputDirectory(url=find_output_dirname(self.output_directory))
96
+ self.output_directory = self.output_directory or OutputDirectory(
97
+ url=find_output_dirname(self.output_directory)
98
+ )
94
99
  self.metric = (self.metric or "").lower() or SupportedMetrics.SMAPE.lower()
95
- self.model = self.model or SupportedModels.Auto
100
+ self.model = self.model or SupportedModels.Prophet
96
101
  self.confidence_interval_width = self.confidence_interval_width or 0.80
97
102
  self.report_filename = self.report_filename or "report.html"
98
103
  self.preprocessing = (
99
- self.preprocessing if self.preprocessing is not None else DataPreprocessor(enabled=True)
104
+ self.preprocessing
105
+ if self.preprocessing is not None
106
+ else DataPreprocessor(enabled=True)
100
107
  )
101
108
  # For Report Generation. When user doesn't specify defaults to True
102
109
  self.generate_report = (
@@ -138,7 +145,7 @@ class ForecastOperatorSpec(DataClassSerializable):
138
145
  )
139
146
  self.target_column = self.target_column or "Sales"
140
147
  self.errors_dict_filename = "errors.json"
141
- self.model_kwargs = self.model_kwargs or dict()
148
+ self.model_kwargs = self.model_kwargs or {}
142
149
 
143
150
 
144
151
  @dataclass(repr=True)
@@ -374,12 +374,12 @@ spec:
374
374
  model:
375
375
  type: string
376
376
  required: false
377
- default: auto-select
377
+ default: prophet
378
378
  allowed:
379
379
  - prophet
380
380
  - arima
381
381
  - neuralprophet
382
- - lgbforecast
382
+ # - lgbforecast
383
383
  - automlx
384
384
  - autots
385
385
  - auto-select
ads/oracledb/oracle_db.py CHANGED
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2021, 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  """
@@ -17,19 +16,20 @@ If user uses DSN string copied from OCI console with OCI database setup for TLS
17
16
  Note: We need to account for cx_Oracle though oracledb can operate in thick mode. The end user may be is using one of the old conda packs or an environment where cx_Oracle is the only available driver.
18
17
  """
19
18
 
20
- from ads.common.utils import ORACLE_DEFAULT_PORT
21
-
22
19
  import logging
23
- import numpy as np
24
20
  import os
25
- import pandas as pd
26
21
  import tempfile
27
- from time import time
28
- from typing import Dict, Optional, List, Union, Iterator
29
22
  import zipfile
23
+ from time import time
24
+ from typing import Dict, Iterator, List, Optional, Union
25
+
26
+ import numpy as np
27
+ import pandas as pd
28
+
30
29
  from ads.common.decorator.runtime_dependency import (
31
30
  OptionalDependency,
32
31
  )
32
+ from ads.common.utils import ORACLE_DEFAULT_PORT
33
33
 
34
34
  logger = logging.getLogger("ads.oracle_connector")
35
35
  CX_ORACLE = "cx_Oracle"
@@ -40,17 +40,17 @@ try:
40
40
  import oracledb as oracle_driver # Both the driver share same signature for the APIs that we are using.
41
41
 
42
42
  PYTHON_DRIVER_NAME = PYTHON_ORACLEDB
43
- except:
43
+ except ModuleNotFoundError:
44
44
  logger.info("oracledb package not found. Trying to load cx_Oracle")
45
45
  try:
46
46
  import cx_Oracle as oracle_driver
47
47
 
48
48
  PYTHON_DRIVER_NAME = CX_ORACLE
49
- except ModuleNotFoundError:
49
+ except ModuleNotFoundError as err2:
50
50
  raise ModuleNotFoundError(
51
51
  f"Neither `oracledb` nor `cx_Oracle` module was not found. Please run "
52
52
  f"`pip install {OptionalDependency.DATA}`."
53
- )
53
+ ) from err2
54
54
 
55
55
 
56
56
  class OracleRDBMSConnection(oracle_driver.Connection):
@@ -75,7 +75,7 @@ class OracleRDBMSConnection(oracle_driver.Connection):
75
75
  logger.info(
76
76
  "Running oracledb driver in thick mode. For mTLS based connection, thick mode is default."
77
77
  )
78
- except:
78
+ except Exception:
79
79
  logger.info(
80
80
  "Could not use thick mode. The driver is running in thin mode. System might prompt for passphrase"
81
81
  )
@@ -154,7 +154,6 @@ class OracleRDBMSConnection(oracle_driver.Connection):
154
154
  batch_size=100000,
155
155
  encoding="utf-8",
156
156
  ):
157
-
158
157
  if if_exists not in ["fail", "replace", "append"]:
159
158
  raise ValueError(
160
159
  f"Unknown option `if_exists`={if_exists}. Valid options are 'fail', 'replace', 'append'"
@@ -173,7 +172,6 @@ class OracleRDBMSConnection(oracle_driver.Connection):
173
172
  df_orcl.columns = df_orcl.columns.str.replace(r"\W+", "_", regex=True)
174
173
  table_exist = True
175
174
  with self.cursor() as cursor:
176
-
177
175
  if if_exists != "replace":
178
176
  try:
179
177
  cursor.execute(f"SELECT 1 from {table_name} FETCH NEXT 1 ROWS ONLY")
@@ -275,7 +273,6 @@ class OracleRDBMSConnection(oracle_driver.Connection):
275
273
  yield lst[i : i + batch_size]
276
274
 
277
275
  for batch in chunks(record_data, batch_size=batch_size):
278
-
279
276
  cursor.executemany(sql, batch, batcherrors=True)
280
277
 
281
278
  for error in cursor.getbatcherrors():
@@ -304,7 +301,6 @@ class OracleRDBMSConnection(oracle_driver.Connection):
304
301
  def query(
305
302
  self, sql: str, bind_variables: Optional[Dict], chunksize=None
306
303
  ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
307
-
308
304
  start_time = time()
309
305
 
310
306
  cursor = self.cursor()
@@ -315,10 +311,8 @@ class OracleRDBMSConnection(oracle_driver.Connection):
315
311
  cursor.execute(sql, **bind_variables)
316
312
  columns = [row[0] for row in cursor.description]
317
313
  df = iter(
318
- (
319
- pd.DataFrame(data=rows, columns=columns)
320
- for rows in self._fetch_by_batch(cursor, chunksize)
321
- )
314
+ pd.DataFrame(data=rows, columns=columns)
315
+ for rows in self._fetch_by_batch(cursor, chunksize)
322
316
  )
323
317
 
324
318
  else:
@@ -332,3 +326,21 @@ class OracleRDBMSConnection(oracle_driver.Connection):
332
326
  )
333
327
 
334
328
  return df
329
+
330
+
331
+ def get_adw_connection(vault_secret_id: str) -> "oracledb.Connection":
332
+ """Creates ADW connection from the credentials stored in the vault"""
333
+ import oracledb
334
+
335
+ from ads.secrets.adb import ADBSecretKeeper
336
+
337
+ secret = vault_secret_id
338
+
339
+ logging.getLogger().debug("A secret id was used to retrieve credentials.")
340
+ creds = ADBSecretKeeper.load_secret(secret).to_dict()
341
+ user = creds.pop("user_name", None)
342
+ password = creds.pop("password", None)
343
+ if not user or not password:
344
+ raise ValueError(f"The user or password is missing in {secret}")
345
+ logging.getLogger().debug("Downloaded secrets successfully.")
346
+ return oracledb.connect(user=user, password=password, **creds)
ads/secrets/adb.py CHANGED
@@ -1,17 +1,18 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2021, 2022 Oracle and/or its affiliates.
3
+ # Copyright (c) 2021, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
- import ads
8
- from ads.secrets import SecretKeeper, Secret
9
6
  import json
10
7
  import os
11
8
  import tempfile
12
9
  import zipfile
10
+
13
11
  from tqdm.auto import tqdm
14
12
 
13
+ import ads
14
+ from ads.secrets import Secret, SecretKeeper
15
+
15
16
  logger = ads.getLogger("ads.secrets")
16
17
 
17
18
  from dataclasses import dataclass, field
@@ -25,7 +26,7 @@ class ADBSecret(Secret):
25
26
 
26
27
  user_name: str
27
28
  password: str
28
- service_name: str
29
+ service_name: str = field(default=None)
29
30
  wallet_location: str = field(
30
31
  default=None, metadata={"serializable": False}
31
32
  ) # Not saved in vault
@@ -40,6 +41,7 @@ class ADBSecret(Secret):
40
41
  wallet_secret_ids: list = field(
41
42
  repr=False, default_factory=list
42
43
  ) # Not exposed through environment or `to_dict` function
44
+ dsn: str = field(default=None)
43
45
 
44
46
  def __post_init__(self):
45
47
  self.wallet_file_name = (
@@ -76,6 +78,22 @@ class ADBSecretKeeper(SecretKeeper):
76
78
  >>> print(adw_keeper.secret_id) # Prints the secret_id of the stored credentials
77
79
  >>> adw_keeper.export_vault_details("adw_employee_att.json", format="json") # Save the secret id and vault info to a json file
78
80
 
81
+
82
+ >>> # Saving credentials for TLS connection
83
+ >>> from ads.secrets.adw import ADBSecretKeeper
84
+ >>> vault_id = "ocid1.vault.oc1..<unique_ID>"
85
+ >>> kid = "ocid1.ke..<unique_ID>"
86
+
87
+ >>> import ads
88
+ >>> ads.set_auth("resource_principal") # If using resource principal for authentication
89
+ >>> connection_parameters={
90
+ ... "user_name":"admin",
91
+ ... "password":"<your password>",
92
+ ... "dsn":"<dsn string>"
93
+ ... }
94
+ >>> adw_keeper = ADBSecretKeeper(vault_id=vault_id, key_id=kid, **connection_parameters)
95
+ >>> adw_keeper.save("adw_employee", "My DB credentials", freeform_tags={"schema":"emp"})
96
+
79
97
  >>> # Loading credentails
80
98
  >>> import ads
81
99
  >>> ads.set_auth("resource_principal") # If using resource principal for authentication
@@ -133,6 +151,7 @@ class ADBSecretKeeper(SecretKeeper):
133
151
  wallet_dir: str = None,
134
152
  repository_path: str = None,
135
153
  repository_key: str = None,
154
+ dsn: str = None,
136
155
  **kwargs,
137
156
  ):
138
157
  """
@@ -152,6 +171,8 @@ class ADBSecretKeeper(SecretKeeper):
152
171
  Path to credentials repository. For more details refer `ads.database.connection`
153
172
  repository_key: (str, optional). Default None.
154
173
  Configuration key for loading the right configuration from repository. For more details refer `ads.database.connection`
174
+ dsn: (str, optional). Default None.
175
+ dsn string copied from the OCI console for TLS connection
155
176
  kwargs:
156
177
  vault_id: str. OCID of the vault where the secret is stored. Required for saving secret.
157
178
  key_id: str. OCID of the key used for encrypting the secret. Required for saving secret.
@@ -180,6 +201,7 @@ class ADBSecretKeeper(SecretKeeper):
180
201
  password=password,
181
202
  service_name=service_name,
182
203
  wallet_location=wallet_location,
204
+ dsn=dsn,
183
205
  )
184
206
  self.wallet_dir = wallet_dir
185
207
 
@@ -252,7 +274,7 @@ class ADBSecretKeeper(SecretKeeper):
252
274
  logger.debug(f"Setting wallet file to {self.data.wallet_location}")
253
275
  data.wallet_location = self.data.wallet_location
254
276
  elif data.wallet_secret_ids and len(data.wallet_secret_ids) > 0:
255
- logger.debug(f"Secret ids corresponding to the wallet files found.")
277
+ logger.debug("Secret ids corresponding to the wallet files found.")
256
278
  # If the secret ids for wallet files are available in secret, then we
257
279
  # can generate the wallet file.
258
280
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.3
2
2
  Name: oracle_ads
3
- Version: 2.12.3
3
+ Version: 2.12.4
4
4
  Summary: Oracle Accelerated Data Science SDK
5
5
  Keywords: Oracle Cloud Infrastructure,OCI,Machine Learning,ML,Artificial Intelligence,AI,Data Science,Cloud,Oracle
6
6
  Author: Oracle Data Science
@@ -26,21 +26,21 @@ ads/aqua/dummy_data/oci_models.json,sha256=mxUU8o3plmAFfr06fQmIQuiGe2qFFBlUB7QNP
26
26
  ads/aqua/dummy_data/readme.md,sha256=AlBPt0HBSOFA5HbYVsFsdTm-BC3R5NRpcKrTxdjEnlI,1256
27
27
  ads/aqua/evaluation/__init__.py,sha256=Fd7WL7MpQ1FtJjlftMY2KHli5cz1wr5MDu3hGmV89a0,298
28
28
  ads/aqua/evaluation/constants.py,sha256=GvcXvPIw-VDKw4a8WNKs36uWdT-f7VJrWSpnnRnthGg,1533
29
- ads/aqua/evaluation/entities.py,sha256=mlu_ohjNPxxyDh4s_dFFxXJqKrwutAn48Wqr_odcT2M,5713
29
+ ads/aqua/evaluation/entities.py,sha256=bDIEtIwyNkUK-1S5jsbbne6xy49U-UmtuzzNuYf0tgk,5430
30
30
  ads/aqua/evaluation/errors.py,sha256=qzR63YEIA8haCh4HcBHFFm7j4g6jWDfGszqrPkXx9zQ,4564
31
- ads/aqua/evaluation/evaluation.py,sha256=LoEBUHwJwMa-seLa8d4qzPRMic_4AjeYcxMbftkoXa0,59885
31
+ ads/aqua/evaluation/evaluation.py,sha256=iOSznRW2AioEgJnJ4xIrkyqDiEdTsGBa7LgkzHprnfQ,58011
32
32
  ads/aqua/extension/__init__.py,sha256=mRArjU6UZpZYVr0qHSSkPteA_CKcCZIczOFaK421m9o,1453
33
33
  ads/aqua/extension/aqua_ws_msg_handler.py,sha256=PcRhBqGpq5aOPP0ibhaKfmkA8ajimldsvJC32o9JeTw,3291
34
34
  ads/aqua/extension/base_handler.py,sha256=MuVxsJG66NdatL-Hh99UD3VQOQw1ir-q2YBajwh9cJk,5132
35
35
  ads/aqua/extension/common_handler.py,sha256=Oz3riHDy5pFfbArLge5iaaRoK8PEAnkBvhqqVGbUsvE,4196
36
36
  ads/aqua/extension/common_ws_msg_handler.py,sha256=pMX79tmJKTKog684o6vuwZkAD47l8SxtRx5TNn8se7k,2230
37
- ads/aqua/extension/deployment_handler.py,sha256=UOhtlYNEHSXOG2oCQ9pLNZzOkcY0mbm7EeMhRc_TuKg,9600
37
+ ads/aqua/extension/deployment_handler.py,sha256=i2UAZQ8_uVgg32OmM1vif3kplAVuRwxZsjgTfUSKnH8,11025
38
38
  ads/aqua/extension/deployment_ws_msg_handler.py,sha256=JX3ZHRtscrflSxT7ZTEEI_p_owtk3m5FZq3QXE96AGY,2013
39
- ads/aqua/extension/errors.py,sha256=i37EnRzxGgvxzUNoyEORzHYmB296DGOUb6pm7VwEyTU,451
40
- ads/aqua/extension/evaluation_handler.py,sha256=RT2W7WDtxNIT0uirLfTcDlmTPYCuMuWRhiDxYZYliZs,4542
39
+ ads/aqua/extension/errors.py,sha256=ojDolyr3_0UCCwKqPtiZZyMQuX35jr8h8MQRP6HcBs4,519
40
+ ads/aqua/extension/evaluation_handler.py,sha256=fJH73fa0xmkEiP8SxKL4A4dJgj-NoL3z_G-w_WW2zJs,4353
41
41
  ads/aqua/extension/evaluation_ws_msg_handler.py,sha256=dv0iwOSTxYj1kQ1rPEoDmGgFBzLUCLXq5h7rpmY2T1M,2098
42
42
  ads/aqua/extension/finetune_handler.py,sha256=abiDXNhkhtoV9hrYhCzwhDjdQKlqQ_KSqxKWntkvh3E,3288
43
- ads/aqua/extension/model_handler.py,sha256=lsa8cRblUbITOtn2K9HuPWrl_CVGV2GXHq2aiGh4K5U,9130
43
+ ads/aqua/extension/model_handler.py,sha256=Mlx12n8cssb7Cti0zpDNRHzIDk-xPC7pXzeHf8eY66E,10398
44
44
  ads/aqua/extension/models_ws_msg_handler.py,sha256=3CPfzWl1xfrE2Dpn_WYP9zY0kY5zlsAE8tU_6Y2-i18,1801
45
45
  ads/aqua/extension/ui_handler.py,sha256=3TibTMeqcsSWfPsorspFrhIV0PRh8_4FoWpudycT80g,10664
46
46
  ads/aqua/extension/ui_websocket_handler.py,sha256=oLFjaDrqkSERbhExdvxjLJX0oRcP-DVJ_aWn0qy0uvo,5084
@@ -55,10 +55,10 @@ ads/aqua/model/__init__.py,sha256=j2iylvERdANxgrEDp7b_mLcKMz1CF5Go0qgYCiMwdos,27
55
55
  ads/aqua/model/constants.py,sha256=H239zDu3koa3UTdw-uQveXHX2NDwidclVcS4QIrCTJo,1593
56
56
  ads/aqua/model/entities.py,sha256=9SsdJfoBH7fDKGXQYs8pKLiZ-SqFnXaZrJod4FWU3mI,9670
57
57
  ads/aqua/model/enums.py,sha256=t8GbK2nblIPm3gClR8W31RmbtTuqpoSzoN4W3JfD6AI,1004
58
- ads/aqua/model/model.py,sha256=Vkm1oszD6Lw1rl8Yxf2azuWI1zF4jl-QE5Sk5SEDKWM,57414
58
+ ads/aqua/model/model.py,sha256=IwfN9I3p7KDzhM5moiEBh9sxU6pGtIARKxJcyDOGslA,62711
59
59
  ads/aqua/modeldeployment/__init__.py,sha256=RJCfU1yazv3hVWi5rS08QVLTpTwZLnlC8wU8diwFjnM,391
60
60
  ads/aqua/modeldeployment/constants.py,sha256=lJF77zwxmlECljDYjwFAMprAUR_zctZHmawiP-4alLg,296
61
- ads/aqua/modeldeployment/deployment.py,sha256=OE_jpPCGNxC6-p88kk7Xx1yQ1rKALgALRgcOnfLZb0A,29970
61
+ ads/aqua/modeldeployment/deployment.py,sha256=8qx4cxzuln5FZpAXTZlvaHCio2fzFJxO4PrrAS1_b6A,30652
62
62
  ads/aqua/modeldeployment/entities.py,sha256=7aoE2HemsFEvkQynAI4PCfZBcfPJrvbyZeEYvc7OIAA,5111
63
63
  ads/aqua/modeldeployment/inference.py,sha256=JPqzbHJoM-PpIU_Ft9lHudO9_1vFr7OPQ2GHjPoAufU,2142
64
64
  ads/aqua/training/__init__.py,sha256=w2DNWltXtASQgbrHyvKo0gMs5_chZoG-CSDMI4qe7i0,202
@@ -149,11 +149,11 @@ ads/dataset/correlation.py,sha256=OKdUO-bhZTQnZ3flrju1se6ToB8v6yod_uYeFPRLHfU,80
149
149
  ads/dataset/correlation_plot.py,sha256=LLGy9ZzkQ-V9yPo58T2Jjj6qdBeqwJcave5yUgBOYK4,17266
150
150
  ads/dataset/dask_series.py,sha256=2BhjLDyKL4-dHq7tBgipd8-2VxR0kJdzImu3sxjdNOg,5427
151
151
  ads/dataset/dataframe_transformer.py,sha256=xPBG8nvYh35hfXcwutj8FDrO1DnSM4_NgW6EMRQWW7s,3592
152
- ads/dataset/dataset.py,sha256=nXNUBEBxX8D5OJNyn8AgvYYQFNMTdRD4R-6-UaxavQs,73481
152
+ ads/dataset/dataset.py,sha256=UhcvDBg1zxz3dADroaHAO8iLoPtGhmLOMEdry13Lyqg,73469
153
153
  ads/dataset/dataset_browser.py,sha256=E-Cx0lJPAHicDUi6nIDWBCd_M_LM7EXpdT1S80l7EPE,11879
154
154
  ads/dataset/dataset_with_target.py,sha256=n11qv3bPaZU_XXwZYBtzZYVmX2sn_wuc422HFBXoE_8,38574
155
155
  ads/dataset/exception.py,sha256=Z96xkd9hzbn0NrMsmubcrXLGIU6nP2-0M02T9C0Xwg0,602
156
- ads/dataset/factory.py,sha256=bD2T-fYuKTo9Wk8OscZeAjYpzym54eMfU_eYWvBl4Fk,37594
156
+ ads/dataset/factory.py,sha256=n-KSiNFJR__8PQ39vTdN0uvPXx0Qjnt00dADZAW6aeE,37588
157
157
  ads/dataset/feature_engineering_transformer.py,sha256=IbR-V7YW-WgGRpNUvFyO_rpmWqeyL4WqlZcJO4gy0v8,1140
158
158
  ads/dataset/feature_selection.py,sha256=FJLsQ0obLW9lSFibxNpmUI592SLSXnL_2hD9Y8E5BWI,4144
159
159
  ads/dataset/forecasting_dataset.py,sha256=-qNeCcFmm-1FDN_EmG7tYEs8-MSzmHyIb9PHeSCk0PM,980
@@ -678,22 +678,22 @@ ads/opctl/operator/lowcode/forecast/MLoperator,sha256=xM8yBUQObjG_6Mg36f3Vv8b9N3
678
678
  ads/opctl/operator/lowcode/forecast/README.md,sha256=kbCCEdo-0pwKlZp9ctnWUK6Z31n69IsnG0i26b202Zg,9768
679
679
  ads/opctl/operator/lowcode/forecast/__init__.py,sha256=sAqmLhogrLXb3xI7dPOj9HmSkpTnLh9wkzysuGd8AXk,204
680
680
  ads/opctl/operator/lowcode/forecast/__main__.py,sha256=5Vh-kClwxTsvZLEuECyQBvbZFfH37HQW2G09RwX11Kw,2503
681
- ads/opctl/operator/lowcode/forecast/cmd.py,sha256=Q-R3yfK7aPfE4-0zIqzLFSjnz1tVMxJ1bbvrCirVZHQ,1246
682
- ads/opctl/operator/lowcode/forecast/const.py,sha256=PBEhOGZaFWzkd5H9Vw687lq2A5q5RZNlS6Mj6ZelOuw,2618
681
+ ads/opctl/operator/lowcode/forecast/cmd.py,sha256=uwU-QvnYwxoRFXZv7_JFkzAUnjTNoSsHEme2FF-9Rl0,1151
682
+ ads/opctl/operator/lowcode/forecast/const.py,sha256=jyoXhrRXFipcATwGIU_3rFRZL-r6hvbKNUVO2uG2siY,2597
683
683
  ads/opctl/operator/lowcode/forecast/environment.yaml,sha256=eVMf9pcjADI14_GRGdZOB_gK5_MyG_-cX037TXqzFho,330
684
684
  ads/opctl/operator/lowcode/forecast/errors.py,sha256=X9zuV2Lqb5N9FuBHHshOFYyhvng5r9KGLHnQijZ5b8c,911
685
685
  ads/opctl/operator/lowcode/forecast/model_evaluator.py,sha256=dSV1aj25wzv0V3y72YdYj4rCPjXAog13ppxYDNY9HQU,8913
686
- ads/opctl/operator/lowcode/forecast/operator_config.py,sha256=XskXuOWtZZb6_EcR_t6XAEdr6jt1wT30oBcWt-8zeWA,6396
687
- ads/opctl/operator/lowcode/forecast/schema.yaml,sha256=Zfhh_wfWxNeTtN4bqAe623Vf0HbQWCLyNx8LkiCTCgo,10138
686
+ ads/opctl/operator/lowcode/forecast/operator_config.py,sha256=vG7n-RIiazujH0UtJ0uarx9IKDIAS0b4WcCo1dNLVL0,6422
687
+ ads/opctl/operator/lowcode/forecast/schema.yaml,sha256=twmsn0wPPkgdVk8tKPZL3zBlxqecuXL0GSlIz3I8ZEI,10136
688
688
  ads/opctl/operator/lowcode/forecast/utils.py,sha256=oc6eBH9naYg4BB14KS2HL0uFdZHMgKsxx9vG28dJrXA,14347
689
689
  ads/opctl/operator/lowcode/forecast/model/__init__.py,sha256=sAqmLhogrLXb3xI7dPOj9HmSkpTnLh9wkzysuGd8AXk,204
690
690
  ads/opctl/operator/lowcode/forecast/model/arima.py,sha256=6ZXtzXcqoEMVF9DChzX0cnTJ-9tXKdbPiiSPQq4a9oM,10914
691
691
  ads/opctl/operator/lowcode/forecast/model/automlx.py,sha256=D7U-y-sTdkiqynk_l86z1HNSjn9c58DJTU7l8T33BJk,14856
692
692
  ads/opctl/operator/lowcode/forecast/model/autots.py,sha256=QxU24eZeaRpnC5rTqBFe6-5ylMorPN0sCamHUiNQVaE,13162
693
693
  ads/opctl/operator/lowcode/forecast/model/base_model.py,sha256=s4_lvasasCqvrj49ubD0H_2wA9pvh16_f5BiivqvL20,30876
694
- ads/opctl/operator/lowcode/forecast/model/factory.py,sha256=NV_m2sEgj3byHHqLs9Vbth7d5yfvFuXj8QI3-y9x2Po,3488
694
+ ads/opctl/operator/lowcode/forecast/model/factory.py,sha256=RrE6JJcUmkypjD6IQOR53I9GCg7jQO380r53oLmVK6A,3439
695
695
  ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py,sha256=02gOA-0KKtD0VYj87SsgRMq4EP2VSnhfuxoH1suAIO0,16968
696
- ads/opctl/operator/lowcode/forecast/model/ml_forecast.py,sha256=EOFZR5wjZcpKACW3ZNnxd31Okz_ehOSaO5_dKL-Ktgw,9558
696
+ ads/opctl/operator/lowcode/forecast/model/ml_forecast.py,sha256=6ynnmfVESR5rBjh5FaX1YEXYziIydEJ4t4IDpiUe-Jg,9554
697
697
  ads/opctl/operator/lowcode/forecast/model/neuralprophet.py,sha256=pRmhLHjP027gmPbkgqzR2SZYKvj1rG9Heev2P8mSZ_k,19347
698
698
  ads/opctl/operator/lowcode/forecast/model/prophet.py,sha256=0OBnyVP9bFpo1zSAqA5qtobZxICRTLVT9mwPOlHb3sM,14554
699
699
  ads/opctl/operator/lowcode/pii/MLoperator,sha256=GKCuiXRwfGLyBqELbtgtg-kJPtNWNVA-kSprYTqhF64,6406
@@ -744,7 +744,7 @@ ads/opctl/spark/cli.py,sha256=ylgV9L2pHY6RzT9WRzpzvFPjA6EPWYi04txpS0X_Xaw,1264
744
744
  ads/opctl/spark/cmds.py,sha256=QFrWaHzKqHHpnX_uyztMMMvvJ5s2dOV5edc05AnoNeM,5055
745
745
  ads/opctl/templates/diagnostic_report_template.jinja2,sha256=YfcEyTTrM-i2WgmS6b1X5ifPy0Pf0xG3WkKi6OunYco,3341
746
746
  ads/oracledb/__init__.py,sha256=xMyuwB5xsIEW9MFmvyjmF1YnRarsIjeFe2Ib-aprCG4,210
747
- ads/oracledb/oracle_db.py,sha256=_8Z8DL45RrWdaVZA464ICDtgm8tBoPGoX_wQTozDPHE,12889
747
+ ads/oracledb/oracle_db.py,sha256=mb70joLXAnm_ieROFWtG0LvsPNz4URh5dpDDP73_YOo,13570
748
748
  ads/pipeline/__init__.py,sha256=AAxC4BtaiTO4fj5odxTPWBToqxSKfKzQzRHW_9ozIOY,1268
749
749
  ads/pipeline/ads_pipeline.py,sha256=NkeryW1guYghFkbOlPdN-Kh_LlyZMwJV3c6eAC56V28,84882
750
750
  ads/pipeline/ads_pipeline_run.py,sha256=sNczf-1B0sROoFno9LbbND5HDUPtTTHOpFlIXB-IUH4,28374
@@ -763,7 +763,7 @@ ads/pipeline/visualizer/base.py,sha256=2TYw4EwTmiMqF7Q9SsSAqXEzJj0859XB6-jqteNUg
763
763
  ads/pipeline/visualizer/graph_renderer.py,sha256=u1o9K6pVdvd2Z_xkEJkItJ3ewy0V8xMC0mSh8dlIB-I,9208
764
764
  ads/pipeline/visualizer/text_renderer.py,sha256=nACiVMGkiv0MKNaki4S3MXym98Jb7ofRJmH1zNgmaC4,2625
765
765
  ads/secrets/__init__.py,sha256=UPcdOB6VfEr4QA61fPFP-_Oqhi3PSR9a4ciKc8HEMFI,381
766
- ads/secrets/adb.py,sha256=1XGn40I-0jKWYuJr9iuY5E4JGlu4Zt2JhfZlZcFLsE0,15817
766
+ ads/secrets/adb.py,sha256=f-ttTGbdc3OMGQm-MrDki2aAOSI_0GvW1skzJe38SjU,16648
767
767
  ads/secrets/auth_token.py,sha256=NCdqYMFXdkqsQpvjLPyQFW01cMq7LinxSYM8Rywd4q0,3705
768
768
  ads/secrets/big_data_service.py,sha256=elZTVTXOqoANF0yF2jeDOCH_0RBz34CnjNCaVH34rg0,13926
769
769
  ads/secrets/mysqldb.py,sha256=hVkWV6drmkmLzLX8WeZr4yriMZvzf-n4am15SRTiIgc,5668
@@ -813,8 +813,8 @@ ads/type_discovery/unknown_detector.py,sha256=yZuYQReO7PUyoWZE7onhhtYaOg6088wf1y
813
813
  ads/type_discovery/zipcode_detector.py,sha256=3AlETg_ZF4FT0u914WXvTT3F3Z6Vf51WiIt34yQMRbw,1421
814
814
  ads/vault/__init__.py,sha256=x9tMdDAOdF5iDHk9u2di_K-ze5Nq068x25EWOBoWwqY,245
815
815
  ads/vault/vault.py,sha256=hFBkpYE-Hfmzu1L0sQwUfYcGxpWmgG18JPndRl0NOXI,8624
816
- oracle_ads-2.12.3.dist-info/entry_points.txt,sha256=9VFnjpQCsMORA4rVkvN8eH6D3uHjtegb9T911t8cqV0,35
817
- oracle_ads-2.12.3.dist-info/LICENSE.txt,sha256=zoGmbfD1IdRKx834U0IzfFFFo5KoFK71TND3K9xqYqo,1845
818
- oracle_ads-2.12.3.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
819
- oracle_ads-2.12.3.dist-info/METADATA,sha256=y5hVMbLVQSS4QuN-cS9TW6TfE8Z8ntnDvgNMMs0-wbw,16217
820
- oracle_ads-2.12.3.dist-info/RECORD,,
816
+ oracle_ads-2.12.4.dist-info/entry_points.txt,sha256=9VFnjpQCsMORA4rVkvN8eH6D3uHjtegb9T911t8cqV0,35
817
+ oracle_ads-2.12.4.dist-info/LICENSE.txt,sha256=zoGmbfD1IdRKx834U0IzfFFFo5KoFK71TND3K9xqYqo,1845
818
+ oracle_ads-2.12.4.dist-info/WHEEL,sha256=CpUCUxeHQbRN5UGRQHYRJorO5Af-Qy_fHMctcQ8DSGI,82
819
+ oracle_ads-2.12.4.dist-info/METADATA,sha256=vOzUD-W4JvNIVdWdZw28o0VjPZ4CRk0CWLOLHuCABQM,16217
820
+ oracle_ads-2.12.4.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: flit 3.9.0
2
+ Generator: flit 3.10.1
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any