openllava 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openllava/__init__.py +96 -0
- openllava/api/__init__.py +31 -0
- openllava/api/callbacks.py +463 -0
- openllava/api/config.py +328 -0
- openllava/api/fast_model.py +406 -0
- openllava/api/loggers.py +501 -0
- openllava/api/strategies.py +433 -0
- openllava/api/trainer.py +519 -0
- openllava/backends/__init__.py +90 -0
- openllava/backends/cpu_simd.py +272 -0
- openllava/backends/gguf_v2.py +470 -0
- openllava/backends/mlx_backend.py +144 -0
- openllava/backends/onnx_export.py +272 -0
- openllava/backends/rocm_backend.py +190 -0
- openllava/backends/safetensors_io.py +137 -0
- openllava/backends/tpu_dataloader.py +364 -0
- openllava/backends/tpu_spmd.py +373 -0
- openllava/backends/tpu_xla.py +326 -0
- openllava/backends/xpu_backend.py +175 -0
- openllava/cli/__init__.py +18 -0
- openllava/cli/commands/__init__.py +9 -0
- openllava/cli/commands/benchmark.py +381 -0
- openllava/cli/commands/export.py +423 -0
- openllava/cli/commands/serve.py +313 -0
- openllava/cli/commands/train.py +214 -0
- openllava/cli/main.py +122 -0
- openllava/core/__init__.py +1 -0
- openllava/core/backend.py +485 -0
- openllava/core/model.py +939 -0
- openllava/core/patcher.py +529 -0
- openllava/data/__init__.py +109 -0
- openllava/data/collator.py +473 -0
- openllava/data/gpu_augmentation.py +303 -0
- openllava/data/pipeline.py +198 -0
- openllava/data/preprocessing.py +204 -0
- openllava/data/smart_batching.py +251 -0
- openllava/data/streaming.py +297 -0
- openllava/data/templates.py +159 -0
- openllava/distributed/__init__.py +223 -0
- openllava/distributed/auto_parallel.py +438 -0
- openllava/distributed/cluster_config.py +444 -0
- openllava/distributed/deepspeed.py +449 -0
- openllava/distributed/device_mesh.py +474 -0
- openllava/distributed/dtensor_ops.py +423 -0
- openllava/distributed/expert_parallel.py +492 -0
- openllava/distributed/fsdp.py +372 -0
- openllava/distributed/heterogeneous.py +357 -0
- openllava/distributed/init_process.py +414 -0
- openllava/distributed/parallel_4d.py +851 -0
- openllava/distributed/pipeline_parallel.py +481 -0
- openllava/distributed/placement.py +364 -0
- openllava/distributed/ring_attention.py +224 -0
- openllava/distributed/tensor_parallel.py +430 -0
- openllava/distributed/topology.py +406 -0
- openllava/distributed/zero_hpz.py +379 -0
- openllava/distributed/zero_offload.py +469 -0
- openllava/eval/__init__.py +1 -0
- openllava/eval/mmbench.py +127 -0
- openllava/eval/runner.py +312 -0
- openllava/eval/scienceqa.py +120 -0
- openllava/eval/textvqa.py +119 -0
- openllava/experts/__init__.py +50 -0
- openllava/experts/moe_layers.py +424 -0
- openllava/experts/moe_lora.py +509 -0
- openllava/experts/moe_trainer.py +588 -0
- openllava/experts/ocr.py +86 -0
- openllava/experts/router.py +192 -0
- openllava/experts/visual.py +194 -0
- openllava/export/__init__.py +24 -0
- openllava/export/gguf.py +251 -0
- openllava/export/hub.py +165 -0
- openllava/export/merge.py +117 -0
- openllava/export/quantize.py +152 -0
- openllava/inference/__init__.py +31 -0
- openllava/inference/continuous_batching.py +519 -0
- openllava/inference/engine.py +633 -0
- openllava/inference/paged_attention.py +617 -0
- openllava/inference/speculative_decoding.py +813 -0
- openllava/kernels/__init__.py +32 -0
- openllava/kernels/cuda_graphs/__init__.py +3 -0
- openllava/kernels/cuda_graphs/graph_trainer.py +161 -0
- openllava/kernels/streams.py +143 -0
- openllava/kernels/triton/__init__.py +91 -0
- openllava/kernels/triton/bitnet_gemm.py +280 -0
- openllava/kernels/triton/blindsight.py +348 -0
- openllava/kernels/triton/flash_attention.py +364 -0
- openllava/kernels/triton/flex_attention.py +248 -0
- openllava/kernels/triton/fused_attention.py +543 -0
- openllava/kernels/triton/fused_cross_entropy.py +286 -0
- openllava/kernels/triton/fused_projector.py +384 -0
- openllava/kernels/triton/fused_rmsnorm.py +129 -0
- openllava/kernels/triton/fused_rope.py +278 -0
- openllava/kernels/triton/fused_swiglu.py +406 -0
- openllava/kernels/triton/grouped_gemm.py +339 -0
- openllava/kernels/triton/online_softmax.py +249 -0
- openllava/kernels/triton/sparse_attention.py +421 -0
- openllava/optimizations/__init__.py +342 -0
- openllava/optimizations/async_io.py +253 -0
- openllava/optimizations/bitnet.py +342 -0
- openllava/optimizations/bitnet_a48.py +336 -0
- openllava/optimizations/chunked_prefill.py +367 -0
- openllava/optimizations/cpu_offload.py +153 -0
- openllava/optimizations/curriculum.py +319 -0
- openllava/optimizations/eagle_draft.py +430 -0
- openllava/optimizations/ema.py +179 -0
- openllava/optimizations/fast_nf4.py +312 -0
- openllava/optimizations/fp4_quant.py +589 -0
- openllava/optimizations/fp8_training.py +411 -0
- openllava/optimizations/full_finetune.py +373 -0
- openllava/optimizations/galore.py +384 -0
- openllava/optimizations/gptq_awq.py +320 -0
- openllava/optimizations/kv_compression.py +482 -0
- openllava/optimizations/kv_eviction.py +452 -0
- openllava/optimizations/kv_quantization.py +448 -0
- openllava/optimizations/medusa_heads.py +356 -0
- openllava/optimizations/memory_pool.py +90 -0
- openllava/optimizations/mixed_precision_quant.py +628 -0
- openllava/optimizations/mxfp8_moe.py +418 -0
- openllava/optimizations/ngram_draft.py +456 -0
- openllava/optimizations/packing.py +142 -0
- openllava/optimizations/padding_free.py +116 -0
- openllava/optimizations/qat.py +516 -0
- openllava/optimizations/schedulers.py +298 -0
- openllava/optimizations/selective_checkpoint.py +94 -0
- openllava/optimizations/sparse_attn_selector.py +385 -0
- openllava/optimizations/split_lora.py +623 -0
- openllava/optimizations/torch_compile.py +259 -0
- openllava/optimizations/torchao_integration.py +253 -0
- openllava/optimizations/tree_verification.py +431 -0
- openllava/optimizations/yadis_cross_attn.py +496 -0
- openllava/optimizations/yadis_moe_adaptive.py +459 -0
- openllava/optimizations/yadis_vq_ema.py +649 -0
- openllava/rl/__init__.py +61 -0
- openllava/rl/dpo.py +561 -0
- openllava/rl/grpo.py +524 -0
- openllava/rl/orpo.py +448 -0
- openllava/rl/ppo.py +674 -0
- openllava/rl/rewards.py +486 -0
- openllava/rl/vllm_integration.py +405 -0
- openllava/serve/__init__.py +47 -0
- openllava/serve/batch_manager.py +523 -0
- openllava/serve/metrics.py +387 -0
- openllava/serve/middleware.py +421 -0
- openllava/serve/openai_api.py +302 -0
- openllava/serve/server.py +499 -0
- openllava/training/__init__.py +54 -0
- openllava/training/bitnet_trainer.py +487 -0
- openllava/training/checkpointing.py +342 -0
- openllava/training/dora.py +366 -0
- openllava/training/lora.py +226 -0
- openllava/training/lora_fa.py +369 -0
- openllava/training/lora_ga.py +372 -0
- openllava/training/lora_plus.py +391 -0
- openllava/training/lora_registry.py +393 -0
- openllava/training/memory.py +309 -0
- openllava/training/trainer.py +614 -0
- openllava/utils/__init__.py +76 -0
- openllava/utils/auto_detect.py +388 -0
- openllava/utils/benchmark.py +433 -0
- openllava/utils/hardware_detect.py +393 -0
- openllava/utils/hub.py +436 -0
- openllava/utils/model_card.py +340 -0
- openllava/utils/profiler.py +478 -0
- openllava/utils/registry.py +431 -0
- openllava-3.0.0.dist-info/METADATA +1299 -0
- openllava-3.0.0.dist-info/RECORD +170 -0
- openllava-3.0.0.dist-info/WHEEL +5 -0
- openllava-3.0.0.dist-info/entry_points.txt +2 -0
- openllava-3.0.0.dist-info/licenses/LICENSE +201 -0
- openllava-3.0.0.dist-info/top_level.txt +1 -0
openllava/__init__.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OpenLLaVA — Open-Source Multimodal Vision Injection Framework.
|
|
3
|
+
|
|
4
|
+
Inject vision into any language model. Architecture-agnostic, multi-backend.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from openllava import OpenLLaVA, Backend, experts
|
|
8
|
+
|
|
9
|
+
model = OpenLLaVA(
|
|
10
|
+
llm="meta-llama/Llama-3-8B",
|
|
11
|
+
vision_encoder="google/siglip2-so400m-patch14-384",
|
|
12
|
+
backend=Backend.CUDA,
|
|
13
|
+
)
|
|
14
|
+
model.lora(r=64, alpha=128)
|
|
15
|
+
model.train(phase1=dict(dataset="liuhaotian/LLaVA-Pretrain", samples=100_000))
|
|
16
|
+
model.push("my-org/my-model")
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
from . import (
|
|
22
|
+
api,
|
|
23
|
+
backends,
|
|
24
|
+
cli,
|
|
25
|
+
data,
|
|
26
|
+
distributed,
|
|
27
|
+
eval,
|
|
28
|
+
experts,
|
|
29
|
+
export,
|
|
30
|
+
inference,
|
|
31
|
+
kernels,
|
|
32
|
+
optimizations,
|
|
33
|
+
rl,
|
|
34
|
+
serve,
|
|
35
|
+
training,
|
|
36
|
+
utils,
|
|
37
|
+
)
|
|
38
|
+
from .api import FastLanguageModel, FastVisionModel, OpenLLaVATrainer, TrainingConfig
|
|
39
|
+
from .core.backend import Backend, BackendManager, get_backend, is_cuda_available
|
|
40
|
+
from .core.model import OpenLLaVA
|
|
41
|
+
from .core.patcher import AnyResProcessor, ModelPatcher, YakiModel, YakiProjector
|
|
42
|
+
from .kernels import cuda_graphs
|
|
43
|
+
from .kernels import triton as triton_kernels
|
|
44
|
+
from .optimizations import (
|
|
45
|
+
EMAModel,
|
|
46
|
+
awq_quantize,
|
|
47
|
+
compile_model,
|
|
48
|
+
enable_fp8_training,
|
|
49
|
+
gptq_quantize,
|
|
50
|
+
)
|
|
51
|
+
from .utils import HardwareDetector, HardwareInfo, auto_configure, profile_model
|
|
52
|
+
|
|
53
|
+
__version__ = "3.0.0"
|
|
54
|
+
__author__ = "OpceanAI Research Team"
|
|
55
|
+
|
|
56
|
+
__all__ = [
|
|
57
|
+
"OpenLLaVA",
|
|
58
|
+
"Backend",
|
|
59
|
+
"BackendManager",
|
|
60
|
+
"get_backend",
|
|
61
|
+
"is_cuda_available",
|
|
62
|
+
"YakiProjector",
|
|
63
|
+
"YakiModel",
|
|
64
|
+
"AnyResProcessor",
|
|
65
|
+
"ModelPatcher",
|
|
66
|
+
"experts",
|
|
67
|
+
"training",
|
|
68
|
+
"data",
|
|
69
|
+
"rl",
|
|
70
|
+
"export",
|
|
71
|
+
"eval",
|
|
72
|
+
"kernels",
|
|
73
|
+
"optimizations",
|
|
74
|
+
"inference",
|
|
75
|
+
"distributed",
|
|
76
|
+
"serve",
|
|
77
|
+
"backends",
|
|
78
|
+
"api",
|
|
79
|
+
"cli",
|
|
80
|
+
"utils",
|
|
81
|
+
"triton_kernels",
|
|
82
|
+
"cuda_graphs",
|
|
83
|
+
"compile_model",
|
|
84
|
+
"enable_fp8_training",
|
|
85
|
+
"EMAModel",
|
|
86
|
+
"gptq_quantize",
|
|
87
|
+
"awq_quantize",
|
|
88
|
+
"FastVisionModel",
|
|
89
|
+
"FastLanguageModel",
|
|
90
|
+
"OpenLLaVATrainer",
|
|
91
|
+
"TrainingConfig",
|
|
92
|
+
"auto_configure",
|
|
93
|
+
"HardwareDetector",
|
|
94
|
+
"HardwareInfo",
|
|
95
|
+
"profile_model",
|
|
96
|
+
]
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""OpenLLaVA High-Level API — Unsloth-compatible fast model loading and training.
|
|
2
|
+
|
|
3
|
+
Provides drop-in replacements for common training frameworks with
|
|
4
|
+
auto-configuration, multi-backend logging, and callback orchestration.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from openllava.api import FastVisionModel, OpenLLaVATrainer, TrainingConfig
|
|
8
|
+
|
|
9
|
+
model, tokenizer = FastVisionModel.from_pretrained("openllava/yaki-8b")
|
|
10
|
+
config = TrainingConfig(mode="lora", output_dir="./output")
|
|
11
|
+
trainer = OpenLLaVATrainer(model=model, tokenizer=tokenizer, args=config)
|
|
12
|
+
trainer.train()
|
|
13
|
+
trainer.save_model("./final")
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from .config import TrainingConfig
|
|
19
|
+
from .fast_model import FastLanguageModel, FastVisionModel
|
|
20
|
+
from .strategies import auto_configure, get_peft_model, load_dataset
|
|
21
|
+
from .trainer import OpenLLaVATrainer
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"FastVisionModel",
|
|
25
|
+
"FastLanguageModel",
|
|
26
|
+
"OpenLLaVATrainer",
|
|
27
|
+
"TrainingConfig",
|
|
28
|
+
"get_peft_model",
|
|
29
|
+
"load_dataset",
|
|
30
|
+
"auto_configure",
|
|
31
|
+
]
|
|
@@ -0,0 +1,463 @@
|
|
|
1
|
+
"""OpenLLaVA Callback System.
|
|
2
|
+
|
|
3
|
+
Complete training callback infrastructure with hooks for monitoring, checkpointing,
|
|
4
|
+
early stopping, and memory profiling. Designed for composition via CallbackList.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from openllava.api.callbacks import EarlyStoppingCallback, CallbackList
|
|
8
|
+
|
|
9
|
+
callbacks = CallbackList([
|
|
10
|
+
EarlyStoppingCallback(monitor="loss", patience=3),
|
|
11
|
+
ModelCheckpointCallback(dirpath="./checkpoints", monitor="loss"),
|
|
12
|
+
])
|
|
13
|
+
callbacks.on_train_begin()
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import math
|
|
19
|
+
import os
|
|
20
|
+
import time
|
|
21
|
+
from abc import ABC
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
|
|
26
|
+
_HOOKS = [
|
|
27
|
+
"on_train_begin", "on_train_end",
|
|
28
|
+
"on_epoch_begin", "on_epoch_end",
|
|
29
|
+
"on_step_begin", "on_step_end",
|
|
30
|
+
"on_save", "on_log", "on_evaluate",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class BaseCallback(ABC):
|
|
35
|
+
"""Abstract base for all training callbacks.
|
|
36
|
+
|
|
37
|
+
Subclasses override any combination of the hook methods below.
|
|
38
|
+
All hooks receive **kwargs for forward compatibility.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def on_train_begin(self, **kwargs) -> None:
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
def on_train_end(self, **kwargs) -> None:
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
def on_epoch_begin(self, **kwargs) -> None:
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
def on_epoch_end(self, **kwargs) -> None:
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
def on_step_begin(self, **kwargs) -> None:
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
def on_step_end(self, **kwargs) -> None:
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
def on_save(self, **kwargs) -> None:
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
def on_log(self, **kwargs) -> None:
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
def on_evaluate(self, **kwargs) -> None:
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class CallbackList:
|
|
70
|
+
"""Manages a list of callbacks and dispatches hooks to all of them.
|
|
71
|
+
|
|
72
|
+
Usage:
|
|
73
|
+
cbl = CallbackList([EarlyStoppingCallback(), ModelCheckpointCallback()])
|
|
74
|
+
cbl.on_train_begin(trainer=trainer)
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(self, callbacks: list[BaseCallback] | None = None):
|
|
78
|
+
self._callbacks: list[BaseCallback] = list(callbacks) if callbacks else []
|
|
79
|
+
|
|
80
|
+
def append(self, callback: BaseCallback) -> None:
|
|
81
|
+
self._callbacks.append(callback)
|
|
82
|
+
|
|
83
|
+
def extend(self, callbacks: list[BaseCallback]) -> None:
|
|
84
|
+
self._callbacks.extend(callbacks)
|
|
85
|
+
|
|
86
|
+
def dispatch(self, hook: str, **kwargs) -> None:
|
|
87
|
+
for cb in self._callbacks:
|
|
88
|
+
try:
|
|
89
|
+
getattr(cb, hook)(**kwargs)
|
|
90
|
+
except Exception as e:
|
|
91
|
+
print(f"[CallbackList] Error in {cb.__class__.__name__}.{hook}: {e}")
|
|
92
|
+
|
|
93
|
+
def on_train_begin(self, **kwargs) -> None:
|
|
94
|
+
self.dispatch("on_train_begin", **kwargs)
|
|
95
|
+
|
|
96
|
+
def on_train_end(self, **kwargs) -> None:
|
|
97
|
+
self.dispatch("on_train_end", **kwargs)
|
|
98
|
+
|
|
99
|
+
def on_epoch_begin(self, **kwargs) -> None:
|
|
100
|
+
self.dispatch("on_epoch_begin", **kwargs)
|
|
101
|
+
|
|
102
|
+
def on_epoch_end(self, **kwargs) -> None:
|
|
103
|
+
self.dispatch("on_epoch_end", **kwargs)
|
|
104
|
+
|
|
105
|
+
def on_step_begin(self, **kwargs) -> None:
|
|
106
|
+
self.dispatch("on_step_begin", **kwargs)
|
|
107
|
+
|
|
108
|
+
def on_step_end(self, **kwargs) -> None:
|
|
109
|
+
self.dispatch("on_step_end", **kwargs)
|
|
110
|
+
|
|
111
|
+
def on_save(self, **kwargs) -> None:
|
|
112
|
+
self.dispatch("on_save", **kwargs)
|
|
113
|
+
|
|
114
|
+
def on_log(self, **kwargs) -> None:
|
|
115
|
+
self.dispatch("on_log", **kwargs)
|
|
116
|
+
|
|
117
|
+
def on_evaluate(self, **kwargs) -> None:
|
|
118
|
+
self.dispatch("on_evaluate", **kwargs)
|
|
119
|
+
|
|
120
|
+
def __len__(self) -> int:
|
|
121
|
+
return len(self._callbacks)
|
|
122
|
+
|
|
123
|
+
def __iter__(self):
|
|
124
|
+
return iter(self._callbacks)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class EarlyStoppingCallback(BaseCallback):
|
|
128
|
+
"""Stop training when a monitored metric stops improving.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
monitor: Metric name to monitor (default: "loss").
|
|
132
|
+
patience: Number of validation checks with no improvement before stopping.
|
|
133
|
+
min_delta: Minimum change to qualify as an improvement.
|
|
134
|
+
mode: "min" (lower is better) or "max" (higher is better).
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
def __init__(
|
|
138
|
+
self,
|
|
139
|
+
monitor: str = "loss",
|
|
140
|
+
patience: int = 3,
|
|
141
|
+
min_delta: float = 0.0,
|
|
142
|
+
mode: str = "min",
|
|
143
|
+
):
|
|
144
|
+
if patience < 1:
|
|
145
|
+
raise ValueError(f"patience must be >= 1, got {patience}")
|
|
146
|
+
if mode not in {"min", "max"}:
|
|
147
|
+
raise ValueError(f"mode must be 'min' or 'max', got '{mode}'")
|
|
148
|
+
self.monitor = monitor
|
|
149
|
+
self.patience = patience
|
|
150
|
+
self.min_delta = min_delta
|
|
151
|
+
self.mode = mode
|
|
152
|
+
self._best: float | None = None
|
|
153
|
+
self._counter: int = 0
|
|
154
|
+
self._should_stop: bool = False
|
|
155
|
+
|
|
156
|
+
def _is_improvement(self, current: float) -> bool:
|
|
157
|
+
if self._best is None:
|
|
158
|
+
return True
|
|
159
|
+
if self.mode == "min":
|
|
160
|
+
return current < self._best - self.min_delta
|
|
161
|
+
return current > self._best + self.min_delta
|
|
162
|
+
|
|
163
|
+
def on_train_begin(self, **kwargs) -> None:
|
|
164
|
+
self._best = None
|
|
165
|
+
self._counter = 0
|
|
166
|
+
self._should_stop = False
|
|
167
|
+
|
|
168
|
+
def on_evaluate(self, **kwargs) -> None:
|
|
169
|
+
current = kwargs.get("metrics", {}).get(self.monitor, None)
|
|
170
|
+
if current is None:
|
|
171
|
+
return
|
|
172
|
+
if self._is_improvement(current):
|
|
173
|
+
self._best = current
|
|
174
|
+
self._counter = 0
|
|
175
|
+
else:
|
|
176
|
+
self._counter += 1
|
|
177
|
+
if self._counter >= self.patience:
|
|
178
|
+
self._should_stop = True
|
|
179
|
+
print(f"[EarlyStopping] Stopping after {self._counter} "
|
|
180
|
+
f"checks without improvement in {self.monitor}")
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def should_stop(self) -> bool:
|
|
184
|
+
return self._should_stop
|
|
185
|
+
|
|
186
|
+
@property
|
|
187
|
+
def best_value(self) -> float | None:
|
|
188
|
+
return self._best
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class ModelCheckpointCallback(BaseCallback):
|
|
192
|
+
"""Save model checkpoints during training.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
dirpath: Directory to save checkpoints.
|
|
196
|
+
monitor: Metric to monitor for best model selection.
|
|
197
|
+
save_top_k: Number of best checkpoints to keep (-1 for all).
|
|
198
|
+
mode: "min" or "max" for metric comparison.
|
|
199
|
+
save_last: Whether to save the last checkpoint.
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
def __init__(
|
|
203
|
+
self,
|
|
204
|
+
dirpath: str = "./checkpoints",
|
|
205
|
+
monitor: str = "loss",
|
|
206
|
+
save_top_k: int = 1,
|
|
207
|
+
mode: str = "min",
|
|
208
|
+
save_last: bool = True,
|
|
209
|
+
):
|
|
210
|
+
self.dirpath = dirpath
|
|
211
|
+
self.monitor = monitor
|
|
212
|
+
self.save_top_k = save_top_k
|
|
213
|
+
self.mode = mode
|
|
214
|
+
self.save_last = save_last
|
|
215
|
+
self._best_scores: list[tuple[float, str]] = []
|
|
216
|
+
os.makedirs(dirpath, exist_ok=True)
|
|
217
|
+
|
|
218
|
+
def on_evaluate(self, **kwargs) -> None:
|
|
219
|
+
score = kwargs.get("metrics", {}).get(self.monitor)
|
|
220
|
+
if score is None:
|
|
221
|
+
return
|
|
222
|
+
step = kwargs.get("step", 0)
|
|
223
|
+
model = kwargs.get("model")
|
|
224
|
+
if model is None:
|
|
225
|
+
return
|
|
226
|
+
|
|
227
|
+
path = os.path.join(self.dirpath, f"checkpoint-{step}")
|
|
228
|
+
self._save_model(model, path)
|
|
229
|
+
|
|
230
|
+
if self.save_top_k > 0:
|
|
231
|
+
self._best_scores.append((float(score), path))
|
|
232
|
+
self._best_scores.sort(
|
|
233
|
+
key=lambda x: x[0],
|
|
234
|
+
reverse=(self.mode == "max"),
|
|
235
|
+
)
|
|
236
|
+
while len(self._best_scores) > self.save_top_k:
|
|
237
|
+
_, old_path = self._best_scores.pop()
|
|
238
|
+
self._remove_checkpoint(old_path)
|
|
239
|
+
|
|
240
|
+
def on_train_end(self, **kwargs) -> None:
|
|
241
|
+
if self.save_last:
|
|
242
|
+
model = kwargs.get("model")
|
|
243
|
+
if model is not None:
|
|
244
|
+
path = os.path.join(self.dirpath, "last")
|
|
245
|
+
self._save_model(model, path)
|
|
246
|
+
|
|
247
|
+
def _save_model(self, model, path: str) -> None:
|
|
248
|
+
try:
|
|
249
|
+
os.makedirs(path, exist_ok=True)
|
|
250
|
+
if hasattr(model, "save_pretrained"):
|
|
251
|
+
model.save_pretrained(path)
|
|
252
|
+
else:
|
|
253
|
+
torch.save(model.state_dict(), os.path.join(path, "pytorch_model.bin"))
|
|
254
|
+
except Exception as e:
|
|
255
|
+
print(f"[ModelCheckpoint] Failed to save to {path}: {e}")
|
|
256
|
+
|
|
257
|
+
@staticmethod
|
|
258
|
+
def _remove_checkpoint(path: str) -> None:
|
|
259
|
+
import shutil
|
|
260
|
+
if os.path.exists(path):
|
|
261
|
+
shutil.rmtree(path, ignore_errors=True)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
class GradientAccumulationCallback(BaseCallback):
|
|
265
|
+
"""Dynamic gradient accumulation step adjustment.
|
|
266
|
+
|
|
267
|
+
Adjusts accumulation steps based on observed batch memory usage
|
|
268
|
+
to fit within available GPU memory.
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
def __init__(self, target_effective_batch: int = 32, min_accum: int = 1, max_accum: int = 128):
|
|
272
|
+
self.target = target_effective_batch
|
|
273
|
+
self.min_accum = min_accum
|
|
274
|
+
self.max_accum = max_accum
|
|
275
|
+
self.current_accum: int = 1
|
|
276
|
+
self._oom_count: int = 0
|
|
277
|
+
|
|
278
|
+
def on_train_begin(self, **kwargs) -> None:
|
|
279
|
+
self.current_accum = getattr(
|
|
280
|
+
kwargs.get("trainer", {}), "gradient_accumulation_steps", 1
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def on_step_end(self, **kwargs) -> None:
|
|
284
|
+
loss = kwargs.get("loss", None)
|
|
285
|
+
if loss is not None and torch.isnan(loss):
|
|
286
|
+
self.current_accum = min(self.current_accum * 2, self.max_accum)
|
|
287
|
+
self._oom_count += 1
|
|
288
|
+
print(f"[GradientAccum] NaN loss detected, increasing accumulation to {self.current_accum}")
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class LearningRateMonitorCallback(BaseCallback):
|
|
292
|
+
"""Log learning rate at each step."""
|
|
293
|
+
|
|
294
|
+
def __init__(self):
|
|
295
|
+
self._lrs: list[tuple[int, float]] = []
|
|
296
|
+
|
|
297
|
+
def on_step_end(self, **kwargs) -> None:
|
|
298
|
+
step = kwargs.get("step", len(self._lrs))
|
|
299
|
+
optimizer = kwargs.get("optimizer")
|
|
300
|
+
if optimizer is not None and len(optimizer.param_groups) > 0:
|
|
301
|
+
lr = optimizer.param_groups[0]["lr"]
|
|
302
|
+
self._lrs.append((step, lr))
|
|
303
|
+
|
|
304
|
+
def get_history(self) -> list[tuple[int, float]]:
|
|
305
|
+
return self._lrs
|
|
306
|
+
|
|
307
|
+
def on_train_end(self, **kwargs) -> None:
|
|
308
|
+
self._lrs.clear()
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
class MemoryProfilerCallback(BaseCallback):
|
|
312
|
+
"""Log GPU/CPU memory usage periodically."""
|
|
313
|
+
|
|
314
|
+
def __init__(self, log_every_n_steps: int = 100):
|
|
315
|
+
self.log_every = log_every_n_steps
|
|
316
|
+
self._step: int = 0
|
|
317
|
+
self._peak_gpu_mb: float = 0.0
|
|
318
|
+
|
|
319
|
+
def on_step_end(self, **kwargs) -> None:
|
|
320
|
+
self._step += 1
|
|
321
|
+
if self._step % self.log_every != 0:
|
|
322
|
+
return
|
|
323
|
+
info_parts = []
|
|
324
|
+
if torch.cuda.is_available():
|
|
325
|
+
try:
|
|
326
|
+
import pynvml
|
|
327
|
+
pynvml.nvmlInit()
|
|
328
|
+
for i in range(torch.cuda.device_count()):
|
|
329
|
+
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
|
330
|
+
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
331
|
+
gb = info.used / (1024 ** 3)
|
|
332
|
+
self._peak_gpu_mb = max(self._peak_gpu_mb, info.used / (1024 ** 2))
|
|
333
|
+
info_parts.append(f"GPU{i}: {gb:.1f}GB")
|
|
334
|
+
except ImportError:
|
|
335
|
+
alloc = torch.cuda.memory_allocated() / (1024 ** 3)
|
|
336
|
+
reserved = torch.cuda.memory_reserved() / (1024 ** 3)
|
|
337
|
+
info_parts.append(f"GPU alloc: {alloc:.1f}GB, reserved: {reserved:.1f}GB")
|
|
338
|
+
import psutil
|
|
339
|
+
try:
|
|
340
|
+
ram = psutil.virtual_memory()
|
|
341
|
+
info_parts.append(f"RAM: {ram.percent:.0f}%")
|
|
342
|
+
except ImportError:
|
|
343
|
+
pass
|
|
344
|
+
if info_parts:
|
|
345
|
+
print(f"[Memory] Step {self._step}: {' | '.join(info_parts)}")
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
class TimingCallback(BaseCallback):
|
|
349
|
+
"""Track epoch and batch timing statistics."""
|
|
350
|
+
|
|
351
|
+
def __init__(self):
|
|
352
|
+
self._epoch_start: float = 0.0
|
|
353
|
+
self._step_start: float = 0.0
|
|
354
|
+
self.epoch_times: list[float] = []
|
|
355
|
+
self.step_times: list[float] = []
|
|
356
|
+
|
|
357
|
+
def on_train_begin(self, **kwargs) -> None:
|
|
358
|
+
self.epoch_times.clear()
|
|
359
|
+
self.step_times.clear()
|
|
360
|
+
|
|
361
|
+
def on_epoch_begin(self, **kwargs) -> None:
|
|
362
|
+
self._epoch_start = time.perf_counter()
|
|
363
|
+
|
|
364
|
+
def on_epoch_end(self, **kwargs) -> None:
|
|
365
|
+
elapsed = time.perf_counter() - self._epoch_start
|
|
366
|
+
self.epoch_times.append(elapsed)
|
|
367
|
+
|
|
368
|
+
def on_step_begin(self, **kwargs) -> None:
|
|
369
|
+
self._step_start = time.perf_counter()
|
|
370
|
+
|
|
371
|
+
def on_step_end(self, **kwargs) -> None:
|
|
372
|
+
elapsed = time.perf_counter() - self._step_start
|
|
373
|
+
self.step_times.append(elapsed)
|
|
374
|
+
|
|
375
|
+
def summary(self) -> dict[str, float]:
|
|
376
|
+
if not self.step_times:
|
|
377
|
+
return {}
|
|
378
|
+
return {
|
|
379
|
+
"mean_step_s": sum(self.step_times) / len(self.step_times),
|
|
380
|
+
"min_step_s": min(self.step_times),
|
|
381
|
+
"max_step_s": max(self.step_times),
|
|
382
|
+
"mean_epoch_s": sum(self.epoch_times) / len(self.epoch_times) if self.epoch_times else 0.0,
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
class ProgressBarCallback(BaseCallback):
|
|
387
|
+
"""Rich progress bar display for training progress."""
|
|
388
|
+
|
|
389
|
+
def __init__(self, total_steps: int = 0, use_rich: bool = False):
|
|
390
|
+
self.total_steps = total_steps
|
|
391
|
+
self._current: int = 0
|
|
392
|
+
self._pbar: Any = None
|
|
393
|
+
self._use_rich = use_rich
|
|
394
|
+
|
|
395
|
+
def on_train_begin(self, **kwargs) -> None:
|
|
396
|
+
if self._use_rich:
|
|
397
|
+
try:
|
|
398
|
+
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
|
|
399
|
+
self._rich_progress = Progress(
|
|
400
|
+
TextColumn("[progress.description]{task.description}"),
|
|
401
|
+
BarColumn(),
|
|
402
|
+
TextColumn("{task.completed}/{task.total}"),
|
|
403
|
+
TimeRemainingColumn(),
|
|
404
|
+
)
|
|
405
|
+
self._rich_progress.start()
|
|
406
|
+
self._task = self._rich_progress.add_task("[cyan]Training...", total=self.total_steps)
|
|
407
|
+
except ImportError:
|
|
408
|
+
self._use_rich = False
|
|
409
|
+
if not self._use_rich:
|
|
410
|
+
from tqdm import tqdm
|
|
411
|
+
self._pbar = tqdm(total=self.total_steps, desc="Training", unit="step")
|
|
412
|
+
|
|
413
|
+
def on_step_end(self, **kwargs) -> None:
|
|
414
|
+
self._current += 1
|
|
415
|
+
loss = kwargs.get("loss", None)
|
|
416
|
+
postfix = {}
|
|
417
|
+
if loss is not None:
|
|
418
|
+
postfix["loss"] = f"{loss:.4f}" if isinstance(loss, (int, float)) else str(loss)
|
|
419
|
+
if self._use_rich and hasattr(self, "_rich_progress"):
|
|
420
|
+
self._rich_progress.update(self._task, advance=1)
|
|
421
|
+
elif self._pbar is not None:
|
|
422
|
+
if postfix:
|
|
423
|
+
self._pbar.set_postfix(postfix)
|
|
424
|
+
self._pbar.update(1)
|
|
425
|
+
|
|
426
|
+
def on_train_end(self, **kwargs) -> None:
|
|
427
|
+
if self._use_rich and hasattr(self, "_rich_progress"):
|
|
428
|
+
self._rich_progress.stop()
|
|
429
|
+
elif self._pbar is not None:
|
|
430
|
+
self._pbar.close()
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
class NaNMonitorCallback(BaseCallback):
|
|
434
|
+
"""Stop training if NaN values are detected in loss or gradients."""
|
|
435
|
+
|
|
436
|
+
def __init__(self, max_nan_count: int = 5):
|
|
437
|
+
self.max_nan_count = max_nan_count
|
|
438
|
+
self._nan_count: int = 0
|
|
439
|
+
self._should_stop: bool = False
|
|
440
|
+
|
|
441
|
+
def on_train_begin(self, **kwargs) -> None:
|
|
442
|
+
self._nan_count = 0
|
|
443
|
+
self._should_stop = False
|
|
444
|
+
|
|
445
|
+
def on_step_end(self, **kwargs) -> None:
|
|
446
|
+
loss = kwargs.get("loss", None)
|
|
447
|
+
if loss is None:
|
|
448
|
+
return
|
|
449
|
+
loss_val = loss.item() if hasattr(loss, "item") else float(loss)
|
|
450
|
+
grad_norm = kwargs.get("grad_norm", None)
|
|
451
|
+
is_nan = math.isnan(loss_val) or math.isinf(loss_val)
|
|
452
|
+
if grad_norm is not None and hasattr(grad_norm, "item"):
|
|
453
|
+
gn = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm)
|
|
454
|
+
is_nan = is_nan or math.isnan(gn) or math.isinf(gn)
|
|
455
|
+
if is_nan:
|
|
456
|
+
self._nan_count += 1
|
|
457
|
+
print(f"[NaNMonitor] NaN/Inf detected ({self._nan_count}/{self.max_nan_count})")
|
|
458
|
+
if self._nan_count >= self.max_nan_count:
|
|
459
|
+
self._should_stop = True
|
|
460
|
+
|
|
461
|
+
@property
|
|
462
|
+
def should_stop(self) -> bool:
|
|
463
|
+
return self._should_stop
|