oracle-ads 2.12.11__py3-none-any.whl → 2.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. ads/aqua/__init__.py +7 -1
  2. ads/aqua/app.py +41 -27
  3. ads/aqua/client/client.py +48 -11
  4. ads/aqua/common/entities.py +28 -1
  5. ads/aqua/common/enums.py +32 -21
  6. ads/aqua/common/errors.py +3 -4
  7. ads/aqua/common/utils.py +10 -15
  8. ads/aqua/config/container_config.py +203 -0
  9. ads/aqua/config/evaluation/evaluation_service_config.py +5 -181
  10. ads/aqua/constants.py +1 -1
  11. ads/aqua/evaluation/constants.py +7 -7
  12. ads/aqua/evaluation/errors.py +3 -4
  13. ads/aqua/evaluation/evaluation.py +4 -4
  14. ads/aqua/extension/base_handler.py +4 -0
  15. ads/aqua/extension/model_handler.py +41 -27
  16. ads/aqua/extension/models/ws_models.py +5 -6
  17. ads/aqua/finetuning/constants.py +3 -3
  18. ads/aqua/finetuning/finetuning.py +2 -3
  19. ads/aqua/model/constants.py +7 -7
  20. ads/aqua/model/entities.py +2 -3
  21. ads/aqua/model/enums.py +4 -5
  22. ads/aqua/model/model.py +46 -29
  23. ads/aqua/modeldeployment/deployment.py +6 -14
  24. ads/aqua/modeldeployment/entities.py +5 -3
  25. ads/aqua/server/__init__.py +4 -0
  26. ads/aqua/server/__main__.py +24 -0
  27. ads/aqua/server/app.py +47 -0
  28. ads/aqua/server/aqua_spec.yml +1291 -0
  29. ads/aqua/ui.py +5 -199
  30. ads/common/auth.py +50 -28
  31. ads/common/extended_enum.py +52 -44
  32. ads/common/utils.py +91 -11
  33. ads/config.py +3 -0
  34. ads/llm/__init__.py +12 -8
  35. ads/llm/langchain/plugins/embeddings/__init__.py +4 -0
  36. ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +184 -0
  37. ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +32 -23
  38. ads/model/artifact_downloader.py +6 -4
  39. ads/model/common/utils.py +15 -3
  40. ads/model/datascience_model.py +422 -71
  41. ads/model/generic_model.py +3 -3
  42. ads/model/model_metadata.py +70 -24
  43. ads/model/model_version_set.py +5 -3
  44. ads/model/service/oci_datascience_model.py +487 -17
  45. ads/opctl/anomaly_detection.py +11 -0
  46. ads/opctl/backend/marketplace/helm_helper.py +13 -14
  47. ads/opctl/cli.py +4 -5
  48. ads/opctl/cmds.py +28 -32
  49. ads/opctl/config/merger.py +8 -11
  50. ads/opctl/config/resolver.py +25 -30
  51. ads/opctl/forecast.py +11 -0
  52. ads/opctl/operator/cli.py +9 -9
  53. ads/opctl/operator/common/backend_factory.py +56 -60
  54. ads/opctl/operator/common/const.py +5 -5
  55. ads/opctl/operator/common/utils.py +16 -0
  56. ads/opctl/operator/lowcode/anomaly/const.py +8 -9
  57. ads/opctl/operator/lowcode/common/data.py +5 -2
  58. ads/opctl/operator/lowcode/common/transformations.py +2 -12
  59. ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py +43 -48
  60. ads/opctl/operator/lowcode/forecast/__main__.py +5 -5
  61. ads/opctl/operator/lowcode/forecast/const.py +6 -6
  62. ads/opctl/operator/lowcode/forecast/model/arima.py +6 -3
  63. ads/opctl/operator/lowcode/forecast/model/automlx.py +61 -31
  64. ads/opctl/operator/lowcode/forecast/model/base_model.py +66 -40
  65. ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +79 -13
  66. ads/opctl/operator/lowcode/forecast/model/neuralprophet.py +5 -2
  67. ads/opctl/operator/lowcode/forecast/model/prophet.py +28 -15
  68. ads/opctl/operator/lowcode/forecast/model_evaluator.py +13 -15
  69. ads/opctl/operator/lowcode/forecast/schema.yaml +1 -1
  70. ads/opctl/operator/lowcode/forecast/whatifserve/deployment_manager.py +7 -0
  71. ads/opctl/operator/lowcode/forecast/whatifserve/score.py +19 -11
  72. ads/opctl/operator/lowcode/pii/constant.py +6 -7
  73. ads/opctl/operator/lowcode/recommender/constant.py +12 -7
  74. ads/opctl/operator/runtime/marketplace_runtime.py +4 -10
  75. ads/opctl/operator/runtime/runtime.py +4 -6
  76. ads/pipeline/ads_pipeline_run.py +13 -25
  77. ads/pipeline/visualizer/graph_renderer.py +3 -4
  78. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info}/METADATA +18 -15
  79. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info}/RECORD +82 -74
  80. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info}/WHEEL +1 -1
  81. ads/aqua/config/evaluation/evaluation_service_model_config.py +0 -8
  82. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info}/entry_points.txt +0 -0
  83. {oracle_ads-2.12.11.dist-info → oracle_ads-2.13.1.dist-info/licenses}/LICENSE.txt +0 -0
