emdash-core 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emdash_core/__init__.py +3 -0
- emdash_core/agent/__init__.py +37 -0
- emdash_core/agent/agents.py +225 -0
- emdash_core/agent/code_reviewer.py +476 -0
- emdash_core/agent/compaction.py +143 -0
- emdash_core/agent/context_manager.py +140 -0
- emdash_core/agent/events.py +338 -0
- emdash_core/agent/handlers.py +224 -0
- emdash_core/agent/inprocess_subagent.py +377 -0
- emdash_core/agent/mcp/__init__.py +50 -0
- emdash_core/agent/mcp/client.py +346 -0
- emdash_core/agent/mcp/config.py +302 -0
- emdash_core/agent/mcp/manager.py +496 -0
- emdash_core/agent/mcp/tool_factory.py +213 -0
- emdash_core/agent/prompts/__init__.py +38 -0
- emdash_core/agent/prompts/main_agent.py +104 -0
- emdash_core/agent/prompts/subagents.py +131 -0
- emdash_core/agent/prompts/workflow.py +136 -0
- emdash_core/agent/providers/__init__.py +34 -0
- emdash_core/agent/providers/base.py +143 -0
- emdash_core/agent/providers/factory.py +80 -0
- emdash_core/agent/providers/models.py +220 -0
- emdash_core/agent/providers/openai_provider.py +463 -0
- emdash_core/agent/providers/transformers_provider.py +217 -0
- emdash_core/agent/research/__init__.py +81 -0
- emdash_core/agent/research/agent.py +143 -0
- emdash_core/agent/research/controller.py +254 -0
- emdash_core/agent/research/critic.py +428 -0
- emdash_core/agent/research/macros.py +469 -0
- emdash_core/agent/research/planner.py +449 -0
- emdash_core/agent/research/researcher.py +436 -0
- emdash_core/agent/research/state.py +523 -0
- emdash_core/agent/research/synthesizer.py +594 -0
- emdash_core/agent/reviewer_profile.py +475 -0
- emdash_core/agent/rules.py +123 -0
- emdash_core/agent/runner.py +601 -0
- emdash_core/agent/session.py +262 -0
- emdash_core/agent/spec_schema.py +66 -0
- emdash_core/agent/specification.py +479 -0
- emdash_core/agent/subagent.py +397 -0
- emdash_core/agent/subagent_prompts.py +13 -0
- emdash_core/agent/toolkit.py +482 -0
- emdash_core/agent/toolkits/__init__.py +64 -0
- emdash_core/agent/toolkits/base.py +96 -0
- emdash_core/agent/toolkits/explore.py +47 -0
- emdash_core/agent/toolkits/plan.py +55 -0
- emdash_core/agent/tools/__init__.py +141 -0
- emdash_core/agent/tools/analytics.py +436 -0
- emdash_core/agent/tools/base.py +131 -0
- emdash_core/agent/tools/coding.py +484 -0
- emdash_core/agent/tools/github_mcp.py +592 -0
- emdash_core/agent/tools/history.py +13 -0
- emdash_core/agent/tools/modes.py +153 -0
- emdash_core/agent/tools/plan.py +206 -0
- emdash_core/agent/tools/plan_write.py +135 -0
- emdash_core/agent/tools/search.py +412 -0
- emdash_core/agent/tools/spec.py +341 -0
- emdash_core/agent/tools/task.py +262 -0
- emdash_core/agent/tools/task_output.py +204 -0
- emdash_core/agent/tools/tasks.py +454 -0
- emdash_core/agent/tools/traversal.py +588 -0
- emdash_core/agent/tools/web.py +179 -0
- emdash_core/analytics/__init__.py +5 -0
- emdash_core/analytics/engine.py +1286 -0
- emdash_core/api/__init__.py +5 -0
- emdash_core/api/agent.py +308 -0
- emdash_core/api/agents.py +154 -0
- emdash_core/api/analyze.py +264 -0
- emdash_core/api/auth.py +173 -0
- emdash_core/api/context.py +77 -0
- emdash_core/api/db.py +121 -0
- emdash_core/api/embed.py +131 -0
- emdash_core/api/feature.py +143 -0
- emdash_core/api/health.py +93 -0
- emdash_core/api/index.py +162 -0
- emdash_core/api/plan.py +110 -0
- emdash_core/api/projectmd.py +210 -0
- emdash_core/api/query.py +320 -0
- emdash_core/api/research.py +122 -0
- emdash_core/api/review.py +161 -0
- emdash_core/api/router.py +76 -0
- emdash_core/api/rules.py +116 -0
- emdash_core/api/search.py +119 -0
- emdash_core/api/spec.py +99 -0
- emdash_core/api/swarm.py +223 -0
- emdash_core/api/tasks.py +109 -0
- emdash_core/api/team.py +120 -0
- emdash_core/auth/__init__.py +17 -0
- emdash_core/auth/github.py +389 -0
- emdash_core/config.py +74 -0
- emdash_core/context/__init__.py +52 -0
- emdash_core/context/models.py +50 -0
- emdash_core/context/providers/__init__.py +11 -0
- emdash_core/context/providers/base.py +74 -0
- emdash_core/context/providers/explored_areas.py +183 -0
- emdash_core/context/providers/touched_areas.py +360 -0
- emdash_core/context/registry.py +73 -0
- emdash_core/context/reranker.py +199 -0
- emdash_core/context/service.py +260 -0
- emdash_core/context/session.py +352 -0
- emdash_core/core/__init__.py +104 -0
- emdash_core/core/config.py +454 -0
- emdash_core/core/exceptions.py +55 -0
- emdash_core/core/models.py +265 -0
- emdash_core/core/review_config.py +57 -0
- emdash_core/db/__init__.py +67 -0
- emdash_core/db/auth.py +134 -0
- emdash_core/db/models.py +91 -0
- emdash_core/db/provider.py +222 -0
- emdash_core/db/providers/__init__.py +5 -0
- emdash_core/db/providers/supabase.py +452 -0
- emdash_core/embeddings/__init__.py +24 -0
- emdash_core/embeddings/indexer.py +534 -0
- emdash_core/embeddings/models.py +192 -0
- emdash_core/embeddings/providers/__init__.py +7 -0
- emdash_core/embeddings/providers/base.py +112 -0
- emdash_core/embeddings/providers/fireworks.py +141 -0
- emdash_core/embeddings/providers/openai.py +104 -0
- emdash_core/embeddings/registry.py +146 -0
- emdash_core/embeddings/service.py +215 -0
- emdash_core/graph/__init__.py +26 -0
- emdash_core/graph/builder.py +134 -0
- emdash_core/graph/connection.py +692 -0
- emdash_core/graph/schema.py +416 -0
- emdash_core/graph/writer.py +667 -0
- emdash_core/ingestion/__init__.py +7 -0
- emdash_core/ingestion/change_detector.py +150 -0
- emdash_core/ingestion/git/__init__.py +5 -0
- emdash_core/ingestion/git/commit_analyzer.py +196 -0
- emdash_core/ingestion/github/__init__.py +6 -0
- emdash_core/ingestion/github/pr_fetcher.py +296 -0
- emdash_core/ingestion/github/task_extractor.py +100 -0
- emdash_core/ingestion/orchestrator.py +540 -0
- emdash_core/ingestion/parsers/__init__.py +10 -0
- emdash_core/ingestion/parsers/base_parser.py +66 -0
- emdash_core/ingestion/parsers/call_graph_builder.py +121 -0
- emdash_core/ingestion/parsers/class_extractor.py +154 -0
- emdash_core/ingestion/parsers/function_extractor.py +202 -0
- emdash_core/ingestion/parsers/import_analyzer.py +119 -0
- emdash_core/ingestion/parsers/python_parser.py +123 -0
- emdash_core/ingestion/parsers/registry.py +72 -0
- emdash_core/ingestion/parsers/ts_ast_parser.js +313 -0
- emdash_core/ingestion/parsers/typescript_parser.py +278 -0
- emdash_core/ingestion/repository.py +346 -0
- emdash_core/models/__init__.py +38 -0
- emdash_core/models/agent.py +68 -0
- emdash_core/models/index.py +77 -0
- emdash_core/models/query.py +113 -0
- emdash_core/planning/__init__.py +7 -0
- emdash_core/planning/agent_api.py +413 -0
- emdash_core/planning/context_builder.py +265 -0
- emdash_core/planning/feature_context.py +232 -0
- emdash_core/planning/feature_expander.py +646 -0
- emdash_core/planning/llm_explainer.py +198 -0
- emdash_core/planning/similarity.py +509 -0
- emdash_core/planning/team_focus.py +821 -0
- emdash_core/server.py +153 -0
- emdash_core/sse/__init__.py +5 -0
- emdash_core/sse/stream.py +196 -0
- emdash_core/swarm/__init__.py +17 -0
- emdash_core/swarm/merge_agent.py +383 -0
- emdash_core/swarm/session_manager.py +274 -0
- emdash_core/swarm/swarm_runner.py +226 -0
- emdash_core/swarm/task_definition.py +137 -0
- emdash_core/swarm/worker_spawner.py +319 -0
- emdash_core/swarm/worktree_manager.py +278 -0
- emdash_core/templates/__init__.py +10 -0
- emdash_core/templates/defaults/agent-builder.md.template +82 -0
- emdash_core/templates/defaults/focus.md.template +115 -0
- emdash_core/templates/defaults/pr-review-enhanced.md.template +309 -0
- emdash_core/templates/defaults/pr-review.md.template +80 -0
- emdash_core/templates/defaults/project.md.template +85 -0
- emdash_core/templates/defaults/research_critic.md.template +112 -0
- emdash_core/templates/defaults/research_planner.md.template +85 -0
- emdash_core/templates/defaults/research_synthesizer.md.template +128 -0
- emdash_core/templates/defaults/reviewer.md.template +81 -0
- emdash_core/templates/defaults/spec.md.template +41 -0
- emdash_core/templates/defaults/tasks.md.template +78 -0
- emdash_core/templates/loader.py +296 -0
- emdash_core/utils/__init__.py +45 -0
- emdash_core/utils/git.py +84 -0
- emdash_core/utils/image.py +502 -0
- emdash_core/utils/logger.py +51 -0
- emdash_core-0.1.7.dist-info/METADATA +35 -0
- emdash_core-0.1.7.dist-info/RECORD +187 -0
- emdash_core-0.1.7.dist-info/WHEEL +4 -0
- emdash_core-0.1.7.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
"""Embedding models enum - single source of truth for all supported models."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass(frozen=True)
|
|
9
|
+
class ModelSpec:
|
|
10
|
+
"""Specification for an embedding model."""
|
|
11
|
+
|
|
12
|
+
provider: str # "openai", "fireworks"
|
|
13
|
+
model_id: str # The actual model identifier for the API
|
|
14
|
+
dimensions: int # Output embedding dimensions
|
|
15
|
+
max_tokens: int # Max input tokens
|
|
16
|
+
batch_size: int # Recommended batch size
|
|
17
|
+
description: str # Human-readable description
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class EmbeddingModel(Enum):
|
|
21
|
+
"""
|
|
22
|
+
All supported embedding models.
|
|
23
|
+
|
|
24
|
+
Format: PROVIDER_MODEL_NAME
|
|
25
|
+
|
|
26
|
+
Usage:
|
|
27
|
+
model = EmbeddingModel.OPENAI_TEXT_3_SMALL
|
|
28
|
+
print(model.spec.dimensions) # 1536
|
|
29
|
+
print(model.spec.provider) # "openai"
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
# ═══════════════════════════════════════════════════════════════════════════
|
|
33
|
+
# OpenAI Models
|
|
34
|
+
# ═══════════════════════════════════════════════════════════════════════════
|
|
35
|
+
|
|
36
|
+
OPENAI_TEXT_3_SMALL = ModelSpec(
|
|
37
|
+
provider="openai",
|
|
38
|
+
model_id="text-embedding-3-small",
|
|
39
|
+
dimensions=1536,
|
|
40
|
+
max_tokens=8191,
|
|
41
|
+
batch_size=100,
|
|
42
|
+
description="OpenAI's small, fast embedding model (best value)",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
OPENAI_TEXT_3_LARGE = ModelSpec(
|
|
46
|
+
provider="openai",
|
|
47
|
+
model_id="text-embedding-3-large",
|
|
48
|
+
dimensions=3072,
|
|
49
|
+
max_tokens=8191,
|
|
50
|
+
batch_size=50,
|
|
51
|
+
description="OpenAI's large, high-quality embedding model",
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
OPENAI_ADA_002 = ModelSpec(
|
|
55
|
+
provider="openai",
|
|
56
|
+
model_id="text-embedding-ada-002",
|
|
57
|
+
dimensions=1536,
|
|
58
|
+
max_tokens=8191,
|
|
59
|
+
batch_size=100,
|
|
60
|
+
description="OpenAI's legacy Ada model (deprecated, use text-3-small)",
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# ═══════════════════════════════════════════════════════════════════════════
|
|
64
|
+
# Fireworks AI Models
|
|
65
|
+
# ═══════════════════════════════════════════════════════════════════════════
|
|
66
|
+
|
|
67
|
+
FIREWORKS_NOMIC_EMBED_V1_5 = ModelSpec(
|
|
68
|
+
provider="fireworks",
|
|
69
|
+
model_id="nomic-ai/nomic-embed-text-v1.5",
|
|
70
|
+
dimensions=768,
|
|
71
|
+
max_tokens=8192,
|
|
72
|
+
batch_size=100,
|
|
73
|
+
description="Nomic's open-source embedding model (fast, good quality)",
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
FIREWORKS_E5_MISTRAL_7B = ModelSpec(
|
|
77
|
+
provider="fireworks",
|
|
78
|
+
model_id="intfloat/e5-mistral-7b-instruct",
|
|
79
|
+
dimensions=4096,
|
|
80
|
+
max_tokens=4096,
|
|
81
|
+
batch_size=20,
|
|
82
|
+
description="E5-Mistral 7B (highest quality, slower)",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
FIREWORKS_UAE_LARGE_V1 = ModelSpec(
|
|
86
|
+
provider="fireworks",
|
|
87
|
+
model_id="WhereIsAI/UAE-Large-V1",
|
|
88
|
+
dimensions=1024,
|
|
89
|
+
max_tokens=512,
|
|
90
|
+
batch_size=50,
|
|
91
|
+
description="UAE-Large-V1 (good balance of speed/quality)",
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
FIREWORKS_GTE_LARGE = ModelSpec(
|
|
95
|
+
provider="fireworks",
|
|
96
|
+
model_id="thenlper/gte-large",
|
|
97
|
+
dimensions=1024,
|
|
98
|
+
max_tokens=512,
|
|
99
|
+
batch_size=50,
|
|
100
|
+
description="GTE-Large (Alibaba's efficient embedding model)",
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
FIREWORKS_BGE_LARGE_EN = ModelSpec(
|
|
104
|
+
provider="fireworks",
|
|
105
|
+
model_id="BAAI/bge-large-en-v1.5",
|
|
106
|
+
dimensions=1024,
|
|
107
|
+
max_tokens=512,
|
|
108
|
+
batch_size=50,
|
|
109
|
+
description="BGE-Large-EN (BAAI's high-quality English model)",
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# ═══════════════════════════════════════════════════════════════════════════
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def spec(self) -> ModelSpec:
|
|
116
|
+
"""Get the model specification."""
|
|
117
|
+
return self.value
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def provider(self) -> str:
|
|
121
|
+
"""Shortcut to get provider name."""
|
|
122
|
+
return self.value.provider
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def model_id(self) -> str:
|
|
126
|
+
"""Shortcut to get the API model ID."""
|
|
127
|
+
return self.value.model_id
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def dimensions(self) -> int:
|
|
131
|
+
"""Shortcut to get embedding dimensions."""
|
|
132
|
+
return self.value.dimensions
|
|
133
|
+
|
|
134
|
+
@classmethod
|
|
135
|
+
def get_default(cls) -> "EmbeddingModel":
|
|
136
|
+
"""Get the default embedding model."""
|
|
137
|
+
return cls.OPENAI_TEXT_3_SMALL
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def from_string(cls, value: str) -> Optional["EmbeddingModel"]:
|
|
141
|
+
"""
|
|
142
|
+
Parse model from string.
|
|
143
|
+
|
|
144
|
+
Accepts:
|
|
145
|
+
- Enum name: "OPENAI_TEXT_3_SMALL"
|
|
146
|
+
- Provider:model: "openai:text-embedding-3-small"
|
|
147
|
+
- Just model_id: "text-embedding-3-small"
|
|
148
|
+
"""
|
|
149
|
+
value = value.strip()
|
|
150
|
+
|
|
151
|
+
# Try enum name first
|
|
152
|
+
try:
|
|
153
|
+
return cls[value.upper().replace("-", "_").replace(":", "_")]
|
|
154
|
+
except KeyError:
|
|
155
|
+
pass
|
|
156
|
+
|
|
157
|
+
# Try provider:model format
|
|
158
|
+
if ":" in value:
|
|
159
|
+
provider, model_id = value.split(":", 1)
|
|
160
|
+
for model in cls:
|
|
161
|
+
if model.provider == provider and model.model_id == model_id:
|
|
162
|
+
return model
|
|
163
|
+
|
|
164
|
+
# Try just model_id
|
|
165
|
+
for model in cls:
|
|
166
|
+
if model.model_id == value:
|
|
167
|
+
return model
|
|
168
|
+
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
@classmethod
|
|
172
|
+
def list_by_provider(cls, provider: str) -> list["EmbeddingModel"]:
|
|
173
|
+
"""List all models for a specific provider."""
|
|
174
|
+
return [m for m in cls if m.provider == provider]
|
|
175
|
+
|
|
176
|
+
@classmethod
|
|
177
|
+
def list_all(cls) -> list[dict]:
|
|
178
|
+
"""List all models with their specs for display."""
|
|
179
|
+
return [
|
|
180
|
+
{
|
|
181
|
+
"name": m.name,
|
|
182
|
+
"provider": m.provider,
|
|
183
|
+
"model_id": m.model_id,
|
|
184
|
+
"dimensions": m.dimensions,
|
|
185
|
+
"description": m.spec.description,
|
|
186
|
+
}
|
|
187
|
+
for m in cls
|
|
188
|
+
]
|
|
189
|
+
|
|
190
|
+
def __str__(self) -> str:
|
|
191
|
+
"""String representation as provider:model_id."""
|
|
192
|
+
return f"{self.provider}:{self.model_id}"
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""Base class for embedding providers."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from ..models import EmbeddingModel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class EmbeddingProvider(ABC):
|
|
10
|
+
"""
|
|
11
|
+
Abstract base class for embedding providers.
|
|
12
|
+
|
|
13
|
+
Each provider (OpenAI, Fireworks, etc.) implements this interface.
|
|
14
|
+
The registry uses this to provide a unified embedding API.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, model: EmbeddingModel):
|
|
18
|
+
"""
|
|
19
|
+
Initialize provider with a specific model.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
model: The embedding model to use
|
|
23
|
+
"""
|
|
24
|
+
self._model = model
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def model(self) -> EmbeddingModel:
|
|
28
|
+
"""Get the embedding model."""
|
|
29
|
+
return self._model
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def dimensions(self) -> int:
|
|
33
|
+
"""Get embedding dimensions for the current model."""
|
|
34
|
+
return self._model.dimensions
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def is_available(self) -> bool:
|
|
39
|
+
"""Check if the provider is available (API key configured, etc.)."""
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def embed_texts(self, texts: list[str]) -> list[Optional[list[float]]]:
|
|
44
|
+
"""
|
|
45
|
+
Generate embeddings for multiple texts.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
texts: List of text strings to embed
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
List of embedding vectors. None for failed embeddings.
|
|
52
|
+
"""
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
def embed_text(self, text: str) -> Optional[list[float]]:
|
|
56
|
+
"""
|
|
57
|
+
Generate embedding for a single text.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
text: Text string to embed
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Embedding vector or None if failed
|
|
64
|
+
"""
|
|
65
|
+
if not text:
|
|
66
|
+
return None
|
|
67
|
+
embeddings = self.embed_texts([text])
|
|
68
|
+
return embeddings[0] if embeddings else None
|
|
69
|
+
|
|
70
|
+
def embed_query(self, query: str) -> Optional[list[float]]:
|
|
71
|
+
"""
|
|
72
|
+
Generate embedding for a search query.
|
|
73
|
+
|
|
74
|
+
Some models treat queries differently from documents.
|
|
75
|
+
Override this method if the model requires special query handling.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
query: Search query string
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Embedding vector or None if failed
|
|
82
|
+
"""
|
|
83
|
+
return self.embed_text(query)
|
|
84
|
+
|
|
85
|
+
def _truncate_text(self, text: str, max_chars: int = 8000) -> str:
|
|
86
|
+
"""
|
|
87
|
+
Truncate text to avoid token limits.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
text: Text to truncate
|
|
91
|
+
max_chars: Maximum character length (roughly 4 chars per token)
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Truncated text
|
|
95
|
+
"""
|
|
96
|
+
if text and len(text) > max_chars:
|
|
97
|
+
return text[:max_chars]
|
|
98
|
+
return text or ""
|
|
99
|
+
|
|
100
|
+
def _clean_batch(self, texts: list[str]) -> list[str]:
|
|
101
|
+
"""
|
|
102
|
+
Clean and truncate a batch of texts.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
texts: List of texts to clean
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Cleaned texts
|
|
109
|
+
"""
|
|
110
|
+
# Calculate max chars based on model's max tokens (roughly 4 chars per token)
|
|
111
|
+
max_chars = min(self._model.spec.max_tokens * 4, 32000)
|
|
112
|
+
return [self._truncate_text(t, max_chars) for t in texts]
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""Fireworks AI embedding provider."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from ..models import EmbeddingModel
|
|
6
|
+
from .base import EmbeddingProvider
|
|
7
|
+
from ...core.config import get_config
|
|
8
|
+
from ...utils.logger import log
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class FireworksProvider(EmbeddingProvider):
|
|
12
|
+
"""
|
|
13
|
+
Fireworks AI embedding provider.
|
|
14
|
+
|
|
15
|
+
Uses the Fireworks API (OpenAI-compatible) to generate embeddings.
|
|
16
|
+
Requires FIREWORKS_API_KEY environment variable.
|
|
17
|
+
|
|
18
|
+
API docs: https://docs.fireworks.ai/guides/querying-embeddings-models
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
# Fireworks API base URL
|
|
22
|
+
BASE_URL = "https://api.fireworks.ai/inference/v1"
|
|
23
|
+
|
|
24
|
+
# Models this provider handles
|
|
25
|
+
SUPPORTED_MODELS = {
|
|
26
|
+
EmbeddingModel.FIREWORKS_NOMIC_EMBED_V1_5,
|
|
27
|
+
EmbeddingModel.FIREWORKS_E5_MISTRAL_7B,
|
|
28
|
+
EmbeddingModel.FIREWORKS_UAE_LARGE_V1,
|
|
29
|
+
EmbeddingModel.FIREWORKS_GTE_LARGE,
|
|
30
|
+
EmbeddingModel.FIREWORKS_BGE_LARGE_EN,
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
def __init__(self, model: EmbeddingModel):
|
|
34
|
+
"""
|
|
35
|
+
Initialize Fireworks provider.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
model: The embedding model to use (must be a Fireworks model)
|
|
39
|
+
"""
|
|
40
|
+
if model not in self.SUPPORTED_MODELS:
|
|
41
|
+
raise ValueError(f"Model {model} is not supported by FireworksProvider")
|
|
42
|
+
super().__init__(model)
|
|
43
|
+
self._client = None
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def _api_key(self) -> Optional[str]:
|
|
47
|
+
"""Get Fireworks API key from config."""
|
|
48
|
+
config = get_config()
|
|
49
|
+
# Check if fireworks config exists
|
|
50
|
+
if hasattr(config, "fireworks") and config.fireworks.api_key:
|
|
51
|
+
return config.fireworks.api_key
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def is_available(self) -> bool:
|
|
56
|
+
"""Check if Fireworks API key is configured."""
|
|
57
|
+
return self._api_key is not None and len(self._api_key) > 0
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def _client_instance(self):
|
|
61
|
+
"""Lazy-load OpenAI client configured for Fireworks."""
|
|
62
|
+
if self._client is None:
|
|
63
|
+
if not self.is_available:
|
|
64
|
+
raise RuntimeError(
|
|
65
|
+
"Fireworks API key not configured. Set FIREWORKS_API_KEY environment variable."
|
|
66
|
+
)
|
|
67
|
+
try:
|
|
68
|
+
from openai import OpenAI
|
|
69
|
+
|
|
70
|
+
# Fireworks uses OpenAI-compatible API
|
|
71
|
+
self._client = OpenAI(
|
|
72
|
+
api_key=self._api_key,
|
|
73
|
+
base_url=self.BASE_URL,
|
|
74
|
+
)
|
|
75
|
+
except ImportError:
|
|
76
|
+
raise RuntimeError(
|
|
77
|
+
"OpenAI library not installed. Run: pip install openai"
|
|
78
|
+
)
|
|
79
|
+
return self._client
|
|
80
|
+
|
|
81
|
+
def embed_texts(self, texts: list[str]) -> list[Optional[list[float]]]:
|
|
82
|
+
"""
|
|
83
|
+
Generate embeddings using Fireworks API.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
texts: List of text strings to embed
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
List of embedding vectors. None for failed embeddings.
|
|
90
|
+
"""
|
|
91
|
+
if not texts:
|
|
92
|
+
return []
|
|
93
|
+
|
|
94
|
+
all_embeddings = []
|
|
95
|
+
batch_size = self._model.spec.batch_size
|
|
96
|
+
|
|
97
|
+
for i in range(0, len(texts), batch_size):
|
|
98
|
+
batch = texts[i : i + batch_size]
|
|
99
|
+
cleaned_batch = self._clean_batch(batch)
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
# Fireworks requires accounts/ prefix for model IDs
|
|
103
|
+
model_id = f"accounts/fireworks/models/{self._model.model_id}"
|
|
104
|
+
|
|
105
|
+
response = self._client_instance.embeddings.create(
|
|
106
|
+
model=model_id,
|
|
107
|
+
input=cleaned_batch,
|
|
108
|
+
)
|
|
109
|
+
batch_embeddings = [item.embedding for item in response.data]
|
|
110
|
+
all_embeddings.extend(batch_embeddings)
|
|
111
|
+
|
|
112
|
+
except Exception as e:
|
|
113
|
+
log.error(f"Fireworks embedding error: {e}")
|
|
114
|
+
all_embeddings.extend([None] * len(cleaned_batch))
|
|
115
|
+
|
|
116
|
+
return all_embeddings
|
|
117
|
+
|
|
118
|
+
def embed_query(self, query: str) -> Optional[list[float]]:
|
|
119
|
+
"""
|
|
120
|
+
Generate embedding for a search query.
|
|
121
|
+
|
|
122
|
+
Some Fireworks models (like Nomic, E5) benefit from query prefixes.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
query: Search query string
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Embedding vector or None if failed
|
|
129
|
+
"""
|
|
130
|
+
if not query:
|
|
131
|
+
return None
|
|
132
|
+
|
|
133
|
+
# E5 models expect "query: " prefix for queries
|
|
134
|
+
if self._model == EmbeddingModel.FIREWORKS_E5_MISTRAL_7B:
|
|
135
|
+
query = f"query: {query}"
|
|
136
|
+
|
|
137
|
+
# Nomic models can optionally use "search_query: " prefix
|
|
138
|
+
elif self._model == EmbeddingModel.FIREWORKS_NOMIC_EMBED_V1_5:
|
|
139
|
+
query = f"search_query: {query}"
|
|
140
|
+
|
|
141
|
+
return self.embed_text(query)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""OpenAI embedding provider."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from ..models import EmbeddingModel
|
|
6
|
+
from .base import EmbeddingProvider
|
|
7
|
+
from ...core.config import get_config
|
|
8
|
+
from ...utils.logger import log
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OpenAIProvider(EmbeddingProvider):
|
|
12
|
+
"""
|
|
13
|
+
OpenAI embedding provider.
|
|
14
|
+
|
|
15
|
+
Uses the OpenAI API to generate embeddings.
|
|
16
|
+
Requires OPENAI_API_KEY environment variable.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
# Models this provider handles
|
|
20
|
+
SUPPORTED_MODELS = {
|
|
21
|
+
EmbeddingModel.OPENAI_TEXT_3_SMALL,
|
|
22
|
+
EmbeddingModel.OPENAI_TEXT_3_LARGE,
|
|
23
|
+
EmbeddingModel.OPENAI_ADA_002,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
def __init__(self, model: EmbeddingModel):
|
|
27
|
+
"""
|
|
28
|
+
Initialize OpenAI provider.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model: The embedding model to use (must be an OpenAI model)
|
|
32
|
+
"""
|
|
33
|
+
if model not in self.SUPPORTED_MODELS:
|
|
34
|
+
raise ValueError(f"Model {model} is not supported by OpenAIProvider")
|
|
35
|
+
super().__init__(model)
|
|
36
|
+
self._client = None
|
|
37
|
+
self._config = get_config().openai
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def is_available(self) -> bool:
|
|
41
|
+
"""Check if OpenAI API key is configured."""
|
|
42
|
+
return self._config.is_available
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def _client_instance(self):
|
|
46
|
+
"""Lazy-load OpenAI client."""
|
|
47
|
+
if self._client is None:
|
|
48
|
+
if not self.is_available:
|
|
49
|
+
raise RuntimeError(
|
|
50
|
+
"OpenAI API key not configured. Set OPENAI_API_KEY environment variable."
|
|
51
|
+
)
|
|
52
|
+
try:
|
|
53
|
+
from openai import OpenAI
|
|
54
|
+
|
|
55
|
+
self._client = OpenAI(api_key=self._config.api_key)
|
|
56
|
+
except ImportError:
|
|
57
|
+
raise RuntimeError(
|
|
58
|
+
"OpenAI library not installed. Run: pip install openai"
|
|
59
|
+
)
|
|
60
|
+
return self._client
|
|
61
|
+
|
|
62
|
+
def embed_texts(self, texts: list[str]) -> list[Optional[list[float]]]:
|
|
63
|
+
"""
|
|
64
|
+
Generate embeddings using OpenAI API.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
texts: List of text strings to embed
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
List of embedding vectors. None for failed embeddings.
|
|
71
|
+
"""
|
|
72
|
+
if not texts:
|
|
73
|
+
return []
|
|
74
|
+
|
|
75
|
+
all_embeddings = []
|
|
76
|
+
batch_size = self._model.spec.batch_size
|
|
77
|
+
|
|
78
|
+
for i in range(0, len(texts), batch_size):
|
|
79
|
+
batch = texts[i : i + batch_size]
|
|
80
|
+
cleaned_batch = self._clean_batch(batch)
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
# Use dimensions parameter for text-embedding-3 models
|
|
84
|
+
kwargs = {
|
|
85
|
+
"model": self._model.model_id,
|
|
86
|
+
"input": cleaned_batch,
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
# text-embedding-3 models support custom dimensions
|
|
90
|
+
if self._model in {
|
|
91
|
+
EmbeddingModel.OPENAI_TEXT_3_SMALL,
|
|
92
|
+
EmbeddingModel.OPENAI_TEXT_3_LARGE,
|
|
93
|
+
}:
|
|
94
|
+
kwargs["dimensions"] = self._model.dimensions
|
|
95
|
+
|
|
96
|
+
response = self._client_instance.embeddings.create(**kwargs)
|
|
97
|
+
batch_embeddings = [item.embedding for item in response.data]
|
|
98
|
+
all_embeddings.extend(batch_embeddings)
|
|
99
|
+
|
|
100
|
+
except Exception as e:
|
|
101
|
+
log.error(f"OpenAI embedding error: {e}")
|
|
102
|
+
all_embeddings.extend([None] * len(cleaned_batch))
|
|
103
|
+
|
|
104
|
+
return all_embeddings
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""Provider registry for embedding models."""
|
|
2
|
+
|
|
3
|
+
from typing import Type, Optional
|
|
4
|
+
|
|
5
|
+
from .models import EmbeddingModel
|
|
6
|
+
from .providers.base import EmbeddingProvider
|
|
7
|
+
from .providers.openai import OpenAIProvider
|
|
8
|
+
from .providers.fireworks import FireworksProvider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ProviderRegistry:
|
|
12
|
+
"""
|
|
13
|
+
Registry for embedding providers.
|
|
14
|
+
|
|
15
|
+
Maps provider names to provider classes. No if-else chains needed.
|
|
16
|
+
Just register your provider once and it's available everywhere.
|
|
17
|
+
|
|
18
|
+
Usage:
|
|
19
|
+
registry = ProviderRegistry()
|
|
20
|
+
provider = registry.get_provider(EmbeddingModel.OPENAI_TEXT_3_SMALL)
|
|
21
|
+
embeddings = provider.embed_texts(["hello world"])
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
# Provider class registry: provider_name -> provider_class
|
|
25
|
+
_providers: dict[str, Type[EmbeddingProvider]] = {}
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def register(cls, provider_name: str, provider_class: Type[EmbeddingProvider]):
|
|
29
|
+
"""
|
|
30
|
+
Register a provider class.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
provider_name: Name of the provider (e.g., "openai", "fireworks")
|
|
34
|
+
provider_class: The provider class to register
|
|
35
|
+
"""
|
|
36
|
+
cls._providers[provider_name] = provider_class
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def get_provider_class(cls, provider_name: str) -> Optional[Type[EmbeddingProvider]]:
|
|
40
|
+
"""
|
|
41
|
+
Get the provider class for a provider name.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
provider_name: Name of the provider
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Provider class or None if not registered
|
|
48
|
+
"""
|
|
49
|
+
return cls._providers.get(provider_name)
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def get_provider(cls, model: EmbeddingModel) -> EmbeddingProvider:
|
|
53
|
+
"""
|
|
54
|
+
Get an instantiated provider for a model.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
model: The embedding model
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Instantiated provider for the model
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
ValueError: If no provider is registered for the model's provider
|
|
64
|
+
"""
|
|
65
|
+
provider_class = cls._providers.get(model.provider)
|
|
66
|
+
if provider_class is None:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"No provider registered for '{model.provider}'. "
|
|
69
|
+
f"Available providers: {list(cls._providers.keys())}"
|
|
70
|
+
)
|
|
71
|
+
return provider_class(model)
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def list_providers(cls) -> list[str]:
|
|
75
|
+
"""List all registered provider names."""
|
|
76
|
+
return list(cls._providers.keys())
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def is_provider_available(cls, provider_name: str) -> bool:
|
|
80
|
+
"""
|
|
81
|
+
Check if a provider is available (registered and configured).
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
provider_name: Name of the provider
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
True if provider is registered and has valid credentials
|
|
88
|
+
"""
|
|
89
|
+
provider_class = cls._providers.get(provider_name)
|
|
90
|
+
if provider_class is None:
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
# Get any model for this provider to check availability
|
|
94
|
+
models = EmbeddingModel.list_by_provider(provider_name)
|
|
95
|
+
if not models:
|
|
96
|
+
return False
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
provider = provider_class(models[0])
|
|
100
|
+
return provider.is_available
|
|
101
|
+
except Exception:
|
|
102
|
+
return False
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
106
|
+
# Register all providers
|
|
107
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
108
|
+
|
|
109
|
+
ProviderRegistry.register("openai", OpenAIProvider)
|
|
110
|
+
ProviderRegistry.register("fireworks", FireworksProvider)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
114
|
+
# Convenience functions
|
|
115
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def get_provider(model: EmbeddingModel) -> EmbeddingProvider:
|
|
119
|
+
"""Get an instantiated provider for a model."""
|
|
120
|
+
return ProviderRegistry.get_provider(model)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def get_default_provider() -> EmbeddingProvider:
|
|
124
|
+
"""Get the default embedding provider (OpenAI text-embedding-3-small)."""
|
|
125
|
+
return ProviderRegistry.get_provider(EmbeddingModel.get_default())
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def get_available_model() -> Optional[EmbeddingModel]:
|
|
129
|
+
"""
|
|
130
|
+
Get the first available model (has valid API credentials).
|
|
131
|
+
|
|
132
|
+
Checks OpenAI first, then Fireworks.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
First available model or None if no providers are configured
|
|
136
|
+
"""
|
|
137
|
+
# Priority order
|
|
138
|
+
priority = ["openai", "fireworks"]
|
|
139
|
+
|
|
140
|
+
for provider_name in priority:
|
|
141
|
+
if ProviderRegistry.is_provider_available(provider_name):
|
|
142
|
+
models = EmbeddingModel.list_by_provider(provider_name)
|
|
143
|
+
if models:
|
|
144
|
+
return models[0]
|
|
145
|
+
|
|
146
|
+
return None
|