ragit 0.8.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ragit/config.py CHANGED
@@ -3,9 +3,10 @@
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
  #
5
5
  """
6
- Ragit configuration management.
6
+ Ragit configuration management with Pydantic validation.
7
7
 
8
8
  Loads configuration from environment variables and .env files.
9
+ Validates all configuration values at startup.
9
10
 
10
11
  Note: As of v0.8.0, ragit no longer has default LLM or embedding models.
11
12
  Users must explicitly configure providers.
@@ -15,6 +16,10 @@ import os
15
16
  from pathlib import Path
16
17
 
17
18
  from dotenv import load_dotenv
19
+ from pydantic import BaseModel, Field, field_validator
20
+
21
+ # Note: We define ConfigValidationError locally to avoid circular imports,
22
+ # but ragit.exceptions.ConfigurationError can be used elsewhere
18
23
 
19
24
  # Load .env file from current working directory or project root
20
25
  _env_path = Path.cwd() / ".env"
@@ -29,32 +34,170 @@ else:
29
34
  break
30
35
 
31
36
 
32
- class Config:
33
- """Ragit configuration loaded from environment variables.
37
+ class ConfigValidationError(Exception):
38
+ """Raised when configuration validation fails."""
39
+
40
+ pass
41
+
42
+
43
+ class RagitConfig(BaseModel):
44
+ """Validated ragit configuration.
34
45
 
35
- Note: As of v0.8.0, DEFAULT_LLM_MODEL and DEFAULT_EMBEDDING_MODEL are
36
- no longer used as defaults. They are only read from environment variables
37
- for backwards compatibility with user configurations.
46
+ All configuration values are validated at startup. Invalid values
47
+ raise ConfigValidationError with a descriptive message.
48
+
49
+ Attributes
50
+ ----------
51
+ ollama_base_url : str
52
+ Ollama server URL (default: http://localhost:11434)
53
+ ollama_embedding_url : str
54
+ Embedding API URL (defaults to ollama_base_url)
55
+ ollama_api_key : str | None
56
+ API key for authentication
57
+ ollama_timeout : int
58
+ Request timeout in seconds (1-600)
59
+ default_llm_model : str | None
60
+ Default LLM model name
61
+ default_embedding_model : str | None
62
+ Default embedding model name
63
+ log_level : str
64
+ Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
38
65
  """
39
66
 
40
- # Ollama LLM API Configuration (used when explicitly using OllamaProvider)
41
- OLLAMA_BASE_URL: str = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
42
- OLLAMA_API_KEY: str | None = os.getenv("OLLAMA_API_KEY")
43
- OLLAMA_TIMEOUT: int = int(os.getenv("OLLAMA_TIMEOUT", "120"))
67
+ ollama_base_url: str = Field(default="http://localhost:11434")
68
+ ollama_embedding_url: str | None = None
69
+ ollama_api_key: str | None = None
70
+ ollama_timeout: int = Field(default=120, gt=0, le=600)
71
+ default_llm_model: str | None = None
72
+ default_embedding_model: str | None = None
73
+ log_level: str = Field(default="INFO")
74
+
75
+ @field_validator("ollama_base_url", "ollama_embedding_url", mode="before")
76
+ @classmethod
77
+ def validate_url(cls, v: str | None) -> str | None:
78
+ """Validate URL format."""
79
+ if v is None:
80
+ return v
81
+ v = str(v).strip().rstrip("/")
82
+ if not v:
83
+ return None
84
+ if not v.startswith(("http://", "https://")):
85
+ raise ValueError(f"URL must start with http:// or https://: {v}")
86
+ return v
87
+
88
+ @field_validator("log_level", mode="before")
89
+ @classmethod
90
+ def validate_log_level(cls, v: str) -> str:
91
+ """Validate log level is a valid Python logging level."""
92
+ valid_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
93
+ v = str(v).upper().strip()
94
+ if v not in valid_levels:
95
+ raise ValueError(f"Invalid log level: {v}. Must be one of {valid_levels}")
96
+ return v
97
+
98
+ @field_validator("ollama_api_key", mode="before")
99
+ @classmethod
100
+ def validate_api_key(cls, v: str | None) -> str | None:
101
+ """Treat empty string as None."""
102
+ if v is not None and not str(v).strip():
103
+ return None
104
+ return v
105
+
106
+ @field_validator("ollama_timeout", mode="before")
107
+ @classmethod
108
+ def validate_timeout(cls, v: int | str) -> int:
109
+ """Parse and validate timeout value."""
110
+ try:
111
+ timeout = int(v)
112
+ except (ValueError, TypeError) as e:
113
+ raise ValueError(f"Invalid timeout value '{v}': must be an integer") from e
114
+ return timeout
115
+
116
+ model_config = {"extra": "forbid"}
117
+
118
+ # Uppercase aliases for backwards compatibility
119
+ @property
120
+ def OLLAMA_BASE_URL(self) -> str:
121
+ return self.ollama_base_url
122
+
123
+ @property
124
+ def OLLAMA_EMBEDDING_URL(self) -> str:
125
+ return self.ollama_embedding_url or self.ollama_base_url
126
+
127
+ @property
128
+ def OLLAMA_API_KEY(self) -> str | None:
129
+ return self.ollama_api_key
44
130
 
