ragit 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragit/__init__.py +2 -0
- ragit/main.py +384 -0
- ragit-0.1.dist-info/METADATA +10 -0
- ragit-0.1.dist-info/RECORD +6 -0
- ragit-0.1.dist-info/WHEEL +5 -0
- ragit-0.1.dist-info/top_level.txt +1 -0
ragit/__init__.py
ADDED
ragit/main.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
1
|
+
import chromadb
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import logging
|
|
4
|
+
from sentence_transformers import SentenceTransformer
|
|
5
|
+
from typing import List, Dict, Optional, Union
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class VectorDBManager:
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
persist_directory: str = "./vector_db",
|
|
13
|
+
provider: str = "sentence_transformer",
|
|
14
|
+
model_name: str = "all-mpnet-base-v2",
|
|
15
|
+
):
|
|
16
|
+
"""
|
|
17
|
+
Initialize the Vector Database Manager.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
persist_directory (str): Directory to persist the database
|
|
21
|
+
"""
|
|
22
|
+
self.persist_directory = persist_directory
|
|
23
|
+
self.client = chromadb.PersistentClient(path=persist_directory)
|
|
24
|
+
if provider == "sentence_transformer":
|
|
25
|
+
self.model = SentenceTransformer(model_name)
|
|
26
|
+
|
|
27
|
+
logging.basicConfig(
|
|
28
|
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
29
|
+
)
|
|
30
|
+
self.logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
def create_database(
|
|
33
|
+
self,
|
|
34
|
+
csv_path: str,
|
|
35
|
+
collection_name: str,
|
|
36
|
+
distance_metric: str = "l2",
|
|
37
|
+
collection_metadata: Dict = None,
|
|
38
|
+
) -> bool:
|
|
39
|
+
"""
|
|
40
|
+
Create a new database from a CSV file.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
csv_path (str): Path to the CSV file containing 'id' and 'text' columns
|
|
44
|
+
collection_name (str): Name of the collection to create
|
|
45
|
+
distance_metric (str): Distance metric (l2, cosine, ip)
|
|
46
|
+
collection_metadata (Dict, optional): Additional metadata for the collection
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
bool: True if successful, False otherwise
|
|
50
|
+
"""
|
|
51
|
+
try:
|
|
52
|
+
|
|
53
|
+
df = pd.read_csv(csv_path)
|
|
54
|
+
|
|
55
|
+
if not {"id", "text"}.issubset(df.columns):
|
|
56
|
+
self.logger.error("CSV must contain 'id' and 'text' columns")
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
collection_meta = {
|
|
60
|
+
"hnsw:space": distance_metric,
|
|
61
|
+
"description": f"Collection created from {csv_path}",
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
if collection_metadata:
|
|
65
|
+
collection_meta.update(collection_metadata)
|
|
66
|
+
|
|
67
|
+
collection = self.client.create_collection(
|
|
68
|
+
name=collection_name, metadata=collection_meta
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
embeddings = self.model.encode(df["text"].tolist()).tolist()
|
|
72
|
+
|
|
73
|
+
collection.add(
|
|
74
|
+
ids=[str(id_) for id_ in df["id"]],
|
|
75
|
+
documents=df["text"].tolist(),
|
|
76
|
+
embeddings=embeddings,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
self.logger.info(f"Successfully created collection '{collection_name}'")
|
|
80
|
+
return True
|
|
81
|
+
|
|
82
|
+
except Exception as e:
|
|
83
|
+
self.logger.error(f"Error creating database: {str(e)}")
|
|
84
|
+
return False
|
|
85
|
+
|
|
86
|
+
def add_values_from_csv(
|
|
87
|
+
self, csv_path: str, collection_name: str
|
|
88
|
+
) -> Dict[str, int]:
|
|
89
|
+
"""
|
|
90
|
+
Add values from CSV file to existing collection, skipping existing IDs.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
csv_path (str): Path to the CSV file
|
|
94
|
+
collection_name (str): Name of the target collection
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Dict[str, int]: Statistics about the operation
|
|
98
|
+
"""
|
|
99
|
+
try:
|
|
100
|
+
|
|
101
|
+
df = pd.read_csv(csv_path)
|
|
102
|
+
|
|
103
|
+
collection = self.client.get_collection(collection_name)
|
|
104
|
+
|
|
105
|
+
existing_ids = set(collection.get()["ids"])
|
|
106
|
+
|
|
107
|
+
new_df = df[~df["id"].astype(str).isin(existing_ids)]
|
|
108
|
+
|
|
109
|
+
if not new_df.empty:
|
|
110
|
+
|
|
111
|
+
embeddings = self.model.encode(new_df["text"].tolist()).tolist()
|
|
112
|
+
|
|
113
|
+
collection.add(
|
|
114
|
+
ids=[str(id_) for id_ in new_df["id"]],
|
|
115
|
+
documents=new_df["text"].tolist(),
|
|
116
|
+
embeddings=embeddings,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
stats = {
|
|
120
|
+
"total_entries": len(df),
|
|
121
|
+
"new_entries_added": len(new_df),
|
|
122
|
+
"skipped_entries": len(df) - len(new_df),
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
self.logger.info(
|
|
126
|
+
f"Added {stats['new_entries_added']} new entries to '{collection_name}'"
|
|
127
|
+
)
|
|
128
|
+
return stats
|
|
129
|
+
|
|
130
|
+
except Exception as e:
|
|
131
|
+
self.logger.error(f"Error adding values from CSV: {str(e)}")
|
|
132
|
+
return {"error": str(e)}
|
|
133
|
+
|
|
134
|
+
def add_single_row(self, id_: str, text: str, collection_name: str) -> bool:
|
|
135
|
+
"""
|
|
136
|
+
Add a single entry to the collection.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
id_ (str): ID for the new entry
|
|
140
|
+
text (str): Text content
|
|
141
|
+
collection_name (str): Target collection name
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
bool: True if successful, False otherwise
|
|
145
|
+
"""
|
|
146
|
+
try:
|
|
147
|
+
collection = self.client.get_collection(collection_name)
|
|
148
|
+
|
|
149
|
+
if str(id_) in collection.get()["ids"]:
|
|
150
|
+
self.logger.warning(f"ID {id_} already exists in collection")
|
|
151
|
+
return False
|
|
152
|
+
|
|
153
|
+
embedding = self.model.encode([text]).tolist()
|
|
154
|
+
|
|
155
|
+
collection.add(ids=[str(id_)], documents=[text], embeddings=embedding)
|
|
156
|
+
|
|
157
|
+
self.logger.info(f"Successfully added entry with ID {id_}")
|
|
158
|
+
return True
|
|
159
|
+
|
|
160
|
+
except Exception as e:
|
|
161
|
+
self.logger.error(f"Error adding single row: {str(e)}")
|
|
162
|
+
return False
|
|
163
|
+
|
|
164
|
+
def delete_entry_by_id(self, id_: str, collection_name: str) -> bool:
|
|
165
|
+
"""
|
|
166
|
+
Delete an entry by its ID.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
id_ (str): ID of the entry to delete
|
|
170
|
+
collection_name (str): Collection name
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
bool: True if successful, False otherwise
|
|
174
|
+
"""
|
|
175
|
+
try:
|
|
176
|
+
collection = self.client.get_collection(collection_name)
|
|
177
|
+
|
|
178
|
+
if str(id_) not in collection.get()["ids"]:
|
|
179
|
+
self.logger.warning(f"ID {id_} not found in collection")
|
|
180
|
+
return False
|
|
181
|
+
|
|
182
|
+
collection.delete(ids=[str(id_)])
|
|
183
|
+
|
|
184
|
+
self.logger.info(f"Successfully deleted entry with ID {id_}")
|
|
185
|
+
return True
|
|
186
|
+
|
|
187
|
+
except Exception as e:
|
|
188
|
+
self.logger.error(f"Error deleting entry: {str(e)}")
|
|
189
|
+
return False
|
|
190
|
+
|
|
191
|
+
def find_nearby_texts(
|
|
192
|
+
self,
|
|
193
|
+
text: str,
|
|
194
|
+
collection_name: str,
|
|
195
|
+
search_string: Optional[str] = None,
|
|
196
|
+
k: int = 5,
|
|
197
|
+
) -> List[Dict[str, Union[str, float]]]:
|
|
198
|
+
"""
|
|
199
|
+
Find nearby texts using similarity search with scores.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
text (str): Query text
|
|
203
|
+
collection_name (str): Collection to search in
|
|
204
|
+
k (int): Number of results to return
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
List[Dict[str, Union[str, float]]]: List of nearby texts with their IDs and similarity scores
|
|
208
|
+
"""
|
|
209
|
+
try:
|
|
210
|
+
collection = self.client.get_collection(collection_name)
|
|
211
|
+
print("Metadata:", collection.metadata)
|
|
212
|
+
|
|
213
|
+
distance_metric = collection.metadata["hnsw:space"]
|
|
214
|
+
|
|
215
|
+
query_embedding = self.model.encode([text]).tolist()
|
|
216
|
+
|
|
217
|
+
if search_string:
|
|
218
|
+
results = collection.query(
|
|
219
|
+
query_embeddings=query_embedding,
|
|
220
|
+
n_results=k,
|
|
221
|
+
include=["documents", "distances", "metadatas"],
|
|
222
|
+
where_document={"$contains": search_string},
|
|
223
|
+
)
|
|
224
|
+
else:
|
|
225
|
+
results = collection.query(
|
|
226
|
+
query_embeddings=query_embedding,
|
|
227
|
+
n_results=k,
|
|
228
|
+
include=["documents", "distances", "metadatas"],
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
distances = results["distances"][0]
|
|
232
|
+
if not distances:
|
|
233
|
+
return []
|
|
234
|
+
|
|
235
|
+
similarities = []
|
|
236
|
+
for dist in distances:
|
|
237
|
+
if distance_metric == "cosine":
|
|
238
|
+
|
|
239
|
+
similarity = 1 - dist
|
|
240
|
+
elif distance_metric == "ip":
|
|
241
|
+
|
|
242
|
+
min_dist = min(distances)
|
|
243
|
+
max_dist = max(distances)
|
|
244
|
+
similarity = (
|
|
245
|
+
(dist - min_dist) / (max_dist - min_dist)
|
|
246
|
+
if max_dist > min_dist
|
|
247
|
+
else 1.0
|
|
248
|
+
)
|
|
249
|
+
elif distance_metric == "l1":
|
|
250
|
+
|
|
251
|
+
max_dist = max(distances)
|
|
252
|
+
similarity = 1 - (dist / max_dist) if max_dist > 0 else 1.0
|
|
253
|
+
elif distance_metric == "l2":
|
|
254
|
+
|
|
255
|
+
max_dist = max(distances)
|
|
256
|
+
similarity = 1 - (dist / max_dist) if max_dist > 0 else 1.0
|
|
257
|
+
|
|
258
|
+
similarities.append(similarity)
|
|
259
|
+
|
|
260
|
+
nearby_texts = [
|
|
261
|
+
{
|
|
262
|
+
"id": id_,
|
|
263
|
+
"text": text_,
|
|
264
|
+
"similarity": round(similarity * 100, 4),
|
|
265
|
+
"raw_distance": dist,
|
|
266
|
+
"metric": distance_metric,
|
|
267
|
+
}
|
|
268
|
+
for id_, text_, similarity, dist in zip(
|
|
269
|
+
results["ids"][0], results["documents"][0], similarities, distances
|
|
270
|
+
)
|
|
271
|
+
]
|
|
272
|
+
|
|
273
|
+
return nearby_texts
|
|
274
|
+
|
|
275
|
+
except Exception as e:
|
|
276
|
+
self.logger.error(f"Error finding nearby texts: {str(e)}")
|
|
277
|
+
return []
|
|
278
|
+
|
|
279
|
+
def delete_collection(self, collection_name: str, confirmation: str = "no") -> bool:
|
|
280
|
+
"""
|
|
281
|
+
Delete an entire collection.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
collection_name (str): Name of collection to delete
|
|
285
|
+
confirmation (str): Must be 'yes' to proceed
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
bool: True if successful, False otherwise
|
|
289
|
+
"""
|
|
290
|
+
try:
|
|
291
|
+
if confirmation.lower() != "yes":
|
|
292
|
+
self.logger.warning("Deletion cancelled - confirmation not provided")
|
|
293
|
+
return False
|
|
294
|
+
|
|
295
|
+
self.client.delete_collection(collection_name)
|
|
296
|
+
self.logger.info(f"Successfully deleted collection '{collection_name}'")
|
|
297
|
+
return True
|
|
298
|
+
|
|
299
|
+
except Exception as e:
|
|
300
|
+
self.logger.error(f"Error deleting collection: {str(e)}")
|
|
301
|
+
return False
|
|
302
|
+
|
|
303
|
+
def get_collection_info(self, collection_name: str) -> Dict:
|
|
304
|
+
"""
|
|
305
|
+
Get information about a collection.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
collection_name (str): Name of the collection
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
Dict: Collection information and statistics
|
|
312
|
+
"""
|
|
313
|
+
try:
|
|
314
|
+
collection = self.client.get_collection(collection_name)
|
|
315
|
+
collection_data = collection.get()
|
|
316
|
+
|
|
317
|
+
info = {
|
|
318
|
+
"name": collection_name,
|
|
319
|
+
"count": len(collection_data["ids"]),
|
|
320
|
+
"metadata": collection.metadata,
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
return info
|
|
324
|
+
|
|
325
|
+
except Exception as e:
|
|
326
|
+
self.logger.error(f"Error getting collection info: {str(e)}")
|
|
327
|
+
return {"error": str(e)}
|
|
328
|
+
|
|
329
|
+
def get_by_ids(self, ids: List[str], collection_name: str) -> Dict[str, str]:
|
|
330
|
+
"""
|
|
331
|
+
Get texts for given IDs in batch.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
ids (List[str]): List of IDs to fetch
|
|
335
|
+
collection_name (str): Name of the collection
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
Dict[str, str]: Dictionary mapping IDs to their corresponding texts
|
|
339
|
+
"""
|
|
340
|
+
try:
|
|
341
|
+
collection = self.client.get_collection(collection_name)
|
|
342
|
+
|
|
343
|
+
str_ids = [str(id_) for id_ in ids]
|
|
344
|
+
|
|
345
|
+
results = collection.get(ids=str_ids, include=["documents"])
|
|
346
|
+
|
|
347
|
+
id_to_text = {
|
|
348
|
+
id_: text for id_, text in zip(results["ids"], results["documents"])
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
return id_to_text
|
|
352
|
+
|
|
353
|
+
except Exception as e:
|
|
354
|
+
self.logger.error(f"Error getting texts by IDs: {str(e)}")
|
|
355
|
+
return {}
|
|
356
|
+
|
|
357
|
+
def get_by_texts(self, texts: List[str], collection_name: str) -> Dict[str, str]:
|
|
358
|
+
"""
|
|
359
|
+
Get IDs for given texts in batch.
|
|
360
|
+
Note: For exact text matching. For similar texts, use find_nearby_texts.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
texts (List[str]): List of texts to fetch
|
|
364
|
+
collection_name (str): Name of the collection
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
Dict[str, str]: Dictionary mapping texts to their corresponding IDs
|
|
368
|
+
"""
|
|
369
|
+
try:
|
|
370
|
+
collection = self.client.get_collection(collection_name)
|
|
371
|
+
|
|
372
|
+
all_data = collection.get()
|
|
373
|
+
|
|
374
|
+
text_to_id = {
|
|
375
|
+
text: id_
|
|
376
|
+
for text, id_ in zip(all_data["documents"], all_data["ids"])
|
|
377
|
+
if text in texts
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
return text_to_id
|
|
381
|
+
|
|
382
|
+
except Exception as e:
|
|
383
|
+
self.logger.error(f"Error getting IDs by texts: {str(e)}")
|
|
384
|
+
return {}
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: ragit
|
|
3
|
+
Version: 0.1
|
|
4
|
+
Requires-Dist: sentence-transformers>=3.4.1
|
|
5
|
+
Requires-Dist: pandas>=2.2.3
|
|
6
|
+
Requires-Dist: chromadb>=0.6.3
|
|
7
|
+
Requires-Dist: setuptools>=75.8.0
|
|
8
|
+
Requires-Dist: wheel>=0.45.1
|
|
9
|
+
Requires-Dist: twine>=6.1.0
|
|
10
|
+
Dynamic: requires-dist
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
ragit/__init__.py,sha256=GECJxYFL_0PMy6tbcVFpW9Fhe1JiI2uXH4iJWhUHpKs,48
|
|
2
|
+
ragit/main.py,sha256=L5jId4Cwrd99tV10QzokvF3XWarOj3vPEOWLzPrTMN0,12721
|
|
3
|
+
ragit-0.1.dist-info/METADATA,sha256=DygiOq0ZZ8aFK7R1Xl-ije86Zgbu9UaZ3eFpnNLabYU,265
|
|
4
|
+
ragit-0.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
5
|
+
ragit-0.1.dist-info/top_level.txt,sha256=pkPbG7yrw61wt9_y_xcLE2vq2a55fzockASD0yq0g4s,6
|
|
6
|
+
ragit-0.1.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
ragit
|