langchain-google-genai 2.0.10__tar.gz → 2.0.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain-google-genai might be problematic. Click here for more details.

Files changed (18) hide show
  1. {langchain_google_genai-2.0.10 → langchain_google_genai-2.0.11}/PKG-INFO +2 -2
  2. langchain_google_genai-2.0.11/langchain_google_genai/_common.py +145 -0
  3. {langchain_google_genai-2.0.10 → langchain_google_genai-2.0.11}/langchain_google_genai/_function_utils.py +9 -26
  4. {langchain_google_genai-2.0.10 → langchain_google_genai-2.0.11}/langchain_google_genai/chat_models.py +29 -111
  5. langchain_google_genai-2.0.11/langchain_google_genai/llms.py +134 -0
  6. {langchain_google_genai-2.0.10 → langchain_google_genai-2.0.11}/pyproject.toml +2 -2
  7. langchain_google_genai-2.0.10/langchain_google_genai/_common.py +0 -52
  8. langchain_google_genai-2.0.10/langchain_google_genai/llms.py +0 -406
  9. {langchain_google_genai-2.0.10 → langchain_google_genai-2.0.11}/LICENSE +0 -0
  10. {langchain_google_genai-2.0.10 → langchain_google_genai-2.0.11}/README.md +0 -0
  11. {langchain_google_genai-2.0.10 → langchain_google_genai-2.0.11}/langchain_google_genai/__init__.py +0 -0
  12. {langchain_google_genai-2.0.10 → langchain_google_genai-2.0.11}/langchain_google_genai/_enums.py +0 -0
  13. {langchain_google_genai-2.0.10 → langchain_google_genai-2.0.11}/langchain_google_genai/_genai_extension.py +0 -0
  14. {langchain_google_genai-2.0.10 → langchain_google_genai-2.0.11}/langchain_google_genai/_image_utils.py +0 -0
  15. {langchain_google_genai-2.0.10 → langchain_google_genai-2.0.11}/langchain_google_genai/embeddings.py +0 -0
  16. {langchain_google_genai-2.0.10 → langchain_google_genai-2.0.11}/langchain_google_genai/genai_aqa.py +0 -0
  17. {langchain_google_genai-2.0.10 → langchain_google_genai-2.0.11}/langchain_google_genai/google_vector_store.py +0 -0
  18. {langchain_google_genai-2.0.10 → langchain_google_genai-2.0.11}/langchain_google_genai/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langchain-google-genai
3
- Version: 2.0.10
3
+ Version: 2.0.11
4
4
  Summary: An integration package connecting Google's genai package and LangChain
5
5
  Home-page: https://github.com/langchain-ai/langchain-google
6
6
  License: MIT
@@ -12,7 +12,7 @@ Classifier: Programming Language :: Python :: 3.10
12
12
  Classifier: Programming Language :: Python :: 3.11
13
13
  Classifier: Programming Language :: Python :: 3.12
14
14
  Requires-Dist: filetype (>=1.2.0,<2.0.0)
15
- Requires-Dist: google-generativeai (>=0.8.0,<0.9.0)
15
+ Requires-Dist: google-ai-generativelanguage (>=0.6.16,<0.7.0)
16
16
  Requires-Dist: langchain-core (>=0.3.37,<0.4.0)
17
17
  Requires-Dist: pydantic (>=2,<3)
18
18
  Project-URL: Repository, https://github.com/langchain-ai/langchain-google
