oracle-ads 2.10.1__py3-none-any.whl → 2.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. ads/aqua/__init__.py +12 -0
  2. ads/aqua/base.py +324 -0
  3. ads/aqua/cli.py +19 -0
  4. ads/aqua/config/deployment_config_defaults.json +9 -0
  5. ads/aqua/config/resource_limit_names.json +7 -0
  6. ads/aqua/constants.py +45 -0
  7. ads/aqua/data.py +40 -0
  8. ads/aqua/decorator.py +101 -0
  9. ads/aqua/deployment.py +643 -0
  10. ads/aqua/dummy_data/icon.txt +1 -0
  11. ads/aqua/dummy_data/oci_model_deployments.json +56 -0
  12. ads/aqua/dummy_data/oci_models.json +1 -0
  13. ads/aqua/dummy_data/readme.md +26 -0
  14. ads/aqua/evaluation.py +1751 -0
  15. ads/aqua/exception.py +82 -0
  16. ads/aqua/extension/__init__.py +40 -0
  17. ads/aqua/extension/base_handler.py +138 -0
  18. ads/aqua/extension/common_handler.py +21 -0
  19. ads/aqua/extension/deployment_handler.py +202 -0
  20. ads/aqua/extension/evaluation_handler.py +135 -0
  21. ads/aqua/extension/finetune_handler.py +66 -0
  22. ads/aqua/extension/model_handler.py +59 -0
  23. ads/aqua/extension/ui_handler.py +201 -0
  24. ads/aqua/extension/utils.py +23 -0
  25. ads/aqua/finetune.py +579 -0
  26. ads/aqua/job.py +29 -0
  27. ads/aqua/model.py +819 -0
  28. ads/aqua/training/__init__.py +4 -0
  29. ads/aqua/training/exceptions.py +459 -0
  30. ads/aqua/ui.py +453 -0
  31. ads/aqua/utils.py +715 -0
  32. ads/cli.py +37 -6
  33. ads/common/decorator/__init__.py +7 -3
  34. ads/common/decorator/require_nonempty_arg.py +65 -0
  35. ads/common/object_storage_details.py +166 -7
  36. ads/common/oci_client.py +18 -1
  37. ads/common/oci_logging.py +2 -2
  38. ads/common/oci_mixin.py +4 -5
  39. ads/common/serializer.py +34 -5
  40. ads/common/utils.py +75 -10
  41. ads/config.py +40 -1
  42. ads/jobs/ads_job.py +43 -25
  43. ads/jobs/builders/infrastructure/base.py +4 -2
  44. ads/jobs/builders/infrastructure/dsc_job.py +49 -39
  45. ads/jobs/builders/runtimes/base.py +71 -1
  46. ads/jobs/builders/runtimes/container_runtime.py +4 -4
  47. ads/jobs/builders/runtimes/pytorch_runtime.py +10 -63
  48. ads/jobs/templates/driver_pytorch.py +27 -10
  49. ads/model/artifact_downloader.py +84 -14
  50. ads/model/artifact_uploader.py +25 -23
  51. ads/model/datascience_model.py +388 -38
  52. ads/model/deployment/model_deployment.py +10 -2
  53. ads/model/generic_model.py +8 -0
  54. ads/model/model_file_description_schema.json +68 -0
  55. ads/model/model_metadata.py +1 -1
  56. ads/model/service/oci_datascience_model.py +34 -5
  57. ads/opctl/operator/lowcode/anomaly/README.md +2 -1
  58. ads/opctl/operator/lowcode/anomaly/__main__.py +10 -4
  59. ads/opctl/operator/lowcode/anomaly/environment.yaml +2 -1
  60. ads/opctl/operator/lowcode/anomaly/model/automlx.py +12 -6
  61. ads/opctl/operator/lowcode/forecast/README.md +3 -2
  62. ads/opctl/operator/lowcode/forecast/environment.yaml +3 -2
  63. ads/opctl/operator/lowcode/forecast/model/automlx.py +12 -23
  64. ads/telemetry/base.py +62 -0
  65. ads/telemetry/client.py +105 -0
  66. ads/telemetry/telemetry.py +6 -3
  67. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/METADATA +37 -7
  68. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/RECORD +71 -36
  69. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/LICENSE.txt +0 -0
  70. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/WHEEL +0 -0
  71. {oracle_ads-2.10.1.dist-info → oracle_ads-2.11.0.dist-info}/entry_points.txt +0 -0
@@ -1,22 +1,35 @@
1
1
  #!/usr/bin/env python
