openai-sdk-helpers 0.6.1__py3-none-any.whl → 0.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openai_sdk_helpers/agent/__init__.py +2 -0
- openai_sdk_helpers/agent/base.py +75 -7
- openai_sdk_helpers/agent/classifier.py +284 -102
- openai_sdk_helpers/agent/configuration.py +42 -0
- openai_sdk_helpers/agent/files.py +120 -0
- openai_sdk_helpers/agent/runner.py +9 -9
- openai_sdk_helpers/agent/translator.py +2 -2
- openai_sdk_helpers/files_api.py +46 -1
- openai_sdk_helpers/prompt/classifier.jinja +25 -10
- openai_sdk_helpers/structure/__init__.py +8 -2
- openai_sdk_helpers/structure/classification.py +240 -85
- {openai_sdk_helpers-0.6.1.dist-info → openai_sdk_helpers-0.6.4.dist-info}/METADATA +1 -1
- {openai_sdk_helpers-0.6.1.dist-info → openai_sdk_helpers-0.6.4.dist-info}/RECORD +16 -15
- {openai_sdk_helpers-0.6.1.dist-info → openai_sdk_helpers-0.6.4.dist-info}/WHEEL +0 -0
- {openai_sdk_helpers-0.6.1.dist-info → openai_sdk_helpers-0.6.4.dist-info}/entry_points.txt +0 -0
- {openai_sdk_helpers-0.6.1.dist-info → openai_sdk_helpers-0.6.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,19 +3,26 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import asyncio
|
|
6
|
+
import threading
|
|
6
7
|
import re
|
|
7
8
|
from dataclasses import dataclass, field
|
|
8
9
|
from enum import Enum
|
|
9
10
|
from pathlib import Path
|
|
10
11
|
from typing import Any, Awaitable, Dict, Iterable, Optional, Sequence, cast
|
|
11
12
|
|
|
13
|
+
from agents.model_settings import ModelSettings
|
|
14
|
+
|
|
12
15
|
from ..structure import (
|
|
13
16
|
ClassificationResult,
|
|
14
17
|
ClassificationStep,
|
|
15
18
|
ClassificationStopReason,
|
|
16
19
|
StructureBase,
|
|
20
|
+
Taxonomy,
|
|
17
21
|
TaxonomyNode,
|
|
22
|
+
format_path_identifier,
|
|
23
|
+
split_path_identifier,
|
|
18
24
|
)
|
|
25
|
+
from ..utils import ensure_list
|
|
19
26
|
from .base import AgentBase
|
|
20
27
|
from .configuration import AgentConfiguration
|
|
21
28
|
|
|
@@ -29,11 +36,15 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
29
36
|
Optional template file path for prompt rendering.
|
|
30
37
|
model : str | None, default=None
|
|
31
38
|
Model identifier to use for classification.
|
|
39
|
+
model_settings : ModelSettings | None, default=None
|
|
40
|
+
Optional model settings to apply to the classifier agent.
|
|
32
41
|
|
|
33
42
|
Methods
|
|
34
43
|
-------
|
|
35
|
-
|
|
36
|
-
Classify text
|
|
44
|
+
run_async(input, context, max_depth, confidence_threshold)
|
|
45
|
+
Classify text asynchronously using taxonomy traversal.
|
|
46
|
+
run_sync(input, context, max_depth, confidence_threshold)
|
|
47
|
+
Classify text synchronously using taxonomy traversal.
|
|
37
48
|
|
|
38
49
|
Examples
|
|
39
50
|
--------
|
|
@@ -51,6 +62,7 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
51
62
|
*,
|
|
52
63
|
template_path: Path | str | None = None,
|
|
53
64
|
model: str | None = None,
|
|
65
|
+
model_settings: ModelSettings | None = None,
|
|
54
66
|
taxonomy: TaxonomyNode | Sequence[TaxonomyNode],
|
|
55
67
|
) -> None:
|
|
56
68
|
"""Initialize the taxonomy classifier agent configuration.
|
|
@@ -61,6 +73,8 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
61
73
|
Optional template file path for prompt rendering.
|
|
62
74
|
model : str | None, default=None
|
|
63
75
|
Model identifier to use for classification.
|
|
76
|
+
model_settings : ModelSettings | None, default=None
|
|
77
|
+
Optional model settings to apply to the classifier agent.
|
|
64
78
|
taxonomy : TaxonomyNode | Sequence[TaxonomyNode]
|
|
65
79
|
Root taxonomy node or list of root nodes.
|
|
66
80
|
|
|
@@ -85,17 +99,19 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
85
99
|
template_path=resolved_template_path,
|
|
86
100
|
output_structure=ClassificationStep,
|
|
87
101
|
model=model,
|
|
102
|
+
model_settings=model_settings,
|
|
88
103
|
)
|
|
89
104
|
super().__init__(configuration=configuration)
|
|
90
105
|
|
|
91
|
-
async def
|
|
106
|
+
async def _run_agent(
|
|
92
107
|
self,
|
|
93
108
|
text: str,
|
|
94
109
|
*,
|
|
95
110
|
context: Optional[Dict[str, Any]] = None,
|
|
111
|
+
file_ids: str | Sequence[str] | None = None,
|
|
96
112
|
max_depth: Optional[int] = None,
|
|
97
113
|
confidence_threshold: float | None = None,
|
|
98
|
-
|
|
114
|
+
session: Optional[Any] = None,
|
|
99
115
|
) -> ClassificationResult:
|
|
100
116
|
"""Classify ``text`` by recursively walking taxonomy levels.
|
|
101
117
|
|
|
@@ -105,12 +121,14 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
105
121
|
Source text to classify.
|
|
106
122
|
context : dict or None, default=None
|
|
107
123
|
Additional context values to merge into the prompt.
|
|
124
|
+
file_ids : str or Sequence[str] or None, default=None
|
|
125
|
+
Optional file IDs to attach to each classification step.
|
|
108
126
|
max_depth : int or None, default=None
|
|
109
127
|
Maximum depth to traverse before stopping.
|
|
110
128
|
confidence_threshold : float or None, default=None
|
|
111
129
|
Minimum confidence required to accept a classification step.
|
|
112
|
-
|
|
113
|
-
|
|
130
|
+
session : Session or None, default=None
|
|
131
|
+
Optional session for maintaining conversation history across runs.
|
|
114
132
|
|
|
115
133
|
Returns
|
|
116
134
|
-------
|
|
@@ -125,61 +143,222 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
125
143
|
True
|
|
126
144
|
"""
|
|
127
145
|
state = _TraversalState()
|
|
146
|
+
input_payload = _build_input_payload(text, file_ids)
|
|
128
147
|
await self._classify_nodes(
|
|
129
|
-
|
|
148
|
+
input_payload=input_payload,
|
|
130
149
|
nodes=list(self._root_nodes),
|
|
131
150
|
depth=0,
|
|
132
151
|
parent_path=[],
|
|
133
152
|
context=context,
|
|
153
|
+
file_ids=file_ids,
|
|
134
154
|
max_depth=max_depth,
|
|
135
155
|
confidence_threshold=confidence_threshold,
|
|
136
|
-
|
|
156
|
+
session=session,
|
|
137
157
|
state=state,
|
|
138
158
|
)
|
|
139
159
|
|
|
140
160
|
final_nodes_value = state.final_nodes or None
|
|
141
|
-
final_node = state.final_nodes[0] if state.final_nodes else None
|
|
142
161
|
stop_reason = _resolve_stop_reason(state)
|
|
143
162
|
return ClassificationResult(
|
|
144
|
-
final_node=final_node,
|
|
145
163
|
final_nodes=final_nodes_value,
|
|
146
164
|
confidence=state.best_confidence,
|
|
147
165
|
stop_reason=stop_reason,
|
|
148
|
-
|
|
149
|
-
|
|
166
|
+
steps=state.steps,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
async def run_async(
|
|
170
|
+
self,
|
|
171
|
+
input: str | list[dict[str, Any]],
|
|
172
|
+
*,
|
|
173
|
+
context: Optional[Dict[str, Any]] = None,
|
|
174
|
+
output_structure: Optional[type[StructureBase]] = None,
|
|
175
|
+
session: Optional[Any] = None,
|
|
176
|
+
file_ids: str | Sequence[str] | None = None,
|
|
177
|
+
max_depth: Optional[int] = None,
|
|
178
|
+
confidence_threshold: float | None = None,
|
|
179
|
+
) -> ClassificationResult:
|
|
180
|
+
"""Classify ``input`` asynchronously with taxonomy traversal.
|
|
181
|
+
|
|
182
|
+
Parameters
|
|
183
|
+
----------
|
|
184
|
+
input : str or list[dict[str, Any]]
|
|
185
|
+
Source text to classify.
|
|
186
|
+
context : dict or None, default=None
|
|
187
|
+
Additional context values to merge into the prompt.
|
|
188
|
+
output_structure : type[StructureBase] or None, default=None
|
|
189
|
+
Unused in taxonomy traversal. Present for API compatibility.
|
|
190
|
+
session : Session or None, default=None
|
|
191
|
+
Optional session for maintaining conversation history across runs.
|
|
192
|
+
file_ids : str or Sequence[str] or None, default=None
|
|
193
|
+
Optional file IDs to attach to each classification step.
|
|
194
|
+
max_depth : int or None, default=None
|
|
195
|
+
Maximum depth to traverse before stopping.
|
|
196
|
+
confidence_threshold : float or None, default=None
|
|
197
|
+
Minimum confidence required to accept a classification step.
|
|
198
|
+
|
|
199
|
+
Returns
|
|
200
|
+
-------
|
|
201
|
+
ClassificationResult
|
|
202
|
+
Structured classification result describing the traversal.
|
|
203
|
+
"""
|
|
204
|
+
_ = output_structure
|
|
205
|
+
if not isinstance(input, str):
|
|
206
|
+
msg = "TaxonomyClassifierAgent run_async requires text input."
|
|
207
|
+
raise TypeError(msg)
|
|
208
|
+
kwargs: Dict[str, Any] = {
|
|
209
|
+
"context": context,
|
|
210
|
+
"file_ids": file_ids,
|
|
211
|
+
"max_depth": max_depth,
|
|
212
|
+
"confidence_threshold": confidence_threshold,
|
|
213
|
+
}
|
|
214
|
+
if session is not None:
|
|
215
|
+
kwargs["session"] = session
|
|
216
|
+
return await self._run_agent(input, **kwargs)
|
|
217
|
+
|
|
218
|
+
def run_sync(
|
|
219
|
+
self,
|
|
220
|
+
input: str | list[dict[str, Any]],
|
|
221
|
+
*,
|
|
222
|
+
context: Optional[Dict[str, Any]] = None,
|
|
223
|
+
output_structure: Optional[type[StructureBase]] = None,
|
|
224
|
+
session: Optional[Any] = None,
|
|
225
|
+
file_ids: str | Sequence[str] | None = None,
|
|
226
|
+
max_depth: Optional[int] = None,
|
|
227
|
+
confidence_threshold: float | None = None,
|
|
228
|
+
) -> ClassificationResult:
|
|
229
|
+
"""Classify ``input`` synchronously with taxonomy traversal.
|
|
230
|
+
|
|
231
|
+
Parameters
|
|
232
|
+
----------
|
|
233
|
+
input : str or list[dict[str, Any]]
|
|
234
|
+
Source text to classify.
|
|
235
|
+
context : dict or None, default=None
|
|
236
|
+
Additional context values to merge into the prompt.
|
|
237
|
+
output_structure : type[StructureBase] or None, default=None
|
|
238
|
+
Unused in taxonomy traversal. Present for API compatibility.
|
|
239
|
+
session : Session or None, default=None
|
|
240
|
+
Optional session for maintaining conversation history across runs.
|
|
241
|
+
file_ids : str or Sequence[str] or None, default=None
|
|
242
|
+
Optional file IDs to attach to each classification step.
|
|
243
|
+
max_depth : int or None, default=None
|
|
244
|
+
Maximum depth to traverse before stopping.
|
|
245
|
+
confidence_threshold : float or None, default=None
|
|
246
|
+
Minimum confidence required to accept a classification step.
|
|
247
|
+
|
|
248
|
+
Returns
|
|
249
|
+
-------
|
|
250
|
+
ClassificationResult
|
|
251
|
+
Structured classification result describing the traversal.
|
|
252
|
+
"""
|
|
253
|
+
_ = output_structure
|
|
254
|
+
if not isinstance(input, str):
|
|
255
|
+
msg = "TaxonomyClassifierAgent run_sync requires text input."
|
|
256
|
+
raise TypeError(msg)
|
|
257
|
+
kwargs: Dict[str, Any] = {
|
|
258
|
+
"context": context,
|
|
259
|
+
"file_ids": file_ids,
|
|
260
|
+
"max_depth": max_depth,
|
|
261
|
+
"confidence_threshold": confidence_threshold,
|
|
262
|
+
}
|
|
263
|
+
if session is not None:
|
|
264
|
+
kwargs["session"] = session
|
|
265
|
+
|
|
266
|
+
async def runner() -> ClassificationResult:
|
|
267
|
+
return await self._run_agent(input, **kwargs)
|
|
268
|
+
|
|
269
|
+
try:
|
|
270
|
+
asyncio.get_running_loop()
|
|
271
|
+
except RuntimeError:
|
|
272
|
+
return asyncio.run(runner())
|
|
273
|
+
|
|
274
|
+
result: ClassificationResult | None = None
|
|
275
|
+
error: Exception | None = None
|
|
276
|
+
|
|
277
|
+
def _thread_func() -> None:
|
|
278
|
+
nonlocal error, result
|
|
279
|
+
try:
|
|
280
|
+
result = asyncio.run(runner())
|
|
281
|
+
except Exception as exc:
|
|
282
|
+
error = exc
|
|
283
|
+
|
|
284
|
+
thread = threading.Thread(target=_thread_func)
|
|
285
|
+
thread.start()
|
|
286
|
+
thread.join()
|
|
287
|
+
|
|
288
|
+
if error is not None:
|
|
289
|
+
raise error
|
|
290
|
+
if result is None:
|
|
291
|
+
msg = "Classification did not return a result"
|
|
292
|
+
raise RuntimeError(msg)
|
|
293
|
+
return result
|
|
294
|
+
|
|
295
|
+
async def _run_step_async(
|
|
296
|
+
self,
|
|
297
|
+
*,
|
|
298
|
+
input: str | list[dict[str, Any]],
|
|
299
|
+
context: Optional[Dict[str, Any]] = None,
|
|
300
|
+
output_structure: Optional[type[StructureBase]] = None,
|
|
301
|
+
session: Optional[Any] = None,
|
|
302
|
+
) -> StructureBase:
|
|
303
|
+
"""Execute a single classification step asynchronously.
|
|
304
|
+
|
|
305
|
+
Parameters
|
|
306
|
+
----------
|
|
307
|
+
input : str or list[dict[str, Any]]
|
|
308
|
+
Prompt or structured input for the agent.
|
|
309
|
+
context : dict or None, default=None
|
|
310
|
+
Optional dictionary passed to the agent.
|
|
311
|
+
output_structure : type[StructureBase] or None, default=None
|
|
312
|
+
Optional type used to cast the final output.
|
|
313
|
+
session : Session or None, default=None
|
|
314
|
+
Optional session for maintaining conversation history across runs.
|
|
315
|
+
|
|
316
|
+
Returns
|
|
317
|
+
-------
|
|
318
|
+
StructureBase
|
|
319
|
+
Parsed result for the classification step.
|
|
320
|
+
"""
|
|
321
|
+
return await super().run_async(
|
|
322
|
+
input=input,
|
|
323
|
+
context=context,
|
|
324
|
+
output_structure=output_structure,
|
|
325
|
+
session=session,
|
|
150
326
|
)
|
|
151
327
|
|
|
152
328
|
async def _classify_nodes(
|
|
153
329
|
self,
|
|
154
330
|
*,
|
|
155
|
-
|
|
331
|
+
input_payload: str | list[dict[str, Any]],
|
|
156
332
|
nodes: list[TaxonomyNode],
|
|
157
333
|
depth: int,
|
|
158
334
|
parent_path: list[str],
|
|
159
335
|
context: Optional[Dict[str, Any]],
|
|
336
|
+
file_ids: str | Sequence[str] | None,
|
|
160
337
|
max_depth: Optional[int],
|
|
161
338
|
confidence_threshold: float | None,
|
|
162
|
-
|
|
339
|
+
session: Optional[Any],
|
|
163
340
|
state: "_TraversalState",
|
|
164
341
|
) -> None:
|
|
165
342
|
"""Classify a taxonomy level and recursively traverse children.
|
|
166
343
|
|
|
167
344
|
Parameters
|
|
168
345
|
----------
|
|
169
|
-
|
|
170
|
-
|
|
346
|
+
input_payload : str or list[dict[str, Any]]
|
|
347
|
+
Input payload used to prompt the agent.
|
|
171
348
|
nodes : list[TaxonomyNode]
|
|
172
349
|
Candidate taxonomy nodes for the current level.
|
|
173
350
|
depth : int
|
|
174
351
|
Current traversal depth.
|
|
175
352
|
context : dict or None
|
|
176
353
|
Additional context values to merge into the prompt.
|
|
354
|
+
file_ids : str or Sequence[str] or None
|
|
355
|
+
Optional file IDs attached to each classification step.
|
|
177
356
|
max_depth : int or None
|
|
178
357
|
Maximum traversal depth before stopping.
|
|
179
358
|
confidence_threshold : float or None
|
|
180
359
|
Minimum confidence required to accept a classification step.
|
|
181
|
-
|
|
182
|
-
|
|
360
|
+
session : Session or None
|
|
361
|
+
Optional session for maintaining conversation history across runs.
|
|
183
362
|
state : _TraversalState
|
|
184
363
|
Aggregated traversal state.
|
|
185
364
|
"""
|
|
@@ -192,18 +371,19 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
192
371
|
node_paths = _build_node_path_map(nodes, parent_path)
|
|
193
372
|
template_context = _build_context(
|
|
194
373
|
node_descriptors=_build_node_descriptors(node_paths),
|
|
195
|
-
|
|
374
|
+
steps=state.steps,
|
|
196
375
|
depth=depth,
|
|
197
376
|
context=context,
|
|
198
377
|
)
|
|
199
378
|
step_structure = _build_step_structure(list(node_paths.keys()))
|
|
200
|
-
raw_step = await self.
|
|
201
|
-
input=
|
|
379
|
+
raw_step = await self._run_step_async(
|
|
380
|
+
input=input_payload,
|
|
202
381
|
context=template_context,
|
|
203
382
|
output_structure=step_structure,
|
|
383
|
+
session=session,
|
|
204
384
|
)
|
|
205
385
|
step = _normalize_step_output(raw_step, step_structure)
|
|
206
|
-
state.
|
|
386
|
+
state.steps.append(step)
|
|
207
387
|
|
|
208
388
|
if (
|
|
209
389
|
confidence_threshold is not None
|
|
@@ -213,10 +393,6 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
213
393
|
return
|
|
214
394
|
|
|
215
395
|
resolved_nodes = _resolve_nodes(node_paths, step)
|
|
216
|
-
if resolved_nodes:
|
|
217
|
-
if single_class:
|
|
218
|
-
resolved_nodes = resolved_nodes[:1]
|
|
219
|
-
state.path_nodes.extend(resolved_nodes)
|
|
220
396
|
|
|
221
397
|
if step.stop_reason.is_terminal:
|
|
222
398
|
if resolved_nodes:
|
|
@@ -230,8 +406,7 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
230
406
|
if not resolved_nodes:
|
|
231
407
|
return
|
|
232
408
|
|
|
233
|
-
|
|
234
|
-
base_path_nodes_len = len(state.path_nodes)
|
|
409
|
+
base_steps_len = len(state.steps)
|
|
235
410
|
child_tasks: list[tuple[Awaitable["_TraversalState"], int]] = []
|
|
236
411
|
for node in resolved_nodes:
|
|
237
412
|
if node.children:
|
|
@@ -242,14 +417,15 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
242
417
|
(
|
|
243
418
|
self._classify_subtree(
|
|
244
419
|
sub_agent=sub_agent,
|
|
245
|
-
|
|
420
|
+
input_payload=input_payload,
|
|
246
421
|
nodes=list(node.children),
|
|
247
422
|
depth=depth + 1,
|
|
248
423
|
parent_path=[*parent_path, node.label],
|
|
249
424
|
context=context,
|
|
425
|
+
file_ids=file_ids,
|
|
250
426
|
max_depth=max_depth,
|
|
251
427
|
confidence_threshold=confidence_threshold,
|
|
252
|
-
|
|
428
|
+
session=session,
|
|
253
429
|
state=sub_state,
|
|
254
430
|
),
|
|
255
431
|
base_final_nodes_len,
|
|
@@ -268,8 +444,7 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
268
444
|
for child_state, (_, base_final_nodes_len) in zip(
|
|
269
445
|
child_states, child_tasks, strict=True
|
|
270
446
|
):
|
|
271
|
-
state.
|
|
272
|
-
state.path_nodes.extend(child_state.path_nodes[base_path_nodes_len:])
|
|
447
|
+
state.steps.extend(child_state.steps[base_steps_len:])
|
|
273
448
|
state.final_nodes.extend(child_state.final_nodes[base_final_nodes_len:])
|
|
274
449
|
state.best_confidence = _max_confidence(
|
|
275
450
|
state.best_confidence, child_state.best_confidence
|
|
@@ -323,23 +498,25 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
323
498
|
sub_agent = TaxonomyClassifierAgent(
|
|
324
499
|
template_path=self._template_path,
|
|
325
500
|
model=self._model,
|
|
501
|
+
model_settings=self._model_settings,
|
|
326
502
|
taxonomy=list(nodes),
|
|
327
503
|
)
|
|
328
|
-
sub_agent.
|
|
504
|
+
sub_agent._run_step_async = self._run_step_async
|
|
329
505
|
return sub_agent
|
|
330
506
|
|
|
331
507
|
async def _classify_subtree(
|
|
332
508
|
self,
|
|
333
509
|
*,
|
|
334
510
|
sub_agent: "TaxonomyClassifierAgent",
|
|
335
|
-
|
|
511
|
+
input_payload: str | list[dict[str, Any]],
|
|
336
512
|
nodes: list[TaxonomyNode],
|
|
337
513
|
depth: int,
|
|
338
514
|
parent_path: list[str],
|
|
339
515
|
context: Optional[Dict[str, Any]],
|
|
516
|
+
file_ids: str | Sequence[str] | None,
|
|
340
517
|
max_depth: Optional[int],
|
|
341
518
|
confidence_threshold: float | None,
|
|
342
|
-
|
|
519
|
+
session: Optional[Any],
|
|
343
520
|
state: "_TraversalState",
|
|
344
521
|
) -> "_TraversalState":
|
|
345
522
|
"""Classify a taxonomy subtree and return the traversal state.
|
|
@@ -348,8 +525,8 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
348
525
|
----------
|
|
349
526
|
sub_agent : TaxonomyClassifierAgent
|
|
350
527
|
Sub-agent configured for the subtree traversal.
|
|
351
|
-
|
|
352
|
-
|
|
528
|
+
input_payload : str or list[dict[str, Any]]
|
|
529
|
+
Input payload used to prompt the agent.
|
|
353
530
|
nodes : list[TaxonomyNode]
|
|
354
531
|
Candidate taxonomy nodes for the subtree.
|
|
355
532
|
depth : int
|
|
@@ -358,12 +535,14 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
358
535
|
Path segments leading to the current subtree.
|
|
359
536
|
context : dict or None
|
|
360
537
|
Additional context values to merge into the prompt.
|
|
538
|
+
file_ids : str or Sequence[str] or None
|
|
539
|
+
Optional file IDs attached to each classification step.
|
|
361
540
|
max_depth : int or None
|
|
362
541
|
Maximum traversal depth before stopping.
|
|
363
542
|
confidence_threshold : float or None
|
|
364
543
|
Minimum confidence required to accept a classification step.
|
|
365
|
-
|
|
366
|
-
|
|
544
|
+
session : Session or None
|
|
545
|
+
Optional session for maintaining conversation history across runs.
|
|
367
546
|
state : _TraversalState
|
|
368
547
|
Traversal state to populate for the subtree.
|
|
369
548
|
|
|
@@ -373,14 +552,15 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
373
552
|
Populated traversal state for the subtree.
|
|
374
553
|
"""
|
|
375
554
|
await sub_agent._classify_nodes(
|
|
376
|
-
|
|
555
|
+
input_payload=input_payload,
|
|
377
556
|
nodes=nodes,
|
|
378
557
|
depth=depth,
|
|
379
558
|
parent_path=parent_path,
|
|
380
559
|
context=context,
|
|
560
|
+
file_ids=file_ids,
|
|
381
561
|
max_depth=max_depth,
|
|
382
562
|
confidence_threshold=confidence_threshold,
|
|
383
|
-
|
|
563
|
+
session=session,
|
|
384
564
|
state=state,
|
|
385
565
|
)
|
|
386
566
|
return state
|
|
@@ -390,8 +570,7 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
390
570
|
class _TraversalState:
|
|
391
571
|
"""Track recursive traversal state."""
|
|
392
572
|
|
|
393
|
-
|
|
394
|
-
path_nodes: list[TaxonomyNode] = field(default_factory=list)
|
|
573
|
+
steps: list[ClassificationStep] = field(default_factory=list)
|
|
395
574
|
final_nodes: list[TaxonomyNode] = field(default_factory=list)
|
|
396
575
|
best_confidence: float | None = None
|
|
397
576
|
saw_max_depth: bool = False
|
|
@@ -413,8 +592,7 @@ def _copy_traversal_state(state: _TraversalState) -> _TraversalState:
|
|
|
413
592
|
Cloned traversal state with copied collections.
|
|
414
593
|
"""
|
|
415
594
|
return _TraversalState(
|
|
416
|
-
|
|
417
|
-
path_nodes=list(state.path_nodes),
|
|
595
|
+
steps=list(state.steps),
|
|
418
596
|
final_nodes=list(state.final_nodes),
|
|
419
597
|
best_confidence=state.best_confidence,
|
|
420
598
|
saw_max_depth=state.saw_max_depth,
|
|
@@ -464,6 +642,8 @@ def _normalize_roots(
|
|
|
464
642
|
list[TaxonomyNode]
|
|
465
643
|
Normalized list of root nodes.
|
|
466
644
|
"""
|
|
645
|
+
if isinstance(taxonomy, Taxonomy):
|
|
646
|
+
return [node for node in taxonomy.children if node is not None]
|
|
467
647
|
if isinstance(taxonomy, TaxonomyNode):
|
|
468
648
|
return [taxonomy]
|
|
469
649
|
return [node for node in taxonomy if node is not None]
|
|
@@ -483,7 +663,7 @@ def _default_template_path() -> Path:
|
|
|
483
663
|
def _build_context(
|
|
484
664
|
*,
|
|
485
665
|
node_descriptors: Iterable[dict[str, Any]],
|
|
486
|
-
|
|
666
|
+
steps: Sequence[ClassificationStep],
|
|
487
667
|
depth: int,
|
|
488
668
|
context: Optional[Dict[str, Any]],
|
|
489
669
|
) -> Dict[str, Any]:
|
|
@@ -493,7 +673,7 @@ def _build_context(
|
|
|
493
673
|
----------
|
|
494
674
|
node_descriptors : Iterable[dict[str, Any]]
|
|
495
675
|
Node descriptors available at the current taxonomy level.
|
|
496
|
-
|
|
676
|
+
steps : Sequence[ClassificationStep]
|
|
497
677
|
Steps recorded so far in the traversal.
|
|
498
678
|
depth : int
|
|
499
679
|
Current traversal depth.
|
|
@@ -505,9 +685,14 @@ def _build_context(
|
|
|
505
685
|
dict[str, Any]
|
|
506
686
|
Context dictionary for prompt rendering.
|
|
507
687
|
"""
|
|
688
|
+
summarized_steps = [
|
|
689
|
+
step.as_summary()
|
|
690
|
+
for step in steps
|
|
691
|
+
if step.selected_nodes and any(node is not None for node in step.selected_nodes)
|
|
692
|
+
]
|
|
508
693
|
template_context: Dict[str, Any] = {
|
|
509
694
|
"taxonomy_nodes": list(node_descriptors),
|
|
510
|
-
"
|
|
695
|
+
"steps": summarized_steps,
|
|
511
696
|
"depth": depth,
|
|
512
697
|
}
|
|
513
698
|
if context:
|
|
@@ -555,7 +740,7 @@ def _build_node_path_map(
|
|
|
555
740
|
path_map: dict[str, TaxonomyNode] = {}
|
|
556
741
|
seen: dict[str, int] = {}
|
|
557
742
|
for node in nodes:
|
|
558
|
-
base_path =
|
|
743
|
+
base_path = format_path_identifier([*parent_path, node.label])
|
|
559
744
|
count = seen.get(base_path, 0) + 1
|
|
560
745
|
seen[base_path] = count
|
|
561
746
|
path = f"{base_path} ({count})" if count > 1 else base_path
|
|
@@ -584,33 +769,12 @@ def _build_node_descriptors(
|
|
|
584
769
|
{
|
|
585
770
|
"identifier": path_id,
|
|
586
771
|
"label": node.label,
|
|
587
|
-
"
|
|
772
|
+
"computed_description": node.computed_description,
|
|
588
773
|
}
|
|
589
774
|
)
|
|
590
775
|
return descriptors
|
|
591
776
|
|
|
592
777
|
|
|
593
|
-
def _format_path_identifier(path_segments: Sequence[str]) -> str:
|
|
594
|
-
"""Format path segments into a safe identifier string.
|
|
595
|
-
|
|
596
|
-
Parameters
|
|
597
|
-
----------
|
|
598
|
-
path_segments : Sequence[str]
|
|
599
|
-
Path segments to format.
|
|
600
|
-
|
|
601
|
-
Returns
|
|
602
|
-
-------
|
|
603
|
-
str
|
|
604
|
-
Escaped path identifier string.
|
|
605
|
-
"""
|
|
606
|
-
delimiter = " > "
|
|
607
|
-
escape_token = "\\>"
|
|
608
|
-
escaped_segments = [
|
|
609
|
-
segment.replace(delimiter, escape_token) for segment in path_segments
|
|
610
|
-
]
|
|
611
|
-
return delimiter.join(escaped_segments)
|
|
612
|
-
|
|
613
|
-
|
|
614
778
|
def _build_taxonomy_enum(name: str, values: Sequence[str]) -> type[Enum]:
|
|
615
779
|
"""Build a safe Enum from taxonomy node values.
|
|
616
780
|
|
|
@@ -635,25 +799,6 @@ def _build_taxonomy_enum(name: str, values: Sequence[str]) -> type[Enum]:
|
|
|
635
799
|
return cast(type[Enum], Enum(name, members))
|
|
636
800
|
|
|
637
801
|
|
|
638
|
-
def _split_taxonomy_path(value: str) -> list[str]:
|
|
639
|
-
"""Split a taxonomy identifier into its path segments.
|
|
640
|
-
|
|
641
|
-
Parameters
|
|
642
|
-
----------
|
|
643
|
-
value : str
|
|
644
|
-
Taxonomy path identifier to split.
|
|
645
|
-
|
|
646
|
-
Returns
|
|
647
|
-
-------
|
|
648
|
-
list[str]
|
|
649
|
-
Path segments with escaped delimiters restored.
|
|
650
|
-
"""
|
|
651
|
-
delimiter = " > "
|
|
652
|
-
escape_token = "\\>"
|
|
653
|
-
segments = value.split(delimiter)
|
|
654
|
-
return [segment.replace(escape_token, delimiter) for segment in segments]
|
|
655
|
-
|
|
656
|
-
|
|
657
802
|
def _sanitize_enum_member(
|
|
658
803
|
value: str,
|
|
659
804
|
index: int,
|
|
@@ -676,7 +821,7 @@ def _sanitize_enum_member(
|
|
|
676
821
|
Sanitized enum member name.
|
|
677
822
|
"""
|
|
678
823
|
normalized_segments: list[str] = []
|
|
679
|
-
for segment in
|
|
824
|
+
for segment in split_path_identifier(value):
|
|
680
825
|
normalized = re.sub(r"[^0-9a-zA-Z]+", "_", segment).strip("_").upper()
|
|
681
826
|
if not normalized:
|
|
682
827
|
normalized = "VALUE"
|
|
@@ -716,6 +861,40 @@ def _normalize_step_output(
|
|
|
716
861
|
return ClassificationStep.from_json(payload)
|
|
717
862
|
|
|
718
863
|
|
|
864
|
+
def _build_input_payload(
|
|
865
|
+
text: str,
|
|
866
|
+
file_ids: str | Sequence[str] | None,
|
|
867
|
+
) -> str | list[dict[str, Any]]:
|
|
868
|
+
"""Build input payloads with optional file attachments.
|
|
869
|
+
|
|
870
|
+
Parameters
|
|
871
|
+
----------
|
|
872
|
+
text : str
|
|
873
|
+
Prompt text to send to the agent.
|
|
874
|
+
file_ids : str or Sequence[str] or None
|
|
875
|
+
Optional file IDs to include as ``input_file`` attachments.
|
|
876
|
+
|
|
877
|
+
Returns
|
|
878
|
+
-------
|
|
879
|
+
str or list[dict[str, Any]]
|
|
880
|
+
Input payload suitable for the Agents SDK.
|
|
881
|
+
"""
|
|
882
|
+
normalized_file_ids = [
|
|
883
|
+
file_id for file_id in dict.fromkeys(ensure_list(file_ids)) if file_id
|
|
884
|
+
]
|
|
885
|
+
if not normalized_file_ids:
|
|
886
|
+
return text
|
|
887
|
+
attachments = [
|
|
888
|
+
{"type": "input_file", "file_id": file_id} for file_id in normalized_file_ids
|
|
889
|
+
]
|
|
890
|
+
return [
|
|
891
|
+
{
|
|
892
|
+
"role": "user",
|
|
893
|
+
"content": [{"type": "input_text", "text": text}, *attachments],
|
|
894
|
+
}
|
|
895
|
+
]
|
|
896
|
+
|
|
897
|
+
|
|
719
898
|
def _extract_enum_fields(
|
|
720
899
|
step_structure: type[StructureBase],
|
|
721
900
|
) -> dict[str, type[Enum]]:
|
|
@@ -807,17 +986,20 @@ def _selected_nodes(step: ClassificationStep) -> list[str]:
|
|
|
807
986
|
list[str]
|
|
808
987
|
Selected identifiers in priority order.
|
|
809
988
|
"""
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
989
|
+
enum_cls: type[Enum] | None = None
|
|
990
|
+
step_cls = step.__class__
|
|
991
|
+
if hasattr(step_cls, "model_fields"):
|
|
992
|
+
field = step_cls.model_fields.get("selected_nodes")
|
|
993
|
+
if field is not None:
|
|
994
|
+
enum_cls = step_cls._extract_enum_class(field.annotation)
|
|
995
|
+
if enum_cls is None:
|
|
996
|
+
enum_cls = Enum
|
|
997
|
+
selected_nodes = [
|
|
998
|
+
str(_normalize_enum_value(selected_node, enum_cls))
|
|
999
|
+
for selected_node in step.selected_nodes or []
|
|
1000
|
+
if selected_node
|
|
1001
|
+
]
|
|
1002
|
+
return selected_nodes
|
|
821
1003
|
|
|
822
1004
|
|
|
823
1005
|
def _max_confidence(
|