openai-sdk-helpers 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,22 +1,29 @@
1
- """Agent for taxonomy-driven text classification."""
1
+ """Recursive agent for taxonomy-driven text classification."""
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import asyncio
6
+ import threading
7
+ import re
8
+ from dataclasses import dataclass, field
9
+ from enum import Enum
5
10
  from pathlib import Path
6
- from typing import Any, Dict, Iterable, Optional, Sequence
11
+ from typing import Any, Awaitable, Dict, Iterable, Optional, Sequence, cast
7
12
 
8
13
  from ..structure import (
9
14
  ClassificationResult,
10
15
  ClassificationStep,
11
16
  ClassificationStopReason,
17
+ StructureBase,
12
18
  TaxonomyNode,
13
19
  )
20
+ from ..utils import ensure_list
14
21
  from .base import AgentBase
15
22
  from .configuration import AgentConfiguration
16
23
 
17
24
 
18
25
  class TaxonomyClassifierAgent(AgentBase):
19
- """Classify text by traversing a taxonomy level by level.
26
+ """Classify text by recursively traversing a taxonomy.
20
27
 
21
28
  Parameters
22
29
  ----------
@@ -27,8 +34,22 @@ class TaxonomyClassifierAgent(AgentBase):
27
34
 
28
35
  Methods
29
36
  -------
30
- run_agent(text, taxonomy, context, max_depth)
31
- Classify text by walking the taxonomy tree.
37
+ run_agent(text, taxonomy, context, max_depth, session)
38
+ Classify text by recursively walking the taxonomy tree.
39
+ run_async(input, context, max_depth, confidence_threshold, single_class)
40
+ Classify text asynchronously using taxonomy traversal.
41
+ run_sync(input, context, max_depth, confidence_threshold, single_class)
42
+ Classify text synchronously using taxonomy traversal.
43
+
44
+ Examples
45
+ --------
46
+ Create a classifier with a flat taxonomy:
47
+
48
+ >>> taxonomy = [
49
+ ... TaxonomyNode(label="Billing"),
50
+ ... TaxonomyNode(label="Support"),
51
+ ... ]
52
+ >>> agent = TaxonomyClassifierAgent(model="gpt-4o-mini", taxonomy=taxonomy)
32
53
  """
33
54
 
34
55
  def __init__(
@@ -36,6 +57,7 @@ class TaxonomyClassifierAgent(AgentBase):
36
57
  *,
37
58
  template_path: Path | str | None = None,
38
59
  model: str | None = None,
60
+ taxonomy: TaxonomyNode | Sequence[TaxonomyNode],
39
61
  ) -> None:
40
62
  """Initialize the taxonomy classifier agent configuration.
41
63
 
@@ -45,21 +67,27 @@ class TaxonomyClassifierAgent(AgentBase):
45
67
  Optional template file path for prompt rendering.
46
68
  model : str | None, default=None
47
69
  Model identifier to use for classification.
70
+ taxonomy : TaxonomyNode | Sequence[TaxonomyNode]
71
+ Root taxonomy node or list of root nodes.
48
72
 
49
73
  Raises
50
74
  ------
51
75
  ValueError
52
- If the model is not provided.
76
+ If the taxonomy is empty.
53
77
 
54
78
  Examples
55
79
  --------
56
- >>> classifier = TaxonomyClassifierAgent(model="gpt-4o-mini")
80
+ >>> classifier = TaxonomyClassifierAgent(model="gpt-4o-mini", taxonomy=[])
57
81
  """
82
+ self._taxonomy = taxonomy
83
+ self._root_nodes = _normalize_roots(taxonomy)
84
+ if not self._root_nodes:
85
+ raise ValueError("taxonomy must include at least one node")
58
86
  resolved_template_path = template_path or _default_template_path()
59
87
  configuration = AgentConfiguration(
60
88
  name="taxonomy_classifier",
61
89
  instructions="Agent instructions",
62
- description="Classify text by traversing taxonomy levels.",
90
+ description="Classify text by traversing taxonomy levels recursively.",
63
91
  template_path=resolved_template_path,
64
92
  output_structure=ClassificationStep,
65
93
  model=model,
@@ -69,84 +97,555 @@ class TaxonomyClassifierAgent(AgentBase):
69
97
  async def run_agent(
70
98
  self,
71
99
  text: str,
72
- taxonomy: TaxonomyNode | Sequence[TaxonomyNode],
73
100
  *,
74
101
  context: Optional[Dict[str, Any]] = None,
102
+ file_ids: str | Sequence[str] | None = None,
75
103
  max_depth: Optional[int] = None,
104
+ confidence_threshold: float | None = None,
105
+ single_class: bool = False,
106
+ session: Optional[Any] = None,
76
107
  ) -> ClassificationResult:
77
- """Classify ``text`` by iterating over taxonomy levels.
108
+ """Classify ``text`` by recursively walking taxonomy levels.
78
109
 
79
110
  Parameters
80
111
  ----------
81
112
  text : str
82
113
  Source text to classify.
83
- taxonomy : TaxonomyNode or Sequence[TaxonomyNode]
84
- Root taxonomy node or list of root nodes to traverse.
85
114
  context : dict or None, default=None
86
115
  Additional context values to merge into the prompt.
116
+ file_ids : str or Sequence[str] or None, default=None
117
+ Optional file IDs to attach to each classification step.
87
118
  max_depth : int or None, default=None
88
119
  Maximum depth to traverse before stopping.
120
+ confidence_threshold : float or None, default=None
121
+ Minimum confidence required to accept a classification step.
122
+ single_class : bool, default=False
123
+ Whether to keep only the highest-priority selection per step.
124
+ session : Session or None, default=None
125
+ Optional session for maintaining conversation history across runs.
89
126
 
90
127
  Returns
91
128
  -------
92
129
  ClassificationResult
93
130
  Structured classification result describing the traversal.
94
131
 
95
- Raises
96
- ------
97
- ValueError
98
- If ``taxonomy`` is empty.
132
+ Examples
133
+ --------
134
+ >>> taxonomy = TaxonomyNode(label="Finance")
135
+ >>> agent = TaxonomyClassifierAgent(model="gpt-4o-mini", taxonomy=taxonomy)
136
+ >>> isinstance(agent.root_nodes, list)
137
+ True
99
138
  """
100
- roots = _normalize_roots(taxonomy)
101
- if not roots:
102
- raise ValueError("taxonomy must include at least one node")
139
+ state = _TraversalState()
140
+ input_payload = _build_input_payload(text, file_ids)
141
+ await self._classify_nodes(
142
+ input_payload=input_payload,
143
+ nodes=list(self._root_nodes),
144
+ depth=0,
145
+ parent_path=[],
146
+ context=context,
147
+ file_ids=file_ids,
148
+ max_depth=max_depth,
149
+ confidence_threshold=confidence_threshold,
150
+ single_class=single_class,
151
+ session=session,
152
+ state=state,
153
+ )
103
154
 
104
- path: list[ClassificationStep] = []
105
- depth = 0
106
- stop_reason = ClassificationStopReason.NO_MATCH
107
- current_nodes = list(roots)
108
-
109
- while current_nodes:
110
- if max_depth is not None and depth >= max_depth:
111
- stop_reason = ClassificationStopReason.MAX_DEPTH
112
- break
113
-
114
- template_context = _build_context(
115
- current_nodes=current_nodes,
116
- path=path,
117
- depth=depth,
118
- context=context,
119
- )
120
- step: ClassificationStep = await self.run_async(
121
- input=text,
122
- context=template_context,
123
- output_structure=ClassificationStep,
155
+ final_nodes_value = state.final_nodes or None
156
+ final_node = state.final_nodes[0] if state.final_nodes else None
157
+ stop_reason = _resolve_stop_reason(state)
158
+ return ClassificationResult(
159
+ final_node=final_node,
160
+ final_nodes=final_nodes_value,
161
+ confidence=state.best_confidence,
162
+ stop_reason=stop_reason,
163
+ path=state.path,
164
+ path_nodes=state.path_nodes,
165
+ )
166
+
167
+ async def run_async(
168
+ self,
169
+ input: str | list[dict[str, Any]],
170
+ *,
171
+ context: Optional[Dict[str, Any]] = None,
172
+ output_structure: Optional[type[StructureBase]] = None,
173
+ session: Optional[Any] = None,
174
+ file_ids: str | Sequence[str] | None = None,
175
+ max_depth: Optional[int] = None,
176
+ confidence_threshold: float | None = None,
177
+ single_class: bool = False,
178
+ ) -> ClassificationResult:
179
+ """Classify ``input`` asynchronously with taxonomy traversal.
180
+
181
+ Parameters
182
+ ----------
183
+ input : str or list[dict[str, Any]]
184
+ Source text to classify.
185
+ context : dict or None, default=None
186
+ Additional context values to merge into the prompt.
187
+ output_structure : type[StructureBase] or None, default=None
188
+ Unused in taxonomy traversal. Present for API compatibility.
189
+ session : Session or None, default=None
190
+ Optional session for maintaining conversation history across runs.
191
+ file_ids : str or Sequence[str] or None, default=None
192
+ Optional file IDs to attach to each classification step.
193
+ max_depth : int or None, default=None
194
+ Maximum depth to traverse before stopping.
195
+ confidence_threshold : float or None, default=None
196
+ Minimum confidence required to accept a classification step.
197
+ single_class : bool, default=False
198
+ Whether to keep only the highest-priority selection per step.
199
+
200
+ Returns
201
+ -------
202
+ ClassificationResult
203
+ Structured classification result describing the traversal.
204
+ """
205
+ _ = output_structure
206
+ if not isinstance(input, str):
207
+ msg = "TaxonomyClassifierAgent run_async requires text input."
208
+ raise TypeError(msg)
209
+ kwargs: Dict[str, Any] = {
210
+ "context": context,
211
+ "file_ids": file_ids,
212
+ "max_depth": max_depth,
213
+ "confidence_threshold": confidence_threshold,
214
+ "single_class": single_class,
215
+ }
216
+ if session is not None:
217
+ kwargs["session"] = session
218
+ return await self.run_agent(input, **kwargs)
219
+
220
+ def run_sync(
221
+ self,
222
+ input: str | list[dict[str, Any]],
223
+ *,
224
+ context: Optional[Dict[str, Any]] = None,
225
+ output_structure: Optional[type[StructureBase]] = None,
226
+ session: Optional[Any] = None,
227
+ file_ids: str | Sequence[str] | None = None,
228
+ max_depth: Optional[int] = None,
229
+ confidence_threshold: float | None = None,
230
+ single_class: bool = False,
231
+ ) -> ClassificationResult:
232
+ """Classify ``input`` synchronously with taxonomy traversal.
233
+
234
+ Parameters
235
+ ----------
236
+ input : str or list[dict[str, Any]]
237
+ Source text to classify.
238
+ context : dict or None, default=None
239
+ Additional context values to merge into the prompt.
240
+ output_structure : type[StructureBase] or None, default=None
241
+ Unused in taxonomy traversal. Present for API compatibility.
242
+ session : Session or None, default=None
243
+ Optional session for maintaining conversation history across runs.
244
+ file_ids : str or Sequence[str] or None, default=None
245
+ Optional file IDs to attach to each classification step.
246
+ max_depth : int or None, default=None
247
+ Maximum depth to traverse before stopping.
248
+ confidence_threshold : float or None, default=None
249
+ Minimum confidence required to accept a classification step.
250
+ single_class : bool, default=False
251
+ Whether to keep only the highest-priority selection per step.
252
+
253
+ Returns
254
+ -------
255
+ ClassificationResult
256
+ Structured classification result describing the traversal.
257
+ """
258
+ _ = output_structure
259
+ if not isinstance(input, str):
260
+ msg = "TaxonomyClassifierAgent run_sync requires text input."
261
+ raise TypeError(msg)
262
+ kwargs: Dict[str, Any] = {
263
+ "context": context,
264
+ "file_ids": file_ids,
265
+ "max_depth": max_depth,
266
+ "confidence_threshold": confidence_threshold,
267
+ "single_class": single_class,
268
+ }
269
+ if session is not None:
270
+ kwargs["session"] = session
271
+
272
+ async def runner() -> ClassificationResult:
273
+ return await self.run_agent(input, **kwargs)
274
+
275
+ try:
276
+ asyncio.get_running_loop()
277
+ except RuntimeError:
278
+ return asyncio.run(runner())
279
+
280
+ result: ClassificationResult | None = None
281
+ error: Exception | None = None
282
+
283
+ def _thread_func() -> None:
284
+ nonlocal error, result
285
+ try:
286
+ result = asyncio.run(runner())
287
+ except Exception as exc:
288
+ error = exc
289
+
290
+ thread = threading.Thread(target=_thread_func)
291
+ thread.start()
292
+ thread.join()
293
+
294
+ if error is not None:
295
+ raise error
296
+ if result is None:
297
+ msg = "Classification did not return a result"
298
+ raise RuntimeError(msg)
299
+ return result
300
+
301
+ async def _run_step_async(
302
+ self,
303
+ *,
304
+ input: str | list[dict[str, Any]],
305
+ context: Optional[Dict[str, Any]] = None,
306
+ output_structure: Optional[type[StructureBase]] = None,
307
+ session: Optional[Any] = None,
308
+ ) -> StructureBase:
309
+ """Execute a single classification step asynchronously.
310
+
311
+ Parameters
312
+ ----------
313
+ input : str or list[dict[str, Any]]
314
+ Prompt or structured input for the agent.
315
+ context : dict or None, default=None
316
+ Optional dictionary passed to the agent.
317
+ output_structure : type[StructureBase] or None, default=None
318
+ Optional type used to cast the final output.
319
+ session : Session or None, default=None
320
+ Optional session for maintaining conversation history across runs.
321
+
322
+ Returns
323
+ -------
324
+ StructureBase
325
+ Parsed result for the classification step.
326
+ """
327
+ return await super().run_async(
328
+ input=input,
329
+ context=context,
330
+ output_structure=output_structure,
331
+ session=session,
332
+ )
333
+
334
+ async def _classify_nodes(
335
+ self,
336
+ *,
337
+ input_payload: str | list[dict[str, Any]],
338
+ nodes: list[TaxonomyNode],
339
+ depth: int,
340
+ parent_path: list[str],
341
+ context: Optional[Dict[str, Any]],
342
+ file_ids: str | Sequence[str] | None,
343
+ max_depth: Optional[int],
344
+ confidence_threshold: float | None,
345
+ single_class: bool,
346
+ session: Optional[Any],
347
+ state: "_TraversalState",
348
+ ) -> None:
349
+ """Classify a taxonomy level and recursively traverse children.
350
+
351
+ Parameters
352
+ ----------
353
+ input_payload : str or list[dict[str, Any]]
354
+ Input payload used to prompt the agent.
355
+ nodes : list[TaxonomyNode]
356
+ Candidate taxonomy nodes for the current level.
357
+ depth : int
358
+ Current traversal depth.
359
+ context : dict or None
360
+ Additional context values to merge into the prompt.
361
+ file_ids : str or Sequence[str] or None
362
+ Optional file IDs attached to each classification step.
363
+ max_depth : int or None
364
+ Maximum traversal depth before stopping.
365
+ confidence_threshold : float or None
366
+ Minimum confidence required to accept a classification step.
367
+ single_class : bool
368
+ Whether to keep only the highest-priority selection per step.
369
+ session : Session or None
370
+ Optional session for maintaining conversation history across runs.
371
+ state : _TraversalState
372
+ Aggregated traversal state.
373
+ """
374
+ if max_depth is not None and depth >= max_depth:
375
+ state.saw_max_depth = True
376
+ return
377
+ if not nodes:
378
+ return
379
+
380
+ node_paths = _build_node_path_map(nodes, parent_path)
381
+ template_context = _build_context(
382
+ node_descriptors=_build_node_descriptors(node_paths),
383
+ path=state.path,
384
+ depth=depth,
385
+ context=context,
386
+ )
387
+ step_structure = _build_step_structure(list(node_paths.keys()))
388
+ raw_step = await self._run_step_async(
389
+ input=input_payload,
390
+ context=template_context,
391
+ output_structure=step_structure,
392
+ session=session,
393
+ )
394
+ step = _normalize_step_output(raw_step, step_structure)
395
+ state.path.append(step)
396
+
397
+ if (
398
+ confidence_threshold is not None
399
+ and step.confidence is not None
400
+ and step.confidence < confidence_threshold
401
+ ):
402
+ return
403
+
404
+ resolved_nodes = _resolve_nodes(node_paths, step)
405
+ if resolved_nodes:
406
+ if single_class:
407
+ resolved_nodes = resolved_nodes[:1]
408
+ state.path_nodes.extend(resolved_nodes)
409
+
410
+ if step.stop_reason.is_terminal:
411
+ if resolved_nodes:
412
+ state.final_nodes.extend(resolved_nodes)
413
+ state.best_confidence = _max_confidence(
414
+ state.best_confidence, step.confidence
415
+ )
416
+ state.saw_terminal_stop = True
417
+ return
418
+
419
+ if not resolved_nodes:
420
+ return
421
+
422
+ base_path_len = len(state.path)
423
+ base_path_nodes_len = len(state.path_nodes)
424
+ child_tasks: list[tuple[Awaitable["_TraversalState"], int]] = []
425
+ for node in resolved_nodes:
426
+ if node.children:
427
+ sub_agent = self._build_sub_agent(list(node.children))
428
+ sub_state = _copy_traversal_state(state)
429
+ base_final_nodes_len = len(state.final_nodes)
430
+ child_tasks.append(
431
+ (
432
+ self._classify_subtree(
433
+ sub_agent=sub_agent,
434
+ input_payload=input_payload,
435
+ nodes=list(node.children),
436
+ depth=depth + 1,
437
+ parent_path=[*parent_path, node.label],
438
+ context=context,
439
+ file_ids=file_ids,
440
+ max_depth=max_depth,
441
+ confidence_threshold=confidence_threshold,
442
+ single_class=single_class,
443
+ session=session,
444
+ state=sub_state,
445
+ ),
446
+ base_final_nodes_len,
447
+ )
448
+ )
449
+ else:
450
+ state.saw_no_children = True
451
+ state.final_nodes.append(node)
452
+ state.best_confidence = _max_confidence(
453
+ state.best_confidence, step.confidence
454
+ )
455
+ if child_tasks:
456
+ child_states = await asyncio.gather(
457
+ *(child_task for child_task, _ in child_tasks)
124
458
  )
125
- path.append(step)
126
- stop_reason = step.stop_reason
459
+ for child_state, (_, base_final_nodes_len) in zip(
460
+ child_states, child_tasks, strict=True
461
+ ):
462
+ state.path.extend(child_state.path[base_path_len:])
463
+ state.path_nodes.extend(child_state.path_nodes[base_path_nodes_len:])
464
+ state.final_nodes.extend(child_state.final_nodes[base_final_nodes_len:])
465
+ state.best_confidence = _max_confidence(
466
+ state.best_confidence, child_state.best_confidence
467
+ )
468
+ state.saw_max_depth = state.saw_max_depth or child_state.saw_max_depth
469
+ state.saw_no_children = (
470
+ state.saw_no_children or child_state.saw_no_children
471
+ )
472
+ state.saw_terminal_stop = (
473
+ state.saw_terminal_stop or child_state.saw_terminal_stop
474
+ )
127
475
 
128
- if step.stop_reason.is_terminal:
129
- break
476
+ @property
477
+ def taxonomy(self) -> TaxonomyNode | Sequence[TaxonomyNode]:
478
+ """Return the root taxonomy node(s).
130
479
 
131
- selected_node = _resolve_node(current_nodes, step)
132
- if selected_node is None:
133
- stop_reason = ClassificationStopReason.NO_MATCH
134
- break
135
- if not selected_node.children:
136
- stop_reason = ClassificationStopReason.NO_CHILDREN
137
- break
480
+ Returns
481
+ -------
482
+ TaxonomyNode or Sequence[TaxonomyNode]
483
+ Root taxonomy node or list of root nodes.
484
+ """
485
+ return self._taxonomy
138
486
 
139
- current_nodes = list(selected_node.children)
140
- depth += 1
487
+ @property
488
+ def root_nodes(self) -> list[TaxonomyNode]:
489
+ """Return the list of root taxonomy nodes.
141
490
 
142
- final_id, final_label, confidence = _final_values(path)
143
- return ClassificationResult(
144
- final_id=final_id,
145
- final_label=final_label,
146
- confidence=confidence,
147
- stop_reason=stop_reason,
148
- path=path,
491
+ Returns
492
+ -------
493
+ list[TaxonomyNode]
494
+ List of root taxonomy nodes.
495
+ """
496
+ return self._root_nodes
497
+
498
+ def _build_sub_agent(
499
+ self,
500
+ nodes: Sequence[TaxonomyNode],
501
+ ) -> "TaxonomyClassifierAgent":
502
+ """Build a classifier agent for a taxonomy subtree.
503
+
504
+ Parameters
505
+ ----------
506
+ nodes : Sequence[TaxonomyNode]
507
+ Taxonomy nodes to use as the sub-agent's root taxonomy.
508
+
509
+ Returns
510
+ -------
511
+ TaxonomyClassifierAgent
512
+ Configured classifier agent for the taxonomy slice.
513
+ """
514
+ sub_agent = TaxonomyClassifierAgent(
515
+ template_path=self._template_path,
516
+ model=self._model,
517
+ taxonomy=list(nodes),
518
+ )
519
+ sub_agent._run_step_async = self._run_step_async
520
+ return sub_agent
521
+
522
+ async def _classify_subtree(
523
+ self,
524
+ *,
525
+ sub_agent: "TaxonomyClassifierAgent",
526
+ input_payload: str | list[dict[str, Any]],
527
+ nodes: list[TaxonomyNode],
528
+ depth: int,
529
+ parent_path: list[str],
530
+ context: Optional[Dict[str, Any]],
531
+ file_ids: str | Sequence[str] | None,
532
+ max_depth: Optional[int],
533
+ confidence_threshold: float | None,
534
+ single_class: bool,
535
+ session: Optional[Any],
536
+ state: "_TraversalState",
537
+ ) -> "_TraversalState":
538
+ """Classify a taxonomy subtree and return the traversal state.
539
+
540
+ Parameters
541
+ ----------
542
+ sub_agent : TaxonomyClassifierAgent
543
+ Sub-agent configured for the subtree traversal.
544
+ input_payload : str or list[dict[str, Any]]
545
+ Input payload used to prompt the agent.
546
+ nodes : list[TaxonomyNode]
547
+ Candidate taxonomy nodes for the subtree.
548
+ depth : int
549
+ Current traversal depth.
550
+ parent_path : list[str]
551
+ Path segments leading to the current subtree.
552
+ context : dict or None
553
+ Additional context values to merge into the prompt.
554
+ file_ids : str or Sequence[str] or None
555
+ Optional file IDs attached to each classification step.
556
+ max_depth : int or None
557
+ Maximum traversal depth before stopping.
558
+ confidence_threshold : float or None
559
+ Minimum confidence required to accept a classification step.
560
+ single_class : bool
561
+ Whether to keep only the highest-priority selection per step.
562
+ session : Session or None
563
+ Optional session for maintaining conversation history across runs.
564
+ state : _TraversalState
565
+ Traversal state to populate for the subtree.
566
+
567
+ Returns
568
+ -------
569
+ _TraversalState
570
+ Populated traversal state for the subtree.
571
+ """
572
+ await sub_agent._classify_nodes(
573
+ input_payload=input_payload,
574
+ nodes=nodes,
575
+ depth=depth,
576
+ parent_path=parent_path,
577
+ context=context,
578
+ file_ids=file_ids,
579
+ max_depth=max_depth,
580
+ confidence_threshold=confidence_threshold,
581
+ single_class=single_class,
582
+ session=session,
583
+ state=state,
149
584
  )
585
+ return state
586
+
587
+
588
+ @dataclass
589
+ class _TraversalState:
590
+ """Track recursive traversal state."""
591
+
592
+ path: list[ClassificationStep] = field(default_factory=list)
593
+ path_nodes: list[TaxonomyNode] = field(default_factory=list)
594
+ final_nodes: list[TaxonomyNode] = field(default_factory=list)
595
+ best_confidence: float | None = None
596
+ saw_max_depth: bool = False
597
+ saw_no_children: bool = False
598
+ saw_terminal_stop: bool = False
599
+
600
+
601
+ def _copy_traversal_state(state: _TraversalState) -> _TraversalState:
602
+ """Copy traversal state for parallel subtree execution.
603
+
604
+ Parameters
605
+ ----------
606
+ state : _TraversalState
607
+ Traversal state to clone.
608
+
609
+ Returns
610
+ -------
611
+ _TraversalState
612
+ Cloned traversal state with copied collections.
613
+ """
614
+ return _TraversalState(
615
+ path=list(state.path),
616
+ path_nodes=list(state.path_nodes),
617
+ final_nodes=list(state.final_nodes),
618
+ best_confidence=state.best_confidence,
619
+ saw_max_depth=state.saw_max_depth,
620
+ saw_no_children=state.saw_no_children,
621
+ saw_terminal_stop=state.saw_terminal_stop,
622
+ )
623
+
624
+
625
+ def _resolve_stop_reason(state: _TraversalState) -> ClassificationStopReason:
626
+ """Resolve the final stop reason based on traversal state.
627
+
628
+ Parameters
629
+ ----------
630
+ state : _TraversalState
631
+ Traversal state to inspect.
632
+
633
+ Returns
634
+ -------
635
+ ClassificationStopReason
636
+ Resolved stop reason.
637
+ """
638
+ if state.saw_terminal_stop:
639
+ return ClassificationStopReason.STOP
640
+ if state.final_nodes and state.saw_no_children:
641
+ return ClassificationStopReason.NO_CHILDREN
642
+ if state.final_nodes:
643
+ return ClassificationStopReason.STOP
644
+ if state.saw_max_depth:
645
+ return ClassificationStopReason.MAX_DEPTH
646
+ if state.saw_no_children:
647
+ return ClassificationStopReason.NO_CHILDREN
648
+ return ClassificationStopReason.NO_MATCH
150
649
 
151
650
 
152
651
  def _normalize_roots(
@@ -156,7 +655,7 @@ def _normalize_roots(
156
655
 
157
656
  Parameters
158
657
  ----------
159
- taxonomy : TaxonomyNode or Sequence[TaxonomyNode]
658
+ taxonomy : TaxonomyNode | Sequence[TaxonomyNode]
160
659
  Root taxonomy node or list of root nodes.
161
660
 
162
661
  Returns
@@ -182,7 +681,7 @@ def _default_template_path() -> Path:
182
681
 
183
682
  def _build_context(
184
683
  *,
185
- current_nodes: Iterable[TaxonomyNode],
684
+ node_descriptors: Iterable[dict[str, Any]],
186
685
  path: Sequence[ClassificationStep],
187
686
  depth: int,
188
687
  context: Optional[Dict[str, Any]],
@@ -191,8 +690,8 @@ def _build_context(
191
690
 
192
691
  Parameters
193
692
  ----------
194
- current_nodes : Iterable[TaxonomyNode]
195
- Nodes available at the current taxonomy level.
693
+ node_descriptors : Iterable[dict[str, Any]]
694
+ Node descriptors available at the current taxonomy level.
196
695
  path : Sequence[ClassificationStep]
197
696
  Steps recorded so far in the traversal.
198
697
  depth : int
@@ -206,7 +705,7 @@ def _build_context(
206
705
  Context dictionary for prompt rendering.
207
706
  """
208
707
  template_context: Dict[str, Any] = {
209
- "taxonomy_nodes": list(current_nodes),
708
+ "taxonomy_nodes": list(node_descriptors),
210
709
  "path": [step.as_summary() for step in path],
211
710
  "depth": depth,
212
711
  }
@@ -215,54 +714,366 @@ def _build_context(
215
714
  return template_context
216
715
 
217
716
 
218
- def _resolve_node(
717
+ def _build_step_structure(
718
+ path_identifiers: Sequence[str],
719
+ ) -> type[ClassificationStep]:
720
+ """Build a step output structure constrained to taxonomy paths.
721
+
722
+ Parameters
723
+ ----------
724
+ path_identifiers : Sequence[str]
725
+ Path identifiers for nodes at the current classification step.
726
+
727
+ Returns
728
+ -------
729
+ type[ClassificationStep]
730
+ Dynamic structure class for the classification step output.
731
+ """
732
+ node_enum = _build_taxonomy_enum("TaxonomyPath", path_identifiers)
733
+ return ClassificationStep.build_for_enum(node_enum)
734
+
735
+
736
+ def _build_node_path_map(
219
737
  nodes: Sequence[TaxonomyNode],
220
- step: ClassificationStep,
221
- ) -> Optional[TaxonomyNode]:
222
- """Resolve the selected node for a classification step.
738
+ parent_path: Sequence[str],
739
+ ) -> dict[str, TaxonomyNode]:
740
+ """Build a mapping of node path identifiers to taxonomy nodes.
223
741
 
224
742
  Parameters
225
743
  ----------
226
744
  nodes : Sequence[TaxonomyNode]
227
- Candidate nodes at the current level.
745
+ Candidate nodes at the current taxonomy level.
746
+ parent_path : Sequence[str]
747
+ Path segments leading to the current taxonomy level.
748
+
749
+ Returns
750
+ -------
751
+ dict[str, TaxonomyNode]
752
+ Mapping of path identifiers to taxonomy nodes.
753
+ """
754
+ path_map: dict[str, TaxonomyNode] = {}
755
+ seen: dict[str, int] = {}
756
+ for node in nodes:
757
+ base_path = _format_path_identifier([*parent_path, node.label])
758
+ count = seen.get(base_path, 0) + 1
759
+ seen[base_path] = count
760
+ path = f"{base_path} ({count})" if count > 1 else base_path
761
+ path_map[path] = node
762
+ return path_map
763
+
764
+
765
+ def _build_node_descriptors(
766
+ node_paths: dict[str, TaxonomyNode],
767
+ ) -> list[dict[str, Any]]:
768
+ """Build node descriptors for prompt rendering.
769
+
770
+ Parameters
771
+ ----------
772
+ node_paths : dict[str, TaxonomyNode]
773
+ Mapping of path identifiers to taxonomy nodes.
774
+
775
+ Returns
776
+ -------
777
+ list[dict[str, Any]]
778
+ Node descriptor dictionaries for prompt rendering.
779
+ """
780
+ descriptors: list[dict[str, Any]] = []
781
+ for path_id, node in node_paths.items():
782
+ descriptors.append(
783
+ {
784
+ "identifier": path_id,
785
+ "label": node.label,
786
+ "description": node.description,
787
+ }
788
+ )
789
+ return descriptors
790
+
791
+
792
+ def _format_path_identifier(path_segments: Sequence[str]) -> str:
793
+ """Format path segments into a safe identifier string.
794
+
795
+ Parameters
796
+ ----------
797
+ path_segments : Sequence[str]
798
+ Path segments to format.
799
+
800
+ Returns
801
+ -------
802
+ str
803
+ Escaped path identifier string.
804
+ """
805
+ delimiter = " > "
806
+ escape_token = "\\>"
807
+ escaped_segments = [
808
+ segment.replace(delimiter, escape_token) for segment in path_segments
809
+ ]
810
+ return delimiter.join(escaped_segments)
811
+
812
+
813
+ def _build_taxonomy_enum(name: str, values: Sequence[str]) -> type[Enum]:
814
+ """Build a safe Enum from taxonomy node values.
815
+
816
+ Parameters
817
+ ----------
818
+ name : str
819
+ Name to use for the enum class.
820
+ values : Sequence[str]
821
+ Taxonomy node values to include as enum members.
822
+
823
+ Returns
824
+ -------
825
+ type[Enum]
826
+ Enum class with sanitized member names.
827
+ """
828
+ members: dict[str, str] = {}
829
+ for index, value in enumerate(values, start=1):
830
+ member_name = _sanitize_enum_member(value, index, members)
831
+ members[member_name] = value
832
+ if not members:
833
+ members["UNSPECIFIED"] = ""
834
+ return cast(type[Enum], Enum(name, members))
835
+
836
+
837
+ def _split_taxonomy_path(value: str) -> list[str]:
838
+ """Split a taxonomy identifier into its path segments.
839
+
840
+ Parameters
841
+ ----------
842
+ value : str
843
+ Taxonomy path identifier to split.
844
+
845
+ Returns
846
+ -------
847
+ list[str]
848
+ Path segments with escaped delimiters restored.
849
+ """
850
+ delimiter = " > "
851
+ escape_token = "\\>"
852
+ segments = value.split(delimiter)
853
+ return [segment.replace(escape_token, delimiter) for segment in segments]
854
+
855
+
856
+ def _sanitize_enum_member(
857
+ value: str,
858
+ index: int,
859
+ existing: dict[str, str],
860
+ ) -> str:
861
+ """Return a valid enum member name for a taxonomy value.
862
+
863
+ Parameters
864
+ ----------
865
+ value : str
866
+ Raw taxonomy value to sanitize.
867
+ index : int
868
+ Index of the value in the source list.
869
+ existing : dict[str, str]
870
+ Existing enum members to avoid collisions.
871
+
872
+ Returns
873
+ -------
874
+ str
875
+ Sanitized enum member name.
876
+ """
877
+ normalized_segments: list[str] = []
878
+ for segment in _split_taxonomy_path(value):
879
+ normalized = re.sub(r"[^0-9a-zA-Z]+", "_", segment).strip("_").upper()
880
+ if not normalized:
881
+ normalized = "VALUE"
882
+ if normalized[0].isdigit():
883
+ normalized = f"VALUE_{normalized}"
884
+ normalized_segments.append(normalized)
885
+ normalized_path = "__".join(normalized_segments) or f"VALUE_{index}"
886
+ candidate = normalized_path
887
+ suffix = 1
888
+ while candidate in existing:
889
+ candidate = f"{normalized_path}__{suffix}"
890
+ suffix += 1
891
+ return candidate
892
+
893
+
894
+ def _normalize_step_output(
895
+ step: StructureBase,
896
+ step_structure: type[StructureBase],
897
+ ) -> ClassificationStep:
898
+ """Normalize dynamic step output into a ClassificationStep.
899
+
900
+ Parameters
901
+ ----------
902
+ step : StructureBase
903
+ Raw step output returned by the agent.
904
+ step_structure : type[StructureBase]
905
+ Structure definition used to parse the agent output.
906
+
907
+ Returns
908
+ -------
909
+ ClassificationStep
910
+ Normalized classification step instance.
911
+ """
912
+ if isinstance(step, ClassificationStep):
913
+ return step
914
+ payload = step.to_json()
915
+ return ClassificationStep.from_json(payload)
916
+
917
+
918
+ def _build_input_payload(
919
+ text: str,
920
+ file_ids: str | Sequence[str] | None,
921
+ ) -> str | list[dict[str, Any]]:
922
+ """Build input payloads with optional file attachments.
923
+
924
+ Parameters
925
+ ----------
926
+ text : str
927
+ Prompt text to send to the agent.
928
+ file_ids : str or Sequence[str] or None
929
+ Optional file IDs to include as ``input_file`` attachments.
930
+
931
+ Returns
932
+ -------
933
+ str or list[dict[str, Any]]
934
+ Input payload suitable for the Agents SDK.
935
+ """
936
+ normalized_file_ids = [file_id for file_id in ensure_list(file_ids) if file_id]
937
+ if not normalized_file_ids:
938
+ return text
939
+ attachments = [
940
+ {"type": "input_file", "file_id": file_id} for file_id in normalized_file_ids
941
+ ]
942
+ return [
943
+ {
944
+ "role": "user",
945
+ "content": [{"type": "input_text", "text": text}, *attachments],
946
+ }
947
+ ]
948
+
949
+
950
+ def _extract_enum_fields(
951
+ step_structure: type[StructureBase],
952
+ ) -> dict[str, type[Enum]]:
953
+ """Return the enum field mapping for a step structure.
954
+
955
+ Parameters
956
+ ----------
957
+ step_structure : type[StructureBase]
958
+ Structure definition to inspect.
959
+
960
+ Returns
961
+ -------
962
+ dict[str, type[Enum]]
963
+ Mapping of field names to enum classes.
964
+ """
965
+ enum_fields: dict[str, type[Enum]] = {}
966
+ for field_name, model_field in step_structure.model_fields.items():
967
+ enum_cls = step_structure._extract_enum_class(model_field.annotation)
968
+ if enum_cls is not None:
969
+ enum_fields[field_name] = enum_cls
970
+ return enum_fields
971
+
972
+
973
+ def _normalize_enum_value(value: Any, enum_cls: type[Enum]) -> Any:
974
+ """Normalize enum values into raw primitives.
975
+
976
+ Parameters
977
+ ----------
978
+ value : Any
979
+ Value to normalize.
980
+ enum_cls : type[Enum]
981
+ Enum type used for normalization.
982
+
983
+ Returns
984
+ -------
985
+ Any
986
+ Primitive value suitable for ``ClassificationStep``.
987
+ """
988
+ if isinstance(value, Enum):
989
+ return value.value
990
+ if isinstance(value, list):
991
+ return [_normalize_enum_value(item, enum_cls) for item in value]
992
+ if isinstance(value, str):
993
+ if value in enum_cls._value2member_map_:
994
+ return enum_cls(value).value
995
+ if value in enum_cls.__members__:
996
+ return enum_cls.__members__[value].value
997
+ return value
998
+
999
+
1000
+ def _resolve_nodes(
1001
+ node_paths: dict[str, TaxonomyNode],
1002
+ step: ClassificationStep,
1003
+ ) -> list[TaxonomyNode]:
1004
+ """Resolve selected taxonomy nodes for a classification step.
1005
+
1006
+ Parameters
1007
+ ----------
1008
+ node_paths : dict[str, TaxonomyNode]
1009
+ Mapping of path identifiers to nodes at the current level.
228
1010
  step : ClassificationStep
229
1011
  Classification step output to resolve.
230
1012
 
231
1013
  Returns
232
1014
  -------
233
- TaxonomyNode or None
234
- Matching taxonomy node if found.
1015
+ list[TaxonomyNode]
1016
+ Matching taxonomy nodes in priority order.
235
1017
  """
236
- if step.selected_id:
237
- for node in nodes:
238
- if node.id == step.selected_id:
239
- return node
240
- if step.selected_label:
241
- for node in nodes:
242
- if node.label == step.selected_label:
243
- return node
244
- return None
245
-
246
-
247
- def _final_values(
248
- path: Sequence[ClassificationStep],
249
- ) -> tuple[Optional[str], Optional[str], Optional[float]]:
250
- """Return the final selection values from the path.
1018
+ resolved: list[TaxonomyNode] = []
1019
+ selected_nodes = _selected_nodes(step)
1020
+ if selected_nodes:
1021
+ for selected_node in selected_nodes:
1022
+ node = node_paths.get(selected_node)
1023
+ if node:
1024
+ resolved.append(node)
1025
+ return resolved
1026
+
1027
+
1028
+ def _selected_nodes(step: ClassificationStep) -> list[str]:
1029
+ """Return selected identifiers for a classification step.
251
1030
 
252
1031
  Parameters
253
1032
  ----------
254
- path : Sequence[ClassificationStep]
255
- Recorded classification steps.
1033
+ step : ClassificationStep
1034
+ Classification output to normalize.
1035
+
1036
+ Returns
1037
+ -------
1038
+ list[str]
1039
+ Selected identifiers in priority order.
1040
+ """
1041
+ if step.selected_nodes is not None:
1042
+ selected_nodes = [
1043
+ str(_normalize_enum_value(selected_node, Enum))
1044
+ for selected_node in step.selected_nodes
1045
+ if selected_node
1046
+ ]
1047
+ if selected_nodes:
1048
+ return selected_nodes
1049
+ if step.selected_node:
1050
+ return [str(_normalize_enum_value(step.selected_node, Enum))]
1051
+ return []
1052
+
1053
+
1054
+ def _max_confidence(
1055
+ current: float | None,
1056
+ candidate: float | None,
1057
+ ) -> float | None:
1058
+ """Return the higher confidence value.
1059
+
1060
+ Parameters
1061
+ ----------
1062
+ current : float or None
1063
+ Current best confidence value.
1064
+ candidate : float or None
1065
+ Candidate confidence value to compare.
256
1066
 
257
1067
  Returns
258
1068
  -------
259
- tuple[str or None, str or None, float or None]
260
- Final identifier, label, and confidence.
1069
+ float or None
1070
+ Highest confidence value available.
261
1071
  """
262
- if not path:
263
- return None, None, None
264
- last_step = path[-1]
265
- return last_step.selected_id, last_step.selected_label, last_step.confidence
1072
+ if current is None:
1073
+ return candidate
1074
+ if candidate is None:
1075
+ return current
1076
+ return max(current, candidate)
266
1077
 
267
1078
 
268
1079
  __all__ = ["TaxonomyClassifierAgent"]