langchain-google-genai 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain-google-genai might be problematic. Click here for more details.
- langchain_google_genai/__init__.py +1 -1
- langchain_google_genai/_common.py +48 -0
- langchain_google_genai/_enums.py +4 -4
- langchain_google_genai/_function_utils.py +50 -54
- langchain_google_genai/_genai_extension.py +64 -7
- langchain_google_genai/_image_utils.py +187 -0
- langchain_google_genai/chat_models.py +211 -107
- langchain_google_genai/embeddings.py +85 -41
- langchain_google_genai/llms.py +0 -1
- {langchain_google_genai-1.0.3.dist-info → langchain_google_genai-1.0.5.dist-info}/METADATA +2 -2
- langchain_google_genai-1.0.5.dist-info/RECORD +16 -0
- langchain_google_genai-1.0.3.dist-info/RECORD +0 -15
- {langchain_google_genai-1.0.3.dist-info → langchain_google_genai-1.0.5.dist-info}/LICENSE +0 -0
- {langchain_google_genai-1.0.3.dist-info → langchain_google_genai-1.0.5.dist-info}/WHEEL +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import base64
|
|
4
5
|
import json
|
|
5
6
|
import logging
|
|
@@ -23,14 +24,23 @@ from typing import (
|
|
|
23
24
|
)
|
|
24
25
|
from urllib.parse import urlparse
|
|
25
26
|
|
|
26
|
-
import google.ai.generativelanguage as glm
|
|
27
27
|
import google.api_core
|
|
28
28
|
|
|
29
29
|
# TODO: remove ignore once the google package is published with types
|
|
30
|
-
import google.generativeai as genai # type: ignore[import]
|
|
31
30
|
import proto # type: ignore[import]
|
|
32
31
|
import requests
|
|
33
|
-
from google.
|
|
32
|
+
from google.ai.generativelanguage_v1beta.types import (
|
|
33
|
+
Candidate,
|
|
34
|
+
Content,
|
|
35
|
+
FunctionCall,
|
|
36
|
+
FunctionResponse,
|
|
37
|
+
GenerateContentRequest,
|
|
38
|
+
GenerateContentResponse,
|
|
39
|
+
GenerationConfig,
|
|
40
|
+
Part,
|
|
41
|
+
SafetySetting,
|
|
42
|
+
ToolConfig,
|
|
43
|
+
)
|
|
34
44
|
from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
|
|
35
45
|
from google.generativeai.types.content_types import ( # type: ignore[import]
|
|
36
46
|
FunctionDeclarationType,
|
|
@@ -56,7 +66,7 @@ from langchain_core.messages import (
|
|
|
56
66
|
)
|
|
57
67
|
from langchain_core.output_parsers.openai_tools import parse_tool_calls
|
|
58
68
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
59
|
-
from langchain_core.pydantic_v1 import SecretStr, root_validator
|
|
69
|
+
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
|
60
70
|
from langchain_core.runnables import Runnable
|
|
61
71
|
from langchain_core.utils import get_from_dict_or_env
|
|
62
72
|
from tenacity import (
|
|
@@ -67,7 +77,11 @@ from tenacity import (
|
|
|
67
77
|
wait_exponential,
|
|
68
78
|
)
|
|
69
79
|
|
|
70
|
-
from langchain_google_genai._common import
|
|
80
|
+
from langchain_google_genai._common import (
|
|
81
|
+
GoogleGenerativeAIError,
|
|
82
|
+
SafetySettingDict,
|
|
83
|
+
get_client_info,
|
|
84
|
+
)
|
|
71
85
|
from langchain_google_genai._function_utils import (
|
|
72
86
|
_tool_choice_to_tool_config,
|
|
73
87
|
_ToolChoiceType,
|
|
@@ -75,7 +89,10 @@ from langchain_google_genai._function_utils import (
|
|
|
75
89
|
convert_to_genai_function_declarations,
|
|
76
90
|
tool_to_dict,
|
|
77
91
|
)
|
|
78
|
-
from langchain_google_genai.
|
|
92
|
+
from langchain_google_genai._image_utils import ImageBytesLoader
|
|
93
|
+
from langchain_google_genai.llms import _BaseGoogleGenerativeAI
|
|
94
|
+
|
|
95
|
+
from . import _genai_extension as genaix
|
|
79
96
|
|
|
80
97
|
IMAGE_TYPES: Tuple = ()
|
|
81
98
|
try:
|
|
@@ -279,18 +296,19 @@ def _url_to_pil(image_source: str) -> Image:
|
|
|
279
296
|
|
|
280
297
|
def _convert_to_parts(
|
|
281
298
|
raw_content: Union[str, Sequence[Union[str, dict]]],
|
|
282
|
-
) -> List[
|
|
299
|
+
) -> List[Part]:
|
|
283
300
|
"""Converts a list of LangChain messages into a google parts."""
|
|
284
301
|
parts = []
|
|
285
302
|
content = [raw_content] if isinstance(raw_content, str) else raw_content
|
|
303
|
+
image_loader = ImageBytesLoader()
|
|
286
304
|
for part in content:
|
|
287
305
|
if isinstance(part, str):
|
|
288
|
-
parts.append(
|
|
306
|
+
parts.append(Part(text=part))
|
|
289
307
|
elif isinstance(part, Mapping):
|
|
290
308
|
# OpenAI Format
|
|
291
309
|
if _is_openai_parts_format(part):
|
|
292
310
|
if part["type"] == "text":
|
|
293
|
-
parts.append(
|
|
311
|
+
parts.append(Part(text=part["text"]))
|
|
294
312
|
elif part["type"] == "image_url":
|
|
295
313
|
img_url = part["image_url"]
|
|
296
314
|
if isinstance(img_url, dict):
|
|
@@ -299,7 +317,7 @@ def _convert_to_parts(
|
|
|
299
317
|
f"Unrecognized message image format: {img_url}"
|
|
300
318
|
)
|
|
301
319
|
img_url = img_url["url"]
|
|
302
|
-
parts.append(
|
|
320
|
+
parts.append(image_loader.load_part(img_url))
|
|
303
321
|
else:
|
|
304
322
|
raise ValueError(f"Unrecognized message part type: {part['type']}")
|
|
305
323
|
else:
|
|
@@ -307,7 +325,7 @@ def _convert_to_parts(
|
|
|
307
325
|
logger.warning(
|
|
308
326
|
"Unrecognized message part format. Assuming it's a text part."
|
|
309
327
|
)
|
|
310
|
-
parts.append(part)
|
|
328
|
+
parts.append(Part(text=str(part)))
|
|
311
329
|
else:
|
|
312
330
|
# TODO: Maybe some of Google's native stuff
|
|
313
331
|
# would hit this branch.
|
|
@@ -319,33 +337,36 @@ def _convert_to_parts(
|
|
|
319
337
|
|
|
320
338
|
def _parse_chat_history(
|
|
321
339
|
input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
|
|
322
|
-
) -> Tuple[Optional[
|
|
323
|
-
messages: List[
|
|
340
|
+
) -> Tuple[Optional[Content], List[Content]]:
|
|
341
|
+
messages: List[Content] = []
|
|
324
342
|
|
|
325
343
|
if convert_system_message_to_human:
|
|
326
344
|
warnings.warn("Convert_system_message_to_human will be deprecated!")
|
|
327
345
|
|
|
328
|
-
system_instruction: Optional[
|
|
346
|
+
system_instruction: Optional[Content] = None
|
|
329
347
|
for i, message in enumerate(input_messages):
|
|
330
348
|
if i == 0 and isinstance(message, SystemMessage):
|
|
331
|
-
system_instruction = _convert_to_parts(message.content)
|
|
349
|
+
system_instruction = Content(parts=_convert_to_parts(message.content))
|
|
332
350
|
continue
|
|
333
351
|
elif isinstance(message, AIMessage):
|
|
334
352
|
role = "model"
|
|
335
353
|
raw_function_call = message.additional_kwargs.get("function_call")
|
|
336
354
|
if raw_function_call:
|
|
337
|
-
function_call =
|
|
355
|
+
function_call = FunctionCall(
|
|
338
356
|
{
|
|
339
357
|
"name": raw_function_call["name"],
|
|
340
358
|
"args": json.loads(raw_function_call["arguments"]),
|
|
341
359
|
}
|
|
342
360
|
)
|
|
343
|
-
parts = [
|
|
361
|
+
parts = [Part(function_call=function_call)]
|
|
344
362
|
else:
|
|
345
363
|
parts = _convert_to_parts(message.content)
|
|
346
364
|
elif isinstance(message, HumanMessage):
|
|
347
365
|
role = "user"
|
|
348
366
|
parts = _convert_to_parts(message.content)
|
|
367
|
+
if i == 1 and convert_system_message_to_human and system_instruction:
|
|
368
|
+
parts = [p for p in system_instruction.parts] + parts
|
|
369
|
+
system_instruction = None
|
|
349
370
|
elif isinstance(message, FunctionMessage):
|
|
350
371
|
role = "user"
|
|
351
372
|
response: Any
|
|
@@ -357,8 +378,8 @@ def _parse_chat_history(
|
|
|
357
378
|
except json.JSONDecodeError:
|
|
358
379
|
response = message.content # leave as str representation
|
|
359
380
|
parts = [
|
|
360
|
-
|
|
361
|
-
function_response=
|
|
381
|
+
Part(
|
|
382
|
+
function_response=FunctionResponse(
|
|
362
383
|
name=message.name,
|
|
363
384
|
response=(
|
|
364
385
|
{"output": response}
|
|
@@ -391,8 +412,8 @@ def _parse_chat_history(
|
|
|
391
412
|
except json.JSONDecodeError:
|
|
392
413
|
tool_response = message.content # leave as str representation
|
|
393
414
|
parts = [
|
|
394
|
-
|
|
395
|
-
function_response=
|
|
415
|
+
Part(
|
|
416
|
+
function_response=FunctionResponse(
|
|
396
417
|
name=name,
|
|
397
418
|
response=(
|
|
398
419
|
{"output": tool_response}
|
|
@@ -407,12 +428,12 @@ def _parse_chat_history(
|
|
|
407
428
|
f"Unexpected message with type {type(message)} at the position {i}."
|
|
408
429
|
)
|
|
409
430
|
|
|
410
|
-
messages.append(
|
|
431
|
+
messages.append(Content(role=role, parts=parts))
|
|
411
432
|
return system_instruction, messages
|
|
412
433
|
|
|
413
434
|
|
|
414
435
|
def _parse_response_candidate(
|
|
415
|
-
response_candidate:
|
|
436
|
+
response_candidate: Candidate, streaming: bool = False
|
|
416
437
|
) -> AIMessage:
|
|
417
438
|
content: Union[None, str, List[str]] = None
|
|
418
439
|
additional_kwargs = {}
|
|
@@ -499,7 +520,7 @@ def _parse_response_candidate(
|
|
|
499
520
|
|
|
500
521
|
|
|
501
522
|
def _response_to_result(
|
|
502
|
-
response:
|
|
523
|
+
response: GenerateContentResponse,
|
|
503
524
|
stream: bool = False,
|
|
504
525
|
) -> ChatResult:
|
|
505
526
|
"""Converts a PaLM API response into a LangChain ChatResult."""
|
|
@@ -538,6 +559,14 @@ def _response_to_result(
|
|
|
538
559
|
return ChatResult(generations=generations, llm_output=llm_output)
|
|
539
560
|
|
|
540
561
|
|
|
562
|
+
def _is_event_loop_running() -> bool:
|
|
563
|
+
try:
|
|
564
|
+
asyncio.get_running_loop()
|
|
565
|
+
return True
|
|
566
|
+
except RuntimeError:
|
|
567
|
+
return False
|
|
568
|
+
|
|
569
|
+
|
|
541
570
|
class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
542
571
|
"""`Google Generative AI` Chat models API.
|
|
543
572
|
|
|
@@ -557,6 +586,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
557
586
|
"""
|
|
558
587
|
|
|
559
588
|
client: Any #: :meta private:
|
|
589
|
+
async_client: Any #: :meta private:
|
|
590
|
+
default_metadata: Sequence[Tuple[str, str]] = Field(
|
|
591
|
+
default_factory=list
|
|
592
|
+
) #: :meta private:
|
|
560
593
|
|
|
561
594
|
convert_system_message_to_human: bool = False
|
|
562
595
|
"""Whether to merge any leading SystemMessage into the following HumanMessage.
|
|
@@ -582,29 +615,6 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
582
615
|
@root_validator()
|
|
583
616
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
584
617
|
"""Validates params and passes them to google-generativeai package."""
|
|
585
|
-
additional_headers = values.get("additional_headers") or {}
|
|
586
|
-
default_metadata = tuple(additional_headers.items())
|
|
587
|
-
|
|
588
|
-
if values.get("credentials"):
|
|
589
|
-
genai.configure(
|
|
590
|
-
credentials=values.get("credentials"),
|
|
591
|
-
transport=values.get("transport"),
|
|
592
|
-
client_options=values.get("client_options"),
|
|
593
|
-
default_metadata=default_metadata,
|
|
594
|
-
)
|
|
595
|
-
else:
|
|
596
|
-
google_api_key = get_from_dict_or_env(
|
|
597
|
-
values, "google_api_key", "GOOGLE_API_KEY"
|
|
598
|
-
)
|
|
599
|
-
if isinstance(google_api_key, SecretStr):
|
|
600
|
-
google_api_key = google_api_key.get_secret_value()
|
|
601
|
-
|
|
602
|
-
genai.configure(
|
|
603
|
-
api_key=google_api_key,
|
|
604
|
-
transport=values.get("transport"),
|
|
605
|
-
client_options=values.get("client_options"),
|
|
606
|
-
default_metadata=default_metadata,
|
|
607
|
-
)
|
|
608
618
|
if (
|
|
609
619
|
values.get("temperature") is not None
|
|
610
620
|
and not 0 <= values["temperature"] <= 1
|
|
@@ -616,8 +626,45 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
616
626
|
|
|
617
627
|
if values.get("top_k") is not None and values["top_k"] <= 0:
|
|
618
628
|
raise ValueError("top_k must be positive")
|
|
619
|
-
|
|
620
|
-
values["
|
|
629
|
+
|
|
630
|
+
if not values["model"].startswith("models/"):
|
|
631
|
+
values["model"] = f"models/{values['model']}"
|
|
632
|
+
|
|
633
|
+
additional_headers = values.get("additional_headers") or {}
|
|
634
|
+
values["default_metadata"] = tuple(additional_headers.items())
|
|
635
|
+
client_info = get_client_info("ChatGoogleGenerativeAI")
|
|
636
|
+
google_api_key = None
|
|
637
|
+
if not values.get("credentials"):
|
|
638
|
+
google_api_key = get_from_dict_or_env(
|
|
639
|
+
values, "google_api_key", "GOOGLE_API_KEY"
|
|
640
|
+
)
|
|
641
|
+
if isinstance(google_api_key, SecretStr):
|
|
642
|
+
google_api_key = google_api_key.get_secret_value()
|
|
643
|
+
transport: Optional[str] = values.get("transport")
|
|
644
|
+
values["client"] = genaix.build_generative_service(
|
|
645
|
+
credentials=values.get("credentials"),
|
|
646
|
+
api_key=google_api_key,
|
|
647
|
+
client_info=client_info,
|
|
648
|
+
client_options=values.get("client_options"),
|
|
649
|
+
transport=transport,
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
# NOTE: genaix.build_generative_async_service requires
|
|
653
|
+
# a running event loop, which causes an error
|
|
654
|
+
# when initialized inside a ThreadPoolExecutor.
|
|
655
|
+
# this check ensures that async client is only initialized
|
|
656
|
+
# within an asyncio event loop to avoid the error
|
|
657
|
+
if _is_event_loop_running():
|
|
658
|
+
values["async_client"] = genaix.build_generative_async_service(
|
|
659
|
+
credentials=values.get("credentials"),
|
|
660
|
+
api_key=google_api_key,
|
|
661
|
+
client_info=client_info,
|
|
662
|
+
client_options=values.get("client_options"),
|
|
663
|
+
transport=transport,
|
|
664
|
+
)
|
|
665
|
+
else:
|
|
666
|
+
values["async_client"] = None
|
|
667
|
+
|
|
621
668
|
return values
|
|
622
669
|
|
|
623
670
|
@property
|
|
@@ -632,8 +679,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
632
679
|
}
|
|
633
680
|
|
|
634
681
|
def _prepare_params(
|
|
635
|
-
self,
|
|
636
|
-
|
|
682
|
+
self,
|
|
683
|
+
stop: Optional[List[str]],
|
|
684
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
685
|
+
) -> GenerationConfig:
|
|
637
686
|
gen_config = {
|
|
638
687
|
k: v
|
|
639
688
|
for k, v in {
|
|
@@ -646,27 +695,37 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
646
695
|
}.items()
|
|
647
696
|
if v is not None
|
|
648
697
|
}
|
|
649
|
-
if
|
|
650
|
-
gen_config = {**gen_config, **
|
|
651
|
-
|
|
652
|
-
return params
|
|
698
|
+
if generation_config:
|
|
699
|
+
gen_config = {**gen_config, **generation_config}
|
|
700
|
+
return GenerationConfig(**gen_config)
|
|
653
701
|
|
|
654
702
|
def _generate(
|
|
655
703
|
self,
|
|
656
704
|
messages: List[BaseMessage],
|
|
657
705
|
stop: Optional[List[str]] = None,
|
|
658
706
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
707
|
+
*,
|
|
708
|
+
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
|
|
709
|
+
functions: Optional[Sequence[FunctionDeclarationType]] = None,
|
|
710
|
+
safety_settings: Optional[SafetySettingDict] = None,
|
|
711
|
+
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
712
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
659
713
|
**kwargs: Any,
|
|
660
714
|
) -> ChatResult:
|
|
661
|
-
|
|
715
|
+
request = self._prepare_request(
|
|
662
716
|
messages,
|
|
663
717
|
stop=stop,
|
|
664
|
-
|
|
718
|
+
tools=tools,
|
|
719
|
+
functions=functions,
|
|
720
|
+
safety_settings=safety_settings,
|
|
721
|
+
tool_config=tool_config,
|
|
722
|
+
generation_config=generation_config,
|
|
665
723
|
)
|
|
666
|
-
response:
|
|
667
|
-
|
|
668
|
-
**
|
|
669
|
-
generation_method=
|
|
724
|
+
response: GenerateContentResponse = _chat_with_retry(
|
|
725
|
+
request=request,
|
|
726
|
+
**kwargs,
|
|
727
|
+
generation_method=self.client.generate_content,
|
|
728
|
+
metadata=self.default_metadata,
|
|
670
729
|
)
|
|
671
730
|
return _response_to_result(response)
|
|
672
731
|
|
|
@@ -675,17 +734,34 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
675
734
|
messages: List[BaseMessage],
|
|
676
735
|
stop: Optional[List[str]] = None,
|
|
677
736
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
737
|
+
*,
|
|
738
|
+
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
|
|
739
|
+
functions: Optional[Sequence[FunctionDeclarationType]] = None,
|
|
740
|
+
safety_settings: Optional[SafetySettingDict] = None,
|
|
741
|
+
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
742
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
678
743
|
**kwargs: Any,
|
|
679
744
|
) -> ChatResult:
|
|
680
|
-
|
|
745
|
+
if not self.async_client:
|
|
746
|
+
raise RuntimeError(
|
|
747
|
+
"Initialize ChatGoogleGenerativeAI with a running event loop "
|
|
748
|
+
"to use async methods."
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
request = self._prepare_request(
|
|
681
752
|
messages,
|
|
682
753
|
stop=stop,
|
|
683
|
-
|
|
754
|
+
tools=tools,
|
|
755
|
+
functions=functions,
|
|
756
|
+
safety_settings=safety_settings,
|
|
757
|
+
tool_config=tool_config,
|
|
758
|
+
generation_config=generation_config,
|
|
684
759
|
)
|
|
685
|
-
response:
|
|
686
|
-
|
|
687
|
-
**
|
|
688
|
-
generation_method=
|
|
760
|
+
response: GenerateContentResponse = await _achat_with_retry(
|
|
761
|
+
request=request,
|
|
762
|
+
**kwargs,
|
|
763
|
+
generation_method=self.async_client.generate_content,
|
|
764
|
+
metadata=self.default_metadata,
|
|
689
765
|
)
|
|
690
766
|
return _response_to_result(response)
|
|
691
767
|
|
|
@@ -694,18 +770,28 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
694
770
|
messages: List[BaseMessage],
|
|
695
771
|
stop: Optional[List[str]] = None,
|
|
696
772
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
773
|
+
*,
|
|
774
|
+
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
|
|
775
|
+
functions: Optional[Sequence[FunctionDeclarationType]] = None,
|
|
776
|
+
safety_settings: Optional[SafetySettingDict] = None,
|
|
777
|
+
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
778
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
697
779
|
**kwargs: Any,
|
|
698
780
|
) -> Iterator[ChatGenerationChunk]:
|
|
699
|
-
|
|
781
|
+
request = self._prepare_request(
|
|
700
782
|
messages,
|
|
701
783
|
stop=stop,
|
|
702
|
-
|
|
784
|
+
tools=tools,
|
|
785
|
+
functions=functions,
|
|
786
|
+
safety_settings=safety_settings,
|
|
787
|
+
tool_config=tool_config,
|
|
788
|
+
generation_config=generation_config,
|
|
703
789
|
)
|
|
704
|
-
response:
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
790
|
+
response: GenerateContentResponse = _chat_with_retry(
|
|
791
|
+
request=request,
|
|
792
|
+
generation_method=self.client.stream_generate_content,
|
|
793
|
+
**kwargs,
|
|
794
|
+
metadata=self.default_metadata,
|
|
709
795
|
)
|
|
710
796
|
for chunk in response:
|
|
711
797
|
_chat_result = _response_to_result(chunk, stream=True)
|
|
@@ -720,18 +806,28 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
720
806
|
messages: List[BaseMessage],
|
|
721
807
|
stop: Optional[List[str]] = None,
|
|
722
808
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
809
|
+
*,
|
|
810
|
+
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
|
|
811
|
+
functions: Optional[Sequence[FunctionDeclarationType]] = None,
|
|
812
|
+
safety_settings: Optional[SafetySettingDict] = None,
|
|
813
|
+
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
814
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
723
815
|
**kwargs: Any,
|
|
724
816
|
) -> AsyncIterator[ChatGenerationChunk]:
|
|
725
|
-
|
|
817
|
+
request = self._prepare_request(
|
|
726
818
|
messages,
|
|
727
819
|
stop=stop,
|
|
728
|
-
|
|
820
|
+
tools=tools,
|
|
821
|
+
functions=functions,
|
|
822
|
+
safety_settings=safety_settings,
|
|
823
|
+
tool_config=tool_config,
|
|
824
|
+
generation_config=generation_config,
|
|
729
825
|
)
|
|
730
826
|
async for chunk in await _achat_with_retry(
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
827
|
+
request=request,
|
|
828
|
+
generation_method=self.async_client.stream_generate_content,
|
|
829
|
+
**kwargs,
|
|
830
|
+
metadata=self.default_metadata,
|
|
735
831
|
):
|
|
736
832
|
_chat_result = _response_to_result(chunk, stream=True)
|
|
737
833
|
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
|
@@ -740,17 +836,17 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
740
836
|
await run_manager.on_llm_new_token(gen.text)
|
|
741
837
|
yield gen
|
|
742
838
|
|
|
743
|
-
def
|
|
839
|
+
def _prepare_request(
|
|
744
840
|
self,
|
|
745
841
|
messages: List[BaseMessage],
|
|
842
|
+
*,
|
|
746
843
|
stop: Optional[List[str]] = None,
|
|
747
844
|
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
|
|
748
845
|
functions: Optional[Sequence[FunctionDeclarationType]] = None,
|
|
749
846
|
safety_settings: Optional[SafetySettingDict] = None,
|
|
750
847
|
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
751
|
-
|
|
752
|
-
) -> Tuple[Dict[str, Any]
|
|
753
|
-
client = self.client
|
|
848
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
849
|
+
) -> Tuple[GenerateContentRequest, Dict[str, Any]]:
|
|
754
850
|
formatted_tools = None
|
|
755
851
|
if tools:
|
|
756
852
|
formatted_tools = [
|
|
@@ -759,25 +855,35 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
759
855
|
elif functions:
|
|
760
856
|
formatted_tools = [convert_to_genai_function_declarations(functions)]
|
|
761
857
|
|
|
762
|
-
if formatted_tools or safety_settings:
|
|
763
|
-
client = genai.GenerativeModel(
|
|
764
|
-
model_name=self.model,
|
|
765
|
-
tools=formatted_tools,
|
|
766
|
-
safety_settings=safety_settings,
|
|
767
|
-
)
|
|
768
|
-
|
|
769
|
-
params = self._prepare_params(stop, tool_config=tool_config, **kwargs)
|
|
770
858
|
system_instruction, history = _parse_chat_history(
|
|
771
859
|
messages,
|
|
772
860
|
convert_system_message_to_human=self.convert_system_message_to_human,
|
|
773
861
|
)
|
|
774
|
-
|
|
775
|
-
if
|
|
776
|
-
|
|
777
|
-
|
|
862
|
+
formatted_tool_config = None
|
|
863
|
+
if tool_config:
|
|
864
|
+
formatted_tool_config = ToolConfig(
|
|
865
|
+
function_calling_config=tool_config["function_calling_config"]
|
|
778
866
|
)
|
|
779
|
-
|
|
780
|
-
|
|
867
|
+
formatted_safety_settings = []
|
|
868
|
+
if safety_settings:
|
|
869
|
+
formatted_safety_settings = [
|
|
870
|
+
SafetySetting(category=c, threshold=t)
|
|
871
|
+
for c, t in safety_settings.items()
|
|
872
|
+
]
|
|
873
|
+
request = GenerateContentRequest(
|
|
874
|
+
model=self.model,
|
|
875
|
+
contents=history,
|
|
876
|
+
tools=formatted_tools,
|
|
877
|
+
tool_config=formatted_tool_config,
|
|
878
|
+
safety_settings=formatted_safety_settings,
|
|
879
|
+
generation_config=self._prepare_params(
|
|
880
|
+
stop, generation_config=generation_config
|
|
881
|
+
),
|
|
882
|
+
)
|
|
883
|
+
if system_instruction:
|
|
884
|
+
request.system_instruction = system_instruction
|
|
885
|
+
|
|
886
|
+
return request
|
|
781
887
|
|
|
782
888
|
def get_num_tokens(self, text: str) -> int:
|
|
783
889
|
"""Get the number of tokens present in the text.
|
|
@@ -790,14 +896,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
790
896
|
Returns:
|
|
791
897
|
The integer number of tokens in the text.
|
|
792
898
|
"""
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
|
798
|
-
token_count = result["token_count"]
|
|
799
|
-
|
|
800
|
-
return token_count
|
|
899
|
+
result = self.client.count_tokens(
|
|
900
|
+
model=self.model, contents=[Content(parts=[Part(text=text)])]
|
|
901
|
+
)
|
|
902
|
+
return result.total_tokens
|
|
801
903
|
|
|
802
904
|
def bind_tools(
|
|
803
905
|
self,
|
|
@@ -828,7 +930,9 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
828
930
|
genai_tools = [tool_to_dict(convert_to_genai_function_declarations(tools))]
|
|
829
931
|
if tool_choice:
|
|
830
932
|
all_names = [
|
|
831
|
-
f["name"]
|
|
933
|
+
f["name"] # type: ignore[index]
|
|
934
|
+
for t in genai_tools
|
|
935
|
+
for f in t["function_declarations"]
|
|
832
936
|
]
|
|
833
937
|
tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
|
|
834
938
|
return self.bind(tools=genai_tools, tool_config=tool_config, **kwargs)
|