kubrick-cli 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,136 @@
1
+ """Base provider adapter interface."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Dict, Iterator, List, Optional
5
+
6
+
7
+ class ProviderMetadata:
8
+ """
9
+ Metadata for a provider that describes its configuration needs.
10
+
11
+ This allows providers to be self-describing and enables automatic
12
+ setup wizard generation and provider discovery.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ name: str,
18
+ display_name: str,
19
+ description: str,
20
+ config_fields: List[Dict[str, str]],
21
+ ):
22
+ """
23
+ Initialize provider metadata.
24
+
25
+ Args:
26
+ name: Internal provider name (e.g., "triton", "openai")
27
+ display_name: Human-readable name (e.g., "Triton", "OpenAI")
28
+ description: Brief description shown in setup wizard
29
+ config_fields: List of configuration field definitions
30
+ Each field should be a dict with:
31
+ - key: Config key name (e.g., "openai_api_key")
32
+ - label: Display label (e.g., "OpenAI API key")
33
+ - type: Field type ("text", "password", "url")
34
+ - default: Default value (optional)
35
+ - help_text: Help text or URL (optional)
36
+ """
37
+ self.name = name
38
+ self.display_name = display_name
39
+ self.description = description
40
+ self.config_fields = config_fields
41
+
42
+
43
+ class ProviderAdapter(ABC):
44
+ """
45
+ Abstract base class for LLM provider adapters.
46
+
47
+ All provider implementations must inherit from this class and implement
48
+ the required methods for streaming and non-streaming generation.
49
+
50
+ To create a truly plug-and-play provider:
51
+ 1. Inherit from ProviderAdapter
52
+ 2. Define METADATA as a class attribute with ProviderMetadata
53
+ 3. Implement all abstract methods
54
+ 4. Place the file in kubrick_cli/providers/ directory
55
+
56
+ The provider will be automatically discovered and available in the setup wizard.
57
+ """
58
+
59
+ METADATA: Optional[ProviderMetadata] = None
60
+
61
+ @abstractmethod
62
+ def generate_streaming(
63
+ self, messages: List[Dict[str, str]], stream_options: Dict = None
64
+ ) -> Iterator[str]:
65
+ """
66
+ Generate streaming response from LLM.
67
+
68
+ Args:
69
+ messages: List of message dicts with 'role' and 'content'
70
+ stream_options: Optional streaming parameters
71
+
72
+ Yields:
73
+ Text chunks as they arrive
74
+ """
75
+ pass
76
+
77
+ @abstractmethod
78
+ def generate(
79
+ self, messages: List[Dict[str, str]], stream_options: Dict = None
80
+ ) -> str:
81
+ """
82
+ Generate non-streaming response from LLM.
83
+
84
+ Args:
85
+ messages: List of message dicts with 'role' and 'content'
86
+ stream_options: Optional parameters
87
+
88
+ Returns:
89
+ Complete response text
90
+ """
91
+ pass
92
+
93
+ @abstractmethod
94
+ def is_healthy(self) -> bool:
95
+ """
96
+ Check if the provider is healthy and accessible.
97
+
98
+ Returns:
99
+ True if healthy, False otherwise
100
+ """
101
+ pass
102
+
103
+ @property
104
+ @abstractmethod
105
+ def provider_name(self) -> str:
106
+ """
107
+ Get the name of this provider.
108
+
109
+ Returns:
110
+ Provider name (e.g., "triton", "openai", "anthropic")
111
+ """
112
+ pass
113
+
114
+ @property
115
+ @abstractmethod
116
+ def model_name(self) -> str:
117
+ """
118
+ Get the model name being used.
119
+
120
+ Returns:
121
+ Model name/identifier
122
+ """
123
+ pass
124
+
125
+ def set_model(self, model_name: str):
126
+ """
127
+ Set the model name dynamically (optional, for model switching).
128
+
129
+ Args:
130
+ model_name: New model name
131
+
132
+ Note:
133
+ Default implementation does nothing. Override in subclasses
134
+ that support model switching.
135
+ """
136
+ pass
@@ -0,0 +1,161 @@
1
+ """Provider factory with automatic provider discovery."""
2
+
3
+ import importlib
4
+ import inspect
5
+ import pkgutil
6
+ from pathlib import Path
7
+ from typing import Dict, List, Type
8
+
9
+ from .base import ProviderAdapter, ProviderMetadata
10
+
11
+
12
+ class ProviderFactory:
13
+ """
14
+ Factory for creating provider adapter instances with automatic discovery.
15
+
16
+ This factory automatically discovers all provider classes in the providers
17
+ directory that inherit from ProviderAdapter and have METADATA defined.
18
+ """
19
+
20
+ _provider_registry: Dict[str, Type[ProviderAdapter]] = {}
21
+ _discovered = False
22
+
23
+ @classmethod
24
+ def _discover_providers(cls):
25
+ """
26
+ Discover all provider classes in the providers directory.
27
+
28
+ This method scans the providers package for classes that:
29
+ 1. Inherit from ProviderAdapter
30
+ 2. Are not the base ProviderAdapter class itself
31
+ 3. Have a METADATA attribute defined
32
+ """
33
+ if cls._discovered:
34
+ return
35
+
36
+ providers_dir = Path(__file__).parent
37
+
38
+ for module_info in pkgutil.iter_modules([str(providers_dir)]):
39
+ module_name = module_info.name
40
+
41
+ if module_name in ("base", "factory", "__init__"):
42
+ continue
43
+
44
+ try:
45
+ module = importlib.import_module(f"kubrick_cli.providers.{module_name}")
46
+
47
+ for name, obj in inspect.getmembers(module, inspect.isclass):
48
+ if (
49
+ issubclass(obj, ProviderAdapter)
50
+ and obj is not ProviderAdapter
51
+ and hasattr(obj, "METADATA")
52
+ and obj.METADATA is not None
53
+ ):
54
+ provider_name = obj.METADATA.name.lower()
55
+ cls._provider_registry[provider_name] = obj
56
+
57
+ except Exception as e:
58
+ print(f"Warning: Failed to load provider from {module_name}: {e}")
59
+ continue
60
+
61
+ cls._discovered = True
62
+
63
+ @classmethod
64
+ def create_provider(cls, config: Dict) -> ProviderAdapter:
65
+ """
66
+ Create a provider instance based on configuration.
67
+
68
+ Args:
69
+ config: Configuration dictionary containing provider settings
70
+
71
+ Returns:
72
+ ProviderAdapter instance
73
+
74
+ Raises:
75
+ ValueError: If provider is invalid or required credentials are missing
76
+ """
77
+ cls._discover_providers()
78
+
79
+ provider_name = config.get("provider", "triton").lower()
80
+
81
+ if provider_name not in cls._provider_registry:
82
+ available = ", ".join(cls._provider_registry.keys())
83
+ raise ValueError(
84
+ f"Unknown provider: {provider_name}. "
85
+ f"Available providers: {available}"
86
+ )
87
+
88
+ provider_class = cls._provider_registry[provider_name]
89
+ metadata = provider_class.METADATA
90
+
91
+ provider_config = {}
92
+ for field in metadata.config_fields:
93
+ key = field["key"]
94
+ value = config.get(key)
95
+
96
+ if value is None and "default" not in field:
97
+ raise ValueError(
98
+ f"{metadata.display_name} configuration missing required field: '{key}'. "
99
+ f"Please run setup wizard or add '{key}' to config."
100
+ )
101
+
102
+ provider_config[key] = value if value is not None else field.get("default")
103
+
104
+ init_signature = inspect.signature(provider_class.__init__)
105
+ init_params = {}
106
+
107
+ for param_name, param in init_signature.parameters.items():
108
+ if param_name == "self":
109
+ continue
110
+
111
+ for field in metadata.config_fields:
112
+ if field["key"] == param_name or field["key"].endswith(
113
+ f"_{param_name}"
114
+ ):
115
+ init_params[param_name] = provider_config[field["key"]]
116
+ break
117
+
118
+ return provider_class(**init_params)
119
+
120
+ @classmethod
121
+ def list_available_providers(cls) -> List[ProviderMetadata]:
122
+ """
123
+ Get list of available providers with their metadata.
124
+
125
+ Returns:
126
+ List of ProviderMetadata objects
127
+ """
128
+ cls._discover_providers()
129
+
130
+ providers = []
131
+ for provider_class in cls._provider_registry.values():
132
+ if provider_class.METADATA:
133
+ providers.append(provider_class.METADATA)
134
+
135
+ def sort_key(p):
136
+ return (0, "") if p.name == "triton" else (1, p.name)
137
+
138
+ providers.sort(key=sort_key)
139
+ return providers
140
+
141
+ @classmethod
142
+ def get_provider_metadata(cls, provider_name: str) -> ProviderMetadata:
143
+ """
144
+ Get metadata for a specific provider.
145
+
146
+ Args:
147
+ provider_name: Name of the provider
148
+
149
+ Returns:
150
+ ProviderMetadata object
151
+
152
+ Raises:
153
+ ValueError: If provider not found
154
+ """
155
+ cls._discover_providers()
156
+
157
+ provider_name = provider_name.lower()
158
+ if provider_name not in cls._provider_registry:
159
+ raise ValueError(f"Provider not found: {provider_name}")
160
+
161
+ return cls._provider_registry[provider_name].METADATA
@@ -0,0 +1,181 @@
1
+ """OpenAI provider adapter."""
2
+
3
+ import http.client
4
+ import json
5
+ import ssl
6
+ from typing import Dict, Iterator, List
7
+
8
+ from .base import ProviderAdapter, ProviderMetadata
9
+
10
+
11
+ class OpenAIProvider(ProviderAdapter):
12
+ """Provider adapter for OpenAI API."""
13
+
14
+ METADATA = ProviderMetadata(
15
+ name="openai",
16
+ display_name="OpenAI",
17
+ description="OpenAI API (GPT-4, GPT-3.5, etc.)",
18
+ config_fields=[
19
+ {
20
+ "key": "openai_api_key",
21
+ "label": "OpenAI API key",
22
+ "type": "password",
23
+ "help_text": "Get your API key from: https://platform.openai.com/api-keys",
24
+ },
25
+ {
26
+ "key": "openai_model",
27
+ "label": "Model name",
28
+ "type": "text",
29
+ "default": "gpt-4",
30
+ },
31
+ ],
32
+ )
33
+
34
+ def __init__(self, openai_api_key: str, openai_model: str = "gpt-4"):
35
+ """
36
+ Initialize OpenAI provider.
37
+
38
+ Args:
39
+ openai_api_key: OpenAI API key
40
+ openai_model: Model name (default: gpt-4)
41
+ """
42
+ self.api_key = openai_api_key
43
+ self._model_name = openai_model
44
+ self.base_url = "api.openai.com"
45
+ self.timeout = 600
46
+
47
+ def generate_streaming(
48
+ self, messages: List[Dict[str, str]], stream_options: Dict = None
49
+ ) -> Iterator[str]:
50
+ """
51
+ Generate streaming response from OpenAI.
52
+
53
+ Args:
54
+ messages: List of message dicts with 'role' and 'content'
55
+ stream_options: Optional streaming parameters
56
+
57
+ Yields:
58
+ Text chunks as they arrive
59
+ """
60
+ payload = {
61
+ "model": self._model_name,
62
+ "messages": messages,
63
+ "stream": True,
64
+ }
65
+
66
+ if stream_options:
67
+ if "temperature" in stream_options:
68
+ payload["temperature"] = stream_options["temperature"]
69
+ if "max_tokens" in stream_options:
70
+ payload["max_tokens"] = stream_options["max_tokens"]
71
+
72
+ headers = {
73
+ "Content-Type": "application/json",
74
+ "Authorization": f"Bearer {self.api_key}",
75
+ }
76
+
77
+ body = json.dumps(payload).encode("utf-8")
78
+
79
+ context = ssl.create_default_context()
80
+ conn = http.client.HTTPSConnection(
81
+ self.base_url, 443, timeout=self.timeout, context=context
82
+ )
83
+
84
+ try:
85
+ conn.request("POST", "/v1/chat/completions", body=body, headers=headers)
86
+ response = conn.getresponse()
87
+
88
+ if response.status != 200:
89
+ error_body = response.read().decode("utf-8")
90
+ raise Exception(f"OpenAI API error {response.status}: {error_body}")
91
+
92
+ buffer = ""
93
+ while True:
94
+ chunk = response.read(1024)
95
+ if not chunk:
96
+ break
97
+
98
+ if isinstance(chunk, bytes):
99
+ chunk = chunk.decode("utf-8")
100
+
101
+ buffer += chunk
102
+
103
+ while "\n" in buffer:
104
+ line, buffer = buffer.split("\n", 1)
105
+ line = line.strip()
106
+
107
+ if not line:
108
+ continue
109
+
110
+ if line.startswith("data: "):
111
+ line = line[6:]
112
+
113
+ if line == "[DONE]":
114
+ return
115
+
116
+ try:
117
+ data = json.loads(line)
118
+ if "choices" in data and len(data["choices"]) > 0:
119
+ delta = data["choices"][0].get("delta", {})
120
+ content = delta.get("content", "")
121
+ if content:
122
+ yield content
123
+ except json.JSONDecodeError:
124
+ continue
125
+
126
+ finally:
127
+ conn.close()
128
+
129
+ def generate(
130
+ self, messages: List[Dict[str, str]], stream_options: Dict = None
131
+ ) -> str:
132
+ """
133
+ Generate non-streaming response from OpenAI.
134
+
135
+ Args:
136
+ messages: List of message dicts with 'role' and 'content'
137
+ stream_options: Optional parameters
138
+
139
+ Returns:
140
+ Complete response text
141
+ """
142
+ chunks = []
143
+ for chunk in self.generate_streaming(messages, stream_options):
144
+ chunks.append(chunk)
145
+ return "".join(chunks)
146
+
147
+ def is_healthy(self) -> bool:
148
+ """
149
+ Check if OpenAI API is accessible.
150
+
151
+ Returns:
152
+ True if healthy, False otherwise
153
+ """
154
+ try:
155
+ context = ssl.create_default_context()
156
+ conn = http.client.HTTPSConnection(
157
+ self.base_url, 443, timeout=10, context=context
158
+ )
159
+ headers = {
160
+ "Authorization": f"Bearer {self.api_key}",
161
+ }
162
+ conn.request("GET", "/v1/models", headers=headers)
163
+ response = conn.getresponse()
164
+ conn.close()
165
+ return response.status == 200
166
+ except Exception:
167
+ return False
168
+
169
+ @property
170
+ def provider_name(self) -> str:
171
+ """Get provider name."""
172
+ return "openai"
173
+
174
+ @property
175
+ def model_name(self) -> str:
176
+ """Get model name."""
177
+ return self._model_name
178
+
179
+ def set_model(self, model_name: str):
180
+ """Set model name dynamically."""
181
+ self._model_name = model_name
@@ -0,0 +1,96 @@
1
+ """Triton provider adapter."""
2
+
3
+ from typing import Dict, Iterator, List
4
+
5
+ from ..triton_client import TritonLLMClient
6
+ from .base import ProviderAdapter, ProviderMetadata
7
+
8
+
9
+ class TritonProvider(ProviderAdapter):
10
+ """Provider adapter for Triton Inference Server."""
11
+
12
+ METADATA = ProviderMetadata(
13
+ name="triton",
14
+ display_name="Triton",
15
+ description="Self-hosted Triton Inference Server (default)",
16
+ config_fields=[
17
+ {
18
+ "key": "triton_url",
19
+ "label": "Triton server URL",
20
+ "type": "url",
21
+ "default": "localhost:8000",
22
+ },
23
+ {
24
+ "key": "triton_model",
25
+ "label": "Triton model name",
26
+ "type": "text",
27
+ "default": "llm_decoupled",
28
+ },
29
+ ],
30
+ )
31
+
32
+ def __init__(self, triton_url: str = "localhost:8000", triton_model: str = "llm_decoupled"):
33
+ """
34
+ Initialize Triton provider.
35
+
36
+ Args:
37
+ triton_url: Triton server URL (default: localhost:8000)
38
+ triton_model: Triton model name (default: llm_decoupled)
39
+ """
40
+ self.client = TritonLLMClient(url=triton_url, model_name=triton_model)
41
+ self._url = triton_url
42
+ self._model_name = triton_model
43
+
44
+ def generate_streaming(
45
+ self, messages: List[Dict[str, str]], stream_options: Dict = None
46
+ ) -> Iterator[str]:
47
+ """
48
+ Generate streaming response from Triton.
49
+
50
+ Args:
51
+ messages: List of message dicts with 'role' and 'content'
52
+ stream_options: Optional streaming parameters
53
+
54
+ Yields:
55
+ Text chunks as they arrive
56
+ """
57
+ yield from self.client.generate_streaming(messages, stream_options)
58
+
59
+ def generate(
60
+ self, messages: List[Dict[str, str]], stream_options: Dict = None
61
+ ) -> str:
62
+ """
63
+ Generate non-streaming response from Triton.
64
+
65
+ Args:
66
+ messages: List of message dicts with 'role' and 'content'
67
+ stream_options: Optional parameters
68
+
69
+ Returns:
70
+ Complete response text
71
+ """
72
+ return self.client.generate(messages, stream_options)
73
+
74
+ def is_healthy(self) -> bool:
75
+ """
76
+ Check if Triton server is healthy.
77
+
78
+ Returns:
79
+ True if healthy, False otherwise
80
+ """
81
+ return self.client.is_healthy()
82
+
83
+ @property
84
+ def provider_name(self) -> str:
85
+ """Get provider name."""
86
+ return "triton"
87
+
88
+ @property
89
+ def model_name(self) -> str:
90
+ """Get model name."""
91
+ return self._model_name
92
+
93
+ @property
94
+ def url(self) -> str:
95
+ """Get Triton server URL."""
96
+ return self._url