rbbt-dm 1.2.6 → 1.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. checksums.yaml +4 -4
  2. data/lib/rbbt/matrix/barcode.rb +2 -2
  3. data/lib/rbbt/matrix/differential.rb +3 -3
  4. data/lib/rbbt/matrix/knowledge_base.rb +1 -1
  5. data/lib/rbbt/plots/bar.rb +1 -1
  6. data/lib/rbbt/stan.rb +1 -1
  7. data/lib/rbbt/statistics/hypergeometric.rb +2 -1
  8. data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
  9. data/lib/rbbt/vector/model/huggingface.rb +57 -38
  10. data/lib/rbbt/vector/model/pytorch_lightning.rb +35 -0
  11. data/lib/rbbt/vector/model/random_forest.rb +1 -1
  12. data/lib/rbbt/vector/model/spaCy.rb +8 -14
  13. data/lib/rbbt/vector/model/tensorflow.rb +6 -5
  14. data/lib/rbbt/vector/model/torch.rb +37 -0
  15. data/lib/rbbt/vector/model/util.rb +18 -0
  16. data/lib/rbbt/vector/model.rb +100 -56
  17. data/python/rbbt_dm/__init__.py +48 -1
  18. data/python/rbbt_dm/atcold/__init__.py +0 -0
  19. data/python/rbbt_dm/atcold/plot_lib.py +141 -0
  20. data/python/rbbt_dm/atcold/spiral.py +27 -0
  21. data/python/rbbt_dm/huggingface.py +57 -26
  22. data/python/rbbt_dm/language_model.py +70 -0
  23. data/python/rbbt_dm/util.py +30 -0
  24. data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
  25. data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
  26. data/test/rbbt/vector/model/test_huggingface.rb +258 -27
  27. data/test/rbbt/vector/model/test_pytorch_lightning.rb +83 -0
  28. data/test/rbbt/vector/model/test_spaCy.rb +1 -1
  29. data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
  30. data/test/rbbt/vector/test_model.rb +25 -26
  31. data/test/test_helper.rb +13 -0
  32. metadata +26 -16
  33. data/lib/rbbt/tensorflow.rb +0 -43
  34. data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
@@ -1,16 +1,25 @@
1
1
  require 'rbbt/util/R'
2
2
  require 'rbbt/vector/model/util'
3
+ require 'rbbt/util/python'
4
+
5
+ RbbtPython.add_path Rbbt.python.find(:lib)
6
+ RbbtPython.init_rbbt
3
7
 
4
8
  class VectorModel
5
- attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process
9
+ attr_accessor :directory, :model_path, :extract_features, :init_model, :train_model, :eval_model, :post_process, :balance
6
10
  attr_accessor :features, :names, :labels, :factor_levels
7
- attr_accessor :model_options
11
+ attr_accessor :model, :model_options
8
12
 
9
13
  def extract_features(&block)
10
14
  @extract_features = block if block_given?
11
15
  @extract_features
12
16
  end
13
17
 
18
+ def init_model(&block)
19
+ @init_model = block if block_given?
20
+ @init_model
21
+ end
22
+
14
23
  def train_model(&block)
15
24
  @train_model = block if block_given?
16
25
  @train_model
@@ -21,13 +30,17 @@ class VectorModel
21
30
  @eval_model
22
31
  end
23
32
 
33
+ def init
34
+ @model ||= self.instance_exec &@init_model
35
+ end
36
+
24
37
  def post_process(&block)
25
38
  @post_process = block if block_given?
26
39
  @post_process
27
40
  end
28
41
 
29
42
 
30
- def self.R_run(model_file, features, labels, code, names = nil, factor_levels = nil)
43
+ def self.R_run(model_path, features, labels, code, names = nil, factor_levels = nil)
31
44
  TmpFile.with_file do |feature_file|
32
45
  Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
33
46
  Open.write(feature_file + '.label', labels * "\n" + "\n")
@@ -54,7 +67,7 @@ features = cbind(features, label = labels);
54
67
  end
55
68
  end
56
69
 
57
- def self.R_train(model_file, features, labels, code, names = nil, factor_levels = nil)
70
+ def self.R_train(model_path, features, labels, code, names = nil, factor_levels = nil)
58
71
  TmpFile.with_file do |feature_file|
59
72
  Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
60
73
  Open.write(feature_file + '.label', labels * "\n" + "\n")
@@ -82,13 +95,13 @@ for (c in names(features)){
82
95
  if (is.factor(features[[c]]))
83
96
  factor_levels[c] = paste(levels(features[[c]]), collapse="\t")
84
97
  }
85
- rbbt.tsv.write("#{model_file}.factor_levels", factor_levels, names=c('Levels'), type='flat')
86
- save(model, file='#{model_file}')
98
+ rbbt.tsv.write("#{model_path}.factor_levels", factor_levels, names=c('Levels'), type='flat')
99
+ save(model, file='#{model_path}')
87
100
  EOF
88
101
  end
89
102
  end
90
103
 
91
- def self.R_eval(model_file, features, list, code, names = nil, factor_levels = nil)
104
+ def self.R_eval(model_path, features, list, code, names = nil, factor_levels = nil)
92
105
  TmpFile.with_file do |feature_file|
93
106
  if list
94
107
  Open.write(feature_file, features.collect{|feat| feat * "\t"} * "\n" + "\n")
@@ -105,7 +118,7 @@ features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=TRUE);
105
118
  #{ factor_levels.collect do |name,levels|
106
119
  "features[['#{name}']] = factor(features[['#{name}']], levels=#{R.ruby2R levels})"
107
120
  end * "\n" if factor_levels }
108
- load(file="#{model_file}");
121
+ load(file="#{model_path}");
109
122
  #{code}
110
123
  cat(paste(label, sep="\\n", collapse="\\n"));
111
124
  EOF
@@ -127,61 +140,77 @@ cat(paste(label, sep="\\n", collapse="\\n"));
127
140
  instance_eval code, file
128
141
  end
129
142
 
130
- def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, post_process = nil, names = nil, factor_levels = nil)
143
+ def initialize(directory = nil, model_options = {})
131
144
  @directory = directory
145
+ @model_options = IndiferentHash.setup(model_options)
146
+
132
147
  if @directory
133
- FileUtils.mkdir_p @directory unless File.exists?(@directory)
148
+ FileUtils.mkdir_p @directory unless File.exist?(@directory)
149
+
150
+ @model_path = File.join(@directory, "model")
134
151
 
135
- @model_file = File.join(@directory, "model")
136
152
  @extract_features_file = File.join(@directory, "features")
137
- @train_model_file = File.join(@directory, "train_model")
138
- @eval_model_file = File.join(@directory, "eval_model")
139
- @post_process_file = File.join(@directory, "post_process")
140
- @train_model_file_R = File.join(@directory, "train_model.R")
141
- @eval_model_file_R = File.join(@directory, "eval_model.R")
142
- @post_process_file_R = File.join(@directory, "post_process.R")
143
- @names_file = File.join(@directory, "feature_names")
144
- @levels_file = File.join(@directory, "levels")
145
- @options_file = File.join(@directory, "options.json")
146
-
147
- if File.exists?(@options_file)
148
- @model_options = JSON.parse(Open.read(@options_file))
153
+ @init_model_path = File.join(@directory, "init_model")
154
+
155
+ @train_model_path = File.join(@directory, "train_model")
156
+ @train_model_path_R = File.join(@directory, "train_model.R")
157
+
158
+ @eval_model_path = File.join(@directory, "eval_model")
159
+ @eval_model_path_R = File.join(@directory, "eval_model.R")
160
+
161
+ @post_process_file = File.join(@directory, "post_process")
162
+ @post_process_file_R = File.join(@directory, "post_process.R")
163
+
164
+ @names_file = File.join(@directory, "feature_names")
165
+ @levels_file = File.join(@directory, "levels")
166
+ @options_file = File.join(@directory, "options.json")
167
+
168
+ if File.exist?(@options_file)
169
+ @model_options = JSON.parse(Open.read(@options_file)).merge(@model_options || {})
149
170
  IndiferentHash.setup(@model_options)
150
171
  end
151
172
  end
152
-
173
+
153
174
  if extract_features.nil?
154
- if @extract_features_file && File.exists?(@extract_features_file)
175
+ if @extract_features_file && File.exist?(@extract_features_file)
155
176
  @extract_features = __load_method @extract_features_file
