isa-model 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +5 -0
- isa_model/core/model_manager.py +143 -0
- isa_model/core/model_registry.py +115 -0
- isa_model/core/model_router.py +226 -0
- isa_model/core/model_storage.py +133 -0
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +202 -0
- isa_model/core/storage/hf_storage.py +0 -0
- isa_model/core/storage/local_storage.py +0 -0
- isa_model/core/storage/minio_storage.py +0 -0
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
- isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
- isa_model/inference/__init__.py +11 -0
- isa_model/inference/adapter/unified_api.py +248 -0
- isa_model/inference/ai_factory.py +359 -0
- isa_model/inference/base.py +46 -0
- isa_model/inference/providers/__init__.py +19 -0
- isa_model/inference/providers/base_provider.py +30 -0
- isa_model/inference/providers/model_cache_manager.py +341 -0
- isa_model/inference/providers/ollama_provider.py +73 -0
- isa_model/inference/providers/openai_provider.py +101 -0
- isa_model/inference/providers/replicate_provider.py +107 -0
- isa_model/inference/providers/triton_provider.py +439 -0
- isa_model/inference/services/__init__.py +14 -0
- isa_model/inference/services/audio/base_stt_service.py +91 -0
- isa_model/inference/services/audio/base_tts_service.py +136 -0
- isa_model/inference/services/audio/openai_tts_service.py +71 -0
- isa_model/inference/services/base_service.py +106 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +97 -0
- isa_model/inference/services/embedding/openai_embed_service.py +0 -0
- isa_model/inference/services/llm/__init__.py +12 -0
- isa_model/inference/services/llm/base_llm_service.py +134 -0
- isa_model/inference/services/llm/ollama_llm_service.py +99 -0
- isa_model/inference/services/llm/openai_llm_service.py +138 -0
- isa_model/inference/services/others/table_transformer_service.py +61 -0
- isa_model/inference/services/vision/__init__.py +12 -0
- isa_model/inference/services/vision/helpers/image_utils.py +58 -0
- isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
- isa_model/inference/services/vision/ollama_vision_service.py +60 -0
- isa_model/inference/services/vision/openai_vision_service.py +80 -0
- isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
- isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
- isa_model/inference/utils/conversion/onnx_converter.py +0 -0
- isa_model/inference/utils/conversion/torch_converter.py +0 -0
- isa_model/scripts/inference_tracker.py +283 -0
- isa_model/scripts/mlflow_manager.py +379 -0
- isa_model/scripts/model_registry.py +465 -0
- isa_model/scripts/start_mlflow.py +95 -0
- isa_model/scripts/training_tracker.py +257 -0
- isa_model/training/engine/llama_factory/__init__.py +39 -0
- isa_model/training/engine/llama_factory/config.py +115 -0
- isa_model/training/engine/llama_factory/data_adapter.py +284 -0
- isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
- isa_model/training/engine/llama_factory/factory.py +331 -0
- isa_model/training/engine/llama_factory/rl.py +254 -0
- isa_model/training/engine/llama_factory/trainer.py +171 -0
- isa_model/training/image_model/configs/create_config.py +37 -0
- isa_model/training/image_model/configs/create_flux_config.py +26 -0
- isa_model/training/image_model/configs/create_lora_config.py +21 -0
- isa_model/training/image_model/prepare_massed_compute.py +97 -0
- isa_model/training/image_model/prepare_upload.py +17 -0
- isa_model/training/image_model/raw_data/create_captions.py +16 -0
- isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
- isa_model/training/image_model/raw_data/pre_processing.py +200 -0
- isa_model/training/image_model/train/train.py +42 -0
- isa_model/training/image_model/train/train_flux.py +41 -0
- isa_model/training/image_model/train/train_lora.py +57 -0
- isa_model/training/image_model/train_main.py +25 -0
- isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
- isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
- isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
- isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
- isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
- isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
- isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
- isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
- isa_model-0.0.1.dist-info/METADATA +327 -0
- isa_model-0.0.1.dist-info/RECORD +86 -0
- isa_model-0.0.1.dist-info/WHEEL +5 -0
- isa_model-0.0.1.dist-info/licenses/LICENSE +21 -0
- isa_model-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,202 @@
|
|
1
|
+
import time
|
2
|
+
import threading
|
3
|
+
import logging
|
4
|
+
from typing import Dict, List, Any, Optional
|
5
|
+
import psutil
|
6
|
+
try:
|
7
|
+
import torch
|
8
|
+
import nvidia_smi
|
9
|
+
HAS_GPU = torch.cuda.is_available()
|
10
|
+
except ImportError:
|
11
|
+
HAS_GPU = False
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
class ResourceManager:
|
16
|
+
"""
|
17
|
+
Monitors system resources and manages model loading/unloading
|
18
|
+
to prevent resource exhaustion.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(self, model_registry, monitoring_interval=30):
|
22
|
+
self.registry = model_registry
|
23
|
+
self.monitoring_interval = monitoring_interval # seconds
|
24
|
+
|
25
|
+
self.max_memory_percent = 90 # Maximum memory usage percentage
|
26
|
+
self.max_gpu_memory_percent = 90 # Maximum GPU memory usage percentage
|
27
|
+
|
28
|
+
self._stop_event = threading.Event()
|
29
|
+
self._monitor_thread = None
|
30
|
+
|
31
|
+
# Track resource usage over time
|
32
|
+
self.resource_history = {
|
33
|
+
"timestamps": [],
|
34
|
+
"cpu_percent": [],
|
35
|
+
"memory_percent": [],
|
36
|
+
"gpu_utilization": [],
|
37
|
+
"gpu_memory_percent": []
|
38
|
+
}
|
39
|
+
|
40
|
+
# Initialize GPU monitoring if available
|
41
|
+
if HAS_GPU:
|
42
|
+
try:
|
43
|
+
nvidia_smi.nvmlInit()
|
44
|
+
self.gpu_count = torch.cuda.device_count()
|
45
|
+
logger.info(f"Initialized GPU monitoring with {self.gpu_count} devices")
|
46
|
+
except Exception as e:
|
47
|
+
logger.warning(f"Failed to initialize NVIDIA SMI: {str(e)}")
|
48
|
+
self.gpu_count = 0
|
49
|
+
else:
|
50
|
+
self.gpu_count = 0
|
51
|
+
|
52
|
+
logger.info("Initialized ResourceManager")
|
53
|
+
|
54
|
+
def start_monitoring(self):
|
55
|
+
"""Start the resource monitoring thread"""
|
56
|
+
if self._monitor_thread is not None and self._monitor_thread.is_alive():
|
57
|
+
logger.warning("Resource monitoring already running")
|
58
|
+
return
|
59
|
+
|
60
|
+
self._stop_event.clear()
|
61
|
+
self._monitor_thread = threading.Thread(
|
62
|
+
target=self._monitor_resources,
|
63
|
+
daemon=True
|
64
|
+
)
|
65
|
+
self._monitor_thread.start()
|
66
|
+
logger.info("Started resource monitoring thread")
|
67
|
+
|
68
|
+
def stop_monitoring(self):
|
69
|
+
"""Stop the resource monitoring thread"""
|
70
|
+
if self._monitor_thread is not None:
|
71
|
+
self._stop_event.set()
|
72
|
+
self._monitor_thread.join(timeout=5)
|
73
|
+
self._monitor_thread = None
|
74
|
+
logger.info("Stopped resource monitoring thread")
|
75
|
+
|
76
|
+
def _monitor_resources(self):
|
77
|
+
"""Monitor system resources in a loop"""
|
78
|
+
while not self._stop_event.is_set():
|
79
|
+
try:
|
80
|
+
# Get current resource usage
|
81
|
+
cpu_percent = psutil.cpu_percent(interval=1)
|
82
|
+
memory = psutil.virtual_memory()
|
83
|
+
memory_percent = memory.percent
|
84
|
+
|
85
|
+
# GPU monitoring
|
86
|
+
gpu_utilization = 0
|
87
|
+
gpu_memory_percent = 0
|
88
|
+
|
89
|
+
if self.gpu_count > 0:
|
90
|
+
gpu_util_sum = 0
|
91
|
+
gpu_mem_percent_sum = 0
|
92
|
+
|
93
|
+
for i in range(self.gpu_count):
|
94
|
+
try:
|
95
|
+
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)
|
96
|
+
util = nvidia_smi.nvmlDeviceGetUtilizationRates(handle)
|
97
|
+
mem_info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
|
98
|
+
|
99
|
+
gpu_util_sum += util.gpu
|
100
|
+
gpu_mem_percent = (mem_info.used / mem_info.total) * 100
|
101
|
+
gpu_mem_percent_sum += gpu_mem_percent
|
102
|
+
except Exception as e:
|
103
|
+
logger.error(f"Error getting GPU {i} stats: {str(e)}")
|
104
|
+
|
105
|
+
if self.gpu_count > 0:
|
106
|
+
gpu_utilization = gpu_util_sum / self.gpu_count
|
107
|
+
gpu_memory_percent = gpu_mem_percent_sum / self.gpu_count
|
108
|
+
|
109
|
+
# Record history (keep last 60 samples)
|
110
|
+
now = time.time()
|
111
|
+
self.resource_history["timestamps"].append(now)
|
112
|
+
self.resource_history["cpu_percent"].append(cpu_percent)
|
113
|
+
self.resource_history["memory_percent"].append(memory_percent)
|
114
|
+
self.resource_history["gpu_utilization"].append(gpu_utilization)
|
115
|
+
self.resource_history["gpu_memory_percent"].append(gpu_memory_percent)
|
116
|
+
|
117
|
+
# Trim history to last 60 samples
|
118
|
+
max_history = 60
|
119
|
+
if len(self.resource_history["timestamps"]) > max_history:
|
120
|
+
for key in self.resource_history:
|
121
|
+
self.resource_history[key] = self.resource_history[key][-max_history:]
|
122
|
+
|
123
|
+
# Check if we need to unload models
|
124
|
+
self._check_resource_constraints(memory_percent, gpu_memory_percent)
|
125
|
+
|
126
|
+
# Log current usage
|
127
|
+
logger.debug(
|
128
|
+
f"Resource usage - CPU: {cpu_percent:.1f}%, Memory: {memory_percent:.1f}%, "
|
129
|
+
f"GPU: {gpu_utilization:.1f}%, GPU Memory: {gpu_memory_percent:.1f}%"
|
130
|
+
)
|
131
|
+
|
132
|
+
# Wait for next check
|
133
|
+
self._stop_event.wait(self.monitoring_interval)
|
134
|
+
|
135
|
+
except Exception as e:
|
136
|
+
logger.error(f"Error in resource monitoring: {str(e)}")
|
137
|
+
# Wait a bit before retrying
|
138
|
+
self._stop_event.wait(5)
|
139
|
+
|
140
|
+
def _check_resource_constraints(self, memory_percent, gpu_memory_percent):
|
141
|
+
"""Check if we need to unload models due to resource constraints"""
|
142
|
+
# Check memory usage
|
143
|
+
if memory_percent > self.max_memory_percent:
|
144
|
+
logger.warning(
|
145
|
+
f"Memory usage ({memory_percent:.1f}%) exceeds threshold ({self.max_memory_percent}%). "
|
146
|
+
"Unloading least used model."
|
147
|
+
)
|
148
|
+
# This would trigger model unloading
|
149
|
+
# self.registry._evict_least_used_model()
|
150
|
+
|
151
|
+
# Check GPU memory usage
|
152
|
+
if HAS_GPU and gpu_memory_percent > self.max_gpu_memory_percent:
|
153
|
+
logger.warning(
|
154
|
+
f"GPU memory usage ({gpu_memory_percent:.1f}%) exceeds threshold ({self.max_gpu_memory_percent}%). "
|
155
|
+
"Unloading least used model."
|
156
|
+
)
|
157
|
+
# This would trigger model unloading
|
158
|
+
# self.registry._evict_least_used_model()
|
159
|
+
|
160
|
+
def get_resource_usage(self) -> Dict[str, Any]:
|
161
|
+
"""Get current resource usage stats"""
|
162
|
+
try:
|
163
|
+
cpu_percent = psutil.cpu_percent(interval=0.1)
|
164
|
+
memory = psutil.virtual_memory()
|
165
|
+
memory_percent = memory.percent
|
166
|
+
|
167
|
+
result = {
|
168
|
+
"cpu_percent": cpu_percent,
|
169
|
+
"memory_total_gb": memory.total / (1024**3),
|
170
|
+
"memory_available_gb": memory.available / (1024**3),
|
171
|
+
"memory_percent": memory_percent,
|
172
|
+
"gpus": []
|
173
|
+
}
|
174
|
+
|
175
|
+
# GPU stats
|
176
|
+
if HAS_GPU:
|
177
|
+
for i in range(self.gpu_count):
|
178
|
+
try:
|
179
|
+
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)
|
180
|
+
name = nvidia_smi.nvmlDeviceGetName(handle)
|
181
|
+
util = nvidia_smi.nvmlDeviceGetUtilizationRates(handle)
|
182
|
+
mem_info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
|
183
|
+
temp = nvidia_smi.nvmlDeviceGetTemperature(
|
184
|
+
handle, nvidia_smi.NVML_TEMPERATURE_GPU
|
185
|
+
)
|
186
|
+
|
187
|
+
result["gpus"].append({
|
188
|
+
"index": i,
|
189
|
+
"name": name,
|
190
|
+
"utilization_percent": util.gpu,
|
191
|
+
"memory_total_gb": mem_info.total / (1024**3),
|
192
|
+
"memory_used_gb": mem_info.used / (1024**3),
|
193
|
+
"memory_percent": (mem_info.used / mem_info.total) * 100,
|
194
|
+
"temperature_c": temp
|
195
|
+
})
|
196
|
+
except Exception as e:
|
197
|
+
logger.error(f"Error getting GPU {i} stats: {str(e)}")
|
198
|
+
|
199
|
+
return result
|
200
|
+
except Exception as e:
|
201
|
+
logger.error(f"Error getting resource usage: {str(e)}")
|
202
|
+
return {"error": str(e)}
|
File without changes
|
File without changes
|
File without changes
|
@@ -0,0 +1,120 @@
|
|
1
|
+
import json
|
2
|
+
import numpy as np
|
3
|
+
import torch
|
4
|
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5
|
+
import os
|
6
|
+
import triton_python_backend_utils as pb_utils
|
7
|
+
|
8
|
+
class TritonPythonModel:
|
9
|
+
def initialize(self, args):
|
10
|
+
"""初始化模型"""
|
11
|
+
self.model_config = json.loads(args['model_config'])
|
12
|
+
|
13
|
+
# --- START: CORRECTED PATH LOGIC ---
|
14
|
+
|
15
|
+
# model_repository 是父目录, e.g., /models/deepseek_r1
|
16
|
+
model_repository = args['model_repository']
|
17
|
+
# model_version 是版本号, e.g., '1'
|
18
|
+
model_version = args['model_version']
|
19
|
+
|
20
|
+
# 将它们拼接成指向模型文件的确切路径
|
21
|
+
model_path = os.path.join(model_repository, model_version)
|
22
|
+
|
23
|
+
print(f"Loading model from specific version path: {model_path}")
|
24
|
+
|
25
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
26
|
+
model_path, # 从正确的版本目录加载
|
27
|
+
trust_remote_code=True
|
28
|
+
)
|
29
|
+
|
30
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
31
|
+
model_path, # 从正确的版本目录加载
|
32
|
+
torch_dtype=torch.bfloat16,
|
33
|
+
device_map="gpu",
|
34
|
+
trust_remote_code=True
|
35
|
+
)
|
36
|
+
|
37
|
+
# --- END: CORRECTED PATH LOGIC ---
|
38
|
+
|
39
|
+
# ... (您代码的其余部分保持不变) ...
|
40
|
+
output_config = pb_utils.get_output_config_by_name(
|
41
|
+
self.model_config, "OUTPUT_TEXT"
|
42
|
+
)
|
43
|
+
self.output_dtype = pb_utils.triton_string_to_numpy(
|
44
|
+
output_config['data_type']
|
45
|
+
)
|
46
|
+
|
47
|
+
self.generation_config = {
|
48
|
+
'max_new_tokens': 512,
|
49
|
+
'temperature': 0.7,
|
50
|
+
'do_sample': True,
|
51
|
+
'top_p': 0.9,
|
52
|
+
'repetition_penalty': 1.1,
|
53
|
+
'pad_token_id': self.tokenizer.eos_token_id
|
54
|
+
}
|
55
|
+
|
56
|
+
print("Model loaded successfully!")
|
57
|
+
|
58
|
+
def execute(self, requests):
|
59
|
+
"""执行推理"""
|
60
|
+
responses = []
|
61
|
+
|
62
|
+
for request in requests:
|
63
|
+
# 获取输入文本
|
64
|
+
input_text = pb_utils.get_input_tensor_by_name(
|
65
|
+
request, "INPUT_TEXT"
|
66
|
+
).as_numpy()
|
67
|
+
|
68
|
+
# 解码输入文本
|
69
|
+
input_texts = [text.decode('utf-8') for text in input_text.flatten()]
|
70
|
+
|
71
|
+
# 批量推理
|
72
|
+
output_texts = []
|
73
|
+
for text in input_texts:
|
74
|
+
try:
|
75
|
+
# 编码输入
|
76
|
+
inputs = self.tokenizer.encode(
|
77
|
+
text,
|
78
|
+
return_tensors="pt"
|
79
|
+
).to(self.model.device)
|
80
|
+
|
81
|
+
# 生成响应
|
82
|
+
with torch.no_grad():
|
83
|
+
outputs = self.model.generate(
|
84
|
+
inputs,
|
85
|
+
**self.generation_config
|
86
|
+
)
|
87
|
+
|
88
|
+
# 解码输出
|
89
|
+
response = self.tokenizer.decode(
|
90
|
+
outputs[0][inputs.shape[-1]:],
|
91
|
+
skip_special_tokens=True
|
92
|
+
)
|
93
|
+
|
94
|
+
output_texts.append(response)
|
95
|
+
|
96
|
+
except Exception as e:
|
97
|
+
print(f"Error processing text: {e}")
|
98
|
+
output_texts.append(f"Error: {str(e)}")
|
99
|
+
|
100
|
+
# 准备输出
|
101
|
+
output_texts_np = np.array(
|
102
|
+
[[text.encode('utf-8')] for text in output_texts],
|
103
|
+
dtype=object
|
104
|
+
)
|
105
|
+
|
106
|
+
output_tensor = pb_utils.Tensor(
|
107
|
+
"OUTPUT_TEXT",
|
108
|
+
output_texts_np.astype(self.output_dtype)
|
109
|
+
)
|
110
|
+
|
111
|
+
response = pb_utils.InferenceResponse(
|
112
|
+
output_tensors=[output_tensor]
|
113
|
+
)
|
114
|
+
responses.append(response)
|
115
|
+
|
116
|
+
return responses
|
117
|
+
|
118
|
+
def finalize(self):
|
119
|
+
"""清理资源"""
|
120
|
+
print("Cleaning up...")
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from huggingface_hub import snapshot_download
|
2
|
+
import os
|
3
|
+
|
4
|
+
model_name = 'deepseek-ai/DeepSeek-R1-0528-Qwen3-8B'
|
5
|
+
# 定义Triton模型仓库中该模型的版本路径
|
6
|
+
local_model_path = os.path.join("models", "deepseek_r1", "1")
|
7
|
+
|
8
|
+
print(f"开始下载模型 '{model_name}' 到 '{local_model_path}'...")
|
9
|
+
|
10
|
+
# 使用 snapshot_download 下载整个模型仓库
|
11
|
+
# 它会下载所有文件,包括.safetensors权重文件
|
12
|
+
snapshot_download(
|
13
|
+
repo_id=model_name,
|
14
|
+
local_dir=local_model_path,
|
15
|
+
local_dir_use_symlinks=False,
|
16
|
+
)
|
17
|
+
|
18
|
+
print("模型所有文件下载完成!")
|
@@ -0,0 +1,66 @@
|
|
1
|
+
import os
|
2
|
+
from fastapi import FastAPI
|
3
|
+
from pydantic import BaseModel
|
4
|
+
from contextlib import asynccontextmanager
|
5
|
+
from pathlib import Path
|
6
|
+
from threading import Thread
|
7
|
+
from transformers import AutoTokenizer
|
8
|
+
from tensorrt_llm.runtime import ModelRunner
|
9
|
+
|
10
|
+
# --- 全局变量 ---
|
11
|
+
ENGINE_PATH = "/app/built_engine/deepseek_engine"
|
12
|
+
TOKENIZER_PATH = "/app/hf_model" # 我们需要原始HF模型中的tokenizer
|
13
|
+
runner = None
|
14
|
+
tokenizer = None
|
15
|
+
|
16
|
+
# --- FastAPI生命周期事件 ---
|
17
|
+
@asynccontextmanager
|
18
|
+
async def lifespan(app: FastAPI):
|
19
|
+
global runner, tokenizer
|
20
|
+
print("--- 正在加载模型引擎和Tokenizer... ---")
|
21
|
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
|
22
|
+
runner = ModelRunner.from_dir(engine_dir=ENGINE_PATH, rank=0, stream=True)
|
23
|
+
print("--- ✅ 模型加载完毕,服务准备就绪 ---")
|
24
|
+
yield
|
25
|
+
print("--- 正在清理资源... ---")
|
26
|
+
runner = None
|
27
|
+
tokenizer = None
|
28
|
+
|
29
|
+
app = FastAPI(lifespan=lifespan)
|
30
|
+
|
31
|
+
# --- API请求和响应模型 ---
|
32
|
+
class GenerateRequest(BaseModel):
|
33
|
+
prompt: str
|
34
|
+
max_new_tokens: int = 256
|
35
|
+
temperature: float = 0.7
|
36
|
+
|
37
|
+
class GenerateResponse(BaseModel):
|
38
|
+
text: str
|
39
|
+
|
40
|
+
# --- API端点 ---
|
41
|
+
@app.post("/generate", response_model=GenerateResponse)
|
42
|
+
async def generate(request: GenerateRequest):
|
43
|
+
print(f"收到请求: {request.prompt}")
|
44
|
+
|
45
|
+
# 准备输入
|
46
|
+
input_ids = tokenizer.encode(request.prompt, return_tensors="pt").to("cuda")
|
47
|
+
|
48
|
+
# 执行推理
|
49
|
+
output_ids = runner.generate(
|
50
|
+
input_ids,
|
51
|
+
max_new_tokens=request.max_new_tokens,
|
52
|
+
temperature=request.temperature,
|
53
|
+
eos_token_id=tokenizer.eos_token_id,
|
54
|
+
pad_token_id=tokenizer.pad_token_id,
|
55
|
+
)
|
56
|
+
|
57
|
+
# 清理并解码输出
|
58
|
+
# output_ids[0] 的形状是 [beam_width, seq_length]
|
59
|
+
generated_text = tokenizer.decode(output_ids[0, 0, len(input_ids[0]):], skip_special_tokens=True)
|
60
|
+
|
61
|
+
print(f"生成响应: {generated_text}")
|
62
|
+
return GenerateResponse(text=generated_text)
|
63
|
+
|
64
|
+
@app.get("/health")
|
65
|
+
async def health_check():
|
66
|
+
return {"status": "ok" if runner is not None else "loading"}
|
@@ -0,0 +1,43 @@
|
|
1
|
+
import requests
|
2
|
+
import json
|
3
|
+
|
4
|
+
# --- 配置 ---
|
5
|
+
TRITON_SERVER_URL = "http://localhost:8000"
|
6
|
+
MODEL_NAME = "deepseek_trtllm"
|
7
|
+
PROMPT = "请给我讲一个关于人工智能的笑话。"
|
8
|
+
MAX_TOKENS = 256
|
9
|
+
STREAM = False
|
10
|
+
# ----------------------------------------------------
|
11
|
+
|
12
|
+
def main():
|
13
|
+
"""向Triton服务器发送请求并打印结果。"""
|
14
|
+
url = f"{TRITON_SERVER_URL}/v2/models/{MODEL_NAME}/generate"
|
15
|
+
payload = {
|
16
|
+
"text_input": PROMPT,
|
17
|
+
"max_new_tokens": MAX_TOKENS,
|
18
|
+
"temperature": 0.7,
|
19
|
+
"stream": STREAM
|
20
|
+
}
|
21
|
+
print(f"Sending request to: {url}")
|
22
|
+
print(f"Payload: {json.dumps(payload, indent=2, ensure_ascii=False)}")
|
23
|
+
print("-" * 30)
|
24
|
+
|
25
|
+
try:
|
26
|
+
response = requests.post(url, json=payload, headers={"Accept": "application/json"})
|
27
|
+
response.raise_for_status()
|
28
|
+
response_data = response.json()
|
29
|
+
generated_text = response_data.get('text_output', 'Error: "text_output" key not found.')
|
30
|
+
|
31
|
+
print("✅ Request successful!")
|
32
|
+
print("-" * 30)
|
33
|
+
print("Prompt:", PROMPT)
|
34
|
+
print("\nGenerated Text:", generated_text)
|
35
|
+
|
36
|
+
except requests.exceptions.RequestException as e:
|
37
|
+
print(f"❌ Error making request to Triton server: {e}")
|
38
|
+
if e.response:
|
39
|
+
print(f"Response Status Code: {e.response.status_code}")
|
40
|
+
print(f"Response Body: {e.response.text}")
|
41
|
+
|
42
|
+
if __name__ == '__main__':
|
43
|
+
main()
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import requests
|
2
|
+
import json
|
3
|
+
|
4
|
+
PROMPT = "请给我讲一个关于人工智能的笑话。"
|
5
|
+
API_URL = "http://localhost:8000/generate"
|
6
|
+
|
7
|
+
def main():
|
8
|
+
payload = {
|
9
|
+
"prompt": PROMPT,
|
10
|
+
"max_new_tokens": 100
|
11
|
+
}
|
12
|
+
|
13
|
+
print(f"Sending request to: {API_URL}")
|
14
|
+
print(f"Payload: {json.dumps(payload, ensure_ascii=False)}")
|
15
|
+
print("-" * 30)
|
16
|
+
|
17
|
+
try:
|
18
|
+
response = requests.post(API_URL, json=payload)
|
19
|
+
response.raise_for_status()
|
20
|
+
|
21
|
+
response_data = response.json()
|
22
|
+
generated_text = response_data.get('text')
|
23
|
+
|
24
|
+
print("✅ Request successful!")
|
25
|
+
print("-" * 30)
|
26
|
+
print("Prompt:", PROMPT)
|
27
|
+
print("\nGenerated Text:", generated_text)
|
28
|
+
|
29
|
+
except requests.exceptions.RequestException as e:
|
30
|
+
print(f"❌ Error making request: {e}")
|
31
|
+
if e.response:
|
32
|
+
print(f"Response Body: {e.response.text}")
|
33
|
+
|
34
|
+
if __name__ == '__main__':
|
35
|
+
main()
|
@@ -0,0 +1,11 @@
|
|
1
|
+
"""
|
2
|
+
Inference module for isA_Model
|
3
|
+
|
4
|
+
File: isa_model/inference/__init__.py
|
5
|
+
This module provides the main inference components for the IsA Model system.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from .ai_factory import AIFactory
|
9
|
+
from .base import ModelType, Capability, RoutingStrategy
|
10
|
+
|
11
|
+
__all__ = ["AIFactory", "ModelType", "Capability", "RoutingStrategy"]
|