ads/common/utils.py CHANGED
@@ -1,10 +1,8 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
3
  # Copyright (c) 2020, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
- from __future__ import absolute_import, print_function
8
6
 
9
7
  import collections
10
8
  import contextlib
@@ -23,9 +21,8 @@ import tempfile
23
21
  from datetime import datetime
24
22
  from enum import Enum
25
23
  from io import DEFAULT_BUFFER_SIZE
26
- from pathlib import Path
27
24
  from textwrap import fill
28
- from typing import Dict, Optional, Union
25
+ from typing import Any, Dict, Optional, Tuple, Union
29
26
  from urllib import request
30
27
  from urllib.parse import urlparse
31
28
 
@@ -66,6 +63,8 @@ MIN_RATIO_FOR_DOWN_SAMPLING = 1 / 20
66
63
  # Maximum distinct values by cardinality will be used for plotting
67
64
  MAX_DISPLAY_VALUES = 10
68
65
 
66
+ UNKNOWN = ""
67
+
69
68
  # par link of the index json file.
70
69
  PAR_LINK = "https://objectstorage.us-ashburn-1.oraclecloud.com/p/WyjtfVIG0uda-P3-2FmAfwaLlXYQZbvPZmfX1qg0-sbkwEQO6jpwabGr2hMDBmBp/n/ociodscdev/b/service-conda-packs/o/service_pack/index.json"
71
70
 
@@ -85,6 +84,7 @@ mpl.rcParams["axes.prop_cycle"] = cycler(
85
84
  color=["teal", "blueviolet", "forestgreen", "peru", "y", "dodgerblue", "r"]
86
85
  )
87
86
 
87
+
88
88
  # sqlalchemy engines
89
89
  _engines = {}
90
90
 
@@ -152,6 +152,22 @@ def oci_key_location():
152
152
  )
153
153
 
154
154
 
