transformers-rb 0.1.4 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/lib/transformers/modeling_utils.rb +11 -0
- data/lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb +2 -0
- data/lib/transformers/pipelines/base.rb +38 -1
- data/lib/transformers/pipelines/reranking.rb +2 -1
- data/lib/transformers/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 815a531b2876ca2f66d9df65e13a08f42ab0b4a98d0d1ef459961fc3069faaa7
|
4
|
+
data.tar.gz: 842158dd34cd785efff7a6df7524acd019afc6ead7738cfd9ddf6efadb89d97d
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 2a4fa39c9d32b4ba83eeafe128d10af0873a910807ad67870359e8bffcdf94bdeeb18dea8bc0b2372926ded909e8223a2603388297aad357bca26ea0c6eb0baa
|
7
|
+
data.tar.gz: bfac7d5c91e91bbbf1770ea71a9a9c69bca71c1e0ded3b0ca696156c99c331a5988879a3cf48731145db7be1d1c1b14dc7e573b728e1d1a123faa251ab56a560
|
data/CHANGELOG.md
CHANGED
@@ -56,6 +56,10 @@ module Transformers
|
|
56
56
|
end
|
57
57
|
|
58
58
|
module ModuleUtilsMixin
|
59
|
+
def device
|
60
|
+
get_parameter_device(self)
|
61
|
+
end
|
62
|
+
|
59
63
|
def get_extended_attention_mask(
|
60
64
|
attention_mask,
|
61
65
|
input_shape,
|
@@ -112,6 +116,13 @@ module Transformers
|
|
112
116
|
|
113
117
|
head_mask
|
114
118
|
end
|
119
|
+
|
120
|
+
private
|
121
|
+
|
122
|
+
def get_parameter_device(parameter)
|
123
|
+
# TODO return Torch::Device in Torch.rb
|
124
|
+
Torch.device(parameter.parameters[0].device)
|
125
|
+
end
|
115
126
|
end
|
116
127
|
|
117
128
|
class PreTrainedModel < Torch::NN::Module
|
@@ -123,6 +123,22 @@ module Transformers
|
|
123
123
|
end
|
124
124
|
end
|
125
125
|
|
126
|
+
if @framework == "pt"
|
127
|
+
if device.is_a?(String)
|
128
|
+
@device = Torch.device(device)
|
129
|
+
else
|
130
|
+
# TODO update default in 0.2.0
|
131
|
+
@device = @model.device
|
132
|
+
end
|
133
|
+
else
|
134
|
+
raise Todo
|
135
|
+
end
|
136
|
+
|
137
|
+
# TODO Fix eql? for Torch::Device in Torch.rb
|
138
|
+
if @device.type != @model.device.type || @device.index != @model.device.index
|
139
|
+
@model.to(@device)
|
140
|
+
end
|
141
|
+
|
126
142
|
@call_count = 0
|
127
143
|
@batch_size = kwargs.delete(:batch_size)
|
128
144
|
@num_workers = kwargs.delete(:num_workers)
|
@@ -133,6 +149,24 @@ module Transformers
|
|
133
149
|
@model.dtype
|
134
150
|
end
|
135
151
|
|
152
|
+
def _ensure_tensor_on_device(inputs, device)
|
153
|
+
# TODO move
|
154
|
+
inputs = inputs.to_h if inputs.is_a?(BatchEncoding)
|
155
|
+
|
156
|
+
if inputs.is_a?(ModelOutput)
|
157
|
+
inputs.instance_variable_get(:@data).transform_values! { |tensor| _ensure_tensor_on_device(tensor, device) }
|
158
|
+
inputs
|
159
|
+
elsif inputs.is_a?(Hash)
|
160
|
+
inputs.transform_values { |tensor| _ensure_tensor_on_device(tensor, device) }
|
161
|
+
elsif inputs.is_a?(Array)
|
162
|
+
inputs.map { |item| _ensure_tensor_on_device(item, device) }
|
163
|
+
elsif inputs.is_a?(Torch::Tensor)
|
164
|
+
inputs.to(device)
|
165
|
+
else
|
166
|
+
inputs
|
167
|
+
end
|
168
|
+
end
|
169
|
+
|
136
170
|
def check_model_type(supported_models)
|
137
171
|
if !supported_models.is_a?(Array)
|
138
172
|
supported_models_names = []
|
@@ -245,7 +279,10 @@ module Transformers
|
|
245
279
|
end
|
246
280
|
|
247
281
|
def forward(model_inputs, **forward_params)
|
248
|
-
|
282
|
+
model_inputs = _ensure_tensor_on_device(model_inputs, @device)
|
283
|
+
model_outputs = _forward(model_inputs, **forward_params)
|
284
|
+
model_outputs = _ensure_tensor_on_device(model_outputs, Torch.device("cpu"))
|
285
|
+
model_outputs
|
249
286
|
end
|
250
287
|
|
251
288
|
def run_single(inputs, preprocess_params, forward_params, postprocess_params)
|
data/lib/transformers/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: transformers-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.5
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-
|
11
|
+
date: 2024-11-01 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|