ollamadiffuser 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ollamadiffuser/__init__.py +0 -0
- ollamadiffuser/__main__.py +50 -0
- ollamadiffuser/api/__init__.py +0 -0
- ollamadiffuser/api/server.py +297 -0
- ollamadiffuser/cli/__init__.py +0 -0
- ollamadiffuser/cli/main.py +597 -0
- ollamadiffuser/core/__init__.py +0 -0
- ollamadiffuser/core/config/__init__.py +0 -0
- ollamadiffuser/core/config/settings.py +137 -0
- ollamadiffuser/core/inference/__init__.py +0 -0
- ollamadiffuser/core/inference/engine.py +926 -0
- ollamadiffuser/core/models/__init__.py +0 -0
- ollamadiffuser/core/models/manager.py +436 -0
- ollamadiffuser/core/utils/__init__.py +3 -0
- ollamadiffuser/core/utils/download_utils.py +356 -0
- ollamadiffuser/core/utils/lora_manager.py +390 -0
- ollamadiffuser/ui/__init__.py +0 -0
- ollamadiffuser/ui/templates/index.html +496 -0
- ollamadiffuser/ui/web.py +278 -0
- ollamadiffuser/utils/__init__.py +0 -0
- ollamadiffuser-1.0.0.dist-info/METADATA +493 -0
- ollamadiffuser-1.0.0.dist-info/RECORD +26 -0
- ollamadiffuser-1.0.0.dist-info/WHEEL +5 -0
- ollamadiffuser-1.0.0.dist-info/entry_points.txt +2 -0
- ollamadiffuser-1.0.0.dist-info/licenses/LICENSE +21 -0
- ollamadiffuser-1.0.0.dist-info/top_level.txt +1 -0
|
File without changes
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
import sys
|
|
3
|
+
import argparse
|
|
4
|
+
from .cli.main import cli
|
|
5
|
+
from .api.server import run_server
|
|
6
|
+
from .ui.web import create_ui_app
|
|
7
|
+
from .core.config.settings import settings
|
|
8
|
+
|
|
9
|
+
def main():
|
|
10
|
+
"""Main entry function"""
|
|
11
|
+
# Check if first argument is a mode flag
|
|
12
|
+
if len(sys.argv) > 1 and sys.argv[1] in ['--mode']:
|
|
13
|
+
# Use argparse for mode selection
|
|
14
|
+
parser = argparse.ArgumentParser(
|
|
15
|
+
description='OllamaDiffuser - Image generation model management tool'
|
|
16
|
+
)
|
|
17
|
+
parser.add_argument(
|
|
18
|
+
'--mode',
|
|
19
|
+
choices=['cli', 'api', 'ui'],
|
|
20
|
+
required=True,
|
|
21
|
+
help='Running mode: cli (command line), api (API server), ui (Web interface)'
|
|
22
|
+
)
|
|
23
|
+
parser.add_argument('--host', default=None, help='Server host address')
|
|
24
|
+
parser.add_argument('--port', type=int, default=None, help='Server port')
|
|
25
|
+
parser.add_argument('--verbose', '-v', action='store_true', help='Verbose output')
|
|
26
|
+
|
|
27
|
+
args, unknown = parser.parse_known_args()
|
|
28
|
+
|
|
29
|
+
if args.mode == 'cli':
|
|
30
|
+
# Command line mode
|
|
31
|
+
sys.argv = [sys.argv[0]] + unknown
|
|
32
|
+
cli()
|
|
33
|
+
elif args.mode == 'api':
|
|
34
|
+
# API server mode
|
|
35
|
+
print("Starting OllamaDiffuser API server...")
|
|
36
|
+
run_server(host=args.host, port=args.port)
|
|
37
|
+
elif args.mode == 'ui':
|
|
38
|
+
# Web UI mode
|
|
39
|
+
print("Starting OllamaDiffuser Web UI...")
|
|
40
|
+
import uvicorn
|
|
41
|
+
app = create_ui_app()
|
|
42
|
+
host = args.host or settings.server.host
|
|
43
|
+
port = args.port or (settings.server.port + 1) # Web UI uses different port
|
|
44
|
+
uvicorn.run(app, host=host, port=port)
|
|
45
|
+
else:
|
|
46
|
+
# Default to CLI mode for direct command usage
|
|
47
|
+
cli()
|
|
48
|
+
|
|
49
|
+
if __name__ == '__main__':
|
|
50
|
+
main()
|
|
File without changes
|
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
from fastapi import FastAPI, HTTPException
|
|
2
|
+
from fastapi.responses import Response
|
|
3
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
4
|
+
import uvicorn
|
|
5
|
+
import io
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Dict, Any, Optional
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
|
|
10
|
+
from ..core.models.manager import model_manager
|
|
11
|
+
from ..core.config.settings import settings
|
|
12
|
+
|
|
13
|
+
# Setup logging
|
|
14
|
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
# API request models
|
|
18
|
+
class GenerateRequest(BaseModel):
|
|
19
|
+
prompt: str
|
|
20
|
+
negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution"
|
|
21
|
+
num_inference_steps: Optional[int] = None
|
|
22
|
+
guidance_scale: Optional[float] = None
|
|
23
|
+
width: int = 1024
|
|
24
|
+
height: int = 1024
|
|
25
|
+
|
|
26
|
+
class LoadModelRequest(BaseModel):
|
|
27
|
+
model_name: str
|
|
28
|
+
|
|
29
|
+
class LoadLoRARequest(BaseModel):
|
|
30
|
+
lora_name: str
|
|
31
|
+
repo_id: str
|
|
32
|
+
weight_name: Optional[str] = None
|
|
33
|
+
scale: float = 1.0
|
|
34
|
+
|
|
35
|
+
def create_app() -> FastAPI:
|
|
36
|
+
"""Create FastAPI application"""
|
|
37
|
+
app = FastAPI(
|
|
38
|
+
title="OllamaDiffuser API",
|
|
39
|
+
description="Image generation model management and inference API",
|
|
40
|
+
version="1.0.0"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Add CORS middleware
|
|
44
|
+
if settings.server.enable_cors:
|
|
45
|
+
app.add_middleware(
|
|
46
|
+
CORSMiddleware,
|
|
47
|
+
allow_origins=["*"],
|
|
48
|
+
allow_credentials=True,
|
|
49
|
+
allow_methods=["*"],
|
|
50
|
+
allow_headers=["*"],
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Root endpoint
|
|
54
|
+
@app.get("/")
|
|
55
|
+
async def root():
|
|
56
|
+
"""Root endpoint with API information"""
|
|
57
|
+
return {
|
|
58
|
+
"name": "OllamaDiffuser API",
|
|
59
|
+
"version": "1.0.0",
|
|
60
|
+
"description": "Image generation model management and inference API",
|
|
61
|
+
"status": "running",
|
|
62
|
+
"endpoints": {
|
|
63
|
+
"documentation": "/docs",
|
|
64
|
+
"openapi_schema": "/openapi.json",
|
|
65
|
+
"health_check": "/api/health",
|
|
66
|
+
"models": "/api/models",
|
|
67
|
+
"generate": "/api/generate"
|
|
68
|
+
},
|
|
69
|
+
"usage": {
|
|
70
|
+
"web_ui": "Use 'ollamadiffuser --mode ui' to start the web interface",
|
|
71
|
+
"cli": "Use 'ollamadiffuser --help' for command line options",
|
|
72
|
+
"api_docs": "Visit /docs for interactive API documentation"
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
# Model management endpoints
|
|
77
|
+
@app.get("/api/models")
|
|
78
|
+
async def list_models():
|
|
79
|
+
"""List all models"""
|
|
80
|
+
return {
|
|
81
|
+
"available": model_manager.list_available_models(),
|
|
82
|
+
"installed": model_manager.list_installed_models(),
|
|
83
|
+
"current": model_manager.get_current_model()
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
@app.get("/api/models/running")
|
|
87
|
+
async def get_running_model():
|
|
88
|
+
"""Get currently running model"""
|
|
89
|
+
if model_manager.is_model_loaded():
|
|
90
|
+
engine = model_manager.loaded_model
|
|
91
|
+
return {
|
|
92
|
+
"model": model_manager.get_current_model(),
|
|
93
|
+
"info": engine.get_model_info(),
|
|
94
|
+
"loaded": True
|
|
95
|
+
}
|
|
96
|
+
else:
|
|
97
|
+
return {"loaded": False}
|
|
98
|
+
|
|
99
|
+
@app.get("/api/models/{model_name}")
|
|
100
|
+
async def get_model_info(model_name: str):
|
|
101
|
+
"""Get model detailed information"""
|
|
102
|
+
info = model_manager.get_model_info(model_name)
|
|
103
|
+
if info is None:
|
|
104
|
+
raise HTTPException(status_code=404, detail="Model does not exist")
|
|
105
|
+
return info
|
|
106
|
+
|
|
107
|
+
@app.post("/api/models/pull")
|
|
108
|
+
async def pull_model(model_name: str):
|
|
109
|
+
"""Download model"""
|
|
110
|
+
if model_manager.pull_model(model_name):
|
|
111
|
+
return {"message": f"Model {model_name} downloaded successfully"}
|
|
112
|
+
else:
|
|
113
|
+
raise HTTPException(status_code=400, detail=f"Failed to download model {model_name}")
|
|
114
|
+
|
|
115
|
+
@app.post("/api/models/load")
|
|
116
|
+
async def load_model(request: LoadModelRequest):
|
|
117
|
+
"""Load model"""
|
|
118
|
+
if model_manager.load_model(request.model_name):
|
|
119
|
+
return {"message": f"Model {request.model_name} loaded successfully"}
|
|
120
|
+
else:
|
|
121
|
+
raise HTTPException(status_code=400, detail=f"Failed to load model {request.model_name}")
|
|
122
|
+
|
|
123
|
+
@app.post("/api/models/unload")
|
|
124
|
+
async def unload_model():
|
|
125
|
+
"""Unload current model"""
|
|
126
|
+
model_manager.unload_model()
|
|
127
|
+
return {"message": "Model unloaded"}
|
|
128
|
+
|
|
129
|
+
@app.delete("/api/models/{model_name}")
|
|
130
|
+
async def remove_model(model_name: str):
|
|
131
|
+
"""Remove model"""
|
|
132
|
+
if model_manager.remove_model(model_name):
|
|
133
|
+
return {"message": f"Model {model_name} removed successfully"}
|
|
134
|
+
else:
|
|
135
|
+
raise HTTPException(status_code=400, detail=f"Failed to remove model {model_name}")
|
|
136
|
+
|
|
137
|
+
# LoRA management endpoints
|
|
138
|
+
@app.post("/api/lora/load")
|
|
139
|
+
async def load_lora(request: LoadLoRARequest):
|
|
140
|
+
"""Load LoRA weights into current model"""
|
|
141
|
+
# Check if model is loaded
|
|
142
|
+
if not model_manager.is_model_loaded():
|
|
143
|
+
raise HTTPException(status_code=400, detail="No model loaded, please load a model first")
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
# Get current loaded inference engine
|
|
147
|
+
engine = model_manager.loaded_model
|
|
148
|
+
|
|
149
|
+
# Load LoRA weights
|
|
150
|
+
success = engine.load_lora_runtime(
|
|
151
|
+
repo_id=request.repo_id,
|
|
152
|
+
weight_name=request.weight_name,
|
|
153
|
+
scale=request.scale
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if success:
|
|
157
|
+
return {"message": f"LoRA {request.lora_name} loaded successfully with scale {request.scale}"}
|
|
158
|
+
else:
|
|
159
|
+
raise HTTPException(status_code=400, detail=f"Failed to load LoRA {request.lora_name}")
|
|
160
|
+
|
|
161
|
+
except Exception as e:
|
|
162
|
+
logger.error(f"LoRA loading failed: {e}")
|
|
163
|
+
raise HTTPException(status_code=500, detail=f"LoRA loading failed: {str(e)}")
|
|
164
|
+
|
|
165
|
+
@app.post("/api/lora/unload")
|
|
166
|
+
async def unload_lora():
|
|
167
|
+
"""Unload current LoRA weights"""
|
|
168
|
+
# Check if model is loaded
|
|
169
|
+
if not model_manager.is_model_loaded():
|
|
170
|
+
raise HTTPException(status_code=400, detail="No model loaded")
|
|
171
|
+
|
|
172
|
+
try:
|
|
173
|
+
# Get current loaded inference engine
|
|
174
|
+
engine = model_manager.loaded_model
|
|
175
|
+
|
|
176
|
+
# Unload LoRA weights
|
|
177
|
+
success = engine.unload_lora()
|
|
178
|
+
|
|
179
|
+
if success:
|
|
180
|
+
return {"message": "LoRA weights unloaded successfully"}
|
|
181
|
+
else:
|
|
182
|
+
raise HTTPException(status_code=400, detail="Failed to unload LoRA weights")
|
|
183
|
+
|
|
184
|
+
except Exception as e:
|
|
185
|
+
logger.error(f"LoRA unloading failed: {e}")
|
|
186
|
+
raise HTTPException(status_code=500, detail=f"LoRA unloading failed: {str(e)}")
|
|
187
|
+
|
|
188
|
+
@app.get("/api/lora/status")
|
|
189
|
+
async def get_lora_status():
|
|
190
|
+
"""Get current LoRA status"""
|
|
191
|
+
# Check if model is loaded
|
|
192
|
+
if not model_manager.is_model_loaded():
|
|
193
|
+
return {"loaded": False, "message": "No model loaded"}
|
|
194
|
+
|
|
195
|
+
try:
|
|
196
|
+
# Get current loaded inference engine
|
|
197
|
+
engine = model_manager.loaded_model
|
|
198
|
+
|
|
199
|
+
# Check tracked LoRA state
|
|
200
|
+
if hasattr(engine, 'current_lora') and engine.current_lora:
|
|
201
|
+
lora_info = engine.current_lora.copy()
|
|
202
|
+
return {
|
|
203
|
+
"loaded": True,
|
|
204
|
+
"info": lora_info,
|
|
205
|
+
"message": "LoRA loaded"
|
|
206
|
+
}
|
|
207
|
+
else:
|
|
208
|
+
return {
|
|
209
|
+
"loaded": False,
|
|
210
|
+
"info": None,
|
|
211
|
+
"message": "No LoRA loaded"
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
except Exception as e:
|
|
215
|
+
logger.error(f"Failed to get LoRA status: {e}")
|
|
216
|
+
raise HTTPException(status_code=500, detail=f"Failed to get LoRA status: {str(e)}")
|
|
217
|
+
|
|
218
|
+
# Image generation endpoints
|
|
219
|
+
@app.post("/api/generate")
|
|
220
|
+
async def generate_image(request: GenerateRequest):
|
|
221
|
+
"""Generate image"""
|
|
222
|
+
# Check if model is loaded
|
|
223
|
+
if not model_manager.is_model_loaded():
|
|
224
|
+
raise HTTPException(status_code=400, detail="No model loaded, please load a model first")
|
|
225
|
+
|
|
226
|
+
try:
|
|
227
|
+
# Get current loaded inference engine
|
|
228
|
+
engine = model_manager.loaded_model
|
|
229
|
+
|
|
230
|
+
# Generate image
|
|
231
|
+
image = engine.generate_image(
|
|
232
|
+
prompt=request.prompt,
|
|
233
|
+
negative_prompt=request.negative_prompt,
|
|
234
|
+
num_inference_steps=request.num_inference_steps,
|
|
235
|
+
guidance_scale=request.guidance_scale,
|
|
236
|
+
width=request.width,
|
|
237
|
+
height=request.height
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Convert PIL image to bytes
|
|
241
|
+
img_byte_arr = io.BytesIO()
|
|
242
|
+
image.save(img_byte_arr, format='PNG')
|
|
243
|
+
img_byte_arr = img_byte_arr.getvalue()
|
|
244
|
+
|
|
245
|
+
return Response(content=img_byte_arr, media_type="image/png")
|
|
246
|
+
|
|
247
|
+
except Exception as e:
|
|
248
|
+
logger.error(f"Image generation failed: {e}")
|
|
249
|
+
raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
|
|
250
|
+
|
|
251
|
+
# Health check endpoints
|
|
252
|
+
@app.get("/api/health")
|
|
253
|
+
async def health_check():
|
|
254
|
+
"""Health check"""
|
|
255
|
+
return {
|
|
256
|
+
"status": "healthy",
|
|
257
|
+
"model_loaded": model_manager.is_model_loaded(),
|
|
258
|
+
"current_model": model_manager.get_current_model()
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
# Server management endpoints
|
|
262
|
+
@app.post("/api/shutdown")
|
|
263
|
+
async def shutdown_server():
|
|
264
|
+
"""Gracefully shutdown the server"""
|
|
265
|
+
import os
|
|
266
|
+
import signal
|
|
267
|
+
import asyncio
|
|
268
|
+
|
|
269
|
+
# Unload model first
|
|
270
|
+
model_manager.unload_model()
|
|
271
|
+
|
|
272
|
+
# Schedule server shutdown
|
|
273
|
+
def shutdown():
|
|
274
|
+
os.kill(os.getpid(), signal.SIGTERM)
|
|
275
|
+
|
|
276
|
+
# Delay shutdown to allow response to be sent
|
|
277
|
+
asyncio.get_event_loop().call_later(0.5, shutdown)
|
|
278
|
+
|
|
279
|
+
return {"message": "Server shutting down..."}
|
|
280
|
+
|
|
281
|
+
return app
|
|
282
|
+
|
|
283
|
+
def run_server(host: str = None, port: int = None):
|
|
284
|
+
"""Start server"""
|
|
285
|
+
# Use default values from configuration
|
|
286
|
+
host = host or settings.server.host
|
|
287
|
+
port = port or settings.server.port
|
|
288
|
+
|
|
289
|
+
# Create FastAPI application
|
|
290
|
+
app = create_app()
|
|
291
|
+
|
|
292
|
+
# Run server with uvicorn
|
|
293
|
+
logger.info(f"Starting server: http://{host}:{port}")
|
|
294
|
+
uvicorn.run(app, host=host, port=port, log_level="info")
|
|
295
|
+
|
|
296
|
+
if __name__ == "__main__":
|
|
297
|
+
run_server()
|
|
File without changes
|