ragaai-catalyst 2.1.4.1b1__py3-none-any.whl → 2.1.5b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,5 @@
1
1
  from .experiment import Experiment
2
2
  from .ragaai_catalyst import RagaAICatalyst
3
- from .tracers import Tracer
4
3
  from .utils import response_checker
5
4
  from .dataset import Dataset
6
5
  from .prompt_manager import PromptManager
@@ -8,6 +7,24 @@ from .evaluation import Evaluation
8
7
  from .synthetic_data_generation import SyntheticDataGeneration
9
8
  from .guardrails_manager import GuardrailsManager
10
9
  from .guard_executor import GuardExecutor
10
+ from .tracers import Tracer, init_tracing, trace_agent, trace_llm, trace_tool, current_span, trace_custom
11
11
 
12
12
 
13
- __all__ = ["Experiment", "RagaAICatalyst", "Tracer", "PromptManager", "Evaluation","SyntheticDataGeneration", "GuardrailsManager"]
13
+
14
+
15
+ __all__ = [
16
+ "Experiment",
17
+ "RagaAICatalyst",
18
+ "Tracer",
19
+ "PromptManager",
20
+ "Evaluation",
21
+ "SyntheticDataGeneration",
22
+ "GuardrailsManager",
23
+ "GuardExecutor",
24
+ "init_tracing",
25
+ "trace_agent",
26
+ "trace_llm",
27
+ "trace_tool",
28
+ "current_span",
29
+ "trace_custom"
30
+ ]
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import json
2
3
  import requests
3
4
  from .utils import response_checker
4
5
  from typing import Union
@@ -271,3 +272,332 @@ class Dataset:
271
272
  except Exception as e:
272
273
  logger.error(f"Error in create_from_csv: {e}")
273
274
  raise
275
+
276
+ def add_rows(self, csv_path, dataset_name):
277
+ """
278
+ Add rows to an existing dataset from a CSV file.
279
+
280
+ Args:
281
+ csv_path (str): Path to the CSV file to be added
282
+ dataset_name (str): Name of the existing dataset to add rows to
283
+
284
+ Raises:
285
+ ValueError: If dataset does not exist or columns are incompatible
286
+ """
287
+ # Get existing dataset columns
288
+ existing_columns = self.get_dataset_columns(dataset_name)
289
+
290
+ # Read the CSV file to check columns
291
+ try:
292
+ import pandas as pd
293
+ df = pd.read_csv(csv_path)
294
+ csv_columns = df.columns.tolist()
295
+ except Exception as e:
296
+ logger.error(f"Failed to read CSV file: {e}")
297
+ raise ValueError(f"Unable to read CSV file: {e}")
298
+
299
+ # Check column compatibility
300
+ for column in existing_columns:
301
+ if column not in csv_columns:
302
+ df[column] = None
303
+
304
+ # Get presigned URL for the CSV
305
+ def get_presignedUrl():
306
+ headers = {
307
+ "Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
308
+ "X-Project-Id": str(self.project_id),
309
+ }
310
+ try:
311
+ response = requests.get(
312
+ f"{Dataset.BASE_URL}/v2/llm/dataset/csv/presigned-url",
313
+ headers=headers,
314
+ timeout=Dataset.TIMEOUT,
315
+ )
316
+ response.raise_for_status()
317
+ return response.json()
318
+ except requests.exceptions.RequestException as e:
319
+ logger.error(f"Failed to get presigned URL: {e}")
320
+ raise
321
+
322
+ try:
323
+ presignedUrl = get_presignedUrl()
324
+ if presignedUrl['success']:
325
+ url = presignedUrl['data']['presignedUrl']
326
+ filename = presignedUrl['data']['fileName']
327
+ else:
328
+ raise ValueError('Unable to fetch presignedUrl')
329
+ except Exception as e:
330
+ logger.error(f"Error in get_presignedUrl: {e}")
331
+ raise
332
+
333
+ # Upload CSV to presigned URL
334
+ def put_csv_to_presignedUrl(url):
335
+ headers = {
336
+ 'Content-Type': 'text/csv',
337
+ 'x-ms-blob-type': 'BlockBlob',
338
+ }
339
+ try:
340
+ with open(csv_path, 'rb') as file:
341
+ response = requests.put(
342
+ url,
343
+ headers=headers,
344
+ data=file,
345
+ timeout=Dataset.TIMEOUT,
346
+ )
347
+ response.raise_for_status()
348
+ return response
349
+ except requests.exceptions.RequestException as e:
350
+ logger.error(f"Failed to put CSV to presigned URL: {e}")
351
+ raise
352
+
353
+ try:
354
+ put_csv_response = put_csv_to_presignedUrl(url)
355
+ if put_csv_response.status_code not in (200, 201):
356
+ raise ValueError('Unable to put csv to the presignedUrl')
357
+ except Exception as e:
358
+ logger.error(f"Error in put_csv_to_presignedUrl: {e}")
359
+ raise
360
+
361
+ # Prepare schema mapping (assuming same mapping as original dataset)
362
+ def generate_schema_mapping(dataset_name):
363
+ headers = {
364
+ 'Content-Type': 'application/json',
365
+ "Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
366
+ "X-Project-Id": str(self.project_id),
367
+ }
368
+ json_data = {
369
+ "size": 12,
370
+ "page": "0",
371
+ "projectId": str(self.project_id),
372
+ "search": ""
373
+ }
374
+ try:
375
+ # First get dataset details
376
+ response = requests.post(
377
+ f"{Dataset.BASE_URL}/v2/llm/dataset",
378
+ headers=headers,
379
+ json=json_data,
380
+ timeout=Dataset.TIMEOUT,
381
+ )
382
+ response.raise_for_status()
383
+ datasets = response.json()["data"]["content"]
384
+ dataset_id = [dataset["id"] for dataset in datasets if dataset["name"]==dataset_name][0]
385
+
386
+ # Get dataset details to extract schema mapping
387
+ response = requests.get(
388
+ f"{Dataset.BASE_URL}/v2/llm/dataset/{dataset_id}?initialCols=0",
389
+ headers=headers,
390
+ timeout=Dataset.TIMEOUT,
391
+ )
392
+ response.raise_for_status()
393
+
394
+ # Extract schema mapping
395
+ schema_mapping = {}
396
+ for col in response.json()["data"]["datasetColumnsResponses"]:
397
+ schema_mapping[col["displayName"]] = {"columnType": col["columnType"]}
398
+
399
+ return schema_mapping
400
+ except requests.exceptions.RequestException as e:
401
+ logger.error(f"Failed to get schema mapping: {e}")
402
+ raise
403
+
404
+ # Upload CSV to elastic
405
+ try:
406
+ schema_mapping = generate_schema_mapping(dataset_name)
407
+
408
+ data = {
409
+ "projectId": str(self.project_id),
410
+ "datasetName": dataset_name,
411
+ "fileName": filename,
412
+ "schemaMapping": schema_mapping,
413
+ "opType": "update", # Use update for adding rows
414
+ "description": "Adding new rows to dataset"
415
+ }
416
+
417
+ headers = {
418
+ 'Content-Type': 'application/json',
419
+ 'Authorization': f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
420
+ "X-Project-Id": str(self.project_id)
421
+ }
422
+
423
+ response = requests.post(
424
+ f"{Dataset.BASE_URL}/v2/llm/dataset/csv",
425
+ headers=headers,
426
+ json=data,
427
+ timeout=Dataset.TIMEOUT,
428
+ )
429
+
430
+ if response.status_code == 400:
431
+ raise ValueError(response.json().get("message", "Failed to add rows"))
432
+
433
+ response.raise_for_status()
434
+
435
+ # Check response
436
+ response_data = response.json()
437
+ if not response_data.get('success', False):
438
+ raise ValueError(response_data.get('message', 'Unknown error occurred'))
439
+
440
+ print(f"Successfully added rows to dataset {dataset_name}")
441
+ return response_data
442
+
443
+ except Exception as e:
444
+ logger.error(f"Error in add_rows_to_dataset: {e}")
445
+ raise
446
+
447
+ def add_columns(self,text_fields,dataset_name, column_name, provider, model,variables={}):
448
+ """
449
+ Add a column to a dataset with dynamically fetched model parameters
450
+
451
+ Args:
452
+ project_id (int): Project ID
453
+ dataset_id (int): Dataset ID
454
+ column_name (str): Name of the new column
455
+ provider (str): Name of the model provider
456
+ model (str): Name of the model
457
+ """
458
+ # First, get model parameters
459
+
460
+ # Validate text_fields input
461
+ if not isinstance(text_fields, list):
462
+ raise ValueError("text_fields must be a list of dictionaries")
463
+
464
+ for field in text_fields:
465
+ if not isinstance(field, dict) or 'role' not in field or 'content' not in field:
466
+ raise ValueError("Each text field must be a dictionary with 'role' and 'content' keys")
467
+
468
+ # First, get the dataset ID
469
+ headers = {
470
+ 'Content-Type': 'application/json',
471
+ "Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
472
+ "X-Project-Id": str(self.project_id),
473
+ }
474
+ json_data = {"size": 12, "page": "0", "projectId": str(self.project_id), "search": ""}
475
+
476
+ try:
477
+ # Get dataset list
478
+ response = requests.post(
479
+ f"{Dataset.BASE_URL}/v2/llm/dataset",
480
+ headers=headers,
481
+ json=json_data,
482
+ timeout=Dataset.TIMEOUT,
483
+ )
484
+ response.raise_for_status()
485
+ datasets = response.json()["data"]["content"]
486
+
487
+ # Find dataset ID
488
+ dataset_id = next((dataset["id"] for dataset in datasets if dataset["name"] == dataset_name), None)
489
+
490
+ if dataset_id is None:
491
+ raise ValueError(f"Dataset {dataset_name} not found")
492
+
493
+
494
+
495
+ parameters_url= f"{Dataset.BASE_URL}/playground/providers/models/parameters/list"
496
+
497
+ headers = {
498
+ 'Content-Type': 'application/json',
499
+ "Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
500
+ "X-Project-Id": str(self.project_id),
501
+ }
502
+
503
+ # Fetch model parameters
504
+ parameters_payload = {
505
+ "providerName": provider,
506
+ "modelName": model
507
+ }
508
+
509
+ # Get model parameters
510
+ params_response = requests.post(
511
+ parameters_url,
512
+ headers=headers,
513
+ json=parameters_payload,
514
+ timeout=30
515
+ )
516
+ params_response.raise_for_status()
517
+
518
+ # Extract parameters
519
+ all_parameters = params_response.json().get('data', [])
520
+
521
+ # Filter and transform parameters for add-column API
522
+ formatted_parameters = []
523
+ for param in all_parameters:
524
+ value = param.get('value')
525
+ param_type = param.get('type')
526
+
527
+ if value is None:
528
+ formatted_param = {
529
+ "name": param.get('name'),
530
+ "value": None, # Pass None if the value is null
531
+ "type": param.get('type')
532
+ }
533
+ else:
534
+ # Improved type handling
535
+ if param_type == "float":
536
+ value = float(value) # Ensure value is converted to float
537
+ elif param_type == "int":
538
+ value = int(value) # Ensure value is converted to int
539
+ elif param_type == "bool":
540
+ value = bool(value) # Ensure value is converted to bool
541
+ elif param_type == "string":
542
+ value = str(value) # Ensure value is converted to string
543
+ else:
544
+ raise ValueError(f"Unsupported parameter type: {param_type}") # Handle unsupported types
545
+
546
+ formatted_param = {
547
+ "name": param.get('name'),
548
+ "value": value,
549
+ "type": param.get('type')
550
+ }
551
+ formatted_parameters.append(formatted_param)
552
+ dataset_id = next((dataset["id"] for dataset in datasets if dataset["name"] == dataset_name), None)
553
+
554
+ # Prepare payload for add column API
555
+ add_column_payload = {
556
+ "rowFilterList": [],
557
+ "columnName": column_name,
558
+ "datasetId": dataset_id,
559
+ "variables": variables,
560
+ "promptTemplate": {
561
+ "textFields": text_fields,
562
+ "modelSpecs": {
563
+ "model": f"{provider}/{model}",
564
+ "parameters": formatted_parameters
565
+ }
566
+ }
567
+ }
568
+ if variables:
569
+ variable_specs = []
570
+ for key, values in variables.items():
571
+ variable_specs.append({
572
+ "name": key,
573
+ "type": "string",
574
+ "schema": values
575
+ })
576
+ add_column_payload["promptTemplate"]["variableSpecs"] = variable_specs
577
+
578
+ # Make API call to add column
579
+ add_column_url = f"{Dataset.BASE_URL}/v2/llm/dataset/add-column"
580
+
581
+ response = requests.post(
582
+ add_column_url,
583
+ headers={
584
+ 'Content-Type': 'application/json',
585
+ 'Authorization': f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
586
+ "X-Project-Id": str(self.project_id)
587
+ },
588
+ json=add_column_payload,
589
+ timeout=30
590
+ )
591
+
592
+ # Check response
593
+ response.raise_for_status()
594
+ response_data = response.json()
595
+
596
+ print("Column added successfully:")
597
+ print(json.dumps(response_data, indent=2))
598
+ return response_data
599
+
600
+ except requests.exceptions.RequestException as e:
601
+ print(f"Error adding column: {e}")
602
+ raise
603
+
@@ -1,3 +1,19 @@
1
1
  from .tracer import Tracer
2
+ from .distributed import (
3
+ init_tracing,
4
+ trace_agent,
5
+ trace_llm,
6
+ trace_tool,
7
+ current_span,
8
+ trace_custom,
9
+ )
2
10
 
3
- __all__ = ["Tracer"]
11
+ __all__ = [
12
+ "Tracer",
13
+ "init_tracing",
14
+ "trace_agent",
15
+ "trace_llm",
16
+ "trace_tool",
17
+ "current_span",
18
+ "trace_custom"
19
+ ]
@@ -44,6 +44,7 @@ class AgentTracerMixin:
44
44
  # Add auto instrument flags
45
45
  self.auto_instrument_agent = False
46
46
  self.auto_instrument_user_interaction = False
47
+ self.auto_instrument_file_io = False
47
48
  self.auto_instrument_network = False
48
49
 
49
50
  def trace_agent(
@@ -512,10 +513,22 @@ class AgentTracerMixin:
512
513
  network_calls = self.component_network_calls.get(kwargs["component_id"], [])
513
514
  interactions = []
514
515
  if self.auto_instrument_user_interaction:
515
- interactions = self.component_user_interaction.get(
516
- kwargs["component_id"], []
517
- )
518
- start_time = kwargs["start_time"]
516
+ input_output_interactions = []
517
+ for interaction in self.component_user_interaction.get(kwargs["component_id"], []):
518
+ if interaction["interaction_type"] in ["input", "output"]:
519
+ input_output_interactions.append(interaction)
520
+ interactions.extend(input_output_interactions)
521
+ if self.auto_instrument_file_io:
522
+ file_io_interactions = []
523
+ for interaction in self.component_user_interaction.get(kwargs["component_id"], []):
524
+ if interaction["interaction_type"] in ["file_read", "file_write"]:
525
+ file_io_interactions.append(interaction)
526
+ interactions.extend(file_io_interactions)
527
+
528
+ # Get start time
529
+ start_time = None
530
+ if "start_time" in kwargs:
531
+ start_time = kwargs["start_time"]
519
532
 
520
533
  # Get tags, metrics
521
534
  name = kwargs["name"]
@@ -621,3 +634,6 @@ class AgentTracerMixin:
621
634
 
622
635
  def instrument_network_calls(self):
623
636
  self.auto_instrument_network = True
637
+
638
+ def instrument_file_io_calls(self):
639
+ self.auto_instrument_file_io = True
@@ -308,15 +308,14 @@ class BaseTracer:
308
308
 
309
309
  # Add metrics to trace before saving
310
310
  trace_data = self.trace.to_dict()
311
-
312
311
  trace_data["metrics"] = self.trace_metrics
313
-
312
+
314
313
  # Clean up trace_data before saving
315
314
  cleaned_trace_data = self._clean_trace(trace_data)
316
315
 
317
316
  # Format interactions and add to trace
318
317
  interactions = self.format_interactions()
319
- self.trace.workflow = interactions["workflow"]
318
+ trace_data["workflow"] = interactions["workflow"]
320
319
 
321
320
  with open(filepath, "w") as f:
322
321
  json.dump(cleaned_trace_data, f, cls=TracerJSONEncoder, indent=2)
@@ -434,38 +433,44 @@ class BaseTracer:
434
433
  def _extract_cost_tokens(self, trace):
435
434
  cost = {}
436
435
  tokens = {}
437
- for span in trace.data[0]["spans"]:
438
- if span.type == "llm":
439
- info = span.info
440
- if isinstance(info, dict):
441
- cost_info = info.get("cost", {})
442
- for key, value in cost_info.items():
443
- if key not in cost:
444
- cost[key] = 0
445
- cost[key] += value
446
- token_info = info.get("tokens", {})
447
- for key, value in token_info.items():
448
- if key not in tokens:
449
- tokens[key] = 0
450
- tokens[key] += value
451
- if span.type == "agent":
452
- for children in span.data["children"]:
453
- if "type" not in children:
454
- continue
455
- if children["type"] != "llm":
456
- continue
457
- info = children["info"]
458
- if isinstance(info, dict):
459
- cost_info = info.get("cost", {})
460
- for key, value in cost_info.items():
461
- if key not in cost:
462
- cost[key] = 0
463
- cost[key] += value
464
- token_info = info.get("tokens", {})
465
- for key, value in token_info.items():
466
- if key not in tokens:
467
- tokens[key] = 0
468
- tokens[key] += value
436
+
437
+ def process_span_info(info):
438
+ if not isinstance(info, dict):
439
+ return
440
+ cost_info = info.get("cost", {})
441
+ for key, value in cost_info.items():
442
+ if key not in cost:
443
+ cost[key] = 0
444
+ cost[key] += value
445
+ token_info = info.get("tokens", {})
446
+ for key, value in token_info.items():
447
+ if key not in tokens:
448
+ tokens[key] = 0
449
+ tokens[key] += value
450
+
451
+ def process_spans(spans):
452
+ for span in spans:
453
+ # Get span type, handling both span objects and dictionaries
454
+ span_type = span.type if hasattr(span, 'type') else span.get('type')
455
+ span_info = span.info if hasattr(span, 'info') else span.get('info', {})
456
+ span_data = span.data if hasattr(span, 'data') else span.get('data', {})
457
+
458
+ # Process direct LLM spans
459
+ if span_type == "llm":
460
+ process_span_info(span_info)
461
+ # Process agent spans recursively
462
+ elif span_type == "agent":
463
+ # Process LLM children in the current agent span
464
+ children = span_data.get("children", [])
465
+ for child in children:
466
+ child_type = child.get("type")
467
+ if child_type == "llm":
468
+ process_span_info(child.get("info", {}))
469
+ # Recursively process nested agent spans
470
+ elif child_type == "agent":
471
+ process_spans([child])
472
+
473
+ process_spans(trace.data[0]["spans"])
469
474
  trace.metadata.cost = cost
470
475
  trace.metadata.tokens = tokens
471
476
  return trace
@@ -665,7 +670,7 @@ class BaseTracer:
665
670
  {
666
671
  "id": str(interaction_id),
667
672
  "span_id": child.get("id"),
668
- "interaction_type": child_type,
673
+ "interaction_type": f"{child_type}_call_start",
669
674
  "name": child.get("name"),
670
675
  "content": child.get("data", {}),
671
676
  "timestamp": child.get("start_time"),
@@ -673,15 +678,29 @@ class BaseTracer:
673
678
  }
674
679
  )
675
680
  interaction_id += 1
681
+
682
+ interactions.append(
683
+ {
684
+ "id": str(interaction_id),
685
+ "span_id": child.get("id"),
686
+ "interaction_type": f"{child_type}_call_end",
687
+ "name": child.get("name"),
688
+ "content": child.get("data", {}),
689
+ "timestamp": child.get("end_time"),
690
+ "error": child.get("error"),
691
+ }
692
+ )
693
+ interaction_id += 1
676
694
 
677
695
  # Process additional interactions and network calls
678
696
  if "interactions" in child:
679
697
  for interaction in child["interactions"]:
680
- interaction["id"] = str(interaction_id)
681
- interaction["span_id"] = child.get("id")
682
- interaction["error"] = None
683
- interactions.append(interaction)
684
- interaction_id += 1
698
+ if interaction!=[]:
699
+ interaction["id"] = str(interaction_id)
700
+ interaction["span_id"] = child.get("id")
701
+ interaction["error"] = None
702
+ interactions.append(interaction)
703
+ interaction_id += 1
685
704
 
686
705
  if "network_calls" in child:
687
706
  for child_network_call in child["network_calls"]:
@@ -833,7 +852,7 @@ class BaseTracer:
833
852
  {
834
853
  "id": str(interaction_id),
835
854
  "span_id": span.id,
836
- "interaction_type": span.type,
855
+ "interaction_type": f"{span.type}_call_start",
837
856
  "name": span.name,
838
857
  "content": span.data,
839
858
  "timestamp": span.start_time,
@@ -841,19 +860,33 @@ class BaseTracer:
841
860
  }
842
861
  )
843
862
  interaction_id += 1
863
+
864
+ interactions.append(
865
+ {
866
+ "id": str(interaction_id),
867
+ "span_id": span.id,
868
+ "interaction_type": f"{span.type}_call_end",
869
+ "name": span.name,
870
+ "content": span.data,
871
+ "timestamp": span.end_time,
872
+ "error": span.error,
873
+ }
874
+ )
875
+ interaction_id += 1
844
876
 
845
877
  # Process interactions from span.data if they exist
846
878
  if span.interactions:
847
879
  for span_interaction in span.interactions:
848
- interaction = {}
849
- interaction["id"] = str(interaction_id)
850
- interaction["span_id"] = span.id
851
- interaction["interaction_type"] = span_interaction.type
852
- interaction["content"] = span_interaction.content
853
- interaction["timestamp"] = span_interaction.timestamp
854
- interaction["error"] = span.error
855
- interactions.append(interaction)
856
- interaction_id += 1
880
+ if span_interaction != []:
881
+ interaction = {}
882
+ interaction["id"] = str(interaction_id)
883
+ interaction["span_id"] = span.id
884
+ interaction["interaction_type"] = span_interaction.type
885
+ interaction["content"] = span_interaction.content
886
+ interaction["timestamp"] = span_interaction.timestamp
887
+ interaction["error"] = span.error
888
+ interactions.append(interaction)
889
+ interaction_id += 1
857
890
 
858
891
  if span.network_calls:
859
892
  for span_network_call in span.network_calls:
@@ -954,7 +987,7 @@ class BaseTracer:
954
987
  self.visited_metrics.append(metric_name)
955
988
 
956
989
  formatted_metric = {
957
- "name": metric_name, # Use potentially modified name
990
+ "name": metric_name,
958
991
  "score": metric["score"],
959
992
  "reason": metric.get("reasoning", ""),
960
993
  "source": "user",
@@ -25,6 +25,7 @@ class CustomTracerMixin:
25
25
  self.auto_instrument_custom = False
26
26
  self.auto_instrument_user_interaction = False
27
27
  self.auto_instrument_network = False
28
+ self.auto_instrument_file_io = False
28
29
 
29
30
  def trace_custom(self, name: str = None, custom_type: str = "generic", version: str = "1.0.0", trace_variables: bool = True):
30
31
  def decorator(func):
@@ -246,8 +247,18 @@ class CustomTracerMixin:
246
247
 
247
248
  interactions = []
248
249
  if self.auto_instrument_user_interaction:
249
- interactions = self.component_user_interaction.get(kwargs["component_id"], [])
250
-
250
+ input_output_interactions = []
251
+ for interaction in self.component_user_interaction.get(kwargs["component_id"], []):
252
+ if interaction["interaction_type"] in ["input", "output"]:
253
+ input_output_interactions.append(interaction)
254
+ interactions.extend(input_output_interactions)
255
+ if self.auto_instrument_file_io:
256
+ file_io_interactions = []
257
+ for interaction in self.component_user_interaction.get(kwargs["component_id"], []):
258
+ if interaction["interaction_type"] in ["file_read", "file_write"]:
259
+ file_io_interactions.append(interaction)
260
+ interactions.extend(file_io_interactions)
261
+
251
262
  component = {
252
263
  "id": kwargs["component_id"],
253
264
  "hash_id": kwargs["hash_id"],
@@ -314,3 +325,7 @@ class CustomTracerMixin:
314
325
  def instrument_network_calls(self):
315
326
  """Enable auto-instrumentation for network calls"""
316
327
  self.auto_instrument_network = True
328
+
329
+ def instrument_file_io_calls(self):
330
+ """Enable auto-instrumentation for file IO calls"""
331
+ self.auto_instrument_file_io = True