doc2lora 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc2lora/__init__.py +7 -0
- doc2lora/cli.py +523 -0
- doc2lora/core.py +289 -0
- doc2lora/deploy.py +230 -0
- doc2lora/lora_trainer.py +605 -0
- doc2lora/parsers.py +881 -0
- doc2lora/utils.py +432 -0
- doc2lora-1.0.0.dist-info/METADATA +603 -0
- doc2lora-1.0.0.dist-info/RECORD +13 -0
- doc2lora-1.0.0.dist-info/WHEEL +5 -0
- doc2lora-1.0.0.dist-info/entry_points.txt +2 -0
- doc2lora-1.0.0.dist-info/licenses/LICENSE +21 -0
- doc2lora-1.0.0.dist-info/top_level.txt +1 -0
doc2lora/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"""doc2lora: A library for fine-tuning LLMs using LoRA by using a folder of documents as input."""
|
|
2
|
+
|
|
3
|
+
from .core import convert, convert_from_r2
|
|
4
|
+
|
|
5
|
+
# single source of truth for the package version (read by pyproject + cli)
|
|
6
|
+
__version__ = "1.0.0"
|
|
7
|
+
__all__ = ["convert", "convert_from_r2", "__version__"]
|
doc2lora/cli.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
1
|
+
"""Command-line interface for doc2lora."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import click
|
|
9
|
+
|
|
10
|
+
from . import __version__
|
|
11
|
+
from .core import convert, convert_from_r2
|
|
12
|
+
|
|
13
|
+
# Set up logging
|
|
14
|
+
logging.basicConfig(
|
|
15
|
+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@click.group()
|
|
20
|
+
@click.version_option(version=__version__)
|
|
21
|
+
def cli():
|
|
22
|
+
"""doc2lora: Convert documents to LoRA adapters for LLM fine-tuning."""
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@cli.command()
|
|
27
|
+
@click.argument(
|
|
28
|
+
"documents_path", type=click.Path(exists=True, file_okay=False, dir_okay=True)
|
|
29
|
+
)
|
|
30
|
+
@click.option(
|
|
31
|
+
"--output",
|
|
32
|
+
"-o",
|
|
33
|
+
default="lora_adapter.json",
|
|
34
|
+
help="Output path for the LoRA adapter",
|
|
35
|
+
)
|
|
36
|
+
@click.option(
|
|
37
|
+
"--model",
|
|
38
|
+
"-m",
|
|
39
|
+
default="microsoft/DialoGPT-small",
|
|
40
|
+
help="Base model name for fine-tuning",
|
|
41
|
+
)
|
|
42
|
+
@click.option(
|
|
43
|
+
"--max-length", default=512, help="Maximum sequence length for tokenization"
|
|
44
|
+
)
|
|
45
|
+
@click.option("--batch-size", default=4, help="Training batch size")
|
|
46
|
+
@click.option("--epochs", default=3, help="Number of training epochs")
|
|
47
|
+
@click.option(
|
|
48
|
+
"--max-steps",
|
|
49
|
+
default=None,
|
|
50
|
+
type=int,
|
|
51
|
+
help="Maximum number of training steps (overrides epochs if set)",
|
|
52
|
+
)
|
|
53
|
+
@click.option("--learning-rate", default=5e-4, help="Learning rate for training")
|
|
54
|
+
@click.option(
|
|
55
|
+
"--lora-r",
|
|
56
|
+
default=8,
|
|
57
|
+
help="LoRA rank parameter (Cloudflare Workers AI supports up to 32)",
|
|
58
|
+
)
|
|
59
|
+
@click.option("--lora-alpha", default=16, help="LoRA alpha parameter")
|
|
60
|
+
@click.option("--lora-dropout", default=0.1, help="LoRA dropout rate")
|
|
61
|
+
@click.option(
|
|
62
|
+
"--gradient-accumulation-steps",
|
|
63
|
+
default=1,
|
|
64
|
+
type=int,
|
|
65
|
+
help="Accumulate grads to emulate a larger batch on low-memory machines",
|
|
66
|
+
)
|
|
67
|
+
@click.option(
|
|
68
|
+
"--gradient-checkpointing/--no-gradient-checkpointing",
|
|
69
|
+
default=True,
|
|
70
|
+
help="Trade compute for memory (helps on low-RAM machines)",
|
|
71
|
+
)
|
|
72
|
+
@click.option(
|
|
73
|
+
"--load-in-4bit",
|
|
74
|
+
is_flag=True,
|
|
75
|
+
default=False,
|
|
76
|
+
help="Use 4-bit QLoRA (requires bitsandbytes + CUDA)",
|
|
77
|
+
)
|
|
78
|
+
@click.option(
|
|
79
|
+
"--device",
|
|
80
|
+
default=None,
|
|
81
|
+
type=click.Choice(["cuda", "mps", "cpu", "auto"], case_sensitive=False),
|
|
82
|
+
help="Device to use for training (auto-detects by default)",
|
|
83
|
+
)
|
|
84
|
+
@click.option("--verbose", "-v", is_flag=True, help="Enable verbose logging")
|
|
85
|
+
def convert_cmd(
|
|
86
|
+
documents_path: str,
|
|
87
|
+
output: str,
|
|
88
|
+
model: str,
|
|
89
|
+
max_length: int,
|
|
90
|
+
batch_size: int,
|
|
91
|
+
epochs: int,
|
|
92
|
+
max_steps: Optional[int],
|
|
93
|
+
learning_rate: float,
|
|
94
|
+
lora_r: int,
|
|
95
|
+
lora_alpha: int,
|
|
96
|
+
lora_dropout: float,
|
|
97
|
+
gradient_accumulation_steps: int,
|
|
98
|
+
gradient_checkpointing: bool,
|
|
99
|
+
load_in_4bit: bool,
|
|
100
|
+
device: Optional[str],
|
|
101
|
+
verbose: bool,
|
|
102
|
+
):
|
|
103
|
+
"""Convert a folder of documents to LoRA adapter format."""
|
|
104
|
+
if verbose:
|
|
105
|
+
logging.getLogger().setLevel(logging.DEBUG)
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
click.echo(f"Converting documents from: {documents_path}")
|
|
109
|
+
click.echo(f"Output will be saved to: {output}")
|
|
110
|
+
|
|
111
|
+
adapter_path = convert(
|
|
112
|
+
documents_path=documents_path,
|
|
113
|
+
output_path=output,
|
|
114
|
+
model_name=model,
|
|
115
|
+
max_length=max_length,
|
|
116
|
+
batch_size=batch_size,
|
|
117
|
+
num_epochs=epochs if max_steps is None else None,
|
|
118
|
+
max_steps=max_steps,
|
|
119
|
+
learning_rate=learning_rate,
|
|
120
|
+
lora_r=lora_r,
|
|
121
|
+
lora_alpha=lora_alpha,
|
|
122
|
+
lora_dropout=lora_dropout,
|
|
123
|
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
|
124
|
+
gradient_checkpointing=gradient_checkpointing,
|
|
125
|
+
load_in_4bit=load_in_4bit,
|
|
126
|
+
device=None if device == "auto" else device,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
click.echo(f"ā
LoRA adapter successfully created at: {adapter_path}")
|
|
130
|
+
|
|
131
|
+
except Exception as e:
|
|
132
|
+
click.echo(f"ā Error: {e}", err=True)
|
|
133
|
+
raise click.Abort()
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@cli.command()
|
|
137
|
+
@click.argument(
|
|
138
|
+
"documents_path", type=click.Path(exists=True, file_okay=False, dir_okay=True)
|
|
139
|
+
)
|
|
140
|
+
@click.option(
|
|
141
|
+
"--device",
|
|
142
|
+
default=None,
|
|
143
|
+
type=click.Choice(["cuda", "mps", "cpu", "auto"], case_sensitive=False),
|
|
144
|
+
help="Device to assume for the training-time estimate",
|
|
145
|
+
)
|
|
146
|
+
def scan(documents_path: str, device: Optional[str]):
|
|
147
|
+
"""Scan a directory for supported document files."""
|
|
148
|
+
from .parsers import DocumentParser
|
|
149
|
+
from .utils import create_training_summary
|
|
150
|
+
|
|
151
|
+
parser = DocumentParser()
|
|
152
|
+
documents = parser.parse_directory(documents_path)
|
|
153
|
+
|
|
154
|
+
click.echo(f"Found {len(documents)} supported documents:")
|
|
155
|
+
|
|
156
|
+
for doc in documents:
|
|
157
|
+
size_kb = doc["size"] / 1024
|
|
158
|
+
click.echo(f" š {doc['filename']} ({doc['extension']}, {size_kb:.1f} KB)")
|
|
159
|
+
|
|
160
|
+
if documents:
|
|
161
|
+
summary = create_training_summary(
|
|
162
|
+
documents, device=None if device == "auto" else device
|
|
163
|
+
)
|
|
164
|
+
click.echo(
|
|
165
|
+
f"\nTotal size: {summary['total_size_formatted']} "
|
|
166
|
+
f"across {len(summary['file_types'])} file type(s)"
|
|
167
|
+
)
|
|
168
|
+
click.echo(
|
|
169
|
+
f"Estimated training time (~small model): "
|
|
170
|
+
f"{summary['estimated_training_time']}"
|
|
171
|
+
)
|
|
172
|
+
click.echo(
|
|
173
|
+
"Note: rough estimate; 7B-class models are ~20-40x slower "
|
|
174
|
+
"(QLoRA on CUDA recovers much of that)."
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@cli.command()
|
|
179
|
+
def formats():
|
|
180
|
+
"""List supported document formats."""
|
|
181
|
+
from .parsers import DocumentParser
|
|
182
|
+
|
|
183
|
+
click.echo("Supported document formats:")
|
|
184
|
+
|
|
185
|
+
formats_info = [
|
|
186
|
+
(".md/.rst", "Markdown / reStructuredText"),
|
|
187
|
+
(".txt", "Text files"),
|
|
188
|
+
(".pdf", "PDF documents"),
|
|
189
|
+
(".html", "HTML files"),
|
|
190
|
+
(".docx", "Word documents"),
|
|
191
|
+
(".pptx", "PowerPoint slides (text + notes)"),
|
|
192
|
+
(".odt/.ods", "OpenDocument text / spreadsheet"),
|
|
193
|
+
(".rtf", "Rich Text Format"),
|
|
194
|
+
(".epub", "EPUB e-books"),
|
|
195
|
+
(".csv", "CSV files"),
|
|
196
|
+
(".json", "JSON files"),
|
|
197
|
+
(".ipynb", "Jupyter notebooks (markdown + code)"),
|
|
198
|
+
(".yaml/.yml", "YAML files"),
|
|
199
|
+
(".xml", "XML files"),
|
|
200
|
+
(".tex", "LaTeX files"),
|
|
201
|
+
(
|
|
202
|
+
"audio",
|
|
203
|
+
"Speech-to-text (.wav, .mp3, .m4a, .flac, .aac, .ogg, ...)",
|
|
204
|
+
),
|
|
205
|
+
(
|
|
206
|
+
"source code",
|
|
207
|
+
"Read as plaintext (.py, .js, .rs, .kt, .c/.cpp, .go, .dart, ...)",
|
|
208
|
+
),
|
|
209
|
+
(".zip", "ZIP archives containing supported documents"),
|
|
210
|
+
(".tar", "TAR archives containing supported documents"),
|
|
211
|
+
(".tar.gz/.tgz", "Gzip-compressed TAR archives"),
|
|
212
|
+
(".tar.bz2/.tbz2", "Bzip2-compressed TAR archives"),
|
|
213
|
+
(".tar.xz/.txz", "XZ-compressed TAR archives"),
|
|
214
|
+
(".7z", "7-Zip archives"),
|
|
215
|
+
(".gz/.bz2/.xz", "Single-file compressed documents"),
|
|
216
|
+
]
|
|
217
|
+
|
|
218
|
+
for ext, description in formats_info:
|
|
219
|
+
click.echo(f" {ext:<15} {description}")
|
|
220
|
+
|
|
221
|
+
code_exts = ", ".join(sorted(DocumentParser.CODE_EXTENSIONS))
|
|
222
|
+
click.echo(f"\nSource-code extensions read as plaintext: {code_exts}")
|
|
223
|
+
click.echo("\nNote: Archive formats (.zip, .tar, .7z, etc.) will extract and parse")
|
|
224
|
+
click.echo(" any supported document files they contain.")
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
@cli.command()
|
|
228
|
+
@click.argument("bucket_name")
|
|
229
|
+
@click.option(
|
|
230
|
+
"--folder-prefix",
|
|
231
|
+
"-f",
|
|
232
|
+
help="Optional folder prefix within the bucket",
|
|
233
|
+
)
|
|
234
|
+
@click.option(
|
|
235
|
+
"--output",
|
|
236
|
+
"-o",
|
|
237
|
+
default="lora_adapter.json",
|
|
238
|
+
help="Output path for the LoRA adapter",
|
|
239
|
+
)
|
|
240
|
+
@click.option(
|
|
241
|
+
"--model",
|
|
242
|
+
"-m",
|
|
243
|
+
default="microsoft/DialoGPT-small",
|
|
244
|
+
help="Base model name for fine-tuning",
|
|
245
|
+
)
|
|
246
|
+
@click.option(
|
|
247
|
+
"--max-length", default=512, help="Maximum sequence length for tokenization"
|
|
248
|
+
)
|
|
249
|
+
@click.option("--batch-size", default=4, help="Training batch size")
|
|
250
|
+
@click.option("--epochs", default=3, help="Number of training epochs")
|
|
251
|
+
@click.option(
|
|
252
|
+
"--max-steps",
|
|
253
|
+
default=None,
|
|
254
|
+
type=int,
|
|
255
|
+
help="Maximum number of training steps (overrides epochs if set)",
|
|
256
|
+
)
|
|
257
|
+
@click.option("--learning-rate", default=5e-4, help="Learning rate for training")
|
|
258
|
+
@click.option(
|
|
259
|
+
"--lora-r",
|
|
260
|
+
default=8,
|
|
261
|
+
help="LoRA rank parameter (Cloudflare Workers AI supports up to 32)",
|
|
262
|
+
)
|
|
263
|
+
@click.option("--lora-alpha", default=16, help="LoRA alpha parameter")
|
|
264
|
+
@click.option("--lora-dropout", default=0.1, help="LoRA dropout rate")
|
|
265
|
+
@click.option(
|
|
266
|
+
"--gradient-accumulation-steps",
|
|
267
|
+
default=1,
|
|
268
|
+
type=int,
|
|
269
|
+
help="Accumulate grads to emulate a larger batch on low-memory machines",
|
|
270
|
+
)
|
|
271
|
+
@click.option(
|
|
272
|
+
"--gradient-checkpointing/--no-gradient-checkpointing",
|
|
273
|
+
default=True,
|
|
274
|
+
help="Trade compute for memory (helps on low-RAM machines)",
|
|
275
|
+
)
|
|
276
|
+
@click.option(
|
|
277
|
+
"--load-in-4bit",
|
|
278
|
+
is_flag=True,
|
|
279
|
+
default=False,
|
|
280
|
+
help="Use 4-bit QLoRA (requires bitsandbytes + CUDA)",
|
|
281
|
+
)
|
|
282
|
+
@click.option(
|
|
283
|
+
"--device",
|
|
284
|
+
default=None,
|
|
285
|
+
type=click.Choice(["cuda", "mps", "cpu", "auto"], case_sensitive=False),
|
|
286
|
+
help="Device to use for training (auto-detects by default)",
|
|
287
|
+
)
|
|
288
|
+
@click.option(
|
|
289
|
+
"--r2-access-key-id",
|
|
290
|
+
envvar="R2_ACCESS_KEY_ID",
|
|
291
|
+
help="R2 access key ID (can also be set via R2_ACCESS_KEY_ID env var)",
|
|
292
|
+
)
|
|
293
|
+
@click.option(
|
|
294
|
+
"--r2-secret-access-key",
|
|
295
|
+
envvar="R2_SECRET_ACCESS_KEY",
|
|
296
|
+
help="R2 secret access key (can also be set via R2_SECRET_ACCESS_KEY env var)",
|
|
297
|
+
)
|
|
298
|
+
@click.option(
|
|
299
|
+
"--endpoint-url",
|
|
300
|
+
envvar="R2_ENDPOINT_URL",
|
|
301
|
+
help="R2 endpoint URL (can also be set via R2_ENDPOINT_URL env var)",
|
|
302
|
+
)
|
|
303
|
+
@click.option(
|
|
304
|
+
"--region-name",
|
|
305
|
+
default="auto",
|
|
306
|
+
help="Region name (default: auto for R2)",
|
|
307
|
+
)
|
|
308
|
+
@click.option(
|
|
309
|
+
"--env-file",
|
|
310
|
+
help="Path to .env file containing R2 credentials",
|
|
311
|
+
)
|
|
312
|
+
@click.option("--verbose", "-v", is_flag=True, help="Enable verbose logging")
|
|
313
|
+
def convert_r2(
|
|
314
|
+
bucket_name: str,
|
|
315
|
+
folder_prefix: str,
|
|
316
|
+
output: str,
|
|
317
|
+
model: str,
|
|
318
|
+
max_length: int,
|
|
319
|
+
batch_size: int,
|
|
320
|
+
epochs: int,
|
|
321
|
+
max_steps: Optional[int],
|
|
322
|
+
learning_rate: float,
|
|
323
|
+
lora_r: int,
|
|
324
|
+
lora_alpha: int,
|
|
325
|
+
lora_dropout: float,
|
|
326
|
+
gradient_accumulation_steps: int,
|
|
327
|
+
gradient_checkpointing: bool,
|
|
328
|
+
load_in_4bit: bool,
|
|
329
|
+
device: Optional[str],
|
|
330
|
+
r2_access_key_id: str,
|
|
331
|
+
r2_secret_access_key: str,
|
|
332
|
+
endpoint_url: str,
|
|
333
|
+
region_name: str,
|
|
334
|
+
env_file: str,
|
|
335
|
+
verbose: bool,
|
|
336
|
+
):
|
|
337
|
+
"""Convert documents from an R2 bucket to LoRA adapter format."""
|
|
338
|
+
if verbose:
|
|
339
|
+
logging.getLogger().setLevel(logging.DEBUG)
|
|
340
|
+
|
|
341
|
+
# Load .env file if provided
|
|
342
|
+
if env_file:
|
|
343
|
+
from .utils import load_env_file
|
|
344
|
+
|
|
345
|
+
try:
|
|
346
|
+
load_env_file(env_file)
|
|
347
|
+
click.echo(f"Loaded credentials from: {env_file}")
|
|
348
|
+
except FileNotFoundError as e:
|
|
349
|
+
click.echo(f"ā Error: {e}", err=True)
|
|
350
|
+
raise click.Abort()
|
|
351
|
+
|
|
352
|
+
# If credentials not provided directly, try to get from environment
|
|
353
|
+
# (which may have been loaded from .env file)
|
|
354
|
+
if not r2_access_key_id:
|
|
355
|
+
r2_access_key_id = os.getenv("R2_ACCESS_KEY_ID") or os.getenv(
|
|
356
|
+
"AWS_ACCESS_KEY_ID"
|
|
357
|
+
)
|
|
358
|
+
if not r2_secret_access_key:
|
|
359
|
+
r2_secret_access_key = os.getenv("R2_SECRET_ACCESS_KEY") or os.getenv(
|
|
360
|
+
"AWS_SECRET_ACCESS_KEY"
|
|
361
|
+
)
|
|
362
|
+
if not endpoint_url:
|
|
363
|
+
endpoint_url = os.getenv("R2_ENDPOINT_URL")
|
|
364
|
+
|
|
365
|
+
# Validate required credentials
|
|
366
|
+
if not r2_access_key_id or not r2_secret_access_key:
|
|
367
|
+
click.echo(
|
|
368
|
+
"ā Error: R2 credentials are required. Provide them via:\n"
|
|
369
|
+
" --r2-access-key-id and --r2-secret-access-key options, or\n"
|
|
370
|
+
" R2_ACCESS_KEY_ID and R2_SECRET_ACCESS_KEY environment variables, or\n"
|
|
371
|
+
" AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables (legacy), or\n"
|
|
372
|
+
" --env-file option pointing to a .env file",
|
|
373
|
+
err=True,
|
|
374
|
+
)
|
|
375
|
+
raise click.Abort()
|
|
376
|
+
|
|
377
|
+
if not endpoint_url:
|
|
378
|
+
click.echo(
|
|
379
|
+
"ā Error: R2 endpoint URL is required. Provide it via:\n"
|
|
380
|
+
" --endpoint-url option, or\n"
|
|
381
|
+
" R2_ENDPOINT_URL environment variable, or\n"
|
|
382
|
+
" --env-file option pointing to a .env file\n"
|
|
383
|
+
" Example: https://your-account.r2.cloudflarestorage.com",
|
|
384
|
+
err=True,
|
|
385
|
+
)
|
|
386
|
+
raise click.Abort()
|
|
387
|
+
|
|
388
|
+
try:
|
|
389
|
+
click.echo(f"Converting documents from R2 bucket: {bucket_name}")
|
|
390
|
+
if folder_prefix:
|
|
391
|
+
click.echo(f"Folder prefix: {folder_prefix}")
|
|
392
|
+
click.echo(f"Output will be saved to: {output}")
|
|
393
|
+
|
|
394
|
+
adapter_path = convert_from_r2(
|
|
395
|
+
bucket_name=bucket_name,
|
|
396
|
+
folder_prefix=folder_prefix,
|
|
397
|
+
output_path=output,
|
|
398
|
+
model_name=model,
|
|
399
|
+
max_length=max_length,
|
|
400
|
+
batch_size=batch_size,
|
|
401
|
+
num_epochs=epochs if max_steps is None else None,
|
|
402
|
+
max_steps=max_steps,
|
|
403
|
+
learning_rate=learning_rate,
|
|
404
|
+
lora_r=lora_r,
|
|
405
|
+
lora_alpha=lora_alpha,
|
|
406
|
+
lora_dropout=lora_dropout,
|
|
407
|
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
|
408
|
+
gradient_checkpointing=gradient_checkpointing,
|
|
409
|
+
load_in_4bit=load_in_4bit,
|
|
410
|
+
device=None if device == "auto" else device,
|
|
411
|
+
aws_access_key_id=r2_access_key_id,
|
|
412
|
+
aws_secret_access_key=r2_secret_access_key,
|
|
413
|
+
endpoint_url=endpoint_url,
|
|
414
|
+
region_name=region_name,
|
|
415
|
+
env_file=env_file,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
click.echo(f"ā
LoRA adapter successfully created at: {adapter_path}")
|
|
419
|
+
|
|
420
|
+
except Exception as e:
|
|
421
|
+
if "No files found" in str(e):
|
|
422
|
+
click.echo(f"ā Error: {e}", err=True)
|
|
423
|
+
click.echo("\nš” Troubleshooting tips:", err=True)
|
|
424
|
+
click.echo(" ⢠Check that your bucket contains files", err=True)
|
|
425
|
+
click.echo(
|
|
426
|
+
" ⢠Verify the folder prefix (if specified) is correct", err=True
|
|
427
|
+
)
|
|
428
|
+
click.echo(
|
|
429
|
+
" ⢠Ensure files are in supported formats (.md, .txt, .pdf, etc.)",
|
|
430
|
+
err=True,
|
|
431
|
+
)
|
|
432
|
+
elif "Bucket" in str(e) and "does not exist" in str(e):
|
|
433
|
+
click.echo(f"ā Error: {e}", err=True)
|
|
434
|
+
click.echo("\nš” Troubleshooting tips:", err=True)
|
|
435
|
+
click.echo(" ⢠Check the bucket name is correct", err=True)
|
|
436
|
+
click.echo(" ⢠Verify the bucket exists in your R2 account", err=True)
|
|
437
|
+
click.echo(
|
|
438
|
+
" ⢠Ensure your credentials have access to this bucket", err=True
|
|
439
|
+
)
|
|
440
|
+
elif "endpoint" in str(e).lower():
|
|
441
|
+
click.echo(f"ā Error: {e}", err=True)
|
|
442
|
+
click.echo("\nš” Troubleshooting tips:", err=True)
|
|
443
|
+
click.echo(
|
|
444
|
+
" ⢠R2 endpoint format: https://your-account-id.r2.cloudflarestorage.com",
|
|
445
|
+
err=True,
|
|
446
|
+
)
|
|
447
|
+
click.echo(
|
|
448
|
+
" ⢠Do NOT include the bucket name in the endpoint URL", err=True
|
|
449
|
+
)
|
|
450
|
+
click.echo(
|
|
451
|
+
" ⢠Get your endpoint from Cloudflare dashboard > R2 > Manage R2 API tokens",
|
|
452
|
+
err=True,
|
|
453
|
+
)
|
|
454
|
+
else:
|
|
455
|
+
click.echo(f"ā Error: {e}", err=True)
|
|
456
|
+
raise click.Abort()
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
@cli.command()
|
|
460
|
+
@click.argument("adapter_path", type=click.Path(exists=True))
|
|
461
|
+
@click.argument("finetune_name")
|
|
462
|
+
@click.option(
|
|
463
|
+
"--cf-model",
|
|
464
|
+
default=None,
|
|
465
|
+
help="Lora-capable base model endpoint (derived from model_type if omitted)",
|
|
466
|
+
)
|
|
467
|
+
@click.option(
|
|
468
|
+
"--backend",
|
|
469
|
+
type=click.Choice(["wrangler", "rest"], case_sensitive=False),
|
|
470
|
+
default="wrangler",
|
|
471
|
+
help="Upload via the wrangler CLI or the Cloudflare REST API",
|
|
472
|
+
)
|
|
473
|
+
@click.option(
|
|
474
|
+
"--account-id",
|
|
475
|
+
envvar="CLOUDFLARE_ACCOUNT_ID",
|
|
476
|
+
help="Cloudflare account id (REST backend; or CLOUDFLARE_ACCOUNT_ID)",
|
|
477
|
+
)
|
|
478
|
+
@click.option(
|
|
479
|
+
"--api-token",
|
|
480
|
+
envvar="CLOUDFLARE_API_TOKEN",
|
|
481
|
+
help="Cloudflare API token (REST backend; or CLOUDFLARE_API_TOKEN)",
|
|
482
|
+
)
|
|
483
|
+
@click.option("--verbose", "-v", is_flag=True, help="Enable verbose logging")
|
|
484
|
+
def deploy(
|
|
485
|
+
adapter_path: str,
|
|
486
|
+
finetune_name: str,
|
|
487
|
+
cf_model: Optional[str],
|
|
488
|
+
backend: str,
|
|
489
|
+
account_id: Optional[str],
|
|
490
|
+
api_token: Optional[str],
|
|
491
|
+
verbose: bool,
|
|
492
|
+
):
|
|
493
|
+
"""Upload a trained adapter to Cloudflare Workers AI."""
|
|
494
|
+
if verbose:
|
|
495
|
+
logging.getLogger().setLevel(logging.DEBUG)
|
|
496
|
+
|
|
497
|
+
from .deploy import deploy_adapter
|
|
498
|
+
|
|
499
|
+
try:
|
|
500
|
+
result = deploy_adapter(
|
|
501
|
+
adapter_path=adapter_path,
|
|
502
|
+
finetune_name=finetune_name,
|
|
503
|
+
cf_model=cf_model,
|
|
504
|
+
backend=backend.lower(),
|
|
505
|
+
account_id=account_id,
|
|
506
|
+
api_token=api_token,
|
|
507
|
+
)
|
|
508
|
+
click.echo(f"ā
Deployed: {result}")
|
|
509
|
+
click.echo(
|
|
510
|
+
f' Reference it at inference with the lora param: "{finetune_name}"'
|
|
511
|
+
)
|
|
512
|
+
except Exception as e:
|
|
513
|
+
click.echo(f"ā Error: {e}", err=True)
|
|
514
|
+
raise click.Abort()
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def main():
|
|
518
|
+
"""Main entry point for the CLI."""
|
|
519
|
+
cli()
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
if __name__ == "__main__":
|
|
523
|
+
main()
|