hamtaa-texttools 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hamtaa-texttools might be problematic. Click here for more details.

Files changed (30) hide show
  1. {hamtaa_texttools-1.0.2.dist-info → hamtaa_texttools-1.0.4.dist-info}/METADATA +18 -6
  2. hamtaa_texttools-1.0.4.dist-info/RECORD +29 -0
  3. texttools/__init__.py +3 -3
  4. texttools/{utils/batch_manager → batch}/batch_runner.py +1 -1
  5. texttools/formatters/user_merge_formatter/user_merge_formatter.py +0 -17
  6. texttools/prompts/README.md +5 -5
  7. texttools/prompts/categorizer.yaml +16 -10
  8. texttools/prompts/keyword_extractor.yaml +4 -1
  9. texttools/prompts/ner_extractor.yaml +4 -1
  10. texttools/prompts/question_detector.yaml +5 -2
  11. texttools/prompts/question_generator.yaml +4 -3
  12. texttools/prompts/question_merger.yaml +6 -4
  13. texttools/prompts/question_rewriter.yaml +6 -4
  14. texttools/prompts/subject_question_generator.yaml +3 -4
  15. texttools/prompts/summarizer.yaml +1 -0
  16. texttools/prompts/translator.yaml +1 -0
  17. texttools/tools/__init__.py +2 -1
  18. texttools/tools/async_the_tool.py +263 -0
  19. texttools/tools/internals/async_operator.py +288 -0
  20. texttools/tools/{operator.py → internals/operator.py} +133 -63
  21. texttools/tools/{output_models.py → internals/output_models.py} +8 -0
  22. texttools/tools/{prompt_loader.py → internals/prompt_loader.py} +16 -18
  23. texttools/tools/the_tool.py +181 -72
  24. hamtaa_texttools-1.0.2.dist-info/RECORD +0 -28
  25. texttools/utils/__init__.py +0 -4
  26. {hamtaa_texttools-1.0.2.dist-info → hamtaa_texttools-1.0.4.dist-info}/WHEEL +0 -0
  27. {hamtaa_texttools-1.0.2.dist-info → hamtaa_texttools-1.0.4.dist-info}/licenses/LICENSE +0 -0
  28. {hamtaa_texttools-1.0.2.dist-info → hamtaa_texttools-1.0.4.dist-info}/top_level.txt +0 -0
  29. /texttools/{utils/batch_manager → batch}/__init__.py +0 -0
  30. /texttools/{utils/batch_manager → batch}/batch_manager.py +0 -0
