langchain-google-genai 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain-google-genai might be problematic. Click here for more details.

@@ -4,6 +4,7 @@ import base64
4
4
  import json
5
5
  import logging
6
6
  import os
7
+ import uuid
7
8
  import warnings
8
9
  from io import BytesIO
9
10
  from typing import (
@@ -22,17 +23,33 @@ from typing import (
22
23
  )
23
24
  from urllib.parse import urlparse
24
25
 
25
- import google.ai.generativelanguage as glm
26
26
  import google.api_core
27
27
 
28
28
  # TODO: remove ignore once the google package is published with types
29
- import google.generativeai as genai # type: ignore[import]
30
29
  import proto # type: ignore[import]
31
30
  import requests
31
+ from google.ai.generativelanguage_v1beta.types import (
32
+ Candidate,
33
+ Content,
34
+ FunctionCall,
35
+ FunctionResponse,
36
+ GenerateContentRequest,
37
+ GenerateContentResponse,
38
+ GenerationConfig,
39
+ Part,
40
+ SafetySetting,
41
+ ToolConfig,
42
+ )
43
+ from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
44
+ from google.generativeai.types.content_types import ( # type: ignore[import]
45
+ FunctionDeclarationType,
46
+ ToolDict,
47
+ )
32
48
  from langchain_core.callbacks.manager import (
33
49
  AsyncCallbackManagerForLLMRun,
34
50
  CallbackManagerForLLMRun,
35
51
  )
52
+ from langchain_core.language_models import LanguageModelInput
36
53
  from langchain_core.language_models.chat_models import BaseChatModel
37
54
  from langchain_core.messages import (
38
55
  AIMessage,
@@ -40,10 +57,16 @@ from langchain_core.messages import (
40
57
  BaseMessage,
41
58
  FunctionMessage,
42
59
  HumanMessage,
60
+ InvalidToolCall,
43
61
  SystemMessage,
62
+ ToolCall,
63
+ ToolCallChunk,
64
+ ToolMessage,
44
65
  )
66
+ from langchain_core.output_parsers.openai_tools import parse_tool_calls
45
67
  from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
46
- from langchain_core.pydantic_v1 import SecretStr, root_validator
68
+ from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
69
+ from langchain_core.runnables import Runnable
47
70
  from langchain_core.utils import get_from_dict_or_env
48
71
  from tenacity import (
49
72
  before_sleep_log,
@@ -53,11 +76,22 @@ from tenacity import (
53
76
  wait_exponential,
54
77
  )
55
78
 
56
- from langchain_google_genai._common import GoogleGenerativeAIError
79
+ from langchain_google_genai._common import (
80
+ GoogleGenerativeAIError,
81
+ SafetySettingDict,
82
+ get_client_info,
83
+ )
57
84
  from langchain_google_genai._function_utils import (
85
+ _tool_choice_to_tool_config,
86
+ _ToolChoiceType,
87
+ _ToolConfigDict,
58
88
  convert_to_genai_function_declarations,
89
+ tool_to_dict,
59
90
  )
60
- from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI
91
+ from langchain_google_genai._image_utils import ImageBytesLoader
92
+ from langchain_google_genai.llms import _BaseGoogleGenerativeAI
93
+
94
+ from . import _genai_extension as genaix
61
95
 
62
96
  IMAGE_TYPES: Tuple = ()
63
97
  try:
@@ -261,18 +295,19 @@ def _url_to_pil(image_source: str) -> Image:
261
295
 
262
296
  def _convert_to_parts(
263
297
  raw_content: Union[str, Sequence[Union[str, dict]]],
264
- ) -> List[genai.types.PartType]:
298
+ ) -> List[Part]:
265
299
  """Converts a list of LangChain messages into a google parts."""
266
300
  parts = []
267
301
  content = [raw_content] if isinstance(raw_content, str) else raw_content
302
+ image_loader = ImageBytesLoader()
268
303
  for part in content:
269
304
  if isinstance(part, str):
270
- parts.append(genai.types.PartDict(text=part))
305
+ parts.append(Part(text=part))
271
306
  elif isinstance(part, Mapping):
272
307
  # OpenAI Format
273
308
  if _is_openai_parts_format(part):
274
309
  if part["type"] == "text":
275
- parts.append({"text": part["text"]})
310
+ parts.append(Part(text=part["text"]))
276
311
  elif part["type"] == "image_url":
277
312
  img_url = part["image_url"]
278
313
  if isinstance(img_url, dict):
@@ -281,7 +316,7 @@ def _convert_to_parts(
281
316
  f"Unrecognized message image format: {img_url}"
282
317
  )
283
318
  img_url = img_url["url"]
284
- parts.append({"inline_data": _url_to_pil(img_url)})
319
+ parts.append(image_loader.load_part(img_url))
285
320
  else:
286
321
  raise ValueError(f"Unrecognized message part type: {part['type']}")
287
322
  else:
@@ -289,7 +324,7 @@ def _convert_to_parts(
289
324
  logger.warning(
290
325
  "Unrecognized message part format. Assuming it's a text part."
291
326
  )
292
- parts.append(part)
327
+ parts.append(Part(text=str(part)))
293
328
  else:
294
329
  # TODO: Maybe some of Google's native stuff
295
330
  # would hit this branch.
@@ -301,33 +336,36 @@ def _convert_to_parts(
301
336
 
302
337
  def _parse_chat_history(
303
338
  input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
304
- ) -> Tuple[Optional[genai.types.ContentDict], List[genai.types.ContentDict]]:
305
- messages: List[genai.types.MessageDict] = []
339
+ ) -> Tuple[Optional[Content], List[Content]]:
340
+ messages: List[Content] = []
306
341
 
307
342
  if convert_system_message_to_human:
308
343
  warnings.warn("Convert_system_message_to_human will be deprecated!")
309
344
 
310
- system_instruction: Optional[genai.types.ContentDict] = None
345
+ system_instruction: Optional[Content] = None
311
346
  for i, message in enumerate(input_messages):
312
347
  if i == 0 and isinstance(message, SystemMessage):
313
- system_instruction = _convert_to_parts(message.content)
348
+ system_instruction = Content(parts=_convert_to_parts(message.content))
314
349
  continue
315
350
  elif isinstance(message, AIMessage):
316
351
  role = "model"
317
352
  raw_function_call = message.additional_kwargs.get("function_call")
318
353
  if raw_function_call:
319
- function_call = glm.FunctionCall(
354
+ function_call = FunctionCall(
320
355
  {
321
356
  "name": raw_function_call["name"],
322
357
  "args": json.loads(raw_function_call["arguments"]),
323
358
  }
324
359
  )
325
- parts = [glm.Part(function_call=function_call)]
360
+ parts = [Part(function_call=function_call)]
326
361
  else:
327
362
  parts = _convert_to_parts(message.content)
328
363
  elif isinstance(message, HumanMessage):
329
364
  role = "user"
330
365
  parts = _convert_to_parts(message.content)
366
+ if i == 1 and convert_system_message_to_human and system_instruction:
367
+ parts = [p for p in system_instruction.parts] + parts
368
+ system_instruction = None
331
369
  elif isinstance(message, FunctionMessage):
332
370
  role = "user"
333
371
  response: Any
@@ -339,8 +377,8 @@ def _parse_chat_history(
339
377
  except json.JSONDecodeError:
340
378
  response = message.content # leave as str representation
341
379
  parts = [
342
- glm.Part(
343
- function_response=glm.FunctionResponse(
380
+ Part(
381
+ function_response=FunctionResponse(
344
382
  name=message.name,
345
383
  response=(
346
384
  {"output": response}
@@ -350,39 +388,138 @@ def _parse_chat_history(
350
388
  )
351
389
  )
352
390
  ]
391
+ elif isinstance(message, ToolMessage):
392
+ role = "user"
393
+ prev_message: Optional[BaseMessage] = (
394
+ input_messages[i - 1] if i > 0 else None
395
+ )
396
+ if (
397
+ prev_message
398
+ and isinstance(prev_message, AIMessage)
399
+ and prev_message.tool_calls
400
+ ):
401
+ # message.name can be null for ToolMessage
402
+ name: str = prev_message.tool_calls[0]["name"]
403
+ else:
404
+ name = message.name # type: ignore
405
+ tool_response: Any
406
+ if not isinstance(message.content, str):
407
+ tool_response = message.content
408
+ else:
409
+ try:
410
+ tool_response = json.loads(message.content)
411
+ except json.JSONDecodeError:
412
+ tool_response = message.content # leave as str representation
413
+ parts = [
414
+ Part(
415
+ function_response=FunctionResponse(
416
+ name=name,
417
+ response=(
418
+ {"output": tool_response}
419
+ if not isinstance(tool_response, dict)
420
+ else tool_response
421
+ ),
422
+ )
423
+ )
424
+ ]
353
425
  else:
354
426
  raise ValueError(
355
427
  f"Unexpected message with type {type(message)} at the position {i}."
356
428
  )
357
429
 
358
- messages.append({"role": role, "parts": parts})
430
+ messages.append(Content(role=role, parts=parts))
359
431
  return system_instruction, messages
360
432
 
361
433
 
362
434
  def _parse_response_candidate(
363
- response_candidate: glm.Candidate, stream: bool
435
+ response_candidate: Candidate, streaming: bool = False
364
436
  ) -> AIMessage:
365
- first_part = response_candidate.content.parts[0]
366
- if first_part.function_call:
367
- function_call = proto.Message.to_dict(first_part.function_call)
368
- function_call["arguments"] = json.dumps(function_call.pop("args", {}))
369
- return (AIMessageChunk if stream else AIMessage)(
370
- content="", additional_kwargs={"function_call": function_call}
437
+ content: Union[None, str, List[str]] = None
438
+ additional_kwargs = {}
439
+ tool_calls = []
440
+ invalid_tool_calls = []
441
+ tool_call_chunks = []
442
+
443
+ for part in response_candidate.content.parts:
444
+ try:
445
+ text: Optional[str] = part.text
446
+ except AttributeError:
447
+ text = None
448
+
449
+ if text is not None:
450
+ if not content:
451
+ content = text
452
+ elif isinstance(content, str) and text:
453
+ content = [content, text]
454
+ elif isinstance(content, list) and text:
455
+ content.append(text)
456
+ elif text:
457
+ raise Exception("Unexpected content type")
458
+
459
+ if part.function_call:
460
+ # TODO: support multiple function calls
461
+ if "function_call" in additional_kwargs:
462
+ raise Exception("Multiple function calls are not currently supported")
463
+ function_call = {"name": part.function_call.name}
464
+ # dump to match other function calling llm for now
465
+ function_call_args_dict = proto.Message.to_dict(part.function_call)["args"]
466
+ function_call["arguments"] = json.dumps(
467
+ {k: function_call_args_dict[k] for k in function_call_args_dict}
468
+ )
469
+ additional_kwargs["function_call"] = function_call
470
+
471
+ if streaming:
472
+ tool_call_chunks.append(
473
+ ToolCallChunk(
474
+ name=function_call.get("name"),
475
+ args=function_call.get("arguments"),
476
+ id=function_call.get("id", str(uuid.uuid4())),
477
+ index=function_call.get("index"), # type: ignore
478
+ )
479
+ )
480
+ else:
481
+ try:
482
+ tool_calls_dicts = parse_tool_calls(
483
+ [{"function": function_call}],
484
+ return_id=False,
485
+ )
486
+ tool_calls = [
487
+ ToolCall(
488
+ name=tool_call["name"],
489
+ args=tool_call["args"],
490
+ id=tool_call.get("id", str(uuid.uuid4())),
491
+ )
492
+ for tool_call in tool_calls_dicts
493
+ ]
494
+ except Exception as e:
495
+ invalid_tool_calls = [
496
+ InvalidToolCall(
497
+ name=function_call.get("name"),
498
+ args=function_call.get("arguments"),
499
+ id=function_call.get("id", str(uuid.uuid4())),
500
+ error=str(e),
501
+ )
502
+ ]
503
+ if content is None:
504
+ content = ""
505
+
506
+ if streaming:
507
+ return AIMessageChunk(
508
+ content=cast(Union[str, List[Union[str, Dict[Any, Any]]]], content),
509
+ additional_kwargs=additional_kwargs,
510
+ tool_call_chunks=tool_call_chunks,
371
511
  )
372
- else:
373
- parts = response_candidate.content.parts
374
-
375
- if len(parts) == 1 and parts[0].text:
376
- content: Union[str, List[Union[str, Dict]]] = parts[0].text
377
- else:
378
- content = [proto.Message.to_dict(part) for part in parts]
379
- return (AIMessageChunk if stream else AIMessage)(
380
- content=content, additional_kwargs={}
512
+
513
+ return AIMessage(
514
+ content=cast(Union[str, List[Union[str, Dict[Any, Any]]]], content),
515
+ additional_kwargs=additional_kwargs,
516
+ tool_calls=tool_calls,
517
+ invalid_tool_calls=invalid_tool_calls,
381
518
  )
382
519
 
383
520
 
384
521
  def _response_to_result(
385
- response: glm.GenerateContentResponse,
522
+ response: GenerateContentResponse,
386
523
  stream: bool = False,
387
524
  ) -> ChatResult:
388
525
  """Converts a PaLM API response into a LangChain ChatResult."""
@@ -400,7 +537,7 @@ def _response_to_result(
400
537
  ]
401
538
  generations.append(
402
539
  (ChatGenerationChunk if stream else ChatGeneration)(
403
- message=_parse_response_candidate(candidate, stream=stream),
540
+ message=_parse_response_candidate(candidate, streaming=stream),
404
541
  generation_info=generation_info,
405
542
  )
406
543
  )
@@ -440,6 +577,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
440
577
  """
441
578
 
442
579
  client: Any #: :meta private:
580
+ async_client: Any #: :meta private:
581
+ default_metadata: Sequence[Tuple[str, str]] = Field(
582
+ default_factory=list
583
+ ) #: :meta private:
443
584
 
444
585
  convert_system_message_to_human: bool = False
445
586
  """Whether to merge any leading SystemMessage into the following HumanMessage.
@@ -465,29 +606,6 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
465
606
  @root_validator()
466
607
  def validate_environment(cls, values: Dict) -> Dict:
467
608
  """Validates params and passes them to google-generativeai package."""
468
- additional_headers = values.get("additional_headers") or {}
469
- default_metadata = tuple(additional_headers.items())
470
-
471
- if values.get("credentials"):
472
- genai.configure(
473
- credentials=values.get("credentials"),
474
- transport=values.get("transport"),
475
- client_options=values.get("client_options"),
476
- default_metadata=default_metadata,
477
- )
478
- else:
479
- google_api_key = get_from_dict_or_env(
480
- values, "google_api_key", "GOOGLE_API_KEY"
481
- )
482
- if isinstance(google_api_key, SecretStr):
483
- google_api_key = google_api_key.get_secret_value()
484
-
485
- genai.configure(
486
- api_key=google_api_key,
487
- transport=values.get("transport"),
488
- client_options=values.get("client_options"),
489
- default_metadata=default_metadata,
490
- )
491
609
  if (
492
610
  values.get("temperature") is not None
493
611
  and not 0 <= values["temperature"] <= 1
@@ -499,8 +617,36 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
499
617
 
500
618
  if values.get("top_k") is not None and values["top_k"] <= 0:
501
619
  raise ValueError("top_k must be positive")
502
- model = values["model"]
503
- values["client"] = genai.GenerativeModel(model_name=model)
620
+
621
+ if not values["model"].startswith("models/"):
622
+ values["model"] = f"models/{values['model']}"
623
+
624
+ additional_headers = values.get("additional_headers") or {}
625
+ values["default_metadata"] = tuple(additional_headers.items())
626
+ client_info = get_client_info("ChatGoogleGenerativeAI")
627
+ google_api_key = None
628
+ if not values.get("credentials"):
629
+ google_api_key = get_from_dict_or_env(
630
+ values, "google_api_key", "GOOGLE_API_KEY"
631
+ )
632
+ if isinstance(google_api_key, SecretStr):
633
+ google_api_key = google_api_key.get_secret_value()
634
+ transport: Optional[str] = values.get("transport")
635
+ values["client"] = genaix.build_generative_service(
636
+ credentials=values.get("credentials"),
637
+ api_key=google_api_key,
638
+ client_info=client_info,
639
+ client_options=values.get("client_options"),
640
+ transport=transport,
641
+ )
642
+ values["async_client"] = genaix.build_generative_async_service(
643
+ credentials=values.get("credentials"),
644
+ api_key=google_api_key,
645
+ client_info=client_info,
646
+ client_options=values.get("client_options"),
647
+ transport=transport,
648
+ )
649
+
504
650
  return values
505
651
 
506
652
  @property
@@ -515,8 +661,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
515
661
  }
516
662
 
517
663
  def _prepare_params(
518
- self, stop: Optional[List[str]], **kwargs: Any
519
- ) -> Dict[str, Any]:
664
+ self,
665
+ stop: Optional[List[str]],
666
+ generation_config: Optional[Dict[str, Any]] = None,
667
+ ) -> GenerationConfig:
520
668
  gen_config = {
521
669
  k: v
522
670
  for k, v in {
@@ -529,27 +677,37 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
529
677
  }.items()
530
678
  if v is not None
531
679
  }
532
- if "generation_config" in kwargs:
533
- gen_config = {**gen_config, **kwargs.pop("generation_config")}
534
- params = {"generation_config": gen_config, **kwargs}
535
- return params
680
+ if generation_config:
681
+ gen_config = {**gen_config, **generation_config}
682
+ return GenerationConfig(**gen_config)
536
683
 
537
684
  def _generate(
538
685
  self,
539
686
  messages: List[BaseMessage],
540
687
  stop: Optional[List[str]] = None,
541
688
  run_manager: Optional[CallbackManagerForLLMRun] = None,
689
+ *,
690
+ tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
691
+ functions: Optional[Sequence[FunctionDeclarationType]] = None,
692
+ safety_settings: Optional[SafetySettingDict] = None,
693
+ tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
694
+ generation_config: Optional[Dict[str, Any]] = None,
542
695
  **kwargs: Any,
543
696
  ) -> ChatResult:
544
- params, chat, message = self._prepare_chat(
697
+ request = self._prepare_request(
545
698
  messages,
546
699
  stop=stop,
547
- **kwargs,
700
+ tools=tools,
701
+ functions=functions,
702
+ safety_settings=safety_settings,
703
+ tool_config=tool_config,
704
+ generation_config=generation_config,
548
705
  )
549
- response: genai.types.GenerateContentResponse = _chat_with_retry(
550
- content=message,
551
- **params,
552
- generation_method=chat.send_message,
706
+ response: GenerateContentResponse = _chat_with_retry(
707
+ request=request,
708
+ **kwargs,
709
+ generation_method=self.client.generate_content,
710
+ metadata=self.default_metadata,
553
711
  )
554
712
  return _response_to_result(response)
555
713
 
@@ -558,17 +716,28 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
558
716
  messages: List[BaseMessage],
559
717
  stop: Optional[List[str]] = None,
560
718
  run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
719
+ *,
720
+ tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
721
+ functions: Optional[Sequence[FunctionDeclarationType]] = None,
722
+ safety_settings: Optional[SafetySettingDict] = None,
723
+ tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
724
+ generation_config: Optional[Dict[str, Any]] = None,
561
725
  **kwargs: Any,
562
726
  ) -> ChatResult:
563
- params, chat, message = self._prepare_chat(
727
+ request = self._prepare_request(
564
728
  messages,
565
729
  stop=stop,
566
- **kwargs,
730
+ tools=tools,
731
+ functions=functions,
732
+ safety_settings=safety_settings,
733
+ tool_config=tool_config,
734
+ generation_config=generation_config,
567
735
  )
568
- response: genai.types.GenerateContentResponse = await _achat_with_retry(
569
- content=message,
570
- **params,
571
- generation_method=chat.send_message_async,
736
+ response: GenerateContentResponse = await _achat_with_retry(
737
+ request=request,
738
+ **kwargs,
739
+ generation_method=self.async_client.generate_content,
740
+ metadata=self.default_metadata,
572
741
  )
573
742
  return _response_to_result(response)
574
743
 
@@ -577,18 +746,28 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
577
746
  messages: List[BaseMessage],
578
747
  stop: Optional[List[str]] = None,
579
748
  run_manager: Optional[CallbackManagerForLLMRun] = None,
749
+ *,
750
+ tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
751
+ functions: Optional[Sequence[FunctionDeclarationType]] = None,
752
+ safety_settings: Optional[SafetySettingDict] = None,
753
+ tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
754
+ generation_config: Optional[Dict[str, Any]] = None,
580
755
  **kwargs: Any,
581
756
  ) -> Iterator[ChatGenerationChunk]:
582
- params, chat, message = self._prepare_chat(
757
+ request = self._prepare_request(
583
758
  messages,
584
759
  stop=stop,
585
- **kwargs,
760
+ tools=tools,
761
+ functions=functions,
762
+ safety_settings=safety_settings,
763
+ tool_config=tool_config,
764
+ generation_config=generation_config,
586
765
  )
587
- response: genai.types.GenerateContentResponse = _chat_with_retry(
588
- content=message,
589
- **params,
590
- generation_method=chat.send_message,
591
- stream=True,
766
+ response: GenerateContentResponse = _chat_with_retry(
767
+ request=request,
768
+ generation_method=self.client.stream_generate_content,
769
+ **kwargs,
770
+ metadata=self.default_metadata,
592
771
  )
593
772
  for chunk in response:
594
773
  _chat_result = _response_to_result(chunk, stream=True)
@@ -603,18 +782,28 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
603
782
  messages: List[BaseMessage],
604
783
  stop: Optional[List[str]] = None,
605
784
  run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
785
+ *,
786
+ tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
787
+ functions: Optional[Sequence[FunctionDeclarationType]] = None,
788
+ safety_settings: Optional[SafetySettingDict] = None,
789
+ tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
790
+ generation_config: Optional[Dict[str, Any]] = None,
606
791
  **kwargs: Any,
607
792
  ) -> AsyncIterator[ChatGenerationChunk]:
608
- params, chat, message = self._prepare_chat(
793
+ request = self._prepare_request(
609
794
  messages,
610
795
  stop=stop,
611
- **kwargs,
796
+ tools=tools,
797
+ functions=functions,
798
+ safety_settings=safety_settings,
799
+ tool_config=tool_config,
800
+ generation_config=generation_config,
612
801
  )
613
802
  async for chunk in await _achat_with_retry(
614
- content=message,
615
- **params,
616
- generation_method=chat.send_message_async,
617
- stream=True,
803
+ request=request,
804
+ generation_method=self.async_client.stream_generate_content,
805
+ **kwargs,
806
+ metadata=self.default_metadata,
618
807
  ):
619
808
  _chat_result = _response_to_result(chunk, stream=True)
620
809
  gen = cast(ChatGenerationChunk, _chat_result.generations[0])
@@ -623,35 +812,54 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
623
812
  await run_manager.on_llm_new_token(gen.text)
624
813
  yield gen
625
814
 
626
- def _prepare_chat(
815
+ def _prepare_request(
627
816
  self,
628
817
  messages: List[BaseMessage],
818
+ *,
629
819
  stop: Optional[List[str]] = None,
630
- **kwargs: Any,
631
- ) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
632
- client = self.client
633
- functions = kwargs.pop("functions", None)
634
- safety_settings = kwargs.pop("safety_settings", self.safety_settings)
635
- if functions or safety_settings:
636
- tools = (
637
- convert_to_genai_function_declarations(functions) if functions else None
638
- )
639
- client = genai.GenerativeModel(
640
- model_name=self.model, tools=tools, safety_settings=safety_settings
641
- )
820
+ tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
821
+ functions: Optional[Sequence[FunctionDeclarationType]] = None,
822
+ safety_settings: Optional[SafetySettingDict] = None,
823
+ tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
824
+ generation_config: Optional[Dict[str, Any]] = None,
825
+ ) -> Tuple[GenerateContentRequest, Dict[str, Any]]:
826
+ formatted_tools = None
827
+ if tools:
828
+ formatted_tools = [
829
+ convert_to_genai_function_declarations(tool) for tool in tools
830
+ ]
831
+ elif functions:
832
+ formatted_tools = [convert_to_genai_function_declarations(functions)]
642
833
 
643
- params = self._prepare_params(stop, **kwargs)
644
834
  system_instruction, history = _parse_chat_history(
645
835
  messages,
646
836
  convert_system_message_to_human=self.convert_system_message_to_human,
647
837
  )
648
- message = history.pop()
649
- if self.client._system_instruction != system_instruction:
650
- self.client = genai.GenerativeModel(
651
- model_name=self.model, system_instruction=system_instruction
838
+ formatted_tool_config = None
839
+ if tool_config:
840
+ formatted_tool_config = ToolConfig(
841
+ function_calling_config=tool_config["function_calling_config"]
652
842
  )
653
- chat = client.start_chat(history=history)
654
- return params, chat, message
843
+ formatted_safety_settings = []
844
+ if safety_settings:
845
+ formatted_safety_settings = [
846
+ SafetySetting(category=c, threshold=t)
847
+ for c, t in safety_settings.items()
848
+ ]
849
+ request = GenerateContentRequest(
850
+ model=self.model,
851
+ contents=history,
852
+ tools=formatted_tools,
853
+ tool_config=formatted_tool_config,
854
+ safety_settings=formatted_safety_settings,
855
+ generation_config=self._prepare_params(
856
+ stop, generation_config=generation_config
857
+ ),
858
+ )
859
+ if system_instruction:
860
+ request.system_instruction = system_instruction
861
+
862
+ return request
655
863
 
656
864
  def get_num_tokens(self, text: str) -> int:
657
865
  """Get the number of tokens present in the text.
@@ -664,11 +872,43 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
664
872
  Returns:
665
873
  The integer number of tokens in the text.
666
874
  """
667
- if self._model_family == GoogleModelFamily.GEMINI:
668
- result = self.client.count_tokens(text)
669
- token_count = result.total_tokens
670
- else:
671
- result = self.client.count_text_tokens(model=self.model, prompt=text)
672
- token_count = result["token_count"]
875
+ result = self.client.count_tokens(
876
+ model=self.model, contents=[Content(parts=[Part(text=text)])]
877
+ )
878
+ return result.total_tokens
673
879
 
674
- return token_count
880
+ def bind_tools(
881
+ self,
882
+ tools: Sequence[Union[ToolDict, GoogleTool]],
883
+ tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
884
+ *,
885
+ tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
886
+ **kwargs: Any,
887
+ ) -> Runnable[LanguageModelInput, BaseMessage]:
888
+ """Bind tool-like objects to this chat model.
889
+
890
+ Assumes model is compatible with google-generativeAI tool-calling API.
891
+
892
+ Args:
893
+ tools: A list of tool definitions to bind to this chat model.
894
+ Can be a pydantic model, callable, or BaseTool. Pydantic
895
+ models, callables, and BaseTools will be automatically converted to
896
+ their schema dictionary representation.
897
+ **kwargs: Any additional parameters to pass to the
898
+ :class:`~langchain.runnable.Runnable` constructor.
899
+ """
900
+ if tool_choice and tool_config:
901
+ raise ValueError(
902
+ "Must specify at most one of tool_choice and tool_config, received "
903
+ f"both:\n\n{tool_choice=}\n\n{tool_config=}"
904
+ )
905
+ # Bind dicts for easier serialization/deserialization.
906
+ genai_tools = [tool_to_dict(convert_to_genai_function_declarations(tools))]
907
+ if tool_choice:
908
+ all_names = [
909
+ f["name"] # type: ignore[index]
910
+ for t in genai_tools
911
+ for f in t["function_declarations"]
912
+ ]
913
+ tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
914
+ return self.bind(tools=genai_tools, tool_config=tool_config, **kwargs)