hamtaa-texttools 1.0.5__py3-none-any.whl → 1.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. hamtaa_texttools-1.1.16.dist-info/METADATA +255 -0
  2. hamtaa_texttools-1.1.16.dist-info/RECORD +31 -0
  3. texttools/__init__.py +6 -8
  4. texttools/batch/batch_config.py +26 -0
  5. texttools/batch/batch_runner.py +144 -139
  6. texttools/batch/{batch_manager.py → internals/batch_manager.py} +42 -54
  7. texttools/batch/internals/utils.py +16 -0
  8. texttools/prompts/README.md +8 -4
  9. texttools/prompts/categorize.yaml +77 -0
  10. texttools/prompts/detect_entity.yaml +22 -0
  11. texttools/prompts/extract_keywords.yaml +68 -0
  12. texttools/prompts/{question_merger.yaml → merge_questions.yaml} +5 -5
  13. texttools/tools/async_tools.py +804 -0
  14. texttools/tools/internals/async_operator.py +139 -236
  15. texttools/tools/internals/formatters.py +24 -0
  16. texttools/tools/internals/models.py +183 -0
  17. texttools/tools/internals/operator_utils.py +54 -0
  18. texttools/tools/internals/prompt_loader.py +23 -43
  19. texttools/tools/internals/sync_operator.py +201 -0
  20. texttools/tools/sync_tools.py +804 -0
  21. hamtaa_texttools-1.0.5.dist-info/METADATA +0 -192
  22. hamtaa_texttools-1.0.5.dist-info/RECORD +0 -30
  23. texttools/batch/__init__.py +0 -4
  24. texttools/formatters/base_formatter.py +0 -33
  25. texttools/formatters/user_merge_formatter.py +0 -30
  26. texttools/prompts/categorizer.yaml +0 -28
  27. texttools/prompts/keyword_extractor.yaml +0 -18
  28. texttools/tools/__init__.py +0 -4
  29. texttools/tools/async_the_tool.py +0 -277
  30. texttools/tools/internals/operator.py +0 -295
  31. texttools/tools/internals/output_models.py +0 -52
  32. texttools/tools/the_tool.py +0 -501
  33. {hamtaa_texttools-1.0.5.dist-info → hamtaa_texttools-1.1.16.dist-info}/WHEEL +0 -0
  34. {hamtaa_texttools-1.0.5.dist-info → hamtaa_texttools-1.1.16.dist-info}/licenses/LICENSE +0 -0
  35. {hamtaa_texttools-1.0.5.dist-info → hamtaa_texttools-1.1.16.dist-info}/top_level.txt +0 -0
  36. /texttools/prompts/{ner_extractor.yaml → extract_entities.yaml} +0 -0
  37. /texttools/prompts/{question_detector.yaml → is_question.yaml} +0 -0
  38. /texttools/prompts/{rewriter.yaml → rewrite.yaml} +0 -0
  39. /texttools/prompts/{custom_tool.yaml → run_custom.yaml} +0 -0
  40. /texttools/prompts/{subject_question_generator.yaml → subject_to_question.yaml} +0 -0
  41. /texttools/prompts/{summarizer.yaml → summarize.yaml} +0 -0
  42. /texttools/prompts/{question_generator.yaml → text_to_question.yaml} +0 -0
  43. /texttools/prompts/{translator.yaml → translate.yaml} +0 -0
@@ -1,297 +1,200 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import math
5
- import re
6
- from typing import Any, Literal, Optional, TypeVar
1
+ from typing import Any, TypeVar, Type
2
+ from collections.abc import Callable
3
+ import logging
7
4
 
8
5
  from openai import AsyncOpenAI
9
6
  from pydantic import BaseModel
10
7
 
11
- from texttools.formatters.user_merge_formatter import (
12
- UserMergeFormatter,
13
- )
8
+ from texttools.tools.internals.models import ToolOutput
9
+ from texttools.tools.internals.operator_utils import OperatorUtils
10
+ from texttools.tools.internals.formatters import Formatter
14
11
  from texttools.tools.internals.prompt_loader import PromptLoader
15
12
 
16
13
  # Base Model type for output models
17
14
  T = TypeVar("T", bound=BaseModel)
18
15
 
16
+ logger = logging.getLogger("texttools.async_operator")
17
+
19
18
 
20
19
  class AsyncOperator:
21
20
  """
22
- Async version of Operator.
21
+ Core engine for running text-processing operations with an LLM (Async).
23
22
 
24
- Behaves like the synchronous Operator but uses AsyncOpenAI and async/await.
23
+ It wires together:
24
+ - `PromptLoader` → loads YAML prompt templates.
25
+ - `UserMergeFormatter` → applies formatting to messages (e.g., merging).
26
+ - AsyncOpenAI client → executes completions/parsed completions.
25
27
  """
26
28
 
27
- def __init__(
28
- self,
29
- client: AsyncOpenAI,
30
- *,
31
- model: str,
32
- temperature: float = 0.0,
33
- **client_kwargs: Any,
34
- ):
35
- self.client: AsyncOpenAI = client
36
- self.model = model
37
- self.temperature = temperature
38
- self.client_kwargs = client_kwargs
39
-
40
- def _build_user_message(self, prompt: str) -> dict[str, str]:
41
- return {"role": "user", "content": prompt}
42
-
43
- async def _analysis_completion(self, analyze_message: list[dict[str, str]]) -> str:
44
- try:
45
- completion = await self.client.chat.completions.create(
46
- model=self.model,
47
- messages=analyze_message,
48
- temperature=self.temperature,
49
- **self.client_kwargs,
50
- )
51
- analysis = completion.choices[0].message.content.strip()
52
- return analysis
29
+ def __init__(self, client: AsyncOpenAI, model: str):
30
+ self._client = client
31
+ self._model = model
53
32
 
54
- except Exception as e:
55
- print(f"[ERROR] Analysis failed: {e}")
56
- raise
57
-
58
- async def _analyze(self, prompt_configs: dict[str, str]) -> str:
33
+ async def _analyze(self, prompt_configs: dict[str, str], temperature: float) -> str:
34
+ """
35
+ Calls OpenAI API for analysis using the configured prompt template.
36
+ Returns the analyzed content as a string.
37
+ """
59
38
  analyze_prompt = prompt_configs["analyze_template"]
60
- analyze_message = [self._build_user_message(analyze_prompt)]
61
- analysis = await self._analysis_completion(analyze_message)
62
-
39
+ analyze_message = [OperatorUtils.build_user_message(analyze_prompt)]
40
+ completion = await self._client.chat.completions.create(
41
+ model=self._model,
42
+ messages=analyze_message,
43
+ temperature=temperature,
44
+ )
45
+ analysis = completion.choices[0].message.content.strip()
63
46
  return analysis
64
47
 
65
48
  async def _parse_completion(
66
49
  self,
67
50
  message: list[dict[str, str]],
68
- output_model: T,
51
+ output_model: Type[T],
52
+ temperature: float,
69
53
  logprobs: bool = False,
70
54
  top_logprobs: int = 3,
71
- max_tokens: int | None = None,
55
+ priority: int | None = 0,
72
56
  ) -> tuple[T, Any]:
73
- try:
74
- request_kwargs = {
75
- "model": self.model,
76
- "messages": message,
77
- "response_format": output_model,
78
- "temperature": self.temperature,
79
- **self.client_kwargs,
80
- }
81
-
82
- if max_tokens is not None:
83
- request_kwargs["max_tokens"] = max_tokens
84
-
85
- if logprobs:
86
- request_kwargs["logprobs"] = True
87
- request_kwargs["top_logprobs"] = top_logprobs
88
-
89
- completion = await self.client.beta.chat.completions.parse(**request_kwargs)
90
- parsed = completion.choices[0].message.parsed
91
- return parsed, completion
92
-
93
- except Exception as e:
94
- print(f"[ERROR] Failed to parse completion: {e}")
95
- raise
96
-
97
- def _clean_json_response(self, response: str) -> str:
98
57
  """
99
- Clean JSON response by removing code block markers and whitespace.
100
- Handles cases like:
101
- - ```json{"result": "value"}```
58
+ Parses a chat completion using OpenAI's structured output format.
59
+ Returns both the parsed object and the raw completion for logprobs.
102
60
  """
