langchain-b12 0.1.8__tar.gz → 0.1.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langchain_b12-0.1.8 → langchain_b12-0.1.10}/PKG-INFO +3 -1
- {langchain_b12-0.1.8 → langchain_b12-0.1.10}/pyproject.toml +3 -1
- {langchain_b12-0.1.8 → langchain_b12-0.1.10}/src/langchain_b12/genai/genai.py +134 -78
- langchain_b12-0.1.10/tests/test_genai.py +279 -0
- {langchain_b12-0.1.8 → langchain_b12-0.1.10}/uv.lock +21 -2
- langchain_b12-0.1.8/tests/test_genai.py +0 -41
- {langchain_b12-0.1.8 → langchain_b12-0.1.10}/.gitignore +0 -0
- {langchain_b12-0.1.8 → langchain_b12-0.1.10}/.python-version +0 -0
- {langchain_b12-0.1.8 → langchain_b12-0.1.10}/.vscode/extensions.json +0 -0
- {langchain_b12-0.1.8 → langchain_b12-0.1.10}/Makefile +0 -0
- {langchain_b12-0.1.8 → langchain_b12-0.1.10}/README.md +0 -0
- {langchain_b12-0.1.8 → langchain_b12-0.1.10}/src/langchain_b12/__init__.py +0 -0
- {langchain_b12-0.1.8 → langchain_b12-0.1.10}/src/langchain_b12/citations/citations.py +0 -0
- {langchain_b12-0.1.8 → langchain_b12-0.1.10}/src/langchain_b12/genai/embeddings.py +0 -0
- {langchain_b12-0.1.8 → langchain_b12-0.1.10}/src/langchain_b12/genai/genai_utils.py +0 -0
- {langchain_b12-0.1.8 → langchain_b12-0.1.10}/src/langchain_b12/py.typed +0 -0
- {langchain_b12-0.1.8 → langchain_b12-0.1.10}/tests/test_citation_mixin.py +0 -0
- {langchain_b12-0.1.8 → langchain_b12-0.1.10}/tests/test_citations.py +0 -0
- {langchain_b12-0.1.8 → langchain_b12-0.1.10}/tests/test_genai_utils.py +0 -0
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: langchain-b12
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.10
|
|
4
4
|
Summary: A reusable collection of tools and implementations for Langchain
|
|
5
5
|
Author-email: Vincent Min <vincent.min@b12-consulting.com>
|
|
6
6
|
Requires-Python: >=3.11
|
|
7
7
|
Requires-Dist: langchain-core>=0.3.60
|
|
8
|
+
Requires-Dist: pytest-anyio>=0.0.0
|
|
9
|
+
Requires-Dist: tenacity>=9.1.2
|
|
8
10
|
Description-Content-Type: text/markdown
|
|
9
11
|
|
|
10
12
|
# Langchain B12
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "langchain-b12"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.10"
|
|
4
4
|
description = "A reusable collection of tools and implementations for Langchain"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
authors = [
|
|
@@ -9,6 +9,8 @@ authors = [
|
|
|
9
9
|
requires-python = ">=3.11"
|
|
10
10
|
dependencies = [
|
|
11
11
|
"langchain-core>=0.3.60",
|
|
12
|
+
"pytest-anyio>=0.0.0",
|
|
13
|
+
"tenacity>=9.1.2",
|
|
12
14
|
]
|
|
13
15
|
|
|
14
16
|
[dependency-groups]
|
|
@@ -7,10 +7,6 @@ from typing import Any, Literal, cast
|
|
|
7
7
|
from google import genai
|
|
8
8
|
from google.genai import types
|
|
9
9
|
from google.oauth2 import service_account
|
|
10
|
-
from langchain_b12.genai.genai_utils import (
|
|
11
|
-
convert_messages_to_contents,
|
|
12
|
-
parse_response_candidate,
|
|
13
|
-
)
|
|
14
10
|
from langchain_core.callbacks import (
|
|
15
11
|
AsyncCallbackManagerForLLMRun,
|
|
16
12
|
CallbackManagerForLLMRun,
|
|
@@ -40,6 +36,18 @@ from langchain_core.utils.function_calling import (
|
|
|
40
36
|
convert_to_openai_tool,
|
|
41
37
|
)
|
|
42
38
|
from pydantic import BaseModel, ConfigDict, Field
|
|
39
|
+
from tenacity import (
|
|
40
|
+
retry,
|
|
41
|
+
retry_if_exception_type,
|
|
42
|
+
stop_after_attempt,
|
|
43
|
+
stop_never,
|
|
44
|
+
wait_exponential_jitter,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
from langchain_b12.genai.genai_utils import (
|
|
48
|
+
convert_messages_to_contents,
|
|
49
|
+
parse_response_candidate,
|
|
50
|
+
)
|
|
43
51
|
|
|
44
52
|
logger = logging.getLogger(__name__)
|
|
45
53
|
|
|
@@ -76,7 +84,7 @@ class ChatGenAI(BaseChatModel):
|
|
|
76
84
|
seed: int | None = None
|
|
77
85
|
"""Random seed for the generation."""
|
|
78
86
|
max_retries: int | None = Field(default=3)
|
|
79
|
-
"""Maximum number of retries when generation fails. None
|
|
87
|
+
"""Maximum number of retries when generation fails. None retries indefinitely."""
|
|
80
88
|
safety_settings: list[types.SafetySetting] | None = None
|
|
81
89
|
"""The default safety settings to use for all generations.
|
|
82
90
|
|
|
@@ -175,24 +183,10 @@ class ChatGenAI(BaseChatModel):
|
|
|
175
183
|
run_manager: CallbackManagerForLLMRun | None = None,
|
|
176
184
|
**kwargs: Any,
|
|
177
185
|
) -> ChatResult:
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
183
|
-
)
|
|
184
|
-
return generate_from_stream(stream_iter)
|
|
185
|
-
except Exception as e: # noqa: BLE001
|
|
186
|
-
if self.max_retries is None or attempts >= self.max_retries:
|
|
187
|
-
raise
|
|
188
|
-
attempts += 1
|
|
189
|
-
logger.warning(
|
|
190
|
-
"ChatGenAI._generate failed (attempt %d/%d). "
|
|
191
|
-
"Retrying... Error: %s",
|
|
192
|
-
attempts,
|
|
193
|
-
self.max_retries,
|
|
194
|
-
e,
|
|
195
|
-
)
|
|
186
|
+
stream_iter = self._stream(
|
|
187
|
+
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
188
|
+
)
|
|
189
|
+
return generate_from_stream(stream_iter)
|
|
196
190
|
|
|
197
191
|
async def _agenerate(
|
|
198
192
|
self,
|
|
@@ -201,24 +195,10 @@ class ChatGenAI(BaseChatModel):
|
|
|
201
195
|
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
|
202
196
|
**kwargs: Any,
|
|
203
197
|
) -> ChatResult:
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
209
|
-
)
|
|
210
|
-
return await agenerate_from_stream(stream_iter)
|
|
211
|
-
except Exception as e: # noqa: BLE001
|
|
212
|
-
if self.max_retries is None or attempts >= self.max_retries:
|
|
213
|
-
raise
|
|
214
|
-
attempts += 1
|
|
215
|
-
logger.warning(
|
|
216
|
-
"ChatGenAI._agenerate failed (attempt %d/%d). "
|
|
217
|
-
"Retrying... Error: %s",
|
|
218
|
-
attempts,
|
|
219
|
-
self.max_retries,
|
|
220
|
-
e,
|
|
221
|
-
)
|
|
198
|
+
stream_iter = self._astream(
|
|
199
|
+
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
200
|
+
)
|
|
201
|
+
return await agenerate_from_stream(stream_iter)
|
|
222
202
|
|
|
223
203
|
def _stream(
|
|
224
204
|
self,
|
|
@@ -228,26 +208,64 @@ class ChatGenAI(BaseChatModel):
|
|
|
228
208
|
**kwargs: Any,
|
|
229
209
|
) -> Iterator[ChatGenerationChunk]:
|
|
230
210
|
system_message, contents = self._prepare_request(messages=messages)
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
),
|
|
247
|
-
**kwargs,
|
|
211
|
+
|
|
212
|
+
@retry(
|
|
213
|
+
reraise=True,
|
|
214
|
+
stop=stop_after_attempt(self.max_retries + 1)
|
|
215
|
+
if self.max_retries is not None
|
|
216
|
+
else stop_never,
|
|
217
|
+
wait=wait_exponential_jitter(initial=1, max=60),
|
|
218
|
+
retry=retry_if_exception_type(Exception),
|
|
219
|
+
before_sleep=lambda retry_state: logger.warning(
|
|
220
|
+
"ChatGenAI._stream failed to start (attempt %d/%s). "
|
|
221
|
+
"Retrying in %.2fs... Error: %s",
|
|
222
|
+
retry_state.attempt_number,
|
|
223
|
+
self.max_retries + 1 if self.max_retries is not None else "∞",
|
|
224
|
+
retry_state.next_action.sleep,
|
|
225
|
+
retry_state.outcome.exception(),
|
|
248
226
|
),
|
|
249
227
|
)
|
|
250
|
-
|
|
228
|
+
def _initiate_stream() -> tuple[
|
|
229
|
+
ChatGenerationChunk,
|
|
230
|
+
Iterator[types.GenerateContentResponse],
|
|
231
|
+
UsageMetadata | None,
|
|
232
|
+
]:
|
|
233
|
+
"""Initialize stream and fetch first chunk. Retries only apply here."""
|
|
234
|
+
response_iter = self.client.models.generate_content_stream(
|
|
235
|
+
model=self.model_name,
|
|
236
|
+
contents=contents,
|
|
237
|
+
config=types.GenerateContentConfig(
|
|
238
|
+
system_instruction=system_message,
|
|
239
|
+
temperature=self.temperature,
|
|
240
|
+
top_k=self.top_k,
|
|
241
|
+
top_p=self.top_p,
|
|
242
|
+
max_output_tokens=self.max_output_tokens,
|
|
243
|
+
candidate_count=self.n,
|
|
244
|
+
stop_sequences=stop or self.stop,
|
|
245
|
+
safety_settings=self.safety_settings,
|
|
246
|
+
thinking_config=self.thinking_config,
|
|
247
|
+
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
|
248
|
+
disable=True,
|
|
249
|
+
),
|
|
250
|
+
**kwargs,
|
|
251
|
+
),
|
|
252
|
+
)
|
|
253
|
+
# Fetch first chunk to ensure connection is established
|
|
254
|
+
first_response = next(iter(response_iter))
|
|
255
|
+
first_chunk, total_usage = self._gemini_chunk_to_generation_chunk(
|
|
256
|
+
first_response, prev_total_usage=None
|
|
257
|
+
)
|
|
258
|
+
return first_chunk, response_iter, total_usage
|
|
259
|
+
|
|
260
|
+
# Retry only covers stream initialization and first chunk
|
|
261
|
+
first_chunk, response_iter, total_lc_usage = _initiate_stream()
|
|
262
|
+
|
|
263
|
+
# Yield first chunk
|
|
264
|
+
if run_manager and isinstance(first_chunk.message.content, str):
|
|
265
|
+
run_manager.on_llm_new_token(first_chunk.message.content)
|
|
266
|
+
yield first_chunk
|
|
267
|
+
|
|
268
|
+
# Continue streaming without retry (retries during streaming are not well defined)
|
|
251
269
|
for response_chunk in response_iter:
|
|
252
270
|
chunk, total_lc_usage = self._gemini_chunk_to_generation_chunk(
|
|
253
271
|
response_chunk, prev_total_usage=total_lc_usage
|
|
@@ -264,27 +282,65 @@ class ChatGenAI(BaseChatModel):
|
|
|
264
282
|
**kwargs: Any,
|
|
265
283
|
) -> AsyncIterator[ChatGenerationChunk]:
|
|
266
284
|
system_message, contents = self._prepare_request(messages=messages)
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
),
|
|
283
|
-
**kwargs,
|
|
285
|
+
|
|
286
|
+
@retry(
|
|
287
|
+
reraise=True,
|
|
288
|
+
stop=stop_after_attempt(self.max_retries + 1)
|
|
289
|
+
if self.max_retries is not None
|
|
290
|
+
else stop_never,
|
|
291
|
+
wait=wait_exponential_jitter(initial=1, max=60),
|
|
292
|
+
retry=retry_if_exception_type(Exception),
|
|
293
|
+
before_sleep=lambda retry_state: logger.warning(
|
|
294
|
+
"ChatGenAI._astream failed to start (attempt %d/%s). "
|
|
295
|
+
"Retrying in %.2fs... Error: %s",
|
|
296
|
+
retry_state.attempt_number,
|
|
297
|
+
self.max_retries + 1 if self.max_retries is not None else "∞",
|
|
298
|
+
retry_state.next_action.sleep,
|
|
299
|
+
retry_state.outcome.exception(),
|
|
284
300
|
),
|
|
285
301
|
)
|
|
286
|
-
|
|
287
|
-
|
|
302
|
+
async def _initiate_stream() -> tuple[
|
|
303
|
+
ChatGenerationChunk,
|
|
304
|
+
AsyncIterator[types.GenerateContentResponse],
|
|
305
|
+
UsageMetadata | None,
|
|
306
|
+
]:
|
|
307
|
+
"""Initialize stream and fetch first chunk. Retries only apply here."""
|
|
308
|
+
response_iter = await self.client.aio.models.generate_content_stream(
|
|
309
|
+
model=self.model_name,
|
|
310
|
+
contents=contents,
|
|
311
|
+
config=types.GenerateContentConfig(
|
|
312
|
+
system_instruction=system_message,
|
|
313
|
+
temperature=self.temperature,
|
|
314
|
+
top_k=self.top_k,
|
|
315
|
+
top_p=self.top_p,
|
|
316
|
+
max_output_tokens=self.max_output_tokens,
|
|
317
|
+
candidate_count=self.n,
|
|
318
|
+
stop_sequences=stop or self.stop,
|
|
319
|
+
safety_settings=self.safety_settings,
|
|
320
|
+
thinking_config=self.thinking_config,
|
|
321
|
+
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
|
322
|
+
disable=True,
|
|
323
|
+
),
|
|
324
|
+
**kwargs,
|
|
325
|
+
),
|
|
326
|
+
)
|
|
327
|
+
# Fetch first chunk to ensure connection is established
|
|
328
|
+
first_response = await response_iter.__anext__()
|
|
329
|
+
first_chunk, total_usage = self._gemini_chunk_to_generation_chunk(
|
|
330
|
+
first_response, prev_total_usage=None
|
|
331
|
+
)
|
|
332
|
+
return first_chunk, response_iter, total_usage
|
|
333
|
+
|
|
334
|
+
# Retry only covers stream initialization and first chunk
|
|
335
|
+
first_chunk, response_iter, total_lc_usage = await _initiate_stream()
|
|
336
|
+
|
|
337
|
+
# Yield first chunk
|
|
338
|
+
if run_manager and isinstance(first_chunk.message.content, str):
|
|
339
|
+
await run_manager.on_llm_new_token(first_chunk.message.content)
|
|
340
|
+
yield first_chunk
|
|
341
|
+
|
|
342
|
+
# Continue streaming without retry (retries during streaming are not well defined)
|
|
343
|
+
async for response_chunk in response_iter:
|
|
288
344
|
chunk, total_lc_usage = self._gemini_chunk_to_generation_chunk(
|
|
289
345
|
response_chunk, prev_total_usage=total_lc_usage
|
|
290
346
|
)
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from google.genai import Client, types
|
|
5
|
+
from langchain_b12.genai.genai import ChatGenAI
|
|
6
|
+
from langchain_core.messages import HumanMessage
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _make_response_chunk(text: str) -> types.GenerateContentResponse:
|
|
10
|
+
"""Helper to create a response chunk."""
|
|
11
|
+
return types.GenerateContentResponse(
|
|
12
|
+
candidates=[
|
|
13
|
+
types.Candidate(content=types.Content(parts=[types.Part(text=text)]))
|
|
14
|
+
]
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def test_chatgenai():
|
|
19
|
+
client = MagicMock(spec=Client)
|
|
20
|
+
model = ChatGenAI(client=client, model="foo", temperature=1)
|
|
21
|
+
assert model.model_name == "foo"
|
|
22
|
+
assert model.temperature == 1
|
|
23
|
+
assert model.client == client
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_chatgenai_invocation():
|
|
27
|
+
client: Client = MagicMock(spec=Client)
|
|
28
|
+
client.models.generate_content_stream.return_value = iter(
|
|
29
|
+
(
|
|
30
|
+
_make_response_chunk("bar"),
|
|
31
|
+
_make_response_chunk("baz"),
|
|
32
|
+
)
|
|
33
|
+
)
|
|
34
|
+
model = ChatGenAI(client=client)
|
|
35
|
+
messages = [HumanMessage(content="foo")]
|
|
36
|
+
response = model.invoke(messages)
|
|
37
|
+
method: MagicMock = client.models.generate_content_stream
|
|
38
|
+
method.assert_called_once()
|
|
39
|
+
assert response.content == "barbaz"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _make_success_iter():
|
|
43
|
+
"""Helper to create a successful streaming iterator."""
|
|
44
|
+
return iter([_make_response_chunk("success")])
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
|
|
48
|
+
def test_chatgenai_retry_succeeds_after_failure(mock_wait):
|
|
49
|
+
"""Test that retry logic succeeds after transient failures."""
|
|
50
|
+
client: Client = MagicMock(spec=Client)
|
|
51
|
+
|
|
52
|
+
# First two calls fail, third succeeds
|
|
53
|
+
client.models.generate_content_stream.side_effect = [
|
|
54
|
+
Exception("Transient error 1"),
|
|
55
|
+
Exception("Transient error 2"),
|
|
56
|
+
_make_success_iter(),
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
model = ChatGenAI(client=client, max_retries=3)
|
|
60
|
+
messages = [HumanMessage(content="foo")]
|
|
61
|
+
response = model.invoke(messages)
|
|
62
|
+
|
|
63
|
+
assert response.content == "success"
|
|
64
|
+
assert client.models.generate_content_stream.call_count == 3
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
|
|
68
|
+
def test_chatgenai_retry_exhausted_raises(mock_wait):
|
|
69
|
+
"""Test that exception is raised after all retries are exhausted."""
|
|
70
|
+
client: Client = MagicMock(spec=Client)
|
|
71
|
+
|
|
72
|
+
# All calls fail
|
|
73
|
+
client.models.generate_content_stream.side_effect = Exception("Persistent error")
|
|
74
|
+
|
|
75
|
+
model = ChatGenAI(client=client, max_retries=2)
|
|
76
|
+
messages = [HumanMessage(content="foo")]
|
|
77
|
+
|
|
78
|
+
with pytest.raises(Exception, match="Persistent error"):
|
|
79
|
+
model.invoke(messages)
|
|
80
|
+
|
|
81
|
+
# Initial attempt + 2 retries = 3 total calls
|
|
82
|
+
assert client.models.generate_content_stream.call_count == 3
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
|
|
86
|
+
def test_chatgenai_no_retry_when_max_retries_zero(mock_wait):
|
|
87
|
+
"""Test that no retries occur when max_retries=0."""
|
|
88
|
+
client: Client = MagicMock(spec=Client)
|
|
89
|
+
client.models.generate_content_stream.side_effect = Exception("Error")
|
|
90
|
+
|
|
91
|
+
model = ChatGenAI(client=client, max_retries=0)
|
|
92
|
+
messages = [HumanMessage(content="foo")]
|
|
93
|
+
|
|
94
|
+
with pytest.raises(Exception, match="Error"):
|
|
95
|
+
model.invoke(messages)
|
|
96
|
+
|
|
97
|
+
# Only 1 attempt, no retries
|
|
98
|
+
assert client.models.generate_content_stream.call_count == 1
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def test_chatgenai_no_retry_on_success():
|
|
102
|
+
"""Test that no retries occur when first attempt succeeds."""
|
|
103
|
+
client: Client = MagicMock(spec=Client)
|
|
104
|
+
client.models.generate_content_stream.return_value = _make_success_iter()
|
|
105
|
+
|
|
106
|
+
model = ChatGenAI(client=client, max_retries=3)
|
|
107
|
+
messages = [HumanMessage(content="foo")]
|
|
108
|
+
response = model.invoke(messages)
|
|
109
|
+
|
|
110
|
+
assert response.content == "success"
|
|
111
|
+
assert client.models.generate_content_stream.call_count == 1
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
# --- Streaming behavior tests ---
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def test_stream_yields_chunks_immediately():
|
|
118
|
+
"""Test that stream yields chunks as they arrive, not buffered."""
|
|
119
|
+
client: Client = MagicMock(spec=Client)
|
|
120
|
+
chunks_yielded: list[str] = []
|
|
121
|
+
|
|
122
|
+
def mock_stream():
|
|
123
|
+
for text in ["chunk1", "chunk2", "chunk3"]:
|
|
124
|
+
# Track when chunks are yielded from the source
|
|
125
|
+
chunks_yielded.append(f"source:{text}")
|
|
126
|
+
yield _make_response_chunk(text)
|
|
127
|
+
|
|
128
|
+
client.models.generate_content_stream.return_value = mock_stream()
|
|
129
|
+
|
|
130
|
+
model = ChatGenAI(client=client, max_retries=3)
|
|
131
|
+
messages = [HumanMessage(content="foo")]
|
|
132
|
+
|
|
133
|
+
received: list[str] = []
|
|
134
|
+
for chunk in model.stream(messages):
|
|
135
|
+
received.append(chunk.content)
|
|
136
|
+
# After receiving each chunk, check that source yielded it
|
|
137
|
+
assert len(received) == len([c for c in chunks_yielded if c.startswith("source:")])
|
|
138
|
+
|
|
139
|
+
assert received == ["chunk1", "chunk2", "chunk3"]
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
|
|
143
|
+
def test_stream_no_retry_after_first_chunk(mock_wait):
|
|
144
|
+
"""Test that errors after first chunk are NOT retried."""
|
|
145
|
+
client: Client = MagicMock(spec=Client)
|
|
146
|
+
|
|
147
|
+
def failing_after_first():
|
|
148
|
+
yield _make_response_chunk("first")
|
|
149
|
+
raise Exception("Mid-stream error")
|
|
150
|
+
|
|
151
|
+
client.models.generate_content_stream.return_value = failing_after_first()
|
|
152
|
+
|
|
153
|
+
model = ChatGenAI(client=client, max_retries=3)
|
|
154
|
+
messages = [HumanMessage(content="foo")]
|
|
155
|
+
|
|
156
|
+
chunks = []
|
|
157
|
+
with pytest.raises(Exception, match="Mid-stream error"):
|
|
158
|
+
for chunk in model.stream(messages):
|
|
159
|
+
chunks.append(chunk.content)
|
|
160
|
+
|
|
161
|
+
# First chunk was received
|
|
162
|
+
assert chunks == ["first"]
|
|
163
|
+
# Only one call - no retry after first chunk
|
|
164
|
+
assert client.models.generate_content_stream.call_count == 1
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
|
|
168
|
+
def test_stream_retry_on_first_chunk_failure(mock_wait):
|
|
169
|
+
"""Test that failure on first chunk triggers retry."""
|
|
170
|
+
client: Client = MagicMock(spec=Client)
|
|
171
|
+
|
|
172
|
+
def fail_on_first_next():
|
|
173
|
+
raise Exception("First chunk error")
|
|
174
|
+
yield # Make it a generator
|
|
175
|
+
|
|
176
|
+
def success_stream():
|
|
177
|
+
yield _make_response_chunk("success1")
|
|
178
|
+
yield _make_response_chunk("success2")
|
|
179
|
+
|
|
180
|
+
client.models.generate_content_stream.side_effect = [
|
|
181
|
+
fail_on_first_next(),
|
|
182
|
+
success_stream(),
|
|
183
|
+
]
|
|
184
|
+
|
|
185
|
+
model = ChatGenAI(client=client, max_retries=3)
|
|
186
|
+
messages = [HumanMessage(content="foo")]
|
|
187
|
+
|
|
188
|
+
chunks = [chunk.content for chunk in model.stream(messages)]
|
|
189
|
+
assert chunks == ["success1", "success2"]
|
|
190
|
+
assert client.models.generate_content_stream.call_count == 2
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
# --- Async streaming tests ---
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
async def _async_iter(items):
|
|
197
|
+
"""Helper to create an async iterator from items."""
|
|
198
|
+
for item in items:
|
|
199
|
+
yield item
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@pytest.mark.anyio
|
|
203
|
+
async def test_astream_yields_chunks_immediately():
|
|
204
|
+
"""Test that async stream yields chunks as they arrive."""
|
|
205
|
+
client: Client = MagicMock(spec=Client)
|
|
206
|
+
|
|
207
|
+
chunks = [
|
|
208
|
+
_make_response_chunk("async1"),
|
|
209
|
+
_make_response_chunk("async2"),
|
|
210
|
+
_make_response_chunk("async3"),
|
|
211
|
+
]
|
|
212
|
+
|
|
213
|
+
# generate_content_stream returns a coroutine that resolves to async iterator
|
|
214
|
+
client.aio.models.generate_content_stream = AsyncMock(
|
|
215
|
+
return_value=_async_iter(chunks)
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
model = ChatGenAI(client=client, max_retries=3)
|
|
219
|
+
messages = [HumanMessage(content="foo")]
|
|
220
|
+
|
|
221
|
+
received: list[str] = []
|
|
222
|
+
async for chunk in model.astream(messages):
|
|
223
|
+
received.append(chunk.content)
|
|
224
|
+
|
|
225
|
+
assert received == ["async1", "async2", "async3"]
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@pytest.mark.anyio
|
|
229
|
+
@patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
|
|
230
|
+
async def test_astream_no_retry_after_first_chunk(mock_wait):
|
|
231
|
+
"""Test that errors after first chunk are NOT retried in async."""
|
|
232
|
+
client: Client = MagicMock(spec=Client)
|
|
233
|
+
|
|
234
|
+
async def failing_after_first():
|
|
235
|
+
yield _make_response_chunk("first")
|
|
236
|
+
raise Exception("Async mid-stream error")
|
|
237
|
+
|
|
238
|
+
client.aio.models.generate_content_stream = AsyncMock(
|
|
239
|
+
return_value=failing_after_first()
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
model = ChatGenAI(client=client, max_retries=3)
|
|
243
|
+
messages = [HumanMessage(content="foo")]
|
|
244
|
+
|
|
245
|
+
chunks = []
|
|
246
|
+
with pytest.raises(Exception, match="Async mid-stream error"):
|
|
247
|
+
async for chunk in model.astream(messages):
|
|
248
|
+
chunks.append(chunk.content)
|
|
249
|
+
|
|
250
|
+
assert chunks == ["first"]
|
|
251
|
+
assert client.aio.models.generate_content_stream.call_count == 1
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@pytest.mark.anyio
|
|
255
|
+
@patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
|
|
256
|
+
async def test_astream_retry_succeeds_after_failure(mock_wait):
|
|
257
|
+
"""Test that async retry logic works for initial failures."""
|
|
258
|
+
client: Client = MagicMock(spec=Client)
|
|
259
|
+
|
|
260
|
+
call_count = 0
|
|
261
|
+
|
|
262
|
+
async def side_effect_fn(*args, **kwargs):
|
|
263
|
+
nonlocal call_count
|
|
264
|
+
call_count += 1
|
|
265
|
+
if call_count == 1:
|
|
266
|
+
raise Exception("Async transient error")
|
|
267
|
+
return _async_iter([_make_response_chunk("async_success")])
|
|
268
|
+
|
|
269
|
+
client.aio.models.generate_content_stream = AsyncMock(side_effect=side_effect_fn)
|
|
270
|
+
|
|
271
|
+
model = ChatGenAI(client=client, max_retries=3)
|
|
272
|
+
messages = [HumanMessage(content="foo")]
|
|
273
|
+
|
|
274
|
+
chunks = []
|
|
275
|
+
async for chunk in model.astream(messages):
|
|
276
|
+
chunks.append(chunk.content)
|
|
277
|
+
|
|
278
|
+
assert chunks == ["async_success"]
|
|
279
|
+
assert client.aio.models.generate_content_stream.call_count == 2
|
|
@@ -252,10 +252,12 @@ wheels = [
|
|
|
252
252
|
|
|
253
253
|
[[package]]
|
|
254
254
|
name = "langchain-b12"
|
|
255
|
-
version = "0.1.
|
|
255
|
+
version = "0.1.9"
|
|
256
256
|
source = { editable = "." }
|
|
257
257
|
dependencies = [
|
|
258
258
|
{ name = "langchain-core" },
|
|
259
|
+
{ name = "pytest-anyio" },
|
|
260
|
+
{ name = "tenacity" },
|
|
259
261
|
]
|
|
260
262
|
|
|
261
263
|
[package.dev-dependencies]
|
|
@@ -272,7 +274,11 @@ google = [
|
|
|
272
274
|
]
|
|
273
275
|
|
|
274
276
|
[package.metadata]
|
|
275
|
-
requires-dist = [
|
|
277
|
+
requires-dist = [
|
|
278
|
+
{ name = "langchain-core", specifier = ">=0.3.60" },
|
|
279
|
+
{ name = "pytest-anyio", specifier = ">=0.0.0" },
|
|
280
|
+
{ name = "tenacity", specifier = ">=9.1.2" },
|
|
281
|
+
]
|
|
276
282
|
|
|
277
283
|
[package.metadata.requires-dev]
|
|
278
284
|
citations = [
|
|
@@ -617,6 +623,19 @@ wheels = [
|
|
|
617
623
|
{ url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474 },
|
|
618
624
|
]
|
|
619
625
|
|
|
626
|
+
[[package]]
|
|
627
|
+
name = "pytest-anyio"
|
|
628
|
+
version = "0.0.0"
|
|
629
|
+
source = { registry = "https://pypi.org/simple" }
|
|
630
|
+
dependencies = [
|
|
631
|
+
{ name = "anyio" },
|
|
632
|
+
{ name = "pytest" },
|
|
633
|
+
]
|
|
634
|
+
sdist = { url = "https://files.pythonhosted.org/packages/00/44/a02e5877a671b0940f21a7a0d9704c22097b123ed5cdbcca9cab39f17acc/pytest-anyio-0.0.0.tar.gz", hash = "sha256:b41234e9e9ad7ea1dbfefcc1d6891b23d5ef7c9f07ccf804c13a9cc338571fd3", size = 1560 }
|
|
635
|
+
wheels = [
|
|
636
|
+
{ url = "https://files.pythonhosted.org/packages/c6/25/bd6493ae85d0a281b6a0f248d0fdb1d9aa2b31f18bcd4a8800cf397d8209/pytest_anyio-0.0.0-py2.py3-none-any.whl", hash = "sha256:dc8b5c4741cb16ff90be37fddd585ca943ed12bbeb563de7ace6cd94441d8746", size = 1999 },
|
|
637
|
+
]
|
|
638
|
+
|
|
620
639
|
[[package]]
|
|
621
640
|
name = "pytest-asyncio"
|
|
622
641
|
version = "1.1.0"
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
from unittest.mock import MagicMock
|
|
2
|
-
|
|
3
|
-
from google.genai import Client, types
|
|
4
|
-
from langchain_b12.genai.genai import ChatGenAI
|
|
5
|
-
from langchain_core.messages import HumanMessage
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def test_chatgenai():
|
|
9
|
-
client = MagicMock(spec=Client)
|
|
10
|
-
model = ChatGenAI(client=client, model="foo", temperature=1)
|
|
11
|
-
assert model.model_name == "foo"
|
|
12
|
-
assert model.temperature == 1
|
|
13
|
-
assert model.client == client
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def test_chatgenai_invocation():
|
|
17
|
-
client: Client = MagicMock(spec=Client)
|
|
18
|
-
client.models.generate_content_stream.return_value = iter(
|
|
19
|
-
(
|
|
20
|
-
types.GenerateContentResponse(
|
|
21
|
-
candidates=[
|
|
22
|
-
types.Candidate(
|
|
23
|
-
content=types.Content(parts=[types.Part(text="bar")])
|
|
24
|
-
),
|
|
25
|
-
]
|
|
26
|
-
),
|
|
27
|
-
types.GenerateContentResponse(
|
|
28
|
-
candidates=[
|
|
29
|
-
types.Candidate(
|
|
30
|
-
content=types.Content(parts=[types.Part(text="baz")])
|
|
31
|
-
),
|
|
32
|
-
]
|
|
33
|
-
),
|
|
34
|
-
)
|
|
35
|
-
)
|
|
36
|
-
model = ChatGenAI(client=client)
|
|
37
|
-
messages = [HumanMessage(content="foo")]
|
|
38
|
-
response = model.invoke(messages)
|
|
39
|
-
method: MagicMock = client.models.generate_content_stream
|
|
40
|
-
method.assert_called_once()
|
|
41
|
-
assert response.content == "barbaz"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|