cortex-llm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortex/__init__.py +73 -0
- cortex/__main__.py +83 -0
- cortex/config.py +329 -0
- cortex/conversation_manager.py +468 -0
- cortex/fine_tuning/__init__.py +8 -0
- cortex/fine_tuning/dataset.py +332 -0
- cortex/fine_tuning/mlx_lora_trainer.py +502 -0
- cortex/fine_tuning/trainer.py +957 -0
- cortex/fine_tuning/wizard.py +707 -0
- cortex/gpu_validator.py +467 -0
- cortex/inference_engine.py +727 -0
- cortex/metal/__init__.py +275 -0
- cortex/metal/gpu_validator.py +177 -0
- cortex/metal/memory_pool.py +886 -0
- cortex/metal/mlx_accelerator.py +678 -0
- cortex/metal/mlx_converter.py +638 -0
- cortex/metal/mps_optimizer.py +417 -0
- cortex/metal/optimizer.py +665 -0
- cortex/metal/performance_profiler.py +364 -0
- cortex/model_downloader.py +130 -0
- cortex/model_manager.py +2187 -0
- cortex/quantization/__init__.py +5 -0
- cortex/quantization/dynamic_quantizer.py +736 -0
- cortex/template_registry/__init__.py +15 -0
- cortex/template_registry/auto_detector.py +144 -0
- cortex/template_registry/config_manager.py +234 -0
- cortex/template_registry/interactive.py +260 -0
- cortex/template_registry/registry.py +347 -0
- cortex/template_registry/template_profiles/__init__.py +5 -0
- cortex/template_registry/template_profiles/base.py +142 -0
- cortex/template_registry/template_profiles/complex/__init__.py +5 -0
- cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
- cortex/template_registry/template_profiles/standard/__init__.py +9 -0
- cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
- cortex/template_registry/template_profiles/standard/chatml.py +82 -0
- cortex/template_registry/template_profiles/standard/gemma.py +103 -0
- cortex/template_registry/template_profiles/standard/llama.py +87 -0
- cortex/template_registry/template_profiles/standard/simple.py +65 -0
- cortex/ui/__init__.py +120 -0
- cortex/ui/cli.py +1685 -0
- cortex/ui/markdown_render.py +185 -0
- cortex/ui/terminal_app.py +534 -0
- cortex_llm-1.0.0.dist-info/METADATA +275 -0
- cortex_llm-1.0.0.dist-info/RECORD +48 -0
- cortex_llm-1.0.0.dist-info/WHEEL +5 -0
- cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
- cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
- cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,727 @@
|
|
|
1
|
+
"""GPU-only inference engine for Cortex."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import time
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Dict, Any, Optional, List, Generator, AsyncGenerator, Tuple
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from enum import Enum
|
|
10
|
+
import threading
|
|
11
|
+
from queue import Queue
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
# Configure logging
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
logger.setLevel(logging.INFO)
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import mlx.core as mx
|
|
20
|
+
import mlx.nn as nn
|
|
21
|
+
|
|
22
|
+
# Import MLX LM functions safely
|
|
23
|
+
try:
|
|
24
|
+
from mlx_lm import generate as mlx_generate, stream_generate as mlx_stream_generate
|
|
25
|
+
except ImportError:
|
|
26
|
+
mlx_generate = None
|
|
27
|
+
mlx_stream_generate = None
|
|
28
|
+
|
|
29
|
+
from cortex.config import Config
|
|
30
|
+
from cortex.model_manager import ModelManager, ModelFormat
|
|
31
|
+
from cortex.metal.memory_pool import MemoryPool, AllocationStrategy
|
|
32
|
+
from cortex.metal.mps_optimizer import MPSOptimizer, MPSConfig
|
|
33
|
+
from cortex.metal.mlx_accelerator import MLXAccelerator, MLXConfig
|
|
34
|
+
from cortex.metal.performance_profiler import PerformanceProfiler
|
|
35
|
+
|
|
36
|
+
class InferenceStatus(Enum):
|
|
37
|
+
"""Status of inference operation."""
|
|
38
|
+
IDLE = "idle"
|
|
39
|
+
LOADING = "loading"
|
|
40
|
+
GENERATING = "generating"
|
|
41
|
+
COMPLETED = "completed"
|
|
42
|
+
ERROR = "error"
|
|
43
|
+
CANCELLED = "cancelled"
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class GenerationMetrics:
|
|
47
|
+
"""Metrics for generation performance."""
|
|
48
|
+
tokens_generated: int
|
|
49
|
+
time_elapsed: float
|
|
50
|
+
tokens_per_second: float
|
|
51
|
+
gpu_utilization: float
|
|
52
|
+
memory_used_gb: float
|
|
53
|
+
first_token_latency: float
|
|
54
|
+
|
|
55
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
56
|
+
"""Convert to dictionary."""
|
|
57
|
+
return {
|
|
58
|
+
'tokens_generated': self.tokens_generated,
|
|
59
|
+
'time_elapsed': self.time_elapsed,
|
|
60
|
+
'tokens_per_second': self.tokens_per_second,
|
|
61
|
+
'gpu_utilization': self.gpu_utilization,
|
|
62
|
+
'memory_used_gb': self.memory_used_gb,
|
|
63
|
+
'first_token_latency': self.first_token_latency
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class GenerationRequest:
|
|
68
|
+
"""Request for text generation."""
|
|
69
|
+
prompt: str
|
|
70
|
+
max_tokens: int = 2048
|
|
71
|
+
temperature: float = 0.7
|
|
72
|
+
top_p: float = 0.95
|
|
73
|
+
top_k: int = 40
|
|
74
|
+
repetition_penalty: float = 1.1
|
|
75
|
+
stop_sequences: List[str] = None
|
|
76
|
+
stream: bool = True
|
|
77
|
+
seed: Optional[int] = None
|
|
78
|
+
|
|
79
|
+
def __post_init__(self):
|
|
80
|
+
if self.stop_sequences is None:
|
|
81
|
+
self.stop_sequences = []
|
|
82
|
+
|
|
83
|
+
class InferenceEngine:
|
|
84
|
+
"""GPU-accelerated inference engine."""
|
|
85
|
+
|
|
86
|
+
def __init__(self, config: Config, model_manager: ModelManager):
|
|
87
|
+
"""Initialize inference engine."""
|
|
88
|
+
self.config = config
|
|
89
|
+
self.model_manager = model_manager
|
|
90
|
+
self.status = InferenceStatus.IDLE
|
|
91
|
+
self.current_metrics: Optional[GenerationMetrics] = None
|
|
92
|
+
self._cancel_event = threading.Event()
|
|
93
|
+
self._generation_lock = threading.Lock()
|
|
94
|
+
|
|
95
|
+
# Initialize Metal optimizations
|
|
96
|
+
self.memory_pool: Optional[MemoryPool] = None
|
|
97
|
+
self.mps_optimizer: Optional[MPSOptimizer] = None
|
|
98
|
+
self.mlx_accelerator: Optional[MLXAccelerator] = None
|
|
99
|
+
self.profiler = PerformanceProfiler(sample_interval=0.1)
|
|
100
|
+
|
|
101
|
+
self._ensure_gpu_backend()
|
|
102
|
+
self._initialize_metal_optimizations()
|
|
103
|
+
|
|
104
|
+
def _ensure_gpu_backend(self) -> None:
|
|
105
|
+
"""Ensure GPU backend is available."""
|
|
106
|
+
if not torch.backends.mps.is_available():
|
|
107
|
+
print("❌ MPS backend not available. GPU acceleration required.")
|
|
108
|
+
sys.exit(1)
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
mx.default_device()
|
|
112
|
+
except Exception as e:
|
|
113
|
+
print(f"❌ MLX not available: {e}")
|
|
114
|
+
print("GPU acceleration via MLX is required.")
|
|
115
|
+
sys.exit(1)
|
|
116
|
+
|
|
117
|
+
def _initialize_metal_optimizations(self) -> None:
|
|
118
|
+
"""Initialize Metal-specific optimizations."""
|
|
119
|
+
# Initialize shared memory pool with auto-sizing
|
|
120
|
+
if self.config.gpu.force_gpu and self.memory_pool is None:
|
|
121
|
+
# Create a single shared memory pool to avoid duplication
|
|
122
|
+
self.memory_pool = MemoryPool(
|
|
123
|
+
pool_size=None, # Will auto-size based on available memory
|
|
124
|
+
strategy=AllocationStrategy.UNIFIED,
|
|
125
|
+
device="mps" if torch.backends.mps.is_available() else "mlx",
|
|
126
|
+
auto_size=True # Enable auto-sizing
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Share the pool with model manager to avoid duplication
|
|
130
|
+
if hasattr(self.model_manager, 'memory_pool') and self.model_manager.memory_pool is None:
|
|
131
|
+
self.model_manager.memory_pool = self.memory_pool
|
|
132
|
+
|
|
133
|
+
# Initialize MPS optimizer
|
|
134
|
+
if torch.backends.mps.is_available():
|
|
135
|
+
mps_config = MPSConfig(
|
|
136
|
+
use_fp16=True,
|
|
137
|
+
use_channels_last=True,
|
|
138
|
+
optimize_memory=True,
|
|
139
|
+
max_batch_size=self.config.performance.batch_size
|
|
140
|
+
)
|
|
141
|
+
self.mps_optimizer = MPSOptimizer(mps_config)
|
|
142
|
+
|
|
143
|
+
# Initialize MLX accelerator with AMX and advanced features
|
|
144
|
+
try:
|
|
145
|
+
mlx_config = MLXConfig(
|
|
146
|
+
compile_model=True,
|
|
147
|
+
use_graph=True,
|
|
148
|
+
batch_size=self.config.performance.batch_size,
|
|
149
|
+
dtype=mx.bfloat16 if self._supports_bfloat16() else mx.float16,
|
|
150
|
+
use_amx=True,
|
|
151
|
+
fuse_operations=True,
|
|
152
|
+
lazy_evaluation=True,
|
|
153
|
+
rotating_kv_cache=True,
|
|
154
|
+
kv_cache_size=self.config.model.context_length if hasattr(self.config.model, 'context_length') else 4096,
|
|
155
|
+
quantization_bits=4
|
|
156
|
+
)
|
|
157
|
+
self.mlx_accelerator = MLXAccelerator(mlx_config)
|
|
158
|
+
print("✓ MLX accelerator initialized with AMX support")
|
|
159
|
+
except Exception as e:
|
|
160
|
+
print(f"Warning: MLX accelerator initialization failed: {e}")
|
|
161
|
+
self.mlx_accelerator = None
|
|
162
|
+
|
|
163
|
+
# GPU acceleration handled by MLX and MPS backends
|
|
164
|
+
|
|
165
|
+
def generate(
|
|
166
|
+
self,
|
|
167
|
+
request: GenerationRequest
|
|
168
|
+
) -> Generator[str, None, GenerationMetrics]:
|
|
169
|
+
"""
|
|
170
|
+
Generate text using GPU-accelerated inference.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
request: Generation request parameters
|
|
174
|
+
|
|
175
|
+
Yields:
|
|
176
|
+
Generated text tokens
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
Generation metrics
|
|
180
|
+
"""
|
|
181
|
+
with self._generation_lock:
|
|
182
|
+
if self.status == InferenceStatus.GENERATING:
|
|
183
|
+
raise RuntimeError("Generation already in progress")
|
|
184
|
+
|
|
185
|
+
self.status = InferenceStatus.GENERATING
|
|
186
|
+
self._cancel_event.clear()
|
|
187
|
+
|
|
188
|
+
try:
|
|
189
|
+
model_info = self.model_manager.get_current_model()
|
|
190
|
+
if not model_info:
|
|
191
|
+
raise RuntimeError("No model loaded")
|
|
192
|
+
|
|
193
|
+
model = self.model_manager.model_cache.get(model_info.name)
|
|
194
|
+
tokenizer = self.model_manager.tokenizers.get(model_info.name)
|
|
195
|
+
|
|
196
|
+
if not model or not tokenizer:
|
|
197
|
+
raise RuntimeError(f"Model '{model_info.name}' not properly loaded")
|
|
198
|
+
|
|
199
|
+
if model_info.format == ModelFormat.MLX:
|
|
200
|
+
yield from self._generate_mlx(model, tokenizer, request)
|
|
201
|
+
elif model_info.format == ModelFormat.PYTORCH:
|
|
202
|
+
yield from self._generate_pytorch(model, tokenizer, request)
|
|
203
|
+
elif model_info.format == ModelFormat.SAFETENSORS:
|
|
204
|
+
yield from self._generate_safetensors(model, tokenizer, request)
|
|
205
|
+
elif model_info.format == ModelFormat.GGUF:
|
|
206
|
+
yield from self._generate_gguf(model, tokenizer, request)
|
|
207
|
+
else:
|
|
208
|
+
raise RuntimeError(f"Unsupported format: {model_info.format}")
|
|
209
|
+
|
|
210
|
+
return self.current_metrics
|
|
211
|
+
|
|
212
|
+
except Exception as e:
|
|
213
|
+
self.status = InferenceStatus.ERROR
|
|
214
|
+
raise e
|
|
215
|
+
finally:
|
|
216
|
+
if self.status != InferenceStatus.CANCELLED:
|
|
217
|
+
self.status = InferenceStatus.COMPLETED
|
|
218
|
+
|
|
219
|
+
def _generate_mlx(
|
|
220
|
+
self,
|
|
221
|
+
model: Any,
|
|
222
|
+
tokenizer: Any,
|
|
223
|
+
request: GenerationRequest
|
|
224
|
+
) -> Generator[str, None, None]:
|
|
225
|
+
"""Generate using MLX model on GPU with Metal optimizations."""
|
|
226
|
+
# Apply MLX optimizations if available
|
|
227
|
+
if self.mlx_accelerator:
|
|
228
|
+
logger.info("Applying MLX accelerator optimizations to model")
|
|
229
|
+
model = self.mlx_accelerator.optimize_model(model)
|
|
230
|
+
|
|
231
|
+
# Start profiling
|
|
232
|
+
self.profiler.start_profiling("mlx_generation", {
|
|
233
|
+
"model_type": "mlx",
|
|
234
|
+
"max_tokens": request.max_tokens
|
|
235
|
+
})
|
|
236
|
+
|
|
237
|
+
start_time = time.time()
|
|
238
|
+
tokens_generated = 0
|
|
239
|
+
first_token_time = None
|
|
240
|
+
last_metrics_update = time.time()
|
|
241
|
+
|
|
242
|
+
try:
|
|
243
|
+
# Use MLX accelerator's optimized generation if available
|
|
244
|
+
if self.mlx_accelerator and request.stream:
|
|
245
|
+
logger.info("Using MLX accelerator optimized generation with AMX")
|
|
246
|
+
for token in self.mlx_accelerator.generate_optimized(
|
|
247
|
+
model,
|
|
248
|
+
tokenizer,
|
|
249
|
+
request.prompt,
|
|
250
|
+
max_tokens=request.max_tokens,
|
|
251
|
+
temperature=request.temperature,
|
|
252
|
+
top_p=request.top_p,
|
|
253
|
+
repetition_penalty=request.repetition_penalty,
|
|
254
|
+
stream=True,
|
|
255
|
+
stop_sequences=request.stop_sequences
|
|
256
|
+
):
|
|
257
|
+
if self._cancel_event.is_set():
|
|
258
|
+
self.status = InferenceStatus.CANCELLED
|
|
259
|
+
break
|
|
260
|
+
|
|
261
|
+
if first_token_time is None:
|
|
262
|
+
first_token_time = time.time() - start_time
|
|
263
|
+
|
|
264
|
+
tokens_generated += 1
|
|
265
|
+
|
|
266
|
+
# Update metrics less frequently
|
|
267
|
+
current_time = time.time()
|
|
268
|
+
if current_time - last_metrics_update > 1.0 or tokens_generated % 50 == 0:
|
|
269
|
+
elapsed_time = current_time - start_time
|
|
270
|
+
|
|
271
|
+
self.current_metrics = GenerationMetrics(
|
|
272
|
+
tokens_generated=tokens_generated,
|
|
273
|
+
time_elapsed=elapsed_time,
|
|
274
|
+
tokens_per_second=tokens_generated / elapsed_time if elapsed_time > 0 else 0,
|
|
275
|
+
gpu_utilization=0.0,
|
|
276
|
+
memory_used_gb=0.0,
|
|
277
|
+
first_token_latency=first_token_time or 0
|
|
278
|
+
)
|
|
279
|
+
last_metrics_update = current_time
|
|
280
|
+
|
|
281
|
+
# Token is already a string from generate_optimized
|
|
282
|
+
yield token
|
|
283
|
+
|
|
284
|
+
if any(stop in token for stop in request.stop_sequences):
|
|
285
|
+
break
|
|
286
|
+
elif mlx_generate:
|
|
287
|
+
# Fallback to standard MLX generation
|
|
288
|
+
logger.info("Using standard MLX generation")
|
|
289
|
+
|
|
290
|
+
# Import sample_utils for creating sampler
|
|
291
|
+
try:
|
|
292
|
+
from mlx_lm.sample_utils import make_sampler
|
|
293
|
+
# Create sampler with temperature and top_p
|
|
294
|
+
sampler = make_sampler(request.temperature, top_p=request.top_p)
|
|
295
|
+
logger.debug(f"Created sampler with temp={request.temperature}, top_p={request.top_p}")
|
|
296
|
+
except ImportError:
|
|
297
|
+
sampler = None
|
|
298
|
+
logger.warning("mlx_lm.sample_utils not available, using default sampler")
|
|
299
|
+
|
|
300
|
+
# Build generation kwargs
|
|
301
|
+
generation_kwargs = {
|
|
302
|
+
'prompt': request.prompt,
|
|
303
|
+
'max_tokens': request.max_tokens,
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
if sampler is not None:
|
|
307
|
+
generation_kwargs['sampler'] = sampler
|
|
308
|
+
|
|
309
|
+
if request.seed is not None and request.seed >= 0:
|
|
310
|
+
mx.random.seed(request.seed)
|
|
311
|
+
|
|
312
|
+
for response in mlx_generate(
|
|
313
|
+
model,
|
|
314
|
+
tokenizer,
|
|
315
|
+
**generation_kwargs
|
|
316
|
+
):
|
|
317
|
+
if self._cancel_event.is_set():
|
|
318
|
+
self.status = InferenceStatus.CANCELLED
|
|
319
|
+
break
|
|
320
|
+
|
|
321
|
+
# Extract text from GenerationResponse
|
|
322
|
+
if hasattr(response, 'text'):
|
|
323
|
+
token = response.text
|
|
324
|
+
else:
|
|
325
|
+
token = str(response)
|
|
326
|
+
|
|
327
|
+
if first_token_time is None:
|
|
328
|
+
first_token_time = time.time() - start_time
|
|
329
|
+
|
|
330
|
+
tokens_generated += 1
|
|
331
|
+
|
|
332
|
+
# Update metrics less frequently to reduce overhead
|
|
333
|
+
# Only update every 50 tokens or 1 second for better performance
|
|
334
|
+
current_time = time.time()
|
|
335
|
+
if current_time - last_metrics_update > 1.0 or tokens_generated % 50 == 0:
|
|
336
|
+
elapsed_time = current_time - start_time
|
|
337
|
+
|
|
338
|
+
# Skip expensive GPU queries during generation for better performance
|
|
339
|
+
# These will be calculated once at the end
|
|
340
|
+
self.current_metrics = GenerationMetrics(
|
|
341
|
+
tokens_generated=tokens_generated,
|
|
342
|
+
time_elapsed=elapsed_time,
|
|
343
|
+
tokens_per_second=tokens_generated / elapsed_time if elapsed_time > 0 else 0,
|
|
344
|
+
gpu_utilization=0.0, # Skip during generation
|
|
345
|
+
memory_used_gb=0.0, # Skip during generation
|
|
346
|
+
first_token_latency=first_token_time or 0
|
|
347
|
+
)
|
|
348
|
+
last_metrics_update = current_time
|
|
349
|
+
|
|
350
|
+
yield token
|
|
351
|
+
|
|
352
|
+
if any(stop in token for stop in request.stop_sequences):
|
|
353
|
+
break
|
|
354
|
+
else:
|
|
355
|
+
# No MLX generation available
|
|
356
|
+
logger.error("MLX generation functions not available")
|
|
357
|
+
raise RuntimeError("MLX generation not available. Please install mlx-lm.")
|
|
358
|
+
|
|
359
|
+
elapsed_time = time.time() - start_time
|
|
360
|
+
|
|
361
|
+
# Stop profiling and get final results
|
|
362
|
+
profile_result = self.profiler.stop_profiling()
|
|
363
|
+
|
|
364
|
+
# Update final metrics
|
|
365
|
+
self.current_metrics = GenerationMetrics(
|
|
366
|
+
tokens_generated=tokens_generated,
|
|
367
|
+
time_elapsed=elapsed_time,
|
|
368
|
+
tokens_per_second=tokens_generated / elapsed_time if elapsed_time > 0 else 0,
|
|
369
|
+
gpu_utilization=profile_result.gpu_utilization,
|
|
370
|
+
memory_used_gb=profile_result.memory_used_mb / 1024,
|
|
371
|
+
first_token_latency=first_token_time or 0
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
except Exception as e:
|
|
375
|
+
self.status = InferenceStatus.ERROR
|
|
376
|
+
self.profiler.stop_profiling()
|
|
377
|
+
raise e
|
|
378
|
+
|
|
379
|
+
def _generate_pytorch(
|
|
380
|
+
self,
|
|
381
|
+
model: Any,
|
|
382
|
+
tokenizer: Any,
|
|
383
|
+
request: GenerationRequest
|
|
384
|
+
) -> Generator[str, None, None]:
|
|
385
|
+
"""Generate using PyTorch model on MPS with Metal optimizations."""
|
|
386
|
+
# Apply MPS optimizations if available
|
|
387
|
+
if self.mps_optimizer:
|
|
388
|
+
model = self.mps_optimizer.optimize_model(model)
|
|
389
|
+
|
|
390
|
+
# Start profiling
|
|
391
|
+
self.profiler.start_profiling("pytorch_generation", {
|
|
392
|
+
"model_type": "pytorch",
|
|
393
|
+
"max_tokens": request.max_tokens
|
|
394
|
+
})
|
|
395
|
+
|
|
396
|
+
start_time = time.time()
|
|
397
|
+
tokens_generated = 0
|
|
398
|
+
first_token_time = None
|
|
399
|
+
last_metrics_update = time.time()
|
|
400
|
+
|
|
401
|
+
try:
|
|
402
|
+
device = torch.device("mps")
|
|
403
|
+
|
|
404
|
+
inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
|
|
405
|
+
|
|
406
|
+
generation_config = {
|
|
407
|
+
'max_new_tokens': request.max_tokens,
|
|
408
|
+
'temperature': request.temperature,
|
|
409
|
+
'top_p': request.top_p,
|
|
410
|
+
'top_k': request.top_k,
|
|
411
|
+
'repetition_penalty': request.repetition_penalty,
|
|
412
|
+
'do_sample': request.temperature > 0,
|
|
413
|
+
'pad_token_id': tokenizer.pad_token_id,
|
|
414
|
+
'eos_token_id': tokenizer.eos_token_id,
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
if request.seed is not None and request.seed >= 0:
|
|
418
|
+
torch.manual_seed(request.seed)
|
|
419
|
+
|
|
420
|
+
with torch.no_grad():
|
|
421
|
+
if request.stream:
|
|
422
|
+
from transformers import TextIteratorStreamer
|
|
423
|
+
|
|
424
|
+
streamer = TextIteratorStreamer(
|
|
425
|
+
tokenizer,
|
|
426
|
+
skip_prompt=True,
|
|
427
|
+
skip_special_tokens=True
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
generation_kwargs = dict(
|
|
431
|
+
inputs,
|
|
432
|
+
streamer=streamer,
|
|
433
|
+
**generation_config
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
thread = threading.Thread(
|
|
437
|
+
target=model.generate,
|
|
438
|
+
kwargs=generation_kwargs
|
|
439
|
+
)
|
|
440
|
+
thread.start()
|
|
441
|
+
|
|
442
|
+
for token in streamer:
|
|
443
|
+
if self._cancel_event.is_set():
|
|
444
|
+
self.status = InferenceStatus.CANCELLED
|
|
445
|
+
break
|
|
446
|
+
|
|
447
|
+
if first_token_time is None:
|
|
448
|
+
first_token_time = time.time() - start_time
|
|
449
|
+
|
|
450
|
+
tokens_generated += 1
|
|
451
|
+
|
|
452
|
+
# Update metrics less frequently to reduce overhead
|
|
453
|
+
# Only update every 50 tokens or 1 second for better performance
|
|
454
|
+
current_time = time.time()
|
|
455
|
+
if current_time - last_metrics_update > 1.0 or tokens_generated % 50 == 0:
|
|
456
|
+
elapsed_time = current_time - start_time
|
|
457
|
+
|
|
458
|
+
# Skip expensive GPU queries during generation for better performance
|
|
459
|
+
# These will be calculated once at the end
|
|
460
|
+
self.current_metrics = GenerationMetrics(
|
|
461
|
+
tokens_generated=tokens_generated,
|
|
462
|
+
time_elapsed=elapsed_time,
|
|
463
|
+
tokens_per_second=tokens_generated / elapsed_time if elapsed_time > 0 else 0,
|
|
464
|
+
gpu_utilization=0.0, # Skip during generation
|
|
465
|
+
memory_used_gb=0.0, # Skip during generation
|
|
466
|
+
first_token_latency=first_token_time or 0
|
|
467
|
+
)
|
|
468
|
+
last_metrics_update = current_time
|
|
469
|
+
|
|
470
|
+
yield token
|
|
471
|
+
|
|
472
|
+
if any(stop in token for stop in request.stop_sequences):
|
|
473
|
+
break
|
|
474
|
+
|
|
475
|
+
thread.join()
|
|
476
|
+
|
|
477
|
+
else:
|
|
478
|
+
outputs = model.generate(
|
|
479
|
+
**inputs,
|
|
480
|
+
**generation_config
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
generated_text = tokenizer.decode(
|
|
484
|
+
outputs[0][inputs['input_ids'].shape[1]:],
|
|
485
|
+
skip_special_tokens=True
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
tokens_generated = outputs.shape[1] - inputs['input_ids'].shape[1]
|
|
489
|
+
first_token_time = (time.time() - start_time) / tokens_generated if tokens_generated > 0 else 0
|
|
490
|
+
|
|
491
|
+
yield generated_text
|
|
492
|
+
|
|
493
|
+
elapsed_time = time.time() - start_time
|
|
494
|
+
|
|
495
|
+
# Stop profiling and get final results
|
|
496
|
+
profile_result = self.profiler.stop_profiling()
|
|
497
|
+
|
|
498
|
+
# Update final metrics
|
|
499
|
+
self.current_metrics = GenerationMetrics(
|
|
500
|
+
tokens_generated=tokens_generated,
|
|
501
|
+
time_elapsed=elapsed_time,
|
|
502
|
+
tokens_per_second=tokens_generated / elapsed_time if elapsed_time > 0 else 0,
|
|
503
|
+
gpu_utilization=profile_result.gpu_utilization,
|
|
504
|
+
memory_used_gb=profile_result.memory_used_mb / 1024,
|
|
505
|
+
first_token_latency=first_token_time or 0
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
except Exception as e:
|
|
509
|
+
self.status = InferenceStatus.ERROR
|
|
510
|
+
self.profiler.stop_profiling()
|
|
511
|
+
raise e
|
|
512
|
+
|
|
513
|
+
def _generate_safetensors(
|
|
514
|
+
self,
|
|
515
|
+
model: Any,
|
|
516
|
+
tokenizer: Any,
|
|
517
|
+
request: GenerationRequest
|
|
518
|
+
) -> Generator[str, None, None]:
|
|
519
|
+
"""Generate using SafeTensors model (loaded as PyTorch) on MPS."""
|
|
520
|
+
# SafeTensors models are loaded as PyTorch models, so use the same generation logic
|
|
521
|
+
yield from self._generate_pytorch(model, tokenizer, request)
|
|
522
|
+
|
|
523
|
+
def _generate_gguf(
|
|
524
|
+
self,
|
|
525
|
+
model: Any,
|
|
526
|
+
tokenizer: Any,
|
|
527
|
+
request: GenerationRequest
|
|
528
|
+
) -> Generator[str, None, None]:
|
|
529
|
+
"""Generate using GGUF model with llama-cpp-python."""
|
|
530
|
+
start_time = time.time()
|
|
531
|
+
first_token_time = None
|
|
532
|
+
tokens_generated = 0
|
|
533
|
+
|
|
534
|
+
try:
|
|
535
|
+
# GGUF models use llama-cpp-python which has its own generation method
|
|
536
|
+
# The model is a Llama object from llama-cpp-python
|
|
537
|
+
|
|
538
|
+
# Generate response using llama-cpp's native method
|
|
539
|
+
response = model(
|
|
540
|
+
request.prompt,
|
|
541
|
+
max_tokens=request.max_tokens,
|
|
542
|
+
temperature=request.temperature,
|
|
543
|
+
top_p=request.top_p,
|
|
544
|
+
top_k=request.top_k,
|
|
545
|
+
repeat_penalty=request.repetition_penalty,
|
|
546
|
+
stream=request.stream
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
if request.stream:
|
|
550
|
+
# Stream tokens
|
|
551
|
+
for chunk in response:
|
|
552
|
+
if self._cancel_event.is_set():
|
|
553
|
+
break
|
|
554
|
+
|
|
555
|
+
if 'choices' in chunk and len(chunk['choices']) > 0:
|
|
556
|
+
token = chunk['choices'][0].get('text', '')
|
|
557
|
+
if token:
|
|
558
|
+
if first_token_time is None:
|
|
559
|
+
first_token_time = time.time()
|
|
560
|
+
tokens_generated += 1
|
|
561
|
+
yield token
|
|
562
|
+
else:
|
|
563
|
+
# Return full response
|
|
564
|
+
if 'choices' in response and len(response['choices']) > 0:
|
|
565
|
+
text = response['choices'][0].get('text', '')
|
|
566
|
+
tokens_generated = len(text.split()) # Rough estimate
|
|
567
|
+
yield text
|
|
568
|
+
|
|
569
|
+
# Calculate metrics
|
|
570
|
+
end_time = time.time()
|
|
571
|
+
time_elapsed = end_time - start_time
|
|
572
|
+
|
|
573
|
+
self.current_metrics = GenerationMetrics(
|
|
574
|
+
tokens_generated=tokens_generated,
|
|
575
|
+
time_elapsed=time_elapsed,
|
|
576
|
+
tokens_per_second=tokens_generated / time_elapsed if time_elapsed > 0 else 0,
|
|
577
|
+
gpu_utilization=0.0, # GGUF doesn't provide GPU metrics directly
|
|
578
|
+
memory_used_gb=self.model_manager.get_memory_status().get('model_memory_gb', 0),
|
|
579
|
+
first_token_latency=first_token_time - start_time if first_token_time else 0
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
except Exception as e:
|
|
583
|
+
self.status = InferenceStatus.ERROR
|
|
584
|
+
raise e
|
|
585
|
+
|
|
586
|
+
async def generate_async(
|
|
587
|
+
self,
|
|
588
|
+
request: GenerationRequest
|
|
589
|
+
) -> AsyncGenerator[str, None]:
|
|
590
|
+
"""
|
|
591
|
+
Async generator for text generation.
|
|
592
|
+
|
|
593
|
+
Args:
|
|
594
|
+
request: Generation request parameters
|
|
595
|
+
|
|
596
|
+
Yields:
|
|
597
|
+
Generated text tokens
|
|
598
|
+
"""
|
|
599
|
+
loop = asyncio.get_event_loop()
|
|
600
|
+
queue = Queue()
|
|
601
|
+
|
|
602
|
+
def generate_worker():
|
|
603
|
+
try:
|
|
604
|
+
for token in self.generate(request):
|
|
605
|
+
queue.put(token)
|
|
606
|
+
except Exception as e:
|
|
607
|
+
queue.put(e)
|
|
608
|
+
finally:
|
|
609
|
+
queue.put(None)
|
|
610
|
+
|
|
611
|
+
thread = threading.Thread(target=generate_worker)
|
|
612
|
+
thread.start()
|
|
613
|
+
|
|
614
|
+
while True:
|
|
615
|
+
result = await loop.run_in_executor(None, queue.get)
|
|
616
|
+
|
|
617
|
+
if result is None:
|
|
618
|
+
break
|
|
619
|
+
elif isinstance(result, Exception):
|
|
620
|
+
raise result
|
|
621
|
+
else:
|
|
622
|
+
yield result
|
|
623
|
+
|
|
624
|
+
thread.join()
|
|
625
|
+
|
|
626
|
+
def _supports_bfloat16(self) -> bool:
|
|
627
|
+
"""Check if system supports bfloat16."""
|
|
628
|
+
try:
|
|
629
|
+
test = mx.array([1.0], dtype=mx.bfloat16)
|
|
630
|
+
mx.eval(test)
|
|
631
|
+
return True
|
|
632
|
+
except:
|
|
633
|
+
return False
|
|
634
|
+
|
|
635
|
+
def cancel_generation(self) -> None:
|
|
636
|
+
"""Cancel ongoing generation."""
|
|
637
|
+
self._cancel_event.set()
|
|
638
|
+
self.status = InferenceStatus.CANCELLED
|
|
639
|
+
|
|
640
|
+
def _get_gpu_utilization(self) -> float:
|
|
641
|
+
"""Get current GPU utilization percentage."""
|
|
642
|
+
try:
|
|
643
|
+
import psutil
|
|
644
|
+
process = psutil.Process()
|
|
645
|
+
return min(process.cpu_percent() * 2, 100.0)
|
|
646
|
+
except:
|
|
647
|
+
return 0.0
|
|
648
|
+
|
|
649
|
+
def _get_memory_usage(self) -> float:
|
|
650
|
+
"""Get current GPU memory usage in GB."""
|
|
651
|
+
try:
|
|
652
|
+
if torch.backends.mps.is_available():
|
|
653
|
+
allocated = torch.mps.current_allocated_memory()
|
|
654
|
+
return allocated / (1024**3)
|
|
655
|
+
return 0.0
|
|
656
|
+
except:
|
|
657
|
+
return 0.0
|
|
658
|
+
|
|
659
|
+
def get_status(self) -> Dict[str, Any]:
|
|
660
|
+
"""Get current inference status with MLX details."""
|
|
661
|
+
status = {
|
|
662
|
+
'status': self.status.value,
|
|
663
|
+
'model': self.model_manager.current_model,
|
|
664
|
+
'metrics': self.current_metrics.to_dict() if self.current_metrics else None,
|
|
665
|
+
'gpu_memory_gb': self._get_memory_usage(),
|
|
666
|
+
'gpu_utilization': self._get_gpu_utilization()
|
|
667
|
+
}
|
|
668
|
+
|
|
669
|
+
# Add MLX accelerator status
|
|
670
|
+
if self.mlx_accelerator:
|
|
671
|
+
status['mlx_accelerator'] = {
|
|
672
|
+
'enabled': True,
|
|
673
|
+
'amx': self.mlx_accelerator.config.use_amx,
|
|
674
|
+
'operation_fusion': self.mlx_accelerator.config.fuse_operations,
|
|
675
|
+
'lazy_evaluation': self.mlx_accelerator.config.lazy_evaluation,
|
|
676
|
+
'kv_cache': self.mlx_accelerator.config.rotating_kv_cache,
|
|
677
|
+
'kv_cache_size': self.mlx_accelerator.config.kv_cache_size,
|
|
678
|
+
'quantization_bits': self.mlx_accelerator.config.quantization_bits
|
|
679
|
+
}
|
|
680
|
+
else:
|
|
681
|
+
status['mlx_accelerator'] = {'enabled': False}
|
|
682
|
+
|
|
683
|
+
return status
|
|
684
|
+
|
|
685
|
+
def benchmark(
|
|
686
|
+
self,
|
|
687
|
+
prompt: str = "Once upon a time",
|
|
688
|
+
num_tokens: int = 100
|
|
689
|
+
) -> GenerationMetrics:
|
|
690
|
+
"""
|
|
691
|
+
Run a benchmark test.
|
|
692
|
+
|
|
693
|
+
Args:
|
|
694
|
+
prompt: Prompt to use for benchmark
|
|
695
|
+
num_tokens: Number of tokens to generate
|
|
696
|
+
|
|
697
|
+
Returns:
|
|
698
|
+
Benchmark metrics
|
|
699
|
+
"""
|
|
700
|
+
request = GenerationRequest(
|
|
701
|
+
prompt=prompt,
|
|
702
|
+
max_tokens=num_tokens,
|
|
703
|
+
temperature=0.7,
|
|
704
|
+
stream=True
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
tokens = []
|
|
708
|
+
for token in self.generate(request):
|
|
709
|
+
tokens.append(token)
|
|
710
|
+
|
|
711
|
+
return self.current_metrics
|
|
712
|
+
|
|
713
|
+
def warmup(self) -> None:
|
|
714
|
+
"""Warm up the GPU with a small generation."""
|
|
715
|
+
try:
|
|
716
|
+
request = GenerationRequest(
|
|
717
|
+
prompt="Hello",
|
|
718
|
+
max_tokens=1,
|
|
719
|
+
temperature=0.0,
|
|
720
|
+
stream=False
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
for _ in self.generate(request):
|
|
724
|
+
pass
|
|
725
|
+
|
|
726
|
+
except Exception as e:
|
|
727
|
+
print(f"Warning: GPU warmup failed: {e}")
|