oracle-ads 2.11.16__py3-none-any.whl → 2.11.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ads/aqua/model/model.py CHANGED
@@ -2,23 +2,31 @@
2
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
  import os
5
+ import pathlib
5
6
  from datetime import datetime, timedelta
6
7
  from threading import Lock
7
- from typing import Dict, Optional, Set, Union
8
+ from typing import Dict, List, Optional, Set, Union
8
9
 
10
+ import oci
9
11
  from cachetools import TTLCache
12
+ from huggingface_hub import snapshot_download
13
+ from oci.data_science.models import JobRun, Model
10
14
 
11
- from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
15
+ from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
12
16
  from ads.aqua.app import AquaApp
13
17
  from ads.aqua.common.enums import Tags
14
18
  from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
15
19
  from ads.aqua.common.utils import (
20
+ LifecycleStatus,
21
+ _build_resource_identifier,
16
22
  copy_model_config,
17
23
  create_word_icon,
18
24
  get_artifact_path,
25
+ get_hf_model_info,
19
26
  list_os_files_with_extension,
20
27
  load_config,
21
28
  read_file,
29
+ upload_folder,
22
30
  )
23
31
  from ads.aqua.constants import (
24
32
  AQUA_MODEL_ARTIFACT_CONFIG,
@@ -38,8 +46,22 @@ from ads.aqua.constants import (
38
46
  VALIDATION_METRICS,
39
47
  VALIDATION_METRICS_FINAL,
40
48
  )
41
- from ads.aqua.model.constants import *
42
- from ads.aqua.model.entities import *
49
+ from ads.aqua.model.constants import (
50
+ FineTuningCustomMetadata,
51
+ FineTuningMetricCategories,
52
+ ModelCustomMetadataFields,
53
+ ModelType,
54
+ )
55
+ from ads.aqua.model.entities import (
56
+ AquaFineTuneModel,
57
+ AquaFineTuningMetric,
58
+ AquaModel,
59
+ AquaModelLicense,
60
+ AquaModelSummary,
61
+ ImportModelDetails,
62
+ ModelFormat,
63
+ ModelValidationResult,
64
+ )
43
65
  from ads.aqua.ui import AquaContainerConfig, AquaContainerConfigItem
44
66
  from ads.common.auth import default_signer
45
67
  from ads.common.oci_resource import SEARCH_TYPE, OCIResource
@@ -55,7 +77,6 @@ from ads.config import (
55
77
  from ads.model import DataScienceModel
56
78
  from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem
57
79
  from ads.telemetry import telemetry
58
- from oci.data_science.models import JobRun, Model
59
80
 
60
81
 
61
82
  class AquaModelApp(AquaApp):
@@ -176,11 +197,9 @@ class AquaModelApp(AquaApp):
176
197
  if not self._if_show(ds_model):
177
198
  raise AquaRuntimeError(f"Target model `{ds_model.id} `is not Aqua model.")
178
199
 
179
- is_fine_tuned_model = (
180
- True
181
- if ds_model.freeform_tags
200
+ is_fine_tuned_model = bool(
201
+ ds_model.freeform_tags
182
202
  and ds_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG)
183
- else False
184
203
  )
185
204
 
186
205
  # todo: consolidate this logic in utils for model and deployment use
@@ -218,6 +237,10 @@ class AquaModelApp(AquaApp):
218
237
  ModelCustomMetadataFields.FINETUNE_CONTAINER,
219
238
  ModelCustomMetadataItem(key=ModelCustomMetadataFields.FINETUNE_CONTAINER),
220
239
  ).value
240
+ artifact_location = ds_model.custom_metadata_list.get(
241
+ ModelCustomMetadataFields.ARTIFACT_LOCATION,
242
+ ModelCustomMetadataItem(key=ModelCustomMetadataFields.ARTIFACT_LOCATION),
243
+ ).value
221
244
 
222
245
  aqua_model_attributes = dict(
223
246
  **self._process_model(ds_model, self.region),
@@ -226,6 +249,7 @@ class AquaModelApp(AquaApp):
226
249
  inference_container=inference_container,
227
250
  finetuning_container=finetuning_container,
228
251
  evaluation_container=evaluation_container,
252
+ artifact_location=artifact_location,
229
253
  )
230
254
 
231
255
  if not is_fine_tuned_model:
@@ -260,7 +284,7 @@ class AquaModelApp(AquaApp):
260
284
  logger.debug(str(e))
261
285
  source_name = UNKNOWN
262
286
 
263
- source_identifier = utils._build_resource_identifier(
287
+ source_identifier = _build_resource_identifier(
264
288
  id=source_id,
265
289
  name=source_name,
266
290
  region=self.region,
@@ -278,7 +302,7 @@ class AquaModelApp(AquaApp):
278
302
  )
279
303
  )
280
304
  # TODO: change the argument's name.
281
- lifecycle_state = utils.LifecycleStatus.get_status(
305
+ lifecycle_state = LifecycleStatus.get_status(
282
306
  evaluation_status=ds_model.lifecycle_state,
283
307
  job_run_status=job_run_status,
284
308
  )
@@ -321,7 +345,7 @@ class AquaModelApp(AquaApp):
321
345
  category=category,
322
346
  scores=scores,
323
347
  )
324
- except:
348
+ except Exception:
325
349
  return AquaFineTuningMetric(name=metric_name, category=category, scores=[])
326
350
 
327
351
  def _build_ft_metrics(
@@ -410,9 +434,7 @@ class AquaModelApp(AquaApp):
410
434
  )
411
435
 
412
436
  description = ""
413
- if isinstance(model, DataScienceModel) or isinstance(
414
- model, oci.data_science.models.model.Model
415
- ):
437
+ if isinstance(model, (DataScienceModel, oci.data_science.models.model.Model)):
416
438
  description = model.description
417
439
  elif isinstance(model, oci.resource_search.models.ResourceSummary):
418
440
  description = model.additional_details.get("description")
@@ -438,16 +460,26 @@ class AquaModelApp(AquaApp):
438
460
  == READY_TO_IMPORT_STATUS
439
461
  )
440
462
 
463
+ try:
464
+ model_file = model.custom_metadata_list.get(AQUA_MODEL_ARTIFACT_FILE).value
465
+ except Exception:
466
+ model_file = UNKNOWN
467
+
441
468
  inference_containers = AquaContainerConfig.from_container_index_json().inference
442
469
 
443
- model_format = ModelFormat[
444
- freeform_tags.get(Tags.MODEL_FORMAT, ModelFormat.SAFETENSORS.value).upper()
470
+ model_formats_str = freeform_tags.get(
471
+ Tags.MODEL_FORMAT, ModelFormat.SAFETENSORS.value
472
+ ).upper()
473
+ model_formats = [
474
+ ModelFormat[model_format] for model_format in model_formats_str.split(",")
445
475
  ]
476
+
446
477
  supported_platform: Set[AquaContainerConfigItem.Platform] = set()
447
478
 
448
479
  for container in inference_containers.values():
449
- if model_format in container.model_formats:
450
- supported_platform.update(container.platforms)
480
+ for model_format in model_formats:
481
+ if model_format in container.model_formats:
482
+ supported_platform.update(container.platforms)
451
483
 
452
484
  nvidia_gpu_supported = (
453
485
  AquaContainerConfigItem.Platform.NVIDIA_GPU in supported_platform
@@ -456,26 +488,27 @@ class AquaModelApp(AquaApp):
456
488
  AquaContainerConfigItem.Platform.ARM_CPU in supported_platform
457
489
  )
458
490
 
459
- return dict(
460
- compartment_id=model.compartment_id,
461
- icon=icon or UNKNOWN,
462
- id=model_id,
463
- license=freeform_tags.get(Tags.LICENSE, UNKNOWN),
464
- name=model.display_name,
465
- organization=freeform_tags.get(Tags.ORGANIZATION, UNKNOWN),
466
- task=freeform_tags.get(Tags.TASK, UNKNOWN),
467
- time_created=str(model.time_created),
468
- is_fine_tuned_model=is_fine_tuned_model,
469
- tags=tags,
470
- console_link=console_link,
471
- search_text=search_text,
472
- ready_to_deploy=ready_to_deploy,
473
- ready_to_finetune=ready_to_finetune,
474
- ready_to_import=ready_to_import,
475
- nvidia_gpu_supported=nvidia_gpu_supported,
476
- arm_cpu_supported=arm_cpu_supported,
477
- model_format=model_format,
478
- )
491
+ return {
492
+ "compartment_id": model.compartment_id,
493
+ "icon": icon or UNKNOWN,
494
+ "id": model_id,
495
+ "license": freeform_tags.get(Tags.LICENSE, UNKNOWN),
496
+ "name": model.display_name,
497
+ "organization": freeform_tags.get(Tags.ORGANIZATION, UNKNOWN),
498
+ "task": freeform_tags.get(Tags.TASK, UNKNOWN),
499
+ "time_created": str(model.time_created),
500
+ "is_fine_tuned_model": is_fine_tuned_model,
501
+ "tags": tags,
502
+ "console_link": console_link,
503
+ "search_text": search_text,
504
+ "ready_to_deploy": ready_to_deploy,
505
+ "ready_to_finetune": ready_to_finetune,
506
+ "ready_to_import": ready_to_import,
507
+ "nvidia_gpu_supported": nvidia_gpu_supported,
508
+ "arm_cpu_supported": arm_cpu_supported,
509
+ "model_file": model_file,
510
+ "model_formats": model_formats,
511
+ }
479
512
 
480
513
  @telemetry(entry_point="plugin=model&action=list", name="aqua")
481
514
  def list(
@@ -524,7 +557,7 @@ class AquaModelApp(AquaApp):
524
557
  category="aqua/service/model", action="list"
525
558
  )
526
559
 
527
- if ODSC_MODEL_COMPARTMENT_OCID in self._service_models_cache.keys():
560
+ if ODSC_MODEL_COMPARTMENT_OCID in self._service_models_cache:
528
561
  logger.info(
529
562
  f"Returning service models list in {ODSC_MODEL_COMPARTMENT_OCID} from cache."
530
563
  )
@@ -576,7 +609,7 @@ class AquaModelApp(AquaApp):
576
609
  res = {}
577
610
  logger.info("Clearing _service_models_cache")
578
611
  with self._cache_lock:
579
- if ODSC_MODEL_COMPARTMENT_OCID in self._service_models_cache.keys():
612
+ if ODSC_MODEL_COMPARTMENT_OCID in self._service_models_cache:
580
613
  self._service_models_cache.pop(key=ODSC_MODEL_COMPARTMENT_OCID)
581
614
  res = {
582
615
  "key": {
@@ -593,10 +626,9 @@ class AquaModelApp(AquaApp):
593
626
  inference_container: str,
594
627
  finetuning_container: str,
595
628
  verified_model: DataScienceModel,
596
- model_format: ModelFormat,
629
+ validation_result: ModelValidationResult,
597
630
  compartment_id: Optional[str],
598
631
  project_id: Optional[str],
599
- model_file: Optional[str],
600
632
  ) -> DataScienceModel:
601
633
  """Create model by reference from the object storage path
602
634
 
@@ -625,11 +657,20 @@ class AquaModelApp(AquaApp):
625
657
  }
626
658
  )
627
659
  tags.update({Tags.BASE_MODEL_CUSTOM: "true"})
628
- tags.update({Tags.MODEL_FORMAT: model_format.value})
660
+
661
+ if validation_result and validation_result.model_formats:
662
+ tags.update(
663
+ {
664
+ Tags.MODEL_FORMAT: ",".join(
665
+ model_format.value
666
+ for model_format in validation_result.model_formats
667
+ )
668
+ }
669
+ )
629
670
 
630
671
  # Remove `ready_to_import` tag that might get copied from service model.
631
672
  tags.pop(Tags.READY_TO_IMPORT, None)
632
- metadata = None
673
+
633
674
  if verified_model:
634
675
  # Verified model is a model in the service catalog that either has no artifacts but contains all the necessary metadata for deploying and fine tuning.
635
676
  # If set, then we copy all the model metadata.
@@ -638,7 +679,6 @@ class AquaModelApp(AquaApp):
638
679
  model = model.with_model_file_description(
639
680
  json_dict=verified_model.model_file_description
640
681
  )
641
-
642
682
  else:
643
683
  metadata = ModelCustomMetadata()
644
684
  if not inference_container:
@@ -658,10 +698,10 @@ class AquaModelApp(AquaApp):
658
698
  "Proceeding with model registration without the fine-tuning container information. "
659
699
  "This model will not be available for fine tuning."
660
700
  )
661
- if model_file:
701
+ if validation_result and validation_result.model_file:
662
702
  metadata.add(
663
703
  key=AQUA_MODEL_ARTIFACT_FILE,
664
- value=model_file,
704
+ value=validation_result.model_file,
665
705
  description=f"The model file for {model_name}",
666
706
  category="Other",
667
707
  )
@@ -678,9 +718,13 @@ class AquaModelApp(AquaApp):
678
718
  description="Evaluation container mapping for SMC",
679
719
  category="Other",
680
720
  )
681
- # TODO: either get task and organization from user or a config file
682
- # tags["task"] = "UNKNOWN"
683
- # tags["organization"] = "UNKNOWN"
721
+
722
+ if validation_result and validation_result.tags:
723
+ tags[Tags.TASK] = validation_result.tags.get(Tags.TASK, UNKNOWN)
724
+ tags[Tags.ORGANIZATION] = validation_result.tags.get(
725
+ Tags.ORGANIZATION, UNKNOWN
726
+ )
727
+ tags[Tags.LICENSE] = validation_result.tags.get(Tags.LICENSE, UNKNOWN)
684
728
 
685
729
  try:
686
730
  # If verified model already has a artifact json, use that.
@@ -695,7 +739,7 @@ class AquaModelApp(AquaApp):
695
739
  copy_model_config(
696
740
  artifact_path=artifact_path, os_path=os_path, auth=default_signer()
697
741
  )
698
- except:
742
+ except Exception:
699
743
  logger.debug(
700
744
  f"Proceeding with model registration without copying model config files at {os_path}. "
701
745
  f"Default configuration will be used for deployment and fine-tuning."
@@ -708,7 +752,6 @@ class AquaModelApp(AquaApp):
708
752
  category="Other",
709
753
  replace=True,
710
754
  )
711
-
712
755
  model = (
713
756
  model.with_custom_metadata_list(metadata)
714
757
  .with_compartment_id(compartment_id or COMPARTMENT_OCID)
@@ -721,7 +764,7 @@ class AquaModelApp(AquaApp):
721
764
  return model
722
765
 
723
766
  @staticmethod
724
- def get_model_files(os_path: str, model_format: ModelFormat) -> [str]:
767
+ def get_model_files(os_path: str, model_format: ModelFormat) -> List[str]:
725
768
  """
726
769
  Get a list of model files based on the given OS path and model format.
727
770
 
@@ -734,13 +777,15 @@ class AquaModelApp(AquaApp):
734
777
 
735
778
  """
736
779
  model_files: List[str] = []
780
+ # todo: revisit this logic to account for .bin files. In the current state, .bin and .safetensor models
781
+ # are grouped in one category and validation checks for config.json files only.
737
782
  if model_format == ModelFormat.SAFETENSORS:
738
783
  try:
739
784
  load_config(
740
785
  file_path=os_path,
741
786
  config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
742
787
  )
743
- except AquaValueError:
788
+ except Exception:
744
789
  pass
745
790
  else:
746
791
  model_files.append(AQUA_MODEL_ARTIFACT_CONFIG)
@@ -751,142 +796,328 @@ class AquaModelApp(AquaApp):
751
796
  )
752
797
  return model_files
753
798
 
754
- def validate_model_config(
799
+ @staticmethod
800
+ def get_hf_model_files(model_name: str, model_format: ModelFormat) -> List[str]:
801
+ """
802
+ Get a list of model files based on the given OS path and model format.
803
+
804
+ Args:
805
+ model_name (str): The huggingface model name.
806
+ model_format (ModelFormat): The format of the model files.
807
+
808
+ Returns:
809
+ List[str]: A list of model file names.
810
+
811
+ """
812
+ model_files: List[str] = []
813
+
814
+ # todo: revisit this logic to account for .bin files. In the current state, .bin and .safetensor models
815
+ # are grouped in one category and returns config.json file only.
816
+
817
+ try:
818
+ model_siblings = get_hf_model_info(repo_id=model_name).siblings
819
+ except Exception as e:
820
+ huggingface_err_message = str(e)
821
+ raise AquaValueError(
822
+ f"Could not get the model files of {model_name} from https://huggingface.co. "
823
+ f"Error: {huggingface_err_message}."
824
+ ) from e
825
+
826
+ if not model_siblings:
827
+ raise AquaValueError(
828
+ f"Failed to fetch the model files of {model_name} from https://huggingface.co."
829
+ )
830
+
831
+ for model_sibling in model_siblings:
832
+ extension = pathlib.Path(model_sibling.rfilename).suffix[1:].upper()
833
+ if model_format == ModelFormat.SAFETENSORS:
834
+ if model_sibling.rfilename == AQUA_MODEL_ARTIFACT_CONFIG:
835
+ model_files.append(model_sibling.rfilename)
836
+ elif extension == model_format.value:
837
+ model_files.append(model_sibling.rfilename)
838
+
839
+ return model_files
840
+
841
+ def _validate_model(
755
842
  self,
756
- import_model_details: ImportModelDetails,
757
- verified_model: Optional[DataScienceModel],
843
+ import_model_details: ImportModelDetails = None,
844
+ model_name: str = None,
845
+ verified_model: DataScienceModel = None,
758
846
  ) -> ModelValidationResult:
759
847
  """
760
848
  Validates the model configuration and returns the model format telemetry model name.
761
849
 
762
850
  Args:
763
- import_model_details (ImportModelDetails): The details of the imported model.
764
- verified_model (Optional[DataScienceModel]): The verified model.
851
+ import_model_details (ImportModelDetails): Model details for importing the model.
852
+ model_name (str): name of the model
853
+ verified_model (DataScienceModel): If set, then copies all the tags and custom metadata information from
854
+ the service verified model
765
855
 
766
856
  Returns:
767
857
  ModelValidationResult: The result of the model validation.
768
858
 
769
859
  Raises:
770
860
  AquaRuntimeError: If there is an error while loading the config file or if the model path is incorrect.
771
- AquaValueError: If the model format is not supported by AQUA."""
772
- inference_containers_config = (
773
- AquaContainerConfig.from_container_index_json().inference
774
- )
775
- inference_container = import_model_details.inference_container
776
- model_format: ModelFormat
861
+ AquaValueError: If the model format is not supported by AQUA.
862
+ """
863
+ model_formats = []
777
864
  validation_result: ModelValidationResult = ModelValidationResult()
865
+
866
+ hf_download_config_present = False
867
+
868
+ if import_model_details.download_from_hf:
869
+ safetensors_model_files = self.get_hf_model_files(
870
+ model_name, ModelFormat.SAFETENSORS
871
+ )
872
+ if safetensors_model_files:
873
+ hf_download_config_present = True
874
+ gguf_model_files = self.get_hf_model_files(model_name, ModelFormat.GGUF)
875
+ else:
876
+ safetensors_model_files = self.get_model_files(
877
+ import_model_details.os_path, ModelFormat.SAFETENSORS
878
+ )
879
+ gguf_model_files = self.get_model_files(
880
+ import_model_details.os_path, ModelFormat.GGUF
881
+ )
882
+
883
+ if not (safetensors_model_files or gguf_model_files):
884
+ raise AquaRuntimeError(
885
+ f"The model {model_name} does not contain either {ModelFormat.SAFETENSORS.value} "
886
+ f"or {ModelFormat.GGUF.value} files in {import_model_details.os_path} or Hugging Face repository. "
887
+ f"Please check if the path is correct or the model artifacts are available at this location."
888
+ )
889
+
778
890
  if verified_model:
779
891
  aqua_model = self.to_aqua_model(verified_model, self.region)
780
- model_format = aqua_model.model_format
892
+ model_formats = aqua_model.model_formats
781
893
  else:
782
- # Todo: Revisit this logic once a container supports multiple formats
783
- try:
784
- model_format = inference_containers_config[
785
- inference_container
786
- ].model_formats[0]
787
- except (KeyError, IndexError):
788
- logger.warn(
789
- "Unable to fetch model format for the model automatically defaulting to safetensors"
790
- )
791
- model_format = ModelFormat.SAFETENSORS
792
- pass
793
- validation_result.model_format = model_format
794
- if model_format == ModelFormat.SAFETENSORS:
795
- try:
796
- model_config = load_config(
797
- file_path=import_model_details.os_path,
798
- config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
799
- )
800
- except Exception as ex:
801
- logger.error(
802
- f"Exception occurred while loading config file from {import_model_details.os_path}"
803
- f"Exception message: {ex}"
804
- )
805
- raise AquaRuntimeError(
806
- f"The model path {import_model_details.os_path} does not contain the file config.json. "
807
- f"Please check if the path is correct or the model artifacts are available at this location."
808
- )
809
- else:
894
+ if safetensors_model_files:
895
+ model_formats.append(ModelFormat.SAFETENSORS)
896
+ if gguf_model_files:
897
+ model_formats.append(ModelFormat.GGUF)
898
+
899
+ # get tags for models from hf
900
+ if import_model_details.download_from_hf:
901
+ model_info = get_hf_model_info(repo_id=model_name)
902
+
810
903
  try:
811
- metadata_model_type = verified_model.custom_metadata_list.get(
812
- AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
813
- ).value
814
- if metadata_model_type:
815
- if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config:
816
- if (
817
- model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]
818
- != metadata_model_type
819
- ):
820
- raise AquaRuntimeError(
821
- f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}"
822
- f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for "
823
- f"the model {import_model_details.model}. Please check if the path is correct or "
824
- f"the correct model artifacts are available at this location."
825
- f""
826
- )
904
+ license_value = UNKNOWN
905
+ if model_info.tags:
906
+ license_tag = next(
907
+ (
908
+ tag
909
+ for tag in model_info.tags
910
+ if tag.startswith("license:")
911
+ ),
912
+ UNKNOWN,
913
+ )
914
+ license_value = (
915
+ license_tag.split(":")[1] if license_tag else UNKNOWN
916
+ )
917
+
918
+ hf_tags = {
919
+ Tags.TASK: (model_info and model_info.pipeline_tag) or UNKNOWN,
920
+ Tags.ORGANIZATION: (
921
+ model_info.author
922
+ if model_info and hasattr(model_info, "author")
923
+ else UNKNOWN
924
+ ),
925
+ Tags.LICENSE: license_value,
926
+ }
927
+ validation_result.tags = hf_tags
928
+ except Exception:
929
+ pass
930
+
931
+ validation_result.model_formats = model_formats
932
+
933
+ # now as we know that at least one type of model files exist, validate the content of oss path.
934
+ # for safetensors, we check if config.json files exist, and for gguf format we check if files with
935
+ # gguf extension exist.
936
+ for model_format in model_formats:
937
+ if (
938
+ model_format == ModelFormat.SAFETENSORS
939
+ and len(safetensors_model_files) > 0
940
+ ):
941
+ if import_model_details.download_from_hf:
942
+ # validates config.json exists for safetensors model from hugginface
943
+ if not hf_download_config_present:
944
+ raise AquaRuntimeError(
945
+ f"The model {model_name} does not contain {AQUA_MODEL_ARTIFACT_CONFIG} file as required "
946
+ f"by {ModelFormat.SAFETENSORS.value} format model."
947
+ f" Please check if the model name is correct in Hugging Face repository."
948
+ )
949
+ else:
950
+ try:
951
+ model_config = load_config(
952
+ file_path=import_model_details.os_path,
953
+ config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
954
+ )
955
+ except Exception as ex:
956
+ logger.error(
957
+ f"Exception occurred while loading config file from {import_model_details.os_path}"
958
+ f"Exception message: {ex}"
959
+ )
960
+ raise AquaRuntimeError(
961
+ f"The model path {import_model_details.os_path} does not contain the file config.json. "
962
+ f"Please check if the path is correct or the model artifacts are available at this location."
963
+ ) from ex
964
+ else:
965
+ try:
966
+ metadata_model_type = (
967
+ verified_model.custom_metadata_list.get(
968
+ AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
969
+ ).value
970
+ )
971
+ if metadata_model_type:
972
+ if (
973
+ AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
974
+ in model_config
975
+ ):
976
+ if (
977
+ model_config[
978
+ AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
979
+ ]
980
+ != metadata_model_type
981
+ ):
982
+ raise AquaRuntimeError(
983
+ f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}"
984
+ f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for "
985
+ f"the model {model_name}. Please check if the path is correct or "
986
+ f"the correct model artifacts are available at this location."
987
+ f""
988
+ )
989
+ else:
990
+ logger.debug(
991
+ f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in "
992
+ f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration."
993
+ )
994
+ except Exception:
995
+ pass
996
+ if verified_model:
997
+ validation_result.telemetry_model_name = (
998
+ verified_model.display_name
999
+ )
1000
+ elif (
1001
+ model_config is not None
1002
+ and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
1003
+ ):
1004
+ validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
1005
+ elif (
1006
+ model_config is not None
1007
+ and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
1008
+ ):
1009
+ validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
827
1010
  else:
828
- logger.debug(
829
- f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in "
830
- f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration."
1011
+ validation_result.telemetry_model_name = (
1012
+ AQUA_MODEL_TYPE_CUSTOM
831
1013
  )
832
- except:
833
- pass
1014
+ elif model_format == ModelFormat.GGUF and len(gguf_model_files) > 0:
1015
+ if import_model_details.finetuning_container and not safetensors_model_files:
1016
+ raise AquaValueError(
1017
+ "Fine-tuning is currently not supported with GGUF model format."
1018
+ )
1019
+ if verified_model:
1020
+ try:
1021
+ model_file = verified_model.custom_metadata_list.get(
1022
+ AQUA_MODEL_ARTIFACT_FILE
1023
+ ).value
1024
+ except ValueError as err:
1025
+ raise AquaRuntimeError(
1026
+ f"The model {verified_model.display_name} does not contain the custom metadata {AQUA_MODEL_ARTIFACT_FILE}. "
1027
+ f"Please check if the model has the valid metadata."
1028
+ ) from err
1029
+ else:
1030
+ model_file = import_model_details.model_file
1031
+
1032
+ model_files = gguf_model_files
1033
+ # todo: have a separate error validation class for different type of error messages.
1034
+ if model_file:
1035
+ if model_file not in model_files:
1036
+ raise AquaRuntimeError(
1037
+ f"The model path {import_model_details.os_path} or the Hugging Face "
1038
+ f"model repository for {model_name} does not contain the file "
1039
+ f"{model_file}. Please check if the path is correct or the model "
1040
+ f"artifacts are available at this location."
1041
+ )
1042
+ else:
1043
+ validation_result.model_file = model_file
1044
+ elif len(model_files) == 0:
1045
+ raise AquaRuntimeError(
1046
+ f"The model path {import_model_details.os_path} or the Hugging Face model "
1047
+ f"repository for {model_name} does not contain any GGUF format files. "
1048
+ f"Please check if the path is correct or the model artifacts are available "
1049
+ f"at this location."
1050
+ )
1051
+ elif len(model_files) > 1:
1052
+ raise AquaRuntimeError(
1053
+ f"The model path {import_model_details.os_path} or the Hugging Face model "
1054
+ f"repository for {model_name} contains multiple GGUF format files. "
1055
+ f"Please specify the file that needs to be deployed using the model_file "
1056
+ f"parameter."
1057
+ )
1058
+ else:
1059
+ validation_result.model_file = model_files[0]
1060
+
834
1061
  if verified_model:
835
1062
  validation_result.telemetry_model_name = verified_model.display_name
836
- elif (
837
- model_config is not None
838
- and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
839
- ):
840
- validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
841
- elif (
842
- model_config is not None
843
- and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
844
- ):
845
- validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
1063
+ elif import_model_details.download_from_hf:
1064
+ validation_result.telemetry_model_name = model_name
846
1065
  else:
847
1066
  validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
848
1067
 
849
- elif model_format == ModelFormat.GGUF:
850
- if import_model_details.finetuning_container:
851
- raise AquaValueError(
852
- "Finetuning is currently not supported with GGUF model format"
853
- )
854
- if verified_model:
855
- model_file = verified_model.custom_metadata_list.get(
856
- AQUA_MODEL_ARTIFACT_FILE, None
857
- )
1068
+ return validation_result
1069
+
1070
+ @staticmethod
1071
+ def _download_model_from_hf(
1072
+ model_name: str,
1073
+ os_path: str,
1074
+ local_dir: str = None,
1075
+ ) -> str:
1076
+ """This helper function downloads the model artifact from Hugging Face to a local folder, then uploads
1077
+ to object storage location.
1078
+
1079
+ Parameters
1080
+ ----------
1081
+ model_name (str): The huggingface model name.
1082
+ os_path (str): The OS path where the model files are located.
1083
+ local_dir (str): The local temp dir to store the huggingface model.
1084
+
1085
+ Returns
1086
+ -------
1087
+ model_artifact_path (str): Location where the model artifacts are downloaded.
1088
+
1089
+ """
1090
+ # Download the model from hub
1091
+ if not local_dir:
1092
+ local_dir = os.path.join(os.path.expanduser("~"), "cached-model")
1093
+ local_dir = os.path.join(local_dir, model_name)
1094
+ retry = 10
1095
+ i = 0
1096
+ huggingface_download_err_message = None
1097
+ while i < retry:
1098
+ try:
1099
+ # Download to cache folder. The while loop retries when there is a network failure
1100
+ snapshot_download(repo_id=model_name)
1101
+ except Exception as e:
1102
+ huggingface_download_err_message = str(e)
1103
+ i += 1
858
1104
  else:
859
- model_file = import_model_details.model_file
860
- model_files = self.get_model_files(
861
- import_model_details.os_path, model_format
1105
+ break
1106
+ if i == retry:
1107
+ raise Exception(
1108
+ f"Could not download the model {model_name} from https://huggingface.co with message {huggingface_download_err_message}"
862
1109
  )
863
- if model_file:
864
- if model_file not in model_files:
865
- raise AquaRuntimeError(
866
- f"The model path {import_model_details.os_path} does not contain the file {model_file}. "
867
- f"Please check if the path is correct or the model artifacts are available at this location."
868
- )
869
- else:
870
- validation_result.model_file = model_file
871
- elif len(model_files) == 0:
872
- raise AquaRuntimeError(
873
- f"The model path {import_model_details.os_path} does not contain any GGUF format files. "
874
- f"Please check if the path is correct or the model artifacts are available at this location."
875
- )
876
- elif len(model_files) > 1:
877
- raise AquaRuntimeError(
878
- f"The model path {import_model_details.os_path} contains multiple GGUF format files. Please specify the file that needs to be deployed."
879
- )
880
- else:
881
- validation_result.model_file = model_files[0]
1110
+ os.makedirs(local_dir, exist_ok=True)
1111
+ # Copy the model from the cache to destination
1112
+ snapshot_download(repo_id=model_name, local_dir=local_dir)
1113
+ # Upload to object storage
1114
+ model_artifact_path = upload_folder(
1115
+ os_path=os_path,
1116
+ local_dir=local_dir,
1117
+ model_name=model_name,
1118
+ )
882
1119
 
883
- if verified_model:
884
- validation_result.telemetry_model_name = verified_model.display_name
885
- else:
886
- validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
887
- else:
888
- raise AquaValueError("This model format is currently not supported by AQUA")
889
- return validation_result
1120
+ return model_artifact_path
890
1121
 
891
1122
  def register(
892
1123
  self, import_model_details: ImportModelDetails = None, **kwargs
@@ -928,9 +1159,6 @@ class AquaModelApp(AquaApp):
928
1159
  )
929
1160
  verified_model = DataScienceModel.from_id(model_service_id)
930
1161
 
931
- validation_result = self.validate_model_config(
932
- import_model_details, verified_model
933
- )
934
1162
  # Copy the model name from the service model if `model` is ocid
935
1163
  model_name = (
936
1164
  verified_model.display_name
@@ -938,17 +1166,33 @@ class AquaModelApp(AquaApp):
938
1166
  else import_model_details.model
939
1167
  )
940
1168
 
1169
+ # validate model and artifact
1170
+ validation_result = self._validate_model(
1171
+ import_model_details=import_model_details,
1172
+ model_name=model_name,
1173
+ verified_model=verified_model,
1174
+ )
1175
+
1176
+ # download model from hugginface if indicates
1177
+ if import_model_details.download_from_hf:
1178
+ artifact_path = self._download_model_from_hf(
1179
+ model_name=model_name,
1180
+ os_path=import_model_details.os_path,
1181
+ local_dir=import_model_details.local_dir,
1182
+ ).rstrip("/")
1183
+ else:
1184
+ artifact_path = import_model_details.os_path.rstrip("/")
1185
+
941
1186
  # Create Model catalog entry with pass by reference
942
1187
  ds_model = self._create_model_catalog_entry(
943
- os_path=import_model_details.os_path,
1188
+ os_path=artifact_path,
944
1189
  model_name=model_name,
945
1190
  inference_container=import_model_details.inference_container,
946
1191
  finetuning_container=import_model_details.finetuning_container,
947
1192
  verified_model=verified_model,
1193
+ validation_result=validation_result,
948
1194
  compartment_id=import_model_details.compartment_id,
949
1195
  project_id=import_model_details.project_id,
950
- model_file=validation_result.model_file,
951
- model_format=validation_result.model_format,
952
1196
  )
953
1197
  # registered model will always have inference and evaluation container, but
954
1198
  # fine-tuning container may be not set
@@ -962,7 +1206,7 @@ class AquaModelApp(AquaApp):
962
1206
  finetuning_container = ds_model.custom_metadata_list.get(
963
1207
  ModelCustomMetadataFields.FINETUNE_CONTAINER,
964
1208
  ).value
965
- except:
1209
+ except Exception:
966
1210
  finetuning_container = None
967
1211
 
968
1212
  aqua_model_attributes = dict(
@@ -970,13 +1214,14 @@ class AquaModelApp(AquaApp):
970
1214
  project_id=ds_model.project_id,
971
1215
  model_card=str(
972
1216
  read_file(
973
- file_path=f"{import_model_details.os_path.rstrip('/')}/{README}",
1217
+ file_path=f"{artifact_path}/{README}",
974
1218
  auth=default_signer(),
975
1219
  )
976
1220
  ),
977
1221
  inference_container=inference_container,
978
1222
  finetuning_container=finetuning_container,
979
1223
  evaluation_container=evaluation_container,
1224
+ artifact_location=artifact_path,
980
1225
  )
981
1226
 
982
1227
  self.telemetry.record_event_async(