fusion-bench 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. fusion_bench/constants/__init__.py +5 -1
  2. fusion_bench/constants/runtime.py +111 -7
  3. fusion_bench/dataset/gsm8k.py +6 -2
  4. fusion_bench/dataset/image_corruption/make_corruption.py +168 -0
  5. fusion_bench/method/__init__.py +10 -2
  6. fusion_bench/method/base_algorithm.py +29 -19
  7. fusion_bench/method/classification/image_classification_finetune.py +1 -2
  8. fusion_bench/method/gossip/clip_task_wise_gossip.py +1 -29
  9. fusion_bench/metrics/model_kinship/__init__.py +2 -0
  10. fusion_bench/metrics/model_kinship/calculate.py +77 -0
  11. fusion_bench/metrics/model_kinship/calculate_split.py +171 -0
  12. fusion_bench/metrics/model_kinship/utility.py +184 -0
  13. fusion_bench/metrics/nyuv2/__init__.py +31 -0
  14. fusion_bench/metrics/nyuv2/depth.py +30 -0
  15. fusion_bench/metrics/nyuv2/loss.py +40 -0
  16. fusion_bench/metrics/nyuv2/noise.py +24 -0
  17. fusion_bench/metrics/nyuv2/normal.py +34 -1
  18. fusion_bench/metrics/nyuv2/segmentation.py +35 -1
  19. fusion_bench/mixins/clip_classification.py +30 -2
  20. fusion_bench/mixins/lightning_fabric.py +46 -5
  21. fusion_bench/mixins/rich_live.py +76 -0
  22. fusion_bench/modelpool/base_pool.py +86 -5
  23. fusion_bench/models/masks/mask_model.py +8 -2
  24. fusion_bench/models/open_clip/modeling.py +7 -0
  25. fusion_bench/models/wrappers/layer_wise_fusion.py +41 -3
  26. fusion_bench/models/wrappers/task_wise_fusion.py +14 -3
  27. fusion_bench/scripts/cli.py +14 -0
  28. fusion_bench/scripts/webui.py +250 -17
  29. fusion_bench/utils/__init__.py +14 -0
  30. fusion_bench/utils/data.py +100 -9
  31. fusion_bench/utils/devices.py +3 -1
  32. fusion_bench/utils/fabric.py +185 -4
  33. fusion_bench/utils/instantiate_utils.py +29 -18
  34. fusion_bench/utils/json.py +6 -0
  35. fusion_bench/utils/misc.py +16 -0
  36. fusion_bench/utils/rich_utils.py +123 -6
  37. fusion_bench/utils/validation.py +197 -0
  38. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/METADATA +72 -13
  39. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/RECORD +49 -45
  40. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +6 -19
  41. fusion_bench_config/llama_full_finetune.yaml +4 -16
  42. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  43. fusion_bench_config/nyuv2_config.yaml +4 -13
  44. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  45. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  46. fusion_bench/utils/auto.py +0 -31
  47. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/WHEEL +0 -0
  48. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/entry_points.txt +0 -0
  49. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/licenses/LICENSE +0 -0
  50. {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,6 @@
1
1
  defaults:
2
- - hydra: default
3
- - fabric: llama_fsdp
4
- - path: default
5
- # --- Model, Method, Task ---
6
- - method: lm_finetune/fullfinetune_sft.yaml
7
- - modelpool: CausalLMPool/llama_alpaca_cleaned.yaml
8
- - taskpool: dummy
2
+ - fabric_model_fusion
3
+ - override fabric: llama_fsdp
4
+ - override method: lm_finetune/fullfinetune_sft.yaml
5
+ - override modelpool: CausalLMPool/llama_alpaca_cleaned.yaml
9
6
  - _self_
10
- _target_: fusion_bench.programs.FabricModelFusionProgram
11
- _recursive_: false
12
- fast_dev_run: false # Run a single batch of data to test the model or method
13
- # Run the script without actually running the experiment, use with `print_config=true`.
14
- # You can also use `--cfg` or `-c` to show the configuration instead of running.
15
- dry_run: false
16
- print_config: true # Print the configuration to the console
17
- report_save_path: null # path to save the result report
18
- print_function_call: true # set to false if you don't want to print the details of instantiate calls
@@ -6,7 +6,7 @@ defaults:
6
6
  - clip-vit-base-patch32_eurosat
7
7
  - clip-vit-base-patch32_resisc45
8
8
  - clip-vit-base-patch32_gtsrb
9
- # `corrption` can be one of:
9
+ # `corruption` can be one of:
10
10
  # contrast, gaussian_noise, impulse_noise, jpeg_compression, motion_blur, pixelate, spatter
11
11
  corruption: ${corruption}
12
12
  # The following datasets are used for test-time adaptation
@@ -1,17 +1,8 @@
1
1
  defaults:
2
- - hydra: default
3
- - fabric: auto
4
- - path: default
5
- # --- Model, Method, Task ---
6
- - method: simple_average
7
- - modelpool: nyuv2_modelpool
8
- - taskpool: nyuv2_taskpool
2
+ - fabric_model_fusion
3
+ - override method: simple_average
4
+ - override modelpool: nyuv2_modelpool
5
+ - override taskpool: nyuv2_taskpool
9
6
  - _self_
10
- _target_: fusion_bench.programs.FabricModelFusionProgram
11
- _recursive_: false
12
- fast_dev_run: false # Run a single batch of data to test the model or method
13
- use_lightning: true # Use the fabric to run the experiment
14
- print_config: true # Print the configuration to the console
15
- save_report: false # path to save the result report
16
7
  trainer:
17
8
  devices: 1
@@ -1,6 +1,6 @@
1
1
  type: clip_vit_classification
2
2
  name: clip-vit-robustness_clean
3
- # corrption can be one of:
3
+ # corruption can be one of:
4
4
  # contrast, gaussian_noise, impulse_noise, jpeg_compression, motion_blur, pixelate, spatter
5
5
  corruption: ${corruption}
6
6
  dataset_type: huggingface_image_classification
@@ -1,6 +1,6 @@
1
1
  type: clip_vit_classification
2
2
  name: clip-vit-robustness_clean
3
- # corrption can be one of:
3
+ # corruption can be one of:
4
4
  # contrast, gaussian_noise, impulse_noise, jpeg_compression, motion_blur, pixelate, spatter
5
5
  corruption: ${corruption}
6
6
  dataset_type: huggingface_image_classification
@@ -1,31 +0,0 @@
1
- from omegaconf import DictConfig
2
-
3
- from fusion_bench.utils import import_object
4
-
5
-
6
- class BaseFactoryClass:
7
- _registry = {}
8
-
9
- @classmethod
10
- def from_config(cls, config: DictConfig):
11
- name = config.name
12
- if name not in cls._registry:
13
- raise ValueError(
14
- f"Unknown name: {name}, available names: {cls._registry.keys()}. "
15
- f"You can register a new item using `{cls.__name__}.register()` method."
16
- )
17
-
18
- item_cls = cls._registry[name]
19
- if isinstance(item_cls, str):
20
- if item_cls.startswith("."):
21
- item_cls = f"{cls.__module__}.{item_cls[1:]}"
22
- item_cls = import_object(item_cls)
23
- return item_cls(config)
24
-
25
- @classmethod
26
- def register(cls, name: str, item_cls):
27
- cls._registry[name] = item_cls
28
-
29
- @classmethod
30
- def available_items(cls):
31
- return list(cls._registry.keys())