openai-sdk-helpers 0.6.0__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,19 @@
1
- """Agent for taxonomy-driven text classification."""
1
+ """Recursive agent for taxonomy-driven text classification."""
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import asyncio
6
+ import re
7
+ from dataclasses import dataclass, field
8
+ from enum import Enum
5
9
  from pathlib import Path
6
- from typing import Any, Dict, Iterable, Optional, Sequence
10
+ from typing import Any, Awaitable, Dict, Iterable, Optional, Sequence, cast
7
11
 
8
12
  from ..structure import (
9
13
  ClassificationResult,
10
14
  ClassificationStep,
11
15
  ClassificationStopReason,
16
+ StructureBase,
12
17
  TaxonomyNode,
13
18
  )
14
19
  from .base import AgentBase
@@ -16,7 +21,7 @@ from .configuration import AgentConfiguration
16
21
 
17
22
 
18
23
  class TaxonomyClassifierAgent(AgentBase):
19
- """Classify text by traversing a taxonomy level by level.
24
+ """Classify text by recursively traversing a taxonomy.
20
25
 
21
26
  Parameters
22
27
  ----------
@@ -28,7 +33,17 @@ class TaxonomyClassifierAgent(AgentBase):
28
33
  Methods
29
34
  -------
30
35
  run_agent(text, taxonomy, context, max_depth)
31
- Classify text by walking the taxonomy tree.
36
+ Classify text by recursively walking the taxonomy tree.
37
+
38
+ Examples
39
+ --------
40
+ Create a classifier with a flat taxonomy:
41
+
42
+ >>> taxonomy = [
43
+ ... TaxonomyNode(label="Billing"),
44
+ ... TaxonomyNode(label="Support"),
45
+ ... ]
46
+ >>> agent = TaxonomyClassifierAgent(model="gpt-4o-mini", taxonomy=taxonomy)
32
47
  """
33
48
 
34
49
  def __init__(
@@ -36,6 +51,7 @@ class TaxonomyClassifierAgent(AgentBase):
36
51
  *,
37
52
  template_path: Path | str | None = None,
38
53
  model: str | None = None,
54
+ taxonomy: TaxonomyNode | Sequence[TaxonomyNode],
39
55
  ) -> None:
40
56
  """Initialize the taxonomy classifier agent configuration.
41
57
 
@@ -45,21 +61,27 @@ class TaxonomyClassifierAgent(AgentBase):
45
61
  Optional template file path for prompt rendering.
46
62
  model : str | None, default=None
47
63
  Model identifier to use for classification.
64
+ taxonomy : TaxonomyNode | Sequence[TaxonomyNode]
65
+ Root taxonomy node or list of root nodes.
48
66
 
49
67
  Raises
50
68
  ------
51
69
  ValueError
52
- If the model is not provided.
70
+ If the taxonomy is empty.
53
71
 
54
72
  Examples
55
73
  --------
56
- >>> classifier = TaxonomyClassifierAgent(model="gpt-4o-mini")
74
+ >>> classifier = TaxonomyClassifierAgent(model="gpt-4o-mini", taxonomy=[])
57
75
  """
76
+ self._taxonomy = taxonomy
77
+ self._root_nodes = _normalize_roots(taxonomy)
78
+ if not self._root_nodes:
79
+ raise ValueError("taxonomy must include at least one node")
58
80
  resolved_template_path = template_path or _default_template_path()
59
81
  configuration = AgentConfiguration(
60
82
  name="taxonomy_classifier",
61
83
  instructions="Agent instructions",
62
- description="Classify text by traversing taxonomy levels.",
84
+ description="Classify text by traversing taxonomy levels recursively.",
63
85
  template_path=resolved_template_path,
64
86
  output_structure=ClassificationStep,
65
87
  model=model,
@@ -69,84 +91,362 @@ class TaxonomyClassifierAgent(AgentBase):
69
91
  async def run_agent(
70
92
  self,
71
93
  text: str,
72
- taxonomy: TaxonomyNode | Sequence[TaxonomyNode],
73
94
  *,
74
95
  context: Optional[Dict[str, Any]] = None,
75
96
  max_depth: Optional[int] = None,
97
+ confidence_threshold: float | None = None,
98
+ single_class: bool = False,
76
99
  ) -> ClassificationResult:
77
- """Classify ``text`` by iterating over taxonomy levels.
100
+ """Classify ``text`` by recursively walking taxonomy levels.
78
101
 
79
102
  Parameters
80
103
  ----------
81
104
  text : str
82
105
  Source text to classify.
83
- taxonomy : TaxonomyNode or Sequence[TaxonomyNode]
84
- Root taxonomy node or list of root nodes to traverse.
85
106
  context : dict or None, default=None
86
107
  Additional context values to merge into the prompt.
87
108
  max_depth : int or None, default=None
88
109
  Maximum depth to traverse before stopping.
110
+ confidence_threshold : float or None, default=None
111
+ Minimum confidence required to accept a classification step.
112
+ single_class : bool, default=False
113
+ Whether to keep only the highest-priority selection per step.
89
114
 
90
115
  Returns
91
116
  -------
92
117
  ClassificationResult
93
118
  Structured classification result describing the traversal.
94
119
 
95
- Raises
96
- ------
97
- ValueError
98
- If ``taxonomy`` is empty.
120
+ Examples
121
+ --------
122
+ >>> taxonomy = TaxonomyNode(label="Finance")
123
+ >>> agent = TaxonomyClassifierAgent(model="gpt-4o-mini", taxonomy=taxonomy)
124
+ >>> isinstance(agent.root_nodes, list)
125
+ True
99
126
  """
100
- roots = _normalize_roots(taxonomy)
101
- if not roots:
102
- raise ValueError("taxonomy must include at least one node")
127
+ state = _TraversalState()
128
+ await self._classify_nodes(
129
+ text=text,
130
+ nodes=list(self._root_nodes),
131
+ depth=0,
132
+ parent_path=[],
133
+ context=context,
134
+ max_depth=max_depth,
135
+ confidence_threshold=confidence_threshold,
136
+ single_class=single_class,
137
+ state=state,
138
+ )
103
139
 
104
- path: list[ClassificationStep] = []
105
- depth = 0
106
- stop_reason = ClassificationStopReason.NO_MATCH
107
- current_nodes = list(roots)
108
-
109
- while current_nodes:
110
- if max_depth is not None and depth >= max_depth:
111
- stop_reason = ClassificationStopReason.MAX_DEPTH
112
- break
113
-
114
- template_context = _build_context(
115
- current_nodes=current_nodes,
116
- path=path,
117
- depth=depth,
118
- context=context,
119
- )
120
- step: ClassificationStep = await self.run_async(
121
- input=text,
122
- context=template_context,
123
- output_structure=ClassificationStep,
140
+ final_nodes_value = state.final_nodes or None
141
+ final_node = state.final_nodes[0] if state.final_nodes else None
142
+ stop_reason = _resolve_stop_reason(state)
143
+ return ClassificationResult(
144
+ final_node=final_node,
145
+ final_nodes=final_nodes_value,
146
+ confidence=state.best_confidence,
147
+ stop_reason=stop_reason,
148
+ path=state.path,
149
+ path_nodes=state.path_nodes,
150
+ )
151
+
152
+ async def _classify_nodes(
153
+ self,
154
+ *,
155
+ text: str,
156
+ nodes: list[TaxonomyNode],
157
+ depth: int,
158
+ parent_path: list[str],
159
+ context: Optional[Dict[str, Any]],
160
+ max_depth: Optional[int],
161
+ confidence_threshold: float | None,
162
+ single_class: bool,
163
+ state: "_TraversalState",
164
+ ) -> None:
165
+ """Classify a taxonomy level and recursively traverse children.
166
+
167
+ Parameters
168
+ ----------
169
+ text : str
170
+ Source text to classify.
171
+ nodes : list[TaxonomyNode]
172
+ Candidate taxonomy nodes for the current level.
173
+ depth : int
174
+ Current traversal depth.
175
+ context : dict or None
176
+ Additional context values to merge into the prompt.
177
+ max_depth : int or None
178
+ Maximum traversal depth before stopping.
179
+ confidence_threshold : float or None
180
+ Minimum confidence required to accept a classification step.
181
+ single_class : bool
182
+ Whether to keep only the highest-priority selection per step.
183
+ state : _TraversalState
184
+ Aggregated traversal state.
185
+ """
186
+ if max_depth is not None and depth >= max_depth:
187
+ state.saw_max_depth = True
188
+ return
189
+ if not nodes:
190
+ return
191
+
192
+ node_paths = _build_node_path_map(nodes, parent_path)
193
+ template_context = _build_context(
194
+ node_descriptors=_build_node_descriptors(node_paths),
195
+ path=state.path,
196
+ depth=depth,
197
+ context=context,
198
+ )
199
+ step_structure = _build_step_structure(list(node_paths.keys()))
200
+ raw_step = await self.run_async(
201
+ input=text,
202
+ context=template_context,
203
+ output_structure=step_structure,
204
+ )
205
+ step = _normalize_step_output(raw_step, step_structure)
206
+ state.path.append(step)
207
+
208
+ if (
209
+ confidence_threshold is not None
210
+ and step.confidence is not None
211
+ and step.confidence < confidence_threshold
212
+ ):
213
+ return
214
+
215
+ resolved_nodes = _resolve_nodes(node_paths, step)
216
+ if resolved_nodes:
217
+ if single_class:
218
+ resolved_nodes = resolved_nodes[:1]
219
+ state.path_nodes.extend(resolved_nodes)
220
+
221
+ if step.stop_reason.is_terminal:
222
+ if resolved_nodes:
223
+ state.final_nodes.extend(resolved_nodes)
224
+ state.best_confidence = _max_confidence(
225
+ state.best_confidence, step.confidence
226
+ )
227
+ state.saw_terminal_stop = True
228
+ return
229
+
230
+ if not resolved_nodes:
231
+ return
232
+
233
+ base_path_len = len(state.path)
234
+ base_path_nodes_len = len(state.path_nodes)
235
+ child_tasks: list[tuple[Awaitable["_TraversalState"], int]] = []
236
+ for node in resolved_nodes:
237
+ if node.children:
238
+ sub_agent = self._build_sub_agent(list(node.children))
239
+ sub_state = _copy_traversal_state(state)
240
+ base_final_nodes_len = len(state.final_nodes)
241
+ child_tasks.append(
242
+ (
243
+ self._classify_subtree(
244
+ sub_agent=sub_agent,
245
+ text=text,
246
+ nodes=list(node.children),
247
+ depth=depth + 1,
248
+ parent_path=[*parent_path, node.label],
249
+ context=context,
250
+ max_depth=max_depth,
251
+ confidence_threshold=confidence_threshold,
252
+ single_class=single_class,
253
+ state=sub_state,
254
+ ),
255
+ base_final_nodes_len,
256
+ )
257
+ )
258
+ else:
259
+ state.saw_no_children = True
260
+ state.final_nodes.append(node)
261
+ state.best_confidence = _max_confidence(
262
+ state.best_confidence, step.confidence
263
+ )
264
+ if child_tasks:
265
+ child_states = await asyncio.gather(
266
+ *(child_task for child_task, _ in child_tasks)
124
267
  )
125
- path.append(step)
126
- stop_reason = step.stop_reason
268
+ for child_state, (_, base_final_nodes_len) in zip(
269
+ child_states, child_tasks, strict=True
270
+ ):
271
+ state.path.extend(child_state.path[base_path_len:])
272
+ state.path_nodes.extend(child_state.path_nodes[base_path_nodes_len:])
273
+ state.final_nodes.extend(child_state.final_nodes[base_final_nodes_len:])
274
+ state.best_confidence = _max_confidence(
275
+ state.best_confidence, child_state.best_confidence
276
+ )
277
+ state.saw_max_depth = state.saw_max_depth or child_state.saw_max_depth
278
+ state.saw_no_children = (
279
+ state.saw_no_children or child_state.saw_no_children
280
+ )
281
+ state.saw_terminal_stop = (
282
+ state.saw_terminal_stop or child_state.saw_terminal_stop
283
+ )
284
+
285
+ @property
286
+ def taxonomy(self) -> TaxonomyNode | Sequence[TaxonomyNode]:
287
+ """Return the root taxonomy node(s).
127
288
 
128
- if step.stop_reason.is_terminal:
129
- break
289
+ Returns
290
+ -------
291
+ TaxonomyNode or Sequence[TaxonomyNode]
292
+ Root taxonomy node or list of root nodes.
293
+ """
294
+ return self._taxonomy
130
295
 
131
- selected_node = _resolve_node(current_nodes, step)
132
- if selected_node is None:
133
- stop_reason = ClassificationStopReason.NO_MATCH
134
- break
135
- if not selected_node.children:
136
- stop_reason = ClassificationStopReason.NO_CHILDREN
137
- break
296
+ @property
297
+ def root_nodes(self) -> list[TaxonomyNode]:
298
+ """Return the list of root taxonomy nodes.
138
299
 
139
- current_nodes = list(selected_node.children)
140
- depth += 1
300
+ Returns
301
+ -------
302
+ list[TaxonomyNode]
303
+ List of root taxonomy nodes.
304
+ """
305
+ return self._root_nodes
141
306
 
142
- final_id, final_label, confidence = _final_values(path)
143
- return ClassificationResult(
144
- final_id=final_id,
145
- final_label=final_label,
146
- confidence=confidence,
147
- stop_reason=stop_reason,
148
- path=path,
307
+ def _build_sub_agent(
308
+ self,
309
+ nodes: Sequence[TaxonomyNode],
310
+ ) -> "TaxonomyClassifierAgent":
311
+ """Build a classifier agent for a taxonomy subtree.
312
+
313
+ Parameters
314
+ ----------
315
+ nodes : Sequence[TaxonomyNode]
316
+ Taxonomy nodes to use as the sub-agent's root taxonomy.
317
+
318
+ Returns
319
+ -------
320
+ TaxonomyClassifierAgent
321
+ Configured classifier agent for the taxonomy slice.
322
+ """
323
+ sub_agent = TaxonomyClassifierAgent(
324
+ template_path=self._template_path,
325
+ model=self._model,
326
+ taxonomy=list(nodes),
149
327
  )
328
+ sub_agent.run_async = self.run_async
329
+ return sub_agent
330
+
331
+ async def _classify_subtree(
332
+ self,
333
+ *,
334
+ sub_agent: "TaxonomyClassifierAgent",
335
+ text: str,
336
+ nodes: list[TaxonomyNode],
337
+ depth: int,
338
+ parent_path: list[str],
339
+ context: Optional[Dict[str, Any]],
340
+ max_depth: Optional[int],
341
+ confidence_threshold: float | None,
342
+ single_class: bool,
343
+ state: "_TraversalState",
344
+ ) -> "_TraversalState":
345
+ """Classify a taxonomy subtree and return the traversal state.
346
+
347
+ Parameters
348
+ ----------
349
+ sub_agent : TaxonomyClassifierAgent
350
+ Sub-agent configured for the subtree traversal.
351
+ text : str
352
+ Source text to classify.
353
+ nodes : list[TaxonomyNode]
354
+ Candidate taxonomy nodes for the subtree.
355
+ depth : int
356
+ Current traversal depth.
357
+ parent_path : list[str]
358
+ Path segments leading to the current subtree.
359
+ context : dict or None
360
+ Additional context values to merge into the prompt.
361
+ max_depth : int or None
362
+ Maximum traversal depth before stopping.
363
+ confidence_threshold : float or None
364
+ Minimum confidence required to accept a classification step.
365
+ single_class : bool
366
+ Whether to keep only the highest-priority selection per step.
367
+ state : _TraversalState
368
+ Traversal state to populate for the subtree.
369
+
370
+ Returns
371
+ -------
372
+ _TraversalState
373
+ Populated traversal state for the subtree.
374
+ """
375
+ await sub_agent._classify_nodes(
376
+ text=text,
377
+ nodes=nodes,
378
+ depth=depth,
379
+ parent_path=parent_path,
380
+ context=context,
381
+ max_depth=max_depth,
382
+ confidence_threshold=confidence_threshold,
383
+ single_class=single_class,
384
+ state=state,
385
+ )
386
+ return state
387
+
388
+
389
+ @dataclass
390
+ class _TraversalState:
391
+ """Track recursive traversal state."""
392
+
393
+ path: list[ClassificationStep] = field(default_factory=list)
394
+ path_nodes: list[TaxonomyNode] = field(default_factory=list)
395
+ final_nodes: list[TaxonomyNode] = field(default_factory=list)
396
+ best_confidence: float | None = None
397
+ saw_max_depth: bool = False
398
+ saw_no_children: bool = False
399
+ saw_terminal_stop: bool = False
400
+
401
+
402
+ def _copy_traversal_state(state: _TraversalState) -> _TraversalState:
403
+ """Copy traversal state for parallel subtree execution.
404
+
405
+ Parameters
406
+ ----------
407
+ state : _TraversalState
408
+ Traversal state to clone.
409
+
410
+ Returns
411
+ -------
412
+ _TraversalState
413
+ Cloned traversal state with copied collections.
414
+ """
415
+ return _TraversalState(
416
+ path=list(state.path),
417
+ path_nodes=list(state.path_nodes),
418
+ final_nodes=list(state.final_nodes),
419
+ best_confidence=state.best_confidence,
420
+ saw_max_depth=state.saw_max_depth,
421
+ saw_no_children=state.saw_no_children,
422
+ saw_terminal_stop=state.saw_terminal_stop,
423
+ )
424
+
425
+
426
+ def _resolve_stop_reason(state: _TraversalState) -> ClassificationStopReason:
427
+ """Resolve the final stop reason based on traversal state.
428
+
429
+ Parameters
430
+ ----------
431
+ state : _TraversalState
432
+ Traversal state to inspect.
433
+
434
+ Returns
435
+ -------
436
+ ClassificationStopReason
437
+ Resolved stop reason.
438
+ """
439
+ if state.saw_terminal_stop:
440
+ return ClassificationStopReason.STOP
441
+ if state.final_nodes and state.saw_no_children:
442
+ return ClassificationStopReason.NO_CHILDREN
443
+ if state.final_nodes:
444
+ return ClassificationStopReason.STOP
445
+ if state.saw_max_depth:
446
+ return ClassificationStopReason.MAX_DEPTH
447
+ if state.saw_no_children:
448
+ return ClassificationStopReason.NO_CHILDREN
449
+ return ClassificationStopReason.NO_MATCH
150
450
 
151
451
 
152
452
  def _normalize_roots(
@@ -156,7 +456,7 @@ def _normalize_roots(
156
456
 
157
457
  Parameters
158
458
  ----------
159
- taxonomy : TaxonomyNode or Sequence[TaxonomyNode]
459
+ taxonomy : TaxonomyNode | Sequence[TaxonomyNode]
160
460
  Root taxonomy node or list of root nodes.
161
461
 
162
462
  Returns
@@ -182,7 +482,7 @@ def _default_template_path() -> Path:
182
482
 
183
483
  def _build_context(
184
484
  *,
185
- current_nodes: Iterable[TaxonomyNode],
485
+ node_descriptors: Iterable[dict[str, Any]],
186
486
  path: Sequence[ClassificationStep],
187
487
  depth: int,
188
488
  context: Optional[Dict[str, Any]],
@@ -191,8 +491,8 @@ def _build_context(
191
491
 
192
492
  Parameters
193
493
  ----------
194
- current_nodes : Iterable[TaxonomyNode]
195
- Nodes available at the current taxonomy level.
494
+ node_descriptors : Iterable[dict[str, Any]]
495
+ Node descriptors available at the current taxonomy level.
196
496
  path : Sequence[ClassificationStep]
197
497
  Steps recorded so far in the traversal.
198
498
  depth : int
@@ -206,7 +506,7 @@ def _build_context(
206
506
  Context dictionary for prompt rendering.
207
507
  """
208
508
  template_context: Dict[str, Any] = {
209
- "taxonomy_nodes": list(current_nodes),
509
+ "taxonomy_nodes": list(node_descriptors),
210
510
  "path": [step.as_summary() for step in path],
211
511
  "depth": depth,
212
512
  }
@@ -215,54 +515,334 @@ def _build_context(
215
515
  return template_context
216
516
 
217
517
 
218
- def _resolve_node(
518
+ def _build_step_structure(
519
+ path_identifiers: Sequence[str],
520
+ ) -> type[ClassificationStep]:
521
+ """Build a step output structure constrained to taxonomy paths.
522
+
523
+ Parameters
524
+ ----------
525
+ path_identifiers : Sequence[str]
526
+ Path identifiers for nodes at the current classification step.
527
+
528
+ Returns
529
+ -------
530
+ type[ClassificationStep]
531
+ Dynamic structure class for the classification step output.
532
+ """
533
+ node_enum = _build_taxonomy_enum("TaxonomyPath", path_identifiers)
534
+ return ClassificationStep.build_for_enum(node_enum)
535
+
536
+
537
+ def _build_node_path_map(
219
538
  nodes: Sequence[TaxonomyNode],
220
- step: ClassificationStep,
221
- ) -> Optional[TaxonomyNode]:
222
- """Resolve the selected node for a classification step.
539
+ parent_path: Sequence[str],
540
+ ) -> dict[str, TaxonomyNode]:
541
+ """Build a mapping of node path identifiers to taxonomy nodes.
223
542
 
224
543
  Parameters
225
544
  ----------
226
545
  nodes : Sequence[TaxonomyNode]
227
- Candidate nodes at the current level.
546
+ Candidate nodes at the current taxonomy level.
547
+ parent_path : Sequence[str]
548
+ Path segments leading to the current taxonomy level.
549
+
550
+ Returns
551
+ -------
552
+ dict[str, TaxonomyNode]
553
+ Mapping of path identifiers to taxonomy nodes.
554
+ """
555
+ path_map: dict[str, TaxonomyNode] = {}
556
+ seen: dict[str, int] = {}
557
+ for node in nodes:
558
+ base_path = _format_path_identifier([*parent_path, node.label])
559
+ count = seen.get(base_path, 0) + 1
560
+ seen[base_path] = count
561
+ path = f"{base_path} ({count})" if count > 1 else base_path
562
+ path_map[path] = node
563
+ return path_map
564
+
565
+
566
+ def _build_node_descriptors(
567
+ node_paths: dict[str, TaxonomyNode],
568
+ ) -> list[dict[str, Any]]:
569
+ """Build node descriptors for prompt rendering.
570
+
571
+ Parameters
572
+ ----------
573
+ node_paths : dict[str, TaxonomyNode]
574
+ Mapping of path identifiers to taxonomy nodes.
575
+
576
+ Returns
577
+ -------
578
+ list[dict[str, Any]]
579
+ Node descriptor dictionaries for prompt rendering.
580
+ """
581
+ descriptors: list[dict[str, Any]] = []
582
+ for path_id, node in node_paths.items():
583
+ descriptors.append(
584
+ {
585
+ "identifier": path_id,
586
+ "label": node.label,
587
+ "description": node.description,
588
+ }
589
+ )
590
+ return descriptors
591
+
592
+
593
+ def _format_path_identifier(path_segments: Sequence[str]) -> str:
594
+ """Format path segments into a safe identifier string.
595
+
596
+ Parameters
597
+ ----------
598
+ path_segments : Sequence[str]
599
+ Path segments to format.
600
+
601
+ Returns
602
+ -------
603
+ str
604
+ Escaped path identifier string.
605
+ """
606
+ delimiter = " > "
607
+ escape_token = "\\>"
608
+ escaped_segments = [
609
+ segment.replace(delimiter, escape_token) for segment in path_segments
610
+ ]
611
+ return delimiter.join(escaped_segments)
612
+
613
+
614
+ def _build_taxonomy_enum(name: str, values: Sequence[str]) -> type[Enum]:
615
+ """Build a safe Enum from taxonomy node values.
616
+
617
+ Parameters
618
+ ----------
619
+ name : str
620
+ Name to use for the enum class.
621
+ values : Sequence[str]
622
+ Taxonomy node values to include as enum members.
623
+
624
+ Returns
625
+ -------
626
+ type[Enum]
627
+ Enum class with sanitized member names.
628
+ """
629
+ members: dict[str, str] = {}
630
+ for index, value in enumerate(values, start=1):
631
+ member_name = _sanitize_enum_member(value, index, members)
632
+ members[member_name] = value
633
+ if not members:
634
+ members["UNSPECIFIED"] = ""
635
+ return cast(type[Enum], Enum(name, members))
636
+
637
+
638
+ def _split_taxonomy_path(value: str) -> list[str]:
639
+ """Split a taxonomy identifier into its path segments.
640
+
641
+ Parameters
642
+ ----------
643
+ value : str
644
+ Taxonomy path identifier to split.
645
+
646
+ Returns
647
+ -------
648
+ list[str]
649
+ Path segments with escaped delimiters restored.
650
+ """
651
+ delimiter = " > "
652
+ escape_token = "\\>"
653
+ segments = value.split(delimiter)
654
+ return [segment.replace(escape_token, delimiter) for segment in segments]
655
+
656
+
657
+ def _sanitize_enum_member(
658
+ value: str,
659
+ index: int,
660
+ existing: dict[str, str],
661
+ ) -> str:
662
+ """Return a valid enum member name for a taxonomy value.
663
+
664
+ Parameters
665
+ ----------
666
+ value : str
667
+ Raw taxonomy value to sanitize.
668
+ index : int
669
+ Index of the value in the source list.
670
+ existing : dict[str, str]
671
+ Existing enum members to avoid collisions.
672
+
673
+ Returns
674
+ -------
675
+ str
676
+ Sanitized enum member name.
677
+ """
678
+ normalized_segments: list[str] = []
679
+ for segment in _split_taxonomy_path(value):
680
+ normalized = re.sub(r"[^0-9a-zA-Z]+", "_", segment).strip("_").upper()
681
+ if not normalized:
682
+ normalized = "VALUE"
683
+ if normalized[0].isdigit():
684
+ normalized = f"VALUE_{normalized}"
685
+ normalized_segments.append(normalized)
686
+ normalized_path = "__".join(normalized_segments) or f"VALUE_{index}"
687
+ candidate = normalized_path
688
+ suffix = 1
689
+ while candidate in existing:
690
+ candidate = f"{normalized_path}__{suffix}"
691
+ suffix += 1
692
+ return candidate
693
+
694
+
695
+ def _normalize_step_output(
696
+ step: StructureBase,
697
+ step_structure: type[StructureBase],
698
+ ) -> ClassificationStep:
699
+ """Normalize dynamic step output into a ClassificationStep.
700
+
701
+ Parameters
702
+ ----------
703
+ step : StructureBase
704
+ Raw step output returned by the agent.
705
+ step_structure : type[StructureBase]
706
+ Structure definition used to parse the agent output.
707
+
708
+ Returns
709
+ -------
710
+ ClassificationStep
711
+ Normalized classification step instance.
712
+ """
713
+ if isinstance(step, ClassificationStep):
714
+ return step
715
+ payload = step.to_json()
716
+ return ClassificationStep.from_json(payload)
717
+
718
+
719
+ def _extract_enum_fields(
720
+ step_structure: type[StructureBase],
721
+ ) -> dict[str, type[Enum]]:
722
+ """Return the enum field mapping for a step structure.
723
+
724
+ Parameters
725
+ ----------
726
+ step_structure : type[StructureBase]
727
+ Structure definition to inspect.
728
+
729
+ Returns
730
+ -------
731
+ dict[str, type[Enum]]
732
+ Mapping of field names to enum classes.
733
+ """
734
+ enum_fields: dict[str, type[Enum]] = {}
735
+ for field_name, model_field in step_structure.model_fields.items():
736
+ enum_cls = step_structure._extract_enum_class(model_field.annotation)
737
+ if enum_cls is not None:
738
+ enum_fields[field_name] = enum_cls
739
+ return enum_fields
740
+
741
+
742
+ def _normalize_enum_value(value: Any, enum_cls: type[Enum]) -> Any:
743
+ """Normalize enum values into raw primitives.
744
+
745
+ Parameters
746
+ ----------
747
+ value : Any
748
+ Value to normalize.
749
+ enum_cls : type[Enum]
750
+ Enum type used for normalization.
751
+
752
+ Returns
753
+ -------
754
+ Any
755
+ Primitive value suitable for ``ClassificationStep``.
756
+ """
757
+ if isinstance(value, Enum):
758
+ return value.value
759
+ if isinstance(value, list):
760
+ return [_normalize_enum_value(item, enum_cls) for item in value]
761
+ if isinstance(value, str):
762
+ if value in enum_cls._value2member_map_:
763
+ return enum_cls(value).value
764
+ if value in enum_cls.__members__:
765
+ return enum_cls.__members__[value].value
766
+ return value
767
+
768
+
769
+ def _resolve_nodes(
770
+ node_paths: dict[str, TaxonomyNode],
771
+ step: ClassificationStep,
772
+ ) -> list[TaxonomyNode]:
773
+ """Resolve selected taxonomy nodes for a classification step.
774
+
775
+ Parameters
776
+ ----------
777
+ node_paths : dict[str, TaxonomyNode]
778
+ Mapping of path identifiers to nodes at the current level.
228
779
  step : ClassificationStep
229
780
  Classification step output to resolve.
230
781
 
231
782
  Returns
232
783
  -------
233
- TaxonomyNode or None
234
- Matching taxonomy node if found.
784
+ list[TaxonomyNode]
785
+ Matching taxonomy nodes in priority order.
235
786
  """
236
- if step.selected_id:
237
- for node in nodes:
238
- if node.id == step.selected_id:
239
- return node
240
- if step.selected_label:
241
- for node in nodes:
242
- if node.label == step.selected_label:
243
- return node
244
- return None
245
-
246
-
247
- def _final_values(
248
- path: Sequence[ClassificationStep],
249
- ) -> tuple[Optional[str], Optional[str], Optional[float]]:
250
- """Return the final selection values from the path.
787
+ resolved: list[TaxonomyNode] = []
788
+ selected_nodes = _selected_nodes(step)
789
+ if selected_nodes:
790
+ for selected_node in selected_nodes:
791
+ node = node_paths.get(selected_node)
792
+ if node:
793
+ resolved.append(node)
794
+ return resolved
795
+
796
+
797
+ def _selected_nodes(step: ClassificationStep) -> list[str]:
798
+ """Return selected identifiers for a classification step.
251
799
 
252
800
  Parameters
253
801
  ----------
254
- path : Sequence[ClassificationStep]
255
- Recorded classification steps.
802
+ step : ClassificationStep
803
+ Classification output to normalize.
804
+
805
+ Returns
806
+ -------
807
+ list[str]
808
+ Selected identifiers in priority order.
809
+ """
810
+ if step.selected_nodes is not None:
811
+ selected_nodes = [
812
+ str(_normalize_enum_value(selected_node, Enum))
813
+ for selected_node in step.selected_nodes
814
+ if selected_node
815
+ ]
816
+ if selected_nodes:
817
+ return selected_nodes
818
+ if step.selected_node:
819
+ return [str(_normalize_enum_value(step.selected_node, Enum))]
820
+ return []
821
+
822
+
823
+ def _max_confidence(
824
+ current: float | None,
825
+ candidate: float | None,
826
+ ) -> float | None:
827
+ """Return the higher confidence value.
828
+
829
+ Parameters
830
+ ----------
831
+ current : float or None
832
+ Current best confidence value.
833
+ candidate : float or None
834
+ Candidate confidence value to compare.
256
835
 
257
836
  Returns
258
837
  -------
259
- tuple[str or None, str or None, float or None]
260
- Final identifier, label, and confidence.
838
+ float or None
839
+ Highest confidence value available.
261
840
  """
262
- if not path:
263
- return None, None, None
264
- last_step = path[-1]
265
- return last_step.selected_id, last_step.selected_label, last_step.confidence
841
+ if current is None:
842
+ return candidate
843
+ if candidate is None:
844
+ return current
845
+ return max(current, candidate)
266
846
 
267
847
 
268
848
  __all__ = ["TaxonomyClassifierAgent"]