@aws/ml-container-creator 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. package/LICENSE +202 -0
  2. package/LICENSE-THIRD-PARTY +68620 -0
  3. package/NOTICE +2 -0
  4. package/README.md +106 -0
  5. package/bin/cli.js +365 -0
  6. package/config/defaults.json +32 -0
  7. package/config/presets/transformers-djl.json +26 -0
  8. package/config/presets/transformers-gpu.json +24 -0
  9. package/config/presets/transformers-lmi.json +27 -0
  10. package/package.json +129 -0
  11. package/servers/README.md +419 -0
  12. package/servers/base-image-picker/catalogs/model-servers.json +1191 -0
  13. package/servers/base-image-picker/catalogs/python-slim.json +38 -0
  14. package/servers/base-image-picker/catalogs/triton-backends.json +51 -0
  15. package/servers/base-image-picker/catalogs/triton.json +38 -0
  16. package/servers/base-image-picker/index.js +495 -0
  17. package/servers/base-image-picker/manifest.json +17 -0
  18. package/servers/base-image-picker/package.json +15 -0
  19. package/servers/hyperpod-cluster-picker/LICENSE +202 -0
  20. package/servers/hyperpod-cluster-picker/index.js +424 -0
  21. package/servers/hyperpod-cluster-picker/manifest.json +14 -0
  22. package/servers/hyperpod-cluster-picker/package.json +17 -0
  23. package/servers/instance-recommender/LICENSE +202 -0
  24. package/servers/instance-recommender/catalogs/instances.json +852 -0
  25. package/servers/instance-recommender/index.js +284 -0
  26. package/servers/instance-recommender/manifest.json +16 -0
  27. package/servers/instance-recommender/package.json +15 -0
  28. package/servers/lib/LICENSE +202 -0
  29. package/servers/lib/bedrock-client.js +160 -0
  30. package/servers/lib/custom-validators.js +46 -0
  31. package/servers/lib/dynamic-resolver.js +36 -0
  32. package/servers/lib/package.json +11 -0
  33. package/servers/lib/schemas/image-catalog.schema.json +185 -0
  34. package/servers/lib/schemas/instances.schema.json +124 -0
  35. package/servers/lib/schemas/manifest.schema.json +64 -0
  36. package/servers/lib/schemas/model-catalog.schema.json +91 -0
  37. package/servers/lib/schemas/regions.schema.json +26 -0
  38. package/servers/lib/schemas/triton-backends.schema.json +51 -0
  39. package/servers/model-picker/catalogs/jumpstart-public.json +66 -0
  40. package/servers/model-picker/catalogs/popular-diffusors.json +88 -0
  41. package/servers/model-picker/catalogs/popular-transformers.json +226 -0
  42. package/servers/model-picker/index.js +1693 -0
  43. package/servers/model-picker/manifest.json +18 -0
  44. package/servers/model-picker/package.json +20 -0
  45. package/servers/region-picker/LICENSE +202 -0
  46. package/servers/region-picker/catalogs/regions.json +263 -0
  47. package/servers/region-picker/index.js +230 -0
  48. package/servers/region-picker/manifest.json +16 -0
  49. package/servers/region-picker/package.json +15 -0
  50. package/src/app.js +1007 -0
  51. package/src/copy-tpl.js +77 -0
  52. package/src/lib/accelerator-validator.js +39 -0
  53. package/src/lib/asset-manager.js +385 -0
  54. package/src/lib/aws-profile-parser.js +181 -0
  55. package/src/lib/bootstrap-command-handler.js +1647 -0
  56. package/src/lib/bootstrap-config.js +238 -0
  57. package/src/lib/ci-register-helpers.js +124 -0
  58. package/src/lib/ci-report-helpers.js +158 -0
  59. package/src/lib/ci-stage-helpers.js +268 -0
  60. package/src/lib/cli-handler.js +529 -0
  61. package/src/lib/comment-generator.js +544 -0
  62. package/src/lib/community-reports-validator.js +91 -0
  63. package/src/lib/config-manager.js +2106 -0
  64. package/src/lib/configuration-exporter.js +204 -0
  65. package/src/lib/configuration-manager.js +695 -0
  66. package/src/lib/configuration-matcher.js +221 -0
  67. package/src/lib/cpu-validator.js +36 -0
  68. package/src/lib/cuda-validator.js +57 -0
  69. package/src/lib/deployment-config-resolver.js +103 -0
  70. package/src/lib/deployment-entry-schema.js +125 -0
  71. package/src/lib/deployment-registry.js +598 -0
  72. package/src/lib/docker-introspection-validator.js +51 -0
  73. package/src/lib/engine-prefix-resolver.js +60 -0
  74. package/src/lib/huggingface-client.js +172 -0
  75. package/src/lib/key-value-parser.js +37 -0
  76. package/src/lib/known-flags-validator.js +200 -0
  77. package/src/lib/manifest-cli.js +280 -0
  78. package/src/lib/mcp-client.js +303 -0
  79. package/src/lib/mcp-command-handler.js +532 -0
  80. package/src/lib/neuron-validator.js +80 -0
  81. package/src/lib/parameter-schema-validator.js +284 -0
  82. package/src/lib/prompt-runner.js +1349 -0
  83. package/src/lib/prompts.js +1138 -0
  84. package/src/lib/registry-command-handler.js +519 -0
  85. package/src/lib/registry-loader.js +198 -0
  86. package/src/lib/rocm-validator.js +80 -0
  87. package/src/lib/schema-validator.js +157 -0
  88. package/src/lib/sensitive-redactor.js +59 -0
  89. package/src/lib/template-engine.js +156 -0
  90. package/src/lib/template-manager.js +341 -0
  91. package/src/lib/validation-engine.js +314 -0
  92. package/src/prompt-adapter.js +63 -0
  93. package/templates/Dockerfile +300 -0
  94. package/templates/IAM_PERMISSIONS.md +84 -0
  95. package/templates/MIGRATION.md +488 -0
  96. package/templates/PROJECT_README.md +439 -0
  97. package/templates/TEMPLATE_SYSTEM.md +243 -0
  98. package/templates/buildspec.yml +64 -0
  99. package/templates/code/chat_template.jinja +1 -0
  100. package/templates/code/flask/gunicorn_config.py +35 -0
  101. package/templates/code/flask/wsgi.py +10 -0
  102. package/templates/code/model_handler.py +387 -0
  103. package/templates/code/serve +300 -0
  104. package/templates/code/serve.py +175 -0
  105. package/templates/code/serving.properties +105 -0
  106. package/templates/code/start_server.py +39 -0
  107. package/templates/code/start_server.sh +39 -0
  108. package/templates/diffusors/Dockerfile +72 -0
  109. package/templates/diffusors/patch_image_api.py +35 -0
  110. package/templates/diffusors/serve +115 -0
  111. package/templates/diffusors/start_server.sh +114 -0
  112. package/templates/do/.gitkeep +1 -0
  113. package/templates/do/README.md +541 -0
  114. package/templates/do/build +83 -0
  115. package/templates/do/ci +681 -0
  116. package/templates/do/clean +811 -0
  117. package/templates/do/config +260 -0
  118. package/templates/do/deploy +1560 -0
  119. package/templates/do/export +306 -0
  120. package/templates/do/logs +319 -0
  121. package/templates/do/manifest +12 -0
  122. package/templates/do/push +119 -0
  123. package/templates/do/register +580 -0
  124. package/templates/do/run +113 -0
  125. package/templates/do/submit +417 -0
  126. package/templates/do/test +1147 -0
  127. package/templates/hyperpod/configmap.yaml +24 -0
  128. package/templates/hyperpod/deployment.yaml +71 -0
  129. package/templates/hyperpod/pvc.yaml +42 -0
  130. package/templates/hyperpod/service.yaml +17 -0
  131. package/templates/nginx-diffusors.conf +74 -0
  132. package/templates/nginx-predictors.conf +47 -0
  133. package/templates/nginx-tensorrt.conf +74 -0
  134. package/templates/requirements.txt +61 -0
  135. package/templates/sample_model/test_inference.py +123 -0
  136. package/templates/sample_model/train_abalone.py +252 -0
  137. package/templates/test/test_endpoint.sh +79 -0
  138. package/templates/test/test_local_image.sh +80 -0
  139. package/templates/test/test_model_handler.py +180 -0
  140. package/templates/triton/Dockerfile +128 -0
  141. package/templates/triton/config.pbtxt +163 -0
  142. package/templates/triton/model.py +130 -0
  143. package/templates/triton/requirements.txt +11 -0
