pyhalos 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
halo/__init__.py ADDED
@@ -0,0 +1,61 @@
1
+ """
2
+ HALO-S: Framework de Atención Dispersa para Modelos de Lenguaje
3
+ ================================================================
4
+
5
+ Arquitectura transformer con complejidad O(N×K) basada en neighbor lists,
6
+ global tokens, conexiones dilatadas y Grouped Query Attention (GQA).
7
+
8
+ Uso básico:
9
+ >>> from halo import HaloConfig, HaloSModel, set_seed
10
+ >>> set_seed(42)
11
+ >>> config = HaloConfig(vocab_size=256, hidden_size=512)
12
+ >>> model = HaloSModel(config)
13
+
14
+ Autor: BUEORM (dalusx64@gmail.com)
15
+ """
16
+
17
+ __version__ = "1.0.3"
18
+
19
+ # --- Core ---
20
+ from halo.core.config import HaloConfig
21
+
22
+ # --- Models ---
23
+ from halo.models.halo_model import HaloSModel
24
+ from halo.models.baseline_model import BaselineModel
25
+
26
+ # --- Training ---
27
+ from halo.training.trainer import Trainer
28
+
29
+ # --- Tokenizers ---
30
+ from halo.tokenizers.char import CharacterTokenizer
31
+
32
+ # WordTokenizer se importa condicionalmente porque se crea en una tarea posterior
33
+ try:
34
+ from halo.tokenizers.word import WordTokenizer
35
+ except ImportError:
36
+ WordTokenizer = None
37
+
38
+ # --- Generation ---
39
+ from halo.generation.samplers import generate
40
+
41
+ # --- Utils ---
42
+ from halo.utils.random import set_seed
43
+ from halo.utils.metrics import count_parameters
44
+
45
+ __all__ = [
46
+ # Core
47
+ "HaloConfig",
48
+ # Models
49
+ "HaloSModel",
50
+ "BaselineModel",
51
+ # Training
52
+ "Trainer",
53
+ # Tokenizers
54
+ "CharacterTokenizer",
55
+ "WordTokenizer",
56
+ # Generation
57
+ "generate",
58
+ # Utils
59
+ "set_seed",
60
+ "count_parameters",
61
+ ]
@@ -0,0 +1,15 @@
1
+ """
2
+ HALO-S Attention — Mecanismos de atención dispersa y global.
3
+ """
4
+
5
+ from halo.attention.halo_attention import HaloSparseAttention
6
+ from halo.attention.global_attention import GlobalFullAttention, _use_sdpa
7
+ from halo.attention.graph import generate_neighbor_lists, estimate_graph_stats
8
+
9
+ __all__ = [
10
+ "HaloSparseAttention",
11
+ "GlobalFullAttention",
12
+ "_use_sdpa",
13
+ "generate_neighbor_lists",
14
+ "estimate_graph_stats",
15
+ ]
@@ -0,0 +1,215 @@
1
+ """
2
+ Atención Densa para Global Tokens — HALO-S.
3
+
4
+ Los Global Tokens son posiciones especiales (0..G-1) que atienden a TODA la
5
+ secuencia (global + regular tokens) con atención densa estándar.
6
+ Complejidad: O(G × N) donde G = num_globals, N = seq_len total.
7
+
8
+ A diferencia de la atención dispersa (gather-based), aquí se computan scores
9
+ contra todas las posiciones, permitiendo que los globals actúen como memoria
10
+ compartida accesible por cualquier token del grafo disperso.
11
+ """
12
+
13
+ import math
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from typing import Optional
18
+
19
+ from halo.core.config import HaloConfig
20
+
21
+
22
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
23
+ """Rota la segunda mitad de las dimensiones para RoPE."""
24
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
25
+ return torch.cat((-x2, x1), dim=-1)
26
+
27
+
28
+ def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
29
+ """Aplica Rotary Positional Embeddings a un tensor individual.
30
+
31
+ Args:
32
+ x: Tensor con shape (batch, heads, seq_len, head_dim).
33
+ cos, sin: Embeddings rotacionales con shape broadcastable a x.
34
+
35
+ Returns:
36
+ Tensor con RoPE aplicado, misma shape que x.
37
+ """
38
+ return (x * cos) + (_rotate_half(x) * sin)
39
+
40
+
41
+ def _use_sdpa() -> bool:
42
+ """Detecta si torch.nn.functional.scaled_dot_product_attention está disponible.
43
+
44
+ Retorna True solo cuando SDPA existe en el entorno y hay backend
45
+ eficiente (CUDA). En CPU no se beneficia significativamente.
46
+ """
47
+ if not hasattr(F, 'scaled_dot_product_attention'):
48
+ return False
49
+ # Solo se beneficia con backends Flash/Memory-efficient en CUDA
50
+ if torch.cuda.is_available():
51
+ return True
52
+ return False
53
+
54
+
55
+ class GlobalFullAttention(nn.Module):
56
+ """
57
+ Atención densa para los Global Tokens con soporte GQA y SDPA.
58
+
59
+ Cada global token atiende a TODA la secuencia (globals + tokens regulares).
60
+ Utiliza Grouped Query Attention: num_heads queries, num_kv_heads keys/values.
61
+
62
+ Args:
63
+ config: Configuración HaloConfig del modelo.
64
+ use_flash: Si True, intenta usar SDPA como backend de atención.
65
+ Se desactiva automáticamente si SDPA no está disponible.
66
+ """
67
+
68
+ def __init__(self, config: HaloConfig, use_flash: bool = True):
69
+ super().__init__()
70
+ self.num_heads = config.num_heads
71
+ self.num_kv_heads = config.num_kv_heads
72
+ self.head_dim = config.head_dim
73
+ self.hidden_size = config.hidden_size
74
+
75
+ assert self.num_heads % self.num_kv_heads == 0, (
76
+ "num_heads debe ser divisible por num_kv_heads"
77
+ )
78
+ self.num_groups = self.num_heads // self.num_kv_heads
79
+
80
+ # Proyecciones lineales (Q desde globals, K/V desde secuencia completa)
81
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
82
+ self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
83
+ self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
84
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
85
+
86
+ self.dropout = nn.Dropout(config.dropout)
87
+
88
+ # Flag para usar SDPA (solo si está disponible)
89
+ self._use_flash = use_flash and _use_sdpa()
90
+
91
+ def forward(
92
+ self,
93
+ globals_x: torch.Tensor,
94
+ full_seq: torch.Tensor,
95
+ cos: torch.Tensor,
96
+ sin: torch.Tensor,
97
+ is_causal: bool = True
98
+ ) -> torch.Tensor:
99
+ """
100
+ Forward pass de atención densa para global tokens.
101
+
102
+ Args:
103
+ globals_x: (batch, num_globals, hidden) — queries son solo los globals.
104
+ full_seq: (batch, total_seq_len, hidden) — keys/values de toda la secuencia.
105
+ cos, sin: Embeddings RoPE con shape (1, 1, max_seq_len, head_dim).
106
+ is_causal: Si True, global[i] solo atiende a posiciones [0, i].
107
+
108
+ Returns:
109
+ (batch, num_globals, hidden) — representación actualizada de globals.
110
+ """
111
+ B, G, _ = globals_x.shape
112
+ _, N, _ = full_seq.shape
113
+
114
+ # Proyectar Q solo de globals, K/V de toda la secuencia
115
+ q = self.q_proj(globals_x).view(B, G, self.num_heads, self.head_dim).transpose(1, 2)
116
+ k = self.k_proj(full_seq).view(B, N, self.num_kv_heads, self.head_dim).transpose(1, 2)
117
+ v = self.v_proj(full_seq).view(B, N, self.num_kv_heads, self.head_dim).transpose(1, 2)
118
+ # q: (B, num_heads, G, head_dim)
119
+ # k: (B, num_kv_heads, N, head_dim)
120
+ # v: (B, num_kv_heads, N, head_dim)
121
+
122
+ # Aplicar RoPE — Q usa posiciones [0, G), K usa posiciones [0, N)
123
+ # Se aplican por separado porque Q y K tienen longitudes distintas
124
+ cos_q = cos[:, :, :G, :]
125
+ sin_q = sin[:, :, :G, :]
126
+ cos_k = cos[:, :, :N, :]
127
+ sin_k = sin[:, :, :N, :]
128
+ q = _apply_rope(q, cos_q, sin_q)
129
+ k = _apply_rope(k, cos_k, sin_k)
130
+
131
+ # Expandir K, V para GQA (repeat_interleave por grupos)
132
+ if self.num_groups > 1:
133
+ k = k.repeat_interleave(self.num_groups, dim=1)
134
+ v = v.repeat_interleave(self.num_groups, dim=1)
135
+ # k, v: (B, num_heads, N, head_dim)
136
+
137
+ # Construir máscara causal: global[i] solo atiende a posiciones [0, i]
138
+ attn_mask = None
139
+ if is_causal:
140
+ # positions_q: [0, 1, ..., G-1], positions_k: [0, 1, ..., N-1]
141
+ positions_q = torch.arange(G, device=q.device).unsqueeze(1) # (G, 1)
142
+ positions_k = torch.arange(N, device=q.device).unsqueeze(0) # (1, N)
143
+ # True donde K está en el futuro respecto a Q
144
+ causal_mask = positions_k > positions_q # (G, N)
145
+ attn_mask = causal_mask # Se usa abajo según el path
146
+
147
+ if self._use_flash:
148
+ out = self._forward_sdpa(q, k, v, attn_mask)
149
+ else:
150
+ out = self._forward_manual(q, k, v, attn_mask)
151
+
152
+ # Reshape: (B, num_heads, G, head_dim) → (B, G, hidden)
153
+ out = out.transpose(1, 2).contiguous().view(B, G, -1)
154
+ return self.o_proj(out)
155
+
156
+ def _forward_sdpa(
157
+ self,
158
+ q: torch.Tensor,
159
+ k: torch.Tensor,
160
+ v: torch.Tensor,
161
+ causal_mask: Optional[torch.Tensor]
162
+ ) -> torch.Tensor:
163
+ """Path acelerado usando scaled_dot_product_attention (SDPA).
164
+
165
+ Se construye una máscara explícita dado que la causal estándar (triangular)
166
+ no aplica aquí: los globals tienen alcance causal muy limitado.
167
+ """
168
+ # SDPA espera attn_mask como float o bool con shape broadcastable
169
+ # a (B, num_heads, G, N). True = posición ignorada en la versión bool.
170
+ if causal_mask is not None:
171
+ # Convertir a float mask: 0 donde se atiende, -inf donde se bloquea
172
+ sdpa_mask = torch.zeros_like(causal_mask, dtype=q.dtype)
173
+ sdpa_mask.masked_fill_(causal_mask, float('-inf'))
174
+ # Expandir para broadcast: (1, 1, G, N)
175
+ sdpa_mask = sdpa_mask.unsqueeze(0).unsqueeze(0)
176
+ else:
177
+ sdpa_mask = None
178
+
179
+ dropout_p = self.dropout.p if self.training else 0.0
180
+
181
+ # SDPA: (B, num_heads, G, head_dim) @ (B, num_heads, N, head_dim) → (B, num_heads, G, head_dim)
182
+ out = F.scaled_dot_product_attention(
183
+ q, k, v,
184
+ attn_mask=sdpa_mask,
185
+ dropout_p=dropout_p,
186
+ is_causal=False # Usamos nuestra máscara custom
187
+ )
188
+ return out
189
+
190
+ def _forward_manual(
191
+ self,
192
+ q: torch.Tensor,
193
+ k: torch.Tensor,
194
+ v: torch.Tensor,
195
+ causal_mask: Optional[torch.Tensor]
196
+ ) -> torch.Tensor:
197
+ """Fallback manual: matmul scores, mask, softmax, dropout, weighted sum.
198
+
199
+ Se usa cuando SDPA no está disponible o use_flash=False.
200
+ """
201
+ # Scores: (B, num_heads, G, N)
202
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
203
+
204
+ # Aplicar máscara causal
205
+ if causal_mask is not None:
206
+ # causal_mask: (G, N) — expandir para (1, 1, G, N)
207
+ scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
208
+
209
+ # Softmax sobre la dimensión de keys
210
+ attn_weights = torch.softmax(scores, dim=-1)
211
+ attn_weights = self.dropout(attn_weights)
212
+
213
+ # Weighted sum: (B, num_heads, G, N) @ (B, num_heads, N, head_dim) → (B, num_heads, G, head_dim)
214
+ out = torch.matmul(attn_weights, v)
215
+ return out
@@ -0,0 +1,111 @@
1
+ """
2
+ HALO-S Phase 0: Graph Prototype
3
+ Módulo para la generación de la conectividad dispersa (sparse) del modelo HALO-S.
4
+ Utiliza Neighbor Lists (listas de adyacencia) para evitar la instanciación de matrices densas NxN.
5
+ """
6
+
7
+ import torch
8
+
9
+ def generate_neighbor_lists(
10
+ seq_len: int,
11
+ local_window: int = 64,
12
+ num_globals: int = 2,
13
+ dilated_offsets: list[int] = [1, 2, 4, 8],
14
+ num_random: int = 2,
15
+ layer_id: int = 0
16
+ ) -> torch.Tensor:
17
+ """
18
+ Genera una lista de vecinos (neighbor list) de tamaño fijo para cada token en la secuencia.
19
+
20
+ Args:
21
+ seq_len: Longitud total de la secuencia (incluyendo globales).
22
+ local_window: Tamaño de la ventana local total (mitad hacia atrás, mitad hacia adelante).
23
+ num_globals: Número de tokens globales al inicio de la secuencia.
24
+ dilated_offsets: Desplazamientos para las conexiones dilatadas.
25
+ num_random: Número de conexiones pseudoaleatorias.
26
+ layer_id: Identificador de la capa para el hash pseudoaleatorio.
27
+
28
+ Returns:
29
+ neighbors: Tensor de enteros de shape (seq_len, num_neighbors)
30
+ donde num_neighbors = num_globals + local_window + 2 * len(dilated_offsets) + num_random
31
+ """
32
+ if seq_len <= num_globals:
33
+ raise ValueError("Sequence length must be greater than num_globals")
34
+
35
+ local_half = local_window // 2
36
+ num_dilated = len(dilated_offsets) * 2
37
+ num_neighbors = num_globals + local_window + num_dilated + num_random
38
+
39
+ device = torch.device('cpu') # For prototype, CPU is fine to build indices
40
+ positions = torch.arange(seq_len, device=device)
41
+
42
+ # 1. Global Tokens (0 to num_globals - 1)
43
+ global_idx = torch.arange(num_globals, device=device).unsqueeze(0).expand(seq_len, -1)
44
+
45
+ # 2. Local Window (-local_half to +local_half - 1)
46
+ # Asegura tamaño exacto `local_window`
47
+ local_offsets = torch.arange(-local_half, local_half + (local_window % 2), device=device).unsqueeze(0)
48
+ local_idx = positions.unsqueeze(1) + local_offsets
49
+
50
+ # 3. Dilated Connections
51
+ dilated_list = []
52
+ for offset in dilated_offsets:
53
+ dilated_list.append(positions + offset)
54
+ dilated_list.append(positions - offset)
55
+
56
+ if dilated_list:
57
+ dilated_idx = torch.stack(dilated_list, dim=1)
58
+ else:
59
+ dilated_idx = torch.empty((seq_len, 0), dtype=torch.long, device=device)
60
+
61
+ # 4. Pseudo-random Connections (Deterministic based on layer_id and position)
62
+ # Using a fast vectorizable integer hash
63
+ pos_expanded = positions.unsqueeze(1).expand(-1, num_random)
64
+ rand_offsets = torch.arange(num_random, device=device).unsqueeze(0).expand(seq_len, -1)
65
+
66
+ # LCG-like hash parameters
67
+ hash_val = (pos_expanded * 2654435761 + layer_id * 805459861 + rand_offsets * 3266489917)
68
+ hash_val = hash_val.to(torch.int64) & 0xFFFFFFFF
69
+
70
+ # Map back to valid sequence indices, avoiding globals
71
+ valid_range = max(1, seq_len - num_globals)
72
+ random_idx = num_globals + (hash_val % valid_range)
73
+
74
+ # Concatenate all indices
75
+ all_idx = torch.cat([global_idx, local_idx, dilated_idx, random_idx], dim=1)
76
+
77
+ # Clamp out-of-bounds indices
78
+ # Si un token apunta fuera de la secuencia, se redirige a sí mismo.
79
+ # Matemáticamente añade más peso al propio token en self-attention, lo cual es inofensivo.
80
+ out_of_bounds = (all_idx < 0) | (all_idx >= seq_len)
81
+ all_idx = torch.where(out_of_bounds, positions.unsqueeze(1), all_idx)
82
+
83
+ return all_idx
84
+
85
+ def estimate_graph_stats(neighbor_lists: torch.Tensor, seq_len: int):
86
+ """
87
+ Calcula estadísticas del grafo generado para comparar con representaciones densas.
88
+ """
89
+ num_neighbors = neighbor_lists.shape[1]
90
+
91
+ # Memoria de la representación rala (int64)
92
+ sparse_memory_bytes = neighbor_lists.numel() * 8
93
+
94
+ # Memoria teórica de una máscara densa (bool)
95
+ dense_memory_bytes = seq_len * seq_len * 1
96
+
97
+ # Factor de compresión
98
+ compression_ratio = dense_memory_bytes / sparse_memory_bytes if sparse_memory_bytes > 0 else float('inf')
99
+
100
+ # Conexiones totales
101
+ total_connections = seq_len * num_neighbors
102
+
103
+ return {
104
+ "seq_len": seq_len,
105
+ "num_neighbors": num_neighbors,
106
+ "sparse_memory_mb": sparse_memory_bytes / (1024 * 1024),
107
+ "dense_memory_mb": dense_memory_bytes / (1024 * 1024),
108
+ "compression_ratio": compression_ratio,
109
+ "total_connections": total_connections,
110
+ "dense_connections": seq_len * seq_len
111
+ }
@@ -0,0 +1,118 @@
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Optional
5
+ from halo.core.config import HaloConfig
6
+ from halo.attention.graph import generate_neighbor_lists
7
+ from halo.nn.rope import apply_rotary_pos_emb
8
+
9
+ class HaloSparseAttention(nn.Module):
10
+ """
11
+ Atención Dispersa HALO-S con Gather-based backend y GQA.
12
+ Garantiza complejidad O(N * num_neighbors) en lugar de O(N^2).
13
+ """
14
+ def __init__(self, config: HaloConfig, layer_id: int):
15
+ super().__init__()
16
+ self.config = config
17
+ self.layer_id = layer_id
18
+
19
+ self.hidden_size = config.hidden_size
20
+ self.num_heads = config.num_heads
21
+ self.num_kv_heads = config.num_kv_heads
22
+ self.head_dim = config.head_dim
23
+
24
+ assert self.num_heads % self.num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"
25
+ self.num_groups = self.num_heads // self.num_kv_heads
26
+
27
+ # Proyecciones lineales
28
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
29
+ self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
30
+ self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
31
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
32
+
33
+ self.dropout = nn.Dropout(config.dropout)
34
+
35
+ # Cache opcional para los vecinos si seq_len no cambia (acelera entrenamiento)
36
+ self._cached_neighbors = None
37
+ self._cached_seq_len = -1
38
+
39
+ def _get_neighbors(self, seq_len: int, device: torch.device):
40
+ if self._cached_seq_len == seq_len and self._cached_neighbors is not None:
41
+ return self._cached_neighbors
42
+
43
+ neighbors = generate_neighbor_lists(
44
+ seq_len=seq_len,
45
+ local_window=self.config.local_window,
46
+ num_globals=self.config.num_globals,
47
+ dilated_offsets=self.config.dilated_offsets,
48
+ num_random=self.config.num_random,
49
+ layer_id=self.layer_id
50
+ ).to(device)
51
+
52
+ self._cached_neighbors = neighbors
53
+ self._cached_seq_len = seq_len
54
+ return neighbors
55
+
56
+ def forward(
57
+ self,
58
+ x: torch.Tensor,
59
+ cos: torch.Tensor,
60
+ sin: torch.Tensor,
61
+ is_causal: bool = True
62
+ ) -> torch.Tensor:
63
+ batch_size, seq_len, _ = x.shape
64
+
65
+ # Proyectar Q, K, V
66
+ q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
67
+ k = self.k_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
68
+ v = self.v_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
69
+
70
+ # Aplicar RoPE
71
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
72
+
73
+ # Obtener matriz de índices dispersos (Neighbor List)
74
+ # shape: (seq_len, num_neighbors)
75
+ neighbors = self._get_neighbors(seq_len, x.device)
76
+
77
+ # Recopilación de vecinos dispersos mediante Advanced Indexing (GATHER)
78
+ # K, V shape original: (batch, num_kv_heads, seq_len, head_dim)
79
+ # k_gathered shape: (batch, num_kv_heads, seq_len, num_neighbors, head_dim)
80
+ k_gathered = k[:, :, neighbors, :]
81
+ v_gathered = v[:, :, neighbors, :]
82
+
83
+ # Expansión para Grouped Query Attention (GQA)
84
+ # Repite el KV cache para coincidir con num_heads
85
+ if self.num_groups > 1:
86
+ k_gathered = k_gathered.repeat_interleave(self.num_groups, dim=1)
87
+ v_gathered = v_gathered.repeat_interleave(self.num_groups, dim=1)
88
+
89
+ # Preparar Q para dot product
90
+ q_expanded = q.unsqueeze(3) # (batch, num_heads, seq_len, 1, head_dim)
91
+
92
+ # Calcular scores de atención locales
93
+ # Matmul: (..., 1, head_dim) x (..., head_dim, num_neighbors) -> (..., 1, num_neighbors)
94
+ scores = torch.matmul(q_expanded, k_gathered.transpose(-2, -1))
95
+ scores = scores / math.sqrt(self.head_dim)
96
+
97
+ # Aplicar máscara causal (autoregresiva)
98
+ # Ningún token puede atender a una posición mayor que él mismo.
99
+ if is_causal:
100
+ # positions: (seq_len, 1)
101
+ positions = torch.arange(seq_len, device=x.device).unsqueeze(1)
102
+ causal_mask = neighbors > positions # True si es futuro
103
+ causal_mask = causal_mask.view(1, 1, seq_len, 1, -1)
104
+ scores.masked_fill_(causal_mask, float('-inf'))
105
+
106
+ # Softmax y Dropout
107
+ attn = torch.softmax(scores, dim=-1)
108
+ attn = self.dropout(attn)
109
+
110
+ # Multiplicar por V local
111
+ # Matmul: (..., 1, num_neighbors) x (..., num_neighbors, head_dim) -> (..., 1, head_dim)
112
+ out = torch.matmul(attn, v_gathered)
113
+ out = out.squeeze(3) # (batch, num_heads, seq_len, head_dim)
114
+
115
+ # Re-empaquetar
116
+ out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
117
+
118
+ return self.o_proj(out)
halo/core/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ """
2
+ HALO-S Core — Configuración y utilidades fundamentales del framework.
3
+ """
4
+
5
+ from halo.core.config import HaloConfig
6
+
7
+ __all__ = ["HaloConfig"]
halo/core/config.py ADDED
@@ -0,0 +1,72 @@
1
+ from dataclasses import dataclass, field, asdict
2
+ from typing import List
3
+
4
+ @dataclass
5
+ class HaloConfig:
6
+ """Configuración principal para el modelo HALO-S."""
7
+ vocab_size: int = 256 # Por defecto para CharTokenizer
8
+ hidden_size: int = 512
9
+ num_layers: int = 6
10
+ num_heads: int = 8
11
+ num_kv_heads: int = 2 # Implementando Grouped Query Attention (GQA)
12
+
13
+ # Parámetros del grafo disperso HALO-S
14
+ num_globals: int = 2
15
+ local_window: int = 64
16
+ dilated_offsets: List[int] = field(default_factory=lambda: [1, 2, 4, 8])
17
+ num_random: int = 2
18
+
19
+ # Hiperparámetros de regularización e inferencia
20
+ dropout: float = 0.1
21
+ max_seq_len: int = 4096
22
+
23
+ def __post_init__(self):
24
+ """Validación de parámetros al crear la instancia."""
25
+ assert self.hidden_size % self.num_heads == 0, (
26
+ f"hidden_size ({self.hidden_size}) debe ser divisible por "
27
+ f"num_heads ({self.num_heads})"
28
+ )
29
+ assert self.num_heads % self.num_kv_heads == 0, (
30
+ f"num_heads ({self.num_heads}) debe ser divisible por "
31
+ f"num_kv_heads ({self.num_kv_heads})"
32
+ )
33
+ assert self.num_globals >= 1, (
34
+ f"Se requiere al menos 1 global token, se recibió num_globals={self.num_globals}"
35
+ )
36
+ assert self.local_window > 0, (
37
+ f"local_window debe ser > 0, se recibió {self.local_window}"
38
+ )
39
+ assert self.max_seq_len > self.num_globals, (
40
+ f"max_seq_len ({self.max_seq_len}) debe ser mayor que "
41
+ f"num_globals ({self.num_globals})"
42
+ )
43
+ assert 0 <= self.dropout < 1, (
44
+ f"dropout debe estar en [0, 1), se recibió {self.dropout}"
45
+ )
46
+
47
+ @property
48
+ def head_dim(self):
49
+ """Dimensión por cabeza de atención."""
50
+ return self.hidden_size // self.num_heads
51
+
52
+ @property
53
+ def num_neighbors(self) -> int:
54
+ """Número total de vecinos por token en el grafo disperso.
55
+
56
+ Incluye: globals + ventana local + conexiones dilatadas (ida y vuelta) + aleatorios.
57
+ """
58
+ return (
59
+ self.num_globals
60
+ + self.local_window
61
+ + 2 * len(self.dilated_offsets)
62
+ + self.num_random
63
+ )
64
+
65
+ def to_dict(self) -> dict:
66
+ """Serializa la configuración a un diccionario con todos los campos."""
67
+ return asdict(self)
68
+
69
+ @classmethod
70
+ def from_dict(cls, d: dict) -> "HaloConfig":
71
+ """Reconstruye una instancia de HaloConfig desde un diccionario."""
72
+ return cls(**d)