hamtaa-texttools 1.1.13__py3-none-any.whl → 1.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hamtaa_texttools-1.1.13.dist-info → hamtaa_texttools-1.1.14.dist-info}/METADATA +8 -6
- {hamtaa_texttools-1.1.13.dist-info → hamtaa_texttools-1.1.14.dist-info}/RECORD +16 -15
- texttools/__init__.py +2 -1
- texttools/batch/batch_config.py +1 -1
- texttools/batch/batch_runner.py +1 -1
- texttools/prompts/categorize.yaml +77 -0
- texttools/prompts/detect_entity.yaml +22 -0
- texttools/prompts/extract_keywords.yaml +68 -18
- texttools/tools/async_tools.py +206 -41
- texttools/tools/internals/async_operator.py +8 -4
- texttools/tools/internals/models.py +181 -0
- texttools/tools/internals/sync_operator.py +9 -4
- texttools/tools/sync_tools.py +206 -41
- texttools/prompts/categorizer.yaml +0 -28
- texttools/tools/internals/output_models.py +0 -62
- {hamtaa_texttools-1.1.13.dist-info → hamtaa_texttools-1.1.14.dist-info}/WHEEL +0 -0
- {hamtaa_texttools-1.1.13.dist-info → hamtaa_texttools-1.1.14.dist-info}/licenses/LICENSE +0 -0
- {hamtaa_texttools-1.1.13.dist-info → hamtaa_texttools-1.1.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
from typing import Type, Any, Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, create_model
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ToolOutput(BaseModel):
|
|
7
|
+
result: Any = None
|
|
8
|
+
analysis: str = ""
|
|
9
|
+
logprobs: list[dict[str, Any]] = []
|
|
10
|
+
errors: list[str] = []
|
|
11
|
+
|
|
12
|
+
def __repr__(self) -> str:
|
|
13
|
+
return f"ToolOutput(result_type='{type(self.result)}', result='{self.result}', analysis='{self.analysis}', logprobs='{self.logprobs}', errors='{self.errors}'"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class StrOutput(BaseModel):
|
|
17
|
+
result: str = Field(..., description="The output string")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BoolOutput(BaseModel):
|
|
21
|
+
result: bool = Field(
|
|
22
|
+
..., description="Boolean indicating the output state", example=True
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ListStrOutput(BaseModel):
|
|
27
|
+
result: list[str] = Field(
|
|
28
|
+
..., description="The output list of strings", example=["text_1", "text_2"]
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ListDictStrStrOutput(BaseModel):
|
|
33
|
+
result: list[dict[str, str]] = Field(
|
|
34
|
+
...,
|
|
35
|
+
description="List of dictionaries containing string key-value pairs",
|
|
36
|
+
example=[{"text": "Mohammad", "type": "PER"}],
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ReasonListStrOutput(BaseModel):
|
|
41
|
+
reason: str = Field(..., description="Thinking process that led to the output")
|
|
42
|
+
result: list[str] = Field(..., description="The output list of strings")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Node(BaseModel):
|
|
46
|
+
node_id: int
|
|
47
|
+
name: str
|
|
48
|
+
level: int
|
|
49
|
+
parent_id: int | None
|
|
50
|
+
description: str = "No description provided"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class CategoryTree:
|
|
54
|
+
def __init__(self, tree_name):
|
|
55
|
+
self.root = Node(node_id=0, name=tree_name, level=0, parent_id=None)
|
|
56
|
+
self.node_list: list[Node] = [self.root]
|
|
57
|
+
self.new_id = 1
|
|
58
|
+
|
|
59
|
+
def add_node(
|
|
60
|
+
self,
|
|
61
|
+
node_name: str,
|
|
62
|
+
parent_name: str | None = None,
|
|
63
|
+
description: str | None = None,
|
|
64
|
+
) -> None:
|
|
65
|
+
if self.find_node(node_name):
|
|
66
|
+
raise ValueError(f"{node_name} has been chosen for another category before")
|
|
67
|
+
|
|
68
|
+
if parent_name:
|
|
69
|
+
parent_node = self.find_node(parent_name)
|
|
70
|
+
if parent_node is None:
|
|
71
|
+
raise ValueError(f"Parent category '{parent_name}' not found")
|
|
72
|
+
parent_id = parent_node.node_id
|
|
73
|
+
level = parent_node.level + 1
|
|
74
|
+
else:
|
|
75
|
+
level = 1
|
|
76
|
+
parent_id = 0
|
|
77
|
+
|
|
78
|
+
node_data = {
|
|
79
|
+
"node_id": self.new_id,
|
|
80
|
+
"name": node_name,
|
|
81
|
+
"level": level,
|
|
82
|
+
"parent_id": parent_id,
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
if description is not None:
|
|
86
|
+
node_data["description"] = description
|
|
87
|
+
|
|
88
|
+
self.node_list.append(Node(**node_data))
|
|
89
|
+
self.new_id += 1
|
|
90
|
+
|
|
91
|
+
def get_nodes(self) -> list[Node]:
|
|
92
|
+
return self.node_list
|
|
93
|
+
|
|
94
|
+
def find_node(self, identifier: int | str) -> Node | None:
|
|
95
|
+
if isinstance(identifier, str):
|
|
96
|
+
for node in self.get_nodes():
|
|
97
|
+
if node.name == identifier:
|
|
98
|
+
return node
|
|
99
|
+
return None
|
|
100
|
+
elif isinstance(identifier, int):
|
|
101
|
+
for node in self.get_nodes():
|
|
102
|
+
if node.node_id == identifier:
|
|
103
|
+
return node
|
|
104
|
+
return None
|
|
105
|
+
else:
|
|
106
|
+
return None
|
|
107
|
+
|
|
108
|
+
def find_children(self, parent_node: Node) -> list[Node] | None:
|
|
109
|
+
children = []
|
|
110
|
+
for node in self.get_nodes():
|
|
111
|
+
if parent_node.node_id == node.parent_id:
|
|
112
|
+
children.append(node)
|
|
113
|
+
|
|
114
|
+
return children if children else None
|
|
115
|
+
|
|
116
|
+
def remove_node(self, identifier: int | str) -> None:
|
|
117
|
+
node = self.find_node(identifier)
|
|
118
|
+
|
|
119
|
+
if node is not None:
|
|
120
|
+
# Remove node's children recursively
|
|
121
|
+
children = self.find_children(node)
|
|
122
|
+
|
|
123
|
+
# Ending condition
|
|
124
|
+
if children is None:
|
|
125
|
+
self.node_list.remove(node)
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
for child in children:
|
|
129
|
+
self.remove_node(child.name)
|
|
130
|
+
|
|
131
|
+
# Remove the node from tree
|
|
132
|
+
self.node_list.remove(node)
|
|
133
|
+
else:
|
|
134
|
+
raise ValueError(f"Node with identifier: '{identifier}' not found.")
|
|
135
|
+
|
|
136
|
+
def dump_tree(self) -> dict:
|
|
137
|
+
def build_dict(node: Node) -> dict:
|
|
138
|
+
children = [
|
|
139
|
+
build_dict(child)
|
|
140
|
+
for child in self.node_list
|
|
141
|
+
if child.parent_id == node.node_id
|
|
142
|
+
]
|
|
143
|
+
return {
|
|
144
|
+
"node_id": node.node_id,
|
|
145
|
+
"name": node.name,
|
|
146
|
+
"level": node.level,
|
|
147
|
+
"parent_id": node.parent_id,
|
|
148
|
+
"children": children,
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
return {"category_tree": build_dict(self.root)["children"]}
|
|
152
|
+
|
|
153
|
+
def level_count(self) -> int:
|
|
154
|
+
return max([item.level for item in self.node_list])
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# This function is needed to create CategorizerOutput with dynamic categories
|
|
158
|
+
def create_dynamic_model(allowed_values: list[str]) -> Type[BaseModel]:
|
|
159
|
+
literal_type = Literal[*allowed_values]
|
|
160
|
+
|
|
161
|
+
CategorizerOutput = create_model(
|
|
162
|
+
"CategorizerOutput",
|
|
163
|
+
reason=(
|
|
164
|
+
str,
|
|
165
|
+
Field(
|
|
166
|
+
..., description="Explanation of why the input belongs to the category"
|
|
167
|
+
),
|
|
168
|
+
),
|
|
169
|
+
result=(literal_type, Field(..., description="Predicted category label")),
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return CategorizerOutput
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class Entity(BaseModel):
|
|
176
|
+
text: str = Field(description="The exact text of the entity")
|
|
177
|
+
type: str = Field(description="The type of the entity")
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class EntityDetectorOutput(BaseModel):
|
|
181
|
+
result: list[Entity] = Field(description="List of all extracted entities")
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
from typing import Any, TypeVar, Type
|
|
1
|
+
from typing import Any, TypeVar, Type
|
|
2
|
+
from collections.abc import Callable
|
|
2
3
|
import logging
|
|
3
4
|
|
|
4
5
|
from openai import OpenAI
|
|
5
6
|
from pydantic import BaseModel
|
|
6
7
|
|
|
7
|
-
from texttools.tools.internals.
|
|
8
|
+
from texttools.tools.internals.models import ToolOutput
|
|
8
9
|
from texttools.tools.internals.operator_utils import OperatorUtils
|
|
9
10
|
from texttools.tools.internals.formatters import Formatter
|
|
10
11
|
from texttools.tools.internals.prompt_loader import PromptLoader
|
|
@@ -51,6 +52,7 @@ class Operator:
|
|
|
51
52
|
temperature: float,
|
|
52
53
|
logprobs: bool = False,
|
|
53
54
|
top_logprobs: int = 3,
|
|
55
|
+
priority: int | None = 0,
|
|
54
56
|
) -> tuple[T, Any]:
|
|
55
57
|
"""
|
|
56
58
|
Parses a chat completion using OpenAI's structured output format.
|
|
@@ -67,6 +69,9 @@ class Operator:
|
|
|
67
69
|
request_kwargs["logprobs"] = True
|
|
68
70
|
request_kwargs["top_logprobs"] = top_logprobs
|
|
69
71
|
|
|
72
|
+
if priority:
|
|
73
|
+
request_kwargs["extra_body"] = {"priority": priority}
|
|
74
|
+
|
|
70
75
|
completion = self._client.beta.chat.completions.parse(**request_kwargs)
|
|
71
76
|
parsed = completion.choices[0].message.parsed
|
|
72
77
|
return parsed, completion
|
|
@@ -87,6 +92,7 @@ class Operator:
|
|
|
87
92
|
prompt_file: str,
|
|
88
93
|
output_model: Type[T],
|
|
89
94
|
mode: str | None,
|
|
95
|
+
priority: int | None = 0,
|
|
90
96
|
**extra_kwargs,
|
|
91
97
|
) -> ToolOutput:
|
|
92
98
|
"""
|
|
@@ -95,7 +101,6 @@ class Operator:
|
|
|
95
101
|
prompt_loader = PromptLoader()
|
|
96
102
|
formatter = Formatter()
|
|
97
103
|
output = ToolOutput()
|
|
98
|
-
|
|
99
104
|
try:
|
|
100
105
|
# Prompt configs contain two keys: main_template and analyze template, both are string
|
|
101
106
|
prompt_configs = prompt_loader.load(
|
|
@@ -136,7 +141,7 @@ class Operator:
|
|
|
136
141
|
messages = formatter.user_merge_format(messages)
|
|
137
142
|
|
|
138
143
|
parsed, completion = self._parse_completion(
|
|
139
|
-
messages, output_model, temperature, logprobs, top_logprobs
|
|
144
|
+
messages, output_model, temperature, logprobs, top_logprobs, priority
|
|
140
145
|
)
|
|
141
146
|
|
|
142
147
|
output.result = parsed.result
|