langchain-b12 0.1.9__tar.gz → 0.1.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langchain-b12
3
- Version: 0.1.9
3
+ Version: 0.1.10
4
4
  Summary: A reusable collection of tools and implementations for Langchain
5
5
  Author-email: Vincent Min <vincent.min@b12-consulting.com>
6
6
  Requires-Python: >=3.11
7
7
  Requires-Dist: langchain-core>=0.3.60
8
+ Requires-Dist: pytest-anyio>=0.0.0
8
9
  Requires-Dist: tenacity>=9.1.2
9
10
  Description-Content-Type: text/markdown
10
11
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "langchain-b12"
3
- version = "0.1.9"
3
+ version = "0.1.10"
4
4
  description = "A reusable collection of tools and implementations for Langchain"
5
5
  readme = "README.md"
6
6
  authors = [
@@ -9,6 +9,7 @@ authors = [
9
9
  requires-python = ">=3.11"
10
10
  dependencies = [
11
11
  "langchain-core>=0.3.60",
12
+ "pytest-anyio>=0.0.0",
12
13
  "tenacity>=9.1.2",
13
14
  ]
14
15
 
@@ -84,7 +84,7 @@ class ChatGenAI(BaseChatModel):
84
84
  seed: int | None = None
85
85
  """Random seed for the generation."""
86
86
  max_retries: int | None = Field(default=3)
87
- """Maximum number of retries when generation fails. None disables retries."""
87
+ """Maximum number of retries when generation fails. None retries indefinitely."""
88
88
  safety_settings: list[types.SafetySetting] | None = None
89
89
  """The default safety settings to use for all generations.
90
90
 
@@ -183,29 +183,10 @@ class ChatGenAI(BaseChatModel):
183
183
  run_manager: CallbackManagerForLLMRun | None = None,
184
184
  **kwargs: Any,
185
185
  ) -> ChatResult:
186
- @retry(
187
- reraise=True,
188
- stop=stop_after_attempt(self.max_retries + 1)
189
- if self.max_retries is not None
190
- else stop_never,
191
- wait=wait_exponential_jitter(initial=1, max=60),
192
- retry=retry_if_exception_type(Exception),
193
- before_sleep=lambda retry_state: logger.warning(
194
- "ChatGenAI._generate failed (attempt %d/%s). "
195
- "Retrying in %.2fs... Error: %s",
196
- retry_state.attempt_number,
197
- self.max_retries + 1 if self.max_retries is not None else "∞",
198
- retry_state.next_action.sleep,
199
- retry_state.outcome.exception(),
200
- ),
186
+ stream_iter = self._stream(
187
+ messages, stop=stop, run_manager=run_manager, **kwargs
201
188
  )
202
- def _generate_with_retry() -> ChatResult:
203
- stream_iter = self._stream(
204
- messages, stop=stop, run_manager=run_manager, **kwargs
205
- )
206
- return generate_from_stream(stream_iter)
207
-
208
- return _generate_with_retry()
189
+ return generate_from_stream(stream_iter)
209
190
 
210
191
  async def _agenerate(
211
192
  self,
@@ -214,6 +195,20 @@ class ChatGenAI(BaseChatModel):
214
195
  run_manager: AsyncCallbackManagerForLLMRun | None = None,
215
196
  **kwargs: Any,
216
197
  ) -> ChatResult:
198
+ stream_iter = self._astream(
199
+ messages, stop=stop, run_manager=run_manager, **kwargs
200
+ )
201
+ return await agenerate_from_stream(stream_iter)
202
+
203
+ def _stream(
204
+ self,
205
+ messages: list[BaseMessage],
206
+ stop: list[str] | None = None,
207
+ run_manager: CallbackManagerForLLMRun | None = None,
208
+ **kwargs: Any,
209
+ ) -> Iterator[ChatGenerationChunk]:
210
+ system_message, contents = self._prepare_request(messages=messages)
211
+
217
212
  @retry(
218
213
  reraise=True,
219
214
  stop=stop_after_attempt(self.max_retries + 1)
@@ -222,7 +217,7 @@ class ChatGenAI(BaseChatModel):
222
217
  wait=wait_exponential_jitter(initial=1, max=60),
223
218
  retry=retry_if_exception_type(Exception),
224
219
  before_sleep=lambda retry_state: logger.warning(
225
- "ChatGenAI._agenerate failed (attempt %d/%s). "
220
+ "ChatGenAI._stream failed to start (attempt %d/%s). "
226
221
  "Retrying in %.2fs... Error: %s",
227
222
  retry_state.attempt_number,
228
223
  self.max_retries + 1 if self.max_retries is not None else "∞",
@@ -230,42 +225,47 @@ class ChatGenAI(BaseChatModel):
230
225
  retry_state.outcome.exception(),
231
226
  ),
232
227
  )
233
- async def _agenerate_with_retry() -> ChatResult:
234
- stream_iter = self._astream(
235
- messages, stop=stop, run_manager=run_manager, **kwargs
228
+ def _initiate_stream() -> tuple[
229
+ ChatGenerationChunk,
230
+ Iterator[types.GenerateContentResponse],
231
+ UsageMetadata | None,
232
+ ]:
233
+ """Initialize stream and fetch first chunk. Retries only apply here."""
234
+ response_iter = self.client.models.generate_content_stream(
235
+ model=self.model_name,
236
+ contents=contents,
237
+ config=types.GenerateContentConfig(
238
+ system_instruction=system_message,
239
+ temperature=self.temperature,
240
+ top_k=self.top_k,
241
+ top_p=self.top_p,
242
+ max_output_tokens=self.max_output_tokens,
243
+ candidate_count=self.n,
244
+ stop_sequences=stop or self.stop,
245
+ safety_settings=self.safety_settings,
246
+ thinking_config=self.thinking_config,
247
+ automatic_function_calling=types.AutomaticFunctionCallingConfig(
248
+ disable=True,
249
+ ),
250
+ **kwargs,
251
+ ),
236
252
  )
237
- return await agenerate_from_stream(stream_iter)
253
+ # Fetch first chunk to ensure connection is established
254
+ first_response = next(iter(response_iter))
255
+ first_chunk, total_usage = self._gemini_chunk_to_generation_chunk(
256
+ first_response, prev_total_usage=None
257
+ )
258
+ return first_chunk, response_iter, total_usage
238
259
 
239
- return await _agenerate_with_retry()
260
+ # Retry only covers stream initialization and first chunk
261
+ first_chunk, response_iter, total_lc_usage = _initiate_stream()
240
262
 
241
- def _stream(
242
- self,
243
- messages: list[BaseMessage],
244
- stop: list[str] | None = None,
245
- run_manager: CallbackManagerForLLMRun | None = None,
246
- **kwargs: Any,
247
- ) -> Iterator[ChatGenerationChunk]:
248
- system_message, contents = self._prepare_request(messages=messages)
249
- response_iter = self.client.models.generate_content_stream(
250
- model=self.model_name,
251
- contents=contents,
252
- config=types.GenerateContentConfig(
253
- system_instruction=system_message,
254
- temperature=self.temperature,
255
- top_k=self.top_k,
256
- top_p=self.top_p,
257
- max_output_tokens=self.max_output_tokens,
258
- candidate_count=self.n,
259
- stop_sequences=stop or self.stop,
260
- safety_settings=self.safety_settings,
261
- thinking_config=self.thinking_config,
262
- automatic_function_calling=types.AutomaticFunctionCallingConfig(
263
- disable=True,
264
- ),
265
- **kwargs,
266
- ),
267
- )
268
- total_lc_usage = None
263
+ # Yield first chunk
264
+ if run_manager and isinstance(first_chunk.message.content, str):
265
+ run_manager.on_llm_new_token(first_chunk.message.content)
266
+ yield first_chunk
267
+
268
+ # Continue streaming without retry (retries during streaming are not well defined)
269
269
  for response_chunk in response_iter:
270
270
  chunk, total_lc_usage = self._gemini_chunk_to_generation_chunk(
271
271
  response_chunk, prev_total_usage=total_lc_usage
@@ -282,27 +282,65 @@ class ChatGenAI(BaseChatModel):
282
282
  **kwargs: Any,
283
283
  ) -> AsyncIterator[ChatGenerationChunk]:
284
284
  system_message, contents = self._prepare_request(messages=messages)
285
- response_iter = self.client.aio.models.generate_content_stream(
286
- model=self.model_name,
287
- contents=contents,
288
- config=types.GenerateContentConfig(
289
- system_instruction=system_message,
290
- temperature=self.temperature,
291
- top_k=self.top_k,
292
- top_p=self.top_p,
293
- max_output_tokens=self.max_output_tokens,
294
- candidate_count=self.n,
295
- stop_sequences=stop or self.stop,
296
- safety_settings=self.safety_settings,
297
- thinking_config=self.thinking_config,
298
- automatic_function_calling=types.AutomaticFunctionCallingConfig(
299
- disable=True,
300
- ),
301
- **kwargs,
285
+
286
+ @retry(
287
+ reraise=True,
288
+ stop=stop_after_attempt(self.max_retries + 1)
289
+ if self.max_retries is not None
290
+ else stop_never,
291
+ wait=wait_exponential_jitter(initial=1, max=60),
292
+ retry=retry_if_exception_type(Exception),
293
+ before_sleep=lambda retry_state: logger.warning(
294
+ "ChatGenAI._astream failed to start (attempt %d/%s). "
295
+ "Retrying in %.2fs... Error: %s",
296
+ retry_state.attempt_number,
297
+ self.max_retries + 1 if self.max_retries is not None else "∞",
298
+ retry_state.next_action.sleep,
299
+ retry_state.outcome.exception(),
302
300
  ),
303
301
  )
304
- total_lc_usage = None
305
- async for response_chunk in await response_iter:
302
+ async def _initiate_stream() -> tuple[
303
+ ChatGenerationChunk,
304
+ AsyncIterator[types.GenerateContentResponse],
305
+ UsageMetadata | None,
306
+ ]:
307
+ """Initialize stream and fetch first chunk. Retries only apply here."""
308
+ response_iter = await self.client.aio.models.generate_content_stream(
309
+ model=self.model_name,
310
+ contents=contents,
311
+ config=types.GenerateContentConfig(
312
+ system_instruction=system_message,
313
+ temperature=self.temperature,
314
+ top_k=self.top_k,
315
+ top_p=self.top_p,
316
+ max_output_tokens=self.max_output_tokens,
317
+ candidate_count=self.n,
318
+ stop_sequences=stop or self.stop,
319
+ safety_settings=self.safety_settings,
320
+ thinking_config=self.thinking_config,
321
+ automatic_function_calling=types.AutomaticFunctionCallingConfig(
322
+ disable=True,
323
+ ),
324
+ **kwargs,
325
+ ),
326
+ )
327
+ # Fetch first chunk to ensure connection is established
328
+ first_response = await response_iter.__anext__()
329
+ first_chunk, total_usage = self._gemini_chunk_to_generation_chunk(
330
+ first_response, prev_total_usage=None
331
+ )
332
+ return first_chunk, response_iter, total_usage
333
+
334
+ # Retry only covers stream initialization and first chunk
335
+ first_chunk, response_iter, total_lc_usage = await _initiate_stream()
336
+
337
+ # Yield first chunk
338
+ if run_manager and isinstance(first_chunk.message.content, str):
339
+ await run_manager.on_llm_new_token(first_chunk.message.content)
340
+ yield first_chunk
341
+
342
+ # Continue streaming without retry (retries during streaming are not well defined)
343
+ async for response_chunk in response_iter:
306
344
  chunk, total_lc_usage = self._gemini_chunk_to_generation_chunk(
307
345
  response_chunk, prev_total_usage=total_lc_usage
308
346
  )
@@ -0,0 +1,279 @@
1
+ from unittest.mock import AsyncMock, MagicMock, patch
2
+
3
+ import pytest
4
+ from google.genai import Client, types
5
+ from langchain_b12.genai.genai import ChatGenAI
6
+ from langchain_core.messages import HumanMessage
7
+
8
+
9
+ def _make_response_chunk(text: str) -> types.GenerateContentResponse:
10
+ """Helper to create a response chunk."""
11
+ return types.GenerateContentResponse(
12
+ candidates=[
13
+ types.Candidate(content=types.Content(parts=[types.Part(text=text)]))
14
+ ]
15
+ )
16
+
17
+
18
+ def test_chatgenai():
19
+ client = MagicMock(spec=Client)
20
+ model = ChatGenAI(client=client, model="foo", temperature=1)
21
+ assert model.model_name == "foo"
22
+ assert model.temperature == 1
23
+ assert model.client == client
24
+
25
+
26
+ def test_chatgenai_invocation():
27
+ client: Client = MagicMock(spec=Client)
28
+ client.models.generate_content_stream.return_value = iter(
29
+ (
30
+ _make_response_chunk("bar"),
31
+ _make_response_chunk("baz"),
32
+ )
33
+ )
34
+ model = ChatGenAI(client=client)
35
+ messages = [HumanMessage(content="foo")]
36
+ response = model.invoke(messages)
37
+ method: MagicMock = client.models.generate_content_stream
38
+ method.assert_called_once()
39
+ assert response.content == "barbaz"
40
+
41
+
42
+ def _make_success_iter():
43
+ """Helper to create a successful streaming iterator."""
44
+ return iter([_make_response_chunk("success")])
45
+
46
+
47
+ @patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
48
+ def test_chatgenai_retry_succeeds_after_failure(mock_wait):
49
+ """Test that retry logic succeeds after transient failures."""
50
+ client: Client = MagicMock(spec=Client)
51
+
52
+ # First two calls fail, third succeeds
53
+ client.models.generate_content_stream.side_effect = [
54
+ Exception("Transient error 1"),
55
+ Exception("Transient error 2"),
56
+ _make_success_iter(),
57
+ ]
58
+
59
+ model = ChatGenAI(client=client, max_retries=3)
60
+ messages = [HumanMessage(content="foo")]
61
+ response = model.invoke(messages)
62
+
63
+ assert response.content == "success"
64
+ assert client.models.generate_content_stream.call_count == 3
65
+
66
+
67
+ @patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
68
+ def test_chatgenai_retry_exhausted_raises(mock_wait):
69
+ """Test that exception is raised after all retries are exhausted."""
70
+ client: Client = MagicMock(spec=Client)
71
+
72
+ # All calls fail
73
+ client.models.generate_content_stream.side_effect = Exception("Persistent error")
74
+
75
+ model = ChatGenAI(client=client, max_retries=2)
76
+ messages = [HumanMessage(content="foo")]
77
+
78
+ with pytest.raises(Exception, match="Persistent error"):
79
+ model.invoke(messages)
80
+
81
+ # Initial attempt + 2 retries = 3 total calls
82
+ assert client.models.generate_content_stream.call_count == 3
83
+
84
+
85
+ @patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
86
+ def test_chatgenai_no_retry_when_max_retries_zero(mock_wait):
87
+ """Test that no retries occur when max_retries=0."""
88
+ client: Client = MagicMock(spec=Client)
89
+ client.models.generate_content_stream.side_effect = Exception("Error")
90
+
91
+ model = ChatGenAI(client=client, max_retries=0)
92
+ messages = [HumanMessage(content="foo")]
93
+
94
+ with pytest.raises(Exception, match="Error"):
95
+ model.invoke(messages)
96
+
97
+ # Only 1 attempt, no retries
98
+ assert client.models.generate_content_stream.call_count == 1
99
+
100
+
101
+ def test_chatgenai_no_retry_on_success():
102
+ """Test that no retries occur when first attempt succeeds."""
103
+ client: Client = MagicMock(spec=Client)
104
+ client.models.generate_content_stream.return_value = _make_success_iter()
105
+
106
+ model = ChatGenAI(client=client, max_retries=3)
107
+ messages = [HumanMessage(content="foo")]
108
+ response = model.invoke(messages)
109
+
110
+ assert response.content == "success"
111
+ assert client.models.generate_content_stream.call_count == 1
112
+
113
+
114
+ # --- Streaming behavior tests ---
115
+
116
+
117
+ def test_stream_yields_chunks_immediately():
118
+ """Test that stream yields chunks as they arrive, not buffered."""
119
+ client: Client = MagicMock(spec=Client)
120
+ chunks_yielded: list[str] = []
121
+
122
+ def mock_stream():
123
+ for text in ["chunk1", "chunk2", "chunk3"]:
124
+ # Track when chunks are yielded from the source
125
+ chunks_yielded.append(f"source:{text}")
126
+ yield _make_response_chunk(text)
127
+
128
+ client.models.generate_content_stream.return_value = mock_stream()
129
+
130
+ model = ChatGenAI(client=client, max_retries=3)
131
+ messages = [HumanMessage(content="foo")]
132
+
133
+ received: list[str] = []
134
+ for chunk in model.stream(messages):
135
+ received.append(chunk.content)
136
+ # After receiving each chunk, check that source yielded it
137
+ assert len(received) == len([c for c in chunks_yielded if c.startswith("source:")])
138
+
139
+ assert received == ["chunk1", "chunk2", "chunk3"]
140
+
141
+
142
+ @patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
143
+ def test_stream_no_retry_after_first_chunk(mock_wait):
144
+ """Test that errors after first chunk are NOT retried."""
145
+ client: Client = MagicMock(spec=Client)
146
+
147
+ def failing_after_first():
148
+ yield _make_response_chunk("first")
149
+ raise Exception("Mid-stream error")
150
+
151
+ client.models.generate_content_stream.return_value = failing_after_first()
152
+
153
+ model = ChatGenAI(client=client, max_retries=3)
154
+ messages = [HumanMessage(content="foo")]
155
+
156
+ chunks = []
157
+ with pytest.raises(Exception, match="Mid-stream error"):
158
+ for chunk in model.stream(messages):
159
+ chunks.append(chunk.content)
160
+
161
+ # First chunk was received
162
+ assert chunks == ["first"]
163
+ # Only one call - no retry after first chunk
164
+ assert client.models.generate_content_stream.call_count == 1
165
+
166
+
167
+ @patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
168
+ def test_stream_retry_on_first_chunk_failure(mock_wait):
169
+ """Test that failure on first chunk triggers retry."""
170
+ client: Client = MagicMock(spec=Client)
171
+
172
+ def fail_on_first_next():
173
+ raise Exception("First chunk error")
174
+ yield # Make it a generator
175
+
176
+ def success_stream():
177
+ yield _make_response_chunk("success1")
178
+ yield _make_response_chunk("success2")
179
+
180
+ client.models.generate_content_stream.side_effect = [
181
+ fail_on_first_next(),
182
+ success_stream(),
183
+ ]
184
+
185
+ model = ChatGenAI(client=client, max_retries=3)
186
+ messages = [HumanMessage(content="foo")]
187
+
188
+ chunks = [chunk.content for chunk in model.stream(messages)]
189
+ assert chunks == ["success1", "success2"]
190
+ assert client.models.generate_content_stream.call_count == 2
191
+
192
+
193
+ # --- Async streaming tests ---
194
+
195
+
196
+ async def _async_iter(items):
197
+ """Helper to create an async iterator from items."""
198
+ for item in items:
199
+ yield item
200
+
201
+
202
+ @pytest.mark.anyio
203
+ async def test_astream_yields_chunks_immediately():
204
+ """Test that async stream yields chunks as they arrive."""
205
+ client: Client = MagicMock(spec=Client)
206
+
207
+ chunks = [
208
+ _make_response_chunk("async1"),
209
+ _make_response_chunk("async2"),
210
+ _make_response_chunk("async3"),
211
+ ]
212
+
213
+ # generate_content_stream returns a coroutine that resolves to async iterator
214
+ client.aio.models.generate_content_stream = AsyncMock(
215
+ return_value=_async_iter(chunks)
216
+ )
217
+
218
+ model = ChatGenAI(client=client, max_retries=3)
219
+ messages = [HumanMessage(content="foo")]
220
+
221
+ received: list[str] = []
222
+ async for chunk in model.astream(messages):
223
+ received.append(chunk.content)
224
+
225
+ assert received == ["async1", "async2", "async3"]
226
+
227
+
228
+ @pytest.mark.anyio
229
+ @patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
230
+ async def test_astream_no_retry_after_first_chunk(mock_wait):
231
+ """Test that errors after first chunk are NOT retried in async."""
232
+ client: Client = MagicMock(spec=Client)
233
+
234
+ async def failing_after_first():
235
+ yield _make_response_chunk("first")
236
+ raise Exception("Async mid-stream error")
237
+
238
+ client.aio.models.generate_content_stream = AsyncMock(
239
+ return_value=failing_after_first()
240
+ )
241
+
242
+ model = ChatGenAI(client=client, max_retries=3)
243
+ messages = [HumanMessage(content="foo")]
244
+
245
+ chunks = []
246
+ with pytest.raises(Exception, match="Async mid-stream error"):
247
+ async for chunk in model.astream(messages):
248
+ chunks.append(chunk.content)
249
+
250
+ assert chunks == ["first"]
251
+ assert client.aio.models.generate_content_stream.call_count == 1
252
+
253
+
254
+ @pytest.mark.anyio
255
+ @patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
256
+ async def test_astream_retry_succeeds_after_failure(mock_wait):
257
+ """Test that async retry logic works for initial failures."""
258
+ client: Client = MagicMock(spec=Client)
259
+
260
+ call_count = 0
261
+
262
+ async def side_effect_fn(*args, **kwargs):
263
+ nonlocal call_count
264
+ call_count += 1
265
+ if call_count == 1:
266
+ raise Exception("Async transient error")
267
+ return _async_iter([_make_response_chunk("async_success")])
268
+
269
+ client.aio.models.generate_content_stream = AsyncMock(side_effect=side_effect_fn)
270
+
271
+ model = ChatGenAI(client=client, max_retries=3)
272
+ messages = [HumanMessage(content="foo")]
273
+
274
+ chunks = []
275
+ async for chunk in model.astream(messages):
276
+ chunks.append(chunk.content)
277
+
278
+ assert chunks == ["async_success"]
279
+ assert client.aio.models.generate_content_stream.call_count == 2
@@ -252,10 +252,11 @@ wheels = [
252
252
 
253
253
  [[package]]
254
254
  name = "langchain-b12"
255
- version = "0.1.8"
255
+ version = "0.1.9"
256
256
  source = { editable = "." }
257
257
  dependencies = [
258
258
  { name = "langchain-core" },
259
+ { name = "pytest-anyio" },
259
260
  { name = "tenacity" },
260
261
  ]
261
262
 
@@ -275,6 +276,7 @@ google = [
275
276
  [package.metadata]
276
277
  requires-dist = [
277
278
  { name = "langchain-core", specifier = ">=0.3.60" },
279
+ { name = "pytest-anyio", specifier = ">=0.0.0" },
278
280
  { name = "tenacity", specifier = ">=9.1.2" },
279
281
  ]
280
282
 
@@ -621,6 +623,19 @@ wheels = [
621
623
  { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474 },
622
624
  ]
623
625
 
626
+ [[package]]
627
+ name = "pytest-anyio"
628
+ version = "0.0.0"
629
+ source = { registry = "https://pypi.org/simple" }
630
+ dependencies = [
631
+ { name = "anyio" },
632
+ { name = "pytest" },
633
+ ]
634
+ sdist = { url = "https://files.pythonhosted.org/packages/00/44/a02e5877a671b0940f21a7a0d9704c22097b123ed5cdbcca9cab39f17acc/pytest-anyio-0.0.0.tar.gz", hash = "sha256:b41234e9e9ad7ea1dbfefcc1d6891b23d5ef7c9f07ccf804c13a9cc338571fd3", size = 1560 }
635
+ wheels = [
636
+ { url = "https://files.pythonhosted.org/packages/c6/25/bd6493ae85d0a281b6a0f248d0fdb1d9aa2b31f18bcd4a8800cf397d8209/pytest_anyio-0.0.0-py2.py3-none-any.whl", hash = "sha256:dc8b5c4741cb16ff90be37fddd585ca943ed12bbeb563de7ace6cd94441d8746", size = 1999 },
637
+ ]
638
+
624
639
  [[package]]
625
640
  name = "pytest-asyncio"
626
641
  version = "1.1.0"
@@ -1,124 +0,0 @@
1
- from unittest.mock import MagicMock, patch
2
-
3
- import pytest
4
- from google.genai import Client, types
5
- from langchain_b12.genai.genai import ChatGenAI
6
- from langchain_core.messages import HumanMessage
7
-
8
-
9
- def test_chatgenai():
10
- client = MagicMock(spec=Client)
11
- model = ChatGenAI(client=client, model="foo", temperature=1)
12
- assert model.model_name == "foo"
13
- assert model.temperature == 1
14
- assert model.client == client
15
-
16
-
17
- def test_chatgenai_invocation():
18
- client: Client = MagicMock(spec=Client)
19
- client.models.generate_content_stream.return_value = iter(
20
- (
21
- types.GenerateContentResponse(
22
- candidates=[
23
- types.Candidate(
24
- content=types.Content(parts=[types.Part(text="bar")])
25
- ),
26
- ]
27
- ),
28
- types.GenerateContentResponse(
29
- candidates=[
30
- types.Candidate(
31
- content=types.Content(parts=[types.Part(text="baz")])
32
- ),
33
- ]
34
- ),
35
- )
36
- )
37
- model = ChatGenAI(client=client)
38
- messages = [HumanMessage(content="foo")]
39
- response = model.invoke(messages)
40
- method: MagicMock = client.models.generate_content_stream
41
- method.assert_called_once()
42
- assert response.content == "barbaz"
43
-
44
-
45
- def _make_success_response():
46
- """Helper to create a successful streaming response."""
47
- return iter(
48
- [
49
- types.GenerateContentResponse(
50
- candidates=[
51
- types.Candidate(
52
- content=types.Content(parts=[types.Part(text="success")])
53
- ),
54
- ]
55
- ),
56
- ]
57
- )
58
-
59
-
60
- @patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
61
- def test_chatgenai_retry_succeeds_after_failure(mock_wait):
62
- """Test that retry logic succeeds after transient failures."""
63
- client: Client = MagicMock(spec=Client)
64
-
65
- # First two calls fail, third succeeds
66
- client.models.generate_content_stream.side_effect = [
67
- Exception("Transient error 1"),
68
- Exception("Transient error 2"),
69
- _make_success_response(),
70
- ]
71
-
72
- model = ChatGenAI(client=client, max_retries=3)
73
- messages = [HumanMessage(content="foo")]
74
- response = model.invoke(messages)
75
-
76
- assert response.content == "success"
77
- assert client.models.generate_content_stream.call_count == 3
78
-
79
-
80
- @patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
81
- def test_chatgenai_retry_exhausted_raises(mock_wait):
82
- """Test that exception is raised after all retries are exhausted."""
83
- client: Client = MagicMock(spec=Client)
84
-
85
- # All calls fail
86
- client.models.generate_content_stream.side_effect = Exception("Persistent error")
87
-
88
- model = ChatGenAI(client=client, max_retries=2)
89
- messages = [HumanMessage(content="foo")]
90
-
91
- with pytest.raises(Exception, match="Persistent error"):
92
- model.invoke(messages)
93
-
94
- # Initial attempt + 2 retries = 3 total calls
95
- assert client.models.generate_content_stream.call_count == 3
96
-
97
-
98
- @patch("langchain_b12.genai.genai.wait_exponential_jitter", return_value=lambda _: 0)
99
- def test_chatgenai_no_retry_when_max_retries_zero(mock_wait):
100
- """Test that no retries occur when max_retries=0."""
101
- client: Client = MagicMock(spec=Client)
102
- client.models.generate_content_stream.side_effect = Exception("Error")
103
-
104
- model = ChatGenAI(client=client, max_retries=0)
105
- messages = [HumanMessage(content="foo")]
106
-
107
- with pytest.raises(Exception, match="Error"):
108
- model.invoke(messages)
109
-
110
- # Only 1 attempt, no retries
111
- assert client.models.generate_content_stream.call_count == 1
112
-
113
-
114
- def test_chatgenai_no_retry_on_success():
115
- """Test that no retries occur when first attempt succeeds."""
116
- client: Client = MagicMock(spec=Client)
117
- client.models.generate_content_stream.return_value = _make_success_response()
118
-
119
- model = ChatGenAI(client=client, max_retries=3)
120
- messages = [HumanMessage(content="foo")]
121
- response = model.invoke(messages)
122
-
123
- assert response.content == "success"
124
- assert client.models.generate_content_stream.call_count == 1
File without changes
File without changes
File without changes