ebm4subjects 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ebm4subjects/__init__.py +0 -0
- ebm4subjects/analyzer.py +57 -0
- ebm4subjects/chunker.py +173 -0
- ebm4subjects/duckdb_client.py +329 -0
- ebm4subjects/ebm_logging.py +203 -0
- ebm4subjects/ebm_model.py +715 -0
- ebm4subjects/embedding_generator.py +63 -0
- ebm4subjects/prepare_data.py +82 -0
- ebm4subjects-0.4.1.dist-info/METADATA +134 -0
- ebm4subjects-0.4.1.dist-info/RECORD +12 -0
- ebm4subjects-0.4.1.dist-info/WHEEL +4 -0
- ebm4subjects-0.4.1.dist-info/licenses/LICENSE +287 -0
ebm4subjects/__init__.py
ADDED
|
File without changes
|
ebm4subjects/analyzer.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
import nltk.data
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class EbmAnalyzer:
|
|
7
|
+
"""
|
|
8
|
+
A class for tokenizing text using NLTK.
|
|
9
|
+
|
|
10
|
+
Attributes:
|
|
11
|
+
tokenizer (nltk.tokenize.TokenizerI): The loaded NLTK tokenizer.
|
|
12
|
+
|
|
13
|
+
Methods:
|
|
14
|
+
- tokenize_sentences: Tokenizes the input text into sentences
|
|
15
|
+
|
|
16
|
+
Raises:
|
|
17
|
+
LookupError: If the specified tokenizer is not found.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, tokenizer_name: str) -> None:
|
|
21
|
+
"""
|
|
22
|
+
Initializes the EbmAnalyzer with the specified tokenizer.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
tokenizer_name (str): The name of the NLTK tokenizer to use.
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
LookupError: If the specified tokenizer is not found.
|
|
29
|
+
"""
|
|
30
|
+
# Attempt to find the tokenizer
|
|
31
|
+
try:
|
|
32
|
+
nltk.data.find(tokenizer_name)
|
|
33
|
+
# If the tokenizer is not found, try to download it
|
|
34
|
+
except LookupError as error:
|
|
35
|
+
if tokenizer_name in str(error):
|
|
36
|
+
nltk.download(tokenizer_name)
|
|
37
|
+
else:
|
|
38
|
+
raise
|
|
39
|
+
|
|
40
|
+
# Load the tokenizer
|
|
41
|
+
self.tokenizer = nltk.data.load(tokenizer_name)
|
|
42
|
+
|
|
43
|
+
def tokenize_sentences(self, text: str) -> list[str]:
|
|
44
|
+
"""
|
|
45
|
+
Tokenizes the input text into sentences using the loaded tokenizer.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
text (str): The input text to tokenize.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
list[str]: A list of tokenized sentences.
|
|
52
|
+
"""
|
|
53
|
+
# Replace multiple periods by a singel one
|
|
54
|
+
# Necessary to work properly with some tables of contents
|
|
55
|
+
text = re.sub(r"\.{4,}", ". ", str(text))
|
|
56
|
+
# Tokenize the text and return it
|
|
57
|
+
return self.tokenizer.tokenize(text)
|
ebm4subjects/chunker.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
2
|
+
from math import ceil
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
import polars as pl
|
|
6
|
+
|
|
7
|
+
from ebm4subjects.analyzer import EbmAnalyzer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Chunker:
|
|
11
|
+
"""
|
|
12
|
+
A class for chunking text into smaller sections based on various criteria.
|
|
13
|
+
|
|
14
|
+
The Chunker class takes a tokenizer name and optional maximum chunk size,
|
|
15
|
+
maximum number of chunks, and maximum number of sentences as input.
|
|
16
|
+
It uses these parameters to chunk text into smaller sections.
|
|
17
|
+
|
|
18
|
+
Attributes:
|
|
19
|
+
tokenizer (EbmAnalyzer): The tokenizer used for tokenizing sentences.
|
|
20
|
+
max_chunks (int): The maximum number of chunks to generate.
|
|
21
|
+
max_chunk_size (int): The maximum size of each chunk in characters.
|
|
22
|
+
max_sentences (int): The maximum number of sentences to consider.
|
|
23
|
+
|
|
24
|
+
Methods:
|
|
25
|
+
- chunk_text: Chunks a given text into smaller sections
|
|
26
|
+
- chunk_batches: Chunks a list of texts into smaller sections in parallel
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
tokenizer_name: str,
|
|
32
|
+
max_chunks: int | None,
|
|
33
|
+
max_chunk_size: int | None,
|
|
34
|
+
max_sentences: int | None,
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Initializes the Chunker.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
tokenizer_name (str): The name of the tokenizer to use.
|
|
41
|
+
max_chunks (int | None): The maximum number of chunks to generate.
|
|
42
|
+
max_chunk_size (int | None): The maximum size of each chunk in characters.
|
|
43
|
+
max_sentences (int | None): The maximum number of sentences to consider.
|
|
44
|
+
"""
|
|
45
|
+
self.max_chunks = max_chunks if max_chunks else float("inf")
|
|
46
|
+
self.max_chunk_size = max_chunk_size if max_chunk_size else float("inf")
|
|
47
|
+
self.max_sentences = max_sentences if max_sentences else float("inf")
|
|
48
|
+
|
|
49
|
+
self.tokenizer = EbmAnalyzer(tokenizer_name)
|
|
50
|
+
|
|
51
|
+
def chunk_text(self, text: str) -> list[str]:
|
|
52
|
+
"""
|
|
53
|
+
Chunks a given text into smaller sections based on the maximum chunk size.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
text (str): The text to be chunked.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
list[str]: A list of chunked text sections.
|
|
60
|
+
"""
|
|
61
|
+
# Initialize an empty list to store the chunks
|
|
62
|
+
chunks = []
|
|
63
|
+
|
|
64
|
+
# Tokenize the text into sentences
|
|
65
|
+
sentences = self.tokenizer.tokenize_sentences(text)
|
|
66
|
+
sentences = sentences[: self.max_sentences]
|
|
67
|
+
|
|
68
|
+
# Initialize an empty list to store the current chunk
|
|
69
|
+
current_chunk = []
|
|
70
|
+
|
|
71
|
+
# Iterate over the sentences
|
|
72
|
+
for sentence in sentences:
|
|
73
|
+
# If the current chunk is not full, add the sentence to it
|
|
74
|
+
if len(" ".join(current_chunk)) < self.max_chunk_size:
|
|
75
|
+
current_chunk.append(sentence)
|
|
76
|
+
# Otherwise, add the current chunk to the list of chunks
|
|
77
|
+
# and start a new chunk
|
|
78
|
+
else:
|
|
79
|
+
chunks.append(" ".join(current_chunk))
|
|
80
|
+
current_chunk = [sentence]
|
|
81
|
+
if len(chunks) == self.max_chunks:
|
|
82
|
+
break
|
|
83
|
+
|
|
84
|
+
# If the maximum number of chunks is reached, break the loop
|
|
85
|
+
if current_chunk and len(chunks) < self.max_chunks:
|
|
86
|
+
chunks.append(" ".join(current_chunk))
|
|
87
|
+
|
|
88
|
+
# Return the chunked text
|
|
89
|
+
return chunks
|
|
90
|
+
|
|
91
|
+
def chunk_batches(
|
|
92
|
+
self, texts: list[str], doc_ids: list[str], chunking_jobs: int
|
|
93
|
+
) -> Tuple[list[str], list[pl.DataFrame]]:
|
|
94
|
+
"""
|
|
95
|
+
Chunks a list of texts into smaller sections in parallel
|
|
96
|
+
using multiple processes.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
texts (list[str]): A list of texts to be chunked.
|
|
100
|
+
doc_ids (list[str]): A list of document IDs corresponding to the texts.
|
|
101
|
+
chunking_jobs (int): The number of processes to use for chunking.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
tuple[list[str], list[pl.DataFrame]]: A tuple containing the list of
|
|
105
|
+
chunked text sections and the list of chunk indices.
|
|
106
|
+
"""
|
|
107
|
+
# Initialize an empty lists to store the chunks and chunk indices
|
|
108
|
+
text_chunks = []
|
|
109
|
+
chunk_index = []
|
|
110
|
+
|
|
111
|
+
# Calculate the batch size for each process
|
|
112
|
+
chunking_batch_size = ceil(len(texts) / chunking_jobs)
|
|
113
|
+
# Split the texts and document IDs into batches
|
|
114
|
+
batch_args = [
|
|
115
|
+
(
|
|
116
|
+
doc_ids[i * chunking_batch_size : (i + 1) * chunking_batch_size],
|
|
117
|
+
texts[i * chunking_batch_size : (i + 1) * chunking_batch_size],
|
|
118
|
+
)
|
|
119
|
+
for i in range(chunking_jobs)
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
# Use ProcessPoolExecutor to chunk the batches in parallel
|
|
123
|
+
with ProcessPoolExecutor(max_workers=chunking_jobs) as executor:
|
|
124
|
+
results = list(executor.map(self._chunk_batch, batch_args))
|
|
125
|
+
|
|
126
|
+
# Flatten the results into a single list of chunked text sections
|
|
127
|
+
# and a single list of chunk indices
|
|
128
|
+
for batch_chunks, batch_chunk_indices in results:
|
|
129
|
+
text_chunks.extend(batch_chunks)
|
|
130
|
+
chunk_index.extend(batch_chunk_indices)
|
|
131
|
+
|
|
132
|
+
# Return the chunked texts and coressponding chunk indices
|
|
133
|
+
return text_chunks, chunk_index
|
|
134
|
+
|
|
135
|
+
def _chunk_batch(self, args) -> Tuple[list[str], list[pl.DataFrame]]:
|
|
136
|
+
"""
|
|
137
|
+
Chunks a batch of texts into smaller sections.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
args (tuple[list[str], list[str]]): A tuple containing the batch
|
|
141
|
+
of document IDs and the batch of texts.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
tuple[list[str], list[pl.DataFrame]]: A tuple containing the list
|
|
145
|
+
of chunked text sections and the list of chunk indices.
|
|
146
|
+
"""
|
|
147
|
+
batch_doc_ids, batch_texts = args
|
|
148
|
+
|
|
149
|
+
# Initialize empty lists to store the chunks and chunk indices
|
|
150
|
+
batch_chunks = []
|
|
151
|
+
batch_chunk_indices = []
|
|
152
|
+
|
|
153
|
+
# Iterate over the texts in the batch
|
|
154
|
+
for doc_id, text in zip(batch_doc_ids, batch_texts):
|
|
155
|
+
# Chunk the text into smaller sections
|
|
156
|
+
new_chunks = self.chunk_text(text)
|
|
157
|
+
n_chunks = len(new_chunks)
|
|
158
|
+
|
|
159
|
+
# Create a DataFrame to store the chunk indices
|
|
160
|
+
chunk_df = pl.DataFrame(
|
|
161
|
+
{
|
|
162
|
+
"query_doc_id": [doc_id] * n_chunks,
|
|
163
|
+
"chunk_position": list(range(n_chunks)),
|
|
164
|
+
"n_chunks": [n_chunks] * n_chunks,
|
|
165
|
+
}
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Add the chunked text sections and chunk indices to the lists
|
|
169
|
+
batch_chunks.extend(new_chunks)
|
|
170
|
+
batch_chunk_indices.append(chunk_df)
|
|
171
|
+
|
|
172
|
+
# Return the chunked texts and the list of chunk indices
|
|
173
|
+
return batch_chunks, batch_chunk_indices
|
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
from threading import Thread
|
|
2
|
+
|
|
3
|
+
import duckdb
|
|
4
|
+
import polars as pl
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Duckdb_client:
|
|
8
|
+
"""
|
|
9
|
+
A class for interacting with a DuckDB database,
|
|
10
|
+
specifically for creating and querying vector search indexes.
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
connection (duckdb.Connection): The connection to the DuckDB database.
|
|
14
|
+
hnsw_index_params (dict): Parameters for the HNSW index,
|
|
15
|
+
including the number of clusters (M),
|
|
16
|
+
the enter factor of construction (ef_construction),
|
|
17
|
+
and the enter factor of search (ef_search).
|
|
18
|
+
Methods:
|
|
19
|
+
- create_collection: Creates a new collection in the DuckDB database
|
|
20
|
+
- vector_search: Performs a vector search on the specified collection
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
db_path: str,
|
|
26
|
+
config: dict = {"hnsw_enable_experimental_persistence": True, "threads": 1},
|
|
27
|
+
hnsw_index_params: dict = {"M": 32, "ef_construction": 256, "ef_search": 256},
|
|
28
|
+
) -> None:
|
|
29
|
+
"""
|
|
30
|
+
Initializes the Duckdb_client.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
db_path (str): The path to the DuckDB database.
|
|
34
|
+
config (dict, optional): Configuration options for the DuckDB connection
|
|
35
|
+
(default: {"hnsw_enable_experimental_persistence": True, "threads": 1).
|
|
36
|
+
hnsw_index_params (dict, optional): Parameters for the HNSW index
|
|
37
|
+
(default: {"M": 32, "ef_construction": 256, "ef_search": 256}).
|
|
38
|
+
|
|
39
|
+
Notes:
|
|
40
|
+
'hnsw_enable_experimental_persistence' needs to be set to 'True' in order
|
|
41
|
+
to store and query the index later
|
|
42
|
+
"""
|
|
43
|
+
# Establish a connection to the DuckDB database
|
|
44
|
+
self.connection = duckdb.connect(
|
|
45
|
+
database=db_path,
|
|
46
|
+
config=config,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Install and load the vss extension for DuckDB
|
|
50
|
+
self.connection.install_extension("vss")
|
|
51
|
+
self.connection.load_extension("vss")
|
|
52
|
+
self.hnsw_index_params = hnsw_index_params
|
|
53
|
+
|
|
54
|
+
def create_collection(
|
|
55
|
+
self,
|
|
56
|
+
collection_df: pl.DataFrame,
|
|
57
|
+
collection_name: str = "my_collection",
|
|
58
|
+
embedding_dimensions: int = 1024,
|
|
59
|
+
hnsw_index_name: str = "hnsw_index",
|
|
60
|
+
hnsw_metric: str = "cosine",
|
|
61
|
+
force: bool = False,
|
|
62
|
+
):
|
|
63
|
+
"""
|
|
64
|
+
Creates a new collection in the DuckDB database and indexes it
|
|
65
|
+
using the HNSW algorithm.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
collection_df (pl.DataFrame): The data to be inserted into the collection.
|
|
69
|
+
collection_name (str, optional): The name of the collection
|
|
70
|
+
(default: "my_collection").
|
|
71
|
+
embedding_dimensions (int, optional): The number of dimensions for the
|
|
72
|
+
vector embeddings (default: 1024).
|
|
73
|
+
hnsw_index_name (str, optional): The name of the HNSW index
|
|
74
|
+
(default: "hnsw_index").
|
|
75
|
+
hnsw_metric (str, optional): The metric to be used for the HNSW index
|
|
76
|
+
(default: "cosine")
|
|
77
|
+
force (bool, optional): Whether to replace the existing collection if it
|
|
78
|
+
already exists (default: False).
|
|
79
|
+
|
|
80
|
+
Notes:
|
|
81
|
+
If 'hnsw_metric' is changed in this function 'hnsw_metric_function' in
|
|
82
|
+
the vector_search function needs to be changed accordingly in order
|
|
83
|
+
for the index to work properly.
|
|
84
|
+
"""
|
|
85
|
+
# Determine whether to replace the existing collection
|
|
86
|
+
replace = ""
|
|
87
|
+
if force:
|
|
88
|
+
replace = "OR REPLACE "
|
|
89
|
+
|
|
90
|
+
# Create the collection table
|
|
91
|
+
self.connection.execute(
|
|
92
|
+
f"""CREATE {replace}TABLE {collection_name} (
|
|
93
|
+
id INTEGER,
|
|
94
|
+
label_id VARCHAR,
|
|
95
|
+
label_text VARCHAR,
|
|
96
|
+
is_prefLabel BOOLEAN,
|
|
97
|
+
embeddings FLOAT[{embedding_dimensions}])"""
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Insert the data into the collection table
|
|
101
|
+
self.connection.execute(
|
|
102
|
+
f"INSERT INTO {collection_name} BY NAME SELECT * FROM collection_df"
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Create the HNSW index
|
|
106
|
+
if force:
|
|
107
|
+
# Drop the existing index if it exists
|
|
108
|
+
self.connection.execute(f"DROP INDEX IF EXISTS {hnsw_index_name}")
|
|
109
|
+
self.connection.execute(
|
|
110
|
+
f"""CREATE INDEX IF NOT EXISTS {hnsw_index_name}
|
|
111
|
+
ON {collection_name}
|
|
112
|
+
USING HNSW (embeddings)
|
|
113
|
+
WITH (metric = '{hnsw_metric}', M = {self.hnsw_index_params["M"]}, ef_construction = {self.hnsw_index_params["ef_construction"]})"""
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def vector_search(
|
|
117
|
+
self,
|
|
118
|
+
query_df: pl.DataFrame,
|
|
119
|
+
collection_name: str,
|
|
120
|
+
embedding_dimensions: int,
|
|
121
|
+
n_jobs: int = 1,
|
|
122
|
+
n_hits: int = 100,
|
|
123
|
+
chunk_size: int = 2048,
|
|
124
|
+
top_k: int = 10,
|
|
125
|
+
hnsw_metric_function: str = "array_cosine_distance",
|
|
126
|
+
) -> pl.DataFrame:
|
|
127
|
+
"""
|
|
128
|
+
Performs a vector search on the specified collection using the HNSW index.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
query_df (pl.DataFrame): The data to be searched against the collection.
|
|
132
|
+
collection_name (str): The name of the collection to search against.
|
|
133
|
+
embedding_dimensions (int): The number of dimensions for the
|
|
134
|
+
vector embeddings.
|
|
135
|
+
n_jobs (int, optional): The number of jobs to use for
|
|
136
|
+
parallel processing (default: 1).
|
|
137
|
+
n_hits (int, optional): The number of hits to return per document
|
|
138
|
+
(default: 100).
|
|
139
|
+
chunk_size (int, optional): The size of each chunk for
|
|
140
|
+
parallel processing (default: 2048).
|
|
141
|
+
top_k (int, optional): The number of top-k suggestions to return
|
|
142
|
+
per document (default: 10).
|
|
143
|
+
hnsw_metric_function (str, optional): The metric function to use for
|
|
144
|
+
the HNSW index (default: "array_cosine_distance").
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
pl.DataFrame: The result of the vector search.
|
|
148
|
+
|
|
149
|
+
Notes:
|
|
150
|
+
If 'hnsw_metric_function' is changed in this function 'hnsw_metric' in
|
|
151
|
+
the create_collection function needs to be changed accordingly in order
|
|
152
|
+
for the index to work properly.
|
|
153
|
+
The argument 'chunk_size' is already set to the optimal value for the
|
|
154
|
+
query processing with DuckDB. Only change it if necessary.
|
|
155
|
+
"""
|
|
156
|
+
# Create a temporary table to store the search results
|
|
157
|
+
self.connection.execute("""CREATE OR REPLACE TABLE results (
|
|
158
|
+
id INTEGER,
|
|
159
|
+
doc_id VARCHAR,
|
|
160
|
+
chunk_position INTEGER,
|
|
161
|
+
n_chunks INTEGER,
|
|
162
|
+
label_id VARCHAR,
|
|
163
|
+
is_prefLabel BOOLEAN,
|
|
164
|
+
score FLOAT)""")
|
|
165
|
+
|
|
166
|
+
# Split the query data into chunks for parallel processing
|
|
167
|
+
query_dfs = [
|
|
168
|
+
query_df.slice(i, chunk_size) for i in range(0, query_df.height, chunk_size)
|
|
169
|
+
]
|
|
170
|
+
batches = [query_dfs[i : i + n_jobs] for i in range(0, len(query_dfs), n_jobs)]
|
|
171
|
+
|
|
172
|
+
# Perform the vector search in parallel
|
|
173
|
+
for batch in batches:
|
|
174
|
+
threads = []
|
|
175
|
+
for df in batch:
|
|
176
|
+
threads.append(
|
|
177
|
+
Thread(
|
|
178
|
+
target=self._vss_thread_query,
|
|
179
|
+
args=(
|
|
180
|
+
df,
|
|
181
|
+
collection_name,
|
|
182
|
+
embedding_dimensions,
|
|
183
|
+
hnsw_metric_function,
|
|
184
|
+
n_hits,
|
|
185
|
+
),
|
|
186
|
+
)
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
for thread in threads:
|
|
190
|
+
thread.start()
|
|
191
|
+
for thread in threads:
|
|
192
|
+
thread.join()
|
|
193
|
+
|
|
194
|
+
# Retrieve the search results
|
|
195
|
+
result_df = self.connection.execute("SELECT * FROM results").pl()
|
|
196
|
+
|
|
197
|
+
# Apply MinMax scaling to the 'score' column per 'id'
|
|
198
|
+
# and keep n_hits results
|
|
199
|
+
result_df = (
|
|
200
|
+
result_df.group_by("id")
|
|
201
|
+
.agg(
|
|
202
|
+
doc_id=pl.col("doc_id").first(),
|
|
203
|
+
chunk_position=pl.col("chunk_position").first(),
|
|
204
|
+
n_chunks=pl.col("n_chunks").first(),
|
|
205
|
+
label_id=pl.col("label_id"),
|
|
206
|
+
is_prefLabel=pl.col("is_prefLabel"),
|
|
207
|
+
cosine_similarity=pl.col("score"),
|
|
208
|
+
max_score=pl.col("score").max(),
|
|
209
|
+
min_score=pl.col("score").min(),
|
|
210
|
+
score=pl.col("score"),
|
|
211
|
+
)
|
|
212
|
+
.explode(["label_id", "is_prefLabel", "cosine_similarity", "score"])
|
|
213
|
+
.with_columns(
|
|
214
|
+
[
|
|
215
|
+
(
|
|
216
|
+
(pl.col("score") - pl.col("min_score"))
|
|
217
|
+
/ (pl.col("max_score") - pl.col("min_score") + 1e-9)
|
|
218
|
+
).alias("score")
|
|
219
|
+
]
|
|
220
|
+
)
|
|
221
|
+
.drop("min_score", "max_score")
|
|
222
|
+
.sort("score", descending=True)
|
|
223
|
+
.group_by("id")
|
|
224
|
+
.head(n_hits)
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# If a label is hit more then once due to altlabels
|
|
228
|
+
# keep only the best hit
|
|
229
|
+
result_df = (
|
|
230
|
+
result_df.sort("score", descending=True)
|
|
231
|
+
.group_by(["id", "label_id", "doc_id"])
|
|
232
|
+
.head(1)
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Across chunks (queries) aggregate statistics for
|
|
236
|
+
# each tupel 'doc_id', 'label_id'
|
|
237
|
+
result_df = result_df.group_by(["doc_id", "label_id"]).agg(
|
|
238
|
+
score=pl.col("score").sum(),
|
|
239
|
+
occurrences=pl.col("doc_id").count(),
|
|
240
|
+
min_cosine_similarity=pl.col("cosine_similarity").min(),
|
|
241
|
+
max_cosine_similarity=pl.col("cosine_similarity").max(),
|
|
242
|
+
first_occurence=pl.col("chunk_position").min(),
|
|
243
|
+
last_occurence=pl.col("chunk_position").max(),
|
|
244
|
+
spread=(pl.col("chunk_position").max() - pl.col("chunk_position").min()),
|
|
245
|
+
is_prefLabel=pl.col("is_prefLabel").first(),
|
|
246
|
+
n_chunks=pl.col("n_chunks").first(),
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# keep only top_k suggestions per document
|
|
250
|
+
result_df = (
|
|
251
|
+
result_df.sort("score", descending=True).group_by("doc_id").head(top_k)
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Scale the results and return it
|
|
255
|
+
return result_df.with_columns(
|
|
256
|
+
(pl.col("score") / pl.col("n_chunks")),
|
|
257
|
+
(pl.col("occurrences") / pl.col("n_chunks")),
|
|
258
|
+
(pl.col("first_occurence") / pl.col("n_chunks")),
|
|
259
|
+
(pl.col("last_occurence") / pl.col("n_chunks")),
|
|
260
|
+
(pl.col("spread") / pl.col("n_chunks")),
|
|
261
|
+
).sort(["doc_id", "label_id"])
|
|
262
|
+
|
|
263
|
+
def _vss_thread_query(
|
|
264
|
+
self,
|
|
265
|
+
queries_df: pl.DataFrame,
|
|
266
|
+
collection_name: str,
|
|
267
|
+
vector_dimensions: int,
|
|
268
|
+
hnsw_metric_function: str = "array_cosine_distance",
|
|
269
|
+
limit: int = 100,
|
|
270
|
+
):
|
|
271
|
+
"""
|
|
272
|
+
A helper function for performing the vector search in parallel.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
queries_df (pl.DataFrame): The data to be searched against the collection.
|
|
276
|
+
collection_name (str): The name of the collection to search against.
|
|
277
|
+
vector_dimensions (int): The number of dimensions for the
|
|
278
|
+
vector embeddings.
|
|
279
|
+
hnsw_metric_function (str, optional): The metric function to use for the
|
|
280
|
+
HNSW index (default: "array_cosine_distance").
|
|
281
|
+
limit (int, optional): The number of hits to return per document
|
|
282
|
+
(default: 100).
|
|
283
|
+
"""
|
|
284
|
+
# Create a temporary connection for the thread
|
|
285
|
+
thread_connection = self.connection.cursor()
|
|
286
|
+
|
|
287
|
+
# Create a temporary table to store the search results
|
|
288
|
+
thread_connection.execute(
|
|
289
|
+
f"""CREATE OR REPLACE TEMP TABLE queries (
|
|
290
|
+
query_id INTEGER,
|
|
291
|
+
query_doc_id VARCHAR,
|
|
292
|
+
chunk_position INTEGER,
|
|
293
|
+
n_chunks INTEGER,
|
|
294
|
+
embeddings FLOAT[{vector_dimensions}])"""
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# Insert the data into the temporary table
|
|
298
|
+
thread_connection.execute(
|
|
299
|
+
"INSERT INTO queries BY NAME SELECT * FROM queries_df"
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# apply oversearch to reduce sensitivity in MinMax scaling
|
|
303
|
+
if limit < 100:
|
|
304
|
+
limit = 100
|
|
305
|
+
|
|
306
|
+
# Set the HNSW index parameters for search
|
|
307
|
+
thread_connection.execute(
|
|
308
|
+
f"SET hnsw_ef_search = {self.hnsw_index_params['ef_search']}"
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
# Perform the vector search
|
|
312
|
+
thread_connection.execute(
|
|
313
|
+
f"""INSERT INTO results
|
|
314
|
+
SELECT queries.query_id,
|
|
315
|
+
queries.query_doc_id,
|
|
316
|
+
queries.chunk_position,
|
|
317
|
+
queries.n_chunks,
|
|
318
|
+
label_id,
|
|
319
|
+
is_prefLabel,
|
|
320
|
+
(1 - intermed_score) AS score,
|
|
321
|
+
FROM queries, LATERAL (
|
|
322
|
+
SELECT {collection_name}.label_id,
|
|
323
|
+
{collection_name}.is_prefLabel,
|
|
324
|
+
{hnsw_metric_function}(queries.embeddings, {collection_name}.embeddings) AS intermed_score
|
|
325
|
+
FROM {collection_name}
|
|
326
|
+
ORDER BY intermed_score
|
|
327
|
+
LIMIT {limit}
|
|
328
|
+
)"""
|
|
329
|
+
)
|