hamtaa-texttools 1.1.13__py3-none-any.whl → 1.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,10 @@
1
- from typing import Literal, Any, Callable
1
+ from typing import Literal, Any
2
+ from collections.abc import Callable
2
3
 
3
4
  from openai import AsyncOpenAI
4
5
 
5
6
  from texttools.tools.internals.async_operator import AsyncOperator
6
- import texttools.tools.internals.output_models as OM
7
+ import texttools.tools.internals.models as Models
7
8
 
8
9
 
9
10
  class AsyncTheTool:
@@ -29,19 +30,23 @@ class AsyncTheTool:
29
30
  async def categorize(
30
31
  self,
31
32
  text: str,
33
+ categories: list[str] | Models.CategoryTree,
32
34
  with_analysis: bool = False,
33
35
  user_prompt: str | None = None,
34
36
  temperature: float | None = 0.0,
35
37
  logprobs: bool = False,
36
38
  top_logprobs: int | None = None,
39
+ mode: Literal["category_list", "category_tree"] = "category_list",
37
40
  validator: Callable[[Any], bool] | None = None,
38
41
  max_validation_retries: int | None = None,
39
- ) -> OM.ToolOutput:
42
+ priority: int | None = 0,
43
+ ) -> Models.ToolOutput:
40
44
  """
41
- Categorize a text into a single Islamic studies domain category.
45
+ Categorize a text into a category / category tree.
42
46
 
43
47
  Arguments:
44
48
  text: The input text to categorize
49
+ categories: The category / category_tree to give to LLM
45
50
  with_analysis: Whether to include detailed reasoning analysis
46
51
  user_prompt: Additional instructions for the categorization
47
52
  temperature: Controls randomness (0.0 = deterministic, 1.0 = creative)
@@ -49,30 +54,104 @@ class AsyncTheTool:
49
54
  top_logprobs: Number of top token alternatives to return if logprobs enabled
50
55
  validator: Custom validation function to validate the output
51
56
  max_validation_retries: Maximum number of retry attempts if validation fails
57
+ priority: Task execution priority (if enabled by vLLM and model)
52
58
 
53
59
  Returns:
54
60
  ToolOutput: Object containing:
55
- - result (str): The assigned Islamic studies category
61
+ - result (str): The assigned category
56
62
  - logprobs (list | None): Probability data if logprobs enabled
57
63
  - analysis (str | None): Detailed reasoning if with_analysis enabled
58
64
  - errors (list(str) | None): Errors occured during tool call
59
65
  """