45
- # Ollama Embedding API Configuration
46
- OLLAMA_EMBEDDING_URL: str = os.getenv(
47
- "OLLAMA_EMBEDDING_URL", os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
48
- )
131
+ @property
132
+ def OLLAMA_TIMEOUT(self) -> int:
133
+ return self.ollama_timeout
134
+
135
+ @property
136
+ def DEFAULT_LLM_MODEL(self) -> str | None:
137
+ return self.default_llm_model
138
+
139
+ @property
140
+ def DEFAULT_EMBEDDING_MODEL(self) -> str | None:
141
+ return self.default_embedding_model
142
+
143
+ @property
144
+ def LOG_LEVEL(self) -> str:
145
+ return self.log_level
146
+
147
+
148
+ def _safe_get_env(key: str, default: str | None = None) -> str | None:
149
+ """Get environment variable, returning None for empty strings."""
150
+ value = os.getenv(key, default)
151
+ if value is not None and not value.strip():
152
+ return default
153
+ return value
154
+
155
+
156
+ def _safe_get_int_env(key: str, default: int) -> int:
157
+ """Get environment variable as int, raising on invalid values."""
158
+ value = os.getenv(key)
159
+ if value is None:
160
+ return default
161
+ try:
162
+ return int(value)
163
+ except ValueError:
164
+ raise ConfigValidationError(f"Invalid integer value for {key}: {value!r}") from None
165
+
166
+
167
+ def load_config() -> RagitConfig:
168
+ """Load and validate configuration from environment variables.
169
+
170
+ Returns
171
+ -------
172
+ RagitConfig
173
+ Validated configuration object.
174
+
175
+ Raises
176
+ ------
177
+ ConfigValidationError
178
+ If configuration validation fails.
179
+ """
180
+ try:
181
+ return RagitConfig(
182
+ ollama_base_url=_safe_get_env("OLLAMA_BASE_URL", "http://localhost:11434") or "http://localhost:11434",
183
+ ollama_embedding_url=_safe_get_env("OLLAMA_EMBEDDING_URL") or _safe_get_env("OLLAMA_BASE_URL"),
184
+ ollama_api_key=_safe_get_env("OLLAMA_API_KEY"),
185
+ ollama_timeout=_safe_get_int_env("OLLAMA_TIMEOUT", 120),
186
+ default_llm_model=_safe_get_env("RAGIT_DEFAULT_LLM_MODEL"),
187
+ default_embedding_model=_safe_get_env("RAGIT_DEFAULT_EMBEDDING_MODEL"),
188
+ log_level=_safe_get_env("RAGIT_LOG_LEVEL", "INFO") or "INFO",
189
+ )
190
+ except Exception as e:
191
+ raise ConfigValidationError(f"Configuration error: {e}") from e
49
192
 
50
- # Model settings (only used if explicitly requested, no defaults)
51
- # These can still be set via environment variables for convenience
52
- DEFAULT_LLM_MODEL: str | None = os.getenv("RAGIT_DEFAULT_LLM_MODEL")
53
- DEFAULT_EMBEDDING_MODEL: str | None = os.getenv("RAGIT_DEFAULT_EMBEDDING_MODEL")
54
193
 
55
- # Logging
56
- LOG_LEVEL: str = os.getenv("RAGIT_LOG_LEVEL", "INFO")
194
+ # Singleton instance - validates configuration at import time
195
+ try:
196
+ config = load_config()
197
+ except ConfigValidationError as e:
198
+ # Re-raise with clear message
199
+ raise ConfigValidationError(str(e)) from e
57
200
 
58
201
 
59
- # Singleton instance
60
- config = Config()
202
+ # Backwards compatibility alias
203
+ Config = RagitConfig
@@ -45,7 +45,13 @@ class Document:
45
45
 
