sdg-hub 0.1.4__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sdg_hub/__init__.py +28 -1
- sdg_hub/_version.py +2 -2
- sdg_hub/core/__init__.py +22 -0
- sdg_hub/core/blocks/__init__.py +58 -0
- sdg_hub/core/blocks/base.py +313 -0
- sdg_hub/core/blocks/deprecated_blocks/__init__.py +29 -0
- sdg_hub/core/blocks/deprecated_blocks/combine_columns.py +93 -0
- sdg_hub/core/blocks/deprecated_blocks/duplicate_columns.py +88 -0
- sdg_hub/core/blocks/deprecated_blocks/filter_by_value.py +103 -0
- sdg_hub/core/blocks/deprecated_blocks/flatten_columns.py +94 -0
- sdg_hub/core/blocks/deprecated_blocks/llmblock.py +479 -0
- sdg_hub/core/blocks/deprecated_blocks/rename_columns.py +88 -0
- sdg_hub/core/blocks/deprecated_blocks/sample_populator.py +58 -0
- sdg_hub/core/blocks/deprecated_blocks/selector.py +97 -0
- sdg_hub/core/blocks/deprecated_blocks/set_to_majority_value.py +88 -0
- sdg_hub/core/blocks/evaluation/__init__.py +9 -0
- sdg_hub/core/blocks/evaluation/evaluate_faithfulness_block.py +564 -0
- sdg_hub/core/blocks/evaluation/evaluate_relevancy_block.py +564 -0
- sdg_hub/core/blocks/evaluation/verify_question_block.py +564 -0
- sdg_hub/core/blocks/filtering/__init__.py +12 -0
- sdg_hub/core/blocks/filtering/column_value_filter.py +188 -0
- sdg_hub/core/blocks/llm/__init__.py +27 -0
- sdg_hub/core/blocks/llm/client_manager.py +398 -0
- sdg_hub/core/blocks/llm/config.py +336 -0
- sdg_hub/core/blocks/llm/error_handler.py +368 -0
- sdg_hub/core/blocks/llm/llm_chat_block.py +542 -0
- sdg_hub/core/blocks/llm/llm_chat_with_parsing_retry_block.py +491 -0
- sdg_hub/core/blocks/llm/prompt_builder_block.py +368 -0
- sdg_hub/core/blocks/llm/text_parser_block.py +357 -0
- sdg_hub/core/blocks/registry.py +331 -0
- sdg_hub/core/blocks/transform/__init__.py +23 -0
- sdg_hub/core/blocks/transform/duplicate_columns.py +88 -0
- sdg_hub/core/blocks/transform/index_based_mapper.py +225 -0
- sdg_hub/core/blocks/transform/melt_columns.py +126 -0
- sdg_hub/core/blocks/transform/rename_columns.py +69 -0
- sdg_hub/core/blocks/transform/text_concat.py +102 -0
- sdg_hub/core/blocks/transform/uniform_col_val_setter.py +101 -0
- sdg_hub/core/flow/__init__.py +20 -0
- sdg_hub/core/flow/base.py +1209 -0
- sdg_hub/core/flow/checkpointer.py +333 -0
- sdg_hub/core/flow/metadata.py +389 -0
- sdg_hub/core/flow/migration.py +198 -0
- sdg_hub/core/flow/registry.py +393 -0
- sdg_hub/core/flow/validation.py +277 -0
- sdg_hub/{utils → core/utils}/__init__.py +7 -4
- sdg_hub/core/utils/datautils.py +63 -0
- sdg_hub/core/utils/error_handling.py +208 -0
- sdg_hub/core/utils/flow_id_words.yaml +231 -0
- sdg_hub/core/utils/flow_identifier.py +94 -0
- sdg_hub/{utils → core/utils}/path_resolution.py +2 -2
- sdg_hub/core/utils/yaml_utils.py +59 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/atomic_facts.yaml +40 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/detailed_summary.yaml +13 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_faithfulness.yaml +64 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_question.yaml +29 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/evaluate_relevancy.yaml +81 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/extractive_summary.yaml +13 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/flow.yaml +192 -0
- sdg_hub/flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/generate_questions_responses.yaml +54 -0
- sdg_hub-0.2.1.dist-info/METADATA +221 -0
- sdg_hub-0.2.1.dist-info/RECORD +68 -0
- sdg_hub/blocks/__init__.py +0 -42
- sdg_hub/blocks/block.py +0 -96
- sdg_hub/blocks/llmblock.py +0 -375
- sdg_hub/blocks/openaichatblock.py +0 -556
- sdg_hub/blocks/utilblocks.py +0 -597
- sdg_hub/checkpointer.py +0 -139
- sdg_hub/configs/annotations/cot_reflection.yaml +0 -34
- sdg_hub/configs/annotations/detailed_annotations.yaml +0 -28
- sdg_hub/configs/annotations/detailed_description.yaml +0 -10
- sdg_hub/configs/annotations/detailed_description_icl.yaml +0 -32
- sdg_hub/configs/annotations/simple_annotations.yaml +0 -9
- sdg_hub/configs/knowledge/__init__.py +0 -0
- sdg_hub/configs/knowledge/atomic_facts.yaml +0 -46
- sdg_hub/configs/knowledge/auxilary_instructions.yaml +0 -35
- sdg_hub/configs/knowledge/detailed_summary.yaml +0 -18
- sdg_hub/configs/knowledge/evaluate_faithfulness.yaml +0 -68
- sdg_hub/configs/knowledge/evaluate_question.yaml +0 -38
- sdg_hub/configs/knowledge/evaluate_relevancy.yaml +0 -84
- sdg_hub/configs/knowledge/extractive_summary.yaml +0 -18
- sdg_hub/configs/knowledge/generate_code_questions_responses.yaml +0 -39
- sdg_hub/configs/knowledge/generate_questions.yaml +0 -82
- sdg_hub/configs/knowledge/generate_questions_responses.yaml +0 -56
- sdg_hub/configs/knowledge/generate_responses.yaml +0 -86
- sdg_hub/configs/knowledge/mcq_generation.yaml +0 -83
- sdg_hub/configs/knowledge/router.yaml +0 -12
- sdg_hub/configs/knowledge/simple_generate_qa.yaml +0 -34
- sdg_hub/configs/reasoning/__init__.py +0 -0
- sdg_hub/configs/reasoning/dynamic_cot.yaml +0 -40
- sdg_hub/configs/skills/__init__.py +0 -0
- sdg_hub/configs/skills/analyzer.yaml +0 -48
- sdg_hub/configs/skills/annotation.yaml +0 -36
- sdg_hub/configs/skills/contexts.yaml +0 -28
- sdg_hub/configs/skills/critic.yaml +0 -60
- sdg_hub/configs/skills/evaluate_freeform_pair.yaml +0 -111
- sdg_hub/configs/skills/evaluate_freeform_questions.yaml +0 -78
- sdg_hub/configs/skills/evaluate_grounded_pair.yaml +0 -119
- sdg_hub/configs/skills/evaluate_grounded_questions.yaml +0 -51
- sdg_hub/configs/skills/freeform_questions.yaml +0 -34
- sdg_hub/configs/skills/freeform_responses.yaml +0 -39
- sdg_hub/configs/skills/grounded_questions.yaml +0 -38
- sdg_hub/configs/skills/grounded_responses.yaml +0 -59
- sdg_hub/configs/skills/icl_examples/STEM.yaml +0 -56
- sdg_hub/configs/skills/icl_examples/__init__.py +0 -0
- sdg_hub/configs/skills/icl_examples/coding.yaml +0 -97
- sdg_hub/configs/skills/icl_examples/extraction.yaml +0 -36
- sdg_hub/configs/skills/icl_examples/humanities.yaml +0 -71
- sdg_hub/configs/skills/icl_examples/math.yaml +0 -85
- sdg_hub/configs/skills/icl_examples/reasoning.yaml +0 -30
- sdg_hub/configs/skills/icl_examples/roleplay.yaml +0 -45
- sdg_hub/configs/skills/icl_examples/writing.yaml +0 -80
- sdg_hub/configs/skills/judge.yaml +0 -53
- sdg_hub/configs/skills/planner.yaml +0 -67
- sdg_hub/configs/skills/respond.yaml +0 -8
- sdg_hub/configs/skills/revised_responder.yaml +0 -78
- sdg_hub/configs/skills/router.yaml +0 -59
- sdg_hub/configs/skills/simple_generate_qa_freeform.yaml +0 -27
- sdg_hub/configs/skills/simple_generate_qa_grounded.yaml +0 -31
- sdg_hub/flow.py +0 -477
- sdg_hub/flow_runner.py +0 -450
- sdg_hub/flows/generation/knowledge/mmlu_bench.yaml +0 -13
- sdg_hub/flows/generation/knowledge/simple_knowledge.yaml +0 -12
- sdg_hub/flows/generation/knowledge/synth_knowledge.yaml +0 -89
- sdg_hub/flows/generation/knowledge/synth_knowledge1.5.yaml +0 -136
- sdg_hub/flows/generation/skills/improve_responses.yaml +0 -103
- sdg_hub/flows/generation/skills/simple_freeform_skill.yaml +0 -12
- sdg_hub/flows/generation/skills/simple_grounded_skill.yaml +0 -12
- sdg_hub/flows/generation/skills/synth_grounded_skills.yaml +0 -80
- sdg_hub/flows/generation/skills/synth_skills.yaml +0 -59
- sdg_hub/pipeline.py +0 -121
- sdg_hub/prompts.py +0 -80
- sdg_hub/registry.py +0 -122
- sdg_hub/sdg.py +0 -206
- sdg_hub/utils/config_validation.py +0 -91
- sdg_hub/utils/datautils.py +0 -14
- sdg_hub/utils/error_handling.py +0 -94
- sdg_hub/utils/validation_result.py +0 -10
- sdg_hub-0.1.4.dist-info/METADATA +0 -190
- sdg_hub-0.1.4.dist-info/RECORD +0 -89
- sdg_hub/{logger_config.py → core/utils/logger_config.py} +1 -1
- /sdg_hub/{configs/__init__.py → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab/README.md} +0 -0
- /sdg_hub/{configs/annotations → flows/qa_generation/document_grounded_qa/multi_summary_qa/instructlab}/__init__.py +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.1.dist-info}/WHEEL +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {sdg_hub-0.1.4.dist-info → sdg_hub-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,103 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Deprecated FilterByValueBlock for backwards compatibility.
|
3
|
+
|
4
|
+
This module provides a deprecated wrapper around ColumnValueFilterBlock
|
5
|
+
to maintain backwards compatibility with existing code and configurations.
|
6
|
+
"""
|
7
|
+
|
8
|
+
# Standard
|
9
|
+
from typing import Any, Callable, Optional, Union
|
10
|
+
import warnings
|
11
|
+
|
12
|
+
# Third Party
|
13
|
+
from datasets import Dataset
|
14
|
+
|
15
|
+
# Local
|
16
|
+
from ...utils.logger_config import setup_logger
|
17
|
+
from ..base import BaseBlock
|
18
|
+
from ..filtering import ColumnValueFilterBlock
|
19
|
+
from ..registry import BlockRegistry
|
20
|
+
|
21
|
+
logger = setup_logger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
@BlockRegistry.register(
|
25
|
+
"FilterByValueBlock",
|
26
|
+
"deprecated",
|
27
|
+
"DEPRECATED: Use ColumnValueFilterBlock instead. Filters datasets based on column values using various comparison operations",
|
28
|
+
)
|
29
|
+
class FilterByValueBlock(BaseBlock):
|
30
|
+
"""DEPRECATED: A block for filtering datasets based on column values.
|
31
|
+
|
32
|
+
This block is deprecated and maintained only for backwards compatibility.
|
33
|
+
Please use ColumnValueFilterBlock instead.
|
34
|
+
|
35
|
+
This block allows filtering of datasets using various operations (e.g., equals, contains)
|
36
|
+
on specified column values, with optional data type conversion.
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
block_name: str,
|
42
|
+
filter_column: str,
|
43
|
+
filter_value: Union[Any, list[Any]],
|
44
|
+
operation: Callable[[Any, Any], bool],
|
45
|
+
convert_dtype: Optional[Union[type[float], type[int]]] = None,
|
46
|
+
**batch_kwargs: dict[str, Any],
|
47
|
+
) -> None:
|
48
|
+
"""Initialize the deprecated FilterByValueBlock.
|
49
|
+
|
50
|
+
Parameters
|
51
|
+
----------
|
52
|
+
block_name : str
|
53
|
+
Name of the block.
|
54
|
+
filter_column : str
|
55
|
+
Column name to filter on.
|
56
|
+
filter_value : Union[Any, list[Any]]
|
57
|
+
The value(s) to filter by.
|
58
|
+
operation : Callable[[Any, Any], bool]
|
59
|
+
A binary operator from the operator module.
|
60
|
+
convert_dtype : Optional[Union[type[float], type[int]]], optional
|
61
|
+
Type to convert the filter column to.
|
62
|
+
**batch_kwargs : dict[str, Any]
|
63
|
+
Additional batch processing arguments.
|
64
|
+
"""
|
65
|
+
# Issue deprecation warning
|
66
|
+
warnings.warn(
|
67
|
+
"FilterByValueBlock is deprecated and will be removed in a future version. "
|
68
|
+
"Please use ColumnValueFilterBlock instead.",
|
69
|
+
DeprecationWarning,
|
70
|
+
stacklevel=2,
|
71
|
+
)
|
72
|
+
|
73
|
+
# Map old signature to new signature
|
74
|
+
super().__init__(
|
75
|
+
block_name=block_name,
|
76
|
+
input_cols=[filter_column],
|
77
|
+
output_cols=[],
|
78
|
+
)
|
79
|
+
|
80
|
+
# Create the new block instance with mapped parameters
|
81
|
+
self._new_block = ColumnValueFilterBlock(
|
82
|
+
block_name=block_name,
|
83
|
+
input_cols=[filter_column],
|
84
|
+
output_cols=[],
|
85
|
+
filter_value=filter_value,
|
86
|
+
operation=operation,
|
87
|
+
convert_dtype=convert_dtype,
|
88
|
+
)
|
89
|
+
|
90
|
+
def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
|
91
|
+
"""Generate filtered dataset using the new ColumnValueFilterBlock.
|
92
|
+
|
93
|
+
Parameters
|
94
|
+
----------
|
95
|
+
samples : Dataset
|
96
|
+
The input dataset to filter.
|
97
|
+
|
98
|
+
Returns
|
99
|
+
-------
|
100
|
+
Dataset
|
101
|
+
The filtered dataset.
|
102
|
+
"""
|
103
|
+
return self._new_block.generate(samples, **kwargs)
|
@@ -0,0 +1,94 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Deprecated FlattenColumnsBlock for backwards compatibility.
|
3
|
+
|
4
|
+
This module provides a deprecated wrapper around MeltColumnsBlock
|
5
|
+
to maintain backwards compatibility with existing code and configurations.
|
6
|
+
"""
|
7
|
+
|
8
|
+
# Standard
|
9
|
+
from typing import Any
|
10
|
+
import warnings
|
11
|
+
|
12
|
+
# Third Party
|
13
|
+
from datasets import Dataset
|
14
|
+
|
15
|
+
# Local
|
16
|
+
from ...utils.logger_config import setup_logger
|
17
|
+
from ..base import BaseBlock
|
18
|
+
from ..registry import BlockRegistry
|
19
|
+
from ..transform import MeltColumnsBlock
|
20
|
+
|
21
|
+
logger = setup_logger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
@BlockRegistry.register(
|
25
|
+
"FlattenColumnsBlock",
|
26
|
+
"deprecated",
|
27
|
+
"DEPRECATED: Use MeltColumnsBlock instead. Transforms wide dataset format into long format by melting columns into rows",
|
28
|
+
)
|
29
|
+
class FlattenColumnsBlock(BaseBlock):
|
30
|
+
"""DEPRECATED: Block for flattening multiple columns into a long format.
|
31
|
+
|
32
|
+
This block is deprecated and maintained only for backwards compatibility.
|
33
|
+
Please use MeltColumnsBlock instead.
|
34
|
+
|
35
|
+
This block transforms a wide dataset format into a long format by melting
|
36
|
+
specified columns into rows, creating new variable and value columns.
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
block_name: str,
|
42
|
+
var_cols: list[str],
|
43
|
+
value_name: str,
|
44
|
+
var_name: str,
|
45
|
+
) -> None:
|
46
|
+
"""Initialize the deprecated FlattenColumnsBlock.
|
47
|
+
|
48
|
+
Parameters
|
49
|
+
----------
|
50
|
+
block_name : str
|
51
|
+
Name of the block.
|
52
|
+
var_cols : List[str]
|
53
|
+
List of column names to be melted into rows.
|
54
|
+
value_name : str
|
55
|
+
Name of the new column that will contain the values.
|
56
|
+
var_name : str
|
57
|
+
Name of the new column that will contain the variable names.
|
58
|
+
"""
|
59
|
+
# Issue deprecation warning
|
60
|
+
warnings.warn(
|
61
|
+
"FlattenColumnsBlock is deprecated and will be removed in a future version. "
|
62
|
+
"Please use MeltColumnsBlock instead.",
|
63
|
+
DeprecationWarning,
|
64
|
+
stacklevel=2,
|
65
|
+
)
|
66
|
+
|
67
|
+
# Map old signature to new signature
|
68
|
+
super().__init__(
|
69
|
+
block_name=block_name,
|
70
|
+
input_cols=var_cols,
|
71
|
+
output_cols=[value_name, var_name],
|
72
|
+
)
|
73
|
+
|
74
|
+
# Create the new block instance with mapped parameters
|
75
|
+
self._new_block = MeltColumnsBlock(
|
76
|
+
block_name=block_name,
|
77
|
+
input_cols=var_cols,
|
78
|
+
output_cols=[value_name, var_name],
|
79
|
+
)
|
80
|
+
|
81
|
+
def generate(self, samples: Dataset, **kwargs: Any) -> Dataset:
|
82
|
+
"""Generate flattened dataset using the new MeltColumnsBlock.
|
83
|
+
|
84
|
+
Parameters
|
85
|
+
----------
|
86
|
+
samples : Dataset
|
87
|
+
The input dataset to flatten.
|
88
|
+
|
89
|
+
Returns
|
90
|
+
-------
|
91
|
+
Dataset
|
92
|
+
The flattened dataset in long format.
|
93
|
+
"""
|
94
|
+
return self._new_block.generate(samples, **kwargs)
|
@@ -0,0 +1,479 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""DEPRECATED: LLM-based blocks for text generation and processing.
|
3
|
+
|
4
|
+
This module provides backwards-compatible blocks for interacting with language models.
|
5
|
+
|
6
|
+
DEPRECATED: The LLMBlock is deprecated and will be removed in a future version.
|
7
|
+
Use the new modular approach with PromptBuilderBlock, LLMChatBlock, and TextParserBlock instead.
|
8
|
+
"""
|
9
|
+
|
10
|
+
# Standard
|
11
|
+
from typing import Any, Optional
|
12
|
+
import os
|
13
|
+
import tempfile
|
14
|
+
import warnings
|
15
|
+
|
16
|
+
# Third Party
|
17
|
+
from datasets import Dataset
|
18
|
+
from jinja2 import Environment, meta
|
19
|
+
import openai
|
20
|
+
import yaml
|
21
|
+
|
22
|
+
# Local
|
23
|
+
from ...utils.logger_config import setup_logger
|
24
|
+
from ..base import BaseBlock
|
25
|
+
from ..llm.llm_chat_block import LLMChatBlock
|
26
|
+
from ..llm.prompt_builder_block import PromptBuilderBlock
|
27
|
+
from ..llm.text_parser_block import TextParserBlock
|
28
|
+
from ..registry import BlockRegistry
|
29
|
+
|
30
|
+
logger = setup_logger(__name__)
|
31
|
+
|
32
|
+
|
33
|
+
def server_supports_batched(client: Any, model_id: str) -> bool:
|
34
|
+
"""Check if the server supports batched inputs.
|
35
|
+
|
36
|
+
This function checks if the server supports batched inputs by making a test call to the server.
|
37
|
+
|
38
|
+
Parameters
|
39
|
+
----------
|
40
|
+
client : openai.OpenAI
|
41
|
+
The client to use to make the test call.
|
42
|
+
model_id : str
|
43
|
+
The model ID to use for the test call.
|
44
|
+
"""
|
45
|
+
supported = getattr(client, "server_supports_batched", None)
|
46
|
+
if supported is not None:
|
47
|
+
return supported
|
48
|
+
try:
|
49
|
+
# Make a test call to the server to determine whether it supports
|
50
|
+
# multiple input prompts per request and also the n parameter
|
51
|
+
response = client.completions.create(
|
52
|
+
model=model_id, prompt=["test1", "test2"], max_tokens=1, n=3
|
53
|
+
)
|
54
|
+
# Number outputs should be 2 * 3 = 6
|
55
|
+
supported = len(response.choices) == 6
|
56
|
+
except openai.InternalServerError:
|
57
|
+
supported = False
|
58
|
+
client.server_supports_batched = supported
|
59
|
+
logger.info(
|
60
|
+
f"LLM server supports batched inputs: {getattr(client, 'server_supports_batched', False)}"
|
61
|
+
)
|
62
|
+
return supported
|
63
|
+
|
64
|
+
|
65
|
+
@BlockRegistry.register(
|
66
|
+
block_name="LLMBlock",
|
67
|
+
category="deprecated",
|
68
|
+
description="DEPRECATED: Use the new modular approach with PromptBuilderBlock, LLMChatBlock, and TextParserBlock instead",
|
69
|
+
)
|
70
|
+
class LLMBlock(BaseBlock):
|
71
|
+
"""DEPRECATED: Block for generating text using language models.
|
72
|
+
|
73
|
+
This block maintains backwards compatibility with the old LLMBlock interface
|
74
|
+
by internally using the new modular blocks: PromptBuilderBlock, LLMChatBlock, and TextParserBlock.
|
75
|
+
|
76
|
+
Parameters
|
77
|
+
----------
|
78
|
+
block_name : str
|
79
|
+
Name of the block.
|
80
|
+
config_path : str
|
81
|
+
Path to the configuration file.
|
82
|
+
client : openai.OpenAI
|
83
|
+
OpenAI client instance.
|
84
|
+
output_cols : List[str]
|
85
|
+
List of output column names.
|
86
|
+
parser_kwargs : Dict[str, Any], optional
|
87
|
+
Keyword arguments for the parser, by default {}.
|
88
|
+
model_prompt : str, optional
|
89
|
+
Template string for model prompt, by default "{prompt}".
|
90
|
+
model_id : Optional[str], optional
|
91
|
+
Model ID to use, by default None.
|
92
|
+
**batch_kwargs : Dict[str, Any]
|
93
|
+
Additional keyword arguments for batch processing.
|
94
|
+
"""
|
95
|
+
|
96
|
+
def __init__(
|
97
|
+
self,
|
98
|
+
block_name: str,
|
99
|
+
config_path: str,
|
100
|
+
client: Any,
|
101
|
+
output_cols: list[str],
|
102
|
+
parser_kwargs: dict[str, Any] = None,
|
103
|
+
model_prompt: str = "{prompt}",
|
104
|
+
model_id: Optional[str] = None,
|
105
|
+
**batch_kwargs: dict[str, Any],
|
106
|
+
) -> None:
|
107
|
+
# Issue deprecation warning
|
108
|
+
if parser_kwargs is None:
|
109
|
+
parser_kwargs = {}
|
110
|
+
warnings.warn(
|
111
|
+
"LLMBlock is deprecated and will be removed in a future version. "
|
112
|
+
"Use the new modular approach with PromptBuilderBlock, LLMChatBlock, and TextParserBlock instead.",
|
113
|
+
DeprecationWarning,
|
114
|
+
stacklevel=2,
|
115
|
+
)
|
116
|
+
|
117
|
+
# Load config and extract input columns before calling super().__init__()
|
118
|
+
block_config = self._load_config_static(config_path)
|
119
|
+
input_cols = self._extract_template_variables_static(block_config)
|
120
|
+
|
121
|
+
super().__init__(
|
122
|
+
block_name=block_name, input_cols=input_cols, output_cols=output_cols
|
123
|
+
)
|
124
|
+
|
125
|
+
# Now we can set instance attributes
|
126
|
+
self.config_path = config_path
|
127
|
+
self.block_config = block_config
|
128
|
+
|
129
|
+
# Store original parameters for compatibility
|
130
|
+
self.client = client
|
131
|
+
self.parser_kwargs = parser_kwargs or {}
|
132
|
+
self.model_prompt = model_prompt
|
133
|
+
self.batch_kwargs = batch_kwargs.get("batch_kwargs", {})
|
134
|
+
|
135
|
+
# Set model
|
136
|
+
if model_id:
|
137
|
+
self.model = model_id
|
138
|
+
else:
|
139
|
+
# get the default model id from client
|
140
|
+
self.model = self.client.models.list().data[0].id
|
141
|
+
|
142
|
+
# Create temporary config file for new prompt builder
|
143
|
+
self._temp_prompt_config = self._create_prompt_config()
|
144
|
+
|
145
|
+
# Initialize the three new blocks
|
146
|
+
self._setup_internal_blocks()
|
147
|
+
|
148
|
+
def _load_config(self, config_path: str) -> dict[str, Any]:
|
149
|
+
"""Load configuration from YAML file."""
|
150
|
+
return self._load_config_static(config_path)
|
151
|
+
|
152
|
+
@staticmethod
|
153
|
+
def _load_config_static(config_path: str) -> dict[str, Any]:
|
154
|
+
"""Load configuration from YAML file (static version)."""
|
155
|
+
try:
|
156
|
+
with open(config_path, encoding="utf-8") as file:
|
157
|
+
return yaml.safe_load(file) or {}
|
158
|
+
except Exception as e:
|
159
|
+
logger.error(f"Failed to load config from {config_path}: {e}")
|
160
|
+
return {}
|
161
|
+
|
162
|
+
def _extract_template_variables(self) -> list[str]:
|
163
|
+
"""Extract Jinja2 template variables from all config fields."""
|
164
|
+
return self._extract_template_variables_static(self.block_config)
|
165
|
+
|
166
|
+
@staticmethod
|
167
|
+
def _extract_template_variables_static(block_config: dict[str, Any]) -> list[str]:
|
168
|
+
"""Extract Jinja2 template variables from all config fields (static version)."""
|
169
|
+
variables: set[str] = set()
|
170
|
+
env = Environment()
|
171
|
+
|
172
|
+
# Extract variables from all string fields in config
|
173
|
+
for field in ["system", "introduction", "principles", "examples", "generation"]:
|
174
|
+
field_content = block_config.get(field, "")
|
175
|
+
if isinstance(field_content, str) and field_content.strip():
|
176
|
+
try:
|
177
|
+
ast = env.parse(field_content)
|
178
|
+
field_vars = meta.find_undeclared_variables(ast)
|
179
|
+
variables.update(field_vars)
|
180
|
+
except Exception as e:
|
181
|
+
logger.debug(
|
182
|
+
f"Could not parse template variables from {field}: {e}"
|
183
|
+
)
|
184
|
+
|
185
|
+
return list(variables)
|
186
|
+
|
187
|
+
def _create_prompt_config(self) -> str:
|
188
|
+
"""Create a temporary YAML config file for the new PromptBuilderBlock format."""
|
189
|
+
# Convert old config format to new message-based format
|
190
|
+
messages = []
|
191
|
+
|
192
|
+
# Create user message with the structured prompt (matching old format)
|
193
|
+
# Build prompt using the original structure: {system}\n{introduction}\n{principles}\n{examples}\n{generation}
|
194
|
+
prompt_parts = []
|
195
|
+
for field in ["system", "introduction", "principles", "examples", "generation"]:
|
196
|
+
field_content = self.block_config.get(field, "")
|
197
|
+
prompt_parts.append(field_content)
|
198
|
+
|
199
|
+
# Join with single newlines to match original prompt_struct
|
200
|
+
user_content = "\n".join(prompt_parts)
|
201
|
+
|
202
|
+
if user_content.strip():
|
203
|
+
messages.append({"role": "user", "content": user_content})
|
204
|
+
|
205
|
+
# Write to temporary file
|
206
|
+
temp_file = tempfile.NamedTemporaryFile(
|
207
|
+
mode="w", suffix=".yaml.j2", delete=False
|
208
|
+
)
|
209
|
+
yaml.safe_dump(messages, temp_file, default_flow_style=False)
|
210
|
+
temp_file.flush()
|
211
|
+
return temp_file.name
|
212
|
+
|
213
|
+
def _setup_internal_blocks(self) -> None:
|
214
|
+
"""Initialize the three internal blocks."""
|
215
|
+
# 1. PromptBuilderBlock
|
216
|
+
self.prompt_builder = PromptBuilderBlock(
|
217
|
+
block_name=f"{self.block_name}_prompt_builder",
|
218
|
+
input_cols=self.input_cols, # Pass through original input columns for template access
|
219
|
+
output_cols=["messages"],
|
220
|
+
prompt_config_path=self._temp_prompt_config,
|
221
|
+
format_as_messages=True,
|
222
|
+
)
|
223
|
+
|
224
|
+
# 2. LLMChatBlock
|
225
|
+
# Convert client to LiteLLM format - support OpenAI and hosted_vllm
|
226
|
+
if self.model.startswith("openai/") or self.model.startswith("hosted_vllm/"):
|
227
|
+
model_name = self.model
|
228
|
+
else:
|
229
|
+
# Local/hosted model
|
230
|
+
model_name = f"hosted_vllm/{self.model}"
|
231
|
+
|
232
|
+
# Extract generation parameters from batch_kwargs and defaults
|
233
|
+
defaults = {
|
234
|
+
"temperature": 0,
|
235
|
+
"max_tokens": 4096,
|
236
|
+
}
|
237
|
+
gen_params = {**defaults, **self.batch_kwargs}
|
238
|
+
|
239
|
+
# Convert URL to string if needed and handle mock objects
|
240
|
+
api_base = getattr(self.client, "base_url", None)
|
241
|
+
if api_base is not None:
|
242
|
+
api_base_str = str(api_base)
|
243
|
+
# Skip mock objects
|
244
|
+
api_base = (
|
245
|
+
api_base_str if not api_base_str.startswith("<MagicMock") else None
|
246
|
+
)
|
247
|
+
|
248
|
+
# Handle api_key - convert to string or set to None for mocks
|
249
|
+
api_key = getattr(self.client, "api_key", None)
|
250
|
+
if api_key is not None:
|
251
|
+
api_key_str = str(api_key)
|
252
|
+
# Skip mock objects
|
253
|
+
api_key = api_key_str if not api_key_str.startswith("<MagicMock") else None
|
254
|
+
|
255
|
+
self.llm_chat = LLMChatBlock(
|
256
|
+
block_name=f"{self.block_name}_llm_chat",
|
257
|
+
input_cols=["messages"],
|
258
|
+
output_cols=["raw_response"],
|
259
|
+
model=model_name,
|
260
|
+
api_key=api_key,
|
261
|
+
api_base=api_base,
|
262
|
+
**gen_params,
|
263
|
+
)
|
264
|
+
|
265
|
+
# 3. TextParserBlock
|
266
|
+
parser_config = {}
|
267
|
+
|
268
|
+
# Handle parsing configuration
|
269
|
+
parser_name = self.parser_kwargs.get("parser_name")
|
270
|
+
if parser_name == "custom":
|
271
|
+
parsing_pattern = self.parser_kwargs.get("parsing_pattern")
|
272
|
+
cleanup_tags = self.parser_kwargs.get("parser_cleanup_tags")
|
273
|
+
if parsing_pattern:
|
274
|
+
parser_config["parsing_pattern"] = parsing_pattern
|
275
|
+
if cleanup_tags:
|
276
|
+
parser_config["parser_cleanup_tags"] = cleanup_tags
|
277
|
+
else:
|
278
|
+
# Use start/end tags from config
|
279
|
+
start_tags = self.block_config.get("start_tags", [])
|
280
|
+
end_tags = self.block_config.get("end_tags", [])
|
281
|
+
if start_tags or end_tags:
|
282
|
+
parser_config["start_tags"] = start_tags
|
283
|
+
parser_config["end_tags"] = end_tags
|
284
|
+
|
285
|
+
# Only create parser if we have parsing configuration
|
286
|
+
if parser_config:
|
287
|
+
self.text_parser: Optional[TextParserBlock] = TextParserBlock(
|
288
|
+
block_name=f"{self.block_name}_text_parser",
|
289
|
+
input_cols=["raw_response"],
|
290
|
+
output_cols=self.output_cols,
|
291
|
+
**parser_config,
|
292
|
+
)
|
293
|
+
else:
|
294
|
+
self.text_parser = None
|
295
|
+
|
296
|
+
def generate(self, samples: Dataset, **gen_kwargs: dict[str, Any]) -> Dataset:
|
297
|
+
"""Generate the output from the block.
|
298
|
+
|
299
|
+
This method maintains backwards compatibility by internally using the three new blocks.
|
300
|
+
"""
|
301
|
+
logger.debug(
|
302
|
+
f"Generating outputs for {len(samples)} samples using deprecated LLMBlock"
|
303
|
+
)
|
304
|
+
|
305
|
+
# Validate num_samples handling
|
306
|
+
num_samples = self.block_config.get("num_samples")
|
307
|
+
if (num_samples is not None) and ("num_samples" not in samples.column_names):
|
308
|
+
samples = samples.add_column("num_samples", [num_samples] * len(samples))
|
309
|
+
|
310
|
+
try:
|
311
|
+
# Step 1: Format prompts using PromptBuilderBlock
|
312
|
+
# Pass the original dataset directly so template variables can be accessed
|
313
|
+
prompt_result = self.prompt_builder.generate(samples)
|
314
|
+
|
315
|
+
# Step 2: Generate responses using LLMChatBlock
|
316
|
+
chat_result = self.llm_chat.generate(prompt_result, **gen_kwargs)
|
317
|
+
|
318
|
+
# Step 3: Handle n parameter before parsing
|
319
|
+
num_parallel_samples = gen_kwargs.get("n", 1)
|
320
|
+
|
321
|
+
if num_parallel_samples > 1:
|
322
|
+
# When n > 1, we need to expand the list responses before parsing
|
323
|
+
# TextParserBlock expects individual strings, not lists
|
324
|
+
expanded_chat_data = []
|
325
|
+
|
326
|
+
for sample in chat_result:
|
327
|
+
raw_responses = sample["raw_response"]
|
328
|
+
if isinstance(raw_responses, list):
|
329
|
+
# Create one row per response
|
330
|
+
for response in raw_responses:
|
331
|
+
expanded_sample = {**sample}
|
332
|
+
expanded_sample["raw_response"] = response
|
333
|
+
expanded_chat_data.append(expanded_sample)
|
334
|
+
else:
|
335
|
+
# Single response (fallback)
|
336
|
+
expanded_chat_data.append(sample)
|
337
|
+
|
338
|
+
expanded_chat_result = Dataset.from_list(expanded_chat_data)
|
339
|
+
|
340
|
+
# Step 4: Parse the expanded responses using TextParserBlock (if configured)
|
341
|
+
if self.text_parser:
|
342
|
+
final_result = self.text_parser.generate(expanded_chat_result)
|
343
|
+
else:
|
344
|
+
# If no parser, just rename the raw_response column to the first output column
|
345
|
+
if self.output_cols:
|
346
|
+
final_result = expanded_chat_result.rename_column(
|
347
|
+
"raw_response", self.output_cols[0]
|
348
|
+
)
|
349
|
+
else:
|
350
|
+
final_result = expanded_chat_result
|
351
|
+
|
352
|
+
# Step 5: Merge with original samples (each original sample maps to n result samples)
|
353
|
+
merged_data = []
|
354
|
+
result_idx = 0
|
355
|
+
|
356
|
+
for orig_sample in samples:
|
357
|
+
# Each original sample should have n corresponding results
|
358
|
+
for _ in range(num_parallel_samples):
|
359
|
+
if result_idx < len(final_result):
|
360
|
+
result_sample = final_result[result_idx]
|
361
|
+
merged_sample = {**orig_sample}
|
362
|
+
for output_col in self.output_cols:
|
363
|
+
if output_col in result_sample:
|
364
|
+
merged_sample[output_col] = result_sample[
|
365
|
+
output_col
|
366
|
+
]
|
367
|
+
else:
|
368
|
+
merged_sample[output_col] = ""
|
369
|
+
merged_data.append(merged_sample)
|
370
|
+
result_idx += 1
|
371
|
+
else:
|
372
|
+
# Missing result - create empty
|
373
|
+
merged_sample = {**orig_sample}
|
374
|
+
for output_col in self.output_cols:
|
375
|
+
merged_sample[output_col] = ""
|
376
|
+
merged_data.append(merged_sample)
|
377
|
+
|
378
|
+
return Dataset.from_list(merged_data)
|
379
|
+
|
380
|
+
else:
|
381
|
+
# Step 4: Parse responses using TextParserBlock (if configured) - n=1 case
|
382
|
+
if self.text_parser:
|
383
|
+
logger.info(
|
384
|
+
f"DEPRECATED LLMBlock '{self.block_name}' before parsing (n=1): {len(chat_result)} samples"
|
385
|
+
)
|
386
|
+
final_result = self.text_parser.generate(chat_result)
|
387
|
+
logger.info(
|
388
|
+
f"DEPRECATED LLMBlock '{self.block_name}' after parsing (n=1): {len(final_result)} samples"
|
389
|
+
)
|
390
|
+
|
391
|
+
else:
|
392
|
+
# If no parser, just rename the raw_response column to the first output column
|
393
|
+
if self.output_cols:
|
394
|
+
final_result = chat_result.rename_column(
|
395
|
+
"raw_response", self.output_cols[0]
|
396
|
+
)
|
397
|
+
else:
|
398
|
+
final_result = chat_result
|
399
|
+
|
400
|
+
# Step 5: Merge with original samples for n=1 case
|
401
|
+
# Handle different parsing outcomes: expansion, contraction, or 1:1
|
402
|
+
if len(final_result) != len(samples):
|
403
|
+
# Row count changed - parsing found different number of results than inputs
|
404
|
+
if len(final_result) > len(samples):
|
405
|
+
logger.info(
|
406
|
+
f"DEPRECATED LLMBlock '{self.block_name}' detected row expansion: {len(samples)} -> {len(final_result)}"
|
407
|
+
)
|
408
|
+
else:
|
409
|
+
logger.info(
|
410
|
+
f"DEPRECATED LLMBlock '{self.block_name}' detected row contraction: {len(samples)} -> {len(final_result)}"
|
411
|
+
)
|
412
|
+
|
413
|
+
# For both expansion and contraction, return parsed results
|
414
|
+
# Keep only the expected output columns plus any preserved input columns
|
415
|
+
# Remove intermediate processing columns to avoid duplicates
|
416
|
+
desired_columns = set(self.output_cols) # Required output columns
|
417
|
+
available_columns = set(final_result.column_names)
|
418
|
+
|
419
|
+
# Add input columns that were preserved (excluding processing columns like raw_response, messages)
|
420
|
+
processing_columns = {
|
421
|
+
"raw_response",
|
422
|
+
"messages",
|
423
|
+
} # Common intermediate columns
|
424
|
+
for col in available_columns:
|
425
|
+
if col not in processing_columns and col not in desired_columns:
|
426
|
+
# This is likely a preserved input column
|
427
|
+
desired_columns.add(col)
|
428
|
+
|
429
|
+
# Filter to only the columns we want
|
430
|
+
columns_to_keep = [
|
431
|
+
col
|
432
|
+
for col in final_result.column_names
|
433
|
+
if col in desired_columns
|
434
|
+
]
|
435
|
+
final_dataset = final_result.select_columns(columns_to_keep)
|
436
|
+
|
437
|
+
else:
|
438
|
+
# Normal 1:1 case - merge with original samples to preserve all input columns
|
439
|
+
merged_data = []
|
440
|
+
for orig_sample, result_sample in zip(samples, final_result):
|
441
|
+
merged_sample = {**orig_sample}
|
442
|
+
for output_col in self.output_cols:
|
443
|
+
if output_col in result_sample:
|
444
|
+
response = result_sample[output_col]
|
445
|
+
# Handle case where response might still be a list with 1 item
|
446
|
+
if isinstance(response, list) and len(response) == 1:
|
447
|
+
merged_sample[output_col] = response[0]
|
448
|
+
elif isinstance(response, list):
|
449
|
+
# Multiple responses but n=1 - take first one
|
450
|
+
merged_sample[output_col] = (
|
451
|
+
response[0] if response else ""
|
452
|
+
)
|
453
|
+
else:
|
454
|
+
merged_sample[output_col] = response
|
455
|
+
else:
|
456
|
+
merged_sample[output_col] = ""
|
457
|
+
merged_data.append(merged_sample)
|
458
|
+
final_dataset = Dataset.from_list(merged_data)
|
459
|
+
|
460
|
+
return final_dataset
|
461
|
+
|
462
|
+
except Exception as e:
|
463
|
+
logger.error(f"Error in deprecated LLMBlock generation: {e}")
|
464
|
+
# Fall back to empty dataset with proper structure
|
465
|
+
empty_data = []
|
466
|
+
for sample in samples:
|
467
|
+
empty_sample = {**sample}
|
468
|
+
for output_col in self.output_cols:
|
469
|
+
empty_sample[output_col] = ""
|
470
|
+
empty_data.append(empty_sample)
|
471
|
+
return Dataset.from_list(empty_data)
|
472
|
+
|
473
|
+
def __del__(self):
|
474
|
+
"""Clean up temporary files."""
|
475
|
+
try:
|
476
|
+
if hasattr(self, "_temp_prompt_config"):
|
477
|
+
os.unlink(self._temp_prompt_config)
|
478
|
+
except Exception:
|
479
|
+
pass
|