rbbt-dm 1.2.6 → 1.2.9

Sign up to get free protection for your applications and to get access to all the features.
Files changed (34) hide show
  1. checksums.yaml +4 -4
  2. data/lib/rbbt/matrix/barcode.rb +2 -2
  3. data/lib/rbbt/matrix/differential.rb +3 -3
  4. data/lib/rbbt/matrix/knowledge_base.rb +1 -1
  5. data/lib/rbbt/plots/bar.rb +1 -1
  6. data/lib/rbbt/stan.rb +1 -1
  7. data/lib/rbbt/statistics/hypergeometric.rb +2 -1
  8. data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
  9. data/lib/rbbt/vector/model/huggingface.rb +57 -38
  10. data/lib/rbbt/vector/model/pytorch_lightning.rb +35 -0
  11. data/lib/rbbt/vector/model/random_forest.rb +1 -1
  12. data/lib/rbbt/vector/model/spaCy.rb +8 -14
  13. data/lib/rbbt/vector/model/tensorflow.rb +6 -5
  14. data/lib/rbbt/vector/model/torch.rb +37 -0
  15. data/lib/rbbt/vector/model/util.rb +18 -0
  16. data/lib/rbbt/vector/model.rb +100 -56
  17. data/python/rbbt_dm/__init__.py +48 -1
  18. data/python/rbbt_dm/atcold/__init__.py +0 -0
  19. data/python/rbbt_dm/atcold/plot_lib.py +141 -0
  20. data/python/rbbt_dm/atcold/spiral.py +27 -0
  21. data/python/rbbt_dm/huggingface.py +57 -26
  22. data/python/rbbt_dm/language_model.py +70 -0
  23. data/python/rbbt_dm/util.py +30 -0
  24. data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
  25. data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
  26. data/test/rbbt/vector/model/test_huggingface.rb +258 -27
  27. data/test/rbbt/vector/model/test_pytorch_lightning.rb +83 -0
  28. data/test/rbbt/vector/model/test_spaCy.rb +1 -1
  29. data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
  30. data/test/rbbt/vector/test_model.rb +25 -26
  31. data/test/test_helper.rb +13 -0
  32. metadata +26 -16
  33. data/lib/rbbt/tensorflow.rb +0 -43
  34. data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
@@ -1,16 +1,25 @@
1
1
  require 'rbbt/util/R'
2
2
  require 'rbbt/vector/model/util'
3
+ require 'rbbt/util/python'
4
+
5
+ RbbtPython.add_path Rbbt.python.find(:lib)
6
+ RbbtPython.init_rbbt
3
7
 
4
8
  class VectorModel
5
- attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process
9
+ attr_accessor :directory, :model_path, :extract_features, :init_model, :train_model, :eval_model, :post_process, :balance
6
10
  attr_accessor :features, :names, :labels, :factor_levels
7
- attr_accessor :model_options
11
+ attr_accessor :model, :model_options
8
12
 
9
13
  def extract_features(&block)
10
14
  @extract_features = block if block_given?
11
15
  @extract_features
12
16
  end
13
17
 
18
+ def init_model(&block)
19
+ @init_model = block if block_given?
20
+ @init_model
21
+ end
22
+
14
23
  def train_model(&block)
15
24
  @train_model = block if block_given?
16
25
  @train_model
@@ -21,13 +30,17 @@ class VectorModel
21
30
  @eval_model
22
31
  end
23
32
 
33
+ def init
34
+ @model ||= self.instance_exec &@init_model
35
+ end
36
+
24
37
  def post_process(&block)
25
38
  @post_process = block if block_given?
26
39
  @post_process
27
40
  end
28
41
 
29
42
 
30
- def self.R_run(model_file, features, labels, code, names = nil, factor_levels = nil)
43
+ def self.R_run(model_path, features, labels, code, names = nil, factor_levels = nil)
31
44
  TmpFile.with_file do |feature_file|
32
45
  Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
33
46
  Open.write(feature_file + '.label', labels * "\n" + "\n")
@@ -54,7 +67,7 @@ features = cbind(features, label = labels);
54
67
  end
55
68
  end
56
69
 
