ragaai-catalyst 2.0.7.2b1__py3-none-any.whl → 2.1b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. ragaai_catalyst/dataset.py +0 -3
  2. ragaai_catalyst/evaluation.py +1 -2
  3. ragaai_catalyst/tracers/__init__.py +1 -1
  4. ragaai_catalyst/tracers/agentic_tracing/agent_tracer.py +217 -106
  5. ragaai_catalyst/tracers/agentic_tracing/agentic_tracing.py +27 -41
  6. ragaai_catalyst/tracers/agentic_tracing/base.py +127 -21
  7. ragaai_catalyst/tracers/agentic_tracing/data_structure.py +88 -79
  8. ragaai_catalyst/tracers/agentic_tracing/examples/FinancialAnalysisSystem.ipynb +536 -0
  9. ragaai_catalyst/tracers/agentic_tracing/examples/GameActivityEventPlanner.ipynb +134 -0
  10. ragaai_catalyst/tracers/agentic_tracing/examples/TravelPlanner.ipynb +563 -0
  11. ragaai_catalyst/tracers/agentic_tracing/file_name_tracker.py +46 -0
  12. ragaai_catalyst/tracers/agentic_tracing/llm_tracer.py +258 -356
  13. ragaai_catalyst/tracers/agentic_tracing/tool_tracer.py +31 -19
  14. ragaai_catalyst/tracers/agentic_tracing/unique_decorator.py +61 -117
  15. ragaai_catalyst/tracers/agentic_tracing/upload_agentic_traces.py +187 -0
  16. ragaai_catalyst/tracers/agentic_tracing/upload_code.py +115 -0
  17. ragaai_catalyst/tracers/agentic_tracing/user_interaction_tracer.py +35 -59
  18. ragaai_catalyst/tracers/agentic_tracing/utils/llm_utils.py +0 -4
  19. ragaai_catalyst/tracers/agentic_tracing/utils/model_costs.json +2201 -324
  20. ragaai_catalyst/tracers/agentic_tracing/zip_list_of_unique_files.py +342 -0
  21. ragaai_catalyst/tracers/exporters/raga_exporter.py +1 -7
  22. ragaai_catalyst/tracers/llamaindex_callback.py +56 -60
  23. ragaai_catalyst/tracers/tracer.py +6 -2
  24. ragaai_catalyst/tracers/upload_traces.py +46 -57
  25. {ragaai_catalyst-2.0.7.2b1.dist-info → ragaai_catalyst-2.1b1.dist-info}/METADATA +6 -2
  26. {ragaai_catalyst-2.0.7.2b1.dist-info → ragaai_catalyst-2.1b1.dist-info}/RECORD +28 -22
  27. ragaai_catalyst/tracers/agentic_tracing/Untitled-1.json +0 -660
  28. {ragaai_catalyst-2.0.7.2b1.dist-info → ragaai_catalyst-2.1b1.dist-info}/WHEEL +0 -0
  29. {ragaai_catalyst-2.0.7.2b1.dist-info → ragaai_catalyst-2.1b1.dist-info}/top_level.txt +0 -0
@@ -1,42 +1,41 @@
1
1
  from typing import Optional, Any, Dict, List
2
2
  import asyncio
3
3
  import psutil
4
- import json
5
4
  import wrapt
6
5
  import functools
7
6
  from datetime import datetime
8
7
  import uuid
9
- import os
10
8
  import contextvars
11
- import sys
12
- import gc
9
+ import traceback
13
10
 
14
- from .unique_decorator import mydecorator
15
- from .utils.trace_utils import calculate_cost, load_model_costs
11
+ from .unique_decorator import generate_unique_hash_simple
12
+ from .utils.trace_utils import load_model_costs
16
13
  from .utils.llm_utils import extract_llm_output
14
+ from .file_name_tracker import TrackName
17
15
 
18
16
 
19
17
  class LLMTracerMixin:
20
18
  def __init__(self, *args, **kwargs):
21
19
  super().__init__(*args, **kwargs)
20
+ self.file_tracker = TrackName()
22
21
  self.patches = []
23
22
  try:
24
23
  self.model_costs = load_model_costs()
25
24
  except Exception as e:
26
- # If model costs can't be loaded, use default costs
27
25
  self.model_costs = {
26
+ # TODO: Default cost handling needs to be improved
28
27
  "default": {
29
- "input_cost_per_token": 0.00002,
30
- "output_cost_per_token": 0.00002
28
+ "input_cost_per_token": 0.0,
29
+ "output_cost_per_token": 0.0
31
30
  }
32
31
  }
33
32
  self.current_llm_call_name = contextvars.ContextVar("llm_call_name", default=None)
