langchain-b12 0.1.3__tar.gz → 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langchain_b12-0.1.3 → langchain_b12-0.1.4}/PKG-INFO +1 -1
- {langchain_b12-0.1.3 → langchain_b12-0.1.4}/pyproject.toml +1 -1
- {langchain_b12-0.1.3 → langchain_b12-0.1.4}/src/langchain_b12/citations/citations.py +80 -3
- langchain_b12-0.1.4/tests/test_citation_mixin.py +460 -0
- {langchain_b12-0.1.3 → langchain_b12-0.1.4}/tests/test_citations.py +1 -2
- {langchain_b12-0.1.3 → langchain_b12-0.1.4}/uv.lock +1 -1
- {langchain_b12-0.1.3 → langchain_b12-0.1.4}/.gitignore +0 -0
- {langchain_b12-0.1.3 → langchain_b12-0.1.4}/.python-version +0 -0
- {langchain_b12-0.1.3 → langchain_b12-0.1.4}/.vscode/extensions.json +0 -0
- {langchain_b12-0.1.3 → langchain_b12-0.1.4}/Makefile +0 -0
- {langchain_b12-0.1.3 → langchain_b12-0.1.4}/README.md +0 -0
- {langchain_b12-0.1.3 → langchain_b12-0.1.4}/src/langchain_b12/__init__.py +0 -0
- {langchain_b12-0.1.3 → langchain_b12-0.1.4}/src/langchain_b12/genai/embeddings.py +0 -0
- {langchain_b12-0.1.3 → langchain_b12-0.1.4}/src/langchain_b12/genai/genai.py +0 -0
- {langchain_b12-0.1.3 → langchain_b12-0.1.4}/src/langchain_b12/genai/genai_utils.py +0 -0
- {langchain_b12-0.1.3 → langchain_b12-0.1.4}/src/langchain_b12/py.typed +0 -0
- {langchain_b12-0.1.3 → langchain_b12-0.1.4}/tests/test_genai.py +0 -0
- {langchain_b12-0.1.3 → langchain_b12-0.1.4}/tests/test_genai_utils.py +0 -0
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
import re
|
|
2
2
|
from collections.abc import Sequence
|
|
3
|
-
from typing import Literal, TypedDict
|
|
3
|
+
from typing import Any, Literal, TypedDict
|
|
4
|
+
from uuid import UUID
|
|
4
5
|
|
|
5
6
|
from fuzzysearch import find_near_matches
|
|
7
|
+
from langchain_core.callbacks import Callbacks
|
|
6
8
|
from langchain_core.language_models import BaseChatModel
|
|
7
9
|
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage
|
|
10
|
+
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
|
|
8
11
|
from langchain_core.runnables import Runnable
|
|
9
12
|
from langgraph.utils.runnable import RunnableCallable
|
|
10
13
|
from pydantic import BaseModel, Field
|
|
@@ -163,7 +166,11 @@ def validate_citations(
|
|
|
163
166
|
if citation.sentence_index < 0 or citation.sentence_index >= n_sentences:
|
|
164
167
|
# discard citations that refer to non-existing sentences
|
|
165
168
|
continue
|
|
166
|
-
|
|
169
|
+
# Allow for 10% error distance
|
|
170
|
+
max_l_dist = max(1, len(citation.cited_text) // 10)
|
|
171
|
+
matches = find_near_matches(
|
|
172
|
+
citation.cited_text, all_text, max_l_dist=max_l_dist
|
|
173
|
+
)
|
|
167
174
|
if not matches:
|
|
168
175
|
citations_with_matches.append((citation, None))
|
|
169
176
|
else:
|
|
@@ -187,6 +194,7 @@ async def add_citations(
|
|
|
187
194
|
messages: Sequence[BaseMessage],
|
|
188
195
|
message: AIMessage,
|
|
189
196
|
system_prompt: str,
|
|
197
|
+
**kwargs: Any,
|
|
190
198
|
) -> AIMessage:
|
|
191
199
|
"""Add citations to the message."""
|
|
192
200
|
if not message.content:
|
|
@@ -214,7 +222,9 @@ async def add_citations(
|
|
|
214
222
|
system_message = SystemMessage(system_prompt)
|
|
215
223
|
_messages = [system_message, *messages, numbered_message]
|
|
216
224
|
|
|
217
|
-
citations = await model.with_structured_output(Citations).ainvoke(
|
|
225
|
+
citations = await model.with_structured_output(Citations).ainvoke(
|
|
226
|
+
_messages, **kwargs
|
|
227
|
+
)
|
|
218
228
|
assert isinstance(
|
|
219
229
|
citations, Citations
|
|
220
230
|
), f"Expected Citations from model invocation but got {type(citations)}"
|
|
@@ -272,3 +282,70 @@ def create_citation_model(
|
|
|
272
282
|
func=None, # TODO: Implement a sync version if needed
|
|
273
283
|
afunc=ainvoke_with_citations,
|
|
274
284
|
)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class CitationMixin(BaseChatModel):
|
|
288
|
+
"""Mixin class to add citation functionality to a runnable.
|
|
289
|
+
|
|
290
|
+
Example usage:
|
|
291
|
+
```
|
|
292
|
+
from langchain_b12.genai.genai import ChatGenAI
|
|
293
|
+
from langchain_b12.citations.citations import CitationMixin
|
|
294
|
+
|
|
295
|
+
class CitationModel(ChatGenAI, CitationMixin):
|
|
296
|
+
pass
|
|
297
|
+
```
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
async def agenerate(
|
|
301
|
+
self,
|
|
302
|
+
messages: list[list[BaseMessage]],
|
|
303
|
+
stop: list[str] | None = None,
|
|
304
|
+
callbacks: Callbacks = None,
|
|
305
|
+
*,
|
|
306
|
+
tags: list[str] | None = None,
|
|
307
|
+
metadata: dict[str, Any] | None = None,
|
|
308
|
+
run_name: str | None = None,
|
|
309
|
+
run_id: UUID | None = None,
|
|
310
|
+
**kwargs: Any,
|
|
311
|
+
) -> LLMResult:
|
|
312
|
+
# Check if we should generate citations and remove it from kwargs
|
|
313
|
+
generate_citations = kwargs.pop("generate_citations", True)
|
|
314
|
+
|
|
315
|
+
llm_result = await super().agenerate(
|
|
316
|
+
messages,
|
|
317
|
+
stop,
|
|
318
|
+
callbacks,
|
|
319
|
+
tags=tags,
|
|
320
|
+
metadata=metadata,
|
|
321
|
+
run_name=run_name,
|
|
322
|
+
run_id=run_id,
|
|
323
|
+
**kwargs,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# Prevent recursion when extracting citations
|
|
327
|
+
if not generate_citations:
|
|
328
|
+
# Below we are call `add_citations` which will call `agenerate` again
|
|
329
|
+
# This will lead to an infinite loop if we don't stop here.
|
|
330
|
+
# We explicitly pass `generate_citations=False` below to sto this recursion.
|
|
331
|
+
return llm_result
|
|
332
|
+
|
|
333
|
+
# overwrite each generation with a version that has citations added
|
|
334
|
+
for _messages, generations in zip(messages, llm_result.generations):
|
|
335
|
+
for generation in generations:
|
|
336
|
+
assert isinstance(generation, ChatGeneration) and not isinstance(
|
|
337
|
+
generation, ChatGenerationChunk
|
|
338
|
+
), f"Expected ChatGeneration; received {type(generation)}"
|
|
339
|
+
assert isinstance(
|
|
340
|
+
generation.message, AIMessage
|
|
341
|
+
), f"Expected AIMessage; received {type(generation.message)}"
|
|
342
|
+
message_with_citations = await add_citations(
|
|
343
|
+
self,
|
|
344
|
+
_messages,
|
|
345
|
+
generation.message,
|
|
346
|
+
SYSTEM_PROMPT,
|
|
347
|
+
generate_citations=False,
|
|
348
|
+
)
|
|
349
|
+
generation.message = message_with_citations
|
|
350
|
+
|
|
351
|
+
return llm_result
|
|
@@ -0,0 +1,460 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Comprehensive tests for the CitationMixin class.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from typing import Any
|
|
7
|
+
from unittest.mock import patch
|
|
8
|
+
|
|
9
|
+
import pytest
|
|
10
|
+
from langchain_b12.citations.citations import Citation, CitationMixin, Citations
|
|
11
|
+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolCall
|
|
12
|
+
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
|
|
13
|
+
from langchain_core.tools import BaseTool
|
|
14
|
+
from pydantic import BaseModel, Field
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MockChatModel(CitationMixin):
|
|
18
|
+
response_content: str = Field(default="This is a mock response.")
|
|
19
|
+
tool_calls: list[ToolCall] = Field(default_factory=list)
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def _llm_type(self) -> str:
|
|
23
|
+
return "simple"
|
|
24
|
+
|
|
25
|
+
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
|
26
|
+
|
|
27
|
+
ai_message = AIMessage(
|
|
28
|
+
content=self.response_content, tool_calls=self.tool_calls
|
|
29
|
+
)
|
|
30
|
+
generation = ChatGeneration(message=ai_message)
|
|
31
|
+
return ChatResult(generations=[generation])
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class MockChatModelWithStructuredOutput(MockChatModel):
|
|
35
|
+
structured_response: BaseModel = Field(...)
|
|
36
|
+
|
|
37
|
+
def bind_tools(
|
|
38
|
+
self,
|
|
39
|
+
tools: Sequence[BaseTool],
|
|
40
|
+
*,
|
|
41
|
+
tool_choice: str | None = None,
|
|
42
|
+
**kwargs: Any,
|
|
43
|
+
):
|
|
44
|
+
return MockChatModel(
|
|
45
|
+
response_content="",
|
|
46
|
+
tool_calls=[
|
|
47
|
+
ToolCall(
|
|
48
|
+
name=self.structured_response.__class__.__name__,
|
|
49
|
+
args=self.structured_response.model_dump(),
|
|
50
|
+
id="structured_abc",
|
|
51
|
+
)
|
|
52
|
+
],
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class TestCitationMixin:
|
|
57
|
+
"""Test the CitationMixin class functionality."""
|
|
58
|
+
|
|
59
|
+
@pytest.mark.asyncio
|
|
60
|
+
async def test_end_to_end(self):
|
|
61
|
+
"""Test that context tags are processed correctly."""
|
|
62
|
+
|
|
63
|
+
citations = Citations(
|
|
64
|
+
values=[Citation(sentence_index=0, key="abc", cited_text="bar")]
|
|
65
|
+
)
|
|
66
|
+
model = MockChatModelWithStructuredOutput(
|
|
67
|
+
response_content="foo", structured_response=citations
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Create a message with context tags
|
|
71
|
+
context_message = HumanMessage(
|
|
72
|
+
content="Question about <context key='abc'>\nbaz\n</context>"
|
|
73
|
+
)
|
|
74
|
+
messages: list[BaseMessage] = [context_message]
|
|
75
|
+
|
|
76
|
+
# Simulate structured content after citation processing
|
|
77
|
+
expected_content = [
|
|
78
|
+
{
|
|
79
|
+
"text": "foo",
|
|
80
|
+
"citations": [
|
|
81
|
+
{
|
|
82
|
+
"cited_text": "baz",
|
|
83
|
+
"generated_cited_text": "bar",
|
|
84
|
+
"key": "abc",
|
|
85
|
+
"dist": 1,
|
|
86
|
+
}
|
|
87
|
+
],
|
|
88
|
+
"type": "text",
|
|
89
|
+
}
|
|
90
|
+
]
|
|
91
|
+
|
|
92
|
+
result = await model.ainvoke(messages)
|
|
93
|
+
|
|
94
|
+
assert result.content == expected_content
|
|
95
|
+
|
|
96
|
+
@pytest.mark.asyncio
|
|
97
|
+
async def test_citation_mixin_basic_functionality_without_context(self):
|
|
98
|
+
"""Test basic functionality when no context tags are present."""
|
|
99
|
+
|
|
100
|
+
model = MockChatModel(response_content="Test response.")
|
|
101
|
+
messages: list[list[BaseMessage]] = [[HumanMessage(content="Test message")]]
|
|
102
|
+
|
|
103
|
+
result = await model.agenerate(messages)
|
|
104
|
+
|
|
105
|
+
assert len(result.generations) == 1
|
|
106
|
+
assert len(result.generations[0]) == 1
|
|
107
|
+
generation = result.generations[0][0]
|
|
108
|
+
assert isinstance(generation, ChatGeneration)
|
|
109
|
+
# When no context tags, content remains string
|
|
110
|
+
assert generation.message.content == "Test response."
|
|
111
|
+
|
|
112
|
+
@pytest.mark.asyncio
|
|
113
|
+
async def test_citation_mixin_basic_functionality_without_context_invoke(self):
|
|
114
|
+
"""Test basic functionality when no context tags are present."""
|
|
115
|
+
|
|
116
|
+
model = MockChatModel(response_content="Test response.")
|
|
117
|
+
messages: list[BaseMessage] = [HumanMessage(content="Test message")]
|
|
118
|
+
|
|
119
|
+
result = await model.ainvoke(messages)
|
|
120
|
+
|
|
121
|
+
assert isinstance(result, AIMessage)
|
|
122
|
+
# When no context tags, content remains string
|
|
123
|
+
assert result.content == "Test response."
|
|
124
|
+
|
|
125
|
+
@pytest.mark.asyncio
|
|
126
|
+
async def test_citation_mixin_recursion_prevention(self):
|
|
127
|
+
"""Test that CitationMixin prevents recursion when _adding_citations is True."""
|
|
128
|
+
|
|
129
|
+
model = MockChatModel()
|
|
130
|
+
|
|
131
|
+
messages: list[list[BaseMessage]] = [[HumanMessage(content="Test message")]]
|
|
132
|
+
|
|
133
|
+
# Mock the parent's agenerate method
|
|
134
|
+
with patch.object(
|
|
135
|
+
CitationMixin.__bases__[0], "agenerate"
|
|
136
|
+
) as mock_parent_agenerate:
|
|
137
|
+
mock_parent_agenerate.return_value = LLMResult(
|
|
138
|
+
generations=[
|
|
139
|
+
[ChatGeneration(message=AIMessage(content="Test response"))]
|
|
140
|
+
]
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
result = await model.agenerate(messages)
|
|
144
|
+
|
|
145
|
+
mock_parent_agenerate.assert_called_once()
|
|
146
|
+
call_args, _ = mock_parent_agenerate.call_args
|
|
147
|
+
assert call_args[0] == messages
|
|
148
|
+
assert len(result.generations) == 1
|
|
149
|
+
assert len(result.generations[0]) == 1
|
|
150
|
+
|
|
151
|
+
@pytest.mark.asyncio
|
|
152
|
+
async def test_citation_mixin_exception_handling(self):
|
|
153
|
+
"""Test that exceptions are handled properly."""
|
|
154
|
+
|
|
155
|
+
class ErrorModel(CitationMixin):
|
|
156
|
+
@property
|
|
157
|
+
def _llm_type(self) -> str:
|
|
158
|
+
return "error"
|
|
159
|
+
|
|
160
|
+
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
|
161
|
+
raise ValueError("Test error")
|
|
162
|
+
|
|
163
|
+
model = ErrorModel()
|
|
164
|
+
messages: list[list[BaseMessage]] = [[HumanMessage(content="Test message")]]
|
|
165
|
+
|
|
166
|
+
with pytest.raises(ValueError, match="Test error"):
|
|
167
|
+
await model.agenerate(messages)
|
|
168
|
+
|
|
169
|
+
@pytest.mark.asyncio
|
|
170
|
+
async def test_citation_mixin_context_tag_processing(self):
|
|
171
|
+
"""Test that context tags are processed correctly."""
|
|
172
|
+
|
|
173
|
+
model = MockChatModel(response_content="Response with context")
|
|
174
|
+
|
|
175
|
+
# Create a message with context tags
|
|
176
|
+
context_message = HumanMessage(
|
|
177
|
+
content="Question about <context key='test'>\nSome context\n</context>"
|
|
178
|
+
)
|
|
179
|
+
messages: list[list[BaseMessage]] = [[context_message]]
|
|
180
|
+
|
|
181
|
+
with patch(
|
|
182
|
+
"langchain_b12.citations.citations.add_citations"
|
|
183
|
+
) as mock_add_citations:
|
|
184
|
+
|
|
185
|
+
# Simulate structured content after citation processing
|
|
186
|
+
cited_message = AIMessage(
|
|
187
|
+
content=[
|
|
188
|
+
{
|
|
189
|
+
"text": "Response with context",
|
|
190
|
+
"citations": [
|
|
191
|
+
{
|
|
192
|
+
"cited_text": "Some context",
|
|
193
|
+
"generated_cited_text": "Some context",
|
|
194
|
+
"key": "test",
|
|
195
|
+
"dist": 0,
|
|
196
|
+
}
|
|
197
|
+
],
|
|
198
|
+
"type": "text",
|
|
199
|
+
}
|
|
200
|
+
]
|
|
201
|
+
)
|
|
202
|
+
mock_add_citations.return_value = cited_message
|
|
203
|
+
|
|
204
|
+
result = await model.agenerate(messages)
|
|
205
|
+
|
|
206
|
+
# Should call add_citations with the context and response
|
|
207
|
+
mock_add_citations.assert_called_once()
|
|
208
|
+
|
|
209
|
+
# Content should be structured after processing
|
|
210
|
+
generation = result.generations[0][0]
|
|
211
|
+
assert generation.message == cited_message
|
|
212
|
+
|
|
213
|
+
@pytest.mark.asyncio
|
|
214
|
+
async def test_citation_mixin_context_tag_processing_invoke(self):
|
|
215
|
+
"""Test that context tags are processed correctly."""
|
|
216
|
+
|
|
217
|
+
model = MockChatModel(response_content="Response with context")
|
|
218
|
+
|
|
219
|
+
# Create a message with context tags
|
|
220
|
+
context_message = HumanMessage(
|
|
221
|
+
content="Question about <context key='test'>\nSome context\n</context>"
|
|
222
|
+
)
|
|
223
|
+
messages: list[BaseMessage] = [context_message]
|
|
224
|
+
|
|
225
|
+
with patch(
|
|
226
|
+
"langchain_b12.citations.citations.add_citations"
|
|
227
|
+
) as mock_add_citations:
|
|
228
|
+
|
|
229
|
+
# Simulate structured content after citation processing
|
|
230
|
+
cited_message = AIMessage(
|
|
231
|
+
content=[
|
|
232
|
+
{
|
|
233
|
+
"text": "Response with context",
|
|
234
|
+
"citations": [
|
|
235
|
+
{
|
|
236
|
+
"cited_text": "Some context",
|
|
237
|
+
"generated_cited_text": "Some context",
|
|
238
|
+
"key": "test",
|
|
239
|
+
"dist": 0,
|
|
240
|
+
}
|
|
241
|
+
],
|
|
242
|
+
"type": "text",
|
|
243
|
+
}
|
|
244
|
+
]
|
|
245
|
+
)
|
|
246
|
+
mock_add_citations.return_value = cited_message
|
|
247
|
+
|
|
248
|
+
result = await model.ainvoke(messages)
|
|
249
|
+
|
|
250
|
+
# Should call add_citations with the context and response
|
|
251
|
+
mock_add_citations.assert_called_once()
|
|
252
|
+
|
|
253
|
+
# Content should be structured after processing
|
|
254
|
+
assert result == cited_message
|
|
255
|
+
|
|
256
|
+
@pytest.mark.asyncio
|
|
257
|
+
async def test_citation_mixin_kwargs_preservation(self):
|
|
258
|
+
"""Test that kwargs are properly passed through."""
|
|
259
|
+
|
|
260
|
+
class KwargsTestModel(CitationMixin):
|
|
261
|
+
@property
|
|
262
|
+
def _llm_type(self) -> str:
|
|
263
|
+
return "kwargstest"
|
|
264
|
+
|
|
265
|
+
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
|
266
|
+
from langchain_core.outputs import ChatResult
|
|
267
|
+
|
|
268
|
+
ai_message = AIMessage(content="Kwargs test")
|
|
269
|
+
generation = ChatGeneration(message=ai_message)
|
|
270
|
+
return ChatResult(generations=[generation])
|
|
271
|
+
|
|
272
|
+
model = KwargsTestModel()
|
|
273
|
+
messages: list[list[BaseMessage]] = [[HumanMessage(content="Test message")]]
|
|
274
|
+
|
|
275
|
+
test_kwargs = {
|
|
276
|
+
"temperature": 0.7,
|
|
277
|
+
"max_tokens": 100,
|
|
278
|
+
"custom_param": "test_value",
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
with patch.object(
|
|
282
|
+
CitationMixin.__bases__[0], "agenerate"
|
|
283
|
+
) as mock_parent_agenerate:
|
|
284
|
+
|
|
285
|
+
mock_parent_agenerate.return_value = LLMResult(
|
|
286
|
+
generations=[[ChatGeneration(message=AIMessage(content="Kwargs test"))]]
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
await model.agenerate(messages, **test_kwargs)
|
|
290
|
+
|
|
291
|
+
# Check that kwargs were passed to parent's agenerate
|
|
292
|
+
mock_parent_agenerate.assert_called_once()
|
|
293
|
+
_, call_kwargs = mock_parent_agenerate.call_args
|
|
294
|
+
|
|
295
|
+
# Verify all test kwargs are present
|
|
296
|
+
for key, value in test_kwargs.items():
|
|
297
|
+
assert key in call_kwargs
|
|
298
|
+
assert call_kwargs[key] == value
|
|
299
|
+
|
|
300
|
+
@pytest.mark.asyncio
|
|
301
|
+
async def test_citation_mixin_multiple_message_batches(self):
|
|
302
|
+
"""Test handling of multiple message batches."""
|
|
303
|
+
|
|
304
|
+
model = MockChatModel()
|
|
305
|
+
messages: list[list[BaseMessage]] = [
|
|
306
|
+
[HumanMessage(content="First batch")],
|
|
307
|
+
[HumanMessage(content="Second batch")],
|
|
308
|
+
]
|
|
309
|
+
|
|
310
|
+
with patch.object(
|
|
311
|
+
CitationMixin.__bases__[0], "agenerate"
|
|
312
|
+
) as mock_parent_agenerate:
|
|
313
|
+
|
|
314
|
+
mock_parent_agenerate.return_value = LLMResult(
|
|
315
|
+
generations=[
|
|
316
|
+
[ChatGeneration(message=AIMessage(content="First response"))],
|
|
317
|
+
[ChatGeneration(message=AIMessage(content="Second response"))],
|
|
318
|
+
]
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
result = await model.agenerate(messages)
|
|
322
|
+
|
|
323
|
+
# Should process all batches
|
|
324
|
+
assert len(result.generations) == 2
|
|
325
|
+
assert len(result.generations[0]) == 1
|
|
326
|
+
assert len(result.generations[1]) == 1
|
|
327
|
+
|
|
328
|
+
@pytest.mark.asyncio
|
|
329
|
+
async def test_citation_mixin_realistic_workflow(self):
|
|
330
|
+
"""Test a realistic workflow with context and citations."""
|
|
331
|
+
|
|
332
|
+
model = MockChatModel()
|
|
333
|
+
|
|
334
|
+
# Message with context tags
|
|
335
|
+
messages: list[list[BaseMessage]] = [
|
|
336
|
+
[
|
|
337
|
+
HumanMessage(
|
|
338
|
+
content="What is the capital of France? "
|
|
339
|
+
"<context key='france'>\nFrance is a country in Europe. "
|
|
340
|
+
"Paris is the capital city.\n</context>"
|
|
341
|
+
)
|
|
342
|
+
]
|
|
343
|
+
]
|
|
344
|
+
|
|
345
|
+
with (
|
|
346
|
+
patch.object(
|
|
347
|
+
CitationMixin.__bases__[0], "agenerate"
|
|
348
|
+
) as mock_parent_agenerate,
|
|
349
|
+
patch(
|
|
350
|
+
"langchain_b12.citations.citations.add_citations"
|
|
351
|
+
) as mock_add_citations,
|
|
352
|
+
):
|
|
353
|
+
|
|
354
|
+
mock_parent_agenerate.return_value = LLMResult(
|
|
355
|
+
generations=[
|
|
356
|
+
[
|
|
357
|
+
ChatGeneration(
|
|
358
|
+
message=AIMessage(content="The capital of France is Paris.")
|
|
359
|
+
)
|
|
360
|
+
]
|
|
361
|
+
]
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
# Simulate citation processing result
|
|
365
|
+
cited_message = AIMessage(
|
|
366
|
+
content=[
|
|
367
|
+
{
|
|
368
|
+
"text": "The capital of France is Paris.",
|
|
369
|
+
"citations": [
|
|
370
|
+
{
|
|
371
|
+
"cited_text": "Paris is the capital city",
|
|
372
|
+
"generated_cited_text": "Paris is the capital city",
|
|
373
|
+
"key": "france",
|
|
374
|
+
"dist": 0,
|
|
375
|
+
}
|
|
376
|
+
],
|
|
377
|
+
"type": "text",
|
|
378
|
+
}
|
|
379
|
+
]
|
|
380
|
+
)
|
|
381
|
+
mock_add_citations.return_value = cited_message
|
|
382
|
+
|
|
383
|
+
result = await model.agenerate(messages)
|
|
384
|
+
|
|
385
|
+
# Verify the workflow executed correctly
|
|
386
|
+
mock_parent_agenerate.assert_called_once()
|
|
387
|
+
mock_add_citations.assert_called_once()
|
|
388
|
+
|
|
389
|
+
# Check that the result has the expected structure
|
|
390
|
+
assert len(result.generations) == 1
|
|
391
|
+
assert len(result.generations[0]) == 1
|
|
392
|
+
|
|
393
|
+
# Verify structured content after citation processing
|
|
394
|
+
generation = result.generations[0][0]
|
|
395
|
+
assert isinstance(generation, ChatGeneration)
|
|
396
|
+
assert isinstance(generation.message.content, list)
|
|
397
|
+
# Basic validation that citation processing occurred
|
|
398
|
+
assert len(generation.message.content) > 0
|
|
399
|
+
|
|
400
|
+
@pytest.mark.asyncio
|
|
401
|
+
async def test_citation_mixin_tool_call_handling(self):
|
|
402
|
+
"""Test handling of AIMessage with tool calls (no string content)."""
|
|
403
|
+
|
|
404
|
+
messages: list[list[BaseMessage]] = [[HumanMessage(content="Use a tool")]]
|
|
405
|
+
tool_calls: list[ToolCall] = [
|
|
406
|
+
{
|
|
407
|
+
"name": "test_tool",
|
|
408
|
+
"args": {"param": "value"},
|
|
409
|
+
"id": "call_123",
|
|
410
|
+
"type": "tool_call",
|
|
411
|
+
}
|
|
412
|
+
]
|
|
413
|
+
model = MockChatModel(response_content="", tool_calls=tool_calls)
|
|
414
|
+
|
|
415
|
+
result = await model.agenerate(messages)
|
|
416
|
+
|
|
417
|
+
# Should handle tool calls without issues
|
|
418
|
+
assert len(result.generations) == 1
|
|
419
|
+
generation = result.generations[0][0]
|
|
420
|
+
assert isinstance(generation, ChatGeneration)
|
|
421
|
+
assert isinstance(generation.message, AIMessage)
|
|
422
|
+
# Content remains empty for tool calls
|
|
423
|
+
assert generation.message.content == ""
|
|
424
|
+
# Tool calls should be preserved
|
|
425
|
+
assert generation.message.tool_calls == tool_calls
|
|
426
|
+
|
|
427
|
+
@pytest.mark.asyncio
|
|
428
|
+
async def test_citation_mixin_error_in_citation_processing(self):
|
|
429
|
+
"""Test that errors in citation processing are handled properly."""
|
|
430
|
+
|
|
431
|
+
model = MockChatModel()
|
|
432
|
+
messages: list[list[BaseMessage]] = [
|
|
433
|
+
[
|
|
434
|
+
HumanMessage(
|
|
435
|
+
content="Test question <context key='test'>Some context</context>"
|
|
436
|
+
)
|
|
437
|
+
]
|
|
438
|
+
]
|
|
439
|
+
|
|
440
|
+
with (
|
|
441
|
+
patch.object(
|
|
442
|
+
CitationMixin.__bases__[0], "agenerate"
|
|
443
|
+
) as mock_parent_agenerate,
|
|
444
|
+
patch(
|
|
445
|
+
"langchain_b12.citations.citations.add_citations"
|
|
446
|
+
) as mock_add_citations,
|
|
447
|
+
):
|
|
448
|
+
|
|
449
|
+
mock_parent_agenerate.return_value = LLMResult(
|
|
450
|
+
generations=[
|
|
451
|
+
[ChatGeneration(message=AIMessage(content="Test response"))]
|
|
452
|
+
]
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
# Simulate an error in citation processing
|
|
456
|
+
mock_add_citations.side_effect = RuntimeError("Citation error")
|
|
457
|
+
|
|
458
|
+
# Should propagate the error from citation processing
|
|
459
|
+
with pytest.raises(RuntimeError, match="Citation error"):
|
|
460
|
+
await model.agenerate(messages)
|
|
@@ -470,8 +470,7 @@ class TestEdgeCases:
|
|
|
470
470
|
citation, match = result[0]
|
|
471
471
|
assert citation.cited_text == "test"
|
|
472
472
|
assert citation.key == "key"
|
|
473
|
-
assert match is
|
|
474
|
-
assert match["dist"] >= 0
|
|
473
|
+
assert match is None
|
|
475
474
|
|
|
476
475
|
# Empty sentences
|
|
477
476
|
citations = Citations(
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|