lollms-client 0.33.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lollms-client might be problematic. Click here for more details.
- lollms_client/__init__.py +1 -1
- lollms_client/llm_bindings/azure_openai/__init__.py +6 -10
- lollms_client/llm_bindings/claude/__init__.py +4 -7
- lollms_client/llm_bindings/gemini/__init__.py +3 -7
- lollms_client/llm_bindings/grok/__init__.py +3 -7
- lollms_client/llm_bindings/groq/__init__.py +4 -6
- lollms_client/llm_bindings/hugging_face_inference_api/__init__.py +4 -6
- lollms_client/llm_bindings/litellm/__init__.py +15 -6
- lollms_client/llm_bindings/llamacpp/__init__.py +27 -9
- lollms_client/llm_bindings/lollms/__init__.py +24 -14
- lollms_client/llm_bindings/lollms_webui/__init__.py +6 -12
- lollms_client/llm_bindings/mistral/__init__.py +3 -5
- lollms_client/llm_bindings/ollama/__init__.py +6 -11
- lollms_client/llm_bindings/open_router/__init__.py +4 -6
- lollms_client/llm_bindings/openai/__init__.py +7 -14
- lollms_client/llm_bindings/openllm/__init__.py +12 -12
- lollms_client/llm_bindings/pythonllamacpp/__init__.py +1 -1
- lollms_client/llm_bindings/tensor_rt/__init__.py +8 -13
- lollms_client/llm_bindings/transformers/__init__.py +14 -6
- lollms_client/llm_bindings/vllm/__init__.py +16 -12
- lollms_client/lollms_core.py +296 -487
- lollms_client/lollms_discussion.py +431 -78
- lollms_client/lollms_llm_binding.py +191 -380
- lollms_client/lollms_mcp_binding.py +33 -2
- lollms_client/mcp_bindings/local_mcp/__init__.py +3 -2
- lollms_client/mcp_bindings/remote_mcp/__init__.py +6 -5
- lollms_client/mcp_bindings/standard_mcp/__init__.py +3 -5
- lollms_client/stt_bindings/lollms/__init__.py +6 -8
- lollms_client/stt_bindings/whisper/__init__.py +2 -4
- lollms_client/stt_bindings/whispercpp/__init__.py +15 -16
- lollms_client/tti_bindings/dalle/__init__.py +29 -28
- lollms_client/tti_bindings/diffusers/__init__.py +25 -21
- lollms_client/tti_bindings/gemini/__init__.py +215 -0
- lollms_client/tti_bindings/lollms/__init__.py +8 -9
- lollms_client-1.0.0.dist-info/METADATA +1214 -0
- lollms_client-1.0.0.dist-info/RECORD +69 -0
- {lollms_client-0.33.0.dist-info → lollms_client-1.0.0.dist-info}/top_level.txt +0 -2
- examples/article_summary/article_summary.py +0 -58
- examples/console_discussion/console_app.py +0 -266
- examples/console_discussion.py +0 -448
- examples/deep_analyze/deep_analyse.py +0 -30
- examples/deep_analyze/deep_analyze_multiple_files.py +0 -32
- examples/function_calling_with_local_custom_mcp.py +0 -250
- examples/generate_a_benchmark_for_safe_store.py +0 -89
- examples/generate_and_speak/generate_and_speak.py +0 -251
- examples/generate_game_sfx/generate_game_fx.py +0 -240
- examples/generate_text_with_multihop_rag_example.py +0 -210
- examples/gradio_chat_app.py +0 -228
- examples/gradio_lollms_chat.py +0 -259
- examples/internet_search_with_rag.py +0 -226
- examples/lollms_chat/calculator.py +0 -59
- examples/lollms_chat/derivative.py +0 -48
- examples/lollms_chat/test_openai_compatible_with_lollms_chat.py +0 -12
- examples/lollms_discussions_test.py +0 -155
- examples/mcp_examples/external_mcp.py +0 -267
- examples/mcp_examples/local_mcp.py +0 -171
- examples/mcp_examples/openai_mcp.py +0 -203
- examples/mcp_examples/run_remote_mcp_example_v2.py +0 -290
- examples/mcp_examples/run_standard_mcp_example.py +0 -204
- examples/simple_text_gen_test.py +0 -173
- examples/simple_text_gen_with_image_test.py +0 -178
- examples/test_local_models/local_chat.py +0 -9
- examples/text_2_audio.py +0 -77
- examples/text_2_image.py +0 -144
- examples/text_2_image_diffusers.py +0 -274
- examples/text_and_image_2_audio.py +0 -59
- examples/text_gen.py +0 -30
- examples/text_gen_system_prompt.py +0 -29
- lollms_client-0.33.0.dist-info/METADATA +0 -854
- lollms_client-0.33.0.dist-info/RECORD +0 -101
- test/test_lollms_discussion.py +0 -368
- {lollms_client-0.33.0.dist-info → lollms_client-1.0.0.dist-info}/WHEEL +0 -0
- {lollms_client-0.33.0.dist-info → lollms_client-1.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -21,15 +21,16 @@ class LocalMCPBinding(LollmsMCPBinding):
|
|
|
21
21
|
"""
|
|
22
22
|
|
|
23
23
|
def __init__(self,
|
|
24
|
-
|
|
24
|
+
**kwargs: Any
|
|
25
|
+
):
|
|
25
26
|
"""
|
|
26
27
|
Initialize the LocalMCPBinding.
|
|
27
28
|
|
|
28
29
|
Args:
|
|
29
|
-
binding_name (str): The name of this binding.
|
|
30
30
|
tools_folder_path (str|Path) a folder where to find tools
|
|
31
31
|
"""
|
|
32
32
|
super().__init__(binding_name="LocalMCP")
|
|
33
|
+
tools_folder_path = kwargs.get("tools_folder_path")
|
|
33
34
|
if tools_folder_path:
|
|
34
35
|
try:
|
|
35
36
|
self.tools_folder_path: Optional[Path] = Path(tools_folder_path)
|
|
@@ -27,8 +27,8 @@ class RemoteMCPBinding(LollmsMCPBinding):
|
|
|
27
27
|
Tools from all connected servers are aggregated and prefixed with the server's alias.
|
|
28
28
|
"""
|
|
29
29
|
def __init__(self,
|
|
30
|
-
|
|
31
|
-
|
|
30
|
+
**kwargs: Any
|
|
31
|
+
):
|
|
32
32
|
"""
|
|
33
33
|
Initializes the binding to connect to multiple MCP servers.
|
|
34
34
|
|
|
@@ -41,10 +41,11 @@ class RemoteMCPBinding(LollmsMCPBinding):
|
|
|
41
41
|
"main_server": {"server_url": "http://localhost:8787", "auth_config": {}},
|
|
42
42
|
"experimental_server": {"server_url": "http://test.server:9000"}
|
|
43
43
|
}
|
|
44
|
-
**
|
|
44
|
+
**kwargs (Any): Additional configuration parameters.
|
|
45
45
|
"""
|
|
46
46
|
super().__init__(binding_name="remote_mcp")
|
|
47
47
|
# initialization in case no servers are present
|
|
48
|
+
servers_infos: Dict[str, Dict[str, Any]] = kwargs.get("servers_infos", {})
|
|
48
49
|
self.servers = None
|
|
49
50
|
if not MCP_LIBRARY_AVAILABLE:
|
|
50
51
|
ASCIIColors.error(f"{self.binding_name}: MCP library not available. This binding will be disabled.")
|
|
@@ -56,8 +57,8 @@ class RemoteMCPBinding(LollmsMCPBinding):
|
|
|
56
57
|
|
|
57
58
|
### NEW: Store the overall configuration
|
|
58
59
|
self.config = {
|
|
59
|
-
"servers_infos": servers_infos,
|
|
60
|
-
**
|
|
60
|
+
"servers_infos": kwargs.get("servers_infos"),
|
|
61
|
+
**kwargs
|
|
61
62
|
}
|
|
62
63
|
|
|
63
64
|
### NEW: State management for multiple servers.
|
|
@@ -48,12 +48,10 @@ class StandardMCPBinding(LollmsMCPBinding):
|
|
|
48
48
|
"""
|
|
49
49
|
|
|
50
50
|
def __init__(self,
|
|
51
|
-
|
|
52
|
-
**other_config_params: Any):
|
|
51
|
+
**kwargs: Any):
|
|
53
52
|
super().__init__(binding_name="standard_mcp")
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
self.config.update(other_config_params)
|
|
53
|
+
self.config = kwargs
|
|
54
|
+
initial_servers = kwargs.get("initial_servers", {})
|
|
57
55
|
|
|
58
56
|
self._server_configs: Dict[str, Dict[str, Any]] = {}
|
|
59
57
|
# Type hint with ClientSession, actual obj if MCP_LIBRARY_AVAILABLE
|
|
@@ -14,10 +14,8 @@ class LollmsSTTBinding_Impl(LollmsSTTBinding):
|
|
|
14
14
|
"""Concrete implementation of the LollmsSTTBinding for the standard LOLLMS server."""
|
|
15
15
|
|
|
16
16
|
def __init__(self,
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
service_key: Optional[str] = None,
|
|
20
|
-
verify_ssl_certificate: bool = True):
|
|
17
|
+
**kwargs
|
|
18
|
+
):
|
|
21
19
|
"""
|
|
22
20
|
Initialize the LOLLMS STT binding.
|
|
23
21
|
|
|
@@ -28,10 +26,10 @@ class LollmsSTTBinding_Impl(LollmsSTTBinding):
|
|
|
28
26
|
verify_ssl_certificate (bool): Whether to verify SSL certificates.
|
|
29
27
|
"""
|
|
30
28
|
super().__init__("lollms")
|
|
31
|
-
self.host_address=host_address
|
|
32
|
-
self.model_name=model_name
|
|
33
|
-
self.service_key=service_key
|
|
34
|
-
self.verify_ssl_certificate=verify_ssl_certificate
|
|
29
|
+
self.host_address=kwargs.get("host_address")
|
|
30
|
+
self.model_name=kwargs.get("model_name")
|
|
31
|
+
self.service_key=kwargs.get("service_key")
|
|
32
|
+
self.verify_ssl_certificate=kwargs.get("verify_ssl_certificate")
|
|
35
33
|
|
|
36
34
|
def transcribe_audio(self, audio_path: Union[str, Path], model: Optional[str] = None, **kwargs) -> str:
|
|
37
35
|
"""
|
|
@@ -70,8 +70,6 @@ class WhisperSTTBinding(LollmsSTTBinding):
|
|
|
70
70
|
WHISPER_MODEL_SIZES = ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2", "large-v3"]
|
|
71
71
|
|
|
72
72
|
def __init__(self,
|
|
73
|
-
model_name: str = "base", # Default Whisper model size
|
|
74
|
-
device: Optional[str] = None, # "cpu", "cuda", "mps", or None for auto
|
|
75
73
|
**kwargs # To catch any other LollmsSTTBinding standard args
|
|
76
74
|
):
|
|
77
75
|
"""
|
|
@@ -88,7 +86,7 @@ class WhisperSTTBinding(LollmsSTTBinding):
|
|
|
88
86
|
if not _whisper_installed:
|
|
89
87
|
raise ImportError(f"Whisper STT binding dependencies not met. Please ensure 'openai-whisper' and 'torch' are installed. Error: {_whisper_installation_error}")
|
|
90
88
|
|
|
91
|
-
self.device = device
|
|
89
|
+
self.device = kwargs.get("device",None)
|
|
92
90
|
if self.device is None: # Auto-detect if not specified
|
|
93
91
|
if torch.cuda.is_available():
|
|
94
92
|
self.device = "cuda"
|
|
@@ -101,7 +99,7 @@ class WhisperSTTBinding(LollmsSTTBinding):
|
|
|
101
99
|
|
|
102
100
|
self.loaded_model_name = None
|
|
103
101
|
self.model = None
|
|
104
|
-
self._load_whisper_model(model_name)
|
|
102
|
+
self._load_whisper_model(kwargs.get("model_name", "base")) # Default to "base" if not specified
|
|
105
103
|
|
|
106
104
|
|
|
107
105
|
def _load_whisper_model(self, model_name_to_load: str):
|
|
@@ -18,20 +18,19 @@ DEFAULT_WHISPERCPP_EXE_NAMES = ["main", "whisper-cli", "whisper"] # Common names
|
|
|
18
18
|
|
|
19
19
|
class WhisperCppSTTBinding(LollmsSTTBinding):
|
|
20
20
|
def __init__(self,
|
|
21
|
-
model_path: Union[str, Path], # Path to the GGUF Whisper model
|
|
22
|
-
whispercpp_exe_path: Optional[Union[str, Path]] = None, # Path to whisper.cpp executable
|
|
23
|
-
ffmpeg_path: Optional[Union[str, Path]] = None, # Path to ffmpeg executable (if not in PATH)
|
|
24
|
-
models_search_path: Optional[Union[str, Path]] = None, # Optional dir to scan for more models
|
|
25
|
-
default_language: str = "auto",
|
|
26
|
-
n_threads: int = 4,
|
|
27
|
-
# Catch LollmsSTTBinding standard args even if not directly used by this local binding
|
|
28
|
-
host_address: Optional[str] = None, # Not used for local binding
|
|
29
|
-
service_key: Optional[str] = None, # Not used for local binding
|
|
30
|
-
verify_ssl_certificate: bool = True, # Not used for local binding
|
|
31
21
|
**kwargs): # Catch-all for future compatibility or specific whisper.cpp params
|
|
32
22
|
|
|
33
|
-
super().__init__(binding_name="whispercpp")
|
|
34
|
-
|
|
23
|
+
super().__init__(binding_name="whispercpp")
|
|
24
|
+
|
|
25
|
+
# --- Extract values from kwargs with defaults ---
|
|
26
|
+
model_path = kwargs.get("model_path")
|
|
27
|
+
whispercpp_exe_path = kwargs.get("whispercpp_exe_path")
|
|
28
|
+
ffmpeg_path = kwargs.get("ffmpeg_path")
|
|
29
|
+
models_search_path = kwargs.get("models_search_path")
|
|
30
|
+
default_language = kwargs.get("default_language", "auto")
|
|
31
|
+
n_threads = kwargs.get("n_threads", 4)
|
|
32
|
+
extra_whisper_args = kwargs.get("extra_whisper_args", []) # e.g. ["--no-timestamps"]
|
|
33
|
+
|
|
35
34
|
# --- Validate FFMPEG ---
|
|
36
35
|
self.ffmpeg_exe = None
|
|
37
36
|
if ffmpeg_path:
|
|
@@ -42,7 +41,7 @@ class WhisperCppSTTBinding(LollmsSTTBinding):
|
|
|
42
41
|
raise FileNotFoundError(f"Provided ffmpeg_path '{ffmpeg_path}' not found or not executable.")
|
|
43
42
|
else:
|
|
44
43
|
self.ffmpeg_exe = shutil.which("ffmpeg")
|
|
45
|
-
|
|
44
|
+
|
|
46
45
|
if not self.ffmpeg_exe:
|
|
47
46
|
ASCIIColors.warning("ffmpeg not found in PATH or explicitly provided. Audio conversion will not be possible for non-WAV files or incompatible WAV files.")
|
|
48
47
|
ASCIIColors.warning("Please install ffmpeg and ensure it's in your system's PATH, or provide ffmpeg_path argument.")
|
|
@@ -63,7 +62,7 @@ class WhisperCppSTTBinding(LollmsSTTBinding):
|
|
|
63
62
|
self.whispercpp_exe = found_path
|
|
64
63
|
ASCIIColors.info(f"Found whisper.cpp executable via PATH: {self.whispercpp_exe}")
|
|
65
64
|
break
|
|
66
|
-
|
|
65
|
+
|
|
67
66
|
if not self.whispercpp_exe:
|
|
68
67
|
raise FileNotFoundError(
|
|
69
68
|
f"Whisper.cpp executable (tried: {', '.join(DEFAULT_WHISPERCPP_EXE_NAMES)}) not found in PATH or explicitly provided. "
|
|
@@ -79,11 +78,11 @@ class WhisperCppSTTBinding(LollmsSTTBinding):
|
|
|
79
78
|
self.model_path = Path(models_search_path, self.model_path).resolve()
|
|
80
79
|
else:
|
|
81
80
|
raise FileNotFoundError(f"Whisper GGUF model file not found at '{self.model_path}'. Also checked in models_search_path if applicable.")
|
|
82
|
-
|
|
81
|
+
|
|
83
82
|
self.models_search_path = Path(models_search_path).resolve() if models_search_path else None
|
|
84
83
|
self.default_language = default_language
|
|
85
84
|
self.n_threads = n_threads
|
|
86
|
-
self.extra_whisper_args =
|
|
85
|
+
self.extra_whisper_args = extra_whisper_args
|
|
87
86
|
|
|
88
87
|
ASCIIColors.green(f"WhisperCppSTTBinding initialized with model: {self.model_path}")
|
|
89
88
|
|
|
@@ -35,22 +35,11 @@ DALLE_MODELS = {
|
|
|
35
35
|
"max_prompt_length": 4000 # Characters
|
|
36
36
|
}
|
|
37
37
|
}
|
|
38
|
-
|
|
39
38
|
class DalleTTIBinding_Impl(LollmsTTIBinding):
|
|
40
39
|
"""
|
|
41
40
|
Concrete implementation of LollmsTTIBinding for OpenAI's DALL-E API.
|
|
42
41
|
"""
|
|
43
|
-
|
|
44
|
-
def __init__(self,
|
|
45
|
-
api_key: Optional[str] = None, # Can be None to check env var
|
|
46
|
-
model_name: str = "dall-e-3", # Default to DALL-E 3
|
|
47
|
-
default_size: Optional[str] = None, # e.g. "1024x1024"
|
|
48
|
-
default_quality: Optional[str] = None, # "standard" or "hd" (DALL-E 3)
|
|
49
|
-
default_style: Optional[str] = None, # "vivid" or "natural" (DALL-E 3)
|
|
50
|
-
host_address: str = DALLE_API_HOST, # OpenAI API host
|
|
51
|
-
verify_ssl_certificate: bool = True,
|
|
52
|
-
**kwargs # To catch any other lollms_client specific params like service_key/client_id
|
|
53
|
-
):
|
|
42
|
+
def __init__(self, **kwargs):
|
|
54
43
|
"""
|
|
55
44
|
Initialize the DALL-E TTI binding.
|
|
56
45
|
|
|
@@ -70,44 +59,56 @@ class DalleTTIBinding_Impl(LollmsTTIBinding):
|
|
|
70
59
|
"""
|
|
71
60
|
super().__init__(binding_name="dalle")
|
|
72
61
|
|
|
73
|
-
|
|
62
|
+
# Extract parameters from kwargs, providing defaults
|
|
63
|
+
self.api_key = kwargs.get("api_key")
|
|
64
|
+
self.model_name = kwargs.get("model_name")
|
|
65
|
+
self.default_size = kwargs.get("default_size")
|
|
66
|
+
self.default_quality = kwargs.get("default_quality")
|
|
67
|
+
self.default_style = kwargs.get("default_style")
|
|
68
|
+
self.host_address = kwargs.get("host_address", DALLE_API_HOST) # Provide default
|
|
69
|
+
self.verify_ssl_certificate = kwargs.get("verify_ssl_certificate", True) # Provide default
|
|
70
|
+
|
|
71
|
+
# Resolve API key from kwargs or environment variable
|
|
72
|
+
resolved_api_key = self.api_key
|
|
74
73
|
if not resolved_api_key:
|
|
75
74
|
ASCIIColors.info(f"API key not provided directly, checking environment variable '{OPENAI_API_KEY_ENV_VAR}'...")
|
|
76
75
|
resolved_api_key = os.environ.get(OPENAI_API_KEY_ENV_VAR)
|
|
77
76
|
|
|
78
77
|
if not resolved_api_key:
|
|
79
78
|
raise ValueError(f"OpenAI API key is required. Provide it directly or set the '{OPENAI_API_KEY_ENV_VAR}' environment variable.")
|
|
80
|
-
|
|
79
|
+
|
|
81
80
|
self.api_key = resolved_api_key
|
|
82
|
-
self.host_address = host_address
|
|
83
|
-
self.verify_ssl_certificate = verify_ssl_certificate
|
|
84
81
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
82
|
+
# Model name validation
|
|
83
|
+
if not self.model_name:
|
|
84
|
+
raise ValueError("Model name is required.")
|
|
85
|
+
if self.model_name not in DALLE_MODELS:
|
|
86
|
+
raise ValueError(f"Unsupported DALL-E model: {self.model_name}. Supported models: {list(DALLE_MODELS.keys())}")
|
|
87
|
+
|
|
89
88
|
model_props = DALLE_MODELS[self.model_name]
|
|
90
89
|
|
|
91
|
-
#
|
|
92
|
-
self.current_size = default_size or model_props["default_size"]
|
|
90
|
+
# Size
|
|
91
|
+
self.current_size = self.default_size or model_props["default_size"]
|
|
93
92
|
if self.current_size not in model_props["sizes"]:
|
|
94
93
|
raise ValueError(f"Unsupported size '{self.current_size}' for model '{self.model_name}'. Supported sizes: {model_props['sizes']}")
|
|
95
94
|
|
|
95
|
+
# Quality
|
|
96
96
|
if model_props["supports_quality"]:
|
|
97
|
-
self.current_quality = default_quality or model_props["default_quality"]
|
|
97
|
+
self.current_quality = self.default_quality or model_props["default_quality"]
|
|
98
98
|
if self.current_quality not in model_props["qualities"]:
|
|
99
99
|
raise ValueError(f"Unsupported quality '{self.current_quality}' for model '{self.model_name}'. Supported qualities: {model_props['qualities']}")
|
|
100
100
|
else:
|
|
101
|
-
self.current_quality = None
|
|
101
|
+
self.current_quality = None # Explicitly None if not supported
|
|
102
102
|
|
|
103
|
+
# Style
|
|
103
104
|
if model_props["supports_style"]:
|
|
104
|
-
self.current_style = default_style or model_props["default_style"]
|
|
105
|
+
self.current_style = self.default_style or model_props["default_style"]
|
|
105
106
|
if self.current_style not in model_props["styles"]:
|
|
106
107
|
raise ValueError(f"Unsupported style '{self.current_style}' for model '{self.model_name}'. Supported styles: {model_props['styles']}")
|
|
107
108
|
else:
|
|
108
|
-
self.current_style = None
|
|
109
|
-
|
|
110
|
-
#
|
|
109
|
+
self.current_style = None # Explicitly None if not supported
|
|
110
|
+
|
|
111
|
+
# Client ID
|
|
111
112
|
self.client_id = kwargs.get("service_key", kwargs.get("client_id", "dalle_client_user"))
|
|
112
113
|
|
|
113
114
|
|
|
@@ -126,25 +126,24 @@ class DiffusersTTIBinding_Impl(LollmsTTIBinding):
|
|
|
126
126
|
"torch_dtype_str": "auto", # "auto", "float16", "bfloat16", "float32"
|
|
127
127
|
"use_safetensors": True,
|
|
128
128
|
"scheduler_name": "default",
|
|
129
|
-
"safety_checker_on": True,
|
|
129
|
+
"safety_checker_on": True, # Note: Diffusers default is ON
|
|
130
130
|
"num_inference_steps": 25,
|
|
131
131
|
"guidance_scale": 7.5,
|
|
132
|
-
"default_width": 768,
|
|
133
|
-
"default_height": 768,
|
|
132
|
+
"default_width": 768, # Default for SD 2.1 base
|
|
133
|
+
"default_height": 768, # Default for SD 2.1 base
|
|
134
134
|
"seed": -1, # -1 for random on each call
|
|
135
135
|
"enable_cpu_offload": False,
|
|
136
136
|
"enable_sequential_cpu_offload": False,
|
|
137
|
-
"enable_xformers": False,
|
|
137
|
+
"enable_xformers": False, # Explicit opt-in for xformers
|
|
138
138
|
"hf_variant": None, # e.g., "fp16"
|
|
139
139
|
"hf_token": None,
|
|
140
140
|
"local_files_only": False,
|
|
141
141
|
}
|
|
142
142
|
|
|
143
|
-
|
|
144
143
|
def __init__(self,
|
|
145
144
|
config: Optional[Dict[str, Any]] = None,
|
|
146
145
|
lollms_paths: Optional[Dict[str, Union[str, Path]]] = None,
|
|
147
|
-
**kwargs
|
|
146
|
+
**kwargs # Catches other potential parameters like 'service_key' or 'client_id'
|
|
148
147
|
):
|
|
149
148
|
"""
|
|
150
149
|
Initialize the Diffusers TTI binding.
|
|
@@ -157,7 +156,7 @@ class DiffusersTTIBinding_Impl(LollmsTTIBinding):
|
|
|
157
156
|
**kwargs: Catches other parameters (e.g. service_key).
|
|
158
157
|
"""
|
|
159
158
|
super().__init__(binding_name="diffusers")
|
|
160
|
-
|
|
159
|
+
|
|
161
160
|
if not DIFFUSERS_AVAILABLE:
|
|
162
161
|
ASCIIColors.error("Diffusers library or its dependencies (torch, Pillow, transformers) are not installed or failed to import.")
|
|
163
162
|
ASCIIColors.info("Attempting to install/verify packages...")
|
|
@@ -171,7 +170,7 @@ class DiffusersTTIBinding_Impl(LollmsTTIBinding):
|
|
|
171
170
|
globals()['AutoPipelineForText2Image'] = _AutoPipelineForText2Image
|
|
172
171
|
globals()['DiffusionPipeline'] = _DiffusionPipeline
|
|
173
172
|
globals()['Image'] = _Image
|
|
174
|
-
|
|
173
|
+
|
|
175
174
|
# Re-populate torch dtype maps if torch was just loaded
|
|
176
175
|
global TORCH_DTYPE_MAP_STR_TO_OBJ, TORCH_DTYPE_MAP_OBJ_TO_STR
|
|
177
176
|
TORCH_DTYPE_MAP_STR_TO_OBJ = {
|
|
@@ -189,27 +188,32 @@ class DiffusersTTIBinding_Impl(LollmsTTIBinding):
|
|
|
189
188
|
f"Error: {e}"
|
|
190
189
|
) from e
|
|
191
190
|
|
|
191
|
+
# Merge configs, lollms_paths, and kwargs
|
|
192
192
|
self.config = {**self.DEFAULT_CONFIG, **(config or {}), **kwargs}
|
|
193
|
-
self.lollms_paths = {k: Path(v) for k, v in lollms_paths.items()} if lollms_paths else {}
|
|
194
|
-
|
|
193
|
+
self.lollms_paths = {k: Path(v) for k, v in (lollms_paths or {}).items()} if lollms_paths else {}
|
|
194
|
+
|
|
195
195
|
self.pipeline: Optional[DiffusionPipeline] = None
|
|
196
|
-
self.current_model_id_or_path = None
|
|
196
|
+
self.current_model_id_or_path = None # To track if model needs reload
|
|
197
197
|
|
|
198
198
|
# Resolve auto settings for device and dtype
|
|
199
199
|
if self.config["device"].lower() == "auto":
|
|
200
|
-
if torch.cuda.is_available():
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
200
|
+
if torch.cuda.is_available():
|
|
201
|
+
self.config["device"] = "cuda"
|
|
202
|
+
elif torch.backends.mps.is_available():
|
|
203
|
+
self.config["device"] = "mps"
|
|
204
|
+
else:
|
|
205
|
+
self.config["device"] = "cpu"
|
|
206
|
+
|
|
204
207
|
if self.config["torch_dtype_str"].lower() == "auto":
|
|
205
|
-
if self.config["device"] == "cpu":
|
|
206
|
-
|
|
208
|
+
if self.config["device"] == "cpu":
|
|
209
|
+
self.config["torch_dtype_str"] = "float32" # CPU usually float32
|
|
210
|
+
else:
|
|
211
|
+
self.config["torch_dtype_str"] = "float16" # Common default for GPU
|
|
207
212
|
|
|
208
213
|
self.torch_dtype = TORCH_DTYPE_MAP_STR_TO_OBJ.get(self.config["torch_dtype_str"].lower(), torch.float32)
|
|
209
|
-
if self.torch_dtype == "auto":
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
214
|
+
if self.torch_dtype == "auto": # Should have been resolved above
|
|
215
|
+
self.torch_dtype = torch.float16 if self.config["device"] != "cpu" else torch.float32
|
|
216
|
+
self.config["torch_dtype_str"] = TORCH_DTYPE_MAP_OBJ_TO_STR.get(self.torch_dtype, "float32")
|
|
213
217
|
|
|
214
218
|
# For potential lollms client specific features
|
|
215
219
|
self.client_id = kwargs.get("service_key", kwargs.get("client_id", "diffusers_client_user"))
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
# lollms_client/tti_bindings/gemini/__init__.py
|
|
2
|
+
import sys
|
|
3
|
+
from typing import Optional, List, Dict, Any, Union
|
|
4
|
+
|
|
5
|
+
from lollms_client.lollms_tti_binding import LollmsTTIBinding
|
|
6
|
+
from ascii_colors import trace_exception, ASCIIColors
|
|
7
|
+
import json
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import pipmaster as pm
|
|
11
|
+
# google-cloud-aiplatform is the main dependency for Vertex AI
|
|
12
|
+
pm.ensure_packages(['google-cloud-aiplatform', 'Pillow'])
|
|
13
|
+
import vertexai
|
|
14
|
+
from vertexai.preview.vision_models import ImageGenerationModel
|
|
15
|
+
from google.api_core import exceptions as google_exceptions
|
|
16
|
+
GEMINI_AVAILABLE = True
|
|
17
|
+
except ImportError as e:
|
|
18
|
+
GEMINI_AVAILABLE = False
|
|
19
|
+
_gemini_installation_error = e
|
|
20
|
+
|
|
21
|
+
# Defines the binding name for the manager
|
|
22
|
+
BindingName = "GeminiTTIBinding_Impl"
|
|
23
|
+
|
|
24
|
+
# Known Imagen models on Vertex AI
|
|
25
|
+
IMAGEN_MODELS = ["imagegeneration@006", "imagegeneration@005", "imagegeneration@002"]
|
|
26
|
+
|
|
27
|
+
class GeminiTTIBinding_Impl(LollmsTTIBinding):
|
|
28
|
+
"""
|
|
29
|
+
Concrete implementation of LollmsTTIBinding for Google's Imagen models via Vertex AI.
|
|
30
|
+
"""
|
|
31
|
+
DEFAULT_CONFIG = {
|
|
32
|
+
"project_id": None,
|
|
33
|
+
"location": "us-central1",
|
|
34
|
+
"model_name": IMAGEN_MODELS[0],
|
|
35
|
+
"seed": -1, # -1 for random
|
|
36
|
+
"guidance_scale": 7.5,
|
|
37
|
+
"number_of_images": 1
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
def __init__(self, config: Optional[Dict[str, Any]] = None, **kwargs):
|
|
41
|
+
"""
|
|
42
|
+
Initialize the Gemini (Vertex AI Imagen) TTI binding.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
config (Optional[Dict[str, Any]]): Configuration dictionary. Overrides DEFAULT_CONFIG.
|
|
46
|
+
**kwargs: Catches other potential parameters.
|
|
47
|
+
"""
|
|
48
|
+
super().__init__(binding_name="gemini")
|
|
49
|
+
|
|
50
|
+
if not GEMINI_AVAILABLE:
|
|
51
|
+
raise ImportError(
|
|
52
|
+
"Gemini (Vertex AI) binding dependencies are not met. "
|
|
53
|
+
"Please ensure 'google-cloud-aiplatform' is installed. "
|
|
54
|
+
f"Error: {_gemini_installation_error}"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
self.config = {**self.DEFAULT_CONFIG, **(config or {}), **kwargs}
|
|
58
|
+
self.model: Optional[ImageGenerationModel] = None
|
|
59
|
+
|
|
60
|
+
self._initialize_client()
|
|
61
|
+
|
|
62
|
+
def _initialize_client(self):
|
|
63
|
+
"""Initializes the Vertex AI client and loads the model."""
|
|
64
|
+
project_id = self.config.get("project_id")
|
|
65
|
+
location = self.config.get("location")
|
|
66
|
+
model_name = self.config.get("model_name")
|
|
67
|
+
|
|
68
|
+
if not project_id:
|
|
69
|
+
raise ValueError("Google Cloud 'project_id' is required for the Gemini (Vertex AI) binding.")
|
|
70
|
+
|
|
71
|
+
ASCIIColors.info("Initializing Vertex AI client...")
|
|
72
|
+
try:
|
|
73
|
+
vertexai.init(project=project_id, location=location)
|
|
74
|
+
self.model = ImageGenerationModel.from_pretrained(model_name)
|
|
75
|
+
ASCIIColors.green(f"Vertex AI initialized successfully. Loaded model: {model_name}")
|
|
76
|
+
except google_exceptions.PermissionDenied as e:
|
|
77
|
+
trace_exception(e)
|
|
78
|
+
raise Exception(
|
|
79
|
+
"Authentication failed. Ensure you have run 'gcloud auth application-default login' "
|
|
80
|
+
"and that the Vertex AI API is enabled for your project."
|
|
81
|
+
) from e
|
|
82
|
+
except Exception as e:
|
|
83
|
+
trace_exception(e)
|
|
84
|
+
raise Exception(f"Failed to initialize Vertex AI client: {e}") from e
|
|
85
|
+
|
|
86
|
+
def _validate_dimensions(self, width: int, height: int) -> None:
|
|
87
|
+
"""Validates image dimensions against Imagen 2 constraints."""
|
|
88
|
+
if not (256 <= width <= 1536 and width % 64 == 0):
|
|
89
|
+
raise ValueError(f"Invalid width: {width}. Must be between 256 and 1536 and a multiple of 64.")
|
|
90
|
+
if not (256 <= height <= 1536 and height % 64 == 0):
|
|
91
|
+
raise ValueError(f"Invalid height: {height}. Must be between 256 and 1536 and a multiple of 64.")
|
|
92
|
+
if width * height > 1536 * 1536: # Max pixels might be more constrained, 1536*1536 is a safe upper bound.
|
|
93
|
+
raise ValueError(f"Invalid dimensions: {width}x{height}. The total number of pixels cannot exceed 1536*1536.")
|
|
94
|
+
|
|
95
|
+
def generate_image(self,
|
|
96
|
+
prompt: str,
|
|
97
|
+
negative_prompt: Optional[str] = "",
|
|
98
|
+
width: int = 1024,
|
|
99
|
+
height: int = 1024,
|
|
100
|
+
**kwargs) -> bytes:
|
|
101
|
+
"""
|
|
102
|
+
Generates image data using the Vertex AI Imagen model.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
prompt (str): The positive text prompt.
|
|
106
|
+
negative_prompt (Optional[str]): The negative prompt.
|
|
107
|
+
width (int): Image width. Must be 256-1536 and a multiple of 64.
|
|
108
|
+
height (int): Image height. Must be 256-1536 and a multiple of 64.
|
|
109
|
+
**kwargs: Additional parameters:
|
|
110
|
+
- seed (int)
|
|
111
|
+
- guidance_scale (float)
|
|
112
|
+
Returns:
|
|
113
|
+
bytes: The generated image data (PNG format).
|
|
114
|
+
Raises:
|
|
115
|
+
Exception: If the request fails or image generation fails.
|
|
116
|
+
"""
|
|
117
|
+
if not self.model:
|
|
118
|
+
raise RuntimeError("Vertex AI model is not loaded. Cannot generate image.")
|
|
119
|
+
|
|
120
|
+
self._validate_dimensions(width, height)
|
|
121
|
+
|
|
122
|
+
seed = kwargs.get("seed", self.config["seed"])
|
|
123
|
+
guidance_scale = kwargs.get("guidance_scale", self.config["guidance_scale"])
|
|
124
|
+
|
|
125
|
+
# Use -1 for random seed, otherwise pass the integer value.
|
|
126
|
+
gen_seed = seed if seed != -1 else None
|
|
127
|
+
|
|
128
|
+
gen_params = {
|
|
129
|
+
"prompt": prompt,
|
|
130
|
+
"negative_prompt": negative_prompt,
|
|
131
|
+
"number_of_images": 1, # This binding returns one image
|
|
132
|
+
"width": width,
|
|
133
|
+
"height": height,
|
|
134
|
+
"guidance_scale": guidance_scale,
|
|
135
|
+
}
|
|
136
|
+
if gen_seed is not None:
|
|
137
|
+
gen_params["seed"] = gen_seed
|
|
138
|
+
|
|
139
|
+
ASCIIColors.info(f"Generating image with prompt: '{prompt[:100]}...'")
|
|
140
|
+
ASCIIColors.debug(f"Imagen generation parameters: {gen_params}")
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
response = self.model.generate_images(**gen_params)
|
|
144
|
+
|
|
145
|
+
if not response.images:
|
|
146
|
+
raise Exception("Image generation resulted in no images. This may be due to safety filters.")
|
|
147
|
+
|
|
148
|
+
img_bytes = response.images[0]._image_bytes
|
|
149
|
+
return img_bytes
|
|
150
|
+
|
|
151
|
+
except google_exceptions.InvalidArgument as e:
|
|
152
|
+
trace_exception(e)
|
|
153
|
+
raise ValueError(f"Invalid argument sent to Vertex AI API: {e.message}") from e
|
|
154
|
+
except google_exceptions.GoogleAPICallError as e:
|
|
155
|
+
trace_exception(e)
|
|
156
|
+
raise Exception(f"A Google API call error occurred: {e.message}") from e
|
|
157
|
+
except Exception as e:
|
|
158
|
+
trace_exception(e)
|
|
159
|
+
raise Exception(f"Imagen image generation failed: {e}") from e
|
|
160
|
+
|
|
161
|
+
def list_services(self, **kwargs) -> List[Dict[str, str]]:
|
|
162
|
+
"""
|
|
163
|
+
Lists available Imagen models supported by this binding.
|
|
164
|
+
"""
|
|
165
|
+
services = []
|
|
166
|
+
for model_name in IMAGEN_MODELS:
|
|
167
|
+
services.append({
|
|
168
|
+
"name": model_name,
|
|
169
|
+
"caption": f"Google Imagen 2 ({model_name})",
|
|
170
|
+
"help": "High-quality text-to-image model from Google, available on Vertex AI."
|
|
171
|
+
})
|
|
172
|
+
return services
|
|
173
|
+
|
|
174
|
+
def get_settings(self, **kwargs) -> List[Dict[str, Any]]:
|
|
175
|
+
"""
|
|
176
|
+
Retrieves the current configurable settings for the binding.
|
|
177
|
+
"""
|
|
178
|
+
return [
|
|
179
|
+
{"name": "project_id", "type": "str", "value": self.config["project_id"], "description": "Your Google Cloud project ID."},
|
|
180
|
+
{"name": "location", "type": "str", "value": self.config["location"], "description": "Google Cloud region for the project (e.g., 'us-central1')."},
|
|
181
|
+
{"name": "model_name", "type": "str", "value": self.config["model_name"], "description": "The Imagen model version to use.", "options": IMAGEN_MODELS},
|
|
182
|
+
{"name": "seed", "type": "int", "value": self.config["seed"], "description": "Default seed for generation (-1 for random)."},
|
|
183
|
+
{"name": "guidance_scale", "type": "float", "value": self.config["guidance_scale"], "description": "Default guidance scale (CFG). Higher values follow the prompt more strictly."},
|
|
184
|
+
]
|
|
185
|
+
|
|
186
|
+
def set_settings(self, settings: Union[Dict[str, Any], List[Dict[str, Any]]], **kwargs) -> bool:
|
|
187
|
+
"""
|
|
188
|
+
Applies new settings to the binding. Re-initializes the client if needed.
|
|
189
|
+
"""
|
|
190
|
+
if isinstance(settings, list):
|
|
191
|
+
parsed_settings = {item["name"]: item["value"] for item in settings if "name" in item and "value" in item}
|
|
192
|
+
elif isinstance(settings, dict):
|
|
193
|
+
parsed_settings = settings
|
|
194
|
+
else:
|
|
195
|
+
ASCIIColors.error("Invalid settings format. Expected a dictionary or list of dictionaries.")
|
|
196
|
+
return False
|
|
197
|
+
|
|
198
|
+
needs_reinit = False
|
|
199
|
+
for key, value in parsed_settings.items():
|
|
200
|
+
if key in self.config and self.config[key] != value:
|
|
201
|
+
self.config[key] = value
|
|
202
|
+
ASCIIColors.info(f"Setting '{key}' changed to: {value}")
|
|
203
|
+
if key in ["project_id", "location", "model_name"]:
|
|
204
|
+
needs_reinit = True
|
|
205
|
+
|
|
206
|
+
if needs_reinit:
|
|
207
|
+
try:
|
|
208
|
+
self._initialize_client()
|
|
209
|
+
ASCIIColors.green("Vertex AI client re-initialized successfully with new settings.")
|
|
210
|
+
except Exception as e:
|
|
211
|
+
ASCIIColors.error(f"Failed to re-initialize client with new settings: {e}")
|
|
212
|
+
# Optionally, revert to old config here to maintain a working state
|
|
213
|
+
return False
|
|
214
|
+
|
|
215
|
+
return True
|
|
@@ -9,15 +9,12 @@ from ascii_colors import trace_exception, ASCIIColors
|
|
|
9
9
|
import json # Added for potential error parsing
|
|
10
10
|
|
|
11
11
|
# Defines the binding name for the manager
|
|
12
|
-
BindingName = "
|
|
13
|
-
|
|
14
|
-
class LollmsTTIBinding_Impl(LollmsTTIBinding):
|
|
12
|
+
BindingName = "LollmsWebuiTTIBinding_Impl"
|
|
13
|
+
class LollmsWebuiTTIBinding_Impl(LollmsTTIBinding):
|
|
15
14
|
"""Concrete implementation of the LollmsTTIBinding for the standard LOLLMS server."""
|
|
16
15
|
|
|
17
16
|
def __init__(self,
|
|
18
|
-
|
|
19
|
-
service_key: Optional[str] = None,
|
|
20
|
-
verify_ssl_certificate: bool = True):
|
|
17
|
+
**kwargs):
|
|
21
18
|
"""
|
|
22
19
|
Initialize the LOLLMS TTI binding.
|
|
23
20
|
|
|
@@ -27,12 +24,14 @@ class LollmsTTIBinding_Impl(LollmsTTIBinding):
|
|
|
27
24
|
verify_ssl_certificate (bool): Whether to verify SSL certificates.
|
|
28
25
|
"""
|
|
29
26
|
super().__init__(binding_name="lollms")
|
|
30
|
-
|
|
31
|
-
|
|
27
|
+
|
|
28
|
+
# Extract parameters from kwargs, providing defaults
|
|
29
|
+
self.host_address = kwargs.get("host_address", "http://localhost:9600") # Default LOLLMS host
|
|
30
|
+
self.verify_ssl_certificate = kwargs.get("verify_ssl_certificate", True)
|
|
32
31
|
|
|
33
32
|
# The 'service_key' here will act as the 'client_id' for TTI requests if provided.
|
|
34
33
|
# This assumes the client library user provides their LOLLMS client_id here.
|
|
35
|
-
self.client_id = service_key
|
|
34
|
+
self.client_id = kwargs.get("service_key", None) # Use service_key or None
|
|
36
35
|
|
|
37
36
|
def _get_client_id(self, **kwargs) -> str:
|
|
38
37
|
"""Helper to get client_id, prioritizing kwargs then instance default."""
|