langchain-b12 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain_b12/citations/citations.py +101 -16
- langchain_b12/genai/genai.py +16 -0
- {langchain_b12-0.1.3.dist-info → langchain_b12-0.1.5.dist-info}/METADATA +1 -1
- langchain_b12-0.1.5.dist-info/RECORD +9 -0
- langchain_b12-0.1.3.dist-info/RECORD +0 -9
- {langchain_b12-0.1.3.dist-info → langchain_b12-0.1.5.dist-info}/WHEEL +0 -0
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import re
|
|
2
2
|
from collections.abc import Sequence
|
|
3
|
-
from typing import Literal, TypedDict
|
|
3
|
+
from typing import Any, Literal, TypedDict
|
|
4
|
+
from uuid import UUID
|
|
4
5
|
|
|
5
|
-
from
|
|
6
|
+
from langchain_core.callbacks import Callbacks
|
|
6
7
|
from langchain_core.language_models import BaseChatModel
|
|
7
8
|
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage
|
|
9
|
+
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
|
|
8
10
|
from langchain_core.runnables import Runnable
|
|
9
11
|
from langgraph.utils.runnable import RunnableCallable
|
|
10
12
|
from pydantic import BaseModel, Field
|
|
@@ -152,6 +154,8 @@ def validate_citations(
|
|
|
152
154
|
sentences: list[str],
|
|
153
155
|
) -> list[tuple[Citation, Match | None]]:
|
|
154
156
|
"""Validate the citations. Invalid citations are dropped."""
|
|
157
|
+
from fuzzysearch import find_near_matches
|
|
158
|
+
|
|
155
159
|
n_sentences = len(sentences)
|
|
156
160
|
|
|
157
161
|
all_text = "\n".join(
|
|
@@ -163,7 +167,11 @@ def validate_citations(
|
|
|
163
167
|
if citation.sentence_index < 0 or citation.sentence_index >= n_sentences:
|
|
164
168
|
# discard citations that refer to non-existing sentences
|
|
165
169
|
continue
|
|
166
|
-
|
|
170
|
+
# Allow for 10% error distance
|
|
171
|
+
max_l_dist = max(1, len(citation.cited_text) // 10)
|
|
172
|
+
matches = find_near_matches(
|
|
173
|
+
citation.cited_text, all_text, max_l_dist=max_l_dist
|
|
174
|
+
)
|
|
167
175
|
if not matches:
|
|
168
176
|
citations_with_matches.append((citation, None))
|
|
169
177
|
else:
|
|
@@ -187,6 +195,7 @@ async def add_citations(
|
|
|
187
195
|
messages: Sequence[BaseMessage],
|
|
188
196
|
message: AIMessage,
|
|
189
197
|
system_prompt: str,
|
|
198
|
+
**kwargs: Any,
|
|
190
199
|
) -> AIMessage:
|
|
191
200
|
"""Add citations to the message."""
|
|
192
201
|
if not message.content:
|
|
@@ -214,7 +223,9 @@ async def add_citations(
|
|
|
214
223
|
system_message = SystemMessage(system_prompt)
|
|
215
224
|
_messages = [system_message, *messages, numbered_message]
|
|
216
225
|
|
|
217
|
-
citations = await model.with_structured_output(Citations).ainvoke(
|
|
226
|
+
citations = await model.with_structured_output(Citations).ainvoke(
|
|
227
|
+
_messages, **kwargs
|
|
228
|
+
)
|
|
218
229
|
assert isinstance(
|
|
219
230
|
citations, Citations
|
|
220
231
|
), f"Expected Citations from model invocation but got {type(citations)}"
|
|
@@ -234,18 +245,25 @@ def create_citation_model(
|
|
|
234
245
|
e.g. `<context key="abc">Today is a sunny day</context>`.
|
|
235
246
|
The returned AIMessage will have the following structure:
|
|
236
247
|
AIMessage(
|
|
237
|
-
content=
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
248
|
+
content=[
|
|
249
|
+
{
|
|
250
|
+
"citations": [
|
|
251
|
+
{
|
|
252
|
+
"cited_text": "the color of the grass is green",
|
|
253
|
+
"generated_cited_text": "the color of the grass is green",
|
|
254
|
+
"key": "abc",
|
|
255
|
+
"dist": 0,
|
|
256
|
+
}
|
|
257
|
+
],
|
|
258
|
+
"text": "The grass is green",
|
|
259
|
+
"type": "text",
|
|
260
|
+
},
|
|
261
|
+
{
|
|
262
|
+
"citations": None,
|
|
263
|
+
"text": "Is there anything else I can help you with?",
|
|
264
|
+
"type": "text",
|
|
265
|
+
}
|
|
266
|
+
]
|
|
249
267
|
)
|
|
250
268
|
|
|
251
269
|
Args:
|
|
@@ -272,3 +290,70 @@ def create_citation_model(
|
|
|
272
290
|
func=None, # TODO: Implement a sync version if needed
|
|
273
291
|
afunc=ainvoke_with_citations,
|
|
274
292
|
)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
class CitationMixin(BaseChatModel):
|
|
296
|
+
"""Mixin class to add citation functionality to a runnable.
|
|
297
|
+
|
|
298
|
+
Example usage:
|
|
299
|
+
```
|
|
300
|
+
from langchain_b12.genai.genai import ChatGenAI
|
|
301
|
+
from langchain_b12.citations.citations import CitationMixin
|
|
302
|
+
|
|
303
|
+
class CitationModel(ChatGenAI, CitationMixin):
|
|
304
|
+
pass
|
|
305
|
+
```
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
async def agenerate(
|
|
309
|
+
self,
|
|
310
|
+
messages: list[list[BaseMessage]],
|
|
311
|
+
stop: list[str] | None = None,
|
|
312
|
+
callbacks: Callbacks = None,
|
|
313
|
+
*,
|
|
314
|
+
tags: list[str] | None = None,
|
|
315
|
+
metadata: dict[str, Any] | None = None,
|
|
316
|
+
run_name: str | None = None,
|
|
317
|
+
run_id: UUID | None = None,
|
|
318
|
+
**kwargs: Any,
|
|
319
|
+
) -> LLMResult:
|
|
320
|
+
# Check if we should generate citations and remove it from kwargs
|
|
321
|
+
generate_citations = kwargs.pop("generate_citations", True)
|
|
322
|
+
|
|
323
|
+
llm_result = await super().agenerate(
|
|
324
|
+
messages,
|
|
325
|
+
stop,
|
|
326
|
+
callbacks,
|
|
327
|
+
tags=tags,
|
|
328
|
+
metadata=metadata,
|
|
329
|
+
run_name=run_name,
|
|
330
|
+
run_id=run_id,
|
|
331
|
+
**kwargs,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Prevent recursion when extracting citations
|
|
335
|
+
if not generate_citations:
|
|
336
|
+
# Below we are call `add_citations` which will call `agenerate` again
|
|
337
|
+
# This will lead to an infinite loop if we don't stop here.
|
|
338
|
+
# We explicitly pass `generate_citations=False` below to sto this recursion.
|
|
339
|
+
return llm_result
|
|
340
|
+
|
|
341
|
+
# overwrite each generation with a version that has citations added
|
|
342
|
+
for _messages, generations in zip(messages, llm_result.generations):
|
|
343
|
+
for generation in generations:
|
|
344
|
+
assert isinstance(generation, ChatGeneration) and not isinstance(
|
|
345
|
+
generation, ChatGenerationChunk
|
|
346
|
+
), f"Expected ChatGeneration; received {type(generation)}"
|
|
347
|
+
assert isinstance(
|
|
348
|
+
generation.message, AIMessage
|
|
349
|
+
), f"Expected AIMessage; received {type(generation.message)}"
|
|
350
|
+
message_with_citations = await add_citations(
|
|
351
|
+
self,
|
|
352
|
+
_messages,
|
|
353
|
+
generation.message,
|
|
354
|
+
SYSTEM_PROMPT,
|
|
355
|
+
generate_citations=False,
|
|
356
|
+
)
|
|
357
|
+
generation.message = message_with_citations
|
|
358
|
+
|
|
359
|
+
return llm_result
|
langchain_b12/genai/genai.py
CHANGED
|
@@ -90,6 +90,8 @@ class ChatGenAI(BaseChatModel):
|
|
|
90
90
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
|
91
91
|
}
|
|
92
92
|
""" # noqa: E501
|
|
93
|
+
thinking_config: types.ThinkingConfig | None = None
|
|
94
|
+
"The thinking configuration to use for the model."
|
|
93
95
|
|
|
94
96
|
model_config = ConfigDict(
|
|
95
97
|
arbitrary_types_allowed=True,
|
|
@@ -208,6 +210,10 @@ class ChatGenAI(BaseChatModel):
|
|
|
208
210
|
candidate_count=self.n,
|
|
209
211
|
stop_sequences=stop or self.stop,
|
|
210
212
|
safety_settings=self.safety_settings,
|
|
213
|
+
thinking_config=self.thinking_config,
|
|
214
|
+
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
|
215
|
+
disable=True,
|
|
216
|
+
),
|
|
211
217
|
**kwargs,
|
|
212
218
|
),
|
|
213
219
|
)
|
|
@@ -240,6 +246,10 @@ class ChatGenAI(BaseChatModel):
|
|
|
240
246
|
candidate_count=self.n,
|
|
241
247
|
stop_sequences=stop or self.stop,
|
|
242
248
|
safety_settings=self.safety_settings,
|
|
249
|
+
thinking_config=self.thinking_config,
|
|
250
|
+
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
|
251
|
+
disable=True,
|
|
252
|
+
),
|
|
243
253
|
**kwargs,
|
|
244
254
|
),
|
|
245
255
|
)
|
|
@@ -362,6 +372,12 @@ class ChatGenAI(BaseChatModel):
|
|
|
362
372
|
input_tokens=usage_metadata.prompt_token_count or 0,
|
|
363
373
|
output_tokens=usage_metadata.candidates_token_count or 0,
|
|
364
374
|
total_tokens=usage_metadata.total_token_count or 0,
|
|
375
|
+
input_token_details={
|
|
376
|
+
"cache_read": usage_metadata.cached_content_token_count or 0
|
|
377
|
+
},
|
|
378
|
+
output_token_details={
|
|
379
|
+
"reasoning": usage_metadata.thoughts_token_count or 0
|
|
380
|
+
},
|
|
365
381
|
)
|
|
366
382
|
|
|
367
383
|
total_lc_usage: UsageMetadata | None = (
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
langchain_b12/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
langchain_b12/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
langchain_b12/citations/citations.py,sha256=ZQvYayjQXIUaRosJ0qwL3Nc7kC8sBzmaIkE-BOslaVI,12261
|
|
4
|
+
langchain_b12/genai/embeddings.py,sha256=od2bVIgt7v9aNAHG0PVypVF1H_XgHto2nTd8vwfvyN8,3355
|
|
5
|
+
langchain_b12/genai/genai.py,sha256=7X7nDt76Icc5woV5b7FX_uza9YgFpFp1_PcYtXPriqE,16667
|
|
6
|
+
langchain_b12/genai/genai_utils.py,sha256=tA6UiJURK25-11vtaX4768UV47jDCYwVKIIWydD4Egw,10736
|
|
7
|
+
langchain_b12-0.1.5.dist-info/METADATA,sha256=unv3NxdFU_VrlPmIuTmDB2dHRi9go44B-q83kQgLUqI,1204
|
|
8
|
+
langchain_b12-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
9
|
+
langchain_b12-0.1.5.dist-info/RECORD,,
|
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
langchain_b12/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
langchain_b12/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
-
langchain_b12/citations/citations.py,sha256=6HYKjyp6MaAWiLWZp-azQ5mM-drgt-Xytgarl7YwxhM,9321
|
|
4
|
-
langchain_b12/genai/embeddings.py,sha256=od2bVIgt7v9aNAHG0PVypVF1H_XgHto2nTd8vwfvyN8,3355
|
|
5
|
-
langchain_b12/genai/genai.py,sha256=gzkgtvs3wNjcslS_KFZYCajUZIsJkVN2Tq2Q1RMIPyc,15910
|
|
6
|
-
langchain_b12/genai/genai_utils.py,sha256=tA6UiJURK25-11vtaX4768UV47jDCYwVKIIWydD4Egw,10736
|
|
7
|
-
langchain_b12-0.1.3.dist-info/METADATA,sha256=gvKeYszVVVT37bI2RN8T3vOIafWAn48Pe9KTaDUeNd4,1204
|
|
8
|
-
langchain_b12-0.1.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
9
|
-
langchain_b12-0.1.3.dist-info/RECORD,,
|
|
File without changes
|