langchain-google-genai 2.1.6__tar.gz → 2.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain-google-genai might be problematic. Click here for more details.

Files changed (16) hide show
  1. {langchain_google_genai-2.1.6 → langchain_google_genai-2.1.8}/PKG-INFO +2 -2
  2. {langchain_google_genai-2.1.6 → langchain_google_genai-2.1.8}/langchain_google_genai/__init__.py +26 -24
  3. {langchain_google_genai-2.1.6 → langchain_google_genai-2.1.8}/langchain_google_genai/_common.py +8 -8
  4. {langchain_google_genai-2.1.6 → langchain_google_genai-2.1.8}/langchain_google_genai/_genai_extension.py +5 -5
  5. {langchain_google_genai-2.1.6 → langchain_google_genai-2.1.8}/langchain_google_genai/_image_utils.py +3 -3
  6. {langchain_google_genai-2.1.6 → langchain_google_genai-2.1.8}/langchain_google_genai/chat_models.py +46 -60
  7. {langchain_google_genai-2.1.6 → langchain_google_genai-2.1.8}/langchain_google_genai/embeddings.py +124 -20
  8. {langchain_google_genai-2.1.6 → langchain_google_genai-2.1.8}/langchain_google_genai/llms.py +12 -0
  9. {langchain_google_genai-2.1.6 → langchain_google_genai-2.1.8}/pyproject.toml +5 -5
  10. {langchain_google_genai-2.1.6 → langchain_google_genai-2.1.8}/LICENSE +0 -0
  11. {langchain_google_genai-2.1.6 → langchain_google_genai-2.1.8}/README.md +0 -0
  12. {langchain_google_genai-2.1.6 → langchain_google_genai-2.1.8}/langchain_google_genai/_enums.py +0 -0
  13. {langchain_google_genai-2.1.6 → langchain_google_genai-2.1.8}/langchain_google_genai/_function_utils.py +0 -0
  14. {langchain_google_genai-2.1.6 → langchain_google_genai-2.1.8}/langchain_google_genai/genai_aqa.py +0 -0
  15. {langchain_google_genai-2.1.6 → langchain_google_genai-2.1.8}/langchain_google_genai/google_vector_store.py +0 -0
  16. {langchain_google_genai-2.1.6 → langchain_google_genai-2.1.8}/langchain_google_genai/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langchain-google-genai
3
- Version: 2.1.6
3
+ Version: 2.1.8
4
4
  Summary: An integration package connecting Google's genai package and LangChain
5
5
  Home-page: https://github.com/langchain-ai/langchain-google
6
6
  License: MIT
@@ -13,7 +13,7 @@ Classifier: Programming Language :: Python :: 3.11
13
13
  Classifier: Programming Language :: Python :: 3.12
14
14
  Requires-Dist: filetype (>=1.2.0,<2.0.0)
15
15
  Requires-Dist: google-ai-generativelanguage (>=0.6.18,<0.7.0)
16
- Requires-Dist: langchain-core (>=0.3.66,<0.4.0)
16
+ Requires-Dist: langchain-core (>=0.3.68,<0.4.0)
17
17
  Requires-Dist: pydantic (>=2,<3)
18
18
  Project-URL: Repository, https://github.com/langchain-ai/langchain-google
19
19
  Project-URL: Source Code, https://github.com/langchain-ai/langchain-google/tree/main/libs/genai
@@ -4,55 +4,57 @@ This module integrates Google's Generative AI models, specifically the Gemini se
4
4
 
5
5
  **Chat Models**
6
6
 
7
- The `ChatGoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini chat models. It allows users to send and receive messages using a specified Gemini model, suitable for various conversational AI applications.
7
+ The ``ChatGoogleGenerativeAI`` class is the primary interface for interacting with Google's Gemini chat models. It allows users to send and receive messages using a specified Gemini model, suitable for various conversational AI applications.
8
8
 
9
9
  **LLMs**
10
10
 
11
- The `GoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini LLMs. It allows users to generate text using a specified Gemini model.
11
+ The ``GoogleGenerativeAI`` class is the primary interface for interacting with Google's Gemini LLMs. It allows users to generate text using a specified Gemini model.
12
12
 
13
13
  **Embeddings**
14
14
 
15
- The `GoogleGenerativeAIEmbeddings` class provides functionalities to generate embeddings using Google's models.
15
+ The ``GoogleGenerativeAIEmbeddings`` class provides functionalities to generate embeddings using Google's models.
16
16
  These embeddings can be used for a range of NLP tasks, including semantic analysis, similarity comparisons, and more.
17
+
17
18
  **Installation**
18
19
 
19
20
  To install the package, use pip:
20
21
 
21
- ```python
22
- pip install -U langchain-google-genai
23
- ```
24
- ## Using Chat Models
22
+ .. code-block:: python
23
+ pip install -U langchain-google-genai
24
+
25
+ **Using Chat Models**
25
26
 
26
27
  After setting up your environment with the required API key, you can interact with the Google Gemini models.
27
28
 
28
- ```python
29
- from langchain_google_genai import ChatGoogleGenerativeAI
29
+ .. code-block:: python
30
+
31
+ from langchain_google_genai import ChatGoogleGenerativeAI
30
32
 
31
- llm = ChatGoogleGenerativeAI(model="gemini-pro")
32
- llm.invoke("Sing a ballad of LangChain.")
33
- ```
33
+ llm = ChatGoogleGenerativeAI(model="gemini-pro")
34
+ llm.invoke("Sing a ballad of LangChain.")
34
35
 
35
- ## Using LLMs
36
+ **Using LLMs**
36
37
 
37
38
  The package also supports generating text with Google's models.
38
39
 
39
- ```python
40
- from langchain_google_genai import GoogleGenerativeAI
40
+ .. code-block:: python
41
41
 
42
- llm = GoogleGenerativeAI(model="gemini-pro")
43
- llm.invoke("Once upon a time, a library called LangChain")
44
- ```
42
+ from langchain_google_genai import GoogleGenerativeAI
45
43
 
46
- ## Embedding Generation
44
+ llm = GoogleGenerativeAI(model="gemini-pro")
45
+ llm.invoke("Once upon a time, a library called LangChain")
46
+
47
+ **Embedding Generation**
47
48
 
48
49
  The package also supports creating embeddings with Google's models, useful for textual similarity and other NLP applications.
49
50
 
50
- ```python
51
- from langchain_google_genai import GoogleGenerativeAIEmbeddings
51
+ .. code-block:: python
52
+
53
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
54
+
55
+ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
56
+ embeddings.embed_query("hello, world!")
52
57
 
53
- embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
54
- embeddings.embed_query("hello, world!")
55
- ```
56
58
  """ # noqa: E501
57
59
 
58
60
  from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory, Modality
@@ -39,20 +39,19 @@ Supported examples:
39
39
  "when making API calls. If not provided, credentials will be ascertained from "
40
40
  "the GOOGLE_API_KEY envvar"
41
41
  temperature: float = 0.7
42
- """Run inference with this temperature. Must by in the closed interval
43
- [0.0, 2.0]."""
42
+ """Run inference with this temperature. Must be within ``[0.0, 2.0]``."""
44
43
  top_p: Optional[float] = None
45
44
  """Decode using nucleus sampling: consider the smallest set of tokens whose
46
- probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
45
+ probability sum is at least ``top_p``. Must be within ``[0.0, 1.0]``."""
47
46
  top_k: Optional[int] = None
48
- """Decode using top-k sampling: consider the set of top_k most probable tokens.
47
+ """Decode using top-k sampling: consider the set of ``top_k`` most probable tokens.
49
48
  Must be positive."""
50
49
  max_output_tokens: Optional[int] = Field(default=None, alias="max_tokens")
51
50
  """Maximum number of tokens to include in a candidate. Must be greater than zero.
52
- If unset, will default to 64."""
51
+ If unset, will default to ``64``."""
53
52
  n: int = 1
54
53
  """Number of chat completions to generate for each prompt. Note that the API may
55
- not return the full n completions if duplicates are generated."""
54
+ not return the full ``n`` completions if duplicates are generated."""
56
55
  max_retries: int = 6
57
56
  """The maximum number of retries to make when generating."""
58
57
 
@@ -94,6 +93,7 @@ Supported examples:
94
93
 
95
94
  For example:
96
95
 
96
+ .. code-block:: python
97
97
  from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
98
98
 
99
99
  safety_settings = {
@@ -102,7 +102,7 @@ Supported examples:
102
102
  HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
103
103
  HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
104
104
  }
105
- """ # noqa: E501
105
+ """ # noqa: E501
106
106
 
107
107
  @property
108
108
  def lc_secrets(self) -> Dict[str, str]:
@@ -149,7 +149,7 @@ def get_client_info(module: Optional[str] = None) -> "ClientInfo":
149
149
  module (Optional[str]):
150
150
  Optional. The module for a custom user agent header.
151
151
  Returns:
152
- google.api_core.gapic_v1.client_info.ClientInfo
152
+ ``google.api_core.gapic_v1.client_info.ClientInfo``
153
153
  """
154
154
  client_library_version, user_agent = get_user_agent(module)
155
155
  return ClientInfo(
@@ -174,12 +174,12 @@ class TestCredentials(credentials.Credentials):
174
174
 
175
175
  @property
176
176
  def expired(self) -> bool:
177
- """Returns `False`, test credentials never expire."""
177
+ """Returns ``False``, test credentials never expire."""
178
178
  return False
179
179
 
180
180
  @property
181
181
  def valid(self) -> bool:
182
- """Returns `True`, test credentials are always valid."""
182
+ """Returns ``True``, test credentials are always valid."""
183
183
  return True
184
184
 
185
185
  def refresh(self, request: Any) -> None:
@@ -206,11 +206,11 @@ class TestCredentials(credentials.Credentials):
206
206
  def _get_credentials() -> Optional[credentials.Credentials]:
207
207
  """Returns credential from config if set or fake credentials for unit testing.
208
208
 
209
- If _config.testing is True, a fake credential is returned.
209
+ If ``_config.testing`` is ``True``, a fake credential is returned.
210
210
  Otherwise, we are in a real environment and will use credentials if provided
211
- or None is returned.
211
+ or ``None`` is returned.
212
212
 
213
- If None is passed to the clients later on, the actual credentials will be
213
+ If ``None`` is passed to the clients later on, the actual credentials will be
214
214
  inferred by the rules specified in google.auth package.
215
215
  """
216
216
  if _config.testing:
@@ -30,7 +30,7 @@ class ImageBytesLoader:
30
30
  """
31
31
 
32
32
  def load_bytes(self, image_string: str) -> bytes:
33
- """Routes to the correct loader based on the image_string.
33
+ """Routes to the correct loader based on the ``'image_string'``.
34
34
 
35
35
  Args:
36
36
  image_string: Can be either:
@@ -178,8 +178,8 @@ def image_bytes_to_b64_string(
178
178
 
179
179
  Args:
180
180
  image_bytes: Bytes of the image.
181
- encoding: Type of encoding in the string. 'ascii' by default.
182
- image_format: Format of the image. 'png' by default.
181
+ encoding: Type of encoding in the string. ``'ascii'`` by default.
182
+ image_format: Format of the image. ``'png'`` by default.
183
183
 
184
184
  Returns:
185
185
  B64 image encoded string.
@@ -31,7 +31,7 @@ from typing import (
31
31
  import filetype # type: ignore[import]
32
32
  import google.api_core
33
33
 
34
- # TODO: remove ignore once the google package is published with types
34
+ # TODO: remove ignore once the Google package is published with types
35
35
  import proto # type: ignore[import]
36
36
  from google.ai.generativelanguage_v1beta import (
37
37
  GenerativeServiceAsyncClient as v1betaGenerativeServiceAsyncClient,
@@ -72,7 +72,7 @@ from langchain_core.messages import (
72
72
  ToolMessage,
73
73
  is_data_content_block,
74
74
  )
75
- from langchain_core.messages.ai import UsageMetadata
75
+ from langchain_core.messages.ai import UsageMetadata, add_usage, subtract_usage
76
76
  from langchain_core.messages.tool import invalid_tool_call, tool_call, tool_call_chunk
77
77
  from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
78
78
  from langchain_core.output_parsers.base import OutputParserLike
@@ -295,7 +295,7 @@ def _is_openai_image_block(block: dict) -> bool:
295
295
  def _convert_to_parts(
296
296
  raw_content: Union[str, Sequence[Union[str, dict]]],
297
297
  ) -> List[Part]:
298
- """Converts a list of LangChain messages into a google parts."""
298
+ """Converts a list of LangChain messages into a Google parts."""
299
299
  parts = []
300
300
  content = [raw_content] if isinstance(raw_content, str) else raw_content
301
301
  image_loader = ImageBytesLoader()
@@ -413,7 +413,7 @@ def _convert_to_parts(
413
413
  def _convert_tool_message_to_parts(
414
414
  message: ToolMessage | FunctionMessage, name: Optional[str] = None
415
415
  ) -> list[Part]:
416
- """Converts a tool or function message to a google part."""
416
+ """Converts a tool or function message to a Google part."""
417
417
  # Legacy agent stores tool name in message.additional_kwargs instead of message.name
418
418
  name = message.name or name or message.additional_kwargs.get("name")
419
419
  response: Any
@@ -716,35 +716,43 @@ def _response_to_result(
716
716
  """Converts a PaLM API response into a LangChain ChatResult."""
717
717
  llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
718
718
 
719
- # previous usage metadata needs to be subtracted because gemini api returns
720
- # already-accumulated token counts with each chunk
721
- prev_input_tokens = prev_usage["input_tokens"] if prev_usage else 0
722
- prev_output_tokens = prev_usage["output_tokens"] if prev_usage else 0
723
- prev_total_tokens = prev_usage["total_tokens"] if prev_usage else 0
724
-
725
719
  # Get usage metadata
726
720
  try:
727
721
  input_tokens = response.usage_metadata.prompt_token_count
728
- output_tokens = response.usage_metadata.candidates_token_count
729
- total_tokens = response.usage_metadata.total_token_count
730
722
  thought_tokens = response.usage_metadata.thoughts_token_count
723
+ output_tokens = response.usage_metadata.candidates_token_count + thought_tokens
724
+ total_tokens = response.usage_metadata.total_token_count
731
725
  cache_read_tokens = response.usage_metadata.cached_content_token_count
732
726
  if input_tokens + output_tokens + cache_read_tokens + total_tokens > 0:
733
727
  if thought_tokens > 0:
734
- lc_usage = UsageMetadata(
735
- input_tokens=input_tokens - prev_input_tokens,
736
- output_tokens=output_tokens - prev_output_tokens,
737
- total_tokens=total_tokens - prev_total_tokens,
728
+ cumulative_usage = UsageMetadata(
729
+ input_tokens=input_tokens,
730
+ output_tokens=output_tokens,
731
+ total_tokens=total_tokens,
738
732
  input_token_details={"cache_read": cache_read_tokens},
739
733
  output_token_details={"reasoning": thought_tokens},
740
734
  )
741
735
  else:
742
- lc_usage = UsageMetadata(
743
- input_tokens=input_tokens - prev_input_tokens,
744
- output_tokens=output_tokens - prev_output_tokens,
745
- total_tokens=total_tokens - prev_total_tokens,
736
+ cumulative_usage = UsageMetadata(
737
+ input_tokens=input_tokens,
738
+ output_tokens=output_tokens,
739
+ total_tokens=total_tokens,
746
740
  input_token_details={"cache_read": cache_read_tokens},
747
741
  )
742
+ # previous usage metadata needs to be subtracted because gemini api returns
743
+ # already-accumulated token counts with each chunk
744
+ lc_usage = subtract_usage(cumulative_usage, prev_usage)
745
+ if prev_usage and cumulative_usage["input_tokens"] < prev_usage.get(
746
+ "input_tokens", 0
747
+ ):
748
+ # Gemini 1.5 and 2.0 return a lower cumulative count of prompt tokens
749
+ # in the final chunk. We take this count to be ground truth because
750
+ # it's consistent with the reported total tokens. So we need to
751
+ # ensure this chunk compensates (the subtract_usage funcction floors
752
+ # at zero).
753
+ lc_usage["input_tokens"] = cumulative_usage[
754
+ "input_tokens"
755
+ ] - prev_usage.get("input_tokens", 0)
748
756
  else:
749
757
  lc_usage = None
750
758
  except AttributeError:
@@ -816,8 +824,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
816
824
  To use, you must have either:
817
825
 
818
826
  1. The ``GOOGLE_API_KEY`` environment variable set with your API key, or
819
- 2. Pass your API key using the google_api_key kwarg
820
- to the ChatGoogleGenerativeAI constructor.
827
+ 2. Pass your API key using the ``google_api_key`` kwarg to the ChatGoogleGenerativeAI constructor.
821
828
 
822
829
  .. code-block:: python
823
830
 
@@ -885,8 +892,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
885
892
 
886
893
  Context Caching:
887
894
  Context caching allows you to store and reuse content (e.g., PDFs, images) for faster processing.
888
- The `cached_content` parameter accepts a cache name created via the Google Generative AI API.
889
- Below are two examples: caching a single file directly and caching multiple files using `Part`.
895
+ The ``cached_content`` parameter accepts a cache name created via the Google Generative AI API.
896
+ Below are two examples: caching a single file directly and caching multiple files using ``Part``.
890
897
 
891
898
  Single File Example:
892
899
  This caches a single file and queries it.
@@ -1132,12 +1139,15 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1132
1139
 
1133
1140
  response_mime_type: Optional[str] = None
1134
1141
  """Optional. Output response mimetype of the generated candidate text. Only
1135
- supported in Gemini 1.5 and later models. Supported mimetype:
1136
- * "text/plain": (default) Text output.
1137
- * "application/json": JSON response in the candidates.
1138
- * "text/x.enum": Enum in plain text.
1139
- The model also needs to be prompted to output the appropriate response
1140
- type, otherwise the behavior is undefined. This is a preview feature.
1142
+ supported in Gemini 1.5 and later models.
1143
+
1144
+ Supported mimetype:
1145
+ * ``'text/plain'``: (default) Text output.
1146
+ * ``'application/json'``: JSON response in the candidates.
1147
+ * ``'text/x.enum'``: Enum in plain text.
1148
+
1149
+ The model also needs to be prompted to output the appropriate response
1150
+ type, otherwise the behavior is undefined. This is a preview feature.
1141
1151
  """
1142
1152
 
1143
1153
  response_schema: Optional[Dict[str, Any]] = None
@@ -1222,9 +1232,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1222
1232
  if self.top_k is not None and self.top_k <= 0:
1223
1233
  raise ValueError("top_k must be positive")
1224
1234
 
1225
- if not any(
1226
- self.model.startswith(prefix) for prefix in ("models/", "tunedModels/")
1227
- ):
1235
+ if not any(self.model.startswith(prefix) for prefix in ("models/",)):
1228
1236
  self.model = f"models/{self.model}"
1229
1237
 
1230
1238
  additional_headers = self.additional_headers or {}
@@ -1320,7 +1328,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1320
1328
 
1321
1329
  else:
1322
1330
  raise ValueError(
1323
- "Tools are already defined." "code_execution tool can't be defined"
1331
+ "Tools are already defined.code_execution tool can't be defined"
1324
1332
  )
1325
1333
 
1326
1334
  return super().invoke(input, config, stop=stop, **kwargs)
@@ -1522,7 +1530,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1522
1530
  metadata=self.default_metadata,
1523
1531
  )
1524
1532
 
1525
- prev_usage_metadata: UsageMetadata | None = None
1533
+ prev_usage_metadata: UsageMetadata | None = None # cumulative usage
1526
1534
  for chunk in response:
1527
1535
  _chat_result = _response_to_result(
1528
1536
  chunk, stream=True, prev_usage=prev_usage_metadata
@@ -1530,21 +1538,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1530
1538
  gen = cast(ChatGenerationChunk, _chat_result.generations[0])
1531
1539
  message = cast(AIMessageChunk, gen.message)
1532
1540
 
1533
- curr_usage_metadata: UsageMetadata | dict[str, int] = (
1534
- message.usage_metadata or {}
1535
- )
1536
-
1537
1541
  prev_usage_metadata = (
1538
1542
  message.usage_metadata
1539
1543
  if prev_usage_metadata is None
1540
- else UsageMetadata(
1541
- input_tokens=prev_usage_metadata.get("input_tokens", 0)
1542
- + curr_usage_metadata.get("input_tokens", 0),
1543
- output_tokens=prev_usage_metadata.get("output_tokens", 0)
1544
- + curr_usage_metadata.get("output_tokens", 0),
1545
- total_tokens=prev_usage_metadata.get("total_tokens", 0)
1546
- + curr_usage_metadata.get("total_tokens", 0),
1547
- )
1544
+ else add_usage(prev_usage_metadata, message.usage_metadata)
1548
1545
  )
1549
1546
 
1550
1547
  if run_manager:
@@ -1594,7 +1591,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1594
1591
  tool_choice=tool_choice,
1595
1592
  **kwargs,
1596
1593
  )
1597
- prev_usage_metadata: UsageMetadata | None = None
1594
+ prev_usage_metadata: UsageMetadata | None = None # cumulative usage
1598
1595
  async for chunk in await _achat_with_retry(
1599
1596
  request=request,
1600
1597
  generation_method=self.async_client.stream_generate_content,
@@ -1607,21 +1604,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1607
1604
  gen = cast(ChatGenerationChunk, _chat_result.generations[0])
1608
1605
  message = cast(AIMessageChunk, gen.message)
1609
1606
 
1610
- curr_usage_metadata: UsageMetadata | dict[str, int] = (
1611
- message.usage_metadata or {}
1612
- )
1613
-
1614
1607
  prev_usage_metadata = (
1615
1608
  message.usage_metadata
1616
1609
  if prev_usage_metadata is None
1617
- else UsageMetadata(
1618
- input_tokens=prev_usage_metadata.get("input_tokens", 0)
1619
- + curr_usage_metadata.get("input_tokens", 0),
1620
- output_tokens=prev_usage_metadata.get("output_tokens", 0)
1621
- + curr_usage_metadata.get("output_tokens", 0),
1622
- total_tokens=prev_usage_metadata.get("total_tokens", 0)
1623
- + curr_usage_metadata.get("total_tokens", 0),
1624
- )
1610
+ else add_usage(prev_usage_metadata, message.usage_metadata)
1625
1611
  )
1626
1612
 
1627
1613
  if run_manager:
@@ -17,7 +17,10 @@ from langchain_google_genai._common import (
17
17
  GoogleGenerativeAIError,
18
18
  get_client_info,
19
19
  )
20
- from langchain_google_genai._genai_extension import build_generative_service
20
+ from langchain_google_genai._genai_extension import (
21
+ build_generative_async_service,
22
+ build_generative_service,
23
+ )
21
24
 
22
25
  _MAX_TOKENS_PER_BATCH = 20000
23
26
  _DEFAULT_BATCH_SIZE = 100
@@ -29,8 +32,8 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
29
32
  To use, you must have either:
30
33
 
31
34
  1. The ``GOOGLE_API_KEY`` environment variable set with your API key, or
32
- 2. Pass your API key using the google_api_key kwarg
33
- to the GoogleGenerativeAIEmbeddings constructor.
35
+ 2. Pass your API key using the google_api_key kwarg to the
36
+ GoogleGenerativeAIEmbeddings constructor.
34
37
 
35
38
  Example:
36
39
  .. code-block:: python
@@ -42,16 +45,17 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
42
45
  """
43
46
 
44
47
  client: Any = None #: :meta private:
48
+ async_client: Any = None #: :meta private:
45
49
  model: str = Field(
46
50
  ...,
47
51
  description="The name of the embedding model to use. "
48
- "Example: models/embedding-001",
52
+ "Example: ``'models/embedding-001'``",
49
53
  )
50
54
  task_type: Optional[str] = Field(
51
55
  default=None,
52
56
  description="The task type. Valid options include: "
53
- "task_type_unspecified, retrieval_query, retrieval_document, "
54
- "semantic_similarity, classification, and clustering",
57
+ "``'task_type_unspecified'``, ``'retrieval_query'``, ``'retrieval_document'``, "
58
+ "``'semantic_similarity'``, ``'classification'``, and ``'clustering'``",
55
59
  )
56
60
  google_api_key: Optional[SecretStr] = Field(
57
61
  default_factory=secret_from_env("GOOGLE_API_KEY", default=None),
@@ -76,7 +80,7 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
76
80
  )
77
81
  transport: Optional[str] = Field(
78
82
  default=None,
79
- description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
83
+ description="A string, one of: [``'rest'``, ``'grpc'``, ``'grpc_asyncio'``].",
80
84
  )
81
85
  request_options: Optional[Dict] = Field(
82
86
  default=None,
@@ -93,6 +97,9 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
93
97
  google_api_key = self.google_api_key
94
98
  client_info = get_client_info("GoogleGenerativeAIEmbeddings")
95
99
 
100
+ if not any(self.model.startswith(prefix) for prefix in ("models/",)):
101
+ self.model = f"models/{self.model}"
102
+
96
103
  self.client = build_generative_service(
97
104
  credentials=self.credentials,
98
105
  api_key=google_api_key,
@@ -100,6 +107,13 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
100
107
  client_options=self.client_options,
101
108
  transport=self.transport,
102
109
  )
110
+ self.async_client = build_generative_async_service(
111
+ credentials=self.credentials,
112
+ api_key=google_api_key,
113
+ client_info=client_info,
114
+ client_options=self.client_options,
115
+ transport=self.transport,
116
+ )
103
117
  return self
104
118
 
105
119
  @staticmethod
@@ -166,12 +180,12 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
166
180
  def _prepare_request(
167
181
  self,
168
182
  text: str,
183
+ *,
169
184
  task_type: Optional[str] = None,
170
185
  title: Optional[str] = None,
171
186
  output_dimensionality: Optional[int] = None,
172
187
  ) -> EmbedContentRequest:
173
188
  task_type = self.task_type or task_type or "RETRIEVAL_DOCUMENT"
174
- # https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest
175
189
  request = EmbedContentRequest(
176
190
  content={"parts": [{"text": text}]},
177
191
  model=self.model,
@@ -190,17 +204,17 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
190
204
  titles: Optional[List[str]] = None,
191
205
  output_dimensionality: Optional[int] = None,
192
206
  ) -> List[List[float]]:
193
- """Embed a list of strings. Google Generative AI currently
194
- sets a max batch size of 100 strings.
207
+ """Embed a list of strings using the `batch endpoint <https://ai.google.dev/api/embeddings#method:-models.batchembedcontents>`__.
208
+
209
+ Google Generative AI currently sets a max batch size of 100 strings.
195
210
 
196
211
  Args:
197
212
  texts: List[str] The list of strings to embed.
198
213
  batch_size: [int] The batch size of embeddings to send to the model
199
- task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType)
214
+ task_type: `task_type <https://ai.google.dev/api/embeddings#tasktype>`__
200
215
  titles: An optional list of titles for texts provided.
201
- Only applicable when TaskType is RETRIEVAL_DOCUMENT.
202
- output_dimensionality: Optional reduced dimension for the output embedding.
203
- https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest
216
+ Only applicable when TaskType is ``'RETRIEVAL_DOCUMENT'``.
217
+ output_dimensionality: Optional `reduced dimension for the output embedding <https://ai.google.dev/api/embeddings#EmbedContentRequest>`__.
204
218
  Returns:
205
219
  List of embeddings, one for each text.
206
220
  """
@@ -237,26 +251,26 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
237
251
  def embed_query(
238
252
  self,
239
253
  text: str,
254
+ *,
240
255
  task_type: Optional[str] = None,
241
256
  title: Optional[str] = None,
242
257
  output_dimensionality: Optional[int] = None,
243
258
  ) -> List[float]:
244
- """Embed a text, using the non-batch endpoint:
245
- https://ai.google.dev/api/rest/v1/models/embedContent#EmbedContentRequest
259
+ """Embed a text, using the `non-batch endpoint <https://ai.google.dev/api/embeddings#method:-models.embedcontent>`__.
246
260
 
247
261
  Args:
248
262
  text: The text to embed.
249
- task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType)
263
+ task_type: `task_type <https://ai.google.dev/api/embeddings#tasktype>`__
250
264
  title: An optional title for the text.
251
- Only applicable when TaskType is RETRIEVAL_DOCUMENT.
252
- output_dimensionality: Optional reduced dimension for the output embedding.
265
+ Only applicable when TaskType is ``'RETRIEVAL_DOCUMENT'``.
266
+ output_dimensionality: Optional `reduced dimension for the output embedding <https://ai.google.dev/api/embeddings#EmbedContentRequest>`__.
253
267
 
254
268
  Returns:
255
269
  Embedding for the text.
256
270
  """
257
271
  task_type_to_use = task_type if task_type else self.task_type
258
272
  if task_type_to_use is None:
259
- task_type_to_use = "RETRIEVAL_QUERY" # Default to RETRIEVAL_QUERY
273
+ task_type_to_use = "RETRIEVAL_QUERY"
260
274
  try:
261
275
  request: EmbedContentRequest = self._prepare_request(
262
276
  text=text,
@@ -268,3 +282,93 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
268
282
  except Exception as e:
269
283
  raise GoogleGenerativeAIError(f"Error embedding content: {e}") from e
270
284
  return list(result.embedding.values)
285
+
286
+ async def aembed_documents(
287
+ self,
288
+ texts: List[str],
289
+ *,
290
+ batch_size: int = _DEFAULT_BATCH_SIZE,
291
+ task_type: Optional[str] = None,
292
+ titles: Optional[List[str]] = None,
293
+ output_dimensionality: Optional[int] = None,
294
+ ) -> List[List[float]]:
295
+ """Embed a list of strings using the `batch endpoint <https://ai.google.dev/api/embeddings#method:-models.batchembedcontents>`__.
296
+
297
+ Google Generative AI currently sets a max batch size of 100 strings.
298
+
299
+ Args:
300
+ texts: List[str] The list of strings to embed.
301
+ batch_size: [int] The batch size of embeddings to send to the model
302
+ task_type: `task_type <https://ai.google.dev/api/embeddings#tasktype>`__
303
+ titles: An optional list of titles for texts provided.
304
+ Only applicable when TaskType is ``'RETRIEVAL_DOCUMENT'``.
305
+ output_dimensionality: Optional `reduced dimension for the output embedding <https://ai.google.dev/api/embeddings#EmbedContentRequest>`__.
306
+ Returns:
307
+ List of embeddings, one for each text.
308
+ """
309
+ embeddings: List[List[float]] = []
310
+ batch_start_index = 0
311
+ for batch in GoogleGenerativeAIEmbeddings._prepare_batches(texts, batch_size):
312
+ if titles:
313
+ titles_batch = titles[
314
+ batch_start_index : batch_start_index + len(batch)
315
+ ]
316
+ batch_start_index += len(batch)
317
+ else:
318
+ titles_batch = [None] * len(batch) # type: ignore[list-item]
319
+
320
+ requests = [
321
+ self._prepare_request(
322
+ text=text,
323
+ task_type=task_type,
324
+ title=title,
325
+ output_dimensionality=output_dimensionality,
326
+ )
327
+ for text, title in zip(batch, titles_batch)
328
+ ]
329
+
330
+ try:
331
+ result = await self.async_client.batch_embed_contents(
332
+ BatchEmbedContentsRequest(requests=requests, model=self.model)
333
+ )
334
+ except Exception as e:
335
+ raise GoogleGenerativeAIError(f"Error embedding content: {e}") from e
336
+ embeddings.extend([list(e.values) for e in result.embeddings])
337
+ return embeddings
338
+
339
+ async def aembed_query(
340
+ self,
341
+ text: str,
342
+ *,
343
+ task_type: Optional[str] = None,
344
+ title: Optional[str] = None,
345
+ output_dimensionality: Optional[int] = None,
346
+ ) -> List[float]:
347
+ """Embed a text, using the `non-batch endpoint <https://ai.google.dev/api/embeddings#method:-models.embedcontent>`__.
348
+
349
+ Args:
350
+ text: The text to embed.
351
+ task_type: `task_type <https://ai.google.dev/api/embeddings#tasktype>`__
352
+ title: An optional title for the text.
353
+ Only applicable when TaskType is ``'RETRIEVAL_DOCUMENT'``.
354
+ output_dimensionality: Optional `reduced dimension for the output embedding <https://ai.google.dev/api/embeddings#EmbedContentRequest>`__.
355
+
356
+ Returns:
357
+ Embedding for the text.
358
+ """
359
+ task_type_to_use = task_type if task_type else self.task_type
360
+ if task_type_to_use is None:
361
+ task_type_to_use = "RETRIEVAL_QUERY"
362
+ try:
363
+ request: EmbedContentRequest = self._prepare_request(
364
+ text=text,
365
+ task_type=task_type,
366
+ title=title,
367
+ output_dimensionality=output_dimensionality,
368
+ )
369
+ result: EmbedContentResponse = await self.async_client.embed_content(
370
+ request
371
+ )
372
+ except Exception as e:
373
+ raise GoogleGenerativeAIError(f"Error embedding content: {e}") from e
374
+ return list(result.embedding.values)
@@ -63,6 +63,9 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
63
63
  def validate_environment(self) -> Self:
64
64
  """Validates params and passes them to google-generativeai package."""
65
65
 
66
+ if not any(self.model.startswith(prefix) for prefix in ("models/",)):
67
+ self.model = f"models/{self.model}"
68
+
66
69
  self.client = ChatGoogleGenerativeAI(
67
70
  api_key=self.google_api_key,
68
71
  credentials=self.credentials,
@@ -86,6 +89,15 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
86
89
  """Get standard params for tracing."""
87
90
  ls_params = super()._get_ls_params(stop=stop, **kwargs)
88
91
  ls_params["ls_provider"] = "google_genai"
92
+
93
+ models_prefix = "models/"
94
+ ls_model_name = (
95
+ self.model[len(models_prefix) :]
96
+ if self.model and self.model.startswith(models_prefix)
97
+ else self.model
98
+ )
99
+ ls_params["ls_model_name"] = ls_model_name
100
+
89
101
  if ls_max_tokens := kwargs.get("max_output_tokens", self.max_output_tokens):
90
102
  ls_params["ls_max_tokens"] = ls_max_tokens
91
103
  return ls_params
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "langchain-google-genai"
3
- version = "2.1.6"
3
+ version = "2.1.8"
4
4
  description = "An integration package connecting Google's genai package and LangChain"
5
5
  authors = []
6
6
  readme = "README.md"
@@ -12,7 +12,7 @@ license = "MIT"
12
12
 
13
13
  [tool.poetry.dependencies]
14
14
  python = ">=3.9,<4.0"
15
- langchain-core = "^0.3.66"
15
+ langchain-core = "^0.3.68"
16
16
  google-ai-generativelanguage = "^0.6.18"
17
17
  pydantic = ">=2,<3"
18
18
  filetype = "^1.2.0"
@@ -29,7 +29,7 @@ pytest-watcher = "^0.3.4"
29
29
  pytest-asyncio = "^0.21.1"
30
30
  pytest-retry = "^1.7.0"
31
31
  numpy = ">=1.26.2"
32
- langchain-tests = "0.3.19"
32
+ langchain-tests = "0.3.20"
33
33
 
34
34
  [tool.codespell]
35
35
  ignore-words-list = "rouge"
@@ -58,7 +58,7 @@ ruff = "^0.1.5"
58
58
 
59
59
  [tool.poetry.group.typing.dependencies]
60
60
  mypy = "^1.10"
61
- types-requests = "^2.28.11.5"
61
+ types-requests = "^2.31.0"
62
62
  types-google-cloud-ndb = "^2.2.0.1"
63
63
  types-protobuf = "^4.24.0.20240302"
64
64
  numpy = ">=1.26.2"
@@ -68,7 +68,7 @@ numpy = ">=1.26.2"
68
68
  optional = true
69
69
 
70
70
  [tool.poetry.group.dev.dependencies]
71
- types-requests = "^2.31.0.10"
71
+ types-requests = "^2.31.0"
72
72
  types-google-cloud-ndb = "^2.2.0.1"
73
73
 
74
74
  [tool.ruff.lint]