dtx-models 0.18.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtx_models/__init__.py +0 -0
- dtx_models/analysis.py +322 -0
- dtx_models/base.py +0 -0
- dtx_models/evaluator.py +273 -0
- dtx_models/exceptions.py +2 -0
- dtx_models/prompts.py +460 -0
- dtx_models/providers/__init__.py +0 -0
- dtx_models/providers/base.py +20 -0
- dtx_models/providers/gradio.py +171 -0
- dtx_models/providers/groq.py +27 -0
- dtx_models/providers/hf.py +161 -0
- dtx_models/providers/http.py +152 -0
- dtx_models/providers/litellm.py +21 -0
- dtx_models/providers/models_spec.py +229 -0
- dtx_models/providers/ollama.py +107 -0
- dtx_models/providers/openai.py +139 -0
- dtx_models/results.py +124 -0
- dtx_models/scope.py +208 -0
- dtx_models/tactic.py +52 -0
- dtx_models/target.py +255 -0
- dtx_models/template/__init__.py +0 -0
- dtx_models/template/prompts/__init__.py +0 -0
- dtx_models/template/prompts/base.py +49 -0
- dtx_models/template/prompts/langhub.py +79 -0
- dtx_models/utils/__init__.py +0 -0
- dtx_models/utils/urls.py +26 -0
- dtx_models-0.18.2.dist-info/METADATA +57 -0
- dtx_models-0.18.2.dist-info/RECORD +29 -0
- dtx_models-0.18.2.dist-info/WHEEL +4 -0
dtx_models/prompts.py
ADDED
@@ -0,0 +1,460 @@
|
|
1
|
+
import hashlib
|
2
|
+
import json
|
3
|
+
from enum import Enum
|
4
|
+
from typing import Any, Dict, Iterator, List, Optional, Union
|
5
|
+
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
7
|
+
|
8
|
+
from .exceptions import EntityNotFound
|
9
|
+
|
10
|
+
from .evaluator import (
|
11
|
+
AnyKeywordBasedPromptEvaluation,
|
12
|
+
ModelBasedPromptEvaluation,
|
13
|
+
)
|
14
|
+
|
15
|
+
# Define allowed roles
|
16
|
+
# RoleType = Literal["USER", "ASSISTANT", "SYSTEM"]
|
17
|
+
|
18
|
+
|
19
|
+
class RoleType(str, Enum):
|
20
|
+
USER = "USER"
|
21
|
+
ASSISTANT = "ASSISTANT"
|
22
|
+
SYSTEM = "SYSTEM"
|
23
|
+
|
24
|
+
def __str__(self):
|
25
|
+
return self.value # Ensures correct YAML serialization
|
26
|
+
|
27
|
+
@classmethod
|
28
|
+
def values(cls):
|
29
|
+
return [member.value for member in cls]
|
30
|
+
|
31
|
+
|
32
|
+
class SupportedFormat(str, Enum):
|
33
|
+
TEXT = "text"
|
34
|
+
OPENAI = "openai"
|
35
|
+
ALPACA = "alpaca"
|
36
|
+
CHATML = "chatml"
|
37
|
+
VICUNA = "vicuna"
|
38
|
+
|
39
|
+
|
40
|
+
class BaseTestPrompt(BaseModel):
|
41
|
+
pass
|
42
|
+
|
43
|
+
|
44
|
+
class BaseTestStrPrompt(BaseTestPrompt):
|
45
|
+
prompt: str = Field(
|
46
|
+
description="Prompt to achieve the goal. "
|
47
|
+
"Prompt can have variable templates that can be replaced with values. "
|
48
|
+
"The templates should use curly brackets to specify template variables."
|
49
|
+
)
|
50
|
+
|
51
|
+
|
52
|
+
class Turn(BaseModel):
|
53
|
+
role: RoleType = Field(
|
54
|
+
..., description="The role in the conversation (USER, ASSISTANT, SYSTEM)."
|
55
|
+
)
|
56
|
+
message: Union[str, Any] = Field(
|
57
|
+
..., min_length=1, description="The message content."
|
58
|
+
)
|
59
|
+
|
60
|
+
@staticmethod
|
61
|
+
def validate_message(value: str) -> str:
|
62
|
+
"""Ensure the message is not empty or just whitespace."""
|
63
|
+
if isinstance(value, str):
|
64
|
+
value = value.strip()
|
65
|
+
if not value:
|
66
|
+
raise ValueError("Message cannot be empty or only whitespace.")
|
67
|
+
return value
|
68
|
+
|
69
|
+
@field_serializer("role")
|
70
|
+
def serialize_role(self, role: RoleType) -> str:
|
71
|
+
"""Serialize the role enum to a string."""
|
72
|
+
return str(role)
|
73
|
+
|
74
|
+
|
75
|
+
class BaseMultiTurnConversation(BaseTestPrompt):
|
76
|
+
turns: List[Turn]
|
77
|
+
|
78
|
+
def _filter_turns(self, multi_turn: bool) -> List[Turn]:
|
79
|
+
if multi_turn:
|
80
|
+
return self.turns
|
81
|
+
else:
|
82
|
+
filtered_turns = [
|
83
|
+
turn for turn in self.turns if turn.role == RoleType.SYSTEM
|
84
|
+
]
|
85
|
+
first_user_turn = next(
|
86
|
+
(turn for turn in self.turns if turn.role == RoleType.USER), None
|
87
|
+
)
|
88
|
+
if first_user_turn:
|
89
|
+
filtered_turns.append(first_user_turn)
|
90
|
+
return filtered_turns
|
91
|
+
|
92
|
+
def to_openai_format(self, multi_turn: bool = True) -> List[Dict[str, str]]:
|
93
|
+
"""
|
94
|
+
Convert the conversation turns into a dictionary format compatible with OpenAI API.
|
95
|
+
"""
|
96
|
+
turns = self._filter_turns(multi_turn)
|
97
|
+
return [{"role": turn.role.lower(), "content": turn.message} for turn in turns]
|
98
|
+
|
99
|
+
def to_alpaca_format(self, multi_turn: bool = True) -> str:
|
100
|
+
"""
|
101
|
+
Convert the conversation to Alpaca format (instruction-based structure).
|
102
|
+
"""
|
103
|
+
turns = self._filter_turns(multi_turn)
|
104
|
+
return "\n".join(f"### {turn.role}\n{turn.message}" for turn in turns)
|
105
|
+
|
106
|
+
def to_chatml_format(self, multi_turn: bool = True) -> str:
|
107
|
+
"""
|
108
|
+
Convert the conversation to ChatML format (OpenAI ChatML used in models like GPT-4).
|
109
|
+
"""
|
110
|
+
turns = self._filter_turns(multi_turn)
|
111
|
+
return "\n".join(
|
112
|
+
f"<|{turn.role.lower()}|>{turn.message}<|end|>" for turn in turns
|
113
|
+
)
|
114
|
+
|
115
|
+
def to_vicuna_format(self, multi_turn: bool = True) -> str:
|
116
|
+
"""
|
117
|
+
Convert the conversation to Vicuna format (similar to Alpaca but slightly different style).
|
118
|
+
"""
|
119
|
+
turns = self._filter_turns(multi_turn)
|
120
|
+
return "".join(
|
121
|
+
f"USER: {turn.message}\nASSISTANT: "
|
122
|
+
if turn.role == RoleType.USER
|
123
|
+
else f"{turn.message}\n"
|
124
|
+
for turn in turns
|
125
|
+
).strip()
|
126
|
+
|
127
|
+
def to_text(self, multi_turn: bool = True) -> str:
|
128
|
+
"""
|
129
|
+
Convert the conversation turns into a plain text format.
|
130
|
+
"""
|
131
|
+
turns = self._filter_turns(multi_turn)
|
132
|
+
return "\n".join(f"{turn.role}: {turn.message}" for turn in turns)
|
133
|
+
|
134
|
+
def to_format(self, supported_format: SupportedFormat, multi_turn: bool = True):
|
135
|
+
"""
|
136
|
+
Convert the conversation to the requested format.
|
137
|
+
"""
|
138
|
+
if supported_format == SupportedFormat.TEXT:
|
139
|
+
return self.to_text(multi_turn)
|
140
|
+
elif supported_format == SupportedFormat.OPENAI:
|
141
|
+
return self.to_openai_format(multi_turn)
|
142
|
+
elif supported_format == SupportedFormat.ALPACA:
|
143
|
+
return self.to_alpaca_format(multi_turn)
|
144
|
+
elif supported_format == SupportedFormat.CHATML:
|
145
|
+
return self.to_chatml_format(multi_turn)
|
146
|
+
elif supported_format == SupportedFormat.VICUNA:
|
147
|
+
return self.to_vicuna_format(multi_turn)
|
148
|
+
else:
|
149
|
+
raise ValueError(f"Unsupported format: {supported_format}")
|
150
|
+
|
151
|
+
def last_user_prompt(self) -> str:
|
152
|
+
"""
|
153
|
+
Returns the last USER message from the conversation.
|
154
|
+
Raises EntityNotFound if no USER message is found.
|
155
|
+
"""
|
156
|
+
for turn in reversed(self.turns):
|
157
|
+
if turn.role == RoleType.USER:
|
158
|
+
return turn.message
|
159
|
+
raise EntityNotFound("No USER message found in the conversation.")
|
160
|
+
|
161
|
+
def first_user_prompt(self) -> str:
|
162
|
+
"""
|
163
|
+
Returns the last USER message from the conversation.
|
164
|
+
Raises EntityNotFound if no USER message is found.
|
165
|
+
"""
|
166
|
+
for turn in self.turns:
|
167
|
+
if turn.role == RoleType.USER:
|
168
|
+
return turn.message
|
169
|
+
raise EntityNotFound("No USER message found in the conversation.")
|
170
|
+
|
171
|
+
def first_system_prompt(self) -> str:
|
172
|
+
"""
|
173
|
+
Returns the last USER message from the conversation.
|
174
|
+
Raises EntityNotFound if no USER message is found.
|
175
|
+
"""
|
176
|
+
for turn in self.turns:
|
177
|
+
if turn.role == RoleType.SYSTEM:
|
178
|
+
return turn.message
|
179
|
+
return None
|
180
|
+
|
181
|
+
def last_assistant_response(self) -> str:
|
182
|
+
"""
|
183
|
+
Returns the last ASSISTANT message from the conversation.
|
184
|
+
Raises EntityNotFound if no ASSISTANT message is found.
|
185
|
+
"""
|
186
|
+
for turn in reversed(self.turns):
|
187
|
+
if turn.role == RoleType.ASSISTANT:
|
188
|
+
return turn.message
|
189
|
+
raise EntityNotFound("No ASSISTANT message found in the conversation.")
|
190
|
+
|
191
|
+
def has_last_assistant_response(self) -> bool:
|
192
|
+
"""
|
193
|
+
Returns the last ASSISTANT message from the conversation.
|
194
|
+
Raises EntityNotFound if no ASSISTANT message is found.
|
195
|
+
"""
|
196
|
+
for turn in reversed(self.turns):
|
197
|
+
if turn.role == RoleType.ASSISTANT:
|
198
|
+
return True
|
199
|
+
return False
|
200
|
+
|
201
|
+
def has_system_turn(self) -> bool:
|
202
|
+
"""
|
203
|
+
Returns a System Turn.
|
204
|
+
"""
|
205
|
+
return self.get_system_turn() is not None
|
206
|
+
|
207
|
+
def get_system_turn(self) -> Turn:
|
208
|
+
"""
|
209
|
+
Returns a System Turn.
|
210
|
+
"""
|
211
|
+
|
212
|
+
iterator = iter(self.turns)
|
213
|
+
for turn in iterator:
|
214
|
+
if turn.role in {RoleType.SYSTEM}:
|
215
|
+
return turn
|
216
|
+
return None
|
217
|
+
|
218
|
+
def get_user_turns(self) -> Iterator[Turn]:
|
219
|
+
"""
|
220
|
+
Returns an iterator over turns with USER or SYSTEM roles.
|
221
|
+
If a SYSTEM role is encountered, the next USER role is also included.
|
222
|
+
"""
|
223
|
+
iterator = iter(self.turns)
|
224
|
+
for turn in iterator:
|
225
|
+
if turn.role in {RoleType.USER}:
|
226
|
+
yield turn
|
227
|
+
|
228
|
+
def get_complete_turns(self) -> Iterator[Turn]:
|
229
|
+
"""
|
230
|
+
Yields USER and SYSTEM turns, followed by an ASSISTANT turn.
|
231
|
+
If the ASSISTANT response is missing after USER, a placeholder is added.
|
232
|
+
SYSTEM is always yielded together with the next USER turn.
|
233
|
+
"""
|
234
|
+
i = 0
|
235
|
+
while i < len(self.turns):
|
236
|
+
current_turn = self.turns[i]
|
237
|
+
|
238
|
+
if current_turn.role == RoleType.SYSTEM:
|
239
|
+
yield current_turn
|
240
|
+
i += 1
|
241
|
+
if i < len(self.turns) and self.turns[i].role == RoleType.USER:
|
242
|
+
yield self.turns[i]
|
243
|
+
i += 1
|
244
|
+
if i >= len(self.turns) or self.turns[i].role != RoleType.ASSISTANT:
|
245
|
+
yield Turn(role=RoleType.ASSISTANT, message="No Response")
|
246
|
+
else:
|
247
|
+
yield self.turns[i]
|
248
|
+
i += 1
|
249
|
+
elif current_turn.role == RoleType.USER:
|
250
|
+
yield current_turn
|
251
|
+
i += 1
|
252
|
+
if i >= len(self.turns) or self.turns[i].role != RoleType.ASSISTANT:
|
253
|
+
yield Turn(role=RoleType.ASSISTANT, message="No Response")
|
254
|
+
else:
|
255
|
+
yield self.turns[i]
|
256
|
+
i += 1
|
257
|
+
else:
|
258
|
+
i += 1
|
259
|
+
|
260
|
+
def add_turn(self, turn: Turn):
|
261
|
+
"""
|
262
|
+
Appends a single turn to the conversation.
|
263
|
+
"""
|
264
|
+
self.turns.append(turn)
|
265
|
+
|
266
|
+
def add_turns(self, turns: List[Turn]):
|
267
|
+
"""
|
268
|
+
Appends multiple turns to the conversation.
|
269
|
+
"""
|
270
|
+
self.turns.extend(turns)
|
271
|
+
|
272
|
+
|
273
|
+
class MultiTurnTestPrompt(BaseMultiTurnConversation):
|
274
|
+
id: Optional[str] = Field(
|
275
|
+
default=None,
|
276
|
+
description="Unique ID of the prompt, auto-generated based on content.",
|
277
|
+
)
|
278
|
+
evaluation_method: Union[
|
279
|
+
ModelBasedPromptEvaluation, AnyKeywordBasedPromptEvaluation
|
280
|
+
] = Field(description="Evaluation method for the prompt.")
|
281
|
+
module_name: str = Field(description="Module that has generated the prompt")
|
282
|
+
policy: str = Field(default="")
|
283
|
+
goal: str = Field(default="")
|
284
|
+
strategy: str = Field(
|
285
|
+
default="", description="strategy used to generate the prompt"
|
286
|
+
)
|
287
|
+
base_prompt: str = Field(
|
288
|
+
default="",
|
289
|
+
description="Base prompt in its most simplistic form that need to be answered by AI Agent. Generally it is the harmful.",
|
290
|
+
)
|
291
|
+
|
292
|
+
model_config = ConfigDict(frozen=True) # Make fields immutable
|
293
|
+
|
294
|
+
def __init__(self, **data):
|
295
|
+
"""Override init to auto-generate unique ID if not provided."""
|
296
|
+
super().__init__(**data)
|
297
|
+
object.__setattr__(self, "id", self.compute_unique_id())
|
298
|
+
|
299
|
+
def compute_unique_id(self) -> str:
|
300
|
+
"""Computes the SHA-1 hash of the prompt as the ID."""
|
301
|
+
prompt = str(self.to_openai_format())
|
302
|
+
return hashlib.sha1(
|
303
|
+
f"{self.strategy}-{self.goal}-{prompt}".encode()
|
304
|
+
).hexdigest()
|
305
|
+
|
306
|
+
|
307
|
+
class MultiTurnConversation(MultiTurnTestPrompt):
|
308
|
+
pass
|
309
|
+
|
310
|
+
|
311
|
+
class MultiturnTestPrompts(BaseModel):
|
312
|
+
risk_name: str
|
313
|
+
test_prompts: List[MultiTurnTestPrompt] = Field(default_factory=list)
|
314
|
+
|
315
|
+
|
316
|
+
class BaseMultiTurnResponse(BaseMultiTurnConversation):
|
317
|
+
"""Multi Turn conversation as part of conversation with an agent or LLM"""
|
318
|
+
|
319
|
+
|
320
|
+
class MultiTurnResponse(BaseMultiTurnResponse):
|
321
|
+
"""Multi Turn conversation as part of conversation with an agent or LLM"""
|
322
|
+
|
323
|
+
pass
|
324
|
+
|
325
|
+
|
326
|
+
class BaseMultiTurnAgentResponse(BaseMultiTurnResponse):
|
327
|
+
"""Multi Turn conversation as part of conversation with an agent or LLM"""
|
328
|
+
|
329
|
+
response: Optional[Union[str, List[Any], Dict[str, Any]]] = Field(
|
330
|
+
default_factory=dict,
|
331
|
+
description="Final Agent Response",
|
332
|
+
)
|
333
|
+
|
334
|
+
scores: Optional[Dict[str, float]] = Field(
|
335
|
+
default_factory=dict,
|
336
|
+
description="Optional classification labels with confidence scores.",
|
337
|
+
)
|
338
|
+
|
339
|
+
policy: Optional[str] = Field(
|
340
|
+
default=None, description="Policy name that will be targeted by the goal"
|
341
|
+
)
|
342
|
+
goal: Optional[str] = Field(
|
343
|
+
default=None, description="Goal that need to be achieved"
|
344
|
+
)
|
345
|
+
|
346
|
+
|
347
|
+
class BaseMultiTurnClassificationResponse(BaseMultiTurnResponse):
|
348
|
+
"""Multi Turn conversation as part of conversation with an agent or LLM"""
|
349
|
+
|
350
|
+
labels: Dict[str, float] = Field(
|
351
|
+
default_factory=dict,
|
352
|
+
description="Optional classification labels with confidence scores.",
|
353
|
+
)
|
354
|
+
|
355
|
+
|
356
|
+
class BaseMultiTurnResponseBuilder:
|
357
|
+
def __init__(self):
|
358
|
+
self.turns = []
|
359
|
+
self.response = {}
|
360
|
+
self.scores = {}
|
361
|
+
self.goal = None
|
362
|
+
self.policy = None
|
363
|
+
|
364
|
+
def add_turn(self, turn: Turn):
|
365
|
+
self.turns.append(turn)
|
366
|
+
return self
|
367
|
+
|
368
|
+
def add_turns(self, turns: List[Turn]):
|
369
|
+
self.turns.extend(turns)
|
370
|
+
return self
|
371
|
+
|
372
|
+
def add_prompt(self, prompt: str, system_prompt: str = None):
|
373
|
+
if system_prompt:
|
374
|
+
self.turns.append(Turn(role=RoleType.SYSTEM, message=system_prompt))
|
375
|
+
self.turns.append(Turn(role=RoleType.USER, message=prompt))
|
376
|
+
return self
|
377
|
+
|
378
|
+
def add_parsed_response(self, response):
|
379
|
+
self.response = response
|
380
|
+
|
381
|
+
def add_prompt_and_response(self, prompt: str, response: str):
|
382
|
+
self.turns.extend(
|
383
|
+
[
|
384
|
+
Turn(role=RoleType.USER, message=prompt),
|
385
|
+
Turn(role=RoleType.ASSISTANT, message=response),
|
386
|
+
]
|
387
|
+
)
|
388
|
+
return self
|
389
|
+
|
390
|
+
def add_turn_response(self, response: str):
|
391
|
+
self.turns.extend(
|
392
|
+
[
|
393
|
+
Turn(role=RoleType.ASSISTANT, message=response),
|
394
|
+
]
|
395
|
+
)
|
396
|
+
return self
|
397
|
+
|
398
|
+
def validate_sequence(self):
|
399
|
+
"""
|
400
|
+
Validate that the conversation follows the sequence: [SYSTEM] (USER - ASSISTANT)+
|
401
|
+
"""
|
402
|
+
if not self.turns:
|
403
|
+
return
|
404
|
+
|
405
|
+
if self.turns[0].role == RoleType.SYSTEM:
|
406
|
+
expected_role = RoleType.USER
|
407
|
+
start_index = 1
|
408
|
+
else:
|
409
|
+
expected_role = RoleType.USER
|
410
|
+
start_index = 0
|
411
|
+
|
412
|
+
for i in range(start_index, len(self.turns)):
|
413
|
+
turn = self.turns[i]
|
414
|
+
if turn.role != expected_role:
|
415
|
+
raise ValueError(
|
416
|
+
f"Invalid conversation sequence at turn {i}: Expected {expected_role}, got {turn.role}"
|
417
|
+
)
|
418
|
+
expected_role = (
|
419
|
+
RoleType.ASSISTANT if expected_role == RoleType.USER else RoleType.USER
|
420
|
+
)
|
421
|
+
|
422
|
+
def _get_response_from_last_turn(self):
|
423
|
+
"""
|
424
|
+
If the last turn is from the assistant, set the response field accordingly.
|
425
|
+
Tries to parse the message as JSON, otherwise falls back to plain string.
|
426
|
+
"""
|
427
|
+
if not self.turns:
|
428
|
+
return None
|
429
|
+
|
430
|
+
last_turn = self.turns[-1]
|
431
|
+
|
432
|
+
if last_turn.role == RoleType.ASSISTANT:
|
433
|
+
content = last_turn.message
|
434
|
+
try:
|
435
|
+
return json.loads(content)
|
436
|
+
except (json.JSONDecodeError, TypeError):
|
437
|
+
return content
|
438
|
+
return None
|
439
|
+
|
440
|
+
def add_prompt_attributes(self, prompt: MultiTurnConversation):
|
441
|
+
self._prompt = prompt
|
442
|
+
if self._prompt:
|
443
|
+
self.policy = (
|
444
|
+
getattr(self._prompt, "policy", None)
|
445
|
+
if hasattr(self, "_prompt")
|
446
|
+
else None
|
447
|
+
)
|
448
|
+
self.goal = (
|
449
|
+
getattr(self._prompt, "goal", None)
|
450
|
+
if hasattr(self, "_prompt")
|
451
|
+
else None
|
452
|
+
)
|
453
|
+
|
454
|
+
def build(self) -> BaseMultiTurnAgentResponse:
|
455
|
+
return BaseMultiTurnAgentResponse(
|
456
|
+
turns=self.turns,
|
457
|
+
response=self.response,
|
458
|
+
policy=self.policy,
|
459
|
+
goal=self.goal,
|
460
|
+
)
|
File without changes
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
|
4
|
+
class ProviderType(str, Enum):
|
5
|
+
ECHO = "echo"
|
6
|
+
ELIZA = "eliza"
|
7
|
+
HF = "huggingface"
|
8
|
+
HTTP = "http"
|
9
|
+
GRADIO = "gradio"
|
10
|
+
OLLAMA = "ollama"
|
11
|
+
OPENAI = "openai"
|
12
|
+
GROQ = "groq"
|
13
|
+
LITE_LLM = "litellm"
|
14
|
+
|
15
|
+
def __str__(self):
|
16
|
+
return self.value # Ensures correct YAML serialization
|
17
|
+
|
18
|
+
@classmethod
|
19
|
+
def values(cls):
|
20
|
+
return [member.value for member in cls]
|
@@ -0,0 +1,171 @@
|
|
1
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
2
|
+
from urllib.parse import urlparse
|
3
|
+
|
4
|
+
from pydantic import BaseModel, Field, field_validator
|
5
|
+
|
6
|
+
from ..utils.urls import url_2_name
|
7
|
+
|
8
|
+
|
9
|
+
from .base import ProviderType
|
10
|
+
|
11
|
+
|
12
|
+
class GradioApiSignatureParam(BaseModel):
|
13
|
+
"""
|
14
|
+
Represents a single parameter for a Gradio API signature.
|
15
|
+
"""
|
16
|
+
|
17
|
+
name: str = Field(..., description="Name of the parameter.")
|
18
|
+
has_default_value: bool = Field(
|
19
|
+
..., description="Indicates if the parameter has a default value."
|
20
|
+
)
|
21
|
+
default_value: Optional[Union[str, int, bool, float, list]] = Field(
|
22
|
+
None,
|
23
|
+
description="The default value of the parameter, which can be a string, integer, boolean, or list.",
|
24
|
+
)
|
25
|
+
python_type: str = Field(
|
26
|
+
...,
|
27
|
+
description="The Python type of the parameter (e.g., str, int, bool, or a Literal).",
|
28
|
+
)
|
29
|
+
|
30
|
+
@field_validator("default_value", mode="before")
|
31
|
+
@classmethod
|
32
|
+
def validate_default_value(cls, value):
|
33
|
+
"""
|
34
|
+
Ensures the default_value is of the correct type.
|
35
|
+
- If it's a list, pick the first element if possible.
|
36
|
+
- If it's an empty list, return `None`.
|
37
|
+
"""
|
38
|
+
if isinstance(value, list):
|
39
|
+
return (
|
40
|
+
value[0] if value else None
|
41
|
+
) # Take the first element or None if empty
|
42
|
+
return value # Otherwise, keep it as it is
|
43
|
+
|
44
|
+
|
45
|
+
class GradioApiSignatureParams(BaseModel):
|
46
|
+
"""
|
47
|
+
A wrapper model to hold multiple API signature parameters.
|
48
|
+
"""
|
49
|
+
|
50
|
+
params: List[GradioApiSignatureParam] = Field(
|
51
|
+
..., description="List of API signature parameters."
|
52
|
+
)
|
53
|
+
|
54
|
+
|
55
|
+
class GradioApiSpec(BaseModel):
|
56
|
+
"""
|
57
|
+
Represents a Gradio API specification, including the API name and its parameters.
|
58
|
+
"""
|
59
|
+
|
60
|
+
api_name: str = Field(..., description="The endpoint path of the API.")
|
61
|
+
params: List[GradioApiSignatureParam] = Field(
|
62
|
+
default_factory=list, description="List of parameters required by the API."
|
63
|
+
)
|
64
|
+
# response: Dict[str, Any] = Field(
|
65
|
+
# default_factory=dict, description="The expected response structure."
|
66
|
+
# )
|
67
|
+
|
68
|
+
|
69
|
+
class GradioApiSpecs(BaseModel):
|
70
|
+
"""
|
71
|
+
Contains multiple Gradio API specifications.
|
72
|
+
"""
|
73
|
+
|
74
|
+
apis: List[GradioApiSpec] = Field(
|
75
|
+
default_factory=list, description="List of API specifications."
|
76
|
+
)
|
77
|
+
|
78
|
+
|
79
|
+
class GradioProviderApiParam(BaseModel):
|
80
|
+
"""
|
81
|
+
Represents a parameter to be used in a request to a Gradio API.
|
82
|
+
"""
|
83
|
+
|
84
|
+
name: str = Field(..., description="The name of the parameter.")
|
85
|
+
value: Optional[Union[str, int, bool, float, list, tuple, dict]] = Field(
|
86
|
+
None, description="The value of the parameter to be sent in the API request."
|
87
|
+
)
|
88
|
+
|
89
|
+
|
90
|
+
class GradioResponseParserSignature(BaseModel):
|
91
|
+
"""
|
92
|
+
Signature based parser response
|
93
|
+
"""
|
94
|
+
|
95
|
+
parser_type: Literal["signature"] = Field("signature", description="")
|
96
|
+
content_type: Optional[str] = Field(
|
97
|
+
default="text",
|
98
|
+
description="Content type str, array etc.",
|
99
|
+
examples=["json", "jsonl", "text", "array"],
|
100
|
+
)
|
101
|
+
location: Optional[List[Union[str, int]]] = Field(
|
102
|
+
None, description="Location of the response as sequence of integers"
|
103
|
+
)
|
104
|
+
|
105
|
+
|
106
|
+
class GradioProviderApi(BaseModel):
|
107
|
+
"""
|
108
|
+
Represents a Gradio API request, including its path and parameters.
|
109
|
+
"""
|
110
|
+
|
111
|
+
path: str = Field(..., description="The API endpoint path.")
|
112
|
+
params: Optional[List[GradioProviderApiParam]] = Field(
|
113
|
+
None, description="Optional list of parameters for the API request."
|
114
|
+
)
|
115
|
+
transform_response: Optional[
|
116
|
+
Union[str, GradioResponseParserSignature, Dict[str, Any]]
|
117
|
+
] = Field(
|
118
|
+
None, description="Logic to extract Assistant Response from Gradio Response"
|
119
|
+
)
|
120
|
+
|
121
|
+
|
122
|
+
class GradioProviderConfig(BaseModel):
|
123
|
+
"""
|
124
|
+
Configuration model for a Gradio API provider.
|
125
|
+
Supports both standard URLs and Hugging Face Space identifiers.
|
126
|
+
"""
|
127
|
+
|
128
|
+
url: str = Field(
|
129
|
+
...,
|
130
|
+
description="The base URL of the Gradio API provider or a Hugging Face Space ID.",
|
131
|
+
)
|
132
|
+
|
133
|
+
apis: Optional[List[GradioProviderApi]] = Field(
|
134
|
+
default_factory=list,
|
135
|
+
description="Optional list of APIs available for testing."
|
136
|
+
)
|
137
|
+
|
138
|
+
@field_validator("url", mode="before")
|
139
|
+
@classmethod
|
140
|
+
def validate_url_or_hf_space(cls, value: str) -> str:
|
141
|
+
try:
|
142
|
+
_ = urlparse(value)
|
143
|
+
return value
|
144
|
+
except Exception:
|
145
|
+
raise ValueError(
|
146
|
+
f"Invalid URL or Hugging Face Space identifier: '{value}'. "
|
147
|
+
"Provide a valid URL (https://example.com) or a Hugging Face Space in 'org/app' format."
|
148
|
+
)
|
149
|
+
|
150
|
+
def get_name(self) -> str:
|
151
|
+
"""
|
152
|
+
Returns a name derived from the URL, limited to 3 path levels.
|
153
|
+
"""
|
154
|
+
return url_2_name(self.url, level=3)
|
155
|
+
|
156
|
+
|
157
|
+
class GradioProvider(BaseModel):
|
158
|
+
"""
|
159
|
+
Represents a Gradio API provider, including its configuration.
|
160
|
+
"""
|
161
|
+
|
162
|
+
provider: Literal["gradio"] = Field(
|
163
|
+
ProviderType.GRADIO.value, description="Provider ID, always set to 'gradio'."
|
164
|
+
)
|
165
|
+
config: GradioProviderConfig = Field(
|
166
|
+
..., description="Configuration details for the Gradio API provider."
|
167
|
+
)
|
168
|
+
|
169
|
+
|
170
|
+
class GradioProviders(BaseModel):
|
171
|
+
providers: List[GradioProvider]
|
@@ -0,0 +1,27 @@
|
|
1
|
+
from typing import Literal, Optional
|
2
|
+
|
3
|
+
from pydantic import BaseModel, Field
|
4
|
+
|
5
|
+
from ..providers.openai import BaseProviderConfig
|
6
|
+
|
7
|
+
|
8
|
+
class GroqProviderConfig(BaseProviderConfig):
|
9
|
+
endpoint: Optional[str] = Field(
|
10
|
+
default="https://api.groq.com/v1",
|
11
|
+
description="Base URL of the Groq server or proxy endpoint.",
|
12
|
+
)
|
13
|
+
|
14
|
+
def get_name(self) -> str:
|
15
|
+
"""
|
16
|
+
Returns the model name as the provider's name.
|
17
|
+
"""
|
18
|
+
return self.model
|
19
|
+
|
20
|
+
|
21
|
+
class GroqProvider(BaseModel):
|
22
|
+
"""Wrapper for Groq provider configuration."""
|
23
|
+
|
24
|
+
provider: Literal["groq"] = Field(
|
25
|
+
"groq", description="Provider ID, always set to 'groq'."
|
26
|
+
)
|
27
|
+
config: GroqProviderConfig
|