155
+ def text_sanitizer(content):
156
+ if isinstance(content, str):
157
+ return (
158
+ content.replace("“", '"')
159
+ .replace("”", '"')
160
+ .replace("’", "'")
161
+ .replace("‘", "'")
162
+ .replace("—", "-")
163
+ .encode("utf-8", "ignore")
164
+ .decode("utf-8", "ignore")
165
+ )
166
+ if isinstance(content, dict):
167
+ return json.dumps(content)
168
+ return str(content)
169
+
170
+
155
171
  @deprecated(
156
172
  "2.5.10",
157
173
  details="Deprecated, use: from ads.common.auth import AuthState; AuthState().oci_config_path",
@@ -215,6 +231,37 @@ def random_valid_ocid(prefix="ocid1.dataflowapplication.oc1.iad"):
215
231
  return f"{left}.{fake}"
216
232
 
217
233
 
234
+ def parse_bool(value: Any) -> bool:
235
+ """
236
+ Converts a value to boolean. For strings, it interprets 'true', '1', or 'yes'
237
+ (case insensitive) as True; everything else as False.
238
+
239
+ Parameters
240
+ ----------
241
+ value : Any
242
+ The value to convert to boolean.
243
+
244
+ Returns
245
+ -------
246
+ bool
247
+ The boolean interpretation of the value.
248
+ """
249
+ if isinstance(value, bool):
250
+ return value
251
+ if isinstance(value, str):
252
+ return value.strip().lower() in ("true", "1", "yes")
253
+ return bool(value)
254
+
255
+
256
+ def read_file(file_path: str, **kwargs) -> str:
257
+ try:
258
+ with fsspec.open(file_path, "r", **kwargs.get("auth", {})) as f:
259
+ return f.read()
260
+ except Exception as e:
261
+ logger.debug(f"Failed to read file {file_path}. {e}")
262
+ return UNKNOWN
263
+
264
+
218
265
  def get_dataframe_styles(max_width=75):
219
266
  """Styles used for dataframe, example usage:
220
267
 
@@ -501,13 +548,13 @@ def print_user_message(
501
548
  if is_documentation_mode() and is_notebook():
502
549
  if display_type.lower() == "tip":
503
550
  if "\n" in msg:
504
- t = "<b>{}:</b>".format(title.upper().strip()) if title else ""
551
+ t = f"<b>{title.upper().strip()}:</b>" if title else ""
505
552
 
506
553
  user_message = "{}{}".format(
507
554
  t,
508
555
  "".join(
509
556
  [
510
- "<br>&nbsp;&nbsp;+&nbsp;{}".format(x.strip())
557
+ f"<br>&nbsp;&nbsp;+&nbsp;{x.strip()}"
511
558
  for x in msg.strip().split("\n")
512
559
  ]
513
560
  ),
@@ -646,7 +693,7 @@ def ellipsis_strings(raw, n=24):
646
693
  else:
647
694
  n2 = int(n) // 2 - 3
648
695
  n1 = n - n2 - 3
649
- result.append("{0}...{1}".format(s[:n1], s[-n2:]))
696
+ result.append(f"{s[:n1]}...{s[-n2:]}")
650
697
 
651
698
  return result
652
699
 
@@ -942,9 +989,9 @@ def generate_requirement_file(
942
989
  with open(os.path.join(file_path, file_name), "w") as req_file:
943
990
  for lib in requirements:
944
991
  if requirements[lib]:
945
- req_file.write("{}=={}\n".format(lib, requirements[lib]))
992
+ req_file.write(f"{lib}=={requirements[lib]}\n")
946
993
  else:
947
- req_file.write("{}\n".format(lib))
994
+ req_file.write(f"{lib}\n")
948
995
 
949
996
 
950
997
  def _get_feature_type_and_dtype(column):
@@ -966,7 +1013,7 @@ def to_dataframe(
966
1013
  pd.Series,
967
1014
  np.ndarray,
968
1015
  pd.DataFrame,
969
- ]
1016
+ ],
970
1017
  ):
971
1018
  """
972
1019
  Convert to pandas DataFrame.
@@ -1391,7 +1438,7 @@ def remove_file(file_path: str, auth: Optional[Dict] = None) -> None:
1391
1438
  fs = fsspec.filesystem(scheme, **auth)
1392
1439
  try:
1393
1440
  fs.rm(file_path)
1394
- except FileNotFoundError as e:
1441
+ except FileNotFoundError:
1395
1442
  raise FileNotFoundError(f"`{file_path}` not found.")
1396
1443
  except Exception as e:
1397
1444
  raise e
@@ -1786,3 +1833,36 @@ def get_log_links(
1786
1833
  console_link_url = f"https://cloud.oracle.com/logging/log-groups/{log_group_id}?region={region}"
1787
1834
 
1788
1835
  return console_link_url
1836
+
1837
+
1838
+ def parse_content_disposition(header: str) -> Tuple[str, Dict[str, str]]:
1839
+ """
1840
+ Parses a Content-Disposition header into its main disposition and a dictionary of parameters.
1841
+
1842
+ For example:
1843
+ 'attachment; filename="example.txt"'
1844
+ will be parsed into:
1845
+ ('attachment', {'filename': 'example.txt'})
1846
+
1847
+ Parameters
1848
+ ----------
1849
+ header (str): The Content-Disposition header string.
1850
+
1851
+ Returns
1852
+ -------
1853
+ Tuple[str, Dict[str, str]]: A tuple containing the disposition and a dictionary of parameters.
1854
+ """
1855
+ if not header:
1856
+ return "", {}
1857
+
1858
+ parts = header.split(";")
1859
+ # The first part is the main disposition (e.g., "attachment").
1860
+ disposition = parts[0].strip().lower()
1861
+ params: Dict[str, str] = {}
1862
+
1863
+ # Process each subsequent part to extract key-value pairs.
1864
+ for part in parts[1:]:
1865
+ if "=" in part:
1866
+ key, value = part.split("=", 1)
1867
+ params[key.strip().lower()] = value.strip().strip('"')
1868
+ return disposition, params
ads/config.py CHANGED
@@ -80,6 +80,9 @@ AQUA_TELEMETRY_BUCKET_NS = os.environ.get("AQUA_TELEMETRY_BUCKET_NS", CONDA_BUCK
80
80
  DEBUG_TELEMETRY = os.environ.get("DEBUG_TELEMETRY", None)
81
81
  AQUA_SERVICE_NAME = "aqua"
82
82
  DATA_SCIENCE_SERVICE_NAME = "data-science"
83
+ USER = "USER"
84
+ SERVICE = "SERVICE"
85
+
83
86
 
84
87
 
85
88
  THREADED_DEFAULT_TIMEOUT = 5
ads/llm/__init__.py CHANGED
@@ -1,21 +1,25 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
- # Copyright (c) 2023 Oracle and/or its affiliates.
3
+ # Copyright (c) 2025 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
5
 
7
6
  try:
8
7
  import langchain
9
- from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
10
- OCIModelDeploymentVLLM,
11
- OCIModelDeploymentTGI,
12
- )
8
+
9
+ from ads.llm.chat_template import ChatTemplates
13
10
  from ads.llm.langchain.plugins.chat_models.oci_data_science import (
14
11
  ChatOCIModelDeployment,
15
- ChatOCIModelDeploymentVLLM,
16
12
  ChatOCIModelDeploymentTGI,
13
+ ChatOCIModelDeploymentVLLM,
14
+ )
15
+ from ads.llm.langchain.plugins.embeddings.oci_data_science_model_deployment_endpoint import (
16
+ OCIDataScienceEmbedding,
17
+ )
18
+ from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
19
+ OCIModelDeploymentLLM,
20
+ OCIModelDeploymentTGI,
21
+ OCIModelDeploymentVLLM,
17
22
  )
18
- from ads.llm.chat_template import ChatTemplates
19
23
  except ImportError as ex:
20
24
  if ex.name == "langchain":
21
25
  raise ImportError(
@@ -0,0 +1,4 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) 2025 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -0,0 +1,184 @@
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) 2025 Oracle and/or its affiliates.
4
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+ from typing import Any, Callable, Dict, List, Mapping, Optional
7
+
8
+ import requests
9
+ from langchain_core.embeddings import Embeddings
10
+ from langchain_core.language_models.llms import create_base_retry_decorator
11
+ from pydantic import BaseModel, Field
12
+
13
+ DEFAULT_HEADER = {
14
+ "Content-Type": "application/json",
15
+ }
16
+
17
+
18
+ class TokenExpiredError(Exception):
19
+ pass
20
+
21
+
22
+ def _create_retry_decorator(llm) -> Callable[[Any], Any]:
23
+ """Creates a retry decorator."""
24
+ errors = [requests.exceptions.ConnectTimeout, TokenExpiredError]
25
+ decorator = create_base_retry_decorator(
26
+ error_types=errors, max_retries=llm.max_retries
27
+ )
28
+ return decorator
29
+
30
+
31
+ class OCIDataScienceEmbedding(BaseModel, Embeddings):
32
+ """Embedding model deployed on OCI Data Science Model Deployment.
33
+
34
+ Example:
35
+
36
+ .. code-block:: python
37
+
38
+ from ads.llm import OCIDataScienceEmbedding
39
+
40
+ embeddings = OCIDataScienceEmbedding(
41
+ endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<md_ocid>/predict",
42
+ )
43
+ """ # noqa: E501
44
+
45
+ auth: dict = Field(default_factory=dict, exclude=True)
46
+ """ADS auth dictionary for OCI authentication:
47
+ https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html.
48
+ This can be generated by calling `ads.common.auth.api_keys()`
49
+ or `ads.common.auth.resource_principal()`. If this is not
50
+ provided then the `ads.common.default_signer()` will be used."""
51
+
52
+ endpoint: str = ""
53
+ """The uri of the endpoint from the deployed Model Deployment model."""
54
+
55
+ model_kwargs: Optional[Dict] = None
56
+ """Keyword arguments to pass to the model."""
57
+
58
+ endpoint_kwargs: Optional[Dict] = None
59
+ """Optional attributes (except for headers) passed to the request.post
60
+ function.
61
+ """
62
+
63
+ max_retries: int = 1
64
+ """The maximum number of retries to make when generating."""
65
+
66
+ @property
67
+ def _identifying_params(self) -> Mapping[str, Any]:
68
+ """Get the identifying parameters."""
69
+ _model_kwargs = self.model_kwargs or {}
70
+ return {
71
+ **{"endpoint": self.endpoint},
72
+ **{"model_kwargs": _model_kwargs},
73
+ }
74
+
75
+ def _embed_with_retry(self, **kwargs) -> Any:
76
+ """Use tenacity to retry the call."""
77
+ retry_decorator = _create_retry_decorator(self)
78
+
79
+ @retry_decorator
80
+ def _completion_with_retry(**kwargs: Any) -> Any:
81
+ try:
82
+ response = requests.post(self.endpoint, **kwargs)
83
+ response.raise_for_status()
84
+ return response
85
+ except requests.exceptions.HTTPError as http_err:
86
+ if response.status_code == 401 and self._refresh_signer():
87
+ raise TokenExpiredError() from http_err
88
+ else:
89
+ raise ValueError(
90
+ f"Server error: {str(http_err)}. Message: {response.text}"
91
+ ) from http_err
92
+ except Exception as e:
93
+ raise ValueError(f"Error occurs by inference endpoint: {str(e)}") from e
94
+
95
+ return _completion_with_retry(**kwargs)
96
+
97
+ def _embedding(self, texts: List[str]) -> List[List[float]]:
98
+ """Call out to OCI Data Science Model Deployment Endpoint.
99
+
100
+ Args:
101
+ texts: A list of texts to embed.
102
+
103
+ Returns:
104
+ A list of list of floats representing the embeddings, or None if an
105
+ error occurs.
106
+ """
107
+ _model_kwargs = self.model_kwargs or {}
108
+ body = self._construct_request_body(texts, _model_kwargs)
109
+ request_kwargs = self._construct_request_kwargs(body)
110
+ response = self._embed_with_retry(**request_kwargs)
111
+ return self._proceses_response(response)
112
+
113
+ def _construct_request_kwargs(self, body: Any) -> dict:
114
+ """Constructs the request kwargs as a dictionary."""
115
+ from ads.model.common.utils import _is_json_serializable
116
+
117
+ _endpoint_kwargs = self.endpoint_kwargs or {}
118
+ headers = _endpoint_kwargs.pop("headers", DEFAULT_HEADER)
119
+ return (
120
+ dict(
121
+ headers=headers,
122
+ json=body,
123
+ auth=self.auth.get("signer"),
124
+ **_endpoint_kwargs,
125
+ )
126
+ if _is_json_serializable(body)
127
+ else dict(
128
+ headers=headers,
129
+ data=body,
130
+ auth=self.auth.get("signer"),
131
+ **_endpoint_kwargs,
132
+ )
133
+ )
134
+
135
+ def _construct_request_body(self, texts: List[str], params: dict) -> Any:
136
+ """Constructs the request body."""
137
+ return {"input": texts}
138
+
139
+ def _proceses_response(self, response: requests.Response) -> List[List[float]]:
140
+ """Extracts results from requests.Response."""
141
+ try:
142
+ res_json = response.json()
143
+ embeddings = res_json["data"][0]["embedding"]
144
+ except Exception as e:
145
+ raise ValueError(
146
+ f"Error raised by inference API: {e}.\nResponse: {response.text}"
147
+ ) from e
148
+ return embeddings
149
+
150
+ def embed_documents(
151
+ self,
152
+ texts: List[str],
153
+ chunk_size: Optional[int] = None,
154
+ ) -> List[List[float]]:
155
+ """Compute doc embeddings using OCI Data Science Model Deployment Endpoint.
156
+
157
+ Args:
158
+ texts: The list of texts to embed.
159
+ chunk_size: The chunk size defines how many input texts will
160
+ be grouped together as request. If None, will use the
161
+ chunk size specified by the class.
162
+
163
+ Returns:
164
+ List of embeddings, one for each text.
165
+ """
166
+ results = []
167
+ _chunk_size = (
168
+ len(texts) if (not chunk_size or chunk_size > len(texts)) else chunk_size
169
+ )
170
+ for i in range(0, len(texts), _chunk_size):
171
+ response = self._embedding(texts[i : i + _chunk_size])
172
+ results.extend(response)
173
+ return results
174
+
175
+ def embed_query(self, text: str) -> List[float]:
176
+ """Compute query embeddings using OCI Data Science Model Deployment Endpoint.
177
+
178
+ Args:
179
+ text: The text to embed.
180
+
181
+ Returns:
182
+ Embeddings for the text.
183
+ """
184
+ return self._embedding([text])[0]
@@ -1,6 +1,6 @@
1
1
  #!/usr/bin/env python
