rbbt-dm 1.2.7 → 1.2.9

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,16 +1,25 @@
1
1
  require 'rbbt/util/R'
2
2
  require 'rbbt/vector/model/util'
3
+ require 'rbbt/util/python'
4
+
5
+ RbbtPython.add_path Rbbt.python.find(:lib)
6
+ RbbtPython.init_rbbt
3
7
 
4
8
  class VectorModel
5
- attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process, :balance
9
+ attr_accessor :directory, :model_path, :extract_features, :init_model, :train_model, :eval_model, :post_process, :balance
6
10
  attr_accessor :features, :names, :labels, :factor_levels
7
- attr_accessor :model_options
11
+ attr_accessor :model, :model_options
8
12
 
9
13
  def extract_features(&block)
10
14
  @extract_features = block if block_given?
11
15
  @extract_features
12
16
  end
13
17
 
18
+ def init_model(&block)
19
+ @init_model = block if block_given?
20
+ @init_model
21
+ end
22
+
14
23
  def train_model(&block)
15
24
  @train_model = block if block_given?
16
25
  @train_model
@@ -21,13 +30,17 @@ class VectorModel
21
30
  @eval_model
22
31
  end
23
32
 
33
+ def init
34
+ @model ||= self.instance_exec &@init_model
35
+ end
36
+
24
37
  def post_process(&block)
25
38
  @post_process = block if block_given?
26
39
  @post_process
27
40
  end
28
41
 
29
42
 
30
- def self.R_run(model_file, features, labels, code, names = nil, factor_levels = nil)
43
+ def self.R_run(model_path, features, labels, code, names = nil, factor_levels = nil)
31
44
  TmpFile.with_file do |feature_file|
32
45
  Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
33
46
  Open.write(feature_file + '.label', labels * "\n" + "\n")
@@ -54,7 +67,7 @@ features = cbind(features, label = labels);
54
67
  end
55
68
  end
56
69
 
57
- def self.R_train(model_file, features, labels, code, names = nil, factor_levels = nil)
70
+ def self.R_train(model_path, features, labels, code, names = nil, factor_levels = nil)
58
71
  TmpFile.with_file do |feature_file|
59
72
  Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
60
73
  Open.write(feature_file + '.label', labels * "\n" + "\n")
@@ -82,13 +95,13 @@ for (c in names(features)){
82
95
  if (is.factor(features[[c]]))
83
96
  factor_levels[c] = paste(levels(features[[c]]), collapse="\t")
84
97
  }
85
- rbbt.tsv.write("#{model_file}.factor_levels", factor_levels, names=c('Levels'), type='flat')
86
- save(model, file='#{model_file}')
98
+ rbbt.tsv.write("#{model_path}.factor_levels", factor_levels, names=c('Levels'), type='flat')
99
+ save(model, file='#{model_path}')
87
100
  EOF
88
101
  end
89
102
  end
90
103
 
91
- def self.R_eval(model_file, features, list, code, names = nil, factor_levels = nil)
104
+ def self.R_eval(model_path, features, list, code, names = nil, factor_levels = nil)
92
105
  TmpFile.with_file do |feature_file|
93
106
  if list
94
107
  Open.write(feature_file, features.collect{|feat| feat * "\t"} * "\n" + "\n")
@@ -105,7 +118,7 @@ features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=TRUE);
105
118
  #{ factor_levels.collect do |name,levels|
106
119
  "features[['#{name}']] = factor(features[['#{name}']], levels=#{R.ruby2R levels})"
107
120
  end * "\n" if factor_levels }
108
- load(file="#{model_file}");
121
+ load(file="#{model_path}");
109
122
  #{code}
110
123
  cat(paste(label, sep="\\n", collapse="\\n"));
111
124
  EOF
@@ -127,61 +140,77 @@ cat(paste(label, sep="\\n", collapse="\\n"));
127
140
  instance_eval code, file
128
141
  end
129
142
 
130
- def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, post_process = nil, names = nil, factor_levels = nil)
143
+ def initialize(directory = nil, model_options = {})
131
144
  @directory = directory
145
+ @model_options = IndiferentHash.setup(model_options)
146
+
132
147
  if @directory
133
- FileUtils.mkdir_p @directory unless File.exists?(@directory)
148
+ FileUtils.mkdir_p @directory unless File.exist?(@directory)
149
+
150
+ @model_path = File.join(@directory, "model")
134
151
 
135
- @model_file = File.join(@directory, "model")
136
152
  @extract_features_file = File.join(@directory, "features")
137
- @train_model_file = File.join(@directory, "train_model")
138
- @eval_model_file = File.join(@directory, "eval_model")
139
- @post_process_file = File.join(@directory, "post_process")
140
- @train_model_file_R = File.join(@directory, "train_model.R")
141
- @eval_model_file_R = File.join(@directory, "eval_model.R")
142
- @post_process_file_R = File.join(@directory, "post_process.R")
143
- @names_file = File.join(@directory, "feature_names")
144
- @levels_file = File.join(@directory, "levels")
145
- @options_file = File.join(@directory, "options.json")
146
-
147
- if File.exists?(@options_file)
148
- @model_options = JSON.parse(Open.read(@options_file))
153
+ @init_model_path = File.join(@directory, "init_model")
154
+
155
+ @train_model_path = File.join(@directory, "train_model")
156
+ @train_model_path_R = File.join(@directory, "train_model.R")
157
+
158
+ @eval_model_path = File.join(@directory, "eval_model")
159
+ @eval_model_path_R = File.join(@directory, "eval_model.R")
160
+
161
+ @post_process_file = File.join(@directory, "post_process")
162
+ @post_process_file_R = File.join(@directory, "post_process.R")
163
+
164
+ @names_file = File.join(@directory, "feature_names")
165
+ @levels_file = File.join(@directory, "levels")
166
+ @options_file = File.join(@directory, "options.json")
167
+
168
+ if File.exist?(@options_file)
169
+ @model_options = JSON.parse(Open.read(@options_file)).merge(@model_options || {})
149
170
  IndiferentHash.setup(@model_options)
150
171
  end
151
172
  end
152
-
173
+
153
174
  if extract_features.nil?
154
- if @extract_features_file && File.exists?(@extract_features_file)
175
+ if @extract_features_file && File.exist?(@extract_features_file)
155
176
  @extract_features = __load_method @extract_features_file
156
177
  end
157
178
  else
158
179
  @extract_features = extract_features
159
180
  end
160
181
 
182
+ if init_model.nil?
183
+ if @init_model_path && File.exist?(@init_model_path)
184
+ @init_model = __load_method @init_model_path
185
+ end
186
+ else
187
+ @init_model = init_model
188
+ end
189
+
161
190
  if train_model.nil?
162
- if @train_model_file && File.exists?(@train_model_file)
163
- @train_model = __load_method @train_model_file
164
- elsif @train_model_file_R && File.exists?(@train_model_file_R)
165
- @train_model = Open.read(@train_model_file_R)
191
+ if @train_model_path && File.exist?(@train_model_path)
192
+ @train_model = __load_method @train_model_path
193
+ elsif @train_model_path_R && File.exist?(@train_model_path_R)
194
+ @train_model = Open.read(@train_model_path_R)
166
195
  end
167
196
  else
168
197
  @train_model = train_model
169
198
  end
170
199
 
171
200
  if eval_model.nil?
172
- if @eval_model_file && File.exists?(@eval_model_file)
173
- @eval_model = __load_method @eval_model_file
174
- elsif @eval_model_file_R && File.exists?(@eval_model_file_R)
175
- @eval_model = Open.read(@eval_model_file_R)
201
+ if @eval_model_path && File.exist?(@eval_model_path)
202
+ @eval_model = __load_method @eval_model_path
203
+ elsif @eval_model_path_R && File.exist?(@eval_model_path_R)
204
+ @eval_model = Open.read(@eval_model_path_R)
176
205
  end
177
206
  else
178
207
  @eval_model = eval_model
179
208
  end
