langchain-b12 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,13 @@
1
1
  import re
2
2
  from collections.abc import Sequence
3
- from typing import Literal, TypedDict
3
+ from typing import Any, Literal, TypedDict
4
+ from uuid import UUID
4
5
 
5
6
  from fuzzysearch import find_near_matches
7
+ from langchain_core.callbacks import Callbacks
6
8
  from langchain_core.language_models import BaseChatModel
7
9
  from langchain_core.messages import AIMessage, BaseMessage, SystemMessage
10
+ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
8
11
  from langchain_core.runnables import Runnable
9
12
  from langgraph.utils.runnable import RunnableCallable
10
13
  from pydantic import BaseModel, Field
@@ -163,7 +166,11 @@ def validate_citations(
163
166
  if citation.sentence_index < 0 or citation.sentence_index >= n_sentences:
164
167
  # discard citations that refer to non-existing sentences
165
168
  continue
166
- matches = find_near_matches(citation.cited_text, all_text, max_l_dist=5)
169
+ # Allow for 10% error distance
170
+ max_l_dist = max(1, len(citation.cited_text) // 10)
171
+ matches = find_near_matches(
172
+ citation.cited_text, all_text, max_l_dist=max_l_dist
173
+ )
167
174
  if not matches:
168
175
  citations_with_matches.append((citation, None))
169
176
  else:
@@ -187,6 +194,7 @@ async def add_citations(
187
194
  messages: Sequence[BaseMessage],
188
195
  message: AIMessage,
189
196
  system_prompt: str,
197
+ **kwargs: Any,
190
198
  ) -> AIMessage:
191
199
  """Add citations to the message."""
192
200
  if not message.content:
@@ -214,7 +222,9 @@ async def add_citations(
214
222
  system_message = SystemMessage(system_prompt)
215
223
  _messages = [system_message, *messages, numbered_message]
216
224
 
217
- citations = await model.with_structured_output(Citations).ainvoke(_messages)
225
+ citations = await model.with_structured_output(Citations).ainvoke(
226
+ _messages, **kwargs
227
+ )
218
228
  assert isinstance(
219
229
  citations, Citations
220
230
  ), f"Expected Citations from model invocation but got {type(citations)}"
@@ -272,3 +282,70 @@ def create_citation_model(
272
282
  func=None, # TODO: Implement a sync version if needed
273
283
  afunc=ainvoke_with_citations,
274
284
  )
285
+
286
+
287
+ class CitationMixin(BaseChatModel):
288
+ """Mixin class to add citation functionality to a runnable.
289
+
290
+ Example usage:
291
+ ```
292
+ from langchain_b12.genai.genai import ChatGenAI
293
+ from langchain_b12.citations.citations import CitationMixin
294
+
295
+ class CitationModel(ChatGenAI, CitationMixin):
296
+ pass
297
+ ```
298
+ """
299
+
300
+ async def agenerate(
301
+ self,
302
+ messages: list[list[BaseMessage]],
303
+ stop: list[str] | None = None,
304
+ callbacks: Callbacks = None,
305
+ *,
306
+ tags: list[str] | None = None,
307
+ metadata: dict[str, Any] | None = None,
308
+ run_name: str | None = None,
309
+ run_id: UUID | None = None,
310
+ **kwargs: Any,
311
+ ) -> LLMResult:
312
+ # Check if we should generate citations and remove it from kwargs
313
+ generate_citations = kwargs.pop("generate_citations", True)
314
+
315
+ llm_result = await super().agenerate(
316
+ messages,
317
+ stop,
318
+ callbacks,
319
+ tags=tags,
320
+ metadata=metadata,
321
+ run_name=run_name,
322
+ run_id=run_id,
323
+ **kwargs,
324
+ )
325
+
326
+ # Prevent recursion when extracting citations
327
+ if not generate_citations:
328
+ # Below we are call `add_citations` which will call `agenerate` again
329
+ # This will lead to an infinite loop if we don't stop here.
330
+ # We explicitly pass `generate_citations=False` below to sto this recursion.
331
+ return llm_result
332
+
333
+ # overwrite each generation with a version that has citations added
334
+ for _messages, generations in zip(messages, llm_result.generations):
335
+ for generation in generations:
336
+ assert isinstance(generation, ChatGeneration) and not isinstance(
337
+ generation, ChatGenerationChunk
338
+ ), f"Expected ChatGeneration; received {type(generation)}"
339
+ assert isinstance(
340
+ generation.message, AIMessage
341
+ ), f"Expected AIMessage; received {type(generation.message)}"
342
+ message_with_citations = await add_citations(
343
+ self,
344
+ _messages,
345
+ generation.message,
346
+ SYSTEM_PROMPT,
347
+ generate_citations=False,
348
+ )
349
+ generation.message = message_with_citations
350
+
351
+ return llm_result
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langchain-b12
3
- Version: 0.1.3
3
+ Version: 0.1.4
4
4
  Summary: A reusable collection of tools and implementations for Langchain
5
5
  Author-email: Vincent Min <vincent.min@b12-consulting.com>
6
6
  Requires-Python: >=3.11
@@ -1,9 +1,9 @@
1
1
  langchain_b12/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  langchain_b12/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- langchain_b12/citations/citations.py,sha256=6HYKjyp6MaAWiLWZp-azQ5mM-drgt-Xytgarl7YwxhM,9321
3
+ langchain_b12/citations/citations.py,sha256=FO9ErybQws082JvV-MtTj81fVzdUWxrALcS81ElRsMw,12023
4
4
  langchain_b12/genai/embeddings.py,sha256=od2bVIgt7v9aNAHG0PVypVF1H_XgHto2nTd8vwfvyN8,3355
5
5
  langchain_b12/genai/genai.py,sha256=gzkgtvs3wNjcslS_KFZYCajUZIsJkVN2Tq2Q1RMIPyc,15910
6
6
  langchain_b12/genai/genai_utils.py,sha256=tA6UiJURK25-11vtaX4768UV47jDCYwVKIIWydD4Egw,10736
7
- langchain_b12-0.1.3.dist-info/METADATA,sha256=gvKeYszVVVT37bI2RN8T3vOIafWAn48Pe9KTaDUeNd4,1204
8
- langchain_b12-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
- langchain_b12-0.1.3.dist-info/RECORD,,
7
+ langchain_b12-0.1.4.dist-info/METADATA,sha256=x659l7J9-4XSfjgZgGvR-cVoiCtTqq7cIaujV9JsTrE,1204
8
+ langchain_b12-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
+ langchain_b12-0.1.4.dist-info/RECORD,,