2
2
  # -*- coding: utf-8; -*-
3
3
 
4
- # Copyright (c) 2022, 2023 Oracle and/or its affiliates.
4
+ # Copyright (c) 2022, 2024 Oracle and/or its affiliates.
5
5
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
 
7
7
  import cgi
8
+ import json
8
9
  import logging
10
+ import os
11
+ import shutil
12
+ import tempfile
9
13
  from copy import deepcopy
10
- from typing import Dict, List, Optional, Union
14
+ from typing import Dict, List, Optional, Union, Tuple
11
15
 
12
16
  import pandas
17
+ from jsonschema import ValidationError, validate
18
+
13
19
  from ads.common import utils
14
20
  from ads.common.object_storage_details import ObjectStorageDetails
15
21
  from ads.config import COMPARTMENT_OCID, PROJECT_OCID
16
22
  from ads.feature_engineering.schema import Schema
17
23
  from ads.jobs.builders.base import Builder
24
+ from ads.model.artifact_downloader import (
25
+ LargeArtifactDownloader,
26
+ SmallArtifactDownloader,
27
+ )
28
+ from ads.model.artifact_uploader import LargeArtifactUploader, SmallArtifactUploader
18
29
  from ads.model.model_metadata import (
30
+ MetadataCustomCategory,
19
31
  ModelCustomMetadata,
32
+ ModelCustomMetadataItem,
20
33
  ModelProvenanceMetadata,
21
34
  ModelTaxonomyMetadata,
22
35
  )
@@ -24,16 +37,13 @@ from ads.model.service.oci_datascience_model import (
24
37
  ModelProvenanceNotFoundError,
25
38
  OCIDataScienceModel,
26
39
  )
27
- from ads.model.artifact_downloader import (
28
- LargeArtifactDownloader,
29
- SmallArtifactDownloader,
30
- )
31
- from ads.model.artifact_uploader import LargeArtifactUploader, SmallArtifactUploader
32
40
 
33
41
  logger = logging.getLogger(__name__)
34
42
 
35
43
 
36
44
  _MAX_ARTIFACT_SIZE_IN_BYTES = 2147483648 # 2GB
45
+ MODEL_BY_REFERENCE_VERSION = "1.0"
46
+ MODEL_BY_REFERENCE_JSON_FILE_NAME = "model_description.json"
37
47
 
38
48
 
39
49
  class ModelArtifactSizeError(Exception): # pragma: no cover
@@ -46,6 +56,23 @@ class ModelArtifactSizeError(Exception): # pragma: no cover
46
56
  )
47
57
 
48
58
 
59
+ class BucketNotVersionedError(Exception): # pragma: no cover
60
+ def __init__(
61
+ self,
62
+ msg="Model artifact bucket is not versioned. Enable versioning on the bucket to proceed with model creation by reference.",
63
+ ):
64
+ super().__init__(msg)
65
+
66
+
67
+ class ModelFileDescriptionError(Exception): # pragma: no cover
68
+ def __init__(self, msg="Model File Description file is not set up."):
69
+ super().__init__(msg)
70
+
71
+
72
+ class InvalidArtifactType(Exception): # pragma: no cover
73
+ pass
74
+
75
+
49
76
  class DataScienceModel(Builder):
