PraisonAI 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (393) hide show
  1. praisonai/__init__.py +54 -0
  2. praisonai/__main__.py +15 -0
  3. praisonai/acp/__init__.py +54 -0
  4. praisonai/acp/config.py +159 -0
  5. praisonai/acp/server.py +587 -0
  6. praisonai/acp/session.py +219 -0
  7. praisonai/adapters/__init__.py +50 -0
  8. praisonai/adapters/readers.py +395 -0
  9. praisonai/adapters/rerankers.py +315 -0
  10. praisonai/adapters/retrievers.py +394 -0
  11. praisonai/adapters/vector_stores.py +409 -0
  12. praisonai/agent_scheduler.py +337 -0
  13. praisonai/agents_generator.py +903 -0
  14. praisonai/api/call.py +292 -0
  15. praisonai/auto.py +1197 -0
  16. praisonai/capabilities/__init__.py +275 -0
  17. praisonai/capabilities/a2a.py +140 -0
  18. praisonai/capabilities/assistants.py +283 -0
  19. praisonai/capabilities/audio.py +320 -0
  20. praisonai/capabilities/batches.py +469 -0
  21. praisonai/capabilities/completions.py +336 -0
  22. praisonai/capabilities/container_files.py +155 -0
  23. praisonai/capabilities/containers.py +93 -0
  24. praisonai/capabilities/embeddings.py +158 -0
  25. praisonai/capabilities/files.py +467 -0
  26. praisonai/capabilities/fine_tuning.py +293 -0
  27. praisonai/capabilities/guardrails.py +182 -0
  28. praisonai/capabilities/images.py +330 -0
  29. praisonai/capabilities/mcp.py +190 -0
  30. praisonai/capabilities/messages.py +270 -0
  31. praisonai/capabilities/moderations.py +154 -0
  32. praisonai/capabilities/ocr.py +217 -0
  33. praisonai/capabilities/passthrough.py +204 -0
  34. praisonai/capabilities/rag.py +207 -0
  35. praisonai/capabilities/realtime.py +160 -0
  36. praisonai/capabilities/rerank.py +165 -0
  37. praisonai/capabilities/responses.py +266 -0
  38. praisonai/capabilities/search.py +109 -0
  39. praisonai/capabilities/skills.py +133 -0
  40. praisonai/capabilities/vector_store_files.py +334 -0
  41. praisonai/capabilities/vector_stores.py +304 -0
  42. praisonai/capabilities/videos.py +141 -0
  43. praisonai/chainlit_ui.py +304 -0
  44. praisonai/chat/__init__.py +106 -0
  45. praisonai/chat/app.py +125 -0
  46. praisonai/cli/__init__.py +26 -0
  47. praisonai/cli/app.py +213 -0
  48. praisonai/cli/commands/__init__.py +75 -0
  49. praisonai/cli/commands/acp.py +70 -0
  50. praisonai/cli/commands/completion.py +333 -0
  51. praisonai/cli/commands/config.py +166 -0
  52. praisonai/cli/commands/debug.py +142 -0
  53. praisonai/cli/commands/diag.py +55 -0
  54. praisonai/cli/commands/doctor.py +166 -0
  55. praisonai/cli/commands/environment.py +179 -0
  56. praisonai/cli/commands/lsp.py +112 -0
  57. praisonai/cli/commands/mcp.py +210 -0
  58. praisonai/cli/commands/profile.py +457 -0
  59. praisonai/cli/commands/run.py +228 -0
  60. praisonai/cli/commands/schedule.py +150 -0
  61. praisonai/cli/commands/serve.py +97 -0
  62. praisonai/cli/commands/session.py +212 -0
  63. praisonai/cli/commands/traces.py +145 -0
  64. praisonai/cli/commands/version.py +101 -0
  65. praisonai/cli/configuration/__init__.py +18 -0
  66. praisonai/cli/configuration/loader.py +353 -0
  67. praisonai/cli/configuration/paths.py +114 -0
  68. praisonai/cli/configuration/schema.py +164 -0
  69. praisonai/cli/features/__init__.py +268 -0
  70. praisonai/cli/features/acp.py +236 -0
  71. praisonai/cli/features/action_orchestrator.py +546 -0
  72. praisonai/cli/features/agent_scheduler.py +773 -0
  73. praisonai/cli/features/agent_tools.py +474 -0
  74. praisonai/cli/features/agents.py +375 -0
  75. praisonai/cli/features/at_mentions.py +471 -0
  76. praisonai/cli/features/auto_memory.py +182 -0
  77. praisonai/cli/features/autonomy_mode.py +490 -0
  78. praisonai/cli/features/background.py +356 -0
  79. praisonai/cli/features/base.py +168 -0
  80. praisonai/cli/features/capabilities.py +1326 -0
  81. praisonai/cli/features/checkpoints.py +338 -0
  82. praisonai/cli/features/code_intelligence.py +652 -0
  83. praisonai/cli/features/compaction.py +294 -0
  84. praisonai/cli/features/compare.py +534 -0
  85. praisonai/cli/features/cost_tracker.py +514 -0
  86. praisonai/cli/features/debug.py +810 -0
  87. praisonai/cli/features/deploy.py +517 -0
  88. praisonai/cli/features/diag.py +289 -0
  89. praisonai/cli/features/doctor/__init__.py +63 -0
  90. praisonai/cli/features/doctor/checks/__init__.py +24 -0
  91. praisonai/cli/features/doctor/checks/acp_checks.py +240 -0
  92. praisonai/cli/features/doctor/checks/config_checks.py +366 -0
  93. praisonai/cli/features/doctor/checks/db_checks.py +366 -0
  94. praisonai/cli/features/doctor/checks/env_checks.py +543 -0
  95. praisonai/cli/features/doctor/checks/lsp_checks.py +199 -0
  96. praisonai/cli/features/doctor/checks/mcp_checks.py +349 -0
  97. praisonai/cli/features/doctor/checks/memory_checks.py +268 -0
  98. praisonai/cli/features/doctor/checks/network_checks.py +251 -0
  99. praisonai/cli/features/doctor/checks/obs_checks.py +328 -0
  100. praisonai/cli/features/doctor/checks/performance_checks.py +235 -0
  101. praisonai/cli/features/doctor/checks/permissions_checks.py +259 -0
  102. praisonai/cli/features/doctor/checks/selftest_checks.py +322 -0
  103. praisonai/cli/features/doctor/checks/serve_checks.py +426 -0
  104. praisonai/cli/features/doctor/checks/skills_checks.py +231 -0
  105. praisonai/cli/features/doctor/checks/tools_checks.py +371 -0
  106. praisonai/cli/features/doctor/engine.py +266 -0
  107. praisonai/cli/features/doctor/formatters.py +310 -0
  108. praisonai/cli/features/doctor/handler.py +397 -0
  109. praisonai/cli/features/doctor/models.py +264 -0
  110. praisonai/cli/features/doctor/registry.py +239 -0
  111. praisonai/cli/features/endpoints.py +1019 -0
  112. praisonai/cli/features/eval.py +560 -0
  113. praisonai/cli/features/external_agents.py +231 -0
  114. praisonai/cli/features/fast_context.py +410 -0
  115. praisonai/cli/features/flow_display.py +566 -0
  116. praisonai/cli/features/git_integration.py +651 -0
  117. praisonai/cli/features/guardrail.py +171 -0
  118. praisonai/cli/features/handoff.py +185 -0
  119. praisonai/cli/features/hooks.py +583 -0
  120. praisonai/cli/features/image.py +384 -0
  121. praisonai/cli/features/interactive_runtime.py +585 -0
  122. praisonai/cli/features/interactive_tools.py +380 -0
  123. praisonai/cli/features/interactive_tui.py +603 -0
  124. praisonai/cli/features/jobs.py +632 -0
  125. praisonai/cli/features/knowledge.py +531 -0
  126. praisonai/cli/features/lite.py +244 -0
  127. praisonai/cli/features/lsp_cli.py +225 -0
  128. praisonai/cli/features/mcp.py +169 -0
  129. praisonai/cli/features/message_queue.py +587 -0
  130. praisonai/cli/features/metrics.py +211 -0
  131. praisonai/cli/features/n8n.py +673 -0
  132. praisonai/cli/features/observability.py +293 -0
  133. praisonai/cli/features/ollama.py +361 -0
  134. praisonai/cli/features/output_style.py +273 -0
  135. praisonai/cli/features/package.py +631 -0
  136. praisonai/cli/features/performance.py +308 -0
  137. praisonai/cli/features/persistence.py +636 -0
  138. praisonai/cli/features/profile.py +226 -0
  139. praisonai/cli/features/profiler/__init__.py +81 -0
  140. praisonai/cli/features/profiler/core.py +558 -0
  141. praisonai/cli/features/profiler/optimizations.py +652 -0
  142. praisonai/cli/features/profiler/suite.py +386 -0
  143. praisonai/cli/features/profiling.py +350 -0
  144. praisonai/cli/features/queue/__init__.py +73 -0
  145. praisonai/cli/features/queue/manager.py +395 -0
  146. praisonai/cli/features/queue/models.py +286 -0
  147. praisonai/cli/features/queue/persistence.py +564 -0
  148. praisonai/cli/features/queue/scheduler.py +484 -0
  149. praisonai/cli/features/queue/worker.py +372 -0
  150. praisonai/cli/features/recipe.py +1723 -0
  151. praisonai/cli/features/recipes.py +449 -0
  152. praisonai/cli/features/registry.py +229 -0
  153. praisonai/cli/features/repo_map.py +860 -0
  154. praisonai/cli/features/router.py +466 -0
  155. praisonai/cli/features/sandbox_executor.py +515 -0
  156. praisonai/cli/features/serve.py +829 -0
  157. praisonai/cli/features/session.py +222 -0
  158. praisonai/cli/features/skills.py +856 -0
  159. praisonai/cli/features/slash_commands.py +650 -0
  160. praisonai/cli/features/telemetry.py +179 -0
  161. praisonai/cli/features/templates.py +1384 -0
  162. praisonai/cli/features/thinking.py +305 -0
  163. praisonai/cli/features/todo.py +334 -0
  164. praisonai/cli/features/tools.py +680 -0
  165. praisonai/cli/features/tui/__init__.py +83 -0
  166. praisonai/cli/features/tui/app.py +580 -0
  167. praisonai/cli/features/tui/cli.py +566 -0
  168. praisonai/cli/features/tui/debug.py +511 -0
  169. praisonai/cli/features/tui/events.py +99 -0
  170. praisonai/cli/features/tui/mock_provider.py +328 -0
  171. praisonai/cli/features/tui/orchestrator.py +652 -0
  172. praisonai/cli/features/tui/screens/__init__.py +50 -0
  173. praisonai/cli/features/tui/screens/main.py +245 -0
  174. praisonai/cli/features/tui/screens/queue.py +174 -0
  175. praisonai/cli/features/tui/screens/session.py +124 -0
  176. praisonai/cli/features/tui/screens/settings.py +148 -0
  177. praisonai/cli/features/tui/widgets/__init__.py +56 -0
  178. praisonai/cli/features/tui/widgets/chat.py +261 -0
  179. praisonai/cli/features/tui/widgets/composer.py +224 -0
  180. praisonai/cli/features/tui/widgets/queue_panel.py +200 -0
  181. praisonai/cli/features/tui/widgets/status.py +167 -0
  182. praisonai/cli/features/tui/widgets/tool_panel.py +248 -0
  183. praisonai/cli/features/workflow.py +720 -0
  184. praisonai/cli/legacy.py +236 -0
  185. praisonai/cli/main.py +5559 -0
  186. praisonai/cli/schedule_cli.py +54 -0
  187. praisonai/cli/state/__init__.py +31 -0
  188. praisonai/cli/state/identifiers.py +161 -0
  189. praisonai/cli/state/sessions.py +313 -0
  190. praisonai/code/__init__.py +93 -0
  191. praisonai/code/agent_tools.py +344 -0
  192. praisonai/code/diff/__init__.py +21 -0
  193. praisonai/code/diff/diff_strategy.py +432 -0
  194. praisonai/code/tools/__init__.py +27 -0
  195. praisonai/code/tools/apply_diff.py +221 -0
  196. praisonai/code/tools/execute_command.py +275 -0
  197. praisonai/code/tools/list_files.py +274 -0
  198. praisonai/code/tools/read_file.py +206 -0
  199. praisonai/code/tools/search_replace.py +248 -0
  200. praisonai/code/tools/write_file.py +217 -0
  201. praisonai/code/utils/__init__.py +46 -0
  202. praisonai/code/utils/file_utils.py +307 -0
  203. praisonai/code/utils/ignore_utils.py +308 -0
  204. praisonai/code/utils/text_utils.py +276 -0
  205. praisonai/db/__init__.py +64 -0
  206. praisonai/db/adapter.py +531 -0
  207. praisonai/deploy/__init__.py +62 -0
  208. praisonai/deploy/api.py +231 -0
  209. praisonai/deploy/docker.py +454 -0
  210. praisonai/deploy/doctor.py +367 -0
  211. praisonai/deploy/main.py +327 -0
  212. praisonai/deploy/models.py +179 -0
  213. praisonai/deploy/providers/__init__.py +33 -0
  214. praisonai/deploy/providers/aws.py +331 -0
  215. praisonai/deploy/providers/azure.py +358 -0
  216. praisonai/deploy/providers/base.py +101 -0
  217. praisonai/deploy/providers/gcp.py +314 -0
  218. praisonai/deploy/schema.py +208 -0
  219. praisonai/deploy.py +185 -0
  220. praisonai/endpoints/__init__.py +53 -0
  221. praisonai/endpoints/a2u_server.py +410 -0
  222. praisonai/endpoints/discovery.py +165 -0
  223. praisonai/endpoints/providers/__init__.py +28 -0
  224. praisonai/endpoints/providers/a2a.py +253 -0
  225. praisonai/endpoints/providers/a2u.py +208 -0
  226. praisonai/endpoints/providers/agents_api.py +171 -0
  227. praisonai/endpoints/providers/base.py +231 -0
  228. praisonai/endpoints/providers/mcp.py +263 -0
  229. praisonai/endpoints/providers/recipe.py +206 -0
  230. praisonai/endpoints/providers/tools_mcp.py +150 -0
  231. praisonai/endpoints/registry.py +131 -0
  232. praisonai/endpoints/server.py +161 -0
  233. praisonai/inbuilt_tools/__init__.py +24 -0
  234. praisonai/inbuilt_tools/autogen_tools.py +117 -0
  235. praisonai/inc/__init__.py +2 -0
  236. praisonai/inc/config.py +96 -0
  237. praisonai/inc/models.py +155 -0
  238. praisonai/integrations/__init__.py +56 -0
  239. praisonai/integrations/base.py +303 -0
  240. praisonai/integrations/claude_code.py +270 -0
  241. praisonai/integrations/codex_cli.py +255 -0
  242. praisonai/integrations/cursor_cli.py +195 -0
  243. praisonai/integrations/gemini_cli.py +222 -0
  244. praisonai/jobs/__init__.py +67 -0
  245. praisonai/jobs/executor.py +425 -0
  246. praisonai/jobs/models.py +230 -0
  247. praisonai/jobs/router.py +314 -0
  248. praisonai/jobs/server.py +186 -0
  249. praisonai/jobs/store.py +203 -0
  250. praisonai/llm/__init__.py +66 -0
  251. praisonai/llm/registry.py +382 -0
  252. praisonai/mcp_server/__init__.py +152 -0
  253. praisonai/mcp_server/adapters/__init__.py +74 -0
  254. praisonai/mcp_server/adapters/agents.py +128 -0
  255. praisonai/mcp_server/adapters/capabilities.py +168 -0
  256. praisonai/mcp_server/adapters/cli_tools.py +568 -0
  257. praisonai/mcp_server/adapters/extended_capabilities.py +462 -0
  258. praisonai/mcp_server/adapters/knowledge.py +93 -0
  259. praisonai/mcp_server/adapters/memory.py +104 -0
  260. praisonai/mcp_server/adapters/prompts.py +306 -0
  261. praisonai/mcp_server/adapters/resources.py +124 -0
  262. praisonai/mcp_server/adapters/tools_bridge.py +280 -0
  263. praisonai/mcp_server/auth/__init__.py +48 -0
  264. praisonai/mcp_server/auth/api_key.py +291 -0
  265. praisonai/mcp_server/auth/oauth.py +460 -0
  266. praisonai/mcp_server/auth/oidc.py +289 -0
  267. praisonai/mcp_server/auth/scopes.py +260 -0
  268. praisonai/mcp_server/cli.py +852 -0
  269. praisonai/mcp_server/elicitation.py +445 -0
  270. praisonai/mcp_server/icons.py +302 -0
  271. praisonai/mcp_server/recipe_adapter.py +573 -0
  272. praisonai/mcp_server/recipe_cli.py +824 -0
  273. praisonai/mcp_server/registry.py +703 -0
  274. praisonai/mcp_server/sampling.py +422 -0
  275. praisonai/mcp_server/server.py +490 -0
  276. praisonai/mcp_server/tasks.py +443 -0
  277. praisonai/mcp_server/transports/__init__.py +18 -0
  278. praisonai/mcp_server/transports/http_stream.py +376 -0
  279. praisonai/mcp_server/transports/stdio.py +132 -0
  280. praisonai/persistence/__init__.py +84 -0
  281. praisonai/persistence/config.py +238 -0
  282. praisonai/persistence/conversation/__init__.py +25 -0
  283. praisonai/persistence/conversation/async_mysql.py +427 -0
  284. praisonai/persistence/conversation/async_postgres.py +410 -0
  285. praisonai/persistence/conversation/async_sqlite.py +371 -0
  286. praisonai/persistence/conversation/base.py +151 -0
  287. praisonai/persistence/conversation/json_store.py +250 -0
  288. praisonai/persistence/conversation/mysql.py +387 -0
  289. praisonai/persistence/conversation/postgres.py +401 -0
  290. praisonai/persistence/conversation/singlestore.py +240 -0
  291. praisonai/persistence/conversation/sqlite.py +341 -0
  292. praisonai/persistence/conversation/supabase.py +203 -0
  293. praisonai/persistence/conversation/surrealdb.py +287 -0
  294. praisonai/persistence/factory.py +301 -0
  295. praisonai/persistence/hooks/__init__.py +18 -0
  296. praisonai/persistence/hooks/agent_hooks.py +297 -0
  297. praisonai/persistence/knowledge/__init__.py +26 -0
  298. praisonai/persistence/knowledge/base.py +144 -0
  299. praisonai/persistence/knowledge/cassandra.py +232 -0
  300. praisonai/persistence/knowledge/chroma.py +295 -0
  301. praisonai/persistence/knowledge/clickhouse.py +242 -0
  302. praisonai/persistence/knowledge/cosmosdb_vector.py +438 -0
  303. praisonai/persistence/knowledge/couchbase.py +286 -0
  304. praisonai/persistence/knowledge/lancedb.py +216 -0
  305. praisonai/persistence/knowledge/langchain_adapter.py +291 -0
  306. praisonai/persistence/knowledge/lightrag_adapter.py +212 -0
  307. praisonai/persistence/knowledge/llamaindex_adapter.py +256 -0
  308. praisonai/persistence/knowledge/milvus.py +277 -0
  309. praisonai/persistence/knowledge/mongodb_vector.py +306 -0
  310. praisonai/persistence/knowledge/pgvector.py +335 -0
  311. praisonai/persistence/knowledge/pinecone.py +253 -0
  312. praisonai/persistence/knowledge/qdrant.py +301 -0
  313. praisonai/persistence/knowledge/redis_vector.py +291 -0
  314. praisonai/persistence/knowledge/singlestore_vector.py +299 -0
  315. praisonai/persistence/knowledge/surrealdb_vector.py +309 -0
  316. praisonai/persistence/knowledge/upstash_vector.py +266 -0
  317. praisonai/persistence/knowledge/weaviate.py +223 -0
  318. praisonai/persistence/migrations/__init__.py +10 -0
  319. praisonai/persistence/migrations/manager.py +251 -0
  320. praisonai/persistence/orchestrator.py +406 -0
  321. praisonai/persistence/state/__init__.py +21 -0
  322. praisonai/persistence/state/async_mongodb.py +200 -0
  323. praisonai/persistence/state/base.py +107 -0
  324. praisonai/persistence/state/dynamodb.py +226 -0
  325. praisonai/persistence/state/firestore.py +175 -0
  326. praisonai/persistence/state/gcs.py +155 -0
  327. praisonai/persistence/state/memory.py +245 -0
  328. praisonai/persistence/state/mongodb.py +158 -0
  329. praisonai/persistence/state/redis.py +190 -0
  330. praisonai/persistence/state/upstash.py +144 -0
  331. praisonai/persistence/tests/__init__.py +3 -0
  332. praisonai/persistence/tests/test_all_backends.py +633 -0
  333. praisonai/profiler.py +1214 -0
  334. praisonai/recipe/__init__.py +134 -0
  335. praisonai/recipe/bridge.py +278 -0
  336. praisonai/recipe/core.py +893 -0
  337. praisonai/recipe/exceptions.py +54 -0
  338. praisonai/recipe/history.py +402 -0
  339. praisonai/recipe/models.py +266 -0
  340. praisonai/recipe/operations.py +440 -0
  341. praisonai/recipe/policy.py +422 -0
  342. praisonai/recipe/registry.py +849 -0
  343. praisonai/recipe/runtime.py +214 -0
  344. praisonai/recipe/security.py +711 -0
  345. praisonai/recipe/serve.py +859 -0
  346. praisonai/recipe/server.py +613 -0
  347. praisonai/scheduler/__init__.py +45 -0
  348. praisonai/scheduler/agent_scheduler.py +552 -0
  349. praisonai/scheduler/base.py +124 -0
  350. praisonai/scheduler/daemon_manager.py +225 -0
  351. praisonai/scheduler/state_manager.py +155 -0
  352. praisonai/scheduler/yaml_loader.py +193 -0
  353. praisonai/scheduler.py +194 -0
  354. praisonai/setup/__init__.py +1 -0
  355. praisonai/setup/build.py +21 -0
  356. praisonai/setup/post_install.py +23 -0
  357. praisonai/setup/setup_conda_env.py +25 -0
  358. praisonai/setup.py +16 -0
  359. praisonai/templates/__init__.py +116 -0
  360. praisonai/templates/cache.py +364 -0
  361. praisonai/templates/dependency_checker.py +358 -0
  362. praisonai/templates/discovery.py +391 -0
  363. praisonai/templates/loader.py +564 -0
  364. praisonai/templates/registry.py +511 -0
  365. praisonai/templates/resolver.py +206 -0
  366. praisonai/templates/security.py +327 -0
  367. praisonai/templates/tool_override.py +498 -0
  368. praisonai/templates/tools_doctor.py +256 -0
  369. praisonai/test.py +105 -0
  370. praisonai/train.py +562 -0
  371. praisonai/train_vision.py +306 -0
  372. praisonai/ui/agents.py +824 -0
  373. praisonai/ui/callbacks.py +57 -0
  374. praisonai/ui/chainlit_compat.py +246 -0
  375. praisonai/ui/chat.py +532 -0
  376. praisonai/ui/code.py +717 -0
  377. praisonai/ui/colab.py +474 -0
  378. praisonai/ui/colab_chainlit.py +81 -0
  379. praisonai/ui/components/aicoder.py +284 -0
  380. praisonai/ui/context.py +283 -0
  381. praisonai/ui/database_config.py +56 -0
  382. praisonai/ui/db.py +294 -0
  383. praisonai/ui/realtime.py +488 -0
  384. praisonai/ui/realtimeclient/__init__.py +756 -0
  385. praisonai/ui/realtimeclient/tools.py +242 -0
  386. praisonai/ui/sql_alchemy.py +710 -0
  387. praisonai/upload_vision.py +140 -0
  388. praisonai/version.py +1 -0
  389. praisonai-3.0.0.dist-info/METADATA +3493 -0
  390. praisonai-3.0.0.dist-info/RECORD +393 -0
  391. praisonai-3.0.0.dist-info/WHEEL +5 -0
  392. praisonai-3.0.0.dist-info/entry_points.txt +4 -0
  393. praisonai-3.0.0.dist-info/top_level.txt +1 -0
praisonai/train.py ADDED
@@ -0,0 +1,562 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ This script finetunes a model using Unsloth's fast training framework.
5
+ It supports both ShareGPT and Alpaca‑style datasets by converting raw conversation
6
+ data into plain-text prompts using a chat template, then pre‑tokenizing the prompts.
7
+ Extra debug logging is added to help trace the root cause of errors.
8
+ """
9
+
10
+ import os
11
+ import sys
12
+ import yaml
13
+ import torch
14
+ import shutil
15
+ import subprocess
16
+ from transformers import TextStreamer
17
+ from unsloth import FastLanguageModel, is_bfloat16_supported
18
+ from trl import SFTTrainer
19
+ from transformers import TrainingArguments
20
+ from datasets import load_dataset, concatenate_datasets
21
+ from psutil import virtual_memory
22
+ from unsloth.chat_templates import standardize_sharegpt, get_chat_template
23
+ from functools import partial
24
+
25
+ #####################################
26
+ # Step 1: Formatting Raw Conversations
27
+ #####################################
28
+ def formatting_prompts_func(examples, tokenizer):
29
+ """
30
+ Converts each example's conversation into a single plain-text prompt.
31
+ If the example has a "conversations" field, process it as ShareGPT-style.
32
+ Otherwise, assume Alpaca-style data with "instruction", "input", and "output" fields.
33
+ """
34
+ print("DEBUG: formatting_prompts_func() received batch with keys:", list(examples.keys()))
35
+ texts = []
36
+ # Check if the example has a "conversations" field.
37
+ if "conversations" in examples:
38
+ for convo in examples["conversations"]:
39
+ try:
40
+ formatted = tokenizer.apply_chat_template(
41
+ convo,
42
+ tokenize=False, # Return a plain string
43
+ add_generation_prompt=False
44
+ )
45
+ except Exception as e:
46
+ print(f"ERROR in apply_chat_template (conversations): {e}")
47
+ formatted = ""
48
+ # Flatten list if necessary
49
+ if isinstance(formatted, list):
50
+ formatted = formatted[0] if len(formatted) == 1 else "\n".join(formatted)
51
+ texts.append(formatted)
52
+ else:
53
+ # Assume Alpaca format: use "instruction", "input", and "output" keys.
54
+ instructions = examples.get("instruction", [])
55
+ inputs_list = examples.get("input", [])
56
+ outputs_list = examples.get("output", [])
57
+ # If any field is missing, replace with empty string.
58
+ for ins, inp, out in zip(instructions, inputs_list, outputs_list):
59
+ # Create a conversation-like structure.
60
+ convo = [
61
+ {"role": "user", "content": ins + (f"\nInput: {inp}" if inp.strip() != "" else "")},
62
+ {"role": "assistant", "content": out}
63
+ ]
64
+ try:
65
+ formatted = tokenizer.apply_chat_template(
66
+ convo,
67
+ tokenize=False,
68
+ add_generation_prompt=False
69
+ )
70
+ except Exception as e:
71
+ print(f"ERROR in apply_chat_template (alpaca): {e}")
72
+ formatted = ""
73
+ if isinstance(formatted, list):
74
+ formatted = formatted[0] if len(formatted) == 1 else "\n".join(formatted)
75
+ texts.append(formatted)
76
+ if texts:
77
+ print("DEBUG: Raw texts sample (first 200 chars):", texts[0][:200])
78
+ return {"text": texts}
79
+
80
+ #####################################
81
+ # Step 2: Tokenizing the Prompts
82
+ #####################################
83
+ def tokenize_function(examples, hf_tokenizer, max_length):
84
+ """
85
+ Tokenizes a batch of text prompts with padding and truncation enabled.
86
+ """
87
+ flat_texts = []
88
+ for t in examples["text"]:
89
+ if isinstance(t, list):
90
+ t = t[0] if len(t) == 1 else " ".join(t)
91
+ flat_texts.append(t)
92
+ print("DEBUG: Tokenizing a batch of size:", len(flat_texts))
93
+ tokenized = hf_tokenizer(
94
+ flat_texts,
95
+ padding="max_length",
96
+ truncation=True,
97
+ max_length=max_length,
98
+ return_tensors="pt",
99
+ )
100
+ tokenized = {key: value.tolist() for key, value in tokenized.items()}
101
+ sample_key = list(tokenized.keys())[0]
102
+ print("DEBUG: Tokenized sample (first 10 tokens of", sample_key, "):", tokenized[sample_key][0][:10])
103
+ return tokenized
104
+
105
+ #####################################
106
+ # Main Training Class
107
+ #####################################
108
+ class TrainModel:
109
+ def __init__(self, config_path="config.yaml"):
110
+ self.load_config(config_path)
111
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
112
+ self.model = None
113
+ self.hf_tokenizer = None # The underlying HF tokenizer
114
+ self.chat_tokenizer = None # Chat wrapper for formatting
115
+
116
+ def load_config(self, path):
117
+ with open(path, "r") as file:
118
+ self.config = yaml.safe_load(file)
119
+ print("DEBUG: Loaded config:", self.config)
120
+
121
+ def print_system_info(self):
122
+ print("DEBUG: PyTorch version:", torch.__version__)
123
+ print("DEBUG: CUDA version:", torch.version.cuda)
124
+ if torch.cuda.is_available():
125
+ print("DEBUG: CUDA Device Capability:", torch.cuda.get_device_capability())
126
+ else:
127
+ print("DEBUG: CUDA is not available")
128
+ print("DEBUG: Python Version:", sys.version)
129
+ print("DEBUG: Python Path:", sys.executable)
130
+
131
+ def check_gpu(self):
132
+ gpu_stats = torch.cuda.get_device_properties(0)
133
+ print(f"DEBUG: GPU = {gpu_stats.name}. Max memory = {round(gpu_stats.total_memory/(1024**3),3)} GB.")
134
+
135
+ def check_ram(self):
136
+ ram_gb = virtual_memory().total / 1e9
137
+ print(f"DEBUG: Your runtime has {ram_gb:.1f} gigabytes of available RAM")
138
+ if ram_gb < 20:
139
+ print("DEBUG: Not using a high-RAM runtime")
140
+ else:
141
+ print("DEBUG: You are using a high-RAM runtime!")
142
+
143
+ def prepare_model(self):
144
+ print("DEBUG: Preparing model and tokenizer...")
145
+ self.model, original_tokenizer = FastLanguageModel.from_pretrained(
146
+ model_name=self.config["model_name"],
147
+ max_seq_length=self.config["max_seq_length"],
148
+ dtype=None,
149
+ load_in_4bit=self.config["load_in_4bit"],
150
+ )
151
+ print("DEBUG: Model and original tokenizer loaded.")
152
+ if original_tokenizer.pad_token is None:
153
+ original_tokenizer.pad_token = original_tokenizer.eos_token
154
+ original_tokenizer.model_max_length = self.config["max_seq_length"]
155
+ self.chat_tokenizer = get_chat_template(original_tokenizer, chat_template="llama-3.1")
156
+ self.hf_tokenizer = original_tokenizer
157
+ print("DEBUG: Chat tokenizer created; HF tokenizer saved.")
158
+ self.model = FastLanguageModel.get_peft_model(
159
+ self.model,
160
+ r=16,
161
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
162
+ lora_alpha=16,
163
+ lora_dropout=0,
164
+ bias="none",
165
+ use_gradient_checkpointing="unsloth",
166
+ random_state=3407,
167
+ use_rslora=False,
168
+ loftq_config=None,
169
+ )
170
+ print("DEBUG: LoRA adapters added.")
171
+
172
+ def process_dataset(self, dataset_info):
173
+ dataset_name = dataset_info["name"]
174
+ split_type = dataset_info.get("split_type", "train")
175
+ print(f"DEBUG: Loading dataset '{dataset_name}' split '{split_type}'...")
176
+ dataset = load_dataset(dataset_name, split=split_type)
177
+ print("DEBUG: Dataset columns:", dataset.column_names)
178
+ if "conversations" in dataset.column_names:
179
+ print("DEBUG: Standardizing dataset (ShareGPT style)...")
180
+ dataset = standardize_sharegpt(dataset)
181
+ else:
182
+ print("DEBUG: Dataset does not have 'conversations'; assuming Alpaca format.")
183
+ print("DEBUG: Applying formatting function to dataset...")
184
+ format_func = partial(formatting_prompts_func, tokenizer=self.chat_tokenizer)
185
+ dataset = dataset.map(format_func, batched=True, remove_columns=dataset.column_names)
186
+ sample = dataset[0]
187
+ print("DEBUG: Sample processed example keys:", list(sample.keys()))
188
+ if "text" in sample:
189
+ print("DEBUG: Sample processed 'text' type:", type(sample["text"]))
190
+ print("DEBUG: Sample processed 'text' content (first 200 chars):", sample["text"][:200])
191
+ else:
192
+ print("DEBUG: Processed sample does not contain 'text'.")
193
+ return dataset
194
+
195
+ def tokenize_dataset(self, dataset):
196
+ print("DEBUG: Tokenizing the entire dataset...")
197
+ tokenized_dataset = dataset.map(
198
+ lambda examples: tokenize_function(examples, self.hf_tokenizer, self.config["max_seq_length"]),
199
+ batched=True
200
+ )
201
+ tokenized_dataset = tokenized_dataset.remove_columns(["text"])
202
+ print("DEBUG: Tokenized dataset sample keys:", tokenized_dataset[0].keys())
203
+ return tokenized_dataset
204
+
205
+ def load_datasets(self):
206
+ datasets = []
207
+ for dataset_info in self.config["dataset"]:
208
+ print("DEBUG: Processing dataset info:", dataset_info)
209
+ datasets.append(self.process_dataset(dataset_info))
210
+ combined = concatenate_datasets(datasets)
211
+ print("DEBUG: Combined dataset has", len(combined), "examples.")
212
+ return combined
213
+
214
+ def train_model(self):
215
+ print("DEBUG: Starting training...")
216
+ raw_dataset = self.load_datasets()
217
+ tokenized_dataset = self.tokenize_dataset(raw_dataset)
218
+ print("DEBUG: Dataset tokenization complete.")
219
+ # Build the training arguments parameters dynamically
220
+ ta_params = {
221
+ "per_device_train_batch_size": self.config.get("per_device_train_batch_size", 2),
222
+ "gradient_accumulation_steps": self.config.get("gradient_accumulation_steps", 2),
223
+ "warmup_steps": self.config.get("warmup_steps", 50),
224
+ "max_steps": self.config.get("max_steps", 2800),
225
+ "learning_rate": self.config.get("learning_rate", 2e-4),
226
+ "fp16": self.config.get("fp16", not is_bfloat16_supported()),
227
+ "bf16": self.config.get("bf16", is_bfloat16_supported()),
228
+ "logging_steps": self.config.get("logging_steps", 15),
229
+ "optim": self.config.get("optim", "adamw_8bit"),
230
+ "weight_decay": self.config.get("weight_decay", 0.01),
231
+ "lr_scheduler_type": self.config.get("lr_scheduler_type", "linear"),
232
+ "seed": self.config.get("seed", 3407),
233
+ "output_dir": self.config.get("output_dir", "outputs"),
234
+ "report_to": "none" if not os.getenv("PRAISON_WANDB") else "wandb",
235
+ "remove_unused_columns": self.config.get("remove_unused_columns", False)
236
+ }
237
+ if os.getenv("PRAISON_WANDB"):
238
+ ta_params["save_steps"] = self.config.get("save_steps", 100)
239
+ ta_params["run_name"] = os.getenv("PRAISON_WANDB_RUN_NAME", "praisonai-train")
240
+
241
+ training_args = TrainingArguments(**ta_params)
242
+ # Since the dataset is pre-tokenized, we supply a dummy dataset_text_field.
243
+ trainer = SFTTrainer(
244
+ model=self.model,
245
+ tokenizer=self.hf_tokenizer,
246
+ train_dataset=tokenized_dataset,
247
+ dataset_text_field="input_ids", # Dummy field since data is numeric
248
+ max_seq_length=self.config["max_seq_length"],
249
+ dataset_num_proc=1, # Use a single process to avoid pickling issues
250
+ packing=False,
251
+ args=training_args,
252
+ )
253
+ from unsloth.chat_templates import train_on_responses_only
254
+ trainer = train_on_responses_only(
255
+ trainer,
256
+ instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
257
+ response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
258
+ )
259
+ print("DEBUG: Beginning trainer.train() ...")
260
+ trainer.train()
261
+ print("DEBUG: Training complete. Saving model and tokenizer locally...")
262
+ self.model.save_pretrained("lora_model")
263
+ self.hf_tokenizer.save_pretrained("lora_model")
264
+ print("DEBUG: Saved model and tokenizer to 'lora_model'.")
265
+
266
+ def inference(self, instruction, input_text):
267
+ FastLanguageModel.for_inference(self.model)
268
+ messages = [{"role": "user", "content": f"{instruction}\n\nInput: {input_text}"}]
269
+ inputs = self.hf_tokenizer.apply_chat_template(
270
+ messages,
271
+ tokenize=True,
272
+ add_generation_prompt=True,
273
+ return_tensors="pt"
274
+ ).to("cuda")
275
+ outputs = self.model.generate(
276
+ input_ids=inputs,
277
+ max_new_tokens=64,
278
+ use_cache=True,
279
+ temperature=1.5,
280
+ min_p=0.1
281
+ )
282
+ print("DEBUG: Inference output:", self.hf_tokenizer.batch_decode(outputs))
283
+
284
+ def load_model(self):
285
+ from unsloth import FastLanguageModel
286
+ model, tokenizer = FastLanguageModel.from_pretrained(
287
+ model_name=self.config["output_dir"],
288
+ max_seq_length=2048,
289
+ dtype=None,
290
+ load_in_4bit=self.config["load_in_4bit"],
291
+ )
292
+ return model, tokenizer
293
+
294
+ def save_model_merged(self):
295
+ if os.path.exists(self.config["hf_model_name"]):
296
+ shutil.rmtree(self.config["hf_model_name"])
297
+ self.model.push_to_hub_merged(
298
+ self.config["hf_model_name"],
299
+ self.hf_tokenizer,
300
+ save_method="merged_16bit",
301
+ token=os.getenv("HF_TOKEN")
302
+ )
303
+
304
+ def push_model_gguf(self):
305
+ self.model.push_to_hub_gguf(
306
+ self.config["hf_model_name"],
307
+ self.hf_tokenizer,
308
+ quantization_method=self.config["quantization_method"],
309
+ token=os.getenv("HF_TOKEN")
310
+ )
311
+
312
+ def save_model_gguf(self):
313
+ self.model.save_pretrained_gguf(
314
+ self.config["hf_model_name"],
315
+ self.hf_tokenizer,
316
+ quantization_method="q4_k_m"
317
+ )
318
+
319
+ def prepare_modelfile_content(self):
320
+ output_model = self.config["hf_model_name"]
321
+ model_name = self.config["model_name"].lower()
322
+ # Mapping from model name keywords to their default TEMPLATE and stop tokens (and optional SYSTEM/num_ctx)
323
+ mapping = {
324
+ "llama": {
325
+ "template": """<|start_header_id|>system<|end_header_id|>
326
+ Cutting Knowledge Date: December 2023
327
+ {{ if .System }}{{ .System }}
328
+ {{- end }}
329
+ {{- if .Tools }}When you receive a tool call response, use the output to format an answer to the orginal user question.
330
+ You are a helpful assistant with tool calling capabilities.
331
+ {{- end }}<|eot_id|>
332
+ {{- range $i, $_ := .Messages }}
333
+ {{- $last := eq (len (slice $.Messages $i)) 1 }}
334
+ {{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
335
+ {{- if and $.Tools $last }}
336
+ Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
337
+ Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
338
+ {{ range $.Tools }}
339
+ {{- . }}
340
+ {{ end }}
341
+ {{ .Content }}<|eot_id|>
342
+ {{- else }}
343
+ {{ .Content }}<|eot_id|>
344
+ {{- end }}{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
345
+ {{ end }}
346
+ {{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
347
+ {{- if .ToolCalls }}
348
+ {{ range .ToolCalls }}
349
+ {"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}
350
+ {{- else }}
351
+ {{ .Content }}
352
+ {{- end }}{{ if not $last }}<|eot_id|>{{ end }}
353
+ {{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
354
+ {{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
355
+ {{ end }}
356
+ {{- end }}
357
+ {{- end }}""",
358
+ "stop_tokens": ["<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>"]
359
+ },
360
+ "qwen": {
361
+ "template": """{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|>
362
+ {{- else if .Messages }}
363
+ {{- if or .System .Tools }}<|im_start|>system
364
+ {{- if .System }}
365
+ {{ .System }}
366
+ {{- end }}
367
+ {{- if .Tools }}
368
+ # Tools
369
+ You may call one or more functions to assist with the user query.
370
+ You are provided with function signatures within <tools></tools> XML tags:
371
+ <tools>
372
+ {{- range .Tools }}
373
+ {"type": "function", "function": {{ .Function }}}
374
+ {{- end }}
375
+ </tools>
376
+ For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
377
+ <tool_call>
378
+ {"name": <function-name>, "arguments": <args-json-object>}
379
+ </tool_call>
380
+ {{- end }}<|im_end|>
381
+ {{ end }}
382
+ {{- range $i, $_ := .Messages }}
383
+ {{- $last := eq (len (slice $.Messages $i)) 1 -}}
384
+ {{- if eq .Role "user" }}<|im_start|>user
385
+ {{ .Content }}<|im_end|>
386
+ {{ else if eq .Role "assistant" }}<|im_start|>assistant
387
+ {{ if .Content }}{{ .Content }}
388
+ {{- else if .ToolCalls }}<tool_call>
389
+ {{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
390
+ {{ end }}</tool_call>
391
+ {{- end }}{{ if not $last }}<|im_end|>
392
+ {{ end }}
393
+ {{- else if eq .Role "tool" }}<|im_start|>user
394
+ <tool_response>
395
+ {{ .Content }}
396
+ </tool_response><|im_end|>
397
+ {{ end }}
398
+ {{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
399
+ {{ end }}
400
+ {{- end }}
401
+ {{- else }}
402
+ {{- if .System }}<|im_start|>system
403
+ {{ .System }}<|im_end|>
404
+ {{ end }}{{ if .Prompt }}<|im_start|>user
405
+ {{ .Prompt }}<|im_end|>
406
+ {{ end }}<|im_start|>assistant
407
+ {{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}""",
408
+ "system": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
409
+ "num_ctx": 32768,
410
+ "stop_tokens": ["<|endoftext|>"]
411
+ },
412
+ "mistral": {
413
+ "template": "[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]",
414
+ "stop_tokens": ["[INST]", "[/INST]"]
415
+ },
416
+ "phi": {
417
+ "template": """{{- range $i, $_ := .Messages }}
418
+ {{- $last := eq (len (slice $.Messages $i)) 1 -}}
419
+ <|im_start|>{{ .Role }}<|im_sep|>
420
+ {{ .Content }}{{ if not $last }}<|im_end|>
421
+ {{ end }}
422
+ {{- if and (ne .Role "assistant") $last }}<|im_end|>
423
+ <|im_start|>assistant<|im_sep|>
424
+ {{ end }}
425
+ {{- end }}""",
426
+ "stop_tokens": ["<|im_start|>", "<|im_end|>", "<|im_sep|>"]
427
+ },
428
+ "deepseek": {
429
+ "template": """{{- if .System }}{{ .System }}{{ end }}
430
+ {{- range $i, $_ := .Messages }}
431
+ {{- $last := eq (len (slice $.Messages $i)) 1}}
432
+ {{- if eq .Role "user" }}
433
+ {{ .Content }}
434
+ {{- else if eq .Role "assistant" }}
435
+ {{ .Content }}{{- if not $last }}
436
+ {{- end }}
437
+ {{- end }}
438
+ {{- if and $last (ne .Role "assistant") }}
439
+ {{ end }}
440
+ {{- end }}""",
441
+ "stop_tokens": ["", "", "", ""]
442
+ },
443
+ "llava": {
444
+ "template": """{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|>
445
+ {{- else if .Messages }}
446
+ {{- if or .System .Tools }}<|im_start|>system
447
+ {{- if .System }}
448
+ {{ .System }}
449
+ {{- end }}
450
+ {{- if .Tools }}
451
+ # Tools
452
+ You may call one or more functions to assist with the user query.
453
+ You are provided with function signatures within <tools></tools> XML tags:
454
+ <tools>
455
+ {{- range .Tools }}
456
+ {"type": "function", "function": {{ .Function }}}
457
+ {{- end }}
458
+ </tools>
459
+ For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
460
+ <tool_call>
461
+ {"name": <function-name>, "arguments": <args-json-object>}
462
+ </tool_call>
463
+ {{- end }}<|im_end|>
464
+ {{ end }}
465
+ {{- range $i, $_ := .Messages }}
466
+ {{- $last := eq (len (slice $.Messages $i)) 1 -}}
467
+ {{- if eq .Role "user" }}<|im_start|>user
468
+ {{ .Content }}<|im_end|>
469
+ {{ else if eq .Role "assistant" }}<|im_start|>assistant
470
+ {{ if .Content }}{{ .Content }}
471
+ {{- else if .ToolCalls }}<tool_call>
472
+ {{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
473
+ {{ end }}</tool_call>
474
+ {{- end }}{{ if not $last }}<|im_end|>
475
+ {{ end }}
476
+ {{- else if eq .Role "tool" }}<|im_start|>user
477
+ <tool_response>
478
+ {{ .Content }}
479
+ </tool_response><|im_end|>
480
+ {{ end }}
481
+ {{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
482
+ {{ end }}
483
+ {{- end }}
484
+ {{- else }}
485
+ {{- if .System }}<|im_start|>system
486
+ {{ .System }}<|im_end|>
487
+ {{ end }}{{ if .Prompt }}<|im_start|>user
488
+ {{ .Prompt }}<|im_end|>
489
+ {{ end }}<|im_start|>assistant
490
+ {{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}""",
491
+ "stop_tokens": ["</s>", "USER:", "ASSSISTANT:"]
492
+ }
493
+ }
494
+ # Select mapping by checking if any key is in the model_name.
495
+ chosen = None
496
+ for key, settings in mapping.items():
497
+ if key in model_name:
498
+ chosen = settings
499
+ break
500
+ if chosen is None:
501
+ # Fallback default
502
+ chosen = {
503
+ "template": """{{ if .System }}<|start_header_id|>system<|end_header_id|>
504
+ {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
505
+ {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
506
+ {{ .Response }}<|eot_id|>""",
507
+ "stop_tokens": ["<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>"]
508
+ }
509
+ # Build the stop parameter lines.
510
+ stop_params = "\n".join([f"PARAMETER stop {token}" for token in chosen["stop_tokens"]])
511
+ # Optionally include a SYSTEM line and num_ctx if defined in the mapping.
512
+ system_line = ""
513
+ if "system" in chosen:
514
+ system_line = f"SYSTEM {chosen['system']}\n"
515
+ num_ctx_line = ""
516
+ if "num_ctx" in chosen:
517
+ num_ctx_line = f"PARAMETER num_ctx {chosen['num_ctx']}\n"
518
+ # Assemble and return the modelfile content.
519
+ return f"""FROM {output_model}
520
+ TEMPLATE \"\"\"{chosen['template']}\"\"\"
521
+ {system_line}{num_ctx_line}{stop_params}
522
+ """
523
+
524
+ def create_and_push_ollama_model(self):
525
+ modelfile_content = self.prepare_modelfile_content()
526
+ with open("Modelfile", "w") as file:
527
+ file.write(modelfile_content)
528
+ subprocess.run(["ollama", "serve"])
529
+ subprocess.run(["ollama", "create", f"{self.config['ollama_model']}:{self.config['model_parameters']}", "-f", "Modelfile"])
530
+ subprocess.run(["ollama", "push", f"{self.config['ollama_model']}:{self.config['model_parameters']}"])
531
+
532
+ def run(self):
533
+ self.print_system_info()
534
+ self.check_gpu()
535
+ self.check_ram()
536
+ if self.config.get("train", "true").lower() == "true":
537
+ self.prepare_model()
538
+ self.train_model()
539
+ if self.config.get("huggingface_save", "true").lower() == "true":
540
+ self.save_model_merged()
541
+ if self.config.get("huggingface_save_gguf", "true").lower() == "true":
542
+ self.push_model_gguf()
543
+ if self.config.get("ollama_save", "true").lower() == "true":
544
+ self.create_and_push_ollama_model()
545
+
546
+ def main():
547
+ import argparse
548
+ parser = argparse.ArgumentParser(description="PraisonAI Training Script")
549
+ parser.add_argument("command", choices=["train"], help="Command to execute")
550
+ parser.add_argument("--config", default="config.yaml", help="Path to configuration file")
551
+ parser.add_argument("--model", type=str, help="Model name")
552
+ parser.add_argument("--hf", type=str, help="Hugging Face model name")
553
+ parser.add_argument("--ollama", type=str, help="Ollama model name")
554
+ parser.add_argument("--dataset", type=str, help="Dataset name for training")
555
+ args = parser.parse_args()
556
+
557
+ if args.command == "train":
558
+ trainer_obj = TrainModel(config_path=args.config)
559
+ trainer_obj.run()
560
+
561
+ if __name__ == "__main__":
562
+ main()