transformers-rb 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 16b409e954f6bcc45fd4f3f2db94dc92c87d47c4c06936162978cc7d7e54fc09
4
- data.tar.gz: c40ea58b7531e89a041ce2782dea89536f85227903e2dc9a60113afe041bb9f7
3
+ metadata.gz: 607afc7b4cc3bb62f8ae2aab693863d2cd1456d0c90b20dc0eb4f682fc788cfb
4
+ data.tar.gz: 3a6ad9f7f624e77e077bf4d7e1d945ee8f6e26d805e08a574d5d583e231500fb
5
5
  SHA512:
6
- metadata.gz: 0576500ca9fe9379aae4c2cc050aa34c90eea7b2d5251b6139c48d88b5107086e6197aad5210bc74d7969a5bfa0458d31ce7faf561df69b2ae9b2a6400280ce0
7
- data.tar.gz: b5b0a865acdd37bcde11571e365a1d39b473f3dcbbe737a082f37002f886fe38f0f57c77f44236f859939d1d3b1df9659bb524a483cd0b9048dbc4cd472a3355
6
+ metadata.gz: ab988a57c8af7883894163b149528172a8602d9031f9c6e6880d6bb32f42842cbfa566ac4f347168a230218565cc7550c99043286e5a31d0d104ba3b250b2483
7
+ data.tar.gz: 7f29ee6adcc0566ce1d1905e557288dec78b8befcc98f9a46c71c0fe8580dd14c914afcb1c1477df65aecc3b106cf799975c2b46ba73a4ec7d93ea34a908b688
data/CHANGELOG.md CHANGED
@@ -1,3 +1,14 @@
1
+ ## 0.1.6 (2024-12-29)
2
+
3
+ - Fixed error with failed HTTP requests
4
+ - Fixed warning with Ruby 3.4
5
+
6
+ ## 0.1.5 (2024-11-01)
7
+
8
+ - Fixed error with pipelines when called more than 10 times
9
+ - Fixed `device` option for pipelines
10
+ - Fixed error with `reranking` pipeline
11
+
1
12
  ## 0.1.4 (2024-10-22)
2
13
 
3
14
  - Added `BertForSequenceClassification`
@@ -24,7 +24,7 @@ module Transformers
24
24
  def hf_raise_for_status(response, endpoint_name: nil)
25
25
  begin
26
26
  response.value unless response.is_a?(Net::HTTPRedirection)
27
- rescue
27
+ rescue => e
28
28
  error_code = response["X-Error-Code"]
29
29
  error_message = response["X-Error-Message"]
30
30
 
@@ -56,6 +56,10 @@ module Transformers
56
56
  end
57
57
 
58
58
  module ModuleUtilsMixin
