pyhalos 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- halo/__init__.py +61 -0
- halo/attention/__init__.py +15 -0
- halo/attention/global_attention.py +215 -0
- halo/attention/graph.py +111 -0
- halo/attention/halo_attention.py +118 -0
- halo/core/__init__.py +7 -0
- halo/core/config.py +72 -0
- halo/core/logging.py +115 -0
- halo/datasets/__init__.py +8 -0
- halo/datasets/jsonl.py +215 -0
- halo/datasets/streaming.py +132 -0
- halo/datasets/synthetic.py +55 -0
- halo/datasets/text.py +29 -0
- halo/generation/__init__.py +7 -0
- halo/generation/samplers.py +44 -0
- halo/models/__init__.py +8 -0
- halo/models/baseline_model.py +94 -0
- halo/models/halo_model.py +198 -0
- halo/nn/__init__.py +14 -0
- halo/nn/feed_forward.py +17 -0
- halo/nn/halo_block.py +50 -0
- halo/nn/rope.py +48 -0
- halo/tokenizers/__init__.py +15 -0
- halo/tokenizers/base.py +12 -0
- halo/tokenizers/char.py +26 -0
- halo/tokenizers/sentencepiece.py +119 -0
- halo/tokenizers/word.py +164 -0
- halo/training/__init__.py +7 -0
- halo/training/trainer.py +318 -0
- halo/utils/__init__.py +12 -0
- halo/utils/benchmarks.py +305 -0
- halo/utils/metrics.py +24 -0
- halo/utils/random.py +14 -0
- pyhalos-1.0.3.dist-info/METADATA +980 -0
- pyhalos-1.0.3.dist-info/RECORD +38 -0
- pyhalos-1.0.3.dist-info/WHEEL +5 -0
- pyhalos-1.0.3.dist-info/licenses/LICENSE +46 -0
- pyhalos-1.0.3.dist-info/top_level.txt +1 -0
halo/__init__.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HALO-S: Framework de Atención Dispersa para Modelos de Lenguaje
|
|
3
|
+
================================================================
|
|
4
|
+
|
|
5
|
+
Arquitectura transformer con complejidad O(N×K) basada en neighbor lists,
|
|
6
|
+
global tokens, conexiones dilatadas y Grouped Query Attention (GQA).
|
|
7
|
+
|
|
8
|
+
Uso básico:
|
|
9
|
+
>>> from halo import HaloConfig, HaloSModel, set_seed
|
|
10
|
+
>>> set_seed(42)
|
|
11
|
+
>>> config = HaloConfig(vocab_size=256, hidden_size=512)
|
|
12
|
+
>>> model = HaloSModel(config)
|
|
13
|
+
|
|
14
|
+
Autor: BUEORM (dalusx64@gmail.com)
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
__version__ = "1.0.3"
|
|
18
|
+
|
|
19
|
+
# --- Core ---
|
|
20
|
+
from halo.core.config import HaloConfig
|
|
21
|
+
|
|
22
|
+
# --- Models ---
|
|
23
|
+
from halo.models.halo_model import HaloSModel
|
|
24
|
+
from halo.models.baseline_model import BaselineModel
|
|
25
|
+
|
|
26
|
+
# --- Training ---
|
|
27
|
+
from halo.training.trainer import Trainer
|
|
28
|
+
|
|
29
|
+
# --- Tokenizers ---
|
|
30
|
+
from halo.tokenizers.char import CharacterTokenizer
|
|
31
|
+
|
|
32
|
+
# WordTokenizer se importa condicionalmente porque se crea en una tarea posterior
|
|
33
|
+
try:
|
|
34
|
+
from halo.tokenizers.word import WordTokenizer
|
|
35
|
+
except ImportError:
|
|
36
|
+
WordTokenizer = None
|
|
37
|
+
|
|
38
|
+
# --- Generation ---
|
|
39
|
+
from halo.generation.samplers import generate
|
|
40
|
+
|
|
41
|
+
# --- Utils ---
|
|
42
|
+
from halo.utils.random import set_seed
|
|
43
|
+
from halo.utils.metrics import count_parameters
|
|
44
|
+
|
|
45
|
+
__all__ = [
|
|
46
|
+
# Core
|
|
47
|
+
"HaloConfig",
|
|
48
|
+
# Models
|
|
49
|
+
"HaloSModel",
|
|
50
|
+
"BaselineModel",
|
|
51
|
+
# Training
|
|
52
|
+
"Trainer",
|
|
53
|
+
# Tokenizers
|
|
54
|
+
"CharacterTokenizer",
|
|
55
|
+
"WordTokenizer",
|
|
56
|
+
# Generation
|
|
57
|
+
"generate",
|
|
58
|
+
# Utils
|
|
59
|
+
"set_seed",
|
|
60
|
+
"count_parameters",
|
|
61
|
+
]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HALO-S Attention — Mecanismos de atención dispersa y global.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from halo.attention.halo_attention import HaloSparseAttention
|
|
6
|
+
from halo.attention.global_attention import GlobalFullAttention, _use_sdpa
|
|
7
|
+
from halo.attention.graph import generate_neighbor_lists, estimate_graph_stats
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"HaloSparseAttention",
|
|
11
|
+
"GlobalFullAttention",
|
|
12
|
+
"_use_sdpa",
|
|
13
|
+
"generate_neighbor_lists",
|
|
14
|
+
"estimate_graph_stats",
|
|
15
|
+
]
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Atención Densa para Global Tokens — HALO-S.
|
|
3
|
+
|
|
4
|
+
Los Global Tokens son posiciones especiales (0..G-1) que atienden a TODA la
|
|
5
|
+
secuencia (global + regular tokens) con atención densa estándar.
|
|
6
|
+
Complejidad: O(G × N) donde G = num_globals, N = seq_len total.
|
|
7
|
+
|
|
8
|
+
A diferencia de la atención dispersa (gather-based), aquí se computan scores
|
|
9
|
+
contra todas las posiciones, permitiendo que los globals actúen como memoria
|
|
10
|
+
compartida accesible por cualquier token del grafo disperso.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import math
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
import torch.nn.functional as F
|
|
17
|
+
from typing import Optional
|
|
18
|
+
|
|
19
|
+
from halo.core.config import HaloConfig
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
|
23
|
+
"""Rota la segunda mitad de las dimensiones para RoPE."""
|
|
24
|
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
|
25
|
+
return torch.cat((-x2, x1), dim=-1)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
|
29
|
+
"""Aplica Rotary Positional Embeddings a un tensor individual.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
x: Tensor con shape (batch, heads, seq_len, head_dim).
|
|
33
|
+
cos, sin: Embeddings rotacionales con shape broadcastable a x.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Tensor con RoPE aplicado, misma shape que x.
|
|
37
|
+
"""
|
|
38
|
+
return (x * cos) + (_rotate_half(x) * sin)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _use_sdpa() -> bool:
|
|
42
|
+
"""Detecta si torch.nn.functional.scaled_dot_product_attention está disponible.
|
|
43
|
+
|
|
44
|
+
Retorna True solo cuando SDPA existe en el entorno y hay backend
|
|
45
|
+
eficiente (CUDA). En CPU no se beneficia significativamente.
|
|
46
|
+
"""
|
|
47
|
+
if not hasattr(F, 'scaled_dot_product_attention'):
|
|
48
|
+
return False
|
|
49
|
+
# Solo se beneficia con backends Flash/Memory-efficient en CUDA
|
|
50
|
+
if torch.cuda.is_available():
|
|
51
|
+
return True
|
|
52
|
+
return False
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class GlobalFullAttention(nn.Module):
|
|
56
|
+
"""
|
|
57
|
+
Atención densa para los Global Tokens con soporte GQA y SDPA.
|
|
58
|
+
|
|
59
|
+
Cada global token atiende a TODA la secuencia (globals + tokens regulares).
|
|
60
|
+
Utiliza Grouped Query Attention: num_heads queries, num_kv_heads keys/values.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
config: Configuración HaloConfig del modelo.
|
|
64
|
+
use_flash: Si True, intenta usar SDPA como backend de atención.
|
|
65
|
+
Se desactiva automáticamente si SDPA no está disponible.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(self, config: HaloConfig, use_flash: bool = True):
|
|
69
|
+
super().__init__()
|
|
70
|
+
self.num_heads = config.num_heads
|
|
71
|
+
self.num_kv_heads = config.num_kv_heads
|
|
72
|
+
self.head_dim = config.head_dim
|
|
73
|
+
self.hidden_size = config.hidden_size
|
|
74
|
+
|
|
75
|
+
assert self.num_heads % self.num_kv_heads == 0, (
|
|
76
|
+
"num_heads debe ser divisible por num_kv_heads"
|
|
77
|
+
)
|
|
78
|
+
self.num_groups = self.num_heads // self.num_kv_heads
|
|
79
|
+
|
|
80
|
+
# Proyecciones lineales (Q desde globals, K/V desde secuencia completa)
|
|
81
|
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
|
82
|
+
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
|
83
|
+
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
|
84
|
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
|
85
|
+
|
|
86
|
+
self.dropout = nn.Dropout(config.dropout)
|
|
87
|
+
|
|
88
|
+
# Flag para usar SDPA (solo si está disponible)
|
|
89
|
+
self._use_flash = use_flash and _use_sdpa()
|
|
90
|
+
|
|
91
|
+
def forward(
|
|
92
|
+
self,
|
|
93
|
+
globals_x: torch.Tensor,
|
|
94
|
+
full_seq: torch.Tensor,
|
|
95
|
+
cos: torch.Tensor,
|
|
96
|
+
sin: torch.Tensor,
|
|
97
|
+
is_causal: bool = True
|
|
98
|
+
) -> torch.Tensor:
|
|
99
|
+
"""
|
|
100
|
+
Forward pass de atención densa para global tokens.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
globals_x: (batch, num_globals, hidden) — queries son solo los globals.
|
|
104
|
+
full_seq: (batch, total_seq_len, hidden) — keys/values de toda la secuencia.
|
|
105
|
+
cos, sin: Embeddings RoPE con shape (1, 1, max_seq_len, head_dim).
|
|
106
|
+
is_causal: Si True, global[i] solo atiende a posiciones [0, i].
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
(batch, num_globals, hidden) — representación actualizada de globals.
|
|
110
|
+
"""
|
|
111
|
+
B, G, _ = globals_x.shape
|
|
112
|
+
_, N, _ = full_seq.shape
|
|
113
|
+
|
|
114
|
+
# Proyectar Q solo de globals, K/V de toda la secuencia
|
|
115
|
+
q = self.q_proj(globals_x).view(B, G, self.num_heads, self.head_dim).transpose(1, 2)
|
|
116
|
+
k = self.k_proj(full_seq).view(B, N, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
117
|
+
v = self.v_proj(full_seq).view(B, N, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
118
|
+
# q: (B, num_heads, G, head_dim)
|
|
119
|
+
# k: (B, num_kv_heads, N, head_dim)
|
|
120
|
+
# v: (B, num_kv_heads, N, head_dim)
|
|
121
|
+
|
|
122
|
+
# Aplicar RoPE — Q usa posiciones [0, G), K usa posiciones [0, N)
|
|
123
|
+
# Se aplican por separado porque Q y K tienen longitudes distintas
|
|
124
|
+
cos_q = cos[:, :, :G, :]
|
|
125
|
+
sin_q = sin[:, :, :G, :]
|
|
126
|
+
cos_k = cos[:, :, :N, :]
|
|
127
|
+
sin_k = sin[:, :, :N, :]
|
|
128
|
+
q = _apply_rope(q, cos_q, sin_q)
|
|
129
|
+
k = _apply_rope(k, cos_k, sin_k)
|
|
130
|
+
|
|
131
|
+
# Expandir K, V para GQA (repeat_interleave por grupos)
|
|
132
|
+
if self.num_groups > 1:
|
|
133
|
+
k = k.repeat_interleave(self.num_groups, dim=1)
|
|
134
|
+
v = v.repeat_interleave(self.num_groups, dim=1)
|
|
135
|
+
# k, v: (B, num_heads, N, head_dim)
|
|
136
|
+
|
|
137
|
+
# Construir máscara causal: global[i] solo atiende a posiciones [0, i]
|
|
138
|
+
attn_mask = None
|
|
139
|
+
if is_causal:
|
|
140
|
+
# positions_q: [0, 1, ..., G-1], positions_k: [0, 1, ..., N-1]
|
|
141
|
+
positions_q = torch.arange(G, device=q.device).unsqueeze(1) # (G, 1)
|
|
142
|
+
positions_k = torch.arange(N, device=q.device).unsqueeze(0) # (1, N)
|
|
143
|
+
# True donde K está en el futuro respecto a Q
|
|
144
|
+
causal_mask = positions_k > positions_q # (G, N)
|
|
145
|
+
attn_mask = causal_mask # Se usa abajo según el path
|
|
146
|
+
|
|
147
|
+
if self._use_flash:
|
|
148
|
+
out = self._forward_sdpa(q, k, v, attn_mask)
|
|
149
|
+
else:
|
|
150
|
+
out = self._forward_manual(q, k, v, attn_mask)
|
|
151
|
+
|
|
152
|
+
# Reshape: (B, num_heads, G, head_dim) → (B, G, hidden)
|
|
153
|
+
out = out.transpose(1, 2).contiguous().view(B, G, -1)
|
|
154
|
+
return self.o_proj(out)
|
|
155
|
+
|
|
156
|
+
def _forward_sdpa(
|
|
157
|
+
self,
|
|
158
|
+
q: torch.Tensor,
|
|
159
|
+
k: torch.Tensor,
|
|
160
|
+
v: torch.Tensor,
|
|
161
|
+
causal_mask: Optional[torch.Tensor]
|
|
162
|
+
) -> torch.Tensor:
|
|
163
|
+
"""Path acelerado usando scaled_dot_product_attention (SDPA).
|
|
164
|
+
|
|
165
|
+
Se construye una máscara explícita dado que la causal estándar (triangular)
|
|
166
|
+
no aplica aquí: los globals tienen alcance causal muy limitado.
|
|
167
|
+
"""
|
|
168
|
+
# SDPA espera attn_mask como float o bool con shape broadcastable
|
|
169
|
+
# a (B, num_heads, G, N). True = posición ignorada en la versión bool.
|
|
170
|
+
if causal_mask is not None:
|
|
171
|
+
# Convertir a float mask: 0 donde se atiende, -inf donde se bloquea
|
|
172
|
+
sdpa_mask = torch.zeros_like(causal_mask, dtype=q.dtype)
|
|
173
|
+
sdpa_mask.masked_fill_(causal_mask, float('-inf'))
|
|
174
|
+
# Expandir para broadcast: (1, 1, G, N)
|
|
175
|
+
sdpa_mask = sdpa_mask.unsqueeze(0).unsqueeze(0)
|
|
176
|
+
else:
|
|
177
|
+
sdpa_mask = None
|
|
178
|
+
|
|
179
|
+
dropout_p = self.dropout.p if self.training else 0.0
|
|
180
|
+
|
|
181
|
+
# SDPA: (B, num_heads, G, head_dim) @ (B, num_heads, N, head_dim) → (B, num_heads, G, head_dim)
|
|
182
|
+
out = F.scaled_dot_product_attention(
|
|
183
|
+
q, k, v,
|
|
184
|
+
attn_mask=sdpa_mask,
|
|
185
|
+
dropout_p=dropout_p,
|
|
186
|
+
is_causal=False # Usamos nuestra máscara custom
|
|
187
|
+
)
|
|
188
|
+
return out
|
|
189
|
+
|
|
190
|
+
def _forward_manual(
|
|
191
|
+
self,
|
|
192
|
+
q: torch.Tensor,
|
|
193
|
+
k: torch.Tensor,
|
|
194
|
+
v: torch.Tensor,
|
|
195
|
+
causal_mask: Optional[torch.Tensor]
|
|
196
|
+
) -> torch.Tensor:
|
|
197
|
+
"""Fallback manual: matmul scores, mask, softmax, dropout, weighted sum.
|
|
198
|
+
|
|
199
|
+
Se usa cuando SDPA no está disponible o use_flash=False.
|
|
200
|
+
"""
|
|
201
|
+
# Scores: (B, num_heads, G, N)
|
|
202
|
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
|
203
|
+
|
|
204
|
+
# Aplicar máscara causal
|
|
205
|
+
if causal_mask is not None:
|
|
206
|
+
# causal_mask: (G, N) — expandir para (1, 1, G, N)
|
|
207
|
+
scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
|
|
208
|
+
|
|
209
|
+
# Softmax sobre la dimensión de keys
|
|
210
|
+
attn_weights = torch.softmax(scores, dim=-1)
|
|
211
|
+
attn_weights = self.dropout(attn_weights)
|
|
212
|
+
|
|
213
|
+
# Weighted sum: (B, num_heads, G, N) @ (B, num_heads, N, head_dim) → (B, num_heads, G, head_dim)
|
|
214
|
+
out = torch.matmul(attn_weights, v)
|
|
215
|
+
return out
|
halo/attention/graph.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HALO-S Phase 0: Graph Prototype
|
|
3
|
+
Módulo para la generación de la conectividad dispersa (sparse) del modelo HALO-S.
|
|
4
|
+
Utiliza Neighbor Lists (listas de adyacencia) para evitar la instanciación de matrices densas NxN.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
def generate_neighbor_lists(
|
|
10
|
+
seq_len: int,
|
|
11
|
+
local_window: int = 64,
|
|
12
|
+
num_globals: int = 2,
|
|
13
|
+
dilated_offsets: list[int] = [1, 2, 4, 8],
|
|
14
|
+
num_random: int = 2,
|
|
15
|
+
layer_id: int = 0
|
|
16
|
+
) -> torch.Tensor:
|
|
17
|
+
"""
|
|
18
|
+
Genera una lista de vecinos (neighbor list) de tamaño fijo para cada token en la secuencia.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
seq_len: Longitud total de la secuencia (incluyendo globales).
|
|
22
|
+
local_window: Tamaño de la ventana local total (mitad hacia atrás, mitad hacia adelante).
|
|
23
|
+
num_globals: Número de tokens globales al inicio de la secuencia.
|
|
24
|
+
dilated_offsets: Desplazamientos para las conexiones dilatadas.
|
|
25
|
+
num_random: Número de conexiones pseudoaleatorias.
|
|
26
|
+
layer_id: Identificador de la capa para el hash pseudoaleatorio.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
neighbors: Tensor de enteros de shape (seq_len, num_neighbors)
|
|
30
|
+
donde num_neighbors = num_globals + local_window + 2 * len(dilated_offsets) + num_random
|
|
31
|
+
"""
|
|
32
|
+
if seq_len <= num_globals:
|
|
33
|
+
raise ValueError("Sequence length must be greater than num_globals")
|
|
34
|
+
|
|
35
|
+
local_half = local_window // 2
|
|
36
|
+
num_dilated = len(dilated_offsets) * 2
|
|
37
|
+
num_neighbors = num_globals + local_window + num_dilated + num_random
|
|
38
|
+
|
|
39
|
+
device = torch.device('cpu') # For prototype, CPU is fine to build indices
|
|
40
|
+
positions = torch.arange(seq_len, device=device)
|
|
41
|
+
|
|
42
|
+
# 1. Global Tokens (0 to num_globals - 1)
|
|
43
|
+
global_idx = torch.arange(num_globals, device=device).unsqueeze(0).expand(seq_len, -1)
|
|
44
|
+
|
|
45
|
+
# 2. Local Window (-local_half to +local_half - 1)
|
|
46
|
+
# Asegura tamaño exacto `local_window`
|
|
47
|
+
local_offsets = torch.arange(-local_half, local_half + (local_window % 2), device=device).unsqueeze(0)
|
|
48
|
+
local_idx = positions.unsqueeze(1) + local_offsets
|
|
49
|
+
|
|
50
|
+
# 3. Dilated Connections
|
|
51
|
+
dilated_list = []
|
|
52
|
+
for offset in dilated_offsets:
|
|
53
|
+
dilated_list.append(positions + offset)
|
|
54
|
+
dilated_list.append(positions - offset)
|
|
55
|
+
|
|
56
|
+
if dilated_list:
|
|
57
|
+
dilated_idx = torch.stack(dilated_list, dim=1)
|
|
58
|
+
else:
|
|
59
|
+
dilated_idx = torch.empty((seq_len, 0), dtype=torch.long, device=device)
|
|
60
|
+
|
|
61
|
+
# 4. Pseudo-random Connections (Deterministic based on layer_id and position)
|
|
62
|
+
# Using a fast vectorizable integer hash
|
|
63
|
+
pos_expanded = positions.unsqueeze(1).expand(-1, num_random)
|
|
64
|
+
rand_offsets = torch.arange(num_random, device=device).unsqueeze(0).expand(seq_len, -1)
|
|
65
|
+
|
|
66
|
+
# LCG-like hash parameters
|
|
67
|
+
hash_val = (pos_expanded * 2654435761 + layer_id * 805459861 + rand_offsets * 3266489917)
|
|
68
|
+
hash_val = hash_val.to(torch.int64) & 0xFFFFFFFF
|
|
69
|
+
|
|
70
|
+
# Map back to valid sequence indices, avoiding globals
|
|
71
|
+
valid_range = max(1, seq_len - num_globals)
|
|
72
|
+
random_idx = num_globals + (hash_val % valid_range)
|
|
73
|
+
|
|
74
|
+
# Concatenate all indices
|
|
75
|
+
all_idx = torch.cat([global_idx, local_idx, dilated_idx, random_idx], dim=1)
|
|
76
|
+
|
|
77
|
+
# Clamp out-of-bounds indices
|
|
78
|
+
# Si un token apunta fuera de la secuencia, se redirige a sí mismo.
|
|
79
|
+
# Matemáticamente añade más peso al propio token en self-attention, lo cual es inofensivo.
|
|
80
|
+
out_of_bounds = (all_idx < 0) | (all_idx >= seq_len)
|
|
81
|
+
all_idx = torch.where(out_of_bounds, positions.unsqueeze(1), all_idx)
|
|
82
|
+
|
|
83
|
+
return all_idx
|
|
84
|
+
|
|
85
|
+
def estimate_graph_stats(neighbor_lists: torch.Tensor, seq_len: int):
|
|
86
|
+
"""
|
|
87
|
+
Calcula estadísticas del grafo generado para comparar con representaciones densas.
|
|
88
|
+
"""
|
|
89
|
+
num_neighbors = neighbor_lists.shape[1]
|
|
90
|
+
|
|
91
|
+
# Memoria de la representación rala (int64)
|
|
92
|
+
sparse_memory_bytes = neighbor_lists.numel() * 8
|
|
93
|
+
|
|
94
|
+
# Memoria teórica de una máscara densa (bool)
|
|
95
|
+
dense_memory_bytes = seq_len * seq_len * 1
|
|
96
|
+
|
|
97
|
+
# Factor de compresión
|
|
98
|
+
compression_ratio = dense_memory_bytes / sparse_memory_bytes if sparse_memory_bytes > 0 else float('inf')
|
|
99
|
+
|
|
100
|
+
# Conexiones totales
|
|
101
|
+
total_connections = seq_len * num_neighbors
|
|
102
|
+
|
|
103
|
+
return {
|
|
104
|
+
"seq_len": seq_len,
|
|
105
|
+
"num_neighbors": num_neighbors,
|
|
106
|
+
"sparse_memory_mb": sparse_memory_bytes / (1024 * 1024),
|
|
107
|
+
"dense_memory_mb": dense_memory_bytes / (1024 * 1024),
|
|
108
|
+
"compression_ratio": compression_ratio,
|
|
109
|
+
"total_connections": total_connections,
|
|
110
|
+
"dense_connections": seq_len * seq_len
|
|
111
|
+
}
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from halo.core.config import HaloConfig
|
|
6
|
+
from halo.attention.graph import generate_neighbor_lists
|
|
7
|
+
from halo.nn.rope import apply_rotary_pos_emb
|
|
8
|
+
|
|
9
|
+
class HaloSparseAttention(nn.Module):
|
|
10
|
+
"""
|
|
11
|
+
Atención Dispersa HALO-S con Gather-based backend y GQA.
|
|
12
|
+
Garantiza complejidad O(N * num_neighbors) en lugar de O(N^2).
|
|
13
|
+
"""
|
|
14
|
+
def __init__(self, config: HaloConfig, layer_id: int):
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.config = config
|
|
17
|
+
self.layer_id = layer_id
|
|
18
|
+
|
|
19
|
+
self.hidden_size = config.hidden_size
|
|
20
|
+
self.num_heads = config.num_heads
|
|
21
|
+
self.num_kv_heads = config.num_kv_heads
|
|
22
|
+
self.head_dim = config.head_dim
|
|
23
|
+
|
|
24
|
+
assert self.num_heads % self.num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"
|
|
25
|
+
self.num_groups = self.num_heads // self.num_kv_heads
|
|
26
|
+
|
|
27
|
+
# Proyecciones lineales
|
|
28
|
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
|
29
|
+
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
|
30
|
+
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
|
31
|
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
|
32
|
+
|
|
33
|
+
self.dropout = nn.Dropout(config.dropout)
|
|
34
|
+
|
|
35
|
+
# Cache opcional para los vecinos si seq_len no cambia (acelera entrenamiento)
|
|
36
|
+
self._cached_neighbors = None
|
|
37
|
+
self._cached_seq_len = -1
|
|
38
|
+
|
|
39
|
+
def _get_neighbors(self, seq_len: int, device: torch.device):
|
|
40
|
+
if self._cached_seq_len == seq_len and self._cached_neighbors is not None:
|
|
41
|
+
return self._cached_neighbors
|
|
42
|
+
|
|
43
|
+
neighbors = generate_neighbor_lists(
|
|
44
|
+
seq_len=seq_len,
|
|
45
|
+
local_window=self.config.local_window,
|
|
46
|
+
num_globals=self.config.num_globals,
|
|
47
|
+
dilated_offsets=self.config.dilated_offsets,
|
|
48
|
+
num_random=self.config.num_random,
|
|
49
|
+
layer_id=self.layer_id
|
|
50
|
+
).to(device)
|
|
51
|
+
|
|
52
|
+
self._cached_neighbors = neighbors
|
|
53
|
+
self._cached_seq_len = seq_len
|
|
54
|
+
return neighbors
|
|
55
|
+
|
|
56
|
+
def forward(
|
|
57
|
+
self,
|
|
58
|
+
x: torch.Tensor,
|
|
59
|
+
cos: torch.Tensor,
|
|
60
|
+
sin: torch.Tensor,
|
|
61
|
+
is_causal: bool = True
|
|
62
|
+
) -> torch.Tensor:
|
|
63
|
+
batch_size, seq_len, _ = x.shape
|
|
64
|
+
|
|
65
|
+
# Proyectar Q, K, V
|
|
66
|
+
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
67
|
+
k = self.k_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
68
|
+
v = self.v_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
|
69
|
+
|
|
70
|
+
# Aplicar RoPE
|
|
71
|
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
|
72
|
+
|
|
73
|
+
# Obtener matriz de índices dispersos (Neighbor List)
|
|
74
|
+
# shape: (seq_len, num_neighbors)
|
|
75
|
+
neighbors = self._get_neighbors(seq_len, x.device)
|
|
76
|
+
|
|
77
|
+
# Recopilación de vecinos dispersos mediante Advanced Indexing (GATHER)
|
|
78
|
+
# K, V shape original: (batch, num_kv_heads, seq_len, head_dim)
|
|
79
|
+
# k_gathered shape: (batch, num_kv_heads, seq_len, num_neighbors, head_dim)
|
|
80
|
+
k_gathered = k[:, :, neighbors, :]
|
|
81
|
+
v_gathered = v[:, :, neighbors, :]
|
|
82
|
+
|
|
83
|
+
# Expansión para Grouped Query Attention (GQA)
|
|
84
|
+
# Repite el KV cache para coincidir con num_heads
|
|
85
|
+
if self.num_groups > 1:
|
|
86
|
+
k_gathered = k_gathered.repeat_interleave(self.num_groups, dim=1)
|
|
87
|
+
v_gathered = v_gathered.repeat_interleave(self.num_groups, dim=1)
|
|
88
|
+
|
|
89
|
+
# Preparar Q para dot product
|
|
90
|
+
q_expanded = q.unsqueeze(3) # (batch, num_heads, seq_len, 1, head_dim)
|
|
91
|
+
|
|
92
|
+
# Calcular scores de atención locales
|
|
93
|
+
# Matmul: (..., 1, head_dim) x (..., head_dim, num_neighbors) -> (..., 1, num_neighbors)
|
|
94
|
+
scores = torch.matmul(q_expanded, k_gathered.transpose(-2, -1))
|
|
95
|
+
scores = scores / math.sqrt(self.head_dim)
|
|
96
|
+
|
|
97
|
+
# Aplicar máscara causal (autoregresiva)
|
|
98
|
+
# Ningún token puede atender a una posición mayor que él mismo.
|
|
99
|
+
if is_causal:
|
|
100
|
+
# positions: (seq_len, 1)
|
|
101
|
+
positions = torch.arange(seq_len, device=x.device).unsqueeze(1)
|
|
102
|
+
causal_mask = neighbors > positions # True si es futuro
|
|
103
|
+
causal_mask = causal_mask.view(1, 1, seq_len, 1, -1)
|
|
104
|
+
scores.masked_fill_(causal_mask, float('-inf'))
|
|
105
|
+
|
|
106
|
+
# Softmax y Dropout
|
|
107
|
+
attn = torch.softmax(scores, dim=-1)
|
|
108
|
+
attn = self.dropout(attn)
|
|
109
|
+
|
|
110
|
+
# Multiplicar por V local
|
|
111
|
+
# Matmul: (..., 1, num_neighbors) x (..., num_neighbors, head_dim) -> (..., 1, head_dim)
|
|
112
|
+
out = torch.matmul(attn, v_gathered)
|
|
113
|
+
out = out.squeeze(3) # (batch, num_heads, seq_len, head_dim)
|
|
114
|
+
|
|
115
|
+
# Re-empaquetar
|
|
116
|
+
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
|
117
|
+
|
|
118
|
+
return self.o_proj(out)
|
halo/core/__init__.py
ADDED
halo/core/config.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from dataclasses import dataclass, field, asdict
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
@dataclass
|
|
5
|
+
class HaloConfig:
|
|
6
|
+
"""Configuración principal para el modelo HALO-S."""
|
|
7
|
+
vocab_size: int = 256 # Por defecto para CharTokenizer
|
|
8
|
+
hidden_size: int = 512
|
|
9
|
+
num_layers: int = 6
|
|
10
|
+
num_heads: int = 8
|
|
11
|
+
num_kv_heads: int = 2 # Implementando Grouped Query Attention (GQA)
|
|
12
|
+
|
|
13
|
+
# Parámetros del grafo disperso HALO-S
|
|
14
|
+
num_globals: int = 2
|
|
15
|
+
local_window: int = 64
|
|
16
|
+
dilated_offsets: List[int] = field(default_factory=lambda: [1, 2, 4, 8])
|
|
17
|
+
num_random: int = 2
|
|
18
|
+
|
|
19
|
+
# Hiperparámetros de regularización e inferencia
|
|
20
|
+
dropout: float = 0.1
|
|
21
|
+
max_seq_len: int = 4096
|
|
22
|
+
|
|
23
|
+
def __post_init__(self):
|
|
24
|
+
"""Validación de parámetros al crear la instancia."""
|
|
25
|
+
assert self.hidden_size % self.num_heads == 0, (
|
|
26
|
+
f"hidden_size ({self.hidden_size}) debe ser divisible por "
|
|
27
|
+
f"num_heads ({self.num_heads})"
|
|
28
|
+
)
|
|
29
|
+
assert self.num_heads % self.num_kv_heads == 0, (
|
|
30
|
+
f"num_heads ({self.num_heads}) debe ser divisible por "
|
|
31
|
+
f"num_kv_heads ({self.num_kv_heads})"
|
|
32
|
+
)
|
|
33
|
+
assert self.num_globals >= 1, (
|
|
34
|
+
f"Se requiere al menos 1 global token, se recibió num_globals={self.num_globals}"
|
|
35
|
+
)
|
|
36
|
+
assert self.local_window > 0, (
|
|
37
|
+
f"local_window debe ser > 0, se recibió {self.local_window}"
|
|
38
|
+
)
|
|
39
|
+
assert self.max_seq_len > self.num_globals, (
|
|
40
|
+
f"max_seq_len ({self.max_seq_len}) debe ser mayor que "
|
|
41
|
+
f"num_globals ({self.num_globals})"
|
|
42
|
+
)
|
|
43
|
+
assert 0 <= self.dropout < 1, (
|
|
44
|
+
f"dropout debe estar en [0, 1), se recibió {self.dropout}"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def head_dim(self):
|
|
49
|
+
"""Dimensión por cabeza de atención."""
|
|
50
|
+
return self.hidden_size // self.num_heads
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def num_neighbors(self) -> int:
|
|
54
|
+
"""Número total de vecinos por token en el grafo disperso.
|
|
55
|
+
|
|
56
|
+
Incluye: globals + ventana local + conexiones dilatadas (ida y vuelta) + aleatorios.
|
|
57
|
+
"""
|
|
58
|
+
return (
|
|
59
|
+
self.num_globals
|
|
60
|
+
+ self.local_window
|
|
61
|
+
+ 2 * len(self.dilated_offsets)
|
|
62
|
+
+ self.num_random
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def to_dict(self) -> dict:
|
|
66
|
+
"""Serializa la configuración a un diccionario con todos los campos."""
|
|
67
|
+
return asdict(self)
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def from_dict(cls, d: dict) -> "HaloConfig":
|
|
71
|
+
"""Reconstruye una instancia de HaloConfig desde un diccionario."""
|
|
72
|
+
return cls(**d)
|