lollms-client 0.33.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lollms-client might be problematic. Click here for more details.

Files changed (73) hide show
  1. lollms_client/__init__.py +1 -1
  2. lollms_client/llm_bindings/azure_openai/__init__.py +6 -10
  3. lollms_client/llm_bindings/claude/__init__.py +4 -7
  4. lollms_client/llm_bindings/gemini/__init__.py +3 -7
  5. lollms_client/llm_bindings/grok/__init__.py +3 -7
  6. lollms_client/llm_bindings/groq/__init__.py +4 -6
  7. lollms_client/llm_bindings/hugging_face_inference_api/__init__.py +4 -6
  8. lollms_client/llm_bindings/litellm/__init__.py +15 -6
  9. lollms_client/llm_bindings/llamacpp/__init__.py +27 -9
  10. lollms_client/llm_bindings/lollms/__init__.py +24 -14
  11. lollms_client/llm_bindings/lollms_webui/__init__.py +6 -12
  12. lollms_client/llm_bindings/mistral/__init__.py +3 -5
  13. lollms_client/llm_bindings/ollama/__init__.py +6 -11
  14. lollms_client/llm_bindings/open_router/__init__.py +4 -6
  15. lollms_client/llm_bindings/openai/__init__.py +7 -14
  16. lollms_client/llm_bindings/openllm/__init__.py +12 -12
  17. lollms_client/llm_bindings/pythonllamacpp/__init__.py +1 -1
  18. lollms_client/llm_bindings/tensor_rt/__init__.py +8 -13
  19. lollms_client/llm_bindings/transformers/__init__.py +14 -6
  20. lollms_client/llm_bindings/vllm/__init__.py +16 -12
  21. lollms_client/lollms_core.py +296 -487
  22. lollms_client/lollms_discussion.py +431 -78
  23. lollms_client/lollms_llm_binding.py +191 -380
  24. lollms_client/lollms_mcp_binding.py +33 -2
  25. lollms_client/mcp_bindings/local_mcp/__init__.py +3 -2
  26. lollms_client/mcp_bindings/remote_mcp/__init__.py +6 -5
  27. lollms_client/mcp_bindings/standard_mcp/__init__.py +3 -5
  28. lollms_client/stt_bindings/lollms/__init__.py +6 -8
  29. lollms_client/stt_bindings/whisper/__init__.py +2 -4
  30. lollms_client/stt_bindings/whispercpp/__init__.py +15 -16
  31. lollms_client/tti_bindings/dalle/__init__.py +29 -28
  32. lollms_client/tti_bindings/diffusers/__init__.py +25 -21
  33. lollms_client/tti_bindings/gemini/__init__.py +215 -0
  34. lollms_client/tti_bindings/lollms/__init__.py +8 -9
  35. lollms_client-1.0.0.dist-info/METADATA +1214 -0
  36. lollms_client-1.0.0.dist-info/RECORD +69 -0
  37. {lollms_client-0.33.0.dist-info → lollms_client-1.0.0.dist-info}/top_level.txt +0 -2
  38. examples/article_summary/article_summary.py +0 -58
  39. examples/console_discussion/console_app.py +0 -266
  40. examples/console_discussion.py +0 -448
  41. examples/deep_analyze/deep_analyse.py +0 -30
  42. examples/deep_analyze/deep_analyze_multiple_files.py +0 -32
  43. examples/function_calling_with_local_custom_mcp.py +0 -250
  44. examples/generate_a_benchmark_for_safe_store.py +0 -89
  45. examples/generate_and_speak/generate_and_speak.py +0 -251
  46. examples/generate_game_sfx/generate_game_fx.py +0 -240
  47. examples/generate_text_with_multihop_rag_example.py +0 -210
  48. examples/gradio_chat_app.py +0 -228
  49. examples/gradio_lollms_chat.py +0 -259
  50. examples/internet_search_with_rag.py +0 -226
  51. examples/lollms_chat/calculator.py +0 -59
  52. examples/lollms_chat/derivative.py +0 -48
  53. examples/lollms_chat/test_openai_compatible_with_lollms_chat.py +0 -12
  54. examples/lollms_discussions_test.py +0 -155
  55. examples/mcp_examples/external_mcp.py +0 -267
  56. examples/mcp_examples/local_mcp.py +0 -171
  57. examples/mcp_examples/openai_mcp.py +0 -203
  58. examples/mcp_examples/run_remote_mcp_example_v2.py +0 -290
  59. examples/mcp_examples/run_standard_mcp_example.py +0 -204
  60. examples/simple_text_gen_test.py +0 -173
  61. examples/simple_text_gen_with_image_test.py +0 -178
  62. examples/test_local_models/local_chat.py +0 -9
  63. examples/text_2_audio.py +0 -77
  64. examples/text_2_image.py +0 -144
  65. examples/text_2_image_diffusers.py +0 -274
  66. examples/text_and_image_2_audio.py +0 -59
  67. examples/text_gen.py +0 -30
  68. examples/text_gen_system_prompt.py +0 -29
  69. lollms_client-0.33.0.dist-info/METADATA +0 -854
  70. lollms_client-0.33.0.dist-info/RECORD +0 -101
  71. test/test_lollms_discussion.py +0 -368
  72. {lollms_client-0.33.0.dist-info → lollms_client-1.0.0.dist-info}/WHEEL +0 -0
  73. {lollms_client-0.33.0.dist-info → lollms_client-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -21,15 +21,16 @@ class LocalMCPBinding(LollmsMCPBinding):
21
21
  """
22
22
 
23
23
  def __init__(self,
24
- tools_folder_path: str|Path|None = None):
24
+ **kwargs: Any
25
+ ):
25
26
  """
26
27
  Initialize the LocalMCPBinding.
27
28
 
28
29
  Args:
29
- binding_name (str): The name of this binding.
30
30
  tools_folder_path (str|Path) a folder where to find tools
31
31
  """
32
32
  super().__init__(binding_name="LocalMCP")
33
+ tools_folder_path = kwargs.get("tools_folder_path")
33
34
  if tools_folder_path:
34
35
  try:
35
36
  self.tools_folder_path: Optional[Path] = Path(tools_folder_path)
@@ -27,8 +27,8 @@ class RemoteMCPBinding(LollmsMCPBinding):
27
27
  Tools from all connected servers are aggregated and prefixed with the server's alias.
28
28
  """
29
29
  def __init__(self,
30
- servers_infos: Dict[str, Dict[str, Any]],
31
- **other_config_params: Any):
30
+ **kwargs: Any
31
+ ):
32
32
  """
33
33
  Initializes the binding to connect to multiple MCP servers.
34
34
 
@@ -41,10 +41,11 @@ class RemoteMCPBinding(LollmsMCPBinding):
41
41
  "main_server": {"server_url": "http://localhost:8787", "auth_config": {}},
42
42
  "experimental_server": {"server_url": "http://test.server:9000"}
43
43
  }
44
- **other_config_params (Any): Additional configuration parameters.
44
+ **kwargs (Any): Additional configuration parameters.
45
45
  """
