ipex-llm 2.2.0b20250107__py3-none-win_amd64.whl → 2.2.0b20250109__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ipex_llm/libs/bloom-api.dll +0 -0
  2. ipex_llm/libs/bloom.dll +0 -0
  3. ipex_llm/libs/gptneox-api.dll +0 -0
  4. ipex_llm/libs/gptneox.dll +0 -0
  5. ipex_llm/libs/libbloom_avx.dll +0 -0
  6. ipex_llm/libs/libbloom_vnni.dll +0 -0
  7. ipex_llm/libs/libgptneox_avx.dll +0 -0
  8. ipex_llm/libs/libgptneox_vnni.dll +0 -0
  9. ipex_llm/libs/libllama_avx.dll +0 -0
  10. ipex_llm/libs/libllama_vnni.dll +0 -0
  11. ipex_llm/libs/libstarcoder_avx.dll +0 -0
  12. ipex_llm/libs/libstarcoder_vnni.dll +0 -0
  13. ipex_llm/libs/llama-api.dll +0 -0
  14. ipex_llm/libs/llama.dll +0 -0
  15. ipex_llm/libs/main-bloom.exe +0 -0
  16. ipex_llm/libs/main-gptneox.exe +0 -0
  17. ipex_llm/libs/main-llama.exe +0 -0
  18. ipex_llm/libs/main-starcoder.exe +0 -0
  19. ipex_llm/libs/pipeline.dll +0 -0
  20. ipex_llm/libs/quantize-bloom.exe +0 -0
  21. ipex_llm/libs/quantize-bloom_vnni.exe +0 -0
  22. ipex_llm/libs/quantize-gptneox.exe +0 -0
  23. ipex_llm/libs/quantize-gptneox_vnni.exe +0 -0
  24. ipex_llm/libs/quantize-llama.exe +0 -0
  25. ipex_llm/libs/quantize-llama_vnni.exe +0 -0
  26. ipex_llm/libs/quantize-starcoder.exe +0 -0
  27. ipex_llm/libs/quantize-starcoder_vnni.exe +0 -0
  28. ipex_llm/libs/starcoder-api.dll +0 -0
  29. ipex_llm/libs/starcoder.dll +0 -0
  30. ipex_llm/transformers/convert.py +20 -50
  31. ipex_llm/transformers/loader.py +1 -1
  32. ipex_llm/transformers/low_bit_linear.py +10 -25
  33. ipex_llm/transformers/model.py +0 -7
  34. ipex_llm/transformers/models/baichuan.py +7 -36
  35. ipex_llm/transformers/models/bert.py +2 -13
  36. ipex_llm/transformers/models/chatglm2.py +8 -31
  37. ipex_llm/transformers/models/chatglm4.py +9 -4
  38. ipex_llm/transformers/models/chatglm4v.py +2 -1
  39. ipex_llm/transformers/models/common.py +3 -1
  40. ipex_llm/transformers/models/glm.py +4 -2
  41. ipex_llm/transformers/models/internlm.py +6 -3
  42. ipex_llm/transformers/models/llama.py +2 -2
  43. ipex_llm/transformers/models/minicpm.py +3 -2
  44. ipex_llm/transformers/models/minicpm3.py +3 -1
  45. ipex_llm/transformers/models/minicpmv.py +1 -0
  46. ipex_llm/transformers/models/mistral.py +1 -1
  47. ipex_llm/transformers/models/mllama.py +1 -1
  48. ipex_llm/transformers/models/phi3.py +6 -2
  49. ipex_llm/transformers/models/qwen.py +4 -2
  50. ipex_llm/transformers/models/qwen2.py +4 -3
  51. ipex_llm/transformers/models/qwen2_moe.py +4 -2
  52. ipex_llm/transformers/models/qwen2_vl.py +3 -1
  53. ipex_llm/transformers/models/stablelm.py +3 -1
  54. ipex_llm/transformers/models/starcoder2.py +3 -1
  55. ipex_llm/transformers/models/utils.py +10 -19
  56. ipex_llm/transformers/models/yuan.py +2 -1
  57. ipex_llm/transformers/speculative.py +2 -14
  58. ipex_llm/transformers/utils.py +2 -14
  59. ipex_llm/transformers/xpu_ops.py +25 -19
  60. {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250109.dist-info}/METADATA +20 -20
  61. {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250109.dist-info}/RECORD +67 -68
  62. ipex_llm/transformers/models/gptj.py +0 -441
  63. {ipex_llm-2.2.0b20250107.data → ipex_llm-2.2.0b20250109.data}/scripts/ipex-llm-init.bat +0 -0
  64. {ipex_llm-2.2.0b20250107.data → ipex_llm-2.2.0b20250109.data}/scripts/llm-chat.ps1 +0 -0
  65. {ipex_llm-2.2.0b20250107.data → ipex_llm-2.2.0b20250109.data}/scripts/llm-cli.ps1 +0 -0
  66. {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250109.dist-info}/WHEEL +0 -0
  67. {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250109.dist-info}/entry_points.txt +0 -0
  68. {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250109.dist-info}/top_level.txt +0 -0
@@ -41,35 +41,35 @@ ipex_llm/langchain/llms/transformerspipelinellm.py,sha256=vm522YPPwWxxAPVvQBtxRf
41
41
  ipex_llm/langchain/vllm/__init__.py,sha256=T-EbRT6GJ_8RCu-iLmSzcftOimXSPQf2d5X72AUAy2Y,874
42
42
  ipex_llm/langchain/vllm/vllm.py,sha256=6dxc-ZISZQrJilEa_HA827l75Dv9rcHpY_G6FdJ8BVs,7793
43
43
  ipex_llm/libs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