57
- def self.R_train(model_file, features, labels, code, names = nil, factor_levels = nil)
70
+ def self.R_train(model_path, features, labels, code, names = nil, factor_levels = nil)
58
71
  TmpFile.with_file do |feature_file|
59
72
  Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
60
73
  Open.write(feature_file + '.label', labels * "\n" + "\n")
@@ -82,13 +95,13 @@ for (c in names(features)){
82
95
  if (is.factor(features[[c]]))
83
96
  factor_levels[c] = paste(levels(features[[c]]), collapse="\t")
84
97
  }
85
- rbbt.tsv.write("#{model_file}.factor_levels", factor_levels, names=c('Levels'), type='flat')
86
- save(model, file='#{model_file}')
98
+ rbbt.tsv.write("#{model_path}.factor_levels", factor_levels, names=c('Levels'), type='flat')
99
+ save(model, file='#{model_path}')
87
100
  EOF
88
101
  end
89
102
  end
90
103
 
91
- def self.R_eval(model_file, features, list, code, names = nil, factor_levels = nil)
104
+ def self.R_eval(model_path, features, list, code, names = nil, factor_levels = nil)
92
105
  TmpFile.with_file do |feature_file|
93
106
  if list
94
107
  Open.write(feature_file, features.collect{|feat| feat * "\t"} * "\n" + "\n")
@@ -105,7 +118,7 @@ features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=TRUE);
105
118
  #{ factor_levels.collect do |name,levels|
106
119
  "features[['#{name}']] = factor(features[['#{name}']], levels=#{R.ruby2R levels})"
107
120
  end * "\n" if factor_levels }
108
- load(file="#{model_file}");
121
+ load(file="#{model_path}");
109
122
  #{code}
110
123
  cat(paste(label, sep="\\n", collapse="\\n"));
111
124
  EOF
@@ -127,61 +140,77 @@ cat(paste(label, sep="\\n", collapse="\\n"));
127
140
  instance_eval code, file
128
141
  end
129
142
 
130
- def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, post_process = nil, names = nil, factor_levels = nil)
143
+ def initialize(directory = nil, model_options = {})
131
144
  @directory = directory
145
+ @model_options = IndiferentHash.setup(model_options)
146
+
132
147
  if @directory
133
- FileUtils.mkdir_p @directory unless File.exists?(@directory)
148
+ FileUtils.mkdir_p @directory unless File.exist?(@directory)
149
+
150
+ @model_path = File.join(@directory, "model")
134
151
 
135
- @model_file = File.join(@directory, "model")
136
152
  @extract_features_file = File.join(@directory, "features")
137
- @train_model_file = File.join(@directory, "train_model")
138
- @eval_model_file = File.join(@directory, "eval_model")
139
- @post_process_file = File.join(@directory, "post_process")
140
- @train_model_file_R = File.join(@directory, "train_model.R")
141
- @eval_model_file_R = File.join(@directory, "eval_model.R")
142
- @post_process_file_R = File.join(@directory, "post_process.R")
143
- @names_file = File.join(@directory, "feature_names")
144
- @levels_file = File.join(@directory, "levels")
145
- @options_file = File.join(@directory, "options.json")
146
-
147
- if File.exists?(@options_file)
148
- @model_options = JSON.parse(Open.read(@options_file))
153
+ @init_model_path = File.join(@directory, "init_model")
154
+
155
+ @train_model_path = File.join(@directory, "train_model")
156
+ @train_model_path_R = File.join(@directory, "train_model.R")
157
+
158
+ @eval_model_path = File.join(@directory, "eval_model")
159
+ @eval_model_path_R = File.join(@directory, "eval_model.R")
160
+
161
+ @post_process_file = File.join(@directory, "post_process")
162
+ @post_process_file_R = File.join(@directory, "post_process.R")
163
+
164
+ @names_file = File.join(@directory, "feature_names")
165
+ @levels_file = File.join(@directory, "levels")
166
+ @options_file = File.join(@directory, "options.json")
167
+
168
+ if File.exist?(@options_file)
169
+ @model_options = JSON.parse(Open.read(@options_file)).merge(@model_options || {})
149
170
  IndiferentHash.setup(@model_options)
150
171
  end
151
172
  end
152
-
173
+
153
174
  if extract_features.nil?
154
- if @extract_features_file && File.exists?(@extract_features_file)
175
+ if @extract_features_file && File.exist?(@extract_features_file)
155
176
  @extract_features = __load_method @extract_features_file
156
177
  end
157
178
  else
158
179
  @extract_features = extract_features
159
180
  end
160
181
 
182
+ if init_model.nil?
183
+ if @init_model_path && File.exist?(@init_model_path)
184
+ @init_model = __load_method @init_model_path
185
+ end
186
+ else
187
+ @init_model = init_model
188
+ end
189
+
161
190
  if train_model.nil?
162
- if @train_model_file && File.exists?(@train_model_file)
163
- @train_model = __load_method @train_model_file
164
- elsif @train_model_file_R && File.exists?(@train_model_file_R)
165
- @train_model = Open.read(@train_model_file_R)
191
+ if @train_model_path && File.exist?(@train_model_path)
192
+ @train_model = __load_method @train_model_path
193
+ elsif @train_model_path_R && File.exist?(@train_model_path_R)
194
+ @train_model = Open.read(@train_model_path_R)
166
195
  end
167
196
  else
168
197
  @train_model = train_model
169
198
  end
170
199
 
171
200
  if eval_model.nil?
172
- if @eval_model_file && File.exists?(@eval_model_file)
173
- @eval_model = __load_method @eval_model_file
174
- elsif @eval_model_file_R && File.exists?(@eval_model_file_R)
175
- @eval_model = Open.read(@eval_model_file_R)
201
+ if @eval_model_path && File.exist?(@eval_model_path)
202
+ @eval_model = __load_method @eval_model_path
203
+ elsif @eval_model_path_R && File.exist?(@eval_model_path_R)
204
+ @eval_model = Open.read(@eval_model_path_R)
176
205
  end
177
206
  else
178
207
  @eval_model = eval_model
179
208
  end
180
209
 
181
210
  if post_process.nil?
182
- if @post_process_file && File.exists?(@post_process_file)
211
+ if @post_process_file && File.exist?(@post_process_file)
183
212
  @post_process = __load_method @post_process_file
184
- elsif @post_process_file_R && File.exists?(@post_process_file_R)
213
+ elsif @post_process_file_R && File.exist?(@post_process_file_R)
185
214
  @post_process = Open.read(@post_process_file_R)
186
215
  end
187
216
  else
@@ -190,7 +219,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
190
219
 
191
220
 
192
221
  if names.nil?
193
- if @names_file && File.exists?(@names_file)
222
+ if @names_file && File.exist?(@names_file)
194
223
  @names = Open.read(@names_file).split("\n")
195
224
  end
196
225
  else
@@ -198,11 +227,11 @@ cat(paste(label, sep="\\n", collapse="\\n"));
198
227
  end
199
228
 
200
229
  if factor_levels.nil?
201
- if @levels_file && File.exists?(@levels_file)
230
+ if @levels_file && File.exist?(@levels_file)
202
231
  @factor_levels = YAML.load(Open.read(@levels_file))
203
232
  end
204
- if @model_file && File.exists?(@model_file + '.factor_levels')
205
- @factor_levels = TSV.open(@model_file + '.factor_levels')
233
+ if @model_path && File.exist?(@model_path + '.factor_levels')
234
+ @factor_levels = TSV.open(@model_path + '.factor_levels')
206
235
  end
207
236
  else
208
237
  @factor_levels = factor_levels
@@ -241,23 +270,24 @@ cat(paste(label, sep="\\n", collapse="\\n"));
241
270
  case
242
271
  when Proc === train_model
243
272
  begin
244
- Open.write(@train_model_file, train_model.source)
273
+ Open.write(@train_model_path, train_model.source)
245
274
  rescue
246
275
  end
247
276
  when String === train_model
248
- Open.write(@train_model_file_R, @train_model)
277
+ Open.write(@train_model_path_R, @train_model)
249
278
  end
250
279
 
251
280
  Open.write(@extract_features_file, @extract_features.source) if @extract_features
281
+ Open.write(@init_model_path, @init_model.source) if @init_model
252
282
 
253
283
  case
254
284
  when Proc === eval_model
255
285
  begin
256
- Open.write(@eval_model_file, eval_model.source)
286
+ Open.write(@eval_model_path, eval_model.source)
257
287
  rescue
