rbbt-dm 1.2.7 → 1.2.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/lib/rbbt/matrix/barcode.rb +2 -2
- data/lib/rbbt/matrix/differential.rb +3 -3
- data/lib/rbbt/matrix/knowledge_base.rb +1 -1
- data/lib/rbbt/plots/bar.rb +1 -1
- data/lib/rbbt/stan.rb +1 -1
- data/lib/rbbt/statistics/hypergeometric.rb +2 -1
- data/lib/rbbt/vector/model/huggingface/masked_lm.rb +50 -0
- data/lib/rbbt/vector/model/huggingface.rb +57 -38
- data/lib/rbbt/vector/model/pytorch_lightning.rb +35 -0
- data/lib/rbbt/vector/model/random_forest.rb +1 -1
- data/lib/rbbt/vector/model/spaCy.rb +8 -6
- data/lib/rbbt/vector/model/tensorflow.rb +6 -5
- data/lib/rbbt/vector/model/torch.rb +37 -0
- data/lib/rbbt/vector/model.rb +82 -52
- data/python/rbbt_dm/__init__.py +48 -1
- data/python/rbbt_dm/atcold/__init__.py +0 -0
- data/python/rbbt_dm/atcold/plot_lib.py +141 -0
- data/python/rbbt_dm/atcold/spiral.py +27 -0
- data/python/rbbt_dm/huggingface.py +57 -26
- data/python/rbbt_dm/language_model.py +70 -0
- data/python/rbbt_dm/util.py +30 -0
- data/share/spaCy/gpu/textcat_accuracy.conf +2 -1
- data/test/rbbt/vector/model/huggingface/test_masked_lm.rb +41 -0
- data/test/rbbt/vector/model/test_huggingface.rb +258 -27
- data/test/rbbt/vector/model/test_pytorch_lightning.rb +83 -0
- data/test/rbbt/vector/model/test_spaCy.rb +1 -1
- data/test/rbbt/vector/model/test_tensorflow.rb +3 -0
- data/test/rbbt/vector/test_model.rb +25 -26
- data/test/test_helper.rb +13 -0
- metadata +26 -16
- data/lib/rbbt/tensorflow.rb +0 -43
- data/lib/rbbt/vector/model/huggingface.old.rb +0 -160
data/lib/rbbt/vector/model.rb
CHANGED
@@ -1,16 +1,25 @@
|
|
1
1
|
require 'rbbt/util/R'
|
2
2
|
require 'rbbt/vector/model/util'
|
3
|
+
require 'rbbt/util/python'
|
4
|
+
|
5
|
+
RbbtPython.add_path Rbbt.python.find(:lib)
|
6
|
+
RbbtPython.init_rbbt
|
3
7
|
|
4
8
|
class VectorModel
|
5
|
-
attr_accessor :directory, :
|
9
|
+
attr_accessor :directory, :model_path, :extract_features, :init_model, :train_model, :eval_model, :post_process, :balance
|
6
10
|
attr_accessor :features, :names, :labels, :factor_levels
|
7
|
-
attr_accessor :model_options
|
11
|
+
attr_accessor :model, :model_options
|
8
12
|
|
9
13
|
def extract_features(&block)
|
10
14
|
@extract_features = block if block_given?
|
11
15
|
@extract_features
|
12
16
|
end
|
13
17
|
|
18
|
+
def init_model(&block)
|
19
|
+
@init_model = block if block_given?
|
20
|
+
@init_model
|
21
|
+
end
|
22
|
+
|
14
23
|
def train_model(&block)
|
15
24
|
@train_model = block if block_given?
|
16
25
|
@train_model
|
@@ -21,13 +30,17 @@ class VectorModel
|
|
21
30
|
@eval_model
|
22
31
|
end
|
23
32
|
|
33
|
+
def init
|
34
|
+
@model ||= self.instance_exec &@init_model
|
35
|
+
end
|
36
|
+
|
24
37
|
def post_process(&block)
|
25
38
|
@post_process = block if block_given?
|
26
39
|
@post_process
|
27
40
|
end
|
28
41
|
|
29
42
|
|
30
|
-
def self.R_run(
|
43
|
+
def self.R_run(model_path, features, labels, code, names = nil, factor_levels = nil)
|
31
44
|
TmpFile.with_file do |feature_file|
|
32
45
|
Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
|
33
46
|
Open.write(feature_file + '.label', labels * "\n" + "\n")
|
@@ -54,7 +67,7 @@ features = cbind(features, label = labels);
|
|
54
67
|
end
|
55
68
|
end
|
56
69
|
|
57
|
-
def self.R_train(
|
70
|
+
def self.R_train(model_path, features, labels, code, names = nil, factor_levels = nil)
|
58
71
|
TmpFile.with_file do |feature_file|
|
59
72
|
Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
|
60
73
|
Open.write(feature_file + '.label', labels * "\n" + "\n")
|
@@ -82,13 +95,13 @@ for (c in names(features)){
|
|
82
95
|
if (is.factor(features[[c]]))
|
83
96
|
factor_levels[c] = paste(levels(features[[c]]), collapse="\t")
|
84
97
|
}
|
85
|
-
rbbt.tsv.write("#{
|
86
|
-
save(model, file='#{
|
98
|
+
rbbt.tsv.write("#{model_path}.factor_levels", factor_levels, names=c('Levels'), type='flat')
|
99
|
+
save(model, file='#{model_path}')
|
87
100
|
EOF
|
88
101
|
end
|
89
102
|
end
|
90
103
|
|
91
|
-
def self.R_eval(
|
104
|
+
def self.R_eval(model_path, features, list, code, names = nil, factor_levels = nil)
|
92
105
|
TmpFile.with_file do |feature_file|
|
93
106
|
if list
|
94
107
|
Open.write(feature_file, features.collect{|feat| feat * "\t"} * "\n" + "\n")
|
@@ -105,7 +118,7 @@ features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=TRUE);
|
|
105
118
|
#{ factor_levels.collect do |name,levels|
|
106
119
|
"features[['#{name}']] = factor(features[['#{name}']], levels=#{R.ruby2R levels})"
|
107
120
|
end * "\n" if factor_levels }
|
108
|
-
load(file="#{
|
121
|
+
load(file="#{model_path}");
|
109
122
|
#{code}
|
110
123
|
cat(paste(label, sep="\\n", collapse="\\n"));
|
111
124
|
EOF
|
@@ -127,61 +140,77 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
127
140
|
instance_eval code, file
|
128
141
|
end
|
129
142
|
|
130
|
-
def initialize(directory = nil,
|
143
|
+
def initialize(directory = nil, model_options = {})
|
131
144
|
@directory = directory
|
145
|
+
@model_options = IndiferentHash.setup(model_options)
|
146
|
+
|
132
147
|
if @directory
|
133
|
-
FileUtils.mkdir_p @directory unless File.
|
148
|
+
FileUtils.mkdir_p @directory unless File.exist?(@directory)
|
149
|
+
|
150
|
+
@model_path = File.join(@directory, "model")
|
134
151
|
|
135
|
-
@model_file = File.join(@directory, "model")
|
136
152
|
@extract_features_file = File.join(@directory, "features")
|
137
|
-
@
|
138
|
-
|
139
|
-
@
|
140
|
-
@
|
141
|
-
|
142
|
-
@
|
143
|
-
@
|
144
|
-
|
145
|
-
@
|
146
|
-
|
147
|
-
|
148
|
-
|
153
|
+
@init_model_path = File.join(@directory, "init_model")
|
154
|
+
|
155
|
+
@train_model_path = File.join(@directory, "train_model")
|
156
|
+
@train_model_path_R = File.join(@directory, "train_model.R")
|
157
|
+
|
158
|
+
@eval_model_path = File.join(@directory, "eval_model")
|
159
|
+
@eval_model_path_R = File.join(@directory, "eval_model.R")
|
160
|
+
|
161
|
+
@post_process_file = File.join(@directory, "post_process")
|
162
|
+
@post_process_file_R = File.join(@directory, "post_process.R")
|
163
|
+
|
164
|
+
@names_file = File.join(@directory, "feature_names")
|
165
|
+
@levels_file = File.join(@directory, "levels")
|
166
|
+
@options_file = File.join(@directory, "options.json")
|
167
|
+
|
168
|
+
if File.exist?(@options_file)
|
169
|
+
@model_options = JSON.parse(Open.read(@options_file)).merge(@model_options || {})
|
149
170
|
IndiferentHash.setup(@model_options)
|
150
171
|
end
|
151
172
|
end
|
152
|
-
|
173
|
+
|
153
174
|
if extract_features.nil?
|
154
|
-
if @extract_features_file && File.
|
175
|
+
if @extract_features_file && File.exist?(@extract_features_file)
|
155
176
|
@extract_features = __load_method @extract_features_file
|
156
177
|
end
|
157
178
|
else
|
158
179
|
@extract_features = extract_features
|
159
180
|
end
|
160
181
|
|
182
|
+
if init_model.nil?
|
183
|
+
if @init_model_path && File.exist?(@init_model_path)
|
184
|
+
@init_model = __load_method @init_model_path
|
185
|
+
end
|
186
|
+
else
|
187
|
+
@init_model = init_model
|
188
|
+
end
|
189
|
+
|
161
190
|
if train_model.nil?
|
162
|
-
if @
|
163
|
-
@train_model = __load_method @
|
164
|
-
elsif @
|
165
|
-
@train_model = Open.read(@
|
191
|
+
if @train_model_path && File.exist?(@train_model_path)
|
192
|
+
@train_model = __load_method @train_model_path
|
193
|
+
elsif @train_model_path_R && File.exist?(@train_model_path_R)
|
194
|
+
@train_model = Open.read(@train_model_path_R)
|
166
195
|
end
|
167
196
|
else
|
168
197
|
@train_model = train_model
|
169
198
|
end
|
170
199
|
|
171
200
|
if eval_model.nil?
|
172
|
-
if @
|
173
|
-
@eval_model = __load_method @
|
174
|
-
elsif @
|
175
|
-
@eval_model = Open.read(@
|
201
|
+
if @eval_model_path && File.exist?(@eval_model_path)
|
202
|
+
@eval_model = __load_method @eval_model_path
|
203
|
+
elsif @eval_model_path_R && File.exist?(@eval_model_path_R)
|
204
|
+
@eval_model = Open.read(@eval_model_path_R)
|
176
205
|
end
|
177
206
|
else
|
178
207
|
@eval_model = eval_model
|
179
208
|
end
|
180
209
|
|
181
210
|
if post_process.nil?
|
182
|
-
if @post_process_file && File.
|
211
|
+
if @post_process_file && File.exist?(@post_process_file)
|
183
212
|
@post_process = __load_method @post_process_file
|
184
|
-
elsif @post_process_file_R && File.
|
213
|
+
elsif @post_process_file_R && File.exist?(@post_process_file_R)
|
185
214
|
@post_process = Open.read(@post_process_file_R)
|
186
215
|
end
|
187
216
|
else
|
@@ -190,7 +219,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
190
219
|
|
191
220
|
|
192
221
|
if names.nil?
|
193
|
-
if @names_file && File.
|
222
|
+
if @names_file && File.exist?(@names_file)
|
194
223
|
@names = Open.read(@names_file).split("\n")
|
195
224
|
end
|
196
225
|
else
|
@@ -198,11 +227,11 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
198
227
|
end
|
199
228
|
|
200
229
|
if factor_levels.nil?
|
201
|
-
if @levels_file && File.
|
230
|
+
if @levels_file && File.exist?(@levels_file)
|
202
231
|
@factor_levels = YAML.load(Open.read(@levels_file))
|
203
232
|
end
|
204
|
-
if @
|
205
|
-
@factor_levels = TSV.open(@
|
233
|
+
if @model_path && File.exist?(@model_path + '.factor_levels')
|
234
|
+
@factor_levels = TSV.open(@model_path + '.factor_levels')
|
206
235
|
end
|
207
236
|
else
|
208
237
|
@factor_levels = factor_levels
|
@@ -241,23 +270,24 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
241
270
|
case
|
242
271
|
when Proc === train_model
|
243
272
|
begin
|
244
|
-
Open.write(@
|
273
|
+
Open.write(@train_model_path, train_model.source)
|
245
274
|
rescue
|
246
275
|
end
|
247
276
|
when String === train_model
|
248
|
-
Open.write(@
|
277
|
+
Open.write(@train_model_path_R, @train_model)
|
249
278
|
end
|
250
279
|
|
251
280
|
Open.write(@extract_features_file, @extract_features.source) if @extract_features
|
281
|
+
Open.write(@init_model_path, @init_model.source) if @init_model
|
252
282
|
|
253
283
|
case
|
254
284
|
when Proc === eval_model
|
255
285
|
begin
|
256
|
-
Open.write(@
|
286
|
+
Open.write(@eval_model_path, eval_model.source)
|
257
287
|
rescue
|
258
288
|
end
|
259
289
|
when String === eval_model
|
260
|
-
Open.write(@
|
290
|
+
Open.write(@eval_model_path_R, eval_model)
|
261
291
|
end
|
262
292
|
|
263
293
|
case
|
@@ -285,9 +315,9 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
285
315
|
|
286
316
|
case
|
287
317
|
when Proc === @train_model
|
288
|
-
self.instance_exec(@
|
318
|
+
self.instance_exec(@features, @labels, @names, @factor_levels, &@train_model)
|
289
319
|
when String === @train_model
|
290
|
-
VectorModel.R_train(@
|
320
|
+
VectorModel.R_train(@model_path, @features, @labels, train_model, @names, @factor_levels)
|
291
321
|
end
|
292
322
|
ensure
|
293
323
|
if @balance
|
@@ -300,7 +330,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
300
330
|
end
|
301
331
|
|
302
332
|
def run(code)
|
303
|
-
VectorModel.R_run(@
|
333
|
+
VectorModel.R_run(@model_path, @features, @labels, code, @names, @factor_levels)
|
304
334
|
end
|
305
335
|
|
306
336
|
def eval(element)
|
@@ -308,14 +338,14 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
308
338
|
|
309
339
|
result = case
|
310
340
|
when Proc === @eval_model
|
311
|
-
self.instance_exec(
|
341
|
+
self.instance_exec(features, false, nil, @names, @factor_levels, &@eval_model)
|
312
342
|
when String === @eval_model
|
313
|
-
VectorModel.R_eval(@
|
343
|
+
VectorModel.R_eval(@model_path, features, false, eval_model, @names, @factor_levels)
|
314
344
|
else
|
315
345
|
raise "No @eval_model function or R script"
|
316
346
|
end
|
317
347
|
|
318
|
-
result = self.instance_exec(result, &@post_process) if Proc === @post_process
|
348
|
+
result = self.instance_exec(result, false, &@post_process) if Proc === @post_process
|
319
349
|
|
320
350
|
result
|
321
351
|
end
|
@@ -334,12 +364,12 @@ cat(paste(label, sep="\\n", collapse="\\n"));
|
|
334
364
|
|
335
365
|
result = case
|
336
366
|
when Proc === eval_model
|
337
|
-
self.instance_exec(
|
367
|
+
self.instance_exec(features, true, nil, @names, @factor_levels, &@eval_model)
|
338
368
|
when String === eval_model
|
339
|
-
VectorModel.R_eval(@
|
369
|
+
VectorModel.R_eval(@model_path, features, true, eval_model, @names, @factor_levels)
|
340
370
|
end
|
341
371
|
|
342
|
-
result = self.instance_exec(result, &@post_process) if Proc === @post_process
|
372
|
+
result = self.instance_exec(result, true, &@post_process) if Proc === @post_process
|
343
373
|
|
344
374
|
result
|
345
375
|
end
|
data/python/rbbt_dm/__init__.py
CHANGED
@@ -1 +1,48 @@
|
|
1
|
-
|
1
|
+
from torch.utils.data import Dataset, DataLoader
|
2
|
+
|
3
|
+
class TSVDataset(Dataset):
|
4
|
+
def __init__(self, tsv):
|
5
|
+
self.tsv = tsv
|
6
|
+
|
7
|
+
def __getitem__(self, key):
|
8
|
+
if (type(key) == int):
|
9
|
+
row = self.tsv.iloc[key]
|
10
|
+
else:
|
11
|
+
row = self.tsv.loc[key]
|
12
|
+
|
13
|
+
row = row.to_numpy()
|
14
|
+
features = row[:-1]
|
15
|
+
label = row[-1]
|
16
|
+
|
17
|
+
return features, label
|
18
|
+
|
19
|
+
def __len__(self):
|
20
|
+
return len(self.tsv)
|
21
|
+
|
22
|
+
def tsv_dataset(filename, *args, **kwargs):
|
23
|
+
import rbbt
|
24
|
+
return TSVDataset(rbbt.tsv(filename, *args, **kwargs))
|
25
|
+
|
26
|
+
def tsv(*args, **kwargs):
|
27
|
+
return tsv_dataset(*args, **kwargs)
|
28
|
+
|
29
|
+
def data_dir():
|
30
|
+
import rbbt
|
31
|
+
return rbbt.path('var/rbbt_dm/data')
|
32
|
+
|
33
|
+
if __name__ == "__main__":
|
34
|
+
import rbbt
|
35
|
+
|
36
|
+
filename = "/home/miki/test/numeric.tsv"
|
37
|
+
ds = tsv(filename)
|
38
|
+
|
39
|
+
dl = DataLoader(ds, batch_size=1)
|
40
|
+
|
41
|
+
for f, l in iter(dl):
|
42
|
+
print(".")
|
43
|
+
print(f[0,:])
|
44
|
+
print(l[0])
|
45
|
+
|
46
|
+
|
47
|
+
|
48
|
+
|
File without changes
|
@@ -0,0 +1,141 @@
|
|
1
|
+
from matplotlib import pyplot as plt
|
2
|
+
import numpy as np
|
3
|
+
import torch
|
4
|
+
from IPython.display import HTML, display
|
5
|
+
|
6
|
+
|
7
|
+
def set_default(figsize=(10, 10), dpi=100):
|
8
|
+
plt.style.use(['dark_background', 'bmh'])
|
9
|
+
plt.rc('axes', facecolor='k')
|
10
|
+
plt.rc('figure', facecolor='k')
|
11
|
+
plt.rc('figure', figsize=figsize, dpi=dpi)
|
12
|
+
|
13
|
+
|
14
|
+
def plot_data(X, y, d=0, auto=False, zoom=1):
|
15
|
+
X = X.cpu()
|
16
|
+
y = y.cpu()
|
17
|
+
plt.scatter(X.numpy()[:, 0], X.numpy()[:, 1], c=y, s=20, cmap=plt.cm.Spectral)
|
18
|
+
plt.axis('square')
|
19
|
+
plt.axis(np.array((-1.1, 1.1, -1.1, 1.1)) * zoom)
|
20
|
+
if auto is True: plt.axis('equal')
|
21
|
+
plt.axis('off')
|
22
|
+
|
23
|
+
_m, _c = 0, '.15'
|
24
|
+
plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0)
|
25
|
+
plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0)
|
26
|
+
|
27
|
+
|
28
|
+
def plot_model(X, y, model):
|
29
|
+
model.cpu()
|
30
|
+
mesh = np.arange(-1.1, 1.1, 0.01)
|
31
|
+
xx, yy = np.meshgrid(mesh, mesh)
|
32
|
+
with torch.no_grad():
|
33
|
+
data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float()
|
34
|
+
Z = model(data).detach()
|
35
|
+
Z = np.argmax(Z, axis=1).reshape(xx.shape)
|
36
|
+
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3)
|
37
|
+
plot_data(X, y)
|
38
|
+
|
39
|
+
|
40
|
+
def show_scatterplot(X, colors, title=''):
|
41
|
+
colors = colors.cpu().numpy()
|
42
|
+
X = X.cpu().numpy()
|
43
|
+
plt.figure()
|
44
|
+
plt.axis('equal')
|
45
|
+
plt.scatter(X[:, 0], X[:, 1], c=colors, s=30)
|
46
|
+
# plt.grid(True)
|
47
|
+
plt.title(title)
|
48
|
+
plt.axis('off')
|
49
|
+
|
50
|
+
|
51
|
+
def plot_bases(bases, width=0.04):
|
52
|
+
bases = bases.cpu()
|
53
|
+
bases[2:] -= bases[:2]
|
54
|
+
plt.arrow(*bases[0], *bases[2], width=width, color=(1,0,0), zorder=10, alpha=1., length_includes_head=True)
|
55
|
+
plt.arrow(*bases[1], *bases[3], width=width, color=(0,1,0), zorder=10, alpha=1., length_includes_head=True)
|
56
|
+
|
57
|
+
|
58
|
+
def show_mat(mat, vect, prod, threshold=-1):
|
59
|
+
# Subplot grid definition
|
60
|
+
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=False, sharey=True,
|
61
|
+
gridspec_kw={'width_ratios':[5,1,1]})
|
62
|
+
# Plot matrices
|
63
|
+
cax1 = ax1.matshow(mat.numpy(), clim=(-1, 1))
|
64
|
+
ax2.matshow(vect.numpy(), clim=(-1, 1))
|
65
|
+
cax3 = ax3.matshow(prod.numpy(), clim=(threshold, 1))
|
66
|
+
|
67
|
+
# Set titles
|
68
|
+
ax1.set_title(f'A: {mat.size(0)} \u00D7 {mat.size(1)}')
|
69
|
+
ax2.set_title(f'a^(i): {vect.numel()}')
|
70
|
+
ax3.set_title(f'p: {prod.numel()}')
|
71
|
+
|
72
|
+
# Remove xticks for vectors
|
73
|
+
ax2.set_xticks(tuple())
|
74
|
+
ax3.set_xticks(tuple())
|
75
|
+
|
76
|
+
# Plot colourbars
|
77
|
+
fig.colorbar(cax1, ax=ax2)
|
78
|
+
fig.colorbar(cax3, ax=ax3)
|
79
|
+
|
80
|
+
# Fix y-axis limits
|
81
|
+
ax1.set_ylim(bottom=max(len(prod), len(vect)) - 0.5)
|
82
|
+
|
83
|
+
|
84
|
+
colors = dict(
|
85
|
+
aqua='#8dd3c7',
|
86
|
+
yellow='#ffffb3',
|
87
|
+
lavender='#bebada',
|
88
|
+
red='#fb8072',
|
89
|
+
blue='#80b1d3',
|
90
|
+
orange='#fdb462',
|
91
|
+
green='#b3de69',
|
92
|
+
pink='#fccde5',
|
93
|
+
grey='#d9d9d9',
|
94
|
+
violet='#bc80bd',
|
95
|
+
unk1='#ccebc5',
|
96
|
+
unk2='#ffed6f',
|
97
|
+
)
|
98
|
+
|
99
|
+
|
100
|
+
def _cstr(s, color='black'):
|
101
|
+
if s == ' ':
|
102
|
+
return f'<text style=color:#000;padding-left:10px;background-color:{color}> </text>'
|
103
|
+
else:
|
104
|
+
return f'<text style=color:#000;background-color:{color}>{s} </text>'
|
105
|
+
|
106
|
+
# print html
|
107
|
+
def _print_color(t):
|
108
|
+
display(HTML(''.join([_cstr(ti, color=ci) for ti, ci in t])))
|
109
|
+
|
110
|
+
# get appropriate color for value
|
111
|
+
def _get_clr(value):
|
112
|
+
colors = ('#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8',
|
113
|
+
'#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8',
|
114
|
+
'#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f',
|
115
|
+
'#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e')
|
116
|
+
value = int((value * 100) / 5)
|
117
|
+
if value == len(colors): value -= 1 # fixing bugs...
|
118
|
+
return colors[value]
|
119
|
+
|
120
|
+
def _visualise_values(output_values, result_list):
|
121
|
+
text_colours = []
|
122
|
+
for i in range(len(output_values)):
|
123
|
+
text = (result_list[i], _get_clr(output_values[i]))
|
124
|
+
text_colours.append(text)
|
125
|
+
_print_color(text_colours)
|
126
|
+
|
127
|
+
def print_colourbar():
|
128
|
+
color_range = torch.linspace(-2.5, 2.5, 20)
|
129
|
+
to_print = [(f'{x:.2f}', _get_clr((x+2.5)/5)) for x in color_range]
|
130
|
+
_print_color(to_print)
|
131
|
+
|
132
|
+
|
133
|
+
# Let's only focus on the last time step for now
|
134
|
+
# First, the cell state (Long term memory)
|
135
|
+
def plot_state(data, state, b, decoder):
|
136
|
+
actual_data = decoder(data[b, :, :].numpy())
|
137
|
+
seq_len = len(actual_data)
|
138
|
+
seq_len_w_pad = len(state)
|
139
|
+
for s in range(state.size(2)):
|
140
|
+
states = torch.sigmoid(state[:, b, s])
|
141
|
+
_visualise_values(states[seq_len_w_pad - seq_len:], list(actual_data))
|
@@ -0,0 +1,27 @@
|
|
1
|
+
import torch
|
2
|
+
import math
|
3
|
+
def spiral_data(N=1000, D=2, C=3):
|
4
|
+
X = torch.zeros(N * C, D)
|
5
|
+
y = torch.zeros(N * C, dtype=torch.long)
|
6
|
+
for c in range(C):
|
7
|
+
index = 0
|
8
|
+
t = torch.linspace(0, 1, N)
|
9
|
+
# When c = 0 and t = 0: start of linspace
|
10
|
+
# When c = 0 and t = 1: end of linpace
|
11
|
+
# This inner_var is for the formula inside sin() and cos() like sin(inner_var) and cos(inner_Var)
|
12
|
+
inner_var = torch.linspace(
|
13
|
+
# When t = 0
|
14
|
+
(2 * math.pi / C) * (c),
|
15
|
+
# When t = 1
|
16
|
+
(2 * math.pi / C) * (2 + c),
|
17
|
+
N
|
18
|
+
) + torch.randn(N) * 0.2
|
19
|
+
|
20
|
+
for ix in range(N * c, N * (c + 1)):
|
21
|
+
X[ix] = t[index] * torch.FloatTensor((
|
22
|
+
math.sin(inner_var[index]), math.cos(inner_var[index])
|
23
|
+
))
|
24
|
+
y[ix] = c
|
25
|
+
index += 1
|
26
|
+
|
27
|
+
return (X, y)
|
@@ -1,32 +1,41 @@
|
|
1
1
|
#{{{ LOAD MODEL
|
2
2
|
|
3
3
|
def import_module_class(module, class_name):
|
4
|
-
|
4
|
+
if (not module == None):
|
5
|
+
exec(f"from {module} import {class_name}")
|
5
6
|
return eval(class_name)
|
6
7
|
|
7
|
-
def load_model(task, checkpoint):
|
8
|
-
|
9
|
-
|
8
|
+
def load_model(task, checkpoint, **kwargs):
|
9
|
+
if (":" in task):
|
10
|
+
module, class_name = task.split(":")
|
11
|
+
if (task == None):
|
12
|
+
module, class_name = None, module
|
13
|
+
return import_module_class(module, class_name).from_pretrained(checkpoint, **kwargs)
|
14
|
+
else:
|
15
|
+
class_name = 'AutoModelFor' + task
|
16
|
+
return import_module_class('transformers', class_name).from_pretrained(checkpoint)
|
10
17
|
|
11
|
-
def load_tokenizer(task, checkpoint):
|
18
|
+
def load_tokenizer(task, checkpoint, **kwargs):
|
12
19
|
class_name = 'AutoTokenizer'
|
13
|
-
return import_module_class('transformers', class_name).from_pretrained(checkpoint)
|
20
|
+
return import_module_class('transformers', class_name).from_pretrained(checkpoint, **kwargs)
|
14
21
|
|
15
22
|
def load_model_and_tokenizer(task, checkpoint):
|
16
23
|
model = load_model(task, checkpoint)
|
17
24
|
tokenizer = load_tokenizer(task, checkpoint)
|
18
25
|
return model, tokenizer
|
19
26
|
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
f
|
27
|
-
|
28
|
-
|
29
|
-
|
27
|
+
# Not used
|
28
|
+
|
29
|
+
#def load_model_and_tokenizer_from_directory(directory):
|
30
|
+
# import os
|
31
|
+
# import json
|
32
|
+
# options_file = os.path.join(directory, 'options.json')
|
33
|
+
# f = open(options_file, "r")
|
34
|
+
# options = json.load(f.read())
|
35
|
+
# f.close()
|
36
|
+
# task = options["task"]
|
37
|
+
# checkpoint = options["checkpoint"]
|
38
|
+
# return load_model_and_tokenizer(task, checkpoint)
|
30
39
|
|
31
40
|
#{{{ SIMPLE EVALUATE
|
32
41
|
|
@@ -51,21 +60,36 @@ def load_tsv(tsv_file):
|
|
51
60
|
from datasets import load_dataset
|
52
61
|
return load_dataset('csv', data_files=[tsv_file], sep="\t")
|
53
62
|
|
63
|
+
def load_json(json_file):
|
64
|
+
from datasets import load_dataset
|
65
|
+
return load_dataset('json', data_files=[json_file])
|
66
|
+
|
67
|
+
def tokenize_dataset(tokenizer, dataset):
|
68
|
+
return dataset.map(lambda subset: subset if ("input_ids" in subset.keys()) else tokenizer(subset["text"], truncation=True), batched=True)
|
69
|
+
|
54
70
|
def tsv_dataset(tokenizer, tsv_file):
|
55
71
|
dataset = load_tsv(tsv_file)
|
56
|
-
|
57
|
-
|
72
|
+
return tokenize_dataset(tokenizer, dataset)
|
73
|
+
|
74
|
+
def json_dataset(tokenizer, json_file):
|
75
|
+
dataset = load_json(json_file)
|
76
|
+
return tokenize_dataset(tokenizer, dataset)
|
58
77
|
|
59
78
|
def training_args(*args, **kwargs):
|
60
79
|
from transformers import TrainingArguments
|
61
80
|
training_args = TrainingArguments(*args, **kwargs)
|
62
81
|
return training_args
|
63
82
|
|
64
|
-
|
65
|
-
def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
|
83
|
+
def train_model(model, tokenizer, training_args, dataset, class_weights=None, **kwargs):
|
66
84
|
from transformers import Trainer
|
67
85
|
|
68
|
-
|
86
|
+
if (isinstance(dataset, str)):
|
87
|
+
if (dataset.endswith('.json')):
|
88
|
+
tokenized_dataset = json_dataset(tokenizer, dataset)
|
89
|
+
else:
|
90
|
+
tokenized_dataset = tsv_dataset(tokenizer, dataset)
|
91
|
+
else:
|
92
|
+
tokenized_dataset = tokenize_dataset(tokenizer, dataset)
|
69
93
|
|
70
94
|
if (not class_weights == None):
|
71
95
|
import torch
|
@@ -86,7 +110,8 @@ def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
|
|
86
110
|
model,
|
87
111
|
training_args,
|
88
112
|
train_dataset = tokenized_dataset["train"],
|
89
|
-
tokenizer = tokenizer
|
113
|
+
tokenizer = tokenizer,
|
114
|
+
**kwargs
|
90
115
|
)
|
91
116
|
else:
|
92
117
|
|
@@ -94,7 +119,8 @@ def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
|
|
94
119
|
model,
|
95
120
|
training_args,
|
96
121
|
train_dataset = tokenized_dataset["train"],
|
97
|
-
tokenizer = tokenizer
|
122
|
+
tokenizer = tokenizer,
|
123
|
+
**kwargs
|
98
124
|
)
|
99
125
|
|
100
126
|
trainer.train()
|
@@ -124,10 +150,16 @@ def find_tokens_in_input(dataset, token_ids):
|
|
124
150
|
return position_rows
|
125
151
|
|
126
152
|
|
127
|
-
def predict_model(model, tokenizer, training_args,
|
153
|
+
def predict_model(model, tokenizer, training_args, dataset, locate_tokens = None):
|
128
154
|
from transformers import Trainer
|
129
155
|
|
130
|
-
|
156
|
+
if (isinstance(dataset, str)):
|
157
|
+
if (dataset.endswith('.json')):
|
158
|
+
tokenized_dataset = json_dataset(tokenizer, dataset)
|
159
|
+
else:
|
160
|
+
tokenized_dataset = tsv_dataset(tokenizer, dataset)
|
161
|
+
else:
|
162
|
+
tokenized_dataset = tokenize_dataset(tokenizer, dataset)
|
131
163
|
|
132
164
|
trainer = Trainer(
|
133
165
|
model,
|
@@ -143,4 +175,3 @@ def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = Non
|
|
143
175
|
else:
|
144
176
|
return result
|
145
177
|
|
146
|
-
|