hamtaa-texttools 1.1.16__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. hamtaa_texttools-1.2.0.dist-info/METADATA +212 -0
  2. hamtaa_texttools-1.2.0.dist-info/RECORD +34 -0
  3. texttools/__init__.py +5 -5
  4. texttools/batch/__init__.py +0 -0
  5. texttools/batch/{batch_config.py → config.py} +16 -2
  6. texttools/batch/{internals/batch_manager.py → manager.py} +2 -2
  7. texttools/batch/{batch_runner.py → runner.py} +80 -69
  8. texttools/core/__init__.py +0 -0
  9. texttools/core/engine.py +254 -0
  10. texttools/core/exceptions.py +22 -0
  11. texttools/core/internal_models.py +58 -0
  12. texttools/core/operators/async_operator.py +194 -0
  13. texttools/core/operators/sync_operator.py +192 -0
  14. texttools/models.py +88 -0
  15. texttools/prompts/categorize.yaml +36 -77
  16. texttools/prompts/check_fact.yaml +24 -0
  17. texttools/prompts/extract_entities.yaml +7 -3
  18. texttools/prompts/extract_keywords.yaml +21 -9
  19. texttools/prompts/is_question.yaml +6 -2
  20. texttools/prompts/merge_questions.yaml +12 -5
  21. texttools/prompts/propositionize.yaml +24 -0
  22. texttools/prompts/rewrite.yaml +9 -10
  23. texttools/prompts/run_custom.yaml +2 -2
  24. texttools/prompts/subject_to_question.yaml +7 -3
  25. texttools/prompts/summarize.yaml +6 -2
  26. texttools/prompts/text_to_question.yaml +12 -6
  27. texttools/prompts/translate.yaml +7 -2
  28. texttools/py.typed +0 -0
  29. texttools/tools/__init__.py +0 -0
  30. texttools/tools/async_tools.py +778 -489
  31. texttools/tools/sync_tools.py +775 -487
  32. hamtaa_texttools-1.1.16.dist-info/METADATA +0 -255
  33. hamtaa_texttools-1.1.16.dist-info/RECORD +0 -31
  34. texttools/batch/internals/utils.py +0 -16
  35. texttools/prompts/README.md +0 -35
  36. texttools/prompts/detect_entity.yaml +0 -22
  37. texttools/tools/internals/async_operator.py +0 -200
  38. texttools/tools/internals/formatters.py +0 -24
  39. texttools/tools/internals/models.py +0 -183
  40. texttools/tools/internals/operator_utils.py +0 -54
  41. texttools/tools/internals/prompt_loader.py +0 -56
  42. texttools/tools/internals/sync_operator.py +0 -201
  43. {hamtaa_texttools-1.1.16.dist-info → hamtaa_texttools-1.2.0.dist-info}/WHEEL +0 -0
  44. {hamtaa_texttools-1.1.16.dist-info → hamtaa_texttools-1.2.0.dist-info}/licenses/LICENSE +0 -0
  45. {hamtaa_texttools-1.1.16.dist-info → hamtaa_texttools-1.2.0.dist-info}/top_level.txt +0 -0