46
46
  super().__init__(binding_name="remote_mcp")
47
47
  # initialization in case no servers are present
48
+ servers_infos: Dict[str, Dict[str, Any]] = kwargs.get("servers_infos", {})
48
49
  self.servers = None
49
50
  if not MCP_LIBRARY_AVAILABLE:
50
51
  ASCIIColors.error(f"{self.binding_name}: MCP library not available. This binding will be disabled.")
@@ -56,8 +57,8 @@ class RemoteMCPBinding(LollmsMCPBinding):
56
57
 
57
58
  ### NEW: Store the overall configuration
58
59
  self.config = {
59
- "servers_infos": servers_infos,
60
- **other_config_params
60
+ "servers_infos": kwargs.get("servers_infos"),
61
+ **kwargs
61
62
  }
62
63
 
63
64
  ### NEW: State management for multiple servers.
@@ -48,12 +48,10 @@ class StandardMCPBinding(LollmsMCPBinding):
48
48
  """
49
49
 
50
50
  def __init__(self,
51
- initial_servers: Optional[Dict[str, Dict[str, Any]]] = None,
52
- **other_config_params: Any):
51
+ **kwargs: Any):
53
52
  super().__init__(binding_name="standard_mcp")
54
-
55
- self.config = {"initial_servers": initial_servers if initial_servers else {}}
56
- self.config.update(other_config_params)
53
+ self.config = kwargs
54
+ initial_servers = kwargs.get("initial_servers", {})
57
55
 
58
56
  self._server_configs: Dict[str, Dict[str, Any]] = {}
59
57
  # Type hint with ClientSession, actual obj if MCP_LIBRARY_AVAILABLE
@@ -14,10 +14,8 @@ class LollmsSTTBinding_Impl(LollmsSTTBinding):
14
14
  """Concrete implementation of the LollmsSTTBinding for the standard LOLLMS server."""
15
15
 
16
16
  def __init__(self,
17
- host_address: Optional[str] = "http://localhost:9600", # Default LOLLMS host
18
- model_name: Optional[str] = None, # Default model (server decides if None)
19
- service_key: Optional[str] = None,
20
- verify_ssl_certificate: bool = True):
17
+ **kwargs
18
+ ):
21
19
  """
22
20
  Initialize the LOLLMS STT binding.
23
21
 
@@ -28,10 +26,10 @@ class LollmsSTTBinding_Impl(LollmsSTTBinding):
28
26
  verify_ssl_certificate (bool): Whether to verify SSL certificates.
