ragaai-catalyst 2.1.4b0__py3-none-any.whl → 2.1.4b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. ragaai_catalyst/tracers/agentic_tracing/data/data_structure.py +36 -10
  2. ragaai_catalyst/tracers/agentic_tracing/tracers/agent_tracer.py +226 -76
  3. ragaai_catalyst/tracers/agentic_tracing/tracers/base.py +568 -107
  4. ragaai_catalyst/tracers/agentic_tracing/tracers/custom_tracer.py +325 -0
  5. ragaai_catalyst/tracers/agentic_tracing/tracers/llm_tracer.py +218 -81
  6. ragaai_catalyst/tracers/agentic_tracing/tracers/main_tracer.py +210 -58
  7. ragaai_catalyst/tracers/agentic_tracing/tracers/network_tracer.py +2 -0
  8. ragaai_catalyst/tracers/agentic_tracing/tracers/tool_tracer.py +137 -28
  9. ragaai_catalyst/tracers/agentic_tracing/tracers/user_interaction_tracer.py +86 -0
  10. ragaai_catalyst/tracers/agentic_tracing/upload/upload_agentic_traces.py +9 -51
  11. ragaai_catalyst/tracers/agentic_tracing/upload/upload_trace_metric.py +83 -0
  12. ragaai_catalyst/tracers/agentic_tracing/utils/create_dataset_schema.py +26 -0
  13. ragaai_catalyst/tracers/agentic_tracing/utils/get_user_trace_metrics.py +28 -0
  14. ragaai_catalyst/tracers/agentic_tracing/utils/llm_utils.py +45 -15
  15. ragaai_catalyst/tracers/agentic_tracing/utils/model_costs.json +2476 -2122
  16. ragaai_catalyst/tracers/agentic_tracing/utils/span_attributes.py +59 -0
  17. ragaai_catalyst/tracers/agentic_tracing/utils/trace_utils.py +23 -0
  18. ragaai_catalyst/tracers/agentic_tracing/utils/zip_list_of_unique_files.py +284 -15
  19. ragaai_catalyst/tracers/tracer.py +77 -8
  20. {ragaai_catalyst-2.1.4b0.dist-info → ragaai_catalyst-2.1.4b2.dist-info}/METADATA +2 -1
  21. {ragaai_catalyst-2.1.4b0.dist-info → ragaai_catalyst-2.1.4b2.dist-info}/RECORD +23 -18
  22. {ragaai_catalyst-2.1.4b0.dist-info → ragaai_catalyst-2.1.4b2.dist-info}/WHEEL +0 -0
  23. {ragaai_catalyst-2.1.4b0.dist-info → ragaai_catalyst-2.1.4b2.dist-info}/top_level.txt +0 -0
@@ -5,22 +5,46 @@ import psutil
5
5
  import pkg_resources
6
6
  from datetime import datetime
7
7
  from pathlib import Path
8
- from typing import List
8
+ from typing import List, Any
9
9
  import uuid
10
10
  import sys
11
11
  import tempfile
12
12
  from ....ragaai_catalyst import RagaAICatalyst
13
13
  from ..data.data_structure import (
14
- Trace, Metadata, SystemInfo, OSInfo, EnvironmentInfo,
15
- Resources, CPUResource, MemoryResource, DiskResource, NetworkResource,
16
- ResourceInfo, MemoryInfo, DiskInfo, NetworkInfo,
17
- Component,
14
+ Trace,
15
+ Metadata,
16
+ SystemInfo,
17
+ OSInfo,
18
+ EnvironmentInfo,
19
+ Resources,
20
+ CPUResource,
21
+ MemoryResource,
22
+ DiskResource,
23
+ NetworkResource,
24
+ ResourceInfo,
25
+ MemoryInfo,
26
+ DiskInfo,
27
+ NetworkInfo,
28
+ Component,
18
29
  )
19
30
 
20
31
  from ..upload.upload_agentic_traces import UploadAgenticTraces
21
32
  from ..upload.upload_code import upload_code
33
+ from ..upload.upload_trace_metric import upload_trace_metric
22
34
  from ..utils.file_name_tracker import TrackName
23
35
  from ..utils.zip_list_of_unique_files import zip_list_of_unique_files
36
+ from ..utils.span_attributes import SpanAttributes
37
+ from ..utils.create_dataset_schema import create_dataset_schema_with_trace
38
+
39
+
40
+ # Configure logging to show debug messages (which includes info messages as well)
41
+ import logging
42
+
43
+ logger = logging.getLogger(__name__)
44
+ logging_level = (
45
+ logger.setLevel(logging.DEBUG) if os.getenv("DEBUG") == "1" else logging.INFO
46
+ )
47
+
24
48
 
25
49
  class TracerJSONEncoder(json.JSONEncoder):
26
50
  def default(self, obj):
@@ -28,147 +52,148 @@ class TracerJSONEncoder(json.JSONEncoder):
28
52
  return obj.isoformat()
29
53
  if isinstance(obj, bytes):
30
54
  try:
31
- return obj.decode('utf-8')
55
+ return obj.decode("utf-8")
32
56
  except UnicodeDecodeError:
33
57
  return str(obj) # Fallback to string representation
34
- if hasattr(obj, 'to_dict'): # Handle objects with to_dict method
58
+ if hasattr(obj, "to_dict"): # Handle objects with to_dict method
35
59
  return obj.to_dict()
36
- if hasattr(obj, '__dict__'):
60
+ if hasattr(obj, "__dict__"):
37
61
  # Filter out None values and handle nested serialization
38
- return {k: v for k, v in obj.__dict__.items()
39
- if v is not None and not k.startswith('_')}
62
+ return {
63
+ k: v
64
+ for k, v in obj.__dict__.items()
65
+ if v is not None and not k.startswith("_")
66
+ }
40
67
  try:
41
68
  # Try to convert to a basic type
42
69
  return str(obj)
43
70
  except:
44
71
  return None # Last resort: return None instead of failing
45
72
 
73
+
46
74
  class BaseTracer:
47
75
  def __init__(self, user_details):
48
76
  self.user_details = user_details
49
- self.project_name = self.user_details['project_name'] # Access the project_name
50
- self.dataset_name = self.user_details['dataset_name'] # Access the dataset_name
51
- self.project_id = self.user_details['project_id'] # Access the project_id
52
-
77
+ self.project_name = self.user_details["project_name"] # Access the project_name
78
+ self.dataset_name = self.user_details["dataset_name"] # Access the dataset_name
79
+ self.project_id = self.user_details["project_id"] # Access the project_id
80
+ self.trace_name = self.user_details["trace_name"] # Access the trace_name
81
+
53
82
  # Initialize trace data
54
83
  self.trace_id = None
55
- self.start_time = None
84
+ self.start_time = None
56
85
  self.components: List[Component] = []
57
86
  self.file_tracker = TrackName()
58
-
87
+ self.span_attributes_dict = {}
88
+
59
89
  def _get_system_info(self) -> SystemInfo:
60
90
  # Get OS info
61
91
  os_info = OSInfo(
62
92
  name=platform.system(),
63
93
  version=platform.version(),
64
94
  platform=platform.machine(),
65
- kernel_version=platform.release()
95
+ kernel_version=platform.release(),
66
96
  )
67
-
97
+
68
98
  # Get Python environment info
69
- installed_packages = [f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set]
99
+ installed_packages = [
100
+ f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set
101
+ ]
70
102
  env_info = EnvironmentInfo(
71
103
  name="Python",
72
104
  version=platform.python_version(),
73
105
  packages=installed_packages,
74
106
  env_path=sys.prefix,
75
- command_to_run=f"python {sys.argv[0]}"
107
+ command_to_run=f"python {sys.argv[0]}",
76
108
  )
77
-
109
+
78
110
  return SystemInfo(
79
111
  id=f"sys_{self.trace_id}",
80
112
  os=os_info,
81
113
  environment=env_info,
82
- source_code="Path to the source code .zip file in format hashid.zip" # TODO: Implement source code archiving
114
+ source_code="Path to the source code .zip file in format hashid.zip", # TODO: Implement source code archiving
83
115
  )
84
-
116
+
85
117
  def _get_resources(self) -> Resources:
86
118
  # CPU info
87
119
  cpu_info = ResourceInfo(
88
120
  name=platform.processor(),
89
121
  cores=psutil.cpu_count(logical=False),
90
- threads=psutil.cpu_count(logical=True)
122
+ threads=psutil.cpu_count(logical=True),
91
123
  )
92
- cpu = CPUResource(
93
- info=cpu_info,
94
- interval="5s",
95
- values=[psutil.cpu_percent()]
96
- )
97
-
124
+ cpu = CPUResource(info=cpu_info, interval="5s", values=[psutil.cpu_percent()])
125
+
98
126
  # Memory info
99
127
  memory = psutil.virtual_memory()
100
128
  mem_info = MemoryInfo(
101
129
  total=memory.total / (1024**3), # Convert to GB
102
- free=memory.available / (1024**3)
130
+ free=memory.available / (1024**3),
103
131
  )
104
- mem = MemoryResource(
105
- info=mem_info,
106
- interval="5s",
107
- values=[memory.percent]
108
- )
109
-
132
+ mem = MemoryResource(info=mem_info, interval="5s", values=[memory.percent])
133
+
110
134
  # Disk info
111
- disk = psutil.disk_usage('/')
112
- disk_info = DiskInfo(
113
- total=disk.total / (1024**3),
114
- free=disk.free / (1024**3)
115
- )
135
+ disk = psutil.disk_usage("/")
136
+ disk_info = DiskInfo(total=disk.total / (1024**3), free=disk.free / (1024**3))
116
137
  disk_io = psutil.disk_io_counters()
117
138
  disk_resource = DiskResource(
118
139
  info=disk_info,
119
140
  interval="5s",
120
141
  read=[disk_io.read_bytes / (1024**2)], # MB
121
- write=[disk_io.write_bytes / (1024**2)]
142
+ write=[disk_io.write_bytes / (1024**2)],
122
143
  )
123
-
144
+
124
145
  # Network info
125
146
  net_io = psutil.net_io_counters()
126
147
  net_info = NetworkInfo(
127
148
  upload_speed=net_io.bytes_sent / (1024**2), # MB
128
- download_speed=net_io.bytes_recv / (1024**2)
149
+ download_speed=net_io.bytes_recv / (1024**2),
129
150
  )
130
151
  net = NetworkResource(
131
152
  info=net_info,
132
153
  interval="5s",
133
154
  uploads=[net_io.bytes_sent / (1024**2)],
134
- downloads=[net_io.bytes_recv / (1024**2)]
155
+ downloads=[net_io.bytes_recv / (1024**2)],
135
156
  )
136
-
157
+
137
158
  return Resources(cpu=cpu, memory=mem, disk=disk_resource, network=net)
138
-
159
+
139
160
  def start(self):
140
161
  """Initialize a new trace"""
141
162
  metadata = Metadata(
142
163
  cost={},
143
164
  tokens={},
144
165
  system_info=self._get_system_info(),
145
- resources=self._get_resources()
166
+ resources=self._get_resources(),
146
167
  )
147
168
 
148
169
  # Generate a unique trace ID, when trace starts
149
- self.trace_id = str(uuid.uuid4())
150
-
170
+ self.trace_id = str(uuid.uuid4())
171
+
151
172
  # Get the start time
152
173
  self.start_time = datetime.now().isoformat()
153
-
154
- self.data_key = [{"start_time": datetime.now().isoformat(),
155
- "end_time": "",
156
- "spans": self.components
157
- }]
158
-
174
+
175
+ self.data_key = [
176
+ {
177
+ "start_time": datetime.now().isoformat(),
178
+ "end_time": "",
179
+ "spans": self.components,
180
+ }
181
+ ]
182
+
159
183
  self.trace = Trace(
160
184
  id=self.trace_id,
185
+ trace_name=self.trace_name,
161
186
  project_name=self.project_name,
162
187
  start_time=datetime.now().isoformat(),
163
188
  end_time="", # Will be set when trace is stopped
164
189
  metadata=metadata,
165
190
  data=self.data_key,
166
- replays={"source": None}
191
+ replays={"source": None},
167
192
  )
168
-
193
+
169
194
  def stop(self):
170
195
  """Stop the trace and save to JSON file"""
171
- if hasattr(self, 'trace'):
196
+ if hasattr(self, "trace"):
172
197
  self.trace.data[0]["end_time"] = datetime.now().isoformat()
173
198
  self.trace.end_time = datetime.now().isoformat()
174
199
 
@@ -176,65 +201,86 @@ class BaseTracer:
176
201
  self.trace = self._change_span_ids_to_int(self.trace)
177
202
  self.trace = self._change_agent_input_output(self.trace)
178
203
  self.trace = self._extract_cost_tokens(self.trace)
179
-
204
+
180
205
  # Create traces directory if it doesn't exist
181
206
  self.traces_dir = tempfile.gettempdir()
182
207
  filename = self.trace.id + ".json"
183
208
  filepath = f"{self.traces_dir}/{filename}"
184
209
 
185
- #get unique files and zip it. Generate a unique hash ID for the contents of the files
210
+ # get unique files and zip it. Generate a unique hash ID for the contents of the files
186
211
  list_of_unique_files = self.file_tracker.get_unique_files()
187
- hash_id, zip_path = zip_list_of_unique_files(list_of_unique_files, output_dir=self.traces_dir)
212
+ hash_id, zip_path = zip_list_of_unique_files(
213
+ list_of_unique_files, output_dir=self.traces_dir
214
+ )
188
215
 
189
- #replace source code with zip_path
216
+ # replace source code with zip_path
190
217
  self.trace.metadata.system_info.source_code = hash_id
191
218
 
192
219
  # Clean up trace_data before saving
193
220
  trace_data = self.trace.__dict__
194
221
  cleaned_trace_data = self._clean_trace(trace_data)
195
222
 
196
- with open(filepath, 'w') as f:
223
+ # Format interactions and add to trace
224
+ interactions = self.format_interactions()
225
+ self.trace.workflow = interactions["workflow"]
226
+
227
+ with open(filepath, "w") as f:
197
228
  json.dump(cleaned_trace_data, f, cls=TracerJSONEncoder, indent=2)
198
-
199
- print(f"Trace saved to {filepath}")
229
+
230
+ logger.info(" Traces saved successfully.")
231
+ logger.debug(f"Trace saved to {filepath}")
200
232
  # Upload traces
233
+
201
234
  json_file_path = str(filepath)
202
235
  project_name = self.project_name
203
- project_id = self.project_id
236
+ project_id = self.project_id
204
237
  dataset_name = self.dataset_name
205
238
  user_detail = self.user_details
206
239
  base_url = RagaAICatalyst.BASE_URL
240
+
241
+ ## create dataset schema
242
+ response = create_dataset_schema_with_trace(
243
+ dataset_name=dataset_name, project_name=project_name
244
+ )
245
+
246
+ ##Upload trace metrics
247
+ response = upload_trace_metric(
248
+ json_file_path=json_file_path,
249
+ dataset_name=self.dataset_name,
250
+ project_name=self.project_name,
251
+ )
252
+
207
253
  upload_traces = UploadAgenticTraces(
208
254
  json_file_path=json_file_path,
209
255
  project_name=project_name,
210
256
  project_id=project_id,
211
257
  dataset_name=dataset_name,
212
258
  user_detail=user_detail,
213
- base_url=base_url
259
+ base_url=base_url,
214
260
  )
215
261
  upload_traces.upload_agentic_traces()
216
262
 
217
- #Upload Codehash
263
+ # Upload Codehash
218
264
  response = upload_code(
219
265
  hash_id=hash_id,
220
266
  zip_path=zip_path,
221
267
  project_name=project_name,
222
- dataset_name=dataset_name
268
+ dataset_name=dataset_name,
223
269
  )
224
270
  print(response)
225
-
271
+
226
272
  # Cleanup
227
273
  self.components = []
228
274
  self.file_tracker.reset()
229
-
275
+
230
276
  def add_component(self, component: Component):
231
277
  """Add a component to the trace"""
232
278
  self.components.append(component)
233
-
279
+
234
280
  def __enter__(self):
235
281
  self.start()
236
282
  return self
237
-
283
+
238
284
  def __exit__(self, exc_type, exc_value, traceback):
239
285
  self.stop()
240
286
 
@@ -244,7 +290,7 @@ class BaseTracer:
244
290
  span.id = id
245
291
  span.parent_id = parent_id
246
292
  id += 1
247
- if span.type=="agent":
293
+ if span.type == "agent":
248
294
  for children in span.data["children"]:
249
295
  children["id"] = id
250
296
  children["parent_id"] = span.id
@@ -265,20 +311,24 @@ class BaseTracer:
265
311
  input_data = child["data"].get("input")
266
312
 
267
313
  if input_data:
268
- span.data["input"] = input_data['args'] if hasattr(input_data, 'args') else input_data
314
+ span.data["input"] = (
315
+ input_data["args"]
316
+ if hasattr(input_data, "args")
317
+ else input_data
318
+ )
269
319
  break
270
-
320
+
271
321
  # Find first non-null output going backward
272
322
  for child in reversed(childrens):
273
323
  if "data" not in child:
274
324
  continue
275
325
  output_data = child["data"].get("output")
276
-
326
+
277
327
  if output_data and output_data != "" and output_data != "None":
278
328
  span.data["output"] = output_data
279
329
  break
280
330
  return trace
281
-
331
+
282
332
  def _extract_cost_tokens(self, trace):
283
333
  cost = {}
284
334
  tokens = {}
@@ -286,30 +336,30 @@ class BaseTracer:
286
336
  if span.type == "llm":
287
337
  info = span.info
288
338
  if isinstance(info, dict):
289
- cost_info = info.get('cost', {})
339
+ cost_info = info.get("cost", {})
290
340
  for key, value in cost_info.items():
291
341
  if key not in cost:
292
- cost[key] = 0
342
+ cost[key] = 0
293
343
  cost[key] += value
294
- token_info = info.get('tokens', {})
344
+ token_info = info.get("tokens", {})
295
345
  for key, value in token_info.items():
296
346
  if key not in tokens:
297
347
  tokens[key] = 0
298
348
  tokens[key] += value
299
349
  if span.type == "agent":
300
350
  for children in span.data["children"]:
301
- if 'type' not in children:
351
+ if "type" not in children:
302
352
  continue
303
353
  if children["type"] != "llm":
304
354
  continue
305
355
  info = children["info"]
306
356
  if isinstance(info, dict):
307
- cost_info = info.get('cost', {})
357
+ cost_info = info.get("cost", {})
308
358
  for key, value in cost_info.items():
309
359
  if key not in cost:
310
- cost[key] = 0
360
+ cost[key] = 0
311
361
  cost[key] += value
312
- token_info = info.get('tokens', {})
362
+ token_info = info.get("tokens", {})
313
363
  for key, value in token_info.items():
314
364
  if key not in tokens:
315
365
  tokens[key] = 0
@@ -321,51 +371,462 @@ class BaseTracer:
321
371
  def _clean_trace(self, trace):
322
372
  # Convert span to dict if it has to_dict method
323
373
  def _to_dict_if_needed(obj):
324
- if hasattr(obj, 'to_dict'):
374
+ if hasattr(obj, "to_dict"):
325
375
  return obj.to_dict()
326
376
  return obj
327
377
 
328
378
  def deduplicate_spans(spans):
329
379
  seen_llm_spans = {} # Dictionary to track unique LLM spans
330
380
  unique_spans = []
331
-
381
+
332
382
  for span in spans:
333
383
  # Convert span to dictionary if needed
334
384
  span_dict = _to_dict_if_needed(span)
335
-
385
+
336
386
  # Skip spans without hash_id
337
- if 'hash_id' not in span_dict:
387
+ if "hash_id" not in span_dict:
338
388
  continue
339
-
340
- if span_dict.get('type') == 'llm':
389
+
390
+ if span_dict.get("type") == "llm":
341
391
  # Create a unique key based on hash_id, input, and output
342
392
  span_key = (
343
- span_dict.get('hash_id'),
344
- str(span_dict.get('data', {}).get('input')),
345
- str(span_dict.get('data', {}).get('output'))
393
+ span_dict.get("hash_id"),
394
+ str(span_dict.get("data", {}).get("input")),
395
+ str(span_dict.get("data", {}).get("output")),
346
396
  )
347
-
397
+
398
+ # Check if we've seen this span before
348
399
  if span_key not in seen_llm_spans:
349
400
  seen_llm_spans[span_key] = True
350
401
  unique_spans.append(span)
402
+ else:
403
+ # If we have interactions in the current span, replace the existing one
404
+ current_interactions = span_dict.get("interactions", [])
405
+ if current_interactions:
406
+ # Find and replace the existing span with this one that has interactions
407
+ for i, existing_span in enumerate(unique_spans):
408
+ existing_dict = (
409
+ existing_span
410
+ if isinstance(existing_span, dict)
411
+ else existing_span.__dict__
412
+ )
413
+ if (
414
+ existing_dict.get("hash_id")
415
+ == span_dict.get("hash_id")
416
+ and str(existing_dict.get("data", {}).get("input"))
417
+ == str(span_dict.get("data", {}).get("input"))
418
+ and str(existing_dict.get("data", {}).get("output"))
419
+ == str(span_dict.get("data", {}).get("output"))
420
+ ):
421
+ unique_spans[i] = span
422
+ break
351
423
  else:
352
424
  # For non-LLM spans, process their children if they exist
353
- if 'data' in span_dict and 'children' in span_dict['data']:
354
- children = span_dict['data']['children']
425
+ if "data" in span_dict and "children" in span_dict["data"]:
426
+ children = span_dict["data"]["children"]
355
427
  # Filter and deduplicate children
356
428
  filtered_children = deduplicate_spans(children)
357
429
  if isinstance(span, dict):
358
- span['data']['children'] = filtered_children
430
+ span["data"]["children"] = filtered_children
359
431
  else:
360
- span.data['children'] = filtered_children
432
+ span.data["children"] = filtered_children
361
433
  unique_spans.append(span)
362
-
434
+
363
435
  return unique_spans
364
436
 
365
437
  # Remove any spans without hash ids
366
- for data in trace.get('data', []):
367
- if 'spans' in data:
438
+ for data in trace.get("data", []):
439
+ if "spans" in data:
368
440
  # First filter out spans without hash_ids, then deduplicate
369
- data['spans'] = deduplicate_spans(data['spans'])
370
-
371
- return trace
441
+ data["spans"] = deduplicate_spans(data["spans"])
442
+
443
+ return trace
444
+
445
+ def add_tags(self, tags: List[str]):
446
+ raise NotImplementedError
447
+
448
+ # def _add_span_attributes_to_trace(self, trace):
449
+ # if not hasattr(trace, 'data'):
450
+ # return trace
451
+ # for data in trace.data:
452
+ # for span in data.get('spans', []):
453
+ # if not hasattr(span, 'name'):
454
+ # continue
455
+ # span_name = span.name
456
+ # if span_name in self.span_attributes_dict:
457
+ # span_attributes = self.span_attributes_dict[span_name]
458
+ # span = self._add_span_attributes_to_span(span_attributes, span)
459
+ # if hasattr(span, 'type'):
460
+ # if span.type == 'agent':
461
+ # if hasattr(span, 'data'):
462
+ # if 'children' in span.data:
463
+ # span.data['children'] = self._add_span_attributes_to_children(span_attributes, span.data['children'])
464
+
465
+ # return trace
466
+
467
+ # def _add_span_attributes_to_children(self, span_attributes: SpanAttributes, children):
468
+ # attributed_children = []
469
+ # for child in children:
470
+ # if 'name' not in child:
471
+ # continue
472
+ # child_name = child['name']
473
+ # if child_name in self.span_attributes_dict:
474
+ # span_attributes = self.span_attributes_dict[child_name]
475
+ # child = self._add_span_attributes_to_span(span_attributes, child)
476
+ # if 'type' in child:
477
+ # if child['type'] == 'agent':
478
+ # if 'data' in child:
479
+ # if 'children' in child['data']:
480
+ # child['data']['children'] = self._add_span_attributes_to_children(span_attributes, child['data']['children'])
481
+ # attributed_children.append(child)
482
+ # return attributed_children
483
+
484
+ # def _add_span_attributes_to_span(self, span_attributes: SpanAttributes, span):
485
+ # metadata = {
486
+ # 'tags': span_attributes.tags,
487
+ # 'user_metadata': span_attributes.metadata
488
+ # }
489
+ # metrics = span_attributes.metrics
490
+ # feedback = span_attributes.feedback
491
+ # if isinstance(span, dict):
492
+ # span['metadata'] = metadata
493
+ # span['metrics'] = metrics
494
+ # span['feedback'] = feedback
495
+ # else:
496
+ # span.metadata = metadata
497
+ # span.metrics = metrics
498
+ # span.feedback = feedback
499
+ # return span
500
+
501
+ def span(self, span_name):
502
+ if span_name not in self.span_attributes_dict:
503
+ self.span_attributes_dict[span_name] = SpanAttributes(span_name)
504
+ return self.span_attributes_dict[span_name]
505
+
506
+ def format_interactions(self) -> dict:
507
+ """
508
+ Format interactions from trace data into a standardized format.
509
+ Returns a dictionary containing formatted interactions based on trace data.
510
+
511
+ The function processes spans from self.trace and formats them into interactions
512
+ of various types including: agent_start, agent_end, input, output, tool_call_start,
513
+ tool_call_end, llm_call, file_read, file_write, network_call.
514
+
515
+ Returns:
516
+ dict: A dictionary with "interactions" key containing a list of interactions
517
+ sorted by timestamp.
518
+ """
519
+ interactions = []
520
+ interaction_id = 1
521
+
522
+ if not hasattr(self, "trace") or not self.trace.data:
523
+ return {"interactions": []}
524
+
525
+ for span in self.trace.data[0]["spans"]:
526
+ # Process agent spans
527
+ if span.type == "agent":
528
+ # Add agent_start interaction
529
+ interactions.append(
530
+ {
531
+ "id": str(interaction_id),
532
+ "span_id": span.id,
533
+ "interaction_type": "agent_call_start",
534
+ "name": span.name,
535
+ "content": None,
536
+ "timestamp": span.start_time,
537
+ "error": span.error,
538
+ }
539
+ )
540
+ interaction_id += 1
541
+
542
+ # Process children of agent
543
+ if "children" in span.data:
544
+ for child in span.data["children"]:
545
+ child_type = child.get("type")
546
+ if child_type == "tool":
547
+ # Tool call start
548
+ interactions.append(
549
+ {
550
+ "id": str(interaction_id),
551
+ "span_id": child.get("id"),
552
+ "interaction_type": "tool_call_start",
553
+ "name": child.get("name"),
554
+ "content": {
555
+ "parameters": [
556
+ child.get("data", {})
557
+ .get("input")
558
+ .get("args"),
559
+ child.get("data", {})
560
+ .get("input")
561
+ .get("kwargs"),
562
+ ]
563
+ },
564
+ "timestamp": child.get("start_time"),
565
+ "error": child.get("error"),
566
+ }
567
+ )
568
+ interaction_id += 1
569
+
570
+ # Tool call end
571
+ interactions.append(
572
+ {
573
+ "id": str(interaction_id),
574
+ "span_id": child.get("id"),
575
+ "interaction_type": "tool_call_end",
576
+ "name": child.get("name"),
577
+ "content": {
578
+ "returns": child.get("data", {}).get("output"),
579
+ },
580
+ "timestamp": child.get("end_time"),
581
+ "error": child.get("error"),
582
+ }
583
+ )
584
+ interaction_id += 1
585
+
586
+ elif child_type == "llm":
587
+ interactions.append(
588
+ {
589
+ "id": str(interaction_id),
590
+ "span_id": child.get("id"),
591
+ "interaction_type": "llm_call_start",
592
+ "name": child.get("name"),
593
+ "content": {
594
+ "prompt": child.get("data", {}).get("input"),
595
+ },
596
+ "timestamp": child.get("start_time"),
597
+ "error": child.get("error"),
598
+ }
599
+ )
600
+ interaction_id += 1
601
+
602
+ interactions.append(
603
+ {
604
+ "id": str(interaction_id),
605
+ "span_id": child.get("id"),
606
+ "interaction_type": "llm_call_end",
607
+ "name": child.get("name"),
608
+ "content": {
609
+ "response": child.get("data", {}).get("output")
610
+ },
611
+ "timestamp": child.get("end_time"),
612
+ "error": child.get("error"),
613
+ }
614
+ )
615
+ interaction_id += 1
616
+
617
+ elif child_type == "agent":
618
+ interactions.append(
619
+ {
620
+ "id": str(interaction_id),
621
+ "span_id": child.get("id"),
622
+ "interaction_type": "agent_call_start",
623
+ "name": child.get("name"),
624
+ "content": None,
625
+ "timestamp": child.get("start_time"),
626
+ "error": child.get("error"),
627
+ }
628
+ )
629
+ interaction_id += 1
630
+
631
+ interactions.append(
632
+ {
633
+ "id": str(interaction_id),
634
+ "span_id": child.get("id"),
635
+ "interaction_type": "agent_call_end",
636
+ "name": child.get("name"),
637
+ "content": child.get("data", {}).get("output"),
638
+ "timestamp": child.get("end_time"),
639
+ "error": child.get("error"),
640
+ }
641
+ )
642
+ interaction_id += 1
643
+
644
+ else:
645
+ interactions.append(
646
+ {
647
+ "id": str(interaction_id),
648
+ "span_id": child.get("id"),
649
+ "interaction_type": child_type,
650
+ "name": child.get("name"),
651
+ "content": child.get("data", {}),
652
+ "timestamp": child.get("start_time"),
653
+ "error": child.get("error"),
654
+ }
655
+ )
656
+ interaction_id += 1
657
+
658
+ if "interactions" in child:
659
+ for interaction in child["interactions"]:
660
+ interaction["id"] = str(interaction_id)
661
+ interaction["span_id"] = child.get("id")
662
+ interaction["error"] = None
663
+ interactions.append(interaction)
664
+ interaction_id += 1
665
+
666
+ if "network_calls" in child:
667
+ for child_network_call in child["network_calls"]:
668
+ network_call = {}
669
+ network_call["id"] = str(interaction_id)
670
+ network_call["span_id"] = child.get("id")
671
+ network_call["interaction_type"] = "network_call"
672
+ network_call["name"] = None
673
+ network_call["content"] = {
674
+ "request": {
675
+ "url": child_network_call.get("url"),
676
+ "method": child_network_call.get("method"),
677
+ "headers": child_network_call.get("headers"),
678
+ },
679
+ "response": {
680
+ "status_code": child_network_call.get(
681
+ "status_code"
682
+ ),
683
+ "headers": child_network_call.get(
684
+ "response_headers"
685
+ ),
686
+ "body": child_network_call.get("response_body"),
687
+ },
688
+ }
689
+ network_call["timestamp"] = child_network_call[
690
+ "start_time"
691
+ ]
692
+ network_call["error"] = child_network_call.get("error")
693
+ interactions.append(network_call)
694
+ interaction_id += 1
695
+
696
+ # Add agent_end interaction
697
+ interactions.append(
698
+ {
699
+ "id": str(interaction_id),
700
+ "span_id": span.id,
701
+ "interaction_type": "agent_call_end",
702
+ "name": span.name,
703
+ "content": span.data.get("output"),
704
+ "timestamp": span.end_time,
705
+ "error": span.error,
706
+ }
707
+ )
708
+ interaction_id += 1
709
+
710
+ elif span.type == "tool":
711
+ interactions.append(
712
+ {
713
+ "id": str(interaction_id),
714
+ "span_id": span.id,
715
+ "interaction_type": "tool_call_start",
716
+ "name": span.name,
717
+ "content": {
718
+ "prompt": span.data.get("input"),
719
+ "response": span.data.get("output"),
720
+ },
721
+ "timestamp": span.start_time,
722
+ "error": span.error,
723
+ }
724
+ )
725
+ interaction_id += 1
726
+
727
+ interactions.append(
728
+ {
729
+ "id": str(interaction_id),
730
+ "span_id": span.id,
731
+ "interaction_type": "tool_call_end",
732
+ "name": span.name,
733
+ "content": {
734
+ "prompt": span.data.get("input"),
735
+ "response": span.data.get("output"),
736
+ },
737
+ "timestamp": span.end_time,
738
+ "error": span.error,
739
+ }
740
+ )
741
+ interaction_id += 1
742
+
743
+ elif span.type == "llm":
744
+ interactions.append(
745
+ {
746
+ "id": str(interaction_id),
747
+ "span_id": span.id,
748
+ "interaction_type": "llm_call_start",
749
+ "name": span.name,
750
+ "content": {
751
+ "prompt": span.data.get("input"),
752
+ },
753
+ "timestamp": span.start_time,
754
+ "error": span.error,
755
+ }
756
+ )
757
+ interaction_id += 1
758
+
759
+ interactions.append(
760
+ {
761
+ "id": str(interaction_id),
762
+ "span_id": span.id,
763
+ "interaction_type": "llm_call_end",
764
+ "name": span.name,
765
+ "content": {"response": span.data.get("output")},
766
+ "timestamp": span.end_time,
767
+ "error": span.error,
768
+ }
769
+ )
770
+ interaction_id += 1
771
+
772
+ else:
773
+ interactions.append(
774
+ {
775
+ "id": str(interaction_id),
776
+ "span_id": span.id,
777
+ "interaction_type": span.type,
778
+ "name": span.name,
779
+ "content": span.data,
780
+ "timestamp": span.start_time,
781
+ "error": span.error,
782
+ }
783
+ )
784
+ interaction_id += 1
785
+
786
+ # Process interactions from span.data if they exist
787
+ if span.interactions:
788
+ for span_interaction in span.interactions:
789
+ interaction = {}
790
+ interaction["id"] = str(interaction_id)
791
+ interaction["span_id"] = span.id
792
+ interaction["interaction_type"] = span_interaction.type
793
+ interaction["content"] = span_interaction.content
794
+ interaction["timestamp"] = span_interaction.timestamp
795
+ interaction["error"] = span.error
796
+ interactions.append(interaction)
797
+ interaction_id += 1
798
+
799
+ if span.network_calls:
800
+ for span_network_call in span.network_calls:
801
+ network_call = {}
802
+ network_call["id"] = str(interaction_id)
803
+ network_call["span_id"] = span.id
804
+ network_call["interaction_type"] = "network_call"
805
+ network_call["name"] = None
806
+ network_call["content"] = {
807
+ "request": {
808
+ "url": span_network_call.get("url"),
809
+ "method": span_network_call.get("method"),
810
+ "headers": span_network_call.get("headers"),
811
+ },
812
+ "response": {
813
+ "status_code": span_network_call.get("status_code"),
814
+ "headers": span_network_call.get("response_headers"),
815
+ "body": span_network_call.get("response_body"),
816
+ },
817
+ }
818
+ network_call["timestamp"] = span_network_call.get("timestamp")
819
+ network_call["error"] = span_network_call.get("error")
820
+ interactions.append(network_call)
821
+ interaction_id += 1
822
+
823
+ # Sort interactions by timestamp
824
+ sorted_interactions = sorted(
825
+ interactions, key=lambda x: x["timestamp"] if x["timestamp"] else ""
826
+ )
827
+
828
+ # Reassign IDs to maintain sequential order after sorting
829
+ for idx, interaction in enumerate(sorted_interactions, 1):
830
+ interaction["id"] = str(idx)
831
+
832
+ return {"workflow": sorted_interactions}