bm25f 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (3) hide show
  1. checksums.yaml +7 -0
  2. data/lib/bm25f.rb +141 -0
  3. metadata +60 -0
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: bc3b5ddbf5d62479a1a0afafccef1d377db2a149ac9db87f8d207b0c16b41550
4
+ data.tar.gz: 295be306f71a84ae399cffaf7c1faff3a768ad7e1c00095230144d564cc9af59
5
+ SHA512:
6
+ metadata.gz: 1c06abf6c6d53a66e151378c610bf5ecfbe92ced01692bbe126d7bfda77bd85a2c2733934460052a6d8129e61881e067a1e35264f7acb197ff51f2336bb31e19
7
+ data.tar.gz: 85ad4ba68f49009bbe667ccf993d498671ced59f9aeff4394435778c53369037f7dee346b7111c3cc6679fc72b9c8e62798139d3dce7a56e098795844e59b85e
data/lib/bm25f.rb ADDED
@@ -0,0 +1,141 @@
1
+ require 'treat'
2
+
3
+ class BM25F
4
+ include Treat::Core::DSL
5
+
6
+ # Initializes a BM25F model.
7
+ #
8
+ # @param term_freq_weight [Float] Weight for term frequency.
9
+ # @param doc_length_weight [Float] Weight for document length.
10
+ def initialize(term_freq_weight: 1.33, doc_length_weight: 0.8)
11
+ @term_freq_weight = term_freq_weight
12
+ @doc_length_weight = doc_length_weight
13
+ end
14
+
15
+ # Fits the model to a set of documents.
16
+ #
17
+ # @param documents [Hash] The documents to fit the model to.
18
+ # @param field_weights [Hash] A specified weight for each key the documents.
19
+ def fit(documents, field_weights = {})
20
+ documents = preprocess_documents(documents)
21
+
22
+ # Set missing field_weights to 1
23
+ unique_keys = documents.flat_map(&:keys).uniq
24
+
25
+ unique_keys.each do |key|
26
+ field_weights[key] = 1 unless field_weights.key?(key)
27
+ end
28
+
29
+ @field_weights = field_weights
30
+ @documents = documents
31
+ @avg_doc_length = calculate_average_document_length(documents)
32
+ @doc_lengths = calculate_document_lengths(documents)
33
+ @total_docs = documents.length
34
+ @idf = calculate_idf
35
+ end
36
+
37
+ # Calculates the score of each document using the query.
38
+ #
39
+ # @param query [String] The query to score with.
40
+ # @return [Hash] A hash containing document IDs and their scores.
41
+ def score(query)
42
+ query_terms = preprocess_query(query)
43
+ scores = {}
44
+ (0...@total_docs).each do |doc_id|
45
+ scores[doc_id] = calculate_document_score(doc_id, query_terms)
46
+ end
47
+ scores
48
+ end
49
+
50
+ private
51
+
52
+ # Preprocesses documents by tokenizing and stemming them.
53
+ #
54
+ # @param documents [Hash] The documents to preprocess.
55
+ def preprocess_documents(documents)
56
+ documents.each do |k, v|
57
+ next unless v.instance_of? String
58
+
59
+ documents[k] = sentence(v).map(&:stem).join(' ')
60
+ end
61
+ documents
62
+ end
63
+
64
+ # Calculates the average document length.
65
+ #
66
+ # @param documents [Hash] The documents.
67
+ # @return [Float] The average document length.
68
+ def calculate_average_document_length(documents)
69
+ total_length = documents.sum { |doc| doc.values.map(&:length).sum }
70
+ total_length / documents.length.to_f
71
+ end
72
+
73
+ # Calculates the lengths of each field in a document.
74
+ #
75
+ # @param documents [Hash] The documents.
76
+ # @return [Hash] A hash of document lengths.
77
+ def calculate_document_lengths(documents)
78
+ doc_lengths = {}
79
+ documents.each_with_index do |doc, i|
80
+ doc_lengths[i] = doc.transform_values(&:length)
81
+ end
82
+ doc_lengths
83
+ end
84
+
85
+ # Calculates the IDF for each field.
86
+ #
87
+ # @return [Hash] A hash of IDF values for each field.
88
+ def calculate_idf
89
+ idf = {}
90
+ @field_weights.each_key do |field|
91
+ field_doc_count = @documents.count { |doc| !doc[field].empty? }
92
+ idf[field] = Math.log((@total_docs - field_doc_count + 0.5) / (field_doc_count + 0.5) + 1.0)
93
+ end
94
+ idf
95
+ end
96
+
97
+ # Preprocesses a query by tokenizing and stemming it.
98
+ #
99
+ # @param query [String] The query to preprocess.
100
+ # @return [Array<String>] An array of preprocessed query terms.
101
+ def preprocess_query(query)
102
+ sentence(query).tokenize.map(&:stem)
103
+ end
104
+
105
+ # Calculates the score of a document using an array of query terms.
106
+ #
107
+ # @param doc_id [Integer] The document ID.
108
+ # @param query_terms [Array<String>] The query terms.
109
+ # @return [Float] The document score.
110
+ def calculate_document_score(doc_id, query_terms)
111
+ doc_score = 0
112
+ @field_weights.each_key do |field|
113
+ query_terms.each do |term|
114
+ tf = field_term_frequency(field, term, doc_id)
115
+ idf = @idf[field]
116
+ field_length_norm = field_length_norm(field, doc_id)
117
+ doc_score += @field_weights[field] * ((tf * (@term_freq_weight + 1)) / (tf + @term_freq_weight * field_length_norm) * idf)
118
+ end
119
+ end
120
+ doc_score
121
+ end
122
+
123
+ # Calculates the term frequency in a field of a document.
124
+ #
125
+ # @param field [Symbol] The field name.
126
+ # @param term [String] The term to calculate frequency for.
127
+ # @param doc_id [Integer] The document ID.
128
+ # @return [Integer] The term frequency.
129
+ def field_term_frequency(field, term, doc_id)
130
+ @documents[doc_id][field].scan(term).count
131
+ end
132
+
133
+ # Calculates the field length normalization factor of a document.
134
+ #
135
+ # @param field [Symbol] The field name.
136
+ # @param doc_id [Integer] The document ID.
137
+ # @return [Float] The field length normalization factor.
138
+ def field_length_norm(field, doc_id)
139
+ 1.0 - @doc_length_weight + @doc_length_weight * (@doc_lengths[doc_id][field] / @avg_doc_length)
140
+ end
141
+ end
metadata ADDED
@@ -0,0 +1,60 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: bm25f
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - catflip
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2023-09-09 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: treat
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: '2.1'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: '2.1'
27
+ description: A fast implementation of the BM25F ranking algorithm for information
28
+ retrieval systems, written in Ruby.
29
+ email:
30
+ executables: []
31
+ extensions: []
32
+ extra_rdoc_files: []
33
+ files:
34
+ - lib/bm25f.rb
35
+ homepage: https://github.com/catflip/bm25f-ruby
36
+ licenses:
37
+ - AGPL-3.0
38
+ metadata:
39
+ homepage_uri: https://github.com/catflip/bm25f-ruby
40
+ source_code_uri: https://github.com/catflip/bm25f-ruby
41
+ post_install_message:
42
+ rdoc_options: []
43
+ require_paths:
44
+ - lib
45
+ required_ruby_version: !ruby/object:Gem::Requirement
46
+ requirements:
47
+ - - ">="
48
+ - !ruby/object:Gem::Version
49
+ version: 3.0.0
50
+ required_rubygems_version: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - ">="
53
+ - !ruby/object:Gem::Version
54
+ version: '0'
55
+ requirements: []
56
+ rubygems_version: 3.3.26
57
+ signing_key:
58
+ specification_version: 4
59
+ summary: BM25F ranking function in Ruby.
60
+ test_files: []