langchain-google-genai 1.0.6__tar.gz → 1.0.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain-google-genai might be problematic. Click here for more details.

Files changed (17) hide show
  1. {langchain_google_genai-1.0.6 → langchain_google_genai-1.0.8}/PKG-INFO +3 -3
  2. langchain_google_genai-1.0.8/langchain_google_genai/_function_utils.py +329 -0
  3. {langchain_google_genai-1.0.6 → langchain_google_genai-1.0.8}/langchain_google_genai/chat_models.py +307 -61
  4. {langchain_google_genai-1.0.6 → langchain_google_genai-1.0.8}/langchain_google_genai/embeddings.py +8 -6
  5. {langchain_google_genai-1.0.6 → langchain_google_genai-1.0.8}/langchain_google_genai/llms.py +14 -7
  6. {langchain_google_genai-1.0.6 → langchain_google_genai-1.0.8}/pyproject.toml +6 -4
  7. langchain_google_genai-1.0.6/langchain_google_genai/_function_utils.py +0 -237
  8. {langchain_google_genai-1.0.6 → langchain_google_genai-1.0.8}/LICENSE +0 -0
  9. {langchain_google_genai-1.0.6 → langchain_google_genai-1.0.8}/README.md +0 -0
  10. {langchain_google_genai-1.0.6 → langchain_google_genai-1.0.8}/langchain_google_genai/__init__.py +0 -0
  11. {langchain_google_genai-1.0.6 → langchain_google_genai-1.0.8}/langchain_google_genai/_common.py +0 -0
  12. {langchain_google_genai-1.0.6 → langchain_google_genai-1.0.8}/langchain_google_genai/_enums.py +0 -0
  13. {langchain_google_genai-1.0.6 → langchain_google_genai-1.0.8}/langchain_google_genai/_genai_extension.py +0 -0
  14. {langchain_google_genai-1.0.6 → langchain_google_genai-1.0.8}/langchain_google_genai/_image_utils.py +0 -0
  15. {langchain_google_genai-1.0.6 → langchain_google_genai-1.0.8}/langchain_google_genai/genai_aqa.py +0 -0
  16. {langchain_google_genai-1.0.6 → langchain_google_genai-1.0.8}/langchain_google_genai/google_vector_store.py +0 -0
  17. {langchain_google_genai-1.0.6 → langchain_google_genai-1.0.8}/langchain_google_genai/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langchain-google-genai
3
- Version: 1.0.6
3
+ Version: 1.0.8
4
4
  Summary: An integration package connecting Google's genai package and LangChain
5
5
  Home-page: https://github.com/langchain-ai/langchain-google
6
6
  License: MIT
@@ -12,8 +12,8 @@ Classifier: Programming Language :: Python :: 3.10
12
12
  Classifier: Programming Language :: Python :: 3.11
13
13
  Classifier: Programming Language :: Python :: 3.12
14
14
  Provides-Extra: images
15
- Requires-Dist: google-generativeai (>=0.5.2,<0.6.0)
16
- Requires-Dist: langchain-core (>=0.2.2,<0.3)
15
+ Requires-Dist: google-generativeai (>=0.7.0,<0.8.0)
16
+ Requires-Dist: langchain-core (>=0.2.17,<0.3)
17
17
  Requires-Dist: pillow (>=10.1.0,<11.0.0) ; extra == "images"
18
18
  Project-URL: Repository, https://github.com/langchain-ai/langchain-google
19
19
  Project-URL: Source Code, https://github.com/langchain-ai/langchain-google/tree/main/libs/genai
