transformers-rb 0.1.4 → 0.1.5

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 16b409e954f6bcc45fd4f3f2db94dc92c87d47c4c06936162978cc7d7e54fc09
4
- data.tar.gz: c40ea58b7531e89a041ce2782dea89536f85227903e2dc9a60113afe041bb9f7
3
+ metadata.gz: 815a531b2876ca2f66d9df65e13a08f42ab0b4a98d0d1ef459961fc3069faaa7
4
+ data.tar.gz: 842158dd34cd785efff7a6df7524acd019afc6ead7738cfd9ddf6efadb89d97d
5
5
  SHA512:
6
- metadata.gz: 0576500ca9fe9379aae4c2cc050aa34c90eea7b2d5251b6139c48d88b5107086e6197aad5210bc74d7969a5bfa0458d31ce7faf561df69b2ae9b2a6400280ce0
7
- data.tar.gz: b5b0a865acdd37bcde11571e365a1d39b473f3dcbbe737a082f37002f886fe38f0f57c77f44236f859939d1d3b1df9659bb524a483cd0b9048dbc4cd472a3355
6
+ metadata.gz: 2a4fa39c9d32b4ba83eeafe128d10af0873a910807ad67870359e8bffcdf94bdeeb18dea8bc0b2372926ded909e8223a2603388297aad357bca26ea0c6eb0baa
7
+ data.tar.gz: bfac7d5c91e91bbbf1770ea71a9a9c69bca71c1e0ded3b0ca696156c99c331a5988879a3cf48731145db7be1d1c1b14dc7e573b728e1d1a123faa251ab56a560
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## 0.1.5 (2024-11-01)
2
+
3
+ - Fixed error with pipelines when called more than 10 times
4
+ - Fixed `device` option for pipelines
5
+ - Fixed error with `reranking` pipeline
6
+
1
7
  ## 0.1.4 (2024-10-22)
2
8
 
3
9
  - Added `BertForSequenceClassification`
@@ -56,6 +56,10 @@ module Transformers
56
56
  end
57
57
 
58
58
  module ModuleUtilsMixin
59
+ def device
60
+ get_parameter_device(self)
61
+ end
62
+
59
63
  def get_extended_attention_mask(
60
64
  attention_mask,
61
65
  input_shape,
@@ -112,6 +116,13 @@ module Transformers
112
116
 
113
117
  head_mask
114
118
  end
119
+
120
+ private
121
+
122
+ def get_parameter_device(parameter)
123
+ # TODO return Torch::Device in Torch.rb
124
+ Torch.device(parameter.parameters[0].device)
125
+ end
115
126
  end
116
127
 
117
128
  class PreTrainedModel < Torch::NN::Module
@@ -1213,4 +1213,6 @@ module Transformers
1213
1213
  end
1214
1214
  end
1215
1215
  end
1216
+
1217
+ XLMRobertaForSequenceClassification = XlmRoberta::XLMRobertaForSequenceClassification
1216
1218
  end
@@ -123,6 +123,22 @@ module Transformers
123
123
  end
124
124
  end
125
125
 
126
+ if @framework == "pt"
127
+ if device.is_a?(String)
128
+ @device = Torch.device(device)
129
+ else
130
+ # TODO update default in 0.2.0
131
+ @device = @model.device
132
+ end
133
+ else
134
+ raise Todo
135
+ end
136
+
137
+ # TODO Fix eql? for Torch::Device in Torch.rb
138
+ if @device.type != @model.device.type || @device.index != @model.device.index
139
+ @model.to(@device)
140
+ end
141
+
126
142
  @call_count = 0
127
143
  @batch_size = kwargs.delete(:batch_size)
128
144
  @num_workers = kwargs.delete(:num_workers)
@@ -133,6 +149,24 @@ module Transformers
133
149
  @model.dtype
134
150
  end
135
151
 
152
+ def _ensure_tensor_on_device(inputs, device)
153
+ # TODO move
154
+ inputs = inputs.to_h if inputs.is_a?(BatchEncoding)
155
+
156
+ if inputs.is_a?(ModelOutput)
157
+ inputs.instance_variable_get(:@data).transform_values! { |tensor| _ensure_tensor_on_device(tensor, device) }
158
+ inputs
159
+ elsif inputs.is_a?(Hash)
160
+ inputs.transform_values { |tensor| _ensure_tensor_on_device(tensor, device) }
161
+ elsif inputs.is_a?(Array)
162
+ inputs.map { |item| _ensure_tensor_on_device(item, device) }
163
+ elsif inputs.is_a?(Torch::Tensor)
164
+ inputs.to(device)
165
+ else
166
+ inputs
167
+ end
168
+ end
169
+
136
170
  def check_model_type(supported_models)
137
171
  if !supported_models.is_a?(Array)
138
172
  supported_models_names = []
@@ -245,7 +279,10 @@ module Transformers
245
279
  end
246
280
 
247
281
  def forward(model_inputs, **forward_params)
248
- _forward(model_inputs, **forward_params)
282
+ model_inputs = _ensure_tensor_on_device(model_inputs, @device)
283
+ model_outputs = _forward(model_inputs, **forward_params)
284
+ model_outputs = _ensure_tensor_on_device(model_outputs, Torch.device("cpu"))
285
+ model_outputs
249
286
  end
250
287
 
251
288
  def run_single(inputs, preprocess_params, forward_params, postprocess_params)
@@ -8,7 +8,8 @@ module Transformers
8
8
  @tokenizer.(
9
9
  [inputs[:query]] * inputs[:documents].length,
10
10
  text_pair: inputs[:documents],
11
- return_tensors: @framework
11
+ return_tensors: @framework,
12
+ padding: true
12
13
  )
13
14
  end
14
15
 
@@ -1,3 +1,3 @@
1
1
  module Transformers
2
- VERSION = "0.1.4"
2
+ VERSION = "0.1.5"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: transformers-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.4
4
+ version: 0.1.5
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-10-22 00:00:00.000000000 Z
11
+ date: 2024-11-01 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray