ebm4subjects 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,57 @@
1
+ import re
2
+
3
+ import nltk.data
4
+
5
+
6
+ class EbmAnalyzer:
7
+ """
8
+ A class for tokenizing text using NLTK.
9
+
10
+ Attributes:
11
+ tokenizer (nltk.tokenize.TokenizerI): The loaded NLTK tokenizer.
12
+
13
+ Methods:
14
+ - tokenize_sentences: Tokenizes the input text into sentences
15
+
16
+ Raises:
17
+ LookupError: If the specified tokenizer is not found.
18
+ """
19
+
20
+ def __init__(self, tokenizer_name: str) -> None:
21
+ """
22
+ Initializes the EbmAnalyzer with the specified tokenizer.
23
+
24
+ Args:
25
+ tokenizer_name (str): The name of the NLTK tokenizer to use.
26
+
27
+ Raises:
28
+ LookupError: If the specified tokenizer is not found.
29
+ """
30
+ # Attempt to find the tokenizer
31
+ try:
32
+ nltk.data.find(tokenizer_name)
33
+ # If the tokenizer is not found, try to download it
34
+ except LookupError as error:
35
+ if tokenizer_name in str(error):
36
+ nltk.download(tokenizer_name)
37
+ else:
38
+ raise
39
+
40
+ # Load the tokenizer
41
+ self.tokenizer = nltk.data.load(tokenizer_name)
42
+
43
+ def tokenize_sentences(self, text: str) -> list[str]:
44
+ """
45
+ Tokenizes the input text into sentences using the loaded tokenizer.
46
+
47
+ Args:
48
+ text (str): The input text to tokenize.
49
+
50
+ Returns:
51
+ list[str]: A list of tokenized sentences.
52
+ """
53
+ # Replace multiple periods by a singel one
54
+ # Necessary to work properly with some tables of contents
55
+ text = re.sub(r"\.{4,}", ". ", str(text))
56
+ # Tokenize the text and return it
57
+ return self.tokenizer.tokenize(text)
@@ -0,0 +1,173 @@
1
+ from concurrent.futures import ProcessPoolExecutor
2
+ from math import ceil
3
+ from typing import Tuple
4
+
5
+ import polars as pl
6
+
7
+ from ebm4subjects.analyzer import EbmAnalyzer
8
+
9
+
10
+ class Chunker:
11
+ """
12
+ A class for chunking text into smaller sections based on various criteria.
13
+
14
+ The Chunker class takes a tokenizer name and optional maximum chunk size,
15
+ maximum number of chunks, and maximum number of sentences as input.
16
+ It uses these parameters to chunk text into smaller sections.
17
+
18
+ Attributes:
19
+ tokenizer (EbmAnalyzer): The tokenizer used for tokenizing sentences.
20
+ max_chunks (int): The maximum number of chunks to generate.
21
+ max_chunk_size (int): The maximum size of each chunk in characters.
22
+ max_sentences (int): The maximum number of sentences to consider.
23
+
24
+ Methods:
25
+ - chunk_text: Chunks a given text into smaller sections
26
+ - chunk_batches: Chunks a list of texts into smaller sections in parallel
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ tokenizer_name: str,
32
+ max_chunks: int | None,
33
+ max_chunk_size: int | None,
34
+ max_sentences: int | None,
35
+ ):
36
+ """
37
+ Initializes the Chunker.
38
+
39
+ Args:
40
+ tokenizer_name (str): The name of the tokenizer to use.
41
+ max_chunks (int | None): The maximum number of chunks to generate.
42
+ max_chunk_size (int | None): The maximum size of each chunk in characters.
43
+ max_sentences (int | None): The maximum number of sentences to consider.
44
+ """
45
+ self.max_chunks = max_chunks if max_chunks else float("inf")
46
+ self.max_chunk_size = max_chunk_size if max_chunk_size else float("inf")
47
+ self.max_sentences = max_sentences if max_sentences else float("inf")
48
+
49
+ self.tokenizer = EbmAnalyzer(tokenizer_name)
50
+
51
+ def chunk_text(self, text: str) -> list[str]:
52
+ """
53
+ Chunks a given text into smaller sections based on the maximum chunk size.
54
+
55
+ Args:
56
+ text (str): The text to be chunked.
57
+
58
+ Returns:
59
+ list[str]: A list of chunked text sections.
60
+ """
61
+ # Initialize an empty list to store the chunks
62
+ chunks = []
63
+
64
+ # Tokenize the text into sentences
65
+ sentences = self.tokenizer.tokenize_sentences(text)
66
+ sentences = sentences[: self.max_sentences]
67
+
68
+ # Initialize an empty list to store the current chunk
69
+ current_chunk = []
70
+
71
+ # Iterate over the sentences
72
+ for sentence in sentences:
73
+ # If the current chunk is not full, add the sentence to it
74
+ if len(" ".join(current_chunk)) < self.max_chunk_size:
75
+ current_chunk.append(sentence)
76
+ # Otherwise, add the current chunk to the list of chunks
77
+ # and start a new chunk
78
+ else:
79
+ chunks.append(" ".join(current_chunk))
80
+ current_chunk = [sentence]
81
+ if len(chunks) == self.max_chunks:
82
+ break
83
+
84
+ # If the maximum number of chunks is reached, break the loop
85
+ if current_chunk and len(chunks) < self.max_chunks:
86
+ chunks.append(" ".join(current_chunk))
87
+
88
+ # Return the chunked text
89
+ return chunks
90
+
91
+ def chunk_batches(
92
+ self, texts: list[str], doc_ids: list[str], chunking_jobs: int
93
+ ) -> Tuple[list[str], list[pl.DataFrame]]:
94
+ """
95
+ Chunks a list of texts into smaller sections in parallel
96
+ using multiple processes.
97
+
98
+ Args:
99
+ texts (list[str]): A list of texts to be chunked.
100
+ doc_ids (list[str]): A list of document IDs corresponding to the texts.
101
+ chunking_jobs (int): The number of processes to use for chunking.
102
+
103
+ Returns:
104
+ tuple[list[str], list[pl.DataFrame]]: A tuple containing the list of
105
+ chunked text sections and the list of chunk indices.
106
+ """
107
+ # Initialize an empty lists to store the chunks and chunk indices
108
+ text_chunks = []
109
+ chunk_index = []
110
+
111
+ # Calculate the batch size for each process
112
+ chunking_batch_size = ceil(len(texts) / chunking_jobs)
113
+ # Split the texts and document IDs into batches
114
+ batch_args = [
115
+ (
116
+ doc_ids[i * chunking_batch_size : (i + 1) * chunking_batch_size],
117
+ texts[i * chunking_batch_size : (i + 1) * chunking_batch_size],
118
+ )
119
+ for i in range(chunking_jobs)
120
+ ]
121
+
122
+ # Use ProcessPoolExecutor to chunk the batches in parallel
123
+ with ProcessPoolExecutor(max_workers=chunking_jobs) as executor:
124
+ results = list(executor.map(self._chunk_batch, batch_args))
125
+
126
+ # Flatten the results into a single list of chunked text sections
127
+ # and a single list of chunk indices
128
+ for batch_chunks, batch_chunk_indices in results:
129
+ text_chunks.extend(batch_chunks)
130
+ chunk_index.extend(batch_chunk_indices)
131
+
132
+ # Return the chunked texts and coressponding chunk indices
133
+ return text_chunks, chunk_index
134
+
135
+ def _chunk_batch(self, args) -> Tuple[list[str], list[pl.DataFrame]]:
136
+ """
137
+ Chunks a batch of texts into smaller sections.
138
+
139
+ Args:
140
+ args (tuple[list[str], list[str]]): A tuple containing the batch
141
+ of document IDs and the batch of texts.
142
+
143
+ Returns:
144
+ tuple[list[str], list[pl.DataFrame]]: A tuple containing the list
145
+ of chunked text sections and the list of chunk indices.
146
+ """
147
+ batch_doc_ids, batch_texts = args
148
+
149
+ # Initialize empty lists to store the chunks and chunk indices
150
+ batch_chunks = []
151
+ batch_chunk_indices = []
152
+
153
+ # Iterate over the texts in the batch
154
+ for doc_id, text in zip(batch_doc_ids, batch_texts):
155
+ # Chunk the text into smaller sections
156
+ new_chunks = self.chunk_text(text)
157
+ n_chunks = len(new_chunks)
158
+
159
+ # Create a DataFrame to store the chunk indices
160
+ chunk_df = pl.DataFrame(
161
+ {
162
+ "query_doc_id": [doc_id] * n_chunks,
163
+ "chunk_position": list(range(n_chunks)),
164
+ "n_chunks": [n_chunks] * n_chunks,
165
+ }
166
+ )
167
+
168
+ # Add the chunked text sections and chunk indices to the lists
169
+ batch_chunks.extend(new_chunks)
170
+ batch_chunk_indices.append(chunk_df)
171
+
172
+ # Return the chunked texts and the list of chunk indices
173
+ return batch_chunks, batch_chunk_indices
@@ -0,0 +1,329 @@
1
+ from threading import Thread
2
+
3
+ import duckdb
4
+ import polars as pl
5
+
6
+
7
+ class Duckdb_client:
8
+ """
9
+ A class for interacting with a DuckDB database,
10
+ specifically for creating and querying vector search indexes.
11
+
12
+ Attributes:
13
+ connection (duckdb.Connection): The connection to the DuckDB database.
14
+ hnsw_index_params (dict): Parameters for the HNSW index,
15
+ including the number of clusters (M),
16
+ the enter factor of construction (ef_construction),
17
+ and the enter factor of search (ef_search).
18
+ Methods:
19
+ - create_collection: Creates a new collection in the DuckDB database
20
+ - vector_search: Performs a vector search on the specified collection
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ db_path: str,
26
+ config: dict = {"hnsw_enable_experimental_persistence": True, "threads": 1},
27
+ hnsw_index_params: dict = {"M": 32, "ef_construction": 256, "ef_search": 256},
28
+ ) -> None:
29
+ """
30
+ Initializes the Duckdb_client.
31
+
32
+ Args:
33
+ db_path (str): The path to the DuckDB database.
34
+ config (dict, optional): Configuration options for the DuckDB connection
35
+ (default: {"hnsw_enable_experimental_persistence": True, "threads": 1).
36
+ hnsw_index_params (dict, optional): Parameters for the HNSW index
37
+ (default: {"M": 32, "ef_construction": 256, "ef_search": 256}).
38
+
39
+ Notes:
40
+ 'hnsw_enable_experimental_persistence' needs to be set to 'True' in order
41
+ to store and query the index later
42
+ """
43
+ # Establish a connection to the DuckDB database
44
+ self.connection = duckdb.connect(
45
+ database=db_path,
46
+ config=config,
47
+ )
48
+
49
+ # Install and load the vss extension for DuckDB
50
+ self.connection.install_extension("vss")
51
+ self.connection.load_extension("vss")
52
+ self.hnsw_index_params = hnsw_index_params
53
+
54
+ def create_collection(
55
+ self,
56
+ collection_df: pl.DataFrame,
57
+ collection_name: str = "my_collection",
58
+ embedding_dimensions: int = 1024,
59
+ hnsw_index_name: str = "hnsw_index",
60
+ hnsw_metric: str = "cosine",
61
+ force: bool = False,
62
+ ):
63
+ """
64
+ Creates a new collection in the DuckDB database and indexes it
65
+ using the HNSW algorithm.
66
+
67
+ Args:
68
+ collection_df (pl.DataFrame): The data to be inserted into the collection.
69
+ collection_name (str, optional): The name of the collection
70
+ (default: "my_collection").
71
+ embedding_dimensions (int, optional): The number of dimensions for the
72
+ vector embeddings (default: 1024).
73
+ hnsw_index_name (str, optional): The name of the HNSW index
74
+ (default: "hnsw_index").
75
+ hnsw_metric (str, optional): The metric to be used for the HNSW index
76
+ (default: "cosine")
77
+ force (bool, optional): Whether to replace the existing collection if it
78
+ already exists (default: False).
79
+
80
+ Notes:
81
+ If 'hnsw_metric' is changed in this function 'hnsw_metric_function' in
82
+ the vector_search function needs to be changed accordingly in order
83
+ for the index to work properly.
84
+ """
85
+ # Determine whether to replace the existing collection
86
+ replace = ""
87
+ if force:
88
+ replace = "OR REPLACE "
89
+
90
+ # Create the collection table
91
+ self.connection.execute(
92
+ f"""CREATE {replace}TABLE {collection_name} (
93
+ id INTEGER,
94
+ label_id VARCHAR,
95
+ label_text VARCHAR,
96
+ is_prefLabel BOOLEAN,
97
+ embeddings FLOAT[{embedding_dimensions}])"""
98
+ )
99
+
100
+ # Insert the data into the collection table
101
+ self.connection.execute(
102
+ f"INSERT INTO {collection_name} BY NAME SELECT * FROM collection_df"
103
+ )
104
+
105
+ # Create the HNSW index
106
+ if force:
107
+ # Drop the existing index if it exists
108
+ self.connection.execute(f"DROP INDEX IF EXISTS {hnsw_index_name}")
109
+ self.connection.execute(
110
+ f"""CREATE INDEX IF NOT EXISTS {hnsw_index_name}
111
+ ON {collection_name}
112
+ USING HNSW (embeddings)
113
+ WITH (metric = '{hnsw_metric}', M = {self.hnsw_index_params["M"]}, ef_construction = {self.hnsw_index_params["ef_construction"]})"""
114
+ )
115
+
116
+ def vector_search(
117
+ self,
118
+ query_df: pl.DataFrame,
119
+ collection_name: str,
120
+ embedding_dimensions: int,
121
+ n_jobs: int = 1,
122
+ n_hits: int = 100,
123
+ chunk_size: int = 2048,
124
+ top_k: int = 10,
125
+ hnsw_metric_function: str = "array_cosine_distance",
126
+ ) -> pl.DataFrame:
127
+ """
128
+ Performs a vector search on the specified collection using the HNSW index.
129
+
130
+ Args:
131
+ query_df (pl.DataFrame): The data to be searched against the collection.
132
+ collection_name (str): The name of the collection to search against.
133
+ embedding_dimensions (int): The number of dimensions for the
134
+ vector embeddings.
135
+ n_jobs (int, optional): The number of jobs to use for
136
+ parallel processing (default: 1).
137
+ n_hits (int, optional): The number of hits to return per document
138
+ (default: 100).
139
+ chunk_size (int, optional): The size of each chunk for
140
+ parallel processing (default: 2048).
141
+ top_k (int, optional): The number of top-k suggestions to return
142
+ per document (default: 10).
143
+ hnsw_metric_function (str, optional): The metric function to use for
144
+ the HNSW index (default: "array_cosine_distance").
145
+
146
+ Returns:
147
+ pl.DataFrame: The result of the vector search.
148
+
149
+ Notes:
150
+ If 'hnsw_metric_function' is changed in this function 'hnsw_metric' in
151
+ the create_collection function needs to be changed accordingly in order
152
+ for the index to work properly.
153
+ The argument 'chunk_size' is already set to the optimal value for the
154
+ query processing with DuckDB. Only change it if necessary.
155
+ """
156
+ # Create a temporary table to store the search results
157
+ self.connection.execute("""CREATE OR REPLACE TABLE results (
158
+ id INTEGER,
159
+ doc_id VARCHAR,
160
+ chunk_position INTEGER,
161
+ n_chunks INTEGER,
162
+ label_id VARCHAR,
163
+ is_prefLabel BOOLEAN,
164
+ score FLOAT)""")
165
+
166
+ # Split the query data into chunks for parallel processing
167
+ query_dfs = [
168
+ query_df.slice(i, chunk_size) for i in range(0, query_df.height, chunk_size)
169
+ ]
170
+ batches = [query_dfs[i : i + n_jobs] for i in range(0, len(query_dfs), n_jobs)]
171
+
172
+ # Perform the vector search in parallel
173
+ for batch in batches:
174
+ threads = []
175
+ for df in batch:
176
+ threads.append(
177
+ Thread(
178
+ target=self._vss_thread_query,
179
+ args=(
180
+ df,
181
+ collection_name,
182
+ embedding_dimensions,
183
+ hnsw_metric_function,
184
+ n_hits,
185
+ ),
186
+ )
187
+ )
188
+
189
+ for thread in threads:
190
+ thread.start()
191
+ for thread in threads:
192
+ thread.join()
193
+
194
+ # Retrieve the search results
195
+ result_df = self.connection.execute("SELECT * FROM results").pl()
196
+
197
+ # Apply MinMax scaling to the 'score' column per 'id'
198
+ # and keep n_hits results
199
+ result_df = (
200
+ result_df.group_by("id")
201
+ .agg(
202
+ doc_id=pl.col("doc_id").first(),
203
+ chunk_position=pl.col("chunk_position").first(),
204
+ n_chunks=pl.col("n_chunks").first(),
205
+ label_id=pl.col("label_id"),
206
+ is_prefLabel=pl.col("is_prefLabel"),
207
+ cosine_similarity=pl.col("score"),
208
+ max_score=pl.col("score").max(),
209
+ min_score=pl.col("score").min(),
210
+ score=pl.col("score"),
211
+ )
212
+ .explode(["label_id", "is_prefLabel", "cosine_similarity", "score"])
213
+ .with_columns(
214
+ [
215
+ (
216
+ (pl.col("score") - pl.col("min_score"))
217
+ / (pl.col("max_score") - pl.col("min_score") + 1e-9)
218
+ ).alias("score")
219
+ ]
220
+ )
221
+ .drop("min_score", "max_score")
222
+ .sort("score", descending=True)
223
+ .group_by("id")
224
+ .head(n_hits)
225
+ )
226
+
227
+ # If a label is hit more then once due to altlabels
228
+ # keep only the best hit
229
+ result_df = (
230
+ result_df.sort("score", descending=True)
231
+ .group_by(["id", "label_id", "doc_id"])
232
+ .head(1)
233
+ )
234
+
235
+ # Across chunks (queries) aggregate statistics for
236
+ # each tupel 'doc_id', 'label_id'
237
+ result_df = result_df.group_by(["doc_id", "label_id"]).agg(
238
+ score=pl.col("score").sum(),
239
+ occurrences=pl.col("doc_id").count(),
240
+ min_cosine_similarity=pl.col("cosine_similarity").min(),
241
+ max_cosine_similarity=pl.col("cosine_similarity").max(),
242
+ first_occurence=pl.col("chunk_position").min(),
243
+ last_occurence=pl.col("chunk_position").max(),
244
+ spread=(pl.col("chunk_position").max() - pl.col("chunk_position").min()),
245
+ is_prefLabel=pl.col("is_prefLabel").first(),
246
+ n_chunks=pl.col("n_chunks").first(),
247
+ )
248
+
249
+ # keep only top_k suggestions per document
250
+ result_df = (
251
+ result_df.sort("score", descending=True).group_by("doc_id").head(top_k)
252
+ )
253
+
254
+ # Scale the results and return it
255
+ return result_df.with_columns(
256
+ (pl.col("score") / pl.col("n_chunks")),
257
+ (pl.col("occurrences") / pl.col("n_chunks")),
258
+ (pl.col("first_occurence") / pl.col("n_chunks")),
259
+ (pl.col("last_occurence") / pl.col("n_chunks")),
260
+ (pl.col("spread") / pl.col("n_chunks")),
261
+ ).sort(["doc_id", "label_id"])
262
+
263
+ def _vss_thread_query(
264
+ self,
265
+ queries_df: pl.DataFrame,
266
+ collection_name: str,
267
+ vector_dimensions: int,
268
+ hnsw_metric_function: str = "array_cosine_distance",
269
+ limit: int = 100,
270
+ ):
271
+ """
272
+ A helper function for performing the vector search in parallel.
273
+
274
+ Args:
275
+ queries_df (pl.DataFrame): The data to be searched against the collection.
276
+ collection_name (str): The name of the collection to search against.
277
+ vector_dimensions (int): The number of dimensions for the
278
+ vector embeddings.
279
+ hnsw_metric_function (str, optional): The metric function to use for the
280
+ HNSW index (default: "array_cosine_distance").
281
+ limit (int, optional): The number of hits to return per document
282
+ (default: 100).
283
+ """
284
+ # Create a temporary connection for the thread
285
+ thread_connection = self.connection.cursor()
286
+
287
+ # Create a temporary table to store the search results
288
+ thread_connection.execute(
289
+ f"""CREATE OR REPLACE TEMP TABLE queries (
290
+ query_id INTEGER,
291
+ query_doc_id VARCHAR,
292
+ chunk_position INTEGER,
293
+ n_chunks INTEGER,
294
+ embeddings FLOAT[{vector_dimensions}])"""
295
+ )
296
+
297
+ # Insert the data into the temporary table
298
+ thread_connection.execute(
299
+ "INSERT INTO queries BY NAME SELECT * FROM queries_df"
300
+ )
301
+
302
+ # apply oversearch to reduce sensitivity in MinMax scaling
303
+ if limit < 100:
304
+ limit = 100
305
+
306
+ # Set the HNSW index parameters for search
307
+ thread_connection.execute(
308
+ f"SET hnsw_ef_search = {self.hnsw_index_params['ef_search']}"
309
+ )
310
+
311
+ # Perform the vector search
312
+ thread_connection.execute(
313
+ f"""INSERT INTO results
314
+ SELECT queries.query_id,
315
+ queries.query_doc_id,
316
+ queries.chunk_position,
317
+ queries.n_chunks,
318
+ label_id,
319
+ is_prefLabel,
320
+ (1 - intermed_score) AS score,
321
+ FROM queries, LATERAL (
322
+ SELECT {collection_name}.label_id,
323
+ {collection_name}.is_prefLabel,
324
+ {hnsw_metric_function}(queries.embeddings, {collection_name}.embeddings) AS intermed_score
325
+ FROM {collection_name}
326
+ ORDER BY intermed_score
327
+ LIMIT {limit}
328
+ )"""
329
+ )