langchain-b12 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,99 @@
1
+ import os
2
+
3
+ from google.genai import Client
4
+ from google.oauth2 import service_account
5
+ from langchain_core.embeddings import Embeddings
6
+ from pydantic import BaseModel, ConfigDict, Field
7
+
8
+
9
+ class GenAIEmbeddings(Embeddings, BaseModel):
10
+ """Embeddings implementation using `google-genai`."""
11
+
12
+ model_name: str = Field(default="text-multilingual-embedding-002")
13
+ client: Client = Field(
14
+ default_factory=lambda: Client(
15
+ credentials=service_account.Credentials.from_service_account_file(
16
+ filename=os.getenv("GOOGLE_APPLICATION_CREDENTIALS"),
17
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
18
+ )
19
+ ),
20
+ exclude=True,
21
+ )
22
+ model_config = ConfigDict(
23
+ arbitrary_types_allowed=True,
24
+ )
25
+
26
+ def embed_documents(self, texts: list[str]) -> list[list[float]]:
27
+ """Embed a list of text strings using the Google GenAI API.
28
+
29
+ Args:
30
+ texts (list[str]): The text strings to embed.
31
+
32
+ Returns:
33
+ list[list[float]]: The embedding vectors.
34
+ """
35
+ embeddings = []
36
+ for text in texts:
37
+ response = self.client.models.embed_content(
38
+ model=self.model_name,
39
+ contents=[text],
40
+ )
41
+ assert (
42
+ response.embeddings is not None
43
+ ), "No embeddings found in the response."
44
+ for embedding in response.embeddings:
45
+ assert (
46
+ embedding.values is not None
47
+ ), "No embedding values found in the response."
48
+ embeddings.append(embedding.values)
49
+ assert len(embeddings) == len(
50
+ texts
51
+ ), "The number of embeddings does not match the number of texts."
52
+ return embeddings
53
+
54
+ def embed_query(self, text: str) -> list[float]:
55
+ """Embed a text string using the Google GenAI API.
56
+
57
+ Args:
58
+ text (str): The text to embed.
59
+
60
+ Returns:
61
+ list[float]: The embedding vector.
62
+ """
63
+ return self.embed_documents([text])[0]
64
+
65
+ async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
66
+ """Embed a list of text strings using the Google GenAI API asynchronously.
67
+
68
+ Args:
69
+ texts (list[str]): The text strings to embed.
70
+
71
+ Returns:
72
+ list[list[float]]: The embedding vectors.
73
+ """
74
+ embeddings = []
75
+ response = await self.client.aio.models.embed_content(
76
+ model=self.model_name,
77
+ contents=texts,
78
+ )
79
+ assert response.embeddings is not None, "No embeddings found in the response."
80
+ for embedding in response.embeddings:
81
+ assert (
82
+ embedding.values is not None
83
+ ), "No embedding values found in the response."
84
+ embeddings.append(embedding.values)
85
+ assert len(embeddings) == len(
86
+ texts
87
+ ), "The number of embeddings does not match the number of texts."
88
+ return embeddings
89
+
90
+ async def aembed_query(self, text: str) -> list[float]:
91
+ """Embed a text string using the Google GenAI API asynchronously.
92
+
93
+ Args:
94
+ text (str): The text to embed.
95
+
96
+ Returns:
97
+ list[float]: The embedding vector.
98
+ """
99
+ return (await self.aembed_documents([text]))[0]
@@ -0,0 +1,407 @@
1
+ import logging
2
+ import os
3
+ from collections.abc import AsyncIterator, Callable, Iterator, Sequence
4
+ from operator import itemgetter
5
+ from typing import Any, Literal, cast
6
+
7
+ from google import genai
8
+ from google.genai import types
9
+ from google.oauth2 import service_account
10
+ from langchain_b12.genai.genai_utils import (
11
+ convert_messages_to_contents,
12
+ parse_response_candidate,
13
+ )
14
+ from langchain_core.callbacks import (
15
+ AsyncCallbackManagerForLLMRun,
16
+ CallbackManagerForLLMRun,
17
+ )
18
+ from langchain_core.language_models import LangSmithParams, LanguageModelInput
19
+ from langchain_core.language_models.chat_models import (
20
+ BaseChatModel,
21
+ agenerate_from_stream,
22
+ generate_from_stream,
23
+ )
24
+ from langchain_core.messages import (
25
+ AIMessageChunk,
26
+ BaseMessage,
27
+ HumanMessage,
28
+ SystemMessage,
29
+ )
30
+ from langchain_core.messages.ai import UsageMetadata
31
+ from langchain_core.output_parsers import PydanticOutputParser
32
+ from langchain_core.output_parsers.base import OutputParserLike
33
+ from langchain_core.output_parsers.openai_tools import (
34
+ PydanticToolsParser,
35
+ )
36
+ from langchain_core.outputs import ChatGenerationChunk, ChatResult
37
+ from langchain_core.runnables import Runnable, RunnablePassthrough
38
+ from langchain_core.tools import BaseTool
39
+ from langchain_core.utils.function_calling import (
40
+ convert_to_openai_tool,
41
+ )
42
+ from pydantic import BaseModel, ConfigDict, Field
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ class ChatGenAI(BaseChatModel):
48
+ """Implementation of BaseChatModel using `google-genai`"""
49
+
50
+ client: genai.Client = Field(
51
+ default_factory=lambda: genai.Client(
52
+ vertexai=True,
53
+ credentials=service_account.Credentials.from_service_account_file(
54
+ filename=os.getenv("GOOGLE_APPLICATION_CREDENTIALS"),
55
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
56
+ ),
57
+ ),
58
+ exclude=True,
59
+ )
60
+ model_name: str = Field(default="gemini-2.0-flash", alias="model")
61
+ "Underlying model name."
62
+ stop: list[str] | None = Field(default=None, alias="stop_sequences")
63
+ "Optional list of stop words to use when generating."
64
+ temperature: float | None = None
65
+ "Sampling temperature, it controls the degree of randomness in token selection."
66
+ max_output_tokens: int | None = Field(default=None, alias="max_tokens")
67
+ "Token limit determines the maximum amount of text output from one prompt."
68
+ top_p: float | None = None
69
+ "Tokens are selected from most probable to least until the sum of their "
70
+ "probabilities equals the top-p value. Top-p is ignored for Codey models."
71
+ top_k: int | None = None
72
+ "How the model selects tokens for output, the next token is selected from "
73
+ "among the top-k most probable tokens. Top-k is ignored for Codey models."
74
+ n: int = 1
75
+ """How many completions to generate for each prompt."""
76
+ seed: int | None = None
77
+ """Random seed for the generation."""
78
+ safety_settings: list[types.SafetySetting] | None = None
79
+ """The default safety settings to use for all generations.
80
+
81
+ For example:
82
+
83
+ from langchain_google_vertexai import HarmBlockThreshold, HarmCategory
84
+
85
+ safety_settings = {
86
+ HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
87
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
88
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
89
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
90
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
91
+ }
92
+ """ # noqa: E501
93
+
94
+ model_config = ConfigDict(
95
+ arbitrary_types_allowed=True,
96
+ )
97
+
98
+ @property
99
+ def _llm_type(self) -> str:
100
+ return "vertexai"
101
+
102
+ @property
103
+ def _default_params(self) -> dict[str, Any]:
104
+ return {
105
+ "temperature": self.temperature,
106
+ "max_output_tokens": self.max_output_tokens,
107
+ "candidate_count": self.n,
108
+ "seed": self.seed,
109
+ "top_k": self.top_k,
110
+ "top_p": self.top_p,
111
+ }
112
+
113
+ @property
114
+ def _identifying_params(self) -> dict[str, Any]:
115
+ """Gets the identifying parameters."""
116
+ return {**{"model_name": self.model_name}, **self._default_params}
117
+
118
+ @classmethod
119
+ def is_lc_serializable(cls) -> bool:
120
+ return True
121
+
122
+ @classmethod
123
+ def get_lc_namespace(cls) -> list[str]:
124
+ """Get the namespace of the langchain object."""
125
+ return ["langchain_b12", "genai", "genai"]
126
+
127
+ def _get_ls_params(
128
+ self, stop: list[str] | None = None, **kwargs: Any
129
+ ) -> LangSmithParams:
130
+ """Get standard params for tracing."""
131
+ params = {**self._default_params, **kwargs}
132
+ ls_params = LangSmithParams(
133
+ ls_provider="google_vertexai",
134
+ ls_model_name=self.model_name,
135
+ ls_model_type="chat",
136
+ ls_temperature=params.get("temperature", self.temperature),
137
+ )
138
+ if ls_max_tokens := params.get("max_output_tokens", self.max_output_tokens):
139
+ ls_params["ls_max_tokens"] = ls_max_tokens
140
+ if ls_stop := stop or params.get("stop", None) or self.stop:
141
+ ls_params["ls_stop"] = ls_stop
142
+ return ls_params
143
+
144
+ def _prepare_request(
145
+ self, messages: list[BaseMessage]
146
+ ) -> tuple[str | None, types.ContentListUnion]:
147
+ contents = convert_messages_to_contents(messages)
148
+ if isinstance(messages[-1], SystemMessage):
149
+ system_instruction = messages[-1].content
150
+ assert isinstance(
151
+ system_instruction, str
152
+ ), "System message content must be a string"
153
+ else:
154
+ system_instruction = None
155
+ return system_instruction, cast(types.ContentListUnion, contents)
156
+
157
+ def get_num_tokens(self, text: str) -> int:
158
+ """Get the number of tokens present in the text."""
159
+ contents = convert_messages_to_contents([HumanMessage(content=text)])
160
+ response = self.client.models.count_tokens(
161
+ model=self.model_name,
162
+ contents=cast(types.ContentListUnion, contents),
163
+ )
164
+ return response.total_tokens or 0
165
+
166
+ def _generate(
167
+ self,
168
+ messages: list[BaseMessage],
169
+ stop: list[str] | None = None,
170
+ run_manager: CallbackManagerForLLMRun | None = None,
171
+ **kwargs: Any,
172
+ ) -> ChatResult:
173
+ stream_iter = self._stream(
174
+ messages, stop=stop, run_manager=run_manager, **kwargs
175
+ )
176
+ return generate_from_stream(stream_iter)
177
+
178
+ async def _agenerate(
179
+ self,
180
+ messages: list[BaseMessage],
181
+ stop: list[str] | None = None,
182
+ run_manager: AsyncCallbackManagerForLLMRun | None = None,
183
+ **kwargs: Any,
184
+ ) -> ChatResult:
185
+ stream_iter = self._astream(
186
+ messages, stop=stop, run_manager=run_manager, **kwargs
187
+ )
188
+ return await agenerate_from_stream(stream_iter)
189
+
190
+ def _stream(
191
+ self,
192
+ messages: list[BaseMessage],
193
+ stop: list[str] | None = None,
194
+ run_manager: CallbackManagerForLLMRun | None = None,
195
+ **kwargs: Any,
196
+ ) -> Iterator[ChatGenerationChunk]:
197
+ system_message, contents = self._prepare_request(messages=messages)
198
+ response_iter = self.client.models.generate_content_stream(
199
+ model=self.model_name,
200
+ contents=contents,
201
+ config=types.GenerateContentConfig(
202
+ system_instruction=system_message,
203
+ temperature=self.temperature,
204
+ top_k=self.top_k,
205
+ top_p=self.top_p,
206
+ max_output_tokens=self.max_output_tokens,
207
+ candidate_count=self.n,
208
+ stop_sequences=stop or self.stop,
209
+ safety_settings=self.safety_settings,
210
+ **kwargs,
211
+ ),
212
+ )
213
+ total_lc_usage = None
214
+ for response_chunk in response_iter:
215
+ chunk, total_lc_usage = self._gemini_chunk_to_generation_chunk(
216
+ response_chunk, prev_total_usage=total_lc_usage
217
+ )
218
+ if run_manager and isinstance(chunk.message.content, str):
219
+ run_manager.on_llm_new_token(chunk.message.content)
220
+ yield chunk
221
+
222
+ async def _astream(
223
+ self,
224
+ messages: list[BaseMessage],
225
+ stop: list[str] | None = None,
226
+ run_manager: AsyncCallbackManagerForLLMRun | None = None,
227
+ **kwargs: Any,
228
+ ) -> AsyncIterator[ChatGenerationChunk]:
229
+ system_message, contents = self._prepare_request(messages=messages)
230
+ response_iter = self.client.aio.models.generate_content_stream(
231
+ model=self.model_name,
232
+ contents=contents,
233
+ config=types.GenerateContentConfig(
234
+ system_instruction=system_message,
235
+ temperature=self.temperature,
236
+ top_k=self.top_k,
237
+ top_p=self.top_p,
238
+ max_output_tokens=self.max_output_tokens,
239
+ candidate_count=self.n,
240
+ stop_sequences=stop or self.stop,
241
+ safety_settings=self.safety_settings,
242
+ **kwargs,
243
+ ),
244
+ )
245
+ total_lc_usage = None
246
+ async for response_chunk in await response_iter:
247
+ chunk, total_lc_usage = self._gemini_chunk_to_generation_chunk(
248
+ response_chunk, prev_total_usage=total_lc_usage
249
+ )
250
+ if run_manager and isinstance(chunk.message.content, str):
251
+ await run_manager.on_llm_new_token(chunk.message.content)
252
+ yield chunk
253
+
254
+ def with_structured_output(
255
+ self,
256
+ schema: dict | type,
257
+ *,
258
+ include_raw: bool = False,
259
+ method: Literal["json_mode", "function_calling"] = "json_mode",
260
+ **kwargs: Any,
261
+ ) -> Runnable[LanguageModelInput, dict | BaseModel]:
262
+ assert isinstance(schema, type) and issubclass(
263
+ schema, BaseModel
264
+ ), "Structured output is only supported for Pydantic models."
265
+ if kwargs:
266
+ raise ValueError(f"Received unsupported arguments {kwargs}")
267
+
268
+ parser: OutputParserLike
269
+
270
+ if method == "json_mode":
271
+ parser = PydanticOutputParser(pydantic_object=schema)
272
+
273
+ llm = self.bind(
274
+ response_mime_type="application/json",
275
+ response_schema=schema,
276
+ ls_structured_output_format={
277
+ "kwargs": {"method": method},
278
+ "schema": schema,
279
+ },
280
+ )
281
+ elif method == "function_calling":
282
+ parser = PydanticToolsParser(tools=[schema], first_tool_only=True)
283
+ tool_choice = schema.__name__
284
+ llm = self.bind_tools(
285
+ [schema],
286
+ tool_choice=tool_choice,
287
+ ls_structured_output_format={
288
+ "kwargs": {"method": "function_calling"},
289
+ "schema": convert_to_openai_tool(schema),
290
+ },
291
+ )
292
+ else:
293
+ raise ValueError("method must be either 'json_mode' or 'function_calling'.")
294
+
295
+ if include_raw:
296
+ parser_with_fallback = RunnablePassthrough.assign(
297
+ parsed=itemgetter("raw") | parser, parsing_error=lambda _: None
298
+ ).with_fallbacks(
299
+ [RunnablePassthrough.assign(parsed=lambda _: None)],
300
+ exception_key="parsing_error",
301
+ )
302
+ return {"raw": llm} | parser_with_fallback
303
+ else:
304
+ return llm | parser
305
+
306
+ def bind_tools(
307
+ self,
308
+ tools: Sequence[
309
+ dict[str, Any] | type | Callable[..., Any] | BaseTool | types.Tool
310
+ ],
311
+ *,
312
+ tool_choice: str | None = None,
313
+ **kwargs: Any,
314
+ ) -> Runnable[LanguageModelInput, BaseMessage]:
315
+ """Bind tool-like objects to this chat model.
316
+
317
+ Assumes model is compatible with Vertex tool-calling API.
318
+
319
+ Args:
320
+ tools: A list of tool definitions to bind to this chat model.
321
+ Can be a pydantic model, callable, or BaseTool. Pydantic
322
+ models, callables, and BaseTools will be automatically converted to
323
+ their schema dictionary representation.
324
+ **kwargs: Any additional parameters to pass to the
325
+ :class:`~langchain.runnable.Runnable` constructor.
326
+ """
327
+ formatted_tools = []
328
+ for tool in tools:
329
+ if not isinstance(tool, types.Tool):
330
+ openai_tool = convert_to_openai_tool(tool)
331
+ if openai_tool["type"] != "function":
332
+ raise ValueError(
333
+ f"Tool {tool} is not a function tool. "
334
+ f"It is a {openai_tool['type']}. "
335
+ "Only function tools are supported."
336
+ )
337
+ function = openai_tool["function"]
338
+ formatted_tools.append(
339
+ types.Tool(
340
+ function_declarations=[types.FunctionDeclaration(**function)],
341
+ )
342
+ )
343
+ else:
344
+ formatted_tools.append(tool)
345
+ if tool_choice:
346
+ kwargs["tool_config"] = types.FunctionCallingConfig(
347
+ mode=types.FunctionCallingConfigMode.ANY,
348
+ allowed_function_names=[tool_choice],
349
+ )
350
+ return self.bind(tools=formatted_tools, **kwargs)
351
+
352
+ def _gemini_chunk_to_generation_chunk(
353
+ self,
354
+ response_chunk: types.GenerateContentResponse,
355
+ prev_total_usage: UsageMetadata | None = None,
356
+ ) -> tuple[ChatGenerationChunk, UsageMetadata | None]:
357
+ def _parse_usage_metadata(
358
+ usage_metadata: types.GenerateContentResponseUsageMetadata,
359
+ ) -> UsageMetadata:
360
+ return UsageMetadata(
361
+ input_tokens=usage_metadata.prompt_token_count or 0,
362
+ output_tokens=usage_metadata.candidates_token_count or 0,
363
+ total_tokens=usage_metadata.total_token_count or 0,
364
+ )
365
+
366
+ total_lc_usage: UsageMetadata | None = (
367
+ _parse_usage_metadata(response_chunk.usage_metadata)
368
+ if response_chunk.usage_metadata
369
+ else None
370
+ )
371
+
372
+ if total_lc_usage and prev_total_usage:
373
+ lc_usage: UsageMetadata | None = UsageMetadata(
374
+ input_tokens=total_lc_usage["input_tokens"]
375
+ - prev_total_usage["input_tokens"],
376
+ output_tokens=total_lc_usage["output_tokens"]
377
+ - prev_total_usage["output_tokens"],
378
+ total_tokens=total_lc_usage["total_tokens"]
379
+ - prev_total_usage["total_tokens"],
380
+ )
381
+ else:
382
+ lc_usage = total_lc_usage
383
+ if not response_chunk.candidates:
384
+ message = AIMessageChunk(content="")
385
+ if lc_usage:
386
+ message.usage_metadata = lc_usage
387
+ generation_info = {}
388
+ else:
389
+ top_candidate = response_chunk.candidates[0]
390
+ generation_info = {
391
+ "finish_reason": top_candidate.finish_reason,
392
+ "finish_message": top_candidate.finish_message,
393
+ }
394
+ message = parse_response_candidate(top_candidate)
395
+ if lc_usage:
396
+ message.usage_metadata = lc_usage
397
+ # add model name if final chunk
398
+ if top_candidate.finish_reason is not None:
399
+ message.response_metadata["model_name"] = self.model_name
400
+
401
+ return (
402
+ ChatGenerationChunk(
403
+ message=message,
404
+ generation_info=generation_info,
405
+ ),
406
+ total_lc_usage,
407
+ )
@@ -0,0 +1,219 @@
1
+ import base64
2
+ import json
3
+ import uuid
4
+ from collections.abc import Sequence
5
+ from typing import Any, cast
6
+
7
+ from google.genai import types
8
+ from langchain_core.messages import (
9
+ AIMessage,
10
+ AIMessageChunk,
11
+ BaseMessage,
12
+ HumanMessage,
13
+ SystemMessage,
14
+ ToolMessage,
15
+ )
16
+ from langchain_core.messages.tool import tool_call_chunk
17
+
18
+
19
+ def multi_content_to_part(
20
+ contents: Sequence[dict[str, str | dict[str, str]] | str],
21
+ ) -> list[types.Part]:
22
+ """Convert sequence content to a Part object.
23
+
24
+ Args:
25
+ contents: A sequence of dictionaries representing content. Examples:
26
+ [
27
+ {
28
+ "type": "text",
29
+ "text": "This is a text message"
30
+ },
31
+ {
32
+ "type": "image_url",
33
+ "image_url": {
34
+ "url": f"data:{mime_type};base64,{encoded_artifact}"
35
+ },
36
+ },
37
+ {
38
+ "type": "file",
39
+ "file": {
40
+ "uri": f"gs://{bucket_name}/{file_name}",
41
+ "mime_type": mime_type,
42
+ }
43
+ }
44
+ ]
45
+ """
46
+ parts = []
47
+ for content in contents:
48
+ assert isinstance(content, dict), "Expected dict content"
49
+ assert "type" in content, "Received dict content without type"
50
+ if content["type"] == "text":
51
+ assert "text" in content, "Expected 'text' in content"
52
+ if content["text"]:
53
+ assert isinstance(content["text"], str), "Expected str content"
54
+ parts.append(types.Part(text=content["text"]))
55
+ elif content["type"] == "image_url":
56
+ assert isinstance(content["image_url"], dict), "Expected dict image_url"
57
+ assert "url" in content["image_url"], "Expected 'url' in content"
58
+ split_url: tuple[str, str] = content["image_url"]["url"].split(",", 1) # type: ignore
59
+ header, encoded_data = split_url
60
+ mime_type = header.split(":", 1)[1].split(";", 1)[0]
61
+ data = base64.b64decode(encoded_data)
62
+ parts.append(types.Part.from_bytes(data=data, mime_type=mime_type))
63
+ elif content["type"] == "file":
64
+ assert "file" in content, "Expected 'file' in content"
65
+ file = content["file"]
66
+ assert isinstance(file, dict), "Expected dict file"
67
+ assert "uri" in file, "Expected 'uri' in content['file']"
68
+ assert "mime_type" in file, "Expected 'mime_type' in content['file']"
69
+ parts.append(
70
+ types.Part.from_uri(file_uri=file["uri"], mime_type=file["mime_type"])
71
+ )
72
+ else:
73
+ raise ValueError(f"Unknown content type: {content['type']}")
74
+ return parts
75
+
76
+
77
+ def convert_base_message_to_parts(
78
+ message: BaseMessage,
79
+ ) -> list[types.Part]:
80
+ """Convert a LangChain BaseMessage to Google GenAI Content object."""
81
+ parts = []
82
+ if isinstance(message.content, str):
83
+ if message.content:
84
+ parts.append(types.Part(text=message.content))
85
+ elif isinstance(message.content, list):
86
+ parts.extend(multi_content_to_part(message.content))
87
+ else:
88
+ raise ValueError(
89
+ "Received unexpected content type, "
90
+ f"expected str or list, but got {type(message.content)}"
91
+ )
92
+ return parts
93
+
94
+
95
+ def convert_messages_to_contents(
96
+ messages: Sequence[BaseMessage],
97
+ ) -> list[types.Content]:
98
+ """Convert a sequence of LangChain messages to Google GenAI Content objects.
99
+
100
+ Args:
101
+ messages: A sequence of LangChain BaseMessage objects
102
+
103
+ Returns:
104
+ A list of Google GenAI Content objects
105
+ """
106
+ contents = []
107
+
108
+ for message in messages:
109
+ if isinstance(message, HumanMessage):
110
+ parts = convert_base_message_to_parts(message)
111
+ contents.append(types.UserContent(parts=parts))
112
+ elif isinstance(message, AIMessage):
113
+ text_parts = convert_base_message_to_parts(message)
114
+ function_parts = []
115
+ if message.tool_calls:
116
+ # Example of tool_call
117
+ # tool_call = {
118
+ # "name": "foo",
119
+ # "args": {"a": 1},
120
+ # "id": "123"
121
+ # }
122
+ for tool_call in message.tool_calls:
123
+ tool_id = tool_call["id"]
124
+ assert tool_id, "Tool call ID is required"
125
+ function_parts.append(
126
+ types.Part(
127
+ function_call=types.FunctionCall(
128
+ name=tool_call["name"],
129
+ args=tool_call["args"],
130
+ id=tool_id,
131
+ ),
132
+ )
133
+ )
134
+
135
+ contents.append(
136
+ types.ModelContent(
137
+ parts=[*text_parts, *function_parts],
138
+ )
139
+ )
140
+ elif isinstance(message, ToolMessage):
141
+ # Note: We tried combining function_call and function_response into one
142
+ # part, but that throws a 4xx server error.
143
+ assert isinstance(message.content, str), "Expected str content"
144
+ assert message.name, "Tool name is required"
145
+ tool_part = types.Part(
146
+ function_response=types.FunctionResponse(
147
+ id=message.tool_call_id,
148
+ name=message.name,
149
+ response={"output": message.content},
150
+ ),
151
+ )
152
+
153
+ # Ensure that all function_responses are in a single content
154
+ last_content = contents[-1]
155
+ last_content_part = last_content.parts[-1]
156
+ if last_content_part.function_response:
157
+ # Merge with the last content
158
+ last_content.parts.append(tool_part)
159
+ else:
160
+ # Create a new content
161
+ contents.append(types.UserContent(parts=[tool_part]))
162
+ elif isinstance(message, SystemMessage):
163
+ # There is no genai.types equivalent for SystemMessage
164
+ pass
165
+ else:
166
+ raise ValueError(f"Invalid message type: {type(message)}")
167
+
168
+ return contents
169
+
170
+
171
+ def parse_response_candidate(response_candidate: types.Candidate) -> AIMessageChunk:
172
+ content: None | str | list[str] = None
173
+ additional_kwargs = {}
174
+ tool_call_chunks = []
175
+
176
+ assert response_candidate.content, "Response candidate content is None"
177
+ for part in response_candidate.content.parts or []:
178
+ try:
179
+ text: str | None = part.text
180
+ except AttributeError:
181
+ text = None
182
+
183
+ if text:
184
+ if not content:
185
+ content = text
186
+ elif isinstance(content, str):
187
+ content = [content, text]
188
+ elif isinstance(content, list):
189
+ content.append(text)
190
+ else:
191
+ raise ValueError("Unexpected content type")
192
+
193
+ if part.function_call:
194
+ # For backward compatibility we store a function call in additional_kwargs,
195
+ # but in general the full set of function calls is stored in tool_calls.
196
+ function_call = {"name": part.function_call.name}
197
+ # dump to match other function calling llm for now
198
+ function_call_args_dict = part.function_call.args
199
+ assert function_call_args_dict is not None
200
+ function_call["arguments"] = json.dumps(function_call_args_dict)
201
+ additional_kwargs["function_call"] = function_call
202
+
203
+ index = function_call.get("index")
204
+ tool_call_chunks.append(
205
+ tool_call_chunk(
206
+ name=function_call.get("name"),
207
+ args=function_call.get("arguments"),
208
+ id=function_call.get("id", str(uuid.uuid4())),
209
+ index=int(index) if index else None,
210
+ )
211
+ )
212
+ if content is None:
213
+ content = ""
214
+
215
+ return AIMessageChunk(
216
+ content=cast(str | list[str | dict[Any, Any]], content),
217
+ additional_kwargs=additional_kwargs,
218
+ tool_call_chunks=tool_call_chunks,
219
+ )
langchain_b12/py.typed ADDED
File without changes
@@ -0,0 +1,39 @@
1
+ Metadata-Version: 2.4
2
+ Name: langchain-b12
3
+ Version: 0.1.0
4
+ Summary: A reusable collection of tools and implementations for Langchain
5
+ Author-email: Vincent Min <vincent.min@b12-consulting.com>
6
+ Requires-Python: >=3.11
7
+ Requires-Dist: langchain-core>=0.3.60
8
+ Description-Content-Type: text/markdown
9
+
10
+ # Langchain B12
11
+
12
+ This repo hosts a collection of custom LangChain components.
13
+
14
+ - ChatGenAI converts a sequence of LangChain messages to Google GenAI Content objects and vice versa.
15
+
16
+ ## Installation
17
+
18
+ To install this package, run
19
+ ```bash
20
+ pip install langchain-b12
21
+ ```
22
+
23
+ Some components rely on additional packages that may be installed as extras.
24
+ For example, to use the Google chatmodel `ChatGenAI`, you can run
25
+ ```bash
26
+ pip install langchain-b12[google]
27
+ ```
28
+
29
+ ## Components
30
+
31
+ The repo contains these components:
32
+
33
+ - `ChatGenAI`: An implementation of `BaseChatModel` that uses the `google-genai` package. Note that `langchain-google-genai` and `langchain-google-vertexai` exist, but neither uses the latest and recommended `google-genai` package.
34
+ - `GenAIEmbeddings`: An implementation of `Embeddings` that uses the `google-genai` package.
35
+
36
+ ## Comments
37
+
38
+ This repo exists for easy reuse and extension of custom LangChain components.
39
+ When appropriate, we will create PRs to merge these components directly into LangChain.
@@ -0,0 +1,8 @@
1
+ langchain_b12/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ langchain_b12/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ langchain_b12/genai/embeddings.py,sha256=od2bVIgt7v9aNAHG0PVypVF1H_XgHto2nTd8vwfvyN8,3355
4
+ langchain_b12/genai/genai.py,sha256=o2KLo2QlLDy0hijt6T2HaRmjSO60SV62dR_Cso6Ad-8,15796
5
+ langchain_b12/genai/genai_utils.py,sha256=2ojnSumovXyx3CxM7JwzyOkEdD7mQxfLnk50-usPbw8,8221
6
+ langchain_b12-0.1.0.dist-info/METADATA,sha256=72Ct-UG2KHnPQuyUltOQ_HGpvLbDrXglmPn-eY9blnc,1307
7
+ langchain_b12-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ langchain_b12-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any