@@ -0,0 +1,329 @@
1
+ from __future__ import annotations
2
+
3
+ import collections
4
+ import json
5
+ import logging
6
+ from typing import (
7
+ Any,
8
+ Callable,
9
+ Collection,
10
+ Dict,
11
+ List,
12
+ Literal,
13
+ Optional,
14
+ Sequence,
15
+ Type,
16
+ TypedDict,
17
+ Union,
18
+ cast,
19
+ )
20
+
21
+ import google.ai.generativelanguage as glm
22
+ import google.ai.generativelanguage_v1beta.types as gapic
23
+ import proto # type: ignore[import]
24
+ from google.generativeai.types.content_types import ToolDict # type: ignore[import]
25
+ from langchain_core.pydantic_v1 import BaseModel
26
+ from langchain_core.tools import BaseTool
27
+ from langchain_core.tools import tool as callable_as_lc_tool
28
+ from langchain_core.utils.function_calling import (
29
+ FunctionDescription,
30
+ convert_to_openai_tool,
31
+ )
32
+ from langchain_core.utils.json_schema import dereference_refs
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ TYPE_ENUM = {
38
+ "string": glm.Type.STRING,
39
+ "number": glm.Type.NUMBER,
40
+ "integer": glm.Type.INTEGER,
41
+ "boolean": glm.Type.BOOLEAN,
42
+ "array": glm.Type.ARRAY,
43
+ "object": glm.Type.OBJECT,
44
+ }
45
+
46
+ TYPE_ENUM_REVERSE = {v: k for k, v in TYPE_ENUM.items()}
47
+ _ALLOWED_SCHEMA_FIELDS = []
48
+ _ALLOWED_SCHEMA_FIELDS.extend([f.name for f in gapic.Schema()._pb.DESCRIPTOR.fields])
49
+ _ALLOWED_SCHEMA_FIELDS.extend(
50
+ [
51
+ f
52
+ for f in gapic.Schema.to_dict(
53
+ gapic.Schema(), preserving_proto_field_name=False
54
+ ).keys()
55
+ ]
56
+ )
57
+ _ALLOWED_SCHEMA_FIELDS_SET = set(_ALLOWED_SCHEMA_FIELDS)
58
+
59
+
60
+ class _ToolDictLike(TypedDict):
61
+ function_declarations: _FunctionDeclarationLikeList
62
+
63
+
64
+ class _FunctionDeclarationDict(TypedDict):
65
+ name: str
66
+ description: str
67
+ parameters: Dict[str, Collection[str]]
68
+
69
+
70
+ class _ToolDict(TypedDict):
71
+ function_declarations: Sequence[_FunctionDeclarationDict]
72
+
73
+
74
+ # Info: This is a FunctionDeclaration(=fc).
75
+ _FunctionDeclarationLike = Union[
76
+ BaseTool, Type[BaseModel], gapic.FunctionDeclaration, Callable, Dict[str, Any]
77
+ ]
78
+
79
+ # Info: This mean one tool.
80
+ _FunctionDeclarationLikeList = Sequence[_FunctionDeclarationLike]
81
+
82
+
83
+ # Info: This means one tool=Sequence of FunctionDeclaration
84
+ # The dict should be gapic.Tool like. {"function_declarations": [ { "name": ...}.
85
+ # OpenAI like dict is not be accepted. {{'type': 'function', 'function': {'name': ...}
86
+ _ToolsType = Union[
87
+ gapic.Tool,
88
+ ToolDict,
89
+ _ToolDictLike,
90
+ _FunctionDeclarationLikeList,
91
+ _FunctionDeclarationLike,
92
+ ]
93
+
94
+
95
+ def _format_json_schema_to_gapic(schema: Dict[str, Any]) -> Dict[str, Any]:
96
+ converted_schema: Dict[str, Any] = {}
97
+ for key, value in schema.items():
98
+ if key == "definitions":
99
+ continue
100
+ elif key == "items":
101
+ converted_schema["items"] = _format_json_schema_to_gapic(value)
102
+ elif key == "properties":
103
+ if "properties" not in converted_schema:
104
+ converted_schema["properties"] = {}
105
+ for pkey, pvalue in value.items():
106
+ converted_schema["properties"][pkey] = _format_json_schema_to_gapic(
107
+ pvalue
108
+ )
109
+ continue
110
+ elif key in ["type", "_type"]:
111
+ converted_schema["type"] = str(value).upper()
112
+ elif key not in _ALLOWED_SCHEMA_FIELDS_SET:
113
+ logger.warning(f"Key '{key}' is not supported in schema, ignoring")
114
+ else:
115
+ converted_schema[key] = value
116
+ return converted_schema
117
+
118
+
119
+ def _dict_to_gapic_schema(schema: Dict[str, Any]) -> gapic.Schema:
120
+ dereferenced_schema = dereference_refs(schema)
121
+ formatted_schema = _format_json_schema_to_gapic(dereferenced_schema)
122
+ json_schema = json.dumps(formatted_schema)
123
+ return gapic.Schema.from_json(json_schema)
124
+
125
+
126
+ def _format_dict_to_function_declaration(
127
+ tool: Union[FunctionDescription, Dict[str, Any]],
128
+ ) -> gapic.FunctionDeclaration:
129
+ return gapic.FunctionDeclaration(
130
+ name=tool.get("name"),
131
+ description=tool.get("description"),
132
+ parameters=_dict_to_gapic_schema(tool.get("parameters", {})),
133
+ )
134
+
135
+
136
+ # Info: gapic.Tool means function_declarations and proto.Message.
137
+ def convert_to_genai_function_declarations(
138
+ tools: Sequence[_ToolsType],
139
+ ) -> gapic.Tool:
140
+ if not isinstance(tools, collections.abc.Sequence):
141
+ logger.warning(
142
+ "convert_to_genai_function_declarations expects a Sequence "
143
+ "and not a single tool."
144
+ )
145
+ tools = [tools]
146
+ gapic_tool = gapic.Tool()
147
+ for tool in tools:
148
+ if isinstance(tool, gapic.Tool):
149
+ gapic_tool.function_declarations.extend(tool.function_declarations)
150
+ elif isinstance(tool, dict):
151
+ if "function_declarations" not in tool:
152
+ fd = _format_to_gapic_function_declaration(tool)
153
+ gapic_tool.function_declarations.append(fd)
154
+ continue
155
+ tool = cast(_ToolDictLike, tool)
156
+ function_declarations = tool["function_declarations"]
157
+ if not isinstance(function_declarations, collections.abc.Sequence):
158
+ raise ValueError(
159
+ "function_declarations should be a list"
160
+ f"got '{type(function_declarations)}'"
161
+ )
162
+ if function_declarations:
163
+ fds = [
164
+ _format_to_gapic_function_declaration(fd)
165
+ for fd in function_declarations
166
+ ]
167
+ gapic_tool.function_declarations.extend(fds)
168
+ else:
169
+ fd = _format_to_gapic_function_declaration(tool)
170
+ gapic_tool.function_declarations.append(fd)
171
+ return gapic_tool
172
+
173
+
174
+ def tool_to_dict(tool: gapic.Tool) -> _ToolDict:
175
+ def _traverse_values(raw: Any) -> Any:
176
+ if isinstance(raw, list):
177
+ return [_traverse_values(v) for v in raw]
178
+ if isinstance(raw, dict):
179
+ return {k: _traverse_values(v) for k, v in raw.items()}
180
+ if isinstance(raw, proto.Message):
181
+ return _traverse_values(type(raw).to_dict(raw))
182
+ return raw
183
+
184
+ return _traverse_values(type(tool).to_dict(tool))
185
+
186
+
187
+ def _format_to_gapic_function_declaration(
188
+ tool: _FunctionDeclarationLike,
189
+ ) -> gapic.FunctionDeclaration:
190
+ if isinstance(tool, BaseTool):
191
+ return _format_base_tool_to_function_declaration(tool)
192
+ elif isinstance(tool, type) and issubclass(tool, BaseModel):
193
+ return _convert_pydantic_to_genai_function(tool)
194
+ elif isinstance(tool, dict):
195
+ if all(k in tool for k in ("name", "description")) and "parameters" not in tool:
196
+ function = cast(dict, tool)
197
+ function["parameters"] = {}
198
+ else:
199
+ function = convert_to_openai_tool(cast(dict, tool))["function"]
200
+ return _format_dict_to_function_declaration(cast(FunctionDescription, function))
201
+ elif callable(tool):
202
+ return _format_base_tool_to_function_declaration(callable_as_lc_tool()(tool))
203
+ raise ValueError(f"Unsupported tool type {tool}")
204
+
205
+
206
+ def _format_base_tool_to_function_declaration(
207
+ tool: BaseTool,
208
+ ) -> gapic.FunctionDeclaration:
209
+ if not tool.args_schema:
210
+ return gapic.FunctionDeclaration(
211
+ name=tool.name,
212
+ description=tool.description,
213
+ parameters=gapic.Schema(
214
+ type=gapic.Type.OBJECT,
215
+ properties={
216
+ "__arg1": gapic.Schema(type=gapic.Type.STRING),
217
+ },
218
+ required=["__arg1"],
219
+ ),
220
+ )
221
+
222
+ schema = tool.args_schema.schema()
223
+ parameters = _dict_to_gapic_schema(schema)
224
+
225
+ return gapic.FunctionDeclaration(
226
+ name=tool.name or schema.get("title"),
227
+ description=tool.description or schema.get("description"),
228
+ parameters=parameters,
229
+ )
230
+
231
+
232
+ def _convert_pydantic_to_genai_function(
233
+ pydantic_model: Type[BaseModel],
234
+ tool_name: Optional[str] = None,
235
+ tool_description: Optional[str] = None,
236
+ ) -> gapic.FunctionDeclaration:
237
+ schema = dereference_refs(pydantic_model.schema())
238
+ schema.pop("definitions", None)
239
+ function_declaration = gapic.FunctionDeclaration(
240
+ name=tool_name if tool_name else schema.get("title"),
241
+ description=tool_description if tool_description else schema.get("description"),
242
+ parameters={
243
+ "properties": {
244
+ k: {
245
+ "type_": _get_type_from_schema(v),
246
+ "description": v.get("description"),
247
+ }
248
+ for k, v in schema["properties"].items()
249
+ },
250
+ "required": schema.get("required", []),
251
+ "type_": TYPE_ENUM[schema["type"]],
252
+ },
253
+ )
254
+ return function_declaration
255
+
256
+
257
+ def _get_type_from_schema(schema: Dict[str, Any]) -> int:
258
+ if "anyOf" in schema:
259
+ types = [_get_type_from_schema(sub_schema) for sub_schema in schema["anyOf"]]
260
+ types = [t for t in types if t is not None] # Remove None values
261
+ if types:
262
+ return types[-1] # TODO: update FunctionDeclaration and pass all types?
263
+ else:
264
+ pass
265
+ elif "type" in schema:
266
+ stype = str(schema["type"])
267
+ if stype in TYPE_ENUM:
268
+ return TYPE_ENUM[stype]
269
+ else:
270
+ pass
271
+ else:
272
+ pass
273
+ return TYPE_ENUM["string"] # Default to string if no valid types found
274
+
275
+
276
+ _ToolChoiceType = Union[
277
+ dict, List[str], str, Literal["auto", "none", "any"], Literal[True]
278
+ ]
279
+
280
+
281
+ class _FunctionCallingConfigDict(TypedDict):
282
+ mode: Union[gapic.FunctionCallingConfig.Mode, str]
283
+ allowed_function_names: Optional[List[str]]
284
+
285
+
286
+ class _ToolConfigDict(TypedDict):
287
+ function_calling_config: _FunctionCallingConfigDict
288
+
289
+
290
+ def _tool_choice_to_tool_config(
291
+ tool_choice: _ToolChoiceType,
292
+ all_names: List[str],
293
+ ) -> _ToolConfigDict:
294
+ allowed_function_names: Optional[List[str]] = None
295
+ if tool_choice is True or tool_choice == "any":
296
+ mode = "ANY"
297
+ allowed_function_names = all_names
298
+ elif tool_choice == "auto":
299
+ mode = "AUTO"
300
+ elif tool_choice == "none":
301
+ mode = "NONE"
302
+ elif isinstance(tool_choice, str):
303
+ mode = "ANY"
304
+ allowed_function_names = [tool_choice]
305
+ elif isinstance(tool_choice, list):
306
+ mode = "ANY"
307
+ allowed_function_names = tool_choice
308
+ elif isinstance(tool_choice, dict):
309
+ if "mode" in tool_choice:
310
+ mode = tool_choice["mode"]
311
+ allowed_function_names = tool_choice.get("allowed_function_names")
312
+ elif "function_calling_config" in tool_choice:
313
+ mode = tool_choice["function_calling_config"]["mode"]
314
+ allowed_function_names = tool_choice["function_calling_config"].get(
315
+ "allowed_function_names"
316
+ )
317
+ else:
318
+ raise ValueError(
319
+ f"Unrecognized tool choice format:\n\n{tool_choice=}\n\nShould match "
320
+ f"Google GenerativeAI ToolConfig or FunctionCallingConfig format."
321
+ )
322
+ else:
323
+ raise ValueError(f"Unrecognized tool choice format:\n\n{tool_choice=}")
324
+ return _ToolConfigDict(
325
+ function_calling_config={
326
+ "mode": mode.upper(),
327
+ "allowed_function_names": allowed_function_names,
328
+ }
329
+ )
@@ -8,6 +8,7 @@ import os
8
8
  import uuid
