codegnipy 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codegnipy/__init__.py +190 -0
- codegnipy/cli.py +153 -0
- codegnipy/decorator.py +151 -0
- codegnipy/determinism.py +631 -0
- codegnipy/memory.py +276 -0
- codegnipy/providers.py +1160 -0
- codegnipy/reflection.py +244 -0
- codegnipy/runtime.py +197 -0
- codegnipy/scheduler.py +498 -0
- codegnipy/streaming.py +387 -0
- codegnipy/tools.py +481 -0
- codegnipy/transformer.py +155 -0
- codegnipy/validation.py +961 -0
- codegnipy-0.0.1.dist-info/METADATA +417 -0
- codegnipy-0.0.1.dist-info/RECORD +19 -0
- codegnipy-0.0.1.dist-info/WHEEL +5 -0
- codegnipy-0.0.1.dist-info/entry_points.txt +2 -0
- codegnipy-0.0.1.dist-info/licenses/LICENSE +21 -0
- codegnipy-0.0.1.dist-info/top_level.txt +1 -0
codegnipy/providers.py
ADDED
|
@@ -0,0 +1,1160 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Codegnipy 多提供商支持模块
|
|
3
|
+
|
|
4
|
+
支持多种 LLM 提供商:OpenAI、Anthropic、本地模型等。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Optional, List, Dict, Any, AsyncIterator, Iterator, TYPE_CHECKING
|
|
11
|
+
from enum import Enum
|
|
12
|
+
import json
|
|
13
|
+
|
|
14
|
+
from .streaming import StreamChunk, StreamStatus
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ProviderType(Enum):
|
|
21
|
+
"""提供商类型"""
|
|
22
|
+
OPENAI = "openai"
|
|
23
|
+
ANTHROPIC = "anthropic"
|
|
24
|
+
OLLAMA = "ollama"
|
|
25
|
+
HUGGINGFACE = "huggingface"
|
|
26
|
+
CUSTOM = "custom"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class ProviderConfig:
|
|
31
|
+
"""提供商配置"""
|
|
32
|
+
provider_type: ProviderType = ProviderType.OPENAI
|
|
33
|
+
api_key: Optional[str] = None
|
|
34
|
+
model: str = ""
|
|
35
|
+
base_url: Optional[str] = None
|
|
36
|
+
temperature: float = 0.7
|
|
37
|
+
max_tokens: int = 1024
|
|
38
|
+
extra_params: dict = field(default_factory=dict)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class BaseProvider(ABC):
|
|
42
|
+
"""提供商基类"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, config: ProviderConfig):
|
|
45
|
+
self.config = config
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def call(
|
|
49
|
+
self,
|
|
50
|
+
messages: List[Dict[str, str]],
|
|
51
|
+
**kwargs
|
|
52
|
+
) -> str:
|
|
53
|
+
"""执行 LLM 调用"""
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def stream(
|
|
58
|
+
self,
|
|
59
|
+
messages: List[Dict[str, str]],
|
|
60
|
+
**kwargs
|
|
61
|
+
) -> Iterator[StreamChunk]:
|
|
62
|
+
"""执行流式调用"""
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
@abstractmethod
|
|
66
|
+
async def call_async(
|
|
67
|
+
self,
|
|
68
|
+
messages: List[Dict[str, str]],
|
|
69
|
+
**kwargs
|
|
70
|
+
) -> str:
|
|
71
|
+
"""执行异步调用"""
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
def stream_async(
|
|
76
|
+
self,
|
|
77
|
+
messages: List[Dict[str, str]],
|
|
78
|
+
**kwargs
|
|
79
|
+
) -> AsyncIterator[StreamChunk]:
|
|
80
|
+
"""执行异步流式调用"""
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
@abstractmethod
|
|
84
|
+
def call_with_tools(
|
|
85
|
+
self,
|
|
86
|
+
messages: List[Dict[str, str]],
|
|
87
|
+
tools: List[dict],
|
|
88
|
+
**kwargs
|
|
89
|
+
) -> Dict[str, Any]:
|
|
90
|
+
"""执行带工具的调用"""
|
|
91
|
+
pass
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class OpenAIProvider(BaseProvider):
|
|
95
|
+
"""OpenAI 提供商"""
|
|
96
|
+
|
|
97
|
+
def __init__(self, config: ProviderConfig):
|
|
98
|
+
super().__init__(config)
|
|
99
|
+
self._client = None
|
|
100
|
+
self._async_client = None
|
|
101
|
+
|
|
102
|
+
def _get_client(self):
|
|
103
|
+
if self._client is None:
|
|
104
|
+
try:
|
|
105
|
+
import openai
|
|
106
|
+
self._client = openai.OpenAI(
|
|
107
|
+
api_key=self.config.api_key,
|
|
108
|
+
base_url=self.config.base_url
|
|
109
|
+
)
|
|
110
|
+
except ImportError:
|
|
111
|
+
raise ImportError("需要安装 openai 包。运行: pip install openai")
|
|
112
|
+
return self._client
|
|
113
|
+
|
|
114
|
+
def _get_async_client(self):
|
|
115
|
+
if self._async_client is None:
|
|
116
|
+
try:
|
|
117
|
+
from openai import AsyncOpenAI
|
|
118
|
+
self._async_client = AsyncOpenAI(
|
|
119
|
+
api_key=self.config.api_key,
|
|
120
|
+
base_url=self.config.base_url
|
|
121
|
+
)
|
|
122
|
+
except ImportError:
|
|
123
|
+
raise ImportError("需要安装 openai 包。运行: pip install openai")
|
|
124
|
+
return self._async_client
|
|
125
|
+
|
|
126
|
+
def call(
|
|
127
|
+
self,
|
|
128
|
+
messages: List[Dict[str, str]],
|
|
129
|
+
**kwargs
|
|
130
|
+
) -> str:
|
|
131
|
+
client = self._get_client()
|
|
132
|
+
|
|
133
|
+
response = client.chat.completions.create(
|
|
134
|
+
model=kwargs.get("model", self.config.model),
|
|
135
|
+
messages=messages,
|
|
136
|
+
temperature=kwargs.get("temperature", self.config.temperature),
|
|
137
|
+
max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
|
|
138
|
+
**self.config.extra_params
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
return response.choices[0].message.content
|
|
142
|
+
|
|
143
|
+
def stream(
|
|
144
|
+
self,
|
|
145
|
+
messages: List[Dict[str, str]],
|
|
146
|
+
**kwargs
|
|
147
|
+
) -> Iterator[StreamChunk]:
|
|
148
|
+
client = self._get_client()
|
|
149
|
+
|
|
150
|
+
response = client.chat.completions.create(
|
|
151
|
+
model=kwargs.get("model", self.config.model),
|
|
152
|
+
messages=messages,
|
|
153
|
+
temperature=kwargs.get("temperature", self.config.temperature),
|
|
154
|
+
max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
|
|
155
|
+
stream=True,
|
|
156
|
+
**self.config.extra_params
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
accumulated = ""
|
|
160
|
+
yield StreamChunk(content="", status=StreamStatus.STARTED, accumulated="")
|
|
161
|
+
|
|
162
|
+
for chunk in response:
|
|
163
|
+
if chunk.choices and chunk.choices[0].delta.content:
|
|
164
|
+
content = chunk.choices[0].delta.content
|
|
165
|
+
accumulated += content
|
|
166
|
+
yield StreamChunk(
|
|
167
|
+
content=content,
|
|
168
|
+
status=StreamStatus.STREAMING,
|
|
169
|
+
accumulated=accumulated
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
yield StreamChunk(content="", status=StreamStatus.COMPLETED, accumulated=accumulated)
|
|
173
|
+
|
|
174
|
+
async def call_async(
|
|
175
|
+
self,
|
|
176
|
+
messages: List[Dict[str, str]],
|
|
177
|
+
**kwargs
|
|
178
|
+
) -> str:
|
|
179
|
+
client = self._get_async_client()
|
|
180
|
+
|
|
181
|
+
response = await client.chat.completions.create(
|
|
182
|
+
model=kwargs.get("model", self.config.model),
|
|
183
|
+
messages=messages,
|
|
184
|
+
temperature=kwargs.get("temperature", self.config.temperature),
|
|
185
|
+
max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
|
|
186
|
+
**self.config.extra_params
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
return response.choices[0].message.content
|
|
190
|
+
|
|
191
|
+
async def stream_async(
|
|
192
|
+
self,
|
|
193
|
+
messages: List[Dict[str, str]],
|
|
194
|
+
**kwargs
|
|
195
|
+
) -> AsyncIterator[StreamChunk]:
|
|
196
|
+
client = self._get_async_client()
|
|
197
|
+
|
|
198
|
+
response = await client.chat.completions.create(
|
|
199
|
+
model=kwargs.get("model", self.config.model),
|
|
200
|
+
messages=messages,
|
|
201
|
+
temperature=kwargs.get("temperature", self.config.temperature),
|
|
202
|
+
max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
|
|
203
|
+
stream=True,
|
|
204
|
+
**self.config.extra_params
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
accumulated = ""
|
|
208
|
+
yield StreamChunk(content="", status=StreamStatus.STARTED, accumulated="")
|
|
209
|
+
|
|
210
|
+
async for chunk in response:
|
|
211
|
+
if chunk.choices and chunk.choices[0].delta.content:
|
|
212
|
+
content = chunk.choices[0].delta.content
|
|
213
|
+
accumulated += content
|
|
214
|
+
yield StreamChunk(
|
|
215
|
+
content=content,
|
|
216
|
+
status=StreamStatus.STREAMING,
|
|
217
|
+
accumulated=accumulated
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
yield StreamChunk(content="", status=StreamStatus.COMPLETED, accumulated=accumulated)
|
|
221
|
+
|
|
222
|
+
def call_with_tools(
|
|
223
|
+
self,
|
|
224
|
+
messages: List[Dict[str, str]],
|
|
225
|
+
tools: List[dict],
|
|
226
|
+
**kwargs
|
|
227
|
+
) -> Dict[str, Any]:
|
|
228
|
+
client = self._get_client()
|
|
229
|
+
|
|
230
|
+
response = client.chat.completions.create(
|
|
231
|
+
model=kwargs.get("model", self.config.model),
|
|
232
|
+
messages=messages,
|
|
233
|
+
tools=tools,
|
|
234
|
+
tool_choice=kwargs.get("tool_choice", "auto"),
|
|
235
|
+
temperature=kwargs.get("temperature", self.config.temperature),
|
|
236
|
+
max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
|
|
237
|
+
**self.config.extra_params
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
message = response.choices[0].message
|
|
241
|
+
|
|
242
|
+
return {
|
|
243
|
+
"content": message.content,
|
|
244
|
+
"tool_calls": message.tool_calls,
|
|
245
|
+
"message": message
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class AnthropicProvider(BaseProvider):
|
|
250
|
+
"""Anthropic 提供商"""
|
|
251
|
+
|
|
252
|
+
def __init__(self, config: ProviderConfig):
|
|
253
|
+
super().__init__(config)
|
|
254
|
+
self._client = None
|
|
255
|
+
self._async_client = None
|
|
256
|
+
|
|
257
|
+
def _get_client(self):
|
|
258
|
+
if self._client is None:
|
|
259
|
+
try:
|
|
260
|
+
import anthropic
|
|
261
|
+
self._client = anthropic.Anthropic(
|
|
262
|
+
api_key=self.config.api_key
|
|
263
|
+
)
|
|
264
|
+
except ImportError:
|
|
265
|
+
raise ImportError(
|
|
266
|
+
"需要安装 anthropic 包。运行: pip install anthropic"
|
|
267
|
+
)
|
|
268
|
+
return self._client
|
|
269
|
+
|
|
270
|
+
def _get_async_client(self):
|
|
271
|
+
if self._async_client is None:
|
|
272
|
+
try:
|
|
273
|
+
import anthropic
|
|
274
|
+
self._async_client = anthropic.AsyncAnthropic(
|
|
275
|
+
api_key=self.config.api_key
|
|
276
|
+
)
|
|
277
|
+
except ImportError:
|
|
278
|
+
raise ImportError(
|
|
279
|
+
"需要安装 anthropic 包。运行: pip install anthropic"
|
|
280
|
+
)
|
|
281
|
+
return self._async_client
|
|
282
|
+
|
|
283
|
+
def _convert_messages(
|
|
284
|
+
self,
|
|
285
|
+
messages: List[Dict[str, str]]
|
|
286
|
+
) -> tuple[str, List[Dict[str, str]]]:
|
|
287
|
+
"""转换消息格式为 Anthropic 格式"""
|
|
288
|
+
system = ""
|
|
289
|
+
converted = []
|
|
290
|
+
|
|
291
|
+
for msg in messages:
|
|
292
|
+
if msg["role"] == "system":
|
|
293
|
+
system = msg["content"]
|
|
294
|
+
elif msg["role"] in ("user", "assistant"):
|
|
295
|
+
converted.append({
|
|
296
|
+
"role": msg["role"],
|
|
297
|
+
"content": msg["content"]
|
|
298
|
+
})
|
|
299
|
+
|
|
300
|
+
return system, converted
|
|
301
|
+
|
|
302
|
+
def call(
|
|
303
|
+
self,
|
|
304
|
+
messages: List[Dict[str, str]],
|
|
305
|
+
**kwargs
|
|
306
|
+
) -> str:
|
|
307
|
+
client = self._get_client()
|
|
308
|
+
system, converted = self._convert_messages(messages)
|
|
309
|
+
|
|
310
|
+
params = {
|
|
311
|
+
"model": kwargs.get("model", self.config.model),
|
|
312
|
+
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
|
|
313
|
+
"messages": converted,
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
if system:
|
|
317
|
+
params["system"] = system
|
|
318
|
+
if "temperature" in kwargs or self.config.temperature:
|
|
319
|
+
params["temperature"] = kwargs.get("temperature", self.config.temperature)
|
|
320
|
+
|
|
321
|
+
params.update(self.config.extra_params)
|
|
322
|
+
|
|
323
|
+
response = client.messages.create(**params)
|
|
324
|
+
|
|
325
|
+
# 提取文本内容
|
|
326
|
+
for block in response.content:
|
|
327
|
+
if hasattr(block, 'text'):
|
|
328
|
+
return block.text
|
|
329
|
+
|
|
330
|
+
return ""
|
|
331
|
+
|
|
332
|
+
def stream(
|
|
333
|
+
self,
|
|
334
|
+
messages: List[Dict[str, str]],
|
|
335
|
+
**kwargs
|
|
336
|
+
) -> Iterator[StreamChunk]:
|
|
337
|
+
client = self._get_client()
|
|
338
|
+
system, converted = self._convert_messages(messages)
|
|
339
|
+
|
|
340
|
+
params = {
|
|
341
|
+
"model": kwargs.get("model", self.config.model),
|
|
342
|
+
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
|
|
343
|
+
"messages": converted,
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
if system:
|
|
347
|
+
params["system"] = system
|
|
348
|
+
if "temperature" in kwargs or self.config.temperature:
|
|
349
|
+
params["temperature"] = kwargs.get("temperature", self.config.temperature)
|
|
350
|
+
|
|
351
|
+
params.update(self.config.extra_params)
|
|
352
|
+
|
|
353
|
+
accumulated = ""
|
|
354
|
+
yield StreamChunk(content="", status=StreamStatus.STARTED, accumulated="")
|
|
355
|
+
|
|
356
|
+
with client.messages.stream(**params) as stream:
|
|
357
|
+
for text in stream.text_stream:
|
|
358
|
+
accumulated += text
|
|
359
|
+
yield StreamChunk(
|
|
360
|
+
content=text,
|
|
361
|
+
status=StreamStatus.STREAMING,
|
|
362
|
+
accumulated=accumulated
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
yield StreamChunk(content="", status=StreamStatus.COMPLETED, accumulated=accumulated)
|
|
366
|
+
|
|
367
|
+
async def call_async(
|
|
368
|
+
self,
|
|
369
|
+
messages: List[Dict[str, str]],
|
|
370
|
+
**kwargs
|
|
371
|
+
) -> str:
|
|
372
|
+
client = self._get_async_client()
|
|
373
|
+
system, converted = self._convert_messages(messages)
|
|
374
|
+
|
|
375
|
+
params = {
|
|
376
|
+
"model": kwargs.get("model", self.config.model),
|
|
377
|
+
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
|
|
378
|
+
"messages": converted,
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
if system:
|
|
382
|
+
params["system"] = system
|
|
383
|
+
if "temperature" in kwargs or self.config.temperature:
|
|
384
|
+
params["temperature"] = kwargs.get("temperature", self.config.temperature)
|
|
385
|
+
|
|
386
|
+
params.update(self.config.extra_params)
|
|
387
|
+
|
|
388
|
+
response = await client.messages.create(**params)
|
|
389
|
+
|
|
390
|
+
for block in response.content:
|
|
391
|
+
if hasattr(block, 'text'):
|
|
392
|
+
return block.text
|
|
393
|
+
|
|
394
|
+
return ""
|
|
395
|
+
|
|
396
|
+
async def stream_async(
|
|
397
|
+
self,
|
|
398
|
+
messages: List[Dict[str, str]],
|
|
399
|
+
**kwargs
|
|
400
|
+
) -> AsyncIterator[StreamChunk]:
|
|
401
|
+
client = self._get_async_client()
|
|
402
|
+
system, converted = self._convert_messages(messages)
|
|
403
|
+
|
|
404
|
+
params = {
|
|
405
|
+
"model": kwargs.get("model", self.config.model),
|
|
406
|
+
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
|
|
407
|
+
"messages": converted,
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
if system:
|
|
411
|
+
params["system"] = system
|
|
412
|
+
if "temperature" in kwargs or self.config.temperature:
|
|
413
|
+
params["temperature"] = kwargs.get("temperature", self.config.temperature)
|
|
414
|
+
|
|
415
|
+
params.update(self.config.extra_params)
|
|
416
|
+
|
|
417
|
+
accumulated = ""
|
|
418
|
+
yield StreamChunk(content="", status=StreamStatus.STARTED, accumulated="")
|
|
419
|
+
|
|
420
|
+
async with client.messages.stream(**params) as stream:
|
|
421
|
+
async for text in stream.text_stream:
|
|
422
|
+
accumulated += text
|
|
423
|
+
yield StreamChunk(
|
|
424
|
+
content=text,
|
|
425
|
+
status=StreamStatus.STREAMING,
|
|
426
|
+
accumulated=accumulated
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
yield StreamChunk(content="", status=StreamStatus.COMPLETED, accumulated=accumulated)
|
|
430
|
+
|
|
431
|
+
def call_with_tools(
|
|
432
|
+
self,
|
|
433
|
+
messages: List[Dict[str, str]],
|
|
434
|
+
tools: List[dict],
|
|
435
|
+
**kwargs
|
|
436
|
+
) -> Dict[str, Any]:
|
|
437
|
+
client = self._get_client()
|
|
438
|
+
system, converted = self._convert_messages(messages)
|
|
439
|
+
|
|
440
|
+
# 转换工具格式
|
|
441
|
+
anthropic_tools = []
|
|
442
|
+
for tool in tools:
|
|
443
|
+
if tool.get("type") == "function":
|
|
444
|
+
func = tool["function"]
|
|
445
|
+
anthropic_tools.append({
|
|
446
|
+
"name": func["name"],
|
|
447
|
+
"description": func.get("description", ""),
|
|
448
|
+
"input_schema": func.get("parameters", {})
|
|
449
|
+
})
|
|
450
|
+
|
|
451
|
+
params = {
|
|
452
|
+
"model": kwargs.get("model", self.config.model),
|
|
453
|
+
"max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
|
|
454
|
+
"messages": converted,
|
|
455
|
+
"tools": anthropic_tools,
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
if system:
|
|
459
|
+
params["system"] = system
|
|
460
|
+
|
|
461
|
+
params.update(self.config.extra_params)
|
|
462
|
+
|
|
463
|
+
response = client.messages.create(**params)
|
|
464
|
+
|
|
465
|
+
# 解析工具调用
|
|
466
|
+
tool_calls = []
|
|
467
|
+
content = ""
|
|
468
|
+
|
|
469
|
+
for block in response.content:
|
|
470
|
+
if hasattr(block, 'text'):
|
|
471
|
+
content = block.text
|
|
472
|
+
elif block.type == "tool_use":
|
|
473
|
+
tool_calls.append({
|
|
474
|
+
"id": block.id,
|
|
475
|
+
"type": "function",
|
|
476
|
+
"function": {
|
|
477
|
+
"name": block.name,
|
|
478
|
+
"arguments": json.dumps(block.input)
|
|
479
|
+
}
|
|
480
|
+
})
|
|
481
|
+
|
|
482
|
+
return {
|
|
483
|
+
"content": content,
|
|
484
|
+
"tool_calls": tool_calls if tool_calls else None,
|
|
485
|
+
"message": response
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
class OllamaProvider(BaseProvider):
|
|
490
|
+
"""Ollama 本地模型提供商
|
|
491
|
+
|
|
492
|
+
支持 Ollama 运行的本地模型,如 llama2、mistral、codellama 等。
|
|
493
|
+
需要 Ollama 服务运行在本地或远程服务器上。
|
|
494
|
+
|
|
495
|
+
使用示例:
|
|
496
|
+
config = ProviderConfig(
|
|
497
|
+
provider_type=ProviderType.OLLAMA,
|
|
498
|
+
model="llama2",
|
|
499
|
+
base_url="http://localhost:11434"
|
|
500
|
+
)
|
|
501
|
+
provider = OllamaProvider(config)
|
|
502
|
+
"""
|
|
503
|
+
|
|
504
|
+
def __init__(self, config: ProviderConfig):
|
|
505
|
+
super().__init__(config)
|
|
506
|
+
self._base_url = config.base_url or "http://localhost:11434"
|
|
507
|
+
|
|
508
|
+
def _get_client(self):
|
|
509
|
+
"""获取 HTTP 客户端"""
|
|
510
|
+
try:
|
|
511
|
+
import urllib.request
|
|
512
|
+
return urllib.request
|
|
513
|
+
except ImportError:
|
|
514
|
+
raise ImportError("需要 urllib 支持")
|
|
515
|
+
|
|
516
|
+
def _make_request(
|
|
517
|
+
self,
|
|
518
|
+
endpoint: str,
|
|
519
|
+
data: Dict[str, Any],
|
|
520
|
+
stream: bool = False
|
|
521
|
+
) -> Any:
|
|
522
|
+
"""发送 HTTP 请求到 Ollama"""
|
|
523
|
+
import json
|
|
524
|
+
import urllib.request
|
|
525
|
+
|
|
526
|
+
url = f"{self._base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
|
527
|
+
headers = {"Content-Type": "application/json"}
|
|
528
|
+
|
|
529
|
+
req = urllib.request.Request(
|
|
530
|
+
url,
|
|
531
|
+
data=json.dumps(data).encode("utf-8"),
|
|
532
|
+
headers=headers,
|
|
533
|
+
method="POST"
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
with urllib.request.urlopen(req) as response:
|
|
537
|
+
return json.loads(response.read().decode("utf-8"))
|
|
538
|
+
|
|
539
|
+
def _make_stream_request(
|
|
540
|
+
self,
|
|
541
|
+
endpoint: str,
|
|
542
|
+
data: Dict[str, Any]
|
|
543
|
+
) -> Iterator[Dict[str, Any]]:
|
|
544
|
+
"""发送流式 HTTP 请求"""
|
|
545
|
+
import json
|
|
546
|
+
import urllib.request
|
|
547
|
+
|
|
548
|
+
url = f"{self._base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
|
549
|
+
headers = {"Content-Type": "application/json"}
|
|
550
|
+
|
|
551
|
+
req = urllib.request.Request(
|
|
552
|
+
url,
|
|
553
|
+
data=json.dumps({**data, "stream": True}).encode("utf-8"),
|
|
554
|
+
headers=headers,
|
|
555
|
+
method="POST"
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
with urllib.request.urlopen(req) as response:
|
|
559
|
+
for line in response:
|
|
560
|
+
line = line.decode("utf-8").strip()
|
|
561
|
+
if line:
|
|
562
|
+
yield json.loads(line)
|
|
563
|
+
|
|
564
|
+
async def _make_async_request(
|
|
565
|
+
self,
|
|
566
|
+
endpoint: str,
|
|
567
|
+
data: Dict[str, Any]
|
|
568
|
+
) -> Any:
|
|
569
|
+
"""发送异步 HTTP 请求"""
|
|
570
|
+
import aiohttp
|
|
571
|
+
|
|
572
|
+
url = f"{self._base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
|
573
|
+
|
|
574
|
+
async with aiohttp.ClientSession() as session:
|
|
575
|
+
async with session.post(
|
|
576
|
+
url,
|
|
577
|
+
json=data,
|
|
578
|
+
headers={"Content-Type": "application/json"}
|
|
579
|
+
) as response:
|
|
580
|
+
return await response.json()
|
|
581
|
+
|
|
582
|
+
async def _make_async_stream_request(
|
|
583
|
+
self,
|
|
584
|
+
endpoint: str,
|
|
585
|
+
data: Dict[str, Any]
|
|
586
|
+
) -> AsyncIterator[Dict[str, Any]]:
|
|
587
|
+
"""发送异步流式请求"""
|
|
588
|
+
import json
|
|
589
|
+
import aiohttp
|
|
590
|
+
|
|
591
|
+
url = f"{self._base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
|
592
|
+
|
|
593
|
+
async with aiohttp.ClientSession() as session:
|
|
594
|
+
async with session.post(
|
|
595
|
+
url,
|
|
596
|
+
json={**data, "stream": True},
|
|
597
|
+
headers={"Content-Type": "application/json"}
|
|
598
|
+
) as response:
|
|
599
|
+
async for line in response.content:
|
|
600
|
+
line = line.decode("utf-8").strip()
|
|
601
|
+
if line:
|
|
602
|
+
yield json.loads(line)
|
|
603
|
+
|
|
604
|
+
def _convert_messages(self, messages: List[Dict[str, str]]) -> str:
|
|
605
|
+
"""将消息转换为 Ollama 格式的提示词"""
|
|
606
|
+
prompt_parts = []
|
|
607
|
+
|
|
608
|
+
for msg in messages:
|
|
609
|
+
role = msg["role"]
|
|
610
|
+
content = msg["content"]
|
|
611
|
+
|
|
612
|
+
if role == "system":
|
|
613
|
+
prompt_parts.append(f"System: {content}")
|
|
614
|
+
elif role == "user":
|
|
615
|
+
prompt_parts.append(f"User: {content}")
|
|
616
|
+
elif role == "assistant":
|
|
617
|
+
prompt_parts.append(f"Assistant: {content}")
|
|
618
|
+
|
|
619
|
+
prompt_parts.append("Assistant:")
|
|
620
|
+
return "\n\n".join(prompt_parts)
|
|
621
|
+
|
|
622
|
+
def call(
|
|
623
|
+
self,
|
|
624
|
+
messages: List[Dict[str, str]],
|
|
625
|
+
**kwargs
|
|
626
|
+
) -> str:
|
|
627
|
+
prompt = self._convert_messages(messages)
|
|
628
|
+
|
|
629
|
+
data = {
|
|
630
|
+
"model": kwargs.get("model", self.config.model),
|
|
631
|
+
"prompt": prompt,
|
|
632
|
+
"options": {
|
|
633
|
+
"temperature": kwargs.get("temperature", self.config.temperature),
|
|
634
|
+
"num_predict": kwargs.get("max_tokens", self.config.max_tokens),
|
|
635
|
+
}
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
# 合并额外参数
|
|
639
|
+
if self.config.extra_params:
|
|
640
|
+
data["options"].update(self.config.extra_params)
|
|
641
|
+
|
|
642
|
+
result = self._make_request("/api/generate", data)
|
|
643
|
+
return result.get("response", "")
|
|
644
|
+
|
|
645
|
+
def stream(
|
|
646
|
+
self,
|
|
647
|
+
messages: List[Dict[str, str]],
|
|
648
|
+
**kwargs
|
|
649
|
+
) -> Iterator[StreamChunk]:
|
|
650
|
+
prompt = self._convert_messages(messages)
|
|
651
|
+
|
|
652
|
+
data = {
|
|
653
|
+
"model": kwargs.get("model", self.config.model),
|
|
654
|
+
"prompt": prompt,
|
|
655
|
+
"options": {
|
|
656
|
+
"temperature": kwargs.get("temperature", self.config.temperature),
|
|
657
|
+
"num_predict": kwargs.get("max_tokens", self.config.max_tokens),
|
|
658
|
+
}
|
|
659
|
+
}
|
|
660
|
+
|
|
661
|
+
if self.config.extra_params:
|
|
662
|
+
data["options"].update(self.config.extra_params)
|
|
663
|
+
|
|
664
|
+
accumulated = ""
|
|
665
|
+
yield StreamChunk(content="", status=StreamStatus.STARTED, accumulated="")
|
|
666
|
+
|
|
667
|
+
for chunk in self._make_stream_request("/api/generate", data):
|
|
668
|
+
if "response" in chunk:
|
|
669
|
+
content = chunk["response"]
|
|
670
|
+
accumulated += content
|
|
671
|
+
yield StreamChunk(
|
|
672
|
+
content=content,
|
|
673
|
+
status=StreamStatus.STREAMING,
|
|
674
|
+
accumulated=accumulated
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
if chunk.get("done", False):
|
|
678
|
+
break
|
|
679
|
+
|
|
680
|
+
yield StreamChunk(content="", status=StreamStatus.COMPLETED, accumulated=accumulated)
|
|
681
|
+
|
|
682
|
+
async def call_async(
|
|
683
|
+
self,
|
|
684
|
+
messages: List[Dict[str, str]],
|
|
685
|
+
**kwargs
|
|
686
|
+
) -> str:
|
|
687
|
+
prompt = self._convert_messages(messages)
|
|
688
|
+
|
|
689
|
+
data = {
|
|
690
|
+
"model": kwargs.get("model", self.config.model),
|
|
691
|
+
"prompt": prompt,
|
|
692
|
+
"options": {
|
|
693
|
+
"temperature": kwargs.get("temperature", self.config.temperature),
|
|
694
|
+
"num_predict": kwargs.get("max_tokens", self.config.max_tokens),
|
|
695
|
+
}
|
|
696
|
+
}
|
|
697
|
+
|
|
698
|
+
if self.config.extra_params:
|
|
699
|
+
data["options"].update(self.config.extra_params)
|
|
700
|
+
|
|
701
|
+
result = await self._make_async_request("/api/generate", data)
|
|
702
|
+
return result.get("response", "")
|
|
703
|
+
|
|
704
|
+
async def stream_async(
|
|
705
|
+
self,
|
|
706
|
+
messages: List[Dict[str, str]],
|
|
707
|
+
**kwargs
|
|
708
|
+
) -> AsyncIterator[StreamChunk]:
|
|
709
|
+
prompt = self._convert_messages(messages)
|
|
710
|
+
|
|
711
|
+
data = {
|
|
712
|
+
"model": kwargs.get("model", self.config.model),
|
|
713
|
+
"prompt": prompt,
|
|
714
|
+
"options": {
|
|
715
|
+
"temperature": kwargs.get("temperature", self.config.temperature),
|
|
716
|
+
"num_predict": kwargs.get("max_tokens", self.config.max_tokens),
|
|
717
|
+
}
|
|
718
|
+
}
|
|
719
|
+
|
|
720
|
+
if self.config.extra_params:
|
|
721
|
+
data["options"].update(self.config.extra_params)
|
|
722
|
+
|
|
723
|
+
accumulated = ""
|
|
724
|
+
yield StreamChunk(content="", status=StreamStatus.STARTED, accumulated="")
|
|
725
|
+
|
|
726
|
+
async for chunk in self._make_async_stream_request("/api/generate", data):
|
|
727
|
+
if "response" in chunk:
|
|
728
|
+
content = chunk["response"]
|
|
729
|
+
accumulated += content
|
|
730
|
+
yield StreamChunk(
|
|
731
|
+
content=content,
|
|
732
|
+
status=StreamStatus.STREAMING,
|
|
733
|
+
accumulated=accumulated
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
if chunk.get("done", False):
|
|
737
|
+
break
|
|
738
|
+
|
|
739
|
+
yield StreamChunk(content="", status=StreamStatus.COMPLETED, accumulated=accumulated)
|
|
740
|
+
|
|
741
|
+
def call_with_tools(
|
|
742
|
+
self,
|
|
743
|
+
messages: List[Dict[str, str]],
|
|
744
|
+
tools: List[dict],
|
|
745
|
+
**kwargs
|
|
746
|
+
) -> Dict[str, Any]:
|
|
747
|
+
# Ollama 原生不支持工具调用,通过提示词模拟
|
|
748
|
+
tool_descriptions = []
|
|
749
|
+
for tool in tools:
|
|
750
|
+
if tool.get("type") == "function":
|
|
751
|
+
func = tool["function"]
|
|
752
|
+
desc = f"- {func['name']}: {func.get('description', 'No description')}"
|
|
753
|
+
if "parameters" in func:
|
|
754
|
+
desc += f"\n Parameters: {json.dumps(func['parameters'])}"
|
|
755
|
+
tool_descriptions.append(desc)
|
|
756
|
+
|
|
757
|
+
tool_prompt = "\n".join(tool_descriptions)
|
|
758
|
+
enhanced_messages = messages + [{
|
|
759
|
+
"role": "system",
|
|
760
|
+
"content": f"\n\nAvailable tools:\n{tool_prompt}\n\nTo use a tool, respond with a JSON object."
|
|
761
|
+
}]
|
|
762
|
+
|
|
763
|
+
response = self.call(enhanced_messages, **kwargs)
|
|
764
|
+
|
|
765
|
+
# 尝试解析 JSON 工具调用
|
|
766
|
+
tool_calls = None
|
|
767
|
+
try:
|
|
768
|
+
import re
|
|
769
|
+
json_match = re.search(r'\{[\s\S]*\}', response)
|
|
770
|
+
if json_match:
|
|
771
|
+
parsed = json.loads(json_match.group())
|
|
772
|
+
if "name" in parsed or "function" in parsed:
|
|
773
|
+
tool_calls = [{
|
|
774
|
+
"id": f"call_{hash(response) % 10000}",
|
|
775
|
+
"type": "function",
|
|
776
|
+
"function": {
|
|
777
|
+
"name": parsed.get("name") or parsed.get("function"),
|
|
778
|
+
"arguments": json.dumps(parsed.get("arguments", parsed))
|
|
779
|
+
}
|
|
780
|
+
}]
|
|
781
|
+
except (json.JSONDecodeError, KeyError):
|
|
782
|
+
pass
|
|
783
|
+
|
|
784
|
+
return {
|
|
785
|
+
"content": response,
|
|
786
|
+
"tool_calls": tool_calls,
|
|
787
|
+
"message": response
|
|
788
|
+
}
|
|
789
|
+
|
|
790
|
+
def list_models(self) -> List[str]:
|
|
791
|
+
"""列出可用的本地模型"""
|
|
792
|
+
import json
|
|
793
|
+
import urllib.request
|
|
794
|
+
|
|
795
|
+
url = f"{self._base_url.rstrip('/')}/api/tags"
|
|
796
|
+
|
|
797
|
+
try:
|
|
798
|
+
req = urllib.request.Request(url, method="GET")
|
|
799
|
+
with urllib.request.urlopen(req) as response:
|
|
800
|
+
data = json.loads(response.read().decode("utf-8"))
|
|
801
|
+
return [model["name"] for model in data.get("models", [])]
|
|
802
|
+
except Exception:
|
|
803
|
+
return []
|
|
804
|
+
|
|
805
|
+
|
|
806
|
+
class TransformersProvider(BaseProvider):
|
|
807
|
+
"""HuggingFace Transformers 本地模型提供商
|
|
808
|
+
|
|
809
|
+
使用 transformers 库在本地运行模型。支持各种 HuggingFace 模型。
|
|
810
|
+
|
|
811
|
+
使用示例:
|
|
812
|
+
config = ProviderConfig(
|
|
813
|
+
provider_type=ProviderType.HUGGINGFACE,
|
|
814
|
+
model="microsoft/DialoGPT-medium",
|
|
815
|
+
extra_params={"device": "cuda"} # 或 "cpu", "mps"
|
|
816
|
+
)
|
|
817
|
+
provider = TransformersProvider(config)
|
|
818
|
+
"""
|
|
819
|
+
|
|
820
|
+
def __init__(self, config: ProviderConfig):
|
|
821
|
+
super().__init__(config)
|
|
822
|
+
self._model: Optional["PreTrainedModel"] = None
|
|
823
|
+
self._tokenizer: Optional["PreTrainedTokenizerBase"] = None
|
|
824
|
+
self._device = config.extra_params.get("device", "auto")
|
|
825
|
+
self._pipeline = None
|
|
826
|
+
|
|
827
|
+
def _load_model(self):
|
|
828
|
+
"""延迟加载模型"""
|
|
829
|
+
if self._model is not None:
|
|
830
|
+
return
|
|
831
|
+
|
|
832
|
+
try:
|
|
833
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
834
|
+
except ImportError:
|
|
835
|
+
raise ImportError(
|
|
836
|
+
"需要安装 transformers 和 torch。运行: pip install transformers torch"
|
|
837
|
+
)
|
|
838
|
+
|
|
839
|
+
model_name = self.config.model
|
|
840
|
+
|
|
841
|
+
try:
|
|
842
|
+
# 尝试加载 tokenizer
|
|
843
|
+
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
844
|
+
|
|
845
|
+
# 加载模型
|
|
846
|
+
self._model = AutoModelForCausalLM.from_pretrained(
|
|
847
|
+
model_name,
|
|
848
|
+
device_map=self._device if self._device != "auto" else None
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
if self._device != "auto" and hasattr(self._model, "to"):
|
|
852
|
+
self._model = self._model.to(self._device)
|
|
853
|
+
|
|
854
|
+
except Exception as e:
|
|
855
|
+
raise RuntimeError(f"加载模型失败: {e}")
|
|
856
|
+
|
|
857
|
+
def _load_pipeline(self):
|
|
858
|
+
"""加载 pipeline(更简单的方式)"""
|
|
859
|
+
if self._pipeline is not None:
|
|
860
|
+
return
|
|
861
|
+
|
|
862
|
+
try:
|
|
863
|
+
from transformers import pipeline
|
|
864
|
+
except ImportError:
|
|
865
|
+
raise ImportError(
|
|
866
|
+
"需要安装 transformers。运行: pip install transformers"
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
self._pipeline = pipeline(
|
|
870
|
+
"text-generation",
|
|
871
|
+
model=self.config.model,
|
|
872
|
+
device=self._device if self._device != "auto" else -1
|
|
873
|
+
)
|
|
874
|
+
|
|
875
|
+
def _convert_messages(self, messages: List[Dict[str, str]]) -> str:
|
|
876
|
+
"""将消息转换为提示词"""
|
|
877
|
+
prompt_parts = []
|
|
878
|
+
|
|
879
|
+
for msg in messages:
|
|
880
|
+
role = msg["role"]
|
|
881
|
+
content = msg["content"]
|
|
882
|
+
|
|
883
|
+
if role == "system":
|
|
884
|
+
prompt_parts.append(f"[INST] <<SYS>>\n{content}\n<</SYS>>\n\n[/INST]")
|
|
885
|
+
elif role == "user":
|
|
886
|
+
prompt_parts.append(f"[INST] {content} [/INST]")
|
|
887
|
+
elif role == "assistant":
|
|
888
|
+
prompt_parts.append(content)
|
|
889
|
+
|
|
890
|
+
return "".join(prompt_parts)
|
|
891
|
+
|
|
892
|
+
def call(
|
|
893
|
+
self,
|
|
894
|
+
messages: List[Dict[str, str]],
|
|
895
|
+
**kwargs
|
|
896
|
+
) -> str:
|
|
897
|
+
self._load_model()
|
|
898
|
+
|
|
899
|
+
# 类型检查:_load_model 成功后这些不为 None
|
|
900
|
+
assert self._model is not None
|
|
901
|
+
assert self._tokenizer is not None
|
|
902
|
+
|
|
903
|
+
prompt = self._convert_messages(messages)
|
|
904
|
+
|
|
905
|
+
# Tokenize
|
|
906
|
+
inputs = self._tokenizer(prompt, return_tensors="pt")
|
|
907
|
+
|
|
908
|
+
if self._device != "cpu" and hasattr(inputs, "to"):
|
|
909
|
+
inputs = {k: v.to(self._model.device) for k, v in inputs.items()}
|
|
910
|
+
|
|
911
|
+
# Generate
|
|
912
|
+
max_new_tokens = kwargs.get("max_tokens", self.config.max_tokens)
|
|
913
|
+
temperature = kwargs.get("temperature", self.config.temperature)
|
|
914
|
+
|
|
915
|
+
if temperature > 0:
|
|
916
|
+
outputs = self._model.generate(
|
|
917
|
+
**inputs,
|
|
918
|
+
max_new_tokens=max_new_tokens,
|
|
919
|
+
temperature=temperature,
|
|
920
|
+
do_sample=True,
|
|
921
|
+
pad_token_id=self._tokenizer.eos_token_id
|
|
922
|
+
)
|
|
923
|
+
else:
|
|
924
|
+
outputs = self._model.generate(
|
|
925
|
+
**inputs,
|
|
926
|
+
max_new_tokens=max_new_tokens,
|
|
927
|
+
do_sample=False,
|
|
928
|
+
pad_token_id=self._tokenizer.eos_token_id
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
# Decode
|
|
932
|
+
generated_text = self._tokenizer.decode(
|
|
933
|
+
outputs[0][inputs["input_ids"].shape[1]:],
|
|
934
|
+
skip_special_tokens=True
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
return generated_text
|
|
938
|
+
|
|
939
|
+
def stream(
|
|
940
|
+
self,
|
|
941
|
+
messages: List[Dict[str, str]],
|
|
942
|
+
**kwargs
|
|
943
|
+
) -> Iterator[StreamChunk]:
|
|
944
|
+
# 本地模型流式生成需要 TextIteratorStreamer
|
|
945
|
+
try:
|
|
946
|
+
from transformers import TextIteratorStreamer
|
|
947
|
+
import threading
|
|
948
|
+
except ImportError:
|
|
949
|
+
# 如果不支持流式,返回完整结果
|
|
950
|
+
result = self.call(messages, **kwargs)
|
|
951
|
+
yield StreamChunk(content="", status=StreamStatus.STARTED, accumulated="")
|
|
952
|
+
yield StreamChunk(content=result, status=StreamStatus.STREAMING, accumulated=result)
|
|
953
|
+
yield StreamChunk(content="", status=StreamStatus.COMPLETED, accumulated=result)
|
|
954
|
+
return
|
|
955
|
+
|
|
956
|
+
self._load_model()
|
|
957
|
+
|
|
958
|
+
# 类型检查:_load_model 成功后这些不为 None
|
|
959
|
+
assert self._model is not None
|
|
960
|
+
assert self._tokenizer is not None
|
|
961
|
+
|
|
962
|
+
prompt = self._convert_messages(messages)
|
|
963
|
+
inputs = self._tokenizer(prompt, return_tensors="pt")
|
|
964
|
+
|
|
965
|
+
if self._device != "cpu":
|
|
966
|
+
inputs = {k: v.to(self._model.device) for k, v in inputs.items()}
|
|
967
|
+
|
|
968
|
+
max_new_tokens = kwargs.get("max_tokens", self.config.max_tokens)
|
|
969
|
+
temperature = kwargs.get("temperature", self.config.temperature)
|
|
970
|
+
|
|
971
|
+
streamer = TextIteratorStreamer(
|
|
972
|
+
self._tokenizer,
|
|
973
|
+
skip_prompt=True,
|
|
974
|
+
skip_special_tokens=True
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
generation_kwargs = {
|
|
978
|
+
**inputs,
|
|
979
|
+
"max_new_tokens": max_new_tokens,
|
|
980
|
+
"streamer": streamer,
|
|
981
|
+
"pad_token_id": self._tokenizer.eos_token_id
|
|
982
|
+
}
|
|
983
|
+
|
|
984
|
+
if temperature > 0:
|
|
985
|
+
generation_kwargs["temperature"] = temperature
|
|
986
|
+
generation_kwargs["do_sample"] = True
|
|
987
|
+
else:
|
|
988
|
+
generation_kwargs["do_sample"] = False
|
|
989
|
+
|
|
990
|
+
# 在后台线程中运行生成
|
|
991
|
+
thread = threading.Thread(
|
|
992
|
+
target=self._model.generate,
|
|
993
|
+
kwargs=generation_kwargs
|
|
994
|
+
)
|
|
995
|
+
thread.start()
|
|
996
|
+
|
|
997
|
+
accumulated = ""
|
|
998
|
+
yield StreamChunk(content="", status=StreamStatus.STARTED, accumulated="")
|
|
999
|
+
|
|
1000
|
+
for text in streamer:
|
|
1001
|
+
accumulated += text
|
|
1002
|
+
yield StreamChunk(
|
|
1003
|
+
content=text,
|
|
1004
|
+
status=StreamStatus.STREAMING,
|
|
1005
|
+
accumulated=accumulated
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
thread.join()
|
|
1009
|
+
yield StreamChunk(content="", status=StreamStatus.COMPLETED, accumulated=accumulated)
|
|
1010
|
+
|
|
1011
|
+
async def call_async(
|
|
1012
|
+
self,
|
|
1013
|
+
messages: List[Dict[str, str]],
|
|
1014
|
+
**kwargs
|
|
1015
|
+
) -> str:
|
|
1016
|
+
import asyncio
|
|
1017
|
+
loop = asyncio.get_event_loop()
|
|
1018
|
+
return await loop.run_in_executor(
|
|
1019
|
+
None,
|
|
1020
|
+
lambda: self.call(messages, **kwargs)
|
|
1021
|
+
)
|
|
1022
|
+
|
|
1023
|
+
async def stream_async(
|
|
1024
|
+
self,
|
|
1025
|
+
messages: List[Dict[str, str]],
|
|
1026
|
+
**kwargs
|
|
1027
|
+
) -> AsyncIterator[StreamChunk]:
|
|
1028
|
+
for chunk in self.stream(messages, **kwargs):
|
|
1029
|
+
yield chunk
|
|
1030
|
+
await asyncio.sleep(0) # 让出控制权
|
|
1031
|
+
|
|
1032
|
+
def call_with_tools(
|
|
1033
|
+
self,
|
|
1034
|
+
messages: List[Dict[str, str]],
|
|
1035
|
+
tools: List[dict],
|
|
1036
|
+
**kwargs
|
|
1037
|
+
) -> Dict[str, Any]:
|
|
1038
|
+
# 与 Ollama 类似,通过提示词模拟
|
|
1039
|
+
tool_descriptions = []
|
|
1040
|
+
for tool in tools:
|
|
1041
|
+
if tool.get("type") == "function":
|
|
1042
|
+
func = tool["function"]
|
|
1043
|
+
desc = f"- {func['name']}: {func.get('description', 'No description')}"
|
|
1044
|
+
if "parameters" in func:
|
|
1045
|
+
desc += f"\n Parameters: {json.dumps(func['parameters'])}"
|
|
1046
|
+
tool_descriptions.append(desc)
|
|
1047
|
+
|
|
1048
|
+
tool_prompt = "\n".join(tool_descriptions)
|
|
1049
|
+
enhanced_messages = messages + [{
|
|
1050
|
+
"role": "system",
|
|
1051
|
+
"content": f"\n\nAvailable tools:\n{tool_prompt}\n\nTo use a tool, respond with a JSON object."
|
|
1052
|
+
}]
|
|
1053
|
+
|
|
1054
|
+
response = self.call(enhanced_messages, **kwargs)
|
|
1055
|
+
|
|
1056
|
+
tool_calls = None
|
|
1057
|
+
try:
|
|
1058
|
+
import re
|
|
1059
|
+
json_match = re.search(r'\{[\s\S]*\}', response)
|
|
1060
|
+
if json_match:
|
|
1061
|
+
parsed = json.loads(json_match.group())
|
|
1062
|
+
if "name" in parsed or "function" in parsed:
|
|
1063
|
+
tool_calls = [{
|
|
1064
|
+
"id": f"call_{hash(response) % 10000}",
|
|
1065
|
+
"type": "function",
|
|
1066
|
+
"function": {
|
|
1067
|
+
"name": parsed.get("name") or parsed.get("function"),
|
|
1068
|
+
"arguments": json.dumps(parsed.get("arguments", parsed))
|
|
1069
|
+
}
|
|
1070
|
+
}]
|
|
1071
|
+
except (json.JSONDecodeError, KeyError):
|
|
1072
|
+
pass
|
|
1073
|
+
|
|
1074
|
+
return {
|
|
1075
|
+
"content": response,
|
|
1076
|
+
"tool_calls": tool_calls,
|
|
1077
|
+
"message": response
|
|
1078
|
+
}
|
|
1079
|
+
|
|
1080
|
+
|
|
1081
|
+
class ProviderFactory:
|
|
1082
|
+
"""提供商工厂"""
|
|
1083
|
+
|
|
1084
|
+
_providers: Dict[ProviderType, type] = {
|
|
1085
|
+
ProviderType.OPENAI: OpenAIProvider,
|
|
1086
|
+
ProviderType.ANTHROPIC: AnthropicProvider,
|
|
1087
|
+
ProviderType.OLLAMA: OllamaProvider,
|
|
1088
|
+
ProviderType.HUGGINGFACE: TransformersProvider,
|
|
1089
|
+
}
|
|
1090
|
+
|
|
1091
|
+
@classmethod
|
|
1092
|
+
def create(cls, config: ProviderConfig) -> BaseProvider:
|
|
1093
|
+
"""创建提供商实例"""
|
|
1094
|
+
provider_class = cls._providers.get(config.provider_type)
|
|
1095
|
+
|
|
1096
|
+
if provider_class is None:
|
|
1097
|
+
raise ValueError(f"Unknown provider type: {config.provider_type}")
|
|
1098
|
+
|
|
1099
|
+
return provider_class(config)
|
|
1100
|
+
|
|
1101
|
+
@classmethod
|
|
1102
|
+
def register(cls, provider_type: ProviderType, provider_class: type) -> None:
|
|
1103
|
+
"""注册自定义提供商"""
|
|
1104
|
+
cls._providers[provider_type] = provider_class
|
|
1105
|
+
|
|
1106
|
+
|
|
1107
|
+
def create_provider(
|
|
1108
|
+
provider_type: str = "openai",
|
|
1109
|
+
api_key: Optional[str] = None,
|
|
1110
|
+
model: Optional[str] = None,
|
|
1111
|
+
base_url: Optional[str] = None,
|
|
1112
|
+
**kwargs
|
|
1113
|
+
) -> BaseProvider:
|
|
1114
|
+
"""
|
|
1115
|
+
创建提供商实例
|
|
1116
|
+
|
|
1117
|
+
参数:
|
|
1118
|
+
provider_type: 提供商类型 ("openai", "anthropic", "ollama", "huggingface", "custom")
|
|
1119
|
+
api_key: API 密钥 (OpenAI/Anthropic 需要)
|
|
1120
|
+
model: 模型名称
|
|
1121
|
+
base_url: API 基础 URL (Ollama 默认 http://localhost:11434)
|
|
1122
|
+
**kwargs: 其他配置
|
|
1123
|
+
|
|
1124
|
+
返回:
|
|
1125
|
+
提供商实例
|
|
1126
|
+
|
|
1127
|
+
示例:
|
|
1128
|
+
# OpenAI
|
|
1129
|
+
provider = create_provider("openai", api_key="sk-...", model="gpt-4")
|
|
1130
|
+
|
|
1131
|
+
# Anthropic
|
|
1132
|
+
provider = create_provider("anthropic", api_key="sk-ant-...", model="claude-3-opus")
|
|
1133
|
+
|
|
1134
|
+
# Ollama 本地模型
|
|
1135
|
+
provider = create_provider("ollama", model="llama2", base_url="http://localhost:11434")
|
|
1136
|
+
|
|
1137
|
+
# HuggingFace Transformers
|
|
1138
|
+
provider = create_provider("huggingface", model="microsoft/DialoGPT-medium")
|
|
1139
|
+
"""
|
|
1140
|
+
type_map = {
|
|
1141
|
+
"openai": ProviderType.OPENAI,
|
|
1142
|
+
"anthropic": ProviderType.ANTHROPIC,
|
|
1143
|
+
"ollama": ProviderType.OLLAMA,
|
|
1144
|
+
"huggingface": ProviderType.HUGGINGFACE,
|
|
1145
|
+
"custom": ProviderType.CUSTOM,
|
|
1146
|
+
}
|
|
1147
|
+
|
|
1148
|
+
pt = type_map.get(provider_type.lower())
|
|
1149
|
+
if pt is None:
|
|
1150
|
+
raise ValueError(f"Unknown provider type: {provider_type}")
|
|
1151
|
+
|
|
1152
|
+
config = ProviderConfig(
|
|
1153
|
+
provider_type=pt,
|
|
1154
|
+
api_key=api_key,
|
|
1155
|
+
model=model or "",
|
|
1156
|
+
base_url=base_url,
|
|
1157
|
+
**kwargs
|
|
1158
|
+
)
|
|
1159
|
+
|
|
1160
|
+
return ProviderFactory.create(config)
|