rbbt-dm 1.2.7 → 1.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,16 +1,25 @@
1
1
  require 'rbbt/util/R'
2
2
  require 'rbbt/vector/model/util'
3
+ require 'rbbt/util/python'
4
+
5
+ RbbtPython.add_path Rbbt.python.find(:lib)
6
+ RbbtPython.init_rbbt
3
7
 
4
8
  class VectorModel
5
- attr_accessor :directory, :model_file, :extract_features, :train_model, :eval_model, :post_process, :balance
9
+ attr_accessor :directory, :model_path, :extract_features, :init_model, :train_model, :eval_model, :post_process, :balance
6
10
  attr_accessor :features, :names, :labels, :factor_levels
7
- attr_accessor :model_options
11
+ attr_accessor :model, :model_options
8
12
 
9
13
  def extract_features(&block)
10
14
  @extract_features = block if block_given?
11
15
  @extract_features
12
16
  end
13
17
 
18
+ def init_model(&block)
19
+ @init_model = block if block_given?
20
+ @init_model
21
+ end
22
+
14
23
  def train_model(&block)
15
24
  @train_model = block if block_given?
16
25
  @train_model
@@ -21,13 +30,17 @@ class VectorModel
21
30
  @eval_model
22
31
  end
23
32
 
33
+ def init
34
+ @model ||= self.instance_exec &@init_model
35
+ end
36
+
24
37
  def post_process(&block)
25
38
  @post_process = block if block_given?
26
39
  @post_process
27
40
  end
28
41
 
29
42
 
30
- def self.R_run(model_file, features, labels, code, names = nil, factor_levels = nil)
43
+ def self.R_run(model_path, features, labels, code, names = nil, factor_levels = nil)
31
44
  TmpFile.with_file do |feature_file|
32
45
  Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
33
46
  Open.write(feature_file + '.label', labels * "\n" + "\n")
@@ -54,7 +67,7 @@ features = cbind(features, label = labels);
54
67
  end
55
68
  end
56
69
 
57
- def self.R_train(model_file, features, labels, code, names = nil, factor_levels = nil)
70
+ def self.R_train(model_path, features, labels, code, names = nil, factor_levels = nil)
58
71
  TmpFile.with_file do |feature_file|
59
72
  Open.write(feature_file, features.collect{|feats| feats * "\t"} * "\n")
60
73
  Open.write(feature_file + '.label', labels * "\n" + "\n")
@@ -82,13 +95,13 @@ for (c in names(features)){
82
95
  if (is.factor(features[[c]]))
83
96
  factor_levels[c] = paste(levels(features[[c]]), collapse="\t")
84
97
  }
85
- rbbt.tsv.write("#{model_file}.factor_levels", factor_levels, names=c('Levels'), type='flat')
86
- save(model, file='#{model_file}')
98
+ rbbt.tsv.write("#{model_path}.factor_levels", factor_levels, names=c('Levels'), type='flat')
99
+ save(model, file='#{model_path}')
87
100
  EOF
88
101
  end
89
102
  end
90
103
 
91
- def self.R_eval(model_file, features, list, code, names = nil, factor_levels = nil)
104
+ def self.R_eval(model_path, features, list, code, names = nil, factor_levels = nil)
92
105
  TmpFile.with_file do |feature_file|
93
106
  if list
94
107
  Open.write(feature_file, features.collect{|feat| feat * "\t"} * "\n" + "\n")
@@ -105,7 +118,7 @@ features = read.table("#{ feature_file }", sep ="\\t", stringsAsFactors=TRUE);
105
118
  #{ factor_levels.collect do |name,levels|
106
119
  "features[['#{name}']] = factor(features[['#{name}']], levels=#{R.ruby2R levels})"
107
120
  end * "\n" if factor_levels }
108
- load(file="#{model_file}");
121
+ load(file="#{model_path}");
109
122
  #{code}
110
123
  cat(paste(label, sep="\\n", collapse="\\n"));
111
124
  EOF
@@ -127,61 +140,77 @@ cat(paste(label, sep="\\n", collapse="\\n"));
127
140
  instance_eval code, file