@@ -0,0 +1,288 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import math
5
+ import re
6
+ from typing import Any, Literal, Optional, TypeVar
7
+
8
+ from openai import AsyncOpenAI
9
+ from pydantic import BaseModel
10
+
11
+ from texttools.formatters.user_merge_formatter.user_merge_formatter import (
12
+ UserMergeFormatter,
13
+ )
14
+ from texttools.tools.internals.prompt_loader import PromptLoader
15
+
16
+ # Base Model type for output models
17
+ T = TypeVar("T", bound=BaseModel)
18
+
19
+
20
+ class AsyncOperator:
21
+ """
22
+ Async version of Operator.
23
+
24
+ Behaves like the synchronous Operator but uses AsyncOpenAI and async/await.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ client: AsyncOpenAI,
30
+ *,
31
+ model: str,
32
+ temperature: float = 0.0,
33
+ **client_kwargs: Any,
34
+ ):
35
+ self.client: AsyncOpenAI = client
36
+ self.model = model
37
+ self.temperature = temperature
38
+ self.client_kwargs = client_kwargs
39
+
40
+ def _build_user_message(self, prompt: str) -> dict[str, str]:
41
+ return {"role": "user", "content": prompt}
42
+
43
+ async def _analysis_completion(self, analyze_message: list[dict[str, str]]) -> str:
44
+ try:
45
+ completion = await self.client.chat.completions.create(
46
+ model=self.model,
47
+ messages=analyze_message,
48
+ temperature=self.temperature,
49
+ **self.client_kwargs,
50
+ )
51
+ analysis = completion.choices[0].message.content.strip()
52
+ return analysis
53
+
54
+ except Exception as e:
55
+ print(f"[ERROR] Analysis failed: {e}")
56
+ raise
57
+
58
+ async def _analyze(self, prompt_configs: dict[str, str]) -> str:
59
+ analyze_prompt = prompt_configs["analyze_template"]
60
+ analyze_message = [self._build_user_message(analyze_prompt)]
61
+ analysis = await self._analysis_completion(analyze_message)
62
+
63
+ return analysis
64
+
65
+ async def _parse_completion(
66
+ self,
67
+ message: list[dict[str, str]],
68
+ output_model: T,
69
+ logprobs: bool = False,
70
+ top_logprobs: int = 3,
71
+ ) -> tuple[T, Any]:
72
+ try:
73
+ request_kwargs = {
74
+ "model": self.model,
75
+ "messages": message,
76
+ "response_format": output_model,
77
+ "temperature": self.temperature,
78
+ **self.client_kwargs,
79
+ }
80
+ if logprobs:
81
+ request_kwargs["logprobs"] = True
82
+ request_kwargs["top_logprobs"] = top_logprobs
83
+
84
+ completion = await self.client.beta.chat.completions.parse(**request_kwargs)
85
+ parsed = completion.choices[0].message.parsed
86
+ return parsed, completion
87
+
88
+ except Exception as e:
89
+ print(f"[ERROR] Failed to parse completion: {e}")
90
+ raise
91
+
92
+ def _clean_json_response(self, response: str) -> str:
93
+ """
94
+ Clean JSON response by removing code block markers and whitespace.
95
+ Handles cases like:
96
+ - ```json{"result": "value"}```
97
+ """
98
+ cleaned = response.strip()
99
+
100
+ # Remove ```json marker
101
+ if cleaned.startswith("```json"):
102
+ cleaned = cleaned[7:]
103
+
104
+ # Remove trailing ```
105
+ if cleaned.endswith("```"):
106
+ cleaned = cleaned[:-3]
107
+
108
+ return cleaned.strip()
109
+
110
+ def _convert_to_output_model(self, response_string: str, output_model: T) -> T:
111
+ """
112
+ Convert a JSON response string to output model.
113
+
114
+ Args:
115
+ response_string: The JSON string (may contain code block markers)
116
+ output_model: Your Pydantic output model class (e.g., StrOutput, ListStrOutput)
117
+
118
+ Returns:
119
+ Instance of your output model
120
+ """
121
+ try:
122
+ # Clean the response string
123
+ cleaned_json = self._clean_json_response(response_string)
124
+
125
+ # Fix Python-style booleans
126
+ cleaned_json = cleaned_json.replace("False", "false").replace(
127
+ "True", "true"
128
+ )
129
+
130
+ # Convert string to Python dictionary
131
+ response_dict = json.loads(cleaned_json)
132
+
133
+ # Convert dictionary to output model
134
+ return output_model(**response_dict)
135
+
136
+ except json.JSONDecodeError as e:
137
+ raise ValueError(
138
+ f"Failed to parse JSON response: {e}\nResponse: {response_string}"
139
+ )
140
+ except Exception as e:
141
+ raise ValueError(f"Failed to convert to output model: {e}")
142
+
143
+ async def _vllm_completion(
144
+ self,
145
+ message: list[dict[str, str]],
146
+ output_model: T,
147
+ logprobs: bool = False,
148
+ top_logprobs: int = 3,
149
+ ) -> tuple[T, Any]:
150
+ try:
151
+ json_schema = output_model.model_json_schema()
152
+
153
+ # Build kwargs dynamically
154
+ request_kwargs = {
155
+ "model": self.model,
156
+ "messages": message,
157
+ "extra_body": {"guided_json": json_schema},
158
+ "temperature": self.temperature,
159
+ **self.client_kwargs,
160
+ }
161
+
162
+ if logprobs:
163
+ request_kwargs["logprobs"] = True
164
+ request_kwargs["top_logprobs"] = top_logprobs
165
+
166
+ completion = await self.client.chat.completions.create(**request_kwargs)
167
+ response = completion.choices[0].message.content
168
+
169
+ # Convert the string response to output model
170
+ parsed = self._convert_to_output_model(response, output_model)
171
+
172
+ return parsed, completion
173
+
174
+ except Exception as e:
175
+ print(f"[ERROR] Failed to get vLLM structured output: {e}")
176
+ raise
177
+
178
+ def _extract_logprobs(self, completion: dict):
179
+ logprobs_data = []
180
+ ignore_pattern = re.compile(r'^(result|[\s\[\]\{\}",:]+)$')
181
+
182
+ for choice in completion.choices:
183
+ if not getattr(choice, "logprobs", None):
184
+ continue
185
+
186
+ for logprob_item in choice.logprobs.content:
187
+ if ignore_pattern.match(logprob_item.token):
188
+ continue
189
+ token_entry = {
190
+ "token": logprob_item.token,
191
+ "prob": round(math.exp(logprob_item.logprob), 8),
192
+ "top_alternatives": [],
193
+ }
194
+ for alt in logprob_item.top_logprobs:
195
+ if ignore_pattern.match(alt.token):
196
+ continue
197
+ token_entry["top_alternatives"].append(
198
+ {
199
+ "token": alt.token,
200
+ "prob": round(math.exp(alt.logprob), 8),
201
+ }
202
+ )
203
+ logprobs_data.append(token_entry)
204
+
205
+ return logprobs_data
206
+
207
+ async def run(
208
+ self,
209
+ input_text: str,
210
+ prompt_file: str,
211
+ output_model: T,
212
+ with_analysis: bool = False,
213
+ use_modes: bool = False,
214
+ mode: str = "",
215
+ resp_format: Literal["vllm", "parse"] = "parse",
216
+ output_lang: Optional[str] = None,
217
+ logprobs: bool = False,
218
+ top_logprobs: int = 3,
219
+ **extra_kwargs,
220
+ ) -> dict[str, Any]:
221
+ """
222
+ Execute the async LLM pipeline with the given input text.
223
+
224
+ Args:
225
+ input_text: The text to process (will be stripped of whitespace)
226
+ **extra_kwargs: Additional variables to inject into prompt templates
227
+
228
+ Returns:
229
+ Dictionary containing the parsed result and optional analysis
230
+ """
231
+ prompt_loader = PromptLoader()
232
+ formatter = UserMergeFormatter()
233
+
234
+ try:
235
+ cleaned_text = input_text.strip()
236
+
237
+ prompt_configs = prompt_loader.load_prompts(
238
+ prompt_file,
239
+ use_modes,
240
+ mode,
241
+ cleaned_text,
242
+ **extra_kwargs,
243
+ )
244
+
245
+ messages: list[dict[str, str]] = []
246
+
247
+ if with_analysis:
248
+ analysis = await self._analyze(prompt_configs)
249
+ messages.append(
250
+ self._build_user_message(f"Based on this analysis: {analysis}")
251
+ )
252
+
253
+ if output_lang:
254
+ messages.append(
255
+ self._build_user_message(
256
+ f"Respond only in the {output_lang} language."
257
+ )
258
+ )
259
+
260
+ messages.append(self._build_user_message(prompt_configs["main_template"]))
261
+
262
+ messages = formatter.format(messages)
263
+
264
+ if resp_format == "vllm":
265
+ parsed, completion = await self._vllm_completion(
266
+ messages, output_model, logprobs, top_logprobs
267
+ )
268
+ elif resp_format == "parse":
269
+ parsed, completion = await self._parse_completion(
270
+ messages, output_model, logprobs, top_logprobs
271
+ )
272
+ else:
273
+ raise ValueError(f"Unknown resp_format: {resp_format}")
274
+
275
+ results = {"result": parsed.result}
276
+
277
+ if logprobs:
278
+ results["logprobs"] = self._extract_logprobs(completion)
279
+
280
+ if with_analysis:
281
+ results["analysis"] = analysis
282
+
283
+ return results
284
+
285
+ except Exception as e:
286
+ # Print error clearly and re-raise for the caller to handle
287
+ print(f"[ERROR] Async operation failed: {e}")
288
+ raise
@@ -1,6 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, TypeVar, Type, Literal
3
+ import math
4
+ import re
5
+ from typing import Any, TypeVar, Literal, Optional
4
6
  import json
5
7
 
6
8
  from openai import OpenAI
@@ -9,7 +11,7 @@ from pydantic import BaseModel
9
11
  from texttools.formatters.user_merge_formatter.user_merge_formatter import (
10
12
  UserMergeFormatter,
11
13
  )
12
- from texttools.tools.prompt_loader import PromptLoader
14
+ from texttools.tools.internals.prompt_loader import PromptLoader
13
15
 
14
16
  # Base Model type for output models
15
17
  T = TypeVar("T", bound=BaseModel)
@@ -42,13 +44,6 @@ class Operator:
42
44
  - RESP_FORMAT: str → "vllm" or "parse"
43
45
  """
