gliner 0.2.3 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 88bb40bab466141ca3e1852059a0ba526a1928cc40f7a6adc5117d0da8e6c70a
4
- data.tar.gz: e901cb0b6b544a359816d8c0eb7abd50fc20cd342ffd4c20f0a2bf52728e97c8
3
+ metadata.gz: eac979a64c4acb302685c0390322c688dcdf35097741ef5b201c83827013f6ce
4
+ data.tar.gz: cc9fd10929e4dffe94ca1e4ce49ec0cea8e7569b22c7c9956cfad8d1bd0d5afc
5
5
  SHA512:
6
- metadata.gz: b1e6bb6047a668641389f5e1559e4a61712223a5a4a3169e3ce4102a0502642bd53b0e7b7f263aed4a3fb813d6a507f5a8a34dff92da8474de583d047e8f109a
7
- data.tar.gz: 9f9595a03255085f666ae045081a99b5293e3165e6214cd71f21e78e3338072d551ce25d753a8d6778084340e9535a2b95de6387678c7673a8fe4d7b16cfd844
6
+ metadata.gz: 4b75a4a0610a52d9364d988ba94277b1d55f0a116691bb6994ea571f9e2ab415aa006f22f2118c4268ef7a365e04614bde7ff8353b01ac4b5413d3545295d5bb
7
+ data.tar.gz: 1498c299f0d522974a0bea27c4ae71a6fee01cd93b832105b197fe227a0a33c6f751220efb128982e96e934c5feea382fe46450c3ffa4f3c6973dfd2f73e3ca4
data/README.md CHANGED
@@ -29,9 +29,19 @@ text = "Apple CEO Tim Cook announced iPhone 15 in Cupertino yesterday."
29
29
  labels = ["company", "person", "product", "location"]
30
30
 
31
31
  model = Gliner[labels]
32
- pp model[text]
32
+ entities = model[text]
33
33
 
34
- # => {"company"=>["Apple"], "person"=>["Tim Cook"], "product"=>["iPhone 15"], "location"=>["Cupertino"]}
34
+ pp entities["person"]
35
+ # => [#<data Gliner::Entity ...>]
36
+
37
+ entities["person"].first.text
38
+ # => "Tim Cook"
39
+
40
+ entities["person"].first.probability
41
+ # => 92.4
42
+
43
+ entities["person"].first.offsets
44
+ # => [10, 18]
35
45
  ```
36
46
 
37
47
  You can also pass per-entity configs:
@@ -43,9 +53,13 @@ labels = {
43
53
  }
44
54
 
45
55
  model = Gliner[labels]
46
- pp model["Email John Doe at john@example.com.", threshold: 0.5]
56
+ entities = model["Email John Doe at john@example.com.", threshold: 0.5]
47
57
 
48
- # => {"email"=>["john@example.com"], "person"=>"John Doe"}
58
+ entities["person"].text
59
+ # => "John Doe"
60
+
61
+ entities["email"].map(&:text)
62
+ # => ["john@example.com"]
49
63
  ```
50
64
 
51
65
  ### Classification
@@ -59,7 +73,30 @@ result = model["This laptop has amazing performance but terrible battery life!"]
59
73
 
60
74
  pp result
61
75
 
62
- # => {"sentiment"=>"negative"}
76
+ # => { sentiment: #<data Gliner::Label ...> }
77
+
78
+ result["sentiment"].label
79
+ # => "negative"
80
+
81
+ result["sentiment"].probability
82
+ # => 87.1
83
+ ```
84
+
85
+ Multiple classification tasks:
86
+
87
+ ```ruby
88
+ text = "Breaking: Tech giant announces major layoffs amid market downturn"
89
+
90
+ tasks = {
91
+ sentiment: %w[positive negative neutral],
92
+ urgency: %w[high medium low],
93
+ category: { labels: %w[tech finance politics sports], multi_label: false }
94
+ }
95
+
96
+ results = Gliner.classify[tasks][text]
97
+
98
+ results.transform_values { |value| value.label }
99
+ # => { sentiment: "negative", urgency: "high", category: "tech" }
63
100
  ```
64
101
 
65
102
  ### Structured extraction
@@ -77,10 +114,21 @@ structure = {
77
114
  }
78
115
 
79
116
  result = Gliner[structure][text]
117
+ product = result.fetch("product").first
80
118
 
81
119
  pp result
82
120
 
83
- # => {"product"=>[{"name"=>"iPhone 15 Pro Max", "storage"=>"256GB", "processor"=>"A17 Pro", "price"=>"1199"}]}
121
+ product["name"].text
122
+ # => "iPhone 15 Pro Max"
123
+
124
+ product["storage"].text
125
+ # => "256GB"
126
+
127
+ product["processor"].text
128
+ # => "A17 Pro"
129
+
130
+ product["price"].text
131
+ # => "1199"
84
132
  ```
85
133
 
86
134
  Choices can be included in field specs:
@@ -88,7 +136,8 @@ Choices can be included in field specs:
88
136
  ```ruby
89
137
  result = Gliner[{ order: ["status::[pending|processing|shipped]::str"] }]["Status: shipped"]
90
138
 
91
- # => {"order"=>[{"status"=>"shipped"}]}
139
+ result.fetch("order").first["status"].text
140
+ # shipped
92
141
  ```
93
142
 
94
143
  ## Model files
@@ -17,12 +17,12 @@ module Gliner
17
17
  end
18
18
  end
19
19
 
20
- def format_classification(scores, labels:, multi_label:, include_confidence:, cls_threshold:)
20
+ def format_classification(scores, labels:, multi_label:, include_probability:, cls_threshold:)
21
21
  label_scores = sorted_label_scores(scores, labels)
22
22
 
23
- return format_multi_label(label_scores, cls_threshold, include_confidence) if multi_label
23
+ return format_multi_label(label_scores, cls_threshold, include_probability) if multi_label
24
24
 
25
- format_single_label(label_scores.first, include_confidence)
25
+ format_single_label(label_scores.first, include_probability)
26
26
  end
27
27
 
28
28
  private
@@ -44,10 +44,12 @@ module Gliner
44
44
  .sort_by { |(_label, score)| -score }
45
45
  end
46
46
 
47
- def format_multi_label(label_scores, cls_threshold, include_confidence)
47
+ def format_multi_label(label_scores, cls_threshold, include_probability)
48
48
  chosen = labels_above_threshold(label_scores, cls_threshold)
49
49
 
50
- chosen.map { |label, score| format_label(label, score, include_confidence) }
50
+ chosen
51
+ .sort_by { |(_label, score)| -score }
52
+ .map { |label, score| format_label(label, score, include_probability) }
51
53
  end
52
54
 
53
55
  def labels_above_threshold(label_scores, threshold)
@@ -55,14 +57,14 @@ module Gliner
55
57
  above.empty? && label_scores.first ? [label_scores.first] : above
56
58
  end
57
59
 
58
- def format_single_label(label_score, include_confidence)
60
+ def format_single_label(label_score, include_probability)
59
61
  label, score = label_score
60
62
 
61
- format_label(label, score, include_confidence)
63
+ format_label(label, score, include_probability)
62
64
  end
63
65
 
64
- def format_label(label, score, include_confidence)
65
- include_confidence ? { 'label' => label, 'confidence' => score } : label
66
+ def format_label(label, score, _include_probability)
67
+ Gliner::Label.new(label: label, probability: score * 100.0)
66
68
  end
67
69
  end
68
70
  end
data/lib/gliner/model.rb CHANGED
@@ -72,7 +72,7 @@ module Gliner
72
72
  end
73
73
 
74
74
  def entity_task
