x-transformers 2.4.5__py3-none-any.whl → 2.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,7 +6,7 @@ from functools import partial
6
6
  from random import randrange, uniform
7
7
 
8
8
  import torch
9
- from torch import nn, cat, randperm
9
+ from torch import nn, cat, tensor, randperm
10
10
  from torch.nn import LSTM, Module
11
11
 
12
12
  from x_transformers.x_transformers import (
@@ -76,6 +76,13 @@ class SyntheticDataGenerator(Module):
76
76
 
77
77
  self.apply(self.init_)
78
78
 
79
+ def reset_(self):
80
+ for m in self.modules():
81
+ if hasattr(m, 'reset_parameters'):
82
+ m.reset_parameters()
83
+
84
+ self.apply(self.init_)
85
+
79
86
  @torch.no_grad()
80
87
  def init_(self, m):
81
88
  if isinstance(m, nn.Linear):
@@ -138,7 +145,8 @@ class UniversalPretrainWrapper(Module):
138
145
  num_reset = 20,
139
146
  batch_size = 32,
140
147
  seq_len = 512,
141
- seed_length = 8
148
+ seed_length = 8,
149
+ reset_turing_machine_every = 0
142
150
  ):
143
151
  super().__init__()
144
152
 
@@ -157,6 +165,8 @@ class UniversalPretrainWrapper(Module):
157
165
  max_seq_len = seq_len
158
166
  )
159
167
 
168
+ self.reset_turing_machine_every = reset_turing_machine_every
169
+
160
170
  self.seq_len = seq_len
161
171
  self.data_generator = data_generator
162
172
 
@@ -176,6 +186,7 @@ class UniversalPretrainWrapper(Module):
176
186
  init_data_buffer = self.random_sequences_fn(buffer_size // 2, buffer_size // 2)
177
187
 
178
188
  self.register_buffer('synth_data_buffer', init_data_buffer)
189
+ self.register_buffer('step', tensor(0))
179
190
 
180
191
  @property
181
192
  def device(self):
@@ -210,6 +221,13 @@ class UniversalPretrainWrapper(Module):
210
221
  seed = seeds
211
222
  )
212
223
 
224
+ self.step.add_(1)
225
+
226
+ # maybe reset turing machine
227
+
228
+ if self.reset_turing_machine_every > 0 and divisible_by(self.step.item(), self.reset_turing_machine_every):
229
+ self.data_generator.reset_()
230
+
213
231
  # reset
214
232
 
215
233
  if self.num_reset > 0:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.5
3
+ Version: 2.4.6
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -8,11 +8,11 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
- x_transformers/up_wrapper.py,sha256=RW-0EmheBqJz7g4j-M37XxWMttz-QuPhmnRWMrD7_-0,6103
11
+ x_transformers/up_wrapper.py,sha256=mUTbuTPrN_Z1j2qZ0NGX96djioANNGVdfCdIlg8nCsM,6664
12
12
  x_transformers/x_transformers.py,sha256=9Fi0HvzpeIJqM6HlAd2M6JqsfjhTN1zEH9iFIimyjS4,117608
13
13
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
14
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.4.5.dist-info/METADATA,sha256=J7sgvMp5pLgUWZ37qdDF4EBmECBqfgEZ0rSC7w12AAI,90223
16
- x_transformers-2.4.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.4.5.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.4.5.dist-info/RECORD,,
15
+ x_transformers-2.4.6.dist-info/METADATA,sha256=RlLYhhqc-kLKG4PRDHqbPeVmac4oVQjnqAANxq_nXrg,90223
16
+ x_transformers-2.4.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ x_transformers-2.4.6.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
+ x_transformers-2.4.6.dist-info/RECORD,,