bm25f 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/lib/bm25f.rb +141 -0
- metadata +60 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: bc3b5ddbf5d62479a1a0afafccef1d377db2a149ac9db87f8d207b0c16b41550
|
4
|
+
data.tar.gz: 295be306f71a84ae399cffaf7c1faff3a768ad7e1c00095230144d564cc9af59
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: 1c06abf6c6d53a66e151378c610bf5ecfbe92ced01692bbe126d7bfda77bd85a2c2733934460052a6d8129e61881e067a1e35264f7acb197ff51f2336bb31e19
|
7
|
+
data.tar.gz: 85ad4ba68f49009bbe667ccf993d498671ced59f9aeff4394435778c53369037f7dee346b7111c3cc6679fc72b9c8e62798139d3dce7a56e098795844e59b85e
|
data/lib/bm25f.rb
ADDED
@@ -0,0 +1,141 @@
|
|
1
|
+
require 'treat'
|
2
|
+
|
3
|
+
class BM25F
|
4
|
+
include Treat::Core::DSL
|
5
|
+
|
6
|
+
# Initializes a BM25F model.
|
7
|
+
#
|
8
|
+
# @param term_freq_weight [Float] Weight for term frequency.
|
9
|
+
# @param doc_length_weight [Float] Weight for document length.
|
10
|
+
def initialize(term_freq_weight: 1.33, doc_length_weight: 0.8)
|
11
|
+
@term_freq_weight = term_freq_weight
|
12
|
+
@doc_length_weight = doc_length_weight
|
13
|
+
end
|
14
|
+
|
15
|
+
# Fits the model to a set of documents.
|
16
|
+
#
|
17
|
+
# @param documents [Hash] The documents to fit the model to.
|
18
|
+
# @param field_weights [Hash] A specified weight for each key the documents.
|
19
|
+
def fit(documents, field_weights = {})
|
20
|
+
documents = preprocess_documents(documents)
|
21
|
+
|
22
|
+
# Set missing field_weights to 1
|
23
|
+
unique_keys = documents.flat_map(&:keys).uniq
|
24
|
+
|
25
|
+
unique_keys.each do |key|
|
26
|
+
field_weights[key] = 1 unless field_weights.key?(key)
|
27
|
+
end
|
28
|
+
|
29
|
+
@field_weights = field_weights
|
30
|
+
@documents = documents
|
31
|
+
@avg_doc_length = calculate_average_document_length(documents)
|
32
|
+
@doc_lengths = calculate_document_lengths(documents)
|
33
|
+
@total_docs = documents.length
|
34
|
+
@idf = calculate_idf
|
35
|
+
end
|
36
|
+
|
37
|
+
# Calculates the score of each document using the query.
|
38
|
+
#
|
39
|
+
# @param query [String] The query to score with.
|
40
|
+
# @return [Hash] A hash containing document IDs and their scores.
|
41
|
+
def score(query)
|
42
|
+
query_terms = preprocess_query(query)
|
43
|
+
scores = {}
|
44
|
+
(0...@total_docs).each do |doc_id|
|
45
|
+
scores[doc_id] = calculate_document_score(doc_id, query_terms)
|
46
|
+
end
|
47
|
+
scores
|
48
|
+
end
|
49
|
+
|
50
|
+
private
|
51
|
+
|
52
|
+
# Preprocesses documents by tokenizing and stemming them.
|
53
|
+
#
|
54
|
+
# @param documents [Hash] The documents to preprocess.
|
55
|
+
def preprocess_documents(documents)
|
56
|
+
documents.each do |k, v|
|
57
|
+
next unless v.instance_of? String
|
58
|
+
|
59
|
+
documents[k] = sentence(v).map(&:stem).join(' ')
|
60
|
+
end
|
61
|
+
documents
|
62
|
+
end
|
63
|
+
|
64
|
+
# Calculates the average document length.
|
65
|
+
#
|
66
|
+
# @param documents [Hash] The documents.
|
67
|
+
# @return [Float] The average document length.
|
68
|
+
def calculate_average_document_length(documents)
|
69
|
+
total_length = documents.sum { |doc| doc.values.map(&:length).sum }
|
70
|
+
total_length / documents.length.to_f
|
71
|
+
end
|
72
|
+
|
73
|
+
# Calculates the lengths of each field in a document.
|
74
|
+
#
|
75
|
+
# @param documents [Hash] The documents.
|
76
|
+
# @return [Hash] A hash of document lengths.
|
77
|
+
def calculate_document_lengths(documents)
|
78
|
+
doc_lengths = {}
|
79
|
+
documents.each_with_index do |doc, i|
|
80
|
+
doc_lengths[i] = doc.transform_values(&:length)
|
81
|
+
end
|
82
|
+
doc_lengths
|
83
|
+
end
|
84
|
+
|
85
|
+
# Calculates the IDF for each field.
|
86
|
+
#
|
87
|
+
# @return [Hash] A hash of IDF values for each field.
|
88
|
+
def calculate_idf
|
89
|
+
idf = {}
|
90
|
+
@field_weights.each_key do |field|
|
91
|
+
field_doc_count = @documents.count { |doc| !doc[field].empty? }
|
92
|
+
idf[field] = Math.log((@total_docs - field_doc_count + 0.5) / (field_doc_count + 0.5) + 1.0)
|
93
|
+
end
|
94
|
+
idf
|
95
|
+
end
|
96
|
+
|
97
|
+
# Preprocesses a query by tokenizing and stemming it.
|
98
|
+
#
|
99
|
+
# @param query [String] The query to preprocess.
|
100
|
+
# @return [Array<String>] An array of preprocessed query terms.
|
101
|
+
def preprocess_query(query)
|
102
|
+
sentence(query).tokenize.map(&:stem)
|
103
|
+
end
|
104
|
+
|
105
|
+
# Calculates the score of a document using an array of query terms.
|
106
|
+
#
|
107
|
+
# @param doc_id [Integer] The document ID.
|
108
|
+
# @param query_terms [Array<String>] The query terms.
|
109
|
+
# @return [Float] The document score.
|
110
|
+
def calculate_document_score(doc_id, query_terms)
|
111
|
+
doc_score = 0
|
112
|
+
@field_weights.each_key do |field|
|
113
|
+
query_terms.each do |term|
|
114
|
+
tf = field_term_frequency(field, term, doc_id)
|
115
|
+
idf = @idf[field]
|
116
|
+
field_length_norm = field_length_norm(field, doc_id)
|
117
|
+
doc_score += @field_weights[field] * ((tf * (@term_freq_weight + 1)) / (tf + @term_freq_weight * field_length_norm) * idf)
|
118
|
+
end
|
119
|
+
end
|
120
|
+
doc_score
|
121
|
+
end
|
122
|
+
|
123
|
+
# Calculates the term frequency in a field of a document.
|
124
|
+
#
|
125
|
+
# @param field [Symbol] The field name.
|
126
|
+
# @param term [String] The term to calculate frequency for.
|
127
|
+
# @param doc_id [Integer] The document ID.
|
128
|
+
# @return [Integer] The term frequency.
|
129
|
+
def field_term_frequency(field, term, doc_id)
|
130
|
+
@documents[doc_id][field].scan(term).count
|
131
|
+
end
|
132
|
+
|
133
|
+
# Calculates the field length normalization factor of a document.
|
134
|
+
#
|
135
|
+
# @param field [Symbol] The field name.
|
136
|
+
# @param doc_id [Integer] The document ID.
|
137
|
+
# @return [Float] The field length normalization factor.
|
138
|
+
def field_length_norm(field, doc_id)
|
139
|
+
1.0 - @doc_length_weight + @doc_length_weight * (@doc_lengths[doc_id][field] / @avg_doc_length)
|
140
|
+
end
|
141
|
+
end
|
metadata
ADDED
@@ -0,0 +1,60 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: bm25f
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.1.0
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- catflip
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2023-09-09 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: treat
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - "~>"
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '2.1'
|
20
|
+
type: :runtime
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - "~>"
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '2.1'
|
27
|
+
description: A fast implementation of the BM25F ranking algorithm for information
|
28
|
+
retrieval systems, written in Ruby.
|
29
|
+
email:
|
30
|
+
executables: []
|
31
|
+
extensions: []
|
32
|
+
extra_rdoc_files: []
|
33
|
+
files:
|
34
|
+
- lib/bm25f.rb
|
35
|
+
homepage: https://github.com/catflip/bm25f-ruby
|
36
|
+
licenses:
|
37
|
+
- AGPL-3.0
|
38
|
+
metadata:
|
39
|
+
homepage_uri: https://github.com/catflip/bm25f-ruby
|
40
|
+
source_code_uri: https://github.com/catflip/bm25f-ruby
|
41
|
+
post_install_message:
|
42
|
+
rdoc_options: []
|
43
|
+
require_paths:
|
44
|
+
- lib
|
45
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
46
|
+
requirements:
|
47
|
+
- - ">="
|
48
|
+
- !ruby/object:Gem::Version
|
49
|
+
version: 3.0.0
|
50
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - ">="
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '0'
|
55
|
+
requirements: []
|
56
|
+
rubygems_version: 3.3.26
|
57
|
+
signing_key:
|
58
|
+
specification_version: 4
|
59
|
+
summary: BM25F ranking function in Ruby.
|
60
|
+
test_files: []
|