oracle-ads 2.11.18__py3-none-any.whl → 2.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. ads/aqua/common/utils.py +20 -3
  2. ads/aqua/config/__init__.py +4 -0
  3. ads/aqua/config/config.py +28 -0
  4. ads/aqua/config/evaluation/__init__.py +4 -0
  5. ads/aqua/config/evaluation/evaluation_service_config.py +282 -0
  6. ads/aqua/config/evaluation/evaluation_service_model_config.py +8 -0
  7. ads/aqua/config/utils/__init__.py +4 -0
  8. ads/aqua/config/utils/serializer.py +339 -0
  9. ads/aqua/constants.py +1 -1
  10. ads/aqua/evaluation/entities.py +1 -0
  11. ads/aqua/evaluation/evaluation.py +56 -88
  12. ads/aqua/extension/common_handler.py +2 -3
  13. ads/aqua/extension/common_ws_msg_handler.py +2 -2
  14. ads/aqua/extension/evaluation_handler.py +4 -3
  15. ads/aqua/extension/model_handler.py +26 -1
  16. ads/aqua/extension/utils.py +12 -1
  17. ads/aqua/modeldeployment/deployment.py +31 -51
  18. ads/aqua/ui.py +27 -25
  19. ads/llm/__init__.py +10 -4
  20. ads/llm/chat_template.py +31 -0
  21. ads/llm/guardrails/base.py +3 -2
  22. ads/llm/guardrails/huggingface.py +1 -1
  23. ads/llm/langchain/plugins/chat_models/__init__.py +5 -0
  24. ads/llm/langchain/plugins/chat_models/oci_data_science.py +924 -0
  25. ads/llm/langchain/plugins/llms/__init__.py +5 -0
  26. ads/llm/langchain/plugins/llms/oci_data_science_model_deployment_endpoint.py +939 -0
  27. ads/llm/requirements.txt +2 -2
  28. ads/llm/serialize.py +3 -6
  29. ads/llm/templates/tool_chat_template_hermes.jinja +130 -0
  30. ads/llm/templates/tool_chat_template_mistral_parallel.jinja +94 -0
  31. {oracle_ads-2.11.18.dist-info → oracle_ads-2.12.0.dist-info}/METADATA +7 -4
  32. {oracle_ads-2.11.18.dist-info → oracle_ads-2.12.0.dist-info}/RECORD +35 -27
  33. ads/llm/langchain/plugins/base.py +0 -118
  34. ads/llm/langchain/plugins/contant.py +0 -44
  35. ads/llm/langchain/plugins/embeddings.py +0 -64
  36. ads/llm/langchain/plugins/llm_gen_ai.py +0 -301
  37. ads/llm/langchain/plugins/llm_md.py +0 -316
  38. {oracle_ads-2.11.18.dist-info → oracle_ads-2.12.0.dist-info}/LICENSE.txt +0 -0
  39. {oracle_ads-2.11.18.dist-info → oracle_ads-2.12.0.dist-info}/WHEEL +0 -0
  40. {oracle_ads-2.11.18.dist-info → oracle_ads-2.12.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,939 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*--
3
+
4
+ # Copyright (c) 2023 Oracle and/or its affiliates.
5
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
+
7
+
8
+ import json
9
+ import logging
10
+ from typing import (
11
+ Any,
12
+ AsyncIterator,
13
+ Callable,
14
+ Dict,
15
+ Iterator,
16
+ List,
17
+ Literal,
18
+ Optional,
19
+ Union,
20
+ )
21
+
22
+ import aiohttp
23
+ import requests
24
+ import traceback
25
+ from langchain_core.callbacks import (
26
+ AsyncCallbackManagerForLLMRun,
27
+ CallbackManagerForLLMRun,
28
+ )
29
+ from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
30
+ from langchain_core.load.serializable import Serializable
31
+ from langchain_core.outputs import Generation, GenerationChunk, LLMResult
32
+ from langchain_core.utils import get_from_dict_or_env, pre_init
33
+ from langchain_community.utilities.requests import Requests
34
+ from pydantic import Field
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ DEFAULT_TIME_OUT = 300
40
+ DEFAULT_CONTENT_TYPE_JSON = "application/json"
41
+ DEFAULT_MODEL_NAME = "odsc-llm"
42
+
43
+
44
+ class TokenExpiredError(Exception):
45
+ """Raises when token expired."""
46
+
47
+
48
+ class ServerError(Exception):
49
+ """Raises when encounter server error when making inference."""
50
+
51
+
52
+ def _create_retry_decorator(
53
+ llm: "BaseOCIModelDeployment",
54
+ *,
55
+ run_manager: Optional[
56
+ Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
57
+ ] = None,
58
+ ) -> Callable[[Any], Any]:
59
+ """Create a retry decorator."""
60
+ errors = [requests.exceptions.ConnectTimeout, TokenExpiredError]
61
+ decorator = create_base_retry_decorator(
62
+ error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
63
+ )
64
+ return decorator
65
+
66
+
67
+ class BaseOCIModelDeployment(Serializable):
68
+ """Base class for LLM deployed on OCI Data Science Model Deployment."""
69
+
70
+ auth: dict = Field(default_factory=dict, exclude=True)
71
+ """ADS auth dictionary for OCI authentication:
72
+ https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html.
73
+ This can be generated by calling `ads.common.auth.api_keys()`
74
+ or `ads.common.auth.resource_principal()`. If this is not
75
+ provided then the `ads.common.default_signer()` will be used."""
76
+
77
+ endpoint: str = ""
78
+ """The uri of the endpoint from the deployed Model Deployment model."""
79
+
80
+ streaming: bool = False
81
+ """Whether to stream the results or not."""
82
+
83
+ max_retries: int = 3
84
+ """Maximum number of retries to make when generating."""
85
+
86
+ @pre_init
87
+ def validate_environment( # pylint: disable=no-self-argument
88
+ cls, values: Dict
89
+ ) -> Dict:
90
+ """Validate that python package exists in environment."""
91
+ try:
92
+ import ads
93
+
94
+ except ImportError as ex:
95
+ raise ImportError(
96
+ "Could not import ads python package. "
97
+ "Please install it with `pip install oracle_ads`."
98
+ ) from ex
99
+
100
+ if not values.get("auth", None):
101
+ values["auth"] = ads.common.auth.default_signer()
102
+
103
+ values["endpoint"] = get_from_dict_or_env(
104
+ values,
105
+ "endpoint",
106
+ "OCI_LLM_ENDPOINT",
107
+ )
108
+ return values
109
+
110
+ def _headers(
111
+ self, is_async: Optional[bool] = False, body: Optional[dict] = None
112
+ ) -> Dict:
113
+ """Construct and return the headers for a request.
114
+
115
+ Args:
116
+ is_async (bool, optional): Indicates if the request is asynchronous.
117
+ Defaults to `False`.
118
+ body (optional): The request body to be included in the headers if
119
+ the request is asynchronous.
120
+
121
+ Returns:
122
+ Dict: A dictionary containing the appropriate headers for the request.
123
+ """
124
+ if is_async:
125
+ signer = self.auth["signer"]
126
+ _req = requests.Request("POST", self.endpoint, json=body)
127
+ req = _req.prepare()
128
+ req = signer(req)
129
+ headers = {}
130
+ for key, value in req.headers.items():
131
+ headers[key] = value
132
+
133
+ if self.streaming:
134
+ headers.update(
135
+ {"enable-streaming": "true", "Accept": "text/event-stream"}
136
+ )
137
+ return headers
138
+
139
+ return (
140
+ {
141
+ "Content-Type": DEFAULT_CONTENT_TYPE_JSON,
142
+ "enable-streaming": "true",
143
+ "Accept": "text/event-stream",
144
+ }
145
+ if self.streaming
146
+ else {
147
+ "Content-Type": DEFAULT_CONTENT_TYPE_JSON,
148
+ }
149
+ )
150
+
151
+ def completion_with_retry(
152
+ self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
153
+ ) -> Any:
154
+ """Use tenacity to retry the completion call."""
155
+ retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
156
+
157
+ @retry_decorator
158
+ def _completion_with_retry(**kwargs: Any) -> Any:
159
+ try:
160
+ request_timeout = kwargs.pop("request_timeout", DEFAULT_TIME_OUT)
161
+ data = kwargs.pop("data")
162
+ stream = kwargs.pop("stream", self.streaming)
163
+
164
+ request = Requests(
165
+ headers=self._headers(), auth=self.auth.get("signer")
166
+ )
167
+ response = request.post(
168
+ url=self.endpoint,
169
+ data=data,
170
+ timeout=request_timeout,
171
+ stream=stream,
172
+ **kwargs,
173
+ )
174
+ self._check_response(response)
175
+ return response
176
+ except TokenExpiredError as e:
177
+ raise e
178
+ except Exception as err:
179
+ traceback.print_exc()
180
+ logger.debug(
181
+ f"Requests payload: {data}. Requests arguments: "
182
+ f"url={self.endpoint},timeout={request_timeout},stream={stream}. "
183
+ f"Additional request kwargs={kwargs}."
184
+ )
185
+ raise RuntimeError(
186
+ f"Error occurs by inference endpoint: {str(err)}"
187
+ ) from err
188
+
189
+ return _completion_with_retry(**kwargs)
190
+
191
+ async def acompletion_with_retry(
192
+ self,
193
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
194
+ **kwargs: Any,
195
+ ) -> Any:
196
+ """Use tenacity to retry the async completion call."""
197
+ retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
198
+
199
+ @retry_decorator
200
+ async def _completion_with_retry(**kwargs: Any) -> Any:
201
+ try:
202
+ request_timeout = kwargs.pop("request_timeout", DEFAULT_TIME_OUT)
203
+ data = kwargs.pop("data")
204
+ stream = kwargs.pop("stream", self.streaming)
205
+
206
+ request = Requests(headers=self._headers(is_async=True, body=data))
207
+ if stream:
208
+ response = request.apost(
209
+ url=self.endpoint,
210
+ data=data,
211
+ timeout=request_timeout,
212
+ )
213
+ return self._aiter_sse(response)
214
+ else:
215
+ async with request.apost(
216
+ url=self.endpoint,
217
+ data=data,
218
+ timeout=request_timeout,
219
+ ) as resp:
220
+ self._check_response(resp)
221
+ data = await resp.json()
222
+ return data
223
+ except TokenExpiredError as e:
224
+ raise e
225
+ except Exception as err:
226
+ traceback.print_exc()
227
+ logger.debug(
228
+ f"Requests payload: `{data}`. "
229
+ f"Stream mode={stream}. "
230
+ f"Requests kwargs: url={self.endpoint}, timeout={request_timeout}."
231
+ )
232
+ raise RuntimeError(
233
+ f"Error occurs by inference endpoint: {str(err)}"
234
+ ) from err
235
+
236
+ return await _completion_with_retry(**kwargs)
237
+
238
+ def _check_response(self, response: Any) -> None:
239
+ """Handle server error by checking the response status.
240
+
241
+ Args:
242
+ response:
243
+ The response object from either `requests` or `aiohttp` library.
244
+
245
+ Raises:
246
+ TokenExpiredError:
247
+ If the response status code is 401 and the token refresh is successful.
248
+ ServerError:
249
+ If any other HTTP error occurs.
250
+ """
251
+ try:
252
+ response.raise_for_status()
253
+ except requests.exceptions.HTTPError as http_err:
254
+ status_code = (
255
+ response.status_code
256
+ if hasattr(response, "status_code")
257
+ else response.status
258
+ )
259
+ if status_code == 401 and self._refresh_signer():
260
+ raise TokenExpiredError() from http_err
261
+
262
+ raise ServerError(
263
+ f"Server error: {str(http_err)}. \nMessage: {response.text}"
264
+ ) from http_err
265
+
266
+ def _parse_stream(self, lines: Iterator[bytes]) -> Iterator[str]:
267
+ """Parse a stream of byte lines and yield parsed string lines.
268
+
269
+ Args:
270
+ lines (Iterator[bytes]):
271
+ An iterator that yields lines in byte format.
272
+
273
+ Yields:
274
+ Iterator[str]:
275
+ An iterator that yields parsed lines as strings.
276
+ """
277
+ for line in lines:
278
+ _line = self._parse_stream_line(line)
279
+ if _line is not None:
280
+ yield _line
281
+
282
+ async def _parse_stream_async(
283
+ self,
284
+ lines: aiohttp.StreamReader,
285
+ ) -> AsyncIterator[str]:
286
+ """
287
+ Asynchronously parse a stream of byte lines and yield parsed string lines.
288
+
289
+ Args:
290
+ lines (aiohttp.StreamReader):
291
+ An `aiohttp.StreamReader` object that yields lines in byte format.
292
+
293
+ Yields:
294
+ AsyncIterator[str]:
295
+ An asynchronous iterator that yields parsed lines as strings.
296
+ """
297
+ async for line in lines:
298
+ _line = self._parse_stream_line(line)
299
+ if _line is not None:
300
+ yield _line
301
+
302
+ def _parse_stream_line(self, line: bytes) -> Optional[str]:
303
+ """Parse a single byte line and return a processed string line if valid.
304
+
305
+ Args:
306
+ line (bytes): A single line in byte format.
307
+
308
+ Returns:
309
+ Optional[str]:
310
+ The processed line as a string if valid, otherwise `None`.
311
+ """
312
+ line = line.strip()
313
+ if not line:
314
+ return None
315
+ _line = line.decode("utf-8")
316
+
317
+ if _line.lower().startswith("data:"):
318
+ _line = _line[5:].lstrip()
319
+
320
+ if _line.startswith("[DONE]"):
321
+ return None
322
+ return _line
323
+ return None
324
+
325
+ async def _aiter_sse(
326
+ self,
327
+ async_cntx_mgr: Any,
328
+ ) -> AsyncIterator[str]:
329
+ """Asynchronously iterate over server-sent events (SSE).
330
+
331
+ Args:
332
+ async_cntx_mgr: An asynchronous context manager that yields a client
333
+ response object.
334
+
335
+ Yields:
336
+ AsyncIterator[str]: An asynchronous iterator that yields parsed server-sent
337
+ event lines as json string.
338
+ """
339
+ async with async_cntx_mgr as client_resp:
340
+ self._check_response(client_resp)
341
+ async for line in self._parse_stream_async(client_resp.content):
342
+ yield line
343
+
344
+ def _refresh_signer(self) -> bool:
345
+ """Attempt to refresh the security token using the signer.
346
+
347
+ Returns:
348
+ bool: `True` if the token was successfully refreshed, `False` otherwise.
349
+ """
350
+ if self.auth.get("signer", None) and hasattr(
351
+ self.auth["signer"], "refresh_security_token"
352
+ ):
353
+ self.auth["signer"].refresh_security_token()
354
+ return True
355
+ return False
356
+
357
+
358
+ class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
359
+ """LLM deployed on OCI Data Science Model Deployment.
360
+
361
+ To use, you must provide the model HTTP endpoint from your deployed
362
+ model, e.g. https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict.
363
+
364
+ To authenticate, `oracle-ads` has been used to automatically load
365
+ credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
366
+
367
+ Make sure to have the required policies to access the OCI Data
368
+ Science Model Deployment endpoint. See:
369
+ https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
370
+
371
+ Example:
372
+
373
+ .. code-block:: python
374
+
375
+ from langchain_community.llms import OCIModelDeploymentLLM
376
+
377
+ llm = OCIModelDeploymentLLM(
378
+ endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<ocid>/predict",
379
+ model="odsc-llm",
380
+ streaming=True,
381
+ model_kwargs={"frequency_penalty": 1.0},
382
+ )
383
+ llm.invoke("tell me a joke.")
384
+
385
+ Customized Usage:
386
+
387
+ User can inherit from our base class and overrwrite the `_process_response`, `_process_stream_response`,
388
+ `_construct_json_body` for satisfying customized needed.
389
+
390
+ .. code-block:: python
391
+
392
+ from langchain_community.llms import OCIModelDeploymentLLM
393
+
394
+ class MyCutomizedModel(OCIModelDeploymentLLM):
395
+ def _process_stream_response(self, response_json:dict) -> GenerationChunk:
396
+ print("My customized output stream handler.")
397
+ return GenerationChunk()
398
+
399
+ def _process_response(self, response_json:dict) -> List[Generation]:
400
+ print("My customized output handler.")
401
+ return [Generation()]
402
+
403
+ def _construct_json_body(self, prompt: str, param:dict) -> dict:
404
+ print("My customized input handler.")
405
+ return {}
406
+
407
+ llm = MyCutomizedModel(
408
+ endpoint=f"https://modeldeployment.us-ashburn-1.oci.customer-oci.com/{ocid}/predict",
409
+ model="<model_name>",
410
+ }
411
+
412
+ llm.invoke("tell me a joke.")
413
+
414
+ """ # noqa: E501
415
+
416
+ model: str = DEFAULT_MODEL_NAME
417
+ """The name of the model."""
418
+
419
+ max_tokens: int = 256
420
+ """Denotes the number of tokens to predict per generation."""
421
+
422
+ temperature: float = 0.2
423
+ """A non-negative float that tunes the degree of randomness in generation."""
424
+
425
+ k: int = -1
426
+ """Number of most likely tokens to consider at each step."""
427
+
428
+ p: float = 0.75
429
+ """Total probability mass of tokens to consider at each step."""
430
+
431
+ best_of: int = 1
432
+ """Generates best_of completions server-side and returns the "best"
433
+ (the one with the highest log probability per token).
434
+ """
435
+
436
+ stop: Optional[List[str]] = None
437
+ """Stop words to use when generating. Model output is cut off
438
+ at the first occurrence of any of these substrings."""
439
+
440
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
441
+ """Keyword arguments to pass to the model."""
442
+
443
+ @property
444
+ def _llm_type(self) -> str:
445
+ """Return type of llm."""
446
+ return "oci_model_deployment_endpoint"
447
+
448
+ @classmethod
449
+ def is_lc_serializable(cls) -> bool:
450
+ """Return whether this model can be serialized by Langchain."""
451
+ return True
452
+
453
+ @property
454
+ def _default_params(self) -> Dict[str, Any]:
455
+ """Get the default parameters."""
456
+ return {
457
+ "best_of": self.best_of,
458
+ "max_tokens": self.max_tokens,
459
+ "model": self.model,
460
+ "stop": self.stop,
461
+ "stream": self.streaming,
462
+ "temperature": self.temperature,
463
+ "top_k": self.k,
464
+ "top_p": self.p,
465
+ }
466
+
467
+ @property
468
+ def _identifying_params(self) -> Dict[str, Any]:
469
+ """Get the identifying parameters."""
470
+ _model_kwargs = self.model_kwargs or {}
471
+ return {
472
+ **{"endpoint": self.endpoint, "model_kwargs": _model_kwargs},
473
+ **self._default_params,
474
+ }
475
+
476
+ def _generate(
477
+ self,
478
+ prompts: List[str],
479
+ stop: Optional[List[str]] = None,
480
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
481
+ **kwargs: Any,
482
+ ) -> LLMResult:
483
+ """Call out to OCI Data Science Model Deployment endpoint with k unique prompts.
484
+
485
+ Args:
486
+ prompts: The prompts to pass into the service.
487
+ stop: Optional list of stop words to use when generating.
488
+
489
+ Returns:
490
+ The full LLM output.
491
+
492
+ Example:
493
+ .. code-block:: python
494
+
495
+ response = llm.invoke("Tell me a joke.")
496
+ response = llm.generate(["Tell me a joke."])
497
+ """
498
+ generations: List[List[Generation]] = []
499
+ params = self._invocation_params(stop, **kwargs)
500
+ for prompt in prompts:
501
+ body = self._construct_json_body(prompt, params)
502
+ if self.streaming:
503
+ generation = GenerationChunk(text="")
504
+ for chunk in self._stream(
505
+ prompt, stop=stop, run_manager=run_manager, **kwargs
506
+ ):
507
+ generation += chunk
508
+ generations.append([generation])
509
+ else:
510
+ res = self.completion_with_retry(
511
+ data=body,
512
+ run_manager=run_manager,
513
+ **kwargs,
514
+ )
515
+ generations.append(self._process_response(res.json()))
516
+ return LLMResult(generations=generations)
517
+
518
+ async def _agenerate(
519
+ self,
520
+ prompts: List[str],
521
+ stop: Optional[List[str]] = None,
522
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
523
+ **kwargs: Any,
524
+ ) -> LLMResult:
525
+ """Call out to OCI Data Science Model Deployment endpoint async with k unique prompts.
526
+
527
+ Args:
528
+ prompts: The prompts to pass into the service.
529
+ stop: Optional list of stop words to use when generating.
530
+
531
+ Returns:
532
+ The full LLM output.
533
+
534
+ Example:
535
+ .. code-block:: python
536
+
537
+ response = await llm.ainvoke("Tell me a joke.")
538
+ response = await llm.agenerate(["Tell me a joke."])
539
+ """ # noqa: E501
540
+ generations: List[List[Generation]] = []
541
+ params = self._invocation_params(stop, **kwargs)
542
+ for prompt in prompts:
543
+ body = self._construct_json_body(prompt, params)
544
+ if self.streaming:
545
+ generation = GenerationChunk(text="")
546
+ async for chunk in self._astream(
547
+ prompt, stop=stop, run_manager=run_manager, **kwargs
548
+ ):
549
+ generation += chunk
550
+ generations.append([generation])
551
+ else:
552
+ res = await self.acompletion_with_retry(
553
+ data=body,
554
+ run_manager=run_manager,
555
+ **kwargs,
556
+ )
557
+ generations.append(self._process_response(res))
558
+ return LLMResult(generations=generations)
559
+
560
+ def _stream(
561
+ self,
562
+ prompt: str,
563
+ stop: Optional[List[str]] = None,
564
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
565
+ **kwargs: Any,
566
+ ) -> Iterator[GenerationChunk]:
567
+ """Stream OCI Data Science Model Deployment endpoint on given prompt.
568
+
569
+
570
+ Args:
571
+ prompt (str):
572
+ The prompt to pass into the model.
573
+ stop (List[str], Optional):
574
+ List of stop words to use when generating.
575
+ kwargs:
576
+ requests_kwargs:
577
+ Additional ``**kwargs`` to pass to requests.post
578
+
579
+ Returns:
580
+ An iterator of GenerationChunks.
581
+
582
+
583
+ Example:
584
+
585
+ .. code-block:: python
586
+
587
+ response = llm.stream("Tell me a joke.")
588
+
589
+ """
590
+ requests_kwargs = kwargs.pop("requests_kwargs", {})
591
+ self.streaming = True
592
+ params = self._invocation_params(stop, **kwargs)
593
+ body = self._construct_json_body(prompt, params)
594
+
595
+ response = self.completion_with_retry(
596
+ data=body, run_manager=run_manager, stream=True, **requests_kwargs
597
+ )
598
+ for line in self._parse_stream(response.iter_lines()):
599
+ chunk = self._handle_sse_line(line)
600
+ if run_manager:
601
+ run_manager.on_llm_new_token(chunk.text, chunk=chunk)
602
+
603
+ yield chunk
604
+
605
+ async def _astream(
606
+ self,
607
+ prompt: str,
608
+ stop: Optional[List[str]] = None,
609
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
610
+ **kwargs: Any,
611
+ ) -> AsyncIterator[GenerationChunk]:
612
+ """Stream OCI Data Science Model Deployment endpoint async on given prompt.
613
+
614
+
615
+ Args:
616
+ prompt (str):
617
+ The prompt to pass into the model.
618
+ stop (List[str], Optional):
619
+ List of stop words to use when generating.
620
+ kwargs:
621
+ requests_kwargs:
622
+ Additional ``**kwargs`` to pass to requests.post
623
+
624
+ Returns:
625
+ An iterator of GenerationChunks.
626
+
627
+
628
+ Example:
629
+
630
+ .. code-block:: python
631
+
632
+ async for chunk in llm.astream(("Tell me a joke."):
633
+ print(chunk, end="", flush=True)
634
+
635
+ """
636
+ requests_kwargs = kwargs.pop("requests_kwargs", {})
637
+ self.streaming = True
638
+ params = self._invocation_params(stop, **kwargs)
639
+ body = self._construct_json_body(prompt, params)
640
+
641
+ async for line in await self.acompletion_with_retry(
642
+ data=body, run_manager=run_manager, stream=True, **requests_kwargs
643
+ ):
644
+ chunk = self._handle_sse_line(line)
645
+ if run_manager:
646
+ await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
647
+ yield chunk
648
+
649
+ def _construct_json_body(self, prompt: str, params: dict) -> dict:
650
+ """Constructs the request body as a dictionary (JSON)."""
651
+ return {
652
+ "prompt": prompt,
653
+ **params,
654
+ }
655
+
656
+ def _invocation_params(
657
+ self, stop: Optional[List[str]] = None, **kwargs: Any
658
+ ) -> dict:
659
+ """Combines the invocation parameters with default parameters."""
660
+ params = self._default_params
661
+ _model_kwargs = self.model_kwargs or {}
662
+ params["stop"] = stop or params.get("stop", [])
663
+ return {**params, **_model_kwargs, **kwargs}
664
+
665
+ def _process_stream_response(self, response_json: dict) -> GenerationChunk:
666
+ """Formats streaming response for OpenAI spec into GenerationChunk."""
667
+ try:
668
+ choice = response_json["choices"][0]
669
+ if not isinstance(choice, dict):
670
+ raise TypeError("Endpoint response is not well formed.")
671
+ except (KeyError, IndexError, TypeError) as e:
672
+ raise ValueError("Error while formatting response payload.") from e
673
+
674
+ return GenerationChunk(text=choice.get("text", ""))
675
+
676
+ def _process_response(self, response_json: dict) -> List[Generation]:
677
+ """Formats response in OpenAI spec.
678
+
679
+ Args:
680
+ response_json (dict): The JSON response from the chat model endpoint.
681
+
682
+ Returns:
683
+ ChatResult: An object containing the list of `ChatGeneration` objects
684
+ and additional LLM output information.
685
+
686
+ Raises:
687
+ ValueError: If the response JSON is not well-formed or does not
688
+ contain the expected structure.
689
+
690
+ """
691
+ generations = []
692
+ try:
693
+ choices = response_json["choices"]
694
+ if not isinstance(choices, list):
695
+ raise TypeError("Endpoint response is not well formed.")
696
+ except (KeyError, TypeError) as e:
697
+ raise ValueError("Error while formatting response payload.") from e
698
+
699
+ for choice in choices:
700
+ gen = Generation(
701
+ text=choice.get("text"),
702
+ generation_info=self._generate_info(choice),
703
+ )
704
+ generations.append(gen)
705
+
706
+ return generations
707
+
708
+ def _generate_info(self, choice: dict) -> Any:
709
+ """Extracts generation info from the response."""
710
+ gen_info = {}
711
+ finish_reason = choice.get("finish_reason", None)
712
+ logprobs = choice.get("logprobs", None)
713
+ index = choice.get("index", None)
714
+ if finish_reason:
715
+ gen_info.update({"finish_reason": finish_reason})
716
+ if logprobs is not None:
717
+ gen_info.update({"logprobs": logprobs})
718
+ if index is not None:
719
+ gen_info.update({"index": index})
720
+
721
+ return gen_info or None
722
+
723
+ def _handle_sse_line(self, line: str) -> GenerationChunk:
724
+ try:
725
+ obj = json.loads(line)
726
+ return self._process_stream_response(obj)
727
+ except Exception:
728
+ return GenerationChunk(text="")
729
+
730
+
731
+ class OCIModelDeploymentTGI(OCIModelDeploymentLLM):
732
+ """OCI Data Science Model Deployment TGI Endpoint.
733
+
734
+ To use, you must provide the model HTTP endpoint from your deployed
735
+ model, e.g. https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict.
736
+
737
+ To authenticate, `oracle-ads` has been used to automatically load
738
+ credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
739
+
740
+ Make sure to have the required policies to access the OCI Data
741
+ Science Model Deployment endpoint. See:
742
+ https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
743
+
744
+ Example:
745
+ .. code-block:: python
746
+
747
+ from langchain_community.llms import OCIModelDeploymentTGI
748
+
749
+ llm = OCIModelDeploymentTGI(
750
+ endpoint="https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict",
751
+ api="/v1/completions",
752
+ streaming=True,
753
+ temperature=0.2,
754
+ seed=42,
755
+ # other model parameters ...
756
+ )
757
+
758
+ """
759
+
760
+ api: Literal["/generate", "/v1/completions"] = "/v1/completions"
761
+ """Api spec."""
762
+
763
+ frequency_penalty: float = 0.0
764
+ """Penalizes repeated tokens according to frequency. Between 0 and 1."""
765
+
766
+ seed: Optional[int] = None
767
+ """Random sampling seed"""
768
+
769
+ repetition_penalty: Optional[float] = None
770
+ """The parameter for repetition penalty. 1.0 means no penalty."""
771
+
772
+ suffix: Optional[str] = None
773
+ """The text to append to the prompt. """
774
+
775
+ do_sample: bool = True
776
+ """If set to True, this parameter enables decoding strategies such as
777
+ multi-nominal sampling, beam-search multi-nominal sampling, Top-K
778
+ sampling and Top-p sampling.
779
+ """
780
+
781
+ watermark: bool = True
782
+ """Watermarking with `A Watermark for Large Language Models <https://arxiv.org/abs/2301.10226>`_.
783
+ Defaults to True."""
784
+
785
+ return_full_text: bool = False
786
+ """Whether to prepend the prompt to the generated text. Defaults to False."""
787
+
788
+ @property
789
+ def _llm_type(self) -> str:
790
+ """Return type of llm."""
791
+ return "oci_model_deployment_tgi_endpoint"
792
+
793
+ @property
794
+ def _default_params(self) -> Dict[str, Any]:
795
+ """Get the default parameters for invoking OCI model deployment TGI endpoint."""
796
+ return (
797
+ {
798
+ "model": self.model, # can be any
799
+ "frequency_penalty": self.frequency_penalty,
800
+ "max_tokens": self.max_tokens,
801
+ "repetition_penalty": self.repetition_penalty,
802
+ "temperature": self.temperature,
803
+ "top_p": self.p,
804
+ "seed": self.seed,
805
+ "stream": self.streaming,
806
+ "suffix": self.suffix,
807
+ "stop": self.stop,
808
+ }
809
+ if self.api == "/v1/completions"
810
+ else {
811
+ "best_of": self.best_of,
812
+ "max_new_tokens": self.max_tokens,
813
+ "temperature": self.temperature,
814
+ "top_k": (
815
+ self.k if self.k > 0 else None
816
+ ), # `top_k` must be strictly positive'
817
+ "top_p": self.p,
818
+ "do_sample": self.do_sample,
819
+ "return_full_text": self.return_full_text,
820
+ "watermark": self.watermark,
821
+ "stop": self.stop,
822
+ }
823
+ )
824
+
825
+ @property
826
+ def _identifying_params(self) -> Dict[str, Any]:
827
+ """Get the identifying parameters."""
828
+ _model_kwargs = self.model_kwargs or {}
829
+ return {
830
+ **{
831
+ "endpoint": self.endpoint,
832
+ "api": self.api,
833
+ "model_kwargs": _model_kwargs,
834
+ },
835
+ **self._default_params,
836
+ }
837
+
838
+ def _construct_json_body(self, prompt: str, params: dict) -> dict:
839
+ """Construct request payload."""
840
+ if self.api == "/v1/completions":
841
+ return super()._construct_json_body(prompt, params)
842
+
843
+ return {
844
+ "inputs": prompt,
845
+ "parameters": params,
846
+ }
847
+
848
+ def _process_response(self, response_json: dict) -> List[Generation]:
849
+ """Formats response."""
850
+ if self.api == "/v1/completions":
851
+ return super()._process_response(response_json)
852
+
853
+ try:
854
+ text = response_json["generated_text"]
855
+ except KeyError as e:
856
+ raise ValueError(
857
+ f"Error while formatting response payload.response_json={response_json}"
858
+ ) from e
859
+
860
+ return [Generation(text=text)]
861
+
862
+
863
+ class OCIModelDeploymentVLLM(OCIModelDeploymentLLM):
864
+ """VLLM deployed on OCI Data Science Model Deployment
865
+
866
+ To use, you must provide the model HTTP endpoint from your deployed
867
+ model, e.g. https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict.
868
+
869
+ To authenticate, `oracle-ads` has been used to automatically load
870
+ credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
871
+
872
+ Make sure to have the required policies to access the OCI Data
873
+ Science Model Deployment endpoint. See:
874
+ https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
875
+
876
+ Example:
877
+ .. code-block:: python
878
+
879
+ from langchain_community.llms import OCIModelDeploymentVLLM
880
+
881
+ llm = OCIModelDeploymentVLLM(
882
+ endpoint="https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict",
883
+ model="odsc-llm",
884
+ streaming=False,
885
+ temperature=0.2,
886
+ max_tokens=512,
887
+ n=3,
888
+ best_of=3,
889
+ # other model parameters
890
+ )
891
+
892
+ """
893
+
894
+ n: int = 1
895
+ """Number of output sequences to return for the given prompt."""
896
+
897
+ k: int = -1
898
+ """Number of most likely tokens to consider at each step."""
899
+
900
+ frequency_penalty: float = 0.0
901
+ """Penalizes repeated tokens according to frequency. Between 0 and 1."""
902
+
903
+ presence_penalty: float = 0.0
904
+ """Penalizes repeated tokens. Between 0 and 1."""
905
+
906
+ use_beam_search: bool = False
907
+ """Whether to use beam search instead of sampling."""
908
+
909
+ ignore_eos: bool = False
910
+ """Whether to ignore the EOS token and continue generating tokens after
911
+ the EOS token is generated."""
912
+
913
+ logprobs: Optional[int] = None
914
+ """Number of log probabilities to return per output token."""
915
+
916
+ @property
917
+ def _llm_type(self) -> str:
918
+ """Return type of llm."""
919
+ return "oci_model_deployment_vllm_endpoint"
920
+
921
+ @property
922
+ def _default_params(self) -> Dict[str, Any]:
923
+ """Get the default parameters for calling vllm."""
924
+ return {
925
+ "best_of": self.best_of,
926
+ "frequency_penalty": self.frequency_penalty,
927
+ "ignore_eos": self.ignore_eos,
928
+ "logprobs": self.logprobs,
929
+ "max_tokens": self.max_tokens,
930
+ "model": self.model,
931
+ "n": self.n,
932
+ "presence_penalty": self.presence_penalty,
933
+ "stop": self.stop,
934
+ "stream": self.streaming,
935
+ "temperature": self.temperature,
936
+ "top_k": self.k,
937
+ "top_p": self.p,
938
+ "use_beam_search": self.use_beam_search,
939
+ }