cortex-llm 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. cortex/__init__.py +73 -0
  2. cortex/__main__.py +83 -0
  3. cortex/config.py +329 -0
  4. cortex/conversation_manager.py +468 -0
  5. cortex/fine_tuning/__init__.py +8 -0
  6. cortex/fine_tuning/dataset.py +332 -0
  7. cortex/fine_tuning/mlx_lora_trainer.py +502 -0
  8. cortex/fine_tuning/trainer.py +957 -0
  9. cortex/fine_tuning/wizard.py +707 -0
  10. cortex/gpu_validator.py +467 -0
  11. cortex/inference_engine.py +727 -0
  12. cortex/metal/__init__.py +275 -0
  13. cortex/metal/gpu_validator.py +177 -0
  14. cortex/metal/memory_pool.py +886 -0
  15. cortex/metal/mlx_accelerator.py +678 -0
  16. cortex/metal/mlx_converter.py +638 -0
  17. cortex/metal/mps_optimizer.py +417 -0
  18. cortex/metal/optimizer.py +665 -0
  19. cortex/metal/performance_profiler.py +364 -0
  20. cortex/model_downloader.py +130 -0
  21. cortex/model_manager.py +2187 -0
  22. cortex/quantization/__init__.py +5 -0
  23. cortex/quantization/dynamic_quantizer.py +736 -0
  24. cortex/template_registry/__init__.py +15 -0
  25. cortex/template_registry/auto_detector.py +144 -0
  26. cortex/template_registry/config_manager.py +234 -0
  27. cortex/template_registry/interactive.py +260 -0
  28. cortex/template_registry/registry.py +347 -0
  29. cortex/template_registry/template_profiles/__init__.py +5 -0
  30. cortex/template_registry/template_profiles/base.py +142 -0
  31. cortex/template_registry/template_profiles/complex/__init__.py +5 -0
  32. cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
  33. cortex/template_registry/template_profiles/standard/__init__.py +9 -0
  34. cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
  35. cortex/template_registry/template_profiles/standard/chatml.py +82 -0
  36. cortex/template_registry/template_profiles/standard/gemma.py +103 -0
  37. cortex/template_registry/template_profiles/standard/llama.py +87 -0
  38. cortex/template_registry/template_profiles/standard/simple.py +65 -0
  39. cortex/ui/__init__.py +120 -0
  40. cortex/ui/cli.py +1685 -0
  41. cortex/ui/markdown_render.py +185 -0
  42. cortex/ui/terminal_app.py +534 -0
  43. cortex_llm-1.0.0.dist-info/METADATA +275 -0
  44. cortex_llm-1.0.0.dist-info/RECORD +48 -0
  45. cortex_llm-1.0.0.dist-info/WHEEL +5 -0
  46. cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
  47. cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
  48. cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,665 @@
