gpu-worker 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,57 @@
1
+ """推理引擎基类"""
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Dict, Optional
4
+ import torch
5
+ import logging
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class BaseEngine(ABC):
11
+ """推理引擎基类 - 所有引擎必须继承此类"""
12
+
13
+ def __init__(self, config: Dict[str, Any]):
14
+ self.config = config
15
+ self.model = None
16
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ self.loaded = False
18
+
19
+ @abstractmethod
20
+ def load_model(self) -> None:
21
+ """加载模型到内存"""
22
+ pass
23
+
24
+ @abstractmethod
25
+ def inference(self, params: Dict[str, Any]) -> Dict[str, Any]:
26
+ """执行推理"""
27
+ pass
28
+
29
+ @abstractmethod
30
+ def unload_model(self) -> None:
31
+ """卸载模型,释放内存"""
32
+ pass
33
+
34
+ def get_status(self) -> Dict[str, Any]:
35
+ """获取引擎状态"""
36
+ status = {
37
+ "loaded": self.loaded,
38
+ "device": self.device,
39
+ }
40
+
41
+ if torch.cuda.is_available():
42
+ status["gpu"] = {
43
+ "name": torch.cuda.get_device_name(0),
44
+ "memory_used_gb": torch.cuda.memory_allocated() / 1024**3,
45
+ "memory_total_gb": torch.cuda.get_device_properties(0).total_memory / 1024**3
46
+ }
47
+
48
+ return status
49
+
50
+ def _get_gpu_memory(self) -> Optional[Dict[str, float]]:
51
+ """获取GPU内存信息"""
52
+ if torch.cuda.is_available():
53
+ return {
54
+ "used_gb": torch.cuda.memory_allocated() / 1024**3,
55
+ "total_gb": torch.cuda.get_device_properties(0).total_memory / 1024**3
56
+ }
57
+ return None
@@ -0,0 +1,83 @@
1
+ """图像生成推理引擎"""
2
+ from typing import Dict, Any
3
+ import torch
4
+ import io
5
+ import base64
6
+ import logging
7
+
8
+ from .base import BaseEngine
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class ImageGenEngine(BaseEngine):
14
+ """图像生成推理引擎"""
15
+
16
+ def load_model(self) -> None:
17
+ """加载图像生成模型"""
18
+ from diffusers import DiffusionPipeline
19
+
20
+ model_id = self.config.get("model_id", "Zhihu-ai/Z-Image-Turbo")
21
+ logger.info(f"Loading image generation model: {model_id}")
22
+
23
+ # 加载Pipeline
24
+ self.pipe = DiffusionPipeline.from_pretrained(
25
+ model_id,
26
+ torch_dtype=torch.bfloat16
27
+ )
28
+
29
+ # 内存优化
30
+ if self.config.get("enable_cpu_offload", True):
31
+ self.pipe.enable_sequential_cpu_offload()
32
+ else:
33
+ self.pipe = self.pipe.to(self.device)
34
+
35
+ self.loaded = True
36
+ logger.info("Image generation model loaded successfully")
37
+
38
+ def inference(self, params: Dict[str, Any]) -> Dict[str, Any]:
39
+ """执行图像生成"""
40
+ prompt = params.get("prompt", "")
41
+ negative_prompt = params.get("negative_prompt", "")
42
+ width = params.get("width", 1024)
43
+ height = params.get("height", 1024)
44
+ steps = params.get("steps", 4)
45
+ seed = params.get("seed", None)
46
+
47
+ # 设置随机种子
48
+ generator = None
49
+ if seed is not None:
50
+ generator = torch.Generator(device="cpu").manual_seed(seed)
51
+
52
+ # 生成图像
53
+ result = self.pipe(
54
+ prompt=prompt,
55
+ negative_prompt=negative_prompt if negative_prompt else None,
56
+ width=width,
57
+ height=height,
58
+ num_inference_steps=steps,
59
+ generator=generator
60
+ )
61
+
62
+ image = result.images[0]
63
+
64
+ # 转换为base64
65
+ buffered = io.BytesIO()
66
+ image.save(buffered, format="PNG")
67
+ image_base64 = base64.b64encode(buffered.getvalue()).decode()
68
+
69
+ return {
70
+ "image_base64": image_base64,
71
+ "width": width,
72
+ "height": height,
73
+ "seed": seed,
74
+ "format": "png"
75
+ }
76
+
77
+ def unload_model(self) -> None:
78
+ """卸载模型"""
79
+ if hasattr(self, "pipe"):
80
+ del self.pipe
81
+ torch.cuda.empty_cache()
82
+ self.loaded = False
83
+ logger.info("Image generation model unloaded")
package/engines/llm.py ADDED
@@ -0,0 +1,97 @@
1
+ """LLM推理引擎"""
2
+ from typing import Dict, Any, List
3
+ import torch
4
+ import logging
5
+
6
+ from .base import BaseEngine
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class LLMEngine(BaseEngine):
12
+ """大语言模型推理引擎"""
13
+
14
+ def load_model(self) -> None:
15
+ """加载LLM模型"""
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
17
+
18
+ model_id = self.config.get("model_id", "Qwen/Qwen2.5-7B-Instruct")
19
+ logger.info(f"Loading LLM model: {model_id}")
20
+
21
+ # 加载tokenizer
22
+ self.tokenizer = AutoTokenizer.from_pretrained(
23
+ model_id,
24
+ trust_remote_code=True
25
+ )
26
+
27
+ # 加载模型
28
+ load_kwargs = {
29
+ "torch_dtype": torch.bfloat16,
30
+ "trust_remote_code": True,
31
+ }
32
+
33
+ if self.config.get("enable_cpu_offload", True):
34
+ load_kwargs["device_map"] = "auto"
35
+ else:
36
+ load_kwargs["device_map"] = {"": self.device}
37
+
38
+ self.model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs)
39
+
40
+ self.loaded = True
41
+ logger.info("LLM model loaded successfully")
42
+
43
+ def inference(self, params: Dict[str, Any]) -> Dict[str, Any]:
44
+ """执行LLM推理"""
45
+ messages = params.get("messages", [])
46
+ max_tokens = params.get("max_tokens", 2048)
47
+ temperature = params.get("temperature", 0.7)
48
+ top_p = params.get("top_p", 0.9)
49
+
50
+ # 格式化输入
51
+ text = self.tokenizer.apply_chat_template(
52
+ messages,
53
+ tokenize=False,
54
+ add_generation_prompt=True
55
+ )
56
+
57
+ inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
58
+ input_length = inputs.input_ids.shape[1]
59
+
60
+ # 生成
61
+ with torch.no_grad():
62
+ outputs = self.model.generate(
63
+ **inputs,
64
+ max_new_tokens=max_tokens,
65
+ temperature=temperature,
66
+ top_p=top_p,
67
+ do_sample=temperature > 0,
68
+ pad_token_id=self.tokenizer.eos_token_id
69
+ )
70
+
71
+ # 解码响应
72
+ response = self.tokenizer.decode(
73
+ outputs[0][input_length:],
74
+ skip_special_tokens=True
75
+ )
76
+
77
+ output_length = outputs.shape[1] - input_length
78
+
79
+ return {
80
+ "response": response,
81
+ "usage": {
82
+ "prompt_tokens": input_length,
83
+ "completion_tokens": output_length,
84
+ "total_tokens": input_length + output_length
85
+ }
86
+ }
87
+
88
+ def unload_model(self) -> None:
89
+ """卸载模型"""
90
+ if self.model:
91
+ del self.model
92
+ self.model = None
93
+ if hasattr(self, "tokenizer"):
94
+ del self.tokenizer
95
+ torch.cuda.empty_cache()
96
+ self.loaded = False
97
+ logger.info("LLM model unloaded")
@@ -0,0 +1,216 @@
1
+ """LLM引擎扩展基类 - 支持高性能推理后端"""
2
+ from abc import abstractmethod
3
+ from typing import Dict, Any, List, Optional, AsyncIterator
4
+ from dataclasses import dataclass
5
+ from enum import Enum
6
+ import asyncio
7
+ import threading
8
+ from concurrent.futures import Future
9
+ import logging
10
+
11
+ from .base import BaseEngine
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class LLMBackend(Enum):
17
+ """LLM 推理后端类型"""
18
+ NATIVE = "native" # 原生 Transformers
19
+ SGLANG = "sglang" # SGLang (推荐)
20
+ VLLM = "vllm" # vLLM
21
+
22
+
23
+ @dataclass
24
+ class GenerationConfig:
25
+ """生成配置"""
26
+ max_tokens: int = 2048
27
+ temperature: float = 0.7
28
+ top_p: float = 0.9
29
+ top_k: int = 50
30
+ stop_sequences: Optional[List[str]] = None
31
+ stream: bool = False
32
+
33
+
34
+ @dataclass
35
+ class GenerationResult:
36
+ """生成结果"""
37
+ text: str
38
+ prompt_tokens: int
39
+ completion_tokens: int
40
+ total_tokens: int
41
+ finish_reason: str = "stop"
42
+ cached_tokens: int = 0 # 前缀缓存命中的 token 数
43
+
44
+
45
+ class LLMBaseEngine(BaseEngine):
46
+ """
47
+ LLM 引擎扩展基类
48
+
49
+ 提供高性能推理后端的统一接口,支持:
50
+ - 异步生成
51
+ - 批量处理
52
+ - 流式输出
53
+ - 前缀缓存
54
+ """
55
+
56
+ def __init__(self, config: Dict[str, Any]):
57
+ super().__init__(config)
58
+ self.backend_type: LLMBackend = LLMBackend.NATIVE
59
+ self.tokenizer = None
60
+ self._batch_processor = None
61
+
62
+ @abstractmethod
63
+ async def generate_async(
64
+ self,
65
+ messages: List[Dict[str, str]],
66
+ config: Optional[GenerationConfig] = None
67
+ ) -> GenerationResult:
68
+ """
69
+ 异步生成接口
70
+
71
+ Args:
72
+ messages: 对话消息列表 [{"role": "user", "content": "..."}]
73
+ config: 生成配置
74
+
75
+ Returns:
76
+ 生成结果
77
+ """
78
+ pass
79
+
80
+ @abstractmethod
81
+ async def batch_generate(
82
+ self,
83
+ batch_messages: List[List[Dict[str, str]]],
84
+ config: Optional[GenerationConfig] = None
85
+ ) -> List[GenerationResult]:
86
+ """
87
+ 批量生成接口
88
+
89
+ Args:
90
+ batch_messages: 批量对话消息
91
+ config: 生成配置
92
+
93
+ Returns:
94
+ 批量生成结果
95
+ """
96
+ pass
97
+
98
+ async def stream_generate(
99
+ self,
100
+ messages: List[Dict[str, str]],
101
+ config: Optional[GenerationConfig] = None
102
+ ) -> AsyncIterator[str]:
103
+ """
104
+ 流式生成接口(默认实现:单次返回完整结果)
105
+
106
+ Args:
107
+ messages: 对话消息列表
108
+ config: 生成配置
109
+
110
+ Yields:
111
+ 生成的文本片段
112
+ """
113
+ result = await self.generate_async(messages, config)
114
+ yield result.text
115
+
116
+ @staticmethod
117
+ def _run_coroutine_in_new_thread(coro):
118
+ future: Future = Future()
119
+
120
+ def runner() -> None:
121
+ try:
122
+ future.set_result(asyncio.run(coro))
123
+ except BaseException as exc:
124
+ future.set_exception(exc)
125
+
126
+ threading.Thread(target=runner, daemon=True).start()
127
+ return future.result()
128
+
129
+ def inference(self, params: Dict[str, Any]) -> Dict[str, Any]:
130
+ """
131
+ 同步推理接口(兼容 BaseEngine)
132
+
133
+ 包装异步接口,供同步调用
134
+ """
135
+ messages = params.get("messages", [])
136
+ config = GenerationConfig(
137
+ max_tokens=params.get("max_tokens", 2048),
138
+ temperature=params.get("temperature", 0.7),
139
+ top_p=params.get("top_p", 0.9),
140
+ top_k=params.get("top_k", 50),
141
+ stop_sequences=params.get("stop", None),
142
+ stream=params.get("stream", False)
143
+ )
144
+
145
+ try:
146
+ asyncio.get_running_loop()
147
+ except RuntimeError:
148
+ result = asyncio.run(self.generate_async(messages, config))
149
+ else:
150
+ result = self._run_coroutine_in_new_thread(self.generate_async(messages, config))
151
+
152
+ return {
153
+ "response": result.text,
154
+ "usage": {
155
+ "prompt_tokens": result.prompt_tokens,
156
+ "completion_tokens": result.completion_tokens,
157
+ "total_tokens": result.total_tokens,
158
+ "cached_tokens": result.cached_tokens
159
+ },
160
+ "finish_reason": result.finish_reason
161
+ }
162
+
163
+ def supports_streaming(self) -> bool:
164
+ """是否支持流式输出"""
165
+ return False
166
+
167
+ def supports_prefix_caching(self) -> bool:
168
+ """是否支持前缀缓存"""
169
+ return False
170
+
171
+ def supports_batch_inference(self) -> bool:
172
+ """是否支持批量推理"""
173
+ return False
174
+
175
+ def get_backend_info(self) -> Dict[str, Any]:
176
+ """获取后端信息"""
177
+ return {
178
+ "backend": self.backend_type.value,
179
+ "supports_streaming": self.supports_streaming(),
180
+ "supports_prefix_caching": self.supports_prefix_caching(),
181
+ "supports_batch_inference": self.supports_batch_inference()
182
+ }
183
+
184
+ def get_status(self) -> Dict[str, Any]:
185
+ """获取引擎状态(扩展父类)"""
186
+ status = super().get_status()
187
+ status["backend_info"] = self.get_backend_info()
188
+ return status
189
+
190
+
191
+ def create_llm_engine(config: Dict[str, Any]) -> LLMBaseEngine:
192
+ """
193
+ LLM 引擎工厂函数
194
+
195
+ 根据配置创建对应的 LLM 引擎实例
196
+
197
+ Args:
198
+ config: 引擎配置
199
+
200
+ Returns:
201
+ LLM 引擎实例
202
+ """
203
+ backend = config.get("backend", "native").lower()
204
+
205
+ if backend == "sglang":
206
+ from .llm_sglang import SGLangEngine
207
+ return SGLangEngine(config)
208
+
209
+ elif backend == "vllm":
210
+ from .llm_vllm import VLLMEngine
211
+ return VLLMEngine(config)
212
+
213
+ else:
214
+ # 默认使用原生 Transformers
215
+ from .llm import LLMEngine
216
+ return LLMEngine(config)