claude-self-reflect 3.0.0 → 3.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude/agents/claude-self-reflect-test.md +110 -66
- package/README.md +1 -1
- package/installer/setup-wizard.js +4 -2
- package/mcp-server/pyproject.toml +1 -0
- package/mcp-server/src/server.py +84 -0
- package/package.json +2 -1
- package/scripts/import-conversations-unified.py +225 -44
- package/scripts/importer/__init__.py +25 -0
- package/scripts/importer/__main__.py +14 -0
- package/scripts/importer/core/__init__.py +25 -0
- package/scripts/importer/core/config.py +120 -0
- package/scripts/importer/core/exceptions.py +52 -0
- package/scripts/importer/core/models.py +184 -0
- package/scripts/importer/embeddings/__init__.py +22 -0
- package/scripts/importer/embeddings/base.py +141 -0
- package/scripts/importer/embeddings/fastembed_provider.py +164 -0
- package/scripts/importer/embeddings/validator.py +136 -0
- package/scripts/importer/embeddings/voyage_provider.py +251 -0
- package/scripts/importer/main.py +393 -0
- package/scripts/importer/processors/__init__.py +15 -0
- package/scripts/importer/processors/ast_extractor.py +197 -0
- package/scripts/importer/processors/chunker.py +157 -0
- package/scripts/importer/processors/concept_extractor.py +109 -0
- package/scripts/importer/processors/conversation_parser.py +181 -0
- package/scripts/importer/processors/tool_extractor.py +165 -0
- package/scripts/importer/state/__init__.py +5 -0
- package/scripts/importer/state/state_manager.py +190 -0
- package/scripts/importer/storage/__init__.py +5 -0
- package/scripts/importer/storage/qdrant_storage.py +250 -0
- package/scripts/importer/utils/__init__.py +9 -0
- package/scripts/importer/utils/logger.py +87 -0
- package/scripts/importer/utils/project_normalizer.py +120 -0
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""Core domain models for the import system."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import List, Dict, Any, Optional, Set
|
|
6
|
+
from uuid import UUID, uuid4
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class Message:
|
|
11
|
+
"""A single message in a conversation."""
|
|
12
|
+
|
|
13
|
+
role: str
|
|
14
|
+
content: str
|
|
15
|
+
timestamp: Optional[datetime] = None
|
|
16
|
+
message_index: int = 0
|
|
17
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
18
|
+
|
|
19
|
+
def __post_init__(self):
|
|
20
|
+
"""Validate message on creation."""
|
|
21
|
+
if not self.role:
|
|
22
|
+
raise ValueError("Message role cannot be empty")
|
|
23
|
+
if self.role not in {"user", "assistant", "system", "human"}:
|
|
24
|
+
# Allow common variations
|
|
25
|
+
pass # Log warning but don't fail
|
|
26
|
+
if self.message_index < 0:
|
|
27
|
+
raise ValueError(f"Message index cannot be negative: {self.message_index}")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class ConversationChunk:
|
|
32
|
+
"""A chunk of conversation ready for embedding."""
|
|
33
|
+
|
|
34
|
+
text: str
|
|
35
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
36
|
+
message_indices: List[int] = field(default_factory=list)
|
|
37
|
+
chunk_index: int = 0
|
|
38
|
+
total_chunks: int = 1
|
|
39
|
+
conversation_id: str = field(default_factory=lambda: str(uuid4()))
|
|
40
|
+
|
|
41
|
+
def __post_init__(self):
|
|
42
|
+
"""Validate chunk on creation."""
|
|
43
|
+
if not self.text:
|
|
44
|
+
raise ValueError("Chunk text cannot be empty")
|
|
45
|
+
if self.chunk_index < 0:
|
|
46
|
+
raise ValueError(f"Chunk index cannot be negative: {self.chunk_index}")
|
|
47
|
+
if self.chunk_index >= self.total_chunks:
|
|
48
|
+
raise ValueError(
|
|
49
|
+
f"Chunk index ({self.chunk_index}) must be less than "
|
|
50
|
+
f"total chunks ({self.total_chunks})"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def unique_id(self) -> str:
|
|
55
|
+
"""Generate unique ID for this chunk."""
|
|
56
|
+
return f"{self.conversation_id}_{self.chunk_index}"
|
|
57
|
+
|
|
58
|
+
def add_metadata(self, key: str, value: Any) -> None:
|
|
59
|
+
"""Add metadata with conflict detection."""
|
|
60
|
+
if key in self.metadata:
|
|
61
|
+
# Handle conflict - could merge, replace, or raise
|
|
62
|
+
if isinstance(self.metadata[key], list) and isinstance(value, list):
|
|
63
|
+
# Merge lists
|
|
64
|
+
self.metadata[key] = list(set(self.metadata[key] + value))
|
|
65
|
+
else:
|
|
66
|
+
# Replace value (log warning in production)
|
|
67
|
+
self.metadata[key] = value
|
|
68
|
+
else:
|
|
69
|
+
self.metadata[key] = value
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class ProcessedPoint:
|
|
74
|
+
"""A fully processed point ready for storage."""
|
|
75
|
+
|
|
76
|
+
id: str
|
|
77
|
+
vector: List[float]
|
|
78
|
+
payload: Dict[str, Any]
|
|
79
|
+
|
|
80
|
+
def __post_init__(self):
|
|
81
|
+
"""Validate point on creation."""
|
|
82
|
+
if not self.id:
|
|
83
|
+
raise ValueError("Point ID cannot be empty")
|
|
84
|
+
if not self.vector:
|
|
85
|
+
raise ValueError("Point vector cannot be empty")
|
|
86
|
+
if not isinstance(self.vector, list):
|
|
87
|
+
raise TypeError(f"Vector must be a list, got {type(self.vector)}")
|
|
88
|
+
if not all(isinstance(x, (int, float)) for x in self.vector):
|
|
89
|
+
raise TypeError("Vector must contain only numeric values")
|
|
90
|
+
if not self.payload:
|
|
91
|
+
# Empty payload is allowed but unusual
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def dimension(self) -> int:
|
|
96
|
+
"""Get vector dimension."""
|
|
97
|
+
return len(self.vector)
|
|
98
|
+
|
|
99
|
+
def validate_dimension(self, expected: int) -> bool:
|
|
100
|
+
"""Check if vector has expected dimension."""
|
|
101
|
+
return self.dimension == expected
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@dataclass
|
|
105
|
+
class ImportResult:
|
|
106
|
+
"""Result of an import operation."""
|
|
107
|
+
|
|
108
|
+
file_path: str
|
|
109
|
+
success: bool
|
|
110
|
+
points_created: int = 0
|
|
111
|
+
chunks_processed: int = 0
|
|
112
|
+
error: Optional[str] = None
|
|
113
|
+
duration_seconds: float = 0.0
|
|
114
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def failed(self) -> bool:
|
|
118
|
+
"""Check if import failed."""
|
|
119
|
+
return not self.success
|
|
120
|
+
|
|
121
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
122
|
+
"""Convert to dictionary for serialization."""
|
|
123
|
+
return {
|
|
124
|
+
"file_path": self.file_path,
|
|
125
|
+
"success": self.success,
|
|
126
|
+
"points_created": self.points_created,
|
|
127
|
+
"chunks_processed": self.chunks_processed,
|
|
128
|
+
"error": self.error,
|
|
129
|
+
"duration_seconds": self.duration_seconds,
|
|
130
|
+
"metadata": self.metadata
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@dataclass
|
|
135
|
+
class ImportStats:
|
|
136
|
+
"""Aggregate statistics for import operations."""
|
|
137
|
+
|
|
138
|
+
total_files: int = 0
|
|
139
|
+
successful_files: int = 0
|
|
140
|
+
failed_files: int = 0
|
|
141
|
+
total_points: int = 0
|
|
142
|
+
total_chunks: int = 0
|
|
143
|
+
total_duration_seconds: float = 0.0
|
|
144
|
+
errors: List[str] = field(default_factory=list)
|
|
145
|
+
|
|
146
|
+
def add_result(self, result: ImportResult) -> None:
|
|
147
|
+
"""Add a result to the statistics."""
|
|
148
|
+
self.total_files += 1
|
|
149
|
+
if result.success:
|
|
150
|
+
self.successful_files += 1
|
|
151
|
+
self.total_points += result.points_created
|
|
152
|
+
self.total_chunks += result.chunks_processed
|
|
153
|
+
else:
|
|
154
|
+
self.failed_files += 1
|
|
155
|
+
if result.error:
|
|
156
|
+
self.errors.append(f"{result.file_path}: {result.error}")
|
|
157
|
+
self.total_duration_seconds += result.duration_seconds
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def success_rate(self) -> float:
|
|
161
|
+
"""Calculate success rate."""
|
|
162
|
+
if self.total_files == 0:
|
|
163
|
+
return 0.0
|
|
164
|
+
return self.successful_files / self.total_files * 100
|
|
165
|
+
|
|
166
|
+
@property
|
|
167
|
+
def average_duration(self) -> float:
|
|
168
|
+
"""Calculate average import duration."""
|
|
169
|
+
if self.total_files == 0:
|
|
170
|
+
return 0.0
|
|
171
|
+
return self.total_duration_seconds / self.total_files
|
|
172
|
+
|
|
173
|
+
def summary(self) -> str:
|
|
174
|
+
"""Generate summary string."""
|
|
175
|
+
return (
|
|
176
|
+
f"Import Statistics:\n"
|
|
177
|
+
f" Total Files: {self.total_files}\n"
|
|
178
|
+
f" Successful: {self.successful_files} ({self.success_rate:.1f}%)\n"
|
|
179
|
+
f" Failed: {self.failed_files}\n"
|
|
180
|
+
f" Total Points: {self.total_points}\n"
|
|
181
|
+
f" Total Chunks: {self.total_chunks}\n"
|
|
182
|
+
f" Total Duration: {self.total_duration_seconds:.2f}s\n"
|
|
183
|
+
f" Average Duration: {self.average_duration:.2f}s per file"
|
|
184
|
+
)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Embedding providers for text vectorization."""
|
|
2
|
+
|
|
3
|
+
from .base import EmbeddingProvider
|
|
4
|
+
from .fastembed_provider import FastEmbedProvider
|
|
5
|
+
from .validator import EmbeddingValidator
|
|
6
|
+
|
|
7
|
+
# Conditional import for Voyage
|
|
8
|
+
try:
|
|
9
|
+
from .voyage_provider import VoyageEmbeddingProvider
|
|
10
|
+
__all__ = [
|
|
11
|
+
"EmbeddingProvider",
|
|
12
|
+
"FastEmbedProvider",
|
|
13
|
+
"VoyageEmbeddingProvider",
|
|
14
|
+
"EmbeddingValidator"
|
|
15
|
+
]
|
|
16
|
+
except ImportError:
|
|
17
|
+
# Voyage not available, continue without it
|
|
18
|
+
__all__ = [
|
|
19
|
+
"EmbeddingProvider",
|
|
20
|
+
"FastEmbedProvider",
|
|
21
|
+
"EmbeddingValidator"
|
|
22
|
+
]
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""Abstract base class for embedding providers."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import List, Optional, Dict, Any
|
|
5
|
+
from ..core.exceptions import EmbeddingError
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EmbeddingProvider(ABC):
|
|
9
|
+
"""
|
|
10
|
+
Abstract interface for embedding providers.
|
|
11
|
+
|
|
12
|
+
Defines the contract that all embedding providers must implement,
|
|
13
|
+
including error handling methods as recommended by code review.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self):
|
|
17
|
+
self._last_error: Optional[Exception] = None
|
|
18
|
+
self._initialized: bool = False
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def initialize(self, config: Any) -> None:
|
|
22
|
+
"""
|
|
23
|
+
Initialize the embedding provider.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
config: Configuration object with provider-specific settings
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
EmbeddingError: If initialization fails
|
|
30
|
+
"""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
|
35
|
+
"""
|
|
36
|
+
Generate embeddings for a list of texts.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
texts: List of text strings to embed
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
List of embedding vectors
|
|
43
|
+
|
|
44
|
+
Raises:
|
|
45
|
+
EmbeddingError: If embedding generation fails
|
|
46
|
+
"""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def get_dimension(self) -> int:
|
|
51
|
+
"""
|
|
52
|
+
Get the dimension of embeddings produced by this provider.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Embedding dimension
|
|
56
|
+
"""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def validate_embedding(self, embedding: List[float]) -> bool:
|
|
61
|
+
"""
|
|
62
|
+
Validate that an embedding is well-formed and not degenerate.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
embedding: Embedding vector to validate
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
True if embedding is valid, False otherwise
|
|
69
|
+
"""
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def handle_initialization_error(self, error: Exception) -> None:
|
|
74
|
+
"""
|
|
75
|
+
Handle initialization failures.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
error: The exception that occurred during initialization
|
|
79
|
+
"""
|
|
80
|
+
pass
|
|
81
|
+
|
|
82
|
+
def get_last_error(self) -> Optional[Exception]:
|
|
83
|
+
"""
|
|
84
|
+
Retrieve the last error for diagnostics.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Last exception that occurred, or None
|
|
88
|
+
"""
|
|
89
|
+
return self._last_error
|
|
90
|
+
|
|
91
|
+
def is_initialized(self) -> bool:
|
|
92
|
+
"""
|
|
93
|
+
Check if provider is initialized and ready.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
True if initialized, False otherwise
|
|
97
|
+
"""
|
|
98
|
+
return self._initialized
|
|
99
|
+
|
|
100
|
+
def get_provider_info(self) -> Dict[str, Any]:
|
|
101
|
+
"""
|
|
102
|
+
Get information about this embedding provider.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Dictionary with provider metadata
|
|
106
|
+
"""
|
|
107
|
+
return {
|
|
108
|
+
"provider": self.__class__.__name__,
|
|
109
|
+
"initialized": self._initialized,
|
|
110
|
+
"dimension": self.get_dimension() if self._initialized else None,
|
|
111
|
+
"has_error": self._last_error is not None
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
def batch_embed_texts(
|
|
115
|
+
self,
|
|
116
|
+
texts: List[str],
|
|
117
|
+
batch_size: int = 32
|
|
118
|
+
) -> List[List[float]]:
|
|
119
|
+
"""
|
|
120
|
+
Generate embeddings in batches for memory efficiency.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
texts: List of texts to embed
|
|
124
|
+
batch_size: Number of texts to process at once
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
List of embedding vectors
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
EmbeddingError: If embedding generation fails
|
|
131
|
+
"""
|
|
132
|
+
if not self._initialized:
|
|
133
|
+
raise EmbeddingError("Provider not initialized", provider=self.__class__.__name__)
|
|
134
|
+
|
|
135
|
+
embeddings = []
|
|
136
|
+
for i in range(0, len(texts), batch_size):
|
|
137
|
+
batch = texts[i:i + batch_size]
|
|
138
|
+
batch_embeddings = self.embed_texts(batch)
|
|
139
|
+
embeddings.extend(batch_embeddings)
|
|
140
|
+
|
|
141
|
+
return embeddings
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""FastEmbed provider for local embeddings."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Any
|
|
4
|
+
import logging
|
|
5
|
+
import statistics
|
|
6
|
+
from .base import EmbeddingProvider
|
|
7
|
+
from ..core.exceptions import EmbeddingError
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FastEmbedProvider(EmbeddingProvider):
|
|
13
|
+
"""
|
|
14
|
+
FastEmbed provider for generating embeddings locally.
|
|
15
|
+
|
|
16
|
+
Uses sentence-transformers/all-MiniLM-L6-v2 model by default.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.model = None
|
|
22
|
+
self.model_name = None
|
|
23
|
+
self.dimension = None
|
|
24
|
+
|
|
25
|
+
def initialize(self, config: Any) -> None:
|
|
26
|
+
"""Initialize FastEmbed with the specified model."""
|
|
27
|
+
try:
|
|
28
|
+
from fastembed import TextEmbedding
|
|
29
|
+
|
|
30
|
+
# CRITICAL: Use the correct model name
|
|
31
|
+
self.model_name = config.embedding_model
|
|
32
|
+
|
|
33
|
+
# FastEmbed uses specific model names
|
|
34
|
+
if self.model_name == "sentence-transformers/all-MiniLM-L6-v2":
|
|
35
|
+
# This is the correct model we must use
|
|
36
|
+
fastembed_model = "sentence-transformers/all-MiniLM-L6-v2"
|
|
37
|
+
else:
|
|
38
|
+
fastembed_model = self.model_name
|
|
39
|
+
|
|
40
|
+
logger.info(f"Initializing FastEmbed with model: {fastembed_model}")
|
|
41
|
+
|
|
42
|
+
self.model = TextEmbedding(model_name=fastembed_model)
|
|
43
|
+
self.dimension = config.embedding_dimension
|
|
44
|
+
self._initialized = True
|
|
45
|
+
|
|
46
|
+
logger.info(f"FastEmbed initialized successfully with dimension {self.dimension}")
|
|
47
|
+
|
|
48
|
+
except ImportError as e:
|
|
49
|
+
error = EmbeddingError(
|
|
50
|
+
"FastEmbed not installed. Install with: pip install fastembed",
|
|
51
|
+
provider="FastEmbed"
|
|
52
|
+
)
|
|
53
|
+
self.handle_initialization_error(error)
|
|
54
|
+
raise error
|
|
55
|
+
except Exception as e:
|
|
56
|
+
error = EmbeddingError(
|
|
57
|
+
f"Failed to initialize FastEmbed: {str(e)}",
|
|
58
|
+
provider="FastEmbed"
|
|
59
|
+
)
|
|
60
|
+
self.handle_initialization_error(error)
|
|
61
|
+
raise error
|
|
62
|
+
|
|
63
|
+
def embed(self, texts: List[str]) -> List[List[float]]:
|
|
64
|
+
"""Generate embeddings for texts using FastEmbed."""
|
|
65
|
+
if not self._initialized:
|
|
66
|
+
raise EmbeddingError("FastEmbed not initialized", provider="FastEmbed")
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
# FastEmbed returns a generator, convert to list
|
|
70
|
+
embeddings = list(self.model.embed(texts))
|
|
71
|
+
|
|
72
|
+
# Convert to regular Python lists with safe indexing
|
|
73
|
+
result = []
|
|
74
|
+
for i, embedding in enumerate(embeddings):
|
|
75
|
+
# Convert numpy array or similar to list
|
|
76
|
+
if hasattr(embedding, 'tolist'):
|
|
77
|
+
emb_list = embedding.tolist()
|
|
78
|
+
else:
|
|
79
|
+
emb_list = list(embedding)
|
|
80
|
+
|
|
81
|
+
# Validate each embedding
|
|
82
|
+
if not self.validate_embedding(emb_list):
|
|
83
|
+
# Safe indexing - use i which is guaranteed to be valid
|
|
84
|
+
text_len = len(texts[i]) if i < len(texts) else 0
|
|
85
|
+
raise EmbeddingError(
|
|
86
|
+
f"Invalid embedding generated for text {i} of length {text_len}",
|
|
87
|
+
provider="FastEmbed"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
result.append(emb_list)
|
|
91
|
+
|
|
92
|
+
return result
|
|
93
|
+
|
|
94
|
+
except Exception as e:
|
|
95
|
+
if not isinstance(e, EmbeddingError):
|
|
96
|
+
e = EmbeddingError(
|
|
97
|
+
f"Failed to generate embeddings: {str(e)}",
|
|
98
|
+
provider="FastEmbed"
|
|
99
|
+
)
|
|
100
|
+
self._last_error = e
|
|
101
|
+
raise e
|
|
102
|
+
|
|
103
|
+
def get_dimension(self) -> int:
|
|
104
|
+
"""Get embedding dimension."""
|
|
105
|
+
if not self._initialized:
|
|
106
|
+
raise EmbeddingError("FastEmbed not initialized", provider="FastEmbed")
|
|
107
|
+
return self.dimension
|
|
108
|
+
|
|
109
|
+
def validate_embedding(self, embedding: List[float]) -> bool:
|
|
110
|
+
"""
|
|
111
|
+
Validate embedding quality.
|
|
112
|
+
|
|
113
|
+
Checks:
|
|
114
|
+
1. Non-empty
|
|
115
|
+
2. Correct dimension
|
|
116
|
+
3. Not degenerate (all same values)
|
|
117
|
+
4. Has reasonable variance
|
|
118
|
+
"""
|
|
119
|
+
if not embedding:
|
|
120
|
+
logger.error("Empty embedding detected")
|
|
121
|
+
return False
|
|
122
|
+
|
|
123
|
+
# Check dimension
|
|
124
|
+
if len(embedding) != self.dimension:
|
|
125
|
+
logger.error(
|
|
126
|
+
f"Dimension mismatch: expected {self.dimension}, got {len(embedding)}"
|
|
127
|
+
)
|
|
128
|
+
return False
|
|
129
|
+
|
|
130
|
+
# Check for degenerate embedding (all values identical)
|
|
131
|
+
unique_values = len(set(embedding))
|
|
132
|
+
if unique_values == 1:
|
|
133
|
+
logger.error(f"Degenerate embedding detected (all values are {embedding[0]})")
|
|
134
|
+
return False
|
|
135
|
+
|
|
136
|
+
# Check variance is above threshold
|
|
137
|
+
try:
|
|
138
|
+
variance = statistics.variance(embedding)
|
|
139
|
+
if variance < 1e-6:
|
|
140
|
+
logger.warning(f"Low variance embedding detected: {variance}")
|
|
141
|
+
# Don't fail on low variance, just warn
|
|
142
|
+
except statistics.StatisticsError:
|
|
143
|
+
# Less than 2 data points
|
|
144
|
+
pass
|
|
145
|
+
|
|
146
|
+
# Check for NaN or Inf values
|
|
147
|
+
if any(not isinstance(x, (int, float)) or x != x or abs(x) == float('inf')
|
|
148
|
+
for x in embedding):
|
|
149
|
+
logger.error("Embedding contains NaN or Inf values")
|
|
150
|
+
return False
|
|
151
|
+
|
|
152
|
+
return True
|
|
153
|
+
|
|
154
|
+
def handle_initialization_error(self, error: Exception) -> None:
|
|
155
|
+
"""Handle and log initialization errors."""
|
|
156
|
+
self._last_error = error
|
|
157
|
+
self._initialized = False
|
|
158
|
+
logger.error(f"FastEmbed initialization failed: {error}")
|
|
159
|
+
|
|
160
|
+
# Could implement retry logic or fallback here
|
|
161
|
+
if "not installed" in str(error):
|
|
162
|
+
logger.info("Try: pip install fastembed")
|
|
163
|
+
elif "model" in str(error).lower():
|
|
164
|
+
logger.info(f"Model {self.model_name} may need to be downloaded first")
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
"""Embedding validation utilities."""
|
|
2
|
+
|
|
3
|
+
import statistics
|
|
4
|
+
from typing import List, Tuple, Optional
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class EmbeddingValidator:
|
|
11
|
+
"""
|
|
12
|
+
Comprehensive embedding validation.
|
|
13
|
+
|
|
14
|
+
Performs multiple checks to ensure embedding quality.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
expected_dimension: int,
|
|
20
|
+
min_variance: float = 1e-6,
|
|
21
|
+
max_magnitude: float = 100.0
|
|
22
|
+
):
|
|
23
|
+
self.expected_dimension = expected_dimension
|
|
24
|
+
self.min_variance = min_variance
|
|
25
|
+
self.max_magnitude = max_magnitude
|
|
26
|
+
|
|
27
|
+
def validate(self, embedding: List[float]) -> Tuple[bool, Optional[str]]:
|
|
28
|
+
"""
|
|
29
|
+
Validate an embedding vector.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Tuple of (is_valid, error_message)
|
|
33
|
+
"""
|
|
34
|
+
# Check empty
|
|
35
|
+
if not embedding:
|
|
36
|
+
return False, "Empty embedding"
|
|
37
|
+
|
|
38
|
+
# Check dimension
|
|
39
|
+
if len(embedding) != self.expected_dimension:
|
|
40
|
+
return False, f"Dimension mismatch: expected {self.expected_dimension}, got {len(embedding)}"
|
|
41
|
+
|
|
42
|
+
# Check for NaN/Inf
|
|
43
|
+
for i, val in enumerate(embedding):
|
|
44
|
+
if not isinstance(val, (int, float)):
|
|
45
|
+
return False, f"Non-numeric value at index {i}: {type(val)}"
|
|
46
|
+
if val != val: # NaN check
|
|
47
|
+
return False, f"NaN value at index {i}"
|
|
48
|
+
if abs(val) == float('inf'):
|
|
49
|
+
return False, f"Infinite value at index {i}"
|
|
50
|
+
|
|
51
|
+
# Check for degenerate (all same)
|
|
52
|
+
unique_count = len(set(embedding))
|
|
53
|
+
if unique_count == 1:
|
|
54
|
+
return False, f"Degenerate embedding (all values are {embedding[0]})"
|
|
55
|
+
|
|
56
|
+
# Check variance
|
|
57
|
+
if len(embedding) > 1:
|
|
58
|
+
try:
|
|
59
|
+
variance = statistics.variance(embedding)
|
|
60
|
+
if variance < self.min_variance:
|
|
61
|
+
# Warning, not error
|
|
62
|
+
logger.warning(f"Low variance: {variance}")
|
|
63
|
+
except Exception as e:
|
|
64
|
+
logger.warning(f"Could not calculate variance: {e}")
|
|
65
|
+
|
|
66
|
+
# Check magnitude
|
|
67
|
+
max_val = max(abs(v) for v in embedding)
|
|
68
|
+
if max_val > self.max_magnitude:
|
|
69
|
+
return False, f"Value exceeds maximum magnitude: {max_val}"
|
|
70
|
+
|
|
71
|
+
# Check for mostly zeros
|
|
72
|
+
zero_count = sum(1 for v in embedding if abs(v) < 1e-10)
|
|
73
|
+
if zero_count > len(embedding) * 0.9:
|
|
74
|
+
return False, f"Embedding is mostly zeros ({zero_count}/{len(embedding)})"
|
|
75
|
+
|
|
76
|
+
return True, None
|
|
77
|
+
|
|
78
|
+
def validate_batch(
|
|
79
|
+
self,
|
|
80
|
+
embeddings: List[List[float]]
|
|
81
|
+
) -> List[Tuple[int, str]]:
|
|
82
|
+
"""
|
|
83
|
+
Validate a batch of embeddings.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
List of (index, error_message) for invalid embeddings
|
|
87
|
+
"""
|
|
88
|
+
errors = []
|
|
89
|
+
for i, embedding in enumerate(embeddings):
|
|
90
|
+
valid, error = self.validate(embedding)
|
|
91
|
+
if not valid:
|
|
92
|
+
errors.append((i, error))
|
|
93
|
+
return errors
|
|
94
|
+
|
|
95
|
+
def check_similarity(
|
|
96
|
+
self,
|
|
97
|
+
embeddings: List[List[float]]
|
|
98
|
+
) -> bool:
|
|
99
|
+
"""
|
|
100
|
+
Check if embeddings in a batch are too similar.
|
|
101
|
+
|
|
102
|
+
This can indicate a problem with the embedding model.
|
|
103
|
+
"""
|
|
104
|
+
if len(embeddings) < 2:
|
|
105
|
+
return True
|
|
106
|
+
|
|
107
|
+
# Calculate pairwise cosine similarities
|
|
108
|
+
from math import sqrt
|
|
109
|
+
|
|
110
|
+
def cosine_similarity(a: List[float], b: List[float]) -> float:
|
|
111
|
+
dot_product = sum(x * y for x, y in zip(a, b))
|
|
112
|
+
norm_a = sqrt(sum(x * x for x in a))
|
|
113
|
+
norm_b = sqrt(sum(y * y for y in b))
|
|
114
|
+
if norm_a == 0 or norm_b == 0:
|
|
115
|
+
return 0
|
|
116
|
+
return dot_product / (norm_a * norm_b)
|
|
117
|
+
|
|
118
|
+
# Check if all embeddings are too similar
|
|
119
|
+
high_similarity_count = 0
|
|
120
|
+
total_pairs = 0
|
|
121
|
+
|
|
122
|
+
for i in range(len(embeddings)):
|
|
123
|
+
for j in range(i + 1, min(i + 5, len(embeddings))): # Check first 5 pairs
|
|
124
|
+
similarity = cosine_similarity(embeddings[i], embeddings[j])
|
|
125
|
+
if similarity > 0.99: # Nearly identical
|
|
126
|
+
high_similarity_count += 1
|
|
127
|
+
total_pairs += 1
|
|
128
|
+
|
|
129
|
+
if total_pairs > 0 and high_similarity_count / total_pairs > 0.8:
|
|
130
|
+
logger.warning(
|
|
131
|
+
f"High similarity detected: {high_similarity_count}/{total_pairs} "
|
|
132
|
+
f"pairs have >0.99 similarity"
|
|
133
|
+
)
|
|
134
|
+
return False
|
|
135
|
+
|
|
136
|
+
return True
|