langchain-b12 0.1.3__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langchain-b12
3
- Version: 0.1.3
3
+ Version: 0.1.4
4
4
  Summary: A reusable collection of tools and implementations for Langchain
5
5
  Author-email: Vincent Min <vincent.min@b12-consulting.com>
6
6
  Requires-Python: >=3.11
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "langchain-b12"
3
- version = "0.1.3"
3
+ version = "0.1.4"
4
4
  description = "A reusable collection of tools and implementations for Langchain"
5
5
  readme = "README.md"
6
6
  authors = [
@@ -1,10 +1,13 @@
1
1
  import re
2
2
  from collections.abc import Sequence
3
- from typing import Literal, TypedDict
3
+ from typing import Any, Literal, TypedDict
4
+ from uuid import UUID
4
5
 
5
6
  from fuzzysearch import find_near_matches
7
+ from langchain_core.callbacks import Callbacks
6
8
  from langchain_core.language_models import BaseChatModel
7
9
  from langchain_core.messages import AIMessage, BaseMessage, SystemMessage
10
+ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
8
11
  from langchain_core.runnables import Runnable
9
12
  from langgraph.utils.runnable import RunnableCallable
10
13
  from pydantic import BaseModel, Field
@@ -163,7 +166,11 @@ def validate_citations(
163
166
  if citation.sentence_index < 0 or citation.sentence_index >= n_sentences:
164
167
  # discard citations that refer to non-existing sentences
165
168
  continue
166
- matches = find_near_matches(citation.cited_text, all_text, max_l_dist=5)
169
+ # Allow for 10% error distance
170
+ max_l_dist = max(1, len(citation.cited_text) // 10)
171
+ matches = find_near_matches(
172
+ citation.cited_text, all_text, max_l_dist=max_l_dist
173
+ )
167
174
  if not matches:
168
175
  citations_with_matches.append((citation, None))
169
176
  else:
@@ -187,6 +194,7 @@ async def add_citations(
187
194
  messages: Sequence[BaseMessage],
188
195
  message: AIMessage,
189
196
  system_prompt: str,
197
+ **kwargs: Any,
190
198
  ) -> AIMessage:
191
199
  """Add citations to the message."""
192
200
  if not message.content:
@@ -214,7 +222,9 @@ async def add_citations(
214
222
  system_message = SystemMessage(system_prompt)
215
223
  _messages = [system_message, *messages, numbered_message]
216
224
 
217
- citations = await model.with_structured_output(Citations).ainvoke(_messages)
225
+ citations = await model.with_structured_output(Citations).ainvoke(
226
+ _messages, **kwargs
227
+ )
218
228
  assert isinstance(
219
229
  citations, Citations
220
230
  ), f"Expected Citations from model invocation but got {type(citations)}"
@@ -272,3 +282,70 @@ def create_citation_model(
272
282
  func=None, # TODO: Implement a sync version if needed
273
283
  afunc=ainvoke_with_citations,
274
284
  )
285
+
286
+
287
+ class CitationMixin(BaseChatModel):
288
+ """Mixin class to add citation functionality to a runnable.
289
+
290
+ Example usage:
291
+ ```
292
+ from langchain_b12.genai.genai import ChatGenAI
293
+ from langchain_b12.citations.citations import CitationMixin
294
+
295
+ class CitationModel(ChatGenAI, CitationMixin):
296
+ pass
297
+ ```
298
+ """
299
+
300
+ async def agenerate(
301
+ self,
302
+ messages: list[list[BaseMessage]],
303
+ stop: list[str] | None = None,
304
+ callbacks: Callbacks = None,
305
+ *,
306
+ tags: list[str] | None = None,
307
+ metadata: dict[str, Any] | None = None,
308
+ run_name: str | None = None,
309
+ run_id: UUID | None = None,
310
+ **kwargs: Any,
311
+ ) -> LLMResult:
312
+ # Check if we should generate citations and remove it from kwargs
313
+ generate_citations = kwargs.pop("generate_citations", True)
314
+
315
+ llm_result = await super().agenerate(
316
+ messages,
317
+ stop,
318
+ callbacks,
319
+ tags=tags,
320
+ metadata=metadata,
321
+ run_name=run_name,
322
+ run_id=run_id,
323
+ **kwargs,
324
+ )
325
+
326
+ # Prevent recursion when extracting citations
327
+ if not generate_citations:
328
+ # Below we are call `add_citations` which will call `agenerate` again
329
+ # This will lead to an infinite loop if we don't stop here.
330
+ # We explicitly pass `generate_citations=False` below to sto this recursion.
331
+ return llm_result
332
+
333
+ # overwrite each generation with a version that has citations added
334
+ for _messages, generations in zip(messages, llm_result.generations):
335
+ for generation in generations:
336
+ assert isinstance(generation, ChatGeneration) and not isinstance(
337
+ generation, ChatGenerationChunk
338
+ ), f"Expected ChatGeneration; received {type(generation)}"
339
+ assert isinstance(
340
+ generation.message, AIMessage
341
+ ), f"Expected AIMessage; received {type(generation.message)}"
342
+ message_with_citations = await add_citations(
343
+ self,
344
+ _messages,
345
+ generation.message,
346
+ SYSTEM_PROMPT,
347
+ generate_citations=False,
348
+ )
349
+ generation.message = message_with_citations
350
+
351
+ return llm_result
@@ -0,0 +1,460 @@
1
+ """
2
+ Comprehensive tests for the CitationMixin class.
3
+ """
4
+
5
+ from collections.abc import Sequence
6
+ from typing import Any
7
+ from unittest.mock import patch
8
+
9
+ import pytest
10
+ from langchain_b12.citations.citations import Citation, CitationMixin, Citations
11
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolCall
12
+ from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
13
+ from langchain_core.tools import BaseTool
14
+ from pydantic import BaseModel, Field
15
+
16
+
17
+ class MockChatModel(CitationMixin):
18
+ response_content: str = Field(default="This is a mock response.")
19
+ tool_calls: list[ToolCall] = Field(default_factory=list)
20
+
21
+ @property
22
+ def _llm_type(self) -> str:
23
+ return "simple"
24
+
25
+ def _generate(self, messages, stop=None, run_manager=None, **kwargs):
26
+
27
+ ai_message = AIMessage(
28
+ content=self.response_content, tool_calls=self.tool_calls
29
+ )
30
+ generation = ChatGeneration(message=ai_message)
31
+ return ChatResult(generations=[generation])
32
+
33
+
34
+ class MockChatModelWithStructuredOutput(MockChatModel):
35
+ structured_response: BaseModel = Field(...)
36
+
37
+ def bind_tools(
38
+ self,
39
+ tools: Sequence[BaseTool],
40
+ *,
41
+ tool_choice: str | None = None,
42
+ **kwargs: Any,
43
+ ):
44
+ return MockChatModel(
45
+ response_content="",
46
+ tool_calls=[
47
+ ToolCall(
48
+ name=self.structured_response.__class__.__name__,
49
+ args=self.structured_response.model_dump(),
50
+ id="structured_abc",
51
+ )
52
+ ],
53
+ )
54
+
55
+
56
+ class TestCitationMixin:
57
+ """Test the CitationMixin class functionality."""
58
+
59
+ @pytest.mark.asyncio
60
+ async def test_end_to_end(self):
61
+ """Test that context tags are processed correctly."""
62
+
63
+ citations = Citations(
64
+ values=[Citation(sentence_index=0, key="abc", cited_text="bar")]
65
+ )
66
+ model = MockChatModelWithStructuredOutput(
67
+ response_content="foo", structured_response=citations
68
+ )
69
+
70
+ # Create a message with context tags
71
+ context_message = HumanMessage(
72
+ content="Question about <context key='abc'>\nbaz\n</context>"
73
+ )
74
+ messages: list[BaseMessage] = [context_message]
75
+
76
+ # Simulate structured content after citation processing
77
+ expected_content = [
78
+ {
79
+ "text": "foo",
80
+ "citations": [
81
+ {
82
+ "cited_text": "baz",
83
+ "generated_cited_text": "bar",
84
+ "key": "abc",
85
+ "dist": 1,
86
+ }
87
+ ],
88
+ "type": "text",
89
+ }
90
+ ]
91
+
92
+ result = await model.ainvoke(messages)
93
+
94
+ assert result.content == expected_content
95
+
96
+ @pytest.mark.asyncio
97
+ async def test_citation_mixin_basic_functionality_without_context(self):
98
+ """Test basic functionality when no context tags are present."""
99
+
100
+ model = MockChatModel(response_content="Test response.")
101
+ messages: list[list[BaseMessage]] = [[HumanMessage(content="Test message")]]
102
+
103
+ result = await model.agenerate(messages)
104
+
105
+ assert len(result.generations) == 1
106
+ assert len(result.generations[0]) == 1
107
+ generation = result.generations[0][0]
108
+ assert isinstance(generation, ChatGeneration)
109
+ # When no context tags, content remains string
110
+ assert generation.message.content == "Test response."
111
+
112
+ @pytest.mark.asyncio
113
+ async def test_citation_mixin_basic_functionality_without_context_invoke(self):
114
+ """Test basic functionality when no context tags are present."""
115
+
116
+ model = MockChatModel(response_content="Test response.")
117
+ messages: list[BaseMessage] = [HumanMessage(content="Test message")]
118
+
119
+ result = await model.ainvoke(messages)
120
+
121
+ assert isinstance(result, AIMessage)
122
+ # When no context tags, content remains string
123
+ assert result.content == "Test response."
124
+
125
+ @pytest.mark.asyncio
126
+ async def test_citation_mixin_recursion_prevention(self):
127
+ """Test that CitationMixin prevents recursion when _adding_citations is True."""
128
+
129
+ model = MockChatModel()
130
+
131
+ messages: list[list[BaseMessage]] = [[HumanMessage(content="Test message")]]
132
+
133
+ # Mock the parent's agenerate method
134
+ with patch.object(
135
+ CitationMixin.__bases__[0], "agenerate"
136
+ ) as mock_parent_agenerate:
137
+ mock_parent_agenerate.return_value = LLMResult(
138
+ generations=[
139
+ [ChatGeneration(message=AIMessage(content="Test response"))]
140
+ ]
141
+ )
142
+
143
+ result = await model.agenerate(messages)
144
+
145
+ mock_parent_agenerate.assert_called_once()
146
+ call_args, _ = mock_parent_agenerate.call_args
147
+ assert call_args[0] == messages
148
+ assert len(result.generations) == 1
149
+ assert len(result.generations[0]) == 1
150
+
151
+ @pytest.mark.asyncio
152
+ async def test_citation_mixin_exception_handling(self):
153
+ """Test that exceptions are handled properly."""
154
+
155
+ class ErrorModel(CitationMixin):
156
+ @property
157
+ def _llm_type(self) -> str:
158
+ return "error"
159
+
160
+ def _generate(self, messages, stop=None, run_manager=None, **kwargs):
161
+ raise ValueError("Test error")
162
+
163
+ model = ErrorModel()
164
+ messages: list[list[BaseMessage]] = [[HumanMessage(content="Test message")]]
165
+
166
+ with pytest.raises(ValueError, match="Test error"):
167
+ await model.agenerate(messages)
168
+
169
+ @pytest.mark.asyncio
170
+ async def test_citation_mixin_context_tag_processing(self):
171
+ """Test that context tags are processed correctly."""
172
+
173
+ model = MockChatModel(response_content="Response with context")
174
+
175
+ # Create a message with context tags
176
+ context_message = HumanMessage(
177
+ content="Question about <context key='test'>\nSome context\n</context>"
178
+ )
179
+ messages: list[list[BaseMessage]] = [[context_message]]
180
+
181
+ with patch(
182
+ "langchain_b12.citations.citations.add_citations"
183
+ ) as mock_add_citations:
184
+
185
+ # Simulate structured content after citation processing
186
+ cited_message = AIMessage(
187
+ content=[
188
+ {
189
+ "text": "Response with context",
190
+ "citations": [
191
+ {
192
+ "cited_text": "Some context",
193
+ "generated_cited_text": "Some context",
194
+ "key": "test",
195
+ "dist": 0,
196
+ }
197
+ ],
198
+ "type": "text",
199
+ }
200
+ ]
201
+ )
202
+ mock_add_citations.return_value = cited_message
203
+
204
+ result = await model.agenerate(messages)
205
+
206
+ # Should call add_citations with the context and response
207
+ mock_add_citations.assert_called_once()
208
+
209
+ # Content should be structured after processing
210
+ generation = result.generations[0][0]
211
+ assert generation.message == cited_message
212
+
213
+ @pytest.mark.asyncio
214
+ async def test_citation_mixin_context_tag_processing_invoke(self):
215
+ """Test that context tags are processed correctly."""
216
+
217
+ model = MockChatModel(response_content="Response with context")
218
+
219
+ # Create a message with context tags
220
+ context_message = HumanMessage(
221
+ content="Question about <context key='test'>\nSome context\n</context>"
222
+ )
223
+ messages: list[BaseMessage] = [context_message]
224
+
225
+ with patch(
226
+ "langchain_b12.citations.citations.add_citations"
227
+ ) as mock_add_citations:
228
+
229
+ # Simulate structured content after citation processing
230
+ cited_message = AIMessage(
231
+ content=[
232
+ {
233
+ "text": "Response with context",
234
+ "citations": [
235
+ {
236
+ "cited_text": "Some context",
237
+ "generated_cited_text": "Some context",
238
+ "key": "test",
239
+ "dist": 0,
240
+ }
241
+ ],
242
+ "type": "text",
243
+ }
244
+ ]
245
+ )
246
+ mock_add_citations.return_value = cited_message
247
+
248
+ result = await model.ainvoke(messages)
249
+
250
+ # Should call add_citations with the context and response
251
+ mock_add_citations.assert_called_once()
252
+
253
+ # Content should be structured after processing
254
+ assert result == cited_message
255
+
256
+ @pytest.mark.asyncio
257
+ async def test_citation_mixin_kwargs_preservation(self):
258
+ """Test that kwargs are properly passed through."""
259
+
260
+ class KwargsTestModel(CitationMixin):
261
+ @property
262
+ def _llm_type(self) -> str:
263
+ return "kwargstest"
264
+
265
+ def _generate(self, messages, stop=None, run_manager=None, **kwargs):
266
+ from langchain_core.outputs import ChatResult
267
+
268
+ ai_message = AIMessage(content="Kwargs test")
269
+ generation = ChatGeneration(message=ai_message)
270
+ return ChatResult(generations=[generation])
271
+
272
+ model = KwargsTestModel()
273
+ messages: list[list[BaseMessage]] = [[HumanMessage(content="Test message")]]
274
+
275
+ test_kwargs = {
276
+ "temperature": 0.7,
277
+ "max_tokens": 100,
278
+ "custom_param": "test_value",
279
+ }
280
+
281
+ with patch.object(
282
+ CitationMixin.__bases__[0], "agenerate"
283
+ ) as mock_parent_agenerate:
284
+
285
+ mock_parent_agenerate.return_value = LLMResult(
286
+ generations=[[ChatGeneration(message=AIMessage(content="Kwargs test"))]]
287
+ )
288
+
289
+ await model.agenerate(messages, **test_kwargs)
290
+
291
+ # Check that kwargs were passed to parent's agenerate
292
+ mock_parent_agenerate.assert_called_once()
293
+ _, call_kwargs = mock_parent_agenerate.call_args
294
+
295
+ # Verify all test kwargs are present
296
+ for key, value in test_kwargs.items():
297
+ assert key in call_kwargs
298
+ assert call_kwargs[key] == value
299
+
300
+ @pytest.mark.asyncio
301
+ async def test_citation_mixin_multiple_message_batches(self):
302
+ """Test handling of multiple message batches."""
303
+
304
+ model = MockChatModel()
305
+ messages: list[list[BaseMessage]] = [
306
+ [HumanMessage(content="First batch")],
307
+ [HumanMessage(content="Second batch")],
308
+ ]
309
+
310
+ with patch.object(
311
+ CitationMixin.__bases__[0], "agenerate"
312
+ ) as mock_parent_agenerate:
313
+
314
+ mock_parent_agenerate.return_value = LLMResult(
315
+ generations=[
316
+ [ChatGeneration(message=AIMessage(content="First response"))],
317
+ [ChatGeneration(message=AIMessage(content="Second response"))],
318
+ ]
319
+ )
320
+
321
+ result = await model.agenerate(messages)
322
+
323
+ # Should process all batches
324
+ assert len(result.generations) == 2
325
+ assert len(result.generations[0]) == 1
326
+ assert len(result.generations[1]) == 1
327
+
328
+ @pytest.mark.asyncio
329
+ async def test_citation_mixin_realistic_workflow(self):
330
+ """Test a realistic workflow with context and citations."""
331
+
332
+ model = MockChatModel()
333
+
334
+ # Message with context tags
335
+ messages: list[list[BaseMessage]] = [
336
+ [
337
+ HumanMessage(
338
+ content="What is the capital of France? "
339
+ "<context key='france'>\nFrance is a country in Europe. "
340
+ "Paris is the capital city.\n</context>"
341
+ )
342
+ ]
343
+ ]
344
+
345
+ with (
346
+ patch.object(
347
+ CitationMixin.__bases__[0], "agenerate"
348
+ ) as mock_parent_agenerate,
349
+ patch(
350
+ "langchain_b12.citations.citations.add_citations"
351
+ ) as mock_add_citations,
352
+ ):
353
+
354
+ mock_parent_agenerate.return_value = LLMResult(
355
+ generations=[
356
+ [
357
+ ChatGeneration(
358
+ message=AIMessage(content="The capital of France is Paris.")
359
+ )
360
+ ]
361
+ ]
362
+ )
363
+
364
+ # Simulate citation processing result
365
+ cited_message = AIMessage(
366
+ content=[
367
+ {
368
+ "text": "The capital of France is Paris.",
369
+ "citations": [
370
+ {
371
+ "cited_text": "Paris is the capital city",
372
+ "generated_cited_text": "Paris is the capital city",
373
+ "key": "france",
374
+ "dist": 0,
375
+ }
376
+ ],
377
+ "type": "text",
378
+ }
379
+ ]
380
+ )
381
+ mock_add_citations.return_value = cited_message
382
+
383
+ result = await model.agenerate(messages)
384
+
385
+ # Verify the workflow executed correctly
386
+ mock_parent_agenerate.assert_called_once()
387
+ mock_add_citations.assert_called_once()
388
+
389
+ # Check that the result has the expected structure
390
+ assert len(result.generations) == 1
391
+ assert len(result.generations[0]) == 1
392
+
393
+ # Verify structured content after citation processing
394
+ generation = result.generations[0][0]
395
+ assert isinstance(generation, ChatGeneration)
396
+ assert isinstance(generation.message.content, list)
397
+ # Basic validation that citation processing occurred
398
+ assert len(generation.message.content) > 0
399
+
400
+ @pytest.mark.asyncio
401
+ async def test_citation_mixin_tool_call_handling(self):
402
+ """Test handling of AIMessage with tool calls (no string content)."""
403
+
404
+ messages: list[list[BaseMessage]] = [[HumanMessage(content="Use a tool")]]
405
+ tool_calls: list[ToolCall] = [
406
+ {
407
+ "name": "test_tool",
408
+ "args": {"param": "value"},
409
+ "id": "call_123",
410
+ "type": "tool_call",
411
+ }
412
+ ]
413
+ model = MockChatModel(response_content="", tool_calls=tool_calls)
414
+
415
+ result = await model.agenerate(messages)
416
+
417
+ # Should handle tool calls without issues
418
+ assert len(result.generations) == 1
419
+ generation = result.generations[0][0]
420
+ assert isinstance(generation, ChatGeneration)
421
+ assert isinstance(generation.message, AIMessage)
422
+ # Content remains empty for tool calls
423
+ assert generation.message.content == ""
424
+ # Tool calls should be preserved
425
+ assert generation.message.tool_calls == tool_calls
426
+
427
+ @pytest.mark.asyncio
428
+ async def test_citation_mixin_error_in_citation_processing(self):
429
+ """Test that errors in citation processing are handled properly."""
430
+
431
+ model = MockChatModel()
432
+ messages: list[list[BaseMessage]] = [
433
+ [
434
+ HumanMessage(
435
+ content="Test question <context key='test'>Some context</context>"
436
+ )
437
+ ]
438
+ ]
439
+
440
+ with (
441
+ patch.object(
442
+ CitationMixin.__bases__[0], "agenerate"
443
+ ) as mock_parent_agenerate,
444
+ patch(
445
+ "langchain_b12.citations.citations.add_citations"
446
+ ) as mock_add_citations,
447
+ ):
448
+
449
+ mock_parent_agenerate.return_value = LLMResult(
450
+ generations=[
451
+ [ChatGeneration(message=AIMessage(content="Test response"))]
452
+ ]
453
+ )
454
+
455
+ # Simulate an error in citation processing
456
+ mock_add_citations.side_effect = RuntimeError("Citation error")
457
+
458
+ # Should propagate the error from citation processing
459
+ with pytest.raises(RuntimeError, match="Citation error"):
460
+ await model.agenerate(messages)
@@ -470,8 +470,7 @@ class TestEdgeCases:
470
470
  citation, match = result[0]
471
471
  assert citation.cited_text == "test"
472
472
  assert citation.key == "key"
473
- assert match is not None
474
- assert match["dist"] >= 0
473
+ assert match is None
475
474
 
476
475
  # Empty sentences
477
476
  citations = Citations(
@@ -295,7 +295,7 @@ wheels = [
295
295
 
296
296
  [[package]]
297
297
  name = "langchain-b12"
298
- version = "0.1.2"
298
+ version = "0.1.4"
299
299
  source = { editable = "." }
300
300
  dependencies = [
301
301
  { name = "langchain-core" },
File without changes
File without changes
File without changes