langchain-b12 0.1.1__tar.gz → 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langchain-b12
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: A reusable collection of tools and implementations for Langchain
5
5
  Author-email: Vincent Min <vincent.min@b12-consulting.com>
6
6
  Requires-Python: >=3.11
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "langchain-b12"
3
- version = "0.1.1"
3
+ version = "0.1.3"
4
4
  description = "A reusable collection of tools and implementations for Langchain"
5
5
  readme = "README.md"
6
6
  authors = [
@@ -17,6 +17,11 @@ google = [
17
17
  ]
18
18
  dev = [
19
19
  "pytest>=8.3.5",
20
+ "pytest-asyncio>=1.0.0",
21
+ ]
22
+ citations = [
23
+ "fuzzysearch>=0.8.0",
24
+ "langgraph>=0.4.7",
20
25
  ]
21
26
 
22
27
  [build-system]
@@ -59,3 +64,8 @@ reportUnknownParameterType = false
59
64
  reportUnknownMemberType = false
60
65
  reportUnknownArgumentType = false
61
66
 
67
+ # Add pytest configuration
68
+ [tool.pytest.ini_options]
69
+ asyncio_default_fixture_loop_scope = "function"
70
+ asyncio_mode = "auto"
71
+
@@ -0,0 +1,274 @@
1
+ import re
2
+ from collections.abc import Sequence
3
+ from typing import Literal, TypedDict
4
+
5
+ from fuzzysearch import find_near_matches
6
+ from langchain_core.language_models import BaseChatModel
7
+ from langchain_core.messages import AIMessage, BaseMessage, SystemMessage
8
+ from langchain_core.runnables import Runnable
9
+ from langgraph.utils.runnable import RunnableCallable
10
+ from pydantic import BaseModel, Field
11
+
12
+ SYSTEM_PROMPT = """
13
+ You are an expert at identifying and adding citations to text.
14
+ Your task is to identify, for each sentence in the final message, which citations were used to generate it.
15
+
16
+ You will receive a numbered zero-indexed list of sentences in the final message, e.g.
17
+ ```
18
+ 0: Grass is green.
19
+ 1: The sky is blue and the sun is shining.
20
+ ```
21
+ The rest of the conversation may contain contexts enclosed in xml tags, e.g.
22
+ ```
23
+ <context key="abc">
24
+ Today is a sunny day and the color of the grass is green.
25
+ </context>
26
+ ```
27
+ Each sentence may have zero, one, or multiple citations from the contexts.
28
+ Each citation may be used for zero, one or multiple sentences.
29
+ A context may be cited zero, one, or multiple times.
30
+
31
+ The final message will be based on the contexts, but may not mention them explicitly.
32
+ You must identify which contexts and which parts of the contexts were used to generate each sentence.
33
+ For each such case, you must return a citation with a "sentence_index", "cited_text" and "key" property.
34
+ The "sentence_index" is the index of the sentence in the final message.
35
+ The "cited_text" must be a substring of the full context that was used to generate the sentence.
36
+ The "key" must be the key of the context that was used to generate the sentence.
37
+ Make sure that you copy the "cited_text" verbatim from the context, or it will not be considered valid.
38
+
39
+ For the example above, the output should look like this:
40
+ [
41
+ {
42
+ "sentence_index": 0,
43
+ "cited_text": "the color of the grass is green",
44
+ "key": "abc"
45
+ },
46
+ {
47
+ "sentence_index": 1,
48
+ "cited_text": "Today is a sunny day",
49
+ "key": "abc"
50
+ },
51
+ ]
52
+ """.strip() # noqa: E501
53
+
54
+
55
+ class Match(TypedDict):
56
+ start: int
57
+ end: int
58
+ dist: int
59
+ matched: str
60
+
61
+
62
+ class CitationType(TypedDict):
63
+
64
+ cited_text: str | None
65
+ generated_cited_text: str
66
+ key: str
67
+ dist: int | None
68
+
69
+
70
+ class ContentType(TypedDict):
71
+
72
+ citations: list[CitationType] | None
73
+ text: str
74
+ type: Literal["text"]
75
+
76
+
77
+ class Citation(BaseModel):
78
+
79
+ sentence_index: int = Field(
80
+ ...,
81
+ description="The index of the sentence from your answer "
82
+ "that this citation refers to.",
83
+ )
84
+ cited_text: str = Field(
85
+ ...,
86
+ description="The text that is cited from the document. "
87
+ "Make sure you cite it verbatim!",
88
+ )
89
+ key: str = Field(..., description="The key of the document you are citing.")
90
+
91
+
92
+ class Citations(BaseModel):
93
+
94
+ values: list[Citation] = Field(..., description="List of citations")
95
+
96
+
97
+ def split_into_sentences(text: str) -> list[str]:
98
+ """Split text into sentences on punctuation marks and newlines."""
99
+ if not text:
100
+ return [text]
101
+
102
+ # Split after punctuation followed by spaces, or on newlines
103
+ # Use capturing groups to preserve delimiters (spaces and newlines)
104
+ parts = re.split(r"((?<=[.!?])(?= +)|\n+)", text)
105
+
106
+ # Filter out empty strings that can result from splitting
107
+ return [part for part in parts if part]
108
+
109
+
110
+ def contains_context_tags(text: str) -> bool:
111
+ """Check if the text contains context tags."""
112
+ return bool(re.search(r"<context\s+key=[^>]+>.*?</context>", text, re.DOTALL))
113
+
114
+
115
+ def merge_citations(
116
+ sentences: list[str], citations: list[tuple[Citation, Match | None]]
117
+ ) -> list[ContentType]:
118
+ """Merge citations into sentences."""
119
+ content: list[ContentType] = []
120
+ for sentence_index, sentence in enumerate(sentences):
121
+ _citations: list[CitationType] = []
122
+ for citation, match in citations:
123
+ if citation.sentence_index == sentence_index:
124
+ if match is None:
125
+ _citations.append(
126
+ {
127
+ "cited_text": None,
128
+ "generated_cited_text": citation.cited_text,
129
+ "key": citation.key,
130
+ "dist": None,
131
+ }
132
+ )
133
+ else:
134
+ _citations.append(
135
+ {
136
+ "cited_text": match["matched"],
137
+ "generated_cited_text": citation.cited_text,
138
+ "key": citation.key,
139
+ "dist": match["dist"],
140
+ }
141
+ )
142
+ content.append(
143
+ {"text": sentence, "citations": _citations or None, "type": "text"}
144
+ )
145
+
146
+ return content
147
+
148
+
149
+ def validate_citations(
150
+ citations: Citations,
151
+ messages: Sequence[BaseMessage],
152
+ sentences: list[str],
153
+ ) -> list[tuple[Citation, Match | None]]:
154
+ """Validate the citations. Invalid citations are dropped."""
155
+ n_sentences = len(sentences)
156
+
157
+ all_text = "\n".join(
158
+ str(msg.content) for msg in messages if isinstance(msg.content, str)
159
+ )
160
+
161
+ citations_with_matches: list[tuple[Citation, Match | None]] = []
162
+ for citation in citations.values:
163
+ if citation.sentence_index < 0 or citation.sentence_index >= n_sentences:
164
+ # discard citations that refer to non-existing sentences
165
+ continue
166
+ matches = find_near_matches(citation.cited_text, all_text, max_l_dist=5)
167
+ if not matches:
168
+ citations_with_matches.append((citation, None))
169
+ else:
170
+ match = matches[0]
171
+ citations_with_matches.append(
172
+ (
173
+ citation,
174
+ Match(
175
+ start=match.start,
176
+ end=match.end,
177
+ dist=match.dist,
178
+ matched=match.matched,
179
+ ),
180
+ )
181
+ )
182
+ return citations_with_matches
183
+
184
+
185
+ async def add_citations(
186
+ model: BaseChatModel,
187
+ messages: Sequence[BaseMessage],
188
+ message: AIMessage,
189
+ system_prompt: str,
190
+ ) -> AIMessage:
191
+ """Add citations to the message."""
192
+ if not message.content:
193
+ # Nothing to be done, for example in case of a tool call
194
+ return message
195
+
196
+ assert isinstance(
197
+ message.content, str
198
+ ), "Citation agent currently only supports string content."
199
+
200
+ if not contains_context_tags("\n".join(str(msg.content) for msg in messages)):
201
+ # No context tags, nothing to do
202
+ return message
203
+
204
+ sentences = split_into_sentences(message.content)
205
+
206
+ num_width = len(str(len(sentences)))
207
+ numbered_message = AIMessage(
208
+ content="\n".join(
209
+ f"{str(i).rjust(num_width)}: {sentence.strip()}"
210
+ for i, sentence in enumerate(sentences)
211
+ ),
212
+ name=message.name,
213
+ )
214
+ system_message = SystemMessage(system_prompt)
215
+ _messages = [system_message, *messages, numbered_message]
216
+
217
+ citations = await model.with_structured_output(Citations).ainvoke(_messages)
218
+ assert isinstance(
219
+ citations, Citations
220
+ ), f"Expected Citations from model invocation but got {type(citations)}"
221
+ citations = validate_citations(citations, messages, sentences)
222
+
223
+ message.content = merge_citations(sentences, citations) # type: ignore[assignment]
224
+ return message
225
+
226
+
227
+ def create_citation_model(
228
+ model: BaseChatModel,
229
+ citation_model: BaseChatModel | None = None,
230
+ system_prompt: str | None = None,
231
+ ) -> Runnable[Sequence[BaseMessage], AIMessage]:
232
+ """Take a base chat model and wrap it such that it adds citations to the messages.
233
+ Any contexts to be cited should be provided in the messages as XML tags,
234
+ e.g. `<context key="abc">Today is a sunny day</context>`.
235
+ The returned AIMessage will have the following structure:
236
+ AIMessage(
237
+ content= {
238
+ "citations": [
239
+ {
240
+ "cited_text": "the color of the grass is green",
241
+ "generated_cited_text": "the color of the grass is green",
242
+ "key": "abc",
243
+ "dist": 0,
244
+ }
245
+ ],
246
+ "text": "The grass is green",
247
+ "type": "text",
248
+ },
249
+ )
250
+
251
+ Args:
252
+ model: The base chat model to wrap.
253
+ citation_model: The model to use for extracting citations.
254
+ If None, the base model is used.
255
+ system_prompt: The system prompt to use for the citation model.
256
+ If None, a default prompt is used.
257
+ """
258
+ citation_model = citation_model or model
259
+ system_prompt = system_prompt or SYSTEM_PROMPT
260
+
261
+ async def ainvoke_with_citations(
262
+ messages: Sequence[BaseMessage],
263
+ ) -> AIMessage:
264
+ """Invoke the model and add citations to the AIMessage."""
265
+ ai_message = await model.ainvoke(messages)
266
+ assert isinstance(
267
+ ai_message, AIMessage
268
+ ), f"Expected AIMessage from model invocation but got {type(ai_message)}"
269
+ return await add_citations(citation_model, messages, ai_message, system_prompt)
270
+
271
+ return RunnableCallable(
272
+ func=None, # TODO: Implement a sync version if needed
273
+ afunc=ainvoke_with_citations,
274
+ )
@@ -24,22 +24,37 @@ def multi_content_to_part(
24
24
  Args:
25
25
  contents: A sequence of dictionaries representing content. Examples:
26
26
  [
27
- {
27
+ { # Text content
28
28
  "type": "text",
29
29
  "text": "This is a text message"
30
30
  },
31
- {
31
+ { # Image content from base64 encoded string with OpenAI format
32
32
  "type": "image_url",
33
33
  "image_url": {
34
34
  "url": f"data:{mime_type};base64,{encoded_artifact}"
35
35
  },
36
36
  },
37
- {
37
+ { # Image content from base64 encoded string with LangChain format
38
+ "type": "image",
39
+ "source_type": "base64",
40
+ "data": "<base64 string>",
41
+ "mime_type": "image/jpeg",
42
+ },
43
+ { # Image content from URL
44
+ "type": "image",
45
+ "source_type": "url",
46
+ "url": "https://...",
47
+ },
48
+ { # File content from base64 encoded string
49
+ "type": "file",
50
+ "source_type": "base64",
51
+ "mime_type": "application/pdf",
52
+ "data": "<base64 data string>",
53
+ },
54
+ { # File content from URL
38
55
  "type": "file",
39
- "file": {
40
- "uri": f"gs://{bucket_name}/{file_name}",
41
- "mime_type": mime_type,
42
- }
56
+ "source_type": "url",
57
+ "url": "https://...",
43
58
  }
44
59
  ]
45
60
  """
@@ -60,15 +75,51 @@ def multi_content_to_part(
60
75
  mime_type = header.split(":", 1)[1].split(";", 1)[0]
61
76
  data = base64.b64decode(encoded_data)
62
77
  parts.append(types.Part.from_bytes(data=data, mime_type=mime_type))
78
+ elif content["type"] == "image":
79
+ if "data" in content:
80
+ assert isinstance(content["data"], str), "Expected str data"
81
+ assert "mime_type" in content, "Expected 'mime_type' in content"
82
+ assert isinstance(content["mime_type"], str), "Expected str mime_type"
83
+ data = base64.b64decode(content["data"])
84
+ parts.append(
85
+ types.Part.from_bytes(data=data, mime_type=content["mime_type"])
86
+ )
87
+ elif "url" in content:
88
+ assert isinstance(content["url"], str), "Expected str url"
89
+ mime_type = content.get("mime_type", None)
90
+ assert mime_type is None or isinstance(
91
+ mime_type, str
92
+ ), "Expected str mime_type"
93
+ parts.append(
94
+ types.Part.from_uri(file_uri=content["url"], mime_type=mime_type)
95
+ )
96
+ else:
97
+ raise ValueError(
98
+ "Expected either 'data' or 'url' in content for image type"
99
+ )
63
100
  elif content["type"] == "file":
64
- assert "file" in content, "Expected 'file' in content"
65
- file = content["file"]
66
- assert isinstance(file, dict), "Expected dict file"
67
- assert "uri" in file, "Expected 'uri' in content['file']"
68
- assert "mime_type" in file, "Expected 'mime_type' in content['file']"
69
- parts.append(
70
- types.Part.from_uri(file_uri=file["uri"], mime_type=file["mime_type"])
71
- )
101
+ if "data" in content:
102
+ assert isinstance(content["data"], str), "Expected str data"
103
+ assert "mime_type" in content, "Expected 'mime_type' in content"
104
+ assert isinstance(content["mime_type"], str), "Expected str mime_type"
105
+ data = base64.b64decode(content["data"])
106
+ parts.append(
107
+ types.Part.from_bytes(data=data, mime_type=content["mime_type"])
108
+ )
109
+ elif "url" in content:
110
+ assert isinstance(content["url"], str), "Expected str url"
111
+ assert content["url"], "File URI is required"
112
+ mime_type = content.get("mime_type", None)
113
+ assert mime_type is None or isinstance(
114
+ mime_type, str
115
+ ), "Expected str mime_type"
116
+ parts.append(
117
+ types.Part.from_uri(file_uri=content["url"], mime_type=mime_type)
118
+ )
119
+ else:
120
+ raise ValueError(
121
+ "Expected either 'data' or 'url' in content for file type"
122
+ )
72
123
  else:
73
124
  raise ValueError(f"Unknown content type: {content['type']}")
74
125
  return parts