29
27
  """
30
28
  super().__init__("lollms")
31
- self.host_address=host_address
32
- self.model_name=model_name
33
- self.service_key=service_key
34
- self.verify_ssl_certificate=verify_ssl_certificate
29
+ self.host_address=kwargs.get("host_address")
30
+ self.model_name=kwargs.get("model_name")
31
+ self.service_key=kwargs.get("service_key")
32
+ self.verify_ssl_certificate=kwargs.get("verify_ssl_certificate")
35
33
 
36
34
  def transcribe_audio(self, audio_path: Union[str, Path], model: Optional[str] = None, **kwargs) -> str:
37
35
  """
@@ -70,8 +70,6 @@ class WhisperSTTBinding(LollmsSTTBinding):
70
70
  WHISPER_MODEL_SIZES = ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2", "large-v3"]
71
71
 
72
72
  def __init__(self,
73
- model_name: str = "base", # Default Whisper model size
74
- device: Optional[str] = None, # "cpu", "cuda", "mps", or None for auto
75
73
  **kwargs # To catch any other LollmsSTTBinding standard args
76
74
  ):
77
75
  """
@@ -88,7 +86,7 @@ class WhisperSTTBinding(LollmsSTTBinding):
88
86
  if not _whisper_installed:
89
87
  raise ImportError(f"Whisper STT binding dependencies not met. Please ensure 'openai-whisper' and 'torch' are installed. Error: {_whisper_installation_error}")
90
88
 
91
- self.device = device
89
+ self.device = kwargs.get("device",None)
92
90
  if self.device is None: # Auto-detect if not specified
93
91
  if torch.cuda.is_available():
94
92
  self.device = "cuda"
@@ -101,7 +99,7 @@ class WhisperSTTBinding(LollmsSTTBinding):
101
99
 
102
100
  self.loaded_model_name = None
103
101
  self.model = None
104
- self._load_whisper_model(model_name)
102
+ self._load_whisper_model(kwargs.get("model_name", "base")) # Default to "base" if not specified
105
103
 
106
104
 
107
105
  def _load_whisper_model(self, model_name_to_load: str):
@@ -18,20 +18,19 @@ DEFAULT_WHISPERCPP_EXE_NAMES = ["main", "whisper-cli", "whisper"] # Common names
18
18
 
19
19
  class WhisperCppSTTBinding(LollmsSTTBinding):
20
20
  def __init__(self,
21
- model_path: Union[str, Path], # Path to the GGUF Whisper model
22
- whispercpp_exe_path: Optional[Union[str, Path]] = None, # Path to whisper.cpp executable
23
- ffmpeg_path: Optional[Union[str, Path]] = None, # Path to ffmpeg executable (if not in PATH)
24
- models_search_path: Optional[Union[str, Path]] = None, # Optional dir to scan for more models
25
- default_language: str = "auto",
26
- n_threads: int = 4,
27
- # Catch LollmsSTTBinding standard args even if not directly used by this local binding
28
- host_address: Optional[str] = None, # Not used for local binding
29
- service_key: Optional[str] = None, # Not used for local binding
30
- verify_ssl_certificate: bool = True, # Not used for local binding
31
21
  **kwargs): # Catch-all for future compatibility or specific whisper.cpp params
32
22
 
33
- super().__init__(binding_name="whispercpp")
34
-
23
+ super().__init__(binding_name="whispercpp")
24
+
25
+ # --- Extract values from kwargs with defaults ---
26
+ model_path = kwargs.get("model_path")
27
+ whispercpp_exe_path = kwargs.get("whispercpp_exe_path")
28
+ ffmpeg_path = kwargs.get("ffmpeg_path")
29
+ models_search_path = kwargs.get("models_search_path")
30
+ default_language = kwargs.get("default_language", "auto")
31
+ n_threads = kwargs.get("n_threads", 4)
32
+ extra_whisper_args = kwargs.get("extra_whisper_args", []) # e.g. ["--no-timestamps"]
33
+
35
34
  # --- Validate FFMPEG ---
36
35
  self.ffmpeg_exe = None
37
36
  if ffmpeg_path:
@@ -42,7 +41,7 @@ class WhisperCppSTTBinding(LollmsSTTBinding):
42
41
  raise FileNotFoundError(f"Provided ffmpeg_path '{ffmpeg_path}' not found or not executable.")
43
42
  else:
44
43
  self.ffmpeg_exe = shutil.which("ffmpeg")
45
-
44
+
46
45
  if not self.ffmpeg_exe:
47
46
  ASCIIColors.warning("ffmpeg not found in PATH or explicitly provided. Audio conversion will not be possible for non-WAV files or incompatible WAV files.")
48
47
  ASCIIColors.warning("Please install ffmpeg and ensure it's in your system's PATH, or provide ffmpeg_path argument.")
@@ -63,7 +62,7 @@ class WhisperCppSTTBinding(LollmsSTTBinding):
63
62
  self.whispercpp_exe = found_path
64
63
  ASCIIColors.info(f"Found whisper.cpp executable via PATH: {self.whispercpp_exe}")
65
64
  break
66
-
65
+
67
66
  if not self.whispercpp_exe:
68
67
  raise FileNotFoundError(
69
68
  f"Whisper.cpp executable (tried: {', '.join(DEFAULT_WHISPERCPP_EXE_NAMES)}) not found in PATH or explicitly provided. "
@@ -79,11 +78,11 @@ class WhisperCppSTTBinding(LollmsSTTBinding):
79
78
  self.model_path = Path(models_search_path, self.model_path).resolve()
80
79
  else:
81
80
  raise FileNotFoundError(f"Whisper GGUF model file not found at '{self.model_path}'. Also checked in models_search_path if applicable.")
82
-
81
+
83
82
  self.models_search_path = Path(models_search_path).resolve() if models_search_path else None
84
83
  self.default_language = default_language
85
84
  self.n_threads = n_threads
86
- self.extra_whisper_args = kwargs.get("extra_whisper_args", []) # e.g. ["--no-timestamps"]
85
+ self.extra_whisper_args = extra_whisper_args
87
86
 
88
87
  ASCIIColors.green(f"WhisperCppSTTBinding initialized with model: {self.model_path}")
89
88
 
@@ -35,22 +35,11 @@ DALLE_MODELS = {
35
35
  "max_prompt_length": 4000 # Characters
36
36
  }
37
37
  }
38
-
39
38
  class DalleTTIBinding_Impl(LollmsTTIBinding):
40
39
  """
41
40
  Concrete implementation of LollmsTTIBinding for OpenAI's DALL-E API.
42
41
  """
43
-
44
- def __init__(self,
45
- api_key: Optional[str] = None, # Can be None to check env var
46
- model_name: str = "dall-e-3", # Default to DALL-E 3
47
- default_size: Optional[str] = None, # e.g. "1024x1024"
48
- default_quality: Optional[str] = None, # "standard" or "hd" (DALL-E 3)
49
- default_style: Optional[str] = None, # "vivid" or "natural" (DALL-E 3)
50
- host_address: str = DALLE_API_HOST, # OpenAI API host
51
- verify_ssl_certificate: bool = True,
52
- **kwargs # To catch any other lollms_client specific params like service_key/client_id
53
- ):
42
+ def __init__(self, **kwargs):
54
43
  """
55
44
  Initialize the DALL-E TTI binding.
56
45
 
@@ -70,44 +59,56 @@ class DalleTTIBinding_Impl(LollmsTTIBinding):
70
59
  """
71
60
  super().__init__(binding_name="dalle")
72
61
 
73
- resolved_api_key = api_key
62
+ # Extract parameters from kwargs, providing defaults
63
+ self.api_key = kwargs.get("api_key")
64
+ self.model_name = kwargs.get("model_name")
65
+ self.default_size = kwargs.get("default_size")
66
+ self.default_quality = kwargs.get("default_quality")
67
+ self.default_style = kwargs.get("default_style")
68
+ self.host_address = kwargs.get("host_address", DALLE_API_HOST) # Provide default
69
+ self.verify_ssl_certificate = kwargs.get("verify_ssl_certificate", True) # Provide default
70
+
71
+ # Resolve API key from kwargs or environment variable
72
+ resolved_api_key = self.api_key
74
73
  if not resolved_api_key:
75
74
  ASCIIColors.info(f"API key not provided directly, checking environment variable '{OPENAI_API_KEY_ENV_VAR}'...")
76
75
  resolved_api_key = os.environ.get(OPENAI_API_KEY_ENV_VAR)
77
76
 
78
77
  if not resolved_api_key:
79
78
  raise ValueError(f"OpenAI API key is required. Provide it directly or set the '{OPENAI_API_KEY_ENV_VAR}' environment variable.")
80
-
79
+
81
80
  self.api_key = resolved_api_key
82
- self.host_address = host_address
83
- self.verify_ssl_certificate = verify_ssl_certificate
84
81
 
85
- if model_name not in DALLE_MODELS:
86
- raise ValueError(f"Unsupported DALL-E model: {model_name}. Supported models: {list(DALLE_MODELS.keys())}")
87
- self.model_name = model_name
88
-
82
+ # Model name validation
83
+ if not self.model_name:
84
+ raise ValueError("Model name is required.")
85
+ if self.model_name not in DALLE_MODELS:
86
+ raise ValueError(f"Unsupported DALL-E model: {self.model_name}. Supported models: {list(DALLE_MODELS.keys())}")
87
+
89
88
  model_props = DALLE_MODELS[self.model_name]
90
89
 
91
- # Set defaults from model_props, overridden by user-provided defaults
92
- self.current_size = default_size or model_props["default_size"]
90
+ # Size
91
+ self.current_size = self.default_size or model_props["default_size"]
93
92
  if self.current_size not in model_props["sizes"]:
94
93
  raise ValueError(f"Unsupported size '{self.current_size}' for model '{self.model_name}'. Supported sizes: {model_props['sizes']}")
95
94
 
95
+ # Quality
96
96
  if model_props["supports_quality"]:
97
- self.current_quality = default_quality or model_props["default_quality"]
97
+ self.current_quality = self.default_quality or model_props["default_quality"]
98
98
  if self.current_quality not in model_props["qualities"]:
99
99
  raise ValueError(f"Unsupported quality '{self.current_quality}' for model '{self.model_name}'. Supported qualities: {model_props['qualities']}")
100
100
  else:
101
- self.current_quality = None # Explicitly None if not supported
101
+ self.current_quality = None # Explicitly None if not supported
102
102
 
103
+ # Style
103
104
  if model_props["supports_style"]:
104
- self.current_style = default_style or model_props["default_style"]
105
+ self.current_style = self.default_style or model_props["default_style"]
105
106
  if self.current_style not in model_props["styles"]:
106
107
  raise ValueError(f"Unsupported style '{self.current_style}' for model '{self.model_name}'. Supported styles: {model_props['styles']}")
107
108
  else:
108
- self.current_style = None # Explicitly None if not supported
109
-
110
- # For potential lollms client specific features, if `service_key` is passed as `client_id`
109
+ self.current_style = None # Explicitly None if not supported
110
+
111
+ # Client ID
111
112
  self.client_id = kwargs.get("service_key", kwargs.get("client_id", "dalle_client_user"))
112
113
 
113
114
 
@@ -126,25 +126,24 @@ class DiffusersTTIBinding_Impl(LollmsTTIBinding):
126
126
  "torch_dtype_str": "auto", # "auto", "float16", "bfloat16", "float32"
127
127
  "use_safetensors": True,
128
128
  "scheduler_name": "default",
129
- "safety_checker_on": True, # Note: Diffusers default is ON
129
+ "safety_checker_on": True, # Note: Diffusers default is ON
130
130
  "num_inference_steps": 25,
131
131
  "guidance_scale": 7.5,
132
- "default_width": 768, # Default for SD 2.1 base
133
- "default_height": 768, # Default for SD 2.1 base
132
+ "default_width": 768, # Default for SD 2.1 base
133
+ "default_height": 768, # Default for SD 2.1 base
134
134
  "seed": -1, # -1 for random on each call
135
135
  "enable_cpu_offload": False,
136
136
  "enable_sequential_cpu_offload": False,
137
- "enable_xformers": False, # Explicit opt-in for xformers
137
+ "enable_xformers": False, # Explicit opt-in for xformers
138
138
  "hf_variant": None, # e.g., "fp16"
139
139
  "hf_token": None,
140
140
  "local_files_only": False,
141
141
  }
142
142
 
143
-
144
143
  def __init__(self,
145
144
  config: Optional[Dict[str, Any]] = None,
146
145
  lollms_paths: Optional[Dict[str, Union[str, Path]]] = None,
147
- **kwargs # Catches other potential parameters like 'service_key' or 'client_id'
146
+ **kwargs # Catches other potential parameters like 'service_key' or 'client_id'
148
147
  ):
149
148
  """
