ragxo 0.1.3__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragxo/__init__.py +1 -1
- ragxo/client.py +271 -0
- {ragxo-0.1.3.dist-info → ragxo-0.1.6.dist-info}/METADATA +79 -67
- ragxo-0.1.6.dist-info/RECORD +5 -0
- ragxo/ragx.py +0 -145
- ragxo-0.1.3.dist-info/RECORD +0 -5
- {ragxo-0.1.3.dist-info → ragxo-0.1.6.dist-info}/WHEEL +0 -0
ragxo/__init__.py
CHANGED
ragxo/client.py
ADDED
@@ -0,0 +1,271 @@
|
|
1
|
+
import time
|
2
|
+
from typing import Self, Callable
|
3
|
+
from pymilvus import MilvusClient
|
4
|
+
from pydantic import BaseModel
|
5
|
+
import boto3
|
6
|
+
import dill
|
7
|
+
import os
|
8
|
+
import shutil
|
9
|
+
import logging
|
10
|
+
import tempfile
|
11
|
+
from botocore.exceptions import ClientError
|
12
|
+
import openai
|
13
|
+
from openai import ChatCompletion
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
class Document(BaseModel):
|
18
|
+
text: str
|
19
|
+
metadata: dict
|
20
|
+
id: int
|
21
|
+
|
22
|
+
class Ragxo:
|
23
|
+
def __init__(self, dimension: int) -> None:
|
24
|
+
self.dimension = dimension
|
25
|
+
self.collection_name = "ragx"
|
26
|
+
os.makedirs("ragx_artifacts", exist_ok=True)
|
27
|
+
|
28
|
+
self.db_path = f"ragx_artifacts/milvus_{int(time.time())}.db"
|
29
|
+
self.client = MilvusClient(self.db_path)
|
30
|
+
self.client.create_collection(self.collection_name, dimension=dimension)
|
31
|
+
self.processing_fn = []
|
32
|
+
self.embedding_fn = None
|
33
|
+
self.system_prompt = None
|
34
|
+
self.model = "gpt-4o-mini"
|
35
|
+
|
36
|
+
def add_preprocess(self, fn: Callable) -> Self:
|
37
|
+
self.processing_fn.append(fn)
|
38
|
+
return self
|
39
|
+
|
40
|
+
def add_llm_response_fn(self, fn: Callable) -> Self:
|
41
|
+
self.llm_response_fn = fn
|
42
|
+
return self
|
43
|
+
|
44
|
+
def add_embedding_fn(self, fn: Callable) -> Self:
|
45
|
+
if not fn:
|
46
|
+
raise ValueError("Embedding function cannot be None")
|
47
|
+
self.embedding_fn = fn
|
48
|
+
return self
|
49
|
+
|
50
|
+
def add_system_prompt(self, prompt: str) -> Self:
|
51
|
+
self.system_prompt = prompt
|
52
|
+
return self
|
53
|
+
|
54
|
+
def add_model(self, model: str) -> Self:
|
55
|
+
self.model = model
|
56
|
+
return self
|
57
|
+
|
58
|
+
def index(self, data: list[Document]) -> Self:
|
59
|
+
if not self.embedding_fn:
|
60
|
+
raise ValueError("Embedding function not set")
|
61
|
+
|
62
|
+
processed_text = []
|
63
|
+
for item in data:
|
64
|
+
current_text = item.text
|
65
|
+
for fn in self.processing_fn:
|
66
|
+
current_text = fn(current_text)
|
67
|
+
processed_text.append(current_text)
|
68
|
+
|
69
|
+
embeddings = [
|
70
|
+
self.embedding_fn(text)
|
71
|
+
for text in processed_text
|
72
|
+
]
|
73
|
+
|
74
|
+
self.client.insert(self.collection_name, [
|
75
|
+
{
|
76
|
+
"text": item.text,
|
77
|
+
"metadata": item.metadata,
|
78
|
+
"id": item.id,
|
79
|
+
"vector": embedding
|
80
|
+
}
|
81
|
+
for item, embedding in zip(data, embeddings)
|
82
|
+
])
|
83
|
+
return self
|
84
|
+
|
85
|
+
def query(self, query: str, output_fields: list[str] = ['text', 'metadata'], limit: int = 10) -> list[list[dict]]:
|
86
|
+
if not self.embedding_fn:
|
87
|
+
raise ValueError("Embedding function not set. Please call add_embedding_fn first.")
|
88
|
+
|
89
|
+
preprocessed_query = query
|
90
|
+
for fn in self.processing_fn:
|
91
|
+
preprocessed_query = fn(preprocessed_query)
|
92
|
+
|
93
|
+
embedding = self.embedding_fn(preprocessed_query)
|
94
|
+
|
95
|
+
return self.client.search(
|
96
|
+
collection_name=self.collection_name,
|
97
|
+
data=[embedding],
|
98
|
+
limit=limit,
|
99
|
+
output_fields=output_fields
|
100
|
+
)
|
101
|
+
|
102
|
+
def export(self, destination: str, s3_bucket: str = None) -> Self:
|
103
|
+
"""
|
104
|
+
Export the Ragx instance to either local filesystem or S3.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
destination: str - Local path or S3 key prefix
|
108
|
+
s3_bucket: str, optional - S3 bucket name. If provided, export to S3
|
109
|
+
"""
|
110
|
+
try:
|
111
|
+
# If s3_bucket is provided, export to S3
|
112
|
+
if s3_bucket:
|
113
|
+
return self._export_to_s3(destination, s3_bucket)
|
114
|
+
|
115
|
+
# Otherwise, export to local filesystem
|
116
|
+
os.makedirs(destination, exist_ok=True)
|
117
|
+
|
118
|
+
# Save using dill
|
119
|
+
pickle_path = os.path.join(destination, "ragx.pkl")
|
120
|
+
with open(pickle_path, "wb") as f:
|
121
|
+
dill.dump(self, f)
|
122
|
+
|
123
|
+
# Copy database
|
124
|
+
db_dest = os.path.join(destination, "milvus.db")
|
125
|
+
shutil.copy(self.db_path, db_dest)
|
126
|
+
|
127
|
+
return self
|
128
|
+
|
129
|
+
except Exception as e:
|
130
|
+
logger.error(f"Error in export: {e}")
|
131
|
+
raise
|
132
|
+
|
133
|
+
def _export_to_s3(self, prefix: str, bucket: str) -> Self:
|
134
|
+
"""
|
135
|
+
Internal method to handle S3 export.
|
136
|
+
"""
|
137
|
+
try:
|
138
|
+
s3_client = boto3.client('s3')
|
139
|
+
|
140
|
+
# Create a temporary directory for the files
|
141
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
142
|
+
# Save pickle file
|
143
|
+
pickle_path = os.path.join(temp_dir, "ragx.pkl")
|
144
|
+
with open(pickle_path, "wb") as f:
|
145
|
+
dill.dump(self, f)
|
146
|
+
|
147
|
+
# Copy database
|
148
|
+
db_path = os.path.join(temp_dir, "milvus.db")
|
149
|
+
shutil.copy(self.db_path, db_path)
|
150
|
+
|
151
|
+
# Upload to S3
|
152
|
+
s3_client.upload_file(
|
153
|
+
pickle_path,
|
154
|
+
bucket,
|
155
|
+
f"{prefix}/ragx.pkl"
|
156
|
+
)
|
157
|
+
s3_client.upload_file(
|
158
|
+
db_path,
|
159
|
+
bucket,
|
160
|
+
f"{prefix}/milvus.db"
|
161
|
+
)
|
162
|
+
|
163
|
+
return self
|
164
|
+
|
165
|
+
except ClientError as e:
|
166
|
+
logger.error(f"Error uploading to S3: {e}")
|
167
|
+
raise
|
168
|
+
except Exception as e:
|
169
|
+
logger.error(f"Error in S3 export: {e}")
|
170
|
+
raise
|
171
|
+
|
172
|
+
@classmethod
|
173
|
+
def load(cls, source: str, s3_bucket: str = None) -> Self:
|
174
|
+
"""
|
175
|
+
Load a Ragx instance from either local filesystem or S3.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
source: str - Local path or S3 key prefix
|
179
|
+
s3_bucket: str, optional - S3 bucket name. If provided, load from S3
|
180
|
+
"""
|
181
|
+
try:
|
182
|
+
# If s3_bucket is provided, load from S3
|
183
|
+
if s3_bucket:
|
184
|
+
return cls._load_from_s3(source, s3_bucket)
|
185
|
+
|
186
|
+
# Otherwise, load from local filesystem
|
187
|
+
pickle_path = os.path.join(source, "ragx.pkl")
|
188
|
+
|
189
|
+
with open(pickle_path, "rb") as f:
|
190
|
+
instance = dill.load(f)
|
191
|
+
|
192
|
+
# Restore client
|
193
|
+
instance.client = MilvusClient(os.path.join(source, "milvus.db"))
|
194
|
+
|
195
|
+
return instance
|
196
|
+
|
197
|
+
except Exception as e:
|
198
|
+
logger.error(f"Error in load: {e}")
|
199
|
+
raise
|
200
|
+
|
201
|
+
@classmethod
|
202
|
+
def _load_from_s3(cls, prefix: str, bucket: str) -> 'Ragx':
|
203
|
+
"""
|
204
|
+
Internal classmethod to handle S3 loading.
|
205
|
+
"""
|
206
|
+
try:
|
207
|
+
s3_client = boto3.client('s3')
|
208
|
+
|
209
|
+
# Create a temporary directory for the files
|
210
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
211
|
+
# Download files from S3
|
212
|
+
pickle_path = os.path.join(temp_dir, "ragx.pkl")
|
213
|
+
db_path = os.path.join(temp_dir, "milvus.db")
|
214
|
+
|
215
|
+
s3_client.download_file(
|
216
|
+
bucket,
|
217
|
+
f"{prefix}/ragx.pkl",
|
218
|
+
pickle_path
|
219
|
+
)
|
220
|
+
s3_client.download_file(
|
221
|
+
bucket,
|
222
|
+
f"{prefix}/milvus.db",
|
223
|
+
db_path
|
224
|
+
)
|
225
|
+
|
226
|
+
# Load the pickle file
|
227
|
+
with open(pickle_path, "rb") as f:
|
228
|
+
instance = dill.load(f)
|
229
|
+
|
230
|
+
# Restore client with the downloaded database
|
231
|
+
instance.client = MilvusClient(db_path)
|
232
|
+
|
233
|
+
return instance
|
234
|
+
|
235
|
+
except ClientError as e:
|
236
|
+
logger.error(f"Error downloading from S3: {e}")
|
237
|
+
raise
|
238
|
+
except Exception as e:
|
239
|
+
logger.error(f"Error in S3 load: {e}")
|
240
|
+
raise
|
241
|
+
|
242
|
+
def generate_llm_response(self,
|
243
|
+
query: str,
|
244
|
+
limit: int = 10,
|
245
|
+
data: list[dict] = None,
|
246
|
+
temperature: float = 0.5,
|
247
|
+
max_tokens: int = 1000,
|
248
|
+
top_p: float = 1.0,
|
249
|
+
frequency_penalty: float = 0.0,
|
250
|
+
presence_penalty: float = 0.0,
|
251
|
+
) -> ChatCompletion:
|
252
|
+
if data is None:
|
253
|
+
data = self.query(query, limit=limit)[0]
|
254
|
+
|
255
|
+
if not self.system_prompt:
|
256
|
+
raise ValueError("System prompt not set. Please call add_system_prompt first.")
|
257
|
+
|
258
|
+
response = openai.chat.completions.create(
|
259
|
+
model=self.model,
|
260
|
+
messages=[
|
261
|
+
{"role": "system", "content": self.system_prompt},
|
262
|
+
{"role": "user", "content": "query: {} data: {}".format(query, data)}
|
263
|
+
],
|
264
|
+
temperature=temperature,
|
265
|
+
max_tokens=max_tokens,
|
266
|
+
top_p=top_p,
|
267
|
+
frequency_penalty=frequency_penalty,
|
268
|
+
presence_penalty=presence_penalty,
|
269
|
+
)
|
270
|
+
|
271
|
+
return response
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ragxo
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.6
|
4
4
|
Summary: A RAG (Retrieval-Augmented Generation) toolkit with Milvus integration
|
5
5
|
Home-page: https://github.com/yourusername/ragx
|
6
6
|
License: MIT
|
@@ -17,6 +17,7 @@ Classifier: Programming Language :: Python :: 3.12
|
|
17
17
|
Classifier: Programming Language :: Python :: 3.13
|
18
18
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
19
19
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
20
|
+
Requires-Dist: boto3 (>=1.36.14,<2.0.0)
|
20
21
|
Requires-Dist: dill (>=0.3.9,<0.4.0)
|
21
22
|
Requires-Dist: milvus (>=2.3.9,<3.0.0)
|
22
23
|
Requires-Dist: openai (>=1.61.1,<2.0.0)
|
@@ -25,7 +26,9 @@ Requires-Dist: pymilvus (>=2.5.4,<3.0.0)
|
|
25
26
|
Project-URL: Repository, https://github.com/yourusername/ragx
|
26
27
|
Description-Content-Type: text/markdown
|
27
28
|
|
28
|
-
# RagXO
|
29
|
+
# RagXO
|
30
|
+
|
31
|
+
Export, version and reuse your RAG pipeline everywhere 🚀
|
29
32
|
|
30
33
|
[](https://badge.fury.io/py/ragxo)
|
31
34
|
[](https://opensource.org/licenses/MIT)
|
@@ -48,65 +51,70 @@ RagXO extends the capabilities of traditional RAG (Retrieval-Augmented Generatio
|
|
48
51
|
pip install ragxo
|
49
52
|
```
|
50
53
|
|
51
|
-
##
|
54
|
+
## Quickstart 🚀
|
55
|
+
|
56
|
+
### Build a RAG pipeline
|
52
57
|
|
53
58
|
```python
|
54
59
|
from ragxo import Ragxo, Document
|
55
|
-
from openai import OpenAI
|
56
|
-
client = OpenAI()
|
57
60
|
|
58
|
-
def
|
59
|
-
response = client.embeddings.create(
|
60
|
-
input=text,
|
61
|
-
model="text-embedding-ada-002"
|
62
|
-
)
|
63
|
-
return response.data[0].embedding
|
64
|
-
|
65
|
-
def preprocess_text(text: str) -> str:
|
61
|
+
def preprocess_text_lower(text: str) -> str:
|
66
62
|
return text.lower()
|
67
63
|
|
68
|
-
|
69
|
-
|
70
|
-
ragxo.add_preprocess(preprocess_text)
|
71
|
-
ragxo.add_embedding_fn(get_openai_embeddings)
|
64
|
+
def preprocess_text_remove_special_chars(text: str) -> str:
|
65
|
+
return re.sub(r'[^a-zA-Z0-9\s]', '', text)
|
72
66
|
|
73
|
-
|
74
|
-
|
75
|
-
ragxo.add_model("gpt-4o-mini")
|
67
|
+
def get_embeddings(text: str) -> list[float]:
|
68
|
+
return openai.embeddings.create(text=text, model="text-embedding-ada-002").data[0].embedding
|
76
69
|
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
ragxo.index(documents)
|
70
|
+
ragxo_client = Ragxo(dimension=768)
|
71
|
+
|
72
|
+
ragxo_client.add_preprocess(preprocess_text_lower)
|
73
|
+
ragxo_client.add_preprocess(preprocess_text_remove_special_chars)
|
74
|
+
ragxo_client.add_embedding_fn(get_embeddings)
|
75
|
+
|
76
|
+
ragxo_client.add_system_prompt("You are a helpful assistant that can answer questions about the data provided.")
|
77
|
+
ragxo_client.add_model("gpt-4o-mini")
|
86
78
|
|
87
|
-
|
88
|
-
|
79
|
+
ragxo_client.index([
|
80
|
+
Document(text="Capital of France is Paris", metadata={"source": "example"}, id=1),
|
81
|
+
Document(text="Capital of Germany is Berlin", metadata={"source": "example"}, id=2),
|
82
|
+
Document(text="Capital of Italy is Rome", metadata={"source": "example"}, id=3),
|
83
|
+
])
|
89
84
|
|
90
|
-
|
91
|
-
loaded_ragxo = Ragxo.load("my_rag_v1")
|
85
|
+
ragxo_client.export("my_rag_v1.0.0")
|
92
86
|
|
93
|
-
# Query and generate response
|
94
|
-
similar_docs = loaded_ragxo.query("sample query")
|
95
|
-
llm_response = loaded_ragxo.generate_llm_response("What can you tell me about the sample?")
|
96
87
|
```
|
97
88
|
|
89
|
+
|
90
|
+
### Load a RAG pipeline
|
91
|
+
|
92
|
+
```python
|
93
|
+
loaded_ragxo_client = Ragxo.load("my_rag_v1.0.0")
|
94
|
+
|
95
|
+
results = loaded_ragxo_client.query("What is the capital of France?")
|
96
|
+
|
97
|
+
llm_response = loaded_ragxo_client.generate_llm_response(
|
98
|
+
"What is the capital of France?",
|
99
|
+
limit=10,
|
100
|
+
temperature=0.5,
|
101
|
+
max_tokens=1000,
|
102
|
+
top_p=1.0,
|
103
|
+
frequency_penalty=0.0,
|
104
|
+
presence_penalty=0.0)
|
105
|
+
|
106
|
+
```
|
107
|
+
|
108
|
+
|
98
109
|
## Usage Guide 📚
|
99
110
|
|
100
|
-
###
|
111
|
+
### Import
|
101
112
|
|
102
113
|
```python
|
103
|
-
from ragxo import Document
|
114
|
+
from ragxo import Ragxo, Document
|
115
|
+
|
116
|
+
ragxo_client = Ragxo(dimension=768)
|
104
117
|
|
105
|
-
doc = Document(
|
106
|
-
text="Your document content here",
|
107
|
-
metadata={"source": "wiki", "category": "science"},
|
108
|
-
id=1
|
109
|
-
)
|
110
118
|
```
|
111
119
|
|
112
120
|
### Adding Preprocessing Steps
|
@@ -120,8 +128,8 @@ def remove_special_chars(text: str) -> str:
|
|
120
128
|
def lowercase(text: str) -> str:
|
121
129
|
return text.lower()
|
122
130
|
|
123
|
-
|
124
|
-
|
131
|
+
ragxo_client.add_preprocess(remove_special_chars)
|
132
|
+
ragxo_client.add_preprocess(lowercase)
|
125
133
|
```
|
126
134
|
|
127
135
|
### Custom Embedding Functions
|
@@ -150,27 +158,43 @@ def get_openai_embeddings(text: str) -> list[float]:
|
|
150
158
|
ragxo.add_embedding_fn(get_openai_embeddings)
|
151
159
|
```
|
152
160
|
|
161
|
+
|
162
|
+
### Creating Documents
|
163
|
+
|
164
|
+
```python
|
165
|
+
from ragxo import Document
|
166
|
+
|
167
|
+
doc = Document(
|
168
|
+
text="Your document content here",
|
169
|
+
metadata={"source": "wiki", "category": "science"},
|
170
|
+
id=1
|
171
|
+
)
|
172
|
+
|
173
|
+
ragxo_client.index([doc])
|
174
|
+
|
175
|
+
```
|
176
|
+
|
153
177
|
### LLM Configuration
|
154
178
|
|
155
179
|
```python
|
156
180
|
# Set system prompt
|
157
|
-
|
181
|
+
ragxo_client.add_system_prompt("""
|
158
182
|
You are a helpful assistant. Use the provided context to answer questions accurately.
|
159
183
|
If you're unsure about something, please say so.
|
160
184
|
""")
|
161
185
|
|
162
186
|
# Set LLM model
|
163
|
-
|
187
|
+
ragxo_client.add_model("gpt-4")
|
164
188
|
```
|
165
189
|
|
166
190
|
### Export and Load
|
167
191
|
|
168
192
|
```python
|
169
193
|
# Export your RAG pipeline
|
170
|
-
|
194
|
+
ragxo_client.export("rag_pipeline_v1")
|
171
195
|
|
172
196
|
# Load it elsewhere
|
173
|
-
|
197
|
+
loaded_ragxo_client = Ragxo.load("rag_pipeline_v1")
|
174
198
|
```
|
175
199
|
|
176
200
|
## Best Practices 💡
|
@@ -180,27 +204,15 @@ loaded_ragxo = Ragxo.load("rag_pipeline_v1")
|
|
180
204
|
ragxo.export("my_rag_v1.0.0")
|
181
205
|
```
|
182
206
|
|
183
|
-
2. **
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
similar_docs = loaded_ragxo.query("test query")
|
189
|
-
# Test LLM generation
|
190
|
-
llm_response = loaded_ragxo.generate_llm_response("test question")
|
191
|
-
print("Pipeline loaded successfully!")
|
192
|
-
except Exception as e:
|
193
|
-
print(f"Error loading pipeline: {e}")
|
207
|
+
2. **S3**: Use S3 to store your exports
|
208
|
+
|
209
|
+
```shell
|
210
|
+
export AWS_ACCESS_KEY_ID=your_access_key
|
211
|
+
export AWS_SECRET_ACCESS_KEY=your_secret_key
|
194
212
|
```
|
195
213
|
|
196
|
-
3. **Document Your Pipeline Configuration**: Keep track of your setup:
|
197
214
|
```python
|
198
|
-
|
199
|
-
"preprocessing_steps": ["remove_special_chars", "lowercase"],
|
200
|
-
"embedding_model": "all-MiniLM-L6-v2",
|
201
|
-
"llm_model": "gpt-4",
|
202
|
-
"dimension": 384
|
203
|
-
}
|
215
|
+
ragxo_client.export("my_rag_v1.0.0", s3_bucket="my_bucket")
|
204
216
|
```
|
205
217
|
|
206
218
|
## License 📝
|
@@ -0,0 +1,5 @@
|
|
1
|
+
ragxo/__init__.py,sha256=0VVe-z4XkkGQLQIG0hF0Hyf87_RgX0E4T9TRwwTkbmE,68
|
2
|
+
ragxo/client.py,sha256=M4777mj6oPdRIm9TvqIwXoQuJUMc7Ywczlykutd6c70,9068
|
3
|
+
ragxo-0.1.6.dist-info/METADATA,sha256=1W4vJeY0awkXbtM0o3dyhFRHEY8VLHwkXrm55KECbg4,6141
|
4
|
+
ragxo-0.1.6.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
5
|
+
ragxo-0.1.6.dist-info/RECORD,,
|
ragxo/ragx.py
DELETED
@@ -1,145 +0,0 @@
|
|
1
|
-
from typing import Self, Callable
|
2
|
-
from pymilvus import MilvusClient
|
3
|
-
from pydantic import BaseModel
|
4
|
-
import dill
|
5
|
-
import os
|
6
|
-
import shutil
|
7
|
-
import logging
|
8
|
-
import openai
|
9
|
-
from openai import ChatCompletion
|
10
|
-
|
11
|
-
logging.basicConfig(level=logging.DEBUG)
|
12
|
-
logger = logging.getLogger(__name__)
|
13
|
-
|
14
|
-
class Document(BaseModel):
|
15
|
-
text: str
|
16
|
-
metadata: dict
|
17
|
-
id: int
|
18
|
-
|
19
|
-
class Ragxo:
|
20
|
-
def __init__(self, dimension: int) -> None:
|
21
|
-
self.dimension = dimension
|
22
|
-
self.collection_name = "ragx"
|
23
|
-
self.db_path = "milvus.db"
|
24
|
-
self.client = MilvusClient(self.db_path)
|
25
|
-
self.client.create_collection(self.collection_name, dimension=dimension)
|
26
|
-
self.processing_fn = []
|
27
|
-
self.embedding_fn = None
|
28
|
-
self.system_prompt = None
|
29
|
-
self.model = "gpt-4o-mini"
|
30
|
-
|
31
|
-
def add_preprocess(self, fn: Callable) -> Self:
|
32
|
-
self.processing_fn.append(fn)
|
33
|
-
return self
|
34
|
-
|
35
|
-
def add_embedding_fn(self, fn: Callable) -> Self:
|
36
|
-
if not fn:
|
37
|
-
raise ValueError("Embedding function cannot be None")
|
38
|
-
self.embedding_fn = fn
|
39
|
-
return self
|
40
|
-
|
41
|
-
def add_system_prompt(self, prompt: str) -> Self:
|
42
|
-
self.system_prompt = prompt
|
43
|
-
return self
|
44
|
-
|
45
|
-
def add_model(self, model: str) -> Self:
|
46
|
-
self.model = model
|
47
|
-
return self
|
48
|
-
|
49
|
-
def index(self, data: list[Document]) -> Self:
|
50
|
-
if not self.embedding_fn:
|
51
|
-
raise ValueError("Embedding function not set")
|
52
|
-
|
53
|
-
processed_text = []
|
54
|
-
for item in data:
|
55
|
-
current_text = item.text
|
56
|
-
for fn in self.processing_fn:
|
57
|
-
current_text = fn(current_text)
|
58
|
-
processed_text.append(current_text)
|
59
|
-
|
60
|
-
embeddings = [
|
61
|
-
self.embedding_fn(text)
|
62
|
-
for text in processed_text
|
63
|
-
]
|
64
|
-
|
65
|
-
self.client.insert(self.collection_name, [
|
66
|
-
{
|
67
|
-
"text": item.text,
|
68
|
-
"metadata": item.metadata,
|
69
|
-
"id": item.id,
|
70
|
-
"vector": embedding
|
71
|
-
}
|
72
|
-
for item, embedding in zip(data, embeddings)
|
73
|
-
])
|
74
|
-
return self
|
75
|
-
|
76
|
-
def query(self, query: str, output_fields: list[str] = ['text', 'metadata']) -> list[list[dict]]:
|
77
|
-
if not self.embedding_fn:
|
78
|
-
raise ValueError("Embedding function not set. Please call add_embedding_fn first.")
|
79
|
-
|
80
|
-
preprocessed_query = query
|
81
|
-
for fn in self.processing_fn:
|
82
|
-
preprocessed_query = fn(preprocessed_query)
|
83
|
-
|
84
|
-
embedding = self.embedding_fn(preprocessed_query)
|
85
|
-
|
86
|
-
return self.client.search(
|
87
|
-
collection_name=self.collection_name,
|
88
|
-
data=[embedding],
|
89
|
-
limit=10,
|
90
|
-
output_fields=output_fields
|
91
|
-
)
|
92
|
-
|
93
|
-
def export(self, folder_path: str) -> Self:
|
94
|
-
try:
|
95
|
-
os.makedirs(folder_path, exist_ok=True)
|
96
|
-
|
97
|
-
# Save using dill
|
98
|
-
pickle_path = os.path.join(folder_path, "ragx.pkl")
|
99
|
-
with open(pickle_path, "wb") as f:
|
100
|
-
dill.dump(self, f)
|
101
|
-
|
102
|
-
# Copy database
|
103
|
-
db_dest = os.path.join(folder_path, "milvus.db")
|
104
|
-
shutil.copy(self.db_path, db_dest)
|
105
|
-
|
106
|
-
return self
|
107
|
-
|
108
|
-
except Exception as e:
|
109
|
-
logger.error(f"Error in export: {e}")
|
110
|
-
raise
|
111
|
-
|
112
|
-
@classmethod
|
113
|
-
def load(cls, folder_path: str) -> 'Ragx':
|
114
|
-
try:
|
115
|
-
pickle_path = os.path.join(folder_path, "ragx.pkl")
|
116
|
-
|
117
|
-
with open(pickle_path, "rb") as f:
|
118
|
-
instance = dill.load(f)
|
119
|
-
|
120
|
-
# Restore client
|
121
|
-
instance.client = MilvusClient(os.path.join(folder_path, "milvus.db"))
|
122
|
-
|
123
|
-
return instance
|
124
|
-
|
125
|
-
except Exception as e:
|
126
|
-
logger.error(f"Error in load: {e}")
|
127
|
-
raise
|
128
|
-
|
129
|
-
def generate_llm_response(self, query: str, data: list[dict] = None) -> ChatCompletion:
|
130
|
-
|
131
|
-
if data is None:
|
132
|
-
data = self.query(query)[0]
|
133
|
-
|
134
|
-
if not self.system_prompt:
|
135
|
-
raise ValueError("System prompt not set. Please call add_system_prompt first.")
|
136
|
-
|
137
|
-
response = openai.chat.completions.create(
|
138
|
-
model=self.model,
|
139
|
-
messages=[
|
140
|
-
{"role": "system", "content": self.system_prompt},
|
141
|
-
{"role": "user", "content": "query: {} data: {}".format(query, data)}
|
142
|
-
]
|
143
|
-
)
|
144
|
-
|
145
|
-
return response
|
ragxo-0.1.3.dist-info/RECORD
DELETED
@@ -1,5 +0,0 @@
|
|
1
|
-
ragxo/__init__.py,sha256=jI_6iulTUQk9JUDft-jM6NHESpZSmJVPIaVOmd4-jWw,65
|
2
|
-
ragxo/ragx.py,sha256=_HQCTth_iR2rxV9amMyA6qlOpdGji5_-rSDB5WWG2u4,4537
|
3
|
-
ragxo-0.1.3.dist-info/METADATA,sha256=FZmy-PL_SZMf9NuDWcniQUsleZna_GYsz5GLoJRbHcM,5960
|
4
|
-
ragxo-0.1.3.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
5
|
-
ragxo-0.1.3.dist-info/RECORD,,
|
File without changes
|