gliner 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE +21 -0
- data/README.md +145 -0
- data/bin/console +81 -0
- data/gliner.gemspec +26 -0
- data/lib/gliner/classifier.rb +68 -0
- data/lib/gliner/config/classification_task.rb +74 -0
- data/lib/gliner/config/entity_types.rb +74 -0
- data/lib/gliner/config/field_spec.rb +87 -0
- data/lib/gliner/config_parser.rb +37 -0
- data/lib/gliner/inference/session_validator.rb +67 -0
- data/lib/gliner/inference.rb +124 -0
- data/lib/gliner/input_builder.rb +117 -0
- data/lib/gliner/model.rb +142 -0
- data/lib/gliner/pipeline.rb +64 -0
- data/lib/gliner/position_iteration.rb +21 -0
- data/lib/gliner/runners/classification_runner.rb +26 -0
- data/lib/gliner/runners/entity_runner.rb +19 -0
- data/lib/gliner/runners/prepared_task.rb +55 -0
- data/lib/gliner/runners/structured_runner.rb +36 -0
- data/lib/gliner/span_extractor.rb +117 -0
- data/lib/gliner/structured_extractor.rb +94 -0
- data/lib/gliner/task.rb +29 -0
- data/lib/gliner/tasks/classification.rb +101 -0
- data/lib/gliner/tasks/entity_extraction.rb +72 -0
- data/lib/gliner/tasks/json_extraction.rb +91 -0
- data/lib/gliner/text_processor.rb +42 -0
- data/lib/gliner/version.rb +5 -0
- data/lib/gliner.rb +97 -0
- metadata +150 -0
checksums.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
---
|
|
2
|
+
SHA256:
|
|
3
|
+
metadata.gz: edbbf72af6499d1823db172793f3dc183f90e445d4ae45587dc1a0cd4005c15e
|
|
4
|
+
data.tar.gz: 04d0881c074e9e84617591768c1b53767d2f509bbb7497c6bc0920bf3bac96b3
|
|
5
|
+
SHA512:
|
|
6
|
+
metadata.gz: 0ea1480d83e383534c50d14320f5cbfcc7c7a890101f9e0492cbd486dc36e05fbebaedaeea31ccf4f8b5efb3f553349cee1b25ad890457734443054d7e1dfb91
|
|
7
|
+
data.tar.gz: da1f798fd89bcaad28a41e67ec2f07e779e8a7d2681ac3d4790f2e27b2adfcfcc02f5c42e7b02c8eb7bc519fea4178f16bad6050576543dde241e8becd18dbec
|
data/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 elcuervo
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
data/README.md
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
# Gliner
|
|
2
|
+
|
|
3
|
+

|
|
4
|
+
|
|
5
|
+
Minimal Ruby inference wrapper for the **GLiNER2** ONNX model using:
|
|
6
|
+
|
|
7
|
+
## Install
|
|
8
|
+
|
|
9
|
+
```ruby
|
|
10
|
+
gem "gliner"
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
## Usage
|
|
14
|
+
### entities
|
|
15
|
+
|
|
16
|
+
```ruby
|
|
17
|
+
require "gliner"
|
|
18
|
+
|
|
19
|
+
Gliner.load("path/to/gliner2-multi-v1")
|
|
20
|
+
|
|
21
|
+
text = "Apple CEO Tim Cook announced iPhone 15 in Cupertino yesterday."
|
|
22
|
+
labels = ["company", "person", "product", "location"]
|
|
23
|
+
|
|
24
|
+
model = Gliner[labels]
|
|
25
|
+
pp model[text]
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
Expected shape:
|
|
29
|
+
|
|
30
|
+
```ruby
|
|
31
|
+
{"entities"=>{"company"=>["Apple"], "person"=>["Tim Cook"], "product"=>["iPhone 15"], "location"=>["Cupertino"]}}
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
You can also pass per-entity configs:
|
|
35
|
+
|
|
36
|
+
```ruby
|
|
37
|
+
labels = {
|
|
38
|
+
"email" => { "description" => "Email addresses", "dtype" => "list", "threshold" => 0.9 },
|
|
39
|
+
"person" => { "description" => "Person names", "dtype" => "str" }
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
model = Gliner[labels]
|
|
43
|
+
pp model["Email John Doe at john@example.com.", threshold: 0.5]
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
### classification
|
|
47
|
+
|
|
48
|
+
```ruby
|
|
49
|
+
model = Gliner.classify[
|
|
50
|
+
{ "sentiment" => %w[positive negative neutral] }
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
result = model["This laptop has amazing performance but terrible battery life!"]
|
|
54
|
+
|
|
55
|
+
pp result
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
Expected shape:
|
|
59
|
+
|
|
60
|
+
```ruby
|
|
61
|
+
{"sentiment"=>"negative"}
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
### structured extraction
|
|
65
|
+
|
|
66
|
+
```ruby
|
|
67
|
+
text = "iPhone 15 Pro Max with 256GB storage, A17 Pro chip, priced at $1199."
|
|
68
|
+
|
|
69
|
+
structure = {
|
|
70
|
+
"product" => [
|
|
71
|
+
"name::str::Full product name and model",
|
|
72
|
+
"storage::str::Storage capacity",
|
|
73
|
+
"processor::str::Chip or processor information",
|
|
74
|
+
"price::str::Product price with currency"
|
|
75
|
+
]
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
result = Gliner[structure][text]
|
|
79
|
+
|
|
80
|
+
pp result
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
Expected shape:
|
|
84
|
+
|
|
85
|
+
```ruby
|
|
86
|
+
{"product"=>[{"name"=>"iPhone 15 Pro Max", "storage"=>"256GB", "processor"=>"A17 Pro chip", "price"=>"$1199"}]}
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
Choices can be included in field specs:
|
|
90
|
+
|
|
91
|
+
```ruby
|
|
92
|
+
result = Gliner[{ "order" => ["status::[pending|processing|shipped]::str"] }]["Status: shipped"]
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
## Model files
|
|
96
|
+
|
|
97
|
+
This implementation expects a directory containing:
|
|
98
|
+
|
|
99
|
+
- `tokenizer.json`
|
|
100
|
+
- `model.onnx` or `model_int8.onnx`
|
|
101
|
+
- (optional) `config.json` with `max_width` and `max_seq_len`
|
|
102
|
+
|
|
103
|
+
One publicly available ONNX export is `cuerbot/gliner2-multi-v1` on Hugging Face.
|
|
104
|
+
|
|
105
|
+
## Integration test
|
|
106
|
+
|
|
107
|
+
Downloads a public ONNX export and runs a real inference:
|
|
108
|
+
|
|
109
|
+
```bash
|
|
110
|
+
rake test:integration
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
To download the model separately (for console testing, etc):
|
|
114
|
+
|
|
115
|
+
```bash
|
|
116
|
+
rake model:pull
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
To reuse an existing local download:
|
|
120
|
+
|
|
121
|
+
```bash
|
|
122
|
+
GLINER_MODEL_DIR=/path/to/model_dir rake test:integration
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
## Console
|
|
126
|
+
|
|
127
|
+
Start an IRB session with the gem loaded:
|
|
128
|
+
|
|
129
|
+
```bash
|
|
130
|
+
rake console MODEL_DIR=/path/to/model_dir
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
If you omit `MODEL_DIR`, the console auto-downloads a public test model (configurable):
|
|
134
|
+
|
|
135
|
+
```bash
|
|
136
|
+
rake console
|
|
137
|
+
# or:
|
|
138
|
+
GLINER_REPO_ID=cuerbot/gliner2-multi-v1 GLINER_MODEL_FILE=model_int8.onnx rake console
|
|
139
|
+
```
|
|
140
|
+
|
|
141
|
+
Or:
|
|
142
|
+
|
|
143
|
+
```bash
|
|
144
|
+
ruby -Ilib bin/console /path/to/model_dir
|
|
145
|
+
```
|
data/bin/console
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
#!/usr/bin/env ruby
|
|
2
|
+
# frozen_string_literal: true
|
|
3
|
+
|
|
4
|
+
begin
|
|
5
|
+
require "bundler/setup"
|
|
6
|
+
rescue LoadError
|
|
7
|
+
end
|
|
8
|
+
|
|
9
|
+
require "gliner"
|
|
10
|
+
require "fileutils"
|
|
11
|
+
require "httpx"
|
|
12
|
+
require "irb"
|
|
13
|
+
|
|
14
|
+
DEFAULT_REPO_ID = "cuerbot/gliner2-multi-v1"
|
|
15
|
+
DEFAULT_MODEL_FILE = "model_int8.onnx"
|
|
16
|
+
|
|
17
|
+
def ensure_model_dir!(repo_id:, model_file:)
|
|
18
|
+
dir = File.expand_path("../tmp/models/#{repo_id.tr('/', '__')}", __dir__)
|
|
19
|
+
FileUtils.mkdir_p(dir)
|
|
20
|
+
|
|
21
|
+
base = "https://huggingface.co/#{repo_id}/resolve/main"
|
|
22
|
+
files = ["tokenizer.json", "config.json", model_file]
|
|
23
|
+
|
|
24
|
+
files.each do |file|
|
|
25
|
+
dest = File.join(dir, file)
|
|
26
|
+
next if File.exist?(dest) && File.size?(dest)
|
|
27
|
+
download("#{base}/#{file}", dest)
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
dir
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
def download(url, dest)
|
|
34
|
+
response = HTTPX.get(url)
|
|
35
|
+
raise "Download failed: #{url} (status: #{response.status})" unless response.status.between?(200, 299)
|
|
36
|
+
|
|
37
|
+
File.binwrite(dest, response.body.to_s)
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
model_dir = ARGV[0] || ENV["GLINER_MODEL_DIR"]
|
|
41
|
+
repo_id = ENV["GLINER_REPO_ID"] || DEFAULT_REPO_ID
|
|
42
|
+
model_file = ENV["GLINER_MODEL_FILE"] || DEFAULT_MODEL_FILE
|
|
43
|
+
|
|
44
|
+
if model_dir && !model_dir.empty?
|
|
45
|
+
$gliner_model = Gliner.load(model_dir, file: model_file)
|
|
46
|
+
else
|
|
47
|
+
begin
|
|
48
|
+
require "fileutils"
|
|
49
|
+
model_dir = ensure_model_dir!(repo_id: repo_id, model_file: model_file)
|
|
50
|
+
$gliner_model = Gliner.load(model_dir, file: model_file)
|
|
51
|
+
rescue => e
|
|
52
|
+
warn "No model loaded (auto-download failed: #{e.class}: #{e.message})"
|
|
53
|
+
warn "Set GLINER_MODEL_DIR to a local model dir, or set GLINER_REPO_ID/GLINER_MODEL_FILE for auto-download."
|
|
54
|
+
end
|
|
55
|
+
end
|
|
56
|
+
|
|
57
|
+
def gliner_extract(text, labels, **opts)
|
|
58
|
+
raise "No model loaded (set GLINER_MODEL_DIR or pass a directory arg)" unless $gliner_model
|
|
59
|
+
Gliner[labels][text, **opts]
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
def gliner_classify(text, tasks, **opts)
|
|
63
|
+
raise "No model loaded (set GLINER_MODEL_DIR or pass a directory arg)" unless $gliner_model
|
|
64
|
+
Gliner.classify[tasks][text, **opts]
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
def gliner_extract_json(text, structures, **opts)
|
|
68
|
+
raise "No model loaded (set GLINER_MODEL_DIR or pass a directory arg)" unless $gliner_model
|
|
69
|
+
Gliner[structures][text, **opts]
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
puts "Gliner console"
|
|
73
|
+
puts "- model: #{$gliner_model ? "loaded" : "not loaded"}"
|
|
74
|
+
puts "- helper: gliner_extract(text, labels, threshold: 0.5)"
|
|
75
|
+
puts "- helper: gliner_classify(text, tasks)"
|
|
76
|
+
puts "- helper: gliner_extract_json(text, structures)"
|
|
77
|
+
puts "- model variable: $gliner_model"
|
|
78
|
+
puts "- model dir: #{model_dir.inspect}"
|
|
79
|
+
puts "- auto-download env: GLINER_REPO_ID=#{repo_id.inspect} GLINER_MODEL_FILE=#{model_file.inspect}" unless $gliner_model
|
|
80
|
+
|
|
81
|
+
IRB.start(__FILE__)
|
data/gliner.gemspec
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative 'lib/gliner/version'
|
|
4
|
+
|
|
5
|
+
Gem::Specification.new do |spec|
|
|
6
|
+
spec.name = 'gliner'
|
|
7
|
+
spec.version = Gliner::VERSION
|
|
8
|
+
spec.authors = ['elcuervo']
|
|
9
|
+
|
|
10
|
+
spec.summary = 'Schema-based information extraction (GLiNER2) via ONNX Runtime'
|
|
11
|
+
spec.description = 'Basic Ruby inference wrapper for the GLiNER2 ONNX model.'
|
|
12
|
+
spec.homepage = 'https://github.com/elcuervo/gliner'
|
|
13
|
+
spec.license = 'MIT'
|
|
14
|
+
spec.required_ruby_version = '>= 3.2'
|
|
15
|
+
|
|
16
|
+
spec.files = Dir.glob('lib/**/*') + Dir.glob('bin/*') + %w[README.md LICENSE gliner.gemspec]
|
|
17
|
+
spec.require_paths = ['lib']
|
|
18
|
+
|
|
19
|
+
spec.add_dependency 'onnxruntime', '~> 0.10'
|
|
20
|
+
spec.add_dependency 'tokenizers', '~> 0.6'
|
|
21
|
+
|
|
22
|
+
spec.add_development_dependency 'httpx', '~> 1.0'
|
|
23
|
+
spec.add_development_dependency 'rake', '~> 13.0'
|
|
24
|
+
spec.add_development_dependency 'rspec', '~> 3.13'
|
|
25
|
+
spec.add_development_dependency 'rubocop', '~> 1.50'
|
|
26
|
+
end
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require 'gliner/position_iteration'
|
|
4
|
+
|
|
5
|
+
module Gliner
|
|
6
|
+
class Classifier
|
|
7
|
+
include PositionIteration
|
|
8
|
+
|
|
9
|
+
def initialize(inference, max_width:)
|
|
10
|
+
@inference = inference
|
|
11
|
+
@max_width = max_width
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def classification_scores(logits, labels, label_positions, prepared)
|
|
15
|
+
labels.each_index.map do |label_index|
|
|
16
|
+
max_label_score(logits, label_index, label_positions, prepared)
|
|
17
|
+
end
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
def format_classification(scores, labels:, multi_label:, include_confidence:, cls_threshold:)
|
|
21
|
+
label_scores = sorted_label_scores(scores, labels)
|
|
22
|
+
|
|
23
|
+
return format_multi_label(label_scores, cls_threshold, include_confidence) if multi_label
|
|
24
|
+
|
|
25
|
+
format_single_label(label_scores.first, include_confidence)
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
private
|
|
29
|
+
|
|
30
|
+
def max_label_score(logits, label_index, label_positions, prepared)
|
|
31
|
+
seq_len = logits[0].length
|
|
32
|
+
|
|
33
|
+
scores = each_position_width(seq_len, prepared, @max_width).map do |pos, _start_word, width|
|
|
34
|
+
logit = @inference.label_logit(logits, pos, width, label_index, label_positions)
|
|
35
|
+
@inference.sigmoid(logit)
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
scores.max || -Float::INFINITY
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def sorted_label_scores(scores, labels)
|
|
42
|
+
scores
|
|
43
|
+
.each_with_index.map { |score, i| [labels.fetch(i), score] }
|
|
44
|
+
.sort_by { |(_label, score)| -score }
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
def format_multi_label(label_scores, cls_threshold, include_confidence)
|
|
48
|
+
chosen = labels_above_threshold(label_scores, cls_threshold)
|
|
49
|
+
|
|
50
|
+
chosen.map { |label, score| format_label(label, score, include_confidence) }
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
def labels_above_threshold(label_scores, threshold)
|
|
54
|
+
above = label_scores.select { |_label, score| score >= threshold }
|
|
55
|
+
above.empty? && label_scores.first ? [label_scores.first] : above
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
def format_single_label(label_score, include_confidence)
|
|
59
|
+
label, score = label_score
|
|
60
|
+
|
|
61
|
+
format_label(label, score, include_confidence)
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
def format_label(label, score, include_confidence)
|
|
65
|
+
include_confidence ? { 'label' => label, 'confidence' => score } : label
|
|
66
|
+
end
|
|
67
|
+
end
|
|
68
|
+
end
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Gliner
|
|
4
|
+
module Config
|
|
5
|
+
class ClassificationTask
|
|
6
|
+
DEFAULT_THRESHOLD = 0.5
|
|
7
|
+
|
|
8
|
+
class << self
|
|
9
|
+
def parse(task_name, config)
|
|
10
|
+
case config
|
|
11
|
+
when Array
|
|
12
|
+
from_labels(config)
|
|
13
|
+
when Hash
|
|
14
|
+
from_hash(task_name, config)
|
|
15
|
+
else
|
|
16
|
+
raise Error, "classification task #{task_name.inspect} must be an Array or Hash"
|
|
17
|
+
end
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
private
|
|
21
|
+
|
|
22
|
+
def from_labels(labels)
|
|
23
|
+
{
|
|
24
|
+
labels: labels.map(&:to_s),
|
|
25
|
+
multi_label: false,
|
|
26
|
+
cls_threshold: DEFAULT_THRESHOLD,
|
|
27
|
+
label_descs: {}
|
|
28
|
+
}
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def from_hash(task_name, config)
|
|
32
|
+
config_hash = config.transform_keys(&:to_s)
|
|
33
|
+
|
|
34
|
+
return from_described_labels(task_name, config_hash) if config_hash.key?('labels')
|
|
35
|
+
|
|
36
|
+
{
|
|
37
|
+
labels: config.keys.map(&:to_s),
|
|
38
|
+
multi_label: false,
|
|
39
|
+
cls_threshold: DEFAULT_THRESHOLD,
|
|
40
|
+
label_descs: config.transform_keys(&:to_s).transform_values(&:to_s)
|
|
41
|
+
}
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
def from_described_labels(task_name, config_hash)
|
|
45
|
+
labels, label_descs = parse_labels(task_name, config_hash['labels'])
|
|
46
|
+
|
|
47
|
+
{
|
|
48
|
+
labels: labels,
|
|
49
|
+
multi_label: config_hash['multi_label'] ? true : false,
|
|
50
|
+
cls_threshold: threshold(config_hash['cls_threshold']),
|
|
51
|
+
label_descs: label_descs
|
|
52
|
+
}
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
def parse_labels(task_name, raw_labels)
|
|
56
|
+
case raw_labels
|
|
57
|
+
when Array
|
|
58
|
+
[raw_labels.map(&:to_s), {}]
|
|
59
|
+
when Hash
|
|
60
|
+
[raw_labels.keys.map(&:to_s), raw_labels.transform_keys(&:to_s).transform_values(&:to_s)]
|
|
61
|
+
else
|
|
62
|
+
raise Error, "classification task #{task_name.inspect} must include labels"
|
|
63
|
+
end
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
def threshold(value)
|
|
67
|
+
return DEFAULT_THRESHOLD if value.nil? || value == false
|
|
68
|
+
|
|
69
|
+
Float(value)
|
|
70
|
+
end
|
|
71
|
+
end
|
|
72
|
+
end
|
|
73
|
+
end
|
|
74
|
+
end
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Gliner
|
|
4
|
+
module Config
|
|
5
|
+
class EntityTypes
|
|
6
|
+
class << self
|
|
7
|
+
def parse(entity_types)
|
|
8
|
+
case entity_types
|
|
9
|
+
when Array
|
|
10
|
+
list_config(entity_types)
|
|
11
|
+
when String, Symbol
|
|
12
|
+
list_config([entity_types])
|
|
13
|
+
when Hash
|
|
14
|
+
hash_config(entity_types)
|
|
15
|
+
else
|
|
16
|
+
raise Error, 'labels must be a String, Array, or Hash'
|
|
17
|
+
end
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
private
|
|
21
|
+
|
|
22
|
+
def list_config(entity_types)
|
|
23
|
+
{
|
|
24
|
+
labels: entity_types.map(&:to_s),
|
|
25
|
+
descriptions: {},
|
|
26
|
+
dtypes: {},
|
|
27
|
+
thresholds: {}
|
|
28
|
+
}
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def hash_config(entity_types)
|
|
32
|
+
state = { labels: [], descriptions: {}, dtypes: {}, thresholds: {} }
|
|
33
|
+
entity_types.each { |label, config| apply_config(state, label, config) }
|
|
34
|
+
state
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
def apply_config(state, label, config)
|
|
38
|
+
name = label.to_s
|
|
39
|
+
state[:labels] << name
|
|
40
|
+
|
|
41
|
+
return if config.nil?
|
|
42
|
+
|
|
43
|
+
case config
|
|
44
|
+
when String
|
|
45
|
+
apply_description(state, name, config)
|
|
46
|
+
when Hash
|
|
47
|
+
apply_hash_config(state, name, config)
|
|
48
|
+
else
|
|
49
|
+
apply_description(state, name, config.to_s)
|
|
50
|
+
end
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
def apply_hash_config(state, name, config)
|
|
54
|
+
config_hash = config.transform_keys(&:to_s)
|
|
55
|
+
apply_description(state, name, config_hash['description']) if config_hash['description']
|
|
56
|
+
apply_dtype(state, name, config_hash['dtype']) if config_hash['dtype']
|
|
57
|
+
apply_threshold(state, name, config_hash['threshold']) if config_hash.key?('threshold')
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
def apply_description(state, name, description)
|
|
61
|
+
state[:descriptions][name] = description
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
def apply_dtype(state, name, dtype)
|
|
65
|
+
state[:dtypes][name] = dtype.to_s == 'str' ? :str : :list
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
def apply_threshold(state, name, threshold)
|
|
69
|
+
state[:thresholds][name] = Float(threshold)
|
|
70
|
+
end
|
|
71
|
+
end
|
|
72
|
+
end
|
|
73
|
+
end
|
|
74
|
+
end
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Gliner
|
|
4
|
+
module Config
|
|
5
|
+
class FieldSpec
|
|
6
|
+
class << self
|
|
7
|
+
def parse(spec)
|
|
8
|
+
name, *parts = spec.split('::')
|
|
9
|
+
field = build_field(name)
|
|
10
|
+
|
|
11
|
+
parts.each { |part| apply_part(field, part.to_s) }
|
|
12
|
+
field.delete(:dtype_explicit)
|
|
13
|
+
field
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
def build_descriptions(parsed_fields)
|
|
17
|
+
parsed_fields.each_with_object({}) do |field, acc|
|
|
18
|
+
description = description_for(field)
|
|
19
|
+
acc[field[:name]] = description if description
|
|
20
|
+
end
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
private
|
|
24
|
+
|
|
25
|
+
def build_field(name)
|
|
26
|
+
{
|
|
27
|
+
name: name.to_s,
|
|
28
|
+
dtype: :list,
|
|
29
|
+
description: nil,
|
|
30
|
+
choices: nil,
|
|
31
|
+
dtype_explicit: false
|
|
32
|
+
}
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
def apply_part(field, part)
|
|
36
|
+
return apply_dtype_part(field, part) if dtype_part?(part)
|
|
37
|
+
return apply_choice_part(field, part) if bracketed_list?(part)
|
|
38
|
+
|
|
39
|
+
append_description(field, part)
|
|
40
|
+
end
|
|
41
|
+
|
|
42
|
+
def dtype_part?(part)
|
|
43
|
+
%w[str list].include?(part)
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
def apply_dtype_part(field, part)
|
|
47
|
+
set_dtype(field, part == 'str' ? :str : :list)
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
def apply_choice_part(field, part)
|
|
51
|
+
field[:choices] = parse_choices(part)
|
|
52
|
+
field[:dtype] = :str unless field[:dtype_explicit]
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
def set_dtype(field, dtype)
|
|
56
|
+
field[:dtype] = dtype
|
|
57
|
+
field[:dtype_explicit] = true
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
def append_description(field, part)
|
|
61
|
+
field[:description] = [field[:description], part].compact.join('::')
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
def bracketed_list?(part)
|
|
65
|
+
part.start_with?('[') && part.end_with?(']')
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
def parse_choices(part)
|
|
69
|
+
part[1..-2].split('|').map(&:strip).reject(&:empty?)
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
def description_for(field)
|
|
73
|
+
description = field[:description].to_s
|
|
74
|
+
choices = field[:choices]
|
|
75
|
+
return nil if description.empty? && !choices&.any?
|
|
76
|
+
|
|
77
|
+
if choices&.any?
|
|
78
|
+
choices_str = choices.join('|')
|
|
79
|
+
description = description.empty? ? "Choices: #{choices_str}" : "#{description} (choices: #{choices_str})"
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
description.empty? ? nil : description
|
|
83
|
+
end
|
|
84
|
+
end
|
|
85
|
+
end
|
|
86
|
+
end
|
|
87
|
+
end
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require 'gliner/config/entity_types'
|
|
4
|
+
require 'gliner/config/classification_task'
|
|
5
|
+
require 'gliner/config/field_spec'
|
|
6
|
+
|
|
7
|
+
module Gliner
|
|
8
|
+
class ConfigParser
|
|
9
|
+
def parse_entity_types(entity_types)
|
|
10
|
+
Config::EntityTypes.parse(entity_types)
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def parse_classification_task(task_name, config)
|
|
14
|
+
Config::ClassificationTask.parse(task_name, config)
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def parse_field_spec(spec)
|
|
18
|
+
Config::FieldSpec.parse(spec)
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
def build_field_descriptions(parsed_fields)
|
|
22
|
+
Config::FieldSpec.build_descriptions(parsed_fields)
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
def build_prompt(base, label_descriptions)
|
|
26
|
+
prompt = base.to_s
|
|
27
|
+
|
|
28
|
+
label_descriptions.to_h.each do |label, description|
|
|
29
|
+
next if description.to_s.empty?
|
|
30
|
+
|
|
31
|
+
prompt += " [DESCRIPTION] #{label}: #{description}"
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
prompt
|
|
35
|
+
end
|
|
36
|
+
end
|
|
37
|
+
end
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Gliner
|
|
4
|
+
class Inference
|
|
5
|
+
class SessionValidator
|
|
6
|
+
EXPECTED_INPUTS_LOGITS = %w[
|
|
7
|
+
input_ids
|
|
8
|
+
attention_mask
|
|
9
|
+
words_mask
|
|
10
|
+
text_lengths
|
|
11
|
+
task_type
|
|
12
|
+
label_positions
|
|
13
|
+
label_mask
|
|
14
|
+
].freeze
|
|
15
|
+
|
|
16
|
+
EXPECTED_INPUTS_SPAN_LOGITS = %w[
|
|
17
|
+
input_ids
|
|
18
|
+
attention_mask
|
|
19
|
+
].freeze
|
|
20
|
+
|
|
21
|
+
class << self
|
|
22
|
+
def [](session) = call(session)
|
|
23
|
+
|
|
24
|
+
def call(session)
|
|
25
|
+
input_names = session.inputs.map { |input| input[:name] }
|
|
26
|
+
output_names = session.outputs.map { |output| output[:name] }
|
|
27
|
+
has_cls_logits = output_names.include?('cls_logits')
|
|
28
|
+
|
|
29
|
+
validation = validation_for_outputs(output_names, input_names)
|
|
30
|
+
|
|
31
|
+
IOValidation.new(
|
|
32
|
+
input_names: input_names,
|
|
33
|
+
output_name: validation.fetch(:output_name),
|
|
34
|
+
label_index_mode: validation.fetch(:label_index_mode),
|
|
35
|
+
has_cls_logits: has_cls_logits
|
|
36
|
+
)
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
private
|
|
40
|
+
|
|
41
|
+
def validation_for_outputs(output_names, input_names)
|
|
42
|
+
return validation_for_logits(input_names) if output_names.include?('logits')
|
|
43
|
+
return validation_for_span_logits(input_names) if output_names.include?('span_logits')
|
|
44
|
+
|
|
45
|
+
raise Error, 'Model missing output: logits or span_logits'
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
def validation_for_logits(input_names)
|
|
49
|
+
ensure_expected_inputs!(EXPECTED_INPUTS_LOGITS, input_names)
|
|
50
|
+
|
|
51
|
+
{ output_name: 'logits', label_index_mode: :label_index }
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
def validation_for_span_logits(input_names)
|
|
55
|
+
ensure_expected_inputs!(EXPECTED_INPUTS_SPAN_LOGITS, input_names)
|
|
56
|
+
|
|
57
|
+
{ output_name: 'span_logits', label_index_mode: :label_position }
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
def ensure_expected_inputs!(expected_inputs, input_names)
|
|
61
|
+
missing = expected_inputs - input_names
|
|
62
|
+
raise Error, "Model missing inputs: #{missing.join(', ')}" unless missing.empty?
|
|
63
|
+
end
|
|
64
|
+
end
|
|
65
|
+
end
|
|
66
|
+
end
|
|
67
|
+
end
|