fusion-bench 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. fusion_bench/__init__.py +152 -42
  2. fusion_bench/dataset/__init__.py +27 -4
  3. fusion_bench/dataset/clip_dataset.py +2 -2
  4. fusion_bench/method/__init__.py +12 -1
  5. fusion_bench/method/classification/__init__.py +27 -2
  6. fusion_bench/method/classification/clip_finetune.py +6 -4
  7. fusion_bench/method/classification/image_classification_finetune.py +214 -0
  8. fusion_bench/method/dop/__init__.py +1 -0
  9. fusion_bench/method/dop/dop.py +366 -0
  10. fusion_bench/method/dop/min_norm_solvers.py +227 -0
  11. fusion_bench/method/dop/utils.py +73 -0
  12. fusion_bench/method/opcm/opcm.py +1 -0
  13. fusion_bench/method/pwe_moe/module.py +0 -2
  14. fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
  15. fusion_bench/mixins/__init__.py +2 -0
  16. fusion_bench/mixins/pyinstrument.py +174 -0
  17. fusion_bench/mixins/simple_profiler.py +106 -23
  18. fusion_bench/modelpool/__init__.py +2 -0
  19. fusion_bench/modelpool/base_pool.py +77 -14
  20. fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
  21. fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
  22. fusion_bench/models/__init__.py +35 -9
  23. fusion_bench/optim/__init__.py +40 -2
  24. fusion_bench/optim/lr_scheduler/__init__.py +27 -1
  25. fusion_bench/optim/muon.py +339 -0
  26. fusion_bench/programs/__init__.py +2 -0
  27. fusion_bench/programs/fabric_fusion_program.py +2 -2
  28. fusion_bench/programs/fusion_program.py +271 -0
  29. fusion_bench/tasks/clip_classification/__init__.py +15 -0
  30. fusion_bench/utils/__init__.py +167 -21
  31. fusion_bench/utils/lazy_imports.py +91 -12
  32. fusion_bench/utils/lazy_state_dict.py +55 -5
  33. fusion_bench/utils/misc.py +104 -13
  34. fusion_bench/utils/packages.py +4 -0
  35. fusion_bench/utils/path.py +7 -0
  36. fusion_bench/utils/pylogger.py +6 -0
  37. fusion_bench/utils/rich_utils.py +1 -0
  38. fusion_bench/utils/state_dict_arithmetic.py +935 -162
  39. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/METADATA +8 -2
  40. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/RECORD +75 -56
  41. fusion_bench_config/method/bitdelta/bitdelta.yaml +3 -0
  42. fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
  43. fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
  44. fusion_bench_config/method/depth_upscaling.yaml +9 -0
  45. fusion_bench_config/method/dop/dop.yaml +30 -0
  46. fusion_bench_config/method/dummy.yaml +6 -0
  47. fusion_bench_config/method/ensemble/max_model_predictor.yaml +6 -0
  48. fusion_bench_config/method/ensemble/simple_ensemble.yaml +8 -1
  49. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +8 -0
  50. fusion_bench_config/method/linear/linear_interpolation.yaml +8 -0
  51. fusion_bench_config/method/linear/weighted_average.yaml +3 -0
  52. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +1 -1
  53. fusion_bench_config/method/model_recombination.yaml +8 -0
  54. fusion_bench_config/method/model_stock/model_stock.yaml +4 -1
  55. fusion_bench_config/method/opcm/opcm.yaml +5 -0
  56. fusion_bench_config/method/opcm/task_arithmetic.yaml +6 -0
  57. fusion_bench_config/method/opcm/ties_merging.yaml +5 -0
  58. fusion_bench_config/method/opcm/weight_average.yaml +5 -0
  59. fusion_bench_config/method/simple_average.yaml +9 -0
  60. fusion_bench_config/method/slerp/slerp.yaml +9 -0
  61. fusion_bench_config/method/slerp/slerp_lm.yaml +5 -0
  62. fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +3 -0
  63. fusion_bench_config/method/task_arithmetic.yaml +9 -0
  64. fusion_bench_config/method/ties_merging.yaml +3 -0
  65. fusion_bench_config/model_fusion.yaml +45 -0
  66. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
  67. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
  68. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
  69. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
  70. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
  71. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
  72. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/WHEEL +0 -0
  73. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/entry_points.txt +0 -0
  74. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/licenses/LICENSE +0 -0
  75. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py CHANGED
@@ -5,46 +5,156 @@
5
5
  # ██║ ╚██████╔╝███████║██║╚██████╔╝██║ ╚████║ ██████╔╝███████╗██║ ╚████║╚██████╗██║ ██║
6
6
  # ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝
7
7
  # flake8: noqa: F401
8
- from . import (
9
- constants,
10
- dataset,
11
- method,
12
- metrics,
13
- mixins,
14
- modelpool,
15
- models,
16
- optim,
17
- programs,
18
- taskpool,
19
- tasks,
20
- utils,
21
- )
8
+ import sys
9
+ from typing import TYPE_CHECKING
10
+
11
+ from fusion_bench.utils.lazy_imports import LazyImporter
12
+
13
+ from . import constants, metrics, optim, tasks
22
14
  from .constants import RuntimeConstants
23
- from .method import BaseAlgorithm, BaseModelFusionAlgorithm
24
- from .mixins import auto_register_config
25
- from .modelpool import BaseModelPool
26
- from .models import (
27
- create_default_model_card,
28
- load_model_card_template,
29
- save_pretrained_with_remote_code,
30
- separate_io,
31
- )
32
- from .programs import BaseHydraProgram
33
- from .taskpool import BaseTaskPool
34
- from .utils import (
35
- BoolStateDictType,
36
- LazyStateDict,
37
- StateDictType,
38
- TorchModelType,
39
- cache_with_joblib,
40
- get_rankzero_logger,
41
- import_object,
42
- instantiate,
43
- parse_dtype,
44
- print_parameters,
45
- seed_everything_by_time,
46
- set_default_cache_dir,
47
- set_print_function_call,
48
- set_print_function_call_permeanent,
49
- timeit_context,
50
- )
15
+ from .method import _available_algorithms
16
+
17
+ _extra_objects = {
18
+ "RuntimeConstants": RuntimeConstants,
19
+ "constants": constants,
20
+ "metrics": metrics,
21
+ "optim": optim,
22
+ "tasks": tasks,
23
+ }
24
+ _import_structure = {
25
+ "dataset": ["CLIPDataset"],
26
+ "method": _available_algorithms,
27
+ "mixins": [
28
+ "CLIPClassificationMixin",
29
+ "FabricTrainingMixin",
30
+ "HydraConfigMixin",
31
+ "LightningFabricMixin",
32
+ "OpenCLIPClassificationMixin",
33
+ "PyinstrumentProfilerMixin",
34
+ "SimpleProfilerMixin",
35
+ "YAMLSerializationMixin",
36
+ "auto_register_config",
37
+ ],
38
+ "modelpool": [
39
+ "AutoModelPool",
40
+ "BaseModelPool",
41
+ "CausalLMBackbonePool",
42
+ "CausalLMPool",
43
+ "CLIPVisionModelPool",
44
+ "GPT2ForSequenceClassificationPool",
45
+ "HuggingFaceGPT2ClassificationPool",
46
+ "NYUv2ModelPool",
47
+ "OpenCLIPVisionModelPool",
48
+ "PeftModelForSeq2SeqLMPool",
49
+ "ResNetForImageClassificationPool",
50
+ "Seq2SeqLMPool",
51
+ "SequenceClassificationModelPool",
52
+ ],
53
+ "models": [
54
+ "create_default_model_card",
55
+ "load_model_card_template",
56
+ "save_pretrained_with_remote_code",
57
+ "separate_load",
58
+ "separate_save",
59
+ ],
60
+ "programs": ["BaseHydraProgram", "FabricModelFusionProgram"],
61
+ "taskpool": [
62
+ "BaseTaskPool",
63
+ "CLIPVisionModelTaskPool",
64
+ "DummyTaskPool",
65
+ "GPT2TextClassificationTaskPool",
66
+ "LMEvalHarnessTaskPool",
67
+ "OpenCLIPVisionModelTaskPool",
68
+ "NYUv2TaskPool",
69
+ ],
70
+ "utils": [
71
+ "ArithmeticStateDict",
72
+ "BoolStateDictType",
73
+ "LazyStateDict",
74
+ "StateDictType",
75
+ "TorchModelType",
76
+ "cache_with_joblib",
77
+ "get_rankzero_logger",
78
+ "import_object",
79
+ "instantiate",
80
+ "parse_dtype",
81
+ "print_parameters",
82
+ "seed_everything_by_time",
83
+ "set_default_cache_dir",
84
+ "set_print_function_call",
85
+ "set_print_function_call_permeanent",
86
+ "timeit_context",
87
+ ],
88
+ }
89
+
90
+ if TYPE_CHECKING:
91
+ from .dataset import CLIPDataset
92
+ from .method import BaseAlgorithm, BaseModelFusionAlgorithm
93
+ from .mixins import (
94
+ CLIPClassificationMixin,
95
+ FabricTrainingMixin,
96
+ HydraConfigMixin,
97
+ LightningFabricMixin,
98
+ OpenCLIPClassificationMixin,
99
+ PyinstrumentProfilerMixin,
100
+ SimpleProfilerMixin,
101
+ YAMLSerializationMixin,
102
+ auto_register_config,
103
+ )
104
+ from .modelpool import (
105
+ AutoModelPool,
106
+ BaseModelPool,
107
+ CausalLMBackbonePool,
108
+ CausalLMPool,
109
+ CLIPVisionModelPool,
110
+ GPT2ForSequenceClassificationPool,
111
+ HuggingFaceGPT2ClassificationPool,
112
+ NYUv2ModelPool,
113
+ OpenCLIPVisionModelPool,
114
+ PeftModelForSeq2SeqLMPool,
115
+ ResNetForImageClassificationPool,
116
+ Seq2SeqLMPool,
117
+ SequenceClassificationModelPool,
118
+ )
119
+ from .models import (
120
+ create_default_model_card,
121
+ load_model_card_template,
122
+ save_pretrained_with_remote_code,
123
+ separate_load,
124
+ separate_save,
125
+ )
126
+ from .programs import BaseHydraProgram, FabricModelFusionProgram
127
+ from .taskpool import (
128
+ BaseTaskPool,
129
+ CLIPVisionModelTaskPool,
130
+ DummyTaskPool,
131
+ GPT2TextClassificationTaskPool,
132
+ LMEvalHarnessTaskPool,
133
+ NYUv2TaskPool,
134
+ OpenCLIPVisionModelTaskPool,
135
+ )
136
+ from .utils import (
137
+ ArithmeticStateDict,
138
+ BoolStateDictType,
139
+ LazyStateDict,
140
+ StateDictType,
141
+ TorchModelType,
142
+ cache_with_joblib,
143
+ get_rankzero_logger,
144
+ import_object,
145
+ instantiate,
146
+ parse_dtype,
147
+ print_parameters,
148
+ seed_everything_by_time,
149
+ set_default_cache_dir,
150
+ set_print_function_call,
151
+ set_print_function_call_permeanent,
152
+ timeit_context,
153
+ )
154
+ else:
155
+ sys.modules[__name__] = LazyImporter(
156
+ __name__,
157
+ globals()["__file__"],
158
+ _import_structure,
159
+ extra_objects=_extra_objects,
160
+ )
@@ -1,16 +1,20 @@
1
1
  # flake8: noqa F401
