langchain-google-genai 2.1.11__py3-none-any.whl → 2.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain-google-genai might be problematic. Click here for more details.

@@ -1,4 +1,4 @@
1
- """**LangChain Google Generative AI Integration**
1
+ """**LangChain Google Generative AI Integration**.
2
2
 
3
3
  This module integrates Google's Generative AI models, specifically the Gemini series, with the LangChain framework. It provides classes for interacting with chat models and generating embeddings, leveraging Google's advanced AI capabilities.
4
4
 
@@ -76,12 +76,12 @@ __all__ = [
76
76
  "AqaOutput",
77
77
  "ChatGoogleGenerativeAI",
78
78
  "DoesNotExistsException",
79
+ "DoesNotExistsException",
79
80
  "GenAIAqa",
80
- "GoogleGenerativeAIEmbeddings",
81
81
  "GoogleGenerativeAI",
82
+ "GoogleGenerativeAIEmbeddings",
82
83
  "GoogleVectorStore",
83
84
  "HarmBlockThreshold",
84
85
  "HarmCategory",
85
86
  "Modality",
86
- "DoesNotExistsException",
87
87
  ]
@@ -13,19 +13,17 @@ _TELEMETRY_ENV_VARIABLE_NAME = "GOOGLE_CLOUD_AGENT_ENGINE_ID"
13
13
 
14
14
 
15
15
  class GoogleGenerativeAIError(Exception):
16
- """
17
- Custom exception class for errors associated with the `Google GenAI` API.
18
- """
16
+ """Custom exception class for errors associated with the `Google GenAI` API."""
19
17
 
20
18
 
21
19
  class _BaseGoogleGenerativeAI(BaseModel):
22
- """Base class for Google Generative AI LLMs"""
20
+ """Base class for Google Generative AI LLMs."""
23
21
 
24
22
  model: str = Field(
25
23
  ...,
26
24
  description="""The name of the model to use.
27
25
  Examples:
28
- - gemini-2.5-pro
26
+ - gemini-2.5-flash
29
27
  - models/text-bison-001""",
30
28
  )
31
29
  """Model name to use."""
@@ -34,28 +32,37 @@ Examples:
34
32
  )
35
33
  """Google AI API key.
36
34
  If not specified will be read from env var ``GOOGLE_API_KEY``."""
35
+
37
36
  credentials: Any = None
38
37
  "The default custom credentials (google.auth.credentials.Credentials) to use "
39
38
  "when making API calls. If not provided, credentials will be ascertained from "
40
39
  "the GOOGLE_API_KEY envvar"
40
+
41
41
  temperature: float = 0.7
42
- """Run inference with this temperature. Must be within ``[0.0, 2.0]``."""
42
+ """Run inference with this temperature. Must be within ``[0.0, 2.0]``. If unset,
43
+ will default to ``0.7``."""
44
+
43
45
  top_p: Optional[float] = None
44
46
  """Decode using nucleus sampling: consider the smallest set of tokens whose
45
- probability sum is at least ``top_p``. Must be within ``[0.0, 1.0]``."""
47
+ probability sum is at least ``top_p``. Must be within ``[0.0, 1.0]``."""
48
+
46
49
  top_k: Optional[int] = None
47
50
  """Decode using top-k sampling: consider the set of ``top_k`` most probable tokens.
48
- Must be positive."""
51
+ Must be positive."""
52
+
49
53
  max_output_tokens: Optional[int] = Field(default=None, alias="max_tokens")
50
54
  """Maximum number of tokens to include in a candidate. Must be greater than zero.
51
- If unset, will default to ``64``."""
55
+ If unset, will default to ``64``."""
56
+
52
57
  n: int = 1
53
58
  """Number of chat completions to generate for each prompt. Note that the API may
