rbbt-dm 1.2.6 → 1.2.9
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/lib/rbbt/matrix/barcode.rb +2 -2
- data/lib/rbbt/matrix/differential.rb +3 -3
- data/lib/rbbt/matrix/knowledge_base.rb +1 -1
- data/lib/rbbt/plots/bar.rb +1 -1
- data/lib/rbbt/stan.rb +1 -1
- data/lib/rbbt/statistics/hypergeometric.rb +2 -1
- data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
- data/lib/rbbt/vector/model/huggingface.rb +57 -38
- data/lib/rbbt/vector/model/pytorch_lightning.rb +35 -0
- data/lib/rbbt/vector/model/random_forest.rb +1 -1
- data/lib/rbbt/vector/model/spaCy.rb +8 -14
- data/lib/rbbt/vector/model/tensorflow.rb +6 -5
- data/lib/rbbt/vector/model/torch.rb +37 -0
- data/lib/rbbt/vector/model/util.rb +18 -0
- data/lib/rbbt/vector/model.rb +100 -56
- data/python/rbbt_dm/__init__.py +48 -1
- data/python/rbbt_dm/atcold/__init__.py +0 -0
- data/python/rbbt_dm/atcold/plot_lib.py +141 -0
- data/python/rbbt_dm/atcold/spiral.py +27 -0
- data/python/rbbt_dm/huggingface.py +57 -26
- data/python/rbbt_dm/language_model.py +70 -0
- data/python/rbbt_dm/util.py +30 -0
- data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
- data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
- data/test/rbbt/vector/model/test_huggingface.rb +258 -27
- data/test/rbbt/vector/model/test_pytorch_lightning.rb +83 -0
- data/test/rbbt/vector/model/test_spaCy.rb +1 -1
- data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
- data/test/rbbt/vector/test_model.rb +25 -26
- data/test/test_helper.rb +13 -0
- metadata +26 -16
- data/lib/rbbt/tensorflow.rb +0 -43
- data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
data/lib/rbbt/vector/model.rb
CHANGED
@@ -1,16 +1,25 @@
|
|
1
1
|
require 'rbbt/util/R'
|
2
2
|
require 'rbbt/vector/model/util'
|
3
|
+
require 'rbbt/util/python'
|
4
|
+
|
5
|
+
RbbtPython.add_path Rbbt.python.find(:lib)
|
6
|
+
RbbtPython.init_rbbt
|
3
7
|
|
4
8
|
class VectorModel
|
5
|
-
attr_accessor :directory, :
|
9
|
+
attr_accessor :directory, :model_path, :extract_features, :init_model, :train_model, :eval_model, :post_process, :balance
|
6
10
|
attr_accessor :features, :names, :labels, :factor_levels
|
7
|
-
attr_accessor :model_options
|
11
|
+
attr_accessor :model, :model_options
|
8
12
|
|
9
13
|
def extract_features(&block)
|
10
14
|
@extract_features = block if block_given?
|
11
15
|
@extract_features
|
12
16
|
end
|
13
17
|
|
18
|
+
def init_model(&block)
|
19
|
+
@init_model = block if block_given?
|
20
|
+
@init_model
|
21
|
+
end
|
22
|
+
|
14
23
|
def train_model(&block)
|
15
24
|
@train_model = block if block_given?
|
16
25
|
@train_model
|
@@ -21,13 +30,17 @@ class VectorModel
|
|
21
30
|
@eval_model
|
22
31
|
end
|
23
32
|
|
33
|
+
def init
|
34
|
+
@model ||= self.instance_exec &@init_model
|
35
|
+
end
|
36
|
+
|
24
37
|
def post_process(&block)
|
25
38
|
@post_process = block if block_given?
|
26
39
|
@post_process
|
27
40
|
end
|
28
41
|
|
29
42
|
|
30
|
-
def self.R_run(
|
43
|
+
def self.R_run(model_path, features, labels, code, names = nil, factor_levels = nil)
|
31
44
|
TmpFile.with_file do |feature_file|
|
32
45
|
Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
|
33
46
|
Open.write(feature_file + '.label', labels * "\n" + "\n")
|
@@ -54,7 +67,7 @@ features = cbind(features, label = labels);
|
|
54
67
|
end
|
55
68
|
end
|
56
69
|
|
57
|
-
def self.R_train(
|
70
|
+
def self.R_train(model_path, features, labels, code, names = nil, factor_levels = nil)
|
58
71
|
TmpFile.with_file do |feature_file|
|
59
72
|
Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
|
60
73
|
Open.write(feature_file + '.label', labels * "\n" + "\n")
|
@@ -82,13 +95,13 @@ for (c in names(features)){
|
|
82
95
|
if (is.factor(features[[c]]))
|
83
96
|
factor_levels[c] = paste(levels(features[[c]]), collapse="\t")
|
84
97
|
}
|
85
|
-
rbbt.tsv.write("#{
|
86
|
-
save(model, file='#{
|
98
|
+
rbbt.tsv.write("#{model_path}.factor_levels", factor_levels, names=c('Levels'), type='flat')
|
99
|
+
save(model, file='#{model_path}')
|
87
100
|
EOF
|
88
101
|
end
|
89
102
|
end
|
90
103
|
|
91
|
-
def self.R_eval(
|
104
|
+
def self.R_eval(model_path, features, list, code, names = nil, factor_levels = nil)
|
92
105
|
TmpFile.with_file do |feature_file|
|
93
106
|
if list
|
94
107
|
Open.write(feature_file, features.collect{|feat| feat * "\t"} * "\n" + "\n")
|
@@ -105,7 +118,7 @@ features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=TRUE);
|
|
105
118
|
#{ factor_levels.collect do |name,levels|
|
106
119
|
"features[['#{name}']] = factor(features[['#{name}']], levels=#{R.ruby2R levels})"
|
107
120
|
end * "\n" if factor_levels }
|
108
|
-
load(file="#{
|
121
|
+
load(file="#{model_path}");
|
109
122
|
#{code}
|
110
123
|
cat(paste(label, sep="\\n", collapse="\\n"));
|
111
124
|
EOF
|
@@ -127,61 +140,77 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
127
140
|
instance_eval code, file
|
128
141
|
end
|
129
142
|
|
130
|
-
def initialize(directory = nil,
|
143
|
+
def initialize(directory = nil, model_options = {})
|
131
144
|
@directory = directory
|
145
|
+
@model_options = IndiferentHash.setup(model_options)
|
146
|
+
|
132
147
|
if @directory
|
133
|
-
FileUtils.mkdir_p @directory unless File.
|
148
|
+
FileUtils.mkdir_p @directory unless File.exist?(@directory)
|
149
|
+
|
150
|
+
@model_path = File.join(@directory, "model")
|
134
151
|
|
135
|
-
@model_file = File.join(@directory, "model")
|
136
152
|
@extract_features_file = File.join(@directory, "features")
|
137
|
-
@
|
138
|
-
|
139
|
-
@
|
140
|
-
@
|
141
|
-
|
142
|
-
@
|
143
|
-
@
|
144
|
-
|
145
|
-
@
|
146
|
-
|
147
|
-
|
148
|
-
|
153
|
+
@init_model_path = File.join(@directory, "init_model")
|
154
|
+
|
155
|
+
@train_model_path = File.join(@directory, "train_model")
|
156
|
+
@train_model_path_R = File.join(@directory, "train_model.R")
|
157
|
+
|
158
|
+
@eval_model_path = File.join(@directory, "eval_model")
|
159
|
+
@eval_model_path_R = File.join(@directory, "eval_model.R")
|
160
|
+
|
161
|
+
@post_process_file = File.join(@directory, "post_process")
|
162
|
+
@post_process_file_R = File.join(@directory, "post_process.R")
|
163
|
+
|
164
|
+
@names_file = File.join(@directory, "feature_names")
|
165
|
+
@levels_file = File.join(@directory, "levels")
|
166
|
+
@options_file = File.join(@directory, "options.json")
|
167
|
+
|
168
|
+
if File.exist?(@options_file)
|
169
|
+
@model_options = JSON.parse(Open.read(@options_file)).merge(@model_options || {})
|
149
170
|
IndiferentHash.setup(@model_options)
|
150
171
|
end
|
151
172
|
end
|
152
|
-
|
173
|
+
|
153
174
|
if extract_features.nil?
|
154
|
-
if @extract_features_file && File.
|
175
|
+
if @extract_features_file && File.exist?(@extract_features_file)
|
155
176
|
@extract_features = __load_method @extract_features_file
|
156
177
|
end
|
157
178
|
else
|
158
179
|
@extract_features = extract_features
|
159
180
|
end
|
160
181
|
|
182
|
+
if init_model.nil?
|
183
|
+
if @init_model_path && File.exist?(@init_model_path)
|
184
|
+
@init_model = __load_method @init_model_path
|
185
|
+
end
|
186
|
+
else
|
187
|
+
@init_model = init_model
|
188
|
+
end
|
189
|
+
|
161
190
|
if train_model.nil?
|
162
|
-
if @
|
163
|
-
@train_model = __load_method @
|
164
|
-
elsif @
|
165
|
-
@train_model = Open.read(@
|
191
|
+
if @train_model_path && File.exist?(@train_model_path)
|
192
|
+
@train_model = __load_method @train_model_path
|
193
|
+
elsif @train_model_path_R && File.exist?(@train_model_path_R)
|
194
|
+
@train_model = Open.read(@train_model_path_R)
|
166
195
|
end
|
167
196
|
else
|
168
197
|
@train_model = train_model
|
169
198
|
end
|
170
199
|
|
171
200
|
if eval_model.nil?
|
172
|
-
if @
|
173
|
-
@eval_model = __load_method @
|
174
|
-
elsif @
|
175
|
-
@eval_model = Open.read(@
|
201
|
+
if @eval_model_path && File.exist?(@eval_model_path)
|
202
|
+
@eval_model = __load_method @eval_model_path
|
203
|
+
elsif @eval_model_path_R && File.exist?(@eval_model_path_R)
|
204
|
+
@eval_model = Open.read(@eval_model_path_R)
|
176
205
|
end
|
177
206
|
else
|
178
207
|
@eval_model = eval_model
|
179
208
|
end
|
180
209
|
|
181
210
|
if post_process.nil?
|
182
|
-
if @post_process_file && File.
|
211
|
+
if @post_process_file && File.exist?(@post_process_file)
|
183
212
|
@post_process = __load_method @post_process_file
|
184
|
-
elsif @post_process_file_R && File.
|
213
|
+
elsif @post_process_file_R && File.exist?(@post_process_file_R)
|
185
214
|
@post_process = Open.read(@post_process_file_R)
|
186
215
|
end
|
187
216
|
else
|
@@ -190,7 +219,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
190
219
|
|
191
220
|
|
192
221
|
if names.nil?
|
193
|
-
if @names_file && File.
|
222
|
+
if @names_file && File.exist?(@names_file)
|
194
223
|
@names = Open.read(@names_file).split("\n")
|
195
224
|
end
|
196
225
|
else
|
@@ -198,11 +227,11 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
198
227
|
end
|
199
228
|
|
200
229
|
if factor_levels.nil?
|
201
|
-
if @levels_file && File.
|
230
|
+
if @levels_file && File.exist?(@levels_file)
|
202
231
|
@factor_levels = YAML.load(Open.read(@levels_file))
|
203
232
|
end
|
204
|
-
if @
|
205
|
-
@factor_levels = TSV.open(@
|
233
|
+
if @model_path && File.exist?(@model_path + '.factor_levels')
|
234
|
+
@factor_levels = TSV.open(@model_path + '.factor_levels')
|
206
235
|
end
|
207
236
|
else
|
208
237
|
@factor_levels = factor_levels
|
@@ -241,23 +270,24 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
241
270
|
case
|
242
271
|
when Proc === train_model
|
243
272
|
begin
|
244
|
-
Open.write(@
|
273
|
+
Open.write(@train_model_path, train_model.source)
|
245
274
|
rescue
|
246
275
|
end
|
247
276
|
when String === train_model
|
248
|
-
Open.write(@
|
277
|
+
Open.write(@train_model_path_R, @train_model)
|
249
278
|
end
|
250
279
|
|
251
280
|
Open.write(@extract_features_file, @extract_features.source) if @extract_features
|
281
|
+
Open.write(@init_model_path, @init_model.source) if @init_model
|
252
282
|
|
253
283
|
case
|
254
284
|
when Proc === eval_model
|
255
285
|
begin
|
256
|
-
Open.write(@
|
286
|
+
Open.write(@eval_model_path, eval_model.source)
|
257
287
|
rescue
|
258
288
|
end
|
259
289
|
when String === eval_model
|
260
|
-
Open.write(@
|
290
|
+
Open.write(@eval_model_path_R, eval_model)
|
261
291
|
end
|
262
292
|
|
263
293
|
case
|
@@ -270,24 +300,37 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
270
300
|
Open.write(@post_process_file_R, post_process)
|
271
301
|
end
|
272
302
|
|
273
|
-
|
274
303
|
Open.write(@levels_file, @factor_levels.to_yaml) if @factor_levels
|
275
304
|
Open.write(@names_file, @names * "\n" + "\n") if @names
|
276
305
|
Open.write(@options_file, @model_options.to_json) if @model_options
|
277
306
|
end
|
278
307
|
|
279
308
|
def train
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
309
|
+
begin
|
310
|
+
if @balance
|
311
|
+
@original_features = @features
|
312
|
+
@original_labels = @labels
|
313
|
+
self.balance_labels
|
314
|
+
end
|
315
|
+
|
316
|
+
case
|
317
|
+
when Proc === @train_model
|
318
|
+
self.instance_exec(@features, @labels, @names, @factor_levels, &@train_model)
|
319
|
+
when String === @train_model
|
320
|
+
VectorModel.R_train(@model_path, @features, @labels, train_model, @names, @factor_levels)
|
321
|
+
end
|
322
|
+
ensure
|
323
|
+
if @balance
|
324
|
+
@features = @original_features
|
325
|
+
@labels = @original_labels
|
326
|
+
end
|
285
327
|
end
|
328
|
+
|
286
329
|
save_models if @directory
|
287
330
|
end
|
288
331
|
|
289
332
|
def run(code)
|
290
|
-
VectorModel.R_run(@
|
333
|
+
VectorModel.R_run(@model_path, @features, @labels, code, @names, @factor_levels)
|
291
334
|
end
|
292
335
|
|
293
336
|
def eval(element)
|
@@ -295,14 +338,14 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
295
338
|
|
296
339
|
result = case
|
297
340
|
when Proc === @eval_model
|
298
|
-
self.instance_exec(
|
341
|
+
self.instance_exec(features, false, nil, @names, @factor_levels, &@eval_model)
|
299
342
|
when String === @eval_model
|
300
|
-
VectorModel.R_eval(@
|
343
|
+
VectorModel.R_eval(@model_path, features, false, eval_model, @names, @factor_levels)
|
301
344
|
else
|
302
345
|
raise "No @eval_model function or R script"
|
303
346
|
end
|
304
347
|
|
305
|
-
result = self.instance_exec(result, &@post_process) if Proc === @post_process
|
348
|
+
result = self.instance_exec(result, false, &@post_process) if Proc === @post_process
|
306
349
|
|
307
350
|
result
|
308
351
|
end
|
@@ -321,12 +364,12 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
321
364
|
|
322
365
|
result = case
|
323
366
|
when Proc === eval_model
|
324
|
-
self.instance_exec(
|
367
|
+
self.instance_exec(features, true, nil, @names, @factor_levels, &@eval_model)
|
325
368
|
when String === eval_model
|
326
|
-
VectorModel.R_eval(@
|
369
|
+
VectorModel.R_eval(@model_path, features, true, eval_model, @names, @factor_levels)
|
327
370
|
end
|
328
371
|
|
329
|
-
result = self.instance_exec(result, &@post_process) if Proc === @post_process
|
372
|
+
result = self.instance_exec(result, true, &@post_process) if Proc === @post_process
|
330
373
|
|
331
374
|
result
|
332
375
|
end
|
@@ -438,6 +481,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
438
481
|
@features = orig_features
|
439
482
|
@labels = orig_labels
|
440
483
|
end unless folds == -1
|
484
|
+
|
441
485
|
self.reset_model if self.respond_to? :reset_model
|
442
486
|
self.train unless folds == 1
|
443
487
|
res
|
data/python/rbbt_dm/__init__.py
CHANGED
@@ -1 +1,48 @@
|
|
1
|
-
|
1
|
+
from torch.utils.data import Dataset, DataLoader
|
2
|
+
|
3
|
+
class TSVDataset(Dataset):
|
4
|
+
def __init__(self, tsv):
|
5
|
+
self.tsv = tsv
|
6
|
+
|
7
|
+
def __getitem__(self, key):
|
8
|
+
if (type(key) == int):
|
9
|
+
row = self.tsv.iloc[key]
|
10
|
+
else:
|
11
|
+
row = self.tsv.loc[key]
|
12
|
+
|
13
|
+
row = row.to_numpy()
|
14
|
+
features = row[:-1]
|
15
|
+
label = row[-1]
|
16
|
+
|
17
|
+
return features, label
|
18
|
+
|
19
|
+
def __len__(self):
|
20
|
+
return len(self.tsv)
|
21
|
+
|
22
|
+
def tsv_dataset(filename, *args, **kwargs):
|
23
|
+
import rbbt
|
24
|
+
return TSVDataset(rbbt.tsv(filename, *args, **kwargs))
|
25
|
+
|
26
|
+
def tsv(*args, **kwargs):
|
27
|
+
return tsv_dataset(*args, **kwargs)
|
28
|
+
|
29
|
+
def data_dir():
|
30
|
+
import rbbt
|
31
|
+
return rbbt.path('var/rbbt_dm/data')
|
32
|
+
|
33
|
+
if __name__ == "__main__":
|
34
|
+
import rbbt
|
35
|
+
|
36
|
+
filename = "/home/miki/test/numeric.tsv"
|
37
|
+
ds = tsv(filename)
|
38
|
+
|
39
|
+
dl = DataLoader(ds, batch_size=1)
|
40
|
+
|
41
|
+
for f, l in iter(dl):
|
42
|
+
print(".")
|
43
|
+
print(f[0,:])
|
44
|
+
print(l[0])
|
45
|
+
|
46
|
+
|
47
|
+
|
48
|
+
|
File without changes
|
@@ -0,0 +1,141 @@
|
|
1
|
+
from matplotlib import pyplot as plt
|
2
|
+
import numpy as np
|
3
|
+
import torch
|
4
|
+
from IPython.display import HTML, display
|
5
|
+
|
6
|
+
|
7
|
+
def set_default(figsize=(10, 10), dpi=100):
|
8
|
+
plt.style.use(['dark_background', 'bmh'])
|
9
|
+
plt.rc('axes', facecolor='k')
|
10
|
+
plt.rc('figure', facecolor='k')
|
11
|
+
plt.rc('figure', figsize=figsize, dpi=dpi)
|
12
|
+
|
13
|
+
|
14
|
+
def plot_data(X, y, d=0, auto=False, zoom=1):
|
15
|
+
X = X.cpu()
|
16
|
+
y = y.cpu()
|
17
|
+
plt.scatter(X.numpy()[:, 0], X.numpy()[:, 1], c=y, s=20, cmap=plt.cm.Spectral)
|
18
|
+
plt.axis('square')
|
19
|
+
plt.axis(np.array((-1.1, 1.1, -1.1, 1.1)) * zoom)
|
20
|
+
if auto is True: plt.axis('equal')
|
21
|
+
plt.axis('off')
|
22
|
+
|
23
|
+
_m, _c = 0, '.15'
|
24
|
+
plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0)
|
25
|
+
plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0)
|
26
|
+
|
27
|
+
|
28
|
+
def plot_model(X, y, model):
|
29
|
+
model.cpu()
|
30
|
+
mesh = np.arange(-1.1, 1.1, 0.01)
|
31
|
+
xx, yy = np.meshgrid(mesh, mesh)
|
32
|
+
with torch.no_grad():
|
33
|
+
data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float()
|
34
|
+
Z = model(data).detach()
|
35
|
+
Z = np.argmax(Z, axis=1).reshape(xx.shape)
|
36
|
+
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3)
|
37
|
+
plot_data(X, y)
|
38
|
+
|
39
|
+
|
40
|
+
def show_scatterplot(X, colors, title=''):
|
41
|
+
colors = colors.cpu().numpy()
|
42
|
+
X = X.cpu().numpy()
|
43
|
+
plt.figure()
|
44
|
+
plt.axis('equal')
|
45
|
+
plt.scatter(X[:, 0], X[:, 1], c=colors, s=30)
|
46
|
+
# plt.grid(True)
|
47
|
+
plt.title(title)
|
48
|
+
plt.axis('off')
|
49
|
+
|
50
|
+
|
51
|
+
def plot_bases(bases, width=0.04):
|
52
|
+
bases = bases.cpu()
|
53
|
+
bases[2:] -= bases[:2]
|
54
|
+
plt.arrow(*bases[0], *bases[2], width=width, color=(1,0,0), zorder=10, alpha=1., length_includes_head=True)
|
55
|
+
plt.arrow(*bases[1], *bases[3], width=width, color=(0,1,0), zorder=10, alpha=1., length_includes_head=True)
|
56
|
+
|
57
|
+
|
58
|
+
def show_mat(mat, vect, prod, threshold=-1):
|
59
|
+
# Subplot grid definition
|
60
|
+
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=False, sharey=True,
|
61
|
+
gridspec_kw={'width_ratios':[5,1,1]})
|
62
|
+
# Plot matrices
|
63
|
+
cax1 = ax1.matshow(mat.numpy(), clim=(-1, 1))
|
64
|
+
ax2.matshow(vect.numpy(), clim=(-1, 1))
|
65
|
+
cax3 = ax3.matshow(prod.numpy(), clim=(threshold, 1))
|
66
|
+
|
67
|
+
# Set titles
|
68
|
+
ax1.set_title(f'A: {mat.size(0)} \u00D7 {mat.size(1)}')
|
69
|
+
ax2.set_title(f'a^(i): {vect.numel()}')
|
70
|
+
ax3.set_title(f'p: {prod.numel()}')
|
71
|
+
|
72
|
+
# Remove xticks for vectors
|
73
|
+
ax2.set_xticks(tuple())
|
74
|
+
ax3.set_xticks(tuple())
|
75
|
+
|
76
|
+
# Plot colourbars
|
77
|
+
fig.colorbar(cax1, ax=ax2)
|
78
|
+
fig.colorbar(cax3, ax=ax3)
|
79
|
+
|
80
|
+
# Fix y-axis limits
|
81
|
+
ax1.set_ylim(bottom=max(len(prod), len(vect)) - 0.5)
|
82
|
+
|
83
|
+
|
84
|
+
colors = dict(
|
85
|
+
aqua='#8dd3c7',
|
86
|
+
yellow='#ffffb3',
|
87
|
+
lavender='#bebada',
|
88
|
+
red='#fb8072',
|
89
|
+
blue='#80b1d3',
|
90
|
+
orange='#fdb462',
|
91
|
+
green='#b3de69',
|
92
|
+
pink='#fccde5',
|
93
|
+
grey='#d9d9d9',
|
94
|
+
violet='#bc80bd',
|
95
|
+
unk1='#ccebc5',
|
96
|
+
unk2='#ffed6f',
|
97
|
+
)
|
98
|
+
|
99
|
+
|
100
|
+
def _cstr(s, color='black'):
|
101
|
+
if s == ' ':
|
102
|
+
return f'<text style=color:#000;padding-left:10px;background-color:{color}> </text>'
|
103
|
+
else:
|
104
|
+
return f'<text style=color:#000;background-color:{color}>{s} </text>'
|
105
|
+
|
106
|
+
# print html
|
107
|
+
def _print_color(t):
|
108
|
+
display(HTML(''.join([_cstr(ti, color=ci) for ti, ci in t])))
|
109
|
+
|
110
|
+
# get appropriate color for value
|
111
|
+
def _get_clr(value):
|
112
|
+
colors = ('#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8',
|
113
|
+
'#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8',
|
114
|
+
'#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f',
|
115
|
+
'#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e')
|
116
|
+
value = int((value * 100) / 5)
|
117
|
+
if value == len(colors): value -= 1 # fixing bugs...
|
118
|
+
return colors[value]
|
119
|
+
|
120
|
+
def _visualise_values(output_values, result_list):
|
121
|
+
text_colours = []
|
122
|
+
for i in range(len(output_values)):
|
123
|
+
text = (result_list[i], _get_clr(output_values[i]))
|
124
|
+
text_colours.append(text)
|
125
|
+
_print_color(text_colours)
|
126
|
+
|
127
|
+
def print_colourbar():
|
128
|
+
color_range = torch.linspace(-2.5, 2.5, 20)
|
129
|
+
to_print = [(f'{x:.2f}', _get_clr((x+2.5)/5)) for x in color_range]
|
130
|
+
_print_color(to_print)
|
131
|
+
|
132
|
+
|
133
|
+
# Let's only focus on the last time step for now
|
134
|
+
# First, the cell state (Long term memory)
|
135
|
+
def plot_state(data, state, b, decoder):
|
136
|
+
actual_data = decoder(data[b, :, :].numpy())
|
137
|
+
seq_len = len(actual_data)
|
138
|
+
seq_len_w_pad = len(state)
|
139
|
+
for s in range(state.size(2)):
|
140
|
+
states = torch.sigmoid(state[:, b, s])
|
141
|
+
_visualise_values(states[seq_len_w_pad - seq_len:], list(actual_data))
|
@@ -0,0 +1,27 @@
|
|
1
|
+
import torch
|
2
|
+
import math
|
3
|
+
def spiral_data(N=1000, D=2, C=3):
|
4
|
+
X = torch.zeros(N * C, D)
|
5
|
+
y = torch.zeros(N * C, dtype=torch.long)
|
6
|
+
for c in range(C):
|
7
|
+
index = 0
|
8
|
+
t = torch.linspace(0, 1, N)
|
9
|
+
# When c = 0 and t = 0: start of linspace
|
10
|
+
# When c = 0 and t = 1: end of linpace
|
11
|
+
# This inner_var is for the formula inside sin() and cos() like sin(inner_var) and cos(inner_Var)
|
12
|
+
inner_var = torch.linspace(
|
13
|
+
# When t = 0
|
14
|
+
(2 * math.pi / C) * (c),
|
15
|
+
# When t = 1
|
16
|
+
(2 * math.pi / C) * (2 + c),
|
17
|
+
N
|
18
|
+
) + torch.randn(N) * 0.2
|
19
|
+
|
20
|
+
for ix in range(N * c, N * (c + 1)):
|
21
|
+
X[ix] = t[index] * torch.FloatTensor((
|
22
|
+
math.sin(inner_var[index]), math.cos(inner_var[index])
|
23
|
+
))
|
24
|
+
y[ix] = c
|
25
|
+
index += 1
|
26
|
+
|
27
|
+
return (X, y)
|