langchain-b12 0.1.7__tar.gz → 0.1.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langchain_b12-0.1.7 → langchain_b12-0.1.9}/PKG-INFO +2 -1
- {langchain_b12-0.1.7 → langchain_b12-0.1.9}/pyproject.toml +2 -1
- {langchain_b12-0.1.7 → langchain_b12-0.1.9}/src/langchain_b12/genai/genai.py +58 -10
- langchain_b12-0.1.9/tests/test_genai.py +124 -0
- {langchain_b12-0.1.7 → langchain_b12-0.1.9}/uv.lock +248 -260
- langchain_b12-0.1.7/tests/test_genai.py +0 -41
- {langchain_b12-0.1.7 → langchain_b12-0.1.9}/.gitignore +0 -0
- {langchain_b12-0.1.7 → langchain_b12-0.1.9}/.python-version +0 -0
- {langchain_b12-0.1.7 → langchain_b12-0.1.9}/.vscode/extensions.json +0 -0
- {langchain_b12-0.1.7 → langchain_b12-0.1.9}/Makefile +0 -0
- {langchain_b12-0.1.7 → langchain_b12-0.1.9}/README.md +0 -0
- {langchain_b12-0.1.7 → langchain_b12-0.1.9}/src/langchain_b12/__init__.py +0 -0
- {langchain_b12-0.1.7 → langchain_b12-0.1.9}/src/langchain_b12/citations/citations.py +0 -0
- {langchain_b12-0.1.7 → langchain_b12-0.1.9}/src/langchain_b12/genai/embeddings.py +0 -0
- {langchain_b12-0.1.7 → langchain_b12-0.1.9}/src/langchain_b12/genai/genai_utils.py +0 -0
- {langchain_b12-0.1.7 → langchain_b12-0.1.9}/src/langchain_b12/py.typed +0 -0
- {langchain_b12-0.1.7 → langchain_b12-0.1.9}/tests/test_citation_mixin.py +0 -0
- {langchain_b12-0.1.7 → langchain_b12-0.1.9}/tests/test_citations.py +0 -0
- {langchain_b12-0.1.7 → langchain_b12-0.1.9}/tests/test_genai_utils.py +0 -0
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: langchain-b12
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.9
|
|
4
4
|
Summary: A reusable collection of tools and implementations for Langchain
|
|
5
5
|
Author-email: Vincent Min <vincent.min@b12-consulting.com>
|
|
6
6
|
Requires-Python: >=3.11
|
|
7
7
|
Requires-Dist: langchain-core>=0.3.60
|
|
8
|
+
Requires-Dist: tenacity>=9.1.2
|
|
8
9
|
Description-Content-Type: text/markdown
|
|
9
10
|
|
|
10
11
|
# Langchain B12
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "langchain-b12"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.9"
|
|
4
4
|
description = "A reusable collection of tools and implementations for Langchain"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
authors = [
|
|
@@ -9,6 +9,7 @@ authors = [
|
|
|
9
9
|
requires-python = ">=3.11"
|
|
10
10
|
dependencies = [
|
|
11
11
|
"langchain-core>=0.3.60",
|
|
12
|
+
"tenacity>=9.1.2",
|
|
12
13
|
]
|
|
13
14
|
|
|
14
15
|
[dependency-groups]
|
|
@@ -7,10 +7,6 @@ from typing import Any, Literal, cast
|
|
|
7
7
|
from google import genai
|
|
8
8
|
from google.genai import types
|
|
9
9
|
from google.oauth2 import service_account
|
|
10
|
-
from langchain_b12.genai.genai_utils import (
|
|
11
|
-
convert_messages_to_contents,
|
|
12
|
-
parse_response_candidate,
|
|
13
|
-
)
|
|
14
10
|
from langchain_core.callbacks import (
|
|
15
11
|
AsyncCallbackManagerForLLMRun,
|
|
16
12
|
CallbackManagerForLLMRun,
|
|
@@ -40,6 +36,18 @@ from langchain_core.utils.function_calling import (
|
|
|
40
36
|
convert_to_openai_tool,
|
|
41
37
|
)
|
|
42
38
|
from pydantic import BaseModel, ConfigDict, Field
|
|
39
|
+
from tenacity import (
|
|
40
|
+
retry,
|
|
41
|
+
retry_if_exception_type,
|
|
42
|
+
stop_after_attempt,
|
|
43
|
+
stop_never,
|
|
44
|
+
wait_exponential_jitter,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
from langchain_b12.genai.genai_utils import (
|
|
48
|
+
convert_messages_to_contents,
|
|
49
|
+
parse_response_candidate,
|
|
50
|
+
)
|
|
43
51
|
|
|
44
52
|
logger = logging.getLogger(__name__)
|
|
45
53
|
|
|
@@ -75,6 +83,8 @@ class ChatGenAI(BaseChatModel):
|
|
|
75
83
|
"""How many completions to generate for each prompt."""
|
|
76
84
|
seed: int | None = None
|
|
77
85
|
"""Random seed for the generation."""
|
|
86
|
+
max_retries: int | None = Field(default=3)
|
|
87
|
+
"""Maximum number of retries when generation fails. None disables retries."""
|
|
78
88
|
safety_settings: list[types.SafetySetting] | None = None
|
|
79
89
|
"""The default safety settings to use for all generations.
|
|
80
90
|
|
|
@@ -173,10 +183,29 @@ class ChatGenAI(BaseChatModel):
|
|
|
173
183
|
run_manager: CallbackManagerForLLMRun | None = None,
|
|
174
184
|
**kwargs: Any,
|
|
175
185
|
) -> ChatResult:
|
|
176
|
-
|
|
177
|
-
|
|
186
|
+
@retry(
|
|
187
|
+
reraise=True,
|
|
188
|
+
stop=stop_after_attempt(self.max_retries + 1)
|
|
189
|
+
if self.max_retries is not None
|
|
190
|
+
else stop_never,
|
|
191
|
+
wait=wait_exponential_jitter(initial=1, max=60),
|
|
192
|
+
retry=retry_if_exception_type(Exception),
|
|
193
|
+
before_sleep=lambda retry_state: logger.warning(
|
|
194
|
+
"ChatGenAI._generate failed (attempt %d/%s). "
|
|
195
|
+
"Retrying in %.2fs... Error: %s",
|
|
196
|
+
retry_state.attempt_number,
|
|
197
|
+
self.max_retries + 1 if self.max_retries is not None else "∞",
|
|
198
|
+
retry_state.next_action.sleep,
|
|
199
|
+
retry_state.outcome.exception(),
|
|
200
|
+
),
|
|
178
201
|
)
|
|
179
|
-
|
|
202
|
+
def _generate_with_retry() -> ChatResult:
|
|
203
|
+
stream_iter = self._stream(
|
|
204
|
+
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
205
|
+
)
|
|
206
|
+
return generate_from_stream(stream_iter)
|
|
207
|
+
|
|
208
|
+
return _generate_with_retry()
|
|
180
209
|
|
|
181
210
|
async def _agenerate(
|
|
182
211
|
self,
|
|
@@ -185,10 +214,29 @@ class ChatGenAI(BaseChatModel):
|
|
|
185
214
|
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
|
186
215
|
**kwargs: Any,
|
|
187
216
|
) -> ChatResult:
|
|
188
|
-
|
|
189
|
-
|
|
217
|
+
@retry(
|
|
218
|
+
reraise=True,
|
|
219
|
+
stop=stop_after_attempt(self.max_retries + 1)
|
|
220
|
+
if self.max_retries is not None
|
|
221
|
+
else stop_never,
|
|
222
|
+
wait=wait_exponential_jitter(initial=1, max=60),
|
|
223
|
+
retry=retry_if_exception_type(Exception),
|
|
224
|
+
before_sleep=lambda retry_state: logger.warning(
|
|
225
|
+
"ChatGenAI._agenerate failed (attempt %d/%s). "
|
|
226
|
+
"Retrying in %.2fs... Error: %s",
|
|
227
|
+
retry_state.attempt_number,
|
|
228
|
+
self.max_retries + 1 if self.max_retries is not None else "∞",
|
|
229
|
+
retry_state.next_action.sleep,
|
|
230
|
+
retry_state.outcome.exception(),
|
|
231
|
+
),
|
|
190
232
|
)
|
|
191
|
-
|
|
233
|
+
async def _agenerate_with_retry() -> ChatResult:
|
|
234
|
+
stream_iter = self._astream(
|
|
235
|
+
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
236
|
+
)
|
|
237
|
+
return await agenerate_from_stream(stream_iter)
|
|
238
|
+
|
|
239
|
+
return await _agenerate_with_retry()
|
|
192
240
|
|
|
193
241
|
def _stream(
|
|
194
242
|
self,
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from unittest.mock import MagicMock, patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from google.genai import Client, types
|
|
5
|
+
from langchain_b12.genai.genai import ChatGenAI
|
|
6
|
+
from langchain_core.messages import HumanMessage
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def test_chatgenai():
|
|
10
|
+
client = MagicMock(spec=Client)
|
|
11
|
+
model = ChatGenAI(client=client, model="foo", temperature=1)
|
|
12
|
+
assert model.model_name == "foo"
|
|
13
|
+
assert model.temperature == 1
|
|
14
|
+
assert model.client == client
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def test_chatgenai_invocation():
|
|
18
|
+
client: Client = MagicMock(spec=Client)
|
|
19
|
+
client.models.generate_content_stream.return_value = iter(
|
|
20
|
+
(
|
|
21
|
+
types.GenerateContentResponse(
|
|
22
|
+
candidates=[
|
|
23
|
+
types.Candidate(
|
|
24
|
+
content=types.Content(parts=[types.Part(text="bar")])
|
|
25
|
+
),
|
|
26
|
+
]
|
|
27
|
+
),
|
|
28
|
+
types.GenerateContentResponse(
|
|
29
|
+
candidates=[
|
|
30
|
+
types.Candidate(
|
|
31
|
+
content=types.Content(parts=[types.Part(text="baz")])
|
|
32
|
+
),
|
|
33
|
+
]
|
|
34
|
+
),
|
|
35
|
+
)
|
|
36
|
+
)
|
|
37
|
+
model = ChatGenAI(client=client)
|
|
38
|
+
messages = [HumanMessage(content="foo")]
|
|
39
|
+
response = model.invoke(messages)
|
|
40
|
+
method: MagicMock = client.models.generate_content_stream
|
|
41
|
+
method.assert_called_once()
|
|
42
|
+
assert response.content == "barbaz"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _make_success_response():
|
|
46
|
+
"""Helper to create a successful streaming response."""
|
|
47
|
+
return iter(
|
|
48
|
+
[
|
|
49
|
+
types.GenerateContentResponse(
|
|
50
|
+
candidates=[
|
|
51
|
+
types.Candidate(
|
|
52
|
+
content=types.Content(parts=[types.Part(text="success")])
|
|
53
|
+
),
|
|
54
|
+
]
|
|
55
|
+
),
|
|
56
|
+
]
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
|
|
61
|
+
def test_chatgenai_retry_succeeds_after_failure(mock_wait):
|
|
62
|
+
"""Test that retry logic succeeds after transient failures."""
|
|
63
|
+
client: Client = MagicMock(spec=Client)
|
|
64
|
+
|
|
65
|
+
# First two calls fail, third succeeds
|
|
66
|
+
client.models.generate_content_stream.side_effect = [
|
|
67
|
+
Exception("Transient error 1"),
|
|
68
|
+
Exception("Transient error 2"),
|
|
69
|
+
_make_success_response(),
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
model = ChatGenAI(client=client, max_retries=3)
|
|
73
|
+
messages = [HumanMessage(content="foo")]
|
|
74
|
+
response = model.invoke(messages)
|
|
75
|
+
|
|
76
|
+
assert response.content == "success"
|
|
77
|
+
assert client.models.generate_content_stream.call_count == 3
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
|
|
81
|
+
def test_chatgenai_retry_exhausted_raises(mock_wait):
|
|
82
|
+
"""Test that exception is raised after all retries are exhausted."""
|
|
83
|
+
client: Client = MagicMock(spec=Client)
|
|
84
|
+
|
|
85
|
+
# All calls fail
|
|
86
|
+
client.models.generate_content_stream.side_effect = Exception("Persistent error")
|
|
87
|
+
|
|
88
|
+
model = ChatGenAI(client=client, max_retries=2)
|
|
89
|
+
messages = [HumanMessage(content="foo")]
|
|
90
|
+
|
|
91
|
+
with pytest.raises(Exception, match="Persistent error"):
|
|
92
|
+
model.invoke(messages)
|
|
93
|
+
|
|
94
|
+
# Initial attempt + 2 retries = 3 total calls
|
|
95
|
+
assert client.models.generate_content_stream.call_count == 3
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
|
|
99
|
+
def test_chatgenai_no_retry_when_max_retries_zero(mock_wait):
|
|
100
|
+
"""Test that no retries occur when max_retries=0."""
|
|
101
|
+
client: Client = MagicMock(spec=Client)
|
|
102
|
+
client.models.generate_content_stream.side_effect = Exception("Error")
|
|
103
|
+
|
|
104
|
+
model = ChatGenAI(client=client, max_retries=0)
|
|
105
|
+
messages = [HumanMessage(content="foo")]
|
|
106
|
+
|
|
107
|
+
with pytest.raises(Exception, match="Error"):
|
|
108
|
+
model.invoke(messages)
|
|
109
|
+
|
|
110
|
+
# Only 1 attempt, no retries
|
|
111
|
+
assert client.models.generate_content_stream.call_count == 1
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def test_chatgenai_no_retry_on_success():
|
|
115
|
+
"""Test that no retries occur when first attempt succeeds."""
|
|
116
|
+
client: Client = MagicMock(spec=Client)
|
|
117
|
+
client.models.generate_content_stream.return_value = _make_success_response()
|
|
118
|
+
|
|
119
|
+
model = ChatGenAI(client=client, max_retries=3)
|
|
120
|
+
messages = [HumanMessage(content="foo")]
|
|
121
|
+
response = model.invoke(messages)
|
|
122
|
+
|
|
123
|
+
assert response.content == "success"
|
|
124
|
+
assert client.models.generate_content_stream.call_count == 1
|