9
9
  import warnings
10
10
  from io import BytesIO
11
+ from operator import itemgetter
11
12
  from typing import (
12
13
  Any,
13
14
  AsyncIterator,
@@ -19,6 +20,7 @@ from typing import (
19
20
  Optional,
20
21
  Sequence,
21
22
  Tuple,
23
+ Type,
22
24
  Union,
23
25
  cast,
24
26
  )
@@ -58,17 +60,20 @@ from langchain_core.messages import (
58
60
  BaseMessage,
59
61
  FunctionMessage,
60
62
  HumanMessage,
61
- InvalidToolCall,
62
63
  SystemMessage,
63
- ToolCall,
64
- ToolCallChunk,
65
64
  ToolMessage,
66
65
  )
67
66
  from langchain_core.messages.ai import UsageMetadata
68
- from langchain_core.output_parsers.openai_tools import parse_tool_calls
67
+ from langchain_core.messages.tool import invalid_tool_call, tool_call, tool_call_chunk
68
+ from langchain_core.output_parsers.base import OutputParserLike
69
+ from langchain_core.output_parsers.openai_tools import (
70
+ JsonOutputToolsParser,
71
+ PydanticToolsParser,
72
+ parse_tool_calls,
73
+ )
69
74
  from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
70
- from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
71
- from langchain_core.runnables import Runnable
75
+ from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
76
+ from langchain_core.runnables import Runnable, RunnablePassthrough
72
77
  from langchain_core.utils import get_from_dict_or_env
73
78
  from tenacity import (
74
79
  before_sleep_log,
@@ -133,7 +138,7 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
133
138
  multiplier = 2
134
139
  min_seconds = 1
135
140
  max_seconds = 60
136
- max_retries = 10
141
+ max_retries = 2
137
142
 
138
143
  return retry(
139
144
  reraise=True,
@@ -459,9 +464,6 @@ def _parse_response_candidate(
459
464
  raise Exception("Unexpected content type")
460
465
 
461
466
  if part.function_call:
462
- # TODO: support multiple function calls
463
- if "function_call" in additional_kwargs:
464
- raise Exception("Multiple function calls are not currently supported")
465
467
  function_call = {"name": part.function_call.name}
466
468
  # dump to match other function calling llm for now
467
469
  function_call_args_dict = proto.Message.to_dict(part.function_call)["args"]
@@ -472,7 +474,7 @@ def _parse_response_candidate(
472
474
 
473
475
  if streaming:
474
476
  tool_call_chunks.append(
475
- ToolCallChunk(
477
+ tool_call_chunk(
476
478
  name=function_call.get("name"),
477
479
  args=function_call.get("arguments"),
478
480
  id=function_call.get("id", str(uuid.uuid4())),
@@ -481,27 +483,27 @@ def _parse_response_candidate(
481
483
  )
482
484
  else:
483
485
  try:
484
- tool_calls_dicts = parse_tool_calls(
486
+ tool_call_dict = parse_tool_calls(
485
487
  [{"function": function_call}],
486
488
  return_id=False,
487
- )
488
- tool_calls = [
489
- ToolCall(
490
- name=tool_call["name"],
491
- args=tool_call["args"],
492
- id=tool_call.get("id", str(uuid.uuid4())),
493
- )
494
- for tool_call in tool_calls_dicts
495
- ]
489
+ )[0]
496
490
  except Exception as e:
497
- invalid_tool_calls = [
498
- InvalidToolCall(
491
+ invalid_tool_calls.append(
492
+ invalid_tool_call(
499
493
  name=function_call.get("name"),
500
494
  args=function_call.get("arguments"),
501
495
  id=function_call.get("id", str(uuid.uuid4())),
502
496
  error=str(e),
503
497
  )
504
- ]
498
+ )
499
+ else:
500
+ tool_calls.append(
501
+ tool_call(
502
+ name=tool_call_dict["name"],
503
+ args=tool_call_dict["args"],
504
+ id=tool_call_dict.get("id", str(uuid.uuid4())),
505
+ )
506
+ )
505
507
  if content is None:
506
508
  content = ""
507
509
 
@@ -587,25 +589,207 @@ def _is_event_loop_running() -> bool:
587
589
 
588
590
 
589
591
  class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
590
- """`Google Generative AI` Chat models API.
592
+ """`Google AI` chat models integration.
591
593
 
592
- To use, you must have either:
594
+ Instantiation:
595
+ To use, you must have either:
593
596
 
594
- 1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
595
- 2. Pass your API key using the google_api_key kwarg to the ChatGoogle
596
- constructor.
597
+ 1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
598
+ 2. Pass your API key using the google_api_key kwarg to the ChatGoogle
599
+ constructor.
597
600
 
598
- Example:
599
601
  .. code-block:: python
600
602
 
601
603
  from langchain_google_genai import ChatGoogleGenerativeAI
602
- chat = ChatGoogleGenerativeAI(model="gemini-pro")
603
- chat.invoke("Write me a ballad about LangChain")
604
604
 
605
- """
605
+ llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro")
606
+ llm.invoke("Write me a ballad about LangChain")
607
+
608
+ Invoke:
609
+ .. code-block:: python
610
+
611
+ messages = [
612
+ ("system", "Translate the user sentence to French."),
613
+ ("human", "I love programming."),
614
+ ]
615
+ llm.invoke(messages)
616
+
617
+ .. code-block:: python
618
+
619
+ AIMessage(
620
+ content="J'adore programmer. \\n",
621
+ response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]},
622
+ id='run-56cecc34-2e54-4b52-a974-337e47008ad2-0',
623
+ usage_metadata={'input_tokens': 18, 'output_tokens': 5, 'total_tokens': 23}
624
+ )
625
+
626
+ Stream:
627
+ .. code-block:: python
628
+
629
+ for chunk in llm.stream(messages):
630
+ print(chunk)
631
+
632
+ .. code-block:: python
633
+
634
+ AIMessageChunk(content='J', response_metadata={'finish_reason': 'STOP', 'safety_ratings': []}, id='run-e905f4f4-58cb-4a10-a960-448a2bb649e3', usage_metadata={'input_tokens': 18, 'output_tokens': 1, 'total_tokens': 19})
635
+ AIMessageChunk(content="'adore programmer. \n", response_metadata={'finish_reason': 'STOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]}, id='run-e905f4f4-58cb-4a10-a960-448a2bb649e3', usage_metadata={'input_tokens': 18, 'output_tokens': 5, 'total_tokens': 23})
636
+
637
+ .. code-block:: python
638
+
639
+ stream = llm.stream(messages)
640
+ full = next(stream)
641
+ for chunk in stream:
642
+ full += chunk
643
+ full
644
+
645
+ .. code-block:: python
646
+
647
+ AIMessageChunk(
648
+ content="J'adore programmer. \\n",
649
+ response_metadata={'finish_reason': 'STOPSTOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]},
650
+ id='run-3ce13a42-cd30-4ad7-a684-f1f0b37cdeec',
651
+ usage_metadata={'input_tokens': 36, 'output_tokens': 6, 'total_tokens': 42}
652
+ )
653
+
654
+ Async:
655
+ .. code-block:: python
656
+
657
+ await llm.ainvoke(messages)
658
+
659
+ # stream:
660
+ # async for chunk in (await llm.astream(messages))
661
+
662
+ # batch:
663
+ # await llm.abatch([messages])
664
+
665
+ Tool calling:
666
+ .. code-block:: python
667
+
668
+ from langchain_core.pydantic_v1 import BaseModel, Field
669
+
670
+
671
+ class GetWeather(BaseModel):
672
+ '''Get the current weather in a given location'''
673
+
674
+ location: str = Field(
675
+ ..., description="The city and state, e.g. San Francisco, CA"
676
+ )
677
+
678
+
679
+ class GetPopulation(BaseModel):
680
+ '''Get the current population in a given location'''
681
+
682
+ location: str = Field(
683
+ ..., description="The city and state, e.g. San Francisco, CA"
684
+ )
685
+
686
+
687
+ llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
688
+ ai_msg = llm_with_tools.invoke(
689
+ "Which city is hotter today and which is bigger: LA or NY?"
690
+ )
691
+ ai_msg.tool_calls
692
+
693
+ .. code-block:: python
694
+
695
+ [{'name': 'GetWeather',
696
+ 'args': {'location': 'Los Angeles, CA'},
697
+ 'id': 'c186c99f-f137-4d52-947f-9e3deabba6f6'},
698
+ {'name': 'GetWeather',
699
+ 'args': {'location': 'New York City, NY'},
700
+ 'id': 'cebd4a5d-e800-4fa5-babd-4aa286af4f31'},
701
+ {'name': 'GetPopulation',
702
+ 'args': {'location': 'Los Angeles, CA'},
703
+ 'id': '4f92d897-f5e4-4d34-a3bc-93062c92591e'},
704
+ {'name': 'GetPopulation',
705
+ 'args': {'location': 'New York City, NY'},
706
+ 'id': '634582de-5186-4e4b-968b-f192f0a93678'}]
707
+
708
+ Structured output:
709
+ .. code-block:: python
710
+
711
+ from typing import Optional
712
+
713
+ from langchain_core.pydantic_v1 import BaseModel, Field
714
+
715
+
716
+ class Joke(BaseModel):
717
+ '''Joke to tell user.'''
718
+
719
+ setup: str = Field(description="The setup of the joke")
720
+ punchline: str = Field(description="The punchline to the joke")
721
+ rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
722
+
723
+
724
+ structured_llm = llm.with_structured_output(Joke)
725
+ structured_llm.invoke("Tell me a joke about cats")
726
+
727
+ .. code-block:: python
606
728
 
607
- client: Any #: :meta private:
608
- async_client: Any #: :meta private:
729
+ Joke(
730
+ setup='Why are cats so good at video games?',
731
+ punchline='They have nine lives on the internet',
732
+ rating=None
733
+ )
734
+
735
+ Image input:
736
+ .. code-block:: python
737
+
738
+ import base64
739
+ import httpx
740
+ from langchain_core.messages import HumanMessage
741
+
742
+ image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
743
+ image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8")
744
+ message = HumanMessage(
745
+ content=[
746
+ {"type": "text", "text": "describe the weather in this image"},
747
+ {
748
+ "type": "image_url",
749
+ "image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
750
+ },
751
+ ]
752
+ )
753
+ ai_msg = llm.invoke([message])
754
+ ai_msg.content
755
+
756
+ .. code-block:: python
757
+
758
+ 'The weather in this image appears to be sunny and pleasant. The sky is a bright blue with scattered white clouds, suggesting fair weather. The lush green grass and trees indicate a warm and possibly slightly breezy day. There are no signs of rain or storms. \n'
759
+
760
+ Token usage:
761
+ .. code-block:: python
762
+
763
+ ai_msg = llm.invoke(messages)
764
+ ai_msg.usage_metadata
765
+
766
+ .. code-block:: python
767
+
768
+ {'input_tokens': 18, 'output_tokens': 5, 'total_tokens': 23}
769
+
770
+
771
+ Response metadata
772
+ .. code-block:: python
773
+
774
+ ai_msg = llm.invoke(messages)
775
+ ai_msg.response_metadata
776
+
777
+ .. code-block:: python
778
+
779
+ {
780
+ 'prompt_feedback': {'block_reason': 0, 'safety_ratings': []},
781
+ 'finish_reason': 'STOP',
782
+ 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]
783
+ }
784
+
785
+ """ # noqa: E501
786
+
787
+ client: Any = None #: :meta private:
788
+ async_client: Any = None #: :meta private:
789
+ google_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
790
+ """Google AI API key.
791
+
792
+ If not specified will be read from env var ``GOOGLE_API_KEY``."""
609
793
  default_metadata: Sequence[Tuple[str, str]] = Field(
610
794
  default_factory=list
611
795
  ) #: :meta private:
@@ -779,9 +963,18 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
779
963
  **kwargs: Any,
780
964
  ) -> ChatResult:
781
965
  if not self.async_client:
782
- raise RuntimeError(
783
- "Initialize ChatGoogleGenerativeAI with a running event loop "
784
- "to use async methods."
966
+ updated_kwargs = {
967
+ **kwargs,
968
+ **{
969
+ "tools": tools,
970
+ "functions": functions,
971
+ "safety_settings": safety_settings,
972
+ "tool_config": tool_config,
973
+ "generation_config": generation_config,
974
+ },
975
+ }
976
+ return await super()._agenerate(
977
+ messages, stop, run_manager, **updated_kwargs
785
978
  )
786
979
 
787
980
  request = self._prepare_request(
@@ -850,27 +1043,43 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
850
1043
  generation_config: Optional[Dict[str, Any]] = None,
851
1044
  **kwargs: Any,
852
1045
  ) -> AsyncIterator[ChatGenerationChunk]:
853
- request = self._prepare_request(
854
- messages,
855
- stop=stop,
856
- tools=tools,
857
- functions=functions,
858
- safety_settings=safety_settings,
859
- tool_config=tool_config,
860
- generation_config=generation_config,
861
- )
862
- async for chunk in await _achat_with_retry(
863
- request=request,
864
- generation_method=self.async_client.stream_generate_content,
865
- **kwargs,
866
- metadata=self.default_metadata,
867
- ):
868
- _chat_result = _response_to_result(chunk, stream=True)
869
- gen = cast(ChatGenerationChunk, _chat_result.generations[0])
1046
+ if not self.async_client:
1047
+ updated_kwargs = {
1048
+ **kwargs,
1049
+ **{
1050
+ "tools": tools,
1051
+ "functions": functions,
1052
+ "safety_settings": safety_settings,
1053
+ "tool_config": tool_config,
1054
+ "generation_config": generation_config,
1055
+ },
1056
+ }
1057
+ async for value in super()._astream(
1058
+ messages, stop, run_manager, **updated_kwargs
1059
+ ):
1060
+ yield value
1061
+ else:
1062
+ request = self._prepare_request(
1063
+ messages,
1064
+ stop=stop,
1065
+ tools=tools,
1066
+ functions=functions,
1067
+ safety_settings=safety_settings,
1068
+ tool_config=tool_config,
1069
+ generation_config=generation_config,
1070
+ )
1071
+ async for chunk in await _achat_with_retry(
1072
+ request=request,
1073
+ generation_method=self.async_client.stream_generate_content,
1074
+ **kwargs,
1075
+ metadata=self.default_metadata,
1076
+ ):
1077
+ _chat_result = _response_to_result(chunk, stream=True)
1078
+ gen = cast(ChatGenerationChunk, _chat_result.generations[0])
870
1079
 
871
- if run_manager:
872
- await run_manager.on_llm_new_token(gen.text)
873
- yield gen
1080
+ if run_manager:
1081
+ await run_manager.on_llm_new_token(gen.text)
1082
+ yield gen
874
1083
 
875
1084
  def _prepare_request(
876
1085
  self,
@@ -885,9 +1094,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
885
1094
  ) -> Tuple[GenerateContentRequest, Dict[str, Any]]:
886
1095
  formatted_tools = None
887
1096
  if tools:
888
- formatted_tools = [
889
- convert_to_genai_function_declarations(tool) for tool in tools
890
- ]
1097
+ formatted_tools = [convert_to_genai_function_declarations(tools)]
891
1098
  elif functions:
892
1099
  formatted_tools = [convert_to_genai_function_declarations(functions)]
893
1100
 
@@ -937,6 +1144,34 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
937
1144
  )
938
1145
  return result.total_tokens
939
1146
 
1147
+ def with_structured_output(
1148
+ self,
1149
+ schema: Union[Dict, Type[BaseModel]],
1150
+ *,
1151
+ include_raw: bool = False,
1152
+ **kwargs: Any,
1153
+ ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
1154
+ if kwargs:
1155
+ raise ValueError(f"Received unsupported arguments {kwargs}")
1156
+ if isinstance(schema, type) and issubclass(schema, BaseModel):
1157
+ parser: OutputParserLike = PydanticToolsParser(
1158
+ tools=[schema], first_tool_only=True
1159
+ )
1160
+ else:
1161
+ parser = JsonOutputToolsParser()
1162
+ tool_choice = _get_tool_name(schema) if self._supports_tool_choice else None
1163
+ llm = self.bind_tools([schema], tool_choice=tool_choice)
1164
+ if include_raw:
1165
+ parser_with_fallback = RunnablePassthrough.assign(
1166
+ parsed=itemgetter("raw") | parser, parsing_error=lambda _: None
1167
+ ).with_fallbacks(
1168
+ [RunnablePassthrough.assign(parsed=lambda _: None)],
1169
+ exception_key="parsing_error",
1170
+ )
1171
+ return {"raw": llm} | parser_with_fallback
1172
+ else:
1173
+ return llm | parser
1174
+
940
1175
  def bind_tools(
941
1176
  self,
942
1177
  tools: Sequence[Union[ToolDict, GoogleTool]],
@@ -972,3 +1207,14 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
972
1207
  ]
973
1208
  tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
974
1209
  return self.bind(tools=genai_tools, tool_config=tool_config, **kwargs)
1210
+
1211
+ @property
1212
+ def _supports_tool_choice(self) -> bool:
1213
+ return "gemini-1.5-pro" in self.model
1214
+
1215
+
1216
+ def _get_tool_name(
1217
+ tool: Union[ToolDict, GoogleTool],
1218
+ ) -> str:
1219
+ genai_tool = tool_to_dict(convert_to_genai_function_declarations([tool]))
1220
+ return [f["name"] for f in genai_tool["function_declarations"]][0] # type: ignore[index]
@@ -39,20 +39,20 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
39
39
  embeddings.embed_query("What's our Q1 revenue?")
40
40
  """
41
41
 
42
- client: Any #: :meta private:
42
+ client: Any = None #: :meta private:
43
43
  model: str = Field(
44
44
  ...,
45
45
  description="The name of the embedding model to use. "
46
46
  "Example: models/embedding-001",
47
47
  )
48
48
  task_type: Optional[str] = Field(
49
- None,
49
+ default=None,
50
50
  description="The task type. Valid options include: "
51
51
  "task_type_unspecified, retrieval_query, retrieval_document, "
52
52
  "semantic_similarity, classification, and clustering",
53
53
  )
54
54
  google_api_key: Optional[SecretStr] = Field(
55
- None,
55
+ default=None,
56
56
  description="The Google API key to use. If not provided, "
57
57
  "the GOOGLE_API_KEY environment variable will be used.",
58
58
  )
@@ -64,18 +64,18 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
64
64
  "provided, credentials will be ascertained from the GOOGLE_API_KEY envvar",
65
65
  )
66
66
  client_options: Optional[Dict] = Field(
67
- None,
67
+ default=None,
68
68
  description=(
69
69
  "A dictionary of client options to pass to the Google API client, "
70
70
  "such as `api_endpoint`."
71
71
  ),
72
72
  )
73
73
  transport: Optional[str] = Field(
74
- None,
74
+ default=None,
75
75
  description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
76
76
  )
77
77
  request_options: Optional[Dict] = Field(
78
- None,
78
+ default=None,
79
79
  description="A dictionary of request options to pass to the Google API client."
80
80
  "Example: `{'timeout': 10}`",
81
81
  )
@@ -86,6 +86,8 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
86
86
  google_api_key = get_from_dict_or_env(
87
87
  values, "google_api_key", "GOOGLE_API_KEY"
88
88
  )
89
+ if isinstance(google_api_key, SecretStr):
90
+ google_api_key = google_api_key.get_secret_value()
89
91
  client_info = get_client_info("GoogleGenerativeAIEmbeddings")
90
92
 
91
93
  values["client"] = build_generative_service(
@@ -149,18 +149,18 @@ Supported examples:
149
149
  """The maximum number of seconds to wait for a response."""
150
150
 
151
151
  client_options: Optional[Dict] = Field(
152
- None,
152
+ default=None,
153
153
  description=(
154
154
  "A dictionary of client options to pass to the Google API client, "
155
155
  "such as `api_endpoint`."
156
156
  ),
157
157
  )
158
158
  transport: Optional[str] = Field(
159
- None,
159
+ default=None,
160
160
  description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
161
161
  )
162
162
  additional_headers: Optional[Dict[str, str]] = Field(
163
- None,
163
+ default=None,
164
164
  description=(
165
165
  "A key-value dictionary representing additional headers for the model call"
166
166
  ),
@@ -212,7 +212,7 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
212
212
  llm = GoogleGenerativeAI(model="gemini-pro")
213
213
  """
214
214
 
215
- client: Any #: :meta private:
215
+ client: Any = None #: :meta private:
216
216
 
217
217
  @root_validator()
218
218
  def validate_environment(cls, values: Dict) -> Dict:
@@ -325,9 +325,16 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
325
325
  run_manager: Optional[CallbackManagerForLLMRun] = None,
326
326
  **kwargs: Any,
327
327
  ) -> Iterator[GenerationChunk]:
328
- generation_config = kwargs.get("generation_config", {})
329
- if stop:
330
- generation_config["stop_sequences"] = stop
328
+ generation_config = {
329
+ "stop_sequences": stop,
330
+ "temperature": self.temperature,
331
+ "top_p": self.top_p,
332
+ "top_k": self.top_k,
333
+ "max_output_tokens": self.max_output_tokens,
334
+ "candidate_count": self.n,
335
+ }
336
+ generation_config = generation_config | kwargs.get("generation_config", {})
337
+
331
338
  for stream_resp in _completion_with_retry(
332
339
  self,
333
340
  prompt,
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "langchain-google-genai"
3
- version = "1.0.6"
3
+ version = "1.0.8"
4
4
  description = "An integration package connecting Google's genai package and LangChain"
5
5
  authors = []
6
6
  readme = "README.md"
@@ -12,8 +12,8 @@ license = "MIT"
12
12
 
13
13
  [tool.poetry.dependencies]
14
14
  python = ">=3.9,<4.0"
15
- langchain-core = ">=0.2.2,<0.3"
16
- google-generativeai = "^0.5.2"
15
+ langchain-core = ">=0.2.17,<0.3"
16
+ google-generativeai = "^0.7.0"
17
17
  pillow = { version = "^10.1.0", optional = true }
18
18
 
19
19
  [tool.poetry.extras]
@@ -31,6 +31,7 @@ pytest-watcher = "^0.3.4"
31
31
  pytest-asyncio = "^0.21.1"
32
32
  numpy = "^1.26.2"
33
33
  langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" }
34
+ langchain-standard-tests = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/standard-tests" }
34
35
 
35
36
  [tool.codespell]
36
37
  ignore-words-list = "rouge"
@@ -61,6 +62,7 @@ types-google-cloud-ndb = "^2.2.0.1"
61
62
  types-pillow = "^10.1.0.2"
62
63
  types-protobuf = "^4.24.0.20240302"
63
64
  langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" }
65
+ numpy = "^1.26.2"
64
66
 
65
67
  [tool.poetry.group.dev]
66
68
  optional = true
@@ -72,7 +74,7 @@ types-pillow = "^10.1.0.2"
72
74
  types-google-cloud-ndb = "^2.2.0.1"
73
75
  langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" }
74
76
 
75
- [tool.ruff]
77
+ [tool.ruff.lint]
76
78
  select = [
77
79
  "E", # pycodestyle
78
80
  "F", # pyflakes
@@ -1,237 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import (
4
- Any,
5
- Callable,
6
- Dict,
7
- List,
8
- Literal,
9
- Optional,
10
- Sequence,
11
- Type,
12
- TypedDict,
13
- Union,
14
- cast,
15
- )
16
-
17
- import google.ai.generativelanguage as glm
18
- from google.ai.generativelanguage import (
19
- FunctionCallingConfig,
20
- FunctionDeclaration,
21
- )
22
- from google.ai.generativelanguage import (
23
- Tool as GoogleTool,
24
- )
25
- from langchain_core.pydantic_v1 import BaseModel
26
- from langchain_core.tools import BaseTool
27
- from langchain_core.tools import tool as callable_as_lc_tool
28
- from langchain_core.utils.json_schema import dereference_refs
29
-
30
- TYPE_ENUM = {
31
- "string": glm.Type.STRING,
32
- "number": glm.Type.NUMBER,
33
- "integer": glm.Type.INTEGER,
34
- "boolean": glm.Type.BOOLEAN,
35
- "array": glm.Type.ARRAY,
36
- "object": glm.Type.OBJECT,
37
- }
38
-
39
- TYPE_ENUM_REVERSE = {v: k for k, v in TYPE_ENUM.items()}
40
-
41
- _FunctionDeclarationLike = Union[
42
- BaseTool, Type[BaseModel], dict, Callable, FunctionDeclaration
43
- ]
44
-
45
-
46
- class _ToolDict(TypedDict):
47
- function_declarations: Sequence[_FunctionDeclarationLike]
48
-
49
-
50
- def convert_to_genai_function_declarations(
51
- tool: Union[
52
- GoogleTool,
53
- _ToolDict,
54
- _FunctionDeclarationLike,
55
- Sequence[_FunctionDeclarationLike],
56
- ],
57
- ) -> GoogleTool:
58
- if isinstance(tool, GoogleTool):
59
- return cast(GoogleTool, tool)
60
- if isinstance(tool, type) and issubclass(tool, BaseModel):
61
- return GoogleTool(function_declarations=[_convert_to_genai_function(tool)])
62
- if callable(tool):
63
- return _convert_tool_to_genai_function(callable_as_lc_tool()(tool))
64
- if isinstance(tool, list):
65
- return convert_to_genai_function_declarations({"function_declarations": tool})
66
- if isinstance(tool, dict) and "function_declarations" in tool:
67
- return GoogleTool(
68
- function_declarations=[
69
- _convert_to_genai_function(fc) for fc in tool["function_declarations"]
70
- ],
71
- )
72
- return GoogleTool(function_declarations=[_convert_to_genai_function(tool)]) # type: ignore[arg-type]
73
-
74
-
75
- def tool_to_dict(tool: GoogleTool) -> _ToolDict:
76
- function_declarations = []
77
- for function_declaration_proto in tool.function_declarations:
78
- properties: Dict[str, Any] = {}
79
- for property in function_declaration_proto.parameters.properties:
80
- property_type = function_declaration_proto.parameters.properties[
81
- property
82
- ].type
83
- property_dict = {"type": TYPE_ENUM_REVERSE[property_type]}
84
- property_description = function_declaration_proto.parameters.properties[
85
- property
86
- ].description
87
- if property_description:
88
- property_dict["description"] = property_description
89
- properties[property] = property_dict
90
- function_declaration = {
91
- "name": function_declaration_proto.name,
92
- "description": function_declaration_proto.description,
93
- "parameters": {"type": "object", "properties": properties},
94
- }
95
- if function_declaration_proto.parameters.required:
96
- function_declaration["parameters"][ # type: ignore[index]
97
- "required"
98
- ] = function_declaration_proto.parameters.required
99
- function_declarations.append(function_declaration)
100
- return {"function_declarations": function_declarations}
101
-
102
-
103
- def _convert_to_genai_function(fc: _FunctionDeclarationLike) -> FunctionDeclaration:
104
- if isinstance(fc, BaseTool):
105
- return _convert_tool_to_genai_function(fc)
106
- elif isinstance(fc, type) and issubclass(fc, BaseModel):
107
- return _convert_pydantic_to_genai_function(fc)
108
- elif callable(fc):
109
- return _convert_tool_to_genai_function(callable_as_lc_tool()(fc))
110
- elif isinstance(fc, dict):
111
- formatted_fc = {"name": fc["name"], "description": fc.get("description")}
112
- if "parameters" in fc:
113
- formatted_fc["parameters"] = {
114
- "properties": {
115
- k: {
116
- "type_": TYPE_ENUM[v["type"]],
117
- "description": v.get("description"),
118
- }
119
- for k, v in fc["parameters"]["properties"].items()
120
- },
121
- "required": fc.get("parameters", []).get("required", []),
122
- "type_": TYPE_ENUM[fc["parameters"]["type"]],
123
- }
124
- return FunctionDeclaration(**formatted_fc)
125
- else:
126
- raise ValueError(f"Unsupported function call type {fc}")
127
-
128
-
129
- def _convert_tool_to_genai_function(tool: BaseTool) -> FunctionDeclaration:
130
- if tool.args_schema:
131
- schema = dereference_refs(tool.args_schema.schema())
132
- schema.pop("definitions", None)
133
- return FunctionDeclaration(
134
- name=tool.name or schema["title"],
135
- description=tool.description or schema["description"],
136
- parameters={
137
- "properties": {
138
- k: {
139
- "type_": TYPE_ENUM[v["type"]],
140
- "description": v.get("description"),
141
- }
142
- for k, v in schema["properties"].items()
143
- },
144
- "required": schema.get("required", []),
145
- "type_": TYPE_ENUM[schema["type"]],
146
- },
147
- )
148
- else:
149
- return FunctionDeclaration(
150
- name=tool.name,
151
- description=tool.description,
152
- parameters={
153
- "properties": {
154
- "__arg1": {"type_": TYPE_ENUM["string"]},
155
- },
156
- "required": ["__arg1"],
157
- "type_": TYPE_ENUM["object"],
158
- },
159
- )
160
-
161
-
162
- def _convert_pydantic_to_genai_function(
163
- pydantic_model: Type[BaseModel],
164
- ) -> FunctionDeclaration:
165
- schema = dereference_refs(pydantic_model.schema())
166
- schema.pop("definitions", None)
167
- return FunctionDeclaration(
168
- name=schema["title"],
169
- description=schema.get("description", ""),
170
- parameters={
171
- "properties": {
172
- k: {
173
- "type_": TYPE_ENUM[v["type"]],
174
- "description": v.get("description"),
175
- }
176
- for k, v in schema["properties"].items()
177
- },
178
- "required": schema["required"],
179
- "type_": TYPE_ENUM[schema["type"]],
180
- },
181
- )
182
-
183
-
184
- _ToolChoiceType = Union[
185
- dict, List[str], str, Literal["auto", "none", "any"], Literal[True]
186
- ]
187
-
188
-
189
- class _FunctionCallingConfigDict(TypedDict):
190
- mode: Union[FunctionCallingConfig.Mode, str]
191
- allowed_function_names: Optional[List[str]]
192
-
193
-
194
- class _ToolConfigDict(TypedDict):
195
- function_calling_config: _FunctionCallingConfigDict
196
-
197
-
198
- def _tool_choice_to_tool_config(
199
- tool_choice: _ToolChoiceType,
200
- all_names: List[str],
201
- ) -> _ToolConfigDict:
202
- allowed_function_names: Optional[List[str]] = None
203
- if tool_choice is True or tool_choice == "any":
204
- mode = "any"
205
- allowed_function_names = all_names
206
- elif tool_choice == "auto":
207
- mode = "auto"
208
- elif tool_choice == "none":
209
- mode = "none"
210
- elif isinstance(tool_choice, str):
211
- mode = "any"
212
- allowed_function_names = [tool_choice]
213
- elif isinstance(tool_choice, list):
214
- mode = "any"
215
- allowed_function_names = tool_choice
216
- elif isinstance(tool_choice, dict):
217
- if "mode" in tool_choice:
218
- mode = tool_choice["mode"]
219
- allowed_function_names = tool_choice.get("allowed_function_names")
220
- elif "function_calling_config" in tool_choice:
221
- mode = tool_choice["function_calling_config"]["mode"]
222
- allowed_function_names = tool_choice["function_calling_config"].get(
223
- "allowed_function_names"
224
- )
225
- else:
226
- raise ValueError(
227
- f"Unrecognized tool choice format:\n\n{tool_choice=}\n\nShould match "
228
- f"Google GenerativeAI ToolConfig or FunctionCallingConfig format."
229
- )
230
- else:
231
- raise ValueError(f"Unrecognized tool choice format:\n\n{tool_choice=}")
232
- return _ToolConfigDict(
233
- function_calling_config={
234
- "mode": mode,
235
- "allowed_function_names": allowed_function_names,
236
- }
237
- )