oracle-ads 2.13.12__py3-none-any.whl → 2.13.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ads/aqua/app.py CHANGED
@@ -5,6 +5,7 @@
5
5
  import json
6
6
  import os
7
7
  import traceback
8
+ from concurrent.futures import ThreadPoolExecutor
8
9
  from dataclasses import fields
9
10
  from datetime import datetime, timedelta
10
11
  from itertools import chain
@@ -58,6 +59,8 @@ from ads.telemetry.client import TelemetryClient
58
59
  class AquaApp:
59
60
  """Base Aqua App to contain common components."""
60
61
 
62
+ MAX_WORKERS = 10 # Number of workers for asynchronous resource loading
63
+
61
64
  @telemetry(name="aqua")
62
65
  def __init__(self) -> None:
63
66
  if OCI_RESOURCE_PRINCIPAL_VERSION:
@@ -128,20 +131,69 @@ class AquaApp:
128
131
  update_model_provenance_details=update_model_provenance_details,
129
132
  )
130
133
 
131
- # TODO: refactor model evaluation implementation to use it.
132
134
  @staticmethod
133
135
  def get_source(source_id: str) -> Union[ModelDeployment, DataScienceModel]:
134
- if is_valid_ocid(source_id):
135
- if "datasciencemodeldeployment" in source_id:
136
- return ModelDeployment.from_id(source_id)
137
- elif "datasciencemodel" in source_id:
138
- return DataScienceModel.from_id(source_id)
136
+ """
137
+ Fetches a model or model deployment based on the provided OCID.
138
+
139
+ Parameters
140
+ ----------
141
+ source_id : str
142
+ OCID of the Data Science model or model deployment.
143
+
144
+ Returns
145
+ -------
146
+ Union[ModelDeployment, DataScienceModel]
147
+ The corresponding resource object.
148
+
149
+ Raises
150
+ ------
151
+ AquaValueError
152
+ If the OCID is invalid or unsupported.
153
+ """
154
+ logger.debug(f"Resolving source for ID: {source_id}")
155
+ if not is_valid_ocid(source_id):
156
+ logger.error(f"Invalid OCID format: {source_id}")
157
+ raise AquaValueError(
158
+ f"Invalid source ID: {source_id}. Please provide a valid model or model deployment OCID."
159
+ )
160
+
161
+ if "datasciencemodeldeployment" in source_id:
162
+ logger.debug(f"Identified as ModelDeployment OCID: {source_id}")
163
+ return ModelDeployment.from_id(source_id)
139
164
 
165
+ if "datasciencemodel" in source_id:
166
+ logger.debug(f"Identified as DataScienceModel OCID: {source_id}")
167
+ return DataScienceModel.from_id(source_id)
168
+
169
+ logger.error(f"Unrecognized OCID type: {source_id}")
140
170
  raise AquaValueError(
141
- f"Invalid source {source_id}. "
142
- "Specify either a model or model deployment id."
171
+ f"Unsupported source ID type: {source_id}. Must be a model or model deployment OCID."
143
172
  )
144
173
 
174
+ def get_multi_source(
175
+ self,
176
+ ids: List[str],
177
+ ) -> Dict[str, Union[ModelDeployment, DataScienceModel]]:
178
+ """
179
+ Retrieves multiple DataScience resources concurrently.
180
+
181
+ Parameters
182
+ ----------
183
+ ids : List[str]
184
+ A list of DataScience OCIDs.
185
+
186
+ Returns
187
+ -------
188
+ Dict[str, Union[ModelDeployment, DataScienceModel]]
189
+ A mapping from OCID to the corresponding resolved resource object.
190
+ """
191
+ logger.debug(f"Fetching {ids} sources in parallel.")
192
+ with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
193
+ results = list(executor.map(self.get_source, ids))
194
+
195
+ return dict(zip(ids, results))
196
+
145
197
  # TODO: refactor model evaluation implementation to use it.
146
198
  @staticmethod
