pirag 0.1.7__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pirag-0.1.7 → pirag-0.2.2}/PKG-INFO +6 -3
- pirag-0.2.2/app/main.py +63 -0
- pirag-0.2.2/app/rag/agent/services.py +11 -0
- pirag-0.2.2/app/rag/api.py +40 -0
- pirag-0.2.2/app/rag/cli.py +54 -0
- pirag-0.2.2/app/rag/config.py +144 -0
- pirag-0.2.2/app/rag/embedding/client.py +70 -0
- pirag-0.2.2/app/rag/embedding/services.py +26 -0
- pirag-0.2.2/app/rag/llm/client.py +128 -0
- pirag-0.2.2/app/rag/llm/services.py +26 -0
- pirag-0.2.2/app/rag/llm/utilities.py +40 -0
- pirag-0.2.2/app/rag/models.py +19 -0
- pirag-0.2.2/app/rag/routers.py +41 -0
- pirag-0.2.2/app/rag/utilities.py +15 -0
- pirag-0.2.2/app/rag/v1/routers.py +7 -0
- pirag-0.2.2/app/rag/vector_store/client.py +84 -0
- pirag-0.2.2/app/rag/vector_store/services.py +56 -0
- pirag-0.2.2/app/requirements.txt +12 -0
- {pirag-0.1.7 → pirag-0.2.2}/app/setup.py +2 -3
- {pirag-0.1.7 → pirag-0.2.2}/pirag.egg-info/PKG-INFO +6 -3
- pirag-0.2.2/pirag.egg-info/SOURCES.txt +30 -0
- pirag-0.2.2/pirag.egg-info/requires.txt +9 -0
- {pirag-0.1.7 → pirag-0.2.2}/pyproject.toml +1 -1
- pirag-0.1.7/app/main.py +0 -81
- pirag-0.1.7/app/rag/agent.py +0 -64
- pirag-0.1.7/app/rag/ask/__init__.py +0 -2
- pirag-0.1.7/app/rag/ask/config.py +0 -9
- pirag-0.1.7/app/rag/ask/router.py +0 -4
- pirag-0.1.7/app/rag/config.py +0 -151
- pirag-0.1.7/app/rag/doctor/__init__.py +0 -2
- pirag-0.1.7/app/rag/doctor/config.py +0 -24
- pirag-0.1.7/app/rag/doctor/router.py +0 -4
- pirag-0.1.7/app/rag/test/__init__.py +0 -2
- pirag-0.1.7/app/rag/test/config.py +0 -9
- pirag-0.1.7/app/rag/test/router.py +0 -4
- pirag-0.1.7/app/rag/train/__init__.py +0 -2
- pirag-0.1.7/app/rag/train/config.py +0 -20
- pirag-0.1.7/app/rag/train/router.py +0 -4
- pirag-0.1.7/app/rag/train/service.py +0 -0
- pirag-0.1.7/app/requirements.txt +0 -11
- pirag-0.1.7/pirag.egg-info/SOURCES.txt +0 -30
- pirag-0.1.7/pirag.egg-info/requires.txt +0 -6
- {pirag-0.1.7 → pirag-0.2.2}/LICENSE +0 -0
- {pirag-0.1.7 → pirag-0.2.2}/README.md +0 -0
- /pirag-0.1.7/app/rag/ask/service.py → /pirag-0.2.2/app/rag/test/client.py +0 -0
- /pirag-0.1.7/app/rag/doctor/service.py → /pirag-0.2.2/app/rag/train/client.py +0 -0
- /pirag-0.1.7/app/rag/test/service.py → /pirag-0.2.2/app/rag/v1/services.py +0 -0
- {pirag-0.1.7 → pirag-0.2.2}/pirag.egg-info/dependency_links.txt +0 -0
- {pirag-0.1.7 → pirag-0.2.2}/pirag.egg-info/entry_points.txt +0 -0
- {pirag-0.1.7 → pirag-0.2.2}/pirag.egg-info/top_level.txt +0 -0
- {pirag-0.1.7 → pirag-0.2.2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: pirag
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.2
|
4
4
|
Summary: CLI Projects of On-Premise RAG. You can use your own LLM and vector DB. Or just add remote LLM servers and vector DB.
|
5
5
|
Author-email: semir4in <semir4in@gmail.com>, jyje <jyjeon@outlook.com>
|
6
6
|
Project-URL: Homepage, https://github.com/jyje/pilot-onpremise-rag
|
@@ -9,12 +9,15 @@ Project-URL: Issue, https://github.com/jyje/pilot-onpremise-rag/issues
|
|
9
9
|
Requires-Python: >=3.9
|
10
10
|
Description-Content-Type: text/markdown
|
11
11
|
License-File: LICENSE
|
12
|
-
Requires-Dist:
|
12
|
+
Requires-Dist: dynaconf<3.3
|
13
13
|
Requires-Dist: loguru<0.8
|
14
14
|
Requires-Dist: pytest<8.4
|
15
|
-
Requires-Dist:
|
15
|
+
Requires-Dist: fastapi<0.116
|
16
|
+
Requires-Dist: uvicorn<0.35
|
16
17
|
Requires-Dist: ragas<0.3
|
17
18
|
Requires-Dist: pymilvus<2.6
|
19
|
+
Requires-Dist: langchain-openai<0.4
|
20
|
+
Requires-Dist: langchain-ollama<0.4
|
18
21
|
Dynamic: license-file
|
19
22
|
|
20
23
|
<div align="center">
|
pirag-0.2.2/app/main.py
ADDED
@@ -0,0 +1,63 @@
|
|
1
|
+
from loguru import logger
|
2
|
+
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
3
|
+
|
4
|
+
import app.rag.config as cfn
|
5
|
+
import app.rag.api as api
|
6
|
+
import app.rag.cli as cli
|
7
|
+
|
8
|
+
# Main parser
|
9
|
+
parser = ArgumentParser(
|
10
|
+
formatter_class = ArgumentDefaultsHelpFormatter,
|
11
|
+
description = """
|
12
|
+
Pilot of On-Premise RAG.
|
13
|
+
""",
|
14
|
+
parents = [cfn.top_parser, cfn.common_parser],
|
15
|
+
add_help = False,
|
16
|
+
)
|
17
|
+
|
18
|
+
# Command definitions
|
19
|
+
commands = {
|
20
|
+
# name: help, description, function, extra_parsers
|
21
|
+
"serve" : ("Start the RAG server", "Run a FastAPI-based RAG server", api.serve, []),
|
22
|
+
"chat" : ("Chat with the RAG system", "Run an interactive chat with the RAG system", cli.chat, [cfn.chat_parser]),
|
23
|
+
"train" : ("Train the RAG system", "Run a pipeline to train the RAG system", cli.train, []),
|
24
|
+
"test" : ("Test the RAG system", "Run a pipeline to test the RAG system", cli.test, []),
|
25
|
+
"doctor" : ("Diagnose the RAG system", "Run a pipeline to diagnose the RAG system", cli.doctor, [cfn.doctor_parser]),
|
26
|
+
}
|
27
|
+
|
28
|
+
# Add command parsers
|
29
|
+
subparsers = parser.add_subparsers(title="commands", dest="command")
|
30
|
+
for name, (help, description, _, extra_parsers) in commands.items():
|
31
|
+
subparsers.add_parser(
|
32
|
+
name = name,
|
33
|
+
help = help,
|
34
|
+
description = description,
|
35
|
+
parents = [cfn.common_parser] + extra_parsers,
|
36
|
+
add_help = False,
|
37
|
+
)
|
38
|
+
|
39
|
+
def main():
|
40
|
+
args = parser.parse_args()
|
41
|
+
cfn.setup_logger(cfn.LOG_LEVEL, cfn.LOG_SAVE, cfn.LOG_DIR)
|
42
|
+
logger.debug(f"Parsed arguments: {args}")
|
43
|
+
|
44
|
+
if command_info := commands.get(args.command):
|
45
|
+
func, extra_parsers = command_info[2], command_info[3]
|
46
|
+
|
47
|
+
# Create parser options dict from extra_parsers
|
48
|
+
extra_options = {}
|
49
|
+
if extra_parsers:
|
50
|
+
for parser_obj in extra_parsers:
|
51
|
+
for action in parser_obj._actions:
|
52
|
+
if action.dest == 'help':
|
53
|
+
continue
|
54
|
+
if hasattr(args, action.dest) and getattr(args, action.dest) != action.default:
|
55
|
+
extra_options[action.dest] = getattr(args, action.dest)
|
56
|
+
|
57
|
+
# Run the command with the extra parser options
|
58
|
+
func(extra_options)
|
59
|
+
else:
|
60
|
+
parser.print_help()
|
61
|
+
|
62
|
+
if __name__ == "__main__":
|
63
|
+
main()
|
@@ -0,0 +1,40 @@
|
|
1
|
+
import uvicorn
|
2
|
+
from fastapi import FastAPI, APIRouter, Request, Depends, HTTPException, Query
|
3
|
+
from fastapi.middleware.cors import CORSMiddleware
|
4
|
+
|
5
|
+
from loguru import logger
|
6
|
+
import app.rag.config as cfn
|
7
|
+
|
8
|
+
from app.rag.routers import system_router
|
9
|
+
from app.rag.v1.routers import router as v1_router
|
10
|
+
|
11
|
+
# Initialize FastAPI app
|
12
|
+
api = FastAPI(
|
13
|
+
title = "RAG API",
|
14
|
+
description = "API for Retrieval-Augmented Generation",
|
15
|
+
version = cfn.__version__,
|
16
|
+
)
|
17
|
+
|
18
|
+
# Add CORS middleware
|
19
|
+
api.add_middleware(
|
20
|
+
CORSMiddleware,
|
21
|
+
allow_origins=["*"], # Adjust in production
|
22
|
+
allow_credentials=True,
|
23
|
+
allow_methods=["*"],
|
24
|
+
allow_headers=["*"],
|
25
|
+
)
|
26
|
+
|
27
|
+
api.include_router(router=system_router, prefix="", tags=["System"])
|
28
|
+
api.include_router(router=v1_router, prefix="/v1")
|
29
|
+
|
30
|
+
def serve(parser_options=None):
|
31
|
+
print("Serving the RAG API...")
|
32
|
+
if parser_options:
|
33
|
+
logger.debug(f"Serve parser options: {parser_options}")
|
34
|
+
|
35
|
+
uvicorn.run(
|
36
|
+
app = "app.rag.api:api",
|
37
|
+
host = cfn.API_HOST,
|
38
|
+
port = cfn.API_PORT,
|
39
|
+
reload = cfn.API_RELOAD,
|
40
|
+
)
|
@@ -0,0 +1,54 @@
|
|
1
|
+
import app.rag.config as cfn
|
2
|
+
from loguru import logger
|
3
|
+
|
4
|
+
from app.rag.llm.services import doctor as doctor_llm
|
5
|
+
from app.rag.embedding.services import doctor as doctor_embedding
|
6
|
+
from app.rag.vector_store.services import doctor as doctor_vector_store
|
7
|
+
from app.rag.agent.services import chat_only_llm, chat_with_rag
|
8
|
+
|
9
|
+
def chat(options: dict):
|
10
|
+
logger.debug(f"Chat parser options: {options}")
|
11
|
+
no_rag = options.get('no_rag', False)
|
12
|
+
|
13
|
+
# -- Chat
|
14
|
+
if no_rag:
|
15
|
+
logger.info("💬 Chatting with the LLM system directly...")
|
16
|
+
chat_only_llm()
|
17
|
+
else:
|
18
|
+
logger.info("💬 Chatting with the RAG system...")
|
19
|
+
chat_with_rag()
|
20
|
+
|
21
|
+
|
22
|
+
def train(options: dict):
|
23
|
+
print("Training the RAG system...")
|
24
|
+
logger.debug(f"Train parser options: {options}")
|
25
|
+
|
26
|
+
|
27
|
+
def test(options: dict):
|
28
|
+
print("Testing the RAG system...")
|
29
|
+
logger.debug(f"Test parser options: {options}")
|
30
|
+
|
31
|
+
|
32
|
+
def doctor(options: dict):
|
33
|
+
logger.info("💚 Doctoring the RAG system...")
|
34
|
+
|
35
|
+
logger.debug(f"Doctor parser options: {options}")
|
36
|
+
# Check if resolve option is present
|
37
|
+
resolve = options.get('resolve', False)
|
38
|
+
if resolve:
|
39
|
+
logger.info("🔧 Resolving issues is enabled")
|
40
|
+
|
41
|
+
# -- LLM Server
|
42
|
+
logger.info("🔍 Checking the LLM server (OpenAI-compatible)...")
|
43
|
+
doctor_llm(resolve)
|
44
|
+
|
45
|
+
# -- Embedding Server
|
46
|
+
logger.info("🔍 Checking the embedding server (OpenAI-compatible)...")
|
47
|
+
doctor_embedding(resolve)
|
48
|
+
|
49
|
+
# -- Vector Store
|
50
|
+
logger.info("🔍 Checking the vector store server (Milvus)...")
|
51
|
+
doctor_vector_store(resolve)
|
52
|
+
|
53
|
+
if resolve:
|
54
|
+
logger.info(f"🔧 Resolving issue completed. To make sure the issues are resolved, please try doctoring again.")
|
@@ -0,0 +1,144 @@
|
|
1
|
+
import argparse, sys, pathlib
|
2
|
+
from loguru import logger
|
3
|
+
from dynaconf import Dynaconf
|
4
|
+
from importlib.metadata import version, PackageNotFoundError
|
5
|
+
try:
|
6
|
+
__version__ = version("pirag")
|
7
|
+
except PackageNotFoundError:
|
8
|
+
__version__ = "0.0.0"
|
9
|
+
|
10
|
+
|
11
|
+
# -- Load configuration
|
12
|
+
settings = Dynaconf(
|
13
|
+
settings_files = ["settings.yaml"],
|
14
|
+
envvar_prefix = False,
|
15
|
+
load_dotenv = False,
|
16
|
+
)
|
17
|
+
|
18
|
+
|
19
|
+
# -- Loging
|
20
|
+
LOG_LEVEL: str = settings.get("LOG.LEVEL", "INFO").upper()
|
21
|
+
if LOG_LEVEL not in ["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"]:
|
22
|
+
raise ValueError(f"Invalid log level: {LOG_LEVEL}. Must be one of: INFO, DEBUG, WARNING, ERROR, CRITICAL")
|
23
|
+
|
24
|
+
LOG_SAVE: bool = settings.get("LOG.SAVE", False)
|
25
|
+
LOG_DIR: str = settings.get("LOG.DIR", ".pirag/logs")
|
26
|
+
|
27
|
+
LOG_TIME_FORMAT = "{time:YYYY-MM-DD HH:mm:ss.SSS!UTC}Z"
|
28
|
+
LOG_FILE_FORMAT = f"{LOG_TIME_FORMAT} | {{level: <8}} | {{name}}:{{function}}:{{line}} - {{message}}"
|
29
|
+
LOG_CONSOLE_FORMAT_FULL = f"<green>{LOG_TIME_FORMAT}</green> | <level>{{level: <8}}</level> | <cyan>{{name}}</cyan>:<cyan>{{function}}</cyan>:<cyan>{{line}}</cyan> - <level>{{message}}</level>\n"
|
30
|
+
LOG_CONSOLE_FORMAT_SIMPLE = f"<green>{LOG_TIME_FORMAT}</green> | <level>{{level: <8}}</level> | <level>{{message}}</level>\n"
|
31
|
+
|
32
|
+
|
33
|
+
# -- Serving API
|
34
|
+
API_HOST: str = settings.get("API.HOST", "0.0.0.0")
|
35
|
+
API_PORT: int = settings.get("API.PORT", 8000)
|
36
|
+
API_RELOAD: bool = settings.get("API.RELOAD", True)
|
37
|
+
|
38
|
+
|
39
|
+
# -- LLM Server
|
40
|
+
LLM_BASE_URL: str = settings.get("LLM.BASE_URL", "http://localhost:11434")
|
41
|
+
LLM_API_KEY: str = settings.get("LLM.API_KEY", "llm_api_key")
|
42
|
+
LLM_MODEL: str = settings.get("LLM.MODEL", "gemma3:4b")
|
43
|
+
LLM_SERVER_TYPE: str = settings.get("LLM.SERVER_TYPE", "openai")
|
44
|
+
|
45
|
+
|
46
|
+
# -- Embedding Server
|
47
|
+
EMBEDDING_BASE_URL: str = settings.get("EMBEDDING.BASE_URL", "http://localhost:11434")
|
48
|
+
EMBEDDING_API_KEY: str = settings.get("EMBEDDING.API_KEY", "embedding_api_key")
|
49
|
+
EMBEDDING_MODEL: str = settings.get("EMBEDDING.MODEL", "nomic-embed-text:latest")
|
50
|
+
EMBEDDING_SERVER_TYPE: str = settings.get("EMBEDDING.SERVER_TYPE", "openai")
|
51
|
+
EMBEDDING_DIMENSION: int = settings.get("EMBEDDING.DIMENSION", 768)
|
52
|
+
|
53
|
+
|
54
|
+
# -- Data Warehouse
|
55
|
+
MINIO_BASE_URL: str = settings.get("MINIO.BASE_URL", "http://localhost:9000")
|
56
|
+
MINIO_ACCESS_KEY: str = settings.get("MINIO.ACCESS_KEY", "minioadmin")
|
57
|
+
MINIO_SECRET_KEY: str = settings.get("MINIO.SECRET_KEY", "minioadmin")
|
58
|
+
MINIO_BUCKET: str = settings.get("MINIO.BUCKET", "pirag")
|
59
|
+
MINIO_REGION: str = settings.get("MINIO.REGION", "us-east-1")
|
60
|
+
|
61
|
+
|
62
|
+
# -- Vector Store
|
63
|
+
MILVUS_BASE_URL: str = settings.get("MILVUS.BASE_URL", "http://localhost:19530")
|
64
|
+
MILVUS_USER: str = settings.get("MILVUS.USER", "milvus")
|
65
|
+
MILVUS_PASSWORD: str = settings.get("MILVUS.PASSWORD", "milvus")
|
66
|
+
MILVUS_DATABASE: str = settings.get("MILVUS.DATABASE", "milvus_database")
|
67
|
+
MILVUS_COLLECTION: str = settings.get("MILVUS.COLLECTION", "milvus_collection")
|
68
|
+
MILVUS_METRIC_TYPE: str = settings.get("MILVUS.METRIC_TYPE", "IP")
|
69
|
+
|
70
|
+
|
71
|
+
# -- Monitoring
|
72
|
+
LANGFUSE_BASE_URL: str = settings.get("LANGFUSE.BASE_URL", "http://localhost:8000")
|
73
|
+
LANGFUSE_API_KEY: str = settings.get("LANGFUSE.API_KEY", "langfuse_api_key")
|
74
|
+
LANGFUSE_PROJECT_ID: str = settings.get("LANGFUSE.PROJECT_ID", "langfuse_project_id")
|
75
|
+
|
76
|
+
|
77
|
+
def setup_logger(log_level: str, log_save: bool, log_dir: str):
|
78
|
+
"""Configure logger with specified level and outputs"""
|
79
|
+
|
80
|
+
logger.remove()
|
81
|
+
|
82
|
+
# Console handler
|
83
|
+
logger.add(
|
84
|
+
sink = sys.stderr,
|
85
|
+
level = log_level,
|
86
|
+
format = lambda record: LOG_CONSOLE_FORMAT_SIMPLE if record["level"].name == "INFO" else LOG_CONSOLE_FORMAT_FULL,
|
87
|
+
colorize = True
|
88
|
+
)
|
89
|
+
|
90
|
+
if log_save:
|
91
|
+
log_dir = pathlib.Path(log_dir)
|
92
|
+
log_dir.mkdir(exist_ok=True, parents=True)
|
93
|
+
|
94
|
+
# File handler
|
95
|
+
logger.add(
|
96
|
+
sink = log_dir / "{time:YYYYMMDD-HHmmss!UTC}Z.log",
|
97
|
+
level = log_level,
|
98
|
+
rotation = "100 MB",
|
99
|
+
retention = 0,
|
100
|
+
format = LOG_FILE_FORMAT,
|
101
|
+
serialize = False,
|
102
|
+
enqueue = True,
|
103
|
+
backtrace = True,
|
104
|
+
diagnose = True,
|
105
|
+
catch = True
|
106
|
+
)
|
107
|
+
|
108
|
+
|
109
|
+
# Top-level parser
|
110
|
+
top_parser = argparse.ArgumentParser(add_help=False)
|
111
|
+
top_parser.add_argument(
|
112
|
+
"-v", "--version",
|
113
|
+
help = "Show the `pirag` application's version and exit",
|
114
|
+
action = "version",
|
115
|
+
version = f"{__version__}",
|
116
|
+
)
|
117
|
+
|
118
|
+
|
119
|
+
# Common parser
|
120
|
+
common_parser = argparse.ArgumentParser(add_help=False)
|
121
|
+
common_parser.add_argument(
|
122
|
+
"-h", "--help",
|
123
|
+
help = "Show help message and exit",
|
124
|
+
default = argparse.SUPPRESS,
|
125
|
+
action = "help",
|
126
|
+
)
|
127
|
+
|
128
|
+
|
129
|
+
# Chat parser
|
130
|
+
chat_parser = argparse.ArgumentParser(add_help=False)
|
131
|
+
chat_parser.add_argument(
|
132
|
+
"-n", "--no-rag",
|
133
|
+
help = "Do not use RAG to answer the question. Just use the LLM to answer the question.",
|
134
|
+
action = "store_true",
|
135
|
+
)
|
136
|
+
|
137
|
+
|
138
|
+
# Doctor parser
|
139
|
+
doctor_parser = argparse.ArgumentParser(add_help=False)
|
140
|
+
doctor_parser.add_argument(
|
141
|
+
"-r", "--resolve",
|
142
|
+
help = "Resolve the issue",
|
143
|
+
action = "store_true",
|
144
|
+
)
|
@@ -0,0 +1,70 @@
|
|
1
|
+
import requests
|
2
|
+
from langchain_openai.embeddings import OpenAIEmbeddings
|
3
|
+
|
4
|
+
import app.rag.config as cfn
|
5
|
+
from app.rag.utilities import connection_check
|
6
|
+
|
7
|
+
|
8
|
+
class EmbeddingClient:
|
9
|
+
def __init__(self, base_url: str, api_key: str, model: str):
|
10
|
+
self.base_url = base_url
|
11
|
+
self.api_key = api_key
|
12
|
+
self.model = model
|
13
|
+
self._is_connected = True
|
14
|
+
self._client = None
|
15
|
+
|
16
|
+
if self.check_connection():
|
17
|
+
try:
|
18
|
+
self._client = OpenAIEmbeddings(
|
19
|
+
base_url = base_url,
|
20
|
+
api_key = api_key,
|
21
|
+
model = model
|
22
|
+
)
|
23
|
+
except Exception as e:
|
24
|
+
self._is_connected = False
|
25
|
+
|
26
|
+
def check_connection(self) -> bool:
|
27
|
+
"""Check if the embedding server is accessible"""
|
28
|
+
try:
|
29
|
+
requests.head(url=self.base_url, timeout=5)
|
30
|
+
except requests.exceptions.ConnectionError:
|
31
|
+
self._is_connected = False
|
32
|
+
return False
|
33
|
+
self._is_connected = True
|
34
|
+
return True
|
35
|
+
|
36
|
+
@connection_check
|
37
|
+
def generate(self, prompt: str) -> str:
|
38
|
+
"""Generate text from prompt"""
|
39
|
+
if not self._is_connected or self._client is None:
|
40
|
+
return ""
|
41
|
+
return self._client.embed_query(prompt)
|
42
|
+
|
43
|
+
@connection_check
|
44
|
+
def list_models(self) -> list:
|
45
|
+
"""List available models"""
|
46
|
+
if not self._is_connected:
|
47
|
+
return []
|
48
|
+
try:
|
49
|
+
response = requests.get(
|
50
|
+
f"{self.base_url}/models",
|
51
|
+
headers={"Authorization": f"Bearer {self.api_key}"}
|
52
|
+
)
|
53
|
+
if response.status_code == 200:
|
54
|
+
return [model['id'] for model in response.json()['data']]
|
55
|
+
return []
|
56
|
+
except Exception:
|
57
|
+
return []
|
58
|
+
|
59
|
+
@connection_check
|
60
|
+
def has_model(self, model: str) -> bool:
|
61
|
+
"""Check if model exists"""
|
62
|
+
if not self._is_connected:
|
63
|
+
return False
|
64
|
+
return model in self.list_models()
|
65
|
+
|
66
|
+
client = EmbeddingClient(
|
67
|
+
base_url = cfn.EMBEDDING_BASE_URL,
|
68
|
+
api_key = cfn.EMBEDDING_API_KEY,
|
69
|
+
model = cfn.EMBEDDING_MODEL,
|
70
|
+
)
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from loguru import logger
|
2
|
+
|
3
|
+
import app.rag.config as cfn
|
4
|
+
from .client import client
|
5
|
+
|
6
|
+
def doctor(resolve: bool):
|
7
|
+
# Check connection
|
8
|
+
is_connected = client.check_connection()
|
9
|
+
if not is_connected:
|
10
|
+
logger.error(f"- ❌ FAILED: Embedding connection ({cfn.EMBEDDING_BASE_URL})")
|
11
|
+
else:
|
12
|
+
logger.info(f"- ✅ PASSED: Embedding connection ({cfn.EMBEDDING_BASE_URL})")
|
13
|
+
|
14
|
+
# Check model availability
|
15
|
+
try:
|
16
|
+
if not is_connected:
|
17
|
+
logger.warning(f"- ⏭️ SKIPPED: Embedding model (Server is not accessible)")
|
18
|
+
else:
|
19
|
+
# List models
|
20
|
+
models = client.list_models()
|
21
|
+
if cfn.EMBEDDING_MODEL not in models:
|
22
|
+
logger.error(f"- ❌ FAILED: Embedding model not found ({cfn.EMBEDDING_MODEL})")
|
23
|
+
else:
|
24
|
+
logger.info(f"- ✅ PASSED: Embedding model found (Model `{cfn.EMBEDDING_MODEL}` exists)")
|
25
|
+
except Exception as e:
|
26
|
+
logger.error(f"- ❌ FAILED: Embedding model check ({str(e)})")
|
@@ -0,0 +1,128 @@
|
|
1
|
+
import requests
|
2
|
+
import time
|
3
|
+
from langchain_openai.llms import OpenAI
|
4
|
+
from typing import Dict, Tuple, Any, List, Optional
|
5
|
+
|
6
|
+
import app.rag.config as cfn
|
7
|
+
from app.rag.utilities import connection_check
|
8
|
+
from .utilities import MetricCallbackHandler
|
9
|
+
|
10
|
+
class LLMClient:
|
11
|
+
def __init__(self, base_url: str, api_key: str, model: str):
|
12
|
+
self.base_url = base_url
|
13
|
+
self.api_key = api_key
|
14
|
+
self.model = model
|
15
|
+
self._is_connected = True
|
16
|
+
self._client = None
|
17
|
+
|
18
|
+
if self.check_connection():
|
19
|
+
try:
|
20
|
+
self._client = OpenAI(
|
21
|
+
base_url = base_url,
|
22
|
+
api_key = api_key,
|
23
|
+
model = model
|
24
|
+
)
|
25
|
+
except Exception as e:
|
26
|
+
self._is_connected = False
|
27
|
+
|
28
|
+
def check_connection(self) -> bool:
|
29
|
+
"""Check if the LLM server is accessible"""
|
30
|
+
try:
|
31
|
+
requests.head(url=self.base_url, timeout=5)
|
32
|
+
except requests.exceptions.ConnectionError:
|
33
|
+
self._is_connected = False
|
34
|
+
return False
|
35
|
+
self._is_connected = True
|
36
|
+
return True
|
37
|
+
|
38
|
+
@connection_check
|
39
|
+
def generate(self, prompt: str) -> tuple:
|
40
|
+
"""Generate text from prompt and return usage information
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
tuple: (generated_text, usage_info)
|
44
|
+
"""
|
45
|
+
if not self._is_connected or self._client is None:
|
46
|
+
return "", {}
|
47
|
+
|
48
|
+
response = self._client.generate([prompt])
|
49
|
+
return response.generations[0][0].text, response.llm_output
|
50
|
+
|
51
|
+
@connection_check
|
52
|
+
def generate_with_metrics(self, prompt: str) -> Tuple[str, Dict[str, Any]]:
|
53
|
+
"""Generate text with timing and usage metrics
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
tuple: (generated_text, metrics_info)
|
57
|
+
"""
|
58
|
+
if not self._is_connected or self._client is None:
|
59
|
+
return "", {"error": "LLM client not connected"}
|
60
|
+
|
61
|
+
handler = MetricCallbackHandler()
|
62
|
+
|
63
|
+
# Create streaming client with callback
|
64
|
+
streaming_client = OpenAI(
|
65
|
+
base_url=self.base_url,
|
66
|
+
api_key=self.api_key,
|
67
|
+
model=self.model,
|
68
|
+
streaming=True,
|
69
|
+
callbacks=[handler]
|
70
|
+
)
|
71
|
+
|
72
|
+
# Make a single request
|
73
|
+
response = streaming_client.generate([prompt], callbacks=[handler])
|
74
|
+
|
75
|
+
# Get base metrics from response
|
76
|
+
metrics = {}
|
77
|
+
|
78
|
+
# Extract token usage from response
|
79
|
+
llm_output = response.llm_output if hasattr(response, 'llm_output') else {}
|
80
|
+
|
81
|
+
# Check if token_usage exists in the response
|
82
|
+
token_usage = llm_output.get('token_usage', {})
|
83
|
+
if token_usage:
|
84
|
+
# If token_usage is available, copy it to our metrics
|
85
|
+
metrics.update(token_usage)
|
86
|
+
|
87
|
+
# Add model name if available
|
88
|
+
if 'model_name' in llm_output:
|
89
|
+
metrics['model'] = llm_output['model_name']
|
90
|
+
else:
|
91
|
+
metrics['model'] = self.model
|
92
|
+
|
93
|
+
# Calculate and add timing metrics
|
94
|
+
metrics['ttft'] = handler.ttft or 0.0
|
95
|
+
metrics['total_time'] = (handler.end_time or time.time()) - handler.start_time
|
96
|
+
metrics['tokens_per_second'] = handler.calculate_tokens_per_second()
|
97
|
+
metrics['completion_tokens'] = handler.token_count
|
98
|
+
|
99
|
+
return handler.result, metrics
|
100
|
+
|
101
|
+
@connection_check
|
102
|
+
def list_models(self) -> list:
|
103
|
+
"""List available models"""
|
104
|
+
if not self._is_connected:
|
105
|
+
return []
|
106
|
+
try:
|
107
|
+
response = requests.get(
|
108
|
+
f"{self.base_url}/models",
|
109
|
+
headers={"Authorization": f"Bearer {self.api_key}"}
|
110
|
+
)
|
111
|
+
if response.status_code == 200:
|
112
|
+
return [model['id'] for model in response.json()['data']]
|
113
|
+
return []
|
114
|
+
except Exception:
|
115
|
+
return []
|
116
|
+
|
117
|
+
@connection_check
|
118
|
+
def has_model(self, model: str) -> bool:
|
119
|
+
"""Check if model exists"""
|
120
|
+
if not self._is_connected:
|
121
|
+
return False
|
122
|
+
return model in self.list_models()
|
123
|
+
|
124
|
+
client = LLMClient(
|
125
|
+
base_url = cfn.LLM_BASE_URL,
|
126
|
+
api_key = cfn.LLM_API_KEY,
|
127
|
+
model = cfn.LLM_MODEL,
|
128
|
+
)
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from loguru import logger
|
2
|
+
|
3
|
+
import app.rag.config as cfn
|
4
|
+
from .client import client
|
5
|
+
|
6
|
+
def doctor(resolve: bool):
|
7
|
+
# Check connection
|
8
|
+
is_connected = client.check_connection()
|
9
|
+
if not is_connected:
|
10
|
+
logger.error(f"- ❌ FAILED: LLM connection ({cfn.LLM_BASE_URL})")
|
11
|
+
else:
|
12
|
+
logger.info(f"- ✅ PASSED: LLM connection ({cfn.LLM_BASE_URL})")
|
13
|
+
|
14
|
+
# Check model availability
|
15
|
+
try:
|
16
|
+
if not is_connected:
|
17
|
+
logger.warning(f"- ⏭️ SKIPPED: LLM model (Server is not accessible)")
|
18
|
+
else:
|
19
|
+
# List models
|
20
|
+
models = client.list_models()
|
21
|
+
if cfn.LLM_MODEL not in models:
|
22
|
+
logger.error(f"- ❌ FAILED: LLM model not found ({cfn.LLM_MODEL})")
|
23
|
+
else:
|
24
|
+
logger.info(f"- ✅ PASSED: LLM model found (Model `{cfn.LLM_MODEL}` exists)")
|
25
|
+
except Exception as e:
|
26
|
+
logger.error(f"- ❌ FAILED: LLM model check ({str(e)})")
|
@@ -0,0 +1,40 @@
|
|
1
|
+
import time
|
2
|
+
from langchain.callbacks.base import BaseCallbackHandler
|
3
|
+
|
4
|
+
class MetricCallbackHandler(BaseCallbackHandler):
|
5
|
+
def __init__(self):
|
6
|
+
self.start_time = time.time()
|
7
|
+
self.ttft = None
|
8
|
+
self.first_token_time = None
|
9
|
+
self.result = ""
|
10
|
+
self.end_time = None
|
11
|
+
self.token_count = 0
|
12
|
+
self.token_timestamps = []
|
13
|
+
|
14
|
+
def on_llm_new_token(self, token: str, **kwargs):
|
15
|
+
current_time = time.time()
|
16
|
+
self.token_count += 1
|
17
|
+
self.token_timestamps.append(current_time)
|
18
|
+
|
19
|
+
if self.ttft is None:
|
20
|
+
self.ttft = current_time - self.start_time
|
21
|
+
self.first_token_time = current_time
|
22
|
+
|
23
|
+
self.result += token
|
24
|
+
|
25
|
+
def on_llm_end(self, *args, **kwargs):
|
26
|
+
self.end_time = time.time()
|
27
|
+
|
28
|
+
def calculate_tokens_per_second(self):
|
29
|
+
"""Calculate tokens per second after the first token"""
|
30
|
+
if self.token_count <= 1 or self.first_token_time is None or self.end_time is None:
|
31
|
+
return 0.0
|
32
|
+
|
33
|
+
# Calculate time from first token to completion (exclude TTFT)
|
34
|
+
generation_time = self.end_time - self.first_token_time
|
35
|
+
if generation_time <= 0:
|
36
|
+
return 0.0
|
37
|
+
|
38
|
+
# Exclude the first token from the count since we're measuring from after it arrived
|
39
|
+
tokens_after_first = self.token_count - 1
|
40
|
+
return tokens_after_first / generation_time
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from pydantic import BaseModel
|
2
|
+
|
3
|
+
class SystemStatusResponse(BaseModel):
|
4
|
+
"""
|
5
|
+
Response model for the system status endpoint.
|
6
|
+
"""
|
7
|
+
status: int
|
8
|
+
message: str
|
9
|
+
|
10
|
+
model_config = {
|
11
|
+
"json_schema_extra": {
|
12
|
+
"examples": [
|
13
|
+
{
|
14
|
+
"status": 200,
|
15
|
+
"message": "System is running normally"
|
16
|
+
}
|
17
|
+
]
|
18
|
+
}
|
19
|
+
}
|