bm25f 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/lib/bm25f.rb +141 -0
- metadata +60 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: bc3b5ddbf5d62479a1a0afafccef1d377db2a149ac9db87f8d207b0c16b41550
|
4
|
+
data.tar.gz: 295be306f71a84ae399cffaf7c1faff3a768ad7e1c00095230144d564cc9af59
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: 1c06abf6c6d53a66e151378c610bf5ecfbe92ced01692bbe126d7bfda77bd85a2c2733934460052a6d8129e61881e067a1e35264f7acb197ff51f2336bb31e19
|
7
|
+
data.tar.gz: 85ad4ba68f49009bbe667ccf993d498671ced59f9aeff4394435778c53369037f7dee346b7111c3cc6679fc72b9c8e62798139d3dce7a56e098795844e59b85e
|
data/lib/bm25f.rb
ADDED
@@ -0,0 +1,141 @@
|
|
1
|
+
require 'treat'
|
2
|
+
|
3
|
+
class BM25F
|
4
|
+
include Treat::Core::DSL
|
5
|
+
|
6
|
+
# Initializes a BM25F model.
|
7
|
+
#
|
8
|
+
# @param term_freq_weight [Float] Weight for term frequency.
|
9
|
+
# @param doc_length_weight [Float] Weight for document length.
|
10
|
+
def initialize(term_freq_weight: 1.33, doc_length_weight: 0.8)
|
11
|
+
@term_freq_weight = term_freq_weight
|
12
|
+
@doc_length_weight = doc_length_weight
|
13
|
+
end
|
14
|
+
|
15
|
+
# Fits the model to a set of documents.
|
16
|
+
#
|
17
|
+
# @param documents [Hash] The documents to fit the model to.
|
18
|
+
# @param field_weights [Hash] A specified weight for each key the documents.
|
19
|
+
def fit(documents, field_weights = {})
|
20
|
+
documents = preprocess_documents(documents)
|
21
|
+
|
22
|
+
# Set missing field_weights to 1
|
23
|
+
unique_keys = documents.flat_map(&:keys).uniq
|
24
|
+
|
25
|
+
unique_keys.each do |key|
|
26
|
+
field_weights[key] = 1 unless field_weights.key?(key)
|
27
|
+
end
|
28
|
+
|
29
|
+
@field_weights = field_weights
|
30
|
+
@documents = documents
|
31
|
+
@avg_doc_length = calculate_average_document_length(documents)
|
32
|
+
@doc_lengths = calculate_document_lengths(documents)
|
33
|
+
@total_docs = documents.length
|
34
|
+
@idf = calculate_idf
|
35
|
+
end
|
36
|
+
|
37
|
+
# Calculates the score of each document using the query.
|
38
|
+
#
|
39
|
+
# @param query [String] The query to score with.
|
40
|
+
# @return [Hash] A hash containing document IDs and their scores.
|
41
|
+
def score(query)
|
42
|
+
query_terms = preprocess_query(query)
|
43
|
+
scores = {}
|
44
|
+
(0...@total_docs).each do |doc_id|
|
45
|
+
scores[doc_id] = calculate_document_score(doc_id, query_terms)
|
46
|
+
end
|
47
|
+
scores
|
48
|
+
end
|
49
|
+
|
50
|
+
private
|
51
|
+
|
52
|
+
# Preprocesses documents by tokenizing and stemming them.
|
53
|
+
#
|
54
|
+
# @param documents [Hash] The documents to preprocess.
|
55
|
+
def preprocess_documents(documents)
|
56
|
+
documents.each do |k, v|
|
57
|
+
next unless v.instance_of? String
|
58
|
+
|
59
|
+
documents[k] = sentence(v).map(&:stem).join(' ')
|
60
|
+
end
|
61
|
+
documents
|
62
|
+
end
|
63
|
+
|
64
|
+
# Calculates the average document length.
|
65
|
+
#
|
66
|
+
# @param documents [Hash] The documents.
|
67
|
+
# @return [Float] The average document length.
|
68
|
+
def calculate_average_document_length(documents)
|
69
|
+
total_length = documents.sum { |doc| doc.values.map(&:length).sum }
|
70
|
+
total_length / documents.length.to_f
|
71
|
+
end
|
72
|
+
|
73
|
+
# Calculates the lengths of each field in a document.
|
74
|
+
#
|
75
|
+
# @param documents [Hash] The documents.
|
76
|
+
# @return [Hash] A hash of document lengths.
|
77
|
+
def calculate_document_lengths(documents)
|
78
|
+
doc_lengths = {}
|
79
|
+
documents.each_with_index do |doc, i|
|
80
|
+
doc_lengths[i] = doc.transform_values(&:length)
|
81
|
+
end
|
82
|
+
doc_lengths
|
83
|
+
end
|
84
|
+
|
85
|
+
# Calculates the IDF for each field.
|
86
|
+
#
|
87
|
+
# @return [Hash] A hash of IDF values for each field.
|
88
|
+
def calculate_idf
|
89
|
+
idf = {}
|
90
|
+
@field_weights.each_key do |field|
|
91
|
+
field_doc_count = @documents.count { |doc| !doc[field].empty? }
|
92
|
+
idf[field] = Math.log((@total_docs - field_doc_count + 0.5) / (field_doc_count + 0.5) + 1.0)
|
93
|
+
end
|
94
|
+
idf
|
95
|
+
end
|
96
|
+
|
97
|
+
# Preprocesses a query by tokenizing and stemming it.
|
98
|
+
#
|
99
|
+
# @param query [String] The query to preprocess.
|
100
|
+
# @return [Array<String>] An array of preprocessed query terms.
|
101
|
+
def preprocess_query(query)
|
102
|
+
sentence(query).tokenize.map(&:stem)
|
103
|
+
end
|
104
|
+
|
105
|
+
# Calculates the score of a document using an array of query terms.
|
106
|
+
#
|
107
|
+
# @param doc_id [Integer] The document ID.
|
108
|
+
# @param query_terms [Array<String>] The query terms.
|
109
|
+
# @return [Float] The document score.
|
110
|
+
def calculate_document_score(doc_id, query_terms)
|
111
|
+
doc_score = 0
|
112
|
+
@field_weights.each_key do |field|
|
113
|
+
query_terms.each do |term|
|
114
|
+
tf = field_term_frequency(field, term, doc_id)
|
115
|
+
idf = @idf[field]
|
116
|
+
field_length_norm = field_length_norm(field, doc_id)
|
117
|
+
doc_score += @field_weights[field] * ((tf * (@term_freq_weight + 1)) / (tf + @term_freq_weight * field_length_norm) * idf)
|
118
|
+
end
|
119
|
+
end
|
120
|
+
doc_score
|
121
|
+
end
|
122
|
+
|
123
|
+
# Calculates the term frequency in a field of a document.
|
124
|
+
#
|
125
|
+
# @param field [Symbol] The field name.
|
126
|
+
# @param term [String] The term to calculate frequency for.
|
127
|
+
# @param doc_id [Integer] The document ID.
|
128
|
+
# @return [Integer] The term frequency.
|
129
|
+
def field_term_frequency(field, term, doc_id)
|
130
|
+
@documents[doc_id][field].scan(term).count
|
131
|
+
end
|
132
|
+
|
133
|
+
# Calculates the field length normalization factor of a document.
|
134
|
+
#
|
135
|
+
# @param field [Symbol] The field name.
|
136
|
+
# @param doc_id [Integer] The document ID.
|
137
|
+
# @return [Float] The field length normalization factor.
|
138
|
+
def field_length_norm(field, doc_id)
|
139
|
+
1.0 - @doc_length_weight + @doc_length_weight * (@doc_lengths[doc_id][field] / @avg_doc_length)
|
140
|
+
end
|
141
|
+
end
|
metadata
ADDED
@@ -0,0 +1,60 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: bm25f
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.1.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- catflip
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2023-09-09 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: treat
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - "~>"
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '2.1'
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - "~>"
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '2.1'
|
27
|
+
description: A fast implementation of the BM25F ranking algorithm for information
|
28
|
+
retrieval systems, written in Ruby.
|
29
|
+
email:
|
30
|
+
executables: []
|
31
|
+
extensions: []
|
32
|
+
extra_rdoc_files: []
|
33
|
+
files:
|
34
|
+
- lib/bm25f.rb
|
35
|
+
homepage: https://github.com/catflip/bm25f-ruby
|
36
|
+
licenses:
|
37
|
+
- AGPL-3.0
|
38
|
+
metadata:
|
39
|
+
homepage_uri: https://github.com/catflip/bm25f-ruby
|
40
|
+
source_code_uri: https://github.com/catflip/bm25f-ruby
|
41
|
+
post_install_message:
|
42
|
+
rdoc_options: []
|
43
|
+
require_paths:
|
44
|
+
- lib
|
45
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
46
|
+
requirements:
|
47
|
+
- - ">="
|
48
|
+
- !ruby/object:Gem::Version
|
49
|
+
version: 3.0.0
|
50
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - ">="
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '0'
|
55
|
+
requirements: []
|
56
|
+
rubygems_version: 3.3.26
|
57
|
+
signing_key:
|
58
|
+
specification_version: 4
|
59
|
+
summary: BM25F ranking function in Ruby.
|
60
|
+
test_files: []
|