isa-model 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +5 -0
- isa_model/core/model_manager.py +143 -0
- isa_model/core/model_registry.py +115 -0
- isa_model/core/model_router.py +226 -0
- isa_model/core/model_storage.py +133 -0
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +202 -0
- isa_model/core/storage/hf_storage.py +0 -0
- isa_model/core/storage/local_storage.py +0 -0
- isa_model/core/storage/minio_storage.py +0 -0
- isa_model/deployment/mlflow_gateway/__init__.py +8 -0
- isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
- isa_model/deployment/unified_multimodal_client.py +341 -0
- isa_model/inference/__init__.py +11 -0
- isa_model/inference/adapter/triton_adapter.py +453 -0
- isa_model/inference/adapter/unified_api.py +248 -0
- isa_model/inference/ai_factory.py +354 -0
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
- isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
- isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
- isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
- isa_model/inference/backends/__init__.py +53 -0
- isa_model/inference/backends/base_backend_client.py +26 -0
- isa_model/inference/backends/container_services.py +104 -0
- isa_model/inference/backends/local_services.py +72 -0
- isa_model/inference/backends/openai_client.py +130 -0
- isa_model/inference/backends/replicate_client.py +197 -0
- isa_model/inference/backends/third_party_services.py +239 -0
- isa_model/inference/backends/triton_client.py +97 -0
- isa_model/inference/base.py +46 -0
- isa_model/inference/client_sdk/__init__.py +0 -0
- isa_model/inference/client_sdk/client.py +134 -0
- isa_model/inference/client_sdk/client_data_std.py +34 -0
- isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
- isa_model/inference/providers/__init__.py +19 -0
- isa_model/inference/providers/base_provider.py +30 -0
- isa_model/inference/providers/model_cache_manager.py +341 -0
- isa_model/inference/providers/ollama_provider.py +73 -0
- isa_model/inference/providers/openai_provider.py +87 -0
- isa_model/inference/providers/replicate_provider.py +94 -0
- isa_model/inference/providers/triton_provider.py +439 -0
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +83 -0
- isa_model/inference/services/__init__.py +14 -0
- isa_model/inference/services/audio/fish_speech/handler.py +215 -0
- isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
- isa_model/inference/services/audio/triton_speech_service.py +138 -0
- isa_model/inference/services/audio/whisper_service.py +186 -0
- isa_model/inference/services/audio/yyds_audio_service.py +71 -0
- isa_model/inference/services/base_service.py +106 -0
- isa_model/inference/services/base_tts_service.py +66 -0
- isa_model/inference/services/embedding/bge_service.py +183 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
- isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
- isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
- isa_model/inference/services/llm/__init__.py +16 -0
- isa_model/inference/services/llm/gemma_service.py +143 -0
- isa_model/inference/services/llm/llama_service.py +143 -0
- isa_model/inference/services/llm/ollama_llm_service.py +108 -0
- isa_model/inference/services/llm/openai_llm_service.py +129 -0
- isa_model/inference/services/llm/replicate_llm_service.py +179 -0
- isa_model/inference/services/llm/triton_llm_service.py +230 -0
- isa_model/inference/services/others/table_transformer_service.py +61 -0
- isa_model/inference/services/vision/__init__.py +12 -0
- isa_model/inference/services/vision/helpers/image_utils.py +58 -0
- isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
- isa_model/inference/services/vision/ollama_vision_service.py +60 -0
- isa_model/inference/services/vision/replicate_vision_service.py +241 -0
- isa_model/inference/services/vision/triton_vision_service.py +199 -0
- isa_model/inference/services/vision/yyds_vision_service.py +80 -0
- isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
- isa_model/inference/utils/conversion/onnx_converter.py +0 -0
- isa_model/inference/utils/conversion/torch_converter.py +0 -0
- isa_model/scripts/inference_tracker.py +283 -0
- isa_model/scripts/mlflow_manager.py +379 -0
- isa_model/scripts/model_registry.py +465 -0
- isa_model/scripts/start_mlflow.py +95 -0
- isa_model/scripts/training_tracker.py +257 -0
- isa_model/training/engine/llama_factory/__init__.py +39 -0
- isa_model/training/engine/llama_factory/config.py +115 -0
- isa_model/training/engine/llama_factory/data_adapter.py +284 -0
- isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
- isa_model/training/engine/llama_factory/factory.py +331 -0
- isa_model/training/engine/llama_factory/rl.py +254 -0
- isa_model/training/engine/llama_factory/trainer.py +171 -0
- isa_model/training/image_model/configs/create_config.py +37 -0
- isa_model/training/image_model/configs/create_flux_config.py +26 -0
- isa_model/training/image_model/configs/create_lora_config.py +21 -0
- isa_model/training/image_model/prepare_massed_compute.py +97 -0
- isa_model/training/image_model/prepare_upload.py +17 -0
- isa_model/training/image_model/raw_data/create_captions.py +16 -0
- isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
- isa_model/training/image_model/raw_data/pre_processing.py +200 -0
- isa_model/training/image_model/train/train.py +42 -0
- isa_model/training/image_model/train/train_flux.py +41 -0
- isa_model/training/image_model/train/train_lora.py +57 -0
- isa_model/training/image_model/train_main.py +25 -0
- isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
- isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
- isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
- isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
- isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
- isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
- isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
- isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
- isa_model-0.1.0.dist-info/METADATA +116 -0
- isa_model-0.1.0.dist-info/RECORD +117 -0
- isa_model-0.1.0.dist-info/WHEEL +5 -0
- isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
- isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,200 @@
|
|
1
|
+
import os
|
2
|
+
import cv2
|
3
|
+
import logging
|
4
|
+
from PIL import Image
|
5
|
+
from pathlib import Path
|
6
|
+
from tqdm import tqdm
|
7
|
+
from concurrent.futures import ProcessPoolExecutor
|
8
|
+
import shutil
|
9
|
+
from ultralytics import YOLO
|
10
|
+
import numpy
|
11
|
+
from pillow_heif import register_heif_opener
|
12
|
+
|
13
|
+
# Configure logging and PIL settings
|
14
|
+
logging.basicConfig(level=logging.INFO)
|
15
|
+
Image.MAX_IMAGE_PIXELS = None
|
16
|
+
register_heif_opener() # This enables HEIC support in PIL
|
17
|
+
|
18
|
+
COCO_CLASSES = {
|
19
|
+
'person': 0, 'bicycle': 1, 'car': 2, 'motorcycle': 3, 'airplane': 4, 'bus': 5, 'train': 6,
|
20
|
+
'truck': 7, 'boat': 8, 'traffic light': 9, 'fire hydrant': 10, 'stop sign': 11,
|
21
|
+
'parking meter': 12, 'bench': 13, 'bird': 14, 'cat': 15, 'dog': 16, 'horse': 17,
|
22
|
+
'sheep': 18, 'cow': 19, 'elephant': 20, 'bear': 21, 'zebra': 22, 'giraffe': 23,
|
23
|
+
'backpack': 24, 'umbrella': 25, 'handbag': 26, 'tie': 27, 'suitcase': 28, 'frisbee': 29,
|
24
|
+
'skis': 30, 'snowboard': 31, 'sports ball': 32, 'kite': 33, 'baseball bat': 34,
|
25
|
+
'baseball glove': 35, 'skateboard': 36, 'surfboard': 37, 'tennis racket': 38,
|
26
|
+
'bottle': 39, 'wine glass': 40, 'cup': 41, 'fork': 42, 'knife': 43, 'spoon': 44,
|
27
|
+
'bowl': 45, 'banana': 46, 'apple': 47, 'sandwich': 48, 'orange': 49, 'broccoli': 50,
|
28
|
+
'carrot': 51, 'hot dog': 52, 'pizza': 53, 'donut': 54, 'cake': 55, 'chair': 56,
|
29
|
+
'couch': 57, 'potted plant': 58, 'bed': 59, 'dining table': 60, 'toilet': 61,
|
30
|
+
'tv': 62, 'laptop': 63, 'mouse': 64, 'remote': 65, 'keyboard': 66, 'cell phone': 67,
|
31
|
+
'microwave': 68, 'oven': 69, 'toaster': 70, 'sink': 71, 'refrigerator': 72,
|
32
|
+
'book': 73, 'clock': 74, 'vase': 75, 'scissors': 76, 'teddy bear': 77,
|
33
|
+
'hair drier': 78, 'toothbrush': 79
|
34
|
+
}
|
35
|
+
|
36
|
+
class ImagePreProcessor:
|
37
|
+
def __init__(self, input_dir: str, output_dir: str, target_size: tuple = (512, 512),
|
38
|
+
padding: float = 0.3):
|
39
|
+
"""
|
40
|
+
Initialize the image preprocessor
|
41
|
+
"""
|
42
|
+
self.input_dir = Path(input_dir)
|
43
|
+
self.output_dir = Path(output_dir)
|
44
|
+
self.target_size = target_size
|
45
|
+
self.padding = padding
|
46
|
+
self.supported_formats = {'.jpg', '.jpeg', '.heic', '.png'}
|
47
|
+
|
48
|
+
# Load YOLO face detection model
|
49
|
+
try:
|
50
|
+
logging.info("Loading YOLO face detection model...")
|
51
|
+
current_dir = Path(__file__).parent # Get the directory where this script is located
|
52
|
+
model_path = current_dir / "models" / "yolov8n-face.pt"
|
53
|
+
|
54
|
+
if not os.path.exists(model_path):
|
55
|
+
raise FileNotFoundError(f"Model file not found at {model_path}")
|
56
|
+
self.model = YOLO(str(model_path)) # Convert Path to string for YOLO
|
57
|
+
logging.info("Successfully loaded YOLO face detection model")
|
58
|
+
except Exception as e:
|
59
|
+
logging.error(f"Failed to load YOLO model: {str(e)}")
|
60
|
+
raise
|
61
|
+
|
62
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
63
|
+
|
64
|
+
def detect_and_crop_face(self, img) -> tuple:
|
65
|
+
"""
|
66
|
+
Detect face in image and return cropped region
|
67
|
+
"""
|
68
|
+
cv2_img = cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR)
|
69
|
+
results = self.model(cv2_img)
|
70
|
+
|
71
|
+
# Get all face detections
|
72
|
+
detections = results[0].boxes
|
73
|
+
|
74
|
+
if len(detections) == 0:
|
75
|
+
return False, None
|
76
|
+
|
77
|
+
# Get coordinates of the first detected face
|
78
|
+
box = detections[0]
|
79
|
+
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
|
80
|
+
|
81
|
+
# Add padding
|
82
|
+
width = x2 - x1
|
83
|
+
height = y2 - y1
|
84
|
+
padding_x = int(width * self.padding)
|
85
|
+
padding_y = int(height * self.padding)
|
86
|
+
|
87
|
+
x1 = max(0, x1 - padding_x)
|
88
|
+
y1 = max(0, y1 - padding_y)
|
89
|
+
x2 = min(img.width, x2 + padding_x)
|
90
|
+
y2 = min(img.height, y2 + padding_y)
|
91
|
+
|
92
|
+
cropped_img = img.crop((x1, y1, x2, y2))
|
93
|
+
return True, cropped_img
|
94
|
+
|
95
|
+
def process_image(self, image_path: Path) -> tuple:
|
96
|
+
"""
|
97
|
+
Process a single image
|
98
|
+
|
99
|
+
Args:
|
100
|
+
image_path (Path): Path to input image
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
tuple: (success, message)
|
104
|
+
"""
|
105
|
+
try:
|
106
|
+
# Handle HEIC/HEIF files
|
107
|
+
if image_path.suffix.lower() in {'.heic', '.heif'}:
|
108
|
+
try:
|
109
|
+
with Image.open(image_path) as img:
|
110
|
+
# Convert HEIC to RGB mode
|
111
|
+
img = img.convert('RGB')
|
112
|
+
detected, cropped_img = self.detect_and_crop_face(img)
|
113
|
+
if not detected:
|
114
|
+
return False, f"No face detected in {image_path.name}"
|
115
|
+
except Exception as e:
|
116
|
+
return False, f"Error processing HEIC file {image_path.name}: {str(e)}"
|
117
|
+
else:
|
118
|
+
# Handle other image formats
|
119
|
+
with Image.open(image_path) as img:
|
120
|
+
if img.mode != 'RGB':
|
121
|
+
img = img.convert('RGB')
|
122
|
+
detected, cropped_img = self.detect_and_crop_face(img)
|
123
|
+
if not detected:
|
124
|
+
return False, f"No face detected in {image_path.name}"
|
125
|
+
|
126
|
+
# Process the cropped image
|
127
|
+
aspect_ratio = cropped_img.width / cropped_img.height
|
128
|
+
if aspect_ratio > 1:
|
129
|
+
new_width = self.target_size[0]
|
130
|
+
new_height = int(self.target_size[0] / aspect_ratio)
|
131
|
+
else:
|
132
|
+
new_height = self.target_size[1]
|
133
|
+
new_width = int(self.target_size[1] * aspect_ratio)
|
134
|
+
|
135
|
+
cropped_img = cropped_img.resize((new_width, new_height), Image.LANCZOS)
|
136
|
+
|
137
|
+
new_img = Image.new('RGB', self.target_size, (0, 0, 0))
|
138
|
+
paste_x = (self.target_size[0] - new_width) // 2
|
139
|
+
paste_y = (self.target_size[1] - new_height) // 2
|
140
|
+
new_img.paste(cropped_img, (paste_x, paste_y))
|
141
|
+
|
142
|
+
output_path = self.output_dir / f"{image_path.stem}.jpg"
|
143
|
+
new_img.save(output_path, 'JPEG', quality=95)
|
144
|
+
|
145
|
+
return True, f"Successfully processed {image_path.name}"
|
146
|
+
|
147
|
+
except Exception as e:
|
148
|
+
return False, f"Error processing {image_path.name}: {str(e)}"
|
149
|
+
|
150
|
+
def process_directory(self, num_workers: int = 4):
|
151
|
+
"""
|
152
|
+
Process all images in the input directory
|
153
|
+
|
154
|
+
Args:
|
155
|
+
num_workers (int): Number of worker processes to use
|
156
|
+
"""
|
157
|
+
# Get list of all images
|
158
|
+
image_files = [
|
159
|
+
f for f in self.input_dir.iterdir()
|
160
|
+
if f.is_file() and f.suffix.lower() in self.supported_formats
|
161
|
+
]
|
162
|
+
|
163
|
+
if not image_files:
|
164
|
+
logging.warning("No supported image files found in input directory")
|
165
|
+
return
|
166
|
+
|
167
|
+
logging.info(f"Found {len(image_files)} images to process")
|
168
|
+
|
169
|
+
# Process images using multiple workers
|
170
|
+
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
171
|
+
with tqdm(total=len(image_files), desc="Processing images") as pbar:
|
172
|
+
futures = []
|
173
|
+
for image_path in image_files:
|
174
|
+
future = executor.submit(self.process_image, image_path)
|
175
|
+
future.add_done_callback(lambda p: pbar.update(1))
|
176
|
+
futures.append(future)
|
177
|
+
|
178
|
+
# Process results
|
179
|
+
for future in futures:
|
180
|
+
success, message = future.result()
|
181
|
+
if not success:
|
182
|
+
logging.error(message)
|
183
|
+
|
184
|
+
def main():
|
185
|
+
# Update paths to use project-relative directories
|
186
|
+
current_dir = Path(__file__).parent # Get the directory where this script is located
|
187
|
+
input_dir = current_dir / "data" / "training_images"
|
188
|
+
output_dir = current_dir / "data" / "training_images_processed"
|
189
|
+
|
190
|
+
processor = ImagePreProcessor(
|
191
|
+
input_dir=input_dir,
|
192
|
+
output_dir=output_dir,
|
193
|
+
target_size=(512, 512), # Good size for Kohya training
|
194
|
+
padding=0.3, # 30% padding around faces
|
195
|
+
)
|
196
|
+
|
197
|
+
processor.process_directory(num_workers=4)
|
198
|
+
|
199
|
+
if __name__ == "__main__":
|
200
|
+
main()
|
@@ -0,0 +1,42 @@
|
|
1
|
+
import json
|
2
|
+
import subprocess
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
def train_lora():
|
6
|
+
# Load your config
|
7
|
+
with open('training_config.json', 'r') as f:
|
8
|
+
config = json.load(f)
|
9
|
+
|
10
|
+
# Construct the training command
|
11
|
+
cmd = [
|
12
|
+
"accelerate", "launch",
|
13
|
+
"--num_cpu_threads_per_process", str(config["num_cpu_threads_per_process"]),
|
14
|
+
"train_network.py",
|
15
|
+
"--pretrained_model_name_or_path", config["pretrained_model_name_or_path"],
|
16
|
+
"--train_data_dir", config["train_data_dir"],
|
17
|
+
"--output_dir", config["output_dir"],
|
18
|
+
"--output_name", config["output_name"],
|
19
|
+
"--save_model_as", config["save_model_as"],
|
20
|
+
"--learning_rate", str(config["learning_rate"]),
|
21
|
+
"--train_batch_size", str(config["train_batch_size"]),
|
22
|
+
"--epoch", str(config["epoch"]),
|
23
|
+
"--save_every_n_epochs", str(config["save_every_n_epochs"]),
|
24
|
+
"--mixed_precision", config["mixed_precision"],
|
25
|
+
"--cache_latents",
|
26
|
+
"--gradient_checkpointing"
|
27
|
+
]
|
28
|
+
|
29
|
+
# Add FLUX specific parameters
|
30
|
+
if config.get("flux1_checkbox"):
|
31
|
+
cmd.extend([
|
32
|
+
"--flux1_t5xxl", config["flux1_t5xxl"],
|
33
|
+
"--flux1_clip_l", config["flux1_clip_l"],
|
34
|
+
"--flux1_cache_text_encoder_outputs",
|
35
|
+
"--flux1_cache_text_encoder_outputs_to_disk"
|
36
|
+
])
|
37
|
+
|
38
|
+
# Execute the training
|
39
|
+
subprocess.run(cmd, check=True)
|
40
|
+
|
41
|
+
if __name__ == "__main__":
|
42
|
+
train_lora()
|
@@ -0,0 +1,41 @@
|
|
1
|
+
import json
|
2
|
+
import subprocess
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
def train_flux():
|
6
|
+
# Load your config
|
7
|
+
with open('flux_config.json', 'r') as f:
|
8
|
+
config = json.load(f)
|
9
|
+
|
10
|
+
# Construct the training command for Flux finetuning
|
11
|
+
cmd = [
|
12
|
+
"accelerate", "launch",
|
13
|
+
"--num_cpu_threads_per_process", str(config["num_cpu_threads_per_process"]),
|
14
|
+
"train_db.py",
|
15
|
+
"--pretrained_model_name_or_path", config["pretrained_model_name_or_path"],
|
16
|
+
"--train_data_dir", config["train_data_dir"],
|
17
|
+
"--output_dir", config["output_dir"],
|
18
|
+
"--output_name", config["output_name"],
|
19
|
+
"--train_batch_size", str(config["train_batch_size"]),
|
20
|
+
"--save_every_n_epochs", str(config["save_every_n_epochs"]),
|
21
|
+
"--learning_rate", str(config["learning_rate"]),
|
22
|
+
"--max_train_epochs", str(config["epoch"]),
|
23
|
+
"--mixed_precision", config["mixed_precision"],
|
24
|
+
"--save_model_as", config["save_model_as"],
|
25
|
+
"--cache_latents",
|
26
|
+
"--cache_latents_to_disk",
|
27
|
+
"--gradient_checkpointing",
|
28
|
+
"--optimizer_type", "Adafactor",
|
29
|
+
"--optimizer_args", "scale_parameter=False relative_step=False warmup_init=False weight_decay=0.01",
|
30
|
+
"--max_resolution", "1024,1024",
|
31
|
+
"--full_bf16",
|
32
|
+
"--flux1_checkbox",
|
33
|
+
"--flux1_t5xxl", config["flux1_t5xxl"],
|
34
|
+
"--flux1_clip_l", config["flux1_clip_l"],
|
35
|
+
"--flux1_cache_text_encoder_outputs",
|
36
|
+
"--flux1_cache_text_encoder_outputs_to_disk",
|
37
|
+
"--flux_fused_backward_pass"
|
38
|
+
]
|
39
|
+
|
40
|
+
# Execute the training
|
41
|
+
subprocess.run(cmd, check=True)
|
@@ -0,0 +1,57 @@
|
|
1
|
+
import json
|
2
|
+
import subprocess
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
def train_lora():
|
6
|
+
# Load your config
|
7
|
+
with open('training_config.json', 'r') as f:
|
8
|
+
config = json.load(f)
|
9
|
+
|
10
|
+
# Construct the training command for LoRA
|
11
|
+
cmd = [
|
12
|
+
"accelerate", "launch",
|
13
|
+
"--num_cpu_threads_per_process", str(config["num_cpu_threads_per_process"]),
|
14
|
+
"sdxl_train_network.py", # Use the SDXL LoRA training script
|
15
|
+
"--network_module", "networks.lora", # Specify LoRA network
|
16
|
+
"--pretrained_model_name_or_path", config["pretrained_model_name_or_path"],
|
17
|
+
"--train_data_dir", config["train_data_dir"],
|
18
|
+
"--output_dir", config["output_dir"],
|
19
|
+
"--output_name", config["output_name"],
|
20
|
+
"--save_model_as", config["save_model_as"],
|
21
|
+
"--network_alpha", "1", # LoRA alpha parameter
|
22
|
+
"--network_dim", "32", # LoRA dimension
|
23
|
+
"--learning_rate", str(config["learning_rate"]),
|
24
|
+
"--train_batch_size", str(config["train_batch_size"]),
|
25
|
+
"--max_train_epochs", str(config["epoch"]),
|
26
|
+
"--save_every_n_epochs", str(config["save_every_n_epochs"]),
|
27
|
+
"--mixed_precision", config["mixed_precision"],
|
28
|
+
"--cache_latents",
|
29
|
+
"--gradient_checkpointing",
|
30
|
+
"--network_args", "conv_dim=32", "conv_alpha=1", # LoRA network arguments
|
31
|
+
"--noise_offset", "0.1",
|
32
|
+
"--adaptive_noise_scale", "0.01",
|
33
|
+
"--max_resolution", "1024,1024",
|
34
|
+
"--min_bucket_reso", "256",
|
35
|
+
"--max_bucket_reso", "1024",
|
36
|
+
"--xformers",
|
37
|
+
"--bucket_reso_steps", "64",
|
38
|
+
"--caption_extension", ".txt",
|
39
|
+
"--optimizer_type", "AdaFactor",
|
40
|
+
"--optimizer_args", "scale_parameter=False", "relative_step=False", "warmup_init=False",
|
41
|
+
"--lr_scheduler", "constant"
|
42
|
+
]
|
43
|
+
|
44
|
+
# Add FLUX specific parameters for LoRA
|
45
|
+
if config.get("flux1_checkbox"):
|
46
|
+
cmd.extend([
|
47
|
+
"--flux1_t5xxl", config["flux1_t5xxl"],
|
48
|
+
"--flux1_clip_l", config["flux1_clip_l"],
|
49
|
+
"--flux1_cache_text_encoder_outputs",
|
50
|
+
"--flux1_cache_text_encoder_outputs_to_disk"
|
51
|
+
])
|
52
|
+
|
53
|
+
# Execute the training
|
54
|
+
subprocess.run(cmd, check=True)
|
55
|
+
|
56
|
+
if __name__ == "__main__":
|
57
|
+
train_lora()
|
@@ -0,0 +1,25 @@
|
|
1
|
+
import os
|
2
|
+
from pathlib import Path
|
3
|
+
import shutil
|
4
|
+
from app.services.training.image_model.raw_data.create_lora_captions import create_lora_captions
|
5
|
+
from app.services.training.image_model.train.train_flux import train_flux
|
6
|
+
|
7
|
+
def main():
|
8
|
+
# Setup paths
|
9
|
+
project_root = Path(__file__).parent
|
10
|
+
processed_images_dir = project_root / "raw_data/training_images_processed"
|
11
|
+
|
12
|
+
# 1. Generate captions for all processed images
|
13
|
+
print("Creating captions for processed images...")
|
14
|
+
create_lora_captions(processed_images_dir)
|
15
|
+
|
16
|
+
# 2. Create Flux config
|
17
|
+
print("Creating Flux configuration...")
|
18
|
+
os.system(f"python {project_root}/configs/create_flux_config.py")
|
19
|
+
|
20
|
+
# 3. Run Flux training
|
21
|
+
print("Starting Flux training...")
|
22
|
+
train_flux()
|
23
|
+
|
24
|
+
if __name__ == "__main__":
|
25
|
+
main()
|
@@ -0,0 +1,47 @@
|
|
1
|
+
# app/services/llm_model/tracing/annotation/annotation_schema.py
|
2
|
+
from enum import Enum
|
3
|
+
from pydantic import BaseModel, Field
|
4
|
+
from typing import Dict, Any, List, Optional
|
5
|
+
from datetime import datetime
|
6
|
+
|
7
|
+
class AnnotationType(str, Enum):
|
8
|
+
ACCURACY = "accuracy"
|
9
|
+
HELPFULNESS = "helpfulness"
|
10
|
+
TOXICITY = "toxicity"
|
11
|
+
CUSTOM = "custom"
|
12
|
+
|
13
|
+
class RatingScale(int, Enum):
|
14
|
+
POOR = 1
|
15
|
+
FAIR = 2
|
16
|
+
GOOD = 3
|
17
|
+
EXCELLENT = 4
|
18
|
+
|
19
|
+
class AnnotationAspects(BaseModel):
|
20
|
+
factually_correct: bool = True
|
21
|
+
relevant: bool = True
|
22
|
+
harmful: bool = False
|
23
|
+
biased: bool = False
|
24
|
+
complete: bool = True
|
25
|
+
efficient: bool = True
|
26
|
+
|
27
|
+
class BetterResponse(BaseModel):
|
28
|
+
content: str
|
29
|
+
reason: Optional[str]
|
30
|
+
metadata: Optional[Dict[str, Any]] = {}
|
31
|
+
|
32
|
+
class AnnotationFeedback(BaseModel):
|
33
|
+
rating: RatingScale
|
34
|
+
category: AnnotationType
|
35
|
+
aspects: AnnotationAspects
|
36
|
+
better_response: Optional[BetterResponse]
|
37
|
+
comment: Optional[str]
|
38
|
+
metadata: Optional[Dict[str, Any]] = {}
|
39
|
+
is_selected_for_training: bool = False
|
40
|
+
|
41
|
+
class ItemAnnotation(BaseModel):
|
42
|
+
item_id: str
|
43
|
+
feedback: Optional[AnnotationFeedback]
|
44
|
+
status: str = "pending"
|
45
|
+
annotated_at: Optional[datetime]
|
46
|
+
annotator_id: Optional[str]
|
47
|
+
training_status: Optional[str] = None
|
@@ -0,0 +1,126 @@
|
|
1
|
+
from typing import Dict, Any, List
|
2
|
+
from datetime import datetime
|
3
|
+
from app.config.config_manager import config_manager
|
4
|
+
from app.services.training.llm_model.annotation.annotation_schema import AnnotationFeedback, RatingScale, AnnotationAspects
|
5
|
+
from bson.objectid import ObjectId
|
6
|
+
from app.services.training.llm_model.annotation.storage.dataset_manager import DatasetManager
|
7
|
+
|
8
|
+
class AnnotationProcessor:
|
9
|
+
def __init__(self):
|
10
|
+
self.logger = config_manager.get_logger(__name__)
|
11
|
+
self.dataset_manager = DatasetManager()
|
12
|
+
self.batch_size = 1000 # Configure as needed
|
13
|
+
|
14
|
+
async def process_queue(self) -> None:
|
15
|
+
"""Process pending items and create datasets"""
|
16
|
+
db = await config_manager.get_db('mongodb')
|
17
|
+
queue = db['training_queue']
|
18
|
+
|
19
|
+
# Process SFT items
|
20
|
+
sft_items = await self._get_pending_items("sft")
|
21
|
+
if len(sft_items) >= self.batch_size:
|
22
|
+
await self._create_sft_dataset(sft_items)
|
23
|
+
|
24
|
+
# Process RLHF items
|
25
|
+
rlhf_items = await self._get_pending_items("rlhf")
|
26
|
+
if len(rlhf_items) >= self.batch_size:
|
27
|
+
await self._create_rlhf_dataset(rlhf_items)
|
28
|
+
|
29
|
+
async def _create_sft_dataset(self, items: List[Dict[str, Any]]):
|
30
|
+
"""Create and upload SFT dataset"""
|
31
|
+
dataset = await self.dataset_manager.create_dataset(
|
32
|
+
name=f"sft_dataset_v{datetime.now().strftime('%Y%m%d')}",
|
33
|
+
type="sft",
|
34
|
+
version=datetime.now().strftime("%Y%m%d"),
|
35
|
+
source_annotations=[item["annotation_id"] for item in items]
|
36
|
+
)
|
37
|
+
|
38
|
+
formatted_data = [
|
39
|
+
await self._process_sft_item(item)
|
40
|
+
for item in items
|
41
|
+
]
|
42
|
+
|
43
|
+
await self.dataset_manager.upload_dataset_file(
|
44
|
+
dataset.id,
|
45
|
+
formatted_data
|
46
|
+
)
|
47
|
+
|
48
|
+
async def _process_sft_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
49
|
+
"""Process item for SFT dataset generation
|
50
|
+
Format follows HF conversation format for SFT training
|
51
|
+
"""
|
52
|
+
db = await config_manager.get_db('mongodb')
|
53
|
+
annotations = db['annotations']
|
54
|
+
|
55
|
+
# Get full annotation context
|
56
|
+
annotation = await annotations.find_one({"_id": ObjectId(item["annotation_id"])})
|
57
|
+
target_item = next(i for i in annotation["items"] if i["item_id"] == item["item_id"])
|
58
|
+
|
59
|
+
# Format as conversation
|
60
|
+
messages = [
|
61
|
+
{
|
62
|
+
"role": "system",
|
63
|
+
"content": "You are a helpful AI assistant that provides accurate and relevant information."
|
64
|
+
},
|
65
|
+
{
|
66
|
+
"role": "user",
|
67
|
+
"content": target_item["input"]["messages"][0]["content"]
|
68
|
+
},
|
69
|
+
{
|
70
|
+
"role": "assistant",
|
71
|
+
"content": target_item["output"]["content"]
|
72
|
+
}
|
73
|
+
]
|
74
|
+
|
75
|
+
return {
|
76
|
+
"messages": messages,
|
77
|
+
"metadata": {
|
78
|
+
"rating": item["feedback"]["rating"],
|
79
|
+
"aspects": item["feedback"]["aspects"],
|
80
|
+
"category": item["feedback"]["category"]
|
81
|
+
}
|
82
|
+
}
|
83
|
+
|
84
|
+
async def _process_rlhf_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
85
|
+
"""Process item for RLHF dataset generation
|
86
|
+
Format follows preference pairs structure for RLHF training
|
87
|
+
"""
|
88
|
+
db = await config_manager.get_db('mongodb')
|
89
|
+
annotations = db['annotations']
|
90
|
+
|
91
|
+
# Get full annotation context
|
92
|
+
annotation = await annotations.find_one({"_id": ObjectId(item["annotation_id"])})
|
93
|
+
target_item = next(i for i in annotation["items"] if i["item_id"] == item["item_id"])
|
94
|
+
|
95
|
+
# Format as preference pairs
|
96
|
+
return {
|
97
|
+
"prompt": target_item["input"]["messages"][0]["content"],
|
98
|
+
"chosen": item["feedback"]["better_response"]["content"],
|
99
|
+
"rejected": target_item["output"]["content"],
|
100
|
+
"metadata": {
|
101
|
+
"reason": item["feedback"]["better_response"]["reason"],
|
102
|
+
"category": item["feedback"]["category"]
|
103
|
+
}
|
104
|
+
}
|
105
|
+
|
106
|
+
async def get_training_data(
|
107
|
+
self,
|
108
|
+
data_type: str,
|
109
|
+
limit: int = 1000
|
110
|
+
) -> List[Dict[str, Any]]:
|
111
|
+
"""Retrieve formatted training data"""
|
112
|
+
db = await config_manager.get_db('mongodb')
|
113
|
+
training_data = db['training_data']
|
114
|
+
|
115
|
+
data = await training_data.find(
|
116
|
+
{"type": data_type}
|
117
|
+
).limit(limit).to_list(length=limit)
|
118
|
+
|
119
|
+
if data_type == "sft":
|
120
|
+
return [item["data"]["messages"] for item in data]
|
121
|
+
else: # rlhf
|
122
|
+
return [{
|
123
|
+
"prompt": item["data"]["prompt"],
|
124
|
+
"chosen": item["data"]["chosen"],
|
125
|
+
"rejected": item["data"]["rejected"]
|
126
|
+
} for item in data]
|
@@ -0,0 +1,131 @@
|
|
1
|
+
# app/services/llm_model/annotation/dataset/dataset_manager.py
|
2
|
+
from typing import Dict, Any, List
|
3
|
+
from datetime import datetime
|
4
|
+
import json
|
5
|
+
import io
|
6
|
+
from app.config.config_manager import config_manager
|
7
|
+
from .dataset_schema import Dataset, DatasetType, DatasetStatus, DatasetFiles, DatasetStats
|
8
|
+
from bson import ObjectId
|
9
|
+
|
10
|
+
class DatasetManager:
|
11
|
+
def __init__(self):
|
12
|
+
self.logger = config_manager.get_logger(__name__)
|
13
|
+
self.minio_client = None
|
14
|
+
self.bucket_name = "training-datasets"
|
15
|
+
|
16
|
+
async def _ensure_minio_client(self):
|
17
|
+
if not self.minio_client:
|
18
|
+
self.minio_client = await config_manager.get_storage_client()
|
19
|
+
|
20
|
+
async def create_dataset(
|
21
|
+
self,
|
22
|
+
name: str,
|
23
|
+
type: DatasetType,
|
24
|
+
version: str,
|
25
|
+
source_annotations: List[str]
|
26
|
+
) -> Dataset:
|
27
|
+
"""Create a new dataset record"""
|
28
|
+
db = await config_manager.get_db('mongodb')
|
29
|
+
collection = db['training_datasets']
|
30
|
+
|
31
|
+
dataset = Dataset(
|
32
|
+
name=name,
|
33
|
+
type=type,
|
34
|
+
version=version,
|
35
|
+
storage_path=f"datasets/{type.value}/{version}",
|
36
|
+
files=DatasetFiles(
|
37
|
+
train="train.jsonl",
|
38
|
+
eval=None,
|
39
|
+
test=None
|
40
|
+
),
|
41
|
+
stats=DatasetStats(
|
42
|
+
total_examples=0,
|
43
|
+
avg_length=0.0,
|
44
|
+
num_conversations=0,
|
45
|
+
additional_metrics={}
|
46
|
+
),
|
47
|
+
source_annotations=source_annotations,
|
48
|
+
created_at=datetime.utcnow(),
|
49
|
+
status=DatasetStatus.PENDING,
|
50
|
+
metadata={}
|
51
|
+
)
|
52
|
+
|
53
|
+
result = await collection.insert_one(dataset.dict(exclude={'id'}))
|
54
|
+
return Dataset(**{**dataset.dict(), '_id': result.inserted_id})
|
55
|
+
|
56
|
+
async def upload_dataset_file(
|
57
|
+
self,
|
58
|
+
dataset_id: str,
|
59
|
+
data: List[Dict[str, Any]],
|
60
|
+
file_type: str = "train"
|
61
|
+
) -> bool:
|
62
|
+
"""Upload dataset to MinIO"""
|
63
|
+
try:
|
64
|
+
await self._ensure_minio_client()
|
65
|
+
db = await config_manager.get_db('mongodb')
|
66
|
+
|
67
|
+
object_id = ObjectId(dataset_id)
|
68
|
+
dataset = await db['training_datasets'].find_one({"_id": object_id})
|
69
|
+
|
70
|
+
if not dataset:
|
71
|
+
self.logger.error(f"Dataset not found with id: {dataset_id}")
|
72
|
+
return False
|
73
|
+
|
74
|
+
# Convert to JSONL
|
75
|
+
buffer = io.StringIO()
|
76
|
+
for item in data:
|
77
|
+
buffer.write(json.dumps(item) + "\n")
|
78
|
+
|
79
|
+
storage_path = dataset['storage_path'].rstrip('/')
|
80
|
+
file_path = f"{storage_path}/{file_type}.jsonl"
|
81
|
+
|
82
|
+
buffer_value = buffer.getvalue().encode()
|
83
|
+
|
84
|
+
self.logger.debug(f"Uploading to MinIO path: {file_path}")
|
85
|
+
|
86
|
+
self.minio_client.put_object(
|
87
|
+
self.bucket_name,
|
88
|
+
file_path,
|
89
|
+
io.BytesIO(buffer_value),
|
90
|
+
len(buffer_value)
|
91
|
+
)
|
92
|
+
|
93
|
+
avg_length = sum(len(str(item)) for item in data) / len(data) if data else 0
|
94
|
+
|
95
|
+
await db['training_datasets'].update_one(
|
96
|
+
{"_id": object_id},
|
97
|
+
{
|
98
|
+
"$set": {
|
99
|
+
f"files.{file_type}": f"{file_type}.jsonl",
|
100
|
+
"stats.total_examples": len(data),
|
101
|
+
"stats.avg_length": avg_length,
|
102
|
+
"stats.num_conversations": len(data),
|
103
|
+
"status": DatasetStatus.READY
|
104
|
+
}
|
105
|
+
}
|
106
|
+
)
|
107
|
+
|
108
|
+
return True
|
109
|
+
|
110
|
+
except Exception as e:
|
111
|
+
self.logger.error(f"Failed to upload dataset: {e}")
|
112
|
+
return False
|
113
|
+
|
114
|
+
async def get_dataset_info(self, dataset_id: str) -> Dict[str, Any]:
|
115
|
+
"""Get dataset information"""
|
116
|
+
try:
|
117
|
+
db = await config_manager.get_db('mongodb')
|
118
|
+
object_id = ObjectId(dataset_id) # Convert string ID to ObjectId
|
119
|
+
dataset = await db['training_datasets'].find_one({"_id": object_id})
|
120
|
+
|
121
|
+
if not dataset:
|
122
|
+
self.logger.error(f"Dataset not found with id: {dataset_id}")
|
123
|
+
return None
|
124
|
+
|
125
|
+
# Convert ObjectId to string for JSON serialization
|
126
|
+
dataset['_id'] = str(dataset['_id'])
|
127
|
+
return dataset
|
128
|
+
|
129
|
+
except Exception as e:
|
130
|
+
self.logger.error(f"Failed to get dataset info: {e}")
|
131
|
+
return None
|