rbbt-dm 1.2.1 → 1.2.4

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ab775c0224960820e5c62e294e6a183be49201da15710b66544762e1aaf97ebf
4
- data.tar.gz: 8fffb47ba226f06d1f41a8893d085bdc12c33c021cf2f0152f4cc741db36e420
3
+ metadata.gz: abaea1fff82b5e14a84dc9afc966fc8dde6482d50769d196854c1d619adebaf3
4
+ data.tar.gz: 561b8864fc2c0ba271a2a658da0d3492c7481a2368b40c3b91fe6edb4ebca4cd
5
5
  SHA512:
6
- metadata.gz: 8be084156063cd93c7fe905bc4b6248dd376bbfcff8e650cbb03a4cc5c28f29dbcdaa2801895a3663067d7660e8bc2cf96682829519ebd6511a7a74cec021da0
7
- data.tar.gz: 9c8570722319caf5afe60c90778d0b8517e70064030e0081968d919a314cfe1af90f63d13a7b8ce8dd56da6644db89916a729a4005f1e20745c8d4b45d50394c
6
+ metadata.gz: f26f6b27f1beb2554fa78369d1d618cc13175e0c9bb0e789b9490dcae0f7f6df4449a3c72d183ae22c96324d4e2f1ab0352bde8068c1c18871d52c5f5b53c235
7
+ data.tar.gz: bb33d93cbe24ea974beedb0530f9af317dec06c7e76f32c37d724322ba05f241c6b79a706a88f1bbe703ac4bc78c53c220f28c3f38cf7939477274b8747c436e
@@ -1,11 +1,12 @@
1
1
  require 'rbbt/vector/model'
2
2
  require 'rbbt/util/python'
3
3
 
4
+ RbbtPython.add_path Rbbt.python.find(:lib)
4
5
  RbbtPython.init_rbbt
5
6
 
6
7
  class HuggingfaceModel < VectorModel
7
8
 
8
- attr_accessor :checkpoint, :task, :locate_tokens, :class_labels
9
+ attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights
9
10
 
10
11
  def tsv_dataset(tsv_dataset_file, elements, labels = nil)
11
12
 
@@ -48,7 +49,7 @@ class HuggingfaceModel < VectorModel
48
49
 
49
50
  if labels
50
51
  training_args = call_method(:training_args, output_dir)
51
- call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels))
52
+ call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels), @class_weights)
52
53
  else
53
54
  if Array === elements
54
55
  training_args = call_method(:training_args, output_dir)
@@ -135,6 +136,3 @@ class HuggingfaceModel < VectorModel
135
136
  end
136
137
  end
137
138
 
138
- if __FILE__ == $0
139
-
140
- end
@@ -436,7 +436,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
436
436
  ensure
437
437
  @features = orig_features
438
438
  @labels = orig_labels
439
- end
439
+ end unless folds == -1
440
440
  self.train unless folds == 1
441
441
  res
442
442
  end
@@ -0,0 +1 @@
1
+ # Keep
@@ -0,0 +1,135 @@
1
+ #{{{ LOAD MODEL
2
+
3
+ def import_module_class(module, class_name):
4
+ exec(f"from {module} import {class_name}")
5
+ return eval(class_name)
6
+
7
+ def load_model(task, checkpoint):
8
+ class_name = 'AutoModelFor' + task
9
+ return import_module_class('transformers', class_name).from_pretrained(checkpoint)
10
+
11
+ def load_tokenizer(task, checkpoint):
12
+ class_name = 'AutoTokenizer'
13
+ return import_module_class('transformers', class_name).from_pretrained(checkpoint)
14
+
15
+ def load_model_and_tokenizer(task, checkpoint):
16
+ model = load_model(task, checkpoint)
17
+ tokenizer = load_tokenizer(task, checkpoint)
18
+ return model, tokenizer
19
+
20
+ #{{{ SIMPLE EVALUATE
21
+
22
+ def forward(model, features):
23
+ return model(**features)
24
+
25
+ def logits(predictions):
26
+ logits = predictions["logits"]
27
+ return [v.detach().cpu().numpy() for v in logits]
28
+
29
+ def eval_model(model, tokenizer, texts, return_logits = True):
30
+ features = tokenizer(texts, return_tensors='pt', truncation=True).to(model.device)
31
+ predictions = forward(model, features)
32
+ if (return_logits):
33
+ return logits(predictions)
34
+ else:
35
+ return predictions
36
+
37
+ #{{{ TRAIN AND PREDICT
38
+
39
+ def load_tsv(tsv_file):
40
+ from datasets import load_dataset
41
+ return load_dataset('csv', data_files=[tsv_file], sep="\t")
42
+
43
+ def tsv_dataset(tokenizer, tsv_file):
44
+ dataset = load_tsv(tsv_file)
45
+ tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True) , batched=True)
46
+ return tokenized_dataset
47
+
48
+ def training_args(*args, **kwargs):
49
+ from transformers import TrainingArguments
50
+ training_args = TrainingArguments(*args, **kwargs)
51
+ return training_args
52
+
53
+
54
+ def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
55
+ from transformers import Trainer
56
+
57
+ tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
58
+
59
+ if (not class_weights == None):
60
+ import torch
61
+ from torch import nn
62
+
63
+ class WeightTrainer(Trainer):
64
+ def compute_loss(self, model, inputs, return_outputs=False):
65
+ labels = inputs.get("labels")
66
+ # forward pass
67
+ outputs = model(**inputs)
68
+ logits = outputs.get('logits')
69
+ # compute custom loss
70
+ loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(model.device))
71
+ loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
72
+ return (loss, outputs) if return_outputs else loss
73
+
74
+ trainer = WeightTrainer(
75
+ model,
76
+ training_args,
77
+ train_dataset = tokenized_dataset["train"],
78
+ tokenizer = tokenizer
79
+ )
80
+ else:
81
+
82
+ trainer = Trainer(
83
+ model,
84
+ training_args,
85
+ train_dataset = tokenized_dataset["train"],
86
+ tokenizer = tokenizer
87
+ )
88
+
89
+ trainer.train()
90
+
91
+ def find_tokens_in_input(dataset, token_ids):
92
+ position_rows = []
93
+
94
+ for row in dataset:
95
+ input_ids = row["input_ids"]
96
+
97
+ if (not hasattr(token_ids, "__len__")):
98
+ token_ids = [token_ids]
99
+
100
+ positions = []
101
+ for token_id in token_ids:
102
+
103
+ item_positions = []
104
+ for i in range(len(input_ids)):
105
+ if input_ids[i] == token_id:
106
+ item_positions.append(i)
107
+
108
+ positions.append(item_positions)
109
+
110
+
111
+ position_rows.append(positions)
112
+
113
+ return position_rows
114
+
115
+
116
+ def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = None):
117
+ from transformers import Trainer
118
+
119
+ tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
120
+
121
+ trainer = Trainer(
122
+ model,
123
+ training_args,
124
+ tokenizer = tokenizer
125
+ )
126
+
127
+ result = trainer.predict(test_dataset = tokenized_dataset["train"])
128
+ if (locate_tokens != None):
129
+ token_ids = tokenizer.convert_tokens_to_ids(locate_tokens)
130
+ token_positions = find_tokens_in_input(tokenized_dataset["train"], token_ids)
131
+ return dict(result=result, token_positions=token_positions)
132
+ else:
133
+ return result
134
+
135
+
@@ -3,6 +3,22 @@ require 'rbbt/vector/model/huggingface'
3
3
 