59
+ def device
60
+ get_parameter_device(self)
61
+ end
62
+
59
63
  def get_extended_attention_mask(
60
64
  attention_mask,
61
65
  input_shape,
@@ -112,6 +116,13 @@ module Transformers
112
116
 
113
117
  head_mask
114
118
  end
119
+
120
+ private
121
+
122
+ def get_parameter_device(parameter)
123
+ # TODO return Torch::Device in Torch.rb
124
+ Torch.device(parameter.parameters[0].device)
125
+ end
115
126
  end
116
127
 
117
128
  class PreTrainedModel < Torch::NN::Module
@@ -1213,4 +1213,6 @@ module Transformers
1213
1213
  end
1214
1214
  end
1215
1215
  end
1216
+
1217
+ XLMRobertaForSequenceClassification = XlmRoberta::XLMRobertaForSequenceClassification
1216
1218
  end
@@ -123,6 +123,22 @@ module Transformers
123
123
  end
124
124
  end
125
125
 
126
+ if @framework == "pt"
127
+ if device.is_a?(String)
128
+ @device = Torch.device(device)
129
+ else
130
+ # TODO update default in 0.2.0
131
+ @device = @model.device
132
+ end
133
+ else
134
+ raise Todo
135
+ end
136
+
137
+ # TODO Fix eql? for Torch::Device in Torch.rb
138
+ if @device.type != @model.device.type || @device.index != @model.device.index
139
+ @model.to(@device)
140
+ end
141
+
126
142
  @call_count = 0
127
143
  @batch_size = kwargs.delete(:batch_size)
128
144
  @num_workers = kwargs.delete(:num_workers)
@@ -133,6 +149,24 @@ module Transformers
133
149
  @model.dtype
134
150
  end
135
151
 
152
+ def _ensure_tensor_on_device(inputs, device)
153
+ # TODO move
154
+ inputs = inputs.to_h if inputs.is_a?(BatchEncoding)
155
+
156
+ if inputs.is_a?(ModelOutput)
157
+ inputs.instance_variable_get(:@data).transform_values! { |tensor| _ensure_tensor_on_device(tensor, device) }
158
+ inputs
159
+ elsif inputs.is_a?(Hash)
160
+ inputs.transform_values { |tensor| _ensure_tensor_on_device(tensor, device) }
161
+ elsif inputs.is_a?(Array)
162
+ inputs.map { |item| _ensure_tensor_on_device(item, device) }
163
+ elsif inputs.is_a?(Torch::Tensor)
164
+ inputs.to(device)
165
+ else
166
+ inputs
167
+ end
168
+ end
169
+
136
170
  def check_model_type(supported_models)
137
171
  if !supported_models.is_a?(Array)
138
172
  supported_models_names = []
@@ -245,7 +279,10 @@ module Transformers
245
279
  end
246
280
 
247
281
  def forward(model_inputs, **forward_params)
248
- _forward(model_inputs, **forward_params)
282
+ model_inputs = _ensure_tensor_on_device(model_inputs, @device)
283
+ model_outputs = _forward(model_inputs, **forward_params)
284
+ model_outputs = _ensure_tensor_on_device(model_outputs, Torch.device("cpu"))
285
+ model_outputs
249
286
  end
250
287
 
251
288
  def run_single(inputs, preprocess_params, forward_params, postprocess_params)
@@ -8,7 +8,8 @@ module Transformers
8
8
  @tokenizer.(
9
9
  [inputs[:query]] * inputs[:documents].length,
10
10
  text_pair: inputs[:documents],
11
- return_tensors: @framework
11
+ return_tensors: @framework,
12
+ padding: true
12
13
  )
13
14
  end
14
15
 
@@ -1,4 +1,5 @@
1
1
  module Transformers
2
+ # TODO remove in 0.2.0
2
3
  class SentenceTransformer
3
4
  def initialize(model_id)
4
5
  @model_id = model_id
@@ -1,3 +1,3 @@
1
1
  module Transformers
2
- VERSION = "0.1.4"
2
+ VERSION = "0.1.6"
3
3
  end
metadata CHANGED
@@ -1,15 +1,28 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: transformers-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.4
4
+ version: 0.1.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
9
8
  bindir: bin
10
9
  cert_chain: []
11
- date: 2024-10-22 00:00:00.000000000 Z
10
+ date: 2024-12-29 00:00:00.000000000 Z
12
11
  dependencies:
12
+ - !ruby/object:Gem::Dependency
13
+ name: logger
14
+ requirement: !ruby/object:Gem::Requirement
15
+ requirements:
16
+ - - ">="
17
+ - !ruby/object:Gem::Version
18
+ version: '0'
19
+ type: :runtime
20
+ prerelease: false
21
+ version_requirements: !ruby/object:Gem::Requirement
22
+ requirements:
23
+ - - ">="
24
+ - !ruby/object:Gem::Version
25
+ version: '0'
13
26
  - !ruby/object:Gem::Dependency
14
27
  name: numo-narray
15
28
  requirement: !ruby/object:Gem::Requirement
@@ -66,7 +79,6 @@ dependencies:
66
79
  - - ">="
67
80
  - !ruby/object:Gem::Version
68
81
  version: 0.17.1
69
- description:
70
82
  email: andrew@ankane.org
71
83
  executables: []
72
84
  extensions: []
@@ -150,7 +162,6 @@ homepage: https://github.com/ankane/transformers-ruby
150
162
  licenses:
151
163
  - Apache-2.0
152
164
  metadata: {}
153
- post_install_message:
154
165
  rdoc_options: []
155
166
  require_paths:
156
167
  - lib
@@ -165,8 +176,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
165
176
  - !ruby/object:Gem::Version
166
177
  version: '0'
167
178
  requirements: []
168
- rubygems_version: 3.5.16
169
- signing_key:
179
+ rubygems_version: 3.6.2
170
180
  specification_version: 4
171
181
  summary: State-of-the-art transformers for Ruby
172
182
  test_files: []