2
- from datasets import load_dataset
3
- from omegaconf import DictConfig, open_dict
2
+ import sys
3
+ from typing import TYPE_CHECKING
4
4
 
5
- from fusion_bench.utils import instantiate
5
+ from omegaconf import DictConfig, open_dict
6
6
 
7
- from .clip_dataset import CLIPDataset
7
+ from fusion_bench.utils.lazy_imports import LazyImporter
8
8
 
9
9
 
10
10
  def load_dataset_from_config(dataset_config: DictConfig):
11
11
  """
12
12
  Load the dataset from the configuration.
13
13
  """
14
+ from datasets import load_dataset
15
+
16
+ from fusion_bench.utils import instantiate
17
+
14
18
  assert hasattr(dataset_config, "type"), "Dataset type not specified"
15
19
  if dataset_config.type == "instantiate":
16
20
  return instantiate(dataset_config.object)
@@ -27,3 +31,22 @@ def load_dataset_from_config(dataset_config: DictConfig):
27
31
  return dataset
28
32
  else:
29
33
  raise ValueError(f"Unknown dataset type: {dataset_config.type}")
34
+
35
+
36
+ _extra_objects = {
37
+ "load_dataset_from_config": load_dataset_from_config,
38
+ }
39
+ _import_structure = {
40
+ "clip_dataset": ["CLIPDataset"],
41
+ }
42
+
43
+ if TYPE_CHECKING:
44
+ from .clip_dataset import CLIPDataset
45
+
46
+ else:
47
+ sys.modules[__name__] = LazyImporter(
48
+ __name__,
49
+ globals()["__file__"],
50
+ _import_structure,
51
+ extra_objects=_extra_objects,
52
+ )
@@ -6,7 +6,7 @@ from typing import Optional, Tuple
6
6
 
