kubrick-cli 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kubrick_cli/__init__.py +47 -0
- kubrick_cli/agent_loop.py +274 -0
- kubrick_cli/classifier.py +194 -0
- kubrick_cli/config.py +247 -0
- kubrick_cli/display.py +154 -0
- kubrick_cli/execution_strategy.py +195 -0
- kubrick_cli/main.py +806 -0
- kubrick_cli/planning.py +319 -0
- kubrick_cli/progress.py +162 -0
- kubrick_cli/providers/__init__.py +6 -0
- kubrick_cli/providers/anthropic_provider.py +209 -0
- kubrick_cli/providers/base.py +136 -0
- kubrick_cli/providers/factory.py +161 -0
- kubrick_cli/providers/openai_provider.py +181 -0
- kubrick_cli/providers/triton_provider.py +96 -0
- kubrick_cli/safety.py +204 -0
- kubrick_cli/scheduler.py +183 -0
- kubrick_cli/setup_wizard.py +161 -0
- kubrick_cli/tools.py +400 -0
- kubrick_cli/triton_client.py +177 -0
- kubrick_cli-0.1.4.dist-info/METADATA +137 -0
- kubrick_cli-0.1.4.dist-info/RECORD +26 -0
- kubrick_cli-0.1.4.dist-info/WHEEL +5 -0
- kubrick_cli-0.1.4.dist-info/entry_points.txt +2 -0
- kubrick_cli-0.1.4.dist-info/licenses/LICENSE +21 -0
- kubrick_cli-0.1.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
"""Base provider adapter interface."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Dict, Iterator, List, Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ProviderMetadata:
|
|
8
|
+
"""
|
|
9
|
+
Metadata for a provider that describes its configuration needs.
|
|
10
|
+
|
|
11
|
+
This allows providers to be self-describing and enables automatic
|
|
12
|
+
setup wizard generation and provider discovery.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
name: str,
|
|
18
|
+
display_name: str,
|
|
19
|
+
description: str,
|
|
20
|
+
config_fields: List[Dict[str, str]],
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Initialize provider metadata.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
name: Internal provider name (e.g., "triton", "openai")
|
|
27
|
+
display_name: Human-readable name (e.g., "Triton", "OpenAI")
|
|
28
|
+
description: Brief description shown in setup wizard
|
|
29
|
+
config_fields: List of configuration field definitions
|
|
30
|
+
Each field should be a dict with:
|
|
31
|
+
- key: Config key name (e.g., "openai_api_key")
|
|
32
|
+
- label: Display label (e.g., "OpenAI API key")
|
|
33
|
+
- type: Field type ("text", "password", "url")
|
|
34
|
+
- default: Default value (optional)
|
|
35
|
+
- help_text: Help text or URL (optional)
|
|
36
|
+
"""
|
|
37
|
+
self.name = name
|
|
38
|
+
self.display_name = display_name
|
|
39
|
+
self.description = description
|
|
40
|
+
self.config_fields = config_fields
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ProviderAdapter(ABC):
|
|
44
|
+
"""
|
|
45
|
+
Abstract base class for LLM provider adapters.
|
|
46
|
+
|
|
47
|
+
All provider implementations must inherit from this class and implement
|
|
48
|
+
the required methods for streaming and non-streaming generation.
|
|
49
|
+
|
|
50
|
+
To create a truly plug-and-play provider:
|
|
51
|
+
1. Inherit from ProviderAdapter
|
|
52
|
+
2. Define METADATA as a class attribute with ProviderMetadata
|
|
53
|
+
3. Implement all abstract methods
|
|
54
|
+
4. Place the file in kubrick_cli/providers/ directory
|
|
55
|
+
|
|
56
|
+
The provider will be automatically discovered and available in the setup wizard.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
METADATA: Optional[ProviderMetadata] = None
|
|
60
|
+
|
|
61
|
+
@abstractmethod
|
|
62
|
+
def generate_streaming(
|
|
63
|
+
self, messages: List[Dict[str, str]], stream_options: Dict = None
|
|
64
|
+
) -> Iterator[str]:
|
|
65
|
+
"""
|
|
66
|
+
Generate streaming response from LLM.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
messages: List of message dicts with 'role' and 'content'
|
|
70
|
+
stream_options: Optional streaming parameters
|
|
71
|
+
|
|
72
|
+
Yields:
|
|
73
|
+
Text chunks as they arrive
|
|
74
|
+
"""
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
@abstractmethod
|
|
78
|
+
def generate(
|
|
79
|
+
self, messages: List[Dict[str, str]], stream_options: Dict = None
|
|
80
|
+
) -> str:
|
|
81
|
+
"""
|
|
82
|
+
Generate non-streaming response from LLM.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
messages: List of message dicts with 'role' and 'content'
|
|
86
|
+
stream_options: Optional parameters
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Complete response text
|
|
90
|
+
"""
|
|
91
|
+
pass
|
|
92
|
+
|
|
93
|
+
@abstractmethod
|
|
94
|
+
def is_healthy(self) -> bool:
|
|
95
|
+
"""
|
|
96
|
+
Check if the provider is healthy and accessible.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
True if healthy, False otherwise
|
|
100
|
+
"""
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
@abstractmethod
|
|
105
|
+
def provider_name(self) -> str:
|
|
106
|
+
"""
|
|
107
|
+
Get the name of this provider.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Provider name (e.g., "triton", "openai", "anthropic")
|
|
111
|
+
"""
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
@abstractmethod
|
|
116
|
+
def model_name(self) -> str:
|
|
117
|
+
"""
|
|
118
|
+
Get the model name being used.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Model name/identifier
|
|
122
|
+
"""
|
|
123
|
+
pass
|
|
124
|
+
|
|
125
|
+
def set_model(self, model_name: str):
|
|
126
|
+
"""
|
|
127
|
+
Set the model name dynamically (optional, for model switching).
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
model_name: New model name
|
|
131
|
+
|
|
132
|
+
Note:
|
|
133
|
+
Default implementation does nothing. Override in subclasses
|
|
134
|
+
that support model switching.
|
|
135
|
+
"""
|
|
136
|
+
pass
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""Provider factory with automatic provider discovery."""
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import inspect
|
|
5
|
+
import pkgutil
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Dict, List, Type
|
|
8
|
+
|
|
9
|
+
from .base import ProviderAdapter, ProviderMetadata
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ProviderFactory:
|
|
13
|
+
"""
|
|
14
|
+
Factory for creating provider adapter instances with automatic discovery.
|
|
15
|
+
|
|
16
|
+
This factory automatically discovers all provider classes in the providers
|
|
17
|
+
directory that inherit from ProviderAdapter and have METADATA defined.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
_provider_registry: Dict[str, Type[ProviderAdapter]] = {}
|
|
21
|
+
_discovered = False
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def _discover_providers(cls):
|
|
25
|
+
"""
|
|
26
|
+
Discover all provider classes in the providers directory.
|
|
27
|
+
|
|
28
|
+
This method scans the providers package for classes that:
|
|
29
|
+
1. Inherit from ProviderAdapter
|
|
30
|
+
2. Are not the base ProviderAdapter class itself
|
|
31
|
+
3. Have a METADATA attribute defined
|
|
32
|
+
"""
|
|
33
|
+
if cls._discovered:
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
providers_dir = Path(__file__).parent
|
|
37
|
+
|
|
38
|
+
for module_info in pkgutil.iter_modules([str(providers_dir)]):
|
|
39
|
+
module_name = module_info.name
|
|
40
|
+
|
|
41
|
+
if module_name in ("base", "factory", "__init__"):
|
|
42
|
+
continue
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
module = importlib.import_module(f"kubrick_cli.providers.{module_name}")
|
|
46
|
+
|
|
47
|
+
for name, obj in inspect.getmembers(module, inspect.isclass):
|
|
48
|
+
if (
|
|
49
|
+
issubclass(obj, ProviderAdapter)
|
|
50
|
+
and obj is not ProviderAdapter
|
|
51
|
+
and hasattr(obj, "METADATA")
|
|
52
|
+
and obj.METADATA is not None
|
|
53
|
+
):
|
|
54
|
+
provider_name = obj.METADATA.name.lower()
|
|
55
|
+
cls._provider_registry[provider_name] = obj
|
|
56
|
+
|
|
57
|
+
except Exception as e:
|
|
58
|
+
print(f"Warning: Failed to load provider from {module_name}: {e}")
|
|
59
|
+
continue
|
|
60
|
+
|
|
61
|
+
cls._discovered = True
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def create_provider(cls, config: Dict) -> ProviderAdapter:
|
|
65
|
+
"""
|
|
66
|
+
Create a provider instance based on configuration.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
config: Configuration dictionary containing provider settings
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
ProviderAdapter instance
|
|
73
|
+
|
|
74
|
+
Raises:
|
|
75
|
+
ValueError: If provider is invalid or required credentials are missing
|
|
76
|
+
"""
|
|
77
|
+
cls._discover_providers()
|
|
78
|
+
|
|
79
|
+
provider_name = config.get("provider", "triton").lower()
|
|
80
|
+
|
|
81
|
+
if provider_name not in cls._provider_registry:
|
|
82
|
+
available = ", ".join(cls._provider_registry.keys())
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"Unknown provider: {provider_name}. "
|
|
85
|
+
f"Available providers: {available}"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
provider_class = cls._provider_registry[provider_name]
|
|
89
|
+
metadata = provider_class.METADATA
|
|
90
|
+
|
|
91
|
+
provider_config = {}
|
|
92
|
+
for field in metadata.config_fields:
|
|
93
|
+
key = field["key"]
|
|
94
|
+
value = config.get(key)
|
|
95
|
+
|
|
96
|
+
if value is None and "default" not in field:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f"{metadata.display_name} configuration missing required field: '{key}'. "
|
|
99
|
+
f"Please run setup wizard or add '{key}' to config."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
provider_config[key] = value if value is not None else field.get("default")
|
|
103
|
+
|
|
104
|
+
init_signature = inspect.signature(provider_class.__init__)
|
|
105
|
+
init_params = {}
|
|
106
|
+
|
|
107
|
+
for param_name, param in init_signature.parameters.items():
|
|
108
|
+
if param_name == "self":
|
|
109
|
+
continue
|
|
110
|
+
|
|
111
|
+
for field in metadata.config_fields:
|
|
112
|
+
if field["key"] == param_name or field["key"].endswith(
|
|
113
|
+
f"_{param_name}"
|
|
114
|
+
):
|
|
115
|
+
init_params[param_name] = provider_config[field["key"]]
|
|
116
|
+
break
|
|
117
|
+
|
|
118
|
+
return provider_class(**init_params)
|
|
119
|
+
|
|
120
|
+
@classmethod
|
|
121
|
+
def list_available_providers(cls) -> List[ProviderMetadata]:
|
|
122
|
+
"""
|
|
123
|
+
Get list of available providers with their metadata.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
List of ProviderMetadata objects
|
|
127
|
+
"""
|
|
128
|
+
cls._discover_providers()
|
|
129
|
+
|
|
130
|
+
providers = []
|
|
131
|
+
for provider_class in cls._provider_registry.values():
|
|
132
|
+
if provider_class.METADATA:
|
|
133
|
+
providers.append(provider_class.METADATA)
|
|
134
|
+
|
|
135
|
+
def sort_key(p):
|
|
136
|
+
return (0, "") if p.name == "triton" else (1, p.name)
|
|
137
|
+
|
|
138
|
+
providers.sort(key=sort_key)
|
|
139
|
+
return providers
|
|
140
|
+
|
|
141
|
+
@classmethod
|
|
142
|
+
def get_provider_metadata(cls, provider_name: str) -> ProviderMetadata:
|
|
143
|
+
"""
|
|
144
|
+
Get metadata for a specific provider.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
provider_name: Name of the provider
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
ProviderMetadata object
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
ValueError: If provider not found
|
|
154
|
+
"""
|
|
155
|
+
cls._discover_providers()
|
|
156
|
+
|
|
157
|
+
provider_name = provider_name.lower()
|
|
158
|
+
if provider_name not in cls._provider_registry:
|
|
159
|
+
raise ValueError(f"Provider not found: {provider_name}")
|
|
160
|
+
|
|
161
|
+
return cls._provider_registry[provider_name].METADATA
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
"""OpenAI provider adapter."""
|
|
2
|
+
|
|
3
|
+
import http.client
|
|
4
|
+
import json
|
|
5
|
+
import ssl
|
|
6
|
+
from typing import Dict, Iterator, List
|
|
7
|
+
|
|
8
|
+
from .base import ProviderAdapter, ProviderMetadata
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OpenAIProvider(ProviderAdapter):
|
|
12
|
+
"""Provider adapter for OpenAI API."""
|
|
13
|
+
|
|
14
|
+
METADATA = ProviderMetadata(
|
|
15
|
+
name="openai",
|
|
16
|
+
display_name="OpenAI",
|
|
17
|
+
description="OpenAI API (GPT-4, GPT-3.5, etc.)",
|
|
18
|
+
config_fields=[
|
|
19
|
+
{
|
|
20
|
+
"key": "openai_api_key",
|
|
21
|
+
"label": "OpenAI API key",
|
|
22
|
+
"type": "password",
|
|
23
|
+
"help_text": "Get your API key from: https://platform.openai.com/api-keys",
|
|
24
|
+
},
|
|
25
|
+
{
|
|
26
|
+
"key": "openai_model",
|
|
27
|
+
"label": "Model name",
|
|
28
|
+
"type": "text",
|
|
29
|
+
"default": "gpt-4",
|
|
30
|
+
},
|
|
31
|
+
],
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def __init__(self, openai_api_key: str, openai_model: str = "gpt-4"):
|
|
35
|
+
"""
|
|
36
|
+
Initialize OpenAI provider.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
openai_api_key: OpenAI API key
|
|
40
|
+
openai_model: Model name (default: gpt-4)
|
|
41
|
+
"""
|
|
42
|
+
self.api_key = openai_api_key
|
|
43
|
+
self._model_name = openai_model
|
|
44
|
+
self.base_url = "api.openai.com"
|
|
45
|
+
self.timeout = 600
|
|
46
|
+
|
|
47
|
+
def generate_streaming(
|
|
48
|
+
self, messages: List[Dict[str, str]], stream_options: Dict = None
|
|
49
|
+
) -> Iterator[str]:
|
|
50
|
+
"""
|
|
51
|
+
Generate streaming response from OpenAI.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
messages: List of message dicts with 'role' and 'content'
|
|
55
|
+
stream_options: Optional streaming parameters
|
|
56
|
+
|
|
57
|
+
Yields:
|
|
58
|
+
Text chunks as they arrive
|
|
59
|
+
"""
|
|
60
|
+
payload = {
|
|
61
|
+
"model": self._model_name,
|
|
62
|
+
"messages": messages,
|
|
63
|
+
"stream": True,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
if stream_options:
|
|
67
|
+
if "temperature" in stream_options:
|
|
68
|
+
payload["temperature"] = stream_options["temperature"]
|
|
69
|
+
if "max_tokens" in stream_options:
|
|
70
|
+
payload["max_tokens"] = stream_options["max_tokens"]
|
|
71
|
+
|
|
72
|
+
headers = {
|
|
73
|
+
"Content-Type": "application/json",
|
|
74
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
body = json.dumps(payload).encode("utf-8")
|
|
78
|
+
|
|
79
|
+
context = ssl.create_default_context()
|
|
80
|
+
conn = http.client.HTTPSConnection(
|
|
81
|
+
self.base_url, 443, timeout=self.timeout, context=context
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
try:
|
|
85
|
+
conn.request("POST", "/v1/chat/completions", body=body, headers=headers)
|
|
86
|
+
response = conn.getresponse()
|
|
87
|
+
|
|
88
|
+
if response.status != 200:
|
|
89
|
+
error_body = response.read().decode("utf-8")
|
|
90
|
+
raise Exception(f"OpenAI API error {response.status}: {error_body}")
|
|
91
|
+
|
|
92
|
+
buffer = ""
|
|
93
|
+
while True:
|
|
94
|
+
chunk = response.read(1024)
|
|
95
|
+
if not chunk:
|
|
96
|
+
break
|
|
97
|
+
|
|
98
|
+
if isinstance(chunk, bytes):
|
|
99
|
+
chunk = chunk.decode("utf-8")
|
|
100
|
+
|
|
101
|
+
buffer += chunk
|
|
102
|
+
|
|
103
|
+
while "\n" in buffer:
|
|
104
|
+
line, buffer = buffer.split("\n", 1)
|
|
105
|
+
line = line.strip()
|
|
106
|
+
|
|
107
|
+
if not line:
|
|
108
|
+
continue
|
|
109
|
+
|
|
110
|
+
if line.startswith("data: "):
|
|
111
|
+
line = line[6:]
|
|
112
|
+
|
|
113
|
+
if line == "[DONE]":
|
|
114
|
+
return
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
data = json.loads(line)
|
|
118
|
+
if "choices" in data and len(data["choices"]) > 0:
|
|
119
|
+
delta = data["choices"][0].get("delta", {})
|
|
120
|
+
content = delta.get("content", "")
|
|
121
|
+
if content:
|
|
122
|
+
yield content
|
|
123
|
+
except json.JSONDecodeError:
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
finally:
|
|
127
|
+
conn.close()
|
|
128
|
+
|
|
129
|
+
def generate(
|
|
130
|
+
self, messages: List[Dict[str, str]], stream_options: Dict = None
|
|
131
|
+
) -> str:
|
|
132
|
+
"""
|
|
133
|
+
Generate non-streaming response from OpenAI.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
messages: List of message dicts with 'role' and 'content'
|
|
137
|
+
stream_options: Optional parameters
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Complete response text
|
|
141
|
+
"""
|
|
142
|
+
chunks = []
|
|
143
|
+
for chunk in self.generate_streaming(messages, stream_options):
|
|
144
|
+
chunks.append(chunk)
|
|
145
|
+
return "".join(chunks)
|
|
146
|
+
|
|
147
|
+
def is_healthy(self) -> bool:
|
|
148
|
+
"""
|
|
149
|
+
Check if OpenAI API is accessible.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
True if healthy, False otherwise
|
|
153
|
+
"""
|
|
154
|
+
try:
|
|
155
|
+
context = ssl.create_default_context()
|
|
156
|
+
conn = http.client.HTTPSConnection(
|
|
157
|
+
self.base_url, 443, timeout=10, context=context
|
|
158
|
+
)
|
|
159
|
+
headers = {
|
|
160
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
161
|
+
}
|
|
162
|
+
conn.request("GET", "/v1/models", headers=headers)
|
|
163
|
+
response = conn.getresponse()
|
|
164
|
+
conn.close()
|
|
165
|
+
return response.status == 200
|
|
166
|
+
except Exception:
|
|
167
|
+
return False
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def provider_name(self) -> str:
|
|
171
|
+
"""Get provider name."""
|
|
172
|
+
return "openai"
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def model_name(self) -> str:
|
|
176
|
+
"""Get model name."""
|
|
177
|
+
return self._model_name
|
|
178
|
+
|
|
179
|
+
def set_model(self, model_name: str):
|
|
180
|
+
"""Set model name dynamically."""
|
|
181
|
+
self._model_name = model_name
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""Triton provider adapter."""
|
|
2
|
+
|
|
3
|
+
from typing import Dict, Iterator, List
|
|
4
|
+
|
|
5
|
+
from ..triton_client import TritonLLMClient
|
|
6
|
+
from .base import ProviderAdapter, ProviderMetadata
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TritonProvider(ProviderAdapter):
|
|
10
|
+
"""Provider adapter for Triton Inference Server."""
|
|
11
|
+
|
|
12
|
+
METADATA = ProviderMetadata(
|
|
13
|
+
name="triton",
|
|
14
|
+
display_name="Triton",
|
|
15
|
+
description="Self-hosted Triton Inference Server (default)",
|
|
16
|
+
config_fields=[
|
|
17
|
+
{
|
|
18
|
+
"key": "triton_url",
|
|
19
|
+
"label": "Triton server URL",
|
|
20
|
+
"type": "url",
|
|
21
|
+
"default": "localhost:8000",
|
|
22
|
+
},
|
|
23
|
+
{
|
|
24
|
+
"key": "triton_model",
|
|
25
|
+
"label": "Triton model name",
|
|
26
|
+
"type": "text",
|
|
27
|
+
"default": "llm_decoupled",
|
|
28
|
+
},
|
|
29
|
+
],
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
def __init__(self, triton_url: str = "localhost:8000", triton_model: str = "llm_decoupled"):
|
|
33
|
+
"""
|
|
34
|
+
Initialize Triton provider.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
triton_url: Triton server URL (default: localhost:8000)
|
|
38
|
+
triton_model: Triton model name (default: llm_decoupled)
|
|
39
|
+
"""
|
|
40
|
+
self.client = TritonLLMClient(url=triton_url, model_name=triton_model)
|
|
41
|
+
self._url = triton_url
|
|
42
|
+
self._model_name = triton_model
|
|
43
|
+
|
|
44
|
+
def generate_streaming(
|
|
45
|
+
self, messages: List[Dict[str, str]], stream_options: Dict = None
|
|
46
|
+
) -> Iterator[str]:
|
|
47
|
+
"""
|
|
48
|
+
Generate streaming response from Triton.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
messages: List of message dicts with 'role' and 'content'
|
|
52
|
+
stream_options: Optional streaming parameters
|
|
53
|
+
|
|
54
|
+
Yields:
|
|
55
|
+
Text chunks as they arrive
|
|
56
|
+
"""
|
|
57
|
+
yield from self.client.generate_streaming(messages, stream_options)
|
|
58
|
+
|
|
59
|
+
def generate(
|
|
60
|
+
self, messages: List[Dict[str, str]], stream_options: Dict = None
|
|
61
|
+
) -> str:
|
|
62
|
+
"""
|
|
63
|
+
Generate non-streaming response from Triton.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
messages: List of message dicts with 'role' and 'content'
|
|
67
|
+
stream_options: Optional parameters
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Complete response text
|
|
71
|
+
"""
|
|
72
|
+
return self.client.generate(messages, stream_options)
|
|
73
|
+
|
|
74
|
+
def is_healthy(self) -> bool:
|
|
75
|
+
"""
|
|
76
|
+
Check if Triton server is healthy.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
True if healthy, False otherwise
|
|
80
|
+
"""
|
|
81
|
+
return self.client.is_healthy()
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def provider_name(self) -> str:
|
|
85
|
+
"""Get provider name."""
|
|
86
|
+
return "triton"
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def model_name(self) -> str:
|
|
90
|
+
"""Get model name."""
|
|
91
|
+
return self._model_name
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def url(self) -> str:
|
|
95
|
+
"""Get Triton server URL."""
|
|
96
|
+
return self._url
|