hamtaa-texttools 0.1.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hamtaa-texttools might be problematic. Click here for more details.
- hamtaa_texttools-0.1.43.dist-info/METADATA +60 -0
- hamtaa_texttools-0.1.43.dist-info/RECORD +60 -0
- hamtaa_texttools-0.1.43.dist-info/WHEEL +5 -0
- hamtaa_texttools-0.1.43.dist-info/top_level.txt +1 -0
- texttools/__init__.py +26 -0
- texttools/base/__init__.py +3 -0
- texttools/base/base_categorizer.py +40 -0
- texttools/base/base_keyword_extractor.py +35 -0
- texttools/base/base_ner_extractor.py +61 -0
- texttools/base/base_question_detector.py +35 -0
- texttools/base/base_question_generator.py +99 -0
- texttools/base/base_question_merger.py +59 -0
- texttools/base/base_question_rewriter.py +61 -0
- texttools/base/base_router.py +33 -0
- texttools/base/base_summarizer.py +55 -0
- texttools/base/base_task_performer.py +53 -0
- texttools/base/base_translator.py +38 -0
- texttools/batch_manager/__init__.py +2 -0
- texttools/batch_manager/batch_manager.py +241 -0
- texttools/batch_manager/batch_runner.py +207 -0
- texttools/formatter/__init__.py +1 -0
- texttools/formatter/base.py +26 -0
- texttools/formatter/gemma3_formatter.py +51 -0
- texttools/handlers/__init__.py +6 -0
- texttools/handlers/categorizer/__init__.py +6 -0
- texttools/handlers/categorizer/categorizer.py +61 -0
- texttools/handlers/handlers.py +88 -0
- texttools/tools/__init__.py +33 -0
- texttools/tools/categorizer/__init__.py +2 -0
- texttools/tools/categorizer/encoder_model/__init__.py +1 -0
- texttools/tools/categorizer/encoder_model/encoder_vectorizer.py +51 -0
- texttools/tools/categorizer/llm/__init__.py +2 -0
- texttools/tools/categorizer/llm/gemma_categorizer.py +169 -0
- texttools/tools/categorizer/llm/openai_categorizer.py +80 -0
- texttools/tools/keyword_extractor/__init__.py +1 -0
- texttools/tools/keyword_extractor/gemma_extractor.py +138 -0
- texttools/tools/merger/__init__.py +2 -0
- texttools/tools/merger/gemma_question_merger.py +214 -0
- texttools/tools/ner/__init__.py +1 -0
- texttools/tools/ner/gemma_ner_extractor.py +157 -0
- texttools/tools/question_detector/__init__.py +2 -0
- texttools/tools/question_detector/gemma_detector.py +130 -0
- texttools/tools/question_detector/llm_detector.py +112 -0
- texttools/tools/question_generator/__init__.py +1 -0
- texttools/tools/question_generator/gemma_question_generator.py +198 -0
- texttools/tools/reranker/__init__.py +3 -0
- texttools/tools/reranker/reranker.py +137 -0
- texttools/tools/reranker/scorer.py +216 -0
- texttools/tools/reranker/sorter.py +278 -0
- texttools/tools/rewriter/__init__.py +2 -0
- texttools/tools/rewriter/gemma_question_rewriter.py +213 -0
- texttools/tools/router/__init__.py +0 -0
- texttools/tools/router/gemma_router.py +169 -0
- texttools/tools/subject_to_question/__init__.py +1 -0
- texttools/tools/subject_to_question/gemma_question_generator.py +224 -0
- texttools/tools/summarizer/__init__.py +2 -0
- texttools/tools/summarizer/gemma_summarizer.py +140 -0
- texttools/tools/summarizer/llm_summerizer.py +108 -0
- texttools/tools/translator/__init__.py +1 -0
- texttools/tools/translator/gemma_translator.py +202 -0
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
from openai import OpenAI
|
|
6
|
+
|
|
7
|
+
from texttools.base.base_task_performer import BaseTaskPerformer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GemmaSorter(BaseTaskPerformer):
|
|
11
|
+
"""
|
|
12
|
+
A sorter component utilizing Gemma-style LLMs to order a list of
|
|
13
|
+
pre-scored results based on a query, handling ties semantically.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
client: OpenAI,
|
|
19
|
+
*,
|
|
20
|
+
model: str,
|
|
21
|
+
temperature: float = 0.0,
|
|
22
|
+
prompt_template: Optional[str] = None,
|
|
23
|
+
use_reason: bool = False,
|
|
24
|
+
handlers: Optional[list[Any]] = None,
|
|
25
|
+
**client_kwargs: Any,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Initializes the GemmaSorter.
|
|
29
|
+
|
|
30
|
+
:param client: An initialized OpenAI client (or compatible).
|
|
31
|
+
:param model: The name of the LLM model to use for sorting (e.g., "gemma-7b-it").
|
|
32
|
+
:param temperature: The sampling temperature for LLM generation (0.0 for deterministic).
|
|
33
|
+
:param prompt_template: An optional initial system-level prompt for the LLM.
|
|
34
|
+
:param use_reason: If True, the sorter will perform an internal reasoning step
|
|
35
|
+
and include it in the sorting prompt.
|
|
36
|
+
:param handlers: Optional list of handlers for dispatching sorting results.
|
|
37
|
+
:param client_kwargs: Additional keyword arguments for the OpenAI client.
|
|
38
|
+
"""
|
|
39
|
+
super().__init__(handlers)
|
|
40
|
+
self.client = client
|
|
41
|
+
self.model = model
|
|
42
|
+
self.temperature = temperature
|
|
43
|
+
self.client_kwargs = client_kwargs
|
|
44
|
+
self.prompt_template = prompt_template
|
|
45
|
+
self.use_reason = use_reason
|
|
46
|
+
|
|
47
|
+
# Defines the expected JSON schema for the LLM's ordered IDs output.
|
|
48
|
+
self.sort_schema = {"ordered_ids": ["string"]}
|
|
49
|
+
|
|
50
|
+
def _build_sorting_messages(
|
|
51
|
+
self,
|
|
52
|
+
query: str,
|
|
53
|
+
scored_results: list[dict[str, Any]],
|
|
54
|
+
reason: Optional[str] = None,
|
|
55
|
+
) -> list[dict[str, str]]:
|
|
56
|
+
"""
|
|
57
|
+
Constructs the messages payload for the LLM API call to sort results.
|
|
58
|
+
|
|
59
|
+
:param query: The original search query.
|
|
60
|
+
:param scored_results: A list of dictionaries, where each dict has '_internal_id', 'text', and 'score'.
|
|
61
|
+
:param reason: An optional reasoning summary to provide context to the LLM.
|
|
62
|
+
:return: A list of message dictionaries formatted for the LLM API.
|
|
63
|
+
"""
|
|
64
|
+
clean_query = self._preprocess(query)
|
|
65
|
+
messages: list[dict[str, str]] = []
|
|
66
|
+
|
|
67
|
+
if self.prompt_template:
|
|
68
|
+
messages.append({"role": "user", "content": self.prompt_template})
|
|
69
|
+
|
|
70
|
+
if self.use_reason and reason:
|
|
71
|
+
messages.append(
|
|
72
|
+
{
|
|
73
|
+
"role": "user",
|
|
74
|
+
"content": f"Based on this analysis: {reason}",
|
|
75
|
+
}
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
messages.append({"role": "user", "content": f"Original Query: {clean_query}"})
|
|
79
|
+
|
|
80
|
+
# Format the scored results for the LLM.
|
|
81
|
+
scored_results_presentation = []
|
|
82
|
+
for res_dict in scored_results:
|
|
83
|
+
scored_results_presentation.append(
|
|
84
|
+
f"ID: {res_dict['_internal_id']}\nScore: {res_dict.get('score', 'N/A')}\nText: {res_dict.get('text', 'N/A')}"
|
|
85
|
+
)
|
|
86
|
+
messages.append(
|
|
87
|
+
{
|
|
88
|
+
"role": "user",
|
|
89
|
+
"content": "Here are the results with their assigned scores:\n"
|
|
90
|
+
+ "\n---\n".join(scored_results_presentation),
|
|
91
|
+
}
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
sorting_instruction = (
|
|
95
|
+
"Based on the provided scores, sort the result 'id's into an ordered list from most relevant to least relevant. "
|
|
96
|
+
"If multiple results have the same score, use semantic similarity to the original query as a tie-breaker. "
|
|
97
|
+
"Return only a JSON list of the 'id's in the final sorted order."
|
|
98
|
+
)
|
|
99
|
+
messages.append({"role": "user", "content": sorting_instruction})
|
|
100
|
+
|
|
101
|
+
schema_instr = f"Respond only in JSON format: {json.dumps(self.sort_schema)}"
|
|
102
|
+
messages.append({"role": "user", "content": schema_instr})
|
|
103
|
+
messages.append({"role": "assistant", "content": "{"})
|
|
104
|
+
return messages
|
|
105
|
+
|
|
106
|
+
def _reason(self, query: str, results: list[dict[str, Any]]) -> str:
|
|
107
|
+
"""
|
|
108
|
+
Generates an internal reasoning summary to help the LLM with sorting,
|
|
109
|
+
especially for tie-breaking. This summary is based on the query and initial results.
|
|
110
|
+
|
|
111
|
+
:param query: The original search query.
|
|
112
|
+
:param results: A list of results, potentially including scores and IDs.
|
|
113
|
+
:return: A string containing the reasoning summary.
|
|
114
|
+
"""
|
|
115
|
+
clean_query = self._preprocess(query)
|
|
116
|
+
|
|
117
|
+
# Truncate results for reasoning prompt to avoid exceeding token limits
|
|
118
|
+
results_for_reasoning_display = []
|
|
119
|
+
for res in results:
|
|
120
|
+
text_snippet = res.get("text", "")
|
|
121
|
+
if len(text_snippet) > 100:
|
|
122
|
+
text_snippet = text_snippet[:100] + "..."
|
|
123
|
+
results_for_reasoning_display.append(text_snippet)
|
|
124
|
+
|
|
125
|
+
reason_prompt = f"""
|
|
126
|
+
Analyze the original query: "{clean_query}"
|
|
127
|
+
And consider these initial result snippets (with their scores if available): {results_for_reasoning_display}
|
|
128
|
+
|
|
129
|
+
Formulate a brief analysis focusing on how to best order these results, especially considering tie-breaking rules based on semantic similarity to the query.
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
messages = [
|
|
133
|
+
{"role": "user", "content": reason_prompt},
|
|
134
|
+
]
|
|
135
|
+
|
|
136
|
+
resp = self.client.chat.completions.create(
|
|
137
|
+
model=self.model,
|
|
138
|
+
messages=messages,
|
|
139
|
+
temperature=self.temperature,
|
|
140
|
+
**self.client_kwargs,
|
|
141
|
+
)
|
|
142
|
+
return resp.choices[0].message.content.strip()
|
|
143
|
+
|
|
144
|
+
def perform(
|
|
145
|
+
self, query: str, scored_results: list[dict[str, Any]]
|
|
146
|
+
) -> list[dict[str, Any]]:
|
|
147
|
+
"""
|
|
148
|
+
Sorts a list of results (each with an assigned score and an '_internal_id')
|
|
149
|
+
based on the query using the configured LLM. This is the main public method for the sorter.
|
|
150
|
+
|
|
151
|
+
:param query: The original search query string.
|
|
152
|
+
:param scored_results: A list of dictionaries, where each dict has at least
|
|
153
|
+
'_internal_id', 'text', and 'score' keys.
|
|
154
|
+
:return: A list of result dictionaries representing the final sorted order.
|
|
155
|
+
Each dictionary will have the '_internal_id' removed.
|
|
156
|
+
:raises ValueError: If scored_results is empty, LLM output is malformed, or IDs are invalid.
|
|
157
|
+
"""
|
|
158
|
+
if not scored_results:
|
|
159
|
+
logging.info("Received empty list of scored results for sorting.")
|
|
160
|
+
self._dispatch(
|
|
161
|
+
{"query": query, "scored_results": scored_results, "ordered_ids": []}
|
|
162
|
+
)
|
|
163
|
+
return []
|
|
164
|
+
|
|
165
|
+
# Prepare a map for quick lookup of full result objects by their internal ID
|
|
166
|
+
id_to_full_result_map: dict[str, dict[str, Any]] = {}
|
|
167
|
+
for i, res in enumerate(scored_results):
|
|
168
|
+
if "_internal_id" not in res:
|
|
169
|
+
# Assign a temporary internal ID if missing, important for LLM interaction
|
|
170
|
+
res["_internal_id"] = res.get("id", f"gen_id_{i}")
|
|
171
|
+
if "score" not in res:
|
|
172
|
+
logging.warning(
|
|
173
|
+
f"Result ID '{res['_internal_id']}' missing score for sorting. Defaulting to 0."
|
|
174
|
+
)
|
|
175
|
+
res["score"] = 0
|
|
176
|
+
id_to_full_result_map[res["_internal_id"]] = res.copy() # Store a copy
|
|
177
|
+
|
|
178
|
+
reason_summary = None
|
|
179
|
+
if self.use_reason:
|
|
180
|
+
reason_summary = self._reason(query, list(id_to_full_result_map.values()))
|
|
181
|
+
|
|
182
|
+
messages_sort = self._build_sorting_messages(
|
|
183
|
+
query, list(id_to_full_result_map.values()), reason_summary
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
resp_sort = self.client.chat.completions.create(
|
|
187
|
+
model=self.model,
|
|
188
|
+
messages=messages_sort,
|
|
189
|
+
temperature=self.temperature,
|
|
190
|
+
**self.client_kwargs,
|
|
191
|
+
)
|
|
192
|
+
raw_sort_output = resp_sort.choices[0].message.content.strip()
|
|
193
|
+
|
|
194
|
+
# Robustly extract JSON if LLM adds preamble
|
|
195
|
+
if not raw_sort_output.startswith("{"):
|
|
196
|
+
raw_sort_output = "{" + raw_sort_output
|
|
197
|
+
|
|
198
|
+
try:
|
|
199
|
+
parsed_sort = json.loads(raw_sort_output)
|
|
200
|
+
except json.JSONDecodeError as e:
|
|
201
|
+
logging.error(
|
|
202
|
+
f"Failed to parse JSON for sorting step: {e}\nRaw output: {raw_sort_output}"
|
|
203
|
+
)
|
|
204
|
+
self._dispatch(
|
|
205
|
+
{
|
|
206
|
+
"query": query,
|
|
207
|
+
"scored_results": scored_results,
|
|
208
|
+
"ordered_ids_from_llm": [],
|
|
209
|
+
"final_sorted_results": [],
|
|
210
|
+
"error": f"JSON parsing failed: {e}",
|
|
211
|
+
}
|
|
212
|
+
)
|
|
213
|
+
raise ValueError(
|
|
214
|
+
f"LLM output is not valid JSON for sorting: {raw_sort_output}"
|
|
215
|
+
) from e
|
|
216
|
+
|
|
217
|
+
llm_ordered_ids = parsed_sort.get("ordered_ids")
|
|
218
|
+
if not isinstance(llm_ordered_ids, list) or not all(
|
|
219
|
+
isinstance(item, str) for item in llm_ordered_ids
|
|
220
|
+
):
|
|
221
|
+
logging.warning(
|
|
222
|
+
f"LLM returned invalid sort schema. Expected 'ordered_ids' as a list of strings, got: {parsed_sort}"
|
|
223
|
+
)
|
|
224
|
+
self._dispatch(
|
|
225
|
+
{
|
|
226
|
+
"query": query,
|
|
227
|
+
"scored_results": scored_results,
|
|
228
|
+
"ordered_ids_from_llm": llm_ordered_ids,
|
|
229
|
+
"final_sorted_results": [],
|
|
230
|
+
"warning": "Invalid sort format",
|
|
231
|
+
}
|
|
232
|
+
)
|
|
233
|
+
raise ValueError(
|
|
234
|
+
f"LLM returned invalid 'ordered_ids' format: {llm_ordered_ids}"
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
final_sorted_results: list[dict[str, Any]] = []
|
|
238
|
+
ids_placed_by_llm_set = set()
|
|
239
|
+
|
|
240
|
+
for internal_id in llm_ordered_ids:
|
|
241
|
+
if (
|
|
242
|
+
internal_id in id_to_full_result_map
|
|
243
|
+
and internal_id not in ids_placed_by_llm_set
|
|
244
|
+
):
|
|
245
|
+
result_to_add = id_to_full_result_map[internal_id].copy()
|
|
246
|
+
result_to_add.pop("_internal_id", None)
|
|
247
|
+
final_sorted_results.append(result_to_add)
|
|
248
|
+
ids_placed_by_llm_set.add(internal_id)
|
|
249
|
+
elif internal_id not in id_to_full_result_map:
|
|
250
|
+
logging.warning(
|
|
251
|
+
f"LLM ordered ID '{internal_id}' not found in original results. Skipping."
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
unranked_results = []
|
|
255
|
+
for res in scored_results:
|
|
256
|
+
if res["_internal_id"] not in ids_placed_by_llm_set:
|
|
257
|
+
result_to_add = res.copy()
|
|
258
|
+
result_to_add.pop("_internal_id", None)
|
|
259
|
+
unranked_results.append(result_to_add)
|
|
260
|
+
|
|
261
|
+
if unranked_results: # Add warning here
|
|
262
|
+
logging.warning(
|
|
263
|
+
f"The LLM did not explicitly rank {len(unranked_results)} result(s). "
|
|
264
|
+
"These will be appended to the end of the sorted list, ordered by their original score."
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
unranked_results.sort(key=lambda x: x.get("score", -1), reverse=True)
|
|
268
|
+
final_sorted_results.extend(unranked_results)
|
|
269
|
+
|
|
270
|
+
self._dispatch(
|
|
271
|
+
{
|
|
272
|
+
"query": query,
|
|
273
|
+
"original_scored_results": scored_results,
|
|
274
|
+
"llm_ordered_ids": llm_ordered_ids,
|
|
275
|
+
"final_sorted_results": final_sorted_results,
|
|
276
|
+
}
|
|
277
|
+
)
|
|
278
|
+
return final_sorted_results
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
from openai import OpenAI
|
|
4
|
+
|
|
5
|
+
from texttools.base.base_question_rewriter import BaseQuestionRewriter, RewriteMode
|
|
6
|
+
from texttools.formatter import Gemma3Formatter
|
|
7
|
+
|
|
8
|
+
# class QuestionGeneration(BaseModel):
|
|
9
|
+
# generated_question: str
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GemmaQuestionRewriter(BaseQuestionRewriter):
|
|
13
|
+
"""
|
|
14
|
+
Question Rewriter for Gemma-style models with two modes:
|
|
15
|
+
1. Rewrite with same meaning, different wording.
|
|
16
|
+
2. Rewrite with different meaning, similar wording.
|
|
17
|
+
Outputs JSON with a single string field: {"rewritten_question": "..."}.
|
|
18
|
+
|
|
19
|
+
Allows optional extra instructions via `prompt_template`.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
client: OpenAI,
|
|
25
|
+
*,
|
|
26
|
+
model: str,
|
|
27
|
+
chat_formatter: Optional[Any] = None,
|
|
28
|
+
use_reason: bool = False,
|
|
29
|
+
temperature: float = 0.0,
|
|
30
|
+
prompt_template: Optional[str] = None,
|
|
31
|
+
handlers: Optional[list[Any]] = None,
|
|
32
|
+
**client_kwargs: Any,
|
|
33
|
+
):
|
|
34
|
+
super().__init__(handlers)
|
|
35
|
+
self.client = client
|
|
36
|
+
self.model = model
|
|
37
|
+
self.temperature = temperature
|
|
38
|
+
self.client_kwargs = client_kwargs
|
|
39
|
+
|
|
40
|
+
self.chat_formatter = chat_formatter or Gemma3Formatter()
|
|
41
|
+
|
|
42
|
+
self.use_reason = use_reason
|
|
43
|
+
self.reason_summary = None
|
|
44
|
+
self.prompt_template = prompt_template
|
|
45
|
+
|
|
46
|
+
self.json_schema = {"rewritten_question": "string"}
|
|
47
|
+
|
|
48
|
+
def _build_messages(
|
|
49
|
+
self,
|
|
50
|
+
question: str,
|
|
51
|
+
mode: RewriteMode,
|
|
52
|
+
) -> list[dict[str, str]]:
|
|
53
|
+
"""
|
|
54
|
+
Builds the message list for the LLM API call for question rewriting,
|
|
55
|
+
adapting the prompt based on the chosen mode.
|
|
56
|
+
"""
|
|
57
|
+
clean_question = self.preprocess(question)
|
|
58
|
+
messages: list[dict[str, str]] = []
|
|
59
|
+
|
|
60
|
+
if self.prompt_template:
|
|
61
|
+
messages.append({"role": "user", "content": self.prompt_template})
|
|
62
|
+
|
|
63
|
+
if self.reason_summary:
|
|
64
|
+
messages.append(
|
|
65
|
+
{
|
|
66
|
+
"role": "user",
|
|
67
|
+
"content": f"Based on this analysis: {self.reason_summary}",
|
|
68
|
+
}
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
if mode == RewriteMode.SAME_MEANING_DIFFERENT_WORDING:
|
|
72
|
+
instruction = (
|
|
73
|
+
"Rewrite the following question using completely different wording and phrasing, "
|
|
74
|
+
"ensuring its original meaning is perfectly preserved. The rewritten question "
|
|
75
|
+
"should be distinct from the original but convey the exact same inquiry."
|
|
76
|
+
"**respond in the language of the question**"
|
|
77
|
+
)
|
|
78
|
+
elif mode == RewriteMode.DIFFERENT_MEANING_SIMILAR_WORDING:
|
|
79
|
+
instruction = (
|
|
80
|
+
"Rewrite the following question using *very similar wording and phrasing* "
|
|
81
|
+
"to the original, but ensure the rewritten question has a *completely different meaning*. "
|
|
82
|
+
"Focus on subtle changes that drastically alter the intent or subject of the question."
|
|
83
|
+
"**respond in the language of the question**"
|
|
84
|
+
)
|
|
85
|
+
else:
|
|
86
|
+
raise ValueError(f"Unsupported rewrite mode: {mode}")
|
|
87
|
+
|
|
88
|
+
messages.append({"role": "user", "content": instruction})
|
|
89
|
+
messages.append(
|
|
90
|
+
{"role": "user", "content": f"here is the question: {clean_question}"}
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# schema_instr = f"Respond only in JSON format: {json.dumps(self.json_schema)}"
|
|
94
|
+
messages.append(
|
|
95
|
+
{
|
|
96
|
+
"role": "user",
|
|
97
|
+
"content": """
|
|
98
|
+
Respond only with the new generated question, without any additional information.
|
|
99
|
+
**the generated question will be in the language of the users input**
|
|
100
|
+
""",
|
|
101
|
+
}
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# messages.append({"role": "assistant", "content": "{"})
|
|
105
|
+
# deprecated method for structured output
|
|
106
|
+
|
|
107
|
+
# this line will restructure the messages
|
|
108
|
+
# based on the formatter that we provided
|
|
109
|
+
# some models will require custom settings
|
|
110
|
+
restructured = self.chat_formatter.format(messages=messages)
|
|
111
|
+
|
|
112
|
+
return restructured
|
|
113
|
+
|
|
114
|
+
def _reason(self, question: str, mode: RewriteMode) -> str:
|
|
115
|
+
"""
|
|
116
|
+
Internal reasoning step to help the model understand the core meaning
|
|
117
|
+
or structure of the question depending on the mode.
|
|
118
|
+
"""
|
|
119
|
+
if mode == RewriteMode.SAME_MEANING_DIFFERENT_WORDING:
|
|
120
|
+
reason_prompt = """
|
|
121
|
+
Analyze the following question to identify its core intent, key concepts,
|
|
122
|
+
and the specific information it is seeking.
|
|
123
|
+
Provide a brief, summarized understanding of the question's meaning that
|
|
124
|
+
will help in rephrasing it accurately without changing its intent.
|
|
125
|
+
|
|
126
|
+
**respond in the language of the question**
|
|
127
|
+
|
|
128
|
+
"""
|
|
129
|
+
elif mode == RewriteMode.DIFFERENT_MEANING_SIMILAR_WORDING:
|
|
130
|
+
reason_prompt = """
|
|
131
|
+
Analyze the following question to identify its exact wording, phrasing,
|
|
132
|
+
and the literal meaning it conveys.
|
|
133
|
+
Provide a brief, summarized analysis of its linguistic structure and current meaning,
|
|
134
|
+
which will then be used to create a new question with similar words but a different meaning.
|
|
135
|
+
|
|
136
|
+
**respond in the language of the question**
|
|
137
|
+
"""
|
|
138
|
+
else:
|
|
139
|
+
raise ValueError(f"Unsupported rewrite mode for reason: {mode}")
|
|
140
|
+
|
|
141
|
+
messages = [
|
|
142
|
+
{"role": "user", "content": reason_prompt},
|
|
143
|
+
{"role": "user", "content": f"here is the question: {question}"},
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
restructured = self.chat_formatter.format(messages=messages)
|
|
147
|
+
|
|
148
|
+
resp = self.client.chat.completions.create(
|
|
149
|
+
model=self.model,
|
|
150
|
+
messages=restructured,
|
|
151
|
+
temperature=self.temperature,
|
|
152
|
+
**self.client_kwargs,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
reason_summary = resp.choices[0].message.content.strip()
|
|
156
|
+
self.reason_summary = reason_summary
|
|
157
|
+
|
|
158
|
+
def rewrite_question(
|
|
159
|
+
self,
|
|
160
|
+
question: str,
|
|
161
|
+
mode: RewriteMode = RewriteMode.SAME_MEANING_DIFFERENT_WORDING,
|
|
162
|
+
reason_summary: str = None,
|
|
163
|
+
) -> str:
|
|
164
|
+
"""
|
|
165
|
+
Rewrites the input `question` based on the specified `mode`.
|
|
166
|
+
Optionally uses an internal reasoning step for better accuracy.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
if self.use_reason and not reason_summary:
|
|
170
|
+
self._reason(question, mode)
|
|
171
|
+
elif reason_summary:
|
|
172
|
+
self.reason_summary = reason_summary
|
|
173
|
+
|
|
174
|
+
messages = self._build_messages(question, mode)
|
|
175
|
+
|
|
176
|
+
# for structured output formatting
|
|
177
|
+
# but now i want to try somthing else
|
|
178
|
+
# i want to see if i could get the results without structured output
|
|
179
|
+
# completion = self.client.beta.chat.completions.parse(
|
|
180
|
+
# model=self.model,
|
|
181
|
+
# messages=messages,
|
|
182
|
+
# response_format=QuestionGeneration,
|
|
183
|
+
# temperature=self.temperature,
|
|
184
|
+
# extra_body=dict(guided_decoding_backend="outlines"),
|
|
185
|
+
# **self.client_kwargs,
|
|
186
|
+
# )
|
|
187
|
+
# message = completion.choices[0].message
|
|
188
|
+
# if message.parsed:
|
|
189
|
+
# result = message.parsed.generated_question
|
|
190
|
+
# else:
|
|
191
|
+
# raise ValueError(f"Failed to parse the response. Raw content: {message.content}")
|
|
192
|
+
|
|
193
|
+
resp = self.client.chat.completions.create(
|
|
194
|
+
model=self.model,
|
|
195
|
+
messages=messages,
|
|
196
|
+
temperature=self.temperature,
|
|
197
|
+
**self.client_kwargs,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
result = resp.choices[0].message.content.strip()
|
|
201
|
+
|
|
202
|
+
# dispatch and return
|
|
203
|
+
self._dispatch(
|
|
204
|
+
{
|
|
205
|
+
"original_question": question,
|
|
206
|
+
"rewritten_question": result,
|
|
207
|
+
"mode": mode.value,
|
|
208
|
+
}
|
|
209
|
+
)
|
|
210
|
+
return result
|
|
211
|
+
|
|
212
|
+
def get_reason(self):
|
|
213
|
+
return self.reason_summary
|
|
File without changes
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
from typing import Any, Literal, Optional
|
|
2
|
+
|
|
3
|
+
from openai import OpenAI
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
from texttools.base.base_categorizer import BaseCategorizer
|
|
7
|
+
from texttools.formatter import Gemma3Formatter
|
|
8
|
+
from texttools.handlers import ResultHandler
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Output(BaseModel):
|
|
12
|
+
reason: str
|
|
13
|
+
main_tag: Literal[
|
|
14
|
+
"باورهای دینی",
|
|
15
|
+
"اخلاق اسلامی",
|
|
16
|
+
"احکام و فقه",
|
|
17
|
+
"تاریخ اسلام و شخصیت ها",
|
|
18
|
+
"منابع دینی",
|
|
19
|
+
"دین و جامعه/سیاست",
|
|
20
|
+
"عرفان و معنویت",
|
|
21
|
+
"هیچکدام",
|
|
22
|
+
] = None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GemmaCategorizer(BaseCategorizer):
|
|
26
|
+
"""
|
|
27
|
+
Categorizer for Gemma-style models. It requires a predefined Enum of categories
|
|
28
|
+
to choose from and returns an Enum member.
|
|
29
|
+
Outputs JSON with a single string field: {"category": "..."}.
|
|
30
|
+
|
|
31
|
+
Allows optional extra instructions via `prompt_template`.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
client: OpenAI,
|
|
37
|
+
*,
|
|
38
|
+
model: str,
|
|
39
|
+
output_structure: BaseModel = Output,
|
|
40
|
+
chat_formatter: Optional[Any] = None,
|
|
41
|
+
use_reason: bool = False,
|
|
42
|
+
temperature: float = 0.0,
|
|
43
|
+
prompt_template: Optional[str] = None,
|
|
44
|
+
handlers: Optional[list[ResultHandler]] = None,
|
|
45
|
+
**client_kwargs: Any,
|
|
46
|
+
):
|
|
47
|
+
super().__init__(handlers=handlers)
|
|
48
|
+
self.client = client
|
|
49
|
+
self.model = model
|
|
50
|
+
self.temperature = temperature
|
|
51
|
+
self.client_kwargs = client_kwargs
|
|
52
|
+
self.output_structure = output_structure
|
|
53
|
+
self.chat_formatter = chat_formatter or Gemma3Formatter()
|
|
54
|
+
|
|
55
|
+
self.use_reason = use_reason
|
|
56
|
+
self.prompt_template = prompt_template
|
|
57
|
+
|
|
58
|
+
def _build_messages(
|
|
59
|
+
self, text: str, reason: Optional[str] = None
|
|
60
|
+
) -> list[dict[str, str]]:
|
|
61
|
+
"""
|
|
62
|
+
Builds the message list for the LLM API call for categorization.
|
|
63
|
+
"""
|
|
64
|
+
clean_text = self.preprocess(text)
|
|
65
|
+
|
|
66
|
+
messages: list[dict[str, str]] = []
|
|
67
|
+
|
|
68
|
+
if self.prompt_template:
|
|
69
|
+
messages.append({"role": "user", "content": self.prompt_template})
|
|
70
|
+
|
|
71
|
+
if reason:
|
|
72
|
+
messages.append(
|
|
73
|
+
{"role": "user", "content": f"Based on this analysis: {reason}"}
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
messages.append(
|
|
77
|
+
{
|
|
78
|
+
"role": "user",
|
|
79
|
+
"content": """
|
|
80
|
+
تو یک متخصص علوم دینی هستی
|
|
81
|
+
من به عنوان کاربر یک متن به تو میدم و از تو میخوام که
|
|
82
|
+
اون متن رو در یکی از دسته بندی های زیر طبقه بندی کنی
|
|
83
|
+
|
|
84
|
+
"باورهای دینی",
|
|
85
|
+
"اخلاق اسلامی",
|
|
86
|
+
"احکام و فقه",
|
|
87
|
+
"تاریخ اسلام و شخصیت ها",
|
|
88
|
+
"منابع دینی",
|
|
89
|
+
"دین و جامعه/سیاست",
|
|
90
|
+
"عرفان و معنویت",
|
|
91
|
+
"هیچکدام",
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
در خروجی که از تو خواسته شده بخشی با عنوان reason وجود دارد
|
|
95
|
+
در اون بخش، دلیل انتخاب دسته بندی رو به صورت خلاصه بیاور
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
متنی که باید طبقه بندی کنی:
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
""",
|
|
103
|
+
}
|
|
104
|
+
)
|
|
105
|
+
messages.append({"role": "user", "content": clean_text})
|
|
106
|
+
restructured = self.chat_formatter.format(messages=messages)
|
|
107
|
+
|
|
108
|
+
return restructured
|
|
109
|
+
|
|
110
|
+
def _reason(self, text: str) -> str:
|
|
111
|
+
"""
|
|
112
|
+
Internal reasoning step to help the model analyze the text for categorization.
|
|
113
|
+
"""
|
|
114
|
+
messages = [
|
|
115
|
+
{
|
|
116
|
+
"role": "user",
|
|
117
|
+
"content": """
|
|
118
|
+
هدف ما طبقه بندی متن هست
|
|
119
|
+
متن رو بخون و ایده اصلی و آنالیزی کوتاه از اون رو ارائه بده
|
|
120
|
+
|
|
121
|
+
بسیار خلاصه باشه خروجی تو
|
|
122
|
+
نهایتا 20 کلمه
|
|
123
|
+
""",
|
|
124
|
+
},
|
|
125
|
+
{
|
|
126
|
+
"role": "user",
|
|
127
|
+
"content": f"""
|
|
128
|
+
{text}
|
|
129
|
+
""",
|
|
130
|
+
},
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
restrucruted = self.chat_formatter.format(messages=messages)
|
|
134
|
+
|
|
135
|
+
resp = self.client.chat.completions.create(
|
|
136
|
+
model=self.model,
|
|
137
|
+
messages=restrucruted,
|
|
138
|
+
temperature=self.temperature,
|
|
139
|
+
**self.client_kwargs,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
reason_summary = resp.choices[0].message.content.strip()
|
|
143
|
+
return reason_summary
|
|
144
|
+
|
|
145
|
+
def categorize(self, text: str):
|
|
146
|
+
"""
|
|
147
|
+
Categorizes `text` by selecting an appropriate member from the predefined Enum.
|
|
148
|
+
Optionally uses an internal reasoning step for better accuracy.
|
|
149
|
+
"""
|
|
150
|
+
reason_summary = None
|
|
151
|
+
if self.use_reason:
|
|
152
|
+
reason_summary = self._reason(text)
|
|
153
|
+
|
|
154
|
+
messages = self._build_messages(text, reason_summary)
|
|
155
|
+
completion = self.client.beta.chat.completions.parse(
|
|
156
|
+
model=self.model,
|
|
157
|
+
messages=messages,
|
|
158
|
+
response_format=Output,
|
|
159
|
+
temperature=self.temperature,
|
|
160
|
+
extra_body=dict(guided_decoding_backend="auto"),
|
|
161
|
+
**self.client_kwargs,
|
|
162
|
+
)
|
|
163
|
+
message = completion.choices[0].message
|
|
164
|
+
|
|
165
|
+
category_name = message.parsed.main_tag
|
|
166
|
+
|
|
167
|
+
# dispatch and return - Note: _dispatch expects dict
|
|
168
|
+
self._dispatch(results={"main_tag": category_name})
|
|
169
|
+
return {"main_tag": category_name}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .gemma_question_generator import GemmaQuestionGeneratorFromSubject
|