sdg-hub 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/__init__.py +28 -1
- sdg_hub/_version.py +2 -2
- sdg_hub/core/__init__.py +22 -0
- sdg_hub/core/blocks/__init__.py +58 -0
- sdg_hub/core/blocks/base.py +313 -0
- sdg_hub/core/blocks/deprecated_blocks/__init__.py +29 -0
- sdg_hub/core/blocks/deprecated_blocks/combine_columns.py +93 -0
- sdg_hub/core/blocks/deprecated_blocks/duplicate_columns.py +88 -0
- sdg_hub/core/blocks/deprecated_blocks/filter_by_value.py +103 -0
- sdg_hub/core/blocks/deprecated_blocks/flatten_columns.py +94 -0
- sdg_hub/core/blocks/deprecated_blocks/llmblock.py +479 -0
- sdg_hub/core/blocks/deprecated_blocks/rename_columns.py +88 -0
- sdg_hub/core/blocks/deprecated_blocks/sample_populator.py +58 -0
- sdg_hub/core/blocks/deprecated_blocks/selector.py +97 -0
- sdg_hub/core/blocks/deprecated_blocks/set_to_majority_value.py +88 -0
- sdg_hub/core/blocks/evaluation/__init__.py +9 -0
- sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +564 -0
- sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +564 -0
- sdg_hub/core/blocks/evaluation/verify_question_block.py +564 -0
- sdg_hub/core/blocks/filtering/__init__.py +12 -0
- sdg_hub/core/blocks/filtering/column_value_filter.py +188 -0
- sdg_hub/core/blocks/llm/__init__.py +25 -0
- sdg_hub/core/blocks/llm/client_manager.py +398 -0
- sdg_hub/core/blocks/llm/config.py +336 -0
- sdg_hub/core/blocks/llm/error_handler.py +368 -0
- sdg_hub/core/blocks/llm/llm_chat_block.py +542 -0
- sdg_hub/core/blocks/llm/prompt_builder_block.py +368 -0
- sdg_hub/core/blocks/llm/text_parser_block.py +310 -0
- sdg_hub/core/blocks/registry.py +331 -0
- sdg_hub/core/blocks/transform/__init__.py +23 -0
- sdg_hub/core/blocks/transform/duplicate_columns.py +88 -0
- sdg_hub/core/blocks/transform/index_based_mapper.py +225 -0
- sdg_hub/core/blocks/transform/melt_columns.py +126 -0
- sdg_hub/core/blocks/transform/rename_columns.py +69 -0
- sdg_hub/core/blocks/transform/text_concat.py +102 -0
- sdg_hub/core/blocks/transform/uniform_col_val_setter.py +101 -0
- sdg_hub/core/flow/__init__.py +20 -0
- sdg_hub/core/flow/base.py +980 -0
- sdg_hub/core/flow/metadata.py +344 -0
- sdg_hub/core/flow/migration.py +187 -0
- sdg_hub/core/flow/registry.py +330 -0
- sdg_hub/core/flow/validation.py +265 -0
- sdg_hub/{utils → core/utils}/__init__.py +6 -4
- sdg_hub/{utils → core/utils}/datautils.py +1 -3
- sdg_hub/core/utils/error_handling.py +208 -0
- sdg_hub/{utils → core/utils}/path_resolution.py +2 -2
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/atomic_facts.yaml +40 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/detailed_summary.yaml +13 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_faithfulness.yaml +64 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_question.yaml +29 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_relevancy.yaml +81 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/extractive_summary.yaml +13 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +191 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/generate_questions_responses.yaml +54 -0
- sdg_hub-0.2.0.dist-info/METADATA +218 -0
- sdg_hub-0.2.0.dist-info/RECORD +63 -0
- sdg_hub/blocks/__init__.py +0 -42
- sdg_hub/blocks/block.py +0 -96
- sdg_hub/blocks/llmblock.py +0 -375
- sdg_hub/blocks/openaichatblock.py +0 -556
- sdg_hub/blocks/utilblocks.py +0 -597
- sdg_hub/checkpointer.py +0 -139
- sdg_hub/configs/annotations/cot_reflection.yaml +0 -34
- sdg_hub/configs/annotations/detailed_annotations.yaml +0 -28
- sdg_hub/configs/annotations/detailed_description.yaml +0 -10
- sdg_hub/configs/annotations/detailed_description_icl.yaml +0 -32
- sdg_hub/configs/annotations/simple_annotations.yaml +0 -9
- sdg_hub/configs/knowledge/__init__.py +0 -0
- sdg_hub/configs/knowledge/atomic_facts.yaml +0 -46
- sdg_hub/configs/knowledge/auxilary_instructions.yaml +0 -35
- sdg_hub/configs/knowledge/detailed_summary.yaml +0 -18
- sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +0 -68
- sdg_hub/configs/knowledge/evaluate_question.yaml +0 -38
- sdg_hub/configs/knowledge/evaluate_relevancy.yaml +0 -84
- sdg_hub/configs/knowledge/extractive_summary.yaml +0 -18
- sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +0 -39
- sdg_hub/configs/knowledge/generate_questions.yaml +0 -82
- sdg_hub/configs/knowledge/generate_questions_responses.yaml +0 -56
- sdg_hub/configs/knowledge/generate_responses.yaml +0 -86
- sdg_hub/configs/knowledge/mcq_generation.yaml +0 -83
- sdg_hub/configs/knowledge/router.yaml +0 -12
- sdg_hub/configs/knowledge/simple_generate_qa.yaml +0 -34
- sdg_hub/configs/reasoning/__init__.py +0 -0
- sdg_hub/configs/reasoning/dynamic_cot.yaml +0 -40
- sdg_hub/configs/skills/__init__.py +0 -0
- sdg_hub/configs/skills/analyzer.yaml +0 -48
- sdg_hub/configs/skills/annotation.yaml +0 -36
- sdg_hub/configs/skills/contexts.yaml +0 -28
- sdg_hub/configs/skills/critic.yaml +0 -60
- sdg_hub/configs/skills/evaluate_freeform_pair.yaml +0 -111
- sdg_hub/configs/skills/evaluate_freeform_questions.yaml +0 -78
- sdg_hub/configs/skills/evaluate_grounded_pair.yaml +0 -119
- sdg_hub/configs/skills/evaluate_grounded_questions.yaml +0 -51
- sdg_hub/configs/skills/freeform_questions.yaml +0 -34
- sdg_hub/configs/skills/freeform_responses.yaml +0 -39
- sdg_hub/configs/skills/grounded_questions.yaml +0 -38
- sdg_hub/configs/skills/grounded_responses.yaml +0 -59
- sdg_hub/configs/skills/icl_examples/STEM.yaml +0 -56
- sdg_hub/configs/skills/icl_examples/__init__.py +0 -0
- sdg_hub/configs/skills/icl_examples/coding.yaml +0 -97
- sdg_hub/configs/skills/icl_examples/extraction.yaml +0 -36
- sdg_hub/configs/skills/icl_examples/humanities.yaml +0 -71
- sdg_hub/configs/skills/icl_examples/math.yaml +0 -85
- sdg_hub/configs/skills/icl_examples/reasoning.yaml +0 -30
- sdg_hub/configs/skills/icl_examples/roleplay.yaml +0 -45
- sdg_hub/configs/skills/icl_examples/writing.yaml +0 -80
- sdg_hub/configs/skills/judge.yaml +0 -53
- sdg_hub/configs/skills/planner.yaml +0 -67
- sdg_hub/configs/skills/respond.yaml +0 -8
- sdg_hub/configs/skills/revised_responder.yaml +0 -78
- sdg_hub/configs/skills/router.yaml +0 -59
- sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +0 -27
- sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +0 -31
- sdg_hub/flow.py +0 -477
- sdg_hub/flow_runner.py +0 -450
- sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +0 -13
- sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +0 -12
- sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +0 -89
- sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +0 -136
- sdg_hub/flows/generation/skills/improve_responses.yaml +0 -103
- sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +0 -12
- sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +0 -12
- sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +0 -80
- sdg_hub/flows/generation/skills/synth_skills.yaml +0 -59
- sdg_hub/pipeline.py +0 -121
- sdg_hub/prompts.py +0 -80
- sdg_hub/registry.py +0 -122
- sdg_hub/sdg.py +0 -206
- sdg_hub/utils/config_validation.py +0 -91
- sdg_hub/utils/error_handling.py +0 -94
- sdg_hub/utils/validation_result.py +0 -10
- sdg_hub-0.1.4.dist-info/METADATA +0 -190
- sdg_hub-0.1.4.dist-info/RECORD +0 -89
- sdg_hub/{logger_config.py → core/utils/logger_config.py} +1 -1
- /sdg_hub/{configs/__init__.py → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/README.md} +0 -0
- /sdg_hub/{configs/annotations → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab}/__init__.py +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/WHEEL +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,188 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Filter by value block for dataset filtering operations.
|
3
|
+
|
4
|
+
This module provides a block for filtering datasets based on column values
|
5
|
+
using various operations with optional data type conversion.
|
6
|
+
"""
|
7
|
+
|
8
|
+
# Standard
|
9
|
+
from typing import Any, Optional, Union
|
10
|
+
import operator
|
11
|
+
|
12
|
+
# Third Party
|
13
|
+
from datasets import Dataset
|
14
|
+
from pydantic import Field, field_validator
|
15
|
+
|
16
|
+
# Local
|
17
|
+
from ...utils.logger_config import setup_logger
|
18
|
+
from ..base import BaseBlock
|
19
|
+
from ..registry import BlockRegistry
|
20
|
+
|
21
|
+
logger = setup_logger(__name__)
|
22
|
+
|
23
|
+
# Supported operations mapping
|
24
|
+
OPERATION_MAP = {
|
25
|
+
"eq": operator.eq,
|
26
|
+
"ne": operator.ne,
|
27
|
+
"lt": operator.lt,
|
28
|
+
"le": operator.le,
|
29
|
+
"gt": operator.gt,
|
30
|
+
"ge": operator.ge,
|
31
|
+
"contains": operator.contains,
|
32
|
+
"in": lambda x, y: x in y, # Reverse contains for "x in y" semantics
|
33
|
+
}
|
34
|
+
|
35
|
+
# Supported data types mapping
|
36
|
+
DTYPE_MAP = {
|
37
|
+
"float": float,
|
38
|
+
"int": int,
|
39
|
+
}
|
40
|
+
|
41
|
+
|
42
|
+
@BlockRegistry.register(
|
43
|
+
"ColumnValueFilterBlock",
|
44
|
+
"filtering",
|
45
|
+
"Filters datasets based on column values using various comparison operations",
|
46
|
+
)
|
47
|
+
class ColumnValueFilterBlock(BaseBlock):
|
48
|
+
"""A block for filtering datasets based on column values.
|
49
|
+
|
50
|
+
This block allows filtering of datasets using various operations (e.g., equals, contains)
|
51
|
+
on specified column values, with optional data type conversion.
|
52
|
+
|
53
|
+
Attributes
|
54
|
+
----------
|
55
|
+
block_name : str
|
56
|
+
Name of the block.
|
57
|
+
input_cols : Union[str, List[str]]
|
58
|
+
Input column name(s). The first column will be used for filtering.
|
59
|
+
filter_value : Union[Any, List[Any]]
|
60
|
+
The value(s) to filter by.
|
61
|
+
operation : str
|
62
|
+
A string representing the binary operation to perform (e.g., "eq", "contains", "gt").
|
63
|
+
Supported operations: "eq", "ne", "lt", "le", "gt", "ge", "contains", "in".
|
64
|
+
convert_dtype : Optional[str], optional
|
65
|
+
String representation of type to convert the filter column to. Can be "float" or "int".
|
66
|
+
If None, no conversion is performed.
|
67
|
+
"""
|
68
|
+
|
69
|
+
filter_value: Union[Any, list[Any]] = Field(
|
70
|
+
..., description="The value(s) to filter by"
|
71
|
+
)
|
72
|
+
operation: str = Field(
|
73
|
+
...,
|
74
|
+
description="String name of binary operator for comparison (e.g., 'eq', 'contains')",
|
75
|
+
)
|
76
|
+
convert_dtype: Optional[str] = Field(
|
77
|
+
None,
|
78
|
+
description="String name of type to convert filter column to ('float' or 'int')",
|
79
|
+
)
|
80
|
+
|
81
|
+
@field_validator("operation")
|
82
|
+
@classmethod
|
83
|
+
def validate_operation(cls, v):
|
84
|
+
"""Validate that operation is a supported operation string."""
|
85
|
+
if v not in OPERATION_MAP:
|
86
|
+
raise ValueError(
|
87
|
+
f"Unsupported operation '{v}'. Supported operations: {list(OPERATION_MAP.keys())}"
|
88
|
+
)
|
89
|
+
return v
|
90
|
+
|
91
|
+
@field_validator("convert_dtype")
|
92
|
+
@classmethod
|
93
|
+
def validate_convert_dtype(cls, v):
|
94
|
+
"""Validate that convert_dtype is a supported type string."""
|
95
|
+
if v is not None and v not in DTYPE_MAP:
|
96
|
+
raise ValueError(
|
97
|
+
f"Unsupported dtype '{v}'. Supported dtypes: {list(DTYPE_MAP.keys())}"
|
98
|
+
)
|
99
|
+
return v
|
100
|
+
|
101
|
+
@field_validator("input_cols", mode="after")
|
102
|
+
@classmethod
|
103
|
+
def validate_input_cols_not_empty(cls, v):
|
104
|
+
"""Validate that we have at least one input column."""
|
105
|
+
if not v or len(v) == 0:
|
106
|
+
raise ValueError(
|
107
|
+
"ColumnValueFilterBlock requires at least one input column"
|
108
|
+
)
|
109
|
+
return v
|
110
|
+
|
111
|
+
def model_post_init(self, __context: Any) -> None:
|
112
|
+
"""Initialize derived attributes after Pydantic validation."""
|
113
|
+
super().model_post_init(__context) if hasattr(
|
114
|
+
super(), "model_post_init"
|
115
|
+
) else None
|
116
|
+
|
117
|
+
# Ensure output_cols is empty list for filtering operations (doesn't create new columns)
|
118
|
+
if self.output_cols is None:
|
119
|
+
self.output_cols = []
|
120
|
+
|
121
|
+
# Set derived attributes
|
122
|
+
self.value = (
|
123
|
+
self.filter_value
|
124
|
+
if isinstance(self.filter_value, list)
|
125
|
+
else [self.filter_value]
|
126
|
+
)
|
127
|
+
self.column_name = self.input_cols[0] # Use first input column for filtering
|
128
|
+
|
129
|
+
# Convert string operation to actual callable
|
130
|
+
self._operation_func = OPERATION_MAP[self.operation]
|
131
|
+
|
132
|
+
# Convert string dtype to actual type if specified
|
133
|
+
self._convert_dtype_func = (
|
134
|
+
DTYPE_MAP[self.convert_dtype] if self.convert_dtype else None
|
135
|
+
)
|
136
|
+
|
137
|
+
def _convert_dtype(self, sample: dict[str, Any]) -> dict[str, Any]:
|
138
|
+
"""Convert the data type of the filter column.
|
139
|
+
|
140
|
+
Parameters
|
141
|
+
----------
|
142
|
+
sample : Dict[str, Any]
|
143
|
+
The sample dictionary containing the column to convert.
|
144
|
+
|
145
|
+
Returns
|
146
|
+
-------
|
147
|
+
Dict[str, Any]
|
148
|
+
The sample with converted column value.
|
149
|
+
"""
|
150
|
+
try:
|
151
|
+
sample[self.column_name] = self._convert_dtype_func(
|
152
|
+
sample[self.column_name]
|
153
|
+
)
|
154
|
+
except ValueError as e:
|
155
|
+
logger.error(
|
156
|
+
"Error converting dtype: %s, filling with None to be filtered later", e
|
157
|
+
)
|
158
|
+
sample[self.column_name] = None
|
159
|
+
return sample
|
160
|
+
|
161
|
+
def generate(self, samples: Dataset, **_kwargs: Any) -> Dataset:
|
162
|
+
"""Generate filtered dataset based on specified conditions.
|
163
|
+
|
164
|
+
Parameters
|
165
|
+
----------
|
166
|
+
samples : Dataset
|
167
|
+
The input dataset to filter.
|
168
|
+
|
169
|
+
Returns
|
170
|
+
-------
|
171
|
+
Dataset
|
172
|
+
The filtered dataset.
|
173
|
+
"""
|
174
|
+
if self._convert_dtype_func:
|
175
|
+
samples = samples.map(self._convert_dtype)
|
176
|
+
|
177
|
+
samples = samples.filter(
|
178
|
+
lambda x: x[self.column_name] is not None,
|
179
|
+
)
|
180
|
+
|
181
|
+
# Apply filter operation
|
182
|
+
samples = samples.filter(
|
183
|
+
lambda x: any(
|
184
|
+
self._operation_func(x[self.column_name], value) for value in self.value
|
185
|
+
)
|
186
|
+
)
|
187
|
+
|
188
|
+
return samples
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""LLM blocks for provider-agnostic text generation.
|
3
|
+
|
4
|
+
This module provides blocks for interacting with language models through
|
5
|
+
LiteLLM, supporting 100+ providers including OpenAI, Anthropic, Google,
|
6
|
+
local models (vLLM, Ollama), and more.
|
7
|
+
"""
|
8
|
+
|
9
|
+
# Local
|
10
|
+
from .client_manager import LLMClientManager
|
11
|
+
from .config import LLMConfig
|
12
|
+
from .error_handler import ErrorCategory, LLMErrorHandler
|
13
|
+
from .llm_chat_block import LLMChatBlock
|
14
|
+
from .prompt_builder_block import PromptBuilderBlock
|
15
|
+
from .text_parser_block import TextParserBlock
|
16
|
+
|
17
|
+
__all__ = [
|
18
|
+
"LLMConfig",
|
19
|
+
"LLMClientManager",
|
20
|
+
"LLMErrorHandler",
|
21
|
+
"ErrorCategory",
|
22
|
+
"LLMChatBlock",
|
23
|
+
"PromptBuilderBlock",
|
24
|
+
"TextParserBlock",
|
25
|
+
]
|
@@ -0,0 +1,398 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Client manager for LLM operations supporting all providers via LiteLLM."""
|
3
|
+
|
4
|
+
# Standard
|
5
|
+
from typing import Any, Optional, Union
|
6
|
+
import asyncio
|
7
|
+
|
8
|
+
# Third Party
|
9
|
+
from litellm import acompletion, completion
|
10
|
+
import litellm
|
11
|
+
|
12
|
+
# Local
|
13
|
+
from ...utils.logger_config import setup_logger
|
14
|
+
from .config import LLMConfig
|
15
|
+
from .error_handler import LLMErrorHandler
|
16
|
+
|
17
|
+
logger = setup_logger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
class LLMClientManager:
|
21
|
+
"""Client manager for LLM operations using LiteLLM.
|
22
|
+
|
23
|
+
This class provides a unified interface for calling any LLM provider
|
24
|
+
supported by LiteLLM, with robust error handling and retry logic.
|
25
|
+
|
26
|
+
Parameters
|
27
|
+
----------
|
28
|
+
config : LLMConfig
|
29
|
+
Configuration for the LLM client.
|
30
|
+
error_handler : Optional[LLMErrorHandler], optional
|
31
|
+
Custom error handler. If None, a default one will be created.
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self, config: LLMConfig, error_handler: Optional[LLMErrorHandler] = None
|
36
|
+
) -> None:
|
37
|
+
self.config = config
|
38
|
+
self.error_handler = error_handler or LLMErrorHandler(
|
39
|
+
max_retries=config.max_retries
|
40
|
+
)
|
41
|
+
self._is_loaded = False
|
42
|
+
|
43
|
+
def load(self) -> None:
|
44
|
+
"""Load and configure the LLM client.
|
45
|
+
|
46
|
+
This method sets up LiteLLM configuration and validates the setup.
|
47
|
+
"""
|
48
|
+
if self._is_loaded:
|
49
|
+
return
|
50
|
+
|
51
|
+
# Configure LiteLLM
|
52
|
+
self._configure_litellm()
|
53
|
+
|
54
|
+
# Test the configuration
|
55
|
+
self._validate_setup()
|
56
|
+
|
57
|
+
self._is_loaded = True
|
58
|
+
|
59
|
+
# Only log when model is actually configured
|
60
|
+
if self.config.model:
|
61
|
+
logger.info(
|
62
|
+
f"Loaded LLM client for model '{self.config.model}'",
|
63
|
+
extra={
|
64
|
+
"model": self.config.model,
|
65
|
+
"provider": self.config.get_provider(),
|
66
|
+
"is_local": self.config.is_local_model(),
|
67
|
+
"api_base": self.config.api_base,
|
68
|
+
},
|
69
|
+
)
|
70
|
+
|
71
|
+
def unload(self) -> None:
|
72
|
+
"""Unload the client and clean up resources."""
|
73
|
+
self._is_loaded = False
|
74
|
+
try:
|
75
|
+
logger.info(f"Unloaded LLM client for model '{self.config.model}'")
|
76
|
+
except Exception:
|
77
|
+
# Ignore logging errors during cleanup to prevent issues during shutdown
|
78
|
+
pass
|
79
|
+
|
80
|
+
def _configure_litellm(self) -> None:
|
81
|
+
"""Configure LiteLLM settings."""
|
82
|
+
# Set global timeout for LiteLLM
|
83
|
+
litellm.request_timeout = self.config.timeout
|
84
|
+
|
85
|
+
# Note: API keys are now passed directly in completion calls
|
86
|
+
# instead of modifying environment variables for thread-safety
|
87
|
+
|
88
|
+
def _validate_setup(self) -> None:
|
89
|
+
"""Validate that the LLM setup is working."""
|
90
|
+
try:
|
91
|
+
# For testing/development, skip validation if using dummy API key
|
92
|
+
if self.config.api_key == "test-key":
|
93
|
+
logger.debug(
|
94
|
+
f"Skipping validation for model '{self.config.model}' (test mode)"
|
95
|
+
)
|
96
|
+
return
|
97
|
+
|
98
|
+
# TODO: Skip validation for now to avoid API calls during initialization
|
99
|
+
# we might want to make a minimal test call
|
100
|
+
logger.debug(
|
101
|
+
f"Setup configured for model '{self.config.model}'. "
|
102
|
+
f"Validation will occur on first actual call."
|
103
|
+
)
|
104
|
+
|
105
|
+
except Exception as e:
|
106
|
+
logger.warning(
|
107
|
+
f"Could not validate setup for model '{self.config.model}': {e}"
|
108
|
+
)
|
109
|
+
|
110
|
+
def create_completion(
|
111
|
+
self, messages: list[dict[str, Any]], **overrides: Any
|
112
|
+
) -> Union[str, list[str]]:
|
113
|
+
"""Create a completion using LiteLLM.
|
114
|
+
|
115
|
+
Parameters
|
116
|
+
----------
|
117
|
+
messages : List[Dict[str, Any]]
|
118
|
+
Messages in OpenAI format.
|
119
|
+
**overrides : Any
|
120
|
+
Runtime parameter overrides.
|
121
|
+
|
122
|
+
Returns
|
123
|
+
-------
|
124
|
+
Union[str, List[str]]
|
125
|
+
The completion text(s). Returns a single string when n=1 or n is None,
|
126
|
+
returns a list of strings when n>1.
|
127
|
+
|
128
|
+
Raises
|
129
|
+
------
|
130
|
+
Exception
|
131
|
+
If the completion fails after all retries.
|
132
|
+
"""
|
133
|
+
if not self._is_loaded:
|
134
|
+
self.load()
|
135
|
+
|
136
|
+
# Merge configuration with overrides
|
137
|
+
final_config = self.config.merge_overrides(**overrides)
|
138
|
+
kwargs = self._build_completion_kwargs(messages, final_config)
|
139
|
+
|
140
|
+
# Create retry wrapper
|
141
|
+
context = {
|
142
|
+
"model": final_config.model,
|
143
|
+
"provider": final_config.get_provider(),
|
144
|
+
"message_count": len(messages),
|
145
|
+
}
|
146
|
+
|
147
|
+
completion_func = self.error_handler.wrap_completion(
|
148
|
+
self._call_litellm_completion, context=context
|
149
|
+
)
|
150
|
+
|
151
|
+
# Make the completion call
|
152
|
+
response = completion_func(kwargs)
|
153
|
+
|
154
|
+
# Extract content from response
|
155
|
+
# Check if n > 1 to determine return type
|
156
|
+
n_value = final_config.n or 1
|
157
|
+
if n_value > 1:
|
158
|
+
return [choice.message.content for choice in response.choices]
|
159
|
+
else:
|
160
|
+
return response.choices[0].message.content
|
161
|
+
|
162
|
+
async def acreate_completion(
|
163
|
+
self, messages: list[dict[str, Any]], **overrides: Any
|
164
|
+
) -> Union[str, list[str]]:
|
165
|
+
"""Create an async completion using LiteLLM.
|
166
|
+
|
167
|
+
Parameters
|
168
|
+
----------
|
169
|
+
messages : List[Dict[str, Any]]
|
170
|
+
Messages in OpenAI format.
|
171
|
+
**overrides : Any
|
172
|
+
Runtime parameter overrides.
|
173
|
+
|
174
|
+
Returns
|
175
|
+
-------
|
176
|
+
Union[str, List[str]]
|
177
|
+
The completion text(s). Returns a single string when n=1 or n is None,
|
178
|
+
returns a list of strings when n>1.
|
179
|
+
|
180
|
+
Raises
|
181
|
+
------
|
182
|
+
Exception
|
183
|
+
If the completion fails after all retries.
|
184
|
+
"""
|
185
|
+
if not self._is_loaded:
|
186
|
+
self.load()
|
187
|
+
|
188
|
+
# Merge configuration with overrides
|
189
|
+
final_config = self.config.merge_overrides(**overrides)
|
190
|
+
kwargs = self._build_completion_kwargs(messages, final_config)
|
191
|
+
|
192
|
+
# Create retry wrapper for async
|
193
|
+
context = {
|
194
|
+
"model": final_config.model,
|
195
|
+
"provider": final_config.get_provider(),
|
196
|
+
"message_count": len(messages),
|
197
|
+
}
|
198
|
+
|
199
|
+
completion_func = self.error_handler.wrap_completion(
|
200
|
+
self._call_litellm_acompletion, context=context
|
201
|
+
)
|
202
|
+
|
203
|
+
# Make the async completion call
|
204
|
+
response = await completion_func(kwargs)
|
205
|
+
|
206
|
+
# Extract content from response
|
207
|
+
# Check if n > 1 to determine return type
|
208
|
+
n_value = final_config.n or 1
|
209
|
+
if n_value > 1:
|
210
|
+
return [choice.message.content for choice in response.choices]
|
211
|
+
else:
|
212
|
+
return response.choices[0].message.content
|
213
|
+
|
214
|
+
def create_completions_batch(
|
215
|
+
self, messages_list: list[list[dict[str, Any]]], **overrides: Any
|
216
|
+
) -> list[Union[str, list[str]]]:
|
217
|
+
"""Create multiple completions in batch.
|
218
|
+
|
219
|
+
Parameters
|
220
|
+
----------
|
221
|
+
messages_list : List[List[Dict[str, Any]]]
|
222
|
+
List of message lists to process.
|
223
|
+
**overrides : Any
|
224
|
+
Runtime parameter overrides.
|
225
|
+
|
226
|
+
Returns
|
227
|
+
-------
|
228
|
+
List[Union[str, List[str]]]
|
229
|
+
List of completion texts. Each element is a single string when n=1 or n is None,
|
230
|
+
or a list of strings when n>1.
|
231
|
+
"""
|
232
|
+
results = []
|
233
|
+
for messages in messages_list:
|
234
|
+
result = self.create_completion(messages, **overrides)
|
235
|
+
results.append(result)
|
236
|
+
return results
|
237
|
+
|
238
|
+
async def acreate_completions_batch(
|
239
|
+
self, messages_list: list[list[dict[str, Any]]], **overrides: Any
|
240
|
+
) -> list[Union[str, list[str]]]:
|
241
|
+
"""Create multiple completions in batch asynchronously.
|
242
|
+
|
243
|
+
Parameters
|
244
|
+
----------
|
245
|
+
messages_list : List[List[Dict[str, Any]]]
|
246
|
+
List of message lists to process.
|
247
|
+
**overrides : Any
|
248
|
+
Runtime parameter overrides.
|
249
|
+
|
250
|
+
Returns
|
251
|
+
-------
|
252
|
+
List[Union[str, List[str]]]
|
253
|
+
List of completion texts. Each element is a single string when n=1 or n is None,
|
254
|
+
or a list of strings when n>1.
|
255
|
+
"""
|
256
|
+
tasks = [
|
257
|
+
self.acreate_completion(messages, **overrides) for messages in messages_list
|
258
|
+
]
|
259
|
+
return await asyncio.gather(*tasks)
|
260
|
+
|
261
|
+
def _build_completion_kwargs(
|
262
|
+
self, messages: list[dict[str, Any]], config: LLMConfig
|
263
|
+
) -> dict[str, Any]:
|
264
|
+
"""Build kwargs for LiteLLM completion call.
|
265
|
+
|
266
|
+
Parameters
|
267
|
+
----------
|
268
|
+
messages : List[Dict[str, Any]]
|
269
|
+
Messages in OpenAI format.
|
270
|
+
config : LLMConfig
|
271
|
+
Final configuration after merging overrides.
|
272
|
+
|
273
|
+
Returns
|
274
|
+
-------
|
275
|
+
Dict[str, Any]
|
276
|
+
Kwargs for litellm.completion().
|
277
|
+
"""
|
278
|
+
kwargs = {
|
279
|
+
"model": config.model,
|
280
|
+
"messages": messages,
|
281
|
+
}
|
282
|
+
|
283
|
+
# Add API configuration
|
284
|
+
if config.api_key:
|
285
|
+
kwargs["api_key"] = config.api_key
|
286
|
+
|
287
|
+
if config.api_base:
|
288
|
+
kwargs["api_base"] = config.api_base
|
289
|
+
|
290
|
+
# Add generation parameters
|
291
|
+
generation_kwargs = config.get_generation_kwargs()
|
292
|
+
kwargs.update(generation_kwargs)
|
293
|
+
|
294
|
+
return kwargs
|
295
|
+
|
296
|
+
def _call_litellm_completion(self, kwargs: dict[str, Any]) -> Any:
|
297
|
+
"""Call LiteLLM completion with error handling.
|
298
|
+
|
299
|
+
Parameters
|
300
|
+
----------
|
301
|
+
kwargs : Dict[str, Any]
|
302
|
+
Arguments for litellm.completion().
|
303
|
+
|
304
|
+
Returns
|
305
|
+
-------
|
306
|
+
Any
|
307
|
+
LiteLLM completion response.
|
308
|
+
"""
|
309
|
+
logger.debug(
|
310
|
+
f"Calling LiteLLM completion for model '{kwargs['model']}'",
|
311
|
+
extra={
|
312
|
+
"model": kwargs["model"],
|
313
|
+
"message_count": len(kwargs["messages"]),
|
314
|
+
"generation_params": {
|
315
|
+
k: v
|
316
|
+
for k, v in kwargs.items()
|
317
|
+
if k in ["temperature", "max_tokens", "top_p", "n"]
|
318
|
+
},
|
319
|
+
},
|
320
|
+
)
|
321
|
+
|
322
|
+
response = completion(**kwargs)
|
323
|
+
|
324
|
+
logger.debug(
|
325
|
+
f"LiteLLM completion successful for model '{kwargs['model']}'",
|
326
|
+
extra={
|
327
|
+
"model": kwargs["model"],
|
328
|
+
"choices_count": len(response.choices),
|
329
|
+
},
|
330
|
+
)
|
331
|
+
|
332
|
+
return response
|
333
|
+
|
334
|
+
async def _call_litellm_acompletion(self, kwargs: dict[str, Any]) -> Any:
|
335
|
+
"""Call LiteLLM async completion with error handling.
|
336
|
+
|
337
|
+
Parameters
|
338
|
+
----------
|
339
|
+
kwargs : Dict[str, Any]
|
340
|
+
Arguments for litellm.acompletion().
|
341
|
+
|
342
|
+
Returns
|
343
|
+
-------
|
344
|
+
Any
|
345
|
+
LiteLLM completion response.
|
346
|
+
"""
|
347
|
+
logger.debug(
|
348
|
+
f"Calling LiteLLM async completion for model '{kwargs['model']}'",
|
349
|
+
extra={
|
350
|
+
"model": kwargs["model"],
|
351
|
+
"message_count": len(kwargs["messages"]),
|
352
|
+
},
|
353
|
+
)
|
354
|
+
|
355
|
+
response = await acompletion(**kwargs)
|
356
|
+
|
357
|
+
logger.debug(
|
358
|
+
f"LiteLLM async completion successful for model '{kwargs['model']}'",
|
359
|
+
extra={
|
360
|
+
"model": kwargs["model"],
|
361
|
+
"choices_count": len(response.choices),
|
362
|
+
},
|
363
|
+
)
|
364
|
+
|
365
|
+
return response
|
366
|
+
|
367
|
+
def get_model_info(self) -> dict[str, Any]:
|
368
|
+
"""Get information about the configured model.
|
369
|
+
|
370
|
+
Returns
|
371
|
+
-------
|
372
|
+
Dict[str, Any]
|
373
|
+
Model information.
|
374
|
+
"""
|
375
|
+
return {
|
376
|
+
"model": self.config.model,
|
377
|
+
"provider": self.config.get_provider(),
|
378
|
+
"model_name": self.config.get_model_name(),
|
379
|
+
"is_local": self.config.is_local_model(),
|
380
|
+
"api_base": self.config.api_base,
|
381
|
+
"is_loaded": self._is_loaded,
|
382
|
+
}
|
383
|
+
|
384
|
+
def __enter__(self):
|
385
|
+
"""Context manager entry."""
|
386
|
+
self.load()
|
387
|
+
return self
|
388
|
+
|
389
|
+
def __exit__(self, _exc_type, _exc_val, _exc_tb):
|
390
|
+
"""Context manager exit."""
|
391
|
+
self.unload()
|
392
|
+
|
393
|
+
def __repr__(self) -> str:
|
394
|
+
"""String representation."""
|
395
|
+
return (
|
396
|
+
f"LLMClientManager(model='{self.config.model}', "
|
397
|
+
f"provider='{self.config.get_provider()}', loaded={self._is_loaded})"
|
398
|
+
)
|