pirag 0.1.7__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- app/main.py +47 -65
- app/rag/agent/services.py +11 -0
- app/rag/api.py +40 -0
- app/rag/cli.py +54 -0
- app/rag/config.py +115 -122
- app/rag/embedding/client.py +70 -0
- app/rag/embedding/services.py +26 -0
- app/rag/llm/client.py +128 -0
- app/rag/llm/services.py +26 -0
- app/rag/llm/utilities.py +40 -0
- app/rag/models.py +19 -0
- app/rag/routers.py +41 -0
- app/rag/utilities.py +15 -0
- app/rag/v1/routers.py +7 -0
- app/rag/vector_store/client.py +84 -0
- app/rag/vector_store/services.py +56 -0
- app/requirements.txt +6 -5
- app/setup.py +2 -3
- {pirag-0.1.7.dist-info → pirag-0.2.2.dist-info}/METADATA +6 -3
- pirag-0.2.2.dist-info/RECORD +27 -0
- {pirag-0.1.7.dist-info → pirag-0.2.2.dist-info}/WHEEL +1 -1
- app/rag/agent.py +0 -64
- app/rag/ask/__init__.py +0 -2
- app/rag/ask/config.py +0 -9
- app/rag/ask/router.py +0 -4
- app/rag/doctor/__init__.py +0 -2
- app/rag/doctor/config.py +0 -24
- app/rag/doctor/router.py +0 -4
- app/rag/test/__init__.py +0 -2
- app/rag/test/config.py +0 -9
- app/rag/test/router.py +0 -4
- app/rag/train/__init__.py +0 -2
- app/rag/train/config.py +0 -20
- app/rag/train/router.py +0 -4
- app/rag/train/service.py +0 -0
- pirag-0.1.7.dist-info/RECORD +0 -27
- /app/rag/{ask/service.py → test/client.py} +0 -0
- /app/rag/{doctor/service.py → train/client.py} +0 -0
- /app/rag/{test/service.py → v1/services.py} +0 -0
- {pirag-0.1.7.dist-info → pirag-0.2.2.dist-info}/entry_points.txt +0 -0
- {pirag-0.1.7.dist-info → pirag-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {pirag-0.1.7.dist-info → pirag-0.2.2.dist-info}/top_level.txt +0 -0
app/main.py
CHANGED
@@ -1,81 +1,63 @@
|
|
1
|
-
import argparse, os
|
2
|
-
from dotenv import load_dotenv
|
3
1
|
from loguru import logger
|
2
|
+
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
4
3
|
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
from app.rag.doctor import help as doctor_help, parser as doctor_parser, route as doctor_route
|
9
|
-
from app.rag.train import help as train_help, parser as train_parser, route as train_route
|
10
|
-
from app.rag.ask import help as ask_help, parser as ask_parser, route as ask_route
|
11
|
-
from app.rag.test import help as test_help, parser as test_parser, route as test_route
|
4
|
+
import app.rag.config as cfn
|
5
|
+
import app.rag.api as api
|
6
|
+
import app.rag.cli as cli
|
12
7
|
|
13
8
|
# Main parser
|
14
|
-
parser =
|
15
|
-
formatter_class =
|
16
|
-
description =
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
# Commands
|
22
|
-
subparsers = parser.add_subparsers(
|
23
|
-
title = 'commands',
|
24
|
-
dest = 'command',
|
25
|
-
)
|
26
|
-
|
27
|
-
subparsers.add_parser(
|
28
|
-
'doctor',
|
29
|
-
help = doctor_help,
|
30
|
-
description = doctor_parser.description,
|
31
|
-
parents = [top_parser, common_parser, doctor_parser],
|
32
|
-
add_help = False,
|
33
|
-
)
|
34
|
-
|
35
|
-
subparsers.add_parser(
|
36
|
-
'train',
|
37
|
-
help = train_help,
|
38
|
-
description = train_parser.description,
|
39
|
-
parents = [top_parser, common_parser, train_parser],
|
40
|
-
add_help = False,
|
41
|
-
)
|
42
|
-
|
43
|
-
subparsers.add_parser(
|
44
|
-
'test',
|
45
|
-
help = test_help,
|
46
|
-
description = test_parser.description,
|
47
|
-
parents = [top_parser, common_parser, test_parser],
|
9
|
+
parser = ArgumentParser(
|
10
|
+
formatter_class = ArgumentDefaultsHelpFormatter,
|
11
|
+
description = """
|
12
|
+
Pilot of On-Premise RAG.
|
13
|
+
""",
|
14
|
+
parents = [cfn.top_parser, cfn.common_parser],
|
48
15
|
add_help = False,
|
49
16
|
)
|
50
17
|
|
51
|
-
|
52
|
-
|
53
|
-
help
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
)
|
18
|
+
# Command definitions
|
19
|
+
commands = {
|
20
|
+
# name: help, description, function, extra_parsers
|
21
|
+
"serve" : ("Start the RAG server", "Run a FastAPI-based RAG server", api.serve, []),
|
22
|
+
"chat" : ("Chat with the RAG system", "Run an interactive chat with the RAG system", cli.chat, [cfn.chat_parser]),
|
23
|
+
"train" : ("Train the RAG system", "Run a pipeline to train the RAG system", cli.train, []),
|
24
|
+
"test" : ("Test the RAG system", "Run a pipeline to test the RAG system", cli.test, []),
|
25
|
+
"doctor" : ("Diagnose the RAG system", "Run a pipeline to diagnose the RAG system", cli.doctor, [cfn.doctor_parser]),
|
26
|
+
}
|
27
|
+
|
28
|
+
# Add command parsers
|
29
|
+
subparsers = parser.add_subparsers(title="commands", dest="command")
|
30
|
+
for name, (help, description, _, extra_parsers) in commands.items():
|
31
|
+
subparsers.add_parser(
|
32
|
+
name = name,
|
33
|
+
help = help,
|
34
|
+
description = description,
|
35
|
+
parents = [cfn.common_parser] + extra_parsers,
|
36
|
+
add_help = False,
|
37
|
+
)
|
58
38
|
|
59
39
|
def main():
|
60
40
|
args = parser.parse_args()
|
61
|
-
setup_logger(
|
62
|
-
log_level = args.log_level,
|
63
|
-
log_dir = args.log_dir,
|
64
|
-
)
|
65
|
-
command_message = f"with command: {args.command}" if args.command else ""
|
66
|
-
logger.info(f"RAG Started {command_message}")
|
41
|
+
cfn.setup_logger(cfn.LOG_LEVEL, cfn.LOG_SAVE, cfn.LOG_DIR)
|
67
42
|
logger.debug(f"Parsed arguments: {args}")
|
68
43
|
|
69
|
-
if args.command
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
44
|
+
if command_info := commands.get(args.command):
|
45
|
+
func, extra_parsers = command_info[2], command_info[3]
|
46
|
+
|
47
|
+
# Create parser options dict from extra_parsers
|
48
|
+
extra_options = {}
|
49
|
+
if extra_parsers:
|
50
|
+
for parser_obj in extra_parsers:
|
51
|
+
for action in parser_obj._actions:
|
52
|
+
if action.dest == 'help':
|
53
|
+
continue
|
54
|
+
if hasattr(args, action.dest) and getattr(args, action.dest) != action.default:
|
55
|
+
extra_options[action.dest] = getattr(args, action.dest)
|
56
|
+
|
57
|
+
# Run the command with the extra parser options
|
58
|
+
func(extra_options)
|
77
59
|
else:
|
78
60
|
parser.print_help()
|
79
61
|
|
80
|
-
if __name__ ==
|
62
|
+
if __name__ == "__main__":
|
81
63
|
main()
|
app/rag/api.py
ADDED
@@ -0,0 +1,40 @@
|
|
1
|
+
import uvicorn
|
2
|
+
from fastapi import FastAPI, APIRouter, Request, Depends, HTTPException, Query
|
3
|
+
from fastapi.middleware.cors import CORSMiddleware
|
4
|
+
|
5
|
+
from loguru import logger
|
6
|
+
import app.rag.config as cfn
|
7
|
+
|
8
|
+
from app.rag.routers import system_router
|
9
|
+
from app.rag.v1.routers import router as v1_router
|
10
|
+
|
11
|
+
# Initialize FastAPI app
|
12
|
+
api = FastAPI(
|
13
|
+
title = "RAG API",
|
14
|
+
description = "API for Retrieval-Augmented Generation",
|
15
|
+
version = cfn.__version__,
|
16
|
+
)
|
17
|
+
|
18
|
+
# Add CORS middleware
|
19
|
+
api.add_middleware(
|
20
|
+
CORSMiddleware,
|
21
|
+
allow_origins=["*"], # Adjust in production
|
22
|
+
allow_credentials=True,
|
23
|
+
allow_methods=["*"],
|
24
|
+
allow_headers=["*"],
|
25
|
+
)
|
26
|
+
|
27
|
+
api.include_router(router=system_router, prefix="", tags=["System"])
|
28
|
+
api.include_router(router=v1_router, prefix="/v1")
|
29
|
+
|
30
|
+
def serve(parser_options=None):
|
31
|
+
print("Serving the RAG API...")
|
32
|
+
if parser_options:
|
33
|
+
logger.debug(f"Serve parser options: {parser_options}")
|
34
|
+
|
35
|
+
uvicorn.run(
|
36
|
+
app = "app.rag.api:api",
|
37
|
+
host = cfn.API_HOST,
|
38
|
+
port = cfn.API_PORT,
|
39
|
+
reload = cfn.API_RELOAD,
|
40
|
+
)
|
app/rag/cli.py
ADDED
@@ -0,0 +1,54 @@
|
|
1
|
+
import app.rag.config as cfn
|
2
|
+
from loguru import logger
|
3
|
+
|
4
|
+
from app.rag.llm.services import doctor as doctor_llm
|
5
|
+
from app.rag.embedding.services import doctor as doctor_embedding
|
6
|
+
from app.rag.vector_store.services import doctor as doctor_vector_store
|
7
|
+
from app.rag.agent.services import chat_only_llm, chat_with_rag
|
8
|
+
|
9
|
+
def chat(options: dict):
|
10
|
+
logger.debug(f"Chat parser options: {options}")
|
11
|
+
no_rag = options.get('no_rag', False)
|
12
|
+
|
13
|
+
# -- Chat
|
14
|
+
if no_rag:
|
15
|
+
logger.info("💬 Chatting with the LLM system directly...")
|
16
|
+
chat_only_llm()
|
17
|
+
else:
|
18
|
+
logger.info("💬 Chatting with the RAG system...")
|
19
|
+
chat_with_rag()
|
20
|
+
|
21
|
+
|
22
|
+
def train(options: dict):
|
23
|
+
print("Training the RAG system...")
|
24
|
+
logger.debug(f"Train parser options: {options}")
|
25
|
+
|
26
|
+
|
27
|
+
def test(options: dict):
|
28
|
+
print("Testing the RAG system...")
|
29
|
+
logger.debug(f"Test parser options: {options}")
|
30
|
+
|
31
|
+
|
32
|
+
def doctor(options: dict):
|
33
|
+
logger.info("💚 Doctoring the RAG system...")
|
34
|
+
|
35
|
+
logger.debug(f"Doctor parser options: {options}")
|
36
|
+
# Check if resolve option is present
|
37
|
+
resolve = options.get('resolve', False)
|
38
|
+
if resolve:
|
39
|
+
logger.info("🔧 Resolving issues is enabled")
|
40
|
+
|
41
|
+
# -- LLM Server
|
42
|
+
logger.info("🔍 Checking the LLM server (OpenAI-compatible)...")
|
43
|
+
doctor_llm(resolve)
|
44
|
+
|
45
|
+
# -- Embedding Server
|
46
|
+
logger.info("🔍 Checking the embedding server (OpenAI-compatible)...")
|
47
|
+
doctor_embedding(resolve)
|
48
|
+
|
49
|
+
# -- Vector Store
|
50
|
+
logger.info("🔍 Checking the vector store server (Milvus)...")
|
51
|
+
doctor_vector_store(resolve)
|
52
|
+
|
53
|
+
if resolve:
|
54
|
+
logger.info(f"🔧 Resolving issue completed. To make sure the issues are resolved, please try doctoring again.")
|
app/rag/config.py
CHANGED
@@ -1,35 +1,84 @@
|
|
1
|
-
import argparse,
|
2
|
-
from pathlib import Path
|
1
|
+
import argparse, sys, pathlib
|
3
2
|
from loguru import logger
|
3
|
+
from dynaconf import Dynaconf
|
4
|
+
from importlib.metadata import version, PackageNotFoundError
|
5
|
+
try:
|
6
|
+
__version__ = version("pirag")
|
7
|
+
except PackageNotFoundError:
|
8
|
+
__version__ = "0.0.0"
|
9
|
+
|
10
|
+
|
11
|
+
# -- Load configuration
|
12
|
+
settings = Dynaconf(
|
13
|
+
settings_files = ["settings.yaml"],
|
14
|
+
envvar_prefix = False,
|
15
|
+
load_dotenv = False,
|
16
|
+
)
|
17
|
+
|
18
|
+
|
19
|
+
# -- Loging
|
20
|
+
LOG_LEVEL: str = settings.get("LOG.LEVEL", "INFO").upper()
|
21
|
+
if LOG_LEVEL not in ["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"]:
|
22
|
+
raise ValueError(f"Invalid log level: {LOG_LEVEL}. Must be one of: INFO, DEBUG, WARNING, ERROR, CRITICAL")
|
23
|
+
|
24
|
+
LOG_SAVE: bool = settings.get("LOG.SAVE", False)
|
25
|
+
LOG_DIR: str = settings.get("LOG.DIR", ".pirag/logs")
|
4
26
|
|
5
|
-
# Logger format constants
|
6
27
|
LOG_TIME_FORMAT = "{time:YYYY-MM-DD HH:mm:ss.SSS!UTC}Z"
|
7
28
|
LOG_FILE_FORMAT = f"{LOG_TIME_FORMAT} | {{level: <8}} | {{name}}:{{function}}:{{line}} - {{message}}"
|
8
29
|
LOG_CONSOLE_FORMAT_FULL = f"<green>{LOG_TIME_FORMAT}</green> | <level>{{level: <8}}</level> | <cyan>{{name}}</cyan>:<cyan>{{function}}</cyan>:<cyan>{{line}}</cyan> - <level>{{message}}</level>\n"
|
9
30
|
LOG_CONSOLE_FORMAT_SIMPLE = f"<green>{LOG_TIME_FORMAT}</green> | <level>{{level: <8}}</level> | <level>{{message}}</level>\n"
|
10
31
|
|
11
|
-
|
32
|
+
|
33
|
+
# -- Serving API
|
34
|
+
API_HOST: str = settings.get("API.HOST", "0.0.0.0")
|
35
|
+
API_PORT: int = settings.get("API.PORT", 8000)
|
36
|
+
API_RELOAD: bool = settings.get("API.RELOAD", True)
|
37
|
+
|
38
|
+
|
39
|
+
# -- LLM Server
|
40
|
+
LLM_BASE_URL: str = settings.get("LLM.BASE_URL", "http://localhost:11434")
|
41
|
+
LLM_API_KEY: str = settings.get("LLM.API_KEY", "llm_api_key")
|
42
|
+
LLM_MODEL: str = settings.get("LLM.MODEL", "gemma3:4b")
|
43
|
+
LLM_SERVER_TYPE: str = settings.get("LLM.SERVER_TYPE", "openai")
|
44
|
+
|
45
|
+
|
46
|
+
# -- Embedding Server
|
47
|
+
EMBEDDING_BASE_URL: str = settings.get("EMBEDDING.BASE_URL", "http://localhost:11434")
|
48
|
+
EMBEDDING_API_KEY: str = settings.get("EMBEDDING.API_KEY", "embedding_api_key")
|
49
|
+
EMBEDDING_MODEL: str = settings.get("EMBEDDING.MODEL", "nomic-embed-text:latest")
|
50
|
+
EMBEDDING_SERVER_TYPE: str = settings.get("EMBEDDING.SERVER_TYPE", "openai")
|
51
|
+
EMBEDDING_DIMENSION: int = settings.get("EMBEDDING.DIMENSION", 768)
|
52
|
+
|
53
|
+
|
54
|
+
# -- Data Warehouse
|
55
|
+
MINIO_BASE_URL: str = settings.get("MINIO.BASE_URL", "http://localhost:9000")
|
56
|
+
MINIO_ACCESS_KEY: str = settings.get("MINIO.ACCESS_KEY", "minioadmin")
|
57
|
+
MINIO_SECRET_KEY: str = settings.get("MINIO.SECRET_KEY", "minioadmin")
|
58
|
+
MINIO_BUCKET: str = settings.get("MINIO.BUCKET", "pirag")
|
59
|
+
MINIO_REGION: str = settings.get("MINIO.REGION", "us-east-1")
|
60
|
+
|
61
|
+
|
62
|
+
# -- Vector Store
|
63
|
+
MILVUS_BASE_URL: str = settings.get("MILVUS.BASE_URL", "http://localhost:19530")
|
64
|
+
MILVUS_USER: str = settings.get("MILVUS.USER", "milvus")
|
65
|
+
MILVUS_PASSWORD: str = settings.get("MILVUS.PASSWORD", "milvus")
|
66
|
+
MILVUS_DATABASE: str = settings.get("MILVUS.DATABASE", "milvus_database")
|
67
|
+
MILVUS_COLLECTION: str = settings.get("MILVUS.COLLECTION", "milvus_collection")
|
68
|
+
MILVUS_METRIC_TYPE: str = settings.get("MILVUS.METRIC_TYPE", "IP")
|
69
|
+
|
70
|
+
|
71
|
+
# -- Monitoring
|
72
|
+
LANGFUSE_BASE_URL: str = settings.get("LANGFUSE.BASE_URL", "http://localhost:8000")
|
73
|
+
LANGFUSE_API_KEY: str = settings.get("LANGFUSE.API_KEY", "langfuse_api_key")
|
74
|
+
LANGFUSE_PROJECT_ID: str = settings.get("LANGFUSE.PROJECT_ID", "langfuse_project_id")
|
75
|
+
|
76
|
+
|
77
|
+
def setup_logger(log_level: str, log_save: bool, log_dir: str):
|
12
78
|
"""Configure logger with specified level and outputs"""
|
13
|
-
|
14
|
-
log_dir = Path(log_dir)
|
15
|
-
log_dir.mkdir(exist_ok=True, parents=True)
|
16
|
-
|
79
|
+
|
17
80
|
logger.remove()
|
18
|
-
|
19
|
-
# File handler
|
20
|
-
logger.add(
|
21
|
-
sink = log_dir / "{time:YYYYMMDD-HHmmss!UTC}Z.log",
|
22
|
-
level = log_level,
|
23
|
-
rotation = "100 MB",
|
24
|
-
retention = 0,
|
25
|
-
format = LOG_FILE_FORMAT,
|
26
|
-
serialize = False,
|
27
|
-
enqueue = True,
|
28
|
-
backtrace = True,
|
29
|
-
diagnose = True,
|
30
|
-
catch = True
|
31
|
-
)
|
32
|
-
|
81
|
+
|
33
82
|
# Console handler
|
34
83
|
logger.add(
|
35
84
|
sink = sys.stderr,
|
@@ -38,114 +87,58 @@ def setup_logger(log_level: str, log_dir: str):
|
|
38
87
|
colorize = True
|
39
88
|
)
|
40
89
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
default (Any, optional): Default value if environment variable is not set. Defaults to None.
|
58
|
-
**kwargs: Additional arguments passed to argparse.Action
|
59
|
-
|
60
|
-
Example:
|
61
|
-
```python
|
62
|
-
parser.add_argument(
|
63
|
-
'--log-level',
|
64
|
-
envvar='LOG_LEVEL',
|
65
|
-
help='Logging level',
|
66
|
-
default='INFO',
|
67
|
-
action=EnvDefault
|
90
|
+
if log_save:
|
91
|
+
log_dir = pathlib.Path(log_dir)
|
92
|
+
log_dir.mkdir(exist_ok=True, parents=True)
|
93
|
+
|
94
|
+
# File handler
|
95
|
+
logger.add(
|
96
|
+
sink = log_dir / "{time:YYYYMMDD-HHmmss!UTC}Z.log",
|
97
|
+
level = log_level,
|
98
|
+
rotation = "100 MB",
|
99
|
+
retention = 0,
|
100
|
+
format = LOG_FILE_FORMAT,
|
101
|
+
serialize = False,
|
102
|
+
enqueue = True,
|
103
|
+
backtrace = True,
|
104
|
+
diagnose = True,
|
105
|
+
catch = True
|
68
106
|
)
|
69
|
-
```
|
70
|
-
|
71
|
-
Note:
|
72
|
-
The help text is automatically updated to include the environment variable name.
|
73
|
-
"""
|
74
|
-
def __init__(self, envvar, required=True, default=None, **kwargs):
|
75
|
-
if envvar and envvar in os.environ:
|
76
|
-
env_value = os.environ[envvar]
|
77
|
-
# Convert string environment variable to boolean
|
78
|
-
if kwargs.get('nargs') is None and kwargs.get('const') is not None: # store_true/store_false case
|
79
|
-
default = env_value.lower() in ('true', '1', 'yes', 'on')
|
80
|
-
else:
|
81
|
-
default = env_value
|
82
|
-
logger.debug(f"Using {envvar}={default} from environment")
|
83
|
-
|
84
|
-
if envvar:
|
85
|
-
kwargs["help"] += f" (envvar: {envvar})"
|
86
|
-
|
87
|
-
if required and default:
|
88
|
-
required = False
|
89
|
-
|
90
|
-
super(EnvDefault, self).__init__(default=default, required=required, **kwargs)
|
91
|
-
self.envvar = envvar
|
92
|
-
|
93
|
-
def __call__(self, parser, namespace, values, option_string=None):
|
94
|
-
setattr(namespace, self.dest, values if values is not None else self.default)
|
95
|
-
|
96
|
-
|
97
|
-
# Top-level parser with common options
|
98
|
-
top_parser = argparse.ArgumentParser(add_help=False)
|
99
107
|
|
100
|
-
top_parser.add_argument(
|
101
|
-
'-h', '--help',
|
102
|
-
help = 'Show help message and exit',
|
103
|
-
default = argparse.SUPPRESS,
|
104
|
-
action = 'help',
|
105
|
-
)
|
106
108
|
|
109
|
+
# Top-level parser
|
110
|
+
top_parser = argparse.ArgumentParser(add_help=False)
|
107
111
|
top_parser.add_argument(
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
type = str,
|
113
|
-
action = EnvDefault,
|
112
|
+
"-v", "--version",
|
113
|
+
help = "Show the `pirag` application's version and exit",
|
114
|
+
action = "version",
|
115
|
+
version = f"{__version__}",
|
114
116
|
)
|
115
117
|
|
116
|
-
top_parser.add_argument(
|
117
|
-
'--log-level',
|
118
|
-
envvar = 'LOG_LEVEL',
|
119
|
-
help = 'Logging level',
|
120
|
-
default = 'INFO',
|
121
|
-
type = lambda x: x.upper(),
|
122
|
-
choices = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
123
|
-
required = False,
|
124
|
-
action = EnvDefault,
|
125
|
-
)
|
126
118
|
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
help
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
action = EnvDefault,
|
119
|
+
# Common parser
|
120
|
+
common_parser = argparse.ArgumentParser(add_help=False)
|
121
|
+
common_parser.add_argument(
|
122
|
+
"-h", "--help",
|
123
|
+
help = "Show help message and exit",
|
124
|
+
default = argparse.SUPPRESS,
|
125
|
+
action = "help",
|
135
126
|
)
|
136
127
|
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
type = bool,
|
145
|
-
required = False,
|
146
|
-
action = EnvDefault,
|
128
|
+
|
129
|
+
# Chat parser
|
130
|
+
chat_parser = argparse.ArgumentParser(add_help=False)
|
131
|
+
chat_parser.add_argument(
|
132
|
+
"-n", "--no-rag",
|
133
|
+
help = "Do not use RAG to answer the question. Just use the LLM to answer the question.",
|
134
|
+
action = "store_true",
|
147
135
|
)
|
148
136
|
|
149
|
-
|
150
|
-
|
137
|
+
|
138
|
+
# Doctor parser
|
139
|
+
doctor_parser = argparse.ArgumentParser(add_help=False)
|
140
|
+
doctor_parser.add_argument(
|
141
|
+
"-r", "--resolve",
|
142
|
+
help = "Resolve the issue",
|
143
|
+
action = "store_true",
|
151
144
|
)
|
@@ -0,0 +1,70 @@
|
|
1
|
+
import requests
|
2
|
+
from langchain_openai.embeddings import OpenAIEmbeddings
|
3
|
+
|
4
|
+
import app.rag.config as cfn
|
5
|
+
from app.rag.utilities import connection_check
|
6
|
+
|
7
|
+
|
8
|
+
class EmbeddingClient:
|
9
|
+
def __init__(self, base_url: str, api_key: str, model: str):
|
10
|
+
self.base_url = base_url
|
11
|
+
self.api_key = api_key
|
12
|
+
self.model = model
|
13
|
+
self._is_connected = True
|
14
|
+
self._client = None
|
15
|
+
|
16
|
+
if self.check_connection():
|
17
|
+
try:
|
18
|
+
self._client = OpenAIEmbeddings(
|
19
|
+
base_url = base_url,
|
20
|
+
api_key = api_key,
|
21
|
+
model = model
|
22
|
+
)
|
23
|
+
except Exception as e:
|
24
|
+
self._is_connected = False
|
25
|
+
|
26
|
+
def check_connection(self) -> bool:
|
27
|
+
"""Check if the embedding server is accessible"""
|
28
|
+
try:
|
29
|
+
requests.head(url=self.base_url, timeout=5)
|
30
|
+
except requests.exceptions.ConnectionError:
|
31
|
+
self._is_connected = False
|
32
|
+
return False
|
33
|
+
self._is_connected = True
|
34
|
+
return True
|
35
|
+
|
36
|
+
@connection_check
|
37
|
+
def generate(self, prompt: str) -> str:
|
38
|
+
"""Generate text from prompt"""
|
39
|
+
if not self._is_connected or self._client is None:
|
40
|
+
return ""
|
41
|
+
return self._client.embed_query(prompt)
|
42
|
+
|
43
|
+
@connection_check
|
44
|
+
def list_models(self) -> list:
|
45
|
+
"""List available models"""
|
46
|
+
if not self._is_connected:
|
47
|
+
return []
|
48
|
+
try:
|
49
|
+
response = requests.get(
|
50
|
+
f"{self.base_url}/models",
|
51
|
+
headers={"Authorization": f"Bearer {self.api_key}"}
|
52
|
+
)
|
53
|
+
if response.status_code == 200:
|
54
|
+
return [model['id'] for model in response.json()['data']]
|
55
|
+
return []
|
56
|
+
except Exception:
|
57
|
+
return []
|
58
|
+
|
59
|
+
@connection_check
|
60
|
+
def has_model(self, model: str) -> bool:
|
61
|
+
"""Check if model exists"""
|
62
|
+
if not self._is_connected:
|
63
|
+
return False
|
64
|
+
return model in self.list_models()
|
65
|
+
|
66
|
+
client = EmbeddingClient(
|
67
|
+
base_url = cfn.EMBEDDING_BASE_URL,
|
68
|
+
api_key = cfn.EMBEDDING_API_KEY,
|
69
|
+
model = cfn.EMBEDDING_MODEL,
|
70
|
+
)
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from loguru import logger
|
2
|
+
|
3
|
+
import app.rag.config as cfn
|
4
|
+
from .client import client
|
5
|
+
|
6
|
+
def doctor(resolve: bool):
|
7
|
+
# Check connection
|
8
|
+
is_connected = client.check_connection()
|
9
|
+
if not is_connected:
|
10
|
+
logger.error(f"- ❌ FAILED: Embedding connection ({cfn.EMBEDDING_BASE_URL})")
|
11
|
+
else:
|
12
|
+
logger.info(f"- ✅ PASSED: Embedding connection ({cfn.EMBEDDING_BASE_URL})")
|
13
|
+
|
14
|
+
# Check model availability
|
15
|
+
try:
|
16
|
+
if not is_connected:
|
17
|
+
logger.warning(f"- ⏭️ SKIPPED: Embedding model (Server is not accessible)")
|
18
|
+
else:
|
19
|
+
# List models
|
20
|
+
models = client.list_models()
|
21
|
+
if cfn.EMBEDDING_MODEL not in models:
|
22
|
+
logger.error(f"- ❌ FAILED: Embedding model not found ({cfn.EMBEDDING_MODEL})")
|
23
|
+
else:
|
24
|
+
logger.info(f"- ✅ PASSED: Embedding model found (Model `{cfn.EMBEDDING_MODEL}` exists)")
|
25
|
+
except Exception as e:
|
26
|
+
logger.error(f"- ❌ FAILED: Embedding model check ({str(e)})")
|