dtx-models 0.18.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtx_models/prompts.py ADDED
@@ -0,0 +1,460 @@
1
+ import hashlib
2
+ import json
3
+ from enum import Enum
4
+ from typing import Any, Dict, Iterator, List, Optional, Union
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field, field_serializer
7
+
8
+ from .exceptions import EntityNotFound
9
+
10
+ from .evaluator import (
11
+ AnyKeywordBasedPromptEvaluation,
12
+ ModelBasedPromptEvaluation,
13
+ )
14
+
15
+ # Define allowed roles
16
+ # RoleType = Literal["USER", "ASSISTANT", "SYSTEM"]
17
+
18
+
19
+ class RoleType(str, Enum):
20
+ USER = "USER"
21
+ ASSISTANT = "ASSISTANT"
22
+ SYSTEM = "SYSTEM"
23
+
24
+ def __str__(self):
25
+ return self.value # Ensures correct YAML serialization
26
+
27
+ @classmethod
28
+ def values(cls):
29
+ return [member.value for member in cls]
30
+
31
+
32
+ class SupportedFormat(str, Enum):
33
+ TEXT = "text"
34
+ OPENAI = "openai"
35
+ ALPACA = "alpaca"
36
+ CHATML = "chatml"
37
+ VICUNA = "vicuna"
38
+
39
+
40
+ class BaseTestPrompt(BaseModel):
41
+ pass
42
+
43
+
44
+ class BaseTestStrPrompt(BaseTestPrompt):
45
+ prompt: str = Field(
46
+ description="Prompt to achieve the goal. "
47
+ "Prompt can have variable templates that can be replaced with values. "
48
+ "The templates should use curly brackets to specify template variables."
49
+ )
50
+
51
+
52
+ class Turn(BaseModel):
53
+ role: RoleType = Field(
54
+ ..., description="The role in the conversation (USER, ASSISTANT, SYSTEM)."
55
+ )
56
+ message: Union[str, Any] = Field(
57
+ ..., min_length=1, description="The message content."
58
+ )
59
+
60
+ @staticmethod
61
+ def validate_message(value: str) -> str:
62
+ """Ensure the message is not empty or just whitespace."""
63
+ if isinstance(value, str):
64
+ value = value.strip()
65
+ if not value:
66
+ raise ValueError("Message cannot be empty or only whitespace.")
67
+ return value
68
+
69
+ @field_serializer("role")
70
+ def serialize_role(self, role: RoleType) -> str:
71
+ """Serialize the role enum to a string."""
72
+ return str(role)
73
+
74
+
75
+ class BaseMultiTurnConversation(BaseTestPrompt):
76
+ turns: List[Turn]
77
+
78
+ def _filter_turns(self, multi_turn: bool) -> List[Turn]:
79
+ if multi_turn:
80
+ return self.turns
81
+ else:
82
+ filtered_turns = [
83
+ turn for turn in self.turns if turn.role == RoleType.SYSTEM
84
+ ]
85
+ first_user_turn = next(
86
+ (turn for turn in self.turns if turn.role == RoleType.USER), None
87
+ )
88
+ if first_user_turn:
89
+ filtered_turns.append(first_user_turn)
90
+ return filtered_turns
91
+
92
+ def to_openai_format(self, multi_turn: bool = True) -> List[Dict[str, str]]:
93
+ """
94
+ Convert the conversation turns into a dictionary format compatible with OpenAI API.
95
+ """
96
+ turns = self._filter_turns(multi_turn)
97
+ return [{"role": turn.role.lower(), "content": turn.message} for turn in turns]
98
+
99
+ def to_alpaca_format(self, multi_turn: bool = True) -> str:
100
+ """
101
+ Convert the conversation to Alpaca format (instruction-based structure).
102
+ """
103
+ turns = self._filter_turns(multi_turn)
104
+ return "\n".join(f"### {turn.role}\n{turn.message}" for turn in turns)
105
+
106
+ def to_chatml_format(self, multi_turn: bool = True) -> str:
107
+ """
108
+ Convert the conversation to ChatML format (OpenAI ChatML used in models like GPT-4).
109
+ """
110
+ turns = self._filter_turns(multi_turn)
111
+ return "\n".join(
112
+ f"<|{turn.role.lower()}|>{turn.message}<|end|>" for turn in turns
113
+ )
114
+
115
+ def to_vicuna_format(self, multi_turn: bool = True) -> str:
116
+ """
117
+ Convert the conversation to Vicuna format (similar to Alpaca but slightly different style).
118
+ """
119
+ turns = self._filter_turns(multi_turn)
120
+ return "".join(
121
+ f"USER: {turn.message}\nASSISTANT: "
122
+ if turn.role == RoleType.USER
123
+ else f"{turn.message}\n"
124
+ for turn in turns
125
+ ).strip()
126
+
127
+ def to_text(self, multi_turn: bool = True) -> str:
128
+ """
129
+ Convert the conversation turns into a plain text format.
130
+ """
131
+ turns = self._filter_turns(multi_turn)
132
+ return "\n".join(f"{turn.role}: {turn.message}" for turn in turns)
133
+
134
+ def to_format(self, supported_format: SupportedFormat, multi_turn: bool = True):
135
+ """
136
+ Convert the conversation to the requested format.
137
+ """
138
+ if supported_format == SupportedFormat.TEXT:
139
+ return self.to_text(multi_turn)
140
+ elif supported_format == SupportedFormat.OPENAI:
141
+ return self.to_openai_format(multi_turn)
142
+ elif supported_format == SupportedFormat.ALPACA:
143
+ return self.to_alpaca_format(multi_turn)
144
+ elif supported_format == SupportedFormat.CHATML:
145
+ return self.to_chatml_format(multi_turn)
146
+ elif supported_format == SupportedFormat.VICUNA:
147
+ return self.to_vicuna_format(multi_turn)
148
+ else:
149
+ raise ValueError(f"Unsupported format: {supported_format}")
150
+
151
+ def last_user_prompt(self) -> str:
152
+ """
153
+ Returns the last USER message from the conversation.
154
+ Raises EntityNotFound if no USER message is found.
155
+ """
156
+ for turn in reversed(self.turns):
157
+ if turn.role == RoleType.USER:
158
+ return turn.message
159
+ raise EntityNotFound("No USER message found in the conversation.")
160
+
161
+ def first_user_prompt(self) -> str:
162
+ """
163
+ Returns the last USER message from the conversation.
164
+ Raises EntityNotFound if no USER message is found.
165
+ """
166
+ for turn in self.turns:
167
+ if turn.role == RoleType.USER:
168
+ return turn.message
169
+ raise EntityNotFound("No USER message found in the conversation.")
170
+
171
+ def first_system_prompt(self) -> str:
172
+ """
173
+ Returns the last USER message from the conversation.
174
+ Raises EntityNotFound if no USER message is found.
175
+ """
176
+ for turn in self.turns:
177
+ if turn.role == RoleType.SYSTEM:
178
+ return turn.message
179
+ return None
180
+
181
+ def last_assistant_response(self) -> str:
182
+ """
183
+ Returns the last ASSISTANT message from the conversation.
184
+ Raises EntityNotFound if no ASSISTANT message is found.
185
+ """
186
+ for turn in reversed(self.turns):
187
+ if turn.role == RoleType.ASSISTANT:
188
+ return turn.message
189
+ raise EntityNotFound("No ASSISTANT message found in the conversation.")
190
+
191
+ def has_last_assistant_response(self) -> bool:
192
+ """
193
+ Returns the last ASSISTANT message from the conversation.
194
+ Raises EntityNotFound if no ASSISTANT message is found.
195
+ """
196
+ for turn in reversed(self.turns):
197
+ if turn.role == RoleType.ASSISTANT:
198
+ return True
199
+ return False
200
+
201
+ def has_system_turn(self) -> bool:
202
+ """
203
+ Returns a System Turn.
204
+ """
205
+ return self.get_system_turn() is not None
206
+
207
+ def get_system_turn(self) -> Turn:
208
+ """
209
+ Returns a System Turn.
210
+ """
211
+
212
+ iterator = iter(self.turns)
213
+ for turn in iterator:
214
+ if turn.role in {RoleType.SYSTEM}:
215
+ return turn
216
+ return None
217
+
218
+ def get_user_turns(self) -> Iterator[Turn]:
219
+ """
220
+ Returns an iterator over turns with USER or SYSTEM roles.
221
+ If a SYSTEM role is encountered, the next USER role is also included.
222
+ """
223
+ iterator = iter(self.turns)
224
+ for turn in iterator:
225
+ if turn.role in {RoleType.USER}:
226
+ yield turn
227
+
228
+ def get_complete_turns(self) -> Iterator[Turn]:
229
+ """
230
+ Yields USER and SYSTEM turns, followed by an ASSISTANT turn.
231
+ If the ASSISTANT response is missing after USER, a placeholder is added.
232
+ SYSTEM is always yielded together with the next USER turn.
233
+ """
234
+ i = 0
235
+ while i < len(self.turns):
236
+ current_turn = self.turns[i]
237
+
238
+ if current_turn.role == RoleType.SYSTEM:
239
+ yield current_turn
240
+ i += 1
241
+ if i < len(self.turns) and self.turns[i].role == RoleType.USER:
242
+ yield self.turns[i]
243
+ i += 1
244
+ if i >= len(self.turns) or self.turns[i].role != RoleType.ASSISTANT:
245
+ yield Turn(role=RoleType.ASSISTANT, message="No Response")
246
+ else:
247
+ yield self.turns[i]
248
+ i += 1
249
+ elif current_turn.role == RoleType.USER:
250
+ yield current_turn
251
+ i += 1
252
+ if i >= len(self.turns) or self.turns[i].role != RoleType.ASSISTANT:
253
+ yield Turn(role=RoleType.ASSISTANT, message="No Response")
254
+ else:
255
+ yield self.turns[i]
256
+ i += 1
257
+ else:
258
+ i += 1
259
+
260
+ def add_turn(self, turn: Turn):
261
+ """
262
+ Appends a single turn to the conversation.
263
+ """
264
+ self.turns.append(turn)
265
+
266
+ def add_turns(self, turns: List[Turn]):
267
+ """
268
+ Appends multiple turns to the conversation.
269
+ """
270
+ self.turns.extend(turns)
271
+
272
+
273
+ class MultiTurnTestPrompt(BaseMultiTurnConversation):
274
+ id: Optional[str] = Field(
275
+ default=None,
276
+ description="Unique ID of the prompt, auto-generated based on content.",
277
+ )
278
+ evaluation_method: Union[
279
+ ModelBasedPromptEvaluation, AnyKeywordBasedPromptEvaluation
280
+ ] = Field(description="Evaluation method for the prompt.")
281
+ module_name: str = Field(description="Module that has generated the prompt")
282
+ policy: str = Field(default="")
283
+ goal: str = Field(default="")
284
+ strategy: str = Field(
285
+ default="", description="strategy used to generate the prompt"
286
+ )
287
+ base_prompt: str = Field(
288
+ default="",
289
+ description="Base prompt in its most simplistic form that need to be answered by AI Agent. Generally it is the harmful.",
290
+ )
291
+
292
+ model_config = ConfigDict(frozen=True) # Make fields immutable
293
+
294
+ def __init__(self, **data):
295
+ """Override init to auto-generate unique ID if not provided."""
296
+ super().__init__(**data)
297
+ object.__setattr__(self, "id", self.compute_unique_id())
298
+
299
+ def compute_unique_id(self) -> str:
300
+ """Computes the SHA-1 hash of the prompt as the ID."""
301
+ prompt = str(self.to_openai_format())
302
+ return hashlib.sha1(
303
+ f"{self.strategy}-{self.goal}-{prompt}".encode()
304
+ ).hexdigest()
305
+
306
+
307
+ class MultiTurnConversation(MultiTurnTestPrompt):
308
+ pass
309
+
310
+
311
+ class MultiturnTestPrompts(BaseModel):
312
+ risk_name: str
313
+ test_prompts: List[MultiTurnTestPrompt] = Field(default_factory=list)
314
+
315
+
316
+ class BaseMultiTurnResponse(BaseMultiTurnConversation):
317
+ """Multi Turn conversation as part of conversation with an agent or LLM"""
318
+
319
+
320
+ class MultiTurnResponse(BaseMultiTurnResponse):
321
+ """Multi Turn conversation as part of conversation with an agent or LLM"""
322
+
323
+ pass
324
+
325
+
326
+ class BaseMultiTurnAgentResponse(BaseMultiTurnResponse):
327
+ """Multi Turn conversation as part of conversation with an agent or LLM"""
328
+
329
+ response: Optional[Union[str, List[Any], Dict[str, Any]]] = Field(
330
+ default_factory=dict,
331
+ description="Final Agent Response",
332
+ )
333
+
334
+ scores: Optional[Dict[str, float]] = Field(
335
+ default_factory=dict,
336
+ description="Optional classification labels with confidence scores.",
337
+ )
338
+
339
+ policy: Optional[str] = Field(
340
+ default=None, description="Policy name that will be targeted by the goal"
341
+ )
342
+ goal: Optional[str] = Field(
343
+ default=None, description="Goal that need to be achieved"
344
+ )
345
+
346
+
347
+ class BaseMultiTurnClassificationResponse(BaseMultiTurnResponse):
348
+ """Multi Turn conversation as part of conversation with an agent or LLM"""
349
+
350
+ labels: Dict[str, float] = Field(
351
+ default_factory=dict,
352
+ description="Optional classification labels with confidence scores.",
353
+ )
354
+
355
+
356
+ class BaseMultiTurnResponseBuilder:
357
+ def __init__(self):
358
+ self.turns = []
359
+ self.response = {}
360
+ self.scores = {}
361
+ self.goal = None
362
+ self.policy = None
363
+
364
+ def add_turn(self, turn: Turn):
365
+ self.turns.append(turn)
366
+ return self
367
+
368
+ def add_turns(self, turns: List[Turn]):
369
+ self.turns.extend(turns)
370
+ return self
371
+
372
+ def add_prompt(self, prompt: str, system_prompt: str = None):
373
+ if system_prompt:
374
+ self.turns.append(Turn(role=RoleType.SYSTEM, message=system_prompt))
375
+ self.turns.append(Turn(role=RoleType.USER, message=prompt))
376
+ return self
377
+
378
+ def add_parsed_response(self, response):
379
+ self.response = response
380
+
381
+ def add_prompt_and_response(self, prompt: str, response: str):
382
+ self.turns.extend(
383
+ [
384
+ Turn(role=RoleType.USER, message=prompt),
385
+ Turn(role=RoleType.ASSISTANT, message=response),
386
+ ]
387
+ )
388
+ return self
389
+
390
+ def add_turn_response(self, response: str):
391
+ self.turns.extend(
392
+ [
393
+ Turn(role=RoleType.ASSISTANT, message=response),
394
+ ]
395
+ )
396
+ return self
397
+
398
+ def validate_sequence(self):
399
+ """
400
+ Validate that the conversation follows the sequence: [SYSTEM] (USER - ASSISTANT)+
401
+ """
402
+ if not self.turns:
403
+ return
404
+
405
+ if self.turns[0].role == RoleType.SYSTEM:
406
+ expected_role = RoleType.USER
407
+ start_index = 1
408
+ else:
409
+ expected_role = RoleType.USER
410
+ start_index = 0
411
+
412
+ for i in range(start_index, len(self.turns)):
413
+ turn = self.turns[i]
414
+ if turn.role != expected_role:
415
+ raise ValueError(
416
+ f"Invalid conversation sequence at turn {i}: Expected {expected_role}, got {turn.role}"
417
+ )
418
+ expected_role = (
419
+ RoleType.ASSISTANT if expected_role == RoleType.USER else RoleType.USER
420
+ )
421
+
422
+ def _get_response_from_last_turn(self):
423
+ """
424
+ If the last turn is from the assistant, set the response field accordingly.
425
+ Tries to parse the message as JSON, otherwise falls back to plain string.
426
+ """
427
+ if not self.turns:
428
+ return None
429
+
430
+ last_turn = self.turns[-1]
431
+
432
+ if last_turn.role == RoleType.ASSISTANT:
433
+ content = last_turn.message
434
+ try:
435
+ return json.loads(content)
436
+ except (json.JSONDecodeError, TypeError):
437
+ return content
438
+ return None
439
+
440
+ def add_prompt_attributes(self, prompt: MultiTurnConversation):
441
+ self._prompt = prompt
442
+ if self._prompt:
443
+ self.policy = (
444
+ getattr(self._prompt, "policy", None)
445
+ if hasattr(self, "_prompt")
446
+ else None
447
+ )
448
+ self.goal = (
449
+ getattr(self._prompt, "goal", None)
450
+ if hasattr(self, "_prompt")
451
+ else None
452
+ )
453
+
454
+ def build(self) -> BaseMultiTurnAgentResponse:
455
+ return BaseMultiTurnAgentResponse(
456
+ turns=self.turns,
457
+ response=self.response,
458
+ policy=self.policy,
459
+ goal=self.goal,
460
+ )
File without changes
@@ -0,0 +1,20 @@
1
+ from enum import Enum
2
+
3
+
4
+ class ProviderType(str, Enum):
5
+ ECHO = "echo"
6
+ ELIZA = "eliza"
7
+ HF = "huggingface"
8
+ HTTP = "http"
9
+ GRADIO = "gradio"
10
+ OLLAMA = "ollama"
11
+ OPENAI = "openai"
12
+ GROQ = "groq"
13
+ LITE_LLM = "litellm"
14
+
15
+ def __str__(self):
16
+ return self.value # Ensures correct YAML serialization
17
+
18
+ @classmethod
19
+ def values(cls):
20
+ return [member.value for member in cls]
@@ -0,0 +1,171 @@
1
+ from typing import Any, Dict, List, Literal, Optional, Union
2
+ from urllib.parse import urlparse
3
+
4
+ from pydantic import BaseModel, Field, field_validator
5
+
6
+ from ..utils.urls import url_2_name
7
+
8
+
9
+ from .base import ProviderType
10
+
11
+
12
+ class GradioApiSignatureParam(BaseModel):
13
+ """
14
+ Represents a single parameter for a Gradio API signature.
15
+ """
16
+
17
+ name: str = Field(..., description="Name of the parameter.")
18
+ has_default_value: bool = Field(
19
+ ..., description="Indicates if the parameter has a default value."
20
+ )
21
+ default_value: Optional[Union[str, int, bool, float, list]] = Field(
22
+ None,
23
+ description="The default value of the parameter, which can be a string, integer, boolean, or list.",
24
+ )
25
+ python_type: str = Field(
26
+ ...,
27
+ description="The Python type of the parameter (e.g., str, int, bool, or a Literal).",
28
+ )
29
+
30
+ @field_validator("default_value", mode="before")
31
+ @classmethod
32
+ def validate_default_value(cls, value):
33
+ """
34
+ Ensures the default_value is of the correct type.
35
+ - If it's a list, pick the first element if possible.
36
+ - If it's an empty list, return `None`.
37
+ """
38
+ if isinstance(value, list):
39
+ return (
40
+ value[0] if value else None
41
+ ) # Take the first element or None if empty
42
+ return value # Otherwise, keep it as it is
43
+
44
+
45
+ class GradioApiSignatureParams(BaseModel):
46
+ """
47
+ A wrapper model to hold multiple API signature parameters.
48
+ """
49
+
50
+ params: List[GradioApiSignatureParam] = Field(
51
+ ..., description="List of API signature parameters."
52
+ )
53
+
54
+
55
+ class GradioApiSpec(BaseModel):
56
+ """
57
+ Represents a Gradio API specification, including the API name and its parameters.
58
+ """
59
+
60
+ api_name: str = Field(..., description="The endpoint path of the API.")
61
+ params: List[GradioApiSignatureParam] = Field(
62
+ default_factory=list, description="List of parameters required by the API."
63
+ )
64
+ # response: Dict[str, Any] = Field(
65
+ # default_factory=dict, description="The expected response structure."
66
+ # )
67
+
68
+
69
+ class GradioApiSpecs(BaseModel):
70
+ """
71
+ Contains multiple Gradio API specifications.
72
+ """
73
+
74
+ apis: List[GradioApiSpec] = Field(
75
+ default_factory=list, description="List of API specifications."
76
+ )
77
+
78
+
79
+ class GradioProviderApiParam(BaseModel):
80
+ """
81
+ Represents a parameter to be used in a request to a Gradio API.
82
+ """
83
+
84
+ name: str = Field(..., description="The name of the parameter.")
85
+ value: Optional[Union[str, int, bool, float, list, tuple, dict]] = Field(
86
+ None, description="The value of the parameter to be sent in the API request."
87
+ )
88
+
89
+
90
+ class GradioResponseParserSignature(BaseModel):
91
+ """
92
+ Signature based parser response
93
+ """
94
+
95
+ parser_type: Literal["signature"] = Field("signature", description="")
96
+ content_type: Optional[str] = Field(
97
+ default="text",
98
+ description="Content type str, array etc.",
99
+ examples=["json", "jsonl", "text", "array"],
100
+ )
101
+ location: Optional[List[Union[str, int]]] = Field(
102
+ None, description="Location of the response as sequence of integers"
103
+ )
104
+
105
+
106
+ class GradioProviderApi(BaseModel):
107
+ """
108
+ Represents a Gradio API request, including its path and parameters.
109
+ """
110
+
111
+ path: str = Field(..., description="The API endpoint path.")
112
+ params: Optional[List[GradioProviderApiParam]] = Field(
113
+ None, description="Optional list of parameters for the API request."
114
+ )
115
+ transform_response: Optional[
116
+ Union[str, GradioResponseParserSignature, Dict[str, Any]]
117
+ ] = Field(
118
+ None, description="Logic to extract Assistant Response from Gradio Response"
119
+ )
120
+
121
+
122
+ class GradioProviderConfig(BaseModel):
123
+ """
124
+ Configuration model for a Gradio API provider.
125
+ Supports both standard URLs and Hugging Face Space identifiers.
126
+ """
127
+
128
+ url: str = Field(
129
+ ...,
130
+ description="The base URL of the Gradio API provider or a Hugging Face Space ID.",
131
+ )
132
+
133
+ apis: Optional[List[GradioProviderApi]] = Field(
134
+ default_factory=list,
135
+ description="Optional list of APIs available for testing."
136
+ )
137
+
138
+ @field_validator("url", mode="before")
139
+ @classmethod
140
+ def validate_url_or_hf_space(cls, value: str) -> str:
141
+ try:
142
+ _ = urlparse(value)
143
+ return value
144
+ except Exception:
145
+ raise ValueError(
146
+ f"Invalid URL or Hugging Face Space identifier: '{value}'. "
147
+ "Provide a valid URL (https://example.com) or a Hugging Face Space in 'org/app' format."
148
+ )
149
+
150
+ def get_name(self) -> str:
151
+ """
152
+ Returns a name derived from the URL, limited to 3 path levels.
153
+ """
154
+ return url_2_name(self.url, level=3)
155
+
156
+
157
+ class GradioProvider(BaseModel):
158
+ """
159
+ Represents a Gradio API provider, including its configuration.
160
+ """
161
+
162
+ provider: Literal["gradio"] = Field(
163
+ ProviderType.GRADIO.value, description="Provider ID, always set to 'gradio'."
164
+ )
165
+ config: GradioProviderConfig = Field(
166
+ ..., description="Configuration details for the Gradio API provider."
167
+ )
168
+
169
+
170
+ class GradioProviders(BaseModel):
171
+ providers: List[GradioProvider]
@@ -0,0 +1,27 @@
1
+ from typing import Literal, Optional
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ from ..providers.openai import BaseProviderConfig
6
+
7
+
8
+ class GroqProviderConfig(BaseProviderConfig):
9
+ endpoint: Optional[str] = Field(
10
+ default="https://api.groq.com/v1",
11
+ description="Base URL of the Groq server or proxy endpoint.",
12
+ )
13
+
14
+ def get_name(self) -> str:
15
+ """
16
+ Returns the model name as the provider's name.
17
+ """
18
+ return self.model
19
+
20
+
21
+ class GroqProvider(BaseModel):
22
+ """Wrapper for Groq provider configuration."""
23
+
24
+ provider: Literal["groq"] = Field(
25
+ "groq", description="Provider ID, always set to 'groq'."
26
+ )
27
+ config: GroqProviderConfig