langchain-google-genai 2.1.11__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain-google-genai might be problematic. Click here for more details.
- langchain_google_genai/__init__.py +3 -3
- langchain_google_genai/_common.py +29 -17
- langchain_google_genai/_compat.py +286 -0
- langchain_google_genai/_function_utils.py +77 -59
- langchain_google_genai/_genai_extension.py +60 -27
- langchain_google_genai/_image_utils.py +10 -9
- langchain_google_genai/chat_models.py +803 -297
- langchain_google_genai/embeddings.py +17 -24
- langchain_google_genai/genai_aqa.py +29 -18
- langchain_google_genai/google_vector_store.py +45 -25
- langchain_google_genai/llms.py +8 -7
- {langchain_google_genai-2.1.11.dist-info → langchain_google_genai-3.0.0.dist-info}/METADATA +43 -30
- langchain_google_genai-3.0.0.dist-info/RECORD +18 -0
- langchain_google_genai-2.1.11.dist-info/RECORD +0 -17
- {langchain_google_genai-2.1.11.dist-info → langchain_google_genai-3.0.0.dist-info}/WHEEL +0 -0
- {langchain_google_genai-2.1.11.dist-info → langchain_google_genai-3.0.0.dist-info}/entry_points.txt +0 -0
- {langchain_google_genai-2.1.11.dist-info → langchain_google_genai-3.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,6 +4,7 @@ import collections
|
|
|
4
4
|
import importlib
|
|
5
5
|
import json
|
|
6
6
|
import logging
|
|
7
|
+
from collections.abc import Sequence
|
|
7
8
|
from typing import (
|
|
8
9
|
Any,
|
|
9
10
|
Callable,
|
|
@@ -11,7 +12,6 @@ from typing import (
|
|
|
11
12
|
List,
|
|
12
13
|
Literal,
|
|
13
14
|
Optional,
|
|
14
|
-
Sequence,
|
|
15
15
|
Type,
|
|
16
16
|
TypedDict,
|
|
17
17
|
Union,
|
|
@@ -20,7 +20,7 @@ from typing import (
|
|
|
20
20
|
|
|
21
21
|
import google.ai.generativelanguage as glm
|
|
22
22
|
import google.ai.generativelanguage_v1beta.types as gapic
|
|
23
|
-
import proto # type: ignore[import]
|
|
23
|
+
import proto # type: ignore[import-untyped]
|
|
24
24
|
from langchain_core.tools import BaseTool
|
|
25
25
|
from langchain_core.tools import tool as callable_as_lc_tool
|
|
26
26
|
from langchain_core.utils.function_calling import (
|
|
@@ -48,12 +48,7 @@ TYPE_ENUM = {
|
|
|
48
48
|
_ALLOWED_SCHEMA_FIELDS = []
|
|
49
49
|
_ALLOWED_SCHEMA_FIELDS.extend([f.name for f in gapic.Schema()._pb.DESCRIPTOR.fields])
|
|
50
50
|
_ALLOWED_SCHEMA_FIELDS.extend(
|
|
51
|
-
|
|
52
|
-
f
|
|
53
|
-
for f in gapic.Schema.to_dict(
|
|
54
|
-
gapic.Schema(), preserving_proto_field_name=False
|
|
55
|
-
).keys()
|
|
56
|
-
]
|
|
51
|
+
list(gapic.Schema.to_dict(gapic.Schema(), preserving_proto_field_name=False).keys())
|
|
57
52
|
)
|
|
58
53
|
_ALLOWED_SCHEMA_FIELDS_SET = set(_ALLOWED_SCHEMA_FIELDS)
|
|
59
54
|
|
|
@@ -89,7 +84,7 @@ def _format_json_schema_to_gapic(schema: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
89
84
|
for key, value in schema.items():
|
|
90
85
|
if key == "definitions":
|
|
91
86
|
continue
|
|
92
|
-
|
|
87
|
+
if key == "items":
|
|
93
88
|
converted_schema["items"] = _format_json_schema_to_gapic(value)
|
|
94
89
|
elif key == "properties":
|
|
95
90
|
converted_schema["properties"] = _get_properties_from_schema(value)
|
|
@@ -142,10 +137,11 @@ def convert_to_genai_function_declarations(
|
|
|
142
137
|
gapic_tool = gapic.Tool()
|
|
143
138
|
for tool in tools:
|
|
144
139
|
if any(f in gapic_tool for f in ["google_search_retrieval"]):
|
|
145
|
-
|
|
140
|
+
msg = (
|
|
146
141
|
"Providing multiple google_search_retrieval"
|
|
147
142
|
" or mixing with function_declarations is not supported"
|
|
148
143
|
)
|
|
144
|
+
raise ValueError(msg)
|
|
149
145
|
if isinstance(tool, (gapic.Tool)):
|
|
150
146
|
rt: gapic.Tool = (
|
|
151
147
|
tool if isinstance(tool, gapic.Tool) else tool._raw_tool # type: ignore
|
|
@@ -171,16 +167,17 @@ def convert_to_genai_function_declarations(
|
|
|
171
167
|
gapic_tool.function_declarations.append(fd)
|
|
172
168
|
continue
|
|
173
169
|
# _ToolDictLike
|
|
174
|
-
tool = cast(_ToolDict, tool)
|
|
170
|
+
tool = cast("_ToolDict", tool)
|
|
175
171
|
if "function_declarations" in tool:
|
|
176
172
|
function_declarations = tool["function_declarations"]
|
|
177
173
|
if not isinstance(
|
|
178
174
|
tool["function_declarations"], collections.abc.Sequence
|
|
179
175
|
):
|
|
180
|
-
|
|
176
|
+
msg = (
|
|
181
177
|
"function_declarations should be a list"
|
|
182
178
|
f"got '{type(function_declarations)}'"
|
|
183
179
|
)
|
|
180
|
+
raise ValueError(msg)
|
|
184
181
|
if function_declarations:
|
|
185
182
|
fds = [
|
|
186
183
|
_format_to_gapic_function_declaration(fd)
|
|
@@ -198,7 +195,7 @@ def convert_to_genai_function_declarations(
|
|
|
198
195
|
if "code_execution" in tool:
|
|
199
196
|
gapic_tool.code_execution = gapic.CodeExecution(tool["code_execution"])
|
|
200
197
|
else:
|
|
201
|
-
fd = _format_to_gapic_function_declaration(tool)
|
|
198
|
+
fd = _format_to_gapic_function_declaration(tool)
|
|
202
199
|
gapic_tool.function_declarations.append(fd)
|
|
203
200
|
return gapic_tool
|
|
204
201
|
|
|
@@ -221,30 +218,32 @@ def _format_to_gapic_function_declaration(
|
|
|
221
218
|
) -> gapic.FunctionDeclaration:
|
|
222
219
|
if isinstance(tool, BaseTool):
|
|
223
220
|
return _format_base_tool_to_function_declaration(tool)
|
|
224
|
-
|
|
221
|
+
if isinstance(tool, type) and is_basemodel_subclass_safe(tool):
|
|
225
222
|
return _convert_pydantic_to_genai_function(tool)
|
|
226
|
-
|
|
223
|
+
if isinstance(tool, dict):
|
|
227
224
|
if all(k in tool for k in ("type", "function")) and tool["type"] == "function":
|
|
228
225
|
function = tool["function"]
|
|
229
226
|
elif (
|
|
230
227
|
all(k in tool for k in ("name", "description")) and "parameters" not in tool
|
|
231
228
|
):
|
|
232
|
-
function = cast(dict, tool)
|
|
229
|
+
function = cast("dict", tool)
|
|
230
|
+
elif (
|
|
231
|
+
"parameters" in tool and tool["parameters"].get("properties") # type: ignore[index]
|
|
232
|
+
):
|
|
233
|
+
function = convert_to_openai_tool(cast("dict", tool))["function"]
|
|
233
234
|
else:
|
|
234
|
-
|
|
235
|
-
"parameters" in tool and tool["parameters"].get("properties") # type: ignore[index]
|
|
236
|
-
):
|
|
237
|
-
function = convert_to_openai_tool(cast(dict, tool))["function"]
|
|
238
|
-
else:
|
|
239
|
-
function = cast(dict, tool)
|
|
235
|
+
function = cast("dict", tool)
|
|
240
236
|
function["parameters"] = function.get("parameters") or {}
|
|
241
237
|
# Empty 'properties' field not supported.
|
|
242
238
|
if not function["parameters"].get("properties"):
|
|
243
239
|
function["parameters"] = {}
|
|
244
|
-
return _format_dict_to_function_declaration(
|
|
245
|
-
|
|
240
|
+
return _format_dict_to_function_declaration(
|
|
241
|
+
cast("FunctionDescription", function)
|
|
242
|
+
)
|
|
243
|
+
if callable(tool):
|
|
246
244
|
return _format_base_tool_to_function_declaration(callable_as_lc_tool()(tool))
|
|
247
|
-
|
|
245
|
+
msg = f"Unsupported tool type {tool}"
|
|
246
|
+
raise ValueError(msg)
|
|
248
247
|
|
|
249
248
|
|
|
250
249
|
def _format_base_tool_to_function_declaration(
|
|
@@ -270,10 +269,11 @@ def _format_base_tool_to_function_declaration(
|
|
|
270
269
|
elif issubclass(tool.args_schema, BaseModelV1):
|
|
271
270
|
schema = tool.args_schema.schema()
|
|
272
271
|
else:
|
|
273
|
-
|
|
272
|
+
msg = (
|
|
274
273
|
"args_schema must be a Pydantic BaseModel or JSON schema, "
|
|
275
274
|
f"got {tool.args_schema}."
|
|
276
275
|
)
|
|
276
|
+
raise NotImplementedError(msg)
|
|
277
277
|
parameters = _dict_to_gapic_schema(schema)
|
|
278
278
|
|
|
279
279
|
return gapic.FunctionDeclaration(
|
|
@@ -293,12 +293,11 @@ def _convert_pydantic_to_genai_function(
|
|
|
293
293
|
elif issubclass(pydantic_model, BaseModelV1):
|
|
294
294
|
schema = pydantic_model.schema()
|
|
295
295
|
else:
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
)
|
|
296
|
+
msg = f"pydantic_model must be a Pydantic BaseModel, got {pydantic_model}"
|
|
297
|
+
raise NotImplementedError(msg)
|
|
299
298
|
schema = dereference_refs(schema)
|
|
300
299
|
schema.pop("definitions", None)
|
|
301
|
-
|
|
300
|
+
return gapic.FunctionDeclaration(
|
|
302
301
|
name=tool_name if tool_name else schema.get("title"),
|
|
303
302
|
description=tool_description if tool_description else schema.get("description"),
|
|
304
303
|
parameters={
|
|
@@ -312,7 +311,6 @@ def _convert_pydantic_to_genai_function(
|
|
|
312
311
|
"type_": TYPE_ENUM[schema["type"]],
|
|
313
312
|
},
|
|
314
313
|
)
|
|
315
|
-
return function_declaration
|
|
316
314
|
|
|
317
315
|
|
|
318
316
|
def _get_properties_from_schema_any(schema: Any) -> Dict[str, Any]:
|
|
@@ -332,8 +330,10 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
|
|
|
332
330
|
continue
|
|
333
331
|
properties_item: Dict[str, Union[str, int, Dict, List]] = {}
|
|
334
332
|
|
|
335
|
-
#
|
|
336
|
-
|
|
333
|
+
# Preserve description and other schema properties before manipulation
|
|
334
|
+
original_description = v.get("description")
|
|
335
|
+
original_enum = v.get("enum")
|
|
336
|
+
original_items = v.get("items")
|
|
337
337
|
|
|
338
338
|
if v.get("anyOf") and all(
|
|
339
339
|
anyOf_type.get("type") != "null" for anyOf_type in v.get("anyOf", [])
|
|
@@ -356,11 +356,34 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
|
|
|
356
356
|
if any_of_types and item_type_ in [glm.Type.ARRAY, glm.Type.OBJECT]:
|
|
357
357
|
json_type_ = "array" if item_type_ == glm.Type.ARRAY else "object"
|
|
358
358
|
# Use Index -1 for consistency with `_get_nullable_type_from_schema`
|
|
359
|
-
|
|
359
|
+
filtered_schema = [
|
|
360
|
+
val for val in any_of_types if val.get("type") == json_type_
|
|
361
|
+
][-1]
|
|
362
|
+
# Merge filtered schema with original properties to preserve enum/items
|
|
363
|
+
v = filtered_schema.copy()
|
|
364
|
+
if original_enum and not v.get("enum"):
|
|
365
|
+
v["enum"] = original_enum
|
|
366
|
+
if original_items and not v.get("items"):
|
|
367
|
+
v["items"] = original_items
|
|
368
|
+
elif any_of_types:
|
|
369
|
+
# For other types (like strings with enums), find the non-null schema
|
|
370
|
+
# and preserve enum/items from the original anyOf structure
|
|
371
|
+
non_null_schemas = [
|
|
372
|
+
val for val in any_of_types if val.get("type") != "null"
|
|
373
|
+
]
|
|
374
|
+
if non_null_schemas:
|
|
375
|
+
filtered_schema = non_null_schemas[-1]
|
|
376
|
+
v = filtered_schema.copy()
|
|
377
|
+
if original_enum and not v.get("enum"):
|
|
378
|
+
v["enum"] = original_enum
|
|
379
|
+
if original_items and not v.get("items"):
|
|
380
|
+
v["items"] = original_items
|
|
360
381
|
|
|
361
382
|
if v.get("enum"):
|
|
362
383
|
properties_item["enum"] = v["enum"]
|
|
363
384
|
|
|
385
|
+
# Prefer description from the filtered schema, fall back to original
|
|
386
|
+
description = v.get("description") or original_description
|
|
364
387
|
if description and isinstance(description, str):
|
|
365
388
|
properties_item["description"] = description
|
|
366
389
|
|
|
@@ -417,6 +440,8 @@ def _get_items_from_schema(schema: Union[Dict, List, str]) -> Dict[str, Any]:
|
|
|
417
440
|
items["description"] = (
|
|
418
441
|
schema.get("description") or schema.get("title") or ""
|
|
419
442
|
)
|
|
443
|
+
if "enum" in schema:
|
|
444
|
+
items["enum"] = schema["enum"]
|
|
420
445
|
if _is_nullable_schema(schema):
|
|
421
446
|
items["nullable"] = True
|
|
422
447
|
if "required" in schema:
|
|
@@ -444,8 +469,6 @@ def _get_nullable_type_from_schema(schema: Dict[str, Any]) -> Optional[int]:
|
|
|
444
469
|
types = [t for t in types if t is not None] # Remove None values
|
|
445
470
|
if types:
|
|
446
471
|
return types[-1] # TODO: update FunctionDeclaration and pass all types?
|
|
447
|
-
else:
|
|
448
|
-
pass
|
|
449
472
|
elif "type" in schema or "type_" in schema:
|
|
450
473
|
type_ = schema["type"] if "type" in schema else schema["type_"]
|
|
451
474
|
if isinstance(type_, int):
|
|
@@ -463,20 +486,16 @@ def _is_nullable_schema(schema: Dict[str, Any]) -> bool:
|
|
|
463
486
|
_get_nullable_type_from_schema(sub_schema) for sub_schema in schema["anyOf"]
|
|
464
487
|
]
|
|
465
488
|
return any(t is None for t in types)
|
|
466
|
-
|
|
489
|
+
if "type" in schema or "type_" in schema:
|
|
467
490
|
type_ = schema["type"] if "type" in schema else schema["type_"]
|
|
468
491
|
if isinstance(type_, int):
|
|
469
492
|
return False
|
|
470
493
|
stype = str(schema["type"]) if "type" in schema else str(schema["type_"])
|
|
471
494
|
return TYPE_ENUM.get(stype, glm.Type.STRING) is None
|
|
472
|
-
else:
|
|
473
|
-
pass
|
|
474
495
|
return False
|
|
475
496
|
|
|
476
497
|
|
|
477
|
-
_ToolChoiceType = Union[
|
|
478
|
-
dict, List[str], str, Literal["auto", "none", "any"], Literal[True]
|
|
479
|
-
]
|
|
498
|
+
_ToolChoiceType = Union[Literal["auto", "none", "any", True], dict, List[str], str]
|
|
480
499
|
|
|
481
500
|
|
|
482
501
|
class _FunctionCallingConfigDict(TypedDict):
|
|
@@ -516,12 +535,14 @@ def _tool_choice_to_tool_config(
|
|
|
516
535
|
"allowed_function_names"
|
|
517
536
|
)
|
|
518
537
|
else:
|
|
519
|
-
|
|
538
|
+
msg = (
|
|
520
539
|
f"Unrecognized tool choice format:\n\n{tool_choice=}\n\nShould match "
|
|
521
540
|
f"Google GenerativeAI ToolConfig or FunctionCallingConfig format."
|
|
522
541
|
)
|
|
542
|
+
raise ValueError(msg)
|
|
523
543
|
else:
|
|
524
|
-
|
|
544
|
+
msg = f"Unrecognized tool choice format:\n\n{tool_choice=}"
|
|
545
|
+
raise ValueError(msg)
|
|
525
546
|
return _ToolConfigDict(
|
|
526
547
|
function_calling_config={
|
|
527
548
|
"mode": mode.upper(),
|
|
@@ -533,12 +554,11 @@ def _tool_choice_to_tool_config(
|
|
|
533
554
|
def is_basemodel_subclass_safe(tool: Type) -> bool:
|
|
534
555
|
if safe_import("langchain_core.utils.pydantic", "is_basemodel_subclass"):
|
|
535
556
|
from langchain_core.utils.pydantic import (
|
|
536
|
-
is_basemodel_subclass,
|
|
557
|
+
is_basemodel_subclass,
|
|
537
558
|
)
|
|
538
559
|
|
|
539
560
|
return is_basemodel_subclass(tool)
|
|
540
|
-
|
|
541
|
-
return issubclass(tool, BaseModel)
|
|
561
|
+
return issubclass(tool, BaseModel)
|
|
542
562
|
|
|
543
563
|
|
|
544
564
|
def safe_import(module_name: str, attribute_name: str = "") -> bool:
|
|
@@ -562,7 +582,6 @@ def replace_defs_in_schema(original_schema: dict, defs: Optional[dict] = None) -
|
|
|
562
582
|
Returns:
|
|
563
583
|
Schema with refs replaced.
|
|
564
584
|
"""
|
|
565
|
-
|
|
566
585
|
new_defs = defs or original_schema.get("$defs")
|
|
567
586
|
|
|
568
587
|
if new_defs is None or not isinstance(new_defs, dict):
|
|
@@ -576,20 +595,19 @@ def replace_defs_in_schema(original_schema: dict, defs: Optional[dict] = None) -
|
|
|
576
595
|
|
|
577
596
|
if not isinstance(value, dict):
|
|
578
597
|
resulting_schema[key] = value
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
new_value = value.copy()
|
|
598
|
+
elif "$ref" in value:
|
|
599
|
+
new_value = value.copy()
|
|
582
600
|
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
601
|
+
path = new_value.pop("$ref")
|
|
602
|
+
def_key = _get_def_key_from_schema_path(path)
|
|
603
|
+
new_item = new_defs.get(def_key)
|
|
586
604
|
|
|
587
|
-
|
|
588
|
-
|
|
605
|
+
assert isinstance(new_item, dict)
|
|
606
|
+
new_value.update(new_item)
|
|
589
607
|
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
608
|
+
resulting_schema[key] = replace_defs_in_schema(new_value, defs=new_defs)
|
|
609
|
+
else:
|
|
610
|
+
resulting_schema[key] = replace_defs_in_schema(value, defs=new_defs)
|
|
593
611
|
|
|
594
612
|
return resulting_schema
|
|
595
613
|
|
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
"""Temporary high-level library of the Google GenerativeAI API.
|
|
2
2
|
|
|
3
|
-
The content of this file should eventually go into the Python package
|
|
4
|
-
google.generativeai.
|
|
3
|
+
(The content of this file should eventually go into the Python package
|
|
4
|
+
google.generativeai.)
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import datetime
|
|
8
8
|
import logging
|
|
9
9
|
import re
|
|
10
|
+
from collections.abc import Iterator, MutableSequence
|
|
10
11
|
from dataclasses import dataclass
|
|
11
|
-
from typing import Any, Dict,
|
|
12
|
+
from typing import Any, Dict, List, Optional
|
|
12
13
|
|
|
13
14
|
import google.ai.generativelanguage as genai
|
|
14
15
|
import langchain_core
|
|
@@ -21,7 +22,7 @@ from google.ai.generativelanguage_v1beta import (
|
|
|
21
22
|
from google.api_core import client_options as client_options_lib
|
|
22
23
|
from google.api_core import exceptions as gapi_exception
|
|
23
24
|
from google.api_core import gapic_v1
|
|
24
|
-
from google.auth import credentials, exceptions
|
|
25
|
+
from google.auth import credentials, exceptions
|
|
25
26
|
from google.protobuf import timestamp_pb2
|
|
26
27
|
|
|
27
28
|
_logger = logging.getLogger(__name__)
|
|
@@ -41,13 +42,15 @@ class EntityName:
|
|
|
41
42
|
|
|
42
43
|
def __post_init__(self) -> None:
|
|
43
44
|
if self.chunk_id is not None and self.document_id is None:
|
|
44
|
-
|
|
45
|
+
msg = f"Chunk must have document ID but found {self}"
|
|
46
|
+
raise ValueError(msg)
|
|
45
47
|
|
|
46
48
|
@classmethod
|
|
47
49
|
def from_str(cls, encoded: str) -> "EntityName":
|
|
48
50
|
matched = _NAME_REGEX.match(encoded)
|
|
49
51
|
if not matched:
|
|
50
|
-
|
|
52
|
+
msg = f"Invalid entity name: {encoded}"
|
|
53
|
+
raise ValueError(msg)
|
|
51
54
|
|
|
52
55
|
return cls(
|
|
53
56
|
corpus_id=matched.group(1),
|
|
@@ -186,7 +189,9 @@ class TestCredentials(credentials.Credentials):
|
|
|
186
189
|
"""Raises :class:``InvalidOperation``, test credentials cannot be
|
|
187
190
|
refreshed.
|
|
188
191
|
"""
|
|
189
|
-
|
|
192
|
+
msg = "Test credentials cannot be refreshed."
|
|
193
|
+
# TODO: remove ignore once google-auth has types.
|
|
194
|
+
raise exceptions.InvalidOperation(msg) # type: ignore[no-untyped-call]
|
|
190
195
|
|
|
191
196
|
def apply(self, headers: Any, token: Any = None) -> None:
|
|
192
197
|
"""Anonymous credentials do nothing to the request.
|
|
@@ -197,7 +202,9 @@ class TestCredentials(credentials.Credentials):
|
|
|
197
202
|
google.auth.exceptions.InvalidValue: If a token was specified.
|
|
198
203
|
"""
|
|
199
204
|
if token is not None:
|
|
200
|
-
|
|
205
|
+
msg = "Test credentials don't support tokens."
|
|
206
|
+
# TODO: remove ignore once google-auth has types.
|
|
207
|
+
raise exceptions.InvalidValue(msg) # type: ignore[no-untyped-call]
|
|
201
208
|
|
|
202
209
|
def before_request(self, request: Any, method: Any, url: Any, headers: Any) -> None:
|
|
203
210
|
"""Test credentials do nothing to the request."""
|
|
@@ -214,8 +221,9 @@ def _get_credentials() -> Optional[credentials.Credentials]:
|
|
|
214
221
|
inferred by the rules specified in google.auth package.
|
|
215
222
|
"""
|
|
216
223
|
if _config.testing:
|
|
217
|
-
|
|
218
|
-
|
|
224
|
+
# TODO: remove ignore once google-auth has types.
|
|
225
|
+
return TestCredentials() # type: ignore[no-untyped-call]
|
|
226
|
+
if _config.auth_credentials:
|
|
219
227
|
return _config.auth_credentials
|
|
220
228
|
return None
|
|
221
229
|
|
|
@@ -224,7 +232,8 @@ def build_semantic_retriever() -> genai.RetrieverServiceClient:
|
|
|
224
232
|
credentials = _get_credentials()
|
|
225
233
|
return genai.RetrieverServiceClient(
|
|
226
234
|
credentials=credentials,
|
|
227
|
-
|
|
235
|
+
# TODO: remove ignore once google-auth has types.
|
|
236
|
+
client_info=gapic_v1.client_info.ClientInfo(user_agent=_USER_AGENT), # type: ignore[no-untyped-call]
|
|
228
237
|
client_options=client_options_lib.ClientOptions(
|
|
229
238
|
api_endpoint=_config.api_endpoint
|
|
230
239
|
),
|
|
@@ -248,7 +257,8 @@ def _prepare_config(
|
|
|
248
257
|
client_info = (
|
|
249
258
|
client_info
|
|
250
259
|
if client_info
|
|
251
|
-
|
|
260
|
+
# TODO: remove ignore once google-auth has types.
|
|
261
|
+
else gapic_v1.client_info.ClientInfo(user_agent=_USER_AGENT) # type: ignore[no-untyped-call]
|
|
252
262
|
)
|
|
253
263
|
config = {
|
|
254
264
|
"credentials": credentials,
|
|
@@ -328,10 +338,7 @@ def create_corpus(
|
|
|
328
338
|
client: genai.RetrieverServiceClient,
|
|
329
339
|
) -> Corpus:
|
|
330
340
|
name: Optional[str]
|
|
331
|
-
if corpus_id is not None
|
|
332
|
-
name = str(EntityName(corpus_id=corpus_id))
|
|
333
|
-
else:
|
|
334
|
-
name = None
|
|
341
|
+
name = str(EntityName(corpus_id=corpus_id)) if corpus_id is not None else None
|
|
335
342
|
|
|
336
343
|
new_display_name = display_name or f"Untitled {datetime.datetime.now()}"
|
|
337
344
|
|
|
@@ -441,10 +448,11 @@ def batch_create_chunk(
|
|
|
441
448
|
if metadatas is None:
|
|
442
449
|
metadatas = [{} for _ in texts]
|
|
443
450
|
if len(texts) != len(metadatas):
|
|
444
|
-
|
|
451
|
+
msg = (
|
|
445
452
|
f"metadatas's length {len(metadatas)} "
|
|
446
453
|
f"and texts's length {len(texts)} are mismatched"
|
|
447
454
|
)
|
|
455
|
+
raise ValueError(msg)
|
|
448
456
|
|
|
449
457
|
doc_name = str(EntityName(corpus_id=corpus_id, document_id=document_id))
|
|
450
458
|
|
|
@@ -571,12 +579,14 @@ def generate_answer(
|
|
|
571
579
|
prompt: str,
|
|
572
580
|
passages: List[str],
|
|
573
581
|
answer_style: int = genai.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE,
|
|
574
|
-
safety_settings: List[genai.SafetySetting] =
|
|
582
|
+
safety_settings: Optional[List[genai.SafetySetting]] = None,
|
|
575
583
|
temperature: Optional[float] = None,
|
|
576
584
|
client: genai.GenerativeServiceClient,
|
|
577
585
|
) -> GroundedAnswer:
|
|
578
586
|
# TODO: Consider passing in the corpus ID instead of the actual
|
|
579
587
|
# passages.
|
|
588
|
+
if safety_settings is None:
|
|
589
|
+
safety_settings = []
|
|
580
590
|
response = client.generate_answer(
|
|
581
591
|
genai.GenerateAnswerRequest(
|
|
582
592
|
contents=[
|
|
@@ -622,20 +632,41 @@ def generate_answer(
|
|
|
622
632
|
)
|
|
623
633
|
|
|
624
634
|
|
|
625
|
-
# TODO: Use candidate.finish_message when that field is launched.
|
|
626
|
-
# For now, we derive this message from other existing fields.
|
|
627
635
|
def _get_finish_message(candidate: genai.Candidate) -> str:
|
|
636
|
+
"""Get a human-readable finish message from the candidate.
|
|
637
|
+
|
|
638
|
+
Uses the official finish_message field if available, otherwise falls back
|
|
639
|
+
to a manual mapping of finish reasons to descriptive messages.
|
|
640
|
+
"""
|
|
641
|
+
# Use the official field when available
|
|
642
|
+
if hasattr(candidate, "finish_message") and candidate.finish_message:
|
|
643
|
+
return candidate.finish_message
|
|
644
|
+
|
|
645
|
+
# Fallback to manual mapping for all known finish reasons
|
|
628
646
|
finish_messages: Dict[int, str] = {
|
|
629
|
-
genai.Candidate.FinishReason.
|
|
647
|
+
genai.Candidate.FinishReason.STOP: "Generation completed successfully",
|
|
648
|
+
genai.Candidate.FinishReason.MAX_TOKENS: (
|
|
649
|
+
"Maximum token in context window reached"
|
|
650
|
+
),
|
|
630
651
|
genai.Candidate.FinishReason.SAFETY: "Blocked because of safety",
|
|
631
652
|
genai.Candidate.FinishReason.RECITATION: "Blocked because of recitation",
|
|
653
|
+
genai.Candidate.FinishReason.LANGUAGE: "Unsupported language detected",
|
|
654
|
+
genai.Candidate.FinishReason.BLOCKLIST: "Content hit forbidden terms",
|
|
655
|
+
genai.Candidate.FinishReason.PROHIBITED_CONTENT: (
|
|
656
|
+
"Inappropriate content detected"
|
|
657
|
+
),
|
|
658
|
+
genai.Candidate.FinishReason.SPII: "Sensitive personal information detected",
|
|
659
|
+
genai.Candidate.FinishReason.IMAGE_SAFETY: "Image safety violation",
|
|
660
|
+
genai.Candidate.FinishReason.MALFORMED_FUNCTION_CALL: "Malformed function call",
|
|
661
|
+
genai.Candidate.FinishReason.UNEXPECTED_TOOL_CALL: "Unexpected tool call",
|
|
662
|
+
genai.Candidate.FinishReason.OTHER: "Other generation issue",
|
|
663
|
+
genai.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED: (
|
|
664
|
+
"Unspecified finish reason"
|
|
665
|
+
),
|
|
632
666
|
}
|
|
633
667
|
|
|
634
668
|
finish_reason = candidate.finish_reason
|
|
635
|
-
|
|
636
|
-
return "Unexpected generation error"
|
|
637
|
-
|
|
638
|
-
return finish_messages[finish_reason]
|
|
669
|
+
return finish_messages.get(finish_reason, "Unexpected generation error")
|
|
639
670
|
|
|
640
671
|
|
|
641
672
|
def _convert_to_metadata(metadata: Dict[str, Any]) -> List[genai.CustomMetadata]:
|
|
@@ -646,7 +677,8 @@ def _convert_to_metadata(metadata: Dict[str, Any]) -> List[genai.CustomMetadata]
|
|
|
646
677
|
elif isinstance(value, (float, int)):
|
|
647
678
|
c = genai.CustomMetadata(key=key, numeric_value=value)
|
|
648
679
|
else:
|
|
649
|
-
|
|
680
|
+
msg = f"Metadata value {value} is not supported"
|
|
681
|
+
raise ValueError(msg)
|
|
650
682
|
|
|
651
683
|
cs.append(c)
|
|
652
684
|
return cs
|
|
@@ -668,7 +700,8 @@ def _convert_filter(fs: Optional[Dict[str, Any]]) -> List[genai.MetadataFilter]:
|
|
|
668
700
|
operation=genai.Condition.Operator.EQUAL, numeric_value=value
|
|
669
701
|
)
|
|
670
702
|
else:
|
|
671
|
-
|
|
703
|
+
msg = f"Filter value {value} is not supported"
|
|
704
|
+
raise ValueError(msg)
|
|
672
705
|
|
|
673
706
|
filters.append(genai.MetadataFilter(key=key, conditions=[condition]))
|
|
674
707
|
|
|
@@ -8,13 +8,13 @@ from enum import Enum
|
|
|
8
8
|
from typing import Any, Dict
|
|
9
9
|
from urllib.parse import urlparse
|
|
10
10
|
|
|
11
|
-
import filetype # type: ignore[import]
|
|
11
|
+
import filetype # type: ignore[import-untyped]
|
|
12
12
|
import requests
|
|
13
13
|
from google.ai.generativelanguage_v1beta.types import Part
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class Route(Enum):
|
|
17
|
-
"""Image Loading Route"""
|
|
17
|
+
"""Image Loading Route."""
|
|
18
18
|
|
|
19
19
|
BASE64 = 1
|
|
20
20
|
LOCAL_FILE = 2
|
|
@@ -40,7 +40,6 @@ class ImageBytesLoader:
|
|
|
40
40
|
Returns:
|
|
41
41
|
Image bytes.
|
|
42
42
|
"""
|
|
43
|
-
|
|
44
43
|
route = self._route(image_string)
|
|
45
44
|
|
|
46
45
|
if route == Route.BASE64:
|
|
@@ -50,18 +49,20 @@ class ImageBytesLoader:
|
|
|
50
49
|
return self._bytes_from_url(image_string)
|
|
51
50
|
|
|
52
51
|
if route == Route.LOCAL_FILE:
|
|
53
|
-
|
|
52
|
+
msg = (
|
|
54
53
|
"Loading from local files is no longer supported for security reasons. "
|
|
55
54
|
"Please pass in images as Google Cloud Storage URI, "
|
|
56
55
|
"b64 encoded image string (data:image/...), or valid image url."
|
|
57
56
|
)
|
|
57
|
+
raise ValueError(msg)
|
|
58
58
|
return self._bytes_from_file(image_string)
|
|
59
59
|
|
|
60
|
-
|
|
60
|
+
msg = (
|
|
61
61
|
"Image string must be one of: Google Cloud Storage URI, "
|
|
62
62
|
"b64 encoded image string (data:image/...), or valid image url."
|
|
63
63
|
f"Instead got '{image_string}'."
|
|
64
64
|
)
|
|
65
|
+
raise ValueError(msg)
|
|
65
66
|
|
|
66
67
|
def load_part(self, image_string: str) -> Part:
|
|
67
68
|
"""Gets Part for loading from Gemini.
|
|
@@ -110,11 +111,12 @@ class ImageBytesLoader:
|
|
|
110
111
|
if os.path.exists(image_string):
|
|
111
112
|
return Route.LOCAL_FILE
|
|
112
113
|
|
|
113
|
-
|
|
114
|
+
msg = (
|
|
114
115
|
"Image string must be one of: "
|
|
115
116
|
"b64 encoded image string (data:image/...) or valid image url."
|
|
116
117
|
f" Instead got '{image_string}'."
|
|
117
118
|
)
|
|
119
|
+
raise ValueError(msg)
|
|
118
120
|
|
|
119
121
|
def _bytes_from_b64(self, base64_image: str) -> bytes:
|
|
120
122
|
"""Gets image bytes from a base64 encoded string.
|
|
@@ -125,7 +127,6 @@ class ImageBytesLoader:
|
|
|
125
127
|
Returns:
|
|
126
128
|
Image bytes
|
|
127
129
|
"""
|
|
128
|
-
|
|
129
130
|
pattern = r"data:image/\w{2,4};base64,(.*)"
|
|
130
131
|
match = re.search(pattern, base64_image)
|
|
131
132
|
|
|
@@ -133,7 +134,8 @@ class ImageBytesLoader:
|
|
|
133
134
|
encoded_string = match.group(1)
|
|
134
135
|
return base64.b64decode(encoded_string)
|
|
135
136
|
|
|
136
|
-
|
|
137
|
+
msg = f"Error in b64 encoded image. Must follow pattern: {pattern}"
|
|
138
|
+
raise ValueError(msg)
|
|
137
139
|
|
|
138
140
|
def _bytes_from_url(self, url: str) -> bytes:
|
|
139
141
|
"""Gets image bytes from a public url.
|
|
@@ -147,7 +149,6 @@ class ImageBytesLoader:
|
|
|
147
149
|
Returns:
|
|
148
150
|
Image bytes
|
|
149
151
|
"""
|
|
150
|
-
|
|
151
152
|
response = requests.get(url)
|
|
152
153
|
|
|
153
154
|
if not response.ok:
|