34
33
  self.component_network_calls = {}
34
+ self.component_user_interaction = {}
35
35
  self.current_component_id = None
36
36
  self.total_tokens = 0
37
37
  self.total_cost = 0.0
38
- # Apply decorator to trace_llm_call method
39
- self.trace_llm_call = mydecorator(self.trace_llm_call)
38
+ self.llm_data = {}
40
39
 
41
40
  def instrument_llm_calls(self):
42
41
  # Handle modules that are already imported
@@ -238,7 +237,7 @@ class LLMTracerMixin:
238
237
  setattr(obj, method_name, wrapped_method)
239
238
  self.patches.append((obj, method_name, original_method))
240
239
 
241
- def _extract_model_name(self, kwargs):
240
+ def _extract_model_name(self, args, kwargs, result):
242
241
  """Extract model name from kwargs or result"""
243
242
  # First try direct model parameter
244
243
  model = kwargs.get("model", "")
@@ -254,6 +253,7 @@ class LLMTracerMixin:
254
253
  elif hasattr(instance, "model"):
255
254
  model = instance.model
256
255
 
256
+ # TODO: This way isn't scalable. The necessity for normalising model names needs to be fixed. We shouldn't have to do this
257
257
  # Normalize Google model names
258
258
  if model and isinstance(model, str):
259
259
  model = model.lower()
@@ -263,33 +263,48 @@ class LLMTracerMixin:
263
263
  return "gemini-1.5-pro"
264
264
  if "gemini-pro" in model:
265
265
  return "gemini-pro"
266
+
267
+ if 'to_dict' in dir(result):
268
+ result = result.to_dict()
269
+ if 'model_version' in result:
270
+ model = result['model_version']
266
271
 
267
272
  return model or "default"
268
273
 
269
- def _extract_parameters(self, kwargs, result=None):
270
- """Extract parameters from kwargs or result"""
271
- params = {
272
- "temperature": kwargs.get("temperature", getattr(result, "temperature", 0.7)),
273
- "top_p": kwargs.get("top_p", getattr(result, "top_p", 1.0)),
274
- "max_tokens": kwargs.get("max_tokens", getattr(result, "max_tokens", 512))
275
- }
276
-
277
- # Add Google AI specific parameters if available
278
- if hasattr(kwargs.get("self", None), "generation_config"):
279
- gen_config = kwargs["self"].generation_config
280
- params.update({
281
- "candidate_count": getattr(gen_config, "candidate_count", 1),
282
- "stop_sequences": getattr(gen_config, "stop_sequences", []),
283
- "top_k": getattr(gen_config, "top_k", 40)
284
- })
285
-
286
- return params
274
+ def _extract_parameters(self, kwargs):
275
+ """Extract all non-null parameters from kwargs"""
276
+ parameters = {k: v for k, v in kwargs.items() if v is not None}
277
+
278
+ # Remove contents key in parameters (Google LLM Response)
279
+ if 'contents' in parameters:
280
+ del parameters['contents']
281
+
282
+ # Remove messages key in parameters (OpenAI message)
283
+ if 'messages' in parameters:
284
+ del parameters['messages']
285
+
286
+ if 'generation_config' in parameters:
287
+ generation_config = parameters['generation_config']
288
+ # If generation_config is already a dict, use it directly
289
+ if isinstance(generation_config, dict):
290
+ config_dict = generation_config
291
+ else:
292
+ # Convert GenerationConfig to dictionary if it has a to_dict method, otherwise try to get its __dict__
293
+ config_dict = getattr(generation_config, 'to_dict', lambda: generation_config.__dict__)()
294
+ parameters.update(config_dict)
295
+ del parameters['generation_config']
296
+
297
+ return parameters
287
298
 
288
299
  def _extract_token_usage(self, result):
289
300
  """Extract token usage from result"""
290
301
  # Handle coroutines
291
302
  if asyncio.iscoroutine(result):
292
- result = asyncio.run(result)
303
+ # Get the current event loop
304
+ loop = asyncio.get_event_loop()
305
+ # Run the coroutine in the current event loop
306
+ result = loop.run_until_complete(result)
307
+
293
308
 
294
309
  # Handle standard OpenAI/Anthropic format
295
310
  if hasattr(result, "usage"):
@@ -317,79 +332,26 @@ class LLMTracerMixin:
317
332
  # Try to get from raw response
318
333
  total_tokens = getattr(result._raw_response, "token_count", 0)
319
334
  return {
335
+ # TODO: This implementation is incorrect. Vertex AI does provide this breakdown
320
336
  "prompt_tokens": 0, # Vertex AI doesn't provide this breakdown
321
337
  "completion_tokens": total_tokens,
322
338
  "total_tokens": total_tokens
323
339
  }
324
340
 
325
- return {
341
+ return { # TODO: Passing 0 in case of not recorded is not correct. This needs to be fixes. Discuss before making changes to this
326
342
  "prompt_tokens": 0,
327
343
  "completion_tokens": 0,
328
344
  "total_tokens": 0
329
345
  }
330
346
 
331
- def _extract_input_data(self, kwargs, result):
332
- """Extract input data from kwargs and result"""
333
-
334
- # For Vertex AI GenerationResponse
335
- if hasattr(result, 'candidates') and hasattr(result, 'usage_metadata'):
336
- # Extract generation config
337
- generation_config = kwargs.get('generation_config', {})
338
- config_dict = {}
339
- if hasattr(generation_config, 'temperature'):
340
- config_dict['temperature'] = generation_config.temperature
341
- if hasattr(generation_config, 'top_p'):
342
- config_dict['top_p'] = generation_config.top_p
343
- if hasattr(generation_config, 'max_output_tokens'):
344
- config_dict['max_tokens'] = generation_config.max_output_tokens
345
- if hasattr(generation_config, 'candidate_count'):
346
- config_dict['n'] = generation_config.candidate_count
347
-
348
- return {
349
- "prompt": kwargs.get('contents', ''),
350
- "model": "gemini-1.5-flash-002",
351
- **config_dict
352
- }
353
-
354
- # For standard OpenAI format
355
- messages = kwargs.get("messages", [])
356
- if messages:
357
- return {
358
- "messages": messages,
359
- "model": kwargs.get("model", "unknown"),
360
- "temperature": kwargs.get("temperature", 0.7),
361
- "max_tokens": kwargs.get("max_tokens", None),
362
- "top_p": kwargs.get("top_p", None),
363
- "frequency_penalty": kwargs.get("frequency_penalty", None),
364
- "presence_penalty": kwargs.get("presence_penalty", None)
365
- }
366
-
367
- # For text completion format
368
- if "prompt" in kwargs:
369
- return {
370
- "prompt": kwargs["prompt"],
371
- "model": kwargs.get("model", "unknown"),
372
- "temperature": kwargs.get("temperature", 0.7),
373
- "max_tokens": kwargs.get("max_tokens", None),
374
- "top_p": kwargs.get("top_p", None),
375
- "frequency_penalty": kwargs.get("frequency_penalty", None),
376
- "presence_penalty": kwargs.get("presence_penalty", None)
377
- }
378
-
379
- # For any other case, try to extract from kwargs
380
- if "contents" in kwargs:
381
- return {
382
- "prompt": kwargs["contents"],
383
- "model": kwargs.get("model", "unknown"),
384
- "temperature": kwargs.get("temperature", 0.7),
385
- "max_tokens": kwargs.get("max_tokens", None),
386
- "top_p": kwargs.get("top_p", None)
387
- }
388
-
389
- print("No input data found")
390
- return {}
347
+ def _extract_input_data(self, args, kwargs, result):
348
+ return {
349
+ 'args': args,
350
+ 'kwargs': kwargs
351
+ }
391
352
 
392
353
  def _calculate_cost(self, token_usage, model_name):
354
+ # TODO: Passing default cost is a faulty logic & implementation and should be fixed
393
355
  """Calculate cost based on token usage and model"""
394
356
  if not isinstance(token_usage, dict):
395
357
  token_usage = {
@@ -398,74 +360,60 @@ class LLMTracerMixin:
398
360
  "total_tokens": token_usage if isinstance(token_usage, (int, float)) else 0
399
361
  }
400
362
 
363
+ # TODO: This is a temporary fix. This needs to be fixed
364
+
401
365
  # Get model costs, defaulting to Vertex AI PaLM2 costs if unknown
402
366
  model_cost = self.model_costs.get(model_name, {
403
- "input_cost_per_token": 0.0005, # $0.0005 per 1K input tokens
404
- "output_cost_per_token": 0.0005 # $0.0005 per 1K output tokens
367
+ "input_cost_per_token": 0.0,
368
+ "output_cost_per_token": 0.0
405
369
  })
406
370
 
407
- # Calculate costs per 1K tokens
408
- input_cost = (token_usage.get("prompt_tokens", 0) / 1000.0) * model_cost.get("input_cost_per_token", 0.0005)
409
- output_cost = (token_usage.get("completion_tokens", 0) / 1000.0) * model_cost.get("output_cost_per_token", 0.0005)
371
+ input_cost = (token_usage.get("prompt_tokens", 0)) * model_cost.get("input_cost_per_token", 0.0)
372
+ output_cost = (token_usage.get("completion_tokens", 0)) * model_cost.get("output_cost_per_token", 0.0)
410
373
  total_cost = input_cost + output_cost
411
374
 
375
+ # TODO: Return the value as it is, no need to round
412
376
  return {
413
- "input_cost": round(input_cost, 6),
414
- "output_cost": round(output_cost, 6),
415
- "total_cost": round(total_cost, 6)
377
+ "input_cost": round(input_cost, 10),
378
+ "output_cost": round(output_cost, 10),
379
+ "total_cost": round(total_cost, 10)
416
380
  }
417
381
 
418
- def create_llm_component(self, **kwargs):
419
- """Create an LLM component according to the data structure"""
420
- start_time = kwargs["start_time"]
421
-
422
- # Ensure cost and usage are dictionaries
423
- cost = kwargs.get("cost", {})
424
- if not isinstance(cost, dict):
425
- cost = {"total_cost": cost}
426
-
427
- usage = kwargs.get("usage", {})
428
- if not isinstance(usage, dict):
429
- usage = {"total_tokens": usage}
430
-
382
+ def create_llm_component(self, component_id, hash_id, name, llm_type, version, memory_used, start_time, end_time, input_data, output_data, cost={}, usage={}, error=None, parameters={}):
383
+ # Update total metrics
384
+ self.total_tokens += usage.get("total_tokens", 0)
385
+ self.total_cost += cost.get("total_cost", 0)
386
+
431
387
  component = {
432
- "id": kwargs["component_id"],
433
- "hash_id": kwargs["hash_id"],
388
+ "id": component_id,
389
+ "hash_id": hash_id,
434
390
  "source_hash_id": None,
435
391
  "type": "llm",
436
- "name": kwargs["name"],
392
+ "name": name,
437
393
  "start_time": start_time.isoformat(),
438
- "end_time": kwargs["end_time"].isoformat(),
439
- "error": kwargs.get("error"),
394
+ "end_time": end_time.isoformat(),
395
+ "error": error,
440
396
  "parent_id": self.current_agent_id.get(),
441
397
  "info": {
442
- "llm_type": kwargs.get("llm_type", "unknown"),
443
- "version": kwargs.get("version", "1.0.0"),
444
- "memory_used": kwargs.get("memory_used", 0),
398
+ "model": llm_type,
399
+ "version": version,
400
+ "memory_used": memory_used,
445
401
  "cost": cost,
446
- "tokens": usage
402
+ "tokens": usage,
403
+ **parameters
447
404
  },
448
405
  "data": {
449
- "input": kwargs.get("input_data"),
450
- "output": kwargs.get("output_data"),
451
- "memory_used": kwargs.get("memory_used", 0)
406
+ "input": input_data['args'] if hasattr(input_data, 'args') else input_data,
407
+ "output": output_data.output_response if output_data else None,
408
+ "memory_used": memory_used
452
409
  },
453
- "network_calls": self.component_network_calls.get(kwargs["component_id"], []),
454
- "interactions": [
455
- {
456
- "id": f"int_{uuid.uuid4()}",
457
- "interaction_type": "input",
458
- "timestamp": start_time.isoformat(),
459
- "content": kwargs.get("input_data")
460
- },
461
- {
462
- "id": f"int_{uuid.uuid4()}",
463
- "interaction_type": "output",
464
- "timestamp": kwargs["end_time"].isoformat(),
465
- "content": kwargs.get("output_data")
466
- }
467
- ]
410
+ "network_calls": self.component_network_calls.get(component_id, []),
411
+ "interactions": self.component_user_interaction.get(component_id, [])
468
412
  }
413
+
414
+ if self.gt:
415
+ component["data"]["gt"] = self.gt
416
+
469
417
  return component
470
418
 
471
419
  def start_component(self, component_id):
@@ -481,29 +429,19 @@ class LLMTracerMixin:
481
429
  async def trace_llm_call(self, original_func, *args, **kwargs):
482
430
  """Trace an LLM API call"""
483
431
  if not self.is_active:
484
- if asyncio.iscoroutinefunction(original_func):
485
- return await original_func(*args, **kwargs)
486
- return original_func(*args, **kwargs)
432
+ return await original_func(*args, **kwargs)
487
433
 
488
434
  start_time = datetime.now().astimezone()
489
435
  start_memory = psutil.Process().memory_info().rss
490
436
  component_id = str(uuid.uuid4())
491
- hash_id = self.trace_llm_call.hash_id
437
+ hash_id = generate_unique_hash_simple(original_func)
492
438
 
493
439
  # Start tracking network calls for this component
494
440
  self.start_component(component_id)
495
441
 
496
442
  try:
497
443
  # Execute the LLM call
498
- result = None
499
- if asyncio.iscoroutinefunction(original_func):
500
- result = await original_func(*args, **kwargs)
501
- else:
502
- result = original_func(*args, **kwargs)
503
-
504
- # If result is a coroutine, await it
505
- if asyncio.iscoroutine(result):
506
- result = await result
444
+ result = await original_func(*args, **kwargs)
507
445
 
508
446
  # Calculate resource usage
509
447
  end_time = datetime.now().astimezone()
@@ -511,30 +449,40 @@ class LLMTracerMixin:
511
449
  memory_used = max(0, end_memory - start_memory)
512
450
 
513
451
  # Extract token usage and calculate cost
514
- token_usage = await self._extract_token_usage(result)
515
- model_name = self._extract_model_name(kwargs)
452
+ token_usage = self._extract_token_usage(result)
453
+ model_name = self._extract_model_name(args, kwargs, result)
516
454
  cost = self._calculate_cost(token_usage, model_name)
455
+ parameters = self._extract_parameters(kwargs)
517
456
 
518
457
  # End tracking network calls for this component
519
458
  self.end_component(component_id)
520
459
 
460
+ name = self.current_llm_call_name.get()
461
+ if name is None:
462
+ name = original_func.__name__
463
+
464
+ # Create input data with ground truth
465
+ input_data = self._extract_input_data(args, kwargs, result)
466
+
521
467
  # Create LLM component
522
468
  llm_component = self.create_llm_component(
523
469
  component_id=component_id,
524
470
  hash_id=hash_id,
525
- name=self.current_llm_call_name.get(),
471
+ name=name,
526
472
  llm_type=model_name,
527
473
  version="1.0.0",
528
474
  memory_used=memory_used,
529
475
  start_time=start_time,
530
476
  end_time=end_time,
531
- input_data=self._extract_input_data(kwargs, result),
477
+ input_data=input_data,
532
478
  output_data=extract_llm_output(result),
533
479
  cost=cost,
534
- usage=token_usage
480
+ usage=token_usage,
481
+ parameters=parameters
535
482
  )
536
-
537
- self.add_component(llm_component)
483
+
484
+ # self.add_component(llm_component)
485
+ self.llm_data = llm_component
538
486
  return result
539
487
 
540
488
  except Exception as e:
@@ -549,67 +497,28 @@ class LLMTracerMixin:
549
497
  self.end_component(component_id)
550
498
 
551
499
  end_time = datetime.now().astimezone()
500
+
501
+ name = self.current_llm_call_name.get()
502
+ if name is None:
503
+ name = original_func.__name__
552
504
 
553
505
  llm_component = self.create_llm_component(
554
506
  component_id=component_id,
555
507
  hash_id=hash_id,
556
- name=self.current_llm_call_name.get(),
508
+ name=name,
557
509
  llm_type="unknown",
558
510
  version="1.0.0",
559
511
  memory_used=0,
560
512
  start_time=start_time,
561
513
  end_time=end_time,
562
- input_data=self._extract_input_data(kwargs, None),
514
+ input_data=self._extract_input_data(args, kwargs, None),
563
515
  output_data=None,
564
516
  error=error_component
565
517
  )
566
-
518
+
567
519
  self.add_component(llm_component)
568
520
  raise
569
521
 
570
- def _extract_token_usage_sync(self, result):
571
- """Sync version of extract token usage"""
572
- # Handle coroutines
573
- if asyncio.iscoroutine(result):
574
- result = asyncio.run(result)
575
-
576
- # Handle standard OpenAI/Anthropic format
577
- if hasattr(result, "usage"):
578
- usage = result.usage
579
- return {
580
- "prompt_tokens": getattr(usage, "prompt_tokens", 0),
581
- "completion_tokens": getattr(usage, "completion_tokens", 0),
582
- "total_tokens": getattr(usage, "total_tokens", 0)
583
- }
584
-
585
- # Handle Google GenerativeAI format with usage_metadata
586
- if hasattr(result, "usage_metadata"):
587
- metadata = result.usage_metadata
588
- return {
589
- "prompt_tokens": getattr(metadata, "prompt_token_count", 0),
590
- "completion_tokens": getattr(metadata, "candidates_token_count", 0),
591
- "total_tokens": getattr(metadata, "total_token_count", 0)
592
- }
593
-
594
- # Handle Vertex AI format
595
- if hasattr(result, "text"):
596
- # For LangChain ChatVertexAI
597
- total_tokens = getattr(result, "token_count", 0)
598
- if not total_tokens and hasattr(result, "_raw_response"):
599
- # Try to get from raw response
600
- total_tokens = getattr(result._raw_response, "token_count", 0)
601
- return {
602
- "prompt_tokens": 0, # Vertex AI doesn't provide this breakdown
603
- "completion_tokens": total_tokens,
604
- "total_tokens": total_tokens
605
- }
606
-
607
- return {
608
- "prompt_tokens": 0,
609
- "completion_tokens": 0,
610
- "total_tokens": 0
611
- }
612
-
613
522
  def trace_llm_call_sync(self, original_func, *args, **kwargs):
614
523
  """Sync version of trace_llm_call"""
615
524
  if not self.is_active:
@@ -618,54 +527,59 @@ class LLMTracerMixin:
618
527
  return original_func(*args, **kwargs)
619
528
 
620
529
  start_time = datetime.now().astimezone()
621
- start_memory = psutil.Process().memory_info().rss
622
530
  component_id = str(uuid.uuid4())
623
- hash_id = self.trace_llm_call.hash_id
531
+ hash_id = generate_unique_hash_simple(original_func)
624
532
 
625
533
  # Start tracking network calls for this component
626
534
  self.start_component(component_id)
627
535
 
536
+ # Calculate resource usage
537
+ end_time = datetime.now().astimezone()
538
+ start_memory = psutil.Process().memory_info().rss
539
+
628
540
  try:
629
- # Execute the LLM call
630
- result = None
541
+ # Execute the function
631
542
  if asyncio.iscoroutinefunction(original_func):
632
543
  result = asyncio.run(original_func(*args, **kwargs))
633
544
  else:
634
545
  result = original_func(*args, **kwargs)
635
546
 
636
- # If result is a coroutine, run it
637
- if asyncio.iscoroutine(result):
638
- result = asyncio.run(result)
639
-
640
- # Calculate resource usage
641
- end_time = datetime.now().astimezone()
642
547
  end_memory = psutil.Process().memory_info().rss
643
548
  memory_used = max(0, end_memory - start_memory)
644
549
 
645
550
  # Extract token usage and calculate cost
646
- token_usage = self._extract_token_usage_sync(result)
647
- model_name = self._extract_model_name(kwargs)
551
+ token_usage = self._extract_token_usage(result)
552
+ model_name = self._extract_model_name(args, kwargs, result)
648
553
  cost = self._calculate_cost(token_usage, model_name)
554
+ parameters = self._extract_parameters(kwargs)
649
555
 
650
556
  # End tracking network calls for this component
651
557
  self.end_component(component_id)
652
558
 
559
+ name = self.current_llm_call_name.get()
560
+ if name is None:
561
+ name = original_func.__name__
562
+
563
+ # Create input data with ground truth
564
+ input_data = self._extract_input_data(args, kwargs, result)
565
+
653
566
  # Create LLM component
654
567
  llm_component = self.create_llm_component(
655
568
  component_id=component_id,
656
569
  hash_id=hash_id,
657
- name=self.current_llm_call_name.get(),
570
+ name=name,
658
571
  llm_type=model_name,
659
572
  version="1.0.0",
660
573
  memory_used=memory_used,
661
574
  start_time=start_time,
662
575
  end_time=end_time,
663
- input_data=self._extract_input_data(kwargs, result),
576
+ input_data=input_data,
664
577
  output_data=extract_llm_output(result),
665
578
  cost=cost,
666
- usage=token_usage
579
+ usage=token_usage,
580
+ parameters=parameters
667
581
  )
668
-
582
+
669
583
  self.add_component(llm_component)
670
584
  return result
671
585
 
@@ -681,105 +595,135 @@ class LLMTracerMixin:
681
595
  self.end_component(component_id)
682
596
 
683
597
  end_time = datetime.now().astimezone()
598
+
599
+ name = self.current_llm_call_name.get()
600
+ if name is None:
601
+ name = original_func.__name__
602
+
603
+ end_memory = psutil.Process().memory_info().rss
604
+ memory_used = max(0, end_memory - start_memory)
684
605
 
685
606
  llm_component = self.create_llm_component(
686
607
  component_id=component_id,
687
608
  hash_id=hash_id,
688
- name=self.current_llm_call_name.get(),
609
+ name=name,
689
610
  llm_type="unknown",
690
611
  version="1.0.0",
691
- memory_used=0,
612
+ memory_used=memory_used,
692
613
  start_time=start_time,
693
614
  end_time=end_time,
694
- input_data=self._extract_input_data(kwargs, None),
615
+ input_data=self._extract_input_data(args, kwargs, None),
695
616
  output_data=None,
696
617
  error=error_component
697
618
  )
698
-
619
+
699
620
  self.add_component(llm_component)
700
621
  raise
701
622
 
702
- async def _extract_token_usage(self, result):
703
- """Extract token usage from result"""
704
- # Handle coroutines
705
- if asyncio.iscoroutine(result):
706
- result = await result
707
-
708
- # Handle standard OpenAI/Anthropic format
709
- if hasattr(result, "usage"):
710
- usage = result.usage
711
- return {
712
- "prompt_tokens": getattr(usage, "prompt_tokens", 0),
713
- "completion_tokens": getattr(usage, "completion_tokens", 0),
714
- "total_tokens": getattr(usage, "total_tokens", 0)
715
- }
716
-
717
- # Handle Google GenerativeAI format with usage_metadata
718
- if hasattr(result, "usage_metadata"):
719
- metadata = result.usage_metadata
720
- return {
721
- "prompt_tokens": getattr(metadata, "prompt_token_count", 0),
722
- "completion_tokens": getattr(metadata, "candidates_token_count", 0),
723
- "total_tokens": getattr(metadata, "total_token_count", 0)
724
- }
725
-
726
- # Handle Vertex AI format
727
- if hasattr(result, "text"):
728
- # For LangChain ChatVertexAI
729
- total_tokens = getattr(result, "token_count", 0)
730
- if not total_tokens and hasattr(result, "_raw_response"):
731
- # Try to get from raw response
732
- total_tokens = getattr(result._raw_response, "token_count", 0)
733
- return {
734
- "prompt_tokens": 0, # Vertex AI doesn't provide this breakdown
735
- "completion_tokens": total_tokens,
736
- "total_tokens": total_tokens
737
- }
738
-
739
- return {
740
- "prompt_tokens": 0,
741
- "completion_tokens": 0,
742
- "total_tokens": 0
743
- }
744
-
745
- def trace_llm(self, name: str, tool_type: str = "llm", version: str = "1.0.0"):
746
- def decorator(func_or_class):
747
- if isinstance(func_or_class, type):
748
- for attr_name, attr_value in func_or_class.__dict__.items():
749
- if callable(attr_value) and not attr_name.startswith("__"):
750
- setattr(
751
- func_or_class,
752
- attr_name,
753
- self.trace_llm(f"{name}.{attr_name}", tool_type, version)(attr_value),
754
- )
755
- return func_or_class
756
- else:
757
- @functools.wraps(func_or_class)
758
- async def async_wrapper(*args, **kwargs):
759
- token = self.current_llm_call_name.set(name)
760
- try:
761
- return await func_or_class(*args, **kwargs)
762
- finally:
763
- self.current_llm_call_name.reset(token)
764
-
765
- @functools.wraps(func_or_class)
766
- def sync_wrapper(*args, **kwargs):
767
- token = self.current_llm_call_name.set(name)
768
- try:
769
- return func_or_class(*args, **kwargs)
770
- finally:
771
- self.current_llm_call_name.reset(token)
772
-
773
- return async_wrapper if asyncio.iscoroutinefunction(func_or_class) else sync_wrapper
774
-
623
+ def trace_llm(self, name: str = None):
624
+ def decorator(func):
625
+ @self.file_tracker.trace_decorator
626
+ @functools.wraps(func)
627
+ async def async_wrapper(*args, **kwargs):
628
+ self.gt = kwargs.get('gt', None) if kwargs else None
629
+ if not self.is_active:
630
+ return await func(*args, **kwargs)
631
+
632
+ hash_id = generate_unique_hash_simple(func)
633
+ component_id = str(uuid.uuid4())
634
+ parent_agent_id = self.current_agent_id.get()
635
+ self.start_component(component_id)
636
+
637
+ start_time = datetime.now()
638
+ error_info = None
639
+ result = None
640
+
641
+ try:
642
+ result = await func(*args, **kwargs)
643
+ return result
644
+ except Exception as e:
645
+ error_info = {
646
+ "error": {
647
+ "type": type(e).__name__,
648
+ "message": str(e),
649
+ "traceback": traceback.format_exc(),
650
+ "timestamp": datetime.now().isoformat()
651
+ }
652
+ }
653
+ raise
654
+ finally:
655
+
656
+ llm_component = self.llm_data
657
+
658
+ if error_info:
659
+ llm_component["error"] = error_info["error"]
660
+
661
+ if parent_agent_id:
662
+ children = self.agent_children.get()
663
+ children.append(llm_component)
664
+ self.agent_children.set(children)
665
+ else:
666
+ self.add_component(llm_component)
667
+
668
+ self.end_component(component_id)
669
+
670
+ @self.file_tracker.trace_decorator
671
+ @functools.wraps(func)
672
+ def sync_wrapper(*args, **kwargs):
673
+ self.gt = kwargs.get('gt', None) if kwargs else None
674
+ if not self.is_active:
675
+ return func(*args, **kwargs)
676
+
677
+ hash_id = generate_unique_hash_simple(func)
678
+
679
+ component_id = str(uuid.uuid4())
680
+ parent_agent_id = self.current_agent_id.get()
681
+ self.start_component(component_id)
682
+
683
+ start_time = datetime.now()
684
+ error_info = None
685
+ result = None
686
+
687
+ try:
688
+ result = func(*args, **kwargs)
689
+ return result
690
+ except Exception as e:
691
+ error_info = {
692
+ "error": {
693
+ "type": type(e).__name__,
694
+ "message": str(e),
695
+ "traceback": traceback.format_exc(),
696
+ "timestamp": datetime.now().isoformat()
697
+ }
698
+ }
699
+ raise
700
+ finally:
701
+
702
+ llm_component = self.llm_data
703
+
704
+ if error_info:
705
+ llm_component["error"] = error_info["error"]
706
+
707
+ if parent_agent_id:
708
+ children = self.agent_children.get()
709
+ children.append(llm_component)
710
+ self.agent_children.set(children)
711
+ else:
712
+ self.add_component(llm_component)
713
+
714
+ self.end_component(component_id)
715
+
716
+ return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
775
717
  return decorator
776
718
 
777
719
  def unpatch_llm_calls(self):
778
- """Remove all patches"""
720
+ # Remove all patches
779
721
  for obj, method_name, original_method in self.patches:
780
- if hasattr(obj, method_name):
722
+ try:
781
723
  setattr(obj, method_name, original_method)
782
- self.patches.clear()
724
+ except Exception as e:
725
+ print(f"Error unpatching {method_name}: {str(e)}")
726
+ self.patches = []
783
727
 
784
728
  def _sanitize_api_keys(self, data):
785
729
  """Remove sensitive information from data"""
@@ -792,62 +736,20 @@ class LLMTracerMixin:
792
736
  return tuple(self._sanitize_api_keys(item) for item in data)
793
737
  return data
794
738
 
795
- def _create_llm_component(self, component_id, hash_id, name, llm_type, version, memory_used, start_time, end_time, input_data, output_data, usage=None, error=None):
796
- cost = None
797
- tokens = None
739
+ def _sanitize_input(self, args, kwargs):
740
+ """Convert input arguments to text format.
798
741
 
799
- if usage:
800
- tokens = {
801
- "prompt_tokens": usage.get("prompt_tokens", 0),
802
- "completion_tokens": usage.get("completion_tokens", 0),
803
- "total_tokens": usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0)
804
- }
805
- cost = calculate_cost(usage)
742
+ Args:
743
+ args: Input arguments that may contain nested dictionaries
806
744
 
807
- # Update total metrics
808
- self.total_tokens += tokens["total_tokens"]
809
- self.total_cost += cost["total"]
810
-
811
- component = {
812
- "id": component_id,
813
- "hash_id": hash_id,
814
- "source_hash_id": None,
815
- "type": "llm",
816
- "name": name,
817
- "start_time": start_time.isoformat(),
818
- "end_time": end_time.isoformat(),
819
- "error": error,
820
- "parent_id": self.current_agent_id.get(),
821
- "info": {
822
- "llm_type": llm_type,
823
- "version": version,
824
- "memory_used": memory_used,
825
- "cost": cost,
826
- "tokens": tokens
827
- },
828
- "data": {
829
- "input": input_data,
830
- "output": output_data.output_response if output_data else None,
831
- "memory_used": memory_used
832
- },
833
- "network_calls": self.component_network_calls.get(component_id, []),
834
- "interactions": [
835
- {
836
- "id": f"int_{uuid.uuid4()}",
837
- "interaction_type": "input",
838
- "timestamp": start_time.isoformat(),
839
- "content": input_data
840
- },
841
- {
842
- "id": f"int_{uuid.uuid4()}",
843
- "interaction_type": "output",
844
- "timestamp": end_time.isoformat(),
845
- "content": output_data.output_response if output_data else None
846
- }
847
- ]
848
- }
849
-
850
- return component
745
+ Returns:
746
+ str: Text representation of the input arguments
747
+ """
748
+ if isinstance(args, dict):
749
+ return str({k: self._sanitize_input(v, {}) for k, v in args.items()})
750
+ elif isinstance(args, (list, tuple)):
751
+ return str([self._sanitize_input(item, {}) for item in args])
752
+ return str(args)
851
753
 
852
754
  def extract_llm_output(result):
853
755
  """Extract output from LLM response"""