isa-model 0.0.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/model_manager.py +69 -4
  3. isa_model/core/model_registry.py +273 -46
  4. isa_model/core/storage/hf_storage.py +419 -0
  5. isa_model/deployment/__init__.py +52 -0
  6. isa_model/deployment/core/__init__.py +34 -0
  7. isa_model/deployment/core/deployment_config.py +356 -0
  8. isa_model/deployment/core/deployment_manager.py +549 -0
  9. isa_model/deployment/core/isa_deployment_service.py +401 -0
  10. isa_model/eval/factory.py +381 -140
  11. isa_model/inference/ai_factory.py +427 -236
  12. isa_model/inference/billing_tracker.py +406 -0
  13. isa_model/inference/providers/base_provider.py +51 -4
  14. isa_model/inference/providers/ml_provider.py +50 -0
  15. isa_model/inference/providers/ollama_provider.py +37 -18
  16. isa_model/inference/providers/openai_provider.py +65 -36
  17. isa_model/inference/providers/replicate_provider.py +42 -30
  18. isa_model/inference/services/audio/base_stt_service.py +21 -2
  19. isa_model/inference/services/audio/openai_realtime_service.py +353 -0
  20. isa_model/inference/services/audio/openai_stt_service.py +252 -0
  21. isa_model/inference/services/audio/openai_tts_service.py +149 -9
  22. isa_model/inference/services/audio/replicate_tts_service.py +239 -0
  23. isa_model/inference/services/base_service.py +36 -1
  24. isa_model/inference/services/embedding/base_embed_service.py +112 -0
  25. isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
  26. isa_model/inference/services/embedding/openai_embed_service.py +223 -0
  27. isa_model/inference/services/llm/__init__.py +2 -0
  28. isa_model/inference/services/llm/base_llm_service.py +158 -86
  29. isa_model/inference/services/llm/llm_adapter.py +414 -0
  30. isa_model/inference/services/llm/ollama_llm_service.py +252 -63
  31. isa_model/inference/services/llm/openai_llm_service.py +231 -93
  32. isa_model/inference/services/llm/triton_llm_service.py +481 -0
  33. isa_model/inference/services/ml/base_ml_service.py +78 -0
  34. isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
  35. isa_model/inference/services/vision/__init__.py +3 -3
  36. isa_model/inference/services/vision/base_image_gen_service.py +161 -0
  37. isa_model/inference/services/vision/base_vision_service.py +177 -0
  38. isa_model/inference/services/vision/helpers/image_utils.py +4 -3
  39. isa_model/inference/services/vision/ollama_vision_service.py +151 -17
  40. isa_model/inference/services/vision/openai_vision_service.py +275 -41
  41. isa_model/inference/services/vision/replicate_image_gen_service.py +278 -118
  42. isa_model/training/__init__.py +62 -32
  43. isa_model/training/cloud/__init__.py +22 -0
  44. isa_model/training/cloud/job_orchestrator.py +402 -0
  45. isa_model/training/cloud/runpod_trainer.py +454 -0
  46. isa_model/training/cloud/storage_manager.py +482 -0
  47. isa_model/training/core/__init__.py +23 -0
  48. isa_model/training/core/config.py +181 -0
  49. isa_model/training/core/dataset.py +222 -0
  50. isa_model/training/core/trainer.py +720 -0
  51. isa_model/training/core/utils.py +213 -0
  52. isa_model/training/factory.py +229 -198
  53. isa_model-0.3.1.dist-info/METADATA +465 -0
  54. isa_model-0.3.1.dist-info/RECORD +91 -0
  55. isa_model/core/model_router.py +0 -226
  56. isa_model/core/model_version.py +0 -0
  57. isa_model/core/resource_manager.py +0 -202
  58. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
  59. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
  60. isa_model/training/engine/llama_factory/__init__.py +0 -39
  61. isa_model/training/engine/llama_factory/config.py +0 -115
  62. isa_model/training/engine/llama_factory/data_adapter.py +0 -284
  63. isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
  64. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
  65. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
  66. isa_model/training/engine/llama_factory/factory.py +0 -331
  67. isa_model/training/engine/llama_factory/rl.py +0 -254
  68. isa_model/training/engine/llama_factory/trainer.py +0 -171
  69. isa_model/training/image_model/configs/create_config.py +0 -37
  70. isa_model/training/image_model/configs/create_flux_config.py +0 -26
  71. isa_model/training/image_model/configs/create_lora_config.py +0 -21
  72. isa_model/training/image_model/prepare_massed_compute.py +0 -97
  73. isa_model/training/image_model/prepare_upload.py +0 -17
  74. isa_model/training/image_model/raw_data/create_captions.py +0 -16
  75. isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
  76. isa_model/training/image_model/raw_data/pre_processing.py +0 -200
  77. isa_model/training/image_model/train/train.py +0 -42
  78. isa_model/training/image_model/train/train_flux.py +0 -41
  79. isa_model/training/image_model/train/train_lora.py +0 -57
  80. isa_model/training/image_model/train_main.py +0 -25
  81. isa_model-0.0.2.dist-info/METADATA +0 -327
  82. isa_model-0.0.2.dist-info/RECORD +0 -92
  83. isa_model-0.0.2.dist-info/licenses/LICENSE +0 -21
  84. /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
  85. /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
  86. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
  87. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
  88. /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
  89. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
  90. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
  91. /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
  92. {isa_model-0.0.2.dist-info → isa_model-0.3.1.dist-info}/WHEEL +0 -0
  93. {isa_model-0.0.2.dist-info → isa_model-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,226 +0,0 @@
1
- import random
2
- import time
3
- from typing import Dict, List, Any, Optional, Callable
4
- import threading
5
-
6
- class ModelRouter:
7
- """
8
- Routes requests to appropriate model instances based on different strategies:
9
- - Weighted round-robin
10
- - Least connections
11
- - Least response time
12
- - Dynamic load balancing
13
- """
14
-
15
- def __init__(self, registry):
16
- self.registry = registry
17
- self.model_stats = {} # Track performance metrics for each model
18
- self.lock = threading.RLock()
19
-
20
- # Maps model_type -> list of model_ids of that type
21
- self.model_type_mapping = {}
22
-
23
- # Maps routing_strategy_name -> routing_function
24
- self.routing_strategies = {
25
- "round_robin": self._route_round_robin,
26
- "weighted_random": self._route_weighted_random,
27
- "least_connections": self._route_least_connections,
28
- "least_response_time": self._route_least_response_time,
29
- "dynamic_load": self._route_dynamic_load
30
- }
31
-
32
- # Round-robin counters for each model type
33
- self.rr_counters = {}
34
-
35
- def register_model_type(self, model_type: str, model_ids: List[str], weights: Optional[List[float]] = None):
36
- """Register models of a specific type with optional weights"""
37
- with self.lock:
38
- self.model_type_mapping[model_type] = model_ids
39
-
40
- # Initialize stats for each model
41
- for i, model_id in enumerate(model_ids):
42
- weight = weights[i] if weights and i < len(weights) else 1.0
43
-
44
- if model_id not in self.model_stats:
45
- self.model_stats[model_id] = {
46
- "active_connections": 0,
47
- "total_requests": 0,
48
- "avg_response_time": 0,
49
- "weight": weight,
50
- "last_used": 0
51
- }
52
- else:
53
- # Update weight if model already exists
54
- self.model_stats[model_id]["weight"] = weight
55
-
56
- # Initialize round-robin counter
57
- self.rr_counters[model_type] = 0
58
-
59
- def route_request(self, model_type: str, routing_strategy: str = "round_robin") -> Optional[str]:
60
- """
61
- Route a request to an appropriate model of the given type
62
-
63
- Args:
64
- model_type: Type of model needed
65
- routing_strategy: Strategy to use for routing
66
-
67
- Returns:
68
- model_id: ID of the model to use, or None if no models available
69
- """
70
- with self.lock:
71
- if model_type not in self.model_type_mapping:
72
- return None
73
-
74
- if not self.model_type_mapping[model_type]:
75
- return None
76
-
77
- # Get the routing function
78
- routing_func = self.routing_strategies.get(routing_strategy, self._route_round_robin)
79
-
80
- # Route the request
81
- model_id = routing_func(model_type)
82
-
83
- if model_id:
84
- # Update stats
85
- self.model_stats[model_id]["active_connections"] += 1
86
- self.model_stats[model_id]["total_requests"] += 1
87
- self.model_stats[model_id]["last_used"] = time.time()
88
-
89
- return model_id
90
-
91
- def release_connection(self, model_id: str, response_time: float = None):
92
- """Release a connection and update stats"""
93
- with self.lock:
94
- if model_id in self.model_stats:
95
- stats = self.model_stats[model_id]
96
- stats["active_connections"] = max(0, stats["active_connections"] - 1)
97
-
98
- # Update average response time
99
- if response_time is not None:
100
- old_avg = stats["avg_response_time"]
101
- total_req = stats["total_requests"]
102
-
103
- if total_req > 0:
104
- # Weighted average
105
- stats["avg_response_time"] = (old_avg * (total_req - 1) + response_time) / total_req
106
-
107
- # Routing strategies
108
- def _route_round_robin(self, model_type: str) -> Optional[str]:
109
- """Simple round-robin routing"""
110
- models = self.model_type_mapping.get(model_type, [])
111
- if not models:
112
- return None
113
-
114
- # Get and increment counter
115
- counter = self.rr_counters[model_type]
116
- self.rr_counters[model_type] = (counter + 1) % len(models)
117
-
118
- return models[counter]
119
-
120
- def _route_weighted_random(self, model_type: str) -> Optional[str]:
121
- """Weighted random selection based on configured weights"""
122
- models = self.model_type_mapping.get(model_type, [])
123
- if not models:
124
- return None
125
-
126
- # Get weights
127
- weights = [self.model_stats[model_id]["weight"] for model_id in models]
128
-
129
- # Weighted random selection
130
- total = sum(weights)
131
- r = random.uniform(0, total)
132
- upto = 0
133
-
134
- for i, w in enumerate(weights):
135
- upto += w
136
- if upto >= r:
137
- return models[i]
138
-
139
- # Fallback
140
- return models[-1]
141
-
142
- def _route_least_connections(self, model_type: str) -> Optional[str]:
143
- """Route to the model with the fewest active connections"""
144
- models = self.model_type_mapping.get(model_type, [])
145
- if not models:
146
- return None
147
-
148
- # Find model with least connections
149
- min_connections = float('inf')
150
- selected_model = None
151
-
152
- for model_id in models:
153
- connections = self.model_stats[model_id]["active_connections"]
154
- if connections < min_connections:
155
- min_connections = connections
156
- selected_model = model_id
157
-
158
- return selected_model
159
-
160
- def _route_least_response_time(self, model_type: str) -> Optional[str]:
161
- """Route to the model with the lowest average response time"""
162
- models = self.model_type_mapping.get(model_type, [])
163
- if not models:
164
- return None
165
-
166
- # Find model with lowest response time
167
- min_response_time = float('inf')
168
- selected_model = None
169
-
170
- for model_id in models:
171
- response_time = self.model_stats[model_id]["avg_response_time"]
172
- # Skip models with no data yet
173
- if response_time == 0:
174
- continue
175
-
176
- if response_time < min_response_time:
177
- min_response_time = response_time
178
- selected_model = model_id
179
-
180
- # If no model has response time data, fall back to least connections
181
- if selected_model is None:
182
- return self._route_least_connections(model_type)
183
-
184
- return selected_model
185
-
186
- def _route_dynamic_load(self, model_type: str) -> Optional[str]:
187
- """
188
- Dynamic load balancing based on a combination of:
189
- - Connection count
190
- - Response time
191
- - Recent usage
192
- """
193
- models = self.model_type_mapping.get(model_type, [])
194
- if not models:
195
- return None
196
-
197
- # Calculate a score for each model (lower is better)
198
- best_score = float('inf')
199
- selected_model = None
200
- now = time.time()
201
-
202
- for model_id in models:
203
- stats = self.model_stats[model_id]
204
-
205
- # Normalize each factor between 0 and 1
206
- connections = stats["active_connections"]
207
- conn_score = connections / (connections + 1) # Approaches 1 as connections increase
208
-
209
- resp_time = stats["avg_response_time"]
210
- # Max expected response time (adjust as needed)
211
- max_resp_time = 5.0
212
- resp_score = min(1.0, resp_time / max_resp_time)
213
-
214
- # Time since last use (for distributing load)
215
- recency = now - stats["last_used"] if stats["last_used"] > 0 else 60
216
- recency_score = 1.0 - min(1.0, recency / 60.0) # Unused for 60s approaches 0
217
-
218
- # Combined score (lower is better)
219
- # Weights can be adjusted based on importance
220
- score = (0.4 * conn_score) + (0.4 * resp_score) + (0.2 * recency_score)
221
-
222
- if score < best_score:
223
- best_score = score
224
- selected_model = model_id
225
-
226
- return selected_model
File without changes
@@ -1,202 +0,0 @@
1
- import time
2
- import threading
3
- import logging
4
- from typing import Dict, List, Any, Optional
5
- import psutil
6
- try:
7
- import torch
8
- import nvidia_smi
9
- HAS_GPU = torch.cuda.is_available()
10
- except ImportError:
11
- HAS_GPU = False
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
- class ResourceManager:
16
- """
17
- Monitors system resources and manages model loading/unloading
18
- to prevent resource exhaustion.
19
- """
20
-
21
- def __init__(self, model_registry, monitoring_interval=30):
22
- self.registry = model_registry
23
- self.monitoring_interval = monitoring_interval # seconds
24
-
25
- self.max_memory_percent = 90 # Maximum memory usage percentage
26
- self.max_gpu_memory_percent = 90 # Maximum GPU memory usage percentage
27
-
28
- self._stop_event = threading.Event()
29
- self._monitor_thread = None
30
-
31
- # Track resource usage over time
32
- self.resource_history = {
33
- "timestamps": [],
34
- "cpu_percent": [],
35
- "memory_percent": [],
36
- "gpu_utilization": [],
37
- "gpu_memory_percent": []
38
- }
39
-
40
- # Initialize GPU monitoring if available
41
- if HAS_GPU:
42
- try:
43
- nvidia_smi.nvmlInit()
44
- self.gpu_count = torch.cuda.device_count()
45
- logger.info(f"Initialized GPU monitoring with {self.gpu_count} devices")
46
- except Exception as e:
47
- logger.warning(f"Failed to initialize NVIDIA SMI: {str(e)}")
48
- self.gpu_count = 0
49
- else:
50
- self.gpu_count = 0
51
-
52
- logger.info("Initialized ResourceManager")
53
-
54
- def start_monitoring(self):
55
- """Start the resource monitoring thread"""
56
- if self._monitor_thread is not None and self._monitor_thread.is_alive():
57
- logger.warning("Resource monitoring already running")
58
- return
59
-
60
- self._stop_event.clear()
61
- self._monitor_thread = threading.Thread(
62
- target=self._monitor_resources,
63
- daemon=True
64
- )
65
- self._monitor_thread.start()
66
- logger.info("Started resource monitoring thread")
67
-
68
- def stop_monitoring(self):
69
- """Stop the resource monitoring thread"""
70
- if self._monitor_thread is not None:
71
- self._stop_event.set()
72
- self._monitor_thread.join(timeout=5)
73
- self._monitor_thread = None
74
- logger.info("Stopped resource monitoring thread")
75
-
76
- def _monitor_resources(self):
77
- """Monitor system resources in a loop"""
78
- while not self._stop_event.is_set():
79
- try:
80
- # Get current resource usage
81
- cpu_percent = psutil.cpu_percent(interval=1)
82
- memory = psutil.virtual_memory()
83
- memory_percent = memory.percent
84
-
85
- # GPU monitoring
86
- gpu_utilization = 0
87
- gpu_memory_percent = 0
88
-
89
- if self.gpu_count > 0:
90
- gpu_util_sum = 0
91
- gpu_mem_percent_sum = 0
92
-
93
- for i in range(self.gpu_count):
94
- try:
95
- handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)
96
- util = nvidia_smi.nvmlDeviceGetUtilizationRates(handle)
97
- mem_info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
98
-
99
- gpu_util_sum += util.gpu
100
- gpu_mem_percent = (mem_info.used / mem_info.total) * 100
101
- gpu_mem_percent_sum += gpu_mem_percent
102
- except Exception as e:
103
- logger.error(f"Error getting GPU {i} stats: {str(e)}")
104
-
105
- if self.gpu_count > 0:
106
- gpu_utilization = gpu_util_sum / self.gpu_count
107
- gpu_memory_percent = gpu_mem_percent_sum / self.gpu_count
108
-
109
- # Record history (keep last 60 samples)
110
- now = time.time()
111
- self.resource_history["timestamps"].append(now)
112
- self.resource_history["cpu_percent"].append(cpu_percent)
113
- self.resource_history["memory_percent"].append(memory_percent)
114
- self.resource_history["gpu_utilization"].append(gpu_utilization)
115
- self.resource_history["gpu_memory_percent"].append(gpu_memory_percent)
116
-
117
- # Trim history to last 60 samples
118
- max_history = 60
119
- if len(self.resource_history["timestamps"]) > max_history:
120
- for key in self.resource_history:
121
- self.resource_history[key] = self.resource_history[key][-max_history:]
122
-
123
- # Check if we need to unload models
124
- self._check_resource_constraints(memory_percent, gpu_memory_percent)
125
-
126
- # Log current usage
127
- logger.debug(
128
- f"Resource usage - CPU: {cpu_percent:.1f}%, Memory: {memory_percent:.1f}%, "
129
- f"GPU: {gpu_utilization:.1f}%, GPU Memory: {gpu_memory_percent:.1f}%"
130
- )
131
-
132
- # Wait for next check
133
- self._stop_event.wait(self.monitoring_interval)
134
-
135
- except Exception as e:
136
- logger.error(f"Error in resource monitoring: {str(e)}")
137
- # Wait a bit before retrying
138
- self._stop_event.wait(5)
139
-
140
- def _check_resource_constraints(self, memory_percent, gpu_memory_percent):
141
- """Check if we need to unload models due to resource constraints"""
142
- # Check memory usage
143
- if memory_percent > self.max_memory_percent:
144
- logger.warning(
145
- f"Memory usage ({memory_percent:.1f}%) exceeds threshold ({self.max_memory_percent}%). "
146
- "Unloading least used model."
147
- )
148
- # This would trigger model unloading
149
- # self.registry._evict_least_used_model()
150
-
151
- # Check GPU memory usage
152
- if HAS_GPU and gpu_memory_percent > self.max_gpu_memory_percent:
153
- logger.warning(
154
- f"GPU memory usage ({gpu_memory_percent:.1f}%) exceeds threshold ({self.max_gpu_memory_percent}%). "
155
- "Unloading least used model."
156
- )
157
- # This would trigger model unloading
158
- # self.registry._evict_least_used_model()
159
-
160
- def get_resource_usage(self) -> Dict[str, Any]:
161
- """Get current resource usage stats"""
162
- try:
163
- cpu_percent = psutil.cpu_percent(interval=0.1)
164
- memory = psutil.virtual_memory()
165
- memory_percent = memory.percent
166
-
167
- result = {
168
- "cpu_percent": cpu_percent,
169
- "memory_total_gb": memory.total / (1024**3),
170
- "memory_available_gb": memory.available / (1024**3),
171
- "memory_percent": memory_percent,
172
- "gpus": []
173
- }
174
-
175
- # GPU stats
176
- if HAS_GPU:
177
- for i in range(self.gpu_count):
178
- try:
179
- handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)
180
- name = nvidia_smi.nvmlDeviceGetName(handle)
181
- util = nvidia_smi.nvmlDeviceGetUtilizationRates(handle)
182
- mem_info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
183
- temp = nvidia_smi.nvmlDeviceGetTemperature(
184
- handle, nvidia_smi.NVML_TEMPERATURE_GPU
185
- )
186
-
187
- result["gpus"].append({
188
- "index": i,
189
- "name": name,
190
- "utilization_percent": util.gpu,
191
- "memory_total_gb": mem_info.total / (1024**3),
192
- "memory_used_gb": mem_info.used / (1024**3),
193
- "memory_percent": (mem_info.used / mem_info.total) * 100,
194
- "temperature_c": temp
195
- })
196
- except Exception as e:
197
- logger.error(f"Error getting GPU {i} stats: {str(e)}")
198
-
199
- return result
200
- except Exception as e:
201
- logger.error(f"Error getting resource usage: {str(e)}")
202
- return {"error": str(e)}
@@ -1,120 +0,0 @@
1
- import json
2
- import numpy as np
3
- import torch
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- import os
6
- import triton_python_backend_utils as pb_utils
7
-
8
- class TritonPythonModel:
9
- def initialize(self, args):
10
- """初始化模型"""
11
- self.model_config = json.loads(args['model_config'])
12
-
13
- # --- START: CORRECTED PATH LOGIC ---
14
-
15
- # model_repository 是父目录, e.g., /models/deepseek_r1
16
- model_repository = args['model_repository']
17
- # model_version 是版本号, e.g., '1'
18
- model_version = args['model_version']
19
-
20
- # 将它们拼接成指向模型文件的确切路径
21
- model_path = os.path.join(model_repository, model_version)
22
-
23
- print(f"Loading model from specific version path: {model_path}")
24
-
25
- self.tokenizer = AutoTokenizer.from_pretrained(
26
- model_path, # 从正确的版本目录加载
27
- trust_remote_code=True
28
- )
29
-
30
- self.model = AutoModelForCausalLM.from_pretrained(
31
- model_path, # 从正确的版本目录加载
32
- torch_dtype=torch.bfloat16,
33
- device_map="gpu",
34
- trust_remote_code=True
35
- )
36
-
37
- # --- END: CORRECTED PATH LOGIC ---
38
-
39
- # ... (您代码的其余部分保持不变) ...
40
- output_config = pb_utils.get_output_config_by_name(
41
- self.model_config, "OUTPUT_TEXT"
42
- )
43
- self.output_dtype = pb_utils.triton_string_to_numpy(
44
- output_config['data_type']
45
- )
46
-
47
- self.generation_config = {
48
- 'max_new_tokens': 512,
49
- 'temperature': 0.7,
50
- 'do_sample': True,
51
- 'top_p': 0.9,
52
- 'repetition_penalty': 1.1,
53
- 'pad_token_id': self.tokenizer.eos_token_id
54
- }
55
-
56
- print("Model loaded successfully!")
57
-
58
- def execute(self, requests):
59
- """执行推理"""
60
- responses = []
61
-
62
- for request in requests:
63
- # 获取输入文本
64
- input_text = pb_utils.get_input_tensor_by_name(
65
- request, "INPUT_TEXT"
66
- ).as_numpy()
67
-
68
- # 解码输入文本
69
- input_texts = [text.decode('utf-8') for text in input_text.flatten()]
70
-
71
- # 批量推理
72
- output_texts = []
73
- for text in input_texts:
74
- try:
75
- # 编码输入
76
- inputs = self.tokenizer.encode(
77
- text,
78
- return_tensors="pt"
79
- ).to(self.model.device)
80
-
81
- # 生成响应
82
- with torch.no_grad():
83
- outputs = self.model.generate(
84
- inputs,
85
- **self.generation_config
86
- )
87
-
88
- # 解码输出
89
- response = self.tokenizer.decode(
90
- outputs[0][inputs.shape[-1]:],
91
- skip_special_tokens=True
92
- )
93
-
94
- output_texts.append(response)
95
-
96
- except Exception as e:
97
- print(f"Error processing text: {e}")
98
- output_texts.append(f"Error: {str(e)}")
99
-
100
- # 准备输出
101
- output_texts_np = np.array(
102
- [[text.encode('utf-8')] for text in output_texts],
103
- dtype=object
104
- )
105
-
106
- output_tensor = pb_utils.Tensor(
107
- "OUTPUT_TEXT",
108
- output_texts_np.astype(self.output_dtype)
109
- )
110
-
111
- response = pb_utils.InferenceResponse(
112
- output_tensors=[output_tensor]
113
- )
114
- responses.append(response)
115
-
116
- return responses
117
-
118
- def finalize(self):
119
- """清理资源"""
120
- print("Cleaning up...")
@@ -1,18 +0,0 @@
1
- from huggingface_hub import snapshot_download
2
- import os
3
-
4
- model_name = 'deepseek-ai/DeepSeek-R1-0528-Qwen3-8B'
5
- # 定义Triton模型仓库中该模型的版本路径
6
- local_model_path = os.path.join("models", "deepseek_r1", "1")
7
-
8
- print(f"开始下载模型 '{model_name}' 到 '{local_model_path}'...")
9
-
10
- # 使用 snapshot_download 下载整个模型仓库
11
- # 它会下载所有文件,包括.safetensors权重文件
12
- snapshot_download(
13
- repo_id=model_name,
14
- local_dir=local_model_path,
15
- local_dir_use_symlinks=False,
16
- )
17
-
18
- print("模型所有文件下载完成!")
@@ -1,39 +0,0 @@
1
- """
2
- LlamaFactory engine for fine-tuning and reinforcement learning.
3
-
4
- This package provides interfaces for using LlamaFactory to:
5
- - Fine-tune models with various datasets
6
- - Perform reinforcement learning from human feedback (RLHF)
7
- - Support instruction tuning and preference optimization
8
- """
9
-
10
- from .config import (
11
- LlamaFactoryConfig,
12
- SFTConfig,
13
- RLConfig,
14
- DPOConfig,
15
- TrainingStrategy,
16
- DatasetFormat,
17
- create_default_config
18
- )
19
- from .trainer import LlamaFactoryTrainer
20
- from .rl import LlamaFactoryRL
21
- from .factory import LlamaFactory
22
- from .data_adapter import DataAdapter, AlpacaAdapter, ShareGPTAdapter, DataAdapterFactory
23
-
24
- __all__ = [
25
- "LlamaFactoryTrainer",
26
- "LlamaFactoryRL",
27
- "LlamaFactoryConfig",
28
- "SFTConfig",
29
- "RLConfig",
30
- "DPOConfig",
31
- "TrainingStrategy",
32
- "DatasetFormat",
33
- "create_default_config",
34
- "LlamaFactory",
35
- "DataAdapter",
36
- "AlpacaAdapter",
37
- "ShareGPTAdapter",
38
- "DataAdapterFactory"
39
- ]