gliner 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,117 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'gliner/position_iteration'
4
+
5
+ module Gliner
6
+ class SpanExtractor
7
+ include PositionIteration
8
+
9
+ SCORE_SIMILARITY_THRESHOLD = 0.02
10
+
11
+ def initialize(inference, max_width:)
12
+ @inference = inference
13
+ @max_width = max_width
14
+ end
15
+
16
+ def extract_spans_by_label(logits, labels, label_positions, prepared, threshold: 0.5, thresholds_by_label: nil)
17
+ labels.each_with_index.with_object({}) do |(label, label_index), out|
18
+ out[label.to_s] = find_spans_for_label(
19
+ logits: logits,
20
+ label_index: label_index,
21
+ label_positions: label_positions,
22
+ prepared: prepared,
23
+ threshold: threshold_for(label, threshold, thresholds_by_label)
24
+ )
25
+ end
26
+ end
27
+
28
+ def find_spans_for_label(logits:, label_index:, label_positions:, prepared:, threshold:)
29
+ seq_len = logits.first.length
30
+
31
+ each_position_width(seq_len, prepared, @max_width).filter_map do |pos, start_word, width|
32
+ score = calculate_span_score(logits, pos, width, label_index, label_positions)
33
+ next if score < threshold
34
+
35
+ build_span(prepared, start_word, start_word + width, score)
36
+ end
37
+ end
38
+
39
+ def choose_best_span(spans)
40
+ return nil if spans.empty?
41
+
42
+ sorted = spans.sort_by { |s| [-s.score, (s.end - s.start), s.text.length] }
43
+ best = sorted.first
44
+ best_score = best.score
45
+ near = spans_within_threshold(sorted, best_score)
46
+
47
+ near.min_by { |s| [(s.end - s.start), -s.score, s.text.length] } || best
48
+ end
49
+
50
+ def format_single_span(span, opts)
51
+ format_span(span, opts)
52
+ end
53
+
54
+ def format_spans(spans, opts)
55
+ return [] if spans.empty?
56
+
57
+ sorted = spans.sort_by { |s| -s.score }
58
+ selected = []
59
+
60
+ sorted.each do |span|
61
+ overlaps = selected.any? { |s| span.overlaps?(s) }
62
+ next if overlaps
63
+
64
+ selected << span
65
+ end
66
+
67
+ selected.map { |span| format_span(span, opts) }
68
+ end
69
+
70
+ private
71
+
72
+ def calculate_span_score(logits, pos, width, label_index, label_positions)
73
+ logit = @inference.label_logit(logits, pos, width, label_index, label_positions)
74
+ @inference.sigmoid(logit)
75
+ end
76
+
77
+ def spans_within_threshold(sorted_spans, best_score)
78
+ sorted_spans.take_while { |span| (best_score - span.score) <= SCORE_SIMILARITY_THRESHOLD }
79
+ end
80
+
81
+ def threshold_for(label, default_threshold, thresholds_by_label)
82
+ return default_threshold unless thresholds_by_label&.key?(label.to_s)
83
+
84
+ thresholds_by_label.fetch(label.to_s)
85
+ end
86
+
87
+ def build_span(prepared, start_word, end_word, score)
88
+ char_start = prepared.start_map[start_word]
89
+ char_end = prepared.end_map[end_word]
90
+
91
+ return nil if char_start.nil? || char_end.nil?
92
+
93
+ text_span = prepared.original_text[char_start...char_end].to_s.strip
94
+
95
+ return nil if text_span.empty?
96
+
97
+ Span.new(text: text_span, score: score, start: char_start, end: char_end)
98
+ end
99
+
100
+ def format_span(span, opts)
101
+ return nil if span.nil?
102
+
103
+ format_opts = FormatOptions.from(opts)
104
+ return span.text unless format_opts.include_confidence || format_opts.include_spans
105
+
106
+ result = { 'text' => span.text }
107
+ result['confidence'] = span.score if format_opts.include_confidence
108
+
109
+ if format_opts.include_spans
110
+ result['start'] = span.start
111
+ result['end'] = span.end
112
+ end
113
+
114
+ result
115
+ end
116
+ end
117
+ end
@@ -0,0 +1,94 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Gliner
4
+ class StructuredExtractor
5
+ def initialize(span_extractor)
6
+ @span_extractor = span_extractor
7
+ end
8
+
9
+ def apply_choice_filters(spans_by_label, parsed_fields)
10
+ filtered = spans_by_label.transform_values(&:dup)
11
+
12
+ parsed_fields.each do |field|
13
+ next unless field[:choices]&.any?
14
+
15
+ label = field[:name]
16
+ spans = filtered.fetch(label, [])
17
+ filtered[label] = filter_spans_by_choices(spans, field[:choices])
18
+ end
19
+
20
+ filtered
21
+ end
22
+
23
+ def filter_spans_by_choices(spans, choices)
24
+ return spans if spans.empty? || choices.nil? || choices.empty?
25
+
26
+ normalized_choices = choices.map { |choice| normalize_choice(choice) }
27
+ matched = spans.select { |span| normalized_choices.include?(normalize_choice(span.text)) }
28
+
29
+ return spans if matched.empty?
30
+
31
+ matched
32
+ end
33
+
34
+ def build_structure_instances(parsed_fields, spans_by_label, opts)
35
+ format_opts = FormatOptions.from(opts)
36
+ anchor_field = anchor_field_for(parsed_fields)
37
+ return [{}] unless anchor_field
38
+
39
+ anchors = spans_by_label.fetch(anchor_field[:name], [])
40
+ return [format_structure_object(parsed_fields, spans_by_label, format_opts)] if anchors.empty?
41
+
42
+ instance_spans = build_instance_spans(anchors, spans_by_label)
43
+ format_instances(parsed_fields, instance_spans, format_opts)
44
+ end
45
+
46
+ def format_structure_object(parsed_fields, spans_by_label, opts)
47
+ obj = {}
48
+
49
+ parsed_fields.each do |field|
50
+ key = field[:name]
51
+ spans = spans_by_label.fetch(key, [])
52
+
53
+ if field[:dtype] == :str
54
+ best = @span_extractor.choose_best_span(spans)
55
+ obj[key] = @span_extractor.format_single_span(best, opts)
56
+ else
57
+ obj[key] = @span_extractor.format_spans(spans, opts)
58
+ end
59
+ end
60
+
61
+ obj
62
+ end
63
+
64
+ private
65
+
66
+ def anchor_field_for(parsed_fields)
67
+ parsed_fields.find { |field| field[:dtype] == :str } || parsed_fields.first
68
+ end
69
+
70
+ def build_instance_spans(anchors, spans_by_label)
71
+ anchors_sorted = anchors.sort_by(&:start)
72
+ instance_spans = anchors_sorted.map { Hash.new { |hash, key| hash[key] = [] } }
73
+
74
+ spans_by_label.each do |label, spans|
75
+ spans.each do |span|
76
+ anchor_index = anchors_sorted.rindex { |anchor| anchor.start <= span.start } || 0
77
+ instance_spans[anchor_index][label] << span
78
+ end
79
+ end
80
+
81
+ instance_spans
82
+ end
83
+
84
+ def format_instances(parsed_fields, instance_spans, opts)
85
+ instance_spans.map do |field_spans|
86
+ format_structure_object(parsed_fields, field_spans, opts)
87
+ end
88
+ end
89
+
90
+ def normalize_choice(value)
91
+ value.to_s.strip.downcase
92
+ end
93
+ end
94
+ end
@@ -0,0 +1,29 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Gliner
4
+ class Task
5
+ attr_reader :config_parser, :inference, :input_builder
6
+
7
+ def initialize(config_parser:, inference:, input_builder:)
8
+ @config_parser = config_parser
9
+ @inference = inference
10
+ @input_builder = input_builder
11
+ end
12
+
13
+ def parse_config(input) = raise NotImplementedError
14
+
15
+ def task_type = raise NotImplementedError
16
+
17
+ def label_prefix = raise NotImplementedError
18
+
19
+ def build_prompt(parsed) = raise NotImplementedError
20
+
21
+ def labels(parsed) = raise NotImplementedError
22
+
23
+ def process_output(logits, parsed, prepared, options) = raise NotImplementedError
24
+
25
+ def normalize_text? = false
26
+
27
+ def needs_cls_logits? = false
28
+ end
29
+ end
@@ -0,0 +1,101 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Gliner
4
+ module Tasks
5
+ class Classification < Task
6
+ def initialize(config_parser:, inference:, input_builder:, classifier:)
7
+ super(config_parser: config_parser, inference: inference, input_builder: input_builder)
8
+
9
+ @classifier = classifier
10
+ end
11
+
12
+ def parse_config(input)
13
+ raise Error, 'classification config must be a Hash' unless input.is_a?(Hash)
14
+
15
+ name, task_config = extract_task_config(input)
16
+ parsed = config_parser.parse_classification_task(name, task_config)
17
+ parsed.merge(name: name.to_s)
18
+ end
19
+
20
+ def task_type
21
+ Inference::TASK_TYPE_CLASSIFICATION
22
+ end
23
+
24
+ def label_prefix
25
+ '[L]'
26
+ end
27
+
28
+ def build_prompt(parsed)
29
+ config_parser.build_prompt(parsed[:name], parsed[:label_descs])
30
+ end
31
+
32
+ def labels(parsed)
33
+ parsed[:labels]
34
+ end
35
+
36
+ def needs_cls_logits?
37
+ inference.has_cls_logits
38
+ end
39
+
40
+ def process_output(logits, parsed, prepared, options)
41
+ include_confidence = options.fetch(:include_confidence, false)
42
+ threshold_override = options[:threshold]
43
+ cls_threshold = threshold_override.nil? ? parsed[:cls_threshold] : threshold_override
44
+
45
+ scores = classification_scores(logits, parsed, prepared, options)
46
+ @classifier.format_classification(
47
+ scores,
48
+ labels: parsed[:labels],
49
+ multi_label: parsed[:multi_label],
50
+ include_confidence: include_confidence,
51
+ cls_threshold: cls_threshold
52
+ )
53
+ end
54
+
55
+ def execute_all(pipeline, text, tasks_config, **options)
56
+ raise Error, 'tasks must be a Hash' unless tasks_config.is_a?(Hash)
57
+
58
+ tasks_config.each_with_object({}) do |(task_name, task_config), results|
59
+ parsed_config = { name: task_name, config: task_config }
60
+ results[task_name.to_s] = pipeline.execute(self, text, parsed_config, **options)
61
+ end
62
+ end
63
+
64
+ private
65
+
66
+ def extract_task_config(input)
67
+ name = input[:name] || input['name']
68
+ task_config = input[:config] || input['config']
69
+
70
+ return [name, task_config] if name && task_config
71
+ return input.first if name.nil? && task_config.nil? && input.length == 1
72
+
73
+ raise Error, 'classification config must include :name and :config'
74
+ end
75
+
76
+ def classification_scores(logits, parsed, prepared, options)
77
+ return cls_scores(logits, parsed) if cls_logits?(logits)
78
+
79
+ label_positions = options.fetch(:label_positions) do
80
+ inference.label_positions_for(prepared.word_ids, parsed[:labels].length)
81
+ end
82
+
83
+ @classifier.classification_scores(
84
+ logits,
85
+ parsed[:labels],
86
+ label_positions,
87
+ prepared
88
+ )
89
+ end
90
+
91
+ def cls_logits?(logits)
92
+ logits.is_a?(Hash) && logits.key?(:cls_logits)
93
+ end
94
+
95
+ def cls_scores(logits, parsed)
96
+ cls_logits = Array(logits.fetch(:cls_logits).fetch(0))
97
+ parsed[:multi_label] ? cls_logits.map { |value| inference.sigmoid(value) } : inference.softmax(cls_logits)
98
+ end
99
+ end
100
+ end
101
+ end
@@ -0,0 +1,72 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Gliner
4
+ module Tasks
5
+ class EntityExtraction < Task
6
+ def initialize(config_parser:, inference:, input_builder:, span_extractor:)
7
+ super(config_parser: config_parser, inference: inference, input_builder: input_builder)
8
+ @span_extractor = span_extractor
9
+ end
10
+
11
+ def parse_config(input)
12
+ config_parser.parse_entity_types(input)
13
+ end
14
+
15
+ def task_type
16
+ Inference::TASK_TYPE_ENTITIES
17
+ end
18
+
19
+ def label_prefix
20
+ '[E]'
21
+ end
22
+
23
+ def build_prompt(parsed)
24
+ config_parser.build_prompt('entities', parsed[:descriptions])
25
+ end
26
+
27
+ def labels(parsed)
28
+ parsed[:labels]
29
+ end
30
+
31
+ def process_output(logits, parsed, prepared, options)
32
+ threshold = options.fetch(:threshold, 0.5)
33
+ format_opts = FormatOptions.from(options)
34
+ label_positions = options[:label_positions] || inference.label_positions_for(prepared.word_ids, parsed[:labels].length)
35
+
36
+ spans_by_label = extract_spans(logits, parsed, prepared, label_positions, threshold)
37
+
38
+ { 'entities' => format_entities(parsed, spans_by_label, format_opts) }
39
+ end
40
+
41
+ private
42
+
43
+ def extract_spans(logits, parsed, prepared, label_positions, threshold)
44
+ @span_extractor.extract_spans_by_label(
45
+ logits,
46
+ parsed[:labels],
47
+ label_positions,
48
+ prepared,
49
+ threshold: threshold,
50
+ thresholds_by_label: parsed[:thresholds]
51
+ )
52
+ end
53
+
54
+ def format_entities(parsed, spans_by_label, format_opts)
55
+ parsed[:labels].each_with_object({}) do |label, entities|
56
+ spans = spans_by_label.fetch(label)
57
+ dtype = parsed[:dtypes].fetch(label, :list)
58
+
59
+ entities[label] = format_entity_value(spans, dtype, format_opts)
60
+ end
61
+ end
62
+
63
+ def format_entity_value(spans, dtype, format_opts)
64
+ if dtype == :str
65
+ @span_extractor.format_single_span(@span_extractor.choose_best_span(spans), format_opts)
66
+ else
67
+ @span_extractor.format_spans(spans, format_opts)
68
+ end
69
+ end
70
+ end
71
+ end
72
+ end
@@ -0,0 +1,91 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Gliner
4
+ module Tasks
5
+ class JsonExtraction < Task
6
+ def initialize(config_parser:, inference:, input_builder:, span_extractor:, structured_extractor:)
7
+ super(config_parser: config_parser, inference: inference, input_builder: input_builder)
8
+
9
+ @span_extractor = span_extractor
10
+ @structured_extractor = structured_extractor
11
+ end
12
+
13
+ def parse_config(input)
14
+ raise Error, 'structure config must be a Hash' unless input.is_a?(Hash)
15
+
16
+ name, fields = extract_structure_config(input)
17
+ parsed_fields = Array(fields).map { |spec| config_parser.parse_field_spec(spec.to_s) }
18
+
19
+ {
20
+ name: name.to_s,
21
+ parsed_fields: parsed_fields,
22
+ labels: parsed_fields.map { |field| field[:name] },
23
+ descriptions: config_parser.build_field_descriptions(parsed_fields)
24
+ }
25
+ end
26
+
27
+ def task_type
28
+ Inference::TASK_TYPE_JSON
29
+ end
30
+
31
+ def label_prefix
32
+ '[C]'
33
+ end
34
+
35
+ def normalize_text?
36
+ true
37
+ end
38
+
39
+ def build_prompt(parsed)
40
+ config_parser.build_prompt(parsed[:name], parsed[:descriptions])
41
+ end
42
+
43
+ def labels(parsed)
44
+ parsed[:labels]
45
+ end
46
+
47
+ def process_output(logits, parsed, prepared, options)
48
+ spans_by_label = extract_spans(logits, parsed, prepared, options)
49
+ filtered_spans = @structured_extractor.apply_choice_filters(spans_by_label, parsed[:parsed_fields])
50
+ format_opts = FormatOptions.from(options)
51
+
52
+ @structured_extractor.build_structure_instances(parsed[:parsed_fields], filtered_spans, format_opts)
53
+ end
54
+
55
+ def execute_all(pipeline, text, structures_config, **options)
56
+ raise Error, 'structures must be a Hash' unless structures_config.is_a?(Hash)
57
+
58
+ structures_config.each_with_object({}) do |(parent, fields), results|
59
+ parsed_config = { name: parent, fields: fields }
60
+ results[parent.to_s] = pipeline.execute(self, text, parsed_config, **options)
61
+ end
62
+ end
63
+
64
+ private
65
+
66
+ def extract_structure_config(input)
67
+ name = input[:name] || input['name']
68
+ fields = input[:fields] || input['fields']
69
+
70
+ return [name, fields] if name && fields
71
+ return input.first if name.nil? && fields.nil? && input.length == 1
72
+
73
+ raise Error, 'structure config must include :name and :fields'
74
+ end
75
+
76
+ def extract_spans(logits, parsed, prepared, options)
77
+ label_positions = options.fetch(:label_positions) do
78
+ inference.label_positions_for(prepared.word_ids, parsed[:labels].length)
79
+ end
80
+
81
+ @span_extractor.extract_spans_by_label(
82
+ logits,
83
+ parsed[:labels],
84
+ label_positions,
85
+ prepared,
86
+ threshold: options.fetch(:threshold, 0.5)
87
+ )
88
+ end
89
+ end
90
+ end
91
+ end
@@ -0,0 +1,42 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Gliner
4
+ class TextProcessor
5
+ def initialize(tokenizer)
6
+ @tokenizer = tokenizer
7
+ @word_pre_tokenizer = Tokenizers::PreTokenizers::BertPreTokenizer.new
8
+ end
9
+
10
+ def normalize_text(text)
11
+ str = text.to_s
12
+ str = '.' if str.empty?
13
+ str.end_with?('.', '!', '?') ? str : "#{str}."
14
+ end
15
+
16
+ def split_words(text)
17
+ text = text.to_s
18
+
19
+ tokens = []
20
+ starts = []
21
+ ends = []
22
+
23
+ @word_pre_tokenizer.pre_tokenize_str(text).each do |(token, (start_pos, end_pos))|
24
+ token = token.to_s.downcase
25
+
26
+ next if token.empty?
27
+
28
+ tokens << token
29
+ starts << start_pos
30
+ ends << end_pos
31
+ end
32
+
33
+ [tokens, starts, ends]
34
+ end
35
+
36
+ def encode_pretokenized(tokens)
37
+ enc = @tokenizer.encode(tokens, is_pretokenized: true, add_special_tokens: false)
38
+
39
+ { ids: enc.ids, word_ids: enc.word_ids }
40
+ end
41
+ end
42
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Gliner
4
+ VERSION = '0.1.0'
5
+ end
data/lib/gliner.rb ADDED
@@ -0,0 +1,97 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'gliner/version'
4
+ require 'gliner/model'
5
+ require 'gliner/runners/prepared_task'
6
+ require 'gliner/runners/entity_runner'
7
+ require 'gliner/runners/structured_runner'
8
+ require 'gliner/runners/classification_runner'
9
+
10
+ module Gliner
11
+ Error = Class.new(StandardError)
12
+
13
+ PreparedInput = Data.define(
14
+ :input_ids,
15
+ :word_ids,
16
+ :attention_mask,
17
+ :words_mask,
18
+ :pos_to_word_index,
19
+ :start_map,
20
+ :end_map,
21
+ :original_text,
22
+ :text_len
23
+ )
24
+
25
+ Span = Data.define(:text, :score, :start, :end) do
26
+ def overlaps?(other)
27
+ !(self.end <= other.start || start >= other.end)
28
+ end
29
+ end
30
+
31
+ FormatOptions = Data.define(:include_confidence, :include_spans) do
32
+ def self.from(input)
33
+ return input if input.is_a?(FormatOptions)
34
+
35
+ new(
36
+ include_confidence: input.fetch(:include_confidence, false),
37
+ include_spans: input.fetch(:include_spans, false)
38
+ )
39
+ end
40
+ end
41
+
42
+ class << self
43
+ attr_writer :model
44
+
45
+ def load(dir, file: 'model_int8.onnx')
46
+ self.model = Model.from_dir(dir, file: file)
47
+ end
48
+
49
+ def model
50
+ @model ||= model_from_env
51
+ end
52
+
53
+ def model!
54
+ fetch_model!
55
+ end
56
+
57
+ def [](config)
58
+ runner_for(config).new(fetch_model!, config)
59
+ end
60
+
61
+ def classify
62
+ Runners::ClassificationRunner
63
+ end
64
+
65
+ private
66
+
67
+ def model_from_env
68
+ dir = ENV.fetch('GLINER_MODEL_DIR', nil)
69
+ return nil if dir.nil? || dir.empty?
70
+
71
+ file = ENV['GLINER_MODEL_FILE'] || 'model_int8.onnx'
72
+ Model.from_dir(dir, file: file)
73
+ end
74
+
75
+ def fetch_model!
76
+ model = self.model
77
+ return model if model
78
+
79
+ raise Error, 'No model loaded. Call Gliner.load("/path/to/model") or set GLINER_MODEL_DIR.'
80
+ end
81
+
82
+ def runner_for(config)
83
+ return Runners::StructuredRunner if structured_config?(config)
84
+
85
+ Runners::EntityRunner
86
+ end
87
+
88
+ def structured_config?(config)
89
+ return false unless config.is_a?(Hash)
90
+
91
+ keys = config.transform_keys(&:to_s)
92
+ return true if keys.key?('name') && keys.key?('fields')
93
+
94
+ config.values.all? { |value| value.is_a?(Array) }
95
+ end
96
+ end
97
+ end