ragaai-catalyst 2.1.5b22__py3-none-any.whl → 2.1.5b23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,6 +5,7 @@ from .dataset import Dataset
5
5
  from .prompt_manager import PromptManager
6
6
  from .evaluation import Evaluation
7
7
  from .synthetic_data_generation import SyntheticDataGeneration
8
+ from .redteaming import RedTeaming
8
9
  from .guardrails_manager import GuardrailsManager
9
10
  from .guard_executor import GuardExecutor
10
11
  from .tracers import Tracer, init_tracing, trace_agent, trace_llm, trace_tool, current_span, trace_custom
@@ -18,7 +19,8 @@ __all__ = [
18
19
  "Tracer",
19
20
  "PromptManager",
20
21
  "Evaluation",
21
- "SyntheticDataGeneration",
22
+ "SyntheticDataGeneration",
23
+ "RedTeaming",
22
24
  "GuardrailsManager",
23
25
  "GuardExecutor",
24
26
  "init_tracing",
@@ -1,5 +1,7 @@
1
1
  import os
2
+ import csv
2
3
  import json
4
+ import tempfile
3
5
  import requests
4
6
  from .utils import response_checker
5
7
  from typing import Union
@@ -653,4 +655,50 @@ class Dataset:
653
655
  return JOB_STATUS_FAILED
654
656
  except Exception as e:
655
657
  logger.error(f"An unexpected error occurred: {e}")
656
- return JOB_STATUS_FAILED
658
+ return JOB_STATUS_FAILED
659
+
660
+ def _jsonl_to_csv(self, jsonl_file, csv_file):
661
+ """Convert a JSONL file to a CSV file."""
662
+ with open(jsonl_file, 'r', encoding='utf-8') as infile:
663
+ data = [json.loads(line) for line in infile]
664
+
665
+ if not data:
666
+ print("Empty JSONL file.")
667
+ return
668
+
669
+ with open(csv_file, 'w', newline='', encoding='utf-8') as outfile:
670
+ writer = csv.DictWriter(outfile, fieldnames=data[0].keys())
671
+ writer.writeheader()
672
+ writer.writerows(data)
673
+
674
+ print(f"Converted {jsonl_file} to {csv_file}")
675
+
676
+ def create_from_jsonl(self, jsonl_path, dataset_name, schema_mapping):
677
+ tmp_csv_path = os.path.join(tempfile.gettempdir(), f"{dataset_name}.csv")
678
+ try:
679
+ self._jsonl_to_csv(jsonl_path, tmp_csv_path)
680
+ self.create_from_csv(tmp_csv_path, dataset_name, schema_mapping)
681
+ except (IOError, UnicodeError) as e:
682
+ logger.error(f"Error converting JSONL to CSV: {e}")
683
+ raise
684
+ finally:
685
+ if os.path.exists(tmp_csv_path):
686
+ try:
687
+ os.remove(tmp_csv_path)
688
+ except Exception as e:
689
+ logger.error(f"Error removing temporary CSV file: {e}")
690
+
691
+ def add_rows_from_jsonl(self, jsonl_path, dataset_name):
692
+ tmp_csv_path = os.path.join(tempfile.gettempdir(), f"{dataset_name}.csv")
693
+ try:
694
+ self._jsonl_to_csv(jsonl_path, tmp_csv_path)
695
+ self.add_rows(tmp_csv_path, dataset_name)
696
+ except (IOError, UnicodeError) as e:
697
+ logger.error(f"Error converting JSONL to CSV: {e}")
698
+ raise
699
+ finally:
700
+ if os.path.exists(tmp_csv_path):
701
+ try:
702
+ os.remove(tmp_csv_path)
703
+ except Exception as e:
704
+ logger.error(f"Error removing temporary CSV file: {e}")
@@ -0,0 +1,171 @@
1
+ import logging
2
+ import os
3
+ from typing import Callable, Optional
4
+
5
+ import giskard as scanner
6
+ import pandas as pd
7
+
8
+ logging.getLogger('giskard.core').disabled = True
9
+ logging.getLogger('giskard.scanner.logger').disabled = True
10
+ logging.getLogger('giskard.models.automodel').disabled = True
11
+ logging.getLogger('giskard.datasets.base').disabled = True
12
+ logging.getLogger('giskard.utils.logging_utils').disabled = True
13
+
14
+
15
+ class RedTeaming:
16
+
17
+ def __init__(self,
18
+ provider: Optional[str] = "openai",
19
+ model: Optional[str] = None,
20
+ api_key: Optional[str] = None,
21
+ api_base: Optional[str] = None,
22
+ api_version: Optional[str] = None):
23
+ self.provider = provider.lower()
24
+ self.model = model
25
+ if not self.provider:
26
+ raise ValueError("Model configuration must be provided with a valid provider and model.")
27
+ if self.provider == "openai":
28
+ if api_key is not None:
29
+ os.environ["OPENAI_API_KEY"] = api_key
30
+ if os.getenv("OPENAI_API_KEY") is None:
31
+ raise ValueError("API key must be provided for OpenAI.")
32
+ elif self.provider == "gemini":
33
+ if api_key is not None:
34
+ os.environ["GEMINI_API_KEY"] = api_key
35
+ if os.getenv("GEMINI_API_KEY") is None:
36
+ raise ValueError("API key must be provided for Gemini.")
37
+ elif self.provider == "azure":
38
+ if api_key is not None:
39
+ os.environ["AZURE_API_KEY"] = api_key
40
+ if api_base is not None:
41
+ os.environ["AZURE_API_BASE"] = api_base
42
+ if api_version is not None:
43
+ os.environ["AZURE_API_VERSION"] = api_version
44
+ if os.getenv("AZURE_API_KEY") is None:
45
+ raise ValueError("API key must be provided for Azure.")
46
+ if os.getenv("AZURE_API_BASE") is None:
47
+ raise ValueError("API base must be provided for Azure.")
48
+ if os.getenv("AZURE_API_VERSION") is None:
49
+ raise ValueError("API version must be provided for Azure.")
50
+ else:
51
+ raise ValueError(f"Provider is not recognized.")
52
+
53
+ def run_scan(
54
+ self,
55
+ model: Callable,
56
+ evaluators: Optional[list] = None,
57
+ save_report: bool = True
58
+ ) -> pd.DataFrame:
59
+ """
60
+ Runs red teaming on the provided model and returns a DataFrame of the results.
61
+
62
+ :param model: The model function provided by the user (can be sync or async).
63
+ :param evaluators: Optional list of scan metrics to run.
64
+ :param save_report: Boolean flag indicating whether to save the scan report as a CSV file.
65
+ :return: A DataFrame containing the scan report.
66
+ """
67
+ import asyncio
68
+ import inspect
69
+
70
+ self.set_scanning_model(self.provider, self.model)
71
+
72
+ supported_evaluators = self.get_supported_evaluators()
73
+ if evaluators:
74
+ if isinstance(evaluators, str):
75
+ evaluators = [evaluators]
76
+ invalid_evaluators = [evaluator for evaluator in evaluators if evaluator not in supported_evaluators]
77
+ if invalid_evaluators:
78
+ raise ValueError(f"Invalid evaluators: {invalid_evaluators}. "
79
+ f"Allowed evaluators: {supported_evaluators}.")
80
+
81
+ # Handle async model functions by wrapping them in a sync function
82
+ if inspect.iscoroutinefunction(model):
83
+ def sync_wrapper(*args, **kwargs):
84
+ try:
85
+ # Try to get the current event loop
86
+ loop = asyncio.get_event_loop()
87
+ except RuntimeError:
88
+ # If no event loop exists (e.g., in Jupyter), create a new one
89
+ loop = asyncio.new_event_loop()
90
+ asyncio.set_event_loop(loop)
91
+
92
+ try:
93
+ # Handle both IPython and regular Python environments
94
+ import nest_asyncio
95
+ nest_asyncio.apply()
96
+ except ImportError:
97
+ pass # nest_asyncio not available, continue without it
98
+
99
+ return loop.run_until_complete(model(*args, **kwargs))
100
+ wrapped_model = sync_wrapper
101
+ else:
102
+ wrapped_model = model
103
+
104
+ model_instance = scanner.Model(
105
+ model=wrapped_model,
106
+ model_type="text_generation",
107
+ name="RagaAI's Scan",
108
+ description="RagaAI's RedTeaming Scan",
109
+ feature_names=["question"],
110
+ )
111
+
112
+ try:
113
+ report = scanner.scan(model_instance, only=evaluators, raise_exceptions=True) if evaluators \
114
+ else scanner.scan(model_instance, raise_exceptions=True)
115
+ except Exception as e:
116
+ raise RuntimeError(f"Error occurred during model scan: {str(e)}")
117
+
118
+ report_df = report.to_dataframe()
119
+
120
+ if save_report:
121
+ report_df.to_csv("raga-ai_red-teaming_scan.csv", index=False)
122
+
123
+ return report_df
124
+
125
+ def get_supported_evaluators(self):
126
+ """Contains tags corresponding to the 'llm' and 'robustness' directories in the giskard > scanner library"""
127
+ return {'control_chars_injection',
128
+ 'discrimination',
129
+ 'ethical_bias',
130
+ 'ethics',
131
+ 'faithfulness',
132
+ 'generative',
133
+ 'hallucination',
134
+ 'harmfulness',
135
+ 'implausible_output',
136
+ 'information_disclosure',
137
+ 'jailbreak',
138
+ 'llm',
139
+ 'llm_harmful_content',
140
+ 'llm_stereotypes_detector',
141
+ 'misinformation',
142
+ 'output_formatting',
143
+ 'prompt_injection',
144
+ 'robustness',
145
+ 'stereotypes',
146
+ 'sycophancy',
147
+ 'text_generation',
148
+ 'text_perturbation'}
149
+
150
+ def set_scanning_model(self, provider, model=None):
151
+ """
152
+ Sets the LLM model for Giskard based on the provider.
153
+
154
+ :param provider: The LLM provider (e.g., "openai", "gemini", "azure").
155
+ :param model: The specific model name to use (optional).
156
+ :raises ValueError: If the provider is "azure" and no model is provided.
157
+ """
158
+ default_models = {
159
+ "openai": "gpt-4o",
160
+ "gemini": "gemini-1.5-pro"
161
+ }
162
+
163
+ if provider == "azure" and model is None:
164
+ raise ValueError("Model must be provided for Azure.")
165
+
166
+ selected_model = model if model is not None else default_models.get(provider)
167
+
168
+ if selected_model is None:
169
+ raise ValueError(f"Unsupported provider: {provider}")
170
+
171
+ scanner.llm.set_llm_model(selected_model)
@@ -475,7 +475,7 @@ class SyntheticDataGeneration:
475
475
  Returns:
476
476
  list: A list of supported AI providers.
477
477
  """
478
- return ['gemini', 'openai']
478
+ return ['gemini', 'openai','azure']
479
479
 
480
480
  # Usage:
481
481
  # from synthetic_data_generation import SyntheticDataGeneration
@@ -48,15 +48,15 @@ class AgentTracerMixin:
48
48
  self.auto_instrument_network = False
49
49
 
50
50
  def trace_agent(
51
- self,
52
- name: str,
53
- agent_type: str = None,
54
- version: str = None,
55
- capabilities: List[str] = None,
56
- tags: List[str] = [],
57
- metadata: Dict[str, Any] = {},
58
- metrics: List[Dict[str, Any]] = [],
59
- feedback: Optional[Any] = None,
51
+ self,
52
+ name: str,
53
+ agent_type: str = None,
54
+ version: str = None,
55
+ capabilities: List[str] = None,
56
+ tags: List[str] = [],
57
+ metadata: Dict[str, Any] = {},
58
+ metrics: List[Dict[str, Any]] = [],
59
+ feedback: Optional[Any] = None,
60
60
  ):
61
61
  if name not in self.span_attributes_dict:
62
62
  self.span_attributes_dict[name] = SpanAttributes(name)
@@ -199,8 +199,8 @@ class AgentTracerMixin:
199
199
  children = tracer.agent_children.get()
200
200
  if children:
201
201
  if (
202
- "children"
203
- not in component["data"]
202
+ "children"
203
+ not in component["data"]
204
204
  ):
205
205
  component["data"][
206
206
  "children"
@@ -263,7 +263,7 @@ class AgentTracerMixin:
263
263
  return decorator
264
264
 
265
265
  def _trace_sync_agent_execution(
266
- self, func, name, agent_type, version, capabilities, top_level_hash_id, *args, **kwargs
266
+ self, func, name, agent_type, version, capabilities, top_level_hash_id, *args, **kwargs
267
267
  ):
268
268
  hash_id = top_level_hash_id
269
269
 
@@ -281,7 +281,7 @@ class AgentTracerMixin:
281
281
  component_id = str(uuid.uuid4())
282
282
 
283
283
  # Extract ground truth if present
284
- ground_truth = kwargs.pop("gt") if kwargs else None
284
+ ground_truth = kwargs.pop("gt", None) if kwargs else None
285
285
  if ground_truth is not None:
286
286
  span = self.span(name)
287
287
  span.add_gt(ground_truth)
@@ -390,7 +390,7 @@ class AgentTracerMixin:
390
390
  self.agent_children.reset(children_token)
391
391
 
392
392
  async def _trace_agent_execution(
393
- self, func, name, agent_type, version, capabilities, hash_id, *args, **kwargs
393
+ self, func, name, agent_type, version, capabilities, hash_id, *args, **kwargs
394
394
  ):
395
395
  """Asynchronous version of agent tracing"""
396
396
  if not self.is_active:
@@ -404,7 +404,7 @@ class AgentTracerMixin:
404
404
  component_id = str(uuid.uuid4())
405
405
 
406
406
  # Extract ground truth if present
407
- ground_truth = kwargs.pop("gt") if kwargs else None
407
+ ground_truth = kwargs.pop("gt", None) if kwargs else None
408
408
  if ground_truth is not None:
409
409
  span = self.span(name)
410
410
  span.add_gt(ground_truth)
@@ -522,7 +522,7 @@ class AgentTracerMixin:
522
522
  for interaction in self.component_user_interaction.get(kwargs["component_id"], []):
523
523
  if interaction["interaction_type"] in ["input", "output"]:
524
524
  input_output_interactions.append(interaction)
525
- interactions.extend(input_output_interactions)
525
+ interactions.extend(input_output_interactions)
526
526
  if self.auto_instrument_file_io:
527
527
  file_io_interactions = []
528
528
  for interaction in self.component_user_interaction.get(kwargs["component_id"], []):
@@ -551,9 +551,10 @@ class AgentTracerMixin:
551
551
  counter = sum(1 for x in self.visited_metrics if x.startswith(base_metric_name))
552
552
  metric_name = f'{base_metric_name}_{counter}' if counter > 0 else base_metric_name
553
553
  self.visited_metrics.append(metric_name)
554
- metric["name"] = metric_name
554
+ metric["name"] = metric_name
555
555
  metrics.append(metric)
556
556
 
557
+ # TODO agent_trace execute metric
557
558
  component = {
558
559
  "id": kwargs["component_id"],
559
560
  "hash_id": kwargs["hash_id"],
@@ -609,22 +610,22 @@ class AgentTracerMixin:
609
610
  self.component_network_calls.set(component_network_calls)
610
611
 
611
612
  def _sanitize_input(self, args: tuple, kwargs: dict) -> dict:
612
- """Sanitize and format input data, including handling of nested lists and dictionaries."""
613
-
614
- def sanitize_value(value):
615
- if isinstance(value, (int, float, bool, str)):
616
- return value
617
- elif isinstance(value, list):
618
- return [sanitize_value(item) for item in value]
619
- elif isinstance(value, dict):
620
- return {key: sanitize_value(val) for key, val in value.items()}
621
- else:
622
- return str(value) # Convert non-standard types to string
613
+ """Sanitize and format input data, including handling of nested lists and dictionaries."""
614
+
615
+ def sanitize_value(value):
616
+ if isinstance(value, (int, float, bool, str)):
617
+ return value
618
+ elif isinstance(value, list):
619
+ return [sanitize_value(item) for item in value]
620
+ elif isinstance(value, dict):
621
+ return {key: sanitize_value(val) for key, val in value.items()}
622
+ else:
623
+ return str(value) # Convert non-standard types to string
623
624
 
624
- return {
625
- "args": [sanitize_value(arg) for arg in args],
626
- "kwargs": {key: sanitize_value(val) for key, val in kwargs.items()},
627
- }
625
+ return {
626
+ "args": [sanitize_value(arg) for arg in args],
627
+ "kwargs": {key: sanitize_value(val) for key, val in kwargs.items()},
628
+ }
628
629
 
629
630
  def _sanitize_output(self, output: Any) -> Any:
630
631
  """Sanitize and format output data"""
@@ -640,6 +641,6 @@ class AgentTracerMixin:
640
641
 
641
642
  def instrument_network_calls(self):
642
643
  self.auto_instrument_network = True
643
-
644
+
644
645
  def instrument_file_io_calls(self):
645
646
  self.auto_instrument_file_io = True