transformers-rb 0.1.4 → 0.1.6

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 16b409e954f6bcc45fd4f3f2db94dc92c87d47c4c06936162978cc7d7e54fc09
4
- data.tar.gz: c40ea58b7531e89a041ce2782dea89536f85227903e2dc9a60113afe041bb9f7
3
+ metadata.gz: 607afc7b4cc3bb62f8ae2aab693863d2cd1456d0c90b20dc0eb4f682fc788cfb
4
+ data.tar.gz: 3a6ad9f7f624e77e077bf4d7e1d945ee8f6e26d805e08a574d5d583e231500fb
5
5
  SHA512:
6
- metadata.gz: 0576500ca9fe9379aae4c2cc050aa34c90eea7b2d5251b6139c48d88b5107086e6197aad5210bc74d7969a5bfa0458d31ce7faf561df69b2ae9b2a6400280ce0
7
- data.tar.gz: b5b0a865acdd37bcde11571e365a1d39b473f3dcbbe737a082f37002f886fe38f0f57c77f44236f859939d1d3b1df9659bb524a483cd0b9048dbc4cd472a3355
6
+ metadata.gz: ab988a57c8af7883894163b149528172a8602d9031f9c6e6880d6bb32f42842cbfa566ac4f347168a230218565cc7550c99043286e5a31d0d104ba3b250b2483
7
+ data.tar.gz: 7f29ee6adcc0566ce1d1905e557288dec78b8befcc98f9a46c71c0fe8580dd14c914afcb1c1477df65aecc3b106cf799975c2b46ba73a4ec7d93ea34a908b688
data/CHANGELOG.md CHANGED
@@ -1,3 +1,14 @@
1
+ ## 0.1.6 (2024-12-29)
2
+
3
+ - Fixed error with failed HTTP requests
4
+ - Fixed warning with Ruby 3.4
5
+
6
+ ## 0.1.5 (2024-11-01)
7
+
8
+ - Fixed error with pipelines when called more than 10 times
9
+ - Fixed `device` option for pipelines
10
+ - Fixed error with `reranking` pipeline
11
+
1
12
  ## 0.1.4 (2024-10-22)
2
13
 
3
14
  - Added `BertForSequenceClassification`
@@ -24,7 +24,7 @@ module Transformers
24
24
  def hf_raise_for_status(response, endpoint_name: nil)
25
25
  begin
26
26
  response.value unless response.is_a?(Net::HTTPRedirection)
27
- rescue
27
+ rescue => e
28
28
  error_code = response["X-Error-Code"]
29
29
  error_message = response["X-Error-Message"]
30
30
 
@@ -56,6 +56,10 @@ module Transformers
56
56
  end
57
57
 
58
58
  module ModuleUtilsMixin
