rbbt-dm 1.2.7 → 1.2.10

Sign up to get free protection for your applications and to get access to all the features.
Files changed (40) hide show
  1. checksums.yaml +4 -4
  2. data/lib/rbbt/matrix/barcode.rb +2 -2
  3. data/lib/rbbt/matrix/differential.rb +3 -3
  4. data/lib/rbbt/matrix/knowledge_base.rb +1 -1
  5. data/lib/rbbt/plots/bar.rb +1 -1
  6. data/lib/rbbt/stan.rb +1 -1
  7. data/lib/rbbt/statistics/hypergeometric.rb +2 -1
  8. data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
  9. data/lib/rbbt/vector/model/huggingface.rb +39 -52
  10. data/lib/rbbt/vector/model/python.rb +33 -0
  11. data/lib/rbbt/vector/model/pytorch_lightning.rb +31 -0
  12. data/lib/rbbt/vector/model/random_forest.rb +1 -1
  13. data/lib/rbbt/vector/model/spaCy.rb +8 -6
  14. data/lib/rbbt/vector/model/tensorflow.rb +6 -5
  15. data/lib/rbbt/vector/model/torch/dataloader.rb +58 -0
  16. data/lib/rbbt/vector/model/torch/helpers.rb +52 -0
  17. data/lib/rbbt/vector/model/torch/introspection.rb +31 -0
  18. data/lib/rbbt/vector/model/torch/load_and_save.rb +30 -0
  19. data/lib/rbbt/vector/model/torch.rb +71 -0
  20. data/lib/rbbt/vector/model.rb +84 -54
  21. data/python/rbbt_dm/__init__.py +31 -1
  22. data/python/rbbt_dm/atcold/__init__.py +0 -0
  23. data/python/rbbt_dm/atcold/plot_lib.py +141 -0
  24. data/python/rbbt_dm/atcold/spiral.py +27 -0
  25. data/python/rbbt_dm/huggingface.py +64 -28
  26. data/python/rbbt_dm/language_model.py +70 -0
  27. data/python/rbbt_dm/util.py +32 -0
  28. data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
  29. data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
  30. data/test/rbbt/vector/model/test_huggingface.rb +258 -27
  31. data/test/rbbt/vector/model/test_python.rb +31 -0
  32. data/test/rbbt/vector/model/test_pytorch_lightning.rb +97 -0
  33. data/test/rbbt/vector/model/test_spaCy.rb +1 -1
  34. data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
  35. data/test/rbbt/vector/model/test_torch.rb +61 -0
  36. data/test/rbbt/vector/test_model.rb +25 -26
  37. data/test/test_helper.rb +13 -0
  38. metadata +35 -16
  39. data/lib/rbbt/tensorflow.rb +0 -43
  40. data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
