emdash-core 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. emdash_core/__init__.py +3 -0
  2. emdash_core/agent/__init__.py +37 -0
  3. emdash_core/agent/agents.py +225 -0
  4. emdash_core/agent/code_reviewer.py +476 -0
  5. emdash_core/agent/compaction.py +143 -0
  6. emdash_core/agent/context_manager.py +140 -0
  7. emdash_core/agent/events.py +338 -0
  8. emdash_core/agent/handlers.py +224 -0
  9. emdash_core/agent/inprocess_subagent.py +377 -0
  10. emdash_core/agent/mcp/__init__.py +50 -0
  11. emdash_core/agent/mcp/client.py +346 -0
  12. emdash_core/agent/mcp/config.py +302 -0
  13. emdash_core/agent/mcp/manager.py +496 -0
  14. emdash_core/agent/mcp/tool_factory.py +213 -0
  15. emdash_core/agent/prompts/__init__.py +38 -0
  16. emdash_core/agent/prompts/main_agent.py +104 -0
  17. emdash_core/agent/prompts/subagents.py +131 -0
  18. emdash_core/agent/prompts/workflow.py +136 -0
  19. emdash_core/agent/providers/__init__.py +34 -0
  20. emdash_core/agent/providers/base.py +143 -0
  21. emdash_core/agent/providers/factory.py +80 -0
  22. emdash_core/agent/providers/models.py +220 -0
  23. emdash_core/agent/providers/openai_provider.py +463 -0
  24. emdash_core/agent/providers/transformers_provider.py +217 -0
  25. emdash_core/agent/research/__init__.py +81 -0
  26. emdash_core/agent/research/agent.py +143 -0
  27. emdash_core/agent/research/controller.py +254 -0
  28. emdash_core/agent/research/critic.py +428 -0
  29. emdash_core/agent/research/macros.py +469 -0
  30. emdash_core/agent/research/planner.py +449 -0
  31. emdash_core/agent/research/researcher.py +436 -0
  32. emdash_core/agent/research/state.py +523 -0
  33. emdash_core/agent/research/synthesizer.py +594 -0
  34. emdash_core/agent/reviewer_profile.py +475 -0
  35. emdash_core/agent/rules.py +123 -0
  36. emdash_core/agent/runner.py +601 -0
  37. emdash_core/agent/session.py +262 -0
  38. emdash_core/agent/spec_schema.py +66 -0
  39. emdash_core/agent/specification.py +479 -0
  40. emdash_core/agent/subagent.py +397 -0
  41. emdash_core/agent/subagent_prompts.py +13 -0
  42. emdash_core/agent/toolkit.py +482 -0
  43. emdash_core/agent/toolkits/__init__.py +64 -0
  44. emdash_core/agent/toolkits/base.py +96 -0
  45. emdash_core/agent/toolkits/explore.py +47 -0
  46. emdash_core/agent/toolkits/plan.py +55 -0
  47. emdash_core/agent/tools/__init__.py +141 -0
  48. emdash_core/agent/tools/analytics.py +436 -0
  49. emdash_core/agent/tools/base.py +131 -0
  50. emdash_core/agent/tools/coding.py +484 -0
  51. emdash_core/agent/tools/github_mcp.py +592 -0
  52. emdash_core/agent/tools/history.py +13 -0
  53. emdash_core/agent/tools/modes.py +153 -0
  54. emdash_core/agent/tools/plan.py +206 -0
  55. emdash_core/agent/tools/plan_write.py +135 -0
  56. emdash_core/agent/tools/search.py +412 -0
  57. emdash_core/agent/tools/spec.py +341 -0
  58. emdash_core/agent/tools/task.py +262 -0
  59. emdash_core/agent/tools/task_output.py +204 -0
  60. emdash_core/agent/tools/tasks.py +454 -0
  61. emdash_core/agent/tools/traversal.py +588 -0
  62. emdash_core/agent/tools/web.py +179 -0
  63. emdash_core/analytics/__init__.py +5 -0
  64. emdash_core/analytics/engine.py +1286 -0
  65. emdash_core/api/__init__.py +5 -0
  66. emdash_core/api/agent.py +308 -0
  67. emdash_core/api/agents.py +154 -0
  68. emdash_core/api/analyze.py +264 -0
  69. emdash_core/api/auth.py +173 -0
  70. emdash_core/api/context.py +77 -0
  71. emdash_core/api/db.py +121 -0
  72. emdash_core/api/embed.py +131 -0
  73. emdash_core/api/feature.py +143 -0
  74. emdash_core/api/health.py +93 -0
  75. emdash_core/api/index.py +162 -0
  76. emdash_core/api/plan.py +110 -0
  77. emdash_core/api/projectmd.py +210 -0
  78. emdash_core/api/query.py +320 -0
  79. emdash_core/api/research.py +122 -0
  80. emdash_core/api/review.py +161 -0
  81. emdash_core/api/router.py +76 -0
  82. emdash_core/api/rules.py +116 -0
  83. emdash_core/api/search.py +119 -0
  84. emdash_core/api/spec.py +99 -0
  85. emdash_core/api/swarm.py +223 -0
  86. emdash_core/api/tasks.py +109 -0
  87. emdash_core/api/team.py +120 -0
  88. emdash_core/auth/__init__.py +17 -0
  89. emdash_core/auth/github.py +389 -0
  90. emdash_core/config.py +74 -0
  91. emdash_core/context/__init__.py +52 -0
  92. emdash_core/context/models.py +50 -0
  93. emdash_core/context/providers/__init__.py +11 -0
  94. emdash_core/context/providers/base.py +74 -0
  95. emdash_core/context/providers/explored_areas.py +183 -0
  96. emdash_core/context/providers/touched_areas.py +360 -0
  97. emdash_core/context/registry.py +73 -0
  98. emdash_core/context/reranker.py +199 -0
  99. emdash_core/context/service.py +260 -0
  100. emdash_core/context/session.py +352 -0
  101. emdash_core/core/__init__.py +104 -0
  102. emdash_core/core/config.py +454 -0
  103. emdash_core/core/exceptions.py +55 -0
  104. emdash_core/core/models.py +265 -0
  105. emdash_core/core/review_config.py +57 -0
  106. emdash_core/db/__init__.py +67 -0
  107. emdash_core/db/auth.py +134 -0
  108. emdash_core/db/models.py +91 -0
  109. emdash_core/db/provider.py +222 -0
  110. emdash_core/db/providers/__init__.py +5 -0
  111. emdash_core/db/providers/supabase.py +452 -0
  112. emdash_core/embeddings/__init__.py +24 -0
  113. emdash_core/embeddings/indexer.py +534 -0
  114. emdash_core/embeddings/models.py +192 -0
  115. emdash_core/embeddings/providers/__init__.py +7 -0
  116. emdash_core/embeddings/providers/base.py +112 -0
  117. emdash_core/embeddings/providers/fireworks.py +141 -0
  118. emdash_core/embeddings/providers/openai.py +104 -0
  119. emdash_core/embeddings/registry.py +146 -0
  120. emdash_core/embeddings/service.py +215 -0
  121. emdash_core/graph/__init__.py +26 -0
  122. emdash_core/graph/builder.py +134 -0
  123. emdash_core/graph/connection.py +692 -0
  124. emdash_core/graph/schema.py +416 -0
  125. emdash_core/graph/writer.py +667 -0
  126. emdash_core/ingestion/__init__.py +7 -0
  127. emdash_core/ingestion/change_detector.py +150 -0
  128. emdash_core/ingestion/git/__init__.py +5 -0
  129. emdash_core/ingestion/git/commit_analyzer.py +196 -0
  130. emdash_core/ingestion/github/__init__.py +6 -0
  131. emdash_core/ingestion/github/pr_fetcher.py +296 -0
  132. emdash_core/ingestion/github/task_extractor.py +100 -0
  133. emdash_core/ingestion/orchestrator.py +540 -0
  134. emdash_core/ingestion/parsers/__init__.py +10 -0
  135. emdash_core/ingestion/parsers/base_parser.py +66 -0
  136. emdash_core/ingestion/parsers/call_graph_builder.py +121 -0
  137. emdash_core/ingestion/parsers/class_extractor.py +154 -0
  138. emdash_core/ingestion/parsers/function_extractor.py +202 -0
  139. emdash_core/ingestion/parsers/import_analyzer.py +119 -0
  140. emdash_core/ingestion/parsers/python_parser.py +123 -0
  141. emdash_core/ingestion/parsers/registry.py +72 -0
  142. emdash_core/ingestion/parsers/ts_ast_parser.js +313 -0
  143. emdash_core/ingestion/parsers/typescript_parser.py +278 -0
  144. emdash_core/ingestion/repository.py +346 -0
  145. emdash_core/models/__init__.py +38 -0
  146. emdash_core/models/agent.py +68 -0
  147. emdash_core/models/index.py +77 -0
  148. emdash_core/models/query.py +113 -0
  149. emdash_core/planning/__init__.py +7 -0
  150. emdash_core/planning/agent_api.py +413 -0
  151. emdash_core/planning/context_builder.py +265 -0
  152. emdash_core/planning/feature_context.py +232 -0
  153. emdash_core/planning/feature_expander.py +646 -0
  154. emdash_core/planning/llm_explainer.py +198 -0
  155. emdash_core/planning/similarity.py +509 -0
  156. emdash_core/planning/team_focus.py +821 -0
  157. emdash_core/server.py +153 -0
  158. emdash_core/sse/__init__.py +5 -0
  159. emdash_core/sse/stream.py +196 -0
  160. emdash_core/swarm/__init__.py +17 -0
  161. emdash_core/swarm/merge_agent.py +383 -0
  162. emdash_core/swarm/session_manager.py +274 -0
  163. emdash_core/swarm/swarm_runner.py +226 -0
  164. emdash_core/swarm/task_definition.py +137 -0
  165. emdash_core/swarm/worker_spawner.py +319 -0
  166. emdash_core/swarm/worktree_manager.py +278 -0
  167. emdash_core/templates/__init__.py +10 -0
  168. emdash_core/templates/defaults/agent-builder.md.template +82 -0
  169. emdash_core/templates/defaults/focus.md.template +115 -0
  170. emdash_core/templates/defaults/pr-review-enhanced.md.template +309 -0
  171. emdash_core/templates/defaults/pr-review.md.template +80 -0
  172. emdash_core/templates/defaults/project.md.template +85 -0
  173. emdash_core/templates/defaults/research_critic.md.template +112 -0
  174. emdash_core/templates/defaults/research_planner.md.template +85 -0
  175. emdash_core/templates/defaults/research_synthesizer.md.template +128 -0
  176. emdash_core/templates/defaults/reviewer.md.template +81 -0
  177. emdash_core/templates/defaults/spec.md.template +41 -0
  178. emdash_core/templates/defaults/tasks.md.template +78 -0
  179. emdash_core/templates/loader.py +296 -0
  180. emdash_core/utils/__init__.py +45 -0
  181. emdash_core/utils/git.py +84 -0
  182. emdash_core/utils/image.py +502 -0
  183. emdash_core/utils/logger.py +51 -0
  184. emdash_core-0.1.7.dist-info/METADATA +35 -0
  185. emdash_core-0.1.7.dist-info/RECORD +187 -0
  186. emdash_core-0.1.7.dist-info/WHEEL +4 -0
  187. emdash_core-0.1.7.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,192 @@
1
+ """Embedding models enum - single source of truth for all supported models."""
2
+
3
+ from enum import Enum
4
+ from dataclasses import dataclass
5
+ from typing import Optional
6
+
7
+
8
+ @dataclass(frozen=True)
9
+ class ModelSpec:
10
+ """Specification for an embedding model."""
11
+
12
+ provider: str # "openai", "fireworks"
13
+ model_id: str # The actual model identifier for the API
14
+ dimensions: int # Output embedding dimensions
15
+ max_tokens: int # Max input tokens
16
+ batch_size: int # Recommended batch size
17
+ description: str # Human-readable description
18
+
19
+
20
+ class EmbeddingModel(Enum):
21
+ """
22
+ All supported embedding models.
23
+
24
+ Format: PROVIDER_MODEL_NAME
25
+
26
+ Usage:
27
+ model = EmbeddingModel.OPENAI_TEXT_3_SMALL
28
+ print(model.spec.dimensions) # 1536
29
+ print(model.spec.provider) # "openai"
30
+ """
31
+
32
+ # ═══════════════════════════════════════════════════════════════════════════
33
+ # OpenAI Models
34
+ # ═══════════════════════════════════════════════════════════════════════════
35
+
36
+ OPENAI_TEXT_3_SMALL = ModelSpec(
37
+ provider="openai",
38
+ model_id="text-embedding-3-small",
39
+ dimensions=1536,
40
+ max_tokens=8191,
41
+ batch_size=100,
42
+ description="OpenAI's small, fast embedding model (best value)",
43
+ )
44
+
45
+ OPENAI_TEXT_3_LARGE = ModelSpec(
46
+ provider="openai",
47
+ model_id="text-embedding-3-large",
48
+ dimensions=3072,
49
+ max_tokens=8191,
50
+ batch_size=50,
51
+ description="OpenAI's large, high-quality embedding model",
52
+ )
53
+
54
+ OPENAI_ADA_002 = ModelSpec(
55
+ provider="openai",
56
+ model_id="text-embedding-ada-002",
57
+ dimensions=1536,
58
+ max_tokens=8191,
59
+ batch_size=100,
60
+ description="OpenAI's legacy Ada model (deprecated, use text-3-small)",
61
+ )
62
+
63
+ # ═══════════════════════════════════════════════════════════════════════════
64
+ # Fireworks AI Models
65
+ # ═══════════════════════════════════════════════════════════════════════════
66
+
67
+ FIREWORKS_NOMIC_EMBED_V1_5 = ModelSpec(
68
+ provider="fireworks",
69
+ model_id="nomic-ai/nomic-embed-text-v1.5",
70
+ dimensions=768,
71
+ max_tokens=8192,
72
+ batch_size=100,
73
+ description="Nomic's open-source embedding model (fast, good quality)",
74
+ )
75
+
76
+ FIREWORKS_E5_MISTRAL_7B = ModelSpec(
77
+ provider="fireworks",
78
+ model_id="intfloat/e5-mistral-7b-instruct",
79
+ dimensions=4096,
80
+ max_tokens=4096,
81
+ batch_size=20,
82
+ description="E5-Mistral 7B (highest quality, slower)",
83
+ )
84
+
85
+ FIREWORKS_UAE_LARGE_V1 = ModelSpec(
86
+ provider="fireworks",
87
+ model_id="WhereIsAI/UAE-Large-V1",
88
+ dimensions=1024,
89
+ max_tokens=512,
90
+ batch_size=50,
91
+ description="UAE-Large-V1 (good balance of speed/quality)",
92
+ )
93
+
94
+ FIREWORKS_GTE_LARGE = ModelSpec(
95
+ provider="fireworks",
96
+ model_id="thenlper/gte-large",
97
+ dimensions=1024,
98
+ max_tokens=512,
99
+ batch_size=50,
100
+ description="GTE-Large (Alibaba's efficient embedding model)",
101
+ )
102
+
103
+ FIREWORKS_BGE_LARGE_EN = ModelSpec(
104
+ provider="fireworks",
105
+ model_id="BAAI/bge-large-en-v1.5",
106
+ dimensions=1024,
107
+ max_tokens=512,
108
+ batch_size=50,
109
+ description="BGE-Large-EN (BAAI's high-quality English model)",
110
+ )
111
+
112
+ # ═══════════════════════════════════════════════════════════════════════════
113
+
114
+ @property
115
+ def spec(self) -> ModelSpec:
116
+ """Get the model specification."""
117
+ return self.value
118
+
119
+ @property
120
+ def provider(self) -> str:
121
+ """Shortcut to get provider name."""
122
+ return self.value.provider
123
+
124
+ @property
125
+ def model_id(self) -> str:
126
+ """Shortcut to get the API model ID."""
127
+ return self.value.model_id
128
+
129
+ @property
130
+ def dimensions(self) -> int:
131
+ """Shortcut to get embedding dimensions."""
132
+ return self.value.dimensions
133
+
134
+ @classmethod
135
+ def get_default(cls) -> "EmbeddingModel":
136
+ """Get the default embedding model."""
137
+ return cls.OPENAI_TEXT_3_SMALL
138
+
139
+ @classmethod
140
+ def from_string(cls, value: str) -> Optional["EmbeddingModel"]:
141
+ """
142
+ Parse model from string.
143
+
144
+ Accepts:
145
+ - Enum name: "OPENAI_TEXT_3_SMALL"
146
+ - Provider:model: "openai:text-embedding-3-small"
147
+ - Just model_id: "text-embedding-3-small"
148
+ """
149
+ value = value.strip()
150
+
151
+ # Try enum name first
152
+ try:
153
+ return cls[value.upper().replace("-", "_").replace(":", "_")]
154
+ except KeyError:
155
+ pass
156
+
157
+ # Try provider:model format
158
+ if ":" in value:
159
+ provider, model_id = value.split(":", 1)
160
+ for model in cls:
161
+ if model.provider == provider and model.model_id == model_id:
162
+ return model
163
+
164
+ # Try just model_id
165
+ for model in cls:
166
+ if model.model_id == value:
167
+ return model
168
+
169
+ return None
170
+
171
+ @classmethod
172
+ def list_by_provider(cls, provider: str) -> list["EmbeddingModel"]:
173
+ """List all models for a specific provider."""
174
+ return [m for m in cls if m.provider == provider]
175
+
176
+ @classmethod
177
+ def list_all(cls) -> list[dict]:
178
+ """List all models with their specs for display."""
179
+ return [
180
+ {
181
+ "name": m.name,
182
+ "provider": m.provider,
183
+ "model_id": m.model_id,
184
+ "dimensions": m.dimensions,
185
+ "description": m.spec.description,
186
+ }
187
+ for m in cls
188
+ ]
189
+
190
+ def __str__(self) -> str:
191
+ """String representation as provider:model_id."""
192
+ return f"{self.provider}:{self.model_id}"
@@ -0,0 +1,7 @@
1
+ """Embedding providers package."""
2
+
3
+ from .base import EmbeddingProvider
4
+ from .openai import OpenAIProvider
5
+ from .fireworks import FireworksProvider
6
+
7
+ __all__ = ["EmbeddingProvider", "OpenAIProvider", "FireworksProvider"]
@@ -0,0 +1,112 @@
1
+ """Base class for embedding providers."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Optional
5
+
6
+ from ..models import EmbeddingModel
7
+
8
+
9
+ class EmbeddingProvider(ABC):
10
+ """
11
+ Abstract base class for embedding providers.
12
+
13
+ Each provider (OpenAI, Fireworks, etc.) implements this interface.
14
+ The registry uses this to provide a unified embedding API.
15
+ """
16
+
17
+ def __init__(self, model: EmbeddingModel):
18
+ """
19
+ Initialize provider with a specific model.
20
+
21
+ Args:
22
+ model: The embedding model to use
23
+ """
24
+ self._model = model
25
+
26
+ @property
27
+ def model(self) -> EmbeddingModel:
28
+ """Get the embedding model."""
29
+ return self._model
30
+
31
+ @property
32
+ def dimensions(self) -> int:
33
+ """Get embedding dimensions for the current model."""
34
+ return self._model.dimensions
35
+
36
+ @property
37
+ @abstractmethod
38
+ def is_available(self) -> bool:
39
+ """Check if the provider is available (API key configured, etc.)."""
40
+ pass
41
+
42
+ @abstractmethod
43
+ def embed_texts(self, texts: list[str]) -> list[Optional[list[float]]]:
44
+ """
45
+ Generate embeddings for multiple texts.
46
+
47
+ Args:
48
+ texts: List of text strings to embed
49
+
50
+ Returns:
51
+ List of embedding vectors. None for failed embeddings.
52
+ """
53
+ pass
54
+
55
+ def embed_text(self, text: str) -> Optional[list[float]]:
56
+ """
57
+ Generate embedding for a single text.
58
+
59
+ Args:
60
+ text: Text string to embed
61
+
62
+ Returns:
63
+ Embedding vector or None if failed
64
+ """
65
+ if not text:
66
+ return None
67
+ embeddings = self.embed_texts([text])
68
+ return embeddings[0] if embeddings else None
69
+
70
+ def embed_query(self, query: str) -> Optional[list[float]]:
71
+ """
72
+ Generate embedding for a search query.
73
+
74
+ Some models treat queries differently from documents.
75
+ Override this method if the model requires special query handling.
76
+
77
+ Args:
78
+ query: Search query string
79
+
80
+ Returns:
81
+ Embedding vector or None if failed
82
+ """
83
+ return self.embed_text(query)
84
+
85
+ def _truncate_text(self, text: str, max_chars: int = 8000) -> str:
86
+ """
87
+ Truncate text to avoid token limits.
88
+
89
+ Args:
90
+ text: Text to truncate
91
+ max_chars: Maximum character length (roughly 4 chars per token)
92
+
93
+ Returns:
94
+ Truncated text
95
+ """
96
+ if text and len(text) > max_chars:
97
+ return text[:max_chars]
98
+ return text or ""
99
+
100
+ def _clean_batch(self, texts: list[str]) -> list[str]:
101
+ """
102
+ Clean and truncate a batch of texts.
103
+
104
+ Args:
105
+ texts: List of texts to clean
106
+
107
+ Returns:
108
+ Cleaned texts
109
+ """
110
+ # Calculate max chars based on model's max tokens (roughly 4 chars per token)
111
+ max_chars = min(self._model.spec.max_tokens * 4, 32000)
112
+ return [self._truncate_text(t, max_chars) for t in texts]
@@ -0,0 +1,141 @@
1
+ """Fireworks AI embedding provider."""
2
+
3
+ from typing import Optional
4
+
5
+ from ..models import EmbeddingModel
6
+ from .base import EmbeddingProvider
7
+ from ...core.config import get_config
8
+ from ...utils.logger import log
9
+
10
+
11
+ class FireworksProvider(EmbeddingProvider):
12
+ """
13
+ Fireworks AI embedding provider.
14
+
15
+ Uses the Fireworks API (OpenAI-compatible) to generate embeddings.
16
+ Requires FIREWORKS_API_KEY environment variable.
17
+
18
+ API docs: https://docs.fireworks.ai/guides/querying-embeddings-models
19
+ """
20
+
21
+ # Fireworks API base URL
22
+ BASE_URL = "https://api.fireworks.ai/inference/v1"
23
+
24
+ # Models this provider handles
25
+ SUPPORTED_MODELS = {
26
+ EmbeddingModel.FIREWORKS_NOMIC_EMBED_V1_5,
27
+ EmbeddingModel.FIREWORKS_E5_MISTRAL_7B,
28
+ EmbeddingModel.FIREWORKS_UAE_LARGE_V1,
29
+ EmbeddingModel.FIREWORKS_GTE_LARGE,
30
+ EmbeddingModel.FIREWORKS_BGE_LARGE_EN,
31
+ }
32
+
33
+ def __init__(self, model: EmbeddingModel):
34
+ """
35
+ Initialize Fireworks provider.
36
+
37
+ Args:
38
+ model: The embedding model to use (must be a Fireworks model)
39
+ """
40
+ if model not in self.SUPPORTED_MODELS:
41
+ raise ValueError(f"Model {model} is not supported by FireworksProvider")
42
+ super().__init__(model)
43
+ self._client = None
44
+
45
+ @property
46
+ def _api_key(self) -> Optional[str]:
47
+ """Get Fireworks API key from config."""
48
+ config = get_config()
49
+ # Check if fireworks config exists
50
+ if hasattr(config, "fireworks") and config.fireworks.api_key:
51
+ return config.fireworks.api_key
52
+ return None
53
+
54
+ @property
55
+ def is_available(self) -> bool:
56
+ """Check if Fireworks API key is configured."""
57
+ return self._api_key is not None and len(self._api_key) > 0
58
+
59
+ @property
60
+ def _client_instance(self):
61
+ """Lazy-load OpenAI client configured for Fireworks."""
62
+ if self._client is None:
63
+ if not self.is_available:
64
+ raise RuntimeError(
65
+ "Fireworks API key not configured. Set FIREWORKS_API_KEY environment variable."
66
+ )
67
+ try:
68
+ from openai import OpenAI
69
+
70
+ # Fireworks uses OpenAI-compatible API
71
+ self._client = OpenAI(
72
+ api_key=self._api_key,
73
+ base_url=self.BASE_URL,
74
+ )
75
+ except ImportError:
76
+ raise RuntimeError(
77
+ "OpenAI library not installed. Run: pip install openai"
78
+ )
79
+ return self._client
80
+
81
+ def embed_texts(self, texts: list[str]) -> list[Optional[list[float]]]:
82
+ """
83
+ Generate embeddings using Fireworks API.
84
+
85
+ Args:
86
+ texts: List of text strings to embed
87
+
88
+ Returns:
89
+ List of embedding vectors. None for failed embeddings.
90
+ """
91
+ if not texts:
92
+ return []
93
+
94
+ all_embeddings = []
95
+ batch_size = self._model.spec.batch_size
96
+
97
+ for i in range(0, len(texts), batch_size):
98
+ batch = texts[i : i + batch_size]
99
+ cleaned_batch = self._clean_batch(batch)
100
+
101
+ try:
102
+ # Fireworks requires accounts/ prefix for model IDs
103
+ model_id = f"accounts/fireworks/models/{self._model.model_id}"
104
+
105
+ response = self._client_instance.embeddings.create(
106
+ model=model_id,
107
+ input=cleaned_batch,
108
+ )
109
+ batch_embeddings = [item.embedding for item in response.data]
110
+ all_embeddings.extend(batch_embeddings)
111
+
112
+ except Exception as e:
113
+ log.error(f"Fireworks embedding error: {e}")
114
+ all_embeddings.extend([None] * len(cleaned_batch))
115
+
116
+ return all_embeddings
117
+
118
+ def embed_query(self, query: str) -> Optional[list[float]]:
119
+ """
120
+ Generate embedding for a search query.
121
+
122
+ Some Fireworks models (like Nomic, E5) benefit from query prefixes.
123
+
124
+ Args:
125
+ query: Search query string
126
+
127
+ Returns:
128
+ Embedding vector or None if failed
129
+ """
130
+ if not query:
131
+ return None
132
+
133
+ # E5 models expect "query: " prefix for queries
134
+ if self._model == EmbeddingModel.FIREWORKS_E5_MISTRAL_7B:
135
+ query = f"query: {query}"
136
+
137
+ # Nomic models can optionally use "search_query: " prefix
138
+ elif self._model == EmbeddingModel.FIREWORKS_NOMIC_EMBED_V1_5:
139
+ query = f"search_query: {query}"
140
+
141
+ return self.embed_text(query)
@@ -0,0 +1,104 @@
1
+ """OpenAI embedding provider."""
2
+
3
+ from typing import Optional
4
+
5
+ from ..models import EmbeddingModel
6
+ from .base import EmbeddingProvider
7
+ from ...core.config import get_config
8
+ from ...utils.logger import log
9
+
10
+
11
+ class OpenAIProvider(EmbeddingProvider):
12
+ """
13
+ OpenAI embedding provider.
14
+
15
+ Uses the OpenAI API to generate embeddings.
16
+ Requires OPENAI_API_KEY environment variable.
17
+ """
18
+
19
+ # Models this provider handles
20
+ SUPPORTED_MODELS = {
21
+ EmbeddingModel.OPENAI_TEXT_3_SMALL,
22
+ EmbeddingModel.OPENAI_TEXT_3_LARGE,
23
+ EmbeddingModel.OPENAI_ADA_002,
24
+ }
25
+
26
+ def __init__(self, model: EmbeddingModel):
27
+ """
28
+ Initialize OpenAI provider.
29
+
30
+ Args:
31
+ model: The embedding model to use (must be an OpenAI model)
32
+ """
33
+ if model not in self.SUPPORTED_MODELS:
34
+ raise ValueError(f"Model {model} is not supported by OpenAIProvider")
35
+ super().__init__(model)
36
+ self._client = None
37
+ self._config = get_config().openai
38
+
39
+ @property
40
+ def is_available(self) -> bool:
41
+ """Check if OpenAI API key is configured."""
42
+ return self._config.is_available
43
+
44
+ @property
45
+ def _client_instance(self):
46
+ """Lazy-load OpenAI client."""
47
+ if self._client is None:
48
+ if not self.is_available:
49
+ raise RuntimeError(
50
+ "OpenAI API key not configured. Set OPENAI_API_KEY environment variable."
51
+ )
52
+ try:
53
+ from openai import OpenAI
54
+
55
+ self._client = OpenAI(api_key=self._config.api_key)
56
+ except ImportError:
57
+ raise RuntimeError(
58
+ "OpenAI library not installed. Run: pip install openai"
59
+ )
60
+ return self._client
61
+
62
+ def embed_texts(self, texts: list[str]) -> list[Optional[list[float]]]:
63
+ """
64
+ Generate embeddings using OpenAI API.
65
+
66
+ Args:
67
+ texts: List of text strings to embed
68
+
69
+ Returns:
70
+ List of embedding vectors. None for failed embeddings.
71
+ """
72
+ if not texts:
73
+ return []
74
+
75
+ all_embeddings = []
76
+ batch_size = self._model.spec.batch_size
77
+
78
+ for i in range(0, len(texts), batch_size):
79
+ batch = texts[i : i + batch_size]
80
+ cleaned_batch = self._clean_batch(batch)
81
+
82
+ try:
83
+ # Use dimensions parameter for text-embedding-3 models
84
+ kwargs = {
85
+ "model": self._model.model_id,
86
+ "input": cleaned_batch,
87
+ }
88
+
89
+ # text-embedding-3 models support custom dimensions
90
+ if self._model in {
91
+ EmbeddingModel.OPENAI_TEXT_3_SMALL,
92
+ EmbeddingModel.OPENAI_TEXT_3_LARGE,
93
+ }:
94
+ kwargs["dimensions"] = self._model.dimensions
95
+
96
+ response = self._client_instance.embeddings.create(**kwargs)
97
+ batch_embeddings = [item.embedding for item in response.data]
98
+ all_embeddings.extend(batch_embeddings)
99
+
100
+ except Exception as e:
101
+ log.error(f"OpenAI embedding error: {e}")
102
+ all_embeddings.extend([None] * len(cleaned_batch))
103
+
104
+ return all_embeddings
@@ -0,0 +1,146 @@
1
+ """Provider registry for embedding models."""
2
+
3
+ from typing import Type, Optional
4
+
5
+ from .models import EmbeddingModel
6
+ from .providers.base import EmbeddingProvider
7
+ from .providers.openai import OpenAIProvider
8
+ from .providers.fireworks import FireworksProvider
9
+
10
+
11
+ class ProviderRegistry:
12
+ """
13
+ Registry for embedding providers.
14
+
15
+ Maps provider names to provider classes. No if-else chains needed.
16
+ Just register your provider once and it's available everywhere.
17
+
18
+ Usage:
19
+ registry = ProviderRegistry()
20
+ provider = registry.get_provider(EmbeddingModel.OPENAI_TEXT_3_SMALL)
21
+ embeddings = provider.embed_texts(["hello world"])
22
+ """
23
+
24
+ # Provider class registry: provider_name -> provider_class
25
+ _providers: dict[str, Type[EmbeddingProvider]] = {}
26
+
27
+ @classmethod
28
+ def register(cls, provider_name: str, provider_class: Type[EmbeddingProvider]):
29
+ """
30
+ Register a provider class.
31
+
32
+ Args:
33
+ provider_name: Name of the provider (e.g., "openai", "fireworks")
34
+ provider_class: The provider class to register
35
+ """
36
+ cls._providers[provider_name] = provider_class
37
+
38
+ @classmethod
39
+ def get_provider_class(cls, provider_name: str) -> Optional[Type[EmbeddingProvider]]:
40
+ """
41
+ Get the provider class for a provider name.
42
+
43
+ Args:
44
+ provider_name: Name of the provider
45
+
46
+ Returns:
47
+ Provider class or None if not registered
48
+ """
49
+ return cls._providers.get(provider_name)
50
+
51
+ @classmethod
52
+ def get_provider(cls, model: EmbeddingModel) -> EmbeddingProvider:
53
+ """
54
+ Get an instantiated provider for a model.
55
+
56
+ Args:
57
+ model: The embedding model
58
+
59
+ Returns:
60
+ Instantiated provider for the model
61
+
62
+ Raises:
63
+ ValueError: If no provider is registered for the model's provider
64
+ """
65
+ provider_class = cls._providers.get(model.provider)
66
+ if provider_class is None:
67
+ raise ValueError(
68
+ f"No provider registered for '{model.provider}'. "
69
+ f"Available providers: {list(cls._providers.keys())}"
70
+ )
71
+ return provider_class(model)
72
+
73
+ @classmethod
74
+ def list_providers(cls) -> list[str]:
75
+ """List all registered provider names."""
76
+ return list(cls._providers.keys())
77
+
78
+ @classmethod
79
+ def is_provider_available(cls, provider_name: str) -> bool:
80
+ """
81
+ Check if a provider is available (registered and configured).
82
+
83
+ Args:
84
+ provider_name: Name of the provider
85
+
86
+ Returns:
87
+ True if provider is registered and has valid credentials
88
+ """
89
+ provider_class = cls._providers.get(provider_name)
90
+ if provider_class is None:
91
+ return False
92
+
93
+ # Get any model for this provider to check availability
94
+ models = EmbeddingModel.list_by_provider(provider_name)
95
+ if not models:
96
+ return False
97
+
98
+ try:
99
+ provider = provider_class(models[0])
100
+ return provider.is_available
101
+ except Exception:
102
+ return False
103
+
104
+
105
+ # ═══════════════════════════════════════════════════════════════════════════════
106
+ # Register all providers
107
+ # ═══════════════════════════════════════════════════════════════════════════════
108
+
109
+ ProviderRegistry.register("openai", OpenAIProvider)
110
+ ProviderRegistry.register("fireworks", FireworksProvider)
111
+
112
+
113
+ # ═══════════════════════════════════════════════════════════════════════════════
114
+ # Convenience functions
115
+ # ═══════════════════════════════════════════════════════════════════════════════
116
+
117
+
118
+ def get_provider(model: EmbeddingModel) -> EmbeddingProvider:
119
+ """Get an instantiated provider for a model."""
120
+ return ProviderRegistry.get_provider(model)
121
+
122
+
123
+ def get_default_provider() -> EmbeddingProvider:
124
+ """Get the default embedding provider (OpenAI text-embedding-3-small)."""
125
+ return ProviderRegistry.get_provider(EmbeddingModel.get_default())
126
+
127
+
128
+ def get_available_model() -> Optional[EmbeddingModel]:
129
+ """
130
+ Get the first available model (has valid API credentials).
131
+
132
+ Checks OpenAI first, then Fireworks.
133
+
134
+ Returns:
135
+ First available model or None if no providers are configured
136
+ """
137
+ # Priority order
138
+ priority = ["openai", "fireworks"]
139
+
140
+ for provider_name in priority:
141
+ if ProviderRegistry.is_provider_available(provider_name):
142
+ models = EmbeddingModel.list_by_provider(provider_name)
143
+ if models:
144
+ return models[0]
145
+
146
+ return None