- ipex_llm/libs/bloom-api.dll,sha256=aa3qzp0VqG1RzqCWC8uPfyzCsf3xC0qqgyj4WxDoOsM,36352
45
- ipex_llm/libs/bloom.dll,sha256=kVhdgqqPvusN-aFfVCBnSZSGUWRw9dv9ICv-klmF0UE,506880
46
- ipex_llm/libs/gptneox-api.dll,sha256=bjp6ZiVvmT6aWrDiqgV43pYeifLnGDlNnRwi3r1rTFo,24576
47
- ipex_llm/libs/gptneox.dll,sha256=xKfJ9oNMoUXzHH43vXrap9wbm-KgKY-AhARrVvU_iLg,567296
48
- ipex_llm/libs/libbloom_avx.dll,sha256=VGPb6z7Jh4pgStkfnzqEBW-mnzLTETVUsc20ym5c1eI,535040
49
- ipex_llm/libs/libbloom_vnni.dll,sha256=_ibtsNyO3G9KV30mKJky49eKMm_8ZVgKPA0FdjjIvB0,506880
50
- ipex_llm/libs/libgptneox_avx.dll,sha256=onI5sXdEsCAPL3JAIGH5r4jwPbevYi6bw15CmPyZL5s,595456
51
- ipex_llm/libs/libgptneox_vnni.dll,sha256=6BvdfvGnikm78KlPcdGq6IiAecPkPwm6i38PotgB9FQ,567808
52
- ipex_llm/libs/libllama_avx.dll,sha256=TRqBjWOmR7dCQtEq9c5l0wshBPM2e7klfca8eV-f79Q,589824
53
- ipex_llm/libs/libllama_vnni.dll,sha256=F9va1YfNVe9XGfilbio3mEdbbGtmpfAR-9T8vWEqulI,561664
54
- ipex_llm/libs/libstarcoder_avx.dll,sha256=IvM2V18nf_g9tsaLz8W30qgchg496V89PpCoKMHnQPY,626688
55
- ipex_llm/libs/libstarcoder_vnni.dll,sha256=X9H9Tyy3DdQg1LIvc8ILP33ilj4amEUfwF_E9KPVWvE,598528
56
- ipex_llm/libs/llama-api.dll,sha256=-XYElLfNrWEP63PStFxXX9wM0kP6PGYQJH6oePg5u-I,25600
57
- ipex_llm/libs/llama.dll,sha256=zLW5de-ASaVBoPQbRYUMK6L6FoPwmt6II4bGFVeHE-M,561152
58
- ipex_llm/libs/main-bloom.exe,sha256=cNaMadt2EpUgCu55XjQW0Vi1UUBQ5eA9CnfXZV6OGjc,103424
59
- ipex_llm/libs/main-gptneox.exe,sha256=54va_bkWTOw0Gf_EMEF2MaqNsKEugEsxTxUszTAxY7Q,98816
60
- ipex_llm/libs/main-llama.exe,sha256=banfh_rJMhBR4FERotQ6FBOKDEhz-KF49Zu6E86Wlso,99840
61
- ipex_llm/libs/main-starcoder.exe,sha256=L-SIpoGn0vA8CD37MMm44i56CTWKXQjnC6WWZE1okPw,157696
62
- ipex_llm/libs/pipeline.dll,sha256=rOzOT0lY3RMQlIDHCvaAx_bjAvQD8BBZNyqi3BssG2U,72704
63
- ipex_llm/libs/quantize-bloom.exe,sha256=GvY5jFYGrK_wFa1RyAurX0vD_UAg7FCD9Ls1zAbYYKc,126464
64
- ipex_llm/libs/quantize-bloom_vnni.exe,sha256=j3i2Q_ymqlnrPBUCskU3gqrwV198NkfLxmlz0KOsqoI,127488
65
- ipex_llm/libs/quantize-gptneox.exe,sha256=he-NK5wy8u_t3S26Qb4Euenf0E6mw8O6lsNOcX-m2ts,104448
66
- ipex_llm/libs/quantize-gptneox_vnni.exe,sha256=Oz7iHHL01QZYJpCBsPeYXhBCRBJ9kFklYZ2UhEw77rk,104960
67
- ipex_llm/libs/quantize-llama.exe,sha256=Ns-98bA07AeTPaN7v0zs3cOb3PAvUmj7R_Xvsn-1bTM,109568
68
- ipex_llm/libs/quantize-llama_vnni.exe,sha256=t4wts7FmUT4n5_Ii6wqtrVMv73pf3pOMMfrziRZAI5U,110592
69
- ipex_llm/libs/quantize-starcoder.exe,sha256=xYLEIQ4gtU23ae7cN_1hZrxqDrSiQdCNu3EZRbffNLs,127488
70
- ipex_llm/libs/quantize-starcoder_vnni.exe,sha256=5gmwDZTTqM9fFuj9UihVyA91TmcoBq95xxAofaBi4sM,128512
71
- ipex_llm/libs/starcoder-api.dll,sha256=UhlNrAN44aZwkbjdmVFr9E2tV3iuQ8jR1OvEgiQzVXk,21504
72
- ipex_llm/libs/starcoder.dll,sha256=1jGXZ-nQo3m1kMXeOVwd-fc68p7kjPGfzAu4ifngrlE,598016
44
+ ipex_llm/libs/bloom-api.dll,sha256=vLC0Hy_vsNY1cTKoFMJ7g9iPzKV5X_YL8U9N3k8FHFk,36352
45
+ ipex_llm/libs/bloom.dll,sha256=hVQn4GiRqyD8P-_XbuQD_inkB74jxR8MutecrzrIQWQ,506880
46
+ ipex_llm/libs/gptneox-api.dll,sha256=R7YnJz6n4TCdPq5NYWuSaanj-hGrYnzZfJ9P4z033Ws,24576
47
+ ipex_llm/libs/gptneox.dll,sha256=Iyi18Z9buCMvglEGOPHYlpL4kSa4CADaXUMXJF4Bhqs,567296
48
+ ipex_llm/libs/libbloom_avx.dll,sha256=8NZmJGuvsGISZDOO1wp6jtfh6xWGmLfFhS9RfGl8WEk,535040
49
+ ipex_llm/libs/libbloom_vnni.dll,sha256=mTGMJzQiM7Mj0Qdrz7Xs7YWWypXijWum_LmSBR5Q2uE,506880
50
+ ipex_llm/libs/libgptneox_avx.dll,sha256=617wPDx7BAimc3oOSF2X0jh9h6j4mdBFRBt-1Qkmz1A,595456
51
+ ipex_llm/libs/libgptneox_vnni.dll,sha256=DmCL8HfUA8DbNQiZHDB77fpadSqnHSaxHpWRCw_FTDA,567808
52
+ ipex_llm/libs/libllama_avx.dll,sha256=JHRfoFLp_B-3J5B1oa2FnKIIjKK6ZJaJrTfJXLgrGIA,589824
53
+ ipex_llm/libs/libllama_vnni.dll,sha256=WxH8QfmfpuZ9rZJP2lcoXX5Zp9lonun-hhw26O40w78,561664
54
+ ipex_llm/libs/libstarcoder_avx.dll,sha256=9Dqf2uWGA8XlGAjOqgMQiC-iNJYBhOhNiFk0GZTfwiQ,626688
55
+ ipex_llm/libs/libstarcoder_vnni.dll,sha256=ob3olLuRXCGvpNuBYImrt92aejvhDsZpi-m_r75lIUY,598528
56
+ ipex_llm/libs/llama-api.dll,sha256=1OMZVvZun0kQBd1EEQZUvSRJXHfaJgo9WYka4tAamGA,25600
57
+ ipex_llm/libs/llama.dll,sha256=ZLuhZQAjPc76oJ-PO_yUlZi24qlzIBs8OXnR9-2LTGU,561152
58
+ ipex_llm/libs/main-bloom.exe,sha256=T17qQytgapUQAmXSkH38RO21Ynb9NrOSiD3erxr4kJI,103424
59
+ ipex_llm/libs/main-gptneox.exe,sha256=qskuY8ijrPRcBzTtGb8gB1PAoSvE4wkJe2GgiAAWS-4,98816
60
+ ipex_llm/libs/main-llama.exe,sha256=t-mQ7luxLky6boib8wSRecL81DAdKF3HsLZ-oN4Agm4,99840
61
+ ipex_llm/libs/main-starcoder.exe,sha256=WQ2L7L2XdmHxzt6L3uaHuWwTJVL9xg683y4_7avVVgk,157696
62
+ ipex_llm/libs/pipeline.dll,sha256=_Kx68vo1DLJc57qF8aGEhdiCFbcAsm8xYQWeJHogVuM,72704
63
+ ipex_llm/libs/quantize-bloom.exe,sha256=CTJhn-jYm4V99Fcv2R2mHkcYMfzQOU8IiqMKfqRK59A,126464
64
+ ipex_llm/libs/quantize-bloom_vnni.exe,sha256=aauMlSe4e05rKB61u4PQdckBNWQzrISe5_HXkRCTCfI,127488
65
+ ipex_llm/libs/quantize-gptneox.exe,sha256=QP-fDnaF6FElKal_kRh_N5nB-WUQKah8vG33lJQ3g9c,104448
66
+ ipex_llm/libs/quantize-gptneox_vnni.exe,sha256=PUgZZTq1a36tbtsgu-jnK0MHlsFQ-zPtAsAQ7PLczJU,104960
67
+ ipex_llm/libs/quantize-llama.exe,sha256=uagQGbAgP8ripvyS0CREOQH0JhwHRi5dWeo_Q3criNc,109568
68
+ ipex_llm/libs/quantize-llama_vnni.exe,sha256=fB7-E0PGSliaZDfu9bCl1hdcoD7xIMoRHiE6fFTJ75g,110592
69
+ ipex_llm/libs/quantize-starcoder.exe,sha256=6OMrGkHL87jw18EVzlEj9NJpjTl7b3_7YOF_YRC24sI,127488
70
+ ipex_llm/libs/quantize-starcoder_vnni.exe,sha256=i6SaJyFmS_S2jjl9Cpa_5S-RyHo1_ABD0gx2maxTdfc,128512
71
+ ipex_llm/libs/starcoder-api.dll,sha256=CgnbCueyY_M3QN-lskMQBPPVrfuKgy2Qt0aVV6wXPgQ,21504
72
+ ipex_llm/libs/starcoder.dll,sha256=bc_dnjXNNVaXQlVdGp-a6IMZO3J2Ux9hwEQH7xoEaek,598016
73
73
  ipex_llm/llamaindex/__init__.py,sha256=T-EbRT6GJ_8RCu-iLmSzcftOimXSPQf2d5X72AUAy2Y,874
74
74
  ipex_llm/llamaindex/llms/__init__.py,sha256=KP1lEdGqDuxPoxL1ZSH25Pm2kKMPJBWUTLR0ckSLMIU,1139
75
75
  ipex_llm/llamaindex/llms/bigdlllm.py,sha256=FQBzq1KOjfc6uofTXAha3O7TqpJkNfOFepXQmOVlbnI,26314
@@ -87,27 +87,27 @@ ipex_llm/serving/fastchat/tgi_api_protocol.py,sha256=brT3k3-V0NJrU4fRqUwWjC0O3iO
87
87
  ipex_llm/serving/fastchat/tgi_api_server.py,sha256=agNTAEiZPSuj3dEdIdYKwkoY0cXOUDX06DiM9VP2knQ,24418
88
88
  ipex_llm/serving/fastchat/vllm_worker.py,sha256=ZLz2Q9GxJO6r_LOiP6epgCRjBGk-K4EB1SNEWSJp5DA,11091
89
89
  ipex_llm/transformers/__init__.py,sha256=l4KkMkLe-pRC7b_kj6LCfeifgE-Uo33_Av_FwN9HnFA,1074
90
- ipex_llm/transformers/convert.py,sha256=pFm6VlU84u_Llr2sp6-gRrEYDeNgIk2QPukolq4IE1s,99947
90
+ ipex_llm/transformers/convert.py,sha256=umI137wqV2d4itS0AJQoZcygeWBATpSJSDJ805cZ-SY,98499
91
91
  ipex_llm/transformers/convert_ipex.py,sha256=iKXo0n8fVFTOA2fNYYrByMFK0dovL-kLd2sVDk88AlQ,14334
92
92
  ipex_llm/transformers/embedding.py,sha256=bdgk59DvD4ZZyxRzewXOR7g56nThgO6uhIwk8QL7f-s,9299
93
93
  ipex_llm/transformers/kv.py,sha256=k4TU18LlA-Sbq9WNNQnfuzu3RSFBwFhmaV3BcGN5bAo,19191
94
94
  ipex_llm/transformers/lisa.py,sha256=F5WxbtXQ7RdKulj83h_2DnEIgKiKGZf7zvOmg6QBl2s,3289
95
- ipex_llm/transformers/loader.py,sha256=cOgX93xOC-4dt01GTJ5wyd7PjZ8S43r4mctkR2YxVuw,6893
95
+ ipex_llm/transformers/loader.py,sha256=AwjV5RpI2t2bedlv7ZhLm8cfd-QJZm5hny-XyjIvdnk,6876
96
96
  ipex_llm/transformers/lookup.py,sha256=b6OlZ9OV10R9qeWw8mVryVpDxszkjwLkldvi7GPMJY8,19614
97
- ipex_llm/transformers/low_bit_linear.py,sha256=nKraUvZJ7UdXP29HSE4CJPIVxmN-TvG8dpT4gpleuyQ,41688
98
- ipex_llm/transformers/model.py,sha256=KcRjkauGg48BYrUBoUZaVMpg7Piuz5JrfIpVZd3EIjs,41105
97
+ ipex_llm/transformers/low_bit_linear.py,sha256=Obdd08D9dvuroS_6XWo4DXO_DrNRsbAqjz-mQAHmfxY,40845
98
+ ipex_llm/transformers/model.py,sha256=fj7LBjrWtWwDJJYXnWiXsLGS4ayqqHfnh0p51dSDssE,40908
99
99
  ipex_llm/transformers/modelling_bigdl.py,sha256=7JpNVMuyq_OmtNUaMFMXdxPWZp2q0QHC02QeA-VTPOw,6709
100
100
  ipex_llm/transformers/npu_model.py,sha256=YW02GeVz-9ZGqxAeSz0AOvciS-17bo9eK5ZOBrICwSQ,39508
101
101
  ipex_llm/transformers/patches.py,sha256=halPWm__ORh2fRFSIFPiCNg3LQBfrRkTPtmtRpBJCZQ,1286
102
102
  ipex_llm/transformers/pipeline_parallel.py,sha256=uNZpOXljNmdoEYnP8U-VFiN4dRZb2piQbIf2bG9LQnE,49051
103
103
  ipex_llm/transformers/qlora.py,sha256=jtPGsvWFjbTUGzDBCdfftnCis_0nJQNRpACSwXUbbGU,14943