@@ -0,0 +1,145 @@
1
+ from importlib import metadata
2
+ from typing import Any, Dict, Optional, Tuple, TypedDict
3
+
4
+ from google.api_core.gapic_v1.client_info import ClientInfo
5
+ from langchain_core.utils import secret_from_env
6
+ from pydantic import BaseModel, Field, SecretStr
7
+
8
+ from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory
9
+
10
+
11
+ class GoogleGenerativeAIError(Exception):
12
+ """
13
+ Custom exception class for errors associated with the `Google GenAI` API.
14
+ """
15
+
16
+
17
+ class _BaseGoogleGenerativeAI(BaseModel):
18
+ """Base class for Google Generative AI LLMs"""
19
+
20
+ model: str = Field(
21
+ ...,
22
+ description="""The name of the model to use.
23
+ Supported examples:
24
+ - gemini-pro
25
+ - models/text-bison-001""",
26
+ )
27
+ """Model name to use."""
28
+ google_api_key: Optional[SecretStr] = Field(
29
+ alias="api_key", default_factory=secret_from_env("GOOGLE_API_KEY", default=None)
30
+ )
31
+ """Google AI API key.
32
+ If not specified will be read from env var ``GOOGLE_API_KEY``."""
33
+ credentials: Any = None
34
+ "The default custom credentials (google.auth.credentials.Credentials) to use "
35
+ "when making API calls. If not provided, credentials will be ascertained from "
36
+ "the GOOGLE_API_KEY envvar"
37
+ temperature: float = 0.7
38
+ """Run inference with this temperature. Must by in the closed interval
39
+ [0.0, 1.0]."""
40
+ top_p: Optional[float] = None
41
+ """Decode using nucleus sampling: consider the smallest set of tokens whose
42
+ probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
43
+ top_k: Optional[int] = None
44
+ """Decode using top-k sampling: consider the set of top_k most probable tokens.
45
+ Must be positive."""
46
+ max_output_tokens: Optional[int] = Field(default=None, alias="max_tokens")
47
+ """Maximum number of tokens to include in a candidate. Must be greater than zero.
48
+ If unset, will default to 64."""
49
+ n: int = 1
50
+ """Number of chat completions to generate for each prompt. Note that the API may
51
+ not return the full n completions if duplicates are generated."""
52
+ max_retries: int = 6
53
+ """The maximum number of retries to make when generating."""
54
+
55
+ timeout: Optional[float] = None
56
+ """The maximum number of seconds to wait for a response."""
57
+
58
+ client_options: Optional[Dict] = Field(
59
+ default=None,
60
+ description=(
61
+ "A dictionary of client options to pass to the Google API client, "
62
+ "such as `api_endpoint`."
63
+ ),
64
+ )
65
+ transport: Optional[str] = Field(
66
+ default=None,
67
+ description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
68
+ )
69
+ additional_headers: Optional[Dict[str, str]] = Field(
70
+ default=None,
71
+ description=(
72
+ "A key-value dictionary representing additional headers for the model call"
73
+ ),
74
+ )
75
+
76
+ safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
77
+ """The default safety settings to use for all generations.
78
+
79
+ For example:
80
+
81
+ from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
82
+
83
+ safety_settings = {
84
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
85
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
86
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
87
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
88
+ }
89
+ """ # noqa: E501
90
+
91
+ @property
92
+ def lc_secrets(self) -> Dict[str, str]:
93
+ return {"google_api_key": "GOOGLE_API_KEY"}
94
+
95
+ @property
96
+ def _identifying_params(self) -> Dict[str, Any]:
97
+ """Get the identifying parameters."""
98
+ return {
99
+ "model": self.model,
100
+ "temperature": self.temperature,
101
+ "top_p": self.top_p,
102
+ "top_k": self.top_k,
103
+ "max_output_tokens": self.max_output_tokens,
104
+ "candidate_count": self.n,
105
+ }
106
+
107
+
108
+ def get_user_agent(module: Optional[str] = None) -> Tuple[str, str]:
109
+ r"""Returns a custom user agent header.
110
+
111
+ Args:
112
+ module (Optional[str]):
113
+ Optional. The module for a custom user agent header.
114
+ Returns:
115
+ Tuple[str, str]
116
+ """
117
+ try:
118
+ langchain_version = metadata.version("langchain-google-genai")
119
+ except metadata.PackageNotFoundError:
120
+ langchain_version = "0.0.0"
121
+ client_library_version = (
122
+ f"{langchain_version}-{module}" if module else langchain_version
123
+ )
124
+ return client_library_version, f"langchain-google-genai/{client_library_version}"
125
+
126
+
127
+ def get_client_info(module: Optional[str] = None) -> "ClientInfo":
128
+ r"""Returns a client info object with a custom user agent header.
129
+
130
+ Args:
131
+ module (Optional[str]):
132
+ Optional. The module for a custom user agent header.
133
+ Returns:
134
+ google.api_core.gapic_v1.client_info.ClientInfo
135
+ """
136
+ client_library_version, user_agent = get_user_agent(module)
137
+ return ClientInfo(
138
+ client_library_version=client_library_version,
139
+ user_agent=user_agent,
140
+ )
141
+
142
+
143
+ class SafetySettingDict(TypedDict):
144
+ category: HarmCategory
145
+ threshold: HarmBlockThreshold
@@ -7,7 +7,6 @@ import logging
7
7
  from typing import (
8
8
  Any,
9
9
  Callable,
10
- Collection,
11
10
  Dict,
12
11
  List,
13
12
  Literal,
@@ -22,7 +21,6 @@ from typing import (
22
21
  import google.ai.generativelanguage as glm
23
22
  import google.ai.generativelanguage_v1beta.types as gapic
24
23
  import proto # type: ignore[import]
25
- from google.generativeai.types.content_types import ToolDict # type: ignore[import]
26
24
  from langchain_core.tools import BaseTool
27
25
  from langchain_core.tools import tool as callable_as_lc_tool
28
26
  from langchain_core.utils.function_calling import (
@@ -59,38 +57,21 @@ _ALLOWED_SCHEMA_FIELDS.extend(
59
57
  _ALLOWED_SCHEMA_FIELDS_SET = set(_ALLOWED_SCHEMA_FIELDS)
60
58
 
61
59
 
62
- class _ToolDictLike(TypedDict):
63
- function_declarations: _FunctionDeclarationLikeList
64
-
65
-
66
- class _FunctionDeclarationDict(TypedDict):
67
- name: str
68
- description: str
69
- parameters: Dict[str, Collection[str]]
70
-
71
-
72
- class _ToolDict(TypedDict):
73
- function_declarations: Sequence[_FunctionDeclarationDict]
74
-
75
-
76
60
  # Info: This is a FunctionDeclaration(=fc).
77
61
  _FunctionDeclarationLike = Union[
78
62
  BaseTool, Type[BaseModel], gapic.FunctionDeclaration, Callable, Dict[str, Any]
79
63
  ]
80
64
 
81
- # Info: This mean one tool.
82
- _FunctionDeclarationLikeList = Sequence[_FunctionDeclarationLike]
65
+
66
+ class _ToolDict(TypedDict):
67
+ function_declarations: Sequence[_FunctionDeclarationLike]
83
68
 
84
69
 
85
70
  # Info: This means one tool=Sequence of FunctionDeclaration
86
71
  # The dict should be gapic.Tool like. {"function_declarations": [ { "name": ...}.
87
72
  # OpenAI like dict is not be accepted. {{'type': 'function', 'function': {'name': ...}
88
73
  _ToolsType = Union[
89
- gapic.Tool,
90
- ToolDict,
91
- _ToolDictLike,
92
- _FunctionDeclarationLikeList,
93
- _FunctionDeclarationLike,
74
+ gapic.Tool, _ToolDict, _FunctionDeclarationLike, Sequence[_FunctionDeclarationLike]
94
75
  ]
95
76
 
96
77
 
@@ -152,12 +133,12 @@ def convert_to_genai_function_declarations(
152
133
  gapic_tool = gapic.Tool()
153
134
  for tool in tools:
154
135
  if isinstance(tool, gapic.Tool):
155
- gapic_tool.function_declarations.extend(tool.function_declarations)
136
+ gapic_tool.function_declarations.extend(tool.function_declarations) # type: ignore[union-attr]
156
137
  elif isinstance(tool, dict) and "function_declarations" not in tool:
157
138
  fd = _format_to_gapic_function_declaration(tool)
158
139
  gapic_tool.function_declarations.append(fd)
159
140
  elif isinstance(tool, dict):
160
- function_declarations = cast(_ToolDictLike, tool)["function_declarations"]
141
+ function_declarations = cast(_ToolDict, tool)["function_declarations"]
161
142
  if not isinstance(function_declarations, collections.abc.Sequence):
162
143
  raise ValueError(
163
144
  "function_declarations should be a list"
@@ -170,7 +151,7 @@ def convert_to_genai_function_declarations(
170
151
  ]
171
152
  gapic_tool.function_declarations.extend(fds)
172
153
  else:
173
- fd = _format_to_gapic_function_declaration(tool)
154
+ fd = _format_to_gapic_function_declaration(tool) # type: ignore[arg-type]
174
155
  gapic_tool.function_declarations.append(fd)
175
156
  return gapic_tool
176
157
 
@@ -375,6 +356,8 @@ def _get_items_from_schema(schema: Union[Dict, List, str]) -> Dict[str, Any]:
375
356
  )
376
357
  if _is_nullable_schema(schema):
377
358
  items["nullable"] = True
359
+ if "required" in schema:
360
+ items["required"] = schema["required"]
378
361
  else:
379
362
  # str
380
363
  items["type_"] = _get_type_from_schema({"type": schema})
@@ -35,6 +35,7 @@ from google.ai.generativelanguage_v1beta.types import (
35
35
  Content,
36
36
  FileData,
37
37
  FunctionCall,
38
+ FunctionDeclaration,
38
39
  FunctionResponse,
39
40
  GenerateContentRequest,
40
41
  GenerateContentResponse,
@@ -44,12 +45,8 @@ from google.ai.generativelanguage_v1beta.types import (
44
45
  ToolConfig,
45
46
  VideoMetadata,
46
47
  )
47
- from google.generativeai.caching import CachedContent # type: ignore[import]
48
- from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
49
- from google.generativeai.types import caching_types, content_types
50
- from google.generativeai.types.content_types import ( # type: ignore[import]
51
- FunctionDeclarationType,
52
- ToolDict,
48
+ from google.ai.generativelanguage_v1beta.types import (
49
+ Tool as GoogleTool,
53
50
  )
54
51
  from langchain_core.callbacks.manager import (
55
52
  AsyncCallbackManagerForLLMRun,
@@ -76,7 +73,7 @@ from langchain_core.output_parsers.openai_tools import (
76
73
  )
77
74
  from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
78
75
  from langchain_core.runnables import Runnable, RunnablePassthrough
79
- from langchain_core.utils import secret_from_env
76
+ from langchain_core.tools import BaseTool
80
77
  from langchain_core.utils.function_calling import convert_to_openai_tool
81
78
  from pydantic import (
82
79
  BaseModel,
@@ -97,24 +94,32 @@ from typing_extensions import Self
97
94
  from langchain_google_genai._common import (
98
95
  GoogleGenerativeAIError,
99
96
  SafetySettingDict,
97
+ _BaseGoogleGenerativeAI,
100
98
  get_client_info,
101
99
  )
102
100
  from langchain_google_genai._function_utils import (
103
101
  _tool_choice_to_tool_config,
104
102
  _ToolChoiceType,
105
103
  _ToolConfigDict,
104
+ _ToolDict,
106
105
  convert_to_genai_function_declarations,
107
106
  is_basemodel_subclass_safe,
108
107
  tool_to_dict,
109
108
  )
110
109
  from langchain_google_genai._image_utils import ImageBytesLoader
111
- from langchain_google_genai.llms import _BaseGoogleGenerativeAI
112
110
 
113
111
  from . import _genai_extension as genaix
114
112
 
115
113
  logger = logging.getLogger(__name__)
116
114
 
117
115
 
116
+ _FunctionDeclarationType = Union[
117
+ FunctionDeclaration,
118
+ dict[str, Any],
119
+ Callable[..., Any],
120
+ ]
121
+
122
+
118
123
  class ChatGoogleGenerativeAIError(GoogleGenerativeAIError):
119
124
  """