180
209
 
181
210
  if post_process.nil?
182
- if @post_process_file && File.exists?(@post_process_file)
211
+ if @post_process_file && File.exist?(@post_process_file)
183
212
  @post_process = __load_method @post_process_file
184
- elsif @post_process_file_R && File.exists?(@post_process_file_R)
213
+ elsif @post_process_file_R && File.exist?(@post_process_file_R)
185
214
  @post_process = Open.read(@post_process_file_R)
186
215
  end
187
216
  else
@@ -190,7 +219,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
190
219
 
191
220
 
192
221
  if names.nil?
193
- if @names_file && File.exists?(@names_file)
222
+ if @names_file && File.exist?(@names_file)
194
223
  @names = Open.read(@names_file).split("\n")
195
224
  end
196
225
  else
@@ -198,11 +227,11 @@ cat(paste(label, sep="\\n", collapse="\\n"));
198
227
  end
199
228
 
200
229
  if factor_levels.nil?
201
- if @levels_file && File.exists?(@levels_file)
230
+ if @levels_file && File.exist?(@levels_file)
202
231
  @factor_levels = YAML.load(Open.read(@levels_file))
203
232
  end
204
- if @model_file && File.exists?(@model_file + '.factor_levels')
205
- @factor_levels = TSV.open(@model_file + '.factor_levels')
233
+ if @model_path && File.exist?(@model_path + '.factor_levels')
234
+ @factor_levels = TSV.open(@model_path + '.factor_levels')
206
235
  end
207
236
  else
208
237
  @factor_levels = factor_levels
@@ -241,23 +270,24 @@ cat(paste(label, sep="\\n", collapse="\\n"));
241
270
  case
242
271
  when Proc === train_model
243
272
  begin
244
- Open.write(@train_model_file, train_model.source)
273
+ Open.write(@train_model_path, train_model.source)
245
274
  rescue
246
275
  end
247
276
  when String === train_model
248
- Open.write(@train_model_file_R, @train_model)
277
+ Open.write(@train_model_path_R, @train_model)
249
278
  end
250
279
 
251
280
  Open.write(@extract_features_file, @extract_features.source) if @extract_features
281
+ Open.write(@init_model_path, @init_model.source) if @init_model
252
282
 
253
283
  case
254
284
  when Proc === eval_model
255
285
  begin
256
- Open.write(@eval_model_file, eval_model.source)
286
+ Open.write(@eval_model_path, eval_model.source)
257
287
  rescue
258
288
  end
259
289
  when String === eval_model
260
- Open.write(@eval_model_file_R, eval_model)
290
+ Open.write(@eval_model_path_R, eval_model)
261
291
  end
262
292
 
263
293
  case
@@ -285,9 +315,9 @@ cat(paste(label, sep="\\n", collapse="\\n"));
285
315
 
286
316
  case
287
317
  when Proc === @train_model
288
- self.instance_exec(@model_file, @features, @labels, @names, @factor_levels, &@train_model)
318
+ self.instance_exec(@features, @labels, @names, @factor_levels, &@train_model)
289
319
  when String === @train_model
290
- VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
320
+ VectorModel.R_train(@model_path, @features, @labels, train_model, @names, @factor_levels)
291
321
  end
292
322
  ensure
293
323
  if @balance
@@ -300,7 +330,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
300
330
  end
301
331
 
302
332
  def run(code)
303
- VectorModel.R_run(@model_file, @features, @labels, code, @names, @factor_levels)
333
+ VectorModel.R_run(@model_path, @features, @labels, code, @names, @factor_levels)
304
334
  end
305
335
 
306
336
  def eval(element)
@@ -308,14 +338,14 @@ cat(paste(label, sep="\\n", collapse="\\n"));
308
338
 
309
339
  result = case
310
340
  when Proc === @eval_model
311
- self.instance_exec(@model_file, features, false, nil, @names, @factor_levels, &@eval_model)
341
+ self.instance_exec(features, false, nil, @names, @factor_levels, &@eval_model)
312
342
  when String === @eval_model
313
- VectorModel.R_eval(@model_file, features, false, eval_model, @names, @factor_levels)
343
+ VectorModel.R_eval(@model_path, features, false, eval_model, @names, @factor_levels)
314
344
  else
315
345
  raise "No @eval_model function or R script"
316
346
  end
317
347
 
318
- result = self.instance_exec(result, &@post_process) if Proc === @post_process
348
+ result = self.instance_exec(result, false, &@post_process) if Proc === @post_process
319
349
 
320
350
  result
321
351
  end
@@ -334,12 +364,12 @@ cat(paste(label, sep="\\n", collapse="\\n"));
334
364
 
335
365
  result = case
336
366
  when Proc === eval_model
337
- self.instance_exec(@model_file, features, true, nil, @names, @factor_levels, &@eval_model)
367
+ self.instance_exec(features, true, nil, @names, @factor_levels, &@eval_model)
338
368
  when String === eval_model
339
- VectorModel.R_eval(@model_file, features, true, eval_model, @names, @factor_levels)
369
+ VectorModel.R_eval(@model_path, features, true, eval_model, @names, @factor_levels)
340
370
  end
341
371
 
342
- result = self.instance_exec(result, &@post_process) if Proc === @post_process
372
+ result = self.instance_exec(result, true, &@post_process) if Proc === @post_process
343
373
 
344
374
  result
345
375
  end