104
104
  ipex_llm/transformers/relora.py,sha256=-dYzUV0P-IhO2jFdnzN9-v_sFzJpRj3ZwN9eCJzOoCw,16567
105
- ipex_llm/transformers/speculative.py,sha256=Zf1nQb5GXpJQrUHBTL-H4RUBfdv3lGhfehzudHimhYk,64109
105
+ ipex_llm/transformers/speculative.py,sha256=0XNLgc9dGswJHVPrXo4iM7pPxkWwfFfJMECcivJSnIc,63368
106
106
  ipex_llm/transformers/streamer.py,sha256=RrVlLblzCOtABRUpaMXAyaMnCGgLUtAi_YesLumRbww,4842
107
107
  ipex_llm/transformers/training_patch.py,sha256=oxMkUtqyvqJiprw6dE3skkYfD1HOmUlH9N0hBkbn0G0,10799
108
- ipex_llm/transformers/utils.py,sha256=fXLIlr9hoBr27p3w3xzczZGPk2cCTIRbUKBkiVCGYbc,16889
108
+ ipex_llm/transformers/utils.py,sha256=9IRSqfDokf8QFW9T47R--i3RL1E-_O31HO7IJf7H6pg,16748
109
109
  ipex_llm/transformers/xpu_customize_fwd.py,sha256=wFpIhs5F6tkNs8gBOrLxWdhLzO3EDHovVkERPIAoAvg,7611
110
- ipex_llm/transformers/xpu_ops.py,sha256=H46-69pMRQhekbAEoDfNacCInLWycMHDqrgMGLvFYfI,4362
110
+ ipex_llm/transformers/xpu_ops.py,sha256=vw4cUwvqUqDr45d-WMIkCpM2oiHfjN-VjF0bjMSF4kY,4830
111
111
  ipex_llm/transformers/awq/__init__.py,sha256=Du5gu3-eeAkeDO_dEMBTzrDBA66DSN3uL3-rn8WGXQw,875
112
112
  ipex_llm/transformers/awq/act.py,sha256=YwomJzOOKwkKtzGrm4L4kwBstBLO1Z8SK4CKi8PSYVQ,2172
113
113
  ipex_llm/transformers/awq/awq.py,sha256=cGyRQJWwAEJtOtdSbsBoQ33KX_Ie0pv5OJHC0ACEELE,8861
@@ -137,46 +137,45 @@ ipex_llm/transformers/gguf/models/model_implement/yuan2/configuration_yuan.py,sh
137
137
  ipex_llm/transformers/gguf/models/model_implement/yuan2/yuan_hf_model.py,sha256=_AOGMV65XHxgTxIib7lgs49InopcecTzRwgtYR8NTUg,51084
138
138
  ipex_llm/transformers/models/__init__.py,sha256=tp2DcVkKg1-QvdYk7DY7rZvQWCDQ4ZjU8NAQ7Fclrpg,584
139
139
  ipex_llm/transformers/models/aquila.py,sha256=VZb5Drpo_fTxwcExZ397LygnsNPX2sVbie9_JeFudZI,5252
140
- ipex_llm/transformers/models/baichuan.py,sha256=oJCAEENSG8oQhJ-QPN2SiapARjAGdOM6nEbyCcYOMCo,19334
141
- ipex_llm/transformers/models/bert.py,sha256=bJNic2pt1kph0kBwdK5MRGyWupFfx2Ts0V3D1L-5kWo,6085
140
+ ipex_llm/transformers/models/baichuan.py,sha256=cAQLmVG-3R8CSTGTcDy2JOOzVe-Ej8AXjIEIjvZBGlo,18376
141
+ ipex_llm/transformers/models/bert.py,sha256=0Mm9jkvkzBxtc_z_GE1TcZoPz-HOg2Z2973ZEWgSwJk,5601
142
142
  ipex_llm/transformers/models/bloom.py,sha256=PxfzyYT-nFn3K5rZhTQjmcEjUUzAhUFzxIN4kzRlCuc,8103
143
143
  ipex_llm/transformers/models/chatglm.py,sha256=UHai1t2AUtGmF765_eHF8LUMVQzp_oCBx8TJB21WrHk,12597
144
- ipex_llm/transformers/models/chatglm2.py,sha256=SGCABJdYQLW0zDarEoWrEQLuWlbq9iQhYU8ZeR1-ptQ,15957
145
- ipex_llm/transformers/models/chatglm4.py,sha256=AAhAFFDDas5DBQPfh2Mwl7a2v7taKf6xphoeeNNFaBI,16593
146
- ipex_llm/transformers/models/chatglm4v.py,sha256=YRfuf9g1E0MQ_7wbHAOMvadFnO-j3LqI_k1SaRkDs0M,14055
147
- ipex_llm/transformers/models/common.py,sha256=4obQMGF02FCiXrHnFle9Fsx7C33b1FDt37qJJ4YgxRc,11578
144
+ ipex_llm/transformers/models/chatglm2.py,sha256=KyAIX7zGVQDQuwwM3QMBNWZbTeMHEzKUIgAryT0voHc,14933
145
+ ipex_llm/transformers/models/chatglm4.py,sha256=QvUehdaCePB3MNHyWg3dneDxmjtBdxYeKUyQUVcsgfM,16886
146
+ ipex_llm/transformers/models/chatglm4v.py,sha256=L6y45M_wjS2_HqchmCUxRlQZUNuSNCGOiynAQrGh918,14124
147
+ ipex_llm/transformers/models/common.py,sha256=Q3IEfGqvxoHyfIIF5s8qHmOJBBP3b2jyVAXk8C3b1Pg,11636
148
148
  ipex_llm/transformers/models/decilm.py,sha256=P-PBuDPf07GvKggLwJx_wPwIn6esN3rX8ai2JxRuZmE,5246
149
149
  ipex_llm/transformers/models/gemma.py,sha256=_E3Yw8Y45xyNVeLqyVKcpr8kjuICtETeL82cJ-bWJuU,9424
150
150
  ipex_llm/transformers/models/gemma2.py,sha256=2WZuv-FLzJyTJFaYxOuzJt47QE64M0lHnzAiO5T6ozI,8049
151
- ipex_llm/transformers/models/glm.py,sha256=gHYgfn20jPRL-ElXy-rUqMh6_LQcc5x7DEXSZuRA4E0,7094
151
+ ipex_llm/transformers/models/glm.py,sha256=lmeEWd_W2O638VzVW4Gm6cJre5XZcg_QBmPs8NWqXsM,7202
152
152
  ipex_llm/transformers/models/gpt2.py,sha256=YSaNgK1uLCFDuIFqnKO0Mi-AsOZsYav-7pNf_NpKGdM,3445
153
153
  ipex_llm/transformers/models/gptbigcode.py,sha256=cP1_qGWoa43R2WacAMblShjku4QupcCZiLaPPAoOUs4,9101