59
+ def device
60
+ get_parameter_device(self)
61
+ end
62
+
59
63
  def get_extended_attention_mask(
60
64
  attention_mask,
61
65
  input_shape,
@@ -112,6 +116,13 @@ module Transformers
112
116
 
113
117
  head_mask
114
118
  end
119
+
120
+ private
121
+
122
+ def get_parameter_device(parameter)
123
+ # TODO return Torch::Device in Torch.rb
124
+ Torch.device(parameter.parameters[0].device)
125
+ end
115
126
  end
116
127
 
117
128
  class PreTrainedModel < Torch::NN::Module
@@ -1213,4 +1213,6 @@ module Transformers
1213
1213
  end
1214
1214
  end
1215
1215
  end
1216
+
1217
+ XLMRobertaForSequenceClassification = XlmRoberta::XLMRobertaForSequenceClassification
1216
1218
  end
@@ -123,6 +123,22 @@ module Transformers
123
123
  end
124
124
  end
125
125
 
126
+ if @framework == "pt"
127
+ if device.is_a?(String)
128
+ @device = Torch.device(device)
129
+ else
130
+ # TODO update default in 0.2.0
131
+ @device = @model.device
132
+ end
133
+ else
134
+ raise Todo
135
+ end
136
+
137
+ # TODO Fix eql? for Torch::Device in Torch.rb
138
+ if @device.type != @model.device.type || @device.index != @model.device.index
139
+ @model.to(@device)
140
+ end
141
+
126
142
  @call_count = 0
127
143
  @batch_size = kwargs.delete(:batch_size)
128
144
  @num_workers = kwargs.delete(:num_workers)
@@ -133,6 +149,24 @@ module Transformers
133
149
  @model.dtype
134
150
  end
135
151
 
152
+ def _ensure_tensor_on_device(inputs, device)
153
+ # TODO move
154
+ inputs = inputs.to_h if inputs.is_a?(BatchEncoding)
155
+
156
+ if inputs.is_a?(ModelOutput)
157
+ inputs.instance_variable_get(:@data).transform_values! { |tensor| _ensure_tensor_on_device(tensor, device) }
158
+ inputs
159
+ elsif inputs.is_a?(Hash)
160
+ inputs.transform_values { |tensor| _ensure_tensor_on_device(tensor, device) }
161
+ elsif inputs.is_a?(Array)
162
+ inputs.map { |item| _ensure_tensor_on_device(item, device) }
163
+ elsif inputs.is_a?(Torch::Tensor)
164
+ inputs.to(device)
165
+ else
166
+ inputs
167
+ end
168
+ end
169
+
136
170
  def check_model_type(supported_models)
137
171
  if !supported_models.is_a?(Array)
138
172
  supported_models_names = []
@@ -245,7 +279,10 @@ module Transformers
245
279
  end
246
280
 
247
281
  def forward(model_inputs, **forward_params)
248
- _forward(model_inputs, **forward_params)
282
+ model_inputs = _ensure_tensor_on_device(model_inputs, @device)
283
+ model_outputs = _forward(model_inputs, **forward_params)
284
+ model_outputs = _ensure_tensor_on_device(model_outputs, Torch.device("cpu"))
285
+ model_outputs
249
286
  end
250
287
 
251
288
  def run_single(inputs, preprocess_params, forward_params, postprocess_params)
@@ -8,7 +8,8 @@ module Transformers
8
8
  @tokenizer.(
9
9
  [inputs[:query]] * inputs[:documents].length,
10
10
  text_pair: inputs[:documents],
11
- return_tensors: @framework
11
+ return_tensors: @framework,
12
+ padding: true
12
13
  )
13
14
  end
14
15
 
@@ -1,4 +1,5 @@
1
1
  module Transformers
2
+ # TODO remove in 0.2.0
2
3
  class SentenceTransformer
3
4
  def initialize(model_id)
4
5
  @model_id = model_id
@@ -1,3 +1,3 @@
1
1
  module Transformers
2
- VERSION = "0.1.4"
2
+ VERSION = "0.1.6"
3
3
  end
metadata CHANGED
@@ -1,15 +1,28 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: transformers-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.4
4
+ version: 0.1.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
9
8
  bindir: bin
10
9
  cert_chain: []
11
- date: 2024-10-22 00:00:00.000000000 Z
10
+ date: 2024-12-29 00:00:00.000000000 Z
12
11
  dependencies:
12
+ - !ruby/object:Gem::Dependency
13
+ name: logger
14
+ requirement: !ruby/object:Gem::Requirement
15
+ requirements:
16
+ - - ">="
17
+ - !ruby/object:Gem::Version
18
+ version: '0'
19
+ type: :runtime
20
+ prerelease: false
21
+ version_requirements: !ruby/object:Gem::Requirement
22
+ requirements:
23
+ - - ">="
24
+ - !ruby/object:Gem::Version
25
+ version: '0'
13
26
  - !ruby/object:Gem::Dependency
14
27
  name: numo-narray
15
28
  requirement: !ruby/object:Gem::Requirement
@@ -66,7 +79,6 @@ dependencies:
66
79
  - - ">="
67
80
  - !ruby/object:Gem::Version
68
81
  version: 0.17.1
69
- description:
70
82
  email: andrew@ankane.org
71
83
  executables: []
72
84
  extensions: []
@@ -150,7 +162,6 @@ homepage: https://github.com/ankane/transformers-ruby
150
162
  licenses:
151
163
  - Apache-2.0
152
164
  metadata: {}
153
- post_install_message:
154
165
  rdoc_options: []
155
166
  require_paths:
156
167
  - lib
@@ -165,8 +176,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
165
176
  - !ruby/object:Gem::Version
166
177
  version: '0'
167
178
  requirements: []
168
- rubygems_version: 3.5.16
169
- signing_key:
179
+ rubygems_version: 3.6.2
170
180
  specification_version: 4
171
181
  summary: State-of-the-art transformers for Ruby
172
182
  test_files: []