optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,435 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ import logging
26
+ from pathlib import Path
27
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
28
+
29
+ import rebel
30
+ import torch
31
+ from rebel.compile_context import CompileContext
32
+ from transformers import PretrainedConfig, TimeSeriesTransformerForPrediction, TimeSeriesTransformerModel
33
+ from transformers.modeling_outputs import SampleTSPredictionOutput, Seq2SeqTSModelOutput
34
+ from transformers.modeling_utils import no_init_weights
35
+
36
+ from ....configuration_utils import RBLNCompileConfig
37
+ from ....modeling import RBLNModel
38
+ from ....utils.runtime_utils import RBLNPytorchRuntime
39
+ from ...modeling_outputs import RBLNSeq2SeqTSDecoderOutput
40
+ from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
41
+ from .time_series_transformers_architecture import TimeSeriesTransformersWrapper
42
+
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+ if TYPE_CHECKING:
47
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
48
+
49
+
50
+ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
51
+ mandatory_members = ["main_input_name"]
52
+
53
+ def __init__(
54
+ self,
55
+ runtime: rebel.Runtime,
56
+ model: TimeSeriesTransformerModel,
57
+ **kwargs: Any,
58
+ ) -> None:
59
+ super().__init__(runtime, **kwargs)
60
+ self._origin_model = model
61
+
62
+ def forward(
63
+ self,
64
+ past_values: torch.Tensor,
65
+ past_time_features: torch.Tensor,
66
+ static_categorical_features: Optional[torch.Tensor] = None,
67
+ static_real_features: Optional[torch.Tensor] = None,
68
+ past_observed_mask: Optional[torch.Tensor] = None,
69
+ future_values: Optional[torch.Tensor] = None,
70
+ future_time_features: Optional[torch.Tensor] = None,
71
+ ):
72
+ # preprocess
73
+ transformer_inputs, loc, scale, static_feat = self._origin_model.create_network_inputs(
74
+ past_values=past_values,
75
+ past_time_features=past_time_features,
76
+ past_observed_mask=past_observed_mask,
77
+ static_categorical_features=static_categorical_features,
78
+ static_real_features=static_real_features,
79
+ future_values=future_values,
80
+ future_time_features=future_time_features,
81
+ )
82
+ enc_input = transformer_inputs[:, : self._origin_model.config.context_length, ...]
83
+
84
+ # enc_attn_key_value_caches is updated to device dram in-place
85
+ _ = super().forward(inputs_embeds=enc_input)
86
+
87
+ return Seq2SeqTSModelOutput(
88
+ loc=loc,
89
+ scale=scale,
90
+ static_features=static_feat,
91
+ )
92
+
93
+
94
+ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
95
+ mandatory_members = ["main_input_name"]
96
+
97
+ def forward(
98
+ self,
99
+ inputs_embeds: torch.Tensor = None,
100
+ attention_mask: torch.Tensor = None,
101
+ cache_position: torch.Tensor = None,
102
+ ):
103
+ block_tables = torch.zeros(1, 1, dtype=torch.int16)
104
+ outputs = super().forward(inputs_embeds, attention_mask, cache_position, block_tables)
105
+
106
+ return RBLNSeq2SeqTSDecoderOutput(
107
+ params=outputs[:-1],
108
+ last_hidden_states=outputs[-1],
109
+ )
110
+
111
+
112
+ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
113
+ """
114
+ The Time Series Transformer Model with a distribution head on top for time-series forecasting. e.g., for datasets like M4, NN5, or other time series forecasting benchmarks.
115
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
116
+
117
+ A class to convert and run pre-trained transformer-based `TimeSeriesTransformerForPrediction` models on RBLN devices.
118
+ It implements the methods to convert a pre-trained transformers `TimeSeriesTransformerForPrediction` model into a RBLN transformer model by:
119
+
120
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
121
+ - compiling the resulting graph using the RBLN Compiler.
122
+ """
123
+
124
+ auto_model_class = None
125
+ main_input_name = "inputs_embeds"
126
+
127
+ def __post_init__(self, **kwargs):
128
+ super().__post_init__(**kwargs)
129
+ self.batch_size = self.rbln_config.batch_size
130
+ self.dec_max_seq_len = self.rbln_config.dec_max_seq_len
131
+ self.num_parallel_samples = self.rbln_config.num_parallel_samples
132
+
133
+ with no_init_weights():
134
+ self._origin_model = TimeSeriesTransformerForPrediction._from_config(self.config)
135
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
136
+ self._origin_model.model.embedder.load_state_dict(artifacts["embedder"])
137
+ self.encoder = RBLNRuntimeEncoder(
138
+ runtime=self.model[0],
139
+ main_input_name="inputs_embeds",
140
+ model=self._origin_model.model,
141
+ )
142
+ self.decoder = RBLNRuntimeDecoder(
143
+ runtime=self.model[1],
144
+ main_input_name="inputs_embeds",
145
+ )
146
+
147
+ def __getattr__(self, __name: str) -> Any:
148
+ def redirect(func):
149
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
150
+
151
+ val = getattr(TimeSeriesTransformerForPrediction, __name)
152
+ if val is not None and isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
153
+ return redirect(val)
154
+
155
+ @classmethod
156
+ def _wrap_model_if_needed(
157
+ self, model: "PreTrainedModel", rbln_config: RBLNTimeSeriesTransformerForPredictionConfig
158
+ ):
159
+ return TimeSeriesTransformersWrapper(model, rbln_config.num_parallel_samples)
160
+
161
+ @classmethod
162
+ @torch.inference_mode()
163
+ def get_compiled_model(cls, model, rbln_config: RBLNTimeSeriesTransformerForPredictionConfig):
164
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
165
+
166
+ enc_compile_config = rbln_config.compile_cfgs[0]
167
+ dec_compile_config = rbln_config.compile_cfgs[1]
168
+
169
+ context = CompileContext(use_weight_sharing=False)
170
+
171
+ enc_example_inputs = enc_compile_config.get_dummy_inputs(fill=0)
172
+
173
+ # Mark encoder's static tensors (cross kv states)
174
+ static_tensors = {}
175
+ for (name, _, _), tensor in zip(enc_compile_config.input_info, enc_example_inputs):
176
+ if "key_value_states" in name:
177
+ static_tensors[name] = tensor
178
+ context.mark_static_address(tensor)
179
+
180
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
181
+
182
+ # Mark decoder's static tensors (self kv states)
183
+ for (name, _, _), tensor in zip(dec_compile_config.input_info, dec_example_inputs):
184
+ if "key_value_states" in name:
185
+ context.mark_static_address(tensor)
186
+
187
+ compiled_decoder = cls.compile(
188
+ wrapped_model.decoder,
189
+ dec_compile_config,
190
+ create_runtimes=rbln_config.create_runtimes,
191
+ device=rbln_config.device,
192
+ example_inputs=dec_example_inputs,
193
+ compile_context=context,
194
+ )
195
+ compiled_encoder = cls.compile(
196
+ wrapped_model.encoder,
197
+ enc_compile_config,
198
+ create_runtimes=rbln_config.create_runtimes,
199
+ device=rbln_config.device,
200
+ example_inputs=enc_example_inputs,
201
+ compile_context=context,
202
+ )
203
+
204
+ return {"encoder": compiled_encoder, "decoder": compiled_decoder}
205
+
206
+ @classmethod
207
+ def save_torch_artifacts(
208
+ cls,
209
+ model: "PreTrainedModel",
210
+ save_dir_path: Path,
211
+ subfolder: str,
212
+ rbln_config: RBLNTimeSeriesTransformerForPredictionConfig,
213
+ ):
214
+ # If you are unavoidably running on a CPU rather than an RBLN device,
215
+ # store the torch tensor, weight, etc. in this function.
216
+
217
+ save_dict = {}
218
+ save_dict["embedder"] = model.model.embedder.state_dict()
219
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
220
+
221
+ @classmethod
222
+ def _update_rbln_config(
223
+ cls,
224
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
225
+ model: Optional["PreTrainedModel"] = None,
226
+ model_config: Optional["PretrainedConfig"] = None,
227
+ rbln_config: Optional[RBLNTimeSeriesTransformerForPredictionConfig] = None,
228
+ ) -> RBLNTimeSeriesTransformerForPredictionConfig:
229
+ rbln_config.num_parallel_samples = rbln_config.num_parallel_samples or model_config.num_parallel_samples
230
+
231
+ if rbln_config.dec_max_seq_len is None:
232
+ predict_length = model_config.prediction_length
233
+ rbln_config.dec_max_seq_len = (
234
+ predict_length if predict_length % 64 == 0 else predict_length + (64 - predict_length % 64)
235
+ )
236
+
237
+ # model input info
238
+ enc_input_info = [
239
+ (
240
+ "inputs_embeds",
241
+ [rbln_config.batch_size, model_config.context_length, model_config.feature_size],
242
+ "float32",
243
+ ),
244
+ ]
245
+ enc_input_info.extend(
246
+ [
247
+ (
248
+ "cross_key_value_states",
249
+ [
250
+ model_config.decoder_layers * 2,
251
+ rbln_config.batch_size,
252
+ model_config.decoder_attention_heads,
253
+ model_config.context_length,
254
+ model_config.d_model // model_config.decoder_attention_heads,
255
+ ],
256
+ "float32",
257
+ )
258
+ ]
259
+ )
260
+
261
+ dec_input_info = [
262
+ (
263
+ "inputs_embeds",
264
+ [rbln_config.batch_size * rbln_config.num_parallel_samples, 1, model_config.feature_size],
265
+ "float32",
266
+ ),
267
+ ("attention_mask", [1, rbln_config.dec_max_seq_len], "float32"),
268
+ ("cache_position", [], "int32"),
269
+ ("block_tables", [1, 1], "int16"),
270
+ ]
271
+ dec_input_info.extend(
272
+ [
273
+ (
274
+ "cross_key_value_states",
275
+ [
276
+ model_config.decoder_layers * 2, # 4
277
+ rbln_config.batch_size, # 64
278
+ model_config.decoder_attention_heads, # 2
279
+ model_config.context_length, # 24
280
+ model_config.d_model // model_config.decoder_attention_heads, # 13
281
+ ],
282
+ "float32",
283
+ )
284
+ ]
285
+ )
286
+ dec_input_info.extend(
287
+ [
288
+ (
289
+ f"self_key_value_states_{i}",
290
+ [
291
+ 1,
292
+ model_config.decoder_attention_heads
293
+ * rbln_config.num_parallel_samples
294
+ * rbln_config.batch_size,
295
+ rbln_config.dec_max_seq_len,
296
+ model_config.d_model // model_config.encoder_attention_heads,
297
+ ],
298
+ "float32",
299
+ )
300
+ for i in range(model_config.decoder_layers * 2)
301
+ ]
302
+ )
303
+ enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
304
+ dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
305
+
306
+ rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
307
+ return rbln_config
308
+
309
+ @classmethod
310
+ def _create_runtimes(
311
+ cls,
312
+ compiled_models: List[rebel.RBLNCompiledModel],
313
+ rbln_config: RBLNTimeSeriesTransformerForPredictionConfig,
314
+ ) -> List[rebel.Runtime]:
315
+ if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
316
+ cls._raise_missing_compiled_file_error(["encoder", "decoder"])
317
+
318
+ return [
319
+ rebel.Runtime(
320
+ compiled_models[0],
321
+ tensor_type="pt",
322
+ device=rbln_config.device_map["encoder"],
323
+ activate_profiler=rbln_config.activate_profiler,
324
+ timeout=rbln_config.timeout,
325
+ ),
326
+ rebel.Runtime(
327
+ compiled_models[1],
328
+ tensor_type="pt",
329
+ device=rbln_config.device_map["decoder"],
330
+ activate_profiler=rbln_config.activate_profiler,
331
+ timeout=rbln_config.timeout,
332
+ ),
333
+ ]
334
+
335
+ def validate_batch_size(self, **kwargs):
336
+ for k, v in kwargs.items():
337
+ if v is not None and v.shape[0] != self.batch_size:
338
+ raise RuntimeError(
339
+ f"Batch size mismatch in '{k}': Expected {self.batch_size}, but got {v.shape[0]}. \n"
340
+ f"Tensor shape: {v.shape} \n\n"
341
+ f"Note: `batch_size` is set at compile time. \n"
342
+ f"To change it, pass `export=True` along with `rbln_batch_size` when calling `from_pretrained()` to trigger recompilation."
343
+ )
344
+
345
+ @torch.no_grad()
346
+ def generate(
347
+ self,
348
+ past_values: torch.Tensor,
349
+ past_time_features: torch.Tensor,
350
+ future_time_features: torch.Tensor,
351
+ past_observed_mask: Optional[torch.Tensor] = None,
352
+ static_categorical_features: Optional[torch.Tensor] = None,
353
+ static_real_features: Optional[torch.Tensor] = None,
354
+ **kwargs,
355
+ ) -> SampleTSPredictionOutput:
356
+ """
357
+ Generate pass for the RBLN-optimized Time Series Transformer model for time series forecasting.
358
+
359
+ Args:
360
+ past_values (torch.FloatTensor of shape (batch_size, sequence_length) or (batch_size, sequence_length, input_size)): Past values of the time series, that serve as context in order to predict the future.
361
+ past_time_features (torch.FloatTensor of shape (batch_size, sequence_length, num_features)): Required time features, which the model internally will add to past_values.
362
+ future_time_features (torch.FloatTensor of shape (batch_size, prediction_length, num_features)): Required time features for the prediction window, which the model internally will add to future_values.
363
+ past_observed_mask (torch.BoolTensor of shape (batch_size, sequence_length) or (batch_size, sequence_length, input_size), optional): Boolean mask to indicate which past_values were observed and which were missing.
364
+ static_categorical_features (torch.LongTensor of shape (batch_size, number of static categorical features), optional): Optional static categorical features for which the model will learn an embedding, which it will add to the values of the time series.
365
+ static_real_features (torch.FloatTensor of shape (batch_size, number of static real features), optional): Optional static real features which the model will add to the values of the time series.
366
+
367
+ Returns:
368
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SampleTSPredictionOutput object.
369
+ """
370
+ self.validate_batch_size(**{k: v for k, v in locals().items() if isinstance(v, torch.Tensor)})
371
+
372
+ outputs = self.encoder(
373
+ static_categorical_features=static_categorical_features,
374
+ static_real_features=static_real_features,
375
+ past_time_features=past_time_features,
376
+ past_values=past_values,
377
+ past_observed_mask=past_observed_mask,
378
+ future_time_features=future_time_features,
379
+ )
380
+
381
+ loc = outputs.loc
382
+ scale = outputs.scale
383
+ static_feat = outputs.static_features
384
+
385
+ num_parallel_samples = self.num_parallel_samples
386
+ repeated_loc = loc.repeat_interleave(repeats=num_parallel_samples, dim=0)
387
+ repeated_scale = scale.repeat_interleave(repeats=num_parallel_samples, dim=0)
388
+
389
+ repeated_past_values = (
390
+ past_values.repeat_interleave(repeats=num_parallel_samples, dim=0) - repeated_loc
391
+ ) / repeated_scale
392
+
393
+ expanded_static_feat = static_feat.unsqueeze(1).expand(-1, future_time_features.shape[1], -1)
394
+ features = torch.cat((expanded_static_feat, future_time_features), dim=-1)
395
+ repeated_features = features.repeat_interleave(repeats=num_parallel_samples, dim=0)
396
+
397
+ # greedy decoding
398
+ future_samples = []
399
+ dec_attn_mask = torch.zeros(1, self.dec_max_seq_len)
400
+ for k in range(self.config.prediction_length):
401
+ lagged_sequence = self._origin_model.model.get_lagged_subsequences(
402
+ sequence=repeated_past_values,
403
+ subsequences_length=1 + k,
404
+ shift=1,
405
+ )
406
+
407
+ lags_shape = lagged_sequence.shape
408
+ reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1)
409
+ decoder_input = torch.cat((reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1)
410
+
411
+ dec_attn_mask[:, k] = 1
412
+ dec_inputs_embeds = decoder_input[:, -1:]
413
+
414
+ decoder_out = self.decoder(
415
+ inputs_embeds=dec_inputs_embeds.contiguous(),
416
+ attention_mask=dec_attn_mask,
417
+ cache_position=torch.tensor(k, dtype=torch.int32),
418
+ )
419
+ params = decoder_out.params
420
+
421
+ distr = self._origin_model.output_distribution(params, loc=repeated_loc, scale=repeated_scale)
422
+ next_sample = distr.sample()
423
+
424
+ repeated_past_values = torch.cat(
425
+ (repeated_past_values, (next_sample - repeated_loc) / repeated_scale), dim=1
426
+ )
427
+ future_samples.append(next_sample)
428
+
429
+ concat_future_samples = torch.cat(future_samples, dim=1)
430
+
431
+ return SampleTSPredictionOutput(
432
+ sequences=concat_future_samples.reshape(
433
+ (-1, num_parallel_samples, self.config.prediction_length) + self._origin_model.target_shape,
434
+ )
435
+ )