54
- not return the full ``n`` completions if duplicates are generated."""
55
- max_retries: int = 6
56
- """The maximum number of retries to make when generating."""
59
+ not return the full ``n`` completions if duplicates are generated."""
60
+
61
+ max_retries: int = Field(default=6, alias="retries")
62
+ """The maximum number of retries to make when generating. If unset, will default
63
+ to ``6``."""
57
64
 
58
- timeout: Optional[float] = None
65
+ timeout: Optional[float] = Field(default=None, alias="request_timeout")
59
66
  """The maximum number of seconds to wait for a response."""
60
67
 
61
68
  client_options: Optional[Dict] = Field(
@@ -68,6 +75,7 @@ Examples:
68
75
  transport: Optional[str] = Field(
69
76
  default=None,
70
77
  description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
78
+ alias="api_transport",
71
79
  )
72
80
  additional_headers: Optional[Dict[str, str]] = Field(
73
81
  default=None,
@@ -89,9 +97,9 @@ Examples:
89
97
  )
90
98
 
91
99
  safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
92
- """The default safety settings to use for all generations.
93
-
94
- For example:
100
+ """The default safety settings to use for all generations.
101
+
102
+ For example:
95
103
 
96
104
  .. code-block:: python
97
105
  from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
@@ -127,6 +135,7 @@ def get_user_agent(module: Optional[str] = None) -> Tuple[str, str]:
127
135
  Args:
128
136
  module (Optional[str]):
129
137
  Optional. The module for a custom user agent header.
138
+
130
139
  Returns:
131
140
  Tuple[str, str]
132
141
  """
@@ -148,11 +157,13 @@ def get_client_info(module: Optional[str] = None) -> "ClientInfo":
148
157
  Args:
149
158
  module (Optional[str]):
150
159
  Optional. The module for a custom user agent header.
160
+
151
161
  Returns:
152
162
  ``google.api_core.gapic_v1.client_info.ClientInfo``
153
163
  """
154
164
  client_library_version, user_agent = get_user_agent(module)
155
- return ClientInfo(
165
+ # TODO: remove ignore once google-auth has types.
166
+ return ClientInfo( # type: ignore[no-untyped-call]
156
167
  client_library_version=client_library_version,
157
168
  user_agent=user_agent,
158
169
  )
@@ -4,6 +4,7 @@ import collections
4
4
  import importlib
5
5
  import json
6
6
  import logging
7
+ from collections.abc import Sequence
7
8
  from typing import (
8
9
  Any,
9
10
  Callable,
@@ -11,7 +12,6 @@ from typing import (
11
12
  List,
12
13
  Literal,
13
14
  Optional,
14
- Sequence,
15
15
  Type,
16
16
  TypedDict,
17
17
  Union,
@@ -20,7 +20,7 @@ from typing import (
20
20
 
21
21
  import google.ai.generativelanguage as glm
22
22
  import google.ai.generativelanguage_v1beta.types as gapic
23
- import proto # type: ignore[import]
23
+ import proto # type: ignore[import-untyped]
24
24
  from langchain_core.tools import BaseTool
25
25
  from langchain_core.tools import tool as callable_as_lc_tool
26
26
  from langchain_core.utils.function_calling import (
@@ -48,12 +48,7 @@ TYPE_ENUM = {
48
48
  _ALLOWED_SCHEMA_FIELDS = []
49
49
  _ALLOWED_SCHEMA_FIELDS.extend([f.name for f in gapic.Schema()._pb.DESCRIPTOR.fields])
50
50
  _ALLOWED_SCHEMA_FIELDS.extend(
51
- [
52
- f
53
- for f in gapic.Schema.to_dict(
54
- gapic.Schema(), preserving_proto_field_name=False
55
- ).keys()
56
- ]
51
+ list(gapic.Schema.to_dict(gapic.Schema(), preserving_proto_field_name=False).keys())
57
52
  )
58
53
  _ALLOWED_SCHEMA_FIELDS_SET = set(_ALLOWED_SCHEMA_FIELDS)
59
54
 
@@ -89,7 +84,7 @@ def _format_json_schema_to_gapic(schema: Dict[str, Any]) -> Dict[str, Any]:
89
84
  for key, value in schema.items():
90
85
  if key == "definitions":
91
86
  continue
92
- elif key == "items":
87
+ if key == "items":
93
88
  converted_schema["items"] = _format_json_schema_to_gapic(value)
94
89
  elif key == "properties":
95
90
  converted_schema["properties"] = _get_properties_from_schema(value)
@@ -142,10 +137,11 @@ def convert_to_genai_function_declarations(
142
137
  gapic_tool = gapic.Tool()
143
138
  for tool in tools:
144
139
  if any(f in gapic_tool for f in ["google_search_retrieval"]):
145
- raise ValueError(
140
+ msg = (
146
141
  "Providing multiple google_search_retrieval"
147
142
  " or mixing with function_declarations is not supported"
148
143
  )
144
+ raise ValueError(msg)
149
145
  if isinstance(tool, (gapic.Tool)):
150
146
  rt: gapic.Tool = (
151
147
  tool if isinstance(tool, gapic.Tool) else tool._raw_tool # type: ignore
@@ -171,16 +167,17 @@ def convert_to_genai_function_declarations(
171
167
  gapic_tool.function_declarations.append(fd)
172
168
  continue
173
169
  # _ToolDictLike
174
- tool = cast(_ToolDict, tool)
170
+ tool = cast("_ToolDict", tool)
175
171
  if "function_declarations" in tool:
176
172
  function_declarations = tool["function_declarations"]
177
173
  if not isinstance(
178
174
  tool["function_declarations"], collections.abc.Sequence
179
175
  ):
180
- raise ValueError(
176
+ msg = (
181
177
  "function_declarations should be a list"
182
178
  f"got '{type(function_declarations)}'"
183
179
  )
180
+ raise ValueError(msg)
184
181
  if function_declarations:
185
182
  fds = [
186
183
  _format_to_gapic_function_declaration(fd)
@@ -198,7 +195,7 @@ def convert_to_genai_function_declarations(
198
195
  if "code_execution" in tool:
199
196
  gapic_tool.code_execution = gapic.CodeExecution(tool["code_execution"])
200
197
  else:
201
- fd = _format_to_gapic_function_declaration(tool) # type: ignore[arg-type]
198
+ fd = _format_to_gapic_function_declaration(tool)
202
199
  gapic_tool.function_declarations.append(fd)
203
200
  return gapic_tool
204
201
 
@@ -221,30 +218,32 @@ def _format_to_gapic_function_declaration(
221
218
  ) -> gapic.FunctionDeclaration:
222
219
  if isinstance(tool, BaseTool):
223
220
  return _format_base_tool_to_function_declaration(tool)
224
- elif isinstance(tool, type) and is_basemodel_subclass_safe(tool):
221
+ if isinstance(tool, type) and is_basemodel_subclass_safe(tool):
225
222
  return _convert_pydantic_to_genai_function(tool)
226
- elif isinstance(tool, dict):
223
+ if isinstance(tool, dict):
227
224
  if all(k in tool for k in ("type", "function")) and tool["type"] == "function":
228
225
  function = tool["function"]
229
226
  elif (
230
227
  all(k in tool for k in ("name", "description")) and "parameters" not in tool
231
228
  ):
232
- function = cast(dict, tool)
229
+ function = cast("dict", tool)
230
+ elif (
231
+ "parameters" in tool and tool["parameters"].get("properties") # type: ignore[index]
232
+ ):
233
+ function = convert_to_openai_tool(cast("dict", tool))["function"]
233
234
  else:
234
- if (
235
- "parameters" in tool and tool["parameters"].get("properties") # type: ignore[index]
236
- ):
237
- function = convert_to_openai_tool(cast(dict, tool))["function"]
238
- else:
239
- function = cast(dict, tool)
235
+ function = cast("dict", tool)
240
236
  function["parameters"] = function.get("parameters") or {}
241
237
  # Empty 'properties' field not supported.
242
238
  if not function["parameters"].get("properties"):
243
239
  function["parameters"] = {}
244
- return _format_dict_to_function_declaration(cast(FunctionDescription, function))
245
- elif callable(tool):
240
+ return _format_dict_to_function_declaration(
241
+ cast("FunctionDescription", function)
242
+ )
243
+ if callable(tool):
246
244
  return _format_base_tool_to_function_declaration(callable_as_lc_tool()(tool))
247
- raise ValueError(f"Unsupported tool type {tool}")
245
+ msg = f"Unsupported tool type {tool}"
246
+ raise ValueError(msg)
248
247
 
249
248
 
250
249
  def _format_base_tool_to_function_declaration(
@@ -270,10 +269,11 @@ def _format_base_tool_to_function_declaration(
270
269
  elif issubclass(tool.args_schema, BaseModelV1):
271
270
  schema = tool.args_schema.schema()
272
271
  else:
273
- raise NotImplementedError(
272
+ msg = (
274
273
  "args_schema must be a Pydantic BaseModel or JSON schema, "
275
274
  f"got {tool.args_schema}."
276
275
  )
276
+ raise NotImplementedError(msg)
277
277
  parameters = _dict_to_gapic_schema(schema)
278
278
 
279
279
  return gapic.FunctionDeclaration(
@@ -293,12 +293,11 @@ def _convert_pydantic_to_genai_function(
293
293
  elif issubclass(pydantic_model, BaseModelV1):
294
294
  schema = pydantic_model.schema()
295
295
  else:
296
- raise NotImplementedError(
297
- f"pydantic_model must be a Pydantic BaseModel, got {pydantic_model}"
298
- )
296
+ msg = f"pydantic_model must be a Pydantic BaseModel, got {pydantic_model}"
297
+ raise NotImplementedError(msg)
299
298
  schema = dereference_refs(schema)
300
299
  schema.pop("definitions", None)
301
- function_declaration = gapic.FunctionDeclaration(
300
+ return gapic.FunctionDeclaration(
302
301
  name=tool_name if tool_name else schema.get("title"),
303
302
  description=tool_description if tool_description else schema.get("description"),
304
303
  parameters={
@@ -312,7 +311,6 @@ def _convert_pydantic_to_genai_function(
312
311
  "type_": TYPE_ENUM[schema["type"]],
313
312
  },
314
313
  )
315
- return function_declaration
316
314
 
317
315
 
318
316
  def _get_properties_from_schema_any(schema: Any) -> Dict[str, Any]:
@@ -444,8 +442,6 @@ def _get_nullable_type_from_schema(schema: Dict[str, Any]) -> Optional[int]:
444
442
  types = [t for t in types if t is not None] # Remove None values
445
443
  if types:
446
444
  return types[-1] # TODO: update FunctionDeclaration and pass all types?
447
- else:
448
- pass
449
445
  elif "type" in schema or "type_" in schema:
450
446
  type_ = schema["type"] if "type" in schema else schema["type_"]
451
447
  if isinstance(type_, int):
@@ -463,20 +459,16 @@ def _is_nullable_schema(schema: Dict[str, Any]) -> bool:
463
459
  _get_nullable_type_from_schema(sub_schema) for sub_schema in schema["anyOf"]
464
460
  ]
465
461
  return any(t is None for t in types)
466
- elif "type" in schema or "type_" in schema:
462
+ if "type" in schema or "type_" in schema:
467
463
  type_ = schema["type"] if "type" in schema else schema["type_"]
468
464
  if isinstance(type_, int):
469
465
  return False
470
466
  stype = str(schema["type"]) if "type" in schema else str(schema["type_"])
471
467
  return TYPE_ENUM.get(stype, glm.Type.STRING) is None
472
- else:
473
- pass
474
468
  return False
475
469
 
476
470
 
477
- _ToolChoiceType = Union[
478
- dict, List[str], str, Literal["auto", "none", "any"], Literal[True]
479
- ]
471
+ _ToolChoiceType = Union[Literal["auto", "none", "any", True], dict, List[str], str]
480
472
 
481
473
 
482
474
  class _FunctionCallingConfigDict(TypedDict):
@@ -516,12 +508,14 @@ def _tool_choice_to_tool_config(
516
508
  "allowed_function_names"
517
509
  )
518
510
  else:
519
- raise ValueError(
511
+ msg = (
520
512
  f"Unrecognized tool choice format:\n\n{tool_choice=}\n\nShould match "
521
513
  f"Google GenerativeAI ToolConfig or FunctionCallingConfig format."
522
514
  )
515
+ raise ValueError(msg)
523
516
  else:
524
- raise ValueError(f"Unrecognized tool choice format:\n\n{tool_choice=}")
517
+ msg = f"Unrecognized tool choice format:\n\n{tool_choice=}"
518
+ raise ValueError(msg)
525
519
  return _ToolConfigDict(
526
520
  function_calling_config={
527
521
  "mode": mode.upper(),
@@ -533,12 +527,11 @@ def _tool_choice_to_tool_config(
533
527
  def is_basemodel_subclass_safe(tool: Type) -> bool:
534
528
  if safe_import("langchain_core.utils.pydantic", "is_basemodel_subclass"):
535
529
  from langchain_core.utils.pydantic import (
536
- is_basemodel_subclass, # type: ignore[import]
530
+ is_basemodel_subclass,
537
531
  )
538
532
 
539
533
  return is_basemodel_subclass(tool)
540
- else:
541
- return issubclass(tool, BaseModel)
534
+ return issubclass(tool, BaseModel)
542
535
 
543
536
 
544
537
  def safe_import(module_name: str, attribute_name: str = "") -> bool:
@@ -562,7 +555,6 @@ def replace_defs_in_schema(original_schema: dict, defs: Optional[dict] = None) -
562
555
  Returns:
563
556
  Schema with refs replaced.
564
557
  """
565
-
566
558
  new_defs = defs or original_schema.get("$defs")
567
559
 
568
560
  if new_defs is None or not isinstance(new_defs, dict):
@@ -576,20 +568,19 @@ def replace_defs_in_schema(original_schema: dict, defs: Optional[dict] = None) -
576
568
 
577
569
  if not isinstance(value, dict):
578
570
  resulting_schema[key] = value
579
- else:
580
- if "$ref" in value:
581
- new_value = value.copy()
571
+ elif "$ref" in value:
572
+ new_value = value.copy()
582
573
 
583
- path = new_value.pop("$ref")
584
- def_key = _get_def_key_from_schema_path(path)
585
- new_item = new_defs.get(def_key)
574
+ path = new_value.pop("$ref")
575
+ def_key = _get_def_key_from_schema_path(path)
576
+ new_item = new_defs.get(def_key)
586
577
 
587
- assert isinstance(new_item, dict)
588
- new_value.update(new_item)
578
+ assert isinstance(new_item, dict)
579
+ new_value.update(new_item)
589
580
 
590
- resulting_schema[key] = replace_defs_in_schema(new_value, defs=new_defs)
591
- else:
592
- resulting_schema[key] = replace_defs_in_schema(value, defs=new_defs)
581
+ resulting_schema[key] = replace_defs_in_schema(new_value, defs=new_defs)
582
+ else:
583
+ resulting_schema[key] = replace_defs_in_schema(value, defs=new_defs)
593
584
 
594
585
  return resulting_schema
595
586
 
@@ -1,14 +1,15 @@
1
1
  """Temporary high-level library of the Google GenerativeAI API.
2
2
 
3
- The content of this file should eventually go into the Python package
4
- google.generativeai.
3
+ (The content of this file should eventually go into the Python package
4
+ google.generativeai.)
5
5
  """
6
6
 
7
7
  import datetime
8
8
  import logging
9
9
  import re
10
+ from collections.abc import Iterator, MutableSequence
10
11
  from dataclasses import dataclass
11
- from typing import Any, Dict, Iterator, List, MutableSequence, Optional
12
+ from typing import Any, Dict, List, Optional
12
13
 
13
14
  import google.ai.generativelanguage as genai
14
15
  import langchain_core
@@ -21,7 +22,7 @@ from google.ai.generativelanguage_v1beta import (
21
22
  from google.api_core import client_options as client_options_lib
22
23
  from google.api_core import exceptions as gapi_exception
23
24
  from google.api_core import gapic_v1
24
- from google.auth import credentials, exceptions # type: ignore
25
+ from google.auth import credentials, exceptions
25
26
  from google.protobuf import timestamp_pb2
26
27
 
27
28
  _logger = logging.getLogger(__name__)
@@ -41,13 +42,15 @@ class EntityName:
41
42
 
42
43
  def __post_init__(self) -> None:
43
44
  if self.chunk_id is not None and self.document_id is None:
44
- raise ValueError(f"Chunk must have document ID but found {self}")
45
+ msg = f"Chunk must have document ID but found {self}"
46
+ raise ValueError(msg)
45
47
 
46
48
  @classmethod
47
49
  def from_str(cls, encoded: str) -> "EntityName":
48
50
  matched = _NAME_REGEX.match(encoded)
49
51
  if not matched:
50
- raise ValueError(f"Invalid entity name: {encoded}")
52
+ msg = f"Invalid entity name: {encoded}"
53
+ raise ValueError(msg)
51
54
 
52
55
  return cls(
53
56
  corpus_id=matched.group(1),
@@ -186,7 +189,9 @@ class TestCredentials(credentials.Credentials):
186
189
  """Raises :class:``InvalidOperation``, test credentials cannot be
187
190
  refreshed.
188
191
  """
189
- raise exceptions.InvalidOperation("Test credentials cannot be refreshed.")
192
+ msg = "Test credentials cannot be refreshed."
193
+ # TODO: remove ignore once google-auth has types.
194
+ raise exceptions.InvalidOperation(msg) # type: ignore[no-untyped-call]
190
195
 
191
196
  def apply(self, headers: Any, token: Any = None) -> None:
192
197
  """Anonymous credentials do nothing to the request.
@@ -197,7 +202,9 @@ class TestCredentials(credentials.Credentials):
197
202
  google.auth.exceptions.InvalidValue: If a token was specified.
198
203
  """
199
204
  if token is not None:
200
- raise exceptions.InvalidValue("Test credentials don't support tokens.")
205
+ msg = "Test credentials don't support tokens."
206
+ # TODO: remove ignore once google-auth has types.
207
+ raise exceptions.InvalidValue(msg) # type: ignore[no-untyped-call]
201
208
 
202
209
  def before_request(self, request: Any, method: Any, url: Any, headers: Any) -> None:
203
210
  """Test credentials do nothing to the request."""
@@ -214,8 +221,9 @@ def _get_credentials() -> Optional[credentials.Credentials]:
214
221
  inferred by the rules specified in google.auth package.
215
222
  """
216
223
  if _config.testing:
217
- return TestCredentials()
218
- elif _config.auth_credentials:
224
+ # TODO: remove ignore once google-auth has types.
225
+ return TestCredentials() # type: ignore[no-untyped-call]
226
+ if _config.auth_credentials:
219
227
  return _config.auth_credentials
220
228
  return None
221
229
 
@@ -224,7 +232,8 @@ def build_semantic_retriever() -> genai.RetrieverServiceClient:
224
232
  credentials = _get_credentials()
225
233
  return genai.RetrieverServiceClient(
226
234
  credentials=credentials,
227
- client_info=gapic_v1.client_info.ClientInfo(user_agent=_USER_AGENT),
235
+ # TODO: remove ignore once google-auth has types.
236
+ client_info=gapic_v1.client_info.ClientInfo(user_agent=_USER_AGENT), # type: ignore[no-untyped-call]
228
237
  client_options=client_options_lib.ClientOptions(
229
238
  api_endpoint=_config.api_endpoint
230
239
  ),
@@ -248,7 +257,8 @@ def _prepare_config(
248
257
  client_info = (
249
258
  client_info
250
259
  if client_info
251
- else gapic_v1.client_info.ClientInfo(user_agent=_USER_AGENT)
260
+ # TODO: remove ignore once google-auth has types.
261
+ else gapic_v1.client_info.ClientInfo(user_agent=_USER_AGENT) # type: ignore[no-untyped-call]
252
262
  )
253
263
  config = {
254
264
  "credentials": credentials,
@@ -328,10 +338,7 @@ def create_corpus(
328
338
  client: genai.RetrieverServiceClient,
329
339
  ) -> Corpus:
330
340
  name: Optional[str]
331
- if corpus_id is not None:
332
- name = str(EntityName(corpus_id=corpus_id))
333
- else:
334
- name = None
341
+ name = str(EntityName(corpus_id=corpus_id)) if corpus_id is not None else None
335
342
 
336
343
  new_display_name = display_name or f"Untitled {datetime.datetime.now()}"
337
344
 
@@ -441,10 +448,11 @@ def batch_create_chunk(
441
448
  if metadatas is None:
442
449
  metadatas = [{} for _ in texts]
443
450
  if len(texts) != len(metadatas):
444
- raise ValueError(
451
+ msg = (
445
452
  f"metadatas's length {len(metadatas)} "
446
453
  f"and texts's length {len(texts)} are mismatched"
447
454
  )
455
+ raise ValueError(msg)
448
456
 
449
457
  doc_name = str(EntityName(corpus_id=corpus_id, document_id=document_id))
450
458
 
@@ -571,12 +579,14 @@ def generate_answer(
571
579
  prompt: str,
572
580
  passages: List[str],
573
581
  answer_style: int = genai.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE,
574
- safety_settings: List[genai.SafetySetting] = [],
582
+ safety_settings: Optional[List[genai.SafetySetting]] = None,
575
583
  temperature: Optional[float] = None,
576
584
  client: genai.GenerativeServiceClient,
577
585
  ) -> GroundedAnswer:
578
586
  # TODO: Consider passing in the corpus ID instead of the actual
579
587
  # passages.
588
+ if safety_settings is None:
589
+ safety_settings = []
580
590
  response = client.generate_answer(
581
591
  genai.GenerateAnswerRequest(
582
592
  contents=[
@@ -626,7 +636,9 @@ def generate_answer(
626
636
  # For now, we derive this message from other existing fields.
627
637
  def _get_finish_message(candidate: genai.Candidate) -> str:
628
638
  finish_messages: Dict[int, str] = {
629
- genai.Candidate.FinishReason.MAX_TOKENS: "Maximum token in context window reached", # noqa: E501
639
+ genai.Candidate.FinishReason.MAX_TOKENS: (
640
+ "Maximum token in context window reached"
641
+ ),
630
642
  genai.Candidate.FinishReason.SAFETY: "Blocked because of safety",
631
643
  genai.Candidate.FinishReason.RECITATION: "Blocked because of recitation",
632
644
  }
@@ -646,7 +658,8 @@ def _convert_to_metadata(metadata: Dict[str, Any]) -> List[genai.CustomMetadata]
646
658
  elif isinstance(value, (float, int)):
647
659
  c = genai.CustomMetadata(key=key, numeric_value=value)
648
660
  else:
649
- raise ValueError(f"Metadata value {value} is not supported")
661
+ msg = f"Metadata value {value} is not supported"
662
+ raise ValueError(msg)
650
663
 
651
664
  cs.append(c)
652
665
  return cs
@@ -668,7 +681,8 @@ def _convert_filter(fs: Optional[Dict[str, Any]]) -> List[genai.MetadataFilter]:
668
681
  operation=genai.Condition.Operator.EQUAL, numeric_value=value
669
682
  )
670
683
  else:
671
- raise ValueError(f"Filter value {value} is not supported")
684
+ msg = f"Filter value {value} is not supported"
685
+ raise ValueError(msg)
672
686
 
673
687
  filters.append(genai.MetadataFilter(key=key, conditions=[condition]))
674
688
 
@@ -8,13 +8,13 @@ from enum import Enum
8
8
  from typing import Any, Dict
9
9
  from urllib.parse import urlparse
10
10
 
11
- import filetype # type: ignore[import]
11
+ import filetype # type: ignore[import-untyped]
12
12
  import requests
13
13
  from google.ai.generativelanguage_v1beta.types import Part
14
14
 
15
15
 
16
16
  class Route(Enum):
17
- """Image Loading Route"""
17
+ """Image Loading Route."""
18
18
 
19
19
  BASE64 = 1
20
20
  LOCAL_FILE = 2
@@ -40,7 +40,6 @@ class ImageBytesLoader:
40
40
  Returns:
41
41
  Image bytes.
42
42
  """
43
-
44
43
  route = self._route(image_string)
45
44
 
46
45
  if route == Route.BASE64:
@@ -50,18 +49,20 @@ class ImageBytesLoader:
50
49
  return self._bytes_from_url(image_string)
51
50
 
52
51
  if route == Route.LOCAL_FILE:
53
- raise ValueError(
52
+ msg = (
54
53
  "Loading from local files is no longer supported for security reasons. "
55
54
  "Please pass in images as Google Cloud Storage URI, "
56
55
  "b64 encoded image string (data:image/...), or valid image url."
57
56
  )
57
+ raise ValueError(msg)
58
58
  return self._bytes_from_file(image_string)
59
59
 
60
- raise ValueError(
60
+ msg = (
61
61
  "Image string must be one of: Google Cloud Storage URI, "
62
62
  "b64 encoded image string (data:image/...), or valid image url."
63
63
  f"Instead got '{image_string}'."
64
64
  )
65
+ raise ValueError(msg)
65
66
 
66
67
  def load_part(self, image_string: str) -> Part:
67
68
  """Gets Part for loading from Gemini.
@@ -110,11 +111,12 @@ class ImageBytesLoader:
110
111
  if os.path.exists(image_string):
111
112
  return Route.LOCAL_FILE
112
113
 
113
- raise ValueError(
114
+ msg = (
114
115
  "Image string must be one of: "
115
116
  "b64 encoded image string (data:image/...) or valid image url."
116
117
  f" Instead got '{image_string}'."
117
118
  )
119
+ raise ValueError(msg)
118
120
 
119
121
  def _bytes_from_b64(self, base64_image: str) -> bytes:
120
122
  """Gets image bytes from a base64 encoded string.
@@ -125,7 +127,6 @@ class ImageBytesLoader:
125
127
  Returns:
126
128
  Image bytes
127
129
  """
128
-
129
130
  pattern = r"data:image/\w{2,4};base64,(.*)"
130
131
  match = re.search(pattern, base64_image)
131
132
 
@@ -133,7 +134,8 @@ class ImageBytesLoader:
133
134
  encoded_string = match.group(1)
134
135
  return base64.b64decode(encoded_string)
135
136
 
136
- raise ValueError(f"Error in b64 encoded image. Must follow pattern: {pattern}")
137
+ msg = f"Error in b64 encoded image. Must follow pattern: {pattern}"
138
+ raise ValueError(msg)
137
139
 
138
140
  def _bytes_from_url(self, url: str) -> bytes:
139
141
  """Gets image bytes from a public url.
@@ -147,7 +149,6 @@ class ImageBytesLoader:
147
149
  Returns:
148
150
  Image bytes
149
151
  """
150
-
151
152
  response = requests.get(url)
152
153
 
153
154
  if not response.ok: