ollamadiffuser 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,50 @@
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import argparse
4
+ from .cli.main import cli
5
+ from .api.server import run_server
6
+ from .ui.web import create_ui_app
7
+ from .core.config.settings import settings
8
+
9
+ def main():
10
+ """Main entry function"""
11
+ # Check if first argument is a mode flag
12
+ if len(sys.argv) > 1 and sys.argv[1] in ['--mode']:
13
+ # Use argparse for mode selection
14
+ parser = argparse.ArgumentParser(
15
+ description='OllamaDiffuser - Image generation model management tool'
16
+ )
17
+ parser.add_argument(
18
+ '--mode',
19
+ choices=['cli', 'api', 'ui'],
20
+ required=True,
21
+ help='Running mode: cli (command line), api (API server), ui (Web interface)'
22
+ )
23
+ parser.add_argument('--host', default=None, help='Server host address')
24
+ parser.add_argument('--port', type=int, default=None, help='Server port')
25
+ parser.add_argument('--verbose', '-v', action='store_true', help='Verbose output')
26
+
27
+ args, unknown = parser.parse_known_args()
28
+
29
+ if args.mode == 'cli':
30
+ # Command line mode
31
+ sys.argv = [sys.argv[0]] + unknown
32
+ cli()
33
+ elif args.mode == 'api':
34
+ # API server mode
35
+ print("Starting OllamaDiffuser API server...")
36
+ run_server(host=args.host, port=args.port)
37
+ elif args.mode == 'ui':
38
+ # Web UI mode
39
+ print("Starting OllamaDiffuser Web UI...")
40
+ import uvicorn
41
+ app = create_ui_app()
42
+ host = args.host or settings.server.host
43
+ port = args.port or (settings.server.port + 1) # Web UI uses different port
44
+ uvicorn.run(app, host=host, port=port)
45
+ else:
46
+ # Default to CLI mode for direct command usage
47
+ cli()
48
+
49
+ if __name__ == '__main__':
50
+ main()
File without changes
@@ -0,0 +1,297 @@
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import Response
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import uvicorn
5
+ import io
6
+ import logging
7
+ from typing import Dict, Any, Optional
8
+ from pydantic import BaseModel
9
+
10
+ from ..core.models.manager import model_manager
11
+ from ..core.config.settings import settings
12
+
13
+ # Setup logging
14
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # API request models
18
+ class GenerateRequest(BaseModel):
19
+ prompt: str
20
+ negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution"
21
+ num_inference_steps: Optional[int] = None
22
+ guidance_scale: Optional[float] = None
23
+ width: int = 1024
24
+ height: int = 1024
25
+
26
+ class LoadModelRequest(BaseModel):
27
+ model_name: str
28
+
29
+ class LoadLoRARequest(BaseModel):
30
+ lora_name: str
31
+ repo_id: str
32
+ weight_name: Optional[str] = None
33
+ scale: float = 1.0
34
+
35
+ def create_app() -> FastAPI:
36
+ """Create FastAPI application"""
37
+ app = FastAPI(
38
+ title="OllamaDiffuser API",
39
+ description="Image generation model management and inference API",
40
+ version="1.0.0"
41
+ )
42
+
43
+ # Add CORS middleware
44
+ if settings.server.enable_cors:
45
+ app.add_middleware(
46
+ CORSMiddleware,
47
+ allow_origins=["*"],
48
+ allow_credentials=True,
49
+ allow_methods=["*"],
50
+ allow_headers=["*"],
51
+ )
52
+
53
+ # Root endpoint
54
+ @app.get("/")
55
+ async def root():
56
+ """Root endpoint with API information"""
57
+ return {
58
+ "name": "OllamaDiffuser API",
59
+ "version": "1.0.0",
60
+ "description": "Image generation model management and inference API",
61
+ "status": "running",
62
+ "endpoints": {
63
+ "documentation": "/docs",
64
+ "openapi_schema": "/openapi.json",
65
+ "health_check": "/api/health",
66
+ "models": "/api/models",
67
+ "generate": "/api/generate"
68
+ },
69
+ "usage": {
70
+ "web_ui": "Use 'ollamadiffuser --mode ui' to start the web interface",
71
+ "cli": "Use 'ollamadiffuser --help' for command line options",
72
+ "api_docs": "Visit /docs for interactive API documentation"
73
+ }
74
+ }
75
+
76
+ # Model management endpoints
77
+ @app.get("/api/models")
78
+ async def list_models():
79
+ """List all models"""
80
+ return {
81
+ "available": model_manager.list_available_models(),
82
+ "installed": model_manager.list_installed_models(),
83
+ "current": model_manager.get_current_model()
84
+ }
85
+
86
+ @app.get("/api/models/running")
87
+ async def get_running_model():
88
+ """Get currently running model"""
89
+ if model_manager.is_model_loaded():
90
+ engine = model_manager.loaded_model
91
+ return {
92
+ "model": model_manager.get_current_model(),
93
+ "info": engine.get_model_info(),
94
+ "loaded": True
95
+ }
96
+ else:
97
+ return {"loaded": False}
98
+
99
+ @app.get("/api/models/{model_name}")
100
+ async def get_model_info(model_name: str):
101
+ """Get model detailed information"""
102
+ info = model_manager.get_model_info(model_name)
103
+ if info is None:
104
+ raise HTTPException(status_code=404, detail="Model does not exist")
105
+ return info
106
+
107
+ @app.post("/api/models/pull")
108
+ async def pull_model(model_name: str):
109
+ """Download model"""
110
+ if model_manager.pull_model(model_name):
111
+ return {"message": f"Model {model_name} downloaded successfully"}
112
+ else:
113
+ raise HTTPException(status_code=400, detail=f"Failed to download model {model_name}")
114
+
115
+ @app.post("/api/models/load")
116
+ async def load_model(request: LoadModelRequest):
117
+ """Load model"""
118
+ if model_manager.load_model(request.model_name):
119
+ return {"message": f"Model {request.model_name} loaded successfully"}
120
+ else:
121
+ raise HTTPException(status_code=400, detail=f"Failed to load model {request.model_name}")
122
+
123
+ @app.post("/api/models/unload")
124
+ async def unload_model():
125
+ """Unload current model"""
126
+ model_manager.unload_model()
127
+ return {"message": "Model unloaded"}
128
+
129
+ @app.delete("/api/models/{model_name}")
130
+ async def remove_model(model_name: str):
131
+ """Remove model"""
132
+ if model_manager.remove_model(model_name):
133
+ return {"message": f"Model {model_name} removed successfully"}
134
+ else:
135
+ raise HTTPException(status_code=400, detail=f"Failed to remove model {model_name}")
136
+
137
+ # LoRA management endpoints
138
+ @app.post("/api/lora/load")
139
+ async def load_lora(request: LoadLoRARequest):
140
+ """Load LoRA weights into current model"""
141
+ # Check if model is loaded
142
+ if not model_manager.is_model_loaded():
143
+ raise HTTPException(status_code=400, detail="No model loaded, please load a model first")
144
+
145
+ try:
146
+ # Get current loaded inference engine
147
+ engine = model_manager.loaded_model
148
+
149
+ # Load LoRA weights
150
+ success = engine.load_lora_runtime(
151
+ repo_id=request.repo_id,
152
+ weight_name=request.weight_name,
153
+ scale=request.scale
154
+ )
155
+
156
+ if success:
157
+ return {"message": f"LoRA {request.lora_name} loaded successfully with scale {request.scale}"}
158
+ else:
159
+ raise HTTPException(status_code=400, detail=f"Failed to load LoRA {request.lora_name}")
160
+
161
+ except Exception as e:
162
+ logger.error(f"LoRA loading failed: {e}")
163
+ raise HTTPException(status_code=500, detail=f"LoRA loading failed: {str(e)}")
164
+
165
+ @app.post("/api/lora/unload")
166
+ async def unload_lora():
167
+ """Unload current LoRA weights"""
168
+ # Check if model is loaded
169
+ if not model_manager.is_model_loaded():
170
+ raise HTTPException(status_code=400, detail="No model loaded")
171
+
172
+ try:
173
+ # Get current loaded inference engine
174
+ engine = model_manager.loaded_model
175
+
176
+ # Unload LoRA weights
177
+ success = engine.unload_lora()
178
+
179
+ if success:
180
+ return {"message": "LoRA weights unloaded successfully"}
181
+ else:
182
+ raise HTTPException(status_code=400, detail="Failed to unload LoRA weights")
183
+
184
+ except Exception as e:
185
+ logger.error(f"LoRA unloading failed: {e}")
186
+ raise HTTPException(status_code=500, detail=f"LoRA unloading failed: {str(e)}")
187
+
188
+ @app.get("/api/lora/status")
189
+ async def get_lora_status():
190
+ """Get current LoRA status"""
191
+ # Check if model is loaded
192
+ if not model_manager.is_model_loaded():
193
+ return {"loaded": False, "message": "No model loaded"}
194
+
195
+ try:
196
+ # Get current loaded inference engine
197
+ engine = model_manager.loaded_model
198
+
199
+ # Check tracked LoRA state
200
+ if hasattr(engine, 'current_lora') and engine.current_lora:
201
+ lora_info = engine.current_lora.copy()
202
+ return {
203
+ "loaded": True,
204
+ "info": lora_info,
205
+ "message": "LoRA loaded"
206
+ }
207
+ else:
208
+ return {
209
+ "loaded": False,
210
+ "info": None,
211
+ "message": "No LoRA loaded"
212
+ }
213
+
214
+ except Exception as e:
215
+ logger.error(f"Failed to get LoRA status: {e}")
216
+ raise HTTPException(status_code=500, detail=f"Failed to get LoRA status: {str(e)}")
217
+
218
+ # Image generation endpoints
219
+ @app.post("/api/generate")
220
+ async def generate_image(request: GenerateRequest):
221
+ """Generate image"""
222
+ # Check if model is loaded
223
+ if not model_manager.is_model_loaded():
224
+ raise HTTPException(status_code=400, detail="No model loaded, please load a model first")
225
+
226
+ try:
227
+ # Get current loaded inference engine
228
+ engine = model_manager.loaded_model
229
+
230
+ # Generate image
231
+ image = engine.generate_image(
232
+ prompt=request.prompt,
233
+ negative_prompt=request.negative_prompt,
234
+ num_inference_steps=request.num_inference_steps,
235
+ guidance_scale=request.guidance_scale,
236
+ width=request.width,
237
+ height=request.height
238
+ )
239
+
240
+ # Convert PIL image to bytes
241
+ img_byte_arr = io.BytesIO()
242
+ image.save(img_byte_arr, format='PNG')
243
+ img_byte_arr = img_byte_arr.getvalue()
244
+
245
+ return Response(content=img_byte_arr, media_type="image/png")
246
+
247
+ except Exception as e:
248
+ logger.error(f"Image generation failed: {e}")
249
+ raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
250
+
251
+ # Health check endpoints
252
+ @app.get("/api/health")
253
+ async def health_check():
254
+ """Health check"""
255
+ return {
256
+ "status": "healthy",
257
+ "model_loaded": model_manager.is_model_loaded(),
258
+ "current_model": model_manager.get_current_model()
259
+ }
260
+
261
+ # Server management endpoints
262
+ @app.post("/api/shutdown")
263
+ async def shutdown_server():
264
+ """Gracefully shutdown the server"""
265
+ import os
266
+ import signal
267
+ import asyncio
268
+
269
+ # Unload model first
270
+ model_manager.unload_model()
271
+
272
+ # Schedule server shutdown
273
+ def shutdown():
274
+ os.kill(os.getpid(), signal.SIGTERM)
275
+
276
+ # Delay shutdown to allow response to be sent
277
+ asyncio.get_event_loop().call_later(0.5, shutdown)
278
+
279
+ return {"message": "Server shutting down..."}
280
+
281
+ return app
282
+
283
+ def run_server(host: str = None, port: int = None):
284
+ """Start server"""
285
+ # Use default values from configuration
286
+ host = host or settings.server.host
287
+ port = port or settings.server.port
288
+
289
+ # Create FastAPI application
290
+ app = create_app()
291
+
292
+ # Run server with uvicorn
293
+ logger.info(f"Starting server: http://{host}:{port}")
294
+ uvicorn.run(app, host=host, port=port, log_level="info")
295
+
296
+ if __name__ == "__main__":
297
+ run_server()
File without changes