ragaai-catalyst 2.1.3__py3-none-any.whl → 2.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. ragaai_catalyst/tracers/agentic_tracing/data/data_structure.py +37 -11
  2. ragaai_catalyst/tracers/agentic_tracing/tracers/agent_tracer.py +240 -81
  3. ragaai_catalyst/tracers/agentic_tracing/tracers/base.py +632 -114
  4. ragaai_catalyst/tracers/agentic_tracing/tracers/custom_tracer.py +316 -0
  5. ragaai_catalyst/tracers/agentic_tracing/tracers/langgraph_tracer.py +0 -0
  6. ragaai_catalyst/tracers/agentic_tracing/tracers/llm_tracer.py +229 -82
  7. ragaai_catalyst/tracers/agentic_tracing/tracers/main_tracer.py +214 -59
  8. ragaai_catalyst/tracers/agentic_tracing/tracers/network_tracer.py +16 -14
  9. ragaai_catalyst/tracers/agentic_tracing/tracers/tool_tracer.py +147 -28
  10. ragaai_catalyst/tracers/agentic_tracing/tracers/user_interaction_tracer.py +88 -2
  11. ragaai_catalyst/tracers/agentic_tracing/upload/upload_agentic_traces.py +9 -51
  12. ragaai_catalyst/tracers/agentic_tracing/upload/upload_trace_metric.py +83 -0
  13. ragaai_catalyst/tracers/agentic_tracing/utils/create_dataset_schema.py +26 -0
  14. ragaai_catalyst/tracers/agentic_tracing/utils/get_user_trace_metrics.py +28 -0
  15. ragaai_catalyst/tracers/agentic_tracing/utils/llm_utils.py +45 -15
  16. ragaai_catalyst/tracers/agentic_tracing/utils/model_costs.json +2520 -2152
  17. ragaai_catalyst/tracers/agentic_tracing/utils/span_attributes.py +59 -0
  18. ragaai_catalyst/tracers/agentic_tracing/utils/trace_utils.py +23 -0
  19. ragaai_catalyst/tracers/agentic_tracing/utils/zip_list_of_unique_files.py +284 -15
  20. ragaai_catalyst/tracers/llamaindex_callback.py +5 -5
  21. ragaai_catalyst/tracers/tracer.py +83 -10
  22. ragaai_catalyst/tracers/upload_traces.py +1 -1
  23. ragaai_catalyst-2.1.4.dist-info/METADATA +431 -0
  24. {ragaai_catalyst-2.1.3.dist-info → ragaai_catalyst-2.1.4.dist-info}/RECORD +26 -20
  25. ragaai_catalyst-2.1.3.dist-info/METADATA +0 -43
  26. {ragaai_catalyst-2.1.3.dist-info → ragaai_catalyst-2.1.4.dist-info}/WHEEL +0 -0
  27. {ragaai_catalyst-2.1.3.dist-info → ragaai_catalyst-2.1.4.dist-info}/top_level.txt +0 -0
@@ -5,22 +5,48 @@ import psutil
5
5
  import pkg_resources
6
6
  from datetime import datetime
7
7
  from pathlib import Path
8
- from typing import List
8
+ from typing import List, Any
9
9
  import uuid
10
10
  import sys
11
11
  import tempfile
12
+ import threading
13
+ import time
12
14
  from ....ragaai_catalyst import RagaAICatalyst
13
15
  from ..data.data_structure import (
14
- Trace, Metadata, SystemInfo, OSInfo, EnvironmentInfo,
15
- Resources, CPUResource, MemoryResource, DiskResource, NetworkResource,
16
- ResourceInfo, MemoryInfo, DiskInfo, NetworkInfo,
17
- Component,
16
+ Trace,
17
+ Metadata,
18
+ SystemInfo,
19
+ OSInfo,
20
+ EnvironmentInfo,
21
+ Resources,
22
+ CPUResource,
23
+ MemoryResource,
24
+ DiskResource,
25
+ NetworkResource,
26
+ ResourceInfo,
27
+ MemoryInfo,
28
+ DiskInfo,
29
+ NetworkInfo,
30
+ Component,
18
31
  )
19
32
 
20
33
  from ..upload.upload_agentic_traces import UploadAgenticTraces
21
34
  from ..upload.upload_code import upload_code
35
+ from ..upload.upload_trace_metric import upload_trace_metric
22
36
  from ..utils.file_name_tracker import TrackName
23
37
  from ..utils.zip_list_of_unique_files import zip_list_of_unique_files
38
+ from ..utils.span_attributes import SpanAttributes
39
+ from ..utils.create_dataset_schema import create_dataset_schema_with_trace
40
+
41
+
42
+ # Configure logging to show debug messages (which includes info messages as well)
43
+ import logging
44
+
45
+ logger = logging.getLogger(__name__)
46
+ logging_level = (
47
+ logger.setLevel(logging.DEBUG) if os.getenv("DEBUG") == "1" else logging.INFO
48
+ )
49
+
24
50
 
25
51
  class TracerJSONEncoder(json.JSONEncoder):
