hamtaa-texttools 1.1.20__py3-none-any.whl → 1.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {hamtaa_texttools-1.1.20.dist-info → hamtaa_texttools-1.1.21.dist-info}/METADATA +8 -27
  2. hamtaa_texttools-1.1.21.dist-info/RECORD +32 -0
  3. texttools/batch/batch_config.py +14 -1
  4. texttools/batch/batch_runner.py +1 -1
  5. texttools/internals/async_operator.py +45 -79
  6. texttools/internals/models.py +74 -105
  7. texttools/internals/operator_utils.py +2 -26
  8. texttools/internals/prompt_loader.py +3 -20
  9. texttools/internals/sync_operator.py +44 -78
  10. texttools/prompts/README.md +2 -2
  11. texttools/prompts/categorize.yaml +35 -77
  12. texttools/prompts/check_fact.yaml +2 -2
  13. texttools/prompts/extract_entities.yaml +2 -2
  14. texttools/prompts/extract_keywords.yaml +6 -6
  15. texttools/prompts/is_question.yaml +2 -2
  16. texttools/prompts/merge_questions.yaml +4 -4
  17. texttools/prompts/propositionize.yaml +2 -2
  18. texttools/prompts/rewrite.yaml +6 -6
  19. texttools/prompts/run_custom.yaml +1 -1
  20. texttools/prompts/subject_to_question.yaml +2 -2
  21. texttools/prompts/summarize.yaml +2 -2
  22. texttools/prompts/text_to_question.yaml +2 -2
  23. texttools/prompts/translate.yaml +2 -2
  24. texttools/tools/async_tools.py +393 -485
  25. texttools/tools/sync_tools.py +394 -486
  26. hamtaa_texttools-1.1.20.dist-info/RECORD +0 -33
  27. texttools/batch/internals/utils.py +0 -13
  28. {hamtaa_texttools-1.1.20.dist-info → hamtaa_texttools-1.1.21.dist-info}/WHEEL +0 -0
  29. {hamtaa_texttools-1.1.20.dist-info → hamtaa_texttools-1.1.21.dist-info}/licenses/LICENSE +0 -0
  30. {hamtaa_texttools-1.1.20.dist-info → hamtaa_texttools-1.1.21.dist-info}/top_level.txt +0 -0
  31. /texttools/batch/{internals/batch_manager.py → batch_manager.py} +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hamtaa-texttools
3
- Version: 1.1.20
3
+ Version: 1.1.21
4
4
  Summary: A high-level NLP toolkit built on top of modern LLMs.
5
5
  Author-email: Tohidi <the.mohammad.tohidi@gmail.com>, Montazer <montazerh82@gmail.com>, Givechi <mohamad.m.givechi@gmail.com>, MoosaviNejad <erfanmoosavi84@gmail.com>, Zareshahi <a.zareshahi1377@gmail.com>
6
6
  License: MIT License
@@ -50,7 +50,7 @@ It provides ready-to-use utilities for **translation, question detection, keywor
50
50
  TextTools provides a rich collection of high-level NLP utilities,
51
51
  Each tool is designed to work with structured outputs (JSON / Pydantic).
52
52
 
53
- - **`categorize()`** - Classifies text into given categories (You have to create a category tree)
53
+ - **`categorize()`** - Classifies text into given categories
54
54
  - **`extract_keywords()`** - Extracts keywords from text
55
55
  - **`extract_entities()`** - Named Entity Recognition (NER) system
56
56
  - **`is_question()`** - Binary detection of whether input is a question
@@ -61,7 +61,7 @@ Each tool is designed to work with structured outputs (JSON / Pydantic).
61
61
  - **`summarize()`** - Text summarization
62
62
  - **`translate()`** - Text translation between languages
63
63
  - **`propositionize()`** - Convert text to atomic independence meaningful sentences
64
- - **`check_fact()`** - Check a statement is relevant to source text or not
64
+ - **`check_fact()`** - Check whether a statement is relevant to the source text
65
65
  - **`run_custom()`** - Allows users to define a custom tool with an arbitrary BaseModel
66
66
 
67
67
  ---
@@ -125,11 +125,12 @@ TextTools provides several optional flags to customize LLM behavior:
125
125
  Every tool of `TextTools` returns a `ToolOutput` object which is a BaseModel with attributes:
126
126
  - **`result: Any`** → The output of LLM
127
127
  - **`analysis: str`** → The reasoning step before generating the final output
