mirage-benchmark 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mirage-benchmark might be problematic. Click here for more details.
- mirage/__init__.py +83 -0
- mirage/cli.py +150 -0
- mirage/core/__init__.py +52 -0
- mirage/core/config.py +248 -0
- mirage/core/llm.py +1745 -0
- mirage/core/prompts.py +884 -0
- mirage/embeddings/__init__.py +31 -0
- mirage/embeddings/models.py +512 -0
- mirage/embeddings/rerankers_multimodal.py +766 -0
- mirage/embeddings/rerankers_text.py +149 -0
- mirage/evaluation/__init__.py +26 -0
- mirage/evaluation/metrics.py +2223 -0
- mirage/evaluation/metrics_optimized.py +2172 -0
- mirage/pipeline/__init__.py +45 -0
- mirage/pipeline/chunker.py +545 -0
- mirage/pipeline/context.py +1003 -0
- mirage/pipeline/deduplication.py +491 -0
- mirage/pipeline/domain.py +514 -0
- mirage/pipeline/pdf_processor.py +598 -0
- mirage/pipeline/qa_generator.py +798 -0
- mirage/utils/__init__.py +31 -0
- mirage/utils/ablation.py +360 -0
- mirage/utils/preflight.py +663 -0
- mirage/utils/stats.py +626 -0
- mirage_benchmark-1.0.4.dist-info/METADATA +490 -0
- mirage_benchmark-1.0.4.dist-info/RECORD +30 -0
- mirage_benchmark-1.0.4.dist-info/WHEEL +5 -0
- mirage_benchmark-1.0.4.dist-info/entry_points.txt +3 -0
- mirage_benchmark-1.0.4.dist-info/licenses/LICENSE +190 -0
- mirage_benchmark-1.0.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,514 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Module to extract domain and expert role from semantic chunks using BERTopic and LLM.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
import logging
|
|
9
|
+
from typing import List, Dict, Tuple, Optional, Union
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import numpy as np
|
|
14
|
+
from umap import UMAP
|
|
15
|
+
from bertopic import BERTopic
|
|
16
|
+
from bertopic.representation import MaximalMarginalRelevance
|
|
17
|
+
from sklearn.feature_extraction.text import CountVectorizer
|
|
18
|
+
from PIL import Image
|
|
19
|
+
from sentence_transformers import SentenceTransformer
|
|
20
|
+
import datamapplot
|
|
21
|
+
|
|
22
|
+
# Import from mirage modules
|
|
23
|
+
from mirage.core.llm import call_llm_simple, setup_logging
|
|
24
|
+
from mirage.embeddings.models import NomicVLEmbed
|
|
25
|
+
from mirage.core.prompts import PROMPTS
|
|
26
|
+
|
|
27
|
+
#%% Setup
|
|
28
|
+
|
|
29
|
+
def save_domain_expert_to_env(domain: str, expert_persona: str):
|
|
30
|
+
"""Save domain and expert persona as environment variables"""
|
|
31
|
+
os.environ['DATASET_DOMAIN'] = domain
|
|
32
|
+
os.environ['DATASET_EXPERT_PERSONA'] = expert_persona
|
|
33
|
+
print(f"💾 Saved to environment: DATASET_DOMAIN={domain}, DATASET_EXPERT_PERSONA={expert_persona}")
|
|
34
|
+
|
|
35
|
+
def load_domain_expert_from_env() -> Tuple[str, str]:
|
|
36
|
+
"""Load domain and expert persona from environment variables"""
|
|
37
|
+
domain = os.environ.get('DATASET_DOMAIN')
|
|
38
|
+
expert_persona = os.environ.get('DATASET_EXPERT_PERSONA')
|
|
39
|
+
if domain and expert_persona:
|
|
40
|
+
print(f"📥 Loaded from environment: domain={domain}, expert_persona={expert_persona}")
|
|
41
|
+
return domain, expert_persona
|
|
42
|
+
return None, None
|
|
43
|
+
|
|
44
|
+
# Configuration (override via config.yaml)
|
|
45
|
+
DEFAULT_CHUNKS_FILE = "output/results/chunks.json"
|
|
46
|
+
OUTPUT_DIR = "output/domain_analysis"
|
|
47
|
+
# Directory containing images referenced in chunks (update as needed)
|
|
48
|
+
IMAGE_BASE_DIR = "output/results/markdown"
|
|
49
|
+
# Directory containing pre-computed embeddings
|
|
50
|
+
EMBEDDINGS_DIR = "output/results/embeddings"
|
|
51
|
+
|
|
52
|
+
# Embedding Mode Configuration
|
|
53
|
+
# Set to False to use BGE-M3 (Text Only), True to use NomicVLEmbed (Multimodal)
|
|
54
|
+
USE_MULTIMODAL_EMBEDDINGS = True
|
|
55
|
+
|
|
56
|
+
#%% Load Pre-computed Embeddings
|
|
57
|
+
|
|
58
|
+
def load_precomputed_embeddings(model_name: str = "nomic") -> Tuple[np.ndarray, List[str]]:
|
|
59
|
+
"""
|
|
60
|
+
Load pre-computed embeddings from .npz file and corresponding chunk IDs.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Tuple of (embeddings_array, chunk_ids_list)
|
|
64
|
+
"""
|
|
65
|
+
embeddings_path = os.path.join(EMBEDDINGS_DIR, "embeddings_dict.npz")
|
|
66
|
+
chunk_ids_path = os.path.join(EMBEDDINGS_DIR, "chunk_ids.json")
|
|
67
|
+
|
|
68
|
+
if not os.path.exists(embeddings_path):
|
|
69
|
+
raise FileNotFoundError(f"Embeddings file not found: {embeddings_path}")
|
|
70
|
+
if not os.path.exists(chunk_ids_path):
|
|
71
|
+
raise FileNotFoundError(f"Chunk IDs file not found: {chunk_ids_path}")
|
|
72
|
+
|
|
73
|
+
# Load embeddings
|
|
74
|
+
print(f"📂 Loading pre-computed embeddings from {embeddings_path}...")
|
|
75
|
+
with np.load(embeddings_path) as data:
|
|
76
|
+
if model_name in data:
|
|
77
|
+
embeddings = data[model_name]
|
|
78
|
+
print(f"✅ Loaded embeddings for {model_name}: {embeddings.shape}")
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError(f"Model {model_name} not found in embeddings file. Available: {data.files}")
|
|
81
|
+
|
|
82
|
+
# Load chunk IDs
|
|
83
|
+
print(f"📂 Loading chunk IDs from {chunk_ids_path}...")
|
|
84
|
+
with open(chunk_ids_path, 'r') as f:
|
|
85
|
+
chunk_ids_data = json.load(f)
|
|
86
|
+
if model_name in chunk_ids_data:
|
|
87
|
+
chunk_ids = chunk_ids_data[model_name]
|
|
88
|
+
print(f"✅ Loaded {len(chunk_ids)} chunk IDs for {model_name}")
|
|
89
|
+
else:
|
|
90
|
+
raise ValueError(f"Model {model_name} not found in chunk_ids file.")
|
|
91
|
+
|
|
92
|
+
if len(chunk_ids) != embeddings.shape[0]:
|
|
93
|
+
raise ValueError(f"Mismatch between chunk IDs count ({len(chunk_ids)}) and embeddings rows ({embeddings.shape[0]})")
|
|
94
|
+
|
|
95
|
+
return embeddings, chunk_ids
|
|
96
|
+
|
|
97
|
+
def align_chunks_with_embeddings(chunks: List[Dict], chunk_ids: List[str]) -> Tuple[List[Dict], List[int]]:
|
|
98
|
+
"""
|
|
99
|
+
Filter and order chunks to match the sequence of pre-computed embeddings.
|
|
100
|
+
Returns:
|
|
101
|
+
aligned_chunks: List of chunks found
|
|
102
|
+
valid_indices: Indices of the embeddings that correspond to the found chunks
|
|
103
|
+
"""
|
|
104
|
+
print("🔄 Aligning chunks with pre-computed embeddings...")
|
|
105
|
+
|
|
106
|
+
# Create map of chunk_id -> chunk for O(1) lookup
|
|
107
|
+
chunk_map = {}
|
|
108
|
+
for c in chunks:
|
|
109
|
+
c_id = str(c.get('chunk_id', ''))
|
|
110
|
+
chunk_map[c_id] = c
|
|
111
|
+
|
|
112
|
+
aligned_chunks = []
|
|
113
|
+
valid_indices = []
|
|
114
|
+
missing_ids = []
|
|
115
|
+
|
|
116
|
+
for idx, target_id in enumerate(chunk_ids):
|
|
117
|
+
target_id = str(target_id)
|
|
118
|
+
if target_id in chunk_map:
|
|
119
|
+
aligned_chunks.append(chunk_map[target_id])
|
|
120
|
+
valid_indices.append(idx)
|
|
121
|
+
else:
|
|
122
|
+
missing_ids.append(target_id)
|
|
123
|
+
|
|
124
|
+
if missing_ids:
|
|
125
|
+
print(f"⚠️ Warning: {len(missing_ids)} chunks from embeddings not found in source file.")
|
|
126
|
+
if len(missing_ids) < 10:
|
|
127
|
+
print(f" Missing IDs: {missing_ids}")
|
|
128
|
+
|
|
129
|
+
print(f"✅ Aligned {len(aligned_chunks)} chunks.")
|
|
130
|
+
return aligned_chunks, valid_indices
|
|
131
|
+
|
|
132
|
+
def get_embeddings_multimodal(chunks: List[Dict], embed_model: NomicVLEmbed) -> np.ndarray:
|
|
133
|
+
"""
|
|
134
|
+
Generate multimodal embeddings for chunks using NomicVLEmbed.
|
|
135
|
+
Handles text and optional images (from 'artifact' field).
|
|
136
|
+
"""
|
|
137
|
+
print(f"🖼️ Generating Multimodal Embeddings for {len(chunks)} chunks...")
|
|
138
|
+
|
|
139
|
+
embeddings = []
|
|
140
|
+
|
|
141
|
+
for i, chunk in enumerate(chunks):
|
|
142
|
+
text = chunk.get('content', '')
|
|
143
|
+
artifact = chunk.get('artifact', 'None')
|
|
144
|
+
image_path = None
|
|
145
|
+
|
|
146
|
+
# Try to parse image path from artifact string if present
|
|
147
|
+
if artifact and artifact != 'None':
|
|
148
|
+
if " + 2
|
|
150
|
+
end = artifact.find(")", start)
|
|
151
|
+
rel_path = artifact[start:end]
|
|
152
|
+
image_path = os.path.join(IMAGE_BASE_DIR, rel_path.lstrip('/'))
|
|
153
|
+
else:
|
|
154
|
+
pass
|
|
155
|
+
|
|
156
|
+
# Verify image existence
|
|
157
|
+
if image_path and not os.path.exists(image_path):
|
|
158
|
+
image_path = None
|
|
159
|
+
|
|
160
|
+
# Generate embedding
|
|
161
|
+
try:
|
|
162
|
+
emb = embed_model.embed_multimodal(text, image_path)
|
|
163
|
+
if isinstance(emb, torch.Tensor):
|
|
164
|
+
emb = emb.cpu().float().numpy()
|
|
165
|
+
embeddings.append(emb)
|
|
166
|
+
except Exception as e:
|
|
167
|
+
print(f"❌ Error embedding chunk {i}: {e}")
|
|
168
|
+
try:
|
|
169
|
+
emb = embed_model.encode(text, convert_to_numpy=True)
|
|
170
|
+
if isinstance(emb, torch.Tensor):
|
|
171
|
+
emb = emb.cpu().float().numpy()
|
|
172
|
+
embeddings.append(emb)
|
|
173
|
+
except:
|
|
174
|
+
embeddings.append(np.zeros(768))
|
|
175
|
+
|
|
176
|
+
if (i+1) % 10 == 0:
|
|
177
|
+
print(f" Processed {i+1}/{len(chunks)}")
|
|
178
|
+
|
|
179
|
+
return np.vstack(embeddings)
|
|
180
|
+
|
|
181
|
+
def get_embeddings_text_only(chunks: List[Dict], model_name: str = "BAAI/bge-m3") -> np.ndarray:
|
|
182
|
+
"""
|
|
183
|
+
Generate text-only embeddings using SentenceTransformer (BGE-M3).
|
|
184
|
+
"""
|
|
185
|
+
print(f"📝 Generating Text-Only Embeddings using {model_name}...")
|
|
186
|
+
|
|
187
|
+
docs = [c.get('content', '') for c in chunks]
|
|
188
|
+
|
|
189
|
+
print(f" Loading model: {model_name}")
|
|
190
|
+
model = SentenceTransformer(model_name)
|
|
191
|
+
|
|
192
|
+
print(f" Encoding {len(docs)} documents...")
|
|
193
|
+
embeddings = model.encode(docs, show_progress_bar=True)
|
|
194
|
+
|
|
195
|
+
return embeddings
|
|
196
|
+
|
|
197
|
+
def get_domain_model(chunks: List[Dict], embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None) -> BERTopic:
|
|
198
|
+
"""
|
|
199
|
+
Train BERTopic model using embeddings (pre-calculated or computed on-the-fly).
|
|
200
|
+
"""
|
|
201
|
+
print(f"🚀 Starting Topic Modeling on {len(chunks)} chunks...")
|
|
202
|
+
|
|
203
|
+
# Extract text content for BERTopic (it still needs docs for representation)
|
|
204
|
+
docs = [c.get('content', '') for c in chunks]
|
|
205
|
+
|
|
206
|
+
# 1. Get Embeddings
|
|
207
|
+
if embeddings is not None:
|
|
208
|
+
# Convert torch.Tensor to numpy if needed (BERTopic requires numpy)
|
|
209
|
+
if isinstance(embeddings, torch.Tensor):
|
|
210
|
+
print("✅ Using pre-computed embeddings (converting GPU tensor to numpy for BERTopic)")
|
|
211
|
+
embeddings = embeddings.cpu().float().numpy()
|
|
212
|
+
else:
|
|
213
|
+
print("✅ Using pre-computed embeddings")
|
|
214
|
+
else:
|
|
215
|
+
if USE_MULTIMODAL_EMBEDDINGS:
|
|
216
|
+
print(" Mode: Multimodal (NomicVLEmbed) - COMPUTING ON THE FLY")
|
|
217
|
+
# Try to use cached embedder from main.py if available
|
|
218
|
+
try:
|
|
219
|
+
import sys
|
|
220
|
+
if 'main' in sys.modules and hasattr(sys.modules['main'], '_MODEL_CACHE'):
|
|
221
|
+
cache = sys.modules['main']._MODEL_CACHE
|
|
222
|
+
if cache.get('nomic_embedder') is not None:
|
|
223
|
+
embedder = cache['nomic_embedder']
|
|
224
|
+
print(" Using cached NomicVLEmbed model (no reload needed)")
|
|
225
|
+
else:
|
|
226
|
+
embedder = NomicVLEmbed()
|
|
227
|
+
else:
|
|
228
|
+
embedder = NomicVLEmbed()
|
|
229
|
+
except:
|
|
230
|
+
embedder = NomicVLEmbed()
|
|
231
|
+
embeddings = get_embeddings_multimodal(chunks, embedder)
|
|
232
|
+
else:
|
|
233
|
+
print(" Mode: Text-Only (BGE-M3) - COMPUTING ON THE FLY")
|
|
234
|
+
embeddings = get_embeddings_text_only(chunks, model_name="BAAI/bge-m3")
|
|
235
|
+
|
|
236
|
+
print(f"✅ Embeddings shape: {embeddings.shape}")
|
|
237
|
+
print(f"✅ Docs length: {len(docs)}")
|
|
238
|
+
|
|
239
|
+
if embeddings.shape[0] != len(docs):
|
|
240
|
+
raise ValueError(f"Mismatch: Embeddings rows ({embeddings.shape[0]}) != Docs length ({len(docs)})")
|
|
241
|
+
|
|
242
|
+
# 2. Prevent Stochastic Behavior & Dimensionality Reduction
|
|
243
|
+
umap_model = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', random_state=42)
|
|
244
|
+
|
|
245
|
+
# 3. Improve Default Representation
|
|
246
|
+
vectorizer_model = CountVectorizer(stop_words="english", min_df=2, ngram_range=(1, 2))
|
|
247
|
+
|
|
248
|
+
# 4. Maximal Marginal Relevance (MMR) for better topic diversity
|
|
249
|
+
representation_model = MaximalMarginalRelevance(diversity=0.5)
|
|
250
|
+
|
|
251
|
+
# Initialize BERTopic
|
|
252
|
+
topic_model = BERTopic(
|
|
253
|
+
umap_model=umap_model,
|
|
254
|
+
vectorizer_model=vectorizer_model,
|
|
255
|
+
representation_model=representation_model,
|
|
256
|
+
calculate_probabilities=False,
|
|
257
|
+
verbose=True
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# Fit the model using pre-calculated embeddings
|
|
261
|
+
topics, probs = topic_model.fit_transform(docs, embeddings)
|
|
262
|
+
|
|
263
|
+
print(f"✅ Topic Modeling Complete. Found {len(topic_model.get_topic_info()) - 1} topics.")
|
|
264
|
+
return topic_model, docs, embeddings
|
|
265
|
+
|
|
266
|
+
def query_llm_for_domain(topic_model: BERTopic) -> Tuple[str, str]:
|
|
267
|
+
"""
|
|
268
|
+
Extract domain and expert role using LLM based on topics.
|
|
269
|
+
"""
|
|
270
|
+
print("🤖 Querying LLM for Domain and Role...")
|
|
271
|
+
|
|
272
|
+
# Get top topics info
|
|
273
|
+
topic_info = topic_model.get_topic_info()
|
|
274
|
+
total_count = topic_info['Count'].sum()
|
|
275
|
+
|
|
276
|
+
print("\n📊 TOPIC SUMMARY:")
|
|
277
|
+
print(f"{'ID':<6} {'Count':<8} {'Freq':<8} {'Keywords'}")
|
|
278
|
+
print("-" * 100)
|
|
279
|
+
|
|
280
|
+
# Filter out outlier topic (-1) and take top 10
|
|
281
|
+
top_topics = topic_info[topic_info['Topic'] != -1].head(15)
|
|
282
|
+
|
|
283
|
+
# Format topics for prompt
|
|
284
|
+
topic_list_str = ""
|
|
285
|
+
for _, row in top_topics.iterrows():
|
|
286
|
+
# Representation is a list of keywords
|
|
287
|
+
keywords = ", ".join(row['Representation'][:5])
|
|
288
|
+
topic_list_str += f"- Topic {row['Topic']} (Count: {row['Count']}): {keywords}\n"
|
|
289
|
+
print(f"{row['Topic']:<6} {row['Count']:<8} {row['Count']/total_count:<8.1%} {keywords}")
|
|
290
|
+
|
|
291
|
+
print("-" * 100 + "\n")
|
|
292
|
+
|
|
293
|
+
# Use prompt from prompt.py
|
|
294
|
+
prompt = PROMPTS["domain_and_expert_from_topics"].format(topic_list_str=topic_list_str)
|
|
295
|
+
|
|
296
|
+
# Call LLM
|
|
297
|
+
response = call_llm_simple(prompt)
|
|
298
|
+
|
|
299
|
+
# Parse response
|
|
300
|
+
domain = "Unknown"
|
|
301
|
+
role = "Expert"
|
|
302
|
+
|
|
303
|
+
# Parse delimiter-based format:
|
|
304
|
+
# <|#|>START<|#|>
|
|
305
|
+
# <|#|>Domain: <The Domain>
|
|
306
|
+
# <|#|>Expert Role: <The Expert Role>
|
|
307
|
+
# <|#|>END<|#|>
|
|
308
|
+
|
|
309
|
+
if '<|#|>' in response:
|
|
310
|
+
parts = response.split('<|#|>')
|
|
311
|
+
for part in parts:
|
|
312
|
+
part = part.strip()
|
|
313
|
+
if part.lower().startswith("domain:"):
|
|
314
|
+
domain = part.split(":", 1)[1].strip()
|
|
315
|
+
elif part.lower().startswith("expert role:"):
|
|
316
|
+
role = part.split(":", 1)[1].strip()
|
|
317
|
+
else:
|
|
318
|
+
# Fallback for line-based format if delimiters not found
|
|
319
|
+
lines = [l.strip() for l in response.split('\n') if l.strip()]
|
|
320
|
+
for line in lines:
|
|
321
|
+
if line.lower().startswith("domain:"):
|
|
322
|
+
domain = line.split(":", 1)[1].strip()
|
|
323
|
+
elif line.lower().startswith("expert role:"):
|
|
324
|
+
role = line.split(":", 1)[1].strip()
|
|
325
|
+
|
|
326
|
+
# Clean up domain string if it contains multiple lines (fix for double printing issue)
|
|
327
|
+
if "\n" in domain:
|
|
328
|
+
domain = domain.split("\n")[0].strip()
|
|
329
|
+
|
|
330
|
+
# Clean up role string if it contains multiple lines or "Expert Role:" prefix repeated
|
|
331
|
+
if "\n" in role:
|
|
332
|
+
role = role.split("\n")[0].strip()
|
|
333
|
+
if role.lower().startswith("expert role:"):
|
|
334
|
+
role = role.split(":", 1)[1].strip()
|
|
335
|
+
|
|
336
|
+
return domain, role
|
|
337
|
+
|
|
338
|
+
def visualize_results(topic_model: BERTopic, docs: List[str], output_dir: str, embeddings: Optional[np.ndarray] = None, generate_plots: bool = False):
|
|
339
|
+
"""
|
|
340
|
+
Generate and save visualizations.
|
|
341
|
+
"""
|
|
342
|
+
print(f"📊 Generating visualizations in {output_dir}...")
|
|
343
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
344
|
+
|
|
345
|
+
if generate_plots:
|
|
346
|
+
# 1. Visualize Topics (Distance Map)
|
|
347
|
+
try:
|
|
348
|
+
fig_topics = topic_model.visualize_topics()
|
|
349
|
+
fig_topics.write_html(os.path.join(output_dir, "topics_distance_map.html"))
|
|
350
|
+
print(" - Saved topics_distance_map.html")
|
|
351
|
+
except Exception as e:
|
|
352
|
+
print(f" ⚠️ Could not generate topic visualization: {e}")
|
|
353
|
+
|
|
354
|
+
# 2. Visualize Hierarchy
|
|
355
|
+
try:
|
|
356
|
+
fig_hierarchy = topic_model.visualize_hierarchy()
|
|
357
|
+
fig_hierarchy.write_html(os.path.join(output_dir, "topics_hierarchy.html"))
|
|
358
|
+
print(" - Saved topics_hierarchy.html")
|
|
359
|
+
except Exception as e:
|
|
360
|
+
print(f" ⚠️ Could not generate hierarchy visualization: {e}")
|
|
361
|
+
|
|
362
|
+
# 3. Visualize Document Datamap (Requires datamapplot and 2D embeddings)
|
|
363
|
+
if embeddings is not None:
|
|
364
|
+
try:
|
|
365
|
+
print(" - Calculating 2D UMAP for Datamap...")
|
|
366
|
+
# Reduce dimensionality of embeddings to 2D for visualization
|
|
367
|
+
# User specified params: n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine'
|
|
368
|
+
umap_2d = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine', random_state=42)
|
|
369
|
+
reduced_embeddings = umap_2d.fit_transform(embeddings)
|
|
370
|
+
|
|
371
|
+
print(" - Generating Datamap...")
|
|
372
|
+
fig = topic_model.visualize_document_datamap(
|
|
373
|
+
docs,
|
|
374
|
+
reduced_embeddings=reduced_embeddings,
|
|
375
|
+
interactive=True
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# Handle datamapplot figure save
|
|
379
|
+
try:
|
|
380
|
+
fig.save(os.path.join(output_dir, "document_datamap.html"))
|
|
381
|
+
except AttributeError:
|
|
382
|
+
# Fallback for Plotly figure if return type changes
|
|
383
|
+
fig.write_html(os.path.join(output_dir, "document_datamap.html"))
|
|
384
|
+
|
|
385
|
+
print(" - Saved document_datamap.html")
|
|
386
|
+
except Exception as e:
|
|
387
|
+
print(f" ⚠️ Could not generate document datamap: {e}")
|
|
388
|
+
else:
|
|
389
|
+
print(" - Plot generation skipped (enable with visualization=True)")
|
|
390
|
+
|
|
391
|
+
# 4. Save Topic Info CSV (Always save this)
|
|
392
|
+
topic_info = topic_model.get_topic_info()
|
|
393
|
+
topic_info.to_csv(os.path.join(output_dir, "topic_info.csv"), index=False)
|
|
394
|
+
print(" - Saved topic_info.csv")
|
|
395
|
+
|
|
396
|
+
def fetch_domain_and_role(chunks_file: str = DEFAULT_CHUNKS_FILE,
|
|
397
|
+
embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None,
|
|
398
|
+
chunk_ids: Optional[List[str]] = None) -> Tuple[str, str]:
|
|
399
|
+
"""
|
|
400
|
+
Wrapper to load chunks, run model, and return domain/role.
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
chunks_file: Path to chunks JSON file
|
|
404
|
+
embeddings: Pre-computed embeddings (torch.Tensor on GPU or np.ndarray, optional)
|
|
405
|
+
chunk_ids: List of chunk IDs matching embeddings (optional)
|
|
406
|
+
"""
|
|
407
|
+
if not os.path.exists(chunks_file):
|
|
408
|
+
# Try to look in current dir
|
|
409
|
+
local_file = os.path.basename(chunks_file)
|
|
410
|
+
if os.path.exists(local_file):
|
|
411
|
+
chunks_file = local_file
|
|
412
|
+
else:
|
|
413
|
+
print(f"❌ Chunks file not found: {chunks_file}")
|
|
414
|
+
return "Unknown", "Expert"
|
|
415
|
+
|
|
416
|
+
print(f"📂 Loading chunks from {chunks_file}...")
|
|
417
|
+
with open(chunks_file, 'r') as f:
|
|
418
|
+
chunks = json.load(f)
|
|
419
|
+
|
|
420
|
+
# Use provided embeddings if available (from main.py pipeline)
|
|
421
|
+
if embeddings is not None and chunk_ids is not None:
|
|
422
|
+
print(f"✅ Using embeddings provided from main pipeline")
|
|
423
|
+
# Convert torch.Tensor to numpy if needed (BERTopic requires numpy)
|
|
424
|
+
if isinstance(embeddings, torch.Tensor):
|
|
425
|
+
print(f" Converting GPU tensor to numpy for BERTopic")
|
|
426
|
+
embeddings = embeddings.cpu().float().numpy()
|
|
427
|
+
# Align chunks to embeddings
|
|
428
|
+
chunks, valid_indices = align_chunks_with_embeddings(chunks, chunk_ids)
|
|
429
|
+
# Filter embeddings to match found chunks
|
|
430
|
+
if len(valid_indices) != embeddings.shape[0]:
|
|
431
|
+
print(f"⚠️ Filtering embeddings from {embeddings.shape[0]} to {len(valid_indices)} rows")
|
|
432
|
+
embeddings = embeddings[valid_indices]
|
|
433
|
+
else:
|
|
434
|
+
# Fallback: Try to load pre-computed embeddings from .npz (legacy support)
|
|
435
|
+
try:
|
|
436
|
+
embeddings, chunk_ids = load_precomputed_embeddings()
|
|
437
|
+
# Align chunks to embeddings
|
|
438
|
+
chunks, valid_indices = align_chunks_with_embeddings(chunks, chunk_ids)
|
|
439
|
+
# Filter embeddings to match found chunks
|
|
440
|
+
if len(valid_indices) != embeddings.shape[0]:
|
|
441
|
+
print(f"⚠️ Filtering embeddings from {embeddings.shape[0]} to {len(valid_indices)} rows")
|
|
442
|
+
embeddings = embeddings[valid_indices]
|
|
443
|
+
except Exception as e:
|
|
444
|
+
print(f"⚠️ Could not load/align pre-computed embeddings: {e}")
|
|
445
|
+
print(" Falling back to on-the-fly computation (slower)")
|
|
446
|
+
embeddings = None
|
|
447
|
+
|
|
448
|
+
topic_model, _, _ = get_domain_model(chunks, embeddings=embeddings)
|
|
449
|
+
domain, role = query_llm_for_domain(topic_model)
|
|
450
|
+
|
|
451
|
+
# Save to environment variables
|
|
452
|
+
save_domain_expert_to_env(domain, role)
|
|
453
|
+
|
|
454
|
+
return domain, role
|
|
455
|
+
|
|
456
|
+
def main(visualization: bool = False):
|
|
457
|
+
setup_logging()
|
|
458
|
+
|
|
459
|
+
# Load Chunks
|
|
460
|
+
chunks_file = DEFAULT_CHUNKS_FILE
|
|
461
|
+
if not os.path.exists(chunks_file):
|
|
462
|
+
print(f"❌ Chunks file not found: {chunks_file}")
|
|
463
|
+
# Try to look in current dir
|
|
464
|
+
local_file = "chunks.json"
|
|
465
|
+
if os.path.exists(local_file):
|
|
466
|
+
chunks_file = local_file
|
|
467
|
+
else:
|
|
468
|
+
return
|
|
469
|
+
|
|
470
|
+
print(f"📂 Loading chunks from {chunks_file}...")
|
|
471
|
+
with open(chunks_file, 'r') as f:
|
|
472
|
+
chunks = json.load(f)
|
|
473
|
+
|
|
474
|
+
# Try to load pre-computed embeddings
|
|
475
|
+
try:
|
|
476
|
+
embeddings, chunk_ids = load_precomputed_embeddings()
|
|
477
|
+
# Align chunks to embeddings
|
|
478
|
+
chunks, valid_indices = align_chunks_with_embeddings(chunks, chunk_ids)
|
|
479
|
+
# Filter embeddings to match found chunks
|
|
480
|
+
if len(valid_indices) != embeddings.shape[0]:
|
|
481
|
+
print(f"⚠️ Filtering embeddings from {embeddings.shape[0]} to {len(valid_indices)} rows")
|
|
482
|
+
embeddings = embeddings[valid_indices]
|
|
483
|
+
except Exception as e:
|
|
484
|
+
print(f"⚠️ Could not load/align pre-computed embeddings: {e}")
|
|
485
|
+
import traceback
|
|
486
|
+
traceback.print_exc()
|
|
487
|
+
print(" Falling back to on-the-fly computation (slower)")
|
|
488
|
+
embeddings = None
|
|
489
|
+
|
|
490
|
+
# Run Topic Modeling
|
|
491
|
+
topic_model, docs, embeddings = get_domain_model(chunks, embeddings=embeddings)
|
|
492
|
+
|
|
493
|
+
# Get Domain and Role
|
|
494
|
+
domain, role = query_llm_for_domain(topic_model)
|
|
495
|
+
|
|
496
|
+
# Save to environment variables
|
|
497
|
+
save_domain_expert_to_env(domain, role)
|
|
498
|
+
|
|
499
|
+
print("\n" + "="*50)
|
|
500
|
+
print(f"🎯 RESULTS")
|
|
501
|
+
print(f"Domain: {domain}")
|
|
502
|
+
print(f"Expert Role: {role}")
|
|
503
|
+
print("="*50 + "\n")
|
|
504
|
+
|
|
505
|
+
# Visualization
|
|
506
|
+
visualize_results(topic_model, docs, OUTPUT_DIR, embeddings=embeddings, generate_plots=visualization)
|
|
507
|
+
|
|
508
|
+
if __name__ == "__main__":
|
|
509
|
+
import argparse
|
|
510
|
+
parser = argparse.ArgumentParser(description="Extract domain and expert role from semantic chunks")
|
|
511
|
+
parser.add_argument("--vis", action="store_true", help="Enable visualization generation")
|
|
512
|
+
args = parser.parse_args()
|
|
513
|
+
|
|
514
|
+
main(visualization=args.vis)
|