147
199
  def create_model_version_set(
ads/aqua/cli.py CHANGED
@@ -14,6 +14,7 @@ from ads.aqua.evaluation import AquaEvaluationApp
14
14
  from ads.aqua.finetuning import AquaFineTuningApp
15
15
  from ads.aqua.model import AquaModelApp
16
16
  from ads.aqua.modeldeployment import AquaDeploymentApp
17
+ from ads.aqua.verify_policies import AquaVerifyPoliciesApp
17
18
  from ads.common.utils import LOG_LEVELS
18
19
 
19
20
 
@@ -29,6 +30,7 @@ class AquaCommand:
29
30
  fine_tuning = AquaFineTuningApp
30
31
  deployment = AquaDeploymentApp
31
32
  evaluation = AquaEvaluationApp
33
+ verify_policies = AquaVerifyPoliciesApp
32
34
 
33
35
  def __init__(
34
36
  self,
ads/aqua/client/client.py CHANGED
@@ -61,13 +61,20 @@ class HttpxOCIAuth(httpx.Auth):
61
61
 
62
62
  def __init__(self, signer: Optional[oci.signer.Signer] = None):
63
63
  """
64
- Initialize the HttpxOCIAuth instance.
64
+ Initializes the authentication handler with the given or default OCI signer.
65
65
 
66
- Args:
67
- signer (oci.signer.Signer): The OCI signer to use for signing requests.
66
+ Parameters
67
+ ----------
68
+ signer : oci.signer.Signer, optional
69
+ The OCI signer instance to use. If None, a default signer will be retrieved.
68
70
  """
69
-
70
- self.signer = signer or authutil.default_signer().get("signer")
71
+ try:
72
+ self.signer = signer or authutil.default_signer().get("signer")
73
+ if not self.signer:
74
+ raise ValueError("OCI signer could not be initialized.")
75
+ except Exception as e:
76
+ logger.error("Failed to initialize OCI signer: %s", e)
77
+ raise
71
78
 
72
79
  def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
73
80
  """
@@ -80,21 +87,31 @@ class HttpxOCIAuth(httpx.Auth):
80
87
  httpx.Request: The signed HTTPX request.
81
88
  """
82
89
  # Create a requests.Request object from the HTTPX request
83
- req = requests.Request(
84
- method=request.method,
85
- url=str(request.url),
86
- headers=dict(request.headers),
87
- data=request.content,
88
- )
89
- prepared_request = req.prepare()
90
+ try:
91
+ req = requests.Request(
92
+ method=request.method,
93
+ url=str(request.url),
94
+ headers=dict(request.headers),
95
+ data=request.content,
96
+ )
97
+ prepared_request = req.prepare()
98
+ self.signer.do_request_sign(prepared_request)
99
+
100
+ # Replace headers on the original HTTPX request with signed headers
101
+ request.headers.update(prepared_request.headers)
102
+ logger.debug("Successfully signed request to %s", request.url)
90
103
 
91
- # Sign the request using the OCI Signer
92
- self.signer.do_request_sign(prepared_request)
104
+ # Fix for GET/DELETE requests that OCI Gateway expects with Content-Length
105
+ if (
106
+ request.method in ["GET", "DELETE"]
107
+ and "content-length" not in request.headers
108
+ ):
109
+ request.headers["content-length"] = "0"
93
110
 
94
- # Update the original HTTPX request with the signed headers
95
- request.headers.update(prepared_request.headers)
111
+ except Exception as e:
112
+ logger.error("Failed to sign request to %s: %s", request.url, e)
113
+ raise
96
114
 
97
- # Proceed with the request
98
115
  yield request
99
116
 
100
117
 
@@ -330,8 +347,8 @@ class BaseClient:
330
347
  "Content-Type": "application/json",
331
348
  "Accept": "text/event-stream" if stream else "application/json",
332
349
  }
333
- if stream:
334
- default_headers["enable-streaming"] = "true"
350
+ # if stream:
351
+ # default_headers["enable-streaming"] = "true"
335
352
  if headers:
336
353
  default_headers.update(headers)
337
354
 
@@ -495,7 +512,7 @@ class Client(BaseClient):
495
512
  prompt: str,
496
513
  payload: Optional[Dict[str, Any]] = None,
497
514
  headers: Optional[Dict[str, str]] = None,
498
- stream: bool = True,
515
+ stream: bool = False,
499
516
  ) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
500
517
  """
501
518
  Generate text completion for the given prompt.
@@ -521,7 +538,7 @@ class Client(BaseClient):
521
538
  messages: List[Dict[str, Any]],
522
539
  payload: Optional[Dict[str, Any]] = None,
523
540
  headers: Optional[Dict[str, str]] = None,
524
- stream: bool = True,
541
+ stream: bool = False,
525
542
  ) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
526
543
  """
527
544
  Perform a chat interaction with the model.
@@ -32,7 +32,7 @@ class ModelDeploymentBaseEndpoint(ExtendedEnum):
32
32
  """Supported base endpoints for model deployments."""
33
33
 
34
34
  PREDICT = "predict"
35
- PREDICT_WITH_RESPONSE_STREAM = "predictwithresponsestream"
35
+ PREDICT_WITH_RESPONSE_STREAM = "predictWithResponseStream"
36
36
 
37
37
 
38
38
  class AquaOpenAIMixin:
@@ -51,9 +51,9 @@ class AquaOpenAIMixin:
51
51
  Returns:
52
52
  str: The normalized OpenAI-compatible route path (e.g., '/v1/chat/completions').
53
53
  """
54
- normalized_path = original_path.lower().rstrip("/")
54
+ normalized_path = original_path.rstrip("/")
55
55
 
56
- match = re.search(r"/predict(withresponsestream)?", normalized_path)
56
+ match = re.search(r"/predict(WithResponseStream)?", normalized_path)
57
57
  if not match:
58
58
  logger.debug("Route header cannot be resolved from path: %s", original_path)
59
59
  return ""
@@ -71,7 +71,7 @@ class AquaOpenAIMixin:
71
71
  "Route suffix does not start with a version prefix (e.g., '/v1'). "
72
72
  "This may lead to compatibility issues with OpenAI-style endpoints. "
73
73
  "Consider updating the URL to include a version prefix, "
74
- "such as '/predict/v1' or '/predictwithresponsestream/v1'."
74
+ "such as '/predict/v1' or '/predictWithResponseStream/v1'."
75
75
  )
76
76
  # route_suffix = f"v1/{route_suffix}"
77
77
 
@@ -124,13 +124,13 @@ class AquaOpenAIMixin:
124
124
 
125
125
  def _patch_url(self) -> httpx.URL:
126
126
  """
127
- Strips any suffixes from the base URL to retain only the `/predict` or `/predictwithresponsestream` path.
127
+ Strips any suffixes from the base URL to retain only the `/predict` or `/predictWithResponseStream` path.
128
128
 
129
129
  Returns:
130
130
  httpx.URL: The normalized base URL with the correct model deployment path.
131
131
  """
132
- base_path = f"{self.base_url.path.lower().rstrip('/')}/"
133
- match = re.search(r"/predict(withresponsestream)?/", base_path)
132
+ base_path = f"{self.base_url.path.rstrip('/')}/"
133
+ match = re.search(r"/predict(WithResponseStream)?/", base_path)
134
134
  if match:
135
135
  trimmed = base_path[: match.end() - 1]
136
136
  return self.base_url.copy_with(path=trimmed)
@@ -144,7 +144,7 @@ class AquaOpenAIMixin:
144
144
 
145
145
  This includes:
146
146
  - Patching headers with streaming and routing info.
147
- - Normalizing the URL path to include only `/predict` or `/predictwithresponsestream`.
147
+ - Normalizing the URL path to include only `/predict` or `/predictWithResponseStream`.
148
148
 
149
149
  Args:
150
150
  request (httpx.Request): The outgoing HTTPX request.
@@ -176,6 +176,7 @@ class OpenAI(openai.OpenAI, AquaOpenAIMixin):
176
176
  http_client: Optional[httpx.Client] = None,
177
177
  http_client_kwargs: Optional[Dict[str, Any]] = None,
178
178
  _strict_response_validation: bool = False,
179
+ patch_headers: bool = False,
179
180
  **kwargs: Any,
180
181
  ) -> None:
181
182
  """
@@ -196,6 +197,7 @@ class OpenAI(openai.OpenAI, AquaOpenAIMixin):
196
197
  http_client (httpx.Client, optional): Custom HTTP client; if not provided, one will be auto-created.
197
198
  http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
198
199
  _strict_response_validation (bool, optional): Enable strict response validation.
