rxnn 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,253 @@
1
1
  import torch
2
2
  from torch import nn
3
- from rxnn.transformers.attention import MultiHeadAttention
3
+ from ..transformers.attention import MultiHeadAttention, GroupedQueryAttention
4
+ from ..transformers.positional import RotaryPositionalEmbedding
5
+ from ..transformers.moe import MoeRouter
6
+
7
+ # Created by Reactive AI
8
+
9
+ class GroupedMoeAttention(GroupedQueryAttention):
10
+ """
11
+ Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
12
+ Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
13
+ number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
14
+ - with num_groups set to 1, it will be MoE MultiQueryAttention
15
+
16
+ Compared to traditional GQA/MQA, it should provide better performance, because lot less data could be lost using
17
+ this approach - we are training the full number of keys/values heads, while using only a group.
18
+
19
+ In case of efficiency, it should be close to GQA/MQA linear performance, but with a small MoE routing overhead.
20
+
21
+ Optionally, it could use even more expert heads than attention heads - in example:
22
+ - 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
23
+ 4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
24
+ """
25
+ def __init__(
26
+ self,
27
+ embed_dim: int,
28
+ num_heads: int,
29
+ num_groups: int,
30
+ dropout: float = 0.0,
31
+ rope: RotaryPositionalEmbedding = None,
32
+ rope_only_for_query: bool = False,
33
+ use_relative_embeddings: bool = False,
34
+ max_seq_len: int = 1024,
35
+ use_flash_attention: bool = False,
36
+ is_causal: bool = False,
37
+ use_bias: bool = False,
38
+ num_experts: int = None,
39
+ *args,
40
+ **kwargs,
41
+ ):
42
+ self.num_experts = num_experts if num_experts is not None else num_heads
43
+ super(GroupedMoeAttention, self).__init__(
44
+ embed_dim,
45
+ num_heads,
46
+ num_groups=num_groups,
47
+ dropout=dropout,
48
+ rope=rope,
49
+ rope_only_for_query=rope_only_for_query,
50
+ use_relative_embeddings=use_relative_embeddings,
51
+ max_seq_len=max_seq_len,
52
+ use_flash_attention=use_flash_attention,
53
+ is_causal=is_causal,
54
+ use_bias=use_bias,
55
+ *args,
56
+ **kwargs,
57
+ )
58
+
59
+ def _init_kv(self, embed_dim: int):
60
+ self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
61
+ hidden_dim = embed_dim // (self.num_heads // self.num_groups)
62
+ self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
63
+ self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
64
+ self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
65
+ self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
66
+ self._init_experts()
67
+
68
+ def _init_experts(self):
69
+ torch.nn.init.xavier_uniform_(self.wk)
70
+ torch.nn.init.xavier_uniform_(self.wv)
71
+ if self.use_bias:
72
+ torch.nn.init.zeros_(self.bk)
73
+ torch.nn.init.zeros_(self.bv)
74
+
75
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int, skip_query_processing: bool = False):
76
+ # head_dim = d // self.num_heads
77
+ # group_heads = self.num_heads // self.num_groups
78
+ #
79
+ # # Process Query as in GQA
80
+ # q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
81
+ #
82
+ # # Process Key and Value with MoE routing
83
+ # key_flat = key.view(-1, d)
84
+ # weights, indices = self.router(key_flat)
85
+ # weights = weights.view(b, key.size(1), self.num_groups, 1)
86
+ # indices = indices.view(b, key.size(1), self.num_groups)
87
+ #
88
+ # # Compute all experts' K and V projections
89
+ # # Shape: (batch_size, seq_len, num_experts, head_dim * num_groups)
90
+ # k_all = torch.einsum(
91
+ # 'be, ehd -> bedh',
92
+ # key_flat,
93
+ # self.wk.view(self.num_experts, d, -1)
94
+ # ).view(b, key.size(1), self.num_experts, -1)
95
+ #
96
+ # v_all = torch.einsum(
97
+ # 'be, ehd -> bedh',
98
+ # value.view(-1, d),
99
+ # self.wv.view(self.num_experts, d, -1)
100
+ # ).view(b, value.size(1), self.num_experts, -1)
101
+ #
102
+ # # Select top_k experts and compute weighted sum
103
+ # selected_k = torch.gather(
104
+ # k_all,
105
+ # 2,
106
+ # indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
107
+ # )
108
+ # selected_v = torch.gather(
109
+ # v_all,
110
+ # 2,
111
+ # indices.unsqueeze(-1).expand(-1, -1, -1, v_all.size(-1))
112
+ # )
113
+ #
114
+ # selected_k = (selected_k * weights).sum(dim=2)
115
+ # selected_v = (selected_v * weights).sum(dim=2)
116
+ # # Reshape to GQA format: (B, G, S, head_dim)
117
+ # k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
118
+ # v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
119
+ #
120
+ # if not self.use_flash_attention:
121
+ # group_heads = self.num_heads // self.num_groups
122
+ #
123
+ # k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
124
+ # v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
125
+ #
126
+ # k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
127
+ # v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
128
+ #
129
+ # return q, k, v
130
+ head_dim = d // self.num_heads
131
+
132
+ # Process Query as in GQA
133
+ q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2) if not skip_query_processing else query
134
+
135
+ # Process Key and Value with MoE routing
136
+ key_flat = key.view(-1, d) # (B*S, d)
137
+ value_flat = value.view(-1, d) # (B*S, d)
138
+
139
+ # Get routing indices and weights for K
140
+ weights_k, indices_k = self.router(key_flat)
141
+ indices_k = indices_k.view(-1, self.top_k) # (B*S, top_k)
142
+ weights_k = weights_k.view(-1, self.top_k, 1) # (B*S, top_k, 1)
143
+
144
+ # Select and compute K projections for only the top_k experts
145
+ selected_k_weights = self.k_experts[indices_k] # (B*S, top_k, d, k_out_dim)
146
+ k_proj = torch.einsum('bd, behd -> beh', key_flat.unsqueeze(1), selected_k_weights)
147
+ selected_k = (k_proj * weights_k).sum(dim=1) # (B*S, k_out_dim)
148
+ selected_k = selected_k.view(b, key.size(1), -1) # (B, S, k_out_dim)
149
+
150
+ # Compute V using the same indices as K (since they share the same router)
151
+ selected_v_weights = self.v_experts[indices_k]
152
+ v_proj = torch.einsum('bd, behd -> beh', value_flat.unsqueeze(1), selected_v_weights)
153
+ selected_v = (v_proj * weights_k).sum(dim=1)
154
+ selected_v = selected_v.view(b, value.size(1), -1) # (B, S, k_out_dim)
155
+
156
+ # Reshape to GQA format: (B, G, S, head_dim)
157
+ k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
158
+ v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
159
+
160
+ if not self.use_flash_attention:
161
+ group_heads = self.num_heads // self.num_groups
162
+
163
+ k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
164
+ v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
165
+
166
+ k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
167
+ v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
168
+
169
+ return q, k, v
170
+
171
+ class SparseMoeAttention(GroupedMoeAttention):
172
+ """
173
+ Sparse MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
174
+ In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
175
+ query heads - with that approach, each token could attend to every other token, but only partially - only some part of
176
+ information from each token is used to identify related information parts from other tokens.
177
+
178
+ This solution could reduce the computational complexity of attention operation to sublinear level (<O(N))
179
+ """
180
+ def __init__(
181
+ self,
182
+ embed_dim: int,
183
+ num_heads: int,
184
+ num_groups: int,
185
+ dropout: float = 0.0,
186
+ rope: RotaryPositionalEmbedding = None,
187
+ rope_only_for_query: bool = False,
188
+ use_relative_embeddings: bool = False,
189
+ max_seq_len: int = 1024,
190
+ use_flash_attention: bool = False,
191
+ is_causal: bool = False,
192
+ use_bias: bool = False,
193
+ num_experts: int = None,
194
+ num_query_experts: int = None,
195
+ num_active_query_heads: int = None,
196
+ *args,
197
+ **kwargs,
198
+ ):
199
+ self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
200
+ self.num_active_query_heads = num_active_query_heads if num_active_query_heads is not None else num_groups
201
+ super(SparseMoeAttention, self).__init__(
202
+ embed_dim,
203
+ num_heads,
204
+ num_groups=num_groups,
205
+ dropout=dropout,
206
+ rope=rope,
207
+ rope_only_for_query=rope_only_for_query,
208
+ use_relative_embeddings=use_relative_embeddings,
209
+ max_seq_len=max_seq_len,
210
+ use_flash_attention=use_flash_attention,
211
+ is_causal=is_causal,
212
+ use_bias=use_bias,
213
+ num_experts=num_experts,
214
+ *args,
215
+ **kwargs,
216
+ )
217
+
218
+ def _init_q(self, embed_dim: int):
219
+ self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_active_query_heads)
220
+ hidden_dim = embed_dim // (self.num_heads // self.num_groups)
221
+ self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
222
+ self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
223
+ self._init_query_experts()
224
+
225
+ def _init_query_experts(self):
226
+ torch.nn.init.xavier_uniform_(self.wq)
227
+ if self.use_bias:
228
+ torch.nn.init.zeros_(self.bq)
229
+
230
+ def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
231
+ head_dim = d // self.num_heads
232
+
233
+ # Process Query with MoE routing
234
+ query_flat = query.view(-1, d) # (B*T, d)
235
+ weights_q, indices_q = self.router_q(query_flat)
236
+ indices_q = indices_q.view(-1, self.top_k_q) # (B*T, top_k_q)
237
+ weights_q = weights_q.view(-1, self.top_k_q, 1) # (B*T, top_k_q, 1)
238
+
239
+ # Select and compute Q projections for top_k experts
240
+ selected_q_weights = self.q_experts[indices_q] # (B*T, top_k_q, d, head_dim*num_heads)
241
+ q_proj = torch.einsum('bd, behd -> beh', query_flat.unsqueeze(1), selected_q_weights)
242
+ selected_q = (q_proj * weights_q).sum(dim=1) # (B*T, head_dim*num_heads)
243
+ selected_q = selected_q.view(b, t, -1) # (B, T, head_dim*num_heads)
244
+
245
+ q = selected_q.view(b, t, self.num_heads, head_dim).transpose(1, 2) # (B, H, T, head_dim)
246
+
247
+ return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
248
+
249
+
250
+ # Others
4
251
 
