hamtaa-texttools 1.1.1__py3-none-any.whl → 1.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hamtaa_texttools-1.1.1.dist-info → hamtaa_texttools-1.1.16.dist-info}/METADATA +98 -26
- hamtaa_texttools-1.1.16.dist-info/RECORD +31 -0
- texttools/__init__.py +6 -8
- texttools/batch/batch_config.py +26 -0
- texttools/batch/batch_runner.py +105 -151
- texttools/batch/{batch_manager.py → internals/batch_manager.py} +39 -40
- texttools/batch/internals/utils.py +16 -0
- texttools/prompts/README.md +4 -4
- texttools/prompts/categorize.yaml +77 -0
- texttools/prompts/detect_entity.yaml +22 -0
- texttools/prompts/extract_keywords.yaml +68 -18
- texttools/tools/async_tools.py +804 -0
- texttools/tools/internals/async_operator.py +90 -69
- texttools/tools/internals/models.py +183 -0
- texttools/tools/internals/operator_utils.py +54 -0
- texttools/tools/internals/prompt_loader.py +13 -14
- texttools/tools/internals/sync_operator.py +201 -0
- texttools/tools/sync_tools.py +804 -0
- hamtaa_texttools-1.1.1.dist-info/RECORD +0 -30
- texttools/batch/__init__.py +0 -4
- texttools/prompts/categorizer.yaml +0 -28
- texttools/tools/__init__.py +0 -4
- texttools/tools/async_the_tool.py +0 -414
- texttools/tools/internals/base_operator.py +0 -91
- texttools/tools/internals/operator.py +0 -179
- texttools/tools/internals/output_models.py +0 -59
- texttools/tools/the_tool.py +0 -412
- {hamtaa_texttools-1.1.1.dist-info → hamtaa_texttools-1.1.16.dist-info}/WHEEL +0 -0
- {hamtaa_texttools-1.1.1.dist-info → hamtaa_texttools-1.1.16.dist-info}/licenses/LICENSE +0 -0
- {hamtaa_texttools-1.1.1.dist-info → hamtaa_texttools-1.1.16.dist-info}/top_level.txt +0 -0
|
@@ -1,23 +1,22 @@
|
|
|
1
|
-
from typing import Any, TypeVar, Type
|
|
1
|
+
from typing import Any, TypeVar, Type
|
|
2
|
+
from collections.abc import Callable
|
|
2
3
|
import logging
|
|
3
4
|
|
|
4
5
|
from openai import AsyncOpenAI
|
|
5
6
|
from pydantic import BaseModel
|
|
6
7
|
|
|
7
|
-
from texttools.tools.internals.
|
|
8
|
-
from texttools.tools.internals.
|
|
8
|
+
from texttools.tools.internals.models import ToolOutput
|
|
9
|
+
from texttools.tools.internals.operator_utils import OperatorUtils
|
|
9
10
|
from texttools.tools.internals.formatters import Formatter
|
|
10
11
|
from texttools.tools.internals.prompt_loader import PromptLoader
|
|
11
12
|
|
|
12
13
|
# Base Model type for output models
|
|
13
14
|
T = TypeVar("T", bound=BaseModel)
|
|
14
15
|
|
|
15
|
-
|
|
16
|
-
logger = logging.getLogger("async_operator")
|
|
17
|
-
logger.setLevel(logging.INFO)
|
|
16
|
+
logger = logging.getLogger("texttools.async_operator")
|
|
18
17
|
|
|
19
18
|
|
|
20
|
-
class AsyncOperator
|
|
19
|
+
class AsyncOperator:
|
|
21
20
|
"""
|
|
22
21
|
Core engine for running text-processing operations with an LLM (Async).
|
|
23
22
|
|
|
@@ -28,14 +27,18 @@ class AsyncOperator(BaseOperator):
|
|
|
28
27
|
"""
|
|
29
28
|
|
|
30
29
|
def __init__(self, client: AsyncOpenAI, model: str):
|
|
31
|
-
self.
|
|
32
|
-
self.
|
|
30
|
+
self._client = client
|
|
31
|
+
self._model = model
|
|
33
32
|
|
|
34
33
|
async def _analyze(self, prompt_configs: dict[str, str], temperature: float) -> str:
|
|
34
|
+
"""
|
|
35
|
+
Calls OpenAI API for analysis using the configured prompt template.
|
|
36
|
+
Returns the analyzed content as a string.
|
|
37
|
+
"""
|
|
35
38
|
analyze_prompt = prompt_configs["analyze_template"]
|
|
36
|
-
analyze_message = [
|
|
37
|
-
completion = await self.
|
|
38
|
-
model=self.
|
|
39
|
+
analyze_message = [OperatorUtils.build_user_message(analyze_prompt)]
|
|
40
|
+
completion = await self._client.chat.completions.create(
|
|
41
|
+
model=self._model,
|
|
39
42
|
messages=analyze_message,
|
|
40
43
|
temperature=temperature,
|
|
41
44
|
)
|
|
@@ -49,9 +52,14 @@ class AsyncOperator(BaseOperator):
|
|
|
49
52
|
temperature: float,
|
|
50
53
|
logprobs: bool = False,
|
|
51
54
|
top_logprobs: int = 3,
|
|
52
|
-
|
|
55
|
+
priority: int | None = 0,
|
|
56
|
+
) -> tuple[T, Any]:
|
|
57
|
+
"""
|
|
58
|
+
Parses a chat completion using OpenAI's structured output format.
|
|
59
|
+
Returns both the parsed object and the raw completion for logprobs.
|
|
60
|
+
"""
|
|
53
61
|
request_kwargs = {
|
|
54
|
-
"model": self.
|
|
62
|
+
"model": self._model,
|
|
55
63
|
"messages": message,
|
|
56
64
|
"response_format": output_model,
|
|
57
65
|
"temperature": temperature,
|
|
@@ -60,40 +68,12 @@ class AsyncOperator(BaseOperator):
|
|
|
60
68
|
if logprobs:
|
|
61
69
|
request_kwargs["logprobs"] = True
|
|
62
70
|
request_kwargs["top_logprobs"] = top_logprobs
|
|
63
|
-
|
|
64
|
-
|
|
71
|
+
if priority:
|
|
72
|
+
request_kwargs["extra_body"] = {"priority": priority}
|
|
73
|
+
completion = await self._client.beta.chat.completions.parse(**request_kwargs)
|
|
65
74
|
parsed = completion.choices[0].message.parsed
|
|
66
75
|
return parsed, completion
|
|
67
76
|
|
|
68
|
-
async def _vllm_completion(
|
|
69
|
-
self,
|
|
70
|
-
message: list[dict[str, str]],
|
|
71
|
-
output_model: Type[T],
|
|
72
|
-
temperature: float,
|
|
73
|
-
logprobs: bool = False,
|
|
74
|
-
top_logprobs: int = 3,
|
|
75
|
-
) -> tuple[Type[T], Any]:
|
|
76
|
-
json_schema = output_model.model_json_schema()
|
|
77
|
-
|
|
78
|
-
# Build kwargs dynamically
|
|
79
|
-
request_kwargs = {
|
|
80
|
-
"model": self.model,
|
|
81
|
-
"messages": message,
|
|
82
|
-
"extra_body": {"guided_json": json_schema},
|
|
83
|
-
"temperature": temperature,
|
|
84
|
-
}
|
|
85
|
-
|
|
86
|
-
if logprobs:
|
|
87
|
-
request_kwargs["logprobs"] = True
|
|
88
|
-
request_kwargs["top_logprobs"] = top_logprobs
|
|
89
|
-
|
|
90
|
-
completion = await self.client.chat.completions.create(**request_kwargs)
|
|
91
|
-
response = completion.choices[0].message.content
|
|
92
|
-
|
|
93
|
-
# Convert the string response to output model
|
|
94
|
-
parsed = self._convert_to_output_model(response, output_model)
|
|
95
|
-
return parsed, completion
|
|
96
|
-
|
|
97
77
|
async def run(
|
|
98
78
|
self,
|
|
99
79
|
# User parameters
|
|
@@ -104,20 +84,24 @@ class AsyncOperator(BaseOperator):
|
|
|
104
84
|
temperature: float,
|
|
105
85
|
logprobs: bool,
|
|
106
86
|
top_logprobs: int | None,
|
|
87
|
+
validator: Callable[[Any], bool] | None,
|
|
88
|
+
max_validation_retries: int | None,
|
|
107
89
|
# Internal parameters
|
|
108
90
|
prompt_file: str,
|
|
109
91
|
output_model: Type[T],
|
|
110
|
-
resp_format: Literal["vllm", "parse"],
|
|
111
92
|
mode: str | None,
|
|
93
|
+
priority: int | None = 0,
|
|
112
94
|
**extra_kwargs,
|
|
113
|
-
) ->
|
|
95
|
+
) -> ToolOutput:
|
|
114
96
|
"""
|
|
115
97
|
Execute the async LLM pipeline with the given input text. (Async)
|
|
116
98
|
"""
|
|
117
99
|
prompt_loader = PromptLoader()
|
|
118
100
|
formatter = Formatter()
|
|
101
|
+
output = ToolOutput()
|
|
119
102
|
|
|
120
103
|
try:
|
|
104
|
+
# Prompt configs contain two keys: main_template and analyze template, both are string
|
|
121
105
|
prompt_configs = prompt_loader.load(
|
|
122
106
|
prompt_file=prompt_file,
|
|
123
107
|
text=text.strip(),
|
|
@@ -125,55 +109,92 @@ class AsyncOperator(BaseOperator):
|
|
|
125
109
|
**extra_kwargs,
|
|
126
110
|
)
|
|
127
111
|
|
|
128
|
-
messages
|
|
112
|
+
messages = []
|
|
129
113
|
|
|
130
114
|
if with_analysis:
|
|
131
115
|
analysis = await self._analyze(prompt_configs, temperature)
|
|
132
116
|
messages.append(
|
|
133
|
-
|
|
117
|
+
OperatorUtils.build_user_message(
|
|
118
|
+
f"Based on this analysis: {analysis}"
|
|
119
|
+
)
|
|
134
120
|
)
|
|
135
121
|
|
|
136
122
|
if output_lang:
|
|
137
123
|
messages.append(
|
|
138
|
-
|
|
124
|
+
OperatorUtils.build_user_message(
|
|
139
125
|
f"Respond only in the {output_lang} language."
|
|
140
126
|
)
|
|
141
127
|
)
|
|
142
128
|
|
|
143
129
|
if user_prompt:
|
|
144
130
|
messages.append(
|
|
145
|
-
|
|
131
|
+
OperatorUtils.build_user_message(
|
|
132
|
+
f"Consider this instruction {user_prompt}"
|
|
133
|
+
)
|
|
146
134
|
)
|
|
147
135
|
|
|
148
|
-
messages.append(
|
|
136
|
+
messages.append(
|
|
137
|
+
OperatorUtils.build_user_message(prompt_configs["main_template"])
|
|
138
|
+
)
|
|
139
|
+
|
|
149
140
|
messages = formatter.user_merge_format(messages)
|
|
150
141
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
)
|
|
155
|
-
elif resp_format == "parse":
|
|
156
|
-
parsed, completion = await self._parse_completion(
|
|
157
|
-
messages, output_model, temperature, logprobs, top_logprobs
|
|
158
|
-
)
|
|
142
|
+
parsed, completion = await self._parse_completion(
|
|
143
|
+
messages, output_model, temperature, logprobs, top_logprobs, priority
|
|
144
|
+
)
|
|
159
145
|
|
|
160
|
-
|
|
161
|
-
if not hasattr(parsed, "result"):
|
|
162
|
-
logger.error(
|
|
163
|
-
"The provided output_model must define a field named 'result'"
|
|
164
|
-
)
|
|
146
|
+
output.result = parsed.result
|
|
165
147
|
|
|
166
|
-
|
|
148
|
+
# Retry logic if validation fails
|
|
149
|
+
if validator and not validator(output.result):
|
|
150
|
+
for attempt in range(max_validation_retries):
|
|
151
|
+
logger.warning(
|
|
152
|
+
f"Validation failed, retrying for the {attempt + 1} time."
|
|
153
|
+
)
|
|
167
154
|
|
|
168
|
-
|
|
155
|
+
# Generate new temperature for retry
|
|
156
|
+
retry_temperature = OperatorUtils.get_retry_temp(temperature)
|
|
157
|
+
try:
|
|
158
|
+
parsed, completion = await self._parse_completion(
|
|
159
|
+
messages,
|
|
160
|
+
output_model,
|
|
161
|
+
retry_temperature,
|
|
162
|
+
logprobs,
|
|
163
|
+
top_logprobs,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
output.result = parsed.result
|
|
167
|
+
|
|
168
|
+
# Check if retry was successful
|
|
169
|
+
if validator(output.result):
|
|
170
|
+
logger.info(
|
|
171
|
+
f"Validation passed on retry attempt {attempt + 1}"
|
|
172
|
+
)
|
|
173
|
+
break
|
|
174
|
+
else:
|
|
175
|
+
logger.warning(
|
|
176
|
+
f"Validation still failing after retry attempt {attempt + 1}"
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
except Exception as e:
|
|
180
|
+
logger.error(f"Retry attempt {attempt + 1} failed: {e}")
|
|
181
|
+
# Continue to next retry attempt if this one fails
|
|
182
|
+
|
|
183
|
+
# Final check after all retries
|
|
184
|
+
if validator and not validator(output.result):
|
|
185
|
+
output.errors.append("Validation failed after all retry attempts")
|
|
169
186
|
|
|
170
187
|
if logprobs:
|
|
171
|
-
output.logprobs =
|
|
188
|
+
output.logprobs = OperatorUtils.extract_logprobs(completion)
|
|
172
189
|
|
|
173
190
|
if with_analysis:
|
|
174
191
|
output.analysis = analysis
|
|
175
192
|
|
|
193
|
+
output.process = prompt_file[:-5]
|
|
194
|
+
|
|
176
195
|
return output
|
|
196
|
+
|
|
177
197
|
except Exception as e:
|
|
178
198
|
logger.error(f"AsyncTheTool failed: {e}")
|
|
179
|
-
|
|
199
|
+
output.errors.append(str(e))
|
|
200
|
+
return output
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Type, Any, Literal
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field, create_model
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ToolOutput(BaseModel):
|
|
8
|
+
result: Any = None
|
|
9
|
+
analysis: str = ""
|
|
10
|
+
logprobs: list[dict[str, Any]] = []
|
|
11
|
+
process: str = ""
|
|
12
|
+
processed_at: datetime = datetime.now()
|
|
13
|
+
execution_time: float = -1.0
|
|
14
|
+
errors: list[str] = []
|
|
15
|
+
|
|
16
|
+
def __repr__(self) -> str:
|
|
17
|
+
return f"ToolOutput(process='{self.process}', result_type='{type(self.result)}', result='{self.result}', analysis='{self.analysis}', logprobs='{self.logprobs}', errors='{self.errors}', processed_at='{self.processed_at}', execution_time='{self.execution_time}'"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class StrOutput(BaseModel):
|
|
21
|
+
result: str = Field(..., description="The output string")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BoolOutput(BaseModel):
|
|
25
|
+
result: bool = Field(
|
|
26
|
+
..., description="Boolean indicating the output state", example=True
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ListStrOutput(BaseModel):
|
|
31
|
+
result: list[str] = Field(
|
|
32
|
+
..., description="The output list of strings", example=["text_1", "text_2"]
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ListDictStrStrOutput(BaseModel):
|
|
37
|
+
result: list[dict[str, str]] = Field(
|
|
38
|
+
...,
|
|
39
|
+
description="List of dictionaries containing string key-value pairs",
|
|
40
|
+
example=[{"text": "Mohammad", "type": "PER"}],
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ReasonListStrOutput(BaseModel):
|
|
45
|
+
reason: str = Field(..., description="Thinking process that led to the output")
|
|
46
|
+
result: list[str] = Field(..., description="The output list of strings")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Node(BaseModel):
|
|
50
|
+
node_id: int
|
|
51
|
+
name: str
|
|
52
|
+
level: int
|
|
53
|
+
parent_id: int | None
|
|
54
|
+
description: str = "No description provided"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class CategoryTree:
|
|
58
|
+
def __init__(self, tree_name):
|
|
59
|
+
self.root = Node(node_id=0, name=tree_name, level=0, parent_id=None)
|
|
60
|
+
self.all_nodes: list[Node] = [self.root]
|
|
61
|
+
self.new_id = 1
|
|
62
|
+
|
|
63
|
+
def add_node(
|
|
64
|
+
self,
|
|
65
|
+
node_name: str,
|
|
66
|
+
parent_name: str | None = None,
|
|
67
|
+
description: str | None = None,
|
|
68
|
+
) -> None:
|
|
69
|
+
if self.find_node(node_name):
|
|
70
|
+
raise ValueError(f"{node_name} has been chosen for another category before")
|
|
71
|
+
|
|
72
|
+
if parent_name:
|
|
73
|
+
parent_node = self.find_node(parent_name)
|
|
74
|
+
if parent_node is None:
|
|
75
|
+
raise ValueError(f"Parent category '{parent_name}' not found")
|
|
76
|
+
parent_id = parent_node.node_id
|
|
77
|
+
level = parent_node.level + 1
|
|
78
|
+
else:
|
|
79
|
+
level = 1
|
|
80
|
+
parent_id = 0
|
|
81
|
+
|
|
82
|
+
node_data = {
|
|
83
|
+
"node_id": self.new_id,
|
|
84
|
+
"name": node_name,
|
|
85
|
+
"level": level,
|
|
86
|
+
"parent_id": parent_id,
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
if description is not None:
|
|
90
|
+
node_data["description"] = description
|
|
91
|
+
|
|
92
|
+
self.all_nodes.append(Node(**node_data))
|
|
93
|
+
self.new_id += 1
|
|
94
|
+
|
|
95
|
+
def get_nodes(self) -> list[Node]:
|
|
96
|
+
return self.all_nodes
|
|
97
|
+
|
|
98
|
+
def get_level_count(self) -> int:
|
|
99
|
+
return max([item.level for item in self.all_nodes])
|
|
100
|
+
|
|
101
|
+
def find_node(self, identifier: int | str) -> Node | None:
|
|
102
|
+
if isinstance(identifier, str):
|
|
103
|
+
for node in self.get_nodes():
|
|
104
|
+
if node.name == identifier:
|
|
105
|
+
return node
|
|
106
|
+
return None
|
|
107
|
+
elif isinstance(identifier, int):
|
|
108
|
+
for node in self.get_nodes():
|
|
109
|
+
if node.node_id == identifier:
|
|
110
|
+
return node
|
|
111
|
+
return None
|
|
112
|
+
else:
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
def find_children(self, parent_node: Node) -> list[Node] | None:
|
|
116
|
+
children = [
|
|
117
|
+
node for node in self.get_nodes() if parent_node.node_id == node.parent_id
|
|
118
|
+
]
|
|
119
|
+
return children if children else None
|
|
120
|
+
|
|
121
|
+
def remove_node(self, identifier: int | str) -> None:
|
|
122
|
+
node = self.find_node(identifier)
|
|
123
|
+
|
|
124
|
+
if node is not None:
|
|
125
|
+
# Remove node's children recursively
|
|
126
|
+
children = self.find_children(node)
|
|
127
|
+
|
|
128
|
+
# Ending condition
|
|
129
|
+
if children is None:
|
|
130
|
+
self.all_nodes.remove(node)
|
|
131
|
+
return
|
|
132
|
+
|
|
133
|
+
for child in children:
|
|
134
|
+
self.remove_node(child.name)
|
|
135
|
+
|
|
136
|
+
# Remove the node from tree
|
|
137
|
+
self.all_nodes.remove(node)
|
|
138
|
+
else:
|
|
139
|
+
raise ValueError(f"Node with identifier: '{identifier}' not found.")
|
|
140
|
+
|
|
141
|
+
def dump_tree(self) -> dict:
|
|
142
|
+
def build_dict(node: Node) -> dict:
|
|
143
|
+
children = [
|
|
144
|
+
build_dict(child)
|
|
145
|
+
for child in self.all_nodes
|
|
146
|
+
if child.parent_id == node.node_id
|
|
147
|
+
]
|
|
148
|
+
return {
|
|
149
|
+
"node_id": node.node_id,
|
|
150
|
+
"name": node.name,
|
|
151
|
+
"level": node.level,
|
|
152
|
+
"parent_id": node.parent_id,
|
|
153
|
+
"children": children,
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
return {"category_tree": build_dict(self.root)["children"]}
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# This function is needed to create CategorizerOutput with dynamic categories
|
|
160
|
+
def create_dynamic_model(allowed_values: list[str]) -> Type[BaseModel]:
|
|
161
|
+
literal_type = Literal[*allowed_values]
|
|
162
|
+
|
|
163
|
+
CategorizerOutput = create_model(
|
|
164
|
+
"CategorizerOutput",
|
|
165
|
+
reason=(
|
|
166
|
+
str,
|
|
167
|
+
Field(
|
|
168
|
+
..., description="Explanation of why the input belongs to the category"
|
|
169
|
+
),
|
|
170
|
+
),
|
|
171
|
+
result=(literal_type, Field(..., description="Predicted category label")),
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
return CategorizerOutput
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class Entity(BaseModel):
|
|
178
|
+
text: str = Field(description="The exact text of the entity")
|
|
179
|
+
entity_type: str = Field(description="The type of the entity")
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class EntityDetectorOutput(BaseModel):
|
|
183
|
+
result: list[Entity] = Field(description="List of all extracted entities")
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import math
|
|
3
|
+
import random
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class OperatorUtils:
|
|
7
|
+
@staticmethod
|
|
8
|
+
def build_user_message(prompt: str) -> dict[str, str]:
|
|
9
|
+
return {"role": "user", "content": prompt}
|
|
10
|
+
|
|
11
|
+
@staticmethod
|
|
12
|
+
def extract_logprobs(completion: dict) -> list[dict]:
|
|
13
|
+
"""
|
|
14
|
+
Extracts and filters token probabilities from completion logprobs.
|
|
15
|
+
Skips punctuation and structural tokens, returns cleaned probability data.
|
|
16
|
+
"""
|
|
17
|
+
logprobs_data = []
|
|
18
|
+
|
|
19
|
+
ignore_pattern = re.compile(r'^(result|[\s\[\]\{\}",:]+)$')
|
|
20
|
+
|
|
21
|
+
for choice in completion.choices:
|
|
22
|
+
if not getattr(choice, "logprobs", None):
|
|
23
|
+
return []
|
|
24
|
+
|
|
25
|
+
for logprob_item in choice.logprobs.content:
|
|
26
|
+
if ignore_pattern.match(logprob_item.token):
|
|
27
|
+
continue
|
|
28
|
+
token_entry = {
|
|
29
|
+
"token": logprob_item.token,
|
|
30
|
+
"prob": round(math.exp(logprob_item.logprob), 8),
|
|
31
|
+
"top_alternatives": [],
|
|
32
|
+
}
|
|
33
|
+
for alt in logprob_item.top_logprobs:
|
|
34
|
+
if ignore_pattern.match(alt.token):
|
|
35
|
+
continue
|
|
36
|
+
token_entry["top_alternatives"].append(
|
|
37
|
+
{
|
|
38
|
+
"token": alt.token,
|
|
39
|
+
"prob": round(math.exp(alt.logprob), 8),
|
|
40
|
+
}
|
|
41
|
+
)
|
|
42
|
+
logprobs_data.append(token_entry)
|
|
43
|
+
|
|
44
|
+
return logprobs_data
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def get_retry_temp(base_temp: float) -> float:
|
|
48
|
+
"""
|
|
49
|
+
Calculate temperature for retry attempts.
|
|
50
|
+
"""
|
|
51
|
+
delta_temp = random.choice([-1, 1]) * random.uniform(0.1, 0.9)
|
|
52
|
+
new_temp = base_temp + delta_temp
|
|
53
|
+
|
|
54
|
+
return max(0.0, min(new_temp, 1.5))
|
|
@@ -11,19 +11,25 @@ class PromptLoader:
|
|
|
11
11
|
- Load and parse YAML prompt definitions.
|
|
12
12
|
- Select the right template (by mode, if applicable).
|
|
13
13
|
- Inject variables (`{input}`, plus any extra kwargs) into the templates.
|
|
14
|
-
- Return a dict with:
|
|
15
|
-
{
|
|
16
|
-
"main_template": "...",
|
|
17
|
-
"analyze_template": "..." | None
|
|
18
|
-
}
|
|
19
14
|
"""
|
|
20
15
|
|
|
21
|
-
MAIN_TEMPLATE
|
|
22
|
-
ANALYZE_TEMPLATE
|
|
16
|
+
MAIN_TEMPLATE = "main_template"
|
|
17
|
+
ANALYZE_TEMPLATE = "analyze_template"
|
|
18
|
+
|
|
19
|
+
@staticmethod
|
|
20
|
+
def _build_format_args(text: str, **extra_kwargs) -> dict[str, str]:
|
|
21
|
+
# Base formatting args
|
|
22
|
+
format_args = {"input": text}
|
|
23
|
+
# Merge extras
|
|
24
|
+
format_args.update(extra_kwargs)
|
|
25
|
+
return format_args
|
|
23
26
|
|
|
24
27
|
# Use lru_cache to load each file once
|
|
25
28
|
@lru_cache(maxsize=32)
|
|
26
29
|
def _load_templates(self, prompt_file: str, mode: str | None) -> dict[str, str]:
|
|
30
|
+
"""
|
|
31
|
+
Loads prompt templates from YAML file with optional mode selection.
|
|
32
|
+
"""
|
|
27
33
|
base_dir = Path(__file__).parent.parent.parent / Path("prompts")
|
|
28
34
|
prompt_path = base_dir / prompt_file
|
|
29
35
|
data = yaml.safe_load(prompt_path.read_text(encoding="utf-8"))
|
|
@@ -37,13 +43,6 @@ class PromptLoader:
|
|
|
37
43
|
else data.get(self.ANALYZE_TEMPLATE),
|
|
38
44
|
}
|
|
39
45
|
|
|
40
|
-
def _build_format_args(self, text: str, **extra_kwargs) -> dict[str, str]:
|
|
41
|
-
# Base formatting args
|
|
42
|
-
format_args = {"input": text}
|
|
43
|
-
# Merge extras
|
|
44
|
-
format_args.update(extra_kwargs)
|
|
45
|
-
return format_args
|
|
46
|
-
|
|
47
46
|
def load(
|
|
48
47
|
self, prompt_file: str, text: str, mode: str, **extra_kwargs
|
|
49
48
|
) -> dict[str, str]:
|