oracle-ads 2.11.18__py3-none-any.whl → 2.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/common/utils.py +20 -3
- ads/aqua/config/__init__.py +4 -0
- ads/aqua/config/config.py +28 -0
- ads/aqua/config/evaluation/__init__.py +4 -0
- ads/aqua/config/evaluation/evaluation_service_config.py +282 -0
- ads/aqua/config/evaluation/evaluation_service_model_config.py +8 -0
- ads/aqua/config/utils/__init__.py +4 -0
- ads/aqua/config/utils/serializer.py +339 -0
- ads/aqua/constants.py +1 -1
- ads/aqua/evaluation/entities.py +1 -0
- ads/aqua/evaluation/evaluation.py +56 -88
- ads/aqua/extension/common_handler.py +2 -3
- ads/aqua/extension/common_ws_msg_handler.py +2 -2
- ads/aqua/extension/evaluation_handler.py +4 -3
- ads/aqua/extension/model_handler.py +26 -1
- ads/aqua/extension/utils.py +12 -1
- ads/aqua/modeldeployment/deployment.py +31 -51
- ads/aqua/ui.py +27 -25
- ads/llm/__init__.py +10 -4
- ads/llm/chat_template.py +31 -0
- ads/llm/guardrails/base.py +3 -2
- ads/llm/guardrails/huggingface.py +1 -1
- ads/llm/langchain/plugins/chat_models/__init__.py +5 -0
- ads/llm/langchain/plugins/chat_models/oci_data_science.py +924 -0
- ads/llm/langchain/plugins/llms/__init__.py +5 -0
- ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +939 -0
- ads/llm/requirements.txt +2 -2
- ads/llm/serialize.py +3 -6
- ads/llm/templates/tool_chat_template_hermes.jinja +130 -0
- ads/llm/templates/tool_chat_template_mistral_parallel.jinja +94 -0
- {oracle_ads-2.11.18.dist-info → oracle_ads-2.12.0.dist-info}/METADATA +7 -4
- {oracle_ads-2.11.18.dist-info → oracle_ads-2.12.0.dist-info}/RECORD +35 -27
- ads/llm/langchain/plugins/base.py +0 -118
- ads/llm/langchain/plugins/contant.py +0 -44
- ads/llm/langchain/plugins/embeddings.py +0 -64
- ads/llm/langchain/plugins/llm_gen_ai.py +0 -301
- ads/llm/langchain/plugins/llm_md.py +0 -316
- {oracle_ads-2.11.18.dist-info → oracle_ads-2.12.0.dist-info}/LICENSE.txt +0 -0
- {oracle_ads-2.11.18.dist-info → oracle_ads-2.12.0.dist-info}/WHEEL +0 -0
- {oracle_ads-2.11.18.dist-info → oracle_ads-2.12.0.dist-info}/entry_points.txt +0 -0
@@ -1,64 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
|
-
|
4
|
-
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
|
5
|
-
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
-
|
7
|
-
from typing import List, Optional
|
8
|
-
from langchain.load.serializable import Serializable
|
9
|
-
from langchain.schema.embeddings import Embeddings
|
10
|
-
from ads.llm.langchain.plugins.base import GenerativeAiClientModel
|
11
|
-
|
12
|
-
|
13
|
-
class GenerativeAIEmbeddings(GenerativeAiClientModel, Embeddings, Serializable):
|
14
|
-
"""OCI Generative AI embedding models."""
|
15
|
-
|
16
|
-
model: str = "cohere.embed-english-light-v2.0"
|
17
|
-
"""Model name to use."""
|
18
|
-
|
19
|
-
truncate: Optional[str] = None
|
20
|
-
"""Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")"""
|
21
|
-
|
22
|
-
@classmethod
|
23
|
-
def get_lc_namespace(cls) -> List[str]:
|
24
|
-
"""Get the namespace of the LangChain object."""
|
25
|
-
return ["ads", "llm"]
|
26
|
-
|
27
|
-
@classmethod
|
28
|
-
def is_lc_serializable(cls) -> bool:
|
29
|
-
"""This class can be serialized with default LangChain serialization."""
|
30
|
-
return True
|
31
|
-
|
32
|
-
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
33
|
-
"""Embeds a list of strings.
|
34
|
-
|
35
|
-
Args:
|
36
|
-
texts: The list of texts to embed.
|
37
|
-
|
38
|
-
Returns:
|
39
|
-
List of embeddings, one for each text.
|
40
|
-
"""
|
41
|
-
from oci.generative_ai_inference.models import (
|
42
|
-
EmbedTextDetails,
|
43
|
-
OnDemandServingMode,
|
44
|
-
)
|
45
|
-
|
46
|
-
details = EmbedTextDetails(
|
47
|
-
compartment_id=self.compartment_id,
|
48
|
-
inputs=texts,
|
49
|
-
serving_mode=OnDemandServingMode(model_id=self.model),
|
50
|
-
truncate=self.truncate,
|
51
|
-
)
|
52
|
-
embeddings = self.client.embed_text(details).data.embeddings
|
53
|
-
return [list(map(float, e)) for e in embeddings]
|
54
|
-
|
55
|
-
def embed_query(self, text: str) -> List[float]:
|
56
|
-
"""Embeds a single string.
|
57
|
-
|
58
|
-
Args:
|
59
|
-
text: The text to embed.
|
60
|
-
|
61
|
-
Returns:
|
62
|
-
Embeddings for the text.
|
63
|
-
"""
|
64
|
-
return self.embed_documents([text])[0]
|
@@ -1,301 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
|
-
|
4
|
-
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
|
5
|
-
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
-
|
7
|
-
import logging
|
8
|
-
from typing import Any, Dict, List, Optional
|
9
|
-
|
10
|
-
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
11
|
-
|
12
|
-
from ads.llm.langchain.plugins.base import BaseLLM, GenerativeAiClientModel
|
13
|
-
from ads.llm.langchain.plugins.contant import Task
|
14
|
-
|
15
|
-
logger = logging.getLogger(__name__)
|
16
|
-
|
17
|
-
|
18
|
-
class GenerativeAI(GenerativeAiClientModel, BaseLLM):
|
19
|
-
"""GenerativeAI Service.
|
20
|
-
|
21
|
-
To use, you should have the ``oci`` python package installed.
|
22
|
-
|
23
|
-
Example
|
24
|
-
-------
|
25
|
-
|
26
|
-
.. code-block:: python
|
27
|
-
|
28
|
-
from ads.llm import GenerativeAI
|
29
|
-
|
30
|
-
gen_ai = GenerativeAI(compartment_id="ocid1.compartment.oc1..<ocid>")
|
31
|
-
|
32
|
-
"""
|
33
|
-
|
34
|
-
task: str = "text_generation"
|
35
|
-
"""Task can be either text_generation or text_summarization."""
|
36
|
-
|
37
|
-
model: Optional[str] = "cohere.command"
|
38
|
-
"""Model name to use."""
|
39
|
-
|
40
|
-
frequency_penalty: float = None
|
41
|
-
"""Penalizes repeated tokens according to frequency. Between 0 and 1."""
|
42
|
-
|
43
|
-
presence_penalty: float = None
|
44
|
-
"""Penalizes repeated tokens. Between 0 and 1."""
|
45
|
-
|
46
|
-
truncate: Optional[str] = None
|
47
|
-
"""Specify how the client handles inputs longer than the maximum token."""
|
48
|
-
|
49
|
-
length: str = "AUTO"
|
50
|
-
"""Indicates the approximate length of the summary. """
|
51
|
-
|
52
|
-
format: str = "PARAGRAPH"
|
53
|
-
"""Indicates the style in which the summary will be delivered - in a free form paragraph or in bullet points."""
|
54
|
-
|
55
|
-
extractiveness: str = "AUTO"
|
56
|
-
"""Controls how close to the original text the summary is. High extractiveness summaries will lean towards reusing sentences verbatim, while low extractiveness summaries will tend to paraphrase more."""
|
57
|
-
|
58
|
-
additional_command: str = ""
|
59
|
-
"""A free-form instruction for modifying how the summaries get generated. """
|
60
|
-
|
61
|
-
@property
|
62
|
-
def _identifying_params(self) -> Dict[str, Any]:
|
63
|
-
"""Get the identifying parameters."""
|
64
|
-
return {
|
65
|
-
**{
|
66
|
-
"model": self.model,
|
67
|
-
"task": self.task,
|
68
|
-
"client_kwargs": self.client_kwargs,
|
69
|
-
"endpoint_kwargs": self.endpoint_kwargs,
|
70
|
-
},
|
71
|
-
**self._default_params,
|
72
|
-
}
|
73
|
-
|
74
|
-
@property
|
75
|
-
def _llm_type(self) -> str:
|
76
|
-
"""Return type of llm."""
|
77
|
-
return "GenerativeAI"
|
78
|
-
|
79
|
-
@property
|
80
|
-
def _default_params(self) -> Dict[str, Any]:
|
81
|
-
"""Get the default parameters for calling OCIGenerativeAI API."""
|
82
|
-
# This property is used by _identifying_params(), which then used for serialization
|
83
|
-
# All parameters returning here should be JSON serializable.
|
84
|
-
|
85
|
-
return (
|
86
|
-
{
|
87
|
-
"compartment_id": self.compartment_id,
|
88
|
-
"temperature": self.temperature,
|
89
|
-
"max_tokens": self.max_tokens,
|
90
|
-
"top_k": self.k,
|
91
|
-
"top_p": self.p,
|
92
|
-
"frequency_penalty": self.frequency_penalty,
|
93
|
-
"presence_penalty": self.presence_penalty,
|
94
|
-
"truncate": self.truncate,
|
95
|
-
}
|
96
|
-
if self.task == Task.TEXT_GENERATION
|
97
|
-
else {
|
98
|
-
"compartment_id": self.compartment_id,
|
99
|
-
"temperature": self.temperature,
|
100
|
-
"length": self.length,
|
101
|
-
"format": self.format,
|
102
|
-
"extractiveness": self.extractiveness,
|
103
|
-
"additional_command": self.additional_command,
|
104
|
-
}
|
105
|
-
)
|
106
|
-
|
107
|
-
def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict:
|
108
|
-
params = self._default_params
|
109
|
-
if self.task == Task.TEXT_SUMMARIZATION:
|
110
|
-
return {**params}
|
111
|
-
|
112
|
-
if self.stop is not None and stop is not None:
|
113
|
-
raise ValueError("`stop` found in both the input and default params.")
|
114
|
-
elif self.stop is not None:
|
115
|
-
params["stop_sequences"] = self.stop
|
116
|
-
else:
|
117
|
-
params["stop_sequences"] = stop
|
118
|
-
return {**params, **kwargs}
|
119
|
-
|
120
|
-
def _call(
|
121
|
-
self,
|
122
|
-
prompt: str,
|
123
|
-
stop: Optional[List[str]] = None,
|
124
|
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
125
|
-
**kwargs: Any,
|
126
|
-
):
|
127
|
-
"""Call out to GenerativeAI's generate endpoint.
|
128
|
-
|
129
|
-
Parameters
|
130
|
-
----------
|
131
|
-
prompt (str):
|
132
|
-
The prompt to pass into the model.
|
133
|
-
stop (List[str], Optional):
|
134
|
-
List of stop words to use when generating.
|
135
|
-
|
136
|
-
Returns
|
137
|
-
-------
|
138
|
-
The string generated by the model.
|
139
|
-
|
140
|
-
Example
|
141
|
-
-------
|
142
|
-
|
143
|
-
.. code-block:: python
|
144
|
-
|
145
|
-
response = gen_ai("Tell me a joke.")
|
146
|
-
"""
|
147
|
-
|
148
|
-
params = self._invocation_params(stop, **kwargs)
|
149
|
-
self._print_request(prompt, params)
|
150
|
-
|
151
|
-
try:
|
152
|
-
completion = self.completion_with_retry(prompt=prompt, **params)
|
153
|
-
except Exception:
|
154
|
-
logger.error(
|
155
|
-
"Error occur when invoking oci service api."
|
156
|
-
"DEBUG INTO: task=%s, params=%s, prompt=%s",
|
157
|
-
self.task,
|
158
|
-
params,
|
159
|
-
prompt,
|
160
|
-
)
|
161
|
-
raise
|
162
|
-
|
163
|
-
return completion
|
164
|
-
|
165
|
-
def _text_generation(self, request_class, serving_mode, **kwargs):
|
166
|
-
from oci.generative_ai_inference.models import (
|
167
|
-
GenerateTextDetails,
|
168
|
-
GenerateTextResult,
|
169
|
-
)
|
170
|
-
|
171
|
-
compartment_id = kwargs.pop("compartment_id")
|
172
|
-
inference_request = request_class(**kwargs)
|
173
|
-
response = self.client.generate_text(
|
174
|
-
GenerateTextDetails(
|
175
|
-
compartment_id=compartment_id,
|
176
|
-
serving_mode=serving_mode,
|
177
|
-
inference_request=inference_request,
|
178
|
-
),
|
179
|
-
**self.endpoint_kwargs,
|
180
|
-
).data
|
181
|
-
response: GenerateTextResult
|
182
|
-
return response.inference_response
|
183
|
-
|
184
|
-
def _cohere_completion(self, serving_mode, **kwargs) -> str:
|
185
|
-
from oci.generative_ai_inference.models import (
|
186
|
-
CohereLlmInferenceRequest,
|
187
|
-
CohereLlmInferenceResponse,
|
188
|
-
)
|
189
|
-
|
190
|
-
response = self._text_generation(
|
191
|
-
CohereLlmInferenceRequest, serving_mode, **kwargs
|
192
|
-
)
|
193
|
-
response: CohereLlmInferenceResponse
|
194
|
-
if kwargs.get("num_generations", 1) == 1:
|
195
|
-
completion = response.generated_texts[0].text
|
196
|
-
else:
|
197
|
-
completion = [result.text for result in response.generated_texts]
|
198
|
-
self._print_response(completion, response)
|
199
|
-
return completion
|
200
|
-
|
201
|
-
def _llama_completion(self, serving_mode, **kwargs) -> str:
|
202
|
-
from oci.generative_ai_inference.models import (
|
203
|
-
LlamaLlmInferenceRequest,
|
204
|
-
LlamaLlmInferenceResponse,
|
205
|
-
)
|
206
|
-
|
207
|
-
# truncate and stop_sequence are not supported.
|
208
|
-
kwargs.pop("truncate", None)
|
209
|
-
kwargs.pop("stop_sequences", None)
|
210
|
-
# top_k must be >1 or -1
|
211
|
-
if "top_k" in kwargs and kwargs["top_k"] == 0:
|
212
|
-
kwargs.pop("top_k")
|
213
|
-
|
214
|
-
# top_p must be 1 when temperature is 0
|
215
|
-
if kwargs.get("temperature") == 0:
|
216
|
-
kwargs["top_p"] = 1
|
217
|
-
|
218
|
-
response = self._text_generation(
|
219
|
-
LlamaLlmInferenceRequest, serving_mode, **kwargs
|
220
|
-
)
|
221
|
-
response: LlamaLlmInferenceResponse
|
222
|
-
if kwargs.get("num_generations", 1) == 1:
|
223
|
-
completion = response.choices[0].text
|
224
|
-
else:
|
225
|
-
completion = [result.text for result in response.choices]
|
226
|
-
self._print_response(completion, response)
|
227
|
-
return completion
|
228
|
-
|
229
|
-
def _cohere_summarize(self, serving_mode, **kwargs) -> str:
|
230
|
-
from oci.generative_ai_inference.models import SummarizeTextDetails
|
231
|
-
|
232
|
-
kwargs["input"] = kwargs.pop("prompt")
|
233
|
-
|
234
|
-
response = self.client.summarize_text(
|
235
|
-
SummarizeTextDetails(serving_mode=serving_mode, **kwargs),
|
236
|
-
**self.endpoint_kwargs,
|
237
|
-
)
|
238
|
-
return response.data.summary
|
239
|
-
|
240
|
-
def completion_with_retry(self, **kwargs: Any) -> Any:
|
241
|
-
from oci.generative_ai_inference.models import OnDemandServingMode
|
242
|
-
|
243
|
-
serving_mode = OnDemandServingMode(model_id=self.model)
|
244
|
-
|
245
|
-
if self.task == Task.TEXT_SUMMARIZATION:
|
246
|
-
return self._cohere_summarize(serving_mode, **kwargs)
|
247
|
-
elif self.model.startswith("cohere"):
|
248
|
-
return self._cohere_completion(serving_mode, **kwargs)
|
249
|
-
elif self.model.startswith("meta.llama"):
|
250
|
-
return self._llama_completion(serving_mode, **kwargs)
|
251
|
-
raise ValueError(f"Model {self.model} is not supported.")
|
252
|
-
|
253
|
-
def batch_completion(
|
254
|
-
self,
|
255
|
-
prompt: str,
|
256
|
-
stop: Optional[List[str]] = None,
|
257
|
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
258
|
-
num_generations: int = 1,
|
259
|
-
**kwargs: Any,
|
260
|
-
) -> List[str]:
|
261
|
-
"""Generates multiple completion for the given prompt.
|
262
|
-
|
263
|
-
Parameters
|
264
|
-
----------
|
265
|
-
prompt (str):
|
266
|
-
The prompt to pass into the model.
|
267
|
-
stop: (List[str], optional):
|
268
|
-
Optional list of stop words to use when generating. Defaults to None.
|
269
|
-
num_generations (int, optional):
|
270
|
-
Number of completions aims to get. Defaults to 1.
|
271
|
-
|
272
|
-
Raises
|
273
|
-
------
|
274
|
-
NotImplementedError
|
275
|
-
Raise when invoking batch_completion under summarization task.
|
276
|
-
|
277
|
-
Returns
|
278
|
-
-------
|
279
|
-
List[str]
|
280
|
-
List of multiple completions.
|
281
|
-
|
282
|
-
Example
|
283
|
-
-------
|
284
|
-
|
285
|
-
.. code-block:: python
|
286
|
-
|
287
|
-
responses = gen_ai.batch_completion("Tell me a joke.", num_generations=5)
|
288
|
-
|
289
|
-
"""
|
290
|
-
if self.task == Task.TEXT_SUMMARIZATION:
|
291
|
-
raise NotImplementedError(
|
292
|
-
f"task={Task.TEXT_SUMMARIZATION} does not support batch_completion. "
|
293
|
-
)
|
294
|
-
|
295
|
-
return self._call(
|
296
|
-
prompt=prompt,
|
297
|
-
stop=stop,
|
298
|
-
run_manager=run_manager,
|
299
|
-
num_generations=num_generations,
|
300
|
-
**kwargs,
|
301
|
-
)
|
@@ -1,316 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python
|
2
|
-
# -*- coding: utf-8 -*--
|
3
|
-
|
4
|
-
# Copyright (c) 2023 Oracle and/or its affiliates.
|
5
|
-
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6
|
-
|
7
|
-
import logging
|
8
|
-
from typing import Any, Dict, List, Optional
|
9
|
-
|
10
|
-
import requests
|
11
|
-
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
12
|
-
from langchain.pydantic_v1 import root_validator
|
13
|
-
from langchain.utils import get_from_dict_or_env
|
14
|
-
from oci.auth import signers
|
15
|
-
|
16
|
-
from ads.llm.langchain.plugins.base import BaseLLM
|
17
|
-
from ads.llm.langchain.plugins.contant import (
|
18
|
-
DEFAULT_CONTENT_TYPE_JSON,
|
19
|
-
DEFAULT_TIME_OUT,
|
20
|
-
)
|
21
|
-
|
22
|
-
logger = logging.getLogger(__name__)
|
23
|
-
|
24
|
-
|
25
|
-
class ModelDeploymentLLM(BaseLLM):
|
26
|
-
"""Base class for LLM deployed on OCI Model Deployment."""
|
27
|
-
|
28
|
-
endpoint: str = ""
|
29
|
-
"""The uri of the endpoint from the deployed Model Deployment model."""
|
30
|
-
|
31
|
-
best_of: int = 1
|
32
|
-
"""Generates best_of completions server-side and returns the "best"
|
33
|
-
(the one with the highest log probability per token).
|
34
|
-
"""
|
35
|
-
|
36
|
-
@root_validator()
|
37
|
-
def validate_environment( # pylint: disable=no-self-argument
|
38
|
-
cls, values: Dict
|
39
|
-
) -> Dict:
|
40
|
-
"""Fetch endpoint from environment variable or arguments."""
|
41
|
-
values["endpoint"] = get_from_dict_or_env(
|
42
|
-
values,
|
43
|
-
"endpoint",
|
44
|
-
"OCI_LLM_ENDPOINT",
|
45
|
-
)
|
46
|
-
return values
|
47
|
-
|
48
|
-
@property
|
49
|
-
def _default_params(self) -> Dict[str, Any]:
|
50
|
-
"""Default parameters for the model."""
|
51
|
-
raise NotImplementedError()
|
52
|
-
|
53
|
-
@property
|
54
|
-
def _identifying_params(self) -> Dict[str, Any]:
|
55
|
-
"""Get the identifying parameters."""
|
56
|
-
return {
|
57
|
-
**{"endpoint": self.endpoint},
|
58
|
-
**self._default_params,
|
59
|
-
}
|
60
|
-
|
61
|
-
def _construct_json_body(self, prompt, params):
|
62
|
-
"""Constructs the request body as a dictionary (JSON)."""
|
63
|
-
raise NotImplementedError
|
64
|
-
|
65
|
-
def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict:
|
66
|
-
"""Combines the invocation parameters with default parameters."""
|
67
|
-
params = self._default_params
|
68
|
-
if self.stop is not None and stop is not None:
|
69
|
-
raise ValueError("`stop` found in both the input and default params.")
|
70
|
-
elif self.stop is not None:
|
71
|
-
params["stop"] = self.stop
|
72
|
-
elif stop is not None:
|
73
|
-
params["stop"] = stop
|
74
|
-
else:
|
75
|
-
# Don't set "stop" in param as None. It should be a list.
|
76
|
-
params["stop"] = []
|
77
|
-
|
78
|
-
return {**params, **kwargs}
|
79
|
-
|
80
|
-
def _process_response(self, response_json: dict):
|
81
|
-
return response_json
|
82
|
-
|
83
|
-
def _call(
|
84
|
-
self,
|
85
|
-
prompt: str,
|
86
|
-
stop: Optional[List[str]] = None,
|
87
|
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
88
|
-
**kwargs: Any,
|
89
|
-
) -> str:
|
90
|
-
"""Call out to OCI Data Science Model Deployment endpoint.
|
91
|
-
|
92
|
-
Parameters
|
93
|
-
----------
|
94
|
-
prompt (str):
|
95
|
-
The prompt to pass into the model.
|
96
|
-
stop (List[str], Optional):
|
97
|
-
List of stop words to use when generating.
|
98
|
-
|
99
|
-
Returns
|
100
|
-
-------
|
101
|
-
The string generated by the model.
|
102
|
-
|
103
|
-
Example
|
104
|
-
-------
|
105
|
-
|
106
|
-
.. code-block:: python
|
107
|
-
|
108
|
-
response = oci_md("Tell me a joke.")
|
109
|
-
|
110
|
-
"""
|
111
|
-
params = self._invocation_params(stop, **kwargs)
|
112
|
-
body = self._construct_json_body(prompt, params)
|
113
|
-
self._print_request(prompt, params)
|
114
|
-
response = self.send_request(data=body, endpoint=self.endpoint)
|
115
|
-
completion = self._process_response(response)
|
116
|
-
self._print_response(completion, response)
|
117
|
-
return completion
|
118
|
-
|
119
|
-
def send_request(
|
120
|
-
self,
|
121
|
-
data,
|
122
|
-
endpoint: str,
|
123
|
-
header: dict = None,
|
124
|
-
**kwargs,
|
125
|
-
) -> Dict:
|
126
|
-
"""Sends request to the model deployment endpoint.
|
127
|
-
|
128
|
-
Parameters
|
129
|
-
----------
|
130
|
-
data (Json serializable):
|
131
|
-
data need to be sent to the endpoint.
|
132
|
-
endpoint (str):
|
133
|
-
The model HTTP endpoint.
|
134
|
-
header (dict, optional):
|
135
|
-
A dictionary of HTTP headers to send to the specified url. Defaults to {}.
|
136
|
-
|
137
|
-
Raises
|
138
|
-
------
|
139
|
-
Exception:
|
140
|
-
Raise when invoking fails.
|
141
|
-
|
142
|
-
Returns
|
143
|
-
-------
|
144
|
-
A JSON representation of a requests.Response object.
|
145
|
-
"""
|
146
|
-
if not header:
|
147
|
-
header = {}
|
148
|
-
header["Content-Type"] = (
|
149
|
-
header.pop("content_type", DEFAULT_CONTENT_TYPE_JSON)
|
150
|
-
or DEFAULT_CONTENT_TYPE_JSON
|
151
|
-
)
|
152
|
-
timeout = kwargs.pop("timeout", DEFAULT_TIME_OUT)
|
153
|
-
request_kwargs = {"json": data}
|
154
|
-
request_kwargs["headers"] = header
|
155
|
-
signer = self.auth.get("signer")
|
156
|
-
|
157
|
-
attempts = 0
|
158
|
-
while attempts < 2:
|
159
|
-
request_kwargs["auth"] = signer
|
160
|
-
response = requests.post(
|
161
|
-
endpoint, timeout=timeout, **request_kwargs, **kwargs
|
162
|
-
)
|
163
|
-
if response.status_code == 401 and self.is_principal_signer(signer):
|
164
|
-
signer.refresh_security_token()
|
165
|
-
attempts += 1
|
166
|
-
continue
|
167
|
-
break
|
168
|
-
|
169
|
-
try:
|
170
|
-
response.raise_for_status()
|
171
|
-
response_json = response.json()
|
172
|
-
|
173
|
-
except Exception:
|
174
|
-
logger.error(
|
175
|
-
"DEBUG INFO: request_kwargs=%s, status_code=%s, content=%s",
|
176
|
-
request_kwargs,
|
177
|
-
response.status_code,
|
178
|
-
response.content,
|
179
|
-
)
|
180
|
-
raise
|
181
|
-
|
182
|
-
return response_json
|
183
|
-
|
184
|
-
@staticmethod
|
185
|
-
def is_principal_signer(signer):
|
186
|
-
"""Checks if the signer is instance principal or resource principal signer."""
|
187
|
-
if (
|
188
|
-
isinstance(signer, signers.InstancePrincipalsSecurityTokenSigner)
|
189
|
-
or isinstance(signer, signers.ResourcePrincipalsFederationSigner)
|
190
|
-
or isinstance(signer, signers.EphemeralResourcePrincipalSigner)
|
191
|
-
or isinstance(signer, signers.EphemeralResourcePrincipalV21Signer)
|
192
|
-
or isinstance(signer, signers.NestedResourcePrincipals)
|
193
|
-
or isinstance(signer, signers.OkeWorkloadIdentityResourcePrincipalSigner)
|
194
|
-
):
|
195
|
-
return True
|
196
|
-
else:
|
197
|
-
return False
|
198
|
-
|
199
|
-
|
200
|
-
class ModelDeploymentTGI(ModelDeploymentLLM):
|
201
|
-
"""OCI Data Science Model Deployment TGI Endpoint.
|
202
|
-
|
203
|
-
Example
|
204
|
-
-------
|
205
|
-
|
206
|
-
.. code-block:: python
|
207
|
-
|
208
|
-
from ads.llm import ModelDeploymentTGI
|
209
|
-
|
210
|
-
oci_md = ModelDeploymentTGI(endpoint="<url_of_model_deployment_endpoint>")
|
211
|
-
|
212
|
-
"""
|
213
|
-
|
214
|
-
do_sample: bool = True
|
215
|
-
"""if set to True, this parameter enables decoding strategies such as
|
216
|
-
multi-nominal sampling, beam-search multi-nominal sampling, Top-K sampling and Top-p sampling.
|
217
|
-
"""
|
218
|
-
|
219
|
-
watermark = True
|
220
|
-
"""Watermarking with `A Watermark for Large Language Models <https://arxiv.org/abs/2301.10226>`_.
|
221
|
-
Defaults to True."""
|
222
|
-
|
223
|
-
return_full_text = False
|
224
|
-
"""Whether to prepend the prompt to the generated text. Defaults to False."""
|
225
|
-
|
226
|
-
@property
|
227
|
-
def _llm_type(self) -> str:
|
228
|
-
"""Return type of llm."""
|
229
|
-
return "oci_model_deployment_tgi_endpoint"
|
230
|
-
|
231
|
-
@property
|
232
|
-
def _default_params(self) -> Dict[str, Any]:
|
233
|
-
"""Get the default parameters for invoking OCI model deployment TGI endpoint."""
|
234
|
-
return {
|
235
|
-
"best_of": self.best_of,
|
236
|
-
"max_new_tokens": self.max_tokens,
|
237
|
-
"temperature": self.temperature,
|
238
|
-
"top_k": self.k
|
239
|
-
if self.k > 0
|
240
|
-
else None, # `top_k` must be strictly positive'
|
241
|
-
"top_p": self.p,
|
242
|
-
"do_sample": self.do_sample,
|
243
|
-
"return_full_text": self.return_full_text,
|
244
|
-
"watermark": self.watermark,
|
245
|
-
}
|
246
|
-
|
247
|
-
def _construct_json_body(self, prompt, params):
|
248
|
-
return {
|
249
|
-
"inputs": prompt,
|
250
|
-
"parameters": params,
|
251
|
-
}
|
252
|
-
|
253
|
-
def _process_response(self, response_json: dict):
|
254
|
-
return str(response_json.get("generated_text", response_json))
|
255
|
-
|
256
|
-
|
257
|
-
class ModelDeploymentVLLM(ModelDeploymentLLM):
|
258
|
-
"""VLLM deployed on OCI Model Deployment"""
|
259
|
-
|
260
|
-
model: str
|
261
|
-
"""Name of the model."""
|
262
|
-
|
263
|
-
n: int = 1
|
264
|
-
"""Number of output sequences to return for the given prompt."""
|
265
|
-
|
266
|
-
k: int = -1
|
267
|
-
"""Number of most likely tokens to consider at each step."""
|
268
|
-
|
269
|
-
frequency_penalty: float = 0.0
|
270
|
-
"""Penalizes repeated tokens according to frequency. Between 0 and 1."""
|
271
|
-
|
272
|
-
presence_penalty: float = 0.0
|
273
|
-
"""Penalizes repeated tokens. Between 0 and 1."""
|
274
|
-
|
275
|
-
use_beam_search: bool = False
|
276
|
-
"""Whether to use beam search instead of sampling."""
|
277
|
-
|
278
|
-
ignore_eos: bool = False
|
279
|
-
"""Whether to ignore the EOS token and continue generating tokens after
|
280
|
-
the EOS token is generated."""
|
281
|
-
|
282
|
-
logprobs: Optional[int] = None
|
283
|
-
"""Number of log probabilities to return per output token."""
|
284
|
-
|
285
|
-
@property
|
286
|
-
def _llm_type(self) -> str:
|
287
|
-
"""Return type of llm."""
|
288
|
-
return "oci_model_deployment_vllm_endpoint"
|
289
|
-
|
290
|
-
@property
|
291
|
-
def _default_params(self) -> Dict[str, Any]:
|
292
|
-
"""Get the default parameters for calling vllm."""
|
293
|
-
return {
|
294
|
-
"n": self.n,
|
295
|
-
"best_of": self.best_of,
|
296
|
-
"max_tokens": self.max_tokens,
|
297
|
-
"top_k": self.k,
|
298
|
-
"top_p": self.p,
|
299
|
-
"temperature": self.temperature,
|
300
|
-
"presence_penalty": self.presence_penalty,
|
301
|
-
"frequency_penalty": self.frequency_penalty,
|
302
|
-
"stop": self.stop,
|
303
|
-
"ignore_eos": self.ignore_eos,
|
304
|
-
"use_beam_search": self.use_beam_search,
|
305
|
-
"logprobs": self.logprobs,
|
306
|
-
"model": self.model,
|
307
|
-
}
|
308
|
-
|
309
|
-
def _construct_json_body(self, prompt, params):
|
310
|
-
return {
|
311
|
-
"prompt": prompt,
|
312
|
-
**params,
|
313
|
-
}
|
314
|
-
|
315
|
-
def _process_response(self, response_json: dict):
|
316
|
-
return response_json["choices"][0]["text"]
|
File without changes
|
File without changes
|
File without changes
|