code-graph-rag 0.0.79__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
codebase_rag/config.py ADDED
@@ -0,0 +1,370 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import asdict, dataclass
4
+ from pathlib import Path
5
+ from typing import TypedDict, Unpack
6
+
7
+ from dotenv import load_dotenv
8
+ from loguru import logger
9
+ from pydantic import Field
10
+ from pydantic_settings import BaseSettings, SettingsConfigDict
11
+
12
+ from . import constants as cs
13
+ from . import exceptions as ex
14
+ from . import logs
15
+ from .types_defs import CgrignorePatterns, ModelConfigKwargs
16
+
17
+ load_dotenv()
18
+
19
+
20
+ class ApiKeyInfoEntry(TypedDict):
21
+ env_var: str
22
+ url: str
23
+ name: str
24
+
25
+
26
+ API_KEY_INFO: dict[str, ApiKeyInfoEntry] = {
27
+ cs.Provider.OPENAI: {
28
+ "env_var": "OPENAI_API_KEY",
29
+ "url": "https://platform.openai.com/api-keys",
30
+ "name": "OpenAI",
31
+ },
32
+ cs.Provider.ANTHROPIC: {
33
+ "env_var": "ANTHROPIC_API_KEY",
34
+ "url": "https://console.anthropic.com/settings/keys",
35
+ "name": "Anthropic",
36
+ },
37
+ cs.Provider.GOOGLE: {
38
+ "env_var": "GOOGLE_API_KEY",
39
+ "url": "https://console.cloud.google.com/apis/credentials",
40
+ "name": "Google AI",
41
+ },
42
+ cs.Provider.AZURE: {
43
+ "env_var": "AZURE_API_KEY",
44
+ "url": "https://portal.azure.com/",
45
+ "name": "Azure OpenAI",
46
+ },
47
+ cs.Provider.COHERE: {
48
+ "env_var": "COHERE_API_KEY",
49
+ "url": "https://dashboard.cohere.com/api-keys",
50
+ "name": "Cohere",
51
+ },
52
+ }
53
+
54
+
55
+ def format_missing_api_key_errors(
56
+ provider: str, role: str = cs.DEFAULT_MODEL_ROLE
57
+ ) -> str:
58
+ provider_lower = provider.lower()
59
+
60
+ if provider_lower in API_KEY_INFO:
61
+ info = API_KEY_INFO[provider_lower]
62
+ env_var = info["env_var"]
63
+ url = info["url"]
64
+ name = info["name"]
65
+ else:
66
+ env_var = f"{provider.upper()}_API_KEY"
67
+ url = f"your {provider} provider's website"
68
+ name = provider.capitalize()
69
+
70
+ role_msg = f" for {role}" if role != cs.DEFAULT_MODEL_ROLE else ""
71
+
72
+ error_msg = f"""
73
+ ─── API Key Missing ───────────────────────────────────────────────
74
+
75
+ Error: {env_var} environment variable is not set.
76
+ This is required to use {name}{role_msg}.
77
+
78
+ To fix this:
79
+
80
+ 1. Get your API key from:
81
+ {url}
82
+
83
+ 2. Set it in your environment:
84
+ export {env_var}='your-key-here'
85
+
86
+ Or add it to your .env file in the project root:
87
+ {env_var}=your-key-here
88
+
89
+ 3. Alternatively, you can use a local model with Ollama:
90
+ (No API key required)
91
+
92
+ ───────────────────────────────────────────────────────────────────
93
+ """.strip() # noqa: W293
94
+ return error_msg
95
+
96
+
97
+ LOCAL_PROVIDERS = frozenset({cs.Provider.OLLAMA, cs.Provider.LOCAL, cs.Provider.VLLM})
98
+
99
+
100
+ @dataclass
101
+ class ModelConfig:
102
+ provider: str
103
+ model_id: str
104
+ api_key: str | None = None
105
+ endpoint: str | None = None
106
+ project_id: str | None = None
107
+ region: str | None = None
108
+ provider_type: str | None = None
109
+ thinking_budget: int | None = None
110
+ service_account_file: str | None = None
111
+
112
+ def to_update_kwargs(self) -> ModelConfigKwargs:
113
+ result = asdict(self)
114
+ del result[cs.FIELD_PROVIDER]
115
+ del result[cs.FIELD_MODEL_ID]
116
+ return ModelConfigKwargs(**result)
117
+
118
+ def validate_api_key(self, role: str = cs.DEFAULT_MODEL_ROLE) -> None:
119
+ provider_lower = self.provider.lower()
120
+ if provider_lower in LOCAL_PROVIDERS or (
121
+ provider_lower == cs.Provider.GOOGLE
122
+ and self.provider_type == cs.GoogleProviderType.VERTEX
123
+ ):
124
+ return
125
+ if (
126
+ not self.api_key
127
+ or not self.api_key.strip()
128
+ or self.api_key == cs.DEFAULT_API_KEY
129
+ ):
130
+ error_msg = format_missing_api_key_errors(self.provider, role)
131
+ raise ValueError(error_msg)
132
+
133
+
134
+ class AppConfig(BaseSettings):
135
+ """
136
+ (H) All settings are loaded from environment variables or a .env file.
137
+ """
138
+
139
+ model_config = SettingsConfigDict(
140
+ env_file=".env",
141
+ env_file_encoding="utf-8",
142
+ case_sensitive=False,
143
+ )
144
+
145
+ MEMGRAPH_HOST: str = "localhost"
146
+ MEMGRAPH_PORT: int = 7687
147
+ MEMGRAPH_HTTP_PORT: int = 7444
148
+ LAB_PORT: int = 3000
149
+ MEMGRAPH_BATCH_SIZE: int = 1000
150
+ AGENT_RETRIES: int = 3
151
+ ORCHESTRATOR_OUTPUT_RETRIES: int = 100
152
+
153
+ ORCHESTRATOR_PROVIDER: str = ""
154
+ ORCHESTRATOR_MODEL: str = ""
155
+ ORCHESTRATOR_API_KEY: str | None = None
156
+ ORCHESTRATOR_ENDPOINT: str | None = None
157
+ ORCHESTRATOR_PROJECT_ID: str | None = None
158
+ ORCHESTRATOR_REGION: str = cs.DEFAULT_REGION
159
+ ORCHESTRATOR_PROVIDER_TYPE: cs.GoogleProviderType | None = None
160
+ ORCHESTRATOR_THINKING_BUDGET: int | None = None
161
+ ORCHESTRATOR_SERVICE_ACCOUNT_FILE: str | None = None
162
+
163
+ CYPHER_PROVIDER: str = ""
164
+ CYPHER_MODEL: str = ""
165
+ CYPHER_API_KEY: str | None = None
166
+ CYPHER_ENDPOINT: str | None = None
167
+ CYPHER_PROJECT_ID: str | None = None
168
+ CYPHER_REGION: str = cs.DEFAULT_REGION
169
+ CYPHER_PROVIDER_TYPE: cs.GoogleProviderType | None = None
170
+ CYPHER_THINKING_BUDGET: int | None = None
171
+ CYPHER_SERVICE_ACCOUNT_FILE: str | None = None
172
+
173
+ OLLAMA_BASE_URL: str = "http://localhost:11434"
174
+
175
+ @property
176
+ def ollama_endpoint(self) -> str:
177
+ return f"{self.OLLAMA_BASE_URL.rstrip('/')}/v1"
178
+
179
+ TARGET_REPO_PATH: str = "."
180
+ SHELL_COMMAND_TIMEOUT: int = 30
181
+ SHELL_COMMAND_ALLOWLIST: frozenset[str] = frozenset(
182
+ {
183
+ "ls",
184
+ "rg",
185
+ "cat",
186
+ "git",
187
+ "echo",
188
+ "pwd",
189
+ "pytest",
190
+ "mypy",
191
+ "ruff",
192
+ "uv",
193
+ "find",
194
+ "pre-commit",
195
+ "rm",
196
+ "cp",
197
+ "mv",
198
+ "mkdir",
199
+ "rmdir",
200
+ "wc",
201
+ "head",
202
+ "tail",
203
+ "sort",
204
+ "uniq",
205
+ "cut",
206
+ "tr",
207
+ "xargs",
208
+ "awk",
209
+ "sed",
210
+ "tee",
211
+ }
212
+ )
213
+ SHELL_READ_ONLY_COMMANDS: frozenset[str] = frozenset(
214
+ {
215
+ "ls",
216
+ "cat",
217
+ "find",
218
+ "pwd",
219
+ "rg",
220
+ "echo",
221
+ "wc",
222
+ "head",
223
+ "tail",
224
+ "sort",
225
+ "uniq",
226
+ "cut",
227
+ "tr",
228
+ }
229
+ )
230
+ SHELL_SAFE_GIT_SUBCOMMANDS: frozenset[str] = frozenset(
231
+ {
232
+ "status",
233
+ "log",
234
+ "diff",
235
+ "show",
236
+ "ls-files",
237
+ "remote",
238
+ "config",
239
+ "branch",
240
+ }
241
+ )
242
+
243
+ QDRANT_DB_PATH: str = "./.qdrant_code_embeddings"
244
+ QDRANT_COLLECTION_NAME: str = "code_embeddings"
245
+ QDRANT_VECTOR_DIM: int = 768
246
+ QDRANT_TOP_K: int = 5
247
+ EMBEDDING_MAX_LENGTH: int = 512
248
+ EMBEDDING_PROGRESS_INTERVAL: int = 10
249
+
250
+ CACHE_MAX_ENTRIES: int = 1000
251
+ CACHE_MAX_MEMORY_MB: int = 500
252
+ CACHE_EVICTION_DIVISOR: int = 10
253
+ CACHE_MEMORY_THRESHOLD_RATIO: float = 0.8
254
+
255
+ OLLAMA_HEALTH_TIMEOUT: float = 5.0
256
+
257
+ _active_orchestrator: ModelConfig | None = None
258
+ _active_cypher: ModelConfig | None = None
259
+
260
+ QUIET: bool = Field(False, validation_alias="CGR_QUIET")
261
+
262
+ def _get_default_config(self, role: str) -> ModelConfig:
263
+ role_upper = role.upper()
264
+
265
+ provider = getattr(self, f"{role_upper}_PROVIDER", None)
266
+ model = getattr(self, f"{role_upper}_MODEL", None)
267
+
268
+ if provider and model:
269
+ return ModelConfig(
270
+ provider=provider.lower(),
271
+ model_id=model,
272
+ api_key=getattr(self, f"{role_upper}_API_KEY", None),
273
+ endpoint=getattr(self, f"{role_upper}_ENDPOINT", None),
274
+ project_id=getattr(self, f"{role_upper}_PROJECT_ID", None),
275
+ region=getattr(self, f"{role_upper}_REGION", cs.DEFAULT_REGION),
276
+ provider_type=getattr(self, f"{role_upper}_PROVIDER_TYPE", None),
277
+ thinking_budget=getattr(self, f"{role_upper}_THINKING_BUDGET", None),
278
+ service_account_file=getattr(
279
+ self, f"{role_upper}_SERVICE_ACCOUNT_FILE", None
280
+ ),
281
+ )
282
+
283
+ return ModelConfig(
284
+ provider=cs.Provider.OLLAMA,
285
+ model_id=cs.DEFAULT_MODEL,
286
+ endpoint=self.ollama_endpoint,
287
+ api_key=cs.DEFAULT_API_KEY,
288
+ )
289
+
290
+ def _get_default_orchestrator_config(self) -> ModelConfig:
291
+ return self._get_default_config(cs.ModelRole.ORCHESTRATOR)
292
+
293
+ def _get_default_cypher_config(self) -> ModelConfig:
294
+ return self._get_default_config(cs.ModelRole.CYPHER)
295
+
296
+ @property
297
+ def active_orchestrator_config(self) -> ModelConfig:
298
+ return self._active_orchestrator or self._get_default_orchestrator_config()
299
+
300
+ @property
301
+ def active_cypher_config(self) -> ModelConfig:
302
+ return self._active_cypher or self._get_default_cypher_config()
303
+
304
+ def set_orchestrator(
305
+ self, provider: str, model: str, **kwargs: Unpack[ModelConfigKwargs]
306
+ ) -> None:
307
+ config = ModelConfig(provider=provider.lower(), model_id=model, **kwargs)
308
+ self._active_orchestrator = config
309
+
310
+ def set_cypher(
311
+ self, provider: str, model: str, **kwargs: Unpack[ModelConfigKwargs]
312
+ ) -> None:
313
+ config = ModelConfig(provider=provider.lower(), model_id=model, **kwargs)
314
+ self._active_cypher = config
315
+
316
+ def parse_model_string(self, model_string: str) -> tuple[str, str]:
317
+ if ":" not in model_string:
318
+ return cs.Provider.OLLAMA, model_string
319
+ provider, model = model_string.split(":", 1)
320
+ if not provider:
321
+ raise ValueError(ex.PROVIDER_EMPTY)
322
+ return provider.lower(), model
323
+
324
+ def resolve_batch_size(self, batch_size: int | None) -> int:
325
+ resolved = self.MEMGRAPH_BATCH_SIZE if batch_size is None else batch_size
326
+ if resolved < 1:
327
+ raise ValueError(ex.BATCH_SIZE_POSITIVE)
328
+ return resolved
329
+
330
+
331
+ settings = AppConfig()
332
+
333
+ CGRIGNORE_FILENAME = ".cgrignore"
334
+
335
+
336
+ EMPTY_CGRIGNORE = CgrignorePatterns(exclude=frozenset(), unignore=frozenset())
337
+
338
+
339
+ def load_cgrignore_patterns(repo_path: Path) -> CgrignorePatterns:
340
+ ignore_file = repo_path / CGRIGNORE_FILENAME
341
+ if not ignore_file.is_file():
342
+ return EMPTY_CGRIGNORE
343
+
344
+ exclude: set[str] = set()
345
+ unignore: set[str] = set()
346
+ try:
347
+ with ignore_file.open(encoding="utf-8") as f:
348
+ for line in f:
349
+ line = line.strip()
350
+ if not line or line.startswith("#"):
351
+ continue
352
+ if line.startswith("!"):
353
+ unignore.add(line[1:].strip())
354
+ else:
355
+ exclude.add(line)
356
+ if exclude or unignore:
357
+ logger.info(
358
+ logs.CGRIGNORE_LOADED.format(
359
+ exclude_count=len(exclude),
360
+ unignore_count=len(unignore),
361
+ path=ignore_file,
362
+ )
363
+ )
364
+ return CgrignorePatterns(
365
+ exclude=frozenset(exclude),
366
+ unignore=frozenset(unignore),
367
+ )
368
+ except OSError as e:
369
+ logger.warning(logs.CGRIGNORE_READ_FAILED.format(path=ignore_file, error=e))
370
+ return EMPTY_CGRIGNORE