2
2
 
3
- # Copyright (c) 2024 Oracle and/or its affiliates.
3
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
4
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
 
6
6
 
@@ -433,23 +433,6 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
433
433
  model: str = DEFAULT_MODEL_NAME
434
434
  """The name of the model."""
435
435
 
436
- max_tokens: int = 256
437
- """Denotes the number of tokens to predict per generation."""
438
-
439
- temperature: float = 0.2
440
- """A non-negative float that tunes the degree of randomness in generation."""
441
-
442
- k: int = -1
443
- """Number of most likely tokens to consider at each step."""
444
-
445
- p: float = 0.75
446
- """Total probability mass of tokens to consider at each step."""
447
-
448
- best_of: int = 1
449
- """Generates best_of completions server-side and returns the "best"
450
- (the one with the highest log probability per token).
451
- """
452
-
453
436
  stop: Optional[List[str]] = None
454
437
  """Stop words to use when generating. Model output is cut off
455
438
  at the first occurrence of any of these substrings."""
@@ -466,14 +449,9 @@ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
466
449
  def _default_params(self) -> Dict[str, Any]:
467
450
  """Get the default parameters."""
468
451
  return {
469
- "best_of": self.best_of,
470
- "max_tokens": self.max_tokens,
471
452
  "model": self.model,
472
453
  "stop": self.stop,
473
454
  "stream": self.streaming,
474
- "temperature": self.temperature,
475
- "top_k": self.k,
476
- "top_p": self.p,
477
455
  }
