langchain-google-genai 2.1.9__tar.gz → 2.1.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain-google-genai might be problematic. Click here for more details.

Files changed (42) hide show
  1. {langchain_google_genai-2.1.9 → langchain_google_genai-2.1.11}/PKG-INFO +6 -15
  2. {langchain_google_genai-2.1.9 → langchain_google_genai-2.1.11}/langchain_google_genai/_function_utils.py +12 -3
  3. {langchain_google_genai-2.1.9 → langchain_google_genai-2.1.11}/langchain_google_genai/chat_models.py +58 -31
  4. {langchain_google_genai-2.1.9 → langchain_google_genai-2.1.11}/langchain_google_genai/embeddings.py +51 -9
  5. {langchain_google_genai-2.1.9 → langchain_google_genai-2.1.11}/langchain_google_genai/llms.py +1 -1
  6. langchain_google_genai-2.1.11/pyproject.toml +84 -0
  7. langchain_google_genai-2.1.11/tests/__init__.py +0 -0
  8. langchain_google_genai-2.1.11/tests/conftest.py +64 -0
  9. langchain_google_genai-2.1.11/tests/integration_tests/.env.example +1 -0
  10. langchain_google_genai-2.1.11/tests/integration_tests/__init__.py +0 -0
  11. langchain_google_genai-2.1.11/tests/integration_tests/terraform/main.tf +12 -0
  12. langchain_google_genai-2.1.11/tests/integration_tests/test_callbacks.py +31 -0
  13. langchain_google_genai-2.1.11/tests/integration_tests/test_chat_models.py +887 -0
  14. langchain_google_genai-2.1.11/tests/integration_tests/test_compile.py +7 -0
  15. langchain_google_genai-2.1.11/tests/integration_tests/test_embeddings.py +145 -0
  16. langchain_google_genai-2.1.11/tests/integration_tests/test_function_call.py +90 -0
  17. langchain_google_genai-2.1.11/tests/integration_tests/test_llms.py +100 -0
  18. langchain_google_genai-2.1.11/tests/integration_tests/test_standard.py +141 -0
  19. langchain_google_genai-2.1.11/tests/integration_tests/test_tools.py +37 -0
  20. langchain_google_genai-2.1.11/tests/unit_tests/__init__.py +0 -0
  21. langchain_google_genai-2.1.11/tests/unit_tests/__snapshots__/test_standard.ambr +63 -0
  22. langchain_google_genai-2.1.11/tests/unit_tests/test_chat_models.py +916 -0
  23. langchain_google_genai-2.1.11/tests/unit_tests/test_chat_models_protobuf_fix.py +132 -0
  24. langchain_google_genai-2.1.11/tests/unit_tests/test_common.py +31 -0
  25. langchain_google_genai-2.1.11/tests/unit_tests/test_embeddings.py +158 -0
  26. langchain_google_genai-2.1.11/tests/unit_tests/test_function_utils.py +1406 -0
  27. langchain_google_genai-2.1.11/tests/unit_tests/test_genai_aqa.py +95 -0
  28. langchain_google_genai-2.1.11/tests/unit_tests/test_google_vector_store.py +440 -0
  29. langchain_google_genai-2.1.11/tests/unit_tests/test_imports.py +20 -0
  30. langchain_google_genai-2.1.11/tests/unit_tests/test_llms.py +47 -0
  31. langchain_google_genai-2.1.11/tests/unit_tests/test_standard.py +42 -0
  32. langchain_google_genai-2.1.9/pyproject.toml +0 -109
  33. {langchain_google_genai-2.1.9 → langchain_google_genai-2.1.11}/LICENSE +0 -0
  34. {langchain_google_genai-2.1.9 → langchain_google_genai-2.1.11}/README.md +0 -0
  35. {langchain_google_genai-2.1.9 → langchain_google_genai-2.1.11}/langchain_google_genai/__init__.py +0 -0
  36. {langchain_google_genai-2.1.9 → langchain_google_genai-2.1.11}/langchain_google_genai/_common.py +0 -0
  37. {langchain_google_genai-2.1.9 → langchain_google_genai-2.1.11}/langchain_google_genai/_enums.py +0 -0
  38. {langchain_google_genai-2.1.9 → langchain_google_genai-2.1.11}/langchain_google_genai/_genai_extension.py +0 -0
  39. {langchain_google_genai-2.1.9 → langchain_google_genai-2.1.11}/langchain_google_genai/_image_utils.py +0 -0
  40. {langchain_google_genai-2.1.9 → langchain_google_genai-2.1.11}/langchain_google_genai/genai_aqa.py +0 -0
  41. {langchain_google_genai-2.1.9 → langchain_google_genai-2.1.11}/langchain_google_genai/google_vector_store.py +0 -0
  42. {langchain_google_genai-2.1.9 → langchain_google_genai-2.1.11}/langchain_google_genai/py.typed +0 -0
