langchain-google-genai 2.0.10__tar.gz → 2.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain-google-genai might be problematic. Click here for more details.
- {langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/PKG-INFO +3 -3
- langchain_google_genai-2.1.0/langchain_google_genai/_common.py +145 -0
- {langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/langchain_google_genai/_function_utils.py +9 -26
- {langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/langchain_google_genai/chat_models.py +72 -116
- langchain_google_genai-2.1.0/langchain_google_genai/llms.py +134 -0
- {langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/pyproject.toml +4 -4
- langchain_google_genai-2.0.10/langchain_google_genai/_common.py +0 -52
- langchain_google_genai-2.0.10/langchain_google_genai/llms.py +0 -406
- {langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/LICENSE +0 -0
- {langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/README.md +0 -0
- {langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/langchain_google_genai/__init__.py +0 -0
- {langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/langchain_google_genai/_enums.py +0 -0
- {langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/langchain_google_genai/_genai_extension.py +0 -0
- {langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/langchain_google_genai/_image_utils.py +0 -0
- {langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/langchain_google_genai/embeddings.py +0 -0
- {langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/langchain_google_genai/genai_aqa.py +0 -0
- {langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/langchain_google_genai/google_vector_store.py +0 -0
- {langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/langchain_google_genai/py.typed +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: langchain-google-genai
|
|
3
|
-
Version: 2.0
|
|
3
|
+
Version: 2.1.0
|
|
4
4
|
Summary: An integration package connecting Google's genai package and LangChain
|
|
5
5
|
Home-page: https://github.com/langchain-ai/langchain-google
|
|
6
6
|
License: MIT
|
|
@@ -12,8 +12,8 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.12
|
|
14
14
|
Requires-Dist: filetype (>=1.2.0,<2.0.0)
|
|
15
|
-
Requires-Dist: google-
|
|
16
|
-
Requires-Dist: langchain-core (>=0.3.
|
|
15
|
+
Requires-Dist: google-ai-generativelanguage (>=0.6.16,<0.7.0)
|
|
16
|
+
Requires-Dist: langchain-core (>=0.3.43,<0.4.0)
|
|
17
17
|
Requires-Dist: pydantic (>=2,<3)
|
|
18
18
|
Project-URL: Repository, https://github.com/langchain-ai/langchain-google
|
|
19
19
|
Project-URL: Source Code, https://github.com/langchain-ai/langchain-google/tree/main/libs/genai
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
from importlib import metadata
|
|
2
|
+
from typing import Any, Dict, Optional, Tuple, TypedDict
|
|
3
|
+
|
|
4
|
+
from google.api_core.gapic_v1.client_info import ClientInfo
|
|
5
|
+
from langchain_core.utils import secret_from_env
|
|
6
|
+
from pydantic import BaseModel, Field, SecretStr
|
|
7
|
+
|
|
8
|
+
from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GoogleGenerativeAIError(Exception):
|
|
12
|
+
"""
|
|
13
|
+
Custom exception class for errors associated with the `Google GenAI` API.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class _BaseGoogleGenerativeAI(BaseModel):
|
|
18
|
+
"""Base class for Google Generative AI LLMs"""
|
|
19
|
+
|
|
20
|
+
model: str = Field(
|
|
21
|
+
...,
|
|
22
|
+
description="""The name of the model to use.
|
|
23
|
+
Supported examples:
|
|
24
|
+
- gemini-pro
|
|
25
|
+
- models/text-bison-001""",
|
|
26
|
+
)
|
|
27
|
+
"""Model name to use."""
|
|
28
|
+
google_api_key: Optional[SecretStr] = Field(
|
|
29
|
+
alias="api_key", default_factory=secret_from_env("GOOGLE_API_KEY", default=None)
|
|
30
|
+
)
|
|
31
|
+
"""Google AI API key.
|
|
32
|
+
If not specified will be read from env var ``GOOGLE_API_KEY``."""
|
|
33
|
+
credentials: Any = None
|
|
34
|
+
"The default custom credentials (google.auth.credentials.Credentials) to use "
|
|
35
|
+
"when making API calls. If not provided, credentials will be ascertained from "
|
|
36
|
+
"the GOOGLE_API_KEY envvar"
|
|
37
|
+
temperature: float = 0.7
|
|
38
|
+
"""Run inference with this temperature. Must by in the closed interval
|
|
39
|
+
[0.0, 1.0]."""
|
|
40
|
+
top_p: Optional[float] = None
|
|
41
|
+
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
|
42
|
+
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
|
43
|
+
top_k: Optional[int] = None
|
|
44
|
+
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
|
45
|
+
Must be positive."""
|
|
46
|
+
max_output_tokens: Optional[int] = Field(default=None, alias="max_tokens")
|
|
47
|
+
"""Maximum number of tokens to include in a candidate. Must be greater than zero.
|
|
48
|
+
If unset, will default to 64."""
|
|
49
|
+
n: int = 1
|
|
50
|
+
"""Number of chat completions to generate for each prompt. Note that the API may
|
|
51
|
+
not return the full n completions if duplicates are generated."""
|
|
52
|
+
max_retries: int = 6
|
|
53
|
+
"""The maximum number of retries to make when generating."""
|
|
54
|
+
|
|
55
|
+
timeout: Optional[float] = None
|
|
56
|
+
"""The maximum number of seconds to wait for a response."""
|
|
57
|
+
|
|
58
|
+
client_options: Optional[Dict] = Field(
|
|
59
|
+
default=None,
|
|
60
|
+
description=(
|
|
61
|
+
"A dictionary of client options to pass to the Google API client, "
|
|
62
|
+
"such as `api_endpoint`."
|
|
63
|
+
),
|
|
64
|
+
)
|
|
65
|
+
transport: Optional[str] = Field(
|
|
66
|
+
default=None,
|
|
67
|
+
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
|
|
68
|
+
)
|
|
69
|
+
additional_headers: Optional[Dict[str, str]] = Field(
|
|
70
|
+
default=None,
|
|
71
|
+
description=(
|
|
72
|
+
"A key-value dictionary representing additional headers for the model call"
|
|
73
|
+
),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
|
|
77
|
+
"""The default safety settings to use for all generations.
|
|
78
|
+
|
|
79
|
+
For example:
|
|
80
|
+
|
|
81
|
+
from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
|
|
82
|
+
|
|
83
|
+
safety_settings = {
|
|
84
|
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
|
85
|
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
|
86
|
+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
|
87
|
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
|
88
|
+
}
|
|
89
|
+
""" # noqa: E501
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def lc_secrets(self) -> Dict[str, str]:
|
|
93
|
+
return {"google_api_key": "GOOGLE_API_KEY"}
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def _identifying_params(self) -> Dict[str, Any]:
|
|
97
|
+
"""Get the identifying parameters."""
|
|
98
|
+
return {
|
|
99
|
+
"model": self.model,
|
|
100
|
+
"temperature": self.temperature,
|
|
101
|
+
"top_p": self.top_p,
|
|
102
|
+
"top_k": self.top_k,
|
|
103
|
+
"max_output_tokens": self.max_output_tokens,
|
|
104
|
+
"candidate_count": self.n,
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def get_user_agent(module: Optional[str] = None) -> Tuple[str, str]:
|
|
109
|
+
r"""Returns a custom user agent header.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
module (Optional[str]):
|
|
113
|
+
Optional. The module for a custom user agent header.
|
|
114
|
+
Returns:
|
|
115
|
+
Tuple[str, str]
|
|
116
|
+
"""
|
|
117
|
+
try:
|
|
118
|
+
langchain_version = metadata.version("langchain-google-genai")
|
|
119
|
+
except metadata.PackageNotFoundError:
|
|
120
|
+
langchain_version = "0.0.0"
|
|
121
|
+
client_library_version = (
|
|
122
|
+
f"{langchain_version}-{module}" if module else langchain_version
|
|
123
|
+
)
|
|
124
|
+
return client_library_version, f"langchain-google-genai/{client_library_version}"
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def get_client_info(module: Optional[str] = None) -> "ClientInfo":
|
|
128
|
+
r"""Returns a client info object with a custom user agent header.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
module (Optional[str]):
|
|
132
|
+
Optional. The module for a custom user agent header.
|
|
133
|
+
Returns:
|
|
134
|
+
google.api_core.gapic_v1.client_info.ClientInfo
|
|
135
|
+
"""
|
|
136
|
+
client_library_version, user_agent = get_user_agent(module)
|
|
137
|
+
return ClientInfo(
|
|
138
|
+
client_library_version=client_library_version,
|
|
139
|
+
user_agent=user_agent,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class SafetySettingDict(TypedDict):
|
|
144
|
+
category: HarmCategory
|
|
145
|
+
threshold: HarmBlockThreshold
|
|
@@ -7,7 +7,6 @@ import logging
|
|
|
7
7
|
from typing import (
|
|
8
8
|
Any,
|
|
9
9
|
Callable,
|
|
10
|
-
Collection,
|
|
11
10
|
Dict,
|
|
12
11
|
List,
|
|
13
12
|
Literal,
|
|
@@ -22,7 +21,6 @@ from typing import (
|
|
|
22
21
|
import google.ai.generativelanguage as glm
|
|
23
22
|
import google.ai.generativelanguage_v1beta.types as gapic
|
|
24
23
|
import proto # type: ignore[import]
|
|
25
|
-
from google.generativeai.types.content_types import ToolDict # type: ignore[import]
|
|
26
24
|
from langchain_core.tools import BaseTool
|
|
27
25
|
from langchain_core.tools import tool as callable_as_lc_tool
|
|
28
26
|
from langchain_core.utils.function_calling import (
|
|
@@ -59,38 +57,21 @@ _ALLOWED_SCHEMA_FIELDS.extend(
|
|
|
59
57
|
_ALLOWED_SCHEMA_FIELDS_SET = set(_ALLOWED_SCHEMA_FIELDS)
|
|
60
58
|
|
|
61
59
|
|
|
62
|
-
class _ToolDictLike(TypedDict):
|
|
63
|
-
function_declarations: _FunctionDeclarationLikeList
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
class _FunctionDeclarationDict(TypedDict):
|
|
67
|
-
name: str
|
|
68
|
-
description: str
|
|
69
|
-
parameters: Dict[str, Collection[str]]
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
class _ToolDict(TypedDict):
|
|
73
|
-
function_declarations: Sequence[_FunctionDeclarationDict]
|
|
74
|
-
|
|
75
|
-
|
|
76
60
|
# Info: This is a FunctionDeclaration(=fc).
|
|
77
61
|
_FunctionDeclarationLike = Union[
|
|
78
62
|
BaseTool, Type[BaseModel], gapic.FunctionDeclaration, Callable, Dict[str, Any]
|
|
79
63
|
]
|
|
80
64
|
|
|
81
|
-
|
|
82
|
-
|
|
65
|
+
|
|
66
|
+
class _ToolDict(TypedDict):
|
|
67
|
+
function_declarations: Sequence[_FunctionDeclarationLike]
|
|
83
68
|
|
|
84
69
|
|
|
85
70
|
# Info: This means one tool=Sequence of FunctionDeclaration
|
|
86
71
|
# The dict should be gapic.Tool like. {"function_declarations": [ { "name": ...}.
|
|
87
72
|
# OpenAI like dict is not be accepted. {{'type': 'function', 'function': {'name': ...}
|
|
88
73
|
_ToolsType = Union[
|
|
89
|
-
gapic.Tool,
|
|
90
|
-
ToolDict,
|
|
91
|
-
_ToolDictLike,
|
|
92
|
-
_FunctionDeclarationLikeList,
|
|
93
|
-
_FunctionDeclarationLike,
|
|
74
|
+
gapic.Tool, _ToolDict, _FunctionDeclarationLike, Sequence[_FunctionDeclarationLike]
|
|
94
75
|
]
|
|
95
76
|
|
|
96
77
|
|
|
@@ -152,12 +133,12 @@ def convert_to_genai_function_declarations(
|
|
|
152
133
|
gapic_tool = gapic.Tool()
|
|
153
134
|
for tool in tools:
|
|
154
135
|
if isinstance(tool, gapic.Tool):
|
|
155
|
-
gapic_tool.function_declarations.extend(tool.function_declarations)
|
|
136
|
+
gapic_tool.function_declarations.extend(tool.function_declarations) # type: ignore[union-attr]
|
|
156
137
|
elif isinstance(tool, dict) and "function_declarations" not in tool:
|
|
157
138
|
fd = _format_to_gapic_function_declaration(tool)
|
|
158
139
|
gapic_tool.function_declarations.append(fd)
|
|
159
140
|
elif isinstance(tool, dict):
|
|
160
|
-
function_declarations = cast(
|
|
141
|
+
function_declarations = cast(_ToolDict, tool)["function_declarations"]
|
|
161
142
|
if not isinstance(function_declarations, collections.abc.Sequence):
|
|
162
143
|
raise ValueError(
|
|
163
144
|
"function_declarations should be a list"
|
|
@@ -170,7 +151,7 @@ def convert_to_genai_function_declarations(
|
|
|
170
151
|
]
|
|
171
152
|
gapic_tool.function_declarations.extend(fds)
|
|
172
153
|
else:
|
|
173
|
-
fd = _format_to_gapic_function_declaration(tool)
|
|
154
|
+
fd = _format_to_gapic_function_declaration(tool) # type: ignore[arg-type]
|
|
174
155
|
gapic_tool.function_declarations.append(fd)
|
|
175
156
|
return gapic_tool
|
|
176
157
|
|
|
@@ -375,6 +356,8 @@ def _get_items_from_schema(schema: Union[Dict, List, str]) -> Dict[str, Any]:
|
|
|
375
356
|
)
|
|
376
357
|
if _is_nullable_schema(schema):
|
|
377
358
|
items["nullable"] = True
|
|
359
|
+
if "required" in schema:
|
|
360
|
+
items["required"] = schema["required"]
|
|
378
361
|
else:
|
|
379
362
|
# str
|
|
380
363
|
items["type_"] = _get_type_from_schema({"type": schema})
|
{langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/langchain_google_genai/chat_models.py
RENAMED
|
@@ -35,6 +35,7 @@ from google.ai.generativelanguage_v1beta.types import (
|
|
|
35
35
|
Content,
|
|
36
36
|
FileData,
|
|
37
37
|
FunctionCall,
|
|
38
|
+
FunctionDeclaration,
|
|
38
39
|
FunctionResponse,
|
|
39
40
|
GenerateContentRequest,
|
|
40
41
|
GenerateContentResponse,
|
|
@@ -44,12 +45,8 @@ from google.ai.generativelanguage_v1beta.types import (
|
|
|
44
45
|
ToolConfig,
|
|
45
46
|
VideoMetadata,
|
|
46
47
|
)
|
|
47
|
-
from google.
|
|
48
|
-
|
|
49
|
-
from google.generativeai.types import caching_types, content_types
|
|
50
|
-
from google.generativeai.types.content_types import ( # type: ignore[import]
|
|
51
|
-
FunctionDeclarationType,
|
|
52
|
-
ToolDict,
|
|
48
|
+
from google.ai.generativelanguage_v1beta.types import (
|
|
49
|
+
Tool as GoogleTool,
|
|
53
50
|
)
|
|
54
51
|
from langchain_core.callbacks.manager import (
|
|
55
52
|
AsyncCallbackManagerForLLMRun,
|
|
@@ -70,13 +67,13 @@ from langchain_core.messages.ai import UsageMetadata
|
|
|
70
67
|
from langchain_core.messages.tool import invalid_tool_call, tool_call, tool_call_chunk
|
|
71
68
|
from langchain_core.output_parsers.base import OutputParserLike
|
|
72
69
|
from langchain_core.output_parsers.openai_tools import (
|
|
73
|
-
|
|
70
|
+
JsonOutputKeyToolsParser,
|
|
74
71
|
PydanticToolsParser,
|
|
75
72
|
parse_tool_calls,
|
|
76
73
|
)
|
|
77
74
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
78
75
|
from langchain_core.runnables import Runnable, RunnablePassthrough
|
|
79
|
-
from langchain_core.
|
|
76
|
+
from langchain_core.tools import BaseTool
|
|
80
77
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
|
81
78
|
from pydantic import (
|
|
82
79
|
BaseModel,
|
|
@@ -92,29 +89,40 @@ from tenacity import (
|
|
|
92
89
|
stop_after_attempt,
|
|
93
90
|
wait_exponential,
|
|
94
91
|
)
|
|
95
|
-
from typing_extensions import Self
|
|
92
|
+
from typing_extensions import Self, is_typeddict
|
|
96
93
|
|
|
97
94
|
from langchain_google_genai._common import (
|
|
98
95
|
GoogleGenerativeAIError,
|
|
99
96
|
SafetySettingDict,
|
|
97
|
+
_BaseGoogleGenerativeAI,
|
|
100
98
|
get_client_info,
|
|
101
99
|
)
|
|
102
100
|
from langchain_google_genai._function_utils import (
|
|
103
101
|
_tool_choice_to_tool_config,
|
|
104
102
|
_ToolChoiceType,
|
|
105
103
|
_ToolConfigDict,
|
|
104
|
+
_ToolDict,
|
|
106
105
|
convert_to_genai_function_declarations,
|
|
107
106
|
is_basemodel_subclass_safe,
|
|
108
107
|
tool_to_dict,
|
|
109
108
|
)
|
|
110
109
|
from langchain_google_genai._image_utils import ImageBytesLoader
|
|
111
|
-
from langchain_google_genai.llms import _BaseGoogleGenerativeAI
|
|
112
110
|
|
|
113
111
|
from . import _genai_extension as genaix
|
|
114
112
|
|
|
113
|
+
WARNED_STRUCTURED_OUTPUT_JSON_MODE = False
|
|
114
|
+
|
|
115
|
+
|
|
115
116
|
logger = logging.getLogger(__name__)
|
|
116
117
|
|
|
117
118
|
|
|
119
|
+
_FunctionDeclarationType = Union[
|
|
120
|
+
FunctionDeclaration,
|
|
121
|
+
dict[str, Any],
|
|
122
|
+
Callable[..., Any],
|
|
123
|
+
]
|
|
124
|
+
|
|
125
|
+
|
|
118
126
|
class ChatGoogleGenerativeAIError(GoogleGenerativeAIError):
|
|
119
127
|
"""
|
|
120
128
|
Custom exception class for errors associated with the `Google GenAI` API.
|
|
@@ -786,11 +794,6 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
786
794
|
|
|
787
795
|
client: Any = Field(default=None, exclude=True) #: :meta private:
|
|
788
796
|
async_client_running: Any = Field(default=None, exclude=True) #: :meta private:
|
|
789
|
-
google_api_key: Optional[SecretStr] = Field(
|
|
790
|
-
alias="api_key", default_factory=secret_from_env("GOOGLE_API_KEY", default=None)
|
|
791
|
-
)
|
|
792
|
-
"""Google AI API key.
|
|
793
|
-
If not specified will be read from env var ``GOOGLE_API_KEY``."""
|
|
794
797
|
default_metadata: Sequence[Tuple[str, str]] = Field(
|
|
795
798
|
default_factory=list
|
|
796
799
|
) #: :meta private:
|
|
@@ -938,8 +941,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
938
941
|
stop: Optional[List[str]] = None,
|
|
939
942
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
940
943
|
*,
|
|
941
|
-
tools: Optional[Sequence[Union[
|
|
942
|
-
functions: Optional[Sequence[
|
|
944
|
+
tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
|
|
945
|
+
functions: Optional[Sequence[_FunctionDeclarationType]] = None,
|
|
943
946
|
safety_settings: Optional[SafetySettingDict] = None,
|
|
944
947
|
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
945
948
|
generation_config: Optional[Dict[str, Any]] = None,
|
|
@@ -972,8 +975,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
972
975
|
stop: Optional[List[str]] = None,
|
|
973
976
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
974
977
|
*,
|
|
975
|
-
tools: Optional[Sequence[Union[
|
|
976
|
-
functions: Optional[Sequence[
|
|
978
|
+
tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
|
|
979
|
+
functions: Optional[Sequence[_FunctionDeclarationType]] = None,
|
|
977
980
|
safety_settings: Optional[SafetySettingDict] = None,
|
|
978
981
|
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
979
982
|
generation_config: Optional[Dict[str, Any]] = None,
|
|
@@ -1021,8 +1024,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
1021
1024
|
stop: Optional[List[str]] = None,
|
|
1022
1025
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
1023
1026
|
*,
|
|
1024
|
-
tools: Optional[Sequence[Union[
|
|
1025
|
-
functions: Optional[Sequence[
|
|
1027
|
+
tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
|
|
1028
|
+
functions: Optional[Sequence[_FunctionDeclarationType]] = None,
|
|
1026
1029
|
safety_settings: Optional[SafetySettingDict] = None,
|
|
1027
1030
|
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
1028
1031
|
generation_config: Optional[Dict[str, Any]] = None,
|
|
@@ -1083,8 +1086,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
1083
1086
|
stop: Optional[List[str]] = None,
|
|
1084
1087
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
1085
1088
|
*,
|
|
1086
|
-
tools: Optional[Sequence[Union[
|
|
1087
|
-
functions: Optional[Sequence[
|
|
1089
|
+
tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
|
|
1090
|
+
functions: Optional[Sequence[_FunctionDeclarationType]] = None,
|
|
1088
1091
|
safety_settings: Optional[SafetySettingDict] = None,
|
|
1089
1092
|
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
1090
1093
|
generation_config: Optional[Dict[str, Any]] = None,
|
|
@@ -1158,8 +1161,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
1158
1161
|
messages: List[BaseMessage],
|
|
1159
1162
|
*,
|
|
1160
1163
|
stop: Optional[List[str]] = None,
|
|
1161
|
-
tools: Optional[Sequence[Union[
|
|
1162
|
-
functions: Optional[Sequence[
|
|
1164
|
+
tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
|
|
1165
|
+
functions: Optional[Sequence[_FunctionDeclarationType]] = None,
|
|
1163
1166
|
safety_settings: Optional[SafetySettingDict] = None,
|
|
1164
1167
|
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
1165
1168
|
tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
|
|
@@ -1177,6 +1180,16 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
1177
1180
|
elif functions:
|
|
1178
1181
|
formatted_tools = [convert_to_genai_function_declarations(functions)]
|
|
1179
1182
|
|
|
1183
|
+
filtered_messages = []
|
|
1184
|
+
for message in messages:
|
|
1185
|
+
if isinstance(message, HumanMessage) and not message.content:
|
|
1186
|
+
warnings.warn(
|
|
1187
|
+
"HumanMessage with empty content was removed to prevent API error"
|
|
1188
|
+
)
|
|
1189
|
+
else:
|
|
1190
|
+
filtered_messages.append(message)
|
|
1191
|
+
messages = filtered_messages
|
|
1192
|
+
|
|
1180
1193
|
system_instruction, history = _parse_chat_history(
|
|
1181
1194
|
messages,
|
|
1182
1195
|
convert_system_message_to_human=self.convert_system_message_to_human,
|
|
@@ -1245,14 +1258,33 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
1245
1258
|
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
|
1246
1259
|
if kwargs:
|
|
1247
1260
|
raise ValueError(f"Received unsupported arguments {kwargs}")
|
|
1261
|
+
tool_name = _get_tool_name(schema) # type: ignore[arg-type]
|
|
1248
1262
|
if isinstance(schema, type) and is_basemodel_subclass_safe(schema):
|
|
1249
1263
|
parser: OutputParserLike = PydanticToolsParser(
|
|
1250
1264
|
tools=[schema], first_tool_only=True
|
|
1251
1265
|
)
|
|
1252
1266
|
else:
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1267
|
+
global WARNED_STRUCTURED_OUTPUT_JSON_MODE
|
|
1268
|
+
warnings.warn(
|
|
1269
|
+
"ChatGoogleGenerativeAI.with_structured_output with dict schema has "
|
|
1270
|
+
"changed recently to align with behavior of other LangChain chat "
|
|
1271
|
+
"models. More context: "
|
|
1272
|
+
"https://github.com/langchain-ai/langchain-google/pull/772"
|
|
1273
|
+
)
|
|
1274
|
+
WARNED_STRUCTURED_OUTPUT_JSON_MODE = True
|
|
1275
|
+
parser = JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
|
|
1276
|
+
tool_choice = tool_name if self._supports_tool_choice else None
|
|
1277
|
+
try:
|
|
1278
|
+
llm = self.bind_tools(
|
|
1279
|
+
[schema],
|
|
1280
|
+
tool_choice=tool_choice,
|
|
1281
|
+
ls_structured_output_format={
|
|
1282
|
+
"kwargs": {"method": "function_calling"},
|
|
1283
|
+
"schema": convert_to_openai_tool(schema),
|
|
1284
|
+
},
|
|
1285
|
+
)
|
|
1286
|
+
except Exception:
|
|
1287
|
+
llm = self.bind_tools([schema], tool_choice=tool_choice)
|
|
1256
1288
|
if include_raw:
|
|
1257
1289
|
parser_with_fallback = RunnablePassthrough.assign(
|
|
1258
1290
|
parsed=itemgetter("raw") | parser, parsing_error=lambda _: None
|
|
@@ -1266,7 +1298,9 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
1266
1298
|
|
|
1267
1299
|
def bind_tools(
|
|
1268
1300
|
self,
|
|
1269
|
-
tools: Sequence[
|
|
1301
|
+
tools: Sequence[
|
|
1302
|
+
dict[str, Any] | type | Callable[..., Any] | BaseTool | GoogleTool
|
|
1303
|
+
],
|
|
1270
1304
|
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
1271
1305
|
*,
|
|
1272
1306
|
tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
|
|
@@ -1303,90 +1337,6 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
1303
1337
|
pass
|
|
1304
1338
|
return self.bind(tools=formatted_tools, **kwargs)
|
|
1305
1339
|
|
|
1306
|
-
def create_cached_content(
|
|
1307
|
-
self,
|
|
1308
|
-
contents: Union[List[BaseMessage], content_types.ContentsType],
|
|
1309
|
-
*,
|
|
1310
|
-
display_name: str | None = None,
|
|
1311
|
-
tools: Union[ToolDict, GoogleTool, None] = None,
|
|
1312
|
-
tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
|
|
1313
|
-
ttl: Optional[caching_types.TTLTypes] = None,
|
|
1314
|
-
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
|
|
1315
|
-
) -> str:
|
|
1316
|
-
"""
|
|
1317
|
-
|
|
1318
|
-
Args:
|
|
1319
|
-
display_name: The user-generated meaningful display name
|
|
1320
|
-
of the cached content. `display_name` must be no
|
|
1321
|
-
more than 128 unicode characters.
|
|
1322
|
-
contents: Contents to cache.
|
|
1323
|
-
tools: A list of `Tools` the model may use to generate response.
|
|
1324
|
-
tool_choice: Which tool to require the model to call.
|
|
1325
|
-
ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
|
|
1326
|
-
`ttl` and `expire_time` are exclusive arguments.
|
|
1327
|
-
expire_time: Expiration time for cached resource.
|
|
1328
|
-
`ttl` and `expire_time` are exclusive arguments.
|
|
1329
|
-
"""
|
|
1330
|
-
system: Optional[content_types.ContentType] = None
|
|
1331
|
-
genai_contents: list = []
|
|
1332
|
-
if all(isinstance(c, BaseMessage) for c in contents):
|
|
1333
|
-
system, genai_contents = _parse_chat_history(
|
|
1334
|
-
contents,
|
|
1335
|
-
convert_system_message_to_human=self.convert_system_message_to_human,
|
|
1336
|
-
)
|
|
1337
|
-
elif any(isinstance(c, BaseMessage) for c in contents):
|
|
1338
|
-
raise ValueError(
|
|
1339
|
-
f"'contents' must either be a list of "
|
|
1340
|
-
f"langchain_core.messages.BaseMessage or a list "
|
|
1341
|
-
f"google.generativeai.types.content_types.ContentType, but not a mix "
|
|
1342
|
-
f"of the two. Received {contents}"
|
|
1343
|
-
)
|
|
1344
|
-
else:
|
|
1345
|
-
for content in contents:
|
|
1346
|
-
if hasattr(content, "role") and content.role == "system":
|
|
1347
|
-
if system is not None:
|
|
1348
|
-
warnings.warn(
|
|
1349
|
-
"Received multiple pieces of content with role 'system'. "
|
|
1350
|
-
"Should only be one set of system instructions. Ignoring "
|
|
1351
|
-
"all but the first 'system' content."
|
|
1352
|
-
)
|
|
1353
|
-
else:
|
|
1354
|
-
system = content
|
|
1355
|
-
elif isinstance(content, dict) and content.get("role") == "system":
|
|
1356
|
-
if system is not None:
|
|
1357
|
-
warnings.warn(
|
|
1358
|
-
"Received multiple pieces of content with role 'system'. "
|
|
1359
|
-
"Should only be one set of system instructions. Ignoring "
|
|
1360
|
-
"all but the first 'system' content."
|
|
1361
|
-
)
|
|
1362
|
-
else:
|
|
1363
|
-
system = content
|
|
1364
|
-
else:
|
|
1365
|
-
genai_contents.append(content)
|
|
1366
|
-
if tools:
|
|
1367
|
-
genai_tools = [convert_to_genai_function_declarations(tools)]
|
|
1368
|
-
else:
|
|
1369
|
-
genai_tools = None
|
|
1370
|
-
if tool_choice and genai_tools:
|
|
1371
|
-
all_names = [f.name for t in genai_tools for f in t.function_declarations]
|
|
1372
|
-
tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
|
|
1373
|
-
genai_tool_config = ToolConfig(
|
|
1374
|
-
function_calling_config=tool_config["function_calling_config"]
|
|
1375
|
-
)
|
|
1376
|
-
else:
|
|
1377
|
-
genai_tool_config = None
|
|
1378
|
-
cached_content = CachedContent.create(
|
|
1379
|
-
model=self.model,
|
|
1380
|
-
system_instruction=system,
|
|
1381
|
-
contents=genai_contents,
|
|
1382
|
-
display_name=display_name,
|
|
1383
|
-
tools=genai_tools,
|
|
1384
|
-
tool_config=genai_tool_config,
|
|
1385
|
-
ttl=ttl,
|
|
1386
|
-
expire_time=expire_time,
|
|
1387
|
-
)
|
|
1388
|
-
return cached_content.name
|
|
1389
|
-
|
|
1390
1340
|
@property
|
|
1391
1341
|
def _supports_tool_choice(self) -> bool:
|
|
1392
1342
|
return (
|
|
@@ -1397,7 +1347,13 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
1397
1347
|
|
|
1398
1348
|
|
|
1399
1349
|
def _get_tool_name(
|
|
1400
|
-
tool: Union[
|
|
1350
|
+
tool: Union[_ToolDict, GoogleTool, Dict],
|
|
1401
1351
|
) -> str:
|
|
1402
|
-
|
|
1403
|
-
|
|
1352
|
+
try:
|
|
1353
|
+
genai_tool = tool_to_dict(convert_to_genai_function_declarations([tool]))
|
|
1354
|
+
return [f["name"] for f in genai_tool["function_declarations"]][0] # type: ignore[index]
|
|
1355
|
+
except ValueError as e: # other TypedDict
|
|
1356
|
+
if is_typeddict(tool):
|
|
1357
|
+
return convert_to_openai_tool(cast(Dict, tool))["function"]["name"]
|
|
1358
|
+
else:
|
|
1359
|
+
raise e
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Iterator, List, Optional
|
|
4
|
+
|
|
5
|
+
from langchain_core.callbacks import (
|
|
6
|
+
CallbackManagerForLLMRun,
|
|
7
|
+
)
|
|
8
|
+
from langchain_core.language_models import LangSmithParams
|
|
9
|
+
from langchain_core.language_models.llms import BaseLLM
|
|
10
|
+
from langchain_core.messages import HumanMessage
|
|
11
|
+
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
|
12
|
+
from pydantic import ConfigDict, model_validator
|
|
13
|
+
from typing_extensions import Self
|
|
14
|
+
|
|
15
|
+
from langchain_google_genai._common import (
|
|
16
|
+
_BaseGoogleGenerativeAI,
|
|
17
|
+
)
|
|
18
|
+
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
|
|
22
|
+
"""Google GenerativeAI models.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
.. code-block:: python
|
|
26
|
+
|
|
27
|
+
from langchain_google_genai import GoogleGenerativeAI
|
|
28
|
+
llm = GoogleGenerativeAI(model="gemini-pro")
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
client: Any = None #: :meta private:
|
|
32
|
+
model_config = ConfigDict(
|
|
33
|
+
populate_by_name=True,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
@model_validator(mode="after")
|
|
37
|
+
def validate_environment(self) -> Self:
|
|
38
|
+
"""Validates params and passes them to google-generativeai package."""
|
|
39
|
+
|
|
40
|
+
self.client = ChatGoogleGenerativeAI(
|
|
41
|
+
api_key=self.google_api_key,
|
|
42
|
+
credentials=self.credentials,
|
|
43
|
+
temperature=self.temperature,
|
|
44
|
+
top_p=self.top_p,
|
|
45
|
+
top_k=self.top_k,
|
|
46
|
+
max_tokens=self.max_output_tokens,
|
|
47
|
+
timeout=self.timeout,
|
|
48
|
+
model=self.model,
|
|
49
|
+
client_options=self.client_options,
|
|
50
|
+
transport=self.transport,
|
|
51
|
+
additional_headers=self.additional_headers,
|
|
52
|
+
safety_settings=self.safety_settings,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
return self
|
|
56
|
+
|
|
57
|
+
def _get_ls_params(
|
|
58
|
+
self, stop: Optional[List[str]] = None, **kwargs: Any
|
|
59
|
+
) -> LangSmithParams:
|
|
60
|
+
"""Get standard params for tracing."""
|
|
61
|
+
ls_params = super()._get_ls_params(stop=stop, **kwargs)
|
|
62
|
+
ls_params["ls_provider"] = "google_genai"
|
|
63
|
+
if ls_max_tokens := kwargs.get("max_output_tokens", self.max_output_tokens):
|
|
64
|
+
ls_params["ls_max_tokens"] = ls_max_tokens
|
|
65
|
+
return ls_params
|
|
66
|
+
|
|
67
|
+
def _generate(
|
|
68
|
+
self,
|
|
69
|
+
prompts: List[str],
|
|
70
|
+
stop: Optional[List[str]] = None,
|
|
71
|
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
72
|
+
**kwargs: Any,
|
|
73
|
+
) -> LLMResult:
|
|
74
|
+
generations = []
|
|
75
|
+
for prompt in prompts:
|
|
76
|
+
chat_result = self.client._generate(
|
|
77
|
+
[HumanMessage(content=prompt)],
|
|
78
|
+
stop=stop,
|
|
79
|
+
run_manager=run_manager,
|
|
80
|
+
**kwargs,
|
|
81
|
+
)
|
|
82
|
+
generations.append(
|
|
83
|
+
[
|
|
84
|
+
Generation(
|
|
85
|
+
text=g.message.content,
|
|
86
|
+
generation_info={
|
|
87
|
+
**g.generation_info,
|
|
88
|
+
**{"usage_metadata": g.message.usage_metadata},
|
|
89
|
+
},
|
|
90
|
+
)
|
|
91
|
+
for g in chat_result.generations
|
|
92
|
+
]
|
|
93
|
+
)
|
|
94
|
+
return LLMResult(generations=generations)
|
|
95
|
+
|
|
96
|
+
def _stream(
|
|
97
|
+
self,
|
|
98
|
+
prompt: str,
|
|
99
|
+
stop: Optional[List[str]] = None,
|
|
100
|
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
101
|
+
**kwargs: Any,
|
|
102
|
+
) -> Iterator[GenerationChunk]:
|
|
103
|
+
for stream_chunk in self.client._stream(
|
|
104
|
+
[HumanMessage(content=prompt)],
|
|
105
|
+
stop=stop,
|
|
106
|
+
run_manager=run_manager,
|
|
107
|
+
**kwargs,
|
|
108
|
+
):
|
|
109
|
+
chunk = GenerationChunk(text=stream_chunk.message.content)
|
|
110
|
+
yield chunk
|
|
111
|
+
if run_manager:
|
|
112
|
+
run_manager.on_llm_new_token(
|
|
113
|
+
chunk.text,
|
|
114
|
+
chunk=chunk,
|
|
115
|
+
verbose=self.verbose,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def _llm_type(self) -> str:
|
|
120
|
+
"""Return type of llm."""
|
|
121
|
+
return "google_gemini"
|
|
122
|
+
|
|
123
|
+
def get_num_tokens(self, text: str) -> int:
|
|
124
|
+
"""Get the number of tokens present in the text.
|
|
125
|
+
|
|
126
|
+
Useful for checking if an input will fit in a model's context window.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
text: The string input to tokenize.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
The integer number of tokens in the text.
|
|
133
|
+
"""
|
|
134
|
+
return self.client.get_num_tokens(text)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "langchain-google-genai"
|
|
3
|
-
version = "2.0
|
|
3
|
+
version = "2.1.0"
|
|
4
4
|
description = "An integration package connecting Google's genai package and LangChain"
|
|
5
5
|
authors = []
|
|
6
6
|
readme = "README.md"
|
|
@@ -12,8 +12,8 @@ license = "MIT"
|
|
|
12
12
|
|
|
13
13
|
[tool.poetry.dependencies]
|
|
14
14
|
python = ">=3.9,<4.0"
|
|
15
|
-
langchain-core = "^0.3.
|
|
16
|
-
google-
|
|
15
|
+
langchain-core = "^0.3.43"
|
|
16
|
+
google-ai-generativelanguage = "^0.6.16"
|
|
17
17
|
pydantic = ">=2,<3"
|
|
18
18
|
filetype = "^1.2.0"
|
|
19
19
|
|
|
@@ -28,7 +28,7 @@ syrupy = "^4.0.2"
|
|
|
28
28
|
pytest-watcher = "^0.3.4"
|
|
29
29
|
pytest-asyncio = "^0.21.1"
|
|
30
30
|
numpy = "^1.26.2"
|
|
31
|
-
langchain-tests = "0.3.
|
|
31
|
+
langchain-tests = "0.3.14"
|
|
32
32
|
|
|
33
33
|
[tool.codespell]
|
|
34
34
|
ignore-words-list = "rouge"
|
|
@@ -1,52 +0,0 @@
|
|
|
1
|
-
from importlib import metadata
|
|
2
|
-
from typing import Optional, Tuple, TypedDict
|
|
3
|
-
|
|
4
|
-
from google.api_core.gapic_v1.client_info import ClientInfo
|
|
5
|
-
|
|
6
|
-
from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class GoogleGenerativeAIError(Exception):
|
|
10
|
-
"""
|
|
11
|
-
Custom exception class for errors associated with the `Google GenAI` API.
|
|
12
|
-
"""
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def get_user_agent(module: Optional[str] = None) -> Tuple[str, str]:
|
|
16
|
-
r"""Returns a custom user agent header.
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
module (Optional[str]):
|
|
20
|
-
Optional. The module for a custom user agent header.
|
|
21
|
-
Returns:
|
|
22
|
-
Tuple[str, str]
|
|
23
|
-
"""
|
|
24
|
-
try:
|
|
25
|
-
langchain_version = metadata.version("langchain-google-genai")
|
|
26
|
-
except metadata.PackageNotFoundError:
|
|
27
|
-
langchain_version = "0.0.0"
|
|
28
|
-
client_library_version = (
|
|
29
|
-
f"{langchain_version}-{module}" if module else langchain_version
|
|
30
|
-
)
|
|
31
|
-
return client_library_version, f"langchain-google-genai/{client_library_version}"
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def get_client_info(module: Optional[str] = None) -> "ClientInfo":
|
|
35
|
-
r"""Returns a client info object with a custom user agent header.
|
|
36
|
-
|
|
37
|
-
Args:
|
|
38
|
-
module (Optional[str]):
|
|
39
|
-
Optional. The module for a custom user agent header.
|
|
40
|
-
Returns:
|
|
41
|
-
google.api_core.gapic_v1.client_info.ClientInfo
|
|
42
|
-
"""
|
|
43
|
-
client_library_version, user_agent = get_user_agent(module)
|
|
44
|
-
return ClientInfo(
|
|
45
|
-
client_library_version=client_library_version,
|
|
46
|
-
user_agent=user_agent,
|
|
47
|
-
)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class SafetySettingDict(TypedDict):
|
|
51
|
-
category: HarmCategory
|
|
52
|
-
threshold: HarmBlockThreshold
|
|
@@ -1,406 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from enum import Enum, auto
|
|
4
|
-
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
|
5
|
-
|
|
6
|
-
import google.api_core
|
|
7
|
-
import google.generativeai as genai # type: ignore[import]
|
|
8
|
-
from langchain_core.callbacks import (
|
|
9
|
-
AsyncCallbackManagerForLLMRun,
|
|
10
|
-
CallbackManagerForLLMRun,
|
|
11
|
-
)
|
|
12
|
-
from langchain_core.language_models import LangSmithParams, LanguageModelInput
|
|
13
|
-
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
|
|
14
|
-
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
|
15
|
-
from langchain_core.utils import secret_from_env
|
|
16
|
-
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
|
17
|
-
from typing_extensions import Self
|
|
18
|
-
|
|
19
|
-
from langchain_google_genai._enums import (
|
|
20
|
-
HarmBlockThreshold,
|
|
21
|
-
HarmCategory,
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class GoogleModelFamily(str, Enum):
|
|
26
|
-
GEMINI = auto()
|
|
27
|
-
PALM = auto()
|
|
28
|
-
|
|
29
|
-
@classmethod
|
|
30
|
-
def _missing_(cls, value: Any) -> Optional["GoogleModelFamily"]:
|
|
31
|
-
if "gemini" in value.lower():
|
|
32
|
-
return GoogleModelFamily.GEMINI
|
|
33
|
-
elif "text-bison" in value.lower():
|
|
34
|
-
return GoogleModelFamily.PALM
|
|
35
|
-
return None
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def _create_retry_decorator(
|
|
39
|
-
llm: BaseLLM,
|
|
40
|
-
*,
|
|
41
|
-
max_retries: int = 1,
|
|
42
|
-
run_manager: Optional[
|
|
43
|
-
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
|
44
|
-
] = None,
|
|
45
|
-
) -> Callable[[Any], Any]:
|
|
46
|
-
"""Creates a retry decorator for Vertex / Palm LLMs."""
|
|
47
|
-
|
|
48
|
-
errors = [
|
|
49
|
-
google.api_core.exceptions.ResourceExhausted,
|
|
50
|
-
google.api_core.exceptions.ServiceUnavailable,
|
|
51
|
-
google.api_core.exceptions.Aborted,
|
|
52
|
-
google.api_core.exceptions.DeadlineExceeded,
|
|
53
|
-
google.api_core.exceptions.GoogleAPIError,
|
|
54
|
-
]
|
|
55
|
-
decorator = create_base_retry_decorator(
|
|
56
|
-
error_types=errors, max_retries=max_retries, run_manager=run_manager
|
|
57
|
-
)
|
|
58
|
-
return decorator
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def _completion_with_retry(
|
|
62
|
-
llm: GoogleGenerativeAI,
|
|
63
|
-
prompt: LanguageModelInput,
|
|
64
|
-
is_gemini: bool = False,
|
|
65
|
-
stream: bool = False,
|
|
66
|
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
67
|
-
**kwargs: Any,
|
|
68
|
-
) -> Any:
|
|
69
|
-
"""Use tenacity to retry the completion call."""
|
|
70
|
-
retry_decorator = _create_retry_decorator(
|
|
71
|
-
llm, max_retries=llm.max_retries, run_manager=run_manager
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
@retry_decorator
|
|
75
|
-
def _completion_with_retry(
|
|
76
|
-
prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
|
|
77
|
-
) -> Any:
|
|
78
|
-
generation_config = kwargs.get("generation_config", {})
|
|
79
|
-
error_msg = (
|
|
80
|
-
"Your location is not supported by google-generativeai at the moment. "
|
|
81
|
-
"Try to use VertexAI LLM from langchain_google_vertexai"
|
|
82
|
-
)
|
|
83
|
-
try:
|
|
84
|
-
if is_gemini:
|
|
85
|
-
return llm.client.generate_content(
|
|
86
|
-
contents=prompt,
|
|
87
|
-
stream=stream,
|
|
88
|
-
generation_config=generation_config,
|
|
89
|
-
safety_settings=kwargs.pop("safety_settings", None),
|
|
90
|
-
request_options={"timeout": llm.timeout} if llm.timeout else None,
|
|
91
|
-
)
|
|
92
|
-
return llm.client.generate_text(prompt=prompt, **kwargs)
|
|
93
|
-
except google.api_core.exceptions.FailedPrecondition as exc:
|
|
94
|
-
if "location is not supported" in exc.message:
|
|
95
|
-
raise ValueError(error_msg)
|
|
96
|
-
|
|
97
|
-
return _completion_with_retry(
|
|
98
|
-
prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
|
|
99
|
-
)
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
def _strip_erroneous_leading_spaces(text: str) -> str:
|
|
103
|
-
"""Strip erroneous leading spaces from text.
|
|
104
|
-
|
|
105
|
-
The PaLM API will sometimes erroneously return a single leading space in all
|
|
106
|
-
lines > 1. This function strips that space.
|
|
107
|
-
"""
|
|
108
|
-
has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:])
|
|
109
|
-
if has_leading_space:
|
|
110
|
-
return text.replace("\n ", "\n")
|
|
111
|
-
else:
|
|
112
|
-
return text
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
class _BaseGoogleGenerativeAI(BaseModel):
|
|
116
|
-
"""Base class for Google Generative AI LLMs"""
|
|
117
|
-
|
|
118
|
-
model: str = Field(
|
|
119
|
-
...,
|
|
120
|
-
description="""The name of the model to use.
|
|
121
|
-
Supported examples:
|
|
122
|
-
- gemini-pro
|
|
123
|
-
- models/text-bison-001""",
|
|
124
|
-
)
|
|
125
|
-
"""Model name to use."""
|
|
126
|
-
google_api_key: Optional[SecretStr] = Field(
|
|
127
|
-
alias="api_key", default_factory=secret_from_env("GOOGLE_API_KEY", default=None)
|
|
128
|
-
)
|
|
129
|
-
credentials: Any = None
|
|
130
|
-
"The default custom credentials (google.auth.credentials.Credentials) to use "
|
|
131
|
-
"when making API calls. If not provided, credentials will be ascertained from "
|
|
132
|
-
"the GOOGLE_API_KEY envvar"
|
|
133
|
-
temperature: float = 0.7
|
|
134
|
-
"""Run inference with this temperature. Must by in the closed interval
|
|
135
|
-
[0.0, 1.0]."""
|
|
136
|
-
top_p: Optional[float] = None
|
|
137
|
-
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
|
138
|
-
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
|
139
|
-
top_k: Optional[int] = None
|
|
140
|
-
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
|
141
|
-
Must be positive."""
|
|
142
|
-
max_output_tokens: Optional[int] = Field(default=None, alias="max_tokens")
|
|
143
|
-
"""Maximum number of tokens to include in a candidate. Must be greater than zero.
|
|
144
|
-
If unset, will default to 64."""
|
|
145
|
-
n: int = 1
|
|
146
|
-
"""Number of chat completions to generate for each prompt. Note that the API may
|
|
147
|
-
not return the full n completions if duplicates are generated."""
|
|
148
|
-
max_retries: int = 6
|
|
149
|
-
"""The maximum number of retries to make when generating."""
|
|
150
|
-
|
|
151
|
-
timeout: Optional[float] = None
|
|
152
|
-
"""The maximum number of seconds to wait for a response."""
|
|
153
|
-
|
|
154
|
-
client_options: Optional[Dict] = Field(
|
|
155
|
-
default=None,
|
|
156
|
-
description=(
|
|
157
|
-
"A dictionary of client options to pass to the Google API client, "
|
|
158
|
-
"such as `api_endpoint`."
|
|
159
|
-
),
|
|
160
|
-
)
|
|
161
|
-
transport: Optional[str] = Field(
|
|
162
|
-
default=None,
|
|
163
|
-
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
|
|
164
|
-
)
|
|
165
|
-
additional_headers: Optional[Dict[str, str]] = Field(
|
|
166
|
-
default=None,
|
|
167
|
-
description=(
|
|
168
|
-
"A key-value dictionary representing additional headers for the model call"
|
|
169
|
-
),
|
|
170
|
-
)
|
|
171
|
-
|
|
172
|
-
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
|
|
173
|
-
"""The default safety settings to use for all generations.
|
|
174
|
-
|
|
175
|
-
For example:
|
|
176
|
-
|
|
177
|
-
from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
|
|
178
|
-
|
|
179
|
-
safety_settings = {
|
|
180
|
-
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
|
181
|
-
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
|
182
|
-
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
|
183
|
-
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
|
184
|
-
}
|
|
185
|
-
""" # noqa: E501
|
|
186
|
-
|
|
187
|
-
@property
|
|
188
|
-
def lc_secrets(self) -> Dict[str, str]:
|
|
189
|
-
return {"google_api_key": "GOOGLE_API_KEY"}
|
|
190
|
-
|
|
191
|
-
@property
|
|
192
|
-
def _model_family(self) -> str:
|
|
193
|
-
return GoogleModelFamily(self.model)
|
|
194
|
-
|
|
195
|
-
@property
|
|
196
|
-
def _identifying_params(self) -> Dict[str, Any]:
|
|
197
|
-
"""Get the identifying parameters."""
|
|
198
|
-
return {
|
|
199
|
-
"model": self.model,
|
|
200
|
-
"temperature": self.temperature,
|
|
201
|
-
"top_p": self.top_p,
|
|
202
|
-
"top_k": self.top_k,
|
|
203
|
-
"max_output_tokens": self.max_output_tokens,
|
|
204
|
-
"candidate_count": self.n,
|
|
205
|
-
}
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
|
|
209
|
-
"""Google GenerativeAI models.
|
|
210
|
-
|
|
211
|
-
Example:
|
|
212
|
-
.. code-block:: python
|
|
213
|
-
|
|
214
|
-
from langchain_google_genai import GoogleGenerativeAI
|
|
215
|
-
llm = GoogleGenerativeAI(model="gemini-pro")
|
|
216
|
-
"""
|
|
217
|
-
|
|
218
|
-
client: Any = None #: :meta private:
|
|
219
|
-
model_config = ConfigDict(
|
|
220
|
-
populate_by_name=True,
|
|
221
|
-
)
|
|
222
|
-
|
|
223
|
-
@model_validator(mode="after")
|
|
224
|
-
def validate_environment(self) -> Self:
|
|
225
|
-
"""Validates params and passes them to google-generativeai package."""
|
|
226
|
-
if self.credentials:
|
|
227
|
-
genai.configure(
|
|
228
|
-
credentials=self.credentials,
|
|
229
|
-
transport=self.transport,
|
|
230
|
-
client_options=self.client_options,
|
|
231
|
-
)
|
|
232
|
-
else:
|
|
233
|
-
if isinstance(self.google_api_key, SecretStr):
|
|
234
|
-
google_api_key: Optional[str] = self.google_api_key.get_secret_value()
|
|
235
|
-
else:
|
|
236
|
-
google_api_key = self.google_api_key
|
|
237
|
-
genai.configure(
|
|
238
|
-
api_key=google_api_key,
|
|
239
|
-
transport=self.transport,
|
|
240
|
-
client_options=self.client_options,
|
|
241
|
-
)
|
|
242
|
-
|
|
243
|
-
model_name = self.model
|
|
244
|
-
|
|
245
|
-
safety_settings = self.safety_settings
|
|
246
|
-
|
|
247
|
-
if safety_settings and (
|
|
248
|
-
not GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI
|
|
249
|
-
):
|
|
250
|
-
raise ValueError("Safety settings are only supported for Gemini models")
|
|
251
|
-
|
|
252
|
-
if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
|
|
253
|
-
self.client = genai.GenerativeModel(
|
|
254
|
-
model_name=model_name, safety_settings=safety_settings
|
|
255
|
-
)
|
|
256
|
-
else:
|
|
257
|
-
self.client = genai
|
|
258
|
-
|
|
259
|
-
if self.temperature is not None and not 0 <= self.temperature <= 1:
|
|
260
|
-
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
|
261
|
-
|
|
262
|
-
if self.top_p is not None and not 0 <= self.top_p <= 1:
|
|
263
|
-
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
|
264
|
-
|
|
265
|
-
if self.top_k is not None and self.top_k <= 0:
|
|
266
|
-
raise ValueError("top_k must be positive")
|
|
267
|
-
|
|
268
|
-
if self.max_output_tokens is not None and self.max_output_tokens <= 0:
|
|
269
|
-
raise ValueError("max_output_tokens must be greater than zero")
|
|
270
|
-
|
|
271
|
-
if self.timeout is not None and self.timeout <= 0:
|
|
272
|
-
raise ValueError("timeout must be greater than zero")
|
|
273
|
-
|
|
274
|
-
return self
|
|
275
|
-
|
|
276
|
-
def _get_ls_params(
|
|
277
|
-
self, stop: Optional[List[str]] = None, **kwargs: Any
|
|
278
|
-
) -> LangSmithParams:
|
|
279
|
-
"""Get standard params for tracing."""
|
|
280
|
-
ls_params = super()._get_ls_params(stop=stop, **kwargs)
|
|
281
|
-
ls_params["ls_provider"] = "google_genai"
|
|
282
|
-
if ls_max_tokens := kwargs.get("max_output_tokens", self.max_output_tokens):
|
|
283
|
-
ls_params["ls_max_tokens"] = ls_max_tokens
|
|
284
|
-
return ls_params
|
|
285
|
-
|
|
286
|
-
def _generate(
|
|
287
|
-
self,
|
|
288
|
-
prompts: List[str],
|
|
289
|
-
stop: Optional[List[str]] = None,
|
|
290
|
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
291
|
-
**kwargs: Any,
|
|
292
|
-
) -> LLMResult:
|
|
293
|
-
generations: List[List[Generation]] = []
|
|
294
|
-
generation_config = {
|
|
295
|
-
"stop_sequences": stop,
|
|
296
|
-
"temperature": self.temperature,
|
|
297
|
-
"top_p": self.top_p,
|
|
298
|
-
"top_k": self.top_k,
|
|
299
|
-
"max_output_tokens": self.max_output_tokens,
|
|
300
|
-
"candidate_count": self.n,
|
|
301
|
-
}
|
|
302
|
-
for prompt in prompts:
|
|
303
|
-
if self._model_family == GoogleModelFamily.GEMINI:
|
|
304
|
-
res = _completion_with_retry(
|
|
305
|
-
self,
|
|
306
|
-
prompt=prompt,
|
|
307
|
-
stream=False,
|
|
308
|
-
is_gemini=True,
|
|
309
|
-
run_manager=run_manager,
|
|
310
|
-
generation_config=generation_config,
|
|
311
|
-
safety_settings=kwargs.pop("safety_settings", None),
|
|
312
|
-
)
|
|
313
|
-
generation_info = None
|
|
314
|
-
if res.usage_metadata is not None:
|
|
315
|
-
generation_info = {
|
|
316
|
-
"usage_metadata": res.to_dict().get("usage_metadata")
|
|
317
|
-
}
|
|
318
|
-
|
|
319
|
-
candidates = [
|
|
320
|
-
"".join([p.text for p in c.content.parts]) for c in res.candidates
|
|
321
|
-
]
|
|
322
|
-
generations.append(
|
|
323
|
-
[
|
|
324
|
-
Generation(text=c, generation_info=generation_info)
|
|
325
|
-
for c in candidates
|
|
326
|
-
]
|
|
327
|
-
)
|
|
328
|
-
else:
|
|
329
|
-
res = _completion_with_retry(
|
|
330
|
-
self,
|
|
331
|
-
model=self.model,
|
|
332
|
-
prompt=prompt,
|
|
333
|
-
stream=False,
|
|
334
|
-
is_gemini=False,
|
|
335
|
-
run_manager=run_manager,
|
|
336
|
-
**generation_config,
|
|
337
|
-
)
|
|
338
|
-
prompt_generations = []
|
|
339
|
-
for candidate in res.candidates:
|
|
340
|
-
raw_text = candidate["output"]
|
|
341
|
-
stripped_text = _strip_erroneous_leading_spaces(raw_text)
|
|
342
|
-
prompt_generations.append(Generation(text=stripped_text))
|
|
343
|
-
generations.append(prompt_generations)
|
|
344
|
-
|
|
345
|
-
return LLMResult(generations=generations)
|
|
346
|
-
|
|
347
|
-
def _stream(
|
|
348
|
-
self,
|
|
349
|
-
prompt: str,
|
|
350
|
-
stop: Optional[List[str]] = None,
|
|
351
|
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
352
|
-
**kwargs: Any,
|
|
353
|
-
) -> Iterator[GenerationChunk]:
|
|
354
|
-
generation_config = {
|
|
355
|
-
"stop_sequences": stop,
|
|
356
|
-
"temperature": self.temperature,
|
|
357
|
-
"top_p": self.top_p,
|
|
358
|
-
"top_k": self.top_k,
|
|
359
|
-
"max_output_tokens": self.max_output_tokens,
|
|
360
|
-
"candidate_count": self.n,
|
|
361
|
-
}
|
|
362
|
-
generation_config = generation_config | kwargs.get("generation_config", {})
|
|
363
|
-
|
|
364
|
-
for stream_resp in _completion_with_retry(
|
|
365
|
-
self,
|
|
366
|
-
prompt,
|
|
367
|
-
stream=True,
|
|
368
|
-
is_gemini=True,
|
|
369
|
-
run_manager=run_manager,
|
|
370
|
-
generation_config=generation_config,
|
|
371
|
-
safety_settings=kwargs.pop("safety_settings", None),
|
|
372
|
-
**kwargs,
|
|
373
|
-
):
|
|
374
|
-
chunk = GenerationChunk(text=stream_resp.text)
|
|
375
|
-
yield chunk
|
|
376
|
-
if run_manager:
|
|
377
|
-
run_manager.on_llm_new_token(
|
|
378
|
-
stream_resp.text,
|
|
379
|
-
chunk=chunk,
|
|
380
|
-
verbose=self.verbose,
|
|
381
|
-
)
|
|
382
|
-
|
|
383
|
-
@property
|
|
384
|
-
def _llm_type(self) -> str:
|
|
385
|
-
"""Return type of llm."""
|
|
386
|
-
return "google_palm"
|
|
387
|
-
|
|
388
|
-
def get_num_tokens(self, text: str) -> int:
|
|
389
|
-
"""Get the number of tokens present in the text.
|
|
390
|
-
|
|
391
|
-
Useful for checking if an input will fit in a model's context window.
|
|
392
|
-
|
|
393
|
-
Args:
|
|
394
|
-
text: The string input to tokenize.
|
|
395
|
-
|
|
396
|
-
Returns:
|
|
397
|
-
The integer number of tokens in the text.
|
|
398
|
-
"""
|
|
399
|
-
if self._model_family == GoogleModelFamily.GEMINI:
|
|
400
|
-
result = self.client.count_tokens(text)
|
|
401
|
-
token_count = result.total_tokens
|
|
402
|
-
else:
|
|
403
|
-
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
|
404
|
-
token_count = result["token_count"]
|
|
405
|
-
|
|
406
|
-
return token_count
|
|
File without changes
|
|
File without changes
|
{langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/langchain_google_genai/__init__.py
RENAMED
|
File without changes
|
{langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/langchain_google_genai/_enums.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/langchain_google_genai/embeddings.py
RENAMED
|
File without changes
|
{langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/langchain_google_genai/genai_aqa.py
RENAMED
|
File without changes
|
|
File without changes
|
{langchain_google_genai-2.0.10 → langchain_google_genai-2.1.0}/langchain_google_genai/py.typed
RENAMED
|
File without changes
|