chunksmith-agent 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chunksmith_agent/__init__.py +13 -0
- chunksmith_agent/agent.py +59 -0
- chunksmith_agent/element_retrieval.py +164 -0
- chunksmith_agent/index_builder.py +325 -0
- chunksmith_agent/index_context.py +310 -0
- chunksmith_agent/langchain_runtime.py +101 -0
- chunksmith_agent/models.py +60 -0
- chunksmith_agent/retrieval.py +80 -0
- chunksmith_agent/session.py +44 -0
- chunksmith_agent/settings.py +68 -0
- chunksmith_agent/tool_agent.py +264 -0
- chunksmith_agent-0.4.0.dist-info/METADATA +82 -0
- chunksmith_agent-0.4.0.dist-info/RECORD +15 -0
- chunksmith_agent-0.4.0.dist-info/WHEEL +5 -0
- chunksmith_agent-0.4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
"""Shared outline traversal and context assembly for retrieval backends."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
from difflib import get_close_matches
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from chunksmith_agent.models import DocumentIndex
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def flatten_structure(nodes: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
15
|
+
out: list[dict[str, Any]] = []
|
|
16
|
+
|
|
17
|
+
def _walk(n: list[dict[str, Any]]) -> None:
|
|
18
|
+
for node in n:
|
|
19
|
+
out.append(node)
|
|
20
|
+
ch = node.get("nodes")
|
|
21
|
+
if isinstance(ch, list):
|
|
22
|
+
_walk(ch)
|
|
23
|
+
|
|
24
|
+
_walk(nodes)
|
|
25
|
+
return out
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def normalize_node_id(raw: Any) -> str | None:
|
|
29
|
+
if raw is None:
|
|
30
|
+
return None
|
|
31
|
+
s = str(raw).strip()
|
|
32
|
+
return s or None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def fallback_nodes_lexical(
|
|
36
|
+
flat_nodes: list[dict[str, Any]],
|
|
37
|
+
query: str,
|
|
38
|
+
max_nodes: int,
|
|
39
|
+
) -> list[str]:
|
|
40
|
+
q = (query or "").lower()
|
|
41
|
+
q_tokens = [t for t in re.split(r"\W+", q) if len(t) >= 3]
|
|
42
|
+
if not q_tokens:
|
|
43
|
+
return []
|
|
44
|
+
scored: list[tuple[int, str]] = []
|
|
45
|
+
for node in flat_nodes:
|
|
46
|
+
nid = normalize_node_id(node.get("node_id"))
|
|
47
|
+
if not nid:
|
|
48
|
+
continue
|
|
49
|
+
title = (node.get("title") or "").lower()
|
|
50
|
+
summary = (node.get("summary") or "").lower()
|
|
51
|
+
blob_words = set(re.findall(r"[a-z]{4,}", f"{title} {summary}"))
|
|
52
|
+
score = 0
|
|
53
|
+
for t in q_tokens:
|
|
54
|
+
if t in title or t in summary:
|
|
55
|
+
score += 4
|
|
56
|
+
continue
|
|
57
|
+
matches = get_close_matches(t, list(blob_words), n=1, cutoff=0.72)
|
|
58
|
+
if matches:
|
|
59
|
+
score += 2
|
|
60
|
+
if score > 0:
|
|
61
|
+
scored.append((score, nid))
|
|
62
|
+
scored.sort(key=lambda x: -x[0])
|
|
63
|
+
seen: set[str] = set()
|
|
64
|
+
result: list[str] = []
|
|
65
|
+
for _, nid in scored:
|
|
66
|
+
if nid in seen:
|
|
67
|
+
continue
|
|
68
|
+
seen.add(nid)
|
|
69
|
+
result.append(nid)
|
|
70
|
+
if len(result) >= max_nodes:
|
|
71
|
+
break
|
|
72
|
+
return result
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def outline_prompt_lines(index: DocumentIndex, *, limit: int = 128) -> str:
|
|
76
|
+
lines: list[str] = []
|
|
77
|
+
for node in flatten_structure(index.structure)[:limit]:
|
|
78
|
+
nid = normalize_node_id(node.get("node_id")) or ""
|
|
79
|
+
title = (node.get("title") or "").replace("\n", " ")
|
|
80
|
+
summary = (node.get("summary") or "").replace("\n", " ")
|
|
81
|
+
if len(summary) > 400:
|
|
82
|
+
summary = summary[:400] + "..."
|
|
83
|
+
lines.append(f"{nid}: {title} || {summary}")
|
|
84
|
+
return "\n".join(lines)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def collect_context(index: DocumentIndex, node_ids: list[str]) -> tuple[str, str]:
|
|
88
|
+
fragments: list[str] = []
|
|
89
|
+
|
|
90
|
+
def _walk(nodes: list[dict[str, Any]]) -> None:
|
|
91
|
+
for node in nodes:
|
|
92
|
+
nid = normalize_node_id(node.get("node_id"))
|
|
93
|
+
if nid and nid in node_ids:
|
|
94
|
+
title = node.get("title") or ""
|
|
95
|
+
summary = node.get("summary") or ""
|
|
96
|
+
media = index.media_by_node.get(nid, {})
|
|
97
|
+
text = media.get("text") or node.get("text") or ""
|
|
98
|
+
fragments.append(f"Title: {title}\nSummary: {summary}\nText:\n{text}")
|
|
99
|
+
ch = node.get("nodes")
|
|
100
|
+
if isinstance(ch, list):
|
|
101
|
+
_walk(ch)
|
|
102
|
+
|
|
103
|
+
_walk(index.structure)
|
|
104
|
+
context = "\n\n".join(fragments).strip()
|
|
105
|
+
tables = ""
|
|
106
|
+
for nid in node_ids:
|
|
107
|
+
media = index.media_by_node.get(nid, {})
|
|
108
|
+
for tbl in media.get("tables") or []:
|
|
109
|
+
if isinstance(tbl, dict):
|
|
110
|
+
tables += f"\nTable (node {nid}, page {tbl.get('page_number')}):\n{tbl.get('html')}\n"
|
|
111
|
+
return context, tables
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
_FIGURE_MENTION = re.compile(r"Figure\s+\d+\s*:[^\n]+", re.IGNORECASE)
|
|
115
|
+
_TABLE_MENTION = re.compile(r"Table\s+\d+[^\n]*", re.IGNORECASE)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def index_media_counts(index: DocumentIndex) -> tuple[int, int]:
|
|
119
|
+
"""Return (table_count, image_count) stored in the index."""
|
|
120
|
+
tables = 0
|
|
121
|
+
images = 0
|
|
122
|
+
for media in index.media_by_node.values():
|
|
123
|
+
tables += len(media.get("tables") or [])
|
|
124
|
+
images += len(media.get("images") or [])
|
|
125
|
+
return tables, images
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def index_media_inventory(
|
|
129
|
+
index: DocumentIndex,
|
|
130
|
+
*,
|
|
131
|
+
max_list: int = 20,
|
|
132
|
+
) -> tuple[list[str], list[str]]:
|
|
133
|
+
"""Human-readable lines: where tables and figures were loaded (by node/page)."""
|
|
134
|
+
table_lines: list[str] = []
|
|
135
|
+
figure_lines: list[str] = []
|
|
136
|
+
flat = {normalize_node_id(n.get("node_id")): n for n in flatten_structure(index.structure)}
|
|
137
|
+
|
|
138
|
+
for nid in sorted(index.media_by_node.keys(), key=lambda x: int(x) if x.isdigit() else x):
|
|
139
|
+
media = index.media_by_node.get(nid) or {}
|
|
140
|
+
node = flat.get(nid) or {}
|
|
141
|
+
title = str(node.get("title") or f"node {nid}")
|
|
142
|
+
for i, tbl in enumerate(media.get("tables") or [], start=1):
|
|
143
|
+
if not isinstance(tbl, dict):
|
|
144
|
+
continue
|
|
145
|
+
pg = tbl.get("page_number", "?")
|
|
146
|
+
table_lines.append(f"• {escape_display(title)} (node {nid}, page {pg}, table {i})")
|
|
147
|
+
for i, img in enumerate(media.get("images") or [], start=1):
|
|
148
|
+
if not isinstance(img, dict):
|
|
149
|
+
continue
|
|
150
|
+
pg = img.get("page_number", "?")
|
|
151
|
+
path = str(img.get("image_path") or "").strip()
|
|
152
|
+
name = Path(path).name if path else f"figure {i}"
|
|
153
|
+
figure_lines.append(f"• {escape_display(title)} (node {nid}, page {pg}) — {name}")
|
|
154
|
+
|
|
155
|
+
return table_lines[:max_list], figure_lines[:max_list]
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def escape_display(text: str) -> str:
|
|
159
|
+
return (text or "").replace("[", "\\[")
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def node_text(index: DocumentIndex, node_id: str) -> str:
|
|
163
|
+
media = index.media_by_node.get(node_id) or {}
|
|
164
|
+
text = str(media.get("text") or "").strip()
|
|
165
|
+
if text:
|
|
166
|
+
return text
|
|
167
|
+
for node in flatten_structure(index.structure):
|
|
168
|
+
if normalize_node_id(node.get("node_id")) == node_id:
|
|
169
|
+
return str(node.get("text") or "").strip()
|
|
170
|
+
return ""
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def text_media_mentions(
|
|
174
|
+
index: DocumentIndex,
|
|
175
|
+
node_ids: list[str],
|
|
176
|
+
*,
|
|
177
|
+
max_each: int = 8,
|
|
178
|
+
) -> tuple[list[dict[str, str]], list[dict[str, str]]]:
|
|
179
|
+
"""Figure/table captions mentioned in section text (no image/HTML files)."""
|
|
180
|
+
figures: list[dict[str, str]] = []
|
|
181
|
+
tables: list[dict[str, str]] = []
|
|
182
|
+
seen_f: set[str] = set()
|
|
183
|
+
seen_t: set[str] = set()
|
|
184
|
+
for nid in node_ids:
|
|
185
|
+
text = node_text(index, nid)
|
|
186
|
+
if not text:
|
|
187
|
+
continue
|
|
188
|
+
for match in _FIGURE_MENTION.findall(text):
|
|
189
|
+
line = match.strip()
|
|
190
|
+
if line in seen_f:
|
|
191
|
+
continue
|
|
192
|
+
seen_f.add(line)
|
|
193
|
+
figures.append({"node_id": nid, "caption": line})
|
|
194
|
+
if len(figures) >= max_each:
|
|
195
|
+
break
|
|
196
|
+
for match in _TABLE_MENTION.findall(text):
|
|
197
|
+
line = match.strip()
|
|
198
|
+
if line in seen_t:
|
|
199
|
+
continue
|
|
200
|
+
seen_t.add(line)
|
|
201
|
+
tables.append({"node_id": nid, "caption": line})
|
|
202
|
+
if len(tables) >= max_each:
|
|
203
|
+
break
|
|
204
|
+
return figures, tables
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
_QUERY_WANTS_TABLES = re.compile(
|
|
208
|
+
r"\b("
|
|
209
|
+
r"table|tables|tabular|chart|matrix|grid|spreadsheet|"
|
|
210
|
+
r"compare|comparison|statistics|stats|metrics|numbers|data"
|
|
211
|
+
r")\b",
|
|
212
|
+
re.IGNORECASE,
|
|
213
|
+
)
|
|
214
|
+
_QUERY_WANTS_FIGURES = re.compile(
|
|
215
|
+
r"\b(figure|figures|image|images|diagram|plot|visual|picture|illustration)\b",
|
|
216
|
+
re.IGNORECASE,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def is_substantive_table_html(html: str) -> bool:
|
|
221
|
+
"""True when HTML looks like a real table (not TOC prose tagged as Table)."""
|
|
222
|
+
h = (html or "").lower()
|
|
223
|
+
if "<tr" not in h:
|
|
224
|
+
return False
|
|
225
|
+
return h.count("<td") + h.count("<th") >= 2
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def query_wants_table_display(query: str) -> bool:
|
|
229
|
+
return bool(_QUERY_WANTS_TABLES.search(query or ""))
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def query_wants_figure_display(query: str) -> bool:
|
|
233
|
+
return bool(_QUERY_WANTS_FIGURES.search(query or ""))
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def select_tables(
|
|
237
|
+
index: DocumentIndex,
|
|
238
|
+
node_ids: list[str],
|
|
239
|
+
*,
|
|
240
|
+
max_tables: int = 6,
|
|
241
|
+
substantive_only: bool = False,
|
|
242
|
+
) -> list[dict[str, Any]]:
|
|
243
|
+
chosen: list[dict[str, Any]] = []
|
|
244
|
+
for nid in node_ids:
|
|
245
|
+
media = index.media_by_node.get(nid, {})
|
|
246
|
+
for tbl in media.get("tables") or []:
|
|
247
|
+
if not isinstance(tbl, dict):
|
|
248
|
+
continue
|
|
249
|
+
html = tbl.get("html")
|
|
250
|
+
if not html or not str(html).strip():
|
|
251
|
+
continue
|
|
252
|
+
if substantive_only and not is_substantive_table_html(str(html)):
|
|
253
|
+
continue
|
|
254
|
+
row = dict(tbl)
|
|
255
|
+
row["node_id"] = nid
|
|
256
|
+
chosen.append(row)
|
|
257
|
+
if len(chosen) >= max_tables:
|
|
258
|
+
return chosen
|
|
259
|
+
return chosen
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def select_tables_for_display(
|
|
263
|
+
index: DocumentIndex,
|
|
264
|
+
node_ids: list[str],
|
|
265
|
+
query: str,
|
|
266
|
+
*,
|
|
267
|
+
max_tables: int = 3,
|
|
268
|
+
) -> list[dict[str, Any]]:
|
|
269
|
+
"""Substantive tables in retrieved sections (skip TOC-like blocks)."""
|
|
270
|
+
_ = query
|
|
271
|
+
mode = os.environ.get("CHUNKSMITH_CLI_TABLES_MODE", "auto").strip().lower()
|
|
272
|
+
if mode in ("0", "false", "no", "off", "never"):
|
|
273
|
+
return []
|
|
274
|
+
return select_tables(index, node_ids, max_tables=max_tables, substantive_only=True)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def select_images_for_display(
|
|
278
|
+
index: DocumentIndex,
|
|
279
|
+
node_ids: list[str],
|
|
280
|
+
query: str,
|
|
281
|
+
*,
|
|
282
|
+
max_images: int = 3,
|
|
283
|
+
) -> list[dict[str, Any]]:
|
|
284
|
+
"""Figures from retrieved sections (capped so the terminal stays readable)."""
|
|
285
|
+
_ = query
|
|
286
|
+
mode = os.environ.get("CHUNKSMITH_CLI_IMAGES_MODE", "auto").strip().lower()
|
|
287
|
+
if mode in ("0", "false", "no", "off", "never"):
|
|
288
|
+
return []
|
|
289
|
+
return select_images(index, node_ids, max_images=max_images)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def select_images(index: DocumentIndex, node_ids: list[str], *, max_images: int = 8) -> list[dict[str, Any]]:
|
|
293
|
+
chosen: list[dict[str, Any]] = []
|
|
294
|
+
seen: set[str] = set()
|
|
295
|
+
for nid in node_ids:
|
|
296
|
+
media = index.media_by_node.get(nid, {})
|
|
297
|
+
for img in media.get("images") or []:
|
|
298
|
+
if not isinstance(img, dict):
|
|
299
|
+
continue
|
|
300
|
+
path = str(img.get("image_path") or "")
|
|
301
|
+
fp = path or str(img.get("element_id") or id(img))
|
|
302
|
+
if fp in seen:
|
|
303
|
+
continue
|
|
304
|
+
seen.add(fp)
|
|
305
|
+
row = dict(img)
|
|
306
|
+
row["node_id"] = nid
|
|
307
|
+
chosen.append(row)
|
|
308
|
+
if len(chosen) >= max_images:
|
|
309
|
+
return chosen
|
|
310
|
+
return chosen
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""LangChain helpers for the tool-calling agent."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from chunksmith_agent.settings import AgentSettings
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
from chunksmith_agent.index_context import (
|
|
11
|
+
fallback_nodes_lexical,
|
|
12
|
+
flatten_structure,
|
|
13
|
+
normalize_node_id,
|
|
14
|
+
outline_prompt_lines,
|
|
15
|
+
)
|
|
16
|
+
from chunksmith_agent.models import DocumentIndex
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def build_chat_model(settings: AgentSettings, *, temperature: float = 0.0):
|
|
20
|
+
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
|
21
|
+
|
|
22
|
+
kwargs = dict(settings.litellm_kwargs or {})
|
|
23
|
+
model = settings.llm_model or settings.pageindex_model
|
|
24
|
+
if str(model).startswith("azure/") or kwargs.get("api_base"):
|
|
25
|
+
deployment = kwargs.get("azure_deployment") or settings.pageindex_model
|
|
26
|
+
return AzureChatOpenAI(
|
|
27
|
+
azure_endpoint=kwargs.get("api_base") or kwargs.get("azure_endpoint"),
|
|
28
|
+
api_key=kwargs.get("api_key") or settings.openai_api_key,
|
|
29
|
+
api_version=kwargs.get("api_version") or "2024-02-15-preview",
|
|
30
|
+
azure_deployment=str(deployment).replace("azure/", ""),
|
|
31
|
+
temperature=temperature,
|
|
32
|
+
)
|
|
33
|
+
return ChatOpenAI(
|
|
34
|
+
model=settings.pageindex_model,
|
|
35
|
+
api_key=settings.openai_api_key,
|
|
36
|
+
temperature=temperature,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class NodeSelection(BaseModel):
|
|
41
|
+
thinking: str = Field(description="Brief reasoning")
|
|
42
|
+
node_list: list[str] = Field(description="Up to 8 node_id strings")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _structured_output(llm: Any, schema: type[BaseModel]) -> Any:
|
|
46
|
+
"""Structured LLM output without OpenAI ``parsed`` serialization warnings."""
|
|
47
|
+
return llm.with_structured_output(schema, method="function_calling")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def select_relevant_nodes(
|
|
51
|
+
index: DocumentIndex,
|
|
52
|
+
query: str,
|
|
53
|
+
settings: AgentSettings,
|
|
54
|
+
*,
|
|
55
|
+
max_nodes: int = 8,
|
|
56
|
+
) -> tuple[list[str], str]:
|
|
57
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
58
|
+
|
|
59
|
+
llm = build_chat_model(settings)
|
|
60
|
+
chain = ChatPromptTemplate.from_messages(
|
|
61
|
+
[
|
|
62
|
+
(
|
|
63
|
+
"system",
|
|
64
|
+
"Pick outline node_ids most likely to answer the user question. "
|
|
65
|
+
f"Return at most {max_nodes} ids. "
|
|
66
|
+
"Prefer pages marked [table] or [figure] when the question is about "
|
|
67
|
+
"statistics, inflation, prices, economic data, or charts.",
|
|
68
|
+
),
|
|
69
|
+
(
|
|
70
|
+
"human",
|
|
71
|
+
"Question:\n{query}\n\nOutline (node_id: title || summary):\n{nodes}",
|
|
72
|
+
),
|
|
73
|
+
]
|
|
74
|
+
) | _structured_output(llm, NodeSelection)
|
|
75
|
+
|
|
76
|
+
out: NodeSelection = chain.invoke({"query": query, "nodes": outline_prompt_lines(index)})
|
|
77
|
+
result: list[str] = []
|
|
78
|
+
seen: set[str] = set()
|
|
79
|
+
for raw_id in out.node_list:
|
|
80
|
+
nid = normalize_node_id(raw_id)
|
|
81
|
+
if not nid or nid in seen:
|
|
82
|
+
continue
|
|
83
|
+
seen.add(nid)
|
|
84
|
+
result.append(nid)
|
|
85
|
+
if len(result) >= max_nodes:
|
|
86
|
+
break
|
|
87
|
+
|
|
88
|
+
thinking = (out.thinking or "").strip()
|
|
89
|
+
if not result:
|
|
90
|
+
result = fallback_nodes_lexical(flatten_structure(index.structure), query, max_nodes)
|
|
91
|
+
if result:
|
|
92
|
+
thinking = (thinking + " [fallback: lexical title match]").strip()
|
|
93
|
+
return result, thinking
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def chunk_content(chunk: Any) -> str:
|
|
97
|
+
"""Extract text from a LangChain stream chunk."""
|
|
98
|
+
content = getattr(chunk, "content", None)
|
|
99
|
+
if isinstance(content, str):
|
|
100
|
+
return content
|
|
101
|
+
return ""
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Agent index and answer types."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class DocumentIndex:
|
|
11
|
+
"""Searchable document: nested outline + per-node media."""
|
|
12
|
+
|
|
13
|
+
doc_name: str
|
|
14
|
+
structure: list[dict[str, Any]]
|
|
15
|
+
media_by_node: dict[str, dict[str, Any]] = field(default_factory=dict)
|
|
16
|
+
canonical_bundle: dict[str, Any] | None = None
|
|
17
|
+
coded_formate: str | None = None
|
|
18
|
+
image_dir: str | None = None
|
|
19
|
+
|
|
20
|
+
def to_dict(self) -> dict[str, Any]:
|
|
21
|
+
return {
|
|
22
|
+
"doc_name": self.doc_name,
|
|
23
|
+
"structure": self.structure,
|
|
24
|
+
"media_by_node": self.media_by_node,
|
|
25
|
+
"canonical_bundle": self.canonical_bundle,
|
|
26
|
+
"coded_formate": self.coded_formate,
|
|
27
|
+
"image_dir": self.image_dir,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def from_dict(cls, raw: dict[str, Any]) -> DocumentIndex:
|
|
32
|
+
bundle = raw.get("canonical_bundle")
|
|
33
|
+
if not isinstance(bundle, dict):
|
|
34
|
+
bundle = None
|
|
35
|
+
st = raw.get("structure")
|
|
36
|
+
if not isinstance(st, list):
|
|
37
|
+
st = []
|
|
38
|
+
mbn = raw.get("media_by_node")
|
|
39
|
+
if not isinstance(mbn, dict):
|
|
40
|
+
mbn = {}
|
|
41
|
+
cf = raw.get("coded_formate")
|
|
42
|
+
if bundle and isinstance(bundle.get("coded_formate"), str):
|
|
43
|
+
cf = bundle.get("coded_formate")
|
|
44
|
+
return cls(
|
|
45
|
+
doc_name=str(raw.get("doc_name") or "document"),
|
|
46
|
+
structure=st,
|
|
47
|
+
media_by_node=mbn,
|
|
48
|
+
canonical_bundle=bundle,
|
|
49
|
+
coded_formate=cf if isinstance(cf, str) else None,
|
|
50
|
+
image_dir=raw.get("image_dir") if isinstance(raw.get("image_dir"), str) else None,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class AgentAnswer:
|
|
56
|
+
answer: str
|
|
57
|
+
nodes_used: list[str]
|
|
58
|
+
selection_thinking: str
|
|
59
|
+
images: list[dict[str, Any]]
|
|
60
|
+
raw_context: dict[str, Any] = field(default_factory=dict)
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""Agent Q&A entry points (delegates to the LangChain tool agent)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Callable, Iterator
|
|
6
|
+
|
|
7
|
+
from chunksmith_agent.index_context import select_images as _select_images
|
|
8
|
+
from chunksmith_agent.models import AgentAnswer, DocumentIndex
|
|
9
|
+
from chunksmith_agent.session import AgentConversation
|
|
10
|
+
from chunksmith_agent.settings import AgentSettings
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def answer_question(
|
|
14
|
+
index: DocumentIndex,
|
|
15
|
+
query: str,
|
|
16
|
+
settings: AgentSettings,
|
|
17
|
+
*,
|
|
18
|
+
conversation: AgentConversation | None = None,
|
|
19
|
+
on_thinking_delta: Callable[[str], None] | None = None,
|
|
20
|
+
on_token: Callable[[str], None] | None = None,
|
|
21
|
+
) -> AgentAnswer:
|
|
22
|
+
conv = conversation if conversation is not None else AgentConversation()
|
|
23
|
+
answer = ""
|
|
24
|
+
node_ids: list[str] = []
|
|
25
|
+
|
|
26
|
+
def _sink(name: str, payload: dict[str, Any]) -> None:
|
|
27
|
+
nonlocal answer, node_ids
|
|
28
|
+
if name == "agent:token":
|
|
29
|
+
chunk = payload.get("content") or ""
|
|
30
|
+
answer += chunk
|
|
31
|
+
if on_token and chunk:
|
|
32
|
+
on_token(chunk)
|
|
33
|
+
elif name == "agent:thinking":
|
|
34
|
+
text = str(payload.get("text") or "")
|
|
35
|
+
if on_thinking_delta and text:
|
|
36
|
+
on_thinking_delta(text)
|
|
37
|
+
elif name == "agent:complete":
|
|
38
|
+
node_ids = list(payload.get("node_ids") or [])
|
|
39
|
+
|
|
40
|
+
for _ in iter_answer_events(
|
|
41
|
+
index,
|
|
42
|
+
query,
|
|
43
|
+
settings,
|
|
44
|
+
event_sink=_sink,
|
|
45
|
+
emit_image_events=False,
|
|
46
|
+
conversation=conv,
|
|
47
|
+
):
|
|
48
|
+
pass
|
|
49
|
+
return AgentAnswer(
|
|
50
|
+
answer=answer,
|
|
51
|
+
nodes_used=node_ids,
|
|
52
|
+
selection_thinking="",
|
|
53
|
+
images=_select_images(index, node_ids),
|
|
54
|
+
raw_context={"node_ids": node_ids, "mode": "tools"},
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def iter_answer_events(
|
|
59
|
+
index: DocumentIndex,
|
|
60
|
+
query: str,
|
|
61
|
+
settings: AgentSettings,
|
|
62
|
+
*,
|
|
63
|
+
event_sink: Callable[[str, dict[str, Any]], None] | None = None,
|
|
64
|
+
emit_image_events: bool = True,
|
|
65
|
+
emit_table_events: bool = True,
|
|
66
|
+
conversation: AgentConversation | None = None,
|
|
67
|
+
) -> Iterator[tuple[str, dict[str, Any]]]:
|
|
68
|
+
"""Event stream for CLI — LangChain tool-calling agent."""
|
|
69
|
+
conv = conversation if conversation is not None else AgentConversation()
|
|
70
|
+
from chunksmith_agent.tool_agent import iter_tool_agent_events
|
|
71
|
+
|
|
72
|
+
yield from iter_tool_agent_events(
|
|
73
|
+
index,
|
|
74
|
+
query,
|
|
75
|
+
settings,
|
|
76
|
+
conversation=conv,
|
|
77
|
+
event_sink=event_sink,
|
|
78
|
+
emit_image_events=emit_image_events,
|
|
79
|
+
emit_table_events=emit_table_events,
|
|
80
|
+
)
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Conversation memory for multi-turn agent chat."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class ChatTurn:
|
|
10
|
+
query: str
|
|
11
|
+
answer: str
|
|
12
|
+
node_ids: list[str]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class AgentConversation:
|
|
17
|
+
"""In-session memory for one loaded document."""
|
|
18
|
+
|
|
19
|
+
turns: list[ChatTurn] = field(default_factory=list)
|
|
20
|
+
last_node_ids: list[str] = field(default_factory=list)
|
|
21
|
+
|
|
22
|
+
def record(self, query: str, answer: str, node_ids: list[str]) -> None:
|
|
23
|
+
ids = list(node_ids)
|
|
24
|
+
self.turns.append(ChatTurn(query=query, answer=answer, node_ids=ids))
|
|
25
|
+
if ids:
|
|
26
|
+
self.last_node_ids = ids
|
|
27
|
+
|
|
28
|
+
def chat_messages(self, *, max_turns: int = 6) -> list[dict[str, str]]:
|
|
29
|
+
recent = self.turns[-max_turns:]
|
|
30
|
+
out: list[dict[str, str]] = []
|
|
31
|
+
for t in recent:
|
|
32
|
+
out.append({"role": "user", "content": t.query})
|
|
33
|
+
if t.answer.strip():
|
|
34
|
+
out.append({"role": "assistant", "content": t.answer})
|
|
35
|
+
return out
|
|
36
|
+
|
|
37
|
+
def recent_context(self, *, max_turns: int = 3) -> str:
|
|
38
|
+
if not self.turns:
|
|
39
|
+
return "(no prior turns)"
|
|
40
|
+
lines: list[str] = []
|
|
41
|
+
for t in self.turns[-max_turns:]:
|
|
42
|
+
lines.append(f"User: {t.query}")
|
|
43
|
+
lines.append(f"Assistant: {t.answer[:400]}")
|
|
44
|
+
return "\n".join(lines)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Agent LLM settings (env-only; no chunksmith pipeline imports)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from dotenv import load_dotenv
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _clean_env(name: str) -> str | None:
|
|
13
|
+
raw = os.getenv(name)
|
|
14
|
+
if raw is None:
|
|
15
|
+
return None
|
|
16
|
+
value = str(raw).split("#", 1)[0].strip()
|
|
17
|
+
return value or None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class AgentSettings:
|
|
22
|
+
llm_model: str
|
|
23
|
+
openai_api_key: str | None
|
|
24
|
+
pageindex_model: str
|
|
25
|
+
litellm_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _resolve_litellm(*, pageindex_model: str) -> tuple[str, dict[str, Any]]:
|
|
29
|
+
model = (_clean_env("CHUNKSMITH_LLM_MODEL") or _clean_env("LLM_MODEL") or pageindex_model).strip()
|
|
30
|
+
kwargs: dict[str, Any] = {}
|
|
31
|
+
|
|
32
|
+
openai_key = _clean_env("OPENAI_API_KEY") or _clean_env("CHATGPT_API_KEY")
|
|
33
|
+
azure_key = _clean_env("AZURE_API_KEY") or _clean_env("AZURE_OPENAI_API_KEY")
|
|
34
|
+
azure_base = _clean_env("AZURE_API_BASE") or _clean_env("AZURE_OPENAI_ENDPOINT")
|
|
35
|
+
azure_version = _clean_env("AZURE_API_VERSION") or _clean_env("AZURE_OPENAI_API_VERSION")
|
|
36
|
+
|
|
37
|
+
if model.startswith("azure/"):
|
|
38
|
+
if azure_key:
|
|
39
|
+
kwargs["api_key"] = azure_key
|
|
40
|
+
if azure_base:
|
|
41
|
+
kwargs["api_base"] = azure_base.rstrip("/")
|
|
42
|
+
if azure_version:
|
|
43
|
+
kwargs["api_version"] = azure_version
|
|
44
|
+
elif openai_key:
|
|
45
|
+
kwargs["api_key"] = openai_key
|
|
46
|
+
|
|
47
|
+
return model, kwargs
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def load_settings(*, pageindex_model: str | None = None) -> AgentSettings:
|
|
51
|
+
"""Load agent LLM settings from environment (``.env`` supported)."""
|
|
52
|
+
load_dotenv()
|
|
53
|
+
model_name = pageindex_model or _clean_env("PAGEINDEX_MODEL") or _clean_env("LLM_MODEL") or "gpt-4o-2024-11-20"
|
|
54
|
+
llm_model, litellm_kwargs = _resolve_litellm(pageindex_model=str(model_name).strip())
|
|
55
|
+
openai_key = _clean_env("OPENAI_API_KEY") or _clean_env("CHATGPT_API_KEY")
|
|
56
|
+
|
|
57
|
+
if llm_model.startswith("azure/"):
|
|
58
|
+
if not (litellm_kwargs.get("api_key") and litellm_kwargs.get("api_base")):
|
|
59
|
+
raise ValueError(f"Azure model {llm_model!r} requires AZURE_API_KEY and AZURE_API_BASE in .env")
|
|
60
|
+
elif not openai_key and not litellm_kwargs.get("api_key"):
|
|
61
|
+
raise ValueError("Missing OPENAI_API_KEY in .env")
|
|
62
|
+
|
|
63
|
+
return AgentSettings(
|
|
64
|
+
llm_model=llm_model,
|
|
65
|
+
openai_api_key=openai_key,
|
|
66
|
+
pageindex_model=str(model_name).strip(),
|
|
67
|
+
litellm_kwargs=litellm_kwargs,
|
|
68
|
+
)
|