478
456
 
479
457
  @property
@@ -788,6 +766,23 @@ class OCIModelDeploymentTGI(OCIModelDeploymentLLM):
788
766
 
789
767
  """
790
768
 
769
+ max_tokens: int = 256
770
+ """Denotes the number of tokens to predict per generation."""
771
+
772
+ temperature: float = 0.2
773
+ """A non-negative float that tunes the degree of randomness in generation."""
774
+
775
+ k: int = -1
776
+ """Number of most likely tokens to consider at each step."""
777
+
778
+ p: float = 0.75
779
+ """Total probability mass of tokens to consider at each step."""
780
+
781
+ best_of: int = 1
782
+ """Generates best_of completions server-side and returns the "best"
783
+ (the one with the highest log probability per token).
784
+ """
785
+
791
786
  api: Literal["/generate", "/v1/completions"] = "/v1/completions"
792
787
  """Api spec."""
793
788
 
@@ -922,6 +917,20 @@ class OCIModelDeploymentVLLM(OCIModelDeploymentLLM):
922
917
 
923
918
  """
924
919
 
920
+ max_tokens: int = 256
921
+ """Denotes the number of tokens to predict per generation."""
922
+
923
+ temperature: float = 0.2
924
+ """A non-negative float that tunes the degree of randomness in generation."""
925
+
926
+ p: float = 0.75
927
+ """Total probability mass of tokens to consider at each step."""
928
+
929
+ best_of: int = 1
930
+ """Generates best_of completions server-side and returns the "best"
931
+ (the one with the highest log probability per token).
932
+ """
933
+
925
934
  n: int = 1
926
935
  """Number of output sequences to return for the given prompt."""
927
936
 
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8; -*-
3
2
 
4
3
  # Copyright (c) 2022, 2024 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -12,9 +11,9 @@ from typing import Dict, Optional
12
11
  from zipfile import ZipFile
13
12
 
14
13
  from ads.common import utils
14
+ from ads.common.object_storage_details import ObjectStorageDetails
15
15
  from ads.common.utils import extract_region
16
16
  from ads.model.service.oci_datascience_model import OCIDataScienceModel
17
- from ads.common.object_storage_details import ObjectStorageDetails
18
17
 
19
18
 
20
19
  class ArtifactDownloader(ABC):
@@ -169,9 +168,12 @@ class LargeArtifactDownloader(ArtifactDownloader):
169
168
 
170
169
  def _download(self):
171
170
  """Downloads model artifacts."""
172
- self.progress.update(f"Importing model artifacts from catalog")
171
+ self.progress.update("Importing model artifacts from catalog")
173
172
 
174
- if self.dsc_model.is_model_by_reference() and self.model_file_description:
173
+ if (
174
+ self.dsc_model.is_model_created_by_reference()
175
+ and self.model_file_description
176
+ ):
175
177
  self.download_from_model_file_description()
176
178
  self.progress.update()
