hamtaa-texttools 1.1.12__py3-none-any.whl → 1.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,181 @@
1
+ from typing import Type, Any, Literal
2
+
3
+ from pydantic import BaseModel, Field, create_model
4
+
5
+
6
+ class ToolOutput(BaseModel):
7
+ result: Any = None
8
+ analysis: str = ""
9
+ logprobs: list[dict[str, Any]] = []
10
+ errors: list[str] = []
11
+
12
+ def __repr__(self) -> str:
13
+ return f"ToolOutput(result_type='{type(self.result)}', result='{self.result}', analysis='{self.analysis}', logprobs='{self.logprobs}', errors='{self.errors}'"
14
+
15
+
16
+ class StrOutput(BaseModel):
17
+ result: str = Field(..., description="The output string")
18
+
19
+
20
+ class BoolOutput(BaseModel):
21
+ result: bool = Field(
22
+ ..., description="Boolean indicating the output state", example=True
23
+ )
24
+
25
+
26
+ class ListStrOutput(BaseModel):
27
+ result: list[str] = Field(
28
+ ..., description="The output list of strings", example=["text_1", "text_2"]
29
+ )
30
+
31
+
32
+ class ListDictStrStrOutput(BaseModel):
33
+ result: list[dict[str, str]] = Field(
34
+ ...,
35
+ description="List of dictionaries containing string key-value pairs",
36
+ example=[{"text": "Mohammad", "type": "PER"}],
37
+ )
38
+
39
+
40
+ class ReasonListStrOutput(BaseModel):
41
+ reason: str = Field(..., description="Thinking process that led to the output")
42
+ result: list[str] = Field(..., description="The output list of strings")
43
+
44
+
45
+ class Node(BaseModel):
46
+ node_id: int
47
+ name: str
48
+ level: int
49
+ parent_id: int | None
50
+ description: str = "No description provided"
51
+
52
+
53
+ class CategoryTree:
54
+ def __init__(self, tree_name):
55
+ self.root = Node(node_id=0, name=tree_name, level=0, parent_id=None)
56
+ self.node_list: list[Node] = [self.root]
57
+ self.new_id = 1
58
+
59
+ def add_node(
60
+ self,
61
+ node_name: str,
62
+ parent_name: str | None = None,
63
+ description: str | None = None,
64
+ ) -> None:
65
+ if self.find_node(node_name):
66
+ raise ValueError(f"{node_name} has been chosen for another category before")
67
+
68
+ if parent_name:
69
+ parent_node = self.find_node(parent_name)
70
+ if parent_node is None:
71
+ raise ValueError(f"Parent category '{parent_name}' not found")
72
+ parent_id = parent_node.node_id
73
+ level = parent_node.level + 1
74
+ else:
75
+ level = 1
76
+ parent_id = 0
77
+
78
+ node_data = {
79
+ "node_id": self.new_id,
80
+ "name": node_name,
81
+ "level": level,
82
+ "parent_id": parent_id,
83
+ }
84
+
85
+ if description is not None:
86
+ node_data["description"] = description
87
+
88
+ self.node_list.append(Node(**node_data))
89
+ self.new_id += 1
90
+
91
+ def get_nodes(self) -> list[Node]:
92
+ return self.node_list
93
+
94
+ def find_node(self, identifier: int | str) -> Node | None:
95
+ if isinstance(identifier, str):
96
+ for node in self.get_nodes():
97
+ if node.name == identifier:
98
+ return node
99
+ return None
100
+ elif isinstance(identifier, int):
101
+ for node in self.get_nodes():
102
+ if node.node_id == identifier:
103
+ return node
104
+ return None
105
+ else:
106
+ return None
107
+
108
+ def find_children(self, parent_node: Node) -> list[Node] | None:
109
+ children = []
110
+ for node in self.get_nodes():
111
+ if parent_node.node_id == node.parent_id:
112
+ children.append(node)
113
+
114
+ return children if children else None
115
+
116
+ def remove_node(self, identifier: int | str) -> None:
117
+ node = self.find_node(identifier)
118
+
119
+ if node is not None:
120
+ # Remove node's children recursively
121
+ children = self.find_children(node)
122
+
123
+ # Ending condition
124
+ if children is None:
125
+ self.node_list.remove(node)
126
+ return
127
+
128
+ for child in children:
129
+ self.remove_node(child.name)
130
+
131
+ # Remove the node from tree
132
+ self.node_list.remove(node)
133
+ else:
134
+ raise ValueError(f"Node with identifier: '{identifier}' not found.")
135
+
136
+ def dump_tree(self) -> dict:
137
+ def build_dict(node: Node) -> dict:
138
+ children = [
139
+ build_dict(child)
140
+ for child in self.node_list
141
+ if child.parent_id == node.node_id
142
+ ]
143
+ return {
144
+ "node_id": node.node_id,
145
+ "name": node.name,
146
+ "level": node.level,
147
+ "parent_id": node.parent_id,
148
+ "children": children,
149
+ }
150
+
151
+ return {"category_tree": build_dict(self.root)["children"]}
152
+
153
+ def level_count(self) -> int:
154
+ return max([item.level for item in self.node_list])
155
+
156
+
157
+ # This function is needed to create CategorizerOutput with dynamic categories
158
+ def create_dynamic_model(allowed_values: list[str]) -> Type[BaseModel]:
159
+ literal_type = Literal[*allowed_values]
160
+
161
+ CategorizerOutput = create_model(
162
+ "CategorizerOutput",
163
+ reason=(
164
+ str,
165
+ Field(
166
+ ..., description="Explanation of why the input belongs to the category"
167
+ ),
168
+ ),
169
+ result=(literal_type, Field(..., description="Predicted category label")),
170
+ )
171
+
172
+ return CategorizerOutput
173
+
174
+
175
+ class Entity(BaseModel):
176
+ text: str = Field(description="The exact text of the entity")
177
+ type: str = Field(description="The type of the entity")
178
+
179
+
180
+ class EntityDetectorOutput(BaseModel):
181
+ result: list[Entity] = Field(description="List of all extracted entities")
@@ -1,11 +1,13 @@
1
- from typing import Any, TypeVar, Type, Callable
1
+ from typing import Any, TypeVar, Type
2
+ from collections.abc import Callable
2
3
  import logging
3
4
 
4
5
  from openai import OpenAI
5
6
  from pydantic import BaseModel
6
7
 
7
- from texttools.tools.internals.output_models import ToolOutput
8
+ from texttools.tools.internals.models import ToolOutput
8
9
  from texttools.tools.internals.operator_utils import OperatorUtils
10
+ from texttools.tools.internals.formatters import Formatter
9
11
  from texttools.tools.internals.prompt_loader import PromptLoader
10
12
 
11
13
  # Base Model type for output models
@@ -50,6 +52,7 @@ class Operator:
50
52
  temperature: float,
51
53
  logprobs: bool = False,
52
54
  top_logprobs: int = 3,
55
+ priority: int | None = 0,
53
56
  ) -> tuple[T, Any]:
54
57
  """
55
58
  Parses a chat completion using OpenAI's structured output format.
@@ -66,6 +69,9 @@ class Operator:
66
69
  request_kwargs["logprobs"] = True
67
70
  request_kwargs["top_logprobs"] = top_logprobs
68
71
 
72
+ if priority:
73
+ request_kwargs["extra_body"] = {"priority": priority}
74
+
69
75
  completion = self._client.beta.chat.completions.parse(**request_kwargs)
70
76
  parsed = completion.choices[0].message.parsed
71
77
  return parsed, completion
@@ -86,14 +92,15 @@ class Operator:
86
92
  prompt_file: str,
87
93
  output_model: Type[T],
88
94
  mode: str | None,
95
+ priority: int | None = 0,
89
96
  **extra_kwargs,
90
97
  ) -> ToolOutput:
91
98
  """
92
99
  Execute the LLM pipeline with the given input text.
93
100
  """
94
101
  prompt_loader = PromptLoader()
102
+ formatter = Formatter()
95
103
  output = ToolOutput()
96
-
97
104
  try:
98
105
  # Prompt configs contain two keys: main_template and analyze template, both are string
99
106
  prompt_configs = prompt_loader.load(
@@ -131,8 +138,10 @@ class Operator:
131
138
  OperatorUtils.build_user_message(prompt_configs["main_template"])
132
139
  )
133
140
 
141
+ messages = formatter.user_merge_format(messages)
142
+
134
143
  parsed, completion = self._parse_completion(
135
- messages, output_model, temperature, logprobs, top_logprobs
144
+ messages, output_model, temperature, logprobs, top_logprobs, priority
136
145
  )
137
146
 
138
147
  output.result = parsed.result