hamtaa-texttools 0.1.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hamtaa-texttools might be problematic. Click here for more details.
- hamtaa_texttools-0.1.43.dist-info/METADATA +60 -0
- hamtaa_texttools-0.1.43.dist-info/RECORD +60 -0
- hamtaa_texttools-0.1.43.dist-info/WHEEL +5 -0
- hamtaa_texttools-0.1.43.dist-info/top_level.txt +1 -0
- texttools/__init__.py +26 -0
- texttools/base/__init__.py +3 -0
- texttools/base/base_categorizer.py +40 -0
- texttools/base/base_keyword_extractor.py +35 -0
- texttools/base/base_ner_extractor.py +61 -0
- texttools/base/base_question_detector.py +35 -0
- texttools/base/base_question_generator.py +99 -0
- texttools/base/base_question_merger.py +59 -0
- texttools/base/base_question_rewriter.py +61 -0
- texttools/base/base_router.py +33 -0
- texttools/base/base_summarizer.py +55 -0
- texttools/base/base_task_performer.py +53 -0
- texttools/base/base_translator.py +38 -0
- texttools/batch_manager/__init__.py +2 -0
- texttools/batch_manager/batch_manager.py +241 -0
- texttools/batch_manager/batch_runner.py +207 -0
- texttools/formatter/__init__.py +1 -0
- texttools/formatter/base.py +26 -0
- texttools/formatter/gemma3_formatter.py +51 -0
- texttools/handlers/__init__.py +6 -0
- texttools/handlers/categorizer/__init__.py +6 -0
- texttools/handlers/categorizer/categorizer.py +61 -0
- texttools/handlers/handlers.py +88 -0
- texttools/tools/__init__.py +33 -0
- texttools/tools/categorizer/__init__.py +2 -0
- texttools/tools/categorizer/encoder_model/__init__.py +1 -0
- texttools/tools/categorizer/encoder_model/encoder_vectorizer.py +51 -0
- texttools/tools/categorizer/llm/__init__.py +2 -0
- texttools/tools/categorizer/llm/gemma_categorizer.py +169 -0
- texttools/tools/categorizer/llm/openai_categorizer.py +80 -0
- texttools/tools/keyword_extractor/__init__.py +1 -0
- texttools/tools/keyword_extractor/gemma_extractor.py +138 -0
- texttools/tools/merger/__init__.py +2 -0
- texttools/tools/merger/gemma_question_merger.py +214 -0
- texttools/tools/ner/__init__.py +1 -0
- texttools/tools/ner/gemma_ner_extractor.py +157 -0
- texttools/tools/question_detector/__init__.py +2 -0
- texttools/tools/question_detector/gemma_detector.py +130 -0
- texttools/tools/question_detector/llm_detector.py +112 -0
- texttools/tools/question_generator/__init__.py +1 -0
- texttools/tools/question_generator/gemma_question_generator.py +198 -0
- texttools/tools/reranker/__init__.py +3 -0
- texttools/tools/reranker/reranker.py +137 -0
- texttools/tools/reranker/scorer.py +216 -0
- texttools/tools/reranker/sorter.py +278 -0
- texttools/tools/rewriter/__init__.py +2 -0
- texttools/tools/rewriter/gemma_question_rewriter.py +213 -0
- texttools/tools/router/__init__.py +0 -0
- texttools/tools/router/gemma_router.py +169 -0
- texttools/tools/subject_to_question/__init__.py +1 -0
- texttools/tools/subject_to_question/gemma_question_generator.py +224 -0
- texttools/tools/summarizer/__init__.py +2 -0
- texttools/tools/summarizer/gemma_summarizer.py +140 -0
- texttools/tools/summarizer/llm_summerizer.py +108 -0
- texttools/tools/translator/__init__.py +1 -0
- texttools/tools/translator/gemma_translator.py +202 -0
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from enum import Enum
|
|
3
|
+
|
|
4
|
+
from elasticsearch import Elasticsearch, helpers
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ResultHandler(ABC):
|
|
8
|
+
"""
|
|
9
|
+
Abstract base class for all result handlers.
|
|
10
|
+
Implement the handle() method to define custom handling logic.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def handle(self, results: dict[str, Enum]) -> None:
|
|
15
|
+
"""
|
|
16
|
+
Process the categorization results.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
results (dict[str, Enum]): A dictionary mapping text (or IDs) to categories.
|
|
20
|
+
"""
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class NoOpResultHandler(ResultHandler):
|
|
25
|
+
"""
|
|
26
|
+
A result handler that does nothing.
|
|
27
|
+
Useful as a default when no other handler is provided.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def handle(self, results) -> None:
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PrintResultHandler(ResultHandler):
|
|
35
|
+
"""
|
|
36
|
+
A simple handler that prints results to the console.
|
|
37
|
+
Useful for debugging or local tests.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def handle(self, results) -> None:
|
|
41
|
+
for key, value in results.items():
|
|
42
|
+
print(f"Text ID: {key}, Category: {value.name}")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class SaveToElasticResultHandler(ResultHandler):
|
|
46
|
+
"""
|
|
47
|
+
A simple handler that saves results to an elastic index.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, es_client: Elasticsearch, index_name: str):
|
|
51
|
+
self.es_client = es_client
|
|
52
|
+
self.index_name = index_name
|
|
53
|
+
|
|
54
|
+
def handle(self, results):
|
|
55
|
+
documents = [
|
|
56
|
+
{"TextID": key, "Category": value.name} for key, value in results.items()
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
actions = [{"_index": self.index_name, "_source": doc} for doc in documents]
|
|
60
|
+
|
|
61
|
+
helpers.bulk(self.es_client, actions)
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ResultHandler(ABC):
|
|
10
|
+
"""
|
|
11
|
+
Abstract base class for all result handlers.
|
|
12
|
+
Implement the handle() method to define custom handling logic.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def handle(self, results: dict[str, Enum]) -> None:
|
|
17
|
+
"""
|
|
18
|
+
Process the categorization results.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
results (dict[str, Enum]): A dictionary mapping text (or IDs) to categories.
|
|
22
|
+
"""
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class NoOpResultHandler(ResultHandler):
|
|
27
|
+
"""
|
|
28
|
+
A result handler that does nothing.
|
|
29
|
+
Useful as a default when no other handler is provided.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def handle(self, results: dict[str, Enum]) -> None:
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class PrintResultHandler(ResultHandler):
|
|
37
|
+
"""
|
|
38
|
+
A simple handler that prints results to the console.
|
|
39
|
+
Useful for debugging or local tests.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def handle(self, results: dict[str, Enum]) -> None:
|
|
43
|
+
for key, value in results.items():
|
|
44
|
+
print(f"Text ID: {key}, Category: {value.name}")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class SaveToFileResultHandler(ResultHandler):
|
|
48
|
+
"""
|
|
49
|
+
A handler that saves each question + result pair to a CSV-like file,
|
|
50
|
+
serializing whatever the result object is.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(self, file_path: str):
|
|
54
|
+
self.file_path = file_path
|
|
55
|
+
|
|
56
|
+
def handle(self, results: dict[str, Any]) -> None:
|
|
57
|
+
"""
|
|
58
|
+
Expects `results` to be a dict with at least:
|
|
59
|
+
- "question": the original input text
|
|
60
|
+
- "result": the classification output (bool, BaseModel, dict, str, etc.)
|
|
61
|
+
|
|
62
|
+
Appends one line per call:
|
|
63
|
+
question_text,serialized_result
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
# Helper to turn anything into a JSON/string
|
|
67
|
+
def serialize(val: Any) -> str:
|
|
68
|
+
if isinstance(val, BaseModel):
|
|
69
|
+
return val.model_dump_json()
|
|
70
|
+
try:
|
|
71
|
+
return json.dumps(val)
|
|
72
|
+
except (TypeError, ValueError):
|
|
73
|
+
return str(val)
|
|
74
|
+
|
|
75
|
+
q = results.get("question", "")
|
|
76
|
+
r = results.get("result", results)
|
|
77
|
+
line = f"{q},{serialize(r)}\n"
|
|
78
|
+
|
|
79
|
+
with open(self.file_path, "a", encoding="utf-8") as f:
|
|
80
|
+
f.write(line)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# You can add more handlers here as needed:
|
|
84
|
+
# - ElasticSearchResultHandler
|
|
85
|
+
# - DatabaseResultHandler
|
|
86
|
+
# - KafkaResultHandler
|
|
87
|
+
# - NATSResultHandler
|
|
88
|
+
# etc.
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from .categorizer import EmbeddingCategorizer, GemmaCategorizer, LLMCategorizer
|
|
2
|
+
from .keyword_extractor import GemmaKeywordExtractor
|
|
3
|
+
from .ner import GemmaNERExtractor
|
|
4
|
+
from .question_detector import GemmaQuestionDetector, LLMQuestionDetector
|
|
5
|
+
from .question_generator import GemmaQuestionGenerator
|
|
6
|
+
from .reranker import GemmaReranker, GemmaScorer, GemmaSorter
|
|
7
|
+
from .rewriter import GemmaQuestionRewriter, RewriteMode
|
|
8
|
+
from .merger import GemmaQuestionMerger, MergingMode
|
|
9
|
+
from .subject_to_question import GemmaQuestionGeneratorFromSubject
|
|
10
|
+
from .summarizer import GemmaSummarizer, LLMSummarizer
|
|
11
|
+
from .translator import GemmaTranslator
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"EmbeddingCategorizer",
|
|
15
|
+
"GemmaCategorizer",
|
|
16
|
+
"LLMCategorizer",
|
|
17
|
+
"GemmaTranslator",
|
|
18
|
+
"GemmaSummarizer",
|
|
19
|
+
"LLMSummarizer",
|
|
20
|
+
"GemmaNERExtractor",
|
|
21
|
+
"GemmaQuestionDetector",
|
|
22
|
+
"LLMQuestionDetector",
|
|
23
|
+
"GemmaQuestionGenerator",
|
|
24
|
+
"GemmaScorer",
|
|
25
|
+
"GemmaSorter",
|
|
26
|
+
"GemmaReranker",
|
|
27
|
+
"GemmaQuestionRewriter",
|
|
28
|
+
"RewriteMode",
|
|
29
|
+
"GemmaKeywordExtractor",
|
|
30
|
+
"GemmaQuestionGeneratorFromSubject",
|
|
31
|
+
"GemmaQuestionMerger",
|
|
32
|
+
"MergingMode",
|
|
33
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .encoder_vectorizer import EmbeddingCategorizer
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from texttools.base import BaseCategorizer
|
|
7
|
+
from texttools.handlers import ResultHandler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class EmbeddingCategorizer(BaseCategorizer):
|
|
11
|
+
"""
|
|
12
|
+
Uses pre-stored embeddings on each Enum member.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
categories: Enum,
|
|
18
|
+
embedding_model: Any,
|
|
19
|
+
handlers: Optional[list[ResultHandler]] = None,
|
|
20
|
+
):
|
|
21
|
+
"""
|
|
22
|
+
:param categories: your Enum class, whose members have `.embeddings`
|
|
23
|
+
:param embedding_model: something with `.encode(text: str) -> list[float]`
|
|
24
|
+
"""
|
|
25
|
+
super().__init__(categories, handlers)
|
|
26
|
+
self.embedding_model = embedding_model
|
|
27
|
+
|
|
28
|
+
def categorize(self, text: str) -> Enum:
|
|
29
|
+
# 1. Preprocess
|
|
30
|
+
text = self.preprocess(text)
|
|
31
|
+
|
|
32
|
+
# 2. Encode the text
|
|
33
|
+
vec = np.array(self.embedding_model.encode(text), dtype=float)
|
|
34
|
+
|
|
35
|
+
# 3. Find best category
|
|
36
|
+
best_cat = None
|
|
37
|
+
best_score = -1.0
|
|
38
|
+
|
|
39
|
+
for cat in self.categories:
|
|
40
|
+
for proto in cat.embeddings:
|
|
41
|
+
score = self._cosine_similarity(vec, proto)
|
|
42
|
+
if score > best_score:
|
|
43
|
+
best_score = score
|
|
44
|
+
best_cat = cat
|
|
45
|
+
|
|
46
|
+
self._dispatch({text: best_cat})
|
|
47
|
+
return best_cat
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
|
51
|
+
return float(a.dot(b) / (np.linalg.norm(a) * np.linalg.norm(b)))
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
from typing import Any, Literal, Optional
|
|
2
|
+
|
|
3
|
+
from openai import OpenAI
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from texttools.base.base_categorizer import BaseCategorizer
|
|
7
|
+
from texttools.formatter.gemma3_formatter import Gemma3Formatter
|
|
8
|
+
from texttools.handlers import ResultHandler
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Output(BaseModel):
|
|
12
|
+
reason: str
|
|
13
|
+
main_tag: Literal[
|
|
14
|
+
"باورهای دینی",
|
|
15
|
+
"اخلاق اسلامی",
|
|
16
|
+
"احکام و فقه",
|
|
17
|
+
"تاریخ اسلام و شخصیت ها",
|
|
18
|
+
"منابع دینی",
|
|
19
|
+
"دین و جامعه/سیاست",
|
|
20
|
+
"عرفان و معنویت",
|
|
21
|
+
"هیچکدام",
|
|
22
|
+
] = None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GemmaCategorizer(BaseCategorizer):
|
|
26
|
+
"""
|
|
27
|
+
Categorizer for Gemma-style models. It requires a predefined Enum of categories
|
|
28
|
+
to choose from and returns an Enum member.
|
|
29
|
+
Outputs JSON with a single string field: {"category": "..."}.
|
|
30
|
+
|
|
31
|
+
Allows optional extra instructions via `prompt_template`.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
client: OpenAI,
|
|
37
|
+
*,
|
|
38
|
+
model: str,
|
|
39
|
+
output_structure: BaseModel = Output,
|
|
40
|
+
chat_formatter: Optional[Any] = None,
|
|
41
|
+
use_reason: bool = False,
|
|
42
|
+
temperature: float = 0.0,
|
|
43
|
+
prompt_template: Optional[str] = None,
|
|
44
|
+
handlers: Optional[list[ResultHandler]] = None,
|
|
45
|
+
**client_kwargs: Any,
|
|
46
|
+
):
|
|
47
|
+
super().__init__(handlers=handlers)
|
|
48
|
+
self.client = client
|
|
49
|
+
self.model = model
|
|
50
|
+
self.temperature = temperature
|
|
51
|
+
self.client_kwargs = client_kwargs
|
|
52
|
+
self.output_structure = output_structure
|
|
53
|
+
self.chat_formatter = chat_formatter or Gemma3Formatter()
|
|
54
|
+
|
|
55
|
+
self.use_reason = use_reason
|
|
56
|
+
self.prompt_template = prompt_template
|
|
57
|
+
|
|
58
|
+
def _build_messages(
|
|
59
|
+
self, text: str, reason: Optional[str] = None
|
|
60
|
+
) -> list[dict[str, str]]:
|
|
61
|
+
"""
|
|
62
|
+
Builds the message list for the LLM API call for categorization.
|
|
63
|
+
"""
|
|
64
|
+
clean_text = self.preprocess(text)
|
|
65
|
+
|
|
66
|
+
messages: list[dict[str, str]] = []
|
|
67
|
+
|
|
68
|
+
if self.prompt_template:
|
|
69
|
+
messages.append({"role": "user", "content": self.prompt_template})
|
|
70
|
+
|
|
71
|
+
if reason:
|
|
72
|
+
messages.append(
|
|
73
|
+
{"role": "user", "content": f"Based on this analysis: {reason}"}
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
messages.append(
|
|
77
|
+
{
|
|
78
|
+
"role": "user",
|
|
79
|
+
"content": """
|
|
80
|
+
تو یک متخصص علوم دینی هستی
|
|
81
|
+
من به عنوان کاربر یک متن به تو میدم و از تو میخوام که
|
|
82
|
+
اون متن رو در یکی از دسته بندی های زیر طبقه بندی کنی
|
|
83
|
+
|
|
84
|
+
"باورهای دینی",
|
|
85
|
+
"اخلاق اسلامی",
|
|
86
|
+
"احکام و فقه",
|
|
87
|
+
"تاریخ اسلام و شخصیت ها",
|
|
88
|
+
"منابع دینی",
|
|
89
|
+
"دین و جامعه/سیاست",
|
|
90
|
+
"عرفان و معنویت",
|
|
91
|
+
"هیچکدام",
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
در خروجی که از تو خواسته شده بخشی با عنوان reason وجود دارد
|
|
95
|
+
در اون بخش، دلیل انتخاب دسته بندی رو به صورت خلاصه بیاور
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
متنی که باید طبقه بندی کنی:
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
""",
|
|
103
|
+
}
|
|
104
|
+
)
|
|
105
|
+
messages.append({"role": "user", "content": clean_text})
|
|
106
|
+
restructured = self.chat_formatter.format(messages=messages)
|
|
107
|
+
|
|
108
|
+
return restructured
|
|
109
|
+
|
|
110
|
+
def _reason(self, text: str) -> str:
|
|
111
|
+
"""
|
|
112
|
+
Internal reasoning step to help the model analyze the text for categorization.
|
|
113
|
+
"""
|
|
114
|
+
messages = [
|
|
115
|
+
{
|
|
116
|
+
"role": "user",
|
|
117
|
+
"content": """
|
|
118
|
+
هدف ما طبقه بندی متن هست
|
|
119
|
+
متن رو بخون و ایده اصلی و آنالیزی کوتاه از اون رو ارائه بده
|
|
120
|
+
|
|
121
|
+
بسیار خلاصه باشه خروجی تو
|
|
122
|
+
نهایتا 20 کلمه
|
|
123
|
+
""",
|
|
124
|
+
},
|
|
125
|
+
{
|
|
126
|
+
"role": "user",
|
|
127
|
+
"content": f"""
|
|
128
|
+
{text}
|
|
129
|
+
""",
|
|
130
|
+
},
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
restrucruted = self.chat_formatter.format(messages=messages)
|
|
134
|
+
|
|
135
|
+
resp = self.client.chat.completions.create(
|
|
136
|
+
model=self.model,
|
|
137
|
+
messages=restrucruted,
|
|
138
|
+
temperature=self.temperature,
|
|
139
|
+
**self.client_kwargs,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
reason_summary = resp.choices[0].message.content.strip()
|
|
143
|
+
return reason_summary
|
|
144
|
+
|
|
145
|
+
def categorize(self, text: str):
|
|
146
|
+
"""
|
|
147
|
+
Categorizes `text` by selecting an appropriate member from the predefined Enum.
|
|
148
|
+
Optionally uses an internal reasoning step for better accuracy.
|
|
149
|
+
"""
|
|
150
|
+
reason_summary = None
|
|
151
|
+
if self.use_reason:
|
|
152
|
+
reason_summary = self._reason(text)
|
|
153
|
+
|
|
154
|
+
messages = self._build_messages(text, reason_summary)
|
|
155
|
+
completion = self.client.beta.chat.completions.parse(
|
|
156
|
+
model=self.model,
|
|
157
|
+
messages=messages,
|
|
158
|
+
response_format=Output,
|
|
159
|
+
temperature=self.temperature,
|
|
160
|
+
extra_body=dict(guided_decoding_backend="auto"),
|
|
161
|
+
**self.client_kwargs,
|
|
162
|
+
)
|
|
163
|
+
message = completion.choices[0].message
|
|
164
|
+
|
|
165
|
+
category_name = message.parsed.main_tag
|
|
166
|
+
|
|
167
|
+
# dispatch and return - Note: _dispatch expects dict
|
|
168
|
+
self._dispatch(results={"main_tag": category_name})
|
|
169
|
+
return {"main_tag": category_name}
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
|
|
4
|
+
from openai import OpenAI
|
|
5
|
+
from pydantic import BaseModel, create_model
|
|
6
|
+
|
|
7
|
+
from texttools.base import BaseCategorizer
|
|
8
|
+
from texttools.handlers import NoOpResultHandler
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LLMCategorizer(BaseCategorizer):
|
|
12
|
+
"""
|
|
13
|
+
LLM-based categorizer using OpenAI's client.responses.parse
|
|
14
|
+
for Structured Outputs (Pydantic models).
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
client: OpenAI,
|
|
20
|
+
categories: Enum,
|
|
21
|
+
*,
|
|
22
|
+
model: str,
|
|
23
|
+
temperature: float = 0.0,
|
|
24
|
+
prompt_template: str = None,
|
|
25
|
+
handlers: Optional[list[NoOpResultHandler]] = None,
|
|
26
|
+
**client_kwargs: Any,
|
|
27
|
+
):
|
|
28
|
+
"""
|
|
29
|
+
:param client: an instantiated OpenAI client
|
|
30
|
+
:param categories: an Enum class of allowed categories
|
|
31
|
+
:param model: the model name (e.g. "gpt-4o-2024-08-06")
|
|
32
|
+
:param temperature: sampling temperature
|
|
33
|
+
:param prompt_template: override default prompt instructions
|
|
34
|
+
:param handlers: list of handler instances to process the output
|
|
35
|
+
:param client_kwargs: any other OpenAI kwargs (e.g. `max_tokens`, `top_p`, etc.)
|
|
36
|
+
"""
|
|
37
|
+
super().__init__(categories, handlers=handlers)
|
|
38
|
+
self.client = client
|
|
39
|
+
self.model = model
|
|
40
|
+
self.temperature = temperature
|
|
41
|
+
self.client_kwargs = client_kwargs
|
|
42
|
+
|
|
43
|
+
self.prompt_template = prompt_template or (
|
|
44
|
+
"You are a text classifier. Choose exactly one category from the list."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
self._OutputModel = create_model(
|
|
48
|
+
"CategorizationOutput",
|
|
49
|
+
category=(self.categories, ...),
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def _build_messages(self, text: str) -> list[dict[str, str]]:
|
|
53
|
+
"""
|
|
54
|
+
Builds the message list for the OpenAI API based on the input text.
|
|
55
|
+
"""
|
|
56
|
+
clean = self.preprocess(text)
|
|
57
|
+
return [
|
|
58
|
+
{"role": "system", "content": self.prompt_template},
|
|
59
|
+
{"role": "user", "content": clean},
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
def categorize(self, text: str) -> Enum:
|
|
63
|
+
"""
|
|
64
|
+
Categorizes the input text using OpenAI API and processes it using handlers.
|
|
65
|
+
"""
|
|
66
|
+
msgs = self._build_messages(text)
|
|
67
|
+
|
|
68
|
+
resp = self.client.responses.parse(
|
|
69
|
+
model=self.model,
|
|
70
|
+
input=msgs,
|
|
71
|
+
text_format=self._OutputModel,
|
|
72
|
+
temperature=self.temperature,
|
|
73
|
+
**self.client_kwargs,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
output: BaseModel = resp.output_parsed
|
|
77
|
+
|
|
78
|
+
self._dispatch({"text": text, "category": output.category})
|
|
79
|
+
|
|
80
|
+
return output.category
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from texttools.tools.keyword_extractor.gemma_extractor import GemmaKeywordExtractor
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
from openai import OpenAI
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from texttools.base.base_keyword_extractor import BaseKeywordExtractor
|
|
7
|
+
from texttools.formatter import Gemma3Formatter
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Output(BaseModel):
|
|
11
|
+
keywords: list
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GemmaKeywordExtractor(BaseKeywordExtractor):
|
|
15
|
+
"""
|
|
16
|
+
Keyword extractor for Gemma-style models with optional reasoning step.
|
|
17
|
+
Outputs JSON with a single array field: {"keywords": ["keyword1", "keyword2", ...]}.
|
|
18
|
+
|
|
19
|
+
Allows optional extra instructions via `prompt_template`.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
client: OpenAI,
|
|
25
|
+
*,
|
|
26
|
+
model: str,
|
|
27
|
+
use_reason: bool = False,
|
|
28
|
+
chat_formatter: Optional[Any] = None,
|
|
29
|
+
temperature: float = 0.0,
|
|
30
|
+
prompt_template: str = None,
|
|
31
|
+
handlers: list[Any] = None,
|
|
32
|
+
**client_kwargs: Any,
|
|
33
|
+
):
|
|
34
|
+
super().__init__(handlers)
|
|
35
|
+
self.client = client
|
|
36
|
+
self.model = model
|
|
37
|
+
self.temperature = temperature
|
|
38
|
+
self.client_kwargs = client_kwargs
|
|
39
|
+
self.chat_formatter = chat_formatter or Gemma3Formatter()
|
|
40
|
+
self.use_reason = use_reason
|
|
41
|
+
self.prompt_template = prompt_template
|
|
42
|
+
|
|
43
|
+
self.output = Output
|
|
44
|
+
|
|
45
|
+
def _build_messages(
|
|
46
|
+
self, text: str, reason: Optional[str] = None
|
|
47
|
+
) -> list[dict[str, str]]:
|
|
48
|
+
clean_text = self.preprocess(text)
|
|
49
|
+
|
|
50
|
+
messages: list[dict[str, str]] = []
|
|
51
|
+
|
|
52
|
+
if self.prompt_template:
|
|
53
|
+
messages.append({"role": "user", "content": self.prompt_template})
|
|
54
|
+
|
|
55
|
+
if reason: # Include the reason if available
|
|
56
|
+
messages.append(
|
|
57
|
+
{"role": "user", "content": f"Based on this analysis: {reason}"}
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
messages.append(
|
|
61
|
+
{
|
|
62
|
+
"role": "user",
|
|
63
|
+
"content": "Extract the most relevant keywords from the following text. Provide them as a list of strings.",
|
|
64
|
+
}
|
|
65
|
+
)
|
|
66
|
+
messages.append({"role": "user", "content": clean_text})
|
|
67
|
+
|
|
68
|
+
# Ensure the schema is dumped as a valid JSON string
|
|
69
|
+
schema_instr = f"Respond only in JSON format: {self.output.model_dump_json()}"
|
|
70
|
+
messages.append({"role": "user", "content": schema_instr})
|
|
71
|
+
|
|
72
|
+
# Deprecated
|
|
73
|
+
# messages.append(
|
|
74
|
+
# {"role": "assistant", "content": "{"}
|
|
75
|
+
# ) # Start with '{' to hint JSON
|
|
76
|
+
|
|
77
|
+
messages = self.chat_formatter.format(messages=messages)
|
|
78
|
+
|
|
79
|
+
return messages
|
|
80
|
+
|
|
81
|
+
def _reason(self, text: str) -> str:
|
|
82
|
+
"""
|
|
83
|
+
Internal reasoning step to help the model identify potential keywords.
|
|
84
|
+
"""
|
|
85
|
+
messages = [
|
|
86
|
+
{
|
|
87
|
+
"role": "user",
|
|
88
|
+
"content": """
|
|
89
|
+
Analyze the following text to identify its main topics, concepts, and important terms.
|
|
90
|
+
Provide a concise summary of your findings that will help in extracting relevant keywords.
|
|
91
|
+
""",
|
|
92
|
+
},
|
|
93
|
+
{
|
|
94
|
+
"role": "user",
|
|
95
|
+
"content": f"""
|
|
96
|
+
{text}
|
|
97
|
+
""",
|
|
98
|
+
},
|
|
99
|
+
]
|
|
100
|
+
messages = self.chat_formatter.format(messages=messages)
|
|
101
|
+
|
|
102
|
+
resp = self.client.chat.completions.create(
|
|
103
|
+
model=self.model,
|
|
104
|
+
messages=messages,
|
|
105
|
+
temperature=self.temperature,
|
|
106
|
+
**self.client_kwargs,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
reason_summary = resp.choices[0].message.content.strip()
|
|
110
|
+
return reason_summary
|
|
111
|
+
|
|
112
|
+
def extract_keywords(self, text: str) -> list[str]:
|
|
113
|
+
"""
|
|
114
|
+
Extracts keywords from `text`.
|
|
115
|
+
Optionally uses an internal reasoning step for better accuracy.
|
|
116
|
+
"""
|
|
117
|
+
reason_summary = None
|
|
118
|
+
if self.use_reason:
|
|
119
|
+
reason_summary = self._reason(text)
|
|
120
|
+
|
|
121
|
+
messages = self._build_messages(text, reason_summary)
|
|
122
|
+
|
|
123
|
+
completion = self.client.beta.chat.completions.parse(
|
|
124
|
+
model=self.model,
|
|
125
|
+
messages=messages,
|
|
126
|
+
response_format=Output,
|
|
127
|
+
temperature=self.temperature,
|
|
128
|
+
extra_body=dict(guided_decoding_backend="auto"),
|
|
129
|
+
**self.client_kwargs,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
message = completion.choices[0].message
|
|
133
|
+
|
|
134
|
+
keywords = message.parsed.keywords
|
|
135
|
+
|
|
136
|
+
# dispatch and return
|
|
137
|
+
self._dispatch({"original_text": text, "keywords": keywords})
|
|
138
|
+
return keywords
|