langchain-google-genai 2.1.10__tar.gz → 2.1.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain-google-genai might be problematic. Click here for more details.
- {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/PKG-INFO +6 -15
- {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/_function_utils.py +12 -3
- {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/chat_models.py +53 -28
- {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/embeddings.py +51 -9
- {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/llms.py +1 -1
- langchain_google_genai-2.1.11/pyproject.toml +84 -0
- langchain_google_genai-2.1.11/tests/__init__.py +0 -0
- langchain_google_genai-2.1.11/tests/conftest.py +64 -0
- langchain_google_genai-2.1.11/tests/integration_tests/.env.example +1 -0
- langchain_google_genai-2.1.11/tests/integration_tests/__init__.py +0 -0
- langchain_google_genai-2.1.11/tests/integration_tests/terraform/main.tf +12 -0
- langchain_google_genai-2.1.11/tests/integration_tests/test_callbacks.py +31 -0
- langchain_google_genai-2.1.11/tests/integration_tests/test_chat_models.py +887 -0
- langchain_google_genai-2.1.11/tests/integration_tests/test_compile.py +7 -0
- langchain_google_genai-2.1.11/tests/integration_tests/test_embeddings.py +145 -0
- langchain_google_genai-2.1.11/tests/integration_tests/test_function_call.py +90 -0
- langchain_google_genai-2.1.11/tests/integration_tests/test_llms.py +100 -0
- langchain_google_genai-2.1.11/tests/integration_tests/test_standard.py +141 -0
- langchain_google_genai-2.1.11/tests/integration_tests/test_tools.py +37 -0
- langchain_google_genai-2.1.11/tests/unit_tests/__init__.py +0 -0
- langchain_google_genai-2.1.11/tests/unit_tests/__snapshots__/test_standard.ambr +63 -0
- langchain_google_genai-2.1.11/tests/unit_tests/test_chat_models.py +916 -0
- langchain_google_genai-2.1.11/tests/unit_tests/test_chat_models_protobuf_fix.py +132 -0
- langchain_google_genai-2.1.11/tests/unit_tests/test_common.py +31 -0
- langchain_google_genai-2.1.11/tests/unit_tests/test_embeddings.py +158 -0
- langchain_google_genai-2.1.11/tests/unit_tests/test_function_utils.py +1406 -0
- langchain_google_genai-2.1.11/tests/unit_tests/test_genai_aqa.py +95 -0
- langchain_google_genai-2.1.11/tests/unit_tests/test_google_vector_store.py +440 -0
- langchain_google_genai-2.1.11/tests/unit_tests/test_imports.py +20 -0
- langchain_google_genai-2.1.11/tests/unit_tests/test_llms.py +47 -0
- langchain_google_genai-2.1.11/tests/unit_tests/test_standard.py +42 -0
- langchain_google_genai-2.1.10/pyproject.toml +0 -109
- {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/LICENSE +0 -0
- {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/README.md +0 -0
- {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/__init__.py +0 -0
- {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/_common.py +0 -0
- {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/_enums.py +0 -0
- {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/_genai_extension.py +0 -0
- {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/_image_utils.py +0 -0
- {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/genai_aqa.py +0 -0
- {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/google_vector_store.py +0 -0
- {langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/py.typed +0 -0
|
@@ -1,22 +1,14 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: langchain-google-genai
|
|
3
|
-
Version: 2.1.
|
|
3
|
+
Version: 2.1.11
|
|
4
4
|
Summary: An integration package connecting Google's genai package and LangChain
|
|
5
|
-
Home-page: https://github.com/langchain-ai/langchain-google
|
|
6
5
|
License: MIT
|
|
7
|
-
Requires-Python: >=3.9,<4.0
|
|
8
|
-
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
-
Classifier: Programming Language :: Python :: 3
|
|
10
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
11
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
12
|
-
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
-
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
-
Requires-Dist: filetype (>=1.2.0,<2.0.0)
|
|
15
|
-
Requires-Dist: google-ai-generativelanguage (>=0.6.18,<0.7.0)
|
|
16
|
-
Requires-Dist: langchain-core (>=0.3.75,<0.4.0)
|
|
17
|
-
Requires-Dist: pydantic (>=2,<3)
|
|
18
|
-
Project-URL: Repository, https://github.com/langchain-ai/langchain-google
|
|
19
6
|
Project-URL: Source Code, https://github.com/langchain-ai/langchain-google/tree/main/libs/genai
|
|
7
|
+
Requires-Python: >=3.9
|
|
8
|
+
Requires-Dist: langchain-core>=0.3.75
|
|
9
|
+
Requires-Dist: google-ai-generativelanguage<1,>=0.7
|
|
10
|
+
Requires-Dist: pydantic<3,>=2
|
|
11
|
+
Requires-Dist: filetype<2,>=1.2
|
|
20
12
|
Description-Content-Type: text/markdown
|
|
21
13
|
|
|
22
14
|
# langchain-google-genai
|
|
@@ -258,4 +250,3 @@ print("Answerable probability:", response.answerable_probability)
|
|
|
258
250
|
- [LangChain Documentation](https://docs.langchain.com/)
|
|
259
251
|
- [Google Generative AI SDK](https://googleapis.github.io/python-genai/)
|
|
260
252
|
- [Gemini Model Documentation](https://ai.google.dev/)
|
|
261
|
-
|
|
@@ -331,6 +331,10 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
|
|
|
331
331
|
logger.warning(f"Value '{v}' is not supported in schema, ignoring v={v}")
|
|
332
332
|
continue
|
|
333
333
|
properties_item: Dict[str, Union[str, int, Dict, List]] = {}
|
|
334
|
+
|
|
335
|
+
# Get description from original schema before any modifications
|
|
336
|
+
description = v.get("description")
|
|
337
|
+
|
|
334
338
|
if v.get("anyOf") and all(
|
|
335
339
|
anyOf_type.get("type") != "null" for anyOf_type in v.get("anyOf", [])
|
|
336
340
|
):
|
|
@@ -338,6 +342,9 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
|
|
|
338
342
|
_format_json_schema_to_gapic(anyOf_type)
|
|
339
343
|
for anyOf_type in v.get("anyOf", [])
|
|
340
344
|
]
|
|
345
|
+
# For non-nullable anyOf, we still need to set a type
|
|
346
|
+
item_type_ = _get_type_from_schema(v)
|
|
347
|
+
properties_item["type_"] = item_type_
|
|
341
348
|
elif v.get("type") or v.get("anyOf") or v.get("type_"):
|
|
342
349
|
item_type_ = _get_type_from_schema(v)
|
|
343
350
|
properties_item["type_"] = item_type_
|
|
@@ -354,7 +361,6 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
|
|
|
354
361
|
if v.get("enum"):
|
|
355
362
|
properties_item["enum"] = v["enum"]
|
|
356
363
|
|
|
357
|
-
description = v.get("description")
|
|
358
364
|
if description and isinstance(description, str):
|
|
359
365
|
properties_item["description"] = description
|
|
360
366
|
|
|
@@ -377,8 +383,9 @@ def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
|
|
|
377
383
|
properties_item["required"] = [
|
|
378
384
|
k for k, v in v_properties.items() if "default" not in v
|
|
379
385
|
]
|
|
380
|
-
|
|
381
|
-
#
|
|
386
|
+
elif not v.get("additionalProperties"):
|
|
387
|
+
# Only provide dummy type for object without properties AND without
|
|
388
|
+
# additionalProperties
|
|
382
389
|
properties_item["type_"] = glm.Type.STRING
|
|
383
390
|
|
|
384
391
|
if k == "title" and "description" not in properties_item:
|
|
@@ -414,6 +421,8 @@ def _get_items_from_schema(schema: Union[Dict, List, str]) -> Dict[str, Any]:
|
|
|
414
421
|
items["nullable"] = True
|
|
415
422
|
if "required" in schema:
|
|
416
423
|
items["required"] = schema["required"]
|
|
424
|
+
if "enum" in schema:
|
|
425
|
+
items["enum"] = schema["enum"]
|
|
417
426
|
else:
|
|
418
427
|
# str
|
|
419
428
|
items["type_"] = _get_type_from_schema({"type": schema})
|
|
@@ -92,13 +92,7 @@ from langchain_core.utils.function_calling import (
|
|
|
92
92
|
)
|
|
93
93
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
|
94
94
|
from langchain_core.utils.utils import _build_model_kwargs
|
|
95
|
-
from pydantic import
|
|
96
|
-
BaseModel,
|
|
97
|
-
ConfigDict,
|
|
98
|
-
Field,
|
|
99
|
-
SecretStr,
|
|
100
|
-
model_validator,
|
|
101
|
-
)
|
|
95
|
+
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
|
102
96
|
from pydantic.v1 import BaseModel as BaseModelV1
|
|
103
97
|
from tenacity import (
|
|
104
98
|
before_sleep_log,
|
|
@@ -496,7 +490,12 @@ def _parse_chat_history(
|
|
|
496
490
|
messages: List[Content] = []
|
|
497
491
|
|
|
498
492
|
if convert_system_message_to_human:
|
|
499
|
-
warnings.warn(
|
|
493
|
+
warnings.warn(
|
|
494
|
+
"The 'convert_system_message_to_human' parameter is deprecated and will be "
|
|
495
|
+
"removed in a future version. Use system instructions instead.",
|
|
496
|
+
DeprecationWarning,
|
|
497
|
+
stacklevel=2,
|
|
498
|
+
)
|
|
500
499
|
|
|
501
500
|
system_instruction: Optional[Content] = None
|
|
502
501
|
messages_without_tool_messages = [
|
|
@@ -659,9 +658,15 @@ def _parse_response_candidate(
|
|
|
659
658
|
function_call = {"name": part.function_call.name}
|
|
660
659
|
# dump to match other function calling llm for now
|
|
661
660
|
function_call_args_dict = proto.Message.to_dict(part.function_call)["args"]
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
661
|
+
|
|
662
|
+
# Fix: Correct integer-like floats from protobuf conversion
|
|
663
|
+
# The protobuf library sometimes converts integers to floats
|
|
664
|
+
corrected_args = {
|
|
665
|
+
k: int(v) if isinstance(v, float) and v.is_integer() else v
|
|
666
|
+
for k, v in function_call_args_dict.items()
|
|
667
|
+
}
|
|
668
|
+
|
|
669
|
+
function_call["arguments"] = json.dumps(corrected_args)
|
|
665
670
|
additional_kwargs["function_call"] = function_call
|
|
666
671
|
|
|
667
672
|
if streaming:
|
|
@@ -819,7 +824,15 @@ def _response_to_result(
|
|
|
819
824
|
if stream:
|
|
820
825
|
generations = [
|
|
821
826
|
ChatGenerationChunk(
|
|
822
|
-
message=AIMessageChunk(
|
|
827
|
+
message=AIMessageChunk(
|
|
828
|
+
content="",
|
|
829
|
+
response_metadata={
|
|
830
|
+
"prompt_feedback": proto.Message.to_dict(
|
|
831
|
+
response.prompt_feedback
|
|
832
|
+
)
|
|
833
|
+
},
|
|
834
|
+
),
|
|
835
|
+
generation_info={},
|
|
823
836
|
)
|
|
824
837
|
]
|
|
825
838
|
else:
|
|
@@ -1319,6 +1332,12 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
1319
1332
|
``cachedContents/{cachedContent}``.
|
|
1320
1333
|
"""
|
|
1321
1334
|
|
|
1335
|
+
stop: Optional[List[str]] = None
|
|
1336
|
+
"""Stop sequences for the model."""
|
|
1337
|
+
|
|
1338
|
+
streaming: Optional[bool] = None
|
|
1339
|
+
"""Whether to stream responses from the model."""
|
|
1340
|
+
|
|
1322
1341
|
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
|
1323
1342
|
"""Holds any unexpected initialization parameters."""
|
|
1324
1343
|
|
|
@@ -1530,18 +1549,21 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
1530
1549
|
"response_modalities": self.response_modalities,
|
|
1531
1550
|
"thinking_config": (
|
|
1532
1551
|
(
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1552
|
+
(
|
|
1553
|
+
{"thinking_budget": self.thinking_budget}
|
|
1554
|
+
if self.thinking_budget is not None
|
|
1555
|
+
else {}
|
|
1556
|
+
)
|
|
1557
|
+
| (
|
|
1558
|
+
{"include_thoughts": self.include_thoughts}
|
|
1559
|
+
if self.include_thoughts is not None
|
|
1560
|
+
else {}
|
|
1561
|
+
)
|
|
1541
1562
|
)
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1563
|
+
if self.thinking_budget is not None
|
|
1564
|
+
or self.include_thoughts is not None
|
|
1565
|
+
else None
|
|
1566
|
+
),
|
|
1545
1567
|
}.items()
|
|
1546
1568
|
if v is not None
|
|
1547
1569
|
}
|
|
@@ -1783,7 +1805,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
1783
1805
|
generation_config: Optional[Dict[str, Any]] = None,
|
|
1784
1806
|
cached_content: Optional[str] = None,
|
|
1785
1807
|
**kwargs: Any,
|
|
1786
|
-
) ->
|
|
1808
|
+
) -> GenerateContentRequest:
|
|
1787
1809
|
if tool_choice and tool_config:
|
|
1788
1810
|
raise ValueError(
|
|
1789
1811
|
"Must specify at most one of tool_choice and tool_config, received "
|
|
@@ -1809,10 +1831,13 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
1809
1831
|
filtered_messages.append(message)
|
|
1810
1832
|
messages = filtered_messages
|
|
1811
1833
|
|
|
1812
|
-
|
|
1813
|
-
|
|
1814
|
-
|
|
1815
|
-
|
|
1834
|
+
if self.convert_system_message_to_human:
|
|
1835
|
+
system_instruction, history = _parse_chat_history(
|
|
1836
|
+
messages,
|
|
1837
|
+
convert_system_message_to_human=self.convert_system_message_to_human,
|
|
1838
|
+
)
|
|
1839
|
+
else:
|
|
1840
|
+
system_instruction, history = _parse_chat_history(messages)
|
|
1816
1841
|
if tool_choice:
|
|
1817
1842
|
if not formatted_tools:
|
|
1818
1843
|
msg = (
|
{langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/embeddings.py
RENAMED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import re
|
|
2
3
|
import string
|
|
3
4
|
from typing import Any, Dict, List, Optional
|
|
@@ -26,6 +27,14 @@ _MAX_TOKENS_PER_BATCH = 20000
|
|
|
26
27
|
_DEFAULT_BATCH_SIZE = 100
|
|
27
28
|
|
|
28
29
|
|
|
30
|
+
def _is_event_loop_running() -> bool:
|
|
31
|
+
try:
|
|
32
|
+
asyncio.get_running_loop()
|
|
33
|
+
return True
|
|
34
|
+
except RuntimeError:
|
|
35
|
+
return False
|
|
36
|
+
|
|
37
|
+
|
|
29
38
|
class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
|
|
30
39
|
"""`Google Generative AI Embeddings`.
|
|
31
40
|
|
|
@@ -107,15 +116,48 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
|
|
|
107
116
|
client_options=self.client_options,
|
|
108
117
|
transport=self.transport,
|
|
109
118
|
)
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
transport
|
|
116
|
-
|
|
119
|
+
# Only initialize async client if there's an event loop running
|
|
120
|
+
# to avoid RuntimeError during synchronous initialization
|
|
121
|
+
if _is_event_loop_running():
|
|
122
|
+
# async clients don't support "rest" transport
|
|
123
|
+
transport = self.transport
|
|
124
|
+
if transport == "rest":
|
|
125
|
+
transport = "grpc_asyncio"
|
|
126
|
+
self.async_client = build_generative_async_service(
|
|
127
|
+
credentials=self.credentials,
|
|
128
|
+
api_key=google_api_key,
|
|
129
|
+
client_info=client_info,
|
|
130
|
+
client_options=self.client_options,
|
|
131
|
+
transport=transport,
|
|
132
|
+
)
|
|
133
|
+
else:
|
|
134
|
+
self.async_client = None
|
|
117
135
|
return self
|
|
118
136
|
|
|
137
|
+
@property
|
|
138
|
+
def _async_client(self) -> Any:
|
|
139
|
+
"""Get or create the async client when needed."""
|
|
140
|
+
if self.async_client is None:
|
|
141
|
+
if isinstance(self.google_api_key, SecretStr):
|
|
142
|
+
google_api_key: Optional[str] = self.google_api_key.get_secret_value()
|
|
143
|
+
else:
|
|
144
|
+
google_api_key = self.google_api_key
|
|
145
|
+
|
|
146
|
+
client_info = get_client_info("GoogleGenerativeAIEmbeddings")
|
|
147
|
+
# async clients don't support "rest" transport
|
|
148
|
+
transport = self.transport
|
|
149
|
+
if transport == "rest":
|
|
150
|
+
transport = "grpc_asyncio"
|
|
151
|
+
|
|
152
|
+
self.async_client = build_generative_async_service(
|
|
153
|
+
credentials=self.credentials,
|
|
154
|
+
api_key=google_api_key,
|
|
155
|
+
client_info=client_info,
|
|
156
|
+
client_options=self.client_options,
|
|
157
|
+
transport=transport,
|
|
158
|
+
)
|
|
159
|
+
return self.async_client
|
|
160
|
+
|
|
119
161
|
@staticmethod
|
|
120
162
|
def _split_by_punctuation(text: str) -> List[str]:
|
|
121
163
|
"""Splits a string by punctuation and whitespace characters."""
|
|
@@ -328,7 +370,7 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
|
|
|
328
370
|
]
|
|
329
371
|
|
|
330
372
|
try:
|
|
331
|
-
result = await self.
|
|
373
|
+
result = await self._async_client.batch_embed_contents(
|
|
332
374
|
BatchEmbedContentsRequest(requests=requests, model=self.model)
|
|
333
375
|
)
|
|
334
376
|
except Exception as e:
|
|
@@ -366,7 +408,7 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
|
|
|
366
408
|
title=title,
|
|
367
409
|
output_dimensionality=output_dimensionality,
|
|
368
410
|
)
|
|
369
|
-
result: EmbedContentResponse = await self.
|
|
411
|
+
result: EmbedContentResponse = await self._async_client.embed_content(
|
|
370
412
|
request
|
|
371
413
|
)
|
|
372
414
|
except Exception as e:
|
{langchain_google_genai-2.1.10 → langchain_google_genai-2.1.11}/langchain_google_genai/llms.py
RENAMED
|
@@ -41,7 +41,7 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
|
|
|
41
41
|
"""Needed for arg validation."""
|
|
42
42
|
# Get all valid field names, including aliases
|
|
43
43
|
valid_fields = set()
|
|
44
|
-
for field_name, field_info in self.model_fields.items():
|
|
44
|
+
for field_name, field_info in self.__class__.model_fields.items():
|
|
45
45
|
valid_fields.add(field_name)
|
|
46
46
|
if hasattr(field_info, "alias") and field_info.alias is not None:
|
|
47
47
|
valid_fields.add(field_info.alias)
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = [
|
|
3
|
+
"pdm-backend",
|
|
4
|
+
]
|
|
5
|
+
build-backend = "pdm.backend"
|
|
6
|
+
|
|
7
|
+
[project]
|
|
8
|
+
name = "langchain-google-genai"
|
|
9
|
+
version = "2.1.11"
|
|
10
|
+
description = "An integration package connecting Google's genai package and LangChain"
|
|
11
|
+
authors = []
|
|
12
|
+
requires-python = ">=3.9"
|
|
13
|
+
readme = "README.md"
|
|
14
|
+
repository = "https://github.com/langchain-ai/langchain-google"
|
|
15
|
+
dependencies = [
|
|
16
|
+
"langchain-core>=0.3.75",
|
|
17
|
+
"google-ai-generativelanguage>=0.7,<1",
|
|
18
|
+
"pydantic>=2,<3",
|
|
19
|
+
"filetype>=1.2,<2",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
[project.license]
|
|
23
|
+
text = "MIT"
|
|
24
|
+
|
|
25
|
+
[project.urls]
|
|
26
|
+
"Source Code" = "https://github.com/langchain-ai/langchain-google/tree/main/libs/genai"
|
|
27
|
+
|
|
28
|
+
[dependency-groups]
|
|
29
|
+
test = [
|
|
30
|
+
"pytest>=8.4,<9",
|
|
31
|
+
"freezegun>=1.5,<2",
|
|
32
|
+
"pytest-mock>=3.14,<4",
|
|
33
|
+
"syrupy>=4.9,<5",
|
|
34
|
+
"pytest-watcher>=0.4,<1",
|
|
35
|
+
"pytest-asyncio>=0.21,<1",
|
|
36
|
+
"pytest-retry>=1.7,<2",
|
|
37
|
+
"pytest-socket>=0.7,<1",
|
|
38
|
+
"numpy>=1.26.4; python_version<'3.13'",
|
|
39
|
+
"numpy>=2.1.0; python_version>='3.13'",
|
|
40
|
+
"langchain-tests>=0.3,<1",
|
|
41
|
+
]
|
|
42
|
+
test_integration = [
|
|
43
|
+
"pytest>=8.4,<9",
|
|
44
|
+
]
|
|
45
|
+
lint = [
|
|
46
|
+
"ruff>=0.12.10,<1",
|
|
47
|
+
]
|
|
48
|
+
typing = [
|
|
49
|
+
"mypy>=1.17.1,<2",
|
|
50
|
+
"types-requests>=2.31,<3",
|
|
51
|
+
"types-google-cloud-ndb>=2.2.0.1,<3",
|
|
52
|
+
"types-protobuf>=4.24.0.20240302,<5",
|
|
53
|
+
"numpy>=1.26.2",
|
|
54
|
+
]
|
|
55
|
+
dev = [
|
|
56
|
+
"types-requests>=2.31.0,<3",
|
|
57
|
+
"types-google-cloud-ndb>=2.2.0.1,<3",
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
[tool.ruff]
|
|
61
|
+
fix = true
|
|
62
|
+
|
|
63
|
+
[tool.ruff.lint]
|
|
64
|
+
select = [
|
|
65
|
+
"E",
|
|
66
|
+
"F",
|
|
67
|
+
"I",
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
[tool.mypy]
|
|
71
|
+
disallow_untyped_defs = "True"
|
|
72
|
+
|
|
73
|
+
[tool.coverage.run]
|
|
74
|
+
omit = [
|
|
75
|
+
"tests/*",
|
|
76
|
+
]
|
|
77
|
+
|
|
78
|
+
[tool.pytest.ini_options]
|
|
79
|
+
markers = [
|
|
80
|
+
"requires: mark tests as requiring a specific library",
|
|
81
|
+
"asyncio: mark tests as requiring asyncio",
|
|
82
|
+
"compile: mark placeholder test used to compile integration tests without running them",
|
|
83
|
+
]
|
|
84
|
+
asyncio_mode = "auto"
|
|
File without changes
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests configuration to be executed before tests execution.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
|
|
9
|
+
_RELEASE_FLAG = "release"
|
|
10
|
+
_GPU_FLAG = "gpu"
|
|
11
|
+
_LONG_FLAG = "long"
|
|
12
|
+
_EXTENDED_FLAG = "extended"
|
|
13
|
+
|
|
14
|
+
_PYTEST_FLAGS = [_RELEASE_FLAG, _GPU_FLAG, _LONG_FLAG, _EXTENDED_FLAG]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def pytest_addoption(parser: pytest.Parser) -> None:
|
|
18
|
+
"""
|
|
19
|
+
Add flags accepted by pytest CLI.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
parser: The pytest parser object.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
for flag in _PYTEST_FLAGS:
|
|
28
|
+
parser.addoption(
|
|
29
|
+
f"--{flag}", action="store_true", default=False, help=f"run {flag} tests"
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def pytest_configure(config: pytest.Config) -> None:
|
|
34
|
+
"""
|
|
35
|
+
Add pytest custom configuration.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
config: The pytest config object.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
"""
|
|
42
|
+
for flag in _PYTEST_FLAGS:
|
|
43
|
+
config.addinivalue_line(
|
|
44
|
+
"markers", f"{flag}: mark test to run as {flag} only test"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def pytest_collection_modifyitems(
|
|
49
|
+
config: pytest.Config, items: List[pytest.Item]
|
|
50
|
+
) -> None:
|
|
51
|
+
"""
|
|
52
|
+
Skip tests with a marker from our list that were not explicitly invoked.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
config: The pytest config object.
|
|
56
|
+
items: The list of tests to be executed.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
"""
|
|
60
|
+
for item in items:
|
|
61
|
+
keywords = list(set(item.keywords).intersection(_PYTEST_FLAGS))
|
|
62
|
+
if keywords and not any((config.getoption(f"--{kw}") for kw in keywords)):
|
|
63
|
+
skip = pytest.mark.skip(reason=f"need --{keywords[0]} option to run")
|
|
64
|
+
item.add_marker(skip)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
PROJECT_ID=project_id
|
|
File without changes
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from typing import Any, List
|
|
2
|
+
|
|
3
|
+
from langchain_core.callbacks import BaseCallbackHandler
|
|
4
|
+
from langchain_core.outputs import LLMResult
|
|
5
|
+
from langchain_core.prompts import PromptTemplate
|
|
6
|
+
|
|
7
|
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class StreamingLLMCallbackHandler(BaseCallbackHandler):
|
|
11
|
+
def __init__(self, **kwargs: Any):
|
|
12
|
+
super().__init__(**kwargs)
|
|
13
|
+
self.tokens: List[Any] = []
|
|
14
|
+
self.generations: List[Any] = []
|
|
15
|
+
|
|
16
|
+
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
17
|
+
self.tokens.append(token)
|
|
18
|
+
|
|
19
|
+
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
|
|
20
|
+
self.generations.append(response.generations[0][0].text)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_streaming_callback() -> None:
|
|
24
|
+
prompt_template = "Tell me details about the Company {name} with 2 bullet point?"
|
|
25
|
+
cb = StreamingLLMCallbackHandler()
|
|
26
|
+
llm = ChatGoogleGenerativeAI(model="models/gemini-2.0-flash-001", callbacks=[cb])
|
|
27
|
+
llm_chain = PromptTemplate.from_template(prompt_template) | llm
|
|
28
|
+
for t in llm_chain.stream({"name": "Google"}):
|
|
29
|
+
pass
|
|
30
|
+
assert len(cb.tokens) > 1
|
|
31
|
+
assert len(cb.generations) == 1
|