oracle-ads 2.11.15__py3-none-any.whl → 2.11.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. ads/aqua/app.py +5 -6
  2. ads/aqua/common/entities.py +17 -0
  3. ads/aqua/common/enums.py +14 -1
  4. ads/aqua/common/utils.py +160 -3
  5. ads/aqua/config/config.py +1 -1
  6. ads/aqua/config/deployment_config_defaults.json +29 -1
  7. ads/aqua/config/resource_limit_names.json +1 -0
  8. ads/aqua/constants.py +6 -1
  9. ads/aqua/evaluation/entities.py +0 -1
  10. ads/aqua/evaluation/evaluation.py +47 -14
  11. ads/aqua/extension/common_handler.py +75 -5
  12. ads/aqua/extension/common_ws_msg_handler.py +57 -0
  13. ads/aqua/extension/deployment_handler.py +16 -13
  14. ads/aqua/extension/deployment_ws_msg_handler.py +54 -0
  15. ads/aqua/extension/errors.py +1 -1
  16. ads/aqua/extension/evaluation_ws_msg_handler.py +28 -6
  17. ads/aqua/extension/model_handler.py +134 -8
  18. ads/aqua/extension/models/ws_models.py +78 -3
  19. ads/aqua/extension/models_ws_msg_handler.py +49 -0
  20. ads/aqua/extension/ui_websocket_handler.py +7 -1
  21. ads/aqua/model/entities.py +28 -0
  22. ads/aqua/model/model.py +544 -129
  23. ads/aqua/modeldeployment/deployment.py +102 -43
  24. ads/aqua/modeldeployment/entities.py +9 -20
  25. ads/aqua/ui.py +152 -28
  26. ads/common/object_storage_details.py +2 -5
  27. ads/common/serializer.py +2 -3
  28. ads/jobs/builders/infrastructure/dsc_job.py +41 -12
  29. ads/jobs/builders/infrastructure/dsc_job_runtime.py +74 -27
  30. ads/jobs/builders/runtimes/container_runtime.py +83 -4
  31. ads/opctl/operator/lowcode/anomaly/const.py +1 -0
  32. ads/opctl/operator/lowcode/anomaly/model/base_model.py +23 -7
  33. ads/opctl/operator/lowcode/anomaly/operator_config.py +1 -0
  34. ads/opctl/operator/lowcode/anomaly/schema.yaml +4 -0
  35. ads/opctl/operator/lowcode/common/errors.py +6 -0
  36. ads/opctl/operator/lowcode/forecast/model/arima.py +3 -1
  37. ads/opctl/operator/lowcode/forecast/model/base_model.py +21 -13
  38. ads/opctl/operator/lowcode/forecast/model_evaluator.py +11 -2
  39. ads/pipeline/ads_pipeline_run.py +13 -2
  40. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/METADATA +2 -1
  41. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/RECORD +44 -40
  42. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/LICENSE.txt +0 -0
  43. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/WHEEL +0 -0
  44. {oracle_ads-2.11.15.dist-info → oracle_ads-2.11.17.dist-info}/entry_points.txt +0 -0
ads/aqua/model/model.py CHANGED
@@ -1,27 +1,39 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
2
  # Copyright (c) 2024 Oracle and/or its affiliates.
4
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
  import os
5
+ import pathlib
6
6
  from datetime import datetime, timedelta
7
7
  from threading import Lock
8
- from typing import List, Optional, Union
8
+ from typing import Dict, List, Optional, Set, Union
9
9
 
10
+ import oci
10
11
  from cachetools import TTLCache
12
+ from huggingface_hub import snapshot_download
11
13
  from oci.data_science.models import JobRun, Model
12
14
 
13
- from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
15
+ from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
14
16
  from ads.aqua.app import AquaApp
15
17
  from ads.aqua.common.enums import Tags
16
- from ads.aqua.common.errors import AquaRuntimeError
18
+ from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
17
19
  from ads.aqua.common.utils import (
20
+ LifecycleStatus,
21
+ _build_resource_identifier,
22
+ copy_model_config,
18
23
  create_word_icon,
19
24
  get_artifact_path,
20
- read_file,
21
- copy_model_config,
25
+ get_hf_model_info,
26
+ list_os_files_with_extension,
22
27
  load_config,
28
+ read_file,
29
+ upload_folder,
23
30
  )
24
31
  from ads.aqua.constants import (
32
+ AQUA_MODEL_ARTIFACT_CONFIG,
33
+ AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME,
34
+ AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
35
+ AQUA_MODEL_ARTIFACT_FILE,
36
+ AQUA_MODEL_TYPE_CUSTOM,
25
37
  LICENSE_TXT,
26
38
  MODEL_BY_REFERENCE_OSS_PATH_KEY,
27
39
  README,
@@ -33,13 +45,24 @@ from ads.aqua.constants import (
33
45
  UNKNOWN,
34
46
  VALIDATION_METRICS,
35
47
  VALIDATION_METRICS_FINAL,
36
- AQUA_MODEL_ARTIFACT_CONFIG,
37
- AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME,
38
- AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
39
- AQUA_MODEL_TYPE_CUSTOM,
40
48
  )
41
- from ads.aqua.model.constants import *
42
- from ads.aqua.model.entities import *
49
+ from ads.aqua.model.constants import (
50
+ FineTuningCustomMetadata,
51
+ FineTuningMetricCategories,
52
+ ModelCustomMetadataFields,
53
+ ModelType,
54
+ )
55
+ from ads.aqua.model.entities import (
56
+ AquaFineTuneModel,
57
+ AquaFineTuningMetric,
58
+ AquaModel,
59
+ AquaModelLicense,
60
+ AquaModelSummary,
61
+ ImportModelDetails,
62
+ ModelFormat,
63
+ ModelValidationResult,
64
+ )
65
+ from ads.aqua.ui import AquaContainerConfig, AquaContainerConfigItem
43
66
  from ads.common.auth import default_signer
44
67
  from ads.common.oci_resource import SEARCH_TYPE, OCIResource
45
68
  from ads.common.utils import get_console_link
@@ -174,11 +197,9 @@ class AquaModelApp(AquaApp):
174
197
  if not self._if_show(ds_model):
175
198
  raise AquaRuntimeError(f"Target model `{ds_model.id} `is not Aqua model.")
176
199
 
177
- is_fine_tuned_model = (
178
- True
179
- if ds_model.freeform_tags
200
+ is_fine_tuned_model = bool(
201
+ ds_model.freeform_tags
180
202
  and ds_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG)
181
- else False
182
203
  )
183
204
 
184
205
  # todo: consolidate this logic in utils for model and deployment use
@@ -216,6 +237,10 @@ class AquaModelApp(AquaApp):
216
237
  ModelCustomMetadataFields.FINETUNE_CONTAINER,
217
238
  ModelCustomMetadataItem(key=ModelCustomMetadataFields.FINETUNE_CONTAINER),
218
239
  ).value
240
+ artifact_location = ds_model.custom_metadata_list.get(
241
+ ModelCustomMetadataFields.ARTIFACT_LOCATION,
242
+ ModelCustomMetadataItem(key=ModelCustomMetadataFields.ARTIFACT_LOCATION),
243
+ ).value
219
244
 
220
245
  aqua_model_attributes = dict(
221
246
  **self._process_model(ds_model, self.region),
@@ -224,6 +249,7 @@ class AquaModelApp(AquaApp):
224
249
  inference_container=inference_container,
225
250
  finetuning_container=finetuning_container,
226
251
  evaluation_container=evaluation_container,
252
+ artifact_location=artifact_location,
227
253
  )
228
254
 
229
255
  if not is_fine_tuned_model:
@@ -236,7 +262,7 @@ class AquaModelApp(AquaApp):
236
262
  try:
237
263
  jobrun_ocid = ds_model.provenance_metadata.training_id
238
264
  jobrun = self.ds_client.get_job_run(jobrun_ocid).data
239
- except Exception as e:
265
+ except Exception:
240
266
  logger.debug(
241
267
  f"Missing jobrun information in the provenance metadata of the given model {model_id}."
242
268
  )
@@ -258,7 +284,7 @@ class AquaModelApp(AquaApp):
258
284
  logger.debug(str(e))
259
285
  source_name = UNKNOWN
260
286
 
261
- source_identifier = utils._build_resource_identifier(
287
+ source_identifier = _build_resource_identifier(
262
288
  id=source_id,
263
289
  name=source_name,
264
290
  region=self.region,
@@ -268,8 +294,7 @@ class AquaModelApp(AquaApp):
268
294
 
269
295
  job_run_status = (
270
296
  jobrun.lifecycle_state
271
- if jobrun
272
- and not jobrun.lifecycle_state == JobRun.LIFECYCLE_STATE_DELETED
297
+ if jobrun and jobrun.lifecycle_state != JobRun.LIFECYCLE_STATE_DELETED
273
298
  else (
274
299
  JobRun.LIFECYCLE_STATE_SUCCEEDED
275
300
  if self.if_artifact_exist(ds_model.id)
@@ -277,7 +302,7 @@ class AquaModelApp(AquaApp):
277
302
  )
278
303
  )
279
304
  # TODO: change the argument's name.
280
- lifecycle_state = utils.LifecycleStatus.get_status(
305
+ lifecycle_state = LifecycleStatus.get_status(
281
306
  evaluation_status=ds_model.lifecycle_state,
282
307
  job_run_status=job_run_status,
283
308
  )
@@ -320,7 +345,7 @@ class AquaModelApp(AquaApp):
320
345
  category=category,
321
346
  scores=scores,
322
347
  )
323
- except:
348
+ except Exception:
324
349
  return AquaFineTuningMetric(name=metric_name, category=category, scores=[])
325
350
 
326
351
  def _build_ft_metrics(
@@ -363,8 +388,21 @@ class AquaModelApp(AquaApp):
363
388
  training_final,
364
389
  ]
365
390
 
391
+ @staticmethod
392
+ def to_aqua_model(
393
+ model: Union[
394
+ DataScienceModel,
395
+ oci.data_science.models.model.Model,
396
+ oci.data_science.models.ModelSummary,
397
+ oci.resource_search.models.ResourceSummary,
398
+ ],
399
+ region: str,
400
+ ) -> AquaModel:
401
+ """Converts a model to an Aqua model."""
402
+ return AquaModel(**AquaModelApp._process_model(model, region))
403
+
404
+ @staticmethod
366
405
  def _process_model(
367
- self,
368
406
  model: Union[
369
407
  DataScienceModel,
370
408
  oci.data_science.models.model.Model,
@@ -389,24 +427,20 @@ class AquaModelApp(AquaApp):
389
427
  else model.id
390
428
  )
391
429
 
392
- console_link = (
393
- get_console_link(
394
- resource="models",
395
- ocid=model_id,
396
- region=region,
397
- ),
430
+ console_link = get_console_link(
431
+ resource="models",
432
+ ocid=model_id,
433
+ region=region,
398
434
  )
399
435
 
400
436
  description = ""
401
- if isinstance(model, DataScienceModel) or isinstance(
402
- model, oci.data_science.models.model.Model
403
- ):
437
+ if isinstance(model, (DataScienceModel, oci.data_science.models.model.Model)):
404
438
  description = model.description
405
439
  elif isinstance(model, oci.resource_search.models.ResourceSummary):
406
440
  description = model.additional_details.get("description")
407
441
 
408
442
  search_text = (
409
- self._build_search_text(tags=tags, description=description)
443
+ AquaModelApp._build_search_text(tags=tags, description=description)
410
444
  if tags
411
445
  else UNKNOWN
412
446
  )
@@ -416,6 +450,7 @@ class AquaModelApp(AquaApp):
416
450
  ready_to_deploy = (
417
451
  freeform_tags.get(Tags.AQUA_TAG, "").upper() == READY_TO_DEPLOY_STATUS
418
452
  )
453
+
419
454
  ready_to_finetune = (
420
455
  freeform_tags.get(Tags.READY_TO_FINE_TUNE, "").upper()
421
456
  == READY_TO_FINE_TUNE_STATUS
@@ -425,23 +460,55 @@ class AquaModelApp(AquaApp):
425
460
  == READY_TO_IMPORT_STATUS
426
461
  )
427
462
 
428
- return dict(
429
- compartment_id=model.compartment_id,
430
- icon=icon or UNKNOWN,
431
- id=model_id,
432
- license=freeform_tags.get(Tags.LICENSE, UNKNOWN),
433
- name=model.display_name,
434
- organization=freeform_tags.get(Tags.ORGANIZATION, UNKNOWN),
435
- task=freeform_tags.get(Tags.TASK, UNKNOWN),
436
- time_created=model.time_created,
437
- is_fine_tuned_model=is_fine_tuned_model,
438
- tags=tags,
439
- console_link=console_link,
440
- search_text=search_text,
441
- ready_to_deploy=ready_to_deploy,
442
- ready_to_finetune=ready_to_finetune,
443
- ready_to_import=ready_to_import,
463
+ try:
464
+ model_file = model.custom_metadata_list.get(AQUA_MODEL_ARTIFACT_FILE).value
465
+ except Exception:
466
+ model_file = UNKNOWN
467
+
468
+ inference_containers = AquaContainerConfig.from_container_index_json().inference
469
+
470
+ model_formats_str = freeform_tags.get(
471
+ Tags.MODEL_FORMAT, ModelFormat.SAFETENSORS.value
472
+ ).upper()
473
+ model_formats = [
474
+ ModelFormat[model_format] for model_format in model_formats_str.split(",")
475
+ ]
476
+
477
+ supported_platform: Set[AquaContainerConfigItem.Platform] = set()
478
+
479
+ for container in inference_containers.values():
480
+ for model_format in model_formats:
481
+ if model_format in container.model_formats:
482
+ supported_platform.update(container.platforms)
483
+
484
+ nvidia_gpu_supported = (
485
+ AquaContainerConfigItem.Platform.NVIDIA_GPU in supported_platform
444
486
  )
487
+ arm_cpu_supported = (
488
+ AquaContainerConfigItem.Platform.ARM_CPU in supported_platform
489
+ )
490
+
491
+ return {
492
+ "compartment_id": model.compartment_id,
493
+ "icon": icon or UNKNOWN,
494
+ "id": model_id,
495
+ "license": freeform_tags.get(Tags.LICENSE, UNKNOWN),
496
+ "name": model.display_name,
497
+ "organization": freeform_tags.get(Tags.ORGANIZATION, UNKNOWN),
498
+ "task": freeform_tags.get(Tags.TASK, UNKNOWN),
499
+ "time_created": str(model.time_created),
500
+ "is_fine_tuned_model": is_fine_tuned_model,
501
+ "tags": tags,
502
+ "console_link": console_link,
503
+ "search_text": search_text,
504
+ "ready_to_deploy": ready_to_deploy,
505
+ "ready_to_finetune": ready_to_finetune,
506
+ "ready_to_import": ready_to_import,
507
+ "nvidia_gpu_supported": nvidia_gpu_supported,
508
+ "arm_cpu_supported": arm_cpu_supported,
509
+ "model_file": model_file,
510
+ "model_formats": model_formats,
511
+ }
445
512
 
446
513
  @telemetry(entry_point="plugin=model&action=list", name="aqua")
447
514
  def list(
@@ -490,7 +557,7 @@ class AquaModelApp(AquaApp):
490
557
  category="aqua/service/model", action="list"
491
558
  )
492
559
 
493
- if ODSC_MODEL_COMPARTMENT_OCID in self._service_models_cache.keys():
560
+ if ODSC_MODEL_COMPARTMENT_OCID in self._service_models_cache:
494
561
  logger.info(
495
562
  f"Returning service models list in {ODSC_MODEL_COMPARTMENT_OCID} from cache."
496
563
  )
@@ -540,9 +607,9 @@ class AquaModelApp(AquaApp):
540
607
  dict with the key used, and True if cache has the key that needs to be deleted.
541
608
  """
542
609
  res = {}
543
- logger.info(f"Clearing _service_models_cache")
610
+ logger.info("Clearing _service_models_cache")
544
611
  with self._cache_lock:
545
- if ODSC_MODEL_COMPARTMENT_OCID in self._service_models_cache.keys():
612
+ if ODSC_MODEL_COMPARTMENT_OCID in self._service_models_cache:
546
613
  self._service_models_cache.pop(key=ODSC_MODEL_COMPARTMENT_OCID)
547
614
  res = {
548
615
  "key": {
@@ -559,6 +626,7 @@ class AquaModelApp(AquaApp):
559
626
  inference_container: str,
560
627
  finetuning_container: str,
561
628
  verified_model: DataScienceModel,
629
+ validation_result: ModelValidationResult,
562
630
  compartment_id: Optional[str],
563
631
  project_id: Optional[str],
564
632
  ) -> DataScienceModel:
@@ -577,19 +645,32 @@ class AquaModelApp(AquaApp):
577
645
  DataScienceModel: Returns Datascience model instance.
578
646
  """
579
647
  model = DataScienceModel()
580
- tags = (
648
+ tags: Dict[str, str] = (
581
649
  {
582
650
  **verified_model.freeform_tags,
583
651
  Tags.AQUA_SERVICE_MODEL_TAG: verified_model.id,
584
652
  }
585
653
  if verified_model
586
- else {Tags.AQUA_TAG: "active", Tags.BASE_MODEL_CUSTOM: "true"}
654
+ else {
655
+ Tags.AQUA_TAG: "active",
656
+ Tags.BASE_MODEL_CUSTOM: "true",
657
+ }
587
658
  )
588
659
  tags.update({Tags.BASE_MODEL_CUSTOM: "true"})
589
660
 
661
+ if validation_result and validation_result.model_formats:
662
+ tags.update(
663
+ {
664
+ Tags.MODEL_FORMAT: ",".join(
665
+ model_format.value
666
+ for model_format in validation_result.model_formats
667
+ )
668
+ }
669
+ )
670
+
590
671
  # Remove `ready_to_import` tag that might get copied from service model.
591
672
  tags.pop(Tags.READY_TO_IMPORT, None)
592
- metadata = None
673
+
593
674
  if verified_model:
594
675
  # Verified model is a model in the service catalog that either has no artifacts but contains all the necessary metadata for deploying and fine tuning.
595
676
  # If set, then we copy all the model metadata.
@@ -598,7 +679,6 @@ class AquaModelApp(AquaApp):
598
679
  model = model.with_model_file_description(
599
680
  json_dict=verified_model.model_file_description
600
681
  )
601
-
602
682
  else:
603
683
  metadata = ModelCustomMetadata()
604
684
  if not inference_container:
@@ -615,8 +695,15 @@ class AquaModelApp(AquaApp):
615
695
  )
616
696
  else:
617
697
  logger.warn(
618
- f"Proceeding with model registration without the fine-tuning container information. "
619
- f"This model will not be available for fine tuning."
698
+ "Proceeding with model registration without the fine-tuning container information. "
699
+ "This model will not be available for fine tuning."
700
+ )
701
+ if validation_result and validation_result.model_file:
702
+ metadata.add(
703
+ key=AQUA_MODEL_ARTIFACT_FILE,
704
+ value=validation_result.model_file,
705
+ description=f"The model file for {model_name}",
706
+ category="Other",
620
707
  )
621
708
 
622
709
  metadata.add(
@@ -631,9 +718,13 @@ class AquaModelApp(AquaApp):
631
718
  description="Evaluation container mapping for SMC",
632
719
  category="Other",
633
720
  )
634
- # TODO: either get task and organization from user or a config file
635
- # tags["task"] = "UNKNOWN"
636
- # tags["organization"] = "UNKNOWN"
721
+
722
+ if validation_result and validation_result.tags:
723
+ tags[Tags.TASK] = validation_result.tags.get(Tags.TASK, UNKNOWN)
724
+ tags[Tags.ORGANIZATION] = validation_result.tags.get(
725
+ Tags.ORGANIZATION, UNKNOWN
726
+ )
727
+ tags[Tags.LICENSE] = validation_result.tags.get(Tags.LICENSE, UNKNOWN)
637
728
 
638
729
  try:
639
730
  # If verified model already has a artifact json, use that.
@@ -648,7 +739,7 @@ class AquaModelApp(AquaApp):
648
739
  copy_model_config(
649
740
  artifact_path=artifact_path, os_path=os_path, auth=default_signer()
650
741
  )
651
- except:
742
+ except Exception:
652
743
  logger.debug(
653
744
  f"Proceeding with model registration without copying model config files at {os_path}. "
654
745
  f"Default configuration will be used for deployment and fine-tuning."
@@ -661,7 +752,6 @@ class AquaModelApp(AquaApp):
661
752
  category="Other",
662
753
  replace=True,
663
754
  )
664
-
665
755
  model = (
666
756
  model.with_custom_metadata_list(metadata)
667
757
  .with_compartment_id(compartment_id or COMPARTMENT_OCID)
@@ -673,6 +763,362 @@ class AquaModelApp(AquaApp):
673
763
  logger.debug(model)
674
764
  return model
675
765
 
766
+ @staticmethod
767
+ def get_model_files(os_path: str, model_format: ModelFormat) -> List[str]:
768
+ """
769
+ Get a list of model files based on the given OS path and model format.
770
+
771
+ Args:
772
+ os_path (str): The OS path where the model files are located.
773
+ model_format (ModelFormat): The format of the model files.
774
+
775
+ Returns:
776
+ List[str]: A list of model file names.
777
+
778
+ """
779
+ model_files: List[str] = []
780
+ # todo: revisit this logic to account for .bin files. In the current state, .bin and .safetensor models
781
+ # are grouped in one category and validation checks for config.json files only.
782
+ if model_format == ModelFormat.SAFETENSORS:
783
+ try:
784
+ load_config(
785
+ file_path=os_path,
786
+ config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
787
+ )
788
+ except Exception:
789
+ pass
790
+ else:
791
+ model_files.append(AQUA_MODEL_ARTIFACT_CONFIG)
792
+
793
+ if model_format == ModelFormat.GGUF:
794
+ model_files.extend(
795
+ list_os_files_with_extension(oss_path=os_path, extension=".gguf")
796
+ )
797
+ return model_files
798
+
799
+ @staticmethod
800
+ def get_hf_model_files(model_name: str, model_format: ModelFormat) -> List[str]:
801
+ """
802
+ Get a list of model files based on the given OS path and model format.
803
+
804
+ Args:
805
+ model_name (str): The huggingface model name.
806
+ model_format (ModelFormat): The format of the model files.
807
+
808
+ Returns:
809
+ List[str]: A list of model file names.
810
+
811
+ """
812
+ model_files: List[str] = []
813
+
814
+ # todo: revisit this logic to account for .bin files. In the current state, .bin and .safetensor models
815
+ # are grouped in one category and returns config.json file only.
816
+
817
+ try:
818
+ model_siblings = get_hf_model_info(repo_id=model_name).siblings
819
+ except Exception as e:
820
+ huggingface_err_message = str(e)
821
+ raise AquaValueError(
822
+ f"Could not get the model files of {model_name} from https://huggingface.co. "
823
+ f"Error: {huggingface_err_message}."
824
+ ) from e
825
+
826
+ if not model_siblings:
827
+ raise AquaValueError(
828
+ f"Failed to fetch the model files of {model_name} from https://huggingface.co."
829
+ )
830
+
831
+ for model_sibling in model_siblings:
832
+ extension = pathlib.Path(model_sibling.rfilename).suffix[1:].upper()
833
+ if model_format == ModelFormat.SAFETENSORS:
834
+ if model_sibling.rfilename == AQUA_MODEL_ARTIFACT_CONFIG:
835
+ model_files.append(model_sibling.rfilename)
836
+ elif extension == model_format.value:
837
+ model_files.append(model_sibling.rfilename)
838
+
839
+ return model_files
840
+
841
+ def _validate_model(
842
+ self,
843
+ import_model_details: ImportModelDetails = None,
844
+ model_name: str = None,
845
+ verified_model: DataScienceModel = None,
846
+ ) -> ModelValidationResult:
847
+ """
848
+ Validates the model configuration and returns the model format telemetry model name.
849
+
850
+ Args:
851
+ import_model_details (ImportModelDetails): Model details for importing the model.
852
+ model_name (str): name of the model
853
+ verified_model (DataScienceModel): If set, then copies all the tags and custom metadata information from
854
+ the service verified model
855
+
856
+ Returns:
857
+ ModelValidationResult: The result of the model validation.
858
+
859
+ Raises:
860
+ AquaRuntimeError: If there is an error while loading the config file or if the model path is incorrect.
861
+ AquaValueError: If the model format is not supported by AQUA.
862
+ """
863
+ model_formats = []
864
+ validation_result: ModelValidationResult = ModelValidationResult()
865
+
866
+ hf_download_config_present = False
867
+
868
+ if import_model_details.download_from_hf:
869
+ safetensors_model_files = self.get_hf_model_files(
870
+ model_name, ModelFormat.SAFETENSORS
871
+ )
872
+ if safetensors_model_files:
873
+ hf_download_config_present = True
874
+ gguf_model_files = self.get_hf_model_files(model_name, ModelFormat.GGUF)
875
+ else:
876
+ safetensors_model_files = self.get_model_files(
877
+ import_model_details.os_path, ModelFormat.SAFETENSORS
878
+ )
879
+ gguf_model_files = self.get_model_files(
880
+ import_model_details.os_path, ModelFormat.GGUF
881
+ )
882
+
883
+ if not (safetensors_model_files or gguf_model_files):
884
+ raise AquaRuntimeError(
885
+ f"The model {model_name} does not contain either {ModelFormat.SAFETENSORS.value} "
886
+ f"or {ModelFormat.GGUF.value} files in {import_model_details.os_path} or Hugging Face repository. "
887
+ f"Please check if the path is correct or the model artifacts are available at this location."
888
+ )
889
+
890
+ if verified_model:
891
+ aqua_model = self.to_aqua_model(verified_model, self.region)
892
+ model_formats = aqua_model.model_formats
893
+ else:
894
+ if safetensors_model_files:
895
+ model_formats.append(ModelFormat.SAFETENSORS)
896
+ if gguf_model_files:
897
+ model_formats.append(ModelFormat.GGUF)
898
+
899
+ # get tags for models from hf
900
+ if import_model_details.download_from_hf:
901
+ model_info = get_hf_model_info(repo_id=model_name)
902
+
903
+ try:
904
+ license_value = UNKNOWN
905
+ if model_info.tags:
906
+ license_tag = next(
907
+ (
908
+ tag
909
+ for tag in model_info.tags
910
+ if tag.startswith("license:")
911
+ ),
912
+ UNKNOWN,
913
+ )
914
+ license_value = (
915
+ license_tag.split(":")[1] if license_tag else UNKNOWN
916
+ )
917
+
918
+ hf_tags = {
919
+ Tags.TASK: (model_info and model_info.pipeline_tag) or UNKNOWN,
920
+ Tags.ORGANIZATION: (
921
+ model_info.author
922
+ if model_info and hasattr(model_info, "author")
923
+ else UNKNOWN
924
+ ),
925
+ Tags.LICENSE: license_value,
926
+ }
927
+ validation_result.tags = hf_tags
928
+ except Exception:
929
+ pass
930
+
931
+ validation_result.model_formats = model_formats
932
+
933
+ # now as we know that at least one type of model files exist, validate the content of oss path.
934
+ # for safetensors, we check if config.json files exist, and for gguf format we check if files with
935
+ # gguf extension exist.
936
+ for model_format in model_formats:
937
+ if (
938
+ model_format == ModelFormat.SAFETENSORS
939
+ and len(safetensors_model_files) > 0
940
+ ):
941
+ if import_model_details.download_from_hf:
942
+ # validates config.json exists for safetensors model from hugginface
943
+ if not hf_download_config_present:
944
+ raise AquaRuntimeError(
945
+ f"The model {model_name} does not contain {AQUA_MODEL_ARTIFACT_CONFIG} file as required "
946
+ f"by {ModelFormat.SAFETENSORS.value} format model."
947
+ f" Please check if the model name is correct in Hugging Face repository."
948
+ )
949
+ else:
950
+ try:
951
+ model_config = load_config(
952
+ file_path=import_model_details.os_path,
953
+ config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
954
+ )
955
+ except Exception as ex:
956
+ logger.error(
957
+ f"Exception occurred while loading config file from {import_model_details.os_path}"
958
+ f"Exception message: {ex}"
959
+ )
960
+ raise AquaRuntimeError(
961
+ f"The model path {import_model_details.os_path} does not contain the file config.json. "
962
+ f"Please check if the path is correct or the model artifacts are available at this location."
963
+ ) from ex
964
+ else:
965
+ try:
966
+ metadata_model_type = (
967
+ verified_model.custom_metadata_list.get(
968
+ AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
969
+ ).value
970
+ )
971
+ if metadata_model_type:
972
+ if (
973
+ AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
974
+ in model_config
975
+ ):
976
+ if (
977
+ model_config[
978
+ AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
979
+ ]
980
+ != metadata_model_type
981
+ ):
982
+ raise AquaRuntimeError(
983
+ f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}"
984
+ f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for "
985
+ f"the model {model_name}. Please check if the path is correct or "
986
+ f"the correct model artifacts are available at this location."
987
+ f""
988
+ )
989
+ else:
990
+ logger.debug(
991
+ f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in "
992
+ f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration."
993
+ )
994
+ except Exception:
995
+ pass
996
+ if verified_model:
997
+ validation_result.telemetry_model_name = (
998
+ verified_model.display_name
999
+ )
1000
+ elif (
1001
+ model_config is not None
1002
+ and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
1003
+ ):
1004
+ validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
1005
+ elif (
1006
+ model_config is not None
1007
+ and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
1008
+ ):
1009
+ validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
1010
+ else:
1011
+ validation_result.telemetry_model_name = (
1012
+ AQUA_MODEL_TYPE_CUSTOM
1013
+ )
1014
+ elif model_format == ModelFormat.GGUF and len(gguf_model_files) > 0:
1015
+ if import_model_details.finetuning_container and not safetensors_model_files:
1016
+ raise AquaValueError(
1017
+ "Fine-tuning is currently not supported with GGUF model format."
1018
+ )
1019
+ if verified_model:
1020
+ try:
1021
+ model_file = verified_model.custom_metadata_list.get(
1022
+ AQUA_MODEL_ARTIFACT_FILE
1023
+ ).value
1024
+ except ValueError as err:
1025
+ raise AquaRuntimeError(
1026
+ f"The model {verified_model.display_name} does not contain the custom metadata {AQUA_MODEL_ARTIFACT_FILE}. "
1027
+ f"Please check if the model has the valid metadata."
1028
+ ) from err
1029
+ else:
1030
+ model_file = import_model_details.model_file
1031
+
1032
+ model_files = gguf_model_files
1033
+ # todo: have a separate error validation class for different type of error messages.
1034
+ if model_file:
1035
+ if model_file not in model_files:
1036
+ raise AquaRuntimeError(
1037
+ f"The model path {import_model_details.os_path} or the Hugging Face "
1038
+ f"model repository for {model_name} does not contain the file "
1039
+ f"{model_file}. Please check if the path is correct or the model "
1040
+ f"artifacts are available at this location."
1041
+ )
1042
+ else:
1043
+ validation_result.model_file = model_file
1044
+ elif len(model_files) == 0:
1045
+ raise AquaRuntimeError(
1046
+ f"The model path {import_model_details.os_path} or the Hugging Face model "
1047
+ f"repository for {model_name} does not contain any GGUF format files. "
1048
+ f"Please check if the path is correct or the model artifacts are available "
1049
+ f"at this location."
1050
+ )
1051
+ elif len(model_files) > 1:
1052
+ raise AquaRuntimeError(
1053
+ f"The model path {import_model_details.os_path} or the Hugging Face model "
1054
+ f"repository for {model_name} contains multiple GGUF format files. "
1055
+ f"Please specify the file that needs to be deployed using the model_file "
1056
+ f"parameter."
1057
+ )
1058
+ else:
1059
+ validation_result.model_file = model_files[0]
1060
+
1061
+ if verified_model:
1062
+ validation_result.telemetry_model_name = verified_model.display_name
1063
+ elif import_model_details.download_from_hf:
1064
+ validation_result.telemetry_model_name = model_name
1065
+ else:
1066
+ validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
1067
+
1068
+ return validation_result
1069
+
1070
+ @staticmethod
1071
+ def _download_model_from_hf(
1072
+ model_name: str,
1073
+ os_path: str,
1074
+ local_dir: str = None,
1075
+ ) -> str:
1076
+ """This helper function downloads the model artifact from Hugging Face to a local folder, then uploads
1077
+ to object storage location.
1078
+
1079
+ Parameters
1080
+ ----------
1081
+ model_name (str): The huggingface model name.
1082
+ os_path (str): The OS path where the model files are located.
1083
+ local_dir (str): The local temp dir to store the huggingface model.
1084
+
1085
+ Returns
1086
+ -------
1087
+ model_artifact_path (str): Location where the model artifacts are downloaded.
1088
+
1089
+ """
1090
+ # Download the model from hub
1091
+ if not local_dir:
1092
+ local_dir = os.path.join(os.path.expanduser("~"), "cached-model")
1093
+ local_dir = os.path.join(local_dir, model_name)
1094
+ retry = 10
1095
+ i = 0
1096
+ huggingface_download_err_message = None
1097
+ while i < retry:
1098
+ try:
1099
+ # Download to cache folder. The while loop retries when there is a network failure
1100
+ snapshot_download(repo_id=model_name)
1101
+ except Exception as e:
1102
+ huggingface_download_err_message = str(e)
1103
+ i += 1
1104
+ else:
1105
+ break
1106
+ if i == retry:
1107
+ raise Exception(
1108
+ f"Could not download the model {model_name} from https://huggingface.co with message {huggingface_download_err_message}"
1109
+ )
1110
+ os.makedirs(local_dir, exist_ok=True)
1111
+ # Copy the model from the cache to destination
1112
+ snapshot_download(repo_id=model_name, local_dir=local_dir)
1113
+ # Upload to object storage
1114
+ model_artifact_path = upload_folder(
1115
+ os_path=os_path,
1116
+ local_dir=local_dir,
1117
+ model_name=model_name,
1118
+ )
1119
+
1120
+ return model_artifact_path
1121
+
676
1122
  def register(
677
1123
  self, import_model_details: ImportModelDetails = None, **kwargs
678
1124
  ) -> AquaModel:
@@ -692,82 +1138,59 @@ class AquaModelApp(AquaApp):
692
1138
  AquaModel:
693
1139
  The registered model as a AquaModel object.
694
1140
  """
695
- verified_model_details: DataScienceModel = None
696
-
697
1141
  if not import_model_details:
698
1142
  import_model_details = ImportModelDetails(**kwargs)
699
1143
 
700
- try:
701
- model_config = load_config(
702
- file_path=import_model_details.os_path,
703
- config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
704
- )
705
- except Exception as ex:
706
- logger.error(
707
- f"Exception occurred while loading config file from {import_model_details.os_path}"
708
- f"Exception message: {ex}"
709
- )
710
- raise AquaRuntimeError(
711
- f"The model path {import_model_details.os_path} does not contain the file config.json. "
712
- f"Please check if the path is correct or the model artifacts are available at this location."
713
- )
714
-
715
- model_service_id = None
716
1144
  # If OCID of a model is passed, we need to copy the defaults for Tags and metadata from the service model.
1145
+ verified_model: Optional[DataScienceModel] = None
717
1146
  if (
718
1147
  import_model_details.model.startswith("ocid")
719
1148
  and "datasciencemodel" in import_model_details.model
720
1149
  ):
721
- model_service_id = import_model_details.model
1150
+ verified_model = DataScienceModel.from_id(import_model_details.model)
722
1151
  else:
723
1152
  # If users passes model name, check if there is model with the same name in the service model catalog. If it is there, then use that model
724
1153
  model_service_id = self._find_matching_aqua_model(
725
1154
  import_model_details.model
726
1155
  )
727
- logger.info(
728
- f"Found service model for {import_model_details.model}: {model_service_id}"
729
- )
730
- if model_service_id:
731
- verified_model_details = DataScienceModel.from_id(model_service_id)
732
- try:
733
- metadata_model_type = verified_model_details.custom_metadata_list.get(
734
- AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
735
- ).value
736
- if metadata_model_type:
737
- if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config:
738
- if (
739
- model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]
740
- != metadata_model_type
741
- ):
742
- raise AquaRuntimeError(
743
- f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}"
744
- f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for "
745
- f"the model {import_model_details.model}. Please check if the path is correct or "
746
- f"the correct model artifacts are available at this location."
747
- f""
748
- )
749
- else:
750
- logger.debug(
751
- f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in "
752
- f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration."
753
- )
754
- except:
755
- pass
1156
+ if model_service_id:
1157
+ logger.info(
1158
+ f"Found service model for {import_model_details.model}: {model_service_id}"
1159
+ )
1160
+ verified_model = DataScienceModel.from_id(model_service_id)
756
1161
 
757
1162
  # Copy the model name from the service model if `model` is ocid
758
1163
  model_name = (
759
- verified_model_details.display_name
760
- if verified_model_details
1164
+ verified_model.display_name
1165
+ if verified_model
761
1166
  else import_model_details.model
762
1167
  )
763
1168
 
1169
+ # validate model and artifact
1170
+ validation_result = self._validate_model(
1171
+ import_model_details=import_model_details,
1172
+ model_name=model_name,
1173
+ verified_model=verified_model,
1174
+ )
1175
+
1176
+ # download model from hugginface if indicates
1177
+ if import_model_details.download_from_hf:
1178
+ artifact_path = self._download_model_from_hf(
1179
+ model_name=model_name,
1180
+ os_path=import_model_details.os_path,
1181
+ local_dir=import_model_details.local_dir,
1182
+ ).rstrip("/")
1183
+ else:
1184
+ artifact_path = import_model_details.os_path.rstrip("/")
1185
+
764
1186
  # Create Model catalog entry with pass by reference
765
1187
  ds_model = self._create_model_catalog_entry(
766
- os_path=import_model_details.os_path,
1188
+ os_path=artifact_path,
767
1189
  model_name=model_name,
768
1190
  inference_container=import_model_details.inference_container,
769
1191
  finetuning_container=import_model_details.finetuning_container,
770
- verified_model=verified_model_details,
1192
+ verified_model=verified_model,
1193
+ validation_result=validation_result,
771
1194
  compartment_id=import_model_details.compartment_id,
772
1195
  project_id=import_model_details.project_id,
773
1196
  )
@@ -783,7 +1206,7 @@ class AquaModelApp(AquaApp):
783
1206
  finetuning_container = ds_model.custom_metadata_list.get(
784
1207
  ModelCustomMetadataFields.FINETUNE_CONTAINER,
785
1208
  ).value
786
- except:
1209
+ except Exception:
787
1210
  finetuning_container = None
788
1211
 
789
1212
  aqua_model_attributes = dict(
@@ -791,29 +1214,20 @@ class AquaModelApp(AquaApp):
791
1214
  project_id=ds_model.project_id,
792
1215
  model_card=str(
793
1216
  read_file(
794
- file_path=f"{import_model_details.os_path.rstrip('/')}/{README}",
1217
+ file_path=f"{artifact_path}/{README}",
795
1218
  auth=default_signer(),
796
1219
  )
797
1220
  ),
798
1221
  inference_container=inference_container,
799
1222
  finetuning_container=finetuning_container,
800
1223
  evaluation_container=evaluation_container,
1224
+ artifact_location=artifact_path,
801
1225
  )
802
1226
 
803
- if verified_model_details:
804
- telemetry_model_name = model_name
805
- else:
806
- if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config:
807
- telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
808
- elif AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config:
809
- telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
810
- else:
811
- telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
812
-
813
1227
  self.telemetry.record_event_async(
814
1228
  category="aqua/model",
815
1229
  action="register",
816
- detail=telemetry_model_name,
1230
+ detail=validation_result.telemetry_model_name,
817
1231
  )
818
1232
 
819
1233
  return AquaModel(**aqua_model_attributes)
@@ -856,7 +1270,8 @@ class AquaModelApp(AquaApp):
856
1270
  query, type=SEARCH_TYPE.STRUCTURED, tenant_id=TENANCY_OCID, **kwargs
857
1271
  )
858
1272
 
859
- def _build_search_text(self, tags: dict, description: str = None) -> str:
1273
+ @staticmethod
1274
+ def _build_search_text(tags: dict, description: str = None) -> str:
860
1275
  """Constructs search_text field in response."""
861
1276
  description = description or ""
862
1277
  tags_text = (