langchain-google-genai 2.1.10__tar.gz → 2.1.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain-google-genai might be problematic. Click here for more details.

Files changed (42) hide show
  1. {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/PKG-INFO +6 -15
  2. {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/_function_utils.py +12 -3
  3. {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/chat_models.py +53 -28
  4. {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/embeddings.py +51 -9
  5. {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/llms.py +1 -1
  6. langchain_google_genai-2.1.11/pyproject.toml +84 -0
  7. langchain_google_genai-2.1.11/tests/__init__.py +0 -0
  8. langchain_google_genai-2.1.11/tests/conftest.py +64 -0
  9. langchain_google_genai-2.1.11/tests/integration_tests/.env.example +1 -0
  10. langchain_google_genai-2.1.11/tests/integration_tests/__init__.py +0 -0
  11. langchain_google_genai-2.1.11/tests/integration_tests/terraform/main.tf +12 -0
  12. langchain_google_genai-2.1.11/tests/integration_tests/test_callbacks.py +31 -0
  13. langchain_google_genai-2.1.11/tests/integration_tests/test_chat_models.py +887 -0
  14. langchain_google_genai-2.1.11/tests/integration_tests/test_compile.py +7 -0
  15. langchain_google_genai-2.1.11/tests/integration_tests/test_embeddings.py +145 -0
  16. langchain_google_genai-2.1.11/tests/integration_tests/test_function_call.py +90 -0
  17. langchain_google_genai-2.1.11/tests/integration_tests/test_llms.py +100 -0
  18. langchain_google_genai-2.1.11/tests/integration_tests/test_standard.py +141 -0
  19. langchain_google_genai-2.1.11/tests/integration_tests/test_tools.py +37 -0
  20. langchain_google_genai-2.1.11/tests/unit_tests/__init__.py +0 -0
  21. langchain_google_genai-2.1.11/tests/unit_tests/__snapshots__/test_standard.ambr +63 -0
  22. langchain_google_genai-2.1.11/tests/unit_tests/test_chat_models.py +916 -0
  23. langchain_google_genai-2.1.11/tests/unit_tests/test_chat_models_protobuf_fix.py +132 -0
  24. langchain_google_genai-2.1.11/tests/unit_tests/test_common.py +31 -0
  25. langchain_google_genai-2.1.11/tests/unit_tests/test_embeddings.py +158 -0
  26. langchain_google_genai-2.1.11/tests/unit_tests/test_function_utils.py +1406 -0
  27. langchain_google_genai-2.1.11/tests/unit_tests/test_genai_aqa.py +95 -0
  28. langchain_google_genai-2.1.11/tests/unit_tests/test_google_vector_store.py +440 -0
  29. langchain_google_genai-2.1.11/tests/unit_tests/test_imports.py +20 -0
  30. langchain_google_genai-2.1.11/tests/unit_tests/test_llms.py +47 -0
  31. langchain_google_genai-2.1.11/tests/unit_tests/test_standard.py +42 -0
  32. langchain_google_genai-2.1.10/pyproject.toml +0 -109
  33. {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/LICENSE +0 -0
  34. {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/README.md +0 -0
  35. {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/__init__.py +0 -0
  36. {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/_common.py +0 -0
  37. {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/_enums.py +0 -0
  38. {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/_genai_extension.py +0 -0
  39. {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/_image_utils.py +0 -0
  40. {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/genai_aqa.py +0 -0
  41. {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/google_vector_store.py +0 -0
  42. {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/py.typed +0 -0
@@ -1,22 +1,14 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langchain-google-genai
3
- Version: 2.1.10
3
+ Version: 2.1.11
4
4
  Summary: An integration package connecting Google's genai package and LangChain
5
- Home-page: https://github.com/langchain-ai/langchain-google
6
5
  License: MIT
7
- Requires-Python: >=3.9,<4.0
8
- Classifier: License :: OSI Approved :: MIT License
9
- Classifier: Programming Language :: Python :: 3
10
- Classifier: Programming Language :: Python :: 3.9
11
- Classifier: Programming Language :: Python :: 3.10
12
- Classifier: Programming Language :: Python :: 3.11
13
- Classifier: Programming Language :: Python :: 3.12
14
- Requires-Dist: filetype (>=1.2.0,<2.0.0)
15
- Requires-Dist: google-ai-generativelanguage (>=0.6.18,<0.7.0)
16
- Requires-Dist: langchain-core (>=0.3.75,<0.4.0)
17
- Requires-Dist: pydantic (>=2,<3)
18
- Project-URL: Repository, https://github.com/langchain-ai/langchain-google
19
6
  Project-URL: Source Code, https://github.com/langchain-ai/langchain-google/tree/main/libs/genai
7
+ Requires-Python: >=3.9
8
+ Requires-Dist: langchain-core>=0.3.75
9
+ Requires-Dist: google-ai-generativelanguage<1,>=0.7
10
+ Requires-Dist: pydantic<3,>=2
11
+ Requires-Dist: filetype<2,>=1.2
20
12
  Description-Content-Type: text/markdown
21
13
 
22
14
  # langchain-google-genai
@@ -258,4 +250,3 @@ print("Answerable probability:", response.answerable_probability)
258
250
  - [LangChain Documentation](https://docs.langchain.com/)
259
251
  - [Google Generative AI SDK](https://googleapis.github.io/python-genai/)
260
252
  - [Gemini Model Documentation](https://ai.google.dev/)
261
-
@@ -331,6 +331,10 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
331
331
  logger.warning(f"Value '{v}' is not supported in schema, ignoring v={v}")
332
332
  continue
333
333
  properties_item: Dict[str, Union[str, int, Dict, List]] = {}
334
+
335
+ # Get description from original schema before any modifications
336
+ description = v.get("description")
337
+
334
338
  if v.get("anyOf") and all(
335
339
  anyOf_type.get("type") != "null" for anyOf_type in v.get("anyOf", [])
336
340
  ):
@@ -338,6 +342,9 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
338
342
  _format_json_schema_to_gapic(anyOf_type)
339
343
  for anyOf_type in v.get("anyOf", [])
340
344
  ]
345
+ # For non-nullable anyOf, we still need to set a type
346
+ item_type_ = _get_type_from_schema(v)
347
+ properties_item["type_"] = item_type_
341
348
  elif v.get("type") or v.get("anyOf") or v.get("type_"):
342
349
  item_type_ = _get_type_from_schema(v)
343
350
  properties_item["type_"] = item_type_
@@ -354,7 +361,6 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
354
361
  if v.get("enum"):
355
362
  properties_item["enum"] = v["enum"]
356
363
 
357
- description = v.get("description")
358
364
  if description and isinstance(description, str):
359
365
  properties_item["description"] = description
360
366
 
@@ -377,8 +383,9 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
377
383
  properties_item["required"] = [
378
384
  k for k, v in v_properties.items() if "default" not in v
379
385
  ]
380
- else:
381
- # Providing dummy type for object without properties
386
+ elif not v.get("additionalProperties"):
387
+ # Only provide dummy type for object without properties AND without
388
+ # additionalProperties
382
389
  properties_item["type_"] = glm.Type.STRING
383
390
 
384
391
  if k == "title" and "description" not in properties_item:
@@ -414,6 +421,8 @@ def _get_items_from_schema(schema: Union[Dict, List, str]) -> Dict[str, Any]:
414
421
  items["nullable"] = True
415
422
  if "required" in schema:
416
423
  items["required"] = schema["required"]
424
+ if "enum" in schema:
425
+ items["enum"] = schema["enum"]
417
426
  else:
418
427
  # str
419
428
  items["type_"] = _get_type_from_schema({"type": schema})
@@ -92,13 +92,7 @@ from langchain_core.utils.function_calling import (
92
92
  )
93
93
  from langchain_core.utils.pydantic import is_basemodel_subclass
94
94
  from langchain_core.utils.utils import _build_model_kwargs
95
- from pydantic import (
96
- BaseModel,
97
- ConfigDict,
98
- Field,
99
- SecretStr,
100
- model_validator,
101
- )
95
+ from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
102
96
  from pydantic.v1 import BaseModel as BaseModelV1
103
97
  from tenacity import (
104
98
  before_sleep_log,
@@ -496,7 +490,12 @@ def _parse_chat_history(
496
490
  messages: List[Content] = []
497
491
 
498
492
  if convert_system_message_to_human:
499
- warnings.warn("Convert_system_message_to_human will be deprecated!")
493
+ warnings.warn(
494
+ "The 'convert_system_message_to_human' parameter is deprecated and will be "
495
+ "removed in a future version. Use system instructions instead.",
496
+ DeprecationWarning,
497
+ stacklevel=2,
498
+ )
500
499
 
501
500
  system_instruction: Optional[Content] = None
502
501
  messages_without_tool_messages = [
@@ -659,9 +658,15 @@ def _parse_response_candidate(
659
658
  function_call = {"name": part.function_call.name}
660
659
  # dump to match other function calling llm for now
661
660
  function_call_args_dict = proto.Message.to_dict(part.function_call)["args"]
662
- function_call["arguments"] = json.dumps(
663
- {k: function_call_args_dict[k] for k in function_call_args_dict}
664
- )
661
+
662
+ # Fix: Correct integer-like floats from protobuf conversion
663
+ # The protobuf library sometimes converts integers to floats
664
+ corrected_args = {
665
+ k: int(v) if isinstance(v, float) and v.is_integer() else v
666
+ for k, v in function_call_args_dict.items()
667
+ }
668
+
669
+ function_call["arguments"] = json.dumps(corrected_args)
665
670
  additional_kwargs["function_call"] = function_call
666
671
 
667
672
  if streaming:
@@ -819,7 +824,15 @@ def _response_to_result(
819
824
  if stream:
820
825
  generations = [
821
826
  ChatGenerationChunk(
822
- message=AIMessageChunk(content=""), generation_info={}
827
+ message=AIMessageChunk(
828
+ content="",
829
+ response_metadata={
830
+ "prompt_feedback": proto.Message.to_dict(
831
+ response.prompt_feedback
832
+ )
833
+ },
834
+ ),
835
+ generation_info={},
823
836
  )
824
837
  ]
825
838
  else:
@@ -1319,6 +1332,12 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1319
1332
  ``cachedContents/{cachedContent}``.
1320
1333
  """
1321
1334
 
1335
+ stop: Optional[List[str]] = None
1336
+ """Stop sequences for the model."""
1337
+
1338
+ streaming: Optional[bool] = None
1339
+ """Whether to stream responses from the model."""
1340
+
1322
1341
  model_kwargs: dict[str, Any] = Field(default_factory=dict)
1323
1342
  """Holds any unexpected initialization parameters."""
1324
1343
 
@@ -1530,18 +1549,21 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1530
1549
  "response_modalities": self.response_modalities,
1531
1550
  "thinking_config": (
1532
1551
  (
1533
- {"thinking_budget": self.thinking_budget}
1534
- if self.thinking_budget is not None
1535
- else {}
1536
- )
1537
- | (
1538
- {"include_thoughts": self.include_thoughts}
1539
- if self.include_thoughts is not None
1540
- else {}
1552
+ (
1553
+ {"thinking_budget": self.thinking_budget}
1554
+ if self.thinking_budget is not None
1555
+ else {}
1556
+ )
1557
+ | (
1558
+ {"include_thoughts": self.include_thoughts}
1559
+ if self.include_thoughts is not None
1560
+ else {}
1561
+ )
1541
1562
  )
1542
- )
1543
- if self.thinking_budget is not None or self.include_thoughts is not None
1544
- else None,
1563
+ if self.thinking_budget is not None
1564
+ or self.include_thoughts is not None
1565
+ else None
1566
+ ),
1545
1567
  }.items()
1546
1568
  if v is not None
1547
1569
  }
@@ -1783,7 +1805,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1783
1805
  generation_config: Optional[Dict[str, Any]] = None,
1784
1806
  cached_content: Optional[str] = None,
1785
1807
  **kwargs: Any,
1786
- ) -> Tuple[GenerateContentRequest, Dict[str, Any]]:
1808
+ ) -> GenerateContentRequest:
1787
1809
  if tool_choice and tool_config:
1788
1810
  raise ValueError(
1789
1811
  "Must specify at most one of tool_choice and tool_config, received "
@@ -1809,10 +1831,13 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
1809
1831
  filtered_messages.append(message)
1810
1832
  messages = filtered_messages
1811
1833
 
1812
- system_instruction, history = _parse_chat_history(
1813
- messages,
1814
- convert_system_message_to_human=self.convert_system_message_to_human,
1815
- )
1834
+ if self.convert_system_message_to_human:
1835
+ system_instruction, history = _parse_chat_history(
1836
+ messages,
1837
+ convert_system_message_to_human=self.convert_system_message_to_human,
1838
+ )
1839
+ else:
1840
+ system_instruction, history = _parse_chat_history(messages)
1816
1841
  if tool_choice:
1817
1842
  if not formatted_tools:
1818
1843
  msg = (
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import re
2
3
  import string
3
4
  from typing import Any, Dict, List, Optional
@@ -26,6 +27,14 @@ _MAX_TOKENS_PER_BATCH = 20000
26
27
  _DEFAULT_BATCH_SIZE = 100
27
28
 
28
29
 
30
+ def _is_event_loop_running() -> bool:
31
+ try:
32
+ asyncio.get_running_loop()
33
+ return True
34
+ except RuntimeError:
35
+ return False
36
+
37
+
29
38
  class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
30
39
  """`Google Generative AI Embeddings`.
31
40
 
@@ -107,15 +116,48 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
107
116
  client_options=self.client_options,
108
117
  transport=self.transport,
109
118
  )
110
- self.async_client = build_generative_async_service(
111
- credentials=self.credentials,
112
- api_key=google_api_key,
113
- client_info=client_info,
114
- client_options=self.client_options,
115
- transport=self.transport,
116
- )
119
+ # Only initialize async client if there's an event loop running
120
+ # to avoid RuntimeError during synchronous initialization
121
+ if _is_event_loop_running():
122
+ # async clients don't support "rest" transport
123
+ transport = self.transport
124
+ if transport == "rest":
125
+ transport = "grpc_asyncio"
126
+ self.async_client = build_generative_async_service(
127
+ credentials=self.credentials,
128
+ api_key=google_api_key,
129
+ client_info=client_info,
130
+ client_options=self.client_options,
131
+ transport=transport,
132
+ )
133
+ else:
134
+ self.async_client = None
117
135
  return self
118
136
 
137
+ @property
138
+ def _async_client(self) -> Any:
139
+ """Get or create the async client when needed."""
140
+ if self.async_client is None:
141
+ if isinstance(self.google_api_key, SecretStr):
142
+ google_api_key: Optional[str] = self.google_api_key.get_secret_value()
143
+ else:
144
+ google_api_key = self.google_api_key
145
+
146
+ client_info = get_client_info("GoogleGenerativeAIEmbeddings")
147
+ # async clients don't support "rest" transport
148
+ transport = self.transport
149
+ if transport == "rest":
150
+ transport = "grpc_asyncio"
151
+
152
+ self.async_client = build_generative_async_service(
153
+ credentials=self.credentials,
154
+ api_key=google_api_key,
155
+ client_info=client_info,
156
+ client_options=self.client_options,
157
+ transport=transport,
158
+ )
159
+ return self.async_client
160
+
119
161
  @staticmethod
120
162
  def _split_by_punctuation(text: str) -> List[str]:
121
163
  """Splits a string by punctuation and whitespace characters."""
@@ -328,7 +370,7 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
328
370
  ]
329
371
 
330
372
  try:
331
- result = await self.async_client.batch_embed_contents(
373
+ result = await self._async_client.batch_embed_contents(
332
374
  BatchEmbedContentsRequest(requests=requests, model=self.model)
333
375
  )
334
376
  except Exception as e:
@@ -366,7 +408,7 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
366
408
  title=title,
367
409
  output_dimensionality=output_dimensionality,
368
410
  )
369
- result: EmbedContentResponse = await self.async_client.embed_content(
411
+ result: EmbedContentResponse = await self._async_client.embed_content(
370
412
  request
371
413
  )
372
414
  except Exception as e:
@@ -41,7 +41,7 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
41
41
  """Needed for arg validation."""
42
42
  # Get all valid field names, including aliases
43
43
  valid_fields = set()
44
- for field_name, field_info in self.model_fields.items():
44
+ for field_name, field_info in self.__class__.model_fields.items():
45
45
  valid_fields.add(field_name)
46
46
  if hasattr(field_info, "alias") and field_info.alias is not None:
47
47
  valid_fields.add(field_info.alias)
@@ -0,0 +1,84 @@
1
+ [build-system]
2
+ requires = [
3
+ "pdm-backend",
4
+ ]
5
+ build-backend = "pdm.backend"
6
+
7
+ [project]
8
+ name = "langchain-google-genai"
9
+ version = "2.1.11"
10
+ description = "An integration package connecting Google's genai package and LangChain"
11
+ authors = []
12
+ requires-python = ">=3.9"
13
+ readme = "README.md"
14
+ repository = "https://github.com/langchain-ai/langchain-google"
15
+ dependencies = [
16
+ "langchain-core>=0.3.75",
17
+ "google-ai-generativelanguage>=0.7,<1",
18
+ "pydantic>=2,<3",
19
+ "filetype>=1.2,<2",
20
+ ]
21
+
22
+ [project.license]
23
+ text = "MIT"
24
+
25
+ [project.urls]
26
+ "Source Code" = "https://github.com/langchain-ai/langchain-google/tree/main/libs/genai"
27
+
28
+ [dependency-groups]
29
+ test = [
30
+ "pytest>=8.4,<9",
31
+ "freezegun>=1.5,<2",
32
+ "pytest-mock>=3.14,<4",
33
+ "syrupy>=4.9,<5",
34
+ "pytest-watcher>=0.4,<1",
35
+ "pytest-asyncio>=0.21,<1",
36
+ "pytest-retry>=1.7,<2",
37
+ "pytest-socket>=0.7,<1",
38
+ "numpy>=1.26.4; python_version<'3.13'",
39
+ "numpy>=2.1.0; python_version>='3.13'",
40
+ "langchain-tests>=0.3,<1",
41
+ ]
42
+ test_integration = [
43
+ "pytest>=8.4,<9",
44
+ ]
45
+ lint = [
46
+ "ruff>=0.12.10,<1",
47
+ ]
48
+ typing = [
49
+ "mypy>=1.17.1,<2",
50
+ "types-requests>=2.31,<3",
51
+ "types-google-cloud-ndb>=2.2.0.1,<3",
52
+ "types-protobuf>=4.24.0.20240302,<5",
53
+ "numpy>=1.26.2",
54
+ ]
55
+ dev = [
56
+ "types-requests>=2.31.0,<3",
57
+ "types-google-cloud-ndb>=2.2.0.1,<3",
58
+ ]
59
+
60
+ [tool.ruff]
61
+ fix = true
62
+
63
+ [tool.ruff.lint]
64
+ select = [
65
+ "E",
66
+ "F",
67
+ "I",
68
+ ]
69
+
70
+ [tool.mypy]
71
+ disallow_untyped_defs = "True"
72
+
73
+ [tool.coverage.run]
74
+ omit = [
75
+ "tests/*",
76
+ ]
77
+
78
+ [tool.pytest.ini_options]
79
+ markers = [
80
+ "requires: mark tests as requiring a specific library",
81
+ "asyncio: mark tests as requiring asyncio",
82
+ "compile: mark placeholder test used to compile integration tests without running them",
83
+ ]
84
+ asyncio_mode = "auto"
File without changes
@@ -0,0 +1,64 @@
1
+ """
2
+ Tests configuration to be executed before tests execution.
3
+ """
4
+
5
+ from typing import List
6
+
7
+ import pytest
8
+
9
+ _RELEASE_FLAG = "release"
10
+ _GPU_FLAG = "gpu"
11
+ _LONG_FLAG = "long"
12
+ _EXTENDED_FLAG = "extended"
13
+
14
+ _PYTEST_FLAGS = [_RELEASE_FLAG, _GPU_FLAG, _LONG_FLAG, _EXTENDED_FLAG]
15
+
16
+
17
+ def pytest_addoption(parser: pytest.Parser) -> None:
18
+ """
19
+ Add flags accepted by pytest CLI.
20
+
21
+ Args:
22
+ parser: The pytest parser object.
23
+
24
+ Returns:
25
+
26
+ """
27
+ for flag in _PYTEST_FLAGS:
28
+ parser.addoption(
29
+ f"--{flag}", action="store_true", default=False, help=f"run {flag} tests"
30
+ )
31
+
32
+
33
+ def pytest_configure(config: pytest.Config) -> None:
34
+ """
35
+ Add pytest custom configuration.
36
+
37
+ Args:
38
+ config: The pytest config object.
39
+
40
+ Returns:
41
+ """
42
+ for flag in _PYTEST_FLAGS:
43
+ config.addinivalue_line(
44
+ "markers", f"{flag}: mark test to run as {flag} only test"
45
+ )
46
+
47
+
48
+ def pytest_collection_modifyitems(
49
+ config: pytest.Config, items: List[pytest.Item]
50
+ ) -> None:
51
+ """
52
+ Skip tests with a marker from our list that were not explicitly invoked.
53
+
54
+ Args:
55
+ config: The pytest config object.
56
+ items: The list of tests to be executed.
57
+
58
+ Returns:
59
+ """
60
+ for item in items:
61
+ keywords = list(set(item.keywords).intersection(_PYTEST_FLAGS))
62
+ if keywords and not any((config.getoption(f"--{kw}") for kw in keywords)):
63
+ skip = pytest.mark.skip(reason=f"need --{keywords[0]} option to run")
64
+ item.add_marker(skip)
@@ -0,0 +1 @@
1
+ PROJECT_ID=project_id
@@ -0,0 +1,12 @@
1
+ module "cloudbuild" {
2
+ source = "./../../../../../terraform/cloudbuild"
3
+
4
+ library = "genai"
5
+ project_id = ""
6
+ cloudbuildv2_repository_id = ""
7
+ cloudbuild_env_vars = {
8
+ }
9
+ cloudbuild_secret_vars = {
10
+ GOOGLE_API_KEY = ""
11
+ }
12
+ }
@@ -0,0 +1,31 @@
1
+ from typing import Any, List
2
+
3
+ from langchain_core.callbacks import BaseCallbackHandler
4
+ from langchain_core.outputs import LLMResult
5
+ from langchain_core.prompts import PromptTemplate
6
+
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+
9
+
10
+ class StreamingLLMCallbackHandler(BaseCallbackHandler):
11
+ def __init__(self, **kwargs: Any):
12
+ super().__init__(**kwargs)
13
+ self.tokens: List[Any] = []
14
+ self.generations: List[Any] = []
15
+
16
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
17
+ self.tokens.append(token)
18
+
19
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
20
+ self.generations.append(response.generations[0][0].text)
21
+
22
+
23
+ def test_streaming_callback() -> None:
24
+ prompt_template = "Tell me details about the Company {name} with 2 bullet point?"
25
+ cb = StreamingLLMCallbackHandler()
26
+ llm = ChatGoogleGenerativeAI(model="models/gemini-2.0-flash-001", callbacks=[cb])
27
+ llm_chain = PromptTemplate.from_template(prompt_template) | llm
28
+ for t in llm_chain.stream({"name": "Google"}):
29
+ pass
30
+ assert len(cb.tokens) > 1
31
+ assert len(cb.generations) == 1