bm25f 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (3) hide show
  1. checksums.yaml +7 -0
  2. data/lib/bm25f.rb +141 -0
  3. metadata +60 -0
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: bc3b5ddbf5d62479a1a0afafccef1d377db2a149ac9db87f8d207b0c16b41550
4
+ data.tar.gz: 295be306f71a84ae399cffaf7c1faff3a768ad7e1c00095230144d564cc9af59
5
+ SHA512:
6
+ metadata.gz: 1c06abf6c6d53a66e151378c610bf5ecfbe92ced01692bbe126d7bfda77bd85a2c2733934460052a6d8129e61881e067a1e35264f7acb197ff51f2336bb31e19
7
+ data.tar.gz: 85ad4ba68f49009bbe667ccf993d498671ced59f9aeff4394435778c53369037f7dee346b7111c3cc6679fc72b9c8e62798139d3dce7a56e098795844e59b85e
data/lib/bm25f.rb ADDED
@@ -0,0 +1,141 @@
1
+ require 'treat'
2
+
3
+ class BM25F
4
+ include Treat::Core::DSL
5
+
6
+ # Initializes a BM25F model.
7
+ #
8
+ # @param term_freq_weight [Float] Weight for term frequency.
9
+ # @param doc_length_weight [Float] Weight for document length.
10
+ def initialize(term_freq_weight: 1.33, doc_length_weight: 0.8)
11
+ @term_freq_weight = term_freq_weight
12
+ @doc_length_weight = doc_length_weight
13
+ end
14
+
15
+ # Fits the model to a set of documents.
16
+ #
17
+ # @param documents [Hash] The documents to fit the model to.
18
+ # @param field_weights [Hash] A specified weight for each key the documents.
19
+ def fit(documents, field_weights = {})
20
+ documents = preprocess_documents(documents)
21
+
22
+ # Set missing field_weights to 1
23
+ unique_keys = documents.flat_map(&:keys).uniq
24
+
25
+ unique_keys.each do |key|
26
+ field_weights[key] = 1 unless field_weights.key?(key)
27
+ end
28
+
29
+ @field_weights = field_weights
30
+ @documents = documents
31
+ @avg_doc_length = calculate_average_document_length(documents)
32
+ @doc_lengths = calculate_document_lengths(documents)
33
+ @total_docs = documents.length
34
+ @idf = calculate_idf
35
+ end
36
+
37
+ # Calculates the score of each document using the query.
38
+ #
39
+ # @param query [String] The query to score with.
40
+ # @return [Hash] A hash containing document IDs and their scores.
41
+ def score(query)
42
+ query_terms = preprocess_query(query)
43
+ scores = {}
44
+ (0...@total_docs).each do |doc_id|
45
+ scores[doc_id] = calculate_document_score(doc_id, query_terms)
46
+ end
47
+ scores
48
+ end
49
+
50
+ private
51
+
52
+ # Preprocesses documents by tokenizing and stemming them.
53
+ #
54
+ # @param documents [Hash] The documents to preprocess.
55
+ def preprocess_documents(documents)
56
+ documents.each do |k, v|
57
+ next unless v.instance_of? String
58
+
59
+ documents[k] = sentence(v).map(&:stem).join(' ')
60
+ end
61
+ documents
62
+ end
63
+
64
+ # Calculates the average document length.
65
+ #
66
+ # @param documents [Hash] The documents.
67
+ # @return [Float] The average document length.
68
+ def calculate_average_document_length(documents)
69
+ total_length = documents.sum { |doc| doc.values.map(&:length).sum }
70
+ total_length / documents.length.to_f
71
+ end
72
+
73
+ # Calculates the lengths of each field in a document.
74
+ #
75
+ # @param documents [Hash] The documents.
76
+ # @return [Hash] A hash of document lengths.
77
+ def calculate_document_lengths(documents)
78
+ doc_lengths = {}
79
+ documents.each_with_index do |doc, i|
80
+ doc_lengths[i] = doc.transform_values(&:length)
81
+ end
82
+ doc_lengths
83
+ end
84
+
85
+ # Calculates the IDF for each field.
86
+ #
87
+ # @return [Hash] A hash of IDF values for each field.
88
+ def calculate_idf
89
+ idf = {}
90
+ @field_weights.each_key do |field|
91
+ field_doc_count = @documents.count { |doc| !doc[field].empty? }
92
+ idf[field] = Math.log((@total_docs - field_doc_count + 0.5) / (field_doc_count + 0.5) + 1.0)
93
+ end
94
+ idf
95
+ end
96
+
97
+ # Preprocesses a query by tokenizing and stemming it.
98
+ #
99
+ # @param query [String] The query to preprocess.
100
+ # @return [Array<String>] An array of preprocessed query terms.
101
+ def preprocess_query(query)
102
+ sentence(query).tokenize.map(&:stem)
103
+ end
104
+
105
+ # Calculates the score of a document using an array of query terms.
106
+ #
107
+ # @param doc_id [Integer] The document ID.
108
+ # @param query_terms [Array<String>] The query terms.
109
+ # @return [Float] The document score.
110
+ def calculate_document_score(doc_id, query_terms)
111
+ doc_score = 0
112
+ @field_weights.each_key do |field|
113
+ query_terms.each do |term|
114
+ tf = field_term_frequency(field, term, doc_id)
115
+ idf = @idf[field]
116
+ field_length_norm = field_length_norm(field, doc_id)
117
+ doc_score += @field_weights[field] * ((tf * (@term_freq_weight + 1)) / (tf + @term_freq_weight * field_length_norm) * idf)
118
+ end
119
+ end
120
+ doc_score
121
+ end
122
+
123
+ # Calculates the term frequency in a field of a document.
124
+ #
125
+ # @param field [Symbol] The field name.
126
+ # @param term [String] The term to calculate frequency for.
127
+ # @param doc_id [Integer] The document ID.
128
+ # @return [Integer] The term frequency.
129
+ def field_term_frequency(field, term, doc_id)
130
+ @documents[doc_id][field].scan(term).count
131
+ end
132
+
133
+ # Calculates the field length normalization factor of a document.
134
+ #
135
+ # @param field [Symbol] The field name.
136
+ # @param doc_id [Integer] The document ID.
137
+ # @return [Float] The field length normalization factor.
138
+ def field_length_norm(field, doc_id)
139
+ 1.0 - @doc_length_weight + @doc_length_weight * (@doc_lengths[doc_id][field] / @avg_doc_length)
140
+ end
141
+ end
metadata ADDED
@@ -0,0 +1,60 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: bm25f
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - catflip
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2023-09-09 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: treat
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: '2.1'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: '2.1'
27
+ description: A fast implementation of the BM25F ranking algorithm for information
28
+ retrieval systems, written in Ruby.
29
+ email:
30
+ executables: []
31
+ extensions: []
32
+ extra_rdoc_files: []
33
+ files:
34
+ - lib/bm25f.rb
35
+ homepage: https://github.com/catflip/bm25f-ruby
36
+ licenses:
37
+ - AGPL-3.0
38
+ metadata:
39
+ homepage_uri: https://github.com/catflip/bm25f-ruby
40
+ source_code_uri: https://github.com/catflip/bm25f-ruby
41
+ post_install_message:
42
+ rdoc_options: []
43
+ require_paths:
44
+ - lib
45
+ required_ruby_version: !ruby/object:Gem::Requirement
46
+ requirements:
47
+ - - ">="
48
+ - !ruby/object:Gem::Version
49
+ version: 3.0.0
50
+ required_rubygems_version: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - ">="
53
+ - !ruby/object:Gem::Version
54
+ version: '0'
55
+ requirements: []
56
+ rubygems_version: 3.3.26
57
+ signing_key:
58
+ specification_version: 4
59
+ summary: BM25F ranking function in Ruby.
60
+ test_files: []