langchain-b12 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
import re
|
|
2
2
|
from collections.abc import Sequence
|
|
3
|
-
from typing import Literal, TypedDict
|
|
3
|
+
from typing import Any, Literal, TypedDict
|
|
4
|
+
from uuid import UUID
|
|
4
5
|
|
|
5
6
|
from fuzzysearch import find_near_matches
|
|
7
|
+
from langchain_core.callbacks import Callbacks
|
|
6
8
|
from langchain_core.language_models import BaseChatModel
|
|
7
9
|
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage
|
|
10
|
+
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
|
|
8
11
|
from langchain_core.runnables import Runnable
|
|
9
12
|
from langgraph.utils.runnable import RunnableCallable
|
|
10
13
|
from pydantic import BaseModel, Field
|
|
@@ -163,7 +166,11 @@ def validate_citations(
|
|
|
163
166
|
if citation.sentence_index < 0 or citation.sentence_index >= n_sentences:
|
|
164
167
|
# discard citations that refer to non-existing sentences
|
|
165
168
|
continue
|
|
166
|
-
|
|
169
|
+
# Allow for 10% error distance
|
|
170
|
+
max_l_dist = max(1, len(citation.cited_text) // 10)
|
|
171
|
+
matches = find_near_matches(
|
|
172
|
+
citation.cited_text, all_text, max_l_dist=max_l_dist
|
|
173
|
+
)
|
|
167
174
|
if not matches:
|
|
168
175
|
citations_with_matches.append((citation, None))
|
|
169
176
|
else:
|
|
@@ -187,6 +194,7 @@ async def add_citations(
|
|
|
187
194
|
messages: Sequence[BaseMessage],
|
|
188
195
|
message: AIMessage,
|
|
189
196
|
system_prompt: str,
|
|
197
|
+
**kwargs: Any,
|
|
190
198
|
) -> AIMessage:
|
|
191
199
|
"""Add citations to the message."""
|
|
192
200
|
if not message.content:
|
|
@@ -214,7 +222,9 @@ async def add_citations(
|
|
|
214
222
|
system_message = SystemMessage(system_prompt)
|
|
215
223
|
_messages = [system_message, *messages, numbered_message]
|
|
216
224
|
|
|
217
|
-
citations = await model.with_structured_output(Citations).ainvoke(
|
|
225
|
+
citations = await model.with_structured_output(Citations).ainvoke(
|
|
226
|
+
_messages, **kwargs
|
|
227
|
+
)
|
|
218
228
|
assert isinstance(
|
|
219
229
|
citations, Citations
|
|
220
230
|
), f"Expected Citations from model invocation but got {type(citations)}"
|
|
@@ -272,3 +282,70 @@ def create_citation_model(
|
|
|
272
282
|
func=None, # TODO: Implement a sync version if needed
|
|
273
283
|
afunc=ainvoke_with_citations,
|
|
274
284
|
)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class CitationMixin(BaseChatModel):
|
|
288
|
+
"""Mixin class to add citation functionality to a runnable.
|
|
289
|
+
|
|
290
|
+
Example usage:
|
|
291
|
+
```
|
|
292
|
+
from langchain_b12.genai.genai import ChatGenAI
|
|
293
|
+
from langchain_b12.citations.citations import CitationMixin
|
|
294
|
+
|
|
295
|
+
class CitationModel(ChatGenAI, CitationMixin):
|
|
296
|
+
pass
|
|
297
|
+
```
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
async def agenerate(
|
|
301
|
+
self,
|
|
302
|
+
messages: list[list[BaseMessage]],
|
|
303
|
+
stop: list[str] | None = None,
|
|
304
|
+
callbacks: Callbacks = None,
|
|
305
|
+
*,
|
|
306
|
+
tags: list[str] | None = None,
|
|
307
|
+
metadata: dict[str, Any] | None = None,
|
|
308
|
+
run_name: str | None = None,
|
|
309
|
+
run_id: UUID | None = None,
|
|
310
|
+
**kwargs: Any,
|
|
311
|
+
) -> LLMResult:
|
|
312
|
+
# Check if we should generate citations and remove it from kwargs
|
|
313
|
+
generate_citations = kwargs.pop("generate_citations", True)
|
|
314
|
+
|
|
315
|
+
llm_result = await super().agenerate(
|
|
316
|
+
messages,
|
|
317
|
+
stop,
|
|
318
|
+
callbacks,
|
|
319
|
+
tags=tags,
|
|
320
|
+
metadata=metadata,
|
|
321
|
+
run_name=run_name,
|
|
322
|
+
run_id=run_id,
|
|
323
|
+
**kwargs,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# Prevent recursion when extracting citations
|
|
327
|
+
if not generate_citations:
|
|
328
|
+
# Below we are call `add_citations` which will call `agenerate` again
|
|
329
|
+
# This will lead to an infinite loop if we don't stop here.
|
|
330
|
+
# We explicitly pass `generate_citations=False` below to sto this recursion.
|
|
331
|
+
return llm_result
|
|
332
|
+
|
|
333
|
+
# overwrite each generation with a version that has citations added
|
|
334
|
+
for _messages, generations in zip(messages, llm_result.generations):
|
|
335
|
+
for generation in generations:
|
|
336
|
+
assert isinstance(generation, ChatGeneration) and not isinstance(
|
|
337
|
+
generation, ChatGenerationChunk
|
|
338
|
+
), f"Expected ChatGeneration; received {type(generation)}"
|
|
339
|
+
assert isinstance(
|
|
340
|
+
generation.message, AIMessage
|
|
341
|
+
), f"Expected AIMessage; received {type(generation.message)}"
|
|
342
|
+
message_with_citations = await add_citations(
|
|
343
|
+
self,
|
|
344
|
+
_messages,
|
|
345
|
+
generation.message,
|
|
346
|
+
SYSTEM_PROMPT,
|
|
347
|
+
generate_citations=False,
|
|
348
|
+
)
|
|
349
|
+
generation.message = message_with_citations
|
|
350
|
+
|
|
351
|
+
return llm_result
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
langchain_b12/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
langchain_b12/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
-
langchain_b12/citations/citations.py,sha256=
|
|
3
|
+
langchain_b12/citations/citations.py,sha256=FO9ErybQws082JvV-MtTj81fVzdUWxrALcS81ElRsMw,12023
|
|
4
4
|
langchain_b12/genai/embeddings.py,sha256=od2bVIgt7v9aNAHG0PVypVF1H_XgHto2nTd8vwfvyN8,3355
|
|
5
5
|
langchain_b12/genai/genai.py,sha256=gzkgtvs3wNjcslS_KFZYCajUZIsJkVN2Tq2Q1RMIPyc,15910
|
|
6
6
|
langchain_b12/genai/genai_utils.py,sha256=tA6UiJURK25-11vtaX4768UV47jDCYwVKIIWydD4Egw,10736
|
|
7
|
-
langchain_b12-0.1.
|
|
8
|
-
langchain_b12-0.1.
|
|
9
|
-
langchain_b12-0.1.
|
|
7
|
+
langchain_b12-0.1.4.dist-info/METADATA,sha256=x659l7J9-4XSfjgZgGvR-cVoiCtTqq7cIaujV9JsTrE,1204
|
|
8
|
+
langchain_b12-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
9
|
+
langchain_b12-0.1.4.dist-info/RECORD,,
|
|
File without changes
|