langchain-b12 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain_b12/__init__.py +0 -0
- langchain_b12/genai/embeddings.py +99 -0
- langchain_b12/genai/genai.py +407 -0
- langchain_b12/genai/genai_utils.py +219 -0
- langchain_b12/py.typed +0 -0
- langchain_b12-0.1.0.dist-info/METADATA +39 -0
- langchain_b12-0.1.0.dist-info/RECORD +8 -0
- langchain_b12-0.1.0.dist-info/WHEEL +4 -0
|
File without changes
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from google.genai import Client
|
|
4
|
+
from google.oauth2 import service_account
|
|
5
|
+
from langchain_core.embeddings import Embeddings
|
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class GenAIEmbeddings(Embeddings, BaseModel):
|
|
10
|
+
"""Embeddings implementation using `google-genai`."""
|
|
11
|
+
|
|
12
|
+
model_name: str = Field(default="text-multilingual-embedding-002")
|
|
13
|
+
client: Client = Field(
|
|
14
|
+
default_factory=lambda: Client(
|
|
15
|
+
credentials=service_account.Credentials.from_service_account_file(
|
|
16
|
+
filename=os.getenv("GOOGLE_APPLICATION_CREDENTIALS"),
|
|
17
|
+
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
18
|
+
)
|
|
19
|
+
),
|
|
20
|
+
exclude=True,
|
|
21
|
+
)
|
|
22
|
+
model_config = ConfigDict(
|
|
23
|
+
arbitrary_types_allowed=True,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
|
27
|
+
"""Embed a list of text strings using the Google GenAI API.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
texts (list[str]): The text strings to embed.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
list[list[float]]: The embedding vectors.
|
|
34
|
+
"""
|
|
35
|
+
embeddings = []
|
|
36
|
+
for text in texts:
|
|
37
|
+
response = self.client.models.embed_content(
|
|
38
|
+
model=self.model_name,
|
|
39
|
+
contents=[text],
|
|
40
|
+
)
|
|
41
|
+
assert (
|
|
42
|
+
response.embeddings is not None
|
|
43
|
+
), "No embeddings found in the response."
|
|
44
|
+
for embedding in response.embeddings:
|
|
45
|
+
assert (
|
|
46
|
+
embedding.values is not None
|
|
47
|
+
), "No embedding values found in the response."
|
|
48
|
+
embeddings.append(embedding.values)
|
|
49
|
+
assert len(embeddings) == len(
|
|
50
|
+
texts
|
|
51
|
+
), "The number of embeddings does not match the number of texts."
|
|
52
|
+
return embeddings
|
|
53
|
+
|
|
54
|
+
def embed_query(self, text: str) -> list[float]:
|
|
55
|
+
"""Embed a text string using the Google GenAI API.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
text (str): The text to embed.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
list[float]: The embedding vector.
|
|
62
|
+
"""
|
|
63
|
+
return self.embed_documents([text])[0]
|
|
64
|
+
|
|
65
|
+
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
|
66
|
+
"""Embed a list of text strings using the Google GenAI API asynchronously.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
texts (list[str]): The text strings to embed.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
list[list[float]]: The embedding vectors.
|
|
73
|
+
"""
|
|
74
|
+
embeddings = []
|
|
75
|
+
response = await self.client.aio.models.embed_content(
|
|
76
|
+
model=self.model_name,
|
|
77
|
+
contents=texts,
|
|
78
|
+
)
|
|
79
|
+
assert response.embeddings is not None, "No embeddings found in the response."
|
|
80
|
+
for embedding in response.embeddings:
|
|
81
|
+
assert (
|
|
82
|
+
embedding.values is not None
|
|
83
|
+
), "No embedding values found in the response."
|
|
84
|
+
embeddings.append(embedding.values)
|
|
85
|
+
assert len(embeddings) == len(
|
|
86
|
+
texts
|
|
87
|
+
), "The number of embeddings does not match the number of texts."
|
|
88
|
+
return embeddings
|
|
89
|
+
|
|
90
|
+
async def aembed_query(self, text: str) -> list[float]:
|
|
91
|
+
"""Embed a text string using the Google GenAI API asynchronously.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
text (str): The text to embed.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
list[float]: The embedding vector.
|
|
98
|
+
"""
|
|
99
|
+
return (await self.aembed_documents([text]))[0]
|
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
|
|
4
|
+
from operator import itemgetter
|
|
5
|
+
from typing import Any, Literal, cast
|
|
6
|
+
|
|
7
|
+
from google import genai
|
|
8
|
+
from google.genai import types
|
|
9
|
+
from google.oauth2 import service_account
|
|
10
|
+
from langchain_b12.genai.genai_utils import (
|
|
11
|
+
convert_messages_to_contents,
|
|
12
|
+
parse_response_candidate,
|
|
13
|
+
)
|
|
14
|
+
from langchain_core.callbacks import (
|
|
15
|
+
AsyncCallbackManagerForLLMRun,
|
|
16
|
+
CallbackManagerForLLMRun,
|
|
17
|
+
)
|
|
18
|
+
from langchain_core.language_models import LangSmithParams, LanguageModelInput
|
|
19
|
+
from langchain_core.language_models.chat_models import (
|
|
20
|
+
BaseChatModel,
|
|
21
|
+
agenerate_from_stream,
|
|
22
|
+
generate_from_stream,
|
|
23
|
+
)
|
|
24
|
+
from langchain_core.messages import (
|
|
25
|
+
AIMessageChunk,
|
|
26
|
+
BaseMessage,
|
|
27
|
+
HumanMessage,
|
|
28
|
+
SystemMessage,
|
|
29
|
+
)
|
|
30
|
+
from langchain_core.messages.ai import UsageMetadata
|
|
31
|
+
from langchain_core.output_parsers import PydanticOutputParser
|
|
32
|
+
from langchain_core.output_parsers.base import OutputParserLike
|
|
33
|
+
from langchain_core.output_parsers.openai_tools import (
|
|
34
|
+
PydanticToolsParser,
|
|
35
|
+
)
|
|
36
|
+
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
|
37
|
+
from langchain_core.runnables import Runnable, RunnablePassthrough
|
|
38
|
+
from langchain_core.tools import BaseTool
|
|
39
|
+
from langchain_core.utils.function_calling import (
|
|
40
|
+
convert_to_openai_tool,
|
|
41
|
+
)
|
|
42
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
43
|
+
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ChatGenAI(BaseChatModel):
|
|
48
|
+
"""Implementation of BaseChatModel using `google-genai`"""
|
|
49
|
+
|
|
50
|
+
client: genai.Client = Field(
|
|
51
|
+
default_factory=lambda: genai.Client(
|
|
52
|
+
vertexai=True,
|
|
53
|
+
credentials=service_account.Credentials.from_service_account_file(
|
|
54
|
+
filename=os.getenv("GOOGLE_APPLICATION_CREDENTIALS"),
|
|
55
|
+
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
56
|
+
),
|
|
57
|
+
),
|
|
58
|
+
exclude=True,
|
|
59
|
+
)
|
|
60
|
+
model_name: str = Field(default="gemini-2.0-flash", alias="model")
|
|
61
|
+
"Underlying model name."
|
|
62
|
+
stop: list[str] | None = Field(default=None, alias="stop_sequences")
|
|
63
|
+
"Optional list of stop words to use when generating."
|
|
64
|
+
temperature: float | None = None
|
|
65
|
+
"Sampling temperature, it controls the degree of randomness in token selection."
|
|
66
|
+
max_output_tokens: int | None = Field(default=None, alias="max_tokens")
|
|
67
|
+
"Token limit determines the maximum amount of text output from one prompt."
|
|
68
|
+
top_p: float | None = None
|
|
69
|
+
"Tokens are selected from most probable to least until the sum of their "
|
|
70
|
+
"probabilities equals the top-p value. Top-p is ignored for Codey models."
|
|
71
|
+
top_k: int | None = None
|
|
72
|
+
"How the model selects tokens for output, the next token is selected from "
|
|
73
|
+
"among the top-k most probable tokens. Top-k is ignored for Codey models."
|
|
74
|
+
n: int = 1
|
|
75
|
+
"""How many completions to generate for each prompt."""
|
|
76
|
+
seed: int | None = None
|
|
77
|
+
"""Random seed for the generation."""
|
|
78
|
+
safety_settings: list[types.SafetySetting] | None = None
|
|
79
|
+
"""The default safety settings to use for all generations.
|
|
80
|
+
|
|
81
|
+
For example:
|
|
82
|
+
|
|
83
|
+
from langchain_google_vertexai import HarmBlockThreshold, HarmCategory
|
|
84
|
+
|
|
85
|
+
safety_settings = {
|
|
86
|
+
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
|
|
87
|
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
|
88
|
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
|
89
|
+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
|
90
|
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
|
91
|
+
}
|
|
92
|
+
""" # noqa: E501
|
|
93
|
+
|
|
94
|
+
model_config = ConfigDict(
|
|
95
|
+
arbitrary_types_allowed=True,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def _llm_type(self) -> str:
|
|
100
|
+
return "vertexai"
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def _default_params(self) -> dict[str, Any]:
|
|
104
|
+
return {
|
|
105
|
+
"temperature": self.temperature,
|
|
106
|
+
"max_output_tokens": self.max_output_tokens,
|
|
107
|
+
"candidate_count": self.n,
|
|
108
|
+
"seed": self.seed,
|
|
109
|
+
"top_k": self.top_k,
|
|
110
|
+
"top_p": self.top_p,
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def _identifying_params(self) -> dict[str, Any]:
|
|
115
|
+
"""Gets the identifying parameters."""
|
|
116
|
+
return {**{"model_name": self.model_name}, **self._default_params}
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def is_lc_serializable(cls) -> bool:
|
|
120
|
+
return True
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def get_lc_namespace(cls) -> list[str]:
|
|
124
|
+
"""Get the namespace of the langchain object."""
|
|
125
|
+
return ["langchain_b12", "genai", "genai"]
|
|
126
|
+
|
|
127
|
+
def _get_ls_params(
|
|
128
|
+
self, stop: list[str] | None = None, **kwargs: Any
|
|
129
|
+
) -> LangSmithParams:
|
|
130
|
+
"""Get standard params for tracing."""
|
|
131
|
+
params = {**self._default_params, **kwargs}
|
|
132
|
+
ls_params = LangSmithParams(
|
|
133
|
+
ls_provider="google_vertexai",
|
|
134
|
+
ls_model_name=self.model_name,
|
|
135
|
+
ls_model_type="chat",
|
|
136
|
+
ls_temperature=params.get("temperature", self.temperature),
|
|
137
|
+
)
|
|
138
|
+
if ls_max_tokens := params.get("max_output_tokens", self.max_output_tokens):
|
|
139
|
+
ls_params["ls_max_tokens"] = ls_max_tokens
|
|
140
|
+
if ls_stop := stop or params.get("stop", None) or self.stop:
|
|
141
|
+
ls_params["ls_stop"] = ls_stop
|
|
142
|
+
return ls_params
|
|
143
|
+
|
|
144
|
+
def _prepare_request(
|
|
145
|
+
self, messages: list[BaseMessage]
|
|
146
|
+
) -> tuple[str | None, types.ContentListUnion]:
|
|
147
|
+
contents = convert_messages_to_contents(messages)
|
|
148
|
+
if isinstance(messages[-1], SystemMessage):
|
|
149
|
+
system_instruction = messages[-1].content
|
|
150
|
+
assert isinstance(
|
|
151
|
+
system_instruction, str
|
|
152
|
+
), "System message content must be a string"
|
|
153
|
+
else:
|
|
154
|
+
system_instruction = None
|
|
155
|
+
return system_instruction, cast(types.ContentListUnion, contents)
|
|
156
|
+
|
|
157
|
+
def get_num_tokens(self, text: str) -> int:
|
|
158
|
+
"""Get the number of tokens present in the text."""
|
|
159
|
+
contents = convert_messages_to_contents([HumanMessage(content=text)])
|
|
160
|
+
response = self.client.models.count_tokens(
|
|
161
|
+
model=self.model_name,
|
|
162
|
+
contents=cast(types.ContentListUnion, contents),
|
|
163
|
+
)
|
|
164
|
+
return response.total_tokens or 0
|
|
165
|
+
|
|
166
|
+
def _generate(
|
|
167
|
+
self,
|
|
168
|
+
messages: list[BaseMessage],
|
|
169
|
+
stop: list[str] | None = None,
|
|
170
|
+
run_manager: CallbackManagerForLLMRun | None = None,
|
|
171
|
+
**kwargs: Any,
|
|
172
|
+
) -> ChatResult:
|
|
173
|
+
stream_iter = self._stream(
|
|
174
|
+
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
175
|
+
)
|
|
176
|
+
return generate_from_stream(stream_iter)
|
|
177
|
+
|
|
178
|
+
async def _agenerate(
|
|
179
|
+
self,
|
|
180
|
+
messages: list[BaseMessage],
|
|
181
|
+
stop: list[str] | None = None,
|
|
182
|
+
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
|
183
|
+
**kwargs: Any,
|
|
184
|
+
) -> ChatResult:
|
|
185
|
+
stream_iter = self._astream(
|
|
186
|
+
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
187
|
+
)
|
|
188
|
+
return await agenerate_from_stream(stream_iter)
|
|
189
|
+
|
|
190
|
+
def _stream(
|
|
191
|
+
self,
|
|
192
|
+
messages: list[BaseMessage],
|
|
193
|
+
stop: list[str] | None = None,
|
|
194
|
+
run_manager: CallbackManagerForLLMRun | None = None,
|
|
195
|
+
**kwargs: Any,
|
|
196
|
+
) -> Iterator[ChatGenerationChunk]:
|
|
197
|
+
system_message, contents = self._prepare_request(messages=messages)
|
|
198
|
+
response_iter = self.client.models.generate_content_stream(
|
|
199
|
+
model=self.model_name,
|
|
200
|
+
contents=contents,
|
|
201
|
+
config=types.GenerateContentConfig(
|
|
202
|
+
system_instruction=system_message,
|
|
203
|
+
temperature=self.temperature,
|
|
204
|
+
top_k=self.top_k,
|
|
205
|
+
top_p=self.top_p,
|
|
206
|
+
max_output_tokens=self.max_output_tokens,
|
|
207
|
+
candidate_count=self.n,
|
|
208
|
+
stop_sequences=stop or self.stop,
|
|
209
|
+
safety_settings=self.safety_settings,
|
|
210
|
+
**kwargs,
|
|
211
|
+
),
|
|
212
|
+
)
|
|
213
|
+
total_lc_usage = None
|
|
214
|
+
for response_chunk in response_iter:
|
|
215
|
+
chunk, total_lc_usage = self._gemini_chunk_to_generation_chunk(
|
|
216
|
+
response_chunk, prev_total_usage=total_lc_usage
|
|
217
|
+
)
|
|
218
|
+
if run_manager and isinstance(chunk.message.content, str):
|
|
219
|
+
run_manager.on_llm_new_token(chunk.message.content)
|
|
220
|
+
yield chunk
|
|
221
|
+
|
|
222
|
+
async def _astream(
|
|
223
|
+
self,
|
|
224
|
+
messages: list[BaseMessage],
|
|
225
|
+
stop: list[str] | None = None,
|
|
226
|
+
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
|
227
|
+
**kwargs: Any,
|
|
228
|
+
) -> AsyncIterator[ChatGenerationChunk]:
|
|
229
|
+
system_message, contents = self._prepare_request(messages=messages)
|
|
230
|
+
response_iter = self.client.aio.models.generate_content_stream(
|
|
231
|
+
model=self.model_name,
|
|
232
|
+
contents=contents,
|
|
233
|
+
config=types.GenerateContentConfig(
|
|
234
|
+
system_instruction=system_message,
|
|
235
|
+
temperature=self.temperature,
|
|
236
|
+
top_k=self.top_k,
|
|
237
|
+
top_p=self.top_p,
|
|
238
|
+
max_output_tokens=self.max_output_tokens,
|
|
239
|
+
candidate_count=self.n,
|
|
240
|
+
stop_sequences=stop or self.stop,
|
|
241
|
+
safety_settings=self.safety_settings,
|
|
242
|
+
**kwargs,
|
|
243
|
+
),
|
|
244
|
+
)
|
|
245
|
+
total_lc_usage = None
|
|
246
|
+
async for response_chunk in await response_iter:
|
|
247
|
+
chunk, total_lc_usage = self._gemini_chunk_to_generation_chunk(
|
|
248
|
+
response_chunk, prev_total_usage=total_lc_usage
|
|
249
|
+
)
|
|
250
|
+
if run_manager and isinstance(chunk.message.content, str):
|
|
251
|
+
await run_manager.on_llm_new_token(chunk.message.content)
|
|
252
|
+
yield chunk
|
|
253
|
+
|
|
254
|
+
def with_structured_output(
|
|
255
|
+
self,
|
|
256
|
+
schema: dict | type,
|
|
257
|
+
*,
|
|
258
|
+
include_raw: bool = False,
|
|
259
|
+
method: Literal["json_mode", "function_calling"] = "json_mode",
|
|
260
|
+
**kwargs: Any,
|
|
261
|
+
) -> Runnable[LanguageModelInput, dict | BaseModel]:
|
|
262
|
+
assert isinstance(schema, type) and issubclass(
|
|
263
|
+
schema, BaseModel
|
|
264
|
+
), "Structured output is only supported for Pydantic models."
|
|
265
|
+
if kwargs:
|
|
266
|
+
raise ValueError(f"Received unsupported arguments {kwargs}")
|
|
267
|
+
|
|
268
|
+
parser: OutputParserLike
|
|
269
|
+
|
|
270
|
+
if method == "json_mode":
|
|
271
|
+
parser = PydanticOutputParser(pydantic_object=schema)
|
|
272
|
+
|
|
273
|
+
llm = self.bind(
|
|
274
|
+
response_mime_type="application/json",
|
|
275
|
+
response_schema=schema,
|
|
276
|
+
ls_structured_output_format={
|
|
277
|
+
"kwargs": {"method": method},
|
|
278
|
+
"schema": schema,
|
|
279
|
+
},
|
|
280
|
+
)
|
|
281
|
+
elif method == "function_calling":
|
|
282
|
+
parser = PydanticToolsParser(tools=[schema], first_tool_only=True)
|
|
283
|
+
tool_choice = schema.__name__
|
|
284
|
+
llm = self.bind_tools(
|
|
285
|
+
[schema],
|
|
286
|
+
tool_choice=tool_choice,
|
|
287
|
+
ls_structured_output_format={
|
|
288
|
+
"kwargs": {"method": "function_calling"},
|
|
289
|
+
"schema": convert_to_openai_tool(schema),
|
|
290
|
+
},
|
|
291
|
+
)
|
|
292
|
+
else:
|
|
293
|
+
raise ValueError("method must be either 'json_mode' or 'function_calling'.")
|
|
294
|
+
|
|
295
|
+
if include_raw:
|
|
296
|
+
parser_with_fallback = RunnablePassthrough.assign(
|
|
297
|
+
parsed=itemgetter("raw") | parser, parsing_error=lambda _: None
|
|
298
|
+
).with_fallbacks(
|
|
299
|
+
[RunnablePassthrough.assign(parsed=lambda _: None)],
|
|
300
|
+
exception_key="parsing_error",
|
|
301
|
+
)
|
|
302
|
+
return {"raw": llm} | parser_with_fallback
|
|
303
|
+
else:
|
|
304
|
+
return llm | parser
|
|
305
|
+
|
|
306
|
+
def bind_tools(
|
|
307
|
+
self,
|
|
308
|
+
tools: Sequence[
|
|
309
|
+
dict[str, Any] | type | Callable[..., Any] | BaseTool | types.Tool
|
|
310
|
+
],
|
|
311
|
+
*,
|
|
312
|
+
tool_choice: str | None = None,
|
|
313
|
+
**kwargs: Any,
|
|
314
|
+
) -> Runnable[LanguageModelInput, BaseMessage]:
|
|
315
|
+
"""Bind tool-like objects to this chat model.
|
|
316
|
+
|
|
317
|
+
Assumes model is compatible with Vertex tool-calling API.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
tools: A list of tool definitions to bind to this chat model.
|
|
321
|
+
Can be a pydantic model, callable, or BaseTool. Pydantic
|
|
322
|
+
models, callables, and BaseTools will be automatically converted to
|
|
323
|
+
their schema dictionary representation.
|
|
324
|
+
**kwargs: Any additional parameters to pass to the
|
|
325
|
+
:class:`~langchain.runnable.Runnable` constructor.
|
|
326
|
+
"""
|
|
327
|
+
formatted_tools = []
|
|
328
|
+
for tool in tools:
|
|
329
|
+
if not isinstance(tool, types.Tool):
|
|
330
|
+
openai_tool = convert_to_openai_tool(tool)
|
|
331
|
+
if openai_tool["type"] != "function":
|
|
332
|
+
raise ValueError(
|
|
333
|
+
f"Tool {tool} is not a function tool. "
|
|
334
|
+
f"It is a {openai_tool['type']}. "
|
|
335
|
+
"Only function tools are supported."
|
|
336
|
+
)
|
|
337
|
+
function = openai_tool["function"]
|
|
338
|
+
formatted_tools.append(
|
|
339
|
+
types.Tool(
|
|
340
|
+
function_declarations=[types.FunctionDeclaration(**function)],
|
|
341
|
+
)
|
|
342
|
+
)
|
|
343
|
+
else:
|
|
344
|
+
formatted_tools.append(tool)
|
|
345
|
+
if tool_choice:
|
|
346
|
+
kwargs["tool_config"] = types.FunctionCallingConfig(
|
|
347
|
+
mode=types.FunctionCallingConfigMode.ANY,
|
|
348
|
+
allowed_function_names=[tool_choice],
|
|
349
|
+
)
|
|
350
|
+
return self.bind(tools=formatted_tools, **kwargs)
|
|
351
|
+
|
|
352
|
+
def _gemini_chunk_to_generation_chunk(
|
|
353
|
+
self,
|
|
354
|
+
response_chunk: types.GenerateContentResponse,
|
|
355
|
+
prev_total_usage: UsageMetadata | None = None,
|
|
356
|
+
) -> tuple[ChatGenerationChunk, UsageMetadata | None]:
|
|
357
|
+
def _parse_usage_metadata(
|
|
358
|
+
usage_metadata: types.GenerateContentResponseUsageMetadata,
|
|
359
|
+
) -> UsageMetadata:
|
|
360
|
+
return UsageMetadata(
|
|
361
|
+
input_tokens=usage_metadata.prompt_token_count or 0,
|
|
362
|
+
output_tokens=usage_metadata.candidates_token_count or 0,
|
|
363
|
+
total_tokens=usage_metadata.total_token_count or 0,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
total_lc_usage: UsageMetadata | None = (
|
|
367
|
+
_parse_usage_metadata(response_chunk.usage_metadata)
|
|
368
|
+
if response_chunk.usage_metadata
|
|
369
|
+
else None
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
if total_lc_usage and prev_total_usage:
|
|
373
|
+
lc_usage: UsageMetadata | None = UsageMetadata(
|
|
374
|
+
input_tokens=total_lc_usage["input_tokens"]
|
|
375
|
+
- prev_total_usage["input_tokens"],
|
|
376
|
+
output_tokens=total_lc_usage["output_tokens"]
|
|
377
|
+
- prev_total_usage["output_tokens"],
|
|
378
|
+
total_tokens=total_lc_usage["total_tokens"]
|
|
379
|
+
- prev_total_usage["total_tokens"],
|
|
380
|
+
)
|
|
381
|
+
else:
|
|
382
|
+
lc_usage = total_lc_usage
|
|
383
|
+
if not response_chunk.candidates:
|
|
384
|
+
message = AIMessageChunk(content="")
|
|
385
|
+
if lc_usage:
|
|
386
|
+
message.usage_metadata = lc_usage
|
|
387
|
+
generation_info = {}
|
|
388
|
+
else:
|
|
389
|
+
top_candidate = response_chunk.candidates[0]
|
|
390
|
+
generation_info = {
|
|
391
|
+
"finish_reason": top_candidate.finish_reason,
|
|
392
|
+
"finish_message": top_candidate.finish_message,
|
|
393
|
+
}
|
|
394
|
+
message = parse_response_candidate(top_candidate)
|
|
395
|
+
if lc_usage:
|
|
396
|
+
message.usage_metadata = lc_usage
|
|
397
|
+
# add model name if final chunk
|
|
398
|
+
if top_candidate.finish_reason is not None:
|
|
399
|
+
message.response_metadata["model_name"] = self.model_name
|
|
400
|
+
|
|
401
|
+
return (
|
|
402
|
+
ChatGenerationChunk(
|
|
403
|
+
message=message,
|
|
404
|
+
generation_info=generation_info,
|
|
405
|
+
),
|
|
406
|
+
total_lc_usage,
|
|
407
|
+
)
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
import uuid
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from typing import Any, cast
|
|
6
|
+
|
|
7
|
+
from google.genai import types
|
|
8
|
+
from langchain_core.messages import (
|
|
9
|
+
AIMessage,
|
|
10
|
+
AIMessageChunk,
|
|
11
|
+
BaseMessage,
|
|
12
|
+
HumanMessage,
|
|
13
|
+
SystemMessage,
|
|
14
|
+
ToolMessage,
|
|
15
|
+
)
|
|
16
|
+
from langchain_core.messages.tool import tool_call_chunk
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def multi_content_to_part(
|
|
20
|
+
contents: Sequence[dict[str, str | dict[str, str]] | str],
|
|
21
|
+
) -> list[types.Part]:
|
|
22
|
+
"""Convert sequence content to a Part object.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
contents: A sequence of dictionaries representing content. Examples:
|
|
26
|
+
[
|
|
27
|
+
{
|
|
28
|
+
"type": "text",
|
|
29
|
+
"text": "This is a text message"
|
|
30
|
+
},
|
|
31
|
+
{
|
|
32
|
+
"type": "image_url",
|
|
33
|
+
"image_url": {
|
|
34
|
+
"url": f"data:{mime_type};base64,{encoded_artifact}"
|
|
35
|
+
},
|
|
36
|
+
},
|
|
37
|
+
{
|
|
38
|
+
"type": "file",
|
|
39
|
+
"file": {
|
|
40
|
+
"uri": f"gs://{bucket_name}/{file_name}",
|
|
41
|
+
"mime_type": mime_type,
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
]
|
|
45
|
+
"""
|
|
46
|
+
parts = []
|
|
47
|
+
for content in contents:
|
|
48
|
+
assert isinstance(content, dict), "Expected dict content"
|
|
49
|
+
assert "type" in content, "Received dict content without type"
|
|
50
|
+
if content["type"] == "text":
|
|
51
|
+
assert "text" in content, "Expected 'text' in content"
|
|
52
|
+
if content["text"]:
|
|
53
|
+
assert isinstance(content["text"], str), "Expected str content"
|
|
54
|
+
parts.append(types.Part(text=content["text"]))
|
|
55
|
+
elif content["type"] == "image_url":
|
|
56
|
+
assert isinstance(content["image_url"], dict), "Expected dict image_url"
|
|
57
|
+
assert "url" in content["image_url"], "Expected 'url' in content"
|
|
58
|
+
split_url: tuple[str, str] = content["image_url"]["url"].split(",", 1) # type: ignore
|
|
59
|
+
header, encoded_data = split_url
|
|
60
|
+
mime_type = header.split(":", 1)[1].split(";", 1)[0]
|
|
61
|
+
data = base64.b64decode(encoded_data)
|
|
62
|
+
parts.append(types.Part.from_bytes(data=data, mime_type=mime_type))
|
|
63
|
+
elif content["type"] == "file":
|
|
64
|
+
assert "file" in content, "Expected 'file' in content"
|
|
65
|
+
file = content["file"]
|
|
66
|
+
assert isinstance(file, dict), "Expected dict file"
|
|
67
|
+
assert "uri" in file, "Expected 'uri' in content['file']"
|
|
68
|
+
assert "mime_type" in file, "Expected 'mime_type' in content['file']"
|
|
69
|
+
parts.append(
|
|
70
|
+
types.Part.from_uri(file_uri=file["uri"], mime_type=file["mime_type"])
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError(f"Unknown content type: {content['type']}")
|
|
74
|
+
return parts
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def convert_base_message_to_parts(
|
|
78
|
+
message: BaseMessage,
|
|
79
|
+
) -> list[types.Part]:
|
|
80
|
+
"""Convert a LangChain BaseMessage to Google GenAI Content object."""
|
|
81
|
+
parts = []
|
|
82
|
+
if isinstance(message.content, str):
|
|
83
|
+
if message.content:
|
|
84
|
+
parts.append(types.Part(text=message.content))
|
|
85
|
+
elif isinstance(message.content, list):
|
|
86
|
+
parts.extend(multi_content_to_part(message.content))
|
|
87
|
+
else:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
"Received unexpected content type, "
|
|
90
|
+
f"expected str or list, but got {type(message.content)}"
|
|
91
|
+
)
|
|
92
|
+
return parts
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def convert_messages_to_contents(
|
|
96
|
+
messages: Sequence[BaseMessage],
|
|
97
|
+
) -> list[types.Content]:
|
|
98
|
+
"""Convert a sequence of LangChain messages to Google GenAI Content objects.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
messages: A sequence of LangChain BaseMessage objects
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
A list of Google GenAI Content objects
|
|
105
|
+
"""
|
|
106
|
+
contents = []
|
|
107
|
+
|
|
108
|
+
for message in messages:
|
|
109
|
+
if isinstance(message, HumanMessage):
|
|
110
|
+
parts = convert_base_message_to_parts(message)
|
|
111
|
+
contents.append(types.UserContent(parts=parts))
|
|
112
|
+
elif isinstance(message, AIMessage):
|
|
113
|
+
text_parts = convert_base_message_to_parts(message)
|
|
114
|
+
function_parts = []
|
|
115
|
+
if message.tool_calls:
|
|
116
|
+
# Example of tool_call
|
|
117
|
+
# tool_call = {
|
|
118
|
+
# "name": "foo",
|
|
119
|
+
# "args": {"a": 1},
|
|
120
|
+
# "id": "123"
|
|
121
|
+
# }
|
|
122
|
+
for tool_call in message.tool_calls:
|
|
123
|
+
tool_id = tool_call["id"]
|
|
124
|
+
assert tool_id, "Tool call ID is required"
|
|
125
|
+
function_parts.append(
|
|
126
|
+
types.Part(
|
|
127
|
+
function_call=types.FunctionCall(
|
|
128
|
+
name=tool_call["name"],
|
|
129
|
+
args=tool_call["args"],
|
|
130
|
+
id=tool_id,
|
|
131
|
+
),
|
|
132
|
+
)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
contents.append(
|
|
136
|
+
types.ModelContent(
|
|
137
|
+
parts=[*text_parts, *function_parts],
|
|
138
|
+
)
|
|
139
|
+
)
|
|
140
|
+
elif isinstance(message, ToolMessage):
|
|
141
|
+
# Note: We tried combining function_call and function_response into one
|
|
142
|
+
# part, but that throws a 4xx server error.
|
|
143
|
+
assert isinstance(message.content, str), "Expected str content"
|
|
144
|
+
assert message.name, "Tool name is required"
|
|
145
|
+
tool_part = types.Part(
|
|
146
|
+
function_response=types.FunctionResponse(
|
|
147
|
+
id=message.tool_call_id,
|
|
148
|
+
name=message.name,
|
|
149
|
+
response={"output": message.content},
|
|
150
|
+
),
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Ensure that all function_responses are in a single content
|
|
154
|
+
last_content = contents[-1]
|
|
155
|
+
last_content_part = last_content.parts[-1]
|
|
156
|
+
if last_content_part.function_response:
|
|
157
|
+
# Merge with the last content
|
|
158
|
+
last_content.parts.append(tool_part)
|
|
159
|
+
else:
|
|
160
|
+
# Create a new content
|
|
161
|
+
contents.append(types.UserContent(parts=[tool_part]))
|
|
162
|
+
elif isinstance(message, SystemMessage):
|
|
163
|
+
# There is no genai.types equivalent for SystemMessage
|
|
164
|
+
pass
|
|
165
|
+
else:
|
|
166
|
+
raise ValueError(f"Invalid message type: {type(message)}")
|
|
167
|
+
|
|
168
|
+
return contents
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def parse_response_candidate(response_candidate: types.Candidate) -> AIMessageChunk:
|
|
172
|
+
content: None | str | list[str] = None
|
|
173
|
+
additional_kwargs = {}
|
|
174
|
+
tool_call_chunks = []
|
|
175
|
+
|
|
176
|
+
assert response_candidate.content, "Response candidate content is None"
|
|
177
|
+
for part in response_candidate.content.parts or []:
|
|
178
|
+
try:
|
|
179
|
+
text: str | None = part.text
|
|
180
|
+
except AttributeError:
|
|
181
|
+
text = None
|
|
182
|
+
|
|
183
|
+
if text:
|
|
184
|
+
if not content:
|
|
185
|
+
content = text
|
|
186
|
+
elif isinstance(content, str):
|
|
187
|
+
content = [content, text]
|
|
188
|
+
elif isinstance(content, list):
|
|
189
|
+
content.append(text)
|
|
190
|
+
else:
|
|
191
|
+
raise ValueError("Unexpected content type")
|
|
192
|
+
|
|
193
|
+
if part.function_call:
|
|
194
|
+
# For backward compatibility we store a function call in additional_kwargs,
|
|
195
|
+
# but in general the full set of function calls is stored in tool_calls.
|
|
196
|
+
function_call = {"name": part.function_call.name}
|
|
197
|
+
# dump to match other function calling llm for now
|
|
198
|
+
function_call_args_dict = part.function_call.args
|
|
199
|
+
assert function_call_args_dict is not None
|
|
200
|
+
function_call["arguments"] = json.dumps(function_call_args_dict)
|
|
201
|
+
additional_kwargs["function_call"] = function_call
|
|
202
|
+
|
|
203
|
+
index = function_call.get("index")
|
|
204
|
+
tool_call_chunks.append(
|
|
205
|
+
tool_call_chunk(
|
|
206
|
+
name=function_call.get("name"),
|
|
207
|
+
args=function_call.get("arguments"),
|
|
208
|
+
id=function_call.get("id", str(uuid.uuid4())),
|
|
209
|
+
index=int(index) if index else None,
|
|
210
|
+
)
|
|
211
|
+
)
|
|
212
|
+
if content is None:
|
|
213
|
+
content = ""
|
|
214
|
+
|
|
215
|
+
return AIMessageChunk(
|
|
216
|
+
content=cast(str | list[str | dict[Any, Any]], content),
|
|
217
|
+
additional_kwargs=additional_kwargs,
|
|
218
|
+
tool_call_chunks=tool_call_chunks,
|
|
219
|
+
)
|
langchain_b12/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: langchain-b12
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A reusable collection of tools and implementations for Langchain
|
|
5
|
+
Author-email: Vincent Min <vincent.min@b12-consulting.com>
|
|
6
|
+
Requires-Python: >=3.11
|
|
7
|
+
Requires-Dist: langchain-core>=0.3.60
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
|
|
10
|
+
# Langchain B12
|
|
11
|
+
|
|
12
|
+
This repo hosts a collection of custom LangChain components.
|
|
13
|
+
|
|
14
|
+
- ChatGenAI converts a sequence of LangChain messages to Google GenAI Content objects and vice versa.
|
|
15
|
+
|
|
16
|
+
## Installation
|
|
17
|
+
|
|
18
|
+
To install this package, run
|
|
19
|
+
```bash
|
|
20
|
+
pip install langchain-b12
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
Some components rely on additional packages that may be installed as extras.
|
|
24
|
+
For example, to use the Google chatmodel `ChatGenAI`, you can run
|
|
25
|
+
```bash
|
|
26
|
+
pip install langchain-b12[google]
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
## Components
|
|
30
|
+
|
|
31
|
+
The repo contains these components:
|
|
32
|
+
|
|
33
|
+
- `ChatGenAI`: An implementation of `BaseChatModel` that uses the `google-genai` package. Note that `langchain-google-genai` and `langchain-google-vertexai` exist, but neither uses the latest and recommended `google-genai` package.
|
|
34
|
+
- `GenAIEmbeddings`: An implementation of `Embeddings` that uses the `google-genai` package.
|
|
35
|
+
|
|
36
|
+
## Comments
|
|
37
|
+
|
|
38
|
+
This repo exists for easy reuse and extension of custom LangChain components.
|
|
39
|
+
When appropriate, we will create PRs to merge these components directly into LangChain.
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
langchain_b12/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
langchain_b12/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
langchain_b12/genai/embeddings.py,sha256=od2bVIgt7v9aNAHG0PVypVF1H_XgHto2nTd8vwfvyN8,3355
|
|
4
|
+
langchain_b12/genai/genai.py,sha256=o2KLo2QlLDy0hijt6T2HaRmjSO60SV62dR_Cso6Ad-8,15796
|
|
5
|
+
langchain_b12/genai/genai_utils.py,sha256=2ojnSumovXyx3CxM7JwzyOkEdD7mQxfLnk50-usPbw8,8221
|
|
6
|
+
langchain_b12-0.1.0.dist-info/METADATA,sha256=72Ct-UG2KHnPQuyUltOQ_HGpvLbDrXglmPn-eY9blnc,1307
|
|
7
|
+
langchain_b12-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
8
|
+
langchain_b12-0.1.0.dist-info/RECORD,,
|