50
77
  """Represents a Data Science Model.
51
78
 
@@ -84,6 +111,8 @@ class DataScienceModel(Builder):
84
111
  Model version set ID
85
112
  version_label: str
86
113
  Model version label
114
+ model_file_description: dict
115
+ Contains object path details for models created by reference.
87
116
 
88
117
  Methods
89
118
  -------
@@ -129,12 +158,16 @@ class DataScienceModel(Builder):
129
158
  Sets model custom metadata.
130
159
  with_provenance_metadata(self, metadata: Union[ModelProvenanceMetadata, Dict]) -> "DataScienceModel"
131
160
  Sets model provenance metadata.
132
- with_artifact(self, uri: str)
133
- Sets the artifact location. Can be a local.
161
+ with_artifact(self, *uri: str)
162
+ Sets the artifact location. Can be a local. For models created by reference, uri can take in single arg or multiple args in case
163
+ of a fine-tuned or multimodel setting.
134
164
  with_model_version_set_id(self, model_version_set_id: str):
135
165
  Sets the model version set ID.
136
166
  with_version_label(self, version_label: str):
137
167
  Sets the model version label.
168
+ with_model_file_description: dict
169
+ Sets path details for models created by reference. Input can be either a dict, string or json file and
170
+ the schema is dictated by model_file_description_schema.json
138
171
 
139
172
 
140
173
  Examples
@@ -170,7 +203,11 @@ class DataScienceModel(Builder):
170
203
  CONST_PROVENANCE_METADATA = "provenanceMetadata"
171
204
  CONST_ARTIFACT = "artifact"
172
205
  CONST_MODEL_VERSION_SET_ID = "modelVersionSetId"
206
+ CONST_MODEL_VERSION_SET_NAME = "modelVersionSetName"
173
207
  CONST_MODEL_VERSION_LABEL = "versionLabel"
208
+ CONST_TIME_CREATED = "timeCreated"
209
+ CONST_LIFECYCLE_STATE = "lifecycleState"
210
+ CONST_MODEL_FILE_DESCRIPTION = "modelDescription"
174
211
 
175
212
  attribute_map = {
176
213
  CONST_ID: "id",
@@ -187,7 +224,11 @@ class DataScienceModel(Builder):
187
224
  CONST_PROVENANCE_METADATA: "provenance_metadata",
188
225
  CONST_ARTIFACT: "artifact",
189
226
  CONST_MODEL_VERSION_SET_ID: "model_version_set_id",
227
+ CONST_MODEL_VERSION_SET_NAME: "model_version_set_name",
190
228
  CONST_MODEL_VERSION_LABEL: "version_label",
229
+ CONST_TIME_CREATED: "time_created",
230
+ CONST_LIFECYCLE_STATE: "lifecycle_state",
231
+ CONST_MODEL_FILE_DESCRIPTION: "model_file_description",
191
232
  }
192
233
 
193
234
  def __init__(self, spec: Dict = None, **kwargs) -> None:
@@ -221,6 +262,7 @@ class DataScienceModel(Builder):
221
262
  self._init_complex_attributes()
222
263
  # Specify oci datascience model instance
223
264
  self.dsc_model = self._to_oci_dsc_model()
265
+ self.local_copy_dir = None
224
266
 
225
267
  @property
226
268
  def id(self) -> Optional[str]:
@@ -242,6 +284,19 @@ class DataScienceModel(Builder):
242
284
  return self.dsc_model.status
243
285
  return None
244
286
 
287
+ @property
288
+ def lifecycle_state(self) -> Union[str, None]:
289
+ """Status of the model.
290
+
291
+ Returns
292
+ -------
293
+ str
294
+ Status of the model.
295
+ """
296
+ if self.dsc_model:
297
+ return self.dsc_model.status
298
+ return None
299
+
245
300
  @property
246
301
  def kind(self) -> str:
247
302
  """The kind of the object as showing in a YAML."""
@@ -266,6 +321,10 @@ class DataScienceModel(Builder):
266
321
  """
267
322
  return self.set_spec(self.CONST_PROJECT_ID, project_id)
268
323
 
324
+ @property
325
+ def time_created(self) -> str:
326
+ return self.get_spec(self.CONST_TIME_CREATED)
327
+
269
328
  @property
270
329
  def description(self) -> str:
271
330
  return self.get_spec(self.CONST_DESCRIPTION)
@@ -358,7 +417,7 @@ class DataScienceModel(Builder):
358
417
  return self.set_spec(self.CONST_DEFINED_TAG, kwargs)
359
418
 
360
419
  @property
361
- def input_schema(self) -> Schema:
420
+ def input_schema(self) -> Union[Schema, Dict]:
362
421
  """Returns model input schema.
363
422
 
364
423
  Returns
@@ -382,11 +441,15 @@ class DataScienceModel(Builder):
382
441
  The DataScienceModel instance (self)
383
442
  """
384
443
  if schema and isinstance(schema, Dict):
385
- schema = Schema.from_dict(schema)
444
+ try:
445
+ schema = Schema.from_dict(schema)
446
+ except Exception as err:
447
+ logger.warn(err)
448
+
386
449
  return self.set_spec(self.CONST_INPUT_SCHEMA, schema)
387
450
 
388
451
  @property
389
- def output_schema(self) -> Schema:
452
+ def output_schema(self) -> Union[Schema, Dict]:
390
453
  """Returns model output schema.
391
454
 
392
455
  Returns
@@ -410,7 +473,11 @@ class DataScienceModel(Builder):
410
473
  The DataScienceModel instance (self)
411
474
  """
412
475
  if schema and isinstance(schema, Dict):
413
- schema = Schema.from_dict(schema)
476
+ try:
477
+ schema = Schema.from_dict(schema)
478
+ except Exception as err:
479
+ logger.warn(err)
480
+
414
481
  return self.set_spec(self.CONST_OUTPUT_SCHEMA, schema)
