cortex-llm 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. cortex/__init__.py +73 -0
  2. cortex/__main__.py +83 -0
  3. cortex/config.py +329 -0
  4. cortex/conversation_manager.py +468 -0
  5. cortex/fine_tuning/__init__.py +8 -0
  6. cortex/fine_tuning/dataset.py +332 -0
  7. cortex/fine_tuning/mlx_lora_trainer.py +502 -0
  8. cortex/fine_tuning/trainer.py +957 -0
  9. cortex/fine_tuning/wizard.py +707 -0
  10. cortex/gpu_validator.py +467 -0
  11. cortex/inference_engine.py +727 -0
  12. cortex/metal/__init__.py +275 -0
  13. cortex/metal/gpu_validator.py +177 -0
  14. cortex/metal/memory_pool.py +886 -0
  15. cortex/metal/mlx_accelerator.py +678 -0
  16. cortex/metal/mlx_converter.py +638 -0
  17. cortex/metal/mps_optimizer.py +417 -0
  18. cortex/metal/optimizer.py +665 -0
  19. cortex/metal/performance_profiler.py +364 -0
  20. cortex/model_downloader.py +130 -0
  21. cortex/model_manager.py +2187 -0
  22. cortex/quantization/__init__.py +5 -0
  23. cortex/quantization/dynamic_quantizer.py +736 -0
  24. cortex/template_registry/__init__.py +15 -0
  25. cortex/template_registry/auto_detector.py +144 -0
  26. cortex/template_registry/config_manager.py +234 -0
  27. cortex/template_registry/interactive.py +260 -0
  28. cortex/template_registry/registry.py +347 -0
  29. cortex/template_registry/template_profiles/__init__.py +5 -0
  30. cortex/template_registry/template_profiles/base.py +142 -0
  31. cortex/template_registry/template_profiles/complex/__init__.py +5 -0
  32. cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
  33. cortex/template_registry/template_profiles/standard/__init__.py +9 -0
  34. cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
  35. cortex/template_registry/template_profiles/standard/chatml.py +82 -0
  36. cortex/template_registry/template_profiles/standard/gemma.py +103 -0
  37. cortex/template_registry/template_profiles/standard/llama.py +87 -0
  38. cortex/template_registry/template_profiles/standard/simple.py +65 -0
  39. cortex/ui/__init__.py +120 -0
  40. cortex/ui/cli.py +1685 -0
  41. cortex/ui/markdown_render.py +185 -0
  42. cortex/ui/terminal_app.py +534 -0
  43. cortex_llm-1.0.0.dist-info/METADATA +275 -0
  44. cortex_llm-1.0.0.dist-info/RECORD +48 -0
  45. cortex_llm-1.0.0.dist-info/WHEEL +5 -0
  46. cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
  47. cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
  48. cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,727 @@