120
125
  Custom exception class for errors associated with the `Google GenAI` API.
@@ -786,11 +791,6 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
786
791
 
787
792
  client: Any = Field(default=None, exclude=True) #: :meta private:
788
793
  async_client_running: Any = Field(default=None, exclude=True) #: :meta private:
789
- google_api_key: Optional[SecretStr] = Field(
790
- alias="api_key", default_factory=secret_from_env("GOOGLE_API_KEY", default=None)
791
- )
792
- """Google AI API key.
793
- If not specified will be read from env var ``GOOGLE_API_KEY``."""
794
794
  default_metadata: Sequence[Tuple[str, str]] = Field(
795
795
  default_factory=list
796
796
  ) #: :meta private:
@@ -938,8 +938,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
938
938
  stop: Optional[List[str]] = None,
939
939
  run_manager: Optional[CallbackManagerForLLMRun] = None,
940
940
  *,
941
- tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
942
- functions: Optional[Sequence[FunctionDeclarationType]] = None,
941
+ tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
942
+ functions: Optional[Sequence[_FunctionDeclarationType]] = None,
943
943
  safety_settings: Optional[SafetySettingDict] = None,
944
944
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
945
945
  generation_config: Optional[Dict[str, Any]] = None,
@@ -972,8 +972,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
972
972
  stop: Optional[List[str]] = None,