415
482
 
416
483
  @property
@@ -486,10 +553,10 @@ class DataScienceModel(Builder):
486
553
  return self.set_spec(self.CONST_PROVENANCE_METADATA, metadata)
487
554
 
488
555
  @property
489
- def artifact(self) -> str:
556
+ def artifact(self) -> Union[str, list]:
490
557
  return self.get_spec(self.CONST_ARTIFACT)
491
558
 
492
- def with_artifact(self, uri: str):
559
+ def with_artifact(self, uri: str, *args):
493
560
  """Sets the artifact location. Can be a local.
494
561
 
495
562
  Parameters
@@ -498,13 +565,16 @@ class DataScienceModel(Builder):
498
565
  Path to artifact directory or to the ZIP archive.
499
566
  It could contain a serialized model(required) as well as any files needed for deployment.
500
567
  The content of the source folder will be zipped and uploaded to the model catalog.
501
-
568
+ For models created by reference, uri can take in single arg or multiple args in case of a fine-tuned or
569
+ multimodel setting.
502
570
  Examples
503
571
  --------
504
572
  >>> .with_artifact(uri="./model1/")
505
573
  >>> .with_artifact(uri="./model1.zip")
574
+ >>> .with_artifact("./model1", "./model2")
506
575
  """
507
- return self.set_spec(self.CONST_ARTIFACT, uri)
576
+
577
+ return self.set_spec(self.CONST_ARTIFACT, [uri] + list(args) if args else uri)
508
578
 
509
579
  @property
510
580
  def model_version_set_id(self) -> str:
@@ -520,6 +590,10 @@ class DataScienceModel(Builder):
520
590
  """
521
591
  return self.set_spec(self.CONST_MODEL_VERSION_SET_ID, model_version_set_id)
522
592
 
593
+ @property
594
+ def model_version_set_name(self) -> str:
595
+ return self.get_spec(self.CONST_MODEL_VERSION_SET_NAME)
596
+
523
597
  @property
524
598
  def version_label(self) -> str:
525
599
  return self.get_spec(self.CONST_MODEL_VERSION_LABEL)
@@ -534,6 +608,58 @@ class DataScienceModel(Builder):
534
608
  """
535
609
  return self.set_spec(self.CONST_MODEL_VERSION_LABEL, version_label)
536
610
 
611
+ @property
612
+ def model_file_description(self) -> dict:
613
+ return self.get_spec(self.CONST_MODEL_FILE_DESCRIPTION)
614
+
615
+ def with_model_file_description(
616
+ self, json_dict: dict = None, json_string: str = None, json_uri: str = None
617
+ ):
618
+ """Sets the json file description for model passed by reference
619
+ Parameters
620
+ ----------
621
+ json_dict : dict, optional
622
+ json dict, by default None
623
+ json_string : str, optional
624
+ json string, by default None
625
+ json_uri : str, optional
626
+ URI location of file containing json, by default None
627
+
628
+ Examples
629
+ --------
630
+ >>> DataScienceModel().with_model_file_description(json_string="<json_string>")
631
+ >>> DataScienceModel().with_model_file_description(json_dict=dict())
632
+ >>> DataScienceModel().with_model_file_description(json_uri="./model_description.json")
633
+ """
634
+ if json_dict:
635
+ json_data = json_dict
636
+ elif json_string:
637
+ json_data = json.loads(json_string)
638
+ elif json_uri:
639
+ with open(json_uri, "r") as json_file:
640
+ json_data = json.load(json_file)
641
+ else:
642
+ raise ValueError("Must provide either a valid json string or URI location.")
643
+
644
+ schema_file_path = os.path.join(
645
+ os.path.dirname(os.path.abspath(__file__)),
646
+ "model_file_description_schema.json",
647
+ )
648
+ with open(schema_file_path, encoding="utf-8") as schema_file:
649
+ schema = json.load(schema_file)
650
+
651
+ try:
652
+ validate(json_data, schema)
653
+ except ValidationError as ve:
654
+ message = (
655
+ f"model_file_description_schema.json validation failed. "
656
+ f"See Exception: {ve}"
657
+ )
658
+ logging.error(message)
659
+ raise ModelFileDescriptionError(message)
660
+
661
+ return self.set_spec(self.CONST_MODEL_FILE_DESCRIPTION, json_data)
662
+
537
663
  def create(self, **kwargs) -> "DataScienceModel":
