dtx-models 0.18.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtx_models/__init__.py +0 -0
- dtx_models/analysis.py +322 -0
- dtx_models/base.py +0 -0
- dtx_models/evaluator.py +273 -0
- dtx_models/exceptions.py +2 -0
- dtx_models/prompts.py +460 -0
- dtx_models/providers/__init__.py +0 -0
- dtx_models/providers/base.py +20 -0
- dtx_models/providers/gradio.py +171 -0
- dtx_models/providers/groq.py +27 -0
- dtx_models/providers/hf.py +161 -0
- dtx_models/providers/http.py +152 -0
- dtx_models/providers/litellm.py +21 -0
- dtx_models/providers/models_spec.py +229 -0
- dtx_models/providers/ollama.py +107 -0
- dtx_models/providers/openai.py +139 -0
- dtx_models/results.py +124 -0
- dtx_models/scope.py +208 -0
- dtx_models/tactic.py +52 -0
- dtx_models/target.py +255 -0
- dtx_models/template/__init__.py +0 -0
- dtx_models/template/prompts/__init__.py +0 -0
- dtx_models/template/prompts/base.py +49 -0
- dtx_models/template/prompts/langhub.py +79 -0
- dtx_models/utils/__init__.py +0 -0
- dtx_models/utils/urls.py +26 -0
- dtx_models-0.18.2.dist-info/METADATA +57 -0
- dtx_models-0.18.2.dist-info/RECORD +29 -0
- dtx_models-0.18.2.dist-info/WHEEL +4 -0
@@ -0,0 +1,161 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Any, Dict, List, Literal, Optional
|
3
|
+
|
4
|
+
from pydantic import BaseModel, Field, HttpUrl, field_serializer, model_validator
|
5
|
+
|
6
|
+
from ..evaluator import EvaluatorInScope
|
7
|
+
from ..prompts import SupportedFormat
|
8
|
+
|
9
|
+
|
10
|
+
class HuggingFaceTask(str, Enum):
|
11
|
+
TEXT_GENERATION = "text-generation"
|
12
|
+
TEXT2TEXT_GENERATION = "text2text-generation"
|
13
|
+
TEXT_CLASSIFICATION = "text-classification"
|
14
|
+
TOKEN_CLASSIFICATION = "token-classification"
|
15
|
+
FEATURE_EXTRACTION = "feature-extraction"
|
16
|
+
SENTENCE_SIMILARITY = "sentence-similarity"
|
17
|
+
FILL_MASK = "fill-mask"
|
18
|
+
|
19
|
+
|
20
|
+
class DType(str, Enum):
|
21
|
+
FLOAT16 = "float16"
|
22
|
+
BFLOAT16 = "bfloat16"
|
23
|
+
FLOAT32 = "float32"
|
24
|
+
|
25
|
+
|
26
|
+
class HFEndpoint(BaseModel):
|
27
|
+
api_endpoint: Optional[HttpUrl] = Field(
|
28
|
+
None, description="Custom API endpoint for the model."
|
29
|
+
)
|
30
|
+
api_key: Optional[str] = Field(None, description="Your HuggingFace API key.")
|
31
|
+
|
32
|
+
|
33
|
+
class HuggingFaceProviderParams(BaseModel):
|
34
|
+
temperature: Optional[float] = Field(
|
35
|
+
None, ge=0, le=1, description="Controls randomness in generation."
|
36
|
+
)
|
37
|
+
top_k: Optional[int] = Field(
|
38
|
+
None, ge=1, description="Controls diversity via the top-k sampling strategy."
|
39
|
+
)
|
40
|
+
top_p: Optional[float] = Field(
|
41
|
+
None, ge=0, le=1, description="Controls diversity via nucleus sampling."
|
42
|
+
)
|
43
|
+
repetition_penalty: Optional[float] = Field(
|
44
|
+
None, ge=0, description="Penalty for repetition."
|
45
|
+
)
|
46
|
+
max_new_tokens: Optional[int] = Field(
|
47
|
+
None, ge=1, description="The maximum number of new tokens to generate."
|
48
|
+
)
|
49
|
+
max_time: Optional[float] = Field(
|
50
|
+
None, ge=0, description="The maximum time in seconds for the model to respond."
|
51
|
+
)
|
52
|
+
return_full_text: Optional[bool] = Field(
|
53
|
+
None, description="Whether to return the full text or just new text."
|
54
|
+
)
|
55
|
+
num_return_sequences: Optional[int] = Field(
|
56
|
+
None, ge=1, description="The number of sequences to return."
|
57
|
+
)
|
58
|
+
do_sample: Optional[bool] = Field(None, description="Whether to sample the output.")
|
59
|
+
use_cache: Optional[bool] = Field(None, description="Whether to use caching.")
|
60
|
+
wait_for_model: Optional[bool] = Field(
|
61
|
+
None, description="Whether to wait for the model to be ready."
|
62
|
+
)
|
63
|
+
extra_params: Optional[Dict[str, Any]] = Field(
|
64
|
+
None, description="Additional parameters to pass to HuggingFace API."
|
65
|
+
)
|
66
|
+
dtype: Optional[DType] = Field(
|
67
|
+
DType.BFLOAT16, description="Data type for model computation."
|
68
|
+
)
|
69
|
+
|
70
|
+
@field_serializer("dtype")
|
71
|
+
def serialize_dtype(self, dtype: DType) -> str:
|
72
|
+
"""Serialize the enum to a string."""
|
73
|
+
if dtype:
|
74
|
+
return dtype.value
|
75
|
+
else:
|
76
|
+
return dtype
|
77
|
+
|
78
|
+
|
79
|
+
class HuggingFaceProviderConfig(BaseModel):
|
80
|
+
model: str = Field(..., description="HuggingFace model name.")
|
81
|
+
task: HuggingFaceTask = Field(
|
82
|
+
..., description="Task type for the HuggingFace model."
|
83
|
+
)
|
84
|
+
params: Optional[HuggingFaceProviderParams] = Field(
|
85
|
+
None, description="Configuration parameters for HuggingFace model."
|
86
|
+
)
|
87
|
+
endpoint: Optional[HFEndpoint] = Field(
|
88
|
+
None, description="HuggingFace API endpoint configuration."
|
89
|
+
)
|
90
|
+
|
91
|
+
support_multi_turn: bool = Field(
|
92
|
+
default=False,
|
93
|
+
description="Does it support Multi Turn",
|
94
|
+
)
|
95
|
+
|
96
|
+
supported_input_format: Optional[SupportedFormat] = Field(
|
97
|
+
default=SupportedFormat.TEXT,
|
98
|
+
description="Supported Input format",
|
99
|
+
)
|
100
|
+
|
101
|
+
id: Optional[str] = Field(
|
102
|
+
None,
|
103
|
+
description="HuggingFace model identifier computed as 'huggingface:<task>:<model>'",
|
104
|
+
)
|
105
|
+
|
106
|
+
preferred_evaluator: Optional[EvaluatorInScope] = Field(
|
107
|
+
None,
|
108
|
+
description="Preferred Evaluator for the provider",
|
109
|
+
)
|
110
|
+
|
111
|
+
@model_validator(mode="after")
|
112
|
+
def compute_id(cls, values):
|
113
|
+
if not values.id:
|
114
|
+
values.id = f"huggingface:{values.task.value}:{values.model}"
|
115
|
+
return values
|
116
|
+
|
117
|
+
@field_serializer("task")
|
118
|
+
def serialize_role(self, task: HuggingFaceTask) -> str:
|
119
|
+
"""Serialize the enum to a string."""
|
120
|
+
return task.value
|
121
|
+
|
122
|
+
@field_serializer("supported_input_format")
|
123
|
+
def serialize_supported_format(
|
124
|
+
self, supported_input_format: SupportedFormat
|
125
|
+
) -> str:
|
126
|
+
"""Serialize the enum to a string."""
|
127
|
+
return str(supported_input_format)
|
128
|
+
|
129
|
+
def get_name(self) -> str:
|
130
|
+
"""
|
131
|
+
Returns the model name as the provider's name.
|
132
|
+
"""
|
133
|
+
return self.model
|
134
|
+
|
135
|
+
|
136
|
+
class HuggingFaceGuardModelsProviderConfig(HuggingFaceProviderConfig):
|
137
|
+
safe_value: str = Field(
|
138
|
+
"",
|
139
|
+
description="JSONPath expression indicating the condition for input to be safe.",
|
140
|
+
)
|
141
|
+
|
142
|
+
"""
|
143
|
+
Example Usage:
|
144
|
+
----------------
|
145
|
+
provider = HuggingFaceGuardModelsProviderConfig(safe_value="$.score[?(@ < 0.6)]")
|
146
|
+
|
147
|
+
This JSONPath expression checks if all scores in the input JSON are below 0.6.
|
148
|
+
"""
|
149
|
+
|
150
|
+
|
151
|
+
class HFProvider(BaseModel):
|
152
|
+
provider: Literal["huggingface"] = Field(
|
153
|
+
"huggingface", description="Provider ID, always set to 'http'."
|
154
|
+
)
|
155
|
+
config: HuggingFaceProviderConfig
|
156
|
+
|
157
|
+
|
158
|
+
class HFModels(BaseModel):
|
159
|
+
huggingface: List[HuggingFaceProviderConfig | HuggingFaceGuardModelsProviderConfig] = Field(
|
160
|
+
default_factory=list, description="List of predefined models."
|
161
|
+
)
|
@@ -0,0 +1,152 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Any, Callable, Dict, Literal, Optional, Union
|
3
|
+
from dtx_models.utils.urls import url_2_name
|
4
|
+
|
5
|
+
from pydantic import (
|
6
|
+
BaseModel,
|
7
|
+
Field,
|
8
|
+
HttpUrl,
|
9
|
+
field_serializer,
|
10
|
+
field_validator,
|
11
|
+
)
|
12
|
+
|
13
|
+
|
14
|
+
class HttpMethod(str, Enum):
|
15
|
+
GET = "GET"
|
16
|
+
POST = "POST"
|
17
|
+
PUT = "PUT"
|
18
|
+
DELETE = "DELETE"
|
19
|
+
PATCH = "PATCH"
|
20
|
+
|
21
|
+
def __str__(self):
|
22
|
+
return self.value # Ensures correct YAML serialization
|
23
|
+
|
24
|
+
@classmethod
|
25
|
+
def values(cls):
|
26
|
+
return [member.value for member in cls]
|
27
|
+
|
28
|
+
|
29
|
+
class BaseHttpProvider(BaseModel):
|
30
|
+
"""Base class for common HTTP provider attributes."""
|
31
|
+
|
32
|
+
max_retries: int = Field(
|
33
|
+
4, ge=0, description="Maximum number of retries for failed requests."
|
34
|
+
)
|
35
|
+
validate_response: Optional[Union[str, Callable[[int], bool]]] = Field(
|
36
|
+
None, description="Validation function for HTTP response status."
|
37
|
+
)
|
38
|
+
transform_request: Optional[Union[str, Dict[str, Any]]] = Field(
|
39
|
+
None, description="Template or mapping to modify request before sending."
|
40
|
+
)
|
41
|
+
transform_response: Optional[Union[str, Dict[str, Any]]] = Field(
|
42
|
+
None, description="Transformation logic for processing API responses."
|
43
|
+
)
|
44
|
+
|
45
|
+
example_response: Optional[str] = Field(
|
46
|
+
None, description="Example response dumped as sample if available"
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
class StructuredHttpProviderConfig(BaseHttpProvider):
|
51
|
+
"""Defines an HTTP request provider with structured fields (URL, method, headers, body)."""
|
52
|
+
|
53
|
+
url: str = Field(..., description="The HTTP endpoint URL.")
|
54
|
+
method: HttpMethod = Field(..., description="HTTP method.")
|
55
|
+
headers: Optional[Dict[str, str]] = Field(
|
56
|
+
default_factory=dict, description="HTTP headers."
|
57
|
+
)
|
58
|
+
body: Optional[Union[str, Dict[str, Any]]] = Field(
|
59
|
+
None, description="HTTP request body as JSON or a form-urlencoded string."
|
60
|
+
)
|
61
|
+
|
62
|
+
@field_serializer("method")
|
63
|
+
def serialize_http_method(self, method: HttpMethod) -> Optional[str]:
|
64
|
+
"""Serialize the HTTP method to a string."""
|
65
|
+
return str(method.value) if method else None
|
66
|
+
|
67
|
+
@field_validator("url", mode="before")
|
68
|
+
@classmethod
|
69
|
+
def validate_base_url(cls, value: str) -> str:
|
70
|
+
"""Validate that the URL is well-formed."""
|
71
|
+
try:
|
72
|
+
HttpUrl(value) # Validate the URL format
|
73
|
+
except ValueError as e:
|
74
|
+
raise ValueError(f"Invalid base_url: {value}") from e
|
75
|
+
return value
|
76
|
+
|
77
|
+
@field_validator("body", mode="before")
|
78
|
+
@classmethod
|
79
|
+
def validate_body(
|
80
|
+
cls, value: Optional[Union[str, Dict[str, Any]]]
|
81
|
+
) -> Optional[Union[str, Dict[str, Any]]]:
|
82
|
+
"""Ensure body is either a JSON dictionary or a form-urlencoded string."""
|
83
|
+
if value is not None and not isinstance(value, (str, dict)):
|
84
|
+
raise ValueError(
|
85
|
+
"Body must be a dictionary (JSON) or a string (form-urlencoded)."
|
86
|
+
)
|
87
|
+
return value
|
88
|
+
|
89
|
+
|
90
|
+
def get_name(self) -> str:
|
91
|
+
"""Generate a name from the URL using scheme:host:port:/path"""
|
92
|
+
return url_2_name(self.url, level=3)
|
93
|
+
|
94
|
+
|
95
|
+
class RawHttpProviderConfig(BaseHttpProvider):
|
96
|
+
"""Defines an HTTP request provider using raw_request instead of structured fields."""
|
97
|
+
|
98
|
+
raw_request: str = Field(..., description="Full raw HTTP request in text format.")
|
99
|
+
use_https: bool = Field(default=False, description="Whether to use HTTPS.")
|
100
|
+
|
101
|
+
@field_validator("raw_request", mode="before")
|
102
|
+
@classmethod
|
103
|
+
def validate_raw_request(cls, value: str) -> str:
|
104
|
+
"""Ensure raw_request is a valid string."""
|
105
|
+
if not isinstance(value, str):
|
106
|
+
raise ValueError("Raw HTTP request must be a string.")
|
107
|
+
return value
|
108
|
+
|
109
|
+
def get_name(self) -> str:
|
110
|
+
"""
|
111
|
+
Extract the HTTP method and path from the first line of the raw request.
|
112
|
+
Returns something like 'GET /api/v1/resource'.
|
113
|
+
"""
|
114
|
+
try:
|
115
|
+
first_line = self.raw_request.strip().splitlines()[0]
|
116
|
+
parts = first_line.split()
|
117
|
+
method = parts[0] if len(parts) > 0 else "UNKNOWN"
|
118
|
+
path = parts[1] if len(parts) > 1 else "/"
|
119
|
+
return f"{method} {path}"
|
120
|
+
except Exception:
|
121
|
+
return "INVALID RAW REQUEST"
|
122
|
+
|
123
|
+
|
124
|
+
class HttpProvider(BaseModel):
|
125
|
+
provider: Literal["http"] = Field(
|
126
|
+
"http", description="Provider ID, always set to 'http'."
|
127
|
+
)
|
128
|
+
config: StructuredHttpProviderConfig | RawHttpProviderConfig
|
129
|
+
|
130
|
+
|
131
|
+
class BaseHttpProviderResponse(BaseModel):
|
132
|
+
"""
|
133
|
+
Represents a standardized response format for HTTP providers.
|
134
|
+
|
135
|
+
- `mutations`: Stores any changes or replacements applied to the request.
|
136
|
+
- `request`: Stores the raw request sent (either as a string or a structured dictionary).
|
137
|
+
- `response`: Stores the raw response received (either as a string or a structured dictionary).
|
138
|
+
"""
|
139
|
+
|
140
|
+
mutations: Optional[Dict[str, Any]] = Field(
|
141
|
+
default_factory=dict, description="Replacements performed in the request"
|
142
|
+
)
|
143
|
+
# request: Optional[Union[str, Dict[str, Any]]] = Field(
|
144
|
+
# default_factory=dict, description="Request Sent"
|
145
|
+
# )
|
146
|
+
response: Optional[Union[str, Dict[str, Any]]] = Field(
|
147
|
+
default_factory=dict, description="Response Received"
|
148
|
+
)
|
149
|
+
|
150
|
+
|
151
|
+
# class Providers(BaseModel):
|
152
|
+
# providers: List[HttpProvider]
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from typing import Literal, Optional
|
2
|
+
|
3
|
+
from pydantic import BaseModel, Field
|
4
|
+
|
5
|
+
from .openai import OpenaiProviderConfig
|
6
|
+
|
7
|
+
|
8
|
+
class LitellmProviderConfig(OpenaiProviderConfig):
|
9
|
+
endpoint: Optional[str] = Field(
|
10
|
+
default=None,
|
11
|
+
description="Base URL of the OpenAI server or proxy endpoint.",
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
class LitellmProvider(BaseModel):
|
16
|
+
"""Wrapper for OpenAI provider configuration."""
|
17
|
+
|
18
|
+
provider: Literal["litellm"] = Field(
|
19
|
+
"litellm", description="Provider ID, always set to 'openai'."
|
20
|
+
)
|
21
|
+
config: LitellmProviderConfig
|
@@ -0,0 +1,229 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import List, Optional
|
3
|
+
|
4
|
+
from pydantic import BaseModel, Field
|
5
|
+
|
6
|
+
from ..providers.base import ProviderType
|
7
|
+
|
8
|
+
# --- Define Modalities ---
|
9
|
+
|
10
|
+
|
11
|
+
class Modality(str, Enum):
|
12
|
+
TEXT = "text"
|
13
|
+
IMAGE = "image"
|
14
|
+
CODE = "code"
|
15
|
+
|
16
|
+
|
17
|
+
# --- Define Task Types ---
|
18
|
+
|
19
|
+
|
20
|
+
class ModelTaskType(str, Enum):
|
21
|
+
GENERATION = "generation"
|
22
|
+
CLASSIFICATION = "classification"
|
23
|
+
EMBEDDING = "embedding"
|
24
|
+
|
25
|
+
|
26
|
+
# --- Define Model Schema ---
|
27
|
+
|
28
|
+
|
29
|
+
class ModelSpec(BaseModel):
|
30
|
+
name: str = Field(..., description="Model name, e.g., gpt-4o, llama3.")
|
31
|
+
task: ModelTaskType = Field(..., description="Primary task type of the model.")
|
32
|
+
modalities: List[Modality] = Field(
|
33
|
+
..., description="Supported modalities (text, image, code, etc.)."
|
34
|
+
)
|
35
|
+
description: Optional[str] = Field(
|
36
|
+
None, description="Optional description of the model."
|
37
|
+
)
|
38
|
+
context_window: Optional[int] = Field(None, description="Context window in tokens.")
|
39
|
+
max_completion_tokens: Optional[int] = Field(
|
40
|
+
None, description="Maximum completion tokens."
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
# --- Define Provider Schema ---
|
45
|
+
|
46
|
+
|
47
|
+
class ProviderModels(BaseModel):
|
48
|
+
provider: ProviderType = Field(
|
49
|
+
..., description="Name of the provider, e.g., openai, groq, etc."
|
50
|
+
)
|
51
|
+
models: List[ModelSpec] = Field(
|
52
|
+
..., description="List of models available under this provider."
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
# --- Example Repository ---
|
57
|
+
|
58
|
+
models_repository: List[ProviderModels] = [
|
59
|
+
ProviderModels(
|
60
|
+
provider=ProviderType.OPENAI,
|
61
|
+
models=[
|
62
|
+
ModelSpec(
|
63
|
+
name="gpt-4.5-preview",
|
64
|
+
task=ModelTaskType.GENERATION,
|
65
|
+
modalities=[Modality.TEXT, Modality.CODE],
|
66
|
+
description="Largest and most capable GPT model.",
|
67
|
+
context_window=128000,
|
68
|
+
max_completion_tokens=4096,
|
69
|
+
),
|
70
|
+
ModelSpec(
|
71
|
+
name="gpt-4o",
|
72
|
+
task=ModelTaskType.GENERATION,
|
73
|
+
modalities=[Modality.TEXT, Modality.CODE],
|
74
|
+
description="Fast, intelligent, flexible GPT model.",
|
75
|
+
context_window=128000,
|
76
|
+
max_completion_tokens=4096,
|
77
|
+
),
|
78
|
+
ModelSpec(
|
79
|
+
name="gpt-4o-mini",
|
80
|
+
task=ModelTaskType.GENERATION,
|
81
|
+
modalities=[Modality.TEXT, Modality.CODE],
|
82
|
+
description="Fast, affordable small model for focused tasks.",
|
83
|
+
context_window=128000,
|
84
|
+
max_completion_tokens=4096,
|
85
|
+
),
|
86
|
+
ModelSpec(
|
87
|
+
name="gpt-4-turbo",
|
88
|
+
task=ModelTaskType.GENERATION,
|
89
|
+
modalities=[Modality.TEXT, Modality.CODE],
|
90
|
+
description="An older high-intelligence GPT model.",
|
91
|
+
context_window=128000,
|
92
|
+
max_completion_tokens=4096,
|
93
|
+
),
|
94
|
+
ModelSpec(
|
95
|
+
name="gpt-3.5-turbo",
|
96
|
+
task=ModelTaskType.GENERATION,
|
97
|
+
modalities=[Modality.TEXT],
|
98
|
+
description="Legacy GPT model for cheaper chat and non-chat tasks.",
|
99
|
+
context_window=16384,
|
100
|
+
max_completion_tokens=4096,
|
101
|
+
),
|
102
|
+
ModelSpec(
|
103
|
+
name="text-embedding-3-large",
|
104
|
+
task=ModelTaskType.EMBEDDING,
|
105
|
+
modalities=[Modality.TEXT],
|
106
|
+
description="Most capable embedding model.",
|
107
|
+
context_window=None,
|
108
|
+
max_completion_tokens=None,
|
109
|
+
),
|
110
|
+
ModelSpec(
|
111
|
+
name="omni-moderation-latest",
|
112
|
+
task=ModelTaskType.CLASSIFICATION,
|
113
|
+
modalities=[Modality.TEXT, Modality.IMAGE],
|
114
|
+
description="Identify potentially harmful content in text and images.",
|
115
|
+
context_window=None,
|
116
|
+
max_completion_tokens=None,
|
117
|
+
),
|
118
|
+
],
|
119
|
+
),
|
120
|
+
ProviderModels(
|
121
|
+
provider=ProviderType.GROQ,
|
122
|
+
models=[
|
123
|
+
ModelSpec(
|
124
|
+
name="llama-3.3-70b-versatile",
|
125
|
+
task=ModelTaskType.GENERATION,
|
126
|
+
modalities=[Modality.TEXT],
|
127
|
+
description="Meta's versatile 70B parameter model.",
|
128
|
+
context_window=128000,
|
129
|
+
max_completion_tokens=32768,
|
130
|
+
),
|
131
|
+
ModelSpec(
|
132
|
+
name="llama-3.1-8b-instant",
|
133
|
+
task=ModelTaskType.GENERATION,
|
134
|
+
modalities=[Modality.TEXT],
|
135
|
+
description="Meta's 8B parameter instant model.",
|
136
|
+
context_window=128000,
|
137
|
+
max_completion_tokens=8192,
|
138
|
+
),
|
139
|
+
ModelSpec(
|
140
|
+
name="llama-guard-3-8b",
|
141
|
+
task=ModelTaskType.CLASSIFICATION,
|
142
|
+
modalities=[Modality.TEXT],
|
143
|
+
description="Meta's guard model for moderation and classification.",
|
144
|
+
context_window=8192,
|
145
|
+
max_completion_tokens=None,
|
146
|
+
),
|
147
|
+
ModelSpec(
|
148
|
+
name="llama3-70b-8192",
|
149
|
+
task=ModelTaskType.GENERATION,
|
150
|
+
modalities=[Modality.TEXT],
|
151
|
+
description="Meta's 70B model with 8k context window.",
|
152
|
+
context_window=8192,
|
153
|
+
max_completion_tokens=None,
|
154
|
+
),
|
155
|
+
ModelSpec(
|
156
|
+
name="llama3-8b-8192",
|
157
|
+
task=ModelTaskType.GENERATION,
|
158
|
+
modalities=[Modality.TEXT],
|
159
|
+
description="Meta's 8B model with 8k context window.",
|
160
|
+
context_window=8192,
|
161
|
+
max_completion_tokens=None,
|
162
|
+
),
|
163
|
+
ModelSpec(
|
164
|
+
name="whisper-large-v3",
|
165
|
+
task=ModelTaskType.CLASSIFICATION,
|
166
|
+
modalities=[Modality.TEXT],
|
167
|
+
description="OpenAI's general-purpose speech recognition model.",
|
168
|
+
context_window=None,
|
169
|
+
max_completion_tokens=None,
|
170
|
+
),
|
171
|
+
ModelSpec(
|
172
|
+
name="whisper-large-v3-turbo",
|
173
|
+
task=ModelTaskType.CLASSIFICATION,
|
174
|
+
modalities=[Modality.TEXT],
|
175
|
+
description="Turbo version of OpenAI's Whisper speech recognition model.",
|
176
|
+
context_window=None,
|
177
|
+
max_completion_tokens=None,
|
178
|
+
),
|
179
|
+
],
|
180
|
+
),
|
181
|
+
]
|
182
|
+
|
183
|
+
|
184
|
+
# --- ModelSpecRepo class ---
|
185
|
+
|
186
|
+
|
187
|
+
class ModelSpecRepo:
|
188
|
+
def __init__(self, repository: Optional[List[ProviderModels]] = None):
|
189
|
+
self.repository = repository or models_repository
|
190
|
+
|
191
|
+
def get_model_by_name(self, name: str) -> Optional[ModelSpec]:
|
192
|
+
for provider in self.repository:
|
193
|
+
for model in provider.models:
|
194
|
+
if model.name == name:
|
195
|
+
return model
|
196
|
+
return None
|
197
|
+
|
198
|
+
def get_models_by_provider(self, provider_name: ProviderType) -> List[ModelSpec]:
|
199
|
+
for provider in self.repository:
|
200
|
+
if provider.provider == provider_name:
|
201
|
+
return provider.models
|
202
|
+
return []
|
203
|
+
|
204
|
+
def list_all_models(self) -> List[ModelSpec]:
|
205
|
+
all_models = []
|
206
|
+
for provider in self.repository:
|
207
|
+
all_models.extend(provider.models)
|
208
|
+
return all_models
|
209
|
+
|
210
|
+
def get_models_by_task(self, task_type: ModelTaskType) -> List[ModelSpec]:
|
211
|
+
models = []
|
212
|
+
for provider in self.repository:
|
213
|
+
for model in provider.models:
|
214
|
+
if model.task == task_type:
|
215
|
+
models.append(model)
|
216
|
+
return models
|
217
|
+
|
218
|
+
|
219
|
+
# Example usage
|
220
|
+
if __name__ == "__main__":
|
221
|
+
# Example: Use ModelSpecRepo
|
222
|
+
repo = ModelSpecRepo()
|
223
|
+
print("\nFetching model by name 'gpt-4o':")
|
224
|
+
model = repo.get_model_by_name("gpt-4o")
|
225
|
+
print(model)
|
226
|
+
|
227
|
+
print("\nListing all generation models:")
|
228
|
+
for model in repo.get_models_by_task(ModelTaskType.GENERATION):
|
229
|
+
print(f"- {model.name}")
|
@@ -0,0 +1,107 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Any, Dict, Literal, Optional
|
3
|
+
from dtx_models.utils.urls import url_2_name
|
4
|
+
|
5
|
+
from pydantic import BaseModel, Field, field_serializer, model_validator
|
6
|
+
|
7
|
+
## Supported Yaml
|
8
|
+
|
9
|
+
# """
|
10
|
+
# ollama:
|
11
|
+
# - model: llama3
|
12
|
+
|
13
|
+
|
14
|
+
# ollama:
|
15
|
+
# - model: llama3
|
16
|
+
# task: text-generation
|
17
|
+
# endpoint: http://localhost:11434
|
18
|
+
# params:
|
19
|
+
# temperature: 1.0
|
20
|
+
# top_k: 50
|
21
|
+
# top_p: 1.0
|
22
|
+
# repeat_penalty: 1.1
|
23
|
+
# max_tokens: 512
|
24
|
+
# num_return_sequences: 1
|
25
|
+
# extra_params:
|
26
|
+
# stop: ["###", "User:"]
|
27
|
+
# """
|
28
|
+
|
29
|
+
|
30
|
+
class OllamaTask(str, Enum):
|
31
|
+
TEXT_GENERATION = "text-generation"
|
32
|
+
TEXT_CLASSIFICATION = "text-classification"
|
33
|
+
SAFETY_CLASSIFICATION = "safety-classification"
|
34
|
+
|
35
|
+
def __str__(self):
|
36
|
+
return self.value # Ensures correct YAML serialization
|
37
|
+
|
38
|
+
@classmethod
|
39
|
+
def values(cls):
|
40
|
+
return [member.value for member in cls]
|
41
|
+
|
42
|
+
|
43
|
+
class OllamaProviderParams(BaseModel):
|
44
|
+
temperature: Optional[float] = Field(
|
45
|
+
None, ge=0, le=1, description="Controls randomness in generation."
|
46
|
+
)
|
47
|
+
top_k: Optional[int] = Field(
|
48
|
+
None, ge=1, description="Top-k sampling strategy for generation."
|
49
|
+
)
|
50
|
+
top_p: Optional[float] = Field(
|
51
|
+
None, ge=0, le=1, description="Nucleus sampling (top-p)."
|
52
|
+
)
|
53
|
+
repeat_penalty: Optional[float] = Field(
|
54
|
+
None, ge=0, description="Penalty for repeating tokens."
|
55
|
+
)
|
56
|
+
max_tokens: Optional[int] = Field(
|
57
|
+
None, ge=1, description="Maximum number of tokens to generate."
|
58
|
+
)
|
59
|
+
num_return_sequences: Optional[int] = Field(
|
60
|
+
None, ge=1, description="Number of sequences to return."
|
61
|
+
)
|
62
|
+
extra_params: Optional[Dict[str, Any]] = Field(
|
63
|
+
default_factory=dict,
|
64
|
+
description="Additional parameters for Ollama model invocation.",
|
65
|
+
)
|
66
|
+
|
67
|
+
|
68
|
+
class OllamaProviderConfig(BaseModel):
|
69
|
+
model: str = Field(..., description="Ollama model name (e.g. llama3, mistral).")
|
70
|
+
task: Optional[OllamaTask] = Field(
|
71
|
+
default=None,
|
72
|
+
description="Task type for the Ollama model. If not provided, will be inferred from the model name.",
|
73
|
+
)
|
74
|
+
params: Optional[OllamaProviderParams] = Field(
|
75
|
+
None, description="Optional parameters for customizing model behavior."
|
76
|
+
)
|
77
|
+
endpoint: Optional[str] = Field(
|
78
|
+
default="http://localhost:11434", description="Base URL of the Ollama server."
|
79
|
+
)
|
80
|
+
|
81
|
+
@model_validator(mode="after")
|
82
|
+
def compute_fields(cls, values):
|
83
|
+
if not values.task:
|
84
|
+
if "guard" in values.model:
|
85
|
+
values.task = OllamaTask.TEXT_CLASSIFICATION
|
86
|
+
else:
|
87
|
+
values.task = OllamaTask.TEXT_GENERATION
|
88
|
+
|
89
|
+
return values
|
90
|
+
|
91
|
+
@field_serializer("task")
|
92
|
+
def serialize_task(self, task: OllamaTask) -> str:
|
93
|
+
return task.value
|
94
|
+
|
95
|
+
|
96
|
+
def get_name(self) -> str:
|
97
|
+
"""
|
98
|
+
Returns a name like 'llama3:http:localhost:11434:/' by combining model and formatted endpoint.
|
99
|
+
"""
|
100
|
+
return f"{self.model}:{url_2_name(self.endpoint, level=3)}"
|
101
|
+
|
102
|
+
|
103
|
+
class OllamaProvider(BaseModel):
|
104
|
+
provider: Literal["ollama"] = Field(
|
105
|
+
"ollama", description="Provider ID, always set to 'ollama'."
|
106
|
+
)
|
107
|
+
config: OllamaProviderConfig
|