langchain-b12 0.1.2__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langchain-b12
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: A reusable collection of tools and implementations for Langchain
5
5
  Author-email: Vincent Min <vincent.min@b12-consulting.com>
6
6
  Requires-Python: >=3.11
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "langchain-b12"
3
- version = "0.1.2"
3
+ version = "0.1.4"
4
4
  description = "A reusable collection of tools and implementations for Langchain"
5
5
  readme = "README.md"
6
6
  authors = [
@@ -17,6 +17,11 @@ google = [
17
17
  ]
18
18
  dev = [
19
19
  "pytest>=8.3.5",
20
+ "pytest-asyncio>=1.0.0",
21
+ ]
22
+ citations = [
23
+ "fuzzysearch>=0.8.0",
24
+ "langgraph>=0.4.7",
20
25
  ]
21
26
 
22
27
  [build-system]
@@ -59,3 +64,8 @@ reportUnknownParameterType = false
59
64
  reportUnknownMemberType = false
60
65
  reportUnknownArgumentType = false
61
66
 
67
+ # Add pytest configuration
68
+ [tool.pytest.ini_options]
69
+ asyncio_default_fixture_loop_scope = "function"
70
+ asyncio_mode = "auto"
71
+
@@ -0,0 +1,351 @@
1
+ import re
2
+ from collections.abc import Sequence
3
+ from typing import Any, Literal, TypedDict
4
+ from uuid import UUID
5
+
6
+ from fuzzysearch import find_near_matches
7
+ from langchain_core.callbacks import Callbacks
8
+ from langchain_core.language_models import BaseChatModel
9
+ from langchain_core.messages import AIMessage, BaseMessage, SystemMessage
10
+ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
11
+ from langchain_core.runnables import Runnable
12
+ from langgraph.utils.runnable import RunnableCallable
13
+ from pydantic import BaseModel, Field
14
+
15
+ SYSTEM_PROMPT = """
16
+ You are an expert at identifying and adding citations to text.
17
+ Your task is to identify, for each sentence in the final message, which citations were used to generate it.
18
+
19
+ You will receive a numbered zero-indexed list of sentences in the final message, e.g.
20
+ ```
21
+ 0: Grass is green.
22
+ 1: The sky is blue and the sun is shining.
23
+ ```
24
+ The rest of the conversation may contain contexts enclosed in xml tags, e.g.
25
+ ```
26
+ <context key="abc">
27
+ Today is a sunny day and the color of the grass is green.
28
+ </context>
29
+ ```
30
+ Each sentence may have zero, one, or multiple citations from the contexts.
31
+ Each citation may be used for zero, one or multiple sentences.
32
+ A context may be cited zero, one, or multiple times.
33
+
34
+ The final message will be based on the contexts, but may not mention them explicitly.
35
+ You must identify which contexts and which parts of the contexts were used to generate each sentence.
36
+ For each such case, you must return a citation with a "sentence_index", "cited_text" and "key" property.
37
+ The "sentence_index" is the index of the sentence in the final message.
38
+ The "cited_text" must be a substring of the full context that was used to generate the sentence.
39
+ The "key" must be the key of the context that was used to generate the sentence.
40
+ Make sure that you copy the "cited_text" verbatim from the context, or it will not be considered valid.
41
+
42
+ For the example above, the output should look like this:
43
+ [
44
+ {
45
+ "sentence_index": 0,
46
+ "cited_text": "the color of the grass is green",
47
+ "key": "abc"
48
+ },
49
+ {
50
+ "sentence_index": 1,
51
+ "cited_text": "Today is a sunny day",
52
+ "key": "abc"
53
+ },
54
+ ]
55
+ """.strip() # noqa: E501
56
+
57
+
58
+ class Match(TypedDict):
59
+ start: int
60
+ end: int
61
+ dist: int
62
+ matched: str
63
+
64
+
65
+ class CitationType(TypedDict):
66
+
67
+ cited_text: str | None
68
+ generated_cited_text: str
69
+ key: str
70
+ dist: int | None
71
+
72
+
73
+ class ContentType(TypedDict):
74
+
75
+ citations: list[CitationType] | None
76
+ text: str
77
+ type: Literal["text"]
78
+
79
+
80
+ class Citation(BaseModel):
81
+
82
+ sentence_index: int = Field(
83
+ ...,
84
+ description="The index of the sentence from your answer "
85
+ "that this citation refers to.",
86
+ )
87
+ cited_text: str = Field(
88
+ ...,
89
+ description="The text that is cited from the document. "
90
+ "Make sure you cite it verbatim!",
91
+ )
92
+ key: str = Field(..., description="The key of the document you are citing.")
93
+
94
+
95
+ class Citations(BaseModel):
96
+
97
+ values: list[Citation] = Field(..., description="List of citations")
98
+
99
+
100
+ def split_into_sentences(text: str) -> list[str]:
101
+ """Split text into sentences on punctuation marks and newlines."""
102
+ if not text:
103
+ return [text]
104
+
105
+ # Split after punctuation followed by spaces, or on newlines
106
+ # Use capturing groups to preserve delimiters (spaces and newlines)
107
+ parts = re.split(r"((?<=[.!?])(?= +)|\n+)", text)
108
+
109
+ # Filter out empty strings that can result from splitting
110
+ return [part for part in parts if part]
111
+
112
+
113
+ def contains_context_tags(text: str) -> bool:
114
+ """Check if the text contains context tags."""
115
+ return bool(re.search(r"<context\s+key=[^>]+>.*?</context>", text, re.DOTALL))
116
+
117
+
118
+ def merge_citations(
119
+ sentences: list[str], citations: list[tuple[Citation, Match | None]]
120
+ ) -> list[ContentType]:
121
+ """Merge citations into sentences."""
122
+ content: list[ContentType] = []
123
+ for sentence_index, sentence in enumerate(sentences):
124
+ _citations: list[CitationType] = []
125
+ for citation, match in citations:
126
+ if citation.sentence_index == sentence_index:
127
+ if match is None:
128
+ _citations.append(
129
+ {
130
+ "cited_text": None,
131
+ "generated_cited_text": citation.cited_text,
132
+ "key": citation.key,
133
+ "dist": None,
134
+ }
135
+ )
136
+ else:
137
+ _citations.append(
138
+ {
139
+ "cited_text": match["matched"],
140
+ "generated_cited_text": citation.cited_text,
141
+ "key": citation.key,
142
+ "dist": match["dist"],
143
+ }
144
+ )
145
+ content.append(
146
+ {"text": sentence, "citations": _citations or None, "type": "text"}
147
+ )
148
+
149
+ return content
150
+
151
+
152
+ def validate_citations(
153
+ citations: Citations,
154
+ messages: Sequence[BaseMessage],
155
+ sentences: list[str],
156
+ ) -> list[tuple[Citation, Match | None]]:
157
+ """Validate the citations. Invalid citations are dropped."""
158
+ n_sentences = len(sentences)
159
+
160
+ all_text = "\n".join(
161
+ str(msg.content) for msg in messages if isinstance(msg.content, str)
162
+ )
163
+
164
+ citations_with_matches: list[tuple[Citation, Match | None]] = []
165
+ for citation in citations.values:
166
+ if citation.sentence_index < 0 or citation.sentence_index >= n_sentences:
167
+ # discard citations that refer to non-existing sentences
168
+ continue
169
+ # Allow for 10% error distance
170
+ max_l_dist = max(1, len(citation.cited_text) // 10)
171
+ matches = find_near_matches(
172
+ citation.cited_text, all_text, max_l_dist=max_l_dist
173
+ )
174
+ if not matches:
175
+ citations_with_matches.append((citation, None))
176
+ else:
177
+ match = matches[0]
178
+ citations_with_matches.append(
179
+ (
180
+ citation,
181
+ Match(
182
+ start=match.start,
183
+ end=match.end,
184
+ dist=match.dist,
185
+ matched=match.matched,
186
+ ),
187
+ )
188
+ )
189
+ return citations_with_matches
190
+
191
+
192
+ async def add_citations(
193
+ model: BaseChatModel,
194
+ messages: Sequence[BaseMessage],
195
+ message: AIMessage,
196
+ system_prompt: str,
197
+ **kwargs: Any,
198
+ ) -> AIMessage:
199
+ """Add citations to the message."""
200
+ if not message.content:
201
+ # Nothing to be done, for example in case of a tool call
202
+ return message
203
+
204
+ assert isinstance(
205
+ message.content, str
206
+ ), "Citation agent currently only supports string content."
207
+
208
+ if not contains_context_tags("\n".join(str(msg.content) for msg in messages)):
209
+ # No context tags, nothing to do
210
+ return message
211
+
212
+ sentences = split_into_sentences(message.content)
213
+
214
+ num_width = len(str(len(sentences)))
215
+ numbered_message = AIMessage(
216
+ content="\n".join(
217
+ f"{str(i).rjust(num_width)}: {sentence.strip()}"
218
+ for i, sentence in enumerate(sentences)
219
+ ),
220
+ name=message.name,
221
+ )
222
+ system_message = SystemMessage(system_prompt)
223
+ _messages = [system_message, *messages, numbered_message]
224
+
225
+ citations = await model.with_structured_output(Citations).ainvoke(
226
+ _messages, **kwargs
227
+ )
228
+ assert isinstance(
229
+ citations, Citations
230
+ ), f"Expected Citations from model invocation but got {type(citations)}"
231
+ citations = validate_citations(citations, messages, sentences)
232
+
233
+ message.content = merge_citations(sentences, citations) # type: ignore[assignment]
234
+ return message
235
+
236
+
237
+ def create_citation_model(
238
+ model: BaseChatModel,
239
+ citation_model: BaseChatModel | None = None,
240
+ system_prompt: str | None = None,
241
+ ) -> Runnable[Sequence[BaseMessage], AIMessage]:
242
+ """Take a base chat model and wrap it such that it adds citations to the messages.
243
+ Any contexts to be cited should be provided in the messages as XML tags,
244
+ e.g. `<context key="abc">Today is a sunny day</context>`.
245
+ The returned AIMessage will have the following structure:
246
+ AIMessage(
247
+ content= {
248
+ "citations": [
249
+ {
250
+ "cited_text": "the color of the grass is green",
251
+ "generated_cited_text": "the color of the grass is green",
252
+ "key": "abc",
253
+ "dist": 0,
254
+ }
255
+ ],
256
+ "text": "The grass is green",
257
+ "type": "text",
258
+ },
259
+ )
260
+
261
+ Args:
262
+ model: The base chat model to wrap.
263
+ citation_model: The model to use for extracting citations.
264
+ If None, the base model is used.
265
+ system_prompt: The system prompt to use for the citation model.
266
+ If None, a default prompt is used.
267
+ """
268
+ citation_model = citation_model or model
269
+ system_prompt = system_prompt or SYSTEM_PROMPT
270
+
271
+ async def ainvoke_with_citations(
272
+ messages: Sequence[BaseMessage],
273
+ ) -> AIMessage:
274
+ """Invoke the model and add citations to the AIMessage."""
275
+ ai_message = await model.ainvoke(messages)
276
+ assert isinstance(
277
+ ai_message, AIMessage
278
+ ), f"Expected AIMessage from model invocation but got {type(ai_message)}"
279
+ return await add_citations(citation_model, messages, ai_message, system_prompt)
280
+
281
+ return RunnableCallable(
282
+ func=None, # TODO: Implement a sync version if needed
283
+ afunc=ainvoke_with_citations,
284
+ )
285
+
286
+
287
+ class CitationMixin(BaseChatModel):
288
+ """Mixin class to add citation functionality to a runnable.
289
+
290
+ Example usage:
291
+ ```
292
+ from langchain_b12.genai.genai import ChatGenAI
293
+ from langchain_b12.citations.citations import CitationMixin
294
+
295
+ class CitationModel(ChatGenAI, CitationMixin):
296
+ pass
297
+ ```
298
+ """
299
+
300
+ async def agenerate(
301
+ self,
302
+ messages: list[list[BaseMessage]],
303
+ stop: list[str] | None = None,
304
+ callbacks: Callbacks = None,
305
+ *,
306
+ tags: list[str] | None = None,
307
+ metadata: dict[str, Any] | None = None,
308
+ run_name: str | None = None,
309
+ run_id: UUID | None = None,
310
+ **kwargs: Any,
311
+ ) -> LLMResult:
312
+ # Check if we should generate citations and remove it from kwargs
313
+ generate_citations = kwargs.pop("generate_citations", True)
314
+
315
+ llm_result = await super().agenerate(
316
+ messages,
317
+ stop,
318
+ callbacks,
319
+ tags=tags,
320
+ metadata=metadata,
321
+ run_name=run_name,
322
+ run_id=run_id,
323
+ **kwargs,
324
+ )
325
+
326
+ # Prevent recursion when extracting citations
327
+ if not generate_citations:
328
+ # Below we are call `add_citations` which will call `agenerate` again
329
+ # This will lead to an infinite loop if we don't stop here.
330
+ # We explicitly pass `generate_citations=False` below to sto this recursion.
331
+ return llm_result
332
+
333
+ # overwrite each generation with a version that has citations added
334
+ for _messages, generations in zip(messages, llm_result.generations):
335
+ for generation in generations:
336
+ assert isinstance(generation, ChatGeneration) and not isinstance(
337
+ generation, ChatGenerationChunk
338
+ ), f"Expected ChatGeneration; received {type(generation)}"
339
+ assert isinstance(
340
+ generation.message, AIMessage
341
+ ), f"Expected AIMessage; received {type(generation.message)}"
342
+ message_with_citations = await add_citations(
343
+ self,
344
+ _messages,
345
+ generation.message,
346
+ SYSTEM_PROMPT,
347
+ generate_citations=False,
348
+ )
349
+ generation.message = message_with_citations
350
+
351
+ return llm_result
@@ -34,7 +34,7 @@ def multi_content_to_part(
34
34
  "url": f"data:{mime_type};base64,{encoded_artifact}"
35
35
  },
36
36
  },
37
- { # Image content fro base64 encoded string with LangChain format
37
+ { # Image content from base64 encoded string with LangChain format
38
38
  "type": "image",
39
39
  "source_type": "base64",
40
40
  "data": "<base64 string>",