538
664
  """Creates datascience model.
539
665
 
@@ -570,6 +696,9 @@ class DataScienceModel(Builder):
570
696
  The connection timeout in seconds for the client.
571
697
  parallel_process_count: (int, optional).
572
698
  The number of worker processes to use in parallel for uploading individual parts of a multipart upload.
699
+ model_by_reference: (bool, optional)
700
+ Whether model artifact is made available to Model Store by reference. Requires artifact location to be
701
+ provided using with_artifact method.
573
702
 
574
703
  Returns
575
704
  -------
@@ -591,6 +720,23 @@ class DataScienceModel(Builder):
591
720
  if not self.display_name:
592
721
  self.display_name = self._random_display_name()
593
722
 
723
+ model_by_reference = kwargs.pop("model_by_reference", False)
724
+ if model_by_reference:
725
+ # Update custom metadata
726
+ logger.info("Update custom metadata field with model by reference flag.")
727
+ metadata_item = ModelCustomMetadataItem(
728
+ key=self.CONST_MODEL_FILE_DESCRIPTION,
729
+ value="true",
730
+ description="model by reference flag",
731
+ category=MetadataCustomCategory.OTHER,
732
+ )
733
+ if self.custom_metadata_list:
734
+ self.custom_metadata_list._add(metadata_item, replace=True)
735
+ else:
736
+ custom_metadata = ModelCustomMetadata()
737
+ custom_metadata._add(metadata_item)
738
+ self.with_custom_metadata_list(custom_metadata)
739
+
594
740
  payload = deepcopy(self._spec)
595
741
  payload.pop("id", None)
596
742
  logger.debug(f"Creating a model with payload {payload}")
@@ -616,6 +762,7 @@ class DataScienceModel(Builder):
616
762
  auth=kwargs.pop("auth", None),
617
763
  timeout=kwargs.pop("timeout", None),
618
764
  parallel_process_count=kwargs.pop("parallel_process_count", None),
765
+ model_by_reference=model_by_reference,
619
766
  )
620
767
 
621
768
  # Sync up model
@@ -633,6 +780,7 @@ class DataScienceModel(Builder):
633
780
  remove_existing_artifact: Optional[bool] = True,
634
781
  timeout: Optional[int] = None,
635
782
  parallel_process_count: int = utils.DEFAULT_PARALLEL_PROCESS_COUNT,
783
+ model_by_reference: Optional[bool] = False,
636
784
  ) -> None:
637
785
  """Uploads model artifacts to the model catalog.
638
786
 
@@ -663,9 +811,16 @@ class DataScienceModel(Builder):
663
811
  The connection timeout in seconds for the client.
664
812
  parallel_process_count: (int, optional)
665
813
  The number of worker processes to use in parallel for uploading individual parts of a multipart upload.