1
+ """GPU-only inference engine for Cortex."""
2
+
3
+ import sys
4
+ import time
5
+ import asyncio
6
+ import logging
7
+ from typing import Dict, Any, Optional, List, Generator, AsyncGenerator, Tuple
8
+ from dataclasses import dataclass
9
+ from enum import Enum
10
+ import threading
11
+ from queue import Queue
12
+ import numpy as np
13
+
14
+ # Configure logging
15
+ logger = logging.getLogger(__name__)
16
+ logger.setLevel(logging.INFO)
17
+
18
+ import torch
19
+ import mlx.core as mx
20
+ import mlx.nn as nn
21
+
22
+ # Import MLX LM functions safely
23
+ try:
24
+ from mlx_lm import generate as mlx_generate, stream_generate as mlx_stream_generate
25
+ except ImportError:
26
+ mlx_generate = None
27
+ mlx_stream_generate = None
28
+
29
+ from cortex.config import Config
30
+ from cortex.model_manager import ModelManager, ModelFormat
31
+ from cortex.metal.memory_pool import MemoryPool, AllocationStrategy
32
+ from cortex.metal.mps_optimizer import MPSOptimizer, MPSConfig
33
+ from cortex.metal.mlx_accelerator import MLXAccelerator, MLXConfig
34
+ from cortex.metal.performance_profiler import PerformanceProfiler
35
+
36
+ class InferenceStatus(Enum):
37
+ """Status of inference operation."""
38
+ IDLE = "idle"
39
+ LOADING = "loading"
40
+ GENERATING = "generating"
41
+ COMPLETED = "completed"
42
+ ERROR = "error"
43
+ CANCELLED = "cancelled"
44
+
45
+ @dataclass
46
+ class GenerationMetrics:
47
+ """Metrics for generation performance."""
48
+ tokens_generated: int
49
+ time_elapsed: float
50
+ tokens_per_second: float
51
+ gpu_utilization: float
52
+ memory_used_gb: float
53
+ first_token_latency: float
54
+
55
+ def to_dict(self) -> Dict[str, Any]:
56
+ """Convert to dictionary."""
57
+ return {
58
+ 'tokens_generated': self.tokens_generated,
59
+ 'time_elapsed': self.time_elapsed,
60
+ 'tokens_per_second': self.tokens_per_second,
61
+ 'gpu_utilization': self.gpu_utilization,
62
+ 'memory_used_gb': self.memory_used_gb,
63
+ 'first_token_latency': self.first_token_latency
64
+ }
65
+
66
+ @dataclass
67
+ class GenerationRequest:
68
+ """Request for text generation."""
69
+ prompt: str
70
+ max_tokens: int = 2048
71
+ temperature: float = 0.7
72
+ top_p: float = 0.95
73
+ top_k: int = 40
74
+ repetition_penalty: float = 1.1
75
+ stop_sequences: List[str] = None
76
+ stream: bool = True
77
+ seed: Optional[int] = None
78
+
79
+ def __post_init__(self):
80
+ if self.stop_sequences is None:
81
+ self.stop_sequences = []
82
+
83
+ class InferenceEngine:
84
+ """GPU-accelerated inference engine."""
85
+
86
+ def __init__(self, config: Config, model_manager: ModelManager):
87
+ """Initialize inference engine."""
88
+ self.config = config
89
+ self.model_manager = model_manager
90
+ self.status = InferenceStatus.IDLE
91
+ self.current_metrics: Optional[GenerationMetrics] = None
92
+ self._cancel_event = threading.Event()
93
+ self._generation_lock = threading.Lock()
94
+
95
+ # Initialize Metal optimizations
96
+ self.memory_pool: Optional[MemoryPool] = None
97
+ self.mps_optimizer: Optional[MPSOptimizer] = None
98
+ self.mlx_accelerator: Optional[MLXAccelerator] = None
99
+ self.profiler = PerformanceProfiler(sample_interval=0.1)
100
+
101
+ self._ensure_gpu_backend()
102
+ self._initialize_metal_optimizations()
103
+
104
+ def _ensure_gpu_backend(self) -> None:
105
+ """Ensure GPU backend is available."""
106
+ if not torch.backends.mps.is_available():
107
+ print("❌ MPS backend not available. GPU acceleration required.")
108
+ sys.exit(1)
109
+
110
+ try:
111
+ mx.default_device()
112
+ except Exception as e:
113
+ print(f"❌ MLX not available: {e}")
114
+ print("GPU acceleration via MLX is required.")
115
+ sys.exit(1)
116
+
117
+ def _initialize_metal_optimizations(self) -> None:
118
+ """Initialize Metal-specific optimizations."""
119
+ # Initialize shared memory pool with auto-sizing
120
+ if self.config.gpu.force_gpu and self.memory_pool is None:
121
+ # Create a single shared memory pool to avoid duplication
122
+ self.memory_pool = MemoryPool(
123
+ pool_size=None, # Will auto-size based on available memory
124
+ strategy=AllocationStrategy.UNIFIED,
125
+ device="mps" if torch.backends.mps.is_available() else "mlx",
126
+ auto_size=True # Enable auto-sizing
127
+ )
128
+
129
+ # Share the pool with model manager to avoid duplication
130
+ if hasattr(self.model_manager, 'memory_pool') and self.model_manager.memory_pool is None:
131
+ self.model_manager.memory_pool = self.memory_pool
132
+
133
+ # Initialize MPS optimizer
134
+ if torch.backends.mps.is_available():
135
+ mps_config = MPSConfig(
136
+ use_fp16=True,
137
+ use_channels_last=True,
138
+ optimize_memory=True,
139
+ max_batch_size=self.config.performance.batch_size
140
+ )
141
+ self.mps_optimizer = MPSOptimizer(mps_config)
142
+
143
+ # Initialize MLX accelerator with AMX and advanced features
144
+ try:
145
+ mlx_config = MLXConfig(
146
+ compile_model=True,
147
+ use_graph=True,
148
+ batch_size=self.config.performance.batch_size,
149
+ dtype=mx.bfloat16 if self._supports_bfloat16() else mx.float16,
150
+ use_amx=True,
151
+ fuse_operations=True,
152
+ lazy_evaluation=True,
153
+ rotating_kv_cache=True,
154
+ kv_cache_size=self.config.model.context_length if hasattr(self.config.model, 'context_length') else 4096,
155
+ quantization_bits=4
156
+ )
157
+ self.mlx_accelerator = MLXAccelerator(mlx_config)
158
+ print("✓ MLX accelerator initialized with AMX support")
159
+ except Exception as e:
160
+ print(f"Warning: MLX accelerator initialization failed: {e}")
161
+ self.mlx_accelerator = None
162
+
163
+ # GPU acceleration handled by MLX and MPS backends
164
+
165
+ def generate(
166
+ self,
167
+ request: GenerationRequest
168
+ ) -> Generator[str, None, GenerationMetrics]:
169
+ """
170
+ Generate text using GPU-accelerated inference.
171
+
172
+ Args:
173
+ request: Generation request parameters
174
+
175
+ Yields:
176
+ Generated text tokens
177
+
178
+ Returns:
179
+ Generation metrics
180
+ """
181
+ with self._generation_lock:
182
+ if self.status == InferenceStatus.GENERATING:
183
+ raise RuntimeError("Generation already in progress")
184
+
185
+ self.status = InferenceStatus.GENERATING
186
+ self._cancel_event.clear()
187
+
188
+ try:
189
+ model_info = self.model_manager.get_current_model()
190
+ if not model_info:
191
+ raise RuntimeError("No model loaded")
192
+
193
+ model = self.model_manager.model_cache.get(model_info.name)
194
+ tokenizer = self.model_manager.tokenizers.get(model_info.name)
195
+
196
+ if not model or not tokenizer:
197
+ raise RuntimeError(f"Model '{model_info.name}' not properly loaded")
198
+
199
+ if model_info.format == ModelFormat.MLX:
200
+ yield from self._generate_mlx(model, tokenizer, request)
201
+ elif model_info.format == ModelFormat.PYTORCH:
202
+ yield from self._generate_pytorch(model, tokenizer, request)
203
+ elif model_info.format == ModelFormat.SAFETENSORS:
204
+ yield from self._generate_safetensors(model, tokenizer, request)
205
+ elif model_info.format == ModelFormat.GGUF:
206
+ yield from self._generate_gguf(model, tokenizer, request)
207
+ else:
208
+ raise RuntimeError(f"Unsupported format: {model_info.format}")
209
+
210
+ return self.current_metrics
211
+
212
+ except Exception as e:
213
+ self.status = InferenceStatus.ERROR
214
+ raise e
215
+ finally:
216
+ if self.status != InferenceStatus.CANCELLED:
217
+ self.status = InferenceStatus.COMPLETED
218
+
219
+ def _generate_mlx(
220
+ self,
221
+ model: Any,
222
+ tokenizer: Any,
223
+ request: GenerationRequest
224
+ ) -> Generator[str, None, None]:
225
+ """Generate using MLX model on GPU with Metal optimizations."""
226
+ # Apply MLX optimizations if available
227
+ if self.mlx_accelerator:
228
+ logger.info("Applying MLX accelerator optimizations to model")
229
+ model = self.mlx_accelerator.optimize_model(model)
230
+
231
+ # Start profiling
232
+ self.profiler.start_profiling("mlx_generation", {
233
+ "model_type": "mlx",
234
+ "max_tokens": request.max_tokens
235
+ })
236
+
237
+ start_time = time.time()
238
+ tokens_generated = 0
239
+ first_token_time = None
240
+ last_metrics_update = time.time()
241
+
242
+ try:
243
+ # Use MLX accelerator's optimized generation if available
244
+ if self.mlx_accelerator and request.stream:
245
+ logger.info("Using MLX accelerator optimized generation with AMX")
246
+ for token in self.mlx_accelerator.generate_optimized(
247
+ model,
248
+ tokenizer,
249
+ request.prompt,
250
+ max_tokens=request.max_tokens,
251
+ temperature=request.temperature,
252
+ top_p=request.top_p,
253
+ repetition_penalty=request.repetition_penalty,
254
+ stream=True,
255
+ stop_sequences=request.stop_sequences
256
+ ):
257
+ if self._cancel_event.is_set():
258
+ self.status = InferenceStatus.CANCELLED
259
+ break
260
+
261
+ if first_token_time is None:
262
+ first_token_time = time.time() - start_time
263
+
264
+ tokens_generated += 1
265
+
266
+ # Update metrics less frequently
267
+ current_time = time.time()
268
+ if current_time - last_metrics_update > 1.0 or tokens_generated % 50 == 0:
269
+ elapsed_time = current_time - start_time
270
+
271
+ self.current_metrics = GenerationMetrics(
272
+ tokens_generated=tokens_generated,
273
+ time_elapsed=elapsed_time,
274
+ tokens_per_second=tokens_generated / elapsed_time if elapsed_time > 0 else 0,
275
+ gpu_utilization=0.0,
276
+ memory_used_gb=0.0,
277
+ first_token_latency=first_token_time or 0
278
+ )
279
+ last_metrics_update = current_time
280
+
281
+ # Token is already a string from generate_optimized
282
+ yield token
283
+
284
+ if any(stop in token for stop in request.stop_sequences):
285
+ break
286
+ elif mlx_generate:
287
+ # Fallback to standard MLX generation
288
+ logger.info("Using standard MLX generation")
289
+
290
+ # Import sample_utils for creating sampler
291
+ try:
292
+ from mlx_lm.sample_utils import make_sampler
293
+ # Create sampler with temperature and top_p
294
+ sampler = make_sampler(request.temperature, top_p=request.top_p)
295
+ logger.debug(f"Created sampler with temp={request.temperature}, top_p={request.top_p}")
296
+ except ImportError:
297
+ sampler = None
298
+ logger.warning("mlx_lm.sample_utils not available, using default sampler")
299
+
300
+ # Build generation kwargs
301
+ generation_kwargs = {
302
+ 'prompt': request.prompt,
303
+ 'max_tokens': request.max_tokens,
304
+ }
305
+
306
+ if sampler is not None:
307
+ generation_kwargs['sampler'] = sampler
308
+
309
+ if request.seed is not None and request.seed >= 0:
310
+ mx.random.seed(request.seed)
311
+
312
+ for response in mlx_generate(
313
+ model,
314
+ tokenizer,
315
+ **generation_kwargs
316
+ ):
317
+ if self._cancel_event.is_set():
318
+ self.status = InferenceStatus.CANCELLED
319
+ break
320
+
321
+ # Extract text from GenerationResponse
322
+ if hasattr(response, 'text'):
323
+ token = response.text
324
+ else:
325
+ token = str(response)
326
+
327
+ if first_token_time is None:
328
+ first_token_time = time.time() - start_time
329
+
330
+ tokens_generated += 1
331
+
332
+ # Update metrics less frequently to reduce overhead
333
+ # Only update every 50 tokens or 1 second for better performance
334
+ current_time = time.time()
335
+ if current_time - last_metrics_update > 1.0 or tokens_generated % 50 == 0:
336
+ elapsed_time = current_time - start_time
337
+
338
+ # Skip expensive GPU queries during generation for better performance
339
+ # These will be calculated once at the end
340
+ self.current_metrics = GenerationMetrics(
341
+ tokens_generated=tokens_generated,
342
+ time_elapsed=elapsed_time,
343
+ tokens_per_second=tokens_generated / elapsed_time if elapsed_time > 0 else 0,
344
+ gpu_utilization=0.0, # Skip during generation
345
+ memory_used_gb=0.0, # Skip during generation
346
+ first_token_latency=first_token_time or 0
347
+ )
348
+ last_metrics_update = current_time
349
+
350
+ yield token
351
+
352
+ if any(stop in token for stop in request.stop_sequences):
353
+ break
354
+ else:
355
+ # No MLX generation available
356
+ logger.error("MLX generation functions not available")
357
+ raise RuntimeError("MLX generation not available. Please install mlx-lm.")
358
+
359
+ elapsed_time = time.time() - start_time
360
+
361
+ # Stop profiling and get final results
362
+ profile_result = self.profiler.stop_profiling()
363
+
364
+ # Update final metrics
365
+ self.current_metrics = GenerationMetrics(
366
+ tokens_generated=tokens_generated,
367
+ time_elapsed=elapsed_time,
368
+ tokens_per_second=tokens_generated / elapsed_time if elapsed_time > 0 else 0,
369
+ gpu_utilization=profile_result.gpu_utilization,
370
+ memory_used_gb=profile_result.memory_used_mb / 1024,
371
+ first_token_latency=first_token_time or 0
372
+ )
373
+
374
+ except Exception as e:
375
+ self.status = InferenceStatus.ERROR
376
+ self.profiler.stop_profiling()
377
+ raise e
378
+
379
+ def _generate_pytorch(
380
+ self,
381
+ model: Any,
382
+ tokenizer: Any,
383
+ request: GenerationRequest
384
+ ) -> Generator[str, None, None]:
385
+ """Generate using PyTorch model on MPS with Metal optimizations."""
386
+ # Apply MPS optimizations if available
387
+ if self.mps_optimizer:
388
+ model = self.mps_optimizer.optimize_model(model)
389
+
390
+ # Start profiling
391
+ self.profiler.start_profiling("pytorch_generation", {
392
+ "model_type": "pytorch",
393
+ "max_tokens": request.max_tokens
394
+ })
395
+
396
+ start_time = time.time()
397
+ tokens_generated = 0
398
+ first_token_time = None
399
+ last_metrics_update = time.time()
400
+
401
+ try:
402
+ device = torch.device("mps")
403
+
404
+ inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
405
+
406
+ generation_config = {
407
+ 'max_new_tokens': request.max_tokens,
408
+ 'temperature': request.temperature,
409
+ 'top_p': request.top_p,
410
+ 'top_k': request.top_k,
411
+ 'repetition_penalty': request.repetition_penalty,
412
+ 'do_sample': request.temperature > 0,
413
+ 'pad_token_id': tokenizer.pad_token_id,
414
+ 'eos_token_id': tokenizer.eos_token_id,
415
+ }
416
+
417
+ if request.seed is not None and request.seed >= 0:
418
+ torch.manual_seed(request.seed)
419
+
420
+ with torch.no_grad():
421
+ if request.stream:
422
+ from transformers import TextIteratorStreamer
423
+
424
+ streamer = TextIteratorStreamer(
425
+ tokenizer,
426
+ skip_prompt=True,
427
+ skip_special_tokens=True
428
+ )
429
+
430
+ generation_kwargs = dict(
431
+ inputs,
432
+ streamer=streamer,
433
+ **generation_config
434
+ )
435
+
436
+ thread = threading.Thread(
437
+ target=model.generate,
438
+ kwargs=generation_kwargs
439
+ )
440
+ thread.start()
441
+
442
+ for token in streamer:
443
+ if self._cancel_event.is_set():
444
+ self.status = InferenceStatus.CANCELLED
445
+ break
446
+
447
+ if first_token_time is None:
448
+ first_token_time = time.time() - start_time
449
+
450
+ tokens_generated += 1
451
+
452
+ # Update metrics less frequently to reduce overhead
453
+ # Only update every 50 tokens or 1 second for better performance
454
+ current_time = time.time()
455
+ if current_time - last_metrics_update > 1.0 or tokens_generated % 50 == 0:
456
+ elapsed_time = current_time - start_time
457
+
458
+ # Skip expensive GPU queries during generation for better performance
459
+ # These will be calculated once at the end
460
+ self.current_metrics = GenerationMetrics(
461
+ tokens_generated=tokens_generated,
462
+ time_elapsed=elapsed_time,
463
+ tokens_per_second=tokens_generated / elapsed_time if elapsed_time > 0 else 0,
464
+ gpu_utilization=0.0, # Skip during generation
465
+ memory_used_gb=0.0, # Skip during generation
466
+ first_token_latency=first_token_time or 0
467
+ )
468
+ last_metrics_update = current_time
469
+
470
+ yield token
471
+
472
+ if any(stop in token for stop in request.stop_sequences):
473
+ break
474
+
475
+ thread.join()
476
+
477
+ else:
478
+ outputs = model.generate(
479
+ **inputs,
480
+ **generation_config
481
+ )
482
+
483
+ generated_text = tokenizer.decode(
484
+ outputs[0][inputs['input_ids'].shape[1]:],
485
+ skip_special_tokens=True
486
+ )
487
+
488
+ tokens_generated = outputs.shape[1] - inputs['input_ids'].shape[1]
489
+ first_token_time = (time.time() - start_time) / tokens_generated if tokens_generated > 0 else 0
490
+
491
+ yield generated_text
492
+
493
+ elapsed_time = time.time() - start_time
494
+
495
+ # Stop profiling and get final results
496
+ profile_result = self.profiler.stop_profiling()
497
+
498
+ # Update final metrics
499
+ self.current_metrics = GenerationMetrics(
500
+ tokens_generated=tokens_generated,
501
+ time_elapsed=elapsed_time,
502
+ tokens_per_second=tokens_generated / elapsed_time if elapsed_time > 0 else 0,
503
+ gpu_utilization=profile_result.gpu_utilization,
504
+ memory_used_gb=profile_result.memory_used_mb / 1024,
505
+ first_token_latency=first_token_time or 0
506
+ )
507
+
508
+ except Exception as e:
509
+ self.status = InferenceStatus.ERROR
510
+ self.profiler.stop_profiling()
511
+ raise e
512
+
513
+ def _generate_safetensors(
514
+ self,
515
+ model: Any,
516
+ tokenizer: Any,
517
+ request: GenerationRequest
518
+ ) -> Generator[str, None, None]:
519
+ """Generate using SafeTensors model (loaded as PyTorch) on MPS."""
520
+ # SafeTensors models are loaded as PyTorch models, so use the same generation logic
521
+ yield from self._generate_pytorch(model, tokenizer, request)
522
+
523
+ def _generate_gguf(
524
+ self,
525
+ model: Any,
526
+ tokenizer: Any,
527
+ request: GenerationRequest
528
+ ) -> Generator[str, None, None]:
529
+ """Generate using GGUF model with llama-cpp-python."""
530
+ start_time = time.time()
531
+ first_token_time = None
532
+ tokens_generated = 0
533
+
534
+ try:
535
+ # GGUF models use llama-cpp-python which has its own generation method
536
+ # The model is a Llama object from llama-cpp-python
537
+
538
+ # Generate response using llama-cpp's native method
539
+ response = model(
540
+ request.prompt,
541
+ max_tokens=request.max_tokens,
542
+ temperature=request.temperature,
543
+ top_p=request.top_p,
544
+ top_k=request.top_k,
545
+ repeat_penalty=request.repetition_penalty,
546
+ stream=request.stream
547
+ )
548
+
549
+ if request.stream:
550
+ # Stream tokens
551
+ for chunk in response:
552
+ if self._cancel_event.is_set():
553
+ break
554
+
555
+ if 'choices' in chunk and len(chunk['choices']) > 0:
556
+ token = chunk['choices'][0].get('text', '')
557
+ if token:
558
+ if first_token_time is None:
559
+ first_token_time = time.time()
560
+ tokens_generated += 1
561
+ yield token
562
+ else:
563
+ # Return full response
564
+ if 'choices' in response and len(response['choices']) > 0:
565
+ text = response['choices'][0].get('text', '')
566
+ tokens_generated = len(text.split()) # Rough estimate
567
+ yield text
568
+
569
+ # Calculate metrics
570
+ end_time = time.time()
571
+ time_elapsed = end_time - start_time
572
+
573
+ self.current_metrics = GenerationMetrics(
574
+ tokens_generated=tokens_generated,
575
+ time_elapsed=time_elapsed,
576
+ tokens_per_second=tokens_generated / time_elapsed if time_elapsed > 0 else 0,
577
+ gpu_utilization=0.0, # GGUF doesn't provide GPU metrics directly
578
+ memory_used_gb=self.model_manager.get_memory_status().get('model_memory_gb', 0),
579
+ first_token_latency=first_token_time - start_time if first_token_time else 0
580
+ )
581
+
582
+ except Exception as e:
583
+ self.status = InferenceStatus.ERROR
584
+ raise e
585
+
586
+ async def generate_async(
587
+ self,
588
+ request: GenerationRequest
589
+ ) -> AsyncGenerator[str, None]:
590
+ """
591
+ Async generator for text generation.
592
+
593
+ Args:
594
+ request: Generation request parameters
595
+
596
+ Yields:
597
+ Generated text tokens
598
+ """
599
+ loop = asyncio.get_event_loop()
600
+ queue = Queue()
601
+
602
+ def generate_worker():
603
+ try:
604
+ for token in self.generate(request):
605
+ queue.put(token)
606
+ except Exception as e:
607
+ queue.put(e)
608
+ finally:
609
+ queue.put(None)
610
+
611
+ thread = threading.Thread(target=generate_worker)
612
+ thread.start()
613
+
614
+ while True:
615
+ result = await loop.run_in_executor(None, queue.get)
616
+
617
+ if result is None:
618
+ break
619
+ elif isinstance(result, Exception):
620
+ raise result
621
+ else:
622
+ yield result
623
+
624
+ thread.join()
625
+
626
+ def _supports_bfloat16(self) -> bool:
627
+ """Check if system supports bfloat16."""
628
+ try:
629
+ test = mx.array([1.0], dtype=mx.bfloat16)
630
+ mx.eval(test)
631
+ return True
632
+ except:
633
+ return False
634
+
635
+ def cancel_generation(self) -> None:
636
+ """Cancel ongoing generation."""
637
+ self._cancel_event.set()
638
+ self.status = InferenceStatus.CANCELLED
639
+
640
+ def _get_gpu_utilization(self) -> float:
641
+ """Get current GPU utilization percentage."""
642
+ try:
643
+ import psutil
644
+ process = psutil.Process()
645
+ return min(process.cpu_percent() * 2, 100.0)
646
+ except:
647
+ return 0.0
648
+
649
+ def _get_memory_usage(self) -> float:
650
+ """Get current GPU memory usage in GB."""
651
+ try:
652
+ if torch.backends.mps.is_available():
653
+ allocated = torch.mps.current_allocated_memory()
654
+ return allocated / (1024**3)
655
+ return 0.0
656
+ except:
657
+ return 0.0
658
+
659
+ def get_status(self) -> Dict[str, Any]:
660
+ """Get current inference status with MLX details."""
661
+ status = {
662
+ 'status': self.status.value,
663
+ 'model': self.model_manager.current_model,
664
+ 'metrics': self.current_metrics.to_dict() if self.current_metrics else None,
665
+ 'gpu_memory_gb': self._get_memory_usage(),
666
+ 'gpu_utilization': self._get_gpu_utilization()
667
+ }
668
+
669
+ # Add MLX accelerator status
670
+ if self.mlx_accelerator:
671
+ status['mlx_accelerator'] = {
672
+ 'enabled': True,
673
+ 'amx': self.mlx_accelerator.config.use_amx,
674
+ 'operation_fusion': self.mlx_accelerator.config.fuse_operations,
675
+ 'lazy_evaluation': self.mlx_accelerator.config.lazy_evaluation,
676
+ 'kv_cache': self.mlx_accelerator.config.rotating_kv_cache,
677
+ 'kv_cache_size': self.mlx_accelerator.config.kv_cache_size,
678
+ 'quantization_bits': self.mlx_accelerator.config.quantization_bits
679
+ }
680
+ else:
681
+ status['mlx_accelerator'] = {'enabled': False}
682
+
683
+ return status
684
+
685
+ def benchmark(
686
+ self,
687
+ prompt: str = "Once upon a time",
688
+ num_tokens: int = 100
689
+ ) -> GenerationMetrics:
690
+ """
691
+ Run a benchmark test.
692
+
693
+ Args:
694
+ prompt: Prompt to use for benchmark
695
+ num_tokens: Number of tokens to generate
696
+
697
+ Returns:
698
+ Benchmark metrics
699
+ """
700
+ request = GenerationRequest(
701
+ prompt=prompt,
702
+ max_tokens=num_tokens,
703
+ temperature=0.7,
704
+ stream=True
705
+ )
706
+
707
+ tokens = []
708
+ for token in self.generate(request):
709
+ tokens.append(token)
710
+
711
+ return self.current_metrics
712
+
713
+ def warmup(self) -> None:
714
+ """Warm up the GPU with a small generation."""
715
+ try:
716
+ request = GenerationRequest(
717
+ prompt="Hello",
718
+ max_tokens=1,
719
+ temperature=0.0,
720
+ stream=False
721
+ )
722
+
723
+ for _ in self.generate(request):
724
+ pass
725
+
726
+ except Exception as e:
727
+ print(f"Warning: GPU warmup failed: {e}")