rxnn 0.1.83__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/.DS_Store +0 -0
- rxnn/experimental/attention.py +5 -0
- rxnn/memory/attention.py +42 -0
- rxnn/memory/stm.py +53 -12
- rxnn/rxt/models.py +71 -0
- rxnn/training/bml.py +2 -59
- rxnn/training/callbacks.py +302 -39
- rxnn/training/dataset.py +344 -1
- rxnn/training/models.py +142 -0
- rxnn/training/mrl.py +808 -0
- rxnn/training/reward.py +111 -0
- rxnn/training/rl.py +69 -0
- rxnn/training/utils.py +148 -0
- rxnn/transformers/attention.py +10 -0
- rxnn/transformers/layers.py +6 -0
- rxnn/transformers/models.py +16 -4
- rxnn/transformers/positional.py +7 -0
- rxnn/transformers/sampler.py +283 -9
- {rxnn-0.1.83.dist-info → rxnn-0.2.0.dist-info}/METADATA +11 -9
- rxnn-0.2.0.dist-info/RECORD +38 -0
- rxnn-0.1.83.dist-info/RECORD +0 -31
- {rxnn-0.1.83.dist-info → rxnn-0.2.0.dist-info}/LICENSE +0 -0
- {rxnn-0.1.83.dist-info → rxnn-0.2.0.dist-info}/WHEEL +0 -0
rxnn/training/reward.py
ADDED
@@ -0,0 +1,111 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from enum import Enum
|
5
|
+
from typing import Optional
|
6
|
+
from .utils import TokenizedDict
|
7
|
+
|
8
|
+
|
9
|
+
class MrlRewardMode(Enum):
|
10
|
+
STANDARD = 1
|
11
|
+
NEGATIVE = 2
|
12
|
+
LONG_RANGE = 3
|
13
|
+
|
14
|
+
class MrlRewardModel:
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
shared_embedding: nn.Embedding,
|
18
|
+
device: torch.device,
|
19
|
+
bleu_with_saved_data: bool = False,
|
20
|
+
bleu_factor: float = 0.5,
|
21
|
+
cos_factor: float = 0.5,
|
22
|
+
cos_ref_factor: float = 0.5,
|
23
|
+
cos_saved_factor: float = 0.5,
|
24
|
+
neg_bleu_factor: Optional[float] = None,
|
25
|
+
neg_cos_factor: Optional[float] = None,
|
26
|
+
neg_cos_ref_factor: Optional[float] = None,
|
27
|
+
neg_cos_saved_factor: Optional[float] = None,
|
28
|
+
neg_bleu_ref_factor: float = 0.5,
|
29
|
+
neg_bleu_saved_factor: float = 0.5,
|
30
|
+
allow_not_summing_factors: bool = False,
|
31
|
+
):
|
32
|
+
self.shared_embedding = shared_embedding.to(device)
|
33
|
+
self.device = device
|
34
|
+
self.bleu_with_saved_data = bleu_with_saved_data
|
35
|
+
|
36
|
+
if not allow_not_summing_factors:
|
37
|
+
assert bleu_factor + cos_factor == 1.0
|
38
|
+
assert cos_ref_factor + cos_saved_factor == 1.0
|
39
|
+
assert neg_bleu_factor + neg_cos_factor == 1.0
|
40
|
+
assert neg_cos_ref_factor + neg_cos_saved_factor == 1.0
|
41
|
+
assert neg_bleu_ref_factor + neg_bleu_saved_factor == 1.0
|
42
|
+
|
43
|
+
self.bleu_factor = bleu_factor
|
44
|
+
self.cos_factor = cos_factor
|
45
|
+
self.cos_ref_factor = cos_ref_factor
|
46
|
+
self.cos_saved_factor = cos_saved_factor
|
47
|
+
self.neg_bleu_factor = neg_bleu_factor if neg_bleu_factor is not None else bleu_factor
|
48
|
+
self.neg_cos_factor = neg_cos_factor if neg_cos_factor is not None else cos_factor
|
49
|
+
self.neg_cos_ref_factor = neg_cos_ref_factor if neg_cos_ref_factor is not None else cos_ref_factor
|
50
|
+
self.neg_cos_saved_factor = neg_cos_saved_factor if neg_cos_saved_factor is not None else cos_saved_factor
|
51
|
+
self.neg_bleu_ref_factor = neg_bleu_ref_factor
|
52
|
+
self.neg_bleu_saved_factor = neg_bleu_saved_factor
|
53
|
+
|
54
|
+
def _sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> float:
|
55
|
+
from nltk.translate.bleu_score import sentence_bleu
|
56
|
+
refs = [reference, saved_data] if self.bleu_with_saved_data else [reference]
|
57
|
+
return sentence_bleu(refs, generated, weights=(0.25, 0.25, 0.25, 0.25))
|
58
|
+
|
59
|
+
def _negative_sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> float:
|
60
|
+
from nltk.translate.bleu_score import sentence_bleu
|
61
|
+
|
62
|
+
if self.bleu_with_saved_data:
|
63
|
+
ref_bleu = sentence_bleu([reference], generated, weights=(0.25, 0.25, 0.25, 0.25))
|
64
|
+
saved_bleu = sentence_bleu([saved_data], generated, weights=(0.25, 0.25, 0.25))
|
65
|
+
saved_bleu = 1 - saved_bleu
|
66
|
+
|
67
|
+
return (self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu) / 2
|
68
|
+
else:
|
69
|
+
return sentence_bleu([reference], generated, weights=(0.25, 0.25, 0.25, 0.25))
|
70
|
+
|
71
|
+
def batch_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> list[float]:
|
72
|
+
batch_size = generated.size(0)
|
73
|
+
return [self._sentence_bleu(generated[i], reference[i], saved_data[i]) for i in range(batch_size)]
|
74
|
+
|
75
|
+
def _sequence_embedding(self, sequence: torch.Tensor) -> torch.Tensor:
|
76
|
+
embedding = self.shared_embedding(sequence.to(self.device))
|
77
|
+
return embedding.mean(dim=1)
|
78
|
+
|
79
|
+
def _cosine_sim(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
|
80
|
+
generated_emb = self._sequence_embedding(generated)
|
81
|
+
|
82
|
+
gen_and_saved = F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data))
|
83
|
+
gen_and_ref = F.cosine_similarity(generated_emb, self._sequence_embedding(reference))
|
84
|
+
return gen_and_saved, gen_and_ref
|
85
|
+
|
86
|
+
def batch_cosine(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> torch.Tensor:
|
87
|
+
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
88
|
+
|
89
|
+
return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
|
90
|
+
|
91
|
+
def negative_cosine(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> torch.Tensor:
|
92
|
+
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
93
|
+
|
94
|
+
return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
|
95
|
+
|
96
|
+
def __call__(
|
97
|
+
self,
|
98
|
+
generated: TokenizedDict,
|
99
|
+
reference: TokenizedDict,
|
100
|
+
saved_data: TokenizedDict,
|
101
|
+
mode: MrlRewardMode = MrlRewardMode.STANDARD
|
102
|
+
) -> list[float]:
|
103
|
+
if mode == MrlRewardMode.STANDARD or mode == MrlRewardMode.LONG_RANGE:
|
104
|
+
bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
105
|
+
cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
106
|
+
return (self.bleu_factor * torch.tensor(bleu) + self.cos_factor * cosine).tolist()
|
107
|
+
else:
|
108
|
+
bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
109
|
+
cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
110
|
+
return (self.neg_bleu_factor * torch.tensor(bleu) + self.neg_cos_factor * cosine).tolist()
|
111
|
+
|
rxnn/training/rl.py
ADDED
@@ -0,0 +1,69 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from typing import TypedDict
|
6
|
+
|
7
|
+
|
8
|
+
class RlAlgorithm(ABC):
|
9
|
+
def __init__(self):
|
10
|
+
super(RlAlgorithm, self).__init__()
|
11
|
+
self.critic_loss = nn.MSELoss()
|
12
|
+
|
13
|
+
@abstractmethod
|
14
|
+
def policy_loss(self, input_ids: torch.Tensor, logits: torch.Tensor, old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
15
|
+
pass
|
16
|
+
|
17
|
+
@abstractmethod
|
18
|
+
def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
|
19
|
+
pass
|
20
|
+
|
21
|
+
def critic_loss(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
|
22
|
+
return self.critic_loss(rewards, values)
|
23
|
+
|
24
|
+
class PPOConfig(TypedDict):
|
25
|
+
gae_gamma: float
|
26
|
+
gae_lambda: float
|
27
|
+
clip_eps: float
|
28
|
+
|
29
|
+
class PPOAlgorithm(RlAlgorithm):
|
30
|
+
def __init__(self, config: PPOConfig):
|
31
|
+
super(PPOAlgorithm, self).__init__()
|
32
|
+
|
33
|
+
# PPO Config
|
34
|
+
self.gae_gamma = config.get('gae_gamma', 0.99)
|
35
|
+
self.gae_lambda = config.get('gae_lambda', 0.95)
|
36
|
+
self.clip_eps = config.get('clip_eps', 0.2)
|
37
|
+
|
38
|
+
def policy_loss(self, input_ids: torch.Tensor, logits: torch.Tensor, old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
39
|
+
# a) Get new log probs
|
40
|
+
new_probs = F.log_softmax(logits, dim=-1)
|
41
|
+
new_log_probs = new_probs.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1)
|
42
|
+
|
43
|
+
# b) Calculate ratio
|
44
|
+
ratio = (new_log_probs - old_log_probs).exp()
|
45
|
+
|
46
|
+
# c) Clipped surrogate loss
|
47
|
+
surr1 = ratio * advantages
|
48
|
+
surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * advantages
|
49
|
+
policy_loss = -torch.min(surr1, surr2).mean()
|
50
|
+
|
51
|
+
# d) Entropy bonus
|
52
|
+
entropy = -torch.sum(new_probs * new_probs.exp(), dim=-1).mean()
|
53
|
+
policy_loss -= 0.01 * entropy
|
54
|
+
|
55
|
+
return policy_loss
|
56
|
+
|
57
|
+
def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor, next_value: torch.Tensor) -> torch.Tensor:
|
58
|
+
advantages = torch.zeros_like(rewards, device=values.device)
|
59
|
+
last_advantage = 0
|
60
|
+
for t in reversed(range(rewards.size(0))):
|
61
|
+
delta = rewards[t] + self.gae_gamma * next_value - values[t]
|
62
|
+
advantages[t] = delta + self.gae_gamma * self.gae_lambda * last_advantage
|
63
|
+
last_advantage = advantages[t]
|
64
|
+
return advantages
|
65
|
+
|
66
|
+
def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
|
67
|
+
advantages = self._compute_gae(rewards, values[:-1], values[-1])
|
68
|
+
normalized_advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
69
|
+
return normalized_advantages
|
rxnn/training/utils.py
ADDED
@@ -0,0 +1,148 @@
|
|
1
|
+
import torch
|
2
|
+
from typing import TypedDict
|
3
|
+
|
4
|
+
class SpecialTokenIds(TypedDict):
|
5
|
+
bos: int
|
6
|
+
eos: int
|
7
|
+
pad: int
|
8
|
+
|
9
|
+
class TokenizedDict(TypedDict):
|
10
|
+
input_ids: torch.Tensor
|
11
|
+
attention_mask: torch.Tensor
|
12
|
+
|
13
|
+
def smart_concat_critic_states(
|
14
|
+
prev_query: TokenizedDict,
|
15
|
+
prev_answer: TokenizedDict,
|
16
|
+
next_query: TokenizedDict,
|
17
|
+
max_length: int,
|
18
|
+
pad_token_id: int
|
19
|
+
) -> TokenizedDict:
|
20
|
+
"""
|
21
|
+
Smart vectorized concatenation of MRL critic states - previous interaction (query and answer) and next query.
|
22
|
+
It creates a batch of critic input sequences from previous query, previous answer and next query batches.
|
23
|
+
Used in MRL to concatenate critic states in correct format.
|
24
|
+
|
25
|
+
All the concatenated sequences (batches) are padded to the same max length, but the result should have two times
|
26
|
+
longer max length. Single max length is made to fit single query and answer, but here we have additional next query,
|
27
|
+
so we are using 2x longer sequence for safety.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
prev_query (TokenizedDict): Batch of tokenized queries with attention masks from previous interaction
|
31
|
+
prev_answer (TokenizedDict): Batch of tokenized answers with attention masks from previous interaction
|
32
|
+
next_query (TokenizedDict): Batch of tokenized queries with attention masks from next interaction
|
33
|
+
max_length (int): Max length of result sequence.
|
34
|
+
pad_token_id (int): Index of padding token
|
35
|
+
"""
|
36
|
+
device = prev_query['input_ids'].device
|
37
|
+
batch_size = prev_query['input_ids'].size(0)
|
38
|
+
|
39
|
+
# Get input dimensions
|
40
|
+
query_max_len = prev_query['input_ids'].size(1)
|
41
|
+
answer_max_len = prev_answer['input_ids'].size(1)
|
42
|
+
next_q_max_len = next_query['input_ids'].size(1)
|
43
|
+
|
44
|
+
# Get actual lengths using attention masks
|
45
|
+
query_lens = prev_query['attention_mask'].sum(dim=1)
|
46
|
+
answer_lens = prev_answer['attention_mask'].sum(dim=1)
|
47
|
+
next_query_lens = next_query['attention_mask'].sum(dim=1)
|
48
|
+
|
49
|
+
# Calculate positions and boundaries
|
50
|
+
positions = torch.arange(max_length, device=device).expand(batch_size, -1)
|
51
|
+
section1_end = query_lens.unsqueeze(1)
|
52
|
+
section2_end = section1_end + answer_lens.unsqueeze(1)
|
53
|
+
section3_end = section2_end + next_query_lens.unsqueeze(1)
|
54
|
+
|
55
|
+
# Create masks for each section
|
56
|
+
mask_prev = positions < section1_end
|
57
|
+
mask_answer = (positions >= section1_end) & (positions < section2_end)
|
58
|
+
mask_next = (positions >= section2_end) & (positions < section3_end)
|
59
|
+
|
60
|
+
# Build combined tensor
|
61
|
+
combined_ids = torch.full((batch_size, max_length), pad_token_id, device=device)
|
62
|
+
|
63
|
+
# 1. Fill previous query section (with input length clamping)
|
64
|
+
query_indices = positions.clamp(max=query_max_len - 1)
|
65
|
+
combined_ids = torch.where(
|
66
|
+
mask_prev,
|
67
|
+
prev_query['input_ids'].gather(1, query_indices),
|
68
|
+
combined_ids
|
69
|
+
)
|
70
|
+
|
71
|
+
# 2. Fill answer section (with answer length clamping)
|
72
|
+
answer_pos = (positions - section1_end).clamp(min=0, max=answer_max_len - 1)
|
73
|
+
combined_ids = torch.where(
|
74
|
+
mask_answer,
|
75
|
+
prev_answer['input_ids'].gather(1, answer_pos),
|
76
|
+
combined_ids
|
77
|
+
)
|
78
|
+
|
79
|
+
# 3. Fill next query section (with next query length clamping)
|
80
|
+
next_q_pos = (positions - section2_end).clamp(min=0, max=next_q_max_len - 1)
|
81
|
+
combined_ids = torch.where(
|
82
|
+
mask_next,
|
83
|
+
next_query['input_ids'].gather(1, next_q_pos),
|
84
|
+
combined_ids
|
85
|
+
)
|
86
|
+
|
87
|
+
# Create attention mask
|
88
|
+
combined_mask = (positions < section3_end).long()
|
89
|
+
|
90
|
+
return {
|
91
|
+
'input_ids': combined_ids,
|
92
|
+
'attention_mask': combined_mask
|
93
|
+
}
|
94
|
+
|
95
|
+
def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, pad_token_id: int) -> TokenizedDict:
|
96
|
+
"""
|
97
|
+
Smart vectorized concatenation of interaction parts - query and answer. It creates
|
98
|
+
batch of interactions from query and answer batches. Used in MRL to concatenate data
|
99
|
+
to encode and update memory.
|
100
|
+
|
101
|
+
Query and answer sequences are padded to the same max length, and the result also has
|
102
|
+
the same length.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
query (TokenizedDict): Batch of tokenized queries with attention masks
|
106
|
+
answer (TokenizedDict): Batch of tokenized answers with attention masks
|
107
|
+
max_length (int): Max length of each sequence - query, answer and result.
|
108
|
+
pad_token_id (int): Index of padding token
|
109
|
+
"""
|
110
|
+
device = query['input_ids'].device
|
111
|
+
batch_size = query['input_ids'].size(0)
|
112
|
+
|
113
|
+
# Get actual lengths from attention masks
|
114
|
+
query_lens = query['attention_mask'].sum(dim=1)
|
115
|
+
answer_lens = answer['attention_mask'].sum(dim=1)
|
116
|
+
|
117
|
+
# Create combined length tensor
|
118
|
+
combined_lens = torch.minimum(query_lens + answer_lens,
|
119
|
+
torch.full_like(query_lens, max_length))
|
120
|
+
|
121
|
+
# Create position indices [batch_size, max_length]
|
122
|
+
positions = torch.arange(max_length, device=device).expand(batch_size, -1)
|
123
|
+
|
124
|
+
# Create mask for query/answer parts
|
125
|
+
query_mask = positions < query_lens.unsqueeze(1)
|
126
|
+
answer_mask = (positions >= query_lens.unsqueeze(1)) & (positions < combined_lens.unsqueeze(1))
|
127
|
+
|
128
|
+
# Calculate answer positions with overflow protection
|
129
|
+
answer_pos = (positions - query_lens.unsqueeze(1)).clamp(min=0)
|
130
|
+
|
131
|
+
# Build combined_ids using vectorized where
|
132
|
+
combined_ids = torch.where(
|
133
|
+
query_mask,
|
134
|
+
query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1)),
|
135
|
+
torch.where(
|
136
|
+
answer_mask,
|
137
|
+
answer['input_ids'].gather(1, answer_pos),
|
138
|
+
query['input_ids'].new_full((1,), pad_token_id)
|
139
|
+
)
|
140
|
+
)
|
141
|
+
|
142
|
+
# Build attention mask
|
143
|
+
combined_mask = (positions < combined_lens.unsqueeze(1)).long()
|
144
|
+
|
145
|
+
return {
|
146
|
+
'input_ids': combined_ids,
|
147
|
+
'attention_mask': combined_mask
|
148
|
+
}
|
rxnn/transformers/attention.py
CHANGED
@@ -16,6 +16,7 @@ class MultiHeadAttention(nn.Module):
|
|
16
16
|
dropout: float = 0.0,
|
17
17
|
rope: RotaryPositionalEmbedding = None,
|
18
18
|
rope_only_for_query: bool = False,
|
19
|
+
rope_only_for_keys: bool = False,
|
19
20
|
use_relative_embeddings: bool = False,
|
20
21
|
max_seq_len: int = 1024,
|
21
22
|
use_flash_attention: bool = True,
|
@@ -37,10 +38,12 @@ class MultiHeadAttention(nn.Module):
|
|
37
38
|
self.rel_embed = RelativePositionalEmbedding(max_seq_len, embed_dim // num_heads)
|
38
39
|
self.rope = None
|
39
40
|
self.rope_only_for_query = False
|
41
|
+
self.rope_only_for_keys = False
|
40
42
|
else:
|
41
43
|
self.rel_embed = None
|
42
44
|
self.rope = rope
|
43
45
|
self.rope_only_for_query = rope_only_for_query
|
46
|
+
self.rope_only_for_keys = rope_only_for_keys
|
44
47
|
self.dropout = nn.Dropout(dropout)
|
45
48
|
self._init_q(embed_dim)
|
46
49
|
self._init_kv(embed_dim)
|
@@ -70,6 +73,8 @@ class MultiHeadAttention(nn.Module):
|
|
70
73
|
if self.rope is not None:
|
71
74
|
if self.rope_only_for_query:
|
72
75
|
q = self.rope.forward_one(q)
|
76
|
+
elif self.rope_only_for_keys:
|
77
|
+
k = self.rope.forward_one(k)
|
73
78
|
else:
|
74
79
|
q, k = self.rope(q, k)
|
75
80
|
return q, k
|
@@ -192,6 +197,7 @@ class GroupedQueryAttention(MultiHeadAttention):
|
|
192
197
|
k = self.k_proj(key).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
|
193
198
|
v = self.v_proj(value).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
|
194
199
|
else:
|
200
|
+
# Relative embedding version is not working without this strange mapping - it will be removed in next versions
|
195
201
|
group_heads = self.num_heads // self.num_groups
|
196
202
|
|
197
203
|
# Process Q
|
@@ -289,6 +295,7 @@ def init_attention(
|
|
289
295
|
dropout: float = 0.0,
|
290
296
|
rope: RotaryPositionalEmbedding = None,
|
291
297
|
rope_only_for_query: bool = False,
|
298
|
+
rope_only_for_keys: bool = False,
|
292
299
|
use_relative_embeddings: bool = False,
|
293
300
|
max_seq_len: int = 1024,
|
294
301
|
use_flash_attention: bool = False,
|
@@ -308,6 +315,7 @@ def init_attention(
|
|
308
315
|
use_relative_embeddings=use_relative_embeddings,
|
309
316
|
max_seq_len=max_seq_len,
|
310
317
|
rope_only_for_query=rope_only_for_query,
|
318
|
+
rope_only_for_keys=rope_only_for_keys,
|
311
319
|
use_flash_attention=use_flash_attention,
|
312
320
|
is_causal=is_causal,
|
313
321
|
use_bias=use_bias,
|
@@ -321,6 +329,7 @@ def init_attention(
|
|
321
329
|
use_relative_embeddings=use_relative_embeddings,
|
322
330
|
max_seq_len=max_seq_len,
|
323
331
|
rope_only_for_query=rope_only_for_query,
|
332
|
+
rope_only_for_keys=rope_only_for_keys,
|
324
333
|
use_flash_attention=use_flash_attention,
|
325
334
|
is_causal=is_causal,
|
326
335
|
use_bias=use_bias,
|
@@ -334,6 +343,7 @@ def init_attention(
|
|
334
343
|
use_relative_embeddings=use_relative_embeddings,
|
335
344
|
max_seq_len=max_seq_len,
|
336
345
|
rope_only_for_query=rope_only_for_query,
|
346
|
+
rope_only_for_keys=rope_only_for_keys,
|
337
347
|
use_flash_attention=use_flash_attention,
|
338
348
|
is_causal=is_causal,
|
339
349
|
use_bias=use_bias,
|
rxnn/transformers/layers.py
CHANGED
@@ -61,6 +61,12 @@ class ReactiveTransformerLayer(nn.Module):
|
|
61
61
|
for param in self.memory_cross_attention.parameters():
|
62
62
|
param.requires_grad_(is_trainable)
|
63
63
|
|
64
|
+
def update_max_len(self, max_seq_len: int):
|
65
|
+
if self.attention.rope is not None:
|
66
|
+
self.attention.rope.update_max_len(max_seq_len)
|
67
|
+
if self.memory_cross_attention.rope is not None:
|
68
|
+
self.memory_cross_attention.rope.update_max_len(max_seq_len)
|
69
|
+
|
64
70
|
def moe_router_loss(self):
|
65
71
|
ff_router_loss = self.ff.router_loss() if self.use_moe else None
|
66
72
|
att_router_loss = None
|
rxnn/transformers/models.py
CHANGED
@@ -72,11 +72,17 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
|
|
72
72
|
# Process shared layers
|
73
73
|
if self.shared_layers is not None:
|
74
74
|
for i in range(self.num_shared_layers):
|
75
|
-
layer_stm = self.stm(i)
|
75
|
+
layer_stm = self.stm(i)
|
76
|
+
# expand layer STM to batch size, if it's not in batch mode
|
77
|
+
if layer_stm.size(0) == 1:
|
78
|
+
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
76
79
|
x = self.shared_layers[i](x, layer_stm, mask=mask)
|
77
80
|
# Process own layers
|
78
81
|
for i in range(self.num_own_layers):
|
79
|
-
layer_stm = self.stm(i)
|
82
|
+
layer_stm = self.stm(i)
|
83
|
+
# expand layer STM to batch size, if it's not in batch mode
|
84
|
+
if layer_stm.size(0) == 1:
|
85
|
+
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
80
86
|
x = self.layers[i](x, layer_stm, mask=mask)
|
81
87
|
return self.head(x)
|
82
88
|
|
@@ -93,12 +99,18 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
|
|
93
99
|
# Process shared layers
|
94
100
|
if self.shared_layers is not None:
|
95
101
|
for i in range(self.num_shared_layers):
|
96
|
-
layer_stm = self.stm(i)
|
102
|
+
layer_stm = self.stm(i)
|
103
|
+
# expand layer STM to batch size, if it's not in batch mode
|
104
|
+
if layer_stm.size(0) == 1:
|
105
|
+
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
97
106
|
x = self.shared_layers[i](x, layer_stm, mask=attention_mask)
|
98
107
|
hidden_states.append(x)
|
99
108
|
# Process own layers
|
100
109
|
for i in range(self.num_own_layers):
|
101
|
-
layer_stm = self.stm(i)
|
110
|
+
layer_stm = self.stm(i)
|
111
|
+
# expand layer STM to batch size, if it's not in batch mode
|
112
|
+
if layer_stm.size(0) == 1:
|
113
|
+
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
102
114
|
x = self.layers[i](x, layer_stm, mask=attention_mask)
|
103
115
|
hidden_states.append(x)
|
104
116
|
return x, torch.stack(hidden_states)
|
rxnn/transformers/positional.py
CHANGED
@@ -18,6 +18,11 @@ class RotaryPositionalEmbedding(nn.Module):
|
|
18
18
|
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
|
19
19
|
self.register_buffer('cache', freqs)
|
20
20
|
|
21
|
+
def update_max_len(self, max_seq_len: int):
|
22
|
+
self.max_seq_len = max_seq_len
|
23
|
+
t = torch.arange(max_seq_len).type_as(self.inv_freq)
|
24
|
+
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
|
25
|
+
self.cache = freqs
|
21
26
|
|
22
27
|
def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
23
28
|
seq_len = q.size(-2)
|
@@ -42,6 +47,8 @@ class RotaryPositionalEmbedding(nn.Module):
|
|
42
47
|
return q_embed
|
43
48
|
|
44
49
|
def _prepare_freqs(self, seq_len: int) -> torch.Tensor:
|
50
|
+
if seq_len > self.max_seq_len:
|
51
|
+
self.update_max_len(seq_len)
|
45
52
|
return self.cache[:seq_len][None, None, :, :]
|
46
53
|
|
47
54
|
def _rotate(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|