814
+ model_by_reference: (bool, optional)
815
+ Whether model artifact is made available to Model Store by reference.
666
816
  """
667
817
  # Upload artifact to the model catalog
668
- if not self.artifact:
818
+ if model_by_reference and self.model_file_description:
819
+ logger.info(
820
+ "Model artifact will be uploaded using model_file_description contents, "
821
+ "artifact location will not be used."
822
+ )
823
+ elif not self.artifact:
669
824
  logger.warn(
670
825
  "Model artifact location not provided. "
671
826
  "Provide the artifact location to upload artifacts to the model catalog."
@@ -679,14 +834,24 @@ class DataScienceModel(Builder):
679
834
  "timeout": timeout,
680
835
  }
681
836
 
682
- if ObjectStorageDetails.is_oci_path(self.artifact):
683
- if bucket_uri and bucket_uri != self.artifact:
684
- logger.warn(
685
- "The `bucket_uri` will be ignored and the value of `self.artifact` will be used instead."
837
+ if model_by_reference:
838
+ self._validate_prepare_file_description_artifact()
839
+ else:
840
+ if isinstance(self.artifact, list):
841
+ raise InvalidArtifactType(
842
+ "Multiple artifacts are only allowed for models created by reference."
686
843
  )
687
- bucket_uri = self.artifact
688
844
 
689
- if bucket_uri or utils.folder_size(self.artifact) > _MAX_ARTIFACT_SIZE_IN_BYTES:
845
+ if ObjectStorageDetails.is_oci_path(self.artifact):
846
+ if bucket_uri and bucket_uri != self.artifact:
847
+ logger.warn(
848
+ "The `bucket_uri` will be ignored and the value of `self.artifact` will be used instead."
849
+ )
850
+ bucket_uri = self.artifact
851
+
852
+ if not model_by_reference and (
853
+ bucket_uri or utils.folder_size(self.artifact) > _MAX_ARTIFACT_SIZE_IN_BYTES
854
+ ):
690
855
  if not bucket_uri:
691
856
  raise ModelArtifactSizeError(
692
857
  max_artifact_size=utils.human_size(_MAX_ARTIFACT_SIZE_IN_BYTES)
@@ -707,9 +872,16 @@ class DataScienceModel(Builder):
707
872
  dsc_model=self.dsc_model,
708
873
  artifact_path=self.artifact,
709
874
  )
710
-
711
875
  artifact_uploader.upload()
712
876
 
877
+ self._remove_file_description_artifact()
878
+
879
+ def _remove_file_description_artifact(self):
880
+ """Removes temporary model file description artifact for model by reference."""
881
+ # delete if local copy directory was created
882
+ if self.local_copy_dir:
883
+ shutil.rmtree(self.local_copy_dir, ignore_errors=True)
884
+
713
885
  def download_artifact(
714
886
  self,
715
887
  target_dir: str,
@@ -767,13 +939,37 @@ class DataScienceModel(Builder):
767
939
  **(self.dsc_model.__class__.kwargs or {}),
768
940
  "timeout": timeout,
769
941
  }
942
+ try:
943
+ model_by_reference = self.custom_metadata_list.get(
944
+ self.CONST_MODEL_FILE_DESCRIPTION
945
+ ).value
946
+ logging.info(
947
+ f"modelDescription tag found in custom metadata list with value {model_by_reference}"
948
+ )
949
+ except ValueError:
950
+ model_by_reference = False
770
951
 
771
- artifact_info = self.dsc_model.get_artifact_info()
772
- artifact_size = int(artifact_info.get("content-length"))
773
- if not bucket_uri and artifact_size > _MAX_ARTIFACT_SIZE_IN_BYTES:
774
- raise ModelArtifactSizeError(utils.human_size(_MAX_ARTIFACT_SIZE_IN_BYTES))
952
+ if model_by_reference:
953
+ _, artifact_size = self._download_file_description_artifact()
954
+ logging.warning(
955
+ f"Model {self.dsc_model.id} was created by reference, artifacts will be downloaded from the bucket {bucket_uri}"
956
+ )
957
+ # artifacts will be downloaded from model_file_description
958
+ bucket_uri = None
959
+ else:
960
+ artifact_info = self.dsc_model.get_artifact_info()
961
+ artifact_size = int(artifact_info.get("content-length"))
775
962
 
776
- if artifact_size > _MAX_ARTIFACT_SIZE_IN_BYTES or bucket_uri:
963
+ if not bucket_uri and artifact_size > _MAX_ARTIFACT_SIZE_IN_BYTES:
964
+ raise ModelArtifactSizeError(
965
+ utils.human_size(_MAX_ARTIFACT_SIZE_IN_BYTES)
966
+ )
967
+
968
+ if (
969
+ artifact_size > _MAX_ARTIFACT_SIZE_IN_BYTES
970
+ or bucket_uri
971
+ or model_by_reference
972
+ ):
777
973
  artifact_downloader = LargeArtifactDownloader(
778
974
  dsc_model=self.dsc_model,
779
975
  target_dir=target_dir,
@@ -783,6 +979,7 @@ class DataScienceModel(Builder):
783
979
  bucket_uri=bucket_uri,
784
980
  overwrite_existing_artifact=overwrite_existing_artifact,
785
981
  remove_existing_artifact=remove_existing_artifact,
982
+ model_file_description=self.model_file_description,
786
983
  )
787
984
  else:
788
985
  artifact_downloader = SmallArtifactDownloader(
@@ -790,7 +987,6 @@ class DataScienceModel(Builder):
790
987
  target_dir=target_dir,
791
988
  force_overwrite=force_overwrite,
792
989
  )
793
-
794
990
  artifact_downloader.download()
795
991
 
796
992
  def update(self, **kwargs) -> "DataScienceModel":
@@ -965,9 +1161,13 @@ class DataScienceModel(Builder):
965
1161
  for infra_attr, dsc_attr in self.attribute_map.items():
966
1162
  value = self.get_spec(infra_attr)
967
1163
  if infra_attr in COMPLEX_ATTRIBUTES_CONVERTER and value:
968
- dsc_spec[dsc_attr] = getattr(
969
- self.get_spec(infra_attr), COMPLEX_ATTRIBUTES_CONVERTER[infra_attr]
970
- )()
1164
+ if isinstance(value, dict):
1165
+ dsc_spec[dsc_attr] = json.dumps(value)
1166
+ else:
1167
+ dsc_spec[dsc_attr] = getattr(
1168
+ self.get_spec(infra_attr),
1169
+ COMPLEX_ATTRIBUTES_CONVERTER[infra_attr],
1170
+ )()
971
1171
  else:
972
1172
  dsc_spec[dsc_attr] = value
973
1173
 
@@ -990,8 +1190,8 @@ class DataScienceModel(Builder):
990
1190
  The DataScienceModel instance (self).
991
1191
  """
