langchain-google-genai 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain_google_genai/_function_utils.py +116 -0
- langchain_google_genai/chat_models.py +148 -132
- langchain_google_genai/llms.py +68 -30
- {langchain_google_genai-0.0.6.dist-info → langchain_google_genai-0.0.8.dist-info}/METADATA +4 -1
- langchain_google_genai-0.0.8.dist-info/RECORD +11 -0
- {langchain_google_genai-0.0.6.dist-info → langchain_google_genai-0.0.8.dist-info}/WHEEL +1 -1
- langchain_google_genai-0.0.6.dist-info/RECORD +0 -10
- {langchain_google_genai-0.0.6.dist-info → langchain_google_genai-0.0.8.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import (
|
|
4
|
+
Dict,
|
|
5
|
+
List,
|
|
6
|
+
Type,
|
|
7
|
+
Union,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
import google.ai.generativelanguage as glm
|
|
11
|
+
from langchain_core.pydantic_v1 import BaseModel
|
|
12
|
+
from langchain_core.tools import BaseTool
|
|
13
|
+
from langchain_core.utils.json_schema import dereference_refs
|
|
14
|
+
|
|
15
|
+
FunctionCallType = Union[BaseTool, Type[BaseModel], Dict]
|
|
16
|
+
|
|
17
|
+
TYPE_ENUM = {
|
|
18
|
+
"string": glm.Type.STRING,
|
|
19
|
+
"number": glm.Type.NUMBER,
|
|
20
|
+
"integer": glm.Type.INTEGER,
|
|
21
|
+
"boolean": glm.Type.BOOLEAN,
|
|
22
|
+
"array": glm.Type.ARRAY,
|
|
23
|
+
"object": glm.Type.OBJECT,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def convert_to_genai_function_declarations(
|
|
28
|
+
function_calls: List[FunctionCallType],
|
|
29
|
+
) -> List[glm.Tool]:
|
|
30
|
+
return [
|
|
31
|
+
glm.Tool(
|
|
32
|
+
function_declarations=[_convert_to_genai_function(fc)],
|
|
33
|
+
)
|
|
34
|
+
for fc in function_calls
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _convert_to_genai_function(fc: FunctionCallType) -> glm.FunctionDeclaration:
|
|
39
|
+
if isinstance(fc, BaseTool):
|
|
40
|
+
return _convert_tool_to_genai_function(fc)
|
|
41
|
+
elif isinstance(fc, type) and issubclass(fc, BaseModel):
|
|
42
|
+
return _convert_pydantic_to_genai_function(fc)
|
|
43
|
+
elif isinstance(fc, dict):
|
|
44
|
+
return glm.FunctionDeclaration(
|
|
45
|
+
name=fc["name"],
|
|
46
|
+
description=fc.get("description"),
|
|
47
|
+
parameters={
|
|
48
|
+
"properties": {
|
|
49
|
+
k: {
|
|
50
|
+
"type_": TYPE_ENUM[v["type"]],
|
|
51
|
+
"description": v.get("description"),
|
|
52
|
+
}
|
|
53
|
+
for k, v in fc["parameters"]["properties"].items()
|
|
54
|
+
},
|
|
55
|
+
"required": fc["parameters"].get("required", []),
|
|
56
|
+
"type_": TYPE_ENUM[fc["parameters"]["type"]],
|
|
57
|
+
},
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError(f"Unsupported function call type {fc}")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _convert_tool_to_genai_function(tool: BaseTool) -> glm.FunctionDeclaration:
|
|
64
|
+
if tool.args_schema:
|
|
65
|
+
schema = dereference_refs(tool.args_schema.schema())
|
|
66
|
+
schema.pop("definitions", None)
|
|
67
|
+
|
|
68
|
+
return glm.FunctionDeclaration(
|
|
69
|
+
name=tool.name or schema["title"],
|
|
70
|
+
description=tool.description or schema["description"],
|
|
71
|
+
parameters={
|
|
72
|
+
"properties": {
|
|
73
|
+
k: {
|
|
74
|
+
"type_": TYPE_ENUM[v["type"]],
|
|
75
|
+
"description": v.get("description"),
|
|
76
|
+
}
|
|
77
|
+
for k, v in schema["properties"].items()
|
|
78
|
+
},
|
|
79
|
+
"required": schema["required"],
|
|
80
|
+
"type_": TYPE_ENUM[schema["type"]],
|
|
81
|
+
},
|
|
82
|
+
)
|
|
83
|
+
else:
|
|
84
|
+
return glm.FunctionDeclaration(
|
|
85
|
+
name=tool.name,
|
|
86
|
+
description=tool.description,
|
|
87
|
+
parameters={
|
|
88
|
+
"properties": {
|
|
89
|
+
"__arg1": {"type_": TYPE_ENUM["string"]},
|
|
90
|
+
},
|
|
91
|
+
"required": ["__arg1"],
|
|
92
|
+
"type_": TYPE_ENUM["object"],
|
|
93
|
+
},
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _convert_pydantic_to_genai_function(
|
|
98
|
+
pydantic_model: Type[BaseModel],
|
|
99
|
+
) -> glm.FunctionDeclaration:
|
|
100
|
+
schema = dereference_refs(pydantic_model.schema())
|
|
101
|
+
schema.pop("definitions", None)
|
|
102
|
+
return glm.FunctionDeclaration(
|
|
103
|
+
name=schema["title"],
|
|
104
|
+
description=schema.get("description", ""),
|
|
105
|
+
parameters={
|
|
106
|
+
"properties": {
|
|
107
|
+
k: {
|
|
108
|
+
"type_": TYPE_ENUM[v["type"]],
|
|
109
|
+
"description": v.get("description"),
|
|
110
|
+
}
|
|
111
|
+
for k, v in schema["properties"].items()
|
|
112
|
+
},
|
|
113
|
+
"required": schema["required"],
|
|
114
|
+
"type_": TYPE_ENUM[schema["type"]],
|
|
115
|
+
},
|
|
116
|
+
)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
+
import json
|
|
4
5
|
import logging
|
|
5
6
|
import os
|
|
6
7
|
from io import BytesIO
|
|
@@ -15,14 +16,17 @@ from typing import (
|
|
|
15
16
|
Optional,
|
|
16
17
|
Sequence,
|
|
17
18
|
Tuple,
|
|
18
|
-
Type,
|
|
19
19
|
Union,
|
|
20
20
|
cast,
|
|
21
21
|
)
|
|
22
22
|
from urllib.parse import urlparse
|
|
23
23
|
|
|
24
|
+
import google.ai.generativelanguage as glm
|
|
25
|
+
import google.api_core
|
|
26
|
+
|
|
24
27
|
# TODO: remove ignore once the google package is published with types
|
|
25
28
|
import google.generativeai as genai # type: ignore[import]
|
|
29
|
+
import proto # type: ignore[import]
|
|
26
30
|
import requests
|
|
27
31
|
from langchain_core.callbacks.manager import (
|
|
28
32
|
AsyncCallbackManagerForLLMRun,
|
|
@@ -33,14 +37,12 @@ from langchain_core.messages import (
|
|
|
33
37
|
AIMessage,
|
|
34
38
|
AIMessageChunk,
|
|
35
39
|
BaseMessage,
|
|
36
|
-
|
|
37
|
-
ChatMessageChunk,
|
|
40
|
+
FunctionMessage,
|
|
38
41
|
HumanMessage,
|
|
39
|
-
HumanMessageChunk,
|
|
40
42
|
SystemMessage,
|
|
41
43
|
)
|
|
42
44
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
43
|
-
from langchain_core.pydantic_v1 import
|
|
45
|
+
from langchain_core.pydantic_v1 import SecretStr, root_validator
|
|
44
46
|
from langchain_core.utils import get_from_dict_or_env
|
|
45
47
|
from tenacity import (
|
|
46
48
|
before_sleep_log,
|
|
@@ -51,6 +53,10 @@ from tenacity import (
|
|
|
51
53
|
)
|
|
52
54
|
|
|
53
55
|
from langchain_google_genai._common import GoogleGenerativeAIError
|
|
56
|
+
from langchain_google_genai._function_utils import (
|
|
57
|
+
convert_to_genai_function_declarations,
|
|
58
|
+
)
|
|
59
|
+
from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI
|
|
54
60
|
|
|
55
61
|
IMAGE_TYPES: Tuple = ()
|
|
56
62
|
try:
|
|
@@ -87,8 +93,6 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
|
|
|
87
93
|
Callable[[Any], Any]: A retry decorator configured for handling specific
|
|
88
94
|
Google API exceptions.
|
|
89
95
|
"""
|
|
90
|
-
import google.api_core.exceptions
|
|
91
|
-
|
|
92
96
|
multiplier = 2
|
|
93
97
|
min_seconds = 1
|
|
94
98
|
max_seconds = 60
|
|
@@ -123,14 +127,22 @@ def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
|
|
|
123
127
|
Any: The result from the chat generation method.
|
|
124
128
|
"""
|
|
125
129
|
retry_decorator = _create_retry_decorator()
|
|
126
|
-
from google.api_core.exceptions import InvalidArgument # type: ignore
|
|
127
130
|
|
|
128
131
|
@retry_decorator
|
|
129
132
|
def _chat_with_retry(**kwargs: Any) -> Any:
|
|
130
133
|
try:
|
|
131
134
|
return generation_method(**kwargs)
|
|
132
|
-
|
|
133
|
-
|
|
135
|
+
# Do not retry for these errors.
|
|
136
|
+
except google.api_core.exceptions.FailedPrecondition as exc:
|
|
137
|
+
if "location is not supported" in exc.message:
|
|
138
|
+
error_msg = (
|
|
139
|
+
"Your location is not supported by google-generativeai "
|
|
140
|
+
"at the moment. Try to use ChatVertexAI LLM from "
|
|
141
|
+
"langchain_google_vertexai."
|
|
142
|
+
)
|
|
143
|
+
raise ValueError(error_msg)
|
|
144
|
+
|
|
145
|
+
except google.api_core.exceptions.InvalidArgument as e:
|
|
134
146
|
raise ChatGoogleGenerativeAIError(
|
|
135
147
|
f"Invalid argument provided to Gemini: {e}"
|
|
136
148
|
) from e
|
|
@@ -312,14 +324,47 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
|
|
|
312
324
|
continue
|
|
313
325
|
elif isinstance(message, AIMessage):
|
|
314
326
|
role = "model"
|
|
327
|
+
raw_function_call = message.additional_kwargs.get("function_call")
|
|
328
|
+
if raw_function_call:
|
|
329
|
+
function_call = glm.FunctionCall(
|
|
330
|
+
{
|
|
331
|
+
"name": raw_function_call["name"],
|
|
332
|
+
"args": json.loads(raw_function_call["arguments"]),
|
|
333
|
+
}
|
|
334
|
+
)
|
|
335
|
+
parts = [glm.Part(function_call=function_call)]
|
|
336
|
+
else:
|
|
337
|
+
parts = _convert_to_parts(message.content)
|
|
315
338
|
elif isinstance(message, HumanMessage):
|
|
316
339
|
role = "user"
|
|
340
|
+
parts = _convert_to_parts(message.content)
|
|
341
|
+
elif isinstance(message, FunctionMessage):
|
|
342
|
+
role = "user"
|
|
343
|
+
response: Any
|
|
344
|
+
if not isinstance(message.content, str):
|
|
345
|
+
response = message.content
|
|
346
|
+
else:
|
|
347
|
+
try:
|
|
348
|
+
response = json.loads(message.content)
|
|
349
|
+
except json.JSONDecodeError:
|
|
350
|
+
response = message.content # leave as str representation
|
|
351
|
+
parts = [
|
|
352
|
+
glm.Part(
|
|
353
|
+
function_response=glm.FunctionResponse(
|
|
354
|
+
name=message.name,
|
|
355
|
+
response=(
|
|
356
|
+
{"output": response}
|
|
357
|
+
if not isinstance(response, dict)
|
|
358
|
+
else response
|
|
359
|
+
),
|
|
360
|
+
)
|
|
361
|
+
)
|
|
362
|
+
]
|
|
317
363
|
else:
|
|
318
364
|
raise ValueError(
|
|
319
365
|
f"Unexpected message with type {type(message)} at the position {i}."
|
|
320
366
|
)
|
|
321
367
|
|
|
322
|
-
parts = _convert_to_parts(message.content)
|
|
323
368
|
if raw_system_message:
|
|
324
369
|
if role == "model":
|
|
325
370
|
raise ValueError(
|
|
@@ -332,71 +377,51 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
|
|
|
332
377
|
return messages
|
|
333
378
|
|
|
334
379
|
|
|
335
|
-
def
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
return
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
return messages
|
|
380
|
+
def _parse_response_candidate(
|
|
381
|
+
response_candidate: glm.Candidate, stream: bool
|
|
382
|
+
) -> AIMessage:
|
|
383
|
+
first_part = response_candidate.content.parts[0]
|
|
384
|
+
if first_part.function_call:
|
|
385
|
+
function_call = proto.Message.to_dict(first_part.function_call)
|
|
386
|
+
function_call["arguments"] = json.dumps(function_call.pop("args", {}))
|
|
387
|
+
return (AIMessageChunk if stream else AIMessage)(
|
|
388
|
+
content="", additional_kwargs={"function_call": function_call}
|
|
389
|
+
)
|
|
390
|
+
else:
|
|
391
|
+
parts = response_candidate.content.parts
|
|
392
|
+
|
|
393
|
+
if len(parts) == 1 and parts[0].text:
|
|
394
|
+
content: Union[str, List[Union[str, Dict]]] = parts[0].text
|
|
395
|
+
else:
|
|
396
|
+
content = [proto.Message.to_dict(part) for part in parts]
|
|
397
|
+
return (AIMessageChunk if stream else AIMessage)(
|
|
398
|
+
content=content, additional_kwargs={}
|
|
399
|
+
)
|
|
356
400
|
|
|
357
401
|
|
|
358
402
|
def _response_to_result(
|
|
359
|
-
response:
|
|
360
|
-
|
|
361
|
-
human_msg_t: Type[BaseMessage] = HumanMessage,
|
|
362
|
-
chat_msg_t: Type[BaseMessage] = ChatMessage,
|
|
363
|
-
generation_t: Type[ChatGeneration] = ChatGeneration,
|
|
403
|
+
response: glm.GenerateContentResponse,
|
|
404
|
+
stream: bool = False,
|
|
364
405
|
) -> ChatResult:
|
|
365
406
|
"""Converts a PaLM API response into a LangChain ChatResult."""
|
|
366
|
-
llm_output = {}
|
|
367
|
-
if response.prompt_feedback:
|
|
368
|
-
try:
|
|
369
|
-
prompt_feedback = type(response.prompt_feedback).to_dict(
|
|
370
|
-
response.prompt_feedback, use_integers_for_enums=False
|
|
371
|
-
)
|
|
372
|
-
llm_output["prompt_feedback"] = prompt_feedback
|
|
373
|
-
except Exception as e:
|
|
374
|
-
logger.debug(f"Unable to convert prompt_feedback to dict: {e}")
|
|
407
|
+
llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
|
|
375
408
|
|
|
376
409
|
generations: List[ChatGeneration] = []
|
|
377
410
|
|
|
378
|
-
role_map = {
|
|
379
|
-
"model": ai_msg_t,
|
|
380
|
-
"user": human_msg_t,
|
|
381
|
-
}
|
|
382
411
|
for candidate in response.candidates:
|
|
383
|
-
content = candidate.content
|
|
384
|
-
parts_content = _parts_to_content(content.parts)
|
|
385
|
-
if content.role not in role_map:
|
|
386
|
-
logger.warning(
|
|
387
|
-
f"Unrecognized role: {content.role}. Treating as a ChatMessage."
|
|
388
|
-
)
|
|
389
|
-
msg = chat_msg_t(content=parts_content, role=content.role)
|
|
390
|
-
else:
|
|
391
|
-
msg = role_map[content.role](content=parts_content)
|
|
392
412
|
generation_info = {}
|
|
393
413
|
if candidate.finish_reason:
|
|
394
414
|
generation_info["finish_reason"] = candidate.finish_reason.name
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
generations.append(
|
|
415
|
+
generation_info["safety_ratings"] = [
|
|
416
|
+
proto.Message.to_dict(safety_rating, use_integers_for_enums=False)
|
|
417
|
+
for safety_rating in candidate.safety_ratings
|
|
418
|
+
]
|
|
419
|
+
generations.append(
|
|
420
|
+
(ChatGenerationChunk if stream else ChatGeneration)(
|
|
421
|
+
message=_parse_response_candidate(candidate, stream=stream),
|
|
422
|
+
generation_info=generation_info,
|
|
423
|
+
)
|
|
424
|
+
)
|
|
400
425
|
if not response.candidates:
|
|
401
426
|
# Likely a "prompt feedback" violation (e.g., toxic input)
|
|
402
427
|
# Raising an error would be different than how OpenAI handles it,
|
|
@@ -405,11 +430,16 @@ def _response_to_result(
|
|
|
405
430
|
"Gemini produced an empty response. Continuing with empty message\n"
|
|
406
431
|
f"Feedback: {response.prompt_feedback}"
|
|
407
432
|
)
|
|
408
|
-
generations = [
|
|
433
|
+
generations = [
|
|
434
|
+
(ChatGenerationChunk if stream else ChatGeneration)(
|
|
435
|
+
message=(AIMessageChunk if stream else AIMessage)(content=""),
|
|
436
|
+
generation_info={},
|
|
437
|
+
)
|
|
438
|
+
]
|
|
409
439
|
return ChatResult(generations=generations, llm_output=llm_output)
|
|
410
440
|
|
|
411
441
|
|
|
412
|
-
class ChatGoogleGenerativeAI(BaseChatModel):
|
|
442
|
+
class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
413
443
|
"""`Google Generative AI` Chat models API.
|
|
414
444
|
|
|
415
445
|
To use, you must have either:
|
|
@@ -427,53 +457,13 @@ class ChatGoogleGenerativeAI(BaseChatModel):
|
|
|
427
457
|
|
|
428
458
|
"""
|
|
429
459
|
|
|
430
|
-
model: str = Field(
|
|
431
|
-
...,
|
|
432
|
-
description="""The name of the model to use.
|
|
433
|
-
Supported examples:
|
|
434
|
-
- gemini-pro""",
|
|
435
|
-
)
|
|
436
|
-
max_output_tokens: int = Field(default=None, description="Max output tokens")
|
|
437
|
-
|
|
438
460
|
client: Any #: :meta private:
|
|
439
|
-
|
|
440
|
-
temperature: Optional[float] = None
|
|
441
|
-
"""Run inference with this temperature. Must by in the closed
|
|
442
|
-
interval [0.0, 1.0]."""
|
|
443
|
-
top_k: Optional[int] = None
|
|
444
|
-
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
|
445
|
-
Must be positive."""
|
|
446
|
-
top_p: Optional[int] = None
|
|
447
|
-
"""The maximum cumulative probability of tokens to consider when sampling.
|
|
448
|
-
|
|
449
|
-
The model uses combined Top-k and nucleus sampling.
|
|
450
|
-
|
|
451
|
-
Tokens are sorted based on their assigned probabilities so
|
|
452
|
-
that only the most likely tokens are considered. Top-k
|
|
453
|
-
sampling directly limits the maximum number of tokens to
|
|
454
|
-
consider, while Nucleus sampling limits number of tokens
|
|
455
|
-
based on the cumulative probability.
|
|
456
|
-
|
|
457
|
-
Note: The default value varies by model, see the
|
|
458
|
-
`Model.top_p` attribute of the `Model` returned the
|
|
459
|
-
`genai.get_model` function.
|
|
460
|
-
"""
|
|
461
|
-
n: int = Field(default=1, alias="candidate_count")
|
|
462
|
-
"""Number of chat completions to generate for each prompt. Note that the API may
|
|
463
|
-
not return the full n completions if duplicates are generated."""
|
|
461
|
+
|
|
464
462
|
convert_system_message_to_human: bool = False
|
|
465
463
|
"""Whether to merge any leading SystemMessage into the following HumanMessage.
|
|
466
464
|
|
|
467
465
|
Gemini does not support system messages; any unsupported messages will
|
|
468
466
|
raise an error."""
|
|
469
|
-
client_options: Optional[Dict] = Field(
|
|
470
|
-
None,
|
|
471
|
-
description="Client options to pass to the Google API client.",
|
|
472
|
-
)
|
|
473
|
-
transport: Optional[str] = Field(
|
|
474
|
-
None,
|
|
475
|
-
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
|
|
476
|
-
)
|
|
477
467
|
|
|
478
468
|
class Config:
|
|
479
469
|
allow_population_by_field_name = True
|
|
@@ -486,10 +476,6 @@ Supported examples:
|
|
|
486
476
|
def _llm_type(self) -> str:
|
|
487
477
|
return "chat-google-generative-ai"
|
|
488
478
|
|
|
489
|
-
@property
|
|
490
|
-
def _is_geminiai(self) -> bool:
|
|
491
|
-
return self.model is not None and "gemini" in self.model
|
|
492
|
-
|
|
493
479
|
@classmethod
|
|
494
480
|
def is_lc_serializable(self) -> bool:
|
|
495
481
|
return True
|
|
@@ -560,7 +546,11 @@ Supported examples:
|
|
|
560
546
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
561
547
|
**kwargs: Any,
|
|
562
548
|
) -> ChatResult:
|
|
563
|
-
params, chat, message = self._prepare_chat(
|
|
549
|
+
params, chat, message = self._prepare_chat(
|
|
550
|
+
messages,
|
|
551
|
+
stop=stop,
|
|
552
|
+
functions=kwargs.get("functions"),
|
|
553
|
+
)
|
|
564
554
|
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
|
565
555
|
content=message,
|
|
566
556
|
**params,
|
|
@@ -575,7 +565,11 @@ Supported examples:
|
|
|
575
565
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
576
566
|
**kwargs: Any,
|
|
577
567
|
) -> ChatResult:
|
|
578
|
-
params, chat, message = self._prepare_chat(
|
|
568
|
+
params, chat, message = self._prepare_chat(
|
|
569
|
+
messages,
|
|
570
|
+
stop=stop,
|
|
571
|
+
functions=kwargs.get("functions"),
|
|
572
|
+
)
|
|
579
573
|
response: genai.types.GenerateContentResponse = await _achat_with_retry(
|
|
580
574
|
content=message,
|
|
581
575
|
**params,
|
|
@@ -590,7 +584,11 @@ Supported examples:
|
|
|
590
584
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
591
585
|
**kwargs: Any,
|
|
592
586
|
) -> Iterator[ChatGenerationChunk]:
|
|
593
|
-
params, chat, message = self._prepare_chat(
|
|
587
|
+
params, chat, message = self._prepare_chat(
|
|
588
|
+
messages,
|
|
589
|
+
stop=stop,
|
|
590
|
+
functions=kwargs.get("functions"),
|
|
591
|
+
)
|
|
594
592
|
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
|
595
593
|
content=message,
|
|
596
594
|
**params,
|
|
@@ -598,17 +596,11 @@ Supported examples:
|
|
|
598
596
|
stream=True,
|
|
599
597
|
)
|
|
600
598
|
for chunk in response:
|
|
601
|
-
_chat_result = _response_to_result(
|
|
602
|
-
chunk,
|
|
603
|
-
ai_msg_t=AIMessageChunk,
|
|
604
|
-
human_msg_t=HumanMessageChunk,
|
|
605
|
-
chat_msg_t=ChatMessageChunk,
|
|
606
|
-
generation_t=ChatGenerationChunk,
|
|
607
|
-
)
|
|
599
|
+
_chat_result = _response_to_result(chunk, stream=True)
|
|
608
600
|
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
|
609
|
-
yield gen
|
|
610
601
|
if run_manager:
|
|
611
602
|
run_manager.on_llm_new_token(gen.text)
|
|
603
|
+
yield gen
|
|
612
604
|
|
|
613
605
|
async def _astream(
|
|
614
606
|
self,
|
|
@@ -617,24 +609,22 @@ Supported examples:
|
|
|
617
609
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
618
610
|
**kwargs: Any,
|
|
619
611
|
) -> AsyncIterator[ChatGenerationChunk]:
|
|
620
|
-
params, chat, message = self._prepare_chat(
|
|
612
|
+
params, chat, message = self._prepare_chat(
|
|
613
|
+
messages,
|
|
614
|
+
stop=stop,
|
|
615
|
+
functions=kwargs.get("functions"),
|
|
616
|
+
)
|
|
621
617
|
async for chunk in await _achat_with_retry(
|
|
622
618
|
content=message,
|
|
623
619
|
**params,
|
|
624
620
|
generation_method=chat.send_message_async,
|
|
625
621
|
stream=True,
|
|
626
622
|
):
|
|
627
|
-
_chat_result = _response_to_result(
|
|
628
|
-
chunk,
|
|
629
|
-
ai_msg_t=AIMessageChunk,
|
|
630
|
-
human_msg_t=HumanMessageChunk,
|
|
631
|
-
chat_msg_t=ChatMessageChunk,
|
|
632
|
-
generation_t=ChatGenerationChunk,
|
|
633
|
-
)
|
|
623
|
+
_chat_result = _response_to_result(chunk, stream=True)
|
|
634
624
|
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
|
635
|
-
yield gen
|
|
636
625
|
if run_manager:
|
|
637
626
|
await run_manager.on_llm_new_token(gen.text)
|
|
627
|
+
yield gen
|
|
638
628
|
|
|
639
629
|
def _prepare_chat(
|
|
640
630
|
self,
|
|
@@ -642,11 +632,37 @@ Supported examples:
|
|
|
642
632
|
stop: Optional[List[str]] = None,
|
|
643
633
|
**kwargs: Any,
|
|
644
634
|
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
|
|
635
|
+
client = self.client
|
|
636
|
+
functions = kwargs.pop("functions", None)
|
|
637
|
+
if functions:
|
|
638
|
+
tools = convert_to_genai_function_declarations(functions)
|
|
639
|
+
client = genai.GenerativeModel(model_name=self.model, tools=tools)
|
|
640
|
+
|
|
645
641
|
params = self._prepare_params(stop, **kwargs)
|
|
646
642
|
history = _parse_chat_history(
|
|
647
643
|
messages,
|
|
648
644
|
convert_system_message_to_human=self.convert_system_message_to_human,
|
|
649
645
|
)
|
|
650
646
|
message = history.pop()
|
|
651
|
-
chat =
|
|
647
|
+
chat = client.start_chat(history=history)
|
|
652
648
|
return params, chat, message
|
|
649
|
+
|
|
650
|
+
def get_num_tokens(self, text: str) -> int:
|
|
651
|
+
"""Get the number of tokens present in the text.
|
|
652
|
+
|
|
653
|
+
Useful for checking if an input will fit in a model's context window.
|
|
654
|
+
|
|
655
|
+
Args:
|
|
656
|
+
text: The string input to tokenize.
|
|
657
|
+
|
|
658
|
+
Returns:
|
|
659
|
+
The integer number of tokens in the text.
|
|
660
|
+
"""
|
|
661
|
+
if self._model_family == GoogleModelFamily.GEMINI:
|
|
662
|
+
result = self.client.count_tokens(text)
|
|
663
|
+
token_count = result.total_tokens
|
|
664
|
+
else:
|
|
665
|
+
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
|
666
|
+
token_count = result["token_count"]
|
|
667
|
+
|
|
668
|
+
return token_count
|
langchain_google_genai/llms.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from enum import Enum, auto
|
|
3
4
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
|
4
5
|
|
|
5
6
|
import google.api_core
|
|
@@ -15,6 +16,19 @@ from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validat
|
|
|
15
16
|
from langchain_core.utils import get_from_dict_or_env
|
|
16
17
|
|
|
17
18
|
|
|
19
|
+
class GoogleModelFamily(str, Enum):
|
|
20
|
+
GEMINI = auto()
|
|
21
|
+
PALM = auto()
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def _missing_(cls, value: Any) -> Optional["GoogleModelFamily"]:
|
|
25
|
+
if "gemini" in value.lower():
|
|
26
|
+
return GoogleModelFamily.GEMINI
|
|
27
|
+
elif "text-bison" in value.lower():
|
|
28
|
+
return GoogleModelFamily.PALM
|
|
29
|
+
return None
|
|
30
|
+
|
|
31
|
+
|
|
18
32
|
def _create_retry_decorator(
|
|
19
33
|
llm: BaseLLM,
|
|
20
34
|
*,
|
|
@@ -56,21 +70,25 @@ def _completion_with_retry(
|
|
|
56
70
|
prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
|
|
57
71
|
) -> Any:
|
|
58
72
|
generation_config = kwargs.get("generation_config", {})
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
73
|
+
error_msg = (
|
|
74
|
+
"Your location is not supported by google-generativeai at the moment. "
|
|
75
|
+
"Try to use VertexAI LLM from langchain_google_vertexai"
|
|
76
|
+
)
|
|
77
|
+
try:
|
|
78
|
+
if is_gemini:
|
|
79
|
+
return llm.client.generate_content(
|
|
80
|
+
contents=prompt, stream=stream, generation_config=generation_config
|
|
81
|
+
)
|
|
82
|
+
return llm.client.generate_text(prompt=prompt, **kwargs)
|
|
83
|
+
except google.api_core.exceptions.FailedPrecondition as exc:
|
|
84
|
+
if "location is not supported" in exc.message:
|
|
85
|
+
raise ValueError(error_msg)
|
|
64
86
|
|
|
65
87
|
return _completion_with_retry(
|
|
66
88
|
prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
|
|
67
89
|
)
|
|
68
90
|
|
|
69
91
|
|
|
70
|
-
def _is_gemini_model(model_name: str) -> bool:
|
|
71
|
-
return "gemini" in model_name
|
|
72
|
-
|
|
73
|
-
|
|
74
92
|
def _strip_erroneous_leading_spaces(text: str) -> str:
|
|
75
93
|
"""Strip erroneous leading spaces from text.
|
|
76
94
|
|
|
@@ -84,17 +102,9 @@ def _strip_erroneous_leading_spaces(text: str) -> str:
|
|
|
84
102
|
return text
|
|
85
103
|
|
|
86
104
|
|
|
87
|
-
class
|
|
88
|
-
"""Google
|
|
89
|
-
|
|
90
|
-
Example:
|
|
91
|
-
.. code-block:: python
|
|
105
|
+
class _BaseGoogleGenerativeAI(BaseModel):
|
|
106
|
+
"""Base class for Google Generative AI LLMs"""
|
|
92
107
|
|
|
93
|
-
from langchain_google_genai import GoogleGenerativeAI
|
|
94
|
-
llm = GoogleGenerativeAI(model="gemini-pro")
|
|
95
|
-
"""
|
|
96
|
-
|
|
97
|
-
client: Any #: :meta private:
|
|
98
108
|
model: str = Field(
|
|
99
109
|
...,
|
|
100
110
|
description="""The name of the model to use.
|
|
@@ -133,15 +143,39 @@ Supported examples:
|
|
|
133
143
|
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
|
|
134
144
|
)
|
|
135
145
|
|
|
136
|
-
@property
|
|
137
|
-
def is_gemini(self) -> bool:
|
|
138
|
-
"""Returns whether a model is belongs to a Gemini family or not."""
|
|
139
|
-
return _is_gemini_model(self.model)
|
|
140
|
-
|
|
141
146
|
@property
|
|
142
147
|
def lc_secrets(self) -> Dict[str, str]:
|
|
143
148
|
return {"google_api_key": "GOOGLE_API_KEY"}
|
|
144
149
|
|
|
150
|
+
@property
|
|
151
|
+
def _model_family(self) -> str:
|
|
152
|
+
return GoogleModelFamily(self.model)
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def _identifying_params(self) -> Dict[str, Any]:
|
|
156
|
+
"""Get the identifying parameters."""
|
|
157
|
+
return {
|
|
158
|
+
"model": self.model,
|
|
159
|
+
"temperature": self.temperature,
|
|
160
|
+
"top_p": self.top_p,
|
|
161
|
+
"top_k": self.top_k,
|
|
162
|
+
"max_output_tokens": self.max_output_tokens,
|
|
163
|
+
"candidate_count": self.n,
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
|
|
168
|
+
"""Google GenerativeAI models.
|
|
169
|
+
|
|
170
|
+
Example:
|
|
171
|
+
.. code-block:: python
|
|
172
|
+
|
|
173
|
+
from langchain_google_genai import GoogleGenerativeAI
|
|
174
|
+
llm = GoogleGenerativeAI(model="gemini-pro")
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
client: Any #: :meta private:
|
|
178
|
+
|
|
145
179
|
@root_validator()
|
|
146
180
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
147
181
|
"""Validates params and passes them to google-generativeai package."""
|
|
@@ -159,7 +193,7 @@ Supported examples:
|
|
|
159
193
|
client_options=values.get("client_options"),
|
|
160
194
|
)
|
|
161
195
|
|
|
162
|
-
if
|
|
196
|
+
if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
|
|
163
197
|
values["client"] = genai.GenerativeModel(model_name=model_name)
|
|
164
198
|
else:
|
|
165
199
|
values["client"] = genai
|
|
@@ -195,7 +229,7 @@ Supported examples:
|
|
|
195
229
|
"candidate_count": self.n,
|
|
196
230
|
}
|
|
197
231
|
for prompt in prompts:
|
|
198
|
-
if self.
|
|
232
|
+
if self._model_family == GoogleModelFamily.GEMINI:
|
|
199
233
|
res = _completion_with_retry(
|
|
200
234
|
self,
|
|
201
235
|
prompt=prompt,
|
|
@@ -271,7 +305,11 @@ Supported examples:
|
|
|
271
305
|
Returns:
|
|
272
306
|
The integer number of tokens in the text.
|
|
273
307
|
"""
|
|
274
|
-
if self.
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
308
|
+
if self._model_family == GoogleModelFamily.GEMINI:
|
|
309
|
+
result = self.client.count_tokens(text)
|
|
310
|
+
token_count = result.total_tokens
|
|
311
|
+
else:
|
|
312
|
+
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
|
313
|
+
token_count = result["token_count"]
|
|
314
|
+
|
|
315
|
+
return token_count
|
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: langchain-google-genai
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.8
|
|
4
4
|
Summary: An integration package connecting Google's genai package and LangChain
|
|
5
5
|
Home-page: https://github.com/langchain-ai/langchain
|
|
6
|
+
License: MIT
|
|
6
7
|
Requires-Python: >=3.9,<4.0
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
7
9
|
Classifier: Programming Language :: Python :: 3
|
|
8
10
|
Classifier: Programming Language :: Python :: 3.9
|
|
9
11
|
Classifier: Programming Language :: Python :: 3.10
|
|
10
12
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
11
14
|
Provides-Extra: images
|
|
12
15
|
Requires-Dist: google-generativeai (>=0.3.1,<0.4.0)
|
|
13
16
|
Requires-Dist: langchain-core (>=0.1,<0.2)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
langchain_google_genai/__init__.py,sha256=cDMb1xbsenQtYBACNP0dYPwA7Rt015MT7HC_XP3X-4Y,2304
|
|
2
|
+
langchain_google_genai/_common.py,sha256=1r0VrrBSTZfGprmICZ5OV-W5SK31jKRFFCNE3vJ3jmk,136
|
|
3
|
+
langchain_google_genai/_function_utils.py,sha256=9IVMPQq5lQB8F_whG3mrGOM_tjmP-TMEY3URHfJnjgI,3640
|
|
4
|
+
langchain_google_genai/chat_models.py,sha256=flq1xYC2OYoijOtU9qJy12CZwJ8sdY3va3x0004pl1M,23208
|
|
5
|
+
langchain_google_genai/embeddings.py,sha256=EMa-sDGXUpAPMSyjA2-YXF_TGrlSlqljNeqysAh574s,3951
|
|
6
|
+
langchain_google_genai/llms.py,sha256=Kk7fCrWbbfR1tpiFRYBX6fgEhgAYq7HQVSNvR6UvWqY,10990
|
|
7
|
+
langchain_google_genai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
+
langchain_google_genai-0.0.8.dist-info/LICENSE,sha256=DppmdYJVSc1jd0aio6ptnMUn5tIHrdAhQ12SclEBfBg,1072
|
|
9
|
+
langchain_google_genai-0.0.8.dist-info/METADATA,sha256=mwV8mJuT-jVGcq-HQa7IiWjd6HDvVOIVa01lSTe1R-A,2851
|
|
10
|
+
langchain_google_genai-0.0.8.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
|
|
11
|
+
langchain_google_genai-0.0.8.dist-info/RECORD,,
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
langchain_google_genai/__init__.py,sha256=cDMb1xbsenQtYBACNP0dYPwA7Rt015MT7HC_XP3X-4Y,2304
|
|
2
|
-
langchain_google_genai/_common.py,sha256=1r0VrrBSTZfGprmICZ5OV-W5SK31jKRFFCNE3vJ3jmk,136
|
|
3
|
-
langchain_google_genai/chat_models.py,sha256=451I60uXY24h7bpTxkF3HNccdY0vIRm6wly1PL9URXk,22894
|
|
4
|
-
langchain_google_genai/embeddings.py,sha256=EMa-sDGXUpAPMSyjA2-YXF_TGrlSlqljNeqysAh574s,3951
|
|
5
|
-
langchain_google_genai/llms.py,sha256=xoypoTqhw-p3_1Htk8yURoTpRtjVjFvsqP1WPWwk8eg,9707
|
|
6
|
-
langchain_google_genai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
|
-
langchain_google_genai-0.0.6.dist-info/LICENSE,sha256=DppmdYJVSc1jd0aio6ptnMUn5tIHrdAhQ12SclEBfBg,1072
|
|
8
|
-
langchain_google_genai-0.0.6.dist-info/METADATA,sha256=9IOHkgYT2g3XbCU0kdxETvOJaK5z4FjckdV6I1hmn2w,2736
|
|
9
|
-
langchain_google_genai-0.0.6.dist-info/WHEEL,sha256=d2fvjOD7sXsVzChCqf0Ty0JbHKBaLYwDbGQDwQTnJ50,88
|
|
10
|
-
langchain_google_genai-0.0.6.dist-info/RECORD,,
|
|
File without changes
|