langchain-google-genai 2.0.0.dev1__tar.gz → 2.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain-google-genai might be problematic. Click here for more details.

Files changed (16) hide show
  1. {langchain_google_genai-2.0.0.dev1 → langchain_google_genai-2.0.2}/PKG-INFO +3 -3
  2. {langchain_google_genai-2.0.0.dev1 → langchain_google_genai-2.0.2}/langchain_google_genai/_function_utils.py +155 -34
  3. {langchain_google_genai-2.0.0.dev1 → langchain_google_genai-2.0.2}/langchain_google_genai/_genai_extension.py +1 -1
  4. {langchain_google_genai-2.0.0.dev1 → langchain_google_genai-2.0.2}/langchain_google_genai/chat_models.py +244 -41
  5. {langchain_google_genai-2.0.0.dev1 → langchain_google_genai-2.0.2}/langchain_google_genai/genai_aqa.py +1 -1
  6. {langchain_google_genai-2.0.0.dev1 → langchain_google_genai-2.0.2}/langchain_google_genai/llms.py +5 -2
  7. {langchain_google_genai-2.0.0.dev1 → langchain_google_genai-2.0.2}/pyproject.toml +7 -7
  8. {langchain_google_genai-2.0.0.dev1 → langchain_google_genai-2.0.2}/LICENSE +0 -0
  9. {langchain_google_genai-2.0.0.dev1 → langchain_google_genai-2.0.2}/README.md +0 -0
  10. {langchain_google_genai-2.0.0.dev1 → langchain_google_genai-2.0.2}/langchain_google_genai/__init__.py +0 -0
  11. {langchain_google_genai-2.0.0.dev1 → langchain_google_genai-2.0.2}/langchain_google_genai/_common.py +0 -0
  12. {langchain_google_genai-2.0.0.dev1 → langchain_google_genai-2.0.2}/langchain_google_genai/_enums.py +0 -0
  13. {langchain_google_genai-2.0.0.dev1 → langchain_google_genai-2.0.2}/langchain_google_genai/_image_utils.py +0 -0
  14. {langchain_google_genai-2.0.0.dev1 → langchain_google_genai-2.0.2}/langchain_google_genai/embeddings.py +0 -0
  15. {langchain_google_genai-2.0.0.dev1 → langchain_google_genai-2.0.2}/langchain_google_genai/google_vector_store.py +0 -0
  16. {langchain_google_genai-2.0.0.dev1 → langchain_google_genai-2.0.2}/langchain_google_genai/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langchain-google-genai
3
- Version: 2.0.0.dev1
3
+ Version: 2.0.2
4
4
  Summary: An integration package connecting Google's genai package and LangChain
5
5
  Home-page: https://github.com/langchain-ai/langchain-google
6
6
  License: MIT
@@ -12,8 +12,8 @@ Classifier: Programming Language :: Python :: 3.10
12
12
  Classifier: Programming Language :: Python :: 3.11
13
13
  Classifier: Programming Language :: Python :: 3.12
14
14
  Provides-Extra: images
15
- Requires-Dist: google-generativeai (>=0.7.0,<0.8.0)
16
- Requires-Dist: langchain-core (>=0.3.0.dev4,<0.4.0)
15
+ Requires-Dist: google-generativeai (>=0.8.0,<0.9.0)
16
+ Requires-Dist: langchain-core (>=0.3.13,<0.4)
17
17
  Requires-Dist: pillow (>=10.1.0,<11.0.0) ; extra == "images"
18
18
  Requires-Dist: pydantic (>=2,<3)
19
19
  Project-URL: Repository, https://github.com/langchain-ai/langchain-google
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import collections
4
+ import importlib
4
5
  import json
5
6
  import logging
6
7
  from typing import (
@@ -42,9 +43,9 @@ TYPE_ENUM = {
42
43
  "boolean": glm.Type.BOOLEAN,
43
44
  "array": glm.Type.ARRAY,
44
45
  "object": glm.Type.OBJECT,
46
+ "null": None,
45
47
  }
46
48
 
47
- TYPE_ENUM_REVERSE = {v: k for k, v in TYPE_ENUM.items()}
48
49
  _ALLOWED_SCHEMA_FIELDS = []
49
50
  _ALLOWED_SCHEMA_FIELDS.extend([f.name for f in gapic.Schema()._pb.DESCRIPTOR.fields])
50
51
  _ALLOWED_SCHEMA_FIELDS.extend(
@@ -101,12 +102,7 @@ def _format_json_schema_to_gapic(schema: Dict[str, Any]) -> Dict[str, Any]:
101
102
  elif key == "items":
102
103
  converted_schema["items"] = _format_json_schema_to_gapic(value)
103
104
  elif key == "properties":
104
- if "properties" not in converted_schema:
105
- converted_schema["properties"] = {}
106
- for pkey, pvalue in value.items():
107
- converted_schema["properties"][pkey] = _format_json_schema_to_gapic(
108
- pvalue
109
- )
105
+ converted_schema["properties"] = _get_properties_from_schema(value)
110
106
  continue
111
107
  elif key == "allOf":
112
108
  if len(value) > 1:
@@ -136,8 +132,9 @@ def _dict_to_gapic_schema(schema: Dict[str, Any]) -> Optional[gapic.Schema]:
136
132
  def _format_dict_to_function_declaration(
137
133
  tool: Union[FunctionDescription, Dict[str, Any]],
138
134
  ) -> gapic.FunctionDeclaration:
135
+ print(tool)
139
136
  return gapic.FunctionDeclaration(
140
- name=tool.get("name"),
137
+ name=tool.get("name") or tool.get("title"),
141
138
  description=tool.get("description"),
142
139
  parameters=_dict_to_gapic_schema(tool.get("parameters", {})),
143
140
  )
@@ -157,13 +154,11 @@ def convert_to_genai_function_declarations(
157
154
  for tool in tools:
158
155
  if isinstance(tool, gapic.Tool):
159
156
  gapic_tool.function_declarations.extend(tool.function_declarations)
157
+ elif isinstance(tool, dict) and "function_declarations" not in tool:
158
+ fd = _format_to_gapic_function_declaration(tool)
159
+ gapic_tool.function_declarations.append(fd)
160
160
  elif isinstance(tool, dict):
161
- if "function_declarations" not in tool:
162
- fd = _format_to_gapic_function_declaration(tool)
163
- gapic_tool.function_declarations.append(fd)
164
- continue
165
- tool = cast(_ToolDictLike, tool)
166
- function_declarations = tool["function_declarations"]
161
+ function_declarations = cast(_ToolDictLike, tool)["function_declarations"]
167
162
  if not isinstance(function_declarations, collections.abc.Sequence):
168
163
  raise ValueError(
169
164
  "function_declarations should be a list"
@@ -199,18 +194,26 @@ def _format_to_gapic_function_declaration(
199
194
  ) -> gapic.FunctionDeclaration:
200
195
  if isinstance(tool, BaseTool):
201
196
  return _format_base_tool_to_function_declaration(tool)
202
- elif isinstance(tool, type) and issubclass(tool, BaseModel):
197
+ elif isinstance(tool, type) and is_basemodel_subclass_safe(tool):
203
198
  return _convert_pydantic_to_genai_function(tool)
204
199
  elif isinstance(tool, dict):
205
- if all(k in tool for k in ("name", "description")) and "parameters" not in tool:
200
+ if all(k in tool for k in ("type", "function")) and tool["type"] == "function":
201
+ function = tool["function"]
202
+ elif (
203
+ all(k in tool for k in ("name", "description")) and "parameters" not in tool
204
+ ):
206
205
  function = cast(dict, tool)
207
- function["parameters"] = {}
208
206
  else:
209
- if "parameters" in tool and tool["parameters"].get("properties"): # type: ignore[index]
207
+ if (
208
+ "parameters" in tool and tool["parameters"].get("properties") # type: ignore[index]
209
+ ):
210
210
  function = convert_to_openai_tool(cast(dict, tool))["function"]
211
211
  else:
212
212
  function = cast(dict, tool)
213
- function["parameters"] = {}
213
+ function["parameters"] = function.get("parameters") or {}
214
+ # Empty 'properties' field not supported.
215
+ if not function["parameters"].get("properties"):
216
+ function["parameters"] = {}
214
217
  return _format_dict_to_function_declaration(cast(FunctionDescription, function))
215
218
  elif callable(tool):
216
219
  return _format_base_tool_to_function_declaration(callable_as_lc_tool()(tool))
@@ -269,13 +272,12 @@ def _convert_pydantic_to_genai_function(
269
272
  name=tool_name if tool_name else schema.get("title"),
270
273
  description=tool_description if tool_description else schema.get("description"),
271
274
  parameters={
272
- "properties": {
273
- k: {
274
- "type_": _get_type_from_schema(v),
275
- "description": v.get("description"),
276
- }
277
- for k, v in schema["properties"].items()
278
- },
275
+ "properties": _get_properties_from_schema_any(
276
+ schema.get("properties")
277
+ ), # TODO: use _dict_to_gapic_schema() if possible
278
+ # "items": _get_items_from_schema_any(
279
+ # schema
280
+ # ), # TODO: fix it https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/function-calling?hl#schema
279
281
  "required": schema.get("required", []),
280
282
  "type_": TYPE_ENUM[schema["type"]],
281
283
  },
@@ -283,23 +285,121 @@ def _convert_pydantic_to_genai_function(
283
285
  return function_declaration
284
286
 
285
287
 
288
+ def _get_properties_from_schema_any(schema: Any) -> Dict[str, Any]:
289
+ if isinstance(schema, Dict):
290
+ return _get_properties_from_schema(schema)
291
+ return {}
292
+
293
+
294
+ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
295
+ properties = {}
296
+ for k, v in schema.items():
297
+ if not isinstance(k, str):
298
+ logger.warning(f"Key '{k}' is not supported in schema, type={type(k)}")
299
+ continue
300
+ if not isinstance(v, Dict):
301
+ logger.warning(f"Value '{v}' is not supported in schema, ignoring v={v}")
302
+ continue
303
+ properties_item: Dict[str, Union[str, int, Dict, List]] = {}
304
+ if v.get("type") or v.get("anyOf") or v.get("type_"):
305
+ properties_item["type_"] = _get_type_from_schema(v)
306
+ if _is_nullable_schema(v):
307
+ properties_item["nullable"] = True
308
+
309
+ if v.get("enum"):
310
+ properties_item["enum"] = v["enum"]
311
+
312
+ description = v.get("description")
313
+ if description and isinstance(description, str):
314
+ properties_item["description"] = description
315
+
316
+ if properties_item.get("type_") == glm.Type.ARRAY and v.get("items"):
317
+ properties_item["items"] = _get_items_from_schema_any(v.get("items"))
318
+
319
+ if properties_item.get("type_") == glm.Type.OBJECT and v.get("properties"):
320
+ properties_item["properties"] = _get_properties_from_schema_any(
321
+ v.get("properties")
322
+ )
323
+ if k == "title" and "description" not in properties_item:
324
+ properties_item["description"] = k + " is " + str(v)
325
+
326
+ properties[k] = properties_item
327
+
328
+ return properties
329
+
330
+
331
+ def _get_items_from_schema_any(schema: Any) -> Dict[str, Any]:
332
+ if isinstance(schema, (dict, list, str)):
333
+ return _get_items_from_schema(schema)
334
+ return {}
335
+
336
+
337
+ def _get_items_from_schema(schema: Union[Dict, List, str]) -> Dict[str, Any]:
338
+ items: Dict = {}
339
+ if isinstance(schema, List):
340
+ for i, v in enumerate(schema):
341
+ items[f"item{i}"] = _get_properties_from_schema_any(v)
342
+ elif isinstance(schema, Dict):
343
+ items["type_"] = _get_type_from_schema(schema)
344
+ if items["type_"] == glm.Type.OBJECT and "properties" in schema:
345
+ items["properties"] = _get_properties_from_schema_any(schema["properties"])
346
+ if "title" in schema:
347
+ items["title"] = schema
348
+ if "title" in schema or "description" in schema:
349
+ items["description"] = (
350
+ schema.get("description") or schema.get("title") or ""
351
+ )
352
+ if _is_nullable_schema(schema):
353
+ items["nullable"] = True
354
+ else:
355
+ # str
356
+ items["type_"] = _get_type_from_schema({"type": schema})
357
+ if _is_nullable_schema({"type": schema}):
358
+ items["nullable"] = True
359
+
360
+ return items
361
+
362
+
286
363
  def _get_type_from_schema(schema: Dict[str, Any]) -> int:
364
+ return _get_nullable_type_from_schema(schema) or glm.Type.STRING
365
+
366
+
367
+ def _get_nullable_type_from_schema(schema: Dict[str, Any]) -> Optional[int]:
287
368
  if "anyOf" in schema:
288
- types = [_get_type_from_schema(sub_schema) for sub_schema in schema["anyOf"]]
369
+ types = [
370
+ _get_nullable_type_from_schema(sub_schema) for sub_schema in schema["anyOf"]
371
+ ]
289
372
  types = [t for t in types if t is not None] # Remove None values
290
373
  if types:
291
374
  return types[-1] # TODO: update FunctionDeclaration and pass all types?
292
375
  else:
293
376
  pass
294
- elif "type" in schema:
295
- stype = str(schema["type"])
296
- if stype in TYPE_ENUM:
297
- return TYPE_ENUM[stype]
298
- else:
299
- pass
377
+ elif "type" in schema or "type_" in schema:
378
+ type_ = schema["type"] if "type" in schema else schema["type_"]
379
+ if isinstance(type_, int):
380
+ return type_
381
+ stype = str(schema["type"]) if "type" in schema else str(schema["type_"])
382
+ return TYPE_ENUM.get(stype, glm.Type.STRING)
300
383
  else:
301
384
  pass
302
- return TYPE_ENUM["string"] # Default to string if no valid types found
385
+ return glm.Type.STRING # Default to string if no valid types found
386
+
387
+
388
+ def _is_nullable_schema(schema: Dict[str, Any]) -> bool:
389
+ if "anyOf" in schema:
390
+ types = [
391
+ _get_nullable_type_from_schema(sub_schema) for sub_schema in schema["anyOf"]
392
+ ]
393
+ return any(t is None for t in types)
394
+ elif "type" in schema or "type_" in schema:
395
+ type_ = schema["type"] if "type" in schema else schema["type_"]
396
+ if isinstance(type_, int):
397
+ return False
398
+ stype = str(schema["type"]) if "type" in schema else str(schema["type_"])
399
+ return TYPE_ENUM.get(stype, glm.Type.STRING) is None
400
+ else:
401
+ pass
402
+ return False
303
403
 
304
404
 
305
405
  _ToolChoiceType = Union[
@@ -356,3 +456,24 @@ def _tool_choice_to_tool_config(
356
456
  "allowed_function_names": allowed_function_names,
357
457
  }
358
458
  )
459
+
460
+
461
+ def is_basemodel_subclass_safe(tool: Type) -> bool:
462
+ if safe_import("langchain_core.utils.pydantic", "is_basemodel_subclass"):
463
+ from langchain_core.utils.pydantic import (
464
+ is_basemodel_subclass, # type: ignore[import]
465
+ )
466
+
467
+ return is_basemodel_subclass(tool)
468
+ else:
469
+ return issubclass(tool, BaseModel)
470
+
471
+
472
+ def safe_import(module_name: str, attribute_name: str = "") -> bool:
473
+ try:
474
+ module = importlib.import_module(module_name)
475
+ if attribute_name:
476
+ return hasattr(module, attribute_name)
477
+ return True
478
+ except ImportError:
479
+ return False
@@ -238,7 +238,7 @@ def _prepare_config(
238
238
  client_info: Optional[gapic_v1.client_info.ClientInfo] = None,
239
239
  transport: Optional[str] = None,
240
240
  ) -> Dict[str, Any]:
241
- formatted_client_options = {"api_endpoint": _config.api_endpoint}
241
+ formatted_client_options: dict = {"api_endpoint": _config.api_endpoint}
242
242
  if client_options:
243
243
  formatted_client_options.update(**client_options)
244
244
  if not credentials and api_key:
@@ -31,6 +31,9 @@ import google.api_core
31
31
  # TODO: remove ignore once the google package is published with types
32
32
  import proto # type: ignore[import]
33
33
  import requests
34
+ from google.ai.generativelanguage_v1beta import (
35
+ GenerativeServiceAsyncClient as v1betaGenerativeServiceAsyncClient,
36
+ )
34
37
  from google.ai.generativelanguage_v1beta.types import (
35
38
  Blob,
36
39
  Candidate,
@@ -46,7 +49,9 @@ from google.ai.generativelanguage_v1beta.types import (
46
49
  ToolConfig,
47
50
  VideoMetadata,
48
51
  )
52
+ from google.generativeai.caching import CachedContent # type: ignore[import]
49
53
  from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
54
+ from google.generativeai.types import caching_types, content_types
50
55
  from google.generativeai.types.content_types import ( # type: ignore[import]
51
56
  FunctionDeclarationType,
52
57
  ToolDict,
@@ -77,7 +82,7 @@ from langchain_core.output_parsers.openai_tools import (
77
82
  from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
78
83
  from langchain_core.runnables import Runnable, RunnablePassthrough
79
84
  from langchain_core.utils import secret_from_env
80
- from langchain_core.utils.pydantic import is_basemodel_subclass
85
+ from langchain_core.utils.function_calling import convert_to_openai_tool
81
86
  from pydantic import (
82
87
  BaseModel,
83
88
  ConfigDict,
@@ -104,6 +109,7 @@ from langchain_google_genai._function_utils import (
104
109
  _ToolChoiceType,
105
110
  _ToolConfigDict,
106
111
  convert_to_genai_function_declarations,
112
+ is_basemodel_subclass_safe,
107
113
  tool_to_dict,
108
114
  )
109
115
  from langchain_google_genai._image_utils import ImageBytesLoader
@@ -563,20 +569,29 @@ def _parse_response_candidate(
563
569
  def _response_to_result(
564
570
  response: GenerateContentResponse,
565
571
  stream: bool = False,
572
+ prev_usage: Optional[UsageMetadata] = None,
566
573
  ) -> ChatResult:
567
574
  """Converts a PaLM API response into a LangChain ChatResult."""
568
575
  llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
569
576
 
577
+ # previous usage metadata needs to be subtracted because gemini api returns
578
+ # already-accumulated token counts with each chunk
579
+ prev_input_tokens = prev_usage["input_tokens"] if prev_usage else 0
580
+ prev_output_tokens = prev_usage["output_tokens"] if prev_usage else 0
581
+ prev_total_tokens = prev_usage["total_tokens"] if prev_usage else 0
582
+
570
583
  # Get usage metadata
571
584
  try:
572
585
  input_tokens = response.usage_metadata.prompt_token_count
573
586
  output_tokens = response.usage_metadata.candidates_token_count
574
587
  total_tokens = response.usage_metadata.total_token_count
575
- if input_tokens + output_tokens + total_tokens > 0:
588
+ cache_read_tokens = response.usage_metadata.cached_content_token_count
589
+ if input_tokens + output_tokens + cache_read_tokens + total_tokens > 0:
576
590
  lc_usage = UsageMetadata(
577
- input_tokens=input_tokens,
578
- output_tokens=output_tokens,
579
- total_tokens=total_tokens,
591
+ input_tokens=input_tokens - prev_input_tokens,
592
+ output_tokens=output_tokens - prev_output_tokens,
593
+ total_tokens=total_tokens - prev_total_tokens,
594
+ input_token_details={"cache_read": cache_read_tokens},
580
595
  )
581
596
  else:
582
597
  lc_usage = None
@@ -595,12 +610,17 @@ def _response_to_result(
595
610
  ]
596
611
  message = _parse_response_candidate(candidate, streaming=stream)
597
612
  message.usage_metadata = lc_usage
598
- generations.append(
599
- (ChatGenerationChunk if stream else ChatGeneration)(
600
- message=message,
601
- generation_info=generation_info,
613
+ if stream:
614
+ generations.append(
615
+ ChatGenerationChunk(
616
+ message=cast(AIMessageChunk, message),
617
+ generation_info=generation_info,
618
+ )
619
+ )
620
+ else:
621
+ generations.append(
622
+ ChatGeneration(message=message, generation_info=generation_info)
602
623
  )
603
- )
604
624
  if not response.candidates:
605
625
  # Likely a "prompt feedback" violation (e.g., toxic input)
606
626
  # Raising an error would be different than how OpenAI handles it,
@@ -609,12 +629,14 @@ def _response_to_result(
609
629
  "Gemini produced an empty response. Continuing with empty message\n"
610
630
  f"Feedback: {response.prompt_feedback}"
611
631
  )
612
- generations = [
613
- (ChatGenerationChunk if stream else ChatGeneration)(
614
- message=(AIMessageChunk if stream else AIMessage)(content=""),
615
- generation_info={},
616
- )
617
- ]
632
+ if stream:
633
+ generations = [
634
+ ChatGenerationChunk(
635
+ message=AIMessageChunk(content=""), generation_info={}
636
+ )
637
+ ]
638
+ else:
639
+ generations = [ChatGeneration(message=AIMessage(""), generation_info={})]
618
640
  return ChatResult(generations=generations, llm_output=llm_output)
619
641
 
620
642
 
@@ -703,7 +725,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
703
725
  Tool calling:
704
726
  .. code-block:: python
705
727
 
706
- from langchain_core.pydantic_v1 import BaseModel, Field
728
+ from pydantic import BaseModel, Field
707
729
 
708
730
 
709
731
  class GetWeather(BaseModel):
@@ -748,7 +770,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
748
770
 
749
771
  from typing import Optional
750
772
 
751
- from langchain_core.pydantic_v1 import BaseModel, Field
773
+ from pydantic import BaseModel, Field
752
774
 
753
775
 
754
776
  class Joke(BaseModel):
@@ -806,7 +828,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
806
828
  {'input_tokens': 18, 'output_tokens': 5, 'total_tokens': 23}
807
829
 
808
830
 
809
- Response metadata
831
+ Response metadata
810
832
  .. code-block:: python
811
833
 
812
834
  ai_msg = llm.invoke(messages)
@@ -823,11 +845,11 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
823
845
  """ # noqa: E501
824
846
 
825
847
  client: Any = Field(default=None, exclude=True) #: :meta private:
826
- async_client: Any = Field(default=None, exclude=True) #: :meta private:
848
+ async_client_running: Any = Field(default=None, exclude=True) #: :meta private:
827
849
  google_api_key: Optional[SecretStr] = Field(
828
850
  alias="api_key", default_factory=secret_from_env("GOOGLE_API_KEY", default=None)
829
851
  )
830
- """Google AI API key.
852
+ """Google AI API key.
831
853
  If not specified will be read from env var ``GOOGLE_API_KEY``."""
832
854
  default_metadata: Sequence[Tuple[str, str]] = Field(
833
855
  default_factory=list
@@ -839,6 +861,14 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
839
861
  Gemini does not support system messages; any unsupported messages will
840
862
  raise an error."""
841
863
 
864
+ cached_content: Optional[str] = None
865
+ """The name of the cached content used as context to serve the prediction.
866
+
867
+ Note: only used in explicit caching, where users can have control over caching
868
+ (e.g. what content to cache) and enjoy guaranteed cost savings. Format:
869
+ ``cachedContents/{cachedContent}``.
870
+ """
871
+
842
872
  model_config = ConfigDict(
843
873
  populate_by_name=True,
844
874
  )
@@ -887,24 +917,31 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
887
917
  client_options=self.client_options,
888
918
  transport=transport,
889
919
  )
920
+ self.async_client_running = None
921
+ return self
890
922
 
923
+ @property
924
+ def async_client(self) -> v1betaGenerativeServiceAsyncClient:
925
+ google_api_key = None
926
+ if not self.credentials:
927
+ if isinstance(self.google_api_key, SecretStr):
928
+ google_api_key = self.google_api_key.get_secret_value()
929
+ else:
930
+ google_api_key = self.google_api_key
891
931
  # NOTE: genaix.build_generative_async_service requires
892
932
  # a running event loop, which causes an error
893
933
  # when initialized inside a ThreadPoolExecutor.
894
934
  # this check ensures that async client is only initialized
895
935
  # within an asyncio event loop to avoid the error
896
- if _is_event_loop_running():
897
- self.async_client = genaix.build_generative_async_service(
936
+ if not self.async_client_running and _is_event_loop_running():
937
+ self.async_client_running = genaix.build_generative_async_service(
898
938
  credentials=self.credentials,
899
939
  api_key=google_api_key,
900
- client_info=client_info,
940
+ client_info=get_client_info("ChatGoogleGenerativeAI"),
901
941
  client_options=self.client_options,
902
- transport=transport,
942
+ transport=self.transport,
903
943
  )
904
- else:
905
- self.async_client = None
906
-
907
- return self
944
+ return self.async_client_running
908
945
 
909
946
  @property
910
947
  def _identifying_params(self) -> Dict[str, Any]:
@@ -966,6 +1003,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
966
1003
  safety_settings: Optional[SafetySettingDict] = None,
967
1004
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
968
1005
  generation_config: Optional[Dict[str, Any]] = None,
1006
+ cached_content: Optional[str] = None,
1007
+ tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
969
1008
  **kwargs: Any,
970
1009
  ) -> ChatResult:
971
1010
  request = self._prepare_request(
@@ -976,6 +1015,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
976
1015
  safety_settings=safety_settings,
977
1016
  tool_config=tool_config,
978
1017
  generation_config=generation_config,
1018
+ cached_content=cached_content or self.cached_content,
1019
+ tool_choice=tool_choice,
979
1020
  )
980
1021
  response: GenerateContentResponse = _chat_with_retry(
981
1022
  request=request,
@@ -996,6 +1037,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
996
1037
  safety_settings: Optional[SafetySettingDict] = None,
997
1038
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
998
1039
  generation_config: Optional[Dict[str, Any]] = None,
1040
+ cached_content: Optional[str] = None,
1041
+ tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
999
1042
  **kwargs: Any,
1000
1043
  ) -> ChatResult:
1001
1044
  if not self.async_client:
@@ -1021,6 +1064,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1021
1064
  safety_settings=safety_settings,
1022
1065
  tool_config=tool_config,
1023
1066
  generation_config=generation_config,
1067
+ cached_content=cached_content or self.cached_content,
1068
+ tool_choice=tool_choice,
1024
1069
  )
1025
1070
  response: GenerateContentResponse = await _achat_with_retry(
1026
1071
  request=request,
@@ -1041,6 +1086,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1041
1086
  safety_settings: Optional[SafetySettingDict] = None,
1042
1087
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1043
1088
  generation_config: Optional[Dict[str, Any]] = None,
1089
+ cached_content: Optional[str] = None,
1090
+ tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
1044
1091
  **kwargs: Any,
1045
1092
  ) -> Iterator[ChatGenerationChunk]:
1046
1093
  request = self._prepare_request(
@@ -1051,6 +1098,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1051
1098
  safety_settings=safety_settings,
1052
1099
  tool_config=tool_config,
1053
1100
  generation_config=generation_config,
1101
+ cached_content=cached_content or self.cached_content,
1102
+ tool_choice=tool_choice,
1054
1103
  )
1055
1104
  response: GenerateContentResponse = _chat_with_retry(
1056
1105
  request=request,
@@ -1058,9 +1107,31 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1058
1107
  **kwargs,
1059
1108
  metadata=self.default_metadata,
1060
1109
  )
1110
+
1111
+ prev_usage_metadata: UsageMetadata | None = None
1061
1112
  for chunk in response:
1062
- _chat_result = _response_to_result(chunk, stream=True)
1113
+ _chat_result = _response_to_result(
1114
+ chunk, stream=True, prev_usage=prev_usage_metadata
1115
+ )
1063
1116
  gen = cast(ChatGenerationChunk, _chat_result.generations[0])
1117
+ message = cast(AIMessageChunk, gen.message)
1118
+
1119
+ curr_usage_metadata: UsageMetadata | dict[str, int] = (
1120
+ message.usage_metadata or {}
1121
+ )
1122
+
1123
+ prev_usage_metadata = (
1124
+ message.usage_metadata
1125
+ if prev_usage_metadata is None
1126
+ else UsageMetadata(
1127
+ input_tokens=prev_usage_metadata.get("input_tokens", 0)
1128
+ + curr_usage_metadata.get("input_tokens", 0),
1129
+ output_tokens=prev_usage_metadata.get("output_tokens", 0)
1130
+ + curr_usage_metadata.get("output_tokens", 0),
1131
+ total_tokens=prev_usage_metadata.get("total_tokens", 0)
1132
+ + curr_usage_metadata.get("total_tokens", 0),
1133
+ )
1134
+ )
1064
1135
 
1065
1136
  if run_manager:
1066
1137
  run_manager.on_llm_new_token(gen.text)
@@ -1077,6 +1148,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1077
1148
  safety_settings: Optional[SafetySettingDict] = None,
1078
1149
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1079
1150
  generation_config: Optional[Dict[str, Any]] = None,
1151
+ cached_content: Optional[str] = None,
1152
+ tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
1080
1153
  **kwargs: Any,
1081
1154
  ) -> AsyncIterator[ChatGenerationChunk]:
1082
1155
  if not self.async_client:
@@ -1103,15 +1176,38 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1103
1176
  safety_settings=safety_settings,
1104
1177
  tool_config=tool_config,
1105
1178
  generation_config=generation_config,
1179
+ cached_content=cached_content or self.cached_content,
1180
+ tool_choice=tool_choice,
1106
1181
  )
1182
+ prev_usage_metadata: UsageMetadata | None = None
1107
1183
  async for chunk in await _achat_with_retry(
1108
1184
  request=request,
1109
1185
  generation_method=self.async_client.stream_generate_content,
1110
1186
  **kwargs,
1111
1187
  metadata=self.default_metadata,
1112
1188
  ):
1113
- _chat_result = _response_to_result(chunk, stream=True)
1189
+ _chat_result = _response_to_result(
1190
+ chunk, stream=True, prev_usage=prev_usage_metadata
1191
+ )
1114
1192
  gen = cast(ChatGenerationChunk, _chat_result.generations[0])
1193
+ message = cast(AIMessageChunk, gen.message)
1194
+
1195
+ curr_usage_metadata: UsageMetadata | dict[str, int] = (
1196
+ message.usage_metadata or {}
1197
+ )
1198
+
1199
+ prev_usage_metadata = (
1200
+ message.usage_metadata
1201
+ if prev_usage_metadata is None
1202
+ else UsageMetadata(
1203
+ input_tokens=prev_usage_metadata.get("input_tokens", 0)
1204
+ + curr_usage_metadata.get("input_tokens", 0),
1205
+ output_tokens=prev_usage_metadata.get("output_tokens", 0)
1206
+ + curr_usage_metadata.get("output_tokens", 0),
1207
+ total_tokens=prev_usage_metadata.get("total_tokens", 0)
1208
+ + curr_usage_metadata.get("total_tokens", 0),
1209
+ )
1210
+ )
1115
1211
 
1116
1212
  if run_manager:
1117
1213
  await run_manager.on_llm_new_token(gen.text)
@@ -1126,8 +1222,15 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1126
1222
  functions: Optional[Sequence[FunctionDeclarationType]] = None,
1127
1223
  safety_settings: Optional[SafetySettingDict] = None,
1128
1224
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1225
+ tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
1129
1226
  generation_config: Optional[Dict[str, Any]] = None,
1227
+ cached_content: Optional[str] = None,
1130
1228
  ) -> Tuple[GenerateContentRequest, Dict[str, Any]]:
1229
+ if tool_choice and tool_config:
1230
+ raise ValueError(
1231
+ "Must specify at most one of tool_choice and tool_config, received "
1232
+ f"both:\n\n{tool_choice=}\n\n{tool_config=}"
1233
+ )
1131
1234
  formatted_tools = None
1132
1235
  if tools:
1133
1236
  formatted_tools = [convert_to_genai_function_declarations(tools)]
@@ -1138,6 +1241,18 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1138
1241
  messages,
1139
1242
  convert_system_message_to_human=self.convert_system_message_to_human,
1140
1243
  )
1244
+ if tool_choice:
1245
+ if not formatted_tools:
1246
+ msg = (
1247
+ f"Received {tool_choice=} but no {tools=}. 'tool_choice' can only "
1248
+ f"be specified if 'tools' is specified."
1249
+ )
1250
+ raise ValueError(msg)
1251
+ all_names = [
1252
+ f.name for t in formatted_tools for f in t.function_declarations
1253
+ ]
1254
+ tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
1255
+
1141
1256
  formatted_tool_config = None
1142
1257
  if tool_config:
1143
1258
  formatted_tool_config = ToolConfig(
@@ -1158,6 +1273,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1158
1273
  generation_config=self._prepare_params(
1159
1274
  stop, generation_config=generation_config
1160
1275
  ),
1276
+ cached_content=cached_content,
1161
1277
  )
1162
1278
  if system_instruction:
1163
1279
  request.system_instruction = system_instruction
@@ -1189,7 +1305,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1189
1305
  ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
1190
1306
  if kwargs:
1191
1307
  raise ValueError(f"Received unsupported arguments {kwargs}")
1192
- if isinstance(schema, type) and is_basemodel_subclass(schema):
1308
+ if isinstance(schema, type) and is_basemodel_subclass_safe(schema):
1193
1309
  parser: OutputParserLike = PydanticToolsParser(
1194
1310
  tools=[schema], first_tool_only=True
1195
1311
  )
@@ -1233,20 +1349,107 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1233
1349
  "Must specify at most one of tool_choice and tool_config, received "
1234
1350
  f"both:\n\n{tool_choice=}\n\n{tool_config=}"
1235
1351
  )
1236
- # Bind dicts for easier serialization/deserialization.
1237
- genai_tools = [tool_to_dict(convert_to_genai_function_declarations(tools))]
1238
- if tool_choice:
1239
- all_names = [
1240
- f["name"] # type: ignore[index]
1241
- for t in genai_tools
1242
- for f in t["function_declarations"]
1352
+ try:
1353
+ formatted_tools: list = [convert_to_openai_tool(tool) for tool in tools] # type: ignore[arg-type]
1354
+ except Exception:
1355
+ formatted_tools = [
1356
+ tool_to_dict(convert_to_genai_function_declarations(tools))
1243
1357
  ]
1358
+ if tool_choice:
1359
+ kwargs["tool_choice"] = tool_choice
1360
+ elif tool_config:
1361
+ kwargs["tool_config"] = tool_config
1362
+ else:
1363
+ pass
1364
+ return self.bind(tools=formatted_tools, **kwargs)
1365
+
1366
+ def create_cached_content(
1367
+ self,
1368
+ contents: Union[List[BaseMessage], content_types.ContentsType],
1369
+ *,
1370
+ display_name: str | None = None,
1371
+ tools: Union[ToolDict, GoogleTool, None] = None,
1372
+ tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
1373
+ ttl: Optional[caching_types.TTLTypes] = None,
1374
+ expire_time: Optional[caching_types.ExpireTimeTypes] = None,
1375
+ ) -> str:
1376
+ """
1377
+
1378
+ Args:
1379
+ display_name: The user-generated meaningful display name
1380
+ of the cached content. `display_name` must be no
1381
+ more than 128 unicode characters.
1382
+ contents: Contents to cache.
1383
+ tools: A list of `Tools` the model may use to generate response.
1384
+ tool_choice: Which tool to require the model to call.
1385
+ ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
1386
+ `ttl` and `expire_time` are exclusive arguments.
1387
+ expire_time: Expiration time for cached resource.
1388
+ `ttl` and `expire_time` are exclusive arguments.
1389
+ """
1390
+ system: Optional[content_types.ContentType] = None
1391
+ genai_contents: list = []
1392
+ if all(isinstance(c, BaseMessage) for c in contents):
1393
+ system, genai_contents = _parse_chat_history(
1394
+ contents,
1395
+ convert_system_message_to_human=self.convert_system_message_to_human,
1396
+ )
1397
+ elif any(isinstance(c, BaseMessage) for c in contents):
1398
+ raise ValueError(
1399
+ f"'contents' must either be a list of "
1400
+ f"langchain_core.messages.BaseMessage or a list "
1401
+ f"google.generativeai.types.content_types.ContentType, but not a mix "
1402
+ f"of the two. Received {contents}"
1403
+ )
1404
+ else:
1405
+ for content in contents:
1406
+ if hasattr(content, "role") and content.role == "system":
1407
+ if system is not None:
1408
+ warnings.warn(
1409
+ "Received multiple pieces of content with role 'system'. "
1410
+ "Should only be one set of system instructions. Ignoring "
1411
+ "all but the first 'system' content."
1412
+ )
1413
+ else:
1414
+ system = content
1415
+ elif isinstance(content, dict) and content.get("role") == "system":
1416
+ if system is not None:
1417
+ warnings.warn(
1418
+ "Received multiple pieces of content with role 'system'. "
1419
+ "Should only be one set of system instructions. Ignoring "
1420
+ "all but the first 'system' content."
1421
+ )
1422
+ else:
1423
+ system = content
1424
+ else:
1425
+ genai_contents.append(content)
1426
+ if tools:
1427
+ genai_tools = [convert_to_genai_function_declarations(tools)]
1428
+ else:
1429
+ genai_tools = None
1430
+ if tool_choice and genai_tools:
1431
+ all_names = [f.name for t in genai_tools for f in t.function_declarations]
1244
1432
  tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
1245
- return self.bind(tools=genai_tools, tool_config=tool_config, **kwargs)
1433
+ genai_tool_config = ToolConfig(
1434
+ function_calling_config=tool_config["function_calling_config"]
1435
+ )
1436
+ else:
1437
+ genai_tool_config = None
1438
+ cached_content = CachedContent.create(
1439
+ model=self.model,
1440
+ system_instruction=system,
1441
+ contents=genai_contents,
1442
+ display_name=display_name,
1443
+ tools=genai_tools,
1444
+ tool_config=genai_tool_config,
1445
+ ttl=ttl,
1446
+ expire_time=expire_time,
1447
+ )
1448
+ return cached_content.name
1246
1449
 
1247
1450
  @property
1248
1451
  def _supports_tool_choice(self) -> bool:
1249
- return "gemini-1.5-pro" in self.model
1452
+ return "gemini-1.5-pro" in self.model or "gemini-1.5-flash" in self.model
1250
1453
 
1251
1454
 
1252
1455
  def _get_tool_name(
@@ -117,7 +117,7 @@ class GenAIAqa(RunnableSerializable[AqaInput, AqaOutput]):
117
117
  self._client = _AqaModel(**kwargs)
118
118
 
119
119
  def invoke(
120
- self, input: AqaInput, config: Optional[RunnableConfig] = None
120
+ self, input: AqaInput, config: Optional[RunnableConfig] = None, **kwargs: Any
121
121
  ) -> AqaOutput:
122
122
  """Generates a grounded response using the provided passages."""
123
123
 
@@ -13,7 +13,7 @@ from langchain_core.language_models import LangSmithParams, LanguageModelInput
13
13
  from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
14
14
  from langchain_core.outputs import Generation, GenerationChunk, LLMResult
15
15
  from langchain_core.utils import secret_from_env
16
- from pydantic import BaseModel, Field, SecretStr, model_validator
16
+ from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
17
17
  from typing_extensions import Self
18
18
 
19
19
  from langchain_google_genai._enums import (
@@ -139,7 +139,7 @@ Supported examples:
139
139
  top_k: Optional[int] = None
140
140
  """Decode using top-k sampling: consider the set of top_k most probable tokens.
141
141
  Must be positive."""
142
- max_output_tokens: Optional[int] = None
142
+ max_output_tokens: Optional[int] = Field(default=None, alias="max_tokens")
143
143
  """Maximum number of tokens to include in a candidate. Must be greater than zero.
144
144
  If unset, will default to 64."""
145
145
  n: int = 1
@@ -216,6 +216,9 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
216
216
  """
217
217
 
218
218
  client: Any = None #: :meta private:
219
+ model_config = ConfigDict(
220
+ populate_by_name=True,
221
+ )
219
222
 
220
223
  @model_validator(mode="after")
221
224
  def validate_environment(self) -> Self:
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "langchain-google-genai"
3
- version = "2.0.0.dev1"
3
+ version = "2.0.2"
4
4
  description = "An integration package connecting Google's genai package and LangChain"
5
5
  authors = []
6
6
  readme = "README.md"
@@ -12,8 +12,8 @@ license = "MIT"
12
12
 
13
13
  [tool.poetry.dependencies]
14
14
  python = ">=3.9,<4.0"
15
- langchain-core = { version = "^0.3.0.dev4", allow-prereleases = true }
16
- google-generativeai = "^0.7.0"
15
+ langchain-core = ">=0.3.13,<0.4"
16
+ google-generativeai = "^0.8.0"
17
17
  pillow = { version = "^10.1.0", optional = true }
18
18
  pydantic = ">=2,<3"
19
19
 
@@ -31,8 +31,8 @@ syrupy = "^4.0.2"
31
31
  pytest-watcher = "^0.3.4"
32
32
  pytest-asyncio = "^0.21.1"
33
33
  numpy = "^1.26.2"
34
- langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core", branch = "v0.3rc" }
35
- langchain-standard-tests = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/standard-tests", branch = "v0.3rc" }
34
+ langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" }
35
+ langchain-standard-tests = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/standard-tests" }
36
36
 
37
37
  [tool.codespell]
38
38
  ignore-words-list = "rouge"
@@ -62,7 +62,7 @@ types-requests = "^2.28.11.5"
62
62
  types-google-cloud-ndb = "^2.2.0.1"
63
63
  types-pillow = "^10.1.0.2"
64
64
  types-protobuf = "^4.24.0.20240302"
65
- langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core", branch = "v0.3rc" }
65
+ langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" }
66
66
  numpy = "^1.26.2"
67
67
 
68
68
  [tool.poetry.group.dev]
@@ -73,7 +73,7 @@ pillow = "^10.1.0"
73
73
  types-requests = "^2.31.0.10"
74
74
  types-pillow = "^10.1.0.2"
75
75
  types-google-cloud-ndb = "^2.2.0.1"
76
- langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core", branch = "v0.3rc" }
76
+ langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" }
77
77
 
78
78
  [tool.ruff.lint]
79
79
  select = [