sdg-hub 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. sdg_hub/__init__.py +28 -1
  2. sdg_hub/_version.py +2 -2
  3. sdg_hub/core/__init__.py +22 -0
  4. sdg_hub/core/blocks/__init__.py +58 -0
  5. sdg_hub/core/blocks/base.py +313 -0
  6. sdg_hub/core/blocks/deprecated_blocks/__init__.py +29 -0
  7. sdg_hub/core/blocks/deprecated_blocks/combine_columns.py +93 -0
  8. sdg_hub/core/blocks/deprecated_blocks/duplicate_columns.py +88 -0
  9. sdg_hub/core/blocks/deprecated_blocks/filter_by_value.py +103 -0
  10. sdg_hub/core/blocks/deprecated_blocks/flatten_columns.py +94 -0
  11. sdg_hub/core/blocks/deprecated_blocks/llmblock.py +479 -0
  12. sdg_hub/core/blocks/deprecated_blocks/rename_columns.py +88 -0
  13. sdg_hub/core/blocks/deprecated_blocks/sample_populator.py +58 -0
  14. sdg_hub/core/blocks/deprecated_blocks/selector.py +97 -0
  15. sdg_hub/core/blocks/deprecated_blocks/set_to_majority_value.py +88 -0
  16. sdg_hub/core/blocks/evaluation/__init__.py +9 -0
  17. sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +564 -0
  18. sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +564 -0
  19. sdg_hub/core/blocks/evaluation/verify_question_block.py +564 -0
  20. sdg_hub/core/blocks/filtering/__init__.py +12 -0
  21. sdg_hub/core/blocks/filtering/column_value_filter.py +188 -0
  22. sdg_hub/core/blocks/llm/__init__.py +25 -0
  23. sdg_hub/core/blocks/llm/client_manager.py +398 -0
  24. sdg_hub/core/blocks/llm/config.py +336 -0
  25. sdg_hub/core/blocks/llm/error_handler.py +368 -0
  26. sdg_hub/core/blocks/llm/llm_chat_block.py +542 -0
  27. sdg_hub/core/blocks/llm/prompt_builder_block.py +368 -0
  28. sdg_hub/core/blocks/llm/text_parser_block.py +310 -0
  29. sdg_hub/core/blocks/registry.py +331 -0
  30. sdg_hub/core/blocks/transform/__init__.py +23 -0
  31. sdg_hub/core/blocks/transform/duplicate_columns.py +88 -0
  32. sdg_hub/core/blocks/transform/index_based_mapper.py +225 -0
  33. sdg_hub/core/blocks/transform/melt_columns.py +126 -0
  34. sdg_hub/core/blocks/transform/rename_columns.py +69 -0
  35. sdg_hub/core/blocks/transform/text_concat.py +102 -0
  36. sdg_hub/core/blocks/transform/uniform_col_val_setter.py +101 -0
  37. sdg_hub/core/flow/__init__.py +20 -0
  38. sdg_hub/core/flow/base.py +980 -0
  39. sdg_hub/core/flow/metadata.py +344 -0
  40. sdg_hub/core/flow/migration.py +187 -0
  41. sdg_hub/core/flow/registry.py +330 -0
  42. sdg_hub/core/flow/validation.py +265 -0
  43. sdg_hub/{utils → core/utils}/__init__.py +6 -4
  44. sdg_hub/{utils → core/utils}/datautils.py +1 -3
  45. sdg_hub/core/utils/error_handling.py +208 -0
  46. sdg_hub/{utils → core/utils}/path_resolution.py +2 -2
  47. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/atomic_facts.yaml +40 -0
  48. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/detailed_summary.yaml +13 -0
  49. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_faithfulness.yaml +64 -0
  50. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_question.yaml +29 -0
  51. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_relevancy.yaml +81 -0
  52. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/extractive_summary.yaml +13 -0
  53. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +191 -0
  54. sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/generate_questions_responses.yaml +54 -0
  55. sdg_hub-0.2.0.dist-info/METADATA +218 -0
  56. sdg_hub-0.2.0.dist-info/RECORD +63 -0
  57. sdg_hub/blocks/__init__.py +0 -42
  58. sdg_hub/blocks/block.py +0 -96
  59. sdg_hub/blocks/llmblock.py +0 -375
  60. sdg_hub/blocks/openaichatblock.py +0 -556
  61. sdg_hub/blocks/utilblocks.py +0 -597
  62. sdg_hub/checkpointer.py +0 -139
  63. sdg_hub/configs/annotations/cot_reflection.yaml +0 -34
  64. sdg_hub/configs/annotations/detailed_annotations.yaml +0 -28
  65. sdg_hub/configs/annotations/detailed_description.yaml +0 -10
  66. sdg_hub/configs/annotations/detailed_description_icl.yaml +0 -32
  67. sdg_hub/configs/annotations/simple_annotations.yaml +0 -9
  68. sdg_hub/configs/knowledge/__init__.py +0 -0
  69. sdg_hub/configs/knowledge/atomic_facts.yaml +0 -46
  70. sdg_hub/configs/knowledge/auxilary_instructions.yaml +0 -35
  71. sdg_hub/configs/knowledge/detailed_summary.yaml +0 -18
  72. sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +0 -68
  73. sdg_hub/configs/knowledge/evaluate_question.yaml +0 -38
  74. sdg_hub/configs/knowledge/evaluate_relevancy.yaml +0 -84
  75. sdg_hub/configs/knowledge/extractive_summary.yaml +0 -18
  76. sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +0 -39
  77. sdg_hub/configs/knowledge/generate_questions.yaml +0 -82
  78. sdg_hub/configs/knowledge/generate_questions_responses.yaml +0 -56
  79. sdg_hub/configs/knowledge/generate_responses.yaml +0 -86
  80. sdg_hub/configs/knowledge/mcq_generation.yaml +0 -83
  81. sdg_hub/configs/knowledge/router.yaml +0 -12
  82. sdg_hub/configs/knowledge/simple_generate_qa.yaml +0 -34
  83. sdg_hub/configs/reasoning/__init__.py +0 -0
  84. sdg_hub/configs/reasoning/dynamic_cot.yaml +0 -40
  85. sdg_hub/configs/skills/__init__.py +0 -0
  86. sdg_hub/configs/skills/analyzer.yaml +0 -48
  87. sdg_hub/configs/skills/annotation.yaml +0 -36
  88. sdg_hub/configs/skills/contexts.yaml +0 -28
  89. sdg_hub/configs/skills/critic.yaml +0 -60
  90. sdg_hub/configs/skills/evaluate_freeform_pair.yaml +0 -111
  91. sdg_hub/configs/skills/evaluate_freeform_questions.yaml +0 -78
  92. sdg_hub/configs/skills/evaluate_grounded_pair.yaml +0 -119
  93. sdg_hub/configs/skills/evaluate_grounded_questions.yaml +0 -51
  94. sdg_hub/configs/skills/freeform_questions.yaml +0 -34
  95. sdg_hub/configs/skills/freeform_responses.yaml +0 -39
  96. sdg_hub/configs/skills/grounded_questions.yaml +0 -38
  97. sdg_hub/configs/skills/grounded_responses.yaml +0 -59
  98. sdg_hub/configs/skills/icl_examples/STEM.yaml +0 -56
  99. sdg_hub/configs/skills/icl_examples/__init__.py +0 -0
  100. sdg_hub/configs/skills/icl_examples/coding.yaml +0 -97
  101. sdg_hub/configs/skills/icl_examples/extraction.yaml +0 -36
  102. sdg_hub/configs/skills/icl_examples/humanities.yaml +0 -71
  103. sdg_hub/configs/skills/icl_examples/math.yaml +0 -85
  104. sdg_hub/configs/skills/icl_examples/reasoning.yaml +0 -30
  105. sdg_hub/configs/skills/icl_examples/roleplay.yaml +0 -45
  106. sdg_hub/configs/skills/icl_examples/writing.yaml +0 -80
  107. sdg_hub/configs/skills/judge.yaml +0 -53
  108. sdg_hub/configs/skills/planner.yaml +0 -67
  109. sdg_hub/configs/skills/respond.yaml +0 -8
  110. sdg_hub/configs/skills/revised_responder.yaml +0 -78
  111. sdg_hub/configs/skills/router.yaml +0 -59
  112. sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +0 -27
  113. sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +0 -31
  114. sdg_hub/flow.py +0 -477
  115. sdg_hub/flow_runner.py +0 -450
  116. sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +0 -13
  117. sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +0 -12
  118. sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +0 -89
  119. sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +0 -136
  120. sdg_hub/flows/generation/skills/improve_responses.yaml +0 -103
  121. sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +0 -12
  122. sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +0 -12
  123. sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +0 -80
  124. sdg_hub/flows/generation/skills/synth_skills.yaml +0 -59
  125. sdg_hub/pipeline.py +0 -121
  126. sdg_hub/prompts.py +0 -80
  127. sdg_hub/registry.py +0 -122
  128. sdg_hub/sdg.py +0 -206
  129. sdg_hub/utils/config_validation.py +0 -91
  130. sdg_hub/utils/error_handling.py +0 -94
  131. sdg_hub/utils/validation_result.py +0 -10
  132. sdg_hub-0.1.4.dist-info/METADATA +0 -190
  133. sdg_hub-0.1.4.dist-info/RECORD +0 -89
  134. sdg_hub/{logger_config.py → core/utils/logger_config.py} +1 -1
  135. /sdg_hub/{configs/__init__.py → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/README.md} +0 -0
  136. /sdg_hub/{configs/annotations → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab}/__init__.py +0 -0
  137. {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/WHEEL +0 -0
  138. {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/licenses/LICENSE +0 -0
  139. {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,336 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Configuration system for LLM blocks supporting all providers via LiteLLM."""
3
+
4
+ # Standard
5
+ from dataclasses import dataclass
6
+ from typing import Any, Optional, Union
7
+ import os
8
+
9
+
10
+ @dataclass
11
+ class LLMConfig:
12
+ """Configuration for LLM blocks supporting all providers via LiteLLM.
13
+
14
+ This configuration supports 100+ LLM providers including OpenAI, Anthropic,
15
+ Google, local models (vLLM, Ollama), and more through LiteLLM.
16
+
17
+ Parameters
18
+ ----------
19
+ model : Optional[str], optional
20
+ Model identifier in LiteLLM format. Can be None initially and set later via set_model_config(). Examples:
21
+ - "openai/gpt-4"
22
+ - "anthropic/claude-3-sonnet-20240229"
23
+ - "hosted_vllm/meta-llama/Llama-2-7b-chat-hf"
24
+ - "ollama/llama2"
25
+
26
+ api_key : Optional[str], optional
27
+ API key for the provider. Falls back to environment variables:
28
+ - OPENAI_API_KEY for OpenAI models
29
+ - ANTHROPIC_API_KEY for Anthropic models
30
+ - GOOGLE_API_KEY for Google models
31
+ - etc.
32
+
33
+ api_base : Optional[str], optional
34
+ Base URL for the API. Required for local models.
35
+
36
+ Examples
37
+ --------
38
+ - "http://localhost:8000/v1" for local vLLM
39
+ - "http://localhost:11434" for Ollama
40
+
41
+ timeout : float, optional
42
+ Request timeout in seconds, by default 120.0
43
+
44
+ max_retries : int, optional
45
+ Maximum number of retry attempts, by default 6
46
+
47
+ ### Generation Parameters ###
48
+
49
+ temperature : Optional[float], optional
50
+ Sampling temperature (0.0 to 2.0), by default None
51
+
52
+ max_tokens : Optional[int], optional
53
+ Maximum tokens to generate, by default None
54
+
55
+ top_p : Optional[float], optional
56
+ Nucleus sampling parameter (0.0 to 1.0), by default None
57
+
58
+ frequency_penalty : Optional[float], optional
59
+ Frequency penalty (-2.0 to 2.0), by default None
60
+
61
+ presence_penalty : Optional[float], optional
62
+ Presence penalty (-2.0 to 2.0), by default None
63
+
64
+ stop : Optional[Union[str, List[str]]], optional
65
+ Stop sequences, by default None
66
+
67
+ seed : Optional[int], optional
68
+ Random seed for reproducible outputs, by default None
69
+
70
+ response_format : Optional[Dict[str, Any]], optional
71
+ Response format specification (e.g., JSON mode), by default None
72
+
73
+ stream : Optional[bool], optional
74
+ Whether to stream responses, by default None
75
+
76
+ n : Optional[int], optional
77
+ Number of completions to generate, by default None
78
+
79
+ logprobs : Optional[bool], optional
80
+ Whether to return log probabilities, by default None
81
+
82
+ top_logprobs : Optional[int], optional
83
+ Number of top log probabilities to return, by default None
84
+
85
+ user : Optional[str], optional
86
+ End-user identifier, by default None
87
+
88
+ extra_headers : Optional[Dict[str, str]], optional
89
+ Additional headers to send with requests, by default None
90
+
91
+ extra_body : Optional[Dict[str, Any]], optional
92
+ Additional parameters for the request body, by default None
93
+
94
+ provider_specific : Optional[Dict[str, Any]], optional
95
+ Provider-specific parameters that don't map to standard OpenAI params, by default None
96
+ """
97
+
98
+ model: Optional[str] = None
99
+ api_key: Optional[str] = None
100
+ api_base: Optional[str] = None
101
+ timeout: float = 120.0
102
+ max_retries: int = 6
103
+
104
+ # Generation parameters (OpenAI-compatible)
105
+ temperature: Optional[float] = None
106
+ max_tokens: Optional[int] = None
107
+ top_p: Optional[float] = None
108
+ frequency_penalty: Optional[float] = None
109
+ presence_penalty: Optional[float] = None
110
+ stop: Optional[Union[str, list[str]]] = None
111
+ seed: Optional[int] = None
112
+ response_format: Optional[dict[str, Any]] = None
113
+ stream: Optional[bool] = None
114
+ n: Optional[int] = None
115
+ logprobs: Optional[bool] = None
116
+ top_logprobs: Optional[int] = None
117
+ user: Optional[str] = None
118
+
119
+ # Additional parameters
120
+ extra_headers: Optional[dict[str, str]] = None
121
+ extra_body: Optional[dict[str, Any]] = None
122
+ provider_specific: Optional[dict[str, Any]] = None
123
+
124
+ def __post_init__(self) -> None:
125
+ """Validate configuration after initialization."""
126
+ self._validate_model()
127
+ self._validate_parameters()
128
+ self._resolve_api_key()
129
+
130
+ def _validate_model(self) -> None:
131
+ """Validate model identifier format."""
132
+ # Model is optional - will be set later via set_model_config()
133
+ if self.model is None:
134
+ return
135
+
136
+ # Check if it's a valid LiteLLM model format
137
+ if "/" not in self.model:
138
+ raise ValueError(
139
+ f"Model '{self.model}' should be in format 'provider/model-name'. "
140
+ f"Examples: 'openai/gpt-4', 'anthropic/claude-3-sonnet-20240229', "
141
+ f"'hosted_vllm/meta-llama/Llama-2-7b-chat-hf'"
142
+ )
143
+
144
+ def _validate_parameters(self) -> None:
145
+ """Validate generation parameters."""
146
+ if self.temperature is not None and not (0.0 <= self.temperature <= 2.0):
147
+ raise ValueError(
148
+ f"Temperature must be between 0.0 and 2.0, got {self.temperature}"
149
+ )
150
+
151
+ if self.max_tokens is not None and self.max_tokens <= 0:
152
+ raise ValueError(f"max_tokens must be positive, got {self.max_tokens}")
153
+
154
+ if self.top_p is not None and not (0.0 <= self.top_p <= 1.0):
155
+ raise ValueError(f"top_p must be between 0.0 and 1.0, got {self.top_p}")
156
+
157
+ if self.frequency_penalty is not None and not (
158
+ -2.0 <= self.frequency_penalty <= 2.0
159
+ ):
160
+ raise ValueError(
161
+ f"frequency_penalty must be between -2.0 and 2.0, got {self.frequency_penalty}"
162
+ )
163
+
164
+ if self.presence_penalty is not None and not (
165
+ -2.0 <= self.presence_penalty <= 2.0
166
+ ):
167
+ raise ValueError(
168
+ f"presence_penalty must be between -2.0 and 2.0, got {self.presence_penalty}"
169
+ )
170
+
171
+ if self.n is not None and self.n <= 0:
172
+ raise ValueError(f"n must be positive, got {self.n}")
173
+
174
+ if self.max_retries < 0:
175
+ raise ValueError(
176
+ f"max_retries must be non-negative, got {self.max_retries}"
177
+ )
178
+
179
+ if self.timeout <= 0:
180
+ raise ValueError(f"timeout must be positive, got {self.timeout}")
181
+
182
+ def _resolve_api_key(self) -> None:
183
+ """Resolve API key from environment variables if not provided.
184
+
185
+ This method only reads from environment variables and does not modify them,
186
+ ensuring thread-safety when multiple instances are used concurrently.
187
+ """
188
+ if self.api_key is not None:
189
+ return
190
+
191
+ # Skip API key resolution if model is not set yet
192
+ if self.model is None:
193
+ return
194
+
195
+ # Extract provider from model
196
+ provider = self.model.split("/")[0].lower()
197
+
198
+ # Map provider to environment variable
199
+ provider_env_map = {
200
+ "openai": "OPENAI_API_KEY",
201
+ "anthropic": "ANTHROPIC_API_KEY",
202
+ "google": "GOOGLE_API_KEY",
203
+ "azure": "AZURE_API_KEY",
204
+ "huggingface": "HUGGINGFACE_API_KEY",
205
+ "cohere": "COHERE_API_KEY",
206
+ "replicate": "REPLICATE_API_KEY",
207
+ "together": "TOGETHER_API_KEY",
208
+ "anyscale": "ANYSCALE_API_KEY",
209
+ "perplexity": "PERPLEXITY_API_KEY",
210
+ "groq": "GROQ_API_KEY",
211
+ "mistral": "MISTRAL_API_KEY",
212
+ "deepinfra": "DEEPINFRA_API_KEY",
213
+ "ai21": "AI21_API_KEY",
214
+ "nlp_cloud": "NLP_CLOUD_API_KEY",
215
+ "aleph_alpha": "ALEPH_ALPHA_API_KEY",
216
+ "bedrock": "AWS_ACCESS_KEY_ID",
217
+ "vertex_ai": "GOOGLE_APPLICATION_CREDENTIALS",
218
+ }
219
+
220
+ env_var = provider_env_map.get(provider)
221
+ if env_var:
222
+ self.api_key = os.getenv(env_var)
223
+
224
+ def get_generation_kwargs(self) -> dict[str, Any]:
225
+ """Get generation parameters as kwargs for LiteLLM completion."""
226
+ kwargs = {}
227
+
228
+ # Standard parameters
229
+ for param in [
230
+ "temperature",
231
+ "max_tokens",
232
+ "top_p",
233
+ "frequency_penalty",
234
+ "presence_penalty",
235
+ "stop",
236
+ "seed",
237
+ "response_format",
238
+ "stream",
239
+ "n",
240
+ "logprobs",
241
+ "top_logprobs",
242
+ "user",
243
+ ]:
244
+ value = getattr(self, param)
245
+ if value is not None:
246
+ kwargs[param] = value
247
+
248
+ # Additional parameters
249
+ if self.extra_headers:
250
+ kwargs["extra_headers"] = self.extra_headers
251
+
252
+ if self.extra_body:
253
+ kwargs["extra_body"] = self.extra_body
254
+
255
+ if self.provider_specific:
256
+ kwargs.update(self.provider_specific)
257
+
258
+ return kwargs
259
+
260
+ def merge_overrides(self, **overrides: Any) -> "LLMConfig":
261
+ """Create a new config with runtime overrides.
262
+
263
+ Parameters
264
+ ----------
265
+ **overrides : Any
266
+ Runtime parameter overrides.
267
+
268
+ Returns
269
+ -------
270
+ LLMConfig
271
+ New configuration with overrides applied.
272
+ """
273
+ # Get current values as dict
274
+ # Standard
275
+ from dataclasses import fields
276
+
277
+ current_values = {
278
+ field.name: getattr(self, field.name) for field in fields(self)
279
+ }
280
+
281
+ # Apply overrides
282
+ current_values.update(overrides)
283
+
284
+ # Create new config
285
+ return LLMConfig(**current_values)
286
+
287
+ def get_provider(self) -> Optional[str]:
288
+ """Get the provider name from the model identifier.
289
+
290
+ Returns
291
+ -------
292
+ Optional[str]
293
+ Provider name (e.g., "openai", "anthropic", "hosted_vllm"), or None if model is not set.
294
+ """
295
+ if self.model is None:
296
+ return None
297
+ return self.model.split("/")[0]
298
+
299
+ def get_model_name(self) -> Optional[str]:
300
+ """Get the model name without provider prefix.
301
+
302
+ Returns
303
+ -------
304
+ Optional[str]
305
+ Model name (e.g., "gpt-4", "claude-3-sonnet-20240229"), or None if model is not set.
306
+ """
307
+ if self.model is None:
308
+ return None
309
+ parts = self.model.split("/", 1)
310
+ return parts[1] if len(parts) > 1 else parts[0]
311
+
312
+ def is_local_model(self) -> bool:
313
+ """Check if this is a local model deployment.
314
+
315
+ Returns
316
+ -------
317
+ bool
318
+ True if the model is hosted locally (vLLM, Ollama, etc.).
319
+ """
320
+ provider = self.get_provider()
321
+ if provider is None:
322
+ return False
323
+ local_providers = {"hosted_vllm", "ollama", "local", "vllm"}
324
+ return provider.lower() in local_providers
325
+
326
+ def __str__(self) -> str:
327
+ """String representation of the configuration."""
328
+ return f"LLMConfig(model='{self.model}', provider='{self.get_provider()}')"
329
+
330
+ def __repr__(self) -> str:
331
+ """Detailed representation of the configuration."""
332
+ return (
333
+ f"LLMConfig(model='{self.model}', provider='{self.get_provider()}', "
334
+ f"api_base='{self.api_base}', timeout={self.timeout}, "
335
+ f"max_retries={self.max_retries})"
336
+ )
@@ -0,0 +1,368 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Error handling system for LLM blocks supporting multiple providers."""
3
+
4
+ # Standard
5
+ from enum import Enum
6
+ from typing import Any, Optional
7
+
8
+ # Third Party
9
+ from litellm import (
10
+ APIConnectionError,
11
+ AuthenticationError,
12
+ BadRequestError,
13
+ ContentPolicyViolationError,
14
+ ContextWindowExceededError,
15
+ InternalServerError,
16
+ InvalidRequestError,
17
+ NotFoundError,
18
+ RateLimitError,
19
+ ServiceUnavailableError,
20
+ UnprocessableEntityError,
21
+ )
22
+ from tenacity import (
23
+ retry,
24
+ retry_if_exception_type,
25
+ stop_after_attempt,
26
+ wait_exponential,
27
+ )
28
+
29
+ # Local
30
+ from ...utils.logger_config import setup_logger
31
+
32
+ logger = setup_logger(__name__)
33
+
34
+
35
+ class ErrorCategory(Enum):
36
+ """Categories of errors for different retry strategies."""
37
+
38
+ RETRYABLE_RATE_LIMIT = "rate_limit"
39
+ RETRYABLE_TIMEOUT = "timeout"
40
+ RETRYABLE_CONNECTION = "connection"
41
+ RETRYABLE_SERVER = "server_error"
42
+ RETRYABLE_CONTENT_FILTER = "content_filter"
43
+
44
+ NON_RETRYABLE_AUTH = "auth_error"
45
+ NON_RETRYABLE_PERMISSION = "permission"
46
+ NON_RETRYABLE_BAD_REQUEST = "bad_request"
47
+ NON_RETRYABLE_NOT_FOUND = "not_found"
48
+ NON_RETRYABLE_CONTEXT_LENGTH = "context_length"
49
+
50
+ UNKNOWN = "unknown"
51
+
52
+
53
+ class LLMErrorHandler:
54
+ """Centralized error handling for LLM operations across all providers.
55
+
56
+ This class handles errors from multiple LLM providers through LiteLLM,
57
+ which maps provider-specific errors to OpenAI-compatible exceptions.
58
+
59
+ Parameters
60
+ ----------
61
+ max_retries : int, optional
62
+ Maximum number of retry attempts, by default 6
63
+ base_delay : float, optional
64
+ Base delay between retries in seconds, by default 1.0
65
+ max_delay : float, optional
66
+ Maximum delay between retries in seconds, by default 60.0
67
+ exponential_base : float, optional
68
+ Base for exponential backoff, by default 2.0
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ max_retries: int = 6,
74
+ base_delay: float = 1.0,
75
+ max_delay: float = 60.0,
76
+ exponential_base: float = 2.0,
77
+ ) -> None:
78
+ self.max_retries = max_retries
79
+ self.base_delay = base_delay
80
+ self.max_delay = max_delay
81
+ self.exponential_base = exponential_base
82
+
83
+ # Error category mappings
84
+ self.error_mappings = {
85
+ # Rate limiting errors
86
+ RateLimitError: ErrorCategory.RETRYABLE_RATE_LIMIT,
87
+ # Connection errors
88
+ APIConnectionError: ErrorCategory.RETRYABLE_CONNECTION,
89
+ # Server errors (5xx)
90
+ InternalServerError: ErrorCategory.RETRYABLE_SERVER,
91
+ ServiceUnavailableError: ErrorCategory.RETRYABLE_SERVER,
92
+ # Content filter errors (might be retryable with different input)
93
+ ContentPolicyViolationError: ErrorCategory.RETRYABLE_CONTENT_FILTER,
94
+ # Authentication errors (non-retryable)
95
+ AuthenticationError: ErrorCategory.NON_RETRYABLE_AUTH,
96
+ # Bad request errors (non-retryable)
97
+ BadRequestError: ErrorCategory.NON_RETRYABLE_BAD_REQUEST,
98
+ InvalidRequestError: ErrorCategory.NON_RETRYABLE_BAD_REQUEST,
99
+ UnprocessableEntityError: ErrorCategory.NON_RETRYABLE_BAD_REQUEST,
100
+ # Not found errors (non-retryable)
101
+ NotFoundError: ErrorCategory.NON_RETRYABLE_NOT_FOUND,
102
+ # Context length errors (non-retryable)
103
+ ContextWindowExceededError: ErrorCategory.NON_RETRYABLE_CONTEXT_LENGTH,
104
+ }
105
+
106
+ # Retryable error types
107
+ self.retryable_errors = {
108
+ ErrorCategory.RETRYABLE_RATE_LIMIT,
109
+ ErrorCategory.RETRYABLE_TIMEOUT,
110
+ ErrorCategory.RETRYABLE_CONNECTION,
111
+ ErrorCategory.RETRYABLE_SERVER,
112
+ ErrorCategory.RETRYABLE_CONTENT_FILTER,
113
+ }
114
+
115
+ def classify_error(self, error: Exception) -> ErrorCategory:
116
+ """Classify an error into a category for retry logic.
117
+
118
+ Parameters
119
+ ----------
120
+ error : Exception
121
+ The error to classify.
122
+
123
+ Returns
124
+ -------
125
+ ErrorCategory
126
+ The category of the error.
127
+ """
128
+ error_type = type(error)
129
+ return self.error_mappings.get(error_type, ErrorCategory.UNKNOWN)
130
+
131
+ def should_retry(self, error: Exception, attempt: int) -> bool:
132
+ """Determine if an error should be retried.
133
+
134
+ Parameters
135
+ ----------
136
+ error : Exception
137
+ The error that occurred.
138
+ attempt : int
139
+ The current attempt number (1-based).
140
+
141
+ Returns
142
+ -------
143
+ bool
144
+ True if the error should be retried.
145
+ """
146
+ if attempt >= self.max_retries:
147
+ return False
148
+
149
+ category = self.classify_error(error)
150
+ return category in self.retryable_errors
151
+
152
+ def calculate_delay(self, error: Exception, attempt: int) -> float:
153
+ """Calculate the delay before the next retry.
154
+
155
+ Parameters
156
+ ----------
157
+ error : Exception
158
+ The error that occurred.
159
+ attempt : int
160
+ The current attempt number (1-based).
161
+
162
+ Returns
163
+ -------
164
+ float
165
+ Delay in seconds before the next retry.
166
+ """
167
+ category = self.classify_error(error)
168
+
169
+ if category == ErrorCategory.RETRYABLE_RATE_LIMIT:
170
+ # Longer delays for rate limiting
171
+ delay = min(
172
+ self.base_delay * (self.exponential_base**attempt) * 2,
173
+ self.max_delay * 2,
174
+ )
175
+ elif category == ErrorCategory.RETRYABLE_TIMEOUT:
176
+ # Shorter delays for timeouts
177
+ delay = min(
178
+ self.base_delay * (self.exponential_base ** (attempt - 1)),
179
+ self.max_delay * 0.5,
180
+ )
181
+ else:
182
+ # Standard exponential backoff
183
+ delay = min(
184
+ self.base_delay * (self.exponential_base ** (attempt - 1)),
185
+ self.max_delay,
186
+ )
187
+
188
+ return delay
189
+
190
+ def log_error_context(
191
+ self, error: Exception, context: dict[str, Any], attempt: int = 1
192
+ ) -> None:
193
+ """Log error with context information.
194
+
195
+ Parameters
196
+ ----------
197
+ error : Exception
198
+ The error that occurred.
199
+ context : Dict[str, Any]
200
+ Context information about the error.
201
+ attempt : int, optional
202
+ The current attempt number, by default 1.
203
+ """
204
+ category = self.classify_error(error)
205
+
206
+ log_data = {
207
+ "error_type": type(error).__name__,
208
+ "error_category": category.value,
209
+ "error_message": str(error),
210
+ "attempt": attempt,
211
+ "max_retries": self.max_retries,
212
+ "retryable": category in self.retryable_errors,
213
+ **context,
214
+ }
215
+
216
+ if category in self.retryable_errors and attempt < self.max_retries:
217
+ delay = self.calculate_delay(error, attempt)
218
+ log_data["retry_delay"] = delay
219
+ logger.warning(
220
+ f"Retryable error occurred (attempt {attempt}/{self.max_retries}). "
221
+ f"Retrying in {delay:.1f}s: {error}",
222
+ extra=log_data,
223
+ )
224
+ else:
225
+ logger.error(
226
+ f"Non-retryable error or max retries exceeded: {error}", extra=log_data
227
+ )
228
+
229
+ def create_retry_decorator(self, context: Optional[dict[str, Any]] = None):
230
+ """Create a retry decorator for LLM operations.
231
+
232
+ Parameters
233
+ ----------
234
+ context : Optional[Dict[str, Any]], optional
235
+ Context information for logging, by default None.
236
+
237
+ Returns
238
+ -------
239
+ Callable
240
+ A retry decorator configured for LLM operations.
241
+ """
242
+ context = context or {}
243
+
244
+ def retry_condition(retry_state):
245
+ """Custom retry condition that logs errors."""
246
+ if retry_state.outcome.failed:
247
+ error = retry_state.outcome.exception()
248
+ self.log_error_context(error, context, retry_state.attempt_number)
249
+ return self.should_retry(error, retry_state.attempt_number)
250
+ return False
251
+
252
+ def wait_strategy(retry_state):
253
+ """Custom wait strategy based on error type."""
254
+ if retry_state.outcome.failed:
255
+ error = retry_state.outcome.exception()
256
+ return self.calculate_delay(error, retry_state.attempt_number)
257
+ return 0
258
+
259
+ return retry(
260
+ retry=retry_condition,
261
+ wait=wait_strategy,
262
+ stop=stop_after_attempt(self.max_retries),
263
+ reraise=True,
264
+ )
265
+
266
+ def create_simple_retry_decorator(self):
267
+ """Create a simple retry decorator using tenacity's built-in strategies.
268
+
269
+ This is a simpler alternative when you don't need custom error handling logic.
270
+
271
+ Returns
272
+ -------
273
+ Callable
274
+ A simple retry decorator for LLM operations.
275
+ """
276
+ # Define retryable exception types
277
+ retryable_exceptions = (
278
+ RateLimitError,
279
+ APIConnectionError,
280
+ InternalServerError,
281
+ ServiceUnavailableError,
282
+ ContentPolicyViolationError,
283
+ )
284
+
285
+ return retry(
286
+ retry=retry_if_exception_type(retryable_exceptions),
287
+ wait=wait_exponential(
288
+ multiplier=self.base_delay, min=self.base_delay, max=self.max_delay
289
+ ),
290
+ stop=stop_after_attempt(self.max_retries),
291
+ reraise=True,
292
+ )
293
+
294
+ def wrap_completion(
295
+ self, completion_func, context: Optional[dict[str, Any]] = None
296
+ ):
297
+ """Wrap a completion function with error handling and retry logic.
298
+
299
+ Parameters
300
+ ----------
301
+ completion_func : Callable
302
+ The completion function to wrap.
303
+ context : Optional[Dict[str, Any]], optional
304
+ Context information for logging, by default None.
305
+
306
+ Returns
307
+ -------
308
+ Callable
309
+ The wrapped completion function with retry logic.
310
+ """
311
+ retry_decorator = self.create_retry_decorator(context)
312
+ return retry_decorator(completion_func)
313
+
314
+ def get_error_summary(self, error: Exception) -> dict[str, Any]:
315
+ """Get a summary of error information.
316
+
317
+ Parameters
318
+ ----------
319
+ error : Exception
320
+ The error to summarize.
321
+
322
+ Returns
323
+ -------
324
+ Dict[str, Any]
325
+ Error summary information.
326
+ """
327
+ category = self.classify_error(error)
328
+
329
+ return {
330
+ "error_type": type(error).__name__,
331
+ "error_category": category.value,
332
+ "error_message": str(error),
333
+ "retryable": category in self.retryable_errors,
334
+ "provider_error": hasattr(error, "response") and error.response is not None,
335
+ }
336
+
337
+ def format_error_message(
338
+ self, error: Exception, context: Optional[dict[str, Any]] = None
339
+ ) -> str:
340
+ """Format an error message for user display.
341
+
342
+ Parameters
343
+ ----------
344
+ error : Exception
345
+ The error to format.
346
+ context : Optional[Dict[str, Any]], optional
347
+ Additional context for the error, by default None.
348
+
349
+ Returns
350
+ -------
351
+ str
352
+ Formatted error message.
353
+ """
354
+ category = self.classify_error(error)
355
+ context = context or {}
356
+
357
+ base_msg = f"LLM operation failed: {error}"
358
+
359
+ if category == ErrorCategory.NON_RETRYABLE_AUTH:
360
+ return f"{base_msg}\nCheck your API key configuration."
361
+ if category == ErrorCategory.NON_RETRYABLE_CONTEXT_LENGTH:
362
+ return f"{base_msg}\nInput text is too long for the model."
363
+ if category == ErrorCategory.RETRYABLE_RATE_LIMIT:
364
+ return f"{base_msg}\nRate limit exceeded. Consider using a different model or reducing request frequency."
365
+ if category == ErrorCategory.NON_RETRYABLE_NOT_FOUND:
366
+ model = context.get("model", "unknown")
367
+ return f"{base_msg}\nModel '{model}' not found. Check the model identifier."
368
+ return base_msg