langchain-b12 0.1.2__tar.gz → 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langchain_b12-0.1.2 → langchain_b12-0.1.4}/PKG-INFO +1 -1
- {langchain_b12-0.1.2 → langchain_b12-0.1.4}/pyproject.toml +11 -1
- langchain_b12-0.1.4/src/langchain_b12/citations/citations.py +351 -0
- {langchain_b12-0.1.2 → langchain_b12-0.1.4}/src/langchain_b12/genai/genai_utils.py +1 -1
- langchain_b12-0.1.4/tests/test_citation_mixin.py +460 -0
- langchain_b12-0.1.4/tests/test_citations.py +483 -0
- {langchain_b12-0.1.2 → langchain_b12-0.1.4}/tests/test_genai.py +1 -1
- langchain_b12-0.1.4/tests/test_genai_utils.py +521 -0
- {langchain_b12-0.1.2 → langchain_b12-0.1.4}/uv.lock +203 -2
- {langchain_b12-0.1.2 → langchain_b12-0.1.4}/.gitignore +0 -0
- {langchain_b12-0.1.2 → langchain_b12-0.1.4}/.python-version +0 -0
- {langchain_b12-0.1.2 → langchain_b12-0.1.4}/.vscode/extensions.json +0 -0
- {langchain_b12-0.1.2 → langchain_b12-0.1.4}/Makefile +0 -0
- {langchain_b12-0.1.2 → langchain_b12-0.1.4}/README.md +0 -0
- {langchain_b12-0.1.2 → langchain_b12-0.1.4}/src/langchain_b12/__init__.py +0 -0
- {langchain_b12-0.1.2 → langchain_b12-0.1.4}/src/langchain_b12/genai/embeddings.py +0 -0
- {langchain_b12-0.1.2 → langchain_b12-0.1.4}/src/langchain_b12/genai/genai.py +0 -0
- {langchain_b12-0.1.2 → langchain_b12-0.1.4}/src/langchain_b12/py.typed +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "langchain-b12"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.4"
|
|
4
4
|
description = "A reusable collection of tools and implementations for Langchain"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
authors = [
|
|
@@ -17,6 +17,11 @@ google = [
|
|
|
17
17
|
]
|
|
18
18
|
dev = [
|
|
19
19
|
"pytest>=8.3.5",
|
|
20
|
+
"pytest-asyncio>=1.0.0",
|
|
21
|
+
]
|
|
22
|
+
citations = [
|
|
23
|
+
"fuzzysearch>=0.8.0",
|
|
24
|
+
"langgraph>=0.4.7",
|
|
20
25
|
]
|
|
21
26
|
|
|
22
27
|
[build-system]
|
|
@@ -59,3 +64,8 @@ reportUnknownParameterType = false
|
|
|
59
64
|
reportUnknownMemberType = false
|
|
60
65
|
reportUnknownArgumentType = false
|
|
61
66
|
|
|
67
|
+
# Add pytest configuration
|
|
68
|
+
[tool.pytest.ini_options]
|
|
69
|
+
asyncio_default_fixture_loop_scope = "function"
|
|
70
|
+
asyncio_mode = "auto"
|
|
71
|
+
|
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import Any, Literal, TypedDict
|
|
4
|
+
from uuid import UUID
|
|
5
|
+
|
|
6
|
+
from fuzzysearch import find_near_matches
|
|
7
|
+
from langchain_core.callbacks import Callbacks
|
|
8
|
+
from langchain_core.language_models import BaseChatModel
|
|
9
|
+
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage
|
|
10
|
+
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
|
|
11
|
+
from langchain_core.runnables import Runnable
|
|
12
|
+
from langgraph.utils.runnable import RunnableCallable
|
|
13
|
+
from pydantic import BaseModel, Field
|
|
14
|
+
|
|
15
|
+
SYSTEM_PROMPT = """
|
|
16
|
+
You are an expert at identifying and adding citations to text.
|
|
17
|
+
Your task is to identify, for each sentence in the final message, which citations were used to generate it.
|
|
18
|
+
|
|
19
|
+
You will receive a numbered zero-indexed list of sentences in the final message, e.g.
|
|
20
|
+
```
|
|
21
|
+
0: Grass is green.
|
|
22
|
+
1: The sky is blue and the sun is shining.
|
|
23
|
+
```
|
|
24
|
+
The rest of the conversation may contain contexts enclosed in xml tags, e.g.
|
|
25
|
+
```
|
|
26
|
+
<context key="abc">
|
|
27
|
+
Today is a sunny day and the color of the grass is green.
|
|
28
|
+
</context>
|
|
29
|
+
```
|
|
30
|
+
Each sentence may have zero, one, or multiple citations from the contexts.
|
|
31
|
+
Each citation may be used for zero, one or multiple sentences.
|
|
32
|
+
A context may be cited zero, one, or multiple times.
|
|
33
|
+
|
|
34
|
+
The final message will be based on the contexts, but may not mention them explicitly.
|
|
35
|
+
You must identify which contexts and which parts of the contexts were used to generate each sentence.
|
|
36
|
+
For each such case, you must return a citation with a "sentence_index", "cited_text" and "key" property.
|
|
37
|
+
The "sentence_index" is the index of the sentence in the final message.
|
|
38
|
+
The "cited_text" must be a substring of the full context that was used to generate the sentence.
|
|
39
|
+
The "key" must be the key of the context that was used to generate the sentence.
|
|
40
|
+
Make sure that you copy the "cited_text" verbatim from the context, or it will not be considered valid.
|
|
41
|
+
|
|
42
|
+
For the example above, the output should look like this:
|
|
43
|
+
[
|
|
44
|
+
{
|
|
45
|
+
"sentence_index": 0,
|
|
46
|
+
"cited_text": "the color of the grass is green",
|
|
47
|
+
"key": "abc"
|
|
48
|
+
},
|
|
49
|
+
{
|
|
50
|
+
"sentence_index": 1,
|
|
51
|
+
"cited_text": "Today is a sunny day",
|
|
52
|
+
"key": "abc"
|
|
53
|
+
},
|
|
54
|
+
]
|
|
55
|
+
""".strip() # noqa: E501
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Match(TypedDict):
|
|
59
|
+
start: int
|
|
60
|
+
end: int
|
|
61
|
+
dist: int
|
|
62
|
+
matched: str
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class CitationType(TypedDict):
|
|
66
|
+
|
|
67
|
+
cited_text: str | None
|
|
68
|
+
generated_cited_text: str
|
|
69
|
+
key: str
|
|
70
|
+
dist: int | None
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class ContentType(TypedDict):
|
|
74
|
+
|
|
75
|
+
citations: list[CitationType] | None
|
|
76
|
+
text: str
|
|
77
|
+
type: Literal["text"]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class Citation(BaseModel):
|
|
81
|
+
|
|
82
|
+
sentence_index: int = Field(
|
|
83
|
+
...,
|
|
84
|
+
description="The index of the sentence from your answer "
|
|
85
|
+
"that this citation refers to.",
|
|
86
|
+
)
|
|
87
|
+
cited_text: str = Field(
|
|
88
|
+
...,
|
|
89
|
+
description="The text that is cited from the document. "
|
|
90
|
+
"Make sure you cite it verbatim!",
|
|
91
|
+
)
|
|
92
|
+
key: str = Field(..., description="The key of the document you are citing.")
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class Citations(BaseModel):
|
|
96
|
+
|
|
97
|
+
values: list[Citation] = Field(..., description="List of citations")
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def split_into_sentences(text: str) -> list[str]:
|
|
101
|
+
"""Split text into sentences on punctuation marks and newlines."""
|
|
102
|
+
if not text:
|
|
103
|
+
return [text]
|
|
104
|
+
|
|
105
|
+
# Split after punctuation followed by spaces, or on newlines
|
|
106
|
+
# Use capturing groups to preserve delimiters (spaces and newlines)
|
|
107
|
+
parts = re.split(r"((?<=[.!?])(?= +)|\n+)", text)
|
|
108
|
+
|
|
109
|
+
# Filter out empty strings that can result from splitting
|
|
110
|
+
return [part for part in parts if part]
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def contains_context_tags(text: str) -> bool:
|
|
114
|
+
"""Check if the text contains context tags."""
|
|
115
|
+
return bool(re.search(r"<context\s+key=[^>]+>.*?</context>", text, re.DOTALL))
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def merge_citations(
|
|
119
|
+
sentences: list[str], citations: list[tuple[Citation, Match | None]]
|
|
120
|
+
) -> list[ContentType]:
|
|
121
|
+
"""Merge citations into sentences."""
|
|
122
|
+
content: list[ContentType] = []
|
|
123
|
+
for sentence_index, sentence in enumerate(sentences):
|
|
124
|
+
_citations: list[CitationType] = []
|
|
125
|
+
for citation, match in citations:
|
|
126
|
+
if citation.sentence_index == sentence_index:
|
|
127
|
+
if match is None:
|
|
128
|
+
_citations.append(
|
|
129
|
+
{
|
|
130
|
+
"cited_text": None,
|
|
131
|
+
"generated_cited_text": citation.cited_text,
|
|
132
|
+
"key": citation.key,
|
|
133
|
+
"dist": None,
|
|
134
|
+
}
|
|
135
|
+
)
|
|
136
|
+
else:
|
|
137
|
+
_citations.append(
|
|
138
|
+
{
|
|
139
|
+
"cited_text": match["matched"],
|
|
140
|
+
"generated_cited_text": citation.cited_text,
|
|
141
|
+
"key": citation.key,
|
|
142
|
+
"dist": match["dist"],
|
|
143
|
+
}
|
|
144
|
+
)
|
|
145
|
+
content.append(
|
|
146
|
+
{"text": sentence, "citations": _citations or None, "type": "text"}
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
return content
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def validate_citations(
|
|
153
|
+
citations: Citations,
|
|
154
|
+
messages: Sequence[BaseMessage],
|
|
155
|
+
sentences: list[str],
|
|
156
|
+
) -> list[tuple[Citation, Match | None]]:
|
|
157
|
+
"""Validate the citations. Invalid citations are dropped."""
|
|
158
|
+
n_sentences = len(sentences)
|
|
159
|
+
|
|
160
|
+
all_text = "\n".join(
|
|
161
|
+
str(msg.content) for msg in messages if isinstance(msg.content, str)
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
citations_with_matches: list[tuple[Citation, Match | None]] = []
|
|
165
|
+
for citation in citations.values:
|
|
166
|
+
if citation.sentence_index < 0 or citation.sentence_index >= n_sentences:
|
|
167
|
+
# discard citations that refer to non-existing sentences
|
|
168
|
+
continue
|
|
169
|
+
# Allow for 10% error distance
|
|
170
|
+
max_l_dist = max(1, len(citation.cited_text) // 10)
|
|
171
|
+
matches = find_near_matches(
|
|
172
|
+
citation.cited_text, all_text, max_l_dist=max_l_dist
|
|
173
|
+
)
|
|
174
|
+
if not matches:
|
|
175
|
+
citations_with_matches.append((citation, None))
|
|
176
|
+
else:
|
|
177
|
+
match = matches[0]
|
|
178
|
+
citations_with_matches.append(
|
|
179
|
+
(
|
|
180
|
+
citation,
|
|
181
|
+
Match(
|
|
182
|
+
start=match.start,
|
|
183
|
+
end=match.end,
|
|
184
|
+
dist=match.dist,
|
|
185
|
+
matched=match.matched,
|
|
186
|
+
),
|
|
187
|
+
)
|
|
188
|
+
)
|
|
189
|
+
return citations_with_matches
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
async def add_citations(
|
|
193
|
+
model: BaseChatModel,
|
|
194
|
+
messages: Sequence[BaseMessage],
|
|
195
|
+
message: AIMessage,
|
|
196
|
+
system_prompt: str,
|
|
197
|
+
**kwargs: Any,
|
|
198
|
+
) -> AIMessage:
|
|
199
|
+
"""Add citations to the message."""
|
|
200
|
+
if not message.content:
|
|
201
|
+
# Nothing to be done, for example in case of a tool call
|
|
202
|
+
return message
|
|
203
|
+
|
|
204
|
+
assert isinstance(
|
|
205
|
+
message.content, str
|
|
206
|
+
), "Citation agent currently only supports string content."
|
|
207
|
+
|
|
208
|
+
if not contains_context_tags("\n".join(str(msg.content) for msg in messages)):
|
|
209
|
+
# No context tags, nothing to do
|
|
210
|
+
return message
|
|
211
|
+
|
|
212
|
+
sentences = split_into_sentences(message.content)
|
|
213
|
+
|
|
214
|
+
num_width = len(str(len(sentences)))
|
|
215
|
+
numbered_message = AIMessage(
|
|
216
|
+
content="\n".join(
|
|
217
|
+
f"{str(i).rjust(num_width)}: {sentence.strip()}"
|
|
218
|
+
for i, sentence in enumerate(sentences)
|
|
219
|
+
),
|
|
220
|
+
name=message.name,
|
|
221
|
+
)
|
|
222
|
+
system_message = SystemMessage(system_prompt)
|
|
223
|
+
_messages = [system_message, *messages, numbered_message]
|
|
224
|
+
|
|
225
|
+
citations = await model.with_structured_output(Citations).ainvoke(
|
|
226
|
+
_messages, **kwargs
|
|
227
|
+
)
|
|
228
|
+
assert isinstance(
|
|
229
|
+
citations, Citations
|
|
230
|
+
), f"Expected Citations from model invocation but got {type(citations)}"
|
|
231
|
+
citations = validate_citations(citations, messages, sentences)
|
|
232
|
+
|
|
233
|
+
message.content = merge_citations(sentences, citations) # type: ignore[assignment]
|
|
234
|
+
return message
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def create_citation_model(
|
|
238
|
+
model: BaseChatModel,
|
|
239
|
+
citation_model: BaseChatModel | None = None,
|
|
240
|
+
system_prompt: str | None = None,
|
|
241
|
+
) -> Runnable[Sequence[BaseMessage], AIMessage]:
|
|
242
|
+
"""Take a base chat model and wrap it such that it adds citations to the messages.
|
|
243
|
+
Any contexts to be cited should be provided in the messages as XML tags,
|
|
244
|
+
e.g. `<context key="abc">Today is a sunny day</context>`.
|
|
245
|
+
The returned AIMessage will have the following structure:
|
|
246
|
+
AIMessage(
|
|
247
|
+
content= {
|
|
248
|
+
"citations": [
|
|
249
|
+
{
|
|
250
|
+
"cited_text": "the color of the grass is green",
|
|
251
|
+
"generated_cited_text": "the color of the grass is green",
|
|
252
|
+
"key": "abc",
|
|
253
|
+
"dist": 0,
|
|
254
|
+
}
|
|
255
|
+
],
|
|
256
|
+
"text": "The grass is green",
|
|
257
|
+
"type": "text",
|
|
258
|
+
},
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
model: The base chat model to wrap.
|
|
263
|
+
citation_model: The model to use for extracting citations.
|
|
264
|
+
If None, the base model is used.
|
|
265
|
+
system_prompt: The system prompt to use for the citation model.
|
|
266
|
+
If None, a default prompt is used.
|
|
267
|
+
"""
|
|
268
|
+
citation_model = citation_model or model
|
|
269
|
+
system_prompt = system_prompt or SYSTEM_PROMPT
|
|
270
|
+
|
|
271
|
+
async def ainvoke_with_citations(
|
|
272
|
+
messages: Sequence[BaseMessage],
|
|
273
|
+
) -> AIMessage:
|
|
274
|
+
"""Invoke the model and add citations to the AIMessage."""
|
|
275
|
+
ai_message = await model.ainvoke(messages)
|
|
276
|
+
assert isinstance(
|
|
277
|
+
ai_message, AIMessage
|
|
278
|
+
), f"Expected AIMessage from model invocation but got {type(ai_message)}"
|
|
279
|
+
return await add_citations(citation_model, messages, ai_message, system_prompt)
|
|
280
|
+
|
|
281
|
+
return RunnableCallable(
|
|
282
|
+
func=None, # TODO: Implement a sync version if needed
|
|
283
|
+
afunc=ainvoke_with_citations,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class CitationMixin(BaseChatModel):
|
|
288
|
+
"""Mixin class to add citation functionality to a runnable.
|
|
289
|
+
|
|
290
|
+
Example usage:
|
|
291
|
+
```
|
|
292
|
+
from langchain_b12.genai.genai import ChatGenAI
|
|
293
|
+
from langchain_b12.citations.citations import CitationMixin
|
|
294
|
+
|
|
295
|
+
class CitationModel(ChatGenAI, CitationMixin):
|
|
296
|
+
pass
|
|
297
|
+
```
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
async def agenerate(
|
|
301
|
+
self,
|
|
302
|
+
messages: list[list[BaseMessage]],
|
|
303
|
+
stop: list[str] | None = None,
|
|
304
|
+
callbacks: Callbacks = None,
|
|
305
|
+
*,
|
|
306
|
+
tags: list[str] | None = None,
|
|
307
|
+
metadata: dict[str, Any] | None = None,
|
|
308
|
+
run_name: str | None = None,
|
|
309
|
+
run_id: UUID | None = None,
|
|
310
|
+
**kwargs: Any,
|
|
311
|
+
) -> LLMResult:
|
|
312
|
+
# Check if we should generate citations and remove it from kwargs
|
|
313
|
+
generate_citations = kwargs.pop("generate_citations", True)
|
|
314
|
+
|
|
315
|
+
llm_result = await super().agenerate(
|
|
316
|
+
messages,
|
|
317
|
+
stop,
|
|
318
|
+
callbacks,
|
|
319
|
+
tags=tags,
|
|
320
|
+
metadata=metadata,
|
|
321
|
+
run_name=run_name,
|
|
322
|
+
run_id=run_id,
|
|
323
|
+
**kwargs,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# Prevent recursion when extracting citations
|
|
327
|
+
if not generate_citations:
|
|
328
|
+
# Below we are call `add_citations` which will call `agenerate` again
|
|
329
|
+
# This will lead to an infinite loop if we don't stop here.
|
|
330
|
+
# We explicitly pass `generate_citations=False` below to sto this recursion.
|
|
331
|
+
return llm_result
|
|
332
|
+
|
|
333
|
+
# overwrite each generation with a version that has citations added
|
|
334
|
+
for _messages, generations in zip(messages, llm_result.generations):
|
|
335
|
+
for generation in generations:
|
|
336
|
+
assert isinstance(generation, ChatGeneration) and not isinstance(
|
|
337
|
+
generation, ChatGenerationChunk
|
|
338
|
+
), f"Expected ChatGeneration; received {type(generation)}"
|
|
339
|
+
assert isinstance(
|
|
340
|
+
generation.message, AIMessage
|
|
341
|
+
), f"Expected AIMessage; received {type(generation.message)}"
|
|
342
|
+
message_with_citations = await add_citations(
|
|
343
|
+
self,
|
|
344
|
+
_messages,
|
|
345
|
+
generation.message,
|
|
346
|
+
SYSTEM_PROMPT,
|
|
347
|
+
generate_citations=False,
|
|
348
|
+
)
|
|
349
|
+
generation.message = message_with_citations
|
|
350
|
+
|
|
351
|
+
return llm_result
|
|
@@ -34,7 +34,7 @@ def multi_content_to_part(
|
|
|
34
34
|
"url": f"data:{mime_type};base64,{encoded_artifact}"
|
|
35
35
|
},
|
|
36
36
|
},
|
|
37
|
-
{ # Image content
|
|
37
|
+
{ # Image content from base64 encoded string with LangChain format
|
|
38
38
|
"type": "image",
|
|
39
39
|
"source_type": "base64",
|
|
40
40
|
"data": "<base64 string>",
|