rbbt-dm 1.2.3 → 1.2.4

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 2ff72107967b0f7c654697f3a7b3c0ef10f7a5264d775117f12f74084a2819b2
4
- data.tar.gz: 6b9a58b5a2723c095332f79a37d9c1c7f4bc1431410f23a55beeed1c3b52f7ad
3
+ metadata.gz: abaea1fff82b5e14a84dc9afc966fc8dde6482d50769d196854c1d619adebaf3
4
+ data.tar.gz: 561b8864fc2c0ba271a2a658da0d3492c7481a2368b40c3b91fe6edb4ebca4cd
5
5
  SHA512:
6
- metadata.gz: a0fb4198cb0be3aa5253df0f655ee230621dd26a31956a774fffe95eac35f4c8b558a41c0e340c25c6eef463760ff6230b967f09eb09671b2078a50066067384
7
- data.tar.gz: 0bd6c3667a8ec26ed092c54e78176671807ed0634e136e497303e34a17b2740e5e023041bc06389ac187de39b54942b9b1c5cd77abbc067c89250424654b6974
6
+ metadata.gz: f26f6b27f1beb2554fa78369d1d618cc13175e0c9bb0e789b9490dcae0f7f6df4449a3c72d183ae22c96324d4e2f1ab0352bde8068c1c18871d52c5f5b53c235
7
+ data.tar.gz: bb33d93cbe24ea974beedb0530f9af317dec06c7e76f32c37d724322ba05f241c6b79a706a88f1bbe703ac4bc78c53c220f28c3f38cf7939477274b8747c436e
@@ -1,11 +1,12 @@
1
1
  require 'rbbt/vector/model'
2
2
  require 'rbbt/util/python'
3
3
 
4
+ RbbtPython.add_path Rbbt.python.find(:lib)
4
5
  RbbtPython.init_rbbt
5
6
 
6
7
  class HuggingfaceModel < VectorModel
7
8
 
8
- attr_accessor :checkpoint, :task, :locate_tokens, :class_labels
9
+ attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights
9
10
 
10
11
  def tsv_dataset(tsv_dataset_file, elements, labels = nil)
11
12
 
@@ -48,7 +49,7 @@ class HuggingfaceModel < VectorModel
48
49
 
49
50
  if labels
50
51
  training_args = call_method(:training_args, output_dir)
51
- call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels))
52
+ call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels), @class_weights)
52
53
  else
53
54
  if Array === elements
54
55
  training_args = call_method(:training_args, output_dir)
@@ -135,6 +136,3 @@ class HuggingfaceModel < VectorModel
135
136
  end
136
137
  end
137
138
 
138
- if __FILE__ == $0
139
-
140
- end
@@ -0,0 +1 @@
1
+ # Keep
@@ -51,17 +51,40 @@ def training_args(*args, **kwargs):
51
51
  return training_args
52
52
 
53
53
 
54
- def train_model(model, tokenizer, training_args, tsv_file):
54
+ def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
55
55
  from transformers import Trainer
56
56
 
57
57
  tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
58
58
 
59
- trainer = Trainer(
60
- model,
61
- training_args,
62
- train_dataset = tokenized_dataset["train"],
63
- tokenizer = tokenizer
64
- )
59
+ if (not class_weights == None):
60
+ import torch
61
+ from torch import nn
62
+
63
+ class WeightTrainer(Trainer):
64
+ def compute_loss(self, model, inputs, return_outputs=False):
65
+ labels = inputs.get("labels")
66
+ # forward pass
67
+ outputs = model(**inputs)
68
+ logits = outputs.get('logits')
69
+ # compute custom loss
70
+ loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(model.device))
71
+ loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
72
+ return (loss, outputs) if return_outputs else loss
73
+
74
+ trainer = WeightTrainer(
75
+ model,
76
+ training_args,
77
+ train_dataset = tokenized_dataset["train"],
78
+ tokenizer = tokenizer
79
+ )
80
+ else:
81
+
82
+ trainer = Trainer(
83
+ model,
84
+ training_args,
85
+ train_dataset = tokenized_dataset["train"],
86
+ tokenizer = tokenizer
87
+ )
65
88
 
66
89
  trainer.train()
67
90
 