128
141
  end
129
142
 
130
- def initialize(directory = nil, extract_features = nil, train_model = nil, eval_model = nil, post_process = nil, names = nil, factor_levels = nil)
143
+ def initialize(directory = nil, model_options = {})
131
144
  @directory = directory
145
+ @model_options = IndiferentHash.setup(model_options)
146
+
132
147
  if @directory
133
- FileUtils.mkdir_p @directory unless File.exists?(@directory)
148
+ FileUtils.mkdir_p @directory unless File.exist?(@directory)
149
+
150
+ @model_path = File.join(@directory, "model")
134
151
 
135
- @model_file = File.join(@directory, "model")
136
152
  @extract_features_file = File.join(@directory, "features")
137
- @train_model_file = File.join(@directory, "train_model")
138
- @eval_model_file = File.join(@directory, "eval_model")
139
- @post_process_file = File.join(@directory, "post_process")
140
- @train_model_file_R = File.join(@directory, "train_model.R")
141
- @eval_model_file_R = File.join(@directory, "eval_model.R")
142
- @post_process_file_R = File.join(@directory, "post_process.R")
143
- @names_file = File.join(@directory, "feature_names")
144
- @levels_file = File.join(@directory, "levels")
145
- @options_file = File.join(@directory, "options.json")
146
-
147
- if File.exists?(@options_file)
148
- @model_options = JSON.parse(Open.read(@options_file))
153
+ @init_model_path = File.join(@directory, "init_model")
154
+
155
+ @train_model_path = File.join(@directory, "train_model")
156
+ @train_model_path_R = File.join(@directory, "train_model.R")
157
+
158
+ @eval_model_path = File.join(@directory, "eval_model")
159
+ @eval_model_path_R = File.join(@directory, "eval_model.R")
160
+
161
+ @post_process_file = File.join(@directory, "post_process")
162
+ @post_process_file_R = File.join(@directory, "post_process.R")
163
+
164
+ @names_file = File.join(@directory, "feature_names")
165
+ @levels_file = File.join(@directory, "levels")
166
+ @options_file = File.join(@directory, "options.json")
167
+
168
+ if File.exist?(@options_file)
169
+ @model_options = JSON.parse(Open.read(@options_file)).merge(@model_options || {})
149
170
  IndiferentHash.setup(@model_options)
150
171
  end
151
172
  end
152
-
173
+
153
174
  if extract_features.nil?
154
- if @extract_features_file && File.exists?(@extract_features_file)
175
+ if @extract_features_file && File.exist?(@extract_features_file)
155
176
  @extract_features = __load_method @extract_features_file
156
177
  end
157
178
  else
158
179
  @extract_features = extract_features
159
180
  end
160
181
 
182
+ if init_model.nil?
183
+ if @init_model_path && File.exist?(@init_model_path)
184
+ @init_model = __load_method @init_model_path
185
+ end
186
+ else
187
+ @init_model = init_model
188
+ end
189
+
161
190
  if train_model.nil?
162
- if @train_model_file && File.exists?(@train_model_file)
163
- @train_model = __load_method @train_model_file
164
- elsif @train_model_file_R && File.exists?(@train_model_file_R)
165
- @train_model = Open.read(@train_model_file_R)
191
+ if @train_model_path && File.exist?(@train_model_path)
192
+ @train_model = __load_method @train_model_path
193
+ elsif @train_model_path_R && File.exist?(@train_model_path_R)
194
+ @train_model = Open.read(@train_model_path_R)
166
195
  end
167
196
  else
168
197
  @train_model = train_model
169
198
  end
170
199
 
171
200
  if eval_model.nil?
172
- if @eval_model_file && File.exists?(@eval_model_file)
173
- @eval_model = __load_method @eval_model_file
174
- elsif @eval_model_file_R && File.exists?(@eval_model_file_R)
175
- @eval_model = Open.read(@eval_model_file_R)
201
+ if @eval_model_path && File.exist?(@eval_model_path)
202
+ @eval_model = __load_method @eval_model_path
203
+ elsif @eval_model_path_R && File.exist?(@eval_model_path_R)
204
+ @eval_model = Open.read(@eval_model_path_R)
176
205
  end
