chunksmith-agent 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,310 @@
1
+ """Shared outline traversal and context assembly for retrieval backends."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import re
7
+ from difflib import get_close_matches
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ from chunksmith_agent.models import DocumentIndex
12
+
13
+
14
+ def flatten_structure(nodes: list[dict[str, Any]]) -> list[dict[str, Any]]:
15
+ out: list[dict[str, Any]] = []
16
+
17
+ def _walk(n: list[dict[str, Any]]) -> None:
18
+ for node in n:
19
+ out.append(node)
20
+ ch = node.get("nodes")
21
+ if isinstance(ch, list):
22
+ _walk(ch)
23
+
24
+ _walk(nodes)
25
+ return out
26
+
27
+
28
+ def normalize_node_id(raw: Any) -> str | None:
29
+ if raw is None:
30
+ return None
31
+ s = str(raw).strip()
32
+ return s or None
33
+
34
+
35
+ def fallback_nodes_lexical(
36
+ flat_nodes: list[dict[str, Any]],
37
+ query: str,
38
+ max_nodes: int,
39
+ ) -> list[str]:
40
+ q = (query or "").lower()
41
+ q_tokens = [t for t in re.split(r"\W+", q) if len(t) >= 3]
42
+ if not q_tokens:
43
+ return []
44
+ scored: list[tuple[int, str]] = []
45
+ for node in flat_nodes:
46
+ nid = normalize_node_id(node.get("node_id"))
47
+ if not nid:
48
+ continue
49
+ title = (node.get("title") or "").lower()
50
+ summary = (node.get("summary") or "").lower()
51
+ blob_words = set(re.findall(r"[a-z]{4,}", f"{title} {summary}"))
52
+ score = 0
53
+ for t in q_tokens:
54
+ if t in title or t in summary:
55
+ score += 4
56
+ continue
57
+ matches = get_close_matches(t, list(blob_words), n=1, cutoff=0.72)
58
+ if matches:
59
+ score += 2
60
+ if score > 0:
61
+ scored.append((score, nid))
62
+ scored.sort(key=lambda x: -x[0])
63
+ seen: set[str] = set()
64
+ result: list[str] = []
65
+ for _, nid in scored:
66
+ if nid in seen:
67
+ continue
68
+ seen.add(nid)
69
+ result.append(nid)
70
+ if len(result) >= max_nodes:
71
+ break
72
+ return result
73
+
74
+
75
+ def outline_prompt_lines(index: DocumentIndex, *, limit: int = 128) -> str:
76
+ lines: list[str] = []
77
+ for node in flatten_structure(index.structure)[:limit]:
78
+ nid = normalize_node_id(node.get("node_id")) or ""
79
+ title = (node.get("title") or "").replace("\n", " ")
80
+ summary = (node.get("summary") or "").replace("\n", " ")
81
+ if len(summary) > 400:
82
+ summary = summary[:400] + "..."
83
+ lines.append(f"{nid}: {title} || {summary}")
84
+ return "\n".join(lines)
85
+
86
+
87
+ def collect_context(index: DocumentIndex, node_ids: list[str]) -> tuple[str, str]:
88
+ fragments: list[str] = []
89
+
90
+ def _walk(nodes: list[dict[str, Any]]) -> None:
91
+ for node in nodes:
92
+ nid = normalize_node_id(node.get("node_id"))
93
+ if nid and nid in node_ids:
94
+ title = node.get("title") or ""
95
+ summary = node.get("summary") or ""
96
+ media = index.media_by_node.get(nid, {})
97
+ text = media.get("text") or node.get("text") or ""
98
+ fragments.append(f"Title: {title}\nSummary: {summary}\nText:\n{text}")
99
+ ch = node.get("nodes")
100
+ if isinstance(ch, list):
101
+ _walk(ch)
102
+
103
+ _walk(index.structure)
104
+ context = "\n\n".join(fragments).strip()
105
+ tables = ""
106
+ for nid in node_ids:
107
+ media = index.media_by_node.get(nid, {})
108
+ for tbl in media.get("tables") or []:
109
+ if isinstance(tbl, dict):
110
+ tables += f"\nTable (node {nid}, page {tbl.get('page_number')}):\n{tbl.get('html')}\n"
111
+ return context, tables
112
+
113
+
114
+ _FIGURE_MENTION = re.compile(r"Figure\s+\d+\s*:[^\n]+", re.IGNORECASE)
115
+ _TABLE_MENTION = re.compile(r"Table\s+\d+[^\n]*", re.IGNORECASE)
116
+
117
+
118
+ def index_media_counts(index: DocumentIndex) -> tuple[int, int]:
119
+ """Return (table_count, image_count) stored in the index."""
120
+ tables = 0
121
+ images = 0
122
+ for media in index.media_by_node.values():
123
+ tables += len(media.get("tables") or [])
124
+ images += len(media.get("images") or [])
125
+ return tables, images
126
+
127
+
128
+ def index_media_inventory(
129
+ index: DocumentIndex,
130
+ *,
131
+ max_list: int = 20,
132
+ ) -> tuple[list[str], list[str]]:
133
+ """Human-readable lines: where tables and figures were loaded (by node/page)."""
134
+ table_lines: list[str] = []
135
+ figure_lines: list[str] = []
136
+ flat = {normalize_node_id(n.get("node_id")): n for n in flatten_structure(index.structure)}
137
+
138
+ for nid in sorted(index.media_by_node.keys(), key=lambda x: int(x) if x.isdigit() else x):
139
+ media = index.media_by_node.get(nid) or {}
140
+ node = flat.get(nid) or {}
141
+ title = str(node.get("title") or f"node {nid}")
142
+ for i, tbl in enumerate(media.get("tables") or [], start=1):
143
+ if not isinstance(tbl, dict):
144
+ continue
145
+ pg = tbl.get("page_number", "?")
146
+ table_lines.append(f"• {escape_display(title)} (node {nid}, page {pg}, table {i})")
147
+ for i, img in enumerate(media.get("images") or [], start=1):
148
+ if not isinstance(img, dict):
149
+ continue
150
+ pg = img.get("page_number", "?")
151
+ path = str(img.get("image_path") or "").strip()
152
+ name = Path(path).name if path else f"figure {i}"
153
+ figure_lines.append(f"• {escape_display(title)} (node {nid}, page {pg}) — {name}")
154
+
155
+ return table_lines[:max_list], figure_lines[:max_list]
156
+
157
+
158
+ def escape_display(text: str) -> str:
159
+ return (text or "").replace("[", "\\[")
160
+
161
+
162
+ def node_text(index: DocumentIndex, node_id: str) -> str:
163
+ media = index.media_by_node.get(node_id) or {}
164
+ text = str(media.get("text") or "").strip()
165
+ if text:
166
+ return text
167
+ for node in flatten_structure(index.structure):
168
+ if normalize_node_id(node.get("node_id")) == node_id:
169
+ return str(node.get("text") or "").strip()
170
+ return ""
171
+
172
+
173
+ def text_media_mentions(
174
+ index: DocumentIndex,
175
+ node_ids: list[str],
176
+ *,
177
+ max_each: int = 8,
178
+ ) -> tuple[list[dict[str, str]], list[dict[str, str]]]:
179
+ """Figure/table captions mentioned in section text (no image/HTML files)."""
180
+ figures: list[dict[str, str]] = []
181
+ tables: list[dict[str, str]] = []
182
+ seen_f: set[str] = set()
183
+ seen_t: set[str] = set()
184
+ for nid in node_ids:
185
+ text = node_text(index, nid)
186
+ if not text:
187
+ continue
188
+ for match in _FIGURE_MENTION.findall(text):
189
+ line = match.strip()
190
+ if line in seen_f:
191
+ continue
192
+ seen_f.add(line)
193
+ figures.append({"node_id": nid, "caption": line})
194
+ if len(figures) >= max_each:
195
+ break
196
+ for match in _TABLE_MENTION.findall(text):
197
+ line = match.strip()
198
+ if line in seen_t:
199
+ continue
200
+ seen_t.add(line)
201
+ tables.append({"node_id": nid, "caption": line})
202
+ if len(tables) >= max_each:
203
+ break
204
+ return figures, tables
205
+
206
+
207
+ _QUERY_WANTS_TABLES = re.compile(
208
+ r"\b("
209
+ r"table|tables|tabular|chart|matrix|grid|spreadsheet|"
210
+ r"compare|comparison|statistics|stats|metrics|numbers|data"
211
+ r")\b",
212
+ re.IGNORECASE,
213
+ )
214
+ _QUERY_WANTS_FIGURES = re.compile(
215
+ r"\b(figure|figures|image|images|diagram|plot|visual|picture|illustration)\b",
216
+ re.IGNORECASE,
217
+ )
218
+
219
+
220
+ def is_substantive_table_html(html: str) -> bool:
221
+ """True when HTML looks like a real table (not TOC prose tagged as Table)."""
222
+ h = (html or "").lower()
223
+ if "<tr" not in h:
224
+ return False
225
+ return h.count("<td") + h.count("<th") >= 2
226
+
227
+
228
+ def query_wants_table_display(query: str) -> bool:
229
+ return bool(_QUERY_WANTS_TABLES.search(query or ""))
230
+
231
+
232
+ def query_wants_figure_display(query: str) -> bool:
233
+ return bool(_QUERY_WANTS_FIGURES.search(query or ""))
234
+
235
+
236
+ def select_tables(
237
+ index: DocumentIndex,
238
+ node_ids: list[str],
239
+ *,
240
+ max_tables: int = 6,
241
+ substantive_only: bool = False,
242
+ ) -> list[dict[str, Any]]:
243
+ chosen: list[dict[str, Any]] = []
244
+ for nid in node_ids:
245
+ media = index.media_by_node.get(nid, {})
246
+ for tbl in media.get("tables") or []:
247
+ if not isinstance(tbl, dict):
248
+ continue
249
+ html = tbl.get("html")
250
+ if not html or not str(html).strip():
251
+ continue
252
+ if substantive_only and not is_substantive_table_html(str(html)):
253
+ continue
254
+ row = dict(tbl)
255
+ row["node_id"] = nid
256
+ chosen.append(row)
257
+ if len(chosen) >= max_tables:
258
+ return chosen
259
+ return chosen
260
+
261
+
262
+ def select_tables_for_display(
263
+ index: DocumentIndex,
264
+ node_ids: list[str],
265
+ query: str,
266
+ *,
267
+ max_tables: int = 3,
268
+ ) -> list[dict[str, Any]]:
269
+ """Substantive tables in retrieved sections (skip TOC-like blocks)."""
270
+ _ = query
271
+ mode = os.environ.get("CHUNKSMITH_CLI_TABLES_MODE", "auto").strip().lower()
272
+ if mode in ("0", "false", "no", "off", "never"):
273
+ return []
274
+ return select_tables(index, node_ids, max_tables=max_tables, substantive_only=True)
275
+
276
+
277
+ def select_images_for_display(
278
+ index: DocumentIndex,
279
+ node_ids: list[str],
280
+ query: str,
281
+ *,
282
+ max_images: int = 3,
283
+ ) -> list[dict[str, Any]]:
284
+ """Figures from retrieved sections (capped so the terminal stays readable)."""
285
+ _ = query
286
+ mode = os.environ.get("CHUNKSMITH_CLI_IMAGES_MODE", "auto").strip().lower()
287
+ if mode in ("0", "false", "no", "off", "never"):
288
+ return []
289
+ return select_images(index, node_ids, max_images=max_images)
290
+
291
+
292
+ def select_images(index: DocumentIndex, node_ids: list[str], *, max_images: int = 8) -> list[dict[str, Any]]:
293
+ chosen: list[dict[str, Any]] = []
294
+ seen: set[str] = set()
295
+ for nid in node_ids:
296
+ media = index.media_by_node.get(nid, {})
297
+ for img in media.get("images") or []:
298
+ if not isinstance(img, dict):
299
+ continue
300
+ path = str(img.get("image_path") or "")
301
+ fp = path or str(img.get("element_id") or id(img))
302
+ if fp in seen:
303
+ continue
304
+ seen.add(fp)
305
+ row = dict(img)
306
+ row["node_id"] = nid
307
+ chosen.append(row)
308
+ if len(chosen) >= max_images:
309
+ return chosen
310
+ return chosen
@@ -0,0 +1,101 @@
1
+ """LangChain helpers for the tool-calling agent."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from chunksmith_agent.settings import AgentSettings
8
+ from pydantic import BaseModel, Field
9
+
10
+ from chunksmith_agent.index_context import (
11
+ fallback_nodes_lexical,
12
+ flatten_structure,
13
+ normalize_node_id,
14
+ outline_prompt_lines,
15
+ )
16
+ from chunksmith_agent.models import DocumentIndex
17
+
18
+
19
+ def build_chat_model(settings: AgentSettings, *, temperature: float = 0.0):
20
+ from langchain_openai import AzureChatOpenAI, ChatOpenAI
21
+
22
+ kwargs = dict(settings.litellm_kwargs or {})
23
+ model = settings.llm_model or settings.pageindex_model
24
+ if str(model).startswith("azure/") or kwargs.get("api_base"):
25
+ deployment = kwargs.get("azure_deployment") or settings.pageindex_model
26
+ return AzureChatOpenAI(
27
+ azure_endpoint=kwargs.get("api_base") or kwargs.get("azure_endpoint"),
28
+ api_key=kwargs.get("api_key") or settings.openai_api_key,
29
+ api_version=kwargs.get("api_version") or "2024-02-15-preview",
30
+ azure_deployment=str(deployment).replace("azure/", ""),
31
+ temperature=temperature,
32
+ )
33
+ return ChatOpenAI(
34
+ model=settings.pageindex_model,
35
+ api_key=settings.openai_api_key,
36
+ temperature=temperature,
37
+ )
38
+
39
+
40
+ class NodeSelection(BaseModel):
41
+ thinking: str = Field(description="Brief reasoning")
42
+ node_list: list[str] = Field(description="Up to 8 node_id strings")
43
+
44
+
45
+ def _structured_output(llm: Any, schema: type[BaseModel]) -> Any:
46
+ """Structured LLM output without OpenAI ``parsed`` serialization warnings."""
47
+ return llm.with_structured_output(schema, method="function_calling")
48
+
49
+
50
+ def select_relevant_nodes(
51
+ index: DocumentIndex,
52
+ query: str,
53
+ settings: AgentSettings,
54
+ *,
55
+ max_nodes: int = 8,
56
+ ) -> tuple[list[str], str]:
57
+ from langchain_core.prompts import ChatPromptTemplate
58
+
59
+ llm = build_chat_model(settings)
60
+ chain = ChatPromptTemplate.from_messages(
61
+ [
62
+ (
63
+ "system",
64
+ "Pick outline node_ids most likely to answer the user question. "
65
+ f"Return at most {max_nodes} ids. "
66
+ "Prefer pages marked [table] or [figure] when the question is about "
67
+ "statistics, inflation, prices, economic data, or charts.",
68
+ ),
69
+ (
70
+ "human",
71
+ "Question:\n{query}\n\nOutline (node_id: title || summary):\n{nodes}",
72
+ ),
73
+ ]
74
+ ) | _structured_output(llm, NodeSelection)
75
+
76
+ out: NodeSelection = chain.invoke({"query": query, "nodes": outline_prompt_lines(index)})
77
+ result: list[str] = []
78
+ seen: set[str] = set()
79
+ for raw_id in out.node_list:
80
+ nid = normalize_node_id(raw_id)
81
+ if not nid or nid in seen:
82
+ continue
83
+ seen.add(nid)
84
+ result.append(nid)
85
+ if len(result) >= max_nodes:
86
+ break
87
+
88
+ thinking = (out.thinking or "").strip()
89
+ if not result:
90
+ result = fallback_nodes_lexical(flatten_structure(index.structure), query, max_nodes)
91
+ if result:
92
+ thinking = (thinking + " [fallback: lexical title match]").strip()
93
+ return result, thinking
94
+
95
+
96
+ def chunk_content(chunk: Any) -> str:
97
+ """Extract text from a LangChain stream chunk."""
98
+ content = getattr(chunk, "content", None)
99
+ if isinstance(content, str):
100
+ return content
101
+ return ""
@@ -0,0 +1,60 @@
1
+ """Agent index and answer types."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any
7
+
8
+
9
+ @dataclass
10
+ class DocumentIndex:
11
+ """Searchable document: nested outline + per-node media."""
12
+
13
+ doc_name: str
14
+ structure: list[dict[str, Any]]
15
+ media_by_node: dict[str, dict[str, Any]] = field(default_factory=dict)
16
+ canonical_bundle: dict[str, Any] | None = None
17
+ coded_formate: str | None = None
18
+ image_dir: str | None = None
19
+
20
+ def to_dict(self) -> dict[str, Any]:
21
+ return {
22
+ "doc_name": self.doc_name,
23
+ "structure": self.structure,
24
+ "media_by_node": self.media_by_node,
25
+ "canonical_bundle": self.canonical_bundle,
26
+ "coded_formate": self.coded_formate,
27
+ "image_dir": self.image_dir,
28
+ }
29
+
30
+ @classmethod
31
+ def from_dict(cls, raw: dict[str, Any]) -> DocumentIndex:
32
+ bundle = raw.get("canonical_bundle")
33
+ if not isinstance(bundle, dict):
34
+ bundle = None
35
+ st = raw.get("structure")
36
+ if not isinstance(st, list):
37
+ st = []
38
+ mbn = raw.get("media_by_node")
39
+ if not isinstance(mbn, dict):
40
+ mbn = {}
41
+ cf = raw.get("coded_formate")
42
+ if bundle and isinstance(bundle.get("coded_formate"), str):
43
+ cf = bundle.get("coded_formate")
44
+ return cls(
45
+ doc_name=str(raw.get("doc_name") or "document"),
46
+ structure=st,
47
+ media_by_node=mbn,
48
+ canonical_bundle=bundle,
49
+ coded_formate=cf if isinstance(cf, str) else None,
50
+ image_dir=raw.get("image_dir") if isinstance(raw.get("image_dir"), str) else None,
51
+ )
52
+
53
+
54
+ @dataclass
55
+ class AgentAnswer:
56
+ answer: str
57
+ nodes_used: list[str]
58
+ selection_thinking: str
59
+ images: list[dict[str, Any]]
60
+ raw_context: dict[str, Any] = field(default_factory=dict)
@@ -0,0 +1,80 @@
1
+ """Agent Q&A entry points (delegates to the LangChain tool agent)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Callable, Iterator
6
+
7
+ from chunksmith_agent.index_context import select_images as _select_images
8
+ from chunksmith_agent.models import AgentAnswer, DocumentIndex
9
+ from chunksmith_agent.session import AgentConversation
10
+ from chunksmith_agent.settings import AgentSettings
11
+
12
+
13
+ def answer_question(
14
+ index: DocumentIndex,
15
+ query: str,
16
+ settings: AgentSettings,
17
+ *,
18
+ conversation: AgentConversation | None = None,
19
+ on_thinking_delta: Callable[[str], None] | None = None,
20
+ on_token: Callable[[str], None] | None = None,
21
+ ) -> AgentAnswer:
22
+ conv = conversation if conversation is not None else AgentConversation()
23
+ answer = ""
24
+ node_ids: list[str] = []
25
+
26
+ def _sink(name: str, payload: dict[str, Any]) -> None:
27
+ nonlocal answer, node_ids
28
+ if name == "agent:token":
29
+ chunk = payload.get("content") or ""
30
+ answer += chunk
31
+ if on_token and chunk:
32
+ on_token(chunk)
33
+ elif name == "agent:thinking":
34
+ text = str(payload.get("text") or "")
35
+ if on_thinking_delta and text:
36
+ on_thinking_delta(text)
37
+ elif name == "agent:complete":
38
+ node_ids = list(payload.get("node_ids") or [])
39
+
40
+ for _ in iter_answer_events(
41
+ index,
42
+ query,
43
+ settings,
44
+ event_sink=_sink,
45
+ emit_image_events=False,
46
+ conversation=conv,
47
+ ):
48
+ pass
49
+ return AgentAnswer(
50
+ answer=answer,
51
+ nodes_used=node_ids,
52
+ selection_thinking="",
53
+ images=_select_images(index, node_ids),
54
+ raw_context={"node_ids": node_ids, "mode": "tools"},
55
+ )
56
+
57
+
58
+ def iter_answer_events(
59
+ index: DocumentIndex,
60
+ query: str,
61
+ settings: AgentSettings,
62
+ *,
63
+ event_sink: Callable[[str, dict[str, Any]], None] | None = None,
64
+ emit_image_events: bool = True,
65
+ emit_table_events: bool = True,
66
+ conversation: AgentConversation | None = None,
67
+ ) -> Iterator[tuple[str, dict[str, Any]]]:
68
+ """Event stream for CLI — LangChain tool-calling agent."""
69
+ conv = conversation if conversation is not None else AgentConversation()
70
+ from chunksmith_agent.tool_agent import iter_tool_agent_events
71
+
72
+ yield from iter_tool_agent_events(
73
+ index,
74
+ query,
75
+ settings,
76
+ conversation=conv,
77
+ event_sink=event_sink,
78
+ emit_image_events=emit_image_events,
79
+ emit_table_events=emit_table_events,
80
+ )
@@ -0,0 +1,44 @@
1
+ """Conversation memory for multi-turn agent chat."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+
7
+
8
+ @dataclass
9
+ class ChatTurn:
10
+ query: str
11
+ answer: str
12
+ node_ids: list[str]
13
+
14
+
15
+ @dataclass
16
+ class AgentConversation:
17
+ """In-session memory for one loaded document."""
18
+
19
+ turns: list[ChatTurn] = field(default_factory=list)
20
+ last_node_ids: list[str] = field(default_factory=list)
21
+
22
+ def record(self, query: str, answer: str, node_ids: list[str]) -> None:
23
+ ids = list(node_ids)
24
+ self.turns.append(ChatTurn(query=query, answer=answer, node_ids=ids))
25
+ if ids:
26
+ self.last_node_ids = ids
27
+
28
+ def chat_messages(self, *, max_turns: int = 6) -> list[dict[str, str]]:
29
+ recent = self.turns[-max_turns:]
30
+ out: list[dict[str, str]] = []
31
+ for t in recent:
32
+ out.append({"role": "user", "content": t.query})
33
+ if t.answer.strip():
34
+ out.append({"role": "assistant", "content": t.answer})
35
+ return out
36
+
37
+ def recent_context(self, *, max_turns: int = 3) -> str:
38
+ if not self.turns:
39
+ return "(no prior turns)"
40
+ lines: list[str] = []
41
+ for t in self.turns[-max_turns:]:
42
+ lines.append(f"User: {t.query}")
43
+ lines.append(f"Assistant: {t.answer[:400]}")
44
+ return "\n".join(lines)
@@ -0,0 +1,68 @@
1
+ """Agent LLM settings (env-only; no chunksmith pipeline imports)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from dataclasses import dataclass, field
7
+ from typing import Any
8
+
9
+ from dotenv import load_dotenv
10
+
11
+
12
+ def _clean_env(name: str) -> str | None:
13
+ raw = os.getenv(name)
14
+ if raw is None:
15
+ return None
16
+ value = str(raw).split("#", 1)[0].strip()
17
+ return value or None
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class AgentSettings:
22
+ llm_model: str
23
+ openai_api_key: str | None
24
+ pageindex_model: str
25
+ litellm_kwargs: dict[str, Any] = field(default_factory=dict)
26
+
27
+
28
+ def _resolve_litellm(*, pageindex_model: str) -> tuple[str, dict[str, Any]]:
29
+ model = (_clean_env("CHUNKSMITH_LLM_MODEL") or _clean_env("LLM_MODEL") or pageindex_model).strip()
30
+ kwargs: dict[str, Any] = {}
31
+
32
+ openai_key = _clean_env("OPENAI_API_KEY") or _clean_env("CHATGPT_API_KEY")
33
+ azure_key = _clean_env("AZURE_API_KEY") or _clean_env("AZURE_OPENAI_API_KEY")
34
+ azure_base = _clean_env("AZURE_API_BASE") or _clean_env("AZURE_OPENAI_ENDPOINT")
35
+ azure_version = _clean_env("AZURE_API_VERSION") or _clean_env("AZURE_OPENAI_API_VERSION")
36
+
37
+ if model.startswith("azure/"):
38
+ if azure_key:
39
+ kwargs["api_key"] = azure_key
40
+ if azure_base:
41
+ kwargs["api_base"] = azure_base.rstrip("/")
42
+ if azure_version:
43
+ kwargs["api_version"] = azure_version
44
+ elif openai_key:
45
+ kwargs["api_key"] = openai_key
46
+
47
+ return model, kwargs
48
+
49
+
50
+ def load_settings(*, pageindex_model: str | None = None) -> AgentSettings:
51
+ """Load agent LLM settings from environment (``.env`` supported)."""
52
+ load_dotenv()
53
+ model_name = pageindex_model or _clean_env("PAGEINDEX_MODEL") or _clean_env("LLM_MODEL") or "gpt-4o-2024-11-20"
54
+ llm_model, litellm_kwargs = _resolve_litellm(pageindex_model=str(model_name).strip())
55
+ openai_key = _clean_env("OPENAI_API_KEY") or _clean_env("CHATGPT_API_KEY")
56
+
57
+ if llm_model.startswith("azure/"):
58
+ if not (litellm_kwargs.get("api_key") and litellm_kwargs.get("api_base")):
59
+ raise ValueError(f"Azure model {llm_model!r} requires AZURE_API_KEY and AZURE_API_BASE in .env")
60
+ elif not openai_key and not litellm_kwargs.get("api_key"):
61
+ raise ValueError("Missing OPENAI_API_KEY in .env")
62
+
63
+ return AgentSettings(
64
+ llm_model=llm_model,
65
+ openai_api_key=openai_key,
66
+ pageindex_model=str(model_name).strip(),
67
+ litellm_kwargs=litellm_kwargs,
68
+ )