unique_toolkit 1.42.9__py3-none-any.whl → 1.43.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unique_toolkit/_common/experimental/write_up_agent/README.md +848 -0
- unique_toolkit/_common/experimental/write_up_agent/__init__.py +22 -0
- unique_toolkit/_common/experimental/write_up_agent/agent.py +170 -0
- unique_toolkit/_common/experimental/write_up_agent/config.py +42 -0
- unique_toolkit/_common/experimental/write_up_agent/examples/data.csv +13 -0
- unique_toolkit/_common/experimental/write_up_agent/examples/example_usage.py +78 -0
- unique_toolkit/_common/experimental/write_up_agent/schemas.py +36 -0
- unique_toolkit/_common/experimental/write_up_agent/services/__init__.py +13 -0
- unique_toolkit/_common/experimental/write_up_agent/services/dataframe_handler/__init__.py +19 -0
- unique_toolkit/_common/experimental/write_up_agent/services/dataframe_handler/exceptions.py +29 -0
- unique_toolkit/_common/experimental/write_up_agent/services/dataframe_handler/service.py +150 -0
- unique_toolkit/_common/experimental/write_up_agent/services/dataframe_handler/utils.py +130 -0
- unique_toolkit/_common/experimental/write_up_agent/services/generation_handler/__init__.py +27 -0
- unique_toolkit/_common/experimental/write_up_agent/services/generation_handler/config.py +56 -0
- unique_toolkit/_common/experimental/write_up_agent/services/generation_handler/exceptions.py +79 -0
- unique_toolkit/_common/experimental/write_up_agent/services/generation_handler/prompts/config.py +34 -0
- unique_toolkit/_common/experimental/write_up_agent/services/generation_handler/prompts/system_prompt.j2 +15 -0
- unique_toolkit/_common/experimental/write_up_agent/services/generation_handler/prompts/user_prompt.j2 +21 -0
- unique_toolkit/_common/experimental/write_up_agent/services/generation_handler/service.py +369 -0
- unique_toolkit/_common/experimental/write_up_agent/services/template_handler/__init__.py +29 -0
- unique_toolkit/_common/experimental/write_up_agent/services/template_handler/default_template.j2 +37 -0
- unique_toolkit/_common/experimental/write_up_agent/services/template_handler/exceptions.py +39 -0
- unique_toolkit/_common/experimental/write_up_agent/services/template_handler/service.py +191 -0
- unique_toolkit/_common/experimental/write_up_agent/services/template_handler/utils.py +182 -0
- unique_toolkit/_common/experimental/write_up_agent/utils.py +24 -0
- {unique_toolkit-1.42.9.dist-info → unique_toolkit-1.43.1.dist-info}/METADATA +7 -1
- {unique_toolkit-1.42.9.dist-info → unique_toolkit-1.43.1.dist-info}/RECORD +29 -4
- {unique_toolkit-1.42.9.dist-info → unique_toolkit-1.43.1.dist-info}/LICENSE +0 -0
- {unique_toolkit-1.42.9.dist-info → unique_toolkit-1.43.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
"""Generation handler service for LLM-based summarization."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, Callable
|
|
5
|
+
|
|
6
|
+
from jinja2 import Template
|
|
7
|
+
from tiktoken import get_encoding
|
|
8
|
+
|
|
9
|
+
from unique_toolkit._common.experimental.write_up_agent.schemas import (
|
|
10
|
+
GroupData,
|
|
11
|
+
ProcessedGroup,
|
|
12
|
+
)
|
|
13
|
+
from unique_toolkit._common.experimental.write_up_agent.services.dataframe_handler.utils import (
|
|
14
|
+
from_snake_case_to_display_name,
|
|
15
|
+
)
|
|
16
|
+
from unique_toolkit._common.experimental.write_up_agent.services.generation_handler.config import (
|
|
17
|
+
GenerationHandlerConfig,
|
|
18
|
+
)
|
|
19
|
+
from unique_toolkit._common.experimental.write_up_agent.services.generation_handler.exceptions import (
|
|
20
|
+
BatchCreationError,
|
|
21
|
+
LLMCallError,
|
|
22
|
+
PromptBuildError,
|
|
23
|
+
TokenLimitError,
|
|
24
|
+
)
|
|
25
|
+
from unique_toolkit.language_model import LanguageModelService
|
|
26
|
+
from unique_toolkit.language_model.builder import MessagesBuilder
|
|
27
|
+
|
|
28
|
+
_LOGGER = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class GenerationHandler:
|
|
32
|
+
"""
|
|
33
|
+
Handles LLM-based generation with adaptive batching and iterative aggregation.
|
|
34
|
+
|
|
35
|
+
This service:
|
|
36
|
+
- Splits groups into batches based on token/row limits
|
|
37
|
+
- Builds prompts from Jinja templates
|
|
38
|
+
- Calls LLM for each batch
|
|
39
|
+
- Aggregates results iteratively with context
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
config: GenerationHandlerConfig,
|
|
45
|
+
renderer: Callable[[GroupData], str],
|
|
46
|
+
):
|
|
47
|
+
"""
|
|
48
|
+
Initialize generation handler.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
config: Configuration for generation
|
|
52
|
+
renderer: Function to render group content (injected from template handler)
|
|
53
|
+
Signature: renderer(group_data: GroupData) -> str
|
|
54
|
+
"""
|
|
55
|
+
self._config = config
|
|
56
|
+
self._renderer = renderer
|
|
57
|
+
|
|
58
|
+
# TODO [UN-16142]: Use token counter from toolkit
|
|
59
|
+
try:
|
|
60
|
+
encoder = get_encoding(self._config.language_model.encoder_name)
|
|
61
|
+
except Exception as e:
|
|
62
|
+
_LOGGER.warning(
|
|
63
|
+
f"Failed to get encoder for model {self._config.language_model.name}: {e}"
|
|
64
|
+
)
|
|
65
|
+
encoder = get_encoding("cl100k_base")
|
|
66
|
+
|
|
67
|
+
def token_counter(text: str) -> int:
|
|
68
|
+
return len(encoder.encode(text))
|
|
69
|
+
|
|
70
|
+
# Token counter (use provided or default to character approximation)
|
|
71
|
+
self._token_counter = token_counter
|
|
72
|
+
|
|
73
|
+
def _default_token_counter(self, text: str) -> int:
|
|
74
|
+
"""Default token counter using tiktoken encoding with cl100k_base."""
|
|
75
|
+
default_encoder = get_encoding("cl100k_base")
|
|
76
|
+
return len(default_encoder.encode(text))
|
|
77
|
+
|
|
78
|
+
def process_groups(
|
|
79
|
+
self,
|
|
80
|
+
groups: list[GroupData],
|
|
81
|
+
grouping_column: str,
|
|
82
|
+
llm_service: LanguageModelService,
|
|
83
|
+
) -> list[ProcessedGroup]:
|
|
84
|
+
"""
|
|
85
|
+
Process all groups with LLM generation.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
groups: List of GroupData instances
|
|
89
|
+
grouping_column: The column name used for grouping (e.g., 'section')
|
|
90
|
+
llm_service: LanguageModelService instance to use for LLM calls
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
List of ProcessedGroup instances with llm_response added
|
|
94
|
+
|
|
95
|
+
Raises:
|
|
96
|
+
GenerationHandlerError: If generation fails
|
|
97
|
+
"""
|
|
98
|
+
processed_groups = []
|
|
99
|
+
|
|
100
|
+
for group in groups:
|
|
101
|
+
group_key_string = group.group_key
|
|
102
|
+
|
|
103
|
+
_LOGGER.info(f"Processing group: {group_key_string}")
|
|
104
|
+
|
|
105
|
+
# Get group-specific instruction using the documented format: "column:value"
|
|
106
|
+
# e.g., "section:introduction" for a group with key "introduction" in column "section"
|
|
107
|
+
lookup_key = f"{grouping_column}:{group_key_string}"
|
|
108
|
+
group_instruction = self._config.group_specific_instructions.get(lookup_key)
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
# Process group with batching
|
|
112
|
+
llm_response = self._process_group_with_batching(
|
|
113
|
+
group, group_instruction, llm_service
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Create ProcessedGroup with proper typing
|
|
117
|
+
processed_group = ProcessedGroup(
|
|
118
|
+
group_key=group.group_key,
|
|
119
|
+
rows=group.rows,
|
|
120
|
+
llm_response=llm_response,
|
|
121
|
+
)
|
|
122
|
+
processed_groups.append(processed_group)
|
|
123
|
+
|
|
124
|
+
_LOGGER.info(
|
|
125
|
+
f"Successfully processed group: {group_key_string} "
|
|
126
|
+
f"(response length: {self._token_counter(llm_response)} tokens)"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
except Exception as e:
|
|
130
|
+
_LOGGER.error(f"Error processing group {group_key_string}: {e}")
|
|
131
|
+
# Re-raise to allow caller to handle
|
|
132
|
+
raise
|
|
133
|
+
|
|
134
|
+
return processed_groups
|
|
135
|
+
|
|
136
|
+
def _process_group_with_batching(
|
|
137
|
+
self,
|
|
138
|
+
group: GroupData,
|
|
139
|
+
group_instruction: str | None,
|
|
140
|
+
llm_service: LanguageModelService,
|
|
141
|
+
) -> str:
|
|
142
|
+
"""
|
|
143
|
+
Process a single group with adaptive batching.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
group: GroupData instance
|
|
147
|
+
group_instruction: Optional group-specific instruction
|
|
148
|
+
llm_service: LanguageModelService instance to use for LLM calls
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Final LLM response (aggregated if multiple batches)
|
|
152
|
+
|
|
153
|
+
Raises:
|
|
154
|
+
BatchCreationError: If batching fails
|
|
155
|
+
LLMCallError: If LLM call fails
|
|
156
|
+
AggregationError: If aggregation fails
|
|
157
|
+
"""
|
|
158
|
+
group_key = group.group_key
|
|
159
|
+
rows = group.rows
|
|
160
|
+
|
|
161
|
+
# Create batches adaptively
|
|
162
|
+
try:
|
|
163
|
+
batches = self._create_batches(rows, group_key)
|
|
164
|
+
_LOGGER.info(
|
|
165
|
+
f"Created {len(batches)} batches for group {group_key} "
|
|
166
|
+
f"({len(rows)} total rows)"
|
|
167
|
+
)
|
|
168
|
+
except Exception as e:
|
|
169
|
+
raise BatchCreationError(
|
|
170
|
+
f"Failed to create batches for group {group_key}: {e}",
|
|
171
|
+
group_key=str(group_key),
|
|
172
|
+
row_count=len(rows),
|
|
173
|
+
) from e
|
|
174
|
+
|
|
175
|
+
# Process each batch iteratively, keeping only one previous summary at a time
|
|
176
|
+
previous_summary: str | None = None
|
|
177
|
+
|
|
178
|
+
# TODO [UN-16142]: Improve error handling logic for LLMCallError
|
|
179
|
+
for batch_index, batch_group in enumerate(batches, start=1):
|
|
180
|
+
try:
|
|
181
|
+
# Render content for this batch
|
|
182
|
+
content = self._renderer(batch_group)
|
|
183
|
+
|
|
184
|
+
# Convert snake_case group_key to Title Case for display in prompts
|
|
185
|
+
display_section_name = from_snake_case_to_display_name(group_key)
|
|
186
|
+
|
|
187
|
+
# Build prompts with section name and at most one previous summary
|
|
188
|
+
system_prompt, user_prompt = self._build_prompts(
|
|
189
|
+
section_name=display_section_name, # Use Title Case for display
|
|
190
|
+
content=content,
|
|
191
|
+
group_instruction=group_instruction,
|
|
192
|
+
previous_summary=previous_summary,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Call LLM
|
|
196
|
+
batch_summary = self._call_llm(system_prompt, user_prompt, llm_service)
|
|
197
|
+
|
|
198
|
+
# Keep only this summary for the next iteration
|
|
199
|
+
previous_summary = batch_summary
|
|
200
|
+
|
|
201
|
+
_LOGGER.debug(
|
|
202
|
+
f"Batch {batch_index}/{len(batches)} processed "
|
|
203
|
+
f"(summary length: {self._token_counter(batch_summary)} tokens)"
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
except LLMCallError:
|
|
207
|
+
raise
|
|
208
|
+
except Exception as e:
|
|
209
|
+
raise LLMCallError(
|
|
210
|
+
f"Error processing batch {batch_index} for group {group_key}: {e}",
|
|
211
|
+
group_key=str(group_key),
|
|
212
|
+
batch_index=batch_index,
|
|
213
|
+
error_details=str(e),
|
|
214
|
+
) from e
|
|
215
|
+
|
|
216
|
+
# Return final summary (last batch's result)
|
|
217
|
+
return previous_summary if previous_summary else ""
|
|
218
|
+
|
|
219
|
+
def _create_batches(
|
|
220
|
+
self, rows: list[dict[str, Any]], group_key: str
|
|
221
|
+
) -> list[GroupData]:
|
|
222
|
+
"""
|
|
223
|
+
Create batches adaptively based on token and row limits.
|
|
224
|
+
|
|
225
|
+
Fits as many rows as possible per batch while staying under limits.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
rows: List of row dicts
|
|
229
|
+
group_key: Group identifier for creating GroupData instances
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
List of GroupData instances (each representing a batch)
|
|
233
|
+
|
|
234
|
+
Raises:
|
|
235
|
+
TokenLimitError: If token counting fails
|
|
236
|
+
"""
|
|
237
|
+
if not rows:
|
|
238
|
+
return [GroupData(group_key=group_key, rows=[])]
|
|
239
|
+
|
|
240
|
+
batches = []
|
|
241
|
+
current_batch = []
|
|
242
|
+
current_batch_tokens = 0
|
|
243
|
+
|
|
244
|
+
for row in rows:
|
|
245
|
+
# Estimate tokens for this row (rough approximation)
|
|
246
|
+
try:
|
|
247
|
+
row_str = str(row)
|
|
248
|
+
row_tokens = self._token_counter(row_str)
|
|
249
|
+
except Exception as e:
|
|
250
|
+
raise TokenLimitError(
|
|
251
|
+
f"Failed to count tokens for row: {e}",
|
|
252
|
+
estimated_tokens=0,
|
|
253
|
+
max_tokens=self._config.max_tokens_per_batch,
|
|
254
|
+
) from e
|
|
255
|
+
|
|
256
|
+
# Check if adding this row would exceed limits
|
|
257
|
+
would_exceed_tokens = (
|
|
258
|
+
current_batch_tokens + row_tokens > self._config.max_tokens_per_batch
|
|
259
|
+
)
|
|
260
|
+
would_exceed_rows = len(current_batch) >= self._config.max_rows_per_batch
|
|
261
|
+
|
|
262
|
+
if current_batch and (would_exceed_tokens or would_exceed_rows):
|
|
263
|
+
# Start new batch - create GroupData instance
|
|
264
|
+
batches.append(GroupData(group_key=group_key, rows=current_batch))
|
|
265
|
+
current_batch = [row]
|
|
266
|
+
current_batch_tokens = row_tokens
|
|
267
|
+
else:
|
|
268
|
+
# Add to current batch
|
|
269
|
+
current_batch.append(row)
|
|
270
|
+
current_batch_tokens += row_tokens
|
|
271
|
+
|
|
272
|
+
# Add final batch
|
|
273
|
+
if current_batch:
|
|
274
|
+
batches.append(GroupData(group_key=group_key, rows=current_batch))
|
|
275
|
+
|
|
276
|
+
return batches
|
|
277
|
+
|
|
278
|
+
def _build_prompts(
|
|
279
|
+
self,
|
|
280
|
+
section_name: str,
|
|
281
|
+
content: str,
|
|
282
|
+
group_instruction: str | None,
|
|
283
|
+
previous_summary: str | None,
|
|
284
|
+
) -> tuple[str, str]:
|
|
285
|
+
"""
|
|
286
|
+
Build system and user prompts from templates.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
section_name: Name of the section being processed (group_key)
|
|
290
|
+
content: Rendered content to summarize
|
|
291
|
+
group_instruction: Optional group-specific instruction
|
|
292
|
+
previous_summary: Optional previous batch summary for context
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
Tuple of (system_prompt, user_prompt)
|
|
296
|
+
|
|
297
|
+
Raises:
|
|
298
|
+
PromptBuildError: If prompt building fails
|
|
299
|
+
"""
|
|
300
|
+
try:
|
|
301
|
+
# Build system prompt
|
|
302
|
+
system_prompt = Template(
|
|
303
|
+
self._config.prompts_config.system_prompt_template
|
|
304
|
+
).render(
|
|
305
|
+
common_instruction=self._config.common_instruction,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Build user prompt with section name
|
|
309
|
+
user_prompt = Template(
|
|
310
|
+
self._config.prompts_config.user_prompt_template
|
|
311
|
+
).render(
|
|
312
|
+
section_name=section_name,
|
|
313
|
+
content=content,
|
|
314
|
+
group_instruction=group_instruction,
|
|
315
|
+
previous_summary=previous_summary,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
return system_prompt.strip(), user_prompt.strip()
|
|
319
|
+
|
|
320
|
+
except Exception as e:
|
|
321
|
+
raise PromptBuildError(
|
|
322
|
+
f"Failed to build prompts: {e}",
|
|
323
|
+
context={
|
|
324
|
+
"section_name": section_name,
|
|
325
|
+
"has_group_instruction": group_instruction is not None,
|
|
326
|
+
"has_previous_summary": previous_summary is not None,
|
|
327
|
+
},
|
|
328
|
+
) from e
|
|
329
|
+
|
|
330
|
+
def _call_llm(
|
|
331
|
+
self, system_prompt: str, user_prompt: str, llm_service: LanguageModelService
|
|
332
|
+
) -> str:
|
|
333
|
+
"""
|
|
334
|
+
Call LLM with prompts.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
system_prompt: System prompt
|
|
338
|
+
user_prompt: User prompt
|
|
339
|
+
llm_service: LanguageModelService instance to use for LLM calls
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
LLM response text
|
|
343
|
+
|
|
344
|
+
Raises:
|
|
345
|
+
LLMCallError: If LLM call fails
|
|
346
|
+
"""
|
|
347
|
+
messages = (
|
|
348
|
+
MessagesBuilder()
|
|
349
|
+
.system_message_append(system_prompt)
|
|
350
|
+
.user_message_append(user_prompt)
|
|
351
|
+
.build()
|
|
352
|
+
)
|
|
353
|
+
try:
|
|
354
|
+
# Call the language model using the configured LMI
|
|
355
|
+
response = llm_service.complete(
|
|
356
|
+
messages=messages,
|
|
357
|
+
model_name=self._config.language_model.name,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
response_text = response.choices[0].message.content
|
|
361
|
+
|
|
362
|
+
assert isinstance(response_text, str), "Response must be a string"
|
|
363
|
+
return response_text
|
|
364
|
+
|
|
365
|
+
except Exception as e:
|
|
366
|
+
raise LLMCallError(
|
|
367
|
+
f"LLM call failed: {e}",
|
|
368
|
+
error_details=str(e),
|
|
369
|
+
) from e
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""Template handler module."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from unique_toolkit._common.experimental.write_up_agent.services.template_handler.exceptions import (
|
|
6
|
+
ColumnExtractionError,
|
|
7
|
+
TemplateHandlerError,
|
|
8
|
+
TemplateParsingError,
|
|
9
|
+
TemplateRenderingError,
|
|
10
|
+
TemplateStructureError,
|
|
11
|
+
)
|
|
12
|
+
from unique_toolkit._common.experimental.write_up_agent.services.template_handler.service import (
|
|
13
|
+
TemplateHandler,
|
|
14
|
+
)
|
|
15
|
+
from unique_toolkit._common.experimental.write_up_agent.utils import template_loader
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def default_jinja_template_loader():
|
|
19
|
+
return template_loader(Path(__file__).parent, "default_template.j2")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"TemplateHandler",
|
|
24
|
+
"TemplateHandlerError",
|
|
25
|
+
"TemplateParsingError",
|
|
26
|
+
"TemplateStructureError",
|
|
27
|
+
"TemplateRenderingError",
|
|
28
|
+
"ColumnExtractionError",
|
|
29
|
+
]
|
unique_toolkit/_common/experimental/write_up_agent/services/template_handler/default_template.j2
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
{#
|
|
2
|
+
TODO [UN-16142]: Simplify template logic
|
|
3
|
+
Default Write-Up Agent Template
|
|
4
|
+
|
|
5
|
+
This template works in two phases:
|
|
6
|
+
1. PHASE 1 (LLM Input): When g.llm_response is not provided, renders all Q&A data
|
|
7
|
+
- Used to send structured data to the LLM for summarization
|
|
8
|
+
2. PHASE 2 (Final Report): When g.llm_response is provided, renders the LLM output
|
|
9
|
+
- Used to generate the final markdown report with LLM summaries
|
|
10
|
+
|
|
11
|
+
Template Variables:
|
|
12
|
+
- groups: List of data groups
|
|
13
|
+
- g.<column>: Access grouping columns (e.g., g.section)
|
|
14
|
+
- g.rows: List of rows in this group
|
|
15
|
+
- g.llm_response: (Optional) LLM-generated summary to replace row data
|
|
16
|
+
- row.<column>: Access data columns (e.g., row.question, row.answer)
|
|
17
|
+
#}
|
|
18
|
+
{% for g in groups %}
|
|
19
|
+
# {{ g.section }}
|
|
20
|
+
|
|
21
|
+
{% if g.llm_response %}
|
|
22
|
+
{# Phase 2: Render LLM-generated summary #}
|
|
23
|
+
{{ g.llm_response }}
|
|
24
|
+
{% else %}
|
|
25
|
+
{# Phase 1: Render detailed Q&A for LLM to process #}
|
|
26
|
+
{% for row in g.rows %}
|
|
27
|
+
**Q: {{ row.question }}**
|
|
28
|
+
|
|
29
|
+
A: {{ row.answer }}
|
|
30
|
+
|
|
31
|
+
{% endfor %}
|
|
32
|
+
{% endif %}
|
|
33
|
+
|
|
34
|
+
---
|
|
35
|
+
|
|
36
|
+
{% endfor %}
|
|
37
|
+
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Exceptions for template handler operations."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class TemplateHandlerError(Exception):
|
|
5
|
+
"""Base exception for all template handler errors."""
|
|
6
|
+
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TemplateParsingError(TemplateHandlerError):
|
|
11
|
+
"""Raised when Jinja template parsing fails."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, message: str, template_snippet: str | None = None):
|
|
14
|
+
super().__init__(message)
|
|
15
|
+
self.template_snippet = template_snippet
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TemplateStructureError(TemplateHandlerError):
|
|
19
|
+
"""Raised when template doesn't have the required structure."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, message: str, expected_structure: str | None = None):
|
|
22
|
+
super().__init__(message)
|
|
23
|
+
self.expected_structure = expected_structure
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TemplateRenderingError(TemplateHandlerError):
|
|
27
|
+
"""Raised when template rendering fails."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, message: str, context_keys: list[str] | None = None):
|
|
30
|
+
super().__init__(message)
|
|
31
|
+
self.context_keys = context_keys or []
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ColumnExtractionError(TemplateHandlerError):
|
|
35
|
+
"""Raised when extracting columns from template fails."""
|
|
36
|
+
|
|
37
|
+
def __init__(self, message: str, detected_columns: list[str] | None = None):
|
|
38
|
+
super().__init__(message)
|
|
39
|
+
self.detected_columns = detected_columns or []
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""Template handler service."""
|
|
2
|
+
|
|
3
|
+
from jinja2 import Template, TemplateError
|
|
4
|
+
|
|
5
|
+
from unique_toolkit._common.experimental.write_up_agent.schemas import (
|
|
6
|
+
GroupData,
|
|
7
|
+
ProcessedGroup,
|
|
8
|
+
)
|
|
9
|
+
from unique_toolkit._common.experimental.write_up_agent.services.dataframe_handler.utils import (
|
|
10
|
+
from_snake_case_to_display_name,
|
|
11
|
+
)
|
|
12
|
+
from unique_toolkit._common.experimental.write_up_agent.services.template_handler.exceptions import (
|
|
13
|
+
ColumnExtractionError,
|
|
14
|
+
TemplateParsingError,
|
|
15
|
+
TemplateRenderingError,
|
|
16
|
+
TemplateStructureError,
|
|
17
|
+
)
|
|
18
|
+
from unique_toolkit._common.experimental.write_up_agent.services.template_handler.utils import (
|
|
19
|
+
parse_template,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# TODO [UN-16142]: Simplify template logic
|
|
24
|
+
class TemplateHandler:
|
|
25
|
+
"""
|
|
26
|
+
Handles all template operations.
|
|
27
|
+
|
|
28
|
+
Responsibilities:
|
|
29
|
+
- Extract grouping column (single only)
|
|
30
|
+
- Extract selected columns
|
|
31
|
+
- Render template for groups
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, template: str):
|
|
35
|
+
"""
|
|
36
|
+
Initialize template handler.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
template: Jinja template string
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
TemplateParsingError: If template cannot be parsed
|
|
43
|
+
"""
|
|
44
|
+
self._template = template
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
self._jinja_template = Template(template, lstrip_blocks=True)
|
|
48
|
+
except TemplateError as e:
|
|
49
|
+
snippet = template[:100] + "..." if len(template) > 100 else template
|
|
50
|
+
raise TemplateParsingError(
|
|
51
|
+
f"Failed to parse Jinja template: {e}", template_snippet=snippet
|
|
52
|
+
) from e
|
|
53
|
+
|
|
54
|
+
self._parsed_info = None
|
|
55
|
+
|
|
56
|
+
def _get_parsed_info(self):
|
|
57
|
+
"""
|
|
58
|
+
Lazy parse template.
|
|
59
|
+
|
|
60
|
+
Raises:
|
|
61
|
+
TemplateParsingError: If template structure cannot be parsed
|
|
62
|
+
"""
|
|
63
|
+
if self._parsed_info is None:
|
|
64
|
+
try:
|
|
65
|
+
self._parsed_info = parse_template(self._template)
|
|
66
|
+
except Exception as e:
|
|
67
|
+
raise TemplateParsingError(
|
|
68
|
+
f"Failed to parse template structure: {e}"
|
|
69
|
+
) from e
|
|
70
|
+
return self._parsed_info
|
|
71
|
+
|
|
72
|
+
def get_grouping_column(self) -> str:
|
|
73
|
+
"""
|
|
74
|
+
Extract the single grouping column.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Column name to group by
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
TemplateStructureError: If template structure is invalid
|
|
81
|
+
ColumnExtractionError: If grouping column detection fails
|
|
82
|
+
"""
|
|
83
|
+
info = self._get_parsed_info()
|
|
84
|
+
|
|
85
|
+
if not info.expects_groups:
|
|
86
|
+
raise TemplateStructureError(
|
|
87
|
+
"Template must use grouping pattern: {% for g in groups %}",
|
|
88
|
+
expected_structure="{% for g in groups %}",
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
if len(info.grouping_columns) == 0:
|
|
92
|
+
raise ColumnExtractionError(
|
|
93
|
+
"No grouping column detected in template. Use {{ g.column_name }} to reference grouping columns."
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
if len(info.grouping_columns) > 1:
|
|
97
|
+
raise ColumnExtractionError(
|
|
98
|
+
f"Single grouping column required. Found {len(info.grouping_columns)}: {info.grouping_columns}",
|
|
99
|
+
detected_columns=info.grouping_columns,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return info.grouping_columns[0]
|
|
103
|
+
|
|
104
|
+
def get_selected_columns(self) -> list[str]:
|
|
105
|
+
"""
|
|
106
|
+
Extract columns referenced in template.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
List of column names from {{ row.column }} patterns
|
|
110
|
+
"""
|
|
111
|
+
info = self._get_parsed_info()
|
|
112
|
+
return info.row_columns
|
|
113
|
+
|
|
114
|
+
def render_group(
|
|
115
|
+
self, group_data: GroupData, llm_response: str | None = None
|
|
116
|
+
) -> str:
|
|
117
|
+
"""
|
|
118
|
+
Render template for a single group.
|
|
119
|
+
|
|
120
|
+
This method supports two rendering modes:
|
|
121
|
+
1. Without llm_response: Renders the full row data (for LLM input)
|
|
122
|
+
2. With llm_response: Renders the LLM output instead of row data (for final report)
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
group_data: GroupData instance with group_key and rows
|
|
126
|
+
llm_response: Optional LLM-generated output. If provided, the template
|
|
127
|
+
will render this instead of the detailed row loop.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Rendered template string
|
|
131
|
+
|
|
132
|
+
Raises:
|
|
133
|
+
TemplateRenderingError: If rendering fails
|
|
134
|
+
"""
|
|
135
|
+
try:
|
|
136
|
+
grouping_column = self.get_grouping_column()
|
|
137
|
+
|
|
138
|
+
# Prepare group item with grouping column value, rows, and llm_response
|
|
139
|
+
group_item = {
|
|
140
|
+
grouping_column: from_snake_case_to_display_name(group_data.group_key),
|
|
141
|
+
"rows": group_data.rows,
|
|
142
|
+
"llm_response": llm_response, # Add to group item, not top level
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
# Render with groups list (template expects {% for g in groups %})
|
|
146
|
+
return self._jinja_template.render(groups=[group_item])
|
|
147
|
+
except (TemplateError, KeyError) as e:
|
|
148
|
+
context_keys = ["group_key", "rows", "llm_response"]
|
|
149
|
+
raise TemplateRenderingError(
|
|
150
|
+
f"Failed to render template: {e}", context_keys=context_keys
|
|
151
|
+
) from e
|
|
152
|
+
|
|
153
|
+
def render_all_groups(self, processed_groups: list[ProcessedGroup]) -> str:
|
|
154
|
+
"""
|
|
155
|
+
Render template for all groups at once.
|
|
156
|
+
|
|
157
|
+
Takes advantage of the template's {% for g in groups %} loop
|
|
158
|
+
to render all groups in a single pass.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
processed_groups: List of ProcessedGroup instances with group_key, rows, and llm_response
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
Rendered template string with all groups
|
|
165
|
+
|
|
166
|
+
Raises:
|
|
167
|
+
TemplateRenderingError: If rendering fails
|
|
168
|
+
"""
|
|
169
|
+
try:
|
|
170
|
+
grouping_column = self.get_grouping_column()
|
|
171
|
+
|
|
172
|
+
# Prepare all groups for rendering
|
|
173
|
+
groups_data = []
|
|
174
|
+
for group_data in processed_groups:
|
|
175
|
+
# Convert snake_case group_key to Title Case for display
|
|
176
|
+
display_group_key = from_snake_case_to_display_name(
|
|
177
|
+
group_data.group_key
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
group_item = {
|
|
181
|
+
grouping_column: display_group_key, # Use Title Case for display
|
|
182
|
+
"rows": group_data.rows,
|
|
183
|
+
"llm_response": group_data.llm_response,
|
|
184
|
+
}
|
|
185
|
+
groups_data.append(group_item)
|
|
186
|
+
|
|
187
|
+
# Render all groups at once using template's loop
|
|
188
|
+
return self._jinja_template.render(groups=groups_data)
|
|
189
|
+
|
|
190
|
+
except (TemplateError, KeyError) as e:
|
|
191
|
+
raise TemplateRenderingError(f"Failed to render all groups: {e}") from e
|