154
- ipex_llm/transformers/models/gptj.py,sha256=TTIx461X2nOcIkrAcZhEf7d7mjJ3yvEC9KLVc1-hrpc,17973
155
154
  ipex_llm/transformers/models/gptneox.py,sha256=loRh1x_5S6BCeOr_s5xr-N_1SQHL3Y5IiUBAEyoMUqQ,6172
156
- ipex_llm/transformers/models/internlm.py,sha256=ZbIUMDwNRcrCeduXfbA_uq1AUEWawEt6CJRvQl3LkAg,17832
155
+ ipex_llm/transformers/models/internlm.py,sha256=OifyiobRligleyZLpLBSe44A6Sq0uMG-8-NOcRCcT4Q,18080
157
156
  ipex_llm/transformers/models/internvl.py,sha256=Vx0vENIEQLX2M6P398mw5TOhpks0U8xf8rtRQvy94go,8154
158
- ipex_llm/transformers/models/llama.py,sha256=ozwtdQ0MbanJEtW4LBFGxqs_QAq82EonhL2dL6tGyw0,8567
159
- ipex_llm/transformers/models/minicpm.py,sha256=ib2rJTN7Tf7znBCtVrtXsF-_Uuk2aA7KVg02xzatLiI,10103
160
- ipex_llm/transformers/models/minicpm3.py,sha256=FhNS6mi2rg7dSdF_QQGrao3g9EC6XLn1MTKd-kd0wF0,9191
161
- ipex_llm/transformers/models/minicpmv.py,sha256=ZV4s48WNIyRoEkvENnlmopnx3ojZANBer0LI6bRtxrY,9826
162
- ipex_llm/transformers/models/mistral.py,sha256=rE1GWQxXvF6aG-buPHDR13zeynDZEDIubPF4PiVhZbM,7451
163
- ipex_llm/transformers/models/mllama.py,sha256=ogpLmmN_OwcFUyjYB-oDC-l3uw8urFvUEc5edkjWHAk,10939
157
+ ipex_llm/transformers/models/llama.py,sha256=NzpyQve_RC9ez1W-jWPLGZ80k_S1I5Rx5saAzCsDIoI,8558
158
+ ipex_llm/transformers/models/minicpm.py,sha256=eaPNVNrep0_xGoELhZd886ff0ceoKqB6cusdAhd52eE,10145
159
+ ipex_llm/transformers/models/minicpm3.py,sha256=11cYl8KM2hoIJNMAOZMxiwCu6dMhup9ric_OEn8-VrQ,9363
160
+ ipex_llm/transformers/models/minicpmv.py,sha256=PP05b5iTnrMpiseCn8iJcxKJDnfq7WqXp9Mrch0kKZ0,9876
161
+ ipex_llm/transformers/models/mistral.py,sha256=uVhkdXaq15v1P3QY0emVsA7SxUbAWChHEEXYN-drjpQ,7449
162
+ ipex_llm/transformers/models/mllama.py,sha256=ZyRq9DTKsvk1AlRbr-z6ngjS3Sr_7YuGZ6-Yr1MBBAM,10937
164
163
  ipex_llm/transformers/models/mpt.py,sha256=z02NwHogJZVh-Mk4sYoIzR90SFIKhoNN_-ifsD907TQ,9540
165
164
  ipex_llm/transformers/models/phi.py,sha256=E6qz4EEuHIVGvaPo-wtLC5lz3iyMqTbAE_cRlcjQRKI,6670
166
- ipex_llm/transformers/models/phi3.py,sha256=jkiadJ85ToHpymY5GOM6orWlnx6LKN8_-v1MUcfGWPg,15159
165
+ ipex_llm/transformers/models/phi3.py,sha256=Fo6PlZ24Gdm7eeeZOTMm1Bfh3U6P4rvq7-_2FHvp0vE,15503
167
166
  ipex_llm/transformers/models/phixtral.py,sha256=MDTMghcu7qAmZmRcUGqXXDXhSU3y_N59HRIXmlcjp5g,4890
168
- ipex_llm/transformers/models/qwen.py,sha256=XIJ_bLzediBURWU-OOS3H6WBIGXQue6jDdUHJsAabwY,19391
169
- ipex_llm/transformers/models/qwen2.py,sha256=b49HO4GSudwGJ3n6uHVno1oo3DgRt3jOjtQnLOB3cdY,25530
170
- ipex_llm/transformers/models/qwen2_moe.py,sha256=EA_OYxYAEgrvi7VpDW192AJXG9Fwe2aBtOAZPkOAJk4,19350
171
- ipex_llm/transformers/models/qwen2_vl.py,sha256=jIm4yZSd751BkRqgj3wR1QBkDIh-TMCLAMM8SZ8n6Qo,13419
167
+ ipex_llm/transformers/models/qwen.py,sha256=A3WiVCzA7NLkcjp4zhFkZvKZzZWZlg0WFuVV_556TAI,19543
168
+ ipex_llm/transformers/models/qwen2.py,sha256=JLaY9ZT7A22oO0G8K-nvjvKQDaIrKA5o-jEHvk_y3eI,25604
169
+ ipex_llm/transformers/models/qwen2_moe.py,sha256=a0gYo-ngf8SxaEnBdZUJDnPS6Mkn_poDd8xqhx50icI,19516
170
+ ipex_llm/transformers/models/qwen2_vl.py,sha256=NrhxlaPj7W-HUBmKc3CSTwZy1lkoZ9qDaxM4GvE0kHs,13583
172
171
  ipex_llm/transformers/models/qwen_vl.py,sha256=j7Nzzz2Qvynu9yrCXmoEfERjw43hXof5TbXIs7Ms-oY,17105
173
172
  ipex_llm/transformers/models/rwkv4.py,sha256=H4KMtxN0JA2ZTXnonHpsUUJ5xULemo-D1Jzl0ri_UY8,6123
174
173
  ipex_llm/transformers/models/rwkv5.py,sha256=OkRNj1pCAZg1z2Fw-I0DEnxLEdZyPeRSQ6msrkxLOCs,10710
175
174
  ipex_llm/transformers/models/sd.py,sha256=VvHV5u-0k2MgHu3NL9113hPj7DgfxqctuKzEEeNfRDU,5981
176
- ipex_llm/transformers/models/stablelm.py,sha256=RGQCYuQhYqtZ1j3RZkYi0_QvCRnUgUIPYxfBcLnElzg,6885
177
- ipex_llm/transformers/models/starcoder2.py,sha256=4P3mhRYf2Kreb1ESjrQGfy1puLMmZXgV35zf-Tksvao,6462
178
- ipex_llm/transformers/models/utils.py,sha256=Qbz7UkYSbsM5bodH2445O0-JF50Mu3UEwW0j2ZNxHSU,15997
179
- ipex_llm/transformers/models/yuan.py,sha256=1jRPebwAK2ENbyYokOmb4LSVo-szucWiygz9zTv-scs,7656
175
+ ipex_llm/transformers/models/stablelm.py,sha256=fj-XtOnR6kggnFUQTMPCOOzolkPztN06WAv8QW-XRnI,7054
176
+ ipex_llm/transformers/models/starcoder2.py,sha256=ONKvD7JCkRM0DI-R56x28QFBJ7CjD5hOZBQ_3WfOcNk,6626
177
+ ipex_llm/transformers/models/utils.py,sha256=ihbWS5kQK2KHDVPkMhgjik3nM8B2fWf-E-z4BWNUstk,15568
178
+ ipex_llm/transformers/models/yuan.py,sha256=JYAn_ZaSGK0NBJLEIxCACfAq084a66GFJkdd5NbpmMA,7732
180
179
  ipex_llm/transformers/npu_models/__init__.py,sha256=ulEUGLjaP48LCrVeury3UxLjXxKzRi0UpSG4bYu-7f8,585
