pirag 0.1.7__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. app/main.py +47 -65
  2. app/rag/agent/services.py +11 -0
  3. app/rag/api.py +40 -0
  4. app/rag/cli.py +54 -0
  5. app/rag/config.py +115 -122
  6. app/rag/embedding/client.py +70 -0
  7. app/rag/embedding/services.py +26 -0
  8. app/rag/llm/client.py +128 -0
  9. app/rag/llm/services.py +26 -0
  10. app/rag/llm/utilities.py +40 -0
  11. app/rag/models.py +19 -0
  12. app/rag/routers.py +41 -0
  13. app/rag/utilities.py +15 -0
  14. app/rag/v1/routers.py +7 -0
  15. app/rag/vector_store/client.py +84 -0
  16. app/rag/vector_store/services.py +56 -0
  17. app/requirements.txt +6 -5
  18. app/setup.py +2 -3
  19. {pirag-0.1.7.dist-info → pirag-0.2.2.dist-info}/METADATA +6 -3
  20. pirag-0.2.2.dist-info/RECORD +27 -0
  21. {pirag-0.1.7.dist-info → pirag-0.2.2.dist-info}/WHEEL +1 -1
  22. app/rag/agent.py +0 -64
  23. app/rag/ask/__init__.py +0 -2
  24. app/rag/ask/config.py +0 -9
  25. app/rag/ask/router.py +0 -4
  26. app/rag/doctor/__init__.py +0 -2
  27. app/rag/doctor/config.py +0 -24
  28. app/rag/doctor/router.py +0 -4
  29. app/rag/test/__init__.py +0 -2
  30. app/rag/test/config.py +0 -9
  31. app/rag/test/router.py +0 -4
  32. app/rag/train/__init__.py +0 -2
  33. app/rag/train/config.py +0 -20
  34. app/rag/train/router.py +0 -4
  35. app/rag/train/service.py +0 -0
  36. pirag-0.1.7.dist-info/RECORD +0 -27
  37. /app/rag/{ask/service.py → test/client.py} +0 -0
  38. /app/rag/{doctor/service.py → train/client.py} +0 -0
  39. /app/rag/{test/service.py → v1/services.py} +0 -0
  40. {pirag-0.1.7.dist-info → pirag-0.2.2.dist-info}/entry_points.txt +0 -0
  41. {pirag-0.1.7.dist-info → pirag-0.2.2.dist-info}/licenses/LICENSE +0 -0
  42. {pirag-0.1.7.dist-info → pirag-0.2.2.dist-info}/top_level.txt +0 -0
app/main.py CHANGED
@@ -1,81 +1,63 @@
1
- import argparse, os
2
- from dotenv import load_dotenv
3
1
  from loguru import logger
2
+ from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
4
3
 
5
- load_dotenv(dotenv_path=os.environ.get('ENV_FILE', '.env'), override=True)
6
-
7
- from app.rag.config import top_parser, common_parser, setup_logger
8
- from app.rag.doctor import help as doctor_help, parser as doctor_parser, route as doctor_route
9
- from app.rag.train import help as train_help, parser as train_parser, route as train_route
10
- from app.rag.ask import help as ask_help, parser as ask_parser, route as ask_route
11
- from app.rag.test import help as test_help, parser as test_parser, route as test_route
4
+ import app.rag.config as cfn
5
+ import app.rag.api as api
6
+ import app.rag.cli as cli
12
7
 
13
8
  # Main parser
