cortex-llm 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortex/__init__.py +73 -0
- cortex/__main__.py +83 -0
- cortex/config.py +329 -0
- cortex/conversation_manager.py +468 -0
- cortex/fine_tuning/__init__.py +8 -0
- cortex/fine_tuning/dataset.py +332 -0
- cortex/fine_tuning/mlx_lora_trainer.py +502 -0
- cortex/fine_tuning/trainer.py +957 -0
- cortex/fine_tuning/wizard.py +707 -0
- cortex/gpu_validator.py +467 -0
- cortex/inference_engine.py +727 -0
- cortex/metal/__init__.py +275 -0
- cortex/metal/gpu_validator.py +177 -0
- cortex/metal/memory_pool.py +886 -0
- cortex/metal/mlx_accelerator.py +678 -0
- cortex/metal/mlx_converter.py +638 -0
- cortex/metal/mps_optimizer.py +417 -0
- cortex/metal/optimizer.py +665 -0
- cortex/metal/performance_profiler.py +364 -0
- cortex/model_downloader.py +130 -0
- cortex/model_manager.py +2187 -0
- cortex/quantization/__init__.py +5 -0
- cortex/quantization/dynamic_quantizer.py +736 -0
- cortex/template_registry/__init__.py +15 -0
- cortex/template_registry/auto_detector.py +144 -0
- cortex/template_registry/config_manager.py +234 -0
- cortex/template_registry/interactive.py +260 -0
- cortex/template_registry/registry.py +347 -0
- cortex/template_registry/template_profiles/__init__.py +5 -0
- cortex/template_registry/template_profiles/base.py +142 -0
- cortex/template_registry/template_profiles/complex/__init__.py +5 -0
- cortex/template_registry/template_profiles/complex/reasoning.py +263 -0
- cortex/template_registry/template_profiles/standard/__init__.py +9 -0
- cortex/template_registry/template_profiles/standard/alpaca.py +73 -0
- cortex/template_registry/template_profiles/standard/chatml.py +82 -0
- cortex/template_registry/template_profiles/standard/gemma.py +103 -0
- cortex/template_registry/template_profiles/standard/llama.py +87 -0
- cortex/template_registry/template_profiles/standard/simple.py +65 -0
- cortex/ui/__init__.py +120 -0
- cortex/ui/cli.py +1685 -0
- cortex/ui/markdown_render.py +185 -0
- cortex/ui/terminal_app.py +534 -0
- cortex_llm-1.0.0.dist-info/METADATA +275 -0
- cortex_llm-1.0.0.dist-info/RECORD +48 -0
- cortex_llm-1.0.0.dist-info/WHEEL +5 -0
- cortex_llm-1.0.0.dist-info/entry_points.txt +2 -0
- cortex_llm-1.0.0.dist-info/licenses/LICENSE +21 -0
- cortex_llm-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,417 @@
|
|
|
1
|
+
"""MPS backend optimization for PyTorch models on Metal."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
from typing import Dict, Any, Optional, List, Tuple, Callable
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
import functools
|
|
8
|
+
import warnings
|
|
9
|
+
import sys
|
|
10
|
+
import os
|
|
11
|
+
|
|
12
|
+
# Add parent directory to path for imports
|
|
13
|
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
14
|
+
from gpu_validator import GPUValidator
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class MPSConfig:
|
|
18
|
+
"""Configuration for MPS optimization."""
|
|
19
|
+
use_channels_last: bool = True
|
|
20
|
+
use_fp16: Optional[bool] = None # None = auto-detect based on hardware
|
|
21
|
+
use_bfloat16: Optional[bool] = None # None = auto-detect based on hardware
|
|
22
|
+
use_jit: bool = False # JIT not fully supported on MPS yet
|
|
23
|
+
use_graph_mode: bool = True
|
|
24
|
+
fuse_operations: bool = True
|
|
25
|
+
optimize_memory: bool = True
|
|
26
|
+
max_batch_size: int = 8
|
|
27
|
+
prefetch_factor: int = 2
|
|
28
|
+
num_workers: int = 0 # MPS works best with main thread
|
|
29
|
+
auto_detect_dtype: bool = True # Automatically select best dtype based on hardware
|
|
30
|
+
|
|
31
|
+
class MPSOptimizer:
|
|
32
|
+
"""Optimize PyTorch models for Metal Performance Shaders backend."""
|
|
33
|
+
|
|
34
|
+
FUSED_OPERATIONS = {
|
|
35
|
+
"conv_bn_relu": ["Conv2d", "BatchNorm2d", "ReLU"],
|
|
36
|
+
"linear_relu": ["Linear", "ReLU"],
|
|
37
|
+
"linear_gelu": ["Linear", "GELU"],
|
|
38
|
+
"layer_norm_linear": ["LayerNorm", "Linear"],
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
def __init__(self, config: Optional[MPSConfig] = None):
|
|
42
|
+
"""Initialize MPS optimizer."""
|
|
43
|
+
self.config = config or MPSConfig()
|
|
44
|
+
|
|
45
|
+
if not torch.backends.mps.is_available():
|
|
46
|
+
raise RuntimeError("MPS backend not available")
|
|
47
|
+
|
|
48
|
+
if not torch.backends.mps.is_built():
|
|
49
|
+
raise RuntimeError("PyTorch not built with MPS support")
|
|
50
|
+
|
|
51
|
+
self.device = torch.device("mps")
|
|
52
|
+
|
|
53
|
+
# Initialize GPU validator for hardware detection
|
|
54
|
+
self.gpu_validator = GPUValidator()
|
|
55
|
+
self.gpu_validator.validate()
|
|
56
|
+
|
|
57
|
+
# Auto-detect optimal dtype if enabled
|
|
58
|
+
if (
|
|
59
|
+
self.config.auto_detect_dtype
|
|
60
|
+
and self.config.use_bfloat16 is None
|
|
61
|
+
and self.config.use_fp16 is None
|
|
62
|
+
):
|
|
63
|
+
self._auto_detect_dtype()
|
|
64
|
+
elif self.config.use_bfloat16 is None and self.config.use_fp16 is None:
|
|
65
|
+
# Default to fp16 if auto-detect is disabled but no dtype is specified
|
|
66
|
+
self.config.use_bfloat16 = False
|
|
67
|
+
self.config.use_fp16 = True
|
|
68
|
+
|
|
69
|
+
def _auto_detect_dtype(self) -> None:
|
|
70
|
+
"""Automatically detect optimal dtype based on hardware."""
|
|
71
|
+
if self.gpu_validator.check_bfloat16_support():
|
|
72
|
+
# Check if PyTorch supports bfloat16 on MPS
|
|
73
|
+
try:
|
|
74
|
+
test_tensor = torch.tensor([1.0], dtype=torch.bfloat16, device=self.device)
|
|
75
|
+
self.config.use_bfloat16 = True
|
|
76
|
+
self.config.use_fp16 = False # Prefer bfloat16 over fp16
|
|
77
|
+
except (RuntimeError, TypeError):
|
|
78
|
+
# PyTorch doesn't support bfloat16 on MPS yet, fall back to fp16
|
|
79
|
+
self.config.use_bfloat16 = False
|
|
80
|
+
self.config.use_fp16 = True
|
|
81
|
+
else:
|
|
82
|
+
# Hardware doesn't support bfloat16, use fp16
|
|
83
|
+
self.config.use_bfloat16 = False
|
|
84
|
+
self.config.use_fp16 = True
|
|
85
|
+
|
|
86
|
+
def optimize_model(
|
|
87
|
+
self,
|
|
88
|
+
model: nn.Module,
|
|
89
|
+
example_input: Optional[torch.Tensor] = None
|
|
90
|
+
) -> nn.Module:
|
|
91
|
+
"""
|
|
92
|
+
Optimize a PyTorch model for MPS backend.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
model: PyTorch model to optimize
|
|
96
|
+
example_input: Example input for shape inference
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Optimized model
|
|
100
|
+
"""
|
|
101
|
+
model = model.to(self.device)
|
|
102
|
+
|
|
103
|
+
if self.config.use_channels_last:
|
|
104
|
+
model = self._convert_to_channels_last(model)
|
|
105
|
+
|
|
106
|
+
# Convert to optimal dtype (bfloat16 or fp16)
|
|
107
|
+
if self.config.use_bfloat16 or self.config.use_fp16:
|
|
108
|
+
model = self._convert_dtype(model)
|
|
109
|
+
|
|
110
|
+
if self.config.fuse_operations:
|
|
111
|
+
model = self._fuse_operations(model)
|
|
112
|
+
|
|
113
|
+
if self.config.optimize_memory:
|
|
114
|
+
model = self._optimize_memory_layout(model)
|
|
115
|
+
|
|
116
|
+
if self.config.use_graph_mode and example_input is not None:
|
|
117
|
+
model = self._enable_graph_mode(model, example_input)
|
|
118
|
+
|
|
119
|
+
model.eval()
|
|
120
|
+
|
|
121
|
+
return model
|
|
122
|
+
|
|
123
|
+
def _convert_to_channels_last(self, model: nn.Module) -> nn.Module:
|
|
124
|
+
"""Convert model to channels_last memory format for better performance."""
|
|
125
|
+
def convert_layer(module):
|
|
126
|
+
if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)):
|
|
127
|
+
module = module.to(memory_format=torch.channels_last)
|
|
128
|
+
return module
|
|
129
|
+
|
|
130
|
+
model.apply(convert_layer)
|
|
131
|
+
return model
|
|
132
|
+
|
|
133
|
+
def _convert_dtype(self, model: nn.Module) -> nn.Module:
|
|
134
|
+
"""
|
|
135
|
+
Convert model to optimal dtype (bfloat16 or fp16) for faster computation.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
model: Model to convert
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Converted model
|
|
142
|
+
"""
|
|
143
|
+
# Determine target dtype
|
|
144
|
+
if self.config.use_bfloat16:
|
|
145
|
+
target_dtype = torch.bfloat16
|
|
146
|
+
conversion_method = lambda m: m.to(dtype=torch.bfloat16)
|
|
147
|
+
else:
|
|
148
|
+
target_dtype = torch.float16
|
|
149
|
+
conversion_method = lambda m: m.half()
|
|
150
|
+
|
|
151
|
+
def should_convert(module):
|
|
152
|
+
# These layers should stay in float32 for numerical stability
|
|
153
|
+
exclude_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm)
|
|
154
|
+
return not isinstance(module, exclude_types)
|
|
155
|
+
|
|
156
|
+
for name, module in model.named_modules():
|
|
157
|
+
if should_convert(module):
|
|
158
|
+
if hasattr(module, 'weight') and module.weight is not None:
|
|
159
|
+
if module.weight.dtype == torch.float32:
|
|
160
|
+
conversion_method(module)
|
|
161
|
+
|
|
162
|
+
return model
|
|
163
|
+
|
|
164
|
+
def get_optimal_dtype(self) -> torch.dtype:
|
|
165
|
+
"""
|
|
166
|
+
Get the optimal dtype for current hardware.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
torch.bfloat16, torch.float16, or torch.float32
|
|
170
|
+
"""
|
|
171
|
+
if self.config.use_bfloat16:
|
|
172
|
+
return torch.bfloat16
|
|
173
|
+
elif self.config.use_fp16:
|
|
174
|
+
return torch.float16
|
|
175
|
+
else:
|
|
176
|
+
return torch.float32
|
|
177
|
+
|
|
178
|
+
def _fuse_operations(self, model: nn.Module) -> nn.Module:
|
|
179
|
+
"""Fuse compatible operations for better performance."""
|
|
180
|
+
# Note: Conv-BN fusion requires specific module pairs, not whole model
|
|
181
|
+
# Skipping automatic fusion as it's model-specific
|
|
182
|
+
|
|
183
|
+
for name, module in model.named_children():
|
|
184
|
+
if isinstance(module, nn.Sequential):
|
|
185
|
+
fused = self._try_fuse_sequential(module)
|
|
186
|
+
if fused is not None:
|
|
187
|
+
setattr(model, name, fused)
|
|
188
|
+
|
|
189
|
+
return model
|
|
190
|
+
|
|
191
|
+
def _try_fuse_sequential(self, sequential: nn.Sequential) -> Optional[nn.Module]:
|
|
192
|
+
"""Try to fuse operations in a sequential module."""
|
|
193
|
+
layers = list(sequential.children())
|
|
194
|
+
|
|
195
|
+
if len(layers) < 2:
|
|
196
|
+
return None
|
|
197
|
+
|
|
198
|
+
for pattern_name, pattern in self.FUSED_OPERATIONS.items():
|
|
199
|
+
if self._matches_pattern(layers, pattern):
|
|
200
|
+
return self._create_fused_module(layers, pattern_name)
|
|
201
|
+
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
def _matches_pattern(self, layers: List[nn.Module], pattern: List[str]) -> bool:
|
|
205
|
+
"""Check if layers match a fusion pattern."""
|
|
206
|
+
if len(layers) < len(pattern):
|
|
207
|
+
return False
|
|
208
|
+
|
|
209
|
+
for i, expected_type in enumerate(pattern):
|
|
210
|
+
if not hasattr(nn, expected_type):
|
|
211
|
+
return False
|
|
212
|
+
expected_class = getattr(nn, expected_type)
|
|
213
|
+
if not isinstance(layers[i], expected_class):
|
|
214
|
+
return False
|
|
215
|
+
|
|
216
|
+
return True
|
|
217
|
+
|
|
218
|
+
def _create_fused_module(self, layers: List[nn.Module], pattern_name: str) -> nn.Module:
|
|
219
|
+
"""Create a fused module from layers."""
|
|
220
|
+
if pattern_name == "conv_bn_relu":
|
|
221
|
+
return ConvBNReLU(layers[0], layers[1])
|
|
222
|
+
elif pattern_name == "linear_relu":
|
|
223
|
+
return LinearReLU(layers[0])
|
|
224
|
+
elif pattern_name == "linear_gelu":
|
|
225
|
+
return LinearGELU(layers[0])
|
|
226
|
+
else:
|
|
227
|
+
return nn.Sequential(*layers)
|
|
228
|
+
|
|
229
|
+
def _optimize_memory_layout(self, model: nn.Module) -> nn.Module:
|
|
230
|
+
"""Optimize memory layout for MPS."""
|
|
231
|
+
def optimize_layer(module):
|
|
232
|
+
if hasattr(module, 'weight'):
|
|
233
|
+
if module.weight is not None and module.weight.is_contiguous():
|
|
234
|
+
module.weight = module.weight.contiguous()
|
|
235
|
+
|
|
236
|
+
if hasattr(module, 'bias'):
|
|
237
|
+
if module.bias is not None and not module.bias.is_contiguous():
|
|
238
|
+
module.bias = module.bias.contiguous()
|
|
239
|
+
|
|
240
|
+
return module
|
|
241
|
+
|
|
242
|
+
model.apply(optimize_layer)
|
|
243
|
+
return model
|
|
244
|
+
|
|
245
|
+
def _enable_graph_mode(
|
|
246
|
+
self,
|
|
247
|
+
model: nn.Module,
|
|
248
|
+
example_input: torch.Tensor
|
|
249
|
+
) -> nn.Module:
|
|
250
|
+
"""Enable graph mode optimization (experimental)."""
|
|
251
|
+
try:
|
|
252
|
+
with warnings.catch_warnings():
|
|
253
|
+
warnings.simplefilter("ignore")
|
|
254
|
+
|
|
255
|
+
example_input = example_input.to(self.device)
|
|
256
|
+
|
|
257
|
+
# Convert input to optimal dtype
|
|
258
|
+
if self.config.use_bfloat16:
|
|
259
|
+
example_input = example_input.to(dtype=torch.bfloat16)
|
|
260
|
+
elif self.config.use_fp16:
|
|
261
|
+
example_input = example_input.half()
|
|
262
|
+
|
|
263
|
+
with torch.no_grad():
|
|
264
|
+
_ = model(example_input)
|
|
265
|
+
|
|
266
|
+
except Exception as e:
|
|
267
|
+
print(f"Warning: Graph mode optimization failed: {e}")
|
|
268
|
+
|
|
269
|
+
return model
|
|
270
|
+
|
|
271
|
+
def optimize_dataloader(
|
|
272
|
+
self,
|
|
273
|
+
dataloader: torch.utils.data.DataLoader
|
|
274
|
+
) -> torch.utils.data.DataLoader:
|
|
275
|
+
"""Optimize DataLoader for MPS backend."""
|
|
276
|
+
optimized_dataloader = torch.utils.data.DataLoader(
|
|
277
|
+
dataloader.dataset,
|
|
278
|
+
batch_size=min(dataloader.batch_size, self.config.max_batch_size),
|
|
279
|
+
shuffle=dataloader.shuffle if hasattr(dataloader, 'shuffle') else False,
|
|
280
|
+
num_workers=self.config.num_workers,
|
|
281
|
+
pin_memory=False, # Not needed for unified memory
|
|
282
|
+
prefetch_factor=self.config.prefetch_factor if self.config.num_workers > 0 else None
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
return optimized_dataloader
|
|
286
|
+
|
|
287
|
+
def profile_model(
|
|
288
|
+
self,
|
|
289
|
+
model: nn.Module,
|
|
290
|
+
input_shape: Tuple[int, ...],
|
|
291
|
+
num_iterations: int = 100
|
|
292
|
+
) -> Dict[str, Any]:
|
|
293
|
+
"""Profile model performance on MPS."""
|
|
294
|
+
model.eval()
|
|
295
|
+
device = torch.device("mps")
|
|
296
|
+
|
|
297
|
+
dummy_input = torch.randn(input_shape).to(device)
|
|
298
|
+
|
|
299
|
+
# Convert to optimal dtype
|
|
300
|
+
if self.config.use_bfloat16:
|
|
301
|
+
dummy_input = dummy_input.to(dtype=torch.bfloat16)
|
|
302
|
+
elif self.config.use_fp16:
|
|
303
|
+
dummy_input = dummy_input.half()
|
|
304
|
+
|
|
305
|
+
torch.mps.synchronize()
|
|
306
|
+
|
|
307
|
+
import time
|
|
308
|
+
warmup_iterations = 10
|
|
309
|
+
for _ in range(warmup_iterations):
|
|
310
|
+
with torch.no_grad():
|
|
311
|
+
_ = model(dummy_input)
|
|
312
|
+
|
|
313
|
+
torch.mps.synchronize()
|
|
314
|
+
|
|
315
|
+
start_time = time.perf_counter()
|
|
316
|
+
for _ in range(num_iterations):
|
|
317
|
+
with torch.no_grad():
|
|
318
|
+
_ = model(dummy_input)
|
|
319
|
+
|
|
320
|
+
torch.mps.synchronize()
|
|
321
|
+
end_time = time.perf_counter()
|
|
322
|
+
|
|
323
|
+
avg_time = (end_time - start_time) / num_iterations
|
|
324
|
+
throughput = input_shape[0] / avg_time if avg_time > 0 else 0
|
|
325
|
+
|
|
326
|
+
memory_allocated = torch.mps.current_allocated_memory() if hasattr(torch.mps, 'current_allocated_memory') else 0
|
|
327
|
+
|
|
328
|
+
return {
|
|
329
|
+
"avg_inference_time": avg_time,
|
|
330
|
+
"throughput": throughput,
|
|
331
|
+
"memory_allocated": memory_allocated,
|
|
332
|
+
"device": "mps",
|
|
333
|
+
"dtype": str(self.get_optimal_dtype()),
|
|
334
|
+
"fp16": self.config.use_fp16,
|
|
335
|
+
"bfloat16": self.config.use_bfloat16,
|
|
336
|
+
"batch_size": input_shape[0]
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
@staticmethod
|
|
340
|
+
def get_mps_info() -> Dict[str, Any]:
|
|
341
|
+
"""Get MPS backend information."""
|
|
342
|
+
info = {
|
|
343
|
+
"available": torch.backends.mps.is_available(),
|
|
344
|
+
"built": torch.backends.mps.is_built()
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
if info["available"]:
|
|
348
|
+
info["current_allocated_memory"] = torch.mps.current_allocated_memory() if hasattr(torch.mps, 'current_allocated_memory') else 0
|
|
349
|
+
info["driver_allocated_memory"] = torch.mps.driver_allocated_memory() if hasattr(torch.mps, 'driver_allocated_memory') else 0
|
|
350
|
+
|
|
351
|
+
# Check bfloat16 support
|
|
352
|
+
try:
|
|
353
|
+
test_tensor = torch.tensor([1.0], dtype=torch.bfloat16, device="mps")
|
|
354
|
+
info["bfloat16_supported"] = True
|
|
355
|
+
except (RuntimeError, TypeError):
|
|
356
|
+
info["bfloat16_supported"] = False
|
|
357
|
+
|
|
358
|
+
return info
|
|
359
|
+
|
|
360
|
+
def get_optimization_summary(self) -> Dict[str, Any]:
|
|
361
|
+
"""
|
|
362
|
+
Get summary of optimizations applied.
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
Dictionary with optimization details
|
|
366
|
+
"""
|
|
367
|
+
optimal_dtype = self.get_optimal_dtype()
|
|
368
|
+
|
|
369
|
+
return {
|
|
370
|
+
"device": str(self.device),
|
|
371
|
+
"optimal_dtype": str(optimal_dtype),
|
|
372
|
+
"dtype_bits": 16 if optimal_dtype in [torch.float16, torch.bfloat16] else 32,
|
|
373
|
+
"memory_reduction": "50%" if optimal_dtype in [torch.float16, torch.bfloat16] else "0%",
|
|
374
|
+
"hardware_features": {
|
|
375
|
+
"bfloat16": self.config.use_bfloat16,
|
|
376
|
+
"channels_last": self.config.use_channels_last,
|
|
377
|
+
"graph_mode": self.config.use_graph_mode,
|
|
378
|
+
"fused_operations": self.config.fuse_operations,
|
|
379
|
+
"memory_optimized": self.config.optimize_memory
|
|
380
|
+
},
|
|
381
|
+
"gpu_family": self.gpu_validator.gpu_info.gpu_family if self.gpu_validator.gpu_info else "unknown",
|
|
382
|
+
"expected_speedup": "10-15%" if self.config.use_bfloat16 else "baseline"
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
class ConvBNReLU(nn.Module):
|
|
386
|
+
"""Fused Conv-BatchNorm-ReLU module."""
|
|
387
|
+
|
|
388
|
+
def __init__(self, conv: nn.Conv2d, bn: nn.BatchNorm2d):
|
|
389
|
+
super().__init__()
|
|
390
|
+
self.conv = conv
|
|
391
|
+
self.bn = bn
|
|
392
|
+
self.relu = nn.ReLU(inplace=True)
|
|
393
|
+
|
|
394
|
+
def forward(self, x):
|
|
395
|
+
return self.relu(self.bn(self.conv(x)))
|
|
396
|
+
|
|
397
|
+
class LinearReLU(nn.Module):
|
|
398
|
+
"""Fused Linear-ReLU module."""
|
|
399
|
+
|
|
400
|
+
def __init__(self, linear: nn.Linear):
|
|
401
|
+
super().__init__()
|
|
402
|
+
self.linear = linear
|
|
403
|
+
self.relu = nn.ReLU(inplace=True)
|
|
404
|
+
|
|
405
|
+
def forward(self, x):
|
|
406
|
+
return self.relu(self.linear(x))
|
|
407
|
+
|
|
408
|
+
class LinearGELU(nn.Module):
|
|
409
|
+
"""Fused Linear-GELU module."""
|
|
410
|
+
|
|
411
|
+
def __init__(self, linear: nn.Linear):
|
|
412
|
+
super().__init__()
|
|
413
|
+
self.linear = linear
|
|
414
|
+
self.gelu = nn.GELU()
|
|
415
|
+
|
|
416
|
+
def forward(self, x):
|
|
417
|
+
return self.gelu(self.linear(x))
|