oracle-ads 2.13.8__py3-none-any.whl → 2.13.9rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -151,6 +151,8 @@ class AquaMultiModelRef(Serializable):
151
151
  The name of the model.
152
152
  gpu_count : Optional[int]
153
153
  Number of GPUs required for deployment.
154
+ model_task : Optional[str]
155
+ The task that model operates on. Supported tasks are in MultiModelSupportedTaskType
154
156
  env_var : Optional[Dict[str, Any]]
155
157
  Optional environment variables to override during deployment.
156
158
  artifact_location : Optional[str]
@@ -162,6 +164,7 @@ class AquaMultiModelRef(Serializable):
162
164
  gpu_count: Optional[int] = Field(
163
165
  None, description="The gpu count allocation for the model."
164
166
  )
167
+ model_task: Optional[str] = Field(None, description="The task that model operates on. Supported tasks are in MultiModelSupportedTaskType")
165
168
  env_var: Optional[dict] = Field(
166
169
  default_factory=dict, description="The environment variables of the model."
167
170
  )
ads/aqua/constants.py CHANGED
@@ -43,6 +43,8 @@ HF_METADATA_FOLDER = ".cache/"
43
43
  HF_LOGIN_DEFAULT_TIMEOUT = 2
44
44
  MODEL_NAME_DELIMITER = ";"
45
45
  AQUA_TROUBLESHOOTING_LINK = "https://github.com/oracle-samples/oci-data-science-ai-samples/blob/main/ai-quick-actions/troubleshooting-tips.md"
46
+ MODEL_FILE_DESCRIPTION_VERSION = "1.0"
47
+ MODEL_FILE_DESCRIPTION_TYPE = "modelOSSReferenceDescription"
46
48
 
47
49
  TRAINING_METRICS_FINAL = "training_metrics_final"
48
50
  VALIDATION_METRICS_FINAL = "validation_metrics_final"
@@ -58,6 +58,7 @@ from ads.jobs.ads_job import Job
58
58
  from ads.jobs.builders.infrastructure.dsc_job import DataScienceJob
59
59
  from ads.jobs.builders.runtimes.base import Runtime
60
60
  from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
61
+ from ads.model.common.utils import MetadataArtifactPathType
61
62
  from ads.model.model_metadata import (
62
63
  MetadataTaxonomyKeys,
63
64
  ModelCustomMetadata,
@@ -315,6 +316,23 @@ class AquaFineTuningApp(AquaApp):
315
316
  model_by_reference=True,
316
317
  defined_tags=create_fine_tuning_details.defined_tags,
317
318
  )
319
+ defined_metadata_dict = {}
320
+ defined_metadata_list_source = source.defined_metadata_list._to_oci_metadata()
321
+ for defined_metadata in defined_metadata_list_source:
322
+ if (
323
+ defined_metadata.has_artifact
324
+ and defined_metadata.key.lower()
325
+ != AquaModelMetadataKeys.FINE_TUNING_CONFIGURATION.lower()
326
+ ):
327
+ content = self.ds_client.get_model_defined_metadatum_artifact_content(
328
+ source.id, defined_metadata.key
329
+ ).data.content
330
+ defined_metadata_dict[defined_metadata.key] = content
331
+
332
+ for key, value in defined_metadata_dict.items():
333
+ ft_model.create_defined_metadata_artifact(
334
+ key, value, MetadataArtifactPathType.CONTENT
335
+ )
318
336
 