75
- @entity_task ||= Tasks::EntityExtraction.new(
75
+ @entity_task ||= Tasks::Entity.new(
76
76
  config_parser: config_parser,
77
77
  inference: @inference,
78
78
  input_builder: input_builder,
@@ -90,7 +90,7 @@ module Gliner
90
90
  end
91
91
 
92
92
  def json_task
93
- @json_task ||= Tasks::JsonExtraction.new(
93
+ @json_task ||= Tasks::Json.new(
94
94
  config_parser: config_parser,
95
95
  inference: @inference,
96
96
  input_builder: input_builder,
@@ -101,7 +101,7 @@ module Gliner
101
101
 
102
102
  def extract_entities(text, entity_types, **options)
103
103
  threshold = options.fetch(:threshold, Gliner.config.threshold)
104
- include_confidence = options.fetch(:include_confidence, false)
104
+ include_probability = options.fetch(:include_probability, false)
105
105
  include_spans = options.fetch(:include_spans, false)
106
106
 
107
107
  pipeline.execute(
@@ -109,16 +109,16 @@ module Gliner
109
109
  text,
110
110
  entity_types,
111
111
  threshold: threshold,
112
- include_confidence: include_confidence,
112
+ include_probability: include_probability,
113
113
  include_spans: include_spans
114
114
  )
115
115
  end
116
116
 
117
117
  def classify_text(text, tasks, **options)
118
- include_confidence = options.fetch(:include_confidence, false)
118
+ include_probability = options.fetch(:include_probability, false)
119
119
  threshold = options[:threshold]
120
120
 
121
- task_options = { include_confidence: include_confidence }
121
+ task_options = { include_probability: include_probability }
122
122
  task_options[:threshold] = threshold unless threshold.nil?
123
123
 
124
124
  classification_task.execute_all(pipeline, text, tasks, **task_options)
@@ -126,7 +126,7 @@ module Gliner
126
126
 
127
127
  def extract_json(text, structures, **options)
128
128
  threshold = options.fetch(:threshold, Gliner.config.threshold)
129
- include_confidence = options.fetch(:include_confidence, false)
129
+ include_probability = options.fetch(:include_probability, false)
130
130
  include_spans = options.fetch(:include_spans, false)
131
131
 
132
132
  json_task.execute_all(
@@ -134,7 +134,7 @@ module Gliner
134
134
  text,
135
135
  structures,
136
136
  threshold: threshold,
137
- include_confidence: include_confidence,
137
+ include_probability: include_probability,
138
138
  include_spans: include_spans
139
139
  )
140
140
  end
@@ -2,7 +2,7 @@
2
2
 
3
3
  module Gliner
4
4
  module Runners
5
- class ClassificationRunner
5
+ class Classification
6
6
  include Inspectable
7
7
 
8
8
  def self.[](tasks)
@@ -12,6 +12,7 @@ module Gliner
12
12
  def initialize(model, tasks_config)
13
13
  raise Error, 'tasks must be a Hash' unless tasks_config.is_a?(Hash)
14
14
 
15
+ @config = tasks_config
15
16
  @tasks = tasks_config.to_h do |name, config|
16
17
  parsed = model.classification_task.parse_config(name: name, config: config)
17
18
  [name.to_s, PreparedTask.new(model.classification_task, parsed)]
@@ -2,12 +2,13 @@
2
2
 
3
3
  module Gliner
4
4
  module Runners
5
- class EntityRunner
5
+ class Entity
6
6
  include Inspectable
7
7
 
8
8
  def initialize(model, config)
9
9
  parsed = model.entity_task.parse_config(config)
10
10
 
11
+ @config = config
11
12
  @labels = parsed[:labels]
12
13
  @task = PreparedTask.new(model.entity_task, parsed)
13
14
  end
@@ -3,10 +3,10 @@
3
3
  module Gliner
4
4
  module Runners
5
5
  module Inspectable
6
- def inspect
7
- items = Array(inspect_items).map(&:to_s)
6
+ attr_reader :config
8
7
 
9
- "#<Gliner(#{inspect_label}) input=#{items.inspect}>"
8
+ def inspect
9
+ "#<Gliner(#{inspect_label}) config=#{config.inspect}>"
10
10
  end
11
11
  end
12
12
  end
@@ -2,10 +2,11 @@
2
2
 
3
3
  module Gliner
4
4
  module Runners
5
- class StructuredRunner
5
+ class Structure
6
6
  include Inspectable
7
7
 
8
8
  def initialize(model, config)
9
+ @config = config
9
10
  @tasks = build_tasks(model, config)
10
11
  end
11
12
 
@@ -47,11 +47,13 @@ module Gliner
47
47
  near.min_by { |s| [(s.end - s.start), -s.score, s.text.length] } || best
48
48
  end
49
49
 
50
- def format_single_span(span, opts)
51
- format_span(span, opts)
50
+ def format_single_span(span, opts = nil)
51
+ label = extract_label(opts)
52
+ format_span(span, opts, label: label, index: 0)
52
53
  end
53
54
 
54
- def format_spans(spans, opts)
55
+ def format_spans(spans, opts = nil)
56
+ label = extract_label(opts)
55
57
  return [] if spans.empty?
56
58
 
57
59
  sorted = spans.sort_by { |s| -s.score }
@@ -64,7 +66,9 @@ module Gliner
64
66
  selected << span
65
67
  end
66
68
 
67
- selected.map { |span| format_span(span, opts) }
69
+ selected.each_with_index.map do |span, index|
70
+ format_span(span, opts, label: label, index: index)
71
+ end
68
72
  end
69
73
 
70
74
  private
@@ -97,21 +101,22 @@ module Gliner
97
101
  Span.new(text: text_span, score: score, start: char_start, end: char_end)
98
102
  end
99
103
 
100
- def format_span(span, opts)
104
+ def format_span(span, _opts, label:, index:)
101
105
  return nil if span.nil?
102
106
 
103
- format_opts = FormatOptions.from(opts)
104
- return span.text unless format_opts.include_confidence || format_opts.include_spans
105
-
106
- result = { 'text' => span.text }
107
- result['confidence'] = span.score if format_opts.include_confidence
107
+ Gliner::Entity.new(
108
+ index: index,
109
+ offsets: [span.start, span.end],
110
+ text: span.text,
111
+ name: label&.to_s,
112
+ probability: span.score * 100.0
113
+ )
114
+ end
108
115
 
109
- if format_opts.include_spans
110
- result['start'] = span.start
111
- result['end'] = span.end
112
- end
116
+ def extract_label(opts)
117
+ return nil unless opts.is_a?(Hash)
113
118
 
114
- result
119
+ opts[:label] || opts['label']
115
120
  end
116
121
  end
117
122
  end
@@ -1,6 +1,18 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Gliner
4
+ Structure = Data.define(:fields) do
5
+ include Enumerable
6
+
7
+ def [](key) = fields[key]
8
+ def fetch(key, ...) = fields.fetch(key, ...)
9
+ def to_h = fields
10
+ def to_hash = fields
11
+ def keys = fields.keys
12
+ def values = fields.values
13
+ def each(&block) = fields.each(&block)
14
+ end
15
+
4
16
  class StructuredExtractor
5
17
  def initialize(span_extractor)
6
18
  @span_extractor = span_extractor
@@ -32,18 +44,17 @@ module Gliner
32
44
  end
33
45
 
34
46
  def build_structure_instances(parsed_fields, spans_by_label, opts)
35
- format_opts = FormatOptions.from(opts)
36
47
  anchor_field = anchor_field_for(parsed_fields)
37
- return [{}] unless anchor_field
48
+ return [Gliner::Structure.new(fields: {})] unless anchor_field
38
49
 
39
50
  anchors = spans_by_label.fetch(anchor_field[:name], [])
40
- return [format_structure_object(parsed_fields, spans_by_label, format_opts)] if anchors.empty?
51
+ return [format_structure_object(parsed_fields, spans_by_label, opts)] if anchors.empty?
41
52
 
42
53
  instance_spans = build_instance_spans(anchors, spans_by_label)
43
- format_instances(parsed_fields, instance_spans, format_opts)
54
+ format_instances(parsed_fields, instance_spans, opts)
44
55
  end
45
56
 
46
- def format_structure_object(parsed_fields, spans_by_label, opts)
57
+ def format_structure_object(parsed_fields, spans_by_label, _opts)
47
58
  obj = {}
48
59
 
49
60
  parsed_fields.each do |field|
@@ -52,13 +63,13 @@ module Gliner
52
63
 
53
64
  if field[:dtype] == :str
54
65
  best = @span_extractor.choose_best_span(spans)
55
- obj[key] = @span_extractor.format_single_span(best, opts)
66
+ obj[key] = @span_extractor.format_single_span(best, label: key)
56
67
  else
57
- obj[key] = @span_extractor.format_spans(spans, opts)
68
+ obj[key] = @span_extractor.format_spans(spans, label: key)
58
69
  end
59
70
  end
60
71
 
61
- obj
72
+ Gliner::Structure.new(fields: obj)
62
73
  end
63
74
 
64
75
  private
@@ -38,7 +38,7 @@ module Gliner
38
38
  end
39
39
 
40
40
  def process_output(logits, parsed, prepared, options)
41
- include_confidence = options.fetch(:include_confidence, false)
41
+ include_probability = options.fetch(:include_probability, false)
42
42
  threshold_override = options[:threshold]
43
43
  cls_threshold = threshold_override.nil? ? parsed[:cls_threshold] : threshold_override
44
44
 
@@ -47,7 +47,7 @@ module Gliner
47
47
  scores,
48
48
  labels: parsed[:labels],
49
49
  multi_label: parsed[:multi_label],
50
- include_confidence: include_confidence,
50
+ include_probability: include_probability,
51
51
  cls_threshold: cls_threshold
52
52
  )
53
53
  end
@@ -2,7 +2,7 @@
2
2
 
3
3
  module Gliner
4
4
  module Tasks
5
- class EntityExtraction < Task
5
+ class Entity < Task
6
6
  def initialize(config_parser:, inference:, input_builder:, span_extractor:)
7
7
  super(config_parser: config_parser, inference: inference, input_builder: input_builder)
8
8
  @span_extractor = span_extractor
@@ -30,12 +30,11 @@ module Gliner
30
30
 
31
31
  def process_output(logits, parsed, prepared, options)
32
32
  threshold = options.fetch(:threshold, Gliner.config.threshold)
33
- format_opts = FormatOptions.from(options)
34
33
  label_positions = options[:label_positions] || inference.label_positions_for(prepared.word_ids, parsed[:labels].length)
35
34
 
36
35
  spans_by_label = extract_spans(logits, parsed, prepared, label_positions, threshold)
37
36
 
38
- { 'entities' => format_entities(parsed, spans_by_label, format_opts) }
37
+ { 'entities' => format_entities(parsed, spans_by_label) }
39
38
  end
40
39
 
41
40
  private
@@ -51,20 +50,28 @@ module Gliner
51
50
  )
52
51
  end
53
52
 
54
- def format_entities(parsed, spans_by_label, format_opts)
55
- parsed[:labels].each_with_object({}) do |label, entities|
53
+ def format_entities(parsed, spans_by_label)
54
+ entities = parsed[:labels].each_with_object({}) do |label, entries|
56
55
  spans = spans_by_label.fetch(label)
57
56
  dtype = parsed[:dtypes].fetch(label, :list)
58
57
 
59
- entities[label] = format_entity_value(spans, dtype, format_opts)
58
+ value = format_entity_value(label, spans, dtype)
59
+ next if value.is_a?(Array) && value.empty?
60
+
61
+ entries[label] = value
60
62
  end
63
+
64
+ Gliner::Entities.new(entities)
61
65
  end
62
66
 
63
- def format_entity_value(spans, dtype, format_opts)
67
+ def format_entity_value(label, spans, dtype)
64
68
  if dtype == :str
65
- @span_extractor.format_single_span(@span_extractor.choose_best_span(spans), format_opts)
69
+ @span_extractor.format_single_span(
70
+ @span_extractor.choose_best_span(spans),
71
+ label: label
72
+ )
66
73
  else
67
- @span_extractor.format_spans(spans, format_opts)
74
+ @span_extractor.format_spans(spans, label: label)
68
75
  end
69
76
  end
70
77
  end
@@ -2,7 +2,7 @@
2
2
 
3
3
  module Gliner
4
4
  module Tasks
5
- class JsonExtraction < Task
5
+ class Json < Task
6
6
  def initialize(config_parser:, inference:, input_builder:, span_extractor:, structured_extractor:)
7
7
  super(config_parser: config_parser, inference: inference, input_builder: input_builder)
8
8
 
@@ -47,9 +47,7 @@ module Gliner
47
47
  def process_output(logits, parsed, prepared, options)
48
48
  spans_by_label = extract_spans(logits, parsed, prepared, options)
49
49
  filtered_spans = @structured_extractor.apply_choice_filters(spans_by_label, parsed[:parsed_fields])
50
- format_opts = FormatOptions.from(options)
51
-
52
- @structured_extractor.build_structure_instances(parsed[:parsed_fields], filtered_spans, format_opts)
50
+ @structured_extractor.build_structure_instances(parsed[:parsed_fields], filtered_spans, options)
53
51
  end
54
52
 
55
53
  def execute_all(pipeline, text, structures_config, **options)
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Gliner
4
- VERSION = '0.2.3'
4
+ VERSION = '0.3.0'
5
5
  end
data/lib/gliner.rb CHANGED
@@ -1,5 +1,6 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ require 'delegate'
3
4
  require 'fileutils'
4
5
  require 'httpx'
5
6
  require 'gliner/version'
@@ -31,20 +32,25 @@ module Gliner
31
32
  :text_len
32
33
  )
33
34
 
34
- Span = Data.define(:text, :score, :start, :end) do
35
- def overlaps?(other)
36
- !(self.end <= other.start || start >= other.end)
35
+ Entity = Data.define(:index, :offsets, :text, :name, :probability) do
36
+ def to_s = text.to_s
37
+ def to_str = text.to_s
38
+ end
39
+
40
+ class Entities < SimpleDelegator
41
+ def list
42
+ __getobj__.values.flat_map { |value| Array(value) }
37
43
  end
38
44
  end
39
45
 
40
- FormatOptions = Data.define(:include_confidence, :include_spans) do
41
- def self.from(input)
42
- return input if input.is_a?(FormatOptions)
46
+ Label = Data.define(:label, :probability) do
47
+ def to_s = label.to_s
48
+ def to_str = label.to_s
49
+ end
43
50
 
44
- new(
45
- include_confidence: input.fetch(:include_confidence, false),
46
- include_spans: input.fetch(:include_spans, false)
47
- )
51
+ Span = Data.define(:text, :score, :start, :end) do
52
+ def overlaps?(other)
53
+ !(self.end <= other.start || start >= other.end)
48
54
  end
49
55
  end
50
56
 
@@ -80,7 +86,7 @@ module Gliner
80
86
  end
81
87
 
82
88
  def classify
83
- Runners::ClassificationRunner
89
+ Runners::Classification
84
90
  end
85
91
 
86
92
  def model!
@@ -111,9 +117,9 @@ module Gliner
111
117
  end
112
118
 
113
119
  def runner_for(config)
114
- return Runners::StructuredRunner if structured_config?(config)
120
+ return Runners::Structure if structured_config?(config)
115
121
 
116
- Runners::EntityRunner
122
+ Runners::Entity
117
123
  end
118
124
 
119
125
  def structured_config?(config)
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: gliner
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.3
4
+ version: 0.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - elcuervo