1
+ """Unified Metal optimization interface for Apple Silicon LLM inference.
2
+
3
+ This module provides a simple, effective interface for accelerating LLM inference
4
+ on Apple Silicon using the most appropriate backend (MLX or MPS).
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ from typing import Dict, Any, Optional, Union, Tuple, Callable
10
+ from dataclasses import dataclass
11
+ from enum import Enum
12
+ import logging
13
+ import warnings
14
+
15
+ import torch
16
+ import numpy as np
17
+
18
+ # Configure logging
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Add parent directory to path for imports
22
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
23
+ from gpu_validator import GPUValidator
24
+
25
+
26
+ class Backend(Enum):
27
+ """Available acceleration backends."""
28
+ MLX = "mlx"
29
+ MPS = "mps"
30
+ CPU = "cpu"
31
+
32
+
33
+ @dataclass
34
+ class OptimizationConfig:
35
+ """Configuration for Metal optimization."""
36
+ backend: Backend = Backend.MLX
37
+ dtype: str = "auto" # auto, float32, float16, bfloat16
38
+ batch_size: int = 1
39
+ max_memory_gb: Optional[float] = None # None = auto-detect
40
+ use_quantization: bool = True
41
+ quantization_bits: int = 8 # 4 or 8
42
+ compile_model: bool = True
43
+ use_kv_cache: bool = True
44
+ kv_cache_size: int = 2048
45
+ enable_profiling: bool = False
46
+ fallback_to_cpu: bool = True
47
+
48
+ def validate(self) -> bool:
49
+ """Validate configuration settings."""
50
+ if self.quantization_bits not in [4, 8]:
51
+ raise ValueError(f"Quantization bits must be 4 or 8, got {self.quantization_bits}")
52
+
53
+ if self.batch_size < 1:
54
+ raise ValueError(f"Batch size must be >= 1, got {self.batch_size}")
55
+
56
+ if self.kv_cache_size < 0:
57
+ raise ValueError(f"KV cache size must be >= 0, got {self.kv_cache_size}")
58
+
59
+ return True
60
+
61
+
62
+ class MetalOptimizer:
63
+ """Unified Metal optimizer for LLM inference on Apple Silicon."""
64
+
65
+ def __init__(self, config: Optional[OptimizationConfig] = None):
66
+ """Initialize the Metal optimizer.
67
+
68
+ Args:
69
+ config: Optimization configuration. Uses defaults if None.
70
+ """
71
+ self.config = config or OptimizationConfig()
72
+ self.config.validate()
73
+
74
+ # Initialize GPU validator
75
+ self.gpu_validator = GPUValidator()
76
+ self.gpu_validator.validate()
77
+
78
+ # Detect best backend
79
+ self.backend = self._select_backend()
80
+ self.device = self._get_device()
81
+
82
+ # Initialize backend-specific components
83
+ self._backend_optimizer = None
84
+ self._initialize_backend()
85
+
86
+ logger.info(f"MetalOptimizer initialized with backend: {self.backend.value}")
87
+ logger.info(f"GPU: {self.gpu_validator.get_gpu_family()}, "
88
+ f"bfloat16: {self.gpu_validator.check_bfloat16_support()}")
89
+
90
+ def _select_backend(self) -> Backend:
91
+ """Select the best available backend.
92
+
93
+ Returns:
94
+ Selected backend based on availability and configuration.
95
+ """
96
+ if self.config.backend != Backend.MLX:
97
+ # User explicitly selected a backend
98
+ return self._validate_backend(self.config.backend)
99
+
100
+ # Auto-select best backend
101
+ try:
102
+ import mlx.core as mx
103
+ # MLX is available and preferred
104
+ return Backend.MLX
105
+ except ImportError:
106
+ logger.warning("MLX not available, falling back to MPS")
107
+
108
+ if torch.backends.mps.is_available():
109
+ return Backend.MPS
110
+
111
+ if self.config.fallback_to_cpu:
112
+ logger.warning("No GPU acceleration available, using CPU")
113
+ return Backend.CPU
114
+
115
+ raise RuntimeError("No suitable backend available and CPU fallback disabled")
116
+
117
+ def _validate_backend(self, backend: Backend) -> Backend:
118
+ """Validate that the requested backend is available.
119
+
120
+ Args:
121
+ backend: Requested backend
122
+
123
+ Returns:
124
+ The backend if available
125
+
126
+ Raises:
127
+ RuntimeError: If backend not available
128
+ """
129
+ if backend == Backend.MLX:
130
+ try:
131
+ import mlx.core as mx
132
+ return backend
133
+ except ImportError:
134
+ raise RuntimeError("MLX backend requested but not installed")
135
+
136
+ elif backend == Backend.MPS:
137
+ if not torch.backends.mps.is_available():
138
+ raise RuntimeError("MPS backend requested but not available")
139
+ return backend
140
+
141
+ elif backend == Backend.CPU:
142
+ return backend
143
+
144
+ raise ValueError(f"Unknown backend: {backend}")
145
+
146
+ def _get_device(self) -> Union[str, torch.device]:
147
+ """Get the appropriate device for the selected backend.
148
+
149
+ Returns:
150
+ Device object or string
151
+ """
152
+ if self.backend == Backend.MPS:
153
+ return torch.device("mps")
154
+ elif self.backend == Backend.CPU:
155
+ return torch.device("cpu")
156
+ else:
157
+ return "gpu" # MLX uses string device names
158
+
159
+ def _initialize_backend(self) -> None:
160
+ """Initialize the backend-specific optimizer."""
161
+ if self.backend == Backend.MLX:
162
+ self._initialize_mlx()
163
+ elif self.backend == Backend.MPS:
164
+ self._initialize_mps()
165
+ else:
166
+ # CPU backend needs no special initialization
167
+ pass
168
+
169
+ def _initialize_mlx(self) -> None:
170
+ """Initialize MLX backend."""
171
+ try:
172
+ from mlx_accelerator import MLXAccelerator, MLXConfig
173
+
174
+ mlx_config = MLXConfig(
175
+ compile_model=self.config.compile_model,
176
+ batch_size=self.config.batch_size,
177
+ dtype=self._get_mlx_dtype()
178
+ )
179
+
180
+ self._backend_optimizer = MLXAccelerator(mlx_config)
181
+
182
+ except ImportError as e:
183
+ logger.error(f"Failed to initialize MLX backend: {e}")
184
+ raise RuntimeError("MLX initialization failed")
185
+
186
+ def _initialize_mps(self) -> None:
187
+ """Initialize MPS backend."""
188
+ try:
189
+ from mps_optimizer import MPSOptimizer, MPSConfig
190
+
191
+ mps_config = MPSConfig(
192
+ use_fp16=(self.config.dtype in ["float16", "auto"]),
193
+ use_bfloat16=self._should_use_bfloat16(),
194
+ max_batch_size=self.config.batch_size,
195
+ optimize_memory=True
196
+ )
197
+
198
+ self._backend_optimizer = MPSOptimizer(mps_config)
199
+
200
+ except ImportError as e:
201
+ logger.error(f"Failed to initialize MPS backend: {e}")
202
+ raise RuntimeError("MPS initialization failed")
203
+
204
+ def _get_mlx_dtype(self):
205
+ """Get MLX dtype based on configuration and hardware."""
206
+ if self.config.dtype == "auto":
207
+ # Auto-select based on hardware
208
+ if self.gpu_validator.check_bfloat16_support():
209
+ import mlx.core as mx
210
+ return mx.bfloat16
211
+ else:
212
+ import mlx.core as mx
213
+ return mx.float16
214
+
215
+ elif self.config.dtype == "float32":
216
+ import mlx.core as mx
217
+ return mx.float32
218
+ elif self.config.dtype == "float16":
219
+ import mlx.core as mx
220
+ return mx.float16
221
+ elif self.config.dtype == "bfloat16":
222
+ import mlx.core as mx
223
+ return mx.bfloat16
224
+ else:
225
+ raise ValueError(f"Unknown dtype: {self.config.dtype}")
226
+
227
+ def _should_use_bfloat16(self) -> bool:
228
+ """Determine if bfloat16 should be used."""
229
+ if self.config.dtype == "bfloat16":
230
+ return True
231
+
232
+ if self.config.dtype == "auto":
233
+ return self.gpu_validator.check_bfloat16_support()
234
+
235
+ return False
236
+
237
+ def optimize_model(
238
+ self,
239
+ model: Any,
240
+ model_type: str = "auto"
241
+ ) -> Tuple[Any, Dict[str, Any]]:
242
+ """Optimize a model for inference.
243
+
244
+ Args:
245
+ model: Model to optimize (PyTorch, MLX, or Hugging Face)
246
+ model_type: Type of model ("pytorch", "mlx", "transformers", "auto")
247
+
248
+ Returns:
249
+ Tuple of (optimized_model, optimization_info)
250
+ """
251
+ if model_type == "auto":
252
+ model_type = self._detect_model_type(model)
253
+
254
+ logger.info(f"Optimizing {model_type} model with {self.backend.value} backend")
255
+
256
+ optimization_info = {
257
+ "backend": self.backend.value,
258
+ "device": str(self.device),
259
+ "dtype": self.config.dtype,
260
+ "quantization": self.config.use_quantization,
261
+ "quantization_bits": self.config.quantization_bits if self.config.use_quantization else None,
262
+ "gpu_family": self.gpu_validator.get_gpu_family(),
263
+ "optimizations_applied": []
264
+ }
265
+
266
+ if self.backend == Backend.MLX:
267
+ optimized_model = self._optimize_mlx_model(model, model_type)
268
+ optimization_info["optimizations_applied"].extend([
269
+ "mlx_compilation",
270
+ "dtype_optimization",
271
+ "memory_layout"
272
+ ])
273
+
274
+ elif self.backend == Backend.MPS:
275
+ optimized_model = self._optimize_mps_model(model, model_type)
276
+ optimization_info["optimizations_applied"].extend([
277
+ "mps_optimization",
278
+ "dtype_conversion",
279
+ "memory_optimization"
280
+ ])
281
+
282
+ else:
283
+ # CPU - minimal optimization
284
+ optimized_model = model
285
+ optimization_info["optimizations_applied"].append("none")
286
+
287
+ # Apply quantization if requested
288
+ if self.config.use_quantization:
289
+ optimized_model = self._apply_quantization(
290
+ optimized_model,
291
+ self.config.quantization_bits
292
+ )
293
+ optimization_info["optimizations_applied"].append(
294
+ f"int{self.config.quantization_bits}_quantization"
295
+ )
296
+
297
+ return optimized_model, optimization_info
298
+
299
+ def _detect_model_type(self, model: Any) -> str:
300
+ """Detect the type of model.
301
+
302
+ Args:
303
+ model: Model to detect
304
+
305
+ Returns:
306
+ Model type string
307
+ """
308
+ # Check for PyTorch model
309
+ if hasattr(model, "parameters") and hasattr(model, "forward"):
310
+ return "pytorch"
311
+
312
+ # Check for MLX model
313
+ if hasattr(model, "apply_to_parameters"):
314
+ return "mlx"
315
+
316
+ # Check for Hugging Face transformers
317
+ if hasattr(model, "config") and hasattr(model, "forward"):
318
+ return "transformers"
319
+
320
+ return "unknown"
321
+
322
+ def _optimize_mlx_model(self, model: Any, model_type: str) -> Any:
323
+ """Optimize model using MLX backend.
324
+
325
+ Args:
326
+ model: Model to optimize
327
+ model_type: Type of model
328
+
329
+ Returns:
330
+ Optimized model
331
+ """
332
+ if not self._backend_optimizer:
333
+ raise RuntimeError("MLX backend not initialized")
334
+
335
+ if model_type == "pytorch":
336
+ # Convert PyTorch model to MLX
337
+ logger.warning("PyTorch to MLX conversion not yet implemented")
338
+ return model
339
+
340
+ elif model_type == "mlx":
341
+ # Already MLX model, just optimize
342
+ return self._backend_optimizer.optimize_model(model)
343
+
344
+ elif model_type == "transformers":
345
+ # Convert Hugging Face model to MLX
346
+ logger.warning("Transformers to MLX conversion not yet implemented")
347
+ return model
348
+
349
+ return model
350
+
351
+ def _optimize_mps_model(self, model: Any, model_type: str) -> Any:
352
+ """Optimize model using MPS backend.
353
+
354
+ Args:
355
+ model: Model to optimize
356
+ model_type: Type of model
357
+
358
+ Returns:
359
+ Optimized model
360
+ """
361
+ if not self._backend_optimizer:
362
+ raise RuntimeError("MPS backend not initialized")
363
+
364
+ if model_type in ["pytorch", "transformers"]:
365
+ # MPS works with PyTorch models
366
+ return self._backend_optimizer.optimize_model(model)
367
+
368
+ elif model_type == "mlx":
369
+ # Cannot use MLX model with MPS
370
+ logger.error("Cannot use MLX model with MPS backend")
371
+ raise ValueError("MLX models not compatible with MPS backend")
372
+
373
+ return model
374
+
375
+ def _apply_quantization(self, model: Any, bits: int) -> Any:
376
+ """Apply quantization to the model.
377
+
378
+ Args:
379
+ model: Model to quantize
380
+ bits: Quantization bits (4 or 8)
381
+
382
+ Returns:
383
+ Quantized model
384
+ """
385
+ if self.backend == Backend.MLX:
386
+ # Use MLX quantization
387
+ if hasattr(self._backend_optimizer, "quantize_model"):
388
+ return self._backend_optimizer.quantize_model(model, bits)
389
+
390
+ elif self.backend == Backend.MPS:
391
+ # Use custom quantization for PyTorch
392
+ try:
393
+ from quantization.dynamic_quantizer import DynamicQuantizer, QuantizationConfig
394
+
395
+ quantizer = DynamicQuantizer(
396
+ QuantizationConfig(
397
+ mode="int8" if bits == 8 else "int4",
398
+ device=self.device
399
+ )
400
+ )
401
+
402
+ quantized_model, _ = quantizer.quantize_model(model)
403
+ return quantized_model
404
+
405
+ except ImportError:
406
+ logger.warning("Quantization module not available")
407
+ return model
408
+
409
+ return model
410
+
411
+ def create_inference_session(
412
+ self,
413
+ model: Any,
414
+ tokenizer: Optional[Any] = None
415
+ ) -> 'InferenceSession':
416
+ """Create an optimized inference session.
417
+
418
+ Args:
419
+ model: Model for inference
420
+ tokenizer: Optional tokenizer
421
+
422
+ Returns:
423
+ InferenceSession object
424
+ """
425
+ optimized_model, info = self.optimize_model(model)
426
+
427
+ return InferenceSession(
428
+ model=optimized_model,
429
+ tokenizer=tokenizer,
430
+ optimizer=self,
431
+ optimization_info=info
432
+ )
433
+
434
+ def get_memory_usage(self) -> Dict[str, float]:
435
+ """Get current memory usage statistics.
436
+
437
+ Returns:
438
+ Dictionary with memory statistics in GB
439
+ """
440
+ import psutil
441
+
442
+ vm = psutil.virtual_memory()
443
+ stats = {
444
+ "total_gb": vm.total / (1024**3),
445
+ "available_gb": vm.available / (1024**3),
446
+ "used_gb": vm.used / (1024**3),
447
+ "percent_used": vm.percent
448
+ }
449
+
450
+ if self.backend == Backend.MPS:
451
+ if hasattr(torch.mps, "current_allocated_memory"):
452
+ stats["mps_allocated_gb"] = torch.mps.current_allocated_memory() / (1024**3)
453
+
454
+ return stats
455
+
456
+ def profile_inference(
457
+ self,
458
+ model: Any,
459
+ input_shape: Tuple[int, ...],
460
+ num_iterations: int = 100
461
+ ) -> Dict[str, Any]:
462
+ """Profile model inference performance.
463
+
464
+ Args:
465
+ model: Model to profile
466
+ input_shape: Shape of input tensor
467
+ num_iterations: Number of iterations for profiling
468
+
469
+ Returns:
470
+ Profiling results
471
+ """
472
+ if self.backend == Backend.MLX and self._backend_optimizer:
473
+ return self._backend_optimizer.profile_model(
474
+ model, input_shape, num_iterations
475
+ )
476
+
477
+ elif self.backend == Backend.MPS and self._backend_optimizer:
478
+ return self._backend_optimizer.profile_model(
479
+ model, input_shape, num_iterations
480
+ )
481
+
482
+ # Basic CPU profiling
483
+ import time
484
+
485
+ if self.backend == Backend.CPU:
486
+ dummy_input = torch.randn(input_shape)
487
+ else:
488
+ dummy_input = torch.randn(input_shape).to(self.device)
489
+
490
+ # Warmup
491
+ for _ in range(10):
492
+ with torch.no_grad():
493
+ _ = model(dummy_input)
494
+
495
+ # Profile
496
+ start_time = time.perf_counter()
497
+ for _ in range(num_iterations):
498
+ with torch.no_grad():
499
+ _ = model(dummy_input)
500
+
501
+ end_time = time.perf_counter()
502
+ avg_time = (end_time - start_time) / num_iterations
503
+
504
+ return {
505
+ "backend": self.backend.value,
506
+ "avg_inference_time": avg_time,
507
+ "throughput": input_shape[0] / avg_time,
508
+ "device": str(self.device),
509
+ "iterations": num_iterations
510
+ }
511
+
512
+ def cleanup(self) -> None:
513
+ """Clean up resources."""
514
+ if self.backend == Backend.MPS:
515
+ torch.mps.empty_cache()
516
+
517
+ if hasattr(self._backend_optimizer, "cleanup"):
518
+ self._backend_optimizer.cleanup()
519
+
520
+
521
+ class InferenceSession:
522
+ """Optimized inference session for LLM models."""
523
+
524
+ def __init__(
525
+ self,
526
+ model: Any,
527
+ tokenizer: Optional[Any],
528
+ optimizer: MetalOptimizer,
529
+ optimization_info: Dict[str, Any]
530
+ ):
531
+ """Initialize inference session.
532
+
533
+ Args:
534
+ model: Optimized model
535
+ tokenizer: Optional tokenizer
536
+ optimizer: MetalOptimizer instance
537
+ optimization_info: Optimization information
538
+ """
539
+ self.model = model
540
+ self.tokenizer = tokenizer
541
+ self.optimizer = optimizer
542
+ self.optimization_info = optimization_info
543
+ self.generation_config = {}
544
+
545
+ def generate(
546
+ self,
547
+ prompt: Union[str, torch.Tensor],
548
+ max_tokens: int = 100,
549
+ temperature: float = 0.7,
550
+ top_p: float = 0.9,
551
+ **kwargs
552
+ ) -> Union[str, torch.Tensor]:
553
+ """Generate text from prompt.
554
+
555
+ Args:
556
+ prompt: Input prompt (string or tensor)
557
+ max_tokens: Maximum tokens to generate
558
+ temperature: Sampling temperature
559
+ top_p: Nucleus sampling parameter
560
+ **kwargs: Additional generation parameters
561
+
562
+ Returns:
563
+ Generated text or tokens
564
+ """
565
+ # This is a simplified interface - actual implementation
566
+ # would depend on the model type and backend
567
+
568
+ if isinstance(prompt, str) and self.tokenizer:
569
+ # Encode prompt
570
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
571
+
572
+ if self.optimizer.backend == Backend.MPS:
573
+ input_ids = input_ids.to(self.optimizer.device)
574
+ else:
575
+ input_ids = prompt
576
+
577
+ # Generate based on backend
578
+ if self.optimizer.backend == Backend.MLX:
579
+ # MLX generation (simplified)
580
+ output = self._generate_mlx(input_ids, max_tokens, temperature, top_p)
581
+ else:
582
+ # PyTorch generation
583
+ output = self._generate_pytorch(input_ids, max_tokens, temperature, top_p, **kwargs)
584
+
585
+ # Decode if tokenizer available
586
+ if self.tokenizer and not isinstance(output, str):
587
+ return self.tokenizer.decode(output[0], skip_special_tokens=True)
588
+
589
+ return output
590
+
591
+ def _generate_mlx(
592
+ self,
593
+ input_ids: Any,
594
+ max_tokens: int,
595
+ temperature: float,
596
+ top_p: float
597
+ ) -> Any:
598
+ """Generate using MLX backend."""
599
+ # Simplified MLX generation
600
+ # Actual implementation would use MLX-specific generation
601
+ return input_ids
602
+
603
+ def _generate_pytorch(
604
+ self,
605
+ input_ids: torch.Tensor,
606
+ max_tokens: int,
607
+ temperature: float,
608
+ top_p: float,
609
+ **kwargs
610
+ ) -> torch.Tensor:
611
+ """Generate using PyTorch backend."""
612
+ # Use model's generate method if available
613
+ if hasattr(self.model, "generate"):
614
+ return self.model.generate(
615
+ input_ids,
616
+ max_new_tokens=max_tokens,
617
+ temperature=temperature,
618
+ top_p=top_p,
619
+ do_sample=temperature > 0,
620
+ **kwargs
621
+ )
622
+
623
+ # Simple generation loop (fallback)
624
+ generated = input_ids
625
+ for _ in range(max_tokens):
626
+ with torch.no_grad():
627
+ outputs = self.model(generated)
628
+ logits = outputs.logits if hasattr(outputs, "logits") else outputs
629
+ next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
630
+ generated = torch.cat([generated, next_token], dim=1)
631
+
632
+ return generated
633
+
634
+ def stream_generate(
635
+ self,
636
+ prompt: str,
637
+ max_tokens: int = 100,
638
+ callback: Optional[Callable] = None,
639
+ **kwargs
640
+ ):
641
+ """Stream generation token by token.
642
+
643
+ Args:
644
+ prompt: Input prompt
645
+ max_tokens: Maximum tokens to generate
646
+ callback: Optional callback for each token
647
+ **kwargs: Additional parameters
648
+
649
+ Yields:
650
+ Generated tokens
651
+ """
652
+ # This would implement streaming generation
653
+ # For now, just yield the final result
654
+ result = self.generate(prompt, max_tokens, **kwargs)
655
+ yield result
656
+
657
+ def get_info(self) -> Dict[str, Any]:
658
+ """Get session information.
659
+
660
+ Returns:
661
+ Session information dictionary
662
+ """
663
+ info = self.optimization_info.copy()
664
+ info["memory_usage"] = self.optimizer.get_memory_usage()
665
+ return info