177
179
  return
ads/model/common/utils.py CHANGED
@@ -1,5 +1,4 @@
1
1
  #!/usr/bin/env python
2
- # -*- coding: utf-8 -*--
3
2
 
4
3
  # Copyright (c) 2022, 2023 Oracle and/or its affiliates.
5
4
  # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -7,16 +6,29 @@
7
6
  import json
8
7
  import os
9
8
  import tempfile
10
- import yaml
11
9
  from typing import Any, Dict, Optional
12
10
  from zipfile import ZipFile
13
- from ads.common import utils
14
11
 
12
+ import yaml
13
+
14
+ from ads.common import utils
15
+ from ads.common.extended_enum import ExtendedEnum
15
16
 
16
17
  DEPRECATE_AS_ONNX_WARNING = "This attribute `as_onnx` will be deprecated in the future. You can choose specific format by setting `model_save_serializer`."
17
18
  DEPRECATE_USE_TORCH_SCRIPT_WARNING = "This attribute `use_torch_script` will be deprecated in the future. You can choose specific format by setting `model_save_serializer`."
18
19
 
19
20
 
21
+ class MetadataArtifactPathType(ExtendedEnum):
22
+ """
23
+ Enum for defining metadata artifact path type.
24
+ Can be either local path or OSS path. It can also be the content itself.
25
+ """
26
+
27
+ LOCAL = "local"
28
+ OSS = "oss"
29
+ CONTENT = "content"
30
+
31
+
20
32
  def _extract_locals(
21
33
  locals: Dict[str, Any], filter_out_nulls: Optional[bool] = True
22
34
  ) -> Dict[str, Any]: