hamtaa-texttools 1.1.1__py3-none-any.whl → 1.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {hamtaa_texttools-1.1.1.dist-info → hamtaa_texttools-1.1.16.dist-info}/METADATA +98 -26
  2. hamtaa_texttools-1.1.16.dist-info/RECORD +31 -0
  3. texttools/__init__.py +6 -8
  4. texttools/batch/batch_config.py +26 -0
  5. texttools/batch/batch_runner.py +105 -151
  6. texttools/batch/{batch_manager.py → internals/batch_manager.py} +39 -40
  7. texttools/batch/internals/utils.py +16 -0
  8. texttools/prompts/README.md +4 -4
  9. texttools/prompts/categorize.yaml +77 -0
  10. texttools/prompts/detect_entity.yaml +22 -0
  11. texttools/prompts/extract_keywords.yaml +68 -18
  12. texttools/tools/async_tools.py +804 -0
  13. texttools/tools/internals/async_operator.py +90 -69
  14. texttools/tools/internals/models.py +183 -0
  15. texttools/tools/internals/operator_utils.py +54 -0
  16. texttools/tools/internals/prompt_loader.py +13 -14
  17. texttools/tools/internals/sync_operator.py +201 -0
  18. texttools/tools/sync_tools.py +804 -0
  19. hamtaa_texttools-1.1.1.dist-info/RECORD +0 -30
  20. texttools/batch/__init__.py +0 -4
  21. texttools/prompts/categorizer.yaml +0 -28
  22. texttools/tools/__init__.py +0 -4
  23. texttools/tools/async_the_tool.py +0 -414
  24. texttools/tools/internals/base_operator.py +0 -91
  25. texttools/tools/internals/operator.py +0 -179
  26. texttools/tools/internals/output_models.py +0 -59
  27. texttools/tools/the_tool.py +0 -412
  28. {hamtaa_texttools-1.1.1.dist-info → hamtaa_texttools-1.1.16.dist-info}/WHEEL +0 -0
  29. {hamtaa_texttools-1.1.1.dist-info → hamtaa_texttools-1.1.16.dist-info}/licenses/LICENSE +0 -0
  30. {hamtaa_texttools-1.1.1.dist-info → hamtaa_texttools-1.1.16.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,22 @@
1
- from typing import Any, TypeVar, Type, Literal
1
+ from typing import Any, TypeVar, Type
2
+ from collections.abc import Callable
2
3
  import logging
3
4
 
4
5
  from openai import AsyncOpenAI
5
6
  from pydantic import BaseModel
6
7
 
7
- from texttools.tools.internals.output_models import ToolOutput
8
- from texttools.tools.internals.base_operator import BaseOperator
8
+ from texttools.tools.internals.models import ToolOutput
9
+ from texttools.tools.internals.operator_utils import OperatorUtils
9
10
  from texttools.tools.internals.formatters import Formatter
10
11
  from texttools.tools.internals.prompt_loader import PromptLoader
11
12
 
12
13
  # Base Model type for output models
13
14
  T = TypeVar("T", bound=BaseModel)
14
15
 
15
- # Configure logger
16
- logger = logging.getLogger("async_operator")
17
- logger.setLevel(logging.INFO)
16
+ logger = logging.getLogger("texttools.async_operator")
18
17
 
19
18
 
20
- class AsyncOperator(BaseOperator):
19
+ class AsyncOperator:
21
20
  """
22
21
  Core engine for running text-processing operations with an LLM (Async).
23
22
 
@@ -28,14 +27,18 @@ class AsyncOperator(BaseOperator):
28
27
  """
29
28
 
30
29
  def __init__(self, client: AsyncOpenAI, model: str):
31
- self.client = client
32
- self.model = model
30
+ self._client = client
31
+ self._model = model
33
32
 
34
33
  async def _analyze(self, prompt_configs: dict[str, str], temperature: float) -> str:
34
+ """
35
+ Calls OpenAI API for analysis using the configured prompt template.
36
+ Returns the analyzed content as a string.
37
+ """
35
38
  analyze_prompt = prompt_configs["analyze_template"]
36
- analyze_message = [self._build_user_message(analyze_prompt)]
37
- completion = await self.client.chat.completions.create(
38
- model=self.model,
39
+ analyze_message = [OperatorUtils.build_user_message(analyze_prompt)]
40
+ completion = await self._client.chat.completions.create(
41
+ model=self._model,
39
42
  messages=analyze_message,
40
43
  temperature=temperature,
41
44
  )
@@ -49,9 +52,14 @@ class AsyncOperator(BaseOperator):
49
52
  temperature: float,
50
53
  logprobs: bool = False,
51
54
  top_logprobs: int = 3,
52
- ) -> tuple[Type[T], Any]:
55
+ priority: int | None = 0,
56
+ ) -> tuple[T, Any]:
57
+ """
58
+ Parses a chat completion using OpenAI's structured output format.
59
+ Returns both the parsed object and the raw completion for logprobs.
60
+ """
53
61
  request_kwargs = {
54
- "model": self.model,
62
+ "model": self._model,
55
63
  "messages": message,
56
64
  "response_format": output_model,
57
65
  "temperature": temperature,
@@ -60,40 +68,12 @@ class AsyncOperator(BaseOperator):
60
68
  if logprobs:
61
69
  request_kwargs["logprobs"] = True
62
70
  request_kwargs["top_logprobs"] = top_logprobs
63
-
64
- completion = await self.client.beta.chat.completions.parse(**request_kwargs)
71
+ if priority:
72
+ request_kwargs["extra_body"] = {"priority": priority}
73
+ completion = await self._client.beta.chat.completions.parse(**request_kwargs)
65
74
  parsed = completion.choices[0].message.parsed
66
75
  return parsed, completion
67
76
 
68
- async def _vllm_completion(
69
- self,
70
- message: list[dict[str, str]],
71
- output_model: Type[T],
72
- temperature: float,
73
- logprobs: bool = False,
74
- top_logprobs: int = 3,
75
- ) -> tuple[Type[T], Any]:
76
- json_schema = output_model.model_json_schema()
77
-
78
- # Build kwargs dynamically
79
- request_kwargs = {
80
- "model": self.model,
81
- "messages": message,
82
- "extra_body": {"guided_json": json_schema},
83
- "temperature": temperature,
84
- }
85
-
86
- if logprobs:
87
- request_kwargs["logprobs"] = True
88
- request_kwargs["top_logprobs"] = top_logprobs
89
-
90
- completion = await self.client.chat.completions.create(**request_kwargs)
91
- response = completion.choices[0].message.content
92
-
93
- # Convert the string response to output model
94
- parsed = self._convert_to_output_model(response, output_model)
95
- return parsed, completion
96
-
97
77
  async def run(
98
78
  self,
99
79
  # User parameters
@@ -104,20 +84,24 @@ class AsyncOperator(BaseOperator):
104
84
  temperature: float,
105
85
  logprobs: bool,
106
86
  top_logprobs: int | None,
87
+ validator: Callable[[Any], bool] | None,
88
+ max_validation_retries: int | None,
107
89
  # Internal parameters
108
90
  prompt_file: str,
109
91
  output_model: Type[T],
110
- resp_format: Literal["vllm", "parse"],
111
92
  mode: str | None,
93
+ priority: int | None = 0,
112
94
  **extra_kwargs,
113
- ) -> dict[str, Any]:
95
+ ) -> ToolOutput:
114
96
  """
115
97
  Execute the async LLM pipeline with the given input text. (Async)
116
98
  """
117
99
  prompt_loader = PromptLoader()
118
100
  formatter = Formatter()
101
+ output = ToolOutput()
119
102
 
120
103
  try:
104
+ # Prompt configs contain two keys: main_template and analyze template, both are string
121
105
  prompt_configs = prompt_loader.load(
122
106
  prompt_file=prompt_file,
123
107
  text=text.strip(),
@@ -125,55 +109,92 @@ class AsyncOperator(BaseOperator):
125
109
  **extra_kwargs,
126
110
  )
127
111
 
128
- messages: list[dict[str, str]] = []
112
+ messages = []
129
113
 
130
114
  if with_analysis:
131
115
  analysis = await self._analyze(prompt_configs, temperature)
132
116
  messages.append(
133
- self._build_user_message(f"Based on this analysis: {analysis}")
117
+ OperatorUtils.build_user_message(
118
+ f"Based on this analysis: {analysis}"
119
+ )
134
120
  )
135
121
 
136
122
  if output_lang:
137
123
  messages.append(
138
- self._build_user_message(
124
+ OperatorUtils.build_user_message(
139
125
  f"Respond only in the {output_lang} language."
140
126
  )
141
127
  )
142
128
 
143
129
  if user_prompt:
144
130
  messages.append(
145
- self._build_user_message(f"Consider this instruction {user_prompt}")
131
+ OperatorUtils.build_user_message(
132
+ f"Consider this instruction {user_prompt}"
133
+ )
146
134
  )
147
135
 
148
- messages.append(self._build_user_message(prompt_configs["main_template"]))
136
+ messages.append(
137
+ OperatorUtils.build_user_message(prompt_configs["main_template"])
138
+ )
139
+
149
140
  messages = formatter.user_merge_format(messages)
150
141
 
151
- if resp_format == "vllm":
152
- parsed, completion = await self._vllm_completion(
153
- messages, output_model, temperature, logprobs, top_logprobs
154
- )
155
- elif resp_format == "parse":
156
- parsed, completion = await self._parse_completion(
157
- messages, output_model, temperature, logprobs, top_logprobs
158
- )
142
+ parsed, completion = await self._parse_completion(
143
+ messages, output_model, temperature, logprobs, top_logprobs, priority
144
+ )
159
145
 
160
- # Ensure output_model has a `result` field
161
- if not hasattr(parsed, "result"):
162
- logger.error(
163
- "The provided output_model must define a field named 'result'"
164
- )
146
+ output.result = parsed.result
165
147
 
166
- output = ToolOutput(result="", analysis="", logprobs=[], errors=[])
148
+ # Retry logic if validation fails
149
+ if validator and not validator(output.result):
150
+ for attempt in range(max_validation_retries):
151
+ logger.warning(
152
+ f"Validation failed, retrying for the {attempt + 1} time."
153
+ )
167
154
 
168
- output.result = parsed.result
155
+ # Generate new temperature for retry
156
+ retry_temperature = OperatorUtils.get_retry_temp(temperature)
157
+ try:
158
+ parsed, completion = await self._parse_completion(
159
+ messages,
160
+ output_model,
161
+ retry_temperature,
162
+ logprobs,
163
+ top_logprobs,
164
+ )
165
+
166
+ output.result = parsed.result
167
+
168
+ # Check if retry was successful
169
+ if validator(output.result):
170
+ logger.info(
171
+ f"Validation passed on retry attempt {attempt + 1}"
172
+ )
173
+ break
174
+ else:
175
+ logger.warning(
176
+ f"Validation still failing after retry attempt {attempt + 1}"
177
+ )
178
+
179
+ except Exception as e:
180
+ logger.error(f"Retry attempt {attempt + 1} failed: {e}")
181
+ # Continue to next retry attempt if this one fails
182
+
183
+ # Final check after all retries
184
+ if validator and not validator(output.result):
185
+ output.errors.append("Validation failed after all retry attempts")
169
186
 
170
187
  if logprobs:
171
- output.logprobs = self._extract_logprobs(completion)
188
+ output.logprobs = OperatorUtils.extract_logprobs(completion)
172
189
 
173
190
  if with_analysis:
174
191
  output.analysis = analysis
175
192
 
193
+ output.process = prompt_file[:-5]
194
+
176
195
  return output
196
+
177
197
  except Exception as e:
178
198
  logger.error(f"AsyncTheTool failed: {e}")
179
- return ToolOutput(result="", analysis="", logprobs=[], errors=[str(e)])
199
+ output.errors.append(str(e))
200
+ return output
@@ -0,0 +1,183 @@
1
+ from datetime import datetime
2
+ from typing import Type, Any, Literal
3
+
4
+ from pydantic import BaseModel, Field, create_model
5
+
6
+
7
+ class ToolOutput(BaseModel):
8
+ result: Any = None
9
+ analysis: str = ""
10
+ logprobs: list[dict[str, Any]] = []
11
+ process: str = ""
12
+ processed_at: datetime = datetime.now()
13
+ execution_time: float = -1.0
14
+ errors: list[str] = []
15
+
16
+ def __repr__(self) -> str:
17
+ return f"ToolOutput(process='{self.process}', result_type='{type(self.result)}', result='{self.result}', analysis='{self.analysis}', logprobs='{self.logprobs}', errors='{self.errors}', processed_at='{self.processed_at}', execution_time='{self.execution_time}'"
18
+
19
+
20
+ class StrOutput(BaseModel):
21
+ result: str = Field(..., description="The output string")
22
+
23
+
24
+ class BoolOutput(BaseModel):
25
+ result: bool = Field(
26
+ ..., description="Boolean indicating the output state", example=True
27
+ )
28
+
29
+
30
+ class ListStrOutput(BaseModel):
31
+ result: list[str] = Field(
32
+ ..., description="The output list of strings", example=["text_1", "text_2"]
33
+ )
34
+
35
+
36
+ class ListDictStrStrOutput(BaseModel):
37
+ result: list[dict[str, str]] = Field(
38
+ ...,
39
+ description="List of dictionaries containing string key-value pairs",
40
+ example=[{"text": "Mohammad", "type": "PER"}],
41
+ )
42
+
43
+
44
+ class ReasonListStrOutput(BaseModel):
45
+ reason: str = Field(..., description="Thinking process that led to the output")
46
+ result: list[str] = Field(..., description="The output list of strings")
47
+
48
+
49
+ class Node(BaseModel):
50
+ node_id: int
51
+ name: str
52
+ level: int
53
+ parent_id: int | None
54
+ description: str = "No description provided"
55
+
56
+
57
+ class CategoryTree:
58
+ def __init__(self, tree_name):
59
+ self.root = Node(node_id=0, name=tree_name, level=0, parent_id=None)
60
+ self.all_nodes: list[Node] = [self.root]
61
+ self.new_id = 1
62
+
63
+ def add_node(
64
+ self,
65
+ node_name: str,
66
+ parent_name: str | None = None,
67
+ description: str | None = None,
68
+ ) -> None:
69
+ if self.find_node(node_name):
70
+ raise ValueError(f"{node_name} has been chosen for another category before")
71
+
72
+ if parent_name:
73
+ parent_node = self.find_node(parent_name)
74
+ if parent_node is None:
75
+ raise ValueError(f"Parent category '{parent_name}' not found")
76
+ parent_id = parent_node.node_id
77
+ level = parent_node.level + 1
78
+ else:
79
+ level = 1
80
+ parent_id = 0
81
+
82
+ node_data = {
83
+ "node_id": self.new_id,
84
+ "name": node_name,
85
+ "level": level,
86
+ "parent_id": parent_id,
87
+ }
88
+
89
+ if description is not None:
90
+ node_data["description"] = description
91
+
92
+ self.all_nodes.append(Node(**node_data))
93
+ self.new_id += 1
94
+
95
+ def get_nodes(self) -> list[Node]:
96
+ return self.all_nodes
97
+
98
+ def get_level_count(self) -> int:
99
+ return max([item.level for item in self.all_nodes])
100
+
101
+ def find_node(self, identifier: int | str) -> Node | None:
102
+ if isinstance(identifier, str):
103
+ for node in self.get_nodes():
104
+ if node.name == identifier:
105
+ return node
106
+ return None
107
+ elif isinstance(identifier, int):
108
+ for node in self.get_nodes():
109
+ if node.node_id == identifier:
110
+ return node
111
+ return None
112
+ else:
113
+ return None
114
+
115
+ def find_children(self, parent_node: Node) -> list[Node] | None:
116
+ children = [
117
+ node for node in self.get_nodes() if parent_node.node_id == node.parent_id
118
+ ]
119
+ return children if children else None
120
+
121
+ def remove_node(self, identifier: int | str) -> None:
122
+ node = self.find_node(identifier)
123
+
124
+ if node is not None:
125
+ # Remove node's children recursively
126
+ children = self.find_children(node)
127
+
128
+ # Ending condition
129
+ if children is None:
130
+ self.all_nodes.remove(node)
131
+ return
132
+
133
+ for child in children:
134
+ self.remove_node(child.name)
135
+
136
+ # Remove the node from tree
137
+ self.all_nodes.remove(node)
138
+ else:
139
+ raise ValueError(f"Node with identifier: '{identifier}' not found.")
140
+
141
+ def dump_tree(self) -> dict:
142
+ def build_dict(node: Node) -> dict:
143
+ children = [
144
+ build_dict(child)
145
+ for child in self.all_nodes
146
+ if child.parent_id == node.node_id
147
+ ]
148
+ return {
149
+ "node_id": node.node_id,
150
+ "name": node.name,
151
+ "level": node.level,
152
+ "parent_id": node.parent_id,
153
+ "children": children,
154
+ }
155
+
156
+ return {"category_tree": build_dict(self.root)["children"]}
157
+
158
+
159
+ # This function is needed to create CategorizerOutput with dynamic categories
160
+ def create_dynamic_model(allowed_values: list[str]) -> Type[BaseModel]:
161
+ literal_type = Literal[*allowed_values]
162
+
163
+ CategorizerOutput = create_model(
164
+ "CategorizerOutput",
165
+ reason=(
166
+ str,
167
+ Field(
168
+ ..., description="Explanation of why the input belongs to the category"
169
+ ),
170
+ ),
171
+ result=(literal_type, Field(..., description="Predicted category label")),
172
+ )
173
+
174
+ return CategorizerOutput
175
+
176
+
177
+ class Entity(BaseModel):
178
+ text: str = Field(description="The exact text of the entity")
179
+ entity_type: str = Field(description="The type of the entity")
180
+
181
+
182
+ class EntityDetectorOutput(BaseModel):
183
+ result: list[Entity] = Field(description="List of all extracted entities")
@@ -0,0 +1,54 @@
1
+ import re
2
+ import math
3
+ import random
4
+
5
+
6
+ class OperatorUtils:
7
+ @staticmethod
8
+ def build_user_message(prompt: str) -> dict[str, str]:
9
+ return {"role": "user", "content": prompt}
10
+
11
+ @staticmethod
12
+ def extract_logprobs(completion: dict) -> list[dict]:
13
+ """
14
+ Extracts and filters token probabilities from completion logprobs.
15
+ Skips punctuation and structural tokens, returns cleaned probability data.
16
+ """
17
+ logprobs_data = []
18
+
19
+ ignore_pattern = re.compile(r'^(result|[\s\[\]\{\}",:]+)$')
20
+
21
+ for choice in completion.choices:
22
+ if not getattr(choice, "logprobs", None):
23
+ return []
24
+
25
+ for logprob_item in choice.logprobs.content:
26
+ if ignore_pattern.match(logprob_item.token):
27
+ continue
28
+ token_entry = {
29
+ "token": logprob_item.token,
30
+ "prob": round(math.exp(logprob_item.logprob), 8),
31
+ "top_alternatives": [],
32
+ }
33
+ for alt in logprob_item.top_logprobs:
34
+ if ignore_pattern.match(alt.token):
35
+ continue
36
+ token_entry["top_alternatives"].append(
37
+ {
38
+ "token": alt.token,
39
+ "prob": round(math.exp(alt.logprob), 8),
40
+ }
41
+ )
42
+ logprobs_data.append(token_entry)
43
+
44
+ return logprobs_data
45
+
46
+ @staticmethod
47
+ def get_retry_temp(base_temp: float) -> float:
48
+ """
49
+ Calculate temperature for retry attempts.
50
+ """
51
+ delta_temp = random.choice([-1, 1]) * random.uniform(0.1, 0.9)
52
+ new_temp = base_temp + delta_temp
53
+
54
+ return max(0.0, min(new_temp, 1.5))
@@ -11,19 +11,25 @@ class PromptLoader:
11
11
  - Load and parse YAML prompt definitions.
12
12
  - Select the right template (by mode, if applicable).
13
13
  - Inject variables (`{input}`, plus any extra kwargs) into the templates.
14
- - Return a dict with:
15
- {
16
- "main_template": "...",
17
- "analyze_template": "..." | None
18
- }
19
14
  """
20
15
 
21
- MAIN_TEMPLATE: str = "main_template"
22
- ANALYZE_TEMPLATE: str = "analyze_template"
16
+ MAIN_TEMPLATE = "main_template"
17
+ ANALYZE_TEMPLATE = "analyze_template"
18
+
19
+ @staticmethod
20
+ def _build_format_args(text: str, **extra_kwargs) -> dict[str, str]:
21
+ # Base formatting args
22
+ format_args = {"input": text}
23
+ # Merge extras
24
+ format_args.update(extra_kwargs)
25
+ return format_args
23
26
 
24
27
  # Use lru_cache to load each file once
25
28
  @lru_cache(maxsize=32)
26
29
  def _load_templates(self, prompt_file: str, mode: str | None) -> dict[str, str]:
30
+ """
31
+ Loads prompt templates from YAML file with optional mode selection.
32
+ """
27
33
  base_dir = Path(__file__).parent.parent.parent / Path("prompts")
28
34
  prompt_path = base_dir / prompt_file
29
35
  data = yaml.safe_load(prompt_path.read_text(encoding="utf-8"))
@@ -37,13 +43,6 @@ class PromptLoader:
37
43
  else data.get(self.ANALYZE_TEMPLATE),
38
44
  }
39
45
 
40
- def _build_format_args(self, text: str, **extra_kwargs) -> dict[str, str]:
41
- # Base formatting args
42
- format_args = {"input": text}
43
- # Merge extras
44
- format_args.update(extra_kwargs)
45
- return format_args
46
-
47
46
  def load(
48
47
  self, prompt_file: str, text: str, mode: str, **extra_kwargs
49
48
  ) -> dict[str, str]: