simpletuner 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. simpletuner/__init__.py +3 -0
  2. simpletuner/cli.py +626 -0
  3. simpletuner/configure.py +5649 -0
  4. simpletuner/examples/README.md +53 -0
  5. simpletuner/examples/auraflow.peft-controlnet-lora/config.json +46 -0
  6. simpletuner/examples/auraflow.peft-lora/config.json +55 -0
  7. simpletuner/examples/bria.lycoris-lokr/config.json +58 -0
  8. simpletuner/examples/bria.lycoris-lokr/lycoris_config.json +20 -0
  9. simpletuner/examples/cosmos2image.lycoris-lokr/config.json +43 -0
  10. simpletuner/examples/cosmos2image.lycoris-lokr/lycoris_config.json +20 -0
  11. simpletuner/examples/flux.peft-controlnet-lora/config.json +47 -0
  12. simpletuner/examples/flux.peft-lora/config.json +46 -0
  13. simpletuner/examples/flux.peft-lora+TREAD/config.json +55 -0
  14. simpletuner/examples/hidream.peft-controlnet-lora/config.env +1 -0
  15. simpletuner/examples/hidream.peft-controlnet-lora/config.json +53 -0
  16. simpletuner/examples/hidream.peft-lora/config.env +1 -0
  17. simpletuner/examples/hidream.peft-lora/config.json +70 -0
  18. simpletuner/examples/kontext.peft-lora/config.env +1 -0
  19. simpletuner/examples/kontext.peft-lora/config.json +50 -0
  20. simpletuner/examples/lumina2.peft-lora/config.json +41 -0
  21. simpletuner/examples/omnigen.lycoris-lokr/config.env +4 -0
  22. simpletuner/examples/omnigen.lycoris-lokr/config.json +61 -0
  23. simpletuner/examples/omnigen.lycoris-lokr/lycoris_config.json +22 -0
  24. simpletuner/examples/pixart.lycoris-lokr/config.json +44 -0
  25. simpletuner/examples/pixart.lycoris-lokr/lycoris_config.json +22 -0
  26. simpletuner/examples/pixart.peft-controlnet-lora/config.json +45 -0
  27. simpletuner/examples/qwen_image.peft-lora/config.json +57 -0
  28. simpletuner/examples/sana.lycoris-lokr/config.json +63 -0
  29. simpletuner/examples/sana.lycoris-lokr/lycoris_config.json +22 -0
  30. simpletuner/examples/sd3.peft-controlnet-lora/config.json +46 -0
  31. simpletuner/examples/sd3.peft-lora/config.json +48 -0
  32. simpletuner/examples/sdxl.lycoris-lokr/config.json +39 -0
  33. simpletuner/examples/sdxl.lycoris-lokr/lycoris_config.json +18 -0
  34. simpletuner/examples/sdxl.peft-controlnet-lora/config.json +42 -0
  35. simpletuner/examples/wan-1.3b.peft-lora+TREAD/config.json +69 -0
  36. simpletuner/helpers/caching/memory.py +13 -0
  37. simpletuner/helpers/caching/text_embeds.py +512 -0
  38. simpletuner/helpers/caching/vae.py +1391 -0
  39. simpletuner/helpers/configuration/cmd_args.py +2869 -0
  40. simpletuner/helpers/configuration/env_file.py +195 -0
  41. simpletuner/helpers/configuration/json_file.py +68 -0
  42. simpletuner/helpers/configuration/loader.py +65 -0
  43. simpletuner/helpers/configuration/toml_file.py +76 -0
  44. simpletuner/helpers/data_backend/aws.py +451 -0
  45. simpletuner/helpers/data_backend/base.py +151 -0
  46. simpletuner/helpers/data_backend/csv_url_list.py +338 -0
  47. simpletuner/helpers/data_backend/factory.py +2221 -0
  48. simpletuner/helpers/data_backend/huggingface.py +651 -0
  49. simpletuner/helpers/data_backend/local.py +347 -0
  50. simpletuner/helpers/data_generation/__init__.py +2 -0
  51. simpletuner/helpers/data_generation/conditioning.py +715 -0
  52. simpletuner/helpers/data_generation/sample_generator.py +1186 -0
  53. simpletuner/helpers/distillation/common.py +153 -0
  54. simpletuner/helpers/distillation/dcm/discriminator/hunyuan_video.py +171 -0
  55. simpletuner/helpers/distillation/dcm/discriminator/wan.py +502 -0
  56. simpletuner/helpers/distillation/dcm/distiller.py +393 -0
  57. simpletuner/helpers/distillation/dcm/loss.py +98 -0
  58. simpletuner/helpers/distillation/dcm/solver.py +679 -0
  59. simpletuner/helpers/distillation/dmd/distiller.py +377 -0
  60. simpletuner/helpers/distillation/factory.py +339 -0
  61. simpletuner/helpers/distillation/lcm/__init__.py +418 -0
  62. simpletuner/helpers/image_manipulation/batched_training_samples.py +398 -0
  63. simpletuner/helpers/image_manipulation/brightness.py +35 -0
  64. simpletuner/helpers/image_manipulation/cropping.py +309 -0
  65. simpletuner/helpers/image_manipulation/load.py +216 -0
  66. simpletuner/helpers/image_manipulation/training_sample.py +1071 -0
  67. simpletuner/helpers/legacy/pipeline.py +1202 -0
  68. simpletuner/helpers/log_format.py +133 -0
  69. simpletuner/helpers/metadata/__init__.py +0 -0
  70. simpletuner/helpers/metadata/backends/base.py +1165 -0
  71. simpletuner/helpers/metadata/backends/discovery.py +299 -0
  72. simpletuner/helpers/metadata/backends/huggingface.py +825 -0
  73. simpletuner/helpers/metadata/backends/parquet.py +660 -0
  74. simpletuner/helpers/metadata/utils/__init__.py +2 -0
  75. simpletuner/helpers/metadata/utils/duplicator.py +209 -0
  76. simpletuner/helpers/models/all.py +58 -0
  77. simpletuner/helpers/models/auraflow/controlnet.py +567 -0
  78. simpletuner/helpers/models/auraflow/model.py +459 -0
  79. simpletuner/helpers/models/auraflow/pipeline.py +1570 -0
  80. simpletuner/helpers/models/auraflow/pipeline_controlnet.py +1157 -0
  81. simpletuner/helpers/models/auraflow/transformer.py +782 -0
  82. simpletuner/helpers/models/common.py +1800 -0
  83. simpletuner/helpers/models/cosmos/model.py +335 -0
  84. simpletuner/helpers/models/cosmos/pipeline.py +708 -0
  85. simpletuner/helpers/models/cosmos/transformer.py +703 -0
  86. simpletuner/helpers/models/deepfloyd/model.py +219 -0
  87. simpletuner/helpers/models/flux/__init__.py +188 -0
  88. simpletuner/helpers/models/flux/attention.py +639 -0
  89. simpletuner/helpers/models/flux/model.py +1083 -0
  90. simpletuner/helpers/models/flux/pipeline.py +2761 -0
  91. simpletuner/helpers/models/flux/pipeline_controlnet.py +2117 -0
  92. simpletuner/helpers/models/flux/transformer.py +1192 -0
  93. simpletuner/helpers/models/hidream/controlnet.py +1699 -0
  94. simpletuner/helpers/models/hidream/model.py +697 -0
  95. simpletuner/helpers/models/hidream/pipeline.py +1933 -0
  96. simpletuner/helpers/models/hidream/schedule.py +802 -0
  97. simpletuner/helpers/models/hidream/transformer.py +1618 -0
  98. simpletuner/helpers/models/kolors/controlnet.py +1009 -0
  99. simpletuner/helpers/models/kolors/model.py +261 -0
  100. simpletuner/helpers/models/kolors/pipeline.py +2211 -0
  101. simpletuner/helpers/models/kolors/pipeline_controlnet.py +1555 -0
  102. simpletuner/helpers/models/ltxvideo/__init__.py +211 -0
  103. simpletuner/helpers/models/ltxvideo/model.py +309 -0
  104. simpletuner/helpers/models/lumina2/model.py +303 -0
  105. simpletuner/helpers/models/omnigen/collator.py +61 -0
  106. simpletuner/helpers/models/omnigen/model.py +173 -0
  107. simpletuner/helpers/models/pixart/controlnet.py +345 -0
  108. simpletuner/helpers/models/pixart/model.py +451 -0
  109. simpletuner/helpers/models/pixart/model_card_templates.py +50 -0
  110. simpletuner/helpers/models/pixart/pipeline.py +3116 -0
  111. simpletuner/helpers/models/pixart/transformer.py +477 -0
  112. simpletuner/helpers/models/qwen_image/model.py +351 -0
  113. simpletuner/helpers/models/sana/__init__.py +0 -0
  114. simpletuner/helpers/models/sana/model.py +195 -0
  115. simpletuner/helpers/models/sana/pipeline.py +964 -0
  116. simpletuner/helpers/models/sana/transformer.py +545 -0
  117. simpletuner/helpers/models/sd1x/model.py +309 -0
  118. simpletuner/helpers/models/sd1x/pipeline.py +4146 -0
  119. simpletuner/helpers/models/sd3/__init__.py +0 -0
  120. simpletuner/helpers/models/sd3/controlnet.py +1410 -0
  121. simpletuner/helpers/models/sd3/expanded.py +740 -0
  122. simpletuner/helpers/models/sd3/model.py +593 -0
  123. simpletuner/helpers/models/sd3/pipeline.py +3023 -0
  124. simpletuner/helpers/models/sd3/transformer.py +478 -0
  125. simpletuner/helpers/models/sdxl/controlnet.py +29 -0
  126. simpletuner/helpers/models/sdxl/model.py +347 -0
  127. simpletuner/helpers/models/sdxl/pipeline.py +5138 -0
  128. simpletuner/helpers/models/wan/__init__.py +94 -0
  129. simpletuner/helpers/models/wan/model.py +303 -0
  130. simpletuner/helpers/models/wan/pipeline.py +653 -0
  131. simpletuner/helpers/models/wan/transformer.py +698 -0
  132. simpletuner/helpers/multiaspect/dataset.py +90 -0
  133. simpletuner/helpers/multiaspect/image.py +327 -0
  134. simpletuner/helpers/multiaspect/sampler.py +889 -0
  135. simpletuner/helpers/multiaspect/state.py +67 -0
  136. simpletuner/helpers/multiaspect/video.py +25 -0
  137. simpletuner/helpers/prompt_expander/__init__.py +276 -0
  138. simpletuner/helpers/prompts.py +678 -0
  139. simpletuner/helpers/publishing/huggingface.py +279 -0
  140. simpletuner/helpers/publishing/metadata.py +593 -0
  141. simpletuner/helpers/training/__init__.py +181 -0
  142. simpletuner/helpers/training/adapter.py +142 -0
  143. simpletuner/helpers/training/collate.py +661 -0
  144. simpletuner/helpers/training/custom_schedule.py +1606 -0
  145. simpletuner/helpers/training/deepspeed.py +134 -0
  146. simpletuner/helpers/training/default_settings/__init__.py +15 -0
  147. simpletuner/helpers/training/default_settings/safety_check.py +163 -0
  148. simpletuner/helpers/training/diffusers_overrides.py +406 -0
  149. simpletuner/helpers/training/diffusion_model.py +4 -0
  150. simpletuner/helpers/training/ema.py +473 -0
  151. simpletuner/helpers/training/error_handling.py +31 -0
  152. simpletuner/helpers/training/evaluation.py +64 -0
  153. simpletuner/helpers/training/exceptions.py +2 -0
  154. simpletuner/helpers/training/gradient_checkpointing_interval.py +42 -0
  155. simpletuner/helpers/training/min_snr_gamma.py +47 -0
  156. simpletuner/helpers/training/model_freeze.py +180 -0
  157. simpletuner/helpers/training/multi_process.py +20 -0
  158. simpletuner/helpers/training/optimizer_param.py +827 -0
  159. simpletuner/helpers/training/optimizers/adamw_bfloat16/__init__.py +167 -0
  160. simpletuner/helpers/training/optimizers/adamw_bfloat16/stochastic/__init__.py +124 -0
  161. simpletuner/helpers/training/optimizers/adamw_schedulefree/__init__.py +149 -0
  162. simpletuner/helpers/training/optimizers/soap/__init__.py +479 -0
  163. simpletuner/helpers/training/peft_init.py +25 -0
  164. simpletuner/helpers/training/quantisation/__init__.py +371 -0
  165. simpletuner/helpers/training/quantisation/peft_workarounds.py +423 -0
  166. simpletuner/helpers/training/quantisation/quanto_workarounds.py +114 -0
  167. simpletuner/helpers/training/quantisation/torchao_workarounds.py +146 -0
  168. simpletuner/helpers/training/save_hooks.py +408 -0
  169. simpletuner/helpers/training/state_tracker.py +657 -0
  170. simpletuner/helpers/training/trainer.py +2718 -0
  171. simpletuner/helpers/training/tread.py +162 -0
  172. simpletuner/helpers/training/validation.py +2635 -0
  173. simpletuner/helpers/training/wrappers.py +32 -0
  174. simpletuner/helpers/webhooks/config.py +51 -0
  175. simpletuner/helpers/webhooks/handler.py +247 -0
  176. simpletuner/helpers/webhooks/mixin.py +31 -0
  177. simpletuner/inference.py +160 -0
  178. simpletuner/inference_comparison.py +163 -0
  179. simpletuner/service_worker.py +110 -0
  180. simpletuner/simpletuner_sdk/api_state.py +87 -0
  181. simpletuner/simpletuner_sdk/configuration.py +183 -0
  182. simpletuner/simpletuner_sdk/interface.py +409 -0
  183. simpletuner/simpletuner_sdk/thread_keeper/__init__.py +67 -0
  184. simpletuner/simpletuner_sdk/training_host.py +75 -0
  185. simpletuner/train.py +92 -0
  186. simpletuner-2.2.0.dist-info/METADATA +274 -0
  187. simpletuner-2.2.0.dist-info/RECORD +191 -0
  188. simpletuner-2.2.0.dist-info/WHEEL +5 -0
  189. simpletuner-2.2.0.dist-info/entry_points.txt +5 -0
  190. simpletuner-2.2.0.dist-info/licenses/LICENSE +208 -0
  191. simpletuner-2.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,3 @@
1
+ """SimpleTuner - Stable Diffusion 2.x and XL tuner."""
2
+
3
+ __version__ = "2.2.0"
simpletuner/cli.py ADDED
@@ -0,0 +1,626 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ SimpleTuner CLI - Command-line interface for SimpleTuner
4
+ """
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ import sys
10
+ import subprocess
11
+ import shutil
12
+ from pathlib import Path
13
+ from typing import List, Optional
14
+
15
+
16
+ def find_config_file() -> Optional[str]:
17
+ """Find config file in current directory or config/ subdirectory."""
18
+ # Check for config.json in current directory
19
+ for config_name in ["config.json", "config.toml", "config.env"]:
20
+ if os.path.exists(config_name):
21
+ return config_name
22
+
23
+ # Check for config files in config/ subdirectory
24
+ config_dir = Path("config")
25
+ if config_dir.exists():
26
+ for config_name in ["config.json", "config.toml", "config.env"]:
27
+ config_path = config_dir / config_name
28
+ if config_path.exists():
29
+ return str(config_path)
30
+
31
+ return None
32
+
33
+
34
+ def get_examples_dir() -> Path:
35
+ """Get the path to the examples directory."""
36
+ # Find simpletuner package directory
37
+ import simpletuner
38
+
39
+ simpletuner_dir = Path(simpletuner.__file__).parent
40
+ return simpletuner_dir / "examples"
41
+
42
+
43
+ def list_examples() -> List[str]:
44
+ """List all available examples (directories only)."""
45
+ examples_dir = get_examples_dir()
46
+ if not examples_dir.exists():
47
+ return []
48
+
49
+ examples = []
50
+ for item in examples_dir.iterdir():
51
+ if item.is_dir() and not item.name.startswith("."):
52
+ examples.append(item.name)
53
+
54
+ return sorted(examples)
55
+
56
+
57
+ def find_referenced_files(config_path: Path) -> List[str]:
58
+ """Find files referenced in config.json that should be copied."""
59
+ referenced_files = []
60
+
61
+ try:
62
+ with open(config_path, "r") as f:
63
+ config = json.load(f)
64
+
65
+ # Look for common fields that reference files in examples directory
66
+ fields_to_check = [
67
+ "data_backend_config",
68
+ "validation_prompt_library",
69
+ "controlnet_config",
70
+ "reference_config",
71
+ "lycoris_config",
72
+ ]
73
+
74
+ for field in fields_to_check:
75
+ if field in config:
76
+ value = config[field]
77
+ if isinstance(value, str):
78
+ # Check if it references a file in examples directory
79
+ if "examples/" in value and value.endswith(".json"):
80
+ # Extract just the filename from paths like "config/examples/file.json"
81
+ filename = Path(value).name
82
+ referenced_files.append(filename)
83
+
84
+ except (json.JSONDecodeError, FileNotFoundError) as e:
85
+ print(f"Warning: Could not parse config file {config_path}: {e}")
86
+
87
+ return referenced_files
88
+
89
+
90
+ def copy_example(example_name: str, dest: Optional[str] = None) -> bool:
91
+ """Copy an example to destination directory."""
92
+ examples_dir = get_examples_dir()
93
+ example_path = examples_dir / example_name
94
+
95
+ if not example_path.exists():
96
+ print(f"Error: Example '{example_name}' not found.")
97
+ print(f"Available examples: {', '.join(list_examples())}")
98
+ return False
99
+
100
+ # Determine destination
101
+ if dest is None:
102
+ dest = "."
103
+
104
+ dest_path = Path(dest)
105
+
106
+ try:
107
+ if example_path.is_dir():
108
+ # Copy directory
109
+ dest_example = dest_path / example_name
110
+ if dest_example.exists():
111
+ print(f"Error: Destination '{dest_example}' already exists.")
112
+ return False
113
+ shutil.copytree(example_path, dest_example)
114
+ print(f"Copied example directory '{example_name}' to '{dest_example}'")
115
+
116
+ # Check for referenced files in config.json
117
+ config_json = dest_example / "config.json"
118
+ if config_json.exists():
119
+ referenced_files = find_referenced_files(config_json)
120
+ if referenced_files:
121
+ print(
122
+ f"Found {len(referenced_files)} referenced file(s) to copy..."
123
+ )
124
+
125
+ for ref_file in referenced_files:
126
+ source_file = examples_dir / ref_file
127
+ if source_file.exists():
128
+ dest_file = dest_example / ref_file
129
+ shutil.copy2(source_file, dest_file)
130
+ print(f" Copied referenced file: {ref_file}")
131
+ else:
132
+ print(f" Warning: Referenced file not found: {ref_file}")
133
+
134
+ # Update config.json to use local paths
135
+ update_config_paths(config_json, referenced_files)
136
+
137
+ else:
138
+ # Copy file
139
+ dest_file = dest_path / example_path.name
140
+ if dest_file.exists():
141
+ print(f"Error: Destination '{dest_file}' already exists.")
142
+ return False
143
+ shutil.copy2(example_path, dest_file)
144
+ print(f"Copied example file '{example_name}' to '{dest_file}'")
145
+
146
+ return True
147
+ except Exception as e:
148
+ print(f"Error copying example: {e}")
149
+ return False
150
+
151
+
152
+ def update_config_paths(config_path: Path, referenced_files: List[str]) -> None:
153
+ """Update config.json paths to point to local referenced files."""
154
+ try:
155
+ with open(config_path, "r") as f:
156
+ config = json.load(f)
157
+
158
+ # Update paths for referenced files
159
+ fields_to_check = [
160
+ "data_backend_config",
161
+ "validation_prompt_library",
162
+ "controlnet_config",
163
+ "reference_config",
164
+ "lycoris_config",
165
+ ]
166
+
167
+ updated = False
168
+ for field in fields_to_check:
169
+ if field in config:
170
+ value = config[field]
171
+ if isinstance(value, str) and "examples/" in value:
172
+ filename = Path(value).name
173
+ if filename in referenced_files:
174
+ # Update to use local path
175
+ config[field] = filename
176
+ updated = True
177
+ print(f" Updated {field}: {value} -> {filename}")
178
+
179
+ if updated:
180
+ with open(config_path, "w") as f:
181
+ json.dump(config, f, indent=4)
182
+ print(f" Updated config.json with local file paths")
183
+
184
+ except (json.JSONDecodeError, FileNotFoundError) as e:
185
+ print(f"Warning: Could not update config paths: {e}")
186
+
187
+
188
+ def setup_environment_from_example(example_name: str) -> dict:
189
+ """Setup environment variables for training with example."""
190
+ env = os.environ.copy()
191
+
192
+ # Set ENV variable to point to the example
193
+ env["ENV"] = f"examples/{example_name}"
194
+
195
+ # Find the example config file
196
+ examples_dir = get_examples_dir()
197
+ example_path = examples_dir / example_name
198
+
199
+ # Check if it's a directory with config files or a standalone config file
200
+ if example_path.is_dir():
201
+ config_path = example_path / "config"
202
+ config_base = str(config_path)
203
+
204
+ # Check for different config file types
205
+ if (example_path / "config.json").exists():
206
+ env["CONFIG_BACKEND"] = "json"
207
+ env["CONFIG_PATH"] = str(example_path / "config")
208
+ elif (example_path / "config.toml").exists():
209
+ env["CONFIG_BACKEND"] = "toml"
210
+ env["CONFIG_PATH"] = str(example_path / "config")
211
+ elif (example_path / "config.env").exists():
212
+ env["CONFIG_BACKEND"] = "env"
213
+ env["CONFIG_PATH"] = str(example_path / "config")
214
+ else:
215
+ raise ValueError(f"No config file found in example {example_name}")
216
+ elif example_path.is_file():
217
+ # Standalone config file
218
+ if example_path.suffix == ".json":
219
+ env["CONFIG_BACKEND"] = "json"
220
+ env["CONFIG_PATH"] = str(example_path.with_suffix(""))
221
+ elif example_path.suffix == ".toml":
222
+ env["CONFIG_BACKEND"] = "toml"
223
+ env["CONFIG_PATH"] = str(example_path.with_suffix(""))
224
+ else:
225
+ raise ValueError(f"Unsupported config file type: {example_path.suffix}")
226
+ else:
227
+ raise ValueError(f"Example {example_name} not found")
228
+
229
+ # Set other default environment variables (similar to train.sh)
230
+ if "TQDM_NCOLS" not in env:
231
+ env["TQDM_NCOLS"] = "125"
232
+ if "TQDM_LEAVE" not in env:
233
+ env["TQDM_LEAVE"] = "false"
234
+
235
+ env["TOKENIZERS_PARALLELISM"] = "false"
236
+
237
+ # Platform-specific settings
238
+ platform = os.uname().sysname
239
+ env["PLATFORM"] = platform
240
+ if platform == "Darwin":
241
+ env["MIXED_PRECISION"] = "no"
242
+
243
+ # Training defaults
244
+ if "TRAINING_NUM_PROCESSES" not in env:
245
+ env["TRAINING_NUM_PROCESSES"] = "1"
246
+ if "TRAINING_NUM_MACHINES" not in env:
247
+ env["TRAINING_NUM_MACHINES"] = "1"
248
+ if "MIXED_PRECISION" not in env:
249
+ env["MIXED_PRECISION"] = "bf16"
250
+ if "TRAINING_DYNAMO_BACKEND" not in env:
251
+ env["TRAINING_DYNAMO_BACKEND"] = "no"
252
+
253
+ return env
254
+
255
+
256
+ def find_accelerate_config() -> Optional[str]:
257
+ """Find accelerate configuration file."""
258
+ # Check HF_HOME first
259
+ hf_home = os.environ.get("HF_HOME", "")
260
+ if hf_home:
261
+ config_path = Path(hf_home) / "accelerate" / "default_config.yaml"
262
+ if config_path.exists():
263
+ return str(config_path)
264
+
265
+ # Fallback to default cache location
266
+ home = Path.home()
267
+ config_path = home / ".cache" / "huggingface" / "accelerate" / "default_config.yaml"
268
+ if config_path.exists():
269
+ return str(config_path)
270
+
271
+ return None
272
+
273
+
274
+ def run_training(example: Optional[str] = None, env: Optional[str] = None) -> int:
275
+ """Run training with specified example or environment."""
276
+ # Setup environment
277
+ training_env = os.environ.copy()
278
+
279
+ # Check for config file if no example/env specified
280
+ if not example and not env:
281
+ config_file = find_config_file()
282
+ if not config_file:
283
+ print(
284
+ "Error: No config file found in current directory or config/ subdirectory."
285
+ )
286
+ print("Expected: config.json, config.toml, or config.env")
287
+ print("Or use: simpletuner train example=<example_name>")
288
+ return 1
289
+
290
+ print(f"Using config file: {config_file}")
291
+
292
+ # Set environment variables for local config
293
+ config_path = Path(config_file)
294
+ if config_path.suffix == ".json":
295
+ training_env["CONFIG_BACKEND"] = "json"
296
+ # If config is in current directory, use absolute path
297
+ if config_path.parent == Path("."):
298
+ training_env["CONFIG_PATH"] = str(
299
+ Path.cwd() / config_path.with_suffix("")
300
+ )
301
+ else:
302
+ training_env["CONFIG_PATH"] = str(config_path.with_suffix(""))
303
+ elif config_path.suffix == ".toml":
304
+ training_env["CONFIG_BACKEND"] = "toml"
305
+ if config_path.parent == Path("."):
306
+ training_env["CONFIG_PATH"] = str(
307
+ Path.cwd() / config_path.with_suffix("")
308
+ )
309
+ else:
310
+ training_env["CONFIG_PATH"] = str(config_path.with_suffix(""))
311
+ elif config_path.suffix == ".env":
312
+ training_env["CONFIG_BACKEND"] = "env"
313
+ if config_path.parent == Path("."):
314
+ training_env["CONFIG_PATH"] = str(
315
+ Path.cwd() / config_path.with_suffix("")
316
+ )
317
+ else:
318
+ training_env["CONFIG_PATH"] = str(config_path.with_suffix(""))
319
+
320
+ if example:
321
+ # Validate example exists
322
+ available_examples = list_examples()
323
+ if example not in available_examples:
324
+ print(f"Error: Example '{example}' not found.")
325
+ print(f"Available examples: {', '.join(available_examples)}")
326
+ return 1
327
+
328
+ training_env = setup_environment_from_example(example)
329
+ print(f"Using example: {example}")
330
+ elif env:
331
+ training_env["ENV"] = env
332
+ print(f"Using environment: {env}")
333
+
334
+ # Find simpletuner train.py
335
+ import simpletuner
336
+
337
+ simpletuner_dir = Path(simpletuner.__file__).parent
338
+ train_py = simpletuner_dir / "train.py"
339
+
340
+ if not train_py.exists():
341
+ print(f"Error: train.py not found at {train_py}")
342
+ return 1
343
+
344
+ # Setup accelerate command
345
+ accelerate_config = find_accelerate_config()
346
+
347
+ if accelerate_config:
348
+ print(f"Using Accelerate config file: {accelerate_config}")
349
+ cmd = [
350
+ "accelerate",
351
+ "launch",
352
+ f"--config_file={accelerate_config}",
353
+ str(train_py),
354
+ ]
355
+ else:
356
+ print("Accelerate config file not found. Using environment variables.")
357
+ cmd = [
358
+ "accelerate",
359
+ "launch",
360
+ f"--mixed_precision={training_env.get('MIXED_PRECISION', 'bf16')}",
361
+ f"--num_processes={training_env.get('TRAINING_NUM_PROCESSES', '1')}",
362
+ f"--num_machines={training_env.get('TRAINING_NUM_MACHINES', '1')}",
363
+ f"--dynamo_backend={training_env.get('TRAINING_DYNAMO_BACKEND', 'no')}",
364
+ str(train_py),
365
+ ]
366
+
367
+ # Add any extra accelerate args
368
+ accelerate_extra_args = training_env.get("ACCELERATE_EXTRA_ARGS", "")
369
+ if accelerate_extra_args:
370
+ # Insert extra args before the train.py script
371
+ extra_args = accelerate_extra_args.split()
372
+ cmd = cmd[:-1] + extra_args + cmd[-1:]
373
+
374
+ print(f"Running: {' '.join(cmd)}")
375
+
376
+ # Run the training
377
+ try:
378
+ result = subprocess.run(cmd, env=training_env)
379
+ return result.returncode
380
+ except KeyboardInterrupt:
381
+ print("\nTraining interrupted by user.")
382
+ return 130
383
+ except Exception as e:
384
+ print(f"Error running training: {e}")
385
+ return 1
386
+
387
+
388
+ def cmd_train(args) -> int:
389
+ """Handle train command."""
390
+ example = getattr(args, "example", None)
391
+ env = getattr(args, "env", None)
392
+
393
+ # Parse key=value arguments
394
+ if hasattr(args, "args") and args.args:
395
+ for arg in args.args:
396
+ if "=" in arg:
397
+ key, value = arg.split("=", 1)
398
+ if key == "example" and not example:
399
+ example = value
400
+ elif key == "env" and not env:
401
+ env = value
402
+ else:
403
+ print(f"Warning: Unknown argument '{key}={value}'")
404
+ else:
405
+ print(f"Warning: Invalid argument format '{arg}'. Expected key=value")
406
+
407
+ return run_training(example=example, env=env)
408
+
409
+
410
+ def cmd_examples(args) -> int:
411
+ """Handle examples command."""
412
+ if args.action == "list":
413
+ examples = list_examples()
414
+ if not examples:
415
+ print("No examples found.")
416
+ return 1
417
+
418
+ print("Available examples:")
419
+ for example in examples:
420
+ print(f" {example}")
421
+ return 0
422
+
423
+ elif args.action == "copy":
424
+ if not args.name:
425
+ print("Error: Example name required for copy action.")
426
+ return 1
427
+
428
+ success = copy_example(args.name, args.dest)
429
+ return 0 if success else 1
430
+
431
+ else:
432
+ print(f"Unknown examples action: {args.action}")
433
+ return 1
434
+
435
+
436
+ def cmd_configure(args) -> int:
437
+ """Handle configure command."""
438
+ output_file = getattr(args, "output_file", "config.json")
439
+
440
+ # Import and run the configure module
441
+ try:
442
+ from simpletuner.configure import main as configure_main
443
+
444
+ # Set up sys.argv for the configure script
445
+ import sys
446
+
447
+ original_argv = sys.argv.copy()
448
+
449
+ # Only pass the output file if it exists (for editing existing configs)
450
+ # Otherwise, start fresh and let the user save to the output file
451
+ if Path(output_file).exists():
452
+ sys.argv = ["configure.py", output_file]
453
+ print(f"Loading existing configuration from: {output_file}")
454
+ else:
455
+ sys.argv = ["configure.py"]
456
+ print(f"Creating new configuration. Will save to: {output_file}")
457
+
458
+ try:
459
+ configure_main()
460
+ return 0
461
+ except KeyboardInterrupt:
462
+ print("\nConfiguration cancelled by user.")
463
+ return 130
464
+ except Exception as e:
465
+ print(f"Error running configuration wizard: {e}")
466
+ return 1
467
+ finally:
468
+ sys.argv = original_argv
469
+
470
+ except ImportError as e:
471
+ print(f"Error importing configuration module: {e}")
472
+ return 1
473
+
474
+
475
+ def cmd_server(args) -> int:
476
+ """Handle server command."""
477
+ host = getattr(args, "host", "0.0.0.0")
478
+ port = getattr(args, "port", 8001)
479
+ reload = getattr(args, "reload", False)
480
+
481
+ print(f"Starting SimpleTuner server on {host}:{port}")
482
+
483
+ try:
484
+ import uvicorn
485
+ from simpletuner.service_worker import app
486
+
487
+ # Create necessary directories
488
+ os.makedirs("static/css", exist_ok=True)
489
+ os.makedirs("static/js", exist_ok=True)
490
+ os.makedirs("templates", exist_ok=True)
491
+ os.makedirs("configs", exist_ok=True)
492
+
493
+ # Run the server
494
+ uvicorn.run(app, host=host, port=port, reload=reload, log_level="info")
495
+ return 0
496
+ except KeyboardInterrupt:
497
+ print("\nServer stopped by user.")
498
+ return 130
499
+ except ImportError as e:
500
+ print(f"Error importing server dependencies: {e}")
501
+ print("Make sure FastAPI and uvicorn are installed.")
502
+ return 1
503
+ except Exception as e:
504
+ print(f"Error starting server: {e}")
505
+ return 1
506
+
507
+
508
+ def get_version() -> str:
509
+ """Get SimpleTuner version."""
510
+ try:
511
+ import simpletuner
512
+
513
+ return getattr(simpletuner, "__version__", "unknown")
514
+ except:
515
+ return "unknown"
516
+
517
+
518
+ def main():
519
+ """Main CLI entry point."""
520
+ parser = argparse.ArgumentParser(
521
+ prog="simpletuner",
522
+ description="SimpleTuner - Fine-tune diffusion models with ease",
523
+ )
524
+ parser.add_argument(
525
+ "--version", "-v", action="version", version=f"SimpleTuner {get_version()}"
526
+ )
527
+
528
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
529
+
530
+ # Train command
531
+ train_parser = subparsers.add_parser(
532
+ "train",
533
+ help="Run training",
534
+ description="Run training with automatic config detection or examples",
535
+ epilog="""
536
+ Examples:
537
+ simpletuner train # Use config.json in current directory
538
+ simpletuner train --example sd3.peft-lora # Use example configuration
539
+ simpletuner train example=sd3.peft-lora # Alternative syntax for examples
540
+ simpletuner train --env custom-path # Use custom environment path
541
+ """,
542
+ formatter_class=argparse.RawDescriptionHelpFormatter,
543
+ )
544
+ train_group = train_parser.add_mutually_exclusive_group()
545
+ train_group.add_argument(
546
+ "--example", "-e", help="Use example configuration (e.g., sd3.peft-lora)"
547
+ )
548
+ train_group.add_argument("--env", help="Use custom environment path")
549
+ # Add support for positional arguments like example=value
550
+ train_parser.add_argument(
551
+ "args",
552
+ nargs="*",
553
+ help="Additional arguments in key=value format (e.g., example=sd3.peft-lora)",
554
+ )
555
+ train_parser.set_defaults(func=cmd_train)
556
+
557
+ # Examples command
558
+ examples_parser = subparsers.add_parser("examples", help="Manage examples")
559
+ examples_subparsers = examples_parser.add_subparsers(
560
+ dest="action", help="Examples actions"
561
+ )
562
+
563
+ # examples list
564
+ list_parser = examples_subparsers.add_parser("list", help="List available examples")
565
+
566
+ # examples copy
567
+ copy_parser = examples_subparsers.add_parser(
568
+ "copy", help="Copy example to local directory"
569
+ )
570
+ copy_parser.add_argument("name", help="Example name to copy")
571
+ copy_parser.add_argument(
572
+ "dest", nargs="?", help="Destination directory (default: current)"
573
+ )
574
+
575
+ examples_parser.set_defaults(func=cmd_examples)
576
+
577
+ # Configure command
578
+ configure_parser = subparsers.add_parser(
579
+ "configure",
580
+ help="Interactive configuration wizard",
581
+ description="Run the interactive configuration wizard to create training configs",
582
+ )
583
+ configure_parser.add_argument(
584
+ "output_file",
585
+ nargs="?",
586
+ default="config.json",
587
+ help="Output configuration file (default: config.json)",
588
+ )
589
+ configure_parser.set_defaults(func=cmd_configure)
590
+
591
+ # Server command
592
+ server_parser = subparsers.add_parser(
593
+ "server",
594
+ help="Start SimpleTuner web server",
595
+ description="Start the SimpleTuner web server for training management",
596
+ )
597
+ server_parser.add_argument(
598
+ "--host",
599
+ default="0.0.0.0",
600
+ help="Host to bind the server to (default: 0.0.0.0)",
601
+ )
602
+ server_parser.add_argument(
603
+ "--port",
604
+ type=int,
605
+ default=8001,
606
+ help="Port to bind the server to (default: 8001)",
607
+ )
608
+ server_parser.add_argument(
609
+ "--reload",
610
+ action="store_true",
611
+ help="Enable auto-reload for development",
612
+ )
613
+ server_parser.set_defaults(func=cmd_server)
614
+
615
+ # Parse args and run command
616
+ args = parser.parse_args()
617
+
618
+ if not hasattr(args, "func"):
619
+ parser.print_help()
620
+ return 1
621
+
622
+ return args.func(args)
623
+
624
+
625
+ if __name__ == "__main__":
626
+ sys.exit(main())