crfm-helm 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +134 -31
  2. crfm_helm-0.5.0.dist-info/RECORD +642 -0
  3. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
  4. helm/benchmark/adaptation/adapter_spec.py +31 -3
  5. helm/benchmark/adaptation/adapters/adapter.py +2 -2
  6. helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
  7. helm/benchmark/adaptation/adapters/generation_adapter.py +1 -0
  8. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -4
  9. helm/benchmark/adaptation/adapters/language_modeling_adapter.py +2 -3
  10. helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
  11. helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +2 -1
  12. helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
  13. helm/benchmark/adaptation/adapters/test_generation_adapter.py +32 -8
  14. helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +7 -19
  15. helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +60 -6
  16. helm/benchmark/adaptation/common_adapter_specs.py +376 -0
  17. helm/benchmark/adaptation/request_state.py +6 -1
  18. helm/benchmark/adaptation/scenario_state.py +6 -2
  19. helm/benchmark/annotation/annotator.py +43 -0
  20. helm/benchmark/annotation/annotator_factory.py +61 -0
  21. helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
  22. helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
  23. helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
  24. helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
  25. helm/benchmark/annotation/test_annotator_factory.py +26 -0
  26. helm/benchmark/annotation/test_dummy_annotator.py +44 -0
  27. helm/benchmark/annotation_executor.py +124 -0
  28. helm/benchmark/augmentations/data_augmenter.py +0 -2
  29. helm/benchmark/augmentations/gender_perturbation.py +1 -1
  30. helm/benchmark/augmentations/perturbation.py +8 -2
  31. helm/benchmark/augmentations/perturbation_description.py +1 -1
  32. helm/benchmark/augmentations/suffix_perturbation.py +29 -0
  33. helm/benchmark/augmentations/test_perturbation.py +11 -7
  34. helm/benchmark/augmentations/translate_perturbation.py +30 -0
  35. helm/benchmark/config_registry.py +7 -1
  36. helm/benchmark/executor.py +46 -16
  37. helm/benchmark/huggingface_registration.py +20 -7
  38. helm/benchmark/metrics/basic_metrics.py +169 -664
  39. helm/benchmark/metrics/bbq_metrics.py +3 -4
  40. helm/benchmark/metrics/bias_metrics.py +6 -6
  41. helm/benchmark/metrics/classification_metrics.py +11 -8
  42. helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
  43. helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
  44. helm/benchmark/metrics/code_metrics_helper.py +0 -2
  45. helm/benchmark/metrics/common_metric_specs.py +167 -0
  46. helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
  47. helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
  48. helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
  49. helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
  50. helm/benchmark/metrics/disinformation_metrics.py +4 -110
  51. helm/benchmark/metrics/dry_run_metrics.py +2 -2
  52. helm/benchmark/metrics/efficiency_metrics.py +206 -0
  53. helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
  54. helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
  55. helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
  56. helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
  57. helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
  58. helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
  59. helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
  60. helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
  61. helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
  62. helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
  63. helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
  64. helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
  65. helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
  66. helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
  67. helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
  68. helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
  69. helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
  70. helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
  71. helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
  72. helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
  73. helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
  74. helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
  75. helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
  76. helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
  77. helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
  78. helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
  79. helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
  80. helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
  81. helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
  82. helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
  83. helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
  84. helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
  85. helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
  86. helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
  87. helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
  88. helm/benchmark/metrics/language_modeling_metrics.py +99 -0
  89. helm/benchmark/metrics/machine_translation_metrics.py +89 -0
  90. helm/benchmark/metrics/metric.py +93 -172
  91. helm/benchmark/metrics/metric_name.py +0 -1
  92. helm/benchmark/metrics/metric_service.py +16 -0
  93. helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
  94. helm/benchmark/metrics/ranking_metrics.py +2 -2
  95. helm/benchmark/metrics/reference_metric.py +148 -0
  96. helm/benchmark/metrics/summac/model_summac.py +0 -2
  97. helm/benchmark/metrics/summarization_metrics.py +2 -2
  98. helm/benchmark/metrics/test_classification_metrics.py +8 -5
  99. helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
  100. helm/benchmark/metrics/{test_basic_metrics.py → test_evaluate_reference_metrics.py} +5 -1
  101. helm/benchmark/metrics/test_metric.py +2 -2
  102. helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +10 -2
  103. helm/benchmark/metrics/toxicity_metrics.py +1 -1
  104. helm/benchmark/metrics/toxicity_utils.py +23 -0
  105. helm/benchmark/metrics/unitxt_metrics.py +81 -0
  106. helm/benchmark/metrics/vision_language/__init__.py +0 -0
  107. helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
  108. helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
  109. helm/benchmark/metrics/vision_language/image_utils.py +100 -0
  110. helm/benchmark/model_deployment_registry.py +74 -0
  111. helm/benchmark/model_metadata_registry.py +36 -0
  112. helm/benchmark/multi_gpu_runner.py +133 -0
  113. helm/benchmark/presentation/create_plots.py +8 -7
  114. helm/benchmark/presentation/run_display.py +26 -10
  115. helm/benchmark/presentation/schema.py +15 -40
  116. helm/benchmark/presentation/summarize.py +119 -79
  117. helm/benchmark/presentation/table.py +8 -8
  118. helm/benchmark/presentation/test_contamination.py +2 -2
  119. helm/benchmark/presentation/test_run_entry.py +1 -2
  120. helm/benchmark/presentation/test_summarize.py +3 -3
  121. helm/benchmark/run.py +54 -26
  122. helm/benchmark/run_expander.py +214 -16
  123. helm/benchmark/run_spec.py +93 -0
  124. helm/benchmark/run_spec_factory.py +162 -0
  125. helm/benchmark/run_specs/__init__.py +0 -0
  126. helm/benchmark/run_specs/classic_run_specs.py +1510 -0
  127. helm/benchmark/run_specs/cleva_run_specs.py +277 -0
  128. helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
  129. helm/benchmark/run_specs/heim_run_specs.py +623 -0
  130. helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
  131. helm/benchmark/run_specs/lite_run_specs.py +307 -0
  132. helm/benchmark/run_specs/simple_run_specs.py +104 -0
  133. helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
  134. helm/benchmark/run_specs/vlm_run_specs.py +501 -0
  135. helm/benchmark/runner.py +51 -57
  136. helm/benchmark/runner_config_registry.py +21 -0
  137. helm/benchmark/scenarios/bbq_scenario.py +1 -1
  138. helm/benchmark/scenarios/bold_scenario.py +2 -2
  139. helm/benchmark/scenarios/code_scenario.py +1 -0
  140. helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
  141. helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
  142. helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
  143. helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
  144. helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
  145. helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
  146. helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
  147. helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
  148. helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
  149. helm/benchmark/scenarios/image_generation/__init__.py +0 -0
  150. helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
  151. helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
  152. helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
  153. helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
  154. helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
  155. helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
  156. helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
  157. helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
  158. helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
  159. helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
  160. helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
  161. helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
  162. helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
  163. helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
  164. helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
  165. helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
  166. helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
  167. helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
  168. helm/benchmark/scenarios/imdb_scenario.py +0 -1
  169. helm/benchmark/scenarios/live_qa_scenario.py +94 -0
  170. helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
  171. helm/benchmark/scenarios/math_scenario.py +19 -2
  172. helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
  173. helm/benchmark/scenarios/numeracy_scenario.py +1 -1
  174. helm/benchmark/scenarios/opinions_qa_scenario.py +0 -4
  175. helm/benchmark/scenarios/scenario.py +4 -0
  176. helm/benchmark/scenarios/simple_scenarios.py +122 -1
  177. helm/benchmark/scenarios/test_math_scenario.py +6 -0
  178. helm/benchmark/scenarios/test_scenario.py +6 -3
  179. helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
  180. helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
  181. helm/benchmark/scenarios/unitxt_scenario.py +56 -0
  182. helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
  183. helm/benchmark/scenarios/vicuna_scenario.py +1 -1
  184. helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
  185. helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
  186. helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
  187. helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
  188. helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
  189. helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
  190. helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
  191. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
  192. helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
  193. helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
  194. helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
  195. helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
  196. helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
  197. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
  198. helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
  199. helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
  200. helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
  201. helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
  202. helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
  203. helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
  204. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
  205. helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +1 -2
  206. helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
  207. helm/benchmark/scenarios/wmt_14_scenario.py +1 -1
  208. helm/benchmark/server.py +24 -1
  209. helm/benchmark/slurm_runner.py +70 -49
  210. helm/benchmark/static/benchmarking.js +1 -1
  211. helm/benchmark/static/schema_classic.yaml +258 -1066
  212. helm/benchmark/static/schema_instruction_following.yaml +210 -0
  213. helm/benchmark/static/schema_lite.yaml +2 -227
  214. helm/benchmark/static/schema_mmlu.yaml +1507 -0
  215. helm/benchmark/static/schema_unitxt.yaml +428 -0
  216. helm/benchmark/static/schema_vlm.yaml +576 -0
  217. helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
  218. helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
  219. helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
  220. helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
  221. helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
  222. helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
  223. helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
  224. helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
  225. helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
  226. helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
  227. helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
  228. helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
  229. helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
  230. helm/benchmark/static_build/assets/index-d839df55.js +9 -0
  231. helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
  232. helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
  233. helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
  234. helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
  235. helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
  236. helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
  237. helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
  238. helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
  239. helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
  240. helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
  241. helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
  242. helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
  243. helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
  244. helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
  245. helm/benchmark/static_build/config.js +4 -0
  246. helm/benchmark/static_build/index.html +20 -0
  247. helm/benchmark/test_data_preprocessor.py +3 -3
  248. helm/benchmark/test_model_deployment_definition.py +14 -16
  249. helm/benchmark/test_run_expander.py +1 -1
  250. helm/benchmark/window_services/ai21_window_service.py +22 -33
  251. helm/benchmark/window_services/cohere_window_service.py +1 -63
  252. helm/benchmark/window_services/default_window_service.py +2 -44
  253. helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
  254. helm/benchmark/window_services/ice_window_service.py +0 -34
  255. helm/benchmark/window_services/image_generation/__init__.py +0 -0
  256. helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
  257. helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
  258. helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
  259. helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
  260. helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
  261. helm/benchmark/window_services/local_window_service.py +21 -4
  262. helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
  263. helm/benchmark/window_services/test_bloom_window_service.py +2 -1
  264. helm/benchmark/window_services/test_cohere_window_service.py +2 -1
  265. helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
  266. helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
  267. helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
  268. helm/benchmark/window_services/test_gptj_window_service.py +3 -2
  269. helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
  270. helm/benchmark/window_services/test_ice_window_service.py +2 -1
  271. helm/benchmark/window_services/test_openai_window_service.py +2 -1
  272. helm/benchmark/window_services/test_opt_window_service.py +3 -2
  273. helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
  274. helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
  275. helm/benchmark/window_services/test_t511b_window_service.py +2 -1
  276. helm/benchmark/window_services/test_ul2_window_service.py +2 -1
  277. helm/benchmark/window_services/test_utils.py +3 -2
  278. helm/benchmark/window_services/test_yalm_window_service.py +2 -1
  279. helm/benchmark/window_services/window_service.py +42 -0
  280. helm/benchmark/window_services/window_service_factory.py +4 -1
  281. helm/benchmark/window_services/yalm_window_service.py +0 -27
  282. helm/clients/__init__.py +0 -0
  283. helm/{proxy/clients → clients}/ai21_client.py +3 -9
  284. helm/clients/aleph_alpha_client.py +112 -0
  285. helm/{proxy/clients → clients}/anthropic_client.py +203 -18
  286. helm/{proxy/clients → clients}/auto_client.py +59 -31
  287. helm/clients/bedrock_client.py +128 -0
  288. helm/clients/bedrock_utils.py +72 -0
  289. helm/{proxy/clients → clients}/client.py +65 -7
  290. helm/clients/clip_score_client.py +49 -0
  291. helm/clients/clip_scorers/__init__.py +0 -0
  292. helm/clients/clip_scorers/base_clip_scorer.py +18 -0
  293. helm/clients/clip_scorers/clip_scorer.py +50 -0
  294. helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
  295. helm/{proxy/clients → clients}/cohere_client.py +4 -11
  296. helm/clients/gcs_client.py +82 -0
  297. helm/{proxy/clients → clients}/google_client.py +5 -5
  298. helm/clients/google_translate_client.py +35 -0
  299. helm/{proxy/clients → clients}/http_model_client.py +5 -7
  300. helm/{proxy/clients → clients}/huggingface_client.py +43 -64
  301. helm/clients/image_generation/__init__.py +0 -0
  302. helm/clients/image_generation/adobe_vision_client.py +78 -0
  303. helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
  304. helm/clients/image_generation/cogview2/__init__.py +0 -0
  305. helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
  306. helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
  307. helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
  308. helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
  309. helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
  310. helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
  311. helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
  312. helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
  313. helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
  314. helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
  315. helm/clients/image_generation/cogview2_client.py +191 -0
  316. helm/clients/image_generation/dalle2_client.py +192 -0
  317. helm/clients/image_generation/dalle3_client.py +108 -0
  318. helm/clients/image_generation/dalle_mini/__init__.py +3 -0
  319. helm/clients/image_generation/dalle_mini/data.py +442 -0
  320. helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
  321. helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
  322. helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
  323. helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
  324. helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
  325. helm/clients/image_generation/dalle_mini/model/text.py +251 -0
  326. helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
  327. helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
  328. helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
  329. helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
  330. helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
  331. helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
  332. helm/clients/image_generation/dalle_mini_client.py +190 -0
  333. helm/clients/image_generation/deep_floyd_client.py +78 -0
  334. helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
  335. helm/clients/image_generation/image_generation_client_utils.py +9 -0
  336. helm/clients/image_generation/lexica_client.py +86 -0
  337. helm/clients/image_generation/mindalle/__init__.py +0 -0
  338. helm/clients/image_generation/mindalle/models/__init__.py +216 -0
  339. helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
  340. helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
  341. helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
  342. helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
  343. helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
  344. helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
  345. helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
  346. helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
  347. helm/clients/image_generation/mindalle/utils/config.py +129 -0
  348. helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
  349. helm/clients/image_generation/mindalle/utils/utils.py +89 -0
  350. helm/clients/image_generation/mindalle_client.py +115 -0
  351. helm/clients/image_generation/nudity_check_client.py +64 -0
  352. helm/clients/image_generation/together_image_generation_client.py +111 -0
  353. helm/{proxy/clients → clients}/lit_gpt_client.py +4 -4
  354. helm/{proxy/clients → clients}/megatron_client.py +5 -5
  355. helm/clients/mistral_client.py +134 -0
  356. helm/clients/moderation_api_client.py +109 -0
  357. helm/clients/open_lm_client.py +43 -0
  358. helm/clients/openai_client.py +302 -0
  359. helm/{proxy/clients → clients}/palmyra_client.py +6 -8
  360. helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
  361. helm/clients/simple_client.py +64 -0
  362. helm/{proxy/clients → clients}/test_auto_client.py +13 -15
  363. helm/clients/test_client.py +100 -0
  364. helm/{proxy/clients → clients}/test_huggingface_client.py +15 -16
  365. helm/clients/test_simple_client.py +19 -0
  366. helm/{proxy/clients → clients}/test_together_client.py +20 -8
  367. helm/{proxy/clients → clients}/together_client.py +12 -72
  368. helm/clients/vertexai_client.py +391 -0
  369. helm/clients/vision_language/__init__.py +0 -0
  370. helm/clients/vision_language/huggingface_vlm_client.py +104 -0
  371. helm/{proxy/clients → clients}/vision_language/idefics_client.py +53 -48
  372. helm/clients/vision_language/open_flamingo/__init__.py +2 -0
  373. helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
  374. helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
  375. helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
  376. helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
  377. helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
  378. helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
  379. helm/clients/vision_language/open_flamingo_client.py +155 -0
  380. helm/clients/vision_language/qwen_vlm_client.py +171 -0
  381. helm/clients/vllm_client.py +46 -0
  382. helm/common/cache.py +16 -4
  383. helm/common/cache_backend_config.py +47 -0
  384. helm/common/clip_score_request.py +41 -0
  385. helm/common/file_caches/__init__.py +0 -0
  386. helm/common/file_caches/file_cache.py +16 -0
  387. helm/common/file_caches/local_file_cache.py +61 -0
  388. helm/common/file_caches/test_local_file_cache.py +25 -0
  389. helm/common/file_upload_request.py +27 -0
  390. helm/common/general.py +1 -1
  391. helm/common/image_generation_parameters.py +25 -0
  392. helm/common/images_utils.py +24 -1
  393. helm/common/key_value_store.py +35 -4
  394. helm/common/media_object.py +13 -0
  395. helm/common/moderations_api_request.py +71 -0
  396. helm/common/mongo_key_value_store.py +3 -3
  397. helm/common/multimodal_request_utils.py +31 -0
  398. helm/common/nudity_check_request.py +29 -0
  399. helm/common/request.py +15 -17
  400. helm/common/test_general.py +6 -0
  401. helm/common/tokenization_request.py +1 -1
  402. helm/config/model_deployments.yaml +1069 -546
  403. helm/config/model_metadata.yaml +753 -31
  404. helm/config/tokenizer_configs.yaml +142 -43
  405. helm/proxy/accounts.py +31 -4
  406. helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
  407. helm/proxy/critique/model_critique_client.py +8 -6
  408. helm/proxy/example_queries.py +29 -17
  409. helm/proxy/server.py +70 -5
  410. helm/proxy/services/remote_service.py +31 -0
  411. helm/proxy/services/server_service.py +96 -16
  412. helm/proxy/services/service.py +30 -0
  413. helm/proxy/services/test_remote_service.py +4 -3
  414. helm/proxy/services/test_service.py +0 -12
  415. helm/proxy/test_accounts.py +32 -0
  416. helm/proxy/token_counters/auto_token_counter.py +37 -37
  417. helm/proxy/token_counters/test_auto_token_counter.py +164 -0
  418. helm/proxy/token_counters/token_counter.py +3 -5
  419. helm/tokenizers/__init__.py +0 -0
  420. helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
  421. helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +1 -1
  422. helm/{proxy/tokenizers → tokenizers}/auto_tokenizer.py +6 -9
  423. helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
  424. helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
  425. helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +7 -26
  426. helm/tokenizers/simple_tokenizer.py +33 -0
  427. helm/{proxy/tokenizers → tokenizers}/test_anthropic_tokenizer.py +1 -1
  428. helm/{proxy/tokenizers → tokenizers}/test_huggingface_tokenizer.py +3 -0
  429. helm/tokenizers/test_simple_tokenizer.py +33 -0
  430. helm/{proxy/tokenizers → tokenizers}/vertexai_tokenizer.py +1 -1
  431. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
  432. helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
  433. helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
  434. helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
  435. crfm_helm-0.4.0.dist-info/RECORD +0 -397
  436. helm/benchmark/run_specs.py +0 -2762
  437. helm/benchmark/test_model_properties.py +0 -1570
  438. helm/benchmark/vlm_run_specs.py +0 -97
  439. helm/benchmark/window_services/flan_t5_window_service.py +0 -29
  440. helm/benchmark/window_services/gpt2_window_service.py +0 -32
  441. helm/benchmark/window_services/huggingface_window_service.py +0 -60
  442. helm/benchmark/window_services/t0pp_window_service.py +0 -35
  443. helm/benchmark/window_services/t511b_window_service.py +0 -30
  444. helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
  445. helm/benchmark/window_services/ul2_window_service.py +0 -30
  446. helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
  447. helm/common/cache_utils.py +0 -14
  448. helm/proxy/clients/aleph_alpha_client.py +0 -95
  449. helm/proxy/clients/goose_ai_client.py +0 -99
  450. helm/proxy/clients/microsoft_client.py +0 -180
  451. helm/proxy/clients/openai_client.py +0 -206
  452. helm/proxy/clients/simple_client.py +0 -60
  453. helm/proxy/clients/test_client.py +0 -49
  454. helm/proxy/clients/vertexai_client.py +0 -115
  455. helm/proxy/token_counters/ai21_token_counter.py +0 -20
  456. helm/proxy/token_counters/cohere_token_counter.py +0 -13
  457. helm/proxy/token_counters/free_token_counter.py +0 -12
  458. helm/proxy/token_counters/gooseai_token_counter.py +0 -24
  459. helm/proxy/token_counters/openai_token_counter.py +0 -22
  460. helm/proxy/token_counters/test_ai21_token_counter.py +0 -88
  461. helm/proxy/token_counters/test_openai_token_counter.py +0 -81
  462. helm/proxy/tokenizers/simple_tokenizer.py +0 -32
  463. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
  464. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
  465. {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
  466. /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
  467. /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
  468. /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
  469. /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
  470. /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
  471. /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
  472. /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
  473. /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
  474. /helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +0 -0
  475. /helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +0 -0
  476. /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
  477. /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
  478. /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
  479. /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
  480. /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
  481. /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
  482. /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
@@ -0,0 +1,269 @@
1
+ # -*- encoding: utf-8 -*-
2
+ """
3
+ @File : itersr_model.py
4
+ @Time : 2021/10/02 01:36:32
5
+ @Author : Ming Ding
6
+ @Contact : dm18@mails.tsinghua.edu.cn
7
+ """
8
+
9
+ # here put the import lib
10
+ import math
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ from helm.common.optional_dependencies import handle_module_not_found_error
15
+
16
+ try:
17
+ from deepspeed.runtime.activation_checkpointing.checkpointing import get_cuda_rng_tracker
18
+ from SwissArmyTransformer.model.base_model import BaseModel, BaseMixin
19
+ from SwissArmyTransformer.mpu.utils import sqrt
20
+ from SwissArmyTransformer.model.transformer import split_tensor_along_last_dim
21
+ from SwissArmyTransformer.ops.local_attention_function import f_similar, f_weighting
22
+ except ModuleNotFoundError as e:
23
+ handle_module_not_found_error(e, ["heim"])
24
+
25
+
26
+ class PositionEmbeddingMixin(BaseMixin):
27
+ def __init__(
28
+ self, additional_sequence_length, hidden_size, init_method_std=0.02, reinit_slice=slice(512, 512 + 400)
29
+ ):
30
+ super(PositionEmbeddingMixin, self).__init__()
31
+ self.reinit_slice = reinit_slice
32
+ self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
33
+ torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
34
+
35
+ def reinit(self, parent_model=None):
36
+ old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
37
+ old_len, hidden_size = old_weights.shape
38
+ assert hidden_size == self.position_embeddings.weight.shape[-1]
39
+ old_edge, new_edge = sqrt(old_len), sqrt(self.position_embeddings.weight.shape[-2])
40
+ assert new_edge % old_edge == 0
41
+ self.position_embeddings.weight.data.view(
42
+ new_edge // old_edge, old_edge, new_edge // old_edge, old_edge, hidden_size
43
+ ).copy_(old_weights.view(1, old_edge, 1, old_edge, hidden_size))
44
+
45
+
46
+ class ItersrModel(BaseModel):
47
+ def __init__(self, args, transformer=None):
48
+ super().__init__(args, transformer=transformer)
49
+ self.original_sequence_length = args.max_sequence_length
50
+ additional_seqlen = args.new_sequence_length - args.max_sequence_length
51
+ self.add_mixin("extra_position_embedding", PositionEmbeddingMixin(additional_seqlen, args.hidden_size))
52
+ # self.add_mixin('attention_plus', AttentionMixin(
53
+ # num_layers=args.num_layers,
54
+ # hidden_size=args.hidden_size
55
+ # ))
56
+ self.layout = args.layout
57
+ # [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] 4095 {layout[2]}
58
+ self.kernel_size = args.kernel_size
59
+ self.kernel_size2 = args.kernel_size2
60
+ self.log_attention_weights = None
61
+
62
+ def position_embedding_forward(self, position_ids, **kw_args):
63
+ position = position_ids[..., : self.layout[0]]
64
+ position_plus = position_ids[..., self.layout[0] :] - self.original_sequence_length
65
+ position_embeddings = torch.cat(
66
+ (
67
+ self.transformer.position_embeddings(position),
68
+ self.get_mixin("extra_position_embedding").position_embeddings(position_plus),
69
+ ),
70
+ dim=-2,
71
+ )
72
+ return position_embeddings
73
+
74
+ def attention_forward(self, hidden_states, mask, layer_id=None, log_attention_weights=None, **kw_args):
75
+ attn_module = self.transformer.layers[layer_id].attention
76
+ # base model qkv
77
+ mixed_raw_layer = attn_module.query_key_value(hidden_states)
78
+ q0, k0, v0 = split_tensor_along_last_dim(mixed_raw_layer[:, : self.layout[0]], 3)
79
+ # cuda2d model qkv
80
+ q1, k1, v1 = split_tensor_along_last_dim(mixed_raw_layer[:, self.layout[0] :], 3)
81
+
82
+ dropout_fn = attn_module.attention_dropout if self.training else None
83
+
84
+ # cuda2d attention
85
+ context_layer = sparse_attention_2d_text(
86
+ q0,
87
+ k0,
88
+ v0,
89
+ q1,
90
+ k1,
91
+ v1,
92
+ mask,
93
+ n_head=attn_module.num_attention_heads_per_partition,
94
+ text_len=self.layout[0],
95
+ kernel_size=self.kernel_size,
96
+ attention_dropout=dropout_fn,
97
+ log_attention_weights=log_attention_weights,
98
+ )
99
+
100
+ output = attn_module.dense(context_layer)
101
+
102
+ return output
103
+
104
+ def final_forward(self, logits, **kwargs):
105
+ logits_parallel = logits
106
+ logits_parallel = torch.nn.functional.linear(
107
+ logits_parallel, self.transformer.word_embeddings.weight[:20000]
108
+ ).float()
109
+ # logits_parallel = torch.nn.functional.linear(logits_parallel, self.transformer.word_embeddings.weight[:20000])
110
+ return logits_parallel
111
+
112
+ # def disable_untrainable_params(self):
113
+ # self.transformer.requires_grad_(False)
114
+
115
+ @classmethod
116
+ def add_model_specific_args(cls, parser):
117
+ group = parser.add_argument_group("Cuda2dModel", "cuda2d model configurations")
118
+ group.add_argument("--kernel-size", type=int, default=5)
119
+ group.add_argument("--kernel-size2", type=int, default=5)
120
+ group.add_argument("--layout", type=str, default="16,3616")
121
+ group.add_argument("--new-sequence-length", type=int, default=4096)
122
+ return parser
123
+
124
+
125
+ def sparse_attention_2d_text(
126
+ q0,
127
+ k0,
128
+ v0,
129
+ q1,
130
+ k1,
131
+ v1,
132
+ attention_mask,
133
+ n_head,
134
+ text_len,
135
+ kernel_size=9,
136
+ attention_dropout=None,
137
+ log_attention_weights=None,
138
+ **kwargs,
139
+ ):
140
+ """
141
+ q0, k0, v0: [batch_size, 16, hidden_size]
142
+ q1, k1, v1: [batch_size, 3600, hidden_size]
143
+ n_head: int
144
+ attention_mask: [batch_size, 16]
145
+ """
146
+ b, s0, h0 = q0.shape
147
+ b, s1, h1 = q1.shape
148
+ h, l1 = h0 // n_head, sqrt(s1)
149
+ assert attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}"
150
+
151
+ q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
152
+ v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
153
+ k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
154
+
155
+ # standard attention for level 0
156
+ attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
157
+
158
+ attention_scores = torch.mul(attention_scores, attention_mask) - 10000.0 * (1.0 - attention_mask)
159
+
160
+ attention_probs0 = F.softmax(attention_scores, dim=-1)
161
+
162
+ # local attention for level 1
163
+ q1 = (
164
+ (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1 // n_head))
165
+ .contiguous()
166
+ .view(b * n_head, h1 // n_head, l1, l1)
167
+ )
168
+ k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b * n_head, h1 // n_head, l1, l1)
169
+ v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b * n_head, h1 // n_head, l1, l1)
170
+ scores_1_to_1 = f_similar(q1, k1, kernel_size * 2 - 1, kernel_size, False)
171
+
172
+ # cross attention
173
+ scores_1_to_0 = torch.matmul(q1.view(b, n_head, h, s1).transpose(-1, -2), k0T)
174
+ if log_attention_weights is not None:
175
+ scores_1_to_0 += log_attention_weights
176
+ scores_1_to_0 = torch.mul(scores_1_to_0, attention_mask) - 10000.0 * (1.0 - attention_mask)
177
+ scores_1 = torch.cat(
178
+ (scores_1_to_0.view(b * n_head, s1, s0), scores_1_to_1.view(b * n_head, -1, scores_1_to_1.shape[3])), dim=-1
179
+ )
180
+ attention_probs1 = F.softmax(scores_1, dim=-1)
181
+
182
+ if attention_dropout is not None:
183
+ with get_cuda_rng_tracker().fork():
184
+ attention_probs1 = attention_dropout(attention_probs1)
185
+
186
+ # weighting for level 0
187
+ context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
188
+ # weighting for level 1
189
+ probs_1_to_1 = attention_probs1[:, :, -scores_1_to_1.shape[3] :].view_as(scores_1_to_1)
190
+ context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size * 2 - 1, kernel_size, False)
191
+
192
+ context1 = context1_to_1.view(b, n_head, h, l1**2)
193
+ # weighting for cross attention
194
+ probs_1_to_0 = attention_probs1[:, :, : scores_1_to_0.shape[3]].view(b, n_head, -1, scores_1_to_0.shape[3])
195
+
196
+ context1_to_0 = torch.matmul(probs_1_to_0, v0)
197
+ context1 = context1.transpose(-1, -2) + context1_to_0
198
+
199
+ output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0 + s1, h0)
200
+
201
+ return output
202
+
203
+
204
+ def sparse_attention_2d_notext(
205
+ q0,
206
+ k0,
207
+ v0,
208
+ q1,
209
+ k1,
210
+ v1,
211
+ attention_mask,
212
+ n_head,
213
+ text_len,
214
+ kernel_size=9,
215
+ attention_dropout=None,
216
+ log_attention_weights=None,
217
+ **kwargs,
218
+ ):
219
+ """
220
+ q0, k0, v0: [batch_size, 16, hidden_size]
221
+ q1, k1, v1: [batch_size, 3600, hidden_size]
222
+ n_head: int
223
+ attention_mask: [batch_size, 16]
224
+ """
225
+ b, s0, h0 = q0.shape
226
+ b, s1, h1 = q1.shape
227
+ h, l1 = h0 // n_head, sqrt(s1)
228
+ assert len(attention_mask.shape) == 4 and attention_mask.shape[-1] == s0, f"Mask Shape: {attention_mask.shape}"
229
+
230
+ q0 = q0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
231
+ v0 = v0.reshape(b, s0, n_head, h).permute(0, 2, 1, 3)
232
+ k0T = k0.reshape(b, s0, n_head, h).permute(0, 2, 3, 1)
233
+
234
+ # standard attention for level 0
235
+ attention_scores = torch.matmul(q0 / math.sqrt(q0.shape[-1]), k0T)
236
+
237
+ attention_scores = torch.mul(attention_scores, attention_mask) - 10000.0 * (1.0 - attention_mask)
238
+
239
+ attention_probs0 = F.softmax(attention_scores, dim=-1)
240
+
241
+ # local attention for level 1
242
+ q1 = (
243
+ (q1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1) / math.sqrt(h1 // n_head))
244
+ .contiguous()
245
+ .view(b * n_head, h1 // n_head, l1, l1)
246
+ )
247
+ k1 = k1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b * n_head, h1 // n_head, l1, l1)
248
+ v1 = v1.view(b, s1, n_head, h1 // n_head).permute(0, 2, 3, 1).contiguous().view(b * n_head, h1 // n_head, l1, l1)
249
+ scores_1_to_1 = f_similar(q1, k1, kernel_size * 2 - 1, kernel_size, False)
250
+
251
+ attention_probs1 = F.softmax(scores_1_to_1, dim=-1)
252
+
253
+ if attention_dropout is not None:
254
+ with get_cuda_rng_tracker().fork():
255
+ attention_probs1 = attention_dropout(attention_probs1)
256
+
257
+ # weighting for level 0
258
+ context0 = torch.matmul(attention_probs0, v0) # [b, n_head, s0, h]
259
+ # weighting for level 1
260
+ probs_1_to_1 = attention_probs1
261
+ context1_to_1 = f_weighting(v1, probs_1_to_1.contiguous(), kernel_size * 2 - 1, kernel_size, False)
262
+
263
+ context1 = context1_to_1.view(b, n_head, h, l1**2)
264
+ # weighting for cross attention
265
+ context1 = context1.transpose(-1, -2)
266
+
267
+ output = torch.cat((context0, context1), dim=2).transpose(1, 2).reshape(b, s0 + s1, h0)
268
+
269
+ return output
@@ -0,0 +1,120 @@
1
+ # -*- encoding: utf-8 -*-
2
+ """
3
+ @File : itersr_sampling.py
4
+ @Time : 2022/03/03 14:24:28
5
+ @Author : Ming Ding
6
+ @Contact : dm18@mails.tsinghua.edu.cn
7
+ """
8
+
9
+ # here put the import lib
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from icetk import icetk as tokenizer
13
+
14
+
15
+ def top_k_logits_(logits, top_k=0, filter_value=-float("Inf")):
16
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
17
+ logits[indices_to_remove] = filter_value
18
+ return logits
19
+
20
+
21
+ class IterativeEntfilterStrategy:
22
+ def __init__(self, invalid_slices=[], temperature=1.0, topk=10):
23
+ self.invalid_slices = invalid_slices
24
+ self.temperature = temperature
25
+ self.topk = topk
26
+
27
+ def forward(self, logits, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
28
+ # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
29
+ if temperature is None:
30
+ temperature = self.temperature
31
+
32
+ logits = logits.float() / temperature
33
+ for invalid_slice in self.invalid_slices:
34
+ logits[..., invalid_slice] = -float("Inf")
35
+
36
+ # debiased topk
37
+ # probs = F.softmax(logits, dim=-1)
38
+ # tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1)
39
+ # pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
40
+ # edge_idx = tk_idx[:, :, -1:]
41
+ # edge_value = tk_value[:, :, -1:]
42
+ # edge_mask = probs.gather(dim=-1, index=pred) < edge_value
43
+ # pred[edge_mask] = edge_idx[edge_mask] # replace outliers as the "filter_topk"-th token
44
+ # pred.squeeze_(-1) # [batch_size, seq_length]
45
+
46
+ top_k_logits_(logits, self.topk)
47
+ probs = F.softmax(logits, dim=-1)
48
+ pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
49
+ pred.squeeze_(-1)
50
+
51
+ assert tokens.shape[1] == pred.shape[1]
52
+ tokens = pred
53
+ return tokens
54
+
55
+
56
+ def filling_sequence_itersr(
57
+ model,
58
+ seq0,
59
+ seq1,
60
+ warmup_steps=3,
61
+ block_hw=(4, 4),
62
+ strategy=IterativeEntfilterStrategy(topk=10),
63
+ ):
64
+ """
65
+ seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1]
66
+ 4095 {layout[2]} final_token.
67
+ Attention:
68
+ The sampling temperature are changing, temporally we hard code them here.
69
+ The temperature in the strategy is not used.
70
+ """
71
+ assert hasattr(model, "layout")
72
+ layout = model.layout
73
+
74
+ device = seq0.device
75
+ # concat and pad sequences
76
+ batch_size = seq0.shape[0]
77
+ n_pad = layout[0] - seq0.shape[1]
78
+ assert n_pad >= 0, "You should truncate long input before filling."
79
+ seq = torch.cat(
80
+ (torch.tensor([0] * n_pad, device=device, dtype=seq0.dtype).unsqueeze(0).expand(batch_size, n_pad), seq0, seq1),
81
+ dim=1,
82
+ ) # [b, layout[-1]+1]
83
+ assert seq.shape[1] == layout[-1]
84
+
85
+ # build initial tokens, attention_mask, and position_ids
86
+ tokens = seq.clone()
87
+ attention_mask = torch.ones(layout[0]).to(device)
88
+ attention_mask[:n_pad] = 0
89
+ attention_mask = attention_mask.unsqueeze(0).type_as(next(model.parameters())) # if fp16
90
+ position_ids = torch.cat(
91
+ (
92
+ torch.zeros(n_pad, dtype=torch.long),
93
+ torch.arange(0, layout[0] - n_pad),
94
+ torch.arange(1024, 1024 + layout[1] - layout[0]),
95
+ )
96
+ ).to(device)
97
+ log_attention_weights = torch.zeros(layout[0], device=device).type_as(next(model.parameters()))
98
+ log_attention_weights[n_pad : layout[0]] = 0.0
99
+ log_attention_weights = log_attention_weights.unsqueeze(0)
100
+
101
+ # prepare for interation
102
+ unfixed = tokens == tokenizer["<start_of_image>"]
103
+ ll, rr = block_hw
104
+ # edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4)
105
+ num_steps = 1
106
+ # interative refining
107
+
108
+ # unfixed[..., -(layout[-1] - layout[-2]):].view(
109
+ # batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False
110
+
111
+ ret = []
112
+ # ret.append(tokens[:, layout[-2]:-1].clone())
113
+ for step_cnt in range(1, num_steps + 1):
114
+ logits, *_dump = model(tokens, position_ids, attention_mask, log_attention_weights=log_attention_weights)
115
+ real_temp = 1.0
116
+ new_tokens = strategy.forward(logits, tokens, real_temp)
117
+ tokens[unfixed] = new_tokens[unfixed]
118
+
119
+ ret.append(tokens[:, layout[-2] :].clone())
120
+ return torch.cat(ret, dim=0)
@@ -0,0 +1,42 @@
1
+ # -*- encoding: utf-8 -*-
2
+ """
3
+ @File : sr_group.py
4
+ @Time : 2022/04/02 01:17:21
5
+ @Author : Ming Ding
6
+ @Contact : dm18@mails.tsinghua.edu.cn
7
+ """
8
+
9
+ # here put the import lib
10
+ from .direct_sr import DirectSuperResolution
11
+ from .iterative_sr import IterativeSuperResolution
12
+
13
+ from helm.common.optional_dependencies import handle_module_not_found_error
14
+
15
+
16
+ class SRGroup:
17
+ def __init__(
18
+ self,
19
+ args,
20
+ home_path=None,
21
+ ):
22
+ try:
23
+ from SwissArmyTransformer.resources import auto_create
24
+ except ModuleNotFoundError as e:
25
+ handle_module_not_found_error(e, ["heim"])
26
+
27
+ dsr_path = auto_create("cogview2-dsr", path=home_path)
28
+ itersr_path = auto_create("cogview2-itersr", path=home_path)
29
+ dsr = DirectSuperResolution(args, dsr_path)
30
+ itersr = IterativeSuperResolution(args, itersr_path, shared_transformer=dsr.model.transformer)
31
+ self.dsr = dsr
32
+ self.itersr = itersr
33
+
34
+ def sr_base(self, img_tokens, txt_tokens):
35
+ assert img_tokens.shape[-1] == 400 and len(img_tokens.shape) == 2
36
+ batch_size = img_tokens.shape[0]
37
+ txt_len = txt_tokens.shape[-1]
38
+ if len(txt_tokens.shape) == 1:
39
+ txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
40
+ sred_tokens = self.dsr(txt_tokens, img_tokens)
41
+ iter_tokens = self.itersr(txt_tokens, sred_tokens[:, -3600:].clone())
42
+ return iter_tokens[-batch_size:]
@@ -0,0 +1,191 @@
1
+ import os
2
+ import argparse
3
+ from functools import partial
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import torch
7
+ from icetk import icetk as tokenizer
8
+ from torchvision.utils import save_image
9
+
10
+ from helm.common.cache import CacheConfig, Cache
11
+ from helm.common.file_caches.file_cache import FileCache
12
+ from helm.common.hierarchical_logger import hlog, htrack_block
13
+ from helm.common.optional_dependencies import handle_module_not_found_error
14
+ from helm.common.request import Request, RequestResult, GeneratedOutput, wrap_request_time
15
+ from helm.common.tokenization_request import (
16
+ DecodeRequest,
17
+ DecodeRequestResult,
18
+ TokenizationRequest,
19
+ TokenizationRequestResult,
20
+ )
21
+ from helm.clients.client import Client, CachingClient
22
+ from helm.clients.image_generation.cogview2.coglm_strategy import CoglmStrategy
23
+ from .image_generation_client_utils import get_single_image_multimedia_object
24
+
25
+
26
+ class CogView2Client(Client):
27
+ """
28
+ https://github.com/THUDM/CogView2
29
+ """
30
+
31
+ MAX_SEQ_LEN: int = 95
32
+ MODEL_URL: str = "https://nlp.stanford.edu/projects/vhelm/cogview2/sharefs.zip"
33
+
34
+ def __init__(self, cache_config: CacheConfig, file_cache: FileCache):
35
+ self._cache = Cache(cache_config)
36
+ self._file_cache: FileCache = file_cache
37
+
38
+ self._args: Optional[argparse.Namespace] = None
39
+ self._strategy: Optional[CoglmStrategy] = None
40
+ self._model = None
41
+ self._srg = None
42
+
43
+ def _get_model(self) -> None:
44
+ try:
45
+ from SwissArmyTransformer import get_args
46
+ from helm.clients.image_generation.cogview2.coglm_utils import (
47
+ get_recipe,
48
+ InferenceModel,
49
+ )
50
+ from helm.clients.image_generation.cogview2.sr_pipeline import SRGroup
51
+ except ModuleNotFoundError as e:
52
+ handle_module_not_found_error(e, ["heim"])
53
+
54
+ tokenizer.add_special_tokens(["<start_of_image>", "<start_of_english>", "<start_of_chinese>"])
55
+
56
+ model_local_path: str = f"{self._file_cache._location}/cogview2" # type: ignore
57
+ os.environ["SAT_HOME"] = f"{model_local_path}/sharefs/cogview-new"
58
+
59
+ # Download the model if not yet
60
+ if not os.path.exists(model_local_path):
61
+ os.system(f"mkdir -p {model_local_path}")
62
+ os.system(f"wget {self.MODEL_URL} -P {model_local_path}")
63
+ os.system(f"unzip {model_local_path}/sharefs.zip -d {model_local_path}")
64
+
65
+ if self._model is None:
66
+ # Set up args
67
+ args = get_args("--mode inference --fp16".split())
68
+ self._args = argparse.Namespace(**vars(args), **get_recipe("none"))
69
+ self._args.img_size = 160
70
+ self._args.only_first_stage = False
71
+ self._args.inverse_prompt = False
72
+ self._args.batch_size = 1
73
+ self._args.max_inference_batch_size = 1
74
+
75
+ # Load the model components
76
+ self._model, self._args = InferenceModel.from_pretrained(self._args, "coglm")
77
+ invalid_slices = [slice(tokenizer.num_image_tokens, None)]
78
+ self._strategy = CoglmStrategy(
79
+ invalid_slices,
80
+ temperature=getattr(self._args, "temp_all_gen"),
81
+ top_k=getattr(self._args, "topk_gen"),
82
+ top_k_cluster=getattr(self._args, "temp_cluster_gen"),
83
+ )
84
+ self._srg = SRGroup(self._args) # type: ignore
85
+
86
+ def _model_inference(self, prompt) -> torch.Tensor:
87
+ try:
88
+ from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
89
+ from helm.clients.image_generation.cogview2.coglm_utils import get_masks_and_position_ids_coglm
90
+ except ModuleNotFoundError as e:
91
+ handle_module_not_found_error(e, ["heim"])
92
+
93
+ with torch.no_grad():
94
+ text = getattr(self._args, "query_template").format(prompt)
95
+ seq = tokenizer.encode(text)
96
+ if len(seq) > self.MAX_SEQ_LEN:
97
+ seq = seq[: self.MAX_SEQ_LEN - 2] + seq[-2:]
98
+ txt_len = len(seq) - 1
99
+ device = getattr(self._args, "device")
100
+ seq = torch.tensor(seq + [-1] * 400, device=device)
101
+ # calibrate text length
102
+ log_attention_weights = torch.zeros(
103
+ len(seq), len(seq), device=device, dtype=torch.half if getattr(self._args, "fp16") else torch.float32
104
+ )
105
+ log_attention_weights[:, :txt_len] = getattr(self._args, "attn_plus")
106
+ # generation
107
+ mbz = getattr(self._args, "max_inference_batch_size")
108
+ batch_size = getattr(self._args, "batch_size")
109
+ assert batch_size < mbz or batch_size % mbz == 0
110
+ get_func = partial(get_masks_and_position_ids_coglm, context_length=txt_len)
111
+ output_list = []
112
+ for tim in range(max(batch_size // mbz, 1)):
113
+ setattr(self._strategy, "start_pos", txt_len + 1)
114
+ coarse_samples = filling_sequence(
115
+ self._model,
116
+ seq.clone(),
117
+ batch_size=min(batch_size, mbz),
118
+ strategy=self._strategy,
119
+ log_attention_weights=log_attention_weights,
120
+ get_masks_and_position_ids=get_func,
121
+ )[0]
122
+ output_list.append(coarse_samples)
123
+
124
+ output_tokens = torch.cat(output_list, dim=0)
125
+ images = []
126
+ iter_tokens = getattr(self._srg, "sr_base")(output_tokens[:, -400:], seq[:txt_len])
127
+ for seq in iter_tokens:
128
+ decoded_img = tokenizer.decode(image_ids=seq[-3600:])
129
+ decoded_img = torch.nn.functional.interpolate(decoded_img, size=(480, 480))
130
+ images.append(decoded_img) # only the last image (target)
131
+ return images[0]
132
+
133
+ def make_request(self, request: Request) -> RequestResult:
134
+ raw_request = {
135
+ "prompt": request.prompt,
136
+ }
137
+
138
+ try:
139
+
140
+ def do_it() -> Dict[str, Any]:
141
+ prompt: str = request.prompt
142
+
143
+ with htrack_block(f"Generating images for prompt: {prompt}"):
144
+ self._get_model()
145
+
146
+ images: List[torch.Tensor] = []
147
+ for _ in range(request.num_completions):
148
+ output = self._model_inference(**raw_request).cpu() # (1, 3, 480, 480)
149
+ images.append(output)
150
+
151
+ assert (
152
+ len(images) == request.num_completions
153
+ ), f"Expected {request.num_completions} images, but got {len(images)}"
154
+
155
+ result: Dict = {"file_locations": []}
156
+ for image in images:
157
+ # Write out the image to a file and save the path
158
+ file_location: str = self._file_cache.generate_unique_new_file_path() # type: ignore
159
+ save_image(image, file_location, normalize=True)
160
+ hlog(f"Image saved at {file_location}.")
161
+ result["file_locations"].append(file_location)
162
+ return result
163
+
164
+ # Include the model name and number of completions in the cache key
165
+ cache_key = CachingClient.make_cache_key(
166
+ {"model": request.model_engine, "n": request.num_completions, **raw_request}, request
167
+ )
168
+ results, cached = self._cache.get(cache_key, wrap_request_time(do_it))
169
+ except RuntimeError as e:
170
+ error: str = f"CogView2Client error: {e}"
171
+ return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
172
+
173
+ completions: List[GeneratedOutput] = [
174
+ GeneratedOutput(
175
+ text="", logprob=0, tokens=[], multimodal_content=get_single_image_multimedia_object(location)
176
+ )
177
+ for location in results["file_locations"]
178
+ ]
179
+ return RequestResult(
180
+ success=True,
181
+ cached=cached,
182
+ request_time=results["request_time"],
183
+ completions=completions,
184
+ embedding=[],
185
+ )
186
+
187
+ def tokenize(self, request: TokenizationRequest) -> TokenizationRequestResult:
188
+ raise NotImplementedError("This client does not support tokenizing.")
189
+
190
+ def decode(self, request: DecodeRequest) -> DecodeRequestResult:
191
+ raise NotImplementedError("This client does not support decoding.")