langchain-google-genai 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain-google-genai might be problematic. Click here for more details.

@@ -35,7 +35,7 @@ llm.invoke("Sing a ballad of LangChain.")
35
35
  ## Using LLMs
36
36
 
37
37
  The package also supports generating text with Google's models.
38
-
38
+
39
39
  ```python
40
40
  from langchain_google_genai import GoogleGenerativeAI
41
41
 
@@ -1,4 +1,52 @@
1
+ from importlib import metadata
2
+ from typing import Optional, Tuple, TypedDict
3
+
4
+ from google.api_core.gapic_v1.client_info import ClientInfo
5
+
6
+ from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory
7
+
8
+
1
9
  class GoogleGenerativeAIError(Exception):
2
10
  """
3
11
  Custom exception class for errors associated with the `Google GenAI` API.
4
12
  """
13
+
14
+
15
+ def get_user_agent(module: Optional[str] = None) -> Tuple[str, str]:
16
+ r"""Returns a custom user agent header.
17
+
18
+ Args:
19
+ module (Optional[str]):
20
+ Optional. The module for a custom user agent header.
21
+ Returns:
22
+ Tuple[str, str]
23
+ """
24
+ try:
25
+ langchain_version = metadata.version("langchain-google-genai")
26
+ except metadata.PackageNotFoundError:
27
+ langchain_version = "0.0.0"
28
+ client_library_version = (
29
+ f"{langchain_version}-{module}" if module else langchain_version
30
+ )
31
+ return client_library_version, f"langchain-google-genai/{client_library_version}"
32
+
33
+
34
+ def get_client_info(module: Optional[str] = None) -> "ClientInfo":
35
+ r"""Returns a client info object with a custom user agent header.
36
+
37
+ Args:
38
+ module (Optional[str]):
39
+ Optional. The module for a custom user agent header.
40
+ Returns:
41
+ google.api_core.gapic_v1.client_info.ClientInfo
42
+ """
43
+ client_library_version, user_agent = get_user_agent(module)
44
+ return ClientInfo(
45
+ client_library_version=client_library_version,
46
+ user_agent=user_agent,
47
+ )
48
+
49
+
50
+ class SafetySettingDict(TypedDict):
51
+ category: HarmCategory
52
+ threshold: HarmBlockThreshold
@@ -1,6 +1,6 @@
1
- from google.generativeai.types.safety_types import ( # type: ignore
2
- HarmBlockThreshold,
3
- HarmCategory,
4
- )
1
+ import google.ai.generativelanguage_v1beta as genai
2
+
3
+ HarmBlockThreshold = genai.SafetySetting.HarmBlockThreshold
4
+ HarmCategory = genai.HarmCategory
5
5
 
6
6
  __all__ = ["HarmBlockThreshold", "HarmCategory"]
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from typing import (
4
4
  Any,
5
+ Callable,
5
6
  Dict,
6
7
  List,
7
8
  Literal,
@@ -10,15 +11,16 @@ from typing import (
10
11
  Type,
11
12
  TypedDict,
12
13
  Union,
14
+ cast,
13
15
  )
14
16
 
15
17
  import google.ai.generativelanguage as glm
16
- from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
17
- from google.generativeai.types.content_types import ( # type: ignore[import]
18
- FunctionCallingConfigType,
19
- FunctionDeclarationType,
20
- ToolDict,
21
- ToolType,
18
+ from google.ai.generativelanguage import (
19
+ FunctionCallingConfig,
20
+ FunctionDeclaration,
21
+ )
22
+ from google.ai.generativelanguage import (
23
+ Tool as GoogleTool,
22
24
  )
23
25
  from langchain_core.pydantic_v1 import BaseModel
24
26
  from langchain_core.tools import BaseTool
@@ -36,51 +38,41 @@ TYPE_ENUM = {
36
38
 
37
39
  TYPE_ENUM_REVERSE = {v: k for k, v in TYPE_ENUM.items()}
38
40
 
41
+ _FunctionDeclarationLike = Union[
42
+ BaseTool, Type[BaseModel], dict, Callable, FunctionDeclaration
43
+ ]
44
+
45
+
46
+ class _ToolDict(TypedDict):
47
+ function_declarations: Sequence[_FunctionDeclarationLike]
48
+
39
49
 
40
50
  def convert_to_genai_function_declarations(
41
51
  tool: Union[
42
- GoogleTool, ToolDict, FunctionDeclarationType, Sequence[FunctionDeclarationType]
52
+ GoogleTool,
53
+ _ToolDict,
54
+ _FunctionDeclarationLike,
55
+ Sequence[_FunctionDeclarationLike],
43
56
  ],
44
- ) -> ToolType:
45
- """Convert any tool-like object to a ToolType.
46
-
47
- https://github.com/google-gemini/generative-ai-python/blob/668695ebe3e9de496a36eeb95cb2ed2faba9b939/google/generativeai/types/content_types.py#L574
48
- """
57
+ ) -> GoogleTool:
49
58
  if isinstance(tool, GoogleTool):
50
- return tool
51
- # check whether a dict is supported by glm, otherwise we parse it explicitly
52
- if isinstance(tool, dict):
53
- first_function_declaration = tool.get("function_declarations", [None])[0]
54
- if isinstance(first_function_declaration, glm.FunctionDeclaration):
55
- return tool
56
- schema = None
57
- try:
58
- schema = first_function_declaration.parameters
59
- except AttributeError:
60
- pass
61
- if schema is None:
62
- schema = first_function_declaration.get("parameters")
63
- if schema is None or isinstance(schema, glm.Schema):
64
- return tool
65
- return glm.Tool(
59
+ return cast(GoogleTool, tool)
60
+ if isinstance(tool, type) and issubclass(tool, BaseModel):
61
+ return GoogleTool(function_declarations=[_convert_to_genai_function(tool)])
62
+ if callable(tool):
63
+ return _convert_tool_to_genai_function(callable_as_lc_tool()(tool))
64
+ if isinstance(tool, list):
65
+ return convert_to_genai_function_declarations({"function_declarations": tool})
66
+ if isinstance(tool, dict) and "function_declarations" in tool:
67
+ return GoogleTool(
66
68
  function_declarations=[
67
69
  _convert_to_genai_function(fc) for fc in tool["function_declarations"]
68
70
  ],
69
71
  )
70
- elif isinstance(tool, type) and issubclass(tool, BaseModel):
71
- return glm.Tool(function_declarations=[_convert_to_genai_function(tool)])
72
- elif callable(tool):
73
- return _convert_tool_to_genai_function(callable_as_lc_tool()(tool))
74
- elif isinstance(tool, list):
75
- return glm.Tool(
76
- function_declarations=[_convert_to_genai_function(fc) for fc in tool]
77
- )
78
- return glm.Tool(function_declarations=[_convert_to_genai_function(tool)])
72
+ return GoogleTool(function_declarations=[_convert_to_genai_function(tool)]) # type: ignore[arg-type]
79
73
 
80
74
 
81
- def tool_to_dict(tool: Union[glm.Tool, GoogleTool]) -> ToolDict:
82
- if isinstance(tool, GoogleTool):
83
- tool = tool._proto
75
+ def tool_to_dict(tool: GoogleTool) -> _ToolDict:
84
76
  function_declarations = []
85
77
  for function_declaration_proto in tool.function_declarations:
86
78
  properties: Dict[str, Any] = {}
@@ -108,7 +100,7 @@ def tool_to_dict(tool: Union[glm.Tool, GoogleTool]) -> ToolDict:
108
100
  return {"function_declarations": function_declarations}
109
101
 
110
102
 
111
- def _convert_to_genai_function(fc: FunctionDeclarationType) -> glm.FunctionDeclaration:
103
+ def _convert_to_genai_function(fc: _FunctionDeclarationLike) -> FunctionDeclaration:
112
104
  if isinstance(fc, BaseTool):
113
105
  return _convert_tool_to_genai_function(fc)
114
106
  elif isinstance(fc, type) and issubclass(fc, BaseModel):
@@ -116,10 +108,9 @@ def _convert_to_genai_function(fc: FunctionDeclarationType) -> glm.FunctionDecla
116
108
  elif callable(fc):
117
109
  return _convert_tool_to_genai_function(callable_as_lc_tool()(fc))
118
110
  elif isinstance(fc, dict):
119
- return glm.FunctionDeclaration(
120
- name=fc["name"],
121
- description=fc.get("description"),
122
- parameters={
111
+ formatted_fc = {"name": fc["name"], "description": fc.get("description")}
112
+ if "parameters" in fc:
113
+ formatted_fc["parameters"] = {
123
114
  "properties": {
124
115
  k: {
125
116
  "type_": TYPE_ENUM[v["type"]],
@@ -127,19 +118,19 @@ def _convert_to_genai_function(fc: FunctionDeclarationType) -> glm.FunctionDecla
127
118
  }
128
119
  for k, v in fc["parameters"]["properties"].items()
129
120
  },
130
- "required": fc["parameters"].get("required", []),
121
+ "required": fc.get("parameters", []).get("required", []),
131
122
  "type_": TYPE_ENUM[fc["parameters"]["type"]],
132
- },
133
- )
123
+ }
124
+ return FunctionDeclaration(**formatted_fc)
134
125
  else:
135
126
  raise ValueError(f"Unsupported function call type {fc}")
136
127
 
137
128
 
138
- def _convert_tool_to_genai_function(tool: BaseTool) -> glm.FunctionDeclaration:
129
+ def _convert_tool_to_genai_function(tool: BaseTool) -> FunctionDeclaration:
139
130
  if tool.args_schema:
140
131
  schema = dereference_refs(tool.args_schema.schema())
141
132
  schema.pop("definitions", None)
142
- return glm.FunctionDeclaration(
133
+ return FunctionDeclaration(
143
134
  name=tool.name or schema["title"],
144
135
  description=tool.description or schema["description"],
145
136
  parameters={
@@ -155,7 +146,7 @@ def _convert_tool_to_genai_function(tool: BaseTool) -> glm.FunctionDeclaration:
155
146
  },
156
147
  )
157
148
  else:
158
- return glm.FunctionDeclaration(
149
+ return FunctionDeclaration(
159
150
  name=tool.name,
160
151
  description=tool.description,
161
152
  parameters={
@@ -170,10 +161,10 @@ def _convert_tool_to_genai_function(tool: BaseTool) -> glm.FunctionDeclaration:
170
161
 
171
162
  def _convert_pydantic_to_genai_function(
172
163
  pydantic_model: Type[BaseModel],
173
- ) -> glm.FunctionDeclaration:
164
+ ) -> FunctionDeclaration:
174
165
  schema = dereference_refs(pydantic_model.schema())
175
166
  schema.pop("definitions", None)
176
- return glm.FunctionDeclaration(
167
+ return FunctionDeclaration(
177
168
  name=schema["title"],
178
169
  description=schema.get("description", ""),
179
170
  parameters={
@@ -195,8 +186,13 @@ _ToolChoiceType = Union[
195
186
  ]
196
187
 
197
188
 
189
+ class _FunctionCallingConfigDict(TypedDict):
190
+ mode: Union[FunctionCallingConfig.Mode, str]
191
+ allowed_function_names: Optional[List[str]]
192
+
193
+
198
194
  class _ToolConfigDict(TypedDict):
199
- function_calling_config: FunctionCallingConfigType
195
+ function_calling_config: _FunctionCallingConfigDict
200
196
 
201
197
 
202
198
  def _tool_choice_to_tool_config(
@@ -12,6 +12,12 @@ from typing import Any, Dict, Iterator, List, MutableSequence, Optional
12
12
 
13
13
  import google.ai.generativelanguage as genai
14
14
  import langchain_core
15
+ from google.ai.generativelanguage_v1beta import (
16
+ GenerativeServiceAsyncClient as v1betaGenerativeServiceAsyncClient,
17
+ )
18
+ from google.ai.generativelanguage_v1beta import (
19
+ GenerativeServiceClient as v1betaGenerativeServiceClient,
20
+ )
15
21
  from google.api_core import client_options as client_options_lib
16
22
  from google.api_core import exceptions as gapi_exception
17
23
  from google.api_core import gapic_v1
@@ -225,15 +231,66 @@ def build_semantic_retriever() -> genai.RetrieverServiceClient:
225
231
  )
226
232
 
227
233
 
228
- def build_generative_service() -> genai.GenerativeServiceClient:
229
- credentials = _get_credentials()
230
- return genai.GenerativeServiceClient(
234
+ def _prepare_config(
235
+ credentials: Optional[credentials.Credentials] = None,
236
+ api_key: Optional[str] = None,
237
+ client_options: Optional[Dict[str, Any]] = None,
238
+ client_info: Optional[gapic_v1.client_info.ClientInfo] = None,
239
+ transport: Optional[str] = None,
240
+ ) -> Dict[str, Any]:
241
+ formatted_client_options = {"api_endpoint": _config.api_endpoint}
242
+ if client_options:
243
+ formatted_client_options.update(**client_options)
244
+ if not credentials and api_key:
245
+ formatted_client_options["api_key"] = api_key
246
+ elif not credentials and not api_key:
247
+ credentials = _get_credentials()
248
+ client_info = (
249
+ client_info
250
+ if client_info
251
+ else gapic_v1.client_info.ClientInfo(user_agent=_USER_AGENT)
252
+ )
253
+ config = {
254
+ "credentials": credentials,
255
+ "client_info": client_info,
256
+ "client_options": client_options_lib.ClientOptions(**formatted_client_options),
257
+ "transport": transport,
258
+ }
259
+ return {k: v for k, v in config.items() if v is not None}
260
+
261
+
262
+ def build_generative_service(
263
+ credentials: Optional[credentials.Credentials] = None,
264
+ api_key: Optional[str] = None,
265
+ client_options: Optional[Dict[str, Any]] = None,
266
+ client_info: Optional[gapic_v1.client_info.ClientInfo] = None,
267
+ transport: Optional[str] = None,
268
+ ) -> v1betaGenerativeServiceClient:
269
+ config = _prepare_config(
231
270
  credentials=credentials,
232
- client_info=gapic_v1.client_info.ClientInfo(user_agent=_USER_AGENT),
233
- client_options=client_options_lib.ClientOptions(
234
- api_endpoint=_config.api_endpoint
235
- ),
271
+ api_key=api_key,
272
+ client_options=client_options,
273
+ transport=transport,
274
+ client_info=client_info,
275
+ )
276
+ return v1betaGenerativeServiceClient(**config)
277
+
278
+
279
+ def build_generative_async_service(
280
+ credentials: Optional[credentials.Credentials],
281
+ api_key: Optional[str] = None,
282
+ client_options: Optional[Dict[str, Any]] = None,
283
+ client_info: Optional[gapic_v1.client_info.ClientInfo] = None,
284
+ transport: Optional[str] = None,
285
+ ) -> v1betaGenerativeServiceAsyncClient:
286
+ config = _prepare_config(
287
+ credentials=credentials,
288
+ api_key=api_key,
289
+ client_options=client_options,
290
+ transport=transport,
291
+ client_info=client_info,
236
292
  )
293
+ return v1betaGenerativeServiceAsyncClient(**config)
237
294
 
238
295
 
239
296
  def list_corpora(
@@ -0,0 +1,187 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import mimetypes
5
+ import os
6
+ import re
7
+ from enum import Enum
8
+ from typing import Any, Dict
9
+ from urllib.parse import urlparse
10
+
11
+ import requests
12
+ from google.ai.generativelanguage_v1beta.types import Part
13
+
14
+
15
+ class Route(Enum):
16
+ """Image Loading Route"""
17
+
18
+ BASE64 = 1
19
+ LOCAL_FILE = 2
20
+ URL = 3
21
+
22
+
23
+ class ImageBytesLoader:
24
+ """Loads image bytes from multiple sources given a string.
25
+
26
+ Currently supported:
27
+ - B64 Encoded image string
28
+ - Local file path
29
+ - URL
30
+ """
31
+
32
+ def load_bytes(self, image_string: str) -> bytes:
33
+ """Routes to the correct loader based on the image_string.
34
+
35
+ Args:
36
+ image_string: Can be either:
37
+ - B64 Encoded image string
38
+ - Local file path
39
+ - URL
40
+
41
+ Returns:
42
+ Image bytes.
43
+ """
44
+
45
+ route = self._route(image_string)
46
+
47
+ if route == Route.BASE64:
48
+ return self._bytes_from_b64(image_string)
49
+
50
+ if route == Route.URL:
51
+ return self._bytes_from_url(image_string)
52
+
53
+ if route == Route.LOCAL_FILE:
54
+ return self._bytes_from_file(image_string)
55
+
56
+ raise ValueError(
57
+ "Image string must be one of: Google Cloud Storage URI, "
58
+ "b64 encoded image string (data:image/...), valid image url, "
59
+ f"or existing local image file. Instead got '{image_string}'."
60
+ )
61
+
62
+ def load_part(self, image_string: str) -> Part:
63
+ """Gets Part for loading from Gemini.
64
+
65
+ Args:
66
+ image_string: Can be either:
67
+ - B64 Encoded image string
68
+ - Local file path
69
+ - URL
70
+ """
71
+ route = self._route(image_string)
72
+
73
+ if route == Route.BASE64:
74
+ bytes_ = self._bytes_from_b64(image_string)
75
+
76
+ if route == Route.URL:
77
+ bytes_ = self._bytes_from_url(image_string)
78
+
79
+ if route == Route.LOCAL_FILE:
80
+ bytes_ = self._bytes_from_file(image_string)
81
+
82
+ inline_data: Dict[str, Any] = {"data": bytes_}
83
+ mime_type, _ = mimetypes.guess_type(image_string)
84
+ if mime_type:
85
+ inline_data["mime_type"] = mime_type
86
+
87
+ return Part(inline_data=inline_data)
88
+
89
+ def _route(self, image_string: str) -> Route:
90
+ if image_string.startswith("data:image/"):
91
+ return Route.BASE64
92
+
93
+ if self._is_url(image_string):
94
+ return Route.URL
95
+
96
+ if os.path.exists(image_string):
97
+ return Route.LOCAL_FILE
98
+
99
+ raise ValueError(
100
+ "Image string must be one of: "
101
+ "b64 encoded image string (data:image/...), valid image url, "
102
+ f"or existing local image file. Instead got '{image_string}'."
103
+ )
104
+
105
+ def _bytes_from_b64(self, base64_image: str) -> bytes:
106
+ """Gets image bytes from a base64 encoded string.
107
+
108
+ Args:
109
+ base64_image: Encoded image in b64 format.
110
+
111
+ Returns:
112
+ Image bytes
113
+ """
114
+
115
+ pattern = r"data:image/\w{2,4};base64,(.*)"
116
+ match = re.search(pattern, base64_image)
117
+
118
+ if match is not None:
119
+ encoded_string = match.group(1)
120
+ return base64.b64decode(encoded_string)
121
+
122
+ raise ValueError(f"Error in b64 encoded image. Must follow pattern: {pattern}")
123
+
124
+ def _bytes_from_file(self, file_path: str) -> bytes:
125
+ """Gets image bytes from a local file path.
126
+
127
+ Args:
128
+ file_path: Existing file path.
129
+
130
+ Returns:
131
+ Image bytes
132
+ """
133
+ with open(file_path, "rb") as image_file:
134
+ image_bytes = image_file.read()
135
+ return image_bytes
136
+
137
+ def _bytes_from_url(self, url: str) -> bytes:
138
+ """Gets image bytes from a public url.
139
+
140
+ Args:
141
+ url: Valid url.
142
+
143
+ Raises:
144
+ HTTP Error if there is one.
145
+
146
+ Returns:
147
+ Image bytes
148
+ """
149
+
150
+ response = requests.get(url)
151
+
152
+ if not response.ok:
153
+ response.raise_for_status()
154
+
155
+ return response.content
156
+
157
+ def _is_url(self, url_string: str) -> bool:
158
+ """Checks if a url is valid.
159
+
160
+ Args:
161
+ url_string: Url to check.
162
+
163
+ Returns:
164
+ Whether the url is valid.
165
+ """
166
+ try:
167
+ result = urlparse(url_string)
168
+ return all([result.scheme, result.netloc])
169
+ except Exception:
170
+ return False
171
+
172
+
173
+ def image_bytes_to_b64_string(
174
+ image_bytes: bytes, encoding: str = "ascii", image_format: str = "png"
175
+ ) -> str:
176
+ """Encodes image bytes into a b64 encoded string.
177
+
178
+ Args:
179
+ image_bytes: Bytes of the image.
180
+ encoding: Type of encoding in the string. 'ascii' by default.
181
+ image_format: Format of the image. 'png' by default.
182
+
183
+ Returns:
184
+ B64 image encoded string.
185
+ """
186
+ encoded_bytes = base64.b64encode(image_bytes).decode(encoding)
187
+ return f"data:image/{image_format};base64,{encoded_bytes}"
@@ -23,14 +23,23 @@ from typing import (
23
23
  )
24
24
  from urllib.parse import urlparse
25
25
 
26
- import google.ai.generativelanguage as glm
27
26
  import google.api_core
28
27
 
29
28
  # TODO: remove ignore once the google package is published with types
30
- import google.generativeai as genai # type: ignore[import]
31
29
  import proto # type: ignore[import]
32
30
  import requests
33
- from google.generativeai.types import SafetySettingDict # type: ignore[import]
31
+ from google.ai.generativelanguage_v1beta.types import (
32
+ Candidate,
33
+ Content,
34
+ FunctionCall,
35
+ FunctionResponse,
36
+ GenerateContentRequest,
37
+ GenerateContentResponse,
38
+ GenerationConfig,
39
+ Part,
40
+ SafetySetting,
41
+ ToolConfig,
42
+ )
34
43
  from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
35
44
  from google.generativeai.types.content_types import ( # type: ignore[import]
36
45
  FunctionDeclarationType,
@@ -56,7 +65,7 @@ from langchain_core.messages import (
56
65
  )
57
66
  from langchain_core.output_parsers.openai_tools import parse_tool_calls
58
67
  from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
59
- from langchain_core.pydantic_v1 import SecretStr, root_validator
68
+ from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
60
69
  from langchain_core.runnables import Runnable
61
70
  from langchain_core.utils import get_from_dict_or_env
62
71
  from tenacity import (
@@ -67,7 +76,11 @@ from tenacity import (
67
76
  wait_exponential,
68
77
  )
69
78
 
70
- from langchain_google_genai._common import GoogleGenerativeAIError
79
+ from langchain_google_genai._common import (
80
+ GoogleGenerativeAIError,
81
+ SafetySettingDict,
82
+ get_client_info,
83
+ )
71
84
  from langchain_google_genai._function_utils import (
72
85
  _tool_choice_to_tool_config,
73
86
  _ToolChoiceType,
@@ -75,7 +88,10 @@ from langchain_google_genai._function_utils import (
75
88
  convert_to_genai_function_declarations,
76
89
  tool_to_dict,
77
90
  )
78
- from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI
91
+ from langchain_google_genai._image_utils import ImageBytesLoader
92
+ from langchain_google_genai.llms import _BaseGoogleGenerativeAI
93
+
94
+ from . import _genai_extension as genaix
79
95
 
80
96
  IMAGE_TYPES: Tuple = ()
81
97
  try:
@@ -279,18 +295,19 @@ def _url_to_pil(image_source: str) -> Image:
279
295
 
280
296
  def _convert_to_parts(
281
297
  raw_content: Union[str, Sequence[Union[str, dict]]],
282
- ) -> List[genai.types.PartType]:
298
+ ) -> List[Part]:
283
299
  """Converts a list of LangChain messages into a google parts."""
284
300
  parts = []
285
301
  content = [raw_content] if isinstance(raw_content, str) else raw_content
302
+ image_loader = ImageBytesLoader()
286
303
  for part in content:
287
304
  if isinstance(part, str):
288
- parts.append(genai.types.PartDict(text=part))
305
+ parts.append(Part(text=part))
289
306
  elif isinstance(part, Mapping):
290
307
  # OpenAI Format
291
308
  if _is_openai_parts_format(part):
292
309
  if part["type"] == "text":
293
- parts.append({"text": part["text"]})
310
+ parts.append(Part(text=part["text"]))
294
311
  elif part["type"] == "image_url":
295
312
  img_url = part["image_url"]
296
313
  if isinstance(img_url, dict):
@@ -299,7 +316,7 @@ def _convert_to_parts(
299
316
  f"Unrecognized message image format: {img_url}"
300
317
  )
301
318
  img_url = img_url["url"]
302
- parts.append({"inline_data": _url_to_pil(img_url)})
319
+ parts.append(image_loader.load_part(img_url))
303
320
  else:
304
321
  raise ValueError(f"Unrecognized message part type: {part['type']}")
305
322
  else:
@@ -307,7 +324,7 @@ def _convert_to_parts(
307
324
  logger.warning(
308
325
  "Unrecognized message part format. Assuming it's a text part."
309
326
  )
310
- parts.append(part)
327
+ parts.append(Part(text=str(part)))
311
328
  else:
312
329
  # TODO: Maybe some of Google's native stuff
313
330
  # would hit this branch.
@@ -319,33 +336,36 @@ def _convert_to_parts(
319
336
 
320
337
  def _parse_chat_history(
321
338
  input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
322
- ) -> Tuple[Optional[genai.types.ContentDict], List[genai.types.ContentDict]]:
323
- messages: List[genai.types.MessageDict] = []
339
+ ) -> Tuple[Optional[Content], List[Content]]:
340
+ messages: List[Content] = []
324
341
 
325
342
  if convert_system_message_to_human:
326
343
  warnings.warn("Convert_system_message_to_human will be deprecated!")
327
344
 
328
- system_instruction: Optional[genai.types.ContentDict] = None
345
+ system_instruction: Optional[Content] = None
329
346
  for i, message in enumerate(input_messages):
330
347
  if i == 0 and isinstance(message, SystemMessage):
331
- system_instruction = _convert_to_parts(message.content)
348
+ system_instruction = Content(parts=_convert_to_parts(message.content))
332
349
  continue
333
350
  elif isinstance(message, AIMessage):
334
351
  role = "model"
335
352
  raw_function_call = message.additional_kwargs.get("function_call")
336
353
  if raw_function_call:
337
- function_call = glm.FunctionCall(
354
+ function_call = FunctionCall(
338
355
  {
339
356
  "name": raw_function_call["name"],
340
357
  "args": json.loads(raw_function_call["arguments"]),
341
358
  }
342
359
  )
343
- parts = [glm.Part(function_call=function_call)]
360
+ parts = [Part(function_call=function_call)]
344
361
  else:
345
362
  parts = _convert_to_parts(message.content)
346
363
  elif isinstance(message, HumanMessage):
347
364
  role = "user"
348
365
  parts = _convert_to_parts(message.content)
366
+ if i == 1 and convert_system_message_to_human and system_instruction:
367
+ parts = [p for p in system_instruction.parts] + parts
368
+ system_instruction = None
349
369
  elif isinstance(message, FunctionMessage):
350
370
  role = "user"
351
371
  response: Any
@@ -357,8 +377,8 @@ def _parse_chat_history(
357
377
  except json.JSONDecodeError:
358
378
  response = message.content # leave as str representation
359
379
  parts = [
360
- glm.Part(
361
- function_response=glm.FunctionResponse(
380
+ Part(
381
+ function_response=FunctionResponse(
362
382
  name=message.name,
363
383
  response=(
364
384
  {"output": response}
@@ -391,8 +411,8 @@ def _parse_chat_history(
391
411
  except json.JSONDecodeError:
392
412
  tool_response = message.content # leave as str representation
393
413
  parts = [
394
- glm.Part(
395
- function_response=glm.FunctionResponse(
414
+ Part(
415
+ function_response=FunctionResponse(
396
416
  name=name,
397
417
  response=(
398
418
  {"output": tool_response}
@@ -407,12 +427,12 @@ def _parse_chat_history(
407
427
  f"Unexpected message with type {type(message)} at the position {i}."
408
428
  )
409
429
 
410
- messages.append({"role": role, "parts": parts})
430
+ messages.append(Content(role=role, parts=parts))
411
431
  return system_instruction, messages
412
432
 
413
433
 
414
434
  def _parse_response_candidate(
415
- response_candidate: glm.Candidate, streaming: bool = False
435
+ response_candidate: Candidate, streaming: bool = False
416
436
  ) -> AIMessage:
417
437
  content: Union[None, str, List[str]] = None
418
438
  additional_kwargs = {}
@@ -499,7 +519,7 @@ def _parse_response_candidate(
499
519
 
500
520
 
501
521
  def _response_to_result(
502
- response: glm.GenerateContentResponse,
522
+ response: GenerateContentResponse,
503
523
  stream: bool = False,
504
524
  ) -> ChatResult:
505
525
  """Converts a PaLM API response into a LangChain ChatResult."""
@@ -557,6 +577,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
557
577
  """
558
578
 
559
579
  client: Any #: :meta private:
580
+ async_client: Any #: :meta private:
581
+ default_metadata: Sequence[Tuple[str, str]] = Field(
582
+ default_factory=list
583
+ ) #: :meta private:
560
584
 
561
585
  convert_system_message_to_human: bool = False
562
586
  """Whether to merge any leading SystemMessage into the following HumanMessage.
@@ -582,29 +606,6 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
582
606
  @root_validator()
583
607
  def validate_environment(cls, values: Dict) -> Dict:
584
608
  """Validates params and passes them to google-generativeai package."""
585
- additional_headers = values.get("additional_headers") or {}
586
- default_metadata = tuple(additional_headers.items())
587
-
588
- if values.get("credentials"):
589
- genai.configure(
590
- credentials=values.get("credentials"),
591
- transport=values.get("transport"),
592
- client_options=values.get("client_options"),
593
- default_metadata=default_metadata,
594
- )
595
- else:
596
- google_api_key = get_from_dict_or_env(
597
- values, "google_api_key", "GOOGLE_API_KEY"
598
- )
599
- if isinstance(google_api_key, SecretStr):
600
- google_api_key = google_api_key.get_secret_value()
601
-
602
- genai.configure(
603
- api_key=google_api_key,
604
- transport=values.get("transport"),
605
- client_options=values.get("client_options"),
606
- default_metadata=default_metadata,
607
- )
608
609
  if (
609
610
  values.get("temperature") is not None
610
611
  and not 0 <= values["temperature"] <= 1
@@ -616,8 +617,36 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
616
617
 
617
618
  if values.get("top_k") is not None and values["top_k"] <= 0:
618
619
  raise ValueError("top_k must be positive")
619
- model = values["model"]
620
- values["client"] = genai.GenerativeModel(model_name=model)
620
+
621
+ if not values["model"].startswith("models/"):
622
+ values["model"] = f"models/{values['model']}"
623
+
624
+ additional_headers = values.get("additional_headers") or {}
625
+ values["default_metadata"] = tuple(additional_headers.items())
626
+ client_info = get_client_info("ChatGoogleGenerativeAI")
627
+ google_api_key = None
628
+ if not values.get("credentials"):
629
+ google_api_key = get_from_dict_or_env(
630
+ values, "google_api_key", "GOOGLE_API_KEY"
631
+ )
632
+ if isinstance(google_api_key, SecretStr):
633
+ google_api_key = google_api_key.get_secret_value()
634
+ transport: Optional[str] = values.get("transport")
635
+ values["client"] = genaix.build_generative_service(
636
+ credentials=values.get("credentials"),
637
+ api_key=google_api_key,
638
+ client_info=client_info,
639
+ client_options=values.get("client_options"),
640
+ transport=transport,
641
+ )
642
+ values["async_client"] = genaix.build_generative_async_service(
643
+ credentials=values.get("credentials"),
644
+ api_key=google_api_key,
645
+ client_info=client_info,
646
+ client_options=values.get("client_options"),
647
+ transport=transport,
648
+ )
649
+
621
650
  return values
622
651
 
623
652
  @property
@@ -632,8 +661,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
632
661
  }
633
662
 
634
663
  def _prepare_params(
635
- self, stop: Optional[List[str]], **kwargs: Any
636
- ) -> Dict[str, Any]:
664
+ self,
665
+ stop: Optional[List[str]],
666
+ generation_config: Optional[Dict[str, Any]] = None,
667
+ ) -> GenerationConfig:
637
668
  gen_config = {
638
669
  k: v
639
670
  for k, v in {
@@ -646,27 +677,37 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
646
677
  }.items()
647
678
  if v is not None
648
679
  }
649
- if "generation_config" in kwargs:
650
- gen_config = {**gen_config, **kwargs.pop("generation_config")}
651
- params = {"generation_config": gen_config, **kwargs}
652
- return params
680
+ if generation_config:
681
+ gen_config = {**gen_config, **generation_config}
682
+ return GenerationConfig(**gen_config)
653
683
 
654
684
  def _generate(
655
685
  self,
656
686
  messages: List[BaseMessage],
657
687
  stop: Optional[List[str]] = None,
658
688
  run_manager: Optional[CallbackManagerForLLMRun] = None,
689
+ *,
690
+ tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
691
+ functions: Optional[Sequence[FunctionDeclarationType]] = None,
692
+ safety_settings: Optional[SafetySettingDict] = None,
693
+ tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
694
+ generation_config: Optional[Dict[str, Any]] = None,
659
695
  **kwargs: Any,
660
696
  ) -> ChatResult:
661
- params, chat, message = self._prepare_chat(
697
+ request = self._prepare_request(
662
698
  messages,
663
699
  stop=stop,
664
- **kwargs,
700
+ tools=tools,
701
+ functions=functions,
702
+ safety_settings=safety_settings,
703
+ tool_config=tool_config,
704
+ generation_config=generation_config,
665
705
  )
666
- response: genai.types.GenerateContentResponse = _chat_with_retry(
667
- content=message,
668
- **params,
669
- generation_method=chat.send_message,
706
+ response: GenerateContentResponse = _chat_with_retry(
707
+ request=request,
708
+ **kwargs,
709
+ generation_method=self.client.generate_content,
710
+ metadata=self.default_metadata,
670
711
  )
671
712
  return _response_to_result(response)
672
713
 
@@ -675,17 +716,28 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
675
716
  messages: List[BaseMessage],
676
717
  stop: Optional[List[str]] = None,
677
718
  run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
719
+ *,
720
+ tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
721
+ functions: Optional[Sequence[FunctionDeclarationType]] = None,
722
+ safety_settings: Optional[SafetySettingDict] = None,
723
+ tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
724
+ generation_config: Optional[Dict[str, Any]] = None,
678
725
  **kwargs: Any,
679
726
  ) -> ChatResult:
680
- params, chat, message = self._prepare_chat(
727
+ request = self._prepare_request(
681
728
  messages,
682
729
  stop=stop,
683
- **kwargs,
730
+ tools=tools,
731
+ functions=functions,
732
+ safety_settings=safety_settings,
733
+ tool_config=tool_config,
734
+ generation_config=generation_config,
684
735
  )
685
- response: genai.types.GenerateContentResponse = await _achat_with_retry(
686
- content=message,
687
- **params,
688
- generation_method=chat.send_message_async,
736
+ response: GenerateContentResponse = await _achat_with_retry(
737
+ request=request,
738
+ **kwargs,
739
+ generation_method=self.async_client.generate_content,
740
+ metadata=self.default_metadata,
689
741
  )
690
742
  return _response_to_result(response)
691
743
 
@@ -694,18 +746,28 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
694
746
  messages: List[BaseMessage],
695
747
  stop: Optional[List[str]] = None,
696
748
  run_manager: Optional[CallbackManagerForLLMRun] = None,
749
+ *,
750
+ tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
751
+ functions: Optional[Sequence[FunctionDeclarationType]] = None,
752
+ safety_settings: Optional[SafetySettingDict] = None,
753
+ tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
754
+ generation_config: Optional[Dict[str, Any]] = None,
697
755
  **kwargs: Any,
698
756
  ) -> Iterator[ChatGenerationChunk]:
699
- params, chat, message = self._prepare_chat(
757
+ request = self._prepare_request(
700
758
  messages,
701
759
  stop=stop,
702
- **kwargs,
760
+ tools=tools,
761
+ functions=functions,
762
+ safety_settings=safety_settings,
763
+ tool_config=tool_config,
764
+ generation_config=generation_config,
703
765
  )
704
- response: genai.types.GenerateContentResponse = _chat_with_retry(
705
- content=message,
706
- **params,
707
- generation_method=chat.send_message,
708
- stream=True,
766
+ response: GenerateContentResponse = _chat_with_retry(
767
+ request=request,
768
+ generation_method=self.client.stream_generate_content,
769
+ **kwargs,
770
+ metadata=self.default_metadata,
709
771
  )
710
772
  for chunk in response:
711
773
  _chat_result = _response_to_result(chunk, stream=True)
@@ -720,18 +782,28 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
720
782
  messages: List[BaseMessage],
721
783
  stop: Optional[List[str]] = None,
722
784
  run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
785
+ *,
786
+ tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
787
+ functions: Optional[Sequence[FunctionDeclarationType]] = None,
788
+ safety_settings: Optional[SafetySettingDict] = None,
789
+ tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
790
+ generation_config: Optional[Dict[str, Any]] = None,
723
791
  **kwargs: Any,
724
792
  ) -> AsyncIterator[ChatGenerationChunk]:
725
- params, chat, message = self._prepare_chat(
793
+ request = self._prepare_request(
726
794
  messages,
727
795
  stop=stop,
728
- **kwargs,
796
+ tools=tools,
797
+ functions=functions,
798
+ safety_settings=safety_settings,
799
+ tool_config=tool_config,
800
+ generation_config=generation_config,
729
801
  )
730
802
  async for chunk in await _achat_with_retry(
731
- content=message,
732
- **params,
733
- generation_method=chat.send_message_async,
734
- stream=True,
803
+ request=request,
804
+ generation_method=self.async_client.stream_generate_content,
805
+ **kwargs,
806
+ metadata=self.default_metadata,
735
807
  ):
736
808
  _chat_result = _response_to_result(chunk, stream=True)
737
809
  gen = cast(ChatGenerationChunk, _chat_result.generations[0])
@@ -740,17 +812,17 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
740
812
  await run_manager.on_llm_new_token(gen.text)
741
813
  yield gen
742
814
 
743
- def _prepare_chat(
815
+ def _prepare_request(
744
816
  self,
745
817
  messages: List[BaseMessage],
818
+ *,
746
819
  stop: Optional[List[str]] = None,
747
820
  tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
748
821
  functions: Optional[Sequence[FunctionDeclarationType]] = None,
749
822
  safety_settings: Optional[SafetySettingDict] = None,
750
823
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
751
- **kwargs: Any,
752
- ) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
753
- client = self.client
824
+ generation_config: Optional[Dict[str, Any]] = None,
825
+ ) -> Tuple[GenerateContentRequest, Dict[str, Any]]:
754
826
  formatted_tools = None
755
827
  if tools:
756
828
  formatted_tools = [
@@ -759,25 +831,35 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
759
831
  elif functions:
760
832
  formatted_tools = [convert_to_genai_function_declarations(functions)]
761
833
 
762
- if formatted_tools or safety_settings:
763
- client = genai.GenerativeModel(
764
- model_name=self.model,
765
- tools=formatted_tools,
766
- safety_settings=safety_settings,
767
- )
768
-
769
- params = self._prepare_params(stop, tool_config=tool_config, **kwargs)
770
834
  system_instruction, history = _parse_chat_history(
771
835
  messages,
772
836
  convert_system_message_to_human=self.convert_system_message_to_human,
773
837
  )
774
- message = history.pop()
775
- if self.client._system_instruction != system_instruction:
776
- self.client = genai.GenerativeModel(
777
- model_name=self.model, system_instruction=system_instruction
838
+ formatted_tool_config = None
839
+ if tool_config:
840
+ formatted_tool_config = ToolConfig(
841
+ function_calling_config=tool_config["function_calling_config"]
778
842
  )
779
- chat = client.start_chat(history=history)
780
- return params, chat, message
843
+ formatted_safety_settings = []
844
+ if safety_settings:
845
+ formatted_safety_settings = [
846
+ SafetySetting(category=c, threshold=t)
847
+ for c, t in safety_settings.items()
848
+ ]
849
+ request = GenerateContentRequest(
850
+ model=self.model,
851
+ contents=history,
852
+ tools=formatted_tools,
853
+ tool_config=formatted_tool_config,
854
+ safety_settings=formatted_safety_settings,
855
+ generation_config=self._prepare_params(
856
+ stop, generation_config=generation_config
857
+ ),
858
+ )
859
+ if system_instruction:
860
+ request.system_instruction = system_instruction
861
+
862
+ return request
781
863
 
782
864
  def get_num_tokens(self, text: str) -> int:
783
865
  """Get the number of tokens present in the text.
@@ -790,14 +872,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
790
872
  Returns:
791
873
  The integer number of tokens in the text.
792
874
  """
793
- if self._model_family == GoogleModelFamily.GEMINI:
794
- result = self.client.count_tokens(text)
795
- token_count = result.total_tokens
796
- else:
797
- result = self.client.count_text_tokens(model=self.model, prompt=text)
798
- token_count = result["token_count"]
799
-
800
- return token_count
875
+ result = self.client.count_tokens(
876
+ model=self.model, contents=[Content(parts=[Part(text=text)])]
877
+ )
878
+ return result.total_tokens
801
879
 
802
880
  def bind_tools(
803
881
  self,
@@ -828,7 +906,9 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
828
906
  genai_tools = [tool_to_dict(convert_to_genai_function_declarations(tools))]
829
907
  if tool_choice:
830
908
  all_names = [
831
- f["name"] for t in genai_tools for f in t["function_declarations"]
909
+ f["name"] # type: ignore[index]
910
+ for t in genai_tools
911
+ for f in t["function_declarations"]
832
912
  ]
833
913
  tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
834
914
  return self.bind(tools=genai_tools, tool_config=tool_config, **kwargs)
@@ -1,12 +1,19 @@
1
1
  from typing import Any, Dict, List, Optional
2
2
 
3
3
  # TODO: remove ignore once the google package is published with types
4
- import google.generativeai as genai # type: ignore[import]
4
+ from google.ai.generativelanguage_v1beta.types import (
5
+ BatchEmbedContentsRequest,
6
+ EmbedContentRequest,
7
+ )
5
8
  from langchain_core.embeddings import Embeddings
6
9
  from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
7
10
  from langchain_core.utils import get_from_dict_or_env
8
11
 
9
- from langchain_google_genai._common import GoogleGenerativeAIError
12
+ from langchain_google_genai._common import (
13
+ GoogleGenerativeAIError,
14
+ get_client_info,
15
+ )
16
+ from langchain_google_genai._genai_extension import build_generative_service
10
17
 
11
18
 
12
19
  class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
@@ -27,6 +34,7 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
27
34
  embeddings.embed_query("What's our Q1 revenue?")
28
35
  """
29
36
 
37
+ client: Any #: :meta private:
30
38
  model: str = Field(
31
39
  ...,
32
40
  description="The name of the embedding model to use. "
@@ -70,44 +78,43 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
70
78
  @root_validator()
71
79
  def validate_environment(cls, values: Dict) -> Dict:
72
80
  """Validates params and passes them to google-generativeai package."""
73
- if values.get("credentials"):
74
- genai.configure(
75
- credentials=values.get("credentials"),
76
- transport=values.get("transport"),
77
- client_options=values.get("client_options"),
78
- )
79
- else:
80
- google_api_key = get_from_dict_or_env(
81
- values, "google_api_key", "GOOGLE_API_KEY"
82
- )
83
- if isinstance(google_api_key, SecretStr):
84
- google_api_key = google_api_key.get_secret_value()
85
-
86
- genai.configure(
87
- api_key=google_api_key,
88
- transport=values.get("transport"),
89
- client_options=values.get("client_options"),
90
- )
81
+ google_api_key = get_from_dict_or_env(
82
+ values, "google_api_key", "GOOGLE_API_KEY"
83
+ )
84
+ client_info = get_client_info("GoogleGenerativeAIEmbeddings")
85
+
86
+ values["client"] = build_generative_service(
87
+ credentials=values.get("credentials"),
88
+ api_key=google_api_key,
89
+ client_info=client_info,
90
+ client_options=values.get("client_options"),
91
+ )
91
92
  return values
92
93
 
93
- def _embed(
94
- self, texts: List[str], task_type: str, title: Optional[str] = None
95
- ) -> List[List[float]]:
96
- task_type = self.task_type or "retrieval_document"
97
- try:
98
- result = genai.embed_content(
99
- model=self.model,
100
- content=texts,
101
- task_type=task_type,
102
- title=title,
103
- request_options=self.request_options,
104
- )
105
- except Exception as e:
106
- raise GoogleGenerativeAIError(f"Error embedding content: {e}") from e
107
- return result["embedding"]
94
+ def _prepare_request(
95
+ self,
96
+ text: str,
97
+ task_type: Optional[str] = None,
98
+ title: Optional[str] = None,
99
+ output_dimensionality: Optional[int] = None,
100
+ ) -> EmbedContentRequest:
101
+ task_type = self.task_type or task_type or "RETRIEVAL_DOCUMENT"
102
+ # https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest
103
+ request = EmbedContentRequest(
104
+ content={"parts": [{"text": text}]},
105
+ model=self.model,
106
+ task_type=task_type.upper(),
107
+ title=title,
108
+ output_dimensionality=output_dimensionality,
109
+ )
110
+ return request
108
111
 
109
112
  def embed_documents(
110
- self, texts: List[str], batch_size: int = 5
113
+ self,
114
+ texts: List[str],
115
+ task_type: Optional[str] = None,
116
+ titles: Optional[List[str]] = None,
117
+ output_dimensionality: Optional[int] = None,
111
118
  ) -> List[List[float]]:
112
119
  """Embed a list of strings. Vertex AI currently
113
120
  sets a max batch size of 5 strings.
@@ -115,21 +122,58 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
115
122
  Args:
116
123
  texts: List[str] The list of strings to embed.
117
124
  batch_size: [int] The batch size of embeddings to send to the model
125
+ task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType)
126
+ titles: An optional list of titles for texts provided.
127
+ Only applicable when TaskType is RETRIEVAL_DOCUMENT.
128
+ output_dimensionality: Optional reduced dimension for the output embedding.
129
+ https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest
118
130
 
119
131
  Returns:
120
132
  List of embeddings, one for each text.
121
133
  """
122
- task_type = self.task_type or "retrieval_document"
123
- return self._embed(texts, task_type=task_type)
134
+ titles = titles if titles else [None] * len(texts) # type: ignore[list-item]
135
+ requests = [
136
+ self._prepare_request(
137
+ text=text,
138
+ task_type=task_type,
139
+ title=title,
140
+ output_dimensionality=output_dimensionality,
141
+ )
142
+ for text, title in zip(texts, titles)
143
+ ]
124
144
 
125
- def embed_query(self, text: str) -> List[float]:
145
+ try:
146
+ result = self.client.batch_embed_contents(
147
+ BatchEmbedContentsRequest(requests=requests, model=self.model)
148
+ )
149
+ except Exception as e:
150
+ raise GoogleGenerativeAIError(f"Error embedding content: {e}") from e
151
+ return [e.values for e in result.embeddings]
152
+
153
+ def embed_query(
154
+ self,
155
+ text: str,
156
+ task_type: Optional[str] = None,
157
+ title: Optional[str] = None,
158
+ output_dimensionality: Optional[int] = None,
159
+ ) -> List[float]:
126
160
  """Embed a text.
127
161
 
128
162
  Args:
129
163
  text: The text to embed.
164
+ task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType)
165
+ title: An optional title for the text.
166
+ Only applicable when TaskType is RETRIEVAL_DOCUMENT.
167
+ output_dimensionality: Optional reduced dimension for the output embedding.
168
+ https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest
130
169
 
131
170
  Returns:
132
171
  Embedding for the text.
133
172
  """
134
- task_type = self.task_type or "retrieval_query"
135
- return self._embed([text], task_type=task_type)[0]
173
+ task_type = self.task_type or "RETRIEVAL_QUERY"
174
+ return self.embed_documents(
175
+ [text],
176
+ task_type=task_type,
177
+ titles=[title] if title else None,
178
+ output_dimensionality=output_dimensionality,
179
+ )[0]
@@ -174,7 +174,6 @@ Supported examples:
174
174
  from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
175
175
 
176
176
  safety_settings = {
177
- HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
178
177
  HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
179
178
  HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
180
179
  HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langchain-google-genai
3
- Version: 1.0.3
3
+ Version: 1.0.4
4
4
  Summary: An integration package connecting Google's genai package and LangChain
5
5
  Home-page: https://github.com/langchain-ai/langchain-google
6
6
  License: MIT
@@ -13,7 +13,7 @@ Classifier: Programming Language :: Python :: 3.11
13
13
  Classifier: Programming Language :: Python :: 3.12
14
14
  Provides-Extra: images
15
15
  Requires-Dist: google-generativeai (>=0.5.2,<0.6.0)
16
- Requires-Dist: langchain-core (>=0.1.45,<0.2)
16
+ Requires-Dist: langchain-core (>=0.1.45,<0.3)
17
17
  Requires-Dist: pillow (>=10.1.0,<11.0.0) ; extra == "images"
18
18
  Project-URL: Repository, https://github.com/langchain-ai/langchain-google
19
19
  Project-URL: Source Code, https://github.com/langchain-ai/langchain-google/tree/main/libs/genai
@@ -0,0 +1,16 @@
1
+ langchain_google_genai/__init__.py,sha256=Oji-S2KYWrku1wyQEskY84IOfY8MfRhujjJ4d7hbsk4,2758
2
+ langchain_google_genai/_common.py,sha256=ASlwE8hEbvOm55BVF_D4rf2nl7RYsnpsi5xbM6DW3Cc,1576
3
+ langchain_google_genai/_enums.py,sha256=KLPmxS1K83K4HjBIXFaXoL_sFEOv8Hq-2B2PDMKyDgo,197
4
+ langchain_google_genai/_function_utils.py,sha256=d0ApSCjoV9Em1CteBaGznilxrw-PDXqQ4sQa5p7cJfM,8232
5
+ langchain_google_genai/_genai_extension.py,sha256=ZwNwLV22RSf9LB7FOCLsoHzLlQDF-EQmRNYM1an2uSw,20769
6
+ langchain_google_genai/_image_utils.py,sha256=-0XgCMdYkvrIktFvUpy-2GPbFgfSVKZICawB2hiJzus,4999
7
+ langchain_google_genai/chat_models.py,sha256=3ubY5qCaZjFSHKHiiP5XCOPDusUxO8kJvpj9DwLtUG4,33014
8
+ langchain_google_genai/embeddings.py,sha256=kQW6pl1TUGKKSxiUjc7rptp0iEu_Rer1m1LLHKVaW14,6578
9
+ langchain_google_genai/genai_aqa.py,sha256=zcC5cdFYtqLK7DGPhYGvWNeHHeU-CQKA9KhewmsA5lw,4303
10
+ langchain_google_genai/google_vector_store.py,sha256=PPIk-4FmD5UUdmYA2u7VcEhGsiztvRVN59QoGLXdfoA,16139
11
+ langchain_google_genai/llms.py,sha256=S7tOy-c37DElcHtkGl8rwvvg1zOzCxb9PEyJ4E-j7qU,13431
12
+ langchain_google_genai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
+ langchain_google_genai-1.0.4.dist-info/LICENSE,sha256=DppmdYJVSc1jd0aio6ptnMUn5tIHrdAhQ12SclEBfBg,1072
14
+ langchain_google_genai-1.0.4.dist-info/METADATA,sha256=nHJMXwe6iuI99J_VO1GQrVgv1BLvVGW5x0mSNswDWq0,3818
15
+ langchain_google_genai-1.0.4.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
16
+ langchain_google_genai-1.0.4.dist-info/RECORD,,
@@ -1,15 +0,0 @@
1
- langchain_google_genai/__init__.py,sha256=NiLfU4IyEQSlSl8hbN_fTdX-hrO5tuOflxFlEK0Sy4c,2762
2
- langchain_google_genai/_common.py,sha256=1r0VrrBSTZfGprmICZ5OV-W5SK31jKRFFCNE3vJ3jmk,136
3
- langchain_google_genai/_enums.py,sha256=q8IYAqufV-_yZ98FDnsZ3x-1w4804J_e8PrTKT0sdhY,163
4
- langchain_google_genai/_function_utils.py,sha256=uDWg2Gcuv3PtdUL14sDOVTFbq8gaaGJ360bL4IbN4AI,8705
5
- langchain_google_genai/_genai_extension.py,sha256=2Uqg7vSF0vu1J4AhAyIPzadtpM5JJwZsBXvpItO2TY4,18736
6
- langchain_google_genai/chat_models.py,sha256=edkRRStq42pxY1xBzfC2yUfN-jBsM-Vr1dx9QVg-qd4,29969
7
- langchain_google_genai/embeddings.py,sha256=QZJRd5xQkGzalAGiKeorrsnVmsyaO4NGmGuzQFDoRe0,4807
8
- langchain_google_genai/genai_aqa.py,sha256=zcC5cdFYtqLK7DGPhYGvWNeHHeU-CQKA9KhewmsA5lw,4303
9
- langchain_google_genai/google_vector_store.py,sha256=PPIk-4FmD5UUdmYA2u7VcEhGsiztvRVN59QoGLXdfoA,16139
10
- langchain_google_genai/llms.py,sha256=QPjCs0AHb-2d3GMCS3sZqzG_Yr71YbaGX_Vs8QlLTKU,13518
11
- langchain_google_genai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- langchain_google_genai-1.0.3.dist-info/LICENSE,sha256=DppmdYJVSc1jd0aio6ptnMUn5tIHrdAhQ12SclEBfBg,1072
13
- langchain_google_genai-1.0.3.dist-info/METADATA,sha256=IPH2elmUC6Rt7ZG-g5eyma97A8r48jftkeq9v8XyQMc,3818
14
- langchain_google_genai-1.0.3.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
15
- langchain_google_genai-1.0.3.dist-info/RECORD,,