7
7
  import torch
8
8
  from torch.utils.data import Dataset
9
- from transformers import CLIPProcessor, ProcessorMixin
9
+ from transformers import BaseImageProcessor, CLIPProcessor, ProcessorMixin
10
10
 
11
11
  __all__ = ["CLIPDataset"]
12
12
 
@@ -60,7 +60,7 @@ class CLIPDataset(torch.utils.data.Dataset):
60
60
  raise ValueError("Each item should be a dictionary or a tuple of length 2")
61
61
  image = item["image"]
62
62
  if self.processor is not None:
63
- if isinstance(self.processor, ProcessorMixin):
63
+ if isinstance(self.processor, (ProcessorMixin, BaseImageProcessor)):
64
64
  # Apply the processor to the image to get the input tensor
65
65
  inputs = self.processor(images=[image], return_tensors="pt")[
66
66
  "pixel_values"
@@ -2,6 +2,7 @@
2
2
  import sys
3
3
  from typing import TYPE_CHECKING
4
4
 
5
+ from fusion_bench.utils import join_lists
5
6
  from fusion_bench.utils.lazy_imports import LazyImporter
6
7
 
7
8
  _import_structure = {
@@ -12,6 +13,8 @@ _import_structure = {
12
13
  "classification": [
13
14
  "ImageClassificationFineTuningForCLIP",
14
15
  "ContinualImageClassificationFineTuningForCLIP",
16
+ "ImageClassificationFineTuning",
17
+ "ImageClassificationFineTuning_Test",
15
18
  ],
16
19
  "lm_finetune": ["FullFinetuneSFT", "PeftFinetuneSFT", "BradleyTerryRewardModeling"],
17
20
  # analysis
@@ -67,6 +70,7 @@ _import_structure = {
67
70
  "IsotropicMergingInCommonSubspace",
68
71
  ],
69
72
  "opcm": ["OPCMForCLIP"],
73
+ "dop": ["ContinualDOPForCLIP"],
70
74
  "gossip": [
71
75
  "CLIPLayerWiseGossipAlgorithm",
72
76
  "CLIPTaskWiseGossipAlgorithm",
@@ -131,7 +135,10 @@ _import_structure = {
131
135
  "ProgressivePruningForMixtral",
132
136
  ],
133
137
  }
134
-
138
+ _available_algorithms = join_lists(list(_import_structure.values()))
139
+ _extra_objects = {
140
+ "_available_algorithms": _available_algorithms,
141
+ }
135
142
 
136
143
  if TYPE_CHECKING:
137
144
  from .ada_svd import AdaSVDMergingForCLIPVisionModel
@@ -141,6 +148,8 @@ if TYPE_CHECKING:
141
148
  from .bitdelta import BitDeltaAlgorithm
142
149
  from .classification import (
143
150
  ContinualImageClassificationFineTuningForCLIP,
151
+ ImageClassificationFineTuning,
152
+ ImageClassificationFineTuning_Test,
144
153
  ImageClassificationFineTuningForCLIP,
145
154
  )
146
155
  from .concrete_subspace import (
@@ -204,6 +213,7 @@ if TYPE_CHECKING:
204
213
  from .model_recombination import ModelRecombinationAlgorithm
205
214
  from .model_stock import ModelStock
206
215
  from .opcm import OPCMForCLIP
216
+ from .dop import ContinualDOPForCLIP
207
217
  from .pruning import (
208
218
  MagnitudeDiffPruningAlgorithm,
209
219
  MagnitudePruningForLlama,
@@ -252,4 +262,5 @@ else:
252
262
  __name__,
253
263
  globals()["__file__"],
254
264
  _import_structure,
265
+ extra_objects=_extra_objects,
255
266
  )
@@ -1,3 +1,28 @@
1
1
  # flake8: noqa F401
2
- from .clip_finetune import ImageClassificationFineTuningForCLIP
3
- from .continual_clip_finetune import ContinualImageClassificationFineTuningForCLIP
2
+ import sys
3
+ from typing import TYPE_CHECKING
4
+
5
+ from fusion_bench.utils.lazy_imports import LazyImporter
6
+
7
+ _import_structure = {
8
+ "clip_finetune": ["ImageClassificationFineTuningForCLIP"],
9
+ "continual_clip_finetune": ["ContinualImageClassificationFineTuningForCLIP"],
10
+ "image_classification_finetune": [
11
+ "ImageClassificationFineTuning",
12
+ "ImageClassificationFineTuning_Test",
13
+ ],
14
+ }
15
+
16
+ if TYPE_CHECKING:
17
+ from .clip_finetune import ImageClassificationFineTuningForCLIP
18
+ from .continual_clip_finetune import ContinualImageClassificationFineTuningForCLIP
19
+ from .image_classification_finetune import (
20
+ ImageClassificationFineTuning,
21
+ ImageClassificationFineTuning_Test,
22
+ )
23
+ else:
24
+ sys.modules[__name__] = LazyImporter(
25
+ __name__,
26
+ globals()["__file__"],
27
+ _import_structure,
28
+ )
@@ -5,8 +5,8 @@ Fine-tune CLIP-ViT-B/32:
5
5
 
6
6
  ```bash
7
7
  fusion_bench \
8
- method=clip_finetune \
9
- modelpool=clip-vit-base-patch32_mtl \
8
+ method=classification/clip_finetune \
9
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_mtl \
10
10
  taskpool=dummy
11
11
  ```
12
12
 
@@ -15,12 +15,14 @@ Fine-tune CLIP-ViT-L/14 on eight GPUs with a per-device per-task batch size of 2
15
15
  ```bash
16
16
  fusion_bench \
17
17
  fabric.devices=8 \
18
- method=clip_finetune \
18
+ method=classification/clip_finetune \
19
19
  method.batch_size=2 \
20
- modelpool=clip-vit-base-patch32_mtl \
20
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_mtl \
21
21
  modelpool.models.0.path=openai/clip-vit-large-patch14 \
22
22
  taskpool=dummy
23
23
  ```
24
+
25
+ See `examples/clip_finetune` for more details.
24
26
  """
25
27
 
26
28
  import os
@@ -0,0 +1,214 @@
1
+ import os
2
+ from typing import Optional
3
+
4
+ import lightning as L
5
+ import lightning.pytorch.callbacks as pl_callbacks
6
+ import torch
7
+ from lightning.pytorch.loggers import TensorBoardLogger
8
+ from lightning_utilities.core.rank_zero import rank_zero_only
9
+ from lit_learn.lit_modules import ERM_LitModule
10
+ from omegaconf import DictConfig
11
+ from torch import nn
12
+ from torch.utils.data import DataLoader
13
+ from torchmetrics.classification import Accuracy
14
+
15
+ from fusion_bench import (
16
+ BaseAlgorithm,
17
+ BaseModelPool,
18
+ RuntimeConstants,
19
+ auto_register_config,
20
+ get_rankzero_logger,
21
+ instantiate,
22
+ )
23
+ from fusion_bench.dataset import CLIPDataset
24
+ from fusion_bench.modelpool import ResNetForImageClassificationPool
25
+ from fusion_bench.tasks.clip_classification import get_num_classes
26
+
27
+ log = get_rankzero_logger(__name__)
28
+
29
+
30
+ @auto_register_config
31
+ class ImageClassificationFineTuning(BaseAlgorithm):
32
+ def __init__(
33
+ self,
34
+ max_epochs: Optional[int],
35
+ max_steps: Optional[int],
36
+ label_smoothing: float,
37
+ optimizer: DictConfig,
38
+ lr_scheduler: DictConfig,
39
+ dataloader_kwargs: DictConfig,
40
+ **kwargs,
41
+ ):
42
+ super().__init__(**kwargs)
43
+ assert (max_epochs is None) or (
44
+ max_steps is None or max_steps < 0
45
+ ), "Only one of max_epochs or max_steps should be set."
46
+ self.training_interval = "epoch" if max_epochs is not None else "step"
47
+ if self.training_interval == "epoch":
48
+ self.max_steps = -1
49
+ log.info(f"Training interval: {self.training_interval}")
50
+ log.info(f"Max epochs: {max_epochs}, max steps: {max_steps}")
51
+
52
+ def run(self, modelpool: ResNetForImageClassificationPool):
53
+ # load model and dataset
54
+ model = modelpool.load_pretrained_or_first_model()
55
+ assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."
56
+
57
+ assert (
58
+ len(modelpool.train_dataset_names) == 1
59
+ ), "Exactly one training dataset is required."
60
+ self.dataset_name = dataset_name = modelpool.train_dataset_names[0]
61
+ num_classes = get_num_classes(dataset_name)
62
+ train_dataset = modelpool.load_train_dataset(dataset_name)
63
+ train_dataset = CLIPDataset(
64
+ train_dataset, processor=modelpool.load_processor(stage="train")
65
+ )
66
+ train_loader = self.get_dataloader(train_dataset, stage="train")
67
+ if modelpool.has_val_dataset:
68
+ val_dataset = modelpool.load_val_dataset(dataset_name)
69
+ val_dataset = CLIPDataset(
70
+ val_dataset, processor=modelpool.load_processor(stage="val")
71
+ )
72
+ val_loader = self.get_dataloader(val_dataset, stage="val")
73
+
74
+ # configure optimizer
75
+ optimizer = instantiate(self.optimizer, params=model.parameters())
76
+ if self.lr_scheduler is not None:
77
+ lr_scheduler = instantiate(self.lr_scheduler, optimizer=optimizer)
78
+ optimizer = {
79
+ "optimizer": optimizer,
80
+ "lr_scheduler": {
81
+ "scheduler": lr_scheduler,
82
+ "interval": self.training_interval,
83
+ "frequency": 1,
84
+ },
85
+ }
86
+ log.info(f"optimizer:\n{optimizer}")
87
+
88
+ lit_module = ERM_LitModule(
89
+ model,
90
+ optimizer,
91
+ objective=nn.CrossEntropyLoss(label_smoothing=self.label_smoothing),
92
+ metrics={
93
+ "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
94
+ "acc@5": Accuracy(task="multiclass", num_classes=num_classes, top_k=5),
95
+ },
96
+ )
97
+
98
+ log_dir = (
99
+ self._program.path.log_dir
100
+ if self._program is not None
101
+ else "outputs/lightning_logs"
102
+ )
103
+ trainer = L.Trainer(
104
+ max_epochs=self.max_epochs,
105
+ max_steps=self.max_steps,
106
+ accelerator="auto",
107
+ devices="auto",
108
+ callbacks=[
109
+ pl_callbacks.LearningRateMonitor(logging_interval="step"),
110
+ pl_callbacks.DeviceStatsMonitor(),
111
+ ],
112
+ logger=TensorBoardLogger(
113
+ save_dir=log_dir,
114
+ name="",
115
+ ),
116
+ fast_dev_run=RuntimeConstants.debug,
117
+ )
118
+
119
+ trainer.fit(
120
+ lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader
121
+ )
122
+ model = lit_module.model
123
+ if rank_zero_only.rank == 0:
124
+ log.info(f"Saving the final model to {log_dir}/raw_checkpoints/final")
125
+ modelpool.save_model(
126
+ model,
127
+ path=os.path.join(
128
+ trainer.log_dir if trainer.log_dir is not None else log_dir,
129
+ "raw_checkpoints",
130
+ "final",
131
+ ),
132
+ )
133
+ return model
134
+
135
+ def get_dataloader(self, dataset, stage: str):
136
+ assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
137
+ dataloader_kwargs = dict(self.dataloader_kwargs)
138
+ if "shuffle" not in dataloader_kwargs:
139
+ dataloader_kwargs["shuffle"] = stage == "train"
140
+ return DataLoader(dataset, **dataloader_kwargs)
141
+
142
+
143
+ @auto_register_config
144
+ class ImageClassificationFineTuning_Test(BaseAlgorithm):
145
+ def __init__(self, checkpoint_path: str, dataloader_kwargs: DictConfig, **kwargs):
146
+ super().__init__(**kwargs)
147
+
148
+ def run(self, modelpool: BaseModelPool):
149
+ assert (
150
+ modelpool.has_val_dataset or modelpool.has_test_dataset
151
+ ), "No validation or test dataset found in the model pool."
152
+
153
+ # load model and dataset
154
+ model = modelpool.load_pretrained_or_first_model()
155
+ assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."
156
+
157
+ if modelpool.has_test_dataset:
158
+ assert (
159
+ len(modelpool.test_dataset_names) == 1
160
+ ), "Exactly one test dataset is required."
161
+ self.dataset_name = dataset_name = modelpool.test_dataset_names[0]
162
+ dataset = modelpool.load_test_dataset(dataset_name)
163
+ dataset = CLIPDataset(
164
+ dataset, processor=modelpool.load_processor(stage="test")
165
+ )
166
+ else:
167
+ assert (
168
+ len(modelpool.val_dataset_names) == 1
169
+ ), "Exactly one validation dataset is required."
170
+ self.dataset_name = dataset_name = modelpool.val_dataset_names[0]
171
+ dataset = modelpool.load_val_dataset(dataset_name)
172
+ dataset = CLIPDataset(
173
+ dataset, processor=modelpool.load_processor(stage="test")
174
+ )
175
+ num_classes = get_num_classes(dataset_name)
176
+
177
+ test_loader = self.get_dataloader(dataset, stage="test")
178
+
179
+ if self.checkpoint_path is None:
180
+ lit_module = ERM_LitModule(
181
+ model,
182
+ metrics={
183
+ "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
184
+ "acc@5": Accuracy(
185
+ task="multiclass", num_classes=num_classes, top_k=5
186
+ ),
187
+ },
188
+ )
189
+ else:
190
+ lit_module = ERM_LitModule.load_from_checkpoint(
191
+ checkpoint_path=self.checkpoint_path,
192
+ model=model,
193
+ metrics={
194
+ "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
195
+ "acc@5": Accuracy(
196
+ task="multiclass", num_classes=num_classes, top_k=5
197
+ ),
198
+ },
199
+ )
200
+
201
+ trainer = L.Trainer(
202
+ devices=1, num_nodes=1, logger=False, fast_dev_run=RuntimeConstants.debug
203
+ )
204
+
205
+ test_metrics = trainer.test(lit_module, dataloaders=test_loader)
206
+ log.info(f"Test metrics: {test_metrics}")
207
+ return model
208
+
209
+ def get_dataloader(self, dataset, stage: str):
210
+ assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
211
+ dataloader_kwargs = dict(self.dataloader_kwargs)
212
+ if "shuffle" not in dataloader_kwargs:
213
+ dataloader_kwargs["shuffle"] = stage == "train"
214
+ return DataLoader(dataset, **dataloader_kwargs)
@@ -0,0 +1 @@
1
+ from .dop import ContinualDOPForCLIP