code-graph-builder 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- code_graph_builder/__init__.py +82 -0
- code_graph_builder/builder.py +366 -0
- code_graph_builder/cgb_cli.py +32 -0
- code_graph_builder/cli.py +564 -0
- code_graph_builder/commands_cli.py +1288 -0
- code_graph_builder/config.py +340 -0
- code_graph_builder/constants.py +708 -0
- code_graph_builder/embeddings/__init__.py +40 -0
- code_graph_builder/embeddings/qwen3_embedder.py +573 -0
- code_graph_builder/embeddings/vector_store.py +584 -0
- code_graph_builder/examples/__init__.py +0 -0
- code_graph_builder/examples/example_configuration.py +276 -0
- code_graph_builder/examples/example_kuzu_usage.py +109 -0
- code_graph_builder/examples/example_semantic_search_full.py +347 -0
- code_graph_builder/examples/generate_wiki.py +915 -0
- code_graph_builder/examples/graph_export_example.py +100 -0
- code_graph_builder/examples/rag_example.py +206 -0
- code_graph_builder/examples/test_cli_demo.py +129 -0
- code_graph_builder/examples/test_embedding_api.py +153 -0
- code_graph_builder/examples/test_kuzu_local.py +190 -0
- code_graph_builder/examples/test_rag_redis.py +390 -0
- code_graph_builder/graph_updater.py +605 -0
- code_graph_builder/guidance/__init__.py +1 -0
- code_graph_builder/guidance/agent.py +123 -0
- code_graph_builder/guidance/prompts.py +74 -0
- code_graph_builder/guidance/toolset.py +264 -0
- code_graph_builder/language_spec.py +536 -0
- code_graph_builder/mcp/__init__.py +21 -0
- code_graph_builder/mcp/api_doc_generator.py +764 -0
- code_graph_builder/mcp/file_editor.py +207 -0
- code_graph_builder/mcp/pipeline.py +777 -0
- code_graph_builder/mcp/server.py +161 -0
- code_graph_builder/mcp/tools.py +1800 -0
- code_graph_builder/models.py +115 -0
- code_graph_builder/parser_loader.py +344 -0
- code_graph_builder/parsers/__init__.py +7 -0
- code_graph_builder/parsers/call_processor.py +306 -0
- code_graph_builder/parsers/call_resolver.py +139 -0
- code_graph_builder/parsers/definition_processor.py +796 -0
- code_graph_builder/parsers/factory.py +119 -0
- code_graph_builder/parsers/import_processor.py +293 -0
- code_graph_builder/parsers/structure_processor.py +145 -0
- code_graph_builder/parsers/type_inference.py +143 -0
- code_graph_builder/parsers/utils.py +134 -0
- code_graph_builder/rag/__init__.py +68 -0
- code_graph_builder/rag/camel_agent.py +429 -0
- code_graph_builder/rag/client.py +298 -0
- code_graph_builder/rag/config.py +239 -0
- code_graph_builder/rag/cypher_generator.py +67 -0
- code_graph_builder/rag/llm_backend.py +210 -0
- code_graph_builder/rag/markdown_generator.py +352 -0
- code_graph_builder/rag/prompt_templates.py +440 -0
- code_graph_builder/rag/rag_engine.py +640 -0
- code_graph_builder/rag/review_report.md +172 -0
- code_graph_builder/rag/tests/__init__.py +3 -0
- code_graph_builder/rag/tests/test_camel_agent.py +313 -0
- code_graph_builder/rag/tests/test_client.py +221 -0
- code_graph_builder/rag/tests/test_config.py +177 -0
- code_graph_builder/rag/tests/test_markdown_generator.py +240 -0
- code_graph_builder/rag/tests/test_prompt_templates.py +160 -0
- code_graph_builder/services/__init__.py +39 -0
- code_graph_builder/services/graph_service.py +465 -0
- code_graph_builder/services/kuzu_service.py +665 -0
- code_graph_builder/services/memory_service.py +171 -0
- code_graph_builder/settings.py +75 -0
- code_graph_builder/tests/ACCEPTANCE_CRITERIA_PHASE2.md +401 -0
- code_graph_builder/tests/__init__.py +1 -0
- code_graph_builder/tests/run_acceptance_check.py +378 -0
- code_graph_builder/tests/test_api_find.py +231 -0
- code_graph_builder/tests/test_api_find_integration.py +226 -0
- code_graph_builder/tests/test_basic.py +78 -0
- code_graph_builder/tests/test_c_api_extraction.py +388 -0
- code_graph_builder/tests/test_call_resolution_scenarios.py +504 -0
- code_graph_builder/tests/test_embedder.py +411 -0
- code_graph_builder/tests/test_integration_semantic.py +434 -0
- code_graph_builder/tests/test_mcp_protocol.py +298 -0
- code_graph_builder/tests/test_mcp_user_flow.py +190 -0
- code_graph_builder/tests/test_rag.py +404 -0
- code_graph_builder/tests/test_settings.py +135 -0
- code_graph_builder/tests/test_step1_graph_build.py +264 -0
- code_graph_builder/tests/test_step2_api_docs.py +323 -0
- code_graph_builder/tests/test_step3_embedding.py +278 -0
- code_graph_builder/tests/test_vector_store.py +552 -0
- code_graph_builder/tools/__init__.py +40 -0
- code_graph_builder/tools/graph_query.py +495 -0
- code_graph_builder/tools/semantic_search.py +387 -0
- code_graph_builder/types.py +333 -0
- code_graph_builder/utils/__init__.py +0 -0
- code_graph_builder/utils/path_utils.py +30 -0
- code_graph_builder-0.2.0.dist-info/METADATA +321 -0
- code_graph_builder-0.2.0.dist-info/RECORD +93 -0
- code_graph_builder-0.2.0.dist-info/WHEEL +4 -0
- code_graph_builder-0.2.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Embeddings module for code semantic search.
|
|
2
|
+
|
|
3
|
+
This module provides embedding functionality for code using Qwen3 models.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from .qwen3_embedder import (
|
|
9
|
+
BaseEmbedder,
|
|
10
|
+
DummyEmbedder,
|
|
11
|
+
Qwen3Embedder,
|
|
12
|
+
create_embedder,
|
|
13
|
+
last_token_pool,
|
|
14
|
+
)
|
|
15
|
+
from .vector_store import (
|
|
16
|
+
MemoryVectorStore,
|
|
17
|
+
QdrantVectorStore,
|
|
18
|
+
SearchResult,
|
|
19
|
+
VectorRecord,
|
|
20
|
+
VectorStore,
|
|
21
|
+
cosine_similarity,
|
|
22
|
+
create_vector_store,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
# Embedders
|
|
27
|
+
"BaseEmbedder",
|
|
28
|
+
"DummyEmbedder",
|
|
29
|
+
"Qwen3Embedder",
|
|
30
|
+
"create_embedder",
|
|
31
|
+
"last_token_pool",
|
|
32
|
+
# Vector stores
|
|
33
|
+
"VectorStore",
|
|
34
|
+
"MemoryVectorStore",
|
|
35
|
+
"QdrantVectorStore",
|
|
36
|
+
"VectorRecord",
|
|
37
|
+
"SearchResult",
|
|
38
|
+
"create_vector_store",
|
|
39
|
+
"cosine_similarity",
|
|
40
|
+
]
|
|
@@ -0,0 +1,573 @@
|
|
|
1
|
+
"""Qwen3 Embedder for code semantic embeddings via Alibaba Cloud Bailian API.
|
|
2
|
+
|
|
3
|
+
This module provides the Qwen3Embedder class for generating code embeddings
|
|
4
|
+
using the Qwen3 embedding models via Alibaba Cloud Bailian API.
|
|
5
|
+
|
|
6
|
+
Required environment variables:
|
|
7
|
+
- DASHSCOPE_API_KEY: Your Alibaba Cloud DashScope API key
|
|
8
|
+
- DASHSCOPE_BASE_URL: API base URL (default: https://dashscope.aliyuncs.com/api/v1)
|
|
9
|
+
|
|
10
|
+
Example:
|
|
11
|
+
export DASHSCOPE_API_KEY="sk-xxxxxxxx"
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
import requests
|
|
21
|
+
from loguru import logger
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BaseEmbedder(ABC):
|
|
25
|
+
"""Abstract base class for code embedders."""
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def embed_code(self, text: str) -> list[float]:
|
|
29
|
+
"""Generate embedding for a single code snippet.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
text: Code text to embed
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Embedding vector as list of floats
|
|
36
|
+
"""
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
41
|
+
"""Generate embeddings for multiple code snippets.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
texts: List of code texts to embed
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
List of embedding vectors
|
|
48
|
+
"""
|
|
49
|
+
...
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def get_embedding_dimension(self) -> int:
|
|
53
|
+
"""Return the embedding vector dimension."""
|
|
54
|
+
...
|
|
55
|
+
|
|
56
|
+
def embed_query(self, query: str) -> list[float]:
|
|
57
|
+
"""Generate embedding for a search query.
|
|
58
|
+
|
|
59
|
+
Subclasses may override to add task instructions for better retrieval.
|
|
60
|
+
"""
|
|
61
|
+
return self.embed_code(query)
|
|
62
|
+
|
|
63
|
+
def embed_documents(self, documents: list[str], show_progress: bool = True) -> list[list[float]]:
|
|
64
|
+
"""Generate embeddings for documents (code snippets for indexing)."""
|
|
65
|
+
return self.embed_batch(documents)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class Qwen3Embedder(BaseEmbedder):
|
|
69
|
+
"""Qwen3 embedding model wrapper using Alibaba Cloud Bailian API.
|
|
70
|
+
|
|
71
|
+
Uses DashScope API to call text-embedding-v4 (Qwen3 Embedding) models.
|
|
72
|
+
No local model download required.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
api_key: DashScope API key (or from DASHSCOPE_API_KEY env var)
|
|
76
|
+
model: Model name (default: text-embedding-v4)
|
|
77
|
+
base_url: API base URL
|
|
78
|
+
batch_size: Batch size for embedding generation (max 25 for API)
|
|
79
|
+
max_retries: Maximum number of retries for failed requests
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
DEFAULT_MODEL = "text-embedding-v4"
|
|
83
|
+
DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1"
|
|
84
|
+
DEFAULT_BATCH_SIZE = 25 # API limit
|
|
85
|
+
MAX_BATCH_SIZE = 25
|
|
86
|
+
CODE_RETRIEVAL_TASK = "Given a code query, retrieve relevant code snippets"
|
|
87
|
+
EMBEDDING_DIMENSION = 1536 # text-embedding-v4 output dimension
|
|
88
|
+
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
api_key: str | None = None,
|
|
92
|
+
model: str = DEFAULT_MODEL,
|
|
93
|
+
base_url: str | None = None,
|
|
94
|
+
batch_size: int = DEFAULT_BATCH_SIZE,
|
|
95
|
+
max_retries: int = 3,
|
|
96
|
+
):
|
|
97
|
+
self.api_key = api_key or os.getenv("DASHSCOPE_API_KEY")
|
|
98
|
+
if not self.api_key:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
"DashScope API key required. Set DASHSCOPE_API_KEY environment variable "
|
|
101
|
+
"or pass api_key parameter."
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
self.model = model
|
|
105
|
+
self.base_url = base_url or os.getenv(
|
|
106
|
+
"DASHSCOPE_BASE_URL", self.DEFAULT_BASE_URL
|
|
107
|
+
)
|
|
108
|
+
self.batch_size = min(batch_size, self.MAX_BATCH_SIZE)
|
|
109
|
+
self.max_retries = max_retries
|
|
110
|
+
|
|
111
|
+
# Validate API key format
|
|
112
|
+
if not self.api_key.startswith("sk-"):
|
|
113
|
+
logger.warning("API key format may be invalid. Expected to start with 'sk-'")
|
|
114
|
+
|
|
115
|
+
logger.info(f"Initialized Qwen3Embedder with model: {self.model}")
|
|
116
|
+
|
|
117
|
+
def _get_headers(self) -> dict[str, str]:
|
|
118
|
+
"""Get API request headers."""
|
|
119
|
+
return {
|
|
120
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
121
|
+
"Content-Type": "application/json",
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
def _make_request(
|
|
125
|
+
self,
|
|
126
|
+
texts: list[str],
|
|
127
|
+
text_type: str = "document",
|
|
128
|
+
dimensions: int | None = None,
|
|
129
|
+
) -> dict[str, Any]:
|
|
130
|
+
"""Make API request to get embeddings.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
texts: List of texts to embed
|
|
134
|
+
text_type: Type of text ("document" or "query")
|
|
135
|
+
dimensions: Optional dimension reduction (not supported by all models)
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
API response JSON
|
|
139
|
+
"""
|
|
140
|
+
url = f"{self.base_url}/services/embeddings/text-embedding/text-embedding"
|
|
141
|
+
|
|
142
|
+
payload: dict[str, Any] = {
|
|
143
|
+
"model": self.model,
|
|
144
|
+
"input": {
|
|
145
|
+
"texts": texts,
|
|
146
|
+
},
|
|
147
|
+
"parameters": {
|
|
148
|
+
"text_type": text_type,
|
|
149
|
+
},
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
if dimensions is not None:
|
|
153
|
+
payload["parameters"]["dimensions"] = dimensions
|
|
154
|
+
|
|
155
|
+
for attempt in range(self.max_retries):
|
|
156
|
+
try:
|
|
157
|
+
response = requests.post(
|
|
158
|
+
url,
|
|
159
|
+
headers=self._get_headers(),
|
|
160
|
+
json=payload,
|
|
161
|
+
timeout=60,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if response.status_code == 200:
|
|
165
|
+
return response.json()
|
|
166
|
+
|
|
167
|
+
# Handle rate limiting
|
|
168
|
+
if response.status_code == 429:
|
|
169
|
+
import time
|
|
170
|
+
|
|
171
|
+
wait_time = 2 ** attempt
|
|
172
|
+
logger.warning(f"Rate limited. Waiting {wait_time}s...")
|
|
173
|
+
time.sleep(wait_time)
|
|
174
|
+
continue
|
|
175
|
+
|
|
176
|
+
# Handle other errors
|
|
177
|
+
error_msg = f"API request failed: {response.status_code}"
|
|
178
|
+
try:
|
|
179
|
+
error_data = response.json()
|
|
180
|
+
error_msg += f" - {error_data.get('message', '')}"
|
|
181
|
+
except Exception:
|
|
182
|
+
error_msg += f" - {response.text[:200]}"
|
|
183
|
+
|
|
184
|
+
if attempt < self.max_retries - 1:
|
|
185
|
+
logger.warning(f"{error_msg}, retrying...")
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
raise RuntimeError(error_msg)
|
|
189
|
+
|
|
190
|
+
except requests.exceptions.Timeout:
|
|
191
|
+
if attempt < self.max_retries - 1:
|
|
192
|
+
logger.warning(f"Request timeout, retrying... ({attempt + 1}/{self.max_retries})")
|
|
193
|
+
continue
|
|
194
|
+
raise RuntimeError("API request timeout after all retries")
|
|
195
|
+
|
|
196
|
+
except requests.exceptions.RequestException as e:
|
|
197
|
+
if attempt < self.max_retries - 1:
|
|
198
|
+
logger.warning(f"Request error: {e}, retrying...")
|
|
199
|
+
continue
|
|
200
|
+
raise RuntimeError(f"API request failed: {e}")
|
|
201
|
+
|
|
202
|
+
raise RuntimeError("All retries failed")
|
|
203
|
+
|
|
204
|
+
def _extract_embeddings(self, response: dict[str, Any]) -> list[list[float]]:
|
|
205
|
+
"""Extract embeddings from API response.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
response: API response JSON
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
List of embedding vectors
|
|
212
|
+
"""
|
|
213
|
+
if "output" not in response or "embeddings" not in response["output"]:
|
|
214
|
+
raise RuntimeError(f"Unexpected API response format: {response.keys()}")
|
|
215
|
+
|
|
216
|
+
embeddings = response["output"]["embeddings"]
|
|
217
|
+
return [item["embedding"] for item in embeddings]
|
|
218
|
+
|
|
219
|
+
def embed_code(
|
|
220
|
+
self,
|
|
221
|
+
text: str,
|
|
222
|
+
use_instruction: bool = False,
|
|
223
|
+
) -> list[float]:
|
|
224
|
+
"""Generate embedding for a single code snippet.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
text: Code text to embed
|
|
228
|
+
use_instruction: Whether to prepend instruction for queries
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
Embedding vector as list of floats
|
|
232
|
+
"""
|
|
233
|
+
if use_instruction:
|
|
234
|
+
text = self._get_detailed_instruct(self.CODE_RETRIEVAL_TASK, text)
|
|
235
|
+
|
|
236
|
+
try:
|
|
237
|
+
response = self._make_request([text], text_type="document")
|
|
238
|
+
embeddings = self._extract_embeddings(response)
|
|
239
|
+
return embeddings[0] if embeddings else []
|
|
240
|
+
except Exception as e:
|
|
241
|
+
logger.error(f"Failed to embed code: {e}")
|
|
242
|
+
raise
|
|
243
|
+
|
|
244
|
+
def embed_batch(
|
|
245
|
+
self,
|
|
246
|
+
texts: list[str],
|
|
247
|
+
use_instruction: bool = False,
|
|
248
|
+
show_progress: bool = False,
|
|
249
|
+
) -> list[list[float]]:
|
|
250
|
+
"""Generate embeddings for multiple code snippets.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
texts: List of code texts to embed
|
|
254
|
+
use_instruction: Whether to prepend instruction (for queries)
|
|
255
|
+
show_progress: Whether to show progress bar
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
List of embedding vectors
|
|
259
|
+
"""
|
|
260
|
+
if not texts:
|
|
261
|
+
return []
|
|
262
|
+
|
|
263
|
+
if use_instruction:
|
|
264
|
+
texts = [
|
|
265
|
+
self._get_detailed_instruct(self.CODE_RETRIEVAL_TASK, t)
|
|
266
|
+
for t in texts
|
|
267
|
+
]
|
|
268
|
+
|
|
269
|
+
all_embeddings: list[list[float]] = []
|
|
270
|
+
|
|
271
|
+
# Process in batches
|
|
272
|
+
iterator = range(0, len(texts), self.batch_size)
|
|
273
|
+
if show_progress:
|
|
274
|
+
try:
|
|
275
|
+
from tqdm import tqdm
|
|
276
|
+
|
|
277
|
+
iterator = tqdm(
|
|
278
|
+
iterator,
|
|
279
|
+
desc="Generating embeddings",
|
|
280
|
+
total=(len(texts) + self.batch_size - 1) // self.batch_size,
|
|
281
|
+
)
|
|
282
|
+
except ImportError:
|
|
283
|
+
pass
|
|
284
|
+
|
|
285
|
+
for i in iterator:
|
|
286
|
+
batch_texts = texts[i : i + self.batch_size]
|
|
287
|
+
|
|
288
|
+
try:
|
|
289
|
+
response = self._make_request(batch_texts, text_type="document")
|
|
290
|
+
batch_embeddings = self._extract_embeddings(response)
|
|
291
|
+
all_embeddings.extend(batch_embeddings)
|
|
292
|
+
except Exception as e:
|
|
293
|
+
batch_num = i // self.batch_size + 1
|
|
294
|
+
total_batches = (len(texts) + self.batch_size - 1) // self.batch_size
|
|
295
|
+
logger.error(
|
|
296
|
+
f"Embedding batch {batch_num}/{total_batches} failed: {e}"
|
|
297
|
+
)
|
|
298
|
+
raise RuntimeError(
|
|
299
|
+
f"Embedding API call failed at batch {batch_num}/{total_batches}: {e}. "
|
|
300
|
+
f"Successfully embedded {len(all_embeddings)}/{len(texts)} texts before failure."
|
|
301
|
+
) from e
|
|
302
|
+
|
|
303
|
+
return all_embeddings
|
|
304
|
+
|
|
305
|
+
def embed_documents(self, documents: list[str], show_progress: bool = True) -> list[list[float]]:
|
|
306
|
+
"""Generate embeddings for documents (code snippets).
|
|
307
|
+
|
|
308
|
+
This is for indexing documents (no instruction needed).
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
documents: List of document texts
|
|
312
|
+
show_progress: Whether to show progress bar
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
List of embedding vectors
|
|
316
|
+
"""
|
|
317
|
+
return self.embed_batch(
|
|
318
|
+
documents,
|
|
319
|
+
use_instruction=False,
|
|
320
|
+
show_progress=show_progress,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
def embed_query(self, query: str) -> list[float]:
|
|
324
|
+
"""Generate embedding for a query.
|
|
325
|
+
|
|
326
|
+
This is for search queries (with instruction for better retrieval).
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
query: Query text
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
Embedding vector as list of floats
|
|
333
|
+
"""
|
|
334
|
+
return self.embed_code(query, use_instruction=True)
|
|
335
|
+
|
|
336
|
+
def _get_detailed_instruct(self, task_description: str, query: str) -> str:
|
|
337
|
+
"""Format query with instruction for better retrieval performance.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
task_description: Task description
|
|
341
|
+
query: Query text
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
Formatted query with instruction
|
|
345
|
+
"""
|
|
346
|
+
return f"Instruct: {task_description}\nQuery: {query}"
|
|
347
|
+
|
|
348
|
+
def get_embedding_dimension(self) -> int:
|
|
349
|
+
"""Get the embedding dimension for this model.
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
Embedding dimension size
|
|
353
|
+
"""
|
|
354
|
+
return self.EMBEDDING_DIMENSION
|
|
355
|
+
|
|
356
|
+
def health_check(self) -> bool:
|
|
357
|
+
"""Check if API is accessible and API key is valid.
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
True if healthy, False otherwise
|
|
361
|
+
"""
|
|
362
|
+
try:
|
|
363
|
+
# Make a simple request
|
|
364
|
+
test_text = "hello"
|
|
365
|
+
self.embed_code(test_text)
|
|
366
|
+
return True
|
|
367
|
+
except Exception as e:
|
|
368
|
+
logger.error(f"Health check failed: {e}")
|
|
369
|
+
return False
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
class OpenAIEmbedder(BaseEmbedder):
|
|
373
|
+
"""OpenAI-compatible embedding client.
|
|
374
|
+
|
|
375
|
+
Works with OpenAI, Azure OpenAI, and any API implementing the
|
|
376
|
+
``/v1/embeddings`` endpoint (e.g. local ollama, vLLM, LiteLLM).
|
|
377
|
+
|
|
378
|
+
Env vars (fallback order):
|
|
379
|
+
EMBEDDING_API_KEY / OPENAI_API_KEY / LLM_API_KEY
|
|
380
|
+
EMBEDDING_BASE_URL / OPENAI_BASE_URL / LLM_BASE_URL (default: https://api.openai.com/v1)
|
|
381
|
+
EMBEDDING_MODEL (default: text-embedding-3-small)
|
|
382
|
+
"""
|
|
383
|
+
|
|
384
|
+
DEFAULT_MODEL = "text-embedding-3-small"
|
|
385
|
+
DEFAULT_BASE_URL = "https://api.openai.com/v1"
|
|
386
|
+
# text-embedding-3-small = 1536, text-embedding-3-large = 3072
|
|
387
|
+
_KNOWN_DIMS: dict[str, int] = {
|
|
388
|
+
"text-embedding-3-small": 1536,
|
|
389
|
+
"text-embedding-3-large": 3072,
|
|
390
|
+
"text-embedding-ada-002": 1536,
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
def __init__(
|
|
394
|
+
self,
|
|
395
|
+
api_key: str | None = None,
|
|
396
|
+
model: str | None = None,
|
|
397
|
+
base_url: str | None = None,
|
|
398
|
+
batch_size: int = 20,
|
|
399
|
+
max_retries: int = 3,
|
|
400
|
+
dimension: int | None = None,
|
|
401
|
+
):
|
|
402
|
+
self.api_key = api_key or os.getenv("EMBEDDING_API_KEY") or os.getenv("OPENAI_API_KEY") or os.getenv("LLM_API_KEY")
|
|
403
|
+
if not self.api_key:
|
|
404
|
+
raise ValueError(
|
|
405
|
+
"OpenAI API key required. Set EMBEDDING_API_KEY, OPENAI_API_KEY, "
|
|
406
|
+
"or LLM_API_KEY environment variable."
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
self.model = model or os.getenv("EMBEDDING_MODEL", self.DEFAULT_MODEL)
|
|
410
|
+
self.base_url = (
|
|
411
|
+
base_url
|
|
412
|
+
or os.getenv("EMBEDDING_BASE_URL")
|
|
413
|
+
or os.getenv("OPENAI_BASE_URL")
|
|
414
|
+
or os.getenv("LLM_BASE_URL")
|
|
415
|
+
or self.DEFAULT_BASE_URL
|
|
416
|
+
).rstrip("/")
|
|
417
|
+
self.batch_size = batch_size
|
|
418
|
+
self.max_retries = max_retries
|
|
419
|
+
self._dimension = dimension or self._KNOWN_DIMS.get(self.model, 1536)
|
|
420
|
+
|
|
421
|
+
logger.info(f"Initialized OpenAIEmbedder with model: {self.model}")
|
|
422
|
+
|
|
423
|
+
def _make_request(self, texts: list[str]) -> list[list[float]]:
|
|
424
|
+
url = f"{self.base_url}/embeddings"
|
|
425
|
+
headers = {
|
|
426
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
427
|
+
"Content-Type": "application/json",
|
|
428
|
+
}
|
|
429
|
+
payload: dict[str, Any] = {
|
|
430
|
+
"model": self.model,
|
|
431
|
+
"input": texts,
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
for attempt in range(self.max_retries):
|
|
435
|
+
try:
|
|
436
|
+
response = requests.post(url, headers=headers, json=payload, timeout=60)
|
|
437
|
+
|
|
438
|
+
if response.status_code == 200:
|
|
439
|
+
data = response.json()
|
|
440
|
+
sorted_items = sorted(data["data"], key=lambda x: x["index"])
|
|
441
|
+
return [item["embedding"] for item in sorted_items]
|
|
442
|
+
|
|
443
|
+
if response.status_code == 429:
|
|
444
|
+
import time
|
|
445
|
+
wait_time = 2 ** attempt
|
|
446
|
+
logger.warning(f"Rate limited. Waiting {wait_time}s...")
|
|
447
|
+
time.sleep(wait_time)
|
|
448
|
+
continue
|
|
449
|
+
|
|
450
|
+
error_msg = f"OpenAI embeddings API error: {response.status_code}"
|
|
451
|
+
try:
|
|
452
|
+
err = response.json()
|
|
453
|
+
error_msg += f" - {err.get('error', {}).get('message', response.text[:200])}"
|
|
454
|
+
except Exception:
|
|
455
|
+
error_msg += f" - {response.text[:200]}"
|
|
456
|
+
|
|
457
|
+
if attempt < self.max_retries - 1:
|
|
458
|
+
logger.warning(f"{error_msg}, retrying...")
|
|
459
|
+
continue
|
|
460
|
+
raise RuntimeError(error_msg)
|
|
461
|
+
|
|
462
|
+
except requests.exceptions.Timeout:
|
|
463
|
+
if attempt < self.max_retries - 1:
|
|
464
|
+
logger.warning(f"Request timeout, retrying ({attempt + 1}/{self.max_retries})...")
|
|
465
|
+
continue
|
|
466
|
+
raise RuntimeError("OpenAI embeddings API timeout after all retries")
|
|
467
|
+
except requests.exceptions.RequestException as e:
|
|
468
|
+
if attempt < self.max_retries - 1:
|
|
469
|
+
logger.warning(f"Request error: {e}, retrying...")
|
|
470
|
+
continue
|
|
471
|
+
raise RuntimeError(f"OpenAI embeddings API request failed: {e}")
|
|
472
|
+
|
|
473
|
+
raise RuntimeError("All retries failed")
|
|
474
|
+
|
|
475
|
+
def embed_code(self, text: str) -> list[float]:
|
|
476
|
+
results = self._make_request([text])
|
|
477
|
+
return results[0] if results else []
|
|
478
|
+
|
|
479
|
+
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
480
|
+
if not texts:
|
|
481
|
+
return []
|
|
482
|
+
all_embeddings: list[list[float]] = []
|
|
483
|
+
for i in range(0, len(texts), self.batch_size):
|
|
484
|
+
batch = texts[i : i + self.batch_size]
|
|
485
|
+
all_embeddings.extend(self._make_request(batch))
|
|
486
|
+
return all_embeddings
|
|
487
|
+
|
|
488
|
+
def get_embedding_dimension(self) -> int:
|
|
489
|
+
return self._dimension
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
class DummyEmbedder(BaseEmbedder):
|
|
493
|
+
"""Dummy embedder for testing without API calls.
|
|
494
|
+
|
|
495
|
+
Returns zero vectors of specified dimension.
|
|
496
|
+
"""
|
|
497
|
+
|
|
498
|
+
def __init__(self, dimension: int = 1536):
|
|
499
|
+
self.dimension = dimension
|
|
500
|
+
|
|
501
|
+
def embed_code(self, text: str) -> list[float]:
|
|
502
|
+
"""Return zero vector."""
|
|
503
|
+
return [0.0] * self.dimension
|
|
504
|
+
|
|
505
|
+
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
506
|
+
"""Return list of zero vectors."""
|
|
507
|
+
return [[0.0] * self.dimension for _ in texts]
|
|
508
|
+
|
|
509
|
+
def get_embedding_dimension(self) -> int:
|
|
510
|
+
return self.dimension
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def create_embedder(
|
|
514
|
+
api_key: str | None = None,
|
|
515
|
+
model: str | None = None,
|
|
516
|
+
use_dummy: bool = False,
|
|
517
|
+
provider: str | None = None,
|
|
518
|
+
**kwargs: Any,
|
|
519
|
+
) -> BaseEmbedder:
|
|
520
|
+
"""Factory function to create an embedder.
|
|
521
|
+
|
|
522
|
+
Provider detection order:
|
|
523
|
+
1. Explicit ``provider`` argument (``"qwen3"``, ``"openai"``, ``"dummy"``).
|
|
524
|
+
2. ``EMBEDDING_PROVIDER`` env var.
|
|
525
|
+
3. Auto-detect: if ``DASHSCOPE_API_KEY`` is set → Qwen3,
|
|
526
|
+
elif ``EMBEDDING_API_KEY`` or ``OPENAI_API_KEY`` or ``LLM_API_KEY`` → OpenAI-compatible,
|
|
527
|
+
else → DummyEmbedder (with a warning).
|
|
528
|
+
|
|
529
|
+
Args:
|
|
530
|
+
api_key: API key override (passed to chosen embedder).
|
|
531
|
+
model: Model name override.
|
|
532
|
+
use_dummy: Force dummy embedder (for tests).
|
|
533
|
+
provider: Explicit provider name.
|
|
534
|
+
**kwargs: Extra arguments forwarded to the embedder constructor.
|
|
535
|
+
|
|
536
|
+
Returns:
|
|
537
|
+
BaseEmbedder instance.
|
|
538
|
+
"""
|
|
539
|
+
if use_dummy:
|
|
540
|
+
return DummyEmbedder()
|
|
541
|
+
|
|
542
|
+
chosen = (provider or os.getenv("EMBEDDING_PROVIDER", "")).lower()
|
|
543
|
+
|
|
544
|
+
if not chosen:
|
|
545
|
+
# Auto-detect
|
|
546
|
+
if os.getenv("DASHSCOPE_API_KEY"):
|
|
547
|
+
chosen = "qwen3"
|
|
548
|
+
elif os.getenv("EMBEDDING_API_KEY") or os.getenv("OPENAI_API_KEY") or os.getenv("LLM_API_KEY"):
|
|
549
|
+
chosen = "openai"
|
|
550
|
+
else:
|
|
551
|
+
logger.warning("No embedding API key found. Using DummyEmbedder (zero vectors).")
|
|
552
|
+
return DummyEmbedder()
|
|
553
|
+
|
|
554
|
+
embedder_kwargs: dict[str, Any] = {}
|
|
555
|
+
if api_key:
|
|
556
|
+
embedder_kwargs["api_key"] = api_key
|
|
557
|
+
if model:
|
|
558
|
+
embedder_kwargs["model"] = model
|
|
559
|
+
embedder_kwargs.update(kwargs)
|
|
560
|
+
|
|
561
|
+
if chosen == "qwen3":
|
|
562
|
+
return Qwen3Embedder(**embedder_kwargs)
|
|
563
|
+
elif chosen == "openai":
|
|
564
|
+
return OpenAIEmbedder(**embedder_kwargs)
|
|
565
|
+
else:
|
|
566
|
+
raise ValueError(f"Unknown embedding provider: {chosen!r}. Use 'qwen3', 'openai', or 'dummy'.")
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
# Keep last_token_pool for backward compatibility (not used in API mode)
|
|
570
|
+
def last_token_pool(last_hidden_states: Any, attention_mask: Any) -> Any:
|
|
571
|
+
"""Legacy function - not used in API mode. Kept for compatibility."""
|
|
572
|
+
logger.warning("last_token_pool is deprecated when using API mode")
|
|
573
|
+
return last_hidden_states
|