package/src/app.js ADDED
@@ -0,0 +1,1007 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ import fs from 'fs'
5
+ import path from 'path'
6
+ import { fileURLToPath } from 'url'
7
+ import { spawn } from 'child_process'
8
+
9
+ import { copyTpl } from './copy-tpl.js'
10
+ import { runPrompts } from './prompt-adapter.js'
11
+ import ConfigManager from './lib/config-manager.js'
12
+ import PromptRunner from './lib/prompt-runner.js'
13
+ import TemplateManager from './lib/template-manager.js'
14
+ import DeploymentConfigResolver from './lib/deployment-config-resolver.js'
15
+ import CommentGenerator from './lib/comment-generator.js'
16
+ import ConfigurationManager from './lib/configuration-manager.js'
17
+ import RegistryLoader from './lib/registry-loader.js'
18
+ import { resolvePrefixedEnvVars } from './lib/engine-prefix-resolver.js'
19
+ import ejs from 'ejs'
20
+
21
+ const __filename = fileURLToPath(import.meta.url)
22
+ const __dirname = path.dirname(__filename)
23
+ const GENERATOR_ROOT = path.resolve(__dirname, '..')
24
+ const TEMPLATE_DIR = path.join(GENERATOR_ROOT, 'templates')
25
+ const LIB_DIR = path.join(GENERATOR_ROOT, 'src', 'lib')
26
+
27
+ /**
28
+ * Main application entry point.
29
+ * Orchestrates the ML Container Creator generation workflow,
30
+ * replicating the original generator lifecycle phases:
31
+ * initializing → prompting → writing → end
32
+ *
33
+ * @param {string|undefined} projectName - Name for the generated project (from positional argument)
34
+ * @param {object} options - Parsed CLI options from commander
35
+ */
36
+ export async function run(projectName, options) {
37
+ // --- Phase: Initializing ---
38
+ // Convert commander's camelCase options to kebab-case for ConfigManager compatibility
39
+ // (ConfigManager expects kebab-case format for option keys)
40
+ const kebabOptions = _toKebabCaseOptions(options)
41
+
42
+ // Build a lightweight adapter that satisfies ConfigManager's generator interface
43
+ const generatorAdapter = _createGeneratorAdapter(projectName, kebabOptions)
44
+ const args = projectName ? [projectName] : []
45
+
46
+ const configManager = new ConfigManager({ options: kebabOptions, args })
47
+
48
+ let baseConfig
49
+ try {
50
+ baseConfig = await configManager.loadConfiguration()
51
+ } catch (error) {
52
+ console.log(`āš ļø ${error.message}`)
53
+ return
54
+ }
55
+
56
+ const errors = configManager.validateConfiguration()
57
+ if (errors.length > 0) {
58
+ console.log(`āš ļø ${errors[0]}`)
59
+ return
60
+ }
61
+
62
+ // Initialize registry system
63
+ let registryConfigManager = null
64
+ let tritonBackends = {}
65
+ try {
66
+ const validateEnvVars = kebabOptions['validate-env-vars'] !== false
67
+ const validateWithDocker = kebabOptions['validate-with-docker'] === true
68
+ const offline = kebabOptions['offline'] === true
69
+
70
+ let effectiveValidateWithDocker = validateWithDocker
71
+ if (validateWithDocker && !validateEnvVars) {
72
+ console.log('\nāš ļø Warning: --validate-with-docker requires --validate-env-vars to be enabled')
73
+ console.log(' Docker validation will be disabled')
74
+ effectiveValidateWithDocker = false
75
+ }
76
+
77
+ registryConfigManager = new ConfigurationManager({
78
+ validateEnvVars,
79
+ validateWithDocker: effectiveValidateWithDocker,
80
+ offline,
81
+ hfTimeout: 5000
82
+ })
83
+
84
+ await registryConfigManager.loadRegistries()
85
+
86
+ const registryLoader = new RegistryLoader()
87
+ tritonBackends = await registryLoader.loadTritonBackends()
88
+
89
+ console.log('\nšŸ“š Registry System Initialized')
90
+ console.log(' • Framework Registry: Loaded')
91
+ console.log(' • Model Registry: Loaded')
92
+ console.log(' • Instance Accelerator Mapping: Loaded')
93
+
94
+ if (validateEnvVars) {
95
+ console.log(' • Environment Variable Validation: Enabled')
96
+ if (effectiveValidateWithDocker) {
97
+ console.log(' • Docker Introspection Validation: Enabled (experimental)')
98
+ }
99
+ } else {
100
+ console.log(' • Environment Variable Validation: Disabled')
101
+ }
102
+
103
+ if (offline) {
104
+ console.log(' • HuggingFace API: Offline mode')
105
+ }
106
+ } catch (error) {
107
+ console.log('\nāš ļø Registry system initialization failed, using defaults')
108
+ console.log(` Error: ${error.message}`)
109
+ registryConfigManager = null
110
+ tritonBackends = {}
111
+ }
112
+
113
+ // Attach registry info to the adapter so PromptRunner can access it
114
+ generatorAdapter.registryConfigManager = registryConfigManager
115
+ generatorAdapter.configManager = configManager
116
+ generatorAdapter.baseConfig = baseConfig
117
+
118
+ // --- Phase: Prompting ---
119
+ let answers
120
+ if (configManager.shouldSkipPrompts()) {
121
+ console.log('\nšŸš€ Skipping prompts - using configuration from other sources')
122
+ answers = configManager.getFinalConfiguration()
123
+
124
+ // Infer modelSource from model name prefix if not set
125
+ const modelName = answers.modelName
126
+ if (!answers.modelSource && modelName) {
127
+ if (modelName.startsWith('s3://')) {
128
+ answers.modelSource = 's3'
129
+ if (!answers.artifactUri) {
130
+ answers.artifactUri = modelName
131
+ }
132
+ } else if (modelName.startsWith('jumpstart://')) {
133
+ answers.modelSource = 'jumpstart'
134
+ } else if (modelName.startsWith('jumpstart-hub://')) {
135
+ answers.modelSource = 'jumpstart-hub'
136
+ } else if (modelName.startsWith('registry://')) {
137
+ answers.modelSource = 'registry'
138
+ }
139
+ }
140
+
141
+ // Warn about unsupported model sources
142
+ if (answers.modelSource === 'jumpstart-hub') {
143
+ console.log('\n āš ļø JumpStart Private Hub models are not yet fully supported.')
144
+ console.log(' The generated project will not be able to download model artifacts at runtime.')
145
+ console.log(' This feature is tracked for a future release.')
146
+ console.log(' Falling back to HuggingFace source.\n')
147
+ answers.modelSource = 'huggingface'
148
+ delete answers.artifactUri
149
+ }
150
+
151
+ // Note about registry model requirements
152
+ if (answers.modelSource === 'registry') {
153
+ console.log('\n ā„¹ļø Registry model: the container will resolve the artifact URI at startup')
154
+ console.log(' via DescribeModelPackage. Ensure the model package has a valid')
155
+ console.log(' InferenceSpecification with ModelDataUrl or S3DataSource.')
156
+ console.log(' If your model package lacks an InferenceSpecification, use the S3 path')
157
+ console.log(' directly instead: --model-name="s3://bucket/path/model.tar.gz"\n')
158
+ }
159
+ } else {
160
+ const promptRunner = new PromptRunner({
161
+ configManager,
162
+ options: kebabOptions,
163
+ registryConfigManager,
164
+ baseConfig
165
+ })
166
+ const promptAnswers = await promptRunner.run()
167
+ answers = configManager.getFinalConfiguration(promptAnswers)
168
+ }
169
+
170
+ // Ensure template variables have defaults and enrich with registry data
171
+ await _ensureTemplateVariables(answers, registryConfigManager)
172
+
173
+ // --- Phase: Writing ---
174
+ const destDir = path.resolve(answers.destinationDir)
175
+ fs.mkdirSync(destDir, { recursive: true })
176
+
177
+ await writeProject(TEMPLATE_DIR, destDir, answers, registryConfigManager, tritonBackends, configManager)
178
+
179
+ // --- Phase: End ---
180
+ await postGenerate(destDir, answers, tritonBackends)
181
+
182
+ console.log('\nāœ… Project generated successfully!')
183
+ console.log(` šŸ“ ${destDir}`)
184
+ }
185
+
186
+ /**
187
+ * Writes the project files from templates to the destination directory.
188
+ * Replicates the writing() phase of the original generator.
189
+ *
190
+ * @param {string} templateDir - Path to the template directory
191
+ * @param {string} destDir - Path to the destination directory
192
+ * @param {object} answers - Merged configuration answers
193
+ * @param {object|null} registryConfigManager - Registry configuration manager (or null)
194
+ * @param {object} tritonBackends - Triton backends catalog
195
+ */
196
+ export async function writeProject(templateDir, destDir, answers, registryConfigManager = null, tritonBackends = {}, configManager = null) {
197
+ // Validate required parameters via ConfigManager
198
+ if (configManager) {
199
+ const requiredParamErrors = configManager.validateRequiredParameters(answers)
200
+ if (requiredParamErrors.length > 0) {
201
+ console.log('\nāŒ Required Parameter Validation Failed:')
202
+ requiredParamErrors.forEach(error => {
203
+ console.log(` • ${error}`)
204
+ })
205
+ console.log('\nPlease provide the missing required parameters and try again.')
206
+ throw new Error('Required parameters are missing. Cannot proceed with file generation.')
207
+ }
208
+ }
209
+
210
+ // Validate environment variables if registry system is available
211
+ if (registryConfigManager && (answers.frameworkVersion || answers.architecture === 'triton')) {
212
+ await _validateEnvironmentVariables(answers, registryConfigManager)
213
+ }
214
+
215
+ // Validate template configuration
216
+ const templateManager = new TemplateManager(answers)
217
+ templateManager.validate()
218
+
219
+ // Generate comments for templates
220
+ const commentGenerator = new CommentGenerator()
221
+ const comments = commentGenerator.generateDockerfileComments(answers)
222
+
223
+ // Prepare ordered environment variables
224
+ const orderedEnvVars = _getOrderedEnvVars(answers.envVars || {})
225
+
226
+ // Append model env vars and prefixed server env vars
227
+ const modelEnvVars = answers.modelEnvVars || {}
228
+ const serverEnvVars = answers.serverEnvVars || {}
229
+ const engine = answers.modelServer || answers.backend || ''
230
+
231
+ Object.entries(modelEnvVars).forEach(([key, value]) => {
232
+ orderedEnvVars.push({ key, value })
233
+ })
234
+
235
+ const prefixedServerEnvVars = resolvePrefixedEnvVars(engine, serverEnvVars)
236
+ Object.entries(prefixedServerEnvVars).forEach(([key, value]) => {
237
+ orderedEnvVars.push({ key, value })
238
+ })
239
+
240
+ // Prepare template variables
241
+ const templateVars = {
242
+ ...answers,
243
+ comments,
244
+ orderedEnvVars,
245
+ serverEnvVars: prefixedServerEnvVars
246
+ }
247
+
248
+ // Build ignore patterns
249
+ const ignorePatterns = []
250
+
251
+ if (answers.deploymentTarget !== 'hyperpod-eks') {
252
+ ignorePatterns.push('**/hyperpod/**')
253
+ }
254
+
255
+ // Resolve architecture
256
+ const resolver = new DeploymentConfigResolver()
257
+ let architecture = answers.architecture
258
+
259
+ if (!architecture && answers.deploymentConfig) {
260
+ try {
261
+ const parts = resolver.decompose(answers.deploymentConfig)
262
+ architecture = parts.architecture
263
+ } catch (e) {
264
+ architecture = answers.framework === 'transformers' ? 'transformers' : 'http'
265
+ }
266
+ } else if (!architecture) {
267
+ architecture = answers.framework === 'transformers' ? 'transformers' : 'http'
268
+ }
269
+
270
+ // Exclude sample_model when not needed
271
+ if (!answers.includeSampleModel || architecture === 'transformers' || architecture === 'diffusors') {
272
+ ignorePatterns.push('**/sample_model/**')
273
+ }
274
+
275
+ // Always exclude triton and diffusors source directories
276
+ ignorePatterns.push('**/triton/**')
277
+ ignorePatterns.push('**/diffusors/**')
278
+
279
+ // For triton and diffusors, exclude the default Dockerfile
280
+ if (architecture === 'triton' || architecture === 'diffusors') {
281
+ ignorePatterns.push('**/Dockerfile')
282
+ }
283
+
284
+ // Copy all templates with EJS rendering
285
+ copyTpl(templateDir, destDir, templateVars, ignorePatterns)
286
+
287
+ // Architecture-specific file routing (delete files that don't belong)
288
+ switch (architecture) {
289
+ case 'http':
290
+ _unlinkIfExists(path.join(destDir, 'code/chat_template.jinja'))
291
+ _unlinkIfExists(path.join(destDir, 'code/serve'))
292
+ _unlinkIfExists(path.join(destDir, 'code/serving.properties'))
293
+ _unlinkIfExists(path.join(destDir, 'code/start_server.sh'))
294
+
295
+ if (answers.modelServer !== 'flask' && answers.backend !== 'flask') {
296
+ _unlinkIfExists(path.join(destDir, 'code/flask/wsgi.py'))
297
+ _unlinkIfExists(path.join(destDir, 'code/flask/gunicorn_config.py'))
298
+ }
299
+ break
300
+
301
+ case 'transformers':
302
+ _unlinkIfExists(path.join(destDir, 'code/model_handler.py'))
303
+ _unlinkIfExists(path.join(destDir, 'code/serve.py'))
304
+ _unlinkIfExists(path.join(destDir, 'code/start_server.py'))
305
+ _unlinkIfExists(path.join(destDir, 'nginx-predictors.conf'))
306
+ _unlinkIfExists(path.join(destDir, 'code/flask/wsgi.py'))
307
+ _unlinkIfExists(path.join(destDir, 'code/flask/gunicorn_config.py'))
308
+ break
309
+
310
+ case 'triton':
311
+ _unlinkIfExists(path.join(destDir, 'code/serve.py'))
312
+ _unlinkIfExists(path.join(destDir, 'code/model_handler.py'))
313
+ _unlinkIfExists(path.join(destDir, 'code/start_server.py'))
314
+ _unlinkIfExists(path.join(destDir, 'nginx-predictors.conf'))
315
+ _unlinkIfExists(path.join(destDir, 'code/flask/wsgi.py'))
316
+ _unlinkIfExists(path.join(destDir, 'code/flask/gunicorn_config.py'))
317
+ _unlinkIfExists(path.join(destDir, 'code/chat_template.jinja'))
318
+ _unlinkIfExists(path.join(destDir, 'code/serve'))
319
+ _unlinkIfExists(path.join(destDir, 'code/serving.properties'))
320
+ _unlinkIfExists(path.join(destDir, 'code/start_server.sh'))
321
+
322
+ // Generate Triton-specific files
323
+ _generateTritonFiles(templateDir, destDir, templateVars, answers, tritonBackends)
324
+ break
325
+
326
+ case 'diffusors':
327
+ _unlinkIfExists(path.join(destDir, 'code/model_handler.py'))
328
+ _unlinkIfExists(path.join(destDir, 'code/serve.py'))
329
+ _unlinkIfExists(path.join(destDir, 'code/start_server.py'))
330
+ _unlinkIfExists(path.join(destDir, 'nginx-predictors.conf'))
331
+ _unlinkIfExists(path.join(destDir, 'code/flask/wsgi.py'))
332
+ _unlinkIfExists(path.join(destDir, 'code/flask/gunicorn_config.py'))
333
+ _unlinkIfExists(path.join(destDir, 'code/chat_template.jinja'))
334
+ _unlinkIfExists(path.join(destDir, 'code/serving.properties'))
335
+
336
+ // Copy diffusors-specific templates
337
+ _renderTemplate(path.join(templateDir, 'diffusors/Dockerfile'), path.join(destDir, 'Dockerfile'), templateVars)
338
+ _renderTemplate(path.join(templateDir, 'diffusors/serve'), path.join(destDir, 'code/serve'), templateVars)
339
+ _renderTemplate(path.join(templateDir, 'diffusors/start_server.sh'), path.join(destDir, 'code/start_server.sh'), templateVars)
340
+ _copyFile(path.join(templateDir, 'diffusors/patch_image_api.py'), path.join(destDir, 'code/patch_image_api.py'))
341
+ break
342
+
343
+ default:
344
+ // Fallback to HTTP behavior
345
+ _unlinkIfExists(path.join(destDir, 'code/chat_template.jinja'))
346
+ _unlinkIfExists(path.join(destDir, 'code/serve'))
347
+ _unlinkIfExists(path.join(destDir, 'code/serving.properties'))
348
+ _unlinkIfExists(path.join(destDir, 'code/start_server.sh'))
349
+ }
350
+
351
+ // nginx-tensorrt.conf: only needed for TensorRT-LLM
352
+ if (answers.modelServer !== 'tensorrt-llm' && answers.backend !== 'tensorrt-llm') {
353
+ _unlinkIfExists(path.join(destDir, 'nginx-tensorrt.conf'))
354
+ }
355
+
356
+ // nginx-diffusors.conf: only needed for diffusors architecture
357
+ if (answers.architecture !== 'diffusors') {
358
+ _unlinkIfExists(path.join(destDir, 'nginx-diffusors.conf'))
359
+ }
360
+
361
+ // Copy PROJECT_README.md as README.md (overwriting the template README)
362
+ _renderTemplate(path.join(templateDir, 'PROJECT_README.md'), path.join(destDir, 'README.md'), templateVars)
363
+
364
+ // Copy do/lib/ Node.js modules (plain copy, no EJS)
365
+ const doLibDir = path.join(destDir, 'do', 'lib')
366
+ fs.mkdirSync(doLibDir, { recursive: true })
367
+ _copyFile(path.join(LIB_DIR, 'manifest-cli.js'), path.join(doLibDir, 'manifest-cli.js'))
368
+ _copyFile(path.join(LIB_DIR, 'asset-manager.js'), path.join(doLibDir, 'asset-manager.js'))
369
+ _copyFile(path.join(LIB_DIR, 'bootstrap-config.js'), path.join(doLibDir, 'bootstrap-config.js'))
370
+ }
371
+
372
+ /**
373
+ * Post-generation tasks: set permissions and run sample model training.
374
+ * Replicates the end() phase of the original generator.
375
+ *
376
+ * @param {string} destDir - Path to the generated project directory
377
+ * @param {object} answers - Merged configuration answers
378
+ * @param {object} tritonBackends - Triton backends catalog
379
+ */
380
+ export async function postGenerate(destDir, answers, tritonBackends = {}) {
381
+ // Set executable permissions on shell scripts
382
+ _setExecutablePermissions(destDir)
383
+
384
+ // Run sample model training if requested
385
+ const architecture = answers.architecture
386
+ const skipSampleTraining = architecture === 'transformers' ||
387
+ (architecture === 'triton' && !tritonBackends[answers.backend]?.supportsSampleModel)
388
+
389
+ if (answers.includeSampleModel && !skipSampleTraining) {
390
+ await _runSampleModelTraining(destDir)
391
+ }
392
+ }
393
+
394
+ // --- Private helpers ---
395
+
396
+ /**
397
+ * Converts commander's camelCase options to kebab-case keys.
398
+ * ConfigManager expects kebab-case keys (e.g., 'skip-prompts', 'deployment-config')
399
+ * because ConfigManager uses kebab-case internally. Commander converts --skip-prompts to skipPrompts.
400
+ *
401
+ * @param {object} options - Commander options object (camelCase keys)
402
+ * @returns {object} Options with kebab-case keys
403
+ */
404
+ function _toKebabCaseOptions(options) {
405
+ const kebabOptions = {}
406
+ for (const [key, value] of Object.entries(options)) {
407
+ // Convert camelCase to kebab-case
408
+ const kebabKey = key.replace(/([A-Z])/g, '-$1').toLowerCase()
409
+ kebabOptions[kebabKey] = value
410
+ }
411
+ return kebabOptions
412
+ }
413
+
414
+ /**
415
+ * Creates a lightweight adapter object that satisfies the generator interface
416
+ * expected by ConfigManager and PromptRunner.
417
+ *
418
+ * @param {string|undefined} projectName - Positional project name argument
419
+ * @param {object} options - Commander options object
420
+ * @returns {object} Generator-like adapter
421
+ */
422
+ function _createGeneratorAdapter(projectName, options) {
423
+ const args = projectName ? [projectName] : []
424
+ let _destinationPath = process.cwd()
425
+
426
+ const adapter = {
427
+ options,
428
+ args,
429
+ destinationPath(...segments) {
430
+ if (segments.length === 0) return _destinationPath
431
+ return path.join(_destinationPath, ...segments)
432
+ },
433
+ destinationRoot(newRoot) {
434
+ if (newRoot !== undefined) {
435
+ _destinationPath = path.resolve(newRoot)
436
+ }
437
+ return _destinationPath
438
+ },
439
+ registryConfigManager: null,
440
+ configManager: null,
441
+ baseConfig: {},
442
+ async prompt(prompts) {
443
+ return runPrompts(prompts)
444
+ }
445
+ }
446
+
447
+ return adapter
448
+ }
449
+
450
+ /**
451
+ * Ensures all template variables have proper defaults to prevent
452
+ * "undefined" errors in EJS templates. Also enriches answers with
453
+ * registry data (env var merging, HuggingFace data, Triton base image).
454
+ *
455
+ * @param {object} answers - Answers object to fill defaults into
456
+ * @param {object|null} registryConfigManager - Registry configuration manager (or null)
457
+ */
458
+ async function _ensureTemplateVariables(answers, registryConfigManager = null) {
459
+ const defaults = {
460
+ chatTemplate: null,
461
+ chatTemplateSource: null,
462
+ hfToken: null,
463
+ ngcApiKey: null,
464
+ envVars: {},
465
+ inferenceAmiVersion: null,
466
+ accelerator: null,
467
+ frameworkVersion: null,
468
+ validationLevel: 'unknown',
469
+ configSources: [],
470
+ recommendedInstanceTypes: [],
471
+ roleArn: null,
472
+ deploymentConfig: '',
473
+ architecture: null,
474
+ backend: null,
475
+ engine: null,
476
+ codebuildComputeType: null,
477
+ codebuildProjectName: null,
478
+ modelName: null,
479
+ modelFormat: null,
480
+ includeSampleModel: false,
481
+ includeTesting: true,
482
+ testTypes: [],
483
+ buildTimestamp: new Date().toISOString(),
484
+ buildTarget: 'codebuild',
485
+ deploymentTarget: 'managed-inference',
486
+ hyperPodCluster: null,
487
+ hyperPodNamespace: 'default',
488
+ hyperPodReplicas: 1,
489
+ fsxVolumeHandle: null,
490
+ baseImage: null,
491
+ modelSource: 'huggingface',
492
+ artifactUri: '',
493
+ modelLoadStrategy: 'runtime'
494
+ }
495
+
496
+ Object.entries(defaults).forEach(([key, value]) => {
497
+ if (answers[key] === undefined) {
498
+ answers[key] = value
499
+ }
500
+ })
501
+
502
+ // Backward compatibility: populate framework and modelServer from architecture/backend
503
+ if (!answers.framework && answers.architecture) {
504
+ answers.framework = answers.architecture
505
+ }
506
+ if (!answers.modelServer && answers.backend) {
507
+ answers.modelServer = answers.backend
508
+ }
509
+
510
+ // Always include testing with all available test types
511
+ answers.includeTesting = true
512
+ if (!answers.testTypes || answers.testTypes.length === 0) {
513
+ if (answers.architecture === 'transformers' || answers.framework === 'transformers') {
514
+ answers.testTypes = ['hosted-model-endpoint']
515
+ } else {
516
+ answers.testTypes = ['local-model-cli', 'local-model-server', 'hosted-model-endpoint']
517
+ }
518
+ }
519
+
520
+ // Merge catalog env vars into answers.envVars with correct precedence
521
+ await _mergeEnvVarsWithPrecedence(answers, registryConfigManager)
522
+
523
+ // For Triton architecture, set default base image fallback
524
+ if (answers.architecture === 'triton' && !answers.baseImage) {
525
+ // Try to look up base image from framework registry using deployment-config key
526
+ const tritonRegistryKey = answers.deploymentConfig
527
+ if (tritonRegistryKey && registryConfigManager?.frameworkRegistry) {
528
+ const tritonFrameworkConfig = registryConfigManager.frameworkRegistry[tritonRegistryKey]
529
+ if (tritonFrameworkConfig) {
530
+ const versions = Object.keys(tritonFrameworkConfig).sort((a, b) =>
531
+ b.localeCompare(a, undefined, { numeric: true })
532
+ )
533
+ if (versions.length > 0) {
534
+ const latestConfig = tritonFrameworkConfig[versions[0]]
535
+ if (latestConfig.baseImage) {
536
+ answers.baseImage = latestConfig.baseImage
537
+ }
538
+ if (latestConfig.inferenceAmiVersion && !answers.inferenceAmiVersion) {
539
+ answers.inferenceAmiVersion = latestConfig.inferenceAmiVersion
540
+ }
541
+ if (latestConfig.accelerator) {
542
+ answers.accelerator = latestConfig.accelerator
543
+ }
544
+ }
545
+ }
546
+ }
547
+ // Final fallback: hardcoded default Triton base image
548
+ if (!answers.baseImage) {
549
+ answers.baseImage = 'nvcr.io/nvidia/tritonserver:24.08-py3'
550
+ }
551
+ }
552
+
553
+ // For transformer models, enrich with HuggingFace data and non-envVar metadata
554
+ if (answers.framework === 'transformers' && answers.modelName && registryConfigManager) {
555
+ try {
556
+ // Fetch HuggingFace data for model-specific info
557
+ const hfData = await registryConfigManager._fetchHuggingFaceData(answers.modelName)
558
+
559
+ // Merge chatTemplate if available and not already set
560
+ if (hfData && hfData.chatTemplate && !answers.chatTemplate) {
561
+ answers.chatTemplate = hfData.chatTemplate
562
+ answers.chatTemplateSource = 'HuggingFace_Hub_API'
563
+ }
564
+
565
+ // Check Model Registry for chatTemplate overrides
566
+ if (registryConfigManager.modelRegistry) {
567
+ const modelConfig = _findModelConfig(answers.modelName, registryConfigManager)
568
+
569
+ if (modelConfig && modelConfig.chatTemplate) {
570
+ answers.chatTemplate = modelConfig.chatTemplate
571
+ answers.chatTemplateSource = 'Model_Registry'
572
+ }
573
+ }
574
+
575
+ // Set framework-level metadata (non-envVar fields)
576
+ if (answers.frameworkVersion && registryConfigManager.frameworkRegistry) {
577
+ const frameworkConfig = registryConfigManager.frameworkRegistry[answers.framework]?.[answers.frameworkVersion]
578
+
579
+ if (frameworkConfig) {
580
+ if (frameworkConfig.inferenceAmiVersion && !answers.inferenceAmiVersion) {
581
+ answers.inferenceAmiVersion = frameworkConfig.inferenceAmiVersion
582
+ }
583
+ if (frameworkConfig.accelerator) {
584
+ answers.accelerator = frameworkConfig.accelerator
585
+ }
586
+ }
587
+ }
588
+ } catch (error) {
589
+ // Silently continue - defaults are already set
590
+ }
591
+ }
592
+ }
593
+
594
+ /**
595
+ * Orders environment variables by priority category for template rendering.
596
+ *
597
+ * @param {object} envVars - Environment variables map
598
+ * @returns {Array<{key: string, value: string}>} Ordered array
599
+ */
600
+ function _getOrderedEnvVars(envVars) {
601
+ const entries = Object.entries(envVars)
602
+
603
+ const priorities = {
604
+ 'LD_LIBRARY_PATH': 1,
605
+ 'PATH': 1,
606
+ 'CUDA_HOME': 1,
607
+ 'CUDA_PATH': 1,
608
+ 'CUDA_VISIBLE_DEVICES': 2,
609
+ 'NVIDIA_VISIBLE_DEVICES': 2,
610
+ 'NVIDIA_DRIVER_CAPABILITIES': 2,
611
+ 'VLLM': 3,
612
+ 'TENSORRT': 3,
613
+ 'SGLANG': 3,
614
+ 'TRANSFORMERS': 3,
615
+ 'MAX': 4,
616
+ 'BATCH': 4,
617
+ 'WORKER': 4,
618
+ 'THREAD': 4,
619
+ 'default': 5
620
+ }
621
+
622
+ function getPriority(key) {
623
+ if (priorities[key]) return priorities[key]
624
+ for (const [pattern, priority] of Object.entries(priorities)) {
625
+ if (pattern !== 'default' && key.includes(pattern)) {
626
+ return priority
627
+ }
628
+ }
629
+ return priorities.default
630
+ }
631
+
632
+ const sorted = entries.sort(([keyA], [keyB]) => {
633
+ const priorityA = getPriority(keyA)
634
+ const priorityB = getPriority(keyB)
635
+ if (priorityA !== priorityB) return priorityA - priorityB
636
+ return keyA.localeCompare(keyB)
637
+ })
638
+
639
+ return sorted.map(([key, value]) => ({ key, value }))
640
+ }
641
+
642
+ /**
643
+ * Validates environment variables using the registry system.
644
+ * Displays errors and warnings to the user.
645
+ *
646
+ * @param {object} answers - Configuration answers
647
+ * @param {object} registryConfigManager - Registry configuration manager
648
+ */
649
+ async function _validateEnvironmentVariables(answers, registryConfigManager) {
650
+ // Get framework configuration
651
+ // For Triton configs, look up using deploymentConfig key (e.g. 'triton-fil')
652
+ let frameworkConfig
653
+ if (answers.architecture === 'triton' && answers.deploymentConfig) {
654
+ const tritonEntry = registryConfigManager.frameworkRegistry?.[answers.deploymentConfig]
655
+ if (tritonEntry) {
656
+ const versions = Object.keys(tritonEntry)
657
+ if (versions.length > 0) {
658
+ frameworkConfig = tritonEntry[versions[0]]
659
+ }
660
+ }
661
+ }
662
+ if (!frameworkConfig) {
663
+ frameworkConfig = registryConfigManager.frameworkRegistry?.[answers.framework]?.[answers.frameworkVersion]
664
+ }
665
+
666
+ if (!frameworkConfig || !frameworkConfig.envVars) {
667
+ return // No env vars to validate
668
+ }
669
+
670
+ console.log('\nšŸ” Validating environment variables...')
671
+
672
+ // Validate environment variables
673
+ const validationResult = registryConfigManager.validateEnvironmentVariables(
674
+ frameworkConfig.envVars,
675
+ frameworkConfig
676
+ )
677
+
678
+ // Display validation results
679
+ if (validationResult.errors && validationResult.errors.length > 0) {
680
+ console.log('\nāŒ Environment Variable Validation Errors:')
681
+ validationResult.errors.forEach(error => {
682
+ console.log(` • ${error.key}: ${error.message}`)
683
+ })
684
+ }
685
+
686
+ if (validationResult.warnings && validationResult.warnings.length > 0) {
687
+ console.log('\nāš ļø Environment Variable Validation Warnings:')
688
+ validationResult.warnings.forEach(warning => {
689
+ console.log(` • ${warning.key ? `${warning.key}: ` : ''}${warning.message}`)
690
+ })
691
+ }
692
+
693
+ if (validationResult.strategiesUsed && validationResult.strategiesUsed.length > 0) {
694
+ console.log(`\nāœ… Validation methods used: ${validationResult.strategiesUsed.join(', ')}`)
695
+ }
696
+
697
+ if (!validationResult.errors || validationResult.errors.length === 0) {
698
+ if (!validationResult.warnings || validationResult.warnings.length === 0) {
699
+ console.log(' āœ… All environment variables validated successfully')
700
+ }
701
+ }
702
+
703
+ // In non-interactive mode (skip-prompts), throw on errors
704
+ if (validationResult.errors && validationResult.errors.length > 0) {
705
+ throw new Error('Environment variable validation failed. Please fix the errors and try again.')
706
+ }
707
+ }
708
+
709
+ /**
710
+ * Merges environment variables from all catalog sources with correct precedence.
711
+ * Precedence (lowest → highest):
712
+ * 1. catalog defaults (Image_Entry defaults.envVars)
713
+ * 2. framework profile (Image_Entry profiles[selectedProfile].envVars)
714
+ * 3. model entry (model catalog entry envVars)
715
+ * 4. model profile (model catalog entry profiles[selectedProfile].envVars)
716
+ * 5. CLI overrides (existing answers.envVars from user CLI input)
717
+ *
718
+ * @param {object} answers - Configuration answers
719
+ * @param {object|null} registryConfigManager - Registry configuration manager
720
+ */
721
+ async function _mergeEnvVarsWithPrecedence(answers, registryConfigManager) {
722
+ if (!registryConfigManager) return
723
+
724
+ // Capture CLI-provided env vars before merging (highest precedence)
725
+ const cliEnvVars = { ...answers.envVars }
726
+
727
+ // Resolve the framework config for the selected framework + version
728
+ const frameworkName = answers.framework || answers.deploymentConfig
729
+ const frameworkVersion = answers.frameworkVersion
730
+ let frameworkConfig = null
731
+
732
+ if (frameworkName && registryConfigManager.frameworkRegistry) {
733
+ const frameworkVersions = registryConfigManager.frameworkRegistry[frameworkName]
734
+ if (frameworkVersions) {
735
+ if (frameworkVersion && frameworkVersions[frameworkVersion]) {
736
+ frameworkConfig = frameworkVersions[frameworkVersion]
737
+ } else {
738
+ // Fall back to latest version for Triton and other non-versioned lookups
739
+ const versions = Object.keys(frameworkVersions).sort((a, b) =>
740
+ b.localeCompare(a, undefined, { numeric: true })
741
+ )
742
+ if (versions.length > 0) {
743
+ frameworkConfig = frameworkVersions[versions[0]]
744
+ }
745
+ }
746
+ }
747
+ }
748
+
749
+ // Resolve the model config (exact match or pattern match)
750
+ let modelConfig = null
751
+ if (answers.modelName && registryConfigManager.modelRegistry) {
752
+ modelConfig = _findModelConfig(answers.modelName, registryConfigManager)
753
+ }
754
+
755
+ // Layer 1: catalog defaults (Image_Entry defaults.envVars)
756
+ const catalogDefaults = frameworkConfig?.envVars || {}
757
+
758
+ // Layer 2: framework profile envVars
759
+ let frameworkProfileEnvVars = {}
760
+ if (answers.frameworkProfile && frameworkConfig?.profiles) {
761
+ const profile = frameworkConfig.profiles[answers.frameworkProfile]
762
+ if (profile?.envVars) {
763
+ frameworkProfileEnvVars = profile.envVars
764
+ }
765
+ }
766
+
767
+ // Layer 3: model entry envVars
768
+ const modelEntryEnvVars = modelConfig?.envVars || {}
769
+
770
+ // Layer 4: model profile envVars
771
+ let modelProfileEnvVars = {}
772
+ if (answers.modelProfile && modelConfig?.profiles) {
773
+ const profile = modelConfig.profiles[answers.modelProfile]
774
+ if (profile?.envVars) {
775
+ modelProfileEnvVars = profile.envVars
776
+ }
777
+ }
778
+
779
+ // Layer 5: CLI overrides (captured above)
780
+
781
+ // Merge in precedence order: each layer overrides the previous
782
+ answers.envVars = {
783
+ ...catalogDefaults,
784
+ ...frameworkProfileEnvVars,
785
+ ...modelEntryEnvVars,
786
+ ...modelProfileEnvVars,
787
+ ...cliEnvVars
788
+ }
789
+ }
790
+
791
+ /**
792
+ * Finds model configuration by exact match or glob-pattern match.
793
+ *
794
+ * @param {string} modelName - Model ID to look up
795
+ * @param {object} registryConfigManager - Registry configuration manager
796
+ * @returns {object|null} Model configuration or null
797
+ */
798
+ function _findModelConfig(modelName, registryConfigManager) {
799
+ if (!registryConfigManager?.modelRegistry) return null
800
+
801
+ // Exact match first
802
+ const exact = registryConfigManager.modelRegistry[modelName]
803
+ if (exact) return exact
804
+
805
+ // Pattern matching with glob-style wildcards
806
+ for (const [pattern, config] of Object.entries(registryConfigManager.modelRegistry)) {
807
+ if (pattern.includes('*')) {
808
+ const regex = new RegExp(`^${pattern.replace(/\*/g, '.*')}$`)
809
+ if (regex.test(modelName)) {
810
+ return config
811
+ }
812
+ }
813
+ }
814
+
815
+ return null
816
+ }
817
+
818
+ /**
819
+ * Generates Triton-specific files (Dockerfile, model repository structure).
820
+ *
821
+ * @param {string} templateDir - Template source directory
822
+ * @param {string} destDir - Destination directory
823
+ * @param {object} templateVars - Template variables for EJS
824
+ * @param {object} answers - Configuration answers
825
+ * @param {object} tritonBackends - Triton backends catalog
826
+ */
827
+ function _generateTritonFiles(templateDir, destDir, templateVars, answers, tritonBackends) {
828
+ const modelName = answers.modelName || 'model'
829
+ const backend = answers.backend
830
+
831
+ // Copy Triton Dockerfile
832
+ _renderTemplate(
833
+ path.join(templateDir, 'triton/Dockerfile'),
834
+ path.join(destDir, 'Dockerfile'),
835
+ templateVars
836
+ )
837
+
838
+ // Create model repository directory structure
839
+ const modelRepoPath = path.join(destDir, `model_repository/${modelName}`)
840
+ fs.mkdirSync(path.join(modelRepoPath, '1'), { recursive: true })
841
+
842
+ // Copy config.pbtxt
843
+ _renderTemplate(
844
+ path.join(templateDir, 'triton/config.pbtxt'),
845
+ path.join(modelRepoPath, 'config.pbtxt'),
846
+ templateVars
847
+ )
848
+
849
+ // Create version 1 directory with .gitkeep
850
+ fs.writeFileSync(
851
+ path.join(modelRepoPath, '1/.gitkeep'),
852
+ '# Placeholder for model artifacts\n'
853
+ )
854
+
855
+ // For triton-python backend: copy model.py and requirements.txt
856
+ if (backend === 'python') {
857
+ _renderTemplate(
858
+ path.join(templateDir, 'triton/model.py'),
859
+ path.join(modelRepoPath, '1/model.py'),
860
+ templateVars
861
+ )
862
+ _renderTemplate(
863
+ path.join(templateDir, 'triton/requirements.txt'),
864
+ path.join(destDir, 'triton/requirements.txt'),
865
+ templateVars
866
+ )
867
+ }
868
+ }
869
+
870
+ /**
871
+ * Renders a single EJS template file to a destination path.
872
+ *
873
+ * @param {string} src - Source template file path
874
+ * @param {string} dest - Destination file path
875
+ * @param {object} vars - Template variables
876
+ */
877
+ function _renderTemplate(src, dest, vars) {
878
+ fs.mkdirSync(path.dirname(dest), { recursive: true })
879
+ const content = fs.readFileSync(src, 'utf8')
880
+ const rendered = ejs.render(content, vars, { filename: src })
881
+ fs.writeFileSync(dest, rendered)
882
+ }
883
+
884
+ /**
885
+ * Copies a file without EJS rendering.
886
+ *
887
+ * @param {string} src - Source file path
888
+ * @param {string} dest - Destination file path
889
+ */
890
+ function _copyFile(src, dest) {
891
+ fs.mkdirSync(path.dirname(dest), { recursive: true })
892
+ fs.copyFileSync(src, dest)
893
+ }
894
+
895
+ /**
896
+ * Removes a file if it exists, silently ignoring if it doesn't.
897
+ *
898
+ * @param {string} filePath - Path to the file to remove
899
+ */
900
+ function _unlinkIfExists(filePath) {
901
+ try {
902
+ if (fs.existsSync(filePath)) {
903
+ fs.unlinkSync(filePath)
904
+ }
905
+ } catch (e) {
906
+ // Silently continue
907
+ }
908
+ }
909
+
910
+ /**
911
+ * Sets executable permissions on shell scripts in the generated project.
912
+ *
913
+ * @param {string} destDir - Path to the generated project directory
914
+ */
915
+ function _setExecutablePermissions(destDir) {
916
+ const shellScripts = [
917
+ 'do/config',
918
+ 'do/build',
919
+ 'do/push',
920
+ 'do/deploy',
921
+ 'do/run',
922
+ 'do/test',
923
+ 'do/logs',
924
+ 'do/clean',
925
+ 'do/submit',
926
+ 'do/register',
927
+ 'do/ci',
928
+ 'do/manifest'
929
+ ]
930
+
931
+ shellScripts.forEach(script => {
932
+ const scriptPath = path.join(destDir, script)
933
+ try {
934
+ if (fs.existsSync(scriptPath)) {
935
+ const stats = fs.statSync(scriptPath)
936
+ const newMode = stats.mode | 0o755
937
+ fs.chmodSync(scriptPath, newMode)
938
+ }
939
+ } catch (error) {
940
+ // Silently continue if chmod fails (e.g., on Windows)
941
+ }
942
+ })
943
+ }
944
+
945
+ /**
946
+ * Runs sample model training script in the generated project.
947
+ * Non-fatal: if training fails, just warns the user.
948
+ *
949
+ * @param {string} destDir - Path to the generated project directory
950
+ */
951
+ async function _runSampleModelTraining(destDir) {
952
+ const trainingScriptName = 'train_abalone.py'
953
+ const trainingScript = path.join(destDir, `sample_model/${trainingScriptName}`)
954
+ const sampleModelDir = path.join(destDir, 'sample_model')
955
+ const requirementsFile = path.join(destDir, 'requirements.txt')
956
+
957
+ console.log('\nšŸ¤– Training sample model...')
958
+ console.log('This will generate the model file needed for Docker build.')
959
+
960
+ try {
961
+ if (!fs.existsSync(trainingScript)) {
962
+ console.log('āš ļø Training script not found, skipping model training')
963
+ return
964
+ }
965
+
966
+ // Install dependencies
967
+ if (fs.existsSync(requirementsFile)) {
968
+ console.log('šŸ“¦ Installing dependencies from requirements.txt...')
969
+ await _spawnAsync('pip', ['install', '-q', '-r', requirementsFile], { cwd: destDir })
970
+ }
971
+
972
+ // Run training script
973
+ await _spawnAsync('python', [trainingScriptName], { cwd: sampleModelDir })
974
+ console.log('āœ… Sample model training completed successfully!')
975
+ console.log(`šŸ“ Model file saved in: ${sampleModelDir}`)
976
+ } catch (error) {
977
+ console.log('āš ļø Error during sample model training:', error.message)
978
+ console.log(`Please run manually: python sample_model/${trainingScriptName}`)
979
+ }
980
+ }
981
+
982
+ /**
983
+ * Spawns a child process and returns a promise.
984
+ * Resolves on exit code 0, rejects otherwise.
985
+ *
986
+ * @param {string} command - Command to run
987
+ * @param {string[]} args - Command arguments
988
+ * @param {object} opts - spawn options
989
+ * @returns {Promise<void>}
990
+ */
991
+ function _spawnAsync(command, args, opts = {}) {
992
+ return new Promise((resolve, reject) => {
993
+ const proc = spawn(command, args, { ...opts, stdio: 'inherit' })
994
+
995
+ proc.on('close', (code) => {
996
+ if (code === 0) {
997
+ resolve()
998
+ } else {
999
+ reject(new Error(`${command} exited with code ${code}`))
1000
+ }
1001
+ })
1002
+
1003
+ proc.on('error', (error) => {
1004
+ reject(error)
1005
+ })
1006
+ })
1007
+ }