rbbt-dm 1.2.3 → 1.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 2ff72107967b0f7c654697f3a7b3c0ef10f7a5264d775117f12f74084a2819b2
4
- data.tar.gz: 6b9a58b5a2723c095332f79a37d9c1c7f4bc1431410f23a55beeed1c3b52f7ad
3
+ metadata.gz: abaea1fff82b5e14a84dc9afc966fc8dde6482d50769d196854c1d619adebaf3
4
+ data.tar.gz: 561b8864fc2c0ba271a2a658da0d3492c7481a2368b40c3b91fe6edb4ebca4cd
5
5
  SHA512:
6
- metadata.gz: a0fb4198cb0be3aa5253df0f655ee230621dd26a31956a774fffe95eac35f4c8b558a41c0e340c25c6eef463760ff6230b967f09eb09671b2078a50066067384
7
- data.tar.gz: 0bd6c3667a8ec26ed092c54e78176671807ed0634e136e497303e34a17b2740e5e023041bc06389ac187de39b54942b9b1c5cd77abbc067c89250424654b6974
6
+ metadata.gz: f26f6b27f1beb2554fa78369d1d618cc13175e0c9bb0e789b9490dcae0f7f6df4449a3c72d183ae22c96324d4e2f1ab0352bde8068c1c18871d52c5f5b53c235
7
+ data.tar.gz: bb33d93cbe24ea974beedb0530f9af317dec06c7e76f32c37d724322ba05f241c6b79a706a88f1bbe703ac4bc78c53c220f28c3f38cf7939477274b8747c436e
@@ -1,11 +1,12 @@
1
1
  require 'rbbt/vector/model'
2
2
  require 'rbbt/util/python'
3
3
 
4
+ RbbtPython.add_path Rbbt.python.find(:lib)
4
5
  RbbtPython.init_rbbt
5
6
 
6
7
  class HuggingfaceModel < VectorModel
7
8
 
8
- attr_accessor :checkpoint, :task, :locate_tokens, :class_labels
9
+ attr_accessor :checkpoint, :task, :locate_tokens, :class_labels, :class_weights
9
10
 
10
11
  def tsv_dataset(tsv_dataset_file, elements, labels = nil)
11
12
 
@@ -48,7 +49,7 @@ class HuggingfaceModel < VectorModel
48
49
 
49
50
  if labels
50
51
  training_args = call_method(:training_args, output_dir)
51
- call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels))
52
+ call_method(:train_model, @model, @tokenizer, training_args, tsv_dataset(tsv_file, elements, labels), @class_weights)
52
53
  else
53
54
  if Array === elements
54
55
  training_args = call_method(:training_args, output_dir)
@@ -135,6 +136,3 @@ class HuggingfaceModel < VectorModel
135
136
  end
136
137
  end
137
138
 
138
- if __FILE__ == $0
139
-
140
- end
@@ -0,0 +1 @@
1
+ # Keep
@@ -51,17 +51,40 @@ def training_args(*args, **kwargs):
51
51
  return training_args
52
52
 
53
53
 
54
- def train_model(model, tokenizer, training_args, tsv_file):
54
+ def train_model(model, tokenizer, training_args, tsv_file, class_weights=None):
55
55
  from transformers import Trainer
56
56
 
57
57
  tokenized_dataset = tsv_dataset(tokenizer, tsv_file)
58
58
 
59
- trainer = Trainer(
60
- model,
61
- training_args,
62
- train_dataset = tokenized_dataset["train"],
63
- tokenizer = tokenizer
64
- )
59
+ if (not class_weights == None):
60
+ import torch
61
+ from torch import nn
62
+
63
+ class WeightTrainer(Trainer):
64
+ def compute_loss(self, model, inputs, return_outputs=False):
65
+ labels = inputs.get("labels")
66
+ # forward pass
67
+ outputs = model(**inputs)
68
+ logits = outputs.get('logits')
69
+ # compute custom loss
70
+ loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(model.device))
71
+ loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
72
+ return (loss, outputs) if return_outputs else loss
73
+
74
+ trainer = WeightTrainer(
75
+ model,
76
+ training_args,
77
+ train_dataset = tokenized_dataset["train"],
78
+ tokenizer = tokenizer
79
+ )
80
+ else:
81
+
82
+ trainer = Trainer(
83
+ model,
84
+ training_args,
85
+ train_dataset = tokenized_dataset["train"],
86
+ tokenizer = tokenizer
87
+ )
65
88
 
