langchain-b12 0.1.9__tar.gz → 0.1.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langchain_b12-0.1.9 → langchain_b12-0.1.11}/PKG-INFO +1 -2
- {langchain_b12-0.1.9 → langchain_b12-0.1.11}/pyproject.toml +1 -2
- {langchain_b12-0.1.9 → langchain_b12-0.1.11}/src/langchain_b12/genai/genai.py +29 -53
- langchain_b12-0.1.11/tests/test_genai.py +210 -0
- {langchain_b12-0.1.9 → langchain_b12-0.1.11}/uv.lock +2 -6
- langchain_b12-0.1.9/tests/test_genai.py +0 -124
- {langchain_b12-0.1.9 → langchain_b12-0.1.11}/.gitignore +0 -0
- {langchain_b12-0.1.9 → langchain_b12-0.1.11}/.python-version +0 -0
- {langchain_b12-0.1.9 → langchain_b12-0.1.11}/.vscode/extensions.json +0 -0
- {langchain_b12-0.1.9 → langchain_b12-0.1.11}/Makefile +0 -0
- {langchain_b12-0.1.9 → langchain_b12-0.1.11}/README.md +0 -0
- {langchain_b12-0.1.9 → langchain_b12-0.1.11}/src/langchain_b12/__init__.py +0 -0
- {langchain_b12-0.1.9 → langchain_b12-0.1.11}/src/langchain_b12/citations/citations.py +0 -0
- {langchain_b12-0.1.9 → langchain_b12-0.1.11}/src/langchain_b12/genai/embeddings.py +0 -0
- {langchain_b12-0.1.9 → langchain_b12-0.1.11}/src/langchain_b12/genai/genai_utils.py +0 -0
- {langchain_b12-0.1.9 → langchain_b12-0.1.11}/src/langchain_b12/py.typed +0 -0
- {langchain_b12-0.1.9 → langchain_b12-0.1.11}/tests/test_citation_mixin.py +0 -0
- {langchain_b12-0.1.9 → langchain_b12-0.1.11}/tests/test_citations.py +0 -0
- {langchain_b12-0.1.9 → langchain_b12-0.1.11}/tests/test_genai_utils.py +0 -0
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: langchain-b12
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.11
|
|
4
4
|
Summary: A reusable collection of tools and implementations for Langchain
|
|
5
5
|
Author-email: Vincent Min <vincent.min@b12-consulting.com>
|
|
6
6
|
Requires-Python: >=3.11
|
|
7
7
|
Requires-Dist: langchain-core>=0.3.60
|
|
8
|
-
Requires-Dist: tenacity>=9.1.2
|
|
9
8
|
Description-Content-Type: text/markdown
|
|
10
9
|
|
|
11
10
|
# Langchain B12
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "langchain-b12"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.11"
|
|
4
4
|
description = "A reusable collection of tools and implementations for Langchain"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
authors = [
|
|
@@ -9,7 +9,6 @@ authors = [
|
|
|
9
9
|
requires-python = ">=3.11"
|
|
10
10
|
dependencies = [
|
|
11
11
|
"langchain-core>=0.3.60",
|
|
12
|
-
"tenacity>=9.1.2",
|
|
13
12
|
]
|
|
14
13
|
|
|
15
14
|
[dependency-groups]
|
|
@@ -35,14 +35,7 @@ from langchain_core.tools import BaseTool
|
|
|
35
35
|
from langchain_core.utils.function_calling import (
|
|
36
36
|
convert_to_openai_tool,
|
|
37
37
|
)
|
|
38
|
-
from pydantic import BaseModel, ConfigDict, Field
|
|
39
|
-
from tenacity import (
|
|
40
|
-
retry,
|
|
41
|
-
retry_if_exception_type,
|
|
42
|
-
stop_after_attempt,
|
|
43
|
-
stop_never,
|
|
44
|
-
wait_exponential_jitter,
|
|
45
|
-
)
|
|
38
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
46
39
|
|
|
47
40
|
from langchain_b12.genai.genai_utils import (
|
|
48
41
|
convert_messages_to_contents,
|
|
@@ -84,7 +77,9 @@ class ChatGenAI(BaseChatModel):
|
|
|
84
77
|
seed: int | None = None
|
|
85
78
|
"""Random seed for the generation."""
|
|
86
79
|
max_retries: int | None = Field(default=3)
|
|
87
|
-
"""Maximum number of retries
|
|
80
|
+
"""Maximum number of retries. Prefer `http_retry_options`, but this is kept for compatibility."""
|
|
81
|
+
http_retry_options: types.HttpRetryOptions | None = Field(default=None)
|
|
82
|
+
"""HTTP retry options for API requests. If not set, max_retries will be used to create default options."""
|
|
88
83
|
safety_settings: list[types.SafetySetting] | None = None
|
|
89
84
|
"""The default safety settings to use for all generations.
|
|
90
85
|
|
|
@@ -107,6 +102,13 @@ class ChatGenAI(BaseChatModel):
|
|
|
107
102
|
arbitrary_types_allowed=True,
|
|
108
103
|
)
|
|
109
104
|
|
|
105
|
+
@model_validator(mode="after")
|
|
106
|
+
def _setup_retry_options(self) -> "ChatGenAI":
|
|
107
|
+
"""Convert max_retries to http_retry_options if not explicitly set."""
|
|
108
|
+
if self.http_retry_options is None and self.max_retries is not None:
|
|
109
|
+
self.http_retry_options = types.HttpRetryOptions(attempts=self.max_retries)
|
|
110
|
+
return self
|
|
111
|
+
|
|
110
112
|
@property
|
|
111
113
|
def _llm_type(self) -> str:
|
|
112
114
|
return "vertexai"
|
|
@@ -183,29 +185,10 @@ class ChatGenAI(BaseChatModel):
|
|
|
183
185
|
run_manager: CallbackManagerForLLMRun | None = None,
|
|
184
186
|
**kwargs: Any,
|
|
185
187
|
) -> ChatResult:
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
stop=stop_after_attempt(self.max_retries + 1)
|
|
189
|
-
if self.max_retries is not None
|
|
190
|
-
else stop_never,
|
|
191
|
-
wait=wait_exponential_jitter(initial=1, max=60),
|
|
192
|
-
retry=retry_if_exception_type(Exception),
|
|
193
|
-
before_sleep=lambda retry_state: logger.warning(
|
|
194
|
-
"ChatGenAI._generate failed (attempt %d/%s). "
|
|
195
|
-
"Retrying in %.2fs... Error: %s",
|
|
196
|
-
retry_state.attempt_number,
|
|
197
|
-
self.max_retries + 1 if self.max_retries is not None else "∞",
|
|
198
|
-
retry_state.next_action.sleep,
|
|
199
|
-
retry_state.outcome.exception(),
|
|
200
|
-
),
|
|
188
|
+
stream_iter = self._stream(
|
|
189
|
+
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
201
190
|
)
|
|
202
|
-
|
|
203
|
-
stream_iter = self._stream(
|
|
204
|
-
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
205
|
-
)
|
|
206
|
-
return generate_from_stream(stream_iter)
|
|
207
|
-
|
|
208
|
-
return _generate_with_retry()
|
|
191
|
+
return generate_from_stream(stream_iter)
|
|
209
192
|
|
|
210
193
|
async def _agenerate(
|
|
211
194
|
self,
|
|
@@ -214,29 +197,10 @@ class ChatGenAI(BaseChatModel):
|
|
|
214
197
|
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
|
215
198
|
**kwargs: Any,
|
|
216
199
|
) -> ChatResult:
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
stop=stop_after_attempt(self.max_retries + 1)
|
|
220
|
-
if self.max_retries is not None
|
|
221
|
-
else stop_never,
|
|
222
|
-
wait=wait_exponential_jitter(initial=1, max=60),
|
|
223
|
-
retry=retry_if_exception_type(Exception),
|
|
224
|
-
before_sleep=lambda retry_state: logger.warning(
|
|
225
|
-
"ChatGenAI._agenerate failed (attempt %d/%s). "
|
|
226
|
-
"Retrying in %.2fs... Error: %s",
|
|
227
|
-
retry_state.attempt_number,
|
|
228
|
-
self.max_retries + 1 if self.max_retries is not None else "∞",
|
|
229
|
-
retry_state.next_action.sleep,
|
|
230
|
-
retry_state.outcome.exception(),
|
|
231
|
-
),
|
|
200
|
+
stream_iter = self._astream(
|
|
201
|
+
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
232
202
|
)
|
|
233
|
-
|
|
234
|
-
stream_iter = self._astream(
|
|
235
|
-
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
236
|
-
)
|
|
237
|
-
return await agenerate_from_stream(stream_iter)
|
|
238
|
-
|
|
239
|
-
return await _agenerate_with_retry()
|
|
203
|
+
return await agenerate_from_stream(stream_iter)
|
|
240
204
|
|
|
241
205
|
def _stream(
|
|
242
206
|
self,
|
|
@@ -246,10 +210,16 @@ class ChatGenAI(BaseChatModel):
|
|
|
246
210
|
**kwargs: Any,
|
|
247
211
|
) -> Iterator[ChatGenerationChunk]:
|
|
248
212
|
system_message, contents = self._prepare_request(messages=messages)
|
|
213
|
+
http_options = (
|
|
214
|
+
types.HttpOptions(retry_options=self.http_retry_options)
|
|
215
|
+
if self.http_retry_options
|
|
216
|
+
else None
|
|
217
|
+
)
|
|
249
218
|
response_iter = self.client.models.generate_content_stream(
|
|
250
219
|
model=self.model_name,
|
|
251
220
|
contents=contents,
|
|
252
221
|
config=types.GenerateContentConfig(
|
|
222
|
+
http_options=http_options,
|
|
253
223
|
system_instruction=system_message,
|
|
254
224
|
temperature=self.temperature,
|
|
255
225
|
top_k=self.top_k,
|
|
@@ -282,10 +252,16 @@ class ChatGenAI(BaseChatModel):
|
|
|
282
252
|
**kwargs: Any,
|
|
283
253
|
) -> AsyncIterator[ChatGenerationChunk]:
|
|
284
254
|
system_message, contents = self._prepare_request(messages=messages)
|
|
255
|
+
http_options = (
|
|
256
|
+
types.HttpOptions(retry_options=self.http_retry_options)
|
|
257
|
+
if self.http_retry_options
|
|
258
|
+
else None
|
|
259
|
+
)
|
|
285
260
|
response_iter = self.client.aio.models.generate_content_stream(
|
|
286
261
|
model=self.model_name,
|
|
287
262
|
contents=contents,
|
|
288
263
|
config=types.GenerateContentConfig(
|
|
264
|
+
http_options=http_options,
|
|
289
265
|
system_instruction=system_message,
|
|
290
266
|
temperature=self.temperature,
|
|
291
267
|
top_k=self.top_k,
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from google.genai import Client, types
|
|
5
|
+
from langchain_b12.genai.genai import ChatGenAI
|
|
6
|
+
from langchain_core.messages import HumanMessage
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _make_response_chunk(text: str) -> types.GenerateContentResponse:
|
|
10
|
+
"""Helper to create a response chunk."""
|
|
11
|
+
return types.GenerateContentResponse(
|
|
12
|
+
candidates=[
|
|
13
|
+
types.Candidate(content=types.Content(parts=[types.Part(text=text)]))
|
|
14
|
+
]
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def test_chatgenai():
|
|
19
|
+
client = MagicMock(spec=Client)
|
|
20
|
+
model = ChatGenAI(client=client, model="foo", temperature=1)
|
|
21
|
+
assert model.model_name == "foo"
|
|
22
|
+
assert model.temperature == 1
|
|
23
|
+
assert model.client == client
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_chatgenai_invocation():
|
|
27
|
+
client: Client = MagicMock(spec=Client)
|
|
28
|
+
client.models.generate_content_stream.return_value = iter(
|
|
29
|
+
(
|
|
30
|
+
_make_response_chunk("bar"),
|
|
31
|
+
_make_response_chunk("baz"),
|
|
32
|
+
)
|
|
33
|
+
)
|
|
34
|
+
model = ChatGenAI(client=client)
|
|
35
|
+
messages = [HumanMessage(content="foo")]
|
|
36
|
+
response = model.invoke(messages)
|
|
37
|
+
method: MagicMock = client.models.generate_content_stream
|
|
38
|
+
method.assert_called_once()
|
|
39
|
+
assert response.content == "barbaz"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def test_max_retries_converts_to_http_retry_options():
|
|
43
|
+
"""Test that max_retries is properly converted to HttpRetryOptions."""
|
|
44
|
+
client = MagicMock(spec=Client)
|
|
45
|
+
model = ChatGenAI(client=client, max_retries=5)
|
|
46
|
+
|
|
47
|
+
assert model.http_retry_options is not None
|
|
48
|
+
assert model.http_retry_options.attempts == 5
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_http_retry_options_passed_directly():
|
|
52
|
+
"""Test that http_retry_options can be passed directly."""
|
|
53
|
+
client = MagicMock(spec=Client)
|
|
54
|
+
retry_options = types.HttpRetryOptions(
|
|
55
|
+
attempts=10,
|
|
56
|
+
initial_delay=2.0,
|
|
57
|
+
max_delay=30.0,
|
|
58
|
+
)
|
|
59
|
+
model = ChatGenAI(client=client, http_retry_options=retry_options)
|
|
60
|
+
|
|
61
|
+
assert model.http_retry_options == retry_options
|
|
62
|
+
assert model.http_retry_options.attempts == 10
|
|
63
|
+
assert model.http_retry_options.initial_delay == 2.0
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def test_http_retry_options_overrides_max_retries():
|
|
67
|
+
"""Test that explicit http_retry_options overrides max_retries."""
|
|
68
|
+
client = MagicMock(spec=Client)
|
|
69
|
+
retry_options = types.HttpRetryOptions(attempts=7)
|
|
70
|
+
model = ChatGenAI(client=client, max_retries=3, http_retry_options=retry_options)
|
|
71
|
+
|
|
72
|
+
# http_retry_options should take precedence
|
|
73
|
+
assert model.http_retry_options == retry_options
|
|
74
|
+
assert model.http_retry_options.attempts == 7
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def test_retry_options_passed_in_stream_config():
|
|
78
|
+
"""Test that retry options are passed to GenerateContentConfig."""
|
|
79
|
+
client: Client = MagicMock(spec=Client)
|
|
80
|
+
client.models.generate_content_stream.return_value = iter(
|
|
81
|
+
[_make_response_chunk("success")]
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
model = ChatGenAI(client=client, max_retries=5)
|
|
85
|
+
messages = [HumanMessage(content="foo")]
|
|
86
|
+
response = model.invoke(messages)
|
|
87
|
+
|
|
88
|
+
# Verify the config was called with http_options containing retry_options
|
|
89
|
+
call_args = client.models.generate_content_stream.call_args
|
|
90
|
+
config = call_args.kwargs["config"]
|
|
91
|
+
assert config.http_options is not None
|
|
92
|
+
assert config.http_options.retry_options is not None
|
|
93
|
+
assert config.http_options.retry_options.attempts == 5
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_no_retry_options_when_max_retries_none():
|
|
97
|
+
"""Test that no http_retry_options are set when max_retries is None."""
|
|
98
|
+
client = MagicMock(spec=Client)
|
|
99
|
+
model = ChatGenAI(client=client, max_retries=None)
|
|
100
|
+
|
|
101
|
+
assert model.http_retry_options is None
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# --- Streaming behavior tests ---
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def test_stream_yields_chunks_immediately():
|
|
108
|
+
"""Test that stream yields chunks as they arrive, not buffered."""
|
|
109
|
+
client: Client = MagicMock(spec=Client)
|
|
110
|
+
chunks_yielded: list[str] = []
|
|
111
|
+
|
|
112
|
+
def mock_stream():
|
|
113
|
+
for text in ["chunk1", "chunk2", "chunk3"]:
|
|
114
|
+
# Track when chunks are yielded from the source
|
|
115
|
+
chunks_yielded.append(f"source:{text}")
|
|
116
|
+
yield _make_response_chunk(text)
|
|
117
|
+
|
|
118
|
+
client.models.generate_content_stream.return_value = mock_stream()
|
|
119
|
+
|
|
120
|
+
model = ChatGenAI(client=client, max_retries=3)
|
|
121
|
+
messages = [HumanMessage(content="foo")]
|
|
122
|
+
|
|
123
|
+
received: list[str] = []
|
|
124
|
+
for chunk in model.stream(messages):
|
|
125
|
+
received.append(chunk.content)
|
|
126
|
+
# After receiving each chunk, check that source yielded it
|
|
127
|
+
assert len(received) == len([c for c in chunks_yielded if c.startswith("source:")])
|
|
128
|
+
|
|
129
|
+
assert received == ["chunk1", "chunk2", "chunk3"]
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def test_stream_error_propagates():
|
|
133
|
+
"""Test that errors during streaming are propagated."""
|
|
134
|
+
client: Client = MagicMock(spec=Client)
|
|
135
|
+
|
|
136
|
+
def failing_stream():
|
|
137
|
+
yield _make_response_chunk("first")
|
|
138
|
+
raise Exception("Mid-stream error")
|
|
139
|
+
|
|
140
|
+
client.models.generate_content_stream.return_value = failing_stream()
|
|
141
|
+
|
|
142
|
+
model = ChatGenAI(client=client, max_retries=3)
|
|
143
|
+
messages = [HumanMessage(content="foo")]
|
|
144
|
+
|
|
145
|
+
chunks = []
|
|
146
|
+
with pytest.raises(Exception, match="Mid-stream error"):
|
|
147
|
+
for chunk in model.stream(messages):
|
|
148
|
+
chunks.append(chunk.content)
|
|
149
|
+
|
|
150
|
+
# First chunk was received before error
|
|
151
|
+
assert chunks == ["first"]
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
# --- Async streaming tests ---
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
async def _async_iter(items):
|
|
158
|
+
"""Helper to create an async iterator from items."""
|
|
159
|
+
for item in items:
|
|
160
|
+
yield item
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@pytest.mark.asyncio
|
|
164
|
+
async def test_astream_yields_chunks_immediately():
|
|
165
|
+
"""Test that async stream yields chunks as they arrive."""
|
|
166
|
+
client: Client = MagicMock(spec=Client)
|
|
167
|
+
|
|
168
|
+
chunks = [
|
|
169
|
+
_make_response_chunk("async1"),
|
|
170
|
+
_make_response_chunk("async2"),
|
|
171
|
+
_make_response_chunk("async3"),
|
|
172
|
+
]
|
|
173
|
+
|
|
174
|
+
# generate_content_stream returns a coroutine that resolves to async iterator
|
|
175
|
+
client.aio.models.generate_content_stream = AsyncMock(
|
|
176
|
+
return_value=_async_iter(chunks)
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
model = ChatGenAI(client=client, max_retries=3)
|
|
180
|
+
messages = [HumanMessage(content="foo")]
|
|
181
|
+
|
|
182
|
+
received: list[str] = []
|
|
183
|
+
async for chunk in model.astream(messages):
|
|
184
|
+
received.append(chunk.content)
|
|
185
|
+
|
|
186
|
+
assert received == ["async1", "async2", "async3"]
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@pytest.mark.asyncio
|
|
190
|
+
async def test_astream_error_propagates():
|
|
191
|
+
"""Test that errors during async streaming are propagated."""
|
|
192
|
+
client: Client = MagicMock(spec=Client)
|
|
193
|
+
|
|
194
|
+
async def failing_after_first():
|
|
195
|
+
yield _make_response_chunk("first")
|
|
196
|
+
raise Exception("Async mid-stream error")
|
|
197
|
+
|
|
198
|
+
client.aio.models.generate_content_stream = AsyncMock(
|
|
199
|
+
return_value=failing_after_first()
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
model = ChatGenAI(client=client, max_retries=3)
|
|
203
|
+
messages = [HumanMessage(content="foo")]
|
|
204
|
+
|
|
205
|
+
chunks = []
|
|
206
|
+
with pytest.raises(Exception, match="Async mid-stream error"):
|
|
207
|
+
async for chunk in model.astream(messages):
|
|
208
|
+
chunks.append(chunk.content)
|
|
209
|
+
|
|
210
|
+
assert chunks == ["first"]
|
|
@@ -252,11 +252,10 @@ wheels = [
|
|
|
252
252
|
|
|
253
253
|
[[package]]
|
|
254
254
|
name = "langchain-b12"
|
|
255
|
-
version = "0.1.
|
|
255
|
+
version = "0.1.10"
|
|
256
256
|
source = { editable = "." }
|
|
257
257
|
dependencies = [
|
|
258
258
|
{ name = "langchain-core" },
|
|
259
|
-
{ name = "tenacity" },
|
|
260
259
|
]
|
|
261
260
|
|
|
262
261
|
[package.dev-dependencies]
|
|
@@ -273,10 +272,7 @@ google = [
|
|
|
273
272
|
]
|
|
274
273
|
|
|
275
274
|
[package.metadata]
|
|
276
|
-
requires-dist = [
|
|
277
|
-
{ name = "langchain-core", specifier = ">=0.3.60" },
|
|
278
|
-
{ name = "tenacity", specifier = ">=9.1.2" },
|
|
279
|
-
]
|
|
275
|
+
requires-dist = [{ name = "langchain-core", specifier = ">=0.3.60" }]
|
|
280
276
|
|
|
281
277
|
[package.metadata.requires-dev]
|
|
282
278
|
citations = [
|
|
@@ -1,124 +0,0 @@
|
|
|
1
|
-
from unittest.mock import MagicMock, patch
|
|
2
|
-
|
|
3
|
-
import pytest
|
|
4
|
-
from google.genai import Client, types
|
|
5
|
-
from langchain_b12.genai.genai import ChatGenAI
|
|
6
|
-
from langchain_core.messages import HumanMessage
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
def test_chatgenai():
|
|
10
|
-
client = MagicMock(spec=Client)
|
|
11
|
-
model = ChatGenAI(client=client, model="foo", temperature=1)
|
|
12
|
-
assert model.model_name == "foo"
|
|
13
|
-
assert model.temperature == 1
|
|
14
|
-
assert model.client == client
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def test_chatgenai_invocation():
|
|
18
|
-
client: Client = MagicMock(spec=Client)
|
|
19
|
-
client.models.generate_content_stream.return_value = iter(
|
|
20
|
-
(
|
|
21
|
-
types.GenerateContentResponse(
|
|
22
|
-
candidates=[
|
|
23
|
-
types.Candidate(
|
|
24
|
-
content=types.Content(parts=[types.Part(text="bar")])
|
|
25
|
-
),
|
|
26
|
-
]
|
|
27
|
-
),
|
|
28
|
-
types.GenerateContentResponse(
|
|
29
|
-
candidates=[
|
|
30
|
-
types.Candidate(
|
|
31
|
-
content=types.Content(parts=[types.Part(text="baz")])
|
|
32
|
-
),
|
|
33
|
-
]
|
|
34
|
-
),
|
|
35
|
-
)
|
|
36
|
-
)
|
|
37
|
-
model = ChatGenAI(client=client)
|
|
38
|
-
messages = [HumanMessage(content="foo")]
|
|
39
|
-
response = model.invoke(messages)
|
|
40
|
-
method: MagicMock = client.models.generate_content_stream
|
|
41
|
-
method.assert_called_once()
|
|
42
|
-
assert response.content == "barbaz"
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def _make_success_response():
|
|
46
|
-
"""Helper to create a successful streaming response."""
|
|
47
|
-
return iter(
|
|
48
|
-
[
|
|
49
|
-
types.GenerateContentResponse(
|
|
50
|
-
candidates=[
|
|
51
|
-
types.Candidate(
|
|
52
|
-
content=types.Content(parts=[types.Part(text="success")])
|
|
53
|
-
),
|
|
54
|
-
]
|
|
55
|
-
),
|
|
56
|
-
]
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
@patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
|
|
61
|
-
def test_chatgenai_retry_succeeds_after_failure(mock_wait):
|
|
62
|
-
"""Test that retry logic succeeds after transient failures."""
|
|
63
|
-
client: Client = MagicMock(spec=Client)
|
|
64
|
-
|
|
65
|
-
# First two calls fail, third succeeds
|
|
66
|
-
client.models.generate_content_stream.side_effect = [
|
|
67
|
-
Exception("Transient error 1"),
|
|
68
|
-
Exception("Transient error 2"),
|
|
69
|
-
_make_success_response(),
|
|
70
|
-
]
|
|
71
|
-
|
|
72
|
-
model = ChatGenAI(client=client, max_retries=3)
|
|
73
|
-
messages = [HumanMessage(content="foo")]
|
|
74
|
-
response = model.invoke(messages)
|
|
75
|
-
|
|
76
|
-
assert response.content == "success"
|
|
77
|
-
assert client.models.generate_content_stream.call_count == 3
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
@patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
|
|
81
|
-
def test_chatgenai_retry_exhausted_raises(mock_wait):
|
|
82
|
-
"""Test that exception is raised after all retries are exhausted."""
|
|
83
|
-
client: Client = MagicMock(spec=Client)
|
|
84
|
-
|
|
85
|
-
# All calls fail
|
|
86
|
-
client.models.generate_content_stream.side_effect = Exception("Persistent error")
|
|
87
|
-
|
|
88
|
-
model = ChatGenAI(client=client, max_retries=2)
|
|
89
|
-
messages = [HumanMessage(content="foo")]
|
|
90
|
-
|
|
91
|
-
with pytest.raises(Exception, match="Persistent error"):
|
|
92
|
-
model.invoke(messages)
|
|
93
|
-
|
|
94
|
-
# Initial attempt + 2 retries = 3 total calls
|
|
95
|
-
assert client.models.generate_content_stream.call_count == 3
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
@patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
|
|
99
|
-
def test_chatgenai_no_retry_when_max_retries_zero(mock_wait):
|
|
100
|
-
"""Test that no retries occur when max_retries=0."""
|
|
101
|
-
client: Client = MagicMock(spec=Client)
|
|
102
|
-
client.models.generate_content_stream.side_effect = Exception("Error")
|
|
103
|
-
|
|
104
|
-
model = ChatGenAI(client=client, max_retries=0)
|
|
105
|
-
messages = [HumanMessage(content="foo")]
|
|
106
|
-
|
|
107
|
-
with pytest.raises(Exception, match="Error"):
|
|
108
|
-
model.invoke(messages)
|
|
109
|
-
|
|
110
|
-
# Only 1 attempt, no retries
|
|
111
|
-
assert client.models.generate_content_stream.call_count == 1
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
def test_chatgenai_no_retry_on_success():
|
|
115
|
-
"""Test that no retries occur when first attempt succeeds."""
|
|
116
|
-
client: Client = MagicMock(spec=Client)
|
|
117
|
-
client.models.generate_content_stream.return_value = _make_success_response()
|
|
118
|
-
|
|
119
|
-
model = ChatGenAI(client=client, max_retries=3)
|
|
120
|
-
messages = [HumanMessage(content="foo")]
|
|
121
|
-
response = model.invoke(messages)
|
|
122
|
-
|
|
123
|
-
assert response.content == "success"
|
|
124
|
-
assert client.models.generate_content_stream.call_count == 1
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|