langchain-google-genai 0.0.5__tar.gz → 0.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langchain_google_genai-0.0.5 → langchain_google_genai-0.0.7}/PKG-INFO +7 -3
- langchain_google_genai-0.0.7/langchain_google_genai/_function_utils.py +135 -0
- {langchain_google_genai-0.0.5 → langchain_google_genai-0.0.7}/langchain_google_genai/chat_models.py +121 -59
- {langchain_google_genai-0.0.5 → langchain_google_genai-0.0.7}/langchain_google_genai/embeddings.py +18 -2
- {langchain_google_genai-0.0.5 → langchain_google_genai-0.0.7}/langchain_google_genai/llms.py +73 -32
- {langchain_google_genai-0.0.5 → langchain_google_genai-0.0.7}/pyproject.toml +6 -2
- {langchain_google_genai-0.0.5 → langchain_google_genai-0.0.7}/LICENSE +0 -0
- {langchain_google_genai-0.0.5 → langchain_google_genai-0.0.7}/README.md +0 -0
- {langchain_google_genai-0.0.5 → langchain_google_genai-0.0.7}/langchain_google_genai/__init__.py +0 -0
- {langchain_google_genai-0.0.5 → langchain_google_genai-0.0.7}/langchain_google_genai/_common.py +0 -0
- {langchain_google_genai-0.0.5 → langchain_google_genai-0.0.7}/langchain_google_genai/py.typed +0 -0
|
@@ -1,18 +1,22 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: langchain-google-genai
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.7
|
|
4
4
|
Summary: An integration package connecting Google's genai package and LangChain
|
|
5
|
-
Home-page: https://github.com/langchain-ai/langchain
|
|
5
|
+
Home-page: https://github.com/langchain-ai/langchain
|
|
6
|
+
License: MIT
|
|
6
7
|
Requires-Python: >=3.9,<4.0
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
7
9
|
Classifier: Programming Language :: Python :: 3
|
|
8
10
|
Classifier: Programming Language :: Python :: 3.9
|
|
9
11
|
Classifier: Programming Language :: Python :: 3.10
|
|
10
12
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
11
14
|
Provides-Extra: images
|
|
12
15
|
Requires-Dist: google-generativeai (>=0.3.1,<0.4.0)
|
|
13
16
|
Requires-Dist: langchain-core (>=0.1,<0.2)
|
|
14
17
|
Requires-Dist: pillow (>=10.1.0,<11.0.0) ; extra == "images"
|
|
15
|
-
Project-URL: Repository, https://github.com/langchain-ai/langchain
|
|
18
|
+
Project-URL: Repository, https://github.com/langchain-ai/langchain
|
|
19
|
+
Project-URL: Source Code, https://github.com/langchain-ai/langchain/tree/master/libs/partners/google-genai
|
|
16
20
|
Description-Content-Type: text/markdown
|
|
17
21
|
|
|
18
22
|
# langchain-google-genai
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import (
|
|
4
|
+
Dict,
|
|
5
|
+
List,
|
|
6
|
+
Type,
|
|
7
|
+
Union,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
from langchain_core.pydantic_v1 import BaseModel
|
|
11
|
+
from langchain_core.tools import BaseTool
|
|
12
|
+
from langchain_core.utils.json_schema import dereference_refs
|
|
13
|
+
|
|
14
|
+
FunctionCallType = Union[BaseTool, Type[BaseModel], Dict]
|
|
15
|
+
|
|
16
|
+
TYPE_ENUM = {
|
|
17
|
+
"string": 1,
|
|
18
|
+
"number": 2,
|
|
19
|
+
"integer": 3,
|
|
20
|
+
"boolean": 4,
|
|
21
|
+
"array": 5,
|
|
22
|
+
"object": 6,
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def convert_to_genai_function_declarations(
|
|
27
|
+
function_calls: List[FunctionCallType],
|
|
28
|
+
) -> Dict:
|
|
29
|
+
function_declarations = []
|
|
30
|
+
for fc in function_calls:
|
|
31
|
+
function_declarations.append(_convert_to_genai_function(fc))
|
|
32
|
+
return {
|
|
33
|
+
"function_declarations": function_declarations,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _convert_to_genai_function(fc: FunctionCallType) -> Dict:
|
|
38
|
+
"""
|
|
39
|
+
Produce
|
|
40
|
+
|
|
41
|
+
{
|
|
42
|
+
"name": "get_weather",
|
|
43
|
+
"description": "Determine weather in my location",
|
|
44
|
+
"parameters": {
|
|
45
|
+
"properties": {
|
|
46
|
+
"location": {
|
|
47
|
+
"description": "The city and state e.g. San Francisco, CA",
|
|
48
|
+
"type_": 1
|
|
49
|
+
},
|
|
50
|
+
"unit": { "enum": ["c", "f"], "type_": 1 }
|
|
51
|
+
},
|
|
52
|
+
"required": ["location"],
|
|
53
|
+
"type_": 6
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
"""
|
|
58
|
+
if isinstance(fc, BaseTool):
|
|
59
|
+
return _convert_tool_to_genai_function(fc)
|
|
60
|
+
elif isinstance(fc, type) and issubclass(fc, BaseModel):
|
|
61
|
+
return _convert_pydantic_to_genai_function(fc)
|
|
62
|
+
elif isinstance(fc, dict):
|
|
63
|
+
return {
|
|
64
|
+
**fc,
|
|
65
|
+
"parameters": {
|
|
66
|
+
"properties": {
|
|
67
|
+
k: {
|
|
68
|
+
"type_": TYPE_ENUM[v["type"]],
|
|
69
|
+
"description": v.get("description"),
|
|
70
|
+
}
|
|
71
|
+
for k, v in fc["parameters"]["properties"].items()
|
|
72
|
+
},
|
|
73
|
+
"required": fc["parameters"].get("required", []),
|
|
74
|
+
"type_": TYPE_ENUM[fc["parameters"]["type"]],
|
|
75
|
+
},
|
|
76
|
+
}
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError(f"Unsupported function call type {fc}")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _convert_tool_to_genai_function(tool: BaseTool) -> Dict:
|
|
82
|
+
if tool.args_schema:
|
|
83
|
+
schema = dereference_refs(tool.args_schema.schema())
|
|
84
|
+
schema.pop("definitions", None)
|
|
85
|
+
|
|
86
|
+
return {
|
|
87
|
+
"name": tool.name or schema["title"],
|
|
88
|
+
"description": tool.description or schema["description"],
|
|
89
|
+
"parameters": {
|
|
90
|
+
"properties": {
|
|
91
|
+
k: {
|
|
92
|
+
"type_": TYPE_ENUM[v["type"]],
|
|
93
|
+
"description": v.get("description"),
|
|
94
|
+
}
|
|
95
|
+
for k, v in schema["properties"].items()
|
|
96
|
+
},
|
|
97
|
+
"required": schema["required"],
|
|
98
|
+
"type_": TYPE_ENUM[schema["type"]],
|
|
99
|
+
},
|
|
100
|
+
}
|
|
101
|
+
else:
|
|
102
|
+
return {
|
|
103
|
+
"name": tool.name,
|
|
104
|
+
"description": tool.description,
|
|
105
|
+
"parameters": {
|
|
106
|
+
"properties": {
|
|
107
|
+
"__arg1": {"type": "string"},
|
|
108
|
+
},
|
|
109
|
+
"required": ["__arg1"],
|
|
110
|
+
"type_": TYPE_ENUM["object"],
|
|
111
|
+
},
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _convert_pydantic_to_genai_function(
|
|
116
|
+
pydantic_model: Type[BaseModel],
|
|
117
|
+
) -> Dict:
|
|
118
|
+
schema = dereference_refs(pydantic_model.schema())
|
|
119
|
+
schema.pop("definitions", None)
|
|
120
|
+
|
|
121
|
+
return {
|
|
122
|
+
"name": schema["title"],
|
|
123
|
+
"description": schema.get("description", ""),
|
|
124
|
+
"parameters": {
|
|
125
|
+
"properties": {
|
|
126
|
+
k: {
|
|
127
|
+
"type_": TYPE_ENUM[v["type"]],
|
|
128
|
+
"description": v.get("description"),
|
|
129
|
+
}
|
|
130
|
+
for k, v in schema["properties"].items()
|
|
131
|
+
},
|
|
132
|
+
"required": schema["required"],
|
|
133
|
+
"type_": TYPE_ENUM[schema["type"]],
|
|
134
|
+
},
|
|
135
|
+
}
|
{langchain_google_genai-0.0.5 → langchain_google_genai-0.0.7}/langchain_google_genai/chat_models.py
RENAMED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
+
import json
|
|
4
5
|
import logging
|
|
5
6
|
import os
|
|
6
7
|
from io import BytesIO
|
|
@@ -21,9 +22,12 @@ from typing import (
|
|
|
21
22
|
)
|
|
22
23
|
from urllib.parse import urlparse
|
|
23
24
|
|
|
25
|
+
import google.api_core
|
|
26
|
+
|
|
24
27
|
# TODO: remove ignore once the google package is published with types
|
|
25
28
|
import google.generativeai as genai # type: ignore[import]
|
|
26
29
|
import requests
|
|
30
|
+
from google.ai.generativelanguage_v1beta import FunctionCall
|
|
27
31
|
from langchain_core.callbacks.manager import (
|
|
28
32
|
AsyncCallbackManagerForLLMRun,
|
|
29
33
|
CallbackManagerForLLMRun,
|
|
@@ -35,12 +39,13 @@ from langchain_core.messages import (
|
|
|
35
39
|
BaseMessage,
|
|
36
40
|
ChatMessage,
|
|
37
41
|
ChatMessageChunk,
|
|
42
|
+
FunctionMessage,
|
|
38
43
|
HumanMessage,
|
|
39
44
|
HumanMessageChunk,
|
|
40
45
|
SystemMessage,
|
|
41
46
|
)
|
|
42
47
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
43
|
-
from langchain_core.pydantic_v1 import
|
|
48
|
+
from langchain_core.pydantic_v1 import SecretStr, root_validator
|
|
44
49
|
from langchain_core.utils import get_from_dict_or_env
|
|
45
50
|
from tenacity import (
|
|
46
51
|
before_sleep_log,
|
|
@@ -51,6 +56,10 @@ from tenacity import (
|
|
|
51
56
|
)
|
|
52
57
|
|
|
53
58
|
from langchain_google_genai._common import GoogleGenerativeAIError
|
|
59
|
+
from langchain_google_genai._function_utils import (
|
|
60
|
+
convert_to_genai_function_declarations,
|
|
61
|
+
)
|
|
62
|
+
from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI
|
|
54
63
|
|
|
55
64
|
IMAGE_TYPES: Tuple = ()
|
|
56
65
|
try:
|
|
@@ -87,8 +96,6 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
|
|
|
87
96
|
Callable[[Any], Any]: A retry decorator configured for handling specific
|
|
88
97
|
Google API exceptions.
|
|
89
98
|
"""
|
|
90
|
-
import google.api_core.exceptions
|
|
91
|
-
|
|
92
99
|
multiplier = 2
|
|
93
100
|
min_seconds = 1
|
|
94
101
|
max_seconds = 60
|
|
@@ -123,14 +130,22 @@ def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
|
|
|
123
130
|
Any: The result from the chat generation method.
|
|
124
131
|
"""
|
|
125
132
|
retry_decorator = _create_retry_decorator()
|
|
126
|
-
from google.api_core.exceptions import InvalidArgument # type: ignore
|
|
127
133
|
|
|
128
134
|
@retry_decorator
|
|
129
135
|
def _chat_with_retry(**kwargs: Any) -> Any:
|
|
130
136
|
try:
|
|
131
137
|
return generation_method(**kwargs)
|
|
132
|
-
|
|
133
|
-
|
|
138
|
+
# Do not retry for these errors.
|
|
139
|
+
except google.api_core.exceptions.FailedPrecondition as exc:
|
|
140
|
+
if "location is not supported" in exc.message:
|
|
141
|
+
error_msg = (
|
|
142
|
+
"Your location is not supported by google-generativeai "
|
|
143
|
+
"at the moment. Try to use ChatVertexAI LLM from "
|
|
144
|
+
"langchain_google_vertexai."
|
|
145
|
+
)
|
|
146
|
+
raise ValueError(error_msg)
|
|
147
|
+
|
|
148
|
+
except google.api_core.exceptions.InvalidArgument as e:
|
|
134
149
|
raise ChatGoogleGenerativeAIError(
|
|
135
150
|
f"Invalid argument provided to Gemini: {e}"
|
|
136
151
|
) from e
|
|
@@ -312,14 +327,20 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
|
|
|
312
327
|
continue
|
|
313
328
|
elif isinstance(message, AIMessage):
|
|
314
329
|
role = "model"
|
|
330
|
+
# TODO: Handle AImessage with function call
|
|
331
|
+
parts = _convert_to_parts(message.content)
|
|
315
332
|
elif isinstance(message, HumanMessage):
|
|
316
333
|
role = "user"
|
|
334
|
+
parts = _convert_to_parts(message.content)
|
|
335
|
+
elif isinstance(message, FunctionMessage):
|
|
336
|
+
role = "user"
|
|
337
|
+
# TODO: Handle FunctionMessage
|
|
338
|
+
parts = _convert_to_parts(message.content)
|
|
317
339
|
else:
|
|
318
340
|
raise ValueError(
|
|
319
341
|
f"Unexpected message with type {type(message)} at the position {i}."
|
|
320
342
|
)
|
|
321
343
|
|
|
322
|
-
parts = _convert_to_parts(message.content)
|
|
323
344
|
if raw_system_message:
|
|
324
345
|
if role == "model":
|
|
325
346
|
raise ValueError(
|
|
@@ -332,15 +353,36 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
|
|
|
332
353
|
return messages
|
|
333
354
|
|
|
334
355
|
|
|
335
|
-
def
|
|
356
|
+
def _retrieve_function_call_response(
|
|
357
|
+
parts: List[genai.types.PartType],
|
|
358
|
+
) -> Optional[Dict]:
|
|
359
|
+
for idx, part in enumerate(parts):
|
|
360
|
+
if part.function_call and part.function_call.name:
|
|
361
|
+
fc: FunctionCall = part.function_call
|
|
362
|
+
return {
|
|
363
|
+
"function_call": {
|
|
364
|
+
"name": fc.name,
|
|
365
|
+
"arguments": json.dumps(
|
|
366
|
+
dict(fc.args.items())
|
|
367
|
+
), # dump to match other function calling llms for now
|
|
368
|
+
}
|
|
369
|
+
}
|
|
370
|
+
return None
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def _parts_to_content(
|
|
374
|
+
parts: List[genai.types.PartType],
|
|
375
|
+
) -> Tuple[Union[str, List[Union[Dict[Any, Any], str]]], Optional[Dict]]:
|
|
336
376
|
"""Converts a list of Gemini API Part objects into a list of LangChain messages."""
|
|
377
|
+
function_call_resp = _retrieve_function_call_response(parts)
|
|
378
|
+
|
|
337
379
|
if len(parts) == 1 and parts[0].text is not None and not parts[0].inline_data:
|
|
338
380
|
# Simple text response. The typical response
|
|
339
|
-
return parts[0].text
|
|
381
|
+
return parts[0].text, function_call_resp
|
|
340
382
|
elif not parts:
|
|
341
383
|
logger.warning("Gemini produced an empty response.")
|
|
342
|
-
return ""
|
|
343
|
-
messages = []
|
|
384
|
+
return "", function_call_resp
|
|
385
|
+
messages: List[Union[Dict[Any, Any], str]] = []
|
|
344
386
|
for part in parts:
|
|
345
387
|
if part.text is not None:
|
|
346
388
|
messages.append(
|
|
@@ -352,7 +394,7 @@ def _parts_to_content(parts: List[genai.types.PartType]) -> Union[List[dict], st
|
|
|
352
394
|
else:
|
|
353
395
|
# TODO: Handle inline_data if that's a thing?
|
|
354
396
|
raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}")
|
|
355
|
-
return messages
|
|
397
|
+
return messages, function_call_resp
|
|
356
398
|
|
|
357
399
|
|
|
358
400
|
def _response_to_result(
|
|
@@ -379,16 +421,24 @@ def _response_to_result(
|
|
|
379
421
|
"model": ai_msg_t,
|
|
380
422
|
"user": human_msg_t,
|
|
381
423
|
}
|
|
424
|
+
|
|
382
425
|
for candidate in response.candidates:
|
|
383
426
|
content = candidate.content
|
|
384
|
-
parts_content = _parts_to_content(content.parts)
|
|
427
|
+
parts_content, additional_kwargs = _parts_to_content(content.parts)
|
|
385
428
|
if content.role not in role_map:
|
|
386
429
|
logger.warning(
|
|
387
430
|
f"Unrecognized role: {content.role}. Treating as a ChatMessage."
|
|
388
431
|
)
|
|
389
|
-
msg = chat_msg_t(
|
|
432
|
+
msg = chat_msg_t(
|
|
433
|
+
content=parts_content,
|
|
434
|
+
role=content.role,
|
|
435
|
+
additional_kwargs=additional_kwargs or {},
|
|
436
|
+
)
|
|
390
437
|
else:
|
|
391
|
-
msg = role_map[content.role](
|
|
438
|
+
msg = role_map[content.role](
|
|
439
|
+
content=parts_content,
|
|
440
|
+
additional_kwargs=additional_kwargs or {},
|
|
441
|
+
)
|
|
392
442
|
generation_info = {}
|
|
393
443
|
if candidate.finish_reason:
|
|
394
444
|
generation_info["finish_reason"] = candidate.finish_reason.name
|
|
@@ -409,7 +459,7 @@ def _response_to_result(
|
|
|
409
459
|
return ChatResult(generations=generations, llm_output=llm_output)
|
|
410
460
|
|
|
411
461
|
|
|
412
|
-
class ChatGoogleGenerativeAI(BaseChatModel):
|
|
462
|
+
class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
413
463
|
"""`Google Generative AI` Chat models API.
|
|
414
464
|
|
|
415
465
|
To use, you must have either:
|
|
@@ -427,40 +477,8 @@ class ChatGoogleGenerativeAI(BaseChatModel):
|
|
|
427
477
|
|
|
428
478
|
"""
|
|
429
479
|
|
|
430
|
-
model: str = Field(
|
|
431
|
-
...,
|
|
432
|
-
description="""The name of the model to use.
|
|
433
|
-
Supported examples:
|
|
434
|
-
- gemini-pro""",
|
|
435
|
-
)
|
|
436
|
-
max_output_tokens: int = Field(default=None, description="Max output tokens")
|
|
437
|
-
|
|
438
480
|
client: Any #: :meta private:
|
|
439
|
-
|
|
440
|
-
temperature: Optional[float] = None
|
|
441
|
-
"""Run inference with this temperature. Must by in the closed
|
|
442
|
-
interval [0.0, 1.0]."""
|
|
443
|
-
top_k: Optional[int] = None
|
|
444
|
-
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
|
445
|
-
Must be positive."""
|
|
446
|
-
top_p: Optional[int] = None
|
|
447
|
-
"""The maximum cumulative probability of tokens to consider when sampling.
|
|
448
|
-
|
|
449
|
-
The model uses combined Top-k and nucleus sampling.
|
|
450
|
-
|
|
451
|
-
Tokens are sorted based on their assigned probabilities so
|
|
452
|
-
that only the most likely tokens are considered. Top-k
|
|
453
|
-
sampling directly limits the maximum number of tokens to
|
|
454
|
-
consider, while Nucleus sampling limits number of tokens
|
|
455
|
-
based on the cumulative probability.
|
|
456
|
-
|
|
457
|
-
Note: The default value varies by model, see the
|
|
458
|
-
`Model.top_p` attribute of the `Model` returned the
|
|
459
|
-
`genai.get_model` function.
|
|
460
|
-
"""
|
|
461
|
-
n: int = Field(default=1, alias="candidate_count")
|
|
462
|
-
"""Number of chat completions to generate for each prompt. Note that the API may
|
|
463
|
-
not return the full n completions if duplicates are generated."""
|
|
481
|
+
|
|
464
482
|
convert_system_message_to_human: bool = False
|
|
465
483
|
"""Whether to merge any leading SystemMessage into the following HumanMessage.
|
|
466
484
|
|
|
@@ -478,22 +496,24 @@ Supported examples:
|
|
|
478
496
|
def _llm_type(self) -> str:
|
|
479
497
|
return "chat-google-generative-ai"
|
|
480
498
|
|
|
481
|
-
@property
|
|
482
|
-
def _is_geminiai(self) -> bool:
|
|
483
|
-
return self.model is not None and "gemini" in self.model
|
|
484
|
-
|
|
485
499
|
@classmethod
|
|
486
500
|
def is_lc_serializable(self) -> bool:
|
|
487
501
|
return True
|
|
488
502
|
|
|
489
503
|
@root_validator()
|
|
490
504
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
505
|
+
"""Validates params and passes them to google-generativeai package."""
|
|
491
506
|
google_api_key = get_from_dict_or_env(
|
|
492
507
|
values, "google_api_key", "GOOGLE_API_KEY"
|
|
493
508
|
)
|
|
494
509
|
if isinstance(google_api_key, SecretStr):
|
|
495
510
|
google_api_key = google_api_key.get_secret_value()
|
|
496
|
-
|
|
511
|
+
|
|
512
|
+
genai.configure(
|
|
513
|
+
api_key=google_api_key,
|
|
514
|
+
transport=values.get("transport"),
|
|
515
|
+
client_options=values.get("client_options"),
|
|
516
|
+
)
|
|
497
517
|
if (
|
|
498
518
|
values.get("temperature") is not None
|
|
499
519
|
and not 0 <= values["temperature"] <= 1
|
|
@@ -546,7 +566,11 @@ Supported examples:
|
|
|
546
566
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
547
567
|
**kwargs: Any,
|
|
548
568
|
) -> ChatResult:
|
|
549
|
-
params, chat, message = self._prepare_chat(
|
|
569
|
+
params, chat, message = self._prepare_chat(
|
|
570
|
+
messages,
|
|
571
|
+
stop=stop,
|
|
572
|
+
functions=kwargs.get("functions"),
|
|
573
|
+
)
|
|
550
574
|
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
|
551
575
|
content=message,
|
|
552
576
|
**params,
|
|
@@ -561,7 +585,11 @@ Supported examples:
|
|
|
561
585
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
562
586
|
**kwargs: Any,
|
|
563
587
|
) -> ChatResult:
|
|
564
|
-
params, chat, message = self._prepare_chat(
|
|
588
|
+
params, chat, message = self._prepare_chat(
|
|
589
|
+
messages,
|
|
590
|
+
stop=stop,
|
|
591
|
+
functions=kwargs.get("functions"),
|
|
592
|
+
)
|
|
565
593
|
response: genai.types.GenerateContentResponse = await _achat_with_retry(
|
|
566
594
|
content=message,
|
|
567
595
|
**params,
|
|
@@ -576,7 +604,11 @@ Supported examples:
|
|
|
576
604
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
577
605
|
**kwargs: Any,
|
|
578
606
|
) -> Iterator[ChatGenerationChunk]:
|
|
579
|
-
params, chat, message = self._prepare_chat(
|
|
607
|
+
params, chat, message = self._prepare_chat(
|
|
608
|
+
messages,
|
|
609
|
+
stop=stop,
|
|
610
|
+
functions=kwargs.get("functions"),
|
|
611
|
+
)
|
|
580
612
|
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
|
581
613
|
content=message,
|
|
582
614
|
**params,
|
|
@@ -603,7 +635,11 @@ Supported examples:
|
|
|
603
635
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
604
636
|
**kwargs: Any,
|
|
605
637
|
) -> AsyncIterator[ChatGenerationChunk]:
|
|
606
|
-
params, chat, message = self._prepare_chat(
|
|
638
|
+
params, chat, message = self._prepare_chat(
|
|
639
|
+
messages,
|
|
640
|
+
stop=stop,
|
|
641
|
+
functions=kwargs.get("functions"),
|
|
642
|
+
)
|
|
607
643
|
async for chunk in await _achat_with_retry(
|
|
608
644
|
content=message,
|
|
609
645
|
**params,
|
|
@@ -628,11 +664,37 @@ Supported examples:
|
|
|
628
664
|
stop: Optional[List[str]] = None,
|
|
629
665
|
**kwargs: Any,
|
|
630
666
|
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
|
|
667
|
+
client = self.client
|
|
668
|
+
functions = kwargs.pop("functions", None)
|
|
669
|
+
if functions:
|
|
670
|
+
tools = convert_to_genai_function_declarations(functions)
|
|
671
|
+
client = genai.GenerativeModel(model_name=self.model, tools=tools)
|
|
672
|
+
|
|
631
673
|
params = self._prepare_params(stop, **kwargs)
|
|
632
674
|
history = _parse_chat_history(
|
|
633
675
|
messages,
|
|
634
676
|
convert_system_message_to_human=self.convert_system_message_to_human,
|
|
635
677
|
)
|
|
636
678
|
message = history.pop()
|
|
637
|
-
chat =
|
|
679
|
+
chat = client.start_chat(history=history)
|
|
638
680
|
return params, chat, message
|
|
681
|
+
|
|
682
|
+
def get_num_tokens(self, text: str) -> int:
|
|
683
|
+
"""Get the number of tokens present in the text.
|
|
684
|
+
|
|
685
|
+
Useful for checking if an input will fit in a model's context window.
|
|
686
|
+
|
|
687
|
+
Args:
|
|
688
|
+
text: The string input to tokenize.
|
|
689
|
+
|
|
690
|
+
Returns:
|
|
691
|
+
The integer number of tokens in the text.
|
|
692
|
+
"""
|
|
693
|
+
if self._model_family == GoogleModelFamily.GEMINI:
|
|
694
|
+
result = self.client.count_tokens(text)
|
|
695
|
+
token_count = result.total_tokens
|
|
696
|
+
else:
|
|
697
|
+
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
|
698
|
+
token_count = result["token_count"]
|
|
699
|
+
|
|
700
|
+
return token_count
|
{langchain_google_genai-0.0.5 → langchain_google_genai-0.0.7}/langchain_google_genai/embeddings.py
RENAMED
|
@@ -43,16 +43,32 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
|
|
|
43
43
|
description="The Google API key to use. If not provided, "
|
|
44
44
|
"the GOOGLE_API_KEY environment variable will be used.",
|
|
45
45
|
)
|
|
46
|
+
client_options: Optional[Dict] = Field(
|
|
47
|
+
None,
|
|
48
|
+
description=(
|
|
49
|
+
"A dictionary of client options to pass to the Google API client, "
|
|
50
|
+
"such as `api_endpoint`."
|
|
51
|
+
),
|
|
52
|
+
)
|
|
53
|
+
transport: Optional[str] = Field(
|
|
54
|
+
None,
|
|
55
|
+
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
|
|
56
|
+
)
|
|
46
57
|
|
|
47
58
|
@root_validator()
|
|
48
59
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
49
|
-
"""Validates
|
|
60
|
+
"""Validates params and passes them to google-generativeai package."""
|
|
50
61
|
google_api_key = get_from_dict_or_env(
|
|
51
62
|
values, "google_api_key", "GOOGLE_API_KEY"
|
|
52
63
|
)
|
|
53
64
|
if isinstance(google_api_key, SecretStr):
|
|
54
65
|
google_api_key = google_api_key.get_secret_value()
|
|
55
|
-
|
|
66
|
+
|
|
67
|
+
genai.configure(
|
|
68
|
+
api_key=google_api_key,
|
|
69
|
+
transport=values.get("transport"),
|
|
70
|
+
client_options=values.get("client_options"),
|
|
71
|
+
)
|
|
56
72
|
return values
|
|
57
73
|
|
|
58
74
|
def _embed(
|
{langchain_google_genai-0.0.5 → langchain_google_genai-0.0.7}/langchain_google_genai/llms.py
RENAMED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from enum import Enum, auto
|
|
3
4
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
|
4
5
|
|
|
5
6
|
import google.api_core
|
|
@@ -15,6 +16,19 @@ from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validat
|
|
|
15
16
|
from langchain_core.utils import get_from_dict_or_env
|
|
16
17
|
|
|
17
18
|
|
|
19
|
+
class GoogleModelFamily(str, Enum):
|
|
20
|
+
GEMINI = auto()
|
|
21
|
+
PALM = auto()
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def _missing_(cls, value: Any) -> Optional["GoogleModelFamily"]:
|
|
25
|
+
if "gemini" in value.lower():
|
|
26
|
+
return GoogleModelFamily.GEMINI
|
|
27
|
+
elif "text-bison" in value.lower():
|
|
28
|
+
return GoogleModelFamily.PALM
|
|
29
|
+
return None
|
|
30
|
+
|
|
31
|
+
|
|
18
32
|
def _create_retry_decorator(
|
|
19
33
|
llm: BaseLLM,
|
|
20
34
|
*,
|
|
@@ -56,21 +70,25 @@ def _completion_with_retry(
|
|
|
56
70
|
prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
|
|
57
71
|
) -> Any:
|
|
58
72
|
generation_config = kwargs.get("generation_config", {})
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
73
|
+
error_msg = (
|
|
74
|
+
"Your location is not supported by google-generativeai at the moment. "
|
|
75
|
+
"Try to use VertexAI LLM from langchain_google_vertexai"
|
|
76
|
+
)
|
|
77
|
+
try:
|
|
78
|
+
if is_gemini:
|
|
79
|
+
return llm.client.generate_content(
|
|
80
|
+
contents=prompt, stream=stream, generation_config=generation_config
|
|
81
|
+
)
|
|
82
|
+
return llm.client.generate_text(prompt=prompt, **kwargs)
|
|
83
|
+
except google.api_core.exceptions.FailedPrecondition as exc:
|
|
84
|
+
if "location is not supported" in exc.message:
|
|
85
|
+
raise ValueError(error_msg)
|
|
64
86
|
|
|
65
87
|
return _completion_with_retry(
|
|
66
88
|
prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
|
|
67
89
|
)
|
|
68
90
|
|
|
69
91
|
|
|
70
|
-
def _is_gemini_model(model_name: str) -> bool:
|
|
71
|
-
return "gemini" in model_name
|
|
72
|
-
|
|
73
|
-
|
|
74
92
|
def _strip_erroneous_leading_spaces(text: str) -> str:
|
|
75
93
|
"""Strip erroneous leading spaces from text.
|
|
76
94
|
|
|
@@ -84,17 +102,9 @@ def _strip_erroneous_leading_spaces(text: str) -> str:
|
|
|
84
102
|
return text
|
|
85
103
|
|
|
86
104
|
|
|
87
|
-
class
|
|
88
|
-
"""Google
|
|
105
|
+
class _BaseGoogleGenerativeAI(BaseModel):
|
|
106
|
+
"""Base class for Google Generative AI LLMs"""
|
|
89
107
|
|
|
90
|
-
Example:
|
|
91
|
-
.. code-block:: python
|
|
92
|
-
|
|
93
|
-
from langchain_google_genai import GoogleGenerativeAI
|
|
94
|
-
llm = GoogleGenerativeAI(model="gemini-pro")
|
|
95
|
-
"""
|
|
96
|
-
|
|
97
|
-
client: Any #: :meta private:
|
|
98
108
|
model: str = Field(
|
|
99
109
|
...,
|
|
100
110
|
description="""The name of the model to use.
|
|
@@ -121,19 +131,42 @@ Supported examples:
|
|
|
121
131
|
not return the full n completions if duplicates are generated."""
|
|
122
132
|
max_retries: int = 6
|
|
123
133
|
"""The maximum number of retries to make when generating."""
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
134
|
+
client_options: Optional[Dict] = Field(
|
|
135
|
+
None,
|
|
136
|
+
description=(
|
|
137
|
+
"A dictionary of client options to pass to the Google API client, "
|
|
138
|
+
"such as `api_endpoint`."
|
|
139
|
+
),
|
|
140
|
+
)
|
|
141
|
+
transport: Optional[str] = Field(
|
|
142
|
+
None,
|
|
143
|
+
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
|
|
144
|
+
)
|
|
129
145
|
|
|
130
146
|
@property
|
|
131
147
|
def lc_secrets(self) -> Dict[str, str]:
|
|
132
148
|
return {"google_api_key": "GOOGLE_API_KEY"}
|
|
133
149
|
|
|
150
|
+
@property
|
|
151
|
+
def _model_family(self) -> str:
|
|
152
|
+
return GoogleModelFamily(self.model)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
|
|
156
|
+
"""Google GenerativeAI models.
|
|
157
|
+
|
|
158
|
+
Example:
|
|
159
|
+
.. code-block:: python
|
|
160
|
+
|
|
161
|
+
from langchain_google_genai import GoogleGenerativeAI
|
|
162
|
+
llm = GoogleGenerativeAI(model="gemini-pro")
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
client: Any #: :meta private:
|
|
166
|
+
|
|
134
167
|
@root_validator()
|
|
135
168
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
136
|
-
"""
|
|
169
|
+
"""Validates params and passes them to google-generativeai package."""
|
|
137
170
|
google_api_key = get_from_dict_or_env(
|
|
138
171
|
values, "google_api_key", "GOOGLE_API_KEY"
|
|
139
172
|
)
|
|
@@ -142,9 +175,13 @@ Supported examples:
|
|
|
142
175
|
if isinstance(google_api_key, SecretStr):
|
|
143
176
|
google_api_key = google_api_key.get_secret_value()
|
|
144
177
|
|
|
145
|
-
genai.configure(
|
|
178
|
+
genai.configure(
|
|
179
|
+
api_key=google_api_key,
|
|
180
|
+
transport=values.get("transport"),
|
|
181
|
+
client_options=values.get("client_options"),
|
|
182
|
+
)
|
|
146
183
|
|
|
147
|
-
if
|
|
184
|
+
if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
|
|
148
185
|
values["client"] = genai.GenerativeModel(model_name=model_name)
|
|
149
186
|
else:
|
|
150
187
|
values["client"] = genai
|
|
@@ -180,7 +217,7 @@ Supported examples:
|
|
|
180
217
|
"candidate_count": self.n,
|
|
181
218
|
}
|
|
182
219
|
for prompt in prompts:
|
|
183
|
-
if self.
|
|
220
|
+
if self._model_family == GoogleModelFamily.GEMINI:
|
|
184
221
|
res = _completion_with_retry(
|
|
185
222
|
self,
|
|
186
223
|
prompt=prompt,
|
|
@@ -256,7 +293,11 @@ Supported examples:
|
|
|
256
293
|
Returns:
|
|
257
294
|
The integer number of tokens in the text.
|
|
258
295
|
"""
|
|
259
|
-
if self.
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
296
|
+
if self._model_family == GoogleModelFamily.GEMINI:
|
|
297
|
+
result = self.client.count_tokens(text)
|
|
298
|
+
token_count = result.total_tokens
|
|
299
|
+
else:
|
|
300
|
+
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
|
301
|
+
token_count = result["token_count"]
|
|
302
|
+
|
|
303
|
+
return token_count
|
|
@@ -1,10 +1,14 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "langchain-google-genai"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.7"
|
|
4
4
|
description = "An integration package connecting Google's genai package and LangChain"
|
|
5
5
|
authors = []
|
|
6
6
|
readme = "README.md"
|
|
7
|
-
repository = "https://github.com/langchain-ai/langchain
|
|
7
|
+
repository = "https://github.com/langchain-ai/langchain"
|
|
8
|
+
license = "MIT"
|
|
9
|
+
|
|
10
|
+
[tool.poetry.urls]
|
|
11
|
+
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/google-genai"
|
|
8
12
|
|
|
9
13
|
[tool.poetry.dependencies]
|
|
10
14
|
python = ">=3.9,<4.0"
|
|
File without changes
|
|
File without changes
|
{langchain_google_genai-0.0.5 → langchain_google_genai-0.0.7}/langchain_google_genai/__init__.py
RENAMED
|
File without changes
|
{langchain_google_genai-0.0.5 → langchain_google_genai-0.0.7}/langchain_google_genai/_common.py
RENAMED
|
File without changes
|
{langchain_google_genai-0.0.5 → langchain_google_genai-0.0.7}/langchain_google_genai/py.typed
RENAMED
|
File without changes
|