@@ -0,0 +1,71 @@
1
+ require_relative 'python'
2
+
3
+ class TorchModel < PythonModel
4
+
5
+ attr_accessor :model, :criterion, :optimizer, :training_args
6
+
7
+ def initialize(...)
8
+ TorchModel.init_python
9
+ super(...)
10
+ @training_args = model_options[:training_args] || {}
11
+
12
+ init_model do
13
+ model = TorchModel.load_architecture(model_path)
14
+ if model.nil?
15
+ RbbtPython.add_path @directory
16
+ RbbtPython.class_new_obj(@python_module, @python_class, **model_options)
17
+ else
18
+ TorchModel.load_state(model, model_path)
19
+ end
20
+ end
21
+
22
+ eval_model do |features,list=false|
23
+ init
24
+ @device ||= TorchModel.device(model_options)
25
+ @dtype ||= TorchModel.dtype(model_options)
26
+ model.to(@device)
27
+
28
+ tensor = list ? TorchModel.tensor(features, @device, @dtype) : TorchModel.tensor([features], @device, @dtype)
29
+
30
+ loss, res = model.call(tensor)
31
+
32
+ res = loss if res.nil?
33
+
34
+ res = TorchModel::Tensor.setup(list ? res : res[0])
35
+
36
+ res
37
+ end
38
+
39
+ train_model do |features,labels|
40
+ init
41
+ @device ||= TorchModel.device(model_options)
42
+ @dtype ||= TorchModel.dtype(model_options)
43
+ model.to(@device)
44
+ @optimizer ||= TorchModel.optimizer(model, training_args)
45
+ epochs = training_args[:epochs] || 3
46
+
47
+ inputs = TorchModel.tensor(features, @device, @dtype)
48
+ #target = TorchModel.tensor(labels.collect{|v| [v] }, @device, @dtype)
49
+ target = TorchModel.tensor(labels, @device, @dtype)
50
+
51
+ Log::ProgressBar.with_bar epochs, :desc => "Training" do |bar|
52
+ epochs.times do |i|
53
+ @optimizer.zero_grad()
54
+ outputs = model.call(inputs)
55
+ outputs = outputs.squeeze() if target.dim() == 1
56
+ loss = criterion.call(outputs, target)
57
+ loss.backward()
58
+ @optimizer.step
59
+ Log.debug "Epoch #{i}, loss #{loss}"
60
+ bar.tick
61
+ end
62
+ end
63
+ TorchModel.save_architecture(model, model_path) if @directory
64
+ TorchModel.save_state(model, model_path) if @directory
65
+ end
66
+ end
67
+ end
68
+ require_relative 'torch/helpers'
69
+ require_relative 'torch/dataloader'
70
+ require_relative 'torch/introspection'
71
+ require_relative 'torch/load_and_save'
@@ -1,16 +1,25 @@
1
1
  require 'rbbt/util/R'
2
2
  require 'rbbt/vector/model/util'
3
+ require 'rbbt/util/python'
4
+
5
+ RbbtPython.add_path Rbbt.python.find(:lib)
6
+ RbbtPython.init_rbbt
3
7
 
4
8
  class VectorModel
5
- attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process, :balance
9
+ attr_accessor :directory, :model_path, :extract_features, :init_model, :train_model, :eval_model, :post_process, :balance
6
10
  attr_accessor :features, :names, :labels, :factor_levels
7
- attr_accessor :model_options
11
+ attr_accessor :model, :model_options
8
12
 
9
13
  def extract_features(&block)
10
14
  @extract_features = block if block_given?
11
15
  @extract_features
12
16
  end
13
17
 
18
+ def init_model(&block)
19
+ @init_model = block if block_given?
20
+ @init_model
21
+ end
22
+
14
23
  def train_model(&block)
15
24
  @train_model = block if block_given?
16
25
  @train_model
@@ -21,13 +30,17 @@ class VectorModel
21
30
  @eval_model
22
31
  end
23
32
 
33
+ def init
34
+ @model ||= self.instance_exec &@init_model
35
+ end
36
+
24
37
  def post_process(&block)
25
38
  @post_process = block if block_given?
26
39
  @post_process
27
40
  end
28
41
 
29
42
 
30
- def self.R_run(model_file, features, labels, code, names = nil, factor_levels = nil)
43
+ def self.R_run(model_path, features, labels, code, names = nil, factor_levels = nil)
31
44
  TmpFile.with_file do |feature_file|
32
45
  Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
33
46
  Open.write(feature_file + '.label', labels * "\n" + "\n")
@@ -54,7 +67,7 @@ features = cbind(features, label = labels);
54
67
  end
55
68
  end
56
69
 
57
- def self.R_train(model_file, features, labels, code, names = nil, factor_levels = nil)
70
+ def self.R_train(model_path, features, labels, code, names = nil, factor_levels = nil)
58
71
  TmpFile.with_file do |feature_file|
59
72
  Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
60
73
  Open.write(feature_file + '.label', labels * "\n" + "\n")
@@ -82,13 +95,13 @@ for (c in names(features)){
82
95
  if (is.factor(features[[c]]))
83
96
  factor_levels[c] = paste(levels(features[[c]]), collapse="\t")
84
97
  }
85
- rbbt.tsv.write("#{model_file}.factor_levels", factor_levels, names=c('Levels'), type='flat')
86
- save(model, file='#{model_file}')
98
+ rbbt.tsv.write("#{model_path}.factor_levels", factor_levels, names=c('Levels'), type='flat')
99
+ save(model, file='#{model_path}')
87
100
  EOF
88
101
  end
89
102
  end
90
103
 
91
- def self.R_eval(model_file, features, list, code, names = nil, factor_levels = nil)
104
+ def self.R_eval(model_path, features, list, code, names = nil, factor_levels = nil)
92
105
  TmpFile.with_file do |feature_file|
93
106
  if list
94
107
  Open.write(feature_file, features.collect{|feat| feat * "\t"} * "\n" + "\n")
@@ -105,7 +118,7 @@ features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=TRUE);
105
118
  #{ factor_levels.collect do |name,levels|
106
119
  "features[['#{name}']] = factor(features[['#{name}']], levels=#{R.ruby2R levels})"
107
120
  end * "\n" if factor_levels }
108
- load(file="#{model_file}");
121
+ load(file="#{model_path}");
109
122
  #{code}
110
123
  cat(paste(label, sep="\\n", collapse="\\n"));
111
124
  EOF
@@ -127,61 +140,77 @@ cat(paste(label, sep="\\n", collapse="\\n"));
127
140
  instance_eval code, file
128
141
  end
129
142
 
130
- def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, post_process = nil, names = nil, factor_levels = nil)
143
+ def initialize(directory = nil, model_options = {})
131
144
  @directory = directory
145
+ @model_options = IndiferentHash.setup(model_options)
146
+
132
147
  if @directory
133
- FileUtils.mkdir_p @directory unless File.exists?(@directory)
148
+ FileUtils.mkdir_p @directory unless File.exist?(@directory)
149
+
150
+ @model_path = File.join(@directory, "model")
134
151
 
135
- @model_file = File.join(@directory, "model")
136
152
  @extract_features_file = File.join(@directory, "features")
137
- @train_model_file = File.join(@directory, "train_model")
138
- @eval_model_file = File.join(@directory, "eval_model")
139
- @post_process_file = File.join(@directory, "post_process")
140
- @train_model_file_R = File.join(@directory, "train_model.R")
141
- @eval_model_file_R = File.join(@directory, "eval_model.R")
142
- @post_process_file_R = File.join(@directory, "post_process.R")
143
- @names_file = File.join(@directory, "feature_names")
144
- @levels_file = File.join(@directory, "levels")
145
- @options_file = File.join(@directory, "options.json")
146
-
147
- if File.exists?(@options_file)
148
- @model_options = JSON.parse(Open.read(@options_file))
153
+ @init_model_path = File.join(@directory, "init_model")
154
+
155
+ @train_model_path = File.join(@directory, "train_model")
156
+ @train_model_path_R = File.join(@directory, "train_model.R")
157
+
158
+ @eval_model_path = File.join(@directory, "eval_model")
159
+ @eval_model_path_R = File.join(@directory, "eval_model.R")
160
+
161
+ @post_process_file = File.join(@directory, "post_process")
162
+ @post_process_file_R = File.join(@directory, "post_process.R")
163
+
164
+ @names_file = File.join(@directory, "feature_names")
165
+ @levels_file = File.join(@directory, "levels")
166
+ @options_file = File.join(@directory, "options.json")
167
+
168
+ if File.exist?(@options_file)
169
+ @model_options = JSON.parse(Open.read(@options_file)).merge(@model_options || {})
149
170
  IndiferentHash.setup(@model_options)
150
171
  end
151
172
  end
152
-
173
+
153
174
  if extract_features.nil?
154
- if @extract_features_file && File.exists?(@extract_features_file)
175
+ if @extract_features_file && File.exist?(@extract_features_file)
155
176
  @extract_features = __load_method @extract_features_file
156
177
  end
157
178
  else
158
179
  @extract_features = extract_features
159
180
  end
160
181
 
182
+ if init_model.nil?
183
+ if @init_model_path && File.exist?(@init_model_path)
184
+ @init_model = __load_method @init_model_path
185
+ end
186
+ else
187
+ @init_model = init_model
188
+ end
189
+
161
190
  if train_model.nil?
162
- if @train_model_file && File.exists?(@train_model_file)
163
- @train_model = __load_method @train_model_file
164
- elsif @train_model_file_R && File.exists?(@train_model_file_R)
165
- @train_model = Open.read(@train_model_file_R)
191
+ if @train_model_path && File.exist?(@train_model_path)
192
+ @train_model = __load_method @train_model_path
193
+ elsif @train_model_path_R && File.exist?(@train_model_path_R)
194
+ @train_model = Open.read(@train_model_path_R)
166
195
  end
167
196
  else
168
197
  @train_model = train_model
169
198
  end
170
199
 
171
200
  if eval_model.nil?
172
- if @eval_model_file && File.exists?(@eval_model_file)
173
- @eval_model = __load_method @eval_model_file
174
- elsif @eval_model_file_R && File.exists?(@eval_model_file_R)
175
- @eval_model = Open.read(@eval_model_file_R)
201
+ if @eval_model_path && File.exist?(@eval_model_path)
202
+ @eval_model = __load_method @eval_model_path
203
+ elsif @eval_model_path_R && File.exist?(@eval_model_path_R)
204
+ @eval_model = Open.read(@eval_model_path_R)
176
205
  end
177
206
  else
178
207
  @eval_model = eval_model
179
208
  end
180
209
 
181
210
  if post_process.nil?
182
- if @post_process_file && File.exists?(@post_process_file)
211
+ if @post_process_file && File.exist?(@post_process_file)
183
212
  @post_process = __load_method @post_process_file
184
- elsif @post_process_file_R && File.exists?(@post_process_file_R)
213
+ elsif @post_process_file_R && File.exist?(@post_process_file_R)
185
214
  @post_process = Open.read(@post_process_file_R)
186
215
  end
187
216
  else
@@ -190,7 +219,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
190
219
 
191
220
 
192
221
  if names.nil?
193
- if @names_file && File.exists?(@names_file)
222
+ if @names_file && File.exist?(@names_file)
194
223
  @names = Open.read(@names_file).split("\n")
195
224
  end
196
225
  else
@@ -198,11 +227,11 @@ cat(paste(label, sep="\\n", collapse="\\n"));
198
227
  end
199
228
 
200
229
  if factor_levels.nil?
201
- if @levels_file && File.exists?(@levels_file)
230
+ if @levels_file && File.exist?(@levels_file)
202
231
  @factor_levels = YAML.load(Open.read(@levels_file))
203
232
  end
204
- if @model_file && File.exists?(@model_file + '.factor_levels')
205
- @factor_levels = TSV.open(@model_file + '.factor_levels')
233
+ if @model_path && File.exist?(@model_path + '.factor_levels')
234
+ @factor_levels = TSV.open(@model_path + '.factor_levels')
206
235
  end
207
236
  else
208
237
  @factor_levels = factor_levels
@@ -241,23 +270,24 @@ cat(paste(label, sep="\\n", collapse="\\n"));
241
270
  case
242
271
  when Proc === train_model
243
272
  begin
244
- Open.write(@train_model_file, train_model.source)
273
+ Open.write(@train_model_path, train_model.source)
245
274
  rescue
246
275
  end
247
276
  when String === train_model
248
- Open.write(@train_model_file_R, @train_model)
277
+ Open.write(@train_model_path_R, @train_model)
249
278
  end
250
279
 
251
280
  Open.write(@extract_features_file, @extract_features.source) if @extract_features
281
+ Open.write(@init_model_path, @init_model.source) if @init_model
252
282
 
253
283
  case
254
284
  when Proc === eval_model
255
285
  begin
256
- Open.write(@eval_model_file, eval_model.source)
286
+ Open.write(@eval_model_path, eval_model.source)
257
287
  rescue
258
288
  end
259
289
  when String === eval_model
260
- Open.write(@eval_model_file_R, eval_model)
290
+ Open.write(@eval_model_path_R, eval_model)
261
291
  end
262
292
 
263
293
  case
@@ -285,9 +315,9 @@ cat(paste(label, sep="\\n", collapse="\\n"));
285
315
 
286
316
  case
287
317
  when Proc === @train_model
288
- self.instance_exec(@model_file, @features, @labels, @names, @factor_levels, &@train_model)
318
+ self.instance_exec(@features, @labels, @names, @factor_levels, &@train_model)
289
319
  when String === @train_model
290
- VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
320
+ VectorModel.R_train(@model_path, @features, @labels, train_model, @names, @factor_levels)
291
321
  end
292
322
  ensure
293
323
  if @balance
@@ -300,7 +330,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
300
330
  end
301
331
 
302
332
  def run(code)
303
- VectorModel.R_run(@model_file, @features, @labels, code, @names, @factor_levels)
333
+ VectorModel.R_run(@model_path, @features, @labels, code, @names, @factor_levels)
304
334
  end
305
335
 
306
336
  def eval(element)
@@ -308,14 +338,14 @@ cat(paste(label, sep="\\n", collapse="\\n"));
308
338
 
309
339
  result = case
310
340
  when Proc === @eval_model
311
- self.instance_exec(@model_file, features, false, nil, @names, @factor_levels, &@eval_model)
341
+ self.instance_exec(features, false, nil, @names, @factor_levels, &@eval_model)
312
342
  when String === @eval_model
313
- VectorModel.R_eval(@model_file, features, false, eval_model, @names, @factor_levels)
343
+ VectorModel.R_eval(@model_path, features, false, eval_model, @names, @factor_levels)
314
344
  else
315
345
  raise "No @eval_model function or R script"
316
346
  end
317
347
 
318
- result = self.instance_exec(result, &@post_process) if Proc === @post_process
348
+ result = self.instance_exec(result, false, &@post_process) if Proc === @post_process
319
349
 
320
350
  result
321
351
  end
@@ -334,12 +364,12 @@ cat(paste(label, sep="\\n", collapse="\\n"));
334
364
 
335
365
  result = case
336
366
  when Proc === eval_model
337
- self.instance_exec(@model_file, features, true, nil, @names, @factor_levels, &@eval_model)
367
+ self.instance_exec(features, true, nil, @names, @factor_levels, &@eval_model)
338
368
  when String === eval_model
339
- VectorModel.R_eval(@model_file, features, true, eval_model, @names, @factor_levels)
369
+ VectorModel.R_eval(@model_path, features, true, eval_model, @names, @factor_levels)
340
370
  end
341
371
 
342
- result = self.instance_exec(result, &@post_process) if Proc === @post_process
372
+ result = self.instance_exec(result, true, &@post_process) if Proc === @post_process
343
373
 
344
374
  result
345
375
  end
@@ -418,10 +448,10 @@ cat(paste(label, sep="\\n", collapse="\\n"));
418
448
  end
419
449
 
420
450
  test_set = feature_folds[fix]
421
- train_set = feature_folds.values_at(*rest).inject([]){|acc,e| acc += e; acc}
451
+ train_set = feature_folds.values_at(*rest).flatten(1)
422
452
 
423
453
  test_labels = labels_folds[fix]
424
- train_labels = labels_folds.values_at(*rest).flatten
454
+ train_labels = labels_folds.values_at(*rest).flatten(1)
425
455
 
426
456
  @features = train_set
427
457
  @labels = train_labels
