ragaai-catalyst 2.1.4.1b0__py3-none-any.whl → 2.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. ragaai_catalyst/__init__.py +23 -2
  2. ragaai_catalyst/dataset.py +462 -1
  3. ragaai_catalyst/evaluation.py +76 -7
  4. ragaai_catalyst/ragaai_catalyst.py +52 -10
  5. ragaai_catalyst/redteaming/__init__.py +7 -0
  6. ragaai_catalyst/redteaming/config/detectors.toml +13 -0
  7. ragaai_catalyst/redteaming/data_generator/scenario_generator.py +95 -0
  8. ragaai_catalyst/redteaming/data_generator/test_case_generator.py +120 -0
  9. ragaai_catalyst/redteaming/evaluator.py +125 -0
  10. ragaai_catalyst/redteaming/llm_generator.py +136 -0
  11. ragaai_catalyst/redteaming/llm_generator_old.py +83 -0
  12. ragaai_catalyst/redteaming/red_teaming.py +331 -0
  13. ragaai_catalyst/redteaming/requirements.txt +4 -0
  14. ragaai_catalyst/redteaming/tests/grok.ipynb +97 -0
  15. ragaai_catalyst/redteaming/tests/stereotype.ipynb +2258 -0
  16. ragaai_catalyst/redteaming/upload_result.py +38 -0
  17. ragaai_catalyst/redteaming/utils/issue_description.py +114 -0
  18. ragaai_catalyst/redteaming/utils/rt.png +0 -0
  19. ragaai_catalyst/redteaming_old.py +171 -0
  20. ragaai_catalyst/synthetic_data_generation.py +400 -22
  21. ragaai_catalyst/tracers/__init__.py +17 -1
  22. ragaai_catalyst/tracers/agentic_tracing/data/data_structure.py +4 -2
  23. ragaai_catalyst/tracers/agentic_tracing/tracers/agent_tracer.py +212 -148
  24. ragaai_catalyst/tracers/agentic_tracing/tracers/base.py +657 -247
  25. ragaai_catalyst/tracers/agentic_tracing/tracers/custom_tracer.py +50 -19
  26. ragaai_catalyst/tracers/agentic_tracing/tracers/llm_tracer.py +588 -177
  27. ragaai_catalyst/tracers/agentic_tracing/tracers/main_tracer.py +99 -100
  28. ragaai_catalyst/tracers/agentic_tracing/tracers/network_tracer.py +3 -3
  29. ragaai_catalyst/tracers/agentic_tracing/tracers/tool_tracer.py +230 -29
  30. ragaai_catalyst/tracers/agentic_tracing/upload/trace_uploader.py +358 -0
  31. ragaai_catalyst/tracers/agentic_tracing/upload/upload_agentic_traces.py +75 -20
  32. ragaai_catalyst/tracers/agentic_tracing/upload/upload_code.py +55 -11
  33. ragaai_catalyst/tracers/agentic_tracing/upload/upload_local_metric.py +74 -0
  34. ragaai_catalyst/tracers/agentic_tracing/upload/upload_trace_metric.py +47 -16
  35. ragaai_catalyst/tracers/agentic_tracing/utils/create_dataset_schema.py +4 -2
  36. ragaai_catalyst/tracers/agentic_tracing/utils/file_name_tracker.py +26 -3
  37. ragaai_catalyst/tracers/agentic_tracing/utils/llm_utils.py +182 -17
  38. ragaai_catalyst/tracers/agentic_tracing/utils/model_costs.json +1233 -497
  39. ragaai_catalyst/tracers/agentic_tracing/utils/span_attributes.py +81 -10
  40. ragaai_catalyst/tracers/agentic_tracing/utils/supported_llm_provider.toml +34 -0
  41. ragaai_catalyst/tracers/agentic_tracing/utils/system_monitor.py +215 -0
  42. ragaai_catalyst/tracers/agentic_tracing/utils/trace_utils.py +0 -32
  43. ragaai_catalyst/tracers/agentic_tracing/utils/unique_decorator.py +3 -1
  44. ragaai_catalyst/tracers/agentic_tracing/utils/zip_list_of_unique_files.py +73 -47
  45. ragaai_catalyst/tracers/distributed.py +300 -0
  46. ragaai_catalyst/tracers/exporters/__init__.py +3 -1
  47. ragaai_catalyst/tracers/exporters/dynamic_trace_exporter.py +160 -0
  48. ragaai_catalyst/tracers/exporters/ragaai_trace_exporter.py +129 -0
  49. ragaai_catalyst/tracers/langchain_callback.py +809 -0
  50. ragaai_catalyst/tracers/llamaindex_instrumentation.py +424 -0
  51. ragaai_catalyst/tracers/tracer.py +301 -55
  52. ragaai_catalyst/tracers/upload_traces.py +24 -7
  53. ragaai_catalyst/tracers/utils/convert_langchain_callbacks_output.py +61 -0
  54. ragaai_catalyst/tracers/utils/convert_llama_instru_callback.py +69 -0
  55. ragaai_catalyst/tracers/utils/extraction_logic_llama_index.py +74 -0
  56. ragaai_catalyst/tracers/utils/langchain_tracer_extraction_logic.py +82 -0
  57. ragaai_catalyst/tracers/utils/model_prices_and_context_window_backup.json +9365 -0
  58. ragaai_catalyst/tracers/utils/trace_json_converter.py +269 -0
  59. {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/METADATA +367 -45
  60. ragaai_catalyst-2.1.5.dist-info/RECORD +97 -0
  61. {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/WHEEL +1 -1
  62. ragaai_catalyst-2.1.4.1b0.dist-info/RECORD +0 -67
  63. {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/LICENSE +0 -0
  64. {ragaai_catalyst-2.1.4.1b0.dist-info → ragaai_catalyst-2.1.5.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import List, Dict, Any
2
+ from typing import List, Dict, Any, Optional
3
3
  import logging
4
4
 
5
5
  logger = logging.getLogger(__name__)
@@ -11,13 +11,17 @@ logging_level = (
11
11
 
12
12
 
13
13
  class SpanAttributes:
14
- def __init__(self, name):
14
+ def __init__(self, name, project_id: Optional[int] = None):
15
15
  self.name = name
16
16
  self.tags = []
17
17
  self.metadata = {}
18
18
  self.metrics = []
19
+ self.local_metrics = []
19
20
  self.feedback = None
21
+ self.project_id = project_id
20
22
  self.trace_attributes = ["tags", "metadata", "metrics"]
23
+ self.gt = None
24
+ self.context = None
21
25
 
22
26
  def add_tags(self, tags: str | List[str]):
23
27
  if isinstance(tags, str):
@@ -30,14 +34,14 @@ class SpanAttributes:
30
34
  logger.debug(f"Added metadata: {metadata}")
31
35
 
32
36
  def add_metrics(
33
- self,
34
- name: str,
35
- score: float | int,
36
- reasoning: str = "",
37
- cost: float = None,
38
- latency: float = None,
39
- metadata: Dict[str, Any] = {},
40
- config: Dict[str, Any] = {},
37
+ self,
38
+ name: str,
39
+ score: float | int,
40
+ reasoning: str = "",
41
+ cost: float = None,
42
+ latency: float = None,
43
+ metadata: Dict[str, Any] = {},
44
+ config: Dict[str, Any] = {},
41
45
  ):
42
46
  self.metrics.append(
43
47
  {
@@ -57,3 +61,70 @@ class SpanAttributes:
57
61
  def add_feedback(self, feedback: Any):
58
62
  self.feedback = feedback
59
63
  logger.debug(f"Added feedback: {self.feedback}")
64
+
65
+ # TODO: Add validation to check if all the required parameters are present
66
+ def execute_metrics(self, **kwargs: Any):
67
+ name = kwargs.get("name")
68
+ model = kwargs.get("model")
69
+ provider = kwargs.get("provider")
70
+ display_name = kwargs.get("display_name", None)
71
+ mapping = kwargs.get("mapping", None)
72
+
73
+ if isinstance(name, str):
74
+ metrics = [{
75
+ "name": name
76
+ }]
77
+ else:
78
+ metrics = name if isinstance(name, list) else [name] if isinstance(name, dict) else []
79
+
80
+ for metric in metrics:
81
+ if not isinstance(metric, dict):
82
+ raise ValueError(f"Expected dict, got {type(metric)}")
83
+
84
+ if "name" not in metric:
85
+ raise ValueError("Metric must contain 'name'")
86
+
87
+ metric_name = metric["name"]
88
+ if metric_name in self.local_metrics:
89
+ count = sum(1 for m in self.local_metrics if m.startswith(metric_name))
90
+ metric_name = f"{metric_name}_{count + 1}"
91
+
92
+ prompt =None
93
+ context = None
94
+ response = None
95
+ # if mapping is not None:
96
+ # prompt = mapping['prompt']
97
+ # context = mapping['context']
98
+ # response = mapping['response']
99
+ new_metric = {
100
+ "name": metric_name,
101
+ "model": model,
102
+ "provider": provider,
103
+ "project_id": self.project_id,
104
+ # "prompt": prompt,
105
+ # "context": context,
106
+ # "response": response,
107
+ "displayName": display_name,
108
+ "mapping": mapping
109
+ }
110
+ self.local_metrics.append(new_metric)
111
+
112
+ def add_gt(self, gt: Any):
113
+ if not isinstance(gt, (str, int, float, bool, list, dict)):
114
+ raise TypeError(f"Unsupported type for gt: {type(gt)}")
115
+ if self.gt:
116
+ logger.warning(f"GT already exists: {self.gt} \n Overwriting...")
117
+ self.gt = gt
118
+ logger.debug(f"Added gt: {self.gt}")
119
+
120
+ def add_context(self, context: Any):
121
+ if isinstance(context, str):
122
+ if not context.strip():
123
+ logger.warning("Empty or whitespace-only context string provided")
124
+ self.context = str(context)
125
+ else:
126
+ try:
127
+ self.context = str(context)
128
+ except Exception as e:
129
+ logger.warning('Cannot cast the context to string... Skipping')
130
+ logger.debug(f"Added context: {self.context}")
@@ -0,0 +1,34 @@
1
+ # List of all supported LLM method calls
2
+
3
+ supported_llm_calls = [
4
+ # OpenAI
5
+ "OpenAI.chat.completions.create()",
6
+ "AsyncOpenAI.chat.completions.create()",
7
+
8
+ # OpenAI Beta
9
+ "OpenAI.beta.threads.create()",
10
+ "OpenAI.beta.threads.messages.create()",
11
+ "OpenAI.beta.threads.runs.create()",
12
+
13
+ # Anthropic
14
+ "Anthropic.messages.create()",
15
+ "Anthropic.messages.acreate()",
16
+
17
+ # Google VertexAI/PaLM
18
+ "GenerativeModel.generate_content()",
19
+ "GenerativeModel.generate_content_async()",
20
+ "ChatVertexAI._generate()",
21
+ "ChatVertexAI._agenerate()",
22
+ "ChatVertexAI.complete()",
23
+ "ChatVertexAI.acomplete()",
24
+
25
+ # Google GenerativeAI
26
+ "ChatGoogleGenerativeAI._generate()",
27
+ "ChatGoogleGenerativeAI._agenerate()",
28
+ "ChatGoogleGenerativeAI.complete()",
29
+ "ChatGoogleGenerativeAI.acomplete()",
30
+
31
+ # LiteLLM
32
+ "litellm.completion()",
33
+ "litellm.acompletion()"
34
+ ]
@@ -0,0 +1,215 @@
1
+ import platform
2
+ import psutil
3
+ import sys
4
+ import pkg_resources
5
+ import logging
6
+ from typing import Dict, List, Optional
7
+ from ..data.data_structure import (
8
+ SystemInfo,
9
+ OSInfo,
10
+ EnvironmentInfo,
11
+ Resources,
12
+ CPUResource,
13
+ MemoryResource,
14
+ DiskResource,
15
+ NetworkResource,
16
+ ResourceInfo,
17
+ MemoryInfo,
18
+ DiskInfo,
19
+ NetworkInfo,
20
+ )
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ class SystemMonitor:
25
+ def __init__(self, trace_id: str):
26
+ self.trace_id = trace_id
27
+
28
+ def get_system_info(self) -> SystemInfo:
29
+ # Initialize with None values
30
+ os_info = OSInfo(
31
+ name=None,
32
+ version=None,
33
+ platform=None,
34
+ kernel_version=None,
35
+ )
36
+ env_info = EnvironmentInfo(
37
+ name=None,
38
+ version=None,
39
+ packages=[],
40
+ env_path=None,
41
+ command_to_run=None,
42
+ )
43
+
44
+ try:
45
+ # Get OS info
46
+ os_info = OSInfo(
47
+ name=platform.system(),
48
+ version=platform.version(),
49
+ platform=platform.machine(),
50
+ kernel_version=platform.release(),
51
+ )
52
+ except Exception as e:
53
+ logger.warning(f"Failed to get OS info: {str(e)}")
54
+
55
+ try:
56
+ # Get Python environment info
57
+ installed_packages = [
58
+ f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set
59
+ ]
60
+ env_info = EnvironmentInfo(
61
+ name="Python",
62
+ version=platform.python_version(),
63
+ packages=installed_packages,
64
+ env_path=sys.prefix,
65
+ command_to_run=f"python {sys.argv[0]}",
66
+ )
67
+ except Exception as e:
68
+ logger.warning(f"Failed to get environment info: {str(e)}")
69
+
70
+
71
+ # Always return a valid SystemInfo object
72
+ return SystemInfo(
73
+ id=f"sys_{self.trace_id}",
74
+ os=os_info,
75
+ environment=env_info,
76
+ source_code="",
77
+ )
78
+
79
+ def get_resources(self) -> Resources:
80
+ # Initialize with None values
81
+ cpu_info = ResourceInfo(
82
+ name=None,
83
+ cores=None,
84
+ threads=None,
85
+ )
86
+ cpu = CPUResource(info=cpu_info, interval="5s", values=[])
87
+
88
+ mem_info = MemoryInfo(
89
+ total=None,
90
+ free=None,
91
+ )
92
+ mem = MemoryResource(info=mem_info, interval="5s", values=[])
93
+
94
+ disk_info = DiskInfo(
95
+ total=None,
96
+ free=None,
97
+ )
98
+ disk_resource = DiskResource(
99
+ info=disk_info,
100
+ interval="5s",
101
+ read=[],
102
+ write=[],
103
+ )
104
+
105
+ net_info = NetworkInfo(
106
+ upload_speed=None,
107
+ download_speed=None,
108
+ )
109
+ net = NetworkResource(
110
+ info=net_info,
111
+ interval="5s",
112
+ uploads=[],
113
+ downloads=[],
114
+ )
115
+
116
+ try:
117
+ # CPU info
118
+ cpu_info = ResourceInfo(
119
+ name=platform.processor(),
120
+ cores=psutil.cpu_count(logical=False),
121
+ threads=psutil.cpu_count(logical=True),
122
+ )
123
+ cpu = CPUResource(info=cpu_info, interval="5s", values=[psutil.cpu_percent()])
124
+ except Exception as e:
125
+ logger.warning(f"Failed to get CPU info: {str(e)}")
126
+
127
+
128
+ try:
129
+ # Memory info
130
+ memory = psutil.virtual_memory()
131
+ mem_info = MemoryInfo(
132
+ total=memory.total / (1024**3), # Convert to GB
133
+ free=memory.available / (1024**3),
134
+ )
135
+ mem = MemoryResource(info=mem_info, interval="5s", values=[memory.percent])
136
+ except Exception as e:
137
+ logger.warning(f"Failed to get memory info: {str(e)}")
138
+
139
+
140
+ try:
141
+ # Disk info
142
+ disk = psutil.disk_usage("/")
143
+ disk_info = DiskInfo(total=disk.total / (1024**3), free=disk.free / (1024**3))
144
+ disk_io = psutil.disk_io_counters()
145
+ disk_resource = DiskResource(
146
+ info=disk_info,
147
+ interval="5s",
148
+ read=[disk_io.read_bytes / (1024**2)], # MB
149
+ write=[disk_io.write_bytes / (1024**2)],
150
+ )
151
+ except Exception as e:
152
+ logger.warning(f"Failed to get disk info: {str(e)}")
153
+
154
+ try:
155
+ # Network info
156
+ net_io = psutil.net_io_counters()
157
+ net_info = NetworkInfo(
158
+ upload_speed=net_io.bytes_sent / (1024**2), # MB
159
+ download_speed=net_io.bytes_recv / (1024**2),
160
+ )
161
+ net = NetworkResource(
162
+ info=net_info,
163
+ interval="5s",
164
+ uploads=[net_io.bytes_sent / (1024**2)],
165
+ downloads=[net_io.bytes_recv / (1024**2)],
166
+ )
167
+ except Exception as e:
168
+ logger.warning(f"Failed to get network info: {str(e)}")
169
+
170
+
171
+ # Always return a valid Resources object
172
+ return Resources(cpu=cpu, memory=mem, disk=disk_resource, network=net)
173
+
174
+ def track_memory_usage(self) -> Optional[float]:
175
+ """Track memory usage in MB"""
176
+ try:
177
+ memory_usage = psutil.Process().memory_info().rss
178
+ return memory_usage / (1024 * 1024) # Convert to MB
179
+ except Exception as e:
180
+ logger.warning(f"Failed to track memory usage: {str(e)}")
181
+ return None
182
+
183
+ def track_cpu_usage(self, interval: float) -> Optional[float]:
184
+ """Track CPU usage percentage"""
185
+ try:
186
+ return psutil.cpu_percent(interval=interval)
187
+ except Exception as e:
188
+ logger.warning(f"Failed to track CPU usage: {str(e)}")
189
+ return None
190
+
191
+ def track_disk_usage(self) -> Dict[str, Optional[float]]:
192
+ """Track disk I/O in MB"""
193
+ default_response = {'disk_read': None, 'disk_write': None}
194
+ try:
195
+ disk_io = psutil.disk_io_counters()
196
+ return {
197
+ 'disk_read': disk_io.read_bytes / (1024 * 1024), # Convert to MB
198
+ 'disk_write': disk_io.write_bytes / (1024 * 1024) # Convert to MB
199
+ }
200
+ except Exception as e:
201
+ logger.warning(f"Failed to track disk usage: {str(e)}")
202
+ return default_response
203
+
204
+ def track_network_usage(self) -> Dict[str, Optional[float]]:
205
+ """Track network I/O in MB"""
206
+ default_response = {'uploads': None, 'downloads': None}
207
+ try:
208
+ net_io = psutil.net_io_counters()
209
+ return {
210
+ 'uploads': net_io.bytes_sent / (1024 * 1024), # Convert to MB
211
+ 'downloads': net_io.bytes_recv / (1024 * 1024) # Convert to MB
212
+ }
213
+ except Exception as e:
214
+ logger.warning(f"Failed to track network usage: {str(e)}")
215
+ return default_response
@@ -59,38 +59,6 @@ def calculate_cost(
59
59
  "total": total_cost,
60
60
  }
61
61
 
62
-
63
- def load_model_costs():
64
- try:
65
- current_dir = os.path.dirname(os.path.abspath(__file__))
66
- model_costs_path = os.path.join(current_dir, "model_costs.json")
67
- with open(model_costs_path, "r") as file:
68
- return json.load(file)
69
- except FileNotFoundError:
70
- with resources.open_text("utils", "model_costs.json") as file:
71
- return json.load(file)
72
-
73
-
74
- def update_model_costs_from_github():
75
- """Updates the model_costs.json file with latest costs from GitHub."""
76
- try:
77
- logger.debug("loading the latest model costs.")
78
- response = requests.get(
79
- "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
80
- )
81
- if response.status_code == 200:
82
- current_dir = os.path.dirname(os.path.abspath(__file__))
83
- model_costs_path = os.path.join(current_dir, "model_costs.json")
84
- with open(model_costs_path, "w") as file:
85
- json.dump(response.json(), file, indent=4)
86
- logger.debug("Model costs updated successfully.")
87
- return True
88
- return False
89
- except Exception as e:
90
- logger.error(f"Failed to update model costs from GitHub: {e}")
91
- return False
92
-
93
-
94
62
  def log_event(event_data, log_file_path):
95
63
  event_data = asdict(event_data)
96
64
  with open(log_file_path, "a") as f:
@@ -56,7 +56,9 @@ def generate_unique_hash(func, *args, **kwargs):
56
56
  return '_'.join(f"{normalize_arg(k)}:{normalize_arg(v)}"
57
57
  for k, v in sorted(arg.items()))
58
58
  elif callable(arg):
59
- return arg.__name__
59
+ if hasattr(arg, "__name__"):
60
+ return arg.__name__
61
+ return str(type(arg).__name__)
60
62
  else:
61
63
  return str(type(arg).__name__)
62
64
 
@@ -1,13 +1,14 @@
1
1
  import os
2
+ import sys
3
+ import importlib
2
4
  import hashlib
3
5
  import zipfile
4
6
  import re
5
7
  import ast
6
8
  import importlib.util
7
9
  import json
8
- import astor
9
10
  import ipynbname
10
- import sys
11
+ from copy import deepcopy
11
12
 
12
13
  from pathlib import Path
13
14
  from IPython import get_ipython
@@ -23,7 +24,7 @@ logger = logging.getLogger(__name__)
23
24
  logging_level = logger.setLevel(logging.DEBUG) if os.getenv("DEBUG") == "1" else logging.INFO
24
25
 
25
26
 
26
- # Define the PackageUsageRemover class
27
+ # PackageUsageRemover class
27
28
  class PackageUsageRemover(ast.NodeTransformer):
28
29
  def __init__(self, package_name):
29
30
  self.package_name = package_name
@@ -49,7 +50,12 @@ class PackageUsageRemover(ast.NodeTransformer):
49
50
  return node
50
51
 
51
52
  def visit_Assign(self, node):
52
- if self._uses_package(node.value):
53
+ if isinstance(node.value, ast.Expr):
54
+ node_value = node.value.body
55
+ else:
56
+ node_value = node.value
57
+
58
+ if self._uses_package(node_value):
53
59
  return None
54
60
  return node
55
61
 
@@ -60,8 +66,10 @@ class PackageUsageRemover(ast.NodeTransformer):
60
66
  if isinstance(node.func.value, ast.Name) and node.func.value.id in self.imported_names:
61
67
  return None
62
68
  return node
63
-
69
+
64
70
  def _uses_package(self, node):
71
+ if isinstance(node, ast.Expr):
72
+ return self._uses_package(node.body)
65
73
  if isinstance(node, ast.Name) and node.id in self.imported_names:
66
74
  return True
67
75
  if isinstance(node, ast.Call):
@@ -70,16 +78,19 @@ class PackageUsageRemover(ast.NodeTransformer):
70
78
  return self._uses_package(node.value)
71
79
  return False
72
80
 
73
- # Define the function to remove package code from a source code string
81
+
82
+ # Remove package code from a source code string
74
83
  def remove_package_code(source_code: str, package_name: str) -> str:
75
84
  try:
76
85
  tree = ast.parse(source_code)
77
- transformer = PackageUsageRemover(package_name)
78
- modified_tree = transformer.visit(tree)
79
- modified_code = astor.to_source(modified_tree)
86
+ # remover = PackageUsageRemover(package_name)
87
+ # modified_tree = remover.visit(tree)
88
+ modified_code = ast.unparse(tree)
89
+
80
90
  return modified_code
81
91
  except Exception as e:
82
- raise Exception(f"Error processing source code: {str(e)}")
92
+ logger.error(f"Error in remove_package_code: {e}")
93
+ return source_code
83
94
 
84
95
  class JupyterNotebookHandler:
85
96
  @staticmethod
@@ -118,25 +129,12 @@ class JupyterNotebookHandler:
118
129
  # Check if running in Colab
119
130
  if JupyterNotebookHandler.is_running_in_colab():
120
131
  try:
121
- from google.colab import drive
122
- if not os.path.exists('/content/drive'):
123
- drive.mount('/content/drive')
124
- # logger.info("Google Drive mounted successfully")
125
-
126
132
  # Look for notebooks in /content first
127
133
  ipynb_files = list(Path('/content').glob('*.ipynb'))
128
134
  if ipynb_files:
129
135
  current_nb = max(ipynb_files, key=os.path.getmtime)
130
136
  # logger.info(f"Found current Colab notebook: {current_nb}")
131
137
  return str(current_nb)
132
-
133
- # Then check Drive if mounted
134
- if os.path.exists('/content/drive'):
135
- drive_ipynb_files = list(Path('/content/drive').rglob('*.ipynb'))
136
- if drive_ipynb_files:
137
- current_nb = max(drive_ipynb_files, key=os.path.getmtime)
138
- # logger.info(f"Found Colab notebook in Drive: {current_nb}")
139
- return str(current_nb)
140
138
  except Exception as e:
141
139
  logger.warning(f"Error in Colab notebook detection: {str(e)}")
142
140
 
@@ -201,7 +199,6 @@ def comment_magic_commands(script_content: str) -> str:
201
199
  class TraceDependencyTracker:
202
200
  def __init__(self, output_dir=None):
203
201
  self.tracked_files = set()
204
- self.python_imports = set()
205
202
  self.notebook_path = None
206
203
  self.colab_content = None
207
204
 
@@ -292,7 +289,7 @@ class TraceDependencyTracker:
292
289
  except (UnicodeDecodeError, IOError):
293
290
  pass
294
291
 
295
- def analyze_python_imports(self, filepath):
292
+ def analyze_python_imports(self, filepath, ignored_locations):
296
293
  try:
297
294
  with open(filepath, 'r', encoding='utf-8') as file:
298
295
  tree = ast.parse(file.read(), filename=filepath)
@@ -305,48 +302,75 @@ class TraceDependencyTracker:
305
302
  module_name = name.name.split('.')[0]
306
303
  try:
307
304
  spec = importlib.util.find_spec(module_name)
308
- if spec and spec.origin and not spec.origin.startswith(os.path.dirname(importlib.__file__)):
309
- self.python_imports.add(spec.origin)
305
+ if spec and spec.origin:
306
+ if not (any(spec.origin.startswith(location) for location in ignored_locations) or (spec.origin in ['built-in', 'frozen'])):
307
+ self.tracked_files.add(spec.origin)
308
+ self.analyze_python_imports(spec.origin, ignored_locations)
310
309
  except (ImportError, AttributeError):
311
310
  pass
312
311
  except Exception as e:
313
312
  pass
314
313
 
314
+ def get_env_location(self):
315
+ return sys.prefix
316
+
317
+ def get_catalyst_location(self):
318
+ try:
319
+ imported_module = importlib.import_module("ragaai_catalyst")
320
+ return os.path.dirname(os.path.abspath(imported_module.__file__))
321
+ except ImportError:
322
+ logger.error("Error getting Catalyst location")
323
+ return 'ragaai_catalyst'
324
+
325
+ def should_ignore_path(self, path, main_filepaths):
326
+ if any(os.path.abspath(path) in os.path.abspath(main_filepath) for main_filepath in main_filepaths):
327
+ return False
328
+ if path in ['', os.path.abspath('')]:
329
+ return False
330
+ return True
331
+
315
332
  def create_zip(self, filepaths):
316
333
  self.track_jupyter_notebook()
317
- # logger.info("Tracked Jupyter notebook and its dependencies")
318
334
 
319
335
  # Ensure output directory exists
320
336
  os.makedirs(self.output_dir, exist_ok=True)
321
- # logger.info(f"Using output directory: {self.output_dir}")
322
337
 
323
338
  # Special handling for Colab
324
339
  if self.jupyter_handler.is_running_in_colab():
325
- # logger.info("Running in Google Colab environment")
326
- # Try to get the Colab notebook path
340
+ # Get the Colab notebook path
327
341
  colab_notebook = self.jupyter_handler.get_notebook_path()
328
342
  if colab_notebook:
329
343
  self.tracked_files.add(os.path.abspath(colab_notebook))
330
- # logger.info(f"Added Colab notebook to tracked files: {colab_notebook}")
331
344
 
332
345
  # Get current cell content
333
346
  self.check_environment_and_save()
334
347
 
348
+ env_location = self.get_env_location()
349
+ catalyst_location = self.get_catalyst_location()
350
+
335
351
  # Process all files (existing code)
352
+ ignored_locations = [env_location, catalyst_location] + [path for path in sys.path if self.should_ignore_path(path, filepaths)]
336
353
  for filepath in filepaths:
337
354
  abs_path = os.path.abspath(filepath)
338
355
  self.track_file_access(abs_path)
339
356
  try:
340
- with open(abs_path, 'r', encoding='utf-8') as file:
357
+ if filepath.endswith('.py'):
358
+ self.analyze_python_imports(abs_path, ignored_locations)
359
+ except Exception as e:
360
+ pass
361
+
362
+ curr_tracked_files = deepcopy(self.tracked_files)
363
+ for filepath in curr_tracked_files:
364
+ try:
365
+ with open(filepath, 'r', encoding='utf-8') as file:
341
366
  content = file.read()
342
367
  # Comment out magic commands before processing
343
368
  content = comment_magic_commands(content)
344
- self.find_config_files(content, abs_path)
345
- if filepath.endswith('.py'):
346
- self.analyze_python_imports(abs_path)
369
+ self.find_config_files(content, filepath)
347
370
  except Exception as e:
348
371
  pass
349
372
 
373
+
350
374
  notebook_content_str = None
351
375
  if self.notebook_path and os.path.exists(self.notebook_path):
352
376
  try:
@@ -370,13 +394,12 @@ class TraceDependencyTracker:
370
394
  pass
371
395
 
372
396
  # Calculate hash and create zip
373
- self.tracked_files.update(self.python_imports)
374
397
  hash_contents = []
375
398
 
376
399
  for filepath in sorted(self.tracked_files):
377
400
  if not filepath.endswith('.py'):
378
401
  continue
379
- elif '/envs' in filepath or '__init__' in filepath:
402
+ elif env_location in filepath or '__init__' in filepath:
380
403
  continue
381
404
  try:
382
405
  with open(filepath, 'rb') as file:
@@ -409,11 +432,15 @@ class TraceDependencyTracker:
409
432
 
410
433
  with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
411
434
  for filepath in sorted(self.tracked_files):
412
- if 'env' in filepath or 'ragaai_catalyst' in filepath:
435
+ if env_location in filepath or catalyst_location in filepath:
413
436
  continue
414
437
  try:
415
438
  relative_path = os.path.relpath(filepath, base_path)
416
- zipf.write(filepath, relative_path)
439
+ if relative_path in ['', '.']:
440
+ zipf.write(filepath, os.path.basename(filepath))
441
+ else:
442
+ zipf.write(filepath, relative_path)
443
+
417
444
  logger.debug(f"Added python script to zip: {relative_path}")
418
445
  except Exception as e:
419
446
  pass
@@ -446,10 +473,9 @@ def zip_list_of_unique_files(filepaths, output_dir=None):
446
473
  return tracker.create_zip(filepaths)
447
474
 
448
475
 
449
- # Example usage
450
- if __name__ == "__main__":
451
- filepaths = ["script1.py", "script2.py"]
452
- hash_id, zip_path = zip_list_of_unique_files(filepaths)
453
- print(f"Created zip file: {zip_path}")
454
- print(f"Hash ID: {hash_id}")
455
-
476
+ # # Example usage
477
+ # if __name__ == "__main__":
478
+ # filepaths = ["script1.py", "script2.py"]
479
+ # hash_id, zip_path = zip_list_of_unique_files(filepaths)
480
+ # print(f"Created zip file: {zip_path}")
481
+ # print(f"Hash ID: {hash_id}")