128
- - **`logprobs: list`** → Token-level probabilities for the generated output
129
- - **`process: str`** → The tool name which processed the input
130
- - **`processed_at: datetime`** → The process time
131
- - **`execution_time: float`** → The execution time (seconds)
128
+ - **`logprobs: list`** → Token-level probabilities for the generated output
132
129
  - **`errors: list[str]`** → Any error that have occured during calling LLM
130
+ - **`ToolOutputMetadata`** →
131
+ - **`tool_name: str`** → The tool name which processed the input
132
+ - **`processed_at: datetime`** → The process time
133
+ - **`execution_time: float`** → The execution time (seconds)
133
134
 
134
135
  **Note:** You can use `repr(ToolOutput)` to see details of your ToolOutput.
135
136
 
@@ -224,26 +225,6 @@ Use **TextTools** when you need to:
224
225
 
225
226
  ---
226
227
 
227
- ## 🔍 Logging
228
-
229
- TextTools uses Python's standard `logging` module. The library's default logger level is `WARNING`, so if you want to modify it, follow instructions:
230
-
231
-
232
- ```python
233
- import logging
234
-
235
- # Default: warnings and errors only
236
- logging.basicConfig(level=logging.WARNING)
237
-
238
- # Debug everything (verbose)
239
- logging.basicConfig(level=logging.DEBUG)
240
-
241
- # Complete silence
242
- logging.basicConfig(level=logging.CRITICAL)
243
- ```
244
-
245
- ---
246
-
247
228
  ## 📚 Batch Processing
248
229
 
249
230
  Process large datasets efficiently using OpenAI's batch API.
@@ -0,0 +1,32 @@
1
+ hamtaa_texttools-1.1.21.dist-info/licenses/LICENSE,sha256=Hb2YOBKy2MJQLnyLrX37B4ZVuac8eaIcE71SvVIMOLg,1082
2
+ texttools/__init__.py,sha256=CmCS9dEvO6061GiJ8A7gD3UAhCWHTkaID9q3Krlyq_o,311
3
+ texttools/batch/batch_config.py,sha256=scWYQBDuaTj8-b2x_a33Zu-zxm7eqEf5FFoquD-Sv94,1029
4
+ texttools/batch/batch_manager.py,sha256=6HfsexU0PHGGBH7HKReZ-CQxaQI9DXYKAPsFXxovb_I,8740
5
+ texttools/batch/batch_runner.py,sha256=fmoq7yxtEdvfLbEhcx95ma-lgrL-ZdI2EgxmEfVcKtE,10016
6
+ texttools/internals/async_operator.py,sha256=sKMYEy7jEcsXpwnBkA18PFubkM-TXZrBH3QwF7l-wSg,7054
7
+ texttools/internals/exceptions.py,sha256=h_yp_5i_5IfmqTBQ4S6ZOISrrliJBQ3HTEAjwJXrplk,495
8
+ texttools/internals/models.py,sha256=9uoCAe2TLrSzyS9lMJja5orPAYaCvVL1zoCb6FNdkfs,4541
9
+ texttools/internals/operator_utils.py,sha256=eLY2OjYQ3jT-50nx3I8gzuVzgGpMi52f5oB3cnFyxko,1864
10
+ texttools/internals/prompt_loader.py,sha256=yYXDD4YYG2zohGPAmvZwmv5f6xV_RSl5yOrObTh9w7I,3352
11
+ texttools/internals/sync_operator.py,sha256=IG3CXfGmv4PdFlAQ4AZcKuBAqPJdkIAK4mVw77zLbqI,6959
12
+ texttools/internals/text_to_chunks.py,sha256=vY3odhgCZK4E44k_SGlLoSiKkdN0ib6-lQAsPcplAHA,3843
13
+ texttools/prompts/README.md,sha256=ztajRJcmFLhyrUF0_qmOXaCwGsTGCFabfMjch2LAJG0,1375
14
+ texttools/prompts/categorize.yaml,sha256=016b1uGtbKXEwB8_2_bBgVuUelBlu_rgT85XK_c3Yv0,1219
15
+ texttools/prompts/check_fact.yaml,sha256=gQqacCXqUEx3u2FRwhFSZHvhyWGwsYuJd1nIJyhpu7Q,700
16
+ texttools/prompts/extract_entities.yaml,sha256=DN8lZjvzCjotODnHFkWIAxFvmVvoeSs-hDKdN1L6bec,608
17
+ texttools/prompts/extract_keywords.yaml,sha256=GoeApi9SUCLZgs18H2-2BxZiKQ3lHptMPesgq3cluqU,3171
18
+ texttools/prompts/is_question.yaml,sha256=w5qF-z05h62YVs-0x2b2ySlHDKIhukFC9pibnvNM0vc,469
19
+ texttools/prompts/merge_questions.yaml,sha256=f6bHEx54jJ8hnb8iDBUCxXeGdGwRFmuu7vOkVWdaIkM,1788
20
+ texttools/prompts/propositionize.yaml,sha256=agZKQY-NmeJD86DGjmd-paIuazf82bczIGadgzSP5Vs,1378
21
+ texttools/prompts/rewrite.yaml,sha256=h6x8aXcW8oRxEbp466eak0y-LCkUOKf-mJ-vNVp5j5M,5386
22
+ texttools/prompts/run_custom.yaml,sha256=IETY9H0wPGWIIzcnupfbwwKQblwZrbYAxB754W9MhgU,125
23
+ texttools/prompts/subject_to_question.yaml,sha256=TfVmZ6gDgaHRqJWCVkFlKpuJczpMvJTo4XLWPaq5zic,1145
24
+ texttools/prompts/summarize.yaml,sha256=CKx4vjhHbGus1TdjDz_oc0bNEQtq7zfHsZkV2WeYHDU,457
25
+ texttools/prompts/text_to_question.yaml,sha256=mnArBoYu7gpGHriaU2-Aw5SixB2ZIgoHMt99PnTPKD0,1003
26
+ texttools/prompts/translate.yaml,sha256=ew9RERAVSzg0cvxAinNwTSFIaOIjdwIsekbUsgAuNgo,632
27
+ texttools/tools/async_tools.py,sha256=VU3cqqCPILsyjRiG84w8kCw3iDSuFbI6S3VjExXZwFQ,44635
28
+ texttools/tools/sync_tools.py,sha256=2cqcosMYR6LHuYw32WFR-drvqQ-t7Q9_2rUBDOeYzho,44441
29
+ hamtaa_texttools-1.1.21.dist-info/METADATA,sha256=lExdE6uMFSs_wqUSElOyktjpHpZx4RY-cUH6azF-IYA,10183
30
+ hamtaa_texttools-1.1.21.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
+ hamtaa_texttools-1.1.21.dist-info/top_level.txt,sha256=5Mh0jIxxZ5rOXHGJ6Mp-JPKviywwN0MYuH0xk5bEWqE,10
32
+ hamtaa_texttools-1.1.21.dist-info/RECORD,,
@@ -1,7 +1,20 @@
1
1
  from dataclasses import dataclass