258
288
  end
259
289
  when String === eval_model
260
- Open.write(@eval_model_file_R, eval_model)
290
+ Open.write(@eval_model_path_R, eval_model)
261
291
  end
262
292
 
263
293
  case
@@ -270,24 +300,37 @@ cat(paste(label, sep="\\n", collapse="\\n"));
270
300
  Open.write(@post_process_file_R, post_process)
271
301
  end
272
302
 
273
-
274
303
  Open.write(@levels_file, @factor_levels.to_yaml) if @factor_levels
275
304
  Open.write(@names_file, @names * "\n" + "\n") if @names
276
305
  Open.write(@options_file, @model_options.to_json) if @model_options
277
306
  end
278
307
 
279
308
  def train
280
- case
281
- when Proc === @train_model
282
- self.instance_exec(@model_file, @features, @labels, @names, @factor_levels, &@train_model)
283
- when String === @train_model
284
- VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
309
+ begin
310
+ if @balance
311
+ @original_features = @features
312
+ @original_labels = @labels
313
+ self.balance_labels
314
+ end
315
+
316
+ case
317
+ when Proc === @train_model
318
+ self.instance_exec(@features, @labels, @names, @factor_levels, &@train_model)
319
+ when String === @train_model
320
+ VectorModel.R_train(@model_path, @features, @labels, train_model, @names, @factor_levels)
321
+ end
322
+ ensure
323
+ if @balance
324
+ @features = @original_features
325
+ @labels = @original_labels
326
+ end
285
327
  end
328
+
286
329
  save_models if @directory
287
330
  end
288
331
 
289
332
  def run(code)
290
- VectorModel.R_run(@model_file, @features, @labels, code, @names, @factor_levels)
333
+ VectorModel.R_run(@model_path, @features, @labels, code, @names, @factor_levels)
291
334
  end
292
335
 
293
336
  def eval(element)
@@ -295,14 +338,14 @@ cat(paste(label, sep="\\n", collapse="\\n"));
295
338
 
296
339
  result = case
297
340
  when Proc === @eval_model
298
- self.instance_exec(@model_file, features, false, nil, @names, @factor_levels, &@eval_model)
341
+ self.instance_exec(features, false, nil, @names, @factor_levels, &@eval_model)
299
342
  when String === @eval_model
300
- VectorModel.R_eval(@model_file, features, false, eval_model, @names, @factor_levels)
343
+ VectorModel.R_eval(@model_path, features, false, eval_model, @names, @factor_levels)
301
344
  else
302
345
  raise "No @eval_model function or R script"
303
346
  end
304
347
 
305
- result = self.instance_exec(result, &@post_process) if Proc === @post_process
348
+ result = self.instance_exec(result, false, &@post_process) if Proc === @post_process
306
349
 
307
350
  result
308
351
  end
@@ -321,12 +364,12 @@ cat(paste(label, sep="\\n", collapse="\\n"));
321
364
 
322
365
  result = case
323
366
  when Proc === eval_model
324
- self.instance_exec(@model_file, features, true, nil, @names, @factor_levels, &@eval_model)
367
+ self.instance_exec(features, true, nil, @names, @factor_levels, &@eval_model)
325
368
  when String === eval_model
326
- VectorModel.R_eval(@model_file, features, true, eval_model, @names, @factor_levels)
369
+ VectorModel.R_eval(@model_path, features, true, eval_model, @names, @factor_levels)
327
370
  end
328
371
 
329
- result = self.instance_exec(result, &@post_process) if Proc === @post_process
372
+ result = self.instance_exec(result, true, &@post_process) if Proc === @post_process
330
373
 
331
374
  result
332
375
  end
@@ -438,6 +481,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
438
481
  @features = orig_features
439
482
  @labels = orig_labels
440
483
  end unless folds == -1
484
+
441
485
  self.reset_model if self.respond_to? :reset_model
442
486
  self.train unless folds == 1
443
487
  res