177
206
  else
178
207
  @eval_model = eval_model
179
208
  end
180
209
 
181
210
  if post_process.nil?
182
- if @post_process_file && File.exists?(@post_process_file)
211
+ if @post_process_file && File.exist?(@post_process_file)
183
212
  @post_process = __load_method @post_process_file
184
- elsif @post_process_file_R && File.exists?(@post_process_file_R)
213
+ elsif @post_process_file_R && File.exist?(@post_process_file_R)
185
214
  @post_process = Open.read(@post_process_file_R)
186
215
  end
187
216
  else
@@ -190,7 +219,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
190
219
 
191
220
 
192
221
  if names.nil?
193
- if @names_file && File.exists?(@names_file)
222
+ if @names_file && File.exist?(@names_file)
194
223
  @names = Open.read(@names_file).split("\n")
195
224
  end
196
225
  else
@@ -198,11 +227,11 @@ cat(paste(label, sep="\\n", collapse="\\n"));
198
227
  end
199
228
 
200
229
  if factor_levels.nil?
201
- if @levels_file && File.exists?(@levels_file)
230
+ if @levels_file && File.exist?(@levels_file)
202
231
  @factor_levels = YAML.load(Open.read(@levels_file))
203
232
  end
204
- if @model_file && File.exists?(@model_file + '.factor_levels')
205
- @factor_levels = TSV.open(@model_file + '.factor_levels')
233
+ if @model_path && File.exist?(@model_path + '.factor_levels')
234
+ @factor_levels = TSV.open(@model_path + '.factor_levels')
206
235
  end
207
236
  else
208
237
  @factor_levels = factor_levels
@@ -241,23 +270,24 @@ cat(paste(label, sep="\\n", collapse="\\n"));
241
270
  case
242
271
  when Proc === train_model
243
272
  begin
244
- Open.write(@train_model_file, train_model.source)
273
+ Open.write(@train_model_path, train_model.source)
245
274
  rescue
246
275
  end
247
276
  when String === train_model
248
- Open.write(@train_model_file_R, @train_model)
277
+ Open.write(@train_model_path_R, @train_model)
249
278
  end
250
279
 
251
280
  Open.write(@extract_features_file, @extract_features.source) if @extract_features
281
+ Open.write(@init_model_path, @init_model.source) if @init_model
252
282
 
253
283
  case
254
284
  when Proc === eval_model
255
285
  begin
256
- Open.write(@eval_model_file, eval_model.source)
286
+ Open.write(@eval_model_path, eval_model.source)
257
287
  rescue
258
288
  end
259
289
  when String === eval_model
260
- Open.write(@eval_model_file_R, eval_model)
290
+ Open.write(@eval_model_path_R, eval_model)
261
291
  end
262
292
 
263
293
  case
@@ -285,9 +315,9 @@ cat(paste(label, sep="\\n", collapse="\\n"));
285
315
 
286
316
  case
287
317
  when Proc === @train_model
288
- self.instance_exec(@model_file, @features, @labels, @names, @factor_levels, &@train_model)
318
+ self.instance_exec(@features, @labels, @names, @factor_levels, &@train_model)
289
319
  when String === @train_model
290
- VectorModel.R_train(@model_file, @features, @labels, train_model, @names, @factor_levels)
320
+ VectorModel.R_train(@model_path, @features, @labels, train_model, @names, @factor_levels)
291
321
  end
292
322
  ensure
293
323
  if @balance
@@ -300,7 +330,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
300
330
  end
301
331
 
302
332
  def run(code)
303
- VectorModel.R_run(@model_file, @features, @labels, code, @names, @factor_levels)
333
+ VectorModel.R_run(@model_path, @features, @labels, code, @names, @factor_levels)
304
334
  end
305
335
 
306
336
  def eval(element)
@@ -308,14 +338,14 @@ cat(paste(label, sep="\\n", collapse="\\n"));
308
338
 
309
339
  result = case
310
340
  when Proc === @eval_model
311
- self.instance_exec(@model_file, features, false, nil, @names, @factor_levels, &@eval_model)
341
+ self.instance_exec(features, false, nil, @names, @factor_levels, &@eval_model)
312
342
  when String === @eval_model
313
- VectorModel.R_eval(@model_file, features, false, eval_model, @names, @factor_levels)
343
+ VectorModel.R_eval(@model_path, features, false, eval_model, @names, @factor_levels)
314
344
  else
315
345
  raise "No @eval_model function or R script"
316
346
  end
317
347
 
318
- result = self.instance_exec(result, &@post_process) if Proc === @post_process
348
+ result = self.instance_exec(result, false, &@post_process) if Proc === @post_process
319
349
 
320
350
  result
321
351
  end
@@ -334,12 +364,12 @@ cat(paste(label, sep="\\n", collapse="\\n"));
334
364
 
335
365
  result = case
336
366
  when Proc === eval_model
337
- self.instance_exec(@model_file, features, true, nil, @names, @factor_levels, &@eval_model)
367
+ self.instance_exec(features, true, nil, @names, @factor_levels, &@eval_model)
338
368
  when String === eval_model
339
- VectorModel.R_eval(@model_file, features, true, eval_model, @names, @factor_levels)
369
+ VectorModel.R_eval(@model_path, features, true, eval_model, @names, @factor_levels)
340
370
  end
341
371
 
342
- result = self.instance_exec(result, &@post_process) if Proc === @post_process
372
+ result = self.instance_exec(result, true, &@post_process) if Proc === @post_process
343
373
 
344
374
  result
345
375
  end
@@ -1 +1,48 @@
1
- # Keep
1
+ from torch.utils.data import Dataset, DataLoader
2
+
3
+ class TSVDataset(Dataset):
4
+ def __init__(self, tsv):
5
+ self.tsv = tsv
6
+
7
+ def __getitem__(self, key):
8
+ if (type(key) == int):
9
+ row = self.tsv.iloc[key]
10
+ else:
11
+ row = self.tsv.loc[key]
12
+
13
+ row = row.to_numpy()
14
+ features = row[:-1]
15
+ label = row[-1]
16
+
17
+ return features, label
18
+
19
+ def __len__(self):
20
+ return len(self.tsv)
21
+
22
+ def tsv_dataset(filename, *args, **kwargs):
23
+ import rbbt
24
+ return TSVDataset(rbbt.tsv(filename, *args, **kwargs))
25
+
26
+ def tsv(*args, **kwargs):
27
+ return tsv_dataset(*args, **kwargs)
28
+
29
+ def data_dir():
30
+ import rbbt
31
+ return rbbt.path('var/rbbt_dm/data')
32
+
33
+ if __name__ == "__main__":
34
+ import rbbt
35
+
36
+ filename = "/home/miki/test/numeric.tsv"
37
+ ds = tsv(filename)
38
+
39
+ dl = DataLoader(ds, batch_size=1)
40
+
41
+ for f, l in iter(dl):
42
+ print(".")
43
+ print(f[0,:])
44
+ print(l[0])
45
+
46
+
47
+
48
+
File without changes
@@ -0,0 +1,141 @@
1
+ from matplotlib import pyplot as plt
2
+ import numpy as np
3
+ import torch
4
+ from IPython.display import HTML, display
5
+
6
+
7
+ def set_default(figsize=(10, 10), dpi=100):
8
+ plt.style.use(['dark_background', 'bmh'])
9
+ plt.rc('axes', facecolor='k')
10
+ plt.rc('figure', facecolor='k')
11
+ plt.rc('figure', figsize=figsize, dpi=dpi)
12
+
13
+
14
+ def plot_data(X, y, d=0, auto=False, zoom=1):
15
+ X = X.cpu()
16
+ y = y.cpu()
17
+ plt.scatter(X.numpy()[:, 0], X.numpy()[:, 1], c=y, s=20, cmap=plt.cm.Spectral)
18
+ plt.axis('square')
19
+ plt.axis(np.array((-1.1, 1.1, -1.1, 1.1)) * zoom)
20
+ if auto is True: plt.axis('equal')
21
+ plt.axis('off')
22
+
23
+ _m, _c = 0, '.15'
24
+ plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0)
25
+ plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0)
26
+
27
+
28
+ def plot_model(X, y, model):
29
+ model.cpu()
30
+ mesh = np.arange(-1.1, 1.1, 0.01)
31
+ xx, yy = np.meshgrid(mesh, mesh)
32
+ with torch.no_grad():
33
+ data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float()
34
+ Z = model(data).detach()
35
+ Z = np.argmax(Z, axis=1).reshape(xx.shape)
36
+ plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3)
37
+ plot_data(X, y)
38
+
39
+
40
+ def show_scatterplot(X, colors, title=''):
41
+ colors = colors.cpu().numpy()
42
+ X = X.cpu().numpy()
43
+ plt.figure()
44
+ plt.axis('equal')
45
+ plt.scatter(X[:, 0], X[:, 1], c=colors, s=30)
46
+ # plt.grid(True)
47
+ plt.title(title)
48
+ plt.axis('off')
49
+
50
+
51
+ def plot_bases(bases, width=0.04):
52
+ bases = bases.cpu()
53
+ bases[2:] -= bases[:2]
54
+ plt.arrow(*bases[0], *bases[2], width=width, color=(1,0,0), zorder=10, alpha=1., length_includes_head=True)
55
+ plt.arrow(*bases[1], *bases[3], width=width, color=(0,1,0), zorder=10, alpha=1., length_includes_head=True)
56
+
57
+
58
+ def show_mat(mat, vect, prod, threshold=-1):
59
+ # Subplot grid definition
60
+ fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=False, sharey=True,
61
+ gridspec_kw={'width_ratios':[5,1,1]})
62
+ # Plot matrices
63
+ cax1 = ax1.matshow(mat.numpy(), clim=(-1, 1))
64
+ ax2.matshow(vect.numpy(), clim=(-1, 1))
65
+ cax3 = ax3.matshow(prod.numpy(), clim=(threshold, 1))
66
+
67
+ # Set titles
68
+ ax1.set_title(f'A: {mat.size(0)} \u00D7 {mat.size(1)}')
69
+ ax2.set_title(f'a^(i): {vect.numel()}')
70
+ ax3.set_title(f'p: {prod.numel()}')
71
+
72
+ # Remove xticks for vectors
73
+ ax2.set_xticks(tuple())
74
+ ax3.set_xticks(tuple())
75
+
76
+ # Plot colourbars
77
+ fig.colorbar(cax1, ax=ax2)
78
+ fig.colorbar(cax3, ax=ax3)
79
+
80
+ # Fix y-axis limits
81
+ ax1.set_ylim(bottom=max(len(prod), len(vect)) - 0.5)
82
+
83
+
84
+ colors = dict(
85
+ aqua='#8dd3c7',
86
+ yellow='#ffffb3',
87
+ lavender='#bebada',
88
+ red='#fb8072',
89
+ blue='#80b1d3',
90
+ orange='#fdb462',
91
+ green='#b3de69',
92
+ pink='#fccde5',
93
+ grey='#d9d9d9',
94
+ violet='#bc80bd',
95
+ unk1='#ccebc5',
96
+ unk2='#ffed6f',
97
+ )
98
+
99
+
100
+ def _cstr(s, color='black'):
101
+ if s == ' ':
102
+ return f'<text style=color:#000;padding-left:10px;background-color:{color}> </text>'
103
+ else:
104
+ return f'<text style=color:#000;background-color:{color}>{s} </text>'
105
+
106
+ # print html
107
+ def _print_color(t):
108
+ display(HTML(''.join([_cstr(ti, color=ci) for ti, ci in t])))
109
+
110
+ # get appropriate color for value
111
+ def _get_clr(value):
112
+ colors = ('#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8',
113
+ '#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8',
114
+ '#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f',
115
+ '#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e')
116
+ value = int((value * 100) / 5)
117
+ if value == len(colors): value -= 1 # fixing bugs...
118
+ return colors[value]
119
+
120
+ def _visualise_values(output_values, result_list):
121
+ text_colours = []
122
+ for i in range(len(output_values)):
123
+ text = (result_list[i], _get_clr(output_values[i]))
124
+ text_colours.append(text)
125
+ _print_color(text_colours)
126
+
127
+ def print_colourbar():
128
+ color_range = torch.linspace(-2.5, 2.5, 20)
129
+ to_print = [(f'{x:.2f}', _get_clr((x+2.5)/5)) for x in color_range]
130
+ _print_color(to_print)
131
+
132
+
133
+ # Let's only focus on the last time step for now
134
+ # First, the cell state (Long term memory)
135
+ def plot_state(data, state, b, decoder):
136
+ actual_data = decoder(data[b, :, :].numpy())
137
+ seq_len = len(actual_data)
138
+ seq_len_w_pad = len(state)
139
+ for s in range(state.size(2)):
140
+ states = torch.sigmoid(state[:, b, s])
141
+ _visualise_values(states[seq_len_w_pad - seq_len:], list(actual_data))
@@ -0,0 +1,27 @@
1
+ import torch
2
+ import math
3
+ def spiral_data(N=1000, D=2, C=3):
4
+ X = torch.zeros(N * C, D)
5
+ y = torch.zeros(N * C, dtype=torch.long)
6
+ for c in range(C):
7
+ index = 0
8
+ t = torch.linspace(0, 1, N)
9
+ # When c = 0 and t = 0: start of linspace
10
+ # When c = 0 and t = 1: end of linpace
11
+ # This inner_var is for the formula inside sin() and cos() like sin(inner_var) and cos(inner_Var)
12
+ inner_var = torch.linspace(
13
+ # When t = 0
14
+ (2 * math.pi / C) * (c),
15
+ # When t = 1
16
+ (2 * math.pi / C) * (2 + c),
17
+ N
18
+ ) + torch.randn(N) * 0.2
19
+
20
+ for ix in range(N * c, N * (c + 1)):
21
+ X[ix] = t[index] * torch.FloatTensor((
22
+ math.sin(inner_var[index]), math.cos(inner_var[index])
23
+ ))
24
+ y[ix] = c
25
+ index += 1
26
+
27
+ return (X, y)
@@ -1,32 +1,41 @@
1
1
  #{{{ LOAD MODEL
2
2
 
3
3
  def import_module_class(module, class_name):
4
- exec(f"from {module} import {class_name}")
4
+ if (not module == None):
5
+ exec(f"from {module} import {class_name}")
5
6
  return eval(class_name)
6
7
 
7
- def load_model(task, checkpoint):
8
- class_name = 'AutoModelFor' + task
9
- return import_module_class('transformers', class_name).from_pretrained(checkpoint)
8
+ def load_model(task, checkpoint, **kwargs):
9
+ if (":" in task):
10
+ module, class_name = task.split(":")
11
+ if (task == None):
12
+ module, class_name = None, module
13
+ return import_module_class(module, class_name).from_pretrained(checkpoint, **kwargs)
14
+ else:
15
+ class_name = 'AutoModelFor' + task
16
+ return import_module_class('transformers', class_name).from_pretrained(checkpoint)
10
17
 
11
- def load_tokenizer(task, checkpoint):
18
+ def load_tokenizer(task, checkpoint, **kwargs):
12
19
  class_name = 'AutoTokenizer'
13
- return import_module_class('transformers', class_name).from_pretrained(checkpoint)
20
+ return import_module_class('transformers', class_name).from_pretrained(checkpoint, **kwargs)
14
21
 
15
22
  def load_model_and_tokenizer(task, checkpoint):
16
23
  model = load_model(task, checkpoint)
17
24
  tokenizer = load_tokenizer(task, checkpoint)
18
25
  return model, tokenizer
19
26
 