5
252
  class FlexAttention(MultiHeadAttention):
6
253
  def __init__(
@@ -59,7 +59,7 @@ class ReactiveTransformerLayer(nn.Module):
59
59
  for param in self.memory_cross_attention.parameters():
60
60
  param.requires_grad_(is_trainable)
61
61
 
62
- def moe_router_loss_(self):
62
+ def moe_router_loss(self):
63
63
  return self.ff.router_loss() if self.use_moe else None
64
64
 
65
65
  def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
@@ -135,7 +135,7 @@ class ClassicTransformerLayer(nn.Module):
135
135
  self.use_post_norm = use_post_norm
136
136
  self.use_moe = use_moe
137
137
 
138
- def moe_router_loss_(self):
138
+ def moe_router_loss(self):
139
139
  return self.ff.router_loss() if self.use_moe else torch.tensor(0.0)
140
140
 
141
141
  def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
@@ -37,7 +37,7 @@ class ReactiveTransformerBase(nn.Module):
37
37
  for i in range(self.num_own_layers):
38
38
  self.layers[i].trainable_cross_attention_(is_trainable)
39
39
 
40
- def moe_router_loss_(self):
40
+ def moe_router_loss(self):
41
41
  return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe] + [
42
42
  self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe]).mean()
43
43
 