4
4
  class TestHuggingface < Test::Unit::TestCase
5
5
 
6
+ def test_pipeline
7
+ require 'rbbt/util/python'
8
+ model = VectorModel.new
9
+ model.post_process do |elements|
10
+ elements.collect{|e| e['label'] }
11
+ end
12
+ model.eval_model do |file, elements|
13
+ RbbtPython.run :transformers do
14
+ classifier ||= transformers.pipeline("sentiment-analysis")
15
+ classifier.call(elements)
16
+ end
17
+ end
18
+
19
+ assert_equal ["POSITIVE"], model.eval("I've been waiting for a HuggingFace course my whole life.")
20
+ end
21
+
6
22
  def test_sst_eval
7
23
  TmpFile.with_file do |dir|
8
24
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
@@ -12,7 +28,6 @@ class TestHuggingface < Test::Unit::TestCase
12
28
  model.class_labels = ["Bad", "Good"]
13
29
 
14
30
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
15
-
16
31
  end
17
32
  end
18
33
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rbbt-dm
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.2.1
4
+ version: 1.2.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - Miguel Vazquez
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-02-04 00:00:00.000000000 Z
11
+ date: 2023-02-07 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rbbt-util
@@ -113,6 +113,8 @@ files:
113
113
  - lib/rbbt/vector/model/svm.rb
114
114
  - lib/rbbt/vector/model/tensorflow.rb
115
115
  - lib/rbbt/vector/model/util.rb
116
+ - python/rbbt_dm/__init__.py
117
+ - python/rbbt_dm/huggingface.py
116
118
  - share/R/MA.R
117
119
  - share/R/barcode.R
118
120
  - share/R/heatmap.3.R
@@ -141,7 +143,7 @@ files:
141
143
  homepage: http://github.com/mikisvaz/rbbt-phgx
142
144
  licenses: []
143
145
  metadata: {}
144
- post_install_message:
146
+ post_install_message:
145
147
  rdoc_options: []
146
148
  require_paths:
147
149
  - lib
@@ -156,22 +158,22 @@ required_rubygems_version: !ruby/object:Gem::Requirement
156
158
  - !ruby/object:Gem::Version
157
159
  version: '0'
158
160
  requirements: []
159
- rubygems_version: 3.1.2
160
- signing_key:
161
+ rubygems_version: 3.1.6
162
+ signing_key:
161
163
  specification_version: 4
162
164
  summary: Data-mining and statistics
163
165
  test_files:
164
- - test/test_helper.rb
166
+ - test/rbbt/statistics/test_hypergeometric.rb
167
+ - test/rbbt/statistics/test_fisher.rb
168
+ - test/rbbt/statistics/test_fdr.rb
169
+ - test/rbbt/statistics/test_random_walk.rb
170
+ - test/rbbt/test_ml_task.rb
165
171
  - test/rbbt/vector/test_model.rb
166
- - test/rbbt/vector/model/test_huggingface.rb
167
172
  - test/rbbt/vector/model/test_tensorflow.rb
168
173
  - test/rbbt/vector/model/test_spaCy.rb
174
+ - test/rbbt/vector/model/test_huggingface.rb
169
175
  - test/rbbt/vector/model/test_svm.rb
170
- - test/rbbt/statistics/test_random_walk.rb
171
- - test/rbbt/statistics/test_fisher.rb
172
- - test/rbbt/statistics/test_fdr.rb
173
- - test/rbbt/statistics/test_hypergeometric.rb
174
- - test/rbbt/test_stan.rb
175
- - test/rbbt/matrix/test_barcode.rb
176
- - test/rbbt/test_ml_task.rb
177
176
  - test/rbbt/network/test_paths.rb
177
+ - test/rbbt/matrix/test_barcode.rb
178
+ - test/rbbt/test_stan.rb
179
+ - test/test_helper.rb