aigroup-workflow 2.2.0 → 2.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude/commands/fix-build.md +10 -5
- package/.claude/commands/init-project.md +13 -8
- package/.claude/commands/plan.md +15 -8
- package/.claude/commands/review.md +12 -6
- package/.claude/commands/tdd.md +11 -5
- package/.claude/commands/workflow-start.md +20 -11
- package/.claude/settings.json +28 -0
- package/.codex/agents/architect.toml +207 -0
- package/.codex/agents/build-error-resolver.toml +110 -0
- package/.codex/agents/code-reviewer.toml +233 -0
- package/.codex/agents/doc-updater.toml +103 -0
- package/.codex/agents/e2e-runner.toml +103 -0
- package/.codex/agents/get-current-datetime.toml +23 -0
- package/.codex/agents/init-architect.toml +181 -0
- package/.codex/agents/planner.toml +208 -0
- package/.codex/agents/refactor-cleaner.toml +81 -0
- package/.codex/agents/rust-reviewer.toml +90 -0
- package/.codex/agents/security-reviewer.toml +104 -0
- package/.codex/agents/tdd-guide.toml +87 -0
- package/AGENTS.md +2 -2
- package/CLAUDE.md +23 -1
- package/LICENSE +20 -20
- package/README.md +333 -333
- package/agents/a11y-architect.md +141 -141
- package/agents/architect.md +211 -211
- package/agents/build-error-resolver.md +114 -114
- package/agents/chief-of-staff.md +151 -151
- package/agents/code-architect.md +71 -71
- package/agents/code-explorer.md +69 -69
- package/agents/code-reviewer.md +237 -237
- package/agents/code-simplifier.md +47 -47
- package/agents/comment-analyzer.md +45 -45
- package/agents/conversation-analyzer.md +52 -52
- package/agents/cpp-build-resolver.md +90 -90
- package/agents/cpp-reviewer.md +72 -72
- package/agents/csharp-reviewer.md +101 -101
- package/agents/dart-build-resolver.md +201 -201
- package/agents/database-reviewer.md +91 -91
- package/agents/doc-updater.md +107 -107
- package/agents/docs-lookup.md +68 -68
- package/agents/e2e-runner.md +107 -107
- package/agents/flutter-reviewer.md +243 -243
- package/agents/gan-evaluator.md +209 -209
- package/agents/gan-generator.md +131 -131
- package/agents/gan-planner.md +99 -99
- package/agents/get-current-datetime.md +26 -26
- package/agents/go-build-resolver.md +94 -94
- package/agents/go-reviewer.md +76 -76
- package/agents/harness-optimizer.md +35 -35
- package/agents/healthcare-reviewer.md +83 -83
- package/agents/java-build-resolver.md +153 -153
- package/agents/java-reviewer.md +92 -92
- package/agents/kotlin-build-resolver.md +118 -118
- package/agents/kotlin-reviewer.md +159 -159
- package/agents/loop-operator.md +36 -36
- package/agents/opensource-forker.md +198 -198
- package/agents/opensource-packager.md +249 -249
- package/agents/opensource-sanitizer.md +188 -188
- package/agents/performance-optimizer.md +446 -446
- package/agents/planner.md +212 -212
- package/agents/pr-test-analyzer.md +45 -45
- package/agents/python-reviewer.md +98 -98
- package/agents/pytorch-build-resolver.md +120 -120
- package/agents/refactor-cleaner.md +85 -85
- package/agents/rust-build-resolver.md +148 -148
- package/agents/rust-reviewer.md +94 -94
- package/agents/security-reviewer.md +108 -108
- package/agents/seo-specialist.md +59 -59
- package/agents/silent-failure-hunter.md +50 -50
- package/agents/tdd-guide.md +91 -91
- package/agents/type-design-analyzer.md +41 -41
- package/agents/typescript-reviewer.md +112 -112
- package/cli/commands/update.mjs +1 -1
- package/cli/utils/scaffold.mjs +53 -0
- package/docs/rules/agents.md +166 -50
- package/docs/rules/cpp/coding-style.md +44 -44
- package/docs/rules/cpp/hooks.md +39 -39
- package/docs/rules/cpp/patterns.md +51 -51
- package/docs/rules/cpp/security.md +51 -51
- package/docs/rules/cpp/testing.md +44 -44
- package/docs/rules/csharp/coding-style.md +72 -72
- package/docs/rules/csharp/hooks.md +25 -25
- package/docs/rules/csharp/patterns.md +50 -50
- package/docs/rules/csharp/security.md +58 -58
- package/docs/rules/csharp/testing.md +46 -46
- package/docs/rules/dart/coding-style.md +159 -159
- package/docs/rules/dart/hooks.md +66 -66
- package/docs/rules/dart/patterns.md +261 -261
- package/docs/rules/dart/security.md +135 -135
- package/docs/rules/dart/testing.md +215 -215
- package/docs/rules/golang/coding-style.md +32 -32
- package/docs/rules/golang/hooks.md +17 -17
- package/docs/rules/golang/patterns.md +45 -45
- package/docs/rules/golang/security.md +34 -34
- package/docs/rules/golang/testing.md +31 -31
- package/docs/rules/java/coding-style.md +114 -114
- package/docs/rules/java/hooks.md +18 -18
- package/docs/rules/java/patterns.md +146 -146
- package/docs/rules/java/security.md +100 -100
- package/docs/rules/java/testing.md +131 -131
- package/docs/rules/kotlin/coding-style.md +86 -86
- package/docs/rules/kotlin/hooks.md +17 -17
- package/docs/rules/kotlin/patterns.md +146 -146
- package/docs/rules/kotlin/security.md +82 -82
- package/docs/rules/kotlin/testing.md +128 -128
- package/docs/rules/perl/coding-style.md +46 -46
- package/docs/rules/perl/hooks.md +22 -22
- package/docs/rules/perl/patterns.md +76 -76
- package/docs/rules/perl/security.md +69 -69
- package/docs/rules/perl/testing.md +54 -54
- package/docs/rules/php/coding-style.md +40 -40
- package/docs/rules/php/hooks.md +24 -24
- package/docs/rules/php/patterns.md +33 -33
- package/docs/rules/php/security.md +37 -37
- package/docs/rules/php/testing.md +39 -39
- package/docs/rules/python/coding-style.md +42 -42
- package/docs/rules/python/hooks.md +19 -19
- package/docs/rules/python/patterns.md +39 -39
- package/docs/rules/python/security.md +30 -30
- package/docs/rules/python/testing.md +38 -38
- package/docs/rules/rust/coding-style.md +151 -151
- package/docs/rules/rust/hooks.md +16 -16
- package/docs/rules/rust/patterns.md +168 -168
- package/docs/rules/rust/security.md +141 -141
- package/docs/rules/rust/testing.md +154 -154
- package/docs/rules/swift/coding-style.md +47 -47
- package/docs/rules/swift/hooks.md +20 -20
- package/docs/rules/swift/patterns.md +66 -66
- package/docs/rules/swift/security.md +33 -33
- package/docs/rules/swift/testing.md +45 -45
- package/docs/rules/typescript/coding-style.md +199 -199
- package/docs/rules/typescript/hooks.md +22 -22
- package/docs/rules/typescript/patterns.md +52 -52
- package/docs/rules/typescript/security.md +28 -28
- package/docs/rules/typescript/testing.md +18 -18
- package/docs/rules/web/coding-style.md +96 -96
- package/docs/rules/web/design-quality.md +62 -62
- package/docs/rules/web/hooks.md +120 -120
- package/docs/rules/web/patterns.md +79 -79
- package/docs/rules/web/performance.md +64 -64
- package/docs/rules/web/security.md +57 -57
- package/docs/rules/web/testing.md +55 -55
- package/docs/templates/README.md +36 -36
- package/docs/templates/ai-project-final.md +124 -124
- package/docs/templates/ai-project.md +105 -105
- package/docs/templates/api.md +157 -157
- package/docs/templates/bug.md +62 -62
- package/docs/templates/code-review.md +87 -87
- package/docs/templates/generic.md +116 -116
- package/docs/templates/implementation-plan.md +1 -1
- package/docs/templates/meeting.md +68 -68
- package/docs/templates/prd.md +98 -98
- package/docs/templates/ui.md +134 -134
- package/docs/workflow-pipeline.md +11 -10
- package/package.json +40 -39
- package/scripts/hooks/checks/orchestration-artifacts.cjs +28 -23
- package/scripts/hooks/checks/workflow-state.cjs +4 -5
- package/scripts/orchestration/lib/orchestrator.cjs +344 -117
- package/scripts/orchestration/lib/validate.cjs +145 -0
- package/scripts/orchestration/session.cjs +88 -44
- package/skills/SUPERPOWERS-LICENSE +21 -21
- package/skills/ai-ml/fine-tuning-expert/SKILL.md +162 -162
- package/skills/ai-ml/fine-tuning-expert/references/dataset-preparation.md +540 -540
- package/skills/ai-ml/fine-tuning-expert/references/deployment-optimization.md +673 -673
- package/skills/ai-ml/fine-tuning-expert/references/evaluation-metrics.md +597 -597
- package/skills/ai-ml/fine-tuning-expert/references/hyperparameter-tuning.md +565 -565
- package/skills/ai-ml/fine-tuning-expert/references/lora-peft.md +347 -347
- package/skills/ai-ml/ml-pipeline/SKILL.md +159 -159
- package/skills/ai-ml/ml-pipeline/references/experiment-tracking.md +833 -833
- package/skills/ai-ml/ml-pipeline/references/feature-engineering.md +631 -631
- package/skills/ai-ml/ml-pipeline/references/model-validation.md +978 -978
- package/skills/ai-ml/ml-pipeline/references/pipeline-orchestration.md +907 -907
- package/skills/ai-ml/ml-pipeline/references/training-pipelines.md +782 -782
- package/skills/ai-ml/rag-architect/SKILL.md +194 -194
- package/skills/ai-ml/rag-architect/references/chunking-strategies.md +878 -878
- package/skills/ai-ml/rag-architect/references/embedding-models.md +561 -561
- package/skills/ai-ml/rag-architect/references/rag-evaluation.md +833 -833
- package/skills/ai-ml/rag-architect/references/retrieval-optimization.md +795 -795
- package/skills/ai-ml/rag-architect/references/vector-databases.md +589 -589
- package/skills/ai-ml/spark-engineer/SKILL.md +148 -148
- package/skills/ai-ml/spark-engineer/references/partitioning-caching.md +543 -543
- package/skills/ai-ml/spark-engineer/references/performance-tuning.md +544 -544
- package/skills/ai-ml/spark-engineer/references/rdd-operations.md +599 -599
- package/skills/ai-ml/spark-engineer/references/spark-sql-dataframes.md +474 -474
- package/skills/ai-ml/spark-engineer/references/streaming-patterns.md +786 -786
- package/skills/backend/api-designer/SKILL.md +217 -217
- package/skills/backend/api-designer/references/error-handling.md +541 -541
- package/skills/backend/api-designer/references/openapi.md +824 -824
- package/skills/backend/api-designer/references/pagination.md +494 -494
- package/skills/backend/api-designer/references/rest-patterns.md +335 -335
- package/skills/backend/api-designer/references/versioning.md +391 -391
- package/skills/backend/architecture-designer/SKILL.md +117 -117
- package/skills/backend/architecture-designer/references/adr-template.md +116 -116
- package/skills/backend/architecture-designer/references/architecture-patterns.md +111 -111
- package/skills/backend/architecture-designer/references/database-selection.md +102 -102
- package/skills/backend/architecture-designer/references/nfr-checklist.md +112 -112
- package/skills/backend/architecture-designer/references/system-design.md +100 -100
- package/skills/backend/code-documenter/SKILL.md +147 -147
- package/skills/backend/code-documenter/references/api-docs-fastapi-django.md +166 -166
- package/skills/backend/code-documenter/references/api-docs-nestjs-express.md +220 -220
- package/skills/backend/code-documenter/references/coverage-reports.md +125 -125
- package/skills/backend/code-documenter/references/documentation-systems.md +333 -333
- package/skills/backend/code-documenter/references/interactive-api-docs.md +531 -531
- package/skills/backend/code-documenter/references/python-docstrings.md +121 -121
- package/skills/backend/code-documenter/references/typescript-jsdoc.md +145 -145
- package/skills/backend/code-documenter/references/user-guides-tutorials.md +530 -530
- package/skills/backend/debugging-wizard/SKILL.md +105 -105
- package/skills/backend/debugging-wizard/references/common-patterns.md +132 -132
- package/skills/backend/debugging-wizard/references/debugging-tools.md +140 -140
- package/skills/backend/debugging-wizard/references/quick-fixes.md +177 -177
- package/skills/backend/debugging-wizard/references/strategies.md +142 -142
- package/skills/backend/debugging-wizard/references/systematic-debugging.md +367 -367
- package/skills/backend/feature-forge/SKILL.md +98 -98
- package/skills/backend/feature-forge/references/acceptance-criteria.md +104 -104
- package/skills/backend/feature-forge/references/ears-syntax.md +99 -99
- package/skills/backend/feature-forge/references/interview-questions.md +150 -150
- package/skills/backend/feature-forge/references/pre-discovery-subagents.md +54 -54
- package/skills/backend/feature-forge/references/specification-template.md +103 -103
- package/skills/backend/fullstack-guardian/SKILL.md +105 -105
- package/skills/backend/fullstack-guardian/references/api-design-standards.md +307 -307
- package/skills/backend/fullstack-guardian/references/architecture-decisions.md +350 -350
- package/skills/backend/fullstack-guardian/references/backend-patterns.md +237 -237
- package/skills/backend/fullstack-guardian/references/common-patterns.md +134 -134
- package/skills/backend/fullstack-guardian/references/deliverables-checklist.md +354 -354
- package/skills/backend/fullstack-guardian/references/design-template.md +91 -91
- package/skills/backend/fullstack-guardian/references/error-handling.md +135 -135
- package/skills/backend/fullstack-guardian/references/frontend-patterns.md +340 -340
- package/skills/backend/fullstack-guardian/references/integration-patterns.md +333 -333
- package/skills/backend/fullstack-guardian/references/security-checklist.md +106 -106
- package/skills/backend/graphql-architect/SKILL.md +146 -146
- package/skills/backend/graphql-architect/references/federation.md +418 -418
- package/skills/backend/graphql-architect/references/migration-from-rest.md +1141 -1141
- package/skills/backend/graphql-architect/references/resolvers.md +425 -425
- package/skills/backend/graphql-architect/references/schema-design.md +393 -393
- package/skills/backend/graphql-architect/references/security.md +569 -569
- package/skills/backend/graphql-architect/references/subscriptions.md +510 -510
- package/skills/backend/legacy-modernizer/SKILL.md +137 -137
- package/skills/backend/legacy-modernizer/references/legacy-testing.md +381 -381
- package/skills/backend/legacy-modernizer/references/migration-strategies.md +423 -423
- package/skills/backend/legacy-modernizer/references/refactoring-patterns.md +395 -395
- package/skills/backend/legacy-modernizer/references/strangler-fig-pattern.md +281 -281
- package/skills/backend/legacy-modernizer/references/system-assessment.md +487 -487
- package/skills/backend/microservices-architect/SKILL.md +164 -164
- package/skills/backend/microservices-architect/references/communication.md +499 -499
- package/skills/backend/microservices-architect/references/data.md +721 -721
- package/skills/backend/microservices-architect/references/decomposition.md +344 -344
- package/skills/backend/microservices-architect/references/observability.md +805 -805
- package/skills/backend/microservices-architect/references/patterns.md +603 -603
- package/skills/database/database-optimizer/SKILL.md +147 -147
- package/skills/database/database-optimizer/references/index-strategies.md +331 -331
- package/skills/database/database-optimizer/references/monitoring-analysis.md +501 -501
- package/skills/database/database-optimizer/references/mysql-tuning.md +452 -452
- package/skills/database/database-optimizer/references/postgresql-tuning.md +413 -413
- package/skills/database/database-optimizer/references/query-optimization.md +251 -251
- package/skills/database/postgres-pro/SKILL.md +152 -152
- package/skills/database/postgres-pro/references/extensions.md +404 -404
- package/skills/database/postgres-pro/references/jsonb.md +321 -321
- package/skills/database/postgres-pro/references/maintenance.md +481 -481
- package/skills/database/postgres-pro/references/performance.md +265 -265
- package/skills/database/postgres-pro/references/replication.md +446 -446
- package/skills/database/sql-pro/SKILL.md +129 -129
- package/skills/database/sql-pro/references/database-design.md +402 -402
- package/skills/database/sql-pro/references/dialect-differences.md +419 -419
- package/skills/database/sql-pro/references/optimization.md +384 -384
- package/skills/database/sql-pro/references/query-patterns.md +285 -285
- package/skills/database/sql-pro/references/window-functions.md +328 -328
- package/skills/dotnet/csharp-developer/SKILL.md +125 -125
- package/skills/dotnet/csharp-developer/references/aspnet-core.md +394 -394
- package/skills/dotnet/csharp-developer/references/blazor.md +553 -553
- package/skills/dotnet/csharp-developer/references/entity-framework.md +409 -409
- package/skills/dotnet/csharp-developer/references/modern-csharp.md +248 -248
- package/skills/dotnet/csharp-developer/references/performance.md +498 -498
- package/skills/dotnet/dotnet-core-expert/SKILL.md +138 -138
- package/skills/dotnet/dotnet-core-expert/references/authentication.md +546 -546
- package/skills/dotnet/dotnet-core-expert/references/clean-architecture.md +455 -455
- package/skills/dotnet/dotnet-core-expert/references/cloud-native.md +548 -548
- package/skills/dotnet/dotnet-core-expert/references/entity-framework.md +440 -440
- package/skills/dotnet/dotnet-core-expert/references/minimal-apis.md +319 -319
- package/skills/frontend/angular-architect/SKILL.md +152 -152
- package/skills/frontend/angular-architect/references/components.md +297 -297
- package/skills/frontend/angular-architect/references/ngrx.md +401 -401
- package/skills/frontend/angular-architect/references/routing.md +361 -361
- package/skills/frontend/angular-architect/references/rxjs.md +319 -319
- package/skills/frontend/angular-architect/references/testing.md +405 -405
- package/skills/frontend/design-commands/design.md +91 -91
- package/skills/frontend/design-commands/handoff.md +97 -97
- package/skills/frontend/design-commands/prototype.md +120 -120
- package/skills/frontend/design-commands/spec.md +160 -160
- package/skills/frontend/design-commands/style.md +78 -78
- package/skills/frontend/flutter-expert/SKILL.md +138 -138
- package/skills/frontend/flutter-expert/references/bloc-state.md +259 -259
- package/skills/frontend/flutter-expert/references/gorouter-navigation.md +119 -119
- package/skills/frontend/flutter-expert/references/performance.md +99 -99
- package/skills/frontend/flutter-expert/references/project-structure.md +118 -118
- package/skills/frontend/flutter-expert/references/riverpod-state.md +130 -130
- package/skills/frontend/flutter-expert/references/widget-patterns.md +123 -123
- package/skills/frontend/nextjs-developer/SKILL.md +143 -143
- package/skills/frontend/nextjs-developer/references/app-router.md +311 -311
- package/skills/frontend/nextjs-developer/references/data-fetching.md +482 -482
- package/skills/frontend/nextjs-developer/references/deployment.md +545 -545
- package/skills/frontend/nextjs-developer/references/server-actions.md +462 -462
- package/skills/frontend/nextjs-developer/references/server-components.md +384 -384
- package/skills/frontend/react-expert/SKILL.md +149 -149
- package/skills/frontend/react-expert/references/hooks-patterns.md +162 -162
- package/skills/frontend/react-expert/references/migration-class-to-modern.md +1119 -1119
- package/skills/frontend/react-expert/references/performance.md +168 -168
- package/skills/frontend/react-expert/references/react-19-features.md +174 -174
- package/skills/frontend/react-expert/references/server-components.md +143 -143
- package/skills/frontend/react-expert/references/state-management.md +171 -171
- package/skills/frontend/react-expert/references/testing-react.md +174 -174
- package/skills/frontend/react-native-expert/SKILL.md +185 -185
- package/skills/frontend/react-native-expert/references/expo-router.md +187 -187
- package/skills/frontend/react-native-expert/references/list-optimization.md +204 -204
- package/skills/frontend/react-native-expert/references/platform-handling.md +188 -188
- package/skills/frontend/react-native-expert/references/project-structure.md +171 -171
- package/skills/frontend/react-native-expert/references/storage-hooks.md +173 -173
- package/skills/frontend/senior-frontend/SKILL.md +477 -477
- package/skills/frontend/senior-frontend/references/frontend_best_practices.md +806 -806
- package/skills/frontend/senior-frontend/references/nextjs_optimization_guide.md +724 -724
- package/skills/frontend/senior-frontend/references/react_patterns.md +746 -746
- package/skills/frontend/senior-frontend/scripts/bundle_analyzer.py +407 -407
- package/skills/frontend/senior-frontend/scripts/component_generator.py +329 -329
- package/skills/frontend/senior-frontend/scripts/frontend_scaffolder.py +1005 -1005
- package/skills/frontend/ui-ux-pro-max/SKILL.md +386 -386
- package/skills/frontend/ui-ux-pro-max/data/charts.csv +26 -26
- package/skills/frontend/ui-ux-pro-max/data/colors.csv +97 -97
- package/skills/frontend/ui-ux-pro-max/data/icons.csv +101 -101
- package/skills/frontend/ui-ux-pro-max/data/landing.csv +31 -31
- package/skills/frontend/ui-ux-pro-max/data/products.csv +96 -96
- package/skills/frontend/ui-ux-pro-max/data/react-performance.csv +45 -45
- package/skills/frontend/ui-ux-pro-max/data/stacks/astro.csv +54 -54
- package/skills/frontend/ui-ux-pro-max/data/stacks/flutter.csv +53 -53
- package/skills/frontend/ui-ux-pro-max/data/stacks/html-tailwind.csv +56 -56
- package/skills/frontend/ui-ux-pro-max/data/stacks/jetpack-compose.csv +53 -53
- package/skills/frontend/ui-ux-pro-max/data/stacks/nextjs.csv +53 -53
- package/skills/frontend/ui-ux-pro-max/data/stacks/nuxt-ui.csv +51 -51
- package/skills/frontend/ui-ux-pro-max/data/stacks/nuxtjs.csv +59 -59
- package/skills/frontend/ui-ux-pro-max/data/stacks/react-native.csv +52 -52
- package/skills/frontend/ui-ux-pro-max/data/stacks/react.csv +54 -54
- package/skills/frontend/ui-ux-pro-max/data/stacks/shadcn.csv +61 -61
- package/skills/frontend/ui-ux-pro-max/data/stacks/svelte.csv +54 -54
- package/skills/frontend/ui-ux-pro-max/data/stacks/swiftui.csv +51 -51
- package/skills/frontend/ui-ux-pro-max/data/stacks/vue.csv +50 -50
- package/skills/frontend/ui-ux-pro-max/data/styles.csv +68 -68
- package/skills/frontend/ui-ux-pro-max/data/typography.csv +57 -57
- package/skills/frontend/ui-ux-pro-max/data/ui-reasoning.csv +101 -101
- package/skills/frontend/ui-ux-pro-max/data/ux-guidelines.csv +99 -99
- package/skills/frontend/ui-ux-pro-max/data/web-interface.csv +31 -31
- package/skills/frontend/ui-ux-pro-max/scripts/core.py +253 -253
- package/skills/frontend/ui-ux-pro-max/scripts/design_system.py +1067 -1067
- package/skills/frontend/ui-ux-pro-max/scripts/search.py +114 -114
- package/skills/frontend/vue-expert/SKILL.md +98 -98
- package/skills/frontend/vue-expert/references/build-tooling.md +480 -480
- package/skills/frontend/vue-expert/references/components.md +448 -448
- package/skills/frontend/vue-expert/references/composition-api.md +299 -299
- package/skills/frontend/vue-expert/references/mobile-hybrid.md +636 -636
- package/skills/frontend/vue-expert/references/nuxt.md +669 -669
- package/skills/frontend/vue-expert/references/state-management.md +449 -449
- package/skills/frontend/vue-expert/references/typescript.md +584 -584
- package/skills/frontend/vue-expert-js/SKILL.md +167 -167
- package/skills/frontend/vue-expert-js/references/component-architecture.md +219 -219
- package/skills/frontend/vue-expert-js/references/composables-patterns.md +183 -183
- package/skills/frontend/vue-expert-js/references/jsdoc-typing.md +535 -535
- package/skills/frontend/vue-expert-js/references/state-management.md +249 -249
- package/skills/frontend/vue-expert-js/references/testing-patterns.md +237 -237
- package/skills/go-rust-cpp/cpp-pro/SKILL.md +115 -115
- package/skills/go-rust-cpp/cpp-pro/references/build-tooling.md +440 -440
- package/skills/go-rust-cpp/cpp-pro/references/concurrency.md +437 -437
- package/skills/go-rust-cpp/cpp-pro/references/memory-performance.md +397 -397
- package/skills/go-rust-cpp/cpp-pro/references/modern-cpp.md +304 -304
- package/skills/go-rust-cpp/cpp-pro/references/templates.md +357 -357
- package/skills/go-rust-cpp/golang-pro/SKILL.md +122 -122
- package/skills/go-rust-cpp/golang-pro/references/concurrency.md +329 -329
- package/skills/go-rust-cpp/golang-pro/references/generics.md +442 -442
- package/skills/go-rust-cpp/golang-pro/references/interfaces.md +432 -432
- package/skills/go-rust-cpp/golang-pro/references/project-structure.md +477 -477
- package/skills/go-rust-cpp/golang-pro/references/testing.md +451 -451
- package/skills/go-rust-cpp/rust-engineer/SKILL.md +167 -167
- package/skills/go-rust-cpp/rust-engineer/references/async.md +458 -458
- package/skills/go-rust-cpp/rust-engineer/references/error-handling.md +334 -334
- package/skills/go-rust-cpp/rust-engineer/references/ownership.md +278 -278
- package/skills/go-rust-cpp/rust-engineer/references/testing.md +470 -470
- package/skills/go-rust-cpp/rust-engineer/references/traits.md +413 -413
- package/skills/infra/cli-developer/SKILL.md +113 -113
- package/skills/infra/cli-developer/references/design-patterns.md +221 -221
- package/skills/infra/cli-developer/references/go-cli.md +540 -540
- package/skills/infra/cli-developer/references/node-cli.md +383 -383
- package/skills/infra/cli-developer/references/python-cli.md +422 -422
- package/skills/infra/cli-developer/references/ux-patterns.md +448 -448
- package/skills/infra/cloud-architect/SKILL.md +216 -216
- package/skills/infra/cloud-architect/references/aws.md +394 -394
- package/skills/infra/cloud-architect/references/azure.md +562 -562
- package/skills/infra/cloud-architect/references/cost.md +582 -582
- package/skills/infra/cloud-architect/references/gcp.md +633 -633
- package/skills/infra/cloud-architect/references/multi-cloud.md +483 -483
- package/skills/infra/devops-engineer/SKILL.md +144 -144
- package/skills/infra/devops-engineer/references/deployment-strategies.md +241 -241
- package/skills/infra/devops-engineer/references/docker-patterns.md +113 -113
- package/skills/infra/devops-engineer/references/github-actions.md +139 -139
- package/skills/infra/devops-engineer/references/incident-response.md +331 -331
- package/skills/infra/devops-engineer/references/kubernetes.md +154 -154
- package/skills/infra/devops-engineer/references/platform-engineering.md +417 -417
- package/skills/infra/devops-engineer/references/release-automation.md +527 -527
- package/skills/infra/devops-engineer/references/terraform-iac.md +141 -141
- package/skills/infra/kubernetes-specialist/SKILL.md +241 -241
- package/skills/infra/kubernetes-specialist/references/configuration.md +452 -452
- package/skills/infra/kubernetes-specialist/references/cost-optimization.md +458 -458
- package/skills/infra/kubernetes-specialist/references/custom-operators.md +563 -563
- package/skills/infra/kubernetes-specialist/references/gitops.md +530 -530
- package/skills/infra/kubernetes-specialist/references/helm-charts.md +912 -912
- package/skills/infra/kubernetes-specialist/references/multi-cluster.md +507 -507
- package/skills/infra/kubernetes-specialist/references/networking.md +447 -447
- package/skills/infra/kubernetes-specialist/references/service-mesh.md +459 -459
- package/skills/infra/kubernetes-specialist/references/storage.md +535 -535
- package/skills/infra/kubernetes-specialist/references/troubleshooting.md +414 -414
- package/skills/infra/kubernetes-specialist/references/workloads.md +377 -377
- package/skills/infra/mcp-developer/SKILL.md +143 -143
- package/skills/infra/mcp-developer/references/protocol.md +244 -244
- package/skills/infra/mcp-developer/references/python-sdk.md +367 -367
- package/skills/infra/mcp-developer/references/resources.md +554 -554
- package/skills/infra/mcp-developer/references/tools.md +480 -480
- package/skills/infra/mcp-developer/references/typescript-sdk.md +350 -350
- package/skills/infra/monitoring-expert/SKILL.md +176 -176
- package/skills/infra/monitoring-expert/references/alerting-rules.md +141 -141
- package/skills/infra/monitoring-expert/references/application-profiling.md +331 -331
- package/skills/infra/monitoring-expert/references/capacity-planning.md +344 -344
- package/skills/infra/monitoring-expert/references/dashboards.md +126 -126
- package/skills/infra/monitoring-expert/references/opentelemetry.md +123 -123
- package/skills/infra/monitoring-expert/references/performance-testing.md +269 -269
- package/skills/infra/monitoring-expert/references/prometheus-metrics.md +136 -136
- package/skills/infra/monitoring-expert/references/structured-logging.md +142 -142
- package/skills/infra/sre-engineer/SKILL.md +181 -181
- package/skills/infra/sre-engineer/references/automation-toil.md +492 -492
- package/skills/infra/sre-engineer/references/error-budget-policy.md +334 -334
- package/skills/infra/sre-engineer/references/incident-chaos.md +576 -576
- package/skills/infra/sre-engineer/references/monitoring-alerting.md +424 -424
- package/skills/infra/sre-engineer/references/slo-sli-management.md +238 -238
- package/skills/infra/terraform-engineer/SKILL.md +143 -143
- package/skills/infra/terraform-engineer/references/best-practices.md +583 -583
- package/skills/infra/terraform-engineer/references/module-patterns.md +297 -297
- package/skills/infra/terraform-engineer/references/providers.md +452 -452
- package/skills/infra/terraform-engineer/references/state-management.md +371 -371
- package/skills/infra/terraform-engineer/references/testing.md +486 -486
- package/skills/infra/websocket-engineer/SKILL.md +168 -168
- package/skills/infra/websocket-engineer/references/alternatives.md +391 -391
- package/skills/infra/websocket-engineer/references/patterns.md +400 -400
- package/skills/infra/websocket-engineer/references/protocol.md +195 -195
- package/skills/infra/websocket-engineer/references/scaling.md +333 -333
- package/skills/infra/websocket-engineer/references/security.md +474 -474
- package/skills/java/java-architect/SKILL.md +132 -132
- package/skills/java/java-architect/references/jpa-optimization.md +393 -393
- package/skills/java/java-architect/references/reactive-webflux.md +356 -356
- package/skills/java/java-architect/references/spring-boot-setup.md +269 -269
- package/skills/java/java-architect/references/spring-security.md +445 -445
- package/skills/java/java-architect/references/testing-patterns.md +500 -500
- package/skills/java/kotlin-specialist/SKILL.md +147 -147
- package/skills/java/kotlin-specialist/references/android-compose.md +419 -419
- package/skills/java/kotlin-specialist/references/coroutines-flow.md +276 -276
- package/skills/java/kotlin-specialist/references/dsl-idioms.md +421 -421
- package/skills/java/kotlin-specialist/references/ktor-server.md +426 -426
- package/skills/java/kotlin-specialist/references/multiplatform-kmp.md +380 -380
- package/skills/java/spring-boot-engineer/SKILL.md +195 -195
- package/skills/java/spring-boot-engineer/references/cloud.md +498 -498
- package/skills/java/spring-boot-engineer/references/data.md +381 -381
- package/skills/java/spring-boot-engineer/references/security.md +459 -459
- package/skills/java/spring-boot-engineer/references/testing.md +545 -545
- package/skills/java/spring-boot-engineer/references/web.md +295 -295
- package/skills/javascript/javascript-pro/SKILL.md +132 -132
- package/skills/javascript/javascript-pro/references/async-patterns.md +334 -334
- package/skills/javascript/javascript-pro/references/browser-apis.md +398 -398
- package/skills/javascript/javascript-pro/references/modern-syntax.md +272 -272
- package/skills/javascript/javascript-pro/references/modules.md +357 -357
- package/skills/javascript/javascript-pro/references/node-essentials.md +471 -471
- package/skills/javascript/nestjs-expert/SKILL.md +206 -206
- package/skills/javascript/nestjs-expert/references/authentication.md +166 -166
- package/skills/javascript/nestjs-expert/references/controllers-routing.md +111 -111
- package/skills/javascript/nestjs-expert/references/dtos-validation.md +153 -153
- package/skills/javascript/nestjs-expert/references/migration-from-express.md +1237 -1237
- package/skills/javascript/nestjs-expert/references/services-di.md +140 -140
- package/skills/javascript/nestjs-expert/references/testing-patterns.md +186 -186
- package/skills/javascript/typescript-pro/SKILL.md +145 -145
- package/skills/javascript/typescript-pro/references/advanced-types.md +259 -259
- package/skills/javascript/typescript-pro/references/configuration.md +445 -445
- package/skills/javascript/typescript-pro/references/patterns.md +484 -484
- package/skills/javascript/typescript-pro/references/type-guards.md +352 -352
- package/skills/javascript/typescript-pro/references/utility-types.md +329 -329
- package/skills/php/laravel-specialist/SKILL.md +262 -262
- package/skills/php/laravel-specialist/references/eloquent.md +351 -351
- package/skills/php/laravel-specialist/references/livewire.md +512 -512
- package/skills/php/laravel-specialist/references/queues.md +423 -423
- package/skills/php/laravel-specialist/references/routing.md +362 -362
- package/skills/php/laravel-specialist/references/testing.md +522 -522
- package/skills/php/php-pro/SKILL.md +206 -206
- package/skills/php/php-pro/references/async-patterns.md +412 -412
- package/skills/php/php-pro/references/laravel-patterns.md +377 -377
- package/skills/php/php-pro/references/modern-php-features.md +323 -323
- package/skills/php/php-pro/references/symfony-patterns.md +466 -466
- package/skills/php/php-pro/references/testing-quality.md +466 -466
- package/skills/product/competitive-analysis/SKILL.md +257 -257
- package/skills/product/meeting-notes/SKILL.md +266 -266
- package/skills/product/prd-template/SKILL.md +150 -150
- package/skills/product/stakeholder-update/SKILL.md +225 -225
- package/skills/product/user-research-synthesis/SKILL.md +235 -235
- package/skills/python/django-expert/SKILL.md +162 -162
- package/skills/python/django-expert/references/authentication.md +145 -145
- package/skills/python/django-expert/references/drf-serializers.md +148 -148
- package/skills/python/django-expert/references/models-orm.md +151 -151
- package/skills/python/django-expert/references/testing-django.md +204 -204
- package/skills/python/django-expert/references/viewsets-views.md +153 -153
- package/skills/python/fastapi-expert/SKILL.md +185 -185
- package/skills/python/fastapi-expert/references/async-sqlalchemy.md +146 -146
- package/skills/python/fastapi-expert/references/authentication.md +159 -159
- package/skills/python/fastapi-expert/references/endpoints-routing.md +142 -142
- package/skills/python/fastapi-expert/references/migration-from-django.md +996 -996
- package/skills/python/fastapi-expert/references/pydantic-v2.md +135 -135
- package/skills/python/fastapi-expert/references/testing-async.md +159 -159
- package/skills/python/pandas-pro/SKILL.md +178 -178
- package/skills/python/pandas-pro/references/aggregation-groupby.md +545 -545
- package/skills/python/pandas-pro/references/data-cleaning.md +500 -500
- package/skills/python/pandas-pro/references/dataframe-operations.md +420 -420
- package/skills/python/pandas-pro/references/merging-joining.md +596 -596
- package/skills/python/pandas-pro/references/performance-optimization.md +597 -597
- package/skills/python/python-pro/SKILL.md +177 -177
- package/skills/python/python-pro/references/async-patterns.md +356 -356
- package/skills/python/python-pro/references/packaging.md +460 -460
- package/skills/python/python-pro/references/standard-library.md +378 -378
- package/skills/python/python-pro/references/testing.md +404 -404
- package/skills/python/python-pro/references/type-system.md +290 -290
- package/skills/quality/chaos-engineer/SKILL.md +182 -182
- package/skills/quality/chaos-engineer/references/chaos-tools.md +511 -511
- package/skills/quality/chaos-engineer/references/experiment-design.md +229 -229
- package/skills/quality/chaos-engineer/references/game-days.md +434 -434
- package/skills/quality/chaos-engineer/references/infrastructure-chaos.md +348 -348
- package/skills/quality/chaos-engineer/references/kubernetes-chaos.md +432 -432
- package/skills/quality/code-reviewer/SKILL.md +119 -119
- package/skills/quality/code-reviewer/references/common-issues.md +142 -142
- package/skills/quality/code-reviewer/references/feedback-examples.md +144 -144
- package/skills/quality/code-reviewer/references/receiving-feedback.md +238 -238
- package/skills/quality/code-reviewer/references/report-template.md +109 -109
- package/skills/quality/code-reviewer/references/review-checklist.md +88 -88
- package/skills/quality/code-reviewer/references/spec-compliance-review.md +258 -258
- package/skills/quality/playwright-expert/SKILL.md +169 -169
- package/skills/quality/playwright-expert/references/api-mocking.md +140 -140
- package/skills/quality/playwright-expert/references/configuration.md +155 -155
- package/skills/quality/playwright-expert/references/debugging-flaky.md +150 -150
- package/skills/quality/playwright-expert/references/page-object-model.md +152 -152
- package/skills/quality/playwright-expert/references/selectors-locators.md +119 -119
- package/skills/quality/secure-code-guardian/SKILL.md +191 -191
- package/skills/quality/secure-code-guardian/references/authentication.md +136 -136
- package/skills/quality/secure-code-guardian/references/input-validation.md +146 -146
- package/skills/quality/secure-code-guardian/references/owasp-prevention.md +135 -135
- package/skills/quality/secure-code-guardian/references/security-headers.md +133 -133
- package/skills/quality/secure-code-guardian/references/xss-csrf.md +157 -157
- package/skills/quality/security-reviewer/SKILL.md +103 -103
- package/skills/quality/security-reviewer/references/infrastructure-security.md +268 -268
- package/skills/quality/security-reviewer/references/penetration-testing.md +268 -268
- package/skills/quality/security-reviewer/references/report-template.md +170 -170
- package/skills/quality/security-reviewer/references/sast-tools.md +117 -117
- package/skills/quality/security-reviewer/references/secret-scanning.md +125 -125
- package/skills/quality/security-reviewer/references/vulnerability-patterns.md +152 -152
- package/skills/quality/senior-qa/README.md +196 -196
- package/skills/quality/senior-qa/SKILL.md +399 -399
- package/skills/quality/senior-qa/references/qa_best_practices.md +964 -964
- package/skills/quality/senior-qa/references/test_automation_patterns.md +1009 -1009
- package/skills/quality/senior-qa/references/testing_strategies.md +649 -649
- package/skills/quality/senior-qa/scripts/coverage_analyzer.py +836 -836
- package/skills/quality/senior-qa/scripts/e2e_test_scaffolder.py +820 -820
- package/skills/quality/senior-qa/scripts/test_suite_generator.py +605 -605
- package/skills/quality/tdd-guide/HOW_TO_USE.md +313 -313
- package/skills/quality/tdd-guide/README.md +680 -680
- package/skills/quality/tdd-guide/SKILL.md +122 -122
- package/skills/quality/tdd-guide/assets/expected_output.json +77 -77
- package/skills/quality/tdd-guide/assets/sample_input_python.json +39 -39
- package/skills/quality/tdd-guide/assets/sample_input_typescript.json +36 -36
- package/skills/quality/tdd-guide/references/ci-integration.md +195 -195
- package/skills/quality/tdd-guide/references/framework-guide.md +206 -206
- package/skills/quality/tdd-guide/references/tdd-best-practices.md +128 -128
- package/skills/quality/tdd-guide/scripts/coverage_analyzer.py +434 -434
- package/skills/quality/tdd-guide/scripts/fixture_generator.py +440 -440
- package/skills/quality/tdd-guide/scripts/format_detector.py +384 -384
- package/skills/quality/tdd-guide/scripts/framework_adapter.py +428 -428
- package/skills/quality/tdd-guide/scripts/metrics_calculator.py +456 -456
- package/skills/quality/tdd-guide/scripts/output_formatter.py +354 -354
- package/skills/quality/tdd-guide/scripts/tdd_workflow.py +474 -474
- package/skills/quality/tdd-guide/scripts/test_generator.py +438 -438
- package/skills/quality/test-master/SKILL.md +94 -94
- package/skills/quality/test-master/references/automation-frameworks.md +294 -294
- package/skills/quality/test-master/references/e2e-testing.md +128 -128
- package/skills/quality/test-master/references/integration-testing.md +120 -120
- package/skills/quality/test-master/references/performance-testing.md +118 -118
- package/skills/quality/test-master/references/qa-methodology.md +247 -247
- package/skills/quality/test-master/references/security-testing.md +127 -127
- package/skills/quality/test-master/references/tdd-iron-laws.md +174 -174
- package/skills/quality/test-master/references/test-reports.md +104 -104
- package/skills/quality/test-master/references/testing-anti-patterns.md +231 -231
- package/skills/quality/test-master/references/unit-testing.md +113 -113
- package/skills/ruby/rails-expert/SKILL.md +154 -154
- package/skills/ruby/rails-expert/references/active-record.md +244 -244
- package/skills/ruby/rails-expert/references/api-development.md +401 -401
- package/skills/ruby/rails-expert/references/background-jobs.md +272 -272
- package/skills/ruby/rails-expert/references/hotwire-turbo.md +228 -228
- package/skills/ruby/rails-expert/references/rspec-testing.md +367 -367
- package/skills/swift/swift-expert/SKILL.md +163 -163
- package/skills/swift/swift-expert/references/async-concurrency.md +360 -360
- package/skills/swift/swift-expert/references/memory-performance.md +377 -377
- package/skills/swift/swift-expert/references/protocol-oriented.md +354 -354
- package/skills/swift/swift-expert/references/swiftui-patterns.md +291 -291
- package/skills/swift/swift-expert/references/testing-patterns.md +399 -399
- package/skills/workflow/brainstorming/SKILL.md +164 -164
- package/skills/workflow/brainstorming/scripts/frame-template.html +214 -214
- package/skills/workflow/brainstorming/scripts/helper.js +88 -88
- package/skills/workflow/brainstorming/scripts/server.cjs +354 -354
- package/skills/workflow/brainstorming/scripts/start-server.sh +148 -148
- package/skills/workflow/brainstorming/scripts/stop-server.sh +56 -56
- package/skills/workflow/brainstorming/spec-document-reviewer-prompt.md +49 -49
- package/skills/workflow/brainstorming/visual-companion.md +287 -287
- package/skills/workflow/documentation/SKILL.md +45 -45
- package/skills/workflow/entropy-management/SKILL.md +115 -115
- package/skills/workflow/executing-plans/SKILL.md +70 -70
- package/skills/workflow/finishing-a-development-branch/SKILL.md +200 -200
- package/skills/workflow/receiving-code-review/SKILL.md +213 -213
- package/skills/workflow/requesting-code-review/SKILL.md +105 -105
- package/skills/workflow/requesting-code-review/code-reviewer.md +146 -146
- package/skills/workflow/requirement-engineering/SKILL.md +111 -111
- package/skills/workflow/systematic-debugging/CREATION-LOG.md +119 -119
- package/skills/workflow/systematic-debugging/SKILL.md +296 -296
- package/skills/workflow/systematic-debugging/condition-based-waiting-example.ts +158 -158
- package/skills/workflow/systematic-debugging/condition-based-waiting.md +115 -115
- package/skills/workflow/systematic-debugging/defense-in-depth.md +122 -122
- package/skills/workflow/systematic-debugging/find-polluter.sh +63 -63
- package/skills/workflow/systematic-debugging/root-cause-tracing.md +169 -169
- package/skills/workflow/systematic-debugging/test-academic.md +14 -14
- package/skills/workflow/systematic-debugging/test-pressure-1.md +58 -58
- package/skills/workflow/systematic-debugging/test-pressure-2.md +68 -68
- package/skills/workflow/systematic-debugging/test-pressure-3.md +69 -69
- package/skills/workflow/using-git-worktrees/SKILL.md +218 -218
- package/skills/workflow/verification-before-completion/SKILL.md +139 -139
- package/skills/workflow/writing-plans/SKILL.md +151 -151
- package/skills/workflow/writing-plans/plan-document-reviewer-prompt.md +49 -49
- package/skills/workflow/writing-skills/SKILL.md +655 -655
- package/skills/workflow/writing-skills/anthropic-best-practices.md +1150 -1150
- package/skills/workflow/writing-skills/examples/CLAUDE_MD_TESTING.md +189 -189
- package/skills/workflow/writing-skills/persuasion-principles.md +187 -187
- package/skills/workflow/writing-skills/render-graphs.js +168 -168
- package/skills/workflow/writing-skills/testing-skills-with-subagents.md +384 -384
|
@@ -1,782 +1,782 @@
|
|
|
1
|
-
# Training Pipelines
|
|
2
|
-
|
|
3
|
-
---
|
|
4
|
-
|
|
5
|
-
## Overview
|
|
6
|
-
|
|
7
|
-
Training pipelines orchestrate the end-to-end model training process including data loading, distributed training, hyperparameter optimization, and artifact management. Production pipelines require reproducibility, scalability, and proper resource management.
|
|
8
|
-
|
|
9
|
-
## When to Use This Reference
|
|
10
|
-
|
|
11
|
-
- Setting up distributed training with PyTorch/TensorFlow
|
|
12
|
-
- Implementing hyperparameter tuning (Optuna, Ray Tune)
|
|
13
|
-
- Managing GPU/TPU resources for training
|
|
14
|
-
- Building reproducible training environments
|
|
15
|
-
- Creating checkpointing and fault-tolerant training
|
|
16
|
-
|
|
17
|
-
## When NOT to Use
|
|
18
|
-
|
|
19
|
-
- Quick model prototyping (use notebooks)
|
|
20
|
-
- Small models that fit in memory on single GPU
|
|
21
|
-
- One-off experiments without production requirements
|
|
22
|
-
|
|
23
|
-
---
|
|
24
|
-
|
|
25
|
-
## PyTorch Training Pipeline
|
|
26
|
-
|
|
27
|
-
### Complete Training Script
|
|
28
|
-
|
|
29
|
-
```python
|
|
30
|
-
import torch
|
|
31
|
-
import torch.nn as nn
|
|
32
|
-
from torch.utils.data import DataLoader, Dataset
|
|
33
|
-
from torch.optim import AdamW
|
|
34
|
-
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
35
|
-
import logging
|
|
36
|
-
from pathlib import Path
|
|
37
|
-
from dataclasses import dataclass
|
|
38
|
-
from typing import Optional
|
|
39
|
-
import json
|
|
40
|
-
|
|
41
|
-
logger = logging.getLogger(__name__)
|
|
42
|
-
|
|
43
|
-
@dataclass
|
|
44
|
-
class TrainingConfig:
|
|
45
|
-
"""Training hyperparameters and settings."""
|
|
46
|
-
model_name: str
|
|
47
|
-
batch_size: int = 32
|
|
48
|
-
learning_rate: float = 1e-4
|
|
49
|
-
weight_decay: float = 0.01
|
|
50
|
-
epochs: int = 10
|
|
51
|
-
warmup_steps: int = 100
|
|
52
|
-
max_grad_norm: float = 1.0
|
|
53
|
-
seed: int = 42
|
|
54
|
-
checkpoint_dir: str = "./checkpoints"
|
|
55
|
-
log_every_n_steps: int = 100
|
|
56
|
-
eval_every_n_steps: int = 500
|
|
57
|
-
save_every_n_steps: int = 1000
|
|
58
|
-
mixed_precision: bool = True
|
|
59
|
-
gradient_accumulation_steps: int = 1
|
|
60
|
-
|
|
61
|
-
def to_dict(self) -> dict:
|
|
62
|
-
return {k: v for k, v in self.__dict__.items()}
|
|
63
|
-
|
|
64
|
-
@classmethod
|
|
65
|
-
def from_dict(cls, d: dict) -> "TrainingConfig":
|
|
66
|
-
return cls(**d)
|
|
67
|
-
|
|
68
|
-
class Trainer:
|
|
69
|
-
"""Production-grade PyTorch trainer."""
|
|
70
|
-
|
|
71
|
-
def __init__(
|
|
72
|
-
self,
|
|
73
|
-
model: nn.Module,
|
|
74
|
-
config: TrainingConfig,
|
|
75
|
-
train_dataloader: DataLoader,
|
|
76
|
-
eval_dataloader: Optional[DataLoader] = None,
|
|
77
|
-
experiment_tracker=None,
|
|
78
|
-
):
|
|
79
|
-
self.model = model
|
|
80
|
-
self.config = config
|
|
81
|
-
self.train_dataloader = train_dataloader
|
|
82
|
-
self.eval_dataloader = eval_dataloader
|
|
83
|
-
self.tracker = experiment_tracker
|
|
84
|
-
|
|
85
|
-
self._setup_device()
|
|
86
|
-
self._setup_training()
|
|
87
|
-
self._setup_checkpointing()
|
|
88
|
-
|
|
89
|
-
def _setup_device(self) -> None:
|
|
90
|
-
"""Configure device and move model."""
|
|
91
|
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
92
|
-
self.model = self.model.to(self.device)
|
|
93
|
-
|
|
94
|
-
if self.config.mixed_precision and self.device.type == "cuda":
|
|
95
|
-
self.scaler = torch.amp.GradScaler("cuda")
|
|
96
|
-
else:
|
|
97
|
-
self.scaler = None
|
|
98
|
-
|
|
99
|
-
logger.info(f"Training on device: {self.device}")
|
|
100
|
-
|
|
101
|
-
def _setup_training(self) -> None:
|
|
102
|
-
"""Initialize optimizer and scheduler."""
|
|
103
|
-
self.optimizer = AdamW(
|
|
104
|
-
self.model.parameters(),
|
|
105
|
-
lr=self.config.learning_rate,
|
|
106
|
-
weight_decay=self.config.weight_decay,
|
|
107
|
-
)
|
|
108
|
-
|
|
109
|
-
total_steps = len(self.train_dataloader) * self.config.epochs
|
|
110
|
-
self.scheduler = CosineAnnealingLR(
|
|
111
|
-
self.optimizer,
|
|
112
|
-
T_max=total_steps,
|
|
113
|
-
eta_min=self.config.learning_rate * 0.01,
|
|
114
|
-
)
|
|
115
|
-
|
|
116
|
-
self.global_step = 0
|
|
117
|
-
self.best_eval_loss = float("inf")
|
|
118
|
-
|
|
119
|
-
def _setup_checkpointing(self) -> None:
|
|
120
|
-
"""Create checkpoint directory."""
|
|
121
|
-
self.checkpoint_dir = Path(self.config.checkpoint_dir)
|
|
122
|
-
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
123
|
-
|
|
124
|
-
def _set_seed(self) -> None:
|
|
125
|
-
"""Set random seeds for reproducibility."""
|
|
126
|
-
import random
|
|
127
|
-
import numpy as np
|
|
128
|
-
|
|
129
|
-
torch.manual_seed(self.config.seed)
|
|
130
|
-
torch.cuda.manual_seed_all(self.config.seed)
|
|
131
|
-
np.random.seed(self.config.seed)
|
|
132
|
-
random.seed(self.config.seed)
|
|
133
|
-
torch.backends.cudnn.deterministic = True
|
|
134
|
-
|
|
135
|
-
def train(self) -> dict:
|
|
136
|
-
"""Run training loop."""
|
|
137
|
-
self._set_seed()
|
|
138
|
-
self.model.train()
|
|
139
|
-
|
|
140
|
-
metrics_history = []
|
|
141
|
-
|
|
142
|
-
for epoch in range(self.config.epochs):
|
|
143
|
-
epoch_loss = 0.0
|
|
144
|
-
num_batches = 0
|
|
145
|
-
|
|
146
|
-
for batch_idx, batch in enumerate(self.train_dataloader):
|
|
147
|
-
loss = self._training_step(batch)
|
|
148
|
-
epoch_loss += loss
|
|
149
|
-
num_batches += 1
|
|
150
|
-
|
|
151
|
-
if self.global_step % self.config.log_every_n_steps == 0:
|
|
152
|
-
self._log_metrics({
|
|
153
|
-
"train/loss": loss,
|
|
154
|
-
"train/lr": self.scheduler.get_last_lr()[0],
|
|
155
|
-
"train/epoch": epoch,
|
|
156
|
-
})
|
|
157
|
-
|
|
158
|
-
if (
|
|
159
|
-
self.eval_dataloader
|
|
160
|
-
and self.global_step % self.config.eval_every_n_steps == 0
|
|
161
|
-
):
|
|
162
|
-
eval_metrics = self.evaluate()
|
|
163
|
-
self._log_metrics(eval_metrics)
|
|
164
|
-
|
|
165
|
-
if eval_metrics["eval/loss"] < self.best_eval_loss:
|
|
166
|
-
self.best_eval_loss = eval_metrics["eval/loss"]
|
|
167
|
-
self.save_checkpoint("best")
|
|
168
|
-
|
|
169
|
-
if self.global_step % self.config.save_every_n_steps == 0:
|
|
170
|
-
self.save_checkpoint(f"step_{self.global_step}")
|
|
171
|
-
|
|
172
|
-
avg_epoch_loss = epoch_loss / num_batches
|
|
173
|
-
logger.info(f"Epoch {epoch}: avg_loss={avg_epoch_loss:.4f}")
|
|
174
|
-
metrics_history.append({"epoch": epoch, "loss": avg_epoch_loss})
|
|
175
|
-
|
|
176
|
-
self.save_checkpoint("final")
|
|
177
|
-
|
|
178
|
-
return {
|
|
179
|
-
"best_eval_loss": self.best_eval_loss,
|
|
180
|
-
"final_train_loss": avg_epoch_loss,
|
|
181
|
-
"total_steps": self.global_step,
|
|
182
|
-
"metrics_history": metrics_history,
|
|
183
|
-
}
|
|
184
|
-
|
|
185
|
-
def _training_step(self, batch: dict) -> float:
|
|
186
|
-
"""Execute single training step."""
|
|
187
|
-
batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
188
|
-
|
|
189
|
-
if self.scaler:
|
|
190
|
-
with torch.amp.autocast("cuda"):
|
|
191
|
-
outputs = self.model(**batch)
|
|
192
|
-
loss = outputs.loss / self.config.gradient_accumulation_steps
|
|
193
|
-
self.scaler.scale(loss).backward()
|
|
194
|
-
else:
|
|
195
|
-
outputs = self.model(**batch)
|
|
196
|
-
loss = outputs.loss / self.config.gradient_accumulation_steps
|
|
197
|
-
loss.backward()
|
|
198
|
-
|
|
199
|
-
if (self.global_step + 1) % self.config.gradient_accumulation_steps == 0:
|
|
200
|
-
if self.scaler:
|
|
201
|
-
self.scaler.unscale_(self.optimizer)
|
|
202
|
-
|
|
203
|
-
torch.nn.utils.clip_grad_norm_(
|
|
204
|
-
self.model.parameters(),
|
|
205
|
-
self.config.max_grad_norm,
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
if self.scaler:
|
|
209
|
-
self.scaler.step(self.optimizer)
|
|
210
|
-
self.scaler.update()
|
|
211
|
-
else:
|
|
212
|
-
self.optimizer.step()
|
|
213
|
-
|
|
214
|
-
self.scheduler.step()
|
|
215
|
-
self.optimizer.zero_grad()
|
|
216
|
-
|
|
217
|
-
self.global_step += 1
|
|
218
|
-
return loss.item() * self.config.gradient_accumulation_steps
|
|
219
|
-
|
|
220
|
-
@torch.no_grad()
|
|
221
|
-
def evaluate(self) -> dict:
|
|
222
|
-
"""Run evaluation loop."""
|
|
223
|
-
self.model.eval()
|
|
224
|
-
total_loss = 0.0
|
|
225
|
-
num_batches = 0
|
|
226
|
-
|
|
227
|
-
for batch in self.eval_dataloader:
|
|
228
|
-
batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
229
|
-
|
|
230
|
-
if self.scaler:
|
|
231
|
-
with torch.amp.autocast("cuda"):
|
|
232
|
-
outputs = self.model(**batch)
|
|
233
|
-
else:
|
|
234
|
-
outputs = self.model(**batch)
|
|
235
|
-
|
|
236
|
-
total_loss += outputs.loss.item()
|
|
237
|
-
num_batches += 1
|
|
238
|
-
|
|
239
|
-
self.model.train()
|
|
240
|
-
|
|
241
|
-
return {
|
|
242
|
-
"eval/loss": total_loss / num_batches,
|
|
243
|
-
"eval/step": self.global_step,
|
|
244
|
-
}
|
|
245
|
-
|
|
246
|
-
def save_checkpoint(self, name: str) -> Path:
|
|
247
|
-
"""Save model checkpoint."""
|
|
248
|
-
checkpoint_path = self.checkpoint_dir / name
|
|
249
|
-
|
|
250
|
-
torch.save({
|
|
251
|
-
"model_state_dict": self.model.state_dict(),
|
|
252
|
-
"optimizer_state_dict": self.optimizer.state_dict(),
|
|
253
|
-
"scheduler_state_dict": self.scheduler.state_dict(),
|
|
254
|
-
"global_step": self.global_step,
|
|
255
|
-
"best_eval_loss": self.best_eval_loss,
|
|
256
|
-
"config": self.config.to_dict(),
|
|
257
|
-
}, checkpoint_path / "checkpoint.pt")
|
|
258
|
-
|
|
259
|
-
# Save config separately for easy loading
|
|
260
|
-
with open(checkpoint_path / "config.json", "w") as f:
|
|
261
|
-
json.dump(self.config.to_dict(), f, indent=2)
|
|
262
|
-
|
|
263
|
-
logger.info(f"Saved checkpoint: {checkpoint_path}")
|
|
264
|
-
return checkpoint_path
|
|
265
|
-
|
|
266
|
-
def load_checkpoint(self, checkpoint_path: Path) -> None:
|
|
267
|
-
"""Load model checkpoint."""
|
|
268
|
-
checkpoint = torch.load(checkpoint_path / "checkpoint.pt", map_location=self.device)
|
|
269
|
-
|
|
270
|
-
self.model.load_state_dict(checkpoint["model_state_dict"])
|
|
271
|
-
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
272
|
-
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
|
273
|
-
self.global_step = checkpoint["global_step"]
|
|
274
|
-
self.best_eval_loss = checkpoint["best_eval_loss"]
|
|
275
|
-
|
|
276
|
-
logger.info(f"Loaded checkpoint from step {self.global_step}")
|
|
277
|
-
|
|
278
|
-
def _log_metrics(self, metrics: dict) -> None:
|
|
279
|
-
"""Log metrics to tracker and console."""
|
|
280
|
-
if self.tracker:
|
|
281
|
-
self.tracker.log_metrics(metrics, step=self.global_step)
|
|
282
|
-
|
|
283
|
-
logger.info(f"Step {self.global_step}: {metrics}")
|
|
284
|
-
```
|
|
285
|
-
|
|
286
|
-
---
|
|
287
|
-
|
|
288
|
-
## Distributed Training
|
|
289
|
-
|
|
290
|
-
### PyTorch Distributed Data Parallel
|
|
291
|
-
|
|
292
|
-
```python
|
|
293
|
-
import torch
|
|
294
|
-
import torch.distributed as dist
|
|
295
|
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
296
|
-
from torch.utils.data.distributed import DistributedSampler
|
|
297
|
-
import os
|
|
298
|
-
|
|
299
|
-
def setup_distributed() -> tuple[int, int, int]:
|
|
300
|
-
"""Initialize distributed training environment."""
|
|
301
|
-
if "RANK" in os.environ:
|
|
302
|
-
rank = int(os.environ["RANK"])
|
|
303
|
-
local_rank = int(os.environ["LOCAL_RANK"])
|
|
304
|
-
world_size = int(os.environ["WORLD_SIZE"])
|
|
305
|
-
else:
|
|
306
|
-
rank = 0
|
|
307
|
-
local_rank = 0
|
|
308
|
-
world_size = 1
|
|
309
|
-
|
|
310
|
-
if world_size > 1:
|
|
311
|
-
dist.init_process_group(
|
|
312
|
-
backend="nccl",
|
|
313
|
-
init_method="env://",
|
|
314
|
-
world_size=world_size,
|
|
315
|
-
rank=rank,
|
|
316
|
-
)
|
|
317
|
-
torch.cuda.set_device(local_rank)
|
|
318
|
-
|
|
319
|
-
return rank, local_rank, world_size
|
|
320
|
-
|
|
321
|
-
def cleanup_distributed() -> None:
|
|
322
|
-
"""Cleanup distributed training."""
|
|
323
|
-
if dist.is_initialized():
|
|
324
|
-
dist.destroy_process_group()
|
|
325
|
-
|
|
326
|
-
class DistributedTrainer(Trainer):
|
|
327
|
-
"""Trainer with DDP support."""
|
|
328
|
-
|
|
329
|
-
def __init__(self, *args, **kwargs):
|
|
330
|
-
self.rank, self.local_rank, self.world_size = setup_distributed()
|
|
331
|
-
super().__init__(*args, **kwargs)
|
|
332
|
-
|
|
333
|
-
def _setup_device(self) -> None:
|
|
334
|
-
"""Configure device for distributed training."""
|
|
335
|
-
if self.world_size > 1:
|
|
336
|
-
self.device = torch.device(f"cuda:{self.local_rank}")
|
|
337
|
-
self.model = self.model.to(self.device)
|
|
338
|
-
self.model = DDP(
|
|
339
|
-
self.model,
|
|
340
|
-
device_ids=[self.local_rank],
|
|
341
|
-
output_device=self.local_rank,
|
|
342
|
-
find_unused_parameters=False,
|
|
343
|
-
)
|
|
344
|
-
else:
|
|
345
|
-
super()._setup_device()
|
|
346
|
-
|
|
347
|
-
if self.config.mixed_precision and self.device.type == "cuda":
|
|
348
|
-
self.scaler = torch.amp.GradScaler("cuda")
|
|
349
|
-
else:
|
|
350
|
-
self.scaler = None
|
|
351
|
-
|
|
352
|
-
def save_checkpoint(self, name: str) -> Path:
|
|
353
|
-
"""Only save on rank 0."""
|
|
354
|
-
if self.rank == 0:
|
|
355
|
-
return super().save_checkpoint(name)
|
|
356
|
-
return None
|
|
357
|
-
|
|
358
|
-
def _log_metrics(self, metrics: dict) -> None:
|
|
359
|
-
"""Only log on rank 0."""
|
|
360
|
-
if self.rank == 0:
|
|
361
|
-
super()._log_metrics(metrics)
|
|
362
|
-
|
|
363
|
-
def create_distributed_dataloader(
|
|
364
|
-
dataset: Dataset,
|
|
365
|
-
batch_size: int,
|
|
366
|
-
world_size: int,
|
|
367
|
-
rank: int,
|
|
368
|
-
shuffle: bool = True,
|
|
369
|
-
) -> DataLoader:
|
|
370
|
-
"""Create DataLoader with distributed sampler."""
|
|
371
|
-
sampler = DistributedSampler(
|
|
372
|
-
dataset,
|
|
373
|
-
num_replicas=world_size,
|
|
374
|
-
rank=rank,
|
|
375
|
-
shuffle=shuffle,
|
|
376
|
-
)
|
|
377
|
-
|
|
378
|
-
return DataLoader(
|
|
379
|
-
dataset,
|
|
380
|
-
batch_size=batch_size,
|
|
381
|
-
sampler=sampler,
|
|
382
|
-
num_workers=4,
|
|
383
|
-
pin_memory=True,
|
|
384
|
-
drop_last=True,
|
|
385
|
-
)
|
|
386
|
-
```
|
|
387
|
-
|
|
388
|
-
### Launch Script
|
|
389
|
-
|
|
390
|
-
```bash
|
|
391
|
-
#!/bin/bash
|
|
392
|
-
# launch_distributed.sh
|
|
393
|
-
|
|
394
|
-
NUM_GPUS=4
|
|
395
|
-
MASTER_PORT=29500
|
|
396
|
-
|
|
397
|
-
torchrun \
|
|
398
|
-
--nproc_per_node=$NUM_GPUS \
|
|
399
|
-
--master_port=$MASTER_PORT \
|
|
400
|
-
train.py \
|
|
401
|
-
--config config/training_config.yaml
|
|
402
|
-
```
|
|
403
|
-
|
|
404
|
-
---
|
|
405
|
-
|
|
406
|
-
## Hyperparameter Tuning
|
|
407
|
-
|
|
408
|
-
### Optuna Integration
|
|
409
|
-
|
|
410
|
-
```python
|
|
411
|
-
import optuna
|
|
412
|
-
from optuna.trial import Trial
|
|
413
|
-
from optuna.integration import PyTorchLightningPruningCallback
|
|
414
|
-
import mlflow
|
|
415
|
-
|
|
416
|
-
def create_objective(
|
|
417
|
-
train_dataset: Dataset,
|
|
418
|
-
eval_dataset: Dataset,
|
|
419
|
-
model_class: type,
|
|
420
|
-
) -> callable:
|
|
421
|
-
"""Create Optuna objective function."""
|
|
422
|
-
|
|
423
|
-
def objective(trial: Trial) -> float:
|
|
424
|
-
# Sample hyperparameters
|
|
425
|
-
config = TrainingConfig(
|
|
426
|
-
model_name="tuned_model",
|
|
427
|
-
learning_rate=trial.suggest_float("lr", 1e-5, 1e-3, log=True),
|
|
428
|
-
batch_size=trial.suggest_categorical("batch_size", [16, 32, 64]),
|
|
429
|
-
weight_decay=trial.suggest_float("weight_decay", 1e-5, 1e-2, log=True),
|
|
430
|
-
epochs=trial.suggest_int("epochs", 3, 10),
|
|
431
|
-
warmup_steps=trial.suggest_int("warmup_steps", 0, 500),
|
|
432
|
-
)
|
|
433
|
-
|
|
434
|
-
# Create data loaders
|
|
435
|
-
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
|
|
436
|
-
eval_loader = DataLoader(eval_dataset, batch_size=config.batch_size)
|
|
437
|
-
|
|
438
|
-
# Create model
|
|
439
|
-
model = model_class(
|
|
440
|
-
hidden_size=trial.suggest_categorical("hidden_size", [128, 256, 512]),
|
|
441
|
-
num_layers=trial.suggest_int("num_layers", 2, 6),
|
|
442
|
-
dropout=trial.suggest_float("dropout", 0.1, 0.5),
|
|
443
|
-
)
|
|
444
|
-
|
|
445
|
-
# Train
|
|
446
|
-
trainer = Trainer(
|
|
447
|
-
model=model,
|
|
448
|
-
config=config,
|
|
449
|
-
train_dataloader=train_loader,
|
|
450
|
-
eval_dataloader=eval_loader,
|
|
451
|
-
)
|
|
452
|
-
|
|
453
|
-
# Report intermediate values for pruning
|
|
454
|
-
for epoch in range(config.epochs):
|
|
455
|
-
trainer.train_epoch()
|
|
456
|
-
eval_loss = trainer.evaluate()["eval/loss"]
|
|
457
|
-
|
|
458
|
-
trial.report(eval_loss, epoch)
|
|
459
|
-
|
|
460
|
-
if trial.should_prune():
|
|
461
|
-
raise optuna.TrialPruned()
|
|
462
|
-
|
|
463
|
-
return trainer.best_eval_loss
|
|
464
|
-
|
|
465
|
-
return objective
|
|
466
|
-
|
|
467
|
-
def run_hyperparameter_search(
|
|
468
|
-
train_dataset: Dataset,
|
|
469
|
-
eval_dataset: Dataset,
|
|
470
|
-
model_class: type,
|
|
471
|
-
n_trials: int = 100,
|
|
472
|
-
study_name: str = "hpo_study",
|
|
473
|
-
) -> optuna.Study:
|
|
474
|
-
"""Run hyperparameter optimization with Optuna."""
|
|
475
|
-
|
|
476
|
-
# Create study with pruning
|
|
477
|
-
pruner = optuna.pruners.MedianPruner(
|
|
478
|
-
n_startup_trials=5,
|
|
479
|
-
n_warmup_steps=3,
|
|
480
|
-
interval_steps=1,
|
|
481
|
-
)
|
|
482
|
-
|
|
483
|
-
study = optuna.create_study(
|
|
484
|
-
study_name=study_name,
|
|
485
|
-
direction="minimize",
|
|
486
|
-
pruner=pruner,
|
|
487
|
-
storage=f"sqlite:///{study_name}.db",
|
|
488
|
-
load_if_exists=True,
|
|
489
|
-
)
|
|
490
|
-
|
|
491
|
-
objective = create_objective(train_dataset, eval_dataset, model_class)
|
|
492
|
-
|
|
493
|
-
study.optimize(
|
|
494
|
-
objective,
|
|
495
|
-
n_trials=n_trials,
|
|
496
|
-
timeout=3600 * 12, # 12 hours
|
|
497
|
-
n_jobs=1, # Sequential for GPU
|
|
498
|
-
show_progress_bar=True,
|
|
499
|
-
)
|
|
500
|
-
|
|
501
|
-
# Log best results
|
|
502
|
-
logger.info(f"Best trial: {study.best_trial.params}")
|
|
503
|
-
logger.info(f"Best value: {study.best_value}")
|
|
504
|
-
|
|
505
|
-
return study
|
|
506
|
-
```
|
|
507
|
-
|
|
508
|
-
### Ray Tune Integration
|
|
509
|
-
|
|
510
|
-
```python
|
|
511
|
-
from ray import tune
|
|
512
|
-
from ray.tune.schedulers import ASHAScheduler
|
|
513
|
-
from ray.tune.search.optuna import OptunaSearch
|
|
514
|
-
from ray.air import RunConfig, CheckpointConfig
|
|
515
|
-
|
|
516
|
-
def train_fn(config: dict) -> None:
|
|
517
|
-
"""Training function for Ray Tune."""
|
|
518
|
-
from ray.train import report, get_checkpoint
|
|
519
|
-
|
|
520
|
-
training_config = TrainingConfig(
|
|
521
|
-
model_name="ray_tune_model",
|
|
522
|
-
learning_rate=config["lr"],
|
|
523
|
-
batch_size=config["batch_size"],
|
|
524
|
-
weight_decay=config["weight_decay"],
|
|
525
|
-
epochs=config["epochs"],
|
|
526
|
-
)
|
|
527
|
-
|
|
528
|
-
# Build model and dataloaders
|
|
529
|
-
model = build_model(config["hidden_size"], config["num_layers"])
|
|
530
|
-
train_loader, eval_loader = build_dataloaders(config["batch_size"])
|
|
531
|
-
|
|
532
|
-
trainer = Trainer(
|
|
533
|
-
model=model,
|
|
534
|
-
config=training_config,
|
|
535
|
-
train_dataloader=train_loader,
|
|
536
|
-
eval_dataloader=eval_loader,
|
|
537
|
-
)
|
|
538
|
-
|
|
539
|
-
# Resume from checkpoint if available
|
|
540
|
-
checkpoint = get_checkpoint()
|
|
541
|
-
if checkpoint:
|
|
542
|
-
with checkpoint.as_directory() as checkpoint_dir:
|
|
543
|
-
trainer.load_checkpoint(Path(checkpoint_dir))
|
|
544
|
-
|
|
545
|
-
for epoch in range(training_config.epochs):
|
|
546
|
-
trainer.train_epoch()
|
|
547
|
-
metrics = trainer.evaluate()
|
|
548
|
-
|
|
549
|
-
# Report metrics to Ray Tune
|
|
550
|
-
report(
|
|
551
|
-
{"loss": metrics["eval/loss"], "epoch": epoch},
|
|
552
|
-
checkpoint=Checkpoint.from_directory(trainer.checkpoint_dir),
|
|
553
|
-
)
|
|
554
|
-
|
|
555
|
-
def run_ray_tune(num_samples: int = 50) -> tune.ResultGrid:
|
|
556
|
-
"""Run hyperparameter search with Ray Tune."""
|
|
557
|
-
|
|
558
|
-
search_space = {
|
|
559
|
-
"lr": tune.loguniform(1e-5, 1e-3),
|
|
560
|
-
"batch_size": tune.choice([16, 32, 64]),
|
|
561
|
-
"weight_decay": tune.loguniform(1e-5, 1e-2),
|
|
562
|
-
"hidden_size": tune.choice([128, 256, 512]),
|
|
563
|
-
"num_layers": tune.randint(2, 7),
|
|
564
|
-
"epochs": 10,
|
|
565
|
-
}
|
|
566
|
-
|
|
567
|
-
scheduler = ASHAScheduler(
|
|
568
|
-
metric="loss",
|
|
569
|
-
mode="min",
|
|
570
|
-
max_t=10,
|
|
571
|
-
grace_period=2,
|
|
572
|
-
reduction_factor=3,
|
|
573
|
-
)
|
|
574
|
-
|
|
575
|
-
tuner = tune.Tuner(
|
|
576
|
-
tune.with_resources(train_fn, {"gpu": 1}),
|
|
577
|
-
param_space=search_space,
|
|
578
|
-
tune_config=tune.TuneConfig(
|
|
579
|
-
num_samples=num_samples,
|
|
580
|
-
scheduler=scheduler,
|
|
581
|
-
search_alg=OptunaSearch(),
|
|
582
|
-
),
|
|
583
|
-
run_config=RunConfig(
|
|
584
|
-
name="hpo_experiment",
|
|
585
|
-
checkpoint_config=CheckpointConfig(
|
|
586
|
-
num_to_keep=3,
|
|
587
|
-
checkpoint_frequency=1,
|
|
588
|
-
),
|
|
589
|
-
),
|
|
590
|
-
)
|
|
591
|
-
|
|
592
|
-
results = tuner.fit()
|
|
593
|
-
best_result = results.get_best_result("loss", "min")
|
|
594
|
-
|
|
595
|
-
logger.info(f"Best config: {best_result.config}")
|
|
596
|
-
logger.info(f"Best loss: {best_result.metrics['loss']}")
|
|
597
|
-
|
|
598
|
-
return results
|
|
599
|
-
```
|
|
600
|
-
|
|
601
|
-
---
|
|
602
|
-
|
|
603
|
-
## Resource Management
|
|
604
|
-
|
|
605
|
-
### GPU Memory Optimization
|
|
606
|
-
|
|
607
|
-
```python
|
|
608
|
-
import torch
|
|
609
|
-
from contextlib import contextmanager
|
|
610
|
-
|
|
611
|
-
@contextmanager
|
|
612
|
-
def gpu_memory_manager():
|
|
613
|
-
"""Context manager for GPU memory cleanup."""
|
|
614
|
-
try:
|
|
615
|
-
yield
|
|
616
|
-
finally:
|
|
617
|
-
torch.cuda.empty_cache()
|
|
618
|
-
torch.cuda.synchronize()
|
|
619
|
-
|
|
620
|
-
def get_gpu_memory_usage() -> dict:
|
|
621
|
-
"""Get current GPU memory statistics."""
|
|
622
|
-
if not torch.cuda.is_available():
|
|
623
|
-
return {"available": False}
|
|
624
|
-
|
|
625
|
-
return {
|
|
626
|
-
"allocated": torch.cuda.memory_allocated() / 1e9,
|
|
627
|
-
"reserved": torch.cuda.memory_reserved() / 1e9,
|
|
628
|
-
"max_allocated": torch.cuda.max_memory_allocated() / 1e9,
|
|
629
|
-
}
|
|
630
|
-
|
|
631
|
-
class GradientCheckpointing:
|
|
632
|
-
"""Enable gradient checkpointing for memory efficiency."""
|
|
633
|
-
|
|
634
|
-
@staticmethod
|
|
635
|
-
def enable(model: nn.Module, checkpoint_layers: list[str] = None) -> None:
|
|
636
|
-
"""Enable gradient checkpointing on specified layers."""
|
|
637
|
-
if hasattr(model, "gradient_checkpointing_enable"):
|
|
638
|
-
model.gradient_checkpointing_enable()
|
|
639
|
-
return
|
|
640
|
-
|
|
641
|
-
# Manual checkpointing for custom models
|
|
642
|
-
from torch.utils.checkpoint import checkpoint
|
|
643
|
-
|
|
644
|
-
def create_custom_forward(module):
|
|
645
|
-
def custom_forward(*inputs):
|
|
646
|
-
return checkpoint(module._original_forward, *inputs, use_reentrant=False)
|
|
647
|
-
return custom_forward
|
|
648
|
-
|
|
649
|
-
for name, module in model.named_modules():
|
|
650
|
-
if checkpoint_layers and name not in checkpoint_layers:
|
|
651
|
-
continue
|
|
652
|
-
if hasattr(module, "forward"):
|
|
653
|
-
module._original_forward = module.forward
|
|
654
|
-
module.forward = create_custom_forward(module)
|
|
655
|
-
```
|
|
656
|
-
|
|
657
|
-
### Batch Size Finder
|
|
658
|
-
|
|
659
|
-
```python
|
|
660
|
-
def find_optimal_batch_size(
|
|
661
|
-
model: nn.Module,
|
|
662
|
-
sample_batch: dict,
|
|
663
|
-
device: torch.device,
|
|
664
|
-
min_batch_size: int = 1,
|
|
665
|
-
max_batch_size: int = 256,
|
|
666
|
-
) -> int:
|
|
667
|
-
"""Find maximum batch size that fits in GPU memory."""
|
|
668
|
-
|
|
669
|
-
model = model.to(device)
|
|
670
|
-
optimal_batch_size = min_batch_size
|
|
671
|
-
|
|
672
|
-
for batch_size in [2**i for i in range(int(np.log2(max_batch_size)) + 1)]:
|
|
673
|
-
if batch_size < min_batch_size:
|
|
674
|
-
continue
|
|
675
|
-
|
|
676
|
-
try:
|
|
677
|
-
# Create batch of target size
|
|
678
|
-
batch = {
|
|
679
|
-
k: v.repeat(batch_size // v.size(0) + 1, *[1] * (v.dim() - 1))[:batch_size]
|
|
680
|
-
for k, v in sample_batch.items()
|
|
681
|
-
}
|
|
682
|
-
batch = {k: v.to(device) for k, v in batch.items()}
|
|
683
|
-
|
|
684
|
-
# Forward pass
|
|
685
|
-
with torch.amp.autocast("cuda"):
|
|
686
|
-
outputs = model(**batch)
|
|
687
|
-
loss = outputs.loss
|
|
688
|
-
|
|
689
|
-
# Backward pass
|
|
690
|
-
loss.backward()
|
|
691
|
-
model.zero_grad()
|
|
692
|
-
|
|
693
|
-
torch.cuda.empty_cache()
|
|
694
|
-
optimal_batch_size = batch_size
|
|
695
|
-
|
|
696
|
-
except RuntimeError as e:
|
|
697
|
-
if "out of memory" in str(e):
|
|
698
|
-
torch.cuda.empty_cache()
|
|
699
|
-
break
|
|
700
|
-
raise
|
|
701
|
-
|
|
702
|
-
logger.info(f"Optimal batch size: {optimal_batch_size}")
|
|
703
|
-
return optimal_batch_size
|
|
704
|
-
```
|
|
705
|
-
|
|
706
|
-
---
|
|
707
|
-
|
|
708
|
-
## Best Practices
|
|
709
|
-
|
|
710
|
-
### Training Configuration Management
|
|
711
|
-
|
|
712
|
-
```yaml
|
|
713
|
-
# config/training_config.yaml
|
|
714
|
-
model:
|
|
715
|
-
name: transformer
|
|
716
|
-
hidden_size: 512
|
|
717
|
-
num_layers: 6
|
|
718
|
-
dropout: 0.1
|
|
719
|
-
|
|
720
|
-
training:
|
|
721
|
-
batch_size: 32
|
|
722
|
-
learning_rate: 1e-4
|
|
723
|
-
weight_decay: 0.01
|
|
724
|
-
epochs: 10
|
|
725
|
-
mixed_precision: true
|
|
726
|
-
gradient_accumulation_steps: 4
|
|
727
|
-
|
|
728
|
-
distributed:
|
|
729
|
-
enabled: true
|
|
730
|
-
backend: nccl
|
|
731
|
-
|
|
732
|
-
checkpointing:
|
|
733
|
-
save_every_n_steps: 1000
|
|
734
|
-
keep_n_checkpoints: 3
|
|
735
|
-
|
|
736
|
-
logging:
|
|
737
|
-
log_every_n_steps: 100
|
|
738
|
-
eval_every_n_steps: 500
|
|
739
|
-
```
|
|
740
|
-
|
|
741
|
-
### Reproducibility Checklist
|
|
742
|
-
|
|
743
|
-
```python
|
|
744
|
-
def ensure_reproducibility(seed: int) -> None:
|
|
745
|
-
"""Set all random seeds for reproducibility."""
|
|
746
|
-
import random
|
|
747
|
-
import numpy as np
|
|
748
|
-
import os
|
|
749
|
-
|
|
750
|
-
# Python
|
|
751
|
-
random.seed(seed)
|
|
752
|
-
|
|
753
|
-
# NumPy
|
|
754
|
-
np.random.seed(seed)
|
|
755
|
-
|
|
756
|
-
# PyTorch
|
|
757
|
-
torch.manual_seed(seed)
|
|
758
|
-
torch.cuda.manual_seed_all(seed)
|
|
759
|
-
|
|
760
|
-
# CUDA
|
|
761
|
-
torch.backends.cudnn.deterministic = True
|
|
762
|
-
torch.backends.cudnn.benchmark = False
|
|
763
|
-
|
|
764
|
-
# Environment
|
|
765
|
-
os.environ["PYTHONHASHSEED"] = str(seed)
|
|
766
|
-
|
|
767
|
-
logger.info(f"Set all random seeds to {seed}")
|
|
768
|
-
```
|
|
769
|
-
|
|
770
|
-
---
|
|
771
|
-
|
|
772
|
-
## Related References
|
|
773
|
-
|
|
774
|
-
- `feature-engineering.md` - Feature preparation for training
|
|
775
|
-
- `experiment-tracking.md` - Logging training metrics
|
|
776
|
-
- `pipeline-orchestration.md` - Orchestrating training pipelines
|
|
777
|
-
- `model-validation.md` - Validating trained models
|
|
778
|
-
|
|
779
|
-
## Cross-Reference Skills
|
|
780
|
-
|
|
781
|
-
- **DevOps Engineer** - CI/CD for training pipelines
|
|
782
|
-
- **Kubernetes Specialist** - K8s-based training infrastructure
|
|
1
|
+
# Training Pipelines
|
|
2
|
+
|
|
3
|
+
---
|
|
4
|
+
|
|
5
|
+
## Overview
|
|
6
|
+
|
|
7
|
+
Training pipelines orchestrate the end-to-end model training process including data loading, distributed training, hyperparameter optimization, and artifact management. Production pipelines require reproducibility, scalability, and proper resource management.
|
|
8
|
+
|
|
9
|
+
## When to Use This Reference
|
|
10
|
+
|
|
11
|
+
- Setting up distributed training with PyTorch/TensorFlow
|
|
12
|
+
- Implementing hyperparameter tuning (Optuna, Ray Tune)
|
|
13
|
+
- Managing GPU/TPU resources for training
|
|
14
|
+
- Building reproducible training environments
|
|
15
|
+
- Creating checkpointing and fault-tolerant training
|
|
16
|
+
|
|
17
|
+
## When NOT to Use
|
|
18
|
+
|
|
19
|
+
- Quick model prototyping (use notebooks)
|
|
20
|
+
- Small models that fit in memory on single GPU
|
|
21
|
+
- One-off experiments without production requirements
|
|
22
|
+
|
|
23
|
+
---
|
|
24
|
+
|
|
25
|
+
## PyTorch Training Pipeline
|
|
26
|
+
|
|
27
|
+
### Complete Training Script
|
|
28
|
+
|
|
29
|
+
```python
|
|
30
|
+
import torch
|
|
31
|
+
import torch.nn as nn
|
|
32
|
+
from torch.utils.data import DataLoader, Dataset
|
|
33
|
+
from torch.optim import AdamW
|
|
34
|
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
35
|
+
import logging
|
|
36
|
+
from pathlib import Path
|
|
37
|
+
from dataclasses import dataclass
|
|
38
|
+
from typing import Optional
|
|
39
|
+
import json
|
|
40
|
+
|
|
41
|
+
logger = logging.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class TrainingConfig:
|
|
45
|
+
"""Training hyperparameters and settings."""
|
|
46
|
+
model_name: str
|
|
47
|
+
batch_size: int = 32
|
|
48
|
+
learning_rate: float = 1e-4
|
|
49
|
+
weight_decay: float = 0.01
|
|
50
|
+
epochs: int = 10
|
|
51
|
+
warmup_steps: int = 100
|
|
52
|
+
max_grad_norm: float = 1.0
|
|
53
|
+
seed: int = 42
|
|
54
|
+
checkpoint_dir: str = "./checkpoints"
|
|
55
|
+
log_every_n_steps: int = 100
|
|
56
|
+
eval_every_n_steps: int = 500
|
|
57
|
+
save_every_n_steps: int = 1000
|
|
58
|
+
mixed_precision: bool = True
|
|
59
|
+
gradient_accumulation_steps: int = 1
|
|
60
|
+
|
|
61
|
+
def to_dict(self) -> dict:
|
|
62
|
+
return {k: v for k, v in self.__dict__.items()}
|
|
63
|
+
|
|
64
|
+
@classmethod
|
|
65
|
+
def from_dict(cls, d: dict) -> "TrainingConfig":
|
|
66
|
+
return cls(**d)
|
|
67
|
+
|
|
68
|
+
class Trainer:
|
|
69
|
+
"""Production-grade PyTorch trainer."""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
model: nn.Module,
|
|
74
|
+
config: TrainingConfig,
|
|
75
|
+
train_dataloader: DataLoader,
|
|
76
|
+
eval_dataloader: Optional[DataLoader] = None,
|
|
77
|
+
experiment_tracker=None,
|
|
78
|
+
):
|
|
79
|
+
self.model = model
|
|
80
|
+
self.config = config
|
|
81
|
+
self.train_dataloader = train_dataloader
|
|
82
|
+
self.eval_dataloader = eval_dataloader
|
|
83
|
+
self.tracker = experiment_tracker
|
|
84
|
+
|
|
85
|
+
self._setup_device()
|
|
86
|
+
self._setup_training()
|
|
87
|
+
self._setup_checkpointing()
|
|
88
|
+
|
|
89
|
+
def _setup_device(self) -> None:
|
|
90
|
+
"""Configure device and move model."""
|
|
91
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
92
|
+
self.model = self.model.to(self.device)
|
|
93
|
+
|
|
94
|
+
if self.config.mixed_precision and self.device.type == "cuda":
|
|
95
|
+
self.scaler = torch.amp.GradScaler("cuda")
|
|
96
|
+
else:
|
|
97
|
+
self.scaler = None
|
|
98
|
+
|
|
99
|
+
logger.info(f"Training on device: {self.device}")
|
|
100
|
+
|
|
101
|
+
def _setup_training(self) -> None:
|
|
102
|
+
"""Initialize optimizer and scheduler."""
|
|
103
|
+
self.optimizer = AdamW(
|
|
104
|
+
self.model.parameters(),
|
|
105
|
+
lr=self.config.learning_rate,
|
|
106
|
+
weight_decay=self.config.weight_decay,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
total_steps = len(self.train_dataloader) * self.config.epochs
|
|
110
|
+
self.scheduler = CosineAnnealingLR(
|
|
111
|
+
self.optimizer,
|
|
112
|
+
T_max=total_steps,
|
|
113
|
+
eta_min=self.config.learning_rate * 0.01,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
self.global_step = 0
|
|
117
|
+
self.best_eval_loss = float("inf")
|
|
118
|
+
|
|
119
|
+
def _setup_checkpointing(self) -> None:
|
|
120
|
+
"""Create checkpoint directory."""
|
|
121
|
+
self.checkpoint_dir = Path(self.config.checkpoint_dir)
|
|
122
|
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
123
|
+
|
|
124
|
+
def _set_seed(self) -> None:
|
|
125
|
+
"""Set random seeds for reproducibility."""
|
|
126
|
+
import random
|
|
127
|
+
import numpy as np
|
|
128
|
+
|
|
129
|
+
torch.manual_seed(self.config.seed)
|
|
130
|
+
torch.cuda.manual_seed_all(self.config.seed)
|
|
131
|
+
np.random.seed(self.config.seed)
|
|
132
|
+
random.seed(self.config.seed)
|
|
133
|
+
torch.backends.cudnn.deterministic = True
|
|
134
|
+
|
|
135
|
+
def train(self) -> dict:
|
|
136
|
+
"""Run training loop."""
|
|
137
|
+
self._set_seed()
|
|
138
|
+
self.model.train()
|
|
139
|
+
|
|
140
|
+
metrics_history = []
|
|
141
|
+
|
|
142
|
+
for epoch in range(self.config.epochs):
|
|
143
|
+
epoch_loss = 0.0
|
|
144
|
+
num_batches = 0
|
|
145
|
+
|
|
146
|
+
for batch_idx, batch in enumerate(self.train_dataloader):
|
|
147
|
+
loss = self._training_step(batch)
|
|
148
|
+
epoch_loss += loss
|
|
149
|
+
num_batches += 1
|
|
150
|
+
|
|
151
|
+
if self.global_step % self.config.log_every_n_steps == 0:
|
|
152
|
+
self._log_metrics({
|
|
153
|
+
"train/loss": loss,
|
|
154
|
+
"train/lr": self.scheduler.get_last_lr()[0],
|
|
155
|
+
"train/epoch": epoch,
|
|
156
|
+
})
|
|
157
|
+
|
|
158
|
+
if (
|
|
159
|
+
self.eval_dataloader
|
|
160
|
+
and self.global_step % self.config.eval_every_n_steps == 0
|
|
161
|
+
):
|
|
162
|
+
eval_metrics = self.evaluate()
|
|
163
|
+
self._log_metrics(eval_metrics)
|
|
164
|
+
|
|
165
|
+
if eval_metrics["eval/loss"] < self.best_eval_loss:
|
|
166
|
+
self.best_eval_loss = eval_metrics["eval/loss"]
|
|
167
|
+
self.save_checkpoint("best")
|
|
168
|
+
|
|
169
|
+
if self.global_step % self.config.save_every_n_steps == 0:
|
|
170
|
+
self.save_checkpoint(f"step_{self.global_step}")
|
|
171
|
+
|
|
172
|
+
avg_epoch_loss = epoch_loss / num_batches
|
|
173
|
+
logger.info(f"Epoch {epoch}: avg_loss={avg_epoch_loss:.4f}")
|
|
174
|
+
metrics_history.append({"epoch": epoch, "loss": avg_epoch_loss})
|
|
175
|
+
|
|
176
|
+
self.save_checkpoint("final")
|
|
177
|
+
|
|
178
|
+
return {
|
|
179
|
+
"best_eval_loss": self.best_eval_loss,
|
|
180
|
+
"final_train_loss": avg_epoch_loss,
|
|
181
|
+
"total_steps": self.global_step,
|
|
182
|
+
"metrics_history": metrics_history,
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
def _training_step(self, batch: dict) -> float:
|
|
186
|
+
"""Execute single training step."""
|
|
187
|
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
188
|
+
|
|
189
|
+
if self.scaler:
|
|
190
|
+
with torch.amp.autocast("cuda"):
|
|
191
|
+
outputs = self.model(**batch)
|
|
192
|
+
loss = outputs.loss / self.config.gradient_accumulation_steps
|
|
193
|
+
self.scaler.scale(loss).backward()
|
|
194
|
+
else:
|
|
195
|
+
outputs = self.model(**batch)
|
|
196
|
+
loss = outputs.loss / self.config.gradient_accumulation_steps
|
|
197
|
+
loss.backward()
|
|
198
|
+
|
|
199
|
+
if (self.global_step + 1) % self.config.gradient_accumulation_steps == 0:
|
|
200
|
+
if self.scaler:
|
|
201
|
+
self.scaler.unscale_(self.optimizer)
|
|
202
|
+
|
|
203
|
+
torch.nn.utils.clip_grad_norm_(
|
|
204
|
+
self.model.parameters(),
|
|
205
|
+
self.config.max_grad_norm,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
if self.scaler:
|
|
209
|
+
self.scaler.step(self.optimizer)
|
|
210
|
+
self.scaler.update()
|
|
211
|
+
else:
|
|
212
|
+
self.optimizer.step()
|
|
213
|
+
|
|
214
|
+
self.scheduler.step()
|
|
215
|
+
self.optimizer.zero_grad()
|
|
216
|
+
|
|
217
|
+
self.global_step += 1
|
|
218
|
+
return loss.item() * self.config.gradient_accumulation_steps
|
|
219
|
+
|
|
220
|
+
@torch.no_grad()
|
|
221
|
+
def evaluate(self) -> dict:
|
|
222
|
+
"""Run evaluation loop."""
|
|
223
|
+
self.model.eval()
|
|
224
|
+
total_loss = 0.0
|
|
225
|
+
num_batches = 0
|
|
226
|
+
|
|
227
|
+
for batch in self.eval_dataloader:
|
|
228
|
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
229
|
+
|
|
230
|
+
if self.scaler:
|
|
231
|
+
with torch.amp.autocast("cuda"):
|
|
232
|
+
outputs = self.model(**batch)
|
|
233
|
+
else:
|
|
234
|
+
outputs = self.model(**batch)
|
|
235
|
+
|
|
236
|
+
total_loss += outputs.loss.item()
|
|
237
|
+
num_batches += 1
|
|
238
|
+
|
|
239
|
+
self.model.train()
|
|
240
|
+
|
|
241
|
+
return {
|
|
242
|
+
"eval/loss": total_loss / num_batches,
|
|
243
|
+
"eval/step": self.global_step,
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
def save_checkpoint(self, name: str) -> Path:
|
|
247
|
+
"""Save model checkpoint."""
|
|
248
|
+
checkpoint_path = self.checkpoint_dir / name
|
|
249
|
+
|
|
250
|
+
torch.save({
|
|
251
|
+
"model_state_dict": self.model.state_dict(),
|
|
252
|
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
|
253
|
+
"scheduler_state_dict": self.scheduler.state_dict(),
|
|
254
|
+
"global_step": self.global_step,
|
|
255
|
+
"best_eval_loss": self.best_eval_loss,
|
|
256
|
+
"config": self.config.to_dict(),
|
|
257
|
+
}, checkpoint_path / "checkpoint.pt")
|
|
258
|
+
|
|
259
|
+
# Save config separately for easy loading
|
|
260
|
+
with open(checkpoint_path / "config.json", "w") as f:
|
|
261
|
+
json.dump(self.config.to_dict(), f, indent=2)
|
|
262
|
+
|
|
263
|
+
logger.info(f"Saved checkpoint: {checkpoint_path}")
|
|
264
|
+
return checkpoint_path
|
|
265
|
+
|
|
266
|
+
def load_checkpoint(self, checkpoint_path: Path) -> None:
|
|
267
|
+
"""Load model checkpoint."""
|
|
268
|
+
checkpoint = torch.load(checkpoint_path / "checkpoint.pt", map_location=self.device)
|
|
269
|
+
|
|
270
|
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
|
271
|
+
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
272
|
+
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
|
273
|
+
self.global_step = checkpoint["global_step"]
|
|
274
|
+
self.best_eval_loss = checkpoint["best_eval_loss"]
|
|
275
|
+
|
|
276
|
+
logger.info(f"Loaded checkpoint from step {self.global_step}")
|
|
277
|
+
|
|
278
|
+
def _log_metrics(self, metrics: dict) -> None:
|
|
279
|
+
"""Log metrics to tracker and console."""
|
|
280
|
+
if self.tracker:
|
|
281
|
+
self.tracker.log_metrics(metrics, step=self.global_step)
|
|
282
|
+
|
|
283
|
+
logger.info(f"Step {self.global_step}: {metrics}")
|
|
284
|
+
```
|
|
285
|
+
|
|
286
|
+
---
|
|
287
|
+
|
|
288
|
+
## Distributed Training
|
|
289
|
+
|
|
290
|
+
### PyTorch Distributed Data Parallel
|
|
291
|
+
|
|
292
|
+
```python
|
|
293
|
+
import torch
|
|
294
|
+
import torch.distributed as dist
|
|
295
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
296
|
+
from torch.utils.data.distributed import DistributedSampler
|
|
297
|
+
import os
|
|
298
|
+
|
|
299
|
+
def setup_distributed() -> tuple[int, int, int]:
|
|
300
|
+
"""Initialize distributed training environment."""
|
|
301
|
+
if "RANK" in os.environ:
|
|
302
|
+
rank = int(os.environ["RANK"])
|
|
303
|
+
local_rank = int(os.environ["LOCAL_RANK"])
|
|
304
|
+
world_size = int(os.environ["WORLD_SIZE"])
|
|
305
|
+
else:
|
|
306
|
+
rank = 0
|
|
307
|
+
local_rank = 0
|
|
308
|
+
world_size = 1
|
|
309
|
+
|
|
310
|
+
if world_size > 1:
|
|
311
|
+
dist.init_process_group(
|
|
312
|
+
backend="nccl",
|
|
313
|
+
init_method="env://",
|
|
314
|
+
world_size=world_size,
|
|
315
|
+
rank=rank,
|
|
316
|
+
)
|
|
317
|
+
torch.cuda.set_device(local_rank)
|
|
318
|
+
|
|
319
|
+
return rank, local_rank, world_size
|
|
320
|
+
|
|
321
|
+
def cleanup_distributed() -> None:
|
|
322
|
+
"""Cleanup distributed training."""
|
|
323
|
+
if dist.is_initialized():
|
|
324
|
+
dist.destroy_process_group()
|
|
325
|
+
|
|
326
|
+
class DistributedTrainer(Trainer):
|
|
327
|
+
"""Trainer with DDP support."""
|
|
328
|
+
|
|
329
|
+
def __init__(self, *args, **kwargs):
|
|
330
|
+
self.rank, self.local_rank, self.world_size = setup_distributed()
|
|
331
|
+
super().__init__(*args, **kwargs)
|
|
332
|
+
|
|
333
|
+
def _setup_device(self) -> None:
|
|
334
|
+
"""Configure device for distributed training."""
|
|
335
|
+
if self.world_size > 1:
|
|
336
|
+
self.device = torch.device(f"cuda:{self.local_rank}")
|
|
337
|
+
self.model = self.model.to(self.device)
|
|
338
|
+
self.model = DDP(
|
|
339
|
+
self.model,
|
|
340
|
+
device_ids=[self.local_rank],
|
|
341
|
+
output_device=self.local_rank,
|
|
342
|
+
find_unused_parameters=False,
|
|
343
|
+
)
|
|
344
|
+
else:
|
|
345
|
+
super()._setup_device()
|
|
346
|
+
|
|
347
|
+
if self.config.mixed_precision and self.device.type == "cuda":
|
|
348
|
+
self.scaler = torch.amp.GradScaler("cuda")
|
|
349
|
+
else:
|
|
350
|
+
self.scaler = None
|
|
351
|
+
|
|
352
|
+
def save_checkpoint(self, name: str) -> Path:
|
|
353
|
+
"""Only save on rank 0."""
|
|
354
|
+
if self.rank == 0:
|
|
355
|
+
return super().save_checkpoint(name)
|
|
356
|
+
return None
|
|
357
|
+
|
|
358
|
+
def _log_metrics(self, metrics: dict) -> None:
|
|
359
|
+
"""Only log on rank 0."""
|
|
360
|
+
if self.rank == 0:
|
|
361
|
+
super()._log_metrics(metrics)
|
|
362
|
+
|
|
363
|
+
def create_distributed_dataloader(
|
|
364
|
+
dataset: Dataset,
|
|
365
|
+
batch_size: int,
|
|
366
|
+
world_size: int,
|
|
367
|
+
rank: int,
|
|
368
|
+
shuffle: bool = True,
|
|
369
|
+
) -> DataLoader:
|
|
370
|
+
"""Create DataLoader with distributed sampler."""
|
|
371
|
+
sampler = DistributedSampler(
|
|
372
|
+
dataset,
|
|
373
|
+
num_replicas=world_size,
|
|
374
|
+
rank=rank,
|
|
375
|
+
shuffle=shuffle,
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
return DataLoader(
|
|
379
|
+
dataset,
|
|
380
|
+
batch_size=batch_size,
|
|
381
|
+
sampler=sampler,
|
|
382
|
+
num_workers=4,
|
|
383
|
+
pin_memory=True,
|
|
384
|
+
drop_last=True,
|
|
385
|
+
)
|
|
386
|
+
```
|
|
387
|
+
|
|
388
|
+
### Launch Script
|
|
389
|
+
|
|
390
|
+
```bash
|
|
391
|
+
#!/bin/bash
|
|
392
|
+
# launch_distributed.sh
|
|
393
|
+
|
|
394
|
+
NUM_GPUS=4
|
|
395
|
+
MASTER_PORT=29500
|
|
396
|
+
|
|
397
|
+
torchrun \
|
|
398
|
+
--nproc_per_node=$NUM_GPUS \
|
|
399
|
+
--master_port=$MASTER_PORT \
|
|
400
|
+
train.py \
|
|
401
|
+
--config config/training_config.yaml
|
|
402
|
+
```
|
|
403
|
+
|
|
404
|
+
---
|
|
405
|
+
|
|
406
|
+
## Hyperparameter Tuning
|
|
407
|
+
|
|
408
|
+
### Optuna Integration
|
|
409
|
+
|
|
410
|
+
```python
|
|
411
|
+
import optuna
|
|
412
|
+
from optuna.trial import Trial
|
|
413
|
+
from optuna.integration import PyTorchLightningPruningCallback
|
|
414
|
+
import mlflow
|
|
415
|
+
|
|
416
|
+
def create_objective(
|
|
417
|
+
train_dataset: Dataset,
|
|
418
|
+
eval_dataset: Dataset,
|
|
419
|
+
model_class: type,
|
|
420
|
+
) -> callable:
|
|
421
|
+
"""Create Optuna objective function."""
|
|
422
|
+
|
|
423
|
+
def objective(trial: Trial) -> float:
|
|
424
|
+
# Sample hyperparameters
|
|
425
|
+
config = TrainingConfig(
|
|
426
|
+
model_name="tuned_model",
|
|
427
|
+
learning_rate=trial.suggest_float("lr", 1e-5, 1e-3, log=True),
|
|
428
|
+
batch_size=trial.suggest_categorical("batch_size", [16, 32, 64]),
|
|
429
|
+
weight_decay=trial.suggest_float("weight_decay", 1e-5, 1e-2, log=True),
|
|
430
|
+
epochs=trial.suggest_int("epochs", 3, 10),
|
|
431
|
+
warmup_steps=trial.suggest_int("warmup_steps", 0, 500),
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
# Create data loaders
|
|
435
|
+
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
|
|
436
|
+
eval_loader = DataLoader(eval_dataset, batch_size=config.batch_size)
|
|
437
|
+
|
|
438
|
+
# Create model
|
|
439
|
+
model = model_class(
|
|
440
|
+
hidden_size=trial.suggest_categorical("hidden_size", [128, 256, 512]),
|
|
441
|
+
num_layers=trial.suggest_int("num_layers", 2, 6),
|
|
442
|
+
dropout=trial.suggest_float("dropout", 0.1, 0.5),
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
# Train
|
|
446
|
+
trainer = Trainer(
|
|
447
|
+
model=model,
|
|
448
|
+
config=config,
|
|
449
|
+
train_dataloader=train_loader,
|
|
450
|
+
eval_dataloader=eval_loader,
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
# Report intermediate values for pruning
|
|
454
|
+
for epoch in range(config.epochs):
|
|
455
|
+
trainer.train_epoch()
|
|
456
|
+
eval_loss = trainer.evaluate()["eval/loss"]
|
|
457
|
+
|
|
458
|
+
trial.report(eval_loss, epoch)
|
|
459
|
+
|
|
460
|
+
if trial.should_prune():
|
|
461
|
+
raise optuna.TrialPruned()
|
|
462
|
+
|
|
463
|
+
return trainer.best_eval_loss
|
|
464
|
+
|
|
465
|
+
return objective
|
|
466
|
+
|
|
467
|
+
def run_hyperparameter_search(
|
|
468
|
+
train_dataset: Dataset,
|
|
469
|
+
eval_dataset: Dataset,
|
|
470
|
+
model_class: type,
|
|
471
|
+
n_trials: int = 100,
|
|
472
|
+
study_name: str = "hpo_study",
|
|
473
|
+
) -> optuna.Study:
|
|
474
|
+
"""Run hyperparameter optimization with Optuna."""
|
|
475
|
+
|
|
476
|
+
# Create study with pruning
|
|
477
|
+
pruner = optuna.pruners.MedianPruner(
|
|
478
|
+
n_startup_trials=5,
|
|
479
|
+
n_warmup_steps=3,
|
|
480
|
+
interval_steps=1,
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
study = optuna.create_study(
|
|
484
|
+
study_name=study_name,
|
|
485
|
+
direction="minimize",
|
|
486
|
+
pruner=pruner,
|
|
487
|
+
storage=f"sqlite:///{study_name}.db",
|
|
488
|
+
load_if_exists=True,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
objective = create_objective(train_dataset, eval_dataset, model_class)
|
|
492
|
+
|
|
493
|
+
study.optimize(
|
|
494
|
+
objective,
|
|
495
|
+
n_trials=n_trials,
|
|
496
|
+
timeout=3600 * 12, # 12 hours
|
|
497
|
+
n_jobs=1, # Sequential for GPU
|
|
498
|
+
show_progress_bar=True,
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
# Log best results
|
|
502
|
+
logger.info(f"Best trial: {study.best_trial.params}")
|
|
503
|
+
logger.info(f"Best value: {study.best_value}")
|
|
504
|
+
|
|
505
|
+
return study
|
|
506
|
+
```
|
|
507
|
+
|
|
508
|
+
### Ray Tune Integration
|
|
509
|
+
|
|
510
|
+
```python
|
|
511
|
+
from ray import tune
|
|
512
|
+
from ray.tune.schedulers import ASHAScheduler
|
|
513
|
+
from ray.tune.search.optuna import OptunaSearch
|
|
514
|
+
from ray.air import RunConfig, CheckpointConfig
|
|
515
|
+
|
|
516
|
+
def train_fn(config: dict) -> None:
|
|
517
|
+
"""Training function for Ray Tune."""
|
|
518
|
+
from ray.train import report, get_checkpoint
|
|
519
|
+
|
|
520
|
+
training_config = TrainingConfig(
|
|
521
|
+
model_name="ray_tune_model",
|
|
522
|
+
learning_rate=config["lr"],
|
|
523
|
+
batch_size=config["batch_size"],
|
|
524
|
+
weight_decay=config["weight_decay"],
|
|
525
|
+
epochs=config["epochs"],
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
# Build model and dataloaders
|
|
529
|
+
model = build_model(config["hidden_size"], config["num_layers"])
|
|
530
|
+
train_loader, eval_loader = build_dataloaders(config["batch_size"])
|
|
531
|
+
|
|
532
|
+
trainer = Trainer(
|
|
533
|
+
model=model,
|
|
534
|
+
config=training_config,
|
|
535
|
+
train_dataloader=train_loader,
|
|
536
|
+
eval_dataloader=eval_loader,
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
# Resume from checkpoint if available
|
|
540
|
+
checkpoint = get_checkpoint()
|
|
541
|
+
if checkpoint:
|
|
542
|
+
with checkpoint.as_directory() as checkpoint_dir:
|
|
543
|
+
trainer.load_checkpoint(Path(checkpoint_dir))
|
|
544
|
+
|
|
545
|
+
for epoch in range(training_config.epochs):
|
|
546
|
+
trainer.train_epoch()
|
|
547
|
+
metrics = trainer.evaluate()
|
|
548
|
+
|
|
549
|
+
# Report metrics to Ray Tune
|
|
550
|
+
report(
|
|
551
|
+
{"loss": metrics["eval/loss"], "epoch": epoch},
|
|
552
|
+
checkpoint=Checkpoint.from_directory(trainer.checkpoint_dir),
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
def run_ray_tune(num_samples: int = 50) -> tune.ResultGrid:
|
|
556
|
+
"""Run hyperparameter search with Ray Tune."""
|
|
557
|
+
|
|
558
|
+
search_space = {
|
|
559
|
+
"lr": tune.loguniform(1e-5, 1e-3),
|
|
560
|
+
"batch_size": tune.choice([16, 32, 64]),
|
|
561
|
+
"weight_decay": tune.loguniform(1e-5, 1e-2),
|
|
562
|
+
"hidden_size": tune.choice([128, 256, 512]),
|
|
563
|
+
"num_layers": tune.randint(2, 7),
|
|
564
|
+
"epochs": 10,
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
scheduler = ASHAScheduler(
|
|
568
|
+
metric="loss",
|
|
569
|
+
mode="min",
|
|
570
|
+
max_t=10,
|
|
571
|
+
grace_period=2,
|
|
572
|
+
reduction_factor=3,
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
tuner = tune.Tuner(
|
|
576
|
+
tune.with_resources(train_fn, {"gpu": 1}),
|
|
577
|
+
param_space=search_space,
|
|
578
|
+
tune_config=tune.TuneConfig(
|
|
579
|
+
num_samples=num_samples,
|
|
580
|
+
scheduler=scheduler,
|
|
581
|
+
search_alg=OptunaSearch(),
|
|
582
|
+
),
|
|
583
|
+
run_config=RunConfig(
|
|
584
|
+
name="hpo_experiment",
|
|
585
|
+
checkpoint_config=CheckpointConfig(
|
|
586
|
+
num_to_keep=3,
|
|
587
|
+
checkpoint_frequency=1,
|
|
588
|
+
),
|
|
589
|
+
),
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
results = tuner.fit()
|
|
593
|
+
best_result = results.get_best_result("loss", "min")
|
|
594
|
+
|
|
595
|
+
logger.info(f"Best config: {best_result.config}")
|
|
596
|
+
logger.info(f"Best loss: {best_result.metrics['loss']}")
|
|
597
|
+
|
|
598
|
+
return results
|
|
599
|
+
```
|
|
600
|
+
|
|
601
|
+
---
|
|
602
|
+
|
|
603
|
+
## Resource Management
|
|
604
|
+
|
|
605
|
+
### GPU Memory Optimization
|
|
606
|
+
|
|
607
|
+
```python
|
|
608
|
+
import torch
|
|
609
|
+
from contextlib import contextmanager
|
|
610
|
+
|
|
611
|
+
@contextmanager
|
|
612
|
+
def gpu_memory_manager():
|
|
613
|
+
"""Context manager for GPU memory cleanup."""
|
|
614
|
+
try:
|
|
615
|
+
yield
|
|
616
|
+
finally:
|
|
617
|
+
torch.cuda.empty_cache()
|
|
618
|
+
torch.cuda.synchronize()
|
|
619
|
+
|
|
620
|
+
def get_gpu_memory_usage() -> dict:
|
|
621
|
+
"""Get current GPU memory statistics."""
|
|
622
|
+
if not torch.cuda.is_available():
|
|
623
|
+
return {"available": False}
|
|
624
|
+
|
|
625
|
+
return {
|
|
626
|
+
"allocated": torch.cuda.memory_allocated() / 1e9,
|
|
627
|
+
"reserved": torch.cuda.memory_reserved() / 1e9,
|
|
628
|
+
"max_allocated": torch.cuda.max_memory_allocated() / 1e9,
|
|
629
|
+
}
|
|
630
|
+
|
|
631
|
+
class GradientCheckpointing:
|
|
632
|
+
"""Enable gradient checkpointing for memory efficiency."""
|
|
633
|
+
|
|
634
|
+
@staticmethod
|
|
635
|
+
def enable(model: nn.Module, checkpoint_layers: list[str] = None) -> None:
|
|
636
|
+
"""Enable gradient checkpointing on specified layers."""
|
|
637
|
+
if hasattr(model, "gradient_checkpointing_enable"):
|
|
638
|
+
model.gradient_checkpointing_enable()
|
|
639
|
+
return
|
|
640
|
+
|
|
641
|
+
# Manual checkpointing for custom models
|
|
642
|
+
from torch.utils.checkpoint import checkpoint
|
|
643
|
+
|
|
644
|
+
def create_custom_forward(module):
|
|
645
|
+
def custom_forward(*inputs):
|
|
646
|
+
return checkpoint(module._original_forward, *inputs, use_reentrant=False)
|
|
647
|
+
return custom_forward
|
|
648
|
+
|
|
649
|
+
for name, module in model.named_modules():
|
|
650
|
+
if checkpoint_layers and name not in checkpoint_layers:
|
|
651
|
+
continue
|
|
652
|
+
if hasattr(module, "forward"):
|
|
653
|
+
module._original_forward = module.forward
|
|
654
|
+
module.forward = create_custom_forward(module)
|
|
655
|
+
```
|
|
656
|
+
|
|
657
|
+
### Batch Size Finder
|
|
658
|
+
|
|
659
|
+
```python
|
|
660
|
+
def find_optimal_batch_size(
|
|
661
|
+
model: nn.Module,
|
|
662
|
+
sample_batch: dict,
|
|
663
|
+
device: torch.device,
|
|
664
|
+
min_batch_size: int = 1,
|
|
665
|
+
max_batch_size: int = 256,
|
|
666
|
+
) -> int:
|
|
667
|
+
"""Find maximum batch size that fits in GPU memory."""
|
|
668
|
+
|
|
669
|
+
model = model.to(device)
|
|
670
|
+
optimal_batch_size = min_batch_size
|
|
671
|
+
|
|
672
|
+
for batch_size in [2**i for i in range(int(np.log2(max_batch_size)) + 1)]:
|
|
673
|
+
if batch_size < min_batch_size:
|
|
674
|
+
continue
|
|
675
|
+
|
|
676
|
+
try:
|
|
677
|
+
# Create batch of target size
|
|
678
|
+
batch = {
|
|
679
|
+
k: v.repeat(batch_size // v.size(0) + 1, *[1] * (v.dim() - 1))[:batch_size]
|
|
680
|
+
for k, v in sample_batch.items()
|
|
681
|
+
}
|
|
682
|
+
batch = {k: v.to(device) for k, v in batch.items()}
|
|
683
|
+
|
|
684
|
+
# Forward pass
|
|
685
|
+
with torch.amp.autocast("cuda"):
|
|
686
|
+
outputs = model(**batch)
|
|
687
|
+
loss = outputs.loss
|
|
688
|
+
|
|
689
|
+
# Backward pass
|
|
690
|
+
loss.backward()
|
|
691
|
+
model.zero_grad()
|
|
692
|
+
|
|
693
|
+
torch.cuda.empty_cache()
|
|
694
|
+
optimal_batch_size = batch_size
|
|
695
|
+
|
|
696
|
+
except RuntimeError as e:
|
|
697
|
+
if "out of memory" in str(e):
|
|
698
|
+
torch.cuda.empty_cache()
|
|
699
|
+
break
|
|
700
|
+
raise
|
|
701
|
+
|
|
702
|
+
logger.info(f"Optimal batch size: {optimal_batch_size}")
|
|
703
|
+
return optimal_batch_size
|
|
704
|
+
```
|
|
705
|
+
|
|
706
|
+
---
|
|
707
|
+
|
|
708
|
+
## Best Practices
|
|
709
|
+
|
|
710
|
+
### Training Configuration Management
|
|
711
|
+
|
|
712
|
+
```yaml
|
|
713
|
+
# config/training_config.yaml
|
|
714
|
+
model:
|
|
715
|
+
name: transformer
|
|
716
|
+
hidden_size: 512
|
|
717
|
+
num_layers: 6
|
|
718
|
+
dropout: 0.1
|
|
719
|
+
|
|
720
|
+
training:
|
|
721
|
+
batch_size: 32
|
|
722
|
+
learning_rate: 1e-4
|
|
723
|
+
weight_decay: 0.01
|
|
724
|
+
epochs: 10
|
|
725
|
+
mixed_precision: true
|
|
726
|
+
gradient_accumulation_steps: 4
|
|
727
|
+
|
|
728
|
+
distributed:
|
|
729
|
+
enabled: true
|
|
730
|
+
backend: nccl
|
|
731
|
+
|
|
732
|
+
checkpointing:
|
|
733
|
+
save_every_n_steps: 1000
|
|
734
|
+
keep_n_checkpoints: 3
|
|
735
|
+
|
|
736
|
+
logging:
|
|
737
|
+
log_every_n_steps: 100
|
|
738
|
+
eval_every_n_steps: 500
|
|
739
|
+
```
|
|
740
|
+
|
|
741
|
+
### Reproducibility Checklist
|
|
742
|
+
|
|
743
|
+
```python
|
|
744
|
+
def ensure_reproducibility(seed: int) -> None:
|
|
745
|
+
"""Set all random seeds for reproducibility."""
|
|
746
|
+
import random
|
|
747
|
+
import numpy as np
|
|
748
|
+
import os
|
|
749
|
+
|
|
750
|
+
# Python
|
|
751
|
+
random.seed(seed)
|
|
752
|
+
|
|
753
|
+
# NumPy
|
|
754
|
+
np.random.seed(seed)
|
|
755
|
+
|
|
756
|
+
# PyTorch
|
|
757
|
+
torch.manual_seed(seed)
|
|
758
|
+
torch.cuda.manual_seed_all(seed)
|
|
759
|
+
|
|
760
|
+
# CUDA
|
|
761
|
+
torch.backends.cudnn.deterministic = True
|
|
762
|
+
torch.backends.cudnn.benchmark = False
|
|
763
|
+
|
|
764
|
+
# Environment
|
|
765
|
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
|
766
|
+
|
|
767
|
+
logger.info(f"Set all random seeds to {seed}")
|
|
768
|
+
```
|
|
769
|
+
|
|
770
|
+
---
|
|
771
|
+
|
|
772
|
+
## Related References
|
|
773
|
+
|
|
774
|
+
- `feature-engineering.md` - Feature preparation for training
|
|
775
|
+
- `experiment-tracking.md` - Logging training metrics
|
|
776
|
+
- `pipeline-orchestration.md` - Orchestrating training pipelines
|
|
777
|
+
- `model-validation.md` - Validating trained models
|
|
778
|
+
|
|
779
|
+
## Cross-Reference Skills
|
|
780
|
+
|
|
781
|
+
- **DevOps Engineer** - CI/CD for training pipelines
|
|
782
|
+
- **Kubernetes Specialist** - K8s-based training infrastructure
|