gliner 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE +21 -0
- data/README.md +145 -0
- data/bin/console +81 -0
- data/gliner.gemspec +26 -0
- data/lib/gliner/classifier.rb +68 -0
- data/lib/gliner/config/classification_task.rb +74 -0
- data/lib/gliner/config/entity_types.rb +74 -0
- data/lib/gliner/config/field_spec.rb +87 -0
- data/lib/gliner/config_parser.rb +37 -0
- data/lib/gliner/inference/session_validator.rb +67 -0
- data/lib/gliner/inference.rb +124 -0
- data/lib/gliner/input_builder.rb +117 -0
- data/lib/gliner/model.rb +142 -0
- data/lib/gliner/pipeline.rb +64 -0
- data/lib/gliner/position_iteration.rb +21 -0
- data/lib/gliner/runners/classification_runner.rb +26 -0
- data/lib/gliner/runners/entity_runner.rb +19 -0
- data/lib/gliner/runners/prepared_task.rb +55 -0
- data/lib/gliner/runners/structured_runner.rb +36 -0
- data/lib/gliner/span_extractor.rb +117 -0
- data/lib/gliner/structured_extractor.rb +94 -0
- data/lib/gliner/task.rb +29 -0
- data/lib/gliner/tasks/classification.rb +101 -0
- data/lib/gliner/tasks/entity_extraction.rb +72 -0
- data/lib/gliner/tasks/json_extraction.rb +91 -0
- data/lib/gliner/text_processor.rb +42 -0
- data/lib/gliner/version.rb +5 -0
- data/lib/gliner.rb +97 -0
- metadata +150 -0
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require 'gliner/inference/session_validator'
|
|
4
|
+
|
|
5
|
+
module Gliner
|
|
6
|
+
class Inference
|
|
7
|
+
TASK_TYPE_ENTITIES = 0
|
|
8
|
+
TASK_TYPE_CLASSIFICATION = 1
|
|
9
|
+
TASK_TYPE_JSON = 2
|
|
10
|
+
|
|
11
|
+
SCHEMA_PREFIX_LENGTH = 4
|
|
12
|
+
LABEL_SPACING = 2
|
|
13
|
+
|
|
14
|
+
Request = Data.define(
|
|
15
|
+
:input_ids,
|
|
16
|
+
:attention_mask,
|
|
17
|
+
:words_mask,
|
|
18
|
+
:text_lengths,
|
|
19
|
+
:task_type,
|
|
20
|
+
:label_positions,
|
|
21
|
+
:label_mask,
|
|
22
|
+
:want_cls
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
IOValidation = Data.define(:input_names, :output_name, :label_index_mode, :has_cls_logits)
|
|
26
|
+
|
|
27
|
+
attr_reader :label_index_mode, :has_cls_logits
|
|
28
|
+
|
|
29
|
+
def initialize(session)
|
|
30
|
+
@session = session
|
|
31
|
+
|
|
32
|
+
validation = SessionValidator[session]
|
|
33
|
+
|
|
34
|
+
@input_names = validation.input_names
|
|
35
|
+
@output_name = validation.output_name
|
|
36
|
+
@label_index_mode = validation.label_index_mode
|
|
37
|
+
@has_cls_logits = validation.has_cls_logits
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
def run(request)
|
|
41
|
+
outputs = output_names_for(request)
|
|
42
|
+
out = @session.run(outputs, build_inputs(request))
|
|
43
|
+
format_outputs(out, outputs)
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
def label_positions_for(word_ids, label_count)
|
|
47
|
+
label_count.times.map do |i|
|
|
48
|
+
combined_idx = SCHEMA_PREFIX_LENGTH + (i * LABEL_SPACING)
|
|
49
|
+
pos = word_ids.index(combined_idx)
|
|
50
|
+
|
|
51
|
+
raise Error, "Could not locate label position at combined index #{combined_idx}" if pos.nil?
|
|
52
|
+
|
|
53
|
+
pos
|
|
54
|
+
end
|
|
55
|
+
end
|
|
56
|
+
|
|
57
|
+
def label_logit(logits, pos, width, label_index, label_positions)
|
|
58
|
+
if @label_index_mode == :label_position
|
|
59
|
+
raise Error, 'Label positions required for span_logits output' if label_positions.nil?
|
|
60
|
+
|
|
61
|
+
label_pos = label_positions.fetch(label_index)
|
|
62
|
+
logits[0][pos][width][label_pos]
|
|
63
|
+
else
|
|
64
|
+
logits[0][pos][width][label_index]
|
|
65
|
+
end
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
def sigmoid(value)
|
|
69
|
+
1.0 / (1.0 + Math.exp(-value))
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
def softmax(values)
|
|
73
|
+
max_value = values.max
|
|
74
|
+
exps = values.map { |value| Math.exp(value - max_value) }
|
|
75
|
+
sum = exps.sum
|
|
76
|
+
exps.map { |value| value / sum }
|
|
77
|
+
end
|
|
78
|
+
|
|
79
|
+
private
|
|
80
|
+
|
|
81
|
+
def build_inputs(request)
|
|
82
|
+
inputs = base_inputs(request)
|
|
83
|
+
add_token_type_ids(inputs, request)
|
|
84
|
+
filter_inputs(inputs)
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
def base_inputs(request)
|
|
88
|
+
{
|
|
89
|
+
input_ids: [request.input_ids],
|
|
90
|
+
attention_mask: [request.attention_mask],
|
|
91
|
+
words_mask: [request.words_mask],
|
|
92
|
+
text_lengths: Array(request.text_lengths).flatten,
|
|
93
|
+
task_type: [request.task_type],
|
|
94
|
+
label_positions: [request.label_positions],
|
|
95
|
+
label_mask: [request.label_mask]
|
|
96
|
+
}
|
|
97
|
+
end
|
|
98
|
+
|
|
99
|
+
def add_token_type_ids(inputs, request)
|
|
100
|
+
return inputs unless @input_names&.include?('token_type_ids')
|
|
101
|
+
|
|
102
|
+
inputs[:token_type_ids] = [Array.new(request.input_ids.length, 0)]
|
|
103
|
+
inputs
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
def filter_inputs(inputs)
|
|
107
|
+
return inputs unless @input_names
|
|
108
|
+
|
|
109
|
+
inputs.select { |name, _| @input_names.include?(name.to_s) }
|
|
110
|
+
end
|
|
111
|
+
|
|
112
|
+
def output_names_for(request)
|
|
113
|
+
output_names = [@output_name]
|
|
114
|
+
output_names << 'cls_logits' if request.want_cls && @has_cls_logits
|
|
115
|
+
output_names
|
|
116
|
+
end
|
|
117
|
+
|
|
118
|
+
def format_outputs(out, output_names)
|
|
119
|
+
return { logits: out.fetch(0), cls_logits: out.fetch(1) } if output_names.length > 1
|
|
120
|
+
|
|
121
|
+
out.fetch(0)
|
|
122
|
+
end
|
|
123
|
+
end
|
|
124
|
+
end
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require 'set'
|
|
4
|
+
|
|
5
|
+
module Gliner
|
|
6
|
+
class InputBuilder
|
|
7
|
+
def initialize(text_processor, max_seq_len:)
|
|
8
|
+
@text_processor = text_processor
|
|
9
|
+
@max_seq_len = max_seq_len
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
def prepare(text, schema_tokens, already_normalized: false)
|
|
13
|
+
normalized_text = normalize_text(text, already_normalized: already_normalized)
|
|
14
|
+
words, start_map, end_map = @text_processor.split_words(normalized_text)
|
|
15
|
+
input_ids, word_ids = encode_tokens(schema_tokens, words)
|
|
16
|
+
input_ids, word_ids = truncate_inputs(input_ids, word_ids, max_len: @max_seq_len)
|
|
17
|
+
|
|
18
|
+
text_start_index = schema_tokens.length + 1
|
|
19
|
+
text_len = infer_effective_text_len(word_ids, text_start_index, words.length)
|
|
20
|
+
|
|
21
|
+
context = {
|
|
22
|
+
input_ids: input_ids,
|
|
23
|
+
word_ids: word_ids,
|
|
24
|
+
text_start_index: text_start_index,
|
|
25
|
+
start_map: start_map,
|
|
26
|
+
end_map: end_map,
|
|
27
|
+
original_text: normalized_text,
|
|
28
|
+
text_len: text_len
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
build_prepared_input(context)
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
def schema_tokens_for(prompt:, labels:, label_prefix:)
|
|
35
|
+
tokens = ['(', '[P]', prompt.to_s, '(']
|
|
36
|
+
|
|
37
|
+
labels.each do |label|
|
|
38
|
+
tokens << label_prefix
|
|
39
|
+
tokens << label.to_s
|
|
40
|
+
end
|
|
41
|
+
|
|
42
|
+
tokens.push(')', ')')
|
|
43
|
+
tokens
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
private
|
|
47
|
+
|
|
48
|
+
def normalize_text(text, already_normalized:)
|
|
49
|
+
already_normalized ? text.to_s : @text_processor.normalize_text(text)
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
def encode_tokens(schema_tokens, words)
|
|
53
|
+
combined_tokens = schema_tokens + ['[SEP_TEXT]'] + words
|
|
54
|
+
encoded = @text_processor.encode_pretokenized(combined_tokens)
|
|
55
|
+
[encoded[:ids], encoded[:word_ids]]
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
def truncate_inputs(input_ids, word_ids, max_len:)
|
|
59
|
+
return [input_ids, word_ids] if input_ids.length <= max_len
|
|
60
|
+
|
|
61
|
+
[input_ids.take(max_len), word_ids.take(max_len)]
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
def build_prepared_input(context)
|
|
65
|
+
input_ids = context.fetch(:input_ids)
|
|
66
|
+
word_ids = context.fetch(:word_ids)
|
|
67
|
+
text_start_index = context.fetch(:text_start_index)
|
|
68
|
+
|
|
69
|
+
word_analysis = analyze_words(word_ids, text_start_index)
|
|
70
|
+
|
|
71
|
+
PreparedInput.new(
|
|
72
|
+
input_ids: input_ids,
|
|
73
|
+
word_ids: word_ids,
|
|
74
|
+
attention_mask: Array.new(input_ids.length, 1),
|
|
75
|
+
words_mask: word_analysis[:mask],
|
|
76
|
+
pos_to_word_index: word_analysis[:index_map],
|
|
77
|
+
start_map: context.fetch(:start_map),
|
|
78
|
+
end_map: context.fetch(:end_map),
|
|
79
|
+
original_text: context.fetch(:original_text),
|
|
80
|
+
text_len: context.fetch(:text_len)
|
|
81
|
+
)
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
def analyze_words(word_ids, text_start_index)
|
|
85
|
+
mask = Array.new(word_ids.length, 0)
|
|
86
|
+
index_map = Array.new(word_ids.length)
|
|
87
|
+
last_word_id = nil
|
|
88
|
+
seen = Set.new
|
|
89
|
+
|
|
90
|
+
word_ids.each_with_index do |word_id, i|
|
|
91
|
+
next unless word_id
|
|
92
|
+
|
|
93
|
+
# Build mask (word boundaries)
|
|
94
|
+
if word_id != last_word_id && word_id >= text_start_index
|
|
95
|
+
mask[i] = 1
|
|
96
|
+
last_word_id = word_id
|
|
97
|
+
end
|
|
98
|
+
|
|
99
|
+
# Build index map (first occurrence)
|
|
100
|
+
unless seen.include?(word_id)
|
|
101
|
+
seen << word_id
|
|
102
|
+
index_map[i] = word_id - text_start_index if word_id >= text_start_index
|
|
103
|
+
end
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
{ mask: mask, index_map: index_map }
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
def infer_effective_text_len(word_ids, text_start_index, full_text_len)
|
|
110
|
+
max_text_word_id = word_ids.compact.select { |word_id| word_id >= text_start_index }.max
|
|
111
|
+
return full_text_len if max_text_word_id.nil?
|
|
112
|
+
|
|
113
|
+
present = (max_text_word_id - text_start_index) + 1
|
|
114
|
+
[present, full_text_len].min
|
|
115
|
+
end
|
|
116
|
+
end
|
|
117
|
+
end
|
data/lib/gliner/model.rb
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require 'json'
|
|
4
|
+
require 'onnxruntime'
|
|
5
|
+
require 'tokenizers'
|
|
6
|
+
|
|
7
|
+
require 'gliner/text_processor'
|
|
8
|
+
require 'gliner/config_parser'
|
|
9
|
+
require 'gliner/inference'
|
|
10
|
+
require 'gliner/input_builder'
|
|
11
|
+
require 'gliner/span_extractor'
|
|
12
|
+
require 'gliner/classifier'
|
|
13
|
+
require 'gliner/structured_extractor'
|
|
14
|
+
require 'gliner/task'
|
|
15
|
+
require 'gliner/pipeline'
|
|
16
|
+
require 'gliner/tasks/entity_extraction'
|
|
17
|
+
require 'gliner/tasks/classification'
|
|
18
|
+
require 'gliner/tasks/json_extraction'
|
|
19
|
+
|
|
20
|
+
module Gliner
|
|
21
|
+
class Model
|
|
22
|
+
DEFAULT_MAX_WIDTH = 8
|
|
23
|
+
DEFAULT_MAX_SEQ_LEN = 512
|
|
24
|
+
|
|
25
|
+
def self.from_dir(dir, file: 'model_int8.onnx')
|
|
26
|
+
config_path = File.join(dir, 'config.json')
|
|
27
|
+
config = File.exist?(config_path) ? JSON.parse(File.read(config_path)) : {}
|
|
28
|
+
|
|
29
|
+
new(
|
|
30
|
+
model_path: File.join(dir, file),
|
|
31
|
+
tokenizer_path: File.join(dir, 'tokenizer.json'),
|
|
32
|
+
max_width: config.fetch('max_width', DEFAULT_MAX_WIDTH),
|
|
33
|
+
max_seq_len: config.fetch('max_seq_len', DEFAULT_MAX_SEQ_LEN)
|
|
34
|
+
)
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
def initialize(model_path:, tokenizer_path:, max_width: DEFAULT_MAX_WIDTH, max_seq_len: DEFAULT_MAX_SEQ_LEN)
|
|
38
|
+
@model_path = model_path
|
|
39
|
+
@tokenizer_path = tokenizer_path
|
|
40
|
+
@max_width = Integer(max_width)
|
|
41
|
+
@max_seq_len = Integer(max_seq_len)
|
|
42
|
+
|
|
43
|
+
tokenizer = Tokenizers.from_file(@tokenizer_path)
|
|
44
|
+
session = OnnxRuntime::InferenceSession.new(@model_path)
|
|
45
|
+
|
|
46
|
+
@text_processor = TextProcessor.new(tokenizer)
|
|
47
|
+
@inference = Inference.new(session)
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
def config_parser
|
|
51
|
+
@config_parser ||= ConfigParser.new
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
def pipeline
|
|
55
|
+
@pipeline ||= Pipeline.new(text_processor: @text_processor, inference: @inference)
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
def input_builder
|
|
59
|
+
@input_builder ||= InputBuilder.new(@text_processor, max_seq_len: @max_seq_len)
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
def span_extractor
|
|
63
|
+
@span_extractor ||= SpanExtractor.new(@inference, max_width: @max_width)
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
def structured_extractor
|
|
67
|
+
@structured_extractor ||= StructuredExtractor.new(span_extractor)
|
|
68
|
+
end
|
|
69
|
+
|
|
70
|
+
def classifier
|
|
71
|
+
@classifier ||= Classifier.new(@inference, max_width: @max_width)
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
def entity_task
|
|
75
|
+
@entity_task ||= Tasks::EntityExtraction.new(
|
|
76
|
+
config_parser: config_parser,
|
|
77
|
+
inference: @inference,
|
|
78
|
+
input_builder: input_builder,
|
|
79
|
+
span_extractor: span_extractor
|
|
80
|
+
)
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
def classification_task
|
|
84
|
+
@classification_task ||= Tasks::Classification.new(
|
|
85
|
+
config_parser: config_parser,
|
|
86
|
+
inference: @inference,
|
|
87
|
+
input_builder: input_builder,
|
|
88
|
+
classifier: classifier
|
|
89
|
+
)
|
|
90
|
+
end
|
|
91
|
+
|
|
92
|
+
def json_task
|
|
93
|
+
@json_task ||= Tasks::JsonExtraction.new(
|
|
94
|
+
config_parser: config_parser,
|
|
95
|
+
inference: @inference,
|
|
96
|
+
input_builder: input_builder,
|
|
97
|
+
span_extractor: span_extractor,
|
|
98
|
+
structured_extractor: structured_extractor
|
|
99
|
+
)
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
def extract_entities(text, entity_types, **options)
|
|
103
|
+
threshold = options.fetch(:threshold, 0.5)
|
|
104
|
+
include_confidence = options.fetch(:include_confidence, false)
|
|
105
|
+
include_spans = options.fetch(:include_spans, false)
|
|
106
|
+
|
|
107
|
+
pipeline.execute(
|
|
108
|
+
entity_task,
|
|
109
|
+
text,
|
|
110
|
+
entity_types,
|
|
111
|
+
threshold: threshold,
|
|
112
|
+
include_confidence: include_confidence,
|
|
113
|
+
include_spans: include_spans
|
|
114
|
+
)
|
|
115
|
+
end
|
|
116
|
+
|
|
117
|
+
def classify_text(text, tasks, **options)
|
|
118
|
+
include_confidence = options.fetch(:include_confidence, false)
|
|
119
|
+
threshold = options[:threshold]
|
|
120
|
+
|
|
121
|
+
task_options = { include_confidence: include_confidence }
|
|
122
|
+
task_options[:threshold] = threshold unless threshold.nil?
|
|
123
|
+
|
|
124
|
+
classification_task.execute_all(pipeline, text, tasks, **task_options)
|
|
125
|
+
end
|
|
126
|
+
|
|
127
|
+
def extract_json(text, structures, **options)
|
|
128
|
+
threshold = options.fetch(:threshold, 0.5)
|
|
129
|
+
include_confidence = options.fetch(:include_confidence, false)
|
|
130
|
+
include_spans = options.fetch(:include_spans, false)
|
|
131
|
+
|
|
132
|
+
json_task.execute_all(
|
|
133
|
+
pipeline,
|
|
134
|
+
text,
|
|
135
|
+
structures,
|
|
136
|
+
threshold: threshold,
|
|
137
|
+
include_confidence: include_confidence,
|
|
138
|
+
include_spans: include_spans
|
|
139
|
+
)
|
|
140
|
+
end
|
|
141
|
+
end
|
|
142
|
+
end
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Gliner
|
|
4
|
+
class Pipeline
|
|
5
|
+
def initialize(text_processor:, inference:)
|
|
6
|
+
@text_processor = text_processor
|
|
7
|
+
@inference = inference
|
|
8
|
+
end
|
|
9
|
+
|
|
10
|
+
def execute(task, text, config, **options)
|
|
11
|
+
parsed = task.parse_config(config)
|
|
12
|
+
prepared_text = prepare_text(task, text)
|
|
13
|
+
labels = task.labels(parsed)
|
|
14
|
+
prepared = prepare_input(task, prepared_text, parsed, labels)
|
|
15
|
+
label_positions = label_positions_for(prepared, labels.length)
|
|
16
|
+
logits = run_inference(task, prepared, labels, label_positions)
|
|
17
|
+
|
|
18
|
+
task.process_output(logits, parsed, prepared, options.merge(label_positions: label_positions))
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
private
|
|
22
|
+
|
|
23
|
+
def prepare_text(task, text)
|
|
24
|
+
return @text_processor.normalize_text(text) if task.normalize_text?
|
|
25
|
+
|
|
26
|
+
text.to_s
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
def prepare_input(task, prepared_text, parsed, labels)
|
|
30
|
+
schema_tokens = task.input_builder.schema_tokens_for(
|
|
31
|
+
prompt: task.build_prompt(parsed),
|
|
32
|
+
labels: labels,
|
|
33
|
+
label_prefix: task.label_prefix
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
task.input_builder.prepare(
|
|
37
|
+
prepared_text,
|
|
38
|
+
schema_tokens,
|
|
39
|
+
already_normalized: task.normalize_text?
|
|
40
|
+
)
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
def label_positions_for(prepared, label_count)
|
|
44
|
+
@inference.label_positions_for(prepared.word_ids, label_count)
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
def run_inference(task, prepared, labels, label_positions)
|
|
48
|
+
@inference.run(build_request(task, prepared, labels, label_positions))
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
def build_request(task, prepared, labels, label_positions)
|
|
52
|
+
Inference::Request.new(
|
|
53
|
+
input_ids: prepared.input_ids,
|
|
54
|
+
attention_mask: prepared.attention_mask,
|
|
55
|
+
words_mask: prepared.words_mask,
|
|
56
|
+
text_lengths: [prepared.text_len],
|
|
57
|
+
task_type: task.task_type,
|
|
58
|
+
label_positions: label_positions,
|
|
59
|
+
label_mask: Array.new(labels.length, 1),
|
|
60
|
+
want_cls: task.needs_cls_logits?
|
|
61
|
+
)
|
|
62
|
+
end
|
|
63
|
+
end
|
|
64
|
+
end
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Gliner
|
|
4
|
+
module PositionIteration
|
|
5
|
+
def each_position_width(seq_len, prepared, max_width)
|
|
6
|
+
return enum_for(:each_position_width, seq_len, prepared, max_width) unless block_given?
|
|
7
|
+
|
|
8
|
+
(0...seq_len).each do |pos|
|
|
9
|
+
start_word = prepared.pos_to_word_index[pos]
|
|
10
|
+
next unless start_word
|
|
11
|
+
|
|
12
|
+
(0...max_width).each do |width|
|
|
13
|
+
end_word = start_word + width
|
|
14
|
+
next if end_word >= prepared.text_len
|
|
15
|
+
|
|
16
|
+
yield pos, start_word, width
|
|
17
|
+
end
|
|
18
|
+
end
|
|
19
|
+
end
|
|
20
|
+
end
|
|
21
|
+
end
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Gliner
|
|
4
|
+
module Runners
|
|
5
|
+
class ClassificationRunner
|
|
6
|
+
def self.[](tasks)
|
|
7
|
+
new(Gliner.model!, tasks)
|
|
8
|
+
end
|
|
9
|
+
|
|
10
|
+
def initialize(model, tasks_config)
|
|
11
|
+
raise Error, 'tasks must be a Hash' unless tasks_config.is_a?(Hash)
|
|
12
|
+
|
|
13
|
+
@tasks = tasks_config.to_h do |name, config|
|
|
14
|
+
parsed = model.classification_task.parse_config(name: name, config: config)
|
|
15
|
+
[name.to_s, PreparedTask.new(model.classification_task, parsed)]
|
|
16
|
+
end
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
def [](text, **options)
|
|
20
|
+
@tasks.transform_values { |task| task.call(text, **options) }
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
alias call []
|
|
24
|
+
end
|
|
25
|
+
end
|
|
26
|
+
end
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Gliner
|
|
4
|
+
module Runners
|
|
5
|
+
class EntityRunner
|
|
6
|
+
def initialize(model, config)
|
|
7
|
+
parsed = model.entity_task.parse_config(config)
|
|
8
|
+
@task = PreparedTask.new(model.entity_task, parsed)
|
|
9
|
+
end
|
|
10
|
+
|
|
11
|
+
def [](text, **options)
|
|
12
|
+
result = @task.call(text, **options)
|
|
13
|
+
result.fetch('entities')
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
alias call []
|
|
17
|
+
end
|
|
18
|
+
end
|
|
19
|
+
end
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Gliner
|
|
4
|
+
module Runners
|
|
5
|
+
class PreparedTask
|
|
6
|
+
def initialize(task, parsed)
|
|
7
|
+
@task = task
|
|
8
|
+
@parsed = parsed
|
|
9
|
+
@labels = task.labels(parsed)
|
|
10
|
+
|
|
11
|
+
@schema_tokens = task.input_builder.schema_tokens_for(
|
|
12
|
+
prompt: task.build_prompt(parsed),
|
|
13
|
+
labels: @labels,
|
|
14
|
+
label_prefix: task.label_prefix
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
@label_mask = Array.new(@labels.length, 1)
|
|
18
|
+
@label_positions_template = precompute_label_positions
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
def call(text, **options)
|
|
22
|
+
prepared = @task.input_builder.prepare(text, @schema_tokens)
|
|
23
|
+
label_positions = @label_positions_template
|
|
24
|
+
|
|
25
|
+
if label_positions.any? { |pos| pos.nil? || pos >= prepared.input_ids.length }
|
|
26
|
+
label_positions = @task.inference.label_positions_for(prepared.word_ids, @labels.length)
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
logits = @task.inference.run(
|
|
30
|
+
Inference::Request.new(
|
|
31
|
+
input_ids: prepared.input_ids,
|
|
32
|
+
attention_mask: prepared.attention_mask,
|
|
33
|
+
words_mask: prepared.words_mask,
|
|
34
|
+
text_lengths: [prepared.text_len],
|
|
35
|
+
task_type: @task.task_type,
|
|
36
|
+
label_positions: label_positions,
|
|
37
|
+
label_mask: @label_mask,
|
|
38
|
+
want_cls: @task.needs_cls_logits?
|
|
39
|
+
)
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@task.process_output(logits, @parsed, prepared, options.merge(label_positions: label_positions))
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
private
|
|
46
|
+
|
|
47
|
+
def precompute_label_positions
|
|
48
|
+
return [] if @labels.empty?
|
|
49
|
+
|
|
50
|
+
prepared = @task.input_builder.prepare('.', @schema_tokens)
|
|
51
|
+
@task.inference.label_positions_for(prepared.word_ids, @labels.length)
|
|
52
|
+
end
|
|
53
|
+
end
|
|
54
|
+
end
|
|
55
|
+
end
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Gliner
|
|
4
|
+
module Runners
|
|
5
|
+
class StructuredRunner
|
|
6
|
+
def initialize(model, config)
|
|
7
|
+
@tasks = build_tasks(model, config)
|
|
8
|
+
end
|
|
9
|
+
|
|
10
|
+
def [](text, **options)
|
|
11
|
+
@tasks.transform_values do |task|
|
|
12
|
+
task.call(text, **options)
|
|
13
|
+
end
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
alias call []
|
|
17
|
+
|
|
18
|
+
private
|
|
19
|
+
|
|
20
|
+
def build_tasks(model, config)
|
|
21
|
+
raise Error, 'structures must be a Hash' unless config.is_a?(Hash)
|
|
22
|
+
|
|
23
|
+
if config.key?(:name) || config.key?('name')
|
|
24
|
+
parsed = model.json_task.parse_config(config)
|
|
25
|
+
|
|
26
|
+
{ parsed[:name].to_s => PreparedTask.new(model.json_task, parsed) }
|
|
27
|
+
else
|
|
28
|
+
config.each_with_object({}) do |(name, fields), tasks|
|
|
29
|
+
parsed = model.json_task.parse_config(name: name, fields: fields)
|
|
30
|
+
tasks[name.to_s] = PreparedTask.new(model.json_task, parsed)
|
|
31
|
+
end
|
|
32
|
+
end
|
|
33
|
+
end
|
|
34
|
+
end
|
|
35
|
+
end
|
|
36
|
+
end
|