200
+ patch_headers (bool, optional): If True, redirects the requests by modifying the headers.
199
201
  **kwargs: Additional keyword arguments passed to the parent __init__.
200
202
  """
201
203
  if http_client is None:
@@ -207,6 +209,8 @@ class OpenAI(openai.OpenAI, AquaOpenAIMixin):
207
209
  logger.debug("API key not provided; using default placeholder for OCI.")
208
210
  api_key = "OCI"
209
211
 
212
+ self.patch_headers = patch_headers
213
+
210
214
  super().__init__(
211
215
  api_key=api_key,
212
216
  organization=organization,
@@ -229,7 +233,8 @@ class OpenAI(openai.OpenAI, AquaOpenAIMixin):
229
233
  Args:
230
234
  request (httpx.Request): The outgoing HTTP request.
231
235
  """
232
- self._prepare_request_common(request)
236
+ if self.patch_headers:
237
+ self._prepare_request_common(request)
233
238
 
234
239
 
235
240
  class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
@@ -248,6 +253,7 @@ class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
248
253
  http_client: Optional[httpx.Client] = None,
249
254
  http_client_kwargs: Optional[Dict[str, Any]] = None,
250
255
  _strict_response_validation: bool = False,
256
+ patch_headers: bool = False,
251
257
  **kwargs: Any,
252
258
  ) -> None:
253
259
  """
@@ -269,6 +275,7 @@ class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
269
275
  http_client (httpx.AsyncClient, optional): Custom asynchronous HTTP client; if not provided, one will be auto-created.
270
276
  http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
271
277
  _strict_response_validation (bool, optional): Enable strict response validation.
278
+ patch_headers (bool, optional): If True, redirects the requests by modifying the headers.
272
279
  **kwargs: Additional keyword arguments passed to the parent __init__.
273
280
  """
274
281
  if http_client is None:
@@ -280,6 +287,8 @@ class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
280
287
  logger.debug("API key not provided; using default placeholder for OCI.")
281
288
  api_key = "OCI"
282
289
 
