langchain-google-genai 0.0.1rc0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain-google-genai might be problematic. Click here for more details.

@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 LangChain, Inc.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,72 @@
1
+ Metadata-Version: 2.1
2
+ Name: langchain-google-genai
3
+ Version: 0.0.1rc0
4
+ Summary: An integration package connecting Google's genai package and LangChain
5
+ Requires-Python: >=3.9,<4.0
6
+ Classifier: Programming Language :: Python :: 3
7
+ Classifier: Programming Language :: Python :: 3.9
8
+ Classifier: Programming Language :: Python :: 3.10
9
+ Classifier: Programming Language :: Python :: 3.11
10
+ Requires-Dist: google-generativeai (>=0.3.1,<0.4.0)
11
+ Requires-Dist: langchain-core (>=0.0.12)
12
+ Description-Content-Type: text/markdown
13
+
14
+ # langchain-google
15
+
16
+ This partner package contains the newer Google LangChain integrations.
17
+
18
+ ## Installation
19
+
20
+ ```python
21
+ pip install -U langchain-google
22
+ ```
23
+
24
+ ## Chat Models
25
+
26
+ This package contains the `ChatGoogleGenerativeAI` class, which is the recommended way to interface with the Google Gemini series of models.
27
+
28
+ To use, install the requirements, and configure your environment.
29
+
30
+ ```bash
31
+ export GOOGLE_API_KEY=your-api-key
32
+ ```
33
+
34
+ Then initialize
35
+
36
+ ```python
37
+ from langchain_google_genai import ChatGoogleGenerativeAI
38
+
39
+ llm = ChatGoogleGenerativeAI(model="gemini-pro")
40
+ llm.invoke("Sing a ballad of LangChain.")
41
+ ```
42
+
43
+ #### Multimodal inputs
44
+
45
+ Gemini vision model supports image inputs when providing a single chat message. Example:
46
+
47
+ ```
48
+ from langchain_core.messages import HumanMessage
49
+ from langchain_google_genai import ChatGoogleGenerativeAI
50
+
51
+ llm = ChatGoogleGenerativeAI(model="gemini-pro-vision")
52
+ # example
53
+ message = HumanMessage(
54
+ content=[
55
+ {
56
+ "type": "text",
57
+ "text": "What's in this image?",
58
+ }, # You can optionally provide text parts
59
+ {"type": "image_url", "image_url": "https://picsum.photos/seed/picsum/200/300"},
60
+ ]
61
+ )
62
+ llm.invoke([message])
63
+ ```
64
+
65
+ The value of `image_url` can be any of the following:
66
+
67
+ - A public image URL
68
+ - An accessible gcs file (e.g., "gcs://path/to/file.png")
69
+ - A local file path
70
+ - A base64 encoded image (e.g., ``)
71
+ - A PIL image
72
+
@@ -0,0 +1,58 @@
1
+ # langchain-google
2
+
3
+ This partner package contains the newer Google LangChain integrations.
4
+
5
+ ## Installation
6
+
7
+ ```python
8
+ pip install -U langchain-google
9
+ ```
10
+
11
+ ## Chat Models
12
+
13
+ This package contains the `ChatGoogleGenerativeAI` class, which is the recommended way to interface with the Google Gemini series of models.
14
+
15
+ To use, install the requirements, and configure your environment.
16
+
17
+ ```bash
18
+ export GOOGLE_API_KEY=your-api-key
19
+ ```
20
+
21
+ Then initialize
22
+
23
+ ```python
24
+ from langchain_google_genai import ChatGoogleGenerativeAI
25
+
26
+ llm = ChatGoogleGenerativeAI(model="gemini-pro")
27
+ llm.invoke("Sing a ballad of LangChain.")
28
+ ```
29
+
30
+ #### Multimodal inputs
31
+
32
+ Gemini vision model supports image inputs when providing a single chat message. Example:
33
+
34
+ ```
35
+ from langchain_core.messages import HumanMessage
36
+ from langchain_google_genai import ChatGoogleGenerativeAI
37
+
38
+ llm = ChatGoogleGenerativeAI(model="gemini-pro-vision")
39
+ # example
40
+ message = HumanMessage(
41
+ content=[
42
+ {
43
+ "type": "text",
44
+ "text": "What's in this image?",
45
+ }, # You can optionally provide text parts
46
+ {"type": "image_url", "image_url": "https://picsum.photos/seed/picsum/200/300"},
47
+ ]
48
+ )
49
+ llm.invoke([message])
50
+ ```
51
+
52
+ The value of `image_url` can be any of the following:
53
+
54
+ - A public image URL
55
+ - An accessible gcs file (e.g., "gcs://path/to/file.png")
56
+ - A local file path
57
+ - A base64 encoded image (e.g., ``)
58
+ - A PIL image
@@ -0,0 +1,3 @@
1
+ from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
2
+
3
+ __all__ = ["ChatGoogleGenerativeAI"]
@@ -0,0 +1,556 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import base64
5
+ import logging
6
+ import os
7
+ from io import BytesIO
8
+ from typing import (
9
+ TYPE_CHECKING,
10
+ Any,
11
+ Awaitable,
12
+ Callable,
13
+ Dict,
14
+ Iterator,
15
+ List,
16
+ Mapping,
17
+ Optional,
18
+ Sequence,
19
+ Tuple,
20
+ Type,
21
+ Union,
22
+ cast,
23
+ )
24
+ from urllib.parse import urlparse
25
+
26
+ import requests
27
+ from langchain_core.callbacks.manager import (
28
+ AsyncCallbackManagerForLLMRun,
29
+ CallbackManagerForLLMRun,
30
+ )
31
+ from langchain_core.language_models.chat_models import BaseChatModel
32
+ from langchain_core.messages import (
33
+ AIMessage,
34
+ AIMessageChunk,
35
+ BaseMessage,
36
+ ChatMessage,
37
+ ChatMessageChunk,
38
+ HumanMessage,
39
+ HumanMessageChunk,
40
+ )
41
+ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
42
+ from langchain_core.pydantic_v1 import Field, root_validator
43
+ from langchain_core.utils import get_from_dict_or_env
44
+ from tenacity import (
45
+ before_sleep_log,
46
+ retry,
47
+ retry_if_exception_type,
48
+ stop_after_attempt,
49
+ wait_exponential,
50
+ )
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+ if TYPE_CHECKING:
55
+ # TODO: remove ignore once the google package is published with types
56
+ import google.generativeai as genai # type: ignore[import]
57
+ IMAGE_TYPES: Tuple = ()
58
+ try:
59
+ import PIL
60
+ from PIL.Image import Image
61
+
62
+ IMAGE_TYPES = IMAGE_TYPES + (Image,)
63
+ except ImportError:
64
+ PIL = None # type: ignore
65
+ Image = None # type: ignore
66
+
67
+
68
+ class ChatGoogleGenerativeAIError(Exception):
69
+ """
70
+ Custom exception class for errors associated with the `Google GenAI` API.
71
+
72
+ This exception is raised when there are specific issues related to the
73
+ Google genai API usage in the ChatGoogleGenerativeAI class, such as unsupported
74
+ message types or roles.
75
+ """
76
+
77
+
78
+ def _create_retry_decorator() -> Callable[[Any], Any]:
79
+ """
80
+ Creates and returns a preconfigured tenacity retry decorator.
81
+
82
+ The retry decorator is configured to handle specific Google API exceptions
83
+ such as ResourceExhausted and ServiceUnavailable. It uses an exponential
84
+ backoff strategy for retries.
85
+
86
+ Returns:
87
+ Callable[[Any], Any]: A retry decorator configured for handling specific
88
+ Google API exceptions.
89
+ """
90
+ import google.api_core.exceptions
91
+
92
+ multiplier = 2
93
+ min_seconds = 1
94
+ max_seconds = 60
95
+ max_retries = 10
96
+
97
+ return retry(
98
+ reraise=True,
99
+ stop=stop_after_attempt(max_retries),
100
+ wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
101
+ retry=(
102
+ retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
103
+ | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
104
+ | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
105
+ ),
106
+ before_sleep=before_sleep_log(logger, logging.WARNING),
107
+ )
108
+
109
+
110
+ def chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
111
+ """
112
+ Executes a chat generation method with retry logic using tenacity.
113
+
114
+ This function is a wrapper that applies a retry mechanism to a provided
115
+ chat generation function. It is useful for handling intermittent issues
116
+ like network errors or temporary service unavailability.
117
+
118
+ Args:
119
+ generation_method (Callable): The chat generation method to be executed.
120
+ **kwargs (Any): Additional keyword arguments to pass to the generation method.
121
+
122
+ Returns:
123
+ Any: The result from the chat generation method.
124
+ """
125
+ retry_decorator = _create_retry_decorator()
126
+ from google.api_core.exceptions import InvalidArgument # type: ignore
127
+
128
+ @retry_decorator
129
+ def _chat_with_retry(**kwargs: Any) -> Any:
130
+ try:
131
+ return generation_method(**kwargs)
132
+ except InvalidArgument as e:
133
+ # Do not retry for these errors.
134
+ raise ChatGoogleGenerativeAIError(
135
+ f"Invalid argument provided to Gemini: {e}"
136
+ ) from e
137
+ except Exception as e:
138
+ raise e
139
+
140
+ return _chat_with_retry(**kwargs)
141
+
142
+
143
+ def _get_role(message: BaseMessage) -> str:
144
+ if isinstance(message, ChatMessage):
145
+ if message.role not in ("user", "model"):
146
+ raise ChatGoogleGenerativeAIError(
147
+ "Gemini only supports user and model roles when"
148
+ " providing it with Chat messages."
149
+ )
150
+ return message.role
151
+ elif isinstance(message, HumanMessage):
152
+ return "user"
153
+ elif isinstance(message, AIMessage):
154
+ return "model"
155
+ else:
156
+ # TODO: Gemini doesn't seem to have a concept of system messages yet.
157
+ raise ChatGoogleGenerativeAIError(
158
+ f"Message of '{message.type}' type not supported by Gemini."
159
+ " Please only provide it with Human or AI (user/assistant) messages."
160
+ )
161
+
162
+
163
+ def _is_openai_parts_format(part: dict) -> bool:
164
+ return "type" in part
165
+
166
+
167
+ def _is_vision_model(model: str) -> bool:
168
+ return "vision" in model
169
+
170
+
171
+ def _is_url(s: str) -> bool:
172
+ try:
173
+ result = urlparse(s)
174
+ return all([result.scheme, result.netloc])
175
+ except Exception as e:
176
+ logger.debug(f"Unable to parse URL: {e}")
177
+ return False
178
+
179
+
180
+ def _is_b64(s: str) -> bool:
181
+ return s.startswith("data:image")
182
+
183
+
184
+ def _load_image_from_gcs(path: str, project: Optional[str] = None) -> Image:
185
+ try:
186
+ from google.cloud import storage # type: ignore[attr-defined]
187
+ except ImportError:
188
+ raise ImportError(
189
+ "google-cloud-storage is required to load images from GCS."
190
+ " Install it with `pip install google-cloud-storage`"
191
+ )
192
+ if PIL is None:
193
+ raise ImportError(
194
+ "PIL is required to load images. Please install it "
195
+ "with `pip install pillow`"
196
+ )
197
+
198
+ gcs_client = storage.Client(project=project)
199
+ pieces = path.split("/")
200
+ blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:])))
201
+ if len(blobs) > 1:
202
+ raise ValueError(f"Found more than one candidate for {path}!")
203
+ img_bytes = blobs[0].download_as_bytes()
204
+ return PIL.Image.open(BytesIO(img_bytes))
205
+
206
+
207
+ def _url_to_pil(image_source: str) -> Image:
208
+ if PIL is None:
209
+ raise ImportError(
210
+ "PIL is required to load images. Please install it "
211
+ "with `pip install pillow`"
212
+ )
213
+ try:
214
+ if isinstance(image_source, IMAGE_TYPES):
215
+ return image_source # type: ignore[return-value]
216
+ elif _is_url(image_source):
217
+ if image_source.startswith("gs://"):
218
+ return _load_image_from_gcs(image_source)
219
+ response = requests.get(image_source)
220
+ response.raise_for_status()
221
+ return PIL.Image.open(BytesIO(response.content))
222
+ elif _is_b64(image_source):
223
+ _, encoded = image_source.split(",", 1)
224
+ data = base64.b64decode(encoded)
225
+ return PIL.Image.open(BytesIO(data))
226
+ elif os.path.exists(image_source):
227
+ return PIL.Image.open(image_source)
228
+ else:
229
+ raise ValueError(
230
+ "The provided string is not a valid URL, base64, or file path."
231
+ )
232
+ except Exception as e:
233
+ raise ValueError(f"Unable to process the provided image source: {e}")
234
+
235
+
236
+ def _convert_to_parts(
237
+ content: Sequence[Union[str, dict]],
238
+ ) -> List[genai.types.PartType]:
239
+ """Converts a list of LangChain messages into a google parts."""
240
+ import google.generativeai as genai
241
+
242
+ parts = []
243
+ for part in content:
244
+ if isinstance(part, str):
245
+ parts.append(genai.types.PartDict(text=part, inline_data=None))
246
+ elif isinstance(part, Mapping):
247
+ # OpenAI Format
248
+ if _is_openai_parts_format(part):
249
+ if part["type"] == "text":
250
+ parts.append({"text": part["text"]})
251
+ elif part["type"] == "image_url":
252
+ img_url = part["image_url"]
253
+ if isinstance(img_url, dict):
254
+ if "url" not in img_url:
255
+ raise ValueError(
256
+ f"Unrecognized message image format: {img_url}"
257
+ )
258
+ img_url = img_url["url"]
259
+ parts.append({"inline_data": _url_to_pil(img_url)})
260
+ else:
261
+ raise ValueError(f"Unrecognized message part type: {part['type']}")
262
+ else:
263
+ # Yolo
264
+ logger.warning(
265
+ "Unrecognized message part format. Assuming it's a text part."
266
+ )
267
+ parts.append(part)
268
+ else:
269
+ # TODO: Maybe some of Google's native stuff
270
+ # would hit this branch.
271
+ raise ChatGoogleGenerativeAIError(
272
+ "Gemini only supports text and inline_data parts."
273
+ )
274
+ return parts
275
+
276
+
277
+ def _messages_to_genai_contents(
278
+ input_messages: Sequence[BaseMessage],
279
+ ) -> List[genai.types.ContentDict]:
280
+ """Converts a list of messages into a Gemini API google content dicts."""
281
+
282
+ messages: List[genai.types.MessageDict] = []
283
+ for i, message in enumerate(input_messages):
284
+ role = _get_role(message)
285
+ if isinstance(message.content, str):
286
+ parts = [message.content]
287
+ else:
288
+ parts = _convert_to_parts(message.content)
289
+ messages.append({"role": role, "parts": parts})
290
+ if i > 0:
291
+ # Cannot have multiple messages from the same role in a row.
292
+ if role == messages[-2]["role"]:
293
+ raise ChatGoogleGenerativeAIError(
294
+ "Cannot have multiple messages from the same role in a row."
295
+ " Consider merging them into a single message with multiple"
296
+ f" parts.\nReceived: {messages}"
297
+ )
298
+ return messages
299
+
300
+
301
+ def _parts_to_content(parts: List[genai.types.PartType]) -> Union[List[dict], str]:
302
+ """Converts a list of Gemini API Part objects into a list of LangChain messages."""
303
+ if len(parts) == 1 and parts[0].text is not None and not parts[0].inline_data:
304
+ # Simple text response. The typical response
305
+ return parts[0].text
306
+ elif not parts:
307
+ logger.warning("Gemini produced an empty response.")
308
+ return ""
309
+ messages = []
310
+ for part in parts:
311
+ if part.text is not None:
312
+ messages.append(
313
+ {
314
+ "type": "text",
315
+ "text": part.text,
316
+ }
317
+ )
318
+ else:
319
+ # TODO: Handle inline_data if that's a thing?
320
+ raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}")
321
+ return messages
322
+
323
+
324
+ def _response_to_result(
325
+ response: genai.types.GenerateContentResponse,
326
+ ai_msg_t: Type[BaseMessage] = AIMessage,
327
+ human_msg_t: Type[BaseMessage] = HumanMessage,
328
+ chat_msg_t: Type[BaseMessage] = ChatMessage,
329
+ generation_t: Type[ChatGeneration] = ChatGeneration,
330
+ ) -> ChatResult:
331
+ """Converts a PaLM API response into a LangChain ChatResult."""
332
+ llm_output = {}
333
+ if response.prompt_feedback:
334
+ try:
335
+ prompt_feedback = type(response.prompt_feedback).to_dict(
336
+ response.prompt_feedback, use_integers_for_enums=False
337
+ )
338
+ llm_output["prompt_feedback"] = prompt_feedback
339
+ except Exception as e:
340
+ logger.debug(f"Unable to convert prompt_feedback to dict: {e}")
341
+
342
+ generations: List[ChatGeneration] = []
343
+
344
+ role_map = {
345
+ "model": ai_msg_t,
346
+ "user": human_msg_t,
347
+ }
348
+ for candidate in response.candidates:
349
+ content = candidate.content
350
+ parts_content = _parts_to_content(content.parts)
351
+ if content.role not in role_map:
352
+ logger.warning(
353
+ f"Unrecognized role: {content.role}. Treating as a ChatMessage."
354
+ )
355
+ msg = chat_msg_t(content=parts_content, role=content.role)
356
+ else:
357
+ msg = role_map[content.role](content=parts_content)
358
+ generation_info = {}
359
+ if candidate.finish_reason:
360
+ generation_info["finish_reason"] = candidate.finish_reason.name
361
+ if candidate.safety_ratings:
362
+ generation_info["safety_ratings"] = [
363
+ type(rating).to_dict(rating) for rating in candidate.safety_ratings
364
+ ]
365
+ generations.append(generation_t(message=msg, generation_info=generation_info))
366
+ if not response.candidates:
367
+ # Likely a "prompt feedback" violation (e.g., toxic input)
368
+ # Raising an error would be different than how OpenAI handles it,
369
+ # so we'll just log a warning and continue with an empty message.
370
+ logger.warning(
371
+ "Gemini produced an empty response. Continuing with empty message\n"
372
+ f"Feedback: {response.prompt_feedback}"
373
+ )
374
+ generations = [generation_t(message=ai_msg_t(content=""), generation_info={})]
375
+ return ChatResult(generations=generations, llm_output=llm_output)
376
+
377
+
378
+ class ChatGoogleGenerativeAI(BaseChatModel):
379
+ """`Google Generative AI` Chat models API.
380
+
381
+ To use you must have the google.generativeai Python package installed and
382
+ either:
383
+
384
+ 1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
385
+ 2. Pass your API key using the google_api_key kwarg to the ChatGoogle
386
+ constructor.
387
+
388
+ Example:
389
+ .. code-block:: python
390
+
391
+ from langchain_google_genai import ChatGoogleGenerativeAI
392
+ chat = ChatGoogleGenerativeAI(model="gemini-pro")
393
+ chat.invoke("Write me a ballad about LangChain")
394
+
395
+ """
396
+
397
+ model: str = Field(
398
+ ...,
399
+ description="""The name of the model to use.
400
+ Supported examples:
401
+ - gemini-pro""",
402
+ )
403
+ max_output_tokens: int = Field(default=None, description="Max output tokens")
404
+
405
+ client: Any #: :meta private:
406
+ google_api_key: Optional[str] = None
407
+ temperature: Optional[float] = None
408
+ """Run inference with this temperature. Must by in the closed
409
+ interval [0.0, 1.0]."""
410
+ top_k: Optional[int] = None
411
+ """Decode using top-k sampling: consider the set of top_k most probable tokens.
412
+ Must be positive."""
413
+ n: int = 1
414
+ """Number of chat completions to generate for each prompt. Note that the API may
415
+ not return the full n completions if duplicates are generated."""
416
+
417
+ _generative_model: Any #: :meta private:
418
+
419
+ @property
420
+ def lc_secrets(self) -> Dict[str, str]:
421
+ return {"google_api_key": "GOOGLE_API_KEY"}
422
+
423
+ @property
424
+ def _llm_type(self) -> str:
425
+ return "chat-google-generative-ai"
426
+
427
+ @property
428
+ def _is_geminiai(self) -> bool:
429
+ return self.model is not None and "gemini" in self.model
430
+
431
+ @classmethod
432
+ def is_lc_serializable(self) -> bool:
433
+ return True
434
+
435
+ @root_validator()
436
+ def validate_environment(cls, values: Dict) -> Dict:
437
+ google_api_key = get_from_dict_or_env(
438
+ values, "google_api_key", "GOOGLE_API_KEY"
439
+ )
440
+ try:
441
+ import google.generativeai as genai
442
+
443
+ genai.configure(api_key=google_api_key)
444
+ except ImportError:
445
+ raise ChatGoogleGenerativeAIError(
446
+ "Could not import google.generativeai python package. "
447
+ "Please install it with `pip install google-generativeai`"
448
+ )
449
+
450
+ values["client"] = genai
451
+ if (
452
+ values.get("temperature") is not None
453
+ and not 0 <= values["temperature"] <= 1
454
+ ):
455
+ raise ValueError("temperature must be in the range [0.0, 1.0]")
456
+
457
+ if values.get("top_p") is not None and not 0 <= values["top_p"] <= 1:
458
+ raise ValueError("top_p must be in the range [0.0, 1.0]")
459
+
460
+ if values.get("top_k") is not None and values["top_k"] <= 0:
461
+ raise ValueError("top_k must be positive")
462
+ model = values["model"]
463
+ values["_generative_model"] = genai.GenerativeModel(model_name=model)
464
+ return values
465
+
466
+ @property
467
+ def _identifying_params(self) -> Dict[str, Any]:
468
+ """Get the identifying parameters."""
469
+ return {
470
+ "model": self.model,
471
+ "temperature": self.temperature,
472
+ "top_k": self.top_k,
473
+ "n": self.n,
474
+ }
475
+
476
+ @property
477
+ def _generation_method(self) -> Callable:
478
+ return self._generative_model.generate_content
479
+
480
+ @property
481
+ def _async_generation_method(self) -> Awaitable:
482
+ # TODO Add support once Google uncomments the async client
483
+ return self._generative_model.generate_content
484
+
485
+ def _prepare_params(
486
+ self, messages: Sequence[BaseMessage], stop: Optional[List[str]]
487
+ ) -> Dict[str, Any]:
488
+ contents = _messages_to_genai_contents(messages)
489
+ gen_config = {
490
+ k: v
491
+ for k, v in {
492
+ "candidate_count": self.n,
493
+ "temperature": self.temperature,
494
+ "stop_sequences": stop,
495
+ "max_output_tokens": self.max_output_tokens,
496
+ }.items()
497
+ if v is not None
498
+ }
499
+ params = {
500
+ "generation_config": gen_config,
501
+ "contents": contents,
502
+ }
503
+ return params
504
+
505
+ def _generate(
506
+ self,
507
+ messages: List[BaseMessage],
508
+ stop: Optional[List[str]] = None,
509
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
510
+ **kwargs: Any,
511
+ ) -> ChatResult:
512
+ params = self._prepare_params(messages, stop)
513
+ response: genai.types.GenerateContentResponse = chat_with_retry(
514
+ **params,
515
+ generation_method=self._generation_method,
516
+ **kwargs,
517
+ )
518
+ return _response_to_result(response)
519
+
520
+ async def _agenerate(
521
+ self,
522
+ messages: List[BaseMessage],
523
+ stop: Optional[List[str]] = None,
524
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
525
+ **kwargs: Any,
526
+ ) -> ChatResult:
527
+ return await asyncio.get_running_loop().run_in_executor(
528
+ None, self._generate, messages, stop, run_manager, **kwargs
529
+ )
530
+
531
+ def _stream(
532
+ self,
533
+ messages: List[BaseMessage],
534
+ stop: Optional[List[str]] = None,
535
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
536
+ **kwargs: Any,
537
+ ) -> Iterator[ChatGenerationChunk]:
538
+ params = self._prepare_params(messages, stop)
539
+ response: genai.types.GenerateContentResponse = chat_with_retry(
540
+ **params,
541
+ generation_method=self._generation_method,
542
+ **kwargs,
543
+ stream=True,
544
+ )
545
+ for chunk in response:
546
+ _chat_result = _response_to_result(
547
+ chunk,
548
+ ai_msg_t=AIMessageChunk,
549
+ human_msg_t=HumanMessageChunk,
550
+ chat_msg_t=ChatMessageChunk,
551
+ generation_t=ChatGenerationChunk,
552
+ )
553
+ gen = cast(ChatGenerationChunk, _chat_result.generations[0])
554
+ yield gen
555
+ if run_manager:
556
+ run_manager.on_llm_new_token(gen.text)
@@ -0,0 +1,94 @@
1
+ [tool.poetry]
2
+ name = "langchain-google-genai"
3
+ version = "0.0.1-rc0"
4
+ description = "An integration package connecting Google's genai package and LangChain"
5
+ authors = []
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.dependencies]
9
+ python = ">=3.9,<4.0"
10
+ langchain-core = ">=0.0.12"
11
+ google-generativeai = "^0.3.1"
12
+
13
+ [tool.poetry.group.test]
14
+ optional = true
15
+
16
+ [tool.poetry.group.test.dependencies]
17
+ pytest = "^7.3.0"
18
+ freezegun = "^1.2.2"
19
+ pytest-mock = "^3.10.0"
20
+ syrupy = "^4.0.2"
21
+ pytest-watcher = "^0.3.4"
22
+ pytest-asyncio = "^0.21.1"
23
+ langchain-core = {path = "../../core", develop = true}
24
+
25
+ [tool.poetry.group.codespell]
26
+ optional = true
27
+
28
+ [tool.poetry.group.codespell.dependencies]
29
+ codespell = "^2.2.0"
30
+
31
+ [tool.poetry.group.test_integration]
32
+ optional = true
33
+
34
+ [tool.poetry.group.test_integration.dependencies]
35
+
36
+ [tool.poetry.group.lint]
37
+ optional = true
38
+
39
+ [tool.poetry.group.lint.dependencies]
40
+ ruff = "^0.1.5"
41
+
42
+ [tool.poetry.group.typing.dependencies]
43
+ mypy = "^0.991"
44
+ langchain-core = {path = "../../core", develop = true}
45
+
46
+ [tool.poetry.group.dev]
47
+ optional = true
48
+
49
+ [tool.poetry.group.dev.dependencies]
50
+ langchain-core = {path = "../../core", develop = true}
51
+ pillow = "^10.1.0"
52
+ types-requests = "^2.31.0.10"
53
+ types-pillow = "^10.1.0.2"
54
+ types-google-cloud-ndb = "^2.2.0.1"
55
+
56
+ [tool.ruff]
57
+ select = [
58
+ "E", # pycodestyle
59
+ "F", # pyflakes
60
+ "I", # isort
61
+ ]
62
+
63
+ [tool.mypy]
64
+ disallow_untyped_defs = "True"
65
+ exclude = ["notebooks", "examples", "example_data", "langchain_core/pydantic"]
66
+
67
+ [tool.coverage.run]
68
+ omit = [
69
+ "tests/*",
70
+ ]
71
+
72
+ [build-system]
73
+ requires = ["poetry-core>=1.0.0"]
74
+ build-backend = "poetry.core.masonry.api"
75
+
76
+ [tool.pytest.ini_options]
77
+ # --strict-markers will raise errors on unknown marks.
78
+ # https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
79
+ #
80
+ # https://docs.pytest.org/en/7.1.x/reference/reference.html
81
+ # --strict-config any warnings encountered while parsing the `pytest`
82
+ # section of the configuration file raise errors.
83
+ #
84
+ # https://github.com/tophat/syrupy
85
+ # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
86
+ addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
87
+ # Registering custom markers.
88
+ # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
89
+ markers = [
90
+ "requires: mark tests as requiring a specific library",
91
+ "asyncio: mark tests as requiring asyncio",
92
+ "compile: mark placeholder test used to compile integration tests without running them",
93
+ ]
94
+ asyncio_mode = "auto"