mirage-benchmark 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mirage-benchmark might be problematic. Click here for more details.
- mirage/__init__.py +83 -0
- mirage/cli.py +150 -0
- mirage/core/__init__.py +52 -0
- mirage/core/config.py +248 -0
- mirage/core/llm.py +1745 -0
- mirage/core/prompts.py +884 -0
- mirage/embeddings/__init__.py +31 -0
- mirage/embeddings/models.py +512 -0
- mirage/embeddings/rerankers_multimodal.py +766 -0
- mirage/embeddings/rerankers_text.py +149 -0
- mirage/evaluation/__init__.py +26 -0
- mirage/evaluation/metrics.py +2223 -0
- mirage/evaluation/metrics_optimized.py +2172 -0
- mirage/pipeline/__init__.py +45 -0
- mirage/pipeline/chunker.py +545 -0
- mirage/pipeline/context.py +1003 -0
- mirage/pipeline/deduplication.py +491 -0
- mirage/pipeline/domain.py +514 -0
- mirage/pipeline/pdf_processor.py +598 -0
- mirage/pipeline/qa_generator.py +798 -0
- mirage/utils/__init__.py +31 -0
- mirage/utils/ablation.py +360 -0
- mirage/utils/preflight.py +663 -0
- mirage/utils/stats.py +626 -0
- mirage_benchmark-1.0.4.dist-info/METADATA +490 -0
- mirage_benchmark-1.0.4.dist-info/RECORD +30 -0
- mirage_benchmark-1.0.4.dist-info/WHEEL +5 -0
- mirage_benchmark-1.0.4.dist-info/entry_points.txt +3 -0
- mirage_benchmark-1.0.4.dist-info/licenses/LICENSE +190 -0
- mirage_benchmark-1.0.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,663 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Preflight Check Module - Validates all services before pipeline execution.
|
|
3
|
+
|
|
4
|
+
Checks:
|
|
5
|
+
1. LLM API connectivity (text generation)
|
|
6
|
+
2. VLM API connectivity (vision + text)
|
|
7
|
+
3. Embedding model loading & inference
|
|
8
|
+
4. Reranker model loading & inference
|
|
9
|
+
5. API key availability
|
|
10
|
+
6. Required directories and files
|
|
11
|
+
|
|
12
|
+
Run standalone: python preflight_check.py
|
|
13
|
+
Or import: from preflight_check import run_preflight_checks
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import sys
|
|
18
|
+
import time
|
|
19
|
+
from typing import Dict, List, Tuple, Optional, Any
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from dataclasses import dataclass, field
|
|
22
|
+
from enum import Enum
|
|
23
|
+
|
|
24
|
+
class CheckStatus(Enum):
|
|
25
|
+
PASS = "✅ PASS"
|
|
26
|
+
FAIL = "❌ FAIL"
|
|
27
|
+
WARN = "⚠️ WARN"
|
|
28
|
+
SKIP = "⏭️ SKIP"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class CheckResult:
|
|
33
|
+
name: str
|
|
34
|
+
status: CheckStatus
|
|
35
|
+
message: str
|
|
36
|
+
details: Dict[str, Any] = field(default_factory=dict)
|
|
37
|
+
duration_ms: float = 0.0
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _timer(func):
|
|
41
|
+
"""Decorator to time check functions."""
|
|
42
|
+
def wrapper(*args, **kwargs):
|
|
43
|
+
start = time.time()
|
|
44
|
+
result = func(*args, **kwargs)
|
|
45
|
+
result.duration_ms = (time.time() - start) * 1000
|
|
46
|
+
return result
|
|
47
|
+
return wrapper
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# ============================================================================
|
|
51
|
+
# INDIVIDUAL CHECK FUNCTIONS
|
|
52
|
+
# ============================================================================
|
|
53
|
+
|
|
54
|
+
@_timer
|
|
55
|
+
def check_config() -> CheckResult:
|
|
56
|
+
"""Check if config.yaml can be loaded."""
|
|
57
|
+
try:
|
|
58
|
+
from config_loader import load_config, get_backend_config, get_paths_config, get_embedding_config
|
|
59
|
+
config = load_config()
|
|
60
|
+
backend = get_backend_config()
|
|
61
|
+
paths = get_paths_config()
|
|
62
|
+
embed = get_embedding_config()
|
|
63
|
+
|
|
64
|
+
return CheckResult(
|
|
65
|
+
name="Configuration",
|
|
66
|
+
status=CheckStatus.PASS,
|
|
67
|
+
message="config.yaml loaded successfully",
|
|
68
|
+
details={
|
|
69
|
+
"backend": backend.get('name', 'UNKNOWN'),
|
|
70
|
+
"llm_model": backend.get('llm_model', 'UNKNOWN'),
|
|
71
|
+
"vlm_model": backend.get('vlm_model', 'UNKNOWN'),
|
|
72
|
+
"embedding_model": embed.get('model', 'UNKNOWN'),
|
|
73
|
+
"output_dir": paths.get('output_dir', 'UNKNOWN')
|
|
74
|
+
}
|
|
75
|
+
)
|
|
76
|
+
except FileNotFoundError as e:
|
|
77
|
+
return CheckResult(
|
|
78
|
+
name="Configuration",
|
|
79
|
+
status=CheckStatus.FAIL,
|
|
80
|
+
message=f"config.yaml not found: {e}"
|
|
81
|
+
)
|
|
82
|
+
except Exception as e:
|
|
83
|
+
return CheckResult(
|
|
84
|
+
name="Configuration",
|
|
85
|
+
status=CheckStatus.FAIL,
|
|
86
|
+
message=f"Failed to load config: {e}"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@_timer
|
|
91
|
+
def check_api_key() -> CheckResult:
|
|
92
|
+
"""Check if API key is available and non-empty."""
|
|
93
|
+
try:
|
|
94
|
+
from call_llm import API_KEY, BACKEND
|
|
95
|
+
|
|
96
|
+
if API_KEY and len(API_KEY) > 10:
|
|
97
|
+
# Mask the key for display
|
|
98
|
+
masked = f"{API_KEY[:8]}...{API_KEY[-4:]}" if len(API_KEY) > 12 else "***"
|
|
99
|
+
return CheckResult(
|
|
100
|
+
name="API Key",
|
|
101
|
+
status=CheckStatus.PASS,
|
|
102
|
+
message=f"API key loaded for {BACKEND}",
|
|
103
|
+
details={"backend": BACKEND, "key_preview": masked}
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
return CheckResult(
|
|
107
|
+
name="API Key",
|
|
108
|
+
status=CheckStatus.FAIL,
|
|
109
|
+
message=f"API key missing or too short for {BACKEND}",
|
|
110
|
+
details={"backend": BACKEND}
|
|
111
|
+
)
|
|
112
|
+
except ImportError as e:
|
|
113
|
+
return CheckResult(
|
|
114
|
+
name="API Key",
|
|
115
|
+
status=CheckStatus.FAIL,
|
|
116
|
+
message=f"Cannot import call_llm: {e}"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@_timer
|
|
121
|
+
def check_llm_call() -> CheckResult:
|
|
122
|
+
"""Test LLM API with a minimal call."""
|
|
123
|
+
try:
|
|
124
|
+
from call_llm import call_llm, BACKEND, LLM_MODEL_NAME
|
|
125
|
+
|
|
126
|
+
test_prompt = "Say 'OK' and nothing else."
|
|
127
|
+
response = call_llm(test_prompt)
|
|
128
|
+
|
|
129
|
+
if response and len(response.strip()) > 0:
|
|
130
|
+
return CheckResult(
|
|
131
|
+
name="LLM API",
|
|
132
|
+
status=CheckStatus.PASS,
|
|
133
|
+
message=f"LLM call successful ({BACKEND})",
|
|
134
|
+
details={
|
|
135
|
+
"backend": BACKEND,
|
|
136
|
+
"model": LLM_MODEL_NAME,
|
|
137
|
+
"response_preview": response[:50] + "..." if len(response) > 50 else response
|
|
138
|
+
}
|
|
139
|
+
)
|
|
140
|
+
else:
|
|
141
|
+
return CheckResult(
|
|
142
|
+
name="LLM API",
|
|
143
|
+
status=CheckStatus.FAIL,
|
|
144
|
+
message="LLM returned empty response",
|
|
145
|
+
details={"backend": BACKEND, "model": LLM_MODEL_NAME}
|
|
146
|
+
)
|
|
147
|
+
except Exception as e:
|
|
148
|
+
return CheckResult(
|
|
149
|
+
name="LLM API",
|
|
150
|
+
status=CheckStatus.FAIL,
|
|
151
|
+
message=f"LLM call failed: {e}",
|
|
152
|
+
details={"error": str(e)}
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@_timer
|
|
157
|
+
def check_vlm_call() -> CheckResult:
|
|
158
|
+
"""Test VLM API with a minimal text-only call (no image needed for connectivity check)."""
|
|
159
|
+
try:
|
|
160
|
+
from call_llm import call_vlm_interweaved, BACKEND, VLM_MODEL_NAME
|
|
161
|
+
|
|
162
|
+
# Test with text-only context (simulates VLM call without actual image)
|
|
163
|
+
test_prompt = "Say 'VLM OK' and nothing else."
|
|
164
|
+
test_chunks = [{"content": "Test content", "image_path": None}]
|
|
165
|
+
|
|
166
|
+
response = call_vlm_interweaved(test_prompt, test_chunks)
|
|
167
|
+
|
|
168
|
+
if response and len(response.strip()) > 0:
|
|
169
|
+
return CheckResult(
|
|
170
|
+
name="VLM API",
|
|
171
|
+
status=CheckStatus.PASS,
|
|
172
|
+
message=f"VLM call successful ({BACKEND})",
|
|
173
|
+
details={
|
|
174
|
+
"backend": BACKEND,
|
|
175
|
+
"model": VLM_MODEL_NAME,
|
|
176
|
+
"response_preview": response[:50] + "..." if len(response) > 50 else response
|
|
177
|
+
}
|
|
178
|
+
)
|
|
179
|
+
else:
|
|
180
|
+
return CheckResult(
|
|
181
|
+
name="VLM API",
|
|
182
|
+
status=CheckStatus.FAIL,
|
|
183
|
+
message="VLM returned empty response",
|
|
184
|
+
details={"backend": BACKEND, "model": VLM_MODEL_NAME}
|
|
185
|
+
)
|
|
186
|
+
except Exception as e:
|
|
187
|
+
return CheckResult(
|
|
188
|
+
name="VLM API",
|
|
189
|
+
status=CheckStatus.FAIL,
|
|
190
|
+
message=f"VLM call failed: {e}",
|
|
191
|
+
details={"error": str(e)}
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@_timer
|
|
196
|
+
def check_embedding_model() -> CheckResult:
|
|
197
|
+
"""Test embedding model loading and inference."""
|
|
198
|
+
try:
|
|
199
|
+
from config_loader import get_embedding_config
|
|
200
|
+
embed_config = get_embedding_config()
|
|
201
|
+
model_name = embed_config.get('model', 'bge_m3')
|
|
202
|
+
|
|
203
|
+
test_text = "This is a test sentence for embedding."
|
|
204
|
+
|
|
205
|
+
if model_name in ["nomic", "nomic-ai/nomic-embed-multimodal-7b"]:
|
|
206
|
+
from embed_models import NomicVLEmbed
|
|
207
|
+
gpus = embed_config.get('gpus', None)
|
|
208
|
+
embedder = NomicVLEmbed(gpus=gpus)
|
|
209
|
+
model_display = "Nomic Multimodal"
|
|
210
|
+
# NomicVLEmbed uses embed_text() not encode()
|
|
211
|
+
embedding = embedder.embed_text(test_text)
|
|
212
|
+
# Convert tensor to numpy (must convert bfloat16 to float32 first)
|
|
213
|
+
if hasattr(embedding, 'cpu'):
|
|
214
|
+
embedding = embedding.cpu().float().numpy()
|
|
215
|
+
elif model_name in ["bge_m3", "BAAI/bge-m3"]:
|
|
216
|
+
from sentence_transformers import SentenceTransformer
|
|
217
|
+
embedder = SentenceTransformer("BAAI/bge-m3", trust_remote_code=True)
|
|
218
|
+
model_display = "BAAI/bge-m3"
|
|
219
|
+
embedding = embedder.encode(test_text, convert_to_numpy=True)
|
|
220
|
+
else:
|
|
221
|
+
from sentence_transformers import SentenceTransformer
|
|
222
|
+
embedder = SentenceTransformer(model_name, trust_remote_code=True)
|
|
223
|
+
model_display = model_name
|
|
224
|
+
embedding = embedder.encode(test_text, convert_to_numpy=True)
|
|
225
|
+
|
|
226
|
+
if embedding is not None and len(embedding) > 0:
|
|
227
|
+
dim = len(embedding) if hasattr(embedding, '__len__') else embedding.shape[-1]
|
|
228
|
+
return CheckResult(
|
|
229
|
+
name="Embedding Model",
|
|
230
|
+
status=CheckStatus.PASS,
|
|
231
|
+
message=f"Embedding model loaded and working",
|
|
232
|
+
details={
|
|
233
|
+
"model": model_display,
|
|
234
|
+
"config_name": model_name,
|
|
235
|
+
"embedding_dim": dim
|
|
236
|
+
}
|
|
237
|
+
)
|
|
238
|
+
else:
|
|
239
|
+
return CheckResult(
|
|
240
|
+
name="Embedding Model",
|
|
241
|
+
status=CheckStatus.FAIL,
|
|
242
|
+
message="Embedding returned empty result",
|
|
243
|
+
details={"model": model_display}
|
|
244
|
+
)
|
|
245
|
+
except ImportError as e:
|
|
246
|
+
return CheckResult(
|
|
247
|
+
name="Embedding Model",
|
|
248
|
+
status=CheckStatus.FAIL,
|
|
249
|
+
message=f"Cannot import embedding module: {e}",
|
|
250
|
+
details={"error": str(e)}
|
|
251
|
+
)
|
|
252
|
+
except Exception as e:
|
|
253
|
+
return CheckResult(
|
|
254
|
+
name="Embedding Model",
|
|
255
|
+
status=CheckStatus.FAIL,
|
|
256
|
+
message=f"Embedding model failed: {e}",
|
|
257
|
+
details={"error": str(e)}
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@_timer
|
|
262
|
+
def check_reranker() -> CheckResult:
|
|
263
|
+
"""Test reranker model loading."""
|
|
264
|
+
try:
|
|
265
|
+
from config_loader import load_config
|
|
266
|
+
config = load_config()
|
|
267
|
+
reranker_config = config.get('reranker', {})
|
|
268
|
+
default_reranker = reranker_config.get('default', 'gemini_vlm')
|
|
269
|
+
|
|
270
|
+
if default_reranker == "gemini_vlm":
|
|
271
|
+
# Gemini VLM reranker uses API, already tested in VLM check
|
|
272
|
+
return CheckResult(
|
|
273
|
+
name="Reranker",
|
|
274
|
+
status=CheckStatus.PASS,
|
|
275
|
+
message="Using Gemini VLM reranker (API-based)",
|
|
276
|
+
details={"type": "gemini_vlm", "model": "gemini-2.5-flash"}
|
|
277
|
+
)
|
|
278
|
+
elif default_reranker in ["monovlm", "MonoVLM"]:
|
|
279
|
+
from rerankers_multimodal import MonoVLMReranker
|
|
280
|
+
reranker = MonoVLMReranker()
|
|
281
|
+
|
|
282
|
+
# Test with minimal query
|
|
283
|
+
test_query = "test query"
|
|
284
|
+
test_chunks = [{"content": "test content", "chunk_id": "1"}]
|
|
285
|
+
rankings = reranker.rerank(test_query, test_chunks, top_k=1)
|
|
286
|
+
|
|
287
|
+
return CheckResult(
|
|
288
|
+
name="Reranker",
|
|
289
|
+
status=CheckStatus.PASS,
|
|
290
|
+
message="MonoVLM reranker loaded and working",
|
|
291
|
+
details={"type": "monovlm", "model": "lightonai/MonoQwen2-VL-v0.1"}
|
|
292
|
+
)
|
|
293
|
+
else:
|
|
294
|
+
return CheckResult(
|
|
295
|
+
name="Reranker",
|
|
296
|
+
status=CheckStatus.WARN,
|
|
297
|
+
message=f"Unknown reranker type: {default_reranker}",
|
|
298
|
+
details={"type": default_reranker}
|
|
299
|
+
)
|
|
300
|
+
except Exception as e:
|
|
301
|
+
return CheckResult(
|
|
302
|
+
name="Reranker",
|
|
303
|
+
status=CheckStatus.FAIL,
|
|
304
|
+
message=f"Reranker check failed: {e}",
|
|
305
|
+
details={"error": str(e)}
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@_timer
|
|
310
|
+
def check_metrics_embeddings() -> CheckResult:
|
|
311
|
+
"""Test metrics evaluation embeddings (for answer_relevancy and semantic_diversity)."""
|
|
312
|
+
try:
|
|
313
|
+
from metrics_optimized import GEMINI_AVAILABLE, SENTENCE_TRANSFORMERS_AVAILABLE
|
|
314
|
+
|
|
315
|
+
# Check which embedding backend is available
|
|
316
|
+
if GEMINI_AVAILABLE:
|
|
317
|
+
from call_llm import API_KEY
|
|
318
|
+
if API_KEY:
|
|
319
|
+
# Try to initialize Gemini embeddings
|
|
320
|
+
try:
|
|
321
|
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
|
322
|
+
embeddings = GoogleGenerativeAIEmbeddings(
|
|
323
|
+
model="models/text-embedding-004",
|
|
324
|
+
google_api_key=API_KEY
|
|
325
|
+
)
|
|
326
|
+
test_emb = embeddings.embed_query("test")
|
|
327
|
+
return CheckResult(
|
|
328
|
+
name="Metrics Embeddings",
|
|
329
|
+
status=CheckStatus.PASS,
|
|
330
|
+
message="Using Gemini API embeddings for metrics",
|
|
331
|
+
details={"backend": "gemini", "model": "text-embedding-004", "dim": len(test_emb)}
|
|
332
|
+
)
|
|
333
|
+
except Exception as e:
|
|
334
|
+
pass # Fall through to sentence-transformers
|
|
335
|
+
|
|
336
|
+
if SENTENCE_TRANSFORMERS_AVAILABLE:
|
|
337
|
+
from metrics_optimized import LocalEmbeddingWrapper
|
|
338
|
+
embeddings = LocalEmbeddingWrapper("BAAI/bge-m3")
|
|
339
|
+
test_emb = embeddings.embed_query("test")
|
|
340
|
+
return CheckResult(
|
|
341
|
+
name="Metrics Embeddings",
|
|
342
|
+
status=CheckStatus.PASS,
|
|
343
|
+
message="Using local BGE-M3 embeddings for metrics",
|
|
344
|
+
details={"backend": "sentence-transformers", "model": "BAAI/bge-m3", "dim": len(test_emb)}
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
return CheckResult(
|
|
348
|
+
name="Metrics Embeddings",
|
|
349
|
+
status=CheckStatus.FAIL,
|
|
350
|
+
message="No embedding backend available for metrics (answer_relevancy will be 0)",
|
|
351
|
+
details={"gemini_available": GEMINI_AVAILABLE, "st_available": SENTENCE_TRANSFORMERS_AVAILABLE}
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
except Exception as e:
|
|
355
|
+
return CheckResult(
|
|
356
|
+
name="Metrics Embeddings",
|
|
357
|
+
status=CheckStatus.FAIL,
|
|
358
|
+
message=f"Metrics embeddings check failed: {e}",
|
|
359
|
+
details={"error": str(e)}
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
@_timer
|
|
364
|
+
def check_gpu_availability() -> CheckResult:
|
|
365
|
+
"""Check GPU availability and memory."""
|
|
366
|
+
try:
|
|
367
|
+
import torch
|
|
368
|
+
|
|
369
|
+
if torch.cuda.is_available():
|
|
370
|
+
gpu_count = torch.cuda.device_count()
|
|
371
|
+
gpus = []
|
|
372
|
+
for i in range(gpu_count):
|
|
373
|
+
props = torch.cuda.get_device_properties(i)
|
|
374
|
+
free_mem = torch.cuda.memory_reserved(i) - torch.cuda.memory_allocated(i)
|
|
375
|
+
total_mem = props.total_memory / (1024**3)
|
|
376
|
+
gpus.append({
|
|
377
|
+
"id": i,
|
|
378
|
+
"name": props.name,
|
|
379
|
+
"total_gb": round(total_mem, 1)
|
|
380
|
+
})
|
|
381
|
+
|
|
382
|
+
return CheckResult(
|
|
383
|
+
name="GPU",
|
|
384
|
+
status=CheckStatus.PASS,
|
|
385
|
+
message=f"Found {gpu_count} GPU(s)",
|
|
386
|
+
details={"gpus": gpus}
|
|
387
|
+
)
|
|
388
|
+
else:
|
|
389
|
+
return CheckResult(
|
|
390
|
+
name="GPU",
|
|
391
|
+
status=CheckStatus.WARN,
|
|
392
|
+
message="No GPU available, will use CPU (slower)",
|
|
393
|
+
details={}
|
|
394
|
+
)
|
|
395
|
+
except ImportError:
|
|
396
|
+
return CheckResult(
|
|
397
|
+
name="GPU",
|
|
398
|
+
status=CheckStatus.WARN,
|
|
399
|
+
message="PyTorch not available for GPU check",
|
|
400
|
+
details={}
|
|
401
|
+
)
|
|
402
|
+
except Exception as e:
|
|
403
|
+
return CheckResult(
|
|
404
|
+
name="GPU",
|
|
405
|
+
status=CheckStatus.WARN,
|
|
406
|
+
message=f"GPU check error: {e}",
|
|
407
|
+
details={"error": str(e)}
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
@_timer
|
|
412
|
+
def check_output_directory() -> CheckResult:
|
|
413
|
+
"""Check if output directory is writable."""
|
|
414
|
+
try:
|
|
415
|
+
from config_loader import get_paths_config
|
|
416
|
+
paths = get_paths_config()
|
|
417
|
+
output_dir = paths.get('output_dir', 'trials/results')
|
|
418
|
+
|
|
419
|
+
# Create directory if it doesn't exist
|
|
420
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
421
|
+
|
|
422
|
+
# Test write
|
|
423
|
+
test_file = os.path.join(output_dir, ".preflight_test")
|
|
424
|
+
with open(test_file, 'w') as f:
|
|
425
|
+
f.write("test")
|
|
426
|
+
os.remove(test_file)
|
|
427
|
+
|
|
428
|
+
return CheckResult(
|
|
429
|
+
name="Output Directory",
|
|
430
|
+
status=CheckStatus.PASS,
|
|
431
|
+
message=f"Output directory writable",
|
|
432
|
+
details={"path": output_dir}
|
|
433
|
+
)
|
|
434
|
+
except Exception as e:
|
|
435
|
+
return CheckResult(
|
|
436
|
+
name="Output Directory",
|
|
437
|
+
status=CheckStatus.FAIL,
|
|
438
|
+
message=f"Cannot write to output directory: {e}",
|
|
439
|
+
details={"error": str(e)}
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
@_timer
|
|
444
|
+
def check_input_data() -> CheckResult:
|
|
445
|
+
"""Check if input data exists."""
|
|
446
|
+
try:
|
|
447
|
+
from config_loader import get_paths_config
|
|
448
|
+
paths = get_paths_config()
|
|
449
|
+
|
|
450
|
+
input_pdf_dir = paths.get('input_pdf_dir')
|
|
451
|
+
input_chunks_file = paths.get('input_chunks_file')
|
|
452
|
+
|
|
453
|
+
if input_chunks_file and os.path.exists(input_chunks_file):
|
|
454
|
+
import json
|
|
455
|
+
with open(input_chunks_file, 'r') as f:
|
|
456
|
+
chunks = json.load(f)
|
|
457
|
+
return CheckResult(
|
|
458
|
+
name="Input Data",
|
|
459
|
+
status=CheckStatus.PASS,
|
|
460
|
+
message=f"Found pre-chunked data with {len(chunks)} chunks",
|
|
461
|
+
details={"source": "chunks_file", "path": input_chunks_file, "count": len(chunks)}
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
if input_pdf_dir and os.path.exists(input_pdf_dir):
|
|
465
|
+
# Check for PDF files
|
|
466
|
+
pdf_files = list(Path(input_pdf_dir).glob("*.pdf"))
|
|
467
|
+
# Also check for HTML files (News domain uses HTML)
|
|
468
|
+
html_files = list(Path(input_pdf_dir).glob("*.html")) + list(Path(input_pdf_dir).glob("*.htm"))
|
|
469
|
+
total_files = len(pdf_files) + len(html_files)
|
|
470
|
+
|
|
471
|
+
if total_files > 0:
|
|
472
|
+
file_types = []
|
|
473
|
+
if pdf_files:
|
|
474
|
+
file_types.append(f"{len(pdf_files)} PDF")
|
|
475
|
+
if html_files:
|
|
476
|
+
file_types.append(f"{len(html_files)} HTML")
|
|
477
|
+
return CheckResult(
|
|
478
|
+
name="Input Data",
|
|
479
|
+
status=CheckStatus.PASS,
|
|
480
|
+
message=f"Found {', '.join(file_types)} files",
|
|
481
|
+
details={"source": "input_dir", "path": input_pdf_dir, "count": total_files}
|
|
482
|
+
)
|
|
483
|
+
else:
|
|
484
|
+
return CheckResult(
|
|
485
|
+
name="Input Data",
|
|
486
|
+
status=CheckStatus.FAIL,
|
|
487
|
+
message="No PDF or HTML files in input directory",
|
|
488
|
+
details={"path": input_pdf_dir}
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
return CheckResult(
|
|
492
|
+
name="Input Data",
|
|
493
|
+
status=CheckStatus.FAIL,
|
|
494
|
+
message="No input data configured",
|
|
495
|
+
details={}
|
|
496
|
+
)
|
|
497
|
+
except Exception as e:
|
|
498
|
+
return CheckResult(
|
|
499
|
+
name="Input Data",
|
|
500
|
+
status=CheckStatus.FAIL,
|
|
501
|
+
message=f"Input data check failed: {e}",
|
|
502
|
+
details={"error": str(e)}
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
# ============================================================================
|
|
507
|
+
# MAIN PREFLIGHT CHECK RUNNER
|
|
508
|
+
# ============================================================================
|
|
509
|
+
|
|
510
|
+
def run_preflight_checks(
|
|
511
|
+
skip_expensive: bool = False,
|
|
512
|
+
quiet: bool = False
|
|
513
|
+
) -> Tuple[bool, List[CheckResult]]:
|
|
514
|
+
"""
|
|
515
|
+
Run all preflight checks.
|
|
516
|
+
|
|
517
|
+
Args:
|
|
518
|
+
skip_expensive: Skip model loading checks (LLM, VLM, embedding, reranker)
|
|
519
|
+
quiet: Suppress output
|
|
520
|
+
|
|
521
|
+
Returns:
|
|
522
|
+
Tuple of (all_passed: bool, results: List[CheckResult])
|
|
523
|
+
"""
|
|
524
|
+
results = []
|
|
525
|
+
|
|
526
|
+
# Define checks in order of importance
|
|
527
|
+
checks = [
|
|
528
|
+
("config", check_config, False),
|
|
529
|
+
("api_key", check_api_key, False),
|
|
530
|
+
("input_data", check_input_data, False),
|
|
531
|
+
("output_dir", check_output_directory, False),
|
|
532
|
+
("gpu", check_gpu_availability, False),
|
|
533
|
+
("llm", check_llm_call, True),
|
|
534
|
+
("vlm", check_vlm_call, True),
|
|
535
|
+
("embedding", check_embedding_model, True),
|
|
536
|
+
("reranker", check_reranker, True),
|
|
537
|
+
("metrics_emb", check_metrics_embeddings, True),
|
|
538
|
+
]
|
|
539
|
+
|
|
540
|
+
if not quiet:
|
|
541
|
+
print("\n" + "=" * 70)
|
|
542
|
+
print("🔍 PREFLIGHT CHECKS - Validating all services before execution")
|
|
543
|
+
print("=" * 70 + "\n")
|
|
544
|
+
|
|
545
|
+
for name, check_func, is_expensive in checks:
|
|
546
|
+
if is_expensive and skip_expensive:
|
|
547
|
+
result = CheckResult(
|
|
548
|
+
name=name.upper(),
|
|
549
|
+
status=CheckStatus.SKIP,
|
|
550
|
+
message="Skipped (--skip-expensive)"
|
|
551
|
+
)
|
|
552
|
+
else:
|
|
553
|
+
if not quiet:
|
|
554
|
+
print(f" Checking {name}...", end=" ", flush=True)
|
|
555
|
+
result = check_func()
|
|
556
|
+
if not quiet:
|
|
557
|
+
print(f"{result.status.value} ({result.duration_ms:.0f}ms)")
|
|
558
|
+
|
|
559
|
+
results.append(result)
|
|
560
|
+
|
|
561
|
+
# Count results
|
|
562
|
+
passed = sum(1 for r in results if r.status == CheckStatus.PASS)
|
|
563
|
+
failed = sum(1 for r in results if r.status == CheckStatus.FAIL)
|
|
564
|
+
warned = sum(1 for r in results if r.status == CheckStatus.WARN)
|
|
565
|
+
skipped = sum(1 for r in results if r.status == CheckStatus.SKIP)
|
|
566
|
+
|
|
567
|
+
all_passed = failed == 0
|
|
568
|
+
|
|
569
|
+
if not quiet:
|
|
570
|
+
print("\n" + "=" * 70)
|
|
571
|
+
print("📋 PREFLIGHT CHECK SUMMARY")
|
|
572
|
+
print("=" * 70)
|
|
573
|
+
print(f"\n Results: {passed} passed, {failed} failed, {warned} warnings, {skipped} skipped\n")
|
|
574
|
+
|
|
575
|
+
# Print detailed report
|
|
576
|
+
print(" " + "-" * 66)
|
|
577
|
+
print(f" {'Service':<25} {'Status':<12} {'Details':<30}")
|
|
578
|
+
print(" " + "-" * 66)
|
|
579
|
+
|
|
580
|
+
for result in results:
|
|
581
|
+
status_str = result.status.value.split()[0] # Just the emoji
|
|
582
|
+
details_str = ""
|
|
583
|
+
|
|
584
|
+
if result.details:
|
|
585
|
+
if 'model' in result.details:
|
|
586
|
+
details_str = result.details['model']
|
|
587
|
+
elif 'backend' in result.details:
|
|
588
|
+
details_str = result.details['backend']
|
|
589
|
+
elif 'path' in result.details:
|
|
590
|
+
details_str = os.path.basename(str(result.details['path']))
|
|
591
|
+
elif 'count' in result.details:
|
|
592
|
+
details_str = f"{result.details['count']} items"
|
|
593
|
+
|
|
594
|
+
print(f" {result.name:<25} {status_str:<12} {details_str:<30}")
|
|
595
|
+
|
|
596
|
+
print(" " + "-" * 66)
|
|
597
|
+
|
|
598
|
+
# Configuration summary
|
|
599
|
+
config_result = next((r for r in results if r.name == "Configuration"), None)
|
|
600
|
+
if config_result and config_result.status == CheckStatus.PASS:
|
|
601
|
+
d = config_result.details
|
|
602
|
+
print(f"\n 📌 Active Configuration:")
|
|
603
|
+
print(f" • Backend: {d.get('backend', 'N/A')}")
|
|
604
|
+
print(f" • LLM Model: {d.get('llm_model', 'N/A')}")
|
|
605
|
+
print(f" • VLM Model: {d.get('vlm_model', 'N/A')}")
|
|
606
|
+
print(f" • Embeddings: {d.get('embedding_model', 'N/A')}")
|
|
607
|
+
print(f" • Output Dir: {d.get('output_dir', 'N/A')}")
|
|
608
|
+
|
|
609
|
+
print("\n" + "=" * 70)
|
|
610
|
+
|
|
611
|
+
if all_passed:
|
|
612
|
+
print("✅ ALL CHECKS PASSED - Ready to start pipeline execution")
|
|
613
|
+
else:
|
|
614
|
+
print("❌ PREFLIGHT CHECKS FAILED - Fix issues before proceeding")
|
|
615
|
+
print("\n Failed checks:")
|
|
616
|
+
for result in results:
|
|
617
|
+
if result.status == CheckStatus.FAIL:
|
|
618
|
+
print(f" • {result.name}: {result.message}")
|
|
619
|
+
|
|
620
|
+
print("=" * 70 + "\n")
|
|
621
|
+
|
|
622
|
+
return all_passed, results
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
def require_preflight_checks() -> bool:
|
|
626
|
+
"""Run preflight checks and exit if any fail. Returns True if all pass."""
|
|
627
|
+
all_passed, results = run_preflight_checks()
|
|
628
|
+
|
|
629
|
+
if not all_passed:
|
|
630
|
+
print("\n🛑 STOPPING: Preflight checks failed. Fix the issues above before running the pipeline.")
|
|
631
|
+
print(" This prevents wasted LLM API calls on a misconfigured system.\n")
|
|
632
|
+
sys.exit(1)
|
|
633
|
+
|
|
634
|
+
return True
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
# ============================================================================
|
|
638
|
+
# CLI ENTRY POINT
|
|
639
|
+
# ============================================================================
|
|
640
|
+
|
|
641
|
+
def main():
|
|
642
|
+
"""Main entry point for preflight checks CLI."""
|
|
643
|
+
import argparse
|
|
644
|
+
|
|
645
|
+
parser = argparse.ArgumentParser(description="Run preflight checks for QA pipeline")
|
|
646
|
+
parser.add_argument("--skip-expensive", action="store_true",
|
|
647
|
+
help="Skip expensive checks (model loading, API calls)")
|
|
648
|
+
parser.add_argument("--quiet", action="store_true",
|
|
649
|
+
help="Suppress output")
|
|
650
|
+
|
|
651
|
+
args = parser.parse_args()
|
|
652
|
+
|
|
653
|
+
all_passed, _ = run_preflight_checks(
|
|
654
|
+
skip_expensive=args.skip_expensive,
|
|
655
|
+
quiet=args.quiet
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
sys.exit(0 if all_passed else 1)
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
if __name__ == "__main__":
|
|
662
|
+
main()
|
|
663
|
+
|