290
+ self.patch_headers = patch_headers
291
+
283
292
  super().__init__(
284
293
  api_key=api_key,
285
294
  organization=organization,
@@ -302,4 +311,5 @@ class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
302
311
  Args:
303
312
  request (httpx.Request): The outgoing HTTP request.
304
313
  """
305
- self._prepare_request_common(request)
314
+ if self.patch_headers:
315
+ self._prepare_request_common(request)
@@ -6,7 +6,7 @@ import re
6
6
  from typing import Any, Dict, List, Optional
7
7
 
8
8
  from oci.data_science.models import Model
9
- from pydantic import BaseModel, Field, model_validator
9
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
10
10
 
11
11
  from ads.aqua import logger
12
12
  from ads.aqua.config.utils.serializer import Serializable
@@ -80,24 +80,29 @@ class GPUShapesIndex(Serializable):
80
80
 
81
81
  class ComputeShapeSummary(Serializable):
82
82
  """
83
- Represents the specifications of a compute instance's shape.
83
+ Represents the specifications of a compute instance shape,
84
+ including CPU, memory, and optional GPU characteristics.
84
85
  """
85
86
 
86
87
  core_count: Optional[int] = Field(
87
- default=None, description="The number of CPU cores available."
88
+ default=None,
89
+ description="Total number of CPU cores available for the compute shape.",
88
90
  )
89
91
  memory_in_gbs: Optional[int] = Field(
90
- default=None, description="The amount of memory (in GB) available."
92
+ default=None,
93
+ description="Amount of memory (in GB) available for the compute shape.",
91
94
  )
92
95
  name: Optional[str] = Field(
93
- default=None, description="The name identifier of the compute shape."
96
+ default=None,
97
+ description="Full name of the compute shape, e.g., 'VM.GPU.A10.2'.",
94
98
  )
95
99
  shape_series: Optional[str] = Field(
96
- default=None, description="The series or category of the compute shape."
100
+ default=None,
101
+ description="Shape family or series, e.g., 'GPU', 'Standard', etc.",
97
102
  )
98
103
  gpu_specs: Optional[GPUSpecs] = Field(
99
104
  default=None,
100
- description="The GPU specifications associated with the compute shape.",
105
+ description="Optional GPU specifications associated with the shape.",
101
106
  )
102
107
 
103
108
  @model_validator(mode="after")
@@ -136,27 +141,46 @@ class ComputeShapeSummary(Serializable):
136
141
  return model
137
142
 
138
143
 
139
- class LoraModuleSpec(Serializable):
144
+ class LoraModuleSpec(BaseModel):
140
145
  """
141
- Lightweight descriptor for LoRA Modules used in fine-tuning models.
146
+ Descriptor for a LoRA (Low-Rank Adaptation) module used in fine-tuning base models.
147
+
148
+ This class is used to define a single fine-tuned module that can be loaded during
149
+ multi-model deployment alongside a base model.
142
150
 
143
151
  Attributes
144
152
  ----------
145
153
  model_id : str
146
- The unique identifier of the fine tuned model.
147
- model_name : str
148
- The name of the fine-tuned model.
149
- model_path : str
150
- The model-by-reference path to the LoRA Module within the model artifact
154
+ The OCID of the fine-tuned model registered in the OCI Model Catalog.
155
+ model_name : Optional[str]
156
+ The unique name used to route inference requests to this model variant.
157
+ model_path : Optional[str]
158
+ The relative path within the artifact pointing to the LoRA adapter weights.
151
159
  """
152
160
 
153
- model_id: Optional[str] = Field(None, description="The fine tuned model OCID to deploy.")
154
- model_name: Optional[str] = Field(None, description="The name of the fine-tuned model.")
161
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
162
+
163
+ model_id: str = Field(
164
+ ...,
165
+ description="OCID of the fine-tuned model (must be registered in the Model Catalog).",
166
+ )
167
+ model_name: Optional[str] = Field(
168
+ default=None,
169
+ description="Name assigned to the fine-tuned model for serving (used as inference route).",
170
+ )
155
171
  model_path: Optional[str] = Field(
156
- None,
157
- description="The model-by-reference path to the LoRA Module within the model artifact.",
172
+ default=None,
173
+ description="Relative path to the LoRA weights inside the model artifact.",
158
174
  )
159
175
 
176
+ @model_validator(mode="before")
177
+ @classmethod
178
+ def validate_lora_module(cls, data: dict) -> dict:
179
+ """Validates that required structure exists for a LoRA module."""
180
+ if "model_id" not in data or not data["model_id"]:
181
+ raise ValueError("Missing required field: 'model_id' for fine-tuned model.")
182
+ return data
183
+
160
184
 
161
185
  class AquaMultiModelRef(Serializable):
162
186
  """
@@ -203,6 +227,22 @@ class AquaMultiModelRef(Serializable):
203
227
  description="For fine tuned models, the artifact path of the modified model weights",
204
228
  )
205
229
 
230
+ def all_model_ids(self) -> List[str]:
231
+ """
232
+ Returns all associated model OCIDs, including the base model and any fine-tuned models.
233
+
234
+ Returns
235
+ -------
236
+ List[str]
237
+ A list of all model OCIDs associated with this multi-model reference.
238
+ """
239
+ ids = {self.model_id}
240
+ if self.fine_tune_weights:
241
+ ids.update(
242
+ module.model_id for module in self.fine_tune_weights if module.model_id
243
+ )
244
+ return list(ids)
245
+
206
246
  class Config:
207
247
  extra = "ignore"
208
248
  protected_namespaces = ()
ads/aqua/constants.py CHANGED
@@ -55,6 +55,8 @@ SERVICE_MANAGED_CONTAINER_URI_SCHEME = "dsmc://"
55
55
  SUPPORTED_FILE_FORMATS = ["jsonl"]
56
56
  MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location"
57
57
 
