rxnn 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/utils.py +1 -1
- rxnn/transformers/layers.py +3 -1
- rxnn/transformers/models.py +2 -1
- rxnn/transformers/sampler.py +5 -0
- {rxnn-0.2.7.dist-info → rxnn-0.2.8.dist-info}/METADATA +1 -1
- {rxnn-0.2.7.dist-info → rxnn-0.2.8.dist-info}/RECORD +8 -8
- {rxnn-0.2.7.dist-info → rxnn-0.2.8.dist-info}/LICENSE +0 -0
- {rxnn-0.2.7.dist-info → rxnn-0.2.8.dist-info}/WHEEL +0 -0
rxnn/training/utils.py
CHANGED
@@ -127,7 +127,7 @@ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, p
|
|
127
127
|
# Build combined_ids using vectorized where
|
128
128
|
combined_ids = torch.where(
|
129
129
|
query_mask,
|
130
|
-
query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1)
|
130
|
+
query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1).to(torch.int64)),
|
131
131
|
torch.where(
|
132
132
|
answer_mask,
|
133
133
|
answer['input_ids'].gather(1, answer_pos),
|
rxnn/transformers/layers.py
CHANGED
@@ -87,11 +87,12 @@ class ReactiveTransformerLayer(nn.Module):
|
|
87
87
|
else:
|
88
88
|
return None
|
89
89
|
|
90
|
-
def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
90
|
+
def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None, logs: bool = False) -> torch.Tensor:
|
91
91
|
# First step, self-attention
|
92
92
|
residual = x
|
93
93
|
if not self.use_post_norm:
|
94
94
|
x = self.norm1(x)
|
95
|
+
if logs: print('att', x.size())
|
95
96
|
x = self.attention(x, x, x, mask=mask)
|
96
97
|
x = residual + x
|
97
98
|
if self.use_post_norm:
|
@@ -100,6 +101,7 @@ class ReactiveTransformerLayer(nn.Module):
|
|
100
101
|
residual = x
|
101
102
|
if not self.use_post_norm:
|
102
103
|
x = self.norm2(x)
|
104
|
+
if logs: print('mxatt', x.size())
|
103
105
|
x = self.memory_cross_attention(x, stm, stm)
|
104
106
|
x = residual + x
|
105
107
|
if self.use_post_norm:
|
rxnn/transformers/models.py
CHANGED
@@ -83,7 +83,8 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
|
|
83
83
|
# expand layer STM to batch size, if it's not in batch mode
|
84
84
|
if layer_stm.size(0) == 1:
|
85
85
|
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
86
|
-
|
86
|
+
print(f'Layer {i} before: {x.size()}')
|
87
|
+
x = self.layers[i](x, layer_stm, mask=mask, logs=True)
|
87
88
|
return self.head(x)
|
88
89
|
|
89
90
|
|
rxnn/transformers/sampler.py
CHANGED
@@ -293,6 +293,7 @@ class BatchSampler:
|
|
293
293
|
working_mask = attention_mask.clone()
|
294
294
|
|
295
295
|
for step in range(max_gen_len):
|
296
|
+
print('Sampler step', step)
|
296
297
|
active = (~finished) & (current_lens < max_seq_len)
|
297
298
|
if not active.any():
|
298
299
|
break
|
@@ -302,9 +303,13 @@ class BatchSampler:
|
|
302
303
|
max_len = active_current_lens.max().item()
|
303
304
|
|
304
305
|
with torch.set_grad_enabled(not no_grad):
|
306
|
+
print('Active size', active.size())
|
305
307
|
# Slice input and mask up to the current max length among active sequences
|
306
308
|
inputs = working_ids[active, :max_len]
|
307
309
|
masks = working_mask[active, :max_len]
|
310
|
+
print('Model input sizes')
|
311
|
+
print(inputs.size())
|
312
|
+
print(masks.size())
|
308
313
|
logits = self.model(inputs, attention_mask=masks)
|
309
314
|
|
310
315
|
# Get the last valid token index for each active sequence
|
@@ -21,18 +21,18 @@ rxnn/training/reward.py,sha256=i0nhrPCDgy1di89HWylRBS6cQ7rSSxJUiS3TX8fiiHE,5614
|
|
21
21
|
rxnn/training/rl.py,sha256=FKrBOBAfNub_qzkceFQR-WUtCBffC6oGHE8wlPsz2YA,2682
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
23
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
24
|
-
rxnn/training/utils.py,sha256=
|
24
|
+
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
25
25
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
26
|
rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
|
27
27
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
28
|
-
rxnn/transformers/layers.py,sha256=
|
28
|
+
rxnn/transformers/layers.py,sha256=u_48ocn_JtdJ905hLDIxO12qQgday_6GuAyplBKKXz0,7579
|
29
29
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
30
|
-
rxnn/transformers/models.py,sha256=
|
30
|
+
rxnn/transformers/models.py,sha256=kZIZxH4MbgQjQeeKIXGi4tuzXWAfmXjgUCfEzyyeOxA,8338
|
31
31
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
|
-
rxnn/transformers/sampler.py,sha256=
|
33
|
+
rxnn/transformers/sampler.py,sha256=GK5aYALWW-DegI7sIjiO531EWAfKl71k_6CbVuYXwfw,17637
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.8.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.8.dist-info/METADATA,sha256=mOklYF8wm6kn9R70EWls0IYatIxEyKrmpNYspsnMQoM,25959
|
37
|
+
rxnn-0.2.8.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|