@@ -90,7 +113,6 @@ def find_tokens_in_input(dataset, token_ids):
90
113
  return position_rows
91
114
 
92
115
 
93
-
94
116
  def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = None):
95
117
  from transformers import Trainer
96
118
 
@@ -110,3 +132,4 @@ def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = Non
110
132
  else:
111
133
  return result
112
134
 
135
+
@@ -3,6 +3,22 @@ require 'rbbt/vector/model/huggingface'
3
3
 
4
4
  class TestHuggingface < Test::Unit::TestCase
5
5
 
6
+ def test_pipeline
7
+ require 'rbbt/util/python'
8
+ model = VectorModel.new
9
+ model.post_process do |elements|
10
+ elements.collect{|e| e['label'] }
11
+ end
12
+ model.eval_model do |file, elements|
13
+ RbbtPython.run :transformers do
14
+ classifier ||= transformers.pipeline("sentiment-analysis")
15
+ classifier.call(elements)
16
+ end
17
+ end
18
+
19
+ assert_equal ["POSITIVE"], model.eval("I've been waiting for a HuggingFace course my whole life.")
20
+ end
21
+
6
22
  def test_sst_eval
7
23
  TmpFile.with_file do |dir|
8
24
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
@@ -12,7 +28,6 @@ class TestHuggingface < Test::Unit::TestCase
12
28
  model.class_labels = ["Bad", "Good"]
13
29
 
14
30
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
15
-
16
31
  end
17
32
  end
18
33
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rbbt-dm
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.2.3
4
+ version: 1.2.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - Miguel Vazquez
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-02-04 00:00:00.000000000 Z
11
+ date: 2023-02-07 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rbbt-util
@@ -113,6 +113,7 @@ files:
113
113
  - lib/rbbt/vector/model/svm.rb
114
114
  - lib/rbbt/vector/model/tensorflow.rb
115
115
  - lib/rbbt/vector/model/util.rb
116
+ - python/rbbt_dm/__init__.py
116
117
  - python/rbbt_dm/huggingface.py
117
118
  - share/R/MA.R
118
119
  - share/R/barcode.R
@@ -142,7 +143,7 @@ files:
142
143
  homepage: http://github.com/mikisvaz/rbbt-phgx
143
144
  licenses: []
144
145
  metadata: {}
145
- post_install_message:
146
+ post_install_message:
146
147
  rdoc_options: []
147
148
  require_paths:
148
149
  - lib
@@ -157,22 +158,22 @@ required_rubygems_version: !ruby/object:Gem::Requirement
157
158
  - !ruby/object:Gem::Version
158
159
  version: '0'
159
160
  requirements: []
160
- rubygems_version: 3.1.2
161
- signing_key:
161
+ rubygems_version: 3.1.6
162
+ signing_key:
162
163
  specification_version: 4
163
164
  summary: Data-mining and statistics
164
165
  test_files:
165
- - test/test_helper.rb
166
+ - test/rbbt/statistics/test_hypergeometric.rb
167
+ - test/rbbt/statistics/test_fisher.rb
168
+ - test/rbbt/statistics/test_fdr.rb
169
+ - test/rbbt/statistics/test_random_walk.rb
170
+ - test/rbbt/test_ml_task.rb
166
171
  - test/rbbt/vector/test_model.rb
167
- - test/rbbt/vector/model/test_huggingface.rb
168
172
  - test/rbbt/vector/model/test_tensorflow.rb
169
173
  - test/rbbt/vector/model/test_spaCy.rb
174
+ - test/rbbt/vector/model/test_huggingface.rb
170
175
  - test/rbbt/vector/model/test_svm.rb
171
- - test/rbbt/statistics/test_random_walk.rb
172
- - test/rbbt/statistics/test_fisher.rb
173
- - test/rbbt/statistics/test_fdr.rb
174
- - test/rbbt/statistics/test_hypergeometric.rb
175
- - test/rbbt/test_stan.rb
176
- - test/rbbt/matrix/test_barcode.rb
177
- - test/rbbt/test_ml_task.rb
178
176
  - test/rbbt/network/test_paths.rb
177
+ - test/rbbt/matrix/test_barcode.rb
178
+ - test/rbbt/test_stan.rb
179
+ - test/test_helper.rb