gliner 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,124 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'gliner/inference/session_validator'
4
+
5
+ module Gliner
6
+ class Inference
7
+ TASK_TYPE_ENTITIES = 0
8
+ TASK_TYPE_CLASSIFICATION = 1
9
+ TASK_TYPE_JSON = 2
10
+
11
+ SCHEMA_PREFIX_LENGTH = 4
12
+ LABEL_SPACING = 2
13
+
14
+ Request = Data.define(
15
+ :input_ids,
16
+ :attention_mask,
17
+ :words_mask,
18
+ :text_lengths,
19
+ :task_type,
20
+ :label_positions,
21
+ :label_mask,
22
+ :want_cls
23
+ )
24
+
25
+ IOValidation = Data.define(:input_names, :output_name, :label_index_mode, :has_cls_logits)
26
+
27
+ attr_reader :label_index_mode, :has_cls_logits
28
+
29
+ def initialize(session)
30
+ @session = session
31
+
32
+ validation = SessionValidator[session]
33
+
34
+ @input_names = validation.input_names
35
+ @output_name = validation.output_name
36
+ @label_index_mode = validation.label_index_mode
37
+ @has_cls_logits = validation.has_cls_logits
38
+ end
39
+
40
+ def run(request)
41
+ outputs = output_names_for(request)
42
+ out = @session.run(outputs, build_inputs(request))
43
+ format_outputs(out, outputs)
44
+ end
45
+
46
+ def label_positions_for(word_ids, label_count)
47
+ label_count.times.map do |i|
48
+ combined_idx = SCHEMA_PREFIX_LENGTH + (i * LABEL_SPACING)
49
+ pos = word_ids.index(combined_idx)
50
+
51
+ raise Error, "Could not locate label position at combined index #{combined_idx}" if pos.nil?
52
+
53
+ pos
54
+ end
55
+ end
56
+
57
+ def label_logit(logits, pos, width, label_index, label_positions)
58
+ if @label_index_mode == :label_position
59
+ raise Error, 'Label positions required for span_logits output' if label_positions.nil?
60
+
61
+ label_pos = label_positions.fetch(label_index)
62
+ logits[0][pos][width][label_pos]
63
+ else
64
+ logits[0][pos][width][label_index]
65
+ end
66
+ end
67
+
68
+ def sigmoid(value)
69
+ 1.0 / (1.0 + Math.exp(-value))
70
+ end
71
+
72
+ def softmax(values)
73
+ max_value = values.max
74
+ exps = values.map { |value| Math.exp(value - max_value) }
75
+ sum = exps.sum
76
+ exps.map { |value| value / sum }
77
+ end
78
+
79
+ private
80
+
81
+ def build_inputs(request)
82
+ inputs = base_inputs(request)
83
+ add_token_type_ids(inputs, request)
84
+ filter_inputs(inputs)
85
+ end
86
+
87
+ def base_inputs(request)
88
+ {
89
+ input_ids: [request.input_ids],
90
+ attention_mask: [request.attention_mask],
91
+ words_mask: [request.words_mask],
92
+ text_lengths: Array(request.text_lengths).flatten,
93
+ task_type: [request.task_type],
94
+ label_positions: [request.label_positions],
95
+ label_mask: [request.label_mask]
96
+ }
97
+ end
98
+
99
+ def add_token_type_ids(inputs, request)
100
+ return inputs unless @input_names&.include?('token_type_ids')
101
+
102
+ inputs[:token_type_ids] = [Array.new(request.input_ids.length, 0)]
103
+ inputs
104
+ end
105
+
106
+ def filter_inputs(inputs)
107
+ return inputs unless @input_names
108
+
109
+ inputs.select { |name, _| @input_names.include?(name.to_s) }
110
+ end
111
+
112
+ def output_names_for(request)
113
+ output_names = [@output_name]
114
+ output_names << 'cls_logits' if request.want_cls && @has_cls_logits
115
+ output_names
116
+ end
117
+
118
+ def format_outputs(out, output_names)
119
+ return { logits: out.fetch(0), cls_logits: out.fetch(1) } if output_names.length > 1
120
+
121
+ out.fetch(0)
122
+ end
123
+ end
124
+ end
@@ -0,0 +1,117 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'set'
4
+
5
+ module Gliner
6
+ class InputBuilder
7
+ def initialize(text_processor, max_seq_len:)
8
+ @text_processor = text_processor
9
+ @max_seq_len = max_seq_len
10
+ end
11
+
12
+ def prepare(text, schema_tokens, already_normalized: false)
13
+ normalized_text = normalize_text(text, already_normalized: already_normalized)
14
+ words, start_map, end_map = @text_processor.split_words(normalized_text)
15
+ input_ids, word_ids = encode_tokens(schema_tokens, words)
16
+ input_ids, word_ids = truncate_inputs(input_ids, word_ids, max_len: @max_seq_len)
17
+
18
+ text_start_index = schema_tokens.length + 1
19
+ text_len = infer_effective_text_len(word_ids, text_start_index, words.length)
20
+
21
+ context = {
22
+ input_ids: input_ids,
23
+ word_ids: word_ids,
24
+ text_start_index: text_start_index,
25
+ start_map: start_map,
26
+ end_map: end_map,
27
+ original_text: normalized_text,
28
+ text_len: text_len
29
+ }
30
+
31
+ build_prepared_input(context)
32
+ end
33
+
34
+ def schema_tokens_for(prompt:, labels:, label_prefix:)
35
+ tokens = ['(', '[P]', prompt.to_s, '(']
36
+
37
+ labels.each do |label|
38
+ tokens << label_prefix
39
+ tokens << label.to_s
40
+ end
41
+
42
+ tokens.push(')', ')')
43
+ tokens
44
+ end
45
+
46
+ private
47
+
48
+ def normalize_text(text, already_normalized:)
49
+ already_normalized ? text.to_s : @text_processor.normalize_text(text)
50
+ end
51
+
52
+ def encode_tokens(schema_tokens, words)
53
+ combined_tokens = schema_tokens + ['[SEP_TEXT]'] + words
54
+ encoded = @text_processor.encode_pretokenized(combined_tokens)
55
+ [encoded[:ids], encoded[:word_ids]]
56
+ end
57
+
58
+ def truncate_inputs(input_ids, word_ids, max_len:)
59
+ return [input_ids, word_ids] if input_ids.length <= max_len
60
+
61
+ [input_ids.take(max_len), word_ids.take(max_len)]
62
+ end
63
+
64
+ def build_prepared_input(context)
65
+ input_ids = context.fetch(:input_ids)
66
+ word_ids = context.fetch(:word_ids)
67
+ text_start_index = context.fetch(:text_start_index)
68
+
69
+ word_analysis = analyze_words(word_ids, text_start_index)
70
+
71
+ PreparedInput.new(
72
+ input_ids: input_ids,
73
+ word_ids: word_ids,
74
+ attention_mask: Array.new(input_ids.length, 1),
75
+ words_mask: word_analysis[:mask],
76
+ pos_to_word_index: word_analysis[:index_map],
77
+ start_map: context.fetch(:start_map),
78
+ end_map: context.fetch(:end_map),
79
+ original_text: context.fetch(:original_text),
80
+ text_len: context.fetch(:text_len)
81
+ )
82
+ end
83
+
84
+ def analyze_words(word_ids, text_start_index)
85
+ mask = Array.new(word_ids.length, 0)
86
+ index_map = Array.new(word_ids.length)
87
+ last_word_id = nil
88
+ seen = Set.new
89
+
90
+ word_ids.each_with_index do |word_id, i|
91
+ next unless word_id
92
+
93
+ # Build mask (word boundaries)
94
+ if word_id != last_word_id && word_id >= text_start_index
95
+ mask[i] = 1
96
+ last_word_id = word_id
97
+ end
98
+
99
+ # Build index map (first occurrence)
100
+ unless seen.include?(word_id)
101
+ seen << word_id
102
+ index_map[i] = word_id - text_start_index if word_id >= text_start_index
103
+ end
104
+ end
105
+
106
+ { mask: mask, index_map: index_map }
107
+ end
108
+
109
+ def infer_effective_text_len(word_ids, text_start_index, full_text_len)
110
+ max_text_word_id = word_ids.compact.select { |word_id| word_id >= text_start_index }.max
111
+ return full_text_len if max_text_word_id.nil?
112
+
113
+ present = (max_text_word_id - text_start_index) + 1
114
+ [present, full_text_len].min
115
+ end
116
+ end
117
+ end
@@ -0,0 +1,142 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'json'
4
+ require 'onnxruntime'
5
+ require 'tokenizers'
6
+
7
+ require 'gliner/text_processor'
8
+ require 'gliner/config_parser'
9
+ require 'gliner/inference'
10
+ require 'gliner/input_builder'
11
+ require 'gliner/span_extractor'
12
+ require 'gliner/classifier'
13
+ require 'gliner/structured_extractor'
14
+ require 'gliner/task'
15
+ require 'gliner/pipeline'
16
+ require 'gliner/tasks/entity_extraction'
17
+ require 'gliner/tasks/classification'
18
+ require 'gliner/tasks/json_extraction'
19
+
20
+ module Gliner
21
+ class Model
22
+ DEFAULT_MAX_WIDTH = 8
23
+ DEFAULT_MAX_SEQ_LEN = 512
24
+
25
+ def self.from_dir(dir, file: 'model_int8.onnx')
26
+ config_path = File.join(dir, 'config.json')
27
+ config = File.exist?(config_path) ? JSON.parse(File.read(config_path)) : {}
28
+
29
+ new(
30
+ model_path: File.join(dir, file),
31
+ tokenizer_path: File.join(dir, 'tokenizer.json'),
32
+ max_width: config.fetch('max_width', DEFAULT_MAX_WIDTH),
33
+ max_seq_len: config.fetch('max_seq_len', DEFAULT_MAX_SEQ_LEN)
34
+ )
35
+ end
36
+
37
+ def initialize(model_path:, tokenizer_path:, max_width: DEFAULT_MAX_WIDTH, max_seq_len: DEFAULT_MAX_SEQ_LEN)
38
+ @model_path = model_path
39
+ @tokenizer_path = tokenizer_path
40
+ @max_width = Integer(max_width)
41
+ @max_seq_len = Integer(max_seq_len)
42
+
43
+ tokenizer = Tokenizers.from_file(@tokenizer_path)
44
+ session = OnnxRuntime::InferenceSession.new(@model_path)
45
+
46
+ @text_processor = TextProcessor.new(tokenizer)
47
+ @inference = Inference.new(session)
48
+ end
49
+
50
+ def config_parser
51
+ @config_parser ||= ConfigParser.new
52
+ end
53
+
54
+ def pipeline
55
+ @pipeline ||= Pipeline.new(text_processor: @text_processor, inference: @inference)
56
+ end
57
+
58
+ def input_builder
59
+ @input_builder ||= InputBuilder.new(@text_processor, max_seq_len: @max_seq_len)
60
+ end
61
+
62
+ def span_extractor
63
+ @span_extractor ||= SpanExtractor.new(@inference, max_width: @max_width)
64
+ end
65
+
66
+ def structured_extractor
67
+ @structured_extractor ||= StructuredExtractor.new(span_extractor)
68
+ end
69
+
70
+ def classifier
71
+ @classifier ||= Classifier.new(@inference, max_width: @max_width)
72
+ end
73
+
74
+ def entity_task
75
+ @entity_task ||= Tasks::EntityExtraction.new(
76
+ config_parser: config_parser,
77
+ inference: @inference,
78
+ input_builder: input_builder,
79
+ span_extractor: span_extractor
80
+ )
81
+ end
82
+
83
+ def classification_task
84
+ @classification_task ||= Tasks::Classification.new(
85
+ config_parser: config_parser,
86
+ inference: @inference,
87
+ input_builder: input_builder,
88
+ classifier: classifier
89
+ )
90
+ end
91
+
92
+ def json_task
93
+ @json_task ||= Tasks::JsonExtraction.new(
94
+ config_parser: config_parser,
95
+ inference: @inference,
96
+ input_builder: input_builder,
97
+ span_extractor: span_extractor,
98
+ structured_extractor: structured_extractor
99
+ )
100
+ end
101
+
102
+ def extract_entities(text, entity_types, **options)
103
+ threshold = options.fetch(:threshold, 0.5)
104
+ include_confidence = options.fetch(:include_confidence, false)
105
+ include_spans = options.fetch(:include_spans, false)
106
+
107
+ pipeline.execute(
108
+ entity_task,
109
+ text,
110
+ entity_types,
111
+ threshold: threshold,
112
+ include_confidence: include_confidence,
113
+ include_spans: include_spans
114
+ )
115
+ end
116
+
117
+ def classify_text(text, tasks, **options)
118
+ include_confidence = options.fetch(:include_confidence, false)
119
+ threshold = options[:threshold]
120
+
121
+ task_options = { include_confidence: include_confidence }
122
+ task_options[:threshold] = threshold unless threshold.nil?
123
+
124
+ classification_task.execute_all(pipeline, text, tasks, **task_options)
125
+ end
126
+
127
+ def extract_json(text, structures, **options)
128
+ threshold = options.fetch(:threshold, 0.5)
129
+ include_confidence = options.fetch(:include_confidence, false)
130
+ include_spans = options.fetch(:include_spans, false)
131
+
132
+ json_task.execute_all(
133
+ pipeline,
134
+ text,
135
+ structures,
136
+ threshold: threshold,
137
+ include_confidence: include_confidence,
138
+ include_spans: include_spans
139
+ )
140
+ end
141
+ end
142
+ end
@@ -0,0 +1,64 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Gliner
4
+ class Pipeline
5
+ def initialize(text_processor:, inference:)
6
+ @text_processor = text_processor
7
+ @inference = inference
8
+ end
9
+
10
+ def execute(task, text, config, **options)
11
+ parsed = task.parse_config(config)
12
+ prepared_text = prepare_text(task, text)
13
+ labels = task.labels(parsed)
14
+ prepared = prepare_input(task, prepared_text, parsed, labels)
15
+ label_positions = label_positions_for(prepared, labels.length)
16
+ logits = run_inference(task, prepared, labels, label_positions)
17
+
18
+ task.process_output(logits, parsed, prepared, options.merge(label_positions: label_positions))
19
+ end
20
+
21
+ private
22
+
23
+ def prepare_text(task, text)
24
+ return @text_processor.normalize_text(text) if task.normalize_text?
25
+
26
+ text.to_s
27
+ end
28
+
29
+ def prepare_input(task, prepared_text, parsed, labels)
30
+ schema_tokens = task.input_builder.schema_tokens_for(
31
+ prompt: task.build_prompt(parsed),
32
+ labels: labels,
33
+ label_prefix: task.label_prefix
34
+ )
35
+
36
+ task.input_builder.prepare(
37
+ prepared_text,
38
+ schema_tokens,
39
+ already_normalized: task.normalize_text?
40
+ )
41
+ end
42
+
43
+ def label_positions_for(prepared, label_count)
44
+ @inference.label_positions_for(prepared.word_ids, label_count)
45
+ end
46
+
47
+ def run_inference(task, prepared, labels, label_positions)
48
+ @inference.run(build_request(task, prepared, labels, label_positions))
49
+ end
50
+
51
+ def build_request(task, prepared, labels, label_positions)
52
+ Inference::Request.new(
53
+ input_ids: prepared.input_ids,
54
+ attention_mask: prepared.attention_mask,
55
+ words_mask: prepared.words_mask,
56
+ text_lengths: [prepared.text_len],
57
+ task_type: task.task_type,
58
+ label_positions: label_positions,
59
+ label_mask: Array.new(labels.length, 1),
60
+ want_cls: task.needs_cls_logits?
61
+ )
62
+ end
63
+ end
64
+ end
@@ -0,0 +1,21 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Gliner
4
+ module PositionIteration
5
+ def each_position_width(seq_len, prepared, max_width)
6
+ return enum_for(:each_position_width, seq_len, prepared, max_width) unless block_given?
7
+
8
+ (0...seq_len).each do |pos|
9
+ start_word = prepared.pos_to_word_index[pos]
10
+ next unless start_word
11
+
12
+ (0...max_width).each do |width|
13
+ end_word = start_word + width
14
+ next if end_word >= prepared.text_len
15
+
16
+ yield pos, start_word, width
17
+ end
18
+ end
19
+ end
20
+ end
21
+ end
@@ -0,0 +1,26 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Gliner
4
+ module Runners
5
+ class ClassificationRunner
6
+ def self.[](tasks)
7
+ new(Gliner.model!, tasks)
8
+ end
9
+
10
+ def initialize(model, tasks_config)
11
+ raise Error, 'tasks must be a Hash' unless tasks_config.is_a?(Hash)
12
+
13
+ @tasks = tasks_config.to_h do |name, config|
14
+ parsed = model.classification_task.parse_config(name: name, config: config)
15
+ [name.to_s, PreparedTask.new(model.classification_task, parsed)]
16
+ end
17
+ end
18
+
19
+ def [](text, **options)
20
+ @tasks.transform_values { |task| task.call(text, **options) }
21
+ end
22
+
23
+ alias call []
24
+ end
25
+ end
26
+ end
@@ -0,0 +1,19 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Gliner
4
+ module Runners
5
+ class EntityRunner
6
+ def initialize(model, config)
7
+ parsed = model.entity_task.parse_config(config)
8
+ @task = PreparedTask.new(model.entity_task, parsed)
9
+ end
10
+
11
+ def [](text, **options)
12
+ result = @task.call(text, **options)
13
+ result.fetch('entities')
14
+ end
15
+
16
+ alias call []
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,55 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Gliner
4
+ module Runners
5
+ class PreparedTask
6
+ def initialize(task, parsed)
7
+ @task = task
8
+ @parsed = parsed
9
+ @labels = task.labels(parsed)
10
+
11
+ @schema_tokens = task.input_builder.schema_tokens_for(
12
+ prompt: task.build_prompt(parsed),
13
+ labels: @labels,
14
+ label_prefix: task.label_prefix
15
+ )
16
+
17
+ @label_mask = Array.new(@labels.length, 1)
18
+ @label_positions_template = precompute_label_positions
19
+ end
20
+
21
+ def call(text, **options)
22
+ prepared = @task.input_builder.prepare(text, @schema_tokens)
23
+ label_positions = @label_positions_template
24
+
25
+ if label_positions.any? { |pos| pos.nil? || pos >= prepared.input_ids.length }
26
+ label_positions = @task.inference.label_positions_for(prepared.word_ids, @labels.length)
27
+ end
28
+
29
+ logits = @task.inference.run(
30
+ Inference::Request.new(
31
+ input_ids: prepared.input_ids,
32
+ attention_mask: prepared.attention_mask,
33
+ words_mask: prepared.words_mask,
34
+ text_lengths: [prepared.text_len],
35
+ task_type: @task.task_type,
36
+ label_positions: label_positions,
37
+ label_mask: @label_mask,
38
+ want_cls: @task.needs_cls_logits?
39
+ )
40
+ )
41
+
42
+ @task.process_output(logits, @parsed, prepared, options.merge(label_positions: label_positions))
43
+ end
44
+
45
+ private
46
+
47
+ def precompute_label_positions
48
+ return [] if @labels.empty?
49
+
50
+ prepared = @task.input_builder.prepare('.', @schema_tokens)
51
+ @task.inference.label_positions_for(prepared.word_ids, @labels.length)
52
+ end
53
+ end
54
+ end
55
+ end
@@ -0,0 +1,36 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Gliner
4
+ module Runners
5
+ class StructuredRunner
6
+ def initialize(model, config)
7
+ @tasks = build_tasks(model, config)
8
+ end
9
+
10
+ def [](text, **options)
11
+ @tasks.transform_values do |task|
12
+ task.call(text, **options)
13
+ end
14
+ end
15
+
16
+ alias call []
17
+
18
+ private
19
+
20
+ def build_tasks(model, config)
21
+ raise Error, 'structures must be a Hash' unless config.is_a?(Hash)
22
+
23
+ if config.key?(:name) || config.key?('name')
24
+ parsed = model.json_task.parse_config(config)
25
+
26
+ { parsed[:name].to_s => PreparedTask.new(model.json_task, parsed) }
27
+ else
28
+ config.each_with_object({}) do |(name, fields), tasks|
29
+ parsed = model.json_task.parse_config(name: name, fields: fields)
30
+ tasks[name.to_s] = PreparedTask.new(model.json_task, parsed)
31
+ end
32
+ end
33
+ end
34
+ end
35
+ end
36
+ end