langchain-google-genai 0.0.1rc0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain-google-genai might be problematic. Click here for more details.
- langchain_google_genai-0.0.1rc0/LICENSE +21 -0
- langchain_google_genai-0.0.1rc0/PKG-INFO +72 -0
- langchain_google_genai-0.0.1rc0/README.md +58 -0
- langchain_google_genai-0.0.1rc0/langchain_google_genai/__init__.py +3 -0
- langchain_google_genai-0.0.1rc0/langchain_google_genai/chat_models.py +556 -0
- langchain_google_genai-0.0.1rc0/pyproject.toml +94 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2023 LangChain, Inc.
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: langchain-google-genai
|
|
3
|
+
Version: 0.0.1rc0
|
|
4
|
+
Summary: An integration package connecting Google's genai package and LangChain
|
|
5
|
+
Requires-Python: >=3.9,<4.0
|
|
6
|
+
Classifier: Programming Language :: Python :: 3
|
|
7
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
8
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
10
|
+
Requires-Dist: google-generativeai (>=0.3.1,<0.4.0)
|
|
11
|
+
Requires-Dist: langchain-core (>=0.0.12)
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
|
|
14
|
+
# langchain-google
|
|
15
|
+
|
|
16
|
+
This partner package contains the newer Google LangChain integrations.
|
|
17
|
+
|
|
18
|
+
## Installation
|
|
19
|
+
|
|
20
|
+
```python
|
|
21
|
+
pip install -U langchain-google
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
## Chat Models
|
|
25
|
+
|
|
26
|
+
This package contains the `ChatGoogleGenerativeAI` class, which is the recommended way to interface with the Google Gemini series of models.
|
|
27
|
+
|
|
28
|
+
To use, install the requirements, and configure your environment.
|
|
29
|
+
|
|
30
|
+
```bash
|
|
31
|
+
export GOOGLE_API_KEY=your-api-key
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
Then initialize
|
|
35
|
+
|
|
36
|
+
```python
|
|
37
|
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
38
|
+
|
|
39
|
+
llm = ChatGoogleGenerativeAI(model="gemini-pro")
|
|
40
|
+
llm.invoke("Sing a ballad of LangChain.")
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
#### Multimodal inputs
|
|
44
|
+
|
|
45
|
+
Gemini vision model supports image inputs when providing a single chat message. Example:
|
|
46
|
+
|
|
47
|
+
```
|
|
48
|
+
from langchain_core.messages import HumanMessage
|
|
49
|
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
50
|
+
|
|
51
|
+
llm = ChatGoogleGenerativeAI(model="gemini-pro-vision")
|
|
52
|
+
# example
|
|
53
|
+
message = HumanMessage(
|
|
54
|
+
content=[
|
|
55
|
+
{
|
|
56
|
+
"type": "text",
|
|
57
|
+
"text": "What's in this image?",
|
|
58
|
+
}, # You can optionally provide text parts
|
|
59
|
+
{"type": "image_url", "image_url": "https://picsum.photos/seed/picsum/200/300"},
|
|
60
|
+
]
|
|
61
|
+
)
|
|
62
|
+
llm.invoke([message])
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
The value of `image_url` can be any of the following:
|
|
66
|
+
|
|
67
|
+
- A public image URL
|
|
68
|
+
- An accessible gcs file (e.g., "gcs://path/to/file.png")
|
|
69
|
+
- A local file path
|
|
70
|
+
- A base64 encoded image (e.g., `data:image/png;base64,abcd124`)
|
|
71
|
+
- A PIL image
|
|
72
|
+
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# langchain-google
|
|
2
|
+
|
|
3
|
+
This partner package contains the newer Google LangChain integrations.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
```python
|
|
8
|
+
pip install -U langchain-google
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Chat Models
|
|
12
|
+
|
|
13
|
+
This package contains the `ChatGoogleGenerativeAI` class, which is the recommended way to interface with the Google Gemini series of models.
|
|
14
|
+
|
|
15
|
+
To use, install the requirements, and configure your environment.
|
|
16
|
+
|
|
17
|
+
```bash
|
|
18
|
+
export GOOGLE_API_KEY=your-api-key
|
|
19
|
+
```
|
|
20
|
+
|
|
21
|
+
Then initialize
|
|
22
|
+
|
|
23
|
+
```python
|
|
24
|
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
25
|
+
|
|
26
|
+
llm = ChatGoogleGenerativeAI(model="gemini-pro")
|
|
27
|
+
llm.invoke("Sing a ballad of LangChain.")
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
#### Multimodal inputs
|
|
31
|
+
|
|
32
|
+
Gemini vision model supports image inputs when providing a single chat message. Example:
|
|
33
|
+
|
|
34
|
+
```
|
|
35
|
+
from langchain_core.messages import HumanMessage
|
|
36
|
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
37
|
+
|
|
38
|
+
llm = ChatGoogleGenerativeAI(model="gemini-pro-vision")
|
|
39
|
+
# example
|
|
40
|
+
message = HumanMessage(
|
|
41
|
+
content=[
|
|
42
|
+
{
|
|
43
|
+
"type": "text",
|
|
44
|
+
"text": "What's in this image?",
|
|
45
|
+
}, # You can optionally provide text parts
|
|
46
|
+
{"type": "image_url", "image_url": "https://picsum.photos/seed/picsum/200/300"},
|
|
47
|
+
]
|
|
48
|
+
)
|
|
49
|
+
llm.invoke([message])
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
The value of `image_url` can be any of the following:
|
|
53
|
+
|
|
54
|
+
- A public image URL
|
|
55
|
+
- An accessible gcs file (e.g., "gcs://path/to/file.png")
|
|
56
|
+
- A local file path
|
|
57
|
+
- A base64 encoded image (e.g., `data:image/png;base64,abcd124`)
|
|
58
|
+
- A PIL image
|
|
@@ -0,0 +1,556 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import base64
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
from io import BytesIO
|
|
8
|
+
from typing import (
|
|
9
|
+
TYPE_CHECKING,
|
|
10
|
+
Any,
|
|
11
|
+
Awaitable,
|
|
12
|
+
Callable,
|
|
13
|
+
Dict,
|
|
14
|
+
Iterator,
|
|
15
|
+
List,
|
|
16
|
+
Mapping,
|
|
17
|
+
Optional,
|
|
18
|
+
Sequence,
|
|
19
|
+
Tuple,
|
|
20
|
+
Type,
|
|
21
|
+
Union,
|
|
22
|
+
cast,
|
|
23
|
+
)
|
|
24
|
+
from urllib.parse import urlparse
|
|
25
|
+
|
|
26
|
+
import requests
|
|
27
|
+
from langchain_core.callbacks.manager import (
|
|
28
|
+
AsyncCallbackManagerForLLMRun,
|
|
29
|
+
CallbackManagerForLLMRun,
|
|
30
|
+
)
|
|
31
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
32
|
+
from langchain_core.messages import (
|
|
33
|
+
AIMessage,
|
|
34
|
+
AIMessageChunk,
|
|
35
|
+
BaseMessage,
|
|
36
|
+
ChatMessage,
|
|
37
|
+
ChatMessageChunk,
|
|
38
|
+
HumanMessage,
|
|
39
|
+
HumanMessageChunk,
|
|
40
|
+
)
|
|
41
|
+
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
42
|
+
from langchain_core.pydantic_v1 import Field, root_validator
|
|
43
|
+
from langchain_core.utils import get_from_dict_or_env
|
|
44
|
+
from tenacity import (
|
|
45
|
+
before_sleep_log,
|
|
46
|
+
retry,
|
|
47
|
+
retry_if_exception_type,
|
|
48
|
+
stop_after_attempt,
|
|
49
|
+
wait_exponential,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
logger = logging.getLogger(__name__)
|
|
53
|
+
|
|
54
|
+
if TYPE_CHECKING:
|
|
55
|
+
# TODO: remove ignore once the google package is published with types
|
|
56
|
+
import google.generativeai as genai # type: ignore[import]
|
|
57
|
+
IMAGE_TYPES: Tuple = ()
|
|
58
|
+
try:
|
|
59
|
+
import PIL
|
|
60
|
+
from PIL.Image import Image
|
|
61
|
+
|
|
62
|
+
IMAGE_TYPES = IMAGE_TYPES + (Image,)
|
|
63
|
+
except ImportError:
|
|
64
|
+
PIL = None # type: ignore
|
|
65
|
+
Image = None # type: ignore
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ChatGoogleGenerativeAIError(Exception):
|
|
69
|
+
"""
|
|
70
|
+
Custom exception class for errors associated with the `Google GenAI` API.
|
|
71
|
+
|
|
72
|
+
This exception is raised when there are specific issues related to the
|
|
73
|
+
Google genai API usage in the ChatGoogleGenerativeAI class, such as unsupported
|
|
74
|
+
message types or roles.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _create_retry_decorator() -> Callable[[Any], Any]:
|
|
79
|
+
"""
|
|
80
|
+
Creates and returns a preconfigured tenacity retry decorator.
|
|
81
|
+
|
|
82
|
+
The retry decorator is configured to handle specific Google API exceptions
|
|
83
|
+
such as ResourceExhausted and ServiceUnavailable. It uses an exponential
|
|
84
|
+
backoff strategy for retries.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Callable[[Any], Any]: A retry decorator configured for handling specific
|
|
88
|
+
Google API exceptions.
|
|
89
|
+
"""
|
|
90
|
+
import google.api_core.exceptions
|
|
91
|
+
|
|
92
|
+
multiplier = 2
|
|
93
|
+
min_seconds = 1
|
|
94
|
+
max_seconds = 60
|
|
95
|
+
max_retries = 10
|
|
96
|
+
|
|
97
|
+
return retry(
|
|
98
|
+
reraise=True,
|
|
99
|
+
stop=stop_after_attempt(max_retries),
|
|
100
|
+
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
|
|
101
|
+
retry=(
|
|
102
|
+
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
|
|
103
|
+
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
|
|
104
|
+
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
|
|
105
|
+
),
|
|
106
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
|
|
111
|
+
"""
|
|
112
|
+
Executes a chat generation method with retry logic using tenacity.
|
|
113
|
+
|
|
114
|
+
This function is a wrapper that applies a retry mechanism to a provided
|
|
115
|
+
chat generation function. It is useful for handling intermittent issues
|
|
116
|
+
like network errors or temporary service unavailability.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
generation_method (Callable): The chat generation method to be executed.
|
|
120
|
+
**kwargs (Any): Additional keyword arguments to pass to the generation method.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Any: The result from the chat generation method.
|
|
124
|
+
"""
|
|
125
|
+
retry_decorator = _create_retry_decorator()
|
|
126
|
+
from google.api_core.exceptions import InvalidArgument # type: ignore
|
|
127
|
+
|
|
128
|
+
@retry_decorator
|
|
129
|
+
def _chat_with_retry(**kwargs: Any) -> Any:
|
|
130
|
+
try:
|
|
131
|
+
return generation_method(**kwargs)
|
|
132
|
+
except InvalidArgument as e:
|
|
133
|
+
# Do not retry for these errors.
|
|
134
|
+
raise ChatGoogleGenerativeAIError(
|
|
135
|
+
f"Invalid argument provided to Gemini: {e}"
|
|
136
|
+
) from e
|
|
137
|
+
except Exception as e:
|
|
138
|
+
raise e
|
|
139
|
+
|
|
140
|
+
return _chat_with_retry(**kwargs)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _get_role(message: BaseMessage) -> str:
|
|
144
|
+
if isinstance(message, ChatMessage):
|
|
145
|
+
if message.role not in ("user", "model"):
|
|
146
|
+
raise ChatGoogleGenerativeAIError(
|
|
147
|
+
"Gemini only supports user and model roles when"
|
|
148
|
+
" providing it with Chat messages."
|
|
149
|
+
)
|
|
150
|
+
return message.role
|
|
151
|
+
elif isinstance(message, HumanMessage):
|
|
152
|
+
return "user"
|
|
153
|
+
elif isinstance(message, AIMessage):
|
|
154
|
+
return "model"
|
|
155
|
+
else:
|
|
156
|
+
# TODO: Gemini doesn't seem to have a concept of system messages yet.
|
|
157
|
+
raise ChatGoogleGenerativeAIError(
|
|
158
|
+
f"Message of '{message.type}' type not supported by Gemini."
|
|
159
|
+
" Please only provide it with Human or AI (user/assistant) messages."
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _is_openai_parts_format(part: dict) -> bool:
|
|
164
|
+
return "type" in part
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _is_vision_model(model: str) -> bool:
|
|
168
|
+
return "vision" in model
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _is_url(s: str) -> bool:
|
|
172
|
+
try:
|
|
173
|
+
result = urlparse(s)
|
|
174
|
+
return all([result.scheme, result.netloc])
|
|
175
|
+
except Exception as e:
|
|
176
|
+
logger.debug(f"Unable to parse URL: {e}")
|
|
177
|
+
return False
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _is_b64(s: str) -> bool:
|
|
181
|
+
return s.startswith("data:image")
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _load_image_from_gcs(path: str, project: Optional[str] = None) -> Image:
|
|
185
|
+
try:
|
|
186
|
+
from google.cloud import storage # type: ignore[attr-defined]
|
|
187
|
+
except ImportError:
|
|
188
|
+
raise ImportError(
|
|
189
|
+
"google-cloud-storage is required to load images from GCS."
|
|
190
|
+
" Install it with `pip install google-cloud-storage`"
|
|
191
|
+
)
|
|
192
|
+
if PIL is None:
|
|
193
|
+
raise ImportError(
|
|
194
|
+
"PIL is required to load images. Please install it "
|
|
195
|
+
"with `pip install pillow`"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
gcs_client = storage.Client(project=project)
|
|
199
|
+
pieces = path.split("/")
|
|
200
|
+
blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:])))
|
|
201
|
+
if len(blobs) > 1:
|
|
202
|
+
raise ValueError(f"Found more than one candidate for {path}!")
|
|
203
|
+
img_bytes = blobs[0].download_as_bytes()
|
|
204
|
+
return PIL.Image.open(BytesIO(img_bytes))
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _url_to_pil(image_source: str) -> Image:
|
|
208
|
+
if PIL is None:
|
|
209
|
+
raise ImportError(
|
|
210
|
+
"PIL is required to load images. Please install it "
|
|
211
|
+
"with `pip install pillow`"
|
|
212
|
+
)
|
|
213
|
+
try:
|
|
214
|
+
if isinstance(image_source, IMAGE_TYPES):
|
|
215
|
+
return image_source # type: ignore[return-value]
|
|
216
|
+
elif _is_url(image_source):
|
|
217
|
+
if image_source.startswith("gs://"):
|
|
218
|
+
return _load_image_from_gcs(image_source)
|
|
219
|
+
response = requests.get(image_source)
|
|
220
|
+
response.raise_for_status()
|
|
221
|
+
return PIL.Image.open(BytesIO(response.content))
|
|
222
|
+
elif _is_b64(image_source):
|
|
223
|
+
_, encoded = image_source.split(",", 1)
|
|
224
|
+
data = base64.b64decode(encoded)
|
|
225
|
+
return PIL.Image.open(BytesIO(data))
|
|
226
|
+
elif os.path.exists(image_source):
|
|
227
|
+
return PIL.Image.open(image_source)
|
|
228
|
+
else:
|
|
229
|
+
raise ValueError(
|
|
230
|
+
"The provided string is not a valid URL, base64, or file path."
|
|
231
|
+
)
|
|
232
|
+
except Exception as e:
|
|
233
|
+
raise ValueError(f"Unable to process the provided image source: {e}")
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def _convert_to_parts(
|
|
237
|
+
content: Sequence[Union[str, dict]],
|
|
238
|
+
) -> List[genai.types.PartType]:
|
|
239
|
+
"""Converts a list of LangChain messages into a google parts."""
|
|
240
|
+
import google.generativeai as genai
|
|
241
|
+
|
|
242
|
+
parts = []
|
|
243
|
+
for part in content:
|
|
244
|
+
if isinstance(part, str):
|
|
245
|
+
parts.append(genai.types.PartDict(text=part, inline_data=None))
|
|
246
|
+
elif isinstance(part, Mapping):
|
|
247
|
+
# OpenAI Format
|
|
248
|
+
if _is_openai_parts_format(part):
|
|
249
|
+
if part["type"] == "text":
|
|
250
|
+
parts.append({"text": part["text"]})
|
|
251
|
+
elif part["type"] == "image_url":
|
|
252
|
+
img_url = part["image_url"]
|
|
253
|
+
if isinstance(img_url, dict):
|
|
254
|
+
if "url" not in img_url:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
f"Unrecognized message image format: {img_url}"
|
|
257
|
+
)
|
|
258
|
+
img_url = img_url["url"]
|
|
259
|
+
parts.append({"inline_data": _url_to_pil(img_url)})
|
|
260
|
+
else:
|
|
261
|
+
raise ValueError(f"Unrecognized message part type: {part['type']}")
|
|
262
|
+
else:
|
|
263
|
+
# Yolo
|
|
264
|
+
logger.warning(
|
|
265
|
+
"Unrecognized message part format. Assuming it's a text part."
|
|
266
|
+
)
|
|
267
|
+
parts.append(part)
|
|
268
|
+
else:
|
|
269
|
+
# TODO: Maybe some of Google's native stuff
|
|
270
|
+
# would hit this branch.
|
|
271
|
+
raise ChatGoogleGenerativeAIError(
|
|
272
|
+
"Gemini only supports text and inline_data parts."
|
|
273
|
+
)
|
|
274
|
+
return parts
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def _messages_to_genai_contents(
|
|
278
|
+
input_messages: Sequence[BaseMessage],
|
|
279
|
+
) -> List[genai.types.ContentDict]:
|
|
280
|
+
"""Converts a list of messages into a Gemini API google content dicts."""
|
|
281
|
+
|
|
282
|
+
messages: List[genai.types.MessageDict] = []
|
|
283
|
+
for i, message in enumerate(input_messages):
|
|
284
|
+
role = _get_role(message)
|
|
285
|
+
if isinstance(message.content, str):
|
|
286
|
+
parts = [message.content]
|
|
287
|
+
else:
|
|
288
|
+
parts = _convert_to_parts(message.content)
|
|
289
|
+
messages.append({"role": role, "parts": parts})
|
|
290
|
+
if i > 0:
|
|
291
|
+
# Cannot have multiple messages from the same role in a row.
|
|
292
|
+
if role == messages[-2]["role"]:
|
|
293
|
+
raise ChatGoogleGenerativeAIError(
|
|
294
|
+
"Cannot have multiple messages from the same role in a row."
|
|
295
|
+
" Consider merging them into a single message with multiple"
|
|
296
|
+
f" parts.\nReceived: {messages}"
|
|
297
|
+
)
|
|
298
|
+
return messages
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _parts_to_content(parts: List[genai.types.PartType]) -> Union[List[dict], str]:
|
|
302
|
+
"""Converts a list of Gemini API Part objects into a list of LangChain messages."""
|
|
303
|
+
if len(parts) == 1 and parts[0].text is not None and not parts[0].inline_data:
|
|
304
|
+
# Simple text response. The typical response
|
|
305
|
+
return parts[0].text
|
|
306
|
+
elif not parts:
|
|
307
|
+
logger.warning("Gemini produced an empty response.")
|
|
308
|
+
return ""
|
|
309
|
+
messages = []
|
|
310
|
+
for part in parts:
|
|
311
|
+
if part.text is not None:
|
|
312
|
+
messages.append(
|
|
313
|
+
{
|
|
314
|
+
"type": "text",
|
|
315
|
+
"text": part.text,
|
|
316
|
+
}
|
|
317
|
+
)
|
|
318
|
+
else:
|
|
319
|
+
# TODO: Handle inline_data if that's a thing?
|
|
320
|
+
raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}")
|
|
321
|
+
return messages
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def _response_to_result(
|
|
325
|
+
response: genai.types.GenerateContentResponse,
|
|
326
|
+
ai_msg_t: Type[BaseMessage] = AIMessage,
|
|
327
|
+
human_msg_t: Type[BaseMessage] = HumanMessage,
|
|
328
|
+
chat_msg_t: Type[BaseMessage] = ChatMessage,
|
|
329
|
+
generation_t: Type[ChatGeneration] = ChatGeneration,
|
|
330
|
+
) -> ChatResult:
|
|
331
|
+
"""Converts a PaLM API response into a LangChain ChatResult."""
|
|
332
|
+
llm_output = {}
|
|
333
|
+
if response.prompt_feedback:
|
|
334
|
+
try:
|
|
335
|
+
prompt_feedback = type(response.prompt_feedback).to_dict(
|
|
336
|
+
response.prompt_feedback, use_integers_for_enums=False
|
|
337
|
+
)
|
|
338
|
+
llm_output["prompt_feedback"] = prompt_feedback
|
|
339
|
+
except Exception as e:
|
|
340
|
+
logger.debug(f"Unable to convert prompt_feedback to dict: {e}")
|
|
341
|
+
|
|
342
|
+
generations: List[ChatGeneration] = []
|
|
343
|
+
|
|
344
|
+
role_map = {
|
|
345
|
+
"model": ai_msg_t,
|
|
346
|
+
"user": human_msg_t,
|
|
347
|
+
}
|
|
348
|
+
for candidate in response.candidates:
|
|
349
|
+
content = candidate.content
|
|
350
|
+
parts_content = _parts_to_content(content.parts)
|
|
351
|
+
if content.role not in role_map:
|
|
352
|
+
logger.warning(
|
|
353
|
+
f"Unrecognized role: {content.role}. Treating as a ChatMessage."
|
|
354
|
+
)
|
|
355
|
+
msg = chat_msg_t(content=parts_content, role=content.role)
|
|
356
|
+
else:
|
|
357
|
+
msg = role_map[content.role](content=parts_content)
|
|
358
|
+
generation_info = {}
|
|
359
|
+
if candidate.finish_reason:
|
|
360
|
+
generation_info["finish_reason"] = candidate.finish_reason.name
|
|
361
|
+
if candidate.safety_ratings:
|
|
362
|
+
generation_info["safety_ratings"] = [
|
|
363
|
+
type(rating).to_dict(rating) for rating in candidate.safety_ratings
|
|
364
|
+
]
|
|
365
|
+
generations.append(generation_t(message=msg, generation_info=generation_info))
|
|
366
|
+
if not response.candidates:
|
|
367
|
+
# Likely a "prompt feedback" violation (e.g., toxic input)
|
|
368
|
+
# Raising an error would be different than how OpenAI handles it,
|
|
369
|
+
# so we'll just log a warning and continue with an empty message.
|
|
370
|
+
logger.warning(
|
|
371
|
+
"Gemini produced an empty response. Continuing with empty message\n"
|
|
372
|
+
f"Feedback: {response.prompt_feedback}"
|
|
373
|
+
)
|
|
374
|
+
generations = [generation_t(message=ai_msg_t(content=""), generation_info={})]
|
|
375
|
+
return ChatResult(generations=generations, llm_output=llm_output)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
class ChatGoogleGenerativeAI(BaseChatModel):
|
|
379
|
+
"""`Google Generative AI` Chat models API.
|
|
380
|
+
|
|
381
|
+
To use you must have the google.generativeai Python package installed and
|
|
382
|
+
either:
|
|
383
|
+
|
|
384
|
+
1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
|
|
385
|
+
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
|
|
386
|
+
constructor.
|
|
387
|
+
|
|
388
|
+
Example:
|
|
389
|
+
.. code-block:: python
|
|
390
|
+
|
|
391
|
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
392
|
+
chat = ChatGoogleGenerativeAI(model="gemini-pro")
|
|
393
|
+
chat.invoke("Write me a ballad about LangChain")
|
|
394
|
+
|
|
395
|
+
"""
|
|
396
|
+
|
|
397
|
+
model: str = Field(
|
|
398
|
+
...,
|
|
399
|
+
description="""The name of the model to use.
|
|
400
|
+
Supported examples:
|
|
401
|
+
- gemini-pro""",
|
|
402
|
+
)
|
|
403
|
+
max_output_tokens: int = Field(default=None, description="Max output tokens")
|
|
404
|
+
|
|
405
|
+
client: Any #: :meta private:
|
|
406
|
+
google_api_key: Optional[str] = None
|
|
407
|
+
temperature: Optional[float] = None
|
|
408
|
+
"""Run inference with this temperature. Must by in the closed
|
|
409
|
+
interval [0.0, 1.0]."""
|
|
410
|
+
top_k: Optional[int] = None
|
|
411
|
+
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
|
412
|
+
Must be positive."""
|
|
413
|
+
n: int = 1
|
|
414
|
+
"""Number of chat completions to generate for each prompt. Note that the API may
|
|
415
|
+
not return the full n completions if duplicates are generated."""
|
|
416
|
+
|
|
417
|
+
_generative_model: Any #: :meta private:
|
|
418
|
+
|
|
419
|
+
@property
|
|
420
|
+
def lc_secrets(self) -> Dict[str, str]:
|
|
421
|
+
return {"google_api_key": "GOOGLE_API_KEY"}
|
|
422
|
+
|
|
423
|
+
@property
|
|
424
|
+
def _llm_type(self) -> str:
|
|
425
|
+
return "chat-google-generative-ai"
|
|
426
|
+
|
|
427
|
+
@property
|
|
428
|
+
def _is_geminiai(self) -> bool:
|
|
429
|
+
return self.model is not None and "gemini" in self.model
|
|
430
|
+
|
|
431
|
+
@classmethod
|
|
432
|
+
def is_lc_serializable(self) -> bool:
|
|
433
|
+
return True
|
|
434
|
+
|
|
435
|
+
@root_validator()
|
|
436
|
+
def validate_environment(cls, values: Dict) -> Dict:
|
|
437
|
+
google_api_key = get_from_dict_or_env(
|
|
438
|
+
values, "google_api_key", "GOOGLE_API_KEY"
|
|
439
|
+
)
|
|
440
|
+
try:
|
|
441
|
+
import google.generativeai as genai
|
|
442
|
+
|
|
443
|
+
genai.configure(api_key=google_api_key)
|
|
444
|
+
except ImportError:
|
|
445
|
+
raise ChatGoogleGenerativeAIError(
|
|
446
|
+
"Could not import google.generativeai python package. "
|
|
447
|
+
"Please install it with `pip install google-generativeai`"
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
values["client"] = genai
|
|
451
|
+
if (
|
|
452
|
+
values.get("temperature") is not None
|
|
453
|
+
and not 0 <= values["temperature"] <= 1
|
|
454
|
+
):
|
|
455
|
+
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
|
456
|
+
|
|
457
|
+
if values.get("top_p") is not None and not 0 <= values["top_p"] <= 1:
|
|
458
|
+
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
|
459
|
+
|
|
460
|
+
if values.get("top_k") is not None and values["top_k"] <= 0:
|
|
461
|
+
raise ValueError("top_k must be positive")
|
|
462
|
+
model = values["model"]
|
|
463
|
+
values["_generative_model"] = genai.GenerativeModel(model_name=model)
|
|
464
|
+
return values
|
|
465
|
+
|
|
466
|
+
@property
|
|
467
|
+
def _identifying_params(self) -> Dict[str, Any]:
|
|
468
|
+
"""Get the identifying parameters."""
|
|
469
|
+
return {
|
|
470
|
+
"model": self.model,
|
|
471
|
+
"temperature": self.temperature,
|
|
472
|
+
"top_k": self.top_k,
|
|
473
|
+
"n": self.n,
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
@property
|
|
477
|
+
def _generation_method(self) -> Callable:
|
|
478
|
+
return self._generative_model.generate_content
|
|
479
|
+
|
|
480
|
+
@property
|
|
481
|
+
def _async_generation_method(self) -> Awaitable:
|
|
482
|
+
# TODO Add support once Google uncomments the async client
|
|
483
|
+
return self._generative_model.generate_content
|
|
484
|
+
|
|
485
|
+
def _prepare_params(
|
|
486
|
+
self, messages: Sequence[BaseMessage], stop: Optional[List[str]]
|
|
487
|
+
) -> Dict[str, Any]:
|
|
488
|
+
contents = _messages_to_genai_contents(messages)
|
|
489
|
+
gen_config = {
|
|
490
|
+
k: v
|
|
491
|
+
for k, v in {
|
|
492
|
+
"candidate_count": self.n,
|
|
493
|
+
"temperature": self.temperature,
|
|
494
|
+
"stop_sequences": stop,
|
|
495
|
+
"max_output_tokens": self.max_output_tokens,
|
|
496
|
+
}.items()
|
|
497
|
+
if v is not None
|
|
498
|
+
}
|
|
499
|
+
params = {
|
|
500
|
+
"generation_config": gen_config,
|
|
501
|
+
"contents": contents,
|
|
502
|
+
}
|
|
503
|
+
return params
|
|
504
|
+
|
|
505
|
+
def _generate(
|
|
506
|
+
self,
|
|
507
|
+
messages: List[BaseMessage],
|
|
508
|
+
stop: Optional[List[str]] = None,
|
|
509
|
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
510
|
+
**kwargs: Any,
|
|
511
|
+
) -> ChatResult:
|
|
512
|
+
params = self._prepare_params(messages, stop)
|
|
513
|
+
response: genai.types.GenerateContentResponse = chat_with_retry(
|
|
514
|
+
**params,
|
|
515
|
+
generation_method=self._generation_method,
|
|
516
|
+
**kwargs,
|
|
517
|
+
)
|
|
518
|
+
return _response_to_result(response)
|
|
519
|
+
|
|
520
|
+
async def _agenerate(
|
|
521
|
+
self,
|
|
522
|
+
messages: List[BaseMessage],
|
|
523
|
+
stop: Optional[List[str]] = None,
|
|
524
|
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
525
|
+
**kwargs: Any,
|
|
526
|
+
) -> ChatResult:
|
|
527
|
+
return await asyncio.get_running_loop().run_in_executor(
|
|
528
|
+
None, self._generate, messages, stop, run_manager, **kwargs
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
def _stream(
|
|
532
|
+
self,
|
|
533
|
+
messages: List[BaseMessage],
|
|
534
|
+
stop: Optional[List[str]] = None,
|
|
535
|
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
536
|
+
**kwargs: Any,
|
|
537
|
+
) -> Iterator[ChatGenerationChunk]:
|
|
538
|
+
params = self._prepare_params(messages, stop)
|
|
539
|
+
response: genai.types.GenerateContentResponse = chat_with_retry(
|
|
540
|
+
**params,
|
|
541
|
+
generation_method=self._generation_method,
|
|
542
|
+
**kwargs,
|
|
543
|
+
stream=True,
|
|
544
|
+
)
|
|
545
|
+
for chunk in response:
|
|
546
|
+
_chat_result = _response_to_result(
|
|
547
|
+
chunk,
|
|
548
|
+
ai_msg_t=AIMessageChunk,
|
|
549
|
+
human_msg_t=HumanMessageChunk,
|
|
550
|
+
chat_msg_t=ChatMessageChunk,
|
|
551
|
+
generation_t=ChatGenerationChunk,
|
|
552
|
+
)
|
|
553
|
+
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
|
554
|
+
yield gen
|
|
555
|
+
if run_manager:
|
|
556
|
+
run_manager.on_llm_new_token(gen.text)
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
[tool.poetry]
|
|
2
|
+
name = "langchain-google-genai"
|
|
3
|
+
version = "0.0.1-rc0"
|
|
4
|
+
description = "An integration package connecting Google's genai package and LangChain"
|
|
5
|
+
authors = []
|
|
6
|
+
readme = "README.md"
|
|
7
|
+
|
|
8
|
+
[tool.poetry.dependencies]
|
|
9
|
+
python = ">=3.9,<4.0"
|
|
10
|
+
langchain-core = ">=0.0.12"
|
|
11
|
+
google-generativeai = "^0.3.1"
|
|
12
|
+
|
|
13
|
+
[tool.poetry.group.test]
|
|
14
|
+
optional = true
|
|
15
|
+
|
|
16
|
+
[tool.poetry.group.test.dependencies]
|
|
17
|
+
pytest = "^7.3.0"
|
|
18
|
+
freezegun = "^1.2.2"
|
|
19
|
+
pytest-mock = "^3.10.0"
|
|
20
|
+
syrupy = "^4.0.2"
|
|
21
|
+
pytest-watcher = "^0.3.4"
|
|
22
|
+
pytest-asyncio = "^0.21.1"
|
|
23
|
+
langchain-core = {path = "../../core", develop = true}
|
|
24
|
+
|
|
25
|
+
[tool.poetry.group.codespell]
|
|
26
|
+
optional = true
|
|
27
|
+
|
|
28
|
+
[tool.poetry.group.codespell.dependencies]
|
|
29
|
+
codespell = "^2.2.0"
|
|
30
|
+
|
|
31
|
+
[tool.poetry.group.test_integration]
|
|
32
|
+
optional = true
|
|
33
|
+
|
|
34
|
+
[tool.poetry.group.test_integration.dependencies]
|
|
35
|
+
|
|
36
|
+
[tool.poetry.group.lint]
|
|
37
|
+
optional = true
|
|
38
|
+
|
|
39
|
+
[tool.poetry.group.lint.dependencies]
|
|
40
|
+
ruff = "^0.1.5"
|
|
41
|
+
|
|
42
|
+
[tool.poetry.group.typing.dependencies]
|
|
43
|
+
mypy = "^0.991"
|
|
44
|
+
langchain-core = {path = "../../core", develop = true}
|
|
45
|
+
|
|
46
|
+
[tool.poetry.group.dev]
|
|
47
|
+
optional = true
|
|
48
|
+
|
|
49
|
+
[tool.poetry.group.dev.dependencies]
|
|
50
|
+
langchain-core = {path = "../../core", develop = true}
|
|
51
|
+
pillow = "^10.1.0"
|
|
52
|
+
types-requests = "^2.31.0.10"
|
|
53
|
+
types-pillow = "^10.1.0.2"
|
|
54
|
+
types-google-cloud-ndb = "^2.2.0.1"
|
|
55
|
+
|
|
56
|
+
[tool.ruff]
|
|
57
|
+
select = [
|
|
58
|
+
"E", # pycodestyle
|
|
59
|
+
"F", # pyflakes
|
|
60
|
+
"I", # isort
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
[tool.mypy]
|
|
64
|
+
disallow_untyped_defs = "True"
|
|
65
|
+
exclude = ["notebooks", "examples", "example_data", "langchain_core/pydantic"]
|
|
66
|
+
|
|
67
|
+
[tool.coverage.run]
|
|
68
|
+
omit = [
|
|
69
|
+
"tests/*",
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
[build-system]
|
|
73
|
+
requires = ["poetry-core>=1.0.0"]
|
|
74
|
+
build-backend = "poetry.core.masonry.api"
|
|
75
|
+
|
|
76
|
+
[tool.pytest.ini_options]
|
|
77
|
+
# --strict-markers will raise errors on unknown marks.
|
|
78
|
+
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
|
79
|
+
#
|
|
80
|
+
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
|
81
|
+
# --strict-config any warnings encountered while parsing the `pytest`
|
|
82
|
+
# section of the configuration file raise errors.
|
|
83
|
+
#
|
|
84
|
+
# https://github.com/tophat/syrupy
|
|
85
|
+
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
|
86
|
+
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
|
87
|
+
# Registering custom markers.
|
|
88
|
+
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
|
89
|
+
markers = [
|
|
90
|
+
"requires: mark tests as requiring a specific library",
|
|
91
|
+
"asyncio: mark tests as requiring asyncio",
|
|
92
|
+
"compile: mark placeholder test used to compile integration tests without running them",
|
|
93
|
+
]
|
|
94
|
+
asyncio_mode = "auto"
|