gliner 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: edbbf72af6499d1823db172793f3dc183f90e445d4ae45587dc1a0cd4005c15e
4
+ data.tar.gz: 04d0881c074e9e84617591768c1b53767d2f509bbb7497c6bc0920bf3bac96b3
5
+ SHA512:
6
+ metadata.gz: 0ea1480d83e383534c50d14320f5cbfcc7c7a890101f9e0492cbd486dc36e05fbebaedaeea31ccf4f8b5efb3f553349cee1b25ad890457734443054d7e1dfb91
7
+ data.tar.gz: da1f798fd89bcaad28a41e67ec2f07e779e8a7d2681ac3d4790f2e27b2adfcfcc02f5c42e7b02c8eb7bc519fea4178f16bad6050576543dde241e8becd18dbec
data/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 elcuervo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
data/README.md ADDED
@@ -0,0 +1,145 @@
1
+ # Gliner
2
+
3
+ ![](https://images.unsplash.com/photo-1625768376503-68d2495d78c5?q=80&w=2225&auto=format&fit=crop&ixlib=rb-4.1.0&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D)
4
+
5
+ Minimal Ruby inference wrapper for the **GLiNER2** ONNX model using:
6
+
7
+ ## Install
8
+
9
+ ```ruby
10
+ gem "gliner"
11
+ ```
12
+
13
+ ## Usage
14
+ ### entities
15
+
16
+ ```ruby
17
+ require "gliner"
18
+
19
+ Gliner.load("path/to/gliner2-multi-v1")
20
+
21
+ text = "Apple CEO Tim Cook announced iPhone 15 in Cupertino yesterday."
22
+ labels = ["company", "person", "product", "location"]
23
+
24
+ model = Gliner[labels]
25
+ pp model[text]
26
+ ```
27
+
28
+ Expected shape:
29
+
30
+ ```ruby
31
+ {"entities"=>{"company"=>["Apple"], "person"=>["Tim Cook"], "product"=>["iPhone 15"], "location"=>["Cupertino"]}}
32
+ ```
33
+
34
+ You can also pass per-entity configs:
35
+
36
+ ```ruby
37
+ labels = {
38
+ "email" => { "description" => "Email addresses", "dtype" => "list", "threshold" => 0.9 },
39
+ "person" => { "description" => "Person names", "dtype" => "str" }
40
+ }
41
+
42
+ model = Gliner[labels]
43
+ pp model["Email John Doe at john@example.com.", threshold: 0.5]
44
+ ```
45
+
46
+ ### classification
47
+
48
+ ```ruby
49
+ model = Gliner.classify[
50
+ { "sentiment" => %w[positive negative neutral] }
51
+ ]
52
+
53
+ result = model["This laptop has amazing performance but terrible battery life!"]
54
+
55
+ pp result
56
+ ```
57
+
58
+ Expected shape:
59
+
60
+ ```ruby
61
+ {"sentiment"=>"negative"}
62
+ ```
63
+
64
+ ### structured extraction
65
+
66
+ ```ruby
67
+ text = "iPhone 15 Pro Max with 256GB storage, A17 Pro chip, priced at $1199."
68
+
69
+ structure = {
70
+ "product" => [
71
+ "name::str::Full product name and model",
72
+ "storage::str::Storage capacity",
73
+ "processor::str::Chip or processor information",
74
+ "price::str::Product price with currency"
75
+ ]
76
+ }
77
+
78
+ result = Gliner[structure][text]
79
+
80
+ pp result
81
+ ```
82
+
83
+ Expected shape:
84
+
85
+ ```ruby
86
+ {"product"=>[{"name"=>"iPhone 15 Pro Max", "storage"=>"256GB", "processor"=>"A17 Pro chip", "price"=>"$1199"}]}
87
+ ```
88
+
89
+ Choices can be included in field specs:
90
+
91
+ ```ruby
92
+ result = Gliner[{ "order" => ["status::[pending|processing|shipped]::str"] }]["Status: shipped"]
93
+ ```
94
+
95
+ ## Model files
96
+
97
+ This implementation expects a directory containing:
98
+
99
+ - `tokenizer.json`
100
+ - `model.onnx` or `model_int8.onnx`
101
+ - (optional) `config.json` with `max_width` and `max_seq_len`
102
+
103
+ One publicly available ONNX export is `cuerbot/gliner2-multi-v1` on Hugging Face.
104
+
105
+ ## Integration test
106
+
107
+ Downloads a public ONNX export and runs a real inference:
108
+
109
+ ```bash
110
+ rake test:integration
111
+ ```
112
+
113
+ To download the model separately (for console testing, etc):
114
+
115
+ ```bash
116
+ rake model:pull
117
+ ```
118
+
119
+ To reuse an existing local download:
120
+
121
+ ```bash
122
+ GLINER_MODEL_DIR=/path/to/model_dir rake test:integration
123
+ ```
124
+
125
+ ## Console
126
+
127
+ Start an IRB session with the gem loaded:
128
+
129
+ ```bash
130
+ rake console MODEL_DIR=/path/to/model_dir
131
+ ```
132
+
133
+ If you omit `MODEL_DIR`, the console auto-downloads a public test model (configurable):
134
+
135
+ ```bash
136
+ rake console
137
+ # or:
138
+ GLINER_REPO_ID=cuerbot/gliner2-multi-v1 GLINER_MODEL_FILE=model_int8.onnx rake console
139
+ ```
140
+
141
+ Or:
142
+
143
+ ```bash
144
+ ruby -Ilib bin/console /path/to/model_dir
145
+ ```
data/bin/console ADDED
@@ -0,0 +1,81 @@
1
+ #!/usr/bin/env ruby
2
+ # frozen_string_literal: true
3
+
4
+ begin
5
+ require "bundler/setup"
6
+ rescue LoadError
7
+ end
8
+
9
+ require "gliner"
10
+ require "fileutils"
11
+ require "httpx"
12
+ require "irb"
13
+
14
+ DEFAULT_REPO_ID = "cuerbot/gliner2-multi-v1"
15
+ DEFAULT_MODEL_FILE = "model_int8.onnx"
16
+
17
+ def ensure_model_dir!(repo_id:, model_file:)
18
+ dir = File.expand_path("../tmp/models/#{repo_id.tr('/', '__')}", __dir__)
19
+ FileUtils.mkdir_p(dir)
20
+
21
+ base = "https://huggingface.co/#{repo_id}/resolve/main"
22
+ files = ["tokenizer.json", "config.json", model_file]
23
+
24
+ files.each do |file|
25
+ dest = File.join(dir, file)
26
+ next if File.exist?(dest) && File.size?(dest)
27
+ download("#{base}/#{file}", dest)
28
+ end
29
+
30
+ dir
31
+ end
32
+
33
+ def download(url, dest)
34
+ response = HTTPX.get(url)
35
+ raise "Download failed: #{url} (status: #{response.status})" unless response.status.between?(200, 299)
36
+
37
+ File.binwrite(dest, response.body.to_s)
38
+ end
39
+
40
+ model_dir = ARGV[0] || ENV["GLINER_MODEL_DIR"]
41
+ repo_id = ENV["GLINER_REPO_ID"] || DEFAULT_REPO_ID
42
+ model_file = ENV["GLINER_MODEL_FILE"] || DEFAULT_MODEL_FILE
43
+
44
+ if model_dir && !model_dir.empty?
45
+ $gliner_model = Gliner.load(model_dir, file: model_file)
46
+ else
47
+ begin
48
+ require "fileutils"
49
+ model_dir = ensure_model_dir!(repo_id: repo_id, model_file: model_file)
50
+ $gliner_model = Gliner.load(model_dir, file: model_file)
51
+ rescue => e
52
+ warn "No model loaded (auto-download failed: #{e.class}: #{e.message})"
53
+ warn "Set GLINER_MODEL_DIR to a local model dir, or set GLINER_REPO_ID/GLINER_MODEL_FILE for auto-download."
54
+ end
55
+ end
56
+
57
+ def gliner_extract(text, labels, **opts)
58
+ raise "No model loaded (set GLINER_MODEL_DIR or pass a directory arg)" unless $gliner_model
59
+ Gliner[labels][text, **opts]
60
+ end
61
+
62
+ def gliner_classify(text, tasks, **opts)
63
+ raise "No model loaded (set GLINER_MODEL_DIR or pass a directory arg)" unless $gliner_model
64
+ Gliner.classify[tasks][text, **opts]
65
+ end
66
+
67
+ def gliner_extract_json(text, structures, **opts)
68
+ raise "No model loaded (set GLINER_MODEL_DIR or pass a directory arg)" unless $gliner_model
69
+ Gliner[structures][text, **opts]
70
+ end
71
+
72
+ puts "Gliner console"
73
+ puts "- model: #{$gliner_model ? "loaded" : "not loaded"}"
74
+ puts "- helper: gliner_extract(text, labels, threshold: 0.5)"
75
+ puts "- helper: gliner_classify(text, tasks)"
76
+ puts "- helper: gliner_extract_json(text, structures)"
77
+ puts "- model variable: $gliner_model"
78
+ puts "- model dir: #{model_dir.inspect}"
79
+ puts "- auto-download env: GLINER_REPO_ID=#{repo_id.inspect} GLINER_MODEL_FILE=#{model_file.inspect}" unless $gliner_model
80
+
81
+ IRB.start(__FILE__)
data/gliner.gemspec ADDED
@@ -0,0 +1,26 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative 'lib/gliner/version'
4
+
5
+ Gem::Specification.new do |spec|
6
+ spec.name = 'gliner'
7
+ spec.version = Gliner::VERSION
8
+ spec.authors = ['elcuervo']
9
+
10
+ spec.summary = 'Schema-based information extraction (GLiNER2) via ONNX Runtime'
11
+ spec.description = 'Basic Ruby inference wrapper for the GLiNER2 ONNX model.'
12
+ spec.homepage = 'https://github.com/elcuervo/gliner'
13
+ spec.license = 'MIT'
14
+ spec.required_ruby_version = '>= 3.2'
15
+
16
+ spec.files = Dir.glob('lib/**/*') + Dir.glob('bin/*') + %w[README.md LICENSE gliner.gemspec]
17
+ spec.require_paths = ['lib']
18
+
19
+ spec.add_dependency 'onnxruntime', '~> 0.10'
20
+ spec.add_dependency 'tokenizers', '~> 0.6'
21
+
22
+ spec.add_development_dependency 'httpx', '~> 1.0'
23
+ spec.add_development_dependency 'rake', '~> 13.0'
24
+ spec.add_development_dependency 'rspec', '~> 3.13'
25
+ spec.add_development_dependency 'rubocop', '~> 1.50'
26
+ end
@@ -0,0 +1,68 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'gliner/position_iteration'
4
+
5
+ module Gliner
6
+ class Classifier
7
+ include PositionIteration
8
+
9
+ def initialize(inference, max_width:)
10
+ @inference = inference
11
+ @max_width = max_width
12
+ end
13
+
14
+ def classification_scores(logits, labels, label_positions, prepared)
15
+ labels.each_index.map do |label_index|
16
+ max_label_score(logits, label_index, label_positions, prepared)
17
+ end
18
+ end
19
+
20
+ def format_classification(scores, labels:, multi_label:, include_confidence:, cls_threshold:)
21
+ label_scores = sorted_label_scores(scores, labels)
22
+
23
+ return format_multi_label(label_scores, cls_threshold, include_confidence) if multi_label
24
+
25
+ format_single_label(label_scores.first, include_confidence)
26
+ end
27
+
28
+ private
29
+
30
+ def max_label_score(logits, label_index, label_positions, prepared)
31
+ seq_len = logits[0].length
32
+
33
+ scores = each_position_width(seq_len, prepared, @max_width).map do |pos, _start_word, width|
34
+ logit = @inference.label_logit(logits, pos, width, label_index, label_positions)
35
+ @inference.sigmoid(logit)
36
+ end
37
+
38
+ scores.max || -Float::INFINITY
39
+ end
40
+
41
+ def sorted_label_scores(scores, labels)
42
+ scores
43
+ .each_with_index.map { |score, i| [labels.fetch(i), score] }
44
+ .sort_by { |(_label, score)| -score }
45
+ end
46
+
47
+ def format_multi_label(label_scores, cls_threshold, include_confidence)
48
+ chosen = labels_above_threshold(label_scores, cls_threshold)
49
+
50
+ chosen.map { |label, score| format_label(label, score, include_confidence) }
51
+ end
52
+
53
+ def labels_above_threshold(label_scores, threshold)
54
+ above = label_scores.select { |_label, score| score >= threshold }
55
+ above.empty? && label_scores.first ? [label_scores.first] : above
56
+ end
57
+
58
+ def format_single_label(label_score, include_confidence)
59
+ label, score = label_score
60
+
61
+ format_label(label, score, include_confidence)
62
+ end
63
+
64
+ def format_label(label, score, include_confidence)
65
+ include_confidence ? { 'label' => label, 'confidence' => score } : label
66
+ end
67
+ end
68
+ end
@@ -0,0 +1,74 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Gliner
4
+ module Config
5
+ class ClassificationTask
6
+ DEFAULT_THRESHOLD = 0.5
7
+
8
+ class << self
9
+ def parse(task_name, config)
10
+ case config
11
+ when Array
12
+ from_labels(config)
13
+ when Hash
14
+ from_hash(task_name, config)
15
+ else
16
+ raise Error, "classification task #{task_name.inspect} must be an Array or Hash"
17
+ end
18
+ end
19
+
20
+ private
21
+
22
+ def from_labels(labels)
23
+ {
24
+ labels: labels.map(&:to_s),
25
+ multi_label: false,
26
+ cls_threshold: DEFAULT_THRESHOLD,
27
+ label_descs: {}
28
+ }
29
+ end
30
+
31
+ def from_hash(task_name, config)
32
+ config_hash = config.transform_keys(&:to_s)
33
+
34
+ return from_described_labels(task_name, config_hash) if config_hash.key?('labels')
35
+
36
+ {
37
+ labels: config.keys.map(&:to_s),
38
+ multi_label: false,
39
+ cls_threshold: DEFAULT_THRESHOLD,
40
+ label_descs: config.transform_keys(&:to_s).transform_values(&:to_s)
41
+ }
42
+ end
43
+
44
+ def from_described_labels(task_name, config_hash)
45
+ labels, label_descs = parse_labels(task_name, config_hash['labels'])
46
+
47
+ {
48
+ labels: labels,
49
+ multi_label: config_hash['multi_label'] ? true : false,
50
+ cls_threshold: threshold(config_hash['cls_threshold']),
51
+ label_descs: label_descs
52
+ }
53
+ end
54
+
55
+ def parse_labels(task_name, raw_labels)
56
+ case raw_labels
57
+ when Array
58
+ [raw_labels.map(&:to_s), {}]
59
+ when Hash
60
+ [raw_labels.keys.map(&:to_s), raw_labels.transform_keys(&:to_s).transform_values(&:to_s)]
61
+ else
62
+ raise Error, "classification task #{task_name.inspect} must include labels"
63
+ end
64
+ end
65
+
66
+ def threshold(value)
67
+ return DEFAULT_THRESHOLD if value.nil? || value == false
68
+
69
+ Float(value)
70
+ end
71
+ end
72
+ end
73
+ end
74
+ end
@@ -0,0 +1,74 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Gliner
4
+ module Config
5
+ class EntityTypes
6
+ class << self
7
+ def parse(entity_types)
8
+ case entity_types
9
+ when Array
10
+ list_config(entity_types)
11
+ when String, Symbol
12
+ list_config([entity_types])
13
+ when Hash
14
+ hash_config(entity_types)
15
+ else
16
+ raise Error, 'labels must be a String, Array, or Hash'
17
+ end
18
+ end
19
+
20
+ private
21
+
22
+ def list_config(entity_types)
23
+ {
24
+ labels: entity_types.map(&:to_s),
25
+ descriptions: {},
26
+ dtypes: {},
27
+ thresholds: {}
28
+ }
29
+ end
30
+
31
+ def hash_config(entity_types)
32
+ state = { labels: [], descriptions: {}, dtypes: {}, thresholds: {} }
33
+ entity_types.each { |label, config| apply_config(state, label, config) }
34
+ state
35
+ end
36
+
37
+ def apply_config(state, label, config)
38
+ name = label.to_s
39
+ state[:labels] << name
40
+
41
+ return if config.nil?
42
+
43
+ case config
44
+ when String
45
+ apply_description(state, name, config)
46
+ when Hash
47
+ apply_hash_config(state, name, config)
48
+ else
49
+ apply_description(state, name, config.to_s)
50
+ end
51
+ end
52
+
53
+ def apply_hash_config(state, name, config)
54
+ config_hash = config.transform_keys(&:to_s)
55
+ apply_description(state, name, config_hash['description']) if config_hash['description']
56
+ apply_dtype(state, name, config_hash['dtype']) if config_hash['dtype']
57
+ apply_threshold(state, name, config_hash['threshold']) if config_hash.key?('threshold')
58
+ end
59
+
60
+ def apply_description(state, name, description)
61
+ state[:descriptions][name] = description
62
+ end
63
+
64
+ def apply_dtype(state, name, dtype)
65
+ state[:dtypes][name] = dtype.to_s == 'str' ? :str : :list
66
+ end
67
+
68
+ def apply_threshold(state, name, threshold)
69
+ state[:thresholds][name] = Float(threshold)
70
+ end
71
+ end
72
+ end
73
+ end
74
+ end
@@ -0,0 +1,87 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Gliner
4
+ module Config
5
+ class FieldSpec
6
+ class << self
7
+ def parse(spec)
8
+ name, *parts = spec.split('::')
9
+ field = build_field(name)
10
+
11
+ parts.each { |part| apply_part(field, part.to_s) }
12
+ field.delete(:dtype_explicit)
13
+ field
14
+ end
15
+
16
+ def build_descriptions(parsed_fields)
17
+ parsed_fields.each_with_object({}) do |field, acc|
18
+ description = description_for(field)
19
+ acc[field[:name]] = description if description
20
+ end
21
+ end
22
+
23
+ private
24
+
25
+ def build_field(name)
26
+ {
27
+ name: name.to_s,
28
+ dtype: :list,
29
+ description: nil,
30
+ choices: nil,
31
+ dtype_explicit: false
32
+ }
33
+ end
34
+
35
+ def apply_part(field, part)
36
+ return apply_dtype_part(field, part) if dtype_part?(part)
37
+ return apply_choice_part(field, part) if bracketed_list?(part)
38
+
39
+ append_description(field, part)
40
+ end
41
+
42
+ def dtype_part?(part)
43
+ %w[str list].include?(part)
44
+ end
45
+
46
+ def apply_dtype_part(field, part)
47
+ set_dtype(field, part == 'str' ? :str : :list)
48
+ end
49
+
50
+ def apply_choice_part(field, part)
51
+ field[:choices] = parse_choices(part)
52
+ field[:dtype] = :str unless field[:dtype_explicit]
53
+ end
54
+
55
+ def set_dtype(field, dtype)
56
+ field[:dtype] = dtype
57
+ field[:dtype_explicit] = true
58
+ end
59
+
60
+ def append_description(field, part)
61
+ field[:description] = [field[:description], part].compact.join('::')
62
+ end
63
+
64
+ def bracketed_list?(part)
65
+ part.start_with?('[') && part.end_with?(']')
66
+ end
67
+
68
+ def parse_choices(part)
69
+ part[1..-2].split('|').map(&:strip).reject(&:empty?)
70
+ end
71
+
72
+ def description_for(field)
73
+ description = field[:description].to_s
74
+ choices = field[:choices]
75
+ return nil if description.empty? && !choices&.any?
76
+
77
+ if choices&.any?
78
+ choices_str = choices.join('|')
79
+ description = description.empty? ? "Choices: #{choices_str}" : "#{description} (choices: #{choices_str})"
80
+ end
81
+
82
+ description.empty? ? nil : description
83
+ end
84
+ end
85
+ end
86
+ end
87
+ end
@@ -0,0 +1,37 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'gliner/config/entity_types'
4
+ require 'gliner/config/classification_task'
5
+ require 'gliner/config/field_spec'
6
+
7
+ module Gliner
8
+ class ConfigParser
9
+ def parse_entity_types(entity_types)
10
+ Config::EntityTypes.parse(entity_types)
11
+ end
12
+
13
+ def parse_classification_task(task_name, config)
14
+ Config::ClassificationTask.parse(task_name, config)
15
+ end
16
+
17
+ def parse_field_spec(spec)
18
+ Config::FieldSpec.parse(spec)
19
+ end
20
+
21
+ def build_field_descriptions(parsed_fields)
22
+ Config::FieldSpec.build_descriptions(parsed_fields)
23
+ end
24
+
25
+ def build_prompt(base, label_descriptions)
26
+ prompt = base.to_s
27
+
28
+ label_descriptions.to_h.each do |label, description|
29
+ next if description.to_s.empty?
30
+
31
+ prompt += " [DESCRIPTION] #{label}: #{description}"
32
+ end
33
+
34
+ prompt
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,67 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Gliner
4
+ class Inference
5
+ class SessionValidator
6
+ EXPECTED_INPUTS_LOGITS = %w[
7
+ input_ids
8
+ attention_mask
9
+ words_mask
10
+ text_lengths
11
+ task_type
12
+ label_positions
13
+ label_mask
14
+ ].freeze
15
+
16
+ EXPECTED_INPUTS_SPAN_LOGITS = %w[
17
+ input_ids
18
+ attention_mask
19
+ ].freeze
20
+
21
+ class << self
22
+ def [](session) = call(session)
23
+
24
+ def call(session)
25
+ input_names = session.inputs.map { |input| input[:name] }
26
+ output_names = session.outputs.map { |output| output[:name] }
27
+ has_cls_logits = output_names.include?('cls_logits')
28
+
29
+ validation = validation_for_outputs(output_names, input_names)
30
+
31
+ IOValidation.new(
32
+ input_names: input_names,
33
+ output_name: validation.fetch(:output_name),
34
+ label_index_mode: validation.fetch(:label_index_mode),
35
+ has_cls_logits: has_cls_logits
36
+ )
37
+ end
38
+
39
+ private
40
+
41
+ def validation_for_outputs(output_names, input_names)
42
+ return validation_for_logits(input_names) if output_names.include?('logits')
43
+ return validation_for_span_logits(input_names) if output_names.include?('span_logits')
44
+
45
+ raise Error, 'Model missing output: logits or span_logits'
46
+ end
47
+
48
+ def validation_for_logits(input_names)
49
+ ensure_expected_inputs!(EXPECTED_INPUTS_LOGITS, input_names)
50
+
51
+ { output_name: 'logits', label_index_mode: :label_index }
52
+ end
53
+
54
+ def validation_for_span_logits(input_names)
55
+ ensure_expected_inputs!(EXPECTED_INPUTS_SPAN_LOGITS, input_names)
56
+
57
+ { output_name: 'span_logits', label_index_mode: :label_position }
58
+ end
59
+
60
+ def ensure_expected_inputs!(expected_inputs, input_names)
61
+ missing = expected_inputs - input_names
62
+ raise Error, "Model missing inputs: #{missing.join(', ')}" unless missing.empty?
63
+ end
64
+ end
65
+ end
66
+ end
67
+ end