oracle-ads 2.11.9__py3-none-any.whl → 2.11.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. ads/aqua/__init__.py +1 -1
  2. ads/aqua/{base.py → app.py} +27 -7
  3. ads/aqua/cli.py +59 -17
  4. ads/aqua/common/__init__.py +5 -0
  5. ads/aqua/{decorator.py → common/decorator.py} +14 -8
  6. ads/aqua/common/enums.py +69 -0
  7. ads/aqua/{exception.py → common/errors.py} +28 -0
  8. ads/aqua/{utils.py → common/utils.py} +168 -77
  9. ads/aqua/config/config.py +18 -0
  10. ads/aqua/constants.py +51 -33
  11. ads/aqua/data.py +15 -26
  12. ads/aqua/evaluation/__init__.py +8 -0
  13. ads/aqua/evaluation/constants.py +53 -0
  14. ads/aqua/evaluation/entities.py +170 -0
  15. ads/aqua/evaluation/errors.py +71 -0
  16. ads/aqua/{evaluation.py → evaluation/evaluation.py} +122 -370
  17. ads/aqua/extension/__init__.py +2 -0
  18. ads/aqua/extension/aqua_ws_msg_handler.py +97 -0
  19. ads/aqua/extension/base_handler.py +0 -7
  20. ads/aqua/extension/common_handler.py +12 -6
  21. ads/aqua/extension/deployment_handler.py +70 -4
  22. ads/aqua/extension/errors.py +10 -0
  23. ads/aqua/extension/evaluation_handler.py +5 -3
  24. ads/aqua/extension/evaluation_ws_msg_handler.py +43 -0
  25. ads/aqua/extension/finetune_handler.py +41 -3
  26. ads/aqua/extension/model_handler.py +56 -4
  27. ads/aqua/extension/models/__init__.py +0 -0
  28. ads/aqua/extension/models/ws_models.py +69 -0
  29. ads/aqua/extension/ui_handler.py +65 -4
  30. ads/aqua/extension/ui_websocket_handler.py +124 -0
  31. ads/aqua/extension/utils.py +1 -1
  32. ads/aqua/finetuning/__init__.py +7 -0
  33. ads/aqua/finetuning/constants.py +17 -0
  34. ads/aqua/finetuning/entities.py +102 -0
  35. ads/aqua/{finetune.py → finetuning/finetuning.py} +162 -136
  36. ads/aqua/model/__init__.py +8 -0
  37. ads/aqua/model/constants.py +46 -0
  38. ads/aqua/model/entities.py +266 -0
  39. ads/aqua/model/enums.py +26 -0
  40. ads/aqua/{model.py → model/model.py} +401 -309
  41. ads/aqua/modeldeployment/__init__.py +8 -0
  42. ads/aqua/modeldeployment/constants.py +26 -0
  43. ads/aqua/{deployment.py → modeldeployment/deployment.py} +288 -227
  44. ads/aqua/modeldeployment/entities.py +142 -0
  45. ads/aqua/modeldeployment/inference.py +75 -0
  46. ads/aqua/ui.py +88 -8
  47. ads/cli.py +55 -7
  48. ads/common/serializer.py +2 -2
  49. ads/config.py +2 -1
  50. ads/jobs/builders/infrastructure/dsc_job.py +49 -6
  51. ads/model/datascience_model.py +1 -1
  52. ads/model/deployment/model_deployment.py +11 -0
  53. ads/model/model_metadata.py +17 -6
  54. ads/opctl/operator/lowcode/anomaly/README.md +0 -2
  55. ads/opctl/operator/lowcode/anomaly/__main__.py +3 -3
  56. ads/opctl/operator/lowcode/anomaly/environment.yaml +0 -2
  57. ads/opctl/operator/lowcode/anomaly/model/automlx.py +2 -2
  58. ads/opctl/operator/lowcode/anomaly/model/autots.py +1 -1
  59. ads/opctl/operator/lowcode/anomaly/model/base_model.py +13 -17
  60. ads/opctl/operator/lowcode/anomaly/operator_config.py +2 -0
  61. ads/opctl/operator/lowcode/anomaly/schema.yaml +1 -2
  62. ads/opctl/operator/lowcode/anomaly/utils.py +3 -2
  63. ads/opctl/operator/lowcode/common/transformations.py +2 -1
  64. ads/opctl/operator/lowcode/common/utils.py +1 -1
  65. ads/opctl/operator/lowcode/forecast/README.md +1 -3
  66. ads/opctl/operator/lowcode/forecast/__main__.py +3 -18
  67. ads/opctl/operator/lowcode/forecast/const.py +2 -0
  68. ads/opctl/operator/lowcode/forecast/environment.yaml +1 -2
  69. ads/opctl/operator/lowcode/forecast/model/arima.py +1 -0
  70. ads/opctl/operator/lowcode/forecast/model/automlx.py +7 -4
  71. ads/opctl/operator/lowcode/forecast/model/autots.py +1 -0
  72. ads/opctl/operator/lowcode/forecast/model/base_model.py +38 -22
  73. ads/opctl/operator/lowcode/forecast/model/factory.py +33 -4
  74. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +15 -1
  75. ads/opctl/operator/lowcode/forecast/model/ml_forecast.py +234 -0
  76. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +9 -1
  77. ads/opctl/operator/lowcode/forecast/model/prophet.py +1 -0
  78. ads/opctl/operator/lowcode/forecast/model_evaluator.py +147 -0
  79. ads/opctl/operator/lowcode/forecast/operator_config.py +2 -1
  80. ads/opctl/operator/lowcode/forecast/schema.yaml +7 -2
  81. ads/opctl/operator/lowcode/forecast/utils.py +18 -44
  82. {oracle_ads-2.11.9.dist-info → oracle_ads-2.11.10.dist-info}/METADATA +9 -12
  83. {oracle_ads-2.11.9.dist-info → oracle_ads-2.11.10.dist-info}/RECORD +86 -61
  84. ads/aqua/job.py +0 -29
  85. {oracle_ads-2.11.9.dist-info → oracle_ads-2.11.10.dist-info}/LICENSE.txt +0 -0
  86. {oracle_ads-2.11.9.dist-info → oracle_ads-2.11.10.dist-info}/WHEEL +0 -0
  87. {oracle_ads-2.11.9.dist-info → oracle_ads-2.11.10.dist-info}/entry_points.txt +0 -0
