rxnn 0.1.83__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,111 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from enum import Enum
5
+ from typing import Optional
6
+ from .utils import TokenizedDict
7
+
8
+
9
+ class MrlRewardMode(Enum):
10
+ STANDARD = 1
11
+ NEGATIVE = 2
12
+ LONG_RANGE = 3
13
+
14
+ class MrlRewardModel:
15
+ def __init__(
16
+ self,
17
+ shared_embedding: nn.Embedding,
18
+ device: torch.device,
19
+ bleu_with_saved_data: bool = False,
20
+ bleu_factor: float = 0.5,
21
+ cos_factor: float = 0.5,
22
+ cos_ref_factor: float = 0.5,
23
+ cos_saved_factor: float = 0.5,
24
+ neg_bleu_factor: Optional[float] = None,
25
+ neg_cos_factor: Optional[float] = None,
26
+ neg_cos_ref_factor: Optional[float] = None,
27
+ neg_cos_saved_factor: Optional[float] = None,
28
+ neg_bleu_ref_factor: float = 0.5,
29
+ neg_bleu_saved_factor: float = 0.5,
30
+ allow_not_summing_factors: bool = False,
31
+ ):
32
+ self.shared_embedding = shared_embedding.to(device)
33
+ self.device = device
34
+ self.bleu_with_saved_data = bleu_with_saved_data
35
+
36
+ if not allow_not_summing_factors:
37
+ assert bleu_factor + cos_factor == 1.0
38
+ assert cos_ref_factor + cos_saved_factor == 1.0
39
+ assert neg_bleu_factor + neg_cos_factor == 1.0
40
+ assert neg_cos_ref_factor + neg_cos_saved_factor == 1.0
41
+ assert neg_bleu_ref_factor + neg_bleu_saved_factor == 1.0
42
+
43
+ self.bleu_factor = bleu_factor
44
+ self.cos_factor = cos_factor
45
+ self.cos_ref_factor = cos_ref_factor
46
+ self.cos_saved_factor = cos_saved_factor
47
+ self.neg_bleu_factor = neg_bleu_factor if neg_bleu_factor is not None else bleu_factor
48
+ self.neg_cos_factor = neg_cos_factor if neg_cos_factor is not None else cos_factor
49
+ self.neg_cos_ref_factor = neg_cos_ref_factor if neg_cos_ref_factor is not None else cos_ref_factor
50
+ self.neg_cos_saved_factor = neg_cos_saved_factor if neg_cos_saved_factor is not None else cos_saved_factor
51
+ self.neg_bleu_ref_factor = neg_bleu_ref_factor
52
+ self.neg_bleu_saved_factor = neg_bleu_saved_factor
53
+
54
+ def _sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> float:
55
+ from nltk.translate.bleu_score import sentence_bleu
56
+ refs = [reference, saved_data] if self.bleu_with_saved_data else [reference]
57
+ return sentence_bleu(refs, generated, weights=(0.25, 0.25, 0.25, 0.25))
58
+
59
+ def _negative_sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> float:
60
+ from nltk.translate.bleu_score import sentence_bleu
61
+
62
+ if self.bleu_with_saved_data:
63
+ ref_bleu = sentence_bleu([reference], generated, weights=(0.25, 0.25, 0.25, 0.25))
64
+ saved_bleu = sentence_bleu([saved_data], generated, weights=(0.25, 0.25, 0.25))
65
+ saved_bleu = 1 - saved_bleu
66
+
67
+ return (self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu) / 2
68
+ else:
69
+ return sentence_bleu([reference], generated, weights=(0.25, 0.25, 0.25, 0.25))
70
+
71
+ def batch_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> list[float]:
72
+ batch_size = generated.size(0)
73
+ return [self._sentence_bleu(generated[i], reference[i], saved_data[i]) for i in range(batch_size)]
74
+
75
+ def _sequence_embedding(self, sequence: torch.Tensor) -> torch.Tensor:
76
+ embedding = self.shared_embedding(sequence.to(self.device))
77
+ return embedding.mean(dim=1)
78
+
79
+ def _cosine_sim(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
80
+ generated_emb = self._sequence_embedding(generated)
81
+
82
+ gen_and_saved = F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data))
83
+ gen_and_ref = F.cosine_similarity(generated_emb, self._sequence_embedding(reference))
84
+ return gen_and_saved, gen_and_ref
85
+
86
+ def batch_cosine(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> torch.Tensor:
87
+ gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
88
+
89
+ return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
90
+
91
+ def negative_cosine(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> torch.Tensor:
92
+ gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
93
+
94
+ return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
95
+
96
+ def __call__(
97
+ self,
98
+ generated: TokenizedDict,
99
+ reference: TokenizedDict,
100
+ saved_data: TokenizedDict,
101
+ mode: MrlRewardMode = MrlRewardMode.STANDARD
102
+ ) -> list[float]:
103
+ if mode == MrlRewardMode.STANDARD or mode == MrlRewardMode.LONG_RANGE:
104
+ bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
105
+ cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
106
+ return (self.bleu_factor * torch.tensor(bleu) + self.cos_factor * cosine).tolist()
107
+ else:
108
+ bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
109
+ cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
110
+ return (self.neg_bleu_factor * torch.tensor(bleu) + self.neg_cos_factor * cosine).tolist()
111
+
rxnn/training/rl.py ADDED
@@ -0,0 +1,69 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from abc import ABC, abstractmethod
5
+ from typing import TypedDict
6
+
7
+
8
+ class RlAlgorithm(ABC):
9
+ def __init__(self):
10
+ super(RlAlgorithm, self).__init__()
11
+ self.critic_loss = nn.MSELoss()
12
+
13
+ @abstractmethod
14
+ def policy_loss(self, input_ids: torch.Tensor, logits: torch.Tensor, old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
15
+ pass
16
+
17
+ @abstractmethod
18
+ def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
19
+ pass
20
+
21
+ def critic_loss(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
22
+ return self.critic_loss(rewards, values)
23
+
24
+ class PPOConfig(TypedDict):
25
+ gae_gamma: float
26
+ gae_lambda: float
27
+ clip_eps: float
28
+
29
+ class PPOAlgorithm(RlAlgorithm):
30
+ def __init__(self, config: PPOConfig):
31
+ super(PPOAlgorithm, self).__init__()
32
+
33
+ # PPO Config
34
+ self.gae_gamma = config.get('gae_gamma', 0.99)
35
+ self.gae_lambda = config.get('gae_lambda', 0.95)
36
+ self.clip_eps = config.get('clip_eps', 0.2)
37
+
38
+ def policy_loss(self, input_ids: torch.Tensor, logits: torch.Tensor, old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
39
+ # a) Get new log probs
40
+ new_probs = F.log_softmax(logits, dim=-1)
41
+ new_log_probs = new_probs.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1)
42
+
43
+ # b) Calculate ratio
44
+ ratio = (new_log_probs - old_log_probs).exp()
45
+
46
+ # c) Clipped surrogate loss
47
+ surr1 = ratio * advantages
48
+ surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * advantages
49
+ policy_loss = -torch.min(surr1, surr2).mean()
50
+
51
+ # d) Entropy bonus
52
+ entropy = -torch.sum(new_probs * new_probs.exp(), dim=-1).mean()
53
+ policy_loss -= 0.01 * entropy
54
+
55
+ return policy_loss
56
+
57
+ def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor, next_value: torch.Tensor) -> torch.Tensor:
58
+ advantages = torch.zeros_like(rewards, device=values.device)
59
+ last_advantage = 0
60
+ for t in reversed(range(rewards.size(0))):
61
+ delta = rewards[t] + self.gae_gamma * next_value - values[t]
62
+ advantages[t] = delta + self.gae_gamma * self.gae_lambda * last_advantage
63
+ last_advantage = advantages[t]
64
+ return advantages
65
+
66
+ def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
67
+ advantages = self._compute_gae(rewards, values[:-1], values[-1])
68
+ normalized_advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
69
+ return normalized_advantages
rxnn/training/utils.py ADDED
@@ -0,0 +1,148 @@
1
+ import torch
2
+ from typing import TypedDict
3
+
4
+ class SpecialTokenIds(TypedDict):
5
+ bos: int
6
+ eos: int
7
+ pad: int
8
+
9
+ class TokenizedDict(TypedDict):
10
+ input_ids: torch.Tensor
11
+ attention_mask: torch.Tensor
12
+
13
+ def smart_concat_critic_states(
14
+ prev_query: TokenizedDict,
15
+ prev_answer: TokenizedDict,
16
+ next_query: TokenizedDict,
17
+ max_length: int,
18
+ pad_token_id: int
19
+ ) -> TokenizedDict:
20
+ """
21
+ Smart vectorized concatenation of MRL critic states - previous interaction (query and answer) and next query.
22
+ It creates a batch of critic input sequences from previous query, previous answer and next query batches.
23
+ Used in MRL to concatenate critic states in correct format.
24
+
25
+ All the concatenated sequences (batches) are padded to the same max length, but the result should have two times
26
+ longer max length. Single max length is made to fit single query and answer, but here we have additional next query,
27
+ so we are using 2x longer sequence for safety.
28
+
29
+ Args:
30
+ prev_query (TokenizedDict): Batch of tokenized queries with attention masks from previous interaction
31
+ prev_answer (TokenizedDict): Batch of tokenized answers with attention masks from previous interaction
32
+ next_query (TokenizedDict): Batch of tokenized queries with attention masks from next interaction
33
+ max_length (int): Max length of result sequence.
34
+ pad_token_id (int): Index of padding token
35
+ """
36
+ device = prev_query['input_ids'].device
37
+ batch_size = prev_query['input_ids'].size(0)
38
+
39
+ # Get input dimensions
40
+ query_max_len = prev_query['input_ids'].size(1)
41
+ answer_max_len = prev_answer['input_ids'].size(1)
42
+ next_q_max_len = next_query['input_ids'].size(1)
43
+
44
+ # Get actual lengths using attention masks
45
+ query_lens = prev_query['attention_mask'].sum(dim=1)
46
+ answer_lens = prev_answer['attention_mask'].sum(dim=1)
47
+ next_query_lens = next_query['attention_mask'].sum(dim=1)
48
+
49
+ # Calculate positions and boundaries
50
+ positions = torch.arange(max_length, device=device).expand(batch_size, -1)
51
+ section1_end = query_lens.unsqueeze(1)
52
+ section2_end = section1_end + answer_lens.unsqueeze(1)
53
+ section3_end = section2_end + next_query_lens.unsqueeze(1)
54
+
55
+ # Create masks for each section
56
+ mask_prev = positions < section1_end
57
+ mask_answer = (positions >= section1_end) & (positions < section2_end)
58
+ mask_next = (positions >= section2_end) & (positions < section3_end)
59
+
60
+ # Build combined tensor
61
+ combined_ids = torch.full((batch_size, max_length), pad_token_id, device=device)
62
+
63
+ # 1. Fill previous query section (with input length clamping)
64
+ query_indices = positions.clamp(max=query_max_len - 1)
65
+ combined_ids = torch.where(
66
+ mask_prev,
67
+ prev_query['input_ids'].gather(1, query_indices),
68
+ combined_ids
69
+ )
70
+
71
+ # 2. Fill answer section (with answer length clamping)
72
+ answer_pos = (positions - section1_end).clamp(min=0, max=answer_max_len - 1)
73
+ combined_ids = torch.where(
74
+ mask_answer,
75
+ prev_answer['input_ids'].gather(1, answer_pos),
76
+ combined_ids
77
+ )
78
+
79
+ # 3. Fill next query section (with next query length clamping)
80
+ next_q_pos = (positions - section2_end).clamp(min=0, max=next_q_max_len - 1)
81
+ combined_ids = torch.where(
82
+ mask_next,
83
+ next_query['input_ids'].gather(1, next_q_pos),
84
+ combined_ids
85
+ )
86
+
87
+ # Create attention mask
88
+ combined_mask = (positions < section3_end).long()
89
+
90
+ return {
91
+ 'input_ids': combined_ids,
92
+ 'attention_mask': combined_mask
93
+ }
94
+
95
+ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, pad_token_id: int) -> TokenizedDict:
96
+ """
97
+ Smart vectorized concatenation of interaction parts - query and answer. It creates
98
+ batch of interactions from query and answer batches. Used in MRL to concatenate data
99
+ to encode and update memory.
100
+
101
+ Query and answer sequences are padded to the same max length, and the result also has
102
+ the same length.
103
+
104
+ Args:
105
+ query (TokenizedDict): Batch of tokenized queries with attention masks
106
+ answer (TokenizedDict): Batch of tokenized answers with attention masks
107
+ max_length (int): Max length of each sequence - query, answer and result.
108
+ pad_token_id (int): Index of padding token
109
+ """
110
+ device = query['input_ids'].device
111
+ batch_size = query['input_ids'].size(0)
112
+
113
+ # Get actual lengths from attention masks
114
+ query_lens = query['attention_mask'].sum(dim=1)
115
+ answer_lens = answer['attention_mask'].sum(dim=1)
116
+
117
+ # Create combined length tensor
118
+ combined_lens = torch.minimum(query_lens + answer_lens,
119
+ torch.full_like(query_lens, max_length))
120
+
121
+ # Create position indices [batch_size, max_length]
122
+ positions = torch.arange(max_length, device=device).expand(batch_size, -1)
123
+
124
+ # Create mask for query/answer parts
125
+ query_mask = positions < query_lens.unsqueeze(1)
126
+ answer_mask = (positions >= query_lens.unsqueeze(1)) & (positions < combined_lens.unsqueeze(1))
127
+
128
+ # Calculate answer positions with overflow protection
129
+ answer_pos = (positions - query_lens.unsqueeze(1)).clamp(min=0)
130
+
131
+ # Build combined_ids using vectorized where
132
+ combined_ids = torch.where(
133
+ query_mask,
134
+ query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1)),
135
+ torch.where(
136
+ answer_mask,
137
+ answer['input_ids'].gather(1, answer_pos),
138
+ query['input_ids'].new_full((1,), pad_token_id)
139
+ )
140
+ )
141
+
142
+ # Build attention mask
143
+ combined_mask = (positions < combined_lens.unsqueeze(1)).long()
144
+
145
+ return {
146
+ 'input_ids': combined_ids,
147
+ 'attention_mask': combined_mask
148
+ }
@@ -16,6 +16,7 @@ class MultiHeadAttention(nn.Module):
16
16
  dropout: float = 0.0,
17
17
  rope: RotaryPositionalEmbedding = None,
18
18
  rope_only_for_query: bool = False,
19
+ rope_only_for_keys: bool = False,
19
20
  use_relative_embeddings: bool = False,
20
21
  max_seq_len: int = 1024,
21
22
  use_flash_attention: bool = True,
@@ -37,10 +38,12 @@ class MultiHeadAttention(nn.Module):
37
38
  self.rel_embed = RelativePositionalEmbedding(max_seq_len, embed_dim // num_heads)
38
39
  self.rope = None
39
40
  self.rope_only_for_query = False
41
+ self.rope_only_for_keys = False
40
42
  else:
41
43
  self.rel_embed = None
42
44
  self.rope = rope
43
45
  self.rope_only_for_query = rope_only_for_query
46
+ self.rope_only_for_keys = rope_only_for_keys
44
47
  self.dropout = nn.Dropout(dropout)
45
48
  self._init_q(embed_dim)
46
49
  self._init_kv(embed_dim)
@@ -70,6 +73,8 @@ class MultiHeadAttention(nn.Module):
70
73
  if self.rope is not None:
71
74
  if self.rope_only_for_query:
72
75
  q = self.rope.forward_one(q)
76
+ elif self.rope_only_for_keys:
77
+ k = self.rope.forward_one(k)
73
78
  else:
74
79
  q, k = self.rope(q, k)
75
80
  return q, k
@@ -192,6 +197,7 @@ class GroupedQueryAttention(MultiHeadAttention):
192
197
  k = self.k_proj(key).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
193
198
  v = self.v_proj(value).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
194
199
  else:
200
+ # Relative embedding version is not working without this strange mapping - it will be removed in next versions
195
201
  group_heads = self.num_heads // self.num_groups
196
202
 
197
203
  # Process Q
@@ -289,6 +295,7 @@ def init_attention(
289
295
  dropout: float = 0.0,
290
296
  rope: RotaryPositionalEmbedding = None,
291
297
  rope_only_for_query: bool = False,
298
+ rope_only_for_keys: bool = False,
292
299
  use_relative_embeddings: bool = False,
293
300
  max_seq_len: int = 1024,
294
301
  use_flash_attention: bool = False,
@@ -308,6 +315,7 @@ def init_attention(
308
315
  use_relative_embeddings=use_relative_embeddings,
309
316
  max_seq_len=max_seq_len,
310
317
  rope_only_for_query=rope_only_for_query,
318
+ rope_only_for_keys=rope_only_for_keys,
311
319
  use_flash_attention=use_flash_attention,
312
320
  is_causal=is_causal,
313
321
  use_bias=use_bias,
@@ -321,6 +329,7 @@ def init_attention(
321
329
  use_relative_embeddings=use_relative_embeddings,
322
330
  max_seq_len=max_seq_len,
323
331
  rope_only_for_query=rope_only_for_query,
332
+ rope_only_for_keys=rope_only_for_keys,
324
333
  use_flash_attention=use_flash_attention,
325
334
  is_causal=is_causal,
326
335
  use_bias=use_bias,
@@ -334,6 +343,7 @@ def init_attention(
334
343
  use_relative_embeddings=use_relative_embeddings,
335
344
  max_seq_len=max_seq_len,
336
345
  rope_only_for_query=rope_only_for_query,
346
+ rope_only_for_keys=rope_only_for_keys,
337
347
  use_flash_attention=use_flash_attention,
338
348
  is_causal=is_causal,
339
349
  use_bias=use_bias,
@@ -61,6 +61,12 @@ class ReactiveTransformerLayer(nn.Module):
61
61
  for param in self.memory_cross_attention.parameters():
62
62
  param.requires_grad_(is_trainable)
63
63
 
64
+ def update_max_len(self, max_seq_len: int):
65
+ if self.attention.rope is not None:
66
+ self.attention.rope.update_max_len(max_seq_len)
67
+ if self.memory_cross_attention.rope is not None:
68
+ self.memory_cross_attention.rope.update_max_len(max_seq_len)
69
+
64
70
  def moe_router_loss(self):
65
71
  ff_router_loss = self.ff.router_loss() if self.use_moe else None
66
72
  att_router_loss = None
@@ -72,11 +72,17 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
72
72
  # Process shared layers
73
73
  if self.shared_layers is not None:
74
74
  for i in range(self.num_shared_layers):
75
- layer_stm = self.stm(i).expand(x.size(0), -1, -1)
75
+ layer_stm = self.stm(i)
76
+ # expand layer STM to batch size, if it's not in batch mode
77
+ if layer_stm.size(0) == 1:
78
+ layer_stm = layer_stm.expand(x.size(0), -1, -1)
76
79
  x = self.shared_layers[i](x, layer_stm, mask=mask)
77
80
  # Process own layers
78
81
  for i in range(self.num_own_layers):
79
- layer_stm = self.stm(i).expand(x.size(0), -1, -1)
82
+ layer_stm = self.stm(i)
83
+ # expand layer STM to batch size, if it's not in batch mode
84
+ if layer_stm.size(0) == 1:
85
+ layer_stm = layer_stm.expand(x.size(0), -1, -1)
80
86
  x = self.layers[i](x, layer_stm, mask=mask)
81
87
  return self.head(x)
82
88
 
@@ -93,12 +99,18 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
93
99
  # Process shared layers
94
100
  if self.shared_layers is not None:
95
101
  for i in range(self.num_shared_layers):
96
- layer_stm = self.stm(i).expand(x.size(0), -1, -1)
102
+ layer_stm = self.stm(i)
103
+ # expand layer STM to batch size, if it's not in batch mode
104
+ if layer_stm.size(0) == 1:
105
+ layer_stm = layer_stm.expand(x.size(0), -1, -1)
97
106
  x = self.shared_layers[i](x, layer_stm, mask=attention_mask)
98
107
  hidden_states.append(x)
99
108
  # Process own layers
100
109
  for i in range(self.num_own_layers):
101
- layer_stm = self.stm(i).expand(x.size(0), -1, -1)
110
+ layer_stm = self.stm(i)
111
+ # expand layer STM to batch size, if it's not in batch mode
112
+ if layer_stm.size(0) == 1:
113
+ layer_stm = layer_stm.expand(x.size(0), -1, -1)
102
114
  x = self.layers[i](x, layer_stm, mask=attention_mask)
103
115
  hidden_states.append(x)
104
116
  return x, torch.stack(hidden_states)
@@ -18,6 +18,11 @@ class RotaryPositionalEmbedding(nn.Module):
18
18
  freqs = torch.einsum('i,j->ij', t, self.inv_freq)
19
19
  self.register_buffer('cache', freqs)
20
20
 
21
+ def update_max_len(self, max_seq_len: int):
22
+ self.max_seq_len = max_seq_len
23
+ t = torch.arange(max_seq_len).type_as(self.inv_freq)
24
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
25
+ self.cache = freqs
21
26
 
22
27
  def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
23
28
  seq_len = q.size(-2)
@@ -42,6 +47,8 @@ class RotaryPositionalEmbedding(nn.Module):
42
47
  return q_embed
43
48
 
44
49
  def _prepare_freqs(self, seq_len: int) -> torch.Tensor:
50
+ if seq_len > self.max_seq_len:
51
+ self.update_max_len(seq_len)
45
52
  return self.cache[:seq_len][None, None, :, :]
46
53
 
47
54
  def _rotate(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: