informers 0.2.0 → 1.0.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,240 @@
1
+ module Informers
2
+ module Utils
3
+ module Hub
4
+ class FileResponse
5
+ attr_reader :exists, :status
6
+
7
+ def initialize(file_path)
8
+ @file_path = file_path
9
+
10
+ @exists = File.exist?(file_path)
11
+ if @exists
12
+ @status = ["200", "OK"]
13
+ else
14
+ @status = ["404", "Not Found"]
15
+ end
16
+ end
17
+
18
+ def read
19
+ File.binread(@file_path)
20
+ end
21
+ end
22
+
23
+ def self.is_valid_url(string, protocols = nil, valid_hosts = nil)
24
+ begin
25
+ url = URI.parse(string)
26
+ rescue
27
+ return false
28
+ end
29
+ if protocols && !protocols.include?(url.scheme)
30
+ return false
31
+ end
32
+ if valid_hosts && !valid_hosts.include?(url.host)
33
+ return false
34
+ end
35
+ true
36
+ end
37
+
38
+ def self.get_file(url_or_path, progress_callback = nil, progress_info = {})
39
+ if !is_valid_url(url_or_path, ["http", "https"])
40
+ raise Error, "Invalid url"
41
+ else
42
+ headers = {}
43
+ headers["User-Agent"] = "informers/#{VERSION};"
44
+
45
+ # Check whether we are making a request to the Hugging Face Hub.
46
+ is_hfurl = is_valid_url(url_or_path, ["http", "https"], ["huggingface.co", "hf.co"])
47
+ if is_hfurl
48
+ # If an access token is present in the environment variables,
49
+ # we add it to the request headers.
50
+ token = ENV["HF_TOKEN"]
51
+ if token
52
+ headers["Authorization"] = "Bearer #{token}"
53
+ end
54
+ end
55
+ options = {}
56
+ if progress_callback
57
+ total_size = nil
58
+ options[:content_length_proc] = lambda do |size|
59
+ total_size = size
60
+ Utils.dispatch_callback(progress_callback, {status: "download"}.merge(progress_info).merge(total_size: size))
61
+ end
62
+ options[:progress_proc] = lambda do |size|
63
+ Utils.dispatch_callback(progress_callback, {status: "progress"}.merge(progress_info).merge(size: size, total_size: total_size))
64
+ end
65
+ end
66
+ URI.parse(url_or_path).open(**headers, **options)
67
+ end
68
+ end
69
+
70
+ class FileCache
71
+ attr_reader :path
72
+
73
+ def initialize(path)
74
+ @path = path
75
+ end
76
+
77
+ def match(request)
78
+ file_path = resolve_path(request)
79
+ file = FileResponse.new(file_path)
80
+
81
+ file if file.exists
82
+ end
83
+
84
+ def put(request, buffer)
85
+ output_path = resolve_path(request)
86
+
87
+ begin
88
+ FileUtils.mkdir_p(File.dirname(output_path))
89
+ File.binwrite(output_path, buffer)
90
+ rescue => e
91
+ warn "An error occurred while writing the file to cache: #{e}"
92
+ end
93
+ end
94
+
95
+ def resolve_path(request)
96
+ File.join(@path, request)
97
+ end
98
+ end
99
+
100
+ def self.try_cache(cache, *names)
101
+ names.each do |name|
102
+ begin
103
+ result = cache.match(name)
104
+ return result if result
105
+ rescue
106
+ next
107
+ end
108
+ end
109
+ nil
110
+ end
111
+
112
+ def self.get_model_file(path_or_repo_id, filename, fatal = true, **options)
113
+ # Initiate file retrieval
114
+ Utils.dispatch_callback(options[:progress_callback], {
115
+ status: "initiate",
116
+ name: path_or_repo_id,
117
+ file: filename
118
+ })
119
+
120
+ # If `cache_dir` is not specified, use the default cache directory
121
+ cache = FileCache.new(options[:cache_dir] || Informers.cache_dir)
122
+
123
+ revision = options[:revision] || "main"
124
+
125
+ request_url = path_join(path_or_repo_id, filename)
126
+
127
+ remote_url = path_join(
128
+ Informers.remote_host,
129
+ Informers.remote_path_template
130
+ .gsub("{model}", path_or_repo_id)
131
+ .gsub("{revision}", URI.encode_www_form_component(revision)),
132
+ filename
133
+ )
134
+
135
+ # Choose cache key for filesystem cache
136
+ # When using the main revision (default), we use the request URL as the cache key.
137
+ # If a specific revision is requested, we account for this in the cache key.
138
+ fs_cache_key = revision == "main" ? request_url : path_join(path_or_repo_id, revision, filename)
139
+
140
+ proposed_cache_key = fs_cache_key
141
+
142
+ resolved_path = cache.resolve_path(proposed_cache_key)
143
+
144
+ # Whether to cache the final response in the end.
145
+ to_cache_response = false
146
+
147
+ # A caching system is available, so we try to get the file from it.
148
+ response = try_cache(cache, proposed_cache_key)
149
+
150
+ cache_hit = !response.nil?
151
+
152
+ if response.nil?
153
+ # File is not cached, so we perform the request
154
+
155
+ if response.nil? || response.status[0] == "404"
156
+ # File not found locally. This means either:
157
+ # - The user has disabled local file access (`Informers.allow_local_models = false`)
158
+ # - the path is a valid HTTP url (`response.nil?`)
159
+ # - the path is not a valid HTTP url and the file is not present on the file system or local server (`response.status[0] == "404"`)
160
+
161
+ if options[:local_files_only] || !Informers.allow_remote_models
162
+ # User requested local files only, but the file is not found locally.
163
+ if fatal
164
+ raise Error, "`local_files_only: true` or `Informers.allow_remote_models = false` and file was not found locally at #{resolved_path.inspect}."
165
+ else
166
+ # File not found, but this file is optional.
167
+ # TODO in future, cache the response?
168
+ return nil
169
+ end
170
+ end
171
+
172
+ progress_info = {
173
+ name: path_or_repo_id,
174
+ file: filename
175
+ }
176
+
177
+ # File not found locally, so we try to download it from the remote server
178
+ response = get_file(remote_url, options[:progress_callback], progress_info)
179
+
180
+ if response.status[0] != "200"
181
+ # should not happen
182
+ raise Todo
183
+ end
184
+
185
+ # Success! We use the proposed cache key from earlier
186
+ cache_key = proposed_cache_key
187
+ end
188
+
189
+ to_cache_response = cache && !response.is_a?(FileResponse) && response.status[0] == "200"
190
+ end
191
+
192
+ buffer = response.read
193
+
194
+ if to_cache_response && cache_key && cache.match(cache_key).nil?
195
+ cache.put(cache_key, buffer)
196
+ end
197
+
198
+ Utils.dispatch_callback(options[:progress_callback], {
199
+ status: "done",
200
+ name: path_or_repo_id,
201
+ file: filename,
202
+ cache_hit: cache_hit
203
+ })
204
+
205
+ resolved_path
206
+ end
207
+
208
+ def self.get_model_json(model_path, file_name, fatal = true, **options)
209
+ buffer = get_model_file(model_path, file_name, fatal, **options)
210
+ if buffer.nil?
211
+ # Return empty object
212
+ return {}
213
+ end
214
+
215
+ JSON.load_file(buffer)
216
+ end
217
+
218
+ def self.path_join(*parts)
219
+ parts = parts.map.with_index do |part, index|
220
+ if index != 0
221
+ part = part.delete_prefix("/")
222
+ end
223
+ if index != parts.length - 1
224
+ part = part.delete_suffix("/")
225
+ end
226
+ part
227
+ end
228
+ parts.join("/")
229
+ end
230
+
231
+ def self.display_progress(filename, width, size, expected_size)
232
+ bar_width = width - (filename.length + 3)
233
+ progress = size / expected_size.to_f
234
+ done = (progress * bar_width).round
235
+ not_done = bar_width - done
236
+ "#{filename} |#{"█" * done}#{" " * not_done}|"
237
+ end
238
+ end
239
+ end
240
+ end
@@ -0,0 +1,44 @@
1
+ module Informers
2
+ module Utils
3
+ def self.softmax(arr)
4
+ # Compute the maximum value in the array
5
+ max_val = arr.max
6
+
7
+ # Compute the exponentials of the array values
8
+ exps = arr.map { |x| Math.exp(x - max_val) }
9
+
10
+ # Compute the sum of the exponentials
11
+ sum_exps = exps.sum
12
+
13
+ # Compute the softmax values
14
+ softmax_arr = exps.map { |x| x / sum_exps }
15
+
16
+ softmax_arr
17
+ end
18
+
19
+ def self.sigmoid(arr)
20
+ arr.map { |v| 1 / (1 + Math.exp(-v)) }
21
+ end
22
+
23
+ def self.get_top_items(items, top_k = 0)
24
+ # if top == 0, return all
25
+
26
+ items = items
27
+ .map.with_index { |x, i| [i, x] } # Get indices ([index, score])
28
+ .sort_by { |v| -v[1] } # Sort by log probabilities
29
+
30
+ if !top_k.nil? && top_k > 0
31
+ items = items.slice(0, top_k) # Get top k items
32
+ end
33
+
34
+ items
35
+ end
36
+
37
+ def self.max(arr)
38
+ if arr.length == 0
39
+ raise Error, "Array must not be empty"
40
+ end
41
+ arr.map.with_index.max_by { |v, _| v }
42
+ end
43
+ end
44
+ end
@@ -0,0 +1,26 @@
1
+ module Informers
2
+ module Utils
3
+ def self.mean_pooling(last_hidden_state, attention_mask)
4
+ last_hidden_state.zip(attention_mask).map do |state, mask|
5
+ state[0].size.times.map do |k|
6
+ sum = 0.0
7
+ count = 0
8
+
9
+ state.zip(mask) do |s, m|
10
+ count += m
11
+ sum += s[k] * m
12
+ end
13
+
14
+ sum / count
15
+ end
16
+ end
17
+ end
18
+
19
+ def self.normalize(result)
20
+ result.map do |row|
21
+ norm = Math.sqrt(row.sum { |v| v * v })
22
+ row.map { |v| v / norm }
23
+ end
24
+ end
25
+ end
26
+ end
@@ -1,3 +1,3 @@
1
1
  module Informers
2
- VERSION = "0.2.0"
2
+ VERSION = "1.0.0"
3
3
  end
data/lib/informers.rb CHANGED
@@ -1,13 +1,32 @@
1
1
  # dependencies
2
- require "blingfire"
3
- require "numo/narray"
4
2
  require "onnxruntime"
3
+ require "tokenizers"
4
+
5
+ # stdlib
6
+ require "io/console"
7
+ require "json"
8
+ require "open-uri"
9
+ require "stringio"
10
+ require "uri"
5
11
 
6
12
  # modules
7
- require "informers/feature_extraction"
8
- require "informers/fill_mask"
9
- require "informers/ner"
10
- require "informers/question_answering"
11
- require "informers/sentiment_analysis"
12
- require "informers/text_generation"
13
- require "informers/version"
13
+ require_relative "informers/utils/core"
14
+ require_relative "informers/utils/hub"
15
+ require_relative "informers/utils/math"
16
+ require_relative "informers/utils/tensor"
17
+ require_relative "informers/configs"
18
+ require_relative "informers/env"
19
+ require_relative "informers/model"
20
+ require_relative "informers/models"
21
+ require_relative "informers/tokenizers"
22
+ require_relative "informers/pipelines"
23
+
24
+ module Informers
25
+ class Error < StandardError; end
26
+
27
+ class Todo < Error
28
+ def message
29
+ "not implemented yet"
30
+ end
31
+ end
32
+ end
metadata CHANGED
@@ -1,57 +1,43 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: informers
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.0
4
+ version: 1.0.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-09-06 00:00:00.000000000 Z
11
+ date: 2024-08-26 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
- name: blingfire
15
- requirement: !ruby/object:Gem::Requirement
16
- requirements:
17
- - - ">="
18
- - !ruby/object:Gem::Version
19
- version: 0.1.7
20
- type: :runtime
21
- prerelease: false
22
- version_requirements: !ruby/object:Gem::Requirement
23
- requirements:
24
- - - ">="
25
- - !ruby/object:Gem::Version
26
- version: 0.1.7
27
- - !ruby/object:Gem::Dependency
28
- name: numo-narray
14
+ name: onnxruntime
29
15
  requirement: !ruby/object:Gem::Requirement
30
16
  requirements:
31
17
  - - ">="
32
18
  - !ruby/object:Gem::Version
33
- version: '0'
19
+ version: '0.9'
34
20
  type: :runtime
35
21
  prerelease: false
36
22
  version_requirements: !ruby/object:Gem::Requirement
37
23
  requirements:
38
24
  - - ">="
39
25
  - !ruby/object:Gem::Version
40
- version: '0'
26
+ version: '0.9'
41
27
  - !ruby/object:Gem::Dependency
42
- name: onnxruntime
28
+ name: tokenizers
43
29
  requirement: !ruby/object:Gem::Requirement
44
30
  requirements:
45
31
  - - ">="
46
32
  - !ruby/object:Gem::Version
47
- version: 0.5.1
33
+ version: 0.5.2
48
34
  type: :runtime
49
35
  prerelease: false
50
36
  version_requirements: !ruby/object:Gem::Requirement
51
37
  requirements:
52
38
  - - ">="
53
39
  - !ruby/object:Gem::Version
54
- version: 0.5.1
40
+ version: 0.5.2
55
41
  description:
56
42
  email: andrew@ankane.org
57
43
  executables: []
@@ -62,23 +48,17 @@ files:
62
48
  - LICENSE.txt
63
49
  - README.md
64
50
  - lib/informers.rb
65
- - lib/informers/feature_extraction.rb
66
- - lib/informers/fill_mask.rb
67
- - lib/informers/ner.rb
68
- - lib/informers/question_answering.rb
69
- - lib/informers/sentiment_analysis.rb
70
- - lib/informers/text_generation.rb
51
+ - lib/informers/configs.rb
52
+ - lib/informers/env.rb
53
+ - lib/informers/model.rb
54
+ - lib/informers/models.rb
55
+ - lib/informers/pipelines.rb
56
+ - lib/informers/tokenizers.rb
57
+ - lib/informers/utils/core.rb
58
+ - lib/informers/utils/hub.rb
59
+ - lib/informers/utils/math.rb
60
+ - lib/informers/utils/tensor.rb
71
61
  - lib/informers/version.rb
72
- - vendor/LICENSE-bert.txt
73
- - vendor/LICENSE-blingfire.txt
74
- - vendor/LICENSE-gpt2.txt
75
- - vendor/LICENSE-roberta.txt
76
- - vendor/bert_base_cased_tok.bin
77
- - vendor/bert_base_tok.bin
78
- - vendor/gpt2.bin
79
- - vendor/gpt2.i2w
80
- - vendor/roberta.bin
81
- - vendor/roberta.i2w
82
62
  homepage: https://github.com/ankane/informers
83
63
  licenses:
84
64
  - Apache-2.0
@@ -91,15 +71,15 @@ required_ruby_version: !ruby/object:Gem::Requirement
91
71
  requirements:
92
72
  - - ">="
93
73
  - !ruby/object:Gem::Version
94
- version: '2.7'
74
+ version: '3.1'
95
75
  required_rubygems_version: !ruby/object:Gem::Requirement
96
76
  requirements:
97
77
  - - ">="
98
78
  - !ruby/object:Gem::Version
99
79
  version: '0'
100
80
  requirements: []
101
- rubygems_version: 3.3.7
81
+ rubygems_version: 3.5.11
102
82
  signing_key:
103
83
  specification_version: 4
104
- summary: State-of-the-art natural language processing for Ruby
84
+ summary: Fast transformer inference for Ruby
105
85
  test_files: []
@@ -1,59 +0,0 @@
1
- # Copyright 2018 The HuggingFace Inc. team.
2
- # Copyright 2020 Andrew Kane.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- module Informers
17
- class FeatureExtraction
18
- def initialize(model_path)
19
- tokenizer_path = File.expand_path("../../vendor/bert_base_cased_tok.bin", __dir__)
20
- @tokenizer = BlingFire.load_model(tokenizer_path)
21
- @model = OnnxRuntime::Model.new(model_path)
22
- end
23
-
24
- def predict(texts)
25
- singular = !texts.is_a?(Array)
26
- texts = [texts] if singular
27
-
28
- # tokenize
29
- input_ids =
30
- texts.map do |text|
31
- tokens = @tokenizer.text_to_ids(text, nil, 100) # unk token
32
- tokens.unshift(101) # cls token
33
- tokens << 102 # sep token
34
- tokens
35
- end
36
-
37
- max_tokens = input_ids.map(&:size).max
38
- attention_mask = []
39
- input_ids.each do |ids|
40
- zeros = [0] * (max_tokens - ids.size)
41
-
42
- mask = ([1] * ids.size) + zeros
43
- attention_mask << mask
44
-
45
- ids.concat(zeros)
46
- end
47
-
48
- # infer
49
- input = {
50
- input_ids: input_ids,
51
- attention_mask: attention_mask
52
- }
53
- output = @model.predict(input)
54
- scores = output["output_0"] || output["last_hidden_state"]
55
-
56
- singular ? scores.first : scores
57
- end
58
- end
59
- end
@@ -1,109 +0,0 @@
1
- # Copyright 2018 The HuggingFace Inc. team.
2
- # Copyright 2021 Andrew Kane.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- module Informers
17
- class FillMask
18
- def initialize(model_path)
19
- encoder_path = File.expand_path("../../vendor/roberta.bin", __dir__)
20
- @encoder = BlingFire.load_model(encoder_path, prefix: false)
21
-
22
- decoder_path = File.expand_path("../../vendor/roberta.i2w", __dir__)
23
- @decoder = BlingFire.load_model(decoder_path)
24
-
25
- @model = OnnxRuntime::Model.new(model_path)
26
- end
27
-
28
- def predict(texts)
29
- singular = !texts.is_a?(Array)
30
- texts = [texts] if singular
31
-
32
- mask_token = 50264
33
-
34
- # tokenize
35
- input_ids =
36
- texts.map do |text|
37
- tokens = @encoder.text_to_ids(text, nil, 3) # unk token
38
-
39
- # add mask token
40
- mask_sequence = [28696, 43776, 15698]
41
- masks = []
42
- (tokens.size - 2).times do |i|
43
- masks << i if tokens[i..(i + 2)] == mask_sequence
44
- end
45
- masks.reverse.each do |mask|
46
- tokens = tokens[0...mask] + [mask_token] + tokens[(mask + 3)..-1]
47
- end
48
-
49
- tokens.unshift(0) # cls token
50
- tokens << 2 # sep token
51
-
52
- tokens
53
- end
54
-
55
- max_tokens = input_ids.map(&:size).max
56
- attention_mask = []
57
- input_ids.each do |ids|
58
- zeros = [0] * (max_tokens - ids.size)
59
-
60
- mask = ([1] * ids.size) + zeros
61
- attention_mask << mask
62
-
63
- ids.concat(zeros)
64
- end
65
-
66
- input = {
67
- input_ids: input_ids,
68
- attention_mask: attention_mask
69
- }
70
-
71
- masked_index = input_ids.map { |v| v.each_index.select { |i| v[i] == mask_token } }
72
- masked_index.each do |v|
73
- raise "No mask_token (<mask>) found on the input" if v.size < 1
74
- raise "More than one mask_token (<mask>) is not supported" if v.size > 1
75
- end
76
-
77
- res = @model.predict(input)
78
- outputs = res["output_0"] || res["logits"]
79
- batch_size = outputs.size
80
-
81
- results = []
82
- batch_size.times do |i|
83
- result = []
84
-
85
- logits = outputs[i][masked_index[i][0]]
86
- values = logits.map { |v| Math.exp(v) }
87
- sum = values.sum
88
- probs = values.map { |v| v / sum }
89
- res = probs.each_with_index.sort_by { |v| -v[0] }.first(5)
90
-
91
- res.each do |(v, p)|
92
- tokens = input[:input_ids][i].dup
93
- tokens[masked_index[i][0]] = p
94
- result << {
95
- sequence: @decoder.ids_to_text(tokens),
96
- score: v,
97
- token: p,
98
- # TODO figure out prefix space
99
- token_str: @decoder.ids_to_text([p], skip_special_tokens: false)
100
- }
101
- end
102
-
103
- results += [result]
104
- end
105
-
106
- singular ? results.first : results
107
- end
108
- end
109
- end