claude-self-reflect 2.8.10 → 3.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/mcp-server/pyproject.toml +1 -0
- package/package.json +2 -1
- package/scripts/importer/__init__.py +25 -0
- package/scripts/importer/__main__.py +14 -0
- package/scripts/importer/core/__init__.py +25 -0
- package/scripts/importer/core/config.py +120 -0
- package/scripts/importer/core/exceptions.py +52 -0
- package/scripts/importer/core/models.py +184 -0
- package/scripts/importer/embeddings/__init__.py +22 -0
- package/scripts/importer/embeddings/base.py +141 -0
- package/scripts/importer/embeddings/fastembed_provider.py +164 -0
- package/scripts/importer/embeddings/validator.py +136 -0
- package/scripts/importer/embeddings/voyage_provider.py +251 -0
- package/scripts/importer/main.py +393 -0
- package/scripts/importer/processors/__init__.py +15 -0
- package/scripts/importer/processors/ast_extractor.py +197 -0
- package/scripts/importer/processors/chunker.py +157 -0
- package/scripts/importer/processors/concept_extractor.py +109 -0
- package/scripts/importer/processors/conversation_parser.py +181 -0
- package/scripts/importer/processors/tool_extractor.py +165 -0
- package/scripts/importer/state/__init__.py +5 -0
- package/scripts/importer/state/state_manager.py +190 -0
- package/scripts/importer/storage/__init__.py +5 -0
- package/scripts/importer/storage/qdrant_storage.py +250 -0
- package/scripts/importer/utils/__init__.py +9 -0
- package/scripts/importer/utils/logger.py +87 -0
- package/scripts/importer/utils/project_normalizer.py +120 -0
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "claude-self-reflect",
|
|
3
|
-
"version": "
|
|
3
|
+
"version": "3.0.1",
|
|
4
4
|
"description": "Give Claude perfect memory of all your conversations - Installation wizard for Python MCP server",
|
|
5
5
|
"keywords": [
|
|
6
6
|
"claude",
|
|
@@ -41,6 +41,7 @@
|
|
|
41
41
|
"mcp-server/run-mcp-clean.sh",
|
|
42
42
|
"mcp-server/run-mcp-docker.sh",
|
|
43
43
|
"scripts/import-*.py",
|
|
44
|
+
"scripts/importer/**/*.py",
|
|
44
45
|
"scripts/delta-metadata-update-safe.py",
|
|
45
46
|
"scripts/force-metadata-recovery.py",
|
|
46
47
|
".claude/agents/*.md",
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Claude Self-Reflect Modular Import System
|
|
3
|
+
==========================================
|
|
4
|
+
|
|
5
|
+
A pristine, modular conversation import system following SOLID principles
|
|
6
|
+
and clean architecture patterns.
|
|
7
|
+
|
|
8
|
+
Version: 3.0.0
|
|
9
|
+
Author: Claude Self-Reflect Team
|
|
10
|
+
License: MIT
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from .core.config import ImportConfig
|
|
14
|
+
from .core.models import Message, ConversationChunk, ProcessedPoint
|
|
15
|
+
from .main import ConversationProcessor, ImporterContainer
|
|
16
|
+
|
|
17
|
+
__version__ = "3.0.0"
|
|
18
|
+
__all__ = [
|
|
19
|
+
"ImportConfig",
|
|
20
|
+
"Message",
|
|
21
|
+
"ConversationChunk",
|
|
22
|
+
"ProcessedPoint",
|
|
23
|
+
"ConversationProcessor",
|
|
24
|
+
"ImporterContainer"
|
|
25
|
+
]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Entry point for running the importer as a module."""
|
|
3
|
+
|
|
4
|
+
import sys
|
|
5
|
+
import logging
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
# Add parent directory to path for standalone execution
|
|
9
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
10
|
+
|
|
11
|
+
from importer.main import main
|
|
12
|
+
|
|
13
|
+
if __name__ == "__main__":
|
|
14
|
+
sys.exit(main())
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Core domain models and configuration."""
|
|
2
|
+
|
|
3
|
+
from .config import ImportConfig
|
|
4
|
+
from .models import Message, ConversationChunk, ProcessedPoint, ImportResult, ImportStats
|
|
5
|
+
from .exceptions import (
|
|
6
|
+
ImportError,
|
|
7
|
+
ValidationError,
|
|
8
|
+
EmbeddingError,
|
|
9
|
+
StorageError,
|
|
10
|
+
ParseError
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"ImportConfig",
|
|
15
|
+
"Message",
|
|
16
|
+
"ConversationChunk",
|
|
17
|
+
"ProcessedPoint",
|
|
18
|
+
"ImportResult",
|
|
19
|
+
"ImportStats",
|
|
20
|
+
"ImportError",
|
|
21
|
+
"ValidationError",
|
|
22
|
+
"EmbeddingError",
|
|
23
|
+
"StorageError",
|
|
24
|
+
"ParseError"
|
|
25
|
+
]
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Immutable configuration with validation."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(frozen=True)
|
|
10
|
+
class ImportConfig:
|
|
11
|
+
"""
|
|
12
|
+
Immutable configuration for the import system.
|
|
13
|
+
|
|
14
|
+
All validation happens in __post_init__ to ensure configuration
|
|
15
|
+
is always in a valid state.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
# Qdrant settings
|
|
19
|
+
qdrant_url: str = field(default="http://localhost:6333")
|
|
20
|
+
qdrant_api_key: Optional[str] = field(default=None)
|
|
21
|
+
|
|
22
|
+
# Embedding settings
|
|
23
|
+
embedding_model: str = field(default="sentence-transformers/all-MiniLM-L6-v2")
|
|
24
|
+
embedding_dimension: int = field(default=384)
|
|
25
|
+
use_voyage: bool = field(default=False)
|
|
26
|
+
voyage_api_key: Optional[str] = field(default=None)
|
|
27
|
+
|
|
28
|
+
# Chunking settings
|
|
29
|
+
chunk_size: int = field(default=3000)
|
|
30
|
+
chunk_overlap: int = field(default=200)
|
|
31
|
+
|
|
32
|
+
# Processing settings
|
|
33
|
+
batch_size: int = field(default=10)
|
|
34
|
+
max_ast_elements: int = field(default=100)
|
|
35
|
+
max_workers: int = field(default=4)
|
|
36
|
+
|
|
37
|
+
# State management
|
|
38
|
+
state_file: str = field(default="~/.claude-self-reflect/config/imported-files.json")
|
|
39
|
+
|
|
40
|
+
# Operational settings
|
|
41
|
+
log_level: str = field(default="INFO")
|
|
42
|
+
dry_run: bool = field(default=False)
|
|
43
|
+
force_reimport: bool = field(default=False)
|
|
44
|
+
|
|
45
|
+
# Limits
|
|
46
|
+
file_limit: Optional[int] = field(default=None)
|
|
47
|
+
|
|
48
|
+
def __post_init__(self):
|
|
49
|
+
"""Validate configuration on initialization."""
|
|
50
|
+
# Validate chunk settings
|
|
51
|
+
if self.chunk_size <= 0:
|
|
52
|
+
raise ValueError(f"chunk_size must be positive, got {self.chunk_size}")
|
|
53
|
+
|
|
54
|
+
if self.chunk_overlap < 0:
|
|
55
|
+
raise ValueError(f"chunk_overlap cannot be negative, got {self.chunk_overlap}")
|
|
56
|
+
|
|
57
|
+
if self.chunk_overlap >= self.chunk_size:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"chunk_overlap ({self.chunk_overlap}) must be less than "
|
|
60
|
+
f"chunk_size ({self.chunk_size})"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Validate batch settings
|
|
64
|
+
if self.batch_size < 1:
|
|
65
|
+
raise ValueError(f"batch_size must be at least 1, got {self.batch_size}")
|
|
66
|
+
|
|
67
|
+
if self.max_workers < 1:
|
|
68
|
+
raise ValueError(f"max_workers must be at least 1, got {self.max_workers}")
|
|
69
|
+
|
|
70
|
+
# Validate embedding settings
|
|
71
|
+
if self.embedding_dimension <= 0:
|
|
72
|
+
raise ValueError(f"embedding_dimension must be positive, got {self.embedding_dimension}")
|
|
73
|
+
|
|
74
|
+
if self.use_voyage and not self.voyage_api_key:
|
|
75
|
+
# Document the limitation of frozen dataclass
|
|
76
|
+
voyage_key = os.getenv("VOYAGE_KEY")
|
|
77
|
+
if not voyage_key:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"voyage_api_key must be provided at initialization when use_voyage=True. "
|
|
80
|
+
"Set VOYAGE_KEY environment variable before creating config."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Validate log level
|
|
84
|
+
valid_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
|
|
85
|
+
if self.log_level.upper() not in valid_levels:
|
|
86
|
+
raise ValueError(f"log_level must be one of {valid_levels}, got {self.log_level}")
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def state_file_path(self) -> Path:
|
|
90
|
+
"""Get expanded state file path with fallback."""
|
|
91
|
+
try:
|
|
92
|
+
return Path(self.state_file).expanduser()
|
|
93
|
+
except (RuntimeError, OSError):
|
|
94
|
+
# Fallback to current directory if expansion fails
|
|
95
|
+
return Path.cwd() / ".import-state.json"
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def from_env(cls) -> "ImportConfig":
|
|
99
|
+
"""Create configuration from environment variables."""
|
|
100
|
+
return cls(
|
|
101
|
+
qdrant_url=os.getenv("QDRANT_URL", "http://localhost:6333"),
|
|
102
|
+
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
|
103
|
+
use_voyage=os.getenv("USE_VOYAGE", "false").lower() == "true",
|
|
104
|
+
voyage_api_key=os.getenv("VOYAGE_KEY"),
|
|
105
|
+
chunk_size=int(os.getenv("CHUNK_SIZE", "3000")),
|
|
106
|
+
chunk_overlap=int(os.getenv("CHUNK_OVERLAP", "200")),
|
|
107
|
+
batch_size=int(os.getenv("BATCH_SIZE", "10")),
|
|
108
|
+
max_workers=int(os.getenv("MAX_WORKERS", "4")),
|
|
109
|
+
log_level=os.getenv("LOG_LEVEL", "INFO"),
|
|
110
|
+
dry_run=os.getenv("DRY_RUN", "false").lower() == "true",
|
|
111
|
+
force_reimport=os.getenv("FORCE_REIMPORT", "false").lower() == "true"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def from_dict(cls, config_dict: dict) -> "ImportConfig":
|
|
116
|
+
"""Create configuration from dictionary."""
|
|
117
|
+
# Filter out any unknown keys
|
|
118
|
+
known_fields = {f.name for f in cls.__dataclass_fields__.values()}
|
|
119
|
+
filtered_dict = {k: v for k, v in config_dict.items() if k in known_fields}
|
|
120
|
+
return cls(**filtered_dict)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""Custom exception hierarchy for import system."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ImportError(Exception):
|
|
7
|
+
"""Base exception for all import-related errors."""
|
|
8
|
+
|
|
9
|
+
def __init__(self, message: str, details: Optional[dict] = None):
|
|
10
|
+
super().__init__(message)
|
|
11
|
+
self.details = details or {}
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ValidationError(ImportError):
|
|
15
|
+
"""Raised when input validation fails."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, field: str, value: Any, reason: str):
|
|
18
|
+
super().__init__(f"Validation failed for {field}: {reason}")
|
|
19
|
+
self.field = field
|
|
20
|
+
self.value = value
|
|
21
|
+
self.reason = reason
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class EmbeddingError(ImportError):
|
|
25
|
+
"""Raised when embedding generation or validation fails."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, message: str, provider: Optional[str] = None):
|
|
28
|
+
super().__init__(message)
|
|
29
|
+
self.provider = provider
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class StorageError(ImportError):
|
|
33
|
+
"""Raised when storage operations fail."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, operation: str, collection: str, reason: str):
|
|
36
|
+
super().__init__(f"Storage {operation} failed for {collection}: {reason}")
|
|
37
|
+
self.operation = operation
|
|
38
|
+
self.collection = collection
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ParseError(ImportError):
|
|
42
|
+
"""Raised when parsing conversation files fails."""
|
|
43
|
+
|
|
44
|
+
def __init__(self, file_path: str, line_number: Optional[int] = None, reason: str = ""):
|
|
45
|
+
message = f"Failed to parse {file_path}"
|
|
46
|
+
if line_number:
|
|
47
|
+
message += f" at line {line_number}"
|
|
48
|
+
if reason:
|
|
49
|
+
message += f": {reason}"
|
|
50
|
+
super().__init__(message)
|
|
51
|
+
self.file_path = file_path
|
|
52
|
+
self.line_number = line_number
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""Core domain models for the import system."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import List, Dict, Any, Optional, Set
|
|
6
|
+
from uuid import UUID, uuid4
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class Message:
|
|
11
|
+
"""A single message in a conversation."""
|
|
12
|
+
|
|
13
|
+
role: str
|
|
14
|
+
content: str
|
|
15
|
+
timestamp: Optional[datetime] = None
|
|
16
|
+
message_index: int = 0
|
|
17
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
18
|
+
|
|
19
|
+
def __post_init__(self):
|
|
20
|
+
"""Validate message on creation."""
|
|
21
|
+
if not self.role:
|
|
22
|
+
raise ValueError("Message role cannot be empty")
|
|
23
|
+
if self.role not in {"user", "assistant", "system", "human"}:
|
|
24
|
+
# Allow common variations
|
|
25
|
+
pass # Log warning but don't fail
|
|
26
|
+
if self.message_index < 0:
|
|
27
|
+
raise ValueError(f"Message index cannot be negative: {self.message_index}")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class ConversationChunk:
|
|
32
|
+
"""A chunk of conversation ready for embedding."""
|
|
33
|
+
|
|
34
|
+
text: str
|
|
35
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
36
|
+
message_indices: List[int] = field(default_factory=list)
|
|
37
|
+
chunk_index: int = 0
|
|
38
|
+
total_chunks: int = 1
|
|
39
|
+
conversation_id: str = field(default_factory=lambda: str(uuid4()))
|
|
40
|
+
|
|
41
|
+
def __post_init__(self):
|
|
42
|
+
"""Validate chunk on creation."""
|
|
43
|
+
if not self.text:
|
|
44
|
+
raise ValueError("Chunk text cannot be empty")
|
|
45
|
+
if self.chunk_index < 0:
|
|
46
|
+
raise ValueError(f"Chunk index cannot be negative: {self.chunk_index}")
|
|
47
|
+
if self.chunk_index >= self.total_chunks:
|
|
48
|
+
raise ValueError(
|
|
49
|
+
f"Chunk index ({self.chunk_index}) must be less than "
|
|
50
|
+
f"total chunks ({self.total_chunks})"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def unique_id(self) -> str:
|
|
55
|
+
"""Generate unique ID for this chunk."""
|
|
56
|
+
return f"{self.conversation_id}_{self.chunk_index}"
|
|
57
|
+
|
|
58
|
+
def add_metadata(self, key: str, value: Any) -> None:
|
|
59
|
+
"""Add metadata with conflict detection."""
|
|
60
|
+
if key in self.metadata:
|
|
61
|
+
# Handle conflict - could merge, replace, or raise
|
|
62
|
+
if isinstance(self.metadata[key], list) and isinstance(value, list):
|
|
63
|
+
# Merge lists
|
|
64
|
+
self.metadata[key] = list(set(self.metadata[key] + value))
|
|
65
|
+
else:
|
|
66
|
+
# Replace value (log warning in production)
|
|
67
|
+
self.metadata[key] = value
|
|
68
|
+
else:
|
|
69
|
+
self.metadata[key] = value
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class ProcessedPoint:
|
|
74
|
+
"""A fully processed point ready for storage."""
|
|
75
|
+
|
|
76
|
+
id: str
|
|
77
|
+
vector: List[float]
|
|
78
|
+
payload: Dict[str, Any]
|
|
79
|
+
|
|
80
|
+
def __post_init__(self):
|
|
81
|
+
"""Validate point on creation."""
|
|
82
|
+
if not self.id:
|
|
83
|
+
raise ValueError("Point ID cannot be empty")
|
|
84
|
+
if not self.vector:
|
|
85
|
+
raise ValueError("Point vector cannot be empty")
|
|
86
|
+
if not isinstance(self.vector, list):
|
|
87
|
+
raise TypeError(f"Vector must be a list, got {type(self.vector)}")
|
|
88
|
+
if not all(isinstance(x, (int, float)) for x in self.vector):
|
|
89
|
+
raise TypeError("Vector must contain only numeric values")
|
|
90
|
+
if not self.payload:
|
|
91
|
+
# Empty payload is allowed but unusual
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def dimension(self) -> int:
|
|
96
|
+
"""Get vector dimension."""
|
|
97
|
+
return len(self.vector)
|
|
98
|
+
|
|
99
|
+
def validate_dimension(self, expected: int) -> bool:
|
|
100
|
+
"""Check if vector has expected dimension."""
|
|
101
|
+
return self.dimension == expected
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@dataclass
|
|
105
|
+
class ImportResult:
|
|
106
|
+
"""Result of an import operation."""
|
|
107
|
+
|
|
108
|
+
file_path: str
|
|
109
|
+
success: bool
|
|
110
|
+
points_created: int = 0
|
|
111
|
+
chunks_processed: int = 0
|
|
112
|
+
error: Optional[str] = None
|
|
113
|
+
duration_seconds: float = 0.0
|
|
114
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def failed(self) -> bool:
|
|
118
|
+
"""Check if import failed."""
|
|
119
|
+
return not self.success
|
|
120
|
+
|
|
121
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
122
|
+
"""Convert to dictionary for serialization."""
|
|
123
|
+
return {
|
|
124
|
+
"file_path": self.file_path,
|
|
125
|
+
"success": self.success,
|
|
126
|
+
"points_created": self.points_created,
|
|
127
|
+
"chunks_processed": self.chunks_processed,
|
|
128
|
+
"error": self.error,
|
|
129
|
+
"duration_seconds": self.duration_seconds,
|
|
130
|
+
"metadata": self.metadata
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@dataclass
|
|
135
|
+
class ImportStats:
|
|
136
|
+
"""Aggregate statistics for import operations."""
|
|
137
|
+
|
|
138
|
+
total_files: int = 0
|
|
139
|
+
successful_files: int = 0
|
|
140
|
+
failed_files: int = 0
|
|
141
|
+
total_points: int = 0
|
|
142
|
+
total_chunks: int = 0
|
|
143
|
+
total_duration_seconds: float = 0.0
|
|
144
|
+
errors: List[str] = field(default_factory=list)
|
|
145
|
+
|
|
146
|
+
def add_result(self, result: ImportResult) -> None:
|
|
147
|
+
"""Add a result to the statistics."""
|
|
148
|
+
self.total_files += 1
|
|
149
|
+
if result.success:
|
|
150
|
+
self.successful_files += 1
|
|
151
|
+
self.total_points += result.points_created
|
|
152
|
+
self.total_chunks += result.chunks_processed
|
|
153
|
+
else:
|
|
154
|
+
self.failed_files += 1
|
|
155
|
+
if result.error:
|
|
156
|
+
self.errors.append(f"{result.file_path}: {result.error}")
|
|
157
|
+
self.total_duration_seconds += result.duration_seconds
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def success_rate(self) -> float:
|
|
161
|
+
"""Calculate success rate."""
|
|
162
|
+
if self.total_files == 0:
|
|
163
|
+
return 0.0
|
|
164
|
+
return self.successful_files / self.total_files * 100
|
|
165
|
+
|
|
166
|
+
@property
|
|
167
|
+
def average_duration(self) -> float:
|
|
168
|
+
"""Calculate average import duration."""
|
|
169
|
+
if self.total_files == 0:
|
|
170
|
+
return 0.0
|
|
171
|
+
return self.total_duration_seconds / self.total_files
|
|
172
|
+
|
|
173
|
+
def summary(self) -> str:
|
|
174
|
+
"""Generate summary string."""
|
|
175
|
+
return (
|
|
176
|
+
f"Import Statistics:\n"
|
|
177
|
+
f" Total Files: {self.total_files}\n"
|
|
178
|
+
f" Successful: {self.successful_files} ({self.success_rate:.1f}%)\n"
|
|
179
|
+
f" Failed: {self.failed_files}\n"
|
|
180
|
+
f" Total Points: {self.total_points}\n"
|
|
181
|
+
f" Total Chunks: {self.total_chunks}\n"
|
|
182
|
+
f" Total Duration: {self.total_duration_seconds:.2f}s\n"
|
|
183
|
+
f" Average Duration: {self.average_duration:.2f}s per file"
|
|
184
|
+
)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Embedding providers for text vectorization."""
|
|
2
|
+
|
|
3
|
+
from .base import EmbeddingProvider
|
|
4
|
+
from .fastembed_provider import FastEmbedProvider
|
|
5
|
+
from .validator import EmbeddingValidator
|
|
6
|
+
|
|
7
|
+
# Conditional import for Voyage
|
|
8
|
+
try:
|
|
9
|
+
from .voyage_provider import VoyageEmbeddingProvider
|
|
10
|
+
__all__ = [
|
|
11
|
+
"EmbeddingProvider",
|
|
12
|
+
"FastEmbedProvider",
|
|
13
|
+
"VoyageEmbeddingProvider",
|
|
14
|
+
"EmbeddingValidator"
|
|
15
|
+
]
|
|
16
|
+
except ImportError:
|
|
17
|
+
# Voyage not available, continue without it
|
|
18
|
+
__all__ = [
|
|
19
|
+
"EmbeddingProvider",
|
|
20
|
+
"FastEmbedProvider",
|
|
21
|
+
"EmbeddingValidator"
|
|
22
|
+
]
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""Abstract base class for embedding providers."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import List, Optional, Dict, Any
|
|
5
|
+
from ..core.exceptions import EmbeddingError
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EmbeddingProvider(ABC):
|
|
9
|
+
"""
|
|
10
|
+
Abstract interface for embedding providers.
|
|
11
|
+
|
|
12
|
+
Defines the contract that all embedding providers must implement,
|
|
13
|
+
including error handling methods as recommended by code review.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self):
|
|
17
|
+
self._last_error: Optional[Exception] = None
|
|
18
|
+
self._initialized: bool = False
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def initialize(self, config: Any) -> None:
|
|
22
|
+
"""
|
|
23
|
+
Initialize the embedding provider.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
config: Configuration object with provider-specific settings
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
EmbeddingError: If initialization fails
|
|
30
|
+
"""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
|
35
|
+
"""
|
|
36
|
+
Generate embeddings for a list of texts.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
texts: List of text strings to embed
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
List of embedding vectors
|
|
43
|
+
|
|
44
|
+
Raises:
|
|
45
|
+
EmbeddingError: If embedding generation fails
|
|
46
|
+
"""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def get_dimension(self) -> int:
|
|
51
|
+
"""
|
|
52
|
+
Get the dimension of embeddings produced by this provider.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Embedding dimension
|
|
56
|
+
"""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def validate_embedding(self, embedding: List[float]) -> bool:
|
|
61
|
+
"""
|
|
62
|
+
Validate that an embedding is well-formed and not degenerate.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
embedding: Embedding vector to validate
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
True if embedding is valid, False otherwise
|
|
69
|
+
"""
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def handle_initialization_error(self, error: Exception) -> None:
|
|
74
|
+
"""
|
|
75
|
+
Handle initialization failures.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
error: The exception that occurred during initialization
|
|
79
|
+
"""
|
|
80
|
+
pass
|
|
81
|
+
|
|
82
|
+
def get_last_error(self) -> Optional[Exception]:
|
|
83
|
+
"""
|
|
84
|
+
Retrieve the last error for diagnostics.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Last exception that occurred, or None
|
|
88
|
+
"""
|
|
89
|
+
return self._last_error
|
|
90
|
+
|
|
91
|
+
def is_initialized(self) -> bool:
|
|
92
|
+
"""
|
|
93
|
+
Check if provider is initialized and ready.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
True if initialized, False otherwise
|
|
97
|
+
"""
|
|
98
|
+
return self._initialized
|
|
99
|
+
|
|
100
|
+
def get_provider_info(self) -> Dict[str, Any]:
|
|
101
|
+
"""
|
|
102
|
+
Get information about this embedding provider.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Dictionary with provider metadata
|
|
106
|
+
"""
|
|
107
|
+
return {
|
|
108
|
+
"provider": self.__class__.__name__,
|
|
109
|
+
"initialized": self._initialized,
|
|
110
|
+
"dimension": self.get_dimension() if self._initialized else None,
|
|
111
|
+
"has_error": self._last_error is not None
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
def batch_embed_texts(
|
|
115
|
+
self,
|
|
116
|
+
texts: List[str],
|
|
117
|
+
batch_size: int = 32
|
|
118
|
+
) -> List[List[float]]:
|
|
119
|
+
"""
|
|
120
|
+
Generate embeddings in batches for memory efficiency.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
texts: List of texts to embed
|
|
124
|
+
batch_size: Number of texts to process at once
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
List of embedding vectors
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
EmbeddingError: If embedding generation fails
|
|
131
|
+
"""
|
|
132
|
+
if not self._initialized:
|
|
133
|
+
raise EmbeddingError("Provider not initialized", provider=self.__class__.__name__)
|
|
134
|
+
|
|
135
|
+
embeddings = []
|
|
136
|
+
for i in range(0, len(texts), batch_size):
|
|
137
|
+
batch = texts[i:i + batch_size]
|
|
138
|
+
batch_embeddings = self.embed_texts(batch)
|
|
139
|
+
embeddings.extend(batch_embeddings)
|
|
140
|
+
|
|
141
|
+
return embeddings
|