60
- return await self._operator.run(
61
- # User parameters
62
- text=text,
63
- with_analysis=with_analysis,
64
- user_prompt=user_prompt,
65
- temperature=temperature,
66
- logprobs=logprobs,
67
- top_logprobs=top_logprobs,
68
- validator=validator,
69
- max_validation_retries=max_validation_retries,
70
- # Internal parameters
71
- prompt_file="categorizer.yaml",
72
- output_model=OM.CategorizerOutput,
73
- mode=None,
74
- output_lang=None,
75
- )
66
+ if mode == "category_tree":
67
+ # Initializations
68
+ output = Models.ToolOutput()
69
+ levels = categories.level_count()
70
+ parent_id = 0
71
+ final_output = []
72
+
73
+ for _ in range(levels):
74
+ # Get child nodes for current parent
75
+ parent_node = categories.find_node(parent_id)
76
+ children = categories.find_children(parent_node)
77
+
78
+ # Check if child nodes exist
79
+ if not children:
80
+ output.errors.append(
81
+ f"No categories found for parent_id {parent_id} in the tree"
82
+ )
83
+ return output
84
+
85
+ # Extract category names and descriptions
86
+ category_list = [
87
+ f"Category Name: {node.name}, Description: {node.description}"
88
+ for node in children
89
+ ]
90
+ category_names = [node.name for node in children]
91
+
92
+ # Run categorization for this level
93
+ level_output = await self._operator.run(
94
+ # User parameters
95
+ text=text,
96
+ category_list=category_list,
97
+ with_analysis=with_analysis,
98
+ user_prompt=user_prompt,
99
+ temperature=temperature,
100
+ logprobs=logprobs,
101
+ top_logprobs=top_logprobs,
102
+ mode=mode,
103
+ validator=validator,
104
+ max_validation_retries=max_validation_retries,
105
+ # Internal parameters
106
+ prompt_file="categorize.yaml",
107
+ output_model=Models.create_dynamic_model(category_names),
108
+ output_lang=None,
109
+ )
110
+
111
+ # Check for errors from operator
112
+ if level_output.errors:
113
+ output.errors.extend(level_output.errors)
114
+ return output
115
+
116
+ # Get the chosen category
117
+ chosen_category = level_output.result
118
+
119
+ # Find the corresponding node
120
+ parent_node = categories.find_node(chosen_category)
121
+ if parent_node is None:
122
+ output.errors.append(
123
+ f"Category '{chosen_category}' not found in tree after selection"
124
+ )
125
+ return output
126
+
127
+ parent_id = parent_node.node_id
128
+ final_output.append(parent_node.name)
129
+
130
+ # Copy analysis/logprobs from the last level's output
131
+ output.analysis = level_output.analysis
132
+ output.logprobs = level_output.logprobs
133
+
134
+ output.result = final_output
135
+ return output
136
+
137
+ else:
138
+ return await self._operator.run(
139
+ # User parameters
140
+ text=text,
141
+ category_list=categories,
142
+ with_analysis=with_analysis,
143
+ user_prompt=user_prompt,
144
+ temperature=temperature,
145
+ logprobs=logprobs,
146
+ top_logprobs=top_logprobs,
147
+ mode=mode,
148
+ validator=validator,
149
+ max_validation_retries=max_validation_retries,
150
+ # Internal parameters
151
+ prompt_file="categorize.yaml",
152
+ output_model=Models.create_dynamic_model(categories),
153
+ output_lang=None,
154
+ )
76
155
 
77
156
  async def extract_keywords(
78
157
  self,
@@ -83,9 +162,12 @@ class AsyncTheTool:
83
162
  temperature: float | None = 0.0,
84
163
  logprobs: bool = False,
85
164
  top_logprobs: int | None = None,
165
+ mode: Literal["auto", "threshold", "count"] = "auto",
166
+ number_of_keywords: int | None = None,
86
167
  validator: Callable[[Any], bool] | None = None,
87
168
  max_validation_retries: int | None = None,
88
- ) -> OM.ToolOutput:
169
+ priority: int | None = 0,
170
+ ) -> Models.ToolOutput:
89
171
  """
90
172
  Extract salient keywords from text.
91
173
 
@@ -99,6 +181,7 @@ class AsyncTheTool:
99
181
  top_logprobs: Number of top token alternatives to return if logprobs enabled
100
182
  validator: Custom validation function to validate the output
101
183
  max_validation_retries: Maximum number of retry attempts if validation fails
184
+ priority: Task execution priority (if enabled by vLLM and model)
102
185
 
103
186
  Returns:
104
187
  ToolOutput: Object containing:
@@ -116,12 +199,14 @@ class AsyncTheTool:
116
199
  temperature=temperature,
117
200
  logprobs=logprobs,
118
201
  top_logprobs=top_logprobs,
202
+ mode=mode,
203
+ number_of_keywords=number_of_keywords,
119
204
  validator=validator,
120
205
  max_validation_retries=max_validation_retries,
121
206
  # Internal parameters
122
207
  prompt_file="extract_keywords.yaml",
123
- output_model=OM.ListStrOutput,
124
- mode=None,
208
+ output_model=Models.ListStrOutput,
209
+ priority=priority,
125
210
  )
126
211
 
127
212
  async def extract_entities(
@@ -135,7 +220,8 @@ class AsyncTheTool:
135
220
  top_logprobs: int | None = None,
136
221
  validator: Callable[[Any], bool] | None = None,
137
222
  max_validation_retries: int | None = None,
138
- ) -> OM.ToolOutput:
223
+ priority: int | None = 0,
224
+ ) -> Models.ToolOutput:
139
225
  """