20
- def load_model_and_tokenizer_from_directory(directory):
21
- import os
22
- import json
23
- options_file = os.path.join(directory, 'options.json')
24
- f = open(options_file, "r")
25
- options = json.load(f.read())
26
- f.close()
27
- task = options["task"]
28
- checkpoint = options["checkpoint"]
29
- return load_model_and_tokenizer(task, checkpoint)
27
+ # Not used
28
+
29
+ #def load_model_and_tokenizer_from_directory(directory):
30
+ # import os
31
+ # import json
32
+ # options_file = os.path.join(directory, 'options.json')
33
+ # f = open(options_file, "r")
34
+ # options = json.load(f.read())
35
+ # f.close()
36
+ # task = options["task"]
37
+ # checkpoint = options["checkpoint"]
38
+ # return load_model_and_tokenizer(task, checkpoint)
30
39
 
31
40
  #{{{ SIMPLE EVALUATE
32
41
 
@@ -51,21 +60,36 @@ def load_tsv(tsv_file):
51
60
  from datasets import load_dataset
52
61
  return load_dataset('csv', data_files=[tsv_file], sep="\t")
53
62
 
63
+ def load_json(json_file):
64
+ from datasets import load_dataset
65
+ return load_dataset('json', data_files=[json_file])
66
+
67
+ def tokenize_dataset(tokenizer, dataset):
68
+ return dataset.map(lambda subset: subset if ("input_ids" in subset.keys()) else tokenizer(subset["text"], truncation=True), batched=True)
69
+
54
70
  def tsv_dataset(tokenizer, tsv_file):
55
71
  dataset = load_tsv(tsv_file)
56
- tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True, max_length=512) , batched=True)
57
- return tokenized_dataset
72
+ return tokenize_dataset(tokenizer, dataset)
73
+
74
+ def json_dataset(tokenizer, json_file):
75
+ dataset = load_json(json_file)
76
+ return tokenize_dataset(tokenizer, dataset)
58
77
 
59
78
  def training_args(*args, **kwargs):
60
79
  from transformers import TrainingArguments
61
80
  training_args = TrainingArguments(*args, **kwargs)
62
81
  return training_args
63
82
 
64
-
65
- def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
83
+ def train_model(model, tokenizer, training_args, dataset, class_weights=None, **kwargs):
66
84
  from transformers import Trainer
67
85
 
68
- tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
86
+ if (isinstance(dataset, str)):
87
+ if (dataset.endswith('.json')):
88
+ tokenized_dataset = json_dataset(tokenizer, dataset)
89
+ else:
90
+ tokenized_dataset = tsv_dataset(tokenizer, dataset)
91
+ else:
92
+ tokenized_dataset = tokenize_dataset(tokenizer, dataset)
69
93
 
70
94
  if (not class_weights == None):
71
95
  import torch
@@ -86,7 +110,8 @@ def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
86
110
  model,
87
111
  training_args,
88
112
  train_dataset = tokenized_dataset["train"],
89
- tokenizer = tokenizer
113
+ tokenizer = tokenizer,
114
+ **kwargs
90
115
  )
91
116
  else:
92
117
 
@@ -94,7 +119,8 @@ def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
94
119
  model,
95
120
  training_args,
96
121
  train_dataset = tokenized_dataset["train"],
97
- tokenizer = tokenizer
122
+ tokenizer = tokenizer,
123
+ **kwargs
98
124
  )
99
125
 
100
126
  trainer.train()
@@ -124,10 +150,16 @@ def find_tokens_in_input(dataset, token_ids):
124
150
  return position_rows
125
151
 
126
152
 
127
- def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = None):
153
+ def predict_model(model, tokenizer, training_args, dataset, locate_tokens = None):
128
154
  from transformers import Trainer
129
155
 
130
- tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
156
+ if (isinstance(dataset, str)):
157
+ if (dataset.endswith('.json')):
158
+ tokenized_dataset = json_dataset(tokenizer, dataset)
159
+ else:
160
+ tokenized_dataset = tsv_dataset(tokenizer, dataset)
161
+ else:
162
+ tokenized_dataset = tokenize_dataset(tokenizer, dataset)
131
163
 
132
164
  trainer = Trainer(
133
165
  model,
@@ -143,4 +175,3 @@ def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = Non
143
175
  else:
144
176
  return result
145
177
 
146
-