156
177
  end
157
178
  else
158
179
  @extract_features = extract_features
159
180
  end
160
181
 
182
+ if init_model.nil?
183
+ if @init_model_path && File.exist?(@init_model_path)
184
+ @init_model = __load_method @init_model_path
185
+ end
186
+ else
187
+ @init_model = init_model
188
+ end
189
+
161
190
  if train_model.nil?
162
- if @train_model_file && File.exists?(@train_model_file)
163
- @train_model = __load_method @train_model_file
164
- elsif @train_model_file_R && File.exists?(@train_model_file_R)
165
- @train_model = Open.read(@train_model_file_R)
191
+ if @train_model_path && File.exist?(@train_model_path)
192
+ @train_model = __load_method @train_model_path
193
+ elsif @train_model_path_R && File.exist?(@train_model_path_R)
194
+ @train_model = Open.read(@train_model_path_R)
166
195
  end
167
196
  else
168
197
  @train_model = train_model
169
198
  end
170
199
 
171
200
  if eval_model.nil?
172
- if @eval_model_file && File.exists?(@eval_model_file)
173
- @eval_model = __load_method @eval_model_file
174
- elsif @eval_model_file_R && File.exists?(@eval_model_file_R)
175
- @eval_model = Open.read(@eval_model_file_R)
201
+ if @eval_model_path && File.exist?(@eval_model_path)
202
+ @eval_model = __load_method @eval_model_path
203
+ elsif @eval_model_path_R && File.exist?(@eval_model_path_R)
204
+ @eval_model = Open.read(@eval_model_path_R)
176
205
  end
177
206
  else
178
207
  @eval_model = eval_model
179
208
  end
180
209
 
181
210
  if post_process.nil?
182
- if @post_process_file && File.exists?(@post_process_file)
211
+ if @post_process_file && File.exist?(@post_process_file)
183
212
  @post_process = __load_method @post_process_file
184
- elsif @post_process_file_R && File.exists?(@post_process_file_R)
213
+ elsif @post_process_file_R && File.exist?(@post_process_file_R)
185
214
  @post_process = Open.read(@post_process_file_R)
186
215
  end
187
216
  else
@@ -190,7 +219,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
190
219
 
191
220
 
192
221
  if names.nil?
193
- if @names_file && File.exists?(@names_file)
222
+ if @names_file && File.exist?(@names_file)
194
223
  @names = Open.read(@names_file).split("\n")
195
224
  end
196
225
  else
@@ -198,11 +227,11 @@ cat(paste(label, sep="\\n", collapse="\\n"));
198
227
  end
199
228
 
200
229
  if factor_levels.nil?
201
- if @levels_file && File.exists?(@levels_file)
230
+ if @levels_file && File.exist?(@levels_file)
202
231
  @factor_levels = YAML.load(Open.read(@levels_file))
203
232
  end
204
- if @model_file && File.exists?(@model_file + '.factor_levels')
205
- @factor_levels = TSV.open(@model_file + '.factor_levels')
233
+ if @model_path && File.exist?(@model_path + '.factor_levels')
234
+ @factor_levels = TSV.open(@model_path + '.factor_levels')
206
235
  end
207
236
  else
208
237
  @factor_levels = factor_levels
@@ -241,23 +270,24 @@ cat(paste(label, sep="\\n", collapse="\\n"));
241
270
  case
242
271
  when Proc === train_model
243
272
  begin
244
- Open.write(@train_model_file, train_model.source)
273
+ Open.write(@train_model_path, train_model.source)
245
274
  rescue
246
275
  end
247
276
  when String === train_model
248
- Open.write(@train_model_file_R, @train_model)
277
+ Open.write(@train_model_path_R, @train_model)
249
278
  end
250
279
 
251
280
  Open.write(@extract_features_file, @extract_features.source) if @extract_features
281
+ Open.write(@init_model_path, @init_model.source) if @init_model
252
282
 
253
283
  case
254
284
  when Proc === eval_model
255
285
  begin
256
- Open.write(@eval_model_file, eval_model.source)
286
+ Open.write(@eval_model_path, eval_model.source)
257
287
  rescue
258
288
  end
259
289
  when String === eval_model