@@ -1 +1,48 @@
1
- # Keep
1
+ from torch.utils.data import Dataset, DataLoader
2
+
3
+ class TSVDataset(Dataset):
4
+ def __init__(self, tsv):
5
+ self.tsv = tsv
6
+
7
+ def __getitem__(self, key):
8
+ if (type(key) == int):
9
+ row = self.tsv.iloc[key]
10
+ else:
11
+ row = self.tsv.loc[key]
12
+
13
+ row = row.to_numpy()
14
+ features = row[:-1]
15
+ label = row[-1]
16
+
17
+ return features, label
18
+
19
+ def __len__(self):
20
+ return len(self.tsv)
21
+
22
+ def tsv_dataset(filename, *args, **kwargs):
23
+ import rbbt
24
+ return TSVDataset(rbbt.tsv(filename, *args, **kwargs))
25
+
26
+ def tsv(*args, **kwargs):
27
+ return tsv_dataset(*args, **kwargs)
28
+
29
+ def data_dir():
30
+ import rbbt
31
+ return rbbt.path('var/rbbt_dm/data')
32
+
33
+ if __name__ == "__main__":
34
+ import rbbt
35
+
36
+ filename = "/home/miki/test/numeric.tsv"
37
+ ds = tsv(filename)
38
+
39
+ dl = DataLoader(ds, batch_size=1)
40
+
41
+ for f, l in iter(dl):
42
+ print(".")
43
+ print(f[0,:])
44
+ print(l[0])
45
+
46
+
47
+
48
+
File without changes
@@ -0,0 +1,141 @@
1
+ from matplotlib import pyplot as plt
2
+ import numpy as np
3
+ import torch
4
+ from IPython.display import HTML, display
5
+
6
+
7
+ def set_default(figsize=(10, 10), dpi=100):
8
+ plt.style.use(['dark_background', 'bmh'])
9
+ plt.rc('axes', facecolor='k')
10
+ plt.rc('figure', facecolor='k')
11
+ plt.rc('figure', figsize=figsize, dpi=dpi)
12
+
13
+
14
+ def plot_data(X, y, d=0, auto=False, zoom=1):
15
+ X = X.cpu()
16
+ y = y.cpu()
17
+ plt.scatter(X.numpy()[:, 0], X.numpy()[:, 1], c=y, s=20, cmap=plt.cm.Spectral)
18
+ plt.axis('square')
19
+ plt.axis(np.array((-1.1, 1.1, -1.1, 1.1)) * zoom)
20
+ if auto is True: plt.axis('equal')
21
+ plt.axis('off')
22
+
23
+ _m, _c = 0, '.15'
24
+ plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0)
25
+ plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0)
26
+
27
+
28
+ def plot_model(X, y, model):
29
+ model.cpu()
30
+ mesh = np.arange(-1.1, 1.1, 0.01)
31
+ xx, yy = np.meshgrid(mesh, mesh)
32
+ with torch.no_grad():
33
+ data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float()
34
+ Z = model(data).detach()
35
+ Z = np.argmax(Z, axis=1).reshape(xx.shape)
36
+ plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3)
37
+ plot_data(X, y)
38
+
39
+
40
+ def show_scatterplot(X, colors, title=''):
41
+ colors = colors.cpu().numpy()
42
+ X = X.cpu().numpy()
43
+ plt.figure()
44
+ plt.axis('equal')
45
+ plt.scatter(X[:, 0], X[:, 1], c=colors, s=30)
46
+ # plt.grid(True)
47
+ plt.title(title)
48
+ plt.axis('off')
49
+
50
+
51
+ def plot_bases(bases, width=0.04):
52
+ bases = bases.cpu()
53
+ bases[2:] -= bases[:2]
54
+ plt.arrow(*bases[0], *bases[2], width=width, color=(1,0,0), zorder=10, alpha=1., length_includes_head=True)
55
+ plt.arrow(*bases[1], *bases[3], width=width, color=(0,1,0), zorder=10, alpha=1., length_includes_head=True)
56
+
57
+
58
+ def show_mat(mat, vect, prod, threshold=-1):
59
+ # Subplot grid definition
60
+ fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=False, sharey=True,
61
+ gridspec_kw={'width_ratios':[5,1,1]})
62
+ # Plot matrices
63
+ cax1 = ax1.matshow(mat.numpy(), clim=(-1, 1))
64
+ ax2.matshow(vect.numpy(), clim=(-1, 1))
65
+ cax3 = ax3.matshow(prod.numpy(), clim=(threshold, 1))
66
+
67
+ # Set titles
68
+ ax1.set_title(f'A: {mat.size(0)} \u00D7 {mat.size(1)}')
69
+ ax2.set_title(f'a^(i): {vect.numel()}')
70
+ ax3.set_title(f'p: {prod.numel()}')
71
+
72
+ # Remove xticks for vectors
73
+ ax2.set_xticks(tuple())
74
+ ax3.set_xticks(tuple())
75
+
76
+ # Plot colourbars
77
+ fig.colorbar(cax1, ax=ax2)
78
+ fig.colorbar(cax3, ax=ax3)
79
+
80
+ # Fix y-axis limits
81
+ ax1.set_ylim(bottom=max(len(prod), len(vect)) - 0.5)
82
+
83
+
84
+ colors = dict(
85
+ aqua='#8dd3c7',
86
+ yellow='#ffffb3',
87
+ lavender='#bebada',
88
+ red='#fb8072',
89
+ blue='#80b1d3',
90
+ orange='#fdb462',
91
+ green='#b3de69',
92
+ pink='#fccde5',
93
+ grey='#d9d9d9',
94
+ violet='#bc80bd',
95
+ unk1='#ccebc5',
96
+ unk2='#ffed6f',
97
+ )
98
+
99
+
100
+ def _cstr(s, color='black'):
101
+ if s == ' ':
102
+ return f'<text style=color:#000;padding-left:10px;background-color:{color}> </text>'
103
+ else:
104
+ return f'<text style=color:#000;background-color:{color}>{s} </text>'
105
+
106
+ # print html
107
+ def _print_color(t):
108
+ display(HTML(''.join([_cstr(ti, color=ci) for ti, ci in t])))
109
+
110
+ # get appropriate color for value
111
+ def _get_clr(value):
112
+ colors = ('#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8',
113
+ '#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8',
114
+ '#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f',
115
+ '#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e')
116
+ value = int((value * 100) / 5)
117
+ if value == len(colors): value -= 1 # fixing bugs...
118
+ return colors[value]
119
+
120
+ def _visualise_values(output_values, result_list):
121
+ text_colours = []
122
+ for i in range(len(output_values)):
123
+ text = (result_list[i], _get_clr(output_values[i]))
124
+ text_colours.append(text)
125
+ _print_color(text_colours)
126
+
127
+ def print_colourbar():
128
+ color_range = torch.linspace(-2.5, 2.5, 20)
129
+ to_print = [(f'{x:.2f}', _get_clr((x+2.5)/5)) for x in color_range]
130
+ _print_color(to_print)
131
+
132
+
133
+ # Let's only focus on the last time step for now
134
+ # First, the cell state (Long term memory)
135
+ def plot_state(data, state, b, decoder):
136
+ actual_data = decoder(data[b, :, :].numpy())
137
+ seq_len = len(actual_data)
138
+ seq_len_w_pad = len(state)
139
+ for s in range(state.size(2)):
140
+ states = torch.sigmoid(state[:, b, s])
141
+ _visualise_values(states[seq_len_w_pad - seq_len:], list(actual_data))
@@ -0,0 +1,27 @@
1
+ import torch
2
+ import math
3
+ def spiral_data(N=1000, D=2, C=3):
4
+ X = torch.zeros(N * C, D)
5
+ y = torch.zeros(N * C, dtype=torch.long)
6
+ for c in range(C):
7
+ index = 0
8
+ t = torch.linspace(0, 1, N)
9
+ # When c = 0 and t = 0: start of linspace
10
+ # When c = 0 and t = 1: end of linpace
11
+ # This inner_var is for the formula inside sin() and cos() like sin(inner_var) and cos(inner_Var)
12
+ inner_var = torch.linspace(
13
+ # When t = 0
14
+ (2 * math.pi / C) * (c),
15
+ # When t = 1
16
+ (2 * math.pi / C) * (2 + c),
17
+ N
18
+ ) + torch.randn(N) * 0.2
19
+
20
+ for ix in range(N * c, N * (c + 1)):
21
+ X[ix] = t[index] * torch.FloatTensor((
22
+ math.sin(inner_var[index]), math.cos(inner_var[index])
23
+ ))
24
+ y[ix] = c
25
+ index += 1
26
+
27
+ return (X, y)
@@ -1,32 +1,41 @@
1
1
  #{{{ LOAD MODEL
2
2
 
3
3
  def import_module_class(module, class_name):
4
- exec(f"from {module} import {class_name}")
4
+ if (not module == None):
5
+ exec(f"from {module} import {class_name}")
5
6
  return eval(class_name)
6
7
 
7
- def load_model(task, checkpoint):
8
- class_name = 'AutoModelFor' + task
9
- return import_module_class('transformers', class_name).from_pretrained(checkpoint)
8
+ def load_model(task, checkpoint, **kwargs):
9
+ if (":" in task):
10
+ module, class_name = task.split(":")
11
+ if (task == None):
12
+ module, class_name = None, module
13
+ return import_module_class(module, class_name).from_pretrained(checkpoint, **kwargs)
14
+ else:
15
+ class_name = 'AutoModelFor' + task
16
+ return import_module_class('transformers', class_name).from_pretrained(checkpoint)
10
17
 
11
- def load_tokenizer(task, checkpoint):
18
+ def load_tokenizer(task, checkpoint, **kwargs):
12
19
  class_name = 'AutoTokenizer'
13
- return import_module_class('transformers', class_name).from_pretrained(checkpoint)
20
+ return import_module_class('transformers', class_name).from_pretrained(checkpoint, **kwargs)
14
21
 
15
22
  def load_model_and_tokenizer(task, checkpoint):
16
23
  model = load_model(task, checkpoint)
17
24
  tokenizer = load_tokenizer(task, checkpoint)
18
25
  return model, tokenizer
19
26
 
