simpletuner 2.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simpletuner/__init__.py +3 -0
- simpletuner/cli.py +626 -0
- simpletuner/configure.py +5649 -0
- simpletuner/examples/README.md +53 -0
- simpletuner/examples/auraflow.peft-controlnet-lora/config.json +46 -0
- simpletuner/examples/auraflow.peft-lora/config.json +55 -0
- simpletuner/examples/bria.lycoris-lokr/config.json +58 -0
- simpletuner/examples/bria.lycoris-lokr/lycoris_config.json +20 -0
- simpletuner/examples/cosmos2image.lycoris-lokr/config.json +43 -0
- simpletuner/examples/cosmos2image.lycoris-lokr/lycoris_config.json +20 -0
- simpletuner/examples/flux.peft-controlnet-lora/config.json +47 -0
- simpletuner/examples/flux.peft-lora/config.json +46 -0
- simpletuner/examples/flux.peft-lora+TREAD/config.json +55 -0
- simpletuner/examples/hidream.peft-controlnet-lora/config.env +1 -0
- simpletuner/examples/hidream.peft-controlnet-lora/config.json +53 -0
- simpletuner/examples/hidream.peft-lora/config.env +1 -0
- simpletuner/examples/hidream.peft-lora/config.json +70 -0
- simpletuner/examples/kontext.peft-lora/config.env +1 -0
- simpletuner/examples/kontext.peft-lora/config.json +50 -0
- simpletuner/examples/lumina2.peft-lora/config.json +41 -0
- simpletuner/examples/omnigen.lycoris-lokr/config.env +4 -0
- simpletuner/examples/omnigen.lycoris-lokr/config.json +61 -0
- simpletuner/examples/omnigen.lycoris-lokr/lycoris_config.json +22 -0
- simpletuner/examples/pixart.lycoris-lokr/config.json +44 -0
- simpletuner/examples/pixart.lycoris-lokr/lycoris_config.json +22 -0
- simpletuner/examples/pixart.peft-controlnet-lora/config.json +45 -0
- simpletuner/examples/qwen_image.peft-lora/config.json +57 -0
- simpletuner/examples/sana.lycoris-lokr/config.json +63 -0
- simpletuner/examples/sana.lycoris-lokr/lycoris_config.json +22 -0
- simpletuner/examples/sd3.peft-controlnet-lora/config.json +46 -0
- simpletuner/examples/sd3.peft-lora/config.json +48 -0
- simpletuner/examples/sdxl.lycoris-lokr/config.json +39 -0
- simpletuner/examples/sdxl.lycoris-lokr/lycoris_config.json +18 -0
- simpletuner/examples/sdxl.peft-controlnet-lora/config.json +42 -0
- simpletuner/examples/wan-1.3b.peft-lora+TREAD/config.json +69 -0
- simpletuner/helpers/caching/memory.py +13 -0
- simpletuner/helpers/caching/text_embeds.py +512 -0
- simpletuner/helpers/caching/vae.py +1391 -0
- simpletuner/helpers/configuration/cmd_args.py +2869 -0
- simpletuner/helpers/configuration/env_file.py +195 -0
- simpletuner/helpers/configuration/json_file.py +68 -0
- simpletuner/helpers/configuration/loader.py +65 -0
- simpletuner/helpers/configuration/toml_file.py +76 -0
- simpletuner/helpers/data_backend/aws.py +451 -0
- simpletuner/helpers/data_backend/base.py +151 -0
- simpletuner/helpers/data_backend/csv_url_list.py +338 -0
- simpletuner/helpers/data_backend/factory.py +2221 -0
- simpletuner/helpers/data_backend/huggingface.py +651 -0
- simpletuner/helpers/data_backend/local.py +347 -0
- simpletuner/helpers/data_generation/__init__.py +2 -0
- simpletuner/helpers/data_generation/conditioning.py +715 -0
- simpletuner/helpers/data_generation/sample_generator.py +1186 -0
- simpletuner/helpers/distillation/common.py +153 -0
- simpletuner/helpers/distillation/dcm/discriminator/hunyuan_video.py +171 -0
- simpletuner/helpers/distillation/dcm/discriminator/wan.py +502 -0
- simpletuner/helpers/distillation/dcm/distiller.py +393 -0
- simpletuner/helpers/distillation/dcm/loss.py +98 -0
- simpletuner/helpers/distillation/dcm/solver.py +679 -0
- simpletuner/helpers/distillation/dmd/distiller.py +377 -0
- simpletuner/helpers/distillation/factory.py +339 -0
- simpletuner/helpers/distillation/lcm/__init__.py +418 -0
- simpletuner/helpers/image_manipulation/batched_training_samples.py +398 -0
- simpletuner/helpers/image_manipulation/brightness.py +35 -0
- simpletuner/helpers/image_manipulation/cropping.py +309 -0
- simpletuner/helpers/image_manipulation/load.py +216 -0
- simpletuner/helpers/image_manipulation/training_sample.py +1071 -0
- simpletuner/helpers/legacy/pipeline.py +1202 -0
- simpletuner/helpers/log_format.py +133 -0
- simpletuner/helpers/metadata/__init__.py +0 -0
- simpletuner/helpers/metadata/backends/base.py +1165 -0
- simpletuner/helpers/metadata/backends/discovery.py +299 -0
- simpletuner/helpers/metadata/backends/huggingface.py +825 -0
- simpletuner/helpers/metadata/backends/parquet.py +660 -0
- simpletuner/helpers/metadata/utils/__init__.py +2 -0
- simpletuner/helpers/metadata/utils/duplicator.py +209 -0
- simpletuner/helpers/models/all.py +58 -0
- simpletuner/helpers/models/auraflow/controlnet.py +567 -0
- simpletuner/helpers/models/auraflow/model.py +459 -0
- simpletuner/helpers/models/auraflow/pipeline.py +1570 -0
- simpletuner/helpers/models/auraflow/pipeline_controlnet.py +1157 -0
- simpletuner/helpers/models/auraflow/transformer.py +782 -0
- simpletuner/helpers/models/common.py +1800 -0
- simpletuner/helpers/models/cosmos/model.py +335 -0
- simpletuner/helpers/models/cosmos/pipeline.py +708 -0
- simpletuner/helpers/models/cosmos/transformer.py +703 -0
- simpletuner/helpers/models/deepfloyd/model.py +219 -0
- simpletuner/helpers/models/flux/__init__.py +188 -0
- simpletuner/helpers/models/flux/attention.py +639 -0
- simpletuner/helpers/models/flux/model.py +1083 -0
- simpletuner/helpers/models/flux/pipeline.py +2761 -0
- simpletuner/helpers/models/flux/pipeline_controlnet.py +2117 -0
- simpletuner/helpers/models/flux/transformer.py +1192 -0
- simpletuner/helpers/models/hidream/controlnet.py +1699 -0
- simpletuner/helpers/models/hidream/model.py +697 -0
- simpletuner/helpers/models/hidream/pipeline.py +1933 -0
- simpletuner/helpers/models/hidream/schedule.py +802 -0
- simpletuner/helpers/models/hidream/transformer.py +1618 -0
- simpletuner/helpers/models/kolors/controlnet.py +1009 -0
- simpletuner/helpers/models/kolors/model.py +261 -0
- simpletuner/helpers/models/kolors/pipeline.py +2211 -0
- simpletuner/helpers/models/kolors/pipeline_controlnet.py +1555 -0
- simpletuner/helpers/models/ltxvideo/__init__.py +211 -0
- simpletuner/helpers/models/ltxvideo/model.py +309 -0
- simpletuner/helpers/models/lumina2/model.py +303 -0
- simpletuner/helpers/models/omnigen/collator.py +61 -0
- simpletuner/helpers/models/omnigen/model.py +173 -0
- simpletuner/helpers/models/pixart/controlnet.py +345 -0
- simpletuner/helpers/models/pixart/model.py +451 -0
- simpletuner/helpers/models/pixart/model_card_templates.py +50 -0
- simpletuner/helpers/models/pixart/pipeline.py +3116 -0
- simpletuner/helpers/models/pixart/transformer.py +477 -0
- simpletuner/helpers/models/qwen_image/model.py +351 -0
- simpletuner/helpers/models/sana/__init__.py +0 -0
- simpletuner/helpers/models/sana/model.py +195 -0
- simpletuner/helpers/models/sana/pipeline.py +964 -0
- simpletuner/helpers/models/sana/transformer.py +545 -0
- simpletuner/helpers/models/sd1x/model.py +309 -0
- simpletuner/helpers/models/sd1x/pipeline.py +4146 -0
- simpletuner/helpers/models/sd3/__init__.py +0 -0
- simpletuner/helpers/models/sd3/controlnet.py +1410 -0
- simpletuner/helpers/models/sd3/expanded.py +740 -0
- simpletuner/helpers/models/sd3/model.py +593 -0
- simpletuner/helpers/models/sd3/pipeline.py +3023 -0
- simpletuner/helpers/models/sd3/transformer.py +478 -0
- simpletuner/helpers/models/sdxl/controlnet.py +29 -0
- simpletuner/helpers/models/sdxl/model.py +347 -0
- simpletuner/helpers/models/sdxl/pipeline.py +5138 -0
- simpletuner/helpers/models/wan/__init__.py +94 -0
- simpletuner/helpers/models/wan/model.py +303 -0
- simpletuner/helpers/models/wan/pipeline.py +653 -0
- simpletuner/helpers/models/wan/transformer.py +698 -0
- simpletuner/helpers/multiaspect/dataset.py +90 -0
- simpletuner/helpers/multiaspect/image.py +327 -0
- simpletuner/helpers/multiaspect/sampler.py +889 -0
- simpletuner/helpers/multiaspect/state.py +67 -0
- simpletuner/helpers/multiaspect/video.py +25 -0
- simpletuner/helpers/prompt_expander/__init__.py +276 -0
- simpletuner/helpers/prompts.py +678 -0
- simpletuner/helpers/publishing/huggingface.py +279 -0
- simpletuner/helpers/publishing/metadata.py +593 -0
- simpletuner/helpers/training/__init__.py +181 -0
- simpletuner/helpers/training/adapter.py +142 -0
- simpletuner/helpers/training/collate.py +661 -0
- simpletuner/helpers/training/custom_schedule.py +1606 -0
- simpletuner/helpers/training/deepspeed.py +134 -0
- simpletuner/helpers/training/default_settings/__init__.py +15 -0
- simpletuner/helpers/training/default_settings/safety_check.py +163 -0
- simpletuner/helpers/training/diffusers_overrides.py +406 -0
- simpletuner/helpers/training/diffusion_model.py +4 -0
- simpletuner/helpers/training/ema.py +473 -0
- simpletuner/helpers/training/error_handling.py +31 -0
- simpletuner/helpers/training/evaluation.py +64 -0
- simpletuner/helpers/training/exceptions.py +2 -0
- simpletuner/helpers/training/gradient_checkpointing_interval.py +42 -0
- simpletuner/helpers/training/min_snr_gamma.py +47 -0
- simpletuner/helpers/training/model_freeze.py +180 -0
- simpletuner/helpers/training/multi_process.py +20 -0
- simpletuner/helpers/training/optimizer_param.py +827 -0
- simpletuner/helpers/training/optimizers/adamw_bfloat16/__init__.py +167 -0
- simpletuner/helpers/training/optimizers/adamw_bfloat16/stochastic/__init__.py +124 -0
- simpletuner/helpers/training/optimizers/adamw_schedulefree/__init__.py +149 -0
- simpletuner/helpers/training/optimizers/soap/__init__.py +479 -0
- simpletuner/helpers/training/peft_init.py +25 -0
- simpletuner/helpers/training/quantisation/__init__.py +371 -0
- simpletuner/helpers/training/quantisation/peft_workarounds.py +423 -0
- simpletuner/helpers/training/quantisation/quanto_workarounds.py +114 -0
- simpletuner/helpers/training/quantisation/torchao_workarounds.py +146 -0
- simpletuner/helpers/training/save_hooks.py +408 -0
- simpletuner/helpers/training/state_tracker.py +657 -0
- simpletuner/helpers/training/trainer.py +2718 -0
- simpletuner/helpers/training/tread.py +162 -0
- simpletuner/helpers/training/validation.py +2635 -0
- simpletuner/helpers/training/wrappers.py +32 -0
- simpletuner/helpers/webhooks/config.py +51 -0
- simpletuner/helpers/webhooks/handler.py +247 -0
- simpletuner/helpers/webhooks/mixin.py +31 -0
- simpletuner/inference.py +160 -0
- simpletuner/inference_comparison.py +163 -0
- simpletuner/service_worker.py +110 -0
- simpletuner/simpletuner_sdk/api_state.py +87 -0
- simpletuner/simpletuner_sdk/configuration.py +183 -0
- simpletuner/simpletuner_sdk/interface.py +409 -0
- simpletuner/simpletuner_sdk/thread_keeper/__init__.py +67 -0
- simpletuner/simpletuner_sdk/training_host.py +75 -0
- simpletuner/train.py +92 -0
- simpletuner-2.2.0.dist-info/METADATA +274 -0
- simpletuner-2.2.0.dist-info/RECORD +191 -0
- simpletuner-2.2.0.dist-info/WHEEL +5 -0
- simpletuner-2.2.0.dist-info/entry_points.txt +5 -0
- simpletuner-2.2.0.dist-info/licenses/LICENSE +208 -0
- simpletuner-2.2.0.dist-info/top_level.txt +1 -0
simpletuner/__init__.py
ADDED
simpletuner/cli.py
ADDED
|
@@ -0,0 +1,626 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
SimpleTuner CLI - Command-line interface for SimpleTuner
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
import subprocess
|
|
11
|
+
import shutil
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import List, Optional
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def find_config_file() -> Optional[str]:
|
|
17
|
+
"""Find config file in current directory or config/ subdirectory."""
|
|
18
|
+
# Check for config.json in current directory
|
|
19
|
+
for config_name in ["config.json", "config.toml", "config.env"]:
|
|
20
|
+
if os.path.exists(config_name):
|
|
21
|
+
return config_name
|
|
22
|
+
|
|
23
|
+
# Check for config files in config/ subdirectory
|
|
24
|
+
config_dir = Path("config")
|
|
25
|
+
if config_dir.exists():
|
|
26
|
+
for config_name in ["config.json", "config.toml", "config.env"]:
|
|
27
|
+
config_path = config_dir / config_name
|
|
28
|
+
if config_path.exists():
|
|
29
|
+
return str(config_path)
|
|
30
|
+
|
|
31
|
+
return None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_examples_dir() -> Path:
|
|
35
|
+
"""Get the path to the examples directory."""
|
|
36
|
+
# Find simpletuner package directory
|
|
37
|
+
import simpletuner
|
|
38
|
+
|
|
39
|
+
simpletuner_dir = Path(simpletuner.__file__).parent
|
|
40
|
+
return simpletuner_dir / "examples"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def list_examples() -> List[str]:
|
|
44
|
+
"""List all available examples (directories only)."""
|
|
45
|
+
examples_dir = get_examples_dir()
|
|
46
|
+
if not examples_dir.exists():
|
|
47
|
+
return []
|
|
48
|
+
|
|
49
|
+
examples = []
|
|
50
|
+
for item in examples_dir.iterdir():
|
|
51
|
+
if item.is_dir() and not item.name.startswith("."):
|
|
52
|
+
examples.append(item.name)
|
|
53
|
+
|
|
54
|
+
return sorted(examples)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def find_referenced_files(config_path: Path) -> List[str]:
|
|
58
|
+
"""Find files referenced in config.json that should be copied."""
|
|
59
|
+
referenced_files = []
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
with open(config_path, "r") as f:
|
|
63
|
+
config = json.load(f)
|
|
64
|
+
|
|
65
|
+
# Look for common fields that reference files in examples directory
|
|
66
|
+
fields_to_check = [
|
|
67
|
+
"data_backend_config",
|
|
68
|
+
"validation_prompt_library",
|
|
69
|
+
"controlnet_config",
|
|
70
|
+
"reference_config",
|
|
71
|
+
"lycoris_config",
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
for field in fields_to_check:
|
|
75
|
+
if field in config:
|
|
76
|
+
value = config[field]
|
|
77
|
+
if isinstance(value, str):
|
|
78
|
+
# Check if it references a file in examples directory
|
|
79
|
+
if "examples/" in value and value.endswith(".json"):
|
|
80
|
+
# Extract just the filename from paths like "config/examples/file.json"
|
|
81
|
+
filename = Path(value).name
|
|
82
|
+
referenced_files.append(filename)
|
|
83
|
+
|
|
84
|
+
except (json.JSONDecodeError, FileNotFoundError) as e:
|
|
85
|
+
print(f"Warning: Could not parse config file {config_path}: {e}")
|
|
86
|
+
|
|
87
|
+
return referenced_files
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def copy_example(example_name: str, dest: Optional[str] = None) -> bool:
|
|
91
|
+
"""Copy an example to destination directory."""
|
|
92
|
+
examples_dir = get_examples_dir()
|
|
93
|
+
example_path = examples_dir / example_name
|
|
94
|
+
|
|
95
|
+
if not example_path.exists():
|
|
96
|
+
print(f"Error: Example '{example_name}' not found.")
|
|
97
|
+
print(f"Available examples: {', '.join(list_examples())}")
|
|
98
|
+
return False
|
|
99
|
+
|
|
100
|
+
# Determine destination
|
|
101
|
+
if dest is None:
|
|
102
|
+
dest = "."
|
|
103
|
+
|
|
104
|
+
dest_path = Path(dest)
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
if example_path.is_dir():
|
|
108
|
+
# Copy directory
|
|
109
|
+
dest_example = dest_path / example_name
|
|
110
|
+
if dest_example.exists():
|
|
111
|
+
print(f"Error: Destination '{dest_example}' already exists.")
|
|
112
|
+
return False
|
|
113
|
+
shutil.copytree(example_path, dest_example)
|
|
114
|
+
print(f"Copied example directory '{example_name}' to '{dest_example}'")
|
|
115
|
+
|
|
116
|
+
# Check for referenced files in config.json
|
|
117
|
+
config_json = dest_example / "config.json"
|
|
118
|
+
if config_json.exists():
|
|
119
|
+
referenced_files = find_referenced_files(config_json)
|
|
120
|
+
if referenced_files:
|
|
121
|
+
print(
|
|
122
|
+
f"Found {len(referenced_files)} referenced file(s) to copy..."
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
for ref_file in referenced_files:
|
|
126
|
+
source_file = examples_dir / ref_file
|
|
127
|
+
if source_file.exists():
|
|
128
|
+
dest_file = dest_example / ref_file
|
|
129
|
+
shutil.copy2(source_file, dest_file)
|
|
130
|
+
print(f" Copied referenced file: {ref_file}")
|
|
131
|
+
else:
|
|
132
|
+
print(f" Warning: Referenced file not found: {ref_file}")
|
|
133
|
+
|
|
134
|
+
# Update config.json to use local paths
|
|
135
|
+
update_config_paths(config_json, referenced_files)
|
|
136
|
+
|
|
137
|
+
else:
|
|
138
|
+
# Copy file
|
|
139
|
+
dest_file = dest_path / example_path.name
|
|
140
|
+
if dest_file.exists():
|
|
141
|
+
print(f"Error: Destination '{dest_file}' already exists.")
|
|
142
|
+
return False
|
|
143
|
+
shutil.copy2(example_path, dest_file)
|
|
144
|
+
print(f"Copied example file '{example_name}' to '{dest_file}'")
|
|
145
|
+
|
|
146
|
+
return True
|
|
147
|
+
except Exception as e:
|
|
148
|
+
print(f"Error copying example: {e}")
|
|
149
|
+
return False
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def update_config_paths(config_path: Path, referenced_files: List[str]) -> None:
|
|
153
|
+
"""Update config.json paths to point to local referenced files."""
|
|
154
|
+
try:
|
|
155
|
+
with open(config_path, "r") as f:
|
|
156
|
+
config = json.load(f)
|
|
157
|
+
|
|
158
|
+
# Update paths for referenced files
|
|
159
|
+
fields_to_check = [
|
|
160
|
+
"data_backend_config",
|
|
161
|
+
"validation_prompt_library",
|
|
162
|
+
"controlnet_config",
|
|
163
|
+
"reference_config",
|
|
164
|
+
"lycoris_config",
|
|
165
|
+
]
|
|
166
|
+
|
|
167
|
+
updated = False
|
|
168
|
+
for field in fields_to_check:
|
|
169
|
+
if field in config:
|
|
170
|
+
value = config[field]
|
|
171
|
+
if isinstance(value, str) and "examples/" in value:
|
|
172
|
+
filename = Path(value).name
|
|
173
|
+
if filename in referenced_files:
|
|
174
|
+
# Update to use local path
|
|
175
|
+
config[field] = filename
|
|
176
|
+
updated = True
|
|
177
|
+
print(f" Updated {field}: {value} -> {filename}")
|
|
178
|
+
|
|
179
|
+
if updated:
|
|
180
|
+
with open(config_path, "w") as f:
|
|
181
|
+
json.dump(config, f, indent=4)
|
|
182
|
+
print(f" Updated config.json with local file paths")
|
|
183
|
+
|
|
184
|
+
except (json.JSONDecodeError, FileNotFoundError) as e:
|
|
185
|
+
print(f"Warning: Could not update config paths: {e}")
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def setup_environment_from_example(example_name: str) -> dict:
|
|
189
|
+
"""Setup environment variables for training with example."""
|
|
190
|
+
env = os.environ.copy()
|
|
191
|
+
|
|
192
|
+
# Set ENV variable to point to the example
|
|
193
|
+
env["ENV"] = f"examples/{example_name}"
|
|
194
|
+
|
|
195
|
+
# Find the example config file
|
|
196
|
+
examples_dir = get_examples_dir()
|
|
197
|
+
example_path = examples_dir / example_name
|
|
198
|
+
|
|
199
|
+
# Check if it's a directory with config files or a standalone config file
|
|
200
|
+
if example_path.is_dir():
|
|
201
|
+
config_path = example_path / "config"
|
|
202
|
+
config_base = str(config_path)
|
|
203
|
+
|
|
204
|
+
# Check for different config file types
|
|
205
|
+
if (example_path / "config.json").exists():
|
|
206
|
+
env["CONFIG_BACKEND"] = "json"
|
|
207
|
+
env["CONFIG_PATH"] = str(example_path / "config")
|
|
208
|
+
elif (example_path / "config.toml").exists():
|
|
209
|
+
env["CONFIG_BACKEND"] = "toml"
|
|
210
|
+
env["CONFIG_PATH"] = str(example_path / "config")
|
|
211
|
+
elif (example_path / "config.env").exists():
|
|
212
|
+
env["CONFIG_BACKEND"] = "env"
|
|
213
|
+
env["CONFIG_PATH"] = str(example_path / "config")
|
|
214
|
+
else:
|
|
215
|
+
raise ValueError(f"No config file found in example {example_name}")
|
|
216
|
+
elif example_path.is_file():
|
|
217
|
+
# Standalone config file
|
|
218
|
+
if example_path.suffix == ".json":
|
|
219
|
+
env["CONFIG_BACKEND"] = "json"
|
|
220
|
+
env["CONFIG_PATH"] = str(example_path.with_suffix(""))
|
|
221
|
+
elif example_path.suffix == ".toml":
|
|
222
|
+
env["CONFIG_BACKEND"] = "toml"
|
|
223
|
+
env["CONFIG_PATH"] = str(example_path.with_suffix(""))
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError(f"Unsupported config file type: {example_path.suffix}")
|
|
226
|
+
else:
|
|
227
|
+
raise ValueError(f"Example {example_name} not found")
|
|
228
|
+
|
|
229
|
+
# Set other default environment variables (similar to train.sh)
|
|
230
|
+
if "TQDM_NCOLS" not in env:
|
|
231
|
+
env["TQDM_NCOLS"] = "125"
|
|
232
|
+
if "TQDM_LEAVE" not in env:
|
|
233
|
+
env["TQDM_LEAVE"] = "false"
|
|
234
|
+
|
|
235
|
+
env["TOKENIZERS_PARALLELISM"] = "false"
|
|
236
|
+
|
|
237
|
+
# Platform-specific settings
|
|
238
|
+
platform = os.uname().sysname
|
|
239
|
+
env["PLATFORM"] = platform
|
|
240
|
+
if platform == "Darwin":
|
|
241
|
+
env["MIXED_PRECISION"] = "no"
|
|
242
|
+
|
|
243
|
+
# Training defaults
|
|
244
|
+
if "TRAINING_NUM_PROCESSES" not in env:
|
|
245
|
+
env["TRAINING_NUM_PROCESSES"] = "1"
|
|
246
|
+
if "TRAINING_NUM_MACHINES" not in env:
|
|
247
|
+
env["TRAINING_NUM_MACHINES"] = "1"
|
|
248
|
+
if "MIXED_PRECISION" not in env:
|
|
249
|
+
env["MIXED_PRECISION"] = "bf16"
|
|
250
|
+
if "TRAINING_DYNAMO_BACKEND" not in env:
|
|
251
|
+
env["TRAINING_DYNAMO_BACKEND"] = "no"
|
|
252
|
+
|
|
253
|
+
return env
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def find_accelerate_config() -> Optional[str]:
|
|
257
|
+
"""Find accelerate configuration file."""
|
|
258
|
+
# Check HF_HOME first
|
|
259
|
+
hf_home = os.environ.get("HF_HOME", "")
|
|
260
|
+
if hf_home:
|
|
261
|
+
config_path = Path(hf_home) / "accelerate" / "default_config.yaml"
|
|
262
|
+
if config_path.exists():
|
|
263
|
+
return str(config_path)
|
|
264
|
+
|
|
265
|
+
# Fallback to default cache location
|
|
266
|
+
home = Path.home()
|
|
267
|
+
config_path = home / ".cache" / "huggingface" / "accelerate" / "default_config.yaml"
|
|
268
|
+
if config_path.exists():
|
|
269
|
+
return str(config_path)
|
|
270
|
+
|
|
271
|
+
return None
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def run_training(example: Optional[str] = None, env: Optional[str] = None) -> int:
|
|
275
|
+
"""Run training with specified example or environment."""
|
|
276
|
+
# Setup environment
|
|
277
|
+
training_env = os.environ.copy()
|
|
278
|
+
|
|
279
|
+
# Check for config file if no example/env specified
|
|
280
|
+
if not example and not env:
|
|
281
|
+
config_file = find_config_file()
|
|
282
|
+
if not config_file:
|
|
283
|
+
print(
|
|
284
|
+
"Error: No config file found in current directory or config/ subdirectory."
|
|
285
|
+
)
|
|
286
|
+
print("Expected: config.json, config.toml, or config.env")
|
|
287
|
+
print("Or use: simpletuner train example=<example_name>")
|
|
288
|
+
return 1
|
|
289
|
+
|
|
290
|
+
print(f"Using config file: {config_file}")
|
|
291
|
+
|
|
292
|
+
# Set environment variables for local config
|
|
293
|
+
config_path = Path(config_file)
|
|
294
|
+
if config_path.suffix == ".json":
|
|
295
|
+
training_env["CONFIG_BACKEND"] = "json"
|
|
296
|
+
# If config is in current directory, use absolute path
|
|
297
|
+
if config_path.parent == Path("."):
|
|
298
|
+
training_env["CONFIG_PATH"] = str(
|
|
299
|
+
Path.cwd() / config_path.with_suffix("")
|
|
300
|
+
)
|
|
301
|
+
else:
|
|
302
|
+
training_env["CONFIG_PATH"] = str(config_path.with_suffix(""))
|
|
303
|
+
elif config_path.suffix == ".toml":
|
|
304
|
+
training_env["CONFIG_BACKEND"] = "toml"
|
|
305
|
+
if config_path.parent == Path("."):
|
|
306
|
+
training_env["CONFIG_PATH"] = str(
|
|
307
|
+
Path.cwd() / config_path.with_suffix("")
|
|
308
|
+
)
|
|
309
|
+
else:
|
|
310
|
+
training_env["CONFIG_PATH"] = str(config_path.with_suffix(""))
|
|
311
|
+
elif config_path.suffix == ".env":
|
|
312
|
+
training_env["CONFIG_BACKEND"] = "env"
|
|
313
|
+
if config_path.parent == Path("."):
|
|
314
|
+
training_env["CONFIG_PATH"] = str(
|
|
315
|
+
Path.cwd() / config_path.with_suffix("")
|
|
316
|
+
)
|
|
317
|
+
else:
|
|
318
|
+
training_env["CONFIG_PATH"] = str(config_path.with_suffix(""))
|
|
319
|
+
|
|
320
|
+
if example:
|
|
321
|
+
# Validate example exists
|
|
322
|
+
available_examples = list_examples()
|
|
323
|
+
if example not in available_examples:
|
|
324
|
+
print(f"Error: Example '{example}' not found.")
|
|
325
|
+
print(f"Available examples: {', '.join(available_examples)}")
|
|
326
|
+
return 1
|
|
327
|
+
|
|
328
|
+
training_env = setup_environment_from_example(example)
|
|
329
|
+
print(f"Using example: {example}")
|
|
330
|
+
elif env:
|
|
331
|
+
training_env["ENV"] = env
|
|
332
|
+
print(f"Using environment: {env}")
|
|
333
|
+
|
|
334
|
+
# Find simpletuner train.py
|
|
335
|
+
import simpletuner
|
|
336
|
+
|
|
337
|
+
simpletuner_dir = Path(simpletuner.__file__).parent
|
|
338
|
+
train_py = simpletuner_dir / "train.py"
|
|
339
|
+
|
|
340
|
+
if not train_py.exists():
|
|
341
|
+
print(f"Error: train.py not found at {train_py}")
|
|
342
|
+
return 1
|
|
343
|
+
|
|
344
|
+
# Setup accelerate command
|
|
345
|
+
accelerate_config = find_accelerate_config()
|
|
346
|
+
|
|
347
|
+
if accelerate_config:
|
|
348
|
+
print(f"Using Accelerate config file: {accelerate_config}")
|
|
349
|
+
cmd = [
|
|
350
|
+
"accelerate",
|
|
351
|
+
"launch",
|
|
352
|
+
f"--config_file={accelerate_config}",
|
|
353
|
+
str(train_py),
|
|
354
|
+
]
|
|
355
|
+
else:
|
|
356
|
+
print("Accelerate config file not found. Using environment variables.")
|
|
357
|
+
cmd = [
|
|
358
|
+
"accelerate",
|
|
359
|
+
"launch",
|
|
360
|
+
f"--mixed_precision={training_env.get('MIXED_PRECISION', 'bf16')}",
|
|
361
|
+
f"--num_processes={training_env.get('TRAINING_NUM_PROCESSES', '1')}",
|
|
362
|
+
f"--num_machines={training_env.get('TRAINING_NUM_MACHINES', '1')}",
|
|
363
|
+
f"--dynamo_backend={training_env.get('TRAINING_DYNAMO_BACKEND', 'no')}",
|
|
364
|
+
str(train_py),
|
|
365
|
+
]
|
|
366
|
+
|
|
367
|
+
# Add any extra accelerate args
|
|
368
|
+
accelerate_extra_args = training_env.get("ACCELERATE_EXTRA_ARGS", "")
|
|
369
|
+
if accelerate_extra_args:
|
|
370
|
+
# Insert extra args before the train.py script
|
|
371
|
+
extra_args = accelerate_extra_args.split()
|
|
372
|
+
cmd = cmd[:-1] + extra_args + cmd[-1:]
|
|
373
|
+
|
|
374
|
+
print(f"Running: {' '.join(cmd)}")
|
|
375
|
+
|
|
376
|
+
# Run the training
|
|
377
|
+
try:
|
|
378
|
+
result = subprocess.run(cmd, env=training_env)
|
|
379
|
+
return result.returncode
|
|
380
|
+
except KeyboardInterrupt:
|
|
381
|
+
print("\nTraining interrupted by user.")
|
|
382
|
+
return 130
|
|
383
|
+
except Exception as e:
|
|
384
|
+
print(f"Error running training: {e}")
|
|
385
|
+
return 1
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def cmd_train(args) -> int:
|
|
389
|
+
"""Handle train command."""
|
|
390
|
+
example = getattr(args, "example", None)
|
|
391
|
+
env = getattr(args, "env", None)
|
|
392
|
+
|
|
393
|
+
# Parse key=value arguments
|
|
394
|
+
if hasattr(args, "args") and args.args:
|
|
395
|
+
for arg in args.args:
|
|
396
|
+
if "=" in arg:
|
|
397
|
+
key, value = arg.split("=", 1)
|
|
398
|
+
if key == "example" and not example:
|
|
399
|
+
example = value
|
|
400
|
+
elif key == "env" and not env:
|
|
401
|
+
env = value
|
|
402
|
+
else:
|
|
403
|
+
print(f"Warning: Unknown argument '{key}={value}'")
|
|
404
|
+
else:
|
|
405
|
+
print(f"Warning: Invalid argument format '{arg}'. Expected key=value")
|
|
406
|
+
|
|
407
|
+
return run_training(example=example, env=env)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def cmd_examples(args) -> int:
|
|
411
|
+
"""Handle examples command."""
|
|
412
|
+
if args.action == "list":
|
|
413
|
+
examples = list_examples()
|
|
414
|
+
if not examples:
|
|
415
|
+
print("No examples found.")
|
|
416
|
+
return 1
|
|
417
|
+
|
|
418
|
+
print("Available examples:")
|
|
419
|
+
for example in examples:
|
|
420
|
+
print(f" {example}")
|
|
421
|
+
return 0
|
|
422
|
+
|
|
423
|
+
elif args.action == "copy":
|
|
424
|
+
if not args.name:
|
|
425
|
+
print("Error: Example name required for copy action.")
|
|
426
|
+
return 1
|
|
427
|
+
|
|
428
|
+
success = copy_example(args.name, args.dest)
|
|
429
|
+
return 0 if success else 1
|
|
430
|
+
|
|
431
|
+
else:
|
|
432
|
+
print(f"Unknown examples action: {args.action}")
|
|
433
|
+
return 1
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def cmd_configure(args) -> int:
|
|
437
|
+
"""Handle configure command."""
|
|
438
|
+
output_file = getattr(args, "output_file", "config.json")
|
|
439
|
+
|
|
440
|
+
# Import and run the configure module
|
|
441
|
+
try:
|
|
442
|
+
from simpletuner.configure import main as configure_main
|
|
443
|
+
|
|
444
|
+
# Set up sys.argv for the configure script
|
|
445
|
+
import sys
|
|
446
|
+
|
|
447
|
+
original_argv = sys.argv.copy()
|
|
448
|
+
|
|
449
|
+
# Only pass the output file if it exists (for editing existing configs)
|
|
450
|
+
# Otherwise, start fresh and let the user save to the output file
|
|
451
|
+
if Path(output_file).exists():
|
|
452
|
+
sys.argv = ["configure.py", output_file]
|
|
453
|
+
print(f"Loading existing configuration from: {output_file}")
|
|
454
|
+
else:
|
|
455
|
+
sys.argv = ["configure.py"]
|
|
456
|
+
print(f"Creating new configuration. Will save to: {output_file}")
|
|
457
|
+
|
|
458
|
+
try:
|
|
459
|
+
configure_main()
|
|
460
|
+
return 0
|
|
461
|
+
except KeyboardInterrupt:
|
|
462
|
+
print("\nConfiguration cancelled by user.")
|
|
463
|
+
return 130
|
|
464
|
+
except Exception as e:
|
|
465
|
+
print(f"Error running configuration wizard: {e}")
|
|
466
|
+
return 1
|
|
467
|
+
finally:
|
|
468
|
+
sys.argv = original_argv
|
|
469
|
+
|
|
470
|
+
except ImportError as e:
|
|
471
|
+
print(f"Error importing configuration module: {e}")
|
|
472
|
+
return 1
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def cmd_server(args) -> int:
|
|
476
|
+
"""Handle server command."""
|
|
477
|
+
host = getattr(args, "host", "0.0.0.0")
|
|
478
|
+
port = getattr(args, "port", 8001)
|
|
479
|
+
reload = getattr(args, "reload", False)
|
|
480
|
+
|
|
481
|
+
print(f"Starting SimpleTuner server on {host}:{port}")
|
|
482
|
+
|
|
483
|
+
try:
|
|
484
|
+
import uvicorn
|
|
485
|
+
from simpletuner.service_worker import app
|
|
486
|
+
|
|
487
|
+
# Create necessary directories
|
|
488
|
+
os.makedirs("static/css", exist_ok=True)
|
|
489
|
+
os.makedirs("static/js", exist_ok=True)
|
|
490
|
+
os.makedirs("templates", exist_ok=True)
|
|
491
|
+
os.makedirs("configs", exist_ok=True)
|
|
492
|
+
|
|
493
|
+
# Run the server
|
|
494
|
+
uvicorn.run(app, host=host, port=port, reload=reload, log_level="info")
|
|
495
|
+
return 0
|
|
496
|
+
except KeyboardInterrupt:
|
|
497
|
+
print("\nServer stopped by user.")
|
|
498
|
+
return 130
|
|
499
|
+
except ImportError as e:
|
|
500
|
+
print(f"Error importing server dependencies: {e}")
|
|
501
|
+
print("Make sure FastAPI and uvicorn are installed.")
|
|
502
|
+
return 1
|
|
503
|
+
except Exception as e:
|
|
504
|
+
print(f"Error starting server: {e}")
|
|
505
|
+
return 1
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
def get_version() -> str:
|
|
509
|
+
"""Get SimpleTuner version."""
|
|
510
|
+
try:
|
|
511
|
+
import simpletuner
|
|
512
|
+
|
|
513
|
+
return getattr(simpletuner, "__version__", "unknown")
|
|
514
|
+
except:
|
|
515
|
+
return "unknown"
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def main():
|
|
519
|
+
"""Main CLI entry point."""
|
|
520
|
+
parser = argparse.ArgumentParser(
|
|
521
|
+
prog="simpletuner",
|
|
522
|
+
description="SimpleTuner - Fine-tune diffusion models with ease",
|
|
523
|
+
)
|
|
524
|
+
parser.add_argument(
|
|
525
|
+
"--version", "-v", action="version", version=f"SimpleTuner {get_version()}"
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
|
529
|
+
|
|
530
|
+
# Train command
|
|
531
|
+
train_parser = subparsers.add_parser(
|
|
532
|
+
"train",
|
|
533
|
+
help="Run training",
|
|
534
|
+
description="Run training with automatic config detection or examples",
|
|
535
|
+
epilog="""
|
|
536
|
+
Examples:
|
|
537
|
+
simpletuner train # Use config.json in current directory
|
|
538
|
+
simpletuner train --example sd3.peft-lora # Use example configuration
|
|
539
|
+
simpletuner train example=sd3.peft-lora # Alternative syntax for examples
|
|
540
|
+
simpletuner train --env custom-path # Use custom environment path
|
|
541
|
+
""",
|
|
542
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
543
|
+
)
|
|
544
|
+
train_group = train_parser.add_mutually_exclusive_group()
|
|
545
|
+
train_group.add_argument(
|
|
546
|
+
"--example", "-e", help="Use example configuration (e.g., sd3.peft-lora)"
|
|
547
|
+
)
|
|
548
|
+
train_group.add_argument("--env", help="Use custom environment path")
|
|
549
|
+
# Add support for positional arguments like example=value
|
|
550
|
+
train_parser.add_argument(
|
|
551
|
+
"args",
|
|
552
|
+
nargs="*",
|
|
553
|
+
help="Additional arguments in key=value format (e.g., example=sd3.peft-lora)",
|
|
554
|
+
)
|
|
555
|
+
train_parser.set_defaults(func=cmd_train)
|
|
556
|
+
|
|
557
|
+
# Examples command
|
|
558
|
+
examples_parser = subparsers.add_parser("examples", help="Manage examples")
|
|
559
|
+
examples_subparsers = examples_parser.add_subparsers(
|
|
560
|
+
dest="action", help="Examples actions"
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
# examples list
|
|
564
|
+
list_parser = examples_subparsers.add_parser("list", help="List available examples")
|
|
565
|
+
|
|
566
|
+
# examples copy
|
|
567
|
+
copy_parser = examples_subparsers.add_parser(
|
|
568
|
+
"copy", help="Copy example to local directory"
|
|
569
|
+
)
|
|
570
|
+
copy_parser.add_argument("name", help="Example name to copy")
|
|
571
|
+
copy_parser.add_argument(
|
|
572
|
+
"dest", nargs="?", help="Destination directory (default: current)"
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
examples_parser.set_defaults(func=cmd_examples)
|
|
576
|
+
|
|
577
|
+
# Configure command
|
|
578
|
+
configure_parser = subparsers.add_parser(
|
|
579
|
+
"configure",
|
|
580
|
+
help="Interactive configuration wizard",
|
|
581
|
+
description="Run the interactive configuration wizard to create training configs",
|
|
582
|
+
)
|
|
583
|
+
configure_parser.add_argument(
|
|
584
|
+
"output_file",
|
|
585
|
+
nargs="?",
|
|
586
|
+
default="config.json",
|
|
587
|
+
help="Output configuration file (default: config.json)",
|
|
588
|
+
)
|
|
589
|
+
configure_parser.set_defaults(func=cmd_configure)
|
|
590
|
+
|
|
591
|
+
# Server command
|
|
592
|
+
server_parser = subparsers.add_parser(
|
|
593
|
+
"server",
|
|
594
|
+
help="Start SimpleTuner web server",
|
|
595
|
+
description="Start the SimpleTuner web server for training management",
|
|
596
|
+
)
|
|
597
|
+
server_parser.add_argument(
|
|
598
|
+
"--host",
|
|
599
|
+
default="0.0.0.0",
|
|
600
|
+
help="Host to bind the server to (default: 0.0.0.0)",
|
|
601
|
+
)
|
|
602
|
+
server_parser.add_argument(
|
|
603
|
+
"--port",
|
|
604
|
+
type=int,
|
|
605
|
+
default=8001,
|
|
606
|
+
help="Port to bind the server to (default: 8001)",
|
|
607
|
+
)
|
|
608
|
+
server_parser.add_argument(
|
|
609
|
+
"--reload",
|
|
610
|
+
action="store_true",
|
|
611
|
+
help="Enable auto-reload for development",
|
|
612
|
+
)
|
|
613
|
+
server_parser.set_defaults(func=cmd_server)
|
|
614
|
+
|
|
615
|
+
# Parse args and run command
|
|
616
|
+
args = parser.parse_args()
|
|
617
|
+
|
|
618
|
+
if not hasattr(args, "func"):
|
|
619
|
+
parser.print_help()
|
|
620
|
+
return 1
|
|
621
|
+
|
|
622
|
+
return args.func(args)
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
if __name__ == "__main__":
|
|
626
|
+
sys.exit(main())
|