langchain-google-genai 2.0.9__tar.gz → 2.0.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain-google-genai might be problematic. Click here for more details.

Files changed (18) hide show
  1. {langchain_google_genai-2.0.9 → langchain_google_genai-2.0.11}/PKG-INFO +3 -3
  2. langchain_google_genai-2.0.11/langchain_google_genai/_common.py +145 -0
  3. {langchain_google_genai-2.0.9 → langchain_google_genai-2.0.11}/langchain_google_genai/_function_utils.py +23 -29
  4. {langchain_google_genai-2.0.9 → langchain_google_genai-2.0.11}/langchain_google_genai/chat_models.py +47 -119
  5. langchain_google_genai-2.0.11/langchain_google_genai/llms.py +134 -0
  6. {langchain_google_genai-2.0.9 → langchain_google_genai-2.0.11}/pyproject.toml +4 -14
  7. langchain_google_genai-2.0.9/langchain_google_genai/_common.py +0 -52
  8. langchain_google_genai-2.0.9/langchain_google_genai/llms.py +0 -406
  9. {langchain_google_genai-2.0.9 → langchain_google_genai-2.0.11}/LICENSE +0 -0
  10. {langchain_google_genai-2.0.9 → langchain_google_genai-2.0.11}/README.md +0 -0
  11. {langchain_google_genai-2.0.9 → langchain_google_genai-2.0.11}/langchain_google_genai/__init__.py +0 -0
  12. {langchain_google_genai-2.0.9 → langchain_google_genai-2.0.11}/langchain_google_genai/_enums.py +0 -0
  13. {langchain_google_genai-2.0.9 → langchain_google_genai-2.0.11}/langchain_google_genai/_genai_extension.py +0 -0
  14. {langchain_google_genai-2.0.9 → langchain_google_genai-2.0.11}/langchain_google_genai/_image_utils.py +0 -0
  15. {langchain_google_genai-2.0.9 → langchain_google_genai-2.0.11}/langchain_google_genai/embeddings.py +0 -0
  16. {langchain_google_genai-2.0.9 → langchain_google_genai-2.0.11}/langchain_google_genai/genai_aqa.py +0 -0
  17. {langchain_google_genai-2.0.9 → langchain_google_genai-2.0.11}/langchain_google_genai/google_vector_store.py +0 -0
  18. {langchain_google_genai-2.0.9 → langchain_google_genai-2.0.11}/langchain_google_genai/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langchain-google-genai
3
- Version: 2.0.9
3
+ Version: 2.0.11
4
4
  Summary: An integration package connecting Google's genai package and LangChain
5
5
  Home-page: https://github.com/langchain-ai/langchain-google
6
6
  License: MIT
@@ -12,8 +12,8 @@ Classifier: Programming Language :: Python :: 3.10
12
12
  Classifier: Programming Language :: Python :: 3.11
13
13
  Classifier: Programming Language :: Python :: 3.12
14
14
  Requires-Dist: filetype (>=1.2.0,<2.0.0)
15
- Requires-Dist: google-generativeai (>=0.8.0,<0.9.0)
16
- Requires-Dist: langchain-core (>=0.3.27,<0.4.0)
15
+ Requires-Dist: google-ai-generativelanguage (>=0.6.16,<0.7.0)
16
+ Requires-Dist: langchain-core (>=0.3.37,<0.4.0)
17
17
  Requires-Dist: pydantic (>=2,<3)
18
18
  Project-URL: Repository, https://github.com/langchain-ai/langchain-google
19
19
  Project-URL: Source Code, https://github.com/langchain-ai/langchain-google/tree/main/libs/genai
