doc2lora 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
doc2lora/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ """doc2lora: A library for fine-tuning LLMs using LoRA by using a folder of documents as input."""
2
+
3
+ from .core import convert, convert_from_r2
4
+
5
+ # single source of truth for the package version (read by pyproject + cli)
6
+ __version__ = "1.0.0"
7
+ __all__ = ["convert", "convert_from_r2", "__version__"]
doc2lora/cli.py ADDED
@@ -0,0 +1,523 @@
1
+ """Command-line interface for doc2lora."""
2
+
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ import click
9
+
10
+ from . import __version__
11
+ from .core import convert, convert_from_r2
12
+
13
+ # Set up logging
14
+ logging.basicConfig(
15
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
16
+ )
17
+
18
+
19
+ @click.group()
20
+ @click.version_option(version=__version__)
21
+ def cli():
22
+ """doc2lora: Convert documents to LoRA adapters for LLM fine-tuning."""
23
+ pass
24
+
25
+
26
+ @cli.command()
27
+ @click.argument(
28
+ "documents_path", type=click.Path(exists=True, file_okay=False, dir_okay=True)
29
+ )
30
+ @click.option(
31
+ "--output",
32
+ "-o",
33
+ default="lora_adapter.json",
34
+ help="Output path for the LoRA adapter",
35
+ )
36
+ @click.option(
37
+ "--model",
38
+ "-m",
39
+ default="microsoft/DialoGPT-small",
40
+ help="Base model name for fine-tuning",
41
+ )
42
+ @click.option(
43
+ "--max-length", default=512, help="Maximum sequence length for tokenization"
44
+ )
45
+ @click.option("--batch-size", default=4, help="Training batch size")
46
+ @click.option("--epochs", default=3, help="Number of training epochs")
47
+ @click.option(
48
+ "--max-steps",
49
+ default=None,
50
+ type=int,
51
+ help="Maximum number of training steps (overrides epochs if set)",
52
+ )
53
+ @click.option("--learning-rate", default=5e-4, help="Learning rate for training")
54
+ @click.option(
55
+ "--lora-r",
56
+ default=8,
57
+ help="LoRA rank parameter (Cloudflare Workers AI supports up to 32)",
58
+ )
59
+ @click.option("--lora-alpha", default=16, help="LoRA alpha parameter")
60
+ @click.option("--lora-dropout", default=0.1, help="LoRA dropout rate")
61
+ @click.option(
62
+ "--gradient-accumulation-steps",
63
+ default=1,
64
+ type=int,
65
+ help="Accumulate grads to emulate a larger batch on low-memory machines",
66
+ )
67
+ @click.option(
68
+ "--gradient-checkpointing/--no-gradient-checkpointing",
69
+ default=True,
70
+ help="Trade compute for memory (helps on low-RAM machines)",
71
+ )
72
+ @click.option(
73
+ "--load-in-4bit",
74
+ is_flag=True,
75
+ default=False,
76
+ help="Use 4-bit QLoRA (requires bitsandbytes + CUDA)",
77
+ )
78
+ @click.option(
79
+ "--device",
80
+ default=None,
81
+ type=click.Choice(["cuda", "mps", "cpu", "auto"], case_sensitive=False),
82
+ help="Device to use for training (auto-detects by default)",
83
+ )
84
+ @click.option("--verbose", "-v", is_flag=True, help="Enable verbose logging")
85
+ def convert_cmd(
86
+ documents_path: str,
87
+ output: str,
88
+ model: str,
89
+ max_length: int,
90
+ batch_size: int,
91
+ epochs: int,
92
+ max_steps: Optional[int],
93
+ learning_rate: float,
94
+ lora_r: int,
95
+ lora_alpha: int,
96
+ lora_dropout: float,
97
+ gradient_accumulation_steps: int,
98
+ gradient_checkpointing: bool,
99
+ load_in_4bit: bool,
100
+ device: Optional[str],
101
+ verbose: bool,
102
+ ):
103
+ """Convert a folder of documents to LoRA adapter format."""
104
+ if verbose:
105
+ logging.getLogger().setLevel(logging.DEBUG)
106
+
107
+ try:
108
+ click.echo(f"Converting documents from: {documents_path}")
109
+ click.echo(f"Output will be saved to: {output}")
110
+
111
+ adapter_path = convert(
112
+ documents_path=documents_path,
113
+ output_path=output,
114
+ model_name=model,
115
+ max_length=max_length,
116
+ batch_size=batch_size,
117
+ num_epochs=epochs if max_steps is None else None,
118
+ max_steps=max_steps,
119
+ learning_rate=learning_rate,
120
+ lora_r=lora_r,
121
+ lora_alpha=lora_alpha,
122
+ lora_dropout=lora_dropout,
123
+ gradient_accumulation_steps=gradient_accumulation_steps,
124
+ gradient_checkpointing=gradient_checkpointing,
125
+ load_in_4bit=load_in_4bit,
126
+ device=None if device == "auto" else device,
127
+ )
128
+
129
+ click.echo(f"āœ… LoRA adapter successfully created at: {adapter_path}")
130
+
131
+ except Exception as e:
132
+ click.echo(f"āŒ Error: {e}", err=True)
133
+ raise click.Abort()
134
+
135
+
136
+ @cli.command()
137
+ @click.argument(
138
+ "documents_path", type=click.Path(exists=True, file_okay=False, dir_okay=True)
139
+ )
140
+ @click.option(
141
+ "--device",
142
+ default=None,
143
+ type=click.Choice(["cuda", "mps", "cpu", "auto"], case_sensitive=False),
144
+ help="Device to assume for the training-time estimate",
145
+ )
146
+ def scan(documents_path: str, device: Optional[str]):
147
+ """Scan a directory for supported document files."""
148
+ from .parsers import DocumentParser
149
+ from .utils import create_training_summary
150
+
151
+ parser = DocumentParser()
152
+ documents = parser.parse_directory(documents_path)
153
+
154
+ click.echo(f"Found {len(documents)} supported documents:")
155
+
156
+ for doc in documents:
157
+ size_kb = doc["size"] / 1024
158
+ click.echo(f" šŸ“„ {doc['filename']} ({doc['extension']}, {size_kb:.1f} KB)")
159
+
160
+ if documents:
161
+ summary = create_training_summary(
162
+ documents, device=None if device == "auto" else device
163
+ )
164
+ click.echo(
165
+ f"\nTotal size: {summary['total_size_formatted']} "
166
+ f"across {len(summary['file_types'])} file type(s)"
167
+ )
168
+ click.echo(
169
+ f"Estimated training time (~small model): "
170
+ f"{summary['estimated_training_time']}"
171
+ )
172
+ click.echo(
173
+ "Note: rough estimate; 7B-class models are ~20-40x slower "
174
+ "(QLoRA on CUDA recovers much of that)."
175
+ )
176
+
177
+
178
+ @cli.command()
179
+ def formats():
180
+ """List supported document formats."""
181
+ from .parsers import DocumentParser
182
+
183
+ click.echo("Supported document formats:")
184
+
185
+ formats_info = [
186
+ (".md/.rst", "Markdown / reStructuredText"),
187
+ (".txt", "Text files"),
188
+ (".pdf", "PDF documents"),
189
+ (".html", "HTML files"),
190
+ (".docx", "Word documents"),
191
+ (".pptx", "PowerPoint slides (text + notes)"),
192
+ (".odt/.ods", "OpenDocument text / spreadsheet"),
193
+ (".rtf", "Rich Text Format"),
194
+ (".epub", "EPUB e-books"),
195
+ (".csv", "CSV files"),
196
+ (".json", "JSON files"),
197
+ (".ipynb", "Jupyter notebooks (markdown + code)"),
198
+ (".yaml/.yml", "YAML files"),
199
+ (".xml", "XML files"),
200
+ (".tex", "LaTeX files"),
201
+ (
202
+ "audio",
203
+ "Speech-to-text (.wav, .mp3, .m4a, .flac, .aac, .ogg, ...)",
204
+ ),
205
+ (
206
+ "source code",
207
+ "Read as plaintext (.py, .js, .rs, .kt, .c/.cpp, .go, .dart, ...)",
208
+ ),
209
+ (".zip", "ZIP archives containing supported documents"),
210
+ (".tar", "TAR archives containing supported documents"),
211
+ (".tar.gz/.tgz", "Gzip-compressed TAR archives"),
212
+ (".tar.bz2/.tbz2", "Bzip2-compressed TAR archives"),
213
+ (".tar.xz/.txz", "XZ-compressed TAR archives"),
214
+ (".7z", "7-Zip archives"),
215
+ (".gz/.bz2/.xz", "Single-file compressed documents"),
216
+ ]
217
+
218
+ for ext, description in formats_info:
219
+ click.echo(f" {ext:<15} {description}")
220
+
221
+ code_exts = ", ".join(sorted(DocumentParser.CODE_EXTENSIONS))
222
+ click.echo(f"\nSource-code extensions read as plaintext: {code_exts}")
223
+ click.echo("\nNote: Archive formats (.zip, .tar, .7z, etc.) will extract and parse")
224
+ click.echo(" any supported document files they contain.")
225
+
226
+
227
+ @cli.command()
228
+ @click.argument("bucket_name")
229
+ @click.option(
230
+ "--folder-prefix",
231
+ "-f",
232
+ help="Optional folder prefix within the bucket",
233
+ )
234
+ @click.option(
235
+ "--output",
236
+ "-o",
237
+ default="lora_adapter.json",
238
+ help="Output path for the LoRA adapter",
239
+ )
240
+ @click.option(
241
+ "--model",
242
+ "-m",
243
+ default="microsoft/DialoGPT-small",
244
+ help="Base model name for fine-tuning",
245
+ )
246
+ @click.option(
247
+ "--max-length", default=512, help="Maximum sequence length for tokenization"
248
+ )
249
+ @click.option("--batch-size", default=4, help="Training batch size")
250
+ @click.option("--epochs", default=3, help="Number of training epochs")
251
+ @click.option(
252
+ "--max-steps",
253
+ default=None,
254
+ type=int,
255
+ help="Maximum number of training steps (overrides epochs if set)",
256
+ )
257
+ @click.option("--learning-rate", default=5e-4, help="Learning rate for training")
258
+ @click.option(
259
+ "--lora-r",
260
+ default=8,
261
+ help="LoRA rank parameter (Cloudflare Workers AI supports up to 32)",
262
+ )
263
+ @click.option("--lora-alpha", default=16, help="LoRA alpha parameter")
264
+ @click.option("--lora-dropout", default=0.1, help="LoRA dropout rate")
265
+ @click.option(
266
+ "--gradient-accumulation-steps",
267
+ default=1,
268
+ type=int,
269
+ help="Accumulate grads to emulate a larger batch on low-memory machines",
270
+ )
271
+ @click.option(
272
+ "--gradient-checkpointing/--no-gradient-checkpointing",
273
+ default=True,
274
+ help="Trade compute for memory (helps on low-RAM machines)",
275
+ )
276
+ @click.option(
277
+ "--load-in-4bit",
278
+ is_flag=True,
279
+ default=False,
280
+ help="Use 4-bit QLoRA (requires bitsandbytes + CUDA)",
281
+ )
282
+ @click.option(
283
+ "--device",
284
+ default=None,
285
+ type=click.Choice(["cuda", "mps", "cpu", "auto"], case_sensitive=False),
286
+ help="Device to use for training (auto-detects by default)",
287
+ )
288
+ @click.option(
289
+ "--r2-access-key-id",
290
+ envvar="R2_ACCESS_KEY_ID",
291
+ help="R2 access key ID (can also be set via R2_ACCESS_KEY_ID env var)",
292
+ )
293
+ @click.option(
294
+ "--r2-secret-access-key",
295
+ envvar="R2_SECRET_ACCESS_KEY",
296
+ help="R2 secret access key (can also be set via R2_SECRET_ACCESS_KEY env var)",
297
+ )
298
+ @click.option(
299
+ "--endpoint-url",
300
+ envvar="R2_ENDPOINT_URL",
301
+ help="R2 endpoint URL (can also be set via R2_ENDPOINT_URL env var)",
302
+ )
303
+ @click.option(
304
+ "--region-name",
305
+ default="auto",
306
+ help="Region name (default: auto for R2)",
307
+ )
308
+ @click.option(
309
+ "--env-file",
310
+ help="Path to .env file containing R2 credentials",
311
+ )
312
+ @click.option("--verbose", "-v", is_flag=True, help="Enable verbose logging")
313
+ def convert_r2(
314
+ bucket_name: str,
315
+ folder_prefix: str,
316
+ output: str,
317
+ model: str,
318
+ max_length: int,
319
+ batch_size: int,
320
+ epochs: int,
321
+ max_steps: Optional[int],
322
+ learning_rate: float,
323
+ lora_r: int,
324
+ lora_alpha: int,
325
+ lora_dropout: float,
326
+ gradient_accumulation_steps: int,
327
+ gradient_checkpointing: bool,
328
+ load_in_4bit: bool,
329
+ device: Optional[str],
330
+ r2_access_key_id: str,
331
+ r2_secret_access_key: str,
332
+ endpoint_url: str,
333
+ region_name: str,
334
+ env_file: str,
335
+ verbose: bool,
336
+ ):
337
+ """Convert documents from an R2 bucket to LoRA adapter format."""
338
+ if verbose:
339
+ logging.getLogger().setLevel(logging.DEBUG)
340
+
341
+ # Load .env file if provided
342
+ if env_file:
343
+ from .utils import load_env_file
344
+
345
+ try:
346
+ load_env_file(env_file)
347
+ click.echo(f"Loaded credentials from: {env_file}")
348
+ except FileNotFoundError as e:
349
+ click.echo(f"āŒ Error: {e}", err=True)
350
+ raise click.Abort()
351
+
352
+ # If credentials not provided directly, try to get from environment
353
+ # (which may have been loaded from .env file)
354
+ if not r2_access_key_id:
355
+ r2_access_key_id = os.getenv("R2_ACCESS_KEY_ID") or os.getenv(
356
+ "AWS_ACCESS_KEY_ID"
357
+ )
358
+ if not r2_secret_access_key:
359
+ r2_secret_access_key = os.getenv("R2_SECRET_ACCESS_KEY") or os.getenv(
360
+ "AWS_SECRET_ACCESS_KEY"
361
+ )
362
+ if not endpoint_url:
363
+ endpoint_url = os.getenv("R2_ENDPOINT_URL")
364
+
365
+ # Validate required credentials
366
+ if not r2_access_key_id or not r2_secret_access_key:
367
+ click.echo(
368
+ "āŒ Error: R2 credentials are required. Provide them via:\n"
369
+ " --r2-access-key-id and --r2-secret-access-key options, or\n"
370
+ " R2_ACCESS_KEY_ID and R2_SECRET_ACCESS_KEY environment variables, or\n"
371
+ " AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables (legacy), or\n"
372
+ " --env-file option pointing to a .env file",
373
+ err=True,
374
+ )
375
+ raise click.Abort()
376
+
377
+ if not endpoint_url:
378
+ click.echo(
379
+ "āŒ Error: R2 endpoint URL is required. Provide it via:\n"
380
+ " --endpoint-url option, or\n"
381
+ " R2_ENDPOINT_URL environment variable, or\n"
382
+ " --env-file option pointing to a .env file\n"
383
+ " Example: https://your-account.r2.cloudflarestorage.com",
384
+ err=True,
385
+ )
386
+ raise click.Abort()
387
+
388
+ try:
389
+ click.echo(f"Converting documents from R2 bucket: {bucket_name}")
390
+ if folder_prefix:
391
+ click.echo(f"Folder prefix: {folder_prefix}")
392
+ click.echo(f"Output will be saved to: {output}")
393
+
394
+ adapter_path = convert_from_r2(
395
+ bucket_name=bucket_name,
396
+ folder_prefix=folder_prefix,
397
+ output_path=output,
398
+ model_name=model,
399
+ max_length=max_length,
400
+ batch_size=batch_size,
401
+ num_epochs=epochs if max_steps is None else None,
402
+ max_steps=max_steps,
403
+ learning_rate=learning_rate,
404
+ lora_r=lora_r,
405
+ lora_alpha=lora_alpha,
406
+ lora_dropout=lora_dropout,
407
+ gradient_accumulation_steps=gradient_accumulation_steps,
408
+ gradient_checkpointing=gradient_checkpointing,
409
+ load_in_4bit=load_in_4bit,
410
+ device=None if device == "auto" else device,
411
+ aws_access_key_id=r2_access_key_id,
412
+ aws_secret_access_key=r2_secret_access_key,
413
+ endpoint_url=endpoint_url,
414
+ region_name=region_name,
415
+ env_file=env_file,
416
+ )
417
+
418
+ click.echo(f"āœ… LoRA adapter successfully created at: {adapter_path}")
419
+
420
+ except Exception as e:
421
+ if "No files found" in str(e):
422
+ click.echo(f"āŒ Error: {e}", err=True)
423
+ click.echo("\nšŸ’” Troubleshooting tips:", err=True)
424
+ click.echo(" • Check that your bucket contains files", err=True)
425
+ click.echo(
426
+ " • Verify the folder prefix (if specified) is correct", err=True
427
+ )
428
+ click.echo(
429
+ " • Ensure files are in supported formats (.md, .txt, .pdf, etc.)",
430
+ err=True,
431
+ )
432
+ elif "Bucket" in str(e) and "does not exist" in str(e):
433
+ click.echo(f"āŒ Error: {e}", err=True)
434
+ click.echo("\nšŸ’” Troubleshooting tips:", err=True)
435
+ click.echo(" • Check the bucket name is correct", err=True)
436
+ click.echo(" • Verify the bucket exists in your R2 account", err=True)
437
+ click.echo(
438
+ " • Ensure your credentials have access to this bucket", err=True
439
+ )
440
+ elif "endpoint" in str(e).lower():
441
+ click.echo(f"āŒ Error: {e}", err=True)
442
+ click.echo("\nšŸ’” Troubleshooting tips:", err=True)
443
+ click.echo(
444
+ " • R2 endpoint format: https://your-account-id.r2.cloudflarestorage.com",
445
+ err=True,
446
+ )
447
+ click.echo(
448
+ " • Do NOT include the bucket name in the endpoint URL", err=True
449
+ )
450
+ click.echo(
451
+ " • Get your endpoint from Cloudflare dashboard > R2 > Manage R2 API tokens",
452
+ err=True,
453
+ )
454
+ else:
455
+ click.echo(f"āŒ Error: {e}", err=True)
456
+ raise click.Abort()
457
+
458
+
459
+ @cli.command()
460
+ @click.argument("adapter_path", type=click.Path(exists=True))
461
+ @click.argument("finetune_name")
462
+ @click.option(
463
+ "--cf-model",
464
+ default=None,
465
+ help="Lora-capable base model endpoint (derived from model_type if omitted)",
466
+ )
467
+ @click.option(
468
+ "--backend",
469
+ type=click.Choice(["wrangler", "rest"], case_sensitive=False),
470
+ default="wrangler",
471
+ help="Upload via the wrangler CLI or the Cloudflare REST API",
472
+ )
473
+ @click.option(
474
+ "--account-id",
475
+ envvar="CLOUDFLARE_ACCOUNT_ID",
476
+ help="Cloudflare account id (REST backend; or CLOUDFLARE_ACCOUNT_ID)",
477
+ )
478
+ @click.option(
479
+ "--api-token",
480
+ envvar="CLOUDFLARE_API_TOKEN",
481
+ help="Cloudflare API token (REST backend; or CLOUDFLARE_API_TOKEN)",
482
+ )
483
+ @click.option("--verbose", "-v", is_flag=True, help="Enable verbose logging")
484
+ def deploy(
485
+ adapter_path: str,
486
+ finetune_name: str,
487
+ cf_model: Optional[str],
488
+ backend: str,
489
+ account_id: Optional[str],
490
+ api_token: Optional[str],
491
+ verbose: bool,
492
+ ):
493
+ """Upload a trained adapter to Cloudflare Workers AI."""
494
+ if verbose:
495
+ logging.getLogger().setLevel(logging.DEBUG)
496
+
497
+ from .deploy import deploy_adapter
498
+
499
+ try:
500
+ result = deploy_adapter(
501
+ adapter_path=adapter_path,
502
+ finetune_name=finetune_name,
503
+ cf_model=cf_model,
504
+ backend=backend.lower(),
505
+ account_id=account_id,
506
+ api_token=api_token,
507
+ )
508
+ click.echo(f"āœ… Deployed: {result}")
509
+ click.echo(
510
+ f' Reference it at inference with the lora param: "{finetune_name}"'
511
+ )
512
+ except Exception as e:
513
+ click.echo(f"āŒ Error: {e}", err=True)
514
+ raise click.Abort()
515
+
516
+
517
+ def main():
518
+ """Main entry point for the CLI."""
519
+ cli()
520
+
521
+
522
+ if __name__ == "__main__":
523
+ main()