ai-pipeline-core 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_pipeline_core/__init__.py +64 -158
- ai_pipeline_core/deployment/__init__.py +6 -18
- ai_pipeline_core/deployment/base.py +392 -212
- ai_pipeline_core/deployment/contract.py +6 -10
- ai_pipeline_core/{utils → deployment}/deploy.py +50 -69
- ai_pipeline_core/deployment/helpers.py +16 -17
- ai_pipeline_core/{progress.py → deployment/progress.py} +23 -24
- ai_pipeline_core/{utils/remote_deployment.py → deployment/remote.py} +11 -14
- ai_pipeline_core/docs_generator/__init__.py +54 -0
- ai_pipeline_core/docs_generator/__main__.py +5 -0
- ai_pipeline_core/docs_generator/cli.py +196 -0
- ai_pipeline_core/docs_generator/extractor.py +324 -0
- ai_pipeline_core/docs_generator/guide_builder.py +644 -0
- ai_pipeline_core/docs_generator/trimmer.py +35 -0
- ai_pipeline_core/docs_generator/validator.py +114 -0
- ai_pipeline_core/document_store/__init__.py +13 -0
- ai_pipeline_core/document_store/_summary.py +9 -0
- ai_pipeline_core/document_store/_summary_worker.py +170 -0
- ai_pipeline_core/document_store/clickhouse.py +492 -0
- ai_pipeline_core/document_store/factory.py +38 -0
- ai_pipeline_core/document_store/local.py +312 -0
- ai_pipeline_core/document_store/memory.py +85 -0
- ai_pipeline_core/document_store/protocol.py +68 -0
- ai_pipeline_core/documents/__init__.py +12 -14
- ai_pipeline_core/documents/_context_vars.py +85 -0
- ai_pipeline_core/documents/_hashing.py +52 -0
- ai_pipeline_core/documents/attachment.py +85 -0
- ai_pipeline_core/documents/context.py +128 -0
- ai_pipeline_core/documents/document.py +318 -1434
- ai_pipeline_core/documents/mime_type.py +11 -84
- ai_pipeline_core/documents/utils.py +4 -12
- ai_pipeline_core/exceptions.py +10 -62
- ai_pipeline_core/images/__init__.py +32 -85
- ai_pipeline_core/images/_processing.py +5 -11
- ai_pipeline_core/llm/__init__.py +6 -4
- ai_pipeline_core/llm/ai_messages.py +102 -90
- ai_pipeline_core/llm/client.py +229 -183
- ai_pipeline_core/llm/model_options.py +12 -84
- ai_pipeline_core/llm/model_response.py +53 -99
- ai_pipeline_core/llm/model_types.py +8 -23
- ai_pipeline_core/logging/__init__.py +2 -7
- ai_pipeline_core/logging/logging.yml +1 -1
- ai_pipeline_core/logging/logging_config.py +27 -37
- ai_pipeline_core/logging/logging_mixin.py +15 -41
- ai_pipeline_core/observability/__init__.py +32 -0
- ai_pipeline_core/observability/_debug/__init__.py +30 -0
- ai_pipeline_core/observability/_debug/_auto_summary.py +94 -0
- ai_pipeline_core/{debug/config.py → observability/_debug/_config.py} +11 -7
- ai_pipeline_core/{debug/content.py → observability/_debug/_content.py} +133 -75
- ai_pipeline_core/{debug/processor.py → observability/_debug/_processor.py} +16 -17
- ai_pipeline_core/{debug/summary.py → observability/_debug/_summary.py} +113 -37
- ai_pipeline_core/observability/_debug/_types.py +75 -0
- ai_pipeline_core/{debug/writer.py → observability/_debug/_writer.py} +126 -196
- ai_pipeline_core/observability/_document_tracking.py +146 -0
- ai_pipeline_core/observability/_initialization.py +194 -0
- ai_pipeline_core/observability/_logging_bridge.py +57 -0
- ai_pipeline_core/observability/_summary.py +81 -0
- ai_pipeline_core/observability/_tracking/__init__.py +6 -0
- ai_pipeline_core/observability/_tracking/_client.py +178 -0
- ai_pipeline_core/observability/_tracking/_internal.py +28 -0
- ai_pipeline_core/observability/_tracking/_models.py +138 -0
- ai_pipeline_core/observability/_tracking/_processor.py +158 -0
- ai_pipeline_core/observability/_tracking/_service.py +311 -0
- ai_pipeline_core/observability/_tracking/_writer.py +229 -0
- ai_pipeline_core/{tracing.py → observability/tracing.py} +139 -335
- ai_pipeline_core/pipeline/__init__.py +10 -0
- ai_pipeline_core/pipeline/decorators.py +915 -0
- ai_pipeline_core/pipeline/options.py +16 -0
- ai_pipeline_core/prompt_manager.py +16 -102
- ai_pipeline_core/settings.py +26 -31
- ai_pipeline_core/testing.py +9 -0
- ai_pipeline_core-0.4.0.dist-info/METADATA +807 -0
- ai_pipeline_core-0.4.0.dist-info/RECORD +76 -0
- ai_pipeline_core/debug/__init__.py +0 -26
- ai_pipeline_core/documents/document_list.py +0 -420
- ai_pipeline_core/documents/flow_document.py +0 -112
- ai_pipeline_core/documents/task_document.py +0 -117
- ai_pipeline_core/documents/temporary_document.py +0 -74
- ai_pipeline_core/flow/__init__.py +0 -9
- ai_pipeline_core/flow/config.py +0 -494
- ai_pipeline_core/flow/options.py +0 -75
- ai_pipeline_core/pipeline.py +0 -718
- ai_pipeline_core/prefect.py +0 -63
- ai_pipeline_core/prompt_builder/__init__.py +0 -5
- ai_pipeline_core/prompt_builder/documents_prompt.jinja2 +0 -23
- ai_pipeline_core/prompt_builder/global_cache.py +0 -78
- ai_pipeline_core/prompt_builder/new_core_documents_prompt.jinja2 +0 -6
- ai_pipeline_core/prompt_builder/prompt_builder.py +0 -253
- ai_pipeline_core/prompt_builder/system_prompt.jinja2 +0 -41
- ai_pipeline_core/storage/__init__.py +0 -8
- ai_pipeline_core/storage/storage.py +0 -628
- ai_pipeline_core/utils/__init__.py +0 -8
- ai_pipeline_core-0.3.4.dist-info/METADATA +0 -569
- ai_pipeline_core-0.3.4.dist-info/RECORD +0 -57
- {ai_pipeline_core-0.3.4.dist-info → ai_pipeline_core-0.4.0.dist-info}/WHEEL +0 -0
- {ai_pipeline_core-0.3.4.dist-info → ai_pipeline_core-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,492 @@
|
|
|
1
|
+
"""ClickHouse-backed document store for production use.
|
|
2
|
+
|
|
3
|
+
Two-table schema: document_content (deduplicated blobs) and document_index
|
|
4
|
+
(per-run document metadata). Uses ReplacingMergeTree for idempotent writes.
|
|
5
|
+
Circuit breaker buffers writes when ClickHouse is unavailable.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import time
|
|
10
|
+
from collections import deque
|
|
11
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from datetime import UTC, datetime
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
import clickhouse_connect
|
|
17
|
+
|
|
18
|
+
from ai_pipeline_core.document_store._summary import SummaryGenerator
|
|
19
|
+
from ai_pipeline_core.document_store._summary_worker import SummaryWorker
|
|
20
|
+
from ai_pipeline_core.documents._context_vars import suppress_registration
|
|
21
|
+
from ai_pipeline_core.documents._hashing import compute_content_sha256, compute_document_sha256
|
|
22
|
+
from ai_pipeline_core.documents.attachment import Attachment
|
|
23
|
+
from ai_pipeline_core.documents.document import Document
|
|
24
|
+
from ai_pipeline_core.logging import get_pipeline_logger
|
|
25
|
+
|
|
26
|
+
logger = get_pipeline_logger(__name__)
|
|
27
|
+
|
|
28
|
+
TABLE_DOCUMENT_CONTENT = "document_content"
|
|
29
|
+
TABLE_DOCUMENT_INDEX = "document_index"
|
|
30
|
+
|
|
31
|
+
_DDL_CONTENT = f"""
|
|
32
|
+
CREATE TABLE IF NOT EXISTS {TABLE_DOCUMENT_CONTENT}
|
|
33
|
+
(
|
|
34
|
+
content_sha256 String,
|
|
35
|
+
content String CODEC(ZSTD(3)),
|
|
36
|
+
created_at DateTime64(3, 'UTC'),
|
|
37
|
+
INDEX content_sha256_idx content_sha256 TYPE bloom_filter GRANULARITY 1
|
|
38
|
+
)
|
|
39
|
+
ENGINE = ReplacingMergeTree()
|
|
40
|
+
ORDER BY (content_sha256)
|
|
41
|
+
SETTINGS index_granularity = 8192
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
_DDL_INDEX = f"""
|
|
45
|
+
CREATE TABLE IF NOT EXISTS {TABLE_DOCUMENT_INDEX}
|
|
46
|
+
(
|
|
47
|
+
document_sha256 String,
|
|
48
|
+
run_scope String,
|
|
49
|
+
content_sha256 String,
|
|
50
|
+
class_name LowCardinality(String),
|
|
51
|
+
name String,
|
|
52
|
+
description String DEFAULT '',
|
|
53
|
+
mime_type LowCardinality(String),
|
|
54
|
+
sources Array(String),
|
|
55
|
+
origins Array(String),
|
|
56
|
+
attachment_names Array(String),
|
|
57
|
+
attachment_descriptions Array(String),
|
|
58
|
+
attachment_sha256s Array(String),
|
|
59
|
+
summary String DEFAULT '',
|
|
60
|
+
stored_at DateTime64(3, 'UTC'),
|
|
61
|
+
version UInt64 DEFAULT 1,
|
|
62
|
+
INDEX doc_sha256_idx document_sha256 TYPE bloom_filter GRANULARITY 1
|
|
63
|
+
)
|
|
64
|
+
ENGINE = ReplacingMergeTree(version)
|
|
65
|
+
ORDER BY (run_scope, class_name, document_sha256)
|
|
66
|
+
SETTINGS index_granularity = 8192
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
_MAX_BUFFER_SIZE = 10_000
|
|
70
|
+
_RECONNECT_INTERVAL_SEC = 60
|
|
71
|
+
_FAILURE_THRESHOLD = 3
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass
|
|
75
|
+
class _BufferedWrite:
|
|
76
|
+
"""A pending write operation buffered during circuit breaker open state."""
|
|
77
|
+
|
|
78
|
+
document: Document
|
|
79
|
+
run_scope: str
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class ClickHouseDocumentStore:
|
|
83
|
+
"""ClickHouse-backed document store with circuit breaker.
|
|
84
|
+
|
|
85
|
+
All sync operations run on a single-thread executor (max_workers=1),
|
|
86
|
+
so circuit breaker state needs no locking. Async methods dispatch to
|
|
87
|
+
this executor via loop.run_in_executor().
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
*,
|
|
93
|
+
host: str,
|
|
94
|
+
port: int = 8443,
|
|
95
|
+
database: str = "default",
|
|
96
|
+
username: str = "default",
|
|
97
|
+
password: str = "",
|
|
98
|
+
secure: bool = True,
|
|
99
|
+
summary_generator: SummaryGenerator | None = None,
|
|
100
|
+
) -> None:
|
|
101
|
+
self._params = {
|
|
102
|
+
"host": host,
|
|
103
|
+
"port": port,
|
|
104
|
+
"database": database,
|
|
105
|
+
"username": username,
|
|
106
|
+
"password": password,
|
|
107
|
+
"secure": secure,
|
|
108
|
+
}
|
|
109
|
+
self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="ch-docstore")
|
|
110
|
+
self._client: Any = None
|
|
111
|
+
self._tables_initialized = False
|
|
112
|
+
|
|
113
|
+
# Circuit breaker state (accessed only from the single executor thread)
|
|
114
|
+
self._consecutive_failures = 0
|
|
115
|
+
self._circuit_open = False
|
|
116
|
+
self._last_reconnect_attempt = 0.0
|
|
117
|
+
self._buffer: deque[_BufferedWrite] = deque(maxlen=_MAX_BUFFER_SIZE)
|
|
118
|
+
|
|
119
|
+
# Summary worker
|
|
120
|
+
self._summary_worker: SummaryWorker | None = None
|
|
121
|
+
if summary_generator:
|
|
122
|
+
self._summary_worker = SummaryWorker(
|
|
123
|
+
generator=summary_generator,
|
|
124
|
+
update_fn=self.update_summary,
|
|
125
|
+
)
|
|
126
|
+
self._summary_worker.start()
|
|
127
|
+
|
|
128
|
+
async def _run(self, fn: Any, *args: Any) -> Any:
|
|
129
|
+
"""Run a sync function on the dedicated executor."""
|
|
130
|
+
loop = asyncio.get_running_loop()
|
|
131
|
+
return await loop.run_in_executor(self._executor, fn, *args)
|
|
132
|
+
|
|
133
|
+
# --- Connection management (sync, executor thread only) ---
|
|
134
|
+
|
|
135
|
+
def _connect(self) -> None:
|
|
136
|
+
self._client = clickhouse_connect.get_client( # pyright: ignore[reportUnknownMemberType]
|
|
137
|
+
**self._params, # pyright: ignore[reportArgumentType]
|
|
138
|
+
)
|
|
139
|
+
logger.info(f"Document store connected to ClickHouse at {self._params['host']}:{self._params['port']}")
|
|
140
|
+
|
|
141
|
+
def _ensure_tables(self) -> None:
|
|
142
|
+
if self._tables_initialized:
|
|
143
|
+
return
|
|
144
|
+
if self._client is None:
|
|
145
|
+
self._connect()
|
|
146
|
+
self._client.command(_DDL_CONTENT)
|
|
147
|
+
self._client.command(_DDL_INDEX)
|
|
148
|
+
self._tables_initialized = True
|
|
149
|
+
logger.info("Document store tables verified/created")
|
|
150
|
+
|
|
151
|
+
def _ensure_connected(self) -> bool:
|
|
152
|
+
try:
|
|
153
|
+
if self._client is None:
|
|
154
|
+
self._connect()
|
|
155
|
+
self._ensure_tables()
|
|
156
|
+
return True
|
|
157
|
+
except Exception as e:
|
|
158
|
+
logger.warning(f"ClickHouse connection failed: {e}")
|
|
159
|
+
self._client = None
|
|
160
|
+
self._tables_initialized = False
|
|
161
|
+
return False
|
|
162
|
+
|
|
163
|
+
def _try_reconnect(self) -> bool:
|
|
164
|
+
now = time.monotonic()
|
|
165
|
+
if now - self._last_reconnect_attempt < _RECONNECT_INTERVAL_SEC:
|
|
166
|
+
return False
|
|
167
|
+
self._last_reconnect_attempt = now
|
|
168
|
+
if self._ensure_connected():
|
|
169
|
+
self._circuit_open = False
|
|
170
|
+
self._consecutive_failures = 0
|
|
171
|
+
logger.info("ClickHouse reconnected, flushing buffer")
|
|
172
|
+
return True
|
|
173
|
+
return False
|
|
174
|
+
|
|
175
|
+
def _record_success(self) -> None:
|
|
176
|
+
self._consecutive_failures = 0
|
|
177
|
+
if self._circuit_open:
|
|
178
|
+
self._circuit_open = False
|
|
179
|
+
logger.info("Circuit breaker closed")
|
|
180
|
+
|
|
181
|
+
def _record_failure(self) -> None:
|
|
182
|
+
self._consecutive_failures += 1
|
|
183
|
+
if self._consecutive_failures >= _FAILURE_THRESHOLD and not self._circuit_open:
|
|
184
|
+
self._circuit_open = True
|
|
185
|
+
self._client = None
|
|
186
|
+
self._tables_initialized = False
|
|
187
|
+
logger.warning(f"Circuit breaker opened after {self._consecutive_failures} failures")
|
|
188
|
+
|
|
189
|
+
# --- Async public API ---
|
|
190
|
+
|
|
191
|
+
async def save(self, document: Document, run_scope: str) -> None:
|
|
192
|
+
"""Save a document. Buffers writes when circuit breaker is open."""
|
|
193
|
+
await self._run(self._save_sync, document, run_scope)
|
|
194
|
+
if self._summary_worker and not self._circuit_open:
|
|
195
|
+
self._summary_worker.schedule(run_scope, document)
|
|
196
|
+
|
|
197
|
+
async def save_batch(self, documents: list[Document], run_scope: str) -> None:
|
|
198
|
+
"""Save multiple documents. Remaining docs are buffered on failure."""
|
|
199
|
+
await self._run(self._save_batch_sync, documents, run_scope)
|
|
200
|
+
if self._summary_worker and not self._circuit_open:
|
|
201
|
+
for doc in documents:
|
|
202
|
+
self._summary_worker.schedule(run_scope, doc)
|
|
203
|
+
|
|
204
|
+
async def load(self, run_scope: str, document_types: list[type[Document]]) -> list[Document]:
|
|
205
|
+
"""Load documents via index JOIN content, then batch-fetch attachments."""
|
|
206
|
+
return await self._run(self._load_sync, run_scope, document_types)
|
|
207
|
+
|
|
208
|
+
async def has_documents(self, run_scope: str, document_type: type[Document]) -> bool:
|
|
209
|
+
"""Check if any documents of this type exist in the run scope."""
|
|
210
|
+
return await self._run(self._has_documents_sync, run_scope, document_type)
|
|
211
|
+
|
|
212
|
+
async def check_existing(self, sha256s: list[str]) -> set[str]:
|
|
213
|
+
"""Return the subset of sha256s that exist in the document index."""
|
|
214
|
+
return await self._run(self._check_existing_sync, sha256s)
|
|
215
|
+
|
|
216
|
+
async def update_summary(self, run_scope: str, document_sha256: str, summary: str) -> None:
|
|
217
|
+
"""Update summary column for a stored document via ALTER TABLE UPDATE."""
|
|
218
|
+
await self._run(self._update_summary_sync, run_scope, document_sha256, summary)
|
|
219
|
+
|
|
220
|
+
async def load_summaries(self, run_scope: str, document_sha256s: list[str]) -> dict[str, str]:
|
|
221
|
+
"""Load summaries by SHA256 from the document index."""
|
|
222
|
+
return await self._run(self._load_summaries_sync, run_scope, document_sha256s)
|
|
223
|
+
|
|
224
|
+
def flush(self) -> None:
|
|
225
|
+
"""Block until all pending document summaries are processed."""
|
|
226
|
+
if self._summary_worker:
|
|
227
|
+
self._summary_worker.flush()
|
|
228
|
+
|
|
229
|
+
def shutdown(self) -> None:
|
|
230
|
+
"""Flush pending summaries, stop the summary worker, and release the executor."""
|
|
231
|
+
if self._summary_worker:
|
|
232
|
+
self._summary_worker.shutdown()
|
|
233
|
+
self._executor.shutdown(wait=True)
|
|
234
|
+
|
|
235
|
+
# --- Sync implementations (executor thread only) ---
|
|
236
|
+
|
|
237
|
+
def _save_sync(self, document: Document, run_scope: str) -> None:
|
|
238
|
+
if self._circuit_open:
|
|
239
|
+
if not self._try_reconnect():
|
|
240
|
+
self._buffer.append(_BufferedWrite(document=document, run_scope=run_scope))
|
|
241
|
+
return
|
|
242
|
+
self._flush_buffer()
|
|
243
|
+
|
|
244
|
+
try:
|
|
245
|
+
self._ensure_tables()
|
|
246
|
+
self._insert_document(document, run_scope)
|
|
247
|
+
self._record_success()
|
|
248
|
+
except Exception as e:
|
|
249
|
+
logger.warning(f"Failed to save document '{document.name}': {e}")
|
|
250
|
+
self._record_failure()
|
|
251
|
+
self._buffer.append(_BufferedWrite(document=document, run_scope=run_scope))
|
|
252
|
+
|
|
253
|
+
def _save_batch_sync(self, documents: list[Document], run_scope: str) -> None:
|
|
254
|
+
if self._circuit_open:
|
|
255
|
+
if not self._try_reconnect():
|
|
256
|
+
for doc in documents:
|
|
257
|
+
self._buffer.append(_BufferedWrite(document=doc, run_scope=run_scope))
|
|
258
|
+
return
|
|
259
|
+
self._flush_buffer()
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
self._ensure_tables()
|
|
263
|
+
except Exception as e:
|
|
264
|
+
logger.warning(f"Failed to ensure tables for batch: {e}")
|
|
265
|
+
self._record_failure()
|
|
266
|
+
for doc in documents:
|
|
267
|
+
self._buffer.append(_BufferedWrite(document=doc, run_scope=run_scope))
|
|
268
|
+
return
|
|
269
|
+
|
|
270
|
+
for i, doc in enumerate(documents):
|
|
271
|
+
try:
|
|
272
|
+
self._insert_document(doc, run_scope)
|
|
273
|
+
except Exception as e:
|
|
274
|
+
logger.warning(f"Failed to save document '{doc.name}' in batch: {e}")
|
|
275
|
+
self._record_failure()
|
|
276
|
+
for remaining in documents[i:]:
|
|
277
|
+
self._buffer.append(_BufferedWrite(document=remaining, run_scope=run_scope))
|
|
278
|
+
return
|
|
279
|
+
self._record_success()
|
|
280
|
+
|
|
281
|
+
def _flush_buffer(self) -> None:
|
|
282
|
+
while self._buffer:
|
|
283
|
+
item = self._buffer.popleft()
|
|
284
|
+
try:
|
|
285
|
+
self._insert_document(item.document, item.run_scope)
|
|
286
|
+
if self._summary_worker:
|
|
287
|
+
self._summary_worker.schedule(item.run_scope, item.document)
|
|
288
|
+
except Exception as e:
|
|
289
|
+
logger.warning(f"Failed to flush buffered document: {e}")
|
|
290
|
+
self._buffer.appendleft(item)
|
|
291
|
+
break
|
|
292
|
+
|
|
293
|
+
def _insert_document(self, document: Document, run_scope: str) -> None:
|
|
294
|
+
doc_sha256 = compute_document_sha256(document)
|
|
295
|
+
content_sha256 = compute_content_sha256(document.content)
|
|
296
|
+
|
|
297
|
+
# Insert content using insert() for binary-safe handling (idempotent via ReplacingMergeTree)
|
|
298
|
+
now = datetime.now(UTC)
|
|
299
|
+
content_rows: list[list[Any]] = [
|
|
300
|
+
[content_sha256, document.content, now],
|
|
301
|
+
]
|
|
302
|
+
|
|
303
|
+
# Collect attachment content
|
|
304
|
+
att_names: list[str] = []
|
|
305
|
+
att_descriptions: list[str] = []
|
|
306
|
+
att_sha256s: list[str] = []
|
|
307
|
+
for att in sorted(document.attachments, key=lambda a: a.name):
|
|
308
|
+
att_sha = compute_content_sha256(att.content)
|
|
309
|
+
att_names.append(att.name)
|
|
310
|
+
att_descriptions.append(att.description or "")
|
|
311
|
+
att_sha256s.append(att_sha)
|
|
312
|
+
content_rows.append([att_sha, att.content, now])
|
|
313
|
+
|
|
314
|
+
self._client.insert(
|
|
315
|
+
TABLE_DOCUMENT_CONTENT,
|
|
316
|
+
content_rows,
|
|
317
|
+
column_names=["content_sha256", "content", "created_at"],
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
# Insert index entry
|
|
321
|
+
self._client.command(
|
|
322
|
+
f"INSERT INTO {TABLE_DOCUMENT_INDEX} "
|
|
323
|
+
"(document_sha256, run_scope, content_sha256, class_name, name, description, "
|
|
324
|
+
"mime_type, sources, origins, "
|
|
325
|
+
"attachment_names, attachment_descriptions, attachment_sha256s, stored_at, version) "
|
|
326
|
+
"VALUES ("
|
|
327
|
+
"{doc_sha256:String}, {run_scope:String}, {content_sha256:String}, "
|
|
328
|
+
"{class_name:String}, {name:String}, {description:String}, "
|
|
329
|
+
"{mime_type:String}, "
|
|
330
|
+
"{sources:Array(String)}, {origins:Array(String)}, "
|
|
331
|
+
"{att_names:Array(String)}, {att_descs:Array(String)}, {att_sha256s:Array(String)}, "
|
|
332
|
+
"now64(3), 1)",
|
|
333
|
+
parameters={
|
|
334
|
+
"doc_sha256": doc_sha256,
|
|
335
|
+
"run_scope": run_scope,
|
|
336
|
+
"content_sha256": content_sha256,
|
|
337
|
+
"class_name": document.__class__.__name__,
|
|
338
|
+
"name": document.name,
|
|
339
|
+
"description": document.description or "",
|
|
340
|
+
"mime_type": document.mime_type,
|
|
341
|
+
"sources": list(document.sources),
|
|
342
|
+
"origins": list(document.origins),
|
|
343
|
+
"att_names": att_names,
|
|
344
|
+
"att_descs": att_descriptions,
|
|
345
|
+
"att_sha256s": att_sha256s,
|
|
346
|
+
},
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
def _load_sync(self, run_scope: str, document_types: list[type[Document]]) -> list[Document]:
|
|
350
|
+
"""Two-query load: index JOIN content, then batch attachment fetch."""
|
|
351
|
+
self._ensure_tables()
|
|
352
|
+
|
|
353
|
+
type_by_name: dict[str, type[Document]] = {t.__name__: t for t in document_types}
|
|
354
|
+
class_names = list(type_by_name.keys())
|
|
355
|
+
|
|
356
|
+
# Query 1: index JOIN content for document bodies
|
|
357
|
+
rows = self._client.query(
|
|
358
|
+
f"SELECT di.class_name, di.name, di.description, di.sources, di.origins, "
|
|
359
|
+
f"di.attachment_names, di.attachment_descriptions, di.attachment_sha256s, "
|
|
360
|
+
f"dc.content "
|
|
361
|
+
f"FROM {TABLE_DOCUMENT_INDEX} AS di FINAL "
|
|
362
|
+
f"JOIN {TABLE_DOCUMENT_CONTENT} AS dc FINAL ON di.content_sha256 = dc.content_sha256 "
|
|
363
|
+
f"WHERE di.run_scope = {{run_scope:String}} "
|
|
364
|
+
f"AND di.class_name IN {{class_names:Array(String)}}",
|
|
365
|
+
parameters={"run_scope": run_scope, "class_names": class_names},
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
# Collect all attachment SHA256s needed across all rows
|
|
369
|
+
parsed_rows: list[tuple[type[Document], str, str | None, tuple[str, ...], tuple[str, ...], list[str], list[str], list[str], bytes]] = []
|
|
370
|
+
all_att_sha256s: set[str] = set()
|
|
371
|
+
|
|
372
|
+
for row in rows.result_rows:
|
|
373
|
+
class_name = _decode(row[0])
|
|
374
|
+
doc_type = type_by_name.get(class_name)
|
|
375
|
+
if doc_type is None:
|
|
376
|
+
continue
|
|
377
|
+
|
|
378
|
+
att_sha256s = [_decode(s) for s in row[7]]
|
|
379
|
+
all_att_sha256s.update(att_sha256s)
|
|
380
|
+
content = row[8] if isinstance(row[8], bytes) else row[8].encode("utf-8")
|
|
381
|
+
|
|
382
|
+
parsed_rows.append((
|
|
383
|
+
doc_type,
|
|
384
|
+
_decode(row[1]), # name
|
|
385
|
+
_decode(row[2]) or None, # description
|
|
386
|
+
tuple(_decode(s) for s in row[3]), # sources
|
|
387
|
+
tuple(_decode(o) for o in row[4]), # origins
|
|
388
|
+
[_decode(n) for n in row[5]], # att_names
|
|
389
|
+
[_decode(d) for d in row[6]], # att_descs
|
|
390
|
+
att_sha256s,
|
|
391
|
+
content,
|
|
392
|
+
))
|
|
393
|
+
|
|
394
|
+
# Query 2: batch fetch ALL attachment content in one query
|
|
395
|
+
att_content_by_sha: dict[str, bytes] = {}
|
|
396
|
+
if all_att_sha256s:
|
|
397
|
+
att_rows = self._client.query(
|
|
398
|
+
f"SELECT content_sha256, content FROM {TABLE_DOCUMENT_CONTENT} FINAL WHERE content_sha256 IN {{sha256s:Array(String)}}",
|
|
399
|
+
parameters={"sha256s": list(all_att_sha256s)},
|
|
400
|
+
)
|
|
401
|
+
for att_row in att_rows.result_rows:
|
|
402
|
+
sha = _decode(att_row[0])
|
|
403
|
+
raw = att_row[1] if isinstance(att_row[1], bytes) else att_row[1].encode("utf-8")
|
|
404
|
+
att_content_by_sha[sha] = raw
|
|
405
|
+
|
|
406
|
+
# Reconstruct documents (suppress registration to avoid polluting TaskDocumentContext)
|
|
407
|
+
documents: list[Document] = []
|
|
408
|
+
with suppress_registration():
|
|
409
|
+
for doc_type, name, description, sources, origins, att_names, att_descs, att_sha256s, content in parsed_rows:
|
|
410
|
+
attachments: tuple[Attachment, ...] = ()
|
|
411
|
+
if att_sha256s:
|
|
412
|
+
att_list: list[Attachment] = []
|
|
413
|
+
for a_name, a_desc, a_sha in zip(att_names, att_descs, att_sha256s, strict=False):
|
|
414
|
+
a_content = att_content_by_sha.get(a_sha)
|
|
415
|
+
if a_content is None:
|
|
416
|
+
logger.warning(f"Attachment content {a_sha[:12]}... not found for document '{name}'")
|
|
417
|
+
continue
|
|
418
|
+
att_list.append(
|
|
419
|
+
Attachment(
|
|
420
|
+
name=a_name,
|
|
421
|
+
content=a_content,
|
|
422
|
+
description=a_desc or None,
|
|
423
|
+
)
|
|
424
|
+
)
|
|
425
|
+
attachments = tuple(att_list)
|
|
426
|
+
|
|
427
|
+
doc = doc_type(
|
|
428
|
+
name=name,
|
|
429
|
+
content=content,
|
|
430
|
+
description=description,
|
|
431
|
+
sources=sources,
|
|
432
|
+
origins=origins if origins else (),
|
|
433
|
+
attachments=attachments if attachments else None,
|
|
434
|
+
)
|
|
435
|
+
documents.append(doc)
|
|
436
|
+
|
|
437
|
+
return documents
|
|
438
|
+
|
|
439
|
+
def _has_documents_sync(self, run_scope: str, document_type: type[Document]) -> bool:
|
|
440
|
+
self._ensure_tables()
|
|
441
|
+
result = self._client.query(
|
|
442
|
+
f"SELECT 1 FROM {TABLE_DOCUMENT_INDEX} FINAL WHERE run_scope = {{run_scope:String}} AND class_name = {{class_name:String}} LIMIT 1",
|
|
443
|
+
parameters={"run_scope": run_scope, "class_name": document_type.__name__},
|
|
444
|
+
)
|
|
445
|
+
return len(result.result_rows) > 0
|
|
446
|
+
|
|
447
|
+
def _check_existing_sync(self, sha256s: list[str]) -> set[str]:
|
|
448
|
+
if not sha256s:
|
|
449
|
+
return set()
|
|
450
|
+
self._ensure_tables()
|
|
451
|
+
result = self._client.query(
|
|
452
|
+
f"SELECT document_sha256 FROM {TABLE_DOCUMENT_INDEX} FINAL WHERE document_sha256 IN {{sha256s:Array(String)}}",
|
|
453
|
+
parameters={"sha256s": sha256s},
|
|
454
|
+
)
|
|
455
|
+
return {_decode(row[0]) for row in result.result_rows}
|
|
456
|
+
|
|
457
|
+
def _update_summary_sync(self, run_scope: str, document_sha256: str, summary: str) -> None:
|
|
458
|
+
"""Update summary column via ALTER TABLE UPDATE mutation."""
|
|
459
|
+
try:
|
|
460
|
+
self._ensure_tables()
|
|
461
|
+
self._client.command(
|
|
462
|
+
f"ALTER TABLE {TABLE_DOCUMENT_INDEX} UPDATE summary = {{summary:String}} "
|
|
463
|
+
f"WHERE document_sha256 = {{sha256:String}} AND run_scope = {{run_scope:String}}",
|
|
464
|
+
parameters={"summary": summary, "sha256": document_sha256, "run_scope": run_scope},
|
|
465
|
+
)
|
|
466
|
+
except Exception as e:
|
|
467
|
+
logger.warning(f"Failed to update summary for {document_sha256[:12]}...: {e}")
|
|
468
|
+
|
|
469
|
+
def _load_summaries_sync(self, run_scope: str, document_sha256s: list[str]) -> dict[str, str]:
|
|
470
|
+
"""Query summaries from the document index."""
|
|
471
|
+
if not document_sha256s:
|
|
472
|
+
return {}
|
|
473
|
+
try:
|
|
474
|
+
self._ensure_tables()
|
|
475
|
+
result = self._client.query(
|
|
476
|
+
f"SELECT document_sha256, summary FROM {TABLE_DOCUMENT_INDEX} FINAL "
|
|
477
|
+
f"WHERE run_scope = {{run_scope:String}} "
|
|
478
|
+
f"AND document_sha256 IN {{sha256s:Array(String)}} "
|
|
479
|
+
f"AND summary != ''",
|
|
480
|
+
parameters={"run_scope": run_scope, "sha256s": document_sha256s},
|
|
481
|
+
)
|
|
482
|
+
return {_decode(row[0]): _decode(row[1]) for row in result.result_rows}
|
|
483
|
+
except Exception as e:
|
|
484
|
+
logger.warning(f"Failed to load summaries: {e}")
|
|
485
|
+
return {}
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def _decode(value: bytes | str) -> str:
|
|
489
|
+
"""Decode bytes to str if needed (strings_as_bytes=True mode)."""
|
|
490
|
+
if isinstance(value, bytes):
|
|
491
|
+
return value.decode("utf-8")
|
|
492
|
+
return value
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Factory function for creating document store instances based on settings."""
|
|
2
|
+
|
|
3
|
+
from ai_pipeline_core.document_store._summary import SummaryGenerator
|
|
4
|
+
from ai_pipeline_core.document_store.protocol import DocumentStore
|
|
5
|
+
from ai_pipeline_core.settings import Settings
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def create_document_store(
|
|
9
|
+
settings: Settings,
|
|
10
|
+
*,
|
|
11
|
+
summary_generator: SummaryGenerator | None = None,
|
|
12
|
+
) -> DocumentStore:
|
|
13
|
+
"""Create a DocumentStore based on settings.
|
|
14
|
+
|
|
15
|
+
Selects ClickHouseDocumentStore when clickhouse_host is configured,
|
|
16
|
+
otherwise falls back to LocalDocumentStore.
|
|
17
|
+
|
|
18
|
+
Backends are imported lazily to avoid circular imports.
|
|
19
|
+
"""
|
|
20
|
+
if not isinstance(settings, Settings): # pyright: ignore[reportUnnecessaryIsInstance]
|
|
21
|
+
raise TypeError(f"Expected Settings instance, got {type(settings).__name__}") # pyright: ignore[reportUnreachable]
|
|
22
|
+
|
|
23
|
+
if settings.clickhouse_host:
|
|
24
|
+
from ai_pipeline_core.document_store.clickhouse import ClickHouseDocumentStore
|
|
25
|
+
|
|
26
|
+
return ClickHouseDocumentStore(
|
|
27
|
+
host=settings.clickhouse_host,
|
|
28
|
+
port=settings.clickhouse_port,
|
|
29
|
+
database=settings.clickhouse_database,
|
|
30
|
+
username=settings.clickhouse_user,
|
|
31
|
+
password=settings.clickhouse_password,
|
|
32
|
+
secure=settings.clickhouse_secure,
|
|
33
|
+
summary_generator=summary_generator,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
from ai_pipeline_core.document_store.local import LocalDocumentStore
|
|
37
|
+
|
|
38
|
+
return LocalDocumentStore(summary_generator=summary_generator)
|