auditi 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
auditi/decorators.py ADDED
@@ -0,0 +1,1441 @@
1
+ """
2
+ Tracing decorators for instrumenting AI agents, tools, and LLM calls.
3
+
4
+ This version uses the provider abstraction layer for robust multi-provider support.
5
+ """
6
+
7
+ import functools
8
+ import json
9
+ import os
10
+ from datetime import datetime
11
+ from uuid import uuid4
12
+ import asyncio
13
+ import inspect
14
+ from typing import Optional, Callable, Any
15
+
16
+ from .types import TraceInput, SpanInput
17
+ from .context import (
18
+ get_current_trace,
19
+ set_current_trace,
20
+ get_current_span,
21
+ push_span,
22
+ pop_span,
23
+ get_context,
24
+ )
25
+ from .client import get_client
26
+ from .evaluator import BaseEvaluator
27
+ from .providers import detect_provider
28
+ from .events import EventType, INTERNAL_EVENTS, CONTENT_EVENTS
29
+
30
+ # Debug flag - set via environment variable
31
+ DEBUG = os.getenv("AUDITI_DEBUG", "false").lower() in ("true", "1", "yes")
32
+
33
+
34
+ def _debug_log(message: str, data: Any = None) -> None:
35
+ """Helper function to conditionally log debug information."""
36
+ if DEBUG:
37
+ print(f"[Auditi Debug] {message}")
38
+ if data is not None:
39
+ try:
40
+ print(json.dumps(data, indent=2, default=str))
41
+ except Exception as e:
42
+ print(f"[Auditi Debug] Could not serialize data: {e}")
43
+ print(f"Raw data: {data}")
44
+
45
+
46
+ def _apply_usage_to_span(span: SpanInput, usage: Any, response: Any = None) -> None:
47
+ """
48
+ Apply usage metrics to a span using provider abstraction.
49
+
50
+ Args:
51
+ span: The span to update
52
+ usage: Raw usage object/dict from API response
53
+ response: Optional full response object for provider detection
54
+ """
55
+ # Detect provider
56
+ provider = detect_provider(model=span.model, response=response)
57
+
58
+ _debug_log(
59
+ f"Detected provider '{provider.name}' for span '{span.name}'",
60
+ {"model": span.model, "provider": provider.name},
61
+ )
62
+
63
+ # Extract usage using provider-specific logic
64
+ input_tokens, output_tokens, total_tokens = provider.extract_usage(usage)
65
+
66
+ _debug_log(
67
+ f"Extracted usage for span '{span.name}':",
68
+ {
69
+ "input_tokens": input_tokens,
70
+ "output_tokens": output_tokens,
71
+ "total_tokens": total_tokens,
72
+ "provider": provider.name,
73
+ },
74
+ )
75
+
76
+ # Update span
77
+ if input_tokens is not None:
78
+ span.input_tokens = input_tokens
79
+ if output_tokens is not None:
80
+ span.output_tokens = output_tokens
81
+ if total_tokens is not None:
82
+ span.tokens = total_tokens
83
+
84
+ # Calculate cost using provider-specific pricing
85
+ span.cost = provider.calculate_cost(span.model, input_tokens, output_tokens)
86
+
87
+ _debug_log(
88
+ f"Calculated cost for span '{span.name}':",
89
+ {"cost": span.cost, "model": span.model, "provider": provider.name},
90
+ )
91
+
92
+
93
+ def _apply_usage_to_trace(trace: TraceInput, usage: Any, model: Optional[str] = None) -> None:
94
+ """
95
+ Apply usage metrics to a trace using provider abstraction.
96
+
97
+ Args:
98
+ trace: The trace to update
99
+ usage: Raw usage object/dict from API response
100
+ model: Optional model name for provider detection
101
+ """
102
+ # Detect provider
103
+ provider = detect_provider(model=model)
104
+
105
+ # Extract usage
106
+ input_tokens, output_tokens, total_tokens = provider.extract_usage(usage)
107
+
108
+ _debug_log(
109
+ f"Applying usage to trace '{trace.name}':",
110
+ {
111
+ "input_tokens": input_tokens,
112
+ "output_tokens": output_tokens,
113
+ "total_tokens": total_tokens,
114
+ "current_trace_tokens": trace.total_tokens,
115
+ "provider": provider.name,
116
+ },
117
+ )
118
+
119
+ if total_tokens is None:
120
+ return
121
+
122
+ # Accumulate tokens
123
+ trace.total_tokens = (trace.total_tokens or 0) + total_tokens
124
+
125
+ # Calculate incremental cost
126
+ incremental_cost = provider.calculate_cost(model, input_tokens, output_tokens)
127
+ trace.cost = (trace.cost or 0.0) + incremental_cost
128
+
129
+
130
+ def _execute_as_standalone_trace(
131
+ func: Callable,
132
+ args: tuple,
133
+ kwargs: dict,
134
+ span_type: str,
135
+ name: Optional[str],
136
+ model: Optional[str],
137
+ ) -> Any:
138
+ """
139
+ Execute a function as a standalone trace (no parent trace required).
140
+
141
+ This creates a simple trace with the function result as the only content.
142
+ Used when @trace_llm or @trace_tool is called with standalone=True
143
+ outside of a @trace_agent context.
144
+ """
145
+ client = get_client()
146
+ trace_id = uuid4()
147
+ span_id = uuid4()
148
+ start_time = datetime.utcnow()
149
+ func_name = name or func.__name__
150
+
151
+ # Extract user input from args/kwargs
152
+ user_input = ""
153
+ if args:
154
+ first_arg = args[0]
155
+ if isinstance(first_arg, str):
156
+ user_input = first_arg
157
+ elif isinstance(first_arg, list):
158
+ # Could be messages array
159
+ user_input = str(first_arg)
160
+ elif isinstance(first_arg, dict):
161
+ user_input = first_arg.get("content") or first_arg.get("message") or str(first_arg)
162
+ else:
163
+ user_input = str(first_arg)
164
+
165
+ if not user_input:
166
+ user_input = kwargs.get("prompt") or kwargs.get("message") or kwargs.get("query") or ""
167
+
168
+ # Create trace
169
+ trace = TraceInput(
170
+ id=trace_id,
171
+ start_time=start_time,
172
+ user_input=user_input,
173
+ name=func_name,
174
+ tags=["standalone", span_type],
175
+ )
176
+ set_current_trace(trace)
177
+
178
+ # Create span
179
+ span = SpanInput(
180
+ id=span_id,
181
+ trace_id=trace_id,
182
+ name=func_name,
183
+ span_type=span_type,
184
+ start_time=start_time,
185
+ inputs={"prompt": user_input} if user_input else {},
186
+ model=model,
187
+ )
188
+ push_span(span)
189
+
190
+ result = None
191
+ error_msg = None
192
+
193
+ try:
194
+ result = func(*args, **kwargs)
195
+ print(f"[Auditi] Standalone {span_type} trace captured.")
196
+
197
+ # Extract model from response if not set
198
+ if not span.model:
199
+ provider = detect_provider(response=result)
200
+ extracted_model = provider.extract_model(result)
201
+ if extracted_model:
202
+ span.model = extracted_model
203
+
204
+ # Extract output and usage based on response type
205
+ if hasattr(result, "choices") and result.choices:
206
+ # OpenAI-style response
207
+ try:
208
+ choice = result.choices[0]
209
+ if hasattr(choice, "message") and hasattr(choice.message, "content"):
210
+ output = str(choice.message.content)
211
+ elif hasattr(choice, "text"):
212
+ output = str(choice.text)
213
+ else:
214
+ output = str(result)
215
+ except (IndexError, AttributeError):
216
+ output = str(result)
217
+
218
+ trace.assistant_output = output
219
+ span.outputs = output
220
+
221
+ if hasattr(result, "usage") and result.usage:
222
+ _apply_usage_to_span(span, result.usage, response=result)
223
+
224
+ elif hasattr(result, "data"):
225
+ # Embedding response (OpenAI embeddings have .data)
226
+ if hasattr(result, "usage"):
227
+ _apply_usage_to_span(span, result.usage, response=result)
228
+ span.outputs = f"Generated {len(result.data)} embeddings"
229
+ trace.assistant_output = span.outputs
230
+
231
+ elif isinstance(result, str):
232
+ trace.assistant_output = result
233
+ span.outputs = result
234
+
235
+ elif isinstance(result, dict):
236
+ content = (
237
+ result.get("content") or result.get("text") or result.get("response") or str(result)
238
+ )
239
+ trace.assistant_output = str(content)
240
+ span.outputs = str(content)
241
+ if "usage" in result:
242
+ _apply_usage_to_span(span, result["usage"], response=result)
243
+
244
+ elif hasattr(result, "content"):
245
+ trace.assistant_output = str(result.content)
246
+ span.outputs = str(result.content)
247
+ if hasattr(result, "usage"):
248
+ _apply_usage_to_span(span, result.usage, response=result)
249
+
250
+ else:
251
+ trace.assistant_output = str(result) if result else ""
252
+ span.outputs = str(result) if result else ""
253
+
254
+ span.status = "ok"
255
+
256
+ except Exception as e:
257
+ error_msg = str(e)
258
+ trace.assistant_output = f"Error: {e}"
259
+ trace.error = error_msg
260
+ span.error = error_msg
261
+ span.status = "error"
262
+ raise
263
+
264
+ finally:
265
+ end_time = datetime.utcnow()
266
+ span.end_time = end_time
267
+ trace.end_time = end_time
268
+
269
+ pop_span()
270
+ trace.spans.append(span)
271
+
272
+ # Aggregate metrics from span to trace
273
+ if span.tokens:
274
+ trace.total_tokens = span.tokens
275
+ if span.cost:
276
+ trace.cost = span.cost
277
+
278
+ # Send trace
279
+ trace_payload = trace.model_dump(mode="json")
280
+ _debug_log(f"Sending standalone trace payload for '{func_name}':", trace_payload)
281
+ client.transport.send_trace(trace_payload)
282
+
283
+ # Clear context
284
+ from .context import clear_current_trace
285
+
286
+ clear_current_trace()
287
+
288
+ return result
289
+
290
+
291
+ def trace_agent(
292
+ name: Optional[str] = None,
293
+ user_id: Optional[str] = None,
294
+ evaluator: Optional[BaseEvaluator] = None,
295
+ capture_input: bool = True,
296
+ ) -> Callable:
297
+ """
298
+ Decorator to trace an entire agent interaction.
299
+ Supports both sync and async functions, and handles instance methods.
300
+ """
301
+
302
+ def decorator(func: Callable) -> Callable:
303
+ # Check if the function is async
304
+ is_async = asyncio.iscoroutinefunction(func)
305
+
306
+ if is_async:
307
+
308
+ @functools.wraps(func)
309
+ async def async_wrapper(*args, **kwargs) -> Any:
310
+ client = get_client()
311
+ trace_id = uuid4()
312
+ start_time = datetime.utcnow()
313
+
314
+ trace_name = name or func.__name__
315
+
316
+ # Smart extraction of context from kwargs or global context
317
+ ctx = get_context()
318
+ session_id = (
319
+ kwargs.get("session_id")
320
+ or kwargs.get("conversation_id")
321
+ or ctx.get("session_id")
322
+ )
323
+ resolved_user_id = kwargs.get("user_id") or user_id or ctx.get("user_id")
324
+
325
+ _debug_log(
326
+ f"Starting trace '{trace_name}':",
327
+ {
328
+ "trace_id": str(trace_id),
329
+ "user_id": resolved_user_id,
330
+ "session_id": session_id,
331
+ "args": [str(arg)[:100] for arg in args],
332
+ "kwargs_keys": list(kwargs.keys()),
333
+ },
334
+ )
335
+
336
+ # Smart extraction of user input - SKIP SELF/CLS
337
+ user_input = ""
338
+ if capture_input and args:
339
+ # Determine starting index (skip 'self' or 'cls')
340
+ start_idx = 0
341
+ if args:
342
+ first_arg = args[0]
343
+ # Check if first arg is likely 'self' or 'cls'
344
+ # It's self/cls if it's an object that's not a basic type
345
+ if not isinstance(
346
+ first_arg, (str, int, float, bool, list, dict, tuple, type(None))
347
+ ):
348
+ start_idx = 1
349
+ _debug_log(
350
+ f"Skipping first argument (detected as self/cls): {type(first_arg)}"
351
+ )
352
+
353
+ # Get the actual user input from the correct position
354
+ if len(args) > start_idx:
355
+ first_arg = args[start_idx]
356
+ if isinstance(first_arg, str):
357
+ user_input = first_arg
358
+ elif isinstance(first_arg, dict) and "message" in first_arg:
359
+ user_input = first_arg["message"]
360
+ elif isinstance(first_arg, dict) and "content" in first_arg:
361
+ user_input = first_arg["content"]
362
+ else:
363
+ user_input = str(first_arg)
364
+
365
+ # Also check for user_input/message/query in kwargs
366
+ if not user_input:
367
+ user_input = (
368
+ kwargs.get("user_input")
369
+ or kwargs.get("message")
370
+ or kwargs.get("query")
371
+ or kwargs.get("prompt")
372
+ or ""
373
+ )
374
+
375
+ _debug_log(f"Captured user input:", {"user_input": user_input[:200]})
376
+
377
+ # Create Trace
378
+ trace = TraceInput(
379
+ id=trace_id,
380
+ user_id=resolved_user_id,
381
+ conversation_id=session_id,
382
+ start_time=start_time,
383
+ user_input=user_input,
384
+ name=trace_name,
385
+ tags=kwargs.get("tags", []),
386
+ )
387
+ set_current_trace(trace)
388
+
389
+ result = None
390
+ error_msg = None
391
+
392
+ try:
393
+ result = await func(*args, **kwargs)
394
+ print("[Auditi] Trace captured.")
395
+ if DEBUG:
396
+ print(f"result type: {type(result)}")
397
+ print(f"result: {str(result)[:200]}")
398
+
399
+ # Smart extraction of assistant output
400
+ if isinstance(result, str):
401
+ trace.assistant_output = result
402
+ elif isinstance(result, dict):
403
+ trace.assistant_output = (
404
+ result.get("content")
405
+ or result.get("message")
406
+ or result.get("response")
407
+ or str(result)
408
+ )
409
+
410
+ # Extract model for provider detection
411
+ model = result.get("model")
412
+
413
+ # EXTRACT METRICS from result dict if available
414
+ if "usage" in result:
415
+ _apply_usage_to_trace(trace, result["usage"], model=model)
416
+
417
+ elif hasattr(result, "content"):
418
+ trace.assistant_output = str(result.content)
419
+
420
+ # Extract model from response
421
+ model = getattr(result, "model", None)
422
+ if model:
423
+ model = str(model)
424
+
425
+ # Check for usage on object
426
+ if hasattr(result, "usage"):
427
+ try:
428
+ _apply_usage_to_trace(trace, result.usage, model=model)
429
+ except Exception as e:
430
+ _debug_log(f"Failed to extract usage from result object: {e}")
431
+
432
+ else:
433
+ trace.assistant_output = str(result) if result else ""
434
+
435
+ # Try to extract model and usage
436
+ model = getattr(result, "model", None)
437
+ if model:
438
+ model = str(model)
439
+
440
+ if hasattr(result, "usage"):
441
+ try:
442
+ _apply_usage_to_trace(trace, result.usage, model=model)
443
+ except Exception as e:
444
+ _debug_log(f"Failed to extract usage: {e}")
445
+
446
+ _debug_log(
447
+ f"Captured assistant output:", {"output": str(trace.assistant_output)[:200]}
448
+ )
449
+
450
+ except Exception as e:
451
+ error_msg = str(e)
452
+ trace.assistant_output = f"Error: {e}"
453
+ trace.error = error_msg
454
+ _debug_log(f"Error in trace '{trace_name}':", {"error": error_msg})
455
+ raise
456
+ finally:
457
+ trace.end_time = datetime.utcnow()
458
+
459
+ # Aggregate metrics from spans if not set on trace
460
+ if (trace.total_tokens is None or trace.total_tokens == 0) and trace.spans:
461
+ calculated_tokens = 0
462
+ calculated_cost = 0.0
463
+ for s in trace.spans:
464
+ if s.tokens:
465
+ calculated_tokens += s.tokens
466
+ if s.cost:
467
+ calculated_cost += s.cost
468
+
469
+ if calculated_tokens > 0:
470
+ trace.total_tokens = calculated_tokens
471
+ if trace.cost is None or trace.cost == 0.0:
472
+ trace.cost = calculated_cost
473
+
474
+ # Run Evaluator if provided and no error
475
+ if evaluator and not error_msg:
476
+ try:
477
+ eval_result = evaluator.evaluate(trace)
478
+ # FIX: Map evaluation result to TraceInput fields
479
+ trace.status = eval_result.status
480
+ trace.score = eval_result.score
481
+ if eval_result.reason:
482
+ trace.eval_reason = eval_result.reason
483
+ if eval_result.failure_mode:
484
+ trace.failure_mode = eval_result.failure_mode
485
+
486
+ _debug_log("Evaluation result:", eval_result)
487
+ except Exception as e:
488
+ print(f"[Auditi] Evaluator failed: {e}")
489
+
490
+ # Prepare and log payload
491
+ trace_payload = trace.model_dump(mode="json")
492
+ _debug_log(f"Sending trace payload for '{trace_name}':", trace_payload)
493
+
494
+ # Send Trace
495
+ client.transport.send_trace(trace_payload)
496
+
497
+ return result
498
+
499
+ return async_wrapper
500
+
501
+ elif inspect.isasyncgenfunction(func):
502
+
503
+ @functools.wraps(func)
504
+ async def async_gen_wrapper(*args, **kwargs) -> Any:
505
+ client = get_client()
506
+ trace_id = uuid4()
507
+ start_time = datetime.utcnow()
508
+
509
+ trace_name = name or func.__name__
510
+
511
+ # Smart extraction of context from kwargs or global context
512
+ ctx = get_context()
513
+ session_id = (
514
+ kwargs.get("session_id")
515
+ or kwargs.get("conversation_id")
516
+ or ctx.get("session_id")
517
+ )
518
+ resolved_user_id = kwargs.get("user_id") or user_id or ctx.get("user_id")
519
+
520
+ _debug_log(
521
+ f"Starting async generator trace '{trace_name}':",
522
+ {
523
+ "trace_id": str(trace_id),
524
+ "user_id": resolved_user_id,
525
+ "session_id": session_id,
526
+ "args": [str(arg)[:100] for arg in args],
527
+ "kwargs_keys": list(kwargs.keys()),
528
+ },
529
+ )
530
+
531
+ # Smart extraction of user input - SKIP SELF/CLS
532
+ user_input = ""
533
+ if capture_input and args:
534
+ # Determine starting index (skip 'self' or 'cls')
535
+ start_idx = 0
536
+ if args:
537
+ first_arg = args[0]
538
+ # Check if first arg is likely 'self' or 'cls'
539
+ # It's self/cls if it's an object that's not a basic type
540
+ if not isinstance(
541
+ first_arg, (str, int, float, bool, list, dict, tuple, type(None))
542
+ ):
543
+ start_idx = 1
544
+
545
+ # Get the actual user input from the correct position
546
+ if len(args) > start_idx:
547
+ first_arg = args[start_idx]
548
+ if isinstance(first_arg, str):
549
+ user_input = first_arg
550
+ elif isinstance(first_arg, dict) and "message" in first_arg:
551
+ user_input = first_arg["message"]
552
+ elif isinstance(first_arg, dict) and "content" in first_arg:
553
+ user_input = first_arg["content"]
554
+ else:
555
+ user_input = str(first_arg)
556
+
557
+ # Also check for user_input/message/query in kwargs
558
+ if not user_input:
559
+ user_input = (
560
+ kwargs.get("user_input")
561
+ or kwargs.get("message")
562
+ or kwargs.get("query")
563
+ or kwargs.get("prompt")
564
+ or ""
565
+ )
566
+
567
+ _debug_log(f"Captured user input:", {"user_input": user_input[:200]})
568
+
569
+ # Create Trace
570
+ trace = TraceInput(
571
+ id=trace_id,
572
+ user_id=resolved_user_id,
573
+ conversation_id=session_id,
574
+ start_time=start_time,
575
+ user_input=user_input,
576
+ name=trace_name,
577
+ tags=kwargs.get("tags", []),
578
+ )
579
+ set_current_trace(trace)
580
+
581
+ error_msg = None
582
+
583
+ # Accumulators for streaming response
584
+ accumulated_content = []
585
+ final_content = None
586
+ accumulated_model = None
587
+ accumulated_usage = None
588
+
589
+ try:
590
+ # Call the function to get the generator
591
+ gen = func(*args, **kwargs)
592
+
593
+ # Iterate through generator
594
+ async for item in gen:
595
+ yield item
596
+
597
+ # Process item for trace
598
+ if isinstance(item, str):
599
+ accumulated_content.append(item)
600
+ elif isinstance(item, dict):
601
+ # Handle different event types
602
+ evt_type = item.get("type")
603
+
604
+ # Token event
605
+ if evt_type == "token" and "content" in item:
606
+ accumulated_content.append(str(item["content"]))
607
+
608
+ # Complete event (often holds final answer/usage)
609
+ elif evt_type == "complete":
610
+ if "content" in item:
611
+ final_content = str(item["content"])
612
+ if "usage" in item:
613
+ accumulated_usage = item["usage"]
614
+ if "model" in item:
615
+ accumulated_model = item["model"]
616
+
617
+ # Other events with content (fallback)
618
+ elif "content" in item:
619
+ # Start of new phase or just info, maybe not content to accumulate?
620
+ # Generative agents often send content chunks in 'content' field
621
+ if evt_type not in (
622
+ "tool_exec_start",
623
+ "tool_exec_end",
624
+ "phase_start",
625
+ "phase_end",
626
+ ):
627
+ if isinstance(item["content"], str):
628
+ # Be careful not to double count if 'token' events are used
629
+ pass
630
+
631
+ elif hasattr(item, "content"):
632
+ accumulated_content.append(str(item.content))
633
+
634
+ # Determine final output and usage
635
+ if final_content is not None:
636
+ trace.assistant_output = final_content
637
+ else:
638
+ trace.assistant_output = "".join(accumulated_content)
639
+
640
+ # Apply usage if found
641
+ if accumulated_usage:
642
+ try:
643
+ _apply_usage_to_trace(trace, accumulated_usage, model=accumulated_model)
644
+ except Exception:
645
+ pass
646
+
647
+ print("[Auditi] Async generator trace captured.")
648
+ _debug_log(
649
+ f"Captured assistant output:", {"output": str(trace.assistant_output)[:200]}
650
+ )
651
+
652
+ except Exception as e:
653
+ error_msg = str(e)
654
+ trace.assistant_output = f"Error: {e}"
655
+ trace.error = error_msg
656
+ _debug_log(f"Error in trace '{trace_name}':", {"error": error_msg})
657
+ raise
658
+ finally:
659
+ trace.end_time = datetime.utcnow()
660
+
661
+ # Aggregate metrics from spans if not set on trace
662
+ if (trace.total_tokens is None or trace.total_tokens == 0) and trace.spans:
663
+ calculated_tokens = 0
664
+ calculated_cost = 0.0
665
+ for s in trace.spans:
666
+ if s.tokens:
667
+ calculated_tokens += s.tokens
668
+ if s.cost:
669
+ calculated_cost += s.cost
670
+
671
+ if calculated_tokens > 0:
672
+ trace.total_tokens = calculated_tokens
673
+ if trace.cost is None or trace.cost == 0.0:
674
+ trace.cost = calculated_cost
675
+
676
+ # Run Evaluator if provided and no error
677
+ if evaluator and not error_msg:
678
+ try:
679
+ eval_result = evaluator.evaluate(trace)
680
+ # FIX: Map evaluation result to TraceInput fields
681
+ trace.status = eval_result.status
682
+ trace.score = eval_result.score
683
+ if eval_result.reason:
684
+ trace.eval_reason = eval_result.reason
685
+ if eval_result.failure_mode:
686
+ trace.failure_mode = eval_result.failure_mode
687
+
688
+ _debug_log("Evaluation result:", eval_result)
689
+ except Exception as e:
690
+ print(f"[Auditi] Evaluator failed: {e}")
691
+
692
+ # Prepare and log payload
693
+ trace_payload = trace.model_dump(mode="json")
694
+ _debug_log(f"Sending trace payload for '{trace_name}':", trace_payload)
695
+
696
+ # Send Trace
697
+ client.transport.send_trace(trace_payload)
698
+
699
+ return async_gen_wrapper
700
+
701
+ else:
702
+ # Original sync wrapper (adapted with new improvements)
703
+ @functools.wraps(func)
704
+ def wrapper(*args, **kwargs) -> Any:
705
+ client = get_client()
706
+ trace_id = uuid4()
707
+ start_time = datetime.utcnow()
708
+
709
+ trace_name = name or func.__name__
710
+
711
+ # Smart extraction of context from kwargs or global context
712
+ ctx = get_context()
713
+ session_id = (
714
+ kwargs.get("session_id")
715
+ or kwargs.get("conversation_id")
716
+ or ctx.get("session_id")
717
+ )
718
+ resolved_user_id = kwargs.get("user_id") or user_id or ctx.get("user_id")
719
+
720
+ _debug_log(
721
+ f"Starting trace '{trace_name}':",
722
+ {
723
+ "trace_id": str(trace_id),
724
+ "user_id": resolved_user_id,
725
+ "session_id": session_id,
726
+ "args": [str(arg)[:100] for arg in args],
727
+ "kwargs_keys": list(kwargs.keys()),
728
+ },
729
+ )
730
+
731
+ # Smart extraction of user input - SKIP SELF/CLS
732
+ user_input = ""
733
+ if capture_input and args:
734
+ # Determine starting index (skip 'self' or 'cls')
735
+ start_idx = 0
736
+ if args:
737
+ first_arg = args[0]
738
+ # Check if first arg is likely 'self' or 'cls'
739
+ # It's self/cls if it's an object that's not a basic type
740
+ if not isinstance(
741
+ first_arg, (str, int, float, bool, list, dict, tuple, type(None))
742
+ ):
743
+ start_idx = 1
744
+ _debug_log(
745
+ f"Skipping first argument (detected as self/cls): {type(first_arg)}"
746
+ )
747
+
748
+ # Get the actual user input from the correct position
749
+ if len(args) > start_idx:
750
+ first_arg = args[start_idx]
751
+ if isinstance(first_arg, str):
752
+ user_input = first_arg
753
+ elif isinstance(first_arg, dict) and "message" in first_arg:
754
+ user_input = first_arg["message"]
755
+ elif isinstance(first_arg, dict) and "content" in first_arg:
756
+ user_input = first_arg["content"]
757
+ else:
758
+ user_input = str(first_arg)
759
+
760
+ # Also check for user_input/message/query in kwargs
761
+ if not user_input:
762
+ user_input = (
763
+ kwargs.get("user_input")
764
+ or kwargs.get("message")
765
+ or kwargs.get("query")
766
+ or kwargs.get("prompt")
767
+ or ""
768
+ )
769
+
770
+ _debug_log(f"Captured user input:", {"user_input": user_input[:200]})
771
+
772
+ # Create Trace
773
+ trace = TraceInput(
774
+ id=trace_id,
775
+ user_id=resolved_user_id,
776
+ conversation_id=session_id,
777
+ start_time=start_time,
778
+ user_input=user_input,
779
+ name=trace_name,
780
+ tags=kwargs.get("tags", []),
781
+ )
782
+ set_current_trace(trace)
783
+
784
+ result = None
785
+ error_msg = None
786
+
787
+ try:
788
+ result = func(*args, **kwargs)
789
+ print("[Auditi] Trace captured.")
790
+ if DEBUG:
791
+ print(f"result: {result}")
792
+
793
+ # Smart extraction of assistant output
794
+ if isinstance(result, str):
795
+ trace.assistant_output = result
796
+ elif isinstance(result, dict):
797
+ trace.assistant_output = (
798
+ result.get("content")
799
+ or result.get("message")
800
+ or result.get("response")
801
+ or str(result)
802
+ )
803
+
804
+ # Extract model for provider detection
805
+ model = result.get("model")
806
+
807
+ # EXTRACT METRICS from result dict if available
808
+ if "usage" in result:
809
+ _apply_usage_to_trace(trace, result["usage"], model=model)
810
+
811
+ elif hasattr(result, "content"):
812
+ trace.assistant_output = str(result.content)
813
+
814
+ # Extract model from response
815
+ model = getattr(result, "model", None)
816
+ if model:
817
+ model = str(model)
818
+
819
+ # Check for usage on object
820
+ if hasattr(result, "usage"):
821
+ try:
822
+ _apply_usage_to_trace(trace, result.usage, model=model)
823
+ except Exception as e:
824
+ _debug_log(f"Failed to extract usage from result object: {e}")
825
+
826
+ else:
827
+ trace.assistant_output = str(result) if result else ""
828
+
829
+ # Try to extract model and usage
830
+ model = getattr(result, "model", None)
831
+ if model:
832
+ model = str(model)
833
+
834
+ if hasattr(result, "usage"):
835
+ try:
836
+ _apply_usage_to_trace(trace, result.usage, model=model)
837
+ except Exception as e:
838
+ _debug_log(f"Failed to extract usage: {e}")
839
+
840
+ _debug_log(
841
+ f"Captured assistant output:", {"output": str(trace.assistant_output)[:200]}
842
+ )
843
+
844
+ except Exception as e:
845
+ error_msg = str(e)
846
+ trace.assistant_output = f"Error: {e}"
847
+ trace.error = error_msg
848
+ _debug_log(f"Error in trace '{trace_name}':", {"error": error_msg})
849
+ raise
850
+ finally:
851
+ trace.end_time = datetime.utcnow()
852
+
853
+ # Aggregate metrics from spans if not set on trace
854
+ if (trace.total_tokens is None or trace.total_tokens == 0) and trace.spans:
855
+ calculated_tokens = 0
856
+ calculated_cost = 0.0
857
+ for s in trace.spans:
858
+ if s.tokens:
859
+ calculated_tokens += s.tokens
860
+ if s.cost:
861
+ calculated_cost += s.cost
862
+
863
+ if calculated_tokens > 0:
864
+ trace.total_tokens = calculated_tokens
865
+ if trace.cost is None or trace.cost == 0.0:
866
+ trace.cost = calculated_cost
867
+
868
+ # Run Evaluator if provided and no error
869
+ if evaluator and not error_msg:
870
+ try:
871
+ eval_result = evaluator.evaluate(trace)
872
+ # FIX: Map evaluation result to TraceInput fields
873
+ trace.status = eval_result.status
874
+ trace.score = eval_result.score
875
+ if eval_result.reason:
876
+ trace.eval_reason = eval_result.reason
877
+ if eval_result.failure_mode:
878
+ trace.failure_mode = eval_result.failure_mode
879
+
880
+ _debug_log("Evaluation result:", eval_result)
881
+ except Exception as e:
882
+ print(f"[Auditi] Evaluator failed: {e}")
883
+
884
+ # Prepare and log payload
885
+ trace_payload = trace.model_dump(mode="json")
886
+ _debug_log(f"Sending trace payload for '{trace_name}':", trace_payload)
887
+
888
+ # Send Trace
889
+ client.transport.send_trace(trace_payload)
890
+
891
+ return result
892
+
893
+ return wrapper
894
+
895
+ return decorator
896
+
897
+
898
+ def _capture_inputs(func, args, kwargs, model_hint=None):
899
+ """
900
+ Helper to capture and normalize inputs from function arguments.
901
+ Handles 'self' skipping, kwarg flattening, and fallback str conversion.
902
+ """
903
+ inputs = {}
904
+ effective_model = model_hint
905
+
906
+ try:
907
+ import inspect
908
+
909
+ sig = inspect.signature(func)
910
+ bound = sig.bind(*args, **kwargs)
911
+ bound.apply_defaults()
912
+
913
+ # Capture all arguments as inputs
914
+ for arg_name, value in bound.arguments.items():
915
+ # Skip 'self' or 'cls' typically found in methods
916
+ if arg_name in ("self", "cls"):
917
+ continue
918
+
919
+ # Flatten **kwargs if they exist and are a dict
920
+ param = sig.parameters.get(arg_name)
921
+ if param and param.kind == inspect.Parameter.VAR_KEYWORD and isinstance(value, dict):
922
+ inputs.update({k: str(v) for k, v in value.items()})
923
+ else:
924
+ inputs[arg_name] = str(value)
925
+
926
+ # Check for model in arguments
927
+ if not effective_model:
928
+ if "model" in inputs:
929
+ effective_model = inputs["model"]
930
+ elif "model_id" in inputs: # ADD THIS
931
+ effective_model = inputs["model_id"]
932
+ # Also check original bound args just in case normalization changed something
933
+ elif "model" in bound.arguments:
934
+ effective_model = str(bound.arguments["model"])
935
+ elif "model_id" in bound.arguments: # ADD THIS
936
+ effective_model = str(bound.arguments["model_id"])
937
+
938
+ except Exception:
939
+ # Binding failed, ignore
940
+ pass
941
+
942
+ # Fallback model detection
943
+ if not effective_model:
944
+ if "model" in kwargs:
945
+ effective_model = str(kwargs["model"])
946
+ elif "model_id" in kwargs: # ADD THIS
947
+ effective_model = str(kwargs["model_id"])
948
+ elif args and hasattr(args[0], "model"):
949
+ effective_model = str(args[0].model)
950
+ elif args and hasattr(args[0], "model_id"): # ADD THIS
951
+ effective_model = str(args[0].model_id)
952
+
953
+ # Fallback input detection
954
+ if not inputs:
955
+ if args:
956
+ first_arg = args[0]
957
+ # Skip likely self/cls in fallback mode too (basic heuristic check)
958
+ start_idx = 0
959
+ if not isinstance(first_arg, (str, int, float, bool, list, dict, type(None))):
960
+ start_idx = 1
961
+
962
+ if len(args) > start_idx:
963
+ val = args[start_idx]
964
+ if isinstance(val, str):
965
+ inputs["prompt"] = val
966
+ elif isinstance(val, (dict, list)):
967
+ inputs["data"] = val
968
+ else:
969
+ inputs["input"] = str(val)
970
+
971
+ if kwargs:
972
+ inputs.update({k: str(v) for k, v in kwargs.items()})
973
+
974
+ return inputs, effective_model
975
+
976
+
977
+ def _trace_span(
978
+ span_type: str,
979
+ name: Optional[str] = None,
980
+ model: Optional[str] = None,
981
+ standalone: bool = False,
982
+ ) -> Callable:
983
+ """
984
+ Internal helper to create span tracing decorators.
985
+
986
+ Args:
987
+ span_type: Type of span ("llm", "tool", "embedding", etc.)
988
+ name: Optional custom name for the span
989
+ model: Optional model name
990
+ standalone: If True, creates a standalone trace when no parent trace exists
991
+ """
992
+
993
+ def decorator(func: Callable) -> Callable:
994
+ # Check for async generator first
995
+ if inspect.isasyncgenfunction(func):
996
+
997
+ @functools.wraps(func)
998
+ async def async_gen_span_wrapper(*args, **kwargs) -> Any:
999
+ trace = get_current_trace()
1000
+ if not trace:
1001
+ if standalone:
1002
+ # Create a standalone trace for this call
1003
+ # Note: Simple detection for standalone async gen not fully implemented in _execute_as_standalone_trace yet
1004
+ # For now, falling back to treating it as regular func (may need enhancement)
1005
+ async for item in func(*args, **kwargs):
1006
+ yield item
1007
+ else:
1008
+ async for item in func(*args, **kwargs):
1009
+ yield item
1010
+ return
1011
+
1012
+ parent = get_current_span()
1013
+ span_id = uuid4()
1014
+ start_time = datetime.utcnow()
1015
+ span_name = name or func.__name__
1016
+
1017
+ # Use robust input capture
1018
+ inputs, effective_model = _capture_inputs(func, args, kwargs, model)
1019
+
1020
+ _debug_log(
1021
+ f"Starting async generator span '{span_name}' (type: {span_type}):",
1022
+ {
1023
+ "span_id": str(span_id),
1024
+ "parent_id": str(parent.id) if parent else None,
1025
+ "model": effective_model,
1026
+ "trace_id": str(trace.id),
1027
+ },
1028
+ )
1029
+
1030
+ span = SpanInput(
1031
+ id=span_id,
1032
+ trace_id=trace.id,
1033
+ parent_id=parent.id if parent else None,
1034
+ name=span_name,
1035
+ span_type=span_type,
1036
+ start_time=start_time,
1037
+ inputs=inputs,
1038
+ model=effective_model,
1039
+ )
1040
+
1041
+ push_span(span)
1042
+
1043
+ accumulated_outputs = []
1044
+
1045
+ try:
1046
+ # Execute async generator
1047
+ gen = func(*args, **kwargs)
1048
+
1049
+ async for item in gen:
1050
+ yield item
1051
+ # Only accumulate relevant content
1052
+ if isinstance(item, str):
1053
+ accumulated_outputs.append(item)
1054
+ elif isinstance(item, dict):
1055
+ evt_type = item.get("type")
1056
+
1057
+ # Filter out internal events using standardized set
1058
+ if evt_type in {e.value for e in INTERNAL_EVENTS}:
1059
+ # Handle turn_metadata to extract usage before skipping
1060
+ if evt_type == EventType.TURN_METADATA.value:
1061
+ if "usage" in item and item["usage"]:
1062
+ _apply_usage_to_span(span, item["usage"])
1063
+ continue
1064
+
1065
+ # Accumulate content from content events
1066
+ if evt_type == EventType.TOKEN.value and "content" in item:
1067
+ accumulated_outputs.append(str(item["content"]))
1068
+ elif evt_type == EventType.COMPLETE.value and "content" in item:
1069
+ pass # Complete event content handled elsewhere
1070
+ elif "content" in item:
1071
+ accumulated_outputs.append(str(item["content"]))
1072
+
1073
+ elif hasattr(item, "content"):
1074
+ accumulated_outputs.append(str(item.content))
1075
+ else:
1076
+ accumulated_outputs.append(str(item))
1077
+
1078
+ # If usage was not found in stream but captured via other means (e.g. side channel), we rely on it being yields.
1079
+
1080
+ span.outputs = "".join(accumulated_outputs)
1081
+ span.status = "ok"
1082
+
1083
+ except Exception as e:
1084
+ span.error = str(e)
1085
+ span.status = "error"
1086
+ _debug_log(f"Span '{span_name}' failed:", {"error": str(e)})
1087
+ # Don't raise here if we are yielding? Actually we should raise to propagate error
1088
+ raise
1089
+ finally:
1090
+ span.end_time = datetime.utcnow()
1091
+ pop_span()
1092
+
1093
+ span_payload = span.model_dump(mode="json")
1094
+ _debug_log(f"Adding span '{span_name}' to trace:", span_payload)
1095
+ trace.spans.append(span)
1096
+
1097
+ return async_gen_span_wrapper
1098
+
1099
+ # Check for regular async function
1100
+ elif asyncio.iscoroutinefunction(func):
1101
+
1102
+ @functools.wraps(func)
1103
+ async def async_span_wrapper(*args, **kwargs) -> Any:
1104
+ trace = get_current_trace()
1105
+ if not trace:
1106
+ if standalone:
1107
+ # Standalone async trace
1108
+ return _execute_as_standalone_trace(
1109
+ func, args, kwargs, span_type, name, model
1110
+ )
1111
+ else:
1112
+ return await func(*args, **kwargs)
1113
+
1114
+ parent = get_current_span()
1115
+ span_id = uuid4()
1116
+ start_time = datetime.utcnow()
1117
+ span_name = name or func.__name__
1118
+
1119
+ # Use robust input capture
1120
+ inputs, effective_model = _capture_inputs(func, args, kwargs, model)
1121
+
1122
+ span = SpanInput(
1123
+ id=span_id,
1124
+ trace_id=trace.id,
1125
+ parent_id=parent.id if parent else None,
1126
+ name=span_name,
1127
+ span_type=span_type,
1128
+ start_time=start_time,
1129
+ inputs=inputs,
1130
+ model=effective_model,
1131
+ )
1132
+
1133
+ push_span(span)
1134
+
1135
+ try:
1136
+ result = await func(*args, **kwargs)
1137
+
1138
+ # Capture output
1139
+ if isinstance(result, str):
1140
+ span.outputs = result
1141
+ elif hasattr(result, "content"):
1142
+ span.outputs = str(result.content)
1143
+ else:
1144
+ span.outputs = str(result)
1145
+
1146
+ span.status = "ok"
1147
+ return result
1148
+
1149
+ except Exception as e:
1150
+ span.error = str(e)
1151
+ span.status = "error"
1152
+ raise
1153
+ finally:
1154
+ span.end_time = datetime.utcnow()
1155
+ pop_span()
1156
+
1157
+ span_payload = span.model_dump(mode="json")
1158
+ _debug_log(f"Adding span '{span_name}' to trace:", span_payload)
1159
+ trace.spans.append(span)
1160
+
1161
+ return async_span_wrapper
1162
+
1163
+ @functools.wraps(func)
1164
+ def wrapper(*args, **kwargs) -> Any:
1165
+ trace = get_current_trace()
1166
+ if not trace:
1167
+ if standalone:
1168
+ # Create a standalone trace for this call
1169
+ return _execute_as_standalone_trace(func, args, kwargs, span_type, name, model)
1170
+ else:
1171
+ # Original behavior: just run the function normally
1172
+ return func(*args, **kwargs)
1173
+
1174
+ parent = get_current_span()
1175
+ span_id = uuid4()
1176
+ start_time = datetime.utcnow()
1177
+ span_name = name or func.__name__
1178
+
1179
+ # Auto-detect model and capture inputs using signature binding if available
1180
+ effective_model = model
1181
+ inputs = {}
1182
+
1183
+ try:
1184
+ import inspect
1185
+
1186
+ sig = inspect.signature(func)
1187
+ bound = sig.bind(*args, **kwargs)
1188
+ bound.apply_defaults()
1189
+
1190
+ # Capture all arguments as inputs
1191
+ for arg_name, value in bound.arguments.items():
1192
+ # Skip 'self' or 'cls' typically found in methods
1193
+ if arg_name in ("self", "cls"):
1194
+ continue
1195
+ # Flatten **kwargs if they exist and are a dict
1196
+ param = sig.parameters.get(arg_name)
1197
+ if (
1198
+ param
1199
+ and param.kind == inspect.Parameter.VAR_KEYWORD
1200
+ and isinstance(value, dict)
1201
+ ):
1202
+ inputs.update({k: str(v) for k, v in value.items()})
1203
+ else:
1204
+ inputs[arg_name] = str(value)
1205
+
1206
+ # Check for model in arguments
1207
+ if not effective_model:
1208
+ if "model" in kwargs:
1209
+ effective_model = str(kwargs["model"])
1210
+ elif "model_id" in kwargs: # ADD THIS
1211
+ effective_model = str(kwargs["model_id"])
1212
+ elif args and hasattr(args[0], "model"):
1213
+ effective_model = str(args[0].model)
1214
+ elif args and hasattr(args[0], "model_id"): # ADD THIS
1215
+ effective_model = str(args[0].model_id)
1216
+
1217
+ except Exception:
1218
+ # If binding fails, valid case for built-ins or certain wrappers
1219
+ pass
1220
+
1221
+ # Fallback model detection (if signature extraction failed or didn't find model)
1222
+ if not effective_model:
1223
+ if "model" in kwargs:
1224
+ effective_model = str(kwargs["model"])
1225
+ elif args and hasattr(args[0], "model"):
1226
+ effective_model = str(args[0].model)
1227
+
1228
+ # Fallback input detection (if signature extraction failed or resulted in empty inputs)
1229
+ if not inputs:
1230
+ if args:
1231
+ first_arg = args[0]
1232
+ if isinstance(first_arg, str):
1233
+ inputs["prompt"] = first_arg
1234
+ elif isinstance(first_arg, (dict, list)):
1235
+ inputs["data"] = first_arg
1236
+ else:
1237
+ inputs["input"] = str(first_arg)
1238
+ if kwargs:
1239
+ inputs.update({k: str(v) for k, v in kwargs.items()})
1240
+
1241
+ _debug_log(
1242
+ f"Starting span '{span_name}' (type: {span_type}):",
1243
+ {
1244
+ "span_id": str(span_id),
1245
+ "parent_id": str(parent.id) if parent else None,
1246
+ "model": effective_model,
1247
+ "trace_id": str(trace.id),
1248
+ "inputs_keys": list(inputs.keys()),
1249
+ },
1250
+ )
1251
+
1252
+ span = SpanInput(
1253
+ id=span_id,
1254
+ trace_id=trace.id,
1255
+ parent_id=parent.id if parent else None,
1256
+ name=span_name,
1257
+ span_type=span_type,
1258
+ start_time=start_time,
1259
+ inputs=inputs,
1260
+ model=effective_model,
1261
+ )
1262
+
1263
+ push_span(span)
1264
+
1265
+ try:
1266
+ result = func(*args, **kwargs)
1267
+
1268
+ # Use provider abstraction to extract model if not set
1269
+ if not span.model:
1270
+ provider = detect_provider(response=result)
1271
+ extracted_model = provider.extract_model(result)
1272
+ if extracted_model:
1273
+ span.model = extracted_model
1274
+
1275
+ # Smart output capture for various LLM response types
1276
+
1277
+ # OpenAI ChatCompletion object (has choices array)
1278
+ if hasattr(result, "choices") and result.choices:
1279
+ try:
1280
+ choice = result.choices[0]
1281
+ if hasattr(choice, "message") and hasattr(choice.message, "content"):
1282
+ span.outputs = str(choice.message.content)
1283
+ elif hasattr(choice, "text"): # Legacy completions API
1284
+ span.outputs = str(choice.text)
1285
+ except (IndexError, AttributeError):
1286
+ span.outputs = str(result)
1287
+
1288
+ # Extract usage from response using provider abstraction
1289
+ if hasattr(result, "usage") and result.usage:
1290
+ _apply_usage_to_span(span, result.usage, response=result)
1291
+
1292
+ # Simple string result
1293
+ elif isinstance(result, str):
1294
+ span.outputs = result
1295
+
1296
+ # Object with .content attribute (e.g., Anthropic)
1297
+ elif hasattr(result, "content"):
1298
+ span.outputs = str(result.content)
1299
+
1300
+ # Extract usage
1301
+ if hasattr(result, "usage"):
1302
+ try:
1303
+ _apply_usage_to_span(span, result.usage, response=result)
1304
+ except Exception as e:
1305
+ _debug_log(f"Failed to extract usage from content object: {e}")
1306
+
1307
+ # Dict result (e.g., from custom LLM wrappers)
1308
+ elif isinstance(result, dict):
1309
+ # Try to extract content from common keys
1310
+ content = (
1311
+ result.get("content")
1312
+ or result.get("text")
1313
+ or result.get("message")
1314
+ or str(result)
1315
+ )
1316
+ span.outputs = str(content)
1317
+
1318
+ # Extract usage
1319
+ if "usage" in result:
1320
+ _apply_usage_to_span(span, result["usage"], response=result)
1321
+
1322
+ # Fallback for unknown types
1323
+ else:
1324
+ span.outputs = str(result)
1325
+
1326
+ # Try to extract usage
1327
+ if hasattr(result, "usage"):
1328
+ try:
1329
+ _apply_usage_to_span(span, result.usage, response=result)
1330
+ except Exception as e:
1331
+ _debug_log(f"Failed to extract usage from unknown type: {e}")
1332
+
1333
+ span.status = "ok"
1334
+
1335
+ _debug_log(
1336
+ f"Span '{span_name}' completed successfully:",
1337
+ {
1338
+ "output_length": len(str(span.outputs)) if span.outputs else 0,
1339
+ "model": span.model,
1340
+ "tokens": span.tokens if hasattr(span, "tokens") and span.tokens else None,
1341
+ "cost": span.cost if hasattr(span, "cost") else None,
1342
+ },
1343
+ )
1344
+
1345
+ return result
1346
+ except Exception as e:
1347
+ span.error = str(e)
1348
+ span.status = "error"
1349
+ _debug_log(f"Span '{span_name}' failed:", {"error": str(e)})
1350
+ raise
1351
+ finally:
1352
+ span.end_time = datetime.utcnow()
1353
+ pop_span()
1354
+
1355
+ # Log span payload before adding to trace
1356
+ span_payload = span.model_dump(mode="json")
1357
+ _debug_log(f"Adding span '{span_name}' to trace:", span_payload)
1358
+
1359
+ trace.spans.append(span)
1360
+
1361
+ return wrapper
1362
+
1363
+ return decorator
1364
+
1365
+
1366
+ def trace_tool(name: Optional[str] = None, standalone: bool = False) -> Callable:
1367
+ """
1368
+ Decorator to trace a tool/function call within an agent.
1369
+
1370
+ Args:
1371
+ name: Optional custom name for the tool span
1372
+ standalone: If True, creates its own trace when not inside @trace_agent
1373
+
1374
+ Example:
1375
+ >>> @trace_tool("database_search")
1376
+ ... def search_db(query: str) -> list:
1377
+ ... return db.search(query)
1378
+
1379
+ >>> # Standalone tool call (creates its own trace)
1380
+ >>> @trace_tool(standalone=True)
1381
+ ... def standalone_tool(data: str) -> str:
1382
+ ... return process(data)
1383
+ """
1384
+ return _trace_span("tool", name, standalone=standalone)
1385
+
1386
+
1387
+ def trace_llm(
1388
+ name: Optional[str] = None, model: Optional[str] = None, standalone: bool = False
1389
+ ) -> Callable:
1390
+ """
1391
+ Decorator to trace an LLM call within an agent.
1392
+
1393
+ Args:
1394
+ name: Optional custom name for the LLM span
1395
+ model: Optional model name (auto-detected from response if not provided)
1396
+ standalone: If True, creates its own trace when not inside @trace_agent
1397
+
1398
+ Example:
1399
+ >>> @trace_llm("generate_response", model="gpt-4")
1400
+ ... def call_gpt(prompt: str) -> str:
1401
+ ... return openai.chat(prompt)
1402
+
1403
+ >>> # Standalone LLM call (creates its own trace)
1404
+ >>> @trace_llm(standalone=True)
1405
+ ... def simple_chat(prompt: str) -> str:
1406
+ ... return openai.chat(prompt)
1407
+ """
1408
+ return _trace_span("llm", name, model, standalone=standalone)
1409
+
1410
+
1411
+ def trace_embedding(name: Optional[str] = None, model: Optional[str] = None) -> Callable:
1412
+ """
1413
+ Decorator to trace an embedding call. Always creates a standalone trace
1414
+ when not inside @trace_agent (embeddings are typically standalone operations).
1415
+
1416
+ Args:
1417
+ name: Optional custom name for the embedding span
1418
+ model: Optional model name (auto-detected from response if not provided)
1419
+
1420
+ Example:
1421
+ >>> @trace_embedding()
1422
+ ... def embed_text(text: str) -> list:
1423
+ ... return openai.embeddings.create(input=text, model="text-embedding-3-small")
1424
+ """
1425
+ return _trace_span("embedding", name, model, standalone=True)
1426
+
1427
+
1428
+ def trace_retrieval(name: Optional[str] = None) -> Callable:
1429
+ """
1430
+ Decorator to trace a retrieval/search operation (e.g., vector DB search).
1431
+ Always creates a standalone trace when not inside @trace_agent.
1432
+
1433
+ Args:
1434
+ name: Optional custom name for the retrieval span
1435
+
1436
+ Example:
1437
+ >>> @trace_retrieval("vector_search")
1438
+ ... def search_docs(query: str) -> list:
1439
+ ... return vector_db.similarity_search(query)
1440
+ """
1441
+ return _trace_span("retrieval", name, standalone=True)