150
149
  Initialize the Diffusers TTI binding.
@@ -157,7 +156,7 @@ class DiffusersTTIBinding_Impl(LollmsTTIBinding):
157
156
  **kwargs: Catches other parameters (e.g. service_key).
158
157
  """
159
158
  super().__init__(binding_name="diffusers")
160
-
159
+
161
160
  if not DIFFUSERS_AVAILABLE:
162
161
  ASCIIColors.error("Diffusers library or its dependencies (torch, Pillow, transformers) are not installed or failed to import.")
163
162
  ASCIIColors.info("Attempting to install/verify packages...")
@@ -171,7 +170,7 @@ class DiffusersTTIBinding_Impl(LollmsTTIBinding):
171
170
  globals()['AutoPipelineForText2Image'] = _AutoPipelineForText2Image
172
171
  globals()['DiffusionPipeline'] = _DiffusionPipeline
173
172
  globals()['Image'] = _Image
174
-
173
+
175
174
  # Re-populate torch dtype maps if torch was just loaded
176
175
  global TORCH_DTYPE_MAP_STR_TO_OBJ, TORCH_DTYPE_MAP_OBJ_TO_STR
177
176
  TORCH_DTYPE_MAP_STR_TO_OBJ = {
@@ -189,27 +188,32 @@ class DiffusersTTIBinding_Impl(LollmsTTIBinding):
189
188
  f"Error: {e}"
190
189
  ) from e
191
190
 
191
+ # Merge configs, lollms_paths, and kwargs
192
192
  self.config = {**self.DEFAULT_CONFIG, **(config or {}), **kwargs}
193
- self.lollms_paths = {k: Path(v) for k, v in lollms_paths.items()} if lollms_paths else {}
194
-
193
+ self.lollms_paths = {k: Path(v) for k, v in (lollms_paths or {}).items()} if lollms_paths else {}
194
+
195
195
  self.pipeline: Optional[DiffusionPipeline] = None
196
- self.current_model_id_or_path = None # To track if model needs reload
196
+ self.current_model_id_or_path = None # To track if model needs reload
197
197
 
198
198
  # Resolve auto settings for device and dtype
199
199
  if self.config["device"].lower() == "auto":
200
- if torch.cuda.is_available(): self.config["device"] = "cuda"
201
- elif torch.backends.mps.is_available(): self.config["device"] = "mps"
202
- else: self.config["device"] = "cpu"
203
-
200
+ if torch.cuda.is_available():
201
+ self.config["device"] = "cuda"
202
+ elif torch.backends.mps.is_available():
203
+ self.config["device"] = "mps"
204
+ else:
205
+ self.config["device"] = "cpu"
206
+
204
207
  if self.config["torch_dtype_str"].lower() == "auto":
205
- if self.config["device"] == "cpu": self.config["torch_dtype_str"] = "float32" # CPU usually float32
206
- else: self.config["torch_dtype_str"] = "float16" # Common default for GPU
208
+ if self.config["device"] == "cpu":
209
+ self.config["torch_dtype_str"] = "float32" # CPU usually float32
210
+ else:
211
+ self.config["torch_dtype_str"] = "float16" # Common default for GPU
207
212
 
208
213
  self.torch_dtype = TORCH_DTYPE_MAP_STR_TO_OBJ.get(self.config["torch_dtype_str"].lower(), torch.float32)
209
- if self.torch_dtype == "auto": # Should have been resolved above
210
- self.torch_dtype = torch.float16 if self.config["device"] != "cpu" else torch.float32
211
- self.config["torch_dtype_str"] = TORCH_DTYPE_MAP_OBJ_TO_STR.get(self.torch_dtype, "float32")
212
-
214
+ if self.torch_dtype == "auto": # Should have been resolved above
215
+ self.torch_dtype = torch.float16 if self.config["device"] != "cpu" else torch.float32
216
+ self.config["torch_dtype_str"] = TORCH_DTYPE_MAP_OBJ_TO_STR.get(self.torch_dtype, "float32")
213
217
 
214
218
  # For potential lollms client specific features
215
219
  self.client_id = kwargs.get("service_key", kwargs.get("client_id", "diffusers_client_user"))
@@ -0,0 +1,215 @@
1
+ # lollms_client/tti_bindings/gemini/__init__.py
2
+ import sys
3
+ from typing import Optional, List, Dict, Any, Union
4
+
5
+ from lollms_client.lollms_tti_binding import LollmsTTIBinding
6
+ from ascii_colors import trace_exception, ASCIIColors
7
+ import json
8
+
9
+ try:
10
+ import pipmaster as pm
11
+ # google-cloud-aiplatform is the main dependency for Vertex AI
12
+ pm.ensure_packages(['google-cloud-aiplatform', 'Pillow'])
13
+ import vertexai
14
+ from vertexai.preview.vision_models import ImageGenerationModel
15
+ from google.api_core import exceptions as google_exceptions
16
+ GEMINI_AVAILABLE = True
17
+ except ImportError as e:
18
+ GEMINI_AVAILABLE = False
19
+ _gemini_installation_error = e
20
+
21
+ # Defines the binding name for the manager
22
+ BindingName = "GeminiTTIBinding_Impl"
23
+
24
+ # Known Imagen models on Vertex AI
25
+ IMAGEN_MODELS = ["imagegeneration@006", "imagegeneration@005", "imagegeneration@002"]
26
+
27
+ class GeminiTTIBinding_Impl(LollmsTTIBinding):
28
+ """
29
+ Concrete implementation of LollmsTTIBinding for Google's Imagen models via Vertex AI.
30
+ """
31
+ DEFAULT_CONFIG = {
32
+ "project_id": None,
33
+ "location": "us-central1",
34
+ "model_name": IMAGEN_MODELS[0],
35
+ "seed": -1, # -1 for random
36
+ "guidance_scale": 7.5,
37
+ "number_of_images": 1
38
+ }
39
+
40
+ def __init__(self, config: Optional[Dict[str, Any]] = None, **kwargs):
41
+ """
42
+ Initialize the Gemini (Vertex AI Imagen) TTI binding.
43
+
44
+ Args:
45
+ config (Optional[Dict[str, Any]]): Configuration dictionary. Overrides DEFAULT_CONFIG.
46
+ **kwargs: Catches other potential parameters.
47
+ """
48
+ super().__init__(binding_name="gemini")
49
+
50
+ if not GEMINI_AVAILABLE:
51
+ raise ImportError(
52
+ "Gemini (Vertex AI) binding dependencies are not met. "
53
+ "Please ensure 'google-cloud-aiplatform' is installed. "
54
+ f"Error: {_gemini_installation_error}"
55
+ )
56
+
57
+ self.config = {**self.DEFAULT_CONFIG, **(config or {}), **kwargs}
58
+ self.model: Optional[ImageGenerationModel] = None
59
+
60
+ self._initialize_client()
61
+
62
+ def _initialize_client(self):
63
+ """Initializes the Vertex AI client and loads the model."""
64
+ project_id = self.config.get("project_id")
65
+ location = self.config.get("location")
66
+ model_name = self.config.get("model_name")
67
+
68
+ if not project_id:
69
+ raise ValueError("Google Cloud 'project_id' is required for the Gemini (Vertex AI) binding.")
70
+
71
+ ASCIIColors.info("Initializing Vertex AI client...")
72
+ try:
73
+ vertexai.init(project=project_id, location=location)
74
+ self.model = ImageGenerationModel.from_pretrained(model_name)
75
+ ASCIIColors.green(f"Vertex AI initialized successfully. Loaded model: {model_name}")
76
+ except google_exceptions.PermissionDenied as e:
77
+ trace_exception(e)
78
+ raise Exception(
79
+ "Authentication failed. Ensure you have run 'gcloud auth application-default login' "
80
+ "and that the Vertex AI API is enabled for your project."
81
+ ) from e
82
+ except Exception as e:
83
+ trace_exception(e)
84
+ raise Exception(f"Failed to initialize Vertex AI client: {e}") from e
85
+
86
+ def _validate_dimensions(self, width: int, height: int) -> None:
87
+ """Validates image dimensions against Imagen 2 constraints."""
88
+ if not (256 <= width <= 1536 and width % 64 == 0):
89
+ raise ValueError(f"Invalid width: {width}. Must be between 256 and 1536 and a multiple of 64.")
90
+ if not (256 <= height <= 1536 and height % 64 == 0):
91
+ raise ValueError(f"Invalid height: {height}. Must be between 256 and 1536 and a multiple of 64.")
92
+ if width * height > 1536 * 1536: # Max pixels might be more constrained, 1536*1536 is a safe upper bound.
93
+ raise ValueError(f"Invalid dimensions: {width}x{height}. The total number of pixels cannot exceed 1536*1536.")
94
+
95
+ def generate_image(self,
96
+ prompt: str,
97
+ negative_prompt: Optional[str] = "",
98
+ width: int = 1024,
99
+ height: int = 1024,
100
+ **kwargs) -> bytes:
101
+ """
102
+ Generates image data using the Vertex AI Imagen model.
103
+
104
+ Args:
105
+ prompt (str): The positive text prompt.
106
+ negative_prompt (Optional[str]): The negative prompt.
107
+ width (int): Image width. Must be 256-1536 and a multiple of 64.
108
+ height (int): Image height. Must be 256-1536 and a multiple of 64.
109
+ **kwargs: Additional parameters:
110
+ - seed (int)
111
+ - guidance_scale (float)
112
+ Returns:
113
+ bytes: The generated image data (PNG format).
114
+ Raises:
115
+ Exception: If the request fails or image generation fails.
116
+ """
117
+ if not self.model:
118
+ raise RuntimeError("Vertex AI model is not loaded. Cannot generate image.")
119
+
120
+ self._validate_dimensions(width, height)
121
+
122
+ seed = kwargs.get("seed", self.config["seed"])
123
+ guidance_scale = kwargs.get("guidance_scale", self.config["guidance_scale"])
124
+
125
+ # Use -1 for random seed, otherwise pass the integer value.
126
+ gen_seed = seed if seed != -1 else None
127
+
128
+ gen_params = {
129
+ "prompt": prompt,
130
+ "negative_prompt": negative_prompt,
131
+ "number_of_images": 1, # This binding returns one image
132
+ "width": width,
133
+ "height": height,
134
+ "guidance_scale": guidance_scale,
135
+ }
136
+ if gen_seed is not None:
137
+ gen_params["seed"] = gen_seed
138
+
139
+ ASCIIColors.info(f"Generating image with prompt: '{prompt[:100]}...'")
140
+ ASCIIColors.debug(f"Imagen generation parameters: {gen_params}")
141
+
142
+ try:
143
+ response = self.model.generate_images(**gen_params)
144
+
145
+ if not response.images:
146
+ raise Exception("Image generation resulted in no images. This may be due to safety filters.")
147
+
148
+ img_bytes = response.images[0]._image_bytes
149
+ return img_bytes
150
+
151
+ except google_exceptions.InvalidArgument as e:
152
+ trace_exception(e)
153
+ raise ValueError(f"Invalid argument sent to Vertex AI API: {e.message}") from e
154
+ except google_exceptions.GoogleAPICallError as e:
155
+ trace_exception(e)
156
+ raise Exception(f"A Google API call error occurred: {e.message}") from e
157
+ except Exception as e:
158
+ trace_exception(e)
159
+ raise Exception(f"Imagen image generation failed: {e}") from e
160
+
161
+ def list_services(self, **kwargs) -> List[Dict[str, str]]:
162
+ """
163
+ Lists available Imagen models supported by this binding.
164
+ """
165
+ services = []
166
+ for model_name in IMAGEN_MODELS:
167
+ services.append({
168
+ "name": model_name,
169
+ "caption": f"Google Imagen 2 ({model_name})",
170
+ "help": "High-quality text-to-image model from Google, available on Vertex AI."
171
+ })
172
+ return services
173
+
174
+ def get_settings(self, **kwargs) -> List[Dict[str, Any]]:
175
+ """
176
+ Retrieves the current configurable settings for the binding.
177
+ """
178
+ return [
179
+ {"name": "project_id", "type": "str", "value": self.config["project_id"], "description": "Your Google Cloud project ID."},
180
+ {"name": "location", "type": "str", "value": self.config["location"], "description": "Google Cloud region for the project (e.g., 'us-central1')."},
181
+ {"name": "model_name", "type": "str", "value": self.config["model_name"], "description": "The Imagen model version to use.", "options": IMAGEN_MODELS},
182
+ {"name": "seed", "type": "int", "value": self.config["seed"], "description": "Default seed for generation (-1 for random)."},
183
+ {"name": "guidance_scale", "type": "float", "value": self.config["guidance_scale"], "description": "Default guidance scale (CFG). Higher values follow the prompt more strictly."},
184
+ ]
185
+
186
+ def set_settings(self, settings: Union[Dict[str, Any], List[Dict[str, Any]]], **kwargs) -> bool:
187
+ """
188
+ Applies new settings to the binding. Re-initializes the client if needed.
189
+ """
190
+ if isinstance(settings, list):
191
+ parsed_settings = {item["name"]: item["value"] for item in settings if "name" in item and "value" in item}
192
+ elif isinstance(settings, dict):
193
+ parsed_settings = settings
194
+ else:
195
+ ASCIIColors.error("Invalid settings format. Expected a dictionary or list of dictionaries.")
196
+ return False
197
+
198
+ needs_reinit = False
199
+ for key, value in parsed_settings.items():
200
+ if key in self.config and self.config[key] != value:
201
+ self.config[key] = value
202
+ ASCIIColors.info(f"Setting '{key}' changed to: {value}")
203
+ if key in ["project_id", "location", "model_name"]:
204
+ needs_reinit = True
205
+
206
+ if needs_reinit:
207
+ try:
208
+ self._initialize_client()
209
+ ASCIIColors.green("Vertex AI client re-initialized successfully with new settings.")
210
+ except Exception as e:
211
+ ASCIIColors.error(f"Failed to re-initialize client with new settings: {e}")
212
+ # Optionally, revert to old config here to maintain a working state
213
+ return False
214
+
215
+ return True
@@ -9,15 +9,12 @@ from ascii_colors import trace_exception, ASCIIColors
9
9
  import json # Added for potential error parsing
10
10
 
11
11
  # Defines the binding name for the manager
12
- BindingName = "LollmsTTIBinding_Impl"
13
-
14
- class LollmsTTIBinding_Impl(LollmsTTIBinding):
12
+ BindingName = "LollmsWebuiTTIBinding_Impl"
13
+ class LollmsWebuiTTIBinding_Impl(LollmsTTIBinding):
15
14
  """Concrete implementation of the LollmsTTIBinding for the standard LOLLMS server."""
16
15
 
17
16
  def __init__(self,
18
- host_address: Optional[str] = "http://localhost:9600", # Default LOLLMS host
19
- service_key: Optional[str] = None,
20
- verify_ssl_certificate: bool = True):
17
+ **kwargs):
21
18
  """
22
19
  Initialize the LOLLMS TTI binding.
23
20
 
@@ -27,12 +24,14 @@ class LollmsTTIBinding_Impl(LollmsTTIBinding):
27
24
  verify_ssl_certificate (bool): Whether to verify SSL certificates.
28
25
  """
29
26
  super().__init__(binding_name="lollms")
30
- self.host_address=host_address
31
- self.verify_ssl_certificate = verify_ssl_certificate
27
+
28
+ # Extract parameters from kwargs, providing defaults
29
+ self.host_address = kwargs.get("host_address", "http://localhost:9600") # Default LOLLMS host
30
+ self.verify_ssl_certificate = kwargs.get("verify_ssl_certificate", True)
32
31
 
33
32
  # The 'service_key' here will act as the 'client_id' for TTI requests if provided.
34
33
  # This assumes the client library user provides their LOLLMS client_id here.
35
- self.client_id = service_key
34
+ self.client_id = kwargs.get("service_key", None) # Use service_key or None
36
35
 
37
36
  def _get_client_id(self, **kwargs) -> str:
38
37
  """Helper to get client_id, prioritizing kwargs then instance default."""