20
- def load_model_and_tokenizer_from_directory(directory):
21
- import os
22
- import json
23
- options_file = os.path.join(directory, 'options.json')
24
- f = open(options_file, "r")
25
- options = json.load(f.read())
26
- f.close()
27
- task = options["task"]
28
- checkpoint = options["checkpoint"]
29
- return load_model_and_tokenizer(task, checkpoint)
27
+ # Not used
28
+
29
+ #def load_model_and_tokenizer_from_directory(directory):
30
+ # import os
31
+ # import json
32
+ # options_file = os.path.join(directory, 'options.json')
33
+ # f = open(options_file, "r")
34
+ # options = json.load(f.read())
35
+ # f.close()
36
+ # task = options["task"]
37
+ # checkpoint = options["checkpoint"]
38
+ # return load_model_and_tokenizer(task, checkpoint)
30
39
 
31
40
  #{{{ SIMPLE EVALUATE
32
41
 
@@ -51,21 +60,36 @@ def load_tsv(tsv_file):
51
60
  from datasets import load_dataset
52
61
  return load_dataset('csv', data_files=[tsv_file], sep="\t")
53
62
 
63
+ def load_json(json_file):
64
+ from datasets import load_dataset
65
+ return load_dataset('json', data_files=[json_file])
66
+
67
+ def tokenize_dataset(tokenizer, dataset):
68
+ return dataset.map(lambda subset: subset if ("input_ids" in subset.keys()) else tokenizer(subset["text"], truncation=True), batched=True)
69
+
54
70
  def tsv_dataset(tokenizer, tsv_file):
55
71
  dataset = load_tsv(tsv_file)
56
- tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True, max_length=512) , batched=True)
57
- return tokenized_dataset
72
+ return tokenize_dataset(tokenizer, dataset)
73
+
74
+ def json_dataset(tokenizer, json_file):
75
+ dataset = load_json(json_file)
76
+ return tokenize_dataset(tokenizer, dataset)
58
77
 
59
78
  def training_args(*args, **kwargs):
60
79
  from transformers import TrainingArguments
61
80
  training_args = TrainingArguments(*args, **kwargs)
62
81
  return training_args
63
82
 
64
-
65
- def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
83
+ def train_model(model, tokenizer, training_args, dataset, class_weights=None, **kwargs):
66
84
  from transformers import Trainer
67
85
 
68
- tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
86
+ if (isinstance(dataset, str)):
87
+ if (dataset.endswith('.json')):
88
+ tokenized_dataset = json_dataset(tokenizer, dataset)
89
+ else:
90
+ tokenized_dataset = tsv_dataset(tokenizer, dataset)
91
+ else:
92
+ tokenized_dataset = tokenize_dataset(tokenizer, dataset)
69
93
 
70
94
  if (not class_weights == None):
71
95
  import torch
@@ -86,7 +110,8 @@ def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
86
110
  model,
87
111
  training_args,
88
112
  train_dataset = tokenized_dataset["train"],
89
- tokenizer = tokenizer
113
+ tokenizer = tokenizer,
114
+ **kwargs
90
115
  )
91
116
  else:
92
117
 
@@ -94,7 +119,8 @@ def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
94
119
  model,
95
120
  training_args,
96
121
  train_dataset = tokenized_dataset["train"],
97
- tokenizer = tokenizer
122
+ tokenizer = tokenizer,
123
+ **kwargs
98
124
  )
99
125
 
100
126
  trainer.train()
@@ -124,10 +150,16 @@ def find_tokens_in_input(dataset, token_ids):
124
150
  return position_rows
125
151
 
126
152
 
127
- def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = None):
153
+ def predict_model(model, tokenizer, training_args, dataset, locate_tokens = None):
128
154
  from transformers import Trainer
129
155
 
130
- tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
156
+ if (isinstance(dataset, str)):
157
+ if (dataset.endswith('.json')):
158
+ tokenized_dataset = json_dataset(tokenizer, dataset)
159
+ else:
160
+ tokenized_dataset = tsv_dataset(tokenizer, dataset)
161
+ else:
162
+ tokenized_dataset = tokenize_dataset(tokenizer, dataset)
131
163
 
132
164
  trainer = Trainer(
133
165
  model,
@@ -143,4 +175,3 @@ def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = Non
143
175
  else:
144
176
  return result
145
177
 
146
-