openai-sdk-helpers 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openai_sdk_helpers/agent/__init__.py +2 -0
- openai_sdk_helpers/agent/base.py +88 -12
- openai_sdk_helpers/agent/classifier.py +905 -94
- openai_sdk_helpers/agent/configuration.py +42 -0
- openai_sdk_helpers/agent/files.py +120 -0
- openai_sdk_helpers/agent/runner.py +9 -9
- openai_sdk_helpers/agent/translator.py +2 -2
- openai_sdk_helpers/files_api.py +46 -1
- openai_sdk_helpers/prompt/classifier.jinja +28 -7
- openai_sdk_helpers/settings.py +65 -0
- openai_sdk_helpers/structure/__init__.py +4 -0
- openai_sdk_helpers/structure/base.py +79 -55
- openai_sdk_helpers/structure/classification.py +265 -43
- openai_sdk_helpers/structure/plan/enum.py +4 -0
- {openai_sdk_helpers-0.6.0.dist-info → openai_sdk_helpers-0.6.2.dist-info}/METADATA +12 -1
- {openai_sdk_helpers-0.6.0.dist-info → openai_sdk_helpers-0.6.2.dist-info}/RECORD +19 -18
- {openai_sdk_helpers-0.6.0.dist-info → openai_sdk_helpers-0.6.2.dist-info}/WHEEL +0 -0
- {openai_sdk_helpers-0.6.0.dist-info → openai_sdk_helpers-0.6.2.dist-info}/entry_points.txt +0 -0
- {openai_sdk_helpers-0.6.0.dist-info → openai_sdk_helpers-0.6.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,22 +1,29 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Recursive agent for taxonomy-driven text classification."""
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import asyncio
|
|
6
|
+
import threading
|
|
7
|
+
import re
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from enum import Enum
|
|
5
10
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Dict, Iterable, Optional, Sequence
|
|
11
|
+
from typing import Any, Awaitable, Dict, Iterable, Optional, Sequence, cast
|
|
7
12
|
|
|
8
13
|
from ..structure import (
|
|
9
14
|
ClassificationResult,
|
|
10
15
|
ClassificationStep,
|
|
11
16
|
ClassificationStopReason,
|
|
17
|
+
StructureBase,
|
|
12
18
|
TaxonomyNode,
|
|
13
19
|
)
|
|
20
|
+
from ..utils import ensure_list
|
|
14
21
|
from .base import AgentBase
|
|
15
22
|
from .configuration import AgentConfiguration
|
|
16
23
|
|
|
17
24
|
|
|
18
25
|
class TaxonomyClassifierAgent(AgentBase):
|
|
19
|
-
"""Classify text by traversing a taxonomy
|
|
26
|
+
"""Classify text by recursively traversing a taxonomy.
|
|
20
27
|
|
|
21
28
|
Parameters
|
|
22
29
|
----------
|
|
@@ -27,8 +34,22 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
27
34
|
|
|
28
35
|
Methods
|
|
29
36
|
-------
|
|
30
|
-
run_agent(text, taxonomy, context, max_depth)
|
|
31
|
-
Classify text by walking the taxonomy tree.
|
|
37
|
+
run_agent(text, taxonomy, context, max_depth, session)
|
|
38
|
+
Classify text by recursively walking the taxonomy tree.
|
|
39
|
+
run_async(input, context, max_depth, confidence_threshold, single_class)
|
|
40
|
+
Classify text asynchronously using taxonomy traversal.
|
|
41
|
+
run_sync(input, context, max_depth, confidence_threshold, single_class)
|
|
42
|
+
Classify text synchronously using taxonomy traversal.
|
|
43
|
+
|
|
44
|
+
Examples
|
|
45
|
+
--------
|
|
46
|
+
Create a classifier with a flat taxonomy:
|
|
47
|
+
|
|
48
|
+
>>> taxonomy = [
|
|
49
|
+
... TaxonomyNode(label="Billing"),
|
|
50
|
+
... TaxonomyNode(label="Support"),
|
|
51
|
+
... ]
|
|
52
|
+
>>> agent = TaxonomyClassifierAgent(model="gpt-4o-mini", taxonomy=taxonomy)
|
|
32
53
|
"""
|
|
33
54
|
|
|
34
55
|
def __init__(
|
|
@@ -36,6 +57,7 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
36
57
|
*,
|
|
37
58
|
template_path: Path | str | None = None,
|
|
38
59
|
model: str | None = None,
|
|
60
|
+
taxonomy: TaxonomyNode | Sequence[TaxonomyNode],
|
|
39
61
|
) -> None:
|
|
40
62
|
"""Initialize the taxonomy classifier agent configuration.
|
|
41
63
|
|
|
@@ -45,21 +67,27 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
45
67
|
Optional template file path for prompt rendering.
|
|
46
68
|
model : str | None, default=None
|
|
47
69
|
Model identifier to use for classification.
|
|
70
|
+
taxonomy : TaxonomyNode | Sequence[TaxonomyNode]
|
|
71
|
+
Root taxonomy node or list of root nodes.
|
|
48
72
|
|
|
49
73
|
Raises
|
|
50
74
|
------
|
|
51
75
|
ValueError
|
|
52
|
-
If the
|
|
76
|
+
If the taxonomy is empty.
|
|
53
77
|
|
|
54
78
|
Examples
|
|
55
79
|
--------
|
|
56
|
-
>>> classifier = TaxonomyClassifierAgent(model="gpt-4o-mini")
|
|
80
|
+
>>> classifier = TaxonomyClassifierAgent(model="gpt-4o-mini", taxonomy=[])
|
|
57
81
|
"""
|
|
82
|
+
self._taxonomy = taxonomy
|
|
83
|
+
self._root_nodes = _normalize_roots(taxonomy)
|
|
84
|
+
if not self._root_nodes:
|
|
85
|
+
raise ValueError("taxonomy must include at least one node")
|
|
58
86
|
resolved_template_path = template_path or _default_template_path()
|
|
59
87
|
configuration = AgentConfiguration(
|
|
60
88
|
name="taxonomy_classifier",
|
|
61
89
|
instructions="Agent instructions",
|
|
62
|
-
description="Classify text by traversing taxonomy levels.",
|
|
90
|
+
description="Classify text by traversing taxonomy levels recursively.",
|
|
63
91
|
template_path=resolved_template_path,
|
|
64
92
|
output_structure=ClassificationStep,
|
|
65
93
|
model=model,
|
|
@@ -69,84 +97,555 @@ class TaxonomyClassifierAgent(AgentBase):
|
|
|
69
97
|
async def run_agent(
|
|
70
98
|
self,
|
|
71
99
|
text: str,
|
|
72
|
-
taxonomy: TaxonomyNode | Sequence[TaxonomyNode],
|
|
73
100
|
*,
|
|
74
101
|
context: Optional[Dict[str, Any]] = None,
|
|
102
|
+
file_ids: str | Sequence[str] | None = None,
|
|
75
103
|
max_depth: Optional[int] = None,
|
|
104
|
+
confidence_threshold: float | None = None,
|
|
105
|
+
single_class: bool = False,
|
|
106
|
+
session: Optional[Any] = None,
|
|
76
107
|
) -> ClassificationResult:
|
|
77
|
-
"""Classify ``text`` by
|
|
108
|
+
"""Classify ``text`` by recursively walking taxonomy levels.
|
|
78
109
|
|
|
79
110
|
Parameters
|
|
80
111
|
----------
|
|
81
112
|
text : str
|
|
82
113
|
Source text to classify.
|
|
83
|
-
taxonomy : TaxonomyNode or Sequence[TaxonomyNode]
|
|
84
|
-
Root taxonomy node or list of root nodes to traverse.
|
|
85
114
|
context : dict or None, default=None
|
|
86
115
|
Additional context values to merge into the prompt.
|
|
116
|
+
file_ids : str or Sequence[str] or None, default=None
|
|
117
|
+
Optional file IDs to attach to each classification step.
|
|
87
118
|
max_depth : int or None, default=None
|
|
88
119
|
Maximum depth to traverse before stopping.
|
|
120
|
+
confidence_threshold : float or None, default=None
|
|
121
|
+
Minimum confidence required to accept a classification step.
|
|
122
|
+
single_class : bool, default=False
|
|
123
|
+
Whether to keep only the highest-priority selection per step.
|
|
124
|
+
session : Session or None, default=None
|
|
125
|
+
Optional session for maintaining conversation history across runs.
|
|
89
126
|
|
|
90
127
|
Returns
|
|
91
128
|
-------
|
|
92
129
|
ClassificationResult
|
|
93
130
|
Structured classification result describing the traversal.
|
|
94
131
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
132
|
+
Examples
|
|
133
|
+
--------
|
|
134
|
+
>>> taxonomy = TaxonomyNode(label="Finance")
|
|
135
|
+
>>> agent = TaxonomyClassifierAgent(model="gpt-4o-mini", taxonomy=taxonomy)
|
|
136
|
+
>>> isinstance(agent.root_nodes, list)
|
|
137
|
+
True
|
|
99
138
|
"""
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
139
|
+
state = _TraversalState()
|
|
140
|
+
input_payload = _build_input_payload(text, file_ids)
|
|
141
|
+
await self._classify_nodes(
|
|
142
|
+
input_payload=input_payload,
|
|
143
|
+
nodes=list(self._root_nodes),
|
|
144
|
+
depth=0,
|
|
145
|
+
parent_path=[],
|
|
146
|
+
context=context,
|
|
147
|
+
file_ids=file_ids,
|
|
148
|
+
max_depth=max_depth,
|
|
149
|
+
confidence_threshold=confidence_threshold,
|
|
150
|
+
single_class=single_class,
|
|
151
|
+
session=session,
|
|
152
|
+
state=state,
|
|
153
|
+
)
|
|
103
154
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
stop_reason =
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
155
|
+
final_nodes_value = state.final_nodes or None
|
|
156
|
+
final_node = state.final_nodes[0] if state.final_nodes else None
|
|
157
|
+
stop_reason = _resolve_stop_reason(state)
|
|
158
|
+
return ClassificationResult(
|
|
159
|
+
final_node=final_node,
|
|
160
|
+
final_nodes=final_nodes_value,
|
|
161
|
+
confidence=state.best_confidence,
|
|
162
|
+
stop_reason=stop_reason,
|
|
163
|
+
path=state.path,
|
|
164
|
+
path_nodes=state.path_nodes,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
async def run_async(
|
|
168
|
+
self,
|
|
169
|
+
input: str | list[dict[str, Any]],
|
|
170
|
+
*,
|
|
171
|
+
context: Optional[Dict[str, Any]] = None,
|
|
172
|
+
output_structure: Optional[type[StructureBase]] = None,
|
|
173
|
+
session: Optional[Any] = None,
|
|
174
|
+
file_ids: str | Sequence[str] | None = None,
|
|
175
|
+
max_depth: Optional[int] = None,
|
|
176
|
+
confidence_threshold: float | None = None,
|
|
177
|
+
single_class: bool = False,
|
|
178
|
+
) -> ClassificationResult:
|
|
179
|
+
"""Classify ``input`` asynchronously with taxonomy traversal.
|
|
180
|
+
|
|
181
|
+
Parameters
|
|
182
|
+
----------
|
|
183
|
+
input : str or list[dict[str, Any]]
|
|
184
|
+
Source text to classify.
|
|
185
|
+
context : dict or None, default=None
|
|
186
|
+
Additional context values to merge into the prompt.
|
|
187
|
+
output_structure : type[StructureBase] or None, default=None
|
|
188
|
+
Unused in taxonomy traversal. Present for API compatibility.
|
|
189
|
+
session : Session or None, default=None
|
|
190
|
+
Optional session for maintaining conversation history across runs.
|
|
191
|
+
file_ids : str or Sequence[str] or None, default=None
|
|
192
|
+
Optional file IDs to attach to each classification step.
|
|
193
|
+
max_depth : int or None, default=None
|
|
194
|
+
Maximum depth to traverse before stopping.
|
|
195
|
+
confidence_threshold : float or None, default=None
|
|
196
|
+
Minimum confidence required to accept a classification step.
|
|
197
|
+
single_class : bool, default=False
|
|
198
|
+
Whether to keep only the highest-priority selection per step.
|
|
199
|
+
|
|
200
|
+
Returns
|
|
201
|
+
-------
|
|
202
|
+
ClassificationResult
|
|
203
|
+
Structured classification result describing the traversal.
|
|
204
|
+
"""
|
|
205
|
+
_ = output_structure
|
|
206
|
+
if not isinstance(input, str):
|
|
207
|
+
msg = "TaxonomyClassifierAgent run_async requires text input."
|
|
208
|
+
raise TypeError(msg)
|
|
209
|
+
kwargs: Dict[str, Any] = {
|
|
210
|
+
"context": context,
|
|
211
|
+
"file_ids": file_ids,
|
|
212
|
+
"max_depth": max_depth,
|
|
213
|
+
"confidence_threshold": confidence_threshold,
|
|
214
|
+
"single_class": single_class,
|
|
215
|
+
}
|
|
216
|
+
if session is not None:
|
|
217
|
+
kwargs["session"] = session
|
|
218
|
+
return await self.run_agent(input, **kwargs)
|
|
219
|
+
|
|
220
|
+
def run_sync(
|
|
221
|
+
self,
|
|
222
|
+
input: str | list[dict[str, Any]],
|
|
223
|
+
*,
|
|
224
|
+
context: Optional[Dict[str, Any]] = None,
|
|
225
|
+
output_structure: Optional[type[StructureBase]] = None,
|
|
226
|
+
session: Optional[Any] = None,
|
|
227
|
+
file_ids: str | Sequence[str] | None = None,
|
|
228
|
+
max_depth: Optional[int] = None,
|
|
229
|
+
confidence_threshold: float | None = None,
|
|
230
|
+
single_class: bool = False,
|
|
231
|
+
) -> ClassificationResult:
|
|
232
|
+
"""Classify ``input`` synchronously with taxonomy traversal.
|
|
233
|
+
|
|
234
|
+
Parameters
|
|
235
|
+
----------
|
|
236
|
+
input : str or list[dict[str, Any]]
|
|
237
|
+
Source text to classify.
|
|
238
|
+
context : dict or None, default=None
|
|
239
|
+
Additional context values to merge into the prompt.
|
|
240
|
+
output_structure : type[StructureBase] or None, default=None
|
|
241
|
+
Unused in taxonomy traversal. Present for API compatibility.
|
|
242
|
+
session : Session or None, default=None
|
|
243
|
+
Optional session for maintaining conversation history across runs.
|
|
244
|
+
file_ids : str or Sequence[str] or None, default=None
|
|
245
|
+
Optional file IDs to attach to each classification step.
|
|
246
|
+
max_depth : int or None, default=None
|
|
247
|
+
Maximum depth to traverse before stopping.
|
|
248
|
+
confidence_threshold : float or None, default=None
|
|
249
|
+
Minimum confidence required to accept a classification step.
|
|
250
|
+
single_class : bool, default=False
|
|
251
|
+
Whether to keep only the highest-priority selection per step.
|
|
252
|
+
|
|
253
|
+
Returns
|
|
254
|
+
-------
|
|
255
|
+
ClassificationResult
|
|
256
|
+
Structured classification result describing the traversal.
|
|
257
|
+
"""
|
|
258
|
+
_ = output_structure
|
|
259
|
+
if not isinstance(input, str):
|
|
260
|
+
msg = "TaxonomyClassifierAgent run_sync requires text input."
|
|
261
|
+
raise TypeError(msg)
|
|
262
|
+
kwargs: Dict[str, Any] = {
|
|
263
|
+
"context": context,
|
|
264
|
+
"file_ids": file_ids,
|
|
265
|
+
"max_depth": max_depth,
|
|
266
|
+
"confidence_threshold": confidence_threshold,
|
|
267
|
+
"single_class": single_class,
|
|
268
|
+
}
|
|
269
|
+
if session is not None:
|
|
270
|
+
kwargs["session"] = session
|
|
271
|
+
|
|
272
|
+
async def runner() -> ClassificationResult:
|
|
273
|
+
return await self.run_agent(input, **kwargs)
|
|
274
|
+
|
|
275
|
+
try:
|
|
276
|
+
asyncio.get_running_loop()
|
|
277
|
+
except RuntimeError:
|
|
278
|
+
return asyncio.run(runner())
|
|
279
|
+
|
|
280
|
+
result: ClassificationResult | None = None
|
|
281
|
+
error: Exception | None = None
|
|
282
|
+
|
|
283
|
+
def _thread_func() -> None:
|
|
284
|
+
nonlocal error, result
|
|
285
|
+
try:
|
|
286
|
+
result = asyncio.run(runner())
|
|
287
|
+
except Exception as exc:
|
|
288
|
+
error = exc
|
|
289
|
+
|
|
290
|
+
thread = threading.Thread(target=_thread_func)
|
|
291
|
+
thread.start()
|
|
292
|
+
thread.join()
|
|
293
|
+
|
|
294
|
+
if error is not None:
|
|
295
|
+
raise error
|
|
296
|
+
if result is None:
|
|
297
|
+
msg = "Classification did not return a result"
|
|
298
|
+
raise RuntimeError(msg)
|
|
299
|
+
return result
|
|
300
|
+
|
|
301
|
+
async def _run_step_async(
|
|
302
|
+
self,
|
|
303
|
+
*,
|
|
304
|
+
input: str | list[dict[str, Any]],
|
|
305
|
+
context: Optional[Dict[str, Any]] = None,
|
|
306
|
+
output_structure: Optional[type[StructureBase]] = None,
|
|
307
|
+
session: Optional[Any] = None,
|
|
308
|
+
) -> StructureBase:
|
|
309
|
+
"""Execute a single classification step asynchronously.
|
|
310
|
+
|
|
311
|
+
Parameters
|
|
312
|
+
----------
|
|
313
|
+
input : str or list[dict[str, Any]]
|
|
314
|
+
Prompt or structured input for the agent.
|
|
315
|
+
context : dict or None, default=None
|
|
316
|
+
Optional dictionary passed to the agent.
|
|
317
|
+
output_structure : type[StructureBase] or None, default=None
|
|
318
|
+
Optional type used to cast the final output.
|
|
319
|
+
session : Session or None, default=None
|
|
320
|
+
Optional session for maintaining conversation history across runs.
|
|
321
|
+
|
|
322
|
+
Returns
|
|
323
|
+
-------
|
|
324
|
+
StructureBase
|
|
325
|
+
Parsed result for the classification step.
|
|
326
|
+
"""
|
|
327
|
+
return await super().run_async(
|
|
328
|
+
input=input,
|
|
329
|
+
context=context,
|
|
330
|
+
output_structure=output_structure,
|
|
331
|
+
session=session,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
async def _classify_nodes(
|
|
335
|
+
self,
|
|
336
|
+
*,
|
|
337
|
+
input_payload: str | list[dict[str, Any]],
|
|
338
|
+
nodes: list[TaxonomyNode],
|
|
339
|
+
depth: int,
|
|
340
|
+
parent_path: list[str],
|
|
341
|
+
context: Optional[Dict[str, Any]],
|
|
342
|
+
file_ids: str | Sequence[str] | None,
|
|
343
|
+
max_depth: Optional[int],
|
|
344
|
+
confidence_threshold: float | None,
|
|
345
|
+
single_class: bool,
|
|
346
|
+
session: Optional[Any],
|
|
347
|
+
state: "_TraversalState",
|
|
348
|
+
) -> None:
|
|
349
|
+
"""Classify a taxonomy level and recursively traverse children.
|
|
350
|
+
|
|
351
|
+
Parameters
|
|
352
|
+
----------
|
|
353
|
+
input_payload : str or list[dict[str, Any]]
|
|
354
|
+
Input payload used to prompt the agent.
|
|
355
|
+
nodes : list[TaxonomyNode]
|
|
356
|
+
Candidate taxonomy nodes for the current level.
|
|
357
|
+
depth : int
|
|
358
|
+
Current traversal depth.
|
|
359
|
+
context : dict or None
|
|
360
|
+
Additional context values to merge into the prompt.
|
|
361
|
+
file_ids : str or Sequence[str] or None
|
|
362
|
+
Optional file IDs attached to each classification step.
|
|
363
|
+
max_depth : int or None
|
|
364
|
+
Maximum traversal depth before stopping.
|
|
365
|
+
confidence_threshold : float or None
|
|
366
|
+
Minimum confidence required to accept a classification step.
|
|
367
|
+
single_class : bool
|
|
368
|
+
Whether to keep only the highest-priority selection per step.
|
|
369
|
+
session : Session or None
|
|
370
|
+
Optional session for maintaining conversation history across runs.
|
|
371
|
+
state : _TraversalState
|
|
372
|
+
Aggregated traversal state.
|
|
373
|
+
"""
|
|
374
|
+
if max_depth is not None and depth >= max_depth:
|
|
375
|
+
state.saw_max_depth = True
|
|
376
|
+
return
|
|
377
|
+
if not nodes:
|
|
378
|
+
return
|
|
379
|
+
|
|
380
|
+
node_paths = _build_node_path_map(nodes, parent_path)
|
|
381
|
+
template_context = _build_context(
|
|
382
|
+
node_descriptors=_build_node_descriptors(node_paths),
|
|
383
|
+
path=state.path,
|
|
384
|
+
depth=depth,
|
|
385
|
+
context=context,
|
|
386
|
+
)
|
|
387
|
+
step_structure = _build_step_structure(list(node_paths.keys()))
|
|
388
|
+
raw_step = await self._run_step_async(
|
|
389
|
+
input=input_payload,
|
|
390
|
+
context=template_context,
|
|
391
|
+
output_structure=step_structure,
|
|
392
|
+
session=session,
|
|
393
|
+
)
|
|
394
|
+
step = _normalize_step_output(raw_step, step_structure)
|
|
395
|
+
state.path.append(step)
|
|
396
|
+
|
|
397
|
+
if (
|
|
398
|
+
confidence_threshold is not None
|
|
399
|
+
and step.confidence is not None
|
|
400
|
+
and step.confidence < confidence_threshold
|
|
401
|
+
):
|
|
402
|
+
return
|
|
403
|
+
|
|
404
|
+
resolved_nodes = _resolve_nodes(node_paths, step)
|
|
405
|
+
if resolved_nodes:
|
|
406
|
+
if single_class:
|
|
407
|
+
resolved_nodes = resolved_nodes[:1]
|
|
408
|
+
state.path_nodes.extend(resolved_nodes)
|
|
409
|
+
|
|
410
|
+
if step.stop_reason.is_terminal:
|
|
411
|
+
if resolved_nodes:
|
|
412
|
+
state.final_nodes.extend(resolved_nodes)
|
|
413
|
+
state.best_confidence = _max_confidence(
|
|
414
|
+
state.best_confidence, step.confidence
|
|
415
|
+
)
|
|
416
|
+
state.saw_terminal_stop = True
|
|
417
|
+
return
|
|
418
|
+
|
|
419
|
+
if not resolved_nodes:
|
|
420
|
+
return
|
|
421
|
+
|
|
422
|
+
base_path_len = len(state.path)
|
|
423
|
+
base_path_nodes_len = len(state.path_nodes)
|
|
424
|
+
child_tasks: list[tuple[Awaitable["_TraversalState"], int]] = []
|
|
425
|
+
for node in resolved_nodes:
|
|
426
|
+
if node.children:
|
|
427
|
+
sub_agent = self._build_sub_agent(list(node.children))
|
|
428
|
+
sub_state = _copy_traversal_state(state)
|
|
429
|
+
base_final_nodes_len = len(state.final_nodes)
|
|
430
|
+
child_tasks.append(
|
|
431
|
+
(
|
|
432
|
+
self._classify_subtree(
|
|
433
|
+
sub_agent=sub_agent,
|
|
434
|
+
input_payload=input_payload,
|
|
435
|
+
nodes=list(node.children),
|
|
436
|
+
depth=depth + 1,
|
|
437
|
+
parent_path=[*parent_path, node.label],
|
|
438
|
+
context=context,
|
|
439
|
+
file_ids=file_ids,
|
|
440
|
+
max_depth=max_depth,
|
|
441
|
+
confidence_threshold=confidence_threshold,
|
|
442
|
+
single_class=single_class,
|
|
443
|
+
session=session,
|
|
444
|
+
state=sub_state,
|
|
445
|
+
),
|
|
446
|
+
base_final_nodes_len,
|
|
447
|
+
)
|
|
448
|
+
)
|
|
449
|
+
else:
|
|
450
|
+
state.saw_no_children = True
|
|
451
|
+
state.final_nodes.append(node)
|
|
452
|
+
state.best_confidence = _max_confidence(
|
|
453
|
+
state.best_confidence, step.confidence
|
|
454
|
+
)
|
|
455
|
+
if child_tasks:
|
|
456
|
+
child_states = await asyncio.gather(
|
|
457
|
+
*(child_task for child_task, _ in child_tasks)
|
|
124
458
|
)
|
|
125
|
-
|
|
126
|
-
|
|
459
|
+
for child_state, (_, base_final_nodes_len) in zip(
|
|
460
|
+
child_states, child_tasks, strict=True
|
|
461
|
+
):
|
|
462
|
+
state.path.extend(child_state.path[base_path_len:])
|
|
463
|
+
state.path_nodes.extend(child_state.path_nodes[base_path_nodes_len:])
|
|
464
|
+
state.final_nodes.extend(child_state.final_nodes[base_final_nodes_len:])
|
|
465
|
+
state.best_confidence = _max_confidence(
|
|
466
|
+
state.best_confidence, child_state.best_confidence
|
|
467
|
+
)
|
|
468
|
+
state.saw_max_depth = state.saw_max_depth or child_state.saw_max_depth
|
|
469
|
+
state.saw_no_children = (
|
|
470
|
+
state.saw_no_children or child_state.saw_no_children
|
|
471
|
+
)
|
|
472
|
+
state.saw_terminal_stop = (
|
|
473
|
+
state.saw_terminal_stop or child_state.saw_terminal_stop
|
|
474
|
+
)
|
|
127
475
|
|
|
128
|
-
|
|
129
|
-
|
|
476
|
+
@property
|
|
477
|
+
def taxonomy(self) -> TaxonomyNode | Sequence[TaxonomyNode]:
|
|
478
|
+
"""Return the root taxonomy node(s).
|
|
130
479
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
break
|
|
480
|
+
Returns
|
|
481
|
+
-------
|
|
482
|
+
TaxonomyNode or Sequence[TaxonomyNode]
|
|
483
|
+
Root taxonomy node or list of root nodes.
|
|
484
|
+
"""
|
|
485
|
+
return self._taxonomy
|
|
138
486
|
|
|
139
|
-
|
|
140
|
-
|
|
487
|
+
@property
|
|
488
|
+
def root_nodes(self) -> list[TaxonomyNode]:
|
|
489
|
+
"""Return the list of root taxonomy nodes.
|
|
141
490
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
491
|
+
Returns
|
|
492
|
+
-------
|
|
493
|
+
list[TaxonomyNode]
|
|
494
|
+
List of root taxonomy nodes.
|
|
495
|
+
"""
|
|
496
|
+
return self._root_nodes
|
|
497
|
+
|
|
498
|
+
def _build_sub_agent(
|
|
499
|
+
self,
|
|
500
|
+
nodes: Sequence[TaxonomyNode],
|
|
501
|
+
) -> "TaxonomyClassifierAgent":
|
|
502
|
+
"""Build a classifier agent for a taxonomy subtree.
|
|
503
|
+
|
|
504
|
+
Parameters
|
|
505
|
+
----------
|
|
506
|
+
nodes : Sequence[TaxonomyNode]
|
|
507
|
+
Taxonomy nodes to use as the sub-agent's root taxonomy.
|
|
508
|
+
|
|
509
|
+
Returns
|
|
510
|
+
-------
|
|
511
|
+
TaxonomyClassifierAgent
|
|
512
|
+
Configured classifier agent for the taxonomy slice.
|
|
513
|
+
"""
|
|
514
|
+
sub_agent = TaxonomyClassifierAgent(
|
|
515
|
+
template_path=self._template_path,
|
|
516
|
+
model=self._model,
|
|
517
|
+
taxonomy=list(nodes),
|
|
518
|
+
)
|
|
519
|
+
sub_agent._run_step_async = self._run_step_async
|
|
520
|
+
return sub_agent
|
|
521
|
+
|
|
522
|
+
async def _classify_subtree(
|
|
523
|
+
self,
|
|
524
|
+
*,
|
|
525
|
+
sub_agent: "TaxonomyClassifierAgent",
|
|
526
|
+
input_payload: str | list[dict[str, Any]],
|
|
527
|
+
nodes: list[TaxonomyNode],
|
|
528
|
+
depth: int,
|
|
529
|
+
parent_path: list[str],
|
|
530
|
+
context: Optional[Dict[str, Any]],
|
|
531
|
+
file_ids: str | Sequence[str] | None,
|
|
532
|
+
max_depth: Optional[int],
|
|
533
|
+
confidence_threshold: float | None,
|
|
534
|
+
single_class: bool,
|
|
535
|
+
session: Optional[Any],
|
|
536
|
+
state: "_TraversalState",
|
|
537
|
+
) -> "_TraversalState":
|
|
538
|
+
"""Classify a taxonomy subtree and return the traversal state.
|
|
539
|
+
|
|
540
|
+
Parameters
|
|
541
|
+
----------
|
|
542
|
+
sub_agent : TaxonomyClassifierAgent
|
|
543
|
+
Sub-agent configured for the subtree traversal.
|
|
544
|
+
input_payload : str or list[dict[str, Any]]
|
|
545
|
+
Input payload used to prompt the agent.
|
|
546
|
+
nodes : list[TaxonomyNode]
|
|
547
|
+
Candidate taxonomy nodes for the subtree.
|
|
548
|
+
depth : int
|
|
549
|
+
Current traversal depth.
|
|
550
|
+
parent_path : list[str]
|
|
551
|
+
Path segments leading to the current subtree.
|
|
552
|
+
context : dict or None
|
|
553
|
+
Additional context values to merge into the prompt.
|
|
554
|
+
file_ids : str or Sequence[str] or None
|
|
555
|
+
Optional file IDs attached to each classification step.
|
|
556
|
+
max_depth : int or None
|
|
557
|
+
Maximum traversal depth before stopping.
|
|
558
|
+
confidence_threshold : float or None
|
|
559
|
+
Minimum confidence required to accept a classification step.
|
|
560
|
+
single_class : bool
|
|
561
|
+
Whether to keep only the highest-priority selection per step.
|
|
562
|
+
session : Session or None
|
|
563
|
+
Optional session for maintaining conversation history across runs.
|
|
564
|
+
state : _TraversalState
|
|
565
|
+
Traversal state to populate for the subtree.
|
|
566
|
+
|
|
567
|
+
Returns
|
|
568
|
+
-------
|
|
569
|
+
_TraversalState
|
|
570
|
+
Populated traversal state for the subtree.
|
|
571
|
+
"""
|
|
572
|
+
await sub_agent._classify_nodes(
|
|
573
|
+
input_payload=input_payload,
|
|
574
|
+
nodes=nodes,
|
|
575
|
+
depth=depth,
|
|
576
|
+
parent_path=parent_path,
|
|
577
|
+
context=context,
|
|
578
|
+
file_ids=file_ids,
|
|
579
|
+
max_depth=max_depth,
|
|
580
|
+
confidence_threshold=confidence_threshold,
|
|
581
|
+
single_class=single_class,
|
|
582
|
+
session=session,
|
|
583
|
+
state=state,
|
|
149
584
|
)
|
|
585
|
+
return state
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
@dataclass
|
|
589
|
+
class _TraversalState:
|
|
590
|
+
"""Track recursive traversal state."""
|
|
591
|
+
|
|
592
|
+
path: list[ClassificationStep] = field(default_factory=list)
|
|
593
|
+
path_nodes: list[TaxonomyNode] = field(default_factory=list)
|
|
594
|
+
final_nodes: list[TaxonomyNode] = field(default_factory=list)
|
|
595
|
+
best_confidence: float | None = None
|
|
596
|
+
saw_max_depth: bool = False
|
|
597
|
+
saw_no_children: bool = False
|
|
598
|
+
saw_terminal_stop: bool = False
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
def _copy_traversal_state(state: _TraversalState) -> _TraversalState:
|
|
602
|
+
"""Copy traversal state for parallel subtree execution.
|
|
603
|
+
|
|
604
|
+
Parameters
|
|
605
|
+
----------
|
|
606
|
+
state : _TraversalState
|
|
607
|
+
Traversal state to clone.
|
|
608
|
+
|
|
609
|
+
Returns
|
|
610
|
+
-------
|
|
611
|
+
_TraversalState
|
|
612
|
+
Cloned traversal state with copied collections.
|
|
613
|
+
"""
|
|
614
|
+
return _TraversalState(
|
|
615
|
+
path=list(state.path),
|
|
616
|
+
path_nodes=list(state.path_nodes),
|
|
617
|
+
final_nodes=list(state.final_nodes),
|
|
618
|
+
best_confidence=state.best_confidence,
|
|
619
|
+
saw_max_depth=state.saw_max_depth,
|
|
620
|
+
saw_no_children=state.saw_no_children,
|
|
621
|
+
saw_terminal_stop=state.saw_terminal_stop,
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
def _resolve_stop_reason(state: _TraversalState) -> ClassificationStopReason:
|
|
626
|
+
"""Resolve the final stop reason based on traversal state.
|
|
627
|
+
|
|
628
|
+
Parameters
|
|
629
|
+
----------
|
|
630
|
+
state : _TraversalState
|
|
631
|
+
Traversal state to inspect.
|
|
632
|
+
|
|
633
|
+
Returns
|
|
634
|
+
-------
|
|
635
|
+
ClassificationStopReason
|
|
636
|
+
Resolved stop reason.
|
|
637
|
+
"""
|
|
638
|
+
if state.saw_terminal_stop:
|
|
639
|
+
return ClassificationStopReason.STOP
|
|
640
|
+
if state.final_nodes and state.saw_no_children:
|
|
641
|
+
return ClassificationStopReason.NO_CHILDREN
|
|
642
|
+
if state.final_nodes:
|
|
643
|
+
return ClassificationStopReason.STOP
|
|
644
|
+
if state.saw_max_depth:
|
|
645
|
+
return ClassificationStopReason.MAX_DEPTH
|
|
646
|
+
if state.saw_no_children:
|
|
647
|
+
return ClassificationStopReason.NO_CHILDREN
|
|
648
|
+
return ClassificationStopReason.NO_MATCH
|
|
150
649
|
|
|
151
650
|
|
|
152
651
|
def _normalize_roots(
|
|
@@ -156,7 +655,7 @@ def _normalize_roots(
|
|
|
156
655
|
|
|
157
656
|
Parameters
|
|
158
657
|
----------
|
|
159
|
-
taxonomy : TaxonomyNode
|
|
658
|
+
taxonomy : TaxonomyNode | Sequence[TaxonomyNode]
|
|
160
659
|
Root taxonomy node or list of root nodes.
|
|
161
660
|
|
|
162
661
|
Returns
|
|
@@ -182,7 +681,7 @@ def _default_template_path() -> Path:
|
|
|
182
681
|
|
|
183
682
|
def _build_context(
|
|
184
683
|
*,
|
|
185
|
-
|
|
684
|
+
node_descriptors: Iterable[dict[str, Any]],
|
|
186
685
|
path: Sequence[ClassificationStep],
|
|
187
686
|
depth: int,
|
|
188
687
|
context: Optional[Dict[str, Any]],
|
|
@@ -191,8 +690,8 @@ def _build_context(
|
|
|
191
690
|
|
|
192
691
|
Parameters
|
|
193
692
|
----------
|
|
194
|
-
|
|
195
|
-
|
|
693
|
+
node_descriptors : Iterable[dict[str, Any]]
|
|
694
|
+
Node descriptors available at the current taxonomy level.
|
|
196
695
|
path : Sequence[ClassificationStep]
|
|
197
696
|
Steps recorded so far in the traversal.
|
|
198
697
|
depth : int
|
|
@@ -206,7 +705,7 @@ def _build_context(
|
|
|
206
705
|
Context dictionary for prompt rendering.
|
|
207
706
|
"""
|
|
208
707
|
template_context: Dict[str, Any] = {
|
|
209
|
-
"taxonomy_nodes": list(
|
|
708
|
+
"taxonomy_nodes": list(node_descriptors),
|
|
210
709
|
"path": [step.as_summary() for step in path],
|
|
211
710
|
"depth": depth,
|
|
212
711
|
}
|
|
@@ -215,54 +714,366 @@ def _build_context(
|
|
|
215
714
|
return template_context
|
|
216
715
|
|
|
217
716
|
|
|
218
|
-
def
|
|
717
|
+
def _build_step_structure(
|
|
718
|
+
path_identifiers: Sequence[str],
|
|
719
|
+
) -> type[ClassificationStep]:
|
|
720
|
+
"""Build a step output structure constrained to taxonomy paths.
|
|
721
|
+
|
|
722
|
+
Parameters
|
|
723
|
+
----------
|
|
724
|
+
path_identifiers : Sequence[str]
|
|
725
|
+
Path identifiers for nodes at the current classification step.
|
|
726
|
+
|
|
727
|
+
Returns
|
|
728
|
+
-------
|
|
729
|
+
type[ClassificationStep]
|
|
730
|
+
Dynamic structure class for the classification step output.
|
|
731
|
+
"""
|
|
732
|
+
node_enum = _build_taxonomy_enum("TaxonomyPath", path_identifiers)
|
|
733
|
+
return ClassificationStep.build_for_enum(node_enum)
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
def _build_node_path_map(
|
|
219
737
|
nodes: Sequence[TaxonomyNode],
|
|
220
|
-
|
|
221
|
-
) ->
|
|
222
|
-
"""
|
|
738
|
+
parent_path: Sequence[str],
|
|
739
|
+
) -> dict[str, TaxonomyNode]:
|
|
740
|
+
"""Build a mapping of node path identifiers to taxonomy nodes.
|
|
223
741
|
|
|
224
742
|
Parameters
|
|
225
743
|
----------
|
|
226
744
|
nodes : Sequence[TaxonomyNode]
|
|
227
|
-
Candidate nodes at the current level.
|
|
745
|
+
Candidate nodes at the current taxonomy level.
|
|
746
|
+
parent_path : Sequence[str]
|
|
747
|
+
Path segments leading to the current taxonomy level.
|
|
748
|
+
|
|
749
|
+
Returns
|
|
750
|
+
-------
|
|
751
|
+
dict[str, TaxonomyNode]
|
|
752
|
+
Mapping of path identifiers to taxonomy nodes.
|
|
753
|
+
"""
|
|
754
|
+
path_map: dict[str, TaxonomyNode] = {}
|
|
755
|
+
seen: dict[str, int] = {}
|
|
756
|
+
for node in nodes:
|
|
757
|
+
base_path = _format_path_identifier([*parent_path, node.label])
|
|
758
|
+
count = seen.get(base_path, 0) + 1
|
|
759
|
+
seen[base_path] = count
|
|
760
|
+
path = f"{base_path} ({count})" if count > 1 else base_path
|
|
761
|
+
path_map[path] = node
|
|
762
|
+
return path_map
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
def _build_node_descriptors(
|
|
766
|
+
node_paths: dict[str, TaxonomyNode],
|
|
767
|
+
) -> list[dict[str, Any]]:
|
|
768
|
+
"""Build node descriptors for prompt rendering.
|
|
769
|
+
|
|
770
|
+
Parameters
|
|
771
|
+
----------
|
|
772
|
+
node_paths : dict[str, TaxonomyNode]
|
|
773
|
+
Mapping of path identifiers to taxonomy nodes.
|
|
774
|
+
|
|
775
|
+
Returns
|
|
776
|
+
-------
|
|
777
|
+
list[dict[str, Any]]
|
|
778
|
+
Node descriptor dictionaries for prompt rendering.
|
|
779
|
+
"""
|
|
780
|
+
descriptors: list[dict[str, Any]] = []
|
|
781
|
+
for path_id, node in node_paths.items():
|
|
782
|
+
descriptors.append(
|
|
783
|
+
{
|
|
784
|
+
"identifier": path_id,
|
|
785
|
+
"label": node.label,
|
|
786
|
+
"description": node.description,
|
|
787
|
+
}
|
|
788
|
+
)
|
|
789
|
+
return descriptors
|
|
790
|
+
|
|
791
|
+
|
|
792
|
+
def _format_path_identifier(path_segments: Sequence[str]) -> str:
|
|
793
|
+
"""Format path segments into a safe identifier string.
|
|
794
|
+
|
|
795
|
+
Parameters
|
|
796
|
+
----------
|
|
797
|
+
path_segments : Sequence[str]
|
|
798
|
+
Path segments to format.
|
|
799
|
+
|
|
800
|
+
Returns
|
|
801
|
+
-------
|
|
802
|
+
str
|
|
803
|
+
Escaped path identifier string.
|
|
804
|
+
"""
|
|
805
|
+
delimiter = " > "
|
|
806
|
+
escape_token = "\\>"
|
|
807
|
+
escaped_segments = [
|
|
808
|
+
segment.replace(delimiter, escape_token) for segment in path_segments
|
|
809
|
+
]
|
|
810
|
+
return delimiter.join(escaped_segments)
|
|
811
|
+
|
|
812
|
+
|
|
813
|
+
def _build_taxonomy_enum(name: str, values: Sequence[str]) -> type[Enum]:
|
|
814
|
+
"""Build a safe Enum from taxonomy node values.
|
|
815
|
+
|
|
816
|
+
Parameters
|
|
817
|
+
----------
|
|
818
|
+
name : str
|
|
819
|
+
Name to use for the enum class.
|
|
820
|
+
values : Sequence[str]
|
|
821
|
+
Taxonomy node values to include as enum members.
|
|
822
|
+
|
|
823
|
+
Returns
|
|
824
|
+
-------
|
|
825
|
+
type[Enum]
|
|
826
|
+
Enum class with sanitized member names.
|
|
827
|
+
"""
|
|
828
|
+
members: dict[str, str] = {}
|
|
829
|
+
for index, value in enumerate(values, start=1):
|
|
830
|
+
member_name = _sanitize_enum_member(value, index, members)
|
|
831
|
+
members[member_name] = value
|
|
832
|
+
if not members:
|
|
833
|
+
members["UNSPECIFIED"] = ""
|
|
834
|
+
return cast(type[Enum], Enum(name, members))
|
|
835
|
+
|
|
836
|
+
|
|
837
|
+
def _split_taxonomy_path(value: str) -> list[str]:
|
|
838
|
+
"""Split a taxonomy identifier into its path segments.
|
|
839
|
+
|
|
840
|
+
Parameters
|
|
841
|
+
----------
|
|
842
|
+
value : str
|
|
843
|
+
Taxonomy path identifier to split.
|
|
844
|
+
|
|
845
|
+
Returns
|
|
846
|
+
-------
|
|
847
|
+
list[str]
|
|
848
|
+
Path segments with escaped delimiters restored.
|
|
849
|
+
"""
|
|
850
|
+
delimiter = " > "
|
|
851
|
+
escape_token = "\\>"
|
|
852
|
+
segments = value.split(delimiter)
|
|
853
|
+
return [segment.replace(escape_token, delimiter) for segment in segments]
|
|
854
|
+
|
|
855
|
+
|
|
856
|
+
def _sanitize_enum_member(
|
|
857
|
+
value: str,
|
|
858
|
+
index: int,
|
|
859
|
+
existing: dict[str, str],
|
|
860
|
+
) -> str:
|
|
861
|
+
"""Return a valid enum member name for a taxonomy value.
|
|
862
|
+
|
|
863
|
+
Parameters
|
|
864
|
+
----------
|
|
865
|
+
value : str
|
|
866
|
+
Raw taxonomy value to sanitize.
|
|
867
|
+
index : int
|
|
868
|
+
Index of the value in the source list.
|
|
869
|
+
existing : dict[str, str]
|
|
870
|
+
Existing enum members to avoid collisions.
|
|
871
|
+
|
|
872
|
+
Returns
|
|
873
|
+
-------
|
|
874
|
+
str
|
|
875
|
+
Sanitized enum member name.
|
|
876
|
+
"""
|
|
877
|
+
normalized_segments: list[str] = []
|
|
878
|
+
for segment in _split_taxonomy_path(value):
|
|
879
|
+
normalized = re.sub(r"[^0-9a-zA-Z]+", "_", segment).strip("_").upper()
|
|
880
|
+
if not normalized:
|
|
881
|
+
normalized = "VALUE"
|
|
882
|
+
if normalized[0].isdigit():
|
|
883
|
+
normalized = f"VALUE_{normalized}"
|
|
884
|
+
normalized_segments.append(normalized)
|
|
885
|
+
normalized_path = "__".join(normalized_segments) or f"VALUE_{index}"
|
|
886
|
+
candidate = normalized_path
|
|
887
|
+
suffix = 1
|
|
888
|
+
while candidate in existing:
|
|
889
|
+
candidate = f"{normalized_path}__{suffix}"
|
|
890
|
+
suffix += 1
|
|
891
|
+
return candidate
|
|
892
|
+
|
|
893
|
+
|
|
894
|
+
def _normalize_step_output(
|
|
895
|
+
step: StructureBase,
|
|
896
|
+
step_structure: type[StructureBase],
|
|
897
|
+
) -> ClassificationStep:
|
|
898
|
+
"""Normalize dynamic step output into a ClassificationStep.
|
|
899
|
+
|
|
900
|
+
Parameters
|
|
901
|
+
----------
|
|
902
|
+
step : StructureBase
|
|
903
|
+
Raw step output returned by the agent.
|
|
904
|
+
step_structure : type[StructureBase]
|
|
905
|
+
Structure definition used to parse the agent output.
|
|
906
|
+
|
|
907
|
+
Returns
|
|
908
|
+
-------
|
|
909
|
+
ClassificationStep
|
|
910
|
+
Normalized classification step instance.
|
|
911
|
+
"""
|
|
912
|
+
if isinstance(step, ClassificationStep):
|
|
913
|
+
return step
|
|
914
|
+
payload = step.to_json()
|
|
915
|
+
return ClassificationStep.from_json(payload)
|
|
916
|
+
|
|
917
|
+
|
|
918
|
+
def _build_input_payload(
|
|
919
|
+
text: str,
|
|
920
|
+
file_ids: str | Sequence[str] | None,
|
|
921
|
+
) -> str | list[dict[str, Any]]:
|
|
922
|
+
"""Build input payloads with optional file attachments.
|
|
923
|
+
|
|
924
|
+
Parameters
|
|
925
|
+
----------
|
|
926
|
+
text : str
|
|
927
|
+
Prompt text to send to the agent.
|
|
928
|
+
file_ids : str or Sequence[str] or None
|
|
929
|
+
Optional file IDs to include as ``input_file`` attachments.
|
|
930
|
+
|
|
931
|
+
Returns
|
|
932
|
+
-------
|
|
933
|
+
str or list[dict[str, Any]]
|
|
934
|
+
Input payload suitable for the Agents SDK.
|
|
935
|
+
"""
|
|
936
|
+
normalized_file_ids = [file_id for file_id in ensure_list(file_ids) if file_id]
|
|
937
|
+
if not normalized_file_ids:
|
|
938
|
+
return text
|
|
939
|
+
attachments = [
|
|
940
|
+
{"type": "input_file", "file_id": file_id} for file_id in normalized_file_ids
|
|
941
|
+
]
|
|
942
|
+
return [
|
|
943
|
+
{
|
|
944
|
+
"role": "user",
|
|
945
|
+
"content": [{"type": "input_text", "text": text}, *attachments],
|
|
946
|
+
}
|
|
947
|
+
]
|
|
948
|
+
|
|
949
|
+
|
|
950
|
+
def _extract_enum_fields(
|
|
951
|
+
step_structure: type[StructureBase],
|
|
952
|
+
) -> dict[str, type[Enum]]:
|
|
953
|
+
"""Return the enum field mapping for a step structure.
|
|
954
|
+
|
|
955
|
+
Parameters
|
|
956
|
+
----------
|
|
957
|
+
step_structure : type[StructureBase]
|
|
958
|
+
Structure definition to inspect.
|
|
959
|
+
|
|
960
|
+
Returns
|
|
961
|
+
-------
|
|
962
|
+
dict[str, type[Enum]]
|
|
963
|
+
Mapping of field names to enum classes.
|
|
964
|
+
"""
|
|
965
|
+
enum_fields: dict[str, type[Enum]] = {}
|
|
966
|
+
for field_name, model_field in step_structure.model_fields.items():
|
|
967
|
+
enum_cls = step_structure._extract_enum_class(model_field.annotation)
|
|
968
|
+
if enum_cls is not None:
|
|
969
|
+
enum_fields[field_name] = enum_cls
|
|
970
|
+
return enum_fields
|
|
971
|
+
|
|
972
|
+
|
|
973
|
+
def _normalize_enum_value(value: Any, enum_cls: type[Enum]) -> Any:
|
|
974
|
+
"""Normalize enum values into raw primitives.
|
|
975
|
+
|
|
976
|
+
Parameters
|
|
977
|
+
----------
|
|
978
|
+
value : Any
|
|
979
|
+
Value to normalize.
|
|
980
|
+
enum_cls : type[Enum]
|
|
981
|
+
Enum type used for normalization.
|
|
982
|
+
|
|
983
|
+
Returns
|
|
984
|
+
-------
|
|
985
|
+
Any
|
|
986
|
+
Primitive value suitable for ``ClassificationStep``.
|
|
987
|
+
"""
|
|
988
|
+
if isinstance(value, Enum):
|
|
989
|
+
return value.value
|
|
990
|
+
if isinstance(value, list):
|
|
991
|
+
return [_normalize_enum_value(item, enum_cls) for item in value]
|
|
992
|
+
if isinstance(value, str):
|
|
993
|
+
if value in enum_cls._value2member_map_:
|
|
994
|
+
return enum_cls(value).value
|
|
995
|
+
if value in enum_cls.__members__:
|
|
996
|
+
return enum_cls.__members__[value].value
|
|
997
|
+
return value
|
|
998
|
+
|
|
999
|
+
|
|
1000
|
+
def _resolve_nodes(
|
|
1001
|
+
node_paths: dict[str, TaxonomyNode],
|
|
1002
|
+
step: ClassificationStep,
|
|
1003
|
+
) -> list[TaxonomyNode]:
|
|
1004
|
+
"""Resolve selected taxonomy nodes for a classification step.
|
|
1005
|
+
|
|
1006
|
+
Parameters
|
|
1007
|
+
----------
|
|
1008
|
+
node_paths : dict[str, TaxonomyNode]
|
|
1009
|
+
Mapping of path identifiers to nodes at the current level.
|
|
228
1010
|
step : ClassificationStep
|
|
229
1011
|
Classification step output to resolve.
|
|
230
1012
|
|
|
231
1013
|
Returns
|
|
232
1014
|
-------
|
|
233
|
-
TaxonomyNode
|
|
234
|
-
Matching taxonomy
|
|
1015
|
+
list[TaxonomyNode]
|
|
1016
|
+
Matching taxonomy nodes in priority order.
|
|
235
1017
|
"""
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
path: Sequence[ClassificationStep],
|
|
249
|
-
) -> tuple[Optional[str], Optional[str], Optional[float]]:
|
|
250
|
-
"""Return the final selection values from the path.
|
|
1018
|
+
resolved: list[TaxonomyNode] = []
|
|
1019
|
+
selected_nodes = _selected_nodes(step)
|
|
1020
|
+
if selected_nodes:
|
|
1021
|
+
for selected_node in selected_nodes:
|
|
1022
|
+
node = node_paths.get(selected_node)
|
|
1023
|
+
if node:
|
|
1024
|
+
resolved.append(node)
|
|
1025
|
+
return resolved
|
|
1026
|
+
|
|
1027
|
+
|
|
1028
|
+
def _selected_nodes(step: ClassificationStep) -> list[str]:
|
|
1029
|
+
"""Return selected identifiers for a classification step.
|
|
251
1030
|
|
|
252
1031
|
Parameters
|
|
253
1032
|
----------
|
|
254
|
-
|
|
255
|
-
|
|
1033
|
+
step : ClassificationStep
|
|
1034
|
+
Classification output to normalize.
|
|
1035
|
+
|
|
1036
|
+
Returns
|
|
1037
|
+
-------
|
|
1038
|
+
list[str]
|
|
1039
|
+
Selected identifiers in priority order.
|
|
1040
|
+
"""
|
|
1041
|
+
if step.selected_nodes is not None:
|
|
1042
|
+
selected_nodes = [
|
|
1043
|
+
str(_normalize_enum_value(selected_node, Enum))
|
|
1044
|
+
for selected_node in step.selected_nodes
|
|
1045
|
+
if selected_node
|
|
1046
|
+
]
|
|
1047
|
+
if selected_nodes:
|
|
1048
|
+
return selected_nodes
|
|
1049
|
+
if step.selected_node:
|
|
1050
|
+
return [str(_normalize_enum_value(step.selected_node, Enum))]
|
|
1051
|
+
return []
|
|
1052
|
+
|
|
1053
|
+
|
|
1054
|
+
def _max_confidence(
|
|
1055
|
+
current: float | None,
|
|
1056
|
+
candidate: float | None,
|
|
1057
|
+
) -> float | None:
|
|
1058
|
+
"""Return the higher confidence value.
|
|
1059
|
+
|
|
1060
|
+
Parameters
|
|
1061
|
+
----------
|
|
1062
|
+
current : float or None
|
|
1063
|
+
Current best confidence value.
|
|
1064
|
+
candidate : float or None
|
|
1065
|
+
Candidate confidence value to compare.
|
|
256
1066
|
|
|
257
1067
|
Returns
|
|
258
1068
|
-------
|
|
259
|
-
|
|
260
|
-
|
|
1069
|
+
float or None
|
|
1070
|
+
Highest confidence value available.
|
|
261
1071
|
"""
|
|
262
|
-
if
|
|
263
|
-
return
|
|
264
|
-
|
|
265
|
-
|
|
1072
|
+
if current is None:
|
|
1073
|
+
return candidate
|
|
1074
|
+
if candidate is None:
|
|
1075
|
+
return current
|
|
1076
|
+
return max(current, candidate)
|
|
266
1077
|
|
|
267
1078
|
|
|
268
1079
|
__all__ = ["TaxonomyClassifierAgent"]
|