@@ -5,158 +5,68 @@
5
5
 
6
6
  import json
7
7
  import logging
8
- from dataclasses import dataclass, field, asdict
9
8
  from typing import Dict, List, Union
10
9
 
11
- import requests
12
- from oci.data_science.models import ModelDeployment, ModelDeploymentSummary
10
+ from oci.data_science.models import ModelDeployment
13
11
 
14
- from ads.aqua.base import AquaApp, logger
15
- from ads.aqua.exception import AquaRuntimeError, AquaValueError
16
- from ads.aqua.model import AquaModelApp, Tags
17
- from ads.aqua.utils import (
18
- UNKNOWN,
19
- MODEL_BY_REFERENCE_OSS_PATH_KEY,
20
- load_config,
12
+ from ads.aqua.app import AquaApp, logger
13
+ from ads.aqua.common.enums import (
14
+ Tags,
15
+ InferenceContainerParamType,
16
+ InferenceContainerType,
17
+ InferenceContainerTypeFamily,
18
+ )
19
+ from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
20
+ from ads.aqua.common.utils import (
21
+ get_container_config,
21
22
  get_container_image,
22
- UNKNOWN_DICT,
23
- get_resource_name,
24
23
  get_model_by_reference_paths,
25
24
  get_ocid_substring,
26
- AQUA_MODEL_TYPE_SERVICE,
25
+ get_combined_params,
26
+ get_params_dict,
27
+ get_params_list,
28
+ get_resource_name,
29
+ load_config,
30
+ )
31
+ from ads.aqua.constants import (
27
32
  AQUA_MODEL_TYPE_CUSTOM,
33
+ AQUA_MODEL_TYPE_SERVICE,
34
+ MODEL_BY_REFERENCE_OSS_PATH_KEY,
35
+ UNKNOWN,
36
+ UNKNOWN_DICT,
28
37
  )
29
- from ads.aqua.finetune import FineTuneCustomMetadata
30
38
  from ads.aqua.data import AquaResourceIdentifier
31
- from ads.common.utils import get_console_link, get_log_links
32
- from ads.common.auth import default_signer
39
+ from ads.aqua.finetuning.finetuning import FineTuneCustomMetadata
40
+ from ads.aqua.model import AquaModelApp
41
+ from ads.aqua.modeldeployment.entities import (
42
+ AquaDeployment,
43
+ AquaDeploymentDetail,
44
+ ContainerSpec,
45
+ )
46
+ from ads.aqua.modeldeployment.constants import (
47
+ VLLMInferenceRestrictedParams,
48
+ TGIInferenceRestrictedParams,
49
+ )
50
+ from ads.common.object_storage_details import ObjectStorageDetails
51
+ from ads.common.utils import get_log_links
52
+ from ads.config import (
53
+ AQUA_CONFIG_FOLDER,
54
+ AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
55
+ AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME,
56
+ AQUA_MODEL_DEPLOYMENT_CONFIG,
57
+ AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS,
58
+ COMPARTMENT_OCID,
59
+ )
60
+ from ads.model.datascience_model import DataScienceModel
33
61
  from ads.model.deployment import (
34
62
  ModelDeployment,
35
63
  ModelDeploymentContainerRuntime,
36
64
  ModelDeploymentInfrastructure,
37
65
  ModelDeploymentMode,
38
66
  )
39
- from ads.common.serializer import DataClassSerializable
40
- from ads.config import (
41
- AQUA_MODEL_DEPLOYMENT_CONFIG,
42
- COMPARTMENT_OCID,
43
- AQUA_CONFIG_FOLDER,
44
- AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS,
45
- AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
46
- AQUA_SERVED_MODEL_NAME,
47
- )
48
- from ads.common.object_storage_details import ObjectStorageDetails
49
67
  from ads.telemetry import telemetry
50
68
 
51
69
 
52
- @dataclass
53
- class ShapeInfo:
54
- instance_shape: str = None
55
- instance_count: int = None
56
- ocpus: float = None
57
- memory_in_gbs: float = None
58
-
59
-
60
- @dataclass(repr=False)
61
- class AquaDeployment(DataClassSerializable):
62
- """Represents an Aqua Model Deployment"""
63
-
64
- id: str = None
65
- display_name: str = None
66
- aqua_service_model: bool = None
67
- aqua_model_name: str = None
68
- state: str = None
69
- description: str = None
70
- created_on: str = None
71
- created_by: str = None
72
- endpoint: str = None
73
- console_link: str = None
74
- lifecycle_details: str = None
75
- shape_info: field(default_factory=ShapeInfo) = None
76
- tags: dict = None
77
-
78
- @classmethod
79
- def from_oci_model_deployment(
80
- cls,
81
- oci_model_deployment: Union[ModelDeploymentSummary, ModelDeployment],
82
- region: str,
83
- ) -> "AquaDeployment":
84
- """Converts oci model deployment response to AquaDeployment instance.
85
-
86
- Parameters
87
- ----------
88
- oci_model_deployment: Union[ModelDeploymentSummary, ModelDeployment]
89
- The instance of either oci.data_science.models.ModelDeployment or
90
- oci.data_science.models.ModelDeploymentSummary class.
91
- region: str
92
- The region of this model deployment.
93
-
94
- Returns
95
- -------
96
- AquaDeployment:
97
- The instance of the Aqua model deployment.
98
- """
99
- instance_configuration = (
100
- oci_model_deployment.model_deployment_configuration_details.model_configuration_details.instance_configuration
101
- )
102
- instance_shape_config_details = (
103
- instance_configuration.model_deployment_instance_shape_config_details
104
- )
105
- instance_count = (
106
- oci_model_deployment.model_deployment_configuration_details.model_configuration_details.scaling_policy.instance_count
107
- )
108
- shape_info = ShapeInfo(
109
- instance_shape=instance_configuration.instance_shape_name,
110
- instance_count=instance_count,
111
- ocpus=(
112
- instance_shape_config_details.ocpus
113
- if instance_shape_config_details
114
- else None
115
- ),
116
- memory_in_gbs=(
117
- instance_shape_config_details.memory_in_gbs
118
- if instance_shape_config_details
119
- else None
120
- ),
121
- )
122
-
123
- freeform_tags = oci_model_deployment.freeform_tags or UNKNOWN_DICT
124
- aqua_service_model_tag = freeform_tags.get(
125
- Tags.AQUA_SERVICE_MODEL_TAG.value, None
126
- )
127
- aqua_model_name = freeform_tags.get(Tags.AQUA_MODEL_NAME_TAG.value, UNKNOWN)
128
-
129
- return AquaDeployment(
130
- id=oci_model_deployment.id,
131
- display_name=oci_model_deployment.display_name,
132
- aqua_service_model=aqua_service_model_tag is not None,
133
- aqua_model_name=aqua_model_name,
134
- shape_info=shape_info,
135
- state=oci_model_deployment.lifecycle_state,
136
- lifecycle_details=getattr(
137
- oci_model_deployment, "lifecycle_details", UNKNOWN
138
- ),
139
- description=oci_model_deployment.description,
140
- created_on=str(oci_model_deployment.time_created),
141
- created_by=oci_model_deployment.created_by,
142
- endpoint=oci_model_deployment.model_deployment_url,
143
- console_link=get_console_link(
144
- resource="model-deployments",
145
- ocid=oci_model_deployment.id,
146
- region=region,
147
- ),
148
- tags=freeform_tags,
149
- )
150
-
151
-
152
- @dataclass(repr=False)
153
- class AquaDeploymentDetail(AquaDeployment, DataClassSerializable):
154
- """Represents a details of Aqua deployment."""
155
-
156
- log_group: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
157
- log: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
158
-
159
-
160
70
  class AquaDeploymentApp(AquaApp):
161
71
  """Provides a suite of APIs to interact with Aqua model deployments within the Oracle
162
72
  Cloud Infrastructure Data Science service, serving as an interface for deploying
@@ -196,9 +106,10 @@ class AquaDeploymentApp(AquaApp):
196
106
  description: str = None,
197
107
  bandwidth_mbps: int = None,
198
108
  web_concurrency: int = None,
199
- server_port: int = 8080,
200
- health_check_port: int = 8080,
109
+ server_port: int = None,
110
+ health_check_port: int = None,
201
111
  env_var: Dict = None,
112
+ container_family: str = None,
202
113
  ) -> "AquaDeployment":
203
114
  """
204
115
  Creates a new Aqua deployment
@@ -231,18 +142,21 @@ class AquaDeploymentApp(AquaApp):
231
142
  The number of worker processes/threads to handle incoming requests
232
143
  with_bucket_uri(bucket_uri)
233
144
  Sets the bucket uri when uploading large size model.
234
- server_port: (int). Defaults to 8080.
145
+ server_port: (int).
235
146
  The server port for docker container image.
236
- health_check_port: (int). Defaults to 8080.
147
+ health_check_port: (int).
237
148
  The health check port for docker container image.
238
149
  env_var : dict, optional
239
150
  Environment variable for the deployment, by default None.
151
+ container_family: str
152
+ The image family of model deployment container runtime. Required for unverified Aqua models.
240
153
  Returns
241
154
  -------
242
155
  AquaDeployment
243
156
  An Aqua deployment instance
244
157
 
245
158
  """
159
+ # TODO validate if the service model has no artifact and if it requires import step before deployment.
246
160
  # Create a model catalog entry in the user compartment
247
161
  aqua_model = AquaModelApp().create(
248
162
  model_id=model_id, compartment_id=compartment_id, project_id=project_id
@@ -250,45 +164,35 @@ class AquaDeploymentApp(AquaApp):
250
164
 
251
165
  tags = {}
252
166
  for tag in [
253
- Tags.AQUA_SERVICE_MODEL_TAG.value,
254
- Tags.AQUA_FINE_TUNED_MODEL_TAG.value,
255
- Tags.AQUA_TAG.value,
167
+ Tags.AQUA_SERVICE_MODEL_TAG,
168
+ Tags.AQUA_FINE_TUNED_MODEL_TAG,
169
+ Tags.AQUA_TAG,
256
170
  ]:
257
171
  if tag in aqua_model.freeform_tags:
258
172
  tags[tag] = aqua_model.freeform_tags[tag]
259
173
 
260
- tags.update({Tags.AQUA_MODEL_NAME_TAG.value: aqua_model.display_name})
174
+ tags.update({Tags.AQUA_MODEL_NAME_TAG: aqua_model.display_name})
261
175
 
262
176
  # Set up info to get deployment config
263
177
  config_source_id = model_id
264
178
  model_name = aqua_model.display_name
265
179
 
266
- is_fine_tuned_model = (
267
- Tags.AQUA_FINE_TUNED_MODEL_TAG.value in aqua_model.freeform_tags
268
- )
180
+ is_fine_tuned_model = Tags.AQUA_FINE_TUNED_MODEL_TAG in aqua_model.freeform_tags
269
181
 
270
182
  if is_fine_tuned_model:
271
183
  try:
272
184
  config_source_id = aqua_model.custom_metadata_list.get(
273
- FineTuneCustomMetadata.FINE_TUNE_SOURCE.value
185
+ FineTuneCustomMetadata.FINE_TUNE_SOURCE
274
186
  ).value
275
187
  model_name = aqua_model.custom_metadata_list.get(
276
- FineTuneCustomMetadata.FINE_TUNE_SOURCE_NAME.value
188
+ FineTuneCustomMetadata.FINE_TUNE_SOURCE_NAME
277
189
  ).value
278
190
  except:
279
191
  raise AquaValueError(
280
- f"Either {FineTuneCustomMetadata.FINE_TUNE_SOURCE.value} or {FineTuneCustomMetadata.FINE_TUNE_SOURCE_NAME.value} is missing "
192
+ f"Either {FineTuneCustomMetadata.FINE_TUNE_SOURCE} or {FineTuneCustomMetadata.FINE_TUNE_SOURCE_NAME} is missing "
281
193
  f"from custom metadata for the model {config_source_id}"
282
194
  )
283
195
 
284
- deployment_config = self.get_deployment_config(config_source_id)
285
- vllm_params = (
286
- deployment_config.get("configuration", UNKNOWN_DICT)
287
- .get(instance_shape, UNKNOWN_DICT)
288
- .get("parameters", UNKNOWN_DICT)
289
- .get("VLLM_PARAMS", UNKNOWN)
290
- )
291
-
292
196
  # set up env vars
293
197
  if not env_var:
294
198
  env_var = dict()
@@ -302,18 +206,11 @@ class AquaDeploymentApp(AquaApp):
302
206
  f"{MODEL_BY_REFERENCE_OSS_PATH_KEY} key is not available in the custom metadata field."
303
207
  )
304
208
 
305
- # todo: remove this after absolute path is removed from env var
306
209
  if ObjectStorageDetails.is_oci_path(model_path_prefix):
307
210
  os_path = ObjectStorageDetails.from_path(model_path_prefix)
308
211
  model_path_prefix = os_path.filepath.rstrip("/")
309
212
 
310
213
  env_var.update({"BASE_MODEL": f"{model_path_prefix}"})
311
- params = f"--served-model-name {AQUA_SERVED_MODEL_NAME} --seed 42 "
312
- if vllm_params:
313
- params += vllm_params
314
- env_var.update({"PARAMS": params})
315
- env_var.update({"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions"})
316
- env_var.update({"MODEL_DEPLOY_ENABLE_STREAMING": "true"})
317
214
 
318
215
  if is_fine_tuned_model:
319
216
  _, fine_tune_output_path = get_model_by_reference_paths(
@@ -330,28 +227,94 @@ class AquaDeploymentApp(AquaApp):
330
227
 
331
228
  env_var.update({"FT_MODEL": f"{fine_tune_output_path}"})
332
229
 
333
- logging.info(f"Env vars used for deploying {aqua_model.id} :{env_var}")
334
-
230
+ is_custom_container = False
335
231
  try:
336
232
  container_type_key = aqua_model.custom_metadata_list.get(
337
233
  AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
338
234
  ).value
339
235
  except ValueError:
340
- raise AquaValueError(
341
- f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field for model {aqua_model.id}"
236
+ message = (
237
+ f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field "
238
+ f"for model {aqua_model.id}."
342
239
  )
240
+ logger.debug(message)
241
+ if not container_family:
242
+ raise AquaValueError(
243
+ f"{message}. For unverified Aqua models, container_family parameter should be "
244
+ f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}."
245
+ )
246
+ container_type_key = container_family
247
+ try:
248
+ # Check if the container override flag is set. If set, then the user has chosen custom image
249
+ if aqua_model.custom_metadata_list.get(
250
+ AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME
251
+ ).value:
252
+ is_custom_container = True
253
+ except Exception:
254
+ pass
343
255
 
344
256
  # fetch image name from config
345
- container_image = get_container_image(
346
- container_type=container_type_key,
257
+ # If the image is of type custom, then `container_type_key` is the inference image
258
+ container_image = (
259
+ get_container_image(
260
+ container_type=container_type_key,
261
+ )
262
+ if not is_custom_container
263
+ else container_type_key
347
264
  )
348
265
  logging.info(
349
266
  f"Aqua Image used for deploying {aqua_model.id} : {container_image}"
350
267
  )
351
268
 
269
+ # Fetch the startup cli command for the container
270
+ # container_index.json will have "containerSpec" section which will provide the cli params for a given container family
271
+ container_config = get_container_config()
272
+ container_spec = container_config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
273
+ container_type_key, {}
274
+ )
275
+ # these params cannot be overridden for Aqua deployments
276
+ params = container_spec.get(ContainerSpec.CLI_PARM, "")
277
+ server_port = server_port or container_spec.get(
278
+ ContainerSpec.SERVER_PORT
279
+ ) # Give precendece to the input parameter
280
+ health_check_port = health_check_port or container_spec.get(
281
+ ContainerSpec.HEALTH_CHECK_PORT
282
+ ) # Give precendece to the input parameter
283
+
284
+ deployment_config = self.get_deployment_config(config_source_id)
285
+ vllm_params = (
286
+ deployment_config.get("configuration", UNKNOWN_DICT)
287
+ .get(instance_shape, UNKNOWN_DICT)
288
+ .get("parameters", UNKNOWN_DICT)
289
+ .get(InferenceContainerParamType.PARAM_TYPE_VLLM, UNKNOWN)
290
+ )
291
+
292
+ # validate user provided params
293
+ user_params = env_var.get("PARAMS", UNKNOWN)
294
+ if user_params:
295
+ restricted_params = self._find_restricted_params(
296
+ params, user_params, container_type_key
297
+ )
298
+ if restricted_params:
299
+ raise AquaValueError(
300
+ f"Parameters {restricted_params} are set by Aqua "
301
+ f"and cannot be overridden or are invalid."
302
+ )
303
+
304
+ deployment_params = get_combined_params(vllm_params, user_params)
305
+
306
+ if deployment_params:
307
+ params = f"{params} {deployment_params}"
308
+
309
+ env_var.update({"PARAMS": params})
310
+ for env in container_spec.get(ContainerSpec.ENV_VARS, []):
311
+ if isinstance(env, dict):
312
+ env_var.update(env)
313
+
314
+ logging.info(f"Env vars used for deploying {aqua_model.id} :{env_var}")
315
+
352
316
  # Start model deployment
353
317
  # configure model deployment infrastructure
354
- # todo : any other infrastructure params needed?
355
318
  infrastructure = (
356
319
  ModelDeploymentInfrastructure()
357
320
  .with_project_id(project_id)
@@ -370,7 +333,6 @@ class AquaDeploymentApp(AquaApp):
370
333
  )
371
334
  )
372
335
  # configure model deployment runtime
373
- # todo : any other runtime params needed?
374
336
  container_runtime = (
375
337
  ModelDeploymentContainerRuntime()
376
338
  .with_image(container_image)
@@ -384,7 +346,6 @@ class AquaDeploymentApp(AquaApp):
384
346
  .with_remove_existing_artifact(True)
385
347
  )
386
348
  # configure model deployment and deploy model on container runtime
387
- # todo : any other deployment params needed?
388
349
  deployment = (
389
350
  ModelDeployment()
390
351
  .with_display_name(display_name)
@@ -447,8 +408,8 @@ class AquaDeploymentApp(AquaApp):
447
408
  for model_deployment in model_deployments:
448
409
  oci_aqua = (
449
410
  (
450
- Tags.AQUA_TAG.value in model_deployment.freeform_tags
451
- or Tags.AQUA_TAG.value.lower() in model_deployment.freeform_tags
411
+ Tags.AQUA_TAG in model_deployment.freeform_tags
412
+ or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
452
413
  )
453
414
  if model_deployment.freeform_tags
454
415
  else False
@@ -502,8 +463,8 @@ class AquaDeploymentApp(AquaApp):
502
463
 
503
464
  oci_aqua = (
504
465
  (
505
- Tags.AQUA_TAG.value in model_deployment.freeform_tags
506
- or Tags.AQUA_TAG.value.lower() in model_deployment.freeform_tags
466
+ Tags.AQUA_TAG in model_deployment.freeform_tags
467
+ or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
507
468
  )
508
469
  if model_deployment.freeform_tags
509
470
  else False
@@ -575,71 +536,171 @@ class AquaDeploymentApp(AquaApp):
575
536
  )
576
537
  return config
577
538
 
539
+ def get_deployment_default_params(
540
+ self,
541
+ model_id: str,
542
+ instance_shape: str,
543
+ ) -> List[str]:
544
+ """Gets the default params set in the deployment configs for the given model and instance shape.
578
545
 
579
- @dataclass
580
- class ModelParams:
581
- max_tokens: int = None
582
- temperature: float = None
583
- top_k: float = None
584
- top_p: float = None
585
- model: str = None
546
+ Parameters
547
+ ----------
548
+ model_id: str
549
+ The OCID of the Aqua model.
586
550
 
551
+ instance_shape: (str).
552
+ The shape of the instance used for deployment.
553
+
554
+ Returns
555
+ -------
556
+ List[str]:
557
+ List of parameters from the loaded from deployment config json file. If not available, then an empty list
558
+ is returned.
559
+
560
+ """
561
+ default_params = []
562
+ model = DataScienceModel.from_id(model_id)
563
+ try:
564
+ container_type_key = model.custom_metadata_list.get(
565
+ AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
566
+ ).value
567
+ except ValueError:
568
+ container_type_key = UNKNOWN
569
+ logger.debug(
570
+ f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field for model {model_id}."
571
+ )
587
572
 
588
- @dataclass
589
- class MDInferenceResponse(AquaApp):
590
- """Contains APIs for Aqua Model deployments Inference.
573
+ if container_type_key:
574
+ container_type_key = container_type_key.lower()
575
+ if container_type_key in InferenceContainerTypeFamily.values():
576
+ deployment_config = self.get_deployment_config(model_id)
577
+ config_parameters = (
578
+ deployment_config.get("configuration", UNKNOWN_DICT)
579
+ .get(instance_shape, UNKNOWN_DICT)
580
+ .get("parameters", UNKNOWN_DICT)
581
+ )
582
+ if InferenceContainerType.CONTAINER_TYPE_VLLM in container_type_key:
583
+ params = config_parameters.get(
584
+ InferenceContainerParamType.PARAM_TYPE_VLLM, UNKNOWN
585
+ )
586
+ elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_key:
587
+ params = config_parameters.get(
588
+ InferenceContainerParamType.PARAM_TYPE_TGI, UNKNOWN
589
+ )
590
+ else:
591
+ params = UNKNOWN
592
+ logger.debug(
593
+ f"Default inference parameters are not available for the model {model_id} and "
594
+ f"instance {instance_shape}."
595
+ )
596
+ if params:
597
+ # account for param that can have --arg but no values, e.g. --trust-remote-code
598
+ default_params.extend(get_params_list(params))
591
599
 
592
- Attributes
593
- ----------
600
+ return default_params
594
601
 
595
- model_params: Dict
596
- prompt: string
602
+ def validate_deployment_params(
603
+ self,
604
+ model_id: str,
605
+ params: List[str] = None,
606
+ container_family: str = None,
607
+ ) -> Dict:
608
+ """Validate if the deployment parameters passed by the user can be overridden. Parameter values are not
609
+ validated, only param keys are validated.
597
610
 
598
- Methods
599
- -------
600
- get_model_deployment_response(self, **kwargs) -> "String"
601
- Creates an instance of model deployment via Aqua
602
- """
611
+ Parameters
612
+ ----------
613
+ model_id: str
614
+ The OCID of the Aqua model.
615
+ params : List[str], optional
616
+ Params passed by the user.
617
+ container_family: str
618
+ The image family of model deployment container runtime. Required for unverified Aqua models.
603
619
 
604
- prompt: str = None
605
- model_params: field(default_factory=ModelParams) = None
620
+ Returns
621
+ -------
622
+ Return a list of restricted params.
606
623
 
607
- @telemetry(entry_point="plugin=inference&action=get_response", name="aqua")
608
- def get_model_deployment_response(self, endpoint):
609
624
  """
610
- Returns MD inference response
625
+ restricted_params = []
626
+ if params:
627
+ model = DataScienceModel.from_id(model_id)
628
+ try:
629
+ container_type_key = model.custom_metadata_list.get(
630
+ AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
631
+ ).value
632
+ except ValueError:
633
+ message = (
634
+ f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field "
635
+ f"for model {model_id}."
636
+ )
637
+ logger.debug(message)
638
+
639
+ if not container_family:
640
+ raise AquaValueError(
641
+ f"{message}. For unverified Aqua models, container_family parameter should be "
642
+ f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}."
643
+ )
644
+ container_type_key = container_family
645
+
646
+ container_config = get_container_config()
647
+ container_spec = container_config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
648
+ container_type_key, {}
649
+ )
650
+ cli_params = container_spec.get(ContainerSpec.CLI_PARM, "")
651
+
652
+ restricted_params = self._find_restricted_params(
653
+ cli_params, params, container_type_key
654
+ )
655
+
656
+ if restricted_params:
657
+ raise AquaValueError(
658
+ f"Parameters {restricted_params} are set by Aqua "
659
+ f"and cannot be overridden or are invalid."
660
+ )
661
+ return dict(valid=True)
662
+
663
+ @staticmethod
664
+ def _find_restricted_params(
665
+ default_params: Union[str, List[str]],
666
+ user_params: Union[str, List[str]],
667
+ container_family: str,
668
+ ) -> List[str]:
669
+ """Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
670
+ The default parameters coming from the container index json file cannot be overridden. In addition to this,
671
+ a set of parameters maintained in
611
672
 
612
673
  Parameters
613
674
  ----------
614
- endpoint: str
615
- MD predict url
616
- prompt: str
617
- User prompt.
618
-
619
- model_params: (Dict, optional)
620
- Model parameters to be associated with the message.
621
- Currently supported VLLM+OpenAI parameters.
622
-
623
- --model-params '{
624
- "max_tokens":500,
625
- "temperature": 0.5,
626
- "top_k": 10,
627
- "top_p": 0.5,
628
- "model": "/opt/ds/model/deployed_model",
629
- ...}'
675
+ default_params:
676
+ Inference container parameter string with default values.
677
+ user_params:
678
+ Inference container parameter string with user provided values.
679
+ container_family: str
680
+ The image family of model deployment container runtime.
630
681
 
631
682
  Returns
632
683
  -------
633
- model_response_content
684
+ A list with params keys common between params1 and params2.
685
+
634
686
  """
687
+ restricted_params = []
688
+ if default_params and user_params:
689
+ default_params_dict = get_params_dict(default_params)
690
+ user_params_dict = get_params_dict(user_params)
691
+
692
+ for key, items in user_params_dict.items():
693
+ if (
694
+ key in default_params_dict
695
+ or (
696
+ InferenceContainerType.CONTAINER_TYPE_VLLM in container_family
697
+ and key in VLLMInferenceRestrictedParams
698
+ )
699
+ or (
700
+ InferenceContainerType.CONTAINER_TYPE_TGI in container_family
701
+ and key in TGIInferenceRestrictedParams
702
+ )
703
+ ):
704
+ restricted_params.append(key.lstrip("--"))
635
705
 
636
- params_dict = asdict(self.model_params)
637
- params_dict = {
638
- key: value for key, value in params_dict.items() if value is not None
639
- }
640
- body = {"prompt": self.prompt, **params_dict}
641
- request_kwargs = {"json": body, "headers": {"Content-Type": "application/json"}}
642
- response = requests.post(
643
- endpoint, auth=default_signer()["signer"], **request_kwargs
644
- )
645
- return json.loads(response.content)
706
+ return restricted_params