181
180
  ipex_llm/transformers/npu_models/baichuan.py,sha256=fJtd7fBrttySghRUgfZTAdxLjsSNC-XL08HISsXigLE,4685
182
181
  ipex_llm/transformers/npu_models/baichuan_mp.py,sha256=tHhO-0v5z6IhxsfzAPYWXVbLrV_4z89DIb4JjE3207M,45026
@@ -244,11 +243,11 @@ ipex_llm/vllm/xpu/engine/__init__.py,sha256=pY_CpyuZd72fr6s32ejeKHKFW0K4vUU2rzZj
244
243
  ipex_llm/vllm/xpu/engine/engine.py,sha256=k4-D27WS_Gk3mA--w3HWAjPjb4Aiu043MVPi0ZoAUBc,5984
245
244
  ipex_llm/vllm/xpu/entrypoints/openai/api_server.py,sha256=GshTZFB8e4PWvqckfbmTOU6b0oLkNn7A-vzLuG9--j8,21544
246
245
  ipex_llm/vllm/xpu/entrypoints/openai/cli_args.py,sha256=2rENA2ucynMaIjiZBEh2ez1o5vR32GaP514t39CD7KM,8676
247
- ipex_llm-2.2.0b20250107.data/scripts/ipex-llm-init.bat,sha256=HPtCYuDYwEatq7dAwOvdfVcHYCpAVdbj75K1qh0vQek,2578
248
- ipex_llm-2.2.0b20250107.data/scripts/llm-chat.ps1,sha256=6qrs-hGVAV8IKh7Jx8nq_XrnZcjd7qGU5wndArM7Yag,2769
249
- ipex_llm-2.2.0b20250107.data/scripts/llm-cli.ps1,sha256=3qBtTLs_EjYDnM8YyCpJhzLnGCKTEGssu9UNqfkjVXs,3009
250
- ipex_llm-2.2.0b20250107.dist-info/METADATA,sha256=rPJCuVvUndZ0XZBTZzlQEPi1y_W0fpmQTEGmogyRzRw,12705
251
- ipex_llm-2.2.0b20250107.dist-info/WHEEL,sha256=6iYPr8vTHsyDK75jr9X0V3I9wPSVmtwr_8fdATBciGk,98
252
- ipex_llm-2.2.0b20250107.dist-info/entry_points.txt,sha256=TiUyBB2MRmfF3ko-pyAEzqeBCRnyhu27bNOAsWPp3e8,61
253
- ipex_llm-2.2.0b20250107.dist-info/top_level.txt,sha256=CGCMHM-SyqUabU4h8RqJ2KTYckQUO3LvIWwmUQ6Qbzw,9
254
- ipex_llm-2.2.0b20250107.dist-info/RECORD,,
246
+ ipex_llm-2.2.0b20250109.data/scripts/ipex-llm-init.bat,sha256=HPtCYuDYwEatq7dAwOvdfVcHYCpAVdbj75K1qh0vQek,2578
247
+ ipex_llm-2.2.0b20250109.data/scripts/llm-chat.ps1,sha256=6qrs-hGVAV8IKh7Jx8nq_XrnZcjd7qGU5wndArM7Yag,2769
248
+ ipex_llm-2.2.0b20250109.data/scripts/llm-cli.ps1,sha256=3qBtTLs_EjYDnM8YyCpJhzLnGCKTEGssu9UNqfkjVXs,3009
249
+ ipex_llm-2.2.0b20250109.dist-info/METADATA,sha256=gPslIWSw_X5E5ULhQa8rOHeRo_UeBDXCAyPjBSPB-nU,12705
250
+ ipex_llm-2.2.0b20250109.dist-info/WHEEL,sha256=6iYPr8vTHsyDK75jr9X0V3I9wPSVmtwr_8fdATBciGk,98
251
+ ipex_llm-2.2.0b20250109.dist-info/entry_points.txt,sha256=TiUyBB2MRmfF3ko-pyAEzqeBCRnyhu27bNOAsWPp3e8,61
252
+ ipex_llm-2.2.0b20250109.dist-info/top_level.txt,sha256=CGCMHM-SyqUabU4h8RqJ2KTYckQUO3LvIWwmUQ6Qbzw,9
253
+ ipex_llm-2.2.0b20250109.dist-info/RECORD,,
@@ -1,441 +0,0 @@
1
- #
2
- # Copyright 2016 The BigDL Authors.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- # This file is adapted from
17
- # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py
18
- #
19
-
20
- import torch
21
- from typing import Optional, Tuple, Union
22
- from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
23
- apply_rotary_pos_emb, append_kv_cache, apply_ipex_rotate_every_two
24
- from transformers.utils.import_utils import is_torch_fx_proxy
25
- from transformers.modeling_outputs import BaseModelOutputWithPast
26
- from transformers.models.gptj.modeling_gptj import GPTJModel
27
- from ipex_llm.utils.common import invalidInputError
28
-
29
- import os
30
-
31
- KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
32
-
33
-
34
- def _get_embed_positions(self, position_ids):
35
- embed_positions = self.embed_positions
36
- if embed_positions.device != position_ids.device:
37
- embed_positions = embed_positions.to(position_ids.device)
38
- self.embed_positions = embed_positions
39
- return embed_positions.repeat(position_ids.shape[0], 1, 1)
40
-
41
-
42
- def _attn(
43
- self,
44
- query,
45
- key,
46
- value,
47
- attention_mask=None,
48
- head_mask=None,
49
- ):
50
- # compute causal mask from causal mask buffer
51
- query_length, key_length = query.size(-2), key.size(-2)
52
- causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length]
53
-
54
- # Keep the attention weights computation in fp32 to avoid overflow issues
55
- query = query.to(torch.float32)
56
- key = key.to(torch.float32)
57
-
58
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
59
-
60
- mask_value = torch.finfo(attn_weights.dtype).min
61
- # Need to be a tensor, otherwise we get error:
62
- # `RuntimeError: expected scalar type float but found double`.
63
- # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
64
- mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
65
- attn_weights = torch.where(causal_mask, attn_weights, mask_value)
66
-
67
- attn_weights = attn_weights / self.scale_attn
68
-
69
- if attention_mask is not None:
70
- # Apply the attention mask
71
- attn_weights = attn_weights + attention_mask
72
-
73
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
74
- attn_weights = attn_weights.to(value.dtype)
75
- attn_weights = self.attn_dropout(attn_weights)
76
-
77
- # Mask heads if we want to
78
- if head_mask is not None:
79
- attn_weights = attn_weights * head_mask
80
-
81
- attn_output = torch.matmul(attn_weights, value)
82
-
83
- return attn_output, attn_weights
84
-
85
-
86
- def gptj_attention_forward(
87
- self,
88
- hidden_states: torch.FloatTensor,
89
- layer_past: Optional[Tuple[torch.Tensor]] = None,
90
- attention_mask: Optional[torch.FloatTensor] = None,
91
- position_ids: Optional[torch.LongTensor] = None,
92
- head_mask: Optional[torch.FloatTensor] = None,
93
- use_cache: Optional[bool] = False,
94
- rotary_emb: Optional[Tuple]=None,
95
- output_attentions: Optional[bool] = False,
96
- ) -> Union[
97
- Tuple[torch.Tensor, Tuple[torch.Tensor]],
98
- Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
99
- ]:
100
- query = self.q_proj(hidden_states)
101
- key = self.k_proj(hidden_states)
102
- value = self.v_proj(hidden_states)
103
-
104
- query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
105
- key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
106
- value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
107
-
108
- sin, cos = rotary_emb
109
- use_fuse_rope = hidden_states.device.type == "xpu" and not self.training
110
-
111
- if self.rotary_dim is not None:
112
- k_rot = key[:, :, :, : self.rotary_dim]
113
- q_rot = query[:, :, :, : self.rotary_dim]
114
-
115
- if use_fuse_rope:
116
- apply_ipex_rotate_every_two(q_rot, k_rot, cos, sin)
117
- else:
118
- k_pass = key[:, :, :, self.rotary_dim:]
119
- q_pass = query[:, :, :, self.rotary_dim:]
120
- q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin, position_ids, "gptj")
121
- key = torch.cat([k_rot, k_pass], dim=-1)
122
- query = torch.cat([q_rot, q_pass], dim=-1)
123
- else:
124
- if use_fuse_rope:
125
- apply_ipex_rotate_every_two(query, key, cos, sin)
126
- else:
127
- query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids, "gptj")
128
-
129
- batch_size, q_len, _ = hidden_states.size()
130
-
131
- key = key.permute(0, 2, 1, 3).contiguous()
132
- query = query.permute(0, 2, 1, 3).contiguous()
133
-
134
- kv_seq_len = key.size(-2)
135
- device = hidden_states.device
136
-
137
- if layer_past is not None:
138
- kv_seq_len += layer_past[0].size(2)
139
-
140
- if layer_past is not None:
141
- cache_k = layer_past[0]
142
- cache_v = layer_past[1]
143
- past_length = cache_k.size(2)
144
- if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
145
- new_cache_k, new_cache_v = extend_kv_cache(batch_size,
146
- self.num_attention_heads,
147
- self.head_dim,
148
- past_length,
149
- kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
150
- dtype=cache_v.dtype,
151
- device=device)
152
- new_cache_k[:] = cache_k
153
- new_cache_v[:] = cache_v
154
- cache_k = new_cache_k
155
- cache_v = new_cache_v
156
- key, value = append_kv_cache(cache_k, cache_v, key, value)
157
-
158
- elif use_cache:
159
- key_cache, value_cache = init_kv_cache(batch_size,
160
- self.num_attention_heads,
161
- self.head_dim,
162
- kv_seq_len,
163
- kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
164
- dtype=value.dtype,
165
- device=device)
166
- key_cache[:] = key
167
- value_cache[:] = value
168
- key = key_cache
169
- value = value_cache
170
-
171
- if use_cache is True:
172
- present = (key, value)
173
- else:
174
- present = None
175
-
176
- # compute self-attention: V x Softmax(QK^T)
177
- attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
178
-
179
- attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
180
- attn_output = self.out_proj(attn_output)
181
- attn_output = self.resid_dropout(attn_output)
182
-
183
- outputs = (attn_output, present)
184
- if output_attentions:
185
- outputs += (attn_weights,)
186
-
187
- return outputs # a, present, (attentions)
188
-
189
-
190
- def gptj_block_forward(
191
- self,
192
- hidden_states: Optional[torch.FloatTensor],
193
- layer_past: Optional[Tuple[torch.Tensor]] = None,
194
- attention_mask: Optional[torch.FloatTensor] = None,
195
- position_ids: Optional[torch.LongTensor] = None,
196
- head_mask: Optional[torch.FloatTensor] = None,
197
- use_cache: Optional[bool] = False,
198
- rotary_emb: Optional[Tuple]=None,
199
- output_attentions: Optional[bool] = False,
200
- ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
201
- residual = hidden_states
202
- hidden_states = self.ln_1(hidden_states)
203
- attn_outputs = self.attn(
204
- hidden_states=hidden_states,
205
- layer_past=layer_past,
206
- attention_mask=attention_mask,
207
- position_ids=position_ids,
208
- head_mask=head_mask,
209
- use_cache=use_cache,
210
- rotary_emb=rotary_emb,
211
- output_attentions=output_attentions,
212
- )
213
- attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
214
- outputs = attn_outputs[1:]
215
-
216
- feed_forward_hidden_states = self.mlp(hidden_states)
217
- hidden_states = attn_output + feed_forward_hidden_states + residual
218
-
219
- if use_cache:
220
- outputs = (hidden_states,) + outputs
221
- else:
222
- outputs = (hidden_states,) + outputs[1:]
223
-
224
- return outputs # hidden_states, present, (attentions)
225
-
226
-
227
- def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
228
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
229
- sinusoid_inp = torch.einsum("i , j -> i j",
230
- torch.arange(num_pos, dtype=torch.float), inv_freq).float()
231
- return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
232
-
233
-
234
- old_init = GPTJModel.__init__
235
-
236
-
237
- def gptj_model_new_init(self, config):
238
- old_init(self, config)
239
- embed_dim = config.hidden_size
240
- rotary_dim = config.rotary_dim
241
- pos_embd_dim = rotary_dim or embed_dim
242
- max_positions = config.max_position_embeddings
243
- self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)
244
-
245
-
246
- def get_new_embed_positions(position_ids, prev_embed_positions):
247
- embed_positions = prev_embed_positions
248
- if embed_positions.device != position_ids.device:
249
- embed_positions = embed_positions.to(position_ids.device)
250
- prev_embed_positions = embed_positions
251
- return embed_positions.repeat(position_ids.shape[0], 1, 1), prev_embed_positions
252
-
253
-
254
- def gptj_model_forward(
255
- self,
256
- input_ids: Optional[torch.LongTensor] = None,
257
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
258
- attention_mask: Optional[torch.FloatTensor] = None,
259
- token_type_ids: Optional[torch.LongTensor] = None,
260
- position_ids: Optional[torch.LongTensor] = None,
261
- head_mask: Optional[torch.FloatTensor] = None,
262
- inputs_embeds: Optional[torch.FloatTensor] = None,
263
- use_cache: Optional[bool] = None,
264
- output_attentions: Optional[bool] = None,
265
- output_hidden_states: Optional[bool] = None,
266
- return_dict: Optional[bool] = None,
267
- ) -> Union[Tuple, BaseModelOutputWithPast]:
268
- output_attentions = output_attentions if output_attentions is not None \
269
- else self.config.output_attentions
270
- output_hidden_states = (
271
- output_hidden_states if output_hidden_states is not None
272
- else self.config.output_hidden_states
273
- )
274
- use_cache = use_cache if use_cache is not None else self.config.use_cache
275
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
276
-
277
- if input_ids is not None and inputs_embeds is not None:
278
- invalidInputError(False,
279
- "You cannot specify both input_ids and inputs_embeds at the same time")
280
- elif input_ids is not None:
281
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
282
- input_shape = input_ids.size()
283
- input_ids = input_ids.view(-1, input_shape[-1])
284
- batch_size = input_ids.shape[0]
285
- elif inputs_embeds is not None:
286
- input_shape = inputs_embeds.size()[:-1]
287
- batch_size = inputs_embeds.shape[0]
288
- else:
289
- invalidInputError(False, "You have to specify either input_ids or inputs_embeds")
290
-
291
- device = input_ids.device if input_ids is not None else inputs_embeds.device
292
-
293
- if token_type_ids is not None:
294
- token_type_ids = token_type_ids.view(-1, input_shape[-1])
295
-
296
- if past_key_values is None:
297
- past_length = 0
298
- past_key_values = tuple([None] * len(self.h))
299
- else:
300
- past_length = past_key_values[0][0].size(-2)
301
-
302
- if position_ids is None:
303
- position_ids = torch.arange(past_length, input_shape[-1] + past_length,
304
- dtype=torch.long, device=device)
305
- position_ids = position_ids.unsqueeze(0)
306
-
307
- # Attention mask.
308
- if attention_mask is not None:
309
- if batch_size <= 0:
310
- invalidInputError(False, "batch_size has to be defined and > 0")
311
- attention_mask = attention_mask.view(batch_size, -1)
312
- # We create a 3D attention mask from a 2D tensor mask.
313
- # Sizes are [batch_size, 1, 1, to_seq_length]
314
- # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
315
- # this attention mask is more simple than the triangular masking of causal attention
316
- # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
317
- attention_mask = attention_mask[:, None, None, :]
318
-
319
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
320
- # masked positions, this operation will create a tensor which is 0.0 for
321
- # positions we want to attend and the dtype's smallest value for masked positions.
322
- # Since we are adding it to the raw scores before the softmax, this is
323
- # effectively the same as removing these entirely.
324
- attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
325
- attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
326
-
327
- # Prepare head mask if needed
328
- # 1.0 in head_mask indicate we keep the head
329
- # attention_probs has shape bsz x num_attention_heads x N x N
330
- # head_mask has shape n_layer x batch x num_attention_heads x N x N
331
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
332
-
333
- if inputs_embeds is None:
334
- inputs_embeds = self.wte(input_ids)
335
-
336
- hidden_states = inputs_embeds
337
-
338
- if token_type_ids is not None:
339
- token_type_embeds = self.wte(token_type_ids)
340
- hidden_states = hidden_states + token_type_embeds
341
-
342
- hidden_states = self.drop(hidden_states)
343
-
344
- output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
345
-
346
- if self.gradient_checkpointing and self.training:
347
- if use_cache:
348
- logger.warning_once(
349
- "`use_cache=True` is incompatible with gradient checkpointing."
350
- "Setting `use_cache=False`..."
351
- )
352
- use_cache = False
353
-
354
- presents = () if use_cache else None
355
- all_self_attentions = () if output_attentions else None
356
- all_hidden_states = () if output_hidden_states else None
357
-
358
- # Repeat cos sin here, call only once for each token.
359
- # If put this to attension forward, it will generate too many times.
360
- if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
361
- # The logic to conditionally copy to GPU could not be traced, so we do this
362
- # every time in the torch.fx case
363
- embed_positions = get_embed_positions(self.embed_positions, position_ids)
364
- else:
365
- embed_positions, self.embed_positions = get_new_embed_positions(position_ids,
366
- self.embed_positions)
367
-
368
- repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
369
- sincos = torch.gather(embed_positions, 1, repeated_position_ids)
370
- sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
371
- sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
372
- cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
373
-
374
- for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
375
- # Model parallel
376
- if self.model_parallel:
377
- torch.cuda.set_device(hidden_states.device)
378
- # Ensure layer_past is on same device as hidden_states (might not be correct)
379
- if layer_past is not None:
380
- layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
381
- # Ensure that attention_mask is always on the same device as hidden_states
382
- if attention_mask is not None:
383
- attention_mask = attention_mask.to(hidden_states.device)
384
- if isinstance(head_mask, torch.Tensor):
385
- head_mask = head_mask.to(hidden_states.device)
386
- if output_hidden_states:
387
- all_hidden_states = all_hidden_states + (hidden_states,)
388
-
389
- if self.gradient_checkpointing and self.training:
390
- outputs = self._gradient_checkpointing_func(
391
- block.__call__,
392
- hidden_states,
393
- None,
394
- attention_mask,
395
- position_ids,
396
- head_mask[i],
397
- use_cache,
398
- output_attentions,
399
- )
400
- else:
401
- outputs = block(
402
- hidden_states=hidden_states,
403
- layer_past=layer_past,
404
- attention_mask=attention_mask,
405
- position_ids=position_ids,
406
- head_mask=head_mask[i],
407
- use_cache=use_cache,
408
- rotary_emb=(sin, cos),
409
- output_attentions=output_attentions,
410
- )
411
-
412
- hidden_states = outputs[0]
413
- if use_cache is True:
414
- presents = presents + (outputs[1],)
415
-
416
- if output_attentions:
417
- all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
418
-
419
- # Model Parallel: If it's the last layer for that device, put things on the next device
420
- if self.model_parallel:
421
- for k, v in self.device_map.items():
422
- if i == v[-1] and "cuda:" + str(k) != self.last_device:
423
- hidden_states = hidden_states.to("cuda:" + str(k + 1))
424
-
425
- hidden_states = self.ln_f(hidden_states)
426
-
427
- hidden_states = hidden_states.view(output_shape)
428
- # Add last hidden state
429
- if output_hidden_states:
430
- all_hidden_states = all_hidden_states + (hidden_states,)
431
-
432
- if not return_dict:
433
- return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions]
434
- if v is not None)
435
-
436
- return BaseModelOutputWithPast(
437
- last_hidden_state=hidden_states,
438
- past_key_values=presents,
439
- hidden_states=all_hidden_states,
440
- attentions=all_self_attentions,
441
- )