codegnipy 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
codegnipy/providers.py ADDED
@@ -0,0 +1,1160 @@
1
+ """
2
+ Codegnipy 多提供商支持模块
3
+
4
+ 支持多种 LLM 提供商:OpenAI、Anthropic、本地模型等。
5
+ """
6
+
7
+ import asyncio
8
+ from abc import ABC, abstractmethod
9
+ from dataclasses import dataclass, field
10
+ from typing import Optional, List, Dict, Any, AsyncIterator, Iterator, TYPE_CHECKING
11
+ from enum import Enum
12
+ import json
13
+
14
+ from .streaming import StreamChunk, StreamStatus
15
+
16
+ if TYPE_CHECKING:
17
+ from transformers import PreTrainedModel, PreTrainedTokenizerBase
18
+
19
+
20
+ class ProviderType(Enum):
21
+ """提供商类型"""
22
+ OPENAI = "openai"
23
+ ANTHROPIC = "anthropic"
24
+ OLLAMA = "ollama"
25
+ HUGGINGFACE = "huggingface"
26
+ CUSTOM = "custom"
27
+
28
+
29
+ @dataclass
30
+ class ProviderConfig:
31
+ """提供商配置"""
32
+ provider_type: ProviderType = ProviderType.OPENAI
33
+ api_key: Optional[str] = None
34
+ model: str = ""
35
+ base_url: Optional[str] = None
36
+ temperature: float = 0.7
37
+ max_tokens: int = 1024
38
+ extra_params: dict = field(default_factory=dict)
39
+
40
+
41
+ class BaseProvider(ABC):
42
+ """提供商基类"""
43
+
44
+ def __init__(self, config: ProviderConfig):
45
+ self.config = config
46
+
47
+ @abstractmethod
48
+ def call(
49
+ self,
50
+ messages: List[Dict[str, str]],
51
+ **kwargs
52
+ ) -> str:
53
+ """执行 LLM 调用"""
54
+ pass
55
+
56
+ @abstractmethod
57
+ def stream(
58
+ self,
59
+ messages: List[Dict[str, str]],
60
+ **kwargs
61
+ ) -> Iterator[StreamChunk]:
62
+ """执行流式调用"""
63
+ pass
64
+
65
+ @abstractmethod
66
+ async def call_async(
67
+ self,
68
+ messages: List[Dict[str, str]],
69
+ **kwargs
70
+ ) -> str:
71
+ """执行异步调用"""
72
+ pass
73
+
74
+ @abstractmethod
75
+ def stream_async(
76
+ self,
77
+ messages: List[Dict[str, str]],
78
+ **kwargs
79
+ ) -> AsyncIterator[StreamChunk]:
80
+ """执行异步流式调用"""
81
+ pass
82
+
83
+ @abstractmethod
84
+ def call_with_tools(
85
+ self,
86
+ messages: List[Dict[str, str]],
87
+ tools: List[dict],
88
+ **kwargs
89
+ ) -> Dict[str, Any]:
90
+ """执行带工具的调用"""
91
+ pass
92
+
93
+
94
+ class OpenAIProvider(BaseProvider):
95
+ """OpenAI 提供商"""
96
+
97
+ def __init__(self, config: ProviderConfig):
98
+ super().__init__(config)
99
+ self._client = None
100
+ self._async_client = None
101
+
102
+ def _get_client(self):
103
+ if self._client is None:
104
+ try:
105
+ import openai
106
+ self._client = openai.OpenAI(
107
+ api_key=self.config.api_key,
108
+ base_url=self.config.base_url
109
+ )
110
+ except ImportError:
111
+ raise ImportError("需要安装 openai 包。运行: pip install openai")
112
+ return self._client
113
+
114
+ def _get_async_client(self):
115
+ if self._async_client is None:
116
+ try:
117
+ from openai import AsyncOpenAI
118
+ self._async_client = AsyncOpenAI(
119
+ api_key=self.config.api_key,
120
+ base_url=self.config.base_url
121
+ )
122
+ except ImportError:
123
+ raise ImportError("需要安装 openai 包。运行: pip install openai")
124
+ return self._async_client
125
+
126
+ def call(
127
+ self,
128
+ messages: List[Dict[str, str]],
129
+ **kwargs
130
+ ) -> str:
131
+ client = self._get_client()
132
+
133
+ response = client.chat.completions.create(
134
+ model=kwargs.get("model", self.config.model),
135
+ messages=messages,
136
+ temperature=kwargs.get("temperature", self.config.temperature),
137
+ max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
138
+ **self.config.extra_params
139
+ )
140
+
141
+ return response.choices[0].message.content
142
+
143
+ def stream(
144
+ self,
145
+ messages: List[Dict[str, str]],
146
+ **kwargs
147
+ ) -> Iterator[StreamChunk]:
148
+ client = self._get_client()
149
+
150
+ response = client.chat.completions.create(
151
+ model=kwargs.get("model", self.config.model),
152
+ messages=messages,
153
+ temperature=kwargs.get("temperature", self.config.temperature),
154
+ max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
155
+ stream=True,
156
+ **self.config.extra_params
157
+ )
158
+
159
+ accumulated = ""
160
+ yield StreamChunk(content="", status=StreamStatus.STARTED, accumulated="")
161
+
162
+ for chunk in response:
163
+ if chunk.choices and chunk.choices[0].delta.content:
164
+ content = chunk.choices[0].delta.content
165
+ accumulated += content
166
+ yield StreamChunk(
167
+ content=content,
168
+ status=StreamStatus.STREAMING,
169
+ accumulated=accumulated
170
+ )
171
+
172
+ yield StreamChunk(content="", status=StreamStatus.COMPLETED, accumulated=accumulated)
173
+
174
+ async def call_async(
175
+ self,
176
+ messages: List[Dict[str, str]],
177
+ **kwargs
178
+ ) -> str:
179
+ client = self._get_async_client()
180
+
181
+ response = await client.chat.completions.create(
182
+ model=kwargs.get("model", self.config.model),
183
+ messages=messages,
184
+ temperature=kwargs.get("temperature", self.config.temperature),
185
+ max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
186
+ **self.config.extra_params
187
+ )
188
+
189
+ return response.choices[0].message.content
190
+
191
+ async def stream_async(
192
+ self,
193
+ messages: List[Dict[str, str]],
194
+ **kwargs
195
+ ) -> AsyncIterator[StreamChunk]:
196
+ client = self._get_async_client()
197
+
198
+ response = await client.chat.completions.create(
199
+ model=kwargs.get("model", self.config.model),
200
+ messages=messages,
201
+ temperature=kwargs.get("temperature", self.config.temperature),
202
+ max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
203
+ stream=True,
204
+ **self.config.extra_params
205
+ )
206
+
207
+ accumulated = ""
208
+ yield StreamChunk(content="", status=StreamStatus.STARTED, accumulated="")
209
+
210
+ async for chunk in response:
211
+ if chunk.choices and chunk.choices[0].delta.content:
212
+ content = chunk.choices[0].delta.content
213
+ accumulated += content
214
+ yield StreamChunk(
215
+ content=content,
216
+ status=StreamStatus.STREAMING,
217
+ accumulated=accumulated
218
+ )
219
+
220
+ yield StreamChunk(content="", status=StreamStatus.COMPLETED, accumulated=accumulated)
221
+
222
+ def call_with_tools(
223
+ self,
224
+ messages: List[Dict[str, str]],
225
+ tools: List[dict],
226
+ **kwargs
227
+ ) -> Dict[str, Any]:
228
+ client = self._get_client()
229
+
230
+ response = client.chat.completions.create(
231
+ model=kwargs.get("model", self.config.model),
232
+ messages=messages,
233
+ tools=tools,
234
+ tool_choice=kwargs.get("tool_choice", "auto"),
235
+ temperature=kwargs.get("temperature", self.config.temperature),
236
+ max_tokens=kwargs.get("max_tokens", self.config.max_tokens),
237
+ **self.config.extra_params
238
+ )
239
+
240
+ message = response.choices[0].message
241
+
242
+ return {
243
+ "content": message.content,
244
+ "tool_calls": message.tool_calls,
245
+ "message": message
246
+ }
247
+
248
+
249
+ class AnthropicProvider(BaseProvider):
250
+ """Anthropic 提供商"""
251
+
252
+ def __init__(self, config: ProviderConfig):
253
+ super().__init__(config)
254
+ self._client = None
255
+ self._async_client = None
256
+
257
+ def _get_client(self):
258
+ if self._client is None:
259
+ try:
260
+ import anthropic
261
+ self._client = anthropic.Anthropic(
262
+ api_key=self.config.api_key
263
+ )
264
+ except ImportError:
265
+ raise ImportError(
266
+ "需要安装 anthropic 包。运行: pip install anthropic"
267
+ )
268
+ return self._client
269
+
270
+ def _get_async_client(self):
271
+ if self._async_client is None:
272
+ try:
273
+ import anthropic
274
+ self._async_client = anthropic.AsyncAnthropic(
275
+ api_key=self.config.api_key
276
+ )
277
+ except ImportError:
278
+ raise ImportError(
279
+ "需要安装 anthropic 包。运行: pip install anthropic"
280
+ )
281
+ return self._async_client
282
+
283
+ def _convert_messages(
284
+ self,
285
+ messages: List[Dict[str, str]]
286
+ ) -> tuple[str, List[Dict[str, str]]]:
287
+ """转换消息格式为 Anthropic 格式"""
288
+ system = ""
289
+ converted = []
290
+
291
+ for msg in messages:
292
+ if msg["role"] == "system":
293
+ system = msg["content"]
294
+ elif msg["role"] in ("user", "assistant"):
295
+ converted.append({
296
+ "role": msg["role"],
297
+ "content": msg["content"]
298
+ })
299
+
300
+ return system, converted
301
+
302
+ def call(
303
+ self,
304
+ messages: List[Dict[str, str]],
305
+ **kwargs
306
+ ) -> str:
307
+ client = self._get_client()
308
+ system, converted = self._convert_messages(messages)
309
+
310
+ params = {
311
+ "model": kwargs.get("model", self.config.model),
312
+ "max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
313
+ "messages": converted,
314
+ }
315
+
316
+ if system:
317
+ params["system"] = system
318
+ if "temperature" in kwargs or self.config.temperature:
319
+ params["temperature"] = kwargs.get("temperature", self.config.temperature)
320
+
321
+ params.update(self.config.extra_params)
322
+
323
+ response = client.messages.create(**params)
324
+
325
+ # 提取文本内容
326
+ for block in response.content:
327
+ if hasattr(block, 'text'):
328
+ return block.text
329
+
330
+ return ""
331
+
332
+ def stream(
333
+ self,
334
+ messages: List[Dict[str, str]],
335
+ **kwargs
336
+ ) -> Iterator[StreamChunk]:
337
+ client = self._get_client()
338
+ system, converted = self._convert_messages(messages)
339
+
340
+ params = {
341
+ "model": kwargs.get("model", self.config.model),
342
+ "max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
343
+ "messages": converted,
344
+ }
345
+
346
+ if system:
347
+ params["system"] = system
348
+ if "temperature" in kwargs or self.config.temperature:
349
+ params["temperature"] = kwargs.get("temperature", self.config.temperature)
350
+
351
+ params.update(self.config.extra_params)
352
+
353
+ accumulated = ""
354
+ yield StreamChunk(content="", status=StreamStatus.STARTED, accumulated="")
355
+
356
+ with client.messages.stream(**params) as stream:
357
+ for text in stream.text_stream:
358
+ accumulated += text
359
+ yield StreamChunk(
360
+ content=text,
361
+ status=StreamStatus.STREAMING,
362
+ accumulated=accumulated
363
+ )
364
+
365
+ yield StreamChunk(content="", status=StreamStatus.COMPLETED, accumulated=accumulated)
366
+
367
+ async def call_async(
368
+ self,
369
+ messages: List[Dict[str, str]],
370
+ **kwargs
371
+ ) -> str:
372
+ client = self._get_async_client()
373
+ system, converted = self._convert_messages(messages)
374
+
375
+ params = {
376
+ "model": kwargs.get("model", self.config.model),
377
+ "max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
378
+ "messages": converted,
379
+ }
380
+
381
+ if system:
382
+ params["system"] = system
383
+ if "temperature" in kwargs or self.config.temperature:
384
+ params["temperature"] = kwargs.get("temperature", self.config.temperature)
385
+
386
+ params.update(self.config.extra_params)
387
+
388
+ response = await client.messages.create(**params)
389
+
390
+ for block in response.content:
391
+ if hasattr(block, 'text'):
392
+ return block.text
393
+
394
+ return ""
395
+
396
+ async def stream_async(
397
+ self,
398
+ messages: List[Dict[str, str]],
399
+ **kwargs
400
+ ) -> AsyncIterator[StreamChunk]:
401
+ client = self._get_async_client()
402
+ system, converted = self._convert_messages(messages)
403
+
404
+ params = {
405
+ "model": kwargs.get("model", self.config.model),
406
+ "max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
407
+ "messages": converted,
408
+ }
409
+
410
+ if system:
411
+ params["system"] = system
412
+ if "temperature" in kwargs or self.config.temperature:
413
+ params["temperature"] = kwargs.get("temperature", self.config.temperature)
414
+
415
+ params.update(self.config.extra_params)
416
+
417
+ accumulated = ""
418
+ yield StreamChunk(content="", status=StreamStatus.STARTED, accumulated="")
419
+
420
+ async with client.messages.stream(**params) as stream:
421
+ async for text in stream.text_stream:
422
+ accumulated += text
423
+ yield StreamChunk(
424
+ content=text,
425
+ status=StreamStatus.STREAMING,
426
+ accumulated=accumulated
427
+ )
428
+
429
+ yield StreamChunk(content="", status=StreamStatus.COMPLETED, accumulated=accumulated)
430
+
431
+ def call_with_tools(
432
+ self,
433
+ messages: List[Dict[str, str]],
434
+ tools: List[dict],
435
+ **kwargs
436
+ ) -> Dict[str, Any]:
437
+ client = self._get_client()
438
+ system, converted = self._convert_messages(messages)
439
+
440
+ # 转换工具格式
441
+ anthropic_tools = []
442
+ for tool in tools:
443
+ if tool.get("type") == "function":
444
+ func = tool["function"]
445
+ anthropic_tools.append({
446
+ "name": func["name"],
447
+ "description": func.get("description", ""),
448
+ "input_schema": func.get("parameters", {})
449
+ })
450
+
451
+ params = {
452
+ "model": kwargs.get("model", self.config.model),
453
+ "max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
454
+ "messages": converted,
455
+ "tools": anthropic_tools,
456
+ }
457
+
458
+ if system:
459
+ params["system"] = system
460
+
461
+ params.update(self.config.extra_params)
462
+
463
+ response = client.messages.create(**params)
464
+
465
+ # 解析工具调用
466
+ tool_calls = []
467
+ content = ""
468
+
469
+ for block in response.content:
470
+ if hasattr(block, 'text'):
471
+ content = block.text
472
+ elif block.type == "tool_use":
473
+ tool_calls.append({
474
+ "id": block.id,
475
+ "type": "function",
476
+ "function": {
477
+ "name": block.name,
478
+ "arguments": json.dumps(block.input)
479
+ }
480
+ })
481
+
482
+ return {
483
+ "content": content,
484
+ "tool_calls": tool_calls if tool_calls else None,
485
+ "message": response
486
+ }
487
+
488
+
489
+ class OllamaProvider(BaseProvider):
490
+ """Ollama 本地模型提供商
491
+
492
+ 支持 Ollama 运行的本地模型,如 llama2、mistral、codellama 等。
493
+ 需要 Ollama 服务运行在本地或远程服务器上。
494
+
495
+ 使用示例:
496
+ config = ProviderConfig(
497
+ provider_type=ProviderType.OLLAMA,
498
+ model="llama2",
499
+ base_url="http://localhost:11434"
500
+ )
501
+ provider = OllamaProvider(config)
502
+ """
503
+
504
+ def __init__(self, config: ProviderConfig):
505
+ super().__init__(config)
506
+ self._base_url = config.base_url or "http://localhost:11434"
507
+
508
+ def _get_client(self):
509
+ """获取 HTTP 客户端"""
510
+ try:
511
+ import urllib.request
512
+ return urllib.request
513
+ except ImportError:
514
+ raise ImportError("需要 urllib 支持")
515
+
516
+ def _make_request(
517
+ self,
518
+ endpoint: str,
519
+ data: Dict[str, Any],
520
+ stream: bool = False
521
+ ) -> Any:
522
+ """发送 HTTP 请求到 Ollama"""
523
+ import json
524
+ import urllib.request
525
+
526
+ url = f"{self._base_url.rstrip('/')}/{endpoint.lstrip('/')}"
527
+ headers = {"Content-Type": "application/json"}
528
+
529
+ req = urllib.request.Request(
530
+ url,
531
+ data=json.dumps(data).encode("utf-8"),
532
+ headers=headers,
533
+ method="POST"
534
+ )
535
+
536
+ with urllib.request.urlopen(req) as response:
537
+ return json.loads(response.read().decode("utf-8"))
538
+
539
+ def _make_stream_request(
540
+ self,
541
+ endpoint: str,
542
+ data: Dict[str, Any]
543
+ ) -> Iterator[Dict[str, Any]]:
544
+ """发送流式 HTTP 请求"""
545
+ import json
546
+ import urllib.request
547
+
548
+ url = f"{self._base_url.rstrip('/')}/{endpoint.lstrip('/')}"
549
+ headers = {"Content-Type": "application/json"}
550
+
551
+ req = urllib.request.Request(
552
+ url,
553
+ data=json.dumps({**data, "stream": True}).encode("utf-8"),
554
+ headers=headers,
555
+ method="POST"
556
+ )
557
+
558
+ with urllib.request.urlopen(req) as response:
559
+ for line in response:
560
+ line = line.decode("utf-8").strip()
561
+ if line:
562
+ yield json.loads(line)
563
+
564
+ async def _make_async_request(
565
+ self,
566
+ endpoint: str,
567
+ data: Dict[str, Any]
568
+ ) -> Any:
569
+ """发送异步 HTTP 请求"""
570
+ import aiohttp
571
+
572
+ url = f"{self._base_url.rstrip('/')}/{endpoint.lstrip('/')}"
573
+
574
+ async with aiohttp.ClientSession() as session:
575
+ async with session.post(
576
+ url,
577
+ json=data,
578
+ headers={"Content-Type": "application/json"}
579
+ ) as response:
580
+ return await response.json()
581
+
582
+ async def _make_async_stream_request(
583
+ self,
584
+ endpoint: str,
585
+ data: Dict[str, Any]
586
+ ) -> AsyncIterator[Dict[str, Any]]:
587
+ """发送异步流式请求"""
588
+ import json
589
+ import aiohttp
590
+
591
+ url = f"{self._base_url.rstrip('/')}/{endpoint.lstrip('/')}"
592
+
593
+ async with aiohttp.ClientSession() as session:
594
+ async with session.post(
595
+ url,
596
+ json={**data, "stream": True},
597
+ headers={"Content-Type": "application/json"}
598
+ ) as response:
599
+ async for line in response.content:
600
+ line = line.decode("utf-8").strip()
601
+ if line:
602
+ yield json.loads(line)
603
+
604
+ def _convert_messages(self, messages: List[Dict[str, str]]) -> str:
605
+ """将消息转换为 Ollama 格式的提示词"""
606
+ prompt_parts = []
607
+
608
+ for msg in messages:
609
+ role = msg["role"]
610
+ content = msg["content"]
611
+
612
+ if role == "system":
613
+ prompt_parts.append(f"System: {content}")
614
+ elif role == "user":
615
+ prompt_parts.append(f"User: {content}")
616
+ elif role == "assistant":
617
+ prompt_parts.append(f"Assistant: {content}")
618
+
619
+ prompt_parts.append("Assistant:")
620
+ return "\n\n".join(prompt_parts)
621
+
622
+ def call(
623
+ self,
624
+ messages: List[Dict[str, str]],
625
+ **kwargs
626
+ ) -> str:
627
+ prompt = self._convert_messages(messages)
628
+
629
+ data = {
630
+ "model": kwargs.get("model", self.config.model),
631
+ "prompt": prompt,
632
+ "options": {
633
+ "temperature": kwargs.get("temperature", self.config.temperature),
634
+ "num_predict": kwargs.get("max_tokens", self.config.max_tokens),
635
+ }
636
+ }
637
+
638
+ # 合并额外参数
639
+ if self.config.extra_params:
640
+ data["options"].update(self.config.extra_params)
641
+
642
+ result = self._make_request("/api/generate", data)
643
+ return result.get("response", "")
644
+
645
+ def stream(
646
+ self,
647
+ messages: List[Dict[str, str]],
648
+ **kwargs
649
+ ) -> Iterator[StreamChunk]:
650
+ prompt = self._convert_messages(messages)
651
+
652
+ data = {
653
+ "model": kwargs.get("model", self.config.model),
654
+ "prompt": prompt,
655
+ "options": {
656
+ "temperature": kwargs.get("temperature", self.config.temperature),
657
+ "num_predict": kwargs.get("max_tokens", self.config.max_tokens),
658
+ }
659
+ }
660
+
661
+ if self.config.extra_params:
662
+ data["options"].update(self.config.extra_params)
663
+
664
+ accumulated = ""
665
+ yield StreamChunk(content="", status=StreamStatus.STARTED, accumulated="")
666
+
667
+ for chunk in self._make_stream_request("/api/generate", data):
668
+ if "response" in chunk:
669
+ content = chunk["response"]
670
+ accumulated += content
671
+ yield StreamChunk(
672
+ content=content,
673
+ status=StreamStatus.STREAMING,
674
+ accumulated=accumulated
675
+ )
676
+
677
+ if chunk.get("done", False):
678
+ break
679
+
680
+ yield StreamChunk(content="", status=StreamStatus.COMPLETED, accumulated=accumulated)
681
+
682
+ async def call_async(
683
+ self,
684
+ messages: List[Dict[str, str]],
685
+ **kwargs
686
+ ) -> str:
687
+ prompt = self._convert_messages(messages)
688
+
689
+ data = {
690
+ "model": kwargs.get("model", self.config.model),
691
+ "prompt": prompt,
692
+ "options": {
693
+ "temperature": kwargs.get("temperature", self.config.temperature),
694
+ "num_predict": kwargs.get("max_tokens", self.config.max_tokens),
695
+ }
696
+ }
697
+
698
+ if self.config.extra_params:
699
+ data["options"].update(self.config.extra_params)
700
+
701
+ result = await self._make_async_request("/api/generate", data)
702
+ return result.get("response", "")
703
+
704
+ async def stream_async(
705
+ self,
706
+ messages: List[Dict[str, str]],
707
+ **kwargs
708
+ ) -> AsyncIterator[StreamChunk]:
709
+ prompt = self._convert_messages(messages)
710
+
711
+ data = {
712
+ "model": kwargs.get("model", self.config.model),
713
+ "prompt": prompt,
714
+ "options": {
715
+ "temperature": kwargs.get("temperature", self.config.temperature),
716
+ "num_predict": kwargs.get("max_tokens", self.config.max_tokens),
717
+ }
718
+ }
719
+
720
+ if self.config.extra_params:
721
+ data["options"].update(self.config.extra_params)
722
+
723
+ accumulated = ""
724
+ yield StreamChunk(content="", status=StreamStatus.STARTED, accumulated="")
725
+
726
+ async for chunk in self._make_async_stream_request("/api/generate", data):
727
+ if "response" in chunk:
728
+ content = chunk["response"]
729
+ accumulated += content
730
+ yield StreamChunk(
731
+ content=content,
732
+ status=StreamStatus.STREAMING,
733
+ accumulated=accumulated
734
+ )
735
+
736
+ if chunk.get("done", False):
737
+ break
738
+
739
+ yield StreamChunk(content="", status=StreamStatus.COMPLETED, accumulated=accumulated)
740
+
741
+ def call_with_tools(
742
+ self,
743
+ messages: List[Dict[str, str]],
744
+ tools: List[dict],
745
+ **kwargs
746
+ ) -> Dict[str, Any]:
747
+ # Ollama 原生不支持工具调用,通过提示词模拟
748
+ tool_descriptions = []
749
+ for tool in tools:
750
+ if tool.get("type") == "function":
751
+ func = tool["function"]
752
+ desc = f"- {func['name']}: {func.get('description', 'No description')}"
753
+ if "parameters" in func:
754
+ desc += f"\n Parameters: {json.dumps(func['parameters'])}"
755
+ tool_descriptions.append(desc)
756
+
757
+ tool_prompt = "\n".join(tool_descriptions)
758
+ enhanced_messages = messages + [{
759
+ "role": "system",
760
+ "content": f"\n\nAvailable tools:\n{tool_prompt}\n\nTo use a tool, respond with a JSON object."
761
+ }]
762
+
763
+ response = self.call(enhanced_messages, **kwargs)
764
+
765
+ # 尝试解析 JSON 工具调用
766
+ tool_calls = None
767
+ try:
768
+ import re
769
+ json_match = re.search(r'\{[\s\S]*\}', response)
770
+ if json_match:
771
+ parsed = json.loads(json_match.group())
772
+ if "name" in parsed or "function" in parsed:
773
+ tool_calls = [{
774
+ "id": f"call_{hash(response) % 10000}",
775
+ "type": "function",
776
+ "function": {
777
+ "name": parsed.get("name") or parsed.get("function"),
778
+ "arguments": json.dumps(parsed.get("arguments", parsed))
779
+ }
780
+ }]
781
+ except (json.JSONDecodeError, KeyError):
782
+ pass
783
+
784
+ return {
785
+ "content": response,
786
+ "tool_calls": tool_calls,
787
+ "message": response
788
+ }
789
+
790
+ def list_models(self) -> List[str]:
791
+ """列出可用的本地模型"""
792
+ import json
793
+ import urllib.request
794
+
795
+ url = f"{self._base_url.rstrip('/')}/api/tags"
796
+
797
+ try:
798
+ req = urllib.request.Request(url, method="GET")
799
+ with urllib.request.urlopen(req) as response:
800
+ data = json.loads(response.read().decode("utf-8"))
801
+ return [model["name"] for model in data.get("models", [])]
802
+ except Exception:
803
+ return []
804
+
805
+
806
+ class TransformersProvider(BaseProvider):
807
+ """HuggingFace Transformers 本地模型提供商
808
+
809
+ 使用 transformers 库在本地运行模型。支持各种 HuggingFace 模型。
810
+
811
+ 使用示例:
812
+ config = ProviderConfig(
813
+ provider_type=ProviderType.HUGGINGFACE,
814
+ model="microsoft/DialoGPT-medium",
815
+ extra_params={"device": "cuda"} # 或 "cpu", "mps"
816
+ )
817
+ provider = TransformersProvider(config)
818
+ """
819
+
820
+ def __init__(self, config: ProviderConfig):
821
+ super().__init__(config)
822
+ self._model: Optional["PreTrainedModel"] = None
823
+ self._tokenizer: Optional["PreTrainedTokenizerBase"] = None
824
+ self._device = config.extra_params.get("device", "auto")
825
+ self._pipeline = None
826
+
827
+ def _load_model(self):
828
+ """延迟加载模型"""
829
+ if self._model is not None:
830
+ return
831
+
832
+ try:
833
+ from transformers import AutoModelForCausalLM, AutoTokenizer
834
+ except ImportError:
835
+ raise ImportError(
836
+ "需要安装 transformers 和 torch。运行: pip install transformers torch"
837
+ )
838
+
839
+ model_name = self.config.model
840
+
841
+ try:
842
+ # 尝试加载 tokenizer
843
+ self._tokenizer = AutoTokenizer.from_pretrained(model_name)
844
+
845
+ # 加载模型
846
+ self._model = AutoModelForCausalLM.from_pretrained(
847
+ model_name,
848
+ device_map=self._device if self._device != "auto" else None
849
+ )
850
+
851
+ if self._device != "auto" and hasattr(self._model, "to"):
852
+ self._model = self._model.to(self._device)
853
+
854
+ except Exception as e:
855
+ raise RuntimeError(f"加载模型失败: {e}")
856
+
857
+ def _load_pipeline(self):
858
+ """加载 pipeline(更简单的方式)"""
859
+ if self._pipeline is not None:
860
+ return
861
+
862
+ try:
863
+ from transformers import pipeline
864
+ except ImportError:
865
+ raise ImportError(
866
+ "需要安装 transformers。运行: pip install transformers"
867
+ )
868
+
869
+ self._pipeline = pipeline(
870
+ "text-generation",
871
+ model=self.config.model,
872
+ device=self._device if self._device != "auto" else -1
873
+ )
874
+
875
+ def _convert_messages(self, messages: List[Dict[str, str]]) -> str:
876
+ """将消息转换为提示词"""
877
+ prompt_parts = []
878
+
879
+ for msg in messages:
880
+ role = msg["role"]
881
+ content = msg["content"]
882
+
883
+ if role == "system":
884
+ prompt_parts.append(f"[INST] <<SYS>>\n{content}\n<</SYS>>\n\n[/INST]")
885
+ elif role == "user":
886
+ prompt_parts.append(f"[INST] {content} [/INST]")
887
+ elif role == "assistant":
888
+ prompt_parts.append(content)
889
+
890
+ return "".join(prompt_parts)
891
+
892
+ def call(
893
+ self,
894
+ messages: List[Dict[str, str]],
895
+ **kwargs
896
+ ) -> str:
897
+ self._load_model()
898
+
899
+ # 类型检查:_load_model 成功后这些不为 None
900
+ assert self._model is not None
901
+ assert self._tokenizer is not None
902
+
903
+ prompt = self._convert_messages(messages)
904
+
905
+ # Tokenize
906
+ inputs = self._tokenizer(prompt, return_tensors="pt")
907
+
908
+ if self._device != "cpu" and hasattr(inputs, "to"):
909
+ inputs = {k: v.to(self._model.device) for k, v in inputs.items()}
910
+
911
+ # Generate
912
+ max_new_tokens = kwargs.get("max_tokens", self.config.max_tokens)
913
+ temperature = kwargs.get("temperature", self.config.temperature)
914
+
915
+ if temperature > 0:
916
+ outputs = self._model.generate(
917
+ **inputs,
918
+ max_new_tokens=max_new_tokens,
919
+ temperature=temperature,
920
+ do_sample=True,
921
+ pad_token_id=self._tokenizer.eos_token_id
922
+ )
923
+ else:
924
+ outputs = self._model.generate(
925
+ **inputs,
926
+ max_new_tokens=max_new_tokens,
927
+ do_sample=False,
928
+ pad_token_id=self._tokenizer.eos_token_id
929
+ )
930
+
931
+ # Decode
932
+ generated_text = self._tokenizer.decode(
933
+ outputs[0][inputs["input_ids"].shape[1]:],
934
+ skip_special_tokens=True
935
+ )
936
+
937
+ return generated_text
938
+
939
+ def stream(
940
+ self,
941
+ messages: List[Dict[str, str]],
942
+ **kwargs
943
+ ) -> Iterator[StreamChunk]:
944
+ # 本地模型流式生成需要 TextIteratorStreamer
945
+ try:
946
+ from transformers import TextIteratorStreamer
947
+ import threading
948
+ except ImportError:
949
+ # 如果不支持流式,返回完整结果
950
+ result = self.call(messages, **kwargs)
951
+ yield StreamChunk(content="", status=StreamStatus.STARTED, accumulated="")
952
+ yield StreamChunk(content=result, status=StreamStatus.STREAMING, accumulated=result)
953
+ yield StreamChunk(content="", status=StreamStatus.COMPLETED, accumulated=result)
954
+ return
955
+
956
+ self._load_model()
957
+
958
+ # 类型检查:_load_model 成功后这些不为 None
959
+ assert self._model is not None
960
+ assert self._tokenizer is not None
961
+
962
+ prompt = self._convert_messages(messages)
963
+ inputs = self._tokenizer(prompt, return_tensors="pt")
964
+
965
+ if self._device != "cpu":
966
+ inputs = {k: v.to(self._model.device) for k, v in inputs.items()}
967
+
968
+ max_new_tokens = kwargs.get("max_tokens", self.config.max_tokens)
969
+ temperature = kwargs.get("temperature", self.config.temperature)
970
+
971
+ streamer = TextIteratorStreamer(
972
+ self._tokenizer,
973
+ skip_prompt=True,
974
+ skip_special_tokens=True
975
+ )
976
+
977
+ generation_kwargs = {
978
+ **inputs,
979
+ "max_new_tokens": max_new_tokens,
980
+ "streamer": streamer,
981
+ "pad_token_id": self._tokenizer.eos_token_id
982
+ }
983
+
984
+ if temperature > 0:
985
+ generation_kwargs["temperature"] = temperature
986
+ generation_kwargs["do_sample"] = True
987
+ else:
988
+ generation_kwargs["do_sample"] = False
989
+
990
+ # 在后台线程中运行生成
991
+ thread = threading.Thread(
992
+ target=self._model.generate,
993
+ kwargs=generation_kwargs
994
+ )
995
+ thread.start()
996
+
997
+ accumulated = ""
998
+ yield StreamChunk(content="", status=StreamStatus.STARTED, accumulated="")
999
+
1000
+ for text in streamer:
1001
+ accumulated += text
1002
+ yield StreamChunk(
1003
+ content=text,
1004
+ status=StreamStatus.STREAMING,
1005
+ accumulated=accumulated
1006
+ )
1007
+
1008
+ thread.join()
1009
+ yield StreamChunk(content="", status=StreamStatus.COMPLETED, accumulated=accumulated)
1010
+
1011
+ async def call_async(
1012
+ self,
1013
+ messages: List[Dict[str, str]],
1014
+ **kwargs
1015
+ ) -> str:
1016
+ import asyncio
1017
+ loop = asyncio.get_event_loop()
1018
+ return await loop.run_in_executor(
1019
+ None,
1020
+ lambda: self.call(messages, **kwargs)
1021
+ )
1022
+
1023
+ async def stream_async(
1024
+ self,
1025
+ messages: List[Dict[str, str]],
1026
+ **kwargs
1027
+ ) -> AsyncIterator[StreamChunk]:
1028
+ for chunk in self.stream(messages, **kwargs):
1029
+ yield chunk
1030
+ await asyncio.sleep(0) # 让出控制权
1031
+
1032
+ def call_with_tools(
1033
+ self,
1034
+ messages: List[Dict[str, str]],
1035
+ tools: List[dict],
1036
+ **kwargs
1037
+ ) -> Dict[str, Any]:
1038
+ # 与 Ollama 类似,通过提示词模拟
1039
+ tool_descriptions = []
1040
+ for tool in tools:
1041
+ if tool.get("type") == "function":
1042
+ func = tool["function"]
1043
+ desc = f"- {func['name']}: {func.get('description', 'No description')}"
1044
+ if "parameters" in func:
1045
+ desc += f"\n Parameters: {json.dumps(func['parameters'])}"
1046
+ tool_descriptions.append(desc)
1047
+
1048
+ tool_prompt = "\n".join(tool_descriptions)
1049
+ enhanced_messages = messages + [{
1050
+ "role": "system",
1051
+ "content": f"\n\nAvailable tools:\n{tool_prompt}\n\nTo use a tool, respond with a JSON object."
1052
+ }]
1053
+
1054
+ response = self.call(enhanced_messages, **kwargs)
1055
+
1056
+ tool_calls = None
1057
+ try:
1058
+ import re
1059
+ json_match = re.search(r'\{[\s\S]*\}', response)
1060
+ if json_match:
1061
+ parsed = json.loads(json_match.group())
1062
+ if "name" in parsed or "function" in parsed:
1063
+ tool_calls = [{
1064
+ "id": f"call_{hash(response) % 10000}",
1065
+ "type": "function",
1066
+ "function": {
1067
+ "name": parsed.get("name") or parsed.get("function"),
1068
+ "arguments": json.dumps(parsed.get("arguments", parsed))
1069
+ }
1070
+ }]
1071
+ except (json.JSONDecodeError, KeyError):
1072
+ pass
1073
+
1074
+ return {
1075
+ "content": response,
1076
+ "tool_calls": tool_calls,
1077
+ "message": response
1078
+ }
1079
+
1080
+
1081
+ class ProviderFactory:
1082
+ """提供商工厂"""
1083
+
1084
+ _providers: Dict[ProviderType, type] = {
1085
+ ProviderType.OPENAI: OpenAIProvider,
1086
+ ProviderType.ANTHROPIC: AnthropicProvider,
1087
+ ProviderType.OLLAMA: OllamaProvider,
1088
+ ProviderType.HUGGINGFACE: TransformersProvider,
1089
+ }
1090
+
1091
+ @classmethod
1092
+ def create(cls, config: ProviderConfig) -> BaseProvider:
1093
+ """创建提供商实例"""
1094
+ provider_class = cls._providers.get(config.provider_type)
1095
+
1096
+ if provider_class is None:
1097
+ raise ValueError(f"Unknown provider type: {config.provider_type}")
1098
+
1099
+ return provider_class(config)
1100
+
1101
+ @classmethod
1102
+ def register(cls, provider_type: ProviderType, provider_class: type) -> None:
1103
+ """注册自定义提供商"""
1104
+ cls._providers[provider_type] = provider_class
1105
+
1106
+
1107
+ def create_provider(
1108
+ provider_type: str = "openai",
1109
+ api_key: Optional[str] = None,
1110
+ model: Optional[str] = None,
1111
+ base_url: Optional[str] = None,
1112
+ **kwargs
1113
+ ) -> BaseProvider:
1114
+ """
1115
+ 创建提供商实例
1116
+
1117
+ 参数:
1118
+ provider_type: 提供商类型 ("openai", "anthropic", "ollama", "huggingface", "custom")
1119
+ api_key: API 密钥 (OpenAI/Anthropic 需要)
1120
+ model: 模型名称
1121
+ base_url: API 基础 URL (Ollama 默认 http://localhost:11434)
1122
+ **kwargs: 其他配置
1123
+
1124
+ 返回:
1125
+ 提供商实例
1126
+
1127
+ 示例:
1128
+ # OpenAI
1129
+ provider = create_provider("openai", api_key="sk-...", model="gpt-4")
1130
+
1131
+ # Anthropic
1132
+ provider = create_provider("anthropic", api_key="sk-ant-...", model="claude-3-opus")
1133
+
1134
+ # Ollama 本地模型
1135
+ provider = create_provider("ollama", model="llama2", base_url="http://localhost:11434")
1136
+
1137
+ # HuggingFace Transformers
1138
+ provider = create_provider("huggingface", model="microsoft/DialoGPT-medium")
1139
+ """
1140
+ type_map = {
1141
+ "openai": ProviderType.OPENAI,
1142
+ "anthropic": ProviderType.ANTHROPIC,
1143
+ "ollama": ProviderType.OLLAMA,
1144
+ "huggingface": ProviderType.HUGGINGFACE,
1145
+ "custom": ProviderType.CUSTOM,
1146
+ }
1147
+
1148
+ pt = type_map.get(provider_type.lower())
1149
+ if pt is None:
1150
+ raise ValueError(f"Unknown provider type: {provider_type}")
1151
+
1152
+ config = ProviderConfig(
1153
+ provider_type=pt,
1154
+ api_key=api_key,
1155
+ model=model or "",
1156
+ base_url=base_url,
1157
+ **kwargs
1158
+ )
1159
+
1160
+ return ProviderFactory.create(config)