isa-model 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. isa_model/__init__.py +5 -0
  2. isa_model/core/model_manager.py +143 -0
  3. isa_model/core/model_registry.py +115 -0
  4. isa_model/core/model_router.py +226 -0
  5. isa_model/core/model_storage.py +133 -0
  6. isa_model/core/model_version.py +0 -0
  7. isa_model/core/resource_manager.py +202 -0
  8. isa_model/core/storage/hf_storage.py +0 -0
  9. isa_model/core/storage/local_storage.py +0 -0
  10. isa_model/core/storage/minio_storage.py +0 -0
  11. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
  12. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
  13. isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
  14. isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
  15. isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
  16. isa_model/inference/__init__.py +11 -0
  17. isa_model/inference/adapter/unified_api.py +248 -0
  18. isa_model/inference/ai_factory.py +359 -0
  19. isa_model/inference/base.py +46 -0
  20. isa_model/inference/providers/__init__.py +19 -0
  21. isa_model/inference/providers/base_provider.py +30 -0
  22. isa_model/inference/providers/model_cache_manager.py +341 -0
  23. isa_model/inference/providers/ollama_provider.py +73 -0
  24. isa_model/inference/providers/openai_provider.py +101 -0
  25. isa_model/inference/providers/replicate_provider.py +107 -0
  26. isa_model/inference/providers/triton_provider.py +439 -0
  27. isa_model/inference/services/__init__.py +14 -0
  28. isa_model/inference/services/audio/base_stt_service.py +91 -0
  29. isa_model/inference/services/audio/base_tts_service.py +136 -0
  30. isa_model/inference/services/audio/openai_tts_service.py +71 -0
  31. isa_model/inference/services/base_service.py +106 -0
  32. isa_model/inference/services/embedding/ollama_embed_service.py +97 -0
  33. isa_model/inference/services/embedding/openai_embed_service.py +0 -0
  34. isa_model/inference/services/llm/__init__.py +12 -0
  35. isa_model/inference/services/llm/base_llm_service.py +134 -0
  36. isa_model/inference/services/llm/ollama_llm_service.py +99 -0
  37. isa_model/inference/services/llm/openai_llm_service.py +138 -0
  38. isa_model/inference/services/others/table_transformer_service.py +61 -0
  39. isa_model/inference/services/vision/__init__.py +12 -0
  40. isa_model/inference/services/vision/helpers/image_utils.py +58 -0
  41. isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
  42. isa_model/inference/services/vision/ollama_vision_service.py +60 -0
  43. isa_model/inference/services/vision/openai_vision_service.py +80 -0
  44. isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
  45. isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
  46. isa_model/inference/utils/conversion/onnx_converter.py +0 -0
  47. isa_model/inference/utils/conversion/torch_converter.py +0 -0
  48. isa_model/scripts/inference_tracker.py +283 -0
  49. isa_model/scripts/mlflow_manager.py +379 -0
  50. isa_model/scripts/model_registry.py +465 -0
  51. isa_model/scripts/start_mlflow.py +95 -0
  52. isa_model/scripts/training_tracker.py +257 -0
  53. isa_model/training/engine/llama_factory/__init__.py +39 -0
  54. isa_model/training/engine/llama_factory/config.py +115 -0
  55. isa_model/training/engine/llama_factory/data_adapter.py +284 -0
  56. isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
  57. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
  58. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
  59. isa_model/training/engine/llama_factory/factory.py +331 -0
  60. isa_model/training/engine/llama_factory/rl.py +254 -0
  61. isa_model/training/engine/llama_factory/trainer.py +171 -0
  62. isa_model/training/image_model/configs/create_config.py +37 -0
  63. isa_model/training/image_model/configs/create_flux_config.py +26 -0
  64. isa_model/training/image_model/configs/create_lora_config.py +21 -0
  65. isa_model/training/image_model/prepare_massed_compute.py +97 -0
  66. isa_model/training/image_model/prepare_upload.py +17 -0
  67. isa_model/training/image_model/raw_data/create_captions.py +16 -0
  68. isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
  69. isa_model/training/image_model/raw_data/pre_processing.py +200 -0
  70. isa_model/training/image_model/train/train.py +42 -0
  71. isa_model/training/image_model/train/train_flux.py +41 -0
  72. isa_model/training/image_model/train/train_lora.py +57 -0
  73. isa_model/training/image_model/train_main.py +25 -0
  74. isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
  75. isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
  76. isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
  77. isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
  78. isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
  79. isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
  80. isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
  81. isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
  82. isa_model-0.0.1.dist-info/METADATA +327 -0
  83. isa_model-0.0.1.dist-info/RECORD +86 -0
  84. isa_model-0.0.1.dist-info/WHEEL +5 -0
  85. isa_model-0.0.1.dist-info/licenses/LICENSE +21 -0
  86. isa_model-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,200 @@
1
+ import os
2
+ import cv2
3
+ import logging
4
+ from PIL import Image
5
+ from pathlib import Path
6
+ from tqdm import tqdm
7
+ from concurrent.futures import ProcessPoolExecutor
8
+ import shutil
9
+ from ultralytics import YOLO
10
+ import numpy
11
+ from pillow_heif import register_heif_opener
12
+
13
+ # Configure logging and PIL settings
14
+ logging.basicConfig(level=logging.INFO)
15
+ Image.MAX_IMAGE_PIXELS = None
16
+ register_heif_opener() # This enables HEIC support in PIL
17
+
18
+ COCO_CLASSES = {
19
+ 'person': 0, 'bicycle': 1, 'car': 2, 'motorcycle': 3, 'airplane': 4, 'bus': 5, 'train': 6,
20
+ 'truck': 7, 'boat': 8, 'traffic light': 9, 'fire hydrant': 10, 'stop sign': 11,
21
+ 'parking meter': 12, 'bench': 13, 'bird': 14, 'cat': 15, 'dog': 16, 'horse': 17,
22
+ 'sheep': 18, 'cow': 19, 'elephant': 20, 'bear': 21, 'zebra': 22, 'giraffe': 23,
23
+ 'backpack': 24, 'umbrella': 25, 'handbag': 26, 'tie': 27, 'suitcase': 28, 'frisbee': 29,
24
+ 'skis': 30, 'snowboard': 31, 'sports ball': 32, 'kite': 33, 'baseball bat': 34,
25
+ 'baseball glove': 35, 'skateboard': 36, 'surfboard': 37, 'tennis racket': 38,
26
+ 'bottle': 39, 'wine glass': 40, 'cup': 41, 'fork': 42, 'knife': 43, 'spoon': 44,
27
+ 'bowl': 45, 'banana': 46, 'apple': 47, 'sandwich': 48, 'orange': 49, 'broccoli': 50,
28
+ 'carrot': 51, 'hot dog': 52, 'pizza': 53, 'donut': 54, 'cake': 55, 'chair': 56,
29
+ 'couch': 57, 'potted plant': 58, 'bed': 59, 'dining table': 60, 'toilet': 61,
30
+ 'tv': 62, 'laptop': 63, 'mouse': 64, 'remote': 65, 'keyboard': 66, 'cell phone': 67,
31
+ 'microwave': 68, 'oven': 69, 'toaster': 70, 'sink': 71, 'refrigerator': 72,
32
+ 'book': 73, 'clock': 74, 'vase': 75, 'scissors': 76, 'teddy bear': 77,
33
+ 'hair drier': 78, 'toothbrush': 79
34
+ }
35
+
36
+ class ImagePreProcessor:
37
+ def __init__(self, input_dir: str, output_dir: str, target_size: tuple = (512, 512),
38
+ padding: float = 0.3):
39
+ """
40
+ Initialize the image preprocessor
41
+ """
42
+ self.input_dir = Path(input_dir)
43
+ self.output_dir = Path(output_dir)
44
+ self.target_size = target_size
45
+ self.padding = padding
46
+ self.supported_formats = {'.jpg', '.jpeg', '.heic', '.png'}
47
+
48
+ # Load YOLO face detection model
49
+ try:
50
+ logging.info("Loading YOLO face detection model...")
51
+ current_dir = Path(__file__).parent # Get the directory where this script is located
52
+ model_path = current_dir / "models" / "yolov8n-face.pt"
53
+
54
+ if not os.path.exists(model_path):
55
+ raise FileNotFoundError(f"Model file not found at {model_path}")
56
+ self.model = YOLO(str(model_path)) # Convert Path to string for YOLO
57
+ logging.info("Successfully loaded YOLO face detection model")
58
+ except Exception as e:
59
+ logging.error(f"Failed to load YOLO model: {str(e)}")
60
+ raise
61
+
62
+ self.output_dir.mkdir(parents=True, exist_ok=True)
63
+
64
+ def detect_and_crop_face(self, img) -> tuple:
65
+ """
66
+ Detect face in image and return cropped region
67
+ """
68
+ cv2_img = cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR)
69
+ results = self.model(cv2_img)
70
+
71
+ # Get all face detections
72
+ detections = results[0].boxes
73
+
74
+ if len(detections) == 0:
75
+ return False, None
76
+
77
+ # Get coordinates of the first detected face
78
+ box = detections[0]
79
+ x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
80
+
81
+ # Add padding
82
+ width = x2 - x1
83
+ height = y2 - y1
84
+ padding_x = int(width * self.padding)
85
+ padding_y = int(height * self.padding)
86
+
87
+ x1 = max(0, x1 - padding_x)
88
+ y1 = max(0, y1 - padding_y)
89
+ x2 = min(img.width, x2 + padding_x)
90
+ y2 = min(img.height, y2 + padding_y)
91
+
92
+ cropped_img = img.crop((x1, y1, x2, y2))
93
+ return True, cropped_img
94
+
95
+ def process_image(self, image_path: Path) -> tuple:
96
+ """
97
+ Process a single image
98
+
99
+ Args:
100
+ image_path (Path): Path to input image
101
+
102
+ Returns:
103
+ tuple: (success, message)
104
+ """
105
+ try:
106
+ # Handle HEIC/HEIF files
107
+ if image_path.suffix.lower() in {'.heic', '.heif'}:
108
+ try:
109
+ with Image.open(image_path) as img:
110
+ # Convert HEIC to RGB mode
111
+ img = img.convert('RGB')
112
+ detected, cropped_img = self.detect_and_crop_face(img)
113
+ if not detected:
114
+ return False, f"No face detected in {image_path.name}"
115
+ except Exception as e:
116
+ return False, f"Error processing HEIC file {image_path.name}: {str(e)}"
117
+ else:
118
+ # Handle other image formats
119
+ with Image.open(image_path) as img:
120
+ if img.mode != 'RGB':
121
+ img = img.convert('RGB')
122
+ detected, cropped_img = self.detect_and_crop_face(img)
123
+ if not detected:
124
+ return False, f"No face detected in {image_path.name}"
125
+
126
+ # Process the cropped image
127
+ aspect_ratio = cropped_img.width / cropped_img.height
128
+ if aspect_ratio > 1:
129
+ new_width = self.target_size[0]
130
+ new_height = int(self.target_size[0] / aspect_ratio)
131
+ else:
132
+ new_height = self.target_size[1]
133
+ new_width = int(self.target_size[1] * aspect_ratio)
134
+
135
+ cropped_img = cropped_img.resize((new_width, new_height), Image.LANCZOS)
136
+
137
+ new_img = Image.new('RGB', self.target_size, (0, 0, 0))
138
+ paste_x = (self.target_size[0] - new_width) // 2
139
+ paste_y = (self.target_size[1] - new_height) // 2
140
+ new_img.paste(cropped_img, (paste_x, paste_y))
141
+
142
+ output_path = self.output_dir / f"{image_path.stem}.jpg"
143
+ new_img.save(output_path, 'JPEG', quality=95)
144
+
145
+ return True, f"Successfully processed {image_path.name}"
146
+
147
+ except Exception as e:
148
+ return False, f"Error processing {image_path.name}: {str(e)}"
149
+
150
+ def process_directory(self, num_workers: int = 4):
151
+ """
152
+ Process all images in the input directory
153
+
154
+ Args:
155
+ num_workers (int): Number of worker processes to use
156
+ """
157
+ # Get list of all images
158
+ image_files = [
159
+ f for f in self.input_dir.iterdir()
160
+ if f.is_file() and f.suffix.lower() in self.supported_formats
161
+ ]
162
+
163
+ if not image_files:
164
+ logging.warning("No supported image files found in input directory")
165
+ return
166
+
167
+ logging.info(f"Found {len(image_files)} images to process")
168
+
169
+ # Process images using multiple workers
170
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
171
+ with tqdm(total=len(image_files), desc="Processing images") as pbar:
172
+ futures = []
173
+ for image_path in image_files:
174
+ future = executor.submit(self.process_image, image_path)
175
+ future.add_done_callback(lambda p: pbar.update(1))
176
+ futures.append(future)
177
+
178
+ # Process results
179
+ for future in futures:
180
+ success, message = future.result()
181
+ if not success:
182
+ logging.error(message)
183
+
184
+ def main():
185
+ # Update paths to use project-relative directories
186
+ current_dir = Path(__file__).parent # Get the directory where this script is located
187
+ input_dir = current_dir / "data" / "training_images"
188
+ output_dir = current_dir / "data" / "training_images_processed"
189
+
190
+ processor = ImagePreProcessor(
191
+ input_dir=input_dir,
192
+ output_dir=output_dir,
193
+ target_size=(512, 512), # Good size for Kohya training
194
+ padding=0.3, # 30% padding around faces
195
+ )
196
+
197
+ processor.process_directory(num_workers=4)
198
+
199
+ if __name__ == "__main__":
200
+ main()
@@ -0,0 +1,42 @@
1
+ import json
2
+ import subprocess
3
+ from pathlib import Path
4
+
5
+ def train_lora():
6
+ # Load your config
7
+ with open('training_config.json', 'r') as f:
8
+ config = json.load(f)
9
+
10
+ # Construct the training command
11
+ cmd = [
12
+ "accelerate", "launch",
13
+ "--num_cpu_threads_per_process", str(config["num_cpu_threads_per_process"]),
14
+ "train_network.py",
15
+ "--pretrained_model_name_or_path", config["pretrained_model_name_or_path"],
16
+ "--train_data_dir", config["train_data_dir"],
17
+ "--output_dir", config["output_dir"],
18
+ "--output_name", config["output_name"],
19
+ "--save_model_as", config["save_model_as"],
20
+ "--learning_rate", str(config["learning_rate"]),
21
+ "--train_batch_size", str(config["train_batch_size"]),
22
+ "--epoch", str(config["epoch"]),
23
+ "--save_every_n_epochs", str(config["save_every_n_epochs"]),
24
+ "--mixed_precision", config["mixed_precision"],
25
+ "--cache_latents",
26
+ "--gradient_checkpointing"
27
+ ]
28
+
29
+ # Add FLUX specific parameters
30
+ if config.get("flux1_checkbox"):
31
+ cmd.extend([
32
+ "--flux1_t5xxl", config["flux1_t5xxl"],
33
+ "--flux1_clip_l", config["flux1_clip_l"],
34
+ "--flux1_cache_text_encoder_outputs",
35
+ "--flux1_cache_text_encoder_outputs_to_disk"
36
+ ])
37
+
38
+ # Execute the training
39
+ subprocess.run(cmd, check=True)
40
+
41
+ if __name__ == "__main__":
42
+ train_lora()
@@ -0,0 +1,41 @@
1
+ import json
2
+ import subprocess
3
+ from pathlib import Path
4
+
5
+ def train_flux():
6
+ # Load your config
7
+ with open('flux_config.json', 'r') as f:
8
+ config = json.load(f)
9
+
10
+ # Construct the training command for Flux finetuning
11
+ cmd = [
12
+ "accelerate", "launch",
13
+ "--num_cpu_threads_per_process", str(config["num_cpu_threads_per_process"]),
14
+ "train_db.py",
15
+ "--pretrained_model_name_or_path", config["pretrained_model_name_or_path"],
16
+ "--train_data_dir", config["train_data_dir"],
17
+ "--output_dir", config["output_dir"],
18
+ "--output_name", config["output_name"],
19
+ "--train_batch_size", str(config["train_batch_size"]),
20
+ "--save_every_n_epochs", str(config["save_every_n_epochs"]),
21
+ "--learning_rate", str(config["learning_rate"]),
22
+ "--max_train_epochs", str(config["epoch"]),
23
+ "--mixed_precision", config["mixed_precision"],
24
+ "--save_model_as", config["save_model_as"],
25
+ "--cache_latents",
26
+ "--cache_latents_to_disk",
27
+ "--gradient_checkpointing",
28
+ "--optimizer_type", "Adafactor",
29
+ "--optimizer_args", "scale_parameter=False relative_step=False warmup_init=False weight_decay=0.01",
30
+ "--max_resolution", "1024,1024",
31
+ "--full_bf16",
32
+ "--flux1_checkbox",
33
+ "--flux1_t5xxl", config["flux1_t5xxl"],
34
+ "--flux1_clip_l", config["flux1_clip_l"],
35
+ "--flux1_cache_text_encoder_outputs",
36
+ "--flux1_cache_text_encoder_outputs_to_disk",
37
+ "--flux_fused_backward_pass"
38
+ ]
39
+
40
+ # Execute the training
41
+ subprocess.run(cmd, check=True)
@@ -0,0 +1,57 @@
1
+ import json
2
+ import subprocess
3
+ from pathlib import Path
4
+
5
+ def train_lora():
6
+ # Load your config
7
+ with open('training_config.json', 'r') as f:
8
+ config = json.load(f)
9
+
10
+ # Construct the training command for LoRA
11
+ cmd = [
12
+ "accelerate", "launch",
13
+ "--num_cpu_threads_per_process", str(config["num_cpu_threads_per_process"]),
14
+ "sdxl_train_network.py", # Use the SDXL LoRA training script
15
+ "--network_module", "networks.lora", # Specify LoRA network
16
+ "--pretrained_model_name_or_path", config["pretrained_model_name_or_path"],
17
+ "--train_data_dir", config["train_data_dir"],
18
+ "--output_dir", config["output_dir"],
19
+ "--output_name", config["output_name"],
20
+ "--save_model_as", config["save_model_as"],
21
+ "--network_alpha", "1", # LoRA alpha parameter
22
+ "--network_dim", "32", # LoRA dimension
23
+ "--learning_rate", str(config["learning_rate"]),
24
+ "--train_batch_size", str(config["train_batch_size"]),
25
+ "--max_train_epochs", str(config["epoch"]),
26
+ "--save_every_n_epochs", str(config["save_every_n_epochs"]),
27
+ "--mixed_precision", config["mixed_precision"],
28
+ "--cache_latents",
29
+ "--gradient_checkpointing",
30
+ "--network_args", "conv_dim=32", "conv_alpha=1", # LoRA network arguments
31
+ "--noise_offset", "0.1",
32
+ "--adaptive_noise_scale", "0.01",
33
+ "--max_resolution", "1024,1024",
34
+ "--min_bucket_reso", "256",
35
+ "--max_bucket_reso", "1024",
36
+ "--xformers",
37
+ "--bucket_reso_steps", "64",
38
+ "--caption_extension", ".txt",
39
+ "--optimizer_type", "AdaFactor",
40
+ "--optimizer_args", "scale_parameter=False", "relative_step=False", "warmup_init=False",
41
+ "--lr_scheduler", "constant"
42
+ ]
43
+
44
+ # Add FLUX specific parameters for LoRA
45
+ if config.get("flux1_checkbox"):
46
+ cmd.extend([
47
+ "--flux1_t5xxl", config["flux1_t5xxl"],
48
+ "--flux1_clip_l", config["flux1_clip_l"],
49
+ "--flux1_cache_text_encoder_outputs",
50
+ "--flux1_cache_text_encoder_outputs_to_disk"
51
+ ])
52
+
53
+ # Execute the training
54
+ subprocess.run(cmd, check=True)
55
+
56
+ if __name__ == "__main__":
57
+ train_lora()
@@ -0,0 +1,25 @@
1
+ import os
2
+ from pathlib import Path
3
+ import shutil
4
+ from app.services.training.image_model.raw_data.create_lora_captions import create_lora_captions
5
+ from app.services.training.image_model.train.train_flux import train_flux
6
+
7
+ def main():
8
+ # Setup paths
9
+ project_root = Path(__file__).parent
10
+ processed_images_dir = project_root / "raw_data/training_images_processed"
11
+
12
+ # 1. Generate captions for all processed images
13
+ print("Creating captions for processed images...")
14
+ create_lora_captions(processed_images_dir)
15
+
16
+ # 2. Create Flux config
17
+ print("Creating Flux configuration...")
18
+ os.system(f"python {project_root}/configs/create_flux_config.py")
19
+
20
+ # 3. Run Flux training
21
+ print("Starting Flux training...")
22
+ train_flux()
23
+
24
+ if __name__ == "__main__":
25
+ main()
@@ -0,0 +1,47 @@
1
+ # app/services/llm_model/tracing/annotation/annotation_schema.py
2
+ from enum import Enum
3
+ from pydantic import BaseModel, Field
4
+ from typing import Dict, Any, List, Optional
5
+ from datetime import datetime
6
+
7
+ class AnnotationType(str, Enum):
8
+ ACCURACY = "accuracy"
9
+ HELPFULNESS = "helpfulness"
10
+ TOXICITY = "toxicity"
11
+ CUSTOM = "custom"
12
+
13
+ class RatingScale(int, Enum):
14
+ POOR = 1
15
+ FAIR = 2
16
+ GOOD = 3
17
+ EXCELLENT = 4
18
+
19
+ class AnnotationAspects(BaseModel):
20
+ factually_correct: bool = True
21
+ relevant: bool = True
22
+ harmful: bool = False
23
+ biased: bool = False
24
+ complete: bool = True
25
+ efficient: bool = True
26
+
27
+ class BetterResponse(BaseModel):
28
+ content: str
29
+ reason: Optional[str]
30
+ metadata: Optional[Dict[str, Any]] = {}
31
+
32
+ class AnnotationFeedback(BaseModel):
33
+ rating: RatingScale
34
+ category: AnnotationType
35
+ aspects: AnnotationAspects
36
+ better_response: Optional[BetterResponse]
37
+ comment: Optional[str]
38
+ metadata: Optional[Dict[str, Any]] = {}
39
+ is_selected_for_training: bool = False
40
+
41
+ class ItemAnnotation(BaseModel):
42
+ item_id: str
43
+ feedback: Optional[AnnotationFeedback]
44
+ status: str = "pending"
45
+ annotated_at: Optional[datetime]
46
+ annotator_id: Optional[str]
47
+ training_status: Optional[str] = None
@@ -0,0 +1,126 @@
1
+ from typing import Dict, Any, List
2
+ from datetime import datetime
3
+ from app.config.config_manager import config_manager
4
+ from app.services.training.llm_model.annotation.annotation_schema import AnnotationFeedback, RatingScale, AnnotationAspects
5
+ from bson.objectid import ObjectId
6
+ from app.services.training.llm_model.annotation.storage.dataset_manager import DatasetManager
7
+
8
+ class AnnotationProcessor:
9
+ def __init__(self):
10
+ self.logger = config_manager.get_logger(__name__)
11
+ self.dataset_manager = DatasetManager()
12
+ self.batch_size = 1000 # Configure as needed
13
+
14
+ async def process_queue(self) -> None:
15
+ """Process pending items and create datasets"""
16
+ db = await config_manager.get_db('mongodb')
17
+ queue = db['training_queue']
18
+
19
+ # Process SFT items
20
+ sft_items = await self._get_pending_items("sft")
21
+ if len(sft_items) >= self.batch_size:
22
+ await self._create_sft_dataset(sft_items)
23
+
24
+ # Process RLHF items
25
+ rlhf_items = await self._get_pending_items("rlhf")
26
+ if len(rlhf_items) >= self.batch_size:
27
+ await self._create_rlhf_dataset(rlhf_items)
28
+
29
+ async def _create_sft_dataset(self, items: List[Dict[str, Any]]):
30
+ """Create and upload SFT dataset"""
31
+ dataset = await self.dataset_manager.create_dataset(
32
+ name=f"sft_dataset_v{datetime.now().strftime('%Y%m%d')}",
33
+ type="sft",
34
+ version=datetime.now().strftime("%Y%m%d"),
35
+ source_annotations=[item["annotation_id"] for item in items]
36
+ )
37
+
38
+ formatted_data = [
39
+ await self._process_sft_item(item)
40
+ for item in items
41
+ ]
42
+
43
+ await self.dataset_manager.upload_dataset_file(
44
+ dataset.id,
45
+ formatted_data
46
+ )
47
+
48
+ async def _process_sft_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
49
+ """Process item for SFT dataset generation
50
+ Format follows HF conversation format for SFT training
51
+ """
52
+ db = await config_manager.get_db('mongodb')
53
+ annotations = db['annotations']
54
+
55
+ # Get full annotation context
56
+ annotation = await annotations.find_one({"_id": ObjectId(item["annotation_id"])})
57
+ target_item = next(i for i in annotation["items"] if i["item_id"] == item["item_id"])
58
+
59
+ # Format as conversation
60
+ messages = [
61
+ {
62
+ "role": "system",
63
+ "content": "You are a helpful AI assistant that provides accurate and relevant information."
64
+ },
65
+ {
66
+ "role": "user",
67
+ "content": target_item["input"]["messages"][0]["content"]
68
+ },
69
+ {
70
+ "role": "assistant",
71
+ "content": target_item["output"]["content"]
72
+ }
73
+ ]
74
+
75
+ return {
76
+ "messages": messages,
77
+ "metadata": {
78
+ "rating": item["feedback"]["rating"],
79
+ "aspects": item["feedback"]["aspects"],
80
+ "category": item["feedback"]["category"]
81
+ }
82
+ }
83
+
84
+ async def _process_rlhf_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
85
+ """Process item for RLHF dataset generation
86
+ Format follows preference pairs structure for RLHF training
87
+ """
88
+ db = await config_manager.get_db('mongodb')
89
+ annotations = db['annotations']
90
+
91
+ # Get full annotation context
92
+ annotation = await annotations.find_one({"_id": ObjectId(item["annotation_id"])})
93
+ target_item = next(i for i in annotation["items"] if i["item_id"] == item["item_id"])
94
+
95
+ # Format as preference pairs
96
+ return {
97
+ "prompt": target_item["input"]["messages"][0]["content"],
98
+ "chosen": item["feedback"]["better_response"]["content"],
99
+ "rejected": target_item["output"]["content"],
100
+ "metadata": {
101
+ "reason": item["feedback"]["better_response"]["reason"],
102
+ "category": item["feedback"]["category"]
103
+ }
104
+ }
105
+
106
+ async def get_training_data(
107
+ self,
108
+ data_type: str,
109
+ limit: int = 1000
110
+ ) -> List[Dict[str, Any]]:
111
+ """Retrieve formatted training data"""
112
+ db = await config_manager.get_db('mongodb')
113
+ training_data = db['training_data']
114
+
115
+ data = await training_data.find(
116
+ {"type": data_type}
117
+ ).limit(limit).to_list(length=limit)
118
+
119
+ if data_type == "sft":
120
+ return [item["data"]["messages"] for item in data]
121
+ else: # rlhf
122
+ return [{
123
+ "prompt": item["data"]["prompt"],
124
+ "chosen": item["data"]["chosen"],
125
+ "rejected": item["data"]["rejected"]
126
+ } for item in data]
@@ -0,0 +1,131 @@
1
+ # app/services/llm_model/annotation/dataset/dataset_manager.py
2
+ from typing import Dict, Any, List
3
+ from datetime import datetime
4
+ import json
5
+ import io
6
+ from app.config.config_manager import config_manager
7
+ from .dataset_schema import Dataset, DatasetType, DatasetStatus, DatasetFiles, DatasetStats
8
+ from bson import ObjectId
9
+
10
+ class DatasetManager:
11
+ def __init__(self):
12
+ self.logger = config_manager.get_logger(__name__)
13
+ self.minio_client = None
14
+ self.bucket_name = "training-datasets"
15
+
16
+ async def _ensure_minio_client(self):
17
+ if not self.minio_client:
18
+ self.minio_client = await config_manager.get_storage_client()
19
+
20
+ async def create_dataset(
21
+ self,
22
+ name: str,
23
+ type: DatasetType,
24
+ version: str,
25
+ source_annotations: List[str]
26
+ ) -> Dataset:
27
+ """Create a new dataset record"""
28
+ db = await config_manager.get_db('mongodb')
29
+ collection = db['training_datasets']
30
+
31
+ dataset = Dataset(
32
+ name=name,
33
+ type=type,
34
+ version=version,
35
+ storage_path=f"datasets/{type.value}/{version}",
36
+ files=DatasetFiles(
37
+ train="train.jsonl",
38
+ eval=None,
39
+ test=None
40
+ ),
41
+ stats=DatasetStats(
42
+ total_examples=0,
43
+ avg_length=0.0,
44
+ num_conversations=0,
45
+ additional_metrics={}
46
+ ),
47
+ source_annotations=source_annotations,
48
+ created_at=datetime.utcnow(),
49
+ status=DatasetStatus.PENDING,
50
+ metadata={}
51
+ )
52
+
53
+ result = await collection.insert_one(dataset.dict(exclude={'id'}))
54
+ return Dataset(**{**dataset.dict(), '_id': result.inserted_id})
55
+
56
+ async def upload_dataset_file(
57
+ self,
58
+ dataset_id: str,
59
+ data: List[Dict[str, Any]],
60
+ file_type: str = "train"
61
+ ) -> bool:
62
+ """Upload dataset to MinIO"""
63
+ try:
64
+ await self._ensure_minio_client()
65
+ db = await config_manager.get_db('mongodb')
66
+
67
+ object_id = ObjectId(dataset_id)
68
+ dataset = await db['training_datasets'].find_one({"_id": object_id})
69
+
70
+ if not dataset:
71
+ self.logger.error(f"Dataset not found with id: {dataset_id}")
72
+ return False
73
+
74
+ # Convert to JSONL
75
+ buffer = io.StringIO()
76
+ for item in data:
77
+ buffer.write(json.dumps(item) + "\n")
78
+
79
+ storage_path = dataset['storage_path'].rstrip('/')
80
+ file_path = f"{storage_path}/{file_type}.jsonl"
81
+
82
+ buffer_value = buffer.getvalue().encode()
83
+
84
+ self.logger.debug(f"Uploading to MinIO path: {file_path}")
85
+
86
+ self.minio_client.put_object(
87
+ self.bucket_name,
88
+ file_path,
89
+ io.BytesIO(buffer_value),
90
+ len(buffer_value)
91
+ )
92
+
93
+ avg_length = sum(len(str(item)) for item in data) / len(data) if data else 0
94
+
95
+ await db['training_datasets'].update_one(
96
+ {"_id": object_id},
97
+ {
98
+ "$set": {
99
+ f"files.{file_type}": f"{file_type}.jsonl",
100
+ "stats.total_examples": len(data),
101
+ "stats.avg_length": avg_length,
102
+ "stats.num_conversations": len(data),
103
+ "status": DatasetStatus.READY
104
+ }
105
+ }
106
+ )
107
+
108
+ return True
109
+
110
+ except Exception as e:
111
+ self.logger.error(f"Failed to upload dataset: {e}")
112
+ return False
113
+
114
+ async def get_dataset_info(self, dataset_id: str) -> Dict[str, Any]:
115
+ """Get dataset information"""
116
+ try:
117
+ db = await config_manager.get_db('mongodb')
118
+ object_id = ObjectId(dataset_id) # Convert string ID to ObjectId
119
+ dataset = await db['training_datasets'].find_one({"_id": object_id})
120
+
121
+ if not dataset:
122
+ self.logger.error(f"Dataset not found with id: {dataset_id}")
123
+ return None
124
+
125
+ # Convert ObjectId to string for JSON serialization
126
+ dataset['_id'] = str(dataset['_id'])
127
+ return dataset
128
+
129
+ except Exception as e:
130
+ self.logger.error(f"Failed to get dataset info: {e}")
131
+ return None