loom-agent 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of loom-agent might be problematic. Click here for more details.

Files changed (89) hide show
  1. loom/__init__.py +77 -0
  2. loom/agent.py +217 -0
  3. loom/agents/__init__.py +10 -0
  4. loom/agents/refs.py +28 -0
  5. loom/agents/registry.py +50 -0
  6. loom/builtin/compression/__init__.py +4 -0
  7. loom/builtin/compression/structured.py +79 -0
  8. loom/builtin/embeddings/__init__.py +9 -0
  9. loom/builtin/embeddings/openai_embedding.py +135 -0
  10. loom/builtin/embeddings/sentence_transformers_embedding.py +145 -0
  11. loom/builtin/llms/__init__.py +8 -0
  12. loom/builtin/llms/mock.py +34 -0
  13. loom/builtin/llms/openai.py +168 -0
  14. loom/builtin/llms/rule.py +102 -0
  15. loom/builtin/memory/__init__.py +5 -0
  16. loom/builtin/memory/in_memory.py +21 -0
  17. loom/builtin/memory/persistent_memory.py +278 -0
  18. loom/builtin/retriever/__init__.py +9 -0
  19. loom/builtin/retriever/chroma_store.py +265 -0
  20. loom/builtin/retriever/in_memory.py +106 -0
  21. loom/builtin/retriever/milvus_store.py +307 -0
  22. loom/builtin/retriever/pinecone_store.py +237 -0
  23. loom/builtin/retriever/qdrant_store.py +274 -0
  24. loom/builtin/retriever/vector_store.py +128 -0
  25. loom/builtin/retriever/vector_store_config.py +217 -0
  26. loom/builtin/tools/__init__.py +32 -0
  27. loom/builtin/tools/calculator.py +49 -0
  28. loom/builtin/tools/document_search.py +111 -0
  29. loom/builtin/tools/glob.py +27 -0
  30. loom/builtin/tools/grep.py +56 -0
  31. loom/builtin/tools/http_request.py +86 -0
  32. loom/builtin/tools/python_repl.py +73 -0
  33. loom/builtin/tools/read_file.py +32 -0
  34. loom/builtin/tools/task.py +158 -0
  35. loom/builtin/tools/web_search.py +64 -0
  36. loom/builtin/tools/write_file.py +31 -0
  37. loom/callbacks/base.py +9 -0
  38. loom/callbacks/logging.py +12 -0
  39. loom/callbacks/metrics.py +27 -0
  40. loom/callbacks/observability.py +248 -0
  41. loom/components/agent.py +107 -0
  42. loom/core/agent_executor.py +450 -0
  43. loom/core/circuit_breaker.py +178 -0
  44. loom/core/compression_manager.py +329 -0
  45. loom/core/context_retriever.py +185 -0
  46. loom/core/error_classifier.py +193 -0
  47. loom/core/errors.py +66 -0
  48. loom/core/message_queue.py +167 -0
  49. loom/core/permission_store.py +62 -0
  50. loom/core/permissions.py +69 -0
  51. loom/core/scheduler.py +125 -0
  52. loom/core/steering_control.py +47 -0
  53. loom/core/structured_logger.py +279 -0
  54. loom/core/subagent_pool.py +232 -0
  55. loom/core/system_prompt.py +141 -0
  56. loom/core/system_reminders.py +283 -0
  57. loom/core/tool_pipeline.py +113 -0
  58. loom/core/types.py +269 -0
  59. loom/interfaces/compressor.py +59 -0
  60. loom/interfaces/embedding.py +51 -0
  61. loom/interfaces/llm.py +33 -0
  62. loom/interfaces/memory.py +29 -0
  63. loom/interfaces/retriever.py +179 -0
  64. loom/interfaces/tool.py +27 -0
  65. loom/interfaces/vector_store.py +80 -0
  66. loom/llm/__init__.py +14 -0
  67. loom/llm/config.py +228 -0
  68. loom/llm/factory.py +111 -0
  69. loom/llm/model_health.py +235 -0
  70. loom/llm/model_pool_advanced.py +305 -0
  71. loom/llm/pool.py +170 -0
  72. loom/llm/registry.py +201 -0
  73. loom/mcp/__init__.py +4 -0
  74. loom/mcp/client.py +86 -0
  75. loom/mcp/registry.py +58 -0
  76. loom/mcp/tool_adapter.py +48 -0
  77. loom/observability/__init__.py +5 -0
  78. loom/patterns/__init__.py +5 -0
  79. loom/patterns/multi_agent.py +123 -0
  80. loom/patterns/rag.py +262 -0
  81. loom/plugins/registry.py +55 -0
  82. loom/resilience/__init__.py +5 -0
  83. loom/tooling.py +72 -0
  84. loom/utils/agent_loader.py +218 -0
  85. loom/utils/token_counter.py +19 -0
  86. loom_agent-0.0.1.dist-info/METADATA +457 -0
  87. loom_agent-0.0.1.dist-info/RECORD +89 -0
  88. loom_agent-0.0.1.dist-info/WHEEL +4 -0
  89. loom_agent-0.0.1.dist-info/licenses/LICENSE +21 -0