@@ -1 +1,31 @@
1
- # Keep
1
+ import rbbt
2
+ import torch
3
+ from .util import *
4
+
5
+ class TSVDataset(torch.utils.data.Dataset):
6
+ def __init__(self, tsv):
7
+ self.tsv = tsv
8
+
9
+ def __getitem__(self, key):
10
+ if (type(key) == int):
11
+ row = self.tsv.iloc[key]
12
+ else:
13
+ row = self.tsv.loc[key]
14
+
15
+ row = row.to_numpy()
16
+ features = row[:-1]
17
+ label = row[-1]
18
+
19
+ return features, label
20
+
21
+ def __len__(self):
22
+ return len(self.tsv)
23
+
24
+ def tsv_dataset(filename, *args, **kwargs):
25
+ return TSVDataset(rbbt.tsv(filename, *args, **kwargs))
26
+
27
+ def tsv(*args, **kwargs):
28
+ return tsv_dataset(*args, **kwargs)
29
+
30
+ def data_dir():
31
+ return rbbt.path('var/rbbt_dm/data')
File without changes
@@ -0,0 +1,141 @@
1
+ from matplotlib import pyplot as plt
2
+ import numpy as np
3
+ import torch
4
+ from IPython.display import HTML, display
5
+
6
+
7
+ def set_default(figsize=(10, 10), dpi=100):
8
+ plt.style.use(['dark_background', 'bmh'])
9
+ plt.rc('axes', facecolor='k')
10
+ plt.rc('figure', facecolor='k')
11
+ plt.rc('figure', figsize=figsize, dpi=dpi)
12
+
13
+
14
+ def plot_data(X, y, d=0, auto=False, zoom=1):
15
+ X = X.cpu()
16
+ y = y.cpu()
17
+ plt.scatter(X.numpy()[:, 0], X.numpy()[:, 1], c=y, s=20, cmap=plt.cm.Spectral)
18
+ plt.axis('square')
19
+ plt.axis(np.array((-1.1, 1.1, -1.1, 1.1)) * zoom)
20
+ if auto is True: plt.axis('equal')
21
+ plt.axis('off')
22
+
23
+ _m, _c = 0, '.15'
24
+ plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0)
25
+ plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0)
26
+
27
+
28
+ def plot_model(X, y, model):
29
+ model.cpu()
30
+ mesh = np.arange(-1.1, 1.1, 0.01)
31
+ xx, yy = np.meshgrid(mesh, mesh)
32
+ with torch.no_grad():
33
+ data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float()
34
+ Z = model(data).detach()
35
+ Z = np.argmax(Z, axis=1).reshape(xx.shape)
36
+ plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3)
37
+ plot_data(X, y)
38
+
39
+
40
+ def show_scatterplot(X, colors, title=''):
41
+ colors = colors.cpu().numpy()
42
+ X = X.cpu().numpy()
43
+ plt.figure()
44
+ plt.axis('equal')
45
+ plt.scatter(X[:, 0], X[:, 1], c=colors, s=30)
46
+ # plt.grid(True)
47
+ plt.title(title)
48
+ plt.axis('off')
49
+
50
+
51
+ def plot_bases(bases, width=0.04):
52
+ bases = bases.cpu()
53
+ bases[2:] -= bases[:2]
54
+ plt.arrow(*bases[0], *bases[2], width=width, color=(1,0,0), zorder=10, alpha=1., length_includes_head=True)
55
+ plt.arrow(*bases[1], *bases[3], width=width, color=(0,1,0), zorder=10, alpha=1., length_includes_head=True)
56
+
57
+
58
+ def show_mat(mat, vect, prod, threshold=-1):
59
+ # Subplot grid definition
60
+ fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=False, sharey=True,
61
+ gridspec_kw={'width_ratios':[5,1,1]})
62
+ # Plot matrices
63
+ cax1 = ax1.matshow(mat.numpy(), clim=(-1, 1))
64
+ ax2.matshow(vect.numpy(), clim=(-1, 1))
65
+ cax3 = ax3.matshow(prod.numpy(), clim=(threshold, 1))
66
+
67
+ # Set titles
68
+ ax1.set_title(f'A: {mat.size(0)} \u00D7 {mat.size(1)}')
69
+ ax2.set_title(f'a^(i): {vect.numel()}')
70
+ ax3.set_title(f'p: {prod.numel()}')
71
+
72
+ # Remove xticks for vectors
73
+ ax2.set_xticks(tuple())
74
+ ax3.set_xticks(tuple())
75
+
76
+ # Plot colourbars
77
+ fig.colorbar(cax1, ax=ax2)
78
+ fig.colorbar(cax3, ax=ax3)
79
+
80
+ # Fix y-axis limits
81
+ ax1.set_ylim(bottom=max(len(prod), len(vect)) - 0.5)
82
+
83
+
84
+ colors = dict(
85
+ aqua='#8dd3c7',
86
+ yellow='#ffffb3',
87
+ lavender='#bebada',
88
+ red='#fb8072',
89
+ blue='#80b1d3',
90
+ orange='#fdb462',
91
+ green='#b3de69',
92
+ pink='#fccde5',
93
+ grey='#d9d9d9',
94
+ violet='#bc80bd',
95
+ unk1='#ccebc5',
96
+ unk2='#ffed6f',
97
+ )
98
+
99
+
100
+ def _cstr(s, color='black'):
101
+ if s == ' ':
102
+ return f'<text style=color:#000;padding-left:10px;background-color:{color}> </text>'
103
+ else:
104
+ return f'<text style=color:#000;background-color:{color}>{s} </text>'
105
+
106
+ # print html
107
+ def _print_color(t):
108
+ display(HTML(''.join([_cstr(ti, color=ci) for ti, ci in t])))
109
+
110
+ # get appropriate color for value
111
+ def _get_clr(value):
112
+ colors = ('#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8',
113
+ '#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8',
114
+ '#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f',
115
+ '#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e')
116
+ value = int((value * 100) / 5)
117
+ if value == len(colors): value -= 1 # fixing bugs...
118
+ return colors[value]
119
+
120
+ def _visualise_values(output_values, result_list):
121
+ text_colours = []
122
+ for i in range(len(output_values)):
123
+ text = (result_list[i], _get_clr(output_values[i]))
124
+ text_colours.append(text)
125
+ _print_color(text_colours)
126
+
127
+ def print_colourbar():
128
+ color_range = torch.linspace(-2.5, 2.5, 20)
129
+ to_print = [(f'{x:.2f}', _get_clr((x+2.5)/5)) for x in color_range]
130
+ _print_color(to_print)
131
+
132
+
133
+ # Let's only focus on the last time step for now
134
+ # First, the cell state (Long term memory)
135
+ def plot_state(data, state, b, decoder):
136
+ actual_data = decoder(data[b, :, :].numpy())
137
+ seq_len = len(actual_data)
138
+ seq_len_w_pad = len(state)
139
+ for s in range(state.size(2)):
140
+ states = torch.sigmoid(state[:, b, s])
141
+ _visualise_values(states[seq_len_w_pad - seq_len:], list(actual_data))
@@ -0,0 +1,27 @@
1
+ import torch
2
+ import math
3
+ def spiral_data(N=1000, D=2, C=3):
4
+ X = torch.zeros(N * C, D)
5
+ y = torch.zeros(N * C, dtype=torch.long)
6
+ for c in range(C):
7
+ index = 0
8
+ t = torch.linspace(0, 1, N)
9
+ # When c = 0 and t = 0: start of linspace
10
+ # When c = 0 and t = 1: end of linpace
11
+ # This inner_var is for the formula inside sin() and cos() like sin(inner_var) and cos(inner_Var)
12
+ inner_var = torch.linspace(
13
+ # When t = 0
14
+ (2 * math.pi / C) * (c),
15
+ # When t = 1
16
+ (2 * math.pi / C) * (2 + c),
17
+ N
18
+ ) + torch.randn(N) * 0.2
19
+
20
+ for ix in range(N * c, N * (c + 1)):
21
+ X[ix] = t[index] * torch.FloatTensor((
22
+ math.sin(inner_var[index]), math.cos(inner_var[index])
23
+ ))
24
+ y[ix] = c
25
+ index += 1
26
+
27
+ return (X, y)