103
- cleaned = response.strip()
104
-
105
- # Remove ```json marker
106
- if cleaned.startswith("```json"):
107
- cleaned = cleaned[7:]
108
-
109
- # Remove trailing ```
110
- if cleaned.endswith("```"):
111
- cleaned = cleaned[:-3]
112
-
113
- return cleaned.strip()
114
-
115
- def _convert_to_output_model(self, response_string: str, output_model: T) -> T:
116
- """
117
- Convert a JSON response string to output model.
118
-
119
- Args:
120
- response_string: The JSON string (may contain code block markers)
121
- output_model: Your Pydantic output model class (e.g., StrOutput, ListStrOutput)
122
-
123
- Returns:
124
- Instance of your output model
125
- """
126
- try:
127
- # Clean the response string
128
- cleaned_json = self._clean_json_response(response_string)
129
-
130
- # Fix Python-style booleans
131
- cleaned_json = cleaned_json.replace("False", "false").replace(
132
- "True", "true"
133
- )
134
-
135
- # Convert string to Python dictionary
136
- response_dict = json.loads(cleaned_json)
137
-
138
- # Convert dictionary to output model
139
- return output_model(**response_dict)
140
-
141
- except json.JSONDecodeError as e:
142
- raise ValueError(
143
- f"Failed to parse JSON response: {e}\nResponse: {response_string}"
144
- )
145
- except Exception as e:
146
- raise ValueError(f"Failed to convert to output model: {e}")
147
-
148
- async def _vllm_completion(
149
- self,
150
- message: list[dict[str, str]],
151
- output_model: T,
152
- logprobs: bool = False,
153
- top_logprobs: int = 3,
154
- max_tokens: int | None = None,
155
- ) -> tuple[T, Any]:
156
- try:
157
- json_schema = output_model.model_json_schema()
158
-
159
- # Build kwargs dynamically
160
- request_kwargs = {
161
- "model": self.model,
162
- "messages": message,
163
- "extra_body": {"guided_json": json_schema},
164
- "temperature": self.temperature,
165
- **self.client_kwargs,
166
- }
167
-
168
- if max_tokens is not None:
169
- request_kwargs["max_tokens"] = max_tokens
170
-
171
- if logprobs:
172
- request_kwargs["logprobs"] = True
173
- request_kwargs["top_logprobs"] = top_logprobs
174
-
175
- completion = await self.client.chat.completions.create(**request_kwargs)
176
- response = completion.choices[0].message.content
177
-
178
- # Convert the string response to output model
179
- parsed = self._convert_to_output_model(response, output_model)
180
-
181
- return parsed, completion
182
-
183
- except Exception as e:
184
- print(f"[ERROR] Failed to get vLLM structured output: {e}")
185
- raise
186
-
187
- def _extract_logprobs(self, completion: dict):
188
- logprobs_data = []
189
- ignore_pattern = re.compile(r'^(result|[\s\[\]\{\}",:]+)$')
190
-
191
- for choice in completion.choices:
192
- if not getattr(choice, "logprobs", None):
193
- continue
194
-
195
- for logprob_item in choice.logprobs.content:
196
- if ignore_pattern.match(logprob_item.token):
197
- continue
198
- token_entry = {
199
- "token": logprob_item.token,
200
- "prob": round(math.exp(logprob_item.logprob), 8),
201
- "top_alternatives": [],
202
- }
203
- for alt in logprob_item.top_logprobs:
204
- if ignore_pattern.match(alt.token):
205
- continue
206
- token_entry["top_alternatives"].append(
207
- {
208
- "token": alt.token,
209
- "prob": round(math.exp(alt.logprob), 8),
210
- }
211
- )
212
- logprobs_data.append(token_entry)
213
-
214
- return logprobs_data
61
+ request_kwargs = {
62
+ "model": self._model,
63
+ "messages": message,
64
+ "response_format": output_model,
65
+ "temperature": temperature,
66
+ }
67
+
68
+ if logprobs:
69
+ request_kwargs["logprobs"] = True
70
+ request_kwargs["top_logprobs"] = top_logprobs
71
+ if priority:
72
+ request_kwargs["extra_body"] = {"priority": priority}
73
+ completion = await self._client.beta.chat.completions.parse(**request_kwargs)
74
+ parsed = completion.choices[0].message.parsed
75
+ return parsed, completion
215
76
 
