rxnn 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
src/memory/norm.py ADDED
@@ -0,0 +1,173 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import TypedDict
4
+
5
+ class AdaptivePositionalMemoryNorm(nn.Module):
6
+ def __init__(
7
+ self,
8
+ num_slots: int,
9
+ dim: int,
10
+ decay: float = 0.99,
11
+ use_scale: bool = True,
12
+ use_gate: bool = True,
13
+ init_gate: float = -4.0
14
+ ):
15
+ super(AdaptivePositionalMemoryNorm, self).__init__()
16
+ self.use_gate = use_gate
17
+ self.num_slots = num_slots
18
+ self.dim = dim
19
+ self.decay = decay
20
+ self.eps = 1e-6
21
+
22
+ # Learnable parameters
23
+ self.scale = nn.Parameter(torch.ones(num_slots, 1, dim)) if use_scale else None
24
+ self.gate = nn.Parameter(torch.full((num_slots, 1, 1), init_gate)) if use_gate else None
25
+
26
+ # EMA buffers
27
+ self.register_buffer("ema_rms", torch.ones(num_slots, 1))
28
+
29
+ # Initialize parameters
30
+ if self.scale is not None:
31
+ nn.init.normal_(self.scale, mean=1.0, std=0.01)
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ # x shape: [batch_size, num_slots, dim]
35
+ batch_size = x.size(0)
36
+
37
+ # Calculate current RMS per slot
38
+ current_rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt() # [batch, slots, 1]
39
+ slot_rms = current_rms.mean(dim=0) # [slots, 1] (average over batch)
40
+
41
+ # Update EMA during training
42
+ if self.training:
43
+ self.ema_rms = self.decay * self.ema_rms + (1 - self.decay) * slot_rms.detach()
44
+
45
+ # Normalize using EMA statistics
46
+ x_norm = x * torch.rsqrt(self.ema_rms + self.eps)
47
+
48
+ # Apply learned scale per slot
49
+ if self.scale is not None:
50
+ x_norm = x_norm * self.scale
51
+
52
+ # Apply gating mechanism
53
+ if self.use_gate:
54
+ gate = torch.sigmoid(self.gate) # [slots, 1, 1]
55
+ return gate * x_norm + (1 - gate) * x
56
+
57
+ return x_norm
58
+
59
+ class AdaptiveRMSMemoryNorm(nn.Module):
60
+ def __init__(
61
+ self,
62
+ dim: int,
63
+ use_gate: bool = True,
64
+ decay: float = 0.99,
65
+ init_scale: float = 1.0,
66
+ init_gate: float = -4.0 # Start with gate closed (no normalization)
67
+ ):
68
+ super().__init__()
69
+ self.use_gate = use_gate
70
+ self.scale = nn.Parameter(torch.ones(dim) * init_scale)
71
+ self.gate = nn.Parameter(torch.tensor([init_gate])) # Scalar gate for this layer
72
+ self.eps = 1e-6
73
+ self.decay = decay
74
+ self.register_buffer("ema_rms", torch.ones(1)) # Scalar EMA RMS for the entire layer's STM
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ # x shape: [batch_size, num_slots, dim]
78
+ if self.training and hasattr(self, 'ema_rms'):
79
+ # Compute current RMS across all slots and batch (scalar)
80
+ current_rms = x.pow(2).mean(-1).mean().sqrt()
81
+ self.ema_rms = self.ema_rms * self.decay + current_rms * (1 - self.decay)
82
+ rms = self.ema_rms
83
+ else:
84
+ # Compute RMS per slot (mean over dim)
85
+ rms = x.pow(2).mean(-1, keepdim=True).sqrt() # [batch_size, num_slots, 1]
86
+
87
+ # Normalize each slot's embedding vector
88
+ normalized = x * torch.rsqrt(rms + self.eps)
89
+ normalized = normalized * self.scale # Apply per-dimension scaling
90
+
91
+ if self.use_gate:
92
+ gate_factor = torch.sigmoid(self.gate) # Scalar gate (0-1)
93
+ return normalized * gate_factor + x * (1 - gate_factor)
94
+ else:
95
+ return normalized
96
+
97
+ class SimpleRMSMemoryNorm(nn.Module):
98
+ def __init__(
99
+ self,
100
+ dim: int,
101
+ use_gate: bool = True,
102
+ init_scale: float = 1.0,
103
+ init_gate: float = -4.0
104
+ ):
105
+ super().__init__()
106
+ self.use_gate = use_gate
107
+ self.scale = nn.Parameter(torch.ones(dim) * init_scale)
108
+ self.gate = nn.Parameter(torch.tensor([init_gate])) # Scalar gate
109
+ self.eps = 1e-6
110
+
111
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
112
+ rms = x.pow(2).mean(-1, keepdim=True).sqrt() # [batch_size, num_slots, 1]
113
+ normalized = x * torch.rsqrt(rms + self.eps)
114
+ normalized = normalized * self.scale # Apply per-dimension scaling
115
+
116
+ if self.use_gate:
117
+ gate_factor = torch.sigmoid(self.gate) # Scalar gate (0-1)
118
+ return normalized * gate_factor + x * (1 - gate_factor)
119
+ else:
120
+ return normalized
121
+
122
+ class MemoryLayerNorm(nn.Module):
123
+ def __init__(
124
+ self,
125
+ dim: int,
126
+ use_gate: bool = True,
127
+ init_scale: float = 1.0,
128
+ init_gate: float = -4.0 # Start with gate closed (no normalization)
129
+ ):
130
+ super().__init__()
131
+ self.use_gate = use_gate
132
+ self.norm = nn.LayerNorm(dim) # Normalizes across embedding dimensions per slot
133
+ self.gate = nn.Parameter(torch.tensor([init_gate])) # Scalar gate for this layer
134
+ self.scale = nn.Parameter(torch.ones(dim) * init_scale)
135
+
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
137
+ normalized = self.norm(x) # Apply LayerNorm across embedding dimensions (per slot)
138
+ normalized = normalized * self.scale # Per-dimension scaling
139
+
140
+ if self.use_gate:
141
+ gate_factor = torch.sigmoid(self.gate) # Scalar gate (0-1)
142
+ return normalized * gate_factor + x * (1 - gate_factor)
143
+ else:
144
+ return normalized
145
+
146
+ class MemoryNormConfig(TypedDict):
147
+ num_slots: int
148
+ decay: float
149
+ use_scale: bool
150
+ use_gate: bool
151
+ init_gate: float
152
+ init_scale: float
153
+
154
+ def init_memory_norm(
155
+ norm_type: str,
156
+ dim: int,
157
+ num_slots: int = None,
158
+ decay: float = 0.99,
159
+ use_scale: bool = True,
160
+ use_gate: bool = True,
161
+ init_gate: float = -4.0,
162
+ init_scale: float = 1.0,
163
+ ) -> nn.Module:
164
+ assert norm_type in ["layer", "rms", "adaptive", "positional"]
165
+ if norm_type == "layer":
166
+ return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
167
+ elif norm_type == "rms":
168
+ return SimpleRMSMemoryNorm(dim, use_gate, init_scale, init_gate)
169
+ elif norm_type == "adaptive":
170
+ return AdaptiveRMSMemoryNorm(dim, use_gate, decay, init_scale, init_gate)
171
+ elif norm_type == "positional":
172
+ return AdaptivePositionalMemoryNorm(num_slots, dim, decay, use_scale, use_gate, init_gate)
173
+ return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
src/memory/stm.py ADDED
@@ -0,0 +1,53 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class ShortTermMemory(nn.Module):
5
+ """Short-term memory module for the Attention-based Memory System"""
6
+
7
+ def __init__(self, num_layers: int, embed_dim: int, stm_size: int, init_type: str = 'normal',
8
+ is_trainable: bool = False, *args, **kwargs):
9
+ super(ShortTermMemory, self).__init__(*args, **kwargs)
10
+ self.num_layers = num_layers
11
+ self.embed_dim = embed_dim
12
+ self.stm_size = stm_size
13
+ self.is_trainable = is_trainable
14
+ assert init_type in ['normal', 'standard', 'uniform', 'ones', 'zeros'], \
15
+ 'STM init type must be one of "normal", "standard", "uniform", "ones", "zeros"'
16
+ if init_type == 'normal':
17
+ stm = torch.normal(0, 0.02, (num_layers, stm_size, embed_dim))
18
+ elif init_type == 'standard':
19
+ stm = torch.normal(0, 1, (num_layers, stm_size, embed_dim))
20
+ elif init_type == 'uniform':
21
+ stm = torch.rand(num_layers, stm_size, embed_dim) * 0.02
22
+ elif init_type == 'ones':
23
+ stm = torch.ones(num_layers, stm_size, embed_dim)
24
+ else:
25
+ stm = torch.zeros(num_layers, stm_size, embed_dim)
26
+
27
+ if self.is_trainable:
28
+ self.memory = nn.Parameter(stm)
29
+ else:
30
+ self.register_buffer('memory', stm)
31
+
32
+ def forward(self, layer: int) -> torch.Tensor:
33
+ return self.memory[layer].unsqueeze(0)
34
+
35
+ def update_layer(self, layer: int, new_stm: torch.Tensor):
36
+ self.memory[layer] = new_stm
37
+
38
+ def update_all(self, new_stm: torch.Tensor):
39
+ self.memory.copy_(new_stm)
40
+
41
+ def make_trainable(self):
42
+ if not self.is_trainable:
43
+ self.is_trainable = True
44
+ initial_stm = self.memory.clone()
45
+ del self.memory
46
+ self.memory = nn.Parameter(initial_stm)
47
+
48
+ def freeze(self):
49
+ if self.is_trainable:
50
+ self.requires_grad_(False)
51
+ trained_stm = self.memory.clone()
52
+ del self.memory
53
+ self.register_buffer('memory', trained_stm)
src/rxt/models.py ADDED
@@ -0,0 +1,180 @@
1
+ import torch
2
+ from torch import nn
3
+ from typing import TypedDict, Union
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+ from src.transformers.positional import RotaryPositionalEmbedding
6
+ from src.transformers.attention import init_attention
7
+ from src.transformers.layers import ReactiveTransformerLayer
8
+ from src.transformers.models import ReactiveTransformerBase, ReactiveTransformerEncoder, ReactiveTransformerDecoder
9
+ from src.transformers.ff import get_activation_layer
10
+ from src.memory.stm import ShortTermMemory
11
+ from src.utils import get_model_size
12
+
13
+
14
+ class RxTAlphaComponentConfig(TypedDict):
15
+ num_layers: int
16
+ vocab_size: int
17
+ embed_dim: int
18
+ ff_dim: int
19
+ att_heads: int
20
+ seq_len: int
21
+ stm_size: int
22
+ use_flash_attention: bool
23
+ use_gated: bool
24
+ ff_activation: str
25
+ ff_dropout: float
26
+ att_dropout: float
27
+ use_rms_norm: bool
28
+ att_groups: int
29
+ use_moe: bool
30
+ num_experts: int
31
+ moe_top_k: int
32
+ self_att_type: str
33
+ cross_att_type: str
34
+
35
+
36
+ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
37
+ """Base class for RxT-Alpha (Reactive Transformer) components (encoder and decoder)"""
38
+
39
+ def __init__(
40
+ self,
41
+ is_causal: bool,
42
+ num_layers: int = 12,
43
+ vocab_size: int = 20000,
44
+ embed_dim: int = 512,
45
+ ff_dim: int = 1536,
46
+ att_heads: int = 16,
47
+ seq_len: int = 1024,
48
+ stm_size: int = 1024,
49
+ use_flash_attention: bool = True,
50
+ use_gated: bool = True,
51
+ ff_activation: str = "swish",
52
+ ff_dropout: float = 0.0,
53
+ att_dropout: float = 0.0,
54
+ use_rms_norm: bool = True,
55
+ att_groups: int = 1,
56
+ use_moe: bool = False,
57
+ num_experts: int = 1,
58
+ moe_top_k: int = 1,
59
+ self_att_type: str = 'gqa',
60
+ cross_att_type: str = 'mqa',
61
+ **kwargs
62
+ ):
63
+ super(RxTAlphaComponentBase, self).__init__(**kwargs)
64
+ assert ff_activation in ['relu', 'gelu',
65
+ 'swish', 'silu', 'linear',
66
+ 'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
67
+ assert self_att_type in ['mha', 'gqa', 'mqa'], 'Self-attention type could be "mha", "gqa", "mqa"'
68
+ assert cross_att_type in ['mha', 'gqa', 'mqa'], 'Memory cross-attention type could be "mha", "gqa", "mqa"'
69
+
70
+ embedding = nn.Embedding(vocab_size, embed_dim)
71
+ rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
72
+ stm = ShortTermMemory(num_layers, embed_dim, stm_size)
73
+
74
+ ff_activation = get_activation_layer(ff_activation)
75
+
76
+ layers = nn.ModuleList([
77
+ ReactiveTransformerLayer(
78
+ embed_dim,
79
+ ff_dim,
80
+ use_gated=use_gated,
81
+ use_moe=use_moe,
82
+ num_experts=num_experts,
83
+ moe_top_k=moe_top_k,
84
+ ff_activation=ff_activation,
85
+ ff_dropout=ff_dropout,
86
+ use_rms_norm=use_rms_norm,
87
+ self_attention=init_attention(embed_dim, att_heads, self_att_type, att_groups, rope=rope,
88
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
89
+ max_seq_len=seq_len, is_causal=is_causal),
90
+ memory_cross_attention=init_attention(embed_dim, att_heads, cross_att_type, att_groups, rope=rope,
91
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
92
+ max_seq_len=seq_len, rope_only_for_query=True,
93
+ is_causal=is_causal)
94
+ ) for _ in range(num_layers)
95
+ ])
96
+ self.model = self._init_model(stm, layers, embedding, use_flash_attention, embed_dim, vocab_size)
97
+
98
+ def _init_model(self, stm: ShortTermMemory, layers: nn.ModuleList, embedding: nn.Embedding,
99
+ use_flash_attention: bool, embed_dim: int, vocab_size: int) -> ReactiveTransformerBase:
100
+ pass
101
+
102
+ def params_count(self):
103
+ return get_model_size(self.model)
104
+
105
+ def load_shared_embedding(self, embedding: nn.Embedding):
106
+ self.model.embedding = embedding
107
+
108
+ def load_shared_memory(self, stm: ShortTermMemory):
109
+ self.model.stm = stm
110
+
111
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> Union[
112
+ torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
113
+ return self.model(x, attention_mask=attention_mask)
114
+
115
+
116
+ class RxTAlphaEncoder(RxTAlphaComponentBase, pipeline_tag="fill-mask", license="apache-2.0"):
117
+ """RxT-Alpha (Reactive Transformer) encoder model"""
118
+
119
+ def __init__(self, **kwargs: RxTAlphaComponentConfig):
120
+ super(RxTAlphaEncoder, self).__init__(False, **kwargs)
121
+
122
+ def _init_model(
123
+ self,
124
+ stm: ShortTermMemory,
125
+ layers: nn.ModuleList,
126
+ embedding: nn.Embedding,
127
+ use_flash_attention: bool,
128
+ embed_dim: int,
129
+ vocab_size: int
130
+ ) -> ReactiveTransformerEncoder:
131
+ return ReactiveTransformerEncoder(
132
+ stm=stm,
133
+ embedding=embedding,
134
+ own_layers=layers,
135
+ use_flash_attention=use_flash_attention,
136
+ )
137
+
138
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
139
+ return self.model(x, attention_mask=attention_mask)
140
+
141
+
142
+ class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", license="apache-2.0"):
143
+ """RxT-Alpha (Reactive Transformer) decoder model"""
144
+
145
+ def __init__(self, **kwargs):
146
+ super(RxTAlphaDecoder, self).__init__(True, **kwargs)
147
+
148
+ def _init_model(
149
+ self, stm: ShortTermMemory,
150
+ layers: nn.ModuleList,
151
+ embedding: nn.Embedding,
152
+ use_flash_attention: bool,
153
+ embed_dim: int,
154
+ vocab_size: int
155
+ ) -> ReactiveTransformerDecoder:
156
+ return ReactiveTransformerDecoder(
157
+ embed_dim,
158
+ vocab_size,
159
+ stm=stm,
160
+ embedding=embedding,
161
+ own_layers=layers,
162
+ use_flash_attention=use_flash_attention,
163
+ )
164
+
165
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
166
+ return self.model(x, attention_mask=attention_mask)
167
+
168
+
169
+ def build_rxt_alpha_for_pretraining(
170
+ encoder_config: RxTAlphaComponentConfig,
171
+ decoder_config: RxTAlphaComponentConfig,
172
+ ) -> tuple[RxTAlphaEncoder, RxTAlphaDecoder]:
173
+ encoder = RxTAlphaEncoder(**encoder_config)
174
+ decoder = RxTAlphaDecoder(**decoder_config)
175
+
176
+ encoder.load_shared_memory(decoder.model.stm)
177
+ encoder.load_shared_embedding(decoder.model.embedding)
178
+
179
+ return encoder, decoder
180
+
src/training/base.py ADDED
@@ -0,0 +1,275 @@
1
+ import torch
2
+ import math
3
+ import os
4
+
5
+ from abc import ABC, abstractmethod
6
+ from torch.utils.tensorboard import SummaryWriter
7
+ import torch.distributed as dist
8
+ from torch.nn.parallel import DistributedDataParallel
9
+ from typing import Callable
10
+ from callbacks import TrainerCallback
11
+
12
+
13
+ class BaseTrainer(ABC):
14
+ def __init__(
15
+ self,
16
+ model: torch.nn.Module,
17
+ device: torch.device,
18
+ optimizer: torch.optim.Optimizer = None,
19
+ dataset: torch.utils.data.Dataset = None,
20
+ validation_dataset: torch.utils.data.Dataset = None,
21
+ callbacks: list[TrainerCallback] = None,
22
+ log_dir: str = None,
23
+ use_ddp: bool = False,
24
+ use_amp: bool = False,
25
+ dtype: torch.dtype = None,
26
+ target_field_name: str = 'labels',
27
+ get_batch_size: Callable[[dict], int] = None,
28
+ gradient_accumulation_steps: int = 1,
29
+ ):
30
+ if get_batch_size is None:
31
+ self.get_batch_size = lambda batch: batch['attention_mask'].size(0)
32
+ else:
33
+ self.get_batch_size = get_batch_size
34
+ if use_amp:
35
+ self.model = model.to(device)
36
+ else:
37
+ self.model = model.to(device, dtype=dtype)
38
+ self.device = device
39
+ self.optimizer = optimizer
40
+ self.dataset = dataset
41
+ self.callbacks = callbacks or []
42
+ self.writer = SummaryWriter(log_dir) if log_dir else None
43
+ self.use_ddp = use_ddp
44
+ self.use_amp = use_amp
45
+ self.dtype = dtype
46
+ self.is_running = False
47
+ self.validation_dataset = validation_dataset
48
+ self.best_val_loss = float('inf')
49
+ self.validation_metrics = {}
50
+ self.target_field_name = target_field_name
51
+ self.total_tokens = 0
52
+ self.gradient_accumulation_steps = gradient_accumulation_steps
53
+ self.accumulated_loss = 0.0
54
+ self.optimizer_step_count = 0
55
+
56
+ @abstractmethod
57
+ def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
58
+ pass
59
+
60
+ def train_step(self, batch: dict[str, torch.Tensor], _batch_idx: int) -> torch.Tensor:
61
+ if self.use_amp:
62
+ batch = {k: v.to(self.device) for k, v in batch.items()}
63
+ with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
64
+ loss, _ = self.compute_loss(batch)
65
+ else:
66
+ batch = {k: v.to(self.device, dtype=self.dtype) for k, v in batch.items()}
67
+ loss, _ = self.compute_loss(batch)
68
+ return loss
69
+
70
+ def __call__(
71
+ self,
72
+ epochs: int,
73
+ batch_size: int,
74
+ dataset: torch.utils.data.Dataset = None,
75
+ optimizer: torch.optim.Optimizer = None,
76
+ scheduler: torch.optim.lr_scheduler.LRScheduler = None,
77
+ ) -> None:
78
+ self.is_running = True
79
+ if dataset is None:
80
+ assert self.dataset is not None, 'You have to specify a dataset for training'
81
+ dataset = self.dataset
82
+ if optimizer is None:
83
+ assert self.optimizer is not None, 'You have to specify an optimizer for training'
84
+ optimizer = self.optimizer
85
+
86
+ if self.use_ddp:
87
+ rank = int(os.environ['RANK'])
88
+ world_size = int(os.environ['WORLD_SIZE'])
89
+ dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
90
+ self.model = DistributedDataParallel(self.model, device_ids=[self.device.index])
91
+ train_sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
92
+ dataloader = torch.utils.data.DataLoader(
93
+ dataset,
94
+ batch_size=batch_size,
95
+ sampler=train_sampler,
96
+ pin_memory=True,
97
+ )
98
+ else:
99
+ train_sampler = None
100
+ dataloader = torch.utils.data.DataLoader(
101
+ dataset,
102
+ batch_size=batch_size,
103
+ shuffle=True,
104
+ pin_memory=True
105
+ )
106
+
107
+ scaler = torch.amp.GradScaler() if self.use_amp else None
108
+
109
+ self.model.train()
110
+ for epoch in range(epochs):
111
+ if self.is_running:
112
+ if train_sampler is not None:
113
+ train_sampler.set_epoch(epoch)
114
+ self._run_epoch(dataloader, epoch, optimizer, batch_size, scaler=scaler, scheduler=scheduler)
115
+ if self.use_ddp:
116
+ dist.barrier()
117
+
118
+ if self.use_ddp:
119
+ dist.destroy_process_group()
120
+ self.is_running = False
121
+ self.on_training_end()
122
+
123
+ def _run_epoch(
124
+ self,
125
+ dataloader: torch.utils.data.DataLoader,
126
+ epoch: int,
127
+ optimizer: torch.optim.Optimizer,
128
+ batch_size: int,
129
+ scaler: torch.cuda.amp.GradScaler = None,
130
+ scheduler: torch.optim.lr_scheduler.LRScheduler = None,
131
+ ) -> None:
132
+ for callback in self.callbacks:
133
+ callback.on_epoch_start(self.model, epoch)
134
+
135
+ self.accumulated_loss = 0.0
136
+ self.optimizer_step_count = 0
137
+
138
+ for batch_idx, batch in enumerate(dataloader):
139
+ if self.is_running:
140
+ for callback in self.callbacks:
141
+ callback.on_batch_start(self.model, batch_idx, batch)
142
+ if self.get_batch_size(batch) == batch_size:
143
+ loss = self.train_step(batch, batch_idx)
144
+ self.accumulated_loss += loss.item()
145
+ loss = loss / self.gradient_accumulation_steps
146
+
147
+ if self.use_amp:
148
+ scaler.scale(loss).backward()
149
+ else:
150
+ loss.backward()
151
+
152
+ self.optimizer_step_count += 1
153
+ if self.optimizer_step_count % self.gradient_accumulation_steps == 0:
154
+ # Clip gradients after accumulation
155
+ if self.use_amp:
156
+ scaler.unscale_(optimizer)
157
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0, error_if_nonfinite=False)
158
+ if self.use_amp:
159
+ scaler.step(optimizer)
160
+ scaler.update()
161
+ else:
162
+ optimizer.step()
163
+
164
+ optimizer.zero_grad()
165
+
166
+ if scheduler is not None:
167
+ scheduler.step()
168
+
169
+ if self.writer:
170
+ loss_item = self.accumulated_loss / self.gradient_accumulation_steps
171
+ self.writer.add_scalar(
172
+ 'Loss/train',
173
+ loss_item,
174
+ epoch * len(dataloader) + batch_idx
175
+ )
176
+ self.writer.add_scalar(
177
+ 'Loss per epoch/train',
178
+ loss_item,
179
+ batch_idx
180
+ )
181
+ self.writer.add_scalar(
182
+ 'Perplexity/train',
183
+ torch.exp(torch.tensor(loss_item)),
184
+ epoch * len(dataloader) + batch_idx
185
+ )
186
+ self.accumulated_loss = 0.0
187
+ self.optimizer_step_count = 0
188
+
189
+ if self.writer:
190
+ self.total_tokens += batch['attention_mask'].sum().item()
191
+ self.writer.add_scalar('Processed tokens', self.total_tokens,
192
+ epoch * len(dataloader) + batch_idx)
193
+
194
+ for callback in self.callbacks:
195
+ should_stop = callback.on_batch_end(self.model, batch_idx, loss.item(), batch)
196
+ if should_stop:
197
+ self.is_running = False
198
+
199
+ if self.validation_dataset:
200
+ val_loss, val_metrics = self.validate(batch_size)
201
+ val_loss_tensor = torch.tensor(val_loss).to(self.device)
202
+ if self.use_ddp:
203
+ dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.SUM)
204
+ val_loss = val_loss_tensor.item() / dist.get_world_size()
205
+ self.validation_metrics[epoch] = val_metrics
206
+
207
+ if self.writer:
208
+ self._valid_writer(epoch, val_loss, val_metrics)
209
+
210
+ for callback in self.callbacks:
211
+ callback.on_validation_end(self.model, epoch, val_loss, val_metrics)
212
+
213
+ for callback in self.callbacks:
214
+ should_stop = callback.on_epoch_end(self.model, epoch)
215
+ if should_stop:
216
+ self.is_running = False
217
+
218
+ if self.writer:
219
+ self.writer.flush()
220
+
221
+ def on_training_end(self):
222
+ for callback in self.callbacks:
223
+ callback.on_training_end(self.model)
224
+ if self.writer:
225
+ self.writer.close()
226
+
227
+ def _valid_writer(self, epoch: int, val_loss: float, val_metrics: dict):
228
+ self.writer.add_scalar('Loss/validation', val_loss, epoch)
229
+ self.writer.add_scalar('Perplexity/validation', math.exp(val_loss), epoch)
230
+ if val_metrics['accuracy']:
231
+ self.writer.add_scalar('Accuracy/validation', val_metrics['accuracy'], epoch)
232
+
233
+ def valid_step(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
234
+ if self.use_amp:
235
+ batch = {k: v.to(self.device) for k, v in batch.items()}
236
+ with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
237
+ loss, outputs = self.compute_loss(batch)
238
+ else:
239
+ batch = {k: v.to(self.device, dtype=self.dtype) for k, v in batch.items()}
240
+ loss, outputs = self.compute_loss(batch)
241
+ return loss, outputs
242
+
243
+ def _valid_loader(self, batch_size: int):
244
+ val_dataset = self.validation_dataset
245
+ if self.use_ddp:
246
+ val_sampler = torch.utils.data.DistributedSampler(val_dataset, shuffle=False)
247
+ return torch.utils.data.DataLoader(
248
+ val_dataset,
249
+ batch_size=batch_size,
250
+ pin_memory=True,
251
+ sampler=val_sampler,
252
+ )
253
+ else:
254
+ return torch.utils.data.DataLoader(
255
+ val_dataset,
256
+ batch_size=batch_size,
257
+ shuffle=False,
258
+ pin_memory=True
259
+ )
260
+
261
+ def validate(self, batch_size: int) -> tuple[float, dict]:
262
+ self.model.eval()
263
+ val_dataloader = self._valid_loader(batch_size)
264
+ val_loss = 0.0
265
+
266
+ with torch.no_grad():
267
+ for batch in val_dataloader:
268
+ if self.get_batch_size(batch) == batch_size:
269
+ loss, outputs = self.valid_step(batch)
270
+ val_loss += loss.item()
271
+
272
+ avg_loss = val_loss / len(val_dataloader)
273
+ metrics = {}
274
+ self.model.train()
275
+ return avg_loss, metrics