58
+ AQUA_CHAT_TEMPLATE_METADATA_KEY = "chat_template"
59
+
58
60
  CONSOLE_LINK_RESOURCE_TYPE_MAPPING = {
59
61
  "datasciencemodel": "models",
60
62
  "datasciencemodeldeployment": "model-deployments",
@@ -1,8 +1,8 @@
1
1
  #!/usr/bin/env python
2
2
  # Copyright (c) 2025 Oracle and/or its affiliates.
3
3
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
-
5
-
4
+ import json
5
+ import os
6
6
  from importlib import metadata
7
7
 
8
8
  import huggingface_hub
@@ -18,6 +18,10 @@ from ads.aqua.common.utils import (
18
18
  )
19
19
  from ads.aqua.extension.base_handler import AquaAPIhandler
20
20
  from ads.aqua.extension.errors import Errors
21
+ from ads.common.object_storage_details import ObjectStorageDetails
22
+ from ads.common.utils import read_file
23
+ from ads.config import CONDA_BUCKET_NAME, CONDA_BUCKET_NS
24
+ from ads.opctl.operator.common.utils import default_signer
21
25
 
22
26
 
23
27
  class ADSVersionHandler(AquaAPIhandler):
@@ -28,6 +32,46 @@ class ADSVersionHandler(AquaAPIhandler):
28
32
  self.finish({"data": metadata.version("oracle_ads")})
29
33
 
30
34
 
35
+ class AquaVersionHandler(AquaAPIhandler):
36
+ @handle_exceptions
37
+ def get(self):
38
+ """
39
+ Returns the current and latest deployed version of AQUA
40
+
41
+ {
42
+ "installed": {
43
+ "aqua": "0.1.3.0",
44
+ "ads": "2.14.2"
45
+ },
46
+ "latest": {
47
+ "aqua": "0.1.4.0",
48
+ "ads": "2.14.4"
49
+ }
50
+ }
51
+
52
+ """
53
+
54
+ current_aqua_version_path = os.path.join(
55
+ os.path.dirname(os.path.abspath(__file__)), "..", "version.json"
56
+ )
57
+ current_aqua_version = json.loads(read_file(current_aqua_version_path))
58
+ current_ads_version = {"ads": metadata.version("oracle_ads")}
59
+ current_version = {"installed": {**current_aqua_version, **current_ads_version}}
60
+ try:
61
+ latest_version_artifact_path = ObjectStorageDetails(
62
+ CONDA_BUCKET_NAME,
63
+ CONDA_BUCKET_NS,
64
+ "service_pack/aqua_latest_version.json",
65
+ ).path
66
+ latest_version = json.loads(
67
+ read_file(latest_version_artifact_path, auth=default_signer())
68
+ )
69
+ except Exception:
70
+ latest_version = {"latest": current_version["installed"]}
71
+ response = {**current_version, **latest_version}
72
+ return self.finish(response)
73
+
74
+
31
75
  class CompatibilityCheckHandler(AquaAPIhandler):
32
76
  """The handler to check if the extension is compatible."""
33
77
 
@@ -118,4 +162,5 @@ __handlers__ = [
118
162
  ("network_status", NetworkStatusHandler),
119
163
  ("hf_login", HFLoginHandler),
120
164
  ("hf_logged_in", HFUserStatusHandler),
165
+ ("aqua_version", AquaVersionHandler),
121
166
  ]
@@ -11,12 +11,15 @@ from ads.aqua.common.decorator import handle_exceptions
11
11
  from ads.aqua.common.enums import CustomInferenceContainerTypeFamily
12
12
  from ads.aqua.common.errors import AquaRuntimeError
13
13
  from ads.aqua.common.utils import get_hf_model_info, is_valid_ocid, list_hf_models
14
+ from ads.aqua.constants import AQUA_CHAT_TEMPLATE_METADATA_KEY
14
15
  from ads.aqua.extension.base_handler import AquaAPIhandler
15
16
  from ads.aqua.extension.errors import Errors
16
17
  from ads.aqua.model import AquaModelApp
17
18
  from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
18
19
  from ads.config import SERVICE
20
+ from ads.model import DataScienceModel
19
21
  from ads.model.common.utils import MetadataArtifactPathType
22
+ from ads.model.service.oci_datascience_model import OCIDataScienceModel
20
23
 
21
24
 
22
25
  class AquaModelHandler(AquaAPIhandler):
@@ -320,26 +323,65 @@ class AquaHuggingFaceHandler(AquaAPIhandler):
320
323
  )
321
324
 
322
325
 
323
- class AquaModelTokenizerConfigHandler(AquaAPIhandler):
326
+ class AquaModelChatTemplateHandler(AquaAPIhandler):
324
327
  def get(self, model_id):
