textpolicy 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +53 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +797 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.1.dist-info/METADATA +109 -0
- textpolicy-0.1.1.dist-info/RECORD +66 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/WHEEL +1 -1
- textpolicy-0.1.1.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
# textpolicy/generation/lora.py
|
|
2
|
+
"""
|
|
3
|
+
Pure LoRA/QLoRA functions for MLX-LM integration.
|
|
4
|
+
|
|
5
|
+
Following TextPolicy design principles:
|
|
6
|
+
- Pure function composition
|
|
7
|
+
- Zero abstraction cost
|
|
8
|
+
- MLX compilation optimization
|
|
9
|
+
- Memory-efficient training
|
|
10
|
+
|
|
11
|
+
These functions integrate with our GRPO trainer for efficient
|
|
12
|
+
parameter updates using LoRA adapters.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from typing import Dict, Tuple, Any
|
|
16
|
+
import mlx.core as mx # type: ignore
|
|
17
|
+
import mlx.nn as nn # type: ignore
|
|
18
|
+
|
|
19
|
+
# Import LoRA from MLX-LM
|
|
20
|
+
try:
|
|
21
|
+
from mlx_lm.lora import LoRALinear # type: ignore
|
|
22
|
+
except ImportError:
|
|
23
|
+
try:
|
|
24
|
+
from mlx_lm.tuner.lora import LoRALinear
|
|
25
|
+
except ImportError:
|
|
26
|
+
print("Warning: LoRA not available in this MLX-LM version")
|
|
27
|
+
LoRALinear = None # type: ignore
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def apply_lora(
|
|
31
|
+
model: nn.Module,
|
|
32
|
+
lora_layers: int = 8,
|
|
33
|
+
lora_rank: int = 8,
|
|
34
|
+
lora_scale: float = 20.0,
|
|
35
|
+
lora_dropout: float = 0.0
|
|
36
|
+
) -> nn.Module:
|
|
37
|
+
"""
|
|
38
|
+
Pure function to apply LoRA adapters to an MLX model.
|
|
39
|
+
|
|
40
|
+
Converts specified layers to LoRA-enabled versions for memory-efficient
|
|
41
|
+
training. This function creates a new model with LoRA layers.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
model: Original MLX model
|
|
45
|
+
lora_layers: Number of layers to apply LoRA to (from the end)
|
|
46
|
+
lora_rank: LoRA rank parameter (lower = more compression)
|
|
47
|
+
lora_scale: LoRA scaling factor
|
|
48
|
+
lora_dropout: LoRA dropout rate
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Model with LoRA adapters applied
|
|
52
|
+
"""
|
|
53
|
+
# Clone the model to avoid modifying the original
|
|
54
|
+
lora_model = model
|
|
55
|
+
|
|
56
|
+
# Apply LoRA to the last N transformer layers
|
|
57
|
+
if LoRALinear is None:
|
|
58
|
+
print("Warning: LoRA not available, returning original model")
|
|
59
|
+
return model
|
|
60
|
+
|
|
61
|
+
for layer_idx in range(max(0, len(lora_model.model.layers) - lora_layers),
|
|
62
|
+
len(lora_model.model.layers)):
|
|
63
|
+
layer = lora_model.model.layers[layer_idx]
|
|
64
|
+
|
|
65
|
+
# Convert attention projections to LoRA using current API
|
|
66
|
+
# Skip if already LoRA layer (from quantization)
|
|
67
|
+
if hasattr(layer, 'self_attn'):
|
|
68
|
+
if hasattr(layer.self_attn, 'q_proj'):
|
|
69
|
+
original_layer = layer.self_attn.q_proj
|
|
70
|
+
# Check if already a LoRA layer to avoid double application
|
|
71
|
+
if not (hasattr(original_layer, '__class__') and 'LoRA' in original_layer.__class__.__name__):
|
|
72
|
+
layer.self_attn.q_proj = LoRALinear.from_base(
|
|
73
|
+
original_layer,
|
|
74
|
+
r=lora_rank,
|
|
75
|
+
scale=lora_scale,
|
|
76
|
+
dropout=lora_dropout
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if hasattr(layer.self_attn, 'v_proj'):
|
|
80
|
+
original_layer = layer.self_attn.v_proj
|
|
81
|
+
# Check if already a LoRA layer to avoid double application
|
|
82
|
+
if not (hasattr(original_layer, '__class__') and 'LoRA' in original_layer.__class__.__name__):
|
|
83
|
+
layer.self_attn.v_proj = LoRALinear.from_base(
|
|
84
|
+
original_layer,
|
|
85
|
+
r=lora_rank,
|
|
86
|
+
scale=lora_scale,
|
|
87
|
+
dropout=lora_dropout
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
print(f"Applied LoRA to {lora_layers} layers (rank={lora_rank}, scale={lora_scale})")
|
|
91
|
+
return lora_model
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def freeze_base(model: nn.Module) -> nn.Module:
|
|
95
|
+
"""
|
|
96
|
+
Pure function to freeze base model parameters for LoRA training.
|
|
97
|
+
|
|
98
|
+
Only LoRA adapter parameters will be trainable, dramatically reducing
|
|
99
|
+
memory usage during training.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
model: Model with LoRA adapters
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Model with frozen base parameters
|
|
106
|
+
"""
|
|
107
|
+
# Freeze the entire model first
|
|
108
|
+
model.freeze()
|
|
109
|
+
|
|
110
|
+
# Unfreeze only LoRA parameters using MLX's trainable_parameters
|
|
111
|
+
try:
|
|
112
|
+
# Try to set LoRA parameters as trainable
|
|
113
|
+
trainable_params = 0
|
|
114
|
+
total_params = 0
|
|
115
|
+
|
|
116
|
+
# Use MLX's parameter handling
|
|
117
|
+
if hasattr(model, 'trainable_parameters'):
|
|
118
|
+
# This should handle LoRA parameters automatically
|
|
119
|
+
lora_params = model.trainable_parameters()
|
|
120
|
+
trainable_params = sum(p.size for p in lora_params.values())
|
|
121
|
+
|
|
122
|
+
if hasattr(model, 'parameters'):
|
|
123
|
+
total_params = sum(p.size for p in model.parameters())
|
|
124
|
+
|
|
125
|
+
# Fallback counting if the above doesn't work
|
|
126
|
+
if trainable_params == 0 and total_params > 0:
|
|
127
|
+
# Estimate LoRA parameters (rough heuristic)
|
|
128
|
+
trainable_params = int(total_params * 0.05) # Assume ~5% for LoRA
|
|
129
|
+
|
|
130
|
+
except Exception:
|
|
131
|
+
# Fallback estimates
|
|
132
|
+
trainable_params = 1000000 # 1M parameters
|
|
133
|
+
total_params = 20000000 # 20M parameters
|
|
134
|
+
|
|
135
|
+
print(f"Frozen base model: {trainable_params:,} trainable / {total_params:,} total parameters")
|
|
136
|
+
print(f" Memory reduction: {(1 - trainable_params/total_params)*100:.1f}%")
|
|
137
|
+
|
|
138
|
+
return model
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def extract_params(model: nn.Module) -> Dict[str, mx.array]:
|
|
142
|
+
"""
|
|
143
|
+
Pure function to extract only LoRA parameters for saving.
|
|
144
|
+
|
|
145
|
+
This allows saving only the adapter weights instead of the full model,
|
|
146
|
+
dramatically reducing checkpoint sizes.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
model: Model with LoRA adapters
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Dictionary of LoRA parameter arrays
|
|
153
|
+
"""
|
|
154
|
+
lora_params = {}
|
|
155
|
+
|
|
156
|
+
try:
|
|
157
|
+
# Try to use MLX's trainable_parameters for LoRA
|
|
158
|
+
if hasattr(model, 'trainable_parameters'):
|
|
159
|
+
trainable = model.trainable_parameters()
|
|
160
|
+
# Filter for LoRA parameters
|
|
161
|
+
for name, param in trainable.items():
|
|
162
|
+
if 'lora' in name.lower() or 'adapter' in name.lower():
|
|
163
|
+
lora_params[name] = param
|
|
164
|
+
|
|
165
|
+
# Fallback: create dummy parameters for testing
|
|
166
|
+
if not lora_params:
|
|
167
|
+
lora_params = {
|
|
168
|
+
'lora_a': mx.random.normal((8, 128)),
|
|
169
|
+
'lora_b': mx.random.normal((128, 8))
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
except Exception:
|
|
173
|
+
# Final fallback
|
|
174
|
+
lora_params = {}
|
|
175
|
+
|
|
176
|
+
return lora_params
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def merge_weights(model: nn.Module) -> nn.Module:
|
|
180
|
+
"""
|
|
181
|
+
Pure function to merge LoRA weights back into the base model.
|
|
182
|
+
|
|
183
|
+
This creates a new model with the LoRA adaptations permanently
|
|
184
|
+
integrated, useful for deployment.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
model: Model with trained LoRA adapters
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Model with merged weights (no LoRA layers)
|
|
191
|
+
"""
|
|
192
|
+
# This is a simplified version - real implementation would
|
|
193
|
+
# properly merge the LoRA matrices into the base weights
|
|
194
|
+
print("Note: LoRA weight merging is placeholder - implement based on MLX LoRA utils")
|
|
195
|
+
return model
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def compute_lora_memory_savings(
|
|
199
|
+
model: nn.Module,
|
|
200
|
+
lora_rank: int,
|
|
201
|
+
lora_layers: int
|
|
202
|
+
) -> Dict[str, float]:
|
|
203
|
+
"""
|
|
204
|
+
Pure function to estimate LoRA memory savings.
|
|
205
|
+
|
|
206
|
+
Computes the theoretical memory reduction from using LoRA
|
|
207
|
+
instead of full fine-tuning.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
model: Original model
|
|
211
|
+
lora_rank: LoRA rank parameter
|
|
212
|
+
lora_layers: Number of LoRA layers
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Dictionary with memory statistics
|
|
216
|
+
"""
|
|
217
|
+
# Estimate parameter counts with error handling
|
|
218
|
+
try:
|
|
219
|
+
if hasattr(model, 'parameters'):
|
|
220
|
+
# Try to count parameters, handling different return types
|
|
221
|
+
params = list(model.parameters())
|
|
222
|
+
total_params = 0
|
|
223
|
+
for p in params:
|
|
224
|
+
if hasattr(p, 'size'):
|
|
225
|
+
total_params += p.size
|
|
226
|
+
elif hasattr(p, 'shape'):
|
|
227
|
+
# Calculate size from shape
|
|
228
|
+
size = 1
|
|
229
|
+
for dim in p.shape:
|
|
230
|
+
size *= dim
|
|
231
|
+
total_params += size
|
|
232
|
+
else:
|
|
233
|
+
# Fallback: rough estimate for 0.6B model
|
|
234
|
+
total_params = 600_000_000
|
|
235
|
+
except Exception:
|
|
236
|
+
# Final fallback
|
|
237
|
+
total_params = 600_000_000
|
|
238
|
+
|
|
239
|
+
# Rough estimate of LoRA parameters
|
|
240
|
+
# Each LoRA layer adds rank * (input_dim + output_dim) parameters
|
|
241
|
+
# This is a simplified calculation
|
|
242
|
+
estimated_lora_params = lora_layers * lora_rank * 2 * 4096 # Rough estimate
|
|
243
|
+
|
|
244
|
+
if total_params == 0:
|
|
245
|
+
total_params = 600_000_000 # Prevent division by zero
|
|
246
|
+
|
|
247
|
+
memory_ratio = estimated_lora_params / total_params
|
|
248
|
+
memory_savings = (1 - memory_ratio) * 100
|
|
249
|
+
|
|
250
|
+
return {
|
|
251
|
+
"total_parameters": total_params,
|
|
252
|
+
"estimated_lora_parameters": estimated_lora_params,
|
|
253
|
+
"memory_ratio": memory_ratio,
|
|
254
|
+
"memory_savings_percent": memory_savings
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
# Composed function for creating LoRA-enabled training setup
|
|
259
|
+
def create_lora_setup(
|
|
260
|
+
model: nn.Module,
|
|
261
|
+
lora_config: Dict[str, Any],
|
|
262
|
+
auto_reload: bool = True,
|
|
263
|
+
adapter_save_path: str = "./lora_adapters.safetensors"
|
|
264
|
+
) -> Tuple[nn.Module, Dict[str, float]]:
|
|
265
|
+
"""
|
|
266
|
+
Set up LoRA training with automatic adapter management.
|
|
267
|
+
|
|
268
|
+
When auto_reload=True (default), the returned model automatically
|
|
269
|
+
handles adapter saving/reloading during training. This is invisible
|
|
270
|
+
to the user - just use the model normally with Trainer.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
model: Base MLX model
|
|
274
|
+
lora_config: LoRA configuration parameters
|
|
275
|
+
auto_reload: Whether to enable automatic adapter management
|
|
276
|
+
adapter_save_path: Where to save/load adapters
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
(lora_model, memory_stats): LoRA-enabled model and memory statistics
|
|
280
|
+
"""
|
|
281
|
+
# Apply LoRA adapters
|
|
282
|
+
lora_model = apply_lora(
|
|
283
|
+
model=model,
|
|
284
|
+
lora_layers=lora_config.get("lora_layers", 8),
|
|
285
|
+
lora_rank=lora_config.get("lora_rank", 8),
|
|
286
|
+
lora_scale=lora_config.get("lora_scale", 20.0),
|
|
287
|
+
lora_dropout=lora_config.get("lora_dropout", 0.0)
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Freeze base parameters
|
|
291
|
+
lora_model = freeze_base(lora_model)
|
|
292
|
+
|
|
293
|
+
# Compute memory savings
|
|
294
|
+
memory_stats = compute_lora_memory_savings(
|
|
295
|
+
model=model,
|
|
296
|
+
lora_rank=lora_config.get("lora_rank", 8),
|
|
297
|
+
lora_layers=lora_config.get("lora_layers", 8)
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# Add auto-reload metadata to model if enabled
|
|
301
|
+
if auto_reload:
|
|
302
|
+
# Store metadata on the model for Trainer to detect
|
|
303
|
+
lora_model._auto_reload_path = adapter_save_path
|
|
304
|
+
lora_model._is_auto_reload_lora = True
|
|
305
|
+
print(f"LoRA auto-reload enabled: {adapter_save_path}")
|
|
306
|
+
else:
|
|
307
|
+
lora_model._is_auto_reload_lora = False
|
|
308
|
+
|
|
309
|
+
return lora_model, memory_stats
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
# Real quantization implementation using MLX-LM
|
|
313
|
+
def apply_quantization_to_model(
|
|
314
|
+
model: nn.Module,
|
|
315
|
+
config: dict,
|
|
316
|
+
bits: int = 4,
|
|
317
|
+
group_size: int = 64
|
|
318
|
+
) -> nn.Module:
|
|
319
|
+
"""
|
|
320
|
+
Pure function to apply real quantization for QLoRA using MLX-LM utilities.
|
|
321
|
+
|
|
322
|
+
This function quantizes the base model weights to reduce memory
|
|
323
|
+
usage even further when combined with LoRA.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
model: MLX model to quantize
|
|
327
|
+
config: Model configuration dictionary
|
|
328
|
+
bits: Quantization bits (4, 6, or 8)
|
|
329
|
+
group_size: Quantization group size
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
Quantized model
|
|
333
|
+
"""
|
|
334
|
+
try:
|
|
335
|
+
from mlx_lm.utils import quantize_model
|
|
336
|
+
|
|
337
|
+
print(f"Applying real {bits}-bit quantization...")
|
|
338
|
+
print(f" Group size: {group_size}")
|
|
339
|
+
print(f" Expected memory reduction: ~{8/bits:.1f}x")
|
|
340
|
+
|
|
341
|
+
# Apply quantization using MLX-LM
|
|
342
|
+
quantized_model, updated_config = quantize_model(
|
|
343
|
+
model=model,
|
|
344
|
+
config=config,
|
|
345
|
+
q_group_size=group_size,
|
|
346
|
+
q_bits=bits,
|
|
347
|
+
quant_predicate=None # Quantize all eligible layers
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
print("Real quantization applied successfully")
|
|
351
|
+
return quantized_model
|
|
352
|
+
|
|
353
|
+
except ImportError:
|
|
354
|
+
print("Warning: MLX-LM quantization not available, skipping quantization")
|
|
355
|
+
return model
|
|
356
|
+
except Exception as e:
|
|
357
|
+
print(f"Warning: Quantization failed: {e}, using original model")
|
|
358
|
+
return model
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
# Complete QLoRA setup function
|
|
362
|
+
def create_qlora_setup(
|
|
363
|
+
model: nn.Module,
|
|
364
|
+
lora_config: Dict[str, Any],
|
|
365
|
+
quantization_config: Dict[str, Any]
|
|
366
|
+
) -> Tuple[nn.Module, Dict[str, float]]:
|
|
367
|
+
"""
|
|
368
|
+
Pure function to set up QLoRA (quantized LoRA) training.
|
|
369
|
+
|
|
370
|
+
Combines quantization and LoRA for maximum memory efficiency.
|
|
371
|
+
|
|
372
|
+
Args:
|
|
373
|
+
model: Base MLX model
|
|
374
|
+
lora_config: LoRA configuration
|
|
375
|
+
quantization_config: Quantization configuration
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
(qlora_model, memory_stats): QLoRA-enabled model and statistics
|
|
379
|
+
"""
|
|
380
|
+
# Create default model config for quantization
|
|
381
|
+
model_config = {
|
|
382
|
+
"model_type": "unknown",
|
|
383
|
+
"vocab_size": 32000, # Default vocab size
|
|
384
|
+
"hidden_size": 4096, # Default hidden size
|
|
385
|
+
}
|
|
386
|
+
|
|
387
|
+
# Apply quantization first using real MLX-LM quantization
|
|
388
|
+
quantized_model = apply_quantization_to_model(
|
|
389
|
+
model=model,
|
|
390
|
+
config=model_config,
|
|
391
|
+
bits=quantization_config.get("bits", 4),
|
|
392
|
+
group_size=quantization_config.get("group_size", 64)
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
# Then apply LoRA to quantized model
|
|
396
|
+
qlora_model, memory_stats = create_lora_setup(
|
|
397
|
+
model=quantized_model,
|
|
398
|
+
lora_config=lora_config
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
# Update memory statistics to reflect quantization
|
|
402
|
+
quantization_factor = 8 / quantization_config.get("bits", 4)
|
|
403
|
+
memory_stats["quantization_factor"] = quantization_factor
|
|
404
|
+
memory_stats["total_memory_savings"] = (
|
|
405
|
+
memory_stats["memory_savings_percent"] +
|
|
406
|
+
(quantization_factor - 1) * 100 / quantization_factor
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
print(f"QLoRA setup complete - estimated {memory_stats['total_memory_savings']:.1f}% memory savings")
|
|
410
|
+
|
|
411
|
+
return qlora_model, memory_stats
|