poma 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- poma/__init__.py +15 -0
- poma/client.py +353 -0
- poma/exceptions.py +20 -0
- poma/integrations/__init__.py +20 -0
- poma/integrations/langchain_poma.py +358 -0
- poma/integrations/llamaindex_poma.py +361 -0
- poma/retrieval.py +176 -0
- poma-0.0.0.dist-info/METADATA +66 -0
- poma-0.0.0.dist-info/RECORD +12 -0
- poma-0.0.0.dist-info/WHEEL +5 -0
- poma-0.0.0.dist-info/licenses/LICENSE +177 -0
- poma-0.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
# ---------------------------------------------------------------------
|
|
2
|
+
# POMA integration for LangChain
|
|
3
|
+
# ---------------------------------------------------------------------
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import hashlib
|
|
7
|
+
from typing import Any
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from collections.abc import Iterable
|
|
10
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
11
|
+
|
|
12
|
+
from langchain.document_loaders.base import BaseLoader
|
|
13
|
+
from langchain.schema import Document
|
|
14
|
+
from langchain.schema.retriever import BaseRetriever
|
|
15
|
+
from langchain_core.vectorstores import VectorStore
|
|
16
|
+
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
|
17
|
+
from langchain_text_splitters import TextSplitter
|
|
18
|
+
from pydantic import Field, PrivateAttr
|
|
19
|
+
|
|
20
|
+
from poma import Poma
|
|
21
|
+
from poma.client import ALLOWED_FILE_EXTENSIONS
|
|
22
|
+
from poma.exceptions import InvalidInputError
|
|
23
|
+
from poma.retrieval import _cheatsheets_from_chunks
|
|
24
|
+
|
|
25
|
+
__all__ = ["PomaFileLoader", "PomaChunksetSplitter", "PomaCheatsheetRetrieverLC"]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# ------------------------------------------------------------------ #
|
|
29
|
+
# Load from Path → LC Documents #
|
|
30
|
+
# ------------------------------------------------------------------ #
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class PomaFileLoader(BaseLoader):
|
|
34
|
+
|
|
35
|
+
def __init__(self, input_path: str | Path):
|
|
36
|
+
"""Initialize with a file or directory path."""
|
|
37
|
+
self.input_path = Path(input_path).expanduser().resolve()
|
|
38
|
+
|
|
39
|
+
def load(self) -> list[Document]:
|
|
40
|
+
"""
|
|
41
|
+
Load files from the input path (file or directory) into LangChain Documents.
|
|
42
|
+
Only files with allowed extensions are processed; others are skipped.
|
|
43
|
+
"""
|
|
44
|
+
path = self.input_path
|
|
45
|
+
if not path.exists():
|
|
46
|
+
raise FileNotFoundError(f"No such path: {path}")
|
|
47
|
+
|
|
48
|
+
documents: list[Document] = []
|
|
49
|
+
skipped: int = 0
|
|
50
|
+
|
|
51
|
+
def _process_file(file_path: Path):
|
|
52
|
+
nonlocal skipped, documents
|
|
53
|
+
if not file_path.is_file():
|
|
54
|
+
return
|
|
55
|
+
file_extension = file_path.suffix.lower()
|
|
56
|
+
if not file_extension or file_extension not in ALLOWED_FILE_EXTENSIONS:
|
|
57
|
+
skipped += 1
|
|
58
|
+
return
|
|
59
|
+
file_bytes = file_path.read_bytes()
|
|
60
|
+
file_hash = hashlib.md5(file_bytes).hexdigest()
|
|
61
|
+
if file_path.suffix.lower() == ".pdf":
|
|
62
|
+
page_content: str = "" # LangChain requires str
|
|
63
|
+
else:
|
|
64
|
+
try:
|
|
65
|
+
page_content = file_bytes.decode("utf-8")
|
|
66
|
+
except UnicodeDecodeError:
|
|
67
|
+
skipped += 1
|
|
68
|
+
return
|
|
69
|
+
documents.append(
|
|
70
|
+
Document(
|
|
71
|
+
page_content=page_content,
|
|
72
|
+
metadata={
|
|
73
|
+
"source_path": str(file_path),
|
|
74
|
+
"doc_id": f"{file_hash}",
|
|
75
|
+
},
|
|
76
|
+
)
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if path.is_file():
|
|
80
|
+
_process_file(path)
|
|
81
|
+
elif path.is_dir():
|
|
82
|
+
for path_in_dir in sorted(path.rglob("*")):
|
|
83
|
+
_process_file(path_in_dir)
|
|
84
|
+
else:
|
|
85
|
+
raise FileNotFoundError(f"Unsupported path type (not file/dir): {path}")
|
|
86
|
+
|
|
87
|
+
allowed = ", ".join(sorted(ALLOWED_FILE_EXTENSIONS))
|
|
88
|
+
if not documents:
|
|
89
|
+
raise InvalidInputError(f"No supported files found. Allowed: {allowed}")
|
|
90
|
+
if skipped > 0:
|
|
91
|
+
print(
|
|
92
|
+
f"Skipped {skipped} file(s) due to unsupported or unreadable type. Allowed: {allowed}"
|
|
93
|
+
)
|
|
94
|
+
return documents
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# ------------------------------------------------------------------ #
|
|
98
|
+
# Generate Chunksets #
|
|
99
|
+
# ------------------------------------------------------------------ #
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class PomaChunksetSplitter(TextSplitter):
|
|
103
|
+
|
|
104
|
+
_client: Poma = PrivateAttr()
|
|
105
|
+
_show_progress: bool = PrivateAttr(default=False)
|
|
106
|
+
|
|
107
|
+
def __init__(self, client: Poma, *, verbose: bool = False, **kwargs):
|
|
108
|
+
"""Initialize with a Poma client and optional verbosity."""
|
|
109
|
+
super().__init__(**kwargs)
|
|
110
|
+
self._client = client
|
|
111
|
+
self._show_progress = bool(verbose)
|
|
112
|
+
|
|
113
|
+
def split_text(self, text: str) -> list[str]:
|
|
114
|
+
"""Not implemented, use split_documents()."""
|
|
115
|
+
raise NotImplementedError("Not implemented, use split_documents().")
|
|
116
|
+
|
|
117
|
+
def split_documents(self, documents: Iterable[Document]) -> list[Document]:
|
|
118
|
+
"""
|
|
119
|
+
Split LangChain Documents into chunkset Documents via POMA API.
|
|
120
|
+
Each output Document corresponds to a chunkset, with associated chunks in metadata.
|
|
121
|
+
"""
|
|
122
|
+
documents = list(documents)
|
|
123
|
+
if not documents:
|
|
124
|
+
raise InvalidInputError("No documents provided to split.")
|
|
125
|
+
|
|
126
|
+
total_docs = len(documents)
|
|
127
|
+
chunked_docs: list[Document] = []
|
|
128
|
+
failed_paths: list[str] = []
|
|
129
|
+
|
|
130
|
+
def _safe_int(value: object) -> int | None:
|
|
131
|
+
if isinstance(value, bool):
|
|
132
|
+
return None
|
|
133
|
+
if isinstance(value, int):
|
|
134
|
+
return value
|
|
135
|
+
if isinstance(value, str):
|
|
136
|
+
try:
|
|
137
|
+
return int(value.strip())
|
|
138
|
+
except Exception:
|
|
139
|
+
return None
|
|
140
|
+
try:
|
|
141
|
+
return int(value) # type: ignore[arg-type]
|
|
142
|
+
except Exception:
|
|
143
|
+
return None
|
|
144
|
+
|
|
145
|
+
def _doc_id_and_src(doc: Document) -> tuple[str, str]:
|
|
146
|
+
src_path = doc.metadata.get("source_path", "in-memory-text")
|
|
147
|
+
doc_id = doc.metadata.get("doc_id") or Path(src_path).stem or "unknown-doc"
|
|
148
|
+
return doc_id, src_path
|
|
149
|
+
|
|
150
|
+
def _process_one(
|
|
151
|
+
poma_doc: Document, doc_idx: int
|
|
152
|
+
) -> tuple[list[Document], str | None]:
|
|
153
|
+
"""Process a single document via POMA API, return chunked Documents or failed source path."""
|
|
154
|
+
try:
|
|
155
|
+
doc_id, src_path = _doc_id_and_src(poma_doc)
|
|
156
|
+
path_obj = None
|
|
157
|
+
if src_path and src_path.strip() and isinstance(src_path, str):
|
|
158
|
+
try:
|
|
159
|
+
path = Path(src_path).resolve()
|
|
160
|
+
if path.exists():
|
|
161
|
+
path_obj = path
|
|
162
|
+
except Exception:
|
|
163
|
+
path_obj = None
|
|
164
|
+
if not path_obj:
|
|
165
|
+
raise InvalidInputError(
|
|
166
|
+
"No valid source_path found in document metadata."
|
|
167
|
+
)
|
|
168
|
+
start_result = self._client.start_chunk_file(path_obj, base_url=None)
|
|
169
|
+
job_id = start_result.get("job_id")
|
|
170
|
+
if not job_id:
|
|
171
|
+
raise RuntimeError("Failed to receive job ID from server.")
|
|
172
|
+
if self._show_progress:
|
|
173
|
+
print(
|
|
174
|
+
f"[{doc_idx}/{total_docs}] ⏳ Job {job_id} started for: {src_path}. Polling for results..."
|
|
175
|
+
)
|
|
176
|
+
result = self._client.get_chunk_result(
|
|
177
|
+
str(job_id), show_progress=self._show_progress
|
|
178
|
+
)
|
|
179
|
+
chunks: list[dict] = result.get("chunks", [])
|
|
180
|
+
chunksets: list[dict] = result.get("chunksets", [])
|
|
181
|
+
except Exception as exception:
|
|
182
|
+
print(
|
|
183
|
+
f"[{doc_idx}/{total_docs}] ❌ Exception chunking document: {exception}"
|
|
184
|
+
)
|
|
185
|
+
src_path = poma_doc.metadata.get("source_path", "in-memory-text")
|
|
186
|
+
return [], src_path
|
|
187
|
+
|
|
188
|
+
file_docs: list[Document] = []
|
|
189
|
+
try:
|
|
190
|
+
chunks_by_index: dict[int, dict] = {}
|
|
191
|
+
for chunk in chunks:
|
|
192
|
+
idx = _safe_int(chunk.get("chunk_index"))
|
|
193
|
+
if idx is not None:
|
|
194
|
+
chunks_by_index[idx] = chunk
|
|
195
|
+
for cs in chunksets:
|
|
196
|
+
chunkset_index = cs.get("chunkset_index")
|
|
197
|
+
chunks_indices = cs.get("chunks", []) or []
|
|
198
|
+
normalized_indices: list[int] = []
|
|
199
|
+
for chunk_index in chunks_indices:
|
|
200
|
+
idx = _safe_int(chunk_index)
|
|
201
|
+
if idx is not None:
|
|
202
|
+
normalized_indices.append(idx)
|
|
203
|
+
relevant_chunks = [
|
|
204
|
+
chunks_by_index[idx]
|
|
205
|
+
for idx in normalized_indices
|
|
206
|
+
if idx in chunks_by_index
|
|
207
|
+
]
|
|
208
|
+
file_docs.append(
|
|
209
|
+
Document(
|
|
210
|
+
page_content=cs.get("contents", ""),
|
|
211
|
+
metadata={
|
|
212
|
+
"doc_id": doc_id,
|
|
213
|
+
"chunkset_index": chunkset_index,
|
|
214
|
+
"chunkset": cs,
|
|
215
|
+
"chunks": relevant_chunks,
|
|
216
|
+
"source_path": src_path,
|
|
217
|
+
},
|
|
218
|
+
)
|
|
219
|
+
)
|
|
220
|
+
except Exception as exception:
|
|
221
|
+
print(
|
|
222
|
+
f"[{doc_idx}/{total_docs}] ❌ Exception processing chunking result: {exception}"
|
|
223
|
+
)
|
|
224
|
+
src_path = poma_doc.metadata.get("source_path", "in-memory-text")
|
|
225
|
+
return [], src_path
|
|
226
|
+
return file_docs, None
|
|
227
|
+
|
|
228
|
+
# parallel processing of documents
|
|
229
|
+
cores = os.cpu_count() or 1
|
|
230
|
+
group_size = 5 if cores >= 5 else cores
|
|
231
|
+
for start in range(0, total_docs, group_size):
|
|
232
|
+
batch = list(
|
|
233
|
+
enumerate(documents[start : start + group_size], start=start + 1)
|
|
234
|
+
)
|
|
235
|
+
with ThreadPoolExecutor(max_workers=group_size) as executor:
|
|
236
|
+
futures = {
|
|
237
|
+
executor.submit(_process_one, doc, idx): (idx, doc)
|
|
238
|
+
for idx, doc in batch
|
|
239
|
+
}
|
|
240
|
+
for future in as_completed(futures):
|
|
241
|
+
idx, doc = futures[future]
|
|
242
|
+
try:
|
|
243
|
+
doc_as_chunk, failed_src = future.result()
|
|
244
|
+
if failed_src is None:
|
|
245
|
+
chunked_docs.extend(doc_as_chunk)
|
|
246
|
+
if self._show_progress:
|
|
247
|
+
src_path = doc.metadata.get(
|
|
248
|
+
"source_path", "in-memory-text"
|
|
249
|
+
)
|
|
250
|
+
print(
|
|
251
|
+
f"[{idx}/{total_docs}] ✅ Done: {src_path} (+{len(doc_as_chunk)} doc-chunks)"
|
|
252
|
+
)
|
|
253
|
+
else:
|
|
254
|
+
failed_paths.append(failed_src)
|
|
255
|
+
if self._show_progress:
|
|
256
|
+
print(f"[{idx}/{total_docs}] ❌ Failed: {failed_src}")
|
|
257
|
+
except Exception as error:
|
|
258
|
+
failed_paths.append(
|
|
259
|
+
doc.metadata.get("source_path", "in-memory-text")
|
|
260
|
+
)
|
|
261
|
+
if self._show_progress:
|
|
262
|
+
print(
|
|
263
|
+
f"[{idx}/{total_docs}] ❌ Failed with unexpected error: {error}"
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
if failed_paths:
|
|
267
|
+
print("The following files failed to process:")
|
|
268
|
+
for path in failed_paths:
|
|
269
|
+
print(f" - {path}")
|
|
270
|
+
|
|
271
|
+
if not chunked_docs:
|
|
272
|
+
raise InvalidInputError("No documents could be split successfully.")
|
|
273
|
+
|
|
274
|
+
return chunked_docs
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
# ------------------------------------------------------------------ #
|
|
278
|
+
# Cheatsheet retriever #
|
|
279
|
+
# ------------------------------------------------------------------ #
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class PomaCheatsheetRetrieverLC(BaseRetriever):
|
|
283
|
+
|
|
284
|
+
tags: list[str] | None = Field(default=None)
|
|
285
|
+
metadata: dict[str, Any] | None = Field(default=None)
|
|
286
|
+
|
|
287
|
+
_vector_store: VectorStore = PrivateAttr()
|
|
288
|
+
_top_k: int = PrivateAttr()
|
|
289
|
+
|
|
290
|
+
def __init__(
|
|
291
|
+
self,
|
|
292
|
+
vector_store: VectorStore,
|
|
293
|
+
*,
|
|
294
|
+
top_k: int = 6,
|
|
295
|
+
**kwargs,
|
|
296
|
+
):
|
|
297
|
+
"""Initialize with a VectorStore and number of top_k results to retrieve."""
|
|
298
|
+
super().__init__(**kwargs)
|
|
299
|
+
self._vector_store = vector_store
|
|
300
|
+
self._top_k = top_k
|
|
301
|
+
|
|
302
|
+
def _retrieve(self, query: str) -> list[Document]:
|
|
303
|
+
"""Retrieve chunkset documents and generate cheatsheets for the given query."""
|
|
304
|
+
hits = self._vector_store.similarity_search(query, k=self._top_k)
|
|
305
|
+
if not hits:
|
|
306
|
+
return []
|
|
307
|
+
grouped: dict[str, list[Document]] = {}
|
|
308
|
+
for doc in hits:
|
|
309
|
+
doc_id = doc.metadata["doc_id"]
|
|
310
|
+
grouped.setdefault(doc_id, []).append(doc)
|
|
311
|
+
cheatsheet_docs: list[Document] = []
|
|
312
|
+
for doc_id, chunked_docs in grouped.items():
|
|
313
|
+
cheatsheet = self._create_cheatsheet_langchain(chunked_docs)
|
|
314
|
+
cheatsheet_docs.append(
|
|
315
|
+
Document(page_content=cheatsheet, metadata={"doc_id": doc_id})
|
|
316
|
+
)
|
|
317
|
+
return cheatsheet_docs
|
|
318
|
+
|
|
319
|
+
def _get_relevant_documents(
|
|
320
|
+
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
321
|
+
) -> list[Document]:
|
|
322
|
+
"""Retrieve relevant documents with callback management."""
|
|
323
|
+
try:
|
|
324
|
+
documents = self._retrieve(query)
|
|
325
|
+
run_manager.on_retriever_end(documents)
|
|
326
|
+
return documents
|
|
327
|
+
except Exception as exception:
|
|
328
|
+
run_manager.on_retriever_error(exception)
|
|
329
|
+
raise
|
|
330
|
+
|
|
331
|
+
def _create_cheatsheet_langchain(self, chunked_docs: list[Document]) -> str:
|
|
332
|
+
"""Generate a single deduplicated cheatsheet from chunked documents."""
|
|
333
|
+
all_chunks = []
|
|
334
|
+
seen = set()
|
|
335
|
+
for doc in chunked_docs:
|
|
336
|
+
doc_id = doc.metadata.get("doc_id", "unknown_doc")
|
|
337
|
+
for chunk in doc.metadata.get("chunks", []):
|
|
338
|
+
if not isinstance(chunk, dict):
|
|
339
|
+
continue
|
|
340
|
+
chunk_index = chunk["chunk_index"]
|
|
341
|
+
if chunk_index not in seen:
|
|
342
|
+
seen.add(chunk_index)
|
|
343
|
+
chunk["tag"] = doc_id
|
|
344
|
+
all_chunks.append(chunk)
|
|
345
|
+
sorted_chunks = sorted(
|
|
346
|
+
all_chunks, key=lambda chunk: (chunk["tag"], chunk["chunk_index"])
|
|
347
|
+
)
|
|
348
|
+
cheatsheets = _cheatsheets_from_chunks(sorted_chunks)
|
|
349
|
+
if (
|
|
350
|
+
not cheatsheets
|
|
351
|
+
or not isinstance(cheatsheets, list)
|
|
352
|
+
or len(cheatsheets) == 0
|
|
353
|
+
or "content" not in cheatsheets[0]
|
|
354
|
+
):
|
|
355
|
+
raise Exception(
|
|
356
|
+
"Unknown error; cheatsheet could not be created from input chunks."
|
|
357
|
+
)
|
|
358
|
+
return cheatsheets[0]["content"]
|