loom-agent 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of loom-agent might be problematic. Click here for more details.

Files changed (89) hide show
  1. loom/__init__.py +77 -0
  2. loom/agent.py +217 -0
  3. loom/agents/__init__.py +10 -0
  4. loom/agents/refs.py +28 -0
  5. loom/agents/registry.py +50 -0
  6. loom/builtin/compression/__init__.py +4 -0
  7. loom/builtin/compression/structured.py +79 -0
  8. loom/builtin/embeddings/__init__.py +9 -0
  9. loom/builtin/embeddings/openai_embedding.py +135 -0
  10. loom/builtin/embeddings/sentence_transformers_embedding.py +145 -0
  11. loom/builtin/llms/__init__.py +8 -0
  12. loom/builtin/llms/mock.py +34 -0
  13. loom/builtin/llms/openai.py +168 -0
  14. loom/builtin/llms/rule.py +102 -0
  15. loom/builtin/memory/__init__.py +5 -0
  16. loom/builtin/memory/in_memory.py +21 -0
  17. loom/builtin/memory/persistent_memory.py +278 -0
  18. loom/builtin/retriever/__init__.py +9 -0
  19. loom/builtin/retriever/chroma_store.py +265 -0
  20. loom/builtin/retriever/in_memory.py +106 -0
  21. loom/builtin/retriever/milvus_store.py +307 -0
  22. loom/builtin/retriever/pinecone_store.py +237 -0
  23. loom/builtin/retriever/qdrant_store.py +274 -0
  24. loom/builtin/retriever/vector_store.py +128 -0
  25. loom/builtin/retriever/vector_store_config.py +217 -0
  26. loom/builtin/tools/__init__.py +32 -0
  27. loom/builtin/tools/calculator.py +49 -0
  28. loom/builtin/tools/document_search.py +111 -0
  29. loom/builtin/tools/glob.py +27 -0
  30. loom/builtin/tools/grep.py +56 -0
  31. loom/builtin/tools/http_request.py +86 -0
  32. loom/builtin/tools/python_repl.py +73 -0
  33. loom/builtin/tools/read_file.py +32 -0
  34. loom/builtin/tools/task.py +158 -0
  35. loom/builtin/tools/web_search.py +64 -0
  36. loom/builtin/tools/write_file.py +31 -0
  37. loom/callbacks/base.py +9 -0
  38. loom/callbacks/logging.py +12 -0
  39. loom/callbacks/metrics.py +27 -0
  40. loom/callbacks/observability.py +248 -0
  41. loom/components/agent.py +107 -0
  42. loom/core/agent_executor.py +450 -0
  43. loom/core/circuit_breaker.py +178 -0
  44. loom/core/compression_manager.py +329 -0
  45. loom/core/context_retriever.py +185 -0
  46. loom/core/error_classifier.py +193 -0
  47. loom/core/errors.py +66 -0
  48. loom/core/message_queue.py +167 -0
  49. loom/core/permission_store.py +62 -0
  50. loom/core/permissions.py +69 -0
  51. loom/core/scheduler.py +125 -0
  52. loom/core/steering_control.py +47 -0
  53. loom/core/structured_logger.py +279 -0
  54. loom/core/subagent_pool.py +232 -0
  55. loom/core/system_prompt.py +141 -0
  56. loom/core/system_reminders.py +283 -0
  57. loom/core/tool_pipeline.py +113 -0
  58. loom/core/types.py +269 -0
  59. loom/interfaces/compressor.py +59 -0
  60. loom/interfaces/embedding.py +51 -0
  61. loom/interfaces/llm.py +33 -0
  62. loom/interfaces/memory.py +29 -0
  63. loom/interfaces/retriever.py +179 -0
  64. loom/interfaces/tool.py +27 -0
  65. loom/interfaces/vector_store.py +80 -0
  66. loom/llm/__init__.py +14 -0
  67. loom/llm/config.py +228 -0
  68. loom/llm/factory.py +111 -0
  69. loom/llm/model_health.py +235 -0
  70. loom/llm/model_pool_advanced.py +305 -0
  71. loom/llm/pool.py +170 -0
  72. loom/llm/registry.py +201 -0
  73. loom/mcp/__init__.py +4 -0
  74. loom/mcp/client.py +86 -0
  75. loom/mcp/registry.py +58 -0
  76. loom/mcp/tool_adapter.py +48 -0
  77. loom/observability/__init__.py +5 -0
  78. loom/patterns/__init__.py +5 -0
  79. loom/patterns/multi_agent.py +123 -0
  80. loom/patterns/rag.py +262 -0
  81. loom/plugins/registry.py +55 -0
  82. loom/resilience/__init__.py +5 -0
  83. loom/tooling.py +72 -0
  84. loom/utils/agent_loader.py +218 -0
  85. loom/utils/token_counter.py +19 -0
  86. loom_agent-0.0.1.dist-info/METADATA +457 -0
  87. loom_agent-0.0.1.dist-info/RECORD +89 -0
  88. loom_agent-0.0.1.dist-info/WHEEL +4 -0
  89. loom_agent-0.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,59 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import List, Tuple