@@ -1,22 +1,14 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langchain-google-genai
3
- Version: 2.1.9
3
+ Version: 2.1.11
4
4
  Summary: An integration package connecting Google's genai package and LangChain
5
- Home-page: https://github.com/langchain-ai/langchain-google
6
5
  License: MIT
7
- Requires-Python: >=3.9,<4.0
8
- Classifier: License :: OSI Approved :: MIT License
9
- Classifier: Programming Language :: Python :: 3
10
- Classifier: Programming Language :: Python :: 3.9
11
- Classifier: Programming Language :: Python :: 3.10
12
- Classifier: Programming Language :: Python :: 3.11
13
- Classifier: Programming Language :: Python :: 3.12
14
- Requires-Dist: filetype (>=1.2.0,<2.0.0)
15
- Requires-Dist: google-ai-generativelanguage (>=0.6.18,<0.7.0)
16
- Requires-Dist: langchain-core (>=0.3.68,<0.4.0)
17
- Requires-Dist: pydantic (>=2,<3)
18
- Project-URL: Repository, https://github.com/langchain-ai/langchain-google
19
6
  Project-URL: Source Code, https://github.com/langchain-ai/langchain-google/tree/main/libs/genai
7
+ Requires-Python: >=3.9
8
+ Requires-Dist: langchain-core>=0.3.75
9
+ Requires-Dist: google-ai-generativelanguage<1,>=0.7
10
+ Requires-Dist: pydantic<3,>=2
11
+ Requires-Dist: filetype<2,>=1.2
20
12
  Description-Content-Type: text/markdown
21
13
 
22
14
  # langchain-google-genai