992
1192
  COMPLEX_ATTRIBUTES_CONVERTER = {
993
- self.CONST_INPUT_SCHEMA: Schema.from_json,
994
- self.CONST_OUTPUT_SCHEMA: Schema.from_json,
1193
+ self.CONST_INPUT_SCHEMA: [Schema.from_json, json.loads],
1194
+ self.CONST_OUTPUT_SCHEMA: [Schema.from_json, json.loads],
995
1195
  self.CONST_CUSTOM_METADATA: ModelCustomMetadata._from_oci_metadata,
996
1196
  self.CONST_DEFINED_METADATA: ModelTaxonomyMetadata._from_oci_metadata,
997
1197
  }
@@ -1002,7 +1202,16 @@ class DataScienceModel(Builder):
1002
1202
  value = utils.get_value(dsc_model, dsc_attr)
1003
1203
  if value:
1004
1204
  if infra_attr in COMPLEX_ATTRIBUTES_CONVERTER:
1005
- value = COMPLEX_ATTRIBUTES_CONVERTER[infra_attr](value)
1205
+ converter = COMPLEX_ATTRIBUTES_CONVERTER[infra_attr]
1206
+ if isinstance(converter, List):
1207
+ for converter_item in converter:
1208
+ try:
1209
+ value = converter_item(value)
1210
+ except Exception as err:
1211
+ logger.warn(err)
1212
+ pass
1213
+ else:
1214
+ value = converter(value)
1006
1215
  self.set_spec(infra_attr, value)
1007
1216
 
1008
1217
  # Update provenance metadata
@@ -1020,7 +1229,14 @@ class DataScienceModel(Builder):
1020
1229
  try:
1021
1230
  artifact_info = self.dsc_model.get_artifact_info()
1022
1231
  _, file_name_info = cgi.parse_header(artifact_info["Content-Disposition"])
1023
- self.set_spec(self.CONST_ARTIFACT, file_name_info["filename"])
1232
+
1233
+ if self.dsc_model.is_model_by_reference():
1234
+ _, file_extension = os.path.splitext(file_name_info["filename"])
1235
+ if file_extension.lower() == ".json":
1236
+ bucket_uri, _ = self._download_file_description_artifact()
1237
+ self.set_spec(self.CONST_ARTIFACT, bucket_uri)
1238
+ else:
1239
+ self.set_spec(self.CONST_ARTIFACT, file_name_info["filename"])
1024
1240
  except:
1025
1241
  pass
1026
1242
 
@@ -1088,3 +1304,137 @@ class DataScienceModel(Builder):
1088
1304
  if f"with_{item}" in self.__dir__():
1089
1305
  return self.get_spec(item)
1090
1306
  raise AttributeError(f"Attribute {item} not found.")
