compressed-tensors 0.10.3a20250716__tar.gz → 0.10.3a20250724__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. {compressed_tensors-0.10.3a20250716/src/compressed_tensors.egg-info → compressed_tensors-0.10.3a20250724}/PKG-INFO +1 -1
  2. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +12 -6
  3. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/lifecycle/forward.py +68 -5
  4. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/lifecycle/initialize.py +35 -2
  5. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/quant_args.py +31 -8
  6. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/quant_scheme.py +41 -0
  7. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/utils/helpers.py +11 -2
  8. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/factory/base.py +3 -4
  9. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/utils/__init__.py +1 -0
  10. compressed_tensors-0.10.3a20250724/src/compressed_tensors/utils/match.py +191 -0
  11. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/version.py +1 -1
  12. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
  13. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors.egg-info/SOURCES.txt +2 -0
  14. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_examples/test_bitmask_compression_ipynb.py +3 -1
  15. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/test_forward.py +50 -0
  16. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/test_initialize.py +13 -3
  17. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/test_quant_args.py +2 -1
  18. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/test_utils/test_helpers.py +28 -1
  19. compressed_tensors-0.10.3a20250724/tests/test_utils/test_match.py +426 -0
  20. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/.gitkeep +0 -0
  21. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/actions/test/action.yml +0 -0
  22. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/scripts/step-status +0 -0
  23. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/workflows/build-test.yml +0 -0
  24. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/workflows/build.yml +0 -0
  25. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/workflows/report.yml +0 -0
  26. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/workflows/test-check.yaml +0 -0
  27. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/workflows/test.yml +0 -0
  28. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/workflows/trigger-all.yml +0 -0
  29. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/workflows/upload.yml +0 -0
  30. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.gitignore +0 -0
  31. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/LICENSE +0 -0
  32. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/Makefile +0 -0
  33. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/README.md +0 -0
  34. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  35. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/examples/bit_packing/int4_config.json +0 -0
  36. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/examples/bitmask_compression.ipynb +0 -0
  37. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  38. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  39. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/examples/llama_1.1b/example_quant_config.json +0 -0
  40. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  41. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/examples/quantize_and_pack_int4.ipynb +0 -0
  42. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/pyproject.toml +0 -0
  43. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/setup.cfg +0 -0
  44. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/setup.py +0 -0
  45. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/__init__.py +0 -0
  46. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/README.md +0 -0
  47. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/__init__.py +0 -0
  48. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/base.py +0 -0
  49. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/__init__.py +0 -0
  50. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/base.py +0 -0
  51. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/helpers.py +0 -0
  52. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  53. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  54. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  55. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  56. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
  57. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  58. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  59. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  60. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  61. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  62. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  63. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  64. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  65. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/config/__init__.py +0 -0
  66. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/config/base.py +0 -0
  67. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/config/dense.py +0 -0
  68. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  69. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  70. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/linear/__init__.py +0 -0
  71. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  72. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/__init__.py +0 -0
  73. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  74. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  75. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  76. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  77. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/quant_config.py +0 -0
  78. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  79. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/registry/__init__.py +0 -0
  80. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/registry/registry.py +0 -0
  81. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/__init__.py +0 -0
  82. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/apply.py +0 -0
  83. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/factory/__init__.py +0 -0
  84. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/factory/hadamard.py +0 -0
  85. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/factory/matrix_multiply.py +0 -0
  86. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/factory/random_hadamard.py +0 -0
  87. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/transform_args.py +0 -0
  88. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/transform_config.py +0 -0
  89. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/transform_scheme.py +0 -0
  90. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/utils/__init__.py +0 -0
  91. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/utils/hadamard.py +0 -0
  92. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
  93. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/utils/matrix.py +0 -0
  94. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/utils/helpers.py +0 -0
  95. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/utils/internal.py +0 -0
  96. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/utils/offload.py +0 -0
  97. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/utils/permutations_24.py +0 -0
  98. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/utils/permute.py +0 -0
  99. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  100. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  101. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  102. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors.egg-info/requires.txt +0 -0
  103. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  104. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/__init__.py +0 -0
  105. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/conftest.py +0 -0
  106. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/__init__.py +0 -0
  107. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/model_compressors/__init__.py +0 -0
  108. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  109. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  110. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  111. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  112. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
  113. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  114. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  115. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  116. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  117. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  118. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  119. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_configs/__init__.py +0 -0
  120. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_configs/test_base.py +0 -0
  121. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_linear/__init__.py +0 -0
  122. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_linear/test_compressed_linear.py +0 -0
  123. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/__init__.py +0 -0
  124. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/__init__.py +0 -0
  125. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/conftest.py +0 -0
  126. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  127. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  128. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  129. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
  130. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  131. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/test_configs/__init__.py +0 -0
  132. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  133. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  134. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/test_quant_config.py +0 -0
  135. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/test_quant_scheme.py +0 -0
  136. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_registry.py +0 -0
  137. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_transform/conftest.py +0 -0
  138. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_transform/factory/test_correctness.py +0 -0
  139. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_transform/factory/test_memory.py +0 -0
  140. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_transform/test_transform_args.py +0 -0
  141. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_transform/test_transform_config.py +0 -0
  142. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_transform/test_transform_scheme.py +0 -0
  143. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_transform/utils/test_hadamard.py +0 -0
  144. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_utils/__init__.py +0 -0
  145. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_utils/test_helpers.py +0 -0
  146. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_utils/test_offload.py +0 -0
  147. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_utils/test_safetensors_load.py +0 -0
  148. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/testing_utils.py +0 -0
  149. {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.3a20250716
3
+ Version: 0.10.3a20250724
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -400,7 +400,10 @@ class ModelCompressor:
400
400
 
401
401
  # in the future, support compression on same device
402
402
  with align_module_device(module, execution_device=exec_device):
403
- state_dict = module.state_dict(prefix=f"{prefix}.")
403
+ state_dict = {
404
+ f"{prefix}.{name}": param
405
+ for name, param in module.named_parameters(recurse=False)
406
+ }
404
407
 
405
408
  # quantization first
406
409
  if prefix in module_to_scheme:
@@ -421,7 +424,7 @@ class ModelCompressor:
421
424
 
422
425
  # remove any existing parameters
423
426
  offload_device = get_offloaded_device(module)
424
- for name, _ in list(module.named_parameters()):
427
+ for name, _ in list(module.named_parameters(recurse=False)):
425
428
  delete_offload_parameter(module, name)
426
429
 
427
430
  # replace with compressed parameters
@@ -458,7 +461,10 @@ class ModelCompressor:
458
461
  if prefix in module_to_scheme or prefix in sparse_compression_targets:
459
462
  # in the future, support decompression on same device
460
463
  with align_module_device(module, execution_device="cpu"):
461
- state_dict = module.state_dict(prefix=f"{prefix}.")
464
+ state_dict = {
465
+ f"{prefix}.{name}": param
466
+ for name, param in module.named_parameters(recurse=False)
467
+ }
462
468
 
463
469
  # sparsity first
464
470
  if prefix in sparse_compression_targets:
@@ -483,7 +489,7 @@ class ModelCompressor:
483
489
  # remove any existing parameters
484
490
  exec_device = get_execution_device(module)
485
491
  offload_device = get_offloaded_device(module)
486
- for name, _ in list(module.named_parameters()):
492
+ for name, _ in list(module.named_parameters(recurse=False)):
487
493
  delete_offload_parameter(module, name)
488
494
 
489
495
  # replace with decompressed parameters
@@ -754,8 +760,8 @@ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
754
760
  fix_fsdp_module_name(name): module.quantization_scheme
755
761
  for name, module in model.named_modules()
756
762
  if (
757
- hasattr(module, "quantization_scheme") and
758
- module.quantization_scheme.weights is not None
763
+ hasattr(module, "quantization_scheme")
764
+ and module.quantization_scheme.weights is not None
759
765
  )
760
766
  }
