rbbt-dm 1.2.1 → 1.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ab775c0224960820e5c62e294e6a183be49201da15710b66544762e1aaf97ebf
4
- data.tar.gz: 8fffb47ba226f06d1f41a8893d085bdc12c33c021cf2f0152f4cc741db36e420
3
+ metadata.gz: abaea1fff82b5e14a84dc9afc966fc8dde6482d50769d196854c1d619adebaf3
4
+ data.tar.gz: 561b8864fc2c0ba271a2a658da0d3492c7481a2368b40c3b91fe6edb4ebca4cd
5
5
  SHA512:
6
- metadata.gz: 8be084156063cd93c7fe905bc4b6248dd376bbfcff8e650cbb03a4cc5c28f29dbcdaa2801895a3663067d7660e8bc2cf96682829519ebd6511a7a74cec021da0
7
- data.tar.gz: 9c8570722319caf5afe60c90778d0b8517e70064030e0081968d919a314cfe1af90f63d13a7b8ce8dd56da6644db89916a729a4005f1e20745c8d4b45d50394c
6
+ metadata.gz: f26f6b27f1beb2554fa78369d1d618cc13175e0c9bb0e789b9490dcae0f7f6df4449a3c72d183ae22c96324d4e2f1ab0352bde8068c1c18871d52c5f5b53c235
7
+ data.tar.gz: bb33d93cbe24ea974beedb0530f9af317dec06c7e76f32c37d724322ba05f241c6b79a706a88f1bbe703ac4bc78c53c220f28c3f38cf7939477274b8747c436e
@@ -1,11 +1,12 @@
1
1
  require 'rbbt/vector/model'
2
2
  require 'rbbt/util/python'
3
3
 
4
+ RbbtPython.add_path Rbbt.python.find(:lib)
4
5
  RbbtPython.init_rbbt
5
6
 
6
7
  class HuggingfaceModel < VectorModel
7
8
 
8
- attr_accessor :checkpoint, :task, :locate_tokens, :class_labels
9
+ attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights
9
10
 
10
11
  def tsv_dataset(tsv_dataset_file, elements, labels = nil)
11
12
 
@@ -48,7 +49,7 @@ class HuggingfaceModel < VectorModel
48
49
 
49
50
  if labels
50
51
  training_args = call_method(:training_args, output_dir)
51
- call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels))
52
+ call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels), @class_weights)
52
53
  else
53
54
  if Array === elements
54
55
  training_args = call_method(:training_args, output_dir)
@@ -135,6 +136,3 @@ class HuggingfaceModel < VectorModel
135
136
  end
136
137
  end
137
138
 
138
- if __FILE__ == $0
139
-
140
- end
@@ -436,7 +436,7 @@ cat(paste(label, sep="\\n", collapse="\\n"));
436
436
  ensure
437
437
  @features = orig_features
438
438
  @labels = orig_labels
439
- end
439
+ end unless folds == -1
440
440
  self.train unless folds == 1
441
441
  res
442
442
  end
@@ -0,0 +1 @@
1
+ # Keep
@@ -0,0 +1,135 @@
1
+ #{{{ LOAD MODEL
2
+
3
+ def import_module_class(module, class_name):
4
+ exec(f"from {module} import {class_name}")
5
+ return eval(class_name)
6
+
7
+ def load_model(task, checkpoint):
8
+ class_name = 'AutoModelFor' + task
9
+ return import_module_class('transformers', class_name).from_pretrained(checkpoint)
10
+
11
+ def load_tokenizer(task, checkpoint):
12
+ class_name = 'AutoTokenizer'
13
+ return import_module_class('transformers', class_name).from_pretrained(checkpoint)
14
+
15
+ def load_model_and_tokenizer(task, checkpoint):
16
+ model = load_model(task, checkpoint)
17
+ tokenizer = load_tokenizer(task, checkpoint)
18
+ return model, tokenizer
19
+
20
+ #{{{ SIMPLE EVALUATE
21
+
22
+ def forward(model, features):
23
+ return model(**features)
24
+
25
+ def logits(predictions):
26
+ logits = predictions["logits"]
27
+ return [v.detach().cpu().numpy() for v in logits]
28
+
29
+ def eval_model(model, tokenizer, texts, return_logits = True):
30
+ features = tokenizer(texts, return_tensors='pt', truncation=True).to(model.device)
31
+ predictions = forward(model, features)
32
+ if (return_logits):
33
+ return logits(predictions)
34
+ else:
35
+ return predictions
36
+
37
+ #{{{ TRAIN AND PREDICT
38
+
39
+ def load_tsv(tsv_file):
40
+ from datasets import load_dataset
41
+ return load_dataset('csv', data_files=[tsv_file], sep="\t")
42
+
43
+ def tsv_dataset(tokenizer, tsv_file):
44
+ dataset = load_tsv(tsv_file)
45
+ tokenized_dataset = dataset.map(lambda example: tokenizer(example["text"], truncation=True) , batched=True)
46
+ return tokenized_dataset
47
+
48
+ def training_args(*args, **kwargs):
49
+ from transformers import TrainingArguments
50
+ training_args = TrainingArguments(*args, **kwargs)
51
+ return training_args
52
+
53
+
54
+ def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
55
+ from transformers import Trainer
56
+
57
+ tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
58
+
59
+ if (not class_weights == None):
60
+ import torch
61
+ from torch import nn
62
+
63
+ class WeightTrainer(Trainer):
64
+ def compute_loss(self, model, inputs, return_outputs=False):
65
+ labels = inputs.get("labels")
66
+ # forward pass
67
+ outputs = model(**inputs)
68
+ logits = outputs.get('logits')
69
+ # compute custom loss
70
+ loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(model.device))
71
+ loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
72
+ return (loss, outputs) if return_outputs else loss
73
+
74
+ trainer = WeightTrainer(
75
+ model,
76
+ training_args,
77
+ train_dataset = tokenized_dataset["train"],
78
+ tokenizer = tokenizer
79
+ )
80
+ else:
81
+
82
+ trainer = Trainer(
83
+ model,
84
+ training_args,
85
+ train_dataset = tokenized_dataset["train"],
86
+ tokenizer = tokenizer
87
+ )
88
+
89
+ trainer.train()
90
+
91
+ def find_tokens_in_input(dataset, token_ids):
92
+ position_rows = []
93
+
94
+ for row in dataset:
95
+ input_ids = row["input_ids"]
96
+
97
+ if (not hasattr(token_ids, "__len__")):
98
+ token_ids = [token_ids]
99
+
100
+ positions = []
101
+ for token_id in token_ids:
102
+
103
+ item_positions = []
104
+ for i in range(len(input_ids)):
105
+ if input_ids[i] == token_id:
106
+ item_positions.append(i)
107
+
108
+ positions.append(item_positions)
109
+
110
+
111
+ position_rows.append(positions)
112
+
113
+ return position_rows
114
+
115
+
116
+ def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = None):
117
+ from transformers import Trainer
118
+
119
+ tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
120
+
121
+ trainer = Trainer(
122
+ model,
123
+ training_args,
124
+ tokenizer = tokenizer
125
+ )
126
+
127
+ result = trainer.predict(test_dataset = tokenized_dataset["train"])
128
+ if (locate_tokens != None):
129
+ token_ids = tokenizer.convert_tokens_to_ids(locate_tokens)
130
+ token_positions = find_tokens_in_input(tokenized_dataset["train"], token_ids)
131
+ return dict(result=result, token_positions=token_positions)
132
+ else:
133
+ return result
134
+
135
+
@@ -3,6 +3,22 @@ require 'rbbt/vector/model/huggingface'
3
3
 
