rxnn 0.1.83__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/.DS_Store +0 -0
- rxnn/experimental/attention.py +5 -0
- rxnn/memory/attention.py +42 -0
- rxnn/memory/stm.py +53 -12
- rxnn/rxt/models.py +71 -0
- rxnn/training/bml.py +2 -59
- rxnn/training/callbacks.py +302 -39
- rxnn/training/dataset.py +344 -1
- rxnn/training/models.py +142 -0
- rxnn/training/mrl.py +808 -0
- rxnn/training/reward.py +111 -0
- rxnn/training/rl.py +69 -0
- rxnn/training/utils.py +148 -0
- rxnn/transformers/attention.py +10 -0
- rxnn/transformers/layers.py +6 -0
- rxnn/transformers/models.py +16 -4
- rxnn/transformers/positional.py +7 -0
- rxnn/transformers/sampler.py +283 -9
- {rxnn-0.1.83.dist-info → rxnn-0.2.0.dist-info}/METADATA +11 -9
- rxnn-0.2.0.dist-info/RECORD +38 -0
- rxnn-0.1.83.dist-info/RECORD +0 -31
- {rxnn-0.1.83.dist-info → rxnn-0.2.0.dist-info}/LICENSE +0 -0
- {rxnn-0.1.83.dist-info → rxnn-0.2.0.dist-info}/WHEEL +0 -0
rxnn/transformers/sampler.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
import torch.nn.functional as F
|
4
|
-
from typing import Iterator, Union
|
4
|
+
from typing import Iterator, Union, Optional
|
5
5
|
from transformers import PreTrainedTokenizerFast, PreTrainedTokenizer
|
6
6
|
|
7
7
|
|
8
8
|
def sample(
|
9
9
|
logits: torch.Tensor,
|
10
10
|
temperature: float = 1.0,
|
11
|
-
top_k: int = None,
|
12
|
-
top_p: float = None,
|
11
|
+
top_k: Optional[int] = None,
|
12
|
+
top_p: Optional[float] = None,
|
13
13
|
) -> torch.Tensor:
|
14
14
|
if temperature <= 0:
|
15
15
|
raise ValueError("Temperature must be > 0")
|
@@ -88,8 +88,8 @@ class Sampler:
|
|
88
88
|
self,
|
89
89
|
initial_tokens: torch.Tensor,
|
90
90
|
temperature: float = 1.0,
|
91
|
-
top_k: int = None,
|
92
|
-
top_p: float = None,
|
91
|
+
top_k: Optional[int] = None,
|
92
|
+
top_p: Optional[float] = None,
|
93
93
|
max_seq_len: int = 50,
|
94
94
|
attention_mask: torch.Tensor = None,
|
95
95
|
no_grad: bool = True,
|
@@ -120,10 +120,10 @@ class SampleDecoder:
|
|
120
120
|
self.tokenizer = tokenizer
|
121
121
|
self.device = self.sampler.device
|
122
122
|
|
123
|
-
def tokenize_input(self, text: str):
|
123
|
+
def tokenize_input(self, text: str, max_seq_len: int = 256):
|
124
124
|
tokenized = self.tokenizer(
|
125
125
|
text,
|
126
|
-
max_length=
|
126
|
+
max_length=max_seq_len,
|
127
127
|
truncation=True,
|
128
128
|
padding=False,
|
129
129
|
return_tensors='pt',
|
@@ -135,7 +135,7 @@ class SampleDecoder:
|
|
135
135
|
return tokenized
|
136
136
|
|
137
137
|
def ids_iter(self, text: str, temperature: float = 0.1, top_p: float = 0.9, max_seq_len=256):
|
138
|
-
tokenized = self.tokenize_input(text)
|
138
|
+
tokenized = self.tokenize_input(text, max_seq_len=max_seq_len)
|
139
139
|
return self.sampler(
|
140
140
|
tokenized['input_ids'],
|
141
141
|
temperature=temperature,
|
@@ -166,4 +166,278 @@ class SampleDecoder:
|
|
166
166
|
if print_stream:
|
167
167
|
return self.print_stream(text, temperature=temperature, top_p=top_p, max_seq_len=max_seq_len)
|
168
168
|
else:
|
169
|
-
return self.txt(text, temperature=temperature, top_p=top_p, max_seq_len=max_seq_len)
|
169
|
+
return self.txt(text, temperature=temperature, top_p=top_p, max_seq_len=max_seq_len)
|
170
|
+
|
171
|
+
class InteractionSampler(SampleDecoder):
|
172
|
+
def __init__(self, sampler: Sampler, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]):
|
173
|
+
super(InteractionSampler, self).__init__(sampler, tokenizer)
|
174
|
+
|
175
|
+
def txt(self, text: str, temperature: float = 0.1, top_p: float = 0.9, max_seq_len: int = 256, special_token_spaces: bool = True):
|
176
|
+
txt = '[Q]' + text + '[A]'
|
177
|
+
start_txt = '[Q] ' + text + ' [A] ' if special_token_spaces else txt
|
178
|
+
return start_txt + ''.join(self.txt_iter(txt, temperature, top_p, max_seq_len))
|
179
|
+
|
180
|
+
def print_stream(self, text: str, temperature: float = 0.1, top_p: float = 0.9, max_seq_len: int = 256, special_token_spaces: bool = True):
|
181
|
+
txt = '[Q]' + text + '[A]'
|
182
|
+
start_txt = '[Q] ' + text + ' [A] ' if special_token_spaces else txt
|
183
|
+
print(start_txt, end='')
|
184
|
+
resp = start_txt
|
185
|
+
for txt_token in self.txt_iter(txt, temperature=temperature, top_p=top_p, max_seq_len=max_seq_len):
|
186
|
+
print(txt_token, end='')
|
187
|
+
resp += txt_token
|
188
|
+
return resp
|
189
|
+
|
190
|
+
def __call__(self, text: str, print_stream: bool = False, temperature: float = 0.1, top_p: float = 0.9,
|
191
|
+
max_seq_len: int = 256, special_token_spaces: bool = True):
|
192
|
+
if print_stream:
|
193
|
+
return self.print_stream(text, temperature=temperature, top_p=top_p, max_seq_len=max_seq_len, special_token_spaces=special_token_spaces)
|
194
|
+
else:
|
195
|
+
return self.txt(text, temperature=temperature, top_p=top_p, max_seq_len=max_seq_len, special_token_spaces=special_token_spaces)
|
196
|
+
|
197
|
+
|
198
|
+
def sample_batch(
|
199
|
+
logits: torch.Tensor,
|
200
|
+
temperature: float = 1.0,
|
201
|
+
top_k: Optional[int] = None,
|
202
|
+
top_p: Optional[float] = None,
|
203
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
204
|
+
"""Returns (sampled_tokens, log_probs)"""
|
205
|
+
if temperature <= 0:
|
206
|
+
raise ValueError("Temperature must be > 0")
|
207
|
+
|
208
|
+
# Store original dtype and device
|
209
|
+
original_dtype = logits.dtype
|
210
|
+
device = logits.device
|
211
|
+
|
212
|
+
# Convert to float32 for numerical stability
|
213
|
+
logits = logits.float()
|
214
|
+
|
215
|
+
# Apply temperature
|
216
|
+
logits = logits / temperature
|
217
|
+
|
218
|
+
# Apply top-k filtering
|
219
|
+
if top_k is not None and top_k > 0:
|
220
|
+
topk_values, _ = torch.topk(logits, top_k, dim=-1)
|
221
|
+
min_topk = topk_values[:, -1].unsqueeze(-1)
|
222
|
+
logits = torch.where(logits < min_topk, torch.tensor(-float('inf'), device=device), logits)
|
223
|
+
|
224
|
+
# Apply top-p filtering
|
225
|
+
if top_p is not None and 0 < top_p <= 1.0:
|
226
|
+
# Sort logits in descending order
|
227
|
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
228
|
+
|
229
|
+
# Calculate cumulative probabilities
|
230
|
+
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
231
|
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
232
|
+
|
233
|
+
# Create mask to filter tokens
|
234
|
+
sorted_mask = cumulative_probs <= top_p
|
235
|
+
sorted_mask[..., 0] = True # Ensure at least one token is kept
|
236
|
+
|
237
|
+
# Create mask for original indices
|
238
|
+
mask = torch.zeros_like(logits, dtype=torch.bool)
|
239
|
+
mask.scatter_(dim=-1, index=sorted_indices, src=sorted_mask)
|
240
|
+
|
241
|
+
# Apply mask
|
242
|
+
logits = torch.where(mask, logits, torch.tensor(-float('inf'), device=device))
|
243
|
+
|
244
|
+
# At this point ensure at least one token is available per batch element
|
245
|
+
alive = torch.isfinite(logits).any(dim=-1)
|
246
|
+
if not alive.all():
|
247
|
+
# Force keep the most probable token for dead rows
|
248
|
+
max_indices = logits.argmax(dim=-1)
|
249
|
+
logits[~alive] = -float('inf')
|
250
|
+
logits.scatter_(dim=-1, index=max_indices.unsqueeze(-1), value=0)
|
251
|
+
|
252
|
+
# Calculate log probabilities
|
253
|
+
log_probs = F.log_softmax(logits, dim=-1)
|
254
|
+
|
255
|
+
# Convert to probabilities
|
256
|
+
probs = torch.exp(log_probs)
|
257
|
+
|
258
|
+
# Ensure numerical stability for sampling
|
259
|
+
probs = probs.clamp(min=1e-12)
|
260
|
+
|
261
|
+
# Sample tokens
|
262
|
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
263
|
+
|
264
|
+
# Gather log probabilities
|
265
|
+
selected_log_probs = log_probs.gather(-1, next_tokens.unsqueeze(-1)).squeeze(-1)
|
266
|
+
|
267
|
+
# Convert back to original dtype
|
268
|
+
return next_tokens.to(original_dtype), selected_log_probs.to(torch.float32)
|
269
|
+
|
270
|
+
|
271
|
+
class BatchSampler:
|
272
|
+
def __init__(self, model: nn.Module, device: torch.device, end_token_id: int):
|
273
|
+
self.model = model.to(device)
|
274
|
+
self.device = device
|
275
|
+
self.end_token_id = end_token_id
|
276
|
+
|
277
|
+
def __call__(
|
278
|
+
self,
|
279
|
+
input_ids: torch.Tensor,
|
280
|
+
attention_mask: torch.Tensor,
|
281
|
+
temperature: float = 1.0,
|
282
|
+
top_k: Optional[int] = None,
|
283
|
+
top_p: Optional[float] = None,
|
284
|
+
max_gen_len: int = 256,
|
285
|
+
no_grad: bool = True,
|
286
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
287
|
+
batch_size, max_seq_len = input_ids.shape
|
288
|
+
initial_lens = attention_mask.sum(dim=1)
|
289
|
+
current_lens = initial_lens.clone()
|
290
|
+
finished = torch.zeros(batch_size, dtype=torch.bool, device=self.device)
|
291
|
+
log_probs = torch.zeros((batch_size, max_gen_len), dtype=torch.float32, device=self.device)
|
292
|
+
working_ids = input_ids.clone()
|
293
|
+
working_mask = attention_mask.clone()
|
294
|
+
|
295
|
+
for step in range(max_gen_len):
|
296
|
+
active = (~finished) & (current_lens < max_seq_len)
|
297
|
+
if not active.any():
|
298
|
+
break
|
299
|
+
|
300
|
+
active_indices = active.nonzero(as_tuple=True)[0]
|
301
|
+
active_current_lens = current_lens[active]
|
302
|
+
max_len = active_current_lens.max().item()
|
303
|
+
|
304
|
+
with torch.set_grad_enabled(not no_grad):
|
305
|
+
# Slice input and mask up to the current max length among active sequences
|
306
|
+
inputs = working_ids[active, :max_len]
|
307
|
+
masks = working_mask[active, :max_len]
|
308
|
+
logits = self.model(inputs, attention_mask=masks)
|
309
|
+
|
310
|
+
# Get the last valid token index for each active sequence
|
311
|
+
indices = (active_current_lens - 1).to(self.device)
|
312
|
+
last_logits = logits[torch.arange(len(active_indices), device=self.device), indices]
|
313
|
+
|
314
|
+
# Sample next tokens and log probs
|
315
|
+
next_tokens, step_log_probs = sample_batch(
|
316
|
+
last_logits, temperature=temperature, top_k=top_k, top_p=top_p
|
317
|
+
)
|
318
|
+
|
319
|
+
# Update working tensors
|
320
|
+
for i, idx in enumerate(active_indices):
|
321
|
+
if current_lens[idx] >= max_seq_len:
|
322
|
+
continue
|
323
|
+
pos = current_lens[idx].item()
|
324
|
+
working_ids[idx, pos] = next_tokens[i]
|
325
|
+
working_mask[idx, pos] = 1
|
326
|
+
log_probs[idx, step] = step_log_probs[i]
|
327
|
+
current_lens[idx] += 1
|
328
|
+
if next_tokens[i] == self.end_token_id:
|
329
|
+
finished[idx] = True
|
330
|
+
|
331
|
+
# Extract generated tokens
|
332
|
+
generated_ids = torch.zeros((batch_size, max_gen_len), dtype=torch.long, device=self.device)
|
333
|
+
generated_mask = torch.zeros((batch_size, max_gen_len), dtype=torch.bool, device=self.device)
|
334
|
+
for i in range(batch_size):
|
335
|
+
start = initial_lens[i].item()
|
336
|
+
end = current_lens[i].item()
|
337
|
+
gen_len = min(end - start, max_gen_len)
|
338
|
+
if gen_len > 0:
|
339
|
+
generated_ids[i, :gen_len] = working_ids[i, start:end]
|
340
|
+
generated_mask[i, :gen_len] = working_mask[i, start:end]
|
341
|
+
|
342
|
+
return generated_ids, generated_mask, log_probs
|
343
|
+
|
344
|
+
|
345
|
+
class BatchSampleDecoder:
|
346
|
+
def __init__(
|
347
|
+
self,
|
348
|
+
sampler: BatchSampler,
|
349
|
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
350
|
+
bos_token: str = '[BOS]',
|
351
|
+
query_token: str = '[Q]',
|
352
|
+
):
|
353
|
+
self.sampler = sampler
|
354
|
+
self.tokenizer = tokenizer
|
355
|
+
self.device = self.sampler.device
|
356
|
+
self.bos_token = bos_token
|
357
|
+
self.query_token = query_token
|
358
|
+
|
359
|
+
def tokenize_batch(self, texts: list[str], max_seq_len: int = 256):
|
360
|
+
tokenized = self.tokenizer(
|
361
|
+
[f'{self.bos_token}{self.query_token}{txt}' for txt in texts],
|
362
|
+
max_length=max_seq_len,
|
363
|
+
truncation=True,
|
364
|
+
padding='max_length',
|
365
|
+
return_tensors='pt',
|
366
|
+
return_attention_mask=True,
|
367
|
+
add_special_tokens=False
|
368
|
+
)
|
369
|
+
return {
|
370
|
+
'input_ids': tokenized['input_ids'].to(self.device),
|
371
|
+
'attention_mask': tokenized['attention_mask'].to(self.device)
|
372
|
+
}
|
373
|
+
|
374
|
+
def generate(
|
375
|
+
self,
|
376
|
+
texts: list[str],
|
377
|
+
temperature: float = 1.0,
|
378
|
+
top_p: Optional[float] = None,
|
379
|
+
top_k: Optional[int] = None,
|
380
|
+
max_seq_len: int = 256,
|
381
|
+
no_grad: bool = True,
|
382
|
+
) -> list[str]:
|
383
|
+
tokenized = self.tokenize_batch(texts, max_seq_len)
|
384
|
+
generated_ids, _, _ = self.sampler(
|
385
|
+
input_ids=tokenized['input_ids'],
|
386
|
+
attention_mask=tokenized['attention_mask'],
|
387
|
+
temperature=temperature,
|
388
|
+
top_p=top_p,
|
389
|
+
top_k=top_k,
|
390
|
+
max_gen_len=max_seq_len,
|
391
|
+
no_grad=no_grad,
|
392
|
+
)
|
393
|
+
|
394
|
+
decoded = []
|
395
|
+
for seq in generated_ids:
|
396
|
+
# Trim after end token
|
397
|
+
end_pos = (seq == self.sampler.end_token_id).nonzero()
|
398
|
+
if end_pos.size(0) > 0:
|
399
|
+
seq = seq[:end_pos[0] + 1]
|
400
|
+
decoded.append(self.tokenizer.decode(seq).replace('Ċ', '\n').replace('Ġ', ' '))
|
401
|
+
|
402
|
+
return decoded
|
403
|
+
|
404
|
+
def generate_with_log_probs(
|
405
|
+
self,
|
406
|
+
texts: list[str],
|
407
|
+
temperature: float = 1.0,
|
408
|
+
top_p: Optional[float] = None,
|
409
|
+
top_k: Optional[int] = None,
|
410
|
+
max_seq_len: int = 256,
|
411
|
+
no_grad: bool = True,
|
412
|
+
) -> tuple[list[str], torch.Tensor]:
|
413
|
+
tokenized = self.tokenize_batch(texts, max_seq_len)
|
414
|
+
generated_ids, _, log_probs = self.sampler(
|
415
|
+
input_ids=tokenized['input_ids'],
|
416
|
+
attention_mask=tokenized['attention_mask'],
|
417
|
+
temperature=temperature,
|
418
|
+
top_p=top_p,
|
419
|
+
top_k=top_k,
|
420
|
+
max_gen_len=max_seq_len,
|
421
|
+
no_grad=no_grad,
|
422
|
+
)
|
423
|
+
|
424
|
+
decoded = []
|
425
|
+
for i, seq in enumerate(generated_ids):
|
426
|
+
# Trim after end token
|
427
|
+
end_pos = (seq == self.sampler.end_token_id).nonzero()
|
428
|
+
if end_pos.size(0) > 0:
|
429
|
+
seq = seq[:end_pos[0] + 1]
|
430
|
+
decoded.append(self.tokenizer.decode(seq).replace('Ċ', '\n').replace('Ġ', ' '))
|
431
|
+
|
432
|
+
return decoded, log_probs
|
433
|
+
|
434
|
+
def __call__(
|
435
|
+
self,
|
436
|
+
texts: list[str],
|
437
|
+
temperature: float = 1.0,
|
438
|
+
top_p: Optional[float] = None,
|
439
|
+
top_k: Optional[int] = None,
|
440
|
+
max_seq_len: int = 256,
|
441
|
+
no_grad: bool = True,
|
442
|
+
) -> list[str]:
|
443
|
+
return self.generate(texts, temperature, top_p, top_k, max_seq_len, no_grad)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: rxnn
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: RxNN: Reactive Neural Networks Platform
|
5
5
|
License: Apache-2.0
|
6
6
|
Keywords: deep-learning,ai,machine-learning
|
@@ -23,8 +23,10 @@ Project-URL: Homepage, https://rxai.dev/rxnn
|
|
23
23
|
Project-URL: Repository, https://github.com/RxAI-dev/rxnn/python
|
24
24
|
Description-Content-Type: text/markdown
|
25
25
|
|
26
|
-
<
|
27
|
-
<img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/
|
26
|
+
<span>
|
27
|
+
<img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/logo_rxai_v2.png" width="400" />
|
28
|
+
<img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/logo_rxnn_v2.png" width="400" />
|
29
|
+
</span>
|
28
30
|
|
29
31
|
# Reactive AI - RxNN
|
30
32
|
## Reactive Neural Networks Platform
|
@@ -61,8 +63,8 @@ We are working on three new reactive architectures, that progressively advance f
|
|
61
63
|
|
62
64
|
Each new architecture is based on the previous one and adding new features/abilities. They will be progressively
|
63
65
|
released with next versions of **RxNN** framework:
|
64
|
-
- 0.1.x: Reactive Transformer base models, Base Model Learning (pre-training/fine-tuning) & Transformers extensions (MoE Attention, Short-Term Memory, etc.)
|
65
|
-
- 0.2.x: Memory Reinforcement Learning (MRL) for Short-Term Memory & Reactive Transformer, Attention-based Memory System details
|
66
|
+
- 0.1.x (Released): Reactive Transformer base models, Base Model Learning (pre-training/fine-tuning) & Transformers extensions (MoE Attention, Short-Term Memory, etc.)
|
67
|
+
- 0.2.x (Released): Memory Reinforcement Learning (MRL) for Short-Term Memory & Reactive Transformer, Attention-based Memory System details
|
66
68
|
- 0.3.x: Reinforcement Learning from Human Feedback for Reactive models (RxRLHF), basic Tensor Reactive
|
67
69
|
Extensions (TRX/Rust) for full Reactive Transformer, RxT-Alpha release (+following models - RxT-Beta, etc.)
|
68
70
|
- 0.4.x: Preactor base models, Tensor Database (TDB/Rust) for Long-Term Memory, mxRAG/revRAG subsystems
|
@@ -126,7 +128,7 @@ Submodules:
|
|
126
128
|
- `rxnn.transformers.moe` - Mixture-of-Experts feed forward layers - `MoeFeedForward` & `GatedMoeFeedForward` (recommended)
|
127
129
|
- `rxnn.transformer.layers` - complete reactive/classic transformer layers - `ReactiveTransformerLayer` & `ClassicTransformerLayer`
|
128
130
|
- `rxnn.transformer.models` - reactive/classic transformer models - `ReactiveTransformerEncoder`, `ReactiveTransformerDecoder` & `ClassicTransformerEncoder`, `ClassicTransformerDecoder`
|
129
|
-
- `rxnn.transformer.sampler` - samplers for reactive models (Sampler is the integral part of reactive architectures) - `Sampler` & `
|
131
|
+
- `rxnn.transformer.sampler` - samplers for reactive models (Sampler is the integral part of reactive architectures) - `Sampler`, `SampleDecoder`, `BatchSampler` & `BatchSampleDecoder`
|
130
132
|
|
131
133
|
In **RxNN** models are initialized in declarative style by class composition, but then they are wrapped in imperative classes,
|
132
134
|
to be compatible with HuggingFace **JSON** config. In example:
|
@@ -211,7 +213,7 @@ include **Long-Term Memory**.
|
|
211
213
|
|
212
214
|
The main `ShortTermMemory` class is located in `rxnn.memory.stm` module - the usage example is in Transformers module description.
|
213
215
|
|
214
|
-
|
216
|
+
> 0.2.x Memory modules docs in progress - will be released soon
|
215
217
|
|
216
218
|
#### Training
|
217
219
|
Training module includes **Trainers** for different training stages of reactive models and shared training utils.
|
@@ -233,9 +235,9 @@ Submodules:
|
|
233
235
|
- `rxnn.training.callbacks` contain Trainer callbacks, for different kind of utils (more info below)
|
234
236
|
- `rxnn.training.scheduler` includes learning rate scheduler for training
|
235
237
|
- `rxnn.training.bml` - Base Model Learning module with Trainers for pre-training and fine-tuning
|
236
|
-
- `rxnn.training.mrl` - Memory Reinforcement Learning module with Trainers for MRL
|
238
|
+
- `rxnn.training.mrl` - Memory Reinforcement Learning module with Trainers for MRL
|
237
239
|
- `rxnn.training.rxrlhf` - Reinforcement Learning from Human Feedback for Reactive Models module (from 0.3.x)
|
238
|
-
- `rxnn.training.brl` - Behavioral Reinforcement Learning module (Reactor / from 0.7.x
|
240
|
+
- `rxnn.training.brl` - Behavioral Reinforcement Learning module (Reactor / from 0.7.x)
|
239
241
|
|
240
242
|
##### Base Model Learning
|
241
243
|
Docs in progress
|
@@ -0,0 +1,38 @@
|
|
1
|
+
rxnn/.DS_Store,sha256=BxZLo9tFs48JMq6jhumiCnCPLTeCwl619CFSg4ClRAY,6148
|
2
|
+
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
rxnn/experimental/attention.py,sha256=46qwZLJuZMpIBrZ-r9DaQEPPmmZkO464C3Tkm_Mq-cs,23445
|
5
|
+
rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4888
|
6
|
+
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
|
+
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
+
rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
|
9
|
+
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
10
|
+
rxnn/memory/stm.py,sha256=AoBgtmAKeAQ7U1OD3Zb2oObo27celvWyfJSUQjYw4Jc,4081
|
11
|
+
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
+
rxnn/rxt/models.py,sha256=zNrf6mn-s2vJyauHwNgYm_e-gFI1clmXp_JyCKGQD3E,12083
|
13
|
+
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
+
rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
|
15
|
+
rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
16
|
+
rxnn/training/callbacks.py,sha256=aqi8CfXUWnjMDbELYC5BPBbYyq0YiMicyVaTIr778DY,35053
|
17
|
+
rxnn/training/dataset.py,sha256=XeRzo0KUYyQ43XjZ3o6Jban9ePIRtpHsqUmeKAQPRQk,50305
|
18
|
+
rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
|
19
|
+
rxnn/training/mrl.py,sha256=KcGvBWlBcFJ5GSwd4lx3pUXKlcyeNgJYPZAk3DRMH48,39179
|
20
|
+
rxnn/training/reward.py,sha256=bjm8ya-HFIRA56JvQgnhtotKEpt8yw6yacVTV_SDpm4,5564
|
21
|
+
rxnn/training/rl.py,sha256=FKrBOBAfNub_qzkceFQR-WUtCBffC6oGHE8wlPsz2YA,2682
|
22
|
+
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
|
+
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
24
|
+
rxnn/training/utils.py,sha256=c-6aBaLnKeGfMW6Sp29z3FPLj5hdV3pyGJ2rZMcKs2s,5775
|
25
|
+
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
|
+
rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
|
27
|
+
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
28
|
+
rxnn/transformers/layers.py,sha256=MbOIX4PurbTbYxcXSavyFsNpTHCm26K_Ssk_VUCzKIE,7469
|
29
|
+
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
30
|
+
rxnn/transformers/models.py,sha256=VvP7r7E6tj7OWsYKlJLCy2vsQ3xSSnlNez6QxR-jBAA,8276
|
31
|
+
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
32
|
+
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
|
+
rxnn/transformers/sampler.py,sha256=2dpUQv88ekZa_CMSPLrXvB6X684wxUE2bDVznsi5ACs,17429
|
34
|
+
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
+
rxnn-0.2.0.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.0.dist-info/METADATA,sha256=EPKUh8u9f4ce4h9J4MO8wiLslO04Wd9VsSSlgrOqxUU,25959
|
37
|
+
rxnn-0.2.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.0.dist-info/RECORD,,
|
rxnn-0.1.83.dist-info/RECORD
DELETED
@@ -1,31 +0,0 @@
|
|
1
|
-
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=bpZQiRXdQ8gJPwYRp3LBr2oELmrysB6-SWiD2F7UQrk,23127
|
4
|
-
rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4888
|
5
|
-
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
|
-
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
|
-
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
8
|
-
rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
|
9
|
-
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
-
rxnn/rxt/models.py,sha256=iUlSvdXrD1NVzZFmdC55qp4_3xoJj31FC40BGgYlf4Q,8763
|
11
|
-
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
|
13
|
-
rxnn/training/bml.py,sha256=S1ZaXTybzeJH7uVFamCr4TPl2bLyZ5xmn_lSsjThTiM,19162
|
14
|
-
rxnn/training/callbacks.py,sha256=xcU3W6_OsIEDTFTbN7S3uIWyGqLulbUWZMpW0aIXmF4,22699
|
15
|
-
rxnn/training/dataset.py,sha256=XEDmOwD8v0c9u0QCk7I3xZShKaMtBDwYlfK1ofu6A1E,35789
|
16
|
-
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
17
|
-
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
18
|
-
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
rxnn/transformers/attention.py,sha256=dC0UmC-_kjX8US6Sf0Fi5zw5kJ-P6orH3JDHeBB5gI8,15695
|
20
|
-
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
|
-
rxnn/transformers/layers.py,sha256=OX8CsFY9A7uqH1SLwyexR_5BNlwheYrJHCGXjF8Q7HU,7186
|
22
|
-
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
23
|
-
rxnn/transformers/models.py,sha256=xbnn3FTNZFhaqq9A0XEM12ie_WL_58pPeq0qFXIgve0,7656
|
24
|
-
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
25
|
-
rxnn/transformers/positional.py,sha256=ge-kaS6WnWnPGnWVp25ZK5bVkmhBUNCaELaN2rN_fSY,4097
|
26
|
-
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
|
-
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
28
|
-
rxnn-0.1.83.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
-
rxnn-0.1.83.dist-info/METADATA,sha256=AhGTqWM9mvBzDRWliKeTRySDAL2cXXTYefRL_HGJN_Q,25930
|
30
|
-
rxnn-0.1.83.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
31
|
-
rxnn-0.1.83.dist-info/RECORD,,
|
File without changes
|
File without changes
|