761
767
 
@@ -111,11 +111,18 @@ def dequantize(
111
111
  elif scale.ndim == 2:
112
112
  if scale.shape[1] == 1:
113
113
  args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
114
- else:
114
+ # Scale height matches input or is 1 -> group quantization across columns
115
+ #
116
+ # Example 1: scale.shape[0] == 1
117
+ # x_q: (4, 8), scale: (1, 4) -> 2 columns per group
118
+ #
119
+ # Example 2: scale.shape[0] == x_q.shape[0]
120
+ # x_q: (4, 8), scale: (4, 4) -> 2 elements per group (per row)
121
+ elif (scale.shape[0] == 1) or (scale.shape[0] == x_q.shape[0]):
115
122
  group_size = int(x_q.shape[1] / scale.shape[1])
116
- args = QuantizationArgs(
117
- strategy=QuantizationStrategy.GROUP, group_size=group_size
118
- )
123
+ args = QuantizationArgs(strategy=QuantizationStrategy.GROUP, group_size=group_size)
124
+ else:
125
+ args = QuantizationArgs(strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape)
119
126
  else:
120
127
  raise ValueError(
121
128
  f"Could not infer a quantization strategy from scale with {scale.ndim} "
@@ -189,7 +196,63 @@ def _process_quantization(
189
196
  q_min, q_max = calculate_range(args, x.device)
190
197
  group_size = args.group_size
191
198
 
192
- if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
199
+ # blockwise FP8: quantize per 2D block, supports block_structure for static block quant
200
+ if args.strategy == QuantizationStrategy.BLOCK:
201
+ original_shape = x.shape
202
+ rows, cols = x.shape[-2], x.shape[-1]
203
+ block_height, block_width = args.block_structure
204
+
205
+ # Ensure exact division (tensor dimensions must be divisible by block size)
206
+ if rows % block_height != 0:
207
+ raise ValueError(
208
+ f"Tensor height {rows} is not divisible by block_height {block_height}. "
209
+ f"Block quantization requires exact division."
210
+ )
211
+ if cols % block_width != 0:
212
+ raise ValueError(
213
+ f"Tensor width {cols} is not divisible by block_width {block_width}. "
214
+ f"Block quantization requires exact division."
215
+ )
216
+
217
+ # reshape into blocks and transpose to make each block contiguous
218
+ num_rows_blocks = rows // block_height
219
+ num_cols_blocks = cols // block_width
220
+ x_blocks = x.reshape(
221
+ num_rows_blocks,
222
+ block_height,
223
+ num_cols_blocks,
224
+ block_width,
225
+ ).transpose(1, 2)
226
+
227
+ # expand scale/zero_point for blocks
228
+ sb = scale.unsqueeze(-1).unsqueeze(-1)
229
+ zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None
230
+ if do_quantize:
231
+ # quantize blocks
232
+ x_blocks = _quantize(
233
+ x=x_blocks,
234
+ scale=sb,
235
+ zero_point=zb,
236
+ q_min=q_min,
237
+ q_max=q_max,
238
+ args=args,
239
+ dtype=dtype,
240
+ global_scale=global_scale,
241
+ )
242
+ if do_dequantize:
243
+ # dequantize blocks
244
+ x_blocks = _dequantize(
245
+ x_q=x_blocks,
246
+ scale=sb,
247
+ zero_point=zb,
248
+ global_scale=global_scale,
249
+ )
250
+ # restore original shape
251
+ output = x_blocks.transpose(1, 2).reshape(original_shape)
252
+ elif args.strategy in (
253
+ QuantizationStrategy.GROUP,
254
+ QuantizationStrategy.TENSOR_GROUP,
255
+ ):
193
256
  n_dims = x.shape
194
257
  if len(n_dims) > 2:
195
258
  x = x.squeeze(0)
@@ -15,6 +15,7 @@
15
15
 
16
16
  import logging
17
17
  import math
18
+ import warnings
18
19
  from enum import Enum
19
20
  from typing import List, Optional
20
21
 
@@ -172,14 +173,41 @@ def _initialize_scale_zero_point(
172
173
 
173
174
  if base_name == "weight" and weight_shape is not None:
174
175
  if quantization_args.strategy == QuantizationStrategy.CHANNEL:
175
- # (output_channels, 1)
176
+ # (output_channels, 1) - only for weights
176
177
  expected_shape = (weight_shape[0], 1)
177
178
  elif quantization_args.strategy in (
178
179
  QuantizationStrategy.TENSOR_GROUP,
179
180
  QuantizationStrategy.GROUP,
180
181
  ):
182
+ # GROUP/TENSOR_GROUP for both weights and activations
181
183
  num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
182
184
  expected_shape = (weight_shape[0], max(num_groups, 1))
185
+ elif quantization_args.strategy == QuantizationStrategy.BLOCK:
186
+ # For block quantization, scale shape should match number of blocks - only for weights
187
+ if quantization_args.block_structure is None:
188
+ raise ValueError("Block quantization requires block_structure to be specified")
189
+ block_height, block_width = quantization_args.block_structure
190
+ rows, cols = weight_shape[-2], weight_shape[-1]
191
+ num_rows_blocks = math.ceil(rows / block_height)
192
+ num_cols_blocks = math.ceil(cols / block_width)
193
+
194
+ # Warn if dimensions don't divide evenly
195
+ if rows % block_height != 0 or cols % block_width != 0:
196
+ warnings.warn(
197
+ f"Block quantization: tensor shape {weight_shape} does not divide evenly "
198
+ f"by block structure {quantization_args.block_structure}. "
199
+ f"Some blocks will be incomplete which may affect quantization quality.",
200
+ UserWarning
201
+ )
202
+
203
+ expected_shape = (num_rows_blocks, num_cols_blocks)
204
+ elif quantization_args.strategy == QuantizationStrategy.BLOCK:
205
+ warnings.warn(
206
+ f"BLOCK quantization not supported for {base_name} activations. "
207
+ f"Falling back to tensor-level quantization.",
208
+ UserWarning
209
+ )
210
+ expected_shape = 1
183
211
 
184
212
  # 3. Identify quantization scale and zp dtype
185
213
  scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype
@@ -189,7 +217,12 @@ def _initialize_scale_zero_point(
189
217
  else:
190
218
  # TODO: consider erroring out in the future as if the dtype if not one of these,
191
219
  # there is likely bug
192
- if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32, torch.float64]:
220
+ if scale_dtype not in [
221
+ torch.float16,
222
+ torch.bfloat16,
223
+ torch.float32,
224
+ torch.float64,
225
+ ]:
193
226
  scale_dtype = torch.float16
194
227
  zp_dtype = quantization_args.pytorch_dtype()
195
228
 
@@ -14,7 +14,7 @@
14
14
 
15
15
  import warnings
16
16
  from enum import Enum
17
- from typing import Any, Dict, Optional, Union
17
+ from typing import Any, Dict, List, Optional, Union
18
18
 
19
19
  import torch
20
20
  from compressed_tensors.utils import Aliasable
@@ -153,8 +153,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
153
153
  :param symmetric: whether or not quantization scale is symmetric about zero-point
154
154
  :param strategy: string id determining the scope of scale/zero-point to apply
155
155
  :param group_size: group length to use for the group strategy
156
- :param block_structure: 2d block structure to use for the block strategy, must be
157
- of the format "2x4", "8x16", etc.
156
+ :param block_structure: 2d block structure to use for the block strategy; must be
157
+ a list of two ints [rows, cols] like [128, 128].
158
158
  :param dynamic: set True to perform dynamic quantization - values will not be
159
159
  calibrated during calibration phase, instead during inference new quantization
160
160
  ranges will be observed with every sample. Defaults to False for static
@@ -169,7 +169,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
169
169
  symmetric: bool = True
170
170
  group_size: Optional[int] = None
171
171
  strategy: Optional[QuantizationStrategy] = None
172
- block_structure: Optional[str] = None
172
+ block_structure: Optional[List[int]] = None
173
173
  dynamic: Union[DynamicType, bool] = False
174
174
  actorder: Union[ActivationOrdering, bool, None] = None
175
175
  observer: Optional[str] = Field(
@@ -207,6 +207,28 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
207
207
 
208
208
  return value
209
209
 
210
+ @field_validator("block_structure", mode="before")
211
+ def validate_block_structure(cls, value) -> Optional[List[int]]:
212
+ if value is None:
213
+ return value
214
+ # For backward compatibility, allow string format "2x4", "8x16", etc.
215
+ if isinstance(value, str):
216
+ try:
217
+ return [int(x) for x in value.split("x")]
218
+ except Exception:
219
+ raise ValueError(
220
+ f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
221
+ )
222
+ if isinstance(value, (list, tuple)):
223
+ if len(value) != 2 or not all(isinstance(v, int) for v in value):
224
+ raise ValueError(
225
+ f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
226
+ )
227
+ return list(value)
228
+ raise ValueError(
229
+ f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
230
+ )
231
+
210
232
  @field_validator("strategy", mode="before")
211
233
  def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]:
212
234
  if isinstance(value, str):
@@ -277,14 +299,15 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
277
299
 
278
300
  # infer observer w.r.t. dynamic
279
301
  if dynamic:
280
- if strategy not in (
302
+ supported_strategies = (
281
303
  QuantizationStrategy.TOKEN,
282
304
  QuantizationStrategy.TENSOR,
283
305
  QuantizationStrategy.TENSOR_GROUP,
284
- ):
306
+ QuantizationStrategy.GROUP,
307
+ )
308
+ if strategy not in supported_strategies:
285
309
  raise ValueError(
286
- f"One of {(QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP)} "
287
- "must be used for dynamic quantization",
310
+ f"One of {supported_strategies} must be used for dynamic quantization"
288
311
  )
289
312
 
290
313
  if (
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import warnings
15
16
  from copy import deepcopy
16
17
  from typing import Any, Dict, List, Optional
17
18
 
@@ -52,6 +53,7 @@ class QuantizationScheme(BaseModel):
52
53
  def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
53
54
  inputs = model.input_activations
54
55
  outputs = model.output_activations
56
+ weights = model.weights
55
57
 
56
58
  if inputs is not None:
57
59
  if inputs.actorder is not None:
@@ -61,6 +63,21 @@ class QuantizationScheme(BaseModel):
61
63
  if outputs.actorder is not None:
62
64
  raise ValueError("Cannot apply actorder to output activations")
63
65
 
66
+ if (
67
+ inputs and weights
68
+ and weights.strategy == QuantizationStrategy.GROUP
69
+ and inputs.strategy == QuantizationStrategy.GROUP
70
+ and weights.group_size != inputs.group_size
71
+ ):
72
+ warnings.warn(
73
+ "Using GROUP strategy for both weights and input_activations "
74
+ f"with different group sizes ({weights.group_size} vs {inputs.group_size}) "
75
+ "may complicate fused kernel implementations. Consider using "
76
+ "TENSOR_GROUP strategy for both or matching group sizes.",
77
+ UserWarning,
78
+ stacklevel=2
79
+ )
80
+
64
81
  return model
65
82
 
66
83
 
@@ -243,6 +260,29 @@ FP8_DYNAMIC = dict(
243
260
  ),
244
261
  )
245
262
 
263
+ # Block‐wise FP8 (deepseekv3-style quantization):
264
+ # static 128x128 per‐block weights and
265
+ # dynamic per‐token‐group activations
266
+ FP8_BLOCK = dict(
267
+ weights=QuantizationArgs(
268
+ num_bits=8,
269
+ type=QuantizationType.FLOAT,
270
+ strategy=QuantizationStrategy.BLOCK,
271
+ symmetric=True,
272
+ dynamic=False,
273
+ block_structure=[128, 128],
274
+ ),
275
+ input_activations=QuantizationArgs(
276
+ num_bits=8,
277
+ type=QuantizationType.FLOAT,
278
+ strategy=QuantizationStrategy.GROUP,
279
+ symmetric=True,
280
+ dynamic=True,
281
+ observer=None,
282
+ group_size=128,
283
+ ),
284
+ )
285
+
246
286
  PRESET_SCHEMES = {
247
287
  # Unquantized (no-op)
248
288
  "UNQUANTIZED": UNQUANTIZED,
@@ -257,6 +297,7 @@ PRESET_SCHEMES = {
257
297
  # Float weight and activation schemes
258
298
  "FP8": FP8,
259
299
  "FP8_DYNAMIC": FP8_DYNAMIC,
300
+ "FP8_BLOCK": FP8_BLOCK,
260
301
  "NVFP4A16": NVFP4A16,
261
302
  "NVFP4": NVFP4,
262
303
  }
@@ -171,7 +171,10 @@ def compute_dynamic_scales_and_zp(
171
171
  reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
172
172
  elif args.strategy == QuantizationStrategy.TENSOR:
173
173
  reduce_dims = None
174
- elif args.strategy == QuantizationStrategy.TENSOR_GROUP:
174
+ elif args.strategy in (
175
+ QuantizationStrategy.TENSOR_GROUP,
176
+ QuantizationStrategy.GROUP,
177
+ ):
175
178
  if len(value.shape) > 2:
176
179
  value = value.squeeze(0)
177
180
 
@@ -187,9 +190,15 @@ def compute_dynamic_scales_and_zp(
187
190
  ),
188
191
  )
189
192
  else:
193
+ supported_strategies = (
194
+ QuantizationStrategy.TOKEN,
195
+ QuantizationStrategy.TENSOR,
196
+ QuantizationStrategy.TENSOR_GROUP,
197
+ QuantizationStrategy.GROUP,
198
+ )
190
199
  raise ValueError(
191
200
  "Dynamic quantization is only supported for ",
192
- f"{QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP}",
201
+ f"{supported_strategies}",
193
202
  )
194
203
 
195
204
  if not reduce_dims:
@@ -18,7 +18,6 @@ from typing import Optional
18
18
  import torch
19
19
  import torch.nn.utils.parametrize as P
20
20
  from compressed_tensors import InternalModule
21
- from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils
22
21
  from compressed_tensors.registry.registry import RegistryMixin, T
23
22
  from compressed_tensors.transform import (
24
23
  TransformArgs,
@@ -29,6 +28,7 @@ from compressed_tensors.utils import (
29
28
  align_module_device,
30
29
  delete_offload_module,
31
30
  has_offloaded_params,
31
+ match_named_modules,
32
32
  patch_attr,
33
33
  register_offload_module,
34
34
  update_offload_parameter,
@@ -87,9 +87,8 @@ class TransformFactory(RegistryMixin, ABC):
87
87
  :param model: module to apply transforms to
88
88
  """
89
89
  for arg in self.scheme.apply:
90
- for name, module in list(model.named_modules()):
91
- if is_target(name, module, arg.targets, arg.ignore):
92
- self._apply_to_module(module, arg)
90
+ for _, module in match_named_modules(model, arg.targets, arg.ignore):
91
+ self._apply_to_module(module, arg)
93
92
 
94
93
  def _apply_to_module(self, module: Module, args: TransformArgs):
95
94
  """
@@ -15,6 +15,7 @@
15
15
 
16
16
  from .helpers import *
17
17
  from .internal import *
18
+ from .match import *
18
19
  from .offload import *
19
20
  from .permutations_24 import *
20
21
  from .permute import *
@@ -0,0 +1,191 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import re
17
+ from collections.abc import Generator
18
+ from typing import Iterable, Tuple
19
+
20
+ import torch
21
+
22
+
23
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
24
+
25
+
26
+ __all__ = [
27
+ "match_named_modules",
28
+ "match_named_parameters",
29
+ "match_modules_set",
30
+ "is_match",
31
+ "match_name",
32
+ "match_class",
33
+ ]
34
+
35
+
36
+ def match_named_modules(
37
+ model: torch.nn.Module,
38
+ targets: Iterable[str],
39
+ ignore: Iterable[str] = tuple(),
40
+ warn_on_fail: bool = False,
41
+ ) -> Generator[Tuple[str, torch.nn.Module]]:
42
+ """
43
+ Yields names and modules which match `targets` but do not match `ignore`.
44
+ Values are returned in order of `model.named_modules()`
45
+
46
+ :param model: model containing submodules to match against
47
+ :param targets: target strings, potentially containing "re:" prefixes
48
+ :param ignore: targets to ignore, potentially containing "re:" prefixes
49
+ :param warn_on_fail: if True, warns if any targets do not match any modules in model
50
+ :return: generator of module names and modules
51
+ """
52
+ unmatched_targets = set(targets)
53
+ for name, module in model.named_modules():
54
+ for target in targets:
55
+ if is_match(name, module, target):
56
+ unmatched_targets -= {target}
57
+
58
+ if not any(is_match(name, module, ign) for ign in ignore):
59
+ yield name, module
60
+
61
+ if warn_on_fail:
62
+ for target in unmatched_targets:
63
+ _LOGGER.warning(
64
+ f"Could not match `{target}` in instance of {model.__class__.__name__}"
65
+ )
66
+
67
+
68
+ def match_named_parameters(
69
+ model: torch.nn.Module,
70
+ targets: Iterable[str],
71
+ ignore: Iterable[str] = tuple(),
72
+ warn_on_fail: bool = False,
73
+ ) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]:
74
+ """
75
+ Yields parameters which match `targets` but do not match `ignore`.
76
+ Values are returned in order of `model.named_modules()`
77
+
78
+ :param model: model containing params to match against
79
+ :param targets: target strings, potentially containing "re:" prefixes
80
+ :param ignore: targets to ignore, potentially containing "re:" prefixes
81
+ :param warn_on_fail: if True, warns if any targets do not match any params in model
82
+ :return: generator of fully-qualified param names, parent modules, and params
83
+ """
84
+ unmatched_targets = set(targets)
85
+ for module_name, module in model.named_modules():
86
+ for param_name, param in module.named_parameters(recurse=False):
87
+ param_fqn = f"{module_name}.{param_name}"
88
+ for target in targets:
89
+ if match_name(param_fqn, target):
90
+ unmatched_targets -= {target}
91
+
92
+ if not any(match_name(param_fqn, ign) for ign in ignore):
93
+ yield param_fqn, module, param
94
+
95
+ if warn_on_fail:
96
+ for target in unmatched_targets:
97
+ _LOGGER.warning(
98
+ f"Could not match `{target}` in instance of {model.__class__.__name__}"
99
+ )
100
+
101
+
102
+ def match_modules_set(
103
+ model: torch.nn.Module,
104
+ targets: Iterable[str],
105
+ ignore: Iterable[str] = tuple(),
106
+ ) -> Generator[Iterable[torch.nn.Module]]:
107
+ """
108
+ Yields modules grouped with the same order and size as `targets`.
109
+ Values are returned in order of `model.named_modules()`
110
+
111
+ For example, the following targets would yield module belonging to the following layers:
112
+ ```python3
113
+ match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == (
114
+ (
115
+ `model.layers.0.self_attn.q_proj`,
116
+ `model.layers.0.self_attn.k_proj`,
117
+ `model.layers.0.self_attn.v_proj`,
118
+ ),
119
+ (
120
+ `model.layers.1.self_attn.q_proj`,
121
+ `model.layers.1.self_attn.k_proj`,
122
+ `model.layers.1.self_attn.v_proj`,
123
+ ),
124
+ ...
125
+ (
126
+ `model.layers.32.self_attn.q_proj`,
127
+ `model.layers.32.self_attn.k_proj`,
128
+ `model.layers.32.self_attn.v_proj`,
129
+ ),
130
+ )
131
+ ```
132
+
133
+ This can be used to match layers to their corresponding downstream counterparts.
134
+ For example, matching layer norms to their subsequent linear layers
135
+ ```python3
136
+ for norm, q, k, v in match_modules_set(model, (norm_tgt, q_tgt, k_tgt, v_tgt)):
137
+ fuse_norm_linears(norm, [q, k, v])
138
+
139
+ :param model: model containing modules to match against
140
+ :param targets: target strings, potentially containing "re:" prefixes
141
+ :param ignore: targets to ignore, potentially containing "re:" prefixes
142
+ """
143
+ matches = dict.fromkeys(targets, None)
144
+ for name, module in model.named_modules():
145
+ # match until we get a full set
146
+ for target in targets:
147
+ if is_match(name, module, target) and not any(
148
+ is_match(name, module, ign) for ign in ignore
149
+ ):
150
+ if matches[target] is not None:
151
+ raise ValueError(f"Matched a {target} twice before completing set")
152
+ matches[target] = module
153
+
154
+ # once we have a full set, yield and reset
155
+ if targets and all((matches[target] is not None for target in targets)):
156
+ yield [matches[target] for target in targets] # ensure correct ordering
157
+ matches = dict.fromkeys(targets, None)
158
+
159
+ # check that none are left over
160
+ unmatched_keys = [match for match, value in matches.items() if value is not None]
161
+ if len(unmatched_keys):
162
+ raise ValueError(f"Unable to match targets into set: {unmatched_keys}")
163
+
164
+
165
+ def is_match(name: str, module: torch.nn.Module, target: str) -> bool:
166
+ """
167
+ Returns true if either module name or module parent classes match against target
168
+ """
169
+ return match_name(name, target) or match_class(module, target)
170
+
171
+
172
+ def match_name(name: str, target: str) -> bool:
173
+ """
174
+ Returns true if target string begins with "re:" and
175
+ regex matches or if target string exactly matches name
176
+ """
177
+ if target.startswith("re:"):
178
+ return re.match(target.removeprefix("re:"), name) is not None
179
+ else:
180
+ return target == name
181
+
182
+
183
+ def match_class(module: torch.nn.Module, target: str) -> bool:
184
+ """
185
+ Returns true if any torch parent class names match the target string exactly
186
+ """
187
+ # will never match against a regex pattern since `:` is not allowed in class names
188
+ return any(
189
+ issubclass(cls, torch.nn.Module) and cls.__name__ == target
190
+ for cls in module.__class__.__mro__
191
+ )
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.10.3.a20250716'
20
+ __version__ = version = '0.10.3.a20250724'
21
21
  __version_tuple__ = version_tuple = (0, 10, 3)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.3a20250716
3
+ Version: 0.10.3a20250724
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -88,6 +88,7 @@ src/compressed_tensors/transform/utils/matrix.py
88
88
  src/compressed_tensors/utils/__init__.py
89
89
  src/compressed_tensors/utils/helpers.py
90
90
  src/compressed_tensors/utils/internal.py
91
+ src/compressed_tensors/utils/match.py
91
92
  src/compressed_tensors/utils/offload.py
92
93
  src/compressed_tensors/utils/permutations_24.py
93
94
  src/compressed_tensors/utils/permute.py
@@ -141,6 +142,7 @@ tests/test_transform/factory/test_memory.py
141
142
  tests/test_transform/utils/test_hadamard.py
142
143
  tests/test_utils/__init__.py
143
144
  tests/test_utils/test_helpers.py
145
+ tests/test_utils/test_match.py
144
146
  tests/test_utils/test_offload.py
145
147
  tests/test_utils/test_safetensors_load.py
146
148
  utils/copyright.py
@@ -12,8 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import nbformat
16
15
  import pytest
16
+
17
+
18
+ nbformat = pytest.importorskip("nbformat")
17
19
  from nbconvert.preprocessors import ExecutePreprocessor
18
20
 
19
21