2
2
  from collections.abc import Callable
3
3
 
4
- from texttools.batch.internals.utils import import_data, export_data
4
+
5
+ def export_data(data) -> list[dict[str, str]]:
6
+ """
7
+ Produces a structure of the following form from an initial data structure:
8
+ [{"id": str, "text": str},...]
9
+ """
10
+ return data
11
+
12
+
13
+ def import_data(data) -> object:
14
+ """
15
+ Takes the output and adds and aggregates it to the original structure.
16
+ """
17
+ return data
5
18
 
6
19
 
7
20
  @dataclass
@@ -9,7 +9,7 @@ from dotenv import load_dotenv
9
9
  from openai import OpenAI
10
10
  from pydantic import BaseModel
11
11
 
12
- from texttools.batch.internals.batch_manager import BatchManager
12
+ from texttools.batch.batch_manager import BatchManager
13
13
  from texttools.batch.batch_config import BatchConfig
14
14
  from texttools.internals.models import Str
15
15
  from texttools.internals.exceptions import TextToolsError, ConfigurationError
@@ -1,11 +1,10 @@
1
1
  from typing import TypeVar, Type
2
2
  from collections.abc import Callable
3
- import logging
4
3
 
5
4
  from openai import AsyncOpenAI
6
5
  from pydantic import BaseModel
7
6
 
8
- from texttools.internals.models import ToolOutput
7
+ from texttools.internals.models import OperatorOutput
9
8
  from texttools.internals.operator_utils import OperatorUtils
10
9
  from texttools.internals.prompt_loader import PromptLoader
11
10
  from texttools.internals.exceptions import (
@@ -18,35 +17,23 @@ from texttools.internals.exceptions import (
18
17
  # Base Model type for output models
19
18
  T = TypeVar("T", bound=BaseModel)
20
19
 
21
- logger = logging.getLogger("texttools.async_operator")
22
-
23
20
 
24
21
  class AsyncOperator:
25
22
  """
26
- Core engine for running text-processing operations with an LLM (Async).
27
-
28
- It wires together:
29
- - `PromptLoader` → loads YAML prompt templates.
30
- - `UserMergeFormatter` → applies formatting to messages (e.g., merging).
31
- - AsyncOpenAI client → executes completions/parsed completions.
23
+ Core engine for running text-processing operations with an LLM.
32
24
  """
33
25
 
34
26
  def __init__(self, client: AsyncOpenAI, model: str):
35
27
  self._client = client
36
28
  self._model = model
37
29
 
38
- async def _analyze(self, prompt_configs: dict[str, str], temperature: float) -> str:
39
- """
40
- Calls OpenAI API for analysis using the configured prompt template.
41
- Returns the analyzed content as a string.
42
- """
30
+ async def _analyze_completion(self, analyze_prompt: str, temperature: float) -> str:
43
31
  try:
44
- analyze_prompt = prompt_configs["analyze_template"]
45
-
46
32
  if not analyze_prompt:
47
33
  raise PromptError("Analyze template is empty")
48
34
 
49
- analyze_message = [OperatorUtils.build_user_message(analyze_prompt)]
35
+ analyze_message = OperatorUtils.build_user_message(analyze_prompt)
36
+
50
37
  completion = await self._client.chat.completions.create(
51
38
  model=self._model,
52
39
  messages=analyze_message,
@@ -61,7 +48,7 @@ class AsyncOperator:
61
48
  if not analysis:
62
49
  raise LLMError("Empty analysis response")
63
50
 
64
- return analysis.strip()
51
+ return analysis
65
52
 
66
53
  except Exception as e:
67
54
  if isinstance(e, (PromptError, LLMError)):
@@ -70,21 +57,23 @@ class AsyncOperator:
70
57
 
71
58
  async def _parse_completion(
72
59
  self,
73
- message: list[dict[str, str]],
60
+ main_prompt: str,
74
61
  output_model: Type[T],
75
62
  temperature: float,
76
- logprobs: bool = False,
77
- top_logprobs: int = 3,
78
- priority: int | None = 0,
63
+ logprobs: bool,
64
+ top_logprobs: int,
65
+ priority: int,
79
66
  ) -> tuple[T, object]:
80
67
  """
81
68
  Parses a chat completion using OpenAI's structured output format.
82
69
  Returns both the parsed object and the raw completion for logprobs.
83
70
  """
84
71
  try:
72
+ main_message = OperatorUtils.build_user_message(main_prompt)
73
+
85
74
  request_kwargs = {
86
75
  "model": self._model,
87
- "messages": message,
76
+ "messages": main_message,
88
77
  "response_format": output_model,
89
78
  "temperature": temperature,
90
79
  }
@@ -92,8 +81,10 @@ class AsyncOperator:
92
81
  if logprobs:
93
82
  request_kwargs["logprobs"] = True
94
83
  request_kwargs["top_logprobs"] = top_logprobs
84
+
95
85
  if priority:
96
86
  request_kwargs["extra_body"] = {"priority": priority}
87
+
97
88
  completion = await self._client.beta.chat.completions.parse(
98
89
  **request_kwargs
99
90
  )
@@ -122,24 +113,22 @@ class AsyncOperator:
122
113
  user_prompt: str | None,
123
114
  temperature: float,
124
115
  logprobs: bool,
125
- top_logprobs: int | None,
116
+ top_logprobs: int,
126
117
  validator: Callable[[object], bool] | None,
127
118
  max_validation_retries: int | None,
119
+ priority: int,
128
120
  # Internal parameters
129
121
  prompt_file: str,
130
122
  output_model: Type[T],
131
123
  mode: str | None,
132
- priority: int | None = 0,
133
124
  **extra_kwargs,
134
- ) -> ToolOutput:
125
+ ) -> OperatorOutput:
135
126
  """
136
- Execute the LLM pipeline with the given input text. (Async)
127
+ Execute the LLM pipeline with the given input text. (Sync)
137
128
  """
138
129
  try:
139
130
  prompt_loader = PromptLoader()
140
- output = ToolOutput()
141
131
 
142
- # Prompt configs contain two keys: main_template and analyze template, both are string
143
132
  prompt_configs = prompt_loader.load(
144
133
  prompt_file=prompt_file,
145
134
  text=text.strip(),
@@ -147,47 +136,32 @@ class AsyncOperator:
147
136
  **extra_kwargs,
148
137
  )
149
138
 
150
- messages = []
139
+ main_prompt = ""
140
+ analysis = ""
151
141
 
152
142
  if with_analysis:
153
- analysis = await self._analyze(prompt_configs, temperature)
154
- messages.append(
155
- OperatorUtils.build_user_message(
156
- f"Based on this analysis: {analysis}"
157
- )
143
+ analysis = await self._analyze_completion(
144
+ prompt_configs["analyze_template"], temperature
158
145
  )
146
+ main_prompt += f"Based on this analysis:\n{analysis}\n"
159
147
 
160
148
  if output_lang:
161
- messages.append(
162
- OperatorUtils.build_user_message(
163
- f"Respond only in the {output_lang} language."
164
- )
165
- )
149
+ main_prompt += f"Respond only in the {output_lang} language.\n"
166
150
 
167
151
  if user_prompt:
168
- messages.append(
169
- OperatorUtils.build_user_message(
170
- f"Consider this instruction {user_prompt}"
171
- )
172
- )
173
-
174
- messages.append(
175
- OperatorUtils.build_user_message(prompt_configs["main_template"])
176
- )
152
+ main_prompt += f"Consider this instruction {user_prompt}\n"
177
153
 
178
- messages = OperatorUtils.user_merge_format(messages)
154
+ main_prompt += prompt_configs["main_template"]
179
155
 
180
156
  if logprobs and (not isinstance(top_logprobs, int) or top_logprobs < 2):
181
157
  raise ValueError("top_logprobs should be an integer greater than 1")
182
158
 
183
159
  parsed, completion = await self._parse_completion(
184
- messages, output_model, temperature, logprobs, top_logprobs, priority
160
+ main_prompt, output_model, temperature, logprobs, top_logprobs, priority
185
161
  )
186
162
 
187
- output.result = parsed.result
188
-
189
163
  # Retry logic if validation fails
190
- if validator and not validator(output.result):
164
+ if validator and not validator(parsed.result):
191
165
  if (
192
166
  not isinstance(max_validation_retries, int)
193
167
  or max_validation_retries < 1
@@ -197,17 +171,13 @@ class AsyncOperator:
197
171
  )
198
172
 
199
173
  succeeded = False
200
- for attempt in range(max_validation_retries):
201
- logger.warning(
202
- f"Validation failed, retrying for the {attempt + 1} time."
203
- )
204
-
205
- # Generate new temperature for retry
174
+ for _ in range(max_validation_retries):
175
+ # Generate a new temperature to retry
206
176
  retry_temperature = OperatorUtils.get_retry_temp(temperature)
207
177
 
208
178
  try:
209
179
  parsed, completion = await self._parse_completion(
210
- messages,
180
+ main_prompt,
211
181
  output_model,
212
182
  retry_temperature,
213
183
  logprobs,
@@ -215,30 +185,26 @@ class AsyncOperator:
215
185
  priority=priority,
216
186
  )
217
187
 
218
- output.result = parsed.result
219
-
220
188
  # Check if retry was successful
221
- if validator(output.result):
189
+ if validator(parsed.result):
222
190
  succeeded = True
223
191
  break
224
192
 
225
- except LLMError as e:
226
- logger.error(f"Retry attempt {attempt + 1} failed: {e}")
193
+ except LLMError:
194
+ pass
227
195
 
228
196
  if not succeeded:
229
- raise ValidationError(
230
- f"Validation failed after {max_validation_retries} retries"
231
- )
232
-
233
- if logprobs:
234
- output.logprobs = OperatorUtils.extract_logprobs(completion)
235
-
236
- if with_analysis:
237
- output.analysis = analysis
238
-
239
- output.process = prompt_file[:-5]
197
+ raise ValidationError("Validation failed after all retries")
198
+
199
+ operator_output = OperatorOutput(
200
+ result=parsed.result,
201
+ analysis=analysis if with_analysis else None,
202
+ logprobs=OperatorUtils.extract_logprobs(completion)
203
+ if logprobs
204
+ else None,
205
+ )
240
206
 
241
- return output
207
+ return operator_output
242
208
 
243
209
  except (PromptError, LLMError, ValidationError):
244
210
  raise
@@ -1,25 +1,39 @@
1
+ from __future__ import annotations
2
+
1
3
  from datetime import datetime
2
- from typing import Type, Literal
4
+ from typing import Type, Literal, Any
3
5
 
4
6
  from pydantic import BaseModel, Field, create_model
5
7
 
6
8
 
7
- class ToolOutput(BaseModel):
8
- result: object = None
9
- logprobs: list[dict[str, object]] = []
10
- analysis: str = ""
11
- process: str | None = None
12
- processed_at: datetime = datetime.now()
9
+ class ToolOutputMetadata(BaseModel):
10
+ tool_name: str
11
+ processed_at: datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
13
12
  execution_time: float | None = None
13
+
14
+
15
+ class ToolOutput(BaseModel):
16
+ result: Any = None
17
+ analysis: str | None = None
18
+ logprobs: list[dict[str, Any]] | None = None
14
19
  errors: list[str] = []
20
+ metadata: ToolOutputMetadata | None = None
15
21
 
16
22
  def __repr__(self) -> str:
17
- return f"""
18
- ToolOutput(process='{self.process}', result_type='{type(self.result)}',
19
- result='{self.result}', analysis='{self.analysis}',
20
- logprobs='{self.logprobs}', errors='{self.errors}',
21
- processed_at='{self.processed_at}', execution_time='{self.execution_time}'
22
- """
23
+ base = f"""ToolOutput(result='{self.result}', result_type='{type(self.result)}', analysis='{self.analysis}', logprobs='{self.logprobs}', errors='{self.errors}'"""
24
+
25
+ if self.metadata:
26
+ base += f""", tool_name='{self.metadata.tool_name}',
27
+ processed_at='{self.metadata.processed_at}', execution_time='{self.metadata.execution_time}'
28
+ """
29
+
30
+ return base
31
+
32
+
33
+ class OperatorOutput(BaseModel):
34
+ result: Any
35
+ analysis: str | None
36
+ logprobs: list[dict[str, Any]] | None
23
37
 
24
38
 
25
39
  class Str(BaseModel):
@@ -53,114 +67,69 @@ class ReasonListStr(BaseModel):
53
67
  )
54
68
 
55
69
 
56
- class Node(BaseModel):
57
- node_id: int
58
- name: str
59
- level: int
60
- parent_id: int | None
61
- description: str
70
+ class Node:
71
+ def __init__(self, name: str, description: str, level: int, parent: Node | None):
72
+ self.name = name
73
+ self.description = description
74
+ self.level = level
75
+ self.parent = parent
76
+ self.children = {}
62
77
 
63
78
 
64
79
  class CategoryTree:
65
- def __init__(self, tree_name):
66
- self._root = Node(
67
- node_id=0, name=tree_name, level=0, parent_id=None, description="Root node"
68
- )
69
- self._all_nodes: list[Node] = [self._root]
70
- self._new_id = 1
71
-
72
- def get_all_nodes(self) -> list[Node]:
80
+ def __init__(self):
81
+ self._root = Node(name="root", description="root", level=0, parent=None)
82
+ self._all_nodes = {"root": self._root}
83
+
84
+ def get_all_nodes(self) -> dict[str, Node]:
73
85
  return self._all_nodes
74
86
 
75
87
  def get_level_count(self) -> int:
76
- return max([item.level for item in self._all_nodes])
77
-
78
- def get_node(self, identifier: int | str) -> Node | None:
79
- if isinstance(identifier, str):
80
- for node in self.get_all_nodes():
81
- if node.name == identifier:
82
- return node
83
- return None
84
- elif isinstance(identifier, int):
85
- for node in self.get_all_nodes():
86
- if node.node_id == identifier:
87
- return node
88
- return None
89
- else:
90
- return None
91
-
92
- def get_children(self, parent_node: Node) -> list[Node] | None:
93
- children = [
94
- node
95
- for node in self.get_all_nodes()
96
- if parent_node.node_id == node.parent_id
97
- ]
98
- return children if children else None
88
+ return max(node.level for node in self._all_nodes.values())
89
+
90
+ def get_node(self, name: str) -> Node | None:
91
+ return self._all_nodes.get(name)
99
92
 
100
93
  def add_node(
101
94
  self,
102
- node_name: str,
103
- parent_name: str | None = None,
95
+ name: str,
96
+ parent_name: str,
104
97
  description: str | None = None,
105
98
  ) -> None:
106
- if self.get_node(node_name):
107
- raise ValueError(f"{node_name} has been chosen for another category before")
108
-
109
- if parent_name:
110
- parent_node = self.get_node(parent_name)
111
- if not parent_node:
112
- raise ValueError(f"Parent category '{parent_name}' not found")
113
- parent_id = parent_node.node_id
114
- level = parent_node.level + 1
115
- else:
116
- level = 1
117
- parent_id = 0
99
+ if self.get_node(name):
100
+ raise ValueError(f"Cannot add {name} category twice")
101
+
102
+ parent = self.get_node(parent_name)
103
+
104
+ if not parent:
105
+ raise ValueError(f"Parent category '{parent_name}' not found")
118
106
 
119
107
  node_data = {
120
- "node_id": self._new_id,
121
- "name": node_name,
122
- "level": level,
123
- "parent_id": parent_id,
108
+ "name": name,
124
109
  "description": description if description else "No description provided",
110
+ "level": parent.level + 1,
111
+ "parent": parent,
125
112
  }
126
113
 
127
- self._all_nodes.append(Node(**node_data))
128
- self._new_id += 1
129
-
130
- def remove_node(self, identifier: int | str) -> None:
131
- node = self.get_node(identifier)
132
-
133
- if node:
134
- # Remove node's children recursively
135
- children = self.get_children(node)
136
-
137
- if not children:
138
- self._all_nodes.remove(node)
139
- return
140
-
141
- for child in children:
142
- self.remove_node(child.name)
143
-
144
- self._all_nodes.remove(node)
145
- else:
146
- raise ValueError(f"Node with identifier: '{identifier}' not found.")
147
-
148
- def dump_tree(self) -> dict:
149
- def build_dict(node: Node) -> dict:
150
- children = [
151
- build_dict(child)
152
- for child in self._all_nodes
153
- if child.parent_id == node.node_id
154
- ]
155
- return {
156
- "node_id": node.node_id,
157
- "name": node.name,
158
- "level": node.level,
159
- "parent_id": node.parent_id,
160
- "children": children,
161
- }
162
-
163
- return {"category_tree": build_dict(self._root)["children"]}
114
+ new_node = Node(**node_data)
115
+ parent.children[name] = new_node
116
+ self._all_nodes[name] = new_node
117
+
118
+ def remove_node(self, name: str) -> None:
119
+ if name == "root":
120
+ raise ValueError("Cannot remove the root node")
121
+
122
+ node = self.get_node(name)
123
+ if not node:
124
+ raise ValueError(f"Category: '{name}' not found")
125
+
126
+ for child_name in list(node.children.keys()):
127
+ self.remove_node(child_name)
128
+
129
+ if node.parent:
130
+ del node.parent.children[name]
131
+
132
+ del self._all_nodes[name]
164
133
 
165
134
 
166
135
  # This function is needed to create CategorizerOutput with dynamic categories
@@ -5,8 +5,8 @@ import random
5
5
 
6
6
  class OperatorUtils:
7
7
  @staticmethod
8
- def build_user_message(prompt: str) -> dict[str, str]:
9
- return {"role": "user", "content": prompt}
8
+ def build_user_message(prompt: str) -> list[dict[str, str]]:
9
+ return [{"role": "user", "content": prompt}]
10
10
 
11
11
  @staticmethod
12
12
  def extract_logprobs(completion: dict) -> list[dict]:
@@ -52,27 +52,3 @@ class OperatorUtils:
52
52
  new_temp = base_temp + delta_temp
53
53
 
54
54
  return max(0.0, min(new_temp, 1.5))
55
-
56
- @staticmethod
57
- def user_merge_format(messages: list[dict[str, str]]) -> list[dict[str, str]]:
58
- """
59
- Merges consecutive user messages into a single message, separated by newlines.
60
-
61
- This is useful for condensing a multi-turn user input into a single
62
- message for the LLM. Assistant and system messages are left unchanged and
63
- act as separators between user message groups.
64
- """
65
- merged = []
66
-
67
- for message in messages:
68
- role, content = message["role"], message["content"].strip()
69
-
70
- # Merge with previous user turn
71
- if merged and role == "user" and merged[-1]["role"] == "user":
72
- merged[-1]["content"] += "\n" + content
73
-
74
- # Otherwise, start a new turn
75
- else:
76
- merged.append({"role": role, "content": content})
77
-
78
- return merged