46
46
  @dataclass
47
47
  class Chunk:
48
- """A document chunk."""
48
+ """A document chunk with optional rich metadata.
49
+
50
+ Metadata can include:
51
+ - document_id: SHA256 hash for deduplication and window search
52
+ - sequence_number: Order within the document
53
+ - chunk_start/chunk_end: Character positions in original text
54
+ """
49
55
 
50
56
  content: str
51
57
  doc_id: str
ragit/exceptions.py ADDED
@@ -0,0 +1,271 @@
1
+ #
2
+ # Copyright RODMENA LIMITED 2025
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ """
6
+ Custom exception hierarchy for ragit.
7
+
8
+ Provides structured exceptions for different failure types,
9
+ enabling better error handling and debugging.
10
+
11
+ Pattern inspired by ai4rag exception_handler.py.
12
+ """
13
+
14
+ from typing import Any
15
+
16
+
17
+ class RagitError(Exception):
18
+ """Base exception for all ragit errors.
19
+
20
+ All ragit-specific exceptions inherit from this class,
21
+ making it easy to catch all ragit errors with a single handler.
22
+
23
+ Parameters
24
+ ----------
25
+ message : str
26
+ Human-readable error message.
27
+ original_exception : BaseException, optional
28
+ The underlying exception that caused this error.
29
+
30
+ Examples
31
+ --------
32
+ >>> try:
33
+ ... provider.embed("text", "model")
34
+ ... except RagitError as e:
35
+ ... print(f"Ragit error: {e}")
36
+ ... if e.original_exception:
37
+ ... print(f"Caused by: {e.original_exception}")
38
+ """
39
+
40
+ def __init__(self, message: str, original_exception: BaseException | None = None):
41
+ self.message = message
42
+ self.original_exception = original_exception
43
+ super().__init__(self._format_message())
44
+
45
+ def _format_message(self) -> str:
46
+ """Format the error message, including original exception if present."""
47
+ if self.original_exception:
48
+ return f"{self.message}: {self.original_exception}"
49
+ return self.message
50
+
51
+
52
+ class ConfigurationError(RagitError):
53
+ """Configuration validation or loading failed.
54
+
55
+ Raised when:
56
+ - Environment variables have invalid values
57
+ - Required configuration is missing
58
+ - URL formats are invalid
59
+ """
60
+
61
+ pass
62
+
63
+
64
+ class ProviderError(RagitError):
65
+ """Provider communication or operation failed.
66
+
67
+ Raised when:
68
+ - Network connection to provider fails
69
+ - Provider returns an error response
70
+ - Provider timeout occurs
71
+ """
72
+
73
+ pass
74
+
75
+
76
+ class IndexingError(RagitError):
77
+ """Document indexing or embedding failed.
78
+
79
+ Raised when:
80
+ - Embedding generation fails
81
+ - Document chunking fails
82
+ - Index building fails
83
+ """
84
+
85
+ pass
86
+
87
+
88
+ class RetrievalError(RagitError):
89
+ """Retrieval operation failed.
90
+
91
+ Raised when:
92
+ - Query embedding fails
93
+ - Search operation fails
94
+ - No results can be retrieved
95
+ """
96
+
97
+ pass
98
+
99
+
100
+ class GenerationError(RagitError):
101
+ """LLM generation failed.
102
+
103
+ Raised when:
104
+ - LLM call fails
105
+ - Response parsing fails
106
+ - Context exceeds model limits
107
+ """
108
+
109
+ pass
110
+
111
+
112
+ class EvaluationError(RagitError):
113
+ """Evaluation or scoring failed.
114
+
115
+ Raised when:
116
+ - Metric calculation fails
117
+ - Benchmark validation fails
118
+ - Score extraction fails
119
+ """
120
+
121
+ pass
122
+
123
+
124
+ class ExceptionAggregator:
125
+ """Collect and report exceptions during batch operations.
126
+
127
+ Useful for operations that should continue even when some
128
+ items fail, then report all failures at the end.
129
+
130
+ Pattern from ai4rag exception_handler.py.
131
+
132
+ Examples
133
+ --------
134
+ >>> aggregator = ExceptionAggregator()
135
+ >>> for doc in documents:
136
+ ... try:
137
+ ... process(doc)
138
+ ... except Exception as e:
139
+ ... aggregator.record(f"doc:{doc.id}", e)
140
+ >>> if aggregator.has_errors:
141
+ ... print(aggregator.get_summary())
142
+ """
143
+
144
+ def __init__(self) -> None:
145
+ self._exceptions: list[tuple[str, Exception]] = []
146
+
147
+ def record(self, context: str, exception: Exception) -> None:
148
+ """Record an exception with context.
149
+
150
+ Parameters
151
+ ----------
152
+ context : str
153
+ Description of where/why the exception occurred.
154
+ exception : Exception
155
+ The exception that was raised.
156
+ """
157
+ self._exceptions.append((context, exception))
158
+
159
+ @property
160
+ def has_errors(self) -> bool:
161
+ """Check if any errors have been recorded."""
162
+ return len(self._exceptions) > 0
163
+
164
+ @property
165
+ def error_count(self) -> int:
166
+ """Get the number of recorded errors."""
167
+ return len(self._exceptions)
168
+
169
+ @property
170
+ def exceptions(self) -> list[tuple[str, Exception]]:
171
+ """Get all recorded exceptions with their contexts."""
172
+ return list(self._exceptions)
173
+
174
+ def get_by_type(self, exc_type: type[Exception]) -> list[tuple[str, Exception]]:
175
+ """Get exceptions of a specific type.
176
+
177
+ Parameters
178
+ ----------
179
+ exc_type : type
180
+ The exception type to filter by.
181
+
182
+ Returns
183
+ -------
184
+ list[tuple[str, Exception]]
185
+ Exceptions matching the type with their contexts.
186
+ """
187
+ return [(ctx, exc) for ctx, exc in self._exceptions if isinstance(exc, exc_type)]
188
+
189
+ def get_summary(self) -> str:
190
+ """Get a summary of all recorded errors.
191
+
192
+ Returns
193
+ -------
194
+ str
195
+ Human-readable summary of errors.
196
+ """
197
+ if not self._exceptions:
198
+ return "No errors recorded"
199
+
200
+ # Group by exception type
201
+ by_type: dict[str, int] = {}
202
+ for _, exc in self._exceptions:
203
+ exc_type = type(exc).__name__
204
+ by_type[exc_type] = by_type.get(exc_type, 0) + 1
205
+
206
+ most_common = max(by_type.items(), key=lambda x: x[1])
207
+ type_summary = ", ".join(f"{t}:{c}" for t, c in sorted(by_type.items(), key=lambda x: -x[1]))
208
+
209
+ return f"{self.error_count} errors ({type_summary}). Most common: {most_common[0]} ({most_common[1]}x)"
210
+
211
+ def get_details(self) -> str:
212
+ """Get detailed information about all errors.
213
+
214
+ Returns
215
+ -------
216
+ str
217
+ Detailed error information with contexts.
218
+ """
219
+ if not self._exceptions:
220
+ return "No errors recorded"
221
+
222
+ lines = [f"Total errors: {self.error_count}", ""]
223
+ for i, (context, exc) in enumerate(self._exceptions, 1):
224
+ lines.append(f"{i}. [{context}] {type(exc).__name__}: {exc}")
225
+
226
+ return "\n".join(lines)
227
+
228
+ def raise_if_errors(self, message: str = "Operation failed") -> None:
229
+ """Raise RagitError if any errors were recorded.
230
+
231
+ Parameters
232
+ ----------
233
+ message : str
234
+ Base message for the raised error.
235
+
236
+ Raises
237
+ ------
238
+ RagitError
239
+ If any errors were recorded.
240
+ """
241
+ if self.has_errors:
242
+ raise RagitError(f"{message}: {self.get_summary()}")
243
+
244
+ def clear(self) -> None:
245
+ """Clear all recorded exceptions."""
246
+ self._exceptions.clear()
247
+
248
+ def merge_from(self, other: "ExceptionAggregator") -> None:
249
+ """Merge exceptions from another aggregator.
250
+
251
+ Parameters
252
+ ----------
253
+ other : ExceptionAggregator
254
+ Another aggregator to merge from.
255
+ """
256
+ self._exceptions.extend(other._exceptions)
257
+
258
+ def to_dict(self) -> dict[str, Any]:
259
+ """Export as dictionary for JSON serialization.
260
+
261
+ Returns
262
+ -------
263
+ dict
264
+ Dictionary representation of aggregated errors.
265
+ """
266
+ return {
267
+ "error_count": self.error_count,
268
+ "errors": [
269
+ {"context": ctx, "type": type(exc).__name__, "message": str(exc)} for ctx, exc in self._exceptions
270
+ ],
271
+ }