mirage-benchmark 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mirage-benchmark might be problematic. Click here for more details.

@@ -0,0 +1,514 @@
1
+ """
2
+ Module to extract domain and expert role from semantic chunks using BERTopic and LLM.
3
+ """
4
+
5
+ import json
6
+ import os
7
+ import sys
8
+ import logging
9
+ from typing import List, Dict, Tuple, Optional, Union
10
+
11
+ import torch
12
+ import pandas as pd
13
+ import numpy as np
14
+ from umap import UMAP
15
+ from bertopic import BERTopic
16
+ from bertopic.representation import MaximalMarginalRelevance
17
+ from sklearn.feature_extraction.text import CountVectorizer
18
+ from PIL import Image
19
+ from sentence_transformers import SentenceTransformer
20
+ import datamapplot
21
+
22
+ # Import from mirage modules
23
+ from mirage.core.llm import call_llm_simple, setup_logging
24
+ from mirage.embeddings.models import NomicVLEmbed
25
+ from mirage.core.prompts import PROMPTS
26
+
27
+ #%% Setup
28
+
29
+ def save_domain_expert_to_env(domain: str, expert_persona: str):
30
+ """Save domain and expert persona as environment variables"""
31
+ os.environ['DATASET_DOMAIN'] = domain
32
+ os.environ['DATASET_EXPERT_PERSONA'] = expert_persona
33
+ print(f"💾 Saved to environment: DATASET_DOMAIN={domain}, DATASET_EXPERT_PERSONA={expert_persona}")
34
+
35
+ def load_domain_expert_from_env() -> Tuple[str, str]:
36
+ """Load domain and expert persona from environment variables"""
37
+ domain = os.environ.get('DATASET_DOMAIN')
38
+ expert_persona = os.environ.get('DATASET_EXPERT_PERSONA')
39
+ if domain and expert_persona:
40
+ print(f"📥 Loaded from environment: domain={domain}, expert_persona={expert_persona}")
41
+ return domain, expert_persona
42
+ return None, None
43
+
44
+ # Configuration (override via config.yaml)
45
+ DEFAULT_CHUNKS_FILE = "output/results/chunks.json"
46
+ OUTPUT_DIR = "output/domain_analysis"
47
+ # Directory containing images referenced in chunks (update as needed)
48
+ IMAGE_BASE_DIR = "output/results/markdown"
49
+ # Directory containing pre-computed embeddings
50
+ EMBEDDINGS_DIR = "output/results/embeddings"
51
+
52
+ # Embedding Mode Configuration
53
+ # Set to False to use BGE-M3 (Text Only), True to use NomicVLEmbed (Multimodal)
54
+ USE_MULTIMODAL_EMBEDDINGS = True
55
+
56
+ #%% Load Pre-computed Embeddings
57
+
58
+ def load_precomputed_embeddings(model_name: str = "nomic") -> Tuple[np.ndarray, List[str]]:
59
+ """
60
+ Load pre-computed embeddings from .npz file and corresponding chunk IDs.
61
+
62
+ Returns:
63
+ Tuple of (embeddings_array, chunk_ids_list)
64
+ """
65
+ embeddings_path = os.path.join(EMBEDDINGS_DIR, "embeddings_dict.npz")
66
+ chunk_ids_path = os.path.join(EMBEDDINGS_DIR, "chunk_ids.json")
67
+
68
+ if not os.path.exists(embeddings_path):
69
+ raise FileNotFoundError(f"Embeddings file not found: {embeddings_path}")
70
+ if not os.path.exists(chunk_ids_path):
71
+ raise FileNotFoundError(f"Chunk IDs file not found: {chunk_ids_path}")
72
+
73
+ # Load embeddings
74
+ print(f"📂 Loading pre-computed embeddings from {embeddings_path}...")
75
+ with np.load(embeddings_path) as data:
76
+ if model_name in data:
77
+ embeddings = data[model_name]
78
+ print(f"✅ Loaded embeddings for {model_name}: {embeddings.shape}")
79
+ else:
80
+ raise ValueError(f"Model {model_name} not found in embeddings file. Available: {data.files}")
81
+
82
+ # Load chunk IDs
83
+ print(f"📂 Loading chunk IDs from {chunk_ids_path}...")
84
+ with open(chunk_ids_path, 'r') as f:
85
+ chunk_ids_data = json.load(f)
86
+ if model_name in chunk_ids_data:
87
+ chunk_ids = chunk_ids_data[model_name]
88
+ print(f"✅ Loaded {len(chunk_ids)} chunk IDs for {model_name}")
89
+ else:
90
+ raise ValueError(f"Model {model_name} not found in chunk_ids file.")
91
+
92
+ if len(chunk_ids) != embeddings.shape[0]:
93
+ raise ValueError(f"Mismatch between chunk IDs count ({len(chunk_ids)}) and embeddings rows ({embeddings.shape[0]})")
94
+
95
+ return embeddings, chunk_ids
96
+
97
+ def align_chunks_with_embeddings(chunks: List[Dict], chunk_ids: List[str]) -> Tuple[List[Dict], List[int]]:
98
+ """
99
+ Filter and order chunks to match the sequence of pre-computed embeddings.
100
+ Returns:
101
+ aligned_chunks: List of chunks found
102
+ valid_indices: Indices of the embeddings that correspond to the found chunks
103
+ """
104
+ print("🔄 Aligning chunks with pre-computed embeddings...")
105
+
106
+ # Create map of chunk_id -> chunk for O(1) lookup
107
+ chunk_map = {}
108
+ for c in chunks:
109
+ c_id = str(c.get('chunk_id', ''))
110
+ chunk_map[c_id] = c
111
+
112
+ aligned_chunks = []
113
+ valid_indices = []
114
+ missing_ids = []
115
+
116
+ for idx, target_id in enumerate(chunk_ids):
117
+ target_id = str(target_id)
118
+ if target_id in chunk_map:
119
+ aligned_chunks.append(chunk_map[target_id])
120
+ valid_indices.append(idx)
121
+ else:
122
+ missing_ids.append(target_id)
123
+
124
+ if missing_ids:
125
+ print(f"⚠️ Warning: {len(missing_ids)} chunks from embeddings not found in source file.")
126
+ if len(missing_ids) < 10:
127
+ print(f" Missing IDs: {missing_ids}")
128
+
129
+ print(f"✅ Aligned {len(aligned_chunks)} chunks.")
130
+ return aligned_chunks, valid_indices
131
+
132
+ def get_embeddings_multimodal(chunks: List[Dict], embed_model: NomicVLEmbed) -> np.ndarray:
133
+ """
134
+ Generate multimodal embeddings for chunks using NomicVLEmbed.
135
+ Handles text and optional images (from 'artifact' field).
136
+ """
137
+ print(f"🖼️ Generating Multimodal Embeddings for {len(chunks)} chunks...")
138
+
139
+ embeddings = []
140
+
141
+ for i, chunk in enumerate(chunks):
142
+ text = chunk.get('content', '')
143
+ artifact = chunk.get('artifact', 'None')
144
+ image_path = None
145
+
146
+ # Try to parse image path from artifact string if present
147
+ if artifact and artifact != 'None':
148
+ if "![" in artifact and "](" in artifact:
149
+ start = artifact.find("](") + 2
150
+ end = artifact.find(")", start)
151
+ rel_path = artifact[start:end]
152
+ image_path = os.path.join(IMAGE_BASE_DIR, rel_path.lstrip('/'))
153
+ else:
154
+ pass
155
+
156
+ # Verify image existence
157
+ if image_path and not os.path.exists(image_path):
158
+ image_path = None
159
+
160
+ # Generate embedding
161
+ try:
162
+ emb = embed_model.embed_multimodal(text, image_path)
163
+ if isinstance(emb, torch.Tensor):
164
+ emb = emb.cpu().float().numpy()
165
+ embeddings.append(emb)
166
+ except Exception as e:
167
+ print(f"❌ Error embedding chunk {i}: {e}")
168
+ try:
169
+ emb = embed_model.encode(text, convert_to_numpy=True)
170
+ if isinstance(emb, torch.Tensor):
171
+ emb = emb.cpu().float().numpy()
172
+ embeddings.append(emb)
173
+ except:
174
+ embeddings.append(np.zeros(768))
175
+
176
+ if (i+1) % 10 == 0:
177
+ print(f" Processed {i+1}/{len(chunks)}")
178
+
179
+ return np.vstack(embeddings)
180
+
181
+ def get_embeddings_text_only(chunks: List[Dict], model_name: str = "BAAI/bge-m3") -> np.ndarray:
182
+ """
183
+ Generate text-only embeddings using SentenceTransformer (BGE-M3).
184
+ """
185
+ print(f"📝 Generating Text-Only Embeddings using {model_name}...")
186
+
187
+ docs = [c.get('content', '') for c in chunks]
188
+
189
+ print(f" Loading model: {model_name}")
190
+ model = SentenceTransformer(model_name)
191
+
192
+ print(f" Encoding {len(docs)} documents...")
193
+ embeddings = model.encode(docs, show_progress_bar=True)
194
+
195
+ return embeddings
196
+
197
+ def get_domain_model(chunks: List[Dict], embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None) -> BERTopic:
198
+ """
199
+ Train BERTopic model using embeddings (pre-calculated or computed on-the-fly).
200
+ """
201
+ print(f"🚀 Starting Topic Modeling on {len(chunks)} chunks...")
202
+
203
+ # Extract text content for BERTopic (it still needs docs for representation)
204
+ docs = [c.get('content', '') for c in chunks]
205
+
206
+ # 1. Get Embeddings
207
+ if embeddings is not None:
208
+ # Convert torch.Tensor to numpy if needed (BERTopic requires numpy)
209
+ if isinstance(embeddings, torch.Tensor):
210
+ print("✅ Using pre-computed embeddings (converting GPU tensor to numpy for BERTopic)")
211
+ embeddings = embeddings.cpu().float().numpy()
212
+ else:
213
+ print("✅ Using pre-computed embeddings")
214
+ else:
215
+ if USE_MULTIMODAL_EMBEDDINGS:
216
+ print(" Mode: Multimodal (NomicVLEmbed) - COMPUTING ON THE FLY")
217
+ # Try to use cached embedder from main.py if available
218
+ try:
219
+ import sys
220
+ if 'main' in sys.modules and hasattr(sys.modules['main'], '_MODEL_CACHE'):
221
+ cache = sys.modules['main']._MODEL_CACHE
222
+ if cache.get('nomic_embedder') is not None:
223
+ embedder = cache['nomic_embedder']
224
+ print(" Using cached NomicVLEmbed model (no reload needed)")
225
+ else:
226
+ embedder = NomicVLEmbed()
227
+ else:
228
+ embedder = NomicVLEmbed()
229
+ except:
230
+ embedder = NomicVLEmbed()
231
+ embeddings = get_embeddings_multimodal(chunks, embedder)
232
+ else:
233
+ print(" Mode: Text-Only (BGE-M3) - COMPUTING ON THE FLY")
234
+ embeddings = get_embeddings_text_only(chunks, model_name="BAAI/bge-m3")
235
+
236
+ print(f"✅ Embeddings shape: {embeddings.shape}")
237
+ print(f"✅ Docs length: {len(docs)}")
238
+
239
+ if embeddings.shape[0] != len(docs):
240
+ raise ValueError(f"Mismatch: Embeddings rows ({embeddings.shape[0]}) != Docs length ({len(docs)})")
241
+
242
+ # 2. Prevent Stochastic Behavior & Dimensionality Reduction
243
+ umap_model = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', random_state=42)
244
+
245
+ # 3. Improve Default Representation
246
+ vectorizer_model = CountVectorizer(stop_words="english", min_df=2, ngram_range=(1, 2))
247
+
248
+ # 4. Maximal Marginal Relevance (MMR) for better topic diversity
249
+ representation_model = MaximalMarginalRelevance(diversity=0.5)
250
+
251
+ # Initialize BERTopic
252
+ topic_model = BERTopic(
253
+ umap_model=umap_model,
254
+ vectorizer_model=vectorizer_model,
255
+ representation_model=representation_model,
256
+ calculate_probabilities=False,
257
+ verbose=True
258
+ )
259
+
260
+ # Fit the model using pre-calculated embeddings
261
+ topics, probs = topic_model.fit_transform(docs, embeddings)
262
+
263
+ print(f"✅ Topic Modeling Complete. Found {len(topic_model.get_topic_info()) - 1} topics.")
264
+ return topic_model, docs, embeddings
265
+
266
+ def query_llm_for_domain(topic_model: BERTopic) -> Tuple[str, str]:
267
+ """
268
+ Extract domain and expert role using LLM based on topics.
269
+ """
270
+ print("🤖 Querying LLM for Domain and Role...")
271
+
272
+ # Get top topics info
273
+ topic_info = topic_model.get_topic_info()
274
+ total_count = topic_info['Count'].sum()
275
+
276
+ print("\n📊 TOPIC SUMMARY:")
277
+ print(f"{'ID':<6} {'Count':<8} {'Freq':<8} {'Keywords'}")
278
+ print("-" * 100)
279
+
280
+ # Filter out outlier topic (-1) and take top 10
281
+ top_topics = topic_info[topic_info['Topic'] != -1].head(15)
282
+
283
+ # Format topics for prompt
284
+ topic_list_str = ""
285
+ for _, row in top_topics.iterrows():
286
+ # Representation is a list of keywords
287
+ keywords = ", ".join(row['Representation'][:5])
288
+ topic_list_str += f"- Topic {row['Topic']} (Count: {row['Count']}): {keywords}\n"
289
+ print(f"{row['Topic']:<6} {row['Count']:<8} {row['Count']/total_count:<8.1%} {keywords}")
290
+
291
+ print("-" * 100 + "\n")
292
+
293
+ # Use prompt from prompt.py
294
+ prompt = PROMPTS["domain_and_expert_from_topics"].format(topic_list_str=topic_list_str)
295
+
296
+ # Call LLM
297
+ response = call_llm_simple(prompt)
298
+
299
+ # Parse response
300
+ domain = "Unknown"
301
+ role = "Expert"
302
+
303
+ # Parse delimiter-based format:
304
+ # <|#|>START<|#|>
305
+ # <|#|>Domain: <The Domain>
306
+ # <|#|>Expert Role: <The Expert Role>
307
+ # <|#|>END<|#|>
308
+
309
+ if '<|#|>' in response:
310
+ parts = response.split('<|#|>')
311
+ for part in parts:
312
+ part = part.strip()
313
+ if part.lower().startswith("domain:"):
314
+ domain = part.split(":", 1)[1].strip()
315
+ elif part.lower().startswith("expert role:"):
316
+ role = part.split(":", 1)[1].strip()
317
+ else:
318
+ # Fallback for line-based format if delimiters not found
319
+ lines = [l.strip() for l in response.split('\n') if l.strip()]
320
+ for line in lines:
321
+ if line.lower().startswith("domain:"):
322
+ domain = line.split(":", 1)[1].strip()
323
+ elif line.lower().startswith("expert role:"):
324
+ role = line.split(":", 1)[1].strip()
325
+
326
+ # Clean up domain string if it contains multiple lines (fix for double printing issue)
327
+ if "\n" in domain:
328
+ domain = domain.split("\n")[0].strip()
329
+
330
+ # Clean up role string if it contains multiple lines or "Expert Role:" prefix repeated
331
+ if "\n" in role:
332
+ role = role.split("\n")[0].strip()
333
+ if role.lower().startswith("expert role:"):
334
+ role = role.split(":", 1)[1].strip()
335
+
336
+ return domain, role
337
+
338
+ def visualize_results(topic_model: BERTopic, docs: List[str], output_dir: str, embeddings: Optional[np.ndarray] = None, generate_plots: bool = False):
339
+ """
340
+ Generate and save visualizations.
341
+ """
342
+ print(f"📊 Generating visualizations in {output_dir}...")
343
+ os.makedirs(output_dir, exist_ok=True)
344
+
345
+ if generate_plots:
346
+ # 1. Visualize Topics (Distance Map)
347
+ try:
348
+ fig_topics = topic_model.visualize_topics()
349
+ fig_topics.write_html(os.path.join(output_dir, "topics_distance_map.html"))
350
+ print(" - Saved topics_distance_map.html")
351
+ except Exception as e:
352
+ print(f" ⚠️ Could not generate topic visualization: {e}")
353
+
354
+ # 2. Visualize Hierarchy
355
+ try:
356
+ fig_hierarchy = topic_model.visualize_hierarchy()
357
+ fig_hierarchy.write_html(os.path.join(output_dir, "topics_hierarchy.html"))
358
+ print(" - Saved topics_hierarchy.html")
359
+ except Exception as e:
360
+ print(f" ⚠️ Could not generate hierarchy visualization: {e}")
361
+
362
+ # 3. Visualize Document Datamap (Requires datamapplot and 2D embeddings)
363
+ if embeddings is not None:
364
+ try:
365
+ print(" - Calculating 2D UMAP for Datamap...")
366
+ # Reduce dimensionality of embeddings to 2D for visualization
367
+ # User specified params: n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine'
368
+ umap_2d = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine', random_state=42)
369
+ reduced_embeddings = umap_2d.fit_transform(embeddings)
370
+
371
+ print(" - Generating Datamap...")
372
+ fig = topic_model.visualize_document_datamap(
373
+ docs,
374
+ reduced_embeddings=reduced_embeddings,
375
+ interactive=True
376
+ )
377
+
378
+ # Handle datamapplot figure save
379
+ try:
380
+ fig.save(os.path.join(output_dir, "document_datamap.html"))
381
+ except AttributeError:
382
+ # Fallback for Plotly figure if return type changes
383
+ fig.write_html(os.path.join(output_dir, "document_datamap.html"))
384
+
385
+ print(" - Saved document_datamap.html")
386
+ except Exception as e:
387
+ print(f" ⚠️ Could not generate document datamap: {e}")
388
+ else:
389
+ print(" - Plot generation skipped (enable with visualization=True)")
390
+
391
+ # 4. Save Topic Info CSV (Always save this)
392
+ topic_info = topic_model.get_topic_info()
393
+ topic_info.to_csv(os.path.join(output_dir, "topic_info.csv"), index=False)
394
+ print(" - Saved topic_info.csv")
395
+
396
+ def fetch_domain_and_role(chunks_file: str = DEFAULT_CHUNKS_FILE,
397
+ embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None,
398
+ chunk_ids: Optional[List[str]] = None) -> Tuple[str, str]:
399
+ """
400
+ Wrapper to load chunks, run model, and return domain/role.
401
+
402
+ Args:
403
+ chunks_file: Path to chunks JSON file
404
+ embeddings: Pre-computed embeddings (torch.Tensor on GPU or np.ndarray, optional)
405
+ chunk_ids: List of chunk IDs matching embeddings (optional)
406
+ """
407
+ if not os.path.exists(chunks_file):
408
+ # Try to look in current dir
409
+ local_file = os.path.basename(chunks_file)
410
+ if os.path.exists(local_file):
411
+ chunks_file = local_file
412
+ else:
413
+ print(f"❌ Chunks file not found: {chunks_file}")
414
+ return "Unknown", "Expert"
415
+
416
+ print(f"📂 Loading chunks from {chunks_file}...")
417
+ with open(chunks_file, 'r') as f:
418
+ chunks = json.load(f)
419
+
420
+ # Use provided embeddings if available (from main.py pipeline)
421
+ if embeddings is not None and chunk_ids is not None:
422
+ print(f"✅ Using embeddings provided from main pipeline")
423
+ # Convert torch.Tensor to numpy if needed (BERTopic requires numpy)
424
+ if isinstance(embeddings, torch.Tensor):
425
+ print(f" Converting GPU tensor to numpy for BERTopic")
426
+ embeddings = embeddings.cpu().float().numpy()
427
+ # Align chunks to embeddings
428
+ chunks, valid_indices = align_chunks_with_embeddings(chunks, chunk_ids)
429
+ # Filter embeddings to match found chunks
430
+ if len(valid_indices) != embeddings.shape[0]:
431
+ print(f"⚠️ Filtering embeddings from {embeddings.shape[0]} to {len(valid_indices)} rows")
432
+ embeddings = embeddings[valid_indices]
433
+ else:
434
+ # Fallback: Try to load pre-computed embeddings from .npz (legacy support)
435
+ try:
436
+ embeddings, chunk_ids = load_precomputed_embeddings()
437
+ # Align chunks to embeddings
438
+ chunks, valid_indices = align_chunks_with_embeddings(chunks, chunk_ids)
439
+ # Filter embeddings to match found chunks
440
+ if len(valid_indices) != embeddings.shape[0]:
441
+ print(f"⚠️ Filtering embeddings from {embeddings.shape[0]} to {len(valid_indices)} rows")
442
+ embeddings = embeddings[valid_indices]
443
+ except Exception as e:
444
+ print(f"⚠️ Could not load/align pre-computed embeddings: {e}")
445
+ print(" Falling back to on-the-fly computation (slower)")
446
+ embeddings = None
447
+
448
+ topic_model, _, _ = get_domain_model(chunks, embeddings=embeddings)
449
+ domain, role = query_llm_for_domain(topic_model)
450
+
451
+ # Save to environment variables
452
+ save_domain_expert_to_env(domain, role)
453
+
454
+ return domain, role
455
+
456
+ def main(visualization: bool = False):
457
+ setup_logging()
458
+
459
+ # Load Chunks
460
+ chunks_file = DEFAULT_CHUNKS_FILE
461
+ if not os.path.exists(chunks_file):
462
+ print(f"❌ Chunks file not found: {chunks_file}")
463
+ # Try to look in current dir
464
+ local_file = "chunks.json"
465
+ if os.path.exists(local_file):
466
+ chunks_file = local_file
467
+ else:
468
+ return
469
+
470
+ print(f"📂 Loading chunks from {chunks_file}...")
471
+ with open(chunks_file, 'r') as f:
472
+ chunks = json.load(f)
473
+
474
+ # Try to load pre-computed embeddings
475
+ try:
476
+ embeddings, chunk_ids = load_precomputed_embeddings()
477
+ # Align chunks to embeddings
478
+ chunks, valid_indices = align_chunks_with_embeddings(chunks, chunk_ids)
479
+ # Filter embeddings to match found chunks
480
+ if len(valid_indices) != embeddings.shape[0]:
481
+ print(f"⚠️ Filtering embeddings from {embeddings.shape[0]} to {len(valid_indices)} rows")
482
+ embeddings = embeddings[valid_indices]
483
+ except Exception as e:
484
+ print(f"⚠️ Could not load/align pre-computed embeddings: {e}")
485
+ import traceback
486
+ traceback.print_exc()
487
+ print(" Falling back to on-the-fly computation (slower)")
488
+ embeddings = None
489
+
490
+ # Run Topic Modeling
491
+ topic_model, docs, embeddings = get_domain_model(chunks, embeddings=embeddings)
492
+
493
+ # Get Domain and Role
494
+ domain, role = query_llm_for_domain(topic_model)
495
+
496
+ # Save to environment variables
497
+ save_domain_expert_to_env(domain, role)
498
+
499
+ print("\n" + "="*50)
500
+ print(f"🎯 RESULTS")
501
+ print(f"Domain: {domain}")
502
+ print(f"Expert Role: {role}")
503
+ print("="*50 + "\n")
504
+
505
+ # Visualization
506
+ visualize_results(topic_model, docs, OUTPUT_DIR, embeddings=embeddings, generate_plots=visualization)
507
+
508
+ if __name__ == "__main__":
509
+ import argparse
510
+ parser = argparse.ArgumentParser(description="Extract domain and expert role from semantic chunks")
511
+ parser.add_argument("--vis", action="store_true", help="Enable visualization generation")
512
+ args = parser.parse_args()
513
+
514
+ main(visualization=args.vis)