loom/__init__.py ADDED
@@ -0,0 +1,77 @@
1
+ from .components.agent import Agent
2
+ from .core.subagent_pool import SubAgentPool
3
+ from .llm import (
4
+ LLMConfig,
5
+ LLMProvider,
6
+ LLMCapabilities,
7
+ LLMFactory,
8
+ ModelPool,
9
+ ModelRegistry,
10
+ )
11
+ from .agent import agent, agent_from_env
12
+ from .tooling import tool
13
+ from .agents import AgentSpec, register_agent, list_agent_types, get_agent_by_type
14
+ from .agents.refs import AgentRef, ModelRef, agent_ref, model_ref
15
+
16
+ # P2 Features - Production Ready
17
+ from .builtin.memory import InMemoryMemory, PersistentMemory
18
+ from .core.error_classifier import ErrorClassifier, RetryPolicy
19
+ from .core.circuit_breaker import CircuitBreaker, CircuitBreakerConfig, CircuitState
20
+
21
+ # P3 Features - Optimization
22
+ from .core.structured_logger import StructuredLogger, get_logger, set_correlation_id
23
+ from .core.system_reminders import SystemReminderManager, get_reminder_manager
24
+ from .callbacks.observability import ObservabilityCallback, MetricsAggregator
25
+ from .llm.model_health import ModelHealthChecker, HealthStatus
26
+ from .llm.model_pool_advanced import ModelPoolLLM, ModelConfig, FallbackChain
27
+
28
+ try:
29
+ from importlib.metadata import version as _pkg_version
30
+
31
+ __version__ = _pkg_version("loom-agent")
32
+ except Exception: # pragma: no cover - best-effort
33
+ __version__ = "0"
34
+
35
+ __all__ = [
36
+ "Agent",
37
+ "SubAgentPool",
38
+ "LLMConfig",
39
+ "LLMProvider",
40
+ "LLMCapabilities",
41
+ "LLMFactory",
42
+ "ModelPool",
43
+ "ModelRegistry",
44
+ "agent",
45
+ "tool",
46
+ "agent_from_env",
47
+ "AgentSpec",
48
+ "register_agent",
49
+ "list_agent_types",
50
+ "get_agent_by_type",
51
+ "AgentRef",
52
+ "ModelRef",
53
+ "agent_ref",
54
+ "model_ref",
55
+ # P2 exports
56
+ "InMemoryMemory",
57
+ "PersistentMemory",
58
+ "ErrorClassifier",
59
+ "RetryPolicy",
60
+ "CircuitBreaker",
61
+ "CircuitBreakerConfig",
62
+ "CircuitState",
63
+ # P3 exports
64
+ "StructuredLogger",
65
+ "get_logger",
66
+ "set_correlation_id",
67
+ "SystemReminderManager",
68
+ "get_reminder_manager",
69
+ "ObservabilityCallback",
70
+ "MetricsAggregator",
71
+ "ModelHealthChecker",
72
+ "HealthStatus",
73
+ "ModelPoolLLM",
74
+ "ModelConfig",
75
+ "FallbackChain",
76
+ "__version__",
77
+ ]
loom/agent.py ADDED
@@ -0,0 +1,217 @@
1
+ """Convenience builder for Agent
2
+
3
+ Allows simple usage patterns like:
4
+
5
+ import loom
6
+ agent = loom.agent(provider="openai", model="gpt-4o", api_key="...")
7
+ text = await agent.ainvoke("Hello")
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import os
13
+ from typing import Dict, List, Optional, Union
14
+
15
+ from .components.agent import Agent as _Agent
16
+ from .interfaces.llm import BaseLLM
17
+ from .interfaces.memory import BaseMemory
18
+ from .interfaces.tool import BaseTool
19
+ from .interfaces.compressor import BaseCompressor
20
+ from .llm.config import LLMConfig, LLMProvider
21
+ from .llm.factory import LLMFactory
22
+ from .callbacks.base import BaseCallback
23
+ from .callbacks.metrics import MetricsCollector
24
+ from .core.steering_control import SteeringControl
25
+
26
+
27
+ def agent(
28
+ *,
29
+ # Provide one of: llm | config | (provider+model)
30
+ llm: Optional[BaseLLM] = None,
31
+ config: Optional[LLMConfig] = None,
32
+ provider: Optional[Union[str, LLMProvider]] = None,
33
+ model: Optional[str] = None,
34
+ api_key: Optional[str] = None,
35
+ base_url: Optional[str] = None,
36
+ # Agent options
37
+ tools: Optional[List[BaseTool]] = None,
38
+ memory: Optional[BaseMemory] = None,
39
+ compressor: Optional[BaseCompressor] = None,
40
+ max_iterations: int = 50,
41
+ max_context_tokens: int = 16000,
42
+ permission_policy: Optional[Dict[str, str]] = None,
43
+ ask_handler=None,
44
+ safe_mode: bool = False,
45
+ permission_store=None,
46
+ # Extra LLM config overrides
47
+ temperature: Optional[float] = None,
48
+ max_tokens: Optional[int] = None,
49
+ # Advanced
50
+ context_retriever=None,
51
+ system_instructions: Optional[str] = None,
52
+ callbacks: Optional[list[BaseCallback]] = None,
53
+ steering_control: Optional[SteeringControl] = None,
54
+ metrics: Optional[MetricsCollector] = None,
55
+ ) -> _Agent:
56
+ """Create an Agent with minimal parameters.
57
+
58
+ Priority:
59
+ 1) Use provided `llm`
60
+ 2) Build from `config`
61
+ 3) Build from `provider` + `model` (+ api_key/base_url)
62
+ """
63
+
64
+ if llm is None:
65
+ if config is None and provider is not None and model is not None:
66
+ cfg = _build_config_from_inputs(
67
+ provider=provider,
68
+ model=model,
69
+ api_key=api_key,
70
+ base_url=base_url,
71
+ temperature=temperature,
72
+ max_tokens=max_tokens,
73
+ )
74
+ llm = LLMFactory.create(cfg)
75
+ elif config is not None:
76
+ llm = LLMFactory.create(config)
77
+ else:
78
+ raise ValueError("Please provide `llm`, or `config`, or `provider` + `model`.")
79
+
80
+ return _Agent(
81
+ llm=llm,
82
+ tools=tools,
83
+ memory=memory,
84
+ compressor=compressor,
85
+ max_iterations=max_iterations,
86
+ max_context_tokens=max_context_tokens,
87
+ permission_policy=permission_policy,
88
+ ask_handler=ask_handler,
89
+ safe_mode=safe_mode,
90
+ permission_store=permission_store,
91
+ context_retriever=context_retriever,
92
+ system_instructions=system_instructions,
93
+ callbacks=callbacks,
94
+ steering_control=steering_control,
95
+ metrics=metrics,
96
+ )
97
+
98
+
99
+ def agent_from_env(
100
+ *,
101
+ provider: Optional[Union[str, LLMProvider]] = None,
102
+ model: Optional[str] = None,
103
+ # Agent options
104
+ tools: Optional[List[BaseTool]] = None,
105
+ memory: Optional[BaseMemory] = None,
106
+ compressor: Optional[BaseCompressor] = None,
107
+ max_iterations: int = 50,
108
+ max_context_tokens: int = 16000,
109
+ permission_policy: Optional[Dict[str, str]] = None,
110
+ ask_handler=None,
111
+ safe_mode: bool = False,
112
+ permission_store=None,
113
+ # Advanced
114
+ context_retriever=None,
115
+ system_instructions: Optional[str] = None,
116
+ callbacks: Optional[list[BaseCallback]] = None,
117
+ steering_control: Optional[SteeringControl] = None,
118
+ metrics: Optional[MetricsCollector] = None,
119
+ ) -> _Agent:
120
+ """Construct an Agent using provider/model resolved from environment.
121
+
122
+ Environment variables:
123
+ - LOOM_PROVIDER (fallback if provider not given)
124
+ - LOOM_MODEL (fallback if model not given)
125
+ - Provider-specific: OPENAI_API_KEY, OPENAI_BASE_URL, ANTHROPIC_API_KEY, COHERE_API_KEY, AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT
126
+ """
127
+ env_provider = os.getenv("LOOM_PROVIDER")
128
+ env_model = os.getenv("LOOM_MODEL")
129
+ use_provider = provider or env_provider
130
+ use_model = model or env_model
131
+ if not use_provider or not use_model:
132
+ raise ValueError("agent_from_env requires provider/model or LOOM_PROVIDER/LOOM_MODEL env")
133
+
134
+ return agent(
135
+ provider=use_provider,
136
+ model=use_model,
137
+ tools=tools,
138
+ memory=memory,
139
+ compressor=compressor,
140
+ max_iterations=max_iterations,
141
+ max_context_tokens=max_context_tokens,
142
+ permission_policy=permission_policy,
143
+ ask_handler=ask_handler,
144
+ safe_mode=safe_mode,
145
+ permission_store=permission_store,
146
+ context_retriever=context_retriever,
147
+ system_instructions=system_instructions,
148
+ callbacks=callbacks,
149
+ steering_control=steering_control,
150
+ metrics=metrics,
151
+ )
152
+
153
+
154
+ def _build_config_from_inputs(
155
+ *,
156
+ provider: Union[str, LLMProvider],
157
+ model: str,
158
+ api_key: Optional[str],
159
+ base_url: Optional[str],
160
+ temperature: Optional[float],
161
+ max_tokens: Optional[int],
162
+ ) -> LLMConfig:
163
+ prov = provider.value if isinstance(provider, LLMProvider) else str(provider).lower()
164
+
165
+ # fill missing api_key from environment if possible
166
+ if api_key is None:
167
+ if prov == "openai" or prov == "custom" or prov == "azure_openai":
168
+ api_key = os.getenv("OPENAI_API_KEY") or os.getenv("AZURE_OPENAI_API_KEY")
169
+ elif prov == "anthropic":
170
+ api_key = os.getenv("ANTHROPIC_API_KEY")
171
+ elif prov == "cohere":
172
+ api_key = os.getenv("COHERE_API_KEY")
173
+
174
+ # default base_url for compatible providers if provided via env
175
+ if base_url is None and prov in {"openai", "custom", "azure_openai"}:
176
+ base_url = os.getenv("OPENAI_BASE_URL") or os.getenv("OPENAI_API_BASE")
177
+
178
+ kwargs = {}
179
+ if temperature is not None:
180
+ kwargs["temperature"] = temperature
181
+ if max_tokens is not None:
182
+ kwargs["max_tokens"] = max_tokens
183
+ if base_url is not None:
184
+ kwargs["base_url"] = base_url
185
+
186
+ if prov == "openai":
187
+ if not api_key:
188
+ raise ValueError("OPENAI provider requires api_key or OPENAI_API_KEY env")
189
+ return LLMConfig.openai(api_key=api_key, model=model, **kwargs)
190
+ if prov == "azure_openai":
191
+ # Treat like OpenAI-compatible; users can pass endpoint via base_url
192
+ if not api_key:
193
+ raise ValueError("AZURE_OPENAI provider requires api_key or AZURE_OPENAI_API_KEY env")
194
+ return LLMConfig.azure_openai(
195
+ api_key=api_key,
196
+ deployment_name=model,
197
+ endpoint=kwargs.pop("base_url", os.getenv("AZURE_OPENAI_ENDPOINT", "")),
198
+ **kwargs,
199
+ )
200
+ if prov == "anthropic":
201
+ if not api_key:
202
+ raise ValueError("ANTHROPIC provider requires api_key or ANTHROPIC_API_KEY env")
203
+ return LLMConfig.anthropic(api_key=api_key, model=model, **kwargs)
204
+ if prov == "cohere":
205
+ if not api_key:
206
+ raise ValueError("COHERE provider requires api_key or COHERE_API_KEY env")
207
+ return LLMConfig.cohere(api_key=api_key, model=model, **kwargs)
208
+ if prov == "google":
209
+ if not api_key:
210
+ raise ValueError("GOOGLE provider requires api_key env")
211
+ return LLMConfig.google(api_key=api_key, model=model, **kwargs)
212
+ if prov == "ollama":
213
+ return LLMConfig.ollama(model=model, base_url=base_url or "http://localhost:11434", **kwargs)
214
+ if prov == "custom":
215
+ return LLMConfig.custom(model_name=model, base_url=base_url or "", api_key=api_key, **kwargs)
216
+
217
+ raise ValueError(f"Unknown provider: {provider}")
@@ -0,0 +1,10 @@
1
+ from .registry import AgentSpec, register_agent, get_agent_by_type, list_agent_types, clear_agents
2
+
3
+ __all__ = [
4
+ "AgentSpec",
5
+ "register_agent",
6
+ "get_agent_by_type",
7
+ "list_agent_types",
8
+ "clear_agents",
9
+ ]
10
+
loom/agents/refs.py ADDED
@@ -0,0 +1,28 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Union
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class AgentRef:
9
+ agent_type: str
10
+
11
+
12
+ def agent_ref(agent_type: str) -> AgentRef:
13
+ return AgentRef(agent_type=agent_type)
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class ModelRef:
18
+ name: str
19
+
20
+
21
+ def model_ref(name: str) -> ModelRef:
22
+ return ModelRef(name=name)
23
+
24
+
25
+ # Back-compat typing helper for convenience
26
+ AgentReferenceLike = Union[str, AgentRef]
27
+ ModelReferenceLike = Union[str, ModelRef]
28
+
@@ -0,0 +1,50 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Optional, Union
5
+
6
+
7
+ @dataclass
8
+ class AgentSpec:
9
+ """Programmatic agent definition for Loom.
10
+
11
+ This enables framework-first agent configuration without on-disk files.
12
+
13
+ Fields:
14
+ - agent_type: identifier used by Task tool (subagent_type)
15
+ - description: when to use this agent (for docs/UX; not enforced)
16
+ - system_instructions: system prompt injected when this agent is used
17
+ - tools: '*' or list of tool names that this agent may use
18
+ - model_name: optional model override for this agent
19
+ """
20
+
21
+ agent_type: str
22
+ description: str
23
+ system_instructions: str
24
+ tools: Union[List[str], str] = "*"
25
+ model_name: Optional[str] = None
26
+
27
+
28
+ _AGENTS: Dict[str, AgentSpec] = {}
29
+
30
+
31
+ def register_agent(spec: AgentSpec) -> AgentSpec:
32
+ """Register or override an agent spec in memory.
33
+
34
+ The most recent registration wins for a given agent_type.
35
+ """
36
+
37
+ _AGENTS[spec.agent_type] = spec
38
+ return spec
39
+
40
+
41
+ def get_agent_by_type(agent_type: str) -> Optional[AgentSpec]:
42
+ return _AGENTS.get(agent_type)
43
+
44
+
45
+ def list_agent_types() -> List[str]:
46
+ return list(_AGENTS.keys())
47
+
48
+
49
+ def clear_agents() -> None:
50
+ _AGENTS.clear()
@@ -0,0 +1,4 @@
1
+ from .structured import StructuredCompressor, CompressionConfig
2
+
3
+ __all__ = ["StructuredCompressor", "CompressionConfig"]
4
+
@@ -0,0 +1,79 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from datetime import datetime
5
+ from typing import List
6
+
7
+ from loom.core.types import Message
8
+ from loom.interfaces.compressor import BaseCompressor
9
+ from loom.utils.token_counter import count_messages_tokens
10
+
11
+
12
+ @dataclass
13
+ class CompressionConfig:
14
+ threshold: float = 0.92
15
+ warning_threshold: float = 0.80
16
+ target_ratio: float = 0.75
17
+ max_tokens_per_section: int = 512
18
+
19
+
20
+ class StructuredCompressor(BaseCompressor):
21
+ """简化版 8 段式结构化压缩器。
22
+
23
+ 不依赖 LLM,直观汇总近端消息片段,生成一条 system 摘要消息,并保留近端窗口。
24
+ """
25
+
26
+ def __init__(self, config: CompressionConfig | None = None, keep_recent: int = 6) -> None:
27
+ self.config = config or CompressionConfig()
28
+ self.keep_recent = keep_recent
29
+
30
+ async def compress(self, messages: List[Message]) -> List[Message]:
31
+ recent = messages[-self.keep_recent :] if self.keep_recent > 0 else []
32
+ # 粗略提取要点:截取用户与助手的近端内容片段
33
+ user_snippets = [m.content for m in messages if m.role == "user"][-3:]
34
+ assistant_snippets = [m.content for m in messages if m.role == "assistant"][-3:]
35
+ tool_snippets = [m.content for m in messages if m.role == "tool"][-5:]
36
+
37
+ summary = [
38
+ "# 对话历史压缩摘要",
39
+ f"时间: {datetime.now().isoformat(timespec='seconds')}",
40
+ "",
41
+ "## background_context",
42
+ "- 最近用户/助手对话被压缩为摘要,保留关键近端消息窗口。",
43
+ "",
44
+ "## key_decisions",
45
+ "- 见 assistant 近端结论片段(如有)。",
46
+ "",
47
+ "## tool_usage_log",
48
+ *[f"- {t[:200]}" for t in tool_snippets],
49
+ "",
50
+ "## user_intent_evolution",
51
+ *[f"- {u[:200]}" for u in user_snippets],
52
+ "",
53
+ "## execution_results",
54
+ *[f"- {a[:200]}" for a in assistant_snippets],
55
+ "",
56
+ "## errors_and_solutions",
57
+ "- (占位)如有错误会在此归档。",
58
+ "",
59
+ "## open_issues",
60
+ "- (占位)后续待解问题列表。",
61
+ "",
62
+ "## future_plans",
63
+ "- (占位)下一步行动建议。",
64
+ ]
65
+
66
+ compressed_msg = Message(
67
+ role="system",
68
+ content="\n".join(summary),
69
+ metadata={"compressed": True, "compression_time": datetime.now().isoformat()},
70
+ )
71
+
72
+ return [compressed_msg, *recent]
73
+
74
+ def should_compress(self, token_count: int, max_tokens: int) -> bool:
75
+ if max_tokens <= 0:
76
+ return False
77
+ ratio = token_count / max_tokens
78
+ return ratio >= self.config.threshold
79
+
@@ -0,0 +1,9 @@
1
+ """内置 Embedding 实现"""
2
+
3
+ from loom.builtin.embeddings.openai_embedding import OpenAIEmbedding
4
+
5
+ try:
6
+ from loom.builtin.embeddings.sentence_transformers_embedding import SentenceTransformersEmbedding
7
+ __all__ = ["OpenAIEmbedding", "SentenceTransformersEmbedding"]
8
+ except ImportError:
9
+ __all__ = ["OpenAIEmbedding"]
@@ -0,0 +1,135 @@
1
+ """OpenAI Embedding 适配器"""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import List
6
+
7
+ from loom.interfaces.embedding import BaseEmbedding
8
+
9
+ try:
10
+ from openai import AsyncOpenAI
11
+ OPENAI_AVAILABLE = True
12
+ except ImportError:
13
+ OPENAI_AVAILABLE = False
14
+
15
+
16
+ class OpenAIEmbedding(BaseEmbedding):
17
+ """
18
+ OpenAI Embedding 适配器
19
+
20
+ 支持模型:
21
+ - text-embedding-3-small (1536 维, 最便宜)
22
+ - text-embedding-3-large (3072 维, 最强)
23
+ - text-embedding-ada-002 (1536 维, 旧版)
24
+
25
+ 示例:
26
+ from loom.builtin.embeddings import OpenAIEmbedding
27
+
28
+ embedding = OpenAIEmbedding(
29
+ api_key="your-api-key",
30
+ model="text-embedding-3-small"
31
+ )
32
+
33
+ # 单个文本
34
+ vector = await embedding.embed_query("Hello world")
35
+
36
+ # 批量文本
37
+ vectors = await embedding.embed_documents([
38
+ "Hello world",
39
+ "Loom framework"
40
+ ])
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ api_key: str,
46
+ model: str = "text-embedding-3-small",
47
+ dimensions: int | None = None,
48
+ base_url: str | None = None,
49
+ timeout: float = 30.0,
50
+ ):
51
+ """
52
+ Parameters:
53
+ api_key: OpenAI API Key
54
+ model: 模型名称
55
+ dimensions: 向量维度(可选,text-embedding-3-* 支持)
56
+ base_url: API 基础 URL(可选,用于代理)
57
+ timeout: 请求超时时间(秒)
58
+ """
59
+ if not OPENAI_AVAILABLE:
60
+ raise ImportError(
61
+ "OpenAI client is not installed. "
62
+ "Install with: pip install openai"
63
+ )
64
+
65
+ self.model = model
66
+ self.dimensions = dimensions
67
+ self.client = AsyncOpenAI(
68
+ api_key=api_key,
69
+ base_url=base_url,
70
+ timeout=timeout
71
+ )
72
+
73
+ async def embed_query(self, text: str) -> List[float]:
74
+ """
75
+ 对单个查询文本进行向量化
76
+
77
+ Parameters:
78
+ text: 查询文本
79
+
80
+ Returns:
81
+ 向量(列表)
82
+ """
83
+ vectors = await self.embed_documents([text])
84
+ return vectors[0]
85
+
86
+ async def embed_documents(self, texts: List[str]) -> List[List[float]]:
87
+ """
88
+ 批量向量化文档
89
+
90
+ Parameters:
91
+ texts: 文本列表
92
+
93
+ Returns:
94
+ 向量列表
95
+ """
96
+ # 过滤空文本
97
+ non_empty_texts = [t for t in texts if t.strip()]
98
+
99
+ if not non_empty_texts:
100
+ return [[0.0] * (self.dimensions or 1536)] * len(texts)
101
+
102
+ # 调用 OpenAI Embedding API
103
+ kwargs = {"input": non_empty_texts, "model": self.model}
104
+ if self.dimensions:
105
+ kwargs["dimensions"] = self.dimensions
106
+
107
+ response = await self.client.embeddings.create(**kwargs)
108
+
109
+ # 提取向量
110
+ vectors = [item.embedding for item in response.data]
111
+
112
+ # 处理空文本位置
113
+ result = []
114
+ non_empty_idx = 0
115
+ for text in texts:
116
+ if text.strip():
117
+ result.append(vectors[non_empty_idx])
118
+ non_empty_idx += 1
119
+ else:
120
+ result.append([0.0] * len(vectors[0]))
121
+
122
+ return result
123
+
124
+ def get_dimension(self) -> int:
125
+ """返回向量维度"""
126
+ if self.dimensions:
127
+ return self.dimensions
128
+
129
+ # 默认维度
130
+ dimension_map = {
131
+ "text-embedding-3-small": 1536,
132
+ "text-embedding-3-large": 3072,
133
+ "text-embedding-ada-002": 1536,
134
+ }
135
+ return dimension_map.get(self.model, 1536)