140
226
  Perform Named Entity Recognition (NER) over the input text.
141
227
 
@@ -149,6 +235,7 @@ class AsyncTheTool:
149
235
  top_logprobs: Number of top token alternatives to return if logprobs enabled
150
236
  validator: Custom validation function to validate the output
151
237
  max_validation_retries: Maximum number of retry attempts if validation fails
238
+ priority: Task execution priority (if enabled by vLLM and model)
152
239
 
153
240
  Returns:
154
241
  ToolOutput: Object containing:
@@ -170,8 +257,9 @@ class AsyncTheTool:
170
257
  max_validation_retries=max_validation_retries,
171
258
  # Internal parameters
172
259
  prompt_file="extract_entities.yaml",
173
- output_model=OM.ListDictStrStrOutput,
260
+ output_model=Models.ListDictStrStrOutput,
174
261
  mode=None,
262
+ priority=priority,
175
263
  )
176
264
 
177
265
  async def is_question(
@@ -184,7 +272,8 @@ class AsyncTheTool:
184
272
  top_logprobs: int | None = None,
185
273
  validator: Callable[[Any], bool] | None = None,
186
274
  max_validation_retries: int | None = None,
187
- ) -> OM.ToolOutput:
275
+ priority: int | None = 0,
276
+ ) -> Models.ToolOutput:
188
277
  """
189
278
  Detect if the input is phrased as a question.
190
279
 
@@ -197,6 +286,7 @@ class AsyncTheTool:
197
286
  top_logprobs: Number of top token alternatives to return if logprobs enabled
198
287
  validator: Custom validation function to validate the output
199
288
  max_validation_retries: Maximum number of retry attempts if validation fails
289
+ priority: Task execution priority (if enabled by vLLM and model)
200
290
 
201
291
  Returns:
202
292
  ToolOutput: Object containing:
@@ -217,9 +307,10 @@ class AsyncTheTool:
217
307
  max_validation_retries=max_validation_retries,
218
308
  # Internal parameters
219
309
  prompt_file="is_question.yaml",
220
- output_model=OM.BoolOutput,
310
+ output_model=Models.BoolOutput,
221
311
  mode=None,
222
312
  output_lang=None,
313
+ priority=priority,
223
314
  )
224
315
 
225
316
  async def text_to_question(
@@ -233,7 +324,8 @@ class AsyncTheTool:
233
324
  top_logprobs: int | None = None,
234
325
  validator: Callable[[Any], bool] | None = None,
235
326
  max_validation_retries: int | None = None,
236
- ) -> OM.ToolOutput:
327
+ priority: int | None = 0,
328
+ ) -> Models.ToolOutput:
237
329
  """
238
330
  Generate a single question from the given text.
239
331
 
@@ -247,6 +339,7 @@ class AsyncTheTool:
247
339
  top_logprobs: Number of top token alternatives to return if logprobs enabled
248
340
  validator: Custom validation function to validate the output
249
341
  max_validation_retries: Maximum number of retry attempts if validation fails
342
+ priority: Task execution priority (if enabled by vLLM and model)
250
343
 
251
344
  Returns:
252
345
  ToolOutput: Object containing:
@@ -268,8 +361,9 @@ class AsyncTheTool:
268
361
  max_validation_retries=max_validation_retries,
269
362
  # Internal parameters
270
363
  prompt_file="text_to_question.yaml",
271
- output_model=OM.StrOutput,
364
+ output_model=Models.StrOutput,
272
365
  mode=None,
366
+ priority=priority,
273
367
  )
274
368
 
