rxnn 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn-0.1.0.dist-info/LICENSE +201 -0
- rxnn-0.1.0.dist-info/METADATA +257 -0
- rxnn-0.1.0.dist-info/RECORD +23 -0
- rxnn-0.1.0.dist-info/WHEEL +4 -0
- src/experimental/attention.py +133 -0
- src/memory/norm.py +173 -0
- src/memory/stm.py +53 -0
- src/rxt/models.py +180 -0
- src/training/base.py +275 -0
- src/training/bml.py +345 -0
- src/training/callbacks.py +491 -0
- src/training/dataset.py +164 -0
- src/training/scheduler.py +19 -0
- src/training/tokenizer.py +208 -0
- src/transformers/attention.py +324 -0
- src/transformers/ff.py +72 -0
- src/transformers/layers.py +150 -0
- src/transformers/mask.py +10 -0
- src/transformers/models.py +168 -0
- src/transformers/moe.py +139 -0
- src/transformers/positional.py +105 -0
- src/transformers/sampler.py +109 -0
- src/utils.py +14 -0
src/memory/norm.py
ADDED
@@ -0,0 +1,173 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
from typing import TypedDict
|
4
|
+
|
5
|
+
class AdaptivePositionalMemoryNorm(nn.Module):
|
6
|
+
def __init__(
|
7
|
+
self,
|
8
|
+
num_slots: int,
|
9
|
+
dim: int,
|
10
|
+
decay: float = 0.99,
|
11
|
+
use_scale: bool = True,
|
12
|
+
use_gate: bool = True,
|
13
|
+
init_gate: float = -4.0
|
14
|
+
):
|
15
|
+
super(AdaptivePositionalMemoryNorm, self).__init__()
|
16
|
+
self.use_gate = use_gate
|
17
|
+
self.num_slots = num_slots
|
18
|
+
self.dim = dim
|
19
|
+
self.decay = decay
|
20
|
+
self.eps = 1e-6
|
21
|
+
|
22
|
+
# Learnable parameters
|
23
|
+
self.scale = nn.Parameter(torch.ones(num_slots, 1, dim)) if use_scale else None
|
24
|
+
self.gate = nn.Parameter(torch.full((num_slots, 1, 1), init_gate)) if use_gate else None
|
25
|
+
|
26
|
+
# EMA buffers
|
27
|
+
self.register_buffer("ema_rms", torch.ones(num_slots, 1))
|
28
|
+
|
29
|
+
# Initialize parameters
|
30
|
+
if self.scale is not None:
|
31
|
+
nn.init.normal_(self.scale, mean=1.0, std=0.01)
|
32
|
+
|
33
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
34
|
+
# x shape: [batch_size, num_slots, dim]
|
35
|
+
batch_size = x.size(0)
|
36
|
+
|
37
|
+
# Calculate current RMS per slot
|
38
|
+
current_rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt() # [batch, slots, 1]
|
39
|
+
slot_rms = current_rms.mean(dim=0) # [slots, 1] (average over batch)
|
40
|
+
|
41
|
+
# Update EMA during training
|
42
|
+
if self.training:
|
43
|
+
self.ema_rms = self.decay * self.ema_rms + (1 - self.decay) * slot_rms.detach()
|
44
|
+
|
45
|
+
# Normalize using EMA statistics
|
46
|
+
x_norm = x * torch.rsqrt(self.ema_rms + self.eps)
|
47
|
+
|
48
|
+
# Apply learned scale per slot
|
49
|
+
if self.scale is not None:
|
50
|
+
x_norm = x_norm * self.scale
|
51
|
+
|
52
|
+
# Apply gating mechanism
|
53
|
+
if self.use_gate:
|
54
|
+
gate = torch.sigmoid(self.gate) # [slots, 1, 1]
|
55
|
+
return gate * x_norm + (1 - gate) * x
|
56
|
+
|
57
|
+
return x_norm
|
58
|
+
|
59
|
+
class AdaptiveRMSMemoryNorm(nn.Module):
|
60
|
+
def __init__(
|
61
|
+
self,
|
62
|
+
dim: int,
|
63
|
+
use_gate: bool = True,
|
64
|
+
decay: float = 0.99,
|
65
|
+
init_scale: float = 1.0,
|
66
|
+
init_gate: float = -4.0 # Start with gate closed (no normalization)
|
67
|
+
):
|
68
|
+
super().__init__()
|
69
|
+
self.use_gate = use_gate
|
70
|
+
self.scale = nn.Parameter(torch.ones(dim) * init_scale)
|
71
|
+
self.gate = nn.Parameter(torch.tensor([init_gate])) # Scalar gate for this layer
|
72
|
+
self.eps = 1e-6
|
73
|
+
self.decay = decay
|
74
|
+
self.register_buffer("ema_rms", torch.ones(1)) # Scalar EMA RMS for the entire layer's STM
|
75
|
+
|
76
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
77
|
+
# x shape: [batch_size, num_slots, dim]
|
78
|
+
if self.training and hasattr(self, 'ema_rms'):
|
79
|
+
# Compute current RMS across all slots and batch (scalar)
|
80
|
+
current_rms = x.pow(2).mean(-1).mean().sqrt()
|
81
|
+
self.ema_rms = self.ema_rms * self.decay + current_rms * (1 - self.decay)
|
82
|
+
rms = self.ema_rms
|
83
|
+
else:
|
84
|
+
# Compute RMS per slot (mean over dim)
|
85
|
+
rms = x.pow(2).mean(-1, keepdim=True).sqrt() # [batch_size, num_slots, 1]
|
86
|
+
|
87
|
+
# Normalize each slot's embedding vector
|
88
|
+
normalized = x * torch.rsqrt(rms + self.eps)
|
89
|
+
normalized = normalized * self.scale # Apply per-dimension scaling
|
90
|
+
|
91
|
+
if self.use_gate:
|
92
|
+
gate_factor = torch.sigmoid(self.gate) # Scalar gate (0-1)
|
93
|
+
return normalized * gate_factor + x * (1 - gate_factor)
|
94
|
+
else:
|
95
|
+
return normalized
|
96
|
+
|
97
|
+
class SimpleRMSMemoryNorm(nn.Module):
|
98
|
+
def __init__(
|
99
|
+
self,
|
100
|
+
dim: int,
|
101
|
+
use_gate: bool = True,
|
102
|
+
init_scale: float = 1.0,
|
103
|
+
init_gate: float = -4.0
|
104
|
+
):
|
105
|
+
super().__init__()
|
106
|
+
self.use_gate = use_gate
|
107
|
+
self.scale = nn.Parameter(torch.ones(dim) * init_scale)
|
108
|
+
self.gate = nn.Parameter(torch.tensor([init_gate])) # Scalar gate
|
109
|
+
self.eps = 1e-6
|
110
|
+
|
111
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
112
|
+
rms = x.pow(2).mean(-1, keepdim=True).sqrt() # [batch_size, num_slots, 1]
|
113
|
+
normalized = x * torch.rsqrt(rms + self.eps)
|
114
|
+
normalized = normalized * self.scale # Apply per-dimension scaling
|
115
|
+
|
116
|
+
if self.use_gate:
|
117
|
+
gate_factor = torch.sigmoid(self.gate) # Scalar gate (0-1)
|
118
|
+
return normalized * gate_factor + x * (1 - gate_factor)
|
119
|
+
else:
|
120
|
+
return normalized
|
121
|
+
|
122
|
+
class MemoryLayerNorm(nn.Module):
|
123
|
+
def __init__(
|
124
|
+
self,
|
125
|
+
dim: int,
|
126
|
+
use_gate: bool = True,
|
127
|
+
init_scale: float = 1.0,
|
128
|
+
init_gate: float = -4.0 # Start with gate closed (no normalization)
|
129
|
+
):
|
130
|
+
super().__init__()
|
131
|
+
self.use_gate = use_gate
|
132
|
+
self.norm = nn.LayerNorm(dim) # Normalizes across embedding dimensions per slot
|
133
|
+
self.gate = nn.Parameter(torch.tensor([init_gate])) # Scalar gate for this layer
|
134
|
+
self.scale = nn.Parameter(torch.ones(dim) * init_scale)
|
135
|
+
|
136
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
137
|
+
normalized = self.norm(x) # Apply LayerNorm across embedding dimensions (per slot)
|
138
|
+
normalized = normalized * self.scale # Per-dimension scaling
|
139
|
+
|
140
|
+
if self.use_gate:
|
141
|
+
gate_factor = torch.sigmoid(self.gate) # Scalar gate (0-1)
|
142
|
+
return normalized * gate_factor + x * (1 - gate_factor)
|
143
|
+
else:
|
144
|
+
return normalized
|
145
|
+
|
146
|
+
class MemoryNormConfig(TypedDict):
|
147
|
+
num_slots: int
|
148
|
+
decay: float
|
149
|
+
use_scale: bool
|
150
|
+
use_gate: bool
|
151
|
+
init_gate: float
|
152
|
+
init_scale: float
|
153
|
+
|
154
|
+
def init_memory_norm(
|
155
|
+
norm_type: str,
|
156
|
+
dim: int,
|
157
|
+
num_slots: int = None,
|
158
|
+
decay: float = 0.99,
|
159
|
+
use_scale: bool = True,
|
160
|
+
use_gate: bool = True,
|
161
|
+
init_gate: float = -4.0,
|
162
|
+
init_scale: float = 1.0,
|
163
|
+
) -> nn.Module:
|
164
|
+
assert norm_type in ["layer", "rms", "adaptive", "positional"]
|
165
|
+
if norm_type == "layer":
|
166
|
+
return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
|
167
|
+
elif norm_type == "rms":
|
168
|
+
return SimpleRMSMemoryNorm(dim, use_gate, init_scale, init_gate)
|
169
|
+
elif norm_type == "adaptive":
|
170
|
+
return AdaptiveRMSMemoryNorm(dim, use_gate, decay, init_scale, init_gate)
|
171
|
+
elif norm_type == "positional":
|
172
|
+
return AdaptivePositionalMemoryNorm(num_slots, dim, decay, use_scale, use_gate, init_gate)
|
173
|
+
return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
|
src/memory/stm.py
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
|
4
|
+
class ShortTermMemory(nn.Module):
|
5
|
+
"""Short-term memory module for the Attention-based Memory System"""
|
6
|
+
|
7
|
+
def __init__(self, num_layers: int, embed_dim: int, stm_size: int, init_type: str = 'normal',
|
8
|
+
is_trainable: bool = False, *args, **kwargs):
|
9
|
+
super(ShortTermMemory, self).__init__(*args, **kwargs)
|
10
|
+
self.num_layers = num_layers
|
11
|
+
self.embed_dim = embed_dim
|
12
|
+
self.stm_size = stm_size
|
13
|
+
self.is_trainable = is_trainable
|
14
|
+
assert init_type in ['normal', 'standard', 'uniform', 'ones', 'zeros'], \
|
15
|
+
'STM init type must be one of "normal", "standard", "uniform", "ones", "zeros"'
|
16
|
+
if init_type == 'normal':
|
17
|
+
stm = torch.normal(0, 0.02, (num_layers, stm_size, embed_dim))
|
18
|
+
elif init_type == 'standard':
|
19
|
+
stm = torch.normal(0, 1, (num_layers, stm_size, embed_dim))
|
20
|
+
elif init_type == 'uniform':
|
21
|
+
stm = torch.rand(num_layers, stm_size, embed_dim) * 0.02
|
22
|
+
elif init_type == 'ones':
|
23
|
+
stm = torch.ones(num_layers, stm_size, embed_dim)
|
24
|
+
else:
|
25
|
+
stm = torch.zeros(num_layers, stm_size, embed_dim)
|
26
|
+
|
27
|
+
if self.is_trainable:
|
28
|
+
self.memory = nn.Parameter(stm)
|
29
|
+
else:
|
30
|
+
self.register_buffer('memory', stm)
|
31
|
+
|
32
|
+
def forward(self, layer: int) -> torch.Tensor:
|
33
|
+
return self.memory[layer].unsqueeze(0)
|
34
|
+
|
35
|
+
def update_layer(self, layer: int, new_stm: torch.Tensor):
|
36
|
+
self.memory[layer] = new_stm
|
37
|
+
|
38
|
+
def update_all(self, new_stm: torch.Tensor):
|
39
|
+
self.memory.copy_(new_stm)
|
40
|
+
|
41
|
+
def make_trainable(self):
|
42
|
+
if not self.is_trainable:
|
43
|
+
self.is_trainable = True
|
44
|
+
initial_stm = self.memory.clone()
|
45
|
+
del self.memory
|
46
|
+
self.memory = nn.Parameter(initial_stm)
|
47
|
+
|
48
|
+
def freeze(self):
|
49
|
+
if self.is_trainable:
|
50
|
+
self.requires_grad_(False)
|
51
|
+
trained_stm = self.memory.clone()
|
52
|
+
del self.memory
|
53
|
+
self.register_buffer('memory', trained_stm)
|
src/rxt/models.py
ADDED
@@ -0,0 +1,180 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
from typing import TypedDict, Union
|
4
|
+
from huggingface_hub import PyTorchModelHubMixin
|
5
|
+
from src.transformers.positional import RotaryPositionalEmbedding
|
6
|
+
from src.transformers.attention import init_attention
|
7
|
+
from src.transformers.layers import ReactiveTransformerLayer
|
8
|
+
from src.transformers.models import ReactiveTransformerBase, ReactiveTransformerEncoder, ReactiveTransformerDecoder
|
9
|
+
from src.transformers.ff import get_activation_layer
|
10
|
+
from src.memory.stm import ShortTermMemory
|
11
|
+
from src.utils import get_model_size
|
12
|
+
|
13
|
+
|
14
|
+
class RxTAlphaComponentConfig(TypedDict):
|
15
|
+
num_layers: int
|
16
|
+
vocab_size: int
|
17
|
+
embed_dim: int
|
18
|
+
ff_dim: int
|
19
|
+
att_heads: int
|
20
|
+
seq_len: int
|
21
|
+
stm_size: int
|
22
|
+
use_flash_attention: bool
|
23
|
+
use_gated: bool
|
24
|
+
ff_activation: str
|
25
|
+
ff_dropout: float
|
26
|
+
att_dropout: float
|
27
|
+
use_rms_norm: bool
|
28
|
+
att_groups: int
|
29
|
+
use_moe: bool
|
30
|
+
num_experts: int
|
31
|
+
moe_top_k: int
|
32
|
+
self_att_type: str
|
33
|
+
cross_att_type: str
|
34
|
+
|
35
|
+
|
36
|
+
class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
37
|
+
"""Base class for RxT-Alpha (Reactive Transformer) components (encoder and decoder)"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
is_causal: bool,
|
42
|
+
num_layers: int = 12,
|
43
|
+
vocab_size: int = 20000,
|
44
|
+
embed_dim: int = 512,
|
45
|
+
ff_dim: int = 1536,
|
46
|
+
att_heads: int = 16,
|
47
|
+
seq_len: int = 1024,
|
48
|
+
stm_size: int = 1024,
|
49
|
+
use_flash_attention: bool = True,
|
50
|
+
use_gated: bool = True,
|
51
|
+
ff_activation: str = "swish",
|
52
|
+
ff_dropout: float = 0.0,
|
53
|
+
att_dropout: float = 0.0,
|
54
|
+
use_rms_norm: bool = True,
|
55
|
+
att_groups: int = 1,
|
56
|
+
use_moe: bool = False,
|
57
|
+
num_experts: int = 1,
|
58
|
+
moe_top_k: int = 1,
|
59
|
+
self_att_type: str = 'gqa',
|
60
|
+
cross_att_type: str = 'mqa',
|
61
|
+
**kwargs
|
62
|
+
):
|
63
|
+
super(RxTAlphaComponentBase, self).__init__(**kwargs)
|
64
|
+
assert ff_activation in ['relu', 'gelu',
|
65
|
+
'swish', 'silu', 'linear',
|
66
|
+
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
67
|
+
assert self_att_type in ['mha', 'gqa', 'mqa'], 'Self-attention type could be "mha", "gqa", "mqa"'
|
68
|
+
assert cross_att_type in ['mha', 'gqa', 'mqa'], 'Memory cross-attention type could be "mha", "gqa", "mqa"'
|
69
|
+
|
70
|
+
embedding = nn.Embedding(vocab_size, embed_dim)
|
71
|
+
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
72
|
+
stm = ShortTermMemory(num_layers, embed_dim, stm_size)
|
73
|
+
|
74
|
+
ff_activation = get_activation_layer(ff_activation)
|
75
|
+
|
76
|
+
layers = nn.ModuleList([
|
77
|
+
ReactiveTransformerLayer(
|
78
|
+
embed_dim,
|
79
|
+
ff_dim,
|
80
|
+
use_gated=use_gated,
|
81
|
+
use_moe=use_moe,
|
82
|
+
num_experts=num_experts,
|
83
|
+
moe_top_k=moe_top_k,
|
84
|
+
ff_activation=ff_activation,
|
85
|
+
ff_dropout=ff_dropout,
|
86
|
+
use_rms_norm=use_rms_norm,
|
87
|
+
self_attention=init_attention(embed_dim, att_heads, self_att_type, att_groups, rope=rope,
|
88
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
89
|
+
max_seq_len=seq_len, is_causal=is_causal),
|
90
|
+
memory_cross_attention=init_attention(embed_dim, att_heads, cross_att_type, att_groups, rope=rope,
|
91
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
92
|
+
max_seq_len=seq_len, rope_only_for_query=True,
|
93
|
+
is_causal=is_causal)
|
94
|
+
) for _ in range(num_layers)
|
95
|
+
])
|
96
|
+
self.model = self._init_model(stm, layers, embedding, use_flash_attention, embed_dim, vocab_size)
|
97
|
+
|
98
|
+
def _init_model(self, stm: ShortTermMemory, layers: nn.ModuleList, embedding: nn.Embedding,
|
99
|
+
use_flash_attention: bool, embed_dim: int, vocab_size: int) -> ReactiveTransformerBase:
|
100
|
+
pass
|
101
|
+
|
102
|
+
def params_count(self):
|
103
|
+
return get_model_size(self.model)
|
104
|
+
|
105
|
+
def load_shared_embedding(self, embedding: nn.Embedding):
|
106
|
+
self.model.embedding = embedding
|
107
|
+
|
108
|
+
def load_shared_memory(self, stm: ShortTermMemory):
|
109
|
+
self.model.stm = stm
|
110
|
+
|
111
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> Union[
|
112
|
+
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
113
|
+
return self.model(x, attention_mask=attention_mask)
|
114
|
+
|
115
|
+
|
116
|
+
class RxTAlphaEncoder(RxTAlphaComponentBase, pipeline_tag="fill-mask", license="apache-2.0"):
|
117
|
+
"""RxT-Alpha (Reactive Transformer) encoder model"""
|
118
|
+
|
119
|
+
def __init__(self, **kwargs: RxTAlphaComponentConfig):
|
120
|
+
super(RxTAlphaEncoder, self).__init__(False, **kwargs)
|
121
|
+
|
122
|
+
def _init_model(
|
123
|
+
self,
|
124
|
+
stm: ShortTermMemory,
|
125
|
+
layers: nn.ModuleList,
|
126
|
+
embedding: nn.Embedding,
|
127
|
+
use_flash_attention: bool,
|
128
|
+
embed_dim: int,
|
129
|
+
vocab_size: int
|
130
|
+
) -> ReactiveTransformerEncoder:
|
131
|
+
return ReactiveTransformerEncoder(
|
132
|
+
stm=stm,
|
133
|
+
embedding=embedding,
|
134
|
+
own_layers=layers,
|
135
|
+
use_flash_attention=use_flash_attention,
|
136
|
+
)
|
137
|
+
|
138
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
139
|
+
return self.model(x, attention_mask=attention_mask)
|
140
|
+
|
141
|
+
|
142
|
+
class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", license="apache-2.0"):
|
143
|
+
"""RxT-Alpha (Reactive Transformer) decoder model"""
|
144
|
+
|
145
|
+
def __init__(self, **kwargs):
|
146
|
+
super(RxTAlphaDecoder, self).__init__(True, **kwargs)
|
147
|
+
|
148
|
+
def _init_model(
|
149
|
+
self, stm: ShortTermMemory,
|
150
|
+
layers: nn.ModuleList,
|
151
|
+
embedding: nn.Embedding,
|
152
|
+
use_flash_attention: bool,
|
153
|
+
embed_dim: int,
|
154
|
+
vocab_size: int
|
155
|
+
) -> ReactiveTransformerDecoder:
|
156
|
+
return ReactiveTransformerDecoder(
|
157
|
+
embed_dim,
|
158
|
+
vocab_size,
|
159
|
+
stm=stm,
|
160
|
+
embedding=embedding,
|
161
|
+
own_layers=layers,
|
162
|
+
use_flash_attention=use_flash_attention,
|
163
|
+
)
|
164
|
+
|
165
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
166
|
+
return self.model(x, attention_mask=attention_mask)
|
167
|
+
|
168
|
+
|
169
|
+
def build_rxt_alpha_for_pretraining(
|
170
|
+
encoder_config: RxTAlphaComponentConfig,
|
171
|
+
decoder_config: RxTAlphaComponentConfig,
|
172
|
+
) -> tuple[RxTAlphaEncoder, RxTAlphaDecoder]:
|
173
|
+
encoder = RxTAlphaEncoder(**encoder_config)
|
174
|
+
decoder = RxTAlphaDecoder(**decoder_config)
|
175
|
+
|
176
|
+
encoder.load_shared_memory(decoder.model.stm)
|
177
|
+
encoder.load_shared_embedding(decoder.model.embedding)
|
178
|
+
|
179
|
+
return encoder, decoder
|
180
|
+
|
src/training/base.py
ADDED
@@ -0,0 +1,275 @@
|
|
1
|
+
import torch
|
2
|
+
import math
|
3
|
+
import os
|
4
|
+
|
5
|
+
from abc import ABC, abstractmethod
|
6
|
+
from torch.utils.tensorboard import SummaryWriter
|
7
|
+
import torch.distributed as dist
|
8
|
+
from torch.nn.parallel import DistributedDataParallel
|
9
|
+
from typing import Callable
|
10
|
+
from callbacks import TrainerCallback
|
11
|
+
|
12
|
+
|
13
|
+
class BaseTrainer(ABC):
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
model: torch.nn.Module,
|
17
|
+
device: torch.device,
|
18
|
+
optimizer: torch.optim.Optimizer = None,
|
19
|
+
dataset: torch.utils.data.Dataset = None,
|
20
|
+
validation_dataset: torch.utils.data.Dataset = None,
|
21
|
+
callbacks: list[TrainerCallback] = None,
|
22
|
+
log_dir: str = None,
|
23
|
+
use_ddp: bool = False,
|
24
|
+
use_amp: bool = False,
|
25
|
+
dtype: torch.dtype = None,
|
26
|
+
target_field_name: str = 'labels',
|
27
|
+
get_batch_size: Callable[[dict], int] = None,
|
28
|
+
gradient_accumulation_steps: int = 1,
|
29
|
+
):
|
30
|
+
if get_batch_size is None:
|
31
|
+
self.get_batch_size = lambda batch: batch['attention_mask'].size(0)
|
32
|
+
else:
|
33
|
+
self.get_batch_size = get_batch_size
|
34
|
+
if use_amp:
|
35
|
+
self.model = model.to(device)
|
36
|
+
else:
|
37
|
+
self.model = model.to(device, dtype=dtype)
|
38
|
+
self.device = device
|
39
|
+
self.optimizer = optimizer
|
40
|
+
self.dataset = dataset
|
41
|
+
self.callbacks = callbacks or []
|
42
|
+
self.writer = SummaryWriter(log_dir) if log_dir else None
|
43
|
+
self.use_ddp = use_ddp
|
44
|
+
self.use_amp = use_amp
|
45
|
+
self.dtype = dtype
|
46
|
+
self.is_running = False
|
47
|
+
self.validation_dataset = validation_dataset
|
48
|
+
self.best_val_loss = float('inf')
|
49
|
+
self.validation_metrics = {}
|
50
|
+
self.target_field_name = target_field_name
|
51
|
+
self.total_tokens = 0
|
52
|
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
53
|
+
self.accumulated_loss = 0.0
|
54
|
+
self.optimizer_step_count = 0
|
55
|
+
|
56
|
+
@abstractmethod
|
57
|
+
def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
58
|
+
pass
|
59
|
+
|
60
|
+
def train_step(self, batch: dict[str, torch.Tensor], _batch_idx: int) -> torch.Tensor:
|
61
|
+
if self.use_amp:
|
62
|
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
63
|
+
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
64
|
+
loss, _ = self.compute_loss(batch)
|
65
|
+
else:
|
66
|
+
batch = {k: v.to(self.device, dtype=self.dtype) for k, v in batch.items()}
|
67
|
+
loss, _ = self.compute_loss(batch)
|
68
|
+
return loss
|
69
|
+
|
70
|
+
def __call__(
|
71
|
+
self,
|
72
|
+
epochs: int,
|
73
|
+
batch_size: int,
|
74
|
+
dataset: torch.utils.data.Dataset = None,
|
75
|
+
optimizer: torch.optim.Optimizer = None,
|
76
|
+
scheduler: torch.optim.lr_scheduler.LRScheduler = None,
|
77
|
+
) -> None:
|
78
|
+
self.is_running = True
|
79
|
+
if dataset is None:
|
80
|
+
assert self.dataset is not None, 'You have to specify a dataset for training'
|
81
|
+
dataset = self.dataset
|
82
|
+
if optimizer is None:
|
83
|
+
assert self.optimizer is not None, 'You have to specify an optimizer for training'
|
84
|
+
optimizer = self.optimizer
|
85
|
+
|
86
|
+
if self.use_ddp:
|
87
|
+
rank = int(os.environ['RANK'])
|
88
|
+
world_size = int(os.environ['WORLD_SIZE'])
|
89
|
+
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
90
|
+
self.model = DistributedDataParallel(self.model, device_ids=[self.device.index])
|
91
|
+
train_sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
|
92
|
+
dataloader = torch.utils.data.DataLoader(
|
93
|
+
dataset,
|
94
|
+
batch_size=batch_size,
|
95
|
+
sampler=train_sampler,
|
96
|
+
pin_memory=True,
|
97
|
+
)
|
98
|
+
else:
|
99
|
+
train_sampler = None
|
100
|
+
dataloader = torch.utils.data.DataLoader(
|
101
|
+
dataset,
|
102
|
+
batch_size=batch_size,
|
103
|
+
shuffle=True,
|
104
|
+
pin_memory=True
|
105
|
+
)
|
106
|
+
|
107
|
+
scaler = torch.amp.GradScaler() if self.use_amp else None
|
108
|
+
|
109
|
+
self.model.train()
|
110
|
+
for epoch in range(epochs):
|
111
|
+
if self.is_running:
|
112
|
+
if train_sampler is not None:
|
113
|
+
train_sampler.set_epoch(epoch)
|
114
|
+
self._run_epoch(dataloader, epoch, optimizer, batch_size, scaler=scaler, scheduler=scheduler)
|
115
|
+
if self.use_ddp:
|
116
|
+
dist.barrier()
|
117
|
+
|
118
|
+
if self.use_ddp:
|
119
|
+
dist.destroy_process_group()
|
120
|
+
self.is_running = False
|
121
|
+
self.on_training_end()
|
122
|
+
|
123
|
+
def _run_epoch(
|
124
|
+
self,
|
125
|
+
dataloader: torch.utils.data.DataLoader,
|
126
|
+
epoch: int,
|
127
|
+
optimizer: torch.optim.Optimizer,
|
128
|
+
batch_size: int,
|
129
|
+
scaler: torch.cuda.amp.GradScaler = None,
|
130
|
+
scheduler: torch.optim.lr_scheduler.LRScheduler = None,
|
131
|
+
) -> None:
|
132
|
+
for callback in self.callbacks:
|
133
|
+
callback.on_epoch_start(self.model, epoch)
|
134
|
+
|
135
|
+
self.accumulated_loss = 0.0
|
136
|
+
self.optimizer_step_count = 0
|
137
|
+
|
138
|
+
for batch_idx, batch in enumerate(dataloader):
|
139
|
+
if self.is_running:
|
140
|
+
for callback in self.callbacks:
|
141
|
+
callback.on_batch_start(self.model, batch_idx, batch)
|
142
|
+
if self.get_batch_size(batch) == batch_size:
|
143
|
+
loss = self.train_step(batch, batch_idx)
|
144
|
+
self.accumulated_loss += loss.item()
|
145
|
+
loss = loss / self.gradient_accumulation_steps
|
146
|
+
|
147
|
+
if self.use_amp:
|
148
|
+
scaler.scale(loss).backward()
|
149
|
+
else:
|
150
|
+
loss.backward()
|
151
|
+
|
152
|
+
self.optimizer_step_count += 1
|
153
|
+
if self.optimizer_step_count % self.gradient_accumulation_steps == 0:
|
154
|
+
# Clip gradients after accumulation
|
155
|
+
if self.use_amp:
|
156
|
+
scaler.unscale_(optimizer)
|
157
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0, error_if_nonfinite=False)
|
158
|
+
if self.use_amp:
|
159
|
+
scaler.step(optimizer)
|
160
|
+
scaler.update()
|
161
|
+
else:
|
162
|
+
optimizer.step()
|
163
|
+
|
164
|
+
optimizer.zero_grad()
|
165
|
+
|
166
|
+
if scheduler is not None:
|
167
|
+
scheduler.step()
|
168
|
+
|
169
|
+
if self.writer:
|
170
|
+
loss_item = self.accumulated_loss / self.gradient_accumulation_steps
|
171
|
+
self.writer.add_scalar(
|
172
|
+
'Loss/train',
|
173
|
+
loss_item,
|
174
|
+
epoch * len(dataloader) + batch_idx
|
175
|
+
)
|
176
|
+
self.writer.add_scalar(
|
177
|
+
'Loss per epoch/train',
|
178
|
+
loss_item,
|
179
|
+
batch_idx
|
180
|
+
)
|
181
|
+
self.writer.add_scalar(
|
182
|
+
'Perplexity/train',
|
183
|
+
torch.exp(torch.tensor(loss_item)),
|
184
|
+
epoch * len(dataloader) + batch_idx
|
185
|
+
)
|
186
|
+
self.accumulated_loss = 0.0
|
187
|
+
self.optimizer_step_count = 0
|
188
|
+
|
189
|
+
if self.writer:
|
190
|
+
self.total_tokens += batch['attention_mask'].sum().item()
|
191
|
+
self.writer.add_scalar('Processed tokens', self.total_tokens,
|
192
|
+
epoch * len(dataloader) + batch_idx)
|
193
|
+
|
194
|
+
for callback in self.callbacks:
|
195
|
+
should_stop = callback.on_batch_end(self.model, batch_idx, loss.item(), batch)
|
196
|
+
if should_stop:
|
197
|
+
self.is_running = False
|
198
|
+
|
199
|
+
if self.validation_dataset:
|
200
|
+
val_loss, val_metrics = self.validate(batch_size)
|
201
|
+
val_loss_tensor = torch.tensor(val_loss).to(self.device)
|
202
|
+
if self.use_ddp:
|
203
|
+
dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.SUM)
|
204
|
+
val_loss = val_loss_tensor.item() / dist.get_world_size()
|
205
|
+
self.validation_metrics[epoch] = val_metrics
|
206
|
+
|
207
|
+
if self.writer:
|
208
|
+
self._valid_writer(epoch, val_loss, val_metrics)
|
209
|
+
|
210
|
+
for callback in self.callbacks:
|
211
|
+
callback.on_validation_end(self.model, epoch, val_loss, val_metrics)
|
212
|
+
|
213
|
+
for callback in self.callbacks:
|
214
|
+
should_stop = callback.on_epoch_end(self.model, epoch)
|
215
|
+
if should_stop:
|
216
|
+
self.is_running = False
|
217
|
+
|
218
|
+
if self.writer:
|
219
|
+
self.writer.flush()
|
220
|
+
|
221
|
+
def on_training_end(self):
|
222
|
+
for callback in self.callbacks:
|
223
|
+
callback.on_training_end(self.model)
|
224
|
+
if self.writer:
|
225
|
+
self.writer.close()
|
226
|
+
|
227
|
+
def _valid_writer(self, epoch: int, val_loss: float, val_metrics: dict):
|
228
|
+
self.writer.add_scalar('Loss/validation', val_loss, epoch)
|
229
|
+
self.writer.add_scalar('Perplexity/validation', math.exp(val_loss), epoch)
|
230
|
+
if val_metrics['accuracy']:
|
231
|
+
self.writer.add_scalar('Accuracy/validation', val_metrics['accuracy'], epoch)
|
232
|
+
|
233
|
+
def valid_step(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
234
|
+
if self.use_amp:
|
235
|
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
236
|
+
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
237
|
+
loss, outputs = self.compute_loss(batch)
|
238
|
+
else:
|
239
|
+
batch = {k: v.to(self.device, dtype=self.dtype) for k, v in batch.items()}
|
240
|
+
loss, outputs = self.compute_loss(batch)
|
241
|
+
return loss, outputs
|
242
|
+
|
243
|
+
def _valid_loader(self, batch_size: int):
|
244
|
+
val_dataset = self.validation_dataset
|
245
|
+
if self.use_ddp:
|
246
|
+
val_sampler = torch.utils.data.DistributedSampler(val_dataset, shuffle=False)
|
247
|
+
return torch.utils.data.DataLoader(
|
248
|
+
val_dataset,
|
249
|
+
batch_size=batch_size,
|
250
|
+
pin_memory=True,
|
251
|
+
sampler=val_sampler,
|
252
|
+
)
|
253
|
+
else:
|
254
|
+
return torch.utils.data.DataLoader(
|
255
|
+
val_dataset,
|
256
|
+
batch_size=batch_size,
|
257
|
+
shuffle=False,
|
258
|
+
pin_memory=True
|
259
|
+
)
|
260
|
+
|
261
|
+
def validate(self, batch_size: int) -> tuple[float, dict]:
|
262
|
+
self.model.eval()
|
263
|
+
val_dataloader = self._valid_loader(batch_size)
|
264
|
+
val_loss = 0.0
|
265
|
+
|
266
|
+
with torch.no_grad():
|
267
|
+
for batch in val_dataloader:
|
268
|
+
if self.get_batch_size(batch) == batch_size:
|
269
|
+
loss, outputs = self.valid_step(batch)
|
270
|
+
val_loss += loss.item()
|
271
|
+
|
272
|
+
avg_loss = val_loss / len(val_dataloader)
|
273
|
+
metrics = {}
|
274
|
+
self.model.train()
|
275
|
+
return avg_loss, metrics
|