@@ -0,0 +1,145 @@
1
+ from importlib import metadata
2
+ from typing import Any, Dict, Optional, Tuple, TypedDict
3
+
4
+ from google.api_core.gapic_v1.client_info import ClientInfo
5
+ from langchain_core.utils import secret_from_env
6
+ from pydantic import BaseModel, Field, SecretStr
7
+
8
+ from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory
9
+
10
+
11
+ class GoogleGenerativeAIError(Exception):
12
+ """
13
+ Custom exception class for errors associated with the `Google GenAI` API.
14
+ """
15
+
16
+
17
+ class _BaseGoogleGenerativeAI(BaseModel):
18
+ """Base class for Google Generative AI LLMs"""
19
+
20
+ model: str = Field(
21
+ ...,
22
+ description="""The name of the model to use.
23
+ Supported examples:
24
+ - gemini-pro
25
+ - models/text-bison-001""",
26
+ )
27
+ """Model name to use."""
28
+ google_api_key: Optional[SecretStr] = Field(
29
+ alias="api_key", default_factory=secret_from_env("GOOGLE_API_KEY", default=None)
30
+ )
31
+ """Google AI API key.
32
+ If not specified will be read from env var ``GOOGLE_API_KEY``."""
33
+ credentials: Any = None
34
+ "The default custom credentials (google.auth.credentials.Credentials) to use "
35
+ "when making API calls. If not provided, credentials will be ascertained from "
36
+ "the GOOGLE_API_KEY envvar"
37
+ temperature: float = 0.7
38
+ """Run inference with this temperature. Must by in the closed interval
39
+ [0.0, 1.0]."""
40
+ top_p: Optional[float] = None
41
+ """Decode using nucleus sampling: consider the smallest set of tokens whose
42
+ probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
43
+ top_k: Optional[int] = None
44
+ """Decode using top-k sampling: consider the set of top_k most probable tokens.
45
+ Must be positive."""
46
+ max_output_tokens: Optional[int] = Field(default=None, alias="max_tokens")
47
+ """Maximum number of tokens to include in a candidate. Must be greater than zero.
48
+ If unset, will default to 64."""
49
+ n: int = 1
50
+ """Number of chat completions to generate for each prompt. Note that the API may
51
+ not return the full n completions if duplicates are generated."""
52
+ max_retries: int = 6
53
+ """The maximum number of retries to make when generating."""
54
+
55
+ timeout: Optional[float] = None
56
+ """The maximum number of seconds to wait for a response."""
57
+
58
+ client_options: Optional[Dict] = Field(
59
+ default=None,
60
+ description=(
61
+ "A dictionary of client options to pass to the Google API client, "
62
+ "such as `api_endpoint`."
63
+ ),
64
+ )
65
+ transport: Optional[str] = Field(
66
+ default=None,
67
+ description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
68
+ )
69
+ additional_headers: Optional[Dict[str, str]] = Field(
70
+ default=None,
71
+ description=(
72
+ "A key-value dictionary representing additional headers for the model call"
73
+ ),
74
+ )
75
+
76
+ safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
77
+ """The default safety settings to use for all generations.
78
+
79
+ For example:
80
+
81
+ from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
82
+
83
+ safety_settings = {
84
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
85
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
86
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
87
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
88
+ }
89
+ """ # noqa: E501
90
+
91
+ @property
92
+ def lc_secrets(self) -> Dict[str, str]:
93
+ return {"google_api_key": "GOOGLE_API_KEY"}
94
+
95
+ @property
96
+ def _identifying_params(self) -> Dict[str, Any]:
97
+ """Get the identifying parameters."""
98
+ return {
99
+ "model": self.model,
100
+ "temperature": self.temperature,
101
+ "top_p": self.top_p,
102
+ "top_k": self.top_k,
103
+ "max_output_tokens": self.max_output_tokens,
104
+ "candidate_count": self.n,
105
+ }
106
+
107
+
108
+ def get_user_agent(module: Optional[str] = None) -> Tuple[str, str]:
109
+ r"""Returns a custom user agent header.
110
+
111
+ Args:
112
+ module (Optional[str]):
113
+ Optional. The module for a custom user agent header.
114
+ Returns:
115
+ Tuple[str, str]
116
+ """
117
+ try:
118
+ langchain_version = metadata.version("langchain-google-genai")
119
+ except metadata.PackageNotFoundError:
120
+ langchain_version = "0.0.0"
121
+ client_library_version = (
122
+ f"{langchain_version}-{module}" if module else langchain_version
123
+ )
124
+ return client_library_version, f"langchain-google-genai/{client_library_version}"
125
+
126
+
127
+ def get_client_info(module: Optional[str] = None) -> "ClientInfo":
128
+ r"""Returns a client info object with a custom user agent header.
129
+
130
+ Args:
131
+ module (Optional[str]):
132
+ Optional. The module for a custom user agent header.
133
+ Returns:
134
+ google.api_core.gapic_v1.client_info.ClientInfo
135
+ """
136
+ client_library_version, user_agent = get_user_agent(module)
137
+ return ClientInfo(
138
+ client_library_version=client_library_version,
139
+ user_agent=user_agent,
140
+ )
141
+
142
+
143
+ class SafetySettingDict(TypedDict):
144
+ category: HarmCategory
145
+ threshold: HarmBlockThreshold
@@ -7,7 +7,6 @@ import logging
7
7
  from typing import (
8
8
  Any,
9
9
  Callable,
10
- Collection,
11
10
  Dict,
12
11
  List,
13
12
  Literal,
@@ -22,7 +21,6 @@ from typing import (
22
21
  import google.ai.generativelanguage as glm
23
22
  import google.ai.generativelanguage_v1beta.types as gapic
24
23
  import proto # type: ignore[import]
25
- from google.generativeai.types.content_types import ToolDict # type: ignore[import]
26
24
  from langchain_core.tools import BaseTool
27
25
  from langchain_core.tools import tool as callable_as_lc_tool
28
26
  from langchain_core.utils.function_calling import (
@@ -59,38 +57,21 @@ _ALLOWED_SCHEMA_FIELDS.extend(
59
57
  _ALLOWED_SCHEMA_FIELDS_SET = set(_ALLOWED_SCHEMA_FIELDS)
60
58
 
61
59
 
62
- class _ToolDictLike(TypedDict):
63
- function_declarations: _FunctionDeclarationLikeList
64
-
65
-
66
- class _FunctionDeclarationDict(TypedDict):
67
- name: str
68
- description: str
69
- parameters: Dict[str, Collection[str]]
70
-
71
-
72
- class _ToolDict(TypedDict):
73
- function_declarations: Sequence[_FunctionDeclarationDict]
74
-
75
-
76
60
  # Info: This is a FunctionDeclaration(=fc).
77
61
  _FunctionDeclarationLike = Union[
78
62
  BaseTool, Type[BaseModel], gapic.FunctionDeclaration, Callable, Dict[str, Any]
79
63
  ]
80
64
 
81
- # Info: This mean one tool.
82
- _FunctionDeclarationLikeList = Sequence[_FunctionDeclarationLike]
65
+
66
+ class _ToolDict(TypedDict):
67
+ function_declarations: Sequence[_FunctionDeclarationLike]
83
68
 
84
69
 
85
70
  # Info: This means one tool=Sequence of FunctionDeclaration
86
71
  # The dict should be gapic.Tool like. {"function_declarations": [ { "name": ...}.
87
72
  # OpenAI like dict is not be accepted. {{'type': 'function', 'function': {'name': ...}
88
73
  _ToolsType = Union[
89
- gapic.Tool,
90
- ToolDict,
91
- _ToolDictLike,
92
- _FunctionDeclarationLikeList,
93
- _FunctionDeclarationLike,
74
+ gapic.Tool, _ToolDict, _FunctionDeclarationLike, Sequence[_FunctionDeclarationLike]
94
75
  ]
95
76
 
96
77
 
@@ -152,12 +133,12 @@ def convert_to_genai_function_declarations(
152
133
  gapic_tool = gapic.Tool()
153
134
  for tool in tools:
154
135
  if isinstance(tool, gapic.Tool):
155
- gapic_tool.function_declarations.extend(tool.function_declarations)
136
+ gapic_tool.function_declarations.extend(tool.function_declarations) # type: ignore[union-attr]
156
137
  elif isinstance(tool, dict) and "function_declarations" not in tool:
157
138
  fd = _format_to_gapic_function_declaration(tool)
158
139
  gapic_tool.function_declarations.append(fd)
159
140
  elif isinstance(tool, dict):
160
- function_declarations = cast(_ToolDictLike, tool)["function_declarations"]
141
+ function_declarations = cast(_ToolDict, tool)["function_declarations"]
161
142
  if not isinstance(function_declarations, collections.abc.Sequence):
162
143
  raise ValueError(
163
144
  "function_declarations should be a list"
@@ -170,7 +151,7 @@ def convert_to_genai_function_declarations(
170
151
  ]
171
152
  gapic_tool.function_declarations.extend(fds)
172
153
  else:
173
- fd = _format_to_gapic_function_declaration(tool)
154
+ fd = _format_to_gapic_function_declaration(tool) # type: ignore[arg-type]
174
155
  gapic_tool.function_declarations.append(fd)
175
156
  return gapic_tool
176
157
 
@@ -235,13 +216,16 @@ def _format_base_tool_to_function_declaration(
235
216
  ),
236
217
  )
237
218
 
238
- if issubclass(tool.args_schema, BaseModel):
219
+ if isinstance(tool.args_schema, dict):
220
+ schema = tool.args_schema
221
+ elif issubclass(tool.args_schema, BaseModel):
239
222
  schema = tool.args_schema.model_json_schema()
240
223
  elif issubclass(tool.args_schema, BaseModelV1):
241
224
  schema = tool.args_schema.schema()
242
225
  else:
243
226
  raise NotImplementedError(
244
- f"args_schema must be a Pydantic BaseModel, got {tool.args_schema}."
227
+ "args_schema must be a Pydantic BaseModel or JSON schema, "
228
+ f"got {tool.args_schema}."
245
229
  )
246
230
  parameters = _dict_to_gapic_schema(schema)
247
231
 
@@ -301,10 +285,18 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
301
285
  continue
302
286
  properties_item: Dict[str, Union[str, int, Dict, List]] = {}
303
287
  if v.get("type") or v.get("anyOf") or v.get("type_"):
304
- properties_item["type_"] = _get_type_from_schema(v)
288
+ item_type_ = _get_type_from_schema(v)
289
+ properties_item["type_"] = item_type_
305
290
  if _is_nullable_schema(v):
306
291
  properties_item["nullable"] = True
307
292
 
293
+ # Replace `v` with chosen definition for array / object json types
294
+ any_of_types = v.get("anyOf")
295
+ if any_of_types and item_type_ in [glm.Type.ARRAY, glm.Type.OBJECT]:
296
+ json_type_ = "array" if item_type_ == glm.Type.ARRAY else "object"
297
+ # Use Index -1 for consistency with `_get_nullable_type_from_schema`
298
+ v = [val for val in any_of_types if val.get("type") == json_type_][-1]
299
+
308
300
  if v.get("enum"):
309
301
  properties_item["enum"] = v["enum"]
310
302
 
@@ -364,6 +356,8 @@ def _get_items_from_schema(schema: Union[Dict, List, str]) -> Dict[str, Any]:
364
356
  )
365
357
  if _is_nullable_schema(schema):
366
358
  items["nullable"] = True
359
+ if "required" in schema:
360
+ items["required"] = schema["required"]
367
361
  else:
368
362
  # str
369
363
  items["type_"] = _get_type_from_schema({"type": schema})
@@ -35,6 +35,7 @@ from google.ai.generativelanguage_v1beta.types import (
35
35
  Content,
36
36
  FileData,
37
37
  FunctionCall,
38
+ FunctionDeclaration,
38
39
  FunctionResponse,
39
40
  GenerateContentRequest,
40
41
  GenerateContentResponse,
@@ -44,12 +45,8 @@ from google.ai.generativelanguage_v1beta.types import (
44
45
  ToolConfig,
45
46
  VideoMetadata,
46
47
  )
47
- from google.generativeai.caching import CachedContent # type: ignore[import]
48
- from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
49
- from google.generativeai.types import caching_types, content_types
50
- from google.generativeai.types.content_types import ( # type: ignore[import]
51
- FunctionDeclarationType,
52
- ToolDict,
48
+ from google.ai.generativelanguage_v1beta.types import (
49
+ Tool as GoogleTool,
53
50
  )
54
51
  from langchain_core.callbacks.manager import (
55
52
  AsyncCallbackManagerForLLMRun,
@@ -76,7 +73,7 @@ from langchain_core.output_parsers.openai_tools import (
76
73
  )
77
74
  from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
78
75
  from langchain_core.runnables import Runnable, RunnablePassthrough
79
- from langchain_core.utils import secret_from_env
76
+ from langchain_core.tools import BaseTool
80
77
  from langchain_core.utils.function_calling import convert_to_openai_tool
81
78
  from pydantic import (
82
79
  BaseModel,
@@ -97,24 +94,32 @@ from typing_extensions import Self
97
94
  from langchain_google_genai._common import (
98
95
  GoogleGenerativeAIError,
99
96
  SafetySettingDict,
97
+ _BaseGoogleGenerativeAI,
100
98
  get_client_info,
101
99
  )
102
100
  from langchain_google_genai._function_utils import (
103
101
  _tool_choice_to_tool_config,
104
102
  _ToolChoiceType,
105
103
  _ToolConfigDict,
104
+ _ToolDict,
106
105
  convert_to_genai_function_declarations,
107
106
  is_basemodel_subclass_safe,
108
107
  tool_to_dict,
109
108
  )
110
109
  from langchain_google_genai._image_utils import ImageBytesLoader
111
- from langchain_google_genai.llms import _BaseGoogleGenerativeAI
112
110
 
113
111
  from . import _genai_extension as genaix
114
112
 
115
113
  logger = logging.getLogger(__name__)
116
114
 
117
115
 
116
+ _FunctionDeclarationType = Union[
117
+ FunctionDeclaration,
118
+ dict[str, Any],
119
+ Callable[..., Any],
120
+ ]
121
+
122
+
118
123
  class ChatGoogleGenerativeAIError(GoogleGenerativeAIError):
119
124
  """
