isa-model 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/model_registry.py +273 -46
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
- isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
- isa_model/eval/__init__.py +56 -0
- isa_model/eval/benchmarks.py +469 -0
- isa_model/eval/factory.py +582 -0
- isa_model/eval/metrics.py +628 -0
- isa_model/inference/ai_factory.py +98 -93
- isa_model/inference/providers/openai_provider.py +21 -7
- isa_model/inference/providers/replicate_provider.py +18 -5
- isa_model/inference/providers/triton_provider.py +1 -1
- isa_model/inference/services/audio/base_stt_service.py +91 -0
- isa_model/inference/services/audio/base_tts_service.py +136 -0
- isa_model/inference/services/audio/{yyds_audio_service.py → openai_tts_service.py} +4 -4
- isa_model/inference/services/embedding/ollama_embed_service.py +48 -36
- isa_model/inference/services/llm/__init__.py +0 -4
- isa_model/inference/services/llm/base_llm_service.py +134 -0
- isa_model/inference/services/llm/ollama_llm_service.py +1 -10
- isa_model/inference/services/llm/openai_llm_service.py +70 -61
- isa_model/inference/services/vision/__init__.py +1 -1
- isa_model/inference/services/vision/ollama_vision_service.py +4 -4
- isa_model/inference/services/vision/{yyds_vision_service.py → openai_vision_service.py} +5 -5
- isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
- isa_model/training/__init__.py +44 -0
- isa_model/training/factory.py +393 -0
- isa_model-0.2.0.dist-info/METADATA +327 -0
- {isa_model-0.1.0.dist-info → isa_model-0.2.0.dist-info}/RECORD +35 -60
- isa_model/deployment/mlflow_gateway/__init__.py +0 -8
- isa_model/deployment/mlflow_gateway/start_gateway.py +0 -65
- isa_model/deployment/unified_multimodal_client.py +0 -341
- isa_model/inference/adapter/triton_adapter.py +0 -453
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +0 -188
- isa_model/inference/backends/Pytorch/gemma_backend.py +0 -167
- isa_model/inference/backends/Pytorch/llama_backend.py +0 -166
- isa_model/inference/backends/Pytorch/whisper_backend.py +0 -194
- isa_model/inference/backends/__init__.py +0 -53
- isa_model/inference/backends/base_backend_client.py +0 -26
- isa_model/inference/backends/container_services.py +0 -104
- isa_model/inference/backends/local_services.py +0 -72
- isa_model/inference/backends/openai_client.py +0 -130
- isa_model/inference/backends/replicate_client.py +0 -197
- isa_model/inference/backends/third_party_services.py +0 -239
- isa_model/inference/backends/triton_client.py +0 -97
- isa_model/inference/client_sdk/client.py +0 -134
- isa_model/inference/client_sdk/client_data_std.py +0 -34
- isa_model/inference/client_sdk/client_sdk_schema.py +0 -16
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +0 -174
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +0 -250
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +0 -76
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +0 -195
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +0 -83
- isa_model/inference/services/audio/fish_speech/handler.py +0 -215
- isa_model/inference/services/audio/runpod_tts_fish_service.py +0 -212
- isa_model/inference/services/audio/triton_speech_service.py +0 -138
- isa_model/inference/services/audio/whisper_service.py +0 -186
- isa_model/inference/services/base_tts_service.py +0 -66
- isa_model/inference/services/embedding/bge_service.py +0 -183
- isa_model/inference/services/embedding/ollama_rerank_service.py +0 -118
- isa_model/inference/services/embedding/onnx_rerank_service.py +0 -73
- isa_model/inference/services/llm/gemma_service.py +0 -143
- isa_model/inference/services/llm/llama_service.py +0 -143
- isa_model/inference/services/llm/replicate_llm_service.py +0 -179
- isa_model/inference/services/llm/triton_llm_service.py +0 -230
- isa_model/inference/services/vision/replicate_vision_service.py +0 -241
- isa_model/inference/services/vision/triton_vision_service.py +0 -199
- isa_model-0.1.0.dist-info/METADATA +0 -116
- /isa_model/inference/{client_sdk/__init__.py → services/embedding/openai_embed_service.py} +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.2.0.dist-info}/WHEEL +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.2.0.dist-info}/top_level.txt +0 -0
isa_model/__init__.py
CHANGED
isa_model/core/model_registry.py
CHANGED
@@ -3,6 +3,9 @@ from enum import Enum
|
|
3
3
|
import logging
|
4
4
|
from pathlib import Path
|
5
5
|
import json
|
6
|
+
import sqlite3
|
7
|
+
from datetime import datetime
|
8
|
+
import threading
|
6
9
|
|
7
10
|
logger = logging.getLogger(__name__)
|
8
11
|
|
@@ -29,27 +32,45 @@ class ModelType(str, Enum):
|
|
29
32
|
VISION = "vision"
|
30
33
|
|
31
34
|
class ModelRegistry:
|
32
|
-
"""
|
35
|
+
"""SQLite-based registry for model metadata and capabilities"""
|
33
36
|
|
34
|
-
def __init__(self,
|
35
|
-
self.
|
36
|
-
self.
|
37
|
-
self.
|
37
|
+
def __init__(self, db_path: str = "./models/model_registry.db"):
|
38
|
+
self.db_path = Path(db_path)
|
39
|
+
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
40
|
+
self._lock = threading.Lock()
|
41
|
+
self._initialize_database()
|
38
42
|
|
39
|
-
def
|
40
|
-
"""
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
43
|
+
def _initialize_database(self):
|
44
|
+
"""Initialize SQLite database with required tables"""
|
45
|
+
with sqlite3.connect(self.db_path) as conn:
|
46
|
+
conn.execute("""
|
47
|
+
CREATE TABLE IF NOT EXISTS models (
|
48
|
+
model_id TEXT PRIMARY KEY,
|
49
|
+
model_type TEXT NOT NULL,
|
50
|
+
metadata TEXT,
|
51
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
52
|
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
53
|
+
)
|
54
|
+
""")
|
55
|
+
|
56
|
+
conn.execute("""
|
57
|
+
CREATE TABLE IF NOT EXISTS model_capabilities (
|
58
|
+
model_id TEXT,
|
59
|
+
capability TEXT,
|
60
|
+
PRIMARY KEY (model_id, capability),
|
61
|
+
FOREIGN KEY (model_id) REFERENCES models(model_id) ON DELETE CASCADE
|
62
|
+
)
|
63
|
+
""")
|
64
|
+
|
65
|
+
conn.execute("""
|
66
|
+
CREATE INDEX IF NOT EXISTS idx_model_type ON models(model_type)
|
67
|
+
""")
|
68
|
+
|
69
|
+
conn.execute("""
|
70
|
+
CREATE INDEX IF NOT EXISTS idx_capability ON model_capabilities(capability)
|
71
|
+
""")
|
72
|
+
|
73
|
+
conn.commit()
|
53
74
|
|
54
75
|
def register_model(self,
|
55
76
|
model_id: str,
|
@@ -58,14 +79,30 @@ class ModelRegistry:
|
|
58
79
|
metadata: Dict[str, Any]) -> bool:
|
59
80
|
"""Register a model with its capabilities and metadata"""
|
60
81
|
try:
|
61
|
-
self.
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
82
|
+
with self._lock:
|
83
|
+
with sqlite3.connect(self.db_path) as conn:
|
84
|
+
# Insert or update model
|
85
|
+
conn.execute("""
|
86
|
+
INSERT OR REPLACE INTO models
|
87
|
+
(model_id, model_type, metadata, updated_at)
|
88
|
+
VALUES (?, ?, ?, CURRENT_TIMESTAMP)
|
89
|
+
""", (model_id, model_type.value, json.dumps(metadata)))
|
90
|
+
|
91
|
+
# Clear existing capabilities
|
92
|
+
conn.execute("DELETE FROM model_capabilities WHERE model_id = ?", (model_id,))
|
93
|
+
|
94
|
+
# Insert new capabilities
|
95
|
+
for capability in capabilities:
|
96
|
+
conn.execute("""
|
97
|
+
INSERT INTO model_capabilities (model_id, capability)
|
98
|
+
VALUES (?, ?)
|
99
|
+
""", (model_id, capability.value))
|
100
|
+
|
101
|
+
conn.commit()
|
102
|
+
|
67
103
|
logger.info(f"Registered model {model_id}")
|
68
104
|
return True
|
105
|
+
|
69
106
|
except Exception as e:
|
70
107
|
logger.error(f"Failed to register model {model_id}: {e}")
|
71
108
|
return False
|
@@ -73,43 +110,233 @@ class ModelRegistry:
|
|
73
110
|
def unregister_model(self, model_id: str) -> bool:
|
74
111
|
"""Unregister a model"""
|
75
112
|
try:
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
113
|
+
with self._lock:
|
114
|
+
with sqlite3.connect(self.db_path) as conn:
|
115
|
+
cursor = conn.execute("DELETE FROM models WHERE model_id = ?", (model_id,))
|
116
|
+
conn.commit()
|
117
|
+
|
118
|
+
if cursor.rowcount > 0:
|
119
|
+
logger.info(f"Unregistered model {model_id}")
|
120
|
+
return True
|
121
|
+
return False
|
122
|
+
|
82
123
|
except Exception as e:
|
83
124
|
logger.error(f"Failed to unregister model {model_id}: {e}")
|
84
125
|
return False
|
85
126
|
|
86
127
|
def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]:
|
87
128
|
"""Get model information"""
|
88
|
-
|
129
|
+
try:
|
130
|
+
with sqlite3.connect(self.db_path) as conn:
|
131
|
+
conn.row_factory = sqlite3.Row
|
132
|
+
|
133
|
+
# Get model info
|
134
|
+
model_row = conn.execute("""
|
135
|
+
SELECT model_id, model_type, metadata, created_at, updated_at
|
136
|
+
FROM models WHERE model_id = ?
|
137
|
+
""", (model_id,)).fetchone()
|
138
|
+
|
139
|
+
if not model_row:
|
140
|
+
return None
|
141
|
+
|
142
|
+
# Get capabilities
|
143
|
+
capabilities = conn.execute("""
|
144
|
+
SELECT capability FROM model_capabilities WHERE model_id = ?
|
145
|
+
""", (model_id,)).fetchall()
|
146
|
+
|
147
|
+
model_info = {
|
148
|
+
"model_id": model_row["model_id"],
|
149
|
+
"type": model_row["model_type"],
|
150
|
+
"capabilities": [cap["capability"] for cap in capabilities],
|
151
|
+
"metadata": json.loads(model_row["metadata"]) if model_row["metadata"] else {},
|
152
|
+
"created_at": model_row["created_at"],
|
153
|
+
"updated_at": model_row["updated_at"]
|
154
|
+
}
|
155
|
+
|
156
|
+
return model_info
|
157
|
+
|
158
|
+
except Exception as e:
|
159
|
+
logger.error(f"Failed to get model info for {model_id}: {e}")
|
160
|
+
return None
|
89
161
|
|
90
162
|
def get_models_by_type(self, model_type: ModelType) -> Dict[str, Dict[str, Any]]:
|
91
163
|
"""Get all models of a specific type"""
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
164
|
+
try:
|
165
|
+
with sqlite3.connect(self.db_path) as conn:
|
166
|
+
conn.row_factory = sqlite3.Row
|
167
|
+
|
168
|
+
models = conn.execute("""
|
169
|
+
SELECT model_id, model_type, metadata, created_at, updated_at
|
170
|
+
FROM models WHERE model_type = ?
|
171
|
+
""", (model_type.value,)).fetchall()
|
172
|
+
|
173
|
+
result = {}
|
174
|
+
for model in models:
|
175
|
+
model_id = model["model_id"]
|
176
|
+
|
177
|
+
# Get capabilities for this model
|
178
|
+
capabilities = conn.execute("""
|
179
|
+
SELECT capability FROM model_capabilities WHERE model_id = ?
|
180
|
+
""", (model_id,)).fetchall()
|
181
|
+
|
182
|
+
result[model_id] = {
|
183
|
+
"type": model["model_type"],
|
184
|
+
"capabilities": [cap["capability"] for cap in capabilities],
|
185
|
+
"metadata": json.loads(model["metadata"]) if model["metadata"] else {},
|
186
|
+
"created_at": model["created_at"],
|
187
|
+
"updated_at": model["updated_at"]
|
188
|
+
}
|
189
|
+
|
190
|
+
return result
|
191
|
+
|
192
|
+
except Exception as e:
|
193
|
+
logger.error(f"Failed to get models by type {model_type}: {e}")
|
194
|
+
return {}
|
97
195
|
|
98
196
|
def get_models_by_capability(self, capability: ModelCapability) -> Dict[str, Dict[str, Any]]:
|
99
197
|
"""Get all models with a specific capability"""
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
198
|
+
try:
|
199
|
+
with sqlite3.connect(self.db_path) as conn:
|
200
|
+
conn.row_factory = sqlite3.Row
|
201
|
+
|
202
|
+
models = conn.execute("""
|
203
|
+
SELECT DISTINCT m.model_id, m.model_type, m.metadata, m.created_at, m.updated_at
|
204
|
+
FROM models m
|
205
|
+
JOIN model_capabilities mc ON m.model_id = mc.model_id
|
206
|
+
WHERE mc.capability = ?
|
207
|
+
""", (capability.value,)).fetchall()
|
208
|
+
|
209
|
+
result = {}
|
210
|
+
for model in models:
|
211
|
+
model_id = model["model_id"]
|
212
|
+
|
213
|
+
# Get all capabilities for this model
|
214
|
+
capabilities = conn.execute("""
|
215
|
+
SELECT capability FROM model_capabilities WHERE model_id = ?
|
216
|
+
""", (model_id,)).fetchall()
|
217
|
+
|
218
|
+
result[model_id] = {
|
219
|
+
"type": model["model_type"],
|
220
|
+
"capabilities": [cap["capability"] for cap in capabilities],
|
221
|
+
"metadata": json.loads(model["metadata"]) if model["metadata"] else {},
|
222
|
+
"created_at": model["created_at"],
|
223
|
+
"updated_at": model["updated_at"]
|
224
|
+
}
|
225
|
+
|
226
|
+
return result
|
227
|
+
|
228
|
+
except Exception as e:
|
229
|
+
logger.error(f"Failed to get models by capability {capability}: {e}")
|
230
|
+
return {}
|
105
231
|
|
106
232
|
def has_capability(self, model_id: str, capability: ModelCapability) -> bool:
|
107
233
|
"""Check if a model has a specific capability"""
|
108
|
-
|
109
|
-
|
234
|
+
try:
|
235
|
+
with sqlite3.connect(self.db_path) as conn:
|
236
|
+
result = conn.execute("""
|
237
|
+
SELECT 1 FROM model_capabilities
|
238
|
+
WHERE model_id = ? AND capability = ?
|
239
|
+
""", (model_id, capability.value)).fetchone()
|
240
|
+
|
241
|
+
return result is not None
|
242
|
+
|
243
|
+
except Exception as e:
|
244
|
+
logger.error(f"Failed to check capability for {model_id}: {e}")
|
110
245
|
return False
|
111
|
-
return capability.value in model_info["capabilities"]
|
112
246
|
|
113
247
|
def list_models(self) -> Dict[str, Dict[str, Any]]:
|
114
248
|
"""List all registered models"""
|
115
|
-
|
249
|
+
try:
|
250
|
+
with sqlite3.connect(self.db_path) as conn:
|
251
|
+
conn.row_factory = sqlite3.Row
|
252
|
+
|
253
|
+
models = conn.execute("""
|
254
|
+
SELECT model_id, model_type, metadata, created_at, updated_at
|
255
|
+
FROM models ORDER BY created_at DESC
|
256
|
+
""").fetchall()
|
257
|
+
|
258
|
+
result = {}
|
259
|
+
for model in models:
|
260
|
+
model_id = model["model_id"]
|
261
|
+
|
262
|
+
# Get capabilities for this model
|
263
|
+
capabilities = conn.execute("""
|
264
|
+
SELECT capability FROM model_capabilities WHERE model_id = ?
|
265
|
+
""", (model_id,)).fetchall()
|
266
|
+
|
267
|
+
result[model_id] = {
|
268
|
+
"type": model["model_type"],
|
269
|
+
"capabilities": [cap["capability"] for cap in capabilities],
|
270
|
+
"metadata": json.loads(model["metadata"]) if model["metadata"] else {},
|
271
|
+
"created_at": model["created_at"],
|
272
|
+
"updated_at": model["updated_at"]
|
273
|
+
}
|
274
|
+
|
275
|
+
return result
|
276
|
+
|
277
|
+
except Exception as e:
|
278
|
+
logger.error(f"Failed to list models: {e}")
|
279
|
+
return {}
|
280
|
+
|
281
|
+
def get_stats(self) -> Dict[str, Any]:
|
282
|
+
"""Get registry statistics"""
|
283
|
+
try:
|
284
|
+
with sqlite3.connect(self.db_path) as conn:
|
285
|
+
# Count total models
|
286
|
+
total_models = conn.execute("SELECT COUNT(*) FROM models").fetchone()[0]
|
287
|
+
|
288
|
+
# Count by type
|
289
|
+
type_counts = dict(conn.execute("""
|
290
|
+
SELECT model_type, COUNT(*) FROM models GROUP BY model_type
|
291
|
+
""").fetchall())
|
292
|
+
|
293
|
+
# Count by capability
|
294
|
+
capability_counts = dict(conn.execute("""
|
295
|
+
SELECT capability, COUNT(*) FROM model_capabilities GROUP BY capability
|
296
|
+
""").fetchall())
|
297
|
+
|
298
|
+
return {
|
299
|
+
"total_models": total_models,
|
300
|
+
"models_by_type": type_counts,
|
301
|
+
"models_by_capability": capability_counts
|
302
|
+
}
|
303
|
+
|
304
|
+
except Exception as e:
|
305
|
+
logger.error(f"Failed to get stats: {e}")
|
306
|
+
return {}
|
307
|
+
|
308
|
+
def search_models(self, query: str) -> Dict[str, Dict[str, Any]]:
|
309
|
+
"""Search models by name or metadata"""
|
310
|
+
try:
|
311
|
+
with sqlite3.connect(self.db_path) as conn:
|
312
|
+
conn.row_factory = sqlite3.Row
|
313
|
+
|
314
|
+
models = conn.execute("""
|
315
|
+
SELECT model_id, model_type, metadata, created_at, updated_at
|
316
|
+
FROM models
|
317
|
+
WHERE model_id LIKE ? OR metadata LIKE ?
|
318
|
+
ORDER BY created_at DESC
|
319
|
+
""", (f"%{query}%", f"%{query}%")).fetchall()
|
320
|
+
|
321
|
+
result = {}
|
322
|
+
for model in models:
|
323
|
+
model_id = model["model_id"]
|
324
|
+
|
325
|
+
# Get capabilities for this model
|
326
|
+
capabilities = conn.execute("""
|
327
|
+
SELECT capability FROM model_capabilities WHERE model_id = ?
|
328
|
+
""", (model_id,)).fetchall()
|
329
|
+
|
330
|
+
result[model_id] = {
|
331
|
+
"type": model["model_type"],
|
332
|
+
"capabilities": [cap["capability"] for cap in capabilities],
|
333
|
+
"metadata": json.loads(model["metadata"]) if model["metadata"] else {},
|
334
|
+
"created_at": model["created_at"],
|
335
|
+
"updated_at": model["updated_at"]
|
336
|
+
}
|
337
|
+
|
338
|
+
return result
|
339
|
+
|
340
|
+
except Exception as e:
|
341
|
+
logger.error(f"Failed to search models with query '{query}': {e}")
|
342
|
+
return {}
|
@@ -0,0 +1,120 @@
|
|
1
|
+
import json
|
2
|
+
import numpy as np
|
3
|
+
import torch
|
4
|
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5
|
+
import os
|
6
|
+
import triton_python_backend_utils as pb_utils
|
7
|
+
|
8
|
+
class TritonPythonModel:
|
9
|
+
def initialize(self, args):
|
10
|
+
"""初始化模型"""
|
11
|
+
self.model_config = json.loads(args['model_config'])
|
12
|
+
|
13
|
+
# --- START: CORRECTED PATH LOGIC ---
|
14
|
+
|
15
|
+
# model_repository 是父目录, e.g., /models/deepseek_r1
|
16
|
+
model_repository = args['model_repository']
|
17
|
+
# model_version 是版本号, e.g., '1'
|
18
|
+
model_version = args['model_version']
|
19
|
+
|
20
|
+
# 将它们拼接成指向模型文件的确切路径
|
21
|
+
model_path = os.path.join(model_repository, model_version)
|
22
|
+
|
23
|
+
print(f"Loading model from specific version path: {model_path}")
|
24
|
+
|
25
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
26
|
+
model_path, # 从正确的版本目录加载
|
27
|
+
trust_remote_code=True
|
28
|
+
)
|
29
|
+
|
30
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
31
|
+
model_path, # 从正确的版本目录加载
|
32
|
+
torch_dtype=torch.bfloat16,
|
33
|
+
device_map="gpu",
|
34
|
+
trust_remote_code=True
|
35
|
+
)
|
36
|
+
|
37
|
+
# --- END: CORRECTED PATH LOGIC ---
|
38
|
+
|
39
|
+
# ... (您代码的其余部分保持不变) ...
|
40
|
+
output_config = pb_utils.get_output_config_by_name(
|
41
|
+
self.model_config, "OUTPUT_TEXT"
|
42
|
+
)
|
43
|
+
self.output_dtype = pb_utils.triton_string_to_numpy(
|
44
|
+
output_config['data_type']
|
45
|
+
)
|
46
|
+
|
47
|
+
self.generation_config = {
|
48
|
+
'max_new_tokens': 512,
|
49
|
+
'temperature': 0.7,
|
50
|
+
'do_sample': True,
|
51
|
+
'top_p': 0.9,
|
52
|
+
'repetition_penalty': 1.1,
|
53
|
+
'pad_token_id': self.tokenizer.eos_token_id
|
54
|
+
}
|
55
|
+
|
56
|
+
print("Model loaded successfully!")
|
57
|
+
|
58
|
+
def execute(self, requests):
|
59
|
+
"""执行推理"""
|
60
|
+
responses = []
|
61
|
+
|
62
|
+
for request in requests:
|
63
|
+
# 获取输入文本
|
64
|
+
input_text = pb_utils.get_input_tensor_by_name(
|
65
|
+
request, "INPUT_TEXT"
|
66
|
+
).as_numpy()
|
67
|
+
|
68
|
+
# 解码输入文本
|
69
|
+
input_texts = [text.decode('utf-8') for text in input_text.flatten()]
|
70
|
+
|
71
|
+
# 批量推理
|
72
|
+
output_texts = []
|
73
|
+
for text in input_texts:
|
74
|
+
try:
|
75
|
+
# 编码输入
|
76
|
+
inputs = self.tokenizer.encode(
|
77
|
+
text,
|
78
|
+
return_tensors="pt"
|
79
|
+
).to(self.model.device)
|
80
|
+
|
81
|
+
# 生成响应
|
82
|
+
with torch.no_grad():
|
83
|
+
outputs = self.model.generate(
|
84
|
+
inputs,
|
85
|
+
**self.generation_config
|
86
|
+
)
|
87
|
+
|
88
|
+
# 解码输出
|
89
|
+
response = self.tokenizer.decode(
|
90
|
+
outputs[0][inputs.shape[-1]:],
|
91
|
+
skip_special_tokens=True
|
92
|
+
)
|
93
|
+
|
94
|
+
output_texts.append(response)
|
95
|
+
|
96
|
+
except Exception as e:
|
97
|
+
print(f"Error processing text: {e}")
|
98
|
+
output_texts.append(f"Error: {str(e)}")
|
99
|
+
|
100
|
+
# 准备输出
|
101
|
+
output_texts_np = np.array(
|
102
|
+
[[text.encode('utf-8')] for text in output_texts],
|
103
|
+
dtype=object
|
104
|
+
)
|
105
|
+
|
106
|
+
output_tensor = pb_utils.Tensor(
|
107
|
+
"OUTPUT_TEXT",
|
108
|
+
output_texts_np.astype(self.output_dtype)
|
109
|
+
)
|
110
|
+
|
111
|
+
response = pb_utils.InferenceResponse(
|
112
|
+
output_tensors=[output_tensor]
|
113
|
+
)
|
114
|
+
responses.append(response)
|
115
|
+
|
116
|
+
return responses
|
117
|
+
|
118
|
+
def finalize(self):
|
119
|
+
"""清理资源"""
|
120
|
+
print("Cleaning up...")
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from huggingface_hub import snapshot_download
|
2
|
+
import os
|
3
|
+
|
4
|
+
model_name = 'deepseek-ai/DeepSeek-R1-0528-Qwen3-8B'
|
5
|
+
# 定义Triton模型仓库中该模型的版本路径
|
6
|
+
local_model_path = os.path.join("models", "deepseek_r1", "1")
|
7
|
+
|
8
|
+
print(f"开始下载模型 '{model_name}' 到 '{local_model_path}'...")
|
9
|
+
|
10
|
+
# 使用 snapshot_download 下载整个模型仓库
|
11
|
+
# 它会下载所有文件,包括.safetensors权重文件
|
12
|
+
snapshot_download(
|
13
|
+
repo_id=model_name,
|
14
|
+
local_dir=local_model_path,
|
15
|
+
local_dir_use_symlinks=False,
|
16
|
+
)
|
17
|
+
|
18
|
+
print("模型所有文件下载完成!")
|
@@ -0,0 +1,66 @@
|
|
1
|
+
import os
|
2
|
+
from fastapi import FastAPI
|
3
|
+
from pydantic import BaseModel
|
4
|
+
from contextlib import asynccontextmanager
|
5
|
+
from pathlib import Path
|
6
|
+
from threading import Thread
|
7
|
+
from transformers import AutoTokenizer
|
8
|
+
from tensorrt_llm.runtime import ModelRunner
|
9
|
+
|
10
|
+
# --- 全局变量 ---
|
11
|
+
ENGINE_PATH = "/app/built_engine/deepseek_engine"
|
12
|
+
TOKENIZER_PATH = "/app/hf_model" # 我们需要原始HF模型中的tokenizer
|
13
|
+
runner = None
|
14
|
+
tokenizer = None
|
15
|
+
|
16
|
+
# --- FastAPI生命周期事件 ---
|
17
|
+
@asynccontextmanager
|
18
|
+
async def lifespan(app: FastAPI):
|
19
|
+
global runner, tokenizer
|
20
|
+
print("--- 正在加载模型引擎和Tokenizer... ---")
|
21
|
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
|
22
|
+
runner = ModelRunner.from_dir(engine_dir=ENGINE_PATH, rank=0, stream=True)
|
23
|
+
print("--- ✅ 模型加载完毕,服务准备就绪 ---")
|
24
|
+
yield
|
25
|
+
print("--- 正在清理资源... ---")
|
26
|
+
runner = None
|
27
|
+
tokenizer = None
|
28
|
+
|
29
|
+
app = FastAPI(lifespan=lifespan)
|
30
|
+
|
31
|
+
# --- API请求和响应模型 ---
|
32
|
+
class GenerateRequest(BaseModel):
|
33
|
+
prompt: str
|
34
|
+
max_new_tokens: int = 256
|
35
|
+
temperature: float = 0.7
|
36
|
+
|
37
|
+
class GenerateResponse(BaseModel):
|
38
|
+
text: str
|
39
|
+
|
40
|
+
# --- API端点 ---
|
41
|
+
@app.post("/generate", response_model=GenerateResponse)
|
42
|
+
async def generate(request: GenerateRequest):
|
43
|
+
print(f"收到请求: {request.prompt}")
|
44
|
+
|
45
|
+
# 准备输入
|
46
|
+
input_ids = tokenizer.encode(request.prompt, return_tensors="pt").to("cuda")
|
47
|
+
|
48
|
+
# 执行推理
|
49
|
+
output_ids = runner.generate(
|
50
|
+
input_ids,
|
51
|
+
max_new_tokens=request.max_new_tokens,
|
52
|
+
temperature=request.temperature,
|
53
|
+
eos_token_id=tokenizer.eos_token_id,
|
54
|
+
pad_token_id=tokenizer.pad_token_id,
|
55
|
+
)
|
56
|
+
|
57
|
+
# 清理并解码输出
|
58
|
+
# output_ids[0] 的形状是 [beam_width, seq_length]
|
59
|
+
generated_text = tokenizer.decode(output_ids[0, 0, len(input_ids[0]):], skip_special_tokens=True)
|
60
|
+
|
61
|
+
print(f"生成响应: {generated_text}")
|
62
|
+
return GenerateResponse(text=generated_text)
|
63
|
+
|
64
|
+
@app.get("/health")
|
65
|
+
async def health_check():
|
66
|
+
return {"status": "ok" if runner is not None else "loading"}
|
@@ -0,0 +1,43 @@
|
|
1
|
+
import requests
|
2
|
+
import json
|
3
|
+
|
4
|
+
# --- 配置 ---
|
5
|
+
TRITON_SERVER_URL = "http://localhost:8000"
|
6
|
+
MODEL_NAME = "deepseek_trtllm"
|
7
|
+
PROMPT = "请给我讲一个关于人工智能的笑话。"
|
8
|
+
MAX_TOKENS = 256
|
9
|
+
STREAM = False
|
10
|
+
# ----------------------------------------------------
|
11
|
+
|
12
|
+
def main():
|
13
|
+
"""向Triton服务器发送请求并打印结果。"""
|
14
|
+
url = f"{TRITON_SERVER_URL}/v2/models/{MODEL_NAME}/generate"
|
15
|
+
payload = {
|
16
|
+
"text_input": PROMPT,
|
17
|
+
"max_new_tokens": MAX_TOKENS,
|
18
|
+
"temperature": 0.7,
|
19
|
+
"stream": STREAM
|
20
|
+
}
|
21
|
+
print(f"Sending request to: {url}")
|
22
|
+
print(f"Payload: {json.dumps(payload, indent=2, ensure_ascii=False)}")
|
23
|
+
print("-" * 30)
|
24
|
+
|
25
|
+
try:
|
26
|
+
response = requests.post(url, json=payload, headers={"Accept": "application/json"})
|
27
|
+
response.raise_for_status()
|
28
|
+
response_data = response.json()
|
29
|
+
generated_text = response_data.get('text_output', 'Error: "text_output" key not found.')
|
30
|
+
|
31
|
+
print("✅ Request successful!")
|
32
|
+
print("-" * 30)
|
33
|
+
print("Prompt:", PROMPT)
|
34
|
+
print("\nGenerated Text:", generated_text)
|
35
|
+
|
36
|
+
except requests.exceptions.RequestException as e:
|
37
|
+
print(f"❌ Error making request to Triton server: {e}")
|
38
|
+
if e.response:
|
39
|
+
print(f"Response Status Code: {e.response.status_code}")
|
40
|
+
print(f"Response Body: {e.response.text}")
|
41
|
+
|
42
|
+
if __name__ == '__main__':
|
43
|
+
main()
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import requests
|
2
|
+
import json
|
3
|
+
|
4
|
+
PROMPT = "请给我讲一个关于人工智能的笑话。"
|
5
|
+
API_URL = "http://localhost:8000/generate"
|
6
|
+
|
7
|
+
def main():
|
8
|
+
payload = {
|
9
|
+
"prompt": PROMPT,
|
10
|
+
"max_new_tokens": 100
|
11
|
+
}
|
12
|
+
|
13
|
+
print(f"Sending request to: {API_URL}")
|
14
|
+
print(f"Payload: {json.dumps(payload, ensure_ascii=False)}")
|
15
|
+
print("-" * 30)
|
16
|
+
|
17
|
+
try:
|
18
|
+
response = requests.post(API_URL, json=payload)
|
19
|
+
response.raise_for_status()
|
20
|
+
|
21
|
+
response_data = response.json()
|
22
|
+
generated_text = response_data.get('text')
|
23
|
+
|
24
|
+
print("✅ Request successful!")
|
25
|
+
print("-" * 30)
|
26
|
+
print("Prompt:", PROMPT)
|
27
|
+
print("\nGenerated Text:", generated_text)
|
28
|
+
|
29
|
+
except requests.exceptions.RequestException as e:
|
30
|
+
print(f"❌ Error making request: {e}")
|
31
|
+
if e.response:
|
32
|
+
print(f"Response Body: {e.response.text}")
|
33
|
+
|
34
|
+
if __name__ == '__main__':
|
35
|
+
main()
|