1307
+
1308
+ def _validate_prepare_file_description_artifact(self):
1309
+ """This helper method validates the path to check if the buckets are versioned and if the OSS location and
1310
+ the files exist. Next, it creates a json dict with the path information and sets it as the artifact to be
1311
+ uploaded."""
1312
+
1313
+ if not self.model_file_description:
1314
+ bucket_uri = self.artifact
1315
+ if isinstance(bucket_uri, str):
1316
+ bucket_uri = [bucket_uri]
1317
+
1318
+ for uri in bucket_uri:
1319
+ if not ObjectStorageDetails.from_path(uri).is_bucket_versioned():
1320
+ message = f"Model artifact bucket {uri} is not versioned. Enable versioning on the bucket to proceed with model creation by reference."
1321
+ logger.error(message)
1322
+ raise BucketNotVersionedError(message)
1323
+
1324
+ json_data = self._prepare_file_description_artifact(bucket_uri)
1325
+ self.with_model_file_description(json_dict=json_data)
1326
+
1327
+ self.local_copy_dir = tempfile.mkdtemp()
1328
+ # create temp directory for model description file
1329
+ json_file_path = os.path.join(
1330
+ self.local_copy_dir, MODEL_BY_REFERENCE_JSON_FILE_NAME
1331
+ )
1332
+ with open(json_file_path, "w") as outfile:
1333
+ json.dump(self.model_file_description, outfile, indent=2)
1334
+
1335
+ self.with_artifact(json_file_path)
1336
+
1337
+ @staticmethod
1338
+ def _prepare_file_description_artifact(bucket_uri: list) -> dict:
1339
+ """Prepares yaml file config if model is passed by reference and uploaded to catalog.
1340
+
1341
+ Returns
1342
+ -------
1343
+ dict
1344
+ json dict with the model by reference artifact details
1345
+ """
1346
+
1347
+ # create json content
1348
+ content = dict()
1349
+ content["version"] = MODEL_BY_REFERENCE_VERSION
1350
+ content["type"] = "modelOSSReferenceDescription"
1351
+ content["models"] = []
1352
+
1353
+ for uri in bucket_uri:
1354
+ if not ObjectStorageDetails.is_oci_path(uri) or uri.endswith(".zip"):
1355
+ msg = "Artifact path cannot be a zip file or local directory for model creation by reference."
1356
+ logging.error(msg)
1357
+ raise InvalidArtifactType(msg)
1358
+
1359
+ # read list from objects from artifact location
1360
+ oss_details = ObjectStorageDetails.from_path(uri)
1361
+
1362
+ # first retrieve the etag and version id
1363
+ object_versions = oss_details.list_object_versions(fields="etag")
1364
+ version_dict = {
1365
+ obj.etag: obj.version_id
1366
+ for obj in object_versions
1367
+ if obj.etag is not None
1368
+ }
1369
+
1370
+ # add version id based on etag for each object
1371
+ objects = oss_details.list_objects(fields="name,etag,size").objects
1372
+
1373
+ if len(objects) == 0:
1374
+ raise ModelFileDescriptionError(
1375
+ f"The path {oss_details.path} does not exist or no objects were found in the path. "
1376
+ )
1377
+
1378
+ object_list = []
1379
+ for obj in objects:
1380
+ object_list.append(
1381
+ {
1382
+ "name": obj.name,
1383
+ "version": version_dict[obj.etag],
1384
+ "sizeInBytes": obj.size,
1385
+ }
1386
+ )
1387
+ content["models"].extend(
1388
+ [
1389
+ {
1390
+ "namespace": oss_details.namespace,
1391
+ "bucketName": oss_details.bucket,
1392
+ "prefix": oss_details.filepath,
1393
+ "objects": object_list,
1394
+ }
1395
+ ]
1396
+ )
1397
+
1398
+ return content
1399
+
1400
+ def _download_file_description_artifact(self) -> Tuple[Union[str, List[str]], int]:
1401
+ """Loads the json file from model artifact, updates the
1402
+ model file description property, and returns the bucket uri and artifact size details.
1403
+
1404
+ Returns
1405
+ -------
1406
+ bucket_uri: Union[str, List[str]]
1407
+ Location(s) of bucket where model artifacts are present
1408
+ artifact_size: int
1409
+ estimated size of the model files in bytes
1410
+
1411
+ """
1412
+ if not self.model_file_description:
1413
+ # get model file description from model artifact json
1414
+ with tempfile.TemporaryDirectory() as temp_dir:
1415
+ artifact_downloader = SmallArtifactDownloader(
1416
+ dsc_model=self.dsc_model,
1417
+ target_dir=temp_dir,
1418
+ )
1419
+ artifact_downloader.download()
1420
+ # create temp directory for model description file
1421
+ json_file_path = os.path.join(
1422
+ temp_dir, MODEL_BY_REFERENCE_JSON_FILE_NAME
1423
+ )
1424
+ self.with_model_file_description(json_uri=json_file_path)
1425
+
1426
+ model_file_desc_dict = self.model_file_description
1427
+ models = model_file_desc_dict["models"]
1428
+
1429
+ bucket_uri = list()
1430
+ artifact_size = 0
1431
+ for model in models:
1432
+ namespace = model["namespace"]
1433
+ bucket_name = model["bucketName"]
1434
+ prefix = model["prefix"]
1435
+ objects = model["objects"]
1436
+ uri = f"oci://{bucket_name}@{namespace}/{prefix}"
1437
+ artifact_size += sum([obj["sizeInBytes"] for obj in objects])
1438
+ bucket_uri.append(uri)
1439
+
1440
+ return bucket_uri[0] if len(bucket_uri) == 1 else bucket_uri, artifact_size