langchain-google-genai 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain-google-genai might be problematic. Click here for more details.
- langchain_google_genai/__init__.py +1 -1
- langchain_google_genai/_common.py +48 -0
- langchain_google_genai/_enums.py +4 -4
- langchain_google_genai/_function_utils.py +145 -24
- langchain_google_genai/_genai_extension.py +64 -7
- langchain_google_genai/_image_utils.py +187 -0
- langchain_google_genai/chat_models.py +360 -120
- langchain_google_genai/embeddings.py +90 -40
- langchain_google_genai/llms.py +8 -1
- {langchain_google_genai-1.0.2.dist-info → langchain_google_genai-1.0.4.dist-info}/METADATA +3 -3
- langchain_google_genai-1.0.4.dist-info/RECORD +16 -0
- langchain_google_genai-1.0.2.dist-info/RECORD +0 -15
- {langchain_google_genai-1.0.2.dist-info → langchain_google_genai-1.0.4.dist-info}/LICENSE +0 -0
- {langchain_google_genai-1.0.2.dist-info → langchain_google_genai-1.0.4.dist-info}/WHEEL +0 -0
|
@@ -4,6 +4,7 @@ import base64
|
|
|
4
4
|
import json
|
|
5
5
|
import logging
|
|
6
6
|
import os
|
|
7
|
+
import uuid
|
|
7
8
|
import warnings
|
|
8
9
|
from io import BytesIO
|
|
9
10
|
from typing import (
|
|
@@ -22,17 +23,33 @@ from typing import (
|
|
|
22
23
|
)
|
|
23
24
|
from urllib.parse import urlparse
|
|
24
25
|
|
|
25
|
-
import google.ai.generativelanguage as glm
|
|
26
26
|
import google.api_core
|
|
27
27
|
|
|
28
28
|
# TODO: remove ignore once the google package is published with types
|
|
29
|
-
import google.generativeai as genai # type: ignore[import]
|
|
30
29
|
import proto # type: ignore[import]
|
|
31
30
|
import requests
|
|
31
|
+
from google.ai.generativelanguage_v1beta.types import (
|
|
32
|
+
Candidate,
|
|
33
|
+
Content,
|
|
34
|
+
FunctionCall,
|
|
35
|
+
FunctionResponse,
|
|
36
|
+
GenerateContentRequest,
|
|
37
|
+
GenerateContentResponse,
|
|
38
|
+
GenerationConfig,
|
|
39
|
+
Part,
|
|
40
|
+
SafetySetting,
|
|
41
|
+
ToolConfig,
|
|
42
|
+
)
|
|
43
|
+
from google.generativeai.types import Tool as GoogleTool # type: ignore[import]
|
|
44
|
+
from google.generativeai.types.content_types import ( # type: ignore[import]
|
|
45
|
+
FunctionDeclarationType,
|
|
46
|
+
ToolDict,
|
|
47
|
+
)
|
|
32
48
|
from langchain_core.callbacks.manager import (
|
|
33
49
|
AsyncCallbackManagerForLLMRun,
|
|
34
50
|
CallbackManagerForLLMRun,
|
|
35
51
|
)
|
|
52
|
+
from langchain_core.language_models import LanguageModelInput
|
|
36
53
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
37
54
|
from langchain_core.messages import (
|
|
38
55
|
AIMessage,
|
|
@@ -40,10 +57,16 @@ from langchain_core.messages import (
|
|
|
40
57
|
BaseMessage,
|
|
41
58
|
FunctionMessage,
|
|
42
59
|
HumanMessage,
|
|
60
|
+
InvalidToolCall,
|
|
43
61
|
SystemMessage,
|
|
62
|
+
ToolCall,
|
|
63
|
+
ToolCallChunk,
|
|
64
|
+
ToolMessage,
|
|
44
65
|
)
|
|
66
|
+
from langchain_core.output_parsers.openai_tools import parse_tool_calls
|
|
45
67
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
46
|
-
from langchain_core.pydantic_v1 import SecretStr, root_validator
|
|
68
|
+
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
|
69
|
+
from langchain_core.runnables import Runnable
|
|
47
70
|
from langchain_core.utils import get_from_dict_or_env
|
|
48
71
|
from tenacity import (
|
|
49
72
|
before_sleep_log,
|
|
@@ -53,11 +76,22 @@ from tenacity import (
|
|
|
53
76
|
wait_exponential,
|
|
54
77
|
)
|
|
55
78
|
|
|
56
|
-
from langchain_google_genai._common import
|
|
79
|
+
from langchain_google_genai._common import (
|
|
80
|
+
GoogleGenerativeAIError,
|
|
81
|
+
SafetySettingDict,
|
|
82
|
+
get_client_info,
|
|
83
|
+
)
|
|
57
84
|
from langchain_google_genai._function_utils import (
|
|
85
|
+
_tool_choice_to_tool_config,
|
|
86
|
+
_ToolChoiceType,
|
|
87
|
+
_ToolConfigDict,
|
|
58
88
|
convert_to_genai_function_declarations,
|
|
89
|
+
tool_to_dict,
|
|
59
90
|
)
|
|
60
|
-
from langchain_google_genai.
|
|
91
|
+
from langchain_google_genai._image_utils import ImageBytesLoader
|
|
92
|
+
from langchain_google_genai.llms import _BaseGoogleGenerativeAI
|
|
93
|
+
|
|
94
|
+
from . import _genai_extension as genaix
|
|
61
95
|
|
|
62
96
|
IMAGE_TYPES: Tuple = ()
|
|
63
97
|
try:
|
|
@@ -261,18 +295,19 @@ def _url_to_pil(image_source: str) -> Image:
|
|
|
261
295
|
|
|
262
296
|
def _convert_to_parts(
|
|
263
297
|
raw_content: Union[str, Sequence[Union[str, dict]]],
|
|
264
|
-
) -> List[
|
|
298
|
+
) -> List[Part]:
|
|
265
299
|
"""Converts a list of LangChain messages into a google parts."""
|
|
266
300
|
parts = []
|
|
267
301
|
content = [raw_content] if isinstance(raw_content, str) else raw_content
|
|
302
|
+
image_loader = ImageBytesLoader()
|
|
268
303
|
for part in content:
|
|
269
304
|
if isinstance(part, str):
|
|
270
|
-
parts.append(
|
|
305
|
+
parts.append(Part(text=part))
|
|
271
306
|
elif isinstance(part, Mapping):
|
|
272
307
|
# OpenAI Format
|
|
273
308
|
if _is_openai_parts_format(part):
|
|
274
309
|
if part["type"] == "text":
|
|
275
|
-
parts.append(
|
|
310
|
+
parts.append(Part(text=part["text"]))
|
|
276
311
|
elif part["type"] == "image_url":
|
|
277
312
|
img_url = part["image_url"]
|
|
278
313
|
if isinstance(img_url, dict):
|
|
@@ -281,7 +316,7 @@ def _convert_to_parts(
|
|
|
281
316
|
f"Unrecognized message image format: {img_url}"
|
|
282
317
|
)
|
|
283
318
|
img_url = img_url["url"]
|
|
284
|
-
parts.append(
|
|
319
|
+
parts.append(image_loader.load_part(img_url))
|
|
285
320
|
else:
|
|
286
321
|
raise ValueError(f"Unrecognized message part type: {part['type']}")
|
|
287
322
|
else:
|
|
@@ -289,7 +324,7 @@ def _convert_to_parts(
|
|
|
289
324
|
logger.warning(
|
|
290
325
|
"Unrecognized message part format. Assuming it's a text part."
|
|
291
326
|
)
|
|
292
|
-
parts.append(part)
|
|
327
|
+
parts.append(Part(text=str(part)))
|
|
293
328
|
else:
|
|
294
329
|
# TODO: Maybe some of Google's native stuff
|
|
295
330
|
# would hit this branch.
|
|
@@ -301,33 +336,36 @@ def _convert_to_parts(
|
|
|
301
336
|
|
|
302
337
|
def _parse_chat_history(
|
|
303
338
|
input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
|
|
304
|
-
) -> Tuple[Optional[
|
|
305
|
-
messages: List[
|
|
339
|
+
) -> Tuple[Optional[Content], List[Content]]:
|
|
340
|
+
messages: List[Content] = []
|
|
306
341
|
|
|
307
342
|
if convert_system_message_to_human:
|
|
308
343
|
warnings.warn("Convert_system_message_to_human will be deprecated!")
|
|
309
344
|
|
|
310
|
-
system_instruction: Optional[
|
|
345
|
+
system_instruction: Optional[Content] = None
|
|
311
346
|
for i, message in enumerate(input_messages):
|
|
312
347
|
if i == 0 and isinstance(message, SystemMessage):
|
|
313
|
-
system_instruction = _convert_to_parts(message.content)
|
|
348
|
+
system_instruction = Content(parts=_convert_to_parts(message.content))
|
|
314
349
|
continue
|
|
315
350
|
elif isinstance(message, AIMessage):
|
|
316
351
|
role = "model"
|
|
317
352
|
raw_function_call = message.additional_kwargs.get("function_call")
|
|
318
353
|
if raw_function_call:
|
|
319
|
-
function_call =
|
|
354
|
+
function_call = FunctionCall(
|
|
320
355
|
{
|
|
321
356
|
"name": raw_function_call["name"],
|
|
322
357
|
"args": json.loads(raw_function_call["arguments"]),
|
|
323
358
|
}
|
|
324
359
|
)
|
|
325
|
-
parts = [
|
|
360
|
+
parts = [Part(function_call=function_call)]
|
|
326
361
|
else:
|
|
327
362
|
parts = _convert_to_parts(message.content)
|
|
328
363
|
elif isinstance(message, HumanMessage):
|
|
329
364
|
role = "user"
|
|
330
365
|
parts = _convert_to_parts(message.content)
|
|
366
|
+
if i == 1 and convert_system_message_to_human and system_instruction:
|
|
367
|
+
parts = [p for p in system_instruction.parts] + parts
|
|
368
|
+
system_instruction = None
|
|
331
369
|
elif isinstance(message, FunctionMessage):
|
|
332
370
|
role = "user"
|
|
333
371
|
response: Any
|
|
@@ -339,8 +377,8 @@ def _parse_chat_history(
|
|
|
339
377
|
except json.JSONDecodeError:
|
|
340
378
|
response = message.content # leave as str representation
|
|
341
379
|
parts = [
|
|
342
|
-
|
|
343
|
-
function_response=
|
|
380
|
+
Part(
|
|
381
|
+
function_response=FunctionResponse(
|
|
344
382
|
name=message.name,
|
|
345
383
|
response=(
|
|
346
384
|
{"output": response}
|
|
@@ -350,39 +388,138 @@ def _parse_chat_history(
|
|
|
350
388
|
)
|
|
351
389
|
)
|
|
352
390
|
]
|
|
391
|
+
elif isinstance(message, ToolMessage):
|
|
392
|
+
role = "user"
|
|
393
|
+
prev_message: Optional[BaseMessage] = (
|
|
394
|
+
input_messages[i - 1] if i > 0 else None
|
|
395
|
+
)
|
|
396
|
+
if (
|
|
397
|
+
prev_message
|
|
398
|
+
and isinstance(prev_message, AIMessage)
|
|
399
|
+
and prev_message.tool_calls
|
|
400
|
+
):
|
|
401
|
+
# message.name can be null for ToolMessage
|
|
402
|
+
name: str = prev_message.tool_calls[0]["name"]
|
|
403
|
+
else:
|
|
404
|
+
name = message.name # type: ignore
|
|
405
|
+
tool_response: Any
|
|
406
|
+
if not isinstance(message.content, str):
|
|
407
|
+
tool_response = message.content
|
|
408
|
+
else:
|
|
409
|
+
try:
|
|
410
|
+
tool_response = json.loads(message.content)
|
|
411
|
+
except json.JSONDecodeError:
|
|
412
|
+
tool_response = message.content # leave as str representation
|
|
413
|
+
parts = [
|
|
414
|
+
Part(
|
|
415
|
+
function_response=FunctionResponse(
|
|
416
|
+
name=name,
|
|
417
|
+
response=(
|
|
418
|
+
{"output": tool_response}
|
|
419
|
+
if not isinstance(tool_response, dict)
|
|
420
|
+
else tool_response
|
|
421
|
+
),
|
|
422
|
+
)
|
|
423
|
+
)
|
|
424
|
+
]
|
|
353
425
|
else:
|
|
354
426
|
raise ValueError(
|
|
355
427
|
f"Unexpected message with type {type(message)} at the position {i}."
|
|
356
428
|
)
|
|
357
429
|
|
|
358
|
-
messages.append(
|
|
430
|
+
messages.append(Content(role=role, parts=parts))
|
|
359
431
|
return system_instruction, messages
|
|
360
432
|
|
|
361
433
|
|
|
362
434
|
def _parse_response_candidate(
|
|
363
|
-
response_candidate:
|
|
435
|
+
response_candidate: Candidate, streaming: bool = False
|
|
364
436
|
) -> AIMessage:
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
437
|
+
content: Union[None, str, List[str]] = None
|
|
438
|
+
additional_kwargs = {}
|
|
439
|
+
tool_calls = []
|
|
440
|
+
invalid_tool_calls = []
|
|
441
|
+
tool_call_chunks = []
|
|
442
|
+
|
|
443
|
+
for part in response_candidate.content.parts:
|
|
444
|
+
try:
|
|
445
|
+
text: Optional[str] = part.text
|
|
446
|
+
except AttributeError:
|
|
447
|
+
text = None
|
|
448
|
+
|
|
449
|
+
if text is not None:
|
|
450
|
+
if not content:
|
|
451
|
+
content = text
|
|
452
|
+
elif isinstance(content, str) and text:
|
|
453
|
+
content = [content, text]
|
|
454
|
+
elif isinstance(content, list) and text:
|
|
455
|
+
content.append(text)
|
|
456
|
+
elif text:
|
|
457
|
+
raise Exception("Unexpected content type")
|
|
458
|
+
|
|
459
|
+
if part.function_call:
|
|
460
|
+
# TODO: support multiple function calls
|
|
461
|
+
if "function_call" in additional_kwargs:
|
|
462
|
+
raise Exception("Multiple function calls are not currently supported")
|
|
463
|
+
function_call = {"name": part.function_call.name}
|
|
464
|
+
# dump to match other function calling llm for now
|
|
465
|
+
function_call_args_dict = proto.Message.to_dict(part.function_call)["args"]
|
|
466
|
+
function_call["arguments"] = json.dumps(
|
|
467
|
+
{k: function_call_args_dict[k] for k in function_call_args_dict}
|
|
468
|
+
)
|
|
469
|
+
additional_kwargs["function_call"] = function_call
|
|
470
|
+
|
|
471
|
+
if streaming:
|
|
472
|
+
tool_call_chunks.append(
|
|
473
|
+
ToolCallChunk(
|
|
474
|
+
name=function_call.get("name"),
|
|
475
|
+
args=function_call.get("arguments"),
|
|
476
|
+
id=function_call.get("id", str(uuid.uuid4())),
|
|
477
|
+
index=function_call.get("index"), # type: ignore
|
|
478
|
+
)
|
|
479
|
+
)
|
|
480
|
+
else:
|
|
481
|
+
try:
|
|
482
|
+
tool_calls_dicts = parse_tool_calls(
|
|
483
|
+
[{"function": function_call}],
|
|
484
|
+
return_id=False,
|
|
485
|
+
)
|
|
486
|
+
tool_calls = [
|
|
487
|
+
ToolCall(
|
|
488
|
+
name=tool_call["name"],
|
|
489
|
+
args=tool_call["args"],
|
|
490
|
+
id=tool_call.get("id", str(uuid.uuid4())),
|
|
491
|
+
)
|
|
492
|
+
for tool_call in tool_calls_dicts
|
|
493
|
+
]
|
|
494
|
+
except Exception as e:
|
|
495
|
+
invalid_tool_calls = [
|
|
496
|
+
InvalidToolCall(
|
|
497
|
+
name=function_call.get("name"),
|
|
498
|
+
args=function_call.get("arguments"),
|
|
499
|
+
id=function_call.get("id", str(uuid.uuid4())),
|
|
500
|
+
error=str(e),
|
|
501
|
+
)
|
|
502
|
+
]
|
|
503
|
+
if content is None:
|
|
504
|
+
content = ""
|
|
505
|
+
|
|
506
|
+
if streaming:
|
|
507
|
+
return AIMessageChunk(
|
|
508
|
+
content=cast(Union[str, List[Union[str, Dict[Any, Any]]]], content),
|
|
509
|
+
additional_kwargs=additional_kwargs,
|
|
510
|
+
tool_call_chunks=tool_call_chunks,
|
|
371
511
|
)
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
content = [proto.Message.to_dict(part) for part in parts]
|
|
379
|
-
return (AIMessageChunk if stream else AIMessage)(
|
|
380
|
-
content=content, additional_kwargs={}
|
|
512
|
+
|
|
513
|
+
return AIMessage(
|
|
514
|
+
content=cast(Union[str, List[Union[str, Dict[Any, Any]]]], content),
|
|
515
|
+
additional_kwargs=additional_kwargs,
|
|
516
|
+
tool_calls=tool_calls,
|
|
517
|
+
invalid_tool_calls=invalid_tool_calls,
|
|
381
518
|
)
|
|
382
519
|
|
|
383
520
|
|
|
384
521
|
def _response_to_result(
|
|
385
|
-
response:
|
|
522
|
+
response: GenerateContentResponse,
|
|
386
523
|
stream: bool = False,
|
|
387
524
|
) -> ChatResult:
|
|
388
525
|
"""Converts a PaLM API response into a LangChain ChatResult."""
|
|
@@ -400,7 +537,7 @@ def _response_to_result(
|
|
|
400
537
|
]
|
|
401
538
|
generations.append(
|
|
402
539
|
(ChatGenerationChunk if stream else ChatGeneration)(
|
|
403
|
-
message=_parse_response_candidate(candidate,
|
|
540
|
+
message=_parse_response_candidate(candidate, streaming=stream),
|
|
404
541
|
generation_info=generation_info,
|
|
405
542
|
)
|
|
406
543
|
)
|
|
@@ -440,6 +577,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
440
577
|
"""
|
|
441
578
|
|
|
442
579
|
client: Any #: :meta private:
|
|
580
|
+
async_client: Any #: :meta private:
|
|
581
|
+
default_metadata: Sequence[Tuple[str, str]] = Field(
|
|
582
|
+
default_factory=list
|
|
583
|
+
) #: :meta private:
|
|
443
584
|
|
|
444
585
|
convert_system_message_to_human: bool = False
|
|
445
586
|
"""Whether to merge any leading SystemMessage into the following HumanMessage.
|
|
@@ -465,29 +606,6 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
465
606
|
@root_validator()
|
|
466
607
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
467
608
|
"""Validates params and passes them to google-generativeai package."""
|
|
468
|
-
additional_headers = values.get("additional_headers") or {}
|
|
469
|
-
default_metadata = tuple(additional_headers.items())
|
|
470
|
-
|
|
471
|
-
if values.get("credentials"):
|
|
472
|
-
genai.configure(
|
|
473
|
-
credentials=values.get("credentials"),
|
|
474
|
-
transport=values.get("transport"),
|
|
475
|
-
client_options=values.get("client_options"),
|
|
476
|
-
default_metadata=default_metadata,
|
|
477
|
-
)
|
|
478
|
-
else:
|
|
479
|
-
google_api_key = get_from_dict_or_env(
|
|
480
|
-
values, "google_api_key", "GOOGLE_API_KEY"
|
|
481
|
-
)
|
|
482
|
-
if isinstance(google_api_key, SecretStr):
|
|
483
|
-
google_api_key = google_api_key.get_secret_value()
|
|
484
|
-
|
|
485
|
-
genai.configure(
|
|
486
|
-
api_key=google_api_key,
|
|
487
|
-
transport=values.get("transport"),
|
|
488
|
-
client_options=values.get("client_options"),
|
|
489
|
-
default_metadata=default_metadata,
|
|
490
|
-
)
|
|
491
609
|
if (
|
|
492
610
|
values.get("temperature") is not None
|
|
493
611
|
and not 0 <= values["temperature"] <= 1
|
|
@@ -499,8 +617,36 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
499
617
|
|
|
500
618
|
if values.get("top_k") is not None and values["top_k"] <= 0:
|
|
501
619
|
raise ValueError("top_k must be positive")
|
|
502
|
-
|
|
503
|
-
values["
|
|
620
|
+
|
|
621
|
+
if not values["model"].startswith("models/"):
|
|
622
|
+
values["model"] = f"models/{values['model']}"
|
|
623
|
+
|
|
624
|
+
additional_headers = values.get("additional_headers") or {}
|
|
625
|
+
values["default_metadata"] = tuple(additional_headers.items())
|
|
626
|
+
client_info = get_client_info("ChatGoogleGenerativeAI")
|
|
627
|
+
google_api_key = None
|
|
628
|
+
if not values.get("credentials"):
|
|
629
|
+
google_api_key = get_from_dict_or_env(
|
|
630
|
+
values, "google_api_key", "GOOGLE_API_KEY"
|
|
631
|
+
)
|
|
632
|
+
if isinstance(google_api_key, SecretStr):
|
|
633
|
+
google_api_key = google_api_key.get_secret_value()
|
|
634
|
+
transport: Optional[str] = values.get("transport")
|
|
635
|
+
values["client"] = genaix.build_generative_service(
|
|
636
|
+
credentials=values.get("credentials"),
|
|
637
|
+
api_key=google_api_key,
|
|
638
|
+
client_info=client_info,
|
|
639
|
+
client_options=values.get("client_options"),
|
|
640
|
+
transport=transport,
|
|
641
|
+
)
|
|
642
|
+
values["async_client"] = genaix.build_generative_async_service(
|
|
643
|
+
credentials=values.get("credentials"),
|
|
644
|
+
api_key=google_api_key,
|
|
645
|
+
client_info=client_info,
|
|
646
|
+
client_options=values.get("client_options"),
|
|
647
|
+
transport=transport,
|
|
648
|
+
)
|
|
649
|
+
|
|
504
650
|
return values
|
|
505
651
|
|
|
506
652
|
@property
|
|
@@ -515,8 +661,10 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
515
661
|
}
|
|
516
662
|
|
|
517
663
|
def _prepare_params(
|
|
518
|
-
self,
|
|
519
|
-
|
|
664
|
+
self,
|
|
665
|
+
stop: Optional[List[str]],
|
|
666
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
667
|
+
) -> GenerationConfig:
|
|
520
668
|
gen_config = {
|
|
521
669
|
k: v
|
|
522
670
|
for k, v in {
|
|
@@ -529,27 +677,37 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
529
677
|
}.items()
|
|
530
678
|
if v is not None
|
|
531
679
|
}
|
|
532
|
-
if
|
|
533
|
-
gen_config = {**gen_config, **
|
|
534
|
-
|
|
535
|
-
return params
|
|
680
|
+
if generation_config:
|
|
681
|
+
gen_config = {**gen_config, **generation_config}
|
|
682
|
+
return GenerationConfig(**gen_config)
|
|
536
683
|
|
|
537
684
|
def _generate(
|
|
538
685
|
self,
|
|
539
686
|
messages: List[BaseMessage],
|
|
540
687
|
stop: Optional[List[str]] = None,
|
|
541
688
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
689
|
+
*,
|
|
690
|
+
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
|
|
691
|
+
functions: Optional[Sequence[FunctionDeclarationType]] = None,
|
|
692
|
+
safety_settings: Optional[SafetySettingDict] = None,
|
|
693
|
+
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
694
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
542
695
|
**kwargs: Any,
|
|
543
696
|
) -> ChatResult:
|
|
544
|
-
|
|
697
|
+
request = self._prepare_request(
|
|
545
698
|
messages,
|
|
546
699
|
stop=stop,
|
|
547
|
-
|
|
700
|
+
tools=tools,
|
|
701
|
+
functions=functions,
|
|
702
|
+
safety_settings=safety_settings,
|
|
703
|
+
tool_config=tool_config,
|
|
704
|
+
generation_config=generation_config,
|
|
548
705
|
)
|
|
549
|
-
response:
|
|
550
|
-
|
|
551
|
-
**
|
|
552
|
-
generation_method=
|
|
706
|
+
response: GenerateContentResponse = _chat_with_retry(
|
|
707
|
+
request=request,
|
|
708
|
+
**kwargs,
|
|
709
|
+
generation_method=self.client.generate_content,
|
|
710
|
+
metadata=self.default_metadata,
|
|
553
711
|
)
|
|
554
712
|
return _response_to_result(response)
|
|
555
713
|
|
|
@@ -558,17 +716,28 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
558
716
|
messages: List[BaseMessage],
|
|
559
717
|
stop: Optional[List[str]] = None,
|
|
560
718
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
719
|
+
*,
|
|
720
|
+
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
|
|
721
|
+
functions: Optional[Sequence[FunctionDeclarationType]] = None,
|
|
722
|
+
safety_settings: Optional[SafetySettingDict] = None,
|
|
723
|
+
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
724
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
561
725
|
**kwargs: Any,
|
|
562
726
|
) -> ChatResult:
|
|
563
|
-
|
|
727
|
+
request = self._prepare_request(
|
|
564
728
|
messages,
|
|
565
729
|
stop=stop,
|
|
566
|
-
|
|
730
|
+
tools=tools,
|
|
731
|
+
functions=functions,
|
|
732
|
+
safety_settings=safety_settings,
|
|
733
|
+
tool_config=tool_config,
|
|
734
|
+
generation_config=generation_config,
|
|
567
735
|
)
|
|
568
|
-
response:
|
|
569
|
-
|
|
570
|
-
**
|
|
571
|
-
generation_method=
|
|
736
|
+
response: GenerateContentResponse = await _achat_with_retry(
|
|
737
|
+
request=request,
|
|
738
|
+
**kwargs,
|
|
739
|
+
generation_method=self.async_client.generate_content,
|
|
740
|
+
metadata=self.default_metadata,
|
|
572
741
|
)
|
|
573
742
|
return _response_to_result(response)
|
|
574
743
|
|
|
@@ -577,18 +746,28 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
577
746
|
messages: List[BaseMessage],
|
|
578
747
|
stop: Optional[List[str]] = None,
|
|
579
748
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
749
|
+
*,
|
|
750
|
+
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
|
|
751
|
+
functions: Optional[Sequence[FunctionDeclarationType]] = None,
|
|
752
|
+
safety_settings: Optional[SafetySettingDict] = None,
|
|
753
|
+
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
754
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
580
755
|
**kwargs: Any,
|
|
581
756
|
) -> Iterator[ChatGenerationChunk]:
|
|
582
|
-
|
|
757
|
+
request = self._prepare_request(
|
|
583
758
|
messages,
|
|
584
759
|
stop=stop,
|
|
585
|
-
|
|
760
|
+
tools=tools,
|
|
761
|
+
functions=functions,
|
|
762
|
+
safety_settings=safety_settings,
|
|
763
|
+
tool_config=tool_config,
|
|
764
|
+
generation_config=generation_config,
|
|
586
765
|
)
|
|
587
|
-
response:
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
766
|
+
response: GenerateContentResponse = _chat_with_retry(
|
|
767
|
+
request=request,
|
|
768
|
+
generation_method=self.client.stream_generate_content,
|
|
769
|
+
**kwargs,
|
|
770
|
+
metadata=self.default_metadata,
|
|
592
771
|
)
|
|
593
772
|
for chunk in response:
|
|
594
773
|
_chat_result = _response_to_result(chunk, stream=True)
|
|
@@ -603,18 +782,28 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
603
782
|
messages: List[BaseMessage],
|
|
604
783
|
stop: Optional[List[str]] = None,
|
|
605
784
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
785
|
+
*,
|
|
786
|
+
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
|
|
787
|
+
functions: Optional[Sequence[FunctionDeclarationType]] = None,
|
|
788
|
+
safety_settings: Optional[SafetySettingDict] = None,
|
|
789
|
+
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
790
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
606
791
|
**kwargs: Any,
|
|
607
792
|
) -> AsyncIterator[ChatGenerationChunk]:
|
|
608
|
-
|
|
793
|
+
request = self._prepare_request(
|
|
609
794
|
messages,
|
|
610
795
|
stop=stop,
|
|
611
|
-
|
|
796
|
+
tools=tools,
|
|
797
|
+
functions=functions,
|
|
798
|
+
safety_settings=safety_settings,
|
|
799
|
+
tool_config=tool_config,
|
|
800
|
+
generation_config=generation_config,
|
|
612
801
|
)
|
|
613
802
|
async for chunk in await _achat_with_retry(
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
803
|
+
request=request,
|
|
804
|
+
generation_method=self.async_client.stream_generate_content,
|
|
805
|
+
**kwargs,
|
|
806
|
+
metadata=self.default_metadata,
|
|
618
807
|
):
|
|
619
808
|
_chat_result = _response_to_result(chunk, stream=True)
|
|
620
809
|
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
|
@@ -623,35 +812,54 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
623
812
|
await run_manager.on_llm_new_token(gen.text)
|
|
624
813
|
yield gen
|
|
625
814
|
|
|
626
|
-
def
|
|
815
|
+
def _prepare_request(
|
|
627
816
|
self,
|
|
628
817
|
messages: List[BaseMessage],
|
|
818
|
+
*,
|
|
629
819
|
stop: Optional[List[str]] = None,
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
820
|
+
tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None,
|
|
821
|
+
functions: Optional[Sequence[FunctionDeclarationType]] = None,
|
|
822
|
+
safety_settings: Optional[SafetySettingDict] = None,
|
|
823
|
+
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
824
|
+
generation_config: Optional[Dict[str, Any]] = None,
|
|
825
|
+
) -> Tuple[GenerateContentRequest, Dict[str, Any]]:
|
|
826
|
+
formatted_tools = None
|
|
827
|
+
if tools:
|
|
828
|
+
formatted_tools = [
|
|
829
|
+
convert_to_genai_function_declarations(tool) for tool in tools
|
|
830
|
+
]
|
|
831
|
+
elif functions:
|
|
832
|
+
formatted_tools = [convert_to_genai_function_declarations(functions)]
|
|
642
833
|
|
|
643
|
-
params = self._prepare_params(stop, **kwargs)
|
|
644
834
|
system_instruction, history = _parse_chat_history(
|
|
645
835
|
messages,
|
|
646
836
|
convert_system_message_to_human=self.convert_system_message_to_human,
|
|
647
837
|
)
|
|
648
|
-
|
|
649
|
-
if
|
|
650
|
-
|
|
651
|
-
|
|
838
|
+
formatted_tool_config = None
|
|
839
|
+
if tool_config:
|
|
840
|
+
formatted_tool_config = ToolConfig(
|
|
841
|
+
function_calling_config=tool_config["function_calling_config"]
|
|
652
842
|
)
|
|
653
|
-
|
|
654
|
-
|
|
843
|
+
formatted_safety_settings = []
|
|
844
|
+
if safety_settings:
|
|
845
|
+
formatted_safety_settings = [
|
|
846
|
+
SafetySetting(category=c, threshold=t)
|
|
847
|
+
for c, t in safety_settings.items()
|
|
848
|
+
]
|
|
849
|
+
request = GenerateContentRequest(
|
|
850
|
+
model=self.model,
|
|
851
|
+
contents=history,
|
|
852
|
+
tools=formatted_tools,
|
|
853
|
+
tool_config=formatted_tool_config,
|
|
854
|
+
safety_settings=formatted_safety_settings,
|
|
855
|
+
generation_config=self._prepare_params(
|
|
856
|
+
stop, generation_config=generation_config
|
|
857
|
+
),
|
|
858
|
+
)
|
|
859
|
+
if system_instruction:
|
|
860
|
+
request.system_instruction = system_instruction
|
|
861
|
+
|
|
862
|
+
return request
|
|
655
863
|
|
|
656
864
|
def get_num_tokens(self, text: str) -> int:
|
|
657
865
|
"""Get the number of tokens present in the text.
|
|
@@ -664,11 +872,43 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
664
872
|
Returns:
|
|
665
873
|
The integer number of tokens in the text.
|
|
666
874
|
"""
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
|
672
|
-
token_count = result["token_count"]
|
|
875
|
+
result = self.client.count_tokens(
|
|
876
|
+
model=self.model, contents=[Content(parts=[Part(text=text)])]
|
|
877
|
+
)
|
|
878
|
+
return result.total_tokens
|
|
673
879
|
|
|
674
|
-
|
|
880
|
+
def bind_tools(
|
|
881
|
+
self,
|
|
882
|
+
tools: Sequence[Union[ToolDict, GoogleTool]],
|
|
883
|
+
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
|
884
|
+
*,
|
|
885
|
+
tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
|
|
886
|
+
**kwargs: Any,
|
|
887
|
+
) -> Runnable[LanguageModelInput, BaseMessage]:
|
|
888
|
+
"""Bind tool-like objects to this chat model.
|
|
889
|
+
|
|
890
|
+
Assumes model is compatible with google-generativeAI tool-calling API.
|
|
891
|
+
|
|
892
|
+
Args:
|
|
893
|
+
tools: A list of tool definitions to bind to this chat model.
|
|
894
|
+
Can be a pydantic model, callable, or BaseTool. Pydantic
|
|
895
|
+
models, callables, and BaseTools will be automatically converted to
|
|
896
|
+
their schema dictionary representation.
|
|
897
|
+
**kwargs: Any additional parameters to pass to the
|
|
898
|
+
:class:`~langchain.runnable.Runnable` constructor.
|
|
899
|
+
"""
|
|
900
|
+
if tool_choice and tool_config:
|
|
901
|
+
raise ValueError(
|
|
902
|
+
"Must specify at most one of tool_choice and tool_config, received "
|
|
903
|
+
f"both:\n\n{tool_choice=}\n\n{tool_config=}"
|
|
904
|
+
)
|
|
905
|
+
# Bind dicts for easier serialization/deserialization.
|
|
906
|
+
genai_tools = [tool_to_dict(convert_to_genai_function_declarations(tools))]
|
|
907
|
+
if tool_choice:
|
|
908
|
+
all_names = [
|
|
909
|
+
f["name"] # type: ignore[index]
|
|
910
|
+
for t in genai_tools
|
|
911
|
+
for f in t["function_declarations"]
|
|
912
|
+
]
|
|
913
|
+
tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
|
|
914
|
+
return self.bind(tools=genai_tools, tool_config=tool_config, **kwargs)
|