@@ -1 +1,48 @@
1
- # Keep
1
+ from torch.utils.data import Dataset, DataLoader
2
+
3
+ class TSVDataset(Dataset):
4
+ def __init__(self, tsv):
5
+ self.tsv = tsv
6
+
7
+ def __getitem__(self, key):
8
+ if (type(key) == int):
9
+ row = self.tsv.iloc[key]
10
+ else:
11
+ row = self.tsv.loc[key]
12
+
13
+ row = row.to_numpy()
14
+ features = row[:-1]
15
+ label = row[-1]
16
+
17
+ return features, label
18
+
19
+ def __len__(self):
20
+ return len(self.tsv)
21
+
22
+ def tsv_dataset(filename, *args, **kwargs):
23
+ import rbbt
24
+ return TSVDataset(rbbt.tsv(filename, *args, **kwargs))
25
+
26
+ def tsv(*args, **kwargs):
27
+ return tsv_dataset(*args, **kwargs)
28
+
29
+ def data_dir():
30
+ import rbbt
31
+ return rbbt.path('var/rbbt_dm/data')
32
+
33
+ if __name__ == "__main__":
34
+ import rbbt
35
+
36
+ filename = "/home/miki/test/numeric.tsv"
37
+ ds = tsv(filename)
38
+
39
+ dl = DataLoader(ds, batch_size=1)
40
+
41
+ for f, l in iter(dl):
42
+ print(".")
43
+ print(f[0,:])
44
+ print(l[0])
45
+
46
+
47
+
48
+
File without changes
@@ -0,0 +1,141 @@
1
+ from matplotlib import pyplot as plt
2
+ import numpy as np
3
+ import torch
4
+ from IPython.display import HTML, display
5
+
6
+
7
+ def set_default(figsize=(10, 10), dpi=100):
8
+ plt.style.use(['dark_background', 'bmh'])
9
+ plt.rc('axes', facecolor='k')
10
+ plt.rc('figure', facecolor='k')
11
+ plt.rc('figure', figsize=figsize, dpi=dpi)
12
+
13
+
14
+ def plot_data(X, y, d=0, auto=False, zoom=1):
15
+ X = X.cpu()
16
+ y = y.cpu()
17
+ plt.scatter(X.numpy()[:, 0], X.numpy()[:, 1], c=y, s=20, cmap=plt.cm.Spectral)
18
+ plt.axis('square')
19
+ plt.axis(np.array((-1.1, 1.1, -1.1, 1.1)) * zoom)
20
+ if auto is True: plt.axis('equal')
21
+ plt.axis('off')
22
+
23
+ _m, _c = 0, '.15'
24
+ plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0)
25
+ plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0)
26
+
27
+
28
+ def plot_model(X, y, model):
29
+ model.cpu()
30
+ mesh = np.arange(-1.1, 1.1, 0.01)
31
+ xx, yy = np.meshgrid(mesh, mesh)
32
+ with torch.no_grad():
33
+ data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float()
34
+ Z = model(data).detach()
35
+ Z = np.argmax(Z, axis=1).reshape(xx.shape)
36
+ plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3)
37
+ plot_data(X, y)
38
+
39
+
40
+ def show_scatterplot(X, colors, title=''):
41
+ colors = colors.cpu().numpy()
42
+ X = X.cpu().numpy()
43
+ plt.figure()
44
+ plt.axis('equal')
45
+ plt.scatter(X[:, 0], X[:, 1], c=colors, s=30)
46
+ # plt.grid(True)
47
+ plt.title(title)
48
+ plt.axis('off')
49
+
50
+
51
+ def plot_bases(bases, width=0.04):
52
+ bases = bases.cpu()
53
+ bases[2:] -= bases[:2]
54
+ plt.arrow(*bases[0], *bases[2], width=width, color=(1,0,0), zorder=10, alpha=1., length_includes_head=True)
55
+ plt.arrow(*bases[1], *bases[3], width=width, color=(0,1,0), zorder=10, alpha=1., length_includes_head=True)
56
+
57
+
58
+ def show_mat(mat, vect, prod, threshold=-1):
59
+ # Subplot grid definition
60
+ fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=False, sharey=True,
61
+ gridspec_kw={'width_ratios':[5,1,1]})
62
+ # Plot matrices
63
+ cax1 = ax1.matshow(mat.numpy(), clim=(-1, 1))
64
+ ax2.matshow(vect.numpy(), clim=(-1, 1))
65
+ cax3 = ax3.matshow(prod.numpy(), clim=(threshold, 1))
66
+
67
+ # Set titles
68
+ ax1.set_title(f'A: {mat.size(0)} \u00D7 {mat.size(1)}')
69
+ ax2.set_title(f'a^(i): {vect.numel()}')
70
+ ax3.set_title(f'p: {prod.numel()}')
71
+
72
+ # Remove xticks for vectors
73
+ ax2.set_xticks(tuple())
74
+ ax3.set_xticks(tuple())
75
+
76
+ # Plot colourbars
77
+ fig.colorbar(cax1, ax=ax2)
78
+ fig.colorbar(cax3, ax=ax3)
79
+
80
+ # Fix y-axis limits
81
+ ax1.set_ylim(bottom=max(len(prod), len(vect)) - 0.5)
82
+
83
+
84
+ colors = dict(
85
+ aqua='#8dd3c7',
86
+ yellow='#ffffb3',
87
+ lavender='#bebada',
88
+ red='#fb8072',
89
+ blue='#80b1d3',
90
+ orange='#fdb462',
91
+ green='#b3de69',
92
+ pink='#fccde5',
93
+ grey='#d9d9d9',
94
+ violet='#bc80bd',
95
+ unk1='#ccebc5',
96
+ unk2='#ffed6f',
97
+ )
98
+
99
+
100
+ def _cstr(s, color='black'):
101
+ if s == ' ':
102
+ return f'<text style=color:#000;padding-left:10px;background-color:{color}> </text>'
103
+ else:
104
+ return f'<text style=color:#000;background-color:{color}>{s} </text>'
105
+
106
+ # print html
107
+ def _print_color(t):
108
+ display(HTML(''.join([_cstr(ti, color=ci) for ti, ci in t])))
109
+
110
+ # get appropriate color for value
111
+ def _get_clr(value):
112
+ colors = ('#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8',
113
+ '#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8',
114
+ '#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f',
115
+ '#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e')
116
+ value = int((value * 100) / 5)
117
+ if value == len(colors): value -= 1 # fixing bugs...
118
+ return colors[value]
119
+
120
+ def _visualise_values(output_values, result_list):
121
+ text_colours = []
122
+ for i in range(len(output_values)):
123
+ text = (result_list[i], _get_clr(output_values[i]))
124
+ text_colours.append(text)
125
+ _print_color(text_colours)
126
+
127
+ def print_colourbar():
128
+ color_range = torch.linspace(-2.5, 2.5, 20)
129
+ to_print = [(f'{x:.2f}', _get_clr((x+2.5)/5)) for x in color_range]
130
+ _print_color(to_print)
131
+
132
+
133
+ # Let's only focus on the last time step for now
134
+ # First, the cell state (Long term memory)
135
+ def plot_state(data, state, b, decoder):
136
+ actual_data = decoder(data[b, :, :].numpy())
137
+ seq_len = len(actual_data)
138
+ seq_len_w_pad = len(state)
139
+ for s in range(state.size(2)):
140
+ states = torch.sigmoid(state[:, b, s])
141
+ _visualise_values(states[seq_len_w_pad - seq_len:], list(actual_data))
@@ -0,0 +1,27 @@
1
+ import torch
2
+ import math
3
+ def spiral_data(N=1000, D=2, C=3):
4
+ X = torch.zeros(N * C, D)
5
+ y = torch.zeros(N * C, dtype=torch.long)
6
+ for c in range(C):
7
+ index = 0
8
+ t = torch.linspace(0, 1, N)
9
+ # When c = 0 and t = 0: start of linspace
10
+ # When c = 0 and t = 1: end of linpace
11
+ # This inner_var is for the formula inside sin() and cos() like sin(inner_var) and cos(inner_Var)
12
+ inner_var = torch.linspace(
13
+ # When t = 0
14
+ (2 * math.pi / C) * (c),
15
+ # When t = 1
16
+ (2 * math.pi / C) * (2 + c),
17
+ N
18
+ ) + torch.randn(N) * 0.2
19
+
20
+ for ix in range(N * c, N * (c + 1)):
21
+ X[ix] = t[index] * torch.FloatTensor((
22
+ math.sin(inner_var[index]), math.cos(inner_var[index])
23
+ ))
24
+ y[ix] = c
25
+ index += 1
26
+
27
+ return (X, y)