4
4
  class TestHuggingface < Test::Unit::TestCase
5
5
 
6
+ def test_pipeline
7
+ require 'rbbt/util/python'
8
+ model = VectorModel.new
9
+ model.post_process do |elements|
10
+ elements.collect{|e| e['label'] }
11
+ end
12
+ model.eval_model do |file, elements|
13
+ RbbtPython.run :transformers do
14
+ classifier ||= transformers.pipeline("sentiment-analysis")
15
+ classifier.call(elements)
16
+ end
17
+ end
18
+
19
+ assert_equal ["POSITIVE"], model.eval("I've been waiting for a HuggingFace course my whole life.")
20
+ end
21
+
6
22
  def test_sst_eval
7
23
  TmpFile.with_file do |dir|
8
24
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
@@ -12,7 +28,6 @@ class TestHuggingface < Test::Unit::TestCase
12
28
  model.class_labels = ["Bad", "Good"]
13
29
 
14
30
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
15
-
16
31
  end
17
32
  end
18
33
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rbbt-dm
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.2.1
4
+ version: 1.2.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - Miguel Vazquez
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-02-04 00:00:00.000000000 Z
11
+ date: 2023-02-07 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rbbt-util
@@ -113,6 +113,8 @@ files:
113
113
  - lib/rbbt/vector/model/svm.rb
114
114
  - lib/rbbt/vector/model/tensorflow.rb
115
115
  - lib/rbbt/vector/model/util.rb
116
+ - python/rbbt_dm/__init__.py
117
+ - python/rbbt_dm/huggingface.py
116
118
  - share/R/MA.R
117
119
  - share/R/barcode.R
118
120
  - share/R/heatmap.3.R
@@ -141,7 +143,7 @@ files:
141
143
  homepage: http://github.com/mikisvaz/rbbt-phgx
142
144
  licenses: []
143
145
  metadata: {}
144
- post_install_message:
146
+ post_install_message:
145
147
  rdoc_options: []
146
148
  require_paths:
147
149
  - lib
@@ -156,22 +158,22 @@ required_rubygems_version: !ruby/object:Gem::Requirement
156
158
  - !ruby/object:Gem::Version
157
159
  version: '0'
158
160
  requirements: []
159
- rubygems_version: 3.1.2
160
- signing_key:
161
+ rubygems_version: 3.1.6
162
+ signing_key:
161
163
  specification_version: 4
162
164
  summary: Data-mining and statistics
163
165
  test_files:
164
- - test/test_helper.rb
166
+ - test/rbbt/statistics/test_hypergeometric.rb
167
+ - test/rbbt/statistics/test_fisher.rb
168
+ - test/rbbt/statistics/test_fdr.rb
169
+ - test/rbbt/statistics/test_random_walk.rb
170
+ - test/rbbt/test_ml_task.rb
165
171
  - test/rbbt/vector/test_model.rb
166
- - test/rbbt/vector/model/test_huggingface.rb
167
172
  - test/rbbt/vector/model/test_tensorflow.rb
168
173
  - test/rbbt/vector/model/test_spaCy.rb
174
+ - test/rbbt/vector/model/test_huggingface.rb
169
175
  - test/rbbt/vector/model/test_svm.rb
170
- - test/rbbt/statistics/test_random_walk.rb
171
- - test/rbbt/statistics/test_fisher.rb
172
- - test/rbbt/statistics/test_fdr.rb
173
- - test/rbbt/statistics/test_hypergeometric.rb
174
- - test/rbbt/test_stan.rb
175
- - test/rbbt/matrix/test_barcode.rb
176
- - test/rbbt/test_ml_task.rb
177
176
  - test/rbbt/network/test_paths.rb
177
+ - test/rbbt/matrix/test_barcode.rb
178
+ - test/rbbt/test_stan.rb
179
+ - test/test_helper.rb