poma 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,358 @@
1
+ # ---------------------------------------------------------------------
2
+ # POMA integration for LangChain
3
+ # ---------------------------------------------------------------------
4
+
5
+ import os
6
+ import hashlib
7
+ from typing import Any
8
+ from pathlib import Path
9
+ from collections.abc import Iterable
10
+ from concurrent.futures import ThreadPoolExecutor, as_completed
11
+
12
+ from langchain.document_loaders.base import BaseLoader
13
+ from langchain.schema import Document
14
+ from langchain.schema.retriever import BaseRetriever
15
+ from langchain_core.vectorstores import VectorStore
16
+ from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
17
+ from langchain_text_splitters import TextSplitter
18
+ from pydantic import Field, PrivateAttr
19
+
20
+ from poma import Poma
21
+ from poma.client import ALLOWED_FILE_EXTENSIONS
22
+ from poma.exceptions import InvalidInputError
23
+ from poma.retrieval import _cheatsheets_from_chunks
24
+
25
+ __all__ = ["PomaFileLoader", "PomaChunksetSplitter", "PomaCheatsheetRetrieverLC"]
26
+
27
+
28
+ # ------------------------------------------------------------------ #
29
+ # Load from Path → LC Documents #
30
+ # ------------------------------------------------------------------ #
31
+
32
+
33
+ class PomaFileLoader(BaseLoader):
34
+
35
+ def __init__(self, input_path: str | Path):
36
+ """Initialize with a file or directory path."""
37
+ self.input_path = Path(input_path).expanduser().resolve()
38
+
39
+ def load(self) -> list[Document]:
40
+ """
41
+ Load files from the input path (file or directory) into LangChain Documents.
42
+ Only files with allowed extensions are processed; others are skipped.
43
+ """
44
+ path = self.input_path
45
+ if not path.exists():
46
+ raise FileNotFoundError(f"No such path: {path}")
47
+
48
+ documents: list[Document] = []
49
+ skipped: int = 0
50
+
51
+ def _process_file(file_path: Path):
52
+ nonlocal skipped, documents
53
+ if not file_path.is_file():
54
+ return
55
+ file_extension = file_path.suffix.lower()
56
+ if not file_extension or file_extension not in ALLOWED_FILE_EXTENSIONS:
57
+ skipped += 1
58
+ return
59
+ file_bytes = file_path.read_bytes()
60
+ file_hash = hashlib.md5(file_bytes).hexdigest()
61
+ if file_path.suffix.lower() == ".pdf":
62
+ page_content: str = "" # LangChain requires str
63
+ else:
64
+ try:
65
+ page_content = file_bytes.decode("utf-8")
66
+ except UnicodeDecodeError:
67
+ skipped += 1
68
+ return
69
+ documents.append(
70
+ Document(
71
+ page_content=page_content,
72
+ metadata={
73
+ "source_path": str(file_path),
74
+ "doc_id": f"{file_hash}",
75
+ },
76
+ )
77
+ )
78
+
79
+ if path.is_file():
80
+ _process_file(path)
81
+ elif path.is_dir():
82
+ for path_in_dir in sorted(path.rglob("*")):
83
+ _process_file(path_in_dir)
84
+ else:
85
+ raise FileNotFoundError(f"Unsupported path type (not file/dir): {path}")
86
+
87
+ allowed = ", ".join(sorted(ALLOWED_FILE_EXTENSIONS))
88
+ if not documents:
89
+ raise InvalidInputError(f"No supported files found. Allowed: {allowed}")
90
+ if skipped > 0:
91
+ print(
92
+ f"Skipped {skipped} file(s) due to unsupported or unreadable type. Allowed: {allowed}"
93
+ )
94
+ return documents
95
+
96
+
97
+ # ------------------------------------------------------------------ #
98
+ # Generate Chunksets #
99
+ # ------------------------------------------------------------------ #
100
+
101
+
102
+ class PomaChunksetSplitter(TextSplitter):
103
+
104
+ _client: Poma = PrivateAttr()
105
+ _show_progress: bool = PrivateAttr(default=False)
106
+
107
+ def __init__(self, client: Poma, *, verbose: bool = False, **kwargs):
108
+ """Initialize with a Poma client and optional verbosity."""
109
+ super().__init__(**kwargs)
110
+ self._client = client
111
+ self._show_progress = bool(verbose)
112
+
113
+ def split_text(self, text: str) -> list[str]:
114
+ """Not implemented, use split_documents()."""
115
+ raise NotImplementedError("Not implemented, use split_documents().")
116
+
117
+ def split_documents(self, documents: Iterable[Document]) -> list[Document]:
118
+ """
119
+ Split LangChain Documents into chunkset Documents via POMA API.
120
+ Each output Document corresponds to a chunkset, with associated chunks in metadata.
121
+ """
122
+ documents = list(documents)
123
+ if not documents:
124
+ raise InvalidInputError("No documents provided to split.")
125
+
126
+ total_docs = len(documents)
127
+ chunked_docs: list[Document] = []
128
+ failed_paths: list[str] = []
129
+
130
+ def _safe_int(value: object) -> int | None:
131
+ if isinstance(value, bool):
132
+ return None
133
+ if isinstance(value, int):
134
+ return value
135
+ if isinstance(value, str):
136
+ try:
137
+ return int(value.strip())
138
+ except Exception:
139
+ return None
140
+ try:
141
+ return int(value) # type: ignore[arg-type]
142
+ except Exception:
143
+ return None
144
+
145
+ def _doc_id_and_src(doc: Document) -> tuple[str, str]:
146
+ src_path = doc.metadata.get("source_path", "in-memory-text")
147
+ doc_id = doc.metadata.get("doc_id") or Path(src_path).stem or "unknown-doc"
148
+ return doc_id, src_path
149
+
150
+ def _process_one(
151
+ poma_doc: Document, doc_idx: int
152
+ ) -> tuple[list[Document], str | None]:
153
+ """Process a single document via POMA API, return chunked Documents or failed source path."""
154
+ try:
155
+ doc_id, src_path = _doc_id_and_src(poma_doc)
156
+ path_obj = None
157
+ if src_path and src_path.strip() and isinstance(src_path, str):
158
+ try:
159
+ path = Path(src_path).resolve()
160
+ if path.exists():
161
+ path_obj = path
162
+ except Exception:
163
+ path_obj = None
164
+ if not path_obj:
165
+ raise InvalidInputError(
166
+ "No valid source_path found in document metadata."
167
+ )
168
+ start_result = self._client.start_chunk_file(path_obj, base_url=None)
169
+ job_id = start_result.get("job_id")
170
+ if not job_id:
171
+ raise RuntimeError("Failed to receive job ID from server.")
172
+ if self._show_progress:
173
+ print(
174
+ f"[{doc_idx}/{total_docs}] ⏳ Job {job_id} started for: {src_path}. Polling for results..."
175
+ )
176
+ result = self._client.get_chunk_result(
177
+ str(job_id), show_progress=self._show_progress
178
+ )
179
+ chunks: list[dict] = result.get("chunks", [])
180
+ chunksets: list[dict] = result.get("chunksets", [])
181
+ except Exception as exception:
182
+ print(
183
+ f"[{doc_idx}/{total_docs}] ❌ Exception chunking document: {exception}"
184
+ )
185
+ src_path = poma_doc.metadata.get("source_path", "in-memory-text")
186
+ return [], src_path
187
+
188
+ file_docs: list[Document] = []
189
+ try:
190
+ chunks_by_index: dict[int, dict] = {}
191
+ for chunk in chunks:
192
+ idx = _safe_int(chunk.get("chunk_index"))
193
+ if idx is not None:
194
+ chunks_by_index[idx] = chunk
195
+ for cs in chunksets:
196
+ chunkset_index = cs.get("chunkset_index")
197
+ chunks_indices = cs.get("chunks", []) or []
198
+ normalized_indices: list[int] = []
199
+ for chunk_index in chunks_indices:
200
+ idx = _safe_int(chunk_index)
201
+ if idx is not None:
202
+ normalized_indices.append(idx)
203
+ relevant_chunks = [
204
+ chunks_by_index[idx]
205
+ for idx in normalized_indices
206
+ if idx in chunks_by_index
207
+ ]
208
+ file_docs.append(
209
+ Document(
210
+ page_content=cs.get("contents", ""),
211
+ metadata={
212
+ "doc_id": doc_id,
213
+ "chunkset_index": chunkset_index,
214
+ "chunkset": cs,
215
+ "chunks": relevant_chunks,
216
+ "source_path": src_path,
217
+ },
218
+ )
219
+ )
220
+ except Exception as exception:
221
+ print(
222
+ f"[{doc_idx}/{total_docs}] ❌ Exception processing chunking result: {exception}"
223
+ )
224
+ src_path = poma_doc.metadata.get("source_path", "in-memory-text")
225
+ return [], src_path
226
+ return file_docs, None
227
+
228
+ # parallel processing of documents
229
+ cores = os.cpu_count() or 1
230
+ group_size = 5 if cores >= 5 else cores
231
+ for start in range(0, total_docs, group_size):
232
+ batch = list(
233
+ enumerate(documents[start : start + group_size], start=start + 1)
234
+ )
235
+ with ThreadPoolExecutor(max_workers=group_size) as executor:
236
+ futures = {
237
+ executor.submit(_process_one, doc, idx): (idx, doc)
238
+ for idx, doc in batch
239
+ }
240
+ for future in as_completed(futures):
241
+ idx, doc = futures[future]
242
+ try:
243
+ doc_as_chunk, failed_src = future.result()
244
+ if failed_src is None:
245
+ chunked_docs.extend(doc_as_chunk)
246
+ if self._show_progress:
247
+ src_path = doc.metadata.get(
248
+ "source_path", "in-memory-text"
249
+ )
250
+ print(
251
+ f"[{idx}/{total_docs}] ✅ Done: {src_path} (+{len(doc_as_chunk)} doc-chunks)"
252
+ )
253
+ else:
254
+ failed_paths.append(failed_src)
255
+ if self._show_progress:
256
+ print(f"[{idx}/{total_docs}] ❌ Failed: {failed_src}")
257
+ except Exception as error:
258
+ failed_paths.append(
259
+ doc.metadata.get("source_path", "in-memory-text")
260
+ )
261
+ if self._show_progress:
262
+ print(
263
+ f"[{idx}/{total_docs}] ❌ Failed with unexpected error: {error}"
264
+ )
265
+
266
+ if failed_paths:
267
+ print("The following files failed to process:")
268
+ for path in failed_paths:
269
+ print(f" - {path}")
270
+
271
+ if not chunked_docs:
272
+ raise InvalidInputError("No documents could be split successfully.")
273
+
274
+ return chunked_docs
275
+
276
+
277
+ # ------------------------------------------------------------------ #
278
+ # Cheatsheet retriever #
279
+ # ------------------------------------------------------------------ #
280
+
281
+
282
+ class PomaCheatsheetRetrieverLC(BaseRetriever):
283
+
284
+ tags: list[str] | None = Field(default=None)
285
+ metadata: dict[str, Any] | None = Field(default=None)
286
+
287
+ _vector_store: VectorStore = PrivateAttr()
288
+ _top_k: int = PrivateAttr()
289
+
290
+ def __init__(
291
+ self,
292
+ vector_store: VectorStore,
293
+ *,
294
+ top_k: int = 6,
295
+ **kwargs,
296
+ ):
297
+ """Initialize with a VectorStore and number of top_k results to retrieve."""
298
+ super().__init__(**kwargs)
299
+ self._vector_store = vector_store
300
+ self._top_k = top_k
301
+
302
+ def _retrieve(self, query: str) -> list[Document]:
303
+ """Retrieve chunkset documents and generate cheatsheets for the given query."""
304
+ hits = self._vector_store.similarity_search(query, k=self._top_k)
305
+ if not hits:
306
+ return []
307
+ grouped: dict[str, list[Document]] = {}
308
+ for doc in hits:
309
+ doc_id = doc.metadata["doc_id"]
310
+ grouped.setdefault(doc_id, []).append(doc)
311
+ cheatsheet_docs: list[Document] = []
312
+ for doc_id, chunked_docs in grouped.items():
313
+ cheatsheet = self._create_cheatsheet_langchain(chunked_docs)
314
+ cheatsheet_docs.append(
315
+ Document(page_content=cheatsheet, metadata={"doc_id": doc_id})
316
+ )
317
+ return cheatsheet_docs
318
+
319
+ def _get_relevant_documents(
320
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
321
+ ) -> list[Document]:
322
+ """Retrieve relevant documents with callback management."""
323
+ try:
324
+ documents = self._retrieve(query)
325
+ run_manager.on_retriever_end(documents)
326
+ return documents
327
+ except Exception as exception:
328
+ run_manager.on_retriever_error(exception)
329
+ raise
330
+
331
+ def _create_cheatsheet_langchain(self, chunked_docs: list[Document]) -> str:
332
+ """Generate a single deduplicated cheatsheet from chunked documents."""
333
+ all_chunks = []
334
+ seen = set()
335
+ for doc in chunked_docs:
336
+ doc_id = doc.metadata.get("doc_id", "unknown_doc")
337
+ for chunk in doc.metadata.get("chunks", []):
338
+ if not isinstance(chunk, dict):
339
+ continue
340
+ chunk_index = chunk["chunk_index"]
341
+ if chunk_index not in seen:
342
+ seen.add(chunk_index)
343
+ chunk["tag"] = doc_id
344
+ all_chunks.append(chunk)
345
+ sorted_chunks = sorted(
346
+ all_chunks, key=lambda chunk: (chunk["tag"], chunk["chunk_index"])
347
+ )
348
+ cheatsheets = _cheatsheets_from_chunks(sorted_chunks)
349
+ if (
350
+ not cheatsheets
351
+ or not isinstance(cheatsheets, list)
352
+ or len(cheatsheets) == 0
353
+ or "content" not in cheatsheets[0]
354
+ ):
355
+ raise Exception(
356
+ "Unknown error; cheatsheet could not be created from input chunks."
357
+ )
358
+ return cheatsheets[0]["content"]