rbbt-dm 1.2.6 → 1.2.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/rbbt/matrix/barcode.rb +2 -2
- data/lib/rbbt/matrix/differential.rb +3 -3
- data/lib/rbbt/matrix/knowledge_base.rb +1 -1
- data/lib/rbbt/plots/bar.rb +1 -1
- data/lib/rbbt/stan.rb +1 -1
- data/lib/rbbt/statistics/hypergeometric.rb +2 -1
- data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
- data/lib/rbbt/vector/model/huggingface.rb +57 -38
- data/lib/rbbt/vector/model/pytorch_lightning.rb +35 -0
- data/lib/rbbt/vector/model/random_forest.rb +1 -1
- data/lib/rbbt/vector/model/spaCy.rb +8 -14
- data/lib/rbbt/vector/model/tensorflow.rb +6 -5
- data/lib/rbbt/vector/model/torch.rb +37 -0
- data/lib/rbbt/vector/model/util.rb +18 -0
- data/lib/rbbt/vector/model.rb +100 -56
- data/python/rbbt_dm/__init__.py +48 -1
- data/python/rbbt_dm/atcold/__init__.py +0 -0
- data/python/rbbt_dm/atcold/plot_lib.py +141 -0
- data/python/rbbt_dm/atcold/spiral.py +27 -0
- data/python/rbbt_dm/huggingface.py +57 -26
- data/python/rbbt_dm/language_model.py +70 -0
- data/python/rbbt_dm/util.py +30 -0
- data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
- data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
- data/test/rbbt/vector/model/test_huggingface.rb +258 -27
- data/test/rbbt/vector/model/test_pytorch_lightning.rb +83 -0
- data/test/rbbt/vector/model/test_spaCy.rb +1 -1
- data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
- data/test/rbbt/vector/test_model.rb +25 -26
- data/test/test_helper.rb +13 -0
- metadata +26 -16
- data/lib/rbbt/tensorflow.rb +0 -43
- data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
data/lib/rbbt/vector/model.rb
CHANGED
@@ -1,16 +1,25 @@
|
|
1
1
|
require 'rbbt/util/R'
|
2
2
|
require 'rbbt/vector/model/util'
|
3
|
+
require 'rbbt/util/python'
|
4
|
+
|
5
|
+
RbbtPython.add_path Rbbt.python.find(:lib)
|
6
|
+
RbbtPython.init_rbbt
|
3
7
|
|
4
8
|
class VectorModel
|
5
|
-
attr_accessor :directory, :
|
9
|
+
attr_accessor :directory, :model_path, :extract_features, :init_model, :train_model, :eval_model, :post_process, :balance
|
6
10
|
attr_accessor :features, :names, :labels, :factor_levels
|
7
|
-
attr_accessor :model_options
|
11
|
+
attr_accessor :model, :model_options
|
8
12
|
|
9
13
|
def extract_features(&block)
|
10
14
|
@extract_features = block if block_given?
|
11
15
|
@extract_features
|
12
16
|
end
|
13
17
|
|
18
|
+
def init_model(&block)
|
19
|
+
@init_model = block if block_given?
|
20
|
+
@init_model
|
21
|
+
end
|
22
|
+
|
14
23
|
def train_model(&block)
|
15
24
|
@train_model = block if block_given?
|
16
25
|
@train_model
|
@@ -21,13 +30,17 @@ class VectorModel
|
|
21
30
|
@eval_model
|
22
31
|
end
|
23
32
|
|
33
|
+
def init
|
34
|
+
@model ||= self.instance_exec &@init_model
|
35
|
+
end
|
36
|
+
|
24
37
|
def post_process(&block)
|
25
38
|
@post_process = block if block_given?
|
26
39
|
@post_process
|
27
40
|
end
|
28
41
|
|
29
42
|
|
30
|
-
def self.R_run(
|
43
|
+
def self.R_run(model_path, features, labels, code, names = nil, factor_levels = nil)
|
31
44
|
TmpFile.with_file do |feature_file|
|
32
45
|
Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
|
33
46
|
Open.write(feature_file + '.label', labels * "\n" + "\n")
|
@@ -54,7 +67,7 @@ features = cbind(features, label = labels);
|
|
54
67
|
end
|
55
68
|
end
|
56
69
|
|
57
|
-
def self.R_train(
|
70
|
+
def self.R_train(model_path, features, labels, code, names = nil, factor_levels = nil)
|
58
71
|
TmpFile.with_file do |feature_file|
|
59
72
|
Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
|
60
73
|
Open.write(feature_file + '.label', labels * "\n" + "\n")
|
@@ -82,13 +95,13 @@ for (c in names(features)){
|
|
82
95
|
if (is.factor(features[[c]]))
|
83
96
|
factor_levels[c] = paste(levels(features[[c]]), collapse="\t")
|
84
97
|
}
|
85
|
-
rbbt.tsv.write("#{
|
86
|
-
save(model, file='#{
|
98
|
+
rbbt.tsv.write("#{model_path}.factor_levels", factor_levels, names=c('Levels'), type='flat')
|
99
|
+
save(model, file='#{model_path}')
|
87
100
|
EOF
|
88
101
|
end
|
89
102
|
end
|
90
103
|
|
91
|
-
def self.R_eval(
|
104
|
+
def self.R_eval(model_path, features, list, code, names = nil, factor_levels = nil)
|
92
105
|
TmpFile.with_file do |feature_file|
|
93
106
|
if list
|
94
107
|
Open.write(feature_file, features.collect{|feat| feat * "\t"} * "\n" + "\n")
|
@@ -105,7 +118,7 @@ features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=TRUE);
|
|
105
118
|
#{ factor_levels.collect do |name,levels|
|
106
119
|
"features[['#{name}']] = factor(features[['#{name}']], levels=#{R.ruby2R levels})"
|
107
120
|
end * "\n" if factor_levels }
|
108
|
-
load(file="#{
|
121
|
+
load(file="#{model_path}");
|
109
122
|
#{code}
|
110
123
|
cat(paste(label, sep="\\n", collapse="\\n"));
|
111
124
|
EOF
|
@@ -127,61 +140,77 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
127
140
|
instance_eval code, file
|
128
141
|
end
|
129
142
|
|
130
|
-
def initialize(directory = nil,
|
143
|
+
def initialize(directory = nil, model_options = {})
|
131
144
|
@directory = directory
|
145
|
+
@model_options = IndiferentHash.setup(model_options)
|
146
|
+
|
132
147
|
if @directory
|
133
|
-
FileUtils.mkdir_p @directory unless File.
|
148
|
+
FileUtils.mkdir_p @directory unless File.exist?(@directory)
|
149
|
+
|
150
|
+
@model_path = File.join(@directory, "model")
|
134
151
|
|
135
|
-
@model_file = File.join(@directory, "model")
|
136
152
|
@extract_features_file = File.join(@directory, "features")
|
137
|
-
@
|
138
|
-
|
139
|
-
@
|
140
|
-
@
|
141
|
-
|
142
|
-
@
|
143
|
-
@
|
144
|
-
|
145
|
-
@
|
146
|
-
|
147
|
-
|
148
|
-
|
153
|
+
@init_model_path = File.join(@directory, "init_model")
|
154
|
+
|
155
|
+
@train_model_path = File.join(@directory, "train_model")
|
156
|
+
@train_model_path_R = File.join(@directory, "train_model.R")
|
157
|
+
|
158
|
+
@eval_model_path = File.join(@directory, "eval_model")
|
159
|
+
@eval_model_path_R = File.join(@directory, "eval_model.R")
|
160
|
+
|
161
|
+
@post_process_file = File.join(@directory, "post_process")
|
162
|
+
@post_process_file_R = File.join(@directory, "post_process.R")
|
163
|
+
|
164
|
+
@names_file = File.join(@directory, "feature_names")
|
165
|
+
@levels_file = File.join(@directory, "levels")
|
166
|
+
@options_file = File.join(@directory, "options.json")
|
167
|
+
|
168
|
+
if File.exist?(@options_file)
|
169
|
+
@model_options = JSON.parse(Open.read(@options_file)).merge(@model_options || {})
|
149
170
|
IndiferentHash.setup(@model_options)
|
150
171
|
end
|
151
172
|
end
|
152
|
-
|
173
|
+
|
153
174
|
if extract_features.nil?
|
154
|
-
if @extract_features_file && File.
|
175
|
+
if @extract_features_file && File.exist?(@extract_features_file)
|
155
176
|
@extract_features = __load_method @extract_features_file
|
156
177
|
end
|
157
178
|
else
|
158
179
|
@extract_features = extract_features
|
159
180
|
end
|
160
181
|
|
182
|
+
if init_model.nil?
|
183
|
+
if @init_model_path && File.exist?(@init_model_path)
|
184
|
+
@init_model = __load_method @init_model_path
|
185
|
+
end
|
186
|
+
else
|
187
|
+
@init_model = init_model
|
188
|
+
end
|
189
|
+
|
161
190
|
if train_model.nil?
|
162
|
-
if @
|
163
|
-
@train_model = __load_method @
|
164
|
-
elsif @
|
165
|
-
@train_model = Open.read(@
|
191
|
+
if @train_model_path && File.exist?(@train_model_path)
|
192
|
+
@train_model = __load_method @train_model_path
|
193
|
+
elsif @train_model_path_R && File.exist?(@train_model_path_R)
|
194
|
+
@train_model = Open.read(@train_model_path_R)
|
166
195
|
end
|
167
196
|
else
|
168
197
|
@train_model = train_model
|
169
198
|
end
|
170
199
|
|
171
200
|
if eval_model.nil?
|
172
|
-
if @
|
173
|
-
@eval_model = __load_method @
|
174
|
-
elsif @
|
175
|
-
@eval_model = Open.read(@
|
201
|
+
if @eval_model_path && File.exist?(@eval_model_path)
|
202
|
+
@eval_model = __load_method @eval_model_path
|
203
|
+
elsif @eval_model_path_R && File.exist?(@eval_model_path_R)
|
204
|
+
@eval_model = Open.read(@eval_model_path_R)
|
176
205
|
end
|
177
206
|
else
|
178
207
|
@eval_model = eval_model
|
179
208
|
end
|
180
209
|
|
181
210
|
if post_process.nil?
|
182
|
-
if @post_process_file && File.
|
211
|
+
if @post_process_file && File.exist?(@post_process_file)
|
183
212
|
@post_process = __load_method @post_process_file
|
184
|
-
elsif @post_process_file_R && File.
|
213
|
+
elsif @post_process_file_R && File.exist?(@post_process_file_R)
|
185
214
|
@post_process = Open.read(@post_process_file_R)
|
186
215
|
end
|
187
216
|
else
|
@@ -190,7 +219,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
190
219
|
|
191
220
|
|
192
221
|
if names.nil?
|
193
|
-
if @names_file && File.
|
222
|
+
if @names_file && File.exist?(@names_file)
|
194
223
|
@names = Open.read(@names_file).split("\n")
|
195
224
|
end
|
196
225
|
else
|
@@ -198,11 +227,11 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
198
227
|
end
|
199
228
|
|
200
229
|
if factor_levels.nil?
|
201
|
-
if @levels_file && File.
|
230
|
+
if @levels_file && File.exist?(@levels_file)
|
202
231
|
@factor_levels = YAML.load(Open.read(@levels_file))
|
203
232
|
end
|
204
|
-
if @
|
205
|
-
@factor_levels = TSV.open(@
|
233
|
+
if @model_path && File.exist?(@model_path + '.factor_levels')
|
234
|
+
@factor_levels = TSV.open(@model_path + '.factor_levels')
|
206
235
|
end
|
207
236
|
else
|
208
237
|
@factor_levels = factor_levels
|
@@ -241,23 +270,24 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
241
270
|
case
|
242
271
|
when Proc === train_model
|
243
272
|
begin
|
244
|
-
Open.write(@
|
273
|
+
Open.write(@train_model_path, train_model.source)
|
245
274
|
rescue
|
246
275
|
end
|
247
276
|
when String === train_model
|
248
|
-
Open.write(@
|
277
|
+
Open.write(@train_model_path_R, @train_model)
|
249
278
|
end
|
250
279
|
|
251
280
|
Open.write(@extract_features_file, @extract_features.source) if @extract_features
|
281
|
+
Open.write(@init_model_path, @init_model.source) if @init_model
|
252
282
|
|
253
283
|
case
|
254
284
|
when Proc === eval_model
|
255
285
|
begin
|
256
|
-
Open.write(@
|
286
|
+
Open.write(@eval_model_path, eval_model.source)
|
257
287
|
rescue
|
258
288
|
end
|
259
289
|
when String === eval_model
|
260
|
-
Open.write(@
|
290
|
+
Open.write(@eval_model_path_R, eval_model)
|
261
291
|
end
|
262
292
|
|
263
293
|
case
|
@@ -270,24 +300,37 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
270
300
|
Open.write(@post_process_file_R, post_process)
|
271
301
|
end
|
272
302
|
|
273
|
-
|
274
303
|
Open.write(@levels_file, @factor_levels.to_yaml) if @factor_levels
|
275
304
|
Open.write(@names_file, @names * "\n" + "\n") if @names
|
276
305
|
Open.write(@options_file, @model_options.to_json) if @model_options
|
277
306
|
end
|
278
307
|
|
279
308
|
def train
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
309
|
+
begin
|
310
|
+
if @balance
|
311
|
+
@original_features = @features
|
312
|
+
@original_labels = @labels
|
313
|
+
self.balance_labels
|
314
|
+
end
|
315
|
+
|
316
|
+
case
|
317
|
+
when Proc === @train_model
|
318
|
+
self.instance_exec(@features, @labels, @names, @factor_levels, &@train_model)
|
319
|
+
when String === @train_model
|
320
|
+
VectorModel.R_train(@model_path, @features, @labels, train_model, @names, @factor_levels)
|
321
|
+
end
|
322
|
+
ensure
|
323
|
+
if @balance
|
324
|
+
@features = @original_features
|
325
|
+
@labels = @original_labels
|
326
|
+
end
|
285
327
|
end
|
328
|
+
|
286
329
|
save_models if @directory
|
287
330
|
end
|
288
331
|
|
289
332
|
def run(code)
|
290
|
-
VectorModel.R_run(@
|
333
|
+
VectorModel.R_run(@model_path, @features, @labels, code, @names, @factor_levels)
|
291
334
|
end
|
292
335
|
|
293
336
|
def eval(element)
|
@@ -295,14 +338,14 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
295
338
|
|
296
339
|
result = case
|
297
340
|
when Proc === @eval_model
|
298
|
-
self.instance_exec(
|
341
|
+
self.instance_exec(features, false, nil, @names, @factor_levels, &@eval_model)
|
299
342
|
when String === @eval_model
|
300
|
-
VectorModel.R_eval(@
|
343
|
+
VectorModel.R_eval(@model_path, features, false, eval_model, @names, @factor_levels)
|
301
344
|
else
|
302
345
|
raise "No @eval_model function or R script"
|
303
346
|
end
|
304
347
|
|
305
|
-
result = self.instance_exec(result, &@post_process) if Proc === @post_process
|
348
|
+
result = self.instance_exec(result, false, &@post_process) if Proc === @post_process
|
306
349
|
|
307
350
|
result
|
308
351
|
end
|
@@ -321,12 +364,12 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
321
364
|
|
322
365
|
result = case
|
323
366
|
when Proc === eval_model
|
324
|
-
self.instance_exec(
|
367
|
+
self.instance_exec(features, true, nil, @names, @factor_levels, &@eval_model)
|
325
368
|
when String === eval_model
|
326
|
-
VectorModel.R_eval(@
|
369
|
+
VectorModel.R_eval(@model_path, features, true, eval_model, @names, @factor_levels)
|
327
370
|
end
|
328
371
|
|
329
|
-
result = self.instance_exec(result, &@post_process) if Proc === @post_process
|
372
|
+
result = self.instance_exec(result, true, &@post_process) if Proc === @post_process
|
330
373
|
|
331
374
|
result
|
332
375
|
end
|
@@ -438,6 +481,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
438
481
|
@features = orig_features
|
439
482
|
@labels = orig_labels
|
440
483
|
end unless folds == -1
|
484
|
+
|
441
485
|
self.reset_model if self.respond_to? :reset_model
|
442
486
|
self.train unless folds == 1
|
443
487
|
res
|
data/python/rbbt_dm/__init__.py
CHANGED
@@ -1 +1,48 @@
|
|
1
|
-
|
1
|
+
from torch.utils.data import Dataset, DataLoader
|
2
|
+
|
3
|
+
class TSVDataset(Dataset):
|
4
|
+
def __init__(self, tsv):
|
5
|
+
self.tsv = tsv
|
6
|
+
|
7
|
+
def __getitem__(self, key):
|
8
|
+
if (type(key) == int):
|
9
|
+
row = self.tsv.iloc[key]
|
10
|
+
else:
|
11
|
+
row = self.tsv.loc[key]
|
12
|
+
|
13
|
+
row = row.to_numpy()
|
14
|
+
features = row[:-1]
|
15
|
+
label = row[-1]
|
16
|
+
|
17
|
+
return features, label
|
18
|
+
|
19
|
+
def __len__(self):
|
20
|
+
return len(self.tsv)
|
21
|
+
|
22
|
+
def tsv_dataset(filename, *args, **kwargs):
|
23
|
+
import rbbt
|
24
|
+
return TSVDataset(rbbt.tsv(filename, *args, **kwargs))
|
25
|
+
|
26
|
+
def tsv(*args, **kwargs):
|
27
|
+
return tsv_dataset(*args, **kwargs)
|
28
|
+
|
29
|
+
def data_dir():
|
30
|
+
import rbbt
|
31
|
+
return rbbt.path('var/rbbt_dm/data')
|
32
|
+
|
33
|
+
if __name__ == "__main__":
|
34
|
+
import rbbt
|
35
|
+
|
36
|
+
filename = "/home/miki/test/numeric.tsv"
|
37
|
+
ds = tsv(filename)
|
38
|
+
|
39
|
+
dl = DataLoader(ds, batch_size=1)
|
40
|
+
|
41
|
+
for f, l in iter(dl):
|
42
|
+
print(".")
|
43
|
+
print(f[0,:])
|
44
|
+
print(l[0])
|
45
|
+
|
46
|
+
|
47
|
+
|
48
|
+
|
File without changes
|
@@ -0,0 +1,141 @@
|
|
1
|
+
from matplotlib import pyplot as plt
|
2
|
+
import numpy as np
|
3
|
+
import torch
|
4
|
+
from IPython.display import HTML, display
|
5
|
+
|
6
|
+
|
7
|
+
def set_default(figsize=(10, 10), dpi=100):
|
8
|
+
plt.style.use(['dark_background', 'bmh'])
|
9
|
+
plt.rc('axes', facecolor='k')
|
10
|
+
plt.rc('figure', facecolor='k')
|
11
|
+
plt.rc('figure', figsize=figsize, dpi=dpi)
|
12
|
+
|
13
|
+
|
14
|
+
def plot_data(X, y, d=0, auto=False, zoom=1):
|
15
|
+
X = X.cpu()
|
16
|
+
y = y.cpu()
|
17
|
+
plt.scatter(X.numpy()[:, 0], X.numpy()[:, 1], c=y, s=20, cmap=plt.cm.Spectral)
|
18
|
+
plt.axis('square')
|
19
|
+
plt.axis(np.array((-1.1, 1.1, -1.1, 1.1)) * zoom)
|
20
|
+
if auto is True: plt.axis('equal')
|
21
|
+
plt.axis('off')
|
22
|
+
|
23
|
+
_m, _c = 0, '.15'
|
24
|
+
plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0)
|
25
|
+
plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0)
|
26
|
+
|
27
|
+
|
28
|
+
def plot_model(X, y, model):
|
29
|
+
model.cpu()
|
30
|
+
mesh = np.arange(-1.1, 1.1, 0.01)
|
31
|
+
xx, yy = np.meshgrid(mesh, mesh)
|
32
|
+
with torch.no_grad():
|
33
|
+
data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float()
|
34
|
+
Z = model(data).detach()
|
35
|
+
Z = np.argmax(Z, axis=1).reshape(xx.shape)
|
36
|
+
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3)
|
37
|
+
plot_data(X, y)
|
38
|
+
|
39
|
+
|
40
|
+
def show_scatterplot(X, colors, title=''):
|
41
|
+
colors = colors.cpu().numpy()
|
42
|
+
X = X.cpu().numpy()
|
43
|
+
plt.figure()
|
44
|
+
plt.axis('equal')
|
45
|
+
plt.scatter(X[:, 0], X[:, 1], c=colors, s=30)
|
46
|
+
# plt.grid(True)
|
47
|
+
plt.title(title)
|
48
|
+
plt.axis('off')
|
49
|
+
|
50
|
+
|
51
|
+
def plot_bases(bases, width=0.04):
|
52
|
+
bases = bases.cpu()
|
53
|
+
bases[2:] -= bases[:2]
|
54
|
+
plt.arrow(*bases[0], *bases[2], width=width, color=(1,0,0), zorder=10, alpha=1., length_includes_head=True)
|
55
|
+
plt.arrow(*bases[1], *bases[3], width=width, color=(0,1,0), zorder=10, alpha=1., length_includes_head=True)
|
56
|
+
|
57
|
+
|
58
|
+
def show_mat(mat, vect, prod, threshold=-1):
|
59
|
+
# Subplot grid definition
|
60
|
+
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=False, sharey=True,
|
61
|
+
gridspec_kw={'width_ratios':[5,1,1]})
|
62
|
+
# Plot matrices
|
63
|
+
cax1 = ax1.matshow(mat.numpy(), clim=(-1, 1))
|
64
|
+
ax2.matshow(vect.numpy(), clim=(-1, 1))
|
65
|
+
cax3 = ax3.matshow(prod.numpy(), clim=(threshold, 1))
|
66
|
+
|
67
|
+
# Set titles
|
68
|
+
ax1.set_title(f'A: {mat.size(0)} \u00D7 {mat.size(1)}')
|
69
|
+
ax2.set_title(f'a^(i): {vect.numel()}')
|
70
|
+
ax3.set_title(f'p: {prod.numel()}')
|
71
|
+
|
72
|
+
# Remove xticks for vectors
|
73
|
+
ax2.set_xticks(tuple())
|
74
|
+
ax3.set_xticks(tuple())
|
75
|
+
|
76
|
+
# Plot colourbars
|
77
|
+
fig.colorbar(cax1, ax=ax2)
|
78
|
+
fig.colorbar(cax3, ax=ax3)
|
79
|
+
|
80
|
+
# Fix y-axis limits
|
81
|
+
ax1.set_ylim(bottom=max(len(prod), len(vect)) - 0.5)
|
82
|
+
|
83
|
+
|
84
|
+
colors = dict(
|
85
|
+
aqua='#8dd3c7',
|
86
|
+
yellow='#ffffb3',
|
87
|
+
lavender='#bebada',
|
88
|
+
red='#fb8072',
|
89
|
+
blue='#80b1d3',
|
90
|
+
orange='#fdb462',
|
91
|
+
green='#b3de69',
|
92
|
+
pink='#fccde5',
|
93
|
+
grey='#d9d9d9',
|
94
|
+
violet='#bc80bd',
|
95
|
+
unk1='#ccebc5',
|
96
|
+
unk2='#ffed6f',
|
97
|
+
)
|
98
|
+
|
99
|
+
|
100
|
+
def _cstr(s, color='black'):
|
101
|
+
if s == ' ':
|
102
|
+
return f'<text style=color:#000;padding-left:10px;background-color:{color}> </text>'
|
103
|
+
else:
|
104
|
+
return f'<text style=color:#000;background-color:{color}>{s} </text>'
|
105
|
+
|
106
|
+
# print html
|
107
|
+
def _print_color(t):
|
108
|
+
display(HTML(''.join([_cstr(ti, color=ci) for ti, ci in t])))
|
109
|
+
|
110
|
+
# get appropriate color for value
|
111
|
+
def _get_clr(value):
|
112
|
+
colors = ('#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8',
|
113
|
+
'#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8',
|
114
|
+
'#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f',
|
115
|
+
'#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e')
|
116
|
+
value = int((value * 100) / 5)
|
117
|
+
if value == len(colors): value -= 1 # fixing bugs...
|
118
|
+
return colors[value]
|
119
|
+
|
120
|
+
def _visualise_values(output_values, result_list):
|
121
|
+
text_colours = []
|
122
|
+
for i in range(len(output_values)):
|
123
|
+
text = (result_list[i], _get_clr(output_values[i]))
|
124
|
+
text_colours.append(text)
|
125
|
+
_print_color(text_colours)
|
126
|
+
|
127
|
+
def print_colourbar():
|
128
|
+
color_range = torch.linspace(-2.5, 2.5, 20)
|
129
|
+
to_print = [(f'{x:.2f}', _get_clr((x+2.5)/5)) for x in color_range]
|
130
|
+
_print_color(to_print)
|
131
|
+
|
132
|
+
|
133
|
+
# Let's only focus on the last time step for now
|
134
|
+
# First, the cell state (Long term memory)
|
135
|
+
def plot_state(data, state, b, decoder):
|
136
|
+
actual_data = decoder(data[b, :, :].numpy())
|
137
|
+
seq_len = len(actual_data)
|
138
|
+
seq_len_w_pad = len(state)
|
139
|
+
for s in range(state.size(2)):
|
140
|
+
states = torch.sigmoid(state[:, b, s])
|
141
|
+
_visualise_values(states[seq_len_w_pad - seq_len:], list(actual_data))
|
@@ -0,0 +1,27 @@
|
|
1
|
+
import torch
|
2
|
+
import math
|
3
|
+
def spiral_data(N=1000, D=2, C=3):
|
4
|
+
X = torch.zeros(N * C, D)
|
5
|
+
y = torch.zeros(N * C, dtype=torch.long)
|
6
|
+
for c in range(C):
|
7
|
+
index = 0
|
8
|
+
t = torch.linspace(0, 1, N)
|
9
|
+
# When c = 0 and t = 0: start of linspace
|
10
|
+
# When c = 0 and t = 1: end of linpace
|
11
|
+
# This inner_var is for the formula inside sin() and cos() like sin(inner_var) and cos(inner_Var)
|
12
|
+
inner_var = torch.linspace(
|
13
|
+
# When t = 0
|
14
|
+
(2 * math.pi / C) * (c),
|
15
|
+
# When t = 1
|
16
|
+
(2 * math.pi / C) * (2 + c),
|
17
|
+
N
|
18
|
+
) + torch.randn(N) * 0.2
|
19
|
+
|
20
|
+
for ix in range(N * c, N * (c + 1)):
|
21
|
+
X[ix] = t[index] * torch.FloatTensor((
|
22
|
+
math.sin(inner_var[index]), math.cos(inner_var[index])
|
23
|
+
))
|
24
|
+
y[ix] = c
|
25
|
+
index += 1
|
26
|
+
|
27
|
+
return (X, y)
|