@@ -1,183 +0,0 @@
1
- from datetime import datetime
2
- from typing import Type, Any, Literal
3
-
4
- from pydantic import BaseModel, Field, create_model
5
-
6
-
7
- class ToolOutput(BaseModel):
8
- result: Any = None
9
- analysis: str = ""
10
- logprobs: list[dict[str, Any]] = []
11
- process: str = ""
12
- processed_at: datetime = datetime.now()
13
- execution_time: float = -1.0
14
- errors: list[str] = []
15
-
16
- def __repr__(self) -> str:
17
- return f"ToolOutput(process='{self.process}', result_type='{type(self.result)}', result='{self.result}', analysis='{self.analysis}', logprobs='{self.logprobs}', errors='{self.errors}', processed_at='{self.processed_at}', execution_time='{self.execution_time}'"
18
-
19
-
20
- class StrOutput(BaseModel):
21
- result: str = Field(..., description="The output string")
22
-
23
-
24
- class BoolOutput(BaseModel):
25
- result: bool = Field(
26
- ..., description="Boolean indicating the output state", example=True
27
- )
28
-
29
-
30
- class ListStrOutput(BaseModel):
31
- result: list[str] = Field(
32
- ..., description="The output list of strings", example=["text_1", "text_2"]
33
- )
34
-
35
-
36
- class ListDictStrStrOutput(BaseModel):
37
- result: list[dict[str, str]] = Field(
38
- ...,
39
- description="List of dictionaries containing string key-value pairs",
40
- example=[{"text": "Mohammad", "type": "PER"}],
41
- )
42
-
43
-
44
- class ReasonListStrOutput(BaseModel):
45
- reason: str = Field(..., description="Thinking process that led to the output")
46
- result: list[str] = Field(..., description="The output list of strings")
47
-
48
-
49
- class Node(BaseModel):
50
- node_id: int
51
- name: str
52
- level: int
53
- parent_id: int | None
54
- description: str = "No description provided"
55
-
56
-
57
- class CategoryTree:
58
- def __init__(self, tree_name):
59
- self.root = Node(node_id=0, name=tree_name, level=0, parent_id=None)
60
- self.all_nodes: list[Node] = [self.root]
61
- self.new_id = 1
62
-
63
- def add_node(
64
- self,
65
- node_name: str,
66
- parent_name: str | None = None,
67
- description: str | None = None,
68
- ) -> None:
69
- if self.find_node(node_name):
70
- raise ValueError(f"{node_name} has been chosen for another category before")
71
-
72
- if parent_name:
73
- parent_node = self.find_node(parent_name)
74
- if parent_node is None:
75
- raise ValueError(f"Parent category '{parent_name}' not found")
76
- parent_id = parent_node.node_id
77
- level = parent_node.level + 1
78
- else:
79
- level = 1
80
- parent_id = 0
81
-
82
- node_data = {
83
- "node_id": self.new_id,
84
- "name": node_name,
85
- "level": level,
86
- "parent_id": parent_id,
87
- }
88
-
89
- if description is not None:
90
- node_data["description"] = description
91
-
92
- self.all_nodes.append(Node(**node_data))
93
- self.new_id += 1
94
-
95
- def get_nodes(self) -> list[Node]:
96
- return self.all_nodes
97
-
98
- def get_level_count(self) -> int:
99
- return max([item.level for item in self.all_nodes])
100
-
101
- def find_node(self, identifier: int | str) -> Node | None:
102
- if isinstance(identifier, str):
103
- for node in self.get_nodes():
104
- if node.name == identifier:
105
- return node
106
- return None
107
- elif isinstance(identifier, int):
108
- for node in self.get_nodes():
109
- if node.node_id == identifier:
110
- return node
111
- return None
112
- else:
113
- return None
114
-
115
- def find_children(self, parent_node: Node) -> list[Node] | None:
116
- children = [
117
- node for node in self.get_nodes() if parent_node.node_id == node.parent_id
118
- ]
119
- return children if children else None
120
-
121
- def remove_node(self, identifier: int | str) -> None:
122
- node = self.find_node(identifier)
123
-
124
- if node is not None:
125
- # Remove node's children recursively
126
- children = self.find_children(node)
127
-
128
- # Ending condition
129
- if children is None:
130
- self.all_nodes.remove(node)
131
- return
132
-
133
- for child in children:
134
- self.remove_node(child.name)
135
-
136
- # Remove the node from tree
137
- self.all_nodes.remove(node)
138
- else:
139
- raise ValueError(f"Node with identifier: '{identifier}' not found.")
140
-
141
- def dump_tree(self) -> dict:
142
- def build_dict(node: Node) -> dict:
143
- children = [
144
- build_dict(child)
145
- for child in self.all_nodes
146
- if child.parent_id == node.node_id
147
- ]
148
- return {
149
- "node_id": node.node_id,
150
- "name": node.name,
151
- "level": node.level,
152
- "parent_id": node.parent_id,
153
- "children": children,
154
- }
155
-
156
- return {"category_tree": build_dict(self.root)["children"]}
157
-
158
-
159
- # This function is needed to create CategorizerOutput with dynamic categories
160
- def create_dynamic_model(allowed_values: list[str]) -> Type[BaseModel]:
161
- literal_type = Literal[*allowed_values]
162
-
163
- CategorizerOutput = create_model(
164
- "CategorizerOutput",
165
- reason=(
166
- str,
167
- Field(
168
- ..., description="Explanation of why the input belongs to the category"
169
- ),
170
- ),
171
- result=(literal_type, Field(..., description="Predicted category label")),
172
- )
173
-
174
- return CategorizerOutput
175
-
176
-
177
- class Entity(BaseModel):
178
- text: str = Field(description="The exact text of the entity")
179
- entity_type: str = Field(description="The type of the entity")
180
-
181
-
182
- class EntityDetectorOutput(BaseModel):
183
- result: list[Entity] = Field(description="List of all extracted entities")
@@ -1,54 +0,0 @@
1
- import re
2
- import math
3
- import random
4
-
5
-
6
- class OperatorUtils:
7
- @staticmethod
8
- def build_user_message(prompt: str) -> dict[str, str]:
9
- return {"role": "user", "content": prompt}
10
-
11
- @staticmethod
12
- def extract_logprobs(completion: dict) -> list[dict]:
13
- """
14
- Extracts and filters token probabilities from completion logprobs.
15
- Skips punctuation and structural tokens, returns cleaned probability data.
16
- """
17
- logprobs_data = []
18
-
19
- ignore_pattern = re.compile(r'^(result|[\s\[\]\{\}",:]+)$')
20
-
21
- for choice in completion.choices:
22
- if not getattr(choice, "logprobs", None):
23
- return []
24
-
25
- for logprob_item in choice.logprobs.content:
26
- if ignore_pattern.match(logprob_item.token):
27
- continue
28
- token_entry = {
29
- "token": logprob_item.token,
30
- "prob": round(math.exp(logprob_item.logprob), 8),
31
- "top_alternatives": [],
32
- }
33
- for alt in logprob_item.top_logprobs:
34
- if ignore_pattern.match(alt.token):
35
- continue
36
- token_entry["top_alternatives"].append(
37
- {
38
- "token": alt.token,
39
- "prob": round(math.exp(alt.logprob), 8),
40
- }
41
- )
42
- logprobs_data.append(token_entry)
43
-
44
- return logprobs_data
45
-
46
- @staticmethod
47
- def get_retry_temp(base_temp: float) -> float:
48
- """
49
- Calculate temperature for retry attempts.
50
- """
51
- delta_temp = random.choice([-1, 1]) * random.uniform(0.1, 0.9)
52
- new_temp = base_temp + delta_temp
53
-
54
- return max(0.0, min(new_temp, 1.5))
@@ -1,56 +0,0 @@
1
- from functools import lru_cache
2
- from pathlib import Path
3
- import yaml
4
-
5
-
6
- class PromptLoader:
7
- """
8
- Utility for loading and formatting YAML prompt templates.
9
-
10
- Responsibilities:
11
- - Load and parse YAML prompt definitions.
12
- - Select the right template (by mode, if applicable).
13
- - Inject variables (`{input}`, plus any extra kwargs) into the templates.
14
- """
15
-
16
- MAIN_TEMPLATE = "main_template"
17
- ANALYZE_TEMPLATE = "analyze_template"
18
-
19
- @staticmethod
20
- def _build_format_args(text: str, **extra_kwargs) -> dict[str, str]:
21
- # Base formatting args
22
- format_args = {"input": text}
23
- # Merge extras
24
- format_args.update(extra_kwargs)
25
- return format_args
26
-
27
- # Use lru_cache to load each file once
28
- @lru_cache(maxsize=32)
29
- def _load_templates(self, prompt_file: str, mode: str | None) -> dict[str, str]:
30
- """
31
- Loads prompt templates from YAML file with optional mode selection.
32
- """
33
- base_dir = Path(__file__).parent.parent.parent / Path("prompts")
34
- prompt_path = base_dir / prompt_file
35
- data = yaml.safe_load(prompt_path.read_text(encoding="utf-8"))
36
-
37
- return {
38
- self.MAIN_TEMPLATE: data[self.MAIN_TEMPLATE][mode]
39
- if mode
40
- else data[self.MAIN_TEMPLATE],
41
- self.ANALYZE_TEMPLATE: data.get(self.ANALYZE_TEMPLATE)[mode]
42
- if mode
43
- else data.get(self.ANALYZE_TEMPLATE),
44
- }
45
-
46
- def load(
47
- self, prompt_file: str, text: str, mode: str, **extra_kwargs
48
- ) -> dict[str, str]:
49
- template_configs = self._load_templates(prompt_file, mode)
50
- format_args = self._build_format_args(text, **extra_kwargs)
51
-
52
- # Inject variables inside each template
53
- for key in template_configs.keys():
54
- template_configs[key] = template_configs[key].format(**format_args)
55
-
56
- return template_configs
@@ -1,201 +0,0 @@
1
- from typing import Any, TypeVar, Type
2
- from collections.abc import Callable
3
- import logging
4
-
5
- from openai import OpenAI
6
- from pydantic import BaseModel
7
-
8
- from texttools.tools.internals.models import ToolOutput
9
- from texttools.tools.internals.operator_utils import OperatorUtils
10
- from texttools.tools.internals.formatters import Formatter
11
- from texttools.tools.internals.prompt_loader import PromptLoader
12
-
13
- # Base Model type for output models
14
- T = TypeVar("T", bound=BaseModel)
15
-
16
- logger = logging.getLogger("texttools.operator")
17
-
18
-
19
- class Operator:
20
- """
21
- Core engine for running text-processing operations with an LLM (Sync).
22
-
23
- It wires together:
24
- - `PromptLoader` → loads YAML prompt templates.
25
- - `UserMergeFormatter` → applies formatting to messages (e.g., merging).
26
- - OpenAI client → executes completions/parsed completions.
27
- """
28
-
29
- def __init__(self, client: OpenAI, model: str):
30
- self._client = client
31
- self._model = model
32
-
33
- def _analyze(self, prompt_configs: dict[str, str], temperature: float) -> str:
34
- """
35
- Calls OpenAI API for analysis using the configured prompt template.
36
- Returns the analyzed content as a string.
37
- """
38
- analyze_prompt = prompt_configs["analyze_template"]
39
- analyze_message = [OperatorUtils.build_user_message(analyze_prompt)]
40
- completion = self._client.chat.completions.create(
41
- model=self._model,
42
- messages=analyze_message,
43
- temperature=temperature,
44
- )
45
- analysis = completion.choices[0].message.content.strip()
46
- return analysis
47
-
48
- def _parse_completion(
49
- self,
50
- message: list[dict[str, str]],
51
- output_model: Type[T],
52
- temperature: float,
53
- logprobs: bool = False,
54
- top_logprobs: int = 3,
55
- priority: int | None = 0,
56
- ) -> tuple[T, Any]:
57
- """
58
- Parses a chat completion using OpenAI's structured output format.
59
- Returns both the parsed object and the raw completion for logprobs.
60
- """
61
- request_kwargs = {
62
- "model": self._model,
63
- "messages": message,
64
- "response_format": output_model,
65
- "temperature": temperature,
66
- }
67
-
68
- if logprobs:
69
- request_kwargs["logprobs"] = True
70
- request_kwargs["top_logprobs"] = top_logprobs
71
-
72
- if priority:
73
- request_kwargs["extra_body"] = {"priority": priority}
74
-
75
- completion = self._client.beta.chat.completions.parse(**request_kwargs)
76
- parsed = completion.choices[0].message.parsed
77
- return parsed, completion
78
-
79
- def run(
80
- self,
81
- # User parameters
82
- text: str,
83
- with_analysis: bool,
84
- output_lang: str | None,
85
- user_prompt: str | None,
86
- temperature: float,
87
- logprobs: bool,
88
- top_logprobs: int | None,
89
- validator: Callable[[Any], bool] | None,
90
- max_validation_retries: int | None,
91
- # Internal parameters
92
- prompt_file: str,
93
- output_model: Type[T],
94
- mode: str | None,
95
- priority: int | None = 0,
96
- **extra_kwargs,
97
- ) -> ToolOutput:
98
- """
99
- Execute the LLM pipeline with the given input text.
100
- """
101
- prompt_loader = PromptLoader()
102
- formatter = Formatter()
103
- output = ToolOutput()
104
- try:
105
- # Prompt configs contain two keys: main_template and analyze template, both are string
106
- prompt_configs = prompt_loader.load(
107
- prompt_file=prompt_file,
108
- text=text.strip(),
109
- mode=mode,
110
- **extra_kwargs,
111
- )
112
-
113
- messages = []
114
-
115
- if with_analysis:
116
- analysis = self._analyze(prompt_configs, temperature)
117
- messages.append(
118
- OperatorUtils.build_user_message(
119
- f"Based on this analysis: {analysis}"
120
- )
121
- )
122
-
123
- if output_lang:
124
- messages.append(
125
- OperatorUtils.build_user_message(
126
- f"Respond only in the {output_lang} language."
127
- )
128
- )
129
-
130
- if user_prompt:
131
- messages.append(
132
- OperatorUtils.build_user_message(
133
- f"Consider this instruction {user_prompt}"
134
- )
135
- )
136
-
137
- messages.append(
138
- OperatorUtils.build_user_message(prompt_configs["main_template"])
139
- )
140
-
141
- messages = formatter.user_merge_format(messages)
142
-
143
- parsed, completion = self._parse_completion(
144
- messages, output_model, temperature, logprobs, top_logprobs, priority
145
- )
146
-
147
- output.result = parsed.result
148
-
149
- # Retry logic if validation fails
150
- if validator and not validator(output.result):
151
- for attempt in range(max_validation_retries):
152
- logger.warning(
153
- f"Validation failed, retrying for the {attempt + 1} time."
154
- )
155
-
156
- # Generate new temperature for retry
157
- retry_temperature = OperatorUtils.get_retry_temp(temperature)
158
- try:
159
- parsed, completion = self._parse_completion(
160
- messages,
161
- output_model,
162
- retry_temperature,
163
- logprobs,
164
- top_logprobs,
165
- )
166
-
167
- output.result = parsed.result
168
-
169
- # Check if retry was successful
170
- if validator(output.result):
171
- logger.info(
172
- f"Validation passed on retry attempt {attempt + 1}"
173
- )
174
- break
175
- else:
176
- logger.warning(
177
- f"Validation still failing after retry attempt {attempt + 1}"
178
- )
179
-
180
- except Exception as e:
181
- logger.error(f"Retry attempt {attempt + 1} failed: {e}")
182
- # Continue to next retry attempt if this one fails
183
-
184
- # Final check after all retries
185
- if validator and not validator(output.result):
186
- output.errors.append("Validation failed after all retry attempts")
187
-
188
- if logprobs:
189
- output.logprobs = OperatorUtils.extract_logprobs(completion)
190
-
191
- if with_analysis:
192
- output.analysis = analysis
193
-
194
- output.process = prompt_file[:-5]
195
-
196
- return output
197
-
198
- except Exception as e:
199
- logger.error(f"TheTool failed: {e}")
200
- output.errors.append(str(e))
201
- return output