@@ -123,7 +123,7 @@ class ClassicTransformerBase(nn.Module):
123
123
  self.layers = layers
124
124
  self.num_layers = len(layers) if layers else 0
125
125
 
126
- def moe_router_loss_(self):
126
+ def moe_router_loss(self):
127
127
  return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_layers) if self.layers[i].use_moe]).mean()
128
128
 
129
129
  def forward(self, x: torch.Tensor) -> torch.Tensor:
rxnn/transformers/moe.py CHANGED
@@ -77,29 +77,63 @@ class MoeFeedForward(nn.Module):
77
77
  return self.router.aux_loss
78
78
 
79
79
  def forward(self, x: torch.Tensor):
80
+ # orig_shape = x.shape
81
+ # x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
82
+ #
83
+ # # Get routing weights and indices
84
+ # weights, indices = self.router(x) # [batch*seq_len, top_k]
85
+ #
86
+ # # Create expert masks and combine it with masks
87
+ # mask = F.one_hot(indices, self.num_experts).float() # [batch*seq_len, top_k, num_experts]
88
+ # weights = (weights.unsqueeze(-1) * mask).sum(dim=1) # [batch*seq_len, num_experts]
89
+ #
90
+ # # Expert computation
91
+ # x = x.unsqueeze(1).expand(-1, self.num_experts, -1) # [batch*seq_len, num_experts, embed_dim]
92
+ #
93
+ # # First linear layer
94
+ # h = torch.einsum('bie,ieh->bih', x, self.w1) + self.b1 # [batch*seq_len, num_experts, hidden_dim]
95
+ # h = self._activate(h)
96
+ # h = self.dropout(h)
97
+ #
98
+ # # Second linear layer (projection back to embed_dim)
99
+ # out = torch.einsum('bih,ihe->bie', h, self.w2) + self.b2 # [batch*seq_len, num_experts, embed_dim]
100
+ #
101
+ # # Weighted sum of expert outputs
102
+ # out = (out * weights.unsqueeze(-1)).sum(dim=1) # [batch*seq_len, embed_dim]
103
+ #
104
+ # return out.view(*orig_shape)
80
105
  orig_shape = x.shape
