ragaai-catalyst 2.1.5b22__py3-none-any.whl → 2.1.5b24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,6 +5,7 @@ from .dataset import Dataset
5
5
  from .prompt_manager import PromptManager
6
6
  from .evaluation import Evaluation
7
7
  from .synthetic_data_generation import SyntheticDataGeneration
8
+ from .redteaming import RedTeaming
8
9
  from .guardrails_manager import GuardrailsManager
9
10
  from .guard_executor import GuardExecutor
10
11
  from .tracers import Tracer, init_tracing, trace_agent, trace_llm, trace_tool, current_span, trace_custom
@@ -18,7 +19,8 @@ __all__ = [
18
19
  "Tracer",
19
20
  "PromptManager",
20
21
  "Evaluation",
21
- "SyntheticDataGeneration",
22
+ "SyntheticDataGeneration",
23
+ "RedTeaming",
22
24
  "GuardrailsManager",
23
25
  "GuardExecutor",
24
26
  "init_tracing",
@@ -1,5 +1,7 @@
1
1
  import os
2
+ import csv
2
3
  import json
4
+ import tempfile
3
5
  import requests
4
6
  from .utils import response_checker
5
7
  from typing import Union
@@ -653,4 +655,50 @@ class Dataset:
653
655
  return JOB_STATUS_FAILED
654
656
  except Exception as e:
655
657
  logger.error(f"An unexpected error occurred: {e}")
656
- return JOB_STATUS_FAILED
658
+ return JOB_STATUS_FAILED
659
+
660
+ def _jsonl_to_csv(self, jsonl_file, csv_file):
661
+ """Convert a JSONL file to a CSV file."""
662
+ with open(jsonl_file, 'r', encoding='utf-8') as infile:
663
+ data = [json.loads(line) for line in infile]
664
+
665
+ if not data:
666
+ print("Empty JSONL file.")
667
+ return
668
+
669
+ with open(csv_file, 'w', newline='', encoding='utf-8') as outfile:
670
+ writer = csv.DictWriter(outfile, fieldnames=data[0].keys())
671
+ writer.writeheader()
672
+ writer.writerows(data)
673
+
674
+ print(f"Converted {jsonl_file} to {csv_file}")
675
+
676
+ def create_from_jsonl(self, jsonl_path, dataset_name, schema_mapping):
677
+ tmp_csv_path = os.path.join(tempfile.gettempdir(), f"{dataset_name}.csv")
678
+ try:
679
+ self._jsonl_to_csv(jsonl_path, tmp_csv_path)
680
+ self.create_from_csv(tmp_csv_path, dataset_name, schema_mapping)
681
+ except (IOError, UnicodeError) as e:
682
+ logger.error(f"Error converting JSONL to CSV: {e}")
683
+ raise
684
+ finally:
685
+ if os.path.exists(tmp_csv_path):
686
+ try:
687
+ os.remove(tmp_csv_path)
688
+ except Exception as e:
689
+ logger.error(f"Error removing temporary CSV file: {e}")
690
+
691
+ def add_rows_from_jsonl(self, jsonl_path, dataset_name):
692
+ tmp_csv_path = os.path.join(tempfile.gettempdir(), f"{dataset_name}.csv")
693
+ try:
694
+ self._jsonl_to_csv(jsonl_path, tmp_csv_path)
695
+ self.add_rows(tmp_csv_path, dataset_name)
696
+ except (IOError, UnicodeError) as e:
697
+ logger.error(f"Error converting JSONL to CSV: {e}")
698
+ raise
699
+ finally:
700
+ if os.path.exists(tmp_csv_path):
701
+ try:
702
+ os.remove(tmp_csv_path)
703
+ except Exception as e:
704
+ logger.error(f"Error removing temporary CSV file: {e}")
@@ -0,0 +1,171 @@
1
+ import logging
2
+ import os
3
+ from typing import Callable, Optional
4
+
5
+ import giskard as scanner
6
+ import pandas as pd
7
+
8
+ logging.getLogger('giskard.core').disabled = True
9
+ logging.getLogger('giskard.scanner.logger').disabled = True
10
+ logging.getLogger('giskard.models.automodel').disabled = True
11
+ logging.getLogger('giskard.datasets.base').disabled = True
12
+ logging.getLogger('giskard.utils.logging_utils').disabled = True
13
+
14
+
15
+ class RedTeaming:
16
+
17
+ def __init__(self,
18
+ provider: Optional[str] = "openai",
19
+ model: Optional[str] = None,
20
+ api_key: Optional[str] = None,
21
+ api_base: Optional[str] = None,
22
+ api_version: Optional[str] = None):
23
+ self.provider = provider.lower()
24
+ self.model = model
25
+ if not self.provider:
26
+ raise ValueError("Model configuration must be provided with a valid provider and model.")
27
+ if self.provider == "openai":
28
+ if api_key is not None:
29
+ os.environ["OPENAI_API_KEY"] = api_key
30
+ if os.getenv("OPENAI_API_KEY") is None:
31
+ raise ValueError("API key must be provided for OpenAI.")
32
+ elif self.provider == "gemini":
33
+ if api_key is not None:
34
+ os.environ["GEMINI_API_KEY"] = api_key
35
+ if os.getenv("GEMINI_API_KEY") is None:
36
+ raise ValueError("API key must be provided for Gemini.")
37
+ elif self.provider == "azure":
38
+ if api_key is not None:
39
+ os.environ["AZURE_API_KEY"] = api_key
40
+ if api_base is not None:
41
+ os.environ["AZURE_API_BASE"] = api_base
42
+ if api_version is not None:
43
+ os.environ["AZURE_API_VERSION"] = api_version
44
+ if os.getenv("AZURE_API_KEY") is None:
45
+ raise ValueError("API key must be provided for Azure.")
46
+ if os.getenv("AZURE_API_BASE") is None:
47
+ raise ValueError("API base must be provided for Azure.")
48
+ if os.getenv("AZURE_API_VERSION") is None:
49
+ raise ValueError("API version must be provided for Azure.")
50
+ else:
51
+ raise ValueError(f"Provider is not recognized.")
52
+
53
+ def run_scan(
54
+ self,
55
+ model: Callable,
56
+ evaluators: Optional[list] = None,
57
+ save_report: bool = True
58
+ ) -> pd.DataFrame:
59
+ """
60
+ Runs red teaming on the provided model and returns a DataFrame of the results.
61
+
62
+ :param model: The model function provided by the user (can be sync or async).
63
+ :param evaluators: Optional list of scan metrics to run.
64
+ :param save_report: Boolean flag indicating whether to save the scan report as a CSV file.
65
+ :return: A DataFrame containing the scan report.
66
+ """
67
+ import asyncio
68
+ import inspect
69
+
70
+ self.set_scanning_model(self.provider, self.model)
71
+
72
+ supported_evaluators = self.get_supported_evaluators()
73
+ if evaluators:
74
+ if isinstance(evaluators, str):
75
+ evaluators = [evaluators]
76
+ invalid_evaluators = [evaluator for evaluator in evaluators if evaluator not in supported_evaluators]
77
+ if invalid_evaluators:
78
+ raise ValueError(f"Invalid evaluators: {invalid_evaluators}. "
79
+ f"Allowed evaluators: {supported_evaluators}.")
80
+
81
+ # Handle async model functions by wrapping them in a sync function
82
+ if inspect.iscoroutinefunction(model):
83
+ def sync_wrapper(*args, **kwargs):
84
+ try:
85
+ # Try to get the current event loop
86
+ loop = asyncio.get_event_loop()
87
+ except RuntimeError:
88
+ # If no event loop exists (e.g., in Jupyter), create a new one
89
+ loop = asyncio.new_event_loop()
90
+ asyncio.set_event_loop(loop)
91
+
92
+ try:
93
+ # Handle both IPython and regular Python environments
94
+ import nest_asyncio
95
+ nest_asyncio.apply()
96
+ except ImportError:
97
+ pass # nest_asyncio not available, continue without it
98
+
99
+ return loop.run_until_complete(model(*args, **kwargs))
100
+ wrapped_model = sync_wrapper
101
+ else:
102
+ wrapped_model = model
103
+
104
+ model_instance = scanner.Model(
105
+ model=wrapped_model,
106
+ model_type="text_generation",
107
+ name="RagaAI's Scan",
108
+ description="RagaAI's RedTeaming Scan",
109
+ feature_names=["question"],
110
+ )
111
+
112
+ try:
113
+ report = scanner.scan(model_instance, only=evaluators, raise_exceptions=True) if evaluators \
114
+ else scanner.scan(model_instance, raise_exceptions=True)
115
+ except Exception as e:
116
+ raise RuntimeError(f"Error occurred during model scan: {str(e)}")
117
+
118
+ report_df = report.to_dataframe()
119
+
120
+ if save_report:
121
+ report_df.to_csv("raga-ai_red-teaming_scan.csv", index=False)
122
+
123
+ return report_df
124
+
125
+ def get_supported_evaluators(self):
126
+ """Contains tags corresponding to the 'llm' and 'robustness' directories in the giskard > scanner library"""
127
+ return {'control_chars_injection',
128
+ 'discrimination',
129
+ 'ethical_bias',
130
+ 'ethics',
131
+ 'faithfulness',
132
+ 'generative',
133
+ 'hallucination',
134
+ 'harmfulness',
135
+ 'implausible_output',
136
+ 'information_disclosure',
137
+ 'jailbreak',
138
+ 'llm',
139
+ 'llm_harmful_content',
140
+ 'llm_stereotypes_detector',
141
+ 'misinformation',
142
+ 'output_formatting',
143
+ 'prompt_injection',
144
+ 'robustness',
145
+ 'stereotypes',
146
+ 'sycophancy',
147
+ 'text_generation',
148
+ 'text_perturbation'}
149
+
150
+ def set_scanning_model(self, provider, model=None):
151
+ """
152
+ Sets the LLM model for Giskard based on the provider.
153
+
154
+ :param provider: The LLM provider (e.g., "openai", "gemini", "azure").
155
+ :param model: The specific model name to use (optional).
156
+ :raises ValueError: If the provider is "azure" and no model is provided.
157
+ """
158
+ default_models = {
159
+ "openai": "gpt-4o",
160
+ "gemini": "gemini-1.5-pro"
161
+ }
162
+
163
+ if provider == "azure" and model is None:
164
+ raise ValueError("Model must be provided for Azure.")
165
+
166
+ selected_model = model if model is not None else default_models.get(provider)
167
+
168
+ if selected_model is None:
169
+ raise ValueError(f"Unsupported provider: {provider}")
170
+
171
+ scanner.llm.set_llm_model(selected_model)
@@ -475,7 +475,7 @@ class SyntheticDataGeneration:
475
475
  Returns:
476
476
  list: A list of supported AI providers.
477
477
  """
478
- return ['gemini', 'openai']
478
+ return ['gemini', 'openai','azure']
479
479
 
480
480
  # Usage:
481
481
  # from synthetic_data_generation import SyntheticDataGeneration
@@ -9,6 +9,7 @@ import contextvars
9
9
  import asyncio
10
10
  from ..utils.file_name_tracker import TrackName
11
11
  from ..utils.span_attributes import SpanAttributes
12
+ from .base import BaseTracer
12
13
  import logging
13
14
 
14
15
  logger = logging.getLogger(__name__)
@@ -48,15 +49,15 @@ class AgentTracerMixin:
48
49
  self.auto_instrument_network = False
49
50
 
50
51
  def trace_agent(
51
- self,
52
- name: str,
53
- agent_type: str = None,
54
- version: str = None,
55
- capabilities: List[str] = None,
56
- tags: List[str] = [],
57
- metadata: Dict[str, Any] = {},
58
- metrics: List[Dict[str, Any]] = [],
59
- feedback: Optional[Any] = None,
52
+ self,
53
+ name: str,
54
+ agent_type: str = None,
55
+ version: str = None,
56
+ capabilities: List[str] = None,
57
+ tags: List[str] = [],
58
+ metadata: Dict[str, Any] = {},
59
+ metrics: List[Dict[str, Any]] = [],
60
+ feedback: Optional[Any] = None,
60
61
  ):
61
62
  if name not in self.span_attributes_dict:
62
63
  self.span_attributes_dict[name] = SpanAttributes(name)
@@ -199,8 +200,8 @@ class AgentTracerMixin:
199
200
  children = tracer.agent_children.get()
200
201
  if children:
201
202
  if (
202
- "children"
203
- not in component["data"]
203
+ "children"
204
+ not in component["data"]
204
205
  ):
205
206
  component["data"][
206
207
  "children"
@@ -263,7 +264,7 @@ class AgentTracerMixin:
263
264
  return decorator
264
265
 
265
266
  def _trace_sync_agent_execution(
266
- self, func, name, agent_type, version, capabilities, top_level_hash_id, *args, **kwargs
267
+ self, func, name, agent_type, version, capabilities, top_level_hash_id, *args, **kwargs
267
268
  ):
268
269
  hash_id = top_level_hash_id
269
270
 
@@ -281,7 +282,7 @@ class AgentTracerMixin:
281
282
  component_id = str(uuid.uuid4())
282
283
 
283
284
  # Extract ground truth if present
284
- ground_truth = kwargs.pop("gt") if kwargs else None
285
+ ground_truth = kwargs.pop("gt", None) if kwargs else None
285
286
  if ground_truth is not None:
286
287
  span = self.span(name)
287
288
  span.add_gt(ground_truth)
@@ -390,7 +391,7 @@ class AgentTracerMixin:
390
391
  self.agent_children.reset(children_token)
391
392
 
392
393
  async def _trace_agent_execution(
393
- self, func, name, agent_type, version, capabilities, hash_id, *args, **kwargs
394
+ self, func, name, agent_type, version, capabilities, hash_id, *args, **kwargs
394
395
  ):
395
396
  """Asynchronous version of agent tracing"""
396
397
  if not self.is_active:
@@ -404,7 +405,7 @@ class AgentTracerMixin:
404
405
  component_id = str(uuid.uuid4())
405
406
 
406
407
  # Extract ground truth if present
407
- ground_truth = kwargs.pop("gt") if kwargs else None
408
+ ground_truth = kwargs.pop("gt", None) if kwargs else None
408
409
  if ground_truth is not None:
409
410
  span = self.span(name)
410
411
  span.add_gt(ground_truth)
@@ -522,7 +523,7 @@ class AgentTracerMixin:
522
523
  for interaction in self.component_user_interaction.get(kwargs["component_id"], []):
523
524
  if interaction["interaction_type"] in ["input", "output"]:
524
525
  input_output_interactions.append(interaction)
525
- interactions.extend(input_output_interactions)
526
+ interactions.extend(input_output_interactions)
526
527
  if self.auto_instrument_file_io:
527
528
  file_io_interactions = []
528
529
  for interaction in self.component_user_interaction.get(kwargs["component_id"], []):
@@ -551,9 +552,14 @@ class AgentTracerMixin:
551
552
  counter = sum(1 for x in self.visited_metrics if x.startswith(base_metric_name))
552
553
  metric_name = f'{base_metric_name}_{counter}' if counter > 0 else base_metric_name
553
554
  self.visited_metrics.append(metric_name)
554
- metric["name"] = metric_name
555
+ metric["name"] = metric_name
555
556
  metrics.append(metric)
556
557
 
558
+ # TODO agent_trace execute metric
559
+ formatted_metrics = BaseTracer.get_formatted_metric(self.span_attributes_dict, self.project_id, name)
560
+ if formatted_metrics:
561
+ metrics.extend(formatted_metrics)
562
+
557
563
  component = {
558
564
  "id": kwargs["component_id"],
559
565
  "hash_id": kwargs["hash_id"],
@@ -609,22 +615,22 @@ class AgentTracerMixin:
609
615
  self.component_network_calls.set(component_network_calls)
610
616
 
611
617
  def _sanitize_input(self, args: tuple, kwargs: dict) -> dict:
612
- """Sanitize and format input data, including handling of nested lists and dictionaries."""
613
-
614
- def sanitize_value(value):
615
- if isinstance(value, (int, float, bool, str)):
616
- return value
617
- elif isinstance(value, list):
618
- return [sanitize_value(item) for item in value]
619
- elif isinstance(value, dict):
620
- return {key: sanitize_value(val) for key, val in value.items()}
621
- else:
622
- return str(value) # Convert non-standard types to string
618
+ """Sanitize and format input data, including handling of nested lists and dictionaries."""
619
+
620
+ def sanitize_value(value):
621
+ if isinstance(value, (int, float, bool, str)):
622
+ return value
623
+ elif isinstance(value, list):
624
+ return [sanitize_value(item) for item in value]
625
+ elif isinstance(value, dict):
626
+ return {key: sanitize_value(val) for key, val in value.items()}
627
+ else:
628
+ return str(value) # Convert non-standard types to string
623
629
 
624
- return {
625
- "args": [sanitize_value(arg) for arg in args],
626
- "kwargs": {key: sanitize_value(val) for key, val in kwargs.items()},
627
- }
630
+ return {
631
+ "args": [sanitize_value(arg) for arg in args],
632
+ "kwargs": {key: sanitize_value(val) for key, val in kwargs.items()},
633
+ }
628
634
 
629
635
  def _sanitize_output(self, output: Any) -> Any:
630
636
  """Sanitize and format output data"""
@@ -640,6 +646,6 @@ class AgentTracerMixin:
640
646
 
641
647
  def instrument_network_calls(self):
642
648
  self.auto_instrument_network = True
643
-
649
+
644
650
  def instrument_file_io_calls(self):
645
651
  self.auto_instrument_file_io = True