275
369
  async def merge_questions(
@@ -284,7 +378,8 @@ class AsyncTheTool:
284
378
  mode: Literal["default", "reason"] = "default",
285
379
  validator: Callable[[Any], bool] | None = None,
286
380
  max_validation_retries: int | None = None,
287
- ) -> OM.ToolOutput:
381
+ priority: int | None = 0,
382
+ ) -> Models.ToolOutput:
288
383
  """
289
384
  Merge multiple questions into a single unified question.
290
385
 
@@ -299,6 +394,7 @@ class AsyncTheTool:
299
394
  mode: Merging strategy - 'default' for direct merge, 'reason' for reasoned merge
300
395
  validator: Custom validation function to validate the output
301
396
  max_validation_retries: Maximum number of retry attempts if validation fails
397
+ priority: Task execution priority (if enabled by vLLM and model)
302
398
 
303
399
  Returns:
304
400
  ToolOutput: Object containing:
@@ -321,8 +417,9 @@ class AsyncTheTool:
321
417
  max_validation_retries=max_validation_retries,
322
418
  # Internal parameters
323
419
  prompt_file="merge_questions.yaml",
324
- output_model=OM.StrOutput,
420
+ output_model=Models.StrOutput,
325
421
  mode=mode,
422
+ priority=priority,
326
423
  )
327
424
 
328
425
  async def rewrite(
@@ -337,7 +434,8 @@ class AsyncTheTool:
337
434
  mode: Literal["positive", "negative", "hard_negative"] = "positive",
338
435
  validator: Callable[[Any], bool] | None = None,
339
436
  max_validation_retries: int | None = None,
340
- ) -> OM.ToolOutput:
437
+ priority: int | None = 0,
438
+ ) -> Models.ToolOutput:
341
439
  """
342
440
  Rewrite a text with different modes.
343
441
 
@@ -352,6 +450,7 @@ class AsyncTheTool:
352
450
  mode: Rewriting mode - 'positive', 'negative', or 'hard_negative'
353
451
  validator: Custom validation function to validate the output
354
452
  max_validation_retries: Maximum number of retry attempts if validation fails
453
+ priority: Task execution priority (if enabled by vLLM and model)
355
454
 
356
455
  Returns:
357
456
  ToolOutput: Object containing:
@@ -373,8 +472,9 @@ class AsyncTheTool:
373
472
  max_validation_retries=max_validation_retries,
374
473
  # Internal parameters
375
474
  prompt_file="rewrite.yaml",
376
- output_model=OM.StrOutput,
475
+ output_model=Models.StrOutput,
377
476
  mode=mode,
477
+ priority=priority,
378
478
  )
379
479
 
380
480
  async def subject_to_question(
@@ -389,7 +489,8 @@ class AsyncTheTool:
389
489
  top_logprobs: int | None = None,
390
490
  validator: Callable[[Any], bool] | None = None,
391
491
  max_validation_retries: int | None = None,
392
- ) -> OM.ToolOutput:
492
+ priority: int | None = 0,
493
+ ) -> Models.ToolOutput:
393
494
  """
394
495
  Generate a list of questions about a subject.
395
496
 
@@ -404,6 +505,7 @@ class AsyncTheTool:
404
505
  top_logprobs: Number of top token alternatives to return if logprobs enabled
405
506
  validator: Custom validation function to validate the output
406
507
  max_validation_retries: Maximum number of retry attempts if validation fails
508
+ priority: Task execution priority (if enabled by vLLM and model)
407
509
 
408
510
  Returns:
409
511
  ToolOutput: Object containing:
@@ -426,8 +528,9 @@ class AsyncTheTool:
426
528
  max_validation_retries=max_validation_retries,
427
529
  # Internal parameters
428
530
  prompt_file="subject_to_question.yaml",
429
- output_model=OM.ReasonListStrOutput,
531
+ output_model=Models.ReasonListStrOutput,
430
532
  mode=None,
533
+ priority=priority,
431
534
  )
432
535
 
433
536
  async def summarize(
@@ -441,7 +544,8 @@ class AsyncTheTool:
441
544
  top_logprobs: int | None = None,
442
545
  validator: Callable[[Any], bool] | None = None,
443
546
  max_validation_retries: int | None = None,
444
- ) -> OM.ToolOutput:
547
+ priority: int | None = 0,
548
+ ) -> Models.ToolOutput:
445
549
  """