81
106
  x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
82
107
 
83
108
  # Get routing weights and indices
84
- weights, indices = self.router(x) # [batch*seq_len, top_k]
109
+ weights, indices = self.router(x) # [batch*seq_len, top_k], [batch*seq_len, top_k]
85
110
 
86
- # Create expert masks and combine it with masks
87
- mask = F.one_hot(indices, self.num_experts).float() # [batch*seq_len, top_k, num_experts]
88
- weights = (weights.unsqueeze(-1) * mask).sum(dim=1) # [batch*seq_len, num_experts]
111
+ # Flatten indices and weights
112
+ batch_size = x.size(0)
113
+ top_k = indices.size(1)
114
+ indices = indices.view(-1) # [batch*seq_len * top_k]
115
+ weights = weights.view(-1, 1) # [batch*seq_len * top_k, 1]
89
116
 
90
- # Expert computation
91
- x = x.unsqueeze(1).expand(-1, self.num_experts, -1) # [batch*seq_len, num_experts, embed_dim]
117
+ # Select only the relevant experts for each token
118
+ selected_w1 = self.w1[indices] # [batch*seq_len * top_k, embed_dim, hidden_dim]
119
+ selected_b1 = self.b1[indices] # [batch*seq_len * top_k, hidden_dim]
120
+ selected_w2 = self.w2[indices] # [batch*seq_len * top_k, hidden_dim, embed_dim]
121
+ selected_b2 = self.b2[indices] # [batch*seq_len * top_k, embed_dim]
92
122
 
93
- # First linear layer
94
- h = torch.einsum('bie,ieh->bih', x, self.w1) + self.b1 # [batch*seq_len, num_experts, hidden_dim]
123
+ # Reshape x for batched computation
124
+ x_expanded = x.unsqueeze(1).repeat(1, top_k, 1).view(-1, self.embed_dim) # [batch*seq_len * top_k, embed_dim]
125
+
126
+ # Compute only the selected experts
127
+ h = torch.einsum('be, beh -> bh', x_expanded, selected_w1) + selected_b1
95
128
  h = self._activate(h)
96
129
  h = self.dropout(h)
97
130
 
98
- # Second linear layer (projection back to embed_dim)
99
- out = torch.einsum('bih,ihe->bie', h, self.w2) + self.b2 # [batch*seq_len, num_experts, embed_dim]
131
+ out = torch.einsum('bh, bhe -> be', h, selected_w2) + selected_b2
100
132
 
101
- # Weighted sum of expert outputs
102
- out = (out * weights.unsqueeze(-1)).sum(dim=1) # [batch*seq_len, embed_dim]
133
+ # Reshape back and apply weights
134
+ out = out.view(batch_size, top_k, -1) # [batch*seq_len, top_k, embed_dim]
135
+ weights = weights.view(batch_size, top_k, 1) # [batch*seq_len, top_k, 1]
136
+ out = (out * weights).sum(dim=1) # Weighted sum over top_k experts
103
137
 
104
138
  return out.view(*orig_shape)
105
139
 
@@ -1,13 +1,15 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  import torch.nn.functional as F
4
- from typing import Iterator
4
+ from typing import Iterator, Union
5
+ from transformers import PreTrainedTokenizerFast, PreTrainedTokenizer
6
+
5
7
 
6
8
  def sample(
7
- logits: torch.Tensor,
8
- temperature: float = 1.0,
9
- top_k: int = None,
10
- top_p: float = None,
9
+ logits: torch.Tensor,
10
+ temperature: float = 1.0,
11
+ top_k: int = None,
12
+ top_p: float = None,
11
13
  ) -> torch.Tensor:
12
14
  if temperature <= 0:
13
15
  raise ValueError("Temperature must be > 0")
@@ -45,6 +47,7 @@ def sample(
45
47
  # Sample from distribution
46
48
  return torch.multinomial(probs, num_samples=1)
47
49
 
50
+
48
51
  class Sampler:
49
52
  def __init__(self, model: nn.Module, device: torch.device, end_token_id: int):
50
53
  self.model = model.to(device)
@@ -52,12 +55,12 @@ class Sampler:
52
55
  self.end_token_id = end_token_id
53
56
 
54
57
  def _generate_token(
55
- self,
56
- input_ids: torch.Tensor,
57
- temperature: float,
58
- top_k: int,
59
- top_p: float ,
60
- attention_mask: torch.Tensor,
58
+ self,
59
+ input_ids: torch.Tensor,
60
+ temperature: float,
61
+ top_k: int,
62
+ top_p: float,
63
+ attention_mask: torch.Tensor,
61
64
  ) -> tuple[int, torch.Tensor, torch.Tensor]:
62
65
  # Forward pass to get next token logits
63
66
  outputs = self.model(input_ids, attention_mask=attention_mask)
@@ -82,14 +85,14 @@ class Sampler:
82
85
  )
83
86
 
84
87
  def __call__(
85
- self,
86
- initial_tokens: torch.Tensor,
87
- temperature: float = 1.0,
88
- top_k: int = None,
89
- top_p: float = None,
90
- max_seq_len: int = 50,
91
- attention_mask: torch.Tensor = None,
92
- no_grad: bool = True,
88
+ self,
89
+ initial_tokens: torch.Tensor,
90
+ temperature: float = 1.0,
91
+ top_k: int = None,
92
+ top_p: float = None,
93
+ max_seq_len: int = 50,
94
+ attention_mask: torch.Tensor = None,
95
+ no_grad: bool = True,
93
96
  ) -> Iterator[int]:
94
97
  # Convert initial tokens to tensor and move to device
95
98
  input_ids = initial_tokens
@@ -97,13 +100,70 @@ class Sampler:
97
100
  if no_grad:
98
101
  with torch.no_grad():
99
102
  for _ in range(max_seq_len):
100
- next_token, input_ids, attention_mask = self._generate_token(input_ids, temperature, top_k, top_p, attention_mask)
103
+ next_token, input_ids, attention_mask = self._generate_token(input_ids, temperature, top_k, top_p,
104
+ attention_mask)
101
105
  yield next_token
102
106
  if next_token == self.end_token_id:
103
107
  break
104
108
  else:
105
109
  for _ in range(max_seq_len):
106
- next_token, input_ids, attention_mask = self._generate_token(input_ids, temperature, top_k, top_p, attention_mask)
110
+ next_token, input_ids, attention_mask = self._generate_token(input_ids, temperature, top_k, top_p,
111
+ attention_mask)
107
112
  yield next_token
108
113
  if next_token == self.end_token_id:
109
- break
114
+ break
115
+
116
+
117
+ class SampleDecoder:
118
+ def __init__(self, sampler: Sampler, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]):
119
+ self.sampler = sampler
120
+ self.tokenizer = tokenizer
121
+ self.device = self.sampler.device
122
+
123
+ def tokenize_input(self, text: str):
124
+ tokenized = self.tokenizer(
125
+ text,
126
+ max_length=256,
127
+ truncation=True,
128
+ padding=False,
129
+ return_tensors='pt',
130
+ return_attention_mask=True
131
+ )
132
+ tokenized['input_ids'] = tokenized['input_ids'][:, :-1].to(self.device)
133
+ tokenized['attention_mask'] = tokenized['attention_mask'][:, :-1].to(self.device)
134
+ del tokenized['token_type_ids']
135
+ return tokenized
136
+
137
+ def ids_iter(self, text: str, temperature: float = 0.1, top_p: float = 0.9, max_seq_len=256):
138
+ tokenized = self.tokenize_input(text)
139
+ return self.sampler(
140
+ tokenized['input_ids'],
141
+ temperature=temperature,
142
+ top_p=top_p,
143
+ max_seq_len=max_seq_len,
144
+ attention_mask=tokenized['attention_mask']
145
+ )
146
+
147
+ def txt_iter(self, text: str, temperature: float = 0.1, top_p: float = 0.9, max_seq_len=256):
148
+ return map(
149
+ lambda x: self.tokenizer.decode([x]).replace('Ċ', '\n').replace('Ġ', ' '),
150
+ self.ids_iter(text, temperature, top_p, max_seq_len)
151
+ )
152
+
153
+ def txt(self, text: str, temperature: float = 0.1, top_p: float = 0.9, max_seq_len=256):
154
+ return text + ''.join(self.txt_iter(text, temperature, top_p, max_seq_len))
155
+
156
+ def print_stream(self, text: str, temperature: float = 0.1, top_p: float = 0.9, max_seq_len=256):
157
+ print(text, end='')
158
+ resp = text
159
+ for txt_token in self.txt_iter(text, temperature=temperature, top_p=top_p, max_seq_len=max_seq_len):
160
+ print(txt_token, end='')
161
+ resp += txt_token
162
+ return resp
163
+
164
+ def __call__(self, text: str, print_stream: bool = False, temperature: float = 0.1, top_p: float = 0.9,
165
+ max_seq_len=256):
166
+ if print_stream:
167
+ return self.print_stream(text, temperature=temperature, top_p=top_p, max_seq_len=max_seq_len)
168
+ else:
169
+ return self.txt(text, temperature=temperature, top_p=top_p, max_seq_len=max_seq_len)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.12
3
+ Version: 0.1.14
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,6 +1,6 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=M85p_GFU0fbUjfUhXdcwIGW-amrdzwKpU8qSABr7brQ,5634
3
+ rxnn/experimental/attention.py,sha256=HahcWU37FTfW8kwSTW8z_l7EtAVkJgvDDxLU8k3miHo,17101
4
4
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
6
6
  rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
@@ -16,14 +16,14 @@ rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,80
16
16
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  rxnn/transformers/attention.py,sha256=FfEYE0THO73p_1eRupr2mcwfW4UbI_riIxkHfr8X_1c,14022
18
18
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
19
- rxnn/transformers/layers.py,sha256=xMocHzdSu7hcC_mPE_aG3-LQg2RXgunKSxcgNXYnOeo,5631
19
+ rxnn/transformers/layers.py,sha256=HhIiykmrBgdsV4AbMQXr9t0cSo4gSIeN0dPtc8mDyOo,5629
20
20
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
21
- rxnn/transformers/models.py,sha256=PVhiTTSQ7VTDVdOcyRf-xGNvj6oOa_2fUV2mfthcE0Y,7171
22
- rxnn/transformers/moe.py,sha256=v21HDEhkDr10--If0P-XBjT5C7IlQJo0wGQlpDnVWEA,5020
21
+ rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
22
+ rxnn/transformers/moe.py,sha256=fFPTRcctCSc9OwHd0PhNb0nwHgNJY7dXfUtGreXtaho,6720
23
23
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
24
- rxnn/transformers/sampler.py,sha256=wSz_1wNloqtuiix5w2Mcsj5NhaO9QlY0j__TVG7wJnM,3938
24
+ rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
25
25
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
26
- rxnn-0.1.12.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
27
- rxnn-0.1.12.dist-info/METADATA,sha256=mdoZLApjlSpC6GnprzoPuVpVhHpmVDejSjJABq_HKbk,14629
28
- rxnn-0.1.12.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
29
- rxnn-0.1.12.dist-info/RECORD,,
26
+ rxnn-0.1.14.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
27
+ rxnn-0.1.14.dist-info/METADATA,sha256=YQDNMaHDrfVdOk44qEUczgLaNcrXApoqVmNX50yQDdM,14629
28
+ rxnn-0.1.14.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
29
+ rxnn-0.1.14.dist-info/RECORD,,
File without changes
File without changes