langchain-google-genai 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain-google-genai might be problematic. Click here for more details.

@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import base64
4
5
  import json
5
6
  import logging
@@ -23,14 +24,23 @@ from typing import (
23
24
  )
24
25
  from urllib.parse import urlparse
25
26
 
26
- import google.ai.generativelanguage as glm
27
27
  import google.api_core
28
28
 
29
29
  # TODO: remove ignore once the google package is published with types
30
- import google.generativeai as genai # type: ignore[import]
31
30
  import proto # type: ignore[import]
32
31
  import requests
33
- from google.generativeai.types import SafetySettingDict # type: ignore[import]
32
+ from google.ai.generativelanguage_v1beta.types import (
33
+ Candidate,
34
+ Content,
35
+ FunctionCall,
36
+ FunctionResponse,
37
+ GenerateContentRequest,
38
+ GenerateContentResponse,
39
+ GenerationConfig,
40
+ Part,
41
+ SafetySetting,
42
+ ToolConfig,
43
+ )
34
44
  from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
35
45
  from google.generativeai.types.content_types import ( # type: ignore[import]
36
46
  FunctionDeclarationType,
@@ -56,7 +66,7 @@ from langchain_core.messages import (
56
66
  )
57
67
  from langchain_core.output_parsers.openai_tools import parse_tool_calls
58
68
  from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
59
- from langchain_core.pydantic_v1 import SecretStr, root_validator
69
+ from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
60
70
  from langchain_core.runnables import Runnable
61
71
  from langchain_core.utils import get_from_dict_or_env
62
72
  from tenacity import (
@@ -67,7 +77,11 @@ from tenacity import (
67
77
  wait_exponential,
68
78
  )
69
79
 
70
- from langchain_google_genai._common import GoogleGenerativeAIError
80
+ from langchain_google_genai._common import (
81
+ GoogleGenerativeAIError,
82
+ SafetySettingDict,
83
+ get_client_info,
84
+ )
71
85
  from langchain_google_genai._function_utils import (
72
86
  _tool_choice_to_tool_config,
73
87
  _ToolChoiceType,
@@ -75,7 +89,10 @@ from langchain_google_genai._function_utils import (
75
89
  convert_to_genai_function_declarations,
76
90
  tool_to_dict,
77
91
  )
78
- from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI
92
+ from langchain_google_genai._image_utils import ImageBytesLoader
93
+ from langchain_google_genai.llms import _BaseGoogleGenerativeAI
94
+
95
+ from . import _genai_extension as genaix
79
96
 
80
97
  IMAGE_TYPES: Tuple = ()
81
98
  try:
@@ -279,18 +296,19 @@ def _url_to_pil(image_source: str) -> Image:
279
296
 
280
297
  def _convert_to_parts(
281
298
  raw_content: Union[str, Sequence[Union[str, dict]]],
282
- ) -> List[genai.types.PartType]:
299
+ ) -> List[Part]:
283
300
  """Converts a list of LangChain messages into a google parts."""
284
301
  parts = []
285
302
  content = [raw_content] if isinstance(raw_content, str) else raw_content
303
+ image_loader = ImageBytesLoader()
286
304
  for part in content:
287
305
  if isinstance(part, str):
288
- parts.append(genai.types.PartDict(text=part))
306
+ parts.append(Part(text=part))
289
307
  elif isinstance(part, Mapping):
290
308
  # OpenAI Format
291
309
  if _is_openai_parts_format(part):
292
310
  if part["type"] == "text":
293
- parts.append({"text": part["text"]})
311
+ parts.append(Part(text=part["text"]))
294
312
  elif part["type"] == "image_url":
295
313
  img_url = part["image_url"]
296
314
  if isinstance(img_url, dict):
@@ -299,7 +317,7 @@ def _convert_to_parts(
299
317
  f"Unrecognized message image format: {img_url}"
300
318
  )
301
319
  img_url = img_url["url"]
302
- parts.append({"inline_data": _url_to_pil(img_url)})
320
+ parts.append(image_loader.load_part(img_url))
303
321
  else:
304
322
  raise ValueError(f"Unrecognized message part type: {part['type']}")
305
323
  else:
@@ -307,7 +325,7 @@ def _convert_to_parts(
307
325
  logger.warning(
308
326
  "Unrecognized message part format. Assuming it's a text part."
309
327
  )
310
- parts.append(part)
328
+ parts.append(Part(text=str(part)))
311
329
  else:
312
330
  # TODO: Maybe some of Google's native stuff
313
331
  # would hit this branch.
@@ -319,33 +337,36 @@ def _convert_to_parts(
319
337
 
320
338
  def _parse_chat_history(
321
339
  input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
322
- ) -> Tuple[Optional[genai.types.ContentDict], List[genai.types.ContentDict]]:
323
- messages: List[genai.types.MessageDict] = []
340
+ ) -> Tuple[Optional[Content], List[Content]]:
341
+ messages: List[Content] = []
324
342
 
325
343
  if convert_system_message_to_human:
326
344
  warnings.warn("Convert_system_message_to_human will be deprecated!")
327
345
 
328
- system_instruction: Optional[genai.types.ContentDict] = None
346
+ system_instruction: Optional[Content] = None
329
347
  for i, message in enumerate(input_messages):
330
348
  if i == 0 and isinstance(message, SystemMessage):
331
- system_instruction = _convert_to_parts(message.content)
349
+ system_instruction = Content(parts=_convert_to_parts(message.content))
332
350
  continue
333
351
  elif isinstance(message, AIMessage):
334
352
  role = "model"
335
353
  raw_function_call = message.additional_kwargs.get("function_call")
336
354
  if raw_function_call:
337
- function_call = glm.FunctionCall(
355
+ function_call = FunctionCall(
338
356
  {
339
357
  "name": raw_function_call["name"],
340
358
  "args": json.loads(raw_function_call["arguments"]),
341
359
  }
342
360
  )
343
- parts = [glm.Part(function_call=function_call)]
361
+ parts = [Part(function_call=function_call)]
344
362
  else:
345
363
  parts = _convert_to_parts(message.content)
346
364
  elif isinstance(message, HumanMessage):
347
365
  role = "user"
348
366
  parts = _convert_to_parts(message.content)
367
+ if i == 1 and convert_system_message_to_human and system_instruction:
368
+ parts = [p for p in system_instruction.parts] + parts
369
+ system_instruction = None
349
370
  elif isinstance(message, FunctionMessage):
350
371
  role = "user"
351
372
  response: Any
@@ -357,8 +378,8 @@ def _parse_chat_history(
357
378
  except json.JSONDecodeError:
358
379
  response = message.content # leave as str representation
359
380
  parts = [
360
- glm.Part(
361
- function_response=glm.FunctionResponse(
381
+ Part(
382
+ function_response=FunctionResponse(
362
383
  name=message.name,
363
384
  response=(
364
385
  {"output": response}
@@ -391,8 +412,8 @@ def _parse_chat_history(
391
412
  except json.JSONDecodeError:
392
413
  tool_response = message.content # leave as str representation
393
414
  parts = [
394
- glm.Part(
395
- function_response=glm.FunctionResponse(
415
+ Part(
416
+ function_response=FunctionResponse(
396
417
  name=name,
397
418
  response=(
398
419
  {"output": tool_response}
@@ -407,12 +428,12 @@ def _parse_chat_history(
407
428
  f"Unexpected message with type {type(message)} at the position {i}."
408
429
  )
409
430
 
410
- messages.append({"role": role, "parts": parts})
431
+ messages.append(Content(role=role, parts=parts))
411
432
  return system_instruction, messages
412
433
 
413
434
 
414
435
  def _parse_response_candidate(
415
- response_candidate: glm.Candidate, streaming: bool = False
436
+ response_candidate: Candidate, streaming: bool = False
416
437
  ) -> AIMessage:
417
438
  content: Union[None, str, List[str]] = None
418
439
  additional_kwargs = {}
@@ -499,7 +520,7 @@ def _parse_response_candidate(
499
520
 
500
521
 
501
522
  def _response_to_result(
502
- response: glm.GenerateContentResponse,
523
+ response: GenerateContentResponse,
503
524
  stream: bool = False,
504
525
  ) -> ChatResult:
505
526
  """Converts a PaLM API response into a LangChain ChatResult."""
@@ -538,6 +559,14 @@ def _response_to_result(
538
559
  return ChatResult(generations=generations, llm_output=llm_output)
539
560
 
540
561
 
562
+ def _is_event_loop_running() -> bool:
563
+ try:
564
+ asyncio.get_running_loop()
565
+ return True
566
+ except RuntimeError:
567
+ return False
568
+
569
+
541
570
  class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
542
571
  """`Google Generative AI` Chat models API.
543
572
 
@@ -557,6 +586,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
557
586
  """
558
587
 
559
588
  client: Any #: :meta private:
589
+ async_client: Any #: :meta private:
590
+ default_metadata: Sequence[Tuple[str, str]] = Field(
591
+ default_factory=list
592
+ ) #: :meta private:
560
593
 
561
594
  convert_system_message_to_human: bool = False
562
595
  """Whether to merge any leading SystemMessage into the following HumanMessage.
@@ -582,29 +615,6 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
582
615
  @root_validator()
583
616
  def validate_environment(cls, values: Dict) -> Dict:
584
617
  """Validates params and passes them to google-generativeai package."""
585
- additional_headers = values.get("additional_headers") or {}
586
- default_metadata = tuple(additional_headers.items())
587
-
588
- if values.get("credentials"):
589
- genai.configure(
590
- credentials=values.get("credentials"),
591
- transport=values.get("transport"),
592
- client_options=values.get("client_options"),
593
- default_metadata=default_metadata,
594
- )
595
- else:
596
- google_api_key = get_from_dict_or_env(
597
- values, "google_api_key", "GOOGLE_API_KEY"
598
- )
599
- if isinstance(google_api_key, SecretStr):
600
- google_api_key = google_api_key.get_secret_value()
601
-
602
- genai.configure(
603
- api_key=google_api_key,
604
- transport=values.get("transport"),
605
- client_options=values.get("client_options"),
606
- default_metadata=default_metadata,
607
- )
608
618
  if (
609
619
  values.get("temperature") is not None
610
620
  and not 0 <= values["temperature"] <= 1
@@ -616,8 +626,45 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
616
626
 
617
627
  if values.get("top_k") is not None and values["top_k"] <= 0:
618
628
  raise ValueError("top_k must be positive")
619
- model = values["model"]
620
- values["client"] = genai.GenerativeModel(model_name=model)
629
+
630
+ if not values["model"].startswith("models/"):
631
+ values["model"] = f"models/{values['model']}"
632
+
633
+ additional_headers = values.get("additional_headers") or {}
634
+ values["default_metadata"] = tuple(additional_headers.items())
635
+ client_info = get_client_info("ChatGoogleGenerativeAI")
636
+ google_api_key = None
637
+ if not values.get("credentials"):
638
+ google_api_key = get_from_dict_or_env(
639
+ values, "google_api_key", "GOOGLE_API_KEY"
640
+ )
641
+ if isinstance(google_api_key, SecretStr):
642
+ google_api_key = google_api_key.get_secret_value()
643
+ transport: Optional[str] = values.get("transport")
644
+ values["client"] = genaix.build_generative_service(
645
+ credentials=values.get("credentials"),
646
+ api_key=google_api_key,
647
+ client_info=client_info,
648
+ client_options=values.get("client_options"),
649
+ transport=transport,
650
+ )
651
+
652
+ # NOTE: genaix.build_generative_async_service requires
653
+ # a running event loop, which causes an error
654
+ # when initialized inside a ThreadPoolExecutor.
655
+ # this check ensures that async client is only initialized
656
+ # within an asyncio event loop to avoid the error
657
+ if _is_event_loop_running():
658
+ values["async_client"] = genaix.build_generative_async_service(
659
+ credentials=values.get("credentials"),
660
+ api_key=google_api_key,
661
+ client_info=client_info,
662
+ client_options=values.get("client_options"),
663
+ transport=transport,
664
+ )
665
+ else:
666
+ values["async_client"] = None
667
+
621
668
  return values
622
669
 
623
670
  @property
@@ -632,8 +679,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
632
679
  }
633
680
 
634
681
  def _prepare_params(
635
- self, stop: Optional[List[str]], **kwargs: Any
636
- ) -> Dict[str, Any]:
682
+ self,
683
+ stop: Optional[List[str]],
684
+ generation_config: Optional[Dict[str, Any]] = None,
685
+ ) -> GenerationConfig:
637
686
  gen_config = {
638
687
  k: v
639
688
  for k, v in {
@@ -646,27 +695,37 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
646
695
  }.items()
647
696
  if v is not None
648
697
  }
649
- if "generation_config" in kwargs:
650
- gen_config = {**gen_config, **kwargs.pop("generation_config")}
651
- params = {"generation_config": gen_config, **kwargs}
652
- return params
698
+ if generation_config:
699
+ gen_config = {**gen_config, **generation_config}
700
+ return GenerationConfig(**gen_config)
653
701
 
654
702
  def _generate(
655
703
  self,
656
704
  messages: List[BaseMessage],
657
705
  stop: Optional[List[str]] = None,
658
706
  run_manager: Optional[CallbackManagerForLLMRun] = None,
707
+ *,
708
+ tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
709
+ functions: Optional[Sequence[FunctionDeclarationType]] = None,
710
+ safety_settings: Optional[SafetySettingDict] = None,
711
+ tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
712
+ generation_config: Optional[Dict[str, Any]] = None,
659
713
  **kwargs: Any,
660
714
  ) -> ChatResult:
661
- params, chat, message = self._prepare_chat(
715
+ request = self._prepare_request(
662
716
  messages,
663
717
  stop=stop,
664
- **kwargs,
718
+ tools=tools,
719
+ functions=functions,
720
+ safety_settings=safety_settings,
721
+ tool_config=tool_config,
722
+ generation_config=generation_config,
665
723
  )
666
- response: genai.types.GenerateContentResponse = _chat_with_retry(
667
- content=message,
668
- **params,
669
- generation_method=chat.send_message,
724
+ response: GenerateContentResponse = _chat_with_retry(
725
+ request=request,
726
+ **kwargs,
727
+ generation_method=self.client.generate_content,
728
+ metadata=self.default_metadata,
670
729
  )
671
730
  return _response_to_result(response)
672
731
 
@@ -675,17 +734,34 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
675
734
  messages: List[BaseMessage],
676
735
  stop: Optional[List[str]] = None,
677
736
  run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
737
+ *,
738
+ tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
739
+ functions: Optional[Sequence[FunctionDeclarationType]] = None,
740
+ safety_settings: Optional[SafetySettingDict] = None,
741
+ tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
742
+ generation_config: Optional[Dict[str, Any]] = None,
678
743
  **kwargs: Any,
679
744
  ) -> ChatResult:
680
- params, chat, message = self._prepare_chat(
745
+ if not self.async_client:
746
+ raise RuntimeError(
747
+ "Initialize ChatGoogleGenerativeAI with a running event loop "
748
+ "to use async methods."
749
+ )
750
+
751
+ request = self._prepare_request(
681
752
  messages,
682
753
  stop=stop,
683
- **kwargs,
754
+ tools=tools,
755
+ functions=functions,
756
+ safety_settings=safety_settings,
757
+ tool_config=tool_config,
758
+ generation_config=generation_config,
684
759
  )
685
- response: genai.types.GenerateContentResponse = await _achat_with_retry(
686
- content=message,
687
- **params,
688
- generation_method=chat.send_message_async,
760
+ response: GenerateContentResponse = await _achat_with_retry(
761
+ request=request,
762
+ **kwargs,
763
+ generation_method=self.async_client.generate_content,
764
+ metadata=self.default_metadata,
689
765
  )
690
766
  return _response_to_result(response)
691
767
 
@@ -694,18 +770,28 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
694
770
  messages: List[BaseMessage],
695
771
  stop: Optional[List[str]] = None,
696
772
  run_manager: Optional[CallbackManagerForLLMRun] = None,
773
+ *,
774
+ tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
775
+ functions: Optional[Sequence[FunctionDeclarationType]] = None,
776
+ safety_settings: Optional[SafetySettingDict] = None,
777
+ tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
778
+ generation_config: Optional[Dict[str, Any]] = None,
697
779
  **kwargs: Any,
698
780
  ) -> Iterator[ChatGenerationChunk]:
699
- params, chat, message = self._prepare_chat(
781
+ request = self._prepare_request(
700
782
  messages,
701
783
  stop=stop,
702
- **kwargs,
784
+ tools=tools,
785
+ functions=functions,
786
+ safety_settings=safety_settings,
787
+ tool_config=tool_config,
788
+ generation_config=generation_config,
703
789
  )
704
- response: genai.types.GenerateContentResponse = _chat_with_retry(
705
- content=message,
706
- **params,
707
- generation_method=chat.send_message,
708
- stream=True,
790
+ response: GenerateContentResponse = _chat_with_retry(
791
+ request=request,
792
+ generation_method=self.client.stream_generate_content,
793
+ **kwargs,
794
+ metadata=self.default_metadata,
709
795
  )
710
796
  for chunk in response:
711
797
  _chat_result = _response_to_result(chunk, stream=True)
@@ -720,18 +806,28 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
720
806
  messages: List[BaseMessage],
721
807
  stop: Optional[List[str]] = None,
722
808
  run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
809
+ *,
810
+ tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
811
+ functions: Optional[Sequence[FunctionDeclarationType]] = None,
812
+ safety_settings: Optional[SafetySettingDict] = None,
813
+ tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
814
+ generation_config: Optional[Dict[str, Any]] = None,
723
815
  **kwargs: Any,
724
816
  ) -> AsyncIterator[ChatGenerationChunk]:
725
- params, chat, message = self._prepare_chat(
817
+ request = self._prepare_request(
726
818
  messages,
727
819
  stop=stop,
728
- **kwargs,
820
+ tools=tools,
821
+ functions=functions,
822
+ safety_settings=safety_settings,
823
+ tool_config=tool_config,
824
+ generation_config=generation_config,
729
825
  )
730
826
  async for chunk in await _achat_with_retry(
731
- content=message,
732
- **params,
733
- generation_method=chat.send_message_async,
734
- stream=True,
827
+ request=request,
828
+ generation_method=self.async_client.stream_generate_content,
829
+ **kwargs,
830
+ metadata=self.default_metadata,
735
831
  ):
736
832
  _chat_result = _response_to_result(chunk, stream=True)
737
833
  gen = cast(ChatGenerationChunk, _chat_result.generations[0])
@@ -740,17 +836,17 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
740
836
  await run_manager.on_llm_new_token(gen.text)
741
837
  yield gen
742
838
 
743
- def _prepare_chat(
839
+ def _prepare_request(
744
840
  self,
745
841
  messages: List[BaseMessage],
842
+ *,
746
843
  stop: Optional[List[str]] = None,
747
844
  tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
748
845
  functions: Optional[Sequence[FunctionDeclarationType]] = None,
749
846
  safety_settings: Optional[SafetySettingDict] = None,
750
847
  tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
751
- **kwargs: Any,
752
- ) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
753
- client = self.client
848
+ generation_config: Optional[Dict[str, Any]] = None,
849
+ ) -> Tuple[GenerateContentRequest, Dict[str, Any]]:
754
850
  formatted_tools = None
755
851
  if tools:
756
852
  formatted_tools = [
@@ -759,25 +855,35 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
759
855
  elif functions:
760
856
  formatted_tools = [convert_to_genai_function_declarations(functions)]
761
857
 
762
- if formatted_tools or safety_settings:
763
- client = genai.GenerativeModel(
764
- model_name=self.model,
765
- tools=formatted_tools,
766
- safety_settings=safety_settings,
767
- )
768
-
769
- params = self._prepare_params(stop, tool_config=tool_config, **kwargs)
770
858
  system_instruction, history = _parse_chat_history(
771
859
  messages,
772
860
  convert_system_message_to_human=self.convert_system_message_to_human,
773
861
  )
774
- message = history.pop()
775
- if self.client._system_instruction != system_instruction:
776
- self.client = genai.GenerativeModel(
777
- model_name=self.model, system_instruction=system_instruction
862
+ formatted_tool_config = None
863
+ if tool_config:
864
+ formatted_tool_config = ToolConfig(
865
+ function_calling_config=tool_config["function_calling_config"]
778
866
  )
779
- chat = client.start_chat(history=history)
780
- return params, chat, message
867
+ formatted_safety_settings = []
868
+ if safety_settings:
869
+ formatted_safety_settings = [
870
+ SafetySetting(category=c, threshold=t)
871
+ for c, t in safety_settings.items()
872
+ ]
873
+ request = GenerateContentRequest(
874
+ model=self.model,
875
+ contents=history,
876
+ tools=formatted_tools,
877
+ tool_config=formatted_tool_config,
878
+ safety_settings=formatted_safety_settings,
879
+ generation_config=self._prepare_params(
880
+ stop, generation_config=generation_config
881
+ ),
882
+ )
883
+ if system_instruction:
884
+ request.system_instruction = system_instruction
885
+
886
+ return request
781
887
 
782
888
  def get_num_tokens(self, text: str) -> int:
783
889
  """Get the number of tokens present in the text.
@@ -790,14 +896,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
790
896
  Returns:
791
897
  The integer number of tokens in the text.
792
898
  """
793
- if self._model_family == GoogleModelFamily.GEMINI:
794
- result = self.client.count_tokens(text)
795
- token_count = result.total_tokens
796
- else:
797
- result = self.client.count_text_tokens(model=self.model, prompt=text)
798
- token_count = result["token_count"]
799
-
800
- return token_count
899
+ result = self.client.count_tokens(
900
+ model=self.model, contents=[Content(parts=[Part(text=text)])]
901
+ )
902
+ return result.total_tokens
801
903
 
802
904
  def bind_tools(
803
905
  self,
@@ -828,7 +930,9 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
828
930
  genai_tools = [tool_to_dict(convert_to_genai_function_declarations(tools))]
829
931
  if tool_choice:
830
932
  all_names = [
831
- f["name"] for t in genai_tools for f in t["function_declarations"]
933
+ f["name"] # type: ignore[index]
934
+ for t in genai_tools
935
+ for f in t["function_declarations"]
832
936
  ]
833
937
  tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
834
938
  return self.bind(tools=genai_tools, tool_config=tool_config, **kwargs)