44
46
 
45
- PROMPT_FILE: str
46
- OUTPUT_MODEL: Type[T]
47
- WITH_ANALYSIS: bool = False
48
- USE_MODES: bool
49
- MODE: str = ""
50
- RESP_FORMAT: Literal["vllm", "parse"] = "vllm"
51
-
52
47
  def __init__(
53
48
  self,
54
49
  client: OpenAI,
@@ -59,17 +54,12 @@ class Operator:
59
54
  ):
60
55
  self.client: OpenAI = client
61
56
  self.model = model
62
- self.prompt_loader = PromptLoader()
63
- self.formatter = UserMergeFormatter()
64
57
  self.temperature = temperature
65
58
  self.client_kwargs = client_kwargs
66
59
 
67
60
  def _build_user_message(self, prompt: str) -> dict[str, str]:
68
61
  return {"role": "user", "content": prompt}
69
62
 
70
- def _apply_formatter(self, messages: list[dict[str, str]]) -> list[dict[str, str]]:
71
- return self.formatter.format(messages)
72
-
73
63
  def _analysis_completion(self, analyze_message: list[dict[str, str]]) -> str:
74
64
  try:
75
65
  completion = self.client.chat.completions.create(
@@ -85,30 +75,35 @@ class Operator:
85
75
  print(f"[ERROR] Analysis failed: {e}")
86
76
  raise
87
77
 
88
- def _analyze(self) -> str:
89
- analyze_prompt = self.prompt_configs["analyze_template"]
78
+ def _analyze(self, prompt_configs: dict[str, str]) -> str:
79
+ analyze_prompt = prompt_configs["analyze_template"]
90
80
  analyze_message = [self._build_user_message(analyze_prompt)]
91
81
  analysis = self._analysis_completion(analyze_message)
92
82
 
93
83
  return analysis
94
84
 
95
- def _build_main_message(self) -> list[dict[str, str]]:
96
- main_prompt = self.prompt_configs["main_template"]
97
- main_message = self._build_user_message(main_prompt)
98
-
99
- return main_message
100
-
101
- def _parse_completion(self, message: list[dict[str, str]]) -> T:
85
+ def _parse_completion(
86
+ self,
87
+ message: list[dict[str, str]],
88
+ output_model: T,
89
+ logprobs: bool = False,
90
+ top_logprobs: int = 3,
91
+ ) -> tuple[T, Any]:
102
92
  try:
103
- completion = self.client.beta.chat.completions.parse(
104
- model=self.model,
105
- messages=message,
106
- response_format=self.OUTPUT_MODEL,
107
- temperature=self.temperature,
93
+ request_kwargs = {
94
+ "model": self.model,
95
+ "messages": message,
96
+ "response_format": output_model,
97
+ "temperature": self.temperature,
108
98
  **self.client_kwargs,
109
- )
99
+ }
100
+ if logprobs:
101
+ request_kwargs["logprobs"] = True
102
+ request_kwargs["top_logprobs"] = top_logprobs
103
+
104
+ completion = self.client.beta.chat.completions.parse(**request_kwargs)
110
105
  parsed = completion.choices[0].message.parsed
111
- return parsed
106
+ return parsed, completion
112
107
 
113
108
  except Exception as e:
114
109
  print(f"[ERROR] Failed to parse completion: {e}")
@@ -119,24 +114,20 @@ class Operator:
119
114
  Clean JSON response by removing code block markers and whitespace.
120
115
  Handles cases like:
121
116
  - ```json{"result": "value"}```
122
- - ```{"result": "value"}```
123
117
  """
124
- # Remove code block markers
125
118
  cleaned = response.strip()
126
119
 
127
- # Remove ```json and ``` markers
120
+ # Remove ```json marker
128
121
  if cleaned.startswith("```json"):
129
- cleaned = cleaned[7:] # Remove ```json
130
- elif cleaned.startswith("```"):
131
- cleaned = cleaned[3:] # Remove ```
122
+ cleaned = cleaned[7:]
132
123
 
133
- # Remove trailing ``` or '''
124
+ # Remove trailing ```
134
125
  if cleaned.endswith("```"):
135
126
  cleaned = cleaned[:-3]
136
127
 
137
128
  return cleaned.strip()
138
129
 
139
- def _convert_to_output_model(self, response_string: str) -> T:
130
+ def _convert_to_output_model(self, response_string: str, output_model: T) -> T:
140
131
  """
141
132
  Convert a JSON response string to output model.
142
133
 
@@ -151,11 +142,16 @@ class Operator:
151
142
  # Clean the response string
152
143
  cleaned_json = self._clean_json_response(response_string)
153
144
 
145
+ # Fix Python-style booleans
146
+ cleaned_json = cleaned_json.replace("False", "false").replace(
147
+ "True", "true"
148
+ )
149
+
154
150
  # Convert string to Python dictionary
155
151
  response_dict = json.loads(cleaned_json)
156
152
 
157
153
  # Convert dictionary to output model
158
- return self.OUTPUT_MODEL(**response_dict)
154
+ return output_model(**response_dict)
159
155
 
160
156
  except json.JSONDecodeError as e:
161
157
  raise ValueError(
@@ -164,28 +160,84 @@ class Operator:
164
160
  except Exception as e:
165
161
  raise ValueError(f"Failed to convert to output model: {e}")
166
162
 
167
- def _vllm_completion(self, message: list[dict[str, str]]) -> T:
163
+ def _vllm_completion(
164
+ self,
165
+ message: list[dict[str, str]],
166
+ output_model: T,
167
+ logprobs: bool = False,
168
+ top_logprobs: int = 3,
169
+ ) -> tuple[T, Any]:
168
170
  try:
169
- json_schema = self.OUTPUT_MODEL.model_json_schema()
170
- completion = self.client.chat.completions.create(
171
- model=self.model,
172
- messages=message,
173
- extra_body={"guided_json": json_schema},
174
- temperature=self.temperature,
171
+ json_schema = output_model.model_json_schema()
172
+
173
+ # Build kwargs dynamically
174
+ request_kwargs = {
175
+ "model": self.model,
176
+ "messages": message,
177
+ "extra_body": {"guided_json": json_schema},
178
+ "temperature": self.temperature,
175
179
  **self.client_kwargs,
176
- )
180
+ }
181
+
182
+ if logprobs:
183
+ request_kwargs["logprobs"] = True
184
+ request_kwargs["top_logprobs"] = top_logprobs
185
+
186
+ completion = self.client.chat.completions.create(**request_kwargs)
177
187
  response = completion.choices[0].message.content
178
188
 
179
189
  # Convert the string response to output model
180
- parsed_response = self._convert_to_output_model(response)
190
+ parsed = self._convert_to_output_model(response, output_model)
181
191
 
182
- return parsed_response
192
+ return parsed, completion
183
193
 
184
194
  except Exception as e:
185
195
  print(f"[ERROR] Failed to get vLLM structured output: {e}")
186
196
  raise
187
197
 
188
- def run(self, input_text: str, **extra_kwargs) -> dict[str, Any]:
198
+ def _extract_logprobs(self, completion: dict):
199
+ logprobs_data = []
200
+ ignore_pattern = re.compile(r'^(result|[\s\[\]\{\}",:]+)$')
201
+
202
+ for choice in completion.choices:
203
+ if not getattr(choice, "logprobs", None):
204
+ continue
205
+
206
+ for logprob_item in choice.logprobs.content:
207
+ if ignore_pattern.match(logprob_item.token):
208
+ continue
209
+ token_entry = {
210
+ "token": logprob_item.token,
211
+ "prob": round(math.exp(logprob_item.logprob), 8),
212
+ "top_alternatives": [],
213
+ }
214
+ for alt in logprob_item.top_logprobs:
215
+ if ignore_pattern.match(alt.token):
216
+ continue
217
+ token_entry["top_alternatives"].append(
218
+ {
219
+ "token": alt.token,
220
+ "prob": round(math.exp(alt.logprob), 8),
221
+ }
222
+ )
223
+ logprobs_data.append(token_entry)
224
+
225
+ return logprobs_data
226
+
227
+ def run(
228
+ self,
229
+ input_text: str,
230
+ prompt_file: str,
231
+ output_model: T,
232
+ with_analysis: bool = False,
233
+ use_modes: bool = False,
234
+ mode: str = "",
235
+ resp_format: Literal["vllm", "parse"] = "parse",
236
+ output_lang: Optional[str] = None,
237
+ logprobs: bool = False,
238
+ top_logprobs: int = 3,
239
+ **extra_kwargs,
240
+ ) -> dict[str, Any]:
189
241
  """
190
242
  Execute the LLM pipeline with the given input text.
191
243
 
@@ -196,36 +248,54 @@ class Operator:
196
248
  Returns:
197
249
  Dictionary containing the parsed result and optional analysis
198
250
  """
251
+ prompt_loader = PromptLoader()
252
+ formatter = UserMergeFormatter()
253
+
199
254
  try:
200
255
  cleaned_text = input_text.strip()
201
256
 
202
- self.prompt_configs = self.prompt_loader.load_prompts(
203
- self.PROMPT_FILE,
204
- self.USE_MODES,
205
- self.MODE,
257
+ prompt_configs = prompt_loader.load_prompts(
258
+ prompt_file,
259
+ use_modes,
260
+ mode,
206
261
  cleaned_text,
207
262
  **extra_kwargs,
208
263
  )
209
264
 
210
265
  messages: list[dict[str, str]] = []
211
266
 
212
- if self.WITH_ANALYSIS:
213
- analysis = self._analyze()
267
+ if with_analysis:
268
+ analysis = self._analyze(prompt_configs)
214
269
  messages.append(
215
270
  self._build_user_message(f"Based on this analysis: {analysis}")
216
271
  )
217
272
 
218
- messages.append(self._build_main_message())
219
- messages = self.formatter.format(messages)
273
+ if output_lang:
274
+ messages.append(
275
+ self._build_user_message(
276
+ f"Respond only in the {output_lang} language."
277
+ )
278
+ )
279
+
280
+ messages.append(self._build_user_message(prompt_configs["main_template"]))
281
+
282
+ messages = formatter.format(messages)
220
283
 
221
- if self.RESP_FORMAT == "vllm":
222
- parsed = self._vllm_completion(messages)
223
- elif self.RESP_FORMAT == "parse":
224
- parsed = self._parse_completion(messages)
284
+ if resp_format == "vllm":
285
+ parsed, completion = self._vllm_completion(
286
+ messages, output_model, logprobs, top_logprobs
287
+ )
288
+ elif resp_format == "parse":
289
+ parsed, completion = self._parse_completion(
290
+ messages, output_model, logprobs, top_logprobs
291
+ )
225
292
 
226
293
  results = {"result": parsed.result}
227
294
 
228
- if self.WITH_ANALYSIS:
295
+ if logprobs:
296
+ results["logprobs"] = self._extract_logprobs(completion)
297
+
298
+ if with_analysis:
229
299
  results["analysis"] = analysis
230
300
 
231
301
  return results
@@ -11,6 +11,14 @@ class StrOutput(BaseModel):
11
11
  result: str
12
12
 
13
13
 
14
+ class BoolOutput(BaseModel):
15
+ """
16
+ Output model for a single boolean result.
17
+ """
18
+
19
+ result: bool
20
+
21
+
14
22
  class ListStrOutput(BaseModel):
15
23
  """
16
24
  Output model for a list of strings result.
@@ -1,4 +1,3 @@
1
- from typing import Optional
2
1
  from pathlib import Path
3
2
  import yaml
4
3
 
@@ -25,16 +24,17 @@ class PromptLoader:
25
24
  MAIN_TEMPLATE: str = "main_template"
26
25
  ANALYZE_TEMPLATE: str = "analyze_template"
27
26
 
28
- def __init__(self, prompts_dir: Optional[str] = None):
29
- self.PROMPTS_DIR = prompts_dir or "prompts"
30
-
31
- def _get_prompt_path(self, prompt_file: str) -> Path:
32
- return Path(__file__).parent.parent / self.PROMPTS_DIR / prompt_file
27
+ def _get_prompt_path(self, prompt_file: str, prompts_dir: str) -> Path:
28
+ return Path(__file__).parent.parent.parent / prompts_dir / prompt_file
33
29
 
34
30
  def _load_templates(
35
- self, prompt_file: str, use_modes: bool, mode: str
31
+ self,
32
+ prompts_dir: str,
33
+ prompt_file: str,
34
+ use_modes: bool,
35
+ mode: str,
36
36
  ) -> dict[str, str]:
37
- prompt_path = self._get_prompt_path(prompt_file)
37
+ prompt_path = self._get_prompt_path(prompt_file, prompts_dir)
38
38
 
39
39
  if not prompt_path.exists():
40
40
  raise FileNotFoundError(f"Prompt file not found: {prompt_path}")
@@ -45,18 +45,13 @@ class PromptLoader:
45
45
  except yaml.YAMLError as e:
46
46
  raise ValueError(f"Invalid YAML in {prompt_path}: {e}")
47
47
 
48
- if self.MAIN_TEMPLATE not in data:
49
- raise ValueError(
50
- f"Missing required '{self.MAIN_TEMPLATE}' in {prompt_file}"
51
- )
52
-
53
48
  return {
54
- self.MAIN_TEMPLATE: data[self.MAIN_TEMPLATE][mode]
49
+ "main_template": data["main_template"][mode]
55
50
  if use_modes
56
- else data[self.MAIN_TEMPLATE],
57
- self.ANALYZE_TEMPLATE: data.get(self.ANALYZE_TEMPLATE)[mode]
51
+ else data["main_template"],
52
+ "analyze_template": data.get("analyze_template")[mode]
58
53
  if use_modes
59
- else data.get(self.ANALYZE_TEMPLATE),
54
+ else data.get("analyze_template"),
60
55
  }
61
56
 
62
57
  def _build_format_args(self, input_text: str, **extra_kwargs) -> dict[str, str]:
@@ -72,9 +67,12 @@ class PromptLoader:
72
67
  use_modes: bool,
73
68
  mode: str,
74
69
  input_text: str,
70
+ prompts_dir: str = "prompts",
75
71
  **extra_kwargs,
76
72
  ) -> dict[str, str]:
77
- template_configs = self._load_templates(prompt_file, use_modes, mode)
73
+ template_configs = self._load_templates(
74
+ prompts_dir, prompt_file, use_modes, mode
75
+ )
78
76
  format_args = self._build_format_args(input_text, **extra_kwargs)
79
77
 
80
78
  # Inject variables inside each template