langchain-google-genai 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain-google-genai might be problematic. Click here for more details.
- langchain_google_genai/__init__.py +1 -1
- langchain_google_genai/_common.py +48 -0
- langchain_google_genai/_enums.py +4 -4
- langchain_google_genai/_function_utils.py +50 -54
- langchain_google_genai/_genai_extension.py +64 -7
- langchain_google_genai/_image_utils.py +187 -0
- langchain_google_genai/chat_models.py +187 -107
- langchain_google_genai/embeddings.py +85 -41
- langchain_google_genai/llms.py +0 -1
- {langchain_google_genai-1.0.3.dist-info → langchain_google_genai-1.0.4.dist-info}/METADATA +2 -2
- langchain_google_genai-1.0.4.dist-info/RECORD +16 -0
- langchain_google_genai-1.0.3.dist-info/RECORD +0 -15
- {langchain_google_genai-1.0.3.dist-info → langchain_google_genai-1.0.4.dist-info}/LICENSE +0 -0
- {langchain_google_genai-1.0.3.dist-info → langchain_google_genai-1.0.4.dist-info}/WHEEL +0 -0
|
@@ -1,4 +1,52 @@
|
|
|
1
|
+
from importlib import metadata
|
|
2
|
+
from typing import Optional, Tuple, TypedDict
|
|
3
|
+
|
|
4
|
+
from google.api_core.gapic_v1.client_info import ClientInfo
|
|
5
|
+
|
|
6
|
+
from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory
|
|
7
|
+
|
|
8
|
+
|
|
1
9
|
class GoogleGenerativeAIError(Exception):
|
|
2
10
|
"""
|
|
3
11
|
Custom exception class for errors associated with the `Google GenAI` API.
|
|
4
12
|
"""
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_user_agent(module: Optional[str] = None) -> Tuple[str, str]:
|
|
16
|
+
r"""Returns a custom user agent header.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
module (Optional[str]):
|
|
20
|
+
Optional. The module for a custom user agent header.
|
|
21
|
+
Returns:
|
|
22
|
+
Tuple[str, str]
|
|
23
|
+
"""
|
|
24
|
+
try:
|
|
25
|
+
langchain_version = metadata.version("langchain-google-genai")
|
|
26
|
+
except metadata.PackageNotFoundError:
|
|
27
|
+
langchain_version = "0.0.0"
|
|
28
|
+
client_library_version = (
|
|
29
|
+
f"{langchain_version}-{module}" if module else langchain_version
|
|
30
|
+
)
|
|
31
|
+
return client_library_version, f"langchain-google-genai/{client_library_version}"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_client_info(module: Optional[str] = None) -> "ClientInfo":
|
|
35
|
+
r"""Returns a client info object with a custom user agent header.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
module (Optional[str]):
|
|
39
|
+
Optional. The module for a custom user agent header.
|
|
40
|
+
Returns:
|
|
41
|
+
google.api_core.gapic_v1.client_info.ClientInfo
|
|
42
|
+
"""
|
|
43
|
+
client_library_version, user_agent = get_user_agent(module)
|
|
44
|
+
return ClientInfo(
|
|
45
|
+
client_library_version=client_library_version,
|
|
46
|
+
user_agent=user_agent,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class SafetySettingDict(TypedDict):
|
|
51
|
+
category: HarmCategory
|
|
52
|
+
threshold: HarmBlockThreshold
|
langchain_google_genai/_enums.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
1
|
+
import google.ai.generativelanguage_v1beta as genai
|
|
2
|
+
|
|
3
|
+
HarmBlockThreshold = genai.SafetySetting.HarmBlockThreshold
|
|
4
|
+
HarmCategory = genai.HarmCategory
|
|
5
5
|
|
|
6
6
|
__all__ = ["HarmBlockThreshold", "HarmCategory"]
|
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from typing import (
|
|
4
4
|
Any,
|
|
5
|
+
Callable,
|
|
5
6
|
Dict,
|
|
6
7
|
List,
|
|
7
8
|
Literal,
|
|
@@ -10,15 +11,16 @@ from typing import (
|
|
|
10
11
|
Type,
|
|
11
12
|
TypedDict,
|
|
12
13
|
Union,
|
|
14
|
+
cast,
|
|
13
15
|
)
|
|
14
16
|
|
|
15
17
|
import google.ai.generativelanguage as glm
|
|
16
|
-
from google.
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
18
|
+
from google.ai.generativelanguage import (
|
|
19
|
+
FunctionCallingConfig,
|
|
20
|
+
FunctionDeclaration,
|
|
21
|
+
)
|
|
22
|
+
from google.ai.generativelanguage import (
|
|
23
|
+
Tool as GoogleTool,
|
|
22
24
|
)
|
|
23
25
|
from langchain_core.pydantic_v1 import BaseModel
|
|
24
26
|
from langchain_core.tools import BaseTool
|
|
@@ -36,51 +38,41 @@ TYPE_ENUM = {
|
|
|
36
38
|
|
|
37
39
|
TYPE_ENUM_REVERSE = {v: k for k, v in TYPE_ENUM.items()}
|
|
38
40
|
|
|
41
|
+
_FunctionDeclarationLike = Union[
|
|
42
|
+
BaseTool, Type[BaseModel], dict, Callable, FunctionDeclaration
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class _ToolDict(TypedDict):
|
|
47
|
+
function_declarations: Sequence[_FunctionDeclarationLike]
|
|
48
|
+
|
|
39
49
|
|
|
40
50
|
def convert_to_genai_function_declarations(
|
|
41
51
|
tool: Union[
|
|
42
|
-
GoogleTool,
|
|
52
|
+
GoogleTool,
|
|
53
|
+
_ToolDict,
|
|
54
|
+
_FunctionDeclarationLike,
|
|
55
|
+
Sequence[_FunctionDeclarationLike],
|
|
43
56
|
],
|
|
44
|
-
) ->
|
|
45
|
-
"""Convert any tool-like object to a ToolType.
|
|
46
|
-
|
|
47
|
-
https://github.com/google-gemini/generative-ai-python/blob/668695ebe3e9de496a36eeb95cb2ed2faba9b939/google/generativeai/types/content_types.py#L574
|
|
48
|
-
"""
|
|
57
|
+
) -> GoogleTool:
|
|
49
58
|
if isinstance(tool, GoogleTool):
|
|
50
|
-
return tool
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
except AttributeError:
|
|
60
|
-
pass
|
|
61
|
-
if schema is None:
|
|
62
|
-
schema = first_function_declaration.get("parameters")
|
|
63
|
-
if schema is None or isinstance(schema, glm.Schema):
|
|
64
|
-
return tool
|
|
65
|
-
return glm.Tool(
|
|
59
|
+
return cast(GoogleTool, tool)
|
|
60
|
+
if isinstance(tool, type) and issubclass(tool, BaseModel):
|
|
61
|
+
return GoogleTool(function_declarations=[_convert_to_genai_function(tool)])
|
|
62
|
+
if callable(tool):
|
|
63
|
+
return _convert_tool_to_genai_function(callable_as_lc_tool()(tool))
|
|
64
|
+
if isinstance(tool, list):
|
|
65
|
+
return convert_to_genai_function_declarations({"function_declarations": tool})
|
|
66
|
+
if isinstance(tool, dict) and "function_declarations" in tool:
|
|
67
|
+
return GoogleTool(
|
|
66
68
|
function_declarations=[
|
|
67
69
|
_convert_to_genai_function(fc) for fc in tool["function_declarations"]
|
|
68
70
|
],
|
|
69
71
|
)
|
|
70
|
-
|
|
71
|
-
return glm.Tool(function_declarations=[_convert_to_genai_function(tool)])
|
|
72
|
-
elif callable(tool):
|
|
73
|
-
return _convert_tool_to_genai_function(callable_as_lc_tool()(tool))
|
|
74
|
-
elif isinstance(tool, list):
|
|
75
|
-
return glm.Tool(
|
|
76
|
-
function_declarations=[_convert_to_genai_function(fc) for fc in tool]
|
|
77
|
-
)
|
|
78
|
-
return glm.Tool(function_declarations=[_convert_to_genai_function(tool)])
|
|
72
|
+
return GoogleTool(function_declarations=[_convert_to_genai_function(tool)]) # type: ignore[arg-type]
|
|
79
73
|
|
|
80
74
|
|
|
81
|
-
def tool_to_dict(tool:
|
|
82
|
-
if isinstance(tool, GoogleTool):
|
|
83
|
-
tool = tool._proto
|
|
75
|
+
def tool_to_dict(tool: GoogleTool) -> _ToolDict:
|
|
84
76
|
function_declarations = []
|
|
85
77
|
for function_declaration_proto in tool.function_declarations:
|
|
86
78
|
properties: Dict[str, Any] = {}
|
|
@@ -108,7 +100,7 @@ def tool_to_dict(tool: Union[glm.Tool, GoogleTool]) -> ToolDict:
|
|
|
108
100
|
return {"function_declarations": function_declarations}
|
|
109
101
|
|
|
110
102
|
|
|
111
|
-
def _convert_to_genai_function(fc:
|
|
103
|
+
def _convert_to_genai_function(fc: _FunctionDeclarationLike) -> FunctionDeclaration:
|
|
112
104
|
if isinstance(fc, BaseTool):
|
|
113
105
|
return _convert_tool_to_genai_function(fc)
|
|
114
106
|
elif isinstance(fc, type) and issubclass(fc, BaseModel):
|
|
@@ -116,10 +108,9 @@ def _convert_to_genai_function(fc: FunctionDeclarationType) -> glm.FunctionDecla
|
|
|
116
108
|
elif callable(fc):
|
|
117
109
|
return _convert_tool_to_genai_function(callable_as_lc_tool()(fc))
|
|
118
110
|
elif isinstance(fc, dict):
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
parameters={
|
|
111
|
+
formatted_fc = {"name": fc["name"], "description": fc.get("description")}
|
|
112
|
+
if "parameters" in fc:
|
|
113
|
+
formatted_fc["parameters"] = {
|
|
123
114
|
"properties": {
|
|
124
115
|
k: {
|
|
125
116
|
"type_": TYPE_ENUM[v["type"]],
|
|
@@ -127,19 +118,19 @@ def _convert_to_genai_function(fc: FunctionDeclarationType) -> glm.FunctionDecla
|
|
|
127
118
|
}
|
|
128
119
|
for k, v in fc["parameters"]["properties"].items()
|
|
129
120
|
},
|
|
130
|
-
"required": fc
|
|
121
|
+
"required": fc.get("parameters", []).get("required", []),
|
|
131
122
|
"type_": TYPE_ENUM[fc["parameters"]["type"]],
|
|
132
|
-
}
|
|
133
|
-
)
|
|
123
|
+
}
|
|
124
|
+
return FunctionDeclaration(**formatted_fc)
|
|
134
125
|
else:
|
|
135
126
|
raise ValueError(f"Unsupported function call type {fc}")
|
|
136
127
|
|
|
137
128
|
|
|
138
|
-
def _convert_tool_to_genai_function(tool: BaseTool) ->
|
|
129
|
+
def _convert_tool_to_genai_function(tool: BaseTool) -> FunctionDeclaration:
|
|
139
130
|
if tool.args_schema:
|
|
140
131
|
schema = dereference_refs(tool.args_schema.schema())
|
|
141
132
|
schema.pop("definitions", None)
|
|
142
|
-
return
|
|
133
|
+
return FunctionDeclaration(
|
|
143
134
|
name=tool.name or schema["title"],
|
|
144
135
|
description=tool.description or schema["description"],
|
|
145
136
|
parameters={
|
|
@@ -155,7 +146,7 @@ def _convert_tool_to_genai_function(tool: BaseTool) -> glm.FunctionDeclaration:
|
|
|
155
146
|
},
|
|
156
147
|
)
|
|
157
148
|
else:
|
|
158
|
-
return
|
|
149
|
+
return FunctionDeclaration(
|
|
159
150
|
name=tool.name,
|
|
160
151
|
description=tool.description,
|
|
161
152
|
parameters={
|
|
@@ -170,10 +161,10 @@ def _convert_tool_to_genai_function(tool: BaseTool) -> glm.FunctionDeclaration:
|
|
|
170
161
|
|
|
171
162
|
def _convert_pydantic_to_genai_function(
|
|
172
163
|
pydantic_model: Type[BaseModel],
|
|
173
|
-
) ->
|
|
164
|
+
) -> FunctionDeclaration:
|
|
174
165
|
schema = dereference_refs(pydantic_model.schema())
|
|
175
166
|
schema.pop("definitions", None)
|
|
176
|
-
return
|
|
167
|
+
return FunctionDeclaration(
|
|
177
168
|
name=schema["title"],
|
|
178
169
|
description=schema.get("description", ""),
|
|
179
170
|
parameters={
|
|
@@ -195,8 +186,13 @@ _ToolChoiceType = Union[
|
|
|
195
186
|
]
|
|
196
187
|
|
|
197
188
|
|
|
189
|
+
class _FunctionCallingConfigDict(TypedDict):
|
|
190
|
+
mode: Union[FunctionCallingConfig.Mode, str]
|
|
191
|
+
allowed_function_names: Optional[List[str]]
|
|
192
|
+
|
|
193
|
+
|
|
198
194
|
class _ToolConfigDict(TypedDict):
|
|
199
|
-
function_calling_config:
|
|
195
|
+
function_calling_config: _FunctionCallingConfigDict
|
|
200
196
|
|
|
201
197
|
|
|
202
198
|
def _tool_choice_to_tool_config(
|
|
@@ -12,6 +12,12 @@ from typing import Any, Dict, Iterator, List, MutableSequence, Optional
|
|
|
12
12
|
|
|
13
13
|
import google.ai.generativelanguage as genai
|
|
14
14
|
import langchain_core
|
|
15
|
+
from google.ai.generativelanguage_v1beta import (
|
|
16
|
+
GenerativeServiceAsyncClient as v1betaGenerativeServiceAsyncClient,
|
|
17
|
+
)
|
|
18
|
+
from google.ai.generativelanguage_v1beta import (
|
|
19
|
+
GenerativeServiceClient as v1betaGenerativeServiceClient,
|
|
20
|
+
)
|
|
15
21
|
from google.api_core import client_options as client_options_lib
|
|
16
22
|
from google.api_core import exceptions as gapi_exception
|
|
17
23
|
from google.api_core import gapic_v1
|
|
@@ -225,15 +231,66 @@ def build_semantic_retriever() -> genai.RetrieverServiceClient:
|
|
|
225
231
|
)
|
|
226
232
|
|
|
227
233
|
|
|
228
|
-
def
|
|
229
|
-
credentials =
|
|
230
|
-
|
|
234
|
+
def _prepare_config(
|
|
235
|
+
credentials: Optional[credentials.Credentials] = None,
|
|
236
|
+
api_key: Optional[str] = None,
|
|
237
|
+
client_options: Optional[Dict[str, Any]] = None,
|
|
238
|
+
client_info: Optional[gapic_v1.client_info.ClientInfo] = None,
|
|
239
|
+
transport: Optional[str] = None,
|
|
240
|
+
) -> Dict[str, Any]:
|
|
241
|
+
formatted_client_options = {"api_endpoint": _config.api_endpoint}
|
|
242
|
+
if client_options:
|
|
243
|
+
formatted_client_options.update(**client_options)
|
|
244
|
+
if not credentials and api_key:
|
|
245
|
+
formatted_client_options["api_key"] = api_key
|
|
246
|
+
elif not credentials and not api_key:
|
|
247
|
+
credentials = _get_credentials()
|
|
248
|
+
client_info = (
|
|
249
|
+
client_info
|
|
250
|
+
if client_info
|
|
251
|
+
else gapic_v1.client_info.ClientInfo(user_agent=_USER_AGENT)
|
|
252
|
+
)
|
|
253
|
+
config = {
|
|
254
|
+
"credentials": credentials,
|
|
255
|
+
"client_info": client_info,
|
|
256
|
+
"client_options": client_options_lib.ClientOptions(**formatted_client_options),
|
|
257
|
+
"transport": transport,
|
|
258
|
+
}
|
|
259
|
+
return {k: v for k, v in config.items() if v is not None}
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def build_generative_service(
|
|
263
|
+
credentials: Optional[credentials.Credentials] = None,
|
|
264
|
+
api_key: Optional[str] = None,
|
|
265
|
+
client_options: Optional[Dict[str, Any]] = None,
|
|
266
|
+
client_info: Optional[gapic_v1.client_info.ClientInfo] = None,
|
|
267
|
+
transport: Optional[str] = None,
|
|
268
|
+
) -> v1betaGenerativeServiceClient:
|
|
269
|
+
config = _prepare_config(
|
|
231
270
|
credentials=credentials,
|
|
232
|
-
|
|
233
|
-
client_options=
|
|
234
|
-
|
|
235
|
-
|
|
271
|
+
api_key=api_key,
|
|
272
|
+
client_options=client_options,
|
|
273
|
+
transport=transport,
|
|
274
|
+
client_info=client_info,
|
|
275
|
+
)
|
|
276
|
+
return v1betaGenerativeServiceClient(**config)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def build_generative_async_service(
|
|
280
|
+
credentials: Optional[credentials.Credentials],
|
|
281
|
+
api_key: Optional[str] = None,
|
|
282
|
+
client_options: Optional[Dict[str, Any]] = None,
|
|
283
|
+
client_info: Optional[gapic_v1.client_info.ClientInfo] = None,
|
|
284
|
+
transport: Optional[str] = None,
|
|
285
|
+
) -> v1betaGenerativeServiceAsyncClient:
|
|
286
|
+
config = _prepare_config(
|
|
287
|
+
credentials=credentials,
|
|
288
|
+
api_key=api_key,
|
|
289
|
+
client_options=client_options,
|
|
290
|
+
transport=transport,
|
|
291
|
+
client_info=client_info,
|
|
236
292
|
)
|
|
293
|
+
return v1betaGenerativeServiceAsyncClient(**config)
|
|
237
294
|
|
|
238
295
|
|
|
239
296
|
def list_corpora(
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import mimetypes
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import Any, Dict
|
|
9
|
+
from urllib.parse import urlparse
|
|
10
|
+
|
|
11
|
+
import requests
|
|
12
|
+
from google.ai.generativelanguage_v1beta.types import Part
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Route(Enum):
|
|
16
|
+
"""Image Loading Route"""
|
|
17
|
+
|
|
18
|
+
BASE64 = 1
|
|
19
|
+
LOCAL_FILE = 2
|
|
20
|
+
URL = 3
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ImageBytesLoader:
|
|
24
|
+
"""Loads image bytes from multiple sources given a string.
|
|
25
|
+
|
|
26
|
+
Currently supported:
|
|
27
|
+
- B64 Encoded image string
|
|
28
|
+
- Local file path
|
|
29
|
+
- URL
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def load_bytes(self, image_string: str) -> bytes:
|
|
33
|
+
"""Routes to the correct loader based on the image_string.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
image_string: Can be either:
|
|
37
|
+
- B64 Encoded image string
|
|
38
|
+
- Local file path
|
|
39
|
+
- URL
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Image bytes.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
route = self._route(image_string)
|
|
46
|
+
|
|
47
|
+
if route == Route.BASE64:
|
|
48
|
+
return self._bytes_from_b64(image_string)
|
|
49
|
+
|
|
50
|
+
if route == Route.URL:
|
|
51
|
+
return self._bytes_from_url(image_string)
|
|
52
|
+
|
|
53
|
+
if route == Route.LOCAL_FILE:
|
|
54
|
+
return self._bytes_from_file(image_string)
|
|
55
|
+
|
|
56
|
+
raise ValueError(
|
|
57
|
+
"Image string must be one of: Google Cloud Storage URI, "
|
|
58
|
+
"b64 encoded image string (data:image/...), valid image url, "
|
|
59
|
+
f"or existing local image file. Instead got '{image_string}'."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def load_part(self, image_string: str) -> Part:
|
|
63
|
+
"""Gets Part for loading from Gemini.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
image_string: Can be either:
|
|
67
|
+
- B64 Encoded image string
|
|
68
|
+
- Local file path
|
|
69
|
+
- URL
|
|
70
|
+
"""
|
|
71
|
+
route = self._route(image_string)
|
|
72
|
+
|
|
73
|
+
if route == Route.BASE64:
|
|
74
|
+
bytes_ = self._bytes_from_b64(image_string)
|
|
75
|
+
|
|
76
|
+
if route == Route.URL:
|
|
77
|
+
bytes_ = self._bytes_from_url(image_string)
|
|
78
|
+
|
|
79
|
+
if route == Route.LOCAL_FILE:
|
|
80
|
+
bytes_ = self._bytes_from_file(image_string)
|
|
81
|
+
|
|
82
|
+
inline_data: Dict[str, Any] = {"data": bytes_}
|
|
83
|
+
mime_type, _ = mimetypes.guess_type(image_string)
|
|
84
|
+
if mime_type:
|
|
85
|
+
inline_data["mime_type"] = mime_type
|
|
86
|
+
|
|
87
|
+
return Part(inline_data=inline_data)
|
|
88
|
+
|
|
89
|
+
def _route(self, image_string: str) -> Route:
|
|
90
|
+
if image_string.startswith("data:image/"):
|
|
91
|
+
return Route.BASE64
|
|
92
|
+
|
|
93
|
+
if self._is_url(image_string):
|
|
94
|
+
return Route.URL
|
|
95
|
+
|
|
96
|
+
if os.path.exists(image_string):
|
|
97
|
+
return Route.LOCAL_FILE
|
|
98
|
+
|
|
99
|
+
raise ValueError(
|
|
100
|
+
"Image string must be one of: "
|
|
101
|
+
"b64 encoded image string (data:image/...), valid image url, "
|
|
102
|
+
f"or existing local image file. Instead got '{image_string}'."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
def _bytes_from_b64(self, base64_image: str) -> bytes:
|
|
106
|
+
"""Gets image bytes from a base64 encoded string.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
base64_image: Encoded image in b64 format.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Image bytes
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
pattern = r"data:image/\w{2,4};base64,(.*)"
|
|
116
|
+
match = re.search(pattern, base64_image)
|
|
117
|
+
|
|
118
|
+
if match is not None:
|
|
119
|
+
encoded_string = match.group(1)
|
|
120
|
+
return base64.b64decode(encoded_string)
|
|
121
|
+
|
|
122
|
+
raise ValueError(f"Error in b64 encoded image. Must follow pattern: {pattern}")
|
|
123
|
+
|
|
124
|
+
def _bytes_from_file(self, file_path: str) -> bytes:
|
|
125
|
+
"""Gets image bytes from a local file path.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
file_path: Existing file path.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Image bytes
|
|
132
|
+
"""
|
|
133
|
+
with open(file_path, "rb") as image_file:
|
|
134
|
+
image_bytes = image_file.read()
|
|
135
|
+
return image_bytes
|
|
136
|
+
|
|
137
|
+
def _bytes_from_url(self, url: str) -> bytes:
|
|
138
|
+
"""Gets image bytes from a public url.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
url: Valid url.
|
|
142
|
+
|
|
143
|
+
Raises:
|
|
144
|
+
HTTP Error if there is one.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Image bytes
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
response = requests.get(url)
|
|
151
|
+
|
|
152
|
+
if not response.ok:
|
|
153
|
+
response.raise_for_status()
|
|
154
|
+
|
|
155
|
+
return response.content
|
|
156
|
+
|
|
157
|
+
def _is_url(self, url_string: str) -> bool:
|
|
158
|
+
"""Checks if a url is valid.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
url_string: Url to check.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
Whether the url is valid.
|
|
165
|
+
"""
|
|
166
|
+
try:
|
|
167
|
+
result = urlparse(url_string)
|
|
168
|
+
return all([result.scheme, result.netloc])
|
|
169
|
+
except Exception:
|
|
170
|
+
return False
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def image_bytes_to_b64_string(
|
|
174
|
+
image_bytes: bytes, encoding: str = "ascii", image_format: str = "png"
|
|
175
|
+
) -> str:
|
|
176
|
+
"""Encodes image bytes into a b64 encoded string.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
image_bytes: Bytes of the image.
|
|
180
|
+
encoding: Type of encoding in the string. 'ascii' by default.
|
|
181
|
+
image_format: Format of the image. 'png' by default.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
B64 image encoded string.
|
|
185
|
+
"""
|
|
186
|
+
encoded_bytes = base64.b64encode(image_bytes).decode(encoding)
|
|
187
|
+
return f"data:image/{image_format};base64,{encoded_bytes}"
|
|
@@ -23,14 +23,23 @@ from typing import (
|
|
|
23
23
|
)
|
|
24
24
|
from urllib.parse import urlparse
|
|
25
25
|
|
|
26
|
-
import google.ai.generativelanguage as glm
|
|
27
26
|
import google.api_core
|
|
28
27
|
|
|
29
28
|
# TODO: remove ignore once the google package is published with types
|
|
30
|
-
import google.generativeai as genai # type: ignore[import]
|
|
31
29
|
import proto # type: ignore[import]
|
|
32
30
|
import requests
|
|
33
|
-
from google.
|
|
31
|
+
from google.ai.generativelanguage_v1beta.types import (
|
|
32
|
+
Candidate,
|
|
33
|
+
Content,
|
|
34
|
+
FunctionCall,
|
|
35
|
+
FunctionResponse,
|
|
36
|
+
GenerateContentRequest,
|
|
37
|
+
GenerateContentResponse,
|
|
38
|
+
GenerationConfig,
|
|
39
|
+
Part,
|
|
40
|
+
SafetySetting,
|
|
41
|
+
ToolConfig,
|
|
42
|
+
)
|
|
34
43
|
from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
|
|
35
44
|
from google.generativeai.types.content_types import ( # type: ignore[import]
|
|
36
45
|
FunctionDeclarationType,
|
|
@@ -56,7 +65,7 @@ from langchain_core.messages import (
|
|
|
56
65
|
)
|
|
57
66
|
from langchain_core.output_parsers.openai_tools import parse_tool_calls
|
|
58
67
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
59
|
-
from langchain_core.pydantic_v1 import SecretStr, root_validator
|
|
68
|
+
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
|
60
69
|
from langchain_core.runnables import Runnable
|
|
61
70
|
from langchain_core.utils import get_from_dict_or_env
|
|
62
71
|
from tenacity import (
|
|
@@ -67,7 +76,11 @@ from tenacity import (
|
|
|
67
76
|
wait_exponential,
|
|
68
77
|
)
|
|
69
78
|
|
|
70
|
-
from langchain_google_genai._common import
|
|
79
|
+
from langchain_google_genai._common import (
|
|
80
|
+
GoogleGenerativeAIError,
|
|
81
|
+
SafetySettingDict,
|
|
82
|
+
get_client_info,
|
|
83
|
+
)
|
|
71
84
|
from langchain_google_genai._function_utils import (
|
|
72
85
|
_tool_choice_to_tool_config,
|
|
73
86
|
_ToolChoiceType,
|
|
@@ -75,7 +88,10 @@ from langchain_google_genai._function_utils import (
|
|
|
75
88
|
convert_to_genai_function_declarations,
|
|
76
89
|
tool_to_dict,
|
|
77
90
|
)
|
|
78
|
-
from langchain_google_genai.
|
|
91
|
+
from langchain_google_genai._image_utils import ImageBytesLoader
|
|
92
|
+
from langchain_google_genai.llms import _BaseGoogleGenerativeAI
|
|
93
|
+
|
|
94
|
+
from . import _genai_extension as genaix
|
|
79
95
|
|
|
80
96
|
IMAGE_TYPES: Tuple = ()
|
|
81
97
|
try:
|
|
@@ -279,18 +295,19 @@ def _url_to_pil(image_source: str) -> Image:
|
|
|
279
295
|
|
|
280
296
|
def _convert_to_parts(
|
|
281
297
|
raw_content: Union[str, Sequence[Union[str, dict]]],
|
|
282
|
-
) -> List[
|
|
298
|
+
) -> List[Part]:
|
|
283
299
|
"""Converts a list of LangChain messages into a google parts."""
|
|
284
300
|
parts = []
|
|
285
301
|
content = [raw_content] if isinstance(raw_content, str) else raw_content
|
|
302
|
+
image_loader = ImageBytesLoader()
|
|
286
303
|
for part in content:
|
|
287
304
|
if isinstance(part, str):
|
|
288
|
-
parts.append(
|
|
305
|
+
parts.append(Part(text=part))
|
|
289
306
|
elif isinstance(part, Mapping):
|
|
290
307
|
# OpenAI Format
|
|
291
308
|
if _is_openai_parts_format(part):
|
|
292
309
|
if part["type"] == "text":
|
|
293
|
-
parts.append(
|
|
310
|
+
parts.append(Part(text=part["text"]))
|
|
294
311
|
elif part["type"] == "image_url":
|
|
295
312
|
img_url = part["image_url"]
|
|
296
313
|
if isinstance(img_url, dict):
|
|
@@ -299,7 +316,7 @@ def _convert_to_parts(
|
|
|
299
316
|
f"Unrecognized message image format: {img_url}"
|
|
300
317
|
)
|
|
301
318
|
img_url = img_url["url"]
|
|
302
|
-
parts.append(
|
|
319
|
+
parts.append(image_loader.load_part(img_url))
|
|
303
320
|
else:
|
|
304
321
|
raise ValueError(f"Unrecognized message part type: {part['type']}")
|
|
305
322
|
else:
|
|
@@ -307,7 +324,7 @@ def _convert_to_parts(
|
|
|
307
324
|
logger.warning(
|
|
308
325
|
"Unrecognized message part format. Assuming it's a text part."
|
|
309
326
|
)
|
|
310
|
-
parts.append(part)
|
|
327
|
+
parts.append(Part(text=str(part)))
|
|
311
328
|
else:
|
|
312
329
|
# TODO: Maybe some of Google's native stuff
|
|
313
330
|
# would hit this branch.
|
|
@@ -319,33 +336,36 @@ def _convert_to_parts(
|
|
|
319
336
|
|
|
320
337
|
def _parse_chat_history(
|
|
321
338
|
input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
|
|
322
|
-
) -> Tuple[Optional[
|
|
323
|
-
messages: List[
|
|
339
|
+
) -> Tuple[Optional[Content], List[Content]]:
|
|
340
|
+
messages: List[Content] = []
|
|
324
341
|
|
|
325
342
|
if convert_system_message_to_human:
|
|
326
343
|
warnings.warn("Convert_system_message_to_human will be deprecated!")
|
|
327
344
|
|
|
328
|
-
system_instruction: Optional[
|
|
345
|
+
system_instruction: Optional[Content] = None
|
|
329
346
|
for i, message in enumerate(input_messages):
|
|
330
347
|
if i == 0 and isinstance(message, SystemMessage):
|
|
331
|
-
system_instruction = _convert_to_parts(message.content)
|
|
348
|
+
system_instruction = Content(parts=_convert_to_parts(message.content))
|
|
332
349
|
continue
|
|
333
350
|
elif isinstance(message, AIMessage):
|
|
334
351
|
role = "model"
|
|
335
352
|
raw_function_call = message.additional_kwargs.get("function_call")
|
|
336
353
|
if raw_function_call:
|
|
337
|
-
function_call =
|
|
354
|
+
function_call = FunctionCall(
|
|
338
355
|
{
|
|
339
356
|
"name": raw_function_call["name"],
|
|
340
357
|
"args": json.loads(raw_function_call["arguments"]),
|
|
341
358
|
}
|
|
342
359
|
)
|
|
343
|
-
parts = [
|
|
360
|
+
parts = [Part(function_call=function_call)]
|
|
344
361
|
else:
|
|
345
362
|
parts = _convert_to_parts(message.content)
|
|
346
363
|
elif isinstance(message, HumanMessage):
|
|
347
364
|
role = "user"
|
|
348
365
|
parts = _convert_to_parts(message.content)
|
|
366
|
+
if i == 1 and convert_system_message_to_human and system_instruction:
|
|
367
|
+
parts = [p for p in system_instruction.parts] + parts
|
|
368
|
+
system_instruction = None
|
|
349
369
|
elif isinstance(message, FunctionMessage):
|
|
350
370
|
role = "user"
|
|
351
371
|
response: Any
|
|
@@ -357,8 +377,8 @@ def _parse_chat_history(
|
|
|
357
377
|
except json.JSONDecodeError:
|
|
358
378
|
response = message.content # leave as str representation
|
|
359
379
|
parts = [
|
|
360
|
-
|
|
361
|
-
function_response=
|
|
380
|
+
Part(
|
|
381
|
+
function_response=FunctionResponse(
|
|
362
382
|
name=message.name,
|
|
363
383
|
response=(
|
|
364
384
|
{"output": response}
|
|
@@ -391,8 +411,8 @@ def _parse_chat_history(
|
|
|
391
411
|
except json.JSONDecodeError:
|
|
392
412
|
tool_response = message.content # leave as str representation
|
|
393
413
|
parts = [
|
|
394
|
-
|
|
395
|
-
function_response=
|
|
414
|
+
Part(
|
|
415
|
+
function_response=FunctionResponse(
|
|
396
416
|
name=name,
|
|
397
417
|
response=(
|
|
398
418
|
{"output": tool_response}
|
|
@@ -407,12 +427,12 @@ def _parse_chat_history(
|
|
|
407
427
|
f"Unexpected message with type {type(message)} at the position {i}."
|
|
408
428
|
)
|
|
409
429
|
|
|
410
|
-
messages.append(
|
|
430
|
+
messages.append(Content(role=role, parts=parts))
|
|
411
431
|
return system_instruction, messages
|
|
412
432
|
|
|
413
433
|
|
|
414
434
|
def _parse_response_candidate(
|
|
415
|
-
response_candidate:
|
|
435
|
+
response_candidate: Candidate, streaming: bool = False
|
|
416
436
|
) -> AIMessage:
|
|
417
437
|
content: Union[None, str, List[str]] = None
|
|
418
438
|
additional_kwargs = {}
|
|
@@ -499,7 +519,7 @@ def _parse_response_candidate(
|
|
|
499
519
|
|
|
500
520
|
|
|
501
521
|
def _response_to_result(
|
|
502
|
-
response:
|
|
522
|
+
response: GenerateContentResponse,
|
|
503
523
|
stream: bool = False,
|
|
504
524
|
) -> ChatResult:
|
|
505
525
|
"""Converts a PaLM API response into a LangChain ChatResult."""
|
|
@@ -557,6 +577,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
557
577
|
"""
|
|
558
578
|
|
|
559
579
|
client: Any #: :meta private:
|
|
580
|
+
async_client: Any #: :meta private:
|
|
581
|
+
default_metadata: Sequence[Tuple[str, str]] = Field(
|
|
582
|
+
default_factory=list
|
|
583
|
+
) #: :meta private:
|
|
560
584
|
|
|
561
585
|
convert_system_message_to_human: bool = False
|
|
562
586
|
"""Whether to merge any leading SystemMessage into the following HumanMessage.
|
|
@@ -582,29 +606,6 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
582
606
|
@root_validator()
|
|
583
607
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
584
608
|
"""Validates params and passes them to google-generativeai package."""
|
|
585
|
-
additional_headers = values.get("additional_headers") or {}
|
|
586
|
-
default_metadata = tuple(additional_headers.items())
|
|
587
|
-
|
|
588
|
-
if values.get("credentials"):
|
|
589
|
-
genai.configure(
|
|
590
|
-
credentials=values.get("credentials"),
|
|
591
|
-
transport=values.get("transport"),
|
|
592
|
-
client_options=values.get("client_options"),
|
|
593
|
-
default_metadata=default_metadata,
|
|
594
|
-
)
|
|
595
|
-
else:
|
|
596
|
-
google_api_key = get_from_dict_or_env(
|
|
597
|
-
values, "google_api_key", "GOOGLE_API_KEY"
|
|
598
|
-
)
|
|
599
|
-
if isinstance(google_api_key, SecretStr):
|
|
600
|
-
google_api_key = google_api_key.get_secret_value()
|
|
601
|
-
|
|
602
|
-
genai.configure(
|
|
603
|
-
api_key=google_api_key,
|
|
604
|
-
transport=values.get("transport"),
|
|
605
|
-
client_options=values.get("client_options"),
|
|
606
|
-
default_metadata=default_metadata,
|
|
607
|
-
)
|
|
608
609
|
if (
|
|
609
610
|
values.get("temperature") is not None
|
|
610
611
|
and not 0 <= values["temperature"] <= 1
|
|
@@ -616,8 +617,36 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
616
617
|
|
|
617
618
|
if values.get("top_k") is not None and values["top_k"] <= 0:
|
|
618
619
|
raise ValueError("top_k must be positive")
|
|
619
|
-
|
|
620
|
-
values["
|
|
620
|
+
|
|
621
|
+
if not values["model"].startswith("models/"):
|
|
622
|
+
values["model"] = f"models/{values['model']}"
|
|
623
|
+
|
|
624
|
+
additional_headers = values.get("additional_headers") or {}
|
|
625
|
+
values["default_metadata"] = tuple(additional_headers.items())
|
|
626
|
+
client_info = get_client_info("ChatGoogleGenerativeAI")
|
|
627
|
+
google_api_key = None
|
|
628
|
+
if not values.get("credentials"):
|
|
629
|
+
google_api_key = get_from_dict_or_env(
|
|
630
|
+
values, "google_api_key", "GOOGLE_API_KEY"
|
|
631
|
+
)
|
|
632
|
+
if isinstance(google_api_key, SecretStr):
|
|
633
|
+
google_api_key = google_api_key.get_secret_value()
|
|
634
|
+
transport: Optional[str] = values.get("transport")
|
|
635
|
+
values["client"] = genaix.build_generative_service(
|
|
636
|
+
credentials=values.get("credentials"),
|
|
637
|
+
api_key=google_api_key,
|
|
638
|
+
client_info=client_info,
|
|
639
|
+
client_options=values.get("client_options"),
|
|
640
|
+
transport=transport,
|
|
641
|
+
)
|
|
642
|
+
values["async_client"] = genaix.build_generative_async_service(
|
|
643
|
+
credentials=values.get("credentials"),
|
|
644
|
+
api_key=google_api_key,
|
|
645
|
+
client_info=client_info,
|
|
646
|
+
client_options=values.get("client_options"),
|
|
647
|
+
transport=transport,
|
|
648
|
+
)
|
|
649
|
+
|
|
621
650
|
return values
|
|
622
651
|
|
|
623
652
|
@property
|
|
@@ -632,8 +661,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
632
661
|
}
|
|
633
662
|
|
|
634
663
|
def _prepare_params(
|
|
635
|
-
self,
|
|
636
|
-
|
|
664
|
+
self,
|
|
665
|
+
stop: Optional[List[str]],
|
|
666
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
667
|
+
) -> GenerationConfig:
|
|
637
668
|
gen_config = {
|
|
638
669
|
k: v
|
|
639
670
|
for k, v in {
|
|
@@ -646,27 +677,37 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
646
677
|
}.items()
|
|
647
678
|
if v is not None
|
|
648
679
|
}
|
|
649
|
-
if
|
|
650
|
-
gen_config = {**gen_config, **
|
|
651
|
-
|
|
652
|
-
return params
|
|
680
|
+
if generation_config:
|
|
681
|
+
gen_config = {**gen_config, **generation_config}
|
|
682
|
+
return GenerationConfig(**gen_config)
|
|
653
683
|
|
|
654
684
|
def _generate(
|
|
655
685
|
self,
|
|
656
686
|
messages: List[BaseMessage],
|
|
657
687
|
stop: Optional[List[str]] = None,
|
|
658
688
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
689
|
+
*,
|
|
690
|
+
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
|
|
691
|
+
functions: Optional[Sequence[FunctionDeclarationType]] = None,
|
|
692
|
+
safety_settings: Optional[SafetySettingDict] = None,
|
|
693
|
+
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
694
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
659
695
|
**kwargs: Any,
|
|
660
696
|
) -> ChatResult:
|
|
661
|
-
|
|
697
|
+
request = self._prepare_request(
|
|
662
698
|
messages,
|
|
663
699
|
stop=stop,
|
|
664
|
-
|
|
700
|
+
tools=tools,
|
|
701
|
+
functions=functions,
|
|
702
|
+
safety_settings=safety_settings,
|
|
703
|
+
tool_config=tool_config,
|
|
704
|
+
generation_config=generation_config,
|
|
665
705
|
)
|
|
666
|
-
response:
|
|
667
|
-
|
|
668
|
-
**
|
|
669
|
-
generation_method=
|
|
706
|
+
response: GenerateContentResponse = _chat_with_retry(
|
|
707
|
+
request=request,
|
|
708
|
+
**kwargs,
|
|
709
|
+
generation_method=self.client.generate_content,
|
|
710
|
+
metadata=self.default_metadata,
|
|
670
711
|
)
|
|
671
712
|
return _response_to_result(response)
|
|
672
713
|
|
|
@@ -675,17 +716,28 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
675
716
|
messages: List[BaseMessage],
|
|
676
717
|
stop: Optional[List[str]] = None,
|
|
677
718
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
719
|
+
*,
|
|
720
|
+
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
|
|
721
|
+
functions: Optional[Sequence[FunctionDeclarationType]] = None,
|
|
722
|
+
safety_settings: Optional[SafetySettingDict] = None,
|
|
723
|
+
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
724
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
678
725
|
**kwargs: Any,
|
|
679
726
|
) -> ChatResult:
|
|
680
|
-
|
|
727
|
+
request = self._prepare_request(
|
|
681
728
|
messages,
|
|
682
729
|
stop=stop,
|
|
683
|
-
|
|
730
|
+
tools=tools,
|
|
731
|
+
functions=functions,
|
|
732
|
+
safety_settings=safety_settings,
|
|
733
|
+
tool_config=tool_config,
|
|
734
|
+
generation_config=generation_config,
|
|
684
735
|
)
|
|
685
|
-
response:
|
|
686
|
-
|
|
687
|
-
**
|
|
688
|
-
generation_method=
|
|
736
|
+
response: GenerateContentResponse = await _achat_with_retry(
|
|
737
|
+
request=request,
|
|
738
|
+
**kwargs,
|
|
739
|
+
generation_method=self.async_client.generate_content,
|
|
740
|
+
metadata=self.default_metadata,
|
|
689
741
|
)
|
|
690
742
|
return _response_to_result(response)
|
|
691
743
|
|
|
@@ -694,18 +746,28 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
694
746
|
messages: List[BaseMessage],
|
|
695
747
|
stop: Optional[List[str]] = None,
|
|
696
748
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
749
|
+
*,
|
|
750
|
+
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
|
|
751
|
+
functions: Optional[Sequence[FunctionDeclarationType]] = None,
|
|
752
|
+
safety_settings: Optional[SafetySettingDict] = None,
|
|
753
|
+
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
754
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
697
755
|
**kwargs: Any,
|
|
698
756
|
) -> Iterator[ChatGenerationChunk]:
|
|
699
|
-
|
|
757
|
+
request = self._prepare_request(
|
|
700
758
|
messages,
|
|
701
759
|
stop=stop,
|
|
702
|
-
|
|
760
|
+
tools=tools,
|
|
761
|
+
functions=functions,
|
|
762
|
+
safety_settings=safety_settings,
|
|
763
|
+
tool_config=tool_config,
|
|
764
|
+
generation_config=generation_config,
|
|
703
765
|
)
|
|
704
|
-
response:
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
766
|
+
response: GenerateContentResponse = _chat_with_retry(
|
|
767
|
+
request=request,
|
|
768
|
+
generation_method=self.client.stream_generate_content,
|
|
769
|
+
**kwargs,
|
|
770
|
+
metadata=self.default_metadata,
|
|
709
771
|
)
|
|
710
772
|
for chunk in response:
|
|
711
773
|
_chat_result = _response_to_result(chunk, stream=True)
|
|
@@ -720,18 +782,28 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
720
782
|
messages: List[BaseMessage],
|
|
721
783
|
stop: Optional[List[str]] = None,
|
|
722
784
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
785
|
+
*,
|
|
786
|
+
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
|
|
787
|
+
functions: Optional[Sequence[FunctionDeclarationType]] = None,
|
|
788
|
+
safety_settings: Optional[SafetySettingDict] = None,
|
|
789
|
+
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
790
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
723
791
|
**kwargs: Any,
|
|
724
792
|
) -> AsyncIterator[ChatGenerationChunk]:
|
|
725
|
-
|
|
793
|
+
request = self._prepare_request(
|
|
726
794
|
messages,
|
|
727
795
|
stop=stop,
|
|
728
|
-
|
|
796
|
+
tools=tools,
|
|
797
|
+
functions=functions,
|
|
798
|
+
safety_settings=safety_settings,
|
|
799
|
+
tool_config=tool_config,
|
|
800
|
+
generation_config=generation_config,
|
|
729
801
|
)
|
|
730
802
|
async for chunk in await _achat_with_retry(
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
803
|
+
request=request,
|
|
804
|
+
generation_method=self.async_client.stream_generate_content,
|
|
805
|
+
**kwargs,
|
|
806
|
+
metadata=self.default_metadata,
|
|
735
807
|
):
|
|
736
808
|
_chat_result = _response_to_result(chunk, stream=True)
|
|
737
809
|
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
|
@@ -740,17 +812,17 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
740
812
|
await run_manager.on_llm_new_token(gen.text)
|
|
741
813
|
yield gen
|
|
742
814
|
|
|
743
|
-
def
|
|
815
|
+
def _prepare_request(
|
|
744
816
|
self,
|
|
745
817
|
messages: List[BaseMessage],
|
|
818
|
+
*,
|
|
746
819
|
stop: Optional[List[str]] = None,
|
|
747
820
|
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
|
|
748
821
|
functions: Optional[Sequence[FunctionDeclarationType]] = None,
|
|
749
822
|
safety_settings: Optional[SafetySettingDict] = None,
|
|
750
823
|
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
751
|
-
|
|
752
|
-
) -> Tuple[Dict[str, Any]
|
|
753
|
-
client = self.client
|
|
824
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
825
|
+
) -> Tuple[GenerateContentRequest, Dict[str, Any]]:
|
|
754
826
|
formatted_tools = None
|
|
755
827
|
if tools:
|
|
756
828
|
formatted_tools = [
|
|
@@ -759,25 +831,35 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
759
831
|
elif functions:
|
|
760
832
|
formatted_tools = [convert_to_genai_function_declarations(functions)]
|
|
761
833
|
|
|
762
|
-
if formatted_tools or safety_settings:
|
|
763
|
-
client = genai.GenerativeModel(
|
|
764
|
-
model_name=self.model,
|
|
765
|
-
tools=formatted_tools,
|
|
766
|
-
safety_settings=safety_settings,
|
|
767
|
-
)
|
|
768
|
-
|
|
769
|
-
params = self._prepare_params(stop, tool_config=tool_config, **kwargs)
|
|
770
834
|
system_instruction, history = _parse_chat_history(
|
|
771
835
|
messages,
|
|
772
836
|
convert_system_message_to_human=self.convert_system_message_to_human,
|
|
773
837
|
)
|
|
774
|
-
|
|
775
|
-
if
|
|
776
|
-
|
|
777
|
-
|
|
838
|
+
formatted_tool_config = None
|
|
839
|
+
if tool_config:
|
|
840
|
+
formatted_tool_config = ToolConfig(
|
|
841
|
+
function_calling_config=tool_config["function_calling_config"]
|
|
778
842
|
)
|
|
779
|
-
|
|
780
|
-
|
|
843
|
+
formatted_safety_settings = []
|
|
844
|
+
if safety_settings:
|
|
845
|
+
formatted_safety_settings = [
|
|
846
|
+
SafetySetting(category=c, threshold=t)
|
|
847
|
+
for c, t in safety_settings.items()
|
|
848
|
+
]
|
|
849
|
+
request = GenerateContentRequest(
|
|
850
|
+
model=self.model,
|
|
851
|
+
contents=history,
|
|
852
|
+
tools=formatted_tools,
|
|
853
|
+
tool_config=formatted_tool_config,
|
|
854
|
+
safety_settings=formatted_safety_settings,
|
|
855
|
+
generation_config=self._prepare_params(
|
|
856
|
+
stop, generation_config=generation_config
|
|
857
|
+
),
|
|
858
|
+
)
|
|
859
|
+
if system_instruction:
|
|
860
|
+
request.system_instruction = system_instruction
|
|
861
|
+
|
|
862
|
+
return request
|
|
781
863
|
|
|
782
864
|
def get_num_tokens(self, text: str) -> int:
|
|
783
865
|
"""Get the number of tokens present in the text.
|
|
@@ -790,14 +872,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
790
872
|
Returns:
|
|
791
873
|
The integer number of tokens in the text.
|
|
792
874
|
"""
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
|
798
|
-
token_count = result["token_count"]
|
|
799
|
-
|
|
800
|
-
return token_count
|
|
875
|
+
result = self.client.count_tokens(
|
|
876
|
+
model=self.model, contents=[Content(parts=[Part(text=text)])]
|
|
877
|
+
)
|
|
878
|
+
return result.total_tokens
|
|
801
879
|
|
|
802
880
|
def bind_tools(
|
|
803
881
|
self,
|
|
@@ -828,7 +906,9 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
828
906
|
genai_tools = [tool_to_dict(convert_to_genai_function_declarations(tools))]
|
|
829
907
|
if tool_choice:
|
|
830
908
|
all_names = [
|
|
831
|
-
f["name"]
|
|
909
|
+
f["name"] # type: ignore[index]
|
|
910
|
+
for t in genai_tools
|
|
911
|
+
for f in t["function_declarations"]
|
|
832
912
|
]
|
|
833
913
|
tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
|
|
834
914
|
return self.bind(tools=genai_tools, tool_config=tool_config, **kwargs)
|
|
@@ -1,12 +1,19 @@
|
|
|
1
1
|
from typing import Any, Dict, List, Optional
|
|
2
2
|
|
|
3
3
|
# TODO: remove ignore once the google package is published with types
|
|
4
|
-
|
|
4
|
+
from google.ai.generativelanguage_v1beta.types import (
|
|
5
|
+
BatchEmbedContentsRequest,
|
|
6
|
+
EmbedContentRequest,
|
|
7
|
+
)
|
|
5
8
|
from langchain_core.embeddings import Embeddings
|
|
6
9
|
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
|
7
10
|
from langchain_core.utils import get_from_dict_or_env
|
|
8
11
|
|
|
9
|
-
from langchain_google_genai._common import
|
|
12
|
+
from langchain_google_genai._common import (
|
|
13
|
+
GoogleGenerativeAIError,
|
|
14
|
+
get_client_info,
|
|
15
|
+
)
|
|
16
|
+
from langchain_google_genai._genai_extension import build_generative_service
|
|
10
17
|
|
|
11
18
|
|
|
12
19
|
class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
|
|
@@ -27,6 +34,7 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
|
|
|
27
34
|
embeddings.embed_query("What's our Q1 revenue?")
|
|
28
35
|
"""
|
|
29
36
|
|
|
37
|
+
client: Any #: :meta private:
|
|
30
38
|
model: str = Field(
|
|
31
39
|
...,
|
|
32
40
|
description="The name of the embedding model to use. "
|
|
@@ -70,44 +78,43 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
|
|
|
70
78
|
@root_validator()
|
|
71
79
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
72
80
|
"""Validates params and passes them to google-generativeai package."""
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
google_api_key
|
|
81
|
-
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
google_api_key = google_api_key.get_secret_value()
|
|
85
|
-
|
|
86
|
-
genai.configure(
|
|
87
|
-
api_key=google_api_key,
|
|
88
|
-
transport=values.get("transport"),
|
|
89
|
-
client_options=values.get("client_options"),
|
|
90
|
-
)
|
|
81
|
+
google_api_key = get_from_dict_or_env(
|
|
82
|
+
values, "google_api_key", "GOOGLE_API_KEY"
|
|
83
|
+
)
|
|
84
|
+
client_info = get_client_info("GoogleGenerativeAIEmbeddings")
|
|
85
|
+
|
|
86
|
+
values["client"] = build_generative_service(
|
|
87
|
+
credentials=values.get("credentials"),
|
|
88
|
+
api_key=google_api_key,
|
|
89
|
+
client_info=client_info,
|
|
90
|
+
client_options=values.get("client_options"),
|
|
91
|
+
)
|
|
91
92
|
return values
|
|
92
93
|
|
|
93
|
-
def
|
|
94
|
-
self,
|
|
95
|
-
|
|
96
|
-
task_type =
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
94
|
+
def _prepare_request(
|
|
95
|
+
self,
|
|
96
|
+
text: str,
|
|
97
|
+
task_type: Optional[str] = None,
|
|
98
|
+
title: Optional[str] = None,
|
|
99
|
+
output_dimensionality: Optional[int] = None,
|
|
100
|
+
) -> EmbedContentRequest:
|
|
101
|
+
task_type = self.task_type or task_type or "RETRIEVAL_DOCUMENT"
|
|
102
|
+
# https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest
|
|
103
|
+
request = EmbedContentRequest(
|
|
104
|
+
content={"parts": [{"text": text}]},
|
|
105
|
+
model=self.model,
|
|
106
|
+
task_type=task_type.upper(),
|
|
107
|
+
title=title,
|
|
108
|
+
output_dimensionality=output_dimensionality,
|
|
109
|
+
)
|
|
110
|
+
return request
|
|
108
111
|
|
|
109
112
|
def embed_documents(
|
|
110
|
-
self,
|
|
113
|
+
self,
|
|
114
|
+
texts: List[str],
|
|
115
|
+
task_type: Optional[str] = None,
|
|
116
|
+
titles: Optional[List[str]] = None,
|
|
117
|
+
output_dimensionality: Optional[int] = None,
|
|
111
118
|
) -> List[List[float]]:
|
|
112
119
|
"""Embed a list of strings. Vertex AI currently
|
|
113
120
|
sets a max batch size of 5 strings.
|
|
@@ -115,21 +122,58 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
|
|
|
115
122
|
Args:
|
|
116
123
|
texts: List[str] The list of strings to embed.
|
|
117
124
|
batch_size: [int] The batch size of embeddings to send to the model
|
|
125
|
+
task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType)
|
|
126
|
+
titles: An optional list of titles for texts provided.
|
|
127
|
+
Only applicable when TaskType is RETRIEVAL_DOCUMENT.
|
|
128
|
+
output_dimensionality: Optional reduced dimension for the output embedding.
|
|
129
|
+
https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest
|
|
118
130
|
|
|
119
131
|
Returns:
|
|
120
132
|
List of embeddings, one for each text.
|
|
121
133
|
"""
|
|
122
|
-
|
|
123
|
-
|
|
134
|
+
titles = titles if titles else [None] * len(texts) # type: ignore[list-item]
|
|
135
|
+
requests = [
|
|
136
|
+
self._prepare_request(
|
|
137
|
+
text=text,
|
|
138
|
+
task_type=task_type,
|
|
139
|
+
title=title,
|
|
140
|
+
output_dimensionality=output_dimensionality,
|
|
141
|
+
)
|
|
142
|
+
for text, title in zip(texts, titles)
|
|
143
|
+
]
|
|
124
144
|
|
|
125
|
-
|
|
145
|
+
try:
|
|
146
|
+
result = self.client.batch_embed_contents(
|
|
147
|
+
BatchEmbedContentsRequest(requests=requests, model=self.model)
|
|
148
|
+
)
|
|
149
|
+
except Exception as e:
|
|
150
|
+
raise GoogleGenerativeAIError(f"Error embedding content: {e}") from e
|
|
151
|
+
return [e.values for e in result.embeddings]
|
|
152
|
+
|
|
153
|
+
def embed_query(
|
|
154
|
+
self,
|
|
155
|
+
text: str,
|
|
156
|
+
task_type: Optional[str] = None,
|
|
157
|
+
title: Optional[str] = None,
|
|
158
|
+
output_dimensionality: Optional[int] = None,
|
|
159
|
+
) -> List[float]:
|
|
126
160
|
"""Embed a text.
|
|
127
161
|
|
|
128
162
|
Args:
|
|
129
163
|
text: The text to embed.
|
|
164
|
+
task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType)
|
|
165
|
+
title: An optional title for the text.
|
|
166
|
+
Only applicable when TaskType is RETRIEVAL_DOCUMENT.
|
|
167
|
+
output_dimensionality: Optional reduced dimension for the output embedding.
|
|
168
|
+
https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest
|
|
130
169
|
|
|
131
170
|
Returns:
|
|
132
171
|
Embedding for the text.
|
|
133
172
|
"""
|
|
134
|
-
task_type = self.task_type or "
|
|
135
|
-
return self.
|
|
173
|
+
task_type = self.task_type or "RETRIEVAL_QUERY"
|
|
174
|
+
return self.embed_documents(
|
|
175
|
+
[text],
|
|
176
|
+
task_type=task_type,
|
|
177
|
+
titles=[title] if title else None,
|
|
178
|
+
output_dimensionality=output_dimensionality,
|
|
179
|
+
)[0]
|
langchain_google_genai/llms.py
CHANGED
|
@@ -174,7 +174,6 @@ Supported examples:
|
|
|
174
174
|
from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
|
|
175
175
|
|
|
176
176
|
safety_settings = {
|
|
177
|
-
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
|
|
178
177
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
|
179
178
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
|
180
179
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: langchain-google-genai
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.4
|
|
4
4
|
Summary: An integration package connecting Google's genai package and LangChain
|
|
5
5
|
Home-page: https://github.com/langchain-ai/langchain-google
|
|
6
6
|
License: MIT
|
|
@@ -13,7 +13,7 @@ Classifier: Programming Language :: Python :: 3.11
|
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.12
|
|
14
14
|
Provides-Extra: images
|
|
15
15
|
Requires-Dist: google-generativeai (>=0.5.2,<0.6.0)
|
|
16
|
-
Requires-Dist: langchain-core (>=0.1.45,<0.
|
|
16
|
+
Requires-Dist: langchain-core (>=0.1.45,<0.3)
|
|
17
17
|
Requires-Dist: pillow (>=10.1.0,<11.0.0) ; extra == "images"
|
|
18
18
|
Project-URL: Repository, https://github.com/langchain-ai/langchain-google
|
|
19
19
|
Project-URL: Source Code, https://github.com/langchain-ai/langchain-google/tree/main/libs/genai
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
langchain_google_genai/__init__.py,sha256=Oji-S2KYWrku1wyQEskY84IOfY8MfRhujjJ4d7hbsk4,2758
|
|
2
|
+
langchain_google_genai/_common.py,sha256=ASlwE8hEbvOm55BVF_D4rf2nl7RYsnpsi5xbM6DW3Cc,1576
|
|
3
|
+
langchain_google_genai/_enums.py,sha256=KLPmxS1K83K4HjBIXFaXoL_sFEOv8Hq-2B2PDMKyDgo,197
|
|
4
|
+
langchain_google_genai/_function_utils.py,sha256=d0ApSCjoV9Em1CteBaGznilxrw-PDXqQ4sQa5p7cJfM,8232
|
|
5
|
+
langchain_google_genai/_genai_extension.py,sha256=ZwNwLV22RSf9LB7FOCLsoHzLlQDF-EQmRNYM1an2uSw,20769
|
|
6
|
+
langchain_google_genai/_image_utils.py,sha256=-0XgCMdYkvrIktFvUpy-2GPbFgfSVKZICawB2hiJzus,4999
|
|
7
|
+
langchain_google_genai/chat_models.py,sha256=3ubY5qCaZjFSHKHiiP5XCOPDusUxO8kJvpj9DwLtUG4,33014
|
|
8
|
+
langchain_google_genai/embeddings.py,sha256=kQW6pl1TUGKKSxiUjc7rptp0iEu_Rer1m1LLHKVaW14,6578
|
|
9
|
+
langchain_google_genai/genai_aqa.py,sha256=zcC5cdFYtqLK7DGPhYGvWNeHHeU-CQKA9KhewmsA5lw,4303
|
|
10
|
+
langchain_google_genai/google_vector_store.py,sha256=PPIk-4FmD5UUdmYA2u7VcEhGsiztvRVN59QoGLXdfoA,16139
|
|
11
|
+
langchain_google_genai/llms.py,sha256=S7tOy-c37DElcHtkGl8rwvvg1zOzCxb9PEyJ4E-j7qU,13431
|
|
12
|
+
langchain_google_genai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
|
+
langchain_google_genai-1.0.4.dist-info/LICENSE,sha256=DppmdYJVSc1jd0aio6ptnMUn5tIHrdAhQ12SclEBfBg,1072
|
|
14
|
+
langchain_google_genai-1.0.4.dist-info/METADATA,sha256=nHJMXwe6iuI99J_VO1GQrVgv1BLvVGW5x0mSNswDWq0,3818
|
|
15
|
+
langchain_google_genai-1.0.4.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
|
|
16
|
+
langchain_google_genai-1.0.4.dist-info/RECORD,,
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
langchain_google_genai/__init__.py,sha256=NiLfU4IyEQSlSl8hbN_fTdX-hrO5tuOflxFlEK0Sy4c,2762
|
|
2
|
-
langchain_google_genai/_common.py,sha256=1r0VrrBSTZfGprmICZ5OV-W5SK31jKRFFCNE3vJ3jmk,136
|
|
3
|
-
langchain_google_genai/_enums.py,sha256=q8IYAqufV-_yZ98FDnsZ3x-1w4804J_e8PrTKT0sdhY,163
|
|
4
|
-
langchain_google_genai/_function_utils.py,sha256=uDWg2Gcuv3PtdUL14sDOVTFbq8gaaGJ360bL4IbN4AI,8705
|
|
5
|
-
langchain_google_genai/_genai_extension.py,sha256=2Uqg7vSF0vu1J4AhAyIPzadtpM5JJwZsBXvpItO2TY4,18736
|
|
6
|
-
langchain_google_genai/chat_models.py,sha256=edkRRStq42pxY1xBzfC2yUfN-jBsM-Vr1dx9QVg-qd4,29969
|
|
7
|
-
langchain_google_genai/embeddings.py,sha256=QZJRd5xQkGzalAGiKeorrsnVmsyaO4NGmGuzQFDoRe0,4807
|
|
8
|
-
langchain_google_genai/genai_aqa.py,sha256=zcC5cdFYtqLK7DGPhYGvWNeHHeU-CQKA9KhewmsA5lw,4303
|
|
9
|
-
langchain_google_genai/google_vector_store.py,sha256=PPIk-4FmD5UUdmYA2u7VcEhGsiztvRVN59QoGLXdfoA,16139
|
|
10
|
-
langchain_google_genai/llms.py,sha256=QPjCs0AHb-2d3GMCS3sZqzG_Yr71YbaGX_Vs8QlLTKU,13518
|
|
11
|
-
langchain_google_genai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
-
langchain_google_genai-1.0.3.dist-info/LICENSE,sha256=DppmdYJVSc1jd0aio6ptnMUn5tIHrdAhQ12SclEBfBg,1072
|
|
13
|
-
langchain_google_genai-1.0.3.dist-info/METADATA,sha256=IPH2elmUC6Rt7ZG-g5eyma97A8r48jftkeq9v8XyQMc,3818
|
|
14
|
-
langchain_google_genai-1.0.3.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
|
|
15
|
-
langchain_google_genai-1.0.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|