260
- Open.write(@eval_model_file_R, eval_model)
290
+ Open.write(@eval_model_path_R, eval_model)
261
291
  end
262
292
 
263
293
  case
@@ -270,24 +300,37 @@ cat(paste(label, sep="\\n", collapse="\\n"));
270
300
  Open.write(@post_process_file_R, post_process)
271
301
  end
272
302
 
273
-
274
303
  Open.write(@levels_file, @factor_levels.to_yaml) if @factor_levels
275
304
  Open.write(@names_file, @names * "\n" + "\n") if @names
276
305
  Open.write(@options_file, @model_options.to_json) if @model_options
277
306
  end
278
307
 
279
308
  def train
280
- case
281
- when Proc === @train_model
282
- self.instance_exec(@model_file, @features, @labels, @names, @factor_levels, &@train_model)
283
- when String === @train_model
284
- VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
309
+ begin
310
+ if @balance
311
+ @original_features = @features
312
+ @original_labels = @labels
313
+ self.balance_labels
314
+ end
315
+
316
+ case
317
+ when Proc === @train_model
318
+ self.instance_exec(@features, @labels, @names, @factor_levels, &@train_model)
319
+ when String === @train_model
320
+ VectorModel.R_train(@model_path, @features, @labels, train_model, @names, @factor_levels)
321
+ end
322
+ ensure
323
+ if @balance
324
+ @features = @original_features
325
+ @labels = @original_labels
326
+ end
285
327
  end
328
+
286
329
  save_models if @directory
287
330
  end
288
331
 
289
332
  def run(code)
290
- VectorModel.R_run(@model_file, @features, @labels, code, @names, @factor_levels)
333
+ VectorModel.R_run(@model_path, @features, @labels, code, @names, @factor_levels)
291
334
  end
292
335
 
293
336
  def eval(element)
@@ -295,14 +338,14 @@ cat(paste(label, sep="\\n", collapse="\\n"));
295
338
 
296
339
  result = case
297
340
  when Proc === @eval_model
298
- self.instance_exec(@model_file, features, false, nil, @names, @factor_levels, &@eval_model)
341
+ self.instance_exec(features, false, nil, @names, @factor_levels, &@eval_model)
299
342
  when String === @eval_model
300
- VectorModel.R_eval(@model_file, features, false, eval_model, @names, @factor_levels)
343
+ VectorModel.R_eval(@model_path, features, false, eval_model, @names, @factor_levels)
301
344
  else
302
345
  raise "No @eval_model function or R script"
303
346
  end
304
347
 
305
- result = self.instance_exec(result, &@post_process) if Proc === @post_process
348
+ result = self.instance_exec(result, false, &@post_process) if Proc === @post_process
306
349
 
307
350
  result
308
351
  end
@@ -321,12 +364,12 @@ cat(paste(label, sep="\\n", collapse="\\n"));
321
364
 
322
365
  result = case
323
366
  when Proc === eval_model
324
- self.instance_exec(@model_file, features, true, nil, @names, @factor_levels, &@eval_model)
367
+ self.instance_exec(features, true, nil, @names, @factor_levels, &@eval_model)
325
368
  when String === eval_model
326
- VectorModel.R_eval(@model_file, features, true, eval_model, @names, @factor_levels)
369
+ VectorModel.R_eval(@model_path, features, true, eval_model, @names, @factor_levels)
327
370
  end
328
371
 
329
- result = self.instance_exec(result, &@post_process) if Proc === @post_process
372
+ result = self.instance_exec(result, true, &@post_process) if Proc === @post_process
330
373
 
331
374
  result
332
375
  end
@@ -438,6 +481,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
438
481
  @features = orig_features
439
482
  @labels = orig_labels
440
483
  end unless folds == -1
484
+
441
485
  self.reset_model if self.respond_to? :reset_model
442
486
  self.train unless folds == 1
443
487
  res