973
973
  run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
974
974
  *,
975
- tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
976
- functions: Optional[Sequence[FunctionDeclarationType]] = None,
975
+ tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
976
+ functions: Optional[Sequence[_FunctionDeclarationType]] = None,
977
977
  safety_settings: Optional[SafetySettingDict] = None,
978
978
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
979
979
  generation_config: Optional[Dict[str, Any]] = None,
@@ -1021,8 +1021,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1021
1021
  stop: Optional[List[str]] = None,
1022
1022
  run_manager: Optional[CallbackManagerForLLMRun] = None,
1023
1023
  *,
1024
- tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
1025
- functions: Optional[Sequence[FunctionDeclarationType]] = None,
1024
+ tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
1025
+ functions: Optional[Sequence[_FunctionDeclarationType]] = None,
1026
1026
  safety_settings: Optional[SafetySettingDict] = None,
1027
1027
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1028
1028
  generation_config: Optional[Dict[str, Any]] = None,
@@ -1083,8 +1083,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1083
1083
  stop: Optional[List[str]] = None,
1084
1084
  run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
1085
1085
  *,
1086
- tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
1087
- functions: Optional[Sequence[FunctionDeclarationType]] = None,
1086
+ tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
1087
+ functions: Optional[Sequence[_FunctionDeclarationType]] = None,
1088
1088
  safety_settings: Optional[SafetySettingDict] = None,
1089
1089
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1090
1090
  generation_config: Optional[Dict[str, Any]] = None,
@@ -1158,8 +1158,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1158
1158
  messages: List[BaseMessage],
1159
1159
  *,
1160
1160
  stop: Optional[List[str]] = None,
1161
- tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
1162
- functions: Optional[Sequence[FunctionDeclarationType]] = None,
1161
+ tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
1162
+ functions: Optional[Sequence[_FunctionDeclarationType]] = None,
1163
1163
  safety_settings: Optional[SafetySettingDict] = None,
1164
1164
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1165
1165
  tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
@@ -1251,8 +1251,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1251
1251
  )
1252
1252
  else:
1253
1253
  parser = JsonOutputToolsParser()
1254
- tool_choice = _get_tool_name(schema) if self._supports_tool_choice else None
1255
- llm = self.bind_tools([schema], tool_choice=tool_choice)
1254
+ tool_choice = _get_tool_name(schema) if self._supports_tool_choice else None # type: ignore[arg-type]
1255
+ llm = self.bind_tools([schema], tool_choice=tool_choice) # type: ignore[list-item]
1256
1256
  if include_raw:
1257
1257
  parser_with_fallback = RunnablePassthrough.assign(
1258
1258
  parsed=itemgetter("raw") | parser, parsing_error=lambda _: None
@@ -1266,7 +1266,9 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1266
1266
 
1267
1267
  def bind_tools(
1268
1268
  self,
1269
- tools: Sequence[Union[ToolDict, GoogleTool]],
1269
+ tools: Sequence[
1270
+ dict[str, Any] | type | Callable[..., Any] | BaseTool | GoogleTool
1271
+ ],
1270
1272
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1271
1273
  *,
1272
1274
  tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
@@ -1303,90 +1305,6 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1303
1305
  pass
1304
1306
  return self.bind(tools=formatted_tools, **kwargs)
1305
1307
 
1306
- def create_cached_content(
1307
- self,
1308
- contents: Union[List[BaseMessage], content_types.ContentsType],
1309
- *,
1310
- display_name: str | None = None,
1311
- tools: Union[ToolDict, GoogleTool, None] = None,
1312
- tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
1313
- ttl: Optional[caching_types.TTLTypes] = None,
1314
- expire_time: Optional[caching_types.ExpireTimeTypes] = None,
1315
- ) -> str:
1316
- """
1317
-
1318
- Args:
1319
- display_name: The user-generated meaningful display name
1320
- of the cached content. `display_name` must be no
1321
- more than 128 unicode characters.
1322
- contents: Contents to cache.
1323
- tools: A list of `Tools` the model may use to generate response.
1324
- tool_choice: Which tool to require the model to call.
1325
- ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
1326
- `ttl` and `expire_time` are exclusive arguments.
1327
- expire_time: Expiration time for cached resource.
1328
- `ttl` and `expire_time` are exclusive arguments.
1329
- """
1330
- system: Optional[content_types.ContentType] = None
1331
- genai_contents: list = []
1332
- if all(isinstance(c, BaseMessage) for c in contents):
1333
- system, genai_contents = _parse_chat_history(
1334
- contents,
1335
- convert_system_message_to_human=self.convert_system_message_to_human,
1336
- )
1337
- elif any(isinstance(c, BaseMessage) for c in contents):
1338
- raise ValueError(
1339
- f"'contents' must either be a list of "
1340
- f"langchain_core.messages.BaseMessage or a list "
1341
- f"google.generativeai.types.content_types.ContentType, but not a mix "
1342
- f"of the two. Received {contents}"
1343
- )
1344
- else:
1345
- for content in contents:
1346
- if hasattr(content, "role") and content.role == "system":
1347
- if system is not None:
1348
- warnings.warn(
1349
- "Received multiple pieces of content with role 'system'. "
1350
- "Should only be one set of system instructions. Ignoring "
1351
- "all but the first 'system' content."
1352
- )
1353
- else:
1354
- system = content
1355
- elif isinstance(content, dict) and content.get("role") == "system":
1356
- if system is not None:
1357
- warnings.warn(
1358
- "Received multiple pieces of content with role 'system'. "
1359
- "Should only be one set of system instructions. Ignoring "
1360
- "all but the first 'system' content."
1361
- )
1362
- else:
1363
- system = content
1364
- else:
1365
- genai_contents.append(content)
1366
- if tools:
1367
- genai_tools = [convert_to_genai_function_declarations(tools)]
1368
- else:
1369
- genai_tools = None
1370
- if tool_choice and genai_tools:
1371
- all_names = [f.name for t in genai_tools for f in t.function_declarations]
1372
- tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
1373
- genai_tool_config = ToolConfig(
1374
- function_calling_config=tool_config["function_calling_config"]
1375
- )
1376
- else:
1377
- genai_tool_config = None
1378
- cached_content = CachedContent.create(
1379
- model=self.model,
1380
- system_instruction=system,
1381
- contents=genai_contents,
1382
- display_name=display_name,
1383
- tools=genai_tools,
1384
- tool_config=genai_tool_config,
1385
- ttl=ttl,
1386
- expire_time=expire_time,
1387
- )
1388
- return cached_content.name
1389
-
1390
1308
  @property
1391
1309
  def _supports_tool_choice(self) -> bool:
1392
1310
  return (
@@ -1397,7 +1315,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1397
1315
 
1398
1316
 
1399
1317
  def _get_tool_name(
1400
- tool: Union[ToolDict, GoogleTool],
1318
+ tool: Union[_ToolDict, GoogleTool],
1401
1319
  ) -> str:
1402
1320
  genai_tool = tool_to_dict(convert_to_genai_function_declarations([tool]))
1403
1321
  return [f["name"] for f in genai_tool["function_declarations"]][0] # type: ignore[index]
@@ -0,0 +1,134 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Iterator, List, Optional
4
+
5
+ from langchain_core.callbacks import (
6
+ CallbackManagerForLLMRun,
7
+ )
8
+ from langchain_core.language_models import LangSmithParams
9
+ from langchain_core.language_models.llms import BaseLLM
10
+ from langchain_core.messages import HumanMessage
11
+ from langchain_core.outputs import Generation, GenerationChunk, LLMResult
12
+ from pydantic import ConfigDict, model_validator
13
+ from typing_extensions import Self
14
+
15
+ from langchain_google_genai._common import (
16
+ _BaseGoogleGenerativeAI,
17
+ )
18
+ from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
19
+
20
+
21
+ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
22
+ """Google GenerativeAI models.
23
+
24
+ Example:
25
+ .. code-block:: python
26
+
27
+ from langchain_google_genai import GoogleGenerativeAI
28
+ llm = GoogleGenerativeAI(model="gemini-pro")
29
+ """
30
+
31
+ client: Any = None #: :meta private:
32
+ model_config = ConfigDict(
33
+ populate_by_name=True,
34
+ )
35
+
36
+ @model_validator(mode="after")
37
+ def validate_environment(self) -> Self:
38
+ """Validates params and passes them to google-generativeai package."""
39
+
40
+ self.client = ChatGoogleGenerativeAI(
41
+ api_key=self.google_api_key,
42
+ credentials=self.credentials,
43
+ temperature=self.temperature,
44
+ top_p=self.top_p,
45
+ top_k=self.top_k,
46
+ max_tokens=self.max_output_tokens,
47
+ timeout=self.timeout,
48
+ model=self.model,
49
+ client_options=self.client_options,
50
+ transport=self.transport,
51
+ additional_headers=self.additional_headers,
52
+ safety_settings=self.safety_settings,
53
+ )
54
+
55
+ return self
56
+
57
+ def _get_ls_params(
58
+ self, stop: Optional[List[str]] = None, **kwargs: Any
59
+ ) -> LangSmithParams:
60
+ """Get standard params for tracing."""
61
+ ls_params = super()._get_ls_params(stop=stop, **kwargs)
62
+ ls_params["ls_provider"] = "google_genai"
63
+ if ls_max_tokens := kwargs.get("max_output_tokens", self.max_output_tokens):
64
+ ls_params["ls_max_tokens"] = ls_max_tokens
65
+ return ls_params
66
+
67
+ def _generate(
68
+ self,
69
+ prompts: List[str],
70
+ stop: Optional[List[str]] = None,
71
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
72
+ **kwargs: Any,
73
+ ) -> LLMResult:
74
+ generations = []
75
+ for prompt in prompts:
76
+ chat_result = self.client._generate(
77
+ [HumanMessage(content=prompt)],
78
+ stop=stop,
79
+ run_manager=run_manager,
80
+ **kwargs,
81
+ )
82
+ generations.append(
83
+ [
84
+ Generation(
85
+ text=g.message.content,
86
+ generation_info={
87
+ **g.generation_info,
88
+ **{"usage_metadata": g.message.usage_metadata},
89
+ },
90
+ )
91
+ for g in chat_result.generations
92
+ ]
93
+ )
94
+ return LLMResult(generations=generations)
95
+
96
+ def _stream(
97
+ self,
98
+ prompt: str,
99
+ stop: Optional[List[str]] = None,
100
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
101
+ **kwargs: Any,
102
+ ) -> Iterator[GenerationChunk]:
103
+ for stream_chunk in self.client._stream(
104
+ [HumanMessage(content=prompt)],
105
+ stop=stop,
106
+ run_manager=run_manager,
107
+ **kwargs,
108
+ ):
109
+ chunk = GenerationChunk(text=stream_chunk.message.content)
110
+ yield chunk
111
+ if run_manager:
112
+ run_manager.on_llm_new_token(
113
+ chunk.text,
114
+ chunk=chunk,
115
+ verbose=self.verbose,
116
+ )
117
+
118
+ @property
119
+ def _llm_type(self) -> str:
120
+ """Return type of llm."""
121
+ return "google_gemini"
122
+
123
+ def get_num_tokens(self, text: str) -> int:
124
+ """Get the number of tokens present in the text.
125
+
126
+ Useful for checking if an input will fit in a model's context window.
127
+
128
+ Args:
129
+ text: The string input to tokenize.
130
+
131
+ Returns:
132
+ The integer number of tokens in the text.
133
+ """
134
+ return self.client.get_num_tokens(text)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "langchain-google-genai"
3
- version = "2.0.10"
3
+ version = "2.0.11"
4
4
  description = "An integration package connecting Google's genai package and LangChain"
5
5
  authors = []
6
6
  readme = "README.md"
@@ -13,7 +13,7 @@ license = "MIT"
13
13
  [tool.poetry.dependencies]
14
14
  python = ">=3.9,<4.0"
15
15
  langchain-core = "^0.3.37"
16
- google-generativeai = "^0.8.0"
16
+ google-ai-generativelanguage = "^0.6.16"
17
17
  pydantic = ">=2,<3"
18
18
  filetype = "^1.2.0"
19
19
 
@@ -1,52 +0,0 @@
1
- from importlib import metadata
2
- from typing import Optional, Tuple, TypedDict
3
-
4
- from google.api_core.gapic_v1.client_info import ClientInfo
5
-
6
- from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory
7
-
8
-
9
- class GoogleGenerativeAIError(Exception):
10
- """
11
- Custom exception class for errors associated with the `Google GenAI` API.
12
- """
13
-
14
-
15
- def get_user_agent(module: Optional[str] = None) -> Tuple[str, str]:
16
- r"""Returns a custom user agent header.
17
-
18
- Args:
19
- module (Optional[str]):
20
- Optional. The module for a custom user agent header.
21
- Returns:
22
- Tuple[str, str]
23
- """
24
- try:
25
- langchain_version = metadata.version("langchain-google-genai")
26
- except metadata.PackageNotFoundError:
27
- langchain_version = "0.0.0"
28
- client_library_version = (
29
- f"{langchain_version}-{module}" if module else langchain_version
30
- )
31
- return client_library_version, f"langchain-google-genai/{client_library_version}"
32
-
33
-
34
- def get_client_info(module: Optional[str] = None) -> "ClientInfo":
35
- r"""Returns a client info object with a custom user agent header.
36
-
37
- Args:
38
- module (Optional[str]):
39
- Optional. The module for a custom user agent header.
40
- Returns:
41
- google.api_core.gapic_v1.client_info.ClientInfo
42
- """
43
- client_library_version, user_agent = get_user_agent(module)
44
- return ClientInfo(
45
- client_library_version=client_library_version,
46
- user_agent=user_agent,
47
- )
48
-
49
-
50
- class SafetySettingDict(TypedDict):
51
- category: HarmCategory
52
- threshold: HarmBlockThreshold
@@ -1,406 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from enum import Enum, auto
4
- from typing import Any, Callable, Dict, Iterator, List, Optional, Union
5
-
6
- import google.api_core
7
- import google.generativeai as genai # type: ignore[import]
8
- from langchain_core.callbacks import (
9
- AsyncCallbackManagerForLLMRun,
10
- CallbackManagerForLLMRun,
11
- )
12
- from langchain_core.language_models import LangSmithParams, LanguageModelInput
13
- from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
14
- from langchain_core.outputs import Generation, GenerationChunk, LLMResult
15
- from langchain_core.utils import secret_from_env
16
- from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
17
- from typing_extensions import Self
18
-
19
- from langchain_google_genai._enums import (
20
- HarmBlockThreshold,
21
- HarmCategory,
22
- )
23
-
24
-
25
- class GoogleModelFamily(str, Enum):
26
- GEMINI = auto()
27
- PALM = auto()
28
-
29
- @classmethod
30
- def _missing_(cls, value: Any) -> Optional["GoogleModelFamily"]:
31
- if "gemini" in value.lower():
32
- return GoogleModelFamily.GEMINI
33
- elif "text-bison" in value.lower():
34
- return GoogleModelFamily.PALM
35
- return None
36
-
37
-
38
- def _create_retry_decorator(
39
- llm: BaseLLM,
40
- *,
41
- max_retries: int = 1,
42
- run_manager: Optional[
43
- Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
44
- ] = None,
45
- ) -> Callable[[Any], Any]:
46
- """Creates a retry decorator for Vertex / Palm LLMs."""
47
-
48
- errors = [
49
- google.api_core.exceptions.ResourceExhausted,
50
- google.api_core.exceptions.ServiceUnavailable,
51
- google.api_core.exceptions.Aborted,
52
- google.api_core.exceptions.DeadlineExceeded,
53
- google.api_core.exceptions.GoogleAPIError,
54
- ]
55
- decorator = create_base_retry_decorator(
56
- error_types=errors, max_retries=max_retries, run_manager=run_manager
57
- )
58
- return decorator
59
-
60
-
61
- def _completion_with_retry(
62
- llm: GoogleGenerativeAI,
63
- prompt: LanguageModelInput,
64
- is_gemini: bool = False,
65
- stream: bool = False,
66
- run_manager: Optional[CallbackManagerForLLMRun] = None,
67
- **kwargs: Any,
68
- ) -> Any:
69
- """Use tenacity to retry the completion call."""
70
- retry_decorator = _create_retry_decorator(
71
- llm, max_retries=llm.max_retries, run_manager=run_manager
72
- )
73
-
74
- @retry_decorator
75
- def _completion_with_retry(
76
- prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
77
- ) -> Any:
78
- generation_config = kwargs.get("generation_config", {})
79
- error_msg = (
80
- "Your location is not supported by google-generativeai at the moment. "
81
- "Try to use VertexAI LLM from langchain_google_vertexai"
82
- )
83
- try:
84
- if is_gemini:
85
- return llm.client.generate_content(
86
- contents=prompt,
87
- stream=stream,
88
- generation_config=generation_config,
89
- safety_settings=kwargs.pop("safety_settings", None),
90
- request_options={"timeout": llm.timeout} if llm.timeout else None,
91
- )
92
- return llm.client.generate_text(prompt=prompt, **kwargs)
93
- except google.api_core.exceptions.FailedPrecondition as exc:
94
- if "location is not supported" in exc.message:
95
- raise ValueError(error_msg)
96
-
97
- return _completion_with_retry(
98
- prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
99
- )
100
-
101
-
102
- def _strip_erroneous_leading_spaces(text: str) -> str:
103
- """Strip erroneous leading spaces from text.
104
-
105
- The PaLM API will sometimes erroneously return a single leading space in all
106
- lines > 1. This function strips that space.
107
- """
108
- has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:])
109
- if has_leading_space:
110
- return text.replace("\n ", "\n")
111
- else:
112
- return text
113
-
114
-
115
- class _BaseGoogleGenerativeAI(BaseModel):
116
- """Base class for Google Generative AI LLMs"""
117
-
118
- model: str = Field(
119
- ...,
120
- description="""The name of the model to use.
121
- Supported examples:
122
- - gemini-pro
123
- - models/text-bison-001""",
124
- )
125
- """Model name to use."""
126
- google_api_key: Optional[SecretStr] = Field(
127
- alias="api_key", default_factory=secret_from_env("GOOGLE_API_KEY", default=None)
128
- )
129
- credentials: Any = None
130
- "The default custom credentials (google.auth.credentials.Credentials) to use "
131
- "when making API calls. If not provided, credentials will be ascertained from "
132
- "the GOOGLE_API_KEY envvar"
133
- temperature: float = 0.7
134
- """Run inference with this temperature. Must by in the closed interval
135
- [0.0, 1.0]."""
136
- top_p: Optional[float] = None
137
- """Decode using nucleus sampling: consider the smallest set of tokens whose
138
- probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
139
- top_k: Optional[int] = None
140
- """Decode using top-k sampling: consider the set of top_k most probable tokens.
141
- Must be positive."""
142
- max_output_tokens: Optional[int] = Field(default=None, alias="max_tokens")
143
- """Maximum number of tokens to include in a candidate. Must be greater than zero.
144
- If unset, will default to 64."""
145
- n: int = 1
146
- """Number of chat completions to generate for each prompt. Note that the API may
147
- not return the full n completions if duplicates are generated."""
148
- max_retries: int = 6
149
- """The maximum number of retries to make when generating."""
150
-
151
- timeout: Optional[float] = None
152
- """The maximum number of seconds to wait for a response."""
153
-
154
- client_options: Optional[Dict] = Field(
155
- default=None,
156
- description=(
157
- "A dictionary of client options to pass to the Google API client, "
158
- "such as `api_endpoint`."
159
- ),
160
- )
161
- transport: Optional[str] = Field(
162
- default=None,
163
- description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
164
- )
165
- additional_headers: Optional[Dict[str, str]] = Field(
166
- default=None,
167
- description=(
168
- "A key-value dictionary representing additional headers for the model call"
169
- ),
170
- )
171
-
172
- safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
173
- """The default safety settings to use for all generations.
174
-
175
- For example:
176
-
177
- from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
178
-
179
- safety_settings = {
180
- HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
181
- HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
182
- HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
183
- HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
184
- }
185
- """ # noqa: E501
186
-
187
- @property
188
- def lc_secrets(self) -> Dict[str, str]:
189
- return {"google_api_key": "GOOGLE_API_KEY"}
190
-
191
- @property
192
- def _model_family(self) -> str:
193
- return GoogleModelFamily(self.model)
194
-
195
- @property
196
- def _identifying_params(self) -> Dict[str, Any]:
197
- """Get the identifying parameters."""
198
- return {
199
- "model": self.model,
200
- "temperature": self.temperature,
201
- "top_p": self.top_p,
202
- "top_k": self.top_k,
203
- "max_output_tokens": self.max_output_tokens,
204
- "candidate_count": self.n,
205
- }
206
-
207
-
208
- class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
209
- """Google GenerativeAI models.
210
-
211
- Example:
212
- .. code-block:: python
213
-
214
- from langchain_google_genai import GoogleGenerativeAI
215
- llm = GoogleGenerativeAI(model="gemini-pro")
216
- """
217
-
218
- client: Any = None #: :meta private:
219
- model_config = ConfigDict(
220
- populate_by_name=True,
221
- )
222
-
223
- @model_validator(mode="after")
224
- def validate_environment(self) -> Self:
225
- """Validates params and passes them to google-generativeai package."""
226
- if self.credentials:
227
- genai.configure(
228
- credentials=self.credentials,
229
- transport=self.transport,
230
- client_options=self.client_options,
231
- )
232
- else:
233
- if isinstance(self.google_api_key, SecretStr):
234
- google_api_key: Optional[str] = self.google_api_key.get_secret_value()
235
- else:
236
- google_api_key = self.google_api_key
237
- genai.configure(
238
- api_key=google_api_key,
239
- transport=self.transport,
240
- client_options=self.client_options,
241
- )
242
-
243
- model_name = self.model
244
-
245
- safety_settings = self.safety_settings
246
-
247
- if safety_settings and (
248
- not GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI
249
- ):
250
- raise ValueError("Safety settings are only supported for Gemini models")
251
-
252
- if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
253
- self.client = genai.GenerativeModel(
254
- model_name=model_name, safety_settings=safety_settings
255
- )
256
- else:
257
- self.client = genai
258
-
259
- if self.temperature is not None and not 0 <= self.temperature <= 1:
260
- raise ValueError("temperature must be in the range [0.0, 1.0]")
261
-
262
- if self.top_p is not None and not 0 <= self.top_p <= 1:
263
- raise ValueError("top_p must be in the range [0.0, 1.0]")
264
-
265
- if self.top_k is not None and self.top_k <= 0:
266
- raise ValueError("top_k must be positive")
267
-
268
- if self.max_output_tokens is not None and self.max_output_tokens <= 0:
269
- raise ValueError("max_output_tokens must be greater than zero")
270
-
271
- if self.timeout is not None and self.timeout <= 0:
272
- raise ValueError("timeout must be greater than zero")
273
-
274
- return self
275
-
276
- def _get_ls_params(
277
- self, stop: Optional[List[str]] = None, **kwargs: Any
278
- ) -> LangSmithParams:
279
- """Get standard params for tracing."""
280
- ls_params = super()._get_ls_params(stop=stop, **kwargs)
281
- ls_params["ls_provider"] = "google_genai"
282
- if ls_max_tokens := kwargs.get("max_output_tokens", self.max_output_tokens):
283
- ls_params["ls_max_tokens"] = ls_max_tokens
284
- return ls_params
285
-
286
- def _generate(
287
- self,
288
- prompts: List[str],
289
- stop: Optional[List[str]] = None,
290
- run_manager: Optional[CallbackManagerForLLMRun] = None,
291
- **kwargs: Any,
292
- ) -> LLMResult:
293
- generations: List[List[Generation]] = []
294
- generation_config = {
295
- "stop_sequences": stop,
296
- "temperature": self.temperature,
297
- "top_p": self.top_p,
298
- "top_k": self.top_k,
299
- "max_output_tokens": self.max_output_tokens,
300
- "candidate_count": self.n,
301
- }
302
- for prompt in prompts:
303
- if self._model_family == GoogleModelFamily.GEMINI:
304
- res = _completion_with_retry(
305
- self,
306
- prompt=prompt,
307
- stream=False,
308
- is_gemini=True,
309
- run_manager=run_manager,
310
- generation_config=generation_config,
311
- safety_settings=kwargs.pop("safety_settings", None),
312
- )
313
- generation_info = None
314
- if res.usage_metadata is not None:
315
- generation_info = {
316
- "usage_metadata": res.to_dict().get("usage_metadata")
317
- }
318
-
319
- candidates = [
320
- "".join([p.text for p in c.content.parts]) for c in res.candidates
321
- ]
322
- generations.append(
323
- [
324
- Generation(text=c, generation_info=generation_info)
325
- for c in candidates
326
- ]
327
- )
328
- else:
329
- res = _completion_with_retry(
330
- self,
331
- model=self.model,
332
- prompt=prompt,
333
- stream=False,
334
- is_gemini=False,
335
- run_manager=run_manager,
336
- **generation_config,
337
- )
338
- prompt_generations = []
339
- for candidate in res.candidates:
340
- raw_text = candidate["output"]
341
- stripped_text = _strip_erroneous_leading_spaces(raw_text)
342
- prompt_generations.append(Generation(text=stripped_text))
343
- generations.append(prompt_generations)
344
-
345
- return LLMResult(generations=generations)
346
-
347
- def _stream(
348
- self,
349
- prompt: str,
350
- stop: Optional[List[str]] = None,
351
- run_manager: Optional[CallbackManagerForLLMRun] = None,
352
- **kwargs: Any,
353
- ) -> Iterator[GenerationChunk]:
354
- generation_config = {
355
- "stop_sequences": stop,
356
- "temperature": self.temperature,
357
- "top_p": self.top_p,
358
- "top_k": self.top_k,
359
- "max_output_tokens": self.max_output_tokens,
360
- "candidate_count": self.n,
361
- }
362
- generation_config = generation_config | kwargs.get("generation_config", {})
363
-
364
- for stream_resp in _completion_with_retry(
365
- self,
366
- prompt,
367
- stream=True,
368
- is_gemini=True,
369
- run_manager=run_manager,
370
- generation_config=generation_config,
371
- safety_settings=kwargs.pop("safety_settings", None),
372
- **kwargs,
373
- ):
374
- chunk = GenerationChunk(text=stream_resp.text)
375
- yield chunk
376
- if run_manager:
377
- run_manager.on_llm_new_token(
378
- stream_resp.text,
379
- chunk=chunk,
380
- verbose=self.verbose,
381
- )
382
-
383
- @property
384
- def _llm_type(self) -> str:
385
- """Return type of llm."""
386
- return "google_palm"
387
-
388
- def get_num_tokens(self, text: str) -> int:
389
- """Get the number of tokens present in the text.
390
-
391
- Useful for checking if an input will fit in a model's context window.
392
-
393
- Args:
394
- text: The string input to tokenize.
395
-
396
- Returns:
397
- The integer number of tokens in the text.
398
- """
399
- if self._model_family == GoogleModelFamily.GEMINI:
400
- result = self.client.count_tokens(text)
401
- token_count = result.total_tokens
402
- else:
403
- result = self.client.count_text_tokens(model=self.model, prompt=text)
404
- token_count = result["token_count"]
405
-
406
- return token_count