transformers-rb 0.1.4 → 0.1.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/lib/transformers/modeling_utils.rb +11 -0
- data/lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb +2 -0
- data/lib/transformers/pipelines/base.rb +38 -1
- data/lib/transformers/pipelines/reranking.rb +2 -1
- data/lib/transformers/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 815a531b2876ca2f66d9df65e13a08f42ab0b4a98d0d1ef459961fc3069faaa7
|
4
|
+
data.tar.gz: 842158dd34cd785efff7a6df7524acd019afc6ead7738cfd9ddf6efadb89d97d
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 2a4fa39c9d32b4ba83eeafe128d10af0873a910807ad67870359e8bffcdf94bdeeb18dea8bc0b2372926ded909e8223a2603388297aad357bca26ea0c6eb0baa
|
7
|
+
data.tar.gz: bfac7d5c91e91bbbf1770ea71a9a9c69bca71c1e0ded3b0ca696156c99c331a5988879a3cf48731145db7be1d1c1b14dc7e573b728e1d1a123faa251ab56a560
|
data/CHANGELOG.md
CHANGED
@@ -56,6 +56,10 @@ module Transformers
|
|
56
56
|
end
|
57
57
|
|
58
58
|
module ModuleUtilsMixin
|
59
|
+
def device
|
60
|
+
get_parameter_device(self)
|
61
|
+
end
|
62
|
+
|
59
63
|
def get_extended_attention_mask(
|
60
64
|
attention_mask,
|
61
65
|
input_shape,
|
@@ -112,6 +116,13 @@ module Transformers
|
|
112
116
|
|
113
117
|
head_mask
|
114
118
|
end
|
119
|
+
|
120
|
+
private
|
121
|
+
|
122
|
+
def get_parameter_device(parameter)
|
123
|
+
# TODO return Torch::Device in Torch.rb
|
124
|
+
Torch.device(parameter.parameters[0].device)
|
125
|
+
end
|
115
126
|
end
|
116
127
|
|
117
128
|
class PreTrainedModel < Torch::NN::Module
|
@@ -123,6 +123,22 @@ module Transformers
|
|
123
123
|
end
|
124
124
|
end
|
125
125
|
|
126
|
+
if @framework == "pt"
|
127
|
+
if device.is_a?(String)
|
128
|
+
@device = Torch.device(device)
|
129
|
+
else
|
130
|
+
# TODO update default in 0.2.0
|
131
|
+
@device = @model.device
|
132
|
+
end
|
133
|
+
else
|
134
|
+
raise Todo
|
135
|
+
end
|
136
|
+
|
137
|
+
# TODO Fix eql? for Torch::Device in Torch.rb
|
138
|
+
if @device.type != @model.device.type || @device.index != @model.device.index
|
139
|
+
@model.to(@device)
|
140
|
+
end
|
141
|
+
|
126
142
|
@call_count = 0
|
127
143
|
@batch_size = kwargs.delete(:batch_size)
|
128
144
|
@num_workers = kwargs.delete(:num_workers)
|
@@ -133,6 +149,24 @@ module Transformers
|
|
133
149
|
@model.dtype
|
134
150
|
end
|
135
151
|
|
152
|
+
def _ensure_tensor_on_device(inputs, device)
|
153
|
+
# TODO move
|
154
|
+
inputs = inputs.to_h if inputs.is_a?(BatchEncoding)
|
155
|
+
|
156
|
+
if inputs.is_a?(ModelOutput)
|
157
|
+
inputs.instance_variable_get(:@data).transform_values! { |tensor| _ensure_tensor_on_device(tensor, device) }
|
158
|
+
inputs
|
159
|
+
elsif inputs.is_a?(Hash)
|
160
|
+
inputs.transform_values { |tensor| _ensure_tensor_on_device(tensor, device) }
|
161
|
+
elsif inputs.is_a?(Array)
|
162
|
+
inputs.map { |item| _ensure_tensor_on_device(item, device) }
|
163
|
+
elsif inputs.is_a?(Torch::Tensor)
|
164
|
+
inputs.to(device)
|
165
|
+
else
|
166
|
+
inputs
|
167
|
+
end
|
168
|
+
end
|
169
|
+
|
136
170
|
def check_model_type(supported_models)
|
137
171
|
if !supported_models.is_a?(Array)
|
138
172
|
supported_models_names = []
|
@@ -245,7 +279,10 @@ module Transformers
|
|
245
279
|
end
|
246
280
|
|
247
281
|
def forward(model_inputs, **forward_params)
|
248
|
-
|
282
|
+
model_inputs = _ensure_tensor_on_device(model_inputs, @device)
|
283
|
+
model_outputs = _forward(model_inputs, **forward_params)
|
284
|
+
model_outputs = _ensure_tensor_on_device(model_outputs, Torch.device("cpu"))
|
285
|
+
model_outputs
|
249
286
|
end
|
250
287
|
|
251
288
|
def run_single(inputs, preprocess_params, forward_params, postprocess_params)
|
data/lib/transformers/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: transformers-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.5
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-
|
11
|
+
date: 2024-11-01 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|