gliner 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE +21 -0
- data/README.md +145 -0
- data/bin/console +81 -0
- data/gliner.gemspec +26 -0
- data/lib/gliner/classifier.rb +68 -0
- data/lib/gliner/config/classification_task.rb +74 -0
- data/lib/gliner/config/entity_types.rb +74 -0
- data/lib/gliner/config/field_spec.rb +87 -0
- data/lib/gliner/config_parser.rb +37 -0
- data/lib/gliner/inference/session_validator.rb +67 -0
- data/lib/gliner/inference.rb +124 -0
- data/lib/gliner/input_builder.rb +117 -0
- data/lib/gliner/model.rb +142 -0
- data/lib/gliner/pipeline.rb +64 -0
- data/lib/gliner/position_iteration.rb +21 -0
- data/lib/gliner/runners/classification_runner.rb +26 -0
- data/lib/gliner/runners/entity_runner.rb +19 -0
- data/lib/gliner/runners/prepared_task.rb +55 -0
- data/lib/gliner/runners/structured_runner.rb +36 -0
- data/lib/gliner/span_extractor.rb +117 -0
- data/lib/gliner/structured_extractor.rb +94 -0
- data/lib/gliner/task.rb +29 -0
- data/lib/gliner/tasks/classification.rb +101 -0
- data/lib/gliner/tasks/entity_extraction.rb +72 -0
- data/lib/gliner/tasks/json_extraction.rb +91 -0
- data/lib/gliner/text_processor.rb +42 -0
- data/lib/gliner/version.rb +5 -0
- data/lib/gliner.rb +97 -0
- metadata +150 -0
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require 'gliner/position_iteration'
|
|
4
|
+
|
|
5
|
+
module Gliner
|
|
6
|
+
class SpanExtractor
|
|
7
|
+
include PositionIteration
|
|
8
|
+
|
|
9
|
+
SCORE_SIMILARITY_THRESHOLD = 0.02
|
|
10
|
+
|
|
11
|
+
def initialize(inference, max_width:)
|
|
12
|
+
@inference = inference
|
|
13
|
+
@max_width = max_width
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
def extract_spans_by_label(logits, labels, label_positions, prepared, threshold: 0.5, thresholds_by_label: nil)
|
|
17
|
+
labels.each_with_index.with_object({}) do |(label, label_index), out|
|
|
18
|
+
out[label.to_s] = find_spans_for_label(
|
|
19
|
+
logits: logits,
|
|
20
|
+
label_index: label_index,
|
|
21
|
+
label_positions: label_positions,
|
|
22
|
+
prepared: prepared,
|
|
23
|
+
threshold: threshold_for(label, threshold, thresholds_by_label)
|
|
24
|
+
)
|
|
25
|
+
end
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
def find_spans_for_label(logits:, label_index:, label_positions:, prepared:, threshold:)
|
|
29
|
+
seq_len = logits.first.length
|
|
30
|
+
|
|
31
|
+
each_position_width(seq_len, prepared, @max_width).filter_map do |pos, start_word, width|
|
|
32
|
+
score = calculate_span_score(logits, pos, width, label_index, label_positions)
|
|
33
|
+
next if score < threshold
|
|
34
|
+
|
|
35
|
+
build_span(prepared, start_word, start_word + width, score)
|
|
36
|
+
end
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
def choose_best_span(spans)
|
|
40
|
+
return nil if spans.empty?
|
|
41
|
+
|
|
42
|
+
sorted = spans.sort_by { |s| [-s.score, (s.end - s.start), s.text.length] }
|
|
43
|
+
best = sorted.first
|
|
44
|
+
best_score = best.score
|
|
45
|
+
near = spans_within_threshold(sorted, best_score)
|
|
46
|
+
|
|
47
|
+
near.min_by { |s| [(s.end - s.start), -s.score, s.text.length] } || best
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
def format_single_span(span, opts)
|
|
51
|
+
format_span(span, opts)
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
def format_spans(spans, opts)
|
|
55
|
+
return [] if spans.empty?
|
|
56
|
+
|
|
57
|
+
sorted = spans.sort_by { |s| -s.score }
|
|
58
|
+
selected = []
|
|
59
|
+
|
|
60
|
+
sorted.each do |span|
|
|
61
|
+
overlaps = selected.any? { |s| span.overlaps?(s) }
|
|
62
|
+
next if overlaps
|
|
63
|
+
|
|
64
|
+
selected << span
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
selected.map { |span| format_span(span, opts) }
|
|
68
|
+
end
|
|
69
|
+
|
|
70
|
+
private
|
|
71
|
+
|
|
72
|
+
def calculate_span_score(logits, pos, width, label_index, label_positions)
|
|
73
|
+
logit = @inference.label_logit(logits, pos, width, label_index, label_positions)
|
|
74
|
+
@inference.sigmoid(logit)
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
def spans_within_threshold(sorted_spans, best_score)
|
|
78
|
+
sorted_spans.take_while { |span| (best_score - span.score) <= SCORE_SIMILARITY_THRESHOLD }
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
def threshold_for(label, default_threshold, thresholds_by_label)
|
|
82
|
+
return default_threshold unless thresholds_by_label&.key?(label.to_s)
|
|
83
|
+
|
|
84
|
+
thresholds_by_label.fetch(label.to_s)
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
def build_span(prepared, start_word, end_word, score)
|
|
88
|
+
char_start = prepared.start_map[start_word]
|
|
89
|
+
char_end = prepared.end_map[end_word]
|
|
90
|
+
|
|
91
|
+
return nil if char_start.nil? || char_end.nil?
|
|
92
|
+
|
|
93
|
+
text_span = prepared.original_text[char_start...char_end].to_s.strip
|
|
94
|
+
|
|
95
|
+
return nil if text_span.empty?
|
|
96
|
+
|
|
97
|
+
Span.new(text: text_span, score: score, start: char_start, end: char_end)
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
def format_span(span, opts)
|
|
101
|
+
return nil if span.nil?
|
|
102
|
+
|
|
103
|
+
format_opts = FormatOptions.from(opts)
|
|
104
|
+
return span.text unless format_opts.include_confidence || format_opts.include_spans
|
|
105
|
+
|
|
106
|
+
result = { 'text' => span.text }
|
|
107
|
+
result['confidence'] = span.score if format_opts.include_confidence
|
|
108
|
+
|
|
109
|
+
if format_opts.include_spans
|
|
110
|
+
result['start'] = span.start
|
|
111
|
+
result['end'] = span.end
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
result
|
|
115
|
+
end
|
|
116
|
+
end
|
|
117
|
+
end
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Gliner
|
|
4
|
+
class StructuredExtractor
|
|
5
|
+
def initialize(span_extractor)
|
|
6
|
+
@span_extractor = span_extractor
|
|
7
|
+
end
|
|
8
|
+
|
|
9
|
+
def apply_choice_filters(spans_by_label, parsed_fields)
|
|
10
|
+
filtered = spans_by_label.transform_values(&:dup)
|
|
11
|
+
|
|
12
|
+
parsed_fields.each do |field|
|
|
13
|
+
next unless field[:choices]&.any?
|
|
14
|
+
|
|
15
|
+
label = field[:name]
|
|
16
|
+
spans = filtered.fetch(label, [])
|
|
17
|
+
filtered[label] = filter_spans_by_choices(spans, field[:choices])
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
filtered
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
def filter_spans_by_choices(spans, choices)
|
|
24
|
+
return spans if spans.empty? || choices.nil? || choices.empty?
|
|
25
|
+
|
|
26
|
+
normalized_choices = choices.map { |choice| normalize_choice(choice) }
|
|
27
|
+
matched = spans.select { |span| normalized_choices.include?(normalize_choice(span.text)) }
|
|
28
|
+
|
|
29
|
+
return spans if matched.empty?
|
|
30
|
+
|
|
31
|
+
matched
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
def build_structure_instances(parsed_fields, spans_by_label, opts)
|
|
35
|
+
format_opts = FormatOptions.from(opts)
|
|
36
|
+
anchor_field = anchor_field_for(parsed_fields)
|
|
37
|
+
return [{}] unless anchor_field
|
|
38
|
+
|
|
39
|
+
anchors = spans_by_label.fetch(anchor_field[:name], [])
|
|
40
|
+
return [format_structure_object(parsed_fields, spans_by_label, format_opts)] if anchors.empty?
|
|
41
|
+
|
|
42
|
+
instance_spans = build_instance_spans(anchors, spans_by_label)
|
|
43
|
+
format_instances(parsed_fields, instance_spans, format_opts)
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
def format_structure_object(parsed_fields, spans_by_label, opts)
|
|
47
|
+
obj = {}
|
|
48
|
+
|
|
49
|
+
parsed_fields.each do |field|
|
|
50
|
+
key = field[:name]
|
|
51
|
+
spans = spans_by_label.fetch(key, [])
|
|
52
|
+
|
|
53
|
+
if field[:dtype] == :str
|
|
54
|
+
best = @span_extractor.choose_best_span(spans)
|
|
55
|
+
obj[key] = @span_extractor.format_single_span(best, opts)
|
|
56
|
+
else
|
|
57
|
+
obj[key] = @span_extractor.format_spans(spans, opts)
|
|
58
|
+
end
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
obj
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
private
|
|
65
|
+
|
|
66
|
+
def anchor_field_for(parsed_fields)
|
|
67
|
+
parsed_fields.find { |field| field[:dtype] == :str } || parsed_fields.first
|
|
68
|
+
end
|
|
69
|
+
|
|
70
|
+
def build_instance_spans(anchors, spans_by_label)
|
|
71
|
+
anchors_sorted = anchors.sort_by(&:start)
|
|
72
|
+
instance_spans = anchors_sorted.map { Hash.new { |hash, key| hash[key] = [] } }
|
|
73
|
+
|
|
74
|
+
spans_by_label.each do |label, spans|
|
|
75
|
+
spans.each do |span|
|
|
76
|
+
anchor_index = anchors_sorted.rindex { |anchor| anchor.start <= span.start } || 0
|
|
77
|
+
instance_spans[anchor_index][label] << span
|
|
78
|
+
end
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
instance_spans
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
def format_instances(parsed_fields, instance_spans, opts)
|
|
85
|
+
instance_spans.map do |field_spans|
|
|
86
|
+
format_structure_object(parsed_fields, field_spans, opts)
|
|
87
|
+
end
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
def normalize_choice(value)
|
|
91
|
+
value.to_s.strip.downcase
|
|
92
|
+
end
|
|
93
|
+
end
|
|
94
|
+
end
|
data/lib/gliner/task.rb
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Gliner
|
|
4
|
+
class Task
|
|
5
|
+
attr_reader :config_parser, :inference, :input_builder
|
|
6
|
+
|
|
7
|
+
def initialize(config_parser:, inference:, input_builder:)
|
|
8
|
+
@config_parser = config_parser
|
|
9
|
+
@inference = inference
|
|
10
|
+
@input_builder = input_builder
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def parse_config(input) = raise NotImplementedError
|
|
14
|
+
|
|
15
|
+
def task_type = raise NotImplementedError
|
|
16
|
+
|
|
17
|
+
def label_prefix = raise NotImplementedError
|
|
18
|
+
|
|
19
|
+
def build_prompt(parsed) = raise NotImplementedError
|
|
20
|
+
|
|
21
|
+
def labels(parsed) = raise NotImplementedError
|
|
22
|
+
|
|
23
|
+
def process_output(logits, parsed, prepared, options) = raise NotImplementedError
|
|
24
|
+
|
|
25
|
+
def normalize_text? = false
|
|
26
|
+
|
|
27
|
+
def needs_cls_logits? = false
|
|
28
|
+
end
|
|
29
|
+
end
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Gliner
|
|
4
|
+
module Tasks
|
|
5
|
+
class Classification < Task
|
|
6
|
+
def initialize(config_parser:, inference:, input_builder:, classifier:)
|
|
7
|
+
super(config_parser: config_parser, inference: inference, input_builder: input_builder)
|
|
8
|
+
|
|
9
|
+
@classifier = classifier
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
def parse_config(input)
|
|
13
|
+
raise Error, 'classification config must be a Hash' unless input.is_a?(Hash)
|
|
14
|
+
|
|
15
|
+
name, task_config = extract_task_config(input)
|
|
16
|
+
parsed = config_parser.parse_classification_task(name, task_config)
|
|
17
|
+
parsed.merge(name: name.to_s)
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
def task_type
|
|
21
|
+
Inference::TASK_TYPE_CLASSIFICATION
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def label_prefix
|
|
25
|
+
'[L]'
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
def build_prompt(parsed)
|
|
29
|
+
config_parser.build_prompt(parsed[:name], parsed[:label_descs])
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
def labels(parsed)
|
|
33
|
+
parsed[:labels]
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
def needs_cls_logits?
|
|
37
|
+
inference.has_cls_logits
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
def process_output(logits, parsed, prepared, options)
|
|
41
|
+
include_confidence = options.fetch(:include_confidence, false)
|
|
42
|
+
threshold_override = options[:threshold]
|
|
43
|
+
cls_threshold = threshold_override.nil? ? parsed[:cls_threshold] : threshold_override
|
|
44
|
+
|
|
45
|
+
scores = classification_scores(logits, parsed, prepared, options)
|
|
46
|
+
@classifier.format_classification(
|
|
47
|
+
scores,
|
|
48
|
+
labels: parsed[:labels],
|
|
49
|
+
multi_label: parsed[:multi_label],
|
|
50
|
+
include_confidence: include_confidence,
|
|
51
|
+
cls_threshold: cls_threshold
|
|
52
|
+
)
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
def execute_all(pipeline, text, tasks_config, **options)
|
|
56
|
+
raise Error, 'tasks must be a Hash' unless tasks_config.is_a?(Hash)
|
|
57
|
+
|
|
58
|
+
tasks_config.each_with_object({}) do |(task_name, task_config), results|
|
|
59
|
+
parsed_config = { name: task_name, config: task_config }
|
|
60
|
+
results[task_name.to_s] = pipeline.execute(self, text, parsed_config, **options)
|
|
61
|
+
end
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
private
|
|
65
|
+
|
|
66
|
+
def extract_task_config(input)
|
|
67
|
+
name = input[:name] || input['name']
|
|
68
|
+
task_config = input[:config] || input['config']
|
|
69
|
+
|
|
70
|
+
return [name, task_config] if name && task_config
|
|
71
|
+
return input.first if name.nil? && task_config.nil? && input.length == 1
|
|
72
|
+
|
|
73
|
+
raise Error, 'classification config must include :name and :config'
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
def classification_scores(logits, parsed, prepared, options)
|
|
77
|
+
return cls_scores(logits, parsed) if cls_logits?(logits)
|
|
78
|
+
|
|
79
|
+
label_positions = options.fetch(:label_positions) do
|
|
80
|
+
inference.label_positions_for(prepared.word_ids, parsed[:labels].length)
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
@classifier.classification_scores(
|
|
84
|
+
logits,
|
|
85
|
+
parsed[:labels],
|
|
86
|
+
label_positions,
|
|
87
|
+
prepared
|
|
88
|
+
)
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
def cls_logits?(logits)
|
|
92
|
+
logits.is_a?(Hash) && logits.key?(:cls_logits)
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
def cls_scores(logits, parsed)
|
|
96
|
+
cls_logits = Array(logits.fetch(:cls_logits).fetch(0))
|
|
97
|
+
parsed[:multi_label] ? cls_logits.map { |value| inference.sigmoid(value) } : inference.softmax(cls_logits)
|
|
98
|
+
end
|
|
99
|
+
end
|
|
100
|
+
end
|
|
101
|
+
end
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Gliner
|
|
4
|
+
module Tasks
|
|
5
|
+
class EntityExtraction < Task
|
|
6
|
+
def initialize(config_parser:, inference:, input_builder:, span_extractor:)
|
|
7
|
+
super(config_parser: config_parser, inference: inference, input_builder: input_builder)
|
|
8
|
+
@span_extractor = span_extractor
|
|
9
|
+
end
|
|
10
|
+
|
|
11
|
+
def parse_config(input)
|
|
12
|
+
config_parser.parse_entity_types(input)
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
def task_type
|
|
16
|
+
Inference::TASK_TYPE_ENTITIES
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
def label_prefix
|
|
20
|
+
'[E]'
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
def build_prompt(parsed)
|
|
24
|
+
config_parser.build_prompt('entities', parsed[:descriptions])
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def labels(parsed)
|
|
28
|
+
parsed[:labels]
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def process_output(logits, parsed, prepared, options)
|
|
32
|
+
threshold = options.fetch(:threshold, 0.5)
|
|
33
|
+
format_opts = FormatOptions.from(options)
|
|
34
|
+
label_positions = options[:label_positions] || inference.label_positions_for(prepared.word_ids, parsed[:labels].length)
|
|
35
|
+
|
|
36
|
+
spans_by_label = extract_spans(logits, parsed, prepared, label_positions, threshold)
|
|
37
|
+
|
|
38
|
+
{ 'entities' => format_entities(parsed, spans_by_label, format_opts) }
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
private
|
|
42
|
+
|
|
43
|
+
def extract_spans(logits, parsed, prepared, label_positions, threshold)
|
|
44
|
+
@span_extractor.extract_spans_by_label(
|
|
45
|
+
logits,
|
|
46
|
+
parsed[:labels],
|
|
47
|
+
label_positions,
|
|
48
|
+
prepared,
|
|
49
|
+
threshold: threshold,
|
|
50
|
+
thresholds_by_label: parsed[:thresholds]
|
|
51
|
+
)
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
def format_entities(parsed, spans_by_label, format_opts)
|
|
55
|
+
parsed[:labels].each_with_object({}) do |label, entities|
|
|
56
|
+
spans = spans_by_label.fetch(label)
|
|
57
|
+
dtype = parsed[:dtypes].fetch(label, :list)
|
|
58
|
+
|
|
59
|
+
entities[label] = format_entity_value(spans, dtype, format_opts)
|
|
60
|
+
end
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
def format_entity_value(spans, dtype, format_opts)
|
|
64
|
+
if dtype == :str
|
|
65
|
+
@span_extractor.format_single_span(@span_extractor.choose_best_span(spans), format_opts)
|
|
66
|
+
else
|
|
67
|
+
@span_extractor.format_spans(spans, format_opts)
|
|
68
|
+
end
|
|
69
|
+
end
|
|
70
|
+
end
|
|
71
|
+
end
|
|
72
|
+
end
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Gliner
|
|
4
|
+
module Tasks
|
|
5
|
+
class JsonExtraction < Task
|
|
6
|
+
def initialize(config_parser:, inference:, input_builder:, span_extractor:, structured_extractor:)
|
|
7
|
+
super(config_parser: config_parser, inference: inference, input_builder: input_builder)
|
|
8
|
+
|
|
9
|
+
@span_extractor = span_extractor
|
|
10
|
+
@structured_extractor = structured_extractor
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def parse_config(input)
|
|
14
|
+
raise Error, 'structure config must be a Hash' unless input.is_a?(Hash)
|
|
15
|
+
|
|
16
|
+
name, fields = extract_structure_config(input)
|
|
17
|
+
parsed_fields = Array(fields).map { |spec| config_parser.parse_field_spec(spec.to_s) }
|
|
18
|
+
|
|
19
|
+
{
|
|
20
|
+
name: name.to_s,
|
|
21
|
+
parsed_fields: parsed_fields,
|
|
22
|
+
labels: parsed_fields.map { |field| field[:name] },
|
|
23
|
+
descriptions: config_parser.build_field_descriptions(parsed_fields)
|
|
24
|
+
}
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def task_type
|
|
28
|
+
Inference::TASK_TYPE_JSON
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def label_prefix
|
|
32
|
+
'[C]'
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
def normalize_text?
|
|
36
|
+
true
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
def build_prompt(parsed)
|
|
40
|
+
config_parser.build_prompt(parsed[:name], parsed[:descriptions])
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
def labels(parsed)
|
|
44
|
+
parsed[:labels]
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
def process_output(logits, parsed, prepared, options)
|
|
48
|
+
spans_by_label = extract_spans(logits, parsed, prepared, options)
|
|
49
|
+
filtered_spans = @structured_extractor.apply_choice_filters(spans_by_label, parsed[:parsed_fields])
|
|
50
|
+
format_opts = FormatOptions.from(options)
|
|
51
|
+
|
|
52
|
+
@structured_extractor.build_structure_instances(parsed[:parsed_fields], filtered_spans, format_opts)
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
def execute_all(pipeline, text, structures_config, **options)
|
|
56
|
+
raise Error, 'structures must be a Hash' unless structures_config.is_a?(Hash)
|
|
57
|
+
|
|
58
|
+
structures_config.each_with_object({}) do |(parent, fields), results|
|
|
59
|
+
parsed_config = { name: parent, fields: fields }
|
|
60
|
+
results[parent.to_s] = pipeline.execute(self, text, parsed_config, **options)
|
|
61
|
+
end
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
private
|
|
65
|
+
|
|
66
|
+
def extract_structure_config(input)
|
|
67
|
+
name = input[:name] || input['name']
|
|
68
|
+
fields = input[:fields] || input['fields']
|
|
69
|
+
|
|
70
|
+
return [name, fields] if name && fields
|
|
71
|
+
return input.first if name.nil? && fields.nil? && input.length == 1
|
|
72
|
+
|
|
73
|
+
raise Error, 'structure config must include :name and :fields'
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
def extract_spans(logits, parsed, prepared, options)
|
|
77
|
+
label_positions = options.fetch(:label_positions) do
|
|
78
|
+
inference.label_positions_for(prepared.word_ids, parsed[:labels].length)
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
@span_extractor.extract_spans_by_label(
|
|
82
|
+
logits,
|
|
83
|
+
parsed[:labels],
|
|
84
|
+
label_positions,
|
|
85
|
+
prepared,
|
|
86
|
+
threshold: options.fetch(:threshold, 0.5)
|
|
87
|
+
)
|
|
88
|
+
end
|
|
89
|
+
end
|
|
90
|
+
end
|
|
91
|
+
end
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Gliner
|
|
4
|
+
class TextProcessor
|
|
5
|
+
def initialize(tokenizer)
|
|
6
|
+
@tokenizer = tokenizer
|
|
7
|
+
@word_pre_tokenizer = Tokenizers::PreTokenizers::BertPreTokenizer.new
|
|
8
|
+
end
|
|
9
|
+
|
|
10
|
+
def normalize_text(text)
|
|
11
|
+
str = text.to_s
|
|
12
|
+
str = '.' if str.empty?
|
|
13
|
+
str.end_with?('.', '!', '?') ? str : "#{str}."
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
def split_words(text)
|
|
17
|
+
text = text.to_s
|
|
18
|
+
|
|
19
|
+
tokens = []
|
|
20
|
+
starts = []
|
|
21
|
+
ends = []
|
|
22
|
+
|
|
23
|
+
@word_pre_tokenizer.pre_tokenize_str(text).each do |(token, (start_pos, end_pos))|
|
|
24
|
+
token = token.to_s.downcase
|
|
25
|
+
|
|
26
|
+
next if token.empty?
|
|
27
|
+
|
|
28
|
+
tokens << token
|
|
29
|
+
starts << start_pos
|
|
30
|
+
ends << end_pos
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
[tokens, starts, ends]
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
def encode_pretokenized(tokens)
|
|
37
|
+
enc = @tokenizer.encode(tokens, is_pretokenized: true, add_special_tokens: false)
|
|
38
|
+
|
|
39
|
+
{ ids: enc.ids, word_ids: enc.word_ids }
|
|
40
|
+
end
|
|
41
|
+
end
|
|
42
|
+
end
|
data/lib/gliner.rb
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require 'gliner/version'
|
|
4
|
+
require 'gliner/model'
|
|
5
|
+
require 'gliner/runners/prepared_task'
|
|
6
|
+
require 'gliner/runners/entity_runner'
|
|
7
|
+
require 'gliner/runners/structured_runner'
|
|
8
|
+
require 'gliner/runners/classification_runner'
|
|
9
|
+
|
|
10
|
+
module Gliner
|
|
11
|
+
Error = Class.new(StandardError)
|
|
12
|
+
|
|
13
|
+
PreparedInput = Data.define(
|
|
14
|
+
:input_ids,
|
|
15
|
+
:word_ids,
|
|
16
|
+
:attention_mask,
|
|
17
|
+
:words_mask,
|
|
18
|
+
:pos_to_word_index,
|
|
19
|
+
:start_map,
|
|
20
|
+
:end_map,
|
|
21
|
+
:original_text,
|
|
22
|
+
:text_len
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
Span = Data.define(:text, :score, :start, :end) do
|
|
26
|
+
def overlaps?(other)
|
|
27
|
+
!(self.end <= other.start || start >= other.end)
|
|
28
|
+
end
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
FormatOptions = Data.define(:include_confidence, :include_spans) do
|
|
32
|
+
def self.from(input)
|
|
33
|
+
return input if input.is_a?(FormatOptions)
|
|
34
|
+
|
|
35
|
+
new(
|
|
36
|
+
include_confidence: input.fetch(:include_confidence, false),
|
|
37
|
+
include_spans: input.fetch(:include_spans, false)
|
|
38
|
+
)
|
|
39
|
+
end
|
|
40
|
+
end
|
|
41
|
+
|
|
42
|
+
class << self
|
|
43
|
+
attr_writer :model
|
|
44
|
+
|
|
45
|
+
def load(dir, file: 'model_int8.onnx')
|
|
46
|
+
self.model = Model.from_dir(dir, file: file)
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
def model
|
|
50
|
+
@model ||= model_from_env
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
def model!
|
|
54
|
+
fetch_model!
|
|
55
|
+
end
|
|
56
|
+
|
|
57
|
+
def [](config)
|
|
58
|
+
runner_for(config).new(fetch_model!, config)
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
def classify
|
|
62
|
+
Runners::ClassificationRunner
|
|
63
|
+
end
|
|
64
|
+
|
|
65
|
+
private
|
|
66
|
+
|
|
67
|
+
def model_from_env
|
|
68
|
+
dir = ENV.fetch('GLINER_MODEL_DIR', nil)
|
|
69
|
+
return nil if dir.nil? || dir.empty?
|
|
70
|
+
|
|
71
|
+
file = ENV['GLINER_MODEL_FILE'] || 'model_int8.onnx'
|
|
72
|
+
Model.from_dir(dir, file: file)
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
def fetch_model!
|
|
76
|
+
model = self.model
|
|
77
|
+
return model if model
|
|
78
|
+
|
|
79
|
+
raise Error, 'No model loaded. Call Gliner.load("/path/to/model") or set GLINER_MODEL_DIR.'
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
def runner_for(config)
|
|
83
|
+
return Runners::StructuredRunner if structured_config?(config)
|
|
84
|
+
|
|
85
|
+
Runners::EntityRunner
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
def structured_config?(config)
|
|
89
|
+
return false unless config.is_a?(Hash)
|
|
90
|
+
|
|
91
|
+
keys = config.transform_keys(&:to_s)
|
|
92
|
+
return true if keys.key?('name') && keys.key?('fields')
|
|
93
|
+
|
|
94
|
+
config.values.all? { |value| value.is_a?(Array) }
|
|
95
|
+
end
|
|
96
|
+
end
|
|
97
|
+
end
|