319
337
  ft_job_freeform_tags = {
320
338
  Tags.AQUA_TAG: UNKNOWN,
@@ -15,12 +15,19 @@ from typing import List, Optional
15
15
 
16
16
  import oci
17
17
  from huggingface_hub import hf_api
18
- from pydantic import BaseModel
18
+ from pydantic import BaseModel, Field
19
+ from pydantic.alias_generators import to_camel
19
20
 
20
21
  from ads.aqua import logger
21
22
  from ads.aqua.app import CLIBuilderMixin
22
23
  from ads.aqua.common import utils
23
- from ads.aqua.constants import LIFECYCLE_DETAILS_MISSING_JOBRUN, UNKNOWN_VALUE
24
+ from ads.aqua.config.utils.serializer import Serializable
25
+ from ads.aqua.constants import (
26
+ LIFECYCLE_DETAILS_MISSING_JOBRUN,
27
+ MODEL_FILE_DESCRIPTION_TYPE,
28
+ MODEL_FILE_DESCRIPTION_VERSION,
29
+ UNKNOWN_VALUE,
30
+ )
24
31
  from ads.aqua.data import AquaResourceIdentifier
25
32
  from ads.aqua.model.enums import FineTuningDefinedMetadata
26
33
  from ads.aqua.training.exceptions import exit_code_dict
@@ -304,3 +311,75 @@ class ImportModelDetails(CLIBuilderMixin):
304
311
 
305
312
  def __post_init__(self):
306
313
  self._command = "model register"
314
+
315
+
316
+ class ModelFileInfo(Serializable):
317
+ """Describes the file information of this model.
318
+
319
+ Attributes:
320
+ name (str): The name of the model artifact file.
321
+ version (str): The version of the model artifact file.
322
+ size_in_bytes (int): The size of the model artifact file in bytes.
323
+ """
324
+
325
+ name: str = Field(..., description="The name of model artifact file.")
326
+ version: str = Field(..., description="The version of model artifact file.")
327
+ size_in_bytes: int = Field(
328
+ ..., description="The size of model artifact file in bytes."
329
+ )
330
+
331
+ class Config:
332
+ alias_generator = to_camel
333
+ extra = "allow"
334
+
335
+
336
+ class ModelArtifactInfo(Serializable):
337
+ """Describes the artifact information of this model.
338
+
339
+ Attributes:
340
+ namespace (str): The namespace of the model artifact location.
341
+ bucket_name (str): The bucket name of model artifact location.
342
+ prefix (str): The prefix of model artifact location.
343
+ objects: (List[ModelFileInfo]): A list of model artifact objects.
344
+ """
345
+
346
+ namespace: str = Field(
347
+ ..., description="The name space of model artifact location."
348
+ )
349
+ bucket_name: str = Field(
350
+ ..., description="The bucket name of model artifact location."
351
+ )
352
+ prefix: str = Field(..., description="The prefix of model artifact location.")
353
+ objects: List[ModelFileInfo] = Field(
354
+ ..., description="List of model artifact objects."
355
+ )
356
+
357
+ class Config:
358
+ alias_generator = to_camel
359
+ extra = "allow"
360
+
361
+
362
+ class ModelFileDescription(Serializable):
363
+ """Describes the model file description.
364
+
365
+ Attributes:
366
+ version (str): The version of the model file description. Defaults to `1.0`.
367
+ type (str): The type of model file description. Defaults to `modelOSSReferenceDescription`.
368
+ models List[ModelArtifactInfo]: A list of model artifact information.
369
+ """
370
+
371
+ version: str = Field(
372
+ default=MODEL_FILE_DESCRIPTION_VERSION,
373
+ description="The version of model file description.",
374
+ )
375
+ type: str = Field(
376
+ default=MODEL_FILE_DESCRIPTION_TYPE,
377
+ description="The type of model file description.",
378
+ )
379
+ models: List[ModelArtifactInfo] = Field(
380
+ ..., description="List of model artifact information."
381
+ )
382
+
383
+ class Config:
384
+ alias_generator = to_camel
385
+ extra = "allow"
ads/aqua/model/enums.py CHANGED
@@ -26,5 +26,7 @@ class FineTuningCustomMetadata(ExtendedEnum):
26
26
 
27
27
 
28
28
  class MultiModelSupportedTaskType(ExtendedEnum):
29
- TEXT_GENERATION = "text-generation"
30
- TEXT_GENERATION_ALT = "text_generation"
29
+ TEXT_GENERATION = "text_generation"
30
+ IMAGE_TEXT_TO_TEXT = "image_text_to_text"
31
+ CODE_SYNTHESIS = "code_synthesis"
32
+ EMBEDDING = "text_embedding"
ads/aqua/model/model.py CHANGED
@@ -4,6 +4,7 @@
4
4
  import json
5
5
  import os
6
6
  import pathlib
7
+ import re
7
8
  from datetime import datetime, timedelta
8
9
  from threading import Lock
9
10
  from typing import Any, Dict, List, Optional, Set, Union
@@ -78,6 +79,7 @@ from ads.aqua.model.entities import (
78
79
  AquaModelReadme,
79
80
  AquaModelSummary,
80
81
  ImportModelDetails,
82
+ ModelFileDescription,
81
83
  ModelValidationResult,
82
84
  )
83
85
  from ads.aqua.model.enums import MultiModelSupportedTaskType
@@ -184,8 +186,12 @@ class AquaModelApp(AquaApp):
184
186
  target_project = project_id or PROJECT_OCID
185
187
  target_compartment = compartment_id or COMPARTMENT_OCID
186
188
 
187
- # Skip model copying if it is registered model
188
- if service_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None) is not None:
189
+ # Skip model copying if it is registered model or fine-tuned model
190
+ if (
191
+ service_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None) is not None
192
+ or service_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG)
193
+ is not None
194
+ ):
189
195
  logger.info(
190
196
  f"Aqua Model {model_id} already exists in the user's compartment."
191
197
  "Skipped copying."
@@ -266,8 +272,8 @@ class AquaModelApp(AquaApp):
266
272
  "Model list cannot be empty. Please provide at least one model for deployment."
267
273
  )
268
274
 
269
- artifact_list = []
270
275
  display_name_list = []
276
+ model_file_description_list: List[ModelFileDescription] = []
271
277
  model_custom_metadata = ModelCustomMetadata()
272
278
 