120
125
  Custom exception class for errors associated with the `Google GenAI` API.
@@ -301,9 +306,12 @@ def _convert_to_parts(
301
306
  return parts
302
307
 
303
308
 
304
- def _convert_tool_message_to_part(message: ToolMessage | FunctionMessage) -> Part:
309
+ def _convert_tool_message_to_part(
310
+ message: ToolMessage | FunctionMessage, name: Optional[str] = None
311
+ ) -> Part:
305
312
  """Converts a tool or function message to a google part."""
306
- name = message.name
313
+ # Legacy agent stores tool name in message.additional_kwargs instead of message.name
314
+ name = message.name or name or message.additional_kwargs.get("name")
307
315
  response: Any
308
316
  if not isinstance(message.content, str):
309
317
  response = message.content
@@ -331,16 +339,17 @@ def _get_ai_message_tool_messages_parts(
331
339
  list of Parts.
332
340
  """
333
341
  # We are interested only in the tool messages that are part of the AI message
334
- tool_calls_ids = [tool_call["id"] for tool_call in ai_message.tool_calls]
342
+ tool_calls_ids = {tool_call["id"]: tool_call for tool_call in ai_message.tool_calls}
335
343
  parts = []
336
344
  for i, message in enumerate(tool_messages):
337
345
  if not tool_calls_ids:
338
346
  break
339
347
  if message.tool_call_id in tool_calls_ids:
340
- # remove the id from the list, so that we do not iterate over it again
341
- tool_calls_ids.remove(message.tool_call_id)
342
- part = _convert_tool_message_to_part(message)
348
+ tool_call = tool_calls_ids[message.tool_call_id]
349
+ part = _convert_tool_message_to_part(message, name=tool_call.get("name"))
343
350
  parts.append(part)
351
+ # remove the id from the dict, so that we do not iterate over it again
352
+ tool_calls_ids.pop(message.tool_call_id)
344
353
  return parts
345
354
 
346
355
 
@@ -360,8 +369,14 @@ def _parse_chat_history(
360
369
  message for message in input_messages if isinstance(message, ToolMessage)
361
370
  ]
362
371
  for i, message in enumerate(messages_without_tool_messages):
363
- if i == 0 and isinstance(message, SystemMessage):
364
- system_instruction = Content(parts=_convert_to_parts(message.content))
372
+ if isinstance(message, SystemMessage):
373
+ system_parts = _convert_to_parts(message.content)
374
+ if i == 0:
375
+ system_instruction = Content(parts=system_parts)
376
+ elif system_instruction is not None:
377
+ system_instruction.parts.extend(system_parts)
378
+ else:
379
+ pass
365
380
  continue
366
381
  elif isinstance(message, AIMessage):
367
382
  role = "model"
@@ -776,11 +791,6 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
776
791
 
777
792
  client: Any = Field(default=None, exclude=True) #: :meta private:
778
793
  async_client_running: Any = Field(default=None, exclude=True) #: :meta private:
779
- google_api_key: Optional[SecretStr] = Field(
780
- alias="api_key", default_factory=secret_from_env("GOOGLE_API_KEY", default=None)
781
- )
782
- """Google AI API key.
783
- If not specified will be read from env var ``GOOGLE_API_KEY``."""
784
794
  default_metadata: Sequence[Tuple[str, str]] = Field(
785
795
  default_factory=list
786
796
  ) #: :meta private:
@@ -928,8 +938,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
928
938
  stop: Optional[List[str]] = None,
929
939
  run_manager: Optional[CallbackManagerForLLMRun] = None,
930
940
  *,
931
- tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
932
- functions: Optional[Sequence[FunctionDeclarationType]] = None,
941
+ tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
942
+ functions: Optional[Sequence[_FunctionDeclarationType]] = None,
933
943
  safety_settings: Optional[SafetySettingDict] = None,
934
944
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
935
945
  generation_config: Optional[Dict[str, Any]] = None,
@@ -962,8 +972,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
962
972
  stop: Optional[List[str]] = None,
963
973
  run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
964
974
  *,
965
- tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
966
- functions: Optional[Sequence[FunctionDeclarationType]] = None,
975
+ tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
976
+ functions: Optional[Sequence[_FunctionDeclarationType]] = None,
967
977
  safety_settings: Optional[SafetySettingDict] = None,
968
978
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
969
979
  generation_config: Optional[Dict[str, Any]] = None,
@@ -1011,8 +1021,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1011
1021
  stop: Optional[List[str]] = None,
1012
1022
  run_manager: Optional[CallbackManagerForLLMRun] = None,
1013
1023
  *,
1014
- tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
1015
- functions: Optional[Sequence[FunctionDeclarationType]] = None,
1024
+ tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
1025
+ functions: Optional[Sequence[_FunctionDeclarationType]] = None,
1016
1026
  safety_settings: Optional[SafetySettingDict] = None,
1017
1027
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1018
1028
  generation_config: Optional[Dict[str, Any]] = None,
@@ -1073,8 +1083,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1073
1083
  stop: Optional[List[str]] = None,
1074
1084
  run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
1075
1085
  *,
1076
- tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
1077
- functions: Optional[Sequence[FunctionDeclarationType]] = None,
1086
+ tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
1087
+ functions: Optional[Sequence[_FunctionDeclarationType]] = None,
1078
1088
  safety_settings: Optional[SafetySettingDict] = None,
1079
1089
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1080
1090
  generation_config: Optional[Dict[str, Any]] = None,
@@ -1148,8 +1158,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1148
1158
  messages: List[BaseMessage],
1149
1159
  *,
1150
1160
  stop: Optional[List[str]] = None,
1151
- tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
1152
- functions: Optional[Sequence[FunctionDeclarationType]] = None,
1161
+ tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
1162
+ functions: Optional[Sequence[_FunctionDeclarationType]] = None,
1153
1163
  safety_settings: Optional[SafetySettingDict] = None,
1154
1164
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1155
1165
  tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
@@ -1241,8 +1251,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1241
1251
  )
1242
1252
  else:
1243
1253
  parser = JsonOutputToolsParser()
1244
- tool_choice = _get_tool_name(schema) if self._supports_tool_choice else None
1245
- llm = self.bind_tools([schema], tool_choice=tool_choice)
1254
+ tool_choice = _get_tool_name(schema) if self._supports_tool_choice else None # type: ignore[arg-type]
1255
+ llm = self.bind_tools([schema], tool_choice=tool_choice) # type: ignore[list-item]
1246
1256
  if include_raw:
1247
1257
  parser_with_fallback = RunnablePassthrough.assign(
1248
1258
  parsed=itemgetter("raw") | parser, parsing_error=lambda _: None
@@ -1256,7 +1266,9 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1256
1266
 
1257
1267
  def bind_tools(
1258
1268
  self,
1259
- tools: Sequence[Union[ToolDict, GoogleTool]],
1269
+ tools: Sequence[
1270
+ dict[str, Any] | type | Callable[..., Any] | BaseTool | GoogleTool
1271
+ ],
1260
1272
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1261
1273
  *,
1262
1274
  tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
@@ -1293,90 +1305,6 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1293
1305
  pass
1294
1306
  return self.bind(tools=formatted_tools, **kwargs)
1295
1307
 
1296
- def create_cached_content(
1297
- self,
1298
- contents: Union[List[BaseMessage], content_types.ContentsType],
1299
- *,
1300
- display_name: str | None = None,
1301
- tools: Union[ToolDict, GoogleTool, None] = None,
1302
- tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
1303
- ttl: Optional[caching_types.TTLTypes] = None,
1304
- expire_time: Optional[caching_types.ExpireTimeTypes] = None,
1305
- ) -> str:
1306
- """
1307
-
1308
- Args:
1309
- display_name: The user-generated meaningful display name
1310
- of the cached content. `display_name` must be no
1311
- more than 128 unicode characters.
1312
- contents: Contents to cache.
1313
- tools: A list of `Tools` the model may use to generate response.
1314
- tool_choice: Which tool to require the model to call.
1315
- ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
1316
- `ttl` and `expire_time` are exclusive arguments.
1317
- expire_time: Expiration time for cached resource.
1318
- `ttl` and `expire_time` are exclusive arguments.
1319
- """
1320
- system: Optional[content_types.ContentType] = None
1321
- genai_contents: list = []
1322
- if all(isinstance(c, BaseMessage) for c in contents):
1323
- system, genai_contents = _parse_chat_history(
1324
- contents,
1325
- convert_system_message_to_human=self.convert_system_message_to_human,
1326
- )
1327
- elif any(isinstance(c, BaseMessage) for c in contents):
1328
- raise ValueError(
1329
- f"'contents' must either be a list of "
1330
- f"langchain_core.messages.BaseMessage or a list "
1331
- f"google.generativeai.types.content_types.ContentType, but not a mix "
1332
- f"of the two. Received {contents}"
1333
- )
1334
- else:
1335
- for content in contents:
1336
- if hasattr(content, "role") and content.role == "system":
1337
- if system is not None:
1338
- warnings.warn(
1339
- "Received multiple pieces of content with role 'system'. "
1340
- "Should only be one set of system instructions. Ignoring "
1341
- "all but the first 'system' content."
1342
- )
1343
- else:
1344
- system = content
1345
- elif isinstance(content, dict) and content.get("role") == "system":
1346
- if system is not None:
1347
- warnings.warn(
1348
- "Received multiple pieces of content with role 'system'. "
1349
- "Should only be one set of system instructions. Ignoring "
1350
- "all but the first 'system' content."
1351
- )
1352
- else:
1353
- system = content
1354
- else:
1355
- genai_contents.append(content)
1356
- if tools:
1357
- genai_tools = [convert_to_genai_function_declarations(tools)]
1358
- else:
1359
- genai_tools = None
1360
- if tool_choice and genai_tools:
1361
- all_names = [f.name for t in genai_tools for f in t.function_declarations]
1362
- tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
1363
- genai_tool_config = ToolConfig(
1364
- function_calling_config=tool_config["function_calling_config"]
1365
- )
1366
- else:
1367
- genai_tool_config = None
1368
- cached_content = CachedContent.create(
1369
- model=self.model,
1370
- system_instruction=system,
1371
- contents=genai_contents,
1372
- display_name=display_name,
1373
- tools=genai_tools,
1374
- tool_config=genai_tool_config,
1375
- ttl=ttl,
1376
- expire_time=expire_time,
1377
- )
1378
- return cached_content.name
1379
-
1380
1308
  @property
1381
1309
  def _supports_tool_choice(self) -> bool:
1382
1310
  return (
@@ -1387,7 +1315,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1387
1315
 
1388
1316
 
1389
1317
  def _get_tool_name(
1390
- tool: Union[ToolDict, GoogleTool],
1318
+ tool: Union[_ToolDict, GoogleTool],
1391
1319
  ) -> str:
1392
1320
  genai_tool = tool_to_dict(convert_to_genai_function_declarations([tool]))
1393
1321
  return [f["name"] for f in genai_tool["function_declarations"]][0] # type: ignore[index]
@@ -0,0 +1,134 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Iterator, List, Optional
4
+
5
+ from langchain_core.callbacks import (
6
+ CallbackManagerForLLMRun,
7
+ )
8
+ from langchain_core.language_models import LangSmithParams
9
+ from langchain_core.language_models.llms import BaseLLM
10
+ from langchain_core.messages import HumanMessage
11
+ from langchain_core.outputs import Generation, GenerationChunk, LLMResult
12
+ from pydantic import ConfigDict, model_validator
13
+ from typing_extensions import Self
14
+
15
+ from langchain_google_genai._common import (
16
+ _BaseGoogleGenerativeAI,
17
+ )
18
+ from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
19
+
20
+
21
+ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
22
+ """Google GenerativeAI models.
23
+
24
+ Example:
25
+ .. code-block:: python
26
+
27
+ from langchain_google_genai import GoogleGenerativeAI
28
+ llm = GoogleGenerativeAI(model="gemini-pro")
29
+ """
30
+
31
+ client: Any = None #: :meta private:
32
+ model_config = ConfigDict(
33
+ populate_by_name=True,
34
+ )
35
+
36
+ @model_validator(mode="after")
37
+ def validate_environment(self) -> Self:
38
+ """Validates params and passes them to google-generativeai package."""
39
+
40
+ self.client = ChatGoogleGenerativeAI(
41
+ api_key=self.google_api_key,
42
+ credentials=self.credentials,
43
+ temperature=self.temperature,
44
+ top_p=self.top_p,
45
+ top_k=self.top_k,
46
+ max_tokens=self.max_output_tokens,
47
+ timeout=self.timeout,
48
+ model=self.model,
49
+ client_options=self.client_options,
50
+ transport=self.transport,
51
+ additional_headers=self.additional_headers,
52
+ safety_settings=self.safety_settings,
53
+ )
54
+
55
+ return self
56
+
57
+ def _get_ls_params(
58
+ self, stop: Optional[List[str]] = None, **kwargs: Any
59
+ ) -> LangSmithParams:
60
+ """Get standard params for tracing."""
61
+ ls_params = super()._get_ls_params(stop=stop, **kwargs)
62
+ ls_params["ls_provider"] = "google_genai"
63
+ if ls_max_tokens := kwargs.get("max_output_tokens", self.max_output_tokens):
64
+ ls_params["ls_max_tokens"] = ls_max_tokens
65
+ return ls_params
66
+
67
+ def _generate(
68
+ self,
69
+ prompts: List[str],
70
+ stop: Optional[List[str]] = None,
71
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
72
+ **kwargs: Any,
73
+ ) -> LLMResult:
74
+ generations = []
75
+ for prompt in prompts:
76
+ chat_result = self.client._generate(
77
+ [HumanMessage(content=prompt)],
78
+ stop=stop,
79
+ run_manager=run_manager,
80
+ **kwargs,
81
+ )
82
+ generations.append(
83
+ [
84
+ Generation(
85
+ text=g.message.content,
86
+ generation_info={
87
+ **g.generation_info,
88
+ **{"usage_metadata": g.message.usage_metadata},
89
+ },
90
+ )
91
+ for g in chat_result.generations
92
+ ]
93
+ )
94
+ return LLMResult(generations=generations)
95
+
96
+ def _stream(
97
+ self,
98
+ prompt: str,
99
+ stop: Optional[List[str]] = None,
100
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
101
+ **kwargs: Any,
102
+ ) -> Iterator[GenerationChunk]:
103
+ for stream_chunk in self.client._stream(
104
+ [HumanMessage(content=prompt)],
105
+ stop=stop,
106
+ run_manager=run_manager,
107
+ **kwargs,
108
+ ):
109
+ chunk = GenerationChunk(text=stream_chunk.message.content)
110
+ yield chunk
111
+ if run_manager:
112
+ run_manager.on_llm_new_token(
113
+ chunk.text,
114
+ chunk=chunk,
115
+ verbose=self.verbose,
116
+ )
117
+
118
+ @property
119
+ def _llm_type(self) -> str:
120
+ """Return type of llm."""
121
+ return "google_gemini"
122
+
123
+ def get_num_tokens(self, text: str) -> int:
124
+ """Get the number of tokens present in the text.
125
+
126
+ Useful for checking if an input will fit in a model's context window.
127
+
128
+ Args:
129
+ text: The string input to tokenize.
130
+
131
+ Returns:
132
+ The integer number of tokens in the text.
133
+ """
134
+ return self.client.get_num_tokens(text)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "langchain-google-genai"
3
- version = "2.0.9"
3
+ version = "2.0.11"
4
4
  description = "An integration package connecting Google's genai package and LangChain"
5
5
  authors = []
6
6
  readme = "README.md"
@@ -12,8 +12,8 @@ license = "MIT"
12
12
 
13
13
  [tool.poetry.dependencies]
14
14
  python = ">=3.9,<4.0"
15
- langchain-core = "^0.3.27"
16
- google-generativeai = "^0.8.0"
15
+ langchain-core = "^0.3.37"
16
+ google-ai-generativelanguage = "^0.6.16"
17
17
  pydantic = ">=2,<3"
18
18
  filetype = "^1.2.0"
19
19
 
@@ -28,14 +28,12 @@ syrupy = "^4.0.2"
28
28
  pytest-watcher = "^0.3.4"
29
29
  pytest-asyncio = "^0.21.1"
30
30
  numpy = "^1.26.2"
31
- langchain-tests = "0.3.1"
31
+ langchain-tests = "0.3.12"
32
32
 
33
33
  [tool.codespell]
34
34
  ignore-words-list = "rouge"
35
35
 
36
36
 
37
-
38
-
39
37
  [tool.poetry.group.codespell]
40
38
  optional = true
41
39
 
@@ -43,8 +41,6 @@ optional = true
43
41
  codespell = "^2.2.0"
44
42
 
45
43
 
46
-
47
-
48
44
  [tool.poetry.group.test_integration]
49
45
  optional = true
50
46
 
@@ -52,8 +48,6 @@ optional = true
52
48
  pytest = "^7.3.0"
53
49
 
54
50
 
55
-
56
-
57
51
  [tool.poetry.group.lint]
58
52
  optional = true
59
53
 
@@ -61,8 +55,6 @@ optional = true
61
55
  ruff = "^0.1.5"
62
56
 
63
57
 
64
-
65
-
66
58
  [tool.poetry.group.typing.dependencies]
67
59
  mypy = "^1.10"
68
60
  types-requests = "^2.28.11.5"
@@ -71,8 +63,6 @@ types-protobuf = "^4.24.0.20240302"
71
63
  numpy = "^1.26.2"
72
64
 
73
65
 
74
-
75
-
76
66
  [tool.poetry.group.dev]
77
67
  optional = true
78
68
 
@@ -1,52 +0,0 @@
1
- from importlib import metadata
2
- from typing import Optional, Tuple, TypedDict
3
-
4
- from google.api_core.gapic_v1.client_info import ClientInfo
5
-
6
- from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory
7
-
8
-
9
- class GoogleGenerativeAIError(Exception):
10
- """
11
- Custom exception class for errors associated with the `Google GenAI` API.
12
- """
13
-
14
-
15
- def get_user_agent(module: Optional[str] = None) -> Tuple[str, str]:
16
- r"""Returns a custom user agent header.
17
-
18
- Args:
19
- module (Optional[str]):
20
- Optional. The module for a custom user agent header.
21
- Returns:
22
- Tuple[str, str]
23
- """
24
- try:
25
- langchain_version = metadata.version("langchain-google-genai")
26
- except metadata.PackageNotFoundError:
27
- langchain_version = "0.0.0"
28
- client_library_version = (
29
- f"{langchain_version}-{module}" if module else langchain_version
30
- )
31
- return client_library_version, f"langchain-google-genai/{client_library_version}"
32
-
33
-
34
- def get_client_info(module: Optional[str] = None) -> "ClientInfo":
35
- r"""Returns a client info object with a custom user agent header.
36
-
37
- Args:
38
- module (Optional[str]):
39
- Optional. The module for a custom user agent header.
40
- Returns:
41
- google.api_core.gapic_v1.client_info.ClientInfo
42
- """
43
- client_library_version, user_agent = get_user_agent(module)
44
- return ClientInfo(
45
- client_library_version=client_library_version,
46
- user_agent=user_agent,
47
- )
48
-
49
-
50
- class SafetySettingDict(TypedDict):
51
- category: HarmCategory
52
- threshold: HarmBlockThreshold
@@ -1,406 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from enum import Enum, auto
4
- from typing import Any, Callable, Dict, Iterator, List, Optional, Union
5
-
6
- import google.api_core
7
- import google.generativeai as genai # type: ignore[import]
8
- from langchain_core.callbacks import (
9
- AsyncCallbackManagerForLLMRun,
10
- CallbackManagerForLLMRun,
11
- )
12
- from langchain_core.language_models import LangSmithParams, LanguageModelInput
13
- from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
14
- from langchain_core.outputs import Generation, GenerationChunk, LLMResult
15
- from langchain_core.utils import secret_from_env
16
- from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
17
- from typing_extensions import Self
18
-
19
- from langchain_google_genai._enums import (
20
- HarmBlockThreshold,
21
- HarmCategory,
22
- )
23
-
24
-
25
- class GoogleModelFamily(str, Enum):
26
- GEMINI = auto()
27
- PALM = auto()
28
-
29
- @classmethod
30
- def _missing_(cls, value: Any) -> Optional["GoogleModelFamily"]:
31
- if "gemini" in value.lower():
32
- return GoogleModelFamily.GEMINI
33
- elif "text-bison" in value.lower():
34
- return GoogleModelFamily.PALM
35
- return None
36
-
37
-
38
- def _create_retry_decorator(
39
- llm: BaseLLM,
40
- *,
41
- max_retries: int = 1,
42
- run_manager: Optional[
43
- Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
44
- ] = None,
45
- ) -> Callable[[Any], Any]:
46
- """Creates a retry decorator for Vertex / Palm LLMs."""
47
-
48
- errors = [
49
- google.api_core.exceptions.ResourceExhausted,
50
- google.api_core.exceptions.ServiceUnavailable,
51
- google.api_core.exceptions.Aborted,
52
- google.api_core.exceptions.DeadlineExceeded,
53
- google.api_core.exceptions.GoogleAPIError,
54
- ]
55
- decorator = create_base_retry_decorator(
56
- error_types=errors, max_retries=max_retries, run_manager=run_manager
57
- )
58
- return decorator
59
-
60
-
61
- def _completion_with_retry(
62
- llm: GoogleGenerativeAI,
63
- prompt: LanguageModelInput,
64
- is_gemini: bool = False,
65
- stream: bool = False,
66
- run_manager: Optional[CallbackManagerForLLMRun] = None,
67
- **kwargs: Any,
68
- ) -> Any:
69
- """Use tenacity to retry the completion call."""
70
- retry_decorator = _create_retry_decorator(
71
- llm, max_retries=llm.max_retries, run_manager=run_manager
72
- )
73
-
74
- @retry_decorator
75
- def _completion_with_retry(
76
- prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
77
- ) -> Any:
78
- generation_config = kwargs.get("generation_config", {})
79
- error_msg = (
80
- "Your location is not supported by google-generativeai at the moment. "
81
- "Try to use VertexAI LLM from langchain_google_vertexai"
82
- )
83
- try:
84
- if is_gemini:
85
- return llm.client.generate_content(
86
- contents=prompt,
87
- stream=stream,
88
- generation_config=generation_config,
89
- safety_settings=kwargs.pop("safety_settings", None),
90
- request_options={"timeout": llm.timeout} if llm.timeout else None,
91
- )
92
- return llm.client.generate_text(prompt=prompt, **kwargs)
93
- except google.api_core.exceptions.FailedPrecondition as exc:
94
- if "location is not supported" in exc.message:
95
- raise ValueError(error_msg)
96
-
97
- return _completion_with_retry(
98
- prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
99
- )
100
-
101
-
102
- def _strip_erroneous_leading_spaces(text: str) -> str:
103
- """Strip erroneous leading spaces from text.
104
-
105
- The PaLM API will sometimes erroneously return a single leading space in all
106
- lines > 1. This function strips that space.
107
- """
108
- has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:])
109
- if has_leading_space:
110
- return text.replace("\n ", "\n")
111
- else:
112
- return text
113
-
114
-
115
- class _BaseGoogleGenerativeAI(BaseModel):
116
- """Base class for Google Generative AI LLMs"""
117
-
118
- model: str = Field(
119
- ...,
120
- description="""The name of the model to use.
121
- Supported examples:
122
- - gemini-pro
123
- - models/text-bison-001""",
124
- )
125
- """Model name to use."""
126
- google_api_key: Optional[SecretStr] = Field(
127
- alias="api_key", default_factory=secret_from_env("GOOGLE_API_KEY", default=None)
128
- )
129
- credentials: Any = None
130
- "The default custom credentials (google.auth.credentials.Credentials) to use "
131
- "when making API calls. If not provided, credentials will be ascertained from "
132
- "the GOOGLE_API_KEY envvar"
133
- temperature: float = 0.7
134
- """Run inference with this temperature. Must by in the closed interval
135
- [0.0, 1.0]."""
136
- top_p: Optional[float] = None
137
- """Decode using nucleus sampling: consider the smallest set of tokens whose
138
- probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
139
- top_k: Optional[int] = None
140
- """Decode using top-k sampling: consider the set of top_k most probable tokens.
141
- Must be positive."""
142
- max_output_tokens: Optional[int] = Field(default=None, alias="max_tokens")
143
- """Maximum number of tokens to include in a candidate. Must be greater than zero.
144
- If unset, will default to 64."""
145
- n: int = 1
146
- """Number of chat completions to generate for each prompt. Note that the API may
147
- not return the full n completions if duplicates are generated."""
148
- max_retries: int = 6
149
- """The maximum number of retries to make when generating."""
150
-
151
- timeout: Optional[float] = None
152
- """The maximum number of seconds to wait for a response."""
153
-
154
- client_options: Optional[Dict] = Field(
155
- default=None,
156
- description=(
157
- "A dictionary of client options to pass to the Google API client, "
158
- "such as `api_endpoint`."
159
- ),
160
- )
161
- transport: Optional[str] = Field(
162
- default=None,
163
- description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
164
- )
165
- additional_headers: Optional[Dict[str, str]] = Field(
166
- default=None,
167
- description=(
168
- "A key-value dictionary representing additional headers for the model call"
169
- ),
170
- )
171
-
172
- safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
173
- """The default safety settings to use for all generations.
174
-
175
- For example:
176
-
177
- from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
178
-
179
- safety_settings = {
180
- HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
181
- HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
182
- HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
183
- HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
184
- }
185
- """ # noqa: E501
186
-
187
- @property
188
- def lc_secrets(self) -> Dict[str, str]:
189
- return {"google_api_key": "GOOGLE_API_KEY"}
190
-
191
- @property
192
- def _model_family(self) -> str:
193
- return GoogleModelFamily(self.model)
194
-
195
- @property
196
- def _identifying_params(self) -> Dict[str, Any]:
197
- """Get the identifying parameters."""
198
- return {
199
- "model": self.model,
200
- "temperature": self.temperature,
201
- "top_p": self.top_p,
202
- "top_k": self.top_k,
203
- "max_output_tokens": self.max_output_tokens,
204
- "candidate_count": self.n,
205
- }
206
-
207
-
208
- class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
209
- """Google GenerativeAI models.
210
-
211
- Example:
212
- .. code-block:: python
213
-
214
- from langchain_google_genai import GoogleGenerativeAI
215
- llm = GoogleGenerativeAI(model="gemini-pro")
216
- """
217
-
218
- client: Any = None #: :meta private:
219
- model_config = ConfigDict(
220
- populate_by_name=True,
221
- )
222
-
223
- @model_validator(mode="after")
224
- def validate_environment(self) -> Self:
225
- """Validates params and passes them to google-generativeai package."""
226
- if self.credentials:
227
- genai.configure(
228
- credentials=self.credentials,
229
- transport=self.transport,
230
- client_options=self.client_options,
231
- )
232
- else:
233
- if isinstance(self.google_api_key, SecretStr):
234
- google_api_key: Optional[str] = self.google_api_key.get_secret_value()
235
- else:
236
- google_api_key = self.google_api_key
237
- genai.configure(
238
- api_key=google_api_key,
239
- transport=self.transport,
240
- client_options=self.client_options,
241
- )
242
-
243
- model_name = self.model
244
-
245
- safety_settings = self.safety_settings
246
-
247
- if safety_settings and (
248
- not GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI
249
- ):
250
- raise ValueError("Safety settings are only supported for Gemini models")
251
-
252
- if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
253
- self.client = genai.GenerativeModel(
254
- model_name=model_name, safety_settings=safety_settings
255
- )
256
- else:
257
- self.client = genai
258
-
259
- if self.temperature is not None and not 0 <= self.temperature <= 1:
260
- raise ValueError("temperature must be in the range [0.0, 1.0]")
261
-
262
- if self.top_p is not None and not 0 <= self.top_p <= 1:
263
- raise ValueError("top_p must be in the range [0.0, 1.0]")
264
-
265
- if self.top_k is not None and self.top_k <= 0:
266
- raise ValueError("top_k must be positive")
267
-
268
- if self.max_output_tokens is not None and self.max_output_tokens <= 0:
269
- raise ValueError("max_output_tokens must be greater than zero")
270
-
271
- if self.timeout is not None and self.timeout <= 0:
272
- raise ValueError("timeout must be greater than zero")
273
-
274
- return self
275
-
276
- def _get_ls_params(
277
- self, stop: Optional[List[str]] = None, **kwargs: Any
278
- ) -> LangSmithParams:
279
- """Get standard params for tracing."""
280
- ls_params = super()._get_ls_params(stop=stop, **kwargs)
281
- ls_params["ls_provider"] = "google_genai"
282
- if ls_max_tokens := kwargs.get("max_output_tokens", self.max_output_tokens):
283
- ls_params["ls_max_tokens"] = ls_max_tokens
284
- return ls_params
285
-
286
- def _generate(
287
- self,
288
- prompts: List[str],
289
- stop: Optional[List[str]] = None,
290
- run_manager: Optional[CallbackManagerForLLMRun] = None,
291
- **kwargs: Any,
292
- ) -> LLMResult:
293
- generations: List[List[Generation]] = []
294
- generation_config = {
295
- "stop_sequences": stop,
296
- "temperature": self.temperature,
297
- "top_p": self.top_p,
298
- "top_k": self.top_k,
299
- "max_output_tokens": self.max_output_tokens,
300
- "candidate_count": self.n,
301
- }
302
- for prompt in prompts:
303
- if self._model_family == GoogleModelFamily.GEMINI:
304
- res = _completion_with_retry(
305
- self,
306
- prompt=prompt,
307
- stream=False,
308
- is_gemini=True,
309
- run_manager=run_manager,
310
- generation_config=generation_config,
311
- safety_settings=kwargs.pop("safety_settings", None),
312
- )
313
- generation_info = None
314
- if res.usage_metadata is not None:
315
- generation_info = {
316
- "usage_metadata": res.to_dict().get("usage_metadata")
317
- }
318
-
319
- candidates = [
320
- "".join([p.text for p in c.content.parts]) for c in res.candidates
321
- ]
322
- generations.append(
323
- [
324
- Generation(text=c, generation_info=generation_info)
325
- for c in candidates
326
- ]
327
- )
328
- else:
329
- res = _completion_with_retry(
330
- self,
331
- model=self.model,
332
- prompt=prompt,
333
- stream=False,
334
- is_gemini=False,
335
- run_manager=run_manager,
336
- **generation_config,
337
- )
338
- prompt_generations = []
339
- for candidate in res.candidates:
340
- raw_text = candidate["output"]
341
- stripped_text = _strip_erroneous_leading_spaces(raw_text)
342
- prompt_generations.append(Generation(text=stripped_text))
343
- generations.append(prompt_generations)
344
-
345
- return LLMResult(generations=generations)
346
-
347
- def _stream(
348
- self,
349
- prompt: str,
350
- stop: Optional[List[str]] = None,
351
- run_manager: Optional[CallbackManagerForLLMRun] = None,
352
- **kwargs: Any,
353
- ) -> Iterator[GenerationChunk]:
354
- generation_config = {
355
- "stop_sequences": stop,
356
- "temperature": self.temperature,
357
- "top_p": self.top_p,
358
- "top_k": self.top_k,
359
- "max_output_tokens": self.max_output_tokens,
360
- "candidate_count": self.n,
361
- }
362
- generation_config = generation_config | kwargs.get("generation_config", {})
363
-
364
- for stream_resp in _completion_with_retry(
365
- self,
366
- prompt,
367
- stream=True,
368
- is_gemini=True,
369
- run_manager=run_manager,
370
- generation_config=generation_config,
371
- safety_settings=kwargs.pop("safety_settings", None),
372
- **kwargs,
373
- ):
374
- chunk = GenerationChunk(text=stream_resp.text)
375
- yield chunk
376
- if run_manager:
377
- run_manager.on_llm_new_token(
378
- stream_resp.text,
379
- chunk=chunk,
380
- verbose=self.verbose,
381
- )
382
-
383
- @property
384
- def _llm_type(self) -> str:
385
- """Return type of llm."""
386
- return "google_palm"
387
-
388
- def get_num_tokens(self, text: str) -> int:
389
- """Get the number of tokens present in the text.
390
-
391
- Useful for checking if an input will fit in a model's context window.
392
-
393
- Args:
394
- text: The string input to tokenize.
395
-
396
- Returns:
397
- The integer number of tokens in the text.
398
- """
399
- if self._model_family == GoogleModelFamily.GEMINI:
400
- result = self.client.count_tokens(text)
401
- token_count = result.total_tokens
402
- else:
403
- result = self.client.count_text_tokens(model=self.model, prompt=text)
404
- token_count = result["token_count"]
405
-
406
- return token_count