oracle-ads 2.11.19__py3-none-any.whl → 2.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,924 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*--
3
+
4
+ # Copyright (c) 2023 Oracle and/or its affiliates.
5
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
+
7
+
8
+ import json
9
+ import logging
10
+ from operator import itemgetter
11
+ from typing import (
12
+ Any,
13
+ AsyncIterator,
14
+ Dict,
15
+ Iterator,
16
+ List,
17
+ Literal,
18
+ Optional,
19
+ Type,
20
+ Union,
21
+ Sequence,
22
+ Callable,
23
+ )
24
+
25
+ from langchain_core.callbacks import (
26
+ AsyncCallbackManagerForLLMRun,
27
+ CallbackManagerForLLMRun,
28
+ )
29
+ from langchain_core.language_models import LanguageModelInput
30
+ from langchain_core.language_models.chat_models import (
31
+ BaseChatModel,
32
+ agenerate_from_stream,
33
+ generate_from_stream,
34
+ )
35
+ from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
36
+ from langchain_core.tools import BaseTool
37
+ from langchain_core.output_parsers import (
38
+ JsonOutputParser,
39
+ PydanticOutputParser,
40
+ )
41
+ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
42
+ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
43
+ from langchain_core.utils.function_calling import convert_to_openai_tool
44
+ from langchain_openai.chat_models.base import (
45
+ _convert_delta_to_message_chunk,
46
+ _convert_message_to_dict,
47
+ _convert_dict_to_message,
48
+ )
49
+
50
+ from pydantic import BaseModel, Field
51
+ from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
52
+ DEFAULT_MODEL_NAME,
53
+ BaseOCIModelDeployment,
54
+ )
55
+
56
+ logger = logging.getLogger(__name__)
57
+
58
+
59
+ def _is_pydantic_class(obj: Any) -> bool:
60
+ return isinstance(obj, type) and issubclass(obj, BaseModel)
61
+
62
+
63
+ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
64
+ """OCI Data Science Model Deployment chat model integration.
65
+
66
+ To use, you must provide the model HTTP endpoint from your deployed
67
+ chat model, e.g. https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict.
68
+
69
+ To authenticate, `oracle-ads` has been used to automatically load
70
+ credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
71
+
72
+ Make sure to have the required policies to access the OCI Data
73
+ Science Model Deployment endpoint. See:
74
+ https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
75
+
76
+ Instantiate:
77
+ .. code-block:: python
78
+
79
+ from langchain_community.chat_models import ChatOCIModelDeployment
80
+
81
+ chat = ChatOCIModelDeployment(
82
+ endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<ocid>/predict",
83
+ model="odsc-llm",
84
+ streaming=True,
85
+ max_retries=3,
86
+ model_kwargs={
87
+ "max_token": 512,
88
+ "temperature": 0.2,
89
+ # other model parameters ...
90
+ },
91
+ )
92
+
93
+ Invocation:
94
+ .. code-block:: python
95
+
96
+ messages = [
97
+ ("system", "You are a helpful translator. Translate the user sentence to French."),
98
+ ("human", "Hello World!"),
99
+ ]
100
+ chat.invoke(messages)
101
+
102
+ .. code-block:: python
103
+
104
+ AIMessage(
105
+ content='Bonjour le monde!',response_metadata={'token_usage': {'prompt_tokens': 40, 'total_tokens': 50, 'completion_tokens': 10},'model_name': 'odsc-llm','system_fingerprint': '','finish_reason': 'stop'},id='run-cbed62da-e1b3-4abd-9df3-ec89d69ca012-0')
106
+
107
+ Streaming:
108
+ .. code-block:: python
109
+
110
+ for chunk in chat.stream(messages):
111
+ print(chunk)
112
+
113
+ .. code-block:: python
114
+
115
+ content='' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
116
+ content='\n' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
117
+ content='B' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
118
+ content='on' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
119
+ content='j' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
120
+ content='our' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
121
+ content=' le' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
122
+ content=' monde' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
123
+ content='!' id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
124
+ content='' response_metadata={'finish_reason': 'stop'} id='run-23df02c6-c43f-42de-87c6-8ad382e125c3'
125
+
126
+ Asyc:
127
+ .. code-block:: python
128
+
129
+ await chat.ainvoke(messages)
130
+
131
+ # stream:
132
+ # async for chunk in (await chat.astream(messages))
133
+
134
+ .. code-block:: python
135
+
136
+ AIMessage(content='Bonjour le monde!', response_metadata={'finish_reason': 'stop'}, id='run-8657a105-96b7-4bb6-b98e-b69ca420e5d1-0')
137
+
138
+ Structured output:
139
+ .. code-block:: python
140
+
141
+ from typing import Optional
142
+ from pydantic import BaseModel, Field
143
+
144
+ class Joke(BaseModel):
145
+ setup: str = Field(description="The setup of the joke")
146
+ punchline: str = Field(description="The punchline to the joke")
147
+
148
+ structured_llm = chat.with_structured_output(Joke, method="json_mode")
149
+ structured_llm.invoke(
150
+ "Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys"
151
+ )
152
+
153
+ .. code-block:: python
154
+
155
+ Joke(setup='Why did the cat get stuck in the tree?',punchline='Because it was chasing its tail!')
156
+
157
+ See ``ChatOCIModelDeployment.with_structured_output()`` for more.
158
+
159
+ Customized Usage:
160
+
161
+ You can inherit from base class and overwrite the `_process_response`, `_process_stream_response`,
162
+ `_construct_json_body` for satisfying customized needed.
163
+
164
+ .. code-block:: python
165
+
166
+ class MyChatModel(ChatOCIModelDeployment):
167
+ def _process_stream_response(self, response_json: dict) -> ChatGenerationChunk:
168
+ print("My customized streaming result handler.")
169
+ return GenerationChunk(...)
170
+
171
+ def _process_response(self, response_json:dict) -> ChatResult:
172
+ print("My customized output handler.")
173
+ return ChatResult(...)
174
+
175
+ def _construct_json_body(self, messages: list, params: dict) -> dict:
176
+ print("My customized payload handler.")
177
+ return {
178
+ "messages": messages,
179
+ **params,
180
+ }
181
+
182
+ chat = MyChatModel(
183
+ endpoint=f"https://modeldeployment.us-ashburn-1.oci.customer-oci.com/{ocid}/predict",
184
+ model="odsc-llm",
185
+ }
186
+
187
+ chat.invoke("tell me a joke")
188
+
189
+ """ # noqa: E501
190
+
191
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
192
+ """Keyword arguments to pass to the model."""
193
+
194
+ model: str = DEFAULT_MODEL_NAME
195
+ """The name of the model."""
196
+
197
+ stop: Optional[List[str]] = None
198
+ """Stop words to use when generating. Model output is cut off
199
+ at the first occurrence of any of these substrings."""
200
+
201
+ @property
202
+ def _llm_type(self) -> str:
203
+ """Return type of llm."""
204
+ return "oci_model_depolyment_chat_endpoint"
205
+
206
+ @property
207
+ def _identifying_params(self) -> Dict[str, Any]:
208
+ """Get the identifying parameters."""
209
+ _model_kwargs = self.model_kwargs or {}
210
+ return {
211
+ **{"endpoint": self.endpoint, "model_kwargs": _model_kwargs},
212
+ **self._default_params,
213
+ }
214
+
215
+ @property
216
+ def _default_params(self) -> Dict[str, Any]:
217
+ """Get the default parameters."""
218
+ return {
219
+ "model": self.model,
220
+ "stop": self.stop,
221
+ "stream": self.streaming,
222
+ }
223
+
224
+ def _generate(
225
+ self,
226
+ messages: List[BaseMessage],
227
+ stop: Optional[List[str]] = None,
228
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
229
+ **kwargs: Any,
230
+ ) -> ChatResult:
231
+ """Call out to an OCI Model Deployment Online endpoint.
232
+
233
+ Args:
234
+ messages: The messages in the conversation with the chat model.
235
+ stop: Optional list of stop words to use when generating.
236
+
237
+ Returns:
238
+ LangChain ChatResult
239
+
240
+ Raises:
241
+ RuntimeError:
242
+ Raise when invoking endpoint fails.
243
+
244
+ Example:
245
+
246
+ .. code-block:: python
247
+
248
+ messages = [
249
+ (
250
+ "system",
251
+ "You are a helpful assistant that translates English to French. Translate the user sentence.",
252
+ ),
253
+ ("human", "Hello World!"),
254
+ ]
255
+
256
+ response = chat.invoke(messages)
257
+ """ # noqa: E501
258
+ if self.streaming:
259
+ stream_iter = self._stream(
260
+ messages, stop=stop, run_manager=run_manager, **kwargs
261
+ )
262
+ return generate_from_stream(stream_iter)
263
+
264
+ requests_kwargs = kwargs.pop("requests_kwargs", {})
265
+ params = self._invocation_params(stop, **kwargs)
266
+ body = self._construct_json_body(messages, params)
267
+ res = self.completion_with_retry(
268
+ data=body, run_manager=run_manager, **requests_kwargs
269
+ )
270
+ return self._process_response(res.json())
271
+
272
+ def _stream(
273
+ self,
274
+ messages: List[BaseMessage],
275
+ stop: Optional[List[str]] = None,
276
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
277
+ **kwargs: Any,
278
+ ) -> Iterator[ChatGenerationChunk]:
279
+ """Stream OCI Data Science Model Deployment endpoint on given messages.
280
+
281
+ Args:
282
+ messages (List[BaseMessage]):
283
+ The messagaes to pass into the model.
284
+ stop (List[str], Optional):
285
+ List of stop words to use when generating.
286
+ kwargs:
287
+ requests_kwargs:
288
+ Additional ``**kwargs`` to pass to requests.post
289
+
290
+ Returns:
291
+ An iterator of ChatGenerationChunk.
292
+
293
+ Raises:
294
+ RuntimeError:
295
+ Raise when invoking endpoint fails.
296
+
297
+ Example:
298
+
299
+ .. code-block:: python
300
+
301
+ messages = [
302
+ (
303
+ "system",
304
+ "You are a helpful assistant that translates English to French. Translate the user sentence.",
305
+ ),
306
+ ("human", "Hello World!"),
307
+ ]
308
+
309
+ chunk_iter = chat.stream(messages)
310
+
311
+ """ # noqa: E501
312
+ requests_kwargs = kwargs.pop("requests_kwargs", {})
313
+ self.streaming = True
314
+ params = self._invocation_params(stop, **kwargs)
315
+ body = self._construct_json_body(messages, params) # request json body
316
+
317
+ response = self.completion_with_retry(
318
+ data=body, run_manager=run_manager, stream=True, **requests_kwargs
319
+ )
320
+ default_chunk_class = AIMessageChunk
321
+ for line in self._parse_stream(response.iter_lines()):
322
+ chunk = self._handle_sse_line(line, default_chunk_class)
323
+ if run_manager:
324
+ run_manager.on_llm_new_token(chunk.text, chunk=chunk)
325
+ yield chunk
326
+
327
+ async def _agenerate(
328
+ self,
329
+ messages: List[BaseMessage],
330
+ stop: Optional[List[str]] = None,
331
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
332
+ **kwargs: Any,
333
+ ) -> ChatResult:
334
+ """Asynchronously call out to OCI Data Science Model Deployment
335
+ endpoint on given messages.
336
+
337
+ Args:
338
+ messages (List[BaseMessage]):
339
+ The messagaes to pass into the model.
340
+ stop (List[str], Optional):
341
+ List of stop words to use when generating.
342
+ kwargs:
343
+ requests_kwargs:
344
+ Additional ``**kwargs`` to pass to requests.post
345
+
346
+ Returns:
347
+ LangChain ChatResult.
348
+
349
+ Raises:
350
+ ValueError:
351
+ Raise when invoking endpoint fails.
352
+
353
+ Example:
354
+
355
+ .. code-block:: python
356
+
357
+ messages = [
358
+ (
359
+ "system",
360
+ "You are a helpful assistant that translates English to French. Translate the user sentence.",
361
+ ),
362
+ ("human", "I love programming."),
363
+ ]
364
+
365
+ resp = await chat.ainvoke(messages)
366
+
367
+ """ # noqa: E501
368
+ if self.streaming:
369
+ stream_iter = self._astream(
370
+ messages, stop=stop, run_manager=run_manager, **kwargs
371
+ )
372
+ return await agenerate_from_stream(stream_iter)
373
+
374
+ requests_kwargs = kwargs.pop("requests_kwargs", {})
375
+ params = self._invocation_params(stop, **kwargs)
376
+ body = self._construct_json_body(messages, params)
377
+ response = await self.acompletion_with_retry(
378
+ data=body,
379
+ run_manager=run_manager,
380
+ **requests_kwargs,
381
+ )
382
+ return self._process_response(response)
383
+
384
+ async def _astream(
385
+ self,
386
+ messages: List[BaseMessage],
387
+ stop: Optional[List[str]] = None,
388
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
389
+ **kwargs: Any,
390
+ ) -> AsyncIterator[ChatGenerationChunk]:
391
+ """Asynchronously streaming OCI Data Science Model Deployment
392
+ endpoint on given messages.
393
+
394
+ Args:
395
+ messages (List[BaseMessage]):
396
+ The messagaes to pass into the model.
397
+ stop (List[str], Optional):
398
+ List of stop words to use when generating.
399
+ kwargs:
400
+ requests_kwargs:
401
+ Additional ``**kwargs`` to pass to requests.post
402
+
403
+ Returns:
404
+ An Asynciterator of ChatGenerationChunk.
405
+
406
+ Raises:
407
+ ValueError:
408
+ Raise when invoking endpoint fails.
409
+
410
+ Example:
411
+
412
+ .. code-block:: python
413
+
414
+ messages = [
415
+ (
416
+ "system",
417
+ "You are a helpful assistant that translates English to French. Translate the user sentence.",
418
+ ),
419
+ ("human", "I love programming."),
420
+ ]
421
+
422
+ chunk_iter = await chat.astream(messages)
423
+
424
+ """ # noqa: E501
425
+ requests_kwargs = kwargs.pop("requests_kwargs", {})
426
+ self.streaming = True
427
+ params = self._invocation_params(stop, **kwargs)
428
+ body = self._construct_json_body(messages, params) # request json body
429
+
430
+ default_chunk_class = AIMessageChunk
431
+ async for line in await self.acompletion_with_retry(
432
+ data=body, run_manager=run_manager, stream=True, **requests_kwargs
433
+ ):
434
+ chunk = self._handle_sse_line(line, default_chunk_class)
435
+ if run_manager:
436
+ await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
437
+ yield chunk
438
+
439
+ def with_structured_output(
440
+ self,
441
+ schema: Optional[Union[Dict, Type[BaseModel]]] = None,
442
+ *,
443
+ method: Literal["json_mode"] = "json_mode",
444
+ include_raw: bool = False,
445
+ **kwargs: Any,
446
+ ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
447
+ """Model wrapper that returns outputs formatted to match the given schema.
448
+
449
+ Args:
450
+ schema: The output schema as a dict or a Pydantic class. If a Pydantic class
451
+ then the model output will be an object of that class. If a dict then
452
+ the model output will be a dict. With a Pydantic class the returned
453
+ attributes will be validated, whereas with a dict they will not be. If
454
+ `method` is "function_calling" and `schema` is a dict, then the dict
455
+ must match the OpenAI function-calling spec.
456
+ method: The method for steering model generation, currently only support
457
+ for "json_mode". If "json_mode" then JSON mode will be used. Note that
458
+ if using "json_mode" then you must include instructions for formatting
459
+ the output into the desired schema into the model call.
460
+ include_raw: If False then only the parsed structured output is returned. If
461
+ an error occurs during model output parsing it will be raised. If True
462
+ then both the raw model response (a BaseMessage) and the parsed model
463
+ response will be returned. If an error occurs during output parsing it
464
+ will be caught and returned as well. The final output is always a dict
465
+ with keys "raw", "parsed", and "parsing_error".
466
+
467
+ Returns:
468
+ A Runnable that takes any ChatModel input and returns as output:
469
+
470
+ If include_raw is True then a dict with keys:
471
+ raw: BaseMessage
472
+ parsed: Optional[_DictOrPydantic]
473
+ parsing_error: Optional[BaseException]
474
+
475
+ If include_raw is False then just _DictOrPydantic is returned,
476
+ where _DictOrPydantic depends on the schema:
477
+
478
+ If schema is a Pydantic class then _DictOrPydantic is the Pydantic
479
+ class.
480
+
481
+ If schema is a dict then _DictOrPydantic is a dict.
482
+
483
+ """ # noqa: E501
484
+ if kwargs:
485
+ raise ValueError(f"Received unsupported arguments {kwargs}")
486
+ is_pydantic_schema = _is_pydantic_class(schema)
487
+ if method == "json_mode":
488
+ llm = self.bind(response_format={"type": "json_object"})
489
+ output_parser = (
490
+ PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
491
+ if is_pydantic_schema
492
+ else JsonOutputParser()
493
+ )
494
+ else:
495
+ raise ValueError(
496
+ f"Unrecognized method argument. Expected `json_mode`."
497
+ f"Received: `{method}`."
498
+ )
499
+
500
+ if include_raw:
501
+ parser_assign = RunnablePassthrough.assign(
502
+ parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
503
+ )
504
+ parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
505
+ parser_with_fallback = parser_assign.with_fallbacks(
506
+ [parser_none], exception_key="parsing_error"
507
+ )
508
+ return RunnableMap(raw=llm) | parser_with_fallback
509
+ else:
510
+ return llm | output_parser
511
+
512
+ def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict:
513
+ """Combines the invocation parameters with default parameters."""
514
+ params = self._default_params
515
+ _model_kwargs = self.model_kwargs or {}
516
+ params["stop"] = stop or params.get("stop", [])
517
+ return {**params, **_model_kwargs, **kwargs}
518
+
519
+ def _handle_sse_line(
520
+ self, line: str, default_chunk_cls: Type[BaseMessageChunk] = AIMessageChunk
521
+ ) -> ChatGenerationChunk:
522
+ """Handle a single Server-Sent Events (SSE) line and process it into
523
+ a chat generation chunk.
524
+
525
+ Args:
526
+ line (str): A single line from the SSE stream in string format.
527
+ default_chunk_cls (AIMessageChunk): The default class for message
528
+ chunks to be used during the processing of the stream response.
529
+
530
+ Returns:
531
+ ChatGenerationChunk: The processed chat generation chunk. If an error
532
+ occurs, an empty `ChatGenerationChunk` is returned.
533
+ """
534
+ try:
535
+ obj = json.loads(line)
536
+ return self._process_stream_response(obj, default_chunk_cls)
537
+ except Exception as e:
538
+ logger.debug(f"Error occurs when processing line={line}: {str(e)}")
539
+ return ChatGenerationChunk(message=AIMessageChunk(content=""))
540
+
541
+ def _construct_json_body(self, messages: list, params: dict) -> dict:
542
+ """Constructs the request body as a dictionary (JSON).
543
+
544
+ Args:
545
+ messages (list): A list of message objects to be included in the
546
+ request body.
547
+ params (dict): A dictionary of additional parameters to be included
548
+ in the request body.
549
+
550
+ Returns:
551
+ dict: A dictionary representing the JSON request body, including
552
+ converted messages and additional parameters.
553
+
554
+ """
555
+ return {
556
+ "messages": [_convert_message_to_dict(m) for m in messages],
557
+ **params,
558
+ }
559
+
560
+ def _process_stream_response(
561
+ self,
562
+ response_json: dict,
563
+ default_chunk_cls: Type[BaseMessageChunk] = AIMessageChunk,
564
+ ) -> ChatGenerationChunk:
565
+ """Formats streaming response in OpenAI spec.
566
+
567
+ Args:
568
+ response_json (dict): The JSON response from the streaming endpoint.
569
+ default_chunk_cls (type, optional): The default class to use for
570
+ creating message chunks. Defaults to `AIMessageChunk`.
571
+
572
+ Returns:
573
+ ChatGenerationChunk: An object containing the processed message
574
+ chunk and any relevant generation information such as finish
575
+ reason and usage.
576
+
577
+ Raises:
578
+ ValueError: If the response JSON is not well-formed or does not
579
+ contain the expected structure.
580
+ """
581
+ try:
582
+ choice = response_json["choices"][0]
583
+ if not isinstance(choice, dict):
584
+ raise TypeError("Endpoint response is not well formed.")
585
+ except (KeyError, IndexError, TypeError) as e:
586
+ raise ValueError(
587
+ "Error while formatting response payload for chat model of type"
588
+ ) from e
589
+
590
+ chunk = _convert_delta_to_message_chunk(choice["delta"], default_chunk_cls)
591
+ default_chunk_cls = chunk.__class__
592
+ finish_reason = choice.get("finish_reason")
593
+ usage = choice.get("usage")
594
+ gen_info = {}
595
+ if finish_reason is not None:
596
+ gen_info.update({"finish_reason": finish_reason})
597
+ if usage is not None:
598
+ gen_info.update({"usage": usage})
599
+
600
+ return ChatGenerationChunk(
601
+ message=chunk, generation_info=gen_info if gen_info else None
602
+ )
603
+
604
+ def _process_response(self, response_json: dict) -> ChatResult:
605
+ """Formats response in OpenAI spec.
606
+
607
+ Args:
608
+ response_json (dict): The JSON response from the chat model endpoint.
609
+
610
+ Returns:
611
+ ChatResult: An object containing the list of `ChatGeneration` objects
612
+ and additional LLM output information.
613
+
614
+ Raises:
615
+ ValueError: If the response JSON is not well-formed or does not
616
+ contain the expected structure.
617
+
618
+ """
619
+ generations = []
620
+ try:
621
+ choices = response_json["choices"]
622
+ if not isinstance(choices, list):
623
+ raise TypeError("Endpoint response is not well formed.")
624
+ except (KeyError, TypeError) as e:
625
+ raise ValueError(
626
+ "Error while formatting response payload for chat model of type"
627
+ ) from e
628
+
629
+ for choice in choices:
630
+ message = _convert_dict_to_message(choice["message"])
631
+ generation_info = dict(finish_reason=choice.get("finish_reason"))
632
+ if "logprobs" in choice:
633
+ generation_info["logprobs"] = choice["logprobs"]
634
+
635
+ gen = ChatGeneration(
636
+ message=message,
637
+ generation_info=generation_info,
638
+ )
639
+ generations.append(gen)
640
+
641
+ token_usage = response_json.get("usage", {})
642
+ llm_output = {
643
+ "token_usage": token_usage,
644
+ "model_name": self.model,
645
+ "system_fingerprint": response_json.get("system_fingerprint", ""),
646
+ }
647
+ return ChatResult(generations=generations, llm_output=llm_output)
648
+
649
+ def bind_tools(
650
+ self,
651
+ tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
652
+ **kwargs: Any,
653
+ ) -> Runnable[LanguageModelInput, BaseMessage]:
654
+ formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
655
+ return super().bind(tools=formatted_tools, **kwargs)
656
+
657
+
658
+ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
659
+ """OCI large language chat models deployed with vLLM.
660
+
661
+ To use, you must provide the model HTTP endpoint from your deployed
662
+ model, e.g. https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<ocid>/predict.
663
+
664
+ To authenticate, `oracle-ads` has been used to automatically load
665
+ credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
666
+
667
+ Make sure to have the required policies to access the OCI Data
668
+ Science Model Deployment endpoint. See:
669
+ https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
670
+
671
+ Example:
672
+
673
+ .. code-block:: python
674
+
675
+ from langchain_community.chat_models import ChatOCIModelDeploymentVLLM
676
+
677
+ chat = ChatOCIModelDeploymentVLLM(
678
+ endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<ocid>/predict",
679
+ frequency_penalty=0.1,
680
+ max_tokens=512,
681
+ temperature=0.2,
682
+ top_p=1.0,
683
+ # other model parameters...
684
+ )
685
+
686
+ """ # noqa: E501
687
+
688
+ frequency_penalty: float = 0.0
689
+ """Penalizes repeated tokens according to frequency. Between 0 and 1."""
690
+
691
+ logit_bias: Optional[Dict[str, float]] = None
692
+ """Adjust the probability of specific tokens being generated."""
693
+
694
+ max_tokens: Optional[int] = 256
695
+ """The maximum number of tokens to generate in the completion."""
696
+
697
+ n: int = 1
698
+ """Number of output sequences to return for the given prompt."""
699
+
700
+ presence_penalty: float = 0.0
701
+ """Penalizes repeated tokens. Between 0 and 1."""
702
+
703
+ temperature: float = 0.2
704
+ """What sampling temperature to use."""
705
+
706
+ top_p: float = 1.0
707
+ """Total probability mass of tokens to consider at each step."""
708
+
709
+ best_of: Optional[int] = None
710
+ """Generates best_of completions server-side and returns the "best"
711
+ (the one with the highest log probability per token).
712
+ """
713
+
714
+ use_beam_search: Optional[bool] = False
715
+ """Whether to use beam search instead of sampling."""
716
+
717
+ top_k: Optional[int] = -1
718
+ """Number of most likely tokens to consider at each step."""
719
+
720
+ min_p: Optional[float] = 0.0
721
+ """Float that represents the minimum probability for a token to be considered.
722
+ Must be in [0,1]. 0 to disable this."""
723
+
724
+ repetition_penalty: Optional[float] = 1.0
725
+ """Float that penalizes new tokens based on their frequency in the
726
+ generated text. Values > 1 encourage the model to use new tokens."""
727
+
728
+ length_penalty: Optional[float] = 1.0
729
+ """Float that penalizes sequences based on their length. Used only
730
+ when `use_beam_search` is True."""
731
+
732
+ early_stopping: Optional[bool] = False
733
+ """Controls the stopping condition for beam search. It accepts the
734
+ following values: `True`, where the generation stops as soon as there
735
+ are `best_of` complete candidates; `False`, where a heuristic is applied
736
+ to the generation stops when it is very unlikely to find better candidates;
737
+ `never`, where the beam search procedure only stops where there cannot be
738
+ better candidates (canonical beam search algorithm)."""
739
+
740
+ ignore_eos: Optional[bool] = False
741
+ """Whether to ignore the EOS token and continue generating tokens after
742
+ the EOS token is generated."""
743
+
744
+ min_tokens: Optional[int] = 0
745
+ """Minimum number of tokens to generate per output sequence before
746
+ EOS or stop_token_ids can be generated"""
747
+
748
+ stop_token_ids: Optional[List[int]] = None
749
+ """List of tokens that stop the generation when they are generated.
750
+ The returned output will contain the stop tokens unless the stop tokens
751
+ are special tokens."""
752
+
753
+ skip_special_tokens: Optional[bool] = True
754
+ """Whether to skip special tokens in the output. Defaults to True."""
755
+
756
+ spaces_between_special_tokens: Optional[bool] = True
757
+ """Whether to add spaces between special tokens in the output.
758
+ Defaults to True."""
759
+
760
+ tool_choice: Optional[str] = None
761
+ """Whether to use tool calling.
762
+ Defaults to None, tool calling is disabled.
763
+ Tool calling requires model support and vLLM to be configured with `--tool-call-parser`.
764
+ Set this to `auto` for the model to determine whether to make tool calls automatically.
765
+ Set this to `required` to force the model to always call one or more tools.
766
+ """
767
+
768
+ chat_template: Optional[str] = None
769
+ """Use customized chat template.
770
+ Defaults to None. The chat template from the tokenizer will be used.
771
+ """
772
+
773
+ @property
774
+ def _llm_type(self) -> str:
775
+ """Return type of llm."""
776
+ return "oci_model_depolyment_chat_endpoint_vllm"
777
+
778
+ @property
779
+ def _default_params(self) -> Dict[str, Any]:
780
+ """Get the default parameters."""
781
+ params = {
782
+ "model": self.model,
783
+ "stop": self.stop,
784
+ "stream": self.streaming,
785
+ }
786
+ for attr_name in self._get_model_params():
787
+ try:
788
+ value = getattr(self, attr_name)
789
+ if value is not None:
790
+ params.update({attr_name: value})
791
+ except Exception:
792
+ pass
793
+
794
+ return params
795
+
796
+ def _get_model_params(self) -> List[str]:
797
+ """Gets the name of model parameters."""
798
+ return [
799
+ "best_of",
800
+ "early_stopping",
801
+ "frequency_penalty",
802
+ "ignore_eos",
803
+ "length_penalty",
804
+ "logit_bias",
805
+ "logprobs",
806
+ "max_tokens",
807
+ "min_p",
808
+ "min_tokens",
809
+ "n",
810
+ "presence_penalty",
811
+ "repetition_penalty",
812
+ "skip_special_tokens",
813
+ "spaces_between_special_tokens",
814
+ "stop_token_ids",
815
+ "temperature",
816
+ "top_k",
817
+ "top_p",
818
+ "use_beam_search",
819
+ "tool_choice",
820
+ "chat_template",
821
+ ]
822
+
823
+
824
+ class ChatOCIModelDeploymentTGI(ChatOCIModelDeployment):
825
+ """OCI large language chat models deployed with Text Generation Inference.
826
+
827
+ To use, you must provide the model HTTP endpoint from your deployed
828
+ model, e.g. https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<ocid>/predict.
829
+
830
+ To authenticate, `oracle-ads` has been used to automatically load
831
+ credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
832
+
833
+ Make sure to have the required policies to access the OCI Data
834
+ Science Model Deployment endpoint. See:
835
+ https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
836
+
837
+ Example:
838
+
839
+ .. code-block:: python
840
+
841
+ from langchain_community.chat_models import ChatOCIModelDeploymentTGI
842
+
843
+ chat = ChatOCIModelDeploymentTGI(
844
+ endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<ocid>/predict",
845
+ max_token=512,
846
+ temperature=0.2,
847
+ frequency_penalty=0.1,
848
+ seed=42,
849
+ # other model parameters...
850
+ )
851
+
852
+ """ # noqa: E501
853
+
854
+ frequency_penalty: Optional[float] = None
855
+ """Penalizes repeated tokens according to frequency. Between 0 and 1."""
856
+
857
+ logit_bias: Optional[Dict[str, float]] = None
858
+ """Adjust the probability of specific tokens being generated."""
859
+
860
+ logprobs: Optional[bool] = None
861
+ """Whether to return log probabilities of the output tokens or not."""
862
+
863
+ max_tokens: int = 256
864
+ """The maximum number of tokens to generate in the completion."""
865
+
866
+ n: int = 1
867
+ """Number of output sequences to return for the given prompt."""
868
+
869
+ presence_penalty: Optional[float] = None
870
+ """Penalizes repeated tokens. Between 0 and 1."""
871
+
872
+ seed: Optional[int] = None
873
+ """To sample deterministically,"""
874
+
875
+ temperature: float = 0.2
876
+ """What sampling temperature to use."""
877
+
878
+ top_p: Optional[float] = None
879
+ """Total probability mass of tokens to consider at each step."""
880
+
881
+ top_logprobs: Optional[int] = None
882
+ """An integer between 0 and 5 specifying the number of most
883
+ likely tokens to return at each token position, each with an
884
+ associated log probability. logprobs must be set to true if
885
+ this parameter is used."""
886
+
887
+ @property
888
+ def _llm_type(self) -> str:
889
+ """Return type of llm."""
890
+ return "oci_model_depolyment_chat_endpoint_tgi"
891
+
892
+ @property
893
+ def _default_params(self) -> Dict[str, Any]:
894
+ """Get the default parameters."""
895
+ params = {
896
+ "model": self.model,
897
+ "stop": self.stop,
898
+ "stream": self.streaming,
899
+ }
900
+ for attr_name in self._get_model_params():
901
+ try:
902
+ value = getattr(self, attr_name)
903
+ if value is not None:
904
+ params.update({attr_name: value})
905
+ except Exception:
906
+ pass
907
+
908
+ return params
909
+
910
+ def _get_model_params(self) -> List[str]:
911
+ """Gets the name of model parameters."""
912
+ return [
913
+ "frequency_penalty",
914
+ "logit_bias",
915
+ "logprobs",
916
+ "max_tokens",
917
+ "n",
918
+ "presence_penalty",
919
+ "seed",
920
+ "temperature",
921
+ "top_k",
922
+ "top_p",
923
+ "top_logprobs",
924
+ ]