26
52
  def default(self, obj):
@@ -28,227 +54,339 @@ class TracerJSONEncoder(json.JSONEncoder):
28
54
  return obj.isoformat()
29
55
  if isinstance(obj, bytes):
30
56
  try:
31
- return obj.decode('utf-8')
57
+ return obj.decode("utf-8")
32
58
  except UnicodeDecodeError:
33
59
  return str(obj) # Fallback to string representation
34
- if hasattr(obj, 'to_dict'): # Handle objects with to_dict method
60
+ if hasattr(obj, "to_dict"): # Handle objects with to_dict method
35
61
  return obj.to_dict()
36
- if hasattr(obj, '__dict__'):
62
+ if hasattr(obj, "__dict__"):
37
63
  # Filter out None values and handle nested serialization
38
- return {k: v for k, v in obj.__dict__.items()
39
- if v is not None and not k.startswith('_')}
64
+ return {
65
+ k: v
66
+ for k, v in obj.__dict__.items()
67
+ if v is not None and not k.startswith("_")
68
+ }
40
69
  try:
41
70
  # Try to convert to a basic type
42
71
  return str(obj)
43
72
  except:
44
73
  return None # Last resort: return None instead of failing
45
74
 
75
+
46
76
  class BaseTracer:
47
77
  def __init__(self, user_details):
48
78
  self.user_details = user_details
49
- self.project_name = self.user_details['project_name'] # Access the project_name
50
- self.dataset_name = self.user_details['dataset_name'] # Access the dataset_name
51
- self.project_id = self.user_details['project_id'] # Access the project_id
52
-
79
+ self.project_name = self.user_details["project_name"] # Access the project_name
80
+ self.dataset_name = self.user_details["dataset_name"] # Access the dataset_name
81
+ self.project_id = self.user_details["project_id"] # Access the project_id
82
+ self.trace_name = self.user_details["trace_name"] # Access the trace_name
83
+ self.visited_metrics = []
84
+
53
85
  # Initialize trace data
54
86
  self.trace_id = None
55
- self.start_time = None
87
+ self.start_time = None
56
88
  self.components: List[Component] = []
57
89
  self.file_tracker = TrackName()
58
-
90
+ self.span_attributes_dict = {}
91
+
92
+ self.interval_time = self.user_details['interval_time']
93
+ self.memory_usage_list = []
94
+ self.cpu_usage_list = []
95
+ self.disk_usage_list = []
96
+ self.network_usage_list = []
97
+ self.tracking_thread = None
98
+ self.tracking = False
99
+
59
100
  def _get_system_info(self) -> SystemInfo:
60
101
  # Get OS info
61
102
  os_info = OSInfo(
62
103
  name=platform.system(),
63
104
  version=platform.version(),
64
105
  platform=platform.machine(),
65
- kernel_version=platform.release()
106
+ kernel_version=platform.release(),
66
107
  )
67
-
108
+
68
109
  # Get Python environment info
69
- installed_packages = [f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set]
110
+ installed_packages = [
111
+ f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set
112
+ ]
70
113
  env_info = EnvironmentInfo(
71
114
  name="Python",
72
115
  version=platform.python_version(),
73
116
  packages=installed_packages,
74
117
  env_path=sys.prefix,
75
- command_to_run=f"python {sys.argv[0]}"
118
+ command_to_run=f"python {sys.argv[0]}",
76
119
  )
77
-
120
+
78
121
  return SystemInfo(
79
122
  id=f"sys_{self.trace_id}",
80
123
  os=os_info,
81
124
  environment=env_info,
82
- source_code="Path to the source code .zip file in format hashid.zip" # TODO: Implement source code archiving
125
+ source_code="Path to the source code .zip file in format hashid.zip", # TODO: Implement source code archiving
83
126
  )
84
-
127
+
85
128
  def _get_resources(self) -> Resources:
86
129
  # CPU info
87
130
  cpu_info = ResourceInfo(
88
131
  name=platform.processor(),
89
132
  cores=psutil.cpu_count(logical=False),
90
- threads=psutil.cpu_count(logical=True)
91
- )
92
- cpu = CPUResource(
93
- info=cpu_info,
94
- interval="5s",
95
- values=[psutil.cpu_percent()]
133
+ threads=psutil.cpu_count(logical=True),
96
134
  )
97
-
135
+ cpu = CPUResource(info=cpu_info, interval="5s", values=[psutil.cpu_percent()])
136
+
98
137
  # Memory info
99
138
  memory = psutil.virtual_memory()
100
139
  mem_info = MemoryInfo(
101
140
  total=memory.total / (1024**3), # Convert to GB
102
- free=memory.available / (1024**3)
103
- )
104
- mem = MemoryResource(
105
- info=mem_info,
106
- interval="5s",
107
- values=[memory.percent]
141
+ free=memory.available / (1024**3),
108
142
  )
109
-
143
+ mem = MemoryResource(info=mem_info, interval="5s", values=[memory.percent])
144
+
110
145
  # Disk info
111
- disk = psutil.disk_usage('/')
112
- disk_info = DiskInfo(
113
- total=disk.total / (1024**3),
114
- free=disk.free / (1024**3)
115
- )
146
+ disk = psutil.disk_usage("/")
147
+ disk_info = DiskInfo(total=disk.total / (1024**3), free=disk.free / (1024**3))
116
148
  disk_io = psutil.disk_io_counters()
117
149
  disk_resource = DiskResource(
118
150
  info=disk_info,
119
151
  interval="5s",
120
152
  read=[disk_io.read_bytes / (1024**2)], # MB
121
- write=[disk_io.write_bytes / (1024**2)]
153
+ write=[disk_io.write_bytes / (1024**2)],
122
154
  )
123
-
155
+
124
156
  # Network info
125
157
  net_io = psutil.net_io_counters()
126
158
  net_info = NetworkInfo(
127
159
  upload_speed=net_io.bytes_sent / (1024**2), # MB
128
- download_speed=net_io.bytes_recv / (1024**2)
160
+ download_speed=net_io.bytes_recv / (1024**2),
129
161
  )
130
162
  net = NetworkResource(
131
163
  info=net_info,
132
164
  interval="5s",
133
165
  uploads=[net_io.bytes_sent / (1024**2)],
134
- downloads=[net_io.bytes_recv / (1024**2)]
166
+ downloads=[net_io.bytes_recv / (1024**2)],
135
167
  )
136
-
168
+
137
169
  return Resources(cpu=cpu, memory=mem, disk=disk_resource, network=net)
138
-
170
+
171
+ def _track_memory_usage(self):
172
+ self.memory_usage_list = []
173
+ while self.tracking:
174
+ memory_usage = psutil.Process().memory_info().rss
175
+ self.memory_usage_list.append(memory_usage / (1024 * 1024)) # Convert to MB and append to the list
176
+ time.sleep(self.interval_time)
177
+
178
+ def _track_cpu_usage(self):
179
+ self.cpu_usage_list = []
180
+ while self.tracking:
181
+ cpu_usage = psutil.cpu_percent(interval=self.interval_time)
182
+ self.cpu_usage_list.append(cpu_usage)
183
+ time.sleep(self.interval_time)
184
+
185
+ def _track_disk_usage(self):
186
+ self.disk_usage_list = []
187
+ while self.tracking:
188
+ disk_io = psutil.disk_io_counters()
189
+ self.disk_usage_list.append({
190
+ 'disk_read': disk_io.read_bytes / (1024 * 1024), # Convert to MB
191
+ 'disk_write': disk_io.write_bytes / (1024 * 1024) # Convert to MB
192
+ })
193
+ time.sleep(self.interval_time)
194
+
195
+ def _track_network_usage(self):
196
+ self.network_usage_list = []
197
+ while self.tracking:
198
+ net_io = psutil.net_io_counters()
199
+ self.network_usage_list.append({
200
+ 'uploads': net_io.bytes_sent / (1024 * 1024), # Convert to MB
201
+ 'downloads': net_io.bytes_recv / (1024 * 1024) # Convert to MB
202
+ })
203
+ time.sleep(self.interval_time)
204
+
139
205
  def start(self):
140
206
  """Initialize a new trace"""
207
+ self.tracking = True
208
+ self.tracking_thread = threading.Thread(target=self._track_memory_usage)
209
+ self.tracking_thread.start()
210
+ threading.Thread(target=self._track_cpu_usage).start()
211
+ threading.Thread(target=self._track_disk_usage).start()
212
+ threading.Thread(target=self._track_network_usage).start()
213
+
141
214
  metadata = Metadata(
142
215
  cost={},
143
216
  tokens={},
144
217
  system_info=self._get_system_info(),
145
- resources=self._get_resources()
218
+ resources=self._get_resources(),
146
219
  )
147
220
 
148
221
  # Generate a unique trace ID, when trace starts
149
- self.trace_id = str(uuid.uuid4())
150
-
222
+ self.trace_id = str(uuid.uuid4())
223
+
151
224
  # Get the start time
152
- self.start_time = datetime.now().isoformat()
153
-
154
- self.data_key = [{"start_time": datetime.now().isoformat(),
155
- "end_time": "",
156
- "spans": self.components
157
- }]
158
-
225
+ self.start_time = datetime.now().astimezone().isoformat()
226
+
227
+ self.data_key = [
228
+ {
229
+ "start_time": datetime.now().astimezone().isoformat(),
230
+ "end_time": "",
231
+ "spans": self.components,
232
+ }
233
+ ]
234
+
159
235
  self.trace = Trace(
160
236
  id=self.trace_id,
237
+ trace_name=self.trace_name,
161
238
  project_name=self.project_name,
162
- start_time=datetime.now().isoformat(),
239
+ start_time=datetime.now().astimezone().isoformat(),
163
240
  end_time="", # Will be set when trace is stopped
164
241
  metadata=metadata,
165
242
  data=self.data_key,
166
- replays={"source": None}
243
+ replays={"source": None},
167
244
  )
168
-
245
+
169
246
  def stop(self):
170
247
  """Stop the trace and save to JSON file"""
171
- if hasattr(self, 'trace'):
172
- self.trace.data[0]["end_time"] = datetime.now().isoformat()
173
- self.trace.end_time = datetime.now().isoformat()
248
+ if hasattr(self, "trace"):
249
+ self.trace.data[0]["end_time"] = datetime.now().astimezone().isoformat()
250
+ self.trace.end_time = datetime.now().astimezone().isoformat()
251
+
252
+ #track memory usage
253
+ self.tracking = False
254
+ if self.tracking_thread is not None:
255
+ self.tracking_thread.join()
256
+ self.trace.metadata.resources.memory.values = self.memory_usage_list
257
+
258
+ #track cpu usage
259
+ self.trace.metadata.resources.cpu.values = self.cpu_usage_list
260
+
261
+ #track network and disk usage
262
+ network_upoloads, network_downloads = 0, 0
263
+ disk_read, disk_write = 0, 0
264
+ for network_usage, disk_usage in zip(self.network_usage_list, self.disk_usage_list):
265
+ network_upoloads += network_usage['uploads']
266
+ network_downloads += network_usage['downloads']
267
+ disk_read += disk_usage['disk_read']
268
+ disk_write += disk_usage['disk_write']
269
+
270
+ #track disk usage
271
+ self.trace.metadata.resources.disk.read = [disk_read / len(self.disk_usage_list)]
272
+ self.trace.metadata.resources.disk.write = [disk_write / len(self.disk_usage_list)]
273
+
274
+ #track network usage
275
+ self.trace.metadata.resources.network.uploads = [network_upoloads / len(self.network_usage_list)]
276
+ self.trace.metadata.resources.network.downloads = [network_downloads / len(self.network_usage_list)]
277
+
278
+ # update interval time
279
+ self.trace.metadata.resources.cpu.interval = float(self.interval_time)
280
+ self.trace.metadata.resources.memory.interval = float(self.interval_time)
281
+ self.trace.metadata.resources.disk.interval = float(self.interval_time)
282
+ self.trace.metadata.resources.network.interval = float(self.interval_time)
174
283
 
175
284
  # Change span ids to int
176
285
  self.trace = self._change_span_ids_to_int(self.trace)
177
286
  self.trace = self._change_agent_input_output(self.trace)
178
287
  self.trace = self._extract_cost_tokens(self.trace)
179
-
288
+
180
289
  # Create traces directory if it doesn't exist
181
290
  self.traces_dir = tempfile.gettempdir()
182
291
  filename = self.trace.id + ".json"
183
292
  filepath = f"{self.traces_dir}/{filename}"
184
293
 
185
- #get unique files and zip it. Generate a unique hash ID for the contents of the files
294
+ # get unique files and zip it. Generate a unique hash ID for the contents of the files
186
295
  list_of_unique_files = self.file_tracker.get_unique_files()
187
- hash_id, zip_path = zip_list_of_unique_files(list_of_unique_files, output_dir=self.traces_dir)
296
+ hash_id, zip_path = zip_list_of_unique_files(
297
+ list_of_unique_files, output_dir=self.traces_dir
298
+ )
188
299
 
189
- #replace source code with zip_path
300
+ # replace source code with zip_path
190
301
  self.trace.metadata.system_info.source_code = hash_id
191
302
 
192
303
  # Clean up trace_data before saving
193
304
  trace_data = self.trace.__dict__
194
305
  cleaned_trace_data = self._clean_trace(trace_data)
195
306
 
196
- with open(filepath, 'w') as f:
307
+ # Format interactions and add to trace
308
+ interactions = self.format_interactions()
309
+ self.trace.workflow = interactions["workflow"]
310
+
311
+ with open(filepath, "w") as f:
197
312
  json.dump(cleaned_trace_data, f, cls=TracerJSONEncoder, indent=2)
198
-
199
- print(f"Trace saved to {filepath}")
313
+
314
+ logger.info(" Traces saved successfully.")
315
+ logger.debug(f"Trace saved to {filepath}")
200
316
  # Upload traces
317
+
201
318
  json_file_path = str(filepath)
202
319
  project_name = self.project_name
203
- project_id = self.project_id
320
+ project_id = self.project_id
204
321
  dataset_name = self.dataset_name
205
322
  user_detail = self.user_details
206
323
  base_url = RagaAICatalyst.BASE_URL
324
+
325
+ ## create dataset schema
326
+ response = create_dataset_schema_with_trace(
327
+ dataset_name=dataset_name, project_name=project_name
328
+ )
329
+
330
+ ##Upload trace metrics
331
+ response = upload_trace_metric(
332
+ json_file_path=json_file_path,
333
+ dataset_name=self.dataset_name,
334
+ project_name=self.project_name,
335
+ )
336
+
207
337
  upload_traces = UploadAgenticTraces(
208
338
  json_file_path=json_file_path,
209
339
  project_name=project_name,
210
340
  project_id=project_id,
211
341
  dataset_name=dataset_name,
212
342
  user_detail=user_detail,
213
- base_url=base_url
343
+ base_url=base_url,
214
344
  )
215
345
  upload_traces.upload_agentic_traces()
216
346
 
217
- #Upload Codehash
347
+ # Upload Codehash
218
348
  response = upload_code(
219
349
  hash_id=hash_id,
220
350
  zip_path=zip_path,
221
351
  project_name=project_name,
222
- dataset_name=dataset_name
352
+ dataset_name=dataset_name,
223
353
  )
224
354
  print(response)
225
-
355
+
226
356
  # Cleanup
227
357
  self.components = []
228
358
  self.file_tracker.reset()
229
-
359
+
230
360
  def add_component(self, component: Component):
231
361
  """Add a component to the trace"""
232
362
  self.components.append(component)
233
-
363
+
234
364
  def __enter__(self):
235
365
  self.start()
236
366
  return self
237
-
367
+
238
368
  def __exit__(self, exc_type, exc_value, traceback):
239
369
  self.stop()
240
370
 
371
+ def _process_children(self, children_list, parent_id, current_id):
372
+ """Helper function to process children recursively."""
373
+ for child in children_list:
374
+ child["id"] = current_id
375
+ child["parent_id"] = parent_id
376
+ current_id += 1
377
+ # Recursively process nested children if they exist
378
+ if "children" in child["data"]:
379
+ current_id = self._process_children(child["data"]["children"], child["id"], current_id)
380
+ return current_id
381
+
241
382
  def _change_span_ids_to_int(self, trace):
242
383
  id, parent_id = 1, 0
243
384
  for span in trace.data[0]["spans"]:
244
385
  span.id = id
245
386
  span.parent_id = parent_id
246
387
  id += 1
247
- if span.type=="agent":
248
- for children in span.data["children"]:
249
- children["id"] = id
250
- children["parent_id"] = span.id
251
- id += 1
388
+ if span.type == "agent" and "children" in span.data:
389
+ id = self._process_children(span.data["children"], span.id, id)
252
390
  return trace
253
391
 
254
392
  def _change_agent_input_output(self, trace):
@@ -265,20 +403,24 @@ class BaseTracer:
265
403
  input_data = child["data"].get("input")
266
404
 
267
405
  if input_data:
268
- span.data["input"] = input_data['args'] if hasattr(input_data, 'args') else input_data
406
+ span.data["input"] = (
407
+ input_data["args"]
408
+ if hasattr(input_data, "args")
409
+ else input_data
410
+ )
269
411
  break
270
-
412
+
271
413
  # Find first non-null output going backward
272
414
  for child in reversed(childrens):
273
415
  if "data" not in child:
274
416
  continue
275
417
  output_data = child["data"].get("output")
276
-
418
+
277
419
  if output_data and output_data != "" and output_data != "None":
278
420
  span.data["output"] = output_data
279
421
  break
280
422
  return trace
281
-
423
+
282
424
  def _extract_cost_tokens(self, trace):
283
425
  cost = {}
284
426
  tokens = {}
@@ -286,30 +428,30 @@ class BaseTracer:
286
428
  if span.type == "llm":
287
429
  info = span.info
288
430
  if isinstance(info, dict):
289
- cost_info = info.get('cost', {})
431
+ cost_info = info.get("cost", {})
290
432
  for key, value in cost_info.items():
291
433
  if key not in cost:
292
- cost[key] = 0
434
+ cost[key] = 0
293
435
  cost[key] += value
294
- token_info = info.get('tokens', {})
436
+ token_info = info.get("tokens", {})
295
437
  for key, value in token_info.items():
296
438
  if key not in tokens:
297
439
  tokens[key] = 0
298
440
  tokens[key] += value
299
441
  if span.type == "agent":
300
442
  for children in span.data["children"]:
301
- if 'type' not in children:
443
+ if "type" not in children:
302
444
  continue
303
445
  if children["type"] != "llm":
304
446
  continue
305
447
  info = children["info"]
306
448
  if isinstance(info, dict):
307
- cost_info = info.get('cost', {})
449
+ cost_info = info.get("cost", {})
308
450
  for key, value in cost_info.items():
309
451
  if key not in cost:
310
- cost[key] = 0
452
+ cost[key] = 0
311
453
  cost[key] += value
312
- token_info = info.get('tokens', {})
454
+ token_info = info.get("tokens", {})
313
455
  for key, value in token_info.items():
314
456
  if key not in tokens:
315
457
  tokens[key] = 0
@@ -321,51 +463,427 @@ class BaseTracer:
321
463
  def _clean_trace(self, trace):
322
464
  # Convert span to dict if it has to_dict method
323
465
  def _to_dict_if_needed(obj):
324
- if hasattr(obj, 'to_dict'):
466
+ if hasattr(obj, "to_dict"):
325
467
  return obj.to_dict()
326
468
  return obj
327
469
 
328
470
  def deduplicate_spans(spans):
329
471
  seen_llm_spans = {} # Dictionary to track unique LLM spans
330
472
  unique_spans = []
331
-
473
+
332
474
  for span in spans:
333
475
  # Convert span to dictionary if needed
334
476
  span_dict = _to_dict_if_needed(span)
335
-
477
+
336
478
  # Skip spans without hash_id
337
- if 'hash_id' not in span_dict:
479
+ if "hash_id" not in span_dict:
338
480
  continue
339
-
340
- if span_dict.get('type') == 'llm':
481
+
482
+ if span_dict.get("type") == "llm":
341
483
  # Create a unique key based on hash_id, input, and output
342
484
  span_key = (
343
- span_dict.get('hash_id'),
344
- str(span_dict.get('data', {}).get('input')),
345
- str(span_dict.get('data', {}).get('output'))
485
+ span_dict.get("hash_id"),
486
+ str(span_dict.get("data", {}).get("input")),
487
+ str(span_dict.get("data", {}).get("output")),
346
488
  )
347
-
489
+
490
+ # Check if we've seen this span before
348
491
  if span_key not in seen_llm_spans:
349
492
  seen_llm_spans[span_key] = True
350
493
  unique_spans.append(span)
494
+ else:
495
+ # If we have interactions in the current span, replace the existing one
496
+ current_interactions = span_dict.get("interactions", [])
497
+ if current_interactions:
498
+ # Find and replace the existing span with this one that has interactions
499
+ for i, existing_span in enumerate(unique_spans):
500
+ existing_dict = (
501
+ existing_span
502
+ if isinstance(existing_span, dict)
503
+ else existing_span.__dict__
504
+ )
505
+ if (
506
+ existing_dict.get("hash_id")
507
+ == span_dict.get("hash_id")
508
+ and str(existing_dict.get("data", {}).get("input"))
509
+ == str(span_dict.get("data", {}).get("input"))
510
+ and str(existing_dict.get("data", {}).get("output"))
511
+ == str(span_dict.get("data", {}).get("output"))
512
+ ):
513
+ unique_spans[i] = span
514
+ break
351
515
  else:
352
516
  # For non-LLM spans, process their children if they exist
353
- if 'data' in span_dict and 'children' in span_dict['data']:
354
- children = span_dict['data']['children']
517
+ if "data" in span_dict and "children" in span_dict["data"]:
518
+ children = span_dict["data"]["children"]
355
519
  # Filter and deduplicate children
356
520
  filtered_children = deduplicate_spans(children)
357
521
  if isinstance(span, dict):
358
- span['data']['children'] = filtered_children
522
+ span["data"]["children"] = filtered_children
359
523
  else:
360
- span.data['children'] = filtered_children
524
+ span.data["children"] = filtered_children
361
525
  unique_spans.append(span)
362
-
526
+
363
527
  return unique_spans
364
528
 
365
529
  # Remove any spans without hash ids
366
- for data in trace.get('data', []):
367
- if 'spans' in data:
530
+ for data in trace.get("data", []):
531
+ if "spans" in data:
368
532
  # First filter out spans without hash_ids, then deduplicate
369
- data['spans'] = deduplicate_spans(data['spans'])
533
+ data["spans"] = deduplicate_spans(data["spans"])
534
+
535
+ return trace
536
+
537
+ def add_tags(self, tags: List[str]):
538
+ raise NotImplementedError
539
+
540
+ def _process_child_interactions(self, child, interaction_id, interactions):
541
+ """
542
+ Helper method to process child interactions recursively.
543
+
544
+ Args:
545
+ child (dict): The child span to process
546
+ interaction_id (int): Current interaction ID
547
+ interactions (list): List of interactions to append to
548
+
549
+ Returns:
550
+ int: Next interaction ID to use
551
+ """
552
+ child_type = child.get("type")
370
553
 
371
- return trace
554
+ if child_type == "tool":
555
+ # Tool call start
556
+ interactions.append(
557
+ {
558
+ "id": str(interaction_id),
559
+ "span_id": child.get("id"),
560
+ "interaction_type": "tool_call_start",
561
+ "name": child.get("name"),
562
+ "content": {
563
+ "parameters": [
564
+ child.get("data", {}).get("input", {}).get("args"),
565
+ child.get("data", {}).get("input", {}).get("kwargs"),
566
+ ]
567
+ },
568
+ "timestamp": child.get("start_time"),
569
+ "error": child.get("error"),
570
+ }
571
+ )
572
+ interaction_id += 1
573
+
574
+ # Tool call end
575
+ interactions.append(
576
+ {
577
+ "id": str(interaction_id),
578
+ "span_id": child.get("id"),
579
+ "interaction_type": "tool_call_end",
580
+ "name": child.get("name"),
581
+ "content": {
582
+ "returns": child.get("data", {}).get("output"),
583
+ },
584
+ "timestamp": child.get("end_time"),
585
+ "error": child.get("error"),
586
+ }
587
+ )
588
+ interaction_id += 1
589
+
590
+ elif child_type == "llm":
591
+ interactions.append(
592
+ {
593
+ "id": str(interaction_id),
594
+ "span_id": child.get("id"),
595
+ "interaction_type": "llm_call_start",
596
+ "name": child.get("name"),
597
+ "content": {
598
+ "prompt": child.get("data", {}).get("input"),
599
+ },
600
+ "timestamp": child.get("start_time"),
601
+ "error": child.get("error"),
602
+ }
603
+ )
604
+ interaction_id += 1
605
+
606
+ interactions.append(
607
+ {
608
+ "id": str(interaction_id),
609
+ "span_id": child.get("id"),
610
+ "interaction_type": "llm_call_end",
611
+ "name": child.get("name"),
612
+ "content": {
613
+ "response": child.get("data", {}).get("output")
614
+ },
615
+ "timestamp": child.get("end_time"),
616
+ "error": child.get("error"),
617
+ }
618
+ )
619
+ interaction_id += 1
620
+
621
+ elif child_type == "agent":
622
+ interactions.append(
623
+ {
624
+ "id": str(interaction_id),
625
+ "span_id": child.get("id"),
626
+ "interaction_type": "agent_call_start",
627
+ "name": child.get("name"),
628
+ "content": None,
629
+ "timestamp": child.get("start_time"),
630
+ "error": child.get("error"),
631
+ }
632
+ )
633
+ interaction_id += 1
634
+
635
+ # Process nested children recursively
636
+ if "children" in child.get("data", {}):
637
+ for nested_child in child["data"]["children"]:
638
+ interaction_id = self._process_child_interactions(
639
+ nested_child, interaction_id, interactions
640
+ )
641
+
642
+ interactions.append(
643
+ {
644
+ "id": str(interaction_id),
645
+ "span_id": child.get("id"),
646
+ "interaction_type": "agent_call_end",
647
+ "name": child.get("name"),
648
+ "content": child.get("data", {}).get("output"),
649
+ "timestamp": child.get("end_time"),
650
+ "error": child.get("error"),
651
+ }
652
+ )
653
+ interaction_id += 1
654
+
655
+ else:
656
+ interactions.append(
657
+ {
658
+ "id": str(interaction_id),
659
+ "span_id": child.get("id"),
660
+ "interaction_type": child_type,
661
+ "name": child.get("name"),
662
+ "content": child.get("data", {}),
663
+ "timestamp": child.get("start_time"),
664
+ "error": child.get("error"),
665
+ }
666
+ )
667
+ interaction_id += 1
668
+
669
+ # Process additional interactions and network calls
670
+ if "interactions" in child:
671
+ for interaction in child["interactions"]:
672
+ interaction["id"] = str(interaction_id)
673
+ interaction["span_id"] = child.get("id")
674
+ interaction["error"] = None
675
+ interactions.append(interaction)
676
+ interaction_id += 1
677
+
678
+ if "network_calls" in child:
679
+ for child_network_call in child["network_calls"]:
680
+ network_call = {}
681
+ network_call["id"] = str(interaction_id)
682
+ network_call["span_id"] = child.get("id")
683
+ network_call["interaction_type"] = "network_call"
684
+ network_call["name"] = None
685
+ network_call["content"] = {
686
+ "request": {
687
+ "url": child_network_call.get("url"),
688
+ "method": child_network_call.get("method"),
689
+ "headers": child_network_call.get("headers"),
690
+ },
691
+ "response": {
692
+ "status_code": child_network_call.get("status_code"),
693
+ "headers": child_network_call.get("response_headers"),
694
+ "body": child_network_call.get("response_body"),
695
+ },
696
+ }
697
+ network_call["timestamp"] = child_network_call.get("start_time")
698
+ network_call["error"] = child_network_call.get("error")
699
+ interactions.append(network_call)
700
+ interaction_id += 1
701
+
702
+ return interaction_id
703
+
704
+ def format_interactions(self) -> dict:
705
+ """
706
+ Format interactions from trace data into a standardized format.
707
+ Returns a dictionary containing formatted interactions based on trace data.
708
+
709
+ The function processes spans from self.trace and formats them into interactions
710
+ of various types including: agent_start, agent_end, input, output, tool_call_start,
711
+ tool_call_end, llm_call, file_read, file_write, network_call.
712
+
713
+ Returns:
714
+ dict: A dictionary with "workflow" key containing a list of interactions
715
+ sorted by timestamp.
716
+ """
717
+ interactions = []
718
+ interaction_id = 1
719
+
720
+ if not hasattr(self, "trace") or not self.trace.data:
721
+ return {"workflow": []}
722
+
723
+ for span in self.trace.data[0]["spans"]:
724
+ # Process agent spans
725
+ if span.type == "agent":
726
+ # Add agent_start interaction
727
+ interactions.append(
728
+ {
729
+ "id": str(interaction_id),
730
+ "span_id": span.id,
731
+ "interaction_type": "agent_call_start",
732
+ "name": span.name,
733
+ "content": None,
734
+ "timestamp": span.start_time,
735
+ "error": span.error,
736
+ }
737
+ )
738
+ interaction_id += 1
739
+
740
+ # Process children of agent recursively
741
+ if "children" in span.data:
742
+ for child in span.data["children"]:
743
+ interaction_id = self._process_child_interactions(
744
+ child, interaction_id, interactions
745
+ )
746
+
747
+ # Add agent_end interaction
748
+ interactions.append(
749
+ {
750
+ "id": str(interaction_id),
751
+ "span_id": span.id,
752
+ "interaction_type": "agent_call_end",
753
+ "name": span.name,
754
+ "content": span.data.get("output"),
755
+ "timestamp": span.end_time,
756
+ "error": span.error,
757
+ }
758
+ )
759
+ interaction_id += 1
760
+
761
+ elif span.type == "tool":
762
+ interactions.append(
763
+ {
764
+ "id": str(interaction_id),
765
+ "span_id": span.id,
766
+ "interaction_type": "tool_call_start",
767
+ "name": span.name,
768
+ "content": {
769
+ "prompt": span.data.get("input"),
770
+ "response": span.data.get("output"),
771
+ },
772
+ "timestamp": span.start_time,
773
+ "error": span.error,
774
+ }
775
+ )
776
+ interaction_id += 1
777
+
778
+ interactions.append(
779
+ {
780
+ "id": str(interaction_id),
781
+ "span_id": span.id,
782
+ "interaction_type": "tool_call_end",
783
+ "name": span.name,
784
+ "content": {
785
+ "prompt": span.data.get("input"),
786
+ "response": span.data.get("output"),
787
+ },
788
+ "timestamp": span.end_time,
789
+ "error": span.error,
790
+ }
791
+ )
792
+ interaction_id += 1
793
+
794
+ elif span.type == "llm":
795
+ interactions.append(
796
+ {
797
+ "id": str(interaction_id),
798
+ "span_id": span.id,
799
+ "interaction_type": "llm_call_start",
800
+ "name": span.name,
801
+ "content": {
802
+ "prompt": span.data.get("input"),
803
+ },
804
+ "timestamp": span.start_time,
805
+ "error": span.error,
806
+ }
807
+ )
808
+ interaction_id += 1
809
+
810
+ interactions.append(
811
+ {
812
+ "id": str(interaction_id),
813
+ "span_id": span.id,
814
+ "interaction_type": "llm_call_end",
815
+ "name": span.name,
816
+ "content": {"response": span.data.get("output")},
817
+ "timestamp": span.end_time,
818
+ "error": span.error,
819
+ }
820
+ )
821
+ interaction_id += 1
822
+
823
+ else:
824
+ interactions.append(
825
+ {
826
+ "id": str(interaction_id),
827
+ "span_id": span.id,
828
+ "interaction_type": span.type,
829
+ "name": span.name,
830
+ "content": span.data,
831
+ "timestamp": span.start_time,
832
+ "error": span.error,
833
+ }
834
+ )
835
+ interaction_id += 1
836
+
837
+ # Process interactions from span.data if they exist
838
+ if span.interactions:
839
+ for span_interaction in span.interactions:
840
+ interaction = {}
841
+ interaction["id"] = str(interaction_id)
842
+ interaction["span_id"] = span.id
843
+ interaction["interaction_type"] = span_interaction.type
844
+ interaction["content"] = span_interaction.content
845
+ interaction["timestamp"] = span_interaction.timestamp
846
+ interaction["error"] = span.error
847
+ interactions.append(interaction)
848
+ interaction_id += 1
849
+
850
+ if span.network_calls:
851
+ for span_network_call in span.network_calls:
852
+ network_call = {}
853
+ network_call["id"] = str(interaction_id)
854
+ network_call["span_id"] = span.id
855
+ network_call["interaction_type"] = "network_call"
856
+ network_call["name"] = None
857
+ network_call["content"] = {
858
+ "request": {
859
+ "url": span_network_call.get("url"),
860
+ "method": span_network_call.get("method"),
861
+ "headers": span_network_call.get("headers"),
862
+ },
863
+ "response": {
864
+ "status_code": span_network_call.get("status_code"),
865
+ "headers": span_network_call.get("response_headers"),
866
+ "body": span_network_call.get("response_body"),
867
+ },
868
+ }
869
+ network_call["timestamp"] = span_network_call.get("timestamp")
870
+ network_call["error"] = span_network_call.get("error")
871
+ interactions.append(network_call)
872
+ interaction_id += 1
873
+
874
+ # Sort interactions by timestamp
875
+ sorted_interactions = sorted(
876
+ interactions, key=lambda x: x["timestamp"] if x["timestamp"] else ""
877
+ )
878
+
879
+ # Reassign IDs to maintain sequential order after sorting
880
+ for idx, interaction in enumerate(sorted_interactions, 1):
881
+ interaction["id"] = str(idx)
882
+
883
+ return {"workflow": sorted_interactions}
884
+
885
+ def span(self, span_name):
886
+ if span_name not in self.span_attributes_dict:
887
+ self.span_attributes_dict[span_name] = SpanAttributes(span_name)
888
+ return self.span_attributes_dict[span_name]
889
+