273
279
  service_inference_containers = (
@@ -294,6 +300,7 @@ class AquaModelApp(AquaApp):
294
300
  for model in models:
295
301
  source_model = DataScienceModel.from_id(model.model_id)
296
302
  display_name = source_model.display_name
303
+ model_file_description = source_model.model_file_description
297
304
  # Update model name in user's input model
298
305
  model.model_name = model.model_name or display_name
299
306
 
@@ -304,18 +311,10 @@ class AquaModelApp(AquaApp):
304
311
  # "Currently only service models are supported for multi model deployment."
305
312
  # )
306
313
 
307
- # TODO uncomment the section below if only the specific types of models should be allowed for multi-model deployment
308
- # if (
309
- # source_model.freeform_tags.get(Tags.TASK, UNKNOWN).lower()
310
- # not in MultiModelSupportedTaskType
311
- # ):
312
- # raise AquaValueError(
313
- # f"Invalid or missing {Tags.TASK} tag for selected model {display_name}. "
314
- # f"Currently only `{MultiModelSupportedTaskType.values()}` models are supported for multi model deployment."
315
- # )
316
-
317
314
  display_name_list.append(display_name)
318
315
 
316
+ self._extract_model_task(model, source_model)
317
+
319
318
  # Retrieve model artifact
320
319
  model_artifact_path = source_model.artifact
321
320
  if not model_artifact_path:
@@ -327,7 +326,15 @@ class AquaModelApp(AquaApp):
327
326
  # Update model artifact location in user's input model
328
327
  model.artifact_location = model_artifact_path
329
328
 
330
- artifact_list.append(model_artifact_path)
329
+ if not model_file_description:
330
+ raise AquaValueError(
331
+ f"Model '{display_name}' (ID: {model.model_id}) has no file description. "
332
+ "Please register the model first."
333
+ )
334
+
335
+ model_file_description_list.append(
336
+ ModelFileDescription(**model_file_description)
337
+ )
331
338
 
332
339
  # Validate deployment container consistency
333
340
  deployment_container = source_model.custom_metadata_list.get(
@@ -405,9 +412,16 @@ class AquaModelApp(AquaApp):
405
412
  .with_custom_metadata_list(model_custom_metadata)
406
413
  )
407
414
 
408
- # Attach artifacts
409
- for artifact in artifact_list:
410
- custom_model.add_artifact(uri=artifact)
415
+ # Update multi model file description to attach artifacts
416
+ custom_model.with_model_file_description(
417
+ json_dict=ModelFileDescription(
418
+ models=[
419
+ models
420
+ for model_file_description in model_file_description_list
421
+ for models in model_file_description.models
422
+ ]
423
+ ).model_dump(by_alias=True)
424
+ )
411
425
 
412
426
  # Finalize creation
413
427
  custom_model.create(model_by_reference=True)
@@ -704,6 +718,26 @@ class AquaModelApp(AquaApp):
704
718
  else:
705
719
  raise AquaRuntimeError("Only registered unverified models can be edited.")
706
720
 
721
+ def _extract_model_task(
722
+ self,
723
+ model: AquaMultiModelRef,
724
+ source_model: DataScienceModel,
725
+ ) -> None:
726
+ """In a Multi Model Deployment, will set model_task parameter in AquaMultiModelRef from freeform tags or user"""
727
+ # user does not supply model task, we extract from model metadata
728
+ if not model.model_task:
729
+ model.model_task = source_model.freeform_tags.get(Tags.TASK, UNKNOWN)
730
+
731
+ task_tag = re.sub(r"-", "_", model.model_task).lower()
732
+ # re-visit logic when more model task types are supported
733
+ if task_tag in MultiModelSupportedTaskType:
734
+ model.model_task = task_tag
735
+ else:
736
+ raise AquaValueError(
737
+ f"Invalid or missing {task_tag} tag for selected model {source_model.display_name}. "
738
+ f"Currently only `{MultiModelSupportedTaskType.values()}` models are supported for multi model deployment."
739
+ )
740
+
707
741
  def _fetch_metric_from_metadata(
708
742
  self,
709
743
  custom_metadata_list: ModelCustomMetadata,
@@ -178,9 +178,7 @@ class AquaDeploymentApp(AquaApp):
178
178
  # validate instance shape availability in compartment
179
179
  available_shapes = [
180
180
  shape.name.lower()
181
- for shape in self.list_shapes(
182
- compartment_id=compartment_id
183
- )
181
+ for shape in self.list_shapes(compartment_id=compartment_id)
184
182
  ]
185
183
 
186
184
  if create_deployment_details.instance_shape.lower() not in available_shapes:
@@ -645,7 +643,11 @@ class AquaDeploymentApp(AquaApp):
645
643
  os_path = ObjectStorageDetails.from_path(artifact_path_prefix)
646
644
  artifact_path_prefix = os_path.filepath.rstrip("/")
647
645
 
648
- model_config.append({"params": params, "model_path": artifact_path_prefix})
646
+ # override by-default completion/ chat endpoint with other endpoint (embedding)
647
+ config_data = {"params": params, "model_path": artifact_path_prefix}
648
+ if model.model_task:
649
+ config_data["model_task"] = model.model_task
650
+ model_config.append(config_data)
649
651
  model_name_list.append(model.model_name)
650
652
 
651
653
  env_var.update({AQUA_MULTI_MODEL_CONFIG: json.dumps({"models": model_config})})
@@ -11,6 +11,7 @@ import sys
11
11
  import tempfile
12
12
  from typing import List, Union
13
13
 
14
+ import cloudpickle
14
15
  import fsspec
15
16
  import oracledb
16
17
  import pandas as pd
@@ -126,7 +127,26 @@ def load_data(data_spec, storage_options=None, **kwargs):
126
127
  return data
127
128
 
128
129
 
130
+ def _safe_write(fn, **kwargs):
131
+ try:
132
+ fn(**kwargs)
133
+ except Exception:
134
+ logger.warning(f'Failed to write file {kwargs.get("filename", "UNKNOWN")}')
135
+
136
+
129
137
  def write_data(data, filename, format, storage_options=None, index=False, **kwargs):
138
+ return _safe_write(
139
+ fn=_write_data,
140
+ data=data,
141
+ filename=filename,
142
+ format=format,
143
+ storage_options=storage_options,
144
+ index=index,
145
+ **kwargs,
146
+ )
147
+
148
+
149
+ def _write_data(data, filename, format, storage_options=None, index=False, **kwargs):
130
150
  disable_print()
131
151
  if not format:
132
152
  _, format = os.path.splitext(filename)
@@ -143,11 +163,24 @@ def write_data(data, filename, format, storage_options=None, index=False, **kwar
143
163
 
144
164
 
145
165
  def write_json(json_dict, filename, storage_options=None):
166
+ return _safe_write(
167
+ fn=_write_json,
168
+ json_dict=json_dict,
169
+ filename=filename,
170
+ storage_options=storage_options,
171
+ )
172
+
173
+
174
+ def _write_json(json_dict, filename, storage_options=None):
146
175
  with fsspec.open(filename, mode="w", **storage_options) as f:
147
176
  f.write(json.dumps(json_dict))
148
177
 
149
178
 
150
179
  def write_simple_json(data, path):
180
+ return _safe_write(fn=_write_simple_json, data=data, path=path)
181
+
182
+
183
+ def _write_simple_json(data, path):
151
184
  if ObjectStorageDetails.is_oci_path(path):
152
185
  storage_options = default_signer()
153
186
  else:
@@ -156,6 +189,60 @@ def write_simple_json(data, path):
156
189
  json.dump(data, f, indent=4)
157
190
 
158
191
 
192
+ def write_file(local_filename, remote_filename, storage_options, **kwargs):
193
+ return _safe_write(
194
+ fn=_write_file,
195
+ local_filename=local_filename,
196
+ remote_filename=remote_filename,
197
+ storage_options=storage_options,
198
+ **kwargs,
199
+ )
200
+
201
+
202
+ def _write_file(local_filename, remote_filename, storage_options, **kwargs):
203
+ with open(local_filename) as f1:
204
+ with fsspec.open(
205
+ remote_filename,
206
+ "w",
207
+ **storage_options,
208
+ ) as f2:
209
+ f2.write(f1.read())
210
+
211
+
212
+ def load_pkl(filepath):
213
+ return _safe_write(fn=_load_pkl, filepath=filepath)
214
+
215
+
216
+ def _load_pkl(filepath):
217
+ storage_options = {}
218
+ if ObjectStorageDetails.is_oci_path(filepath):
219
+ storage_options = default_signer()
220
+
221
+ with fsspec.open(filepath, "rb", **storage_options) as f:
222
+ return cloudpickle.load(f)
223
+ return None
224
+
225
+
226
+ def write_pkl(obj, filename, output_dir, storage_options):
227
+ return _safe_write(
228
+ fn=_write_pkl,
229
+ obj=obj,
230
+ filename=filename,
231
+ output_dir=output_dir,
232
+ storage_options=storage_options,
233
+ )
234
+
235
+
236
+ def _write_pkl(obj, filename, output_dir, storage_options):
237
+ pkl_path = os.path.join(output_dir, filename)
238
+ with fsspec.open(
239
+ pkl_path,
240
+ "wb",
241
+ **storage_options,
242
+ ) as f:
243
+ cloudpickle.dump(obj, f)
244
+
245
+
159
246
  def merge_category_columns(data, target_category_columns):
160
247
  result = data.apply(
161
248
  lambda x: "__".join([str(x[col]) for col in target_category_columns]), axis=1
@@ -290,4 +377,8 @@ def disable_print():
290
377
 
291
378
  # Restore
292
379
  def enable_print():
380
+ try:
381
+ sys.stdout.close()
382
+ except Exception:
383
+ pass
293
384
  sys.stdout = sys.__stdout__
@@ -38,6 +38,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
38
38
  super().__init__(config, datasets)
39
39
  self.global_explanation = {}
40
40
  self.local_explanation = {}
41
+ self.explainability_kwargs = {}
41
42
 
42
43
  def set_kwargs(self):
43
44
  model_kwargs_cleaned = self.spec.model_kwargs
@@ -54,6 +55,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
54
55
  self.spec.preprocessing.enabled
55
56
  or model_kwargs_cleaned.get("preprocessing", True)
56
57
  )
58
+ sample_ratio = model_kwargs_cleaned.pop("sample_to_feature_ratio", None)
59
+ if sample_ratio is not None:
60
+ self.explainability_kwargs = {"sample_to_feature_ratio": sample_ratio}
57
61
  return model_kwargs_cleaned, time_budget
58
62
 
59
63
  def preprocess(self, data, series_id): # TODO: re-use self.le for explanations
@@ -445,6 +449,7 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
445
449
  else None,
446
450
  pd.DataFrame(data_i[self.spec.target_column]),
447
451
  task="forecasting",
452
+ **self.explainability_kwargs,
448
453
  )
449
454
 
450
455
  # Generate explanations for the forecast
@@ -518,7 +523,9 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
518
523
  model_params = model.selected_model_params_
519
524
  if len(trials) > 0:
520
525
  score_col = [col for col in trials.columns if "Score" in col][0]
521
- validation_score = trials[trials.Hyperparameters == model_params][score_col].iloc[0]
526
+ validation_score = trials[trials.Hyperparameters == model_params][
527
+ score_col
528
+ ].iloc[0]
522
529
  else:
523
530
  validation_score = 0
524
531
  return -1 * validation_score
@@ -531,8 +538,12 @@ class AutoMLXOperatorModel(ForecastOperatorBaseModel):
531
538
  for s_id in self.forecast_output.list_series_ids():
532
539
  try:
533
540
  metrics = {self.spec.metric.upper(): self.models[s_id]["score"]}
534
- metrics_df = pd.DataFrame.from_dict(metrics, orient="index", columns=[s_id])
535
- logger.warning("AutoMLX failed to generate training metrics. Recovering validation loss instead")
541
+ metrics_df = pd.DataFrame.from_dict(
542
+ metrics, orient="index", columns=[s_id]
543
+ )
544
+ logger.warning(
545
+ "AutoMLX failed to generate training metrics. Recovering validation loss instead"
546
+ )
536
547
  total_metrics = pd.concat([total_metrics, metrics_df], axis=1)
537
548
  except Exception as e:
538
549
  logger.debug(
@@ -11,7 +11,6 @@ import traceback
11
11
  from abc import ABC, abstractmethod
12
12
  from typing import Tuple
13
13
 
14
- import fsspec
15
14
  import numpy as np
16
15
  import pandas as pd
17
16
  import report_creator as rc
@@ -25,10 +24,13 @@ from ads.opctl.operator.lowcode.common.utils import (
25
24
  disable_print,
26
25
  enable_print,
27
26
  human_time_friendly,
27
+ load_pkl,
28
28
  merged_category_column_name,
29
29
  seconds_to_datetime,
30
30
  write_data,
31
+ write_file,
31
32
  write_json,
33
+ write_pkl,
32
34
  )
33
35
  from ads.opctl.operator.lowcode.forecast.utils import (
34
36
  _build_metrics_df,
@@ -38,8 +40,6 @@ from ads.opctl.operator.lowcode.forecast.utils import (
38
40
  evaluate_train_metrics,
39
41
  get_auto_select_plot,
40
42
  get_forecast_plots,
41
- load_pkl,
42
- write_pkl,
43
43
  )
44
44
 
45
45
  from ..const import (
@@ -493,13 +493,11 @@ class ForecastOperatorBaseModel(ABC):
493
493
  enable_print()
494
494
 
495
495
  report_path = os.path.join(unique_output_dir, self.spec.report_filename)
496
- with open(report_local_path) as f1:
497
- with fsspec.open(
498
- report_path,
499
- "w",
500
- **storage_options,
501
- ) as f2:
502
- f2.write(f1.read())
496
+ write_file(
497
+ local_filename=report_local_path,
498
+ remote_filename=report_path,
499
+ storage_options=storage_options,
500
+ )
503
501
 
504
502
  # forecast csv report
505
503
  # todo: add test data into forecast.csv
@@ -573,9 +571,16 @@ class ForecastOperatorBaseModel(ABC):
573
571
  if self.spec.generate_explanations:
574
572
  try:
575
573
  if not self.formatted_global_explanation.empty:
574
+ # Round to 4 decimal places before writing
575
+ global_expl_rounded = self.formatted_global_explanation.copy()
576
+ global_expl_rounded = global_expl_rounded.apply(
577
+ lambda col: np.round(col, 4)
578
+ if np.issubdtype(col.dtype, np.number)
579
+ else col
580
+ )
576
581
  if self.spec.generate_explanation_files:
577
582
  write_data(
578
- data=self.formatted_global_explanation,
583
+ data=global_expl_rounded,
579
584
  filename=os.path.join(
580
585
  unique_output_dir, self.spec.global_explanation_filename
581
586
  ),
@@ -583,16 +588,23 @@ class ForecastOperatorBaseModel(ABC):
583
588
  storage_options=storage_options,
584
589
  index=True,
585
590
  )
586
- results.set_global_explanations(self.formatted_global_explanation)
591
+ results.set_global_explanations(global_expl_rounded)
587
592
  else:
588
593
  logger.warning(
589
594
  f"Attempted to generate global explanations for the {self.spec.global_explanation_filename} file, but an issue occured in formatting the explanations."
590
595
  )
591
596
 
592
597
  if not self.formatted_local_explanation.empty:
598
+ # Round to 4 decimal places before writing
599
+ local_expl_rounded = self.formatted_local_explanation.copy()
600
+ local_expl_rounded = local_expl_rounded.apply(
601
+ lambda col: np.round(col, 4)
602
+ if np.issubdtype(col.dtype, np.number)
603
+ else col
604
+ )
593
605
  if self.spec.generate_explanation_files:
594
606
  write_data(
595
- data=self.formatted_local_explanation,
607
+ data=local_expl_rounded,
596
608
  filename=os.path.join(
597
609
  unique_output_dir, self.spec.local_explanation_filename
598
610
  ),
@@ -600,7 +612,7 @@ class ForecastOperatorBaseModel(ABC):
600
612
  storage_options=storage_options,
601
613
  index=True,
602
614
  )
603
- results.set_local_explanations(self.formatted_local_explanation)
615
+ results.set_local_explanations(local_expl_rounded)
604
616
  else:
605
617
  logger.warning(
606
618
  f"Attempted to generate local explanations for the {self.spec.local_explanation_filename} file, but an issue occured in formatting the explanations."
@@ -19,12 +19,10 @@ from ads.opctl import logger
19
19
  from ads.opctl.operator.lowcode.common.utils import (
20
20
  disable_print,
21
21
  enable_print,
22
- )
23
- from ads.opctl.operator.lowcode.forecast.utils import (
24
- _select_plot_list,
25
22
  load_pkl,
26
23
  write_pkl,
27
24
  )
25
+ from ads.opctl.operator.lowcode.forecast.utils import _select_plot_list
28
26
 
29
27
  from ..const import DEFAULT_TRIALS, SupportedModels
30
28
  from ..operator_config import ForecastOperatorConfig
@@ -159,20 +157,18 @@ class NeuralProphetOperatorModel(ForecastOperatorBaseModel):
159
157
  upper_bound=self.get_horizon(forecast[upper_bound_col_name]).values,
160
158
  lower_bound=self.get_horizon(forecast[lower_bound_col_name]).values,
161
159
  )
162
- core_columns = set(forecast.columns) - set(
163
- [
164
- "y",
165
- "yhat1",
166
- upper_bound_col_name,
167
- lower_bound_col_name,
168
- "future_regressors_additive",
169
- "future_regressors_multiplicative",
170
- ]
171
- )
160
+ core_columns = set(forecast.columns) - {
161
+ "y",
162
+ "yhat1",
163
+ upper_bound_col_name,
164
+ lower_bound_col_name,
165
+ "future_regressors_additive",
166
+ "future_regressors_multiplicative",
167
+ }
172
168
  exog_variables = set(
173
169
  filter(lambda x: x.startswith("future_regressor_"), list(core_columns))
174
170
  )
175
- combine_terms = list(core_columns - exog_variables - set(["ds"]))
171
+ combine_terms = list(core_columns - exog_variables - {"ds"})
176
172
  temp_df = (
177
173
  forecast[list(core_columns)]
178
174
  .rename({"ds": "Date"}, axis=1)
@@ -1,14 +1,12 @@
1
1
  #!/usr/bin/env python
2
2
 
3
- # Copyright (c) 2023, 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2023, 2025 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
6
  import logging
7
7
  import os
8
8
  from typing import Set
9
9
 
10
- import cloudpickle
11
- import fsspec
12
10
  import numpy as np
13
11
  import pandas as pd
14
12
  import report_creator as rc
@@ -21,7 +19,6 @@ from sklearn.metrics import (
21
19
  r2_score,
22
20
  )
23
21
 
24
- from ads.common.object_storage_details import ObjectStorageDetails
25
22
  from ads.dataset.label_encoder import DataFrameLabelEncoder
26
23
  from ads.opctl import logger
27
24
  from ads.opctl.operator.lowcode.forecast.const import ForecastOutputColumns
@@ -170,26 +167,6 @@ def _build_metrics_per_horizon(
170
167
  return metrics_df
171
168
 
172
169
 
173
- def load_pkl(filepath):
174
- storage_options = {}
175
- if ObjectStorageDetails.is_oci_path(filepath):
176
- storage_options = default_signer()
177
-
178
- with fsspec.open(filepath, "rb", **storage_options) as f:
179
- return cloudpickle.load(f)
180
- return None
181
-
182
-
183
- def write_pkl(obj, filename, output_dir, storage_options):
184
- pkl_path = os.path.join(output_dir, filename)
185
- with fsspec.open(
186
- pkl_path,
187
- "wb",
188
- **storage_options,
189
- ) as f:
190
- cloudpickle.dump(obj, f)
191
-
192
-
193
170
  def _build_metrics_df(y_true, y_pred, series_id):
194
171
  if len(y_true) == 0 or len(y_pred) == 0:
195
172
  return pd.DataFrame()
@@ -251,7 +228,10 @@ def evaluate_train_metrics(output):
251
228
 
252
229
 
253
230
  def _select_plot_list(fn, series_ids, target_category_column):
254
- blocks = [rc.Widget(fn(s_id=s_id), label=s_id if target_category_column else None) for s_id in series_ids]
231
+ blocks = [
232
+ rc.Widget(fn(s_id=s_id), label=s_id if target_category_column else None)
233
+ for s_id in series_ids
234
+ ]
255
235
  return rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
256
236
 
257
237
 
@@ -264,8 +244,10 @@ def get_auto_select_plot(backtest_results):
264
244
  back_test_csv_columns = backtest_results.columns.tolist()
265
245
  back_test_column = "backtest"
266
246
  metric_column = "metric"
267
- models = [x for x in back_test_csv_columns if x not in [back_test_column, metric_column]]
268
- for i, column in enumerate(models):
247
+ models = [
248
+ x for x in back_test_csv_columns if x not in [back_test_column, metric_column]
249
+ ]
250
+ for column in models:
269
251
  fig.add_trace(
270
252
  go.Scatter(
271
253
  x=backtest_results[back_test_column],
@@ -283,7 +265,7 @@ def get_forecast_plots(
283
265
  horizon,
284
266
  test_data=None,
285
267
  ci_interval_width=0.95,
286
- target_category_column=None
268
+ target_category_column=None,
287
269
  ):
288
270
  def plot_forecast_plotly(s_id):
289
271
  fig = go.Figure()
@@ -380,7 +362,9 @@ def get_forecast_plots(
380
362
  )
381
363
  return fig
382
364
 
383
- return _select_plot_list(plot_forecast_plotly, forecast_output.list_series_ids(), target_category_column)
365
+ return _select_plot_list(
366
+ plot_forecast_plotly, forecast_output.list_series_ids(), target_category_column
367
+ )
384
368
 
385
369
 
386
370
  def convert_target(target: str, target_col: str):
@@ -1,14 +1,14 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: oracle_ads
3
- Version: 2.13.8
3
+ Version: 2.13.9rc1
4
4
  Summary: Oracle Accelerated Data Science SDK
5
- Keywords: Oracle Cloud Infrastructure,OCI,Machine Learning,ML,Artificial Intelligence,AI,Data Science,Cloud,Oracle
5
+ Keywords: Oracle Cloud Infrastructure,OCI,Machine Learning,ML,Artificial Intelligence,AI,Data Science,Cloud,Oracle,GenAI,Generative AI,Forecast,Anomaly,Document Understanding
6
6
  Author: Oracle Data Science
7
7
  Requires-Python: >=3.8
8
8
  Description-Content-Type: text/markdown
9
+ License-Expression: UPL-1.0
9
10
  Classifier: Development Status :: 5 - Production/Stable
10
11
  Classifier: Intended Audience :: Developers
11
- Classifier: License :: OSI Approved :: Universal Permissive License (UPL)
12
12
  Classifier: Operating System :: OS Independent
13
13
  Classifier: Programming Language :: Python :: 3.9
14
14
  Classifier: Programming Language :: Python :: 3.10
@@ -4,7 +4,7 @@ ads/config.py,sha256=yrCvWEEYcMwWkk9_6IJJZnxbvrOVzsQNMBrCJVafYU8,8106
4
4
  ads/aqua/__init__.py,sha256=7DjwtmZaX-_atIkmZu6XQKHqJUEeemJGR2TlxzMHSXs,973
5
5
  ads/aqua/app.py,sha256=KesfIyVm3T8mj3ugsdVSp05b9RwQAEVw7QN1UB4o4qU,18397
6
6
  ads/aqua/cli.py,sha256=8S0JnhWY9IBZjMyB-5r4I-2nl-WK6yw1iirPsAXICF0,3358
7
- ads/aqua/constants.py,sha256=E_7eaHTMkKjY1VMe8os8xW337giIjESUYvMAnbN9bKw,4981
7
+ ads/aqua/constants.py,sha256=dUl02j5XTAG6sL7XJ9HS5fT0Z869WceRLFIbBwCzmtw,5081
8
8
  ads/aqua/data.py,sha256=HfxLfKiNiPJecMQy0JAztUsT3IdZilHHHOrCJnjZMc4,408
9
9
  ads/aqua/ui.py,sha256=AyX1vFW9f6hoyKN55M6s4iKBLHsOHC41hwRjDfD4NlI,20191
10
10
  ads/aqua/client/__init__.py,sha256=-46EcKQjnWEXxTt85bQzXjA5xsfoBXIGm_syKFlVL1c,178
@@ -12,7 +12,7 @@ ads/aqua/client/client.py,sha256=zlscNhFZVgGnkJ-aj5iZ5v5FedOzpQc4RJDxGPl9VvQ,313
12
12
  ads/aqua/client/openai_client.py,sha256=Gi8nSrtPAUOjxRNu-6UUAqtxWyQIQ5CAvatnm7XfnaM,12501
13
13
  ads/aqua/common/__init__.py,sha256=rZrmh1nho40OCeabXCNWtze-mXi-PGKetcZdxZSn3_0,204
14
14
  ads/aqua/common/decorator.py,sha256=JEN6Cy4DYgQbmIR3ShCjTuBMCnilDxq7jkYMJse1rcM,4112
15
- ads/aqua/common/entities.py,sha256=kLUJu77Sg97VrHb76PvFAAaSWEUum9nYTwzMtOnUo50,8922
15
+ ads/aqua/common/entities.py,sha256=2_Sv07SvekPmU77DfQIycdv0UJOaT6TZM5WyNJyq7GM,9188
16
16
  ads/aqua/common/enums.py,sha256=rTZDOQzTfcgwEl7gjVY3_JotHXkz7wB_edEIB0i6AeQ,3739
17
17
  ads/aqua/common/errors.py,sha256=QONm-2jKBg8AjgOKXm6x-arAV1KIW9pdhfNN1Ys21Wo,3044
18
18
  ads/aqua/common/utils.py,sha256=z93NqufjGzmEpsd7VmSvIpFUawcaoLjBFPSiBCjq2Wk,42001
@@ -52,15 +52,15 @@ ads/aqua/extension/models/ws_models.py,sha256=IgAwu324zlT0XII2nFWQUTeEzqvbFch_9K
52
52
  ads/aqua/finetuning/__init__.py,sha256=vwYT5PluMR0mDQwVIavn_8Icms7LmvfV_FOrJ8fJx8I,296
53
53
  ads/aqua/finetuning/constants.py,sha256=Fx-8LMyF9ZbV9zo5LUYgCv9VniV7djGnM2iW7js2ILE,844
54
54
  ads/aqua/finetuning/entities.py,sha256=ax6tpqrzuF54YNdwJNRSpzhAnkvOeXdnJ18EA-GfIlw,6885
55
- ads/aqua/finetuning/finetuning.py,sha256=SizHmPN1kOlzriQ2GHUvyhL9LxEmntoBFusHhYAz6SI,30220
55
+ ads/aqua/finetuning/finetuning.py,sha256=11DJEEZPa0yu8k0wZvp9IuYEU7IdOd_ZQFUigTqvG0k,31094
56
56
  ads/aqua/model/__init__.py,sha256=j2iylvERdANxgrEDp7b_mLcKMz1CF5Go0qgYCiMwdos,278
57
57
  ads/aqua/model/constants.py,sha256=oOAb4ulsdWBtokCE5SPX7wg8X8SaisLPayua58zhWfY,1856
58
- ads/aqua/model/entities.py,sha256=8P9BEXCroruJHA1RhL66NdmScL-Lql1_7SjnFYk273Y,10089
59
- ads/aqua/model/enums.py,sha256=iJi-AZRh7KR_HK5HUwTkgnTOGVna2Ai5WEzqCjk7Y3s,1079
60
- ads/aqua/model/model.py,sha256=i1cRCdGV1UWyLNwfkikHF0oPhF682ZB-uKqgvJJ7860,86864
58
+ ads/aqua/model/entities.py,sha256=JiKB8SnaUxerRMlwrgpyfQLRuTOB8I14J-Rg5RFPwqw,12660
59
+ ads/aqua/model/enums.py,sha256=bN8GKmgRl40PQrTmd1r-Pqr9VXTIV8gic5-3SAGNnwg,1152
60
+ ads/aqua/model/model.py,sha256=AjsM7o5Dcas4G5imdCQ1VX2Y5bCooMZvDQBKQX8KTUA,88217
61
61
  ads/aqua/modeldeployment/__init__.py,sha256=RJCfU1yazv3hVWi5rS08QVLTpTwZLnlC8wU8diwFjnM,391
62
62
  ads/aqua/modeldeployment/constants.py,sha256=lJF77zwxmlECljDYjwFAMprAUR_zctZHmawiP-4alLg,296
63
- ads/aqua/modeldeployment/deployment.py,sha256=8HWFkc50_DdTM4MEPVzUXYOxvAmXeEHplqsPzK-II8k,56071
63
+ ads/aqua/modeldeployment/deployment.py,sha256=hWpjomp9UyRRsUDAa9VN5ezpqNOpLB7lOZJX8mYwDjI,56265
64
64
  ads/aqua/modeldeployment/entities.py,sha256=qwNH-8eHv-C2QPMITGQkb6haaJRvZ5c0i1H0Aoxeiu4,27100
65
65
  ads/aqua/modeldeployment/inference.py,sha256=rjTF-AM_rHLzL5HCxdLRTrsaSMdB-SzFYUp9dIy5ejw,2109
66
66
  ads/aqua/modeldeployment/utils.py,sha256=Aky4WZ5E564JVZ96X9RYJz_KlB_cAHGzV6mihtd3HV8,22009
@@ -692,7 +692,7 @@ ads/opctl/operator/lowcode/common/const.py,sha256=1dUhgup4L_U0s6BSYmgLPpZAe6xqfS
692
692
  ads/opctl/operator/lowcode/common/data.py,sha256=_0UbW-A0kVQjNOO2aeZoRiebgmKqDqcprPPjZ6KDWdk,4188
693
693
  ads/opctl/operator/lowcode/common/errors.py,sha256=LvQ_Qzh6cqD6uP91DMFFVXPrcc3010EE8LfBH-CH0ho,1534
694
694
  ads/opctl/operator/lowcode/common/transformations.py,sha256=n-Yac9WtI9GLEc5sDKSq75-2q0j59bR_pxlV5EAmkO0,11048
695
- ads/opctl/operator/lowcode/common/utils.py,sha256=z8NqmBk1ScU6R1cTBna9drJxkoD-UGiPqvN9HUw2VR8,9941
695
+ ads/opctl/operator/lowcode/common/utils.py,sha256=9wVM7lNZX_RRO0MVxZrN0FkWBPZweKq_D3_0ON7wjnM,12159
696
696
  ads/opctl/operator/lowcode/feature_store_marketplace/MLoperator,sha256=JO5ulr32WsFnbpk1KN97h8-D70jcFt1kRQ08UMkP4rU,346
697
697
  ads/opctl/operator/lowcode/feature_store_marketplace/README.md,sha256=fN9ROzOPdEZdRgSP_uYvAmD5bD983NC7Irfe_D-mvrw,1356
698
698
  ads/opctl/operator/lowcode/feature_store_marketplace/__init__.py,sha256=rZrmh1nho40OCeabXCNWtze-mXi-PGKetcZdxZSn3_0,204
@@ -718,16 +718,16 @@ ads/opctl/operator/lowcode/forecast/errors.py,sha256=X9zuV2Lqb5N9FuBHHshOFYyhvng
718
718
  ads/opctl/operator/lowcode/forecast/model_evaluator.py,sha256=crtCQ4KIWCueOf2zU-AKD_i3h_cJA_-qAGakdgBazVI,10257
719
719
  ads/opctl/operator/lowcode/forecast/operator_config.py,sha256=3pJzgbSqgPzE7vkce6KkthPQJEgWRRdDOAf1l6aSZpg,8318
720
720
  ads/opctl/operator/lowcode/forecast/schema.yaml,sha256=RoNwjg5jxXMbljtregMkV_rJbax8ir7zdJltC5YfYM8,12438
721
- ads/opctl/operator/lowcode/forecast/utils.py,sha256=0ssrXBAEL5hjQX4avLPkSwFp3sKE8QV5M3K5InqvzYg,14137
721
+ ads/opctl/operator/lowcode/forecast/utils.py,sha256=00prJFK1F3esHlPsPp1WSJ3YoT0NK95f3cH2qNH8AJQ,13578
722
722
  ads/opctl/operator/lowcode/forecast/model/__init__.py,sha256=sAqmLhogrLXb3xI7dPOj9HmSkpTnLh9wkzysuGd8AXk,204
723
723
  ads/opctl/operator/lowcode/forecast/model/arima.py,sha256=PvHoTdDr6RIC4I-YLzed91td6Pq6uxbgluEdu_h0e3c,11766
724
- ads/opctl/operator/lowcode/forecast/model/automlx.py,sha256=4XwS60f7Cs9-oexAn_v0hiWHmrw4jBY_o-_VLzuOd-4,22891
724
+ ads/opctl/operator/lowcode/forecast/model/automlx.py,sha256=XnF3zku8RWrqwBJet5yfx_f6G5nkJ_2e-TzdVJQt7yE,23292
725
725
  ads/opctl/operator/lowcode/forecast/model/autots.py,sha256=UThBBGsEiC3WLSn-BPAuNWT_ZFa3bYMu52keB0vvSt8,13137
726
- ads/opctl/operator/lowcode/forecast/model/base_model.py,sha256=s9WwPpo61YY7teAcmL2MK7cl1GGYAKZu7IkxoReD1I0,35969
726
+ ads/opctl/operator/lowcode/forecast/model/base_model.py,sha256=ENrizwJwhHbJa8DPMqCDEUKqwQaGfR5-fYdTxreQrHU,36613
727
727
  ads/opctl/operator/lowcode/forecast/model/factory.py,sha256=5a9A3ql-bU412BiTB20ob6OxQlkdk8z_tGONMwDXT1k,3900
728
728
  ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py,sha256=2NsWE2WtD_O7uAXw42_3tmG3vb5lk3mdnzCZTph4hao,18903
729
729
  ads/opctl/operator/lowcode/forecast/model/ml_forecast.py,sha256=t5x6EBxOd7XwfT3FGdt-n9gscxaHMm3R2A4Evvxbj38,9646
730
- ads/opctl/operator/lowcode/forecast/model/neuralprophet.py,sha256=60nfNGxjRDsD09Sg7s1YG8G8Qxfcyw0_2rW2PcNy1-c,20021
730
+ ads/opctl/operator/lowcode/forecast/model/neuralprophet.py,sha256=-AS3PPd8Fqn1uaMybJwTnFbmIfUxNPDlgYjGtjy9-E8,19944
731
731
  ads/opctl/operator/lowcode/forecast/model/prophet.py,sha256=jb8bshJf5lDdGJkNH-2SrwN4tdHImP7iD9I8KS4EmZU,17321
732
732
  ads/opctl/operator/lowcode/forecast/whatifserve/__init__.py,sha256=JNDDjLrNorKXMHUuXMifqXea3eheST-lnrcwCl2bWrk,242
733
733
  ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py,sha256=w42anuqAScEQ0vBG3vW4LVLNq1bPdpAWGQEmNhMwZ08,12052
@@ -851,8 +851,8 @@ ads/type_discovery/unknown_detector.py,sha256=yZuYQReO7PUyoWZE7onhhtYaOg6088wf1y
851
851
  ads/type_discovery/zipcode_detector.py,sha256=3AlETg_ZF4FT0u914WXvTT3F3Z6Vf51WiIt34yQMRbw,1421
852
852
  ads/vault/__init__.py,sha256=x9tMdDAOdF5iDHk9u2di_K-ze5Nq068x25EWOBoWwqY,245
853
853
  ads/vault/vault.py,sha256=hFBkpYE-Hfmzu1L0sQwUfYcGxpWmgG18JPndRl0NOXI,8624
854
- oracle_ads-2.13.8.dist-info/entry_points.txt,sha256=9VFnjpQCsMORA4rVkvN8eH6D3uHjtegb9T911t8cqV0,35
855
- oracle_ads-2.13.8.dist-info/licenses/LICENSE.txt,sha256=zoGmbfD1IdRKx834U0IzfFFFo5KoFK71TND3K9xqYqo,1845
856
- oracle_ads-2.13.8.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
857
- oracle_ads-2.13.8.dist-info/METADATA,sha256=LfiYG2aQlavlrsikEbkE1C6CN9ytsyDU32oF0oHClg4,16639
858
- oracle_ads-2.13.8.dist-info/RECORD,,
854
+ oracle_ads-2.13.9rc1.dist-info/entry_points.txt,sha256=9VFnjpQCsMORA4rVkvN8eH6D3uHjtegb9T911t8cqV0,35
855
+ oracle_ads-2.13.9rc1.dist-info/licenses/LICENSE.txt,sha256=zoGmbfD1IdRKx834U0IzfFFFo5KoFK71TND3K9xqYqo,1845
856
+ oracle_ads-2.13.9rc1.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
857
+ oracle_ads-2.13.9rc1.dist-info/METADATA,sha256=322tybBGPwpm9eAuQEcS9dQ_7qwMye0YfNFu75RxQhs,16656
858
+ oracle_ads-2.13.9rc1.dist-info/RECORD,,