transformers-rb 0.1.4 → 0.1.6
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +1 -1
- data/lib/transformers/modeling_utils.rb +11 -0
- data/lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb +2 -0
- data/lib/transformers/pipelines/base.rb +38 -1
- data/lib/transformers/pipelines/reranking.rb +2 -1
- data/lib/transformers/sentence_transformer.rb +1 -0
- data/lib/transformers/version.rb +1 -1
- metadata +17 -7
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 607afc7b4cc3bb62f8ae2aab693863d2cd1456d0c90b20dc0eb4f682fc788cfb
|
4
|
+
data.tar.gz: 3a6ad9f7f624e77e077bf4d7e1d945ee8f6e26d805e08a574d5d583e231500fb
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: ab988a57c8af7883894163b149528172a8602d9031f9c6e6880d6bb32f42842cbfa566ac4f347168a230218565cc7550c99043286e5a31d0d104ba3b250b2483
|
7
|
+
data.tar.gz: 7f29ee6adcc0566ce1d1905e557288dec78b8befcc98f9a46c71c0fe8580dd14c914afcb1c1477df65aecc3b106cf799975c2b46ba73a4ec7d93ea34a908b688
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,14 @@
|
|
1
|
+
## 0.1.6 (2024-12-29)
|
2
|
+
|
3
|
+
- Fixed error with failed HTTP requests
|
4
|
+
- Fixed warning with Ruby 3.4
|
5
|
+
|
6
|
+
## 0.1.5 (2024-11-01)
|
7
|
+
|
8
|
+
- Fixed error with pipelines when called more than 10 times
|
9
|
+
- Fixed `device` option for pipelines
|
10
|
+
- Fixed error with `reranking` pipeline
|
11
|
+
|
1
12
|
## 0.1.4 (2024-10-22)
|
2
13
|
|
3
14
|
- Added `BertForSequenceClassification`
|
@@ -24,7 +24,7 @@ module Transformers
|
|
24
24
|
def hf_raise_for_status(response, endpoint_name: nil)
|
25
25
|
begin
|
26
26
|
response.value unless response.is_a?(Net::HTTPRedirection)
|
27
|
-
rescue
|
27
|
+
rescue => e
|
28
28
|
error_code = response["X-Error-Code"]
|
29
29
|
error_message = response["X-Error-Message"]
|
30
30
|
|
@@ -56,6 +56,10 @@ module Transformers
|
|
56
56
|
end
|
57
57
|
|
58
58
|
module ModuleUtilsMixin
|
59
|
+
def device
|
60
|
+
get_parameter_device(self)
|
61
|
+
end
|
62
|
+
|
59
63
|
def get_extended_attention_mask(
|
60
64
|
attention_mask,
|
61
65
|
input_shape,
|
@@ -112,6 +116,13 @@ module Transformers
|
|
112
116
|
|
113
117
|
head_mask
|
114
118
|
end
|
119
|
+
|
120
|
+
private
|
121
|
+
|
122
|
+
def get_parameter_device(parameter)
|
123
|
+
# TODO return Torch::Device in Torch.rb
|
124
|
+
Torch.device(parameter.parameters[0].device)
|
125
|
+
end
|
115
126
|
end
|
116
127
|
|
117
128
|
class PreTrainedModel < Torch::NN::Module
|
@@ -123,6 +123,22 @@ module Transformers
|
|
123
123
|
end
|
124
124
|
end
|
125
125
|
|
126
|
+
if @framework == "pt"
|
127
|
+
if device.is_a?(String)
|
128
|
+
@device = Torch.device(device)
|
129
|
+
else
|
130
|
+
# TODO update default in 0.2.0
|
131
|
+
@device = @model.device
|
132
|
+
end
|
133
|
+
else
|
134
|
+
raise Todo
|
135
|
+
end
|
136
|
+
|
137
|
+
# TODO Fix eql? for Torch::Device in Torch.rb
|
138
|
+
if @device.type != @model.device.type || @device.index != @model.device.index
|
139
|
+
@model.to(@device)
|
140
|
+
end
|
141
|
+
|
126
142
|
@call_count = 0
|
127
143
|
@batch_size = kwargs.delete(:batch_size)
|
128
144
|
@num_workers = kwargs.delete(:num_workers)
|
@@ -133,6 +149,24 @@ module Transformers
|
|
133
149
|
@model.dtype
|
134
150
|
end
|
135
151
|
|
152
|
+
def _ensure_tensor_on_device(inputs, device)
|
153
|
+
# TODO move
|
154
|
+
inputs = inputs.to_h if inputs.is_a?(BatchEncoding)
|
155
|
+
|
156
|
+
if inputs.is_a?(ModelOutput)
|
157
|
+
inputs.instance_variable_get(:@data).transform_values! { |tensor| _ensure_tensor_on_device(tensor, device) }
|
158
|
+
inputs
|
159
|
+
elsif inputs.is_a?(Hash)
|
160
|
+
inputs.transform_values { |tensor| _ensure_tensor_on_device(tensor, device) }
|
161
|
+
elsif inputs.is_a?(Array)
|
162
|
+
inputs.map { |item| _ensure_tensor_on_device(item, device) }
|
163
|
+
elsif inputs.is_a?(Torch::Tensor)
|
164
|
+
inputs.to(device)
|
165
|
+
else
|
166
|
+
inputs
|
167
|
+
end
|
168
|
+
end
|
169
|
+
|
136
170
|
def check_model_type(supported_models)
|
137
171
|
if !supported_models.is_a?(Array)
|
138
172
|
supported_models_names = []
|
@@ -245,7 +279,10 @@ module Transformers
|
|
245
279
|
end
|
246
280
|
|
247
281
|
def forward(model_inputs, **forward_params)
|
248
|
-
|
282
|
+
model_inputs = _ensure_tensor_on_device(model_inputs, @device)
|
283
|
+
model_outputs = _forward(model_inputs, **forward_params)
|
284
|
+
model_outputs = _ensure_tensor_on_device(model_outputs, Torch.device("cpu"))
|
285
|
+
model_outputs
|
249
286
|
end
|
250
287
|
|
251
288
|
def run_single(inputs, preprocess_params, forward_params, postprocess_params)
|
data/lib/transformers/version.rb
CHANGED
metadata
CHANGED
@@ -1,15 +1,28 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: transformers-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.6
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
|
-
autorequire:
|
9
8
|
bindir: bin
|
10
9
|
cert_chain: []
|
11
|
-
date: 2024-
|
10
|
+
date: 2024-12-29 00:00:00.000000000 Z
|
12
11
|
dependencies:
|
12
|
+
- !ruby/object:Gem::Dependency
|
13
|
+
name: logger
|
14
|
+
requirement: !ruby/object:Gem::Requirement
|
15
|
+
requirements:
|
16
|
+
- - ">="
|
17
|
+
- !ruby/object:Gem::Version
|
18
|
+
version: '0'
|
19
|
+
type: :runtime
|
20
|
+
prerelease: false
|
21
|
+
version_requirements: !ruby/object:Gem::Requirement
|
22
|
+
requirements:
|
23
|
+
- - ">="
|
24
|
+
- !ruby/object:Gem::Version
|
25
|
+
version: '0'
|
13
26
|
- !ruby/object:Gem::Dependency
|
14
27
|
name: numo-narray
|
15
28
|
requirement: !ruby/object:Gem::Requirement
|
@@ -66,7 +79,6 @@ dependencies:
|
|
66
79
|
- - ">="
|
67
80
|
- !ruby/object:Gem::Version
|
68
81
|
version: 0.17.1
|
69
|
-
description:
|
70
82
|
email: andrew@ankane.org
|
71
83
|
executables: []
|
72
84
|
extensions: []
|
@@ -150,7 +162,6 @@ homepage: https://github.com/ankane/transformers-ruby
|
|
150
162
|
licenses:
|
151
163
|
- Apache-2.0
|
152
164
|
metadata: {}
|
153
|
-
post_install_message:
|
154
165
|
rdoc_options: []
|
155
166
|
require_paths:
|
156
167
|
- lib
|
@@ -165,8 +176,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
165
176
|
- !ruby/object:Gem::Version
|
166
177
|
version: '0'
|
167
178
|
requirements: []
|
168
|
-
rubygems_version: 3.
|
169
|
-
signing_key:
|
179
|
+
rubygems_version: 3.6.2
|
170
180
|
specification_version: 4
|
171
181
|
summary: State-of-the-art transformers for Ruby
|
172
182
|
test_files: []
|