446
550
  Summarize the given subject text.
447
551
 
@@ -455,6 +559,7 @@ class AsyncTheTool:
455
559
  top_logprobs: Number of top token alternatives to return if logprobs enabled
456
560
  validator: Custom validation function to validate the output
457
561
  max_validation_retries: Maximum number of retry attempts if validation fails
562
+ priority: Task execution priority (if enabled by vLLM and model)
458
563
 
459
564
  Returns:
460
565
  ToolOutput: Object containing:
@@ -476,8 +581,9 @@ class AsyncTheTool:
476
581
  max_validation_retries=max_validation_retries,
477
582
  # Internal parameters
478
583
  prompt_file="summarize.yaml",
479
- output_model=OM.StrOutput,
584
+ output_model=Models.StrOutput,
480
585
  mode=None,
586
+ priority=priority,
481
587
  )
482
588
 
483
589
  async def translate(
@@ -491,7 +597,8 @@ class AsyncTheTool:
491
597
  top_logprobs: int | None = None,
492
598
  validator: Callable[[Any], bool] | None = None,
493
599
  max_validation_retries: int | None = None,
494
- ) -> OM.ToolOutput:
600
+ priority: int | None = 0,
601
+ ) -> Models.ToolOutput:
495
602
  """
496
603
  Translate text between languages.
497
604
 
@@ -505,6 +612,7 @@ class AsyncTheTool:
505
612
  top_logprobs: Number of top token alternatives to return if logprobs enabled
506
613
  validator: Custom validation function to validate the output
507
614
  max_validation_retries: Maximum number of retry attempts if validation fails
615
+ priority: Task execution priority (if enabled by vLLM and model)
508
616
 
509
617
  Returns:
510
618
  ToolOutput: Object containing:
@@ -526,9 +634,63 @@ class AsyncTheTool:
526
634
  max_validation_retries=max_validation_retries,
527
635
  # Internal parameters
528
636
  prompt_file="translate.yaml",
529
- output_model=OM.StrOutput,
637
+ output_model=Models.StrOutput,
530
638
  mode=None,
531
639
  output_lang=None,
640
+ priority=priority,
641
+ )
642
+
643
+ async def detect_entity(
644
+ self,
645
+ text: str,
646
+ with_analysis: bool = False,
647
+ output_lang: str | None = None,
648
+ user_prompt: str | None = None,
649
+ temperature: float | None = 0.0,
650
+ logprobs: bool = False,
651
+ top_logprobs: int | None = None,
652
+ validator: Callable[[Any], bool] | None = None,
653
+ max_validation_retries: int | None = None,
654
+ priority: int | None = 0,
655
+ ) -> Models.ToolOutput:
656
+ """
657
+ Detects entities in a given text based on the entity_detector.yaml prompt.
658
+
659
+ Arguments:
660
+ text: The input text
661
+ with_analysis: Whether to include detailed reasoning analysis
662
+ output_lang: Language for the output summary
663
+ user_prompt: Additional instructions for summarization
664
+ temperature: Controls randomness (0.0 = deterministic, 1.0 = creative)
665
+ logprobs: Whether to return token probability information
666
+ top_logprobs: Number of top token alternatives to return if logprobs enabled
667
+ validator: Custom validation function to validate the output
668
+ max_validation_retries: Maximum number of retry attempts if validation fails
669
+ priority: Task execution priority (if enabled by vLLM and model)
670
+
671
+ Returns:
672
+ ToolOutput: Object containing:
673
+ - result (list[Entity]): The entities
674
+ - logprobs (list | None): Probability data if logprobs enabled
675
+ - analysis (str | None): Detailed reasoning if with_analysis enabled
676
+ - errors (list(str) | None): Errors occured during tool call
677
+ """
678
+ return await self._operator.run(
679
+ # User parameters
680
+ text=text,
681
+ with_analysis=with_analysis,
682
+ output_lang=output_lang,
683
+ user_prompt=user_prompt,
684
+ temperature=temperature,
685
+ logprobs=logprobs,
686
+ top_logprobs=top_logprobs,
687
+ validator=validator,
688
+ max_validation_retries=max_validation_retries,
689
+ # Internal parameters
690
+ prompt_file="detect_entity.yaml",
691
+ output_model=Models.EntityDetectorOutput,
692
+ mode=None,
693
+ priority=priority,
532
694
  )
533
695
 
534
696
  async def run_custom(
@@ -541,7 +703,8 @@ class AsyncTheTool:
541
703
  top_logprobs: int | None = None,
542
704
  validator: Callable[[Any], bool] | None = None,
543
705
  max_validation_retries: int | None = None,
544
- ) -> OM.ToolOutput:
706
+ priority: int | None = 0,
707
+ ) -> Models.ToolOutput:
545
708
  """
546
709
  Custom tool that can do almost anything!
547
710
 
@@ -553,6 +716,7 @@ class AsyncTheTool:
553
716
  top_logprobs: Number of top token alternatives to return if logprobs enabled
554
717
  validator: Custom validation function to validate the output
555
718
  max_validation_retries: Maximum number of retry attempts if validation fails
719
+ priority: Task execution priority (if enabled by vLLM and model)
556
720
 
557
721
  Returns:
558
722
  ToolOutput: Object containing:
@@ -577,4 +741,5 @@ class AsyncTheTool:
577
741
  user_prompt=None,
578
742
  with_analysis=False,
579
743
  mode=None,
744
+ priority=priority,
580
745
  )
@@ -1,10 +1,11 @@
1
- from typing import Any, TypeVar, Type, Callable
1
+ from typing import Any, TypeVar, Type
2
+ from collections.abc import Callable
2
3
  import logging
3
4
 
4
5
  from openai import AsyncOpenAI
5
6
  from pydantic import BaseModel
6
7
 
7
- from texttools.tools.internals.output_models import ToolOutput
8
+ from texttools.tools.internals.models import ToolOutput
8
9
  from texttools.tools.internals.operator_utils import OperatorUtils
9
10
  from texttools.tools.internals.formatters import Formatter
10
11
  from texttools.tools.internals.prompt_loader import PromptLoader
@@ -51,6 +52,7 @@ class AsyncOperator:
51
52
  temperature: float,
52
53
  logprobs: bool = False,
53
54
  top_logprobs: int = 3,
55
+ priority: int | None = 0,
54
56
  ) -> tuple[T, Any]:
55
57
  """
56
58
  Parses a chat completion using OpenAI's structured output format.
@@ -66,7 +68,8 @@ class AsyncOperator:
66
68
  if logprobs:
67
69
  request_kwargs["logprobs"] = True
68
70
  request_kwargs["top_logprobs"] = top_logprobs
69
-
71
+ if priority:
72
+ request_kwargs["extra_body"] = {"priority": priority}
70
73
  completion = await self._client.beta.chat.completions.parse(**request_kwargs)
71
74
  parsed = completion.choices[0].message.parsed
72
75
  return parsed, completion
@@ -87,6 +90,7 @@ class AsyncOperator:
87
90
  prompt_file: str,
88
91
  output_model: Type[T],
89
92
  mode: str | None,
93
+ priority: int | None = 0,
90
94
  **extra_kwargs,
91
95
  ) -> ToolOutput:
92
96
  """
@@ -136,7 +140,7 @@ class AsyncOperator:
136
140
  messages = formatter.user_merge_format(messages)
137
141
 
138
142
  parsed, completion = await self._parse_completion(
139
- messages, output_model, temperature, logprobs, top_logprobs
143
+ messages, output_model, temperature, logprobs, top_logprobs, priority
140
144
  )
141
145
 
142
146
  output.result = parsed.result