325
328
  """
326
- Handles requests for retrieving the Hugging Face tokenizer configuration of a specified model.
327
- Expected request format: GET /aqua/models/<model-ocid>/tokenizer
329
+ Handles requests for retrieving the chat template from custom metadata of a specified model.
330
+ Expected request format: GET /aqua/models/<model-ocid>/chat-template
328
331
 
329
332
  """
330
333
 
331
334
  path_list = urlparse(self.request.path).path.strip("/").split("/")
332
- # Path should be /aqua/models/ocid1.iad.ahdxxx/tokenizer
333
- # path_list=['aqua','models','<model-ocid>','tokenizer']
335
+ # Path should be /aqua/models/ocid1.iad.ahdxxx/chat-template
336
+ # path_list=['aqua','models','<model-ocid>','chat-template']
334
337
  if (
335
338
  len(path_list) == 4
336
339
  and is_valid_ocid(path_list[2])
337
- and path_list[3] == "tokenizer"
340
+ and path_list[3] == "chat-template"
338
341
  ):
339
- return self.finish(AquaModelApp().get_hf_tokenizer_config(model_id))
342
+ try:
343
+ oci_data_science_model = OCIDataScienceModel.from_id(model_id)
344
+ except Exception as e:
345
+ raise HTTPError(404, f"Model not found for id: {model_id}. Details: {str(e)}")
346
+ return self.finish(oci_data_science_model.get_custom_metadata_artifact("chat_template"))
340
347
 
341
348
  raise HTTPError(400, f"The request {self.request.path} is invalid.")
342
349
 
350
+ @handle_exceptions
351
+ def post(self, model_id: str):
352
+ """
353
+ Handles POST requests to add a custom chat_template metadata artifact to a model.
354
+
355
+ Expected request format:
356
+ POST /aqua/models/<model-ocid>/chat-template
357
+ Body: { "chat_template": "<your_template_string>" }
358
+
359
+ """
360
+ try:
361
+ input_body = self.get_json_body()
362
+ except Exception as e:
363
+ raise HTTPError(400, f"Invalid JSON body: {str(e)}")
364
+
365
+ chat_template = input_body.get("chat_template")
366
+ if not chat_template:
367
+ raise HTTPError(400, "Missing required field: 'chat_template'")
368
+
369
+ try:
370
+ data_science_model = DataScienceModel.from_id(model_id)
371
+ except Exception as e:
372
+ raise HTTPError(404, f"Model not found for id: {model_id}. Details: {str(e)}")
373
+
374
+ try:
375
+ result = data_science_model.create_custom_metadata_artifact(
376
+ metadata_key_name=AQUA_CHAT_TEMPLATE_METADATA_KEY,
377
+ path_type=MetadataArtifactPathType.CONTENT,
378
+ artifact_path_or_content=chat_template.encode()
379
+ )
380
+ except Exception as e:
381
+ raise HTTPError(500, f"Failed to create metadata artifact: {str(e)}")
382
+
383
+ return self.finish(result)
384
+
343
385
 
344
386
  class AquaModelDefinedMetadataArtifactHandler(AquaAPIhandler):
345
387
  """
@@ -381,7 +423,7 @@ __handlers__ = [
381
423
  ("model/?([^/]*)", AquaModelHandler),
382
424
  ("model/?([^/]*)/license", AquaModelLicenseHandler),
383
425
  ("model/?([^/]*)/readme", AquaModelReadmeHandler),
384
- ("model/?([^/]*)/tokenizer", AquaModelTokenizerConfigHandler),
426
+ ("model/?([^/]*)/chat-template", AquaModelChatTemplateHandler),
385
427
  ("model/hf/search/?([^/]*)", AquaHuggingFaceHandler),
386
428
  (
387
429
  "model/?([^/]*)/definedMetadata/?([^/]*)",
@@ -26,6 +26,7 @@ class ModelTask(ExtendedEnum):
26
26
  TEXT_GENERATION = "text-generation"
27
27
  IMAGE_TEXT_TO_TEXT = "image-text-to-text"
28
28
  IMAGE_TO_TEXT = "image-to-text"
29
+ TIME_SERIES_FORECASTING = "time-series-forecasting"
29
30
 
30
31
 
31
32
  class FineTuningMetricCategories(ExtendedEnum):