5
+
6
+ from loom.core.types import Message, CompressionMetadata
7
+
8
+
9
+ class BaseCompressor(ABC):
10
+ """Context compression interface (US2-compatible).
11
+
12
+ Updated in v4.0.0 to return compression metadata alongside compressed messages.
13
+
14
+ Migration Note:
15
+ Old interface (v3.0.1):
16
+ async def compress(self, messages) -> List[Message]
17
+
18
+ New interface (v4.0.0):
19
+ async def compress(self, messages) -> Tuple[List[Message], CompressionMetadata]
20
+
21
+ For custom compressor implementations, update your compress() method to return
22
+ a tuple: (compressed_messages, metadata).
23
+ """
24
+
25
+ @abstractmethod
26
+ async def compress(
27
+ self, messages: List[Message]
28
+ ) -> Tuple[List[Message], CompressionMetadata]:
29
+ """Compress messages and return metadata.
30
+
31
+ Args:
32
+ messages: List of messages to compress
33
+
34
+ Returns:
35
+ Tuple of (compressed_messages, compression_metadata)
36
+
37
+ Example:
38
+ compressed_msgs, metadata = await compressor.compress(history)
39
+ print(f"Reduced tokens: {metadata.original_tokens} → {metadata.compressed_tokens}")
40
+ """
41
+ raise NotImplementedError
42
+
43
+ @abstractmethod
44
+ def should_compress(self, token_count: int, max_tokens: int) -> bool:
45
+ """Check if compression should be triggered.
46
+
47
+ Args:
48
+ token_count: Current context token count
49
+ max_tokens: Maximum allowed context tokens
50
+
51
+ Returns:
52
+ True if compression should be triggered (typically at 92% threshold)
53
+
54
+ Example:
55
+ if compressor.should_compress(15000, 16000): # 93.75% usage
56
+ compressed, metadata = await compressor.compress(history)
57
+ """
58
+ raise NotImplementedError
59
+
@@ -0,0 +1,51 @@
1
+ """Embedding 接口定义"""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+ from typing import List
7
+
8
+
9
+ class BaseEmbedding(ABC):
10
+ """
11
+ Embedding 基类
12
+
13
+ 文本向量化的统一接口,支持:
14
+ - 单个文本向量化(查询)
15
+ - 批量文本向量化(文档)
16
+ """
17
+
18
+ @abstractmethod
19
+ async def embed_query(self, text: str) -> List[float]:
20
+ """
21
+ 对单个查询文本进行向量化
22
+
23
+ Parameters:
24
+ text: 查询文本
25
+
26
+ Returns:
27
+ 向量(浮点数列表)
28
+ """
29
+ pass
30
+
31
+ @abstractmethod
32
+ async def embed_documents(self, texts: List[str]) -> List[List[float]]:
33
+ """
34
+ 批量向量化文档
35
+
36
+ Parameters:
37
+ texts: 文本列表
38
+
39
+ Returns:
40
+ 向量列表,与输入文本列表一一对应
41
+ """
42
+ pass
43
+
44
+ def get_dimension(self) -> int:
45
+ """
46
+ 返回向量维度
47
+
48
+ Returns:
49
+ 向量的维度(例如 384, 768, 1536)
50
+ """
51
+ raise NotImplementedError("Subclasses should implement get_dimension()")
loom/interfaces/llm.py ADDED
@@ -0,0 +1,33 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import AsyncGenerator, Dict, List
5
+
6
+
7
+ class BaseLLM(ABC):
8
+ """LLM 基础接口 - 所有 LLM 提供者必须实现"""
9
+
10
+ @abstractmethod
11
+ async def generate(self, messages: List[Dict]) -> str:
12
+ """非流式生成一个完整响应。"""
13
+ raise NotImplementedError
14
+
15
+ @abstractmethod
16
+ async def stream(self, messages: List[Dict]) -> AsyncGenerator[str, None]:
17
+ """流式生成响应内容增量。"""
18
+ raise NotImplementedError
19
+
20
+ @abstractmethod
21
+ async def generate_with_tools(self, messages: List[Dict], tools: List[Dict]) -> Dict:
22
+ """带工具调用的生成(返回可能包含 tool_calls 等结构)。"""
23
+ raise NotImplementedError
24
+
25
+ @property
26
+ @abstractmethod
27
+ def model_name(self) -> str:
28
+ raise NotImplementedError
29
+
30
+ @property
31
+ def supports_tools(self) -> bool:
32
+ return False
33
+
@@ -0,0 +1,29 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import List, Optional
5
+
6
+ from loom.core.types import Message
7
+
8
+
9
+ class BaseMemory(ABC):
10
+ """对话/状态内存接口。"""
11
+
12
+ @abstractmethod
13
+ async def add_message(self, message: Message) -> None:
14
+ raise NotImplementedError
15
+
16
+ @abstractmethod
17
+ async def get_messages(self, limit: Optional[int] = None) -> List[Message]:
18
+ raise NotImplementedError
19
+
20
+ @abstractmethod
21
+ async def clear(self) -> None:
22
+ raise NotImplementedError
23
+
24
+ async def save(self, path: str) -> None: # 可选
25
+ return None
26
+
27
+ async def load(self, path: str) -> None: # 可选
28
+ return None
29
+
@@ -0,0 +1,179 @@
1
+ """检索器接口 - RAG 系统的基础抽象"""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ from pydantic import BaseModel, Field
9
+
10
+
11
+ class Document(BaseModel):
12
+ """
13
+ 文档数据模型
14
+
15
+ 表示从知识库检索到的文档片段
16
+ """
17
+
18
+ content: str = Field(description="文档内容")
19
+ metadata: Dict[str, Any] = Field(default_factory=dict, description="文档元数据")
20
+ score: Optional[float] = Field(default=None, description="相关性分数 (0-1)")
21
+ doc_id: Optional[str] = Field(default=None, description="文档唯一标识")
22
+
23
+ def __str__(self) -> str:
24
+ source = self.metadata.get("source", "Unknown")
25
+ return f"Document(source={source}, score={self.score}, len={len(self.content)})"
26
+
27
+ def format(self, max_length: int = 500) -> str:
28
+ """格式化文档用于显示"""
29
+ lines = []
30
+ if self.metadata.get("source"):
31
+ lines.append(f"**Source**: {self.metadata['source']}")
32
+ if self.score is not None:
33
+ lines.append(f"**Relevance**: {self.score:.2%}")
34
+
35
+ content = self.content
36
+ if len(content) > max_length:
37
+ content = content[:max_length] + "..."
38
+ lines.append(content)
39
+
40
+ return "\n".join(lines)
41
+
42
+
43
+ class BaseRetriever(ABC):
44
+ """
45
+ 检索器基础接口
46
+
47
+ 所有检索器必须实现此接口:
48
+ - VectorStoreRetriever (向量检索)
49
+ - BM25Retriever (关键词检索)
50
+ - HybridRetriever (混合检索)
51
+ """
52
+
53
+ @abstractmethod
54
+ async def retrieve(
55
+ self,
56
+ query: str,
57
+ top_k: int = 5,
58
+ filters: Optional[Dict[str, Any]] = None,
59
+ ) -> List[Document]:
60
+ """
61
+ 检索相关文档
62
+
63
+ Parameters:
64
+ query: 查询字符串
65
+ top_k: 返回文档数量
66
+ filters: 过滤条件 (例如: {"source": "doc.pdf"})
67
+
68
+ Returns:
69
+ 按相关性排序的文档列表
70
+ """
71
+ pass
72
+
73
+ @abstractmethod
74
+ async def add_documents(
75
+ self,
76
+ documents: List[Document],
77
+ ) -> None:
78
+ """
79
+ 添加文档到检索系统
80
+
81
+ Parameters:
82
+ documents: 要添加的文档列表
83
+ """
84
+ pass
85
+
86
+ async def add_texts(
87
+ self,
88
+ texts: List[str],
89
+ metadatas: Optional[List[Dict[str, Any]]] = None,
90
+ ) -> None:
91
+ """
92
+ 便捷方法: 添加文本列表
93
+
94
+ Parameters:
95
+ texts: 文本列表
96
+ metadatas: 对应的元数据列表
97
+ """
98
+ if metadatas is None:
99
+ metadatas = [{} for _ in texts]
100
+
101
+ documents = [
102
+ Document(content=text, metadata=meta)
103
+ for text, meta in zip(texts, metadatas)
104
+ ]
105
+
106
+ await self.add_documents(documents)
107
+
108
+
109
+ class BaseVectorStore(ABC):
110
+ """
111
+ 向量存储基础接口
112
+
113
+ 用于底层向量数据库的抽象:
114
+ - ChromaDB
115
+ - FAISS
116
+ - Pinecone
117
+ - Weaviate
118
+ """
119
+
120
+ @abstractmethod
121
+ async def similarity_search(
122
+ self,
123
+ query: str,
124
+ k: int = 5,
125
+ filters: Optional[Dict[str, Any]] = None,
126
+ ) -> List[Document]:
127
+ """
128
+ 基于向量相似度搜索
129
+
130
+ Parameters:
131
+ query: 查询字符串 (会被自动向量化)
132
+ k: 返回结果数量
133
+ filters: 元数据过滤条件
134
+
135
+ Returns:
136
+ 按相似度排序的文档列表
137
+ """
138
+ pass
139
+
140
+ @abstractmethod
141
+ async def add_texts(
142
+ self,
143
+ texts: List[str],
144
+ metadatas: Optional[List[Dict[str, Any]]] = None,
145
+ ) -> List[str]:
146
+ """
147
+ 添加文本到向量存储
148
+
149
+ Parameters:
150
+ texts: 文本列表
151
+ metadatas: 元数据列表
152
+
153
+ Returns:
154
+ 文档 ID 列表
155
+ """
156
+ pass
157
+
158
+ @abstractmethod
159
+ async def delete(self, ids: List[str]) -> None:
160
+ """删除文档"""
161
+ pass
162
+
163
+
164
+ class BaseEmbedding(ABC):
165
+ """
166
+ 嵌入模型接口
167
+
168
+ 用于将文本转换为向量
169
+ """
170
+
171
+ @abstractmethod
172
+ async def embed_query(self, text: str) -> List[float]:
173
+ """生成查询向量"""
174
+ pass
175
+
176
+ @abstractmethod
177
+ async def embed_documents(self, texts: List[str]) -> List[List[float]]:
178
+ """批量生成文档向量"""
179
+ pass
@@ -0,0 +1,27 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any
5
+
6
+ from pydantic import BaseModel
7
+
8
+
9
+ class BaseTool(ABC):
10
+ """工具基础接口。"""
11
+
12
+ name: str
13
+ description: str
14
+ args_schema: type[BaseModel]
15
+
16
+ @abstractmethod
17
+ async def run(self, **kwargs) -> Any:
18
+ raise NotImplementedError
19
+
20
+ @property
21
+ def is_async(self) -> bool:
22
+ return True
23
+
24
+ @property
25
+ def is_concurrency_safe(self) -> bool:
26
+ return True
27
+
@@ -0,0 +1,80 @@
1
+ """向量存储接口定义"""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ from loom.interfaces.retriever import Document
9
+
10
+
11
+ class BaseVectorStore(ABC):
12
+ """
13
+ 向量存储基类
14
+
15
+ 向量数据库的统一接口,支持:
16
+ - 向量的添加和删除
17
+ - 基于向量的相似度搜索
18
+ - 元数据过滤
19
+ """
20
+
21
+ @abstractmethod
22
+ async def initialize(self) -> None:
23
+ """
24
+ 初始化向量存储连接
25
+
26
+ 用于建立数据库连接、创建集合/索引等初始化操作
27
+ """
28
+ pass
29
+
30
+ @abstractmethod
31
+ async def add_vectors(
32
+ self,
33
+ vectors: List[List[float]],
34
+ documents: List[Document]
35
+ ) -> None:
36
+ """
37
+ 添加向量到存储
38
+
39
+ Parameters:
40
+ vectors: 向量列表,每个向量是一个浮点数列表
41
+ documents: 对应的文档列表,与 vectors 一一对应
42
+ """
43
+ pass
44
+
45
+ @abstractmethod
46
+ async def search(
47
+ self,
48
+ query_vector: List[float],
49
+ top_k: int = 5,
50
+ filters: Optional[Dict[str, Any]] = None
51
+ ) -> List[Tuple[Document, float]]:
52
+ """
53
+ 搜索相似向量
54
+
55
+ Parameters:
56
+ query_vector: 查询向量
57
+ top_k: 返回结果数量
58
+ filters: 可选的元数据过滤条件
59
+
60
+ Returns:
61
+ [(Document, score), ...] 列表,按相似度分数降序排列
62
+ """
63
+ pass
64
+
65
+ async def delete(self, doc_ids: List[str]) -> None:
66
+ """
67
+ 删除指定文档
68
+
69
+ Parameters:
70
+ doc_ids: 要删除的文档 ID 列表
71
+ """
72
+ pass
73
+
74
+ async def clear(self) -> None:
75
+ """清空所有向量"""
76
+ pass
77
+
78
+ async def close(self) -> None:
79
+ """关闭连接,释放资源"""
80
+ pass
loom/llm/__init__.py ADDED
@@ -0,0 +1,14 @@
1
+ from .config import LLMConfig, LLMProvider, LLMCapabilities
2
+ from .factory import LLMFactory
3
+ from .pool import ModelPool
4
+ from .registry import ModelRegistry
5
+
6
+ __all__ = [
7
+ "LLMConfig",
8
+ "LLMProvider",
9
+ "LLMCapabilities",
10
+ "LLMFactory",
11
+ "ModelPool",
12
+ "ModelRegistry",
13
+ ]
14
+
loom/llm/config.py ADDED
@@ -0,0 +1,228 @@
1
+ """LLM 统一配置系统
2
+
3
+ 提供统一的接口来配置各种 LLM 提供商,支持:
4
+ - OpenAI (GPT-4, GPT-3.5)
5
+ - Anthropic (Claude)
6
+ - Google (Gemini)
7
+ - Cohere
8
+ - Azure OpenAI
9
+ - Ollama (本地模型)
10
+ - 自定义模型
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from typing import Any, Dict, Optional
16
+ from pydantic import BaseModel, Field
17
+ from enum import Enum
18
+
19
+
20
+ class LLMProvider(str, Enum):
21
+ """LLM 提供商枚举"""
22
+ OPENAI = "openai"
23
+ ANTHROPIC = "anthropic"
24
+ GOOGLE = "google"
25
+ COHERE = "cohere"
26
+ AZURE_OPENAI = "azure_openai"
27
+ OLLAMA = "ollama"
28
+ CUSTOM = "custom"
29
+
30
+
31
+ class LLMCapabilities(BaseModel):
32
+ """模型能力描述"""
33
+ supports_tools: bool = Field(default=False, description="是否支持工具调用/Function Calling")
34
+ supports_vision: bool = Field(default=False, description="是否支持视觉输入")
35
+ supports_streaming: bool = Field(default=True, description="是否支持流式输出")
36
+ supports_json_mode: bool = Field(default=False, description="是否支持 JSON 模式")
37
+ supports_system_message: bool = Field(default=True, description="是否支持系统消息")
38
+ max_tokens: int = Field(default=4096, description="最大输出 token 数")
39
+ context_window: int = Field(default=8192, description="上下文窗口大小")
40
+
41
+
42
+ class LLMConfig(BaseModel):
43
+ """统一的 LLM 配置
44
+
45
+ 示例:
46
+ # OpenAI
47
+ config = LLMConfig.openai(api_key="sk-...", model="gpt-4")
48
+
49
+ # Anthropic
50
+ config = LLMConfig.anthropic(api_key="sk-ant-...", model="claude-3-5-sonnet-20241022")
51
+
52
+ # Ollama 本地模型
53
+ config = LLMConfig.ollama(model="llama3", base_url="http://localhost:11434")
54
+
55
+ # 自定义模型
56
+ config = LLMConfig.custom(
57
+ model_name="my-model",
58
+ base_url="https://my-api.com",
59
+ api_key="..."
60
+ )
61
+ """
62
+
63
+ # 基本配置
64
+ provider: LLMProvider = Field(description="LLM 提供商")
65
+ model_name: str = Field(description="模型名称")
66
+ api_key: Optional[str] = Field(default=None, description="API Key(如果需要)")
67
+ base_url: Optional[str] = Field(default=None, description="API 基础 URL(用于代理或自定义端点)")
68
+
69
+ # 生成参数
70
+ temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="温度参数,控制随机性")
71
+ max_tokens: Optional[int] = Field(default=None, description="最大输出 token 数")
72
+ top_p: float = Field(default=1.0, ge=0.0, le=1.0, description="核采样参数")
73
+ frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0, description="频率惩罚")
74
+ presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0, description="存在惩罚")
75
+
76
+ # 可靠性配置
77
+ max_retries: int = Field(default=3, ge=0, description="最大重试次数")
78
+ retry_delay: float = Field(default=1.0, ge=0.0, description="重试延迟(秒)")
79
+ timeout: float = Field(default=60.0, ge=0.0, description="请求超时时间(秒)")
80
+
81
+ # 额外参数(特定提供商的特殊参数)
82
+ extra_params: Dict[str, Any] = Field(default_factory=dict, description="额外的提供商特定参数")
83
+
84
+ # 模型能力(可选,用于自动检测)
85
+ capabilities: Optional[LLMCapabilities] = Field(default=None, description="模型能力,不指定则自动检测")
86
+
87
+ class Config:
88
+ use_enum_values = True
89
+
90
+ # ==================== 快速配置方法 ====================
91
+
92
+ @classmethod
93
+ def openai(
94
+ cls,
95
+ api_key: str,
96
+ model: str = "gpt-4",
97
+ temperature: float = 0.7,
98
+ max_tokens: Optional[int] = None,
99
+ base_url: Optional[str] = None,
100
+ **kwargs
101
+ ) -> "LLMConfig":
102
+ return cls(
103
+ provider=LLMProvider.OPENAI,
104
+ model_name=model,
105
+ api_key=api_key,
106
+ temperature=temperature,
107
+ max_tokens=max_tokens,
108
+ base_url=base_url,
109
+ **kwargs
110
+ )
111
+
112
+ @classmethod
113
+ def anthropic(
114
+ cls,
115
+ api_key: str,
116
+ model: str = "claude-3-5-sonnet-20241022",
117
+ temperature: float = 0.7,
118
+ max_tokens: Optional[int] = None,
119
+ **kwargs
120
+ ) -> "LLMConfig":
121
+ return cls(
122
+ provider=LLMProvider.ANTHROPIC,
123
+ model_name=model,
124
+ api_key=api_key,
125
+ temperature=temperature,
126
+ max_tokens=max_tokens or 4096,
127
+ **kwargs
128
+ )
129
+
130
+ @classmethod
131
+ def ollama(
132
+ cls,
133
+ model: str = "llama3",
134
+ base_url: str = "http://localhost:11434",
135
+ temperature: float = 0.7,
136
+ **kwargs
137
+ ) -> "LLMConfig":
138
+ return cls(
139
+ provider=LLMProvider.OLLAMA,
140
+ model_name=model,
141
+ base_url=base_url,
142
+ temperature=temperature,
143
+ api_key=None,
144
+ **kwargs
145
+ )
146
+
147
+ @classmethod
148
+ def azure_openai(
149
+ cls,
150
+ api_key: str,
151
+ deployment_name: str,
152
+ endpoint: str,
153
+ api_version: str = "2024-02-15-preview",
154
+ temperature: float = 0.7,
155
+ **kwargs
156
+ ) -> "LLMConfig":
157
+ return cls(
158
+ provider=LLMProvider.AZURE_OPENAI,
159
+ model_name=deployment_name,
160
+ api_key=api_key,
161
+ base_url=endpoint,
162
+ temperature=temperature,
163
+ extra_params={"api_version": api_version},
164
+ **kwargs
165
+ )
166
+
167
+ @classmethod
168
+ def google(
169
+ cls,
170
+ api_key: str,
171
+ model: str = "gemini-pro",
172
+ temperature: float = 0.7,
173
+ **kwargs
174
+ ) -> "LLMConfig":
175
+ return cls(
176
+ provider=LLMProvider.GOOGLE,
177
+ model_name=model,
178
+ api_key=api_key,
179
+ temperature=temperature,
180
+ **kwargs
181
+ )
182
+
183
+ @classmethod
184
+ def cohere(
185
+ cls,
186
+ api_key: str,
187
+ model: str = "command-r-plus",
188
+ temperature: float = 0.7,
189
+ **kwargs
190
+ ) -> "LLMConfig":
191
+ return cls(
192
+ provider=LLMProvider.COHERE,
193
+ model_name=model,
194
+ api_key=api_key,
195
+ temperature=temperature,
196
+ **kwargs
197
+ )
198
+
199
+ @classmethod
200
+ def custom(
201
+ cls,
202
+ model_name: str,
203
+ base_url: str,
204
+ api_key: Optional[str] = None,
205
+ temperature: float = 0.7,
206
+ capabilities: Optional[LLMCapabilities] = None,
207
+ **kwargs
208
+ ) -> "LLMConfig":
209
+ return cls(
210
+ provider=LLMProvider.CUSTOM,
211
+ model_name=model_name,
212
+ base_url=base_url,
213
+ api_key=api_key,
214
+ temperature=temperature,
215
+ capabilities=capabilities,
216
+ **kwargs
217
+ )
218
+
219
+ def to_dict(self) -> Dict[str, Any]:
220
+ return self.model_dump()
221
+
222
+ @classmethod
223
+ def from_dict(cls, data: Dict[str, Any]) -> "LLMConfig":
224
+ return cls(**data)
225
+
226
+ def __repr__(self) -> str:
227
+ return f"LLMConfig(provider={self.provider}, model={self.model_name})"
228
+