@@ -258,4 +250,3 @@ print("Answerable probability:", response.answerable_probability)
258
250
  - [LangChain Documentation](https://docs.langchain.com/)
259
251
  - [Google Generative AI SDK](https://googleapis.github.io/python-genai/)
260
252
  - [Gemini Model Documentation](https://ai.google.dev/)
261
-
@@ -331,6 +331,10 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
331
331
  logger.warning(f"Value '{v}' is not supported in schema, ignoring v={v}")
332
332
  continue
333
333
  properties_item: Dict[str, Union[str, int, Dict, List]] = {}
334
+
335
+ # Get description from original schema before any modifications
336
+ description = v.get("description")
337
+
334
338
  if v.get("anyOf") and all(
335
339
  anyOf_type.get("type") != "null" for anyOf_type in v.get("anyOf", [])
336
340
  ):
@@ -338,6 +342,9 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
338
342
  _format_json_schema_to_gapic(anyOf_type)
339
343
  for anyOf_type in v.get("anyOf", [])
340
344
  ]
345
+ # For non-nullable anyOf, we still need to set a type
346
+ item_type_ = _get_type_from_schema(v)
347
+ properties_item["type_"] = item_type_
341
348
  elif v.get("type") or v.get("anyOf") or v.get("type_"):
342
349
  item_type_ = _get_type_from_schema(v)
343
350
  properties_item["type_"] = item_type_
@@ -354,7 +361,6 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
354
361
  if v.get("enum"):
355
362
  properties_item["enum"] = v["enum"]
356
363
 
357
- description = v.get("description")
358
364
  if description and isinstance(description, str):
359
365
  properties_item["description"] = description
360
366
 
@@ -377,8 +383,9 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
377
383
  properties_item["required"] = [
378
384
  k for k, v in v_properties.items() if "default" not in v
379
385
  ]
380
- else:
381
- # Providing dummy type for object without properties
386
+ elif not v.get("additionalProperties"):
387
+ # Only provide dummy type for object without properties AND without
388
+ # additionalProperties
382
389
  properties_item["type_"] = glm.Type.STRING
383
390
 
384
391
  if k == "title" and "description" not in properties_item:
@@ -414,6 +421,8 @@ def _get_items_from_schema(schema: Union[Dict, List, str]) -> Dict[str, Any]:
414
421
  items["nullable"] = True
415
422
  if "required" in schema:
416
423
  items["required"] = schema["required"]
424
+ if "enum" in schema:
425
+ items["enum"] = schema["enum"]
417
426
  else:
418
427
  # str
419
428
  items["type_"] = _get_type_from_schema({"type": schema})
@@ -92,13 +92,7 @@ from langchain_core.utils.function_calling import (
92
92
  )
93
93
  from langchain_core.utils.pydantic import is_basemodel_subclass
94
94
  from langchain_core.utils.utils import _build_model_kwargs
95
- from pydantic import (
96
- BaseModel,
97
- ConfigDict,
98
- Field,
99
- SecretStr,
100
- model_validator,
101
- )
95
+ from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
102
96
  from pydantic.v1 import BaseModel as BaseModelV1
103
97
  from tenacity import (
104
98
  before_sleep_log,
@@ -496,7 +490,12 @@ def _parse_chat_history(
496
490
  messages: List[Content] = []
497
491
 
498
492
  if convert_system_message_to_human:
499
- warnings.warn("Convert_system_message_to_human will be deprecated!")
493
+ warnings.warn(
494
+ "The 'convert_system_message_to_human' parameter is deprecated and will be "
495
+ "removed in a future version. Use system instructions instead.",
496
+ DeprecationWarning,
497
+ stacklevel=2,
498
+ )
500
499
 
501
500
  system_instruction: Optional[Content] = None
502
501
  messages_without_tool_messages = [
@@ -659,9 +658,15 @@ def _parse_response_candidate(
659
658
  function_call = {"name": part.function_call.name}
660
659
  # dump to match other function calling llm for now
661
660
  function_call_args_dict = proto.Message.to_dict(part.function_call)["args"]
662
- function_call["arguments"] = json.dumps(
663
- {k: function_call_args_dict[k] for k in function_call_args_dict}
664
- )
661
+
662
+ # Fix: Correct integer-like floats from protobuf conversion
663
+ # The protobuf library sometimes converts integers to floats
664
+ corrected_args = {
665
+ k: int(v) if isinstance(v, float) and v.is_integer() else v
666
+ for k, v in function_call_args_dict.items()
667
+ }
668
+
669
+ function_call["arguments"] = json.dumps(corrected_args)
665
670
  additional_kwargs["function_call"] = function_call
666
671
 
667
672
  if streaming:
@@ -698,7 +703,9 @@ def _parse_response_candidate(
698
703
  )
699
704
  if content is None:
700
705
  content = ""
701
- if any(isinstance(item, dict) and "executable_code" in item for item in content):
706
+ if isinstance(content, list) and any(
707
+ isinstance(item, dict) and "executable_code" in item for item in content
708
+ ):
702
709
  warnings.warn(
703
710
  """
704
711
  ⚠️ Warning: Output may vary each run.
@@ -817,7 +824,15 @@ def _response_to_result(
817
824
  if stream:
818
825
  generations = [
819
826
  ChatGenerationChunk(
820
- message=AIMessageChunk(content=""), generation_info={}
827
+ message=AIMessageChunk(
828
+ content="",
829
+ response_metadata={
830
+ "prompt_feedback": proto.Message.to_dict(
831
+ response.prompt_feedback
832
+ )
833
+ },
834
+ ),
835
+ generation_info={},
821
836
  )
822
837
  ]
823
838
  else:
@@ -1281,8 +1296,8 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1281
1296
 
1282
1297
  client: Any = Field(default=None, exclude=True) #: :meta private:
1283
1298
  async_client_running: Any = Field(default=None, exclude=True) #: :meta private:
1284
- default_metadata: Sequence[Tuple[str, str]] = Field(
1285
- default_factory=list
1299
+ default_metadata: Optional[Sequence[Tuple[str, str]]] = Field(
1300
+ default=None, alias="default_metadata_input"
1286
1301
  ) #: :meta private:
1287
1302
 
1288
1303
  convert_system_message_to_human: bool = False
@@ -1317,6 +1332,12 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1317
1332
  ``cachedContents/{cachedContent}``.
1318
1333
  """
1319
1334
 
1335
+ stop: Optional[List[str]] = None
1336
+ """Stop sequences for the model."""
1337
+
1338
+ streaming: Optional[bool] = None
1339
+ """Whether to stream responses from the model."""
1340
+
1320
1341
  model_kwargs: dict[str, Any] = Field(default_factory=dict)
1321
1342
  """Holds any unexpected initialization parameters."""
1322
1343
 
@@ -1528,18 +1549,21 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1528
1549
  "response_modalities": self.response_modalities,
1529
1550
  "thinking_config": (
1530
1551
  (
1531
- {"thinking_budget": self.thinking_budget}
1532
- if self.thinking_budget is not None
1533
- else {}
1534
- )
1535
- | (
1536
- {"include_thoughts": self.include_thoughts}
1537
- if self.include_thoughts is not None
1538
- else {}
1552
+ (
1553
+ {"thinking_budget": self.thinking_budget}
1554
+ if self.thinking_budget is not None
1555
+ else {}
1556
+ )
1557
+ | (
1558
+ {"include_thoughts": self.include_thoughts}
1559
+ if self.include_thoughts is not None
1560
+ else {}
1561
+ )
1539
1562
  )
1540
- )
1541
- if self.thinking_budget is not None or self.include_thoughts is not None
1542
- else None,
1563
+ if self.thinking_budget is not None
1564
+ or self.include_thoughts is not None
1565
+ else None
1566
+ ),
1543
1567
  }.items()
1544
1568
  if v is not None
1545
1569
  }
@@ -1781,7 +1805,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1781
1805
  generation_config: Optional[Dict[str, Any]] = None,
1782
1806
  cached_content: Optional[str] = None,
1783
1807
  **kwargs: Any,
1784
- ) -> Tuple[GenerateContentRequest, Dict[str, Any]]:
1808
+ ) -> GenerateContentRequest:
1785
1809
  if tool_choice and tool_config:
1786
1810
  raise ValueError(
1787
1811
  "Must specify at most one of tool_choice and tool_config, received "
@@ -1807,10 +1831,13 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1807
1831
  filtered_messages.append(message)
1808
1832
  messages = filtered_messages
1809
1833
 
1810
- system_instruction, history = _parse_chat_history(
1811
- messages,
1812
- convert_system_message_to_human=self.convert_system_message_to_human,
1813
- )
1834
+ if self.convert_system_message_to_human:
1835
+ system_instruction, history = _parse_chat_history(
1836
+ messages,
1837
+ convert_system_message_to_human=self.convert_system_message_to_human,
1838
+ )
1839
+ else:
1840
+ system_instruction, history = _parse_chat_history(messages)
1814
1841
  if tool_choice:
1815
1842
  if not formatted_tools:
1816
1843
  msg = (
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import re
2
3
  import string
3
4
  from typing import Any, Dict, List, Optional
@@ -26,6 +27,14 @@ _MAX_TOKENS_PER_BATCH = 20000
26
27
  _DEFAULT_BATCH_SIZE = 100
27
28
 
28
29
 
30
+ def _is_event_loop_running() -> bool:
31
+ try:
32
+ asyncio.get_running_loop()
33
+ return True
34
+ except RuntimeError:
35
+ return False
36
+
37
+
29
38
  class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
30
39
  """`Google Generative AI Embeddings`.
31
40
 
@@ -107,15 +116,48 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
107
116
  client_options=self.client_options,
108
117
  transport=self.transport,
109
118
  )
110
- self.async_client = build_generative_async_service(
111
- credentials=self.credentials,
112
- api_key=google_api_key,
113
- client_info=client_info,
114
- client_options=self.client_options,
115
- transport=self.transport,
116
- )
119
+ # Only initialize async client if there's an event loop running
120
+ # to avoid RuntimeError during synchronous initialization
121
+ if _is_event_loop_running():
122
+ # async clients don't support "rest" transport
123
+ transport = self.transport
124
+ if transport == "rest":
125
+ transport = "grpc_asyncio"
126
+ self.async_client = build_generative_async_service(
127
+ credentials=self.credentials,
128
+ api_key=google_api_key,
129
+ client_info=client_info,
130
+ client_options=self.client_options,
131
+ transport=transport,
132
+ )
133
+ else:
134
+ self.async_client = None
117
135
  return self
118
136
 
137
+ @property
138
+ def _async_client(self) -> Any:
139
+ """Get or create the async client when needed."""
140
+ if self.async_client is None:
141
+ if isinstance(self.google_api_key, SecretStr):
142
+ google_api_key: Optional[str] = self.google_api_key.get_secret_value()
143
+ else:
144
+ google_api_key = self.google_api_key
145
+
146
+ client_info = get_client_info("GoogleGenerativeAIEmbeddings")
147
+ # async clients don't support "rest" transport
148
+ transport = self.transport
149
+ if transport == "rest":
150
+ transport = "grpc_asyncio"
151
+
152
+ self.async_client = build_generative_async_service(
153
+ credentials=self.credentials,
154
+ api_key=google_api_key,
155
+ client_info=client_info,
156
+ client_options=self.client_options,
157
+ transport=transport,
158
+ )
159
+ return self.async_client
160
+
119
161
  @staticmethod
120
162
  def _split_by_punctuation(text: str) -> List[str]:
121
163
  """Splits a string by punctuation and whitespace characters."""
@@ -328,7 +370,7 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
328
370
  ]
329
371
 
330
372
  try:
331
- result = await self.async_client.batch_embed_contents(
373
+ result = await self._async_client.batch_embed_contents(
332
374
  BatchEmbedContentsRequest(requests=requests, model=self.model)
333
375
  )
334
376
  except Exception as e:
@@ -366,7 +408,7 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
366
408
  title=title,
367
409
  output_dimensionality=output_dimensionality,
368
410
  )
369
- result: EmbedContentResponse = await self.async_client.embed_content(
411
+ result: EmbedContentResponse = await self._async_client.embed_content(
370
412
  request
371
413
  )
372
414
  except Exception as e:
@@ -41,7 +41,7 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
41
41
  """Needed for arg validation."""
42
42
  # Get all valid field names, including aliases
43
43
  valid_fields = set()
44
- for field_name, field_info in self.model_fields.items():
44
+ for field_name, field_info in self.__class__.model_fields.items():
45
45
  valid_fields.add(field_name)
46
46
  if hasattr(field_info, "alias") and field_info.alias is not None:
47
47
  valid_fields.add(field_info.alias)
@@ -0,0 +1,84 @@
1
+ [build-system]
2
+ requires = [
3
+ "pdm-backend",
4
+ ]
5
+ build-backend = "pdm.backend"
6
+
7
+ [project]
8
+ name = "langchain-google-genai"
9
+ version = "2.1.11"
10
+ description = "An integration package connecting Google's genai package and LangChain"
11
+ authors = []
12
+ requires-python = ">=3.9"
13
+ readme = "README.md"
14
+ repository = "https://github.com/langchain-ai/langchain-google"
15
+ dependencies = [
16
+ "langchain-core>=0.3.75",
17
+ "google-ai-generativelanguage>=0.7,<1",
18
+ "pydantic>=2,<3",
19
+ "filetype>=1.2,<2",
20
+ ]
21
+
22
+ [project.license]
23
+ text = "MIT"
24
+
25
+ [project.urls]
26
+ "Source Code" = "https://github.com/langchain-ai/langchain-google/tree/main/libs/genai"
27
+
28
+ [dependency-groups]
29
+ test = [
30
+ "pytest>=8.4,<9",
31
+ "freezegun>=1.5,<2",
32
+ "pytest-mock>=3.14,<4",
33
+ "syrupy>=4.9,<5",
34
+ "pytest-watcher>=0.4,<1",
35
+ "pytest-asyncio>=0.21,<1",
36
+ "pytest-retry>=1.7,<2",
37
+ "pytest-socket>=0.7,<1",
38
+ "numpy>=1.26.4; python_version<'3.13'",
39
+ "numpy>=2.1.0; python_version>='3.13'",
40
+ "langchain-tests>=0.3,<1",
41
+ ]
42
+ test_integration = [
43
+ "pytest>=8.4,<9",
44
+ ]
45
+ lint = [
46
+ "ruff>=0.12.10,<1",
47
+ ]
48
+ typing = [
49
+ "mypy>=1.17.1,<2",
50
+ "types-requests>=2.31,<3",
51
+ "types-google-cloud-ndb>=2.2.0.1,<3",
52
+ "types-protobuf>=4.24.0.20240302,<5",
53
+ "numpy>=1.26.2",
54
+ ]
55
+ dev = [
56
+ "types-requests>=2.31.0,<3",
57
+ "types-google-cloud-ndb>=2.2.0.1,<3",
58
+ ]
59
+
60
+ [tool.ruff]
61
+ fix = true
62
+
63
+ [tool.ruff.lint]
64
+ select = [
65
+ "E",
66
+ "F",
67
+ "I",
68
+ ]
69
+
70
+ [tool.mypy]
71
+ disallow_untyped_defs = "True"
72
+
73
+ [tool.coverage.run]
74
+ omit = [
75
+ "tests/*",
76
+ ]
77
+
78
+ [tool.pytest.ini_options]
79
+ markers = [
80
+ "requires: mark tests as requiring a specific library",
81
+ "asyncio: mark tests as requiring asyncio",
82
+ "compile: mark placeholder test used to compile integration tests without running them",
83
+ ]
84
+ asyncio_mode = "auto"
File without changes
@@ -0,0 +1,64 @@
1
+ """
2
+ Tests configuration to be executed before tests execution.
3
+ """
4
+
5
+ from typing import List
6
+
7
+ import pytest
8
+
9
+ _RELEASE_FLAG = "release"
10
+ _GPU_FLAG = "gpu"
11
+ _LONG_FLAG = "long"
12
+ _EXTENDED_FLAG = "extended"
13
+
14
+ _PYTEST_FLAGS = [_RELEASE_FLAG, _GPU_FLAG, _LONG_FLAG, _EXTENDED_FLAG]
15
+
16
+
17
+ def pytest_addoption(parser: pytest.Parser) -> None:
18
+ """
19
+ Add flags accepted by pytest CLI.
20
+
21
+ Args:
22
+ parser: The pytest parser object.
23
+
24
+ Returns:
25
+
26
+ """
27
+ for flag in _PYTEST_FLAGS:
28
+ parser.addoption(
29
+ f"--{flag}", action="store_true", default=False, help=f"run {flag} tests"
30
+ )
31
+
32
+
33
+ def pytest_configure(config: pytest.Config) -> None:
34
+ """
35
+ Add pytest custom configuration.
36
+
37
+ Args:
38
+ config: The pytest config object.
39
+
40
+ Returns:
41
+ """
42
+ for flag in _PYTEST_FLAGS:
43
+ config.addinivalue_line(
44
+ "markers", f"{flag}: mark test to run as {flag} only test"
45
+ )
46
+
47
+
48
+ def pytest_collection_modifyitems(
49
+ config: pytest.Config, items: List[pytest.Item]
50
+ ) -> None:
51
+ """
52
+ Skip tests with a marker from our list that were not explicitly invoked.
53
+
54
+ Args:
55
+ config: The pytest config object.
56
+ items: The list of tests to be executed.
57
+
58
+ Returns:
59
+ """
60
+ for item in items:
61
+ keywords = list(set(item.keywords).intersection(_PYTEST_FLAGS))
62
+ if keywords and not any((config.getoption(f"--{kw}") for kw in keywords)):
63
+ skip = pytest.mark.skip(reason=f"need --{keywords[0]} option to run")
64
+ item.add_marker(skip)
@@ -0,0 +1 @@
1
+ PROJECT_ID=project_id
@@ -0,0 +1,12 @@
1
+ module "cloudbuild" {
2
+ source = "./../../../../../terraform/cloudbuild"
3
+
4
+ library = "genai"
5
+ project_id = ""
6
+ cloudbuildv2_repository_id = ""
7
+ cloudbuild_env_vars = {
8
+ }
9
+ cloudbuild_secret_vars = {
10
+ GOOGLE_API_KEY = ""
11
+ }
12
+ }
@@ -0,0 +1,31 @@
1
+ from typing import Any, List
2
+
3
+ from langchain_core.callbacks import BaseCallbackHandler
4
+ from langchain_core.outputs import LLMResult
5
+ from langchain_core.prompts import PromptTemplate
6
+
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+
9
+
10
+ class StreamingLLMCallbackHandler(BaseCallbackHandler):
11
+ def __init__(self, **kwargs: Any):
12
+ super().__init__(**kwargs)
13
+ self.tokens: List[Any] = []
14
+ self.generations: List[Any] = []
15
+
16
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
17
+ self.tokens.append(token)
18
+
19
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
20
+ self.generations.append(response.generations[0][0].text)
21
+
22
+
23
+ def test_streaming_callback() -> None:
24
+ prompt_template = "Tell me details about the Company {name} with 2 bullet point?"
25
+ cb = StreamingLLMCallbackHandler()
26
+ llm = ChatGoogleGenerativeAI(model="models/gemini-2.0-flash-001", callbacks=[cb])
27
+ llm_chain = PromptTemplate.from_template(prompt_template) | llm
28
+ for t in llm_chain.stream({"name": "Google"}):
29
+ pass
30
+ assert len(cb.tokens) > 1
31
+ assert len(cb.generations) == 1