hamtaa-texttools 1.0.1__py3-none-any.whl → 1.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hamtaa-texttools might be problematic. Click here for more details.
- hamtaa_texttools-1.1.7.dist-info/METADATA +228 -0
- hamtaa_texttools-1.1.7.dist-info/RECORD +30 -0
- {hamtaa_texttools-1.0.1.dist-info → hamtaa_texttools-1.1.7.dist-info}/licenses/LICENSE +20 -20
- {hamtaa_texttools-1.0.1.dist-info → hamtaa_texttools-1.1.7.dist-info}/top_level.txt +0 -0
- texttools/__init__.py +4 -9
- texttools/batch/__init__.py +3 -0
- texttools/{utils/batch_manager → batch}/batch_manager.py +226 -240
- texttools/batch/batch_runner.py +254 -0
- texttools/prompts/README.md +35 -0
- texttools/prompts/categorizer.yaml +28 -0
- texttools/prompts/extract_entities.yaml +20 -0
- texttools/prompts/extract_keywords.yaml +18 -0
- texttools/prompts/is_question.yaml +14 -0
- texttools/prompts/merge_questions.yaml +46 -0
- texttools/prompts/rewrite.yaml +111 -0
- texttools/prompts/run_custom.yaml +7 -0
- texttools/prompts/subject_to_question.yaml +22 -0
- texttools/prompts/summarize.yaml +14 -0
- texttools/prompts/text_to_question.yaml +20 -0
- texttools/prompts/translate.yaml +15 -0
- texttools/tools/__init__.py +4 -3
- texttools/tools/async_the_tool.py +435 -0
- texttools/tools/internals/async_operator.py +242 -0
- texttools/tools/internals/base_operator.py +100 -0
- texttools/tools/internals/formatters.py +24 -0
- texttools/tools/internals/operator.py +242 -0
- texttools/tools/internals/output_models.py +62 -0
- texttools/tools/internals/prompt_loader.py +60 -0
- texttools/tools/the_tool.py +433 -291
- hamtaa_texttools-1.0.1.dist-info/METADATA +0 -129
- hamtaa_texttools-1.0.1.dist-info/RECORD +0 -18
- texttools/formatters/base_formatter.py +0 -33
- texttools/formatters/user_merge_formatter/user_merge_formatter.py +0 -47
- texttools/prompts/__init__.py +0 -0
- texttools/tools/operator.py +0 -236
- texttools/tools/output_models.py +0 -54
- texttools/tools/prompt_loader.py +0 -84
- texttools/utils/__init__.py +0 -4
- texttools/utils/batch_manager/__init__.py +0 -4
- texttools/utils/batch_manager/batch_runner.py +0 -212
- {hamtaa_texttools-1.0.1.dist-info → hamtaa_texttools-1.1.7.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from typing import TypeVar, Type, Any
|
|
2
|
+
import json
|
|
3
|
+
import re
|
|
4
|
+
import math
|
|
5
|
+
import logging
|
|
6
|
+
import random
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
from openai import OpenAI, AsyncOpenAI
|
|
10
|
+
|
|
11
|
+
# Base Model type for output models
|
|
12
|
+
T = TypeVar("T", bound=BaseModel)
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger("texttools.base_operator")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BaseOperator:
|
|
18
|
+
def __init__(self, client: OpenAI | AsyncOpenAI, model: str):
|
|
19
|
+
self.client = client
|
|
20
|
+
self.model = model
|
|
21
|
+
|
|
22
|
+
def _build_user_message(self, prompt: str) -> dict[str, str]:
|
|
23
|
+
return {"role": "user", "content": prompt}
|
|
24
|
+
|
|
25
|
+
def _clean_json_response(self, response: str) -> str:
|
|
26
|
+
"""
|
|
27
|
+
Clean JSON response by removing code block markers and whitespace.
|
|
28
|
+
Handles cases like:
|
|
29
|
+
- ```json{"result": "value"}```
|
|
30
|
+
"""
|
|
31
|
+
stripped = response.strip()
|
|
32
|
+
cleaned = re.sub(r"^```(?:json)?\s*", "", stripped)
|
|
33
|
+
cleaned = re.sub(r"\s*```$", "", cleaned)
|
|
34
|
+
|
|
35
|
+
return cleaned.strip()
|
|
36
|
+
|
|
37
|
+
def _convert_to_output_model(
|
|
38
|
+
self, response_string: str, output_model: Type[T]
|
|
39
|
+
) -> Type[T]:
|
|
40
|
+
"""
|
|
41
|
+
Convert a JSON response string to output model.
|
|
42
|
+
"""
|
|
43
|
+
# Clean the response string
|
|
44
|
+
cleaned_json = self._clean_json_response(response_string)
|
|
45
|
+
|
|
46
|
+
# Fix Python-style booleans
|
|
47
|
+
cleaned_json = cleaned_json.replace("False", "false").replace("True", "true")
|
|
48
|
+
|
|
49
|
+
# Convert string to Python dictionary
|
|
50
|
+
response_dict = json.loads(cleaned_json)
|
|
51
|
+
|
|
52
|
+
# Convert dictionary to output model
|
|
53
|
+
return output_model(**response_dict)
|
|
54
|
+
|
|
55
|
+
def _extract_logprobs(self, completion: dict) -> list[dict[str, Any]]:
|
|
56
|
+
"""
|
|
57
|
+
Extracts and filters token probabilities from completion logprobs.
|
|
58
|
+
Skips punctuation and structural tokens, returns cleaned probability data.
|
|
59
|
+
"""
|
|
60
|
+
logprobs_data = []
|
|
61
|
+
|
|
62
|
+
ignore_pattern = re.compile(r'^(result|[\s\[\]\{\}",:]+)$')
|
|
63
|
+
|
|
64
|
+
for choice in completion.choices:
|
|
65
|
+
if not getattr(choice, "logprobs", None):
|
|
66
|
+
logger.error("logprobs is not avalible in the chosen model.")
|
|
67
|
+
return []
|
|
68
|
+
|
|
69
|
+
for logprob_item in choice.logprobs.content:
|
|
70
|
+
if ignore_pattern.match(logprob_item.token):
|
|
71
|
+
continue
|
|
72
|
+
token_entry = {
|
|
73
|
+
"token": logprob_item.token,
|
|
74
|
+
"prob": round(math.exp(logprob_item.logprob), 8),
|
|
75
|
+
"top_alternatives": [],
|
|
76
|
+
}
|
|
77
|
+
for alt in logprob_item.top_logprobs:
|
|
78
|
+
if ignore_pattern.match(alt.token):
|
|
79
|
+
continue
|
|
80
|
+
token_entry["top_alternatives"].append(
|
|
81
|
+
{
|
|
82
|
+
"token": alt.token,
|
|
83
|
+
"prob": round(math.exp(alt.logprob), 8),
|
|
84
|
+
}
|
|
85
|
+
)
|
|
86
|
+
logprobs_data.append(token_entry)
|
|
87
|
+
|
|
88
|
+
return logprobs_data
|
|
89
|
+
|
|
90
|
+
def _get_retry_temp(self, base_temp: float) -> float:
|
|
91
|
+
"""
|
|
92
|
+
Calculate temperature for retry attempts.
|
|
93
|
+
"""
|
|
94
|
+
delta_temp = random.choice([-1, 1]) * random.uniform(0.1, 0.9)
|
|
95
|
+
new_temp = base_temp + delta_temp
|
|
96
|
+
print(f"Base Temp: {base_temp}")
|
|
97
|
+
print(f"Delta Temp: {delta_temp}")
|
|
98
|
+
print(f"New Temp: {new_temp}")
|
|
99
|
+
|
|
100
|
+
return max(0.0, min(new_temp, 1.5))
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
class Formatter:
|
|
2
|
+
@staticmethod
|
|
3
|
+
def user_merge_format(messages: list[dict[str, str]]) -> list[dict[str, str]]:
|
|
4
|
+
"""
|
|
5
|
+
Merges consecutive user messages into a single message, separated by newlines.
|
|
6
|
+
|
|
7
|
+
This is useful for condensing a multi-turn user input into a single
|
|
8
|
+
message for the LLM. Assistant and system messages are left unchanged and
|
|
9
|
+
act as separators between user message groups.
|
|
10
|
+
"""
|
|
11
|
+
merged: list[dict[str, str]] = []
|
|
12
|
+
|
|
13
|
+
for message in messages:
|
|
14
|
+
role, content = message["role"], message["content"].strip()
|
|
15
|
+
|
|
16
|
+
# Merge with previous user turn
|
|
17
|
+
if merged and role == "user" and merged[-1]["role"] == "user":
|
|
18
|
+
merged[-1]["content"] += "\n" + content
|
|
19
|
+
|
|
20
|
+
# Otherwise, start a new turn
|
|
21
|
+
else:
|
|
22
|
+
merged.append({"role": role, "content": content})
|
|
23
|
+
|
|
24
|
+
return merged
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
from typing import Any, TypeVar, Type, Literal, Callable
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
from openai import OpenAI
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from texttools.tools.internals.output_models import ToolOutput
|
|
8
|
+
from texttools.tools.internals.base_operator import BaseOperator
|
|
9
|
+
from texttools.tools.internals.formatters import Formatter
|
|
10
|
+
from texttools.tools.internals.prompt_loader import PromptLoader
|
|
11
|
+
|
|
12
|
+
# Base Model type for output models
|
|
13
|
+
T = TypeVar("T", bound=BaseModel)
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger("texttools.operator")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Operator(BaseOperator):
|
|
19
|
+
"""
|
|
20
|
+
Core engine for running text-processing operations with an LLM (Sync).
|
|
21
|
+
|
|
22
|
+
It wires together:
|
|
23
|
+
- `PromptLoader` → loads YAML prompt templates.
|
|
24
|
+
- `UserMergeFormatter` → applies formatting to messages (e.g., merging).
|
|
25
|
+
- OpenAI client → executes completions/parsed completions.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, client: OpenAI, model: str):
|
|
29
|
+
self.client = client
|
|
30
|
+
self.model = model
|
|
31
|
+
|
|
32
|
+
def _analyze(self, prompt_configs: dict[str, str], temperature: float) -> str:
|
|
33
|
+
"""
|
|
34
|
+
Calls OpenAI API for analysis using the configured prompt template.
|
|
35
|
+
Returns the analyzed content as a string.
|
|
36
|
+
"""
|
|
37
|
+
analyze_prompt = prompt_configs["analyze_template"]
|
|
38
|
+
analyze_message = [self._build_user_message(analyze_prompt)]
|
|
39
|
+
completion = self.client.chat.completions.create(
|
|
40
|
+
model=self.model,
|
|
41
|
+
messages=analyze_message,
|
|
42
|
+
temperature=temperature,
|
|
43
|
+
)
|
|
44
|
+
analysis = completion.choices[0].message.content.strip()
|
|
45
|
+
return analysis
|
|
46
|
+
|
|
47
|
+
def _parse_completion(
|
|
48
|
+
self,
|
|
49
|
+
message: list[dict[str, str]],
|
|
50
|
+
output_model: Type[T],
|
|
51
|
+
temperature: float,
|
|
52
|
+
logprobs: bool = False,
|
|
53
|
+
top_logprobs: int = 3,
|
|
54
|
+
) -> tuple[Type[T], Any]:
|
|
55
|
+
"""
|
|
56
|
+
Parses a chat completion using OpenAI's structured output format.
|
|
57
|
+
Returns both the parsed object and the raw completion for logging.
|
|
58
|
+
"""
|
|
59
|
+
request_kwargs = {
|
|
60
|
+
"model": self.model,
|
|
61
|
+
"messages": message,
|
|
62
|
+
"response_format": output_model,
|
|
63
|
+
"temperature": temperature,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
if logprobs:
|
|
67
|
+
request_kwargs["logprobs"] = True
|
|
68
|
+
request_kwargs["top_logprobs"] = top_logprobs
|
|
69
|
+
|
|
70
|
+
completion = self.client.beta.chat.completions.parse(**request_kwargs)
|
|
71
|
+
parsed = completion.choices[0].message.parsed
|
|
72
|
+
return parsed, completion
|
|
73
|
+
|
|
74
|
+
def _vllm_completion(
|
|
75
|
+
self,
|
|
76
|
+
message: list[dict[str, str]],
|
|
77
|
+
output_model: Type[T],
|
|
78
|
+
temperature: float,
|
|
79
|
+
logprobs: bool = False,
|
|
80
|
+
top_logprobs: int = 3,
|
|
81
|
+
) -> tuple[Type[T], Any]:
|
|
82
|
+
"""
|
|
83
|
+
Generates a completion using vLLM with JSON schema guidance.
|
|
84
|
+
Returns the parsed output model and raw completion.
|
|
85
|
+
"""
|
|
86
|
+
json_schema = output_model.model_json_schema()
|
|
87
|
+
|
|
88
|
+
# Build kwargs dynamically
|
|
89
|
+
request_kwargs = {
|
|
90
|
+
"model": self.model,
|
|
91
|
+
"messages": message,
|
|
92
|
+
"extra_body": {"guided_json": json_schema},
|
|
93
|
+
"temperature": temperature,
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
if logprobs:
|
|
97
|
+
request_kwargs["logprobs"] = True
|
|
98
|
+
request_kwargs["top_logprobs"] = top_logprobs
|
|
99
|
+
|
|
100
|
+
completion = self.client.chat.completions.create(**request_kwargs)
|
|
101
|
+
response = completion.choices[0].message.content
|
|
102
|
+
|
|
103
|
+
# Convert the string response to output model
|
|
104
|
+
parsed = self._convert_to_output_model(response, output_model)
|
|
105
|
+
return parsed, completion
|
|
106
|
+
|
|
107
|
+
def run(
|
|
108
|
+
self,
|
|
109
|
+
# User parameters
|
|
110
|
+
text: str,
|
|
111
|
+
with_analysis: bool,
|
|
112
|
+
output_lang: str | None,
|
|
113
|
+
user_prompt: str | None,
|
|
114
|
+
temperature: float,
|
|
115
|
+
logprobs: bool,
|
|
116
|
+
top_logprobs: int | None,
|
|
117
|
+
validator: Callable[[Any], bool] | None,
|
|
118
|
+
# Internal parameters
|
|
119
|
+
prompt_file: str,
|
|
120
|
+
output_model: Type[T],
|
|
121
|
+
resp_format: Literal["vllm", "parse"],
|
|
122
|
+
mode: str | None,
|
|
123
|
+
**extra_kwargs,
|
|
124
|
+
) -> ToolOutput:
|
|
125
|
+
"""
|
|
126
|
+
Execute the LLM pipeline with the given input text.
|
|
127
|
+
"""
|
|
128
|
+
prompt_loader = PromptLoader()
|
|
129
|
+
formatter = Formatter()
|
|
130
|
+
output = ToolOutput()
|
|
131
|
+
|
|
132
|
+
try:
|
|
133
|
+
# Prompt configs contain two keys: main_template and analyze template, both are string
|
|
134
|
+
prompt_configs = prompt_loader.load(
|
|
135
|
+
prompt_file=prompt_file,
|
|
136
|
+
text=text.strip(),
|
|
137
|
+
mode=mode,
|
|
138
|
+
**extra_kwargs,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
messages: list[dict[str, str]] = []
|
|
142
|
+
|
|
143
|
+
if with_analysis:
|
|
144
|
+
analysis = self._analyze(prompt_configs, temperature)
|
|
145
|
+
messages.append(
|
|
146
|
+
self._build_user_message(f"Based on this analysis: {analysis}")
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
if output_lang:
|
|
150
|
+
messages.append(
|
|
151
|
+
self._build_user_message(
|
|
152
|
+
f"Respond only in the {output_lang} language."
|
|
153
|
+
)
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if user_prompt:
|
|
157
|
+
messages.append(
|
|
158
|
+
self._build_user_message(f"Consider this instruction {user_prompt}")
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
messages.append(self._build_user_message(prompt_configs["main_template"]))
|
|
162
|
+
messages = formatter.user_merge_format(messages)
|
|
163
|
+
|
|
164
|
+
if resp_format == "vllm":
|
|
165
|
+
parsed, completion = self._vllm_completion(
|
|
166
|
+
messages, output_model, temperature, logprobs, top_logprobs
|
|
167
|
+
)
|
|
168
|
+
elif resp_format == "parse":
|
|
169
|
+
parsed, completion = self._parse_completion(
|
|
170
|
+
messages, output_model, temperature, logprobs, top_logprobs
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Ensure output_model has a `result` field
|
|
174
|
+
if not hasattr(parsed, "result"):
|
|
175
|
+
error = "The provided output_model must define a field named 'result'"
|
|
176
|
+
logger.error(error)
|
|
177
|
+
output.errors.append(error)
|
|
178
|
+
return output
|
|
179
|
+
|
|
180
|
+
output.result = parsed.result
|
|
181
|
+
|
|
182
|
+
# Retry logic if validation fails
|
|
183
|
+
if validator and not validator(output.result):
|
|
184
|
+
max_retries = 3
|
|
185
|
+
for attempt in range(max_retries):
|
|
186
|
+
logger.warning(
|
|
187
|
+
f"Validation failed, retrying for the {attempt + 1} time."
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Generate new temperature for retry
|
|
191
|
+
retry_temperature = self._get_retry_temp(temperature)
|
|
192
|
+
try:
|
|
193
|
+
if resp_format == "vllm":
|
|
194
|
+
parsed, completion = self._vllm_completion(
|
|
195
|
+
messages,
|
|
196
|
+
output_model,
|
|
197
|
+
retry_temperature,
|
|
198
|
+
logprobs,
|
|
199
|
+
top_logprobs,
|
|
200
|
+
)
|
|
201
|
+
elif resp_format == "parse":
|
|
202
|
+
parsed, completion = self._parse_completion(
|
|
203
|
+
messages,
|
|
204
|
+
output_model,
|
|
205
|
+
retry_temperature,
|
|
206
|
+
logprobs,
|
|
207
|
+
top_logprobs,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
output.result = parsed.result
|
|
211
|
+
|
|
212
|
+
# Check if retry was successful
|
|
213
|
+
if validator(output.result):
|
|
214
|
+
logger.info(
|
|
215
|
+
f"Validation passed on retry attempt {attempt + 1}"
|
|
216
|
+
)
|
|
217
|
+
break
|
|
218
|
+
else:
|
|
219
|
+
logger.warning(
|
|
220
|
+
f"Validation still failing after retry attempt {attempt + 1}"
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
except Exception as e:
|
|
224
|
+
logger.error(f"Retry attempt {attempt + 1} failed: {e}")
|
|
225
|
+
# Continue to next retry attempt if this one fails
|
|
226
|
+
|
|
227
|
+
# Final check after all retries
|
|
228
|
+
if validator and not validator(output.result):
|
|
229
|
+
output.errors.append("Validation failed after all retry attempts")
|
|
230
|
+
|
|
231
|
+
if logprobs:
|
|
232
|
+
output.logprobs = self._extract_logprobs(completion)
|
|
233
|
+
|
|
234
|
+
if with_analysis:
|
|
235
|
+
output.analysis = analysis
|
|
236
|
+
|
|
237
|
+
return output
|
|
238
|
+
|
|
239
|
+
except Exception as e:
|
|
240
|
+
logger.error(f"TheTool failed: {e}")
|
|
241
|
+
output.errors.append(str(e))
|
|
242
|
+
return output
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from typing import Literal, Any
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ToolOutput(BaseModel):
|
|
7
|
+
result: Any = None
|
|
8
|
+
analysis: str = ""
|
|
9
|
+
logprobs: list[dict[str, Any]] = []
|
|
10
|
+
errors: list[str] = []
|
|
11
|
+
|
|
12
|
+
def __repr__(self) -> str:
|
|
13
|
+
return f"ToolOutput(result_type='{type(self.result)}', result='{self.result}', analysis='{self.analysis}', logprobs='{self.logprobs}', errors='{self.errors}'"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class StrOutput(BaseModel):
|
|
17
|
+
result: str = Field(..., description="The output string")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BoolOutput(BaseModel):
|
|
21
|
+
result: bool = Field(
|
|
22
|
+
..., description="Boolean indicating the output state", example=True
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ListStrOutput(BaseModel):
|
|
27
|
+
result: list[str] = Field(
|
|
28
|
+
..., description="The output list of strings", example=["text_1", "text_2"]
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ListDictStrStrOutput(BaseModel):
|
|
33
|
+
result: list[dict[str, str]] = Field(
|
|
34
|
+
...,
|
|
35
|
+
description="List of dictionaries containing string key-value pairs",
|
|
36
|
+
example=[{"text": "Mohammad", "type": "PER"}],
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ReasonListStrOutput(BaseModel):
|
|
41
|
+
reason: str = Field(..., description="Thinking process that led to the output")
|
|
42
|
+
result: list[str] = Field(..., description="The output list of strings")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class CategorizerOutput(BaseModel):
|
|
46
|
+
reason: str = Field(
|
|
47
|
+
..., description="Explanation of why the input belongs to the category"
|
|
48
|
+
)
|
|
49
|
+
result: Literal[
|
|
50
|
+
"باورهای دینی",
|
|
51
|
+
"اخلاق اسلامی",
|
|
52
|
+
"احکام و فقه",
|
|
53
|
+
"تاریخ اسلام و شخصیت ها",
|
|
54
|
+
"منابع دینی",
|
|
55
|
+
"دین و جامعه/سیاست",
|
|
56
|
+
"عرفان و معنویت",
|
|
57
|
+
"هیچکدام",
|
|
58
|
+
] = Field(
|
|
59
|
+
...,
|
|
60
|
+
description="Predicted category label",
|
|
61
|
+
example="اخلاق اسلامی",
|
|
62
|
+
)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from functools import lru_cache
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
import yaml
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class PromptLoader:
|
|
7
|
+
"""
|
|
8
|
+
Utility for loading and formatting YAML prompt templates.
|
|
9
|
+
|
|
10
|
+
Responsibilities:
|
|
11
|
+
- Load and parse YAML prompt definitions.
|
|
12
|
+
- Select the right template (by mode, if applicable).
|
|
13
|
+
- Inject variables (`{input}`, plus any extra kwargs) into the templates.
|
|
14
|
+
- Return a dict with:
|
|
15
|
+
{
|
|
16
|
+
"main_template": "...",
|
|
17
|
+
"analyze_template": "..." | None
|
|
18
|
+
}
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
MAIN_TEMPLATE: str = "main_template"
|
|
22
|
+
ANALYZE_TEMPLATE: str = "analyze_template"
|
|
23
|
+
|
|
24
|
+
# Use lru_cache to load each file once
|
|
25
|
+
@lru_cache(maxsize=32)
|
|
26
|
+
def _load_templates(self, prompt_file: str, mode: str | None) -> dict[str, str]:
|
|
27
|
+
"""
|
|
28
|
+
Loads prompt templates from YAML file with optional mode selection.
|
|
29
|
+
"""
|
|
30
|
+
base_dir = Path(__file__).parent.parent.parent / Path("prompts")
|
|
31
|
+
prompt_path = base_dir / prompt_file
|
|
32
|
+
data = yaml.safe_load(prompt_path.read_text(encoding="utf-8"))
|
|
33
|
+
|
|
34
|
+
return {
|
|
35
|
+
self.MAIN_TEMPLATE: data[self.MAIN_TEMPLATE][mode]
|
|
36
|
+
if mode
|
|
37
|
+
else data[self.MAIN_TEMPLATE],
|
|
38
|
+
self.ANALYZE_TEMPLATE: data.get(self.ANALYZE_TEMPLATE)[mode]
|
|
39
|
+
if mode
|
|
40
|
+
else data.get(self.ANALYZE_TEMPLATE),
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
def _build_format_args(self, text: str, **extra_kwargs) -> dict[str, str]:
|
|
44
|
+
# Base formatting args
|
|
45
|
+
format_args = {"input": text}
|
|
46
|
+
# Merge extras
|
|
47
|
+
format_args.update(extra_kwargs)
|
|
48
|
+
return format_args
|
|
49
|
+
|
|
50
|
+
def load(
|
|
51
|
+
self, prompt_file: str, text: str, mode: str, **extra_kwargs
|
|
52
|
+
) -> dict[str, str]:
|
|
53
|
+
template_configs = self._load_templates(prompt_file, mode)
|
|
54
|
+
format_args = self._build_format_args(text, **extra_kwargs)
|
|
55
|
+
|
|
56
|
+
# Inject variables inside each template
|
|
57
|
+
for key in template_configs.keys():
|
|
58
|
+
template_configs[key] = template_configs[key].format(**format_args)
|
|
59
|
+
|
|
60
|
+
return template_configs
|