pirag 0.1.7__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {pirag-0.1.7 → pirag-0.2.2}/PKG-INFO +6 -3
  2. pirag-0.2.2/app/main.py +63 -0
  3. pirag-0.2.2/app/rag/agent/services.py +11 -0
  4. pirag-0.2.2/app/rag/api.py +40 -0
  5. pirag-0.2.2/app/rag/cli.py +54 -0
  6. pirag-0.2.2/app/rag/config.py +144 -0
  7. pirag-0.2.2/app/rag/embedding/client.py +70 -0
  8. pirag-0.2.2/app/rag/embedding/services.py +26 -0
  9. pirag-0.2.2/app/rag/llm/client.py +128 -0
  10. pirag-0.2.2/app/rag/llm/services.py +26 -0
  11. pirag-0.2.2/app/rag/llm/utilities.py +40 -0
  12. pirag-0.2.2/app/rag/models.py +19 -0
  13. pirag-0.2.2/app/rag/routers.py +41 -0
  14. pirag-0.2.2/app/rag/utilities.py +15 -0
  15. pirag-0.2.2/app/rag/v1/routers.py +7 -0
  16. pirag-0.2.2/app/rag/vector_store/client.py +84 -0
  17. pirag-0.2.2/app/rag/vector_store/services.py +56 -0
  18. pirag-0.2.2/app/requirements.txt +12 -0
  19. {pirag-0.1.7 → pirag-0.2.2}/app/setup.py +2 -3
  20. {pirag-0.1.7 → pirag-0.2.2}/pirag.egg-info/PKG-INFO +6 -3
  21. pirag-0.2.2/pirag.egg-info/SOURCES.txt +30 -0
  22. pirag-0.2.2/pirag.egg-info/requires.txt +9 -0
  23. {pirag-0.1.7 → pirag-0.2.2}/pyproject.toml +1 -1
  24. pirag-0.1.7/app/main.py +0 -81
  25. pirag-0.1.7/app/rag/agent.py +0 -64
  26. pirag-0.1.7/app/rag/ask/__init__.py +0 -2
  27. pirag-0.1.7/app/rag/ask/config.py +0 -9
  28. pirag-0.1.7/app/rag/ask/router.py +0 -4
  29. pirag-0.1.7/app/rag/config.py +0 -151
  30. pirag-0.1.7/app/rag/doctor/__init__.py +0 -2
  31. pirag-0.1.7/app/rag/doctor/config.py +0 -24
  32. pirag-0.1.7/app/rag/doctor/router.py +0 -4
  33. pirag-0.1.7/app/rag/test/__init__.py +0 -2
  34. pirag-0.1.7/app/rag/test/config.py +0 -9
  35. pirag-0.1.7/app/rag/test/router.py +0 -4
  36. pirag-0.1.7/app/rag/train/__init__.py +0 -2
  37. pirag-0.1.7/app/rag/train/config.py +0 -20
  38. pirag-0.1.7/app/rag/train/router.py +0 -4
  39. pirag-0.1.7/app/rag/train/service.py +0 -0
  40. pirag-0.1.7/app/requirements.txt +0 -11
  41. pirag-0.1.7/pirag.egg-info/SOURCES.txt +0 -30
  42. pirag-0.1.7/pirag.egg-info/requires.txt +0 -6
  43. {pirag-0.1.7 → pirag-0.2.2}/LICENSE +0 -0
  44. {pirag-0.1.7 → pirag-0.2.2}/README.md +0 -0
  45. /pirag-0.1.7/app/rag/ask/service.py → /pirag-0.2.2/app/rag/test/client.py +0 -0
  46. /pirag-0.1.7/app/rag/doctor/service.py → /pirag-0.2.2/app/rag/train/client.py +0 -0
  47. /pirag-0.1.7/app/rag/test/service.py → /pirag-0.2.2/app/rag/v1/services.py +0 -0
  48. {pirag-0.1.7 → pirag-0.2.2}/pirag.egg-info/dependency_links.txt +0 -0
  49. {pirag-0.1.7 → pirag-0.2.2}/pirag.egg-info/entry_points.txt +0 -0
  50. {pirag-0.1.7 → pirag-0.2.2}/pirag.egg-info/top_level.txt +0 -0
  51. {pirag-0.1.7 → pirag-0.2.2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pirag
3
- Version: 0.1.7
3
+ Version: 0.2.2
4
4
  Summary: CLI Projects of On-Premise RAG. You can use your own LLM and vector DB. Or just add remote LLM servers and vector DB.
5
5
  Author-email: semir4in <semir4in@gmail.com>, jyje <jyjeon@outlook.com>
6
6
  Project-URL: Homepage, https://github.com/jyje/pilot-onpremise-rag
@@ -9,12 +9,15 @@ Project-URL: Issue, https://github.com/jyje/pilot-onpremise-rag/issues
9
9
  Requires-Python: >=3.9
10
10
  Description-Content-Type: text/markdown
11
11
  License-File: LICENSE
12
- Requires-Dist: python-dotenv<1.2
12
+ Requires-Dist: dynaconf<3.3
13
13
  Requires-Dist: loguru<0.8
14
14
  Requires-Dist: pytest<8.4
15
- Requires-Dist: black<25.2
15
+ Requires-Dist: fastapi<0.116
16
+ Requires-Dist: uvicorn<0.35
16
17
  Requires-Dist: ragas<0.3
17
18
  Requires-Dist: pymilvus<2.6
19
+ Requires-Dist: langchain-openai<0.4
20
+ Requires-Dist: langchain-ollama<0.4
18
21
  Dynamic: license-file
19
22
 
20
23
  <div align="center">
@@ -0,0 +1,63 @@
1
+ from loguru import logger
2
+ from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
3
+
4
+ import app.rag.config as cfn
5
+ import app.rag.api as api
6
+ import app.rag.cli as cli
7
+
8
+ # Main parser
9
+ parser = ArgumentParser(
10
+ formatter_class = ArgumentDefaultsHelpFormatter,
11
+ description = """
12
+ Pilot of On-Premise RAG.
13
+ """,
14
+ parents = [cfn.top_parser, cfn.common_parser],
15
+ add_help = False,
16
+ )
17
+
18
+ # Command definitions
19
+ commands = {
20
+ # name: help, description, function, extra_parsers
21
+ "serve" : ("Start the RAG server", "Run a FastAPI-based RAG server", api.serve, []),
22
+ "chat" : ("Chat with the RAG system", "Run an interactive chat with the RAG system", cli.chat, [cfn.chat_parser]),
23
+ "train" : ("Train the RAG system", "Run a pipeline to train the RAG system", cli.train, []),
24
+ "test" : ("Test the RAG system", "Run a pipeline to test the RAG system", cli.test, []),
25
+ "doctor" : ("Diagnose the RAG system", "Run a pipeline to diagnose the RAG system", cli.doctor, [cfn.doctor_parser]),
26
+ }
27
+
28
+ # Add command parsers
29
+ subparsers = parser.add_subparsers(title="commands", dest="command")
30
+ for name, (help, description, _, extra_parsers) in commands.items():
31
+ subparsers.add_parser(
32
+ name = name,
33
+ help = help,
34
+ description = description,
35
+ parents = [cfn.common_parser] + extra_parsers,
36
+ add_help = False,
37
+ )
38
+
39
+ def main():
40
+ args = parser.parse_args()
41
+ cfn.setup_logger(cfn.LOG_LEVEL, cfn.LOG_SAVE, cfn.LOG_DIR)
42
+ logger.debug(f"Parsed arguments: {args}")
43
+
44
+ if command_info := commands.get(args.command):
45
+ func, extra_parsers = command_info[2], command_info[3]
46
+
47
+ # Create parser options dict from extra_parsers
48
+ extra_options = {}
49
+ if extra_parsers:
50
+ for parser_obj in extra_parsers:
51
+ for action in parser_obj._actions:
52
+ if action.dest == 'help':
53
+ continue
54
+ if hasattr(args, action.dest) and getattr(args, action.dest) != action.default:
55
+ extra_options[action.dest] = getattr(args, action.dest)
56
+
57
+ # Run the command with the extra parser options
58
+ func(extra_options)
59
+ else:
60
+ parser.print_help()
61
+
62
+ if __name__ == "__main__":
63
+ main()
@@ -0,0 +1,11 @@
1
+ from app.rag.llm.client import client as llm_client
2
+
3
+ def chat_only_llm():
4
+ response = llm_client.generate_with_metrics("Hello, how are you?")
5
+ print(response)
6
+
7
+
8
+ def chat_with_rag():
9
+ pass
10
+
11
+
@@ -0,0 +1,40 @@
1
+ import uvicorn
2
+ from fastapi import FastAPI, APIRouter, Request, Depends, HTTPException, Query
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+
5
+ from loguru import logger
6
+ import app.rag.config as cfn
7
+
8
+ from app.rag.routers import system_router
9
+ from app.rag.v1.routers import router as v1_router
10
+
11
+ # Initialize FastAPI app
12
+ api = FastAPI(
13
+ title = "RAG API",
14
+ description = "API for Retrieval-Augmented Generation",
15
+ version = cfn.__version__,
16
+ )
17
+
18
+ # Add CORS middleware
19
+ api.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"], # Adjust in production
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ api.include_router(router=system_router, prefix="", tags=["System"])
28
+ api.include_router(router=v1_router, prefix="/v1")
29
+
30
+ def serve(parser_options=None):
31
+ print("Serving the RAG API...")
32
+ if parser_options:
33
+ logger.debug(f"Serve parser options: {parser_options}")
34
+
35
+ uvicorn.run(
36
+ app = "app.rag.api:api",
37
+ host = cfn.API_HOST,
38
+ port = cfn.API_PORT,
39
+ reload = cfn.API_RELOAD,
40
+ )
@@ -0,0 +1,54 @@
1
+ import app.rag.config as cfn
2
+ from loguru import logger
3
+
4
+ from app.rag.llm.services import doctor as doctor_llm
5
+ from app.rag.embedding.services import doctor as doctor_embedding
6
+ from app.rag.vector_store.services import doctor as doctor_vector_store
7
+ from app.rag.agent.services import chat_only_llm, chat_with_rag
8
+
9
+ def chat(options: dict):
10
+ logger.debug(f"Chat parser options: {options}")
11
+ no_rag = options.get('no_rag', False)
12
+
13
+ # -- Chat
14
+ if no_rag:
15
+ logger.info("💬 Chatting with the LLM system directly...")
16
+ chat_only_llm()
17
+ else:
18
+ logger.info("💬 Chatting with the RAG system...")
19
+ chat_with_rag()
20
+
21
+
22
+ def train(options: dict):
23
+ print("Training the RAG system...")
24
+ logger.debug(f"Train parser options: {options}")
25
+
26
+
27
+ def test(options: dict):
28
+ print("Testing the RAG system...")
29
+ logger.debug(f"Test parser options: {options}")
30
+
31
+
32
+ def doctor(options: dict):
33
+ logger.info("💚 Doctoring the RAG system...")
34
+
35
+ logger.debug(f"Doctor parser options: {options}")
36
+ # Check if resolve option is present
37
+ resolve = options.get('resolve', False)
38
+ if resolve:
39
+ logger.info("🔧 Resolving issues is enabled")
40
+
41
+ # -- LLM Server
42
+ logger.info("🔍 Checking the LLM server (OpenAI-compatible)...")
43
+ doctor_llm(resolve)
44
+
45
+ # -- Embedding Server
46
+ logger.info("🔍 Checking the embedding server (OpenAI-compatible)...")
47
+ doctor_embedding(resolve)
48
+
49
+ # -- Vector Store
50
+ logger.info("🔍 Checking the vector store server (Milvus)...")
51
+ doctor_vector_store(resolve)
52
+
53
+ if resolve:
54
+ logger.info(f"🔧 Resolving issue completed. To make sure the issues are resolved, please try doctoring again.")
@@ -0,0 +1,144 @@
1
+ import argparse, sys, pathlib
2
+ from loguru import logger
3
+ from dynaconf import Dynaconf
4
+ from importlib.metadata import version, PackageNotFoundError
5
+ try:
6
+ __version__ = version("pirag")
7
+ except PackageNotFoundError:
8
+ __version__ = "0.0.0"
9
+
10
+
11
+ # -- Load configuration
12
+ settings = Dynaconf(
13
+ settings_files = ["settings.yaml"],
14
+ envvar_prefix = False,
15
+ load_dotenv = False,
16
+ )
17
+
18
+
19
+ # -- Loging
20
+ LOG_LEVEL: str = settings.get("LOG.LEVEL", "INFO").upper()
21
+ if LOG_LEVEL not in ["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"]:
22
+ raise ValueError(f"Invalid log level: {LOG_LEVEL}. Must be one of: INFO, DEBUG, WARNING, ERROR, CRITICAL")
23
+
24
+ LOG_SAVE: bool = settings.get("LOG.SAVE", False)
25
+ LOG_DIR: str = settings.get("LOG.DIR", ".pirag/logs")
26
+
27
+ LOG_TIME_FORMAT = "{time:YYYY-MM-DD HH:mm:ss.SSS!UTC}Z"
28
+ LOG_FILE_FORMAT = f"{LOG_TIME_FORMAT} | {{level: <8}} | {{name}}:{{function}}:{{line}} - {{message}}"
29
+ LOG_CONSOLE_FORMAT_FULL = f"<green>{LOG_TIME_FORMAT}</green> | <level>{{level: <8}}</level> | <cyan>{{name}}</cyan>:<cyan>{{function}}</cyan>:<cyan>{{line}}</cyan> - <level>{{message}}</level>\n"
30
+ LOG_CONSOLE_FORMAT_SIMPLE = f"<green>{LOG_TIME_FORMAT}</green> | <level>{{level: <8}}</level> | <level>{{message}}</level>\n"
31
+
32
+
33
+ # -- Serving API
34
+ API_HOST: str = settings.get("API.HOST", "0.0.0.0")
35
+ API_PORT: int = settings.get("API.PORT", 8000)
36
+ API_RELOAD: bool = settings.get("API.RELOAD", True)
37
+
38
+
39
+ # -- LLM Server
40
+ LLM_BASE_URL: str = settings.get("LLM.BASE_URL", "http://localhost:11434")
41
+ LLM_API_KEY: str = settings.get("LLM.API_KEY", "llm_api_key")
42
+ LLM_MODEL: str = settings.get("LLM.MODEL", "gemma3:4b")
43
+ LLM_SERVER_TYPE: str = settings.get("LLM.SERVER_TYPE", "openai")
44
+
45
+
46
+ # -- Embedding Server
47
+ EMBEDDING_BASE_URL: str = settings.get("EMBEDDING.BASE_URL", "http://localhost:11434")
48
+ EMBEDDING_API_KEY: str = settings.get("EMBEDDING.API_KEY", "embedding_api_key")
49
+ EMBEDDING_MODEL: str = settings.get("EMBEDDING.MODEL", "nomic-embed-text:latest")
50
+ EMBEDDING_SERVER_TYPE: str = settings.get("EMBEDDING.SERVER_TYPE", "openai")
51
+ EMBEDDING_DIMENSION: int = settings.get("EMBEDDING.DIMENSION", 768)
52
+
53
+
54
+ # -- Data Warehouse
55
+ MINIO_BASE_URL: str = settings.get("MINIO.BASE_URL", "http://localhost:9000")
56
+ MINIO_ACCESS_KEY: str = settings.get("MINIO.ACCESS_KEY", "minioadmin")
57
+ MINIO_SECRET_KEY: str = settings.get("MINIO.SECRET_KEY", "minioadmin")
58
+ MINIO_BUCKET: str = settings.get("MINIO.BUCKET", "pirag")
59
+ MINIO_REGION: str = settings.get("MINIO.REGION", "us-east-1")
60
+
61
+
62
+ # -- Vector Store
63
+ MILVUS_BASE_URL: str = settings.get("MILVUS.BASE_URL", "http://localhost:19530")
64
+ MILVUS_USER: str = settings.get("MILVUS.USER", "milvus")
65
+ MILVUS_PASSWORD: str = settings.get("MILVUS.PASSWORD", "milvus")
66
+ MILVUS_DATABASE: str = settings.get("MILVUS.DATABASE", "milvus_database")
67
+ MILVUS_COLLECTION: str = settings.get("MILVUS.COLLECTION", "milvus_collection")
68
+ MILVUS_METRIC_TYPE: str = settings.get("MILVUS.METRIC_TYPE", "IP")
69
+
70
+
71
+ # -- Monitoring
72
+ LANGFUSE_BASE_URL: str = settings.get("LANGFUSE.BASE_URL", "http://localhost:8000")
73
+ LANGFUSE_API_KEY: str = settings.get("LANGFUSE.API_KEY", "langfuse_api_key")
74
+ LANGFUSE_PROJECT_ID: str = settings.get("LANGFUSE.PROJECT_ID", "langfuse_project_id")
75
+
76
+
77
+ def setup_logger(log_level: str, log_save: bool, log_dir: str):
78
+ """Configure logger with specified level and outputs"""
79
+
80
+ logger.remove()
81
+
82
+ # Console handler
83
+ logger.add(
84
+ sink = sys.stderr,
85
+ level = log_level,
86
+ format = lambda record: LOG_CONSOLE_FORMAT_SIMPLE if record["level"].name == "INFO" else LOG_CONSOLE_FORMAT_FULL,
87
+ colorize = True
88
+ )
89
+
90
+ if log_save:
91
+ log_dir = pathlib.Path(log_dir)
92
+ log_dir.mkdir(exist_ok=True, parents=True)
93
+
94
+ # File handler
95
+ logger.add(
96
+ sink = log_dir / "{time:YYYYMMDD-HHmmss!UTC}Z.log",
97
+ level = log_level,
98
+ rotation = "100 MB",
99
+ retention = 0,
100
+ format = LOG_FILE_FORMAT,
101
+ serialize = False,
102
+ enqueue = True,
103
+ backtrace = True,
104
+ diagnose = True,
105
+ catch = True
106
+ )
107
+
108
+
109
+ # Top-level parser
110
+ top_parser = argparse.ArgumentParser(add_help=False)
111
+ top_parser.add_argument(
112
+ "-v", "--version",
113
+ help = "Show the `pirag` application's version and exit",
114
+ action = "version",
115
+ version = f"{__version__}",
116
+ )
117
+
118
+
119
+ # Common parser
120
+ common_parser = argparse.ArgumentParser(add_help=False)
121
+ common_parser.add_argument(
122
+ "-h", "--help",
123
+ help = "Show help message and exit",
124
+ default = argparse.SUPPRESS,
125
+ action = "help",
126
+ )
127
+
128
+
129
+ # Chat parser
130
+ chat_parser = argparse.ArgumentParser(add_help=False)
131
+ chat_parser.add_argument(
132
+ "-n", "--no-rag",
133
+ help = "Do not use RAG to answer the question. Just use the LLM to answer the question.",
134
+ action = "store_true",
135
+ )
136
+
137
+
138
+ # Doctor parser
139
+ doctor_parser = argparse.ArgumentParser(add_help=False)
140
+ doctor_parser.add_argument(
141
+ "-r", "--resolve",
142
+ help = "Resolve the issue",
143
+ action = "store_true",
144
+ )
@@ -0,0 +1,70 @@
1
+ import requests
2
+ from langchain_openai.embeddings import OpenAIEmbeddings
3
+
4
+ import app.rag.config as cfn
5
+ from app.rag.utilities import connection_check
6
+
7
+
8
+ class EmbeddingClient:
9
+ def __init__(self, base_url: str, api_key: str, model: str):
10
+ self.base_url = base_url
11
+ self.api_key = api_key
12
+ self.model = model
13
+ self._is_connected = True
14
+ self._client = None
15
+
16
+ if self.check_connection():
17
+ try:
18
+ self._client = OpenAIEmbeddings(
19
+ base_url = base_url,
20
+ api_key = api_key,
21
+ model = model
22
+ )
23
+ except Exception as e:
24
+ self._is_connected = False
25
+
26
+ def check_connection(self) -> bool:
27
+ """Check if the embedding server is accessible"""
28
+ try:
29
+ requests.head(url=self.base_url, timeout=5)
30
+ except requests.exceptions.ConnectionError:
31
+ self._is_connected = False
32
+ return False
33
+ self._is_connected = True
34
+ return True
35
+
36
+ @connection_check
37
+ def generate(self, prompt: str) -> str:
38
+ """Generate text from prompt"""
39
+ if not self._is_connected or self._client is None:
40
+ return ""
41
+ return self._client.embed_query(prompt)
42
+
43
+ @connection_check
44
+ def list_models(self) -> list:
45
+ """List available models"""
46
+ if not self._is_connected:
47
+ return []
48
+ try:
49
+ response = requests.get(
50
+ f"{self.base_url}/models",
51
+ headers={"Authorization": f"Bearer {self.api_key}"}
52
+ )
53
+ if response.status_code == 200:
54
+ return [model['id'] for model in response.json()['data']]
55
+ return []
56
+ except Exception:
57
+ return []
58
+
59
+ @connection_check
60
+ def has_model(self, model: str) -> bool:
61
+ """Check if model exists"""
62
+ if not self._is_connected:
63
+ return False
64
+ return model in self.list_models()
65
+
66
+ client = EmbeddingClient(
67
+ base_url = cfn.EMBEDDING_BASE_URL,
68
+ api_key = cfn.EMBEDDING_API_KEY,
69
+ model = cfn.EMBEDDING_MODEL,
70
+ )
@@ -0,0 +1,26 @@
1
+ from loguru import logger
2
+
3
+ import app.rag.config as cfn
4
+ from .client import client
5
+
6
+ def doctor(resolve: bool):
7
+ # Check connection
8
+ is_connected = client.check_connection()
9
+ if not is_connected:
10
+ logger.error(f"- ❌ FAILED: Embedding connection ({cfn.EMBEDDING_BASE_URL})")
11
+ else:
12
+ logger.info(f"- ✅ PASSED: Embedding connection ({cfn.EMBEDDING_BASE_URL})")
13
+
14
+ # Check model availability
15
+ try:
16
+ if not is_connected:
17
+ logger.warning(f"- ⏭️ SKIPPED: Embedding model (Server is not accessible)")
18
+ else:
19
+ # List models
20
+ models = client.list_models()
21
+ if cfn.EMBEDDING_MODEL not in models:
22
+ logger.error(f"- ❌ FAILED: Embedding model not found ({cfn.EMBEDDING_MODEL})")
23
+ else:
24
+ logger.info(f"- ✅ PASSED: Embedding model found (Model `{cfn.EMBEDDING_MODEL}` exists)")
25
+ except Exception as e:
26
+ logger.error(f"- ❌ FAILED: Embedding model check ({str(e)})")
@@ -0,0 +1,128 @@
1
+ import requests
2
+ import time
3
+ from langchain_openai.llms import OpenAI
4
+ from typing import Dict, Tuple, Any, List, Optional
5
+
6
+ import app.rag.config as cfn
7
+ from app.rag.utilities import connection_check
8
+ from .utilities import MetricCallbackHandler
9
+
10
+ class LLMClient:
11
+ def __init__(self, base_url: str, api_key: str, model: str):
12
+ self.base_url = base_url
13
+ self.api_key = api_key
14
+ self.model = model
15
+ self._is_connected = True
16
+ self._client = None
17
+
18
+ if self.check_connection():
19
+ try:
20
+ self._client = OpenAI(
21
+ base_url = base_url,
22
+ api_key = api_key,
23
+ model = model
24
+ )
25
+ except Exception as e:
26
+ self._is_connected = False
27
+
28
+ def check_connection(self) -> bool:
29
+ """Check if the LLM server is accessible"""
30
+ try:
31
+ requests.head(url=self.base_url, timeout=5)
32
+ except requests.exceptions.ConnectionError:
33
+ self._is_connected = False
34
+ return False
35
+ self._is_connected = True
36
+ return True
37
+
38
+ @connection_check
39
+ def generate(self, prompt: str) -> tuple:
40
+ """Generate text from prompt and return usage information
41
+
42
+ Returns:
43
+ tuple: (generated_text, usage_info)
44
+ """
45
+ if not self._is_connected or self._client is None:
46
+ return "", {}
47
+
48
+ response = self._client.generate([prompt])
49
+ return response.generations[0][0].text, response.llm_output
50
+
51
+ @connection_check
52
+ def generate_with_metrics(self, prompt: str) -> Tuple[str, Dict[str, Any]]:
53
+ """Generate text with timing and usage metrics
54
+
55
+ Returns:
56
+ tuple: (generated_text, metrics_info)
57
+ """
58
+ if not self._is_connected or self._client is None:
59
+ return "", {"error": "LLM client not connected"}
60
+
61
+ handler = MetricCallbackHandler()
62
+
63
+ # Create streaming client with callback
64
+ streaming_client = OpenAI(
65
+ base_url=self.base_url,
66
+ api_key=self.api_key,
67
+ model=self.model,
68
+ streaming=True,
69
+ callbacks=[handler]
70
+ )
71
+
72
+ # Make a single request
73
+ response = streaming_client.generate([prompt], callbacks=[handler])
74
+
75
+ # Get base metrics from response
76
+ metrics = {}
77
+
78
+ # Extract token usage from response
79
+ llm_output = response.llm_output if hasattr(response, 'llm_output') else {}
80
+
81
+ # Check if token_usage exists in the response
82
+ token_usage = llm_output.get('token_usage', {})
83
+ if token_usage:
84
+ # If token_usage is available, copy it to our metrics
85
+ metrics.update(token_usage)
86
+
87
+ # Add model name if available
88
+ if 'model_name' in llm_output:
89
+ metrics['model'] = llm_output['model_name']
90
+ else:
91
+ metrics['model'] = self.model
92
+
93
+ # Calculate and add timing metrics
94
+ metrics['ttft'] = handler.ttft or 0.0
95
+ metrics['total_time'] = (handler.end_time or time.time()) - handler.start_time
96
+ metrics['tokens_per_second'] = handler.calculate_tokens_per_second()
97
+ metrics['completion_tokens'] = handler.token_count
98
+
99
+ return handler.result, metrics
100
+
101
+ @connection_check
102
+ def list_models(self) -> list:
103
+ """List available models"""
104
+ if not self._is_connected:
105
+ return []
106
+ try:
107
+ response = requests.get(
108
+ f"{self.base_url}/models",
109
+ headers={"Authorization": f"Bearer {self.api_key}"}
110
+ )
111
+ if response.status_code == 200:
112
+ return [model['id'] for model in response.json()['data']]
113
+ return []
114
+ except Exception:
115
+ return []
116
+
117
+ @connection_check
118
+ def has_model(self, model: str) -> bool:
119
+ """Check if model exists"""
120
+ if not self._is_connected:
121
+ return False
122
+ return model in self.list_models()
123
+
124
+ client = LLMClient(
125
+ base_url = cfn.LLM_BASE_URL,
126
+ api_key = cfn.LLM_API_KEY,
127
+ model = cfn.LLM_MODEL,
128
+ )
@@ -0,0 +1,26 @@
1
+ from loguru import logger
2
+
3
+ import app.rag.config as cfn
4
+ from .client import client
5
+
6
+ def doctor(resolve: bool):
7
+ # Check connection
8
+ is_connected = client.check_connection()
9
+ if not is_connected:
10
+ logger.error(f"- ❌ FAILED: LLM connection ({cfn.LLM_BASE_URL})")
11
+ else:
12
+ logger.info(f"- ✅ PASSED: LLM connection ({cfn.LLM_BASE_URL})")
13
+
14
+ # Check model availability
15
+ try:
16
+ if not is_connected:
17
+ logger.warning(f"- ⏭️ SKIPPED: LLM model (Server is not accessible)")
18
+ else:
19
+ # List models
20
+ models = client.list_models()
21
+ if cfn.LLM_MODEL not in models:
22
+ logger.error(f"- ❌ FAILED: LLM model not found ({cfn.LLM_MODEL})")
23
+ else:
24
+ logger.info(f"- ✅ PASSED: LLM model found (Model `{cfn.LLM_MODEL}` exists)")
25
+ except Exception as e:
26
+ logger.error(f"- ❌ FAILED: LLM model check ({str(e)})")
@@ -0,0 +1,40 @@
1
+ import time
2
+ from langchain.callbacks.base import BaseCallbackHandler
3
+
4
+ class MetricCallbackHandler(BaseCallbackHandler):
5
+ def __init__(self):
6
+ self.start_time = time.time()
7
+ self.ttft = None
8
+ self.first_token_time = None
9
+ self.result = ""
10
+ self.end_time = None
11
+ self.token_count = 0
12
+ self.token_timestamps = []
13
+
14
+ def on_llm_new_token(self, token: str, **kwargs):
15
+ current_time = time.time()
16
+ self.token_count += 1
17
+ self.token_timestamps.append(current_time)
18
+
19
+ if self.ttft is None:
20
+ self.ttft = current_time - self.start_time
21
+ self.first_token_time = current_time
22
+
23
+ self.result += token
24
+
25
+ def on_llm_end(self, *args, **kwargs):
26
+ self.end_time = time.time()
27
+
28
+ def calculate_tokens_per_second(self):
29
+ """Calculate tokens per second after the first token"""
30
+ if self.token_count <= 1 or self.first_token_time is None or self.end_time is None:
31
+ return 0.0
32
+
33
+ # Calculate time from first token to completion (exclude TTFT)
34
+ generation_time = self.end_time - self.first_token_time
35
+ if generation_time <= 0:
36
+ return 0.0
37
+
38
+ # Exclude the first token from the count since we're measuring from after it arrived
39
+ tokens_after_first = self.token_count - 1
40
+ return tokens_after_first / generation_time
@@ -0,0 +1,19 @@
1
+ from pydantic import BaseModel
2
+
3
+ class SystemStatusResponse(BaseModel):
4
+ """
5
+ Response model for the system status endpoint.
6
+ """
7
+ status: int
8
+ message: str
9
+
10
+ model_config = {
11
+ "json_schema_extra": {
12
+ "examples": [
13
+ {
14
+ "status": 200,
15
+ "message": "System is running normally"
16
+ }
17
+ ]
18
+ }
19
+ }