14
- parser = argparse.ArgumentParser(
15
- formatter_class = argparse.ArgumentDefaultsHelpFormatter,
16
- description = 'Pilot of On-Premise RAG',
17
- parents = [top_parser],
18
- add_help = False,
19
- )
20
-
21
- # Commands
22
- subparsers = parser.add_subparsers(
23
- title = 'commands',
24
- dest = 'command',
25
- )
26
-
27
- subparsers.add_parser(
28
- 'doctor',
29
- help = doctor_help,
30
- description = doctor_parser.description,
31
- parents = [top_parser, common_parser, doctor_parser],
32
- add_help = False,
33
- )
34
-
35
- subparsers.add_parser(
36
- 'train',
37
- help = train_help,
38
- description = train_parser.description,
39
- parents = [top_parser, common_parser, train_parser],
40
- add_help = False,
41
- )
42
-
43
- subparsers.add_parser(
44
- 'test',
45
- help = test_help,
46
- description = test_parser.description,
47
- parents = [top_parser, common_parser, test_parser],
9
+ parser = ArgumentParser(
10
+ formatter_class = ArgumentDefaultsHelpFormatter,
11
+ description = """
12
+ Pilot of On-Premise RAG.
13
+ """,
14
+ parents = [cfn.top_parser, cfn.common_parser],
48
15
  add_help = False,
49
16
  )
50
17
 
51
- subparsers.add_parser(
52
- 'ask',
53
- help = ask_help,
54
- description = ask_parser.description,
55
- parents = [top_parser, common_parser, ask_parser],
56
- add_help = False,
57
- )
18
+ # Command definitions
19
+ commands = {
20
+ # name: help, description, function, extra_parsers
21
+ "serve" : ("Start the RAG server", "Run a FastAPI-based RAG server", api.serve, []),
22
+ "chat" : ("Chat with the RAG system", "Run an interactive chat with the RAG system", cli.chat, [cfn.chat_parser]),
23
+ "train" : ("Train the RAG system", "Run a pipeline to train the RAG system", cli.train, []),
24
+ "test" : ("Test the RAG system", "Run a pipeline to test the RAG system", cli.test, []),
25
+ "doctor" : ("Diagnose the RAG system", "Run a pipeline to diagnose the RAG system", cli.doctor, [cfn.doctor_parser]),
26
+ }
27
+
28
+ # Add command parsers
29
+ subparsers = parser.add_subparsers(title="commands", dest="command")
30
+ for name, (help, description, _, extra_parsers) in commands.items():
31
+ subparsers.add_parser(
32
+ name = name,
33
+ help = help,
34
+ description = description,
35
+ parents = [cfn.common_parser] + extra_parsers,
36
+ add_help = False,
37
+ )
58
38
 
59
39
  def main():
60
40
  args = parser.parse_args()
61
- setup_logger(
62
- log_level = args.log_level,
63
- log_dir = args.log_dir,
64
- )
65
- command_message = f"with command: {args.command}" if args.command else ""
66
- logger.info(f"RAG Started {command_message}")
41
+ cfn.setup_logger(cfn.LOG_LEVEL, cfn.LOG_SAVE, cfn.LOG_DIR)
67
42
  logger.debug(f"Parsed arguments: {args}")
68
43
 
69
- if args.command == 'doctor':
70
- doctor_route(args)
71
- elif args.command == 'ask':
72
- ask_route(args)
73
- elif args.command == 'train':
74
- train_route(args)
75
- elif args.command == 'test':
76
- test_route(args)
44
+ if command_info := commands.get(args.command):
45
+ func, extra_parsers = command_info[2], command_info[3]
46
+
47
+ # Create parser options dict from extra_parsers
48
+ extra_options = {}
49
+ if extra_parsers:
50
+ for parser_obj in extra_parsers:
51
+ for action in parser_obj._actions:
52
+ if action.dest == 'help':
53
+ continue
54
+ if hasattr(args, action.dest) and getattr(args, action.dest) != action.default:
55
+ extra_options[action.dest] = getattr(args, action.dest)
56
+
57
+ # Run the command with the extra parser options
58
+ func(extra_options)
77
59
  else:
78
60
  parser.print_help()
79
61
 
80
- if __name__ == '__main__':
62
+ if __name__ == "__main__":
81
63
  main()
@@ -0,0 +1,11 @@
1
+ from app.rag.llm.client import client as llm_client
2
+
3
+ def chat_only_llm():
4
+ response = llm_client.generate_with_metrics("Hello, how are you?")
5
+ print(response)
6
+
7
+
8
+ def chat_with_rag():
9
+ pass
10
+
11
+
app/rag/api.py ADDED
@@ -0,0 +1,40 @@
1
+ import uvicorn
2
+ from fastapi import FastAPI, APIRouter, Request, Depends, HTTPException, Query
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+
5
+ from loguru import logger
6
+ import app.rag.config as cfn
7
+
8
+ from app.rag.routers import system_router
9
+ from app.rag.v1.routers import router as v1_router
10
+
11
+ # Initialize FastAPI app
12
+ api = FastAPI(
13
+ title = "RAG API",
14
+ description = "API for Retrieval-Augmented Generation",
15
+ version = cfn.__version__,
16
+ )
17
+
18
+ # Add CORS middleware
19
+ api.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"], # Adjust in production
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ api.include_router(router=system_router, prefix="", tags=["System"])
28
+ api.include_router(router=v1_router, prefix="/v1")
29
+
30
+ def serve(parser_options=None):
31
+ print("Serving the RAG API...")
32
+ if parser_options:
33
+ logger.debug(f"Serve parser options: {parser_options}")
34
+
35
+ uvicorn.run(
36
+ app = "app.rag.api:api",
37
+ host = cfn.API_HOST,
38
+ port = cfn.API_PORT,
39
+ reload = cfn.API_RELOAD,
40
+ )
app/rag/cli.py ADDED
@@ -0,0 +1,54 @@
1
+ import app.rag.config as cfn
2
+ from loguru import logger
3
+
4
+ from app.rag.llm.services import doctor as doctor_llm
5
+ from app.rag.embedding.services import doctor as doctor_embedding
6
+ from app.rag.vector_store.services import doctor as doctor_vector_store
7
+ from app.rag.agent.services import chat_only_llm, chat_with_rag
8
+
9
+ def chat(options: dict):
10
+ logger.debug(f"Chat parser options: {options}")
11
+ no_rag = options.get('no_rag', False)
12
+
13
+ # -- Chat
14
+ if no_rag:
15
+ logger.info("💬 Chatting with the LLM system directly...")
16
+ chat_only_llm()
17
+ else:
18
+ logger.info("💬 Chatting with the RAG system...")
19
+ chat_with_rag()
20
+
21
+
22
+ def train(options: dict):
23
+ print("Training the RAG system...")
24
+ logger.debug(f"Train parser options: {options}")
25
+
26
+
27
+ def test(options: dict):
28
+ print("Testing the RAG system...")
29
+ logger.debug(f"Test parser options: {options}")
30
+
31
+
32
+ def doctor(options: dict):
33
+ logger.info("💚 Doctoring the RAG system...")
34
+
35
+ logger.debug(f"Doctor parser options: {options}")
36
+ # Check if resolve option is present
37
+ resolve = options.get('resolve', False)
38
+ if resolve:
39
+ logger.info("🔧 Resolving issues is enabled")
40
+
41
+ # -- LLM Server
42
+ logger.info("🔍 Checking the LLM server (OpenAI-compatible)...")
43
+ doctor_llm(resolve)
44
+
45
+ # -- Embedding Server
46
+ logger.info("🔍 Checking the embedding server (OpenAI-compatible)...")
47
+ doctor_embedding(resolve)
48
+
49
+ # -- Vector Store
50
+ logger.info("🔍 Checking the vector store server (Milvus)...")
51
+ doctor_vector_store(resolve)
52
+
53
+ if resolve:
54
+ logger.info(f"🔧 Resolving issue completed. To make sure the issues are resolved, please try doctoring again.")
app/rag/config.py CHANGED
@@ -1,35 +1,84 @@
1
- import argparse, os, sys
2
- from pathlib import Path
1
+ import argparse, sys, pathlib
3
2
  from loguru import logger
3
+ from dynaconf import Dynaconf
4
+ from importlib.metadata import version, PackageNotFoundError
5
+ try:
6
+ __version__ = version("pirag")
7
+ except PackageNotFoundError:
8
+ __version__ = "0.0.0"
9
+
10
+
11
+ # -- Load configuration
12
+ settings = Dynaconf(
13
+ settings_files = ["settings.yaml"],
14
+ envvar_prefix = False,
15
+ load_dotenv = False,
16
+ )
17
+
18
+
19
+ # -- Loging
20
+ LOG_LEVEL: str = settings.get("LOG.LEVEL", "INFO").upper()
21
+ if LOG_LEVEL not in ["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"]:
22
+ raise ValueError(f"Invalid log level: {LOG_LEVEL}. Must be one of: INFO, DEBUG, WARNING, ERROR, CRITICAL")
23
+
24
+ LOG_SAVE: bool = settings.get("LOG.SAVE", False)
25
+ LOG_DIR: str = settings.get("LOG.DIR", ".pirag/logs")
4
26
 
5
- # Logger format constants
6
27
  LOG_TIME_FORMAT = "{time:YYYY-MM-DD HH:mm:ss.SSS!UTC}Z"
7
28
  LOG_FILE_FORMAT = f"{LOG_TIME_FORMAT} | {{level: <8}} | {{name}}:{{function}}:{{line}} - {{message}}"
8
29
  LOG_CONSOLE_FORMAT_FULL = f"<green>{LOG_TIME_FORMAT}</green> | <level>{{level: <8}}</level> | <cyan>{{name}}</cyan>:<cyan>{{function}}</cyan>:<cyan>{{line}}</cyan> - <level>{{message}}</level>\n"
9
30
  LOG_CONSOLE_FORMAT_SIMPLE = f"<green>{LOG_TIME_FORMAT}</green> | <level>{{level: <8}}</level> | <level>{{message}}</level>\n"
10
31
 
11
- def setup_logger(log_level: str, log_dir: str):
32
+
33
+ # -- Serving API
34
+ API_HOST: str = settings.get("API.HOST", "0.0.0.0")
35
+ API_PORT: int = settings.get("API.PORT", 8000)
36
+ API_RELOAD: bool = settings.get("API.RELOAD", True)
37
+
38
+
39
+ # -- LLM Server
40
+ LLM_BASE_URL: str = settings.get("LLM.BASE_URL", "http://localhost:11434")
41
+ LLM_API_KEY: str = settings.get("LLM.API_KEY", "llm_api_key")
42
+ LLM_MODEL: str = settings.get("LLM.MODEL", "gemma3:4b")
43
+ LLM_SERVER_TYPE: str = settings.get("LLM.SERVER_TYPE", "openai")
44
+
45
+
46
+ # -- Embedding Server
47
+ EMBEDDING_BASE_URL: str = settings.get("EMBEDDING.BASE_URL", "http://localhost:11434")
48
+ EMBEDDING_API_KEY: str = settings.get("EMBEDDING.API_KEY", "embedding_api_key")
49
+ EMBEDDING_MODEL: str = settings.get("EMBEDDING.MODEL", "nomic-embed-text:latest")
50
+ EMBEDDING_SERVER_TYPE: str = settings.get("EMBEDDING.SERVER_TYPE", "openai")
51
+ EMBEDDING_DIMENSION: int = settings.get("EMBEDDING.DIMENSION", 768)
52
+
53
+
54
+ # -- Data Warehouse
55
+ MINIO_BASE_URL: str = settings.get("MINIO.BASE_URL", "http://localhost:9000")
56
+ MINIO_ACCESS_KEY: str = settings.get("MINIO.ACCESS_KEY", "minioadmin")
57
+ MINIO_SECRET_KEY: str = settings.get("MINIO.SECRET_KEY", "minioadmin")
58
+ MINIO_BUCKET: str = settings.get("MINIO.BUCKET", "pirag")
59
+ MINIO_REGION: str = settings.get("MINIO.REGION", "us-east-1")
60
+
61
+
62
+ # -- Vector Store
63
+ MILVUS_BASE_URL: str = settings.get("MILVUS.BASE_URL", "http://localhost:19530")
64
+ MILVUS_USER: str = settings.get("MILVUS.USER", "milvus")
65
+ MILVUS_PASSWORD: str = settings.get("MILVUS.PASSWORD", "milvus")
66
+ MILVUS_DATABASE: str = settings.get("MILVUS.DATABASE", "milvus_database")
67
+ MILVUS_COLLECTION: str = settings.get("MILVUS.COLLECTION", "milvus_collection")
68
+ MILVUS_METRIC_TYPE: str = settings.get("MILVUS.METRIC_TYPE", "IP")
69
+
70
+
71
+ # -- Monitoring
72
+ LANGFUSE_BASE_URL: str = settings.get("LANGFUSE.BASE_URL", "http://localhost:8000")
73
+ LANGFUSE_API_KEY: str = settings.get("LANGFUSE.API_KEY", "langfuse_api_key")
74
+ LANGFUSE_PROJECT_ID: str = settings.get("LANGFUSE.PROJECT_ID", "langfuse_project_id")
75
+
76
+
77
+ def setup_logger(log_level: str, log_save: bool, log_dir: str):
12
78
  """Configure logger with specified level and outputs"""
13
-
14
- log_dir = Path(log_dir)
15
- log_dir.mkdir(exist_ok=True, parents=True)
16
-
79
+
17
80
  logger.remove()
18
-
19
- # File handler
20
- logger.add(
21
- sink = log_dir / "{time:YYYYMMDD-HHmmss!UTC}Z.log",
22
- level = log_level,
23
- rotation = "100 MB",
24
- retention = 0,
25
- format = LOG_FILE_FORMAT,
26
- serialize = False,
27
- enqueue = True,
28
- backtrace = True,
29
- diagnose = True,
30
- catch = True
31
- )
32
-
81
+
33
82
  # Console handler
34
83
  logger.add(
35
84
  sink = sys.stderr,
@@ -38,114 +87,58 @@ def setup_logger(log_level: str, log_dir: str):
38
87
  colorize = True
39
88
  )
40
89
 
41
-
42
- class EnvDefault(argparse.Action):
43
- """Custom argparse action that uses environment variables as defaults.
44
-
45
- This action extends the standard argparse.Action to support reading default values
46
- from environment variables. If the specified environment variable exists, its value
47
- will be used as the default value for the argument.
48
-
49
- For boolean flags (store_true/store_false), the environment variable is interpreted
50
- as a boolean value where 'true', '1', 'yes', or 'on' (case-insensitive) are
51
- considered True.
52
-
53
- Args:
54
- envvar (str): Name of the environment variable to use as default
55
- required (bool, optional): Whether the argument is required. Defaults to True.
56
- Note: If a default value is found in environment variables, required is set to False.
57
- default (Any, optional): Default value if environment variable is not set. Defaults to None.
58
- **kwargs: Additional arguments passed to argparse.Action
59
-
60
- Example:
61
- ```python
62
- parser.add_argument(
63
- '--log-level',
64
- envvar='LOG_LEVEL',
65
- help='Logging level',
66
- default='INFO',
67
- action=EnvDefault
90
+ if log_save:
91
+ log_dir = pathlib.Path(log_dir)
92
+ log_dir.mkdir(exist_ok=True, parents=True)
93
+
94
+ # File handler
95
+ logger.add(
96
+ sink = log_dir / "{time:YYYYMMDD-HHmmss!UTC}Z.log",
97
+ level = log_level,
98
+ rotation = "100 MB",
99
+ retention = 0,
100
+ format = LOG_FILE_FORMAT,
101
+ serialize = False,
102
+ enqueue = True,
103
+ backtrace = True,
104
+ diagnose = True,
105
+ catch = True
68
106
  )
69
- ```
70
-
71
- Note:
72
- The help text is automatically updated to include the environment variable name.
73
- """
74
- def __init__(self, envvar, required=True, default=None, **kwargs):
75
- if envvar and envvar in os.environ:
76
- env_value = os.environ[envvar]
77
- # Convert string environment variable to boolean
78
- if kwargs.get('nargs') is None and kwargs.get('const') is not None: # store_true/store_false case
79
- default = env_value.lower() in ('true', '1', 'yes', 'on')
80
- else:
81
- default = env_value
82
- logger.debug(f"Using {envvar}={default} from environment")
83
-
84
- if envvar:
85
- kwargs["help"] += f" (envvar: {envvar})"
86
-
87
- if required and default:
88
- required = False
89
-
90
- super(EnvDefault, self).__init__(default=default, required=required, **kwargs)
91
- self.envvar = envvar
92
-
93
- def __call__(self, parser, namespace, values, option_string=None):
94
- setattr(namespace, self.dest, values if values is not None else self.default)
95
-
96
-
97
- # Top-level parser with common options
98
- top_parser = argparse.ArgumentParser(add_help=False)
99
107
 
100
- top_parser.add_argument(
101
- '-h', '--help',
102
- help = 'Show help message and exit',
103
- default = argparse.SUPPRESS,
104
- action = 'help',
105
- )
106
108
 
109
+ # Top-level parser
110
+ top_parser = argparse.ArgumentParser(add_help=False)
107
111
  top_parser.add_argument(
108
- '--env-file',
109
- envvar = 'ENV_FILE',
110
- help = 'Path to environment file',
111
- default = '.env',
112
- type = str,
113
- action = EnvDefault,
112
+ "-v", "--version",
113
+ help = "Show the `pirag` application's version and exit",
114
+ action = "version",
115
+ version = f"{__version__}",
114
116
  )
115
117
 
116
- top_parser.add_argument(
117
- '--log-level',
118
- envvar = 'LOG_LEVEL',
119
- help = 'Logging level',
120
- default = 'INFO',
121
- type = lambda x: x.upper(),
122
- choices = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
123
- required = False,
124
- action = EnvDefault,
125
- )
126
118
 
127
- top_parser.add_argument(
128
- '--log-dir',
129
- envvar = 'LOG_DIR',
130
- help = 'Path to log directory',
131
- default = '.pirag/logs',
132
- type = str,
133
- required = False,
134
- action = EnvDefault,
119
+ # Common parser
120
+ common_parser = argparse.ArgumentParser(add_help=False)
121
+ common_parser.add_argument(
122
+ "-h", "--help",
123
+ help = "Show help message and exit",
124
+ default = argparse.SUPPRESS,
125
+ action = "help",
135
126
  )
136
127
 
137
- top_parser.add_argument(
138
- '--log-save',
139
- envvar = 'LOG_SAVE',
140
- help = 'Save log to file. If this flag is set, the log will be saved to the file specified in the `--log-path`.',
141
- default = False,
142
- const = True,
143
- nargs = 0,
144
- type = bool,
145
- required = False,
146
- action = EnvDefault,
128
+
129
+ # Chat parser
130
+ chat_parser = argparse.ArgumentParser(add_help=False)
131
+ chat_parser.add_argument(
132
+ "-n", "--no-rag",
133
+ help = "Do not use RAG to answer the question. Just use the LLM to answer the question.",
134
+ action = "store_true",
147
135
  )
148
136
 
149
- common_parser = argparse.ArgumentParser(
150
- add_help = False,
137
+
138
+ # Doctor parser
139
+ doctor_parser = argparse.ArgumentParser(add_help=False)
140
+ doctor_parser.add_argument(
141
+ "-r", "--resolve",
142
+ help = "Resolve the issue",
143
+ action = "store_true",
151
144
  )
@@ -0,0 +1,70 @@
1
+ import requests
2
+ from langchain_openai.embeddings import OpenAIEmbeddings
3
+
4
+ import app.rag.config as cfn
5
+ from app.rag.utilities import connection_check
6
+
7
+
8
+ class EmbeddingClient:
9
+ def __init__(self, base_url: str, api_key: str, model: str):
10
+ self.base_url = base_url
11
+ self.api_key = api_key
12
+ self.model = model
13
+ self._is_connected = True
14
+ self._client = None
15
+
16
+ if self.check_connection():
17
+ try:
18
+ self._client = OpenAIEmbeddings(
19
+ base_url = base_url,
20
+ api_key = api_key,
21
+ model = model
22
+ )
23
+ except Exception as e:
24
+ self._is_connected = False
25
+
26
+ def check_connection(self) -> bool:
27
+ """Check if the embedding server is accessible"""
28
+ try:
29
+ requests.head(url=self.base_url, timeout=5)
30
+ except requests.exceptions.ConnectionError:
31
+ self._is_connected = False
32
+ return False
33
+ self._is_connected = True
34
+ return True
35
+
36
+ @connection_check
37
+ def generate(self, prompt: str) -> str:
38
+ """Generate text from prompt"""
39
+ if not self._is_connected or self._client is None:
40
+ return ""
41
+ return self._client.embed_query(prompt)
42
+
43
+ @connection_check
44
+ def list_models(self) -> list:
45
+ """List available models"""
46
+ if not self._is_connected:
47
+ return []
48
+ try:
49
+ response = requests.get(
50
+ f"{self.base_url}/models",
51
+ headers={"Authorization": f"Bearer {self.api_key}"}
52
+ )
53
+ if response.status_code == 200:
54
+ return [model['id'] for model in response.json()['data']]
55
+ return []
56
+ except Exception:
57
+ return []
58
+
59
+ @connection_check
60
+ def has_model(self, model: str) -> bool:
61
+ """Check if model exists"""
62
+ if not self._is_connected:
63
+ return False
64
+ return model in self.list_models()
65
+
66
+ client = EmbeddingClient(
67
+ base_url = cfn.EMBEDDING_BASE_URL,
68
+ api_key = cfn.EMBEDDING_API_KEY,
69
+ model = cfn.EMBEDDING_MODEL,
70
+ )
@@ -0,0 +1,26 @@
1
+ from loguru import logger
2
+
3
+ import app.rag.config as cfn
4
+ from .client import client
5
+
6
+ def doctor(resolve: bool):
7
+ # Check connection
8
+ is_connected = client.check_connection()
9
+ if not is_connected:
10
+ logger.error(f"- ❌ FAILED: Embedding connection ({cfn.EMBEDDING_BASE_URL})")
11
+ else:
12
+ logger.info(f"- ✅ PASSED: Embedding connection ({cfn.EMBEDDING_BASE_URL})")
13
+
14
+ # Check model availability
15
+ try:
16
+ if not is_connected:
17
+ logger.warning(f"- ⏭️ SKIPPED: Embedding model (Server is not accessible)")
18
+ else:
19
+ # List models
20
+ models = client.list_models()
21
+ if cfn.EMBEDDING_MODEL not in models:
22
+ logger.error(f"- ❌ FAILED: Embedding model not found ({cfn.EMBEDDING_MODEL})")
23
+ else:
24
+ logger.info(f"- ✅ PASSED: Embedding model found (Model `{cfn.EMBEDDING_MODEL}` exists)")
25
+ except Exception as e:
26
+ logger.error(f"- ❌ FAILED: Embedding model check ({str(e)})")