hamtaa-texttools 1.1.20__py3-none-any.whl → 1.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hamtaa_texttools-1.1.20.dist-info → hamtaa_texttools-1.1.22.dist-info}/METADATA +49 -109
- hamtaa_texttools-1.1.22.dist-info/RECORD +32 -0
- texttools/__init__.py +3 -3
- texttools/batch/batch_config.py +14 -1
- texttools/batch/batch_runner.py +2 -2
- texttools/internals/async_operator.py +49 -92
- texttools/internals/models.py +74 -105
- texttools/internals/operator_utils.py +25 -27
- texttools/internals/prompt_loader.py +3 -20
- texttools/internals/sync_operator.py +49 -92
- texttools/prompts/README.md +2 -2
- texttools/prompts/categorize.yaml +35 -77
- texttools/prompts/check_fact.yaml +2 -2
- texttools/prompts/extract_entities.yaml +2 -2
- texttools/prompts/extract_keywords.yaml +6 -6
- texttools/prompts/is_question.yaml +2 -2
- texttools/prompts/merge_questions.yaml +4 -4
- texttools/prompts/propositionize.yaml +2 -2
- texttools/prompts/rewrite.yaml +6 -6
- texttools/prompts/run_custom.yaml +1 -1
- texttools/prompts/subject_to_question.yaml +2 -2
- texttools/prompts/summarize.yaml +2 -2
- texttools/prompts/text_to_question.yaml +2 -2
- texttools/prompts/translate.yaml +2 -2
- texttools/tools/async_tools.py +393 -487
- texttools/tools/sync_tools.py +394 -488
- hamtaa_texttools-1.1.20.dist-info/RECORD +0 -33
- texttools/batch/internals/utils.py +0 -13
- {hamtaa_texttools-1.1.20.dist-info → hamtaa_texttools-1.1.22.dist-info}/WHEEL +0 -0
- {hamtaa_texttools-1.1.20.dist-info → hamtaa_texttools-1.1.22.dist-info}/licenses/LICENSE +0 -0
- {hamtaa_texttools-1.1.20.dist-info → hamtaa_texttools-1.1.22.dist-info}/top_level.txt +0 -0
- /texttools/batch/{internals/batch_manager.py → batch_manager.py} +0 -0
texttools/internals/models.py
CHANGED
|
@@ -1,25 +1,39 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from datetime import datetime
|
|
2
|
-
from typing import Type, Literal
|
|
4
|
+
from typing import Type, Literal, Any
|
|
3
5
|
|
|
4
6
|
from pydantic import BaseModel, Field, create_model
|
|
5
7
|
|
|
6
8
|
|
|
7
|
-
class
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
analysis: str = ""
|
|
11
|
-
process: str | None = None
|
|
12
|
-
processed_at: datetime = datetime.now()
|
|
9
|
+
class ToolOutputMetadata(BaseModel):
|
|
10
|
+
tool_name: str
|
|
11
|
+
processed_at: datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
13
12
|
execution_time: float | None = None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ToolOutput(BaseModel):
|
|
16
|
+
result: Any = None
|
|
17
|
+
analysis: str | None = None
|
|
18
|
+
logprobs: list[dict[str, Any]] | None = None
|
|
14
19
|
errors: list[str] = []
|
|
20
|
+
metadata: ToolOutputMetadata | None = None
|
|
15
21
|
|
|
16
22
|
def __repr__(self) -> str:
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
+
base = f"""ToolOutput(result='{self.result}', result_type='{type(self.result)}', analysis='{self.analysis}', logprobs='{self.logprobs}', errors='{self.errors}'"""
|
|
24
|
+
|
|
25
|
+
if self.metadata:
|
|
26
|
+
base += f""", tool_name='{self.metadata.tool_name}',
|
|
27
|
+
processed_at='{self.metadata.processed_at}', execution_time='{self.metadata.execution_time}'
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
return base
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class OperatorOutput(BaseModel):
|
|
34
|
+
result: Any
|
|
35
|
+
analysis: str | None
|
|
36
|
+
logprobs: list[dict[str, Any]] | None
|
|
23
37
|
|
|
24
38
|
|
|
25
39
|
class Str(BaseModel):
|
|
@@ -53,114 +67,69 @@ class ReasonListStr(BaseModel):
|
|
|
53
67
|
)
|
|
54
68
|
|
|
55
69
|
|
|
56
|
-
class Node
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
70
|
+
class Node:
|
|
71
|
+
def __init__(self, name: str, description: str, level: int, parent: Node | None):
|
|
72
|
+
self.name = name
|
|
73
|
+
self.description = description
|
|
74
|
+
self.level = level
|
|
75
|
+
self.parent = parent
|
|
76
|
+
self.children = {}
|
|
62
77
|
|
|
63
78
|
|
|
64
79
|
class CategoryTree:
|
|
65
|
-
def __init__(self
|
|
66
|
-
self._root = Node(
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
self._new_id = 1
|
|
71
|
-
|
|
72
|
-
def get_all_nodes(self) -> list[Node]:
|
|
80
|
+
def __init__(self):
|
|
81
|
+
self._root = Node(name="root", description="root", level=0, parent=None)
|
|
82
|
+
self._all_nodes = {"root": self._root}
|
|
83
|
+
|
|
84
|
+
def get_all_nodes(self) -> dict[str, Node]:
|
|
73
85
|
return self._all_nodes
|
|
74
86
|
|
|
75
87
|
def get_level_count(self) -> int:
|
|
76
|
-
return max(
|
|
77
|
-
|
|
78
|
-
def get_node(self,
|
|
79
|
-
|
|
80
|
-
for node in self.get_all_nodes():
|
|
81
|
-
if node.name == identifier:
|
|
82
|
-
return node
|
|
83
|
-
return None
|
|
84
|
-
elif isinstance(identifier, int):
|
|
85
|
-
for node in self.get_all_nodes():
|
|
86
|
-
if node.node_id == identifier:
|
|
87
|
-
return node
|
|
88
|
-
return None
|
|
89
|
-
else:
|
|
90
|
-
return None
|
|
91
|
-
|
|
92
|
-
def get_children(self, parent_node: Node) -> list[Node] | None:
|
|
93
|
-
children = [
|
|
94
|
-
node
|
|
95
|
-
for node in self.get_all_nodes()
|
|
96
|
-
if parent_node.node_id == node.parent_id
|
|
97
|
-
]
|
|
98
|
-
return children if children else None
|
|
88
|
+
return max(node.level for node in self._all_nodes.values())
|
|
89
|
+
|
|
90
|
+
def get_node(self, name: str) -> Node | None:
|
|
91
|
+
return self._all_nodes.get(name)
|
|
99
92
|
|
|
100
93
|
def add_node(
|
|
101
94
|
self,
|
|
102
|
-
|
|
103
|
-
parent_name: str
|
|
95
|
+
name: str,
|
|
96
|
+
parent_name: str,
|
|
104
97
|
description: str | None = None,
|
|
105
98
|
) -> None:
|
|
106
|
-
if self.get_node(
|
|
107
|
-
raise ValueError(f"{
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
parent_id = parent_node.node_id
|
|
114
|
-
level = parent_node.level + 1
|
|
115
|
-
else:
|
|
116
|
-
level = 1
|
|
117
|
-
parent_id = 0
|
|
99
|
+
if self.get_node(name):
|
|
100
|
+
raise ValueError(f"Cannot add {name} category twice")
|
|
101
|
+
|
|
102
|
+
parent = self.get_node(parent_name)
|
|
103
|
+
|
|
104
|
+
if not parent:
|
|
105
|
+
raise ValueError(f"Parent category '{parent_name}' not found")
|
|
118
106
|
|
|
119
107
|
node_data = {
|
|
120
|
-
"
|
|
121
|
-
"name": node_name,
|
|
122
|
-
"level": level,
|
|
123
|
-
"parent_id": parent_id,
|
|
108
|
+
"name": name,
|
|
124
109
|
"description": description if description else "No description provided",
|
|
110
|
+
"level": parent.level + 1,
|
|
111
|
+
"parent": parent,
|
|
125
112
|
}
|
|
126
113
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
raise ValueError(f"Node with identifier: '{identifier}' not found.")
|
|
147
|
-
|
|
148
|
-
def dump_tree(self) -> dict:
|
|
149
|
-
def build_dict(node: Node) -> dict:
|
|
150
|
-
children = [
|
|
151
|
-
build_dict(child)
|
|
152
|
-
for child in self._all_nodes
|
|
153
|
-
if child.parent_id == node.node_id
|
|
154
|
-
]
|
|
155
|
-
return {
|
|
156
|
-
"node_id": node.node_id,
|
|
157
|
-
"name": node.name,
|
|
158
|
-
"level": node.level,
|
|
159
|
-
"parent_id": node.parent_id,
|
|
160
|
-
"children": children,
|
|
161
|
-
}
|
|
162
|
-
|
|
163
|
-
return {"category_tree": build_dict(self._root)["children"]}
|
|
114
|
+
new_node = Node(**node_data)
|
|
115
|
+
parent.children[name] = new_node
|
|
116
|
+
self._all_nodes[name] = new_node
|
|
117
|
+
|
|
118
|
+
def remove_node(self, name: str) -> None:
|
|
119
|
+
if name == "root":
|
|
120
|
+
raise ValueError("Cannot remove the root node")
|
|
121
|
+
|
|
122
|
+
node = self.get_node(name)
|
|
123
|
+
if not node:
|
|
124
|
+
raise ValueError(f"Category: '{name}' not found")
|
|
125
|
+
|
|
126
|
+
for child_name in list(node.children.keys()):
|
|
127
|
+
self.remove_node(child_name)
|
|
128
|
+
|
|
129
|
+
if node.parent:
|
|
130
|
+
del node.parent.children[name]
|
|
131
|
+
|
|
132
|
+
del self._all_nodes[name]
|
|
164
133
|
|
|
165
134
|
|
|
166
135
|
# This function is needed to create CategorizerOutput with dynamic categories
|
|
@@ -5,8 +5,30 @@ import random
|
|
|
5
5
|
|
|
6
6
|
class OperatorUtils:
|
|
7
7
|
@staticmethod
|
|
8
|
-
def
|
|
9
|
-
|
|
8
|
+
def build_main_prompt(
|
|
9
|
+
main_template: str,
|
|
10
|
+
analysis: str | None,
|
|
11
|
+
output_lang: str | None,
|
|
12
|
+
user_prompt: str | None,
|
|
13
|
+
) -> str:
|
|
14
|
+
main_prompt = ""
|
|
15
|
+
|
|
16
|
+
if analysis:
|
|
17
|
+
main_prompt += f"Based on this analysis:\n{analysis}\n"
|
|
18
|
+
|
|
19
|
+
if output_lang:
|
|
20
|
+
main_prompt += f"Respond only in the {output_lang} language.\n"
|
|
21
|
+
|
|
22
|
+
if user_prompt:
|
|
23
|
+
main_prompt += f"Consider this instruction {user_prompt}\n"
|
|
24
|
+
|
|
25
|
+
main_prompt += main_template
|
|
26
|
+
|
|
27
|
+
return main_prompt
|
|
28
|
+
|
|
29
|
+
@staticmethod
|
|
30
|
+
def build_message(prompt: str) -> list[dict[str, str]]:
|
|
31
|
+
return [{"role": "user", "content": prompt}]
|
|
10
32
|
|
|
11
33
|
@staticmethod
|
|
12
34
|
def extract_logprobs(completion: dict) -> list[dict]:
|
|
@@ -20,7 +42,7 @@ class OperatorUtils:
|
|
|
20
42
|
|
|
21
43
|
for choice in completion.choices:
|
|
22
44
|
if not getattr(choice, "logprobs", None):
|
|
23
|
-
|
|
45
|
+
raise ValueError("Your model does not support logprobs")
|
|
24
46
|
|
|
25
47
|
for logprob_item in choice.logprobs.content:
|
|
26
48
|
if ignore_pattern.match(logprob_item.token):
|
|
@@ -52,27 +74,3 @@ class OperatorUtils:
|
|
|
52
74
|
new_temp = base_temp + delta_temp
|
|
53
75
|
|
|
54
76
|
return max(0.0, min(new_temp, 1.5))
|
|
55
|
-
|
|
56
|
-
@staticmethod
|
|
57
|
-
def user_merge_format(messages: list[dict[str, str]]) -> list[dict[str, str]]:
|
|
58
|
-
"""
|
|
59
|
-
Merges consecutive user messages into a single message, separated by newlines.
|
|
60
|
-
|
|
61
|
-
This is useful for condensing a multi-turn user input into a single
|
|
62
|
-
message for the LLM. Assistant and system messages are left unchanged and
|
|
63
|
-
act as separators between user message groups.
|
|
64
|
-
"""
|
|
65
|
-
merged = []
|
|
66
|
-
|
|
67
|
-
for message in messages:
|
|
68
|
-
role, content = message["role"], message["content"].strip()
|
|
69
|
-
|
|
70
|
-
# Merge with previous user turn
|
|
71
|
-
if merged and role == "user" and merged[-1]["role"] == "user":
|
|
72
|
-
merged[-1]["content"] += "\n" + content
|
|
73
|
-
|
|
74
|
-
# Otherwise, start a new turn
|
|
75
|
-
else:
|
|
76
|
-
merged.append({"role": role, "content": content})
|
|
77
|
-
|
|
78
|
-
return merged
|
|
@@ -12,20 +12,12 @@ class PromptLoader:
|
|
|
12
12
|
Responsibilities:
|
|
13
13
|
- Load and parse YAML prompt definitions.
|
|
14
14
|
- Select the right template (by mode, if applicable).
|
|
15
|
-
- Inject variables (`{
|
|
15
|
+
- Inject variables (`{text}`, plus any extra kwargs) into the templates.
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
18
|
MAIN_TEMPLATE = "main_template"
|
|
19
19
|
ANALYZE_TEMPLATE = "analyze_template"
|
|
20
20
|
|
|
21
|
-
@staticmethod
|
|
22
|
-
def _build_format_args(text: str, **extra_kwargs) -> dict[str, str]:
|
|
23
|
-
# Base formatting args
|
|
24
|
-
format_args = {"input": text}
|
|
25
|
-
# Merge extras
|
|
26
|
-
format_args.update(extra_kwargs)
|
|
27
|
-
return format_args
|
|
28
|
-
|
|
29
21
|
# Use lru_cache to load each file once
|
|
30
22
|
@lru_cache(maxsize=32)
|
|
31
23
|
def _load_templates(self, prompt_file: str, mode: str | None) -> dict[str, str]:
|
|
@@ -69,16 +61,6 @@ class PromptLoader:
|
|
|
69
61
|
+ (f" for mode '{mode}'" if mode else "")
|
|
70
62
|
)
|
|
71
63
|
|
|
72
|
-
if (
|
|
73
|
-
not analyze_template
|
|
74
|
-
or not analyze_template.strip()
|
|
75
|
-
or analyze_template.strip() in ["{analyze_template}", "{}"]
|
|
76
|
-
):
|
|
77
|
-
raise PromptError(
|
|
78
|
-
"analyze_template cannot be empty"
|
|
79
|
-
+ (f" for mode '{mode}'" if mode else "")
|
|
80
|
-
)
|
|
81
|
-
|
|
82
64
|
return {
|
|
83
65
|
self.MAIN_TEMPLATE: main_template,
|
|
84
66
|
self.ANALYZE_TEMPLATE: analyze_template,
|
|
@@ -94,7 +76,8 @@ class PromptLoader:
|
|
|
94
76
|
) -> dict[str, str]:
|
|
95
77
|
try:
|
|
96
78
|
template_configs = self._load_templates(prompt_file, mode)
|
|
97
|
-
format_args =
|
|
79
|
+
format_args = {"text": text}
|
|
80
|
+
format_args.update(extra_kwargs)
|
|
98
81
|
|
|
99
82
|
# Inject variables inside each template
|
|
100
83
|
for key in template_configs.keys():
|
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
from typing import TypeVar, Type
|
|
2
2
|
from collections.abc import Callable
|
|
3
|
-
import logging
|
|
4
3
|
|
|
5
4
|
from openai import OpenAI
|
|
6
5
|
from pydantic import BaseModel
|
|
7
6
|
|
|
8
|
-
from texttools.internals.models import
|
|
7
|
+
from texttools.internals.models import OperatorOutput
|
|
9
8
|
from texttools.internals.operator_utils import OperatorUtils
|
|
10
9
|
from texttools.internals.prompt_loader import PromptLoader
|
|
11
10
|
from texttools.internals.exceptions import (
|
|
@@ -18,39 +17,21 @@ from texttools.internals.exceptions import (
|
|
|
18
17
|
# Base Model type for output models
|
|
19
18
|
T = TypeVar("T", bound=BaseModel)
|
|
20
19
|
|
|
21
|
-
logger = logging.getLogger("texttools.sync_operator")
|
|
22
|
-
|
|
23
20
|
|
|
24
21
|
class Operator:
|
|
25
22
|
"""
|
|
26
|
-
Core engine for running text-processing operations with an LLM
|
|
27
|
-
|
|
28
|
-
It wires together:
|
|
29
|
-
- `PromptLoader` → loads YAML prompt templates.
|
|
30
|
-
- `UserMergeFormatter` → applies formatting to messages (e.g., merging).
|
|
31
|
-
- OpenAI client → executes completions/parsed completions.
|
|
23
|
+
Core engine for running text-processing operations with an LLM.
|
|
32
24
|
"""
|
|
33
25
|
|
|
34
26
|
def __init__(self, client: OpenAI, model: str):
|
|
35
27
|
self._client = client
|
|
36
28
|
self._model = model
|
|
37
29
|
|
|
38
|
-
def
|
|
39
|
-
"""
|
|
40
|
-
Calls OpenAI API for analysis using the configured prompt template.
|
|
41
|
-
Returns the analyzed content as a string.
|
|
42
|
-
"""
|
|
30
|
+
def _analyze_completion(self, analyze_message: list[dict[str, str]]) -> str:
|
|
43
31
|
try:
|
|
44
|
-
analyze_prompt = prompt_configs["analyze_template"]
|
|
45
|
-
|
|
46
|
-
if not analyze_prompt:
|
|
47
|
-
raise PromptError("Analyze template is empty")
|
|
48
|
-
|
|
49
|
-
analyze_message = [OperatorUtils.build_user_message(analyze_prompt)]
|
|
50
32
|
completion = self._client.chat.completions.create(
|
|
51
33
|
model=self._model,
|
|
52
34
|
messages=analyze_message,
|
|
53
|
-
temperature=temperature,
|
|
54
35
|
)
|
|
55
36
|
|
|
56
37
|
if not completion.choices:
|
|
@@ -61,7 +42,7 @@ class Operator:
|
|
|
61
42
|
if not analysis:
|
|
62
43
|
raise LLMError("Empty analysis response")
|
|
63
44
|
|
|
64
|
-
return analysis
|
|
45
|
+
return analysis
|
|
65
46
|
|
|
66
47
|
except Exception as e:
|
|
67
48
|
if isinstance(e, (PromptError, LLMError)):
|
|
@@ -70,12 +51,12 @@ class Operator:
|
|
|
70
51
|
|
|
71
52
|
def _parse_completion(
|
|
72
53
|
self,
|
|
73
|
-
|
|
54
|
+
main_message: list[dict[str, str]],
|
|
74
55
|
output_model: Type[T],
|
|
75
56
|
temperature: float,
|
|
76
|
-
logprobs: bool
|
|
77
|
-
top_logprobs: int
|
|
78
|
-
priority: int
|
|
57
|
+
logprobs: bool,
|
|
58
|
+
top_logprobs: int,
|
|
59
|
+
priority: int,
|
|
79
60
|
) -> tuple[T, object]:
|
|
80
61
|
"""
|
|
81
62
|
Parses a chat completion using OpenAI's structured output format.
|
|
@@ -84,7 +65,7 @@ class Operator:
|
|
|
84
65
|
try:
|
|
85
66
|
request_kwargs = {
|
|
86
67
|
"model": self._model,
|
|
87
|
-
"messages":
|
|
68
|
+
"messages": main_message,
|
|
88
69
|
"response_format": output_model,
|
|
89
70
|
"temperature": temperature,
|
|
90
71
|
}
|
|
@@ -92,8 +73,10 @@ class Operator:
|
|
|
92
73
|
if logprobs:
|
|
93
74
|
request_kwargs["logprobs"] = True
|
|
94
75
|
request_kwargs["top_logprobs"] = top_logprobs
|
|
76
|
+
|
|
95
77
|
if priority:
|
|
96
78
|
request_kwargs["extra_body"] = {"priority": priority}
|
|
79
|
+
|
|
97
80
|
completion = self._client.beta.chat.completions.parse(**request_kwargs)
|
|
98
81
|
|
|
99
82
|
if not completion.choices:
|
|
@@ -120,24 +103,24 @@ class Operator:
|
|
|
120
103
|
user_prompt: str | None,
|
|
121
104
|
temperature: float,
|
|
122
105
|
logprobs: bool,
|
|
123
|
-
top_logprobs: int
|
|
106
|
+
top_logprobs: int,
|
|
124
107
|
validator: Callable[[object], bool] | None,
|
|
125
108
|
max_validation_retries: int | None,
|
|
109
|
+
priority: int,
|
|
126
110
|
# Internal parameters
|
|
127
111
|
prompt_file: str,
|
|
128
112
|
output_model: Type[T],
|
|
129
113
|
mode: str | None,
|
|
130
|
-
priority: int | None = 0,
|
|
131
114
|
**extra_kwargs,
|
|
132
|
-
) ->
|
|
115
|
+
) -> OperatorOutput:
|
|
133
116
|
"""
|
|
134
|
-
Execute the LLM pipeline with the given input text.
|
|
117
|
+
Execute the LLM pipeline with the given input text.
|
|
135
118
|
"""
|
|
136
119
|
try:
|
|
137
|
-
|
|
138
|
-
|
|
120
|
+
if logprobs and (not isinstance(top_logprobs, int) or top_logprobs < 2):
|
|
121
|
+
raise ValueError("top_logprobs should be an int greater than 1")
|
|
139
122
|
|
|
140
|
-
|
|
123
|
+
prompt_loader = PromptLoader()
|
|
141
124
|
prompt_configs = prompt_loader.load(
|
|
142
125
|
prompt_file=prompt_file,
|
|
143
126
|
text=text.strip(),
|
|
@@ -145,67 +128,45 @@ class Operator:
|
|
|
145
128
|
**extra_kwargs,
|
|
146
129
|
)
|
|
147
130
|
|
|
148
|
-
|
|
131
|
+
analysis: str | None = None
|
|
149
132
|
|
|
150
133
|
if with_analysis:
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
OperatorUtils.build_user_message(
|
|
154
|
-
f"Based on this analysis: {analysis}"
|
|
155
|
-
)
|
|
156
|
-
)
|
|
157
|
-
|
|
158
|
-
if output_lang:
|
|
159
|
-
messages.append(
|
|
160
|
-
OperatorUtils.build_user_message(
|
|
161
|
-
f"Respond only in the {output_lang} language."
|
|
162
|
-
)
|
|
134
|
+
analyze_message = OperatorUtils.build_message(
|
|
135
|
+
prompt_configs["analyze_template"]
|
|
163
136
|
)
|
|
137
|
+
analysis = self._analyze_completion(analyze_message)
|
|
164
138
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
f"Consider this instruction {user_prompt}"
|
|
169
|
-
)
|
|
139
|
+
main_message = OperatorUtils.build_message(
|
|
140
|
+
OperatorUtils.build_main_prompt(
|
|
141
|
+
prompt_configs["main_template"], analysis, output_lang, user_prompt
|
|
170
142
|
)
|
|
171
|
-
|
|
172
|
-
messages.append(
|
|
173
|
-
OperatorUtils.build_user_message(prompt_configs["main_template"])
|
|
174
143
|
)
|
|
175
144
|
|
|
176
|
-
messages = OperatorUtils.user_merge_format(messages)
|
|
177
|
-
|
|
178
|
-
if logprobs and (not isinstance(top_logprobs, int) or top_logprobs < 2):
|
|
179
|
-
raise ValueError("top_logprobs should be an integer greater than 1")
|
|
180
|
-
|
|
181
145
|
parsed, completion = self._parse_completion(
|
|
182
|
-
|
|
146
|
+
main_message,
|
|
147
|
+
output_model,
|
|
148
|
+
temperature,
|
|
149
|
+
logprobs,
|
|
150
|
+
top_logprobs,
|
|
151
|
+
priority,
|
|
183
152
|
)
|
|
184
153
|
|
|
185
|
-
output.result = parsed.result
|
|
186
|
-
|
|
187
154
|
# Retry logic if validation fails
|
|
188
|
-
if validator and not validator(
|
|
155
|
+
if validator and not validator(parsed.result):
|
|
189
156
|
if (
|
|
190
157
|
not isinstance(max_validation_retries, int)
|
|
191
158
|
or max_validation_retries < 1
|
|
192
159
|
):
|
|
193
|
-
raise ValueError(
|
|
194
|
-
"max_validation_retries should be a positive integer"
|
|
195
|
-
)
|
|
160
|
+
raise ValueError("max_validation_retries should be a positive int")
|
|
196
161
|
|
|
197
162
|
succeeded = False
|
|
198
|
-
for
|
|
199
|
-
|
|
200
|
-
f"Validation failed, retrying for the {attempt + 1} time."
|
|
201
|
-
)
|
|
202
|
-
|
|
203
|
-
# Generate new temperature for retry
|
|
163
|
+
for _ in range(max_validation_retries):
|
|
164
|
+
# Generate a new temperature to retry
|
|
204
165
|
retry_temperature = OperatorUtils.get_retry_temp(temperature)
|
|
205
166
|
|
|
206
167
|
try:
|
|
207
168
|
parsed, completion = self._parse_completion(
|
|
208
|
-
|
|
169
|
+
main_message,
|
|
209
170
|
output_model,
|
|
210
171
|
retry_temperature,
|
|
211
172
|
logprobs,
|
|
@@ -213,30 +174,26 @@ class Operator:
|
|
|
213
174
|
priority=priority,
|
|
214
175
|
)
|
|
215
176
|
|
|
216
|
-
output.result = parsed.result
|
|
217
|
-
|
|
218
177
|
# Check if retry was successful
|
|
219
|
-
if validator(
|
|
178
|
+
if validator(parsed.result):
|
|
220
179
|
succeeded = True
|
|
221
180
|
break
|
|
222
181
|
|
|
223
|
-
except LLMError
|
|
224
|
-
|
|
182
|
+
except LLMError:
|
|
183
|
+
pass
|
|
225
184
|
|
|
226
185
|
if not succeeded:
|
|
227
|
-
raise ValidationError(
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
output.process = prompt_file[:-5]
|
|
186
|
+
raise ValidationError("Validation failed after all retries")
|
|
187
|
+
|
|
188
|
+
operator_output = OperatorOutput(
|
|
189
|
+
result=parsed.result,
|
|
190
|
+
analysis=analysis if with_analysis else None,
|
|
191
|
+
logprobs=OperatorUtils.extract_logprobs(completion)
|
|
192
|
+
if logprobs
|
|
193
|
+
else None,
|
|
194
|
+
)
|
|
238
195
|
|
|
239
|
-
return
|
|
196
|
+
return operator_output
|
|
240
197
|
|
|
241
198
|
except (PromptError, LLMError, ValidationError):
|
|
242
199
|
raise
|
texttools/prompts/README.md
CHANGED
|
@@ -15,7 +15,7 @@ This folder contains YAML files for all prompts used in the project. Each file r
|
|
|
15
15
|
```yaml
|
|
16
16
|
main_template:
|
|
17
17
|
mode_1: |
|
|
18
|
-
Your main instructions here with placeholders like {
|
|
18
|
+
Your main instructions here with placeholders like {text}.
|
|
19
19
|
mode_2: |
|
|
20
20
|
Optional reasoning instructions here.
|
|
21
21
|
|
|
@@ -30,6 +30,6 @@ analyze_template:
|
|
|
30
30
|
|
|
31
31
|
## Guidelines
|
|
32
32
|
1. **Naming**: Use descriptive names for each YAML file corresponding to the tool or task it serves.
|
|
33
|
-
2. **Placeholders**: Use `{
|
|
33
|
+
2. **Placeholders**: Use `{text}` or other relevant placeholders to dynamically inject data.
|
|
34
34
|
3. **Modes**: If using modes, ensure both `main_template` and `analyze_template` contain the corresponding keys.
|
|
35
35
|
4. **Consistency**: Keep formatting consistent across files for easier parsing by scripts.
|