216
77
  async def run(
217
78
  self,
218
- input_text: str,
79
+ # User parameters
80
+ text: str,
81
+ with_analysis: bool,
82
+ output_lang: str | None,
83
+ user_prompt: str | None,
84
+ temperature: float,
85
+ logprobs: bool,
86
+ top_logprobs: int | None,
87
+ validator: Callable[[Any], bool] | None,
88
+ max_validation_retries: int | None,
89
+ # Internal parameters
219
90
  prompt_file: str,
220
- output_model: T,
221
- with_analysis: bool = False,
222
- use_modes: bool = False,
223
- mode: str = "",
224
- resp_format: Literal["vllm", "parse"] = "parse",
225
- output_lang: str | None = None,
226
- logprobs: bool = False,
227
- top_logprobs: int = 3,
228
- max_tokens: int | None = None,
91
+ output_model: Type[T],
92
+ mode: str | None,
93
+ priority: int | None = 0,
229
94
  **extra_kwargs,
230
- ) -> dict[str, Any]:
95
+ ) -> ToolOutput:
231
96
  """
232
- Execute the async LLM pipeline with the given input text.
97
+ Execute the async LLM pipeline with the given input text. (Async)
233
98
  """
234
99
  prompt_loader = PromptLoader()
235
- formatter = UserMergeFormatter()
100
+ formatter = Formatter()
101
+ output = ToolOutput()
236
102
 
237
103
  try:
238
- cleaned_text = input_text.strip()
239
-
240
- # FIXED: Correct parameter order for load
104
+ # Prompt configs contain two keys: main_template and analyze template, both are string
241
105
  prompt_configs = prompt_loader.load(
242
- prompt_file=prompt_file, # prompt_file
243
- text=cleaned_text, # text
244
- mode=mode if use_modes else "", # mode
106
+ prompt_file=prompt_file,
107
+ text=text.strip(),
108
+ mode=mode,
245
109
  **extra_kwargs,
246
110
  )
247
111
 
248
- messages: list[dict[str, str]] = []
112
+ messages = []
249
113
 
250
114
  if with_analysis:
251
- analysis = await self._analyze(prompt_configs)
115
+ analysis = await self._analyze(prompt_configs, temperature)
252
116
  messages.append(
253
- self._build_user_message(f"Based on this analysis: {analysis}")
117
+ OperatorUtils.build_user_message(
118
+ f"Based on this analysis: {analysis}"
119
+ )
254
120
  )
255
121
 
256
122
  if output_lang:
257
123
  messages.append(
258
- self._build_user_message(
124
+ OperatorUtils.build_user_message(
259
125
  f"Respond only in the {output_lang} language."
260
126
  )
261
127
  )
262
128
 
263
- messages.append(self._build_user_message(prompt_configs["main_template"]))
264
- messages = formatter.format(messages)
265
-
266
- if resp_format == "vllm":
267
- parsed, completion = await self._vllm_completion(
268
- messages,
269
- output_model,
270
- logprobs,
271
- top_logprobs,
272
- max_tokens, # Pass max_tokens
273
- )
274
- elif resp_format == "parse":
275
- parsed, completion = await self._parse_completion(
276
- messages,
277
- output_model,
278
- logprobs,
279
- top_logprobs,
280
- max_tokens, # Pass max_tokens
129
+ if user_prompt:
130
+ messages.append(
131
+ OperatorUtils.build_user_message(
132
+ f"Consider this instruction {user_prompt}"
133
+ )
281
134
  )
282
- else:
283
- raise ValueError(f"Unknown resp_format: {resp_format}")
284
135
 
285
- results = {"result": parsed.result}
136
+ messages.append(
137
+ OperatorUtils.build_user_message(prompt_configs["main_template"])
138
+ )
139
+
140
+ messages = formatter.user_merge_format(messages)
141
+
142
+ parsed, completion = await self._parse_completion(
143
+ messages, output_model, temperature, logprobs, top_logprobs, priority
144
+ )
145
+
146
+ output.result = parsed.result
147
+
148
+ # Retry logic if validation fails
149
+ if validator and not validator(output.result):
150
+ for attempt in range(max_validation_retries):
151
+ logger.warning(
152
+ f"Validation failed, retrying for the {attempt + 1} time."
153
+ )
154
+
155
+ # Generate new temperature for retry
156
+ retry_temperature = OperatorUtils.get_retry_temp(temperature)
157
+ try:
158
+ parsed, completion = await self._parse_completion(
159
+ messages,
160
+ output_model,
161
+ retry_temperature,
162
+ logprobs,
163
+ top_logprobs,
164
+ )
165
+
166
+ output.result = parsed.result
167
+
168
+ # Check if retry was successful
169
+ if validator(output.result):
170
+ logger.info(
171
+ f"Validation passed on retry attempt {attempt + 1}"
172
+ )
173
+ break
174
+ else:
175
+ logger.warning(
176
+ f"Validation still failing after retry attempt {attempt + 1}"
177
+ )
178
+
179
+ except Exception as e:
180
+ logger.error(f"Retry attempt {attempt + 1} failed: {e}")
181
+ # Continue to next retry attempt if this one fails
182
+
183
+ # Final check after all retries
184
+ if validator and not validator(output.result):
185
+ output.errors.append("Validation failed after all retry attempts")
286
186
 
287
187
  if logprobs:
288
- results["logprobs"] = self._extract_logprobs(completion)
188
+ output.logprobs = OperatorUtils.extract_logprobs(completion)
289
189
 
290
190
  if with_analysis:
291
- results["analysis"] = analysis
191
+ output.analysis = analysis
192
+
193
+ output.process = prompt_file[:-5]
292
194
 
293
- return results
195
+ return output
294
196
 
295
197
  except Exception as e:
296
- print(f"[ERROR] Async operation failed: {e}")
297
- raise
198
+ logger.error(f"AsyncTheTool failed: {e}")
199
+ output.errors.append(str(e))
200
+ return output
@@ -0,0 +1,24 @@
1
+ class Formatter:
2
+ @staticmethod
3
+ def user_merge_format(messages: list[dict[str, str]]) -> list[dict[str, str]]:
4
+ """
5
+ Merges consecutive user messages into a single message, separated by newlines.
6
+
7
+ This is useful for condensing a multi-turn user input into a single
8
+ message for the LLM. Assistant and system messages are left unchanged and
9
+ act as separators between user message groups.
10
+ """
11
+ merged: list[dict[str, str]] = []
12
+
13
+ for message in messages:
14
+ role, content = message["role"], message["content"].strip()
15
+
16
+ # Merge with previous user turn
17
+ if merged and role == "user" and merged[-1]["role"] == "user":
18
+ merged[-1]["content"] += "\n" + content
19
+
20
+ # Otherwise, start a new turn
21
+ else:
22
+ merged.append({"role": role, "content": content})
23
+
24
+ return merged
@@ -0,0 +1,183 @@
1
+ from datetime import datetime
2
+ from typing import Type, Any, Literal
3
+
4
+ from pydantic import BaseModel, Field, create_model
5
+
6
+
7
+ class ToolOutput(BaseModel):
8
+ result: Any = None
9
+ analysis: str = ""
10
+ logprobs: list[dict[str, Any]] = []
11
+ process: str = ""
12
+ processed_at: datetime = datetime.now()
13
+ execution_time: float = -1.0
14
+ errors: list[str] = []
15
+
16
+ def __repr__(self) -> str:
17
+ return f"ToolOutput(process='{self.process}', result_type='{type(self.result)}', result='{self.result}', analysis='{self.analysis}', logprobs='{self.logprobs}', errors='{self.errors}', processed_at='{self.processed_at}', execution_time='{self.execution_time}'"
18
+
19
+
20
+ class StrOutput(BaseModel):
21
+ result: str = Field(..., description="The output string")
22
+
23
+
24
+ class BoolOutput(BaseModel):
25
+ result: bool = Field(
26
+ ..., description="Boolean indicating the output state", example=True
27
+ )
28
+
29
+
30
+ class ListStrOutput(BaseModel):
31
+ result: list[str] = Field(
32
+ ..., description="The output list of strings", example=["text_1", "text_2"]
33
+ )
34
+
35
+
36
+ class ListDictStrStrOutput(BaseModel):
37
+ result: list[dict[str, str]] = Field(
38
+ ...,
39
+ description="List of dictionaries containing string key-value pairs",
40
+ example=[{"text": "Mohammad", "type": "PER"}],
41
+ )
42
+
43
+
44
+ class ReasonListStrOutput(BaseModel):
45
+ reason: str = Field(..., description="Thinking process that led to the output")
46
+ result: list[str] = Field(..., description="The output list of strings")
47
+
48
+
49
+ class Node(BaseModel):
50
+ node_id: int
51
+ name: str
52
+ level: int
53
+ parent_id: int | None
54
+ description: str = "No description provided"
55
+
56
+
57
+ class CategoryTree:
58
+ def __init__(self, tree_name):
59
+ self.root = Node(node_id=0, name=tree_name, level=0, parent_id=None)
60
+ self.all_nodes: list[Node] = [self.root]
61
+ self.new_id = 1
62
+
63
+ def add_node(
64
+ self,
65
+ node_name: str,
66
+ parent_name: str | None = None,
67
+ description: str | None = None,
68
+ ) -> None:
69
+ if self.find_node(node_name):
70
+ raise ValueError(f"{node_name} has been chosen for another category before")
71
+
72
+ if parent_name:
73
+ parent_node = self.find_node(parent_name)
74
+ if parent_node is None:
75
+ raise ValueError(f"Parent category '{parent_name}' not found")
76
+ parent_id = parent_node.node_id
77
+ level = parent_node.level + 1
78
+ else:
79
+ level = 1
80
+ parent_id = 0
81
+
82
+ node_data = {
83
+ "node_id": self.new_id,
84
+ "name": node_name,
85
+ "level": level,
86
+ "parent_id": parent_id,
87
+ }
88
+
89
+ if description is not None:
90
+ node_data["description"] = description
91
+
92
+ self.all_nodes.append(Node(**node_data))
93
+ self.new_id += 1
94
+
95
+ def get_nodes(self) -> list[Node]:
96
+ return self.all_nodes
97
+
98
+ def get_level_count(self) -> int:
99
+ return max([item.level for item in self.all_nodes])
100
+
101
+ def find_node(self, identifier: int | str) -> Node | None:
102
+ if isinstance(identifier, str):
103
+ for node in self.get_nodes():
104
+ if node.name == identifier:
105
+ return node
106
+ return None
107
+ elif isinstance(identifier, int):
108
+ for node in self.get_nodes():
109
+ if node.node_id == identifier:
110
+ return node
111
+ return None
112
+ else:
113
+ return None
114
+
115
+ def find_children(self, parent_node: Node) -> list[Node] | None:
116
+ children = [
117
+ node for node in self.get_nodes() if parent_node.node_id == node.parent_id
118
+ ]
119
+ return children if children else None
120
+
121
+ def remove_node(self, identifier: int | str) -> None:
122
+ node = self.find_node(identifier)
123
+
124
+ if node is not None:
125
+ # Remove node's children recursively
126
+ children = self.find_children(node)
127
+
128
+ # Ending condition
129
+ if children is None:
130
+ self.all_nodes.remove(node)
131
+ return
132
+
133
+ for child in children:
134
+ self.remove_node(child.name)
135
+
136
+ # Remove the node from tree
137
+ self.all_nodes.remove(node)
138
+ else:
139
+ raise ValueError(f"Node with identifier: '{identifier}' not found.")
140
+
141
+ def dump_tree(self) -> dict:
142
+ def build_dict(node: Node) -> dict:
143
+ children = [
144
+ build_dict(child)
145
+ for child in self.all_nodes
146
+ if child.parent_id == node.node_id
147
+ ]
148
+ return {
149
+ "node_id": node.node_id,
150
+ "name": node.name,
151
+ "level": node.level,
152
+ "parent_id": node.parent_id,
153
+ "children": children,
154
+ }
155
+
156
+ return {"category_tree": build_dict(self.root)["children"]}
157
+
158
+
159
+ # This function is needed to create CategorizerOutput with dynamic categories
160
+ def create_dynamic_model(allowed_values: list[str]) -> Type[BaseModel]:
161
+ literal_type = Literal[*allowed_values]
162
+
163
+ CategorizerOutput = create_model(
164
+ "CategorizerOutput",
165
+ reason=(
166
+ str,
167
+ Field(
168
+ ..., description="Explanation of why the input belongs to the category"
169
+ ),
170
+ ),
171
+ result=(literal_type, Field(..., description="Predicted category label")),
172
+ )
173
+
174
+ return CategorizerOutput
175
+
176
+
177
+ class Entity(BaseModel):
178
+ text: str = Field(description="The exact text of the entity")
179
+ entity_type: str = Field(description="The type of the entity")
180
+
181
+
182
+ class EntityDetectorOutput(BaseModel):
183
+ result: list[Entity] = Field(description="List of all extracted entities")