@@ -1 +1,48 @@
1
- # Keep
1
+ from torch.utils.data import Dataset, DataLoader
2
+
3
+ class TSVDataset(Dataset):
4
+ def __init__(self, tsv):
5
+ self.tsv = tsv
6
+
7
+ def __getitem__(self, key):
8
+ if (type(key) == int):
9
+ row = self.tsv.iloc[key]
10
+ else:
11
+ row = self.tsv.loc[key]
12
+
13
+ row = row.to_numpy()
14
+ features = row[:-1]
15
+ label = row[-1]
16
+
17
+ return features, label
18
+
19
+ def __len__(self):
20
+ return len(self.tsv)
21
+
22
+ def tsv_dataset(filename, *args, **kwargs):
23
+ import rbbt
24
+ return TSVDataset(rbbt.tsv(filename, *args, **kwargs))
25
+
26
+ def tsv(*args, **kwargs):
27
+ return tsv_dataset(*args, **kwargs)
28
+
29
+ def data_dir():
30
+ import rbbt
31
+ return rbbt.path('var/rbbt_dm/data')
32
+
33
+ if __name__ == "__main__":
34
+ import rbbt
35
+
36
+ filename = "/home/miki/test/numeric.tsv"
37
+ ds = tsv(filename)
38
+
39
+ dl = DataLoader(ds, batch_size=1)
40
+
41
+ for f, l in iter(dl):
42
+ print(".")
43
+ print(f[0,:])
44
+ print(l[0])
45
+
46
+
47
+
48
+
File without changes
@@ -0,0 +1,141 @@
1
+ from matplotlib import pyplot as plt
2
+ import numpy as np
3
+ import torch
4
+ from IPython.display import HTML, display
5
+
6
+
7
+ def set_default(figsize=(10, 10), dpi=100):
8
+ plt.style.use(['dark_background', 'bmh'])
9
+ plt.rc('axes', facecolor='k')
10
+ plt.rc('figure', facecolor='k')
11
+ plt.rc('figure', figsize=figsize, dpi=dpi)
12
+
13
+
14
+ def plot_data(X, y, d=0, auto=False, zoom=1):
15
+ X = X.cpu()
16
+ y = y.cpu()
17
+ plt.scatter(X.numpy()[:, 0], X.numpy()[:, 1], c=y, s=20, cmap=plt.cm.Spectral)
18
+ plt.axis('square')
19
+ plt.axis(np.array((-1.1, 1.1, -1.1, 1.1)) * zoom)
20
+ if auto is True: plt.axis('equal')
21
+ plt.axis('off')
22
+
23
+ _m, _c = 0, '.15'
24
+ plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0)
25
+ plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0)
26
+
27
+
28
+ def plot_model(X, y, model):
29
+ model.cpu()
30
+ mesh = np.arange(-1.1, 1.1, 0.01)
31
+ xx, yy = np.meshgrid(mesh, mesh)
32
+ with torch.no_grad():
33
+ data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float()
34
+ Z = model(data).detach()
35
+ Z = np.argmax(Z, axis=1).reshape(xx.shape)
36
+ plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3)
37
+ plot_data(X, y)
38
+
39
+
40
+ def show_scatterplot(X, colors, title=''):
41
+ colors = colors.cpu().numpy()
42
+ X = X.cpu().numpy()
43
+ plt.figure()
44
+ plt.axis('equal')
45
+ plt.scatter(X[:, 0], X[:, 1], c=colors, s=30)
46
+ # plt.grid(True)
47
+ plt.title(title)
48
+ plt.axis('off')
49
+
50
+
51
+ def plot_bases(bases, width=0.04):
52
+ bases = bases.cpu()
53
+ bases[2:] -= bases[:2]
54
+ plt.arrow(*bases[0], *bases[2], width=width, color=(1,0,0), zorder=10, alpha=1., length_includes_head=True)
55
+ plt.arrow(*bases[1], *bases[3], width=width, color=(0,1,0), zorder=10, alpha=1., length_includes_head=True)
56
+
57
+
58
+ def show_mat(mat, vect, prod, threshold=-1):
59
+ # Subplot grid definition
60
+ fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=False, sharey=True,
61
+ gridspec_kw={'width_ratios':[5,1,1]})
62
+ # Plot matrices
63
+ cax1 = ax1.matshow(mat.numpy(), clim=(-1, 1))
64
+ ax2.matshow(vect.numpy(), clim=(-1, 1))
65
+ cax3 = ax3.matshow(prod.numpy(), clim=(threshold, 1))
66
+
67
+ # Set titles
68
+ ax1.set_title(f'A: {mat.size(0)} \u00D7 {mat.size(1)}')
69
+ ax2.set_title(f'a^(i): {vect.numel()}')
70
+ ax3.set_title(f'p: {prod.numel()}')
71
+
72
+ # Remove xticks for vectors
73
+ ax2.set_xticks(tuple())
74
+ ax3.set_xticks(tuple())
75
+
76
+ # Plot colourbars
77
+ fig.colorbar(cax1, ax=ax2)
78
+ fig.colorbar(cax3, ax=ax3)
79
+
80
+ # Fix y-axis limits
81
+ ax1.set_ylim(bottom=max(len(prod), len(vect)) - 0.5)
82
+
83
+
84
+ colors = dict(
85
+ aqua='#8dd3c7',
86
+ yellow='#ffffb3',
87
+ lavender='#bebada',
88
+ red='#fb8072',
89
+ blue='#80b1d3',
90
+ orange='#fdb462',
91
+ green='#b3de69',
92
+ pink='#fccde5',
93
+ grey='#d9d9d9',
94
+ violet='#bc80bd',
95
+ unk1='#ccebc5',
96
+ unk2='#ffed6f',
97
+ )
98
+
99
+
100
+ def _cstr(s, color='black'):
101
+ if s == ' ':
102
+ return f'<text style=color:#000;padding-left:10px;background-color:{color}> </text>'
103
+ else:
104
+ return f'<text style=color:#000;background-color:{color}>{s} </text>'
105
+
106
+ # print html
107
+ def _print_color(t):
108
+ display(HTML(''.join([_cstr(ti, color=ci) for ti, ci in t])))
109
+
110
+ # get appropriate color for value
111
+ def _get_clr(value):
112
+ colors = ('#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8',
113
+ '#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8',
114
+ '#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f',
115
+ '#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e')
116
+ value = int((value * 100) / 5)
117
+ if value == len(colors): value -= 1 # fixing bugs...
118
+ return colors[value]
119
+
120
+ def _visualise_values(output_values, result_list):
121
+ text_colours = []
122
+ for i in range(len(output_values)):
123
+ text = (result_list[i], _get_clr(output_values[i]))
124
+ text_colours.append(text)
125
+ _print_color(text_colours)
126
+
127
+ def print_colourbar():
128
+ color_range = torch.linspace(-2.5, 2.5, 20)
129
+ to_print = [(f'{x:.2f}', _get_clr((x+2.5)/5)) for x in color_range]
130
+ _print_color(to_print)
131
+
132
+
133
+ # Let's only focus on the last time step for now
134
+ # First, the cell state (Long term memory)
135
+ def plot_state(data, state, b, decoder):
136
+ actual_data = decoder(data[b, :, :].numpy())
137
+ seq_len = len(actual_data)
138
+ seq_len_w_pad = len(state)
139
+ for s in range(state.size(2)):
140
+ states = torch.sigmoid(state[:, b, s])
141
+ _visualise_values(states[seq_len_w_pad - seq_len:], list(actual_data))
@@ -0,0 +1,27 @@
1
+ import torch
2
+ import math
3
+ def spiral_data(N=1000, D=2, C=3):
4
+ X = torch.zeros(N * C, D)
5
+ y = torch.zeros(N * C, dtype=torch.long)
6
+ for c in range(C):
7
+ index = 0
8
+ t = torch.linspace(0, 1, N)
9
+ # When c = 0 and t = 0: start of linspace
10
+ # When c = 0 and t = 1: end of linpace
11
+ # This inner_var is for the formula inside sin() and cos() like sin(inner_var) and cos(inner_Var)
12
+ inner_var = torch.linspace(
13
+ # When t = 0
14
+ (2 * math.pi / C) * (c),
15
+ # When t = 1
16
+ (2 * math.pi / C) * (2 + c),
17
+ N
18
+ ) + torch.randn(N) * 0.2
19
+
20
+ for ix in range(N * c, N * (c + 1)):
21
+ X[ix] = t[index] * torch.FloatTensor((
22
+ math.sin(inner_var[index]), math.cos(inner_var[index])
23
+ ))
24
+ y[ix] = c
25
+ index += 1
26
+
27
+ return (X, y)