langchain-google-genai 2.0.10__py3-none-any.whl → 2.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain-google-genai might be problematic. Click here for more details.

@@ -1,7 +1,9 @@
1
1
  from importlib import metadata
2
- from typing import Optional, Tuple, TypedDict
2
+ from typing import Any, Dict, Optional, Tuple, TypedDict
3
3
 
4
4
  from google.api_core.gapic_v1.client_info import ClientInfo
5
+ from langchain_core.utils import secret_from_env
6
+ from pydantic import BaseModel, Field, SecretStr
5
7
 
6
8
  from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory
7
9
 
@@ -12,6 +14,97 @@ class GoogleGenerativeAIError(Exception):
12
14
  """
13
15
 
14
16
 
17
+ class _BaseGoogleGenerativeAI(BaseModel):
18
+ """Base class for Google Generative AI LLMs"""
19
+
20
+ model: str = Field(
21
+ ...,
22
+ description="""The name of the model to use.
23
+ Supported examples:
24
+ - gemini-pro
25
+ - models/text-bison-001""",
26
+ )
27
+ """Model name to use."""
28
+ google_api_key: Optional[SecretStr] = Field(
29
+ alias="api_key", default_factory=secret_from_env("GOOGLE_API_KEY", default=None)
30
+ )
31
+ """Google AI API key.
32
+ If not specified will be read from env var ``GOOGLE_API_KEY``."""
33
+ credentials: Any = None
34
+ "The default custom credentials (google.auth.credentials.Credentials) to use "
35
+ "when making API calls. If not provided, credentials will be ascertained from "
36
+ "the GOOGLE_API_KEY envvar"
37
+ temperature: float = 0.7
38
+ """Run inference with this temperature. Must by in the closed interval
39
+ [0.0, 1.0]."""
40
+ top_p: Optional[float] = None
41
+ """Decode using nucleus sampling: consider the smallest set of tokens whose
42
+ probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
43
+ top_k: Optional[int] = None
44
+ """Decode using top-k sampling: consider the set of top_k most probable tokens.
45
+ Must be positive."""
46
+ max_output_tokens: Optional[int] = Field(default=None, alias="max_tokens")
47
+ """Maximum number of tokens to include in a candidate. Must be greater than zero.
48
+ If unset, will default to 64."""
49
+ n: int = 1
50
+ """Number of chat completions to generate for each prompt. Note that the API may
51
+ not return the full n completions if duplicates are generated."""
52
+ max_retries: int = 6
53
+ """The maximum number of retries to make when generating."""
54
+
55
+ timeout: Optional[float] = None
56
+ """The maximum number of seconds to wait for a response."""
57
+
58
+ client_options: Optional[Dict] = Field(
59
+ default=None,
60
+ description=(
61
+ "A dictionary of client options to pass to the Google API client, "
62
+ "such as `api_endpoint`."
63
+ ),
64
+ )
65
+ transport: Optional[str] = Field(
66
+ default=None,
67
+ description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
68
+ )
69
+ additional_headers: Optional[Dict[str, str]] = Field(
70
+ default=None,
71
+ description=(
72
+ "A key-value dictionary representing additional headers for the model call"
73
+ ),
74
+ )
75
+
76
+ safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
77
+ """The default safety settings to use for all generations.
78
+
79
+ For example:
80
+
81
+ from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
82
+
83
+ safety_settings = {
84
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
85
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
86
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
87
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
88
+ }
89
+ """ # noqa: E501
90
+
91
+ @property
92
+ def lc_secrets(self) -> Dict[str, str]:
93
+ return {"google_api_key": "GOOGLE_API_KEY"}
94
+
95
+ @property
96
+ def _identifying_params(self) -> Dict[str, Any]:
97
+ """Get the identifying parameters."""
98
+ return {
99
+ "model": self.model,
100
+ "temperature": self.temperature,
101
+ "top_p": self.top_p,
102
+ "top_k": self.top_k,
103
+ "max_output_tokens": self.max_output_tokens,
104
+ "candidate_count": self.n,
105
+ }
106
+
107
+
15
108
  def get_user_agent(module: Optional[str] = None) -> Tuple[str, str]:
16
109
  r"""Returns a custom user agent header.
17
110
 
@@ -7,7 +7,6 @@ import logging
7
7
  from typing import (
8
8
  Any,
9
9
  Callable,
10
- Collection,
11
10
  Dict,
12
11
  List,
13
12
  Literal,
@@ -22,7 +21,6 @@ from typing import (
22
21
  import google.ai.generativelanguage as glm
23
22
  import google.ai.generativelanguage_v1beta.types as gapic
24
23
  import proto # type: ignore[import]
25
- from google.generativeai.types.content_types import ToolDict # type: ignore[import]
26
24
  from langchain_core.tools import BaseTool
27
25
  from langchain_core.tools import tool as callable_as_lc_tool
28
26
  from langchain_core.utils.function_calling import (
@@ -59,38 +57,21 @@ _ALLOWED_SCHEMA_FIELDS.extend(
59
57
  _ALLOWED_SCHEMA_FIELDS_SET = set(_ALLOWED_SCHEMA_FIELDS)
60
58
 
61
59
 
62
- class _ToolDictLike(TypedDict):
63
- function_declarations: _FunctionDeclarationLikeList
64
-
65
-
66
- class _FunctionDeclarationDict(TypedDict):
67
- name: str
68
- description: str
69
- parameters: Dict[str, Collection[str]]
70
-
71
-
72
- class _ToolDict(TypedDict):
73
- function_declarations: Sequence[_FunctionDeclarationDict]
74
-
75
-
76
60
  # Info: This is a FunctionDeclaration(=fc).
77
61
  _FunctionDeclarationLike = Union[
78
62
  BaseTool, Type[BaseModel], gapic.FunctionDeclaration, Callable, Dict[str, Any]
79
63
  ]
80
64
 
81
- # Info: This mean one tool.
82
- _FunctionDeclarationLikeList = Sequence[_FunctionDeclarationLike]
65
+
66
+ class _ToolDict(TypedDict):
67
+ function_declarations: Sequence[_FunctionDeclarationLike]
83
68
 
84
69
 
85
70
  # Info: This means one tool=Sequence of FunctionDeclaration
86
71
  # The dict should be gapic.Tool like. {"function_declarations": [ { "name": ...}.
87
72
  # OpenAI like dict is not be accepted. {{'type': 'function', 'function': {'name': ...}
88
73
  _ToolsType = Union[
89
- gapic.Tool,
90
- ToolDict,
91
- _ToolDictLike,
92
- _FunctionDeclarationLikeList,
93
- _FunctionDeclarationLike,
74
+ gapic.Tool, _ToolDict, _FunctionDeclarationLike, Sequence[_FunctionDeclarationLike]
94
75
  ]
95
76
 
96
77
 
@@ -152,12 +133,12 @@ def convert_to_genai_function_declarations(
152
133
  gapic_tool = gapic.Tool()
153
134
  for tool in tools:
154
135
  if isinstance(tool, gapic.Tool):
155
- gapic_tool.function_declarations.extend(tool.function_declarations)
136
+ gapic_tool.function_declarations.extend(tool.function_declarations) # type: ignore[union-attr]
156
137
  elif isinstance(tool, dict) and "function_declarations" not in tool:
157
138
  fd = _format_to_gapic_function_declaration(tool)
158
139
  gapic_tool.function_declarations.append(fd)
159
140
  elif isinstance(tool, dict):
160
- function_declarations = cast(_ToolDictLike, tool)["function_declarations"]
141
+ function_declarations = cast(_ToolDict, tool)["function_declarations"]
161
142
  if not isinstance(function_declarations, collections.abc.Sequence):
162
143
  raise ValueError(
163
144
  "function_declarations should be a list"
@@ -170,7 +151,7 @@ def convert_to_genai_function_declarations(
170
151
  ]
171
152
  gapic_tool.function_declarations.extend(fds)
172
153
  else:
173
- fd = _format_to_gapic_function_declaration(tool)
154
+ fd = _format_to_gapic_function_declaration(tool) # type: ignore[arg-type]
174
155
  gapic_tool.function_declarations.append(fd)
175
156
  return gapic_tool
176
157
 
@@ -375,6 +356,8 @@ def _get_items_from_schema(schema: Union[Dict, List, str]) -> Dict[str, Any]:
375
356
  )
376
357
  if _is_nullable_schema(schema):
377
358
  items["nullable"] = True
359
+ if "required" in schema:
360
+ items["required"] = schema["required"]
378
361
  else:
379
362
  # str
380
363
  items["type_"] = _get_type_from_schema({"type": schema})
@@ -35,6 +35,7 @@ from google.ai.generativelanguage_v1beta.types import (
35
35
  Content,
36
36
  FileData,
37
37
  FunctionCall,
38
+ FunctionDeclaration,
38
39
  FunctionResponse,
39
40
  GenerateContentRequest,
40
41
  GenerateContentResponse,
@@ -44,12 +45,8 @@ from google.ai.generativelanguage_v1beta.types import (
44
45
  ToolConfig,
45
46
  VideoMetadata,
46
47
  )
47
- from google.generativeai.caching import CachedContent # type: ignore[import]
48
- from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
49
- from google.generativeai.types import caching_types, content_types
50
- from google.generativeai.types.content_types import ( # type: ignore[import]
51
- FunctionDeclarationType,
52
- ToolDict,
48
+ from google.ai.generativelanguage_v1beta.types import (
49
+ Tool as GoogleTool,
53
50
  )
54
51
  from langchain_core.callbacks.manager import (
55
52
  AsyncCallbackManagerForLLMRun,
@@ -76,7 +73,7 @@ from langchain_core.output_parsers.openai_tools import (
76
73
  )
77
74
  from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
78
75
  from langchain_core.runnables import Runnable, RunnablePassthrough
79
- from langchain_core.utils import secret_from_env
76
+ from langchain_core.tools import BaseTool
80
77
  from langchain_core.utils.function_calling import convert_to_openai_tool
81
78
  from pydantic import (
82
79
  BaseModel,
@@ -97,24 +94,32 @@ from typing_extensions import Self
97
94
  from langchain_google_genai._common import (
98
95
  GoogleGenerativeAIError,
99
96
  SafetySettingDict,
97
+ _BaseGoogleGenerativeAI,
100
98
  get_client_info,
101
99
  )
102
100
  from langchain_google_genai._function_utils import (
103
101
  _tool_choice_to_tool_config,
104
102
  _ToolChoiceType,
105
103
  _ToolConfigDict,
104
+ _ToolDict,
106
105
  convert_to_genai_function_declarations,
107
106
  is_basemodel_subclass_safe,
108
107
  tool_to_dict,
109
108
  )
110
109
  from langchain_google_genai._image_utils import ImageBytesLoader
111
- from langchain_google_genai.llms import _BaseGoogleGenerativeAI
112
110
 
113
111
  from . import _genai_extension as genaix
114
112
 
115
113
  logger = logging.getLogger(__name__)
116
114
 
117
115
 
116
+ _FunctionDeclarationType = Union[
117
+ FunctionDeclaration,
118
+ dict[str, Any],
119
+ Callable[..., Any],
120
+ ]
121
+
122
+
118
123
  class ChatGoogleGenerativeAIError(GoogleGenerativeAIError):
119
124
  """
120
125
  Custom exception class for errors associated with the `Google GenAI` API.
@@ -786,11 +791,6 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
786
791
 
787
792
  client: Any = Field(default=None, exclude=True) #: :meta private:
788
793
  async_client_running: Any = Field(default=None, exclude=True) #: :meta private:
789
- google_api_key: Optional[SecretStr] = Field(
790
- alias="api_key", default_factory=secret_from_env("GOOGLE_API_KEY", default=None)
791
- )
792
- """Google AI API key.
793
- If not specified will be read from env var ``GOOGLE_API_KEY``."""
794
794
  default_metadata: Sequence[Tuple[str, str]] = Field(
795
795
  default_factory=list
796
796
  ) #: :meta private:
@@ -938,8 +938,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
938
938
  stop: Optional[List[str]] = None,
939
939
  run_manager: Optional[CallbackManagerForLLMRun] = None,
940
940
  *,
941
- tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
942
- functions: Optional[Sequence[FunctionDeclarationType]] = None,
941
+ tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
942
+ functions: Optional[Sequence[_FunctionDeclarationType]] = None,
943
943
  safety_settings: Optional[SafetySettingDict] = None,
944
944
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
945
945
  generation_config: Optional[Dict[str, Any]] = None,
@@ -972,8 +972,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
972
972
  stop: Optional[List[str]] = None,
973
973
  run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
974
974
  *,
975
- tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
976
- functions: Optional[Sequence[FunctionDeclarationType]] = None,
975
+ tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
976
+ functions: Optional[Sequence[_FunctionDeclarationType]] = None,
977
977
  safety_settings: Optional[SafetySettingDict] = None,
978
978
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
979
979
  generation_config: Optional[Dict[str, Any]] = None,
@@ -1021,8 +1021,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1021
1021
  stop: Optional[List[str]] = None,
1022
1022
  run_manager: Optional[CallbackManagerForLLMRun] = None,
1023
1023
  *,
1024
- tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
1025
- functions: Optional[Sequence[FunctionDeclarationType]] = None,
1024
+ tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
1025
+ functions: Optional[Sequence[_FunctionDeclarationType]] = None,
1026
1026
  safety_settings: Optional[SafetySettingDict] = None,
1027
1027
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1028
1028
  generation_config: Optional[Dict[str, Any]] = None,
@@ -1083,8 +1083,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1083
1083
  stop: Optional[List[str]] = None,
1084
1084
  run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
1085
1085
  *,
1086
- tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
1087
- functions: Optional[Sequence[FunctionDeclarationType]] = None,
1086
+ tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
1087
+ functions: Optional[Sequence[_FunctionDeclarationType]] = None,
1088
1088
  safety_settings: Optional[SafetySettingDict] = None,
1089
1089
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1090
1090
  generation_config: Optional[Dict[str, Any]] = None,
@@ -1158,8 +1158,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1158
1158
  messages: List[BaseMessage],
1159
1159
  *,
1160
1160
  stop: Optional[List[str]] = None,
1161
- tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
1162
- functions: Optional[Sequence[FunctionDeclarationType]] = None,
1161
+ tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
1162
+ functions: Optional[Sequence[_FunctionDeclarationType]] = None,
1163
1163
  safety_settings: Optional[SafetySettingDict] = None,
1164
1164
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1165
1165
  tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
@@ -1251,8 +1251,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1251
1251
  )
1252
1252
  else:
1253
1253
  parser = JsonOutputToolsParser()
1254
- tool_choice = _get_tool_name(schema) if self._supports_tool_choice else None
1255
- llm = self.bind_tools([schema], tool_choice=tool_choice)
1254
+ tool_choice = _get_tool_name(schema) if self._supports_tool_choice else None # type: ignore[arg-type]
1255
+ llm = self.bind_tools([schema], tool_choice=tool_choice) # type: ignore[list-item]
1256
1256
  if include_raw:
1257
1257
  parser_with_fallback = RunnablePassthrough.assign(
1258
1258
  parsed=itemgetter("raw") | parser, parsing_error=lambda _: None
@@ -1266,7 +1266,9 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1266
1266
 
1267
1267
  def bind_tools(
1268
1268
  self,
1269
- tools: Sequence[Union[ToolDict, GoogleTool]],
1269
+ tools: Sequence[
1270
+ dict[str, Any] | type | Callable[..., Any] | BaseTool | GoogleTool
1271
+ ],
1270
1272
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1271
1273
  *,
1272
1274
  tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
@@ -1303,90 +1305,6 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1303
1305
  pass
1304
1306
  return self.bind(tools=formatted_tools, **kwargs)
1305
1307
 
1306
- def create_cached_content(
1307
- self,
1308
- contents: Union[List[BaseMessage], content_types.ContentsType],
1309
- *,
1310
- display_name: str | None = None,
1311
- tools: Union[ToolDict, GoogleTool, None] = None,
1312
- tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
1313
- ttl: Optional[caching_types.TTLTypes] = None,
1314
- expire_time: Optional[caching_types.ExpireTimeTypes] = None,
1315
- ) -> str:
1316
- """
1317
-
1318
- Args:
1319
- display_name: The user-generated meaningful display name
1320
- of the cached content. `display_name` must be no
1321
- more than 128 unicode characters.
1322
- contents: Contents to cache.
1323
- tools: A list of `Tools` the model may use to generate response.
1324
- tool_choice: Which tool to require the model to call.
1325
- ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
1326
- `ttl` and `expire_time` are exclusive arguments.
1327
- expire_time: Expiration time for cached resource.
1328
- `ttl` and `expire_time` are exclusive arguments.
1329
- """
1330
- system: Optional[content_types.ContentType] = None
1331
- genai_contents: list = []
1332
- if all(isinstance(c, BaseMessage) for c in contents):
1333
- system, genai_contents = _parse_chat_history(
1334
- contents,
1335
- convert_system_message_to_human=self.convert_system_message_to_human,
1336
- )
1337
- elif any(isinstance(c, BaseMessage) for c in contents):
1338
- raise ValueError(
1339
- f"'contents' must either be a list of "
1340
- f"langchain_core.messages.BaseMessage or a list "
1341
- f"google.generativeai.types.content_types.ContentType, but not a mix "
1342
- f"of the two. Received {contents}"
1343
- )
1344
- else:
1345
- for content in contents:
1346
- if hasattr(content, "role") and content.role == "system":
1347
- if system is not None:
1348
- warnings.warn(
1349
- "Received multiple pieces of content with role 'system'. "
1350
- "Should only be one set of system instructions. Ignoring "
1351
- "all but the first 'system' content."
1352
- )
1353
- else:
1354
- system = content
1355
- elif isinstance(content, dict) and content.get("role") == "system":
1356
- if system is not None:
1357
- warnings.warn(
1358
- "Received multiple pieces of content with role 'system'. "
1359
- "Should only be one set of system instructions. Ignoring "
1360
- "all but the first 'system' content."
1361
- )
1362
- else:
1363
- system = content
1364
- else:
1365
- genai_contents.append(content)
1366
- if tools:
1367
- genai_tools = [convert_to_genai_function_declarations(tools)]
1368
- else:
1369
- genai_tools = None
1370
- if tool_choice and genai_tools:
1371
- all_names = [f.name for t in genai_tools for f in t.function_declarations]
1372
- tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
1373
- genai_tool_config = ToolConfig(
1374
- function_calling_config=tool_config["function_calling_config"]
1375
- )
1376
- else:
1377
- genai_tool_config = None
1378
- cached_content = CachedContent.create(
1379
- model=self.model,
1380
- system_instruction=system,
1381
- contents=genai_contents,
1382
- display_name=display_name,
1383
- tools=genai_tools,
1384
- tool_config=genai_tool_config,
1385
- ttl=ttl,
1386
- expire_time=expire_time,
1387
- )
1388
- return cached_content.name
1389
-
1390
1308
  @property
1391
1309
  def _supports_tool_choice(self) -> bool:
1392
1310
  return (
@@ -1397,7 +1315,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1397
1315
 
1398
1316
 
1399
1317
  def _get_tool_name(
1400
- tool: Union[ToolDict, GoogleTool],
1318
+ tool: Union[_ToolDict, GoogleTool],
1401
1319
  ) -> str:
1402
1320
  genai_tool = tool_to_dict(convert_to_genai_function_declarations([tool]))
1403
1321
  return [f["name"] for f in genai_tool["function_declarations"]][0] # type: ignore[index]
@@ -1,208 +1,21 @@
1
1
  from __future__ import annotations
2
2
 
3
- from enum import Enum, auto
4
- from typing import Any, Callable, Dict, Iterator, List, Optional, Union
3
+ from typing import Any, Iterator, List, Optional
5
4
 
6
- import google.api_core
7
- import google.generativeai as genai # type: ignore[import]
8
5
  from langchain_core.callbacks import (
9
- AsyncCallbackManagerForLLMRun,
10
6
  CallbackManagerForLLMRun,
11
7
  )
12
- from langchain_core.language_models import LangSmithParams, LanguageModelInput
13
- from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
8
+ from langchain_core.language_models import LangSmithParams
9
+ from langchain_core.language_models.llms import BaseLLM
10
+ from langchain_core.messages import HumanMessage
14
11
  from langchain_core.outputs import Generation, GenerationChunk, LLMResult
15
- from langchain_core.utils import secret_from_env
16
- from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
12
+ from pydantic import ConfigDict, model_validator
17
13
  from typing_extensions import Self
18
14
 
19
- from langchain_google_genai._enums import (
20
- HarmBlockThreshold,
21
- HarmCategory,
15
+ from langchain_google_genai._common import (
16
+ _BaseGoogleGenerativeAI,
22
17
  )
23
-
24
-
25
- class GoogleModelFamily(str, Enum):
26
- GEMINI = auto()
27
- PALM = auto()
28
-
29
- @classmethod
30
- def _missing_(cls, value: Any) -> Optional["GoogleModelFamily"]:
31
- if "gemini" in value.lower():
32
- return GoogleModelFamily.GEMINI
33
- elif "text-bison" in value.lower():
34
- return GoogleModelFamily.PALM
35
- return None
36
-
37
-
38
- def _create_retry_decorator(
39
- llm: BaseLLM,
40
- *,
41
- max_retries: int = 1,
42
- run_manager: Optional[
43
- Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
44
- ] = None,
45
- ) -> Callable[[Any], Any]:
46
- """Creates a retry decorator for Vertex / Palm LLMs."""
47
-
48
- errors = [
49
- google.api_core.exceptions.ResourceExhausted,
50
- google.api_core.exceptions.ServiceUnavailable,
51
- google.api_core.exceptions.Aborted,
52
- google.api_core.exceptions.DeadlineExceeded,
53
- google.api_core.exceptions.GoogleAPIError,
54
- ]
55
- decorator = create_base_retry_decorator(
56
- error_types=errors, max_retries=max_retries, run_manager=run_manager
57
- )
58
- return decorator
59
-
60
-
61
- def _completion_with_retry(
62
- llm: GoogleGenerativeAI,
63
- prompt: LanguageModelInput,
64
- is_gemini: bool = False,
65
- stream: bool = False,
66
- run_manager: Optional[CallbackManagerForLLMRun] = None,
67
- **kwargs: Any,
68
- ) -> Any:
69
- """Use tenacity to retry the completion call."""
70
- retry_decorator = _create_retry_decorator(
71
- llm, max_retries=llm.max_retries, run_manager=run_manager
72
- )
73
-
74
- @retry_decorator
75
- def _completion_with_retry(
76
- prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
77
- ) -> Any:
78
- generation_config = kwargs.get("generation_config", {})
79
- error_msg = (
80
- "Your location is not supported by google-generativeai at the moment. "
81
- "Try to use VertexAI LLM from langchain_google_vertexai"
82
- )
83
- try:
84
- if is_gemini:
85
- return llm.client.generate_content(
86
- contents=prompt,
87
- stream=stream,
88
- generation_config=generation_config,
89
- safety_settings=kwargs.pop("safety_settings", None),
90
- request_options={"timeout": llm.timeout} if llm.timeout else None,
91
- )
92
- return llm.client.generate_text(prompt=prompt, **kwargs)
93
- except google.api_core.exceptions.FailedPrecondition as exc:
94
- if "location is not supported" in exc.message:
95
- raise ValueError(error_msg)
96
-
97
- return _completion_with_retry(
98
- prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
99
- )
100
-
101
-
102
- def _strip_erroneous_leading_spaces(text: str) -> str:
103
- """Strip erroneous leading spaces from text.
104
-
105
- The PaLM API will sometimes erroneously return a single leading space in all
106
- lines > 1. This function strips that space.
107
- """
108
- has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:])
109
- if has_leading_space:
110
- return text.replace("\n ", "\n")
111
- else:
112
- return text
113
-
114
-
115
- class _BaseGoogleGenerativeAI(BaseModel):
116
- """Base class for Google Generative AI LLMs"""
117
-
118
- model: str = Field(
119
- ...,
120
- description="""The name of the model to use.
121
- Supported examples:
122
- - gemini-pro
123
- - models/text-bison-001""",
124
- )
125
- """Model name to use."""
126
- google_api_key: Optional[SecretStr] = Field(
127
- alias="api_key", default_factory=secret_from_env("GOOGLE_API_KEY", default=None)
128
- )
129
- credentials: Any = None
130
- "The default custom credentials (google.auth.credentials.Credentials) to use "
131
- "when making API calls. If not provided, credentials will be ascertained from "
132
- "the GOOGLE_API_KEY envvar"
133
- temperature: float = 0.7
134
- """Run inference with this temperature. Must by in the closed interval
135
- [0.0, 1.0]."""
136
- top_p: Optional[float] = None
137
- """Decode using nucleus sampling: consider the smallest set of tokens whose
138
- probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
139
- top_k: Optional[int] = None
140
- """Decode using top-k sampling: consider the set of top_k most probable tokens.
141
- Must be positive."""
142
- max_output_tokens: Optional[int] = Field(default=None, alias="max_tokens")
143
- """Maximum number of tokens to include in a candidate. Must be greater than zero.
144
- If unset, will default to 64."""
145
- n: int = 1
146
- """Number of chat completions to generate for each prompt. Note that the API may
147
- not return the full n completions if duplicates are generated."""
148
- max_retries: int = 6
149
- """The maximum number of retries to make when generating."""
150
-
151
- timeout: Optional[float] = None
152
- """The maximum number of seconds to wait for a response."""
153
-
154
- client_options: Optional[Dict] = Field(
155
- default=None,
156
- description=(
157
- "A dictionary of client options to pass to the Google API client, "
158
- "such as `api_endpoint`."
159
- ),
160
- )
161
- transport: Optional[str] = Field(
162
- default=None,
163
- description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
164
- )
165
- additional_headers: Optional[Dict[str, str]] = Field(
166
- default=None,
167
- description=(
168
- "A key-value dictionary representing additional headers for the model call"
169
- ),
170
- )
171
-
172
- safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
173
- """The default safety settings to use for all generations.
174
-
175
- For example:
176
-
177
- from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
178
-
179
- safety_settings = {
180
- HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
181
- HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
182
- HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
183
- HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
184
- }
185
- """ # noqa: E501
186
-
187
- @property
188
- def lc_secrets(self) -> Dict[str, str]:
189
- return {"google_api_key": "GOOGLE_API_KEY"}
190
-
191
- @property
192
- def _model_family(self) -> str:
193
- return GoogleModelFamily(self.model)
194
-
195
- @property
196
- def _identifying_params(self) -> Dict[str, Any]:
197
- """Get the identifying parameters."""
198
- return {
199
- "model": self.model,
200
- "temperature": self.temperature,
201
- "top_p": self.top_p,
202
- "top_k": self.top_k,
203
- "max_output_tokens": self.max_output_tokens,
204
- "candidate_count": self.n,
205
- }
18
+ from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
206
19
 
207
20
 
208
21
  class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
@@ -223,53 +36,21 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
223
36
  @model_validator(mode="after")
224
37
  def validate_environment(self) -> Self:
225
38
  """Validates params and passes them to google-generativeai package."""
226
- if self.credentials:
227
- genai.configure(
228
- credentials=self.credentials,
229
- transport=self.transport,
230
- client_options=self.client_options,
231
- )
232
- else:
233
- if isinstance(self.google_api_key, SecretStr):
234
- google_api_key: Optional[str] = self.google_api_key.get_secret_value()
235
- else:
236
- google_api_key = self.google_api_key
237
- genai.configure(
238
- api_key=google_api_key,
239
- transport=self.transport,
240
- client_options=self.client_options,
241
- )
242
-
243
- model_name = self.model
244
-
245
- safety_settings = self.safety_settings
246
-
247
- if safety_settings and (
248
- not GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI
249
- ):
250
- raise ValueError("Safety settings are only supported for Gemini models")
251
-
252
- if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
253
- self.client = genai.GenerativeModel(
254
- model_name=model_name, safety_settings=safety_settings
255
- )
256
- else:
257
- self.client = genai
258
-
259
- if self.temperature is not None and not 0 <= self.temperature <= 1:
260
- raise ValueError("temperature must be in the range [0.0, 1.0]")
261
-
262
- if self.top_p is not None and not 0 <= self.top_p <= 1:
263
- raise ValueError("top_p must be in the range [0.0, 1.0]")
264
-
265
- if self.top_k is not None and self.top_k <= 0:
266
- raise ValueError("top_k must be positive")
267
39
 
268
- if self.max_output_tokens is not None and self.max_output_tokens <= 0:
269
- raise ValueError("max_output_tokens must be greater than zero")
270
-
271
- if self.timeout is not None and self.timeout <= 0:
272
- raise ValueError("timeout must be greater than zero")
40
+ self.client = ChatGoogleGenerativeAI(
41
+ api_key=self.google_api_key,
42
+ credentials=self.credentials,
43
+ temperature=self.temperature,
44
+ top_p=self.top_p,
45
+ top_k=self.top_k,
46
+ max_tokens=self.max_output_tokens,
47
+ timeout=self.timeout,
48
+ model=self.model,
49
+ client_options=self.client_options,
50
+ transport=self.transport,
51
+ additional_headers=self.additional_headers,
52
+ safety_settings=self.safety_settings,
53
+ )
273
54
 
274
55
  return self
275
56
 
@@ -290,58 +71,26 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
290
71
  run_manager: Optional[CallbackManagerForLLMRun] = None,
291
72
  **kwargs: Any,
292
73
  ) -> LLMResult:
293
- generations: List[List[Generation]] = []
294
- generation_config = {
295
- "stop_sequences": stop,
296
- "temperature": self.temperature,
297
- "top_p": self.top_p,
298
- "top_k": self.top_k,
299
- "max_output_tokens": self.max_output_tokens,
300
- "candidate_count": self.n,
301
- }
74
+ generations = []
302
75
  for prompt in prompts:
303
- if self._model_family == GoogleModelFamily.GEMINI:
304
- res = _completion_with_retry(
305
- self,
306
- prompt=prompt,
307
- stream=False,
308
- is_gemini=True,
309
- run_manager=run_manager,
310
- generation_config=generation_config,
311
- safety_settings=kwargs.pop("safety_settings", None),
312
- )
313
- generation_info = None
314
- if res.usage_metadata is not None:
315
- generation_info = {
316
- "usage_metadata": res.to_dict().get("usage_metadata")
317
- }
318
-
319
- candidates = [
320
- "".join([p.text for p in c.content.parts]) for c in res.candidates
76
+ chat_result = self.client._generate(
77
+ [HumanMessage(content=prompt)],
78
+ stop=stop,
79
+ run_manager=run_manager,
80
+ **kwargs,
81
+ )
82
+ generations.append(
83
+ [
84
+ Generation(
85
+ text=g.message.content,
86
+ generation_info={
87
+ **g.generation_info,
88
+ **{"usage_metadata": g.message.usage_metadata},
89
+ },
90
+ )
91
+ for g in chat_result.generations
321
92
  ]
322
- generations.append(
323
- [
324
- Generation(text=c, generation_info=generation_info)
325
- for c in candidates
326
- ]
327
- )
328
- else:
329
- res = _completion_with_retry(
330
- self,
331
- model=self.model,
332
- prompt=prompt,
333
- stream=False,
334
- is_gemini=False,
335
- run_manager=run_manager,
336
- **generation_config,
337
- )
338
- prompt_generations = []
339
- for candidate in res.candidates:
340
- raw_text = candidate["output"]
341
- stripped_text = _strip_erroneous_leading_spaces(raw_text)
342
- prompt_generations.append(Generation(text=stripped_text))
343
- generations.append(prompt_generations)
344
-
93
+ )
345
94
  return LLMResult(generations=generations)
346
95
 
347
96
  def _stream(
@@ -351,31 +100,17 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
351
100
  run_manager: Optional[CallbackManagerForLLMRun] = None,
352
101
  **kwargs: Any,
353
102
  ) -> Iterator[GenerationChunk]:
354
- generation_config = {
355
- "stop_sequences": stop,
356
- "temperature": self.temperature,
357
- "top_p": self.top_p,
358
- "top_k": self.top_k,
359
- "max_output_tokens": self.max_output_tokens,
360
- "candidate_count": self.n,
361
- }
362
- generation_config = generation_config | kwargs.get("generation_config", {})
363
-
364
- for stream_resp in _completion_with_retry(
365
- self,
366
- prompt,
367
- stream=True,
368
- is_gemini=True,
103
+ for stream_chunk in self.client._stream(
104
+ [HumanMessage(content=prompt)],
105
+ stop=stop,
369
106
  run_manager=run_manager,
370
- generation_config=generation_config,
371
- safety_settings=kwargs.pop("safety_settings", None),
372
107
  **kwargs,
373
108
  ):
374
- chunk = GenerationChunk(text=stream_resp.text)
109
+ chunk = GenerationChunk(text=stream_chunk.message.content)
375
110
  yield chunk
376
111
  if run_manager:
377
112
  run_manager.on_llm_new_token(
378
- stream_resp.text,
113
+ chunk.text,
379
114
  chunk=chunk,
380
115
  verbose=self.verbose,
381
116
  )
@@ -383,7 +118,7 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
383
118
  @property
384
119
  def _llm_type(self) -> str:
385
120
  """Return type of llm."""
386
- return "google_palm"
121
+ return "google_gemini"
387
122
 
388
123
  def get_num_tokens(self, text: str) -> int:
389
124
  """Get the number of tokens present in the text.
@@ -396,11 +131,4 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
396
131
  Returns:
397
132
  The integer number of tokens in the text.
398
133
  """
399
- if self._model_family == GoogleModelFamily.GEMINI:
400
- result = self.client.count_tokens(text)
401
- token_count = result.total_tokens
402
- else:
403
- result = self.client.count_text_tokens(model=self.model, prompt=text)
404
- token_count = result["token_count"]
405
-
406
- return token_count
134
+ return self.client.get_num_tokens(text)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langchain-google-genai
3
- Version: 2.0.10
3
+ Version: 2.0.11
4
4
  Summary: An integration package connecting Google's genai package and LangChain
5
5
  Home-page: https://github.com/langchain-ai/langchain-google
6
6
  License: MIT
@@ -12,7 +12,7 @@ Classifier: Programming Language :: Python :: 3.10
12
12
  Classifier: Programming Language :: Python :: 3.11
13
13
  Classifier: Programming Language :: Python :: 3.12
14
14
  Requires-Dist: filetype (>=1.2.0,<2.0.0)
15
- Requires-Dist: google-generativeai (>=0.8.0,<0.9.0)
15
+ Requires-Dist: google-ai-generativelanguage (>=0.6.16,<0.7.0)
16
16
  Requires-Dist: langchain-core (>=0.3.37,<0.4.0)
17
17
  Requires-Dist: pydantic (>=2,<3)
18
18
  Project-URL: Repository, https://github.com/langchain-ai/langchain-google
@@ -1,16 +1,16 @@
1
1
  langchain_google_genai/__init__.py,sha256=Oji-S2KYWrku1wyQEskY84IOfY8MfRhujjJ4d7hbsk4,2758
2
- langchain_google_genai/_common.py,sha256=ASlwE8hEbvOm55BVF_D4rf2nl7RYsnpsi5xbM6DW3Cc,1576
2
+ langchain_google_genai/_common.py,sha256=KaNewDLkaQtDEgTMc631kjOKqKl1QoUKqD9iOIXNf-0,5275
3
3
  langchain_google_genai/_enums.py,sha256=KLPmxS1K83K4HjBIXFaXoL_sFEOv8Hq-2B2PDMKyDgo,197
4
- langchain_google_genai/_function_utils.py,sha256=c0bYzUcWyDnaYQi5tPtBxl7KGV4FswzSb3ywu8tD6yI,18036
4
+ langchain_google_genai/_function_utils.py,sha256=FPZ4CxI4iTFco_W6oWNCNV_lgNbCLRbj7Gf-D0zeTCY,17736
5
5
  langchain_google_genai/_genai_extension.py,sha256=81a4ly5ZHlqMf37uJfdB8K41qE6J5ujLnbUypIfFf2o,20775
6
6
  langchain_google_genai/_image_utils.py,sha256=tPrQyMvVmO8xkuow1SvA91omxUEv9ZUy1EMHNGjMAKY,5202
7
- langchain_google_genai/chat_models.py,sha256=F36_mMwLgnsQIEDJomKLuF4QdXdjkatXR5Ut-nMEvRA,55022
7
+ langchain_google_genai/chat_models.py,sha256=hRCxo2eagOnz7suKPiaQO6W1XyBxKr1UCy9AeK71xRs,50920
8
8
  langchain_google_genai/embeddings.py,sha256=jQRWPXD9twXoVBlXJQG7Duz0fb8UC0kgRzzwAmW3Dic,10146
9
9
  langchain_google_genai/genai_aqa.py,sha256=qB6h3-BSXqe0YLR3eeVllYzmNKK6ofI6xJLdBahUVZo,4300
10
10
  langchain_google_genai/google_vector_store.py,sha256=4wvhIiOmc3Fo046FyafPmT9NBCLek-9bgluvuTfrbpQ,16148
11
- langchain_google_genai/llms.py,sha256=EPUgkz5aqKOyKbztT7br8w60Uo5D_X_bF5qP-zd6iLs,14593
11
+ langchain_google_genai/llms.py,sha256=QNPitkORf86w8WQpTbjuPFCQFkB-qKRMW2phhRBwAEA,4318
12
12
  langchain_google_genai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
- langchain_google_genai-2.0.10.dist-info/LICENSE,sha256=DppmdYJVSc1jd0aio6ptnMUn5tIHrdAhQ12SclEBfBg,1072
14
- langchain_google_genai-2.0.10.dist-info/METADATA,sha256=2VjXxw5v4_8anWbUPouX2Y3yjG8JmBk9mKTJwIpvEkw,3595
15
- langchain_google_genai-2.0.10.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
16
- langchain_google_genai-2.0.10.dist-info/RECORD,,
13
+ langchain_google_genai-2.0.11.dist-info/LICENSE,sha256=DppmdYJVSc1jd0aio6ptnMUn5tIHrdAhQ12SclEBfBg,1072
14
+ langchain_google_genai-2.0.11.dist-info/METADATA,sha256=QQSJFXoI4IeDHVyAFT5Xtic6aJ5cd_kY1qCP-B6NX1c,3605
15
+ langchain_google_genai-2.0.11.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
16
+ langchain_google_genai-2.0.11.dist-info/RECORD,,