mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,824 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
GEPA RAG Optimization Example with Multiple Vector Stores
|
|
4
|
+
|
|
5
|
+
This example demonstrates how to use GEPA to optimize a RAG system using various
|
|
6
|
+
vector stores, showcasing their unique capabilities and search methods.
|
|
7
|
+
|
|
8
|
+
Supported Vector Stores:
|
|
9
|
+
- ChromaDB: Local/persistent vector store with simple setup
|
|
10
|
+
- LanceDB: Developer-friendly serverless vector database
|
|
11
|
+
- Milvus: Cloud-native vector database with Lite mode
|
|
12
|
+
- Qdrant: High-performance vector database with advanced filtering
|
|
13
|
+
- Weaviate: Vector database with hybrid search capabilities
|
|
14
|
+
|
|
15
|
+
Usage:
|
|
16
|
+
# ChromaDB (default, no external dependencies)
|
|
17
|
+
python rag_optimization.py --vector-store chromadb
|
|
18
|
+
|
|
19
|
+
# LanceDB (local, no Docker required)
|
|
20
|
+
python rag_optimization.py --vector-store lancedb
|
|
21
|
+
|
|
22
|
+
# Milvus Lite (local SQLite-based)
|
|
23
|
+
python rag_optimization.py --vector-store milvus
|
|
24
|
+
|
|
25
|
+
# Qdrant (in-memory or with Docker)
|
|
26
|
+
python rag_optimization.py --vector-store qdrant
|
|
27
|
+
|
|
28
|
+
# Weaviate (requires Docker)
|
|
29
|
+
python rag_optimization.py --vector-store weaviate
|
|
30
|
+
|
|
31
|
+
# With specific models
|
|
32
|
+
python rag_optimization.py --vector-store chromadb --model ollama/llama3.1:8b
|
|
33
|
+
|
|
34
|
+
# Full optimization run
|
|
35
|
+
python rag_optimization.py --vector-store qdrant --max-iterations 20
|
|
36
|
+
|
|
37
|
+
Requirements:
|
|
38
|
+
Base: pip install gepa[rag]
|
|
39
|
+
ChromaDB: pip install chromadb
|
|
40
|
+
LanceDB: pip install lancedb pyarrow sentence-transformers
|
|
41
|
+
Milvus: pip install pymilvus sentence-transformers
|
|
42
|
+
Qdrant: pip install qdrant-client
|
|
43
|
+
Weaviate: pip install weaviate-client
|
|
44
|
+
|
|
45
|
+
Prerequisites:
|
|
46
|
+
- For Ollama: ollama pull qwen3:8b && ollama pull nomic-embed-text:latest
|
|
47
|
+
- For Weaviate: docker run -p 8080:8080 -p 50051:50051 cr.weaviate.io/semitechnologies/weaviate:1.26.1
|
|
48
|
+
- For Qdrant (optional): docker run -p 6333:6333 qdrant/qdrant
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
import argparse
|
|
52
|
+
import os
|
|
53
|
+
import sys
|
|
54
|
+
import tempfile
|
|
55
|
+
import warnings
|
|
56
|
+
from pathlib import Path
|
|
57
|
+
from typing import Any
|
|
58
|
+
|
|
59
|
+
# Suppress all warnings for clean output
|
|
60
|
+
warnings.filterwarnings("ignore")
|
|
61
|
+
os.environ["PYTHONWARNINGS"] = "ignore"
|
|
62
|
+
|
|
63
|
+
# Add parent directory to path for imports
|
|
64
|
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
|
|
65
|
+
|
|
66
|
+
import mantisdk.algorithm.gepa.lib as gepa # noqa: E402
|
|
67
|
+
from mantisdk.algorithm.gepa.lib.adapters.generic_rag_adapter import GenericRAGAdapter, RAGDataInst # noqa: E402
|
|
68
|
+
|
|
69
|
+
# Vector store imports (lazy loaded)
|
|
70
|
+
_vector_stores = {}
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def lazy_import_vector_store(store_name: str):
|
|
74
|
+
"""Lazy import vector store classes to avoid dependency issues."""
|
|
75
|
+
global _vector_stores
|
|
76
|
+
|
|
77
|
+
if store_name in _vector_stores:
|
|
78
|
+
return _vector_stores[store_name]
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
if store_name == "chromadb":
|
|
82
|
+
from mantisdk.algorithm.gepa.lib.adapters.generic_rag_adapter import ChromaVectorStore
|
|
83
|
+
|
|
84
|
+
_vector_stores[store_name] = ChromaVectorStore
|
|
85
|
+
return ChromaVectorStore
|
|
86
|
+
elif store_name == "lancedb":
|
|
87
|
+
from mantisdk.algorithm.gepa.lib.adapters.generic_rag_adapter import LanceDBVectorStore
|
|
88
|
+
|
|
89
|
+
_vector_stores[store_name] = LanceDBVectorStore
|
|
90
|
+
return LanceDBVectorStore
|
|
91
|
+
elif store_name == "milvus":
|
|
92
|
+
from mantisdk.algorithm.gepa.lib.adapters.generic_rag_adapter import MilvusVectorStore
|
|
93
|
+
|
|
94
|
+
_vector_stores[store_name] = MilvusVectorStore
|
|
95
|
+
return MilvusVectorStore
|
|
96
|
+
elif store_name == "qdrant":
|
|
97
|
+
from mantisdk.algorithm.gepa.lib.adapters.generic_rag_adapter import QdrantVectorStore
|
|
98
|
+
|
|
99
|
+
_vector_stores[store_name] = QdrantVectorStore
|
|
100
|
+
return QdrantVectorStore
|
|
101
|
+
elif store_name == "weaviate":
|
|
102
|
+
from mantisdk.algorithm.gepa.lib.adapters.generic_rag_adapter import WeaviateVectorStore
|
|
103
|
+
|
|
104
|
+
_vector_stores[store_name] = WeaviateVectorStore
|
|
105
|
+
return WeaviateVectorStore
|
|
106
|
+
else:
|
|
107
|
+
raise ValueError(f"Unknown vector store: {store_name}")
|
|
108
|
+
except ImportError as e:
|
|
109
|
+
raise ImportError(
|
|
110
|
+
f"Failed to import {store_name} dependencies: {e}\n"
|
|
111
|
+
f"Install with: pip install {get_install_command(store_name)}"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_install_command(store_name: str) -> str:
|
|
116
|
+
"""Get pip install command for vector store dependencies."""
|
|
117
|
+
commands = {
|
|
118
|
+
"chromadb": "chromadb",
|
|
119
|
+
"lancedb": "lancedb pyarrow sentence-transformers",
|
|
120
|
+
"milvus": "pymilvus sentence-transformers",
|
|
121
|
+
"qdrant": "qdrant-client",
|
|
122
|
+
"weaviate": "weaviate-client",
|
|
123
|
+
}
|
|
124
|
+
return commands.get(store_name, "unknown")
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def create_llm_client(model_name: str):
|
|
128
|
+
"""Create LLM client supporting both Ollama and cloud models."""
|
|
129
|
+
try:
|
|
130
|
+
import litellm
|
|
131
|
+
|
|
132
|
+
litellm.drop_params = True
|
|
133
|
+
litellm.set_verbose = False
|
|
134
|
+
except ImportError:
|
|
135
|
+
raise ImportError("LiteLLM is required. Install with: pip install litellm")
|
|
136
|
+
|
|
137
|
+
def llm_client(messages_or_prompt, **kwargs):
|
|
138
|
+
try:
|
|
139
|
+
# Handle both string prompts and message lists
|
|
140
|
+
if isinstance(messages_or_prompt, str):
|
|
141
|
+
messages = [{"role": "user", "content": messages_or_prompt}]
|
|
142
|
+
else:
|
|
143
|
+
messages = messages_or_prompt
|
|
144
|
+
|
|
145
|
+
params = {
|
|
146
|
+
"model": model_name,
|
|
147
|
+
"messages": messages,
|
|
148
|
+
"max_tokens": kwargs.get("max_tokens", 400),
|
|
149
|
+
"temperature": kwargs.get("temperature", 0.1),
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
if "ollama/" in model_name:
|
|
153
|
+
params["request_timeout"] = 120
|
|
154
|
+
|
|
155
|
+
response = litellm.completion(**params)
|
|
156
|
+
return response.choices[0].message.content.strip()
|
|
157
|
+
|
|
158
|
+
except Exception as e:
|
|
159
|
+
return f"Error: Unable to generate response ({e})"
|
|
160
|
+
|
|
161
|
+
return llm_client
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def create_embedding_function():
|
|
165
|
+
"""Create embedding function using sentence-transformers as fallback."""
|
|
166
|
+
try:
|
|
167
|
+
from sentence_transformers import SentenceTransformer
|
|
168
|
+
|
|
169
|
+
model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
170
|
+
return lambda text: model.encode(text)
|
|
171
|
+
except ImportError:
|
|
172
|
+
# Fallback to litellm for embedding
|
|
173
|
+
try:
|
|
174
|
+
import litellm
|
|
175
|
+
|
|
176
|
+
def embed_text(text: str):
|
|
177
|
+
try:
|
|
178
|
+
response = litellm.embedding(model="ollama/nomic-embed-text:latest", input=text)
|
|
179
|
+
if hasattr(response, "data") and response.data:
|
|
180
|
+
if hasattr(response.data[0], "embedding"):
|
|
181
|
+
return response.data[0].embedding
|
|
182
|
+
elif isinstance(response.data[0], dict) and "embedding" in response.data[0]:
|
|
183
|
+
return response.data[0]["embedding"]
|
|
184
|
+
elif isinstance(response, dict):
|
|
185
|
+
if response.get("data"):
|
|
186
|
+
return response["data"][0]["embedding"]
|
|
187
|
+
elif "embedding" in response:
|
|
188
|
+
return response["embedding"]
|
|
189
|
+
raise ValueError(f"Unknown response format: {type(response)}")
|
|
190
|
+
except Exception as e:
|
|
191
|
+
raise RuntimeError(
|
|
192
|
+
f"Embedding failed: {e}. Please check your embedding model setup (sentence-transformers or litellm) and ensure all dependencies are installed."
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
return embed_text
|
|
196
|
+
except ImportError:
|
|
197
|
+
raise ImportError("Either sentence-transformers or litellm is required for embeddings")
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def setup_chromadb_store():
|
|
201
|
+
"""Set up ChromaDB vector store with sample data."""
|
|
202
|
+
print("šļø Setting up ChromaDB vector store...")
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
from chromadb.utils import embedding_functions
|
|
206
|
+
except ImportError:
|
|
207
|
+
raise ImportError("ChromaDB is required. Install with: pip install chromadb")
|
|
208
|
+
|
|
209
|
+
# Create temporary directory
|
|
210
|
+
temp_dir = tempfile.mkdtemp()
|
|
211
|
+
print(f" š ChromaDB directory: {temp_dir}")
|
|
212
|
+
|
|
213
|
+
# Initialize ChromaDB
|
|
214
|
+
embedding_function = embedding_functions.DefaultEmbeddingFunction()
|
|
215
|
+
chroma_vector_store = lazy_import_vector_store("chromadb")
|
|
216
|
+
vector_store = chroma_vector_store.create_local(
|
|
217
|
+
persist_directory=temp_dir, collection_name="ai_ml_knowledge", embedding_function=embedding_function
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
documents = get_sample_documents()
|
|
221
|
+
vector_store.collection.add(
|
|
222
|
+
documents=[doc["content"] for doc in documents],
|
|
223
|
+
metadatas=[doc["metadata"] for doc in documents],
|
|
224
|
+
ids=[doc["metadata"]["doc_id"] for doc in documents],
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
print(f" ā
Created ChromaDB knowledge base with {len(documents)} articles")
|
|
228
|
+
return vector_store
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def setup_lancedb_store():
|
|
232
|
+
"""Set up LanceDB vector store with sample data."""
|
|
233
|
+
print("šļø Setting up LanceDB vector store...")
|
|
234
|
+
|
|
235
|
+
try:
|
|
236
|
+
embedding_function = create_embedding_function()
|
|
237
|
+
lancedb_vector_store = lazy_import_vector_store("lancedb")
|
|
238
|
+
|
|
239
|
+
vector_store = lancedb_vector_store.create_local(
|
|
240
|
+
table_name="rag_demo",
|
|
241
|
+
embedding_function=embedding_function,
|
|
242
|
+
db_path="./lancedb_demo",
|
|
243
|
+
vector_size=384,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
documents = get_sample_documents_simple()
|
|
247
|
+
embeddings = [embedding_function(doc["content"]) for doc in documents]
|
|
248
|
+
ids = vector_store.add_documents(documents, embeddings)
|
|
249
|
+
|
|
250
|
+
print(f" ā
Added {len(ids)} documents to LanceDB table")
|
|
251
|
+
return vector_store
|
|
252
|
+
|
|
253
|
+
except ImportError as e:
|
|
254
|
+
raise ImportError(
|
|
255
|
+
f"LanceDB dependencies missing: {e}\nInstall with: pip install lancedb pyarrow sentence-transformers"
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def setup_milvus_store():
|
|
260
|
+
"""Set up Milvus vector store with sample data."""
|
|
261
|
+
print("šļø Setting up Milvus vector store...")
|
|
262
|
+
|
|
263
|
+
try:
|
|
264
|
+
embedding_function = create_embedding_function()
|
|
265
|
+
milvus_vector_store = lazy_import_vector_store("milvus")
|
|
266
|
+
|
|
267
|
+
vector_store = milvus_vector_store.create_local(
|
|
268
|
+
collection_name="rag_demo",
|
|
269
|
+
embedding_function=embedding_function,
|
|
270
|
+
vector_size=384,
|
|
271
|
+
uri="./milvus_demo.db",
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
documents = get_sample_documents_simple()
|
|
275
|
+
embeddings = [embedding_function(doc["content"]) for doc in documents]
|
|
276
|
+
ids = vector_store.add_documents(documents, embeddings)
|
|
277
|
+
|
|
278
|
+
print(f" ā
Added {len(ids)} documents to Milvus collection")
|
|
279
|
+
return vector_store
|
|
280
|
+
|
|
281
|
+
except ImportError as e:
|
|
282
|
+
raise ImportError(f"Milvus dependencies missing: {e}\nInstall with: pip install pymilvus sentence-transformers")
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def setup_qdrant_store():
|
|
286
|
+
"""Set up Qdrant vector store with sample data."""
|
|
287
|
+
print("šļø Setting up Qdrant vector store...")
|
|
288
|
+
|
|
289
|
+
try:
|
|
290
|
+
from qdrant_client import QdrantClient
|
|
291
|
+
from qdrant_client.http import models
|
|
292
|
+
except ImportError:
|
|
293
|
+
raise ImportError("Qdrant client required. Install with: pip install qdrant-client")
|
|
294
|
+
|
|
295
|
+
# Connect to in-memory Qdrant
|
|
296
|
+
client = QdrantClient(path=":memory:")
|
|
297
|
+
print(" ā
Connected to in-memory Qdrant")
|
|
298
|
+
|
|
299
|
+
collection_name = "AIKnowledge"
|
|
300
|
+
|
|
301
|
+
# Delete existing collection if it exists
|
|
302
|
+
try:
|
|
303
|
+
client.delete_collection(collection_name)
|
|
304
|
+
except Exception:
|
|
305
|
+
pass
|
|
306
|
+
|
|
307
|
+
# Create embedding function and determine vector size
|
|
308
|
+
embedding_fn = create_embedding_function()
|
|
309
|
+
sample_vector = embedding_fn("test")
|
|
310
|
+
vector_size = len(sample_vector)
|
|
311
|
+
|
|
312
|
+
client.create_collection(
|
|
313
|
+
collection_name=collection_name,
|
|
314
|
+
vectors_config=models.VectorParams(
|
|
315
|
+
size=vector_size,
|
|
316
|
+
distance=models.Distance.COSINE,
|
|
317
|
+
),
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
# Add documents
|
|
321
|
+
documents = get_sample_documents_for_qdrant()
|
|
322
|
+
|
|
323
|
+
points = []
|
|
324
|
+
for i, doc in enumerate(documents):
|
|
325
|
+
doc_vector = embedding_fn(doc["content"])
|
|
326
|
+
payload = dict(doc)
|
|
327
|
+
payload["original_id"] = f"doc_{i + 1}"
|
|
328
|
+
|
|
329
|
+
point = models.PointStruct(
|
|
330
|
+
id=i + 1,
|
|
331
|
+
vector=doc_vector,
|
|
332
|
+
payload=payload,
|
|
333
|
+
)
|
|
334
|
+
points.append(point)
|
|
335
|
+
|
|
336
|
+
client.upsert(collection_name=collection_name, points=points, wait=True)
|
|
337
|
+
print(f" ā
Created Qdrant knowledge base with {len(documents)} articles")
|
|
338
|
+
|
|
339
|
+
qdrant_vector_store = lazy_import_vector_store("qdrant")
|
|
340
|
+
vector_store = qdrant_vector_store(client, collection_name, embedding_fn)
|
|
341
|
+
return vector_store
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def setup_weaviate_store():
|
|
345
|
+
"""Set up Weaviate vector store with sample data."""
|
|
346
|
+
print("šļø Setting up Weaviate vector store...")
|
|
347
|
+
|
|
348
|
+
try:
|
|
349
|
+
import weaviate
|
|
350
|
+
import weaviate.classes as wvc
|
|
351
|
+
except ImportError:
|
|
352
|
+
raise ImportError("Weaviate client required. Install with: pip install weaviate-client")
|
|
353
|
+
|
|
354
|
+
# Connect to local Weaviate
|
|
355
|
+
try:
|
|
356
|
+
client = weaviate.connect_to_local()
|
|
357
|
+
print(" ā
Connected to local Weaviate")
|
|
358
|
+
except Exception as e:
|
|
359
|
+
print(f" ā Failed to connect to Weaviate: {e}")
|
|
360
|
+
print(" š” Make sure Weaviate is running:")
|
|
361
|
+
print(" docker run -p 8080:8080 -p 50051:50051 cr.weaviate.io/semitechnologies/weaviate:1.26.1")
|
|
362
|
+
raise
|
|
363
|
+
|
|
364
|
+
collection_name = "AIKnowledge"
|
|
365
|
+
|
|
366
|
+
# Delete existing collection if it exists
|
|
367
|
+
try:
|
|
368
|
+
client.collections.delete(collection_name)
|
|
369
|
+
print(f" šļø Removed existing collection: {collection_name}")
|
|
370
|
+
except Exception:
|
|
371
|
+
pass
|
|
372
|
+
|
|
373
|
+
# Create collection
|
|
374
|
+
collection = client.collections.create(
|
|
375
|
+
name=collection_name,
|
|
376
|
+
properties=[
|
|
377
|
+
wvc.config.Property(name="content", data_type=wvc.config.DataType.TEXT, description="Document content"),
|
|
378
|
+
wvc.config.Property(name="topic", data_type=wvc.config.DataType.TEXT, description="Topic category"),
|
|
379
|
+
wvc.config.Property(name="difficulty", data_type=wvc.config.DataType.TEXT, description="Difficulty level"),
|
|
380
|
+
],
|
|
381
|
+
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
|
|
382
|
+
inverted_index_config=wvc.config.Configure.inverted_index(
|
|
383
|
+
bm25_b=0.75,
|
|
384
|
+
bm25_k1=1.2,
|
|
385
|
+
),
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
# Create embedding function and add documents
|
|
389
|
+
embedding_fn = create_embedding_function()
|
|
390
|
+
documents = get_sample_documents_for_weaviate()
|
|
391
|
+
|
|
392
|
+
with collection.batch.dynamic() as batch:
|
|
393
|
+
for doc in documents:
|
|
394
|
+
doc_vector = embedding_fn(doc["content"])
|
|
395
|
+
batch.add_object(properties=doc, vector=doc_vector)
|
|
396
|
+
|
|
397
|
+
client.close()
|
|
398
|
+
print(f" ā
Created Weaviate knowledge base with {len(documents)} articles")
|
|
399
|
+
|
|
400
|
+
# Reconnect and create vector store wrapper
|
|
401
|
+
client_for_store = weaviate.connect_to_local()
|
|
402
|
+
weaviate_vector_store = lazy_import_vector_store("weaviate")
|
|
403
|
+
vector_store = weaviate_vector_store(client_for_store, collection_name, embedding_fn)
|
|
404
|
+
return vector_store
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def get_sample_documents() -> list[dict[str, Any]]:
|
|
408
|
+
"""Get sample documents for ChromaDB (with nested metadata structure)."""
|
|
409
|
+
return [
|
|
410
|
+
{
|
|
411
|
+
"content": "Machine Learning is a subset of artificial intelligence that enables computers to learn and improve from experience without being explicitly programmed. It focuses on the development of computer programs that can access data and use it to learn for themselves.",
|
|
412
|
+
"metadata": {"doc_id": "ml_basics", "topic": "machine_learning", "difficulty": "beginner"},
|
|
413
|
+
},
|
|
414
|
+
{
|
|
415
|
+
"content": "Deep Learning is a subset of machine learning based on artificial neural networks with representation learning. It can learn from data that is unstructured or unlabeled. Deep learning models are inspired by information processing patterns found in biological neural networks.",
|
|
416
|
+
"metadata": {"doc_id": "dl_basics", "topic": "deep_learning", "difficulty": "intermediate"},
|
|
417
|
+
},
|
|
418
|
+
{
|
|
419
|
+
"content": "Natural Language Processing (NLP) is a branch of artificial intelligence that helps computers understand, interpret and manipulate human language. NLP draws from many disciplines, including computer science and computational linguistics.",
|
|
420
|
+
"metadata": {"doc_id": "nlp_basics", "topic": "nlp", "difficulty": "intermediate"},
|
|
421
|
+
},
|
|
422
|
+
{
|
|
423
|
+
"content": "Computer Vision is a field of artificial intelligence that trains computers to interpret and understand the visual world. Using digital images from cameras and videos and deep learning models, machines can accurately identify and classify objects.",
|
|
424
|
+
"metadata": {"doc_id": "cv_basics", "topic": "computer_vision", "difficulty": "intermediate"},
|
|
425
|
+
},
|
|
426
|
+
{
|
|
427
|
+
"content": "Reinforcement Learning is an area of machine learning where an agent learns to behave in an environment by performing actions and seeing the results. The agent receives rewards by performing correctly and penalties for performing incorrectly.",
|
|
428
|
+
"metadata": {"doc_id": "rl_basics", "topic": "reinforcement_learning", "difficulty": "advanced"},
|
|
429
|
+
},
|
|
430
|
+
{
|
|
431
|
+
"content": "Large Language Models (LLMs) are a type of artificial intelligence model designed to understand and generate human-like text. They are trained on vast amounts of text data and can perform various natural language tasks such as translation, summarization, and question answering.",
|
|
432
|
+
"metadata": {"doc_id": "llm_basics", "topic": "large_language_models", "difficulty": "advanced"},
|
|
433
|
+
},
|
|
434
|
+
]
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def get_sample_documents_simple() -> list[dict[str, str]]:
|
|
438
|
+
"""Get sample documents for LanceDB/Milvus (flat structure)."""
|
|
439
|
+
return [
|
|
440
|
+
{"content": "Machine learning is a method of data analysis that automates analytical model building."},
|
|
441
|
+
{"content": "It is a branch of artificial intelligence based on the idea that systems can learn from data."},
|
|
442
|
+
{"content": "Machine learning algorithms build a model based on training data to make predictions."},
|
|
443
|
+
{
|
|
444
|
+
"content": "Deep learning is part of a broader family of machine learning methods based on artificial neural networks."
|
|
445
|
+
},
|
|
446
|
+
{"content": "It uses multiple layers to progressively extract higher-level features from raw input."},
|
|
447
|
+
{
|
|
448
|
+
"content": "Deep learning models can automatically learn representations of data with multiple levels of abstraction."
|
|
449
|
+
},
|
|
450
|
+
{
|
|
451
|
+
"content": "Natural language processing (NLP) is a subfield of linguistics, computer science, and artificial intelligence."
|
|
452
|
+
},
|
|
453
|
+
{"content": "It deals with the interaction between computers and human language."},
|
|
454
|
+
{"content": "NLP techniques enable computers to process and analyze large amounts of natural language data."},
|
|
455
|
+
]
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def get_sample_documents_for_qdrant() -> list[dict[str, Any]]:
|
|
459
|
+
"""Get sample documents for Qdrant (with flat metadata)."""
|
|
460
|
+
return [
|
|
461
|
+
{
|
|
462
|
+
"content": "Artificial Intelligence (AI) is the simulation of human intelligence in machines that are programmed to think and learn like humans. The term may also be applied to any machine that exhibits traits associated with a human mind such as learning and problem-solving.",
|
|
463
|
+
"topic": "artificial_intelligence",
|
|
464
|
+
"difficulty": "beginner",
|
|
465
|
+
"category": "definition",
|
|
466
|
+
},
|
|
467
|
+
{
|
|
468
|
+
"content": "Machine Learning is a method of data analysis that automates analytical model building. It is a branch of artificial intelligence based on the idea that systems can learn from data, identify patterns and make decisions with minimal human intervention.",
|
|
469
|
+
"topic": "machine_learning",
|
|
470
|
+
"difficulty": "beginner",
|
|
471
|
+
"category": "definition",
|
|
472
|
+
},
|
|
473
|
+
{
|
|
474
|
+
"content": "Deep Learning is part of a broader family of machine learning methods based on artificial neural networks with representation learning. Learning can be supervised, semi-supervised or unsupervised. Deep learning architectures such as deep neural networks have been applied to computer vision, speech recognition, and natural language processing.",
|
|
475
|
+
"topic": "deep_learning",
|
|
476
|
+
"difficulty": "intermediate",
|
|
477
|
+
"category": "technical",
|
|
478
|
+
},
|
|
479
|
+
{
|
|
480
|
+
"content": "Natural Language Processing (NLP) is a subfield of linguistics, computer science, and artificial intelligence concerned with the interactions between computers and human language. The goal is to program computers to process and analyze large amounts of natural language data.",
|
|
481
|
+
"topic": "nlp",
|
|
482
|
+
"difficulty": "intermediate",
|
|
483
|
+
"category": "technical",
|
|
484
|
+
},
|
|
485
|
+
{
|
|
486
|
+
"content": "Computer Vision is a field of artificial intelligence (AI) that enables computers and systems to derive meaningful information from digital images, videos and other visual inputs. It uses machine learning models to analyze and interpret visual data.",
|
|
487
|
+
"topic": "computer_vision",
|
|
488
|
+
"difficulty": "intermediate",
|
|
489
|
+
"category": "application",
|
|
490
|
+
},
|
|
491
|
+
{
|
|
492
|
+
"content": "Transformers are a deep learning architecture that has revolutionized natural language processing. They rely entirely on self-attention mechanisms to draw global dependencies between input and output, dispensing with recurrence and convolutions entirely.",
|
|
493
|
+
"topic": "transformers",
|
|
494
|
+
"difficulty": "advanced",
|
|
495
|
+
"category": "architecture",
|
|
496
|
+
},
|
|
497
|
+
]
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def get_sample_documents_for_weaviate() -> list[dict[str, str]]:
|
|
501
|
+
"""Get sample documents for Weaviate (flat string properties)."""
|
|
502
|
+
return [
|
|
503
|
+
{
|
|
504
|
+
"content": "Artificial Intelligence (AI) is the simulation of human intelligence in machines that are programmed to think and learn like humans. The term may also be applied to any machine that exhibits traits associated with a human mind such as learning and problem-solving.",
|
|
505
|
+
"topic": "artificial_intelligence",
|
|
506
|
+
"difficulty": "beginner",
|
|
507
|
+
},
|
|
508
|
+
{
|
|
509
|
+
"content": "Machine Learning is a method of data analysis that automates analytical model building. It is a branch of artificial intelligence based on the idea that systems can learn from data, identify patterns and make decisions with minimal human intervention.",
|
|
510
|
+
"topic": "machine_learning",
|
|
511
|
+
"difficulty": "beginner",
|
|
512
|
+
},
|
|
513
|
+
{
|
|
514
|
+
"content": "Deep Learning is part of a broader family of machine learning methods based on artificial neural networks with representation learning. Learning can be supervised, semi-supervised or unsupervised. Deep learning architectures such as deep neural networks have been applied to computer vision, speech recognition, and natural language processing.",
|
|
515
|
+
"topic": "deep_learning",
|
|
516
|
+
"difficulty": "intermediate",
|
|
517
|
+
},
|
|
518
|
+
{
|
|
519
|
+
"content": "Natural Language Processing (NLP) is a subfield of linguistics, computer science, and artificial intelligence concerned with the interactions between computers and human language. The goal is to program computers to process and analyze large amounts of natural language data.",
|
|
520
|
+
"topic": "nlp",
|
|
521
|
+
"difficulty": "intermediate",
|
|
522
|
+
},
|
|
523
|
+
{
|
|
524
|
+
"content": "Computer Vision is an interdisciplinary scientific field that deals with how computers can gain high-level understanding from digital images or videos. From an engineering perspective, it seeks to understand and automate tasks that the human visual system can do.",
|
|
525
|
+
"topic": "computer_vision",
|
|
526
|
+
"difficulty": "intermediate",
|
|
527
|
+
},
|
|
528
|
+
{
|
|
529
|
+
"content": "Transformers are a deep learning architecture that has revolutionized natural language processing. They rely entirely on self-attention mechanisms to draw global dependencies between input and output, dispensing with recurrence and convolutions entirely.",
|
|
530
|
+
"topic": "transformers",
|
|
531
|
+
"difficulty": "advanced",
|
|
532
|
+
},
|
|
533
|
+
]
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def create_training_data() -> tuple[list[RAGDataInst], list[RAGDataInst]]:
|
|
537
|
+
"""Create training and validation datasets for RAG optimization."""
|
|
538
|
+
# Training examples
|
|
539
|
+
train_data = [
|
|
540
|
+
RAGDataInst(
|
|
541
|
+
query="What is machine learning?",
|
|
542
|
+
ground_truth_answer="Machine Learning is a method of data analysis that automates analytical model building. It is a branch of artificial intelligence based on the idea that systems can learn from data, identify patterns and make decisions with minimal human intervention.",
|
|
543
|
+
relevant_doc_ids=["ml_basics"],
|
|
544
|
+
metadata={"category": "definition", "difficulty": "beginner"},
|
|
545
|
+
),
|
|
546
|
+
RAGDataInst(
|
|
547
|
+
query="How does deep learning work?",
|
|
548
|
+
ground_truth_answer="Deep Learning is a subset of machine learning based on artificial neural networks with representation learning. It can learn from data that is unstructured or unlabeled. Deep learning models are inspired by information processing patterns found in biological neural networks.",
|
|
549
|
+
relevant_doc_ids=["dl_basics"],
|
|
550
|
+
metadata={"category": "explanation", "difficulty": "intermediate"},
|
|
551
|
+
),
|
|
552
|
+
RAGDataInst(
|
|
553
|
+
query="What is natural language processing?",
|
|
554
|
+
ground_truth_answer="Natural Language Processing (NLP) is a branch of artificial intelligence that helps computers understand, interpret and manipulate human language. NLP draws from many disciplines, including computer science and computational linguistics.",
|
|
555
|
+
relevant_doc_ids=["nlp_basics"],
|
|
556
|
+
metadata={"category": "definition", "difficulty": "intermediate"},
|
|
557
|
+
),
|
|
558
|
+
]
|
|
559
|
+
|
|
560
|
+
# Validation examples
|
|
561
|
+
val_data = [
|
|
562
|
+
RAGDataInst(
|
|
563
|
+
query="Explain computer vision in AI",
|
|
564
|
+
ground_truth_answer="Computer Vision is a field of artificial intelligence that trains computers to interpret and understand the visual world. Using digital images from cameras and videos and deep learning models, machines can accurately identify and classify objects.",
|
|
565
|
+
relevant_doc_ids=["cv_basics"],
|
|
566
|
+
metadata={"category": "explanation", "difficulty": "intermediate"},
|
|
567
|
+
),
|
|
568
|
+
RAGDataInst(
|
|
569
|
+
query="What are large language models?",
|
|
570
|
+
ground_truth_answer="Large Language Models (LLMs) are a type of artificial intelligence model designed to understand and generate human-like text. They are trained on vast amounts of text data and can perform various natural language tasks such as translation, summarization, and question answering.",
|
|
571
|
+
relevant_doc_ids=["llm_basics"],
|
|
572
|
+
metadata={"category": "definition", "difficulty": "advanced"},
|
|
573
|
+
),
|
|
574
|
+
]
|
|
575
|
+
|
|
576
|
+
return train_data, val_data
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
def clean_answer(answer: str) -> str:
|
|
580
|
+
"""Clean up LLM answer by removing thinking tokens and truncating appropriately."""
|
|
581
|
+
import re
|
|
582
|
+
|
|
583
|
+
cleaned = re.sub(r"<think>.*?</think>", "", answer, flags=re.DOTALL)
|
|
584
|
+
cleaned = cleaned.strip()
|
|
585
|
+
|
|
586
|
+
# If still empty or starts with <think> without closing tag, try to find content after
|
|
587
|
+
if not cleaned or cleaned.startswith("<think>"):
|
|
588
|
+
lines = answer.split("\n")
|
|
589
|
+
content_lines = []
|
|
590
|
+
skip_thinking = False
|
|
591
|
+
|
|
592
|
+
for line in lines:
|
|
593
|
+
if "<think>" in line:
|
|
594
|
+
skip_thinking = True
|
|
595
|
+
continue
|
|
596
|
+
if "</think>" in line:
|
|
597
|
+
skip_thinking = False
|
|
598
|
+
continue
|
|
599
|
+
if not skip_thinking and line.strip():
|
|
600
|
+
content_lines.append(line.strip())
|
|
601
|
+
|
|
602
|
+
cleaned = " ".join(content_lines)
|
|
603
|
+
|
|
604
|
+
# Show more of the answer - increase limit significantly
|
|
605
|
+
if len(cleaned) > 500:
|
|
606
|
+
return cleaned[:500] + "..."
|
|
607
|
+
return cleaned or answer[:500] + ("..." if len(answer) > 500 else "")
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def create_initial_prompts() -> dict[str, str]:
|
|
611
|
+
"""Create initial prompt templates for optimization."""
|
|
612
|
+
return {
|
|
613
|
+
"answer_generation": """You are an AI expert providing accurate technical explanations.
|
|
614
|
+
|
|
615
|
+
Based on the retrieved context, provide a clear and informative answer to the user's question.
|
|
616
|
+
|
|
617
|
+
Guidelines:
|
|
618
|
+
- Use information from the provided context
|
|
619
|
+
- Be accurate and concise
|
|
620
|
+
- Include key technical details
|
|
621
|
+
- Structure your response clearly
|
|
622
|
+
|
|
623
|
+
Context: {context}
|
|
624
|
+
|
|
625
|
+
Question: {query}
|
|
626
|
+
|
|
627
|
+
Answer:"""
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
def setup_vector_store(store_name: str):
|
|
632
|
+
"""Factory function to set up the specified vector store."""
|
|
633
|
+
setup_functions = {
|
|
634
|
+
"chromadb": setup_chromadb_store,
|
|
635
|
+
"lancedb": setup_lancedb_store,
|
|
636
|
+
"milvus": setup_milvus_store,
|
|
637
|
+
"qdrant": setup_qdrant_store,
|
|
638
|
+
"weaviate": setup_weaviate_store,
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
if store_name not in setup_functions:
|
|
642
|
+
raise ValueError(f"Unknown vector store: {store_name}. Supported: {list(setup_functions.keys())}")
|
|
643
|
+
|
|
644
|
+
return setup_functions[store_name]()
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
def parse_arguments():
|
|
648
|
+
"""Parse command line arguments."""
|
|
649
|
+
parser = argparse.ArgumentParser(
|
|
650
|
+
description="GEPA RAG Optimization Example with Multiple Vector Stores",
|
|
651
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
652
|
+
epilog="""
|
|
653
|
+
Examples:
|
|
654
|
+
python rag_optimization.py --vector-store chromadb
|
|
655
|
+
python rag_optimization.py --vector-store lancedb --model ollama/llama3.1:8b
|
|
656
|
+
python rag_optimization.py --vector-store qdrant --max-iterations 10
|
|
657
|
+
python rag_optimization.py --vector-store weaviate --model gpt-4o-mini
|
|
658
|
+
|
|
659
|
+
Supported Vector Stores:
|
|
660
|
+
chromadb - Local/persistent, simple setup (default)
|
|
661
|
+
lancedb - Serverless, no Docker required
|
|
662
|
+
milvus - Cloud-native, uses Lite mode locally
|
|
663
|
+
qdrant - High-performance, advanced filtering
|
|
664
|
+
weaviate - Hybrid search capabilities (requires Docker)
|
|
665
|
+
""",
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
parser.add_argument(
|
|
669
|
+
"--vector-store",
|
|
670
|
+
type=str,
|
|
671
|
+
default="chromadb",
|
|
672
|
+
choices=["chromadb", "lancedb", "milvus", "qdrant", "weaviate"],
|
|
673
|
+
help="Vector store to use (default: chromadb)",
|
|
674
|
+
)
|
|
675
|
+
parser.add_argument("--model", type=str, default="ollama/qwen3:8b", help="LLM model (default: ollama/qwen3:8b)")
|
|
676
|
+
parser.add_argument(
|
|
677
|
+
"--embedding-model",
|
|
678
|
+
type=str,
|
|
679
|
+
default="ollama/nomic-embed-text:latest",
|
|
680
|
+
help="Embedding model (default: ollama/nomic-embed-text:latest)",
|
|
681
|
+
)
|
|
682
|
+
parser.add_argument(
|
|
683
|
+
"--max-iterations",
|
|
684
|
+
type=int,
|
|
685
|
+
default=5,
|
|
686
|
+
help="GEPA optimization iterations (default: 5, use 0 to skip optimization)",
|
|
687
|
+
)
|
|
688
|
+
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
|
|
689
|
+
|
|
690
|
+
return parser.parse_args()
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def main():
|
|
694
|
+
"""Main function demonstrating RAG optimization with multiple vector stores."""
|
|
695
|
+
args = parse_arguments()
|
|
696
|
+
|
|
697
|
+
print("š GEPA RAG Optimization with Multiple Vector Stores")
|
|
698
|
+
print("=" * 60)
|
|
699
|
+
print(f"šļø Vector Store: {args.vector_store}")
|
|
700
|
+
print(f"š Model: {args.model}")
|
|
701
|
+
print(f"š Embeddings: {args.embedding_model}")
|
|
702
|
+
print(f"š Max Iterations: {args.max_iterations}")
|
|
703
|
+
|
|
704
|
+
try:
|
|
705
|
+
# Step 1: Setup vector store
|
|
706
|
+
print(f"\n1ļøā£ Setting up {args.vector_store} vector store...")
|
|
707
|
+
vector_store = setup_vector_store(args.vector_store)
|
|
708
|
+
|
|
709
|
+
# Step 2: Create datasets
|
|
710
|
+
print("\n2ļøā£ Creating training and validation datasets...")
|
|
711
|
+
train_data, val_data = create_training_data()
|
|
712
|
+
print(f" š Training examples: {len(train_data)}")
|
|
713
|
+
print(f" š Validation examples: {len(val_data)}")
|
|
714
|
+
|
|
715
|
+
# Step 3: Initialize LLM client
|
|
716
|
+
print(f"\n3ļøā£ Initializing LLM client ({args.model})...")
|
|
717
|
+
llm_client = create_llm_client(args.model)
|
|
718
|
+
|
|
719
|
+
# Test LLM
|
|
720
|
+
test_response = llm_client([{"role": "user", "content": "Say 'OK' only."}])
|
|
721
|
+
if "Error:" not in test_response:
|
|
722
|
+
print(f" ā
LLM connected: {test_response[:30]}...")
|
|
723
|
+
else:
|
|
724
|
+
print(f" ā ļø LLM issue: {test_response}")
|
|
725
|
+
|
|
726
|
+
# Step 4: Initialize RAG adapter
|
|
727
|
+
print("\n4ļøā£ Initializing GenericRAGAdapter...")
|
|
728
|
+
rag_config = {
|
|
729
|
+
"retrieval_strategy": "similarity",
|
|
730
|
+
"top_k": 3,
|
|
731
|
+
"retrieval_weight": 0.3,
|
|
732
|
+
"generation_weight": 0.7,
|
|
733
|
+
}
|
|
734
|
+
|
|
735
|
+
# Add hybrid search for Weaviate
|
|
736
|
+
if args.vector_store == "weaviate":
|
|
737
|
+
rag_config["retrieval_strategy"] = "hybrid"
|
|
738
|
+
rag_config["hybrid_alpha"] = 0.7
|
|
739
|
+
|
|
740
|
+
rag_adapter = GenericRAGAdapter(
|
|
741
|
+
vector_store=vector_store,
|
|
742
|
+
llm_model=llm_client,
|
|
743
|
+
embedding_model=args.embedding_model,
|
|
744
|
+
rag_config=rag_config,
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
# Step 5: Create initial prompts
|
|
748
|
+
print("\n5ļøā£ Creating initial prompts...")
|
|
749
|
+
initial_prompts = create_initial_prompts()
|
|
750
|
+
|
|
751
|
+
# Step 6: Test initial performance
|
|
752
|
+
print("\n6ļøā£ Testing initial performance...")
|
|
753
|
+
eval_result = rag_adapter.evaluate(batch=val_data[:1], candidate=initial_prompts, capture_traces=True)
|
|
754
|
+
|
|
755
|
+
initial_score = eval_result.scores[0]
|
|
756
|
+
print(f" š Initial score: {initial_score:.3f}")
|
|
757
|
+
print(f" š¬ Sample answer: {clean_answer(eval_result.outputs[0]['final_answer'])}")
|
|
758
|
+
|
|
759
|
+
# Step 7: Run GEPA optimization
|
|
760
|
+
if args.max_iterations > 0:
|
|
761
|
+
print(f"\n7ļøā£ Running GEPA optimization ({args.max_iterations} iterations)...")
|
|
762
|
+
|
|
763
|
+
result = gepa.optimize(
|
|
764
|
+
seed_candidate=initial_prompts,
|
|
765
|
+
trainset=train_data,
|
|
766
|
+
valset=val_data,
|
|
767
|
+
adapter=rag_adapter,
|
|
768
|
+
reflection_lm=llm_client,
|
|
769
|
+
max_metric_calls=args.max_iterations,
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
best_score = result.val_aggregate_scores[result.best_idx]
|
|
773
|
+
print(" š Optimization complete!")
|
|
774
|
+
print(f" š Best score: {best_score:.3f}")
|
|
775
|
+
print(f" š Improvement: {best_score - initial_score:+.3f}")
|
|
776
|
+
print(f" š Total iterations: {result.total_metric_calls or 0}")
|
|
777
|
+
|
|
778
|
+
# Test optimized prompts
|
|
779
|
+
print("\n Testing optimized prompts...")
|
|
780
|
+
optimized_result = rag_adapter.evaluate(
|
|
781
|
+
batch=val_data[:1], candidate=result.best_candidate, capture_traces=False
|
|
782
|
+
)
|
|
783
|
+
print(f" š¬ Optimized answer: {clean_answer(optimized_result.outputs[0]['final_answer'])}")
|
|
784
|
+
|
|
785
|
+
else:
|
|
786
|
+
print("\n7ļøā£ Skipping optimization (use --max-iterations > 0 to enable)")
|
|
787
|
+
|
|
788
|
+
print(f"\nā
{args.vector_store.title()} RAG optimization completed successfully!")
|
|
789
|
+
|
|
790
|
+
# Clean up connections
|
|
791
|
+
try:
|
|
792
|
+
if hasattr(vector_store, "client") and hasattr(vector_store.client, "close"):
|
|
793
|
+
vector_store.client.close()
|
|
794
|
+
except Exception:
|
|
795
|
+
pass
|
|
796
|
+
|
|
797
|
+
except Exception as e:
|
|
798
|
+
print(f"\nā Error: {e}")
|
|
799
|
+
if args.verbose:
|
|
800
|
+
import traceback
|
|
801
|
+
|
|
802
|
+
traceback.print_exc()
|
|
803
|
+
|
|
804
|
+
print("\nš§ Troubleshooting tips:")
|
|
805
|
+
if args.vector_store == "weaviate":
|
|
806
|
+
print(" ⢠Ensure Weaviate is running: curl http://localhost:8080/v1/meta")
|
|
807
|
+
print(
|
|
808
|
+
" ⢠Start Weaviate: docker run -p 8080:8080 -p 50051:50051 cr.weaviate.io/semitechnologies/weaviate:1.26.1"
|
|
809
|
+
)
|
|
810
|
+
elif args.vector_store == "qdrant":
|
|
811
|
+
print(" ⢠For external Qdrant: docker run -p 6333:6333 qdrant/qdrant")
|
|
812
|
+
|
|
813
|
+
print(" ⢠Ensure Ollama is running: ollama list")
|
|
814
|
+
print(" ⢠Check models are available: ollama pull qwen3:8b")
|
|
815
|
+
print(" ⢠For cloud models: set API keys (OPENAI_API_KEY, etc.)")
|
|
816
|
+
print(f" ⢠Install dependencies: pip install {get_install_command(args.vector_store)}")
|
|
817
|
+
|
|
818
|
+
return 1
|
|
819
|
+
|
|
820
|
+
return 0
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
if __name__ == "__main__":
|
|
824
|
+
sys.exit(main())
|