66
89
  trainer.train()
67
90
 
@@ -90,7 +113,6 @@ def find_tokens_in_input(dataset, token_ids):
90
113
  return position_rows
91
114
 
92
115
 
93
-
94
116
  def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = None):
95
117
  from transformers import Trainer
96
118
 
@@ -110,3 +132,4 @@ def predict_model(model, tokenizer, training_args, tsv_file, locate_tokens = Non
110
132
  else:
111
133
  return result
112
134
 
135
+
@@ -3,6 +3,22 @@ require 'rbbt/vector/model/huggingface'
3
3
 
4
4
  class TestHuggingface < Test::Unit::TestCase
5
5
 
6
+ def test_pipeline
7
+ require 'rbbt/util/python'
8
+ model = VectorModel.new
9
+ model.post_process do |elements|
10
+ elements.collect{|e| e['label'] }
11
+ end
12
+ model.eval_model do |file, elements|
13
+ RbbtPython.run :transformers do
14
+ classifier ||= transformers.pipeline("sentiment-analysis")
15
+ classifier.call(elements)
16
+ end
17
+ end
18
+
19
+ assert_equal ["POSITIVE"], model.eval("I've been waiting for a HuggingFace course my whole life.")
20
+ end
21
+
6
22
  def test_sst_eval
7
23
  TmpFile.with_file do |dir|
8
24
  checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
@@ -12,7 +28,6 @@ class TestHuggingface < Test::Unit::TestCase
12
28
  model.class_labels = ["Bad", "Good"]
13
29
 
14
30
  assert_equal ["Bad", "Good"], model.eval(["This is dog", "This is cat"])
15
-
16
31
  end
17
32
  end
18
33
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rbbt-dm
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.2.3
4
+ version: 1.2.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - Miguel Vazquez
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-02-04 00:00:00.000000000 Z
11
+ date: 2023-02-07 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rbbt-util
@@ -113,6 +113,7 @@ files:
113
113
  - lib/rbbt/vector/model/svm.rb
114
114
  - lib/rbbt/vector/model/tensorflow.rb
115
115
  - lib/rbbt/vector/model/util.rb
116
+ - python/rbbt_dm/__init__.py
116
117
  - python/rbbt_dm/huggingface.py
117
118
  - share/R/MA.R
118
119
  - share/R/barcode.R
@@ -142,7 +143,7 @@ files:
142
143
  homepage: http://github.com/mikisvaz/rbbt-phgx
143
144
  licenses: []
144
145
  metadata: {}
145
- post_install_message:
146
+ post_install_message:
146
147
  rdoc_options: []
147
148
  require_paths:
148
149
  - lib
@@ -157,22 +158,22 @@ required_rubygems_version: !ruby/object:Gem::Requirement
157
158
  - !ruby/object:Gem::Version
158
159
  version: '0'
159
160
  requirements: []
160
- rubygems_version: 3.1.2
161
- signing_key:
161
+ rubygems_version: 3.1.6
162
+ signing_key:
162
163
  specification_version: 4
163
164
  summary: Data-mining and statistics
164
165
  test_files:
165
- - test/test_helper.rb
166
+ - test/rbbt/statistics/test_hypergeometric.rb
167
+ - test/rbbt/statistics/test_fisher.rb
168
+ - test/rbbt/statistics/test_fdr.rb
169
+ - test/rbbt/statistics/test_random_walk.rb
170
+ - test/rbbt/test_ml_task.rb
166
171
  - test/rbbt/vector/test_model.rb
167
- - test/rbbt/vector/model/test_huggingface.rb
168
172
  - test/rbbt/vector/model/test_tensorflow.rb
169
173
  - test/rbbt/vector/model/test_spaCy.rb
174
+ - test/rbbt/vector/model/test_huggingface.rb
170
175
  - test/rbbt/vector/model/test_svm.rb
171
- - test/rbbt/statistics/test_random_walk.rb
172
- - test/rbbt/statistics/test_fisher.rb
173
- - test/rbbt/statistics/test_fdr.rb
174
- - test/rbbt/statistics/test_hypergeometric.rb
175
- - test/rbbt/test_stan.rb
176
- - test/rbbt/matrix/test_barcode.rb
177
- - test/rbbt/test_ml_task.rb
178
176
  - test/rbbt/network/test_paths.rb
177
+ - test/rbbt/matrix/test_barcode.rb
178
+ - test/rbbt/test_stan.rb
179
+ - test/test_helper.rb