hydra-ax-sweeper 1.3.0.dev0__tar.gz → 1.4.0.dev4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_ax_sweeper-1.4.0.dev4/MANIFEST.in +3 -0
- hydra_ax_sweeper-1.4.0.dev4/PKG-INFO +35 -0
- hydra_ax_sweeper-1.4.0.dev4/hydra_ax_sweeper.egg-info/PKG-INFO +35 -0
- {hydra-ax-sweeper-1.3.0.dev0 → hydra_ax_sweeper-1.4.0.dev4}/hydra_ax_sweeper.egg-info/SOURCES.txt +3 -1
- hydra_ax_sweeper-1.4.0.dev4/hydra_ax_sweeper.egg-info/requires.txt +3 -0
- {hydra-ax-sweeper-1.3.0.dev0 → hydra_ax_sweeper-1.4.0.dev4}/hydra_plugins/hydra_ax_sweeper/__init__.py +1 -1
- {hydra-ax-sweeper-1.3.0.dev0 → hydra_ax_sweeper-1.4.0.dev4}/hydra_plugins/hydra_ax_sweeper/_core.py +150 -79
- {hydra-ax-sweeper-1.3.0.dev0 → hydra_ax_sweeper-1.4.0.dev4}/hydra_plugins/hydra_ax_sweeper/_earlystopper.py +2 -4
- {hydra-ax-sweeper-1.3.0.dev0 → hydra_ax_sweeper-1.4.0.dev4}/hydra_plugins/hydra_ax_sweeper/config.py +4 -6
- hydra_ax_sweeper-1.4.0.dev4/hydra_plugins/hydra_ax_sweeper/py.typed +0 -0
- {hydra-ax-sweeper-1.3.0.dev0 → hydra_ax_sweeper-1.4.0.dev4}/setup.py +8 -8
- hydra_ax_sweeper-1.4.0.dev4/tests/test_ax_sweeper_plugin.py +461 -0
- hydra-ax-sweeper-1.3.0.dev0/MANIFEST.in +0 -3
- hydra-ax-sweeper-1.3.0.dev0/PKG-INFO +0 -21
- hydra-ax-sweeper-1.3.0.dev0/hydra_ax_sweeper.egg-info/PKG-INFO +0 -21
- hydra-ax-sweeper-1.3.0.dev0/hydra_ax_sweeper.egg-info/requires.txt +0 -4
- {hydra-ax-sweeper-1.3.0.dev0 → hydra_ax_sweeper-1.4.0.dev4}/README.md +0 -0
- {hydra-ax-sweeper-1.3.0.dev0 → hydra_ax_sweeper-1.4.0.dev4}/hydra_ax_sweeper.egg-info/dependency_links.txt +0 -0
- {hydra-ax-sweeper-1.3.0.dev0 → hydra_ax_sweeper-1.4.0.dev4}/hydra_ax_sweeper.egg-info/top_level.txt +0 -0
- {hydra-ax-sweeper-1.3.0.dev0 → hydra_ax_sweeper-1.4.0.dev4}/hydra_plugins/hydra_ax_sweeper/ax_sweeper.py +0 -0
- {hydra-ax-sweeper-1.3.0.dev0 → hydra_ax_sweeper-1.4.0.dev4}/pyproject.toml +0 -0
- {hydra-ax-sweeper-1.3.0.dev0 → hydra_ax_sweeper-1.4.0.dev4}/setup.cfg +0 -0
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: hydra-ax-sweeper
|
|
3
|
+
Version: 1.4.0.dev4
|
|
4
|
+
Summary: Hydra Ax Sweeper plugin
|
|
5
|
+
Home-page: https://github.com/facebookresearch/hydra/
|
|
6
|
+
Author: Omry Yadan, Shagun Sodhani
|
|
7
|
+
Author-email: omry@fb.com, sshagunsodhani@gmail.com
|
|
8
|
+
License: MIT
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
13
|
+
Classifier: Operating System :: POSIX :: Linux
|
|
14
|
+
Classifier: Operating System :: MacOS
|
|
15
|
+
Classifier: Development Status :: 4 - Beta
|
|
16
|
+
Requires-Python: >=3.11,<3.15
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
Requires-Dist: hydra-core>=1.1.0.dev7
|
|
19
|
+
Requires-Dist: ax-platform<1.3.0,>=1.2.4
|
|
20
|
+
Requires-Dist: torch>=2.2
|
|
21
|
+
Dynamic: author
|
|
22
|
+
Dynamic: author-email
|
|
23
|
+
Dynamic: classifier
|
|
24
|
+
Dynamic: description
|
|
25
|
+
Dynamic: description-content-type
|
|
26
|
+
Dynamic: home-page
|
|
27
|
+
Dynamic: license
|
|
28
|
+
Dynamic: requires-dist
|
|
29
|
+
Dynamic: requires-python
|
|
30
|
+
Dynamic: summary
|
|
31
|
+
|
|
32
|
+
# Hydra Ax Sweeper
|
|
33
|
+
Provides a [`Ax Sweeper`](https://ax.dev/) based Hydra Sweeper supporting parallel execution.
|
|
34
|
+
|
|
35
|
+
See [website](https://hydra.cc/docs/plugins/ax_sweeper/) for more information
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: hydra-ax-sweeper
|
|
3
|
+
Version: 1.4.0.dev4
|
|
4
|
+
Summary: Hydra Ax Sweeper plugin
|
|
5
|
+
Home-page: https://github.com/facebookresearch/hydra/
|
|
6
|
+
Author: Omry Yadan, Shagun Sodhani
|
|
7
|
+
Author-email: omry@fb.com, sshagunsodhani@gmail.com
|
|
8
|
+
License: MIT
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
13
|
+
Classifier: Operating System :: POSIX :: Linux
|
|
14
|
+
Classifier: Operating System :: MacOS
|
|
15
|
+
Classifier: Development Status :: 4 - Beta
|
|
16
|
+
Requires-Python: >=3.11,<3.15
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
Requires-Dist: hydra-core>=1.1.0.dev7
|
|
19
|
+
Requires-Dist: ax-platform<1.3.0,>=1.2.4
|
|
20
|
+
Requires-Dist: torch>=2.2
|
|
21
|
+
Dynamic: author
|
|
22
|
+
Dynamic: author-email
|
|
23
|
+
Dynamic: classifier
|
|
24
|
+
Dynamic: description
|
|
25
|
+
Dynamic: description-content-type
|
|
26
|
+
Dynamic: home-page
|
|
27
|
+
Dynamic: license
|
|
28
|
+
Dynamic: requires-dist
|
|
29
|
+
Dynamic: requires-python
|
|
30
|
+
Dynamic: summary
|
|
31
|
+
|
|
32
|
+
# Hydra Ax Sweeper
|
|
33
|
+
Provides a [`Ax Sweeper`](https://ax.dev/) based Hydra Sweeper supporting parallel execution.
|
|
34
|
+
|
|
35
|
+
See [website](https://hydra.cc/docs/plugins/ax_sweeper/) for more information
|
{hydra-ax-sweeper-1.3.0.dev0 → hydra_ax_sweeper-1.4.0.dev4}/hydra_ax_sweeper.egg-info/SOURCES.txt
RENAMED
|
@@ -11,4 +11,6 @@ hydra_plugins/hydra_ax_sweeper/__init__.py
|
|
|
11
11
|
hydra_plugins/hydra_ax_sweeper/_core.py
|
|
12
12
|
hydra_plugins/hydra_ax_sweeper/_earlystopper.py
|
|
13
13
|
hydra_plugins/hydra_ax_sweeper/ax_sweeper.py
|
|
14
|
-
hydra_plugins/hydra_ax_sweeper/config.py
|
|
14
|
+
hydra_plugins/hydra_ax_sweeper/config.py
|
|
15
|
+
hydra_plugins/hydra_ax_sweeper/py.typed
|
|
16
|
+
tests/test_ax_sweeper_plugin.py
|
{hydra-ax-sweeper-1.3.0.dev0 → hydra_ax_sweeper-1.4.0.dev4}/hydra_plugins/hydra_ax_sweeper/_core.py
RENAMED
|
@@ -1,11 +1,23 @@
|
|
|
1
1
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
2
2
|
import logging
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import
|
|
5
|
-
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
Dict,
|
|
7
|
+
Iterable,
|
|
8
|
+
List,
|
|
9
|
+
Literal,
|
|
10
|
+
Mapping,
|
|
11
|
+
Optional,
|
|
12
|
+
Tuple,
|
|
13
|
+
Union,
|
|
14
|
+
cast,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from ax.api.client import Client # type: ignore
|
|
18
|
+
from ax.api.configs import ChoiceParameterConfig, RangeParameterConfig # type: ignore
|
|
6
19
|
from ax.core import types as ax_types # type: ignore
|
|
7
|
-
from ax.exceptions.core import SearchSpaceExhausted # type: ignore
|
|
8
|
-
from ax.service.ax_client import AxClient # type: ignore
|
|
20
|
+
from ax.exceptions.core import SearchSpaceExhausted, UnsupportedError # type: ignore
|
|
9
21
|
from hydra.core.override_parser.overrides_parser import OverridesParser
|
|
10
22
|
from hydra.core.override_parser.types import IntervalSweep, Override, Transformer
|
|
11
23
|
from hydra.core.plugins import Plugins
|
|
@@ -19,6 +31,10 @@ from .config import AxConfig, ClientConfig, ExperimentConfig
|
|
|
19
31
|
|
|
20
32
|
log = logging.getLogger(__name__)
|
|
21
33
|
|
|
34
|
+
AxRangeParameterType = Literal["float", "int"]
|
|
35
|
+
AxChoiceParameterType = Literal["float", "int", "str", "bool"]
|
|
36
|
+
AxParameterConfig = Union[RangeParameterConfig, ChoiceParameterConfig]
|
|
37
|
+
|
|
22
38
|
|
|
23
39
|
@dataclass
|
|
24
40
|
class Trial:
|
|
@@ -34,7 +50,7 @@ class TrialBatch:
|
|
|
34
50
|
|
|
35
51
|
def encoder_parameters_into_string(parameters: List[Dict[str, Any]]) -> str:
|
|
36
52
|
"""Convert a list of params into a string"""
|
|
37
|
-
mandatory_keys =
|
|
53
|
+
mandatory_keys = {"name", "type", "bounds", "values", "value"}
|
|
38
54
|
parameter_log_string = ""
|
|
39
55
|
for parameter in parameters:
|
|
40
56
|
parameter_log_string += "\n"
|
|
@@ -52,7 +68,9 @@ def encoder_parameters_into_string(parameters: List[Dict[str, Any]]) -> str:
|
|
|
52
68
|
return parameter_log_string
|
|
53
69
|
|
|
54
70
|
|
|
55
|
-
def map_params_to_arg_list(
|
|
71
|
+
def map_params_to_arg_list(
|
|
72
|
+
params: Mapping[str, Union[str, float, int, bool]],
|
|
73
|
+
) -> List[str]:
|
|
56
74
|
"""Method to map a dictionary of params to a list of string arguments"""
|
|
57
75
|
arg_list = []
|
|
58
76
|
for key in params:
|
|
@@ -60,45 +78,76 @@ def map_params_to_arg_list(params: Mapping[str, Union[str, float, int]]) -> List
|
|
|
60
78
|
return arg_list
|
|
61
79
|
|
|
62
80
|
|
|
81
|
+
def get_ax_choice_parameter_type(values: Iterable[Any]) -> AxChoiceParameterType:
|
|
82
|
+
value_types = {type(value) for value in values}
|
|
83
|
+
if value_types == {bool}:
|
|
84
|
+
return "bool"
|
|
85
|
+
if value_types <= {int}:
|
|
86
|
+
return "int"
|
|
87
|
+
if value_types <= {int, float}:
|
|
88
|
+
return "float"
|
|
89
|
+
if value_types == {str}:
|
|
90
|
+
return "str"
|
|
91
|
+
raise ValueError(f"Unsupported mixed Ax parameter value types: {value_types}")
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def create_ax_parameter_config(param: Dict[Any, Any]) -> AxParameterConfig:
|
|
95
|
+
name = param["name"]
|
|
96
|
+
if param["type"] == "range":
|
|
97
|
+
bounds = param["bounds"]
|
|
98
|
+
range_parameter_type: AxRangeParameterType = (
|
|
99
|
+
"int" if all(type(bound) is int for bound in bounds) else "float"
|
|
100
|
+
)
|
|
101
|
+
return RangeParameterConfig(
|
|
102
|
+
name=name,
|
|
103
|
+
bounds=cast(Tuple[float, float], tuple(bounds)),
|
|
104
|
+
parameter_type=range_parameter_type,
|
|
105
|
+
scaling="log" if param.get("log_scale") else None,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if param["type"] == "choice":
|
|
109
|
+
values = param["values"]
|
|
110
|
+
choice_parameter_type = get_ax_choice_parameter_type(values)
|
|
111
|
+
is_ordered = param.get("is_ordered", choice_parameter_type != "str")
|
|
112
|
+
elif param["type"] == "fixed":
|
|
113
|
+
values = [param["value"]]
|
|
114
|
+
choice_parameter_type = get_ax_choice_parameter_type(values)
|
|
115
|
+
is_ordered = None
|
|
116
|
+
else:
|
|
117
|
+
raise ValueError(f"Unexpected Ax parameter type: {param['type']}")
|
|
118
|
+
|
|
119
|
+
if choice_parameter_type == "float":
|
|
120
|
+
values = [float(value) for value in values]
|
|
121
|
+
return ChoiceParameterConfig(
|
|
122
|
+
name=name,
|
|
123
|
+
values=cast(Any, values),
|
|
124
|
+
parameter_type=choice_parameter_type,
|
|
125
|
+
is_ordered=is_ordered,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
63
129
|
def get_one_batch_of_trials(
|
|
64
|
-
ax_client:
|
|
65
|
-
parallelism: Tuple[int, int],
|
|
66
|
-
num_trials_so_far: int,
|
|
130
|
+
ax_client: Client,
|
|
67
131
|
num_max_trials_to_do: int,
|
|
68
132
|
) -> TrialBatch:
|
|
69
133
|
"""Returns a TrialBatch that contains a list of trials that can be
|
|
70
134
|
run in parallel. TrialBatch also flags if the search space is exhausted."""
|
|
71
135
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
if max_parallelism_setting == -1:
|
|
77
|
-
# Special case, we can group all the trials into one batch
|
|
78
|
-
max_parallelism_setting = num_trials - num_trials_so_far
|
|
79
|
-
|
|
80
|
-
if num_trials == -1:
|
|
81
|
-
# This is a special case where we can run as many trials in parallel as we want.
|
|
82
|
-
# Given that num_trials is also -1, we can run all the trials in parallel.
|
|
83
|
-
max_parallelism_setting = num_max_trials_to_do
|
|
84
|
-
|
|
85
|
-
list_of_trials = []
|
|
86
|
-
for _ in range(max_parallelism_setting):
|
|
87
|
-
try:
|
|
88
|
-
parameters, trial_index = ax_client.get_next_trial()
|
|
89
|
-
list_of_trials.append(
|
|
90
|
-
Trial(
|
|
91
|
-
overrides=map_params_to_arg_list(params=parameters),
|
|
92
|
-
trial_index=trial_index,
|
|
93
|
-
)
|
|
94
|
-
)
|
|
95
|
-
except SearchSpaceExhausted:
|
|
96
|
-
is_search_space_exhausted = True
|
|
97
|
-
break
|
|
136
|
+
try:
|
|
137
|
+
trials = ax_client.get_next_trials(max_trials=num_max_trials_to_do)
|
|
138
|
+
except SearchSpaceExhausted:
|
|
139
|
+
return TrialBatch(list_of_trials=[], is_search_space_exhausted=True)
|
|
98
140
|
|
|
141
|
+
list_of_trials = [
|
|
142
|
+
Trial(
|
|
143
|
+
overrides=map_params_to_arg_list(params=parameters),
|
|
144
|
+
trial_index=trial_index,
|
|
145
|
+
)
|
|
146
|
+
for trial_index, parameters in trials.items()
|
|
147
|
+
]
|
|
99
148
|
return TrialBatch(
|
|
100
149
|
list_of_trials=list_of_trials,
|
|
101
|
-
is_search_space_exhausted=
|
|
150
|
+
is_search_space_exhausted=len(list_of_trials) == 0,
|
|
102
151
|
)
|
|
103
152
|
|
|
104
153
|
|
|
@@ -146,56 +195,40 @@ class CoreAxSweeper(Sweeper):
|
|
|
146
195
|
ax_client = self.setup_ax_client(arguments)
|
|
147
196
|
|
|
148
197
|
num_trials_left = self.max_trials
|
|
149
|
-
max_parallelism = ax_client.get_max_parallelism()
|
|
150
|
-
current_parallelism_index = 0
|
|
151
|
-
# Index to track the parallelism value we are using right now.
|
|
152
198
|
is_search_space_exhausted = False
|
|
153
199
|
# Ax throws an exception if the search space is exhausted. We catch
|
|
154
200
|
# the exception and set the flag to True
|
|
155
201
|
|
|
156
202
|
best_parameters = {}
|
|
157
203
|
while num_trials_left > 0 and not is_search_space_exhausted:
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
) and num_trials_left > 0:
|
|
164
|
-
trial_batch = get_one_batch_of_trials(
|
|
165
|
-
ax_client=ax_client,
|
|
166
|
-
parallelism=current_parallelism,
|
|
167
|
-
num_trials_so_far=num_trials_so_far,
|
|
168
|
-
num_max_trials_to_do=num_trials_left,
|
|
169
|
-
)
|
|
170
|
-
|
|
171
|
-
list_of_trials_to_launch = trial_batch.list_of_trials[:num_trials_left]
|
|
172
|
-
is_search_space_exhausted = trial_batch.is_search_space_exhausted
|
|
204
|
+
num_trials_to_request = min(num_trials_left, self.max_batch_size or 5)
|
|
205
|
+
trial_batch = get_one_batch_of_trials(
|
|
206
|
+
ax_client=ax_client,
|
|
207
|
+
num_max_trials_to_do=num_trials_to_request,
|
|
208
|
+
)
|
|
173
209
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
len(list_of_trials_to_launch)
|
|
177
|
-
)
|
|
178
|
-
)
|
|
210
|
+
list_of_trials_to_launch = trial_batch.list_of_trials[:num_trials_left]
|
|
211
|
+
is_search_space_exhausted = trial_batch.is_search_space_exhausted
|
|
179
212
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
213
|
+
log.info(
|
|
214
|
+
"AxSweeper is launching {} jobs".format(len(list_of_trials_to_launch))
|
|
215
|
+
)
|
|
183
216
|
|
|
184
|
-
|
|
185
|
-
|
|
217
|
+
self.sweep_over_batches(
|
|
218
|
+
ax_client=ax_client, list_of_trials=list_of_trials_to_launch
|
|
219
|
+
)
|
|
186
220
|
|
|
187
|
-
|
|
188
|
-
metric = predictions[0][ax_client.objective_name]
|
|
221
|
+
num_trials_left -= len(list_of_trials_to_launch)
|
|
189
222
|
|
|
223
|
+
best_point = self.get_best_point(ax_client)
|
|
224
|
+
if best_point is not None:
|
|
225
|
+
best_parameters, metric = best_point
|
|
190
226
|
if self.early_stopper.should_stop(metric, best_parameters):
|
|
191
|
-
num_trials_left = -1
|
|
192
227
|
break
|
|
193
228
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
current_parallelism_index += 1
|
|
229
|
+
if is_search_space_exhausted:
|
|
230
|
+
log.info("Ax has exhausted the search space")
|
|
231
|
+
break
|
|
199
232
|
|
|
200
233
|
results_to_serialize = {"optimizer": "ax", "ax": best_parameters}
|
|
201
234
|
OmegaConf.save(
|
|
@@ -205,7 +238,7 @@ class CoreAxSweeper(Sweeper):
|
|
|
205
238
|
log.info("Best parameters: " + str(best_parameters))
|
|
206
239
|
|
|
207
240
|
def sweep_over_batches(
|
|
208
|
-
self, ax_client:
|
|
241
|
+
self, ax_client: Client, list_of_trials: List[Trial]
|
|
209
242
|
) -> None:
|
|
210
243
|
assert self.launcher is not None
|
|
211
244
|
assert self.job_idx is not None
|
|
@@ -234,11 +267,30 @@ class CoreAxSweeper(Sweeper):
|
|
|
234
267
|
val = (val, None) # specify unknown noise
|
|
235
268
|
else:
|
|
236
269
|
val = (val, 0) # specify no noise
|
|
270
|
+
if isinstance(val, tuple):
|
|
271
|
+
val = {self.experiment.objective_name: val}
|
|
237
272
|
ax_client.complete_trial(
|
|
238
273
|
trial_index=batch[idx].trial_index, raw_data=val
|
|
239
274
|
)
|
|
240
275
|
|
|
241
|
-
def
|
|
276
|
+
def get_best_point(
|
|
277
|
+
self, ax_client: Client
|
|
278
|
+
) -> Optional[Tuple[Mapping[str, Any], float]]:
|
|
279
|
+
try:
|
|
280
|
+
best_parameters, metrics, _, _ = ax_client.get_best_parameterization(
|
|
281
|
+
use_model_predictions=False
|
|
282
|
+
)
|
|
283
|
+
except (AssertionError, UnsupportedError):
|
|
284
|
+
return None
|
|
285
|
+
|
|
286
|
+
metric = metrics.get(self.experiment.objective_name)
|
|
287
|
+
if metric is None:
|
|
288
|
+
return None
|
|
289
|
+
if isinstance(metric, tuple):
|
|
290
|
+
metric = metric[0]
|
|
291
|
+
return best_parameters, float(metric)
|
|
292
|
+
|
|
293
|
+
def setup_ax_client(self, arguments: List[str]) -> Client:
|
|
242
294
|
"""Method to setup the Ax Client"""
|
|
243
295
|
parameters: List[Dict[Any, Any]] = []
|
|
244
296
|
for key, value in self.ax_params.items():
|
|
@@ -249,7 +301,6 @@ class CoreAxSweeper(Sweeper):
|
|
|
249
301
|
if not (all(isinstance(x, int) for x in bounds)):
|
|
250
302
|
# Type mismatch. Promote all to float
|
|
251
303
|
param["bounds"] = [float(x) for x in bounds]
|
|
252
|
-
|
|
253
304
|
parameters.append(param)
|
|
254
305
|
parameters[-1]["name"] = key
|
|
255
306
|
commandline_params = self.parse_commandline_args(arguments)
|
|
@@ -265,11 +316,27 @@ class CoreAxSweeper(Sweeper):
|
|
|
265
316
|
log.info(
|
|
266
317
|
f"AxSweeper is optimizing the following parameters: {encoder_parameters_into_string(parameters)}"
|
|
267
318
|
)
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
319
|
+
|
|
320
|
+
if not self.ax_client_config.verbose_logging:
|
|
321
|
+
logging.getLogger("ax.api.client").setLevel(logging.WARNING)
|
|
322
|
+
ax_client = Client(random_seed=self.ax_client_config.random_seed)
|
|
323
|
+
ax_client.configure_experiment(
|
|
324
|
+
parameters=[create_ax_parameter_config(param) for param in parameters],
|
|
325
|
+
parameter_constraints=self.experiment.parameter_constraints,
|
|
326
|
+
name=self.experiment.name,
|
|
327
|
+
)
|
|
328
|
+
if self.experiment.status_quo is not None:
|
|
329
|
+
ax_client.attach_baseline(
|
|
330
|
+
parameters=cast(Mapping[str, Any], self.experiment.status_quo)
|
|
331
|
+
)
|
|
332
|
+
ax_client.configure_optimization(
|
|
333
|
+
objective=(
|
|
334
|
+
f"-{self.experiment.objective_name}"
|
|
335
|
+
if self.experiment.minimize
|
|
336
|
+
else self.experiment.objective_name
|
|
337
|
+
),
|
|
338
|
+
outcome_constraints=self.experiment.outcome_constraints,
|
|
271
339
|
)
|
|
272
|
-
ax_client.create_experiment(parameters=parameters, **self.experiment)
|
|
273
340
|
|
|
274
341
|
return ax_client
|
|
275
342
|
|
|
@@ -288,8 +355,12 @@ class CoreAxSweeper(Sweeper):
|
|
|
288
355
|
param = create_choice_param_from_range_override(override)
|
|
289
356
|
elif override.is_interval_sweep():
|
|
290
357
|
param = create_range_param_using_interval_override(override)
|
|
358
|
+
else:
|
|
359
|
+
raise ValueError(f"Unsupported sweep override: {override}")
|
|
291
360
|
elif not override.is_hydra_override():
|
|
292
361
|
param = create_fixed_param_from_element_override(override)
|
|
362
|
+
else:
|
|
363
|
+
continue
|
|
293
364
|
parameters.append(param)
|
|
294
365
|
|
|
295
366
|
return parameters
|
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
2
2
|
import logging
|
|
3
|
-
from typing import Optional
|
|
4
|
-
|
|
5
|
-
from ax import ParameterType # type: ignore
|
|
3
|
+
from typing import Any, Mapping, Optional
|
|
6
4
|
|
|
7
5
|
log = logging.getLogger(__name__)
|
|
8
6
|
|
|
@@ -24,7 +22,7 @@ class EarlyStopper:
|
|
|
24
22
|
self.current_epochs_without_improvement = 0
|
|
25
23
|
|
|
26
24
|
def should_stop(
|
|
27
|
-
self, potential_best_value: float, best_parameters:
|
|
25
|
+
self, potential_best_value: float, best_parameters: Mapping[str, Any]
|
|
28
26
|
) -> bool:
|
|
29
27
|
"""Check if the optimization process should be stopped."""
|
|
30
28
|
is_improving = True
|
{hydra-ax-sweeper-1.3.0.dev0 → hydra_ax_sweeper-1.4.0.dev4}/hydra_plugins/hydra_ax_sweeper/config.py
RENAMED
|
@@ -34,7 +34,6 @@ class ExperimentConfig:
|
|
|
34
34
|
|
|
35
35
|
@dataclass
|
|
36
36
|
class ClientConfig:
|
|
37
|
-
|
|
38
37
|
verbose_logging: bool = False
|
|
39
38
|
|
|
40
39
|
# set random seed here to make Ax results reproducible
|
|
@@ -43,12 +42,11 @@ class ClientConfig:
|
|
|
43
42
|
|
|
44
43
|
@dataclass
|
|
45
44
|
class AxConfig:
|
|
46
|
-
|
|
47
45
|
# max_trials is application-specific. Tune it for your use case
|
|
48
46
|
max_trials: int = 10
|
|
49
|
-
early_stop: EarlyStopConfig = EarlyStopConfig
|
|
50
|
-
experiment: ExperimentConfig = ExperimentConfig
|
|
51
|
-
client: ClientConfig = ClientConfig
|
|
47
|
+
early_stop: EarlyStopConfig = field(default_factory=EarlyStopConfig)
|
|
48
|
+
experiment: ExperimentConfig = field(default_factory=ExperimentConfig)
|
|
49
|
+
client: ClientConfig = field(default_factory=ClientConfig)
|
|
52
50
|
params: Dict[str, Any] = field(default_factory=dict)
|
|
53
51
|
# is_noisy = True indicates measurements have unknown uncertainty
|
|
54
52
|
# is_noisy = False indicates measurements have an uncertainty of zero
|
|
@@ -60,7 +58,7 @@ class AxSweeperConf:
|
|
|
60
58
|
_target_: str = "hydra_plugins.hydra_ax_sweeper.ax_sweeper.AxSweeper"
|
|
61
59
|
# Maximum number of trials to run in parallel
|
|
62
60
|
max_batch_size: Optional[int] = None
|
|
63
|
-
ax_config: AxConfig = AxConfig
|
|
61
|
+
ax_config: AxConfig = field(default_factory=AxConfig)
|
|
64
62
|
|
|
65
63
|
|
|
66
64
|
ConfigStore.instance().store(
|
|
File without changes
|
|
@@ -11,25 +11,25 @@ setup(
|
|
|
11
11
|
author="Omry Yadan, Shagun Sodhani",
|
|
12
12
|
author_email="omry@fb.com, sshagunsodhani@gmail.com",
|
|
13
13
|
description="Hydra Ax Sweeper plugin",
|
|
14
|
+
license="MIT",
|
|
14
15
|
long_description=(Path(__file__).parent / "README.md").read_text(),
|
|
15
16
|
long_description_content_type="text/markdown",
|
|
16
17
|
url="https://github.com/facebookresearch/hydra/",
|
|
17
18
|
packages=find_namespace_packages(include=["hydra_plugins.*"]),
|
|
18
19
|
classifiers=[
|
|
19
|
-
"
|
|
20
|
-
"Programming Language :: Python :: 3.
|
|
21
|
-
"Programming Language :: Python :: 3.
|
|
22
|
-
"Programming Language :: Python :: 3.
|
|
23
|
-
"Programming Language :: Python :: 3.10",
|
|
20
|
+
"Programming Language :: Python :: 3.11",
|
|
21
|
+
"Programming Language :: Python :: 3.12",
|
|
22
|
+
"Programming Language :: Python :: 3.13",
|
|
23
|
+
"Programming Language :: Python :: 3.14",
|
|
24
24
|
"Operating System :: POSIX :: Linux",
|
|
25
25
|
"Operating System :: MacOS",
|
|
26
26
|
"Development Status :: 4 - Beta",
|
|
27
27
|
],
|
|
28
|
+
python_requires=">=3.11,<3.15",
|
|
28
29
|
install_requires=[
|
|
29
30
|
"hydra-core>=1.1.0.dev7",
|
|
30
|
-
"ax-platform>=
|
|
31
|
-
"torch",
|
|
32
|
-
"gpytorch<=1.8.1", # avoid deprecation warnings. This can probably be removed when ax-platform is unpinned.
|
|
31
|
+
"ax-platform>=1.2.4,<1.3.0",
|
|
32
|
+
"torch>=2.2",
|
|
33
33
|
],
|
|
34
34
|
include_package_data=True,
|
|
35
35
|
)
|
|
@@ -0,0 +1,461 @@
|
|
|
1
|
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
2
|
+
import math
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, List, Tuple
|
|
6
|
+
|
|
7
|
+
from hydra.core.plugins import Plugins
|
|
8
|
+
from hydra.plugins.sweeper import Sweeper
|
|
9
|
+
from hydra.test_utils.test_utils import (
|
|
10
|
+
TSweepRunner,
|
|
11
|
+
chdir_plugin_root,
|
|
12
|
+
run_python_script,
|
|
13
|
+
)
|
|
14
|
+
from omegaconf import DictConfig, OmegaConf
|
|
15
|
+
from pytest import mark, raises
|
|
16
|
+
|
|
17
|
+
from hydra_plugins.hydra_ax_sweeper.ax_sweeper import AxSweeper
|
|
18
|
+
|
|
19
|
+
chdir_plugin_root()
|
|
20
|
+
|
|
21
|
+
WARNING_FILTERS = [
|
|
22
|
+
# 2026-05-15: linear_operator 0.6.1 imports torch.jit.script via
|
|
23
|
+
# Ax -> Botorch -> GPyTorch on Python 3.14 with Torch 2.12.
|
|
24
|
+
# Remove when the current Ax stack no longer emits it under -Werror.
|
|
25
|
+
"ignore:`torch.jit.script` is deprecated:DeprecationWarning",
|
|
26
|
+
# 2026-05-15: same linear_operator import path as above, but Torch emits this
|
|
27
|
+
# alternate wording in some import paths. Remove with the filter above.
|
|
28
|
+
"ignore:`torch.jit.script` is not supported:DeprecationWarning",
|
|
29
|
+
# 2026-05-15: Ax 1.2.4 uses asyncio.iscoroutinefunction in retry helpers.
|
|
30
|
+
# Remove when Ax no longer emits it on Python 3.14 under -Werror.
|
|
31
|
+
"ignore:'asyncio.iscoroutinefunction' is deprecated:DeprecationWarning",
|
|
32
|
+
# 2026-05-15: Ax 1.2.4's JSON storage registry imports this moved shim.
|
|
33
|
+
# Remove when importing ax.api.Client no longer touches the shim.
|
|
34
|
+
"ignore:ax.service.utils.orchestrator_options has been moved:DeprecationWarning",
|
|
35
|
+
# 2026-05-15: Ax 1.2.4's overview analysis imports this moved shim.
|
|
36
|
+
# Remove when importing ax.api.Client no longer touches the shim.
|
|
37
|
+
"ignore:ax.service.orchestrator has been moved:DeprecationWarning",
|
|
38
|
+
# 2026-05-15: Botorch/linear_operator can add Cholesky jitter while fitting
|
|
39
|
+
# the Ax surrogate for log-scale tests. Remove when the current Ax stack no
|
|
40
|
+
# longer emits it under -Werror for the plugin's log-scale sweep.
|
|
41
|
+
"ignore:A not p.d., added jitter:Warning",
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
pytestmark = [mark.filterwarnings(warning_filter) for warning_filter in WARNING_FILTERS]
|
|
45
|
+
PYTHON_WARNING_FILTERS = [f"-W{warning_filter}" for warning_filter in WARNING_FILTERS]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def run_ax_python_script(cmd: List[str]) -> Tuple[str, str]:
|
|
49
|
+
return run_python_script(PYTHON_WARNING_FILTERS + cmd)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def test_discovery() -> None:
|
|
53
|
+
"""
|
|
54
|
+
Tests that this plugin can be discovered via the plugins subsystem when looking for Sweeper
|
|
55
|
+
:return:
|
|
56
|
+
"""
|
|
57
|
+
assert AxSweeper.__name__ in [
|
|
58
|
+
x.__name__ for x in Plugins.instance().discover(Sweeper)
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def quadratic(cfg: DictConfig) -> Any:
|
|
63
|
+
return 100 * (cfg.quadratic.x**2) + 1 * cfg.quadratic.y
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@mark.parametrize(
|
|
67
|
+
"n,expected",
|
|
68
|
+
[
|
|
69
|
+
(None, [[1, 2, 3, 4, 5]]),
|
|
70
|
+
(1, [[1], [2], [3], [4], [5]]),
|
|
71
|
+
(2, [[1, 2], [3, 4], [5]]),
|
|
72
|
+
(5, [[1, 2, 3, 4, 5]]),
|
|
73
|
+
(6, [[1, 2, 3, 4, 5]]),
|
|
74
|
+
],
|
|
75
|
+
)
|
|
76
|
+
def test_chunk_method_for_valid_inputs(n: int, expected: List[List[int]]) -> None:
|
|
77
|
+
from hydra_plugins.hydra_ax_sweeper._core import CoreAxSweeper
|
|
78
|
+
|
|
79
|
+
chunk_func = CoreAxSweeper.chunks
|
|
80
|
+
batch = [1, 2, 3, 4, 5]
|
|
81
|
+
out = list(chunk_func(batch, n))
|
|
82
|
+
assert out == expected
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@mark.parametrize("n", [-1, -11, 0])
|
|
86
|
+
def test_chunk_method_for_invalid_inputs(n: int) -> None:
|
|
87
|
+
from hydra_plugins.hydra_ax_sweeper._core import CoreAxSweeper
|
|
88
|
+
|
|
89
|
+
chunk_func = CoreAxSweeper.chunks
|
|
90
|
+
batch = [1, 2, 3, 4, 5]
|
|
91
|
+
with raises(ValueError):
|
|
92
|
+
list(chunk_func(batch, n))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def test_jobs_dirs(hydra_sweep_runner: TSweepRunner) -> None:
|
|
96
|
+
# Verify that the spawned jobs are not overstepping the directories of one another.
|
|
97
|
+
sweep = hydra_sweep_runner(
|
|
98
|
+
calling_file="tests/test_ax_sweeper_plugin.py",
|
|
99
|
+
calling_module=None,
|
|
100
|
+
task_function=quadratic,
|
|
101
|
+
config_path="config",
|
|
102
|
+
config_name="config.yaml",
|
|
103
|
+
overrides=[
|
|
104
|
+
"hydra/launcher=basic",
|
|
105
|
+
"hydra.sweeper.ax_config.max_trials=6",
|
|
106
|
+
"hydra.sweeper.ax_config.early_stop.max_epochs_without_improvement=100",
|
|
107
|
+
"hydra.sweeper.max_batch_size=2",
|
|
108
|
+
"params=basic",
|
|
109
|
+
],
|
|
110
|
+
)
|
|
111
|
+
with sweep:
|
|
112
|
+
assert isinstance(sweep.temp_dir, str)
|
|
113
|
+
dirs = [
|
|
114
|
+
x
|
|
115
|
+
for x in os.listdir(sweep.temp_dir)
|
|
116
|
+
if os.path.isdir(os.path.join(sweep.temp_dir, x))
|
|
117
|
+
]
|
|
118
|
+
assert len(dirs) == 6 # and a total of 6 unique output directories
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@mark.parametrize("test_conf", ["basic", "logscale"])
|
|
122
|
+
def test_jobs_configured_via_config(
|
|
123
|
+
hydra_sweep_runner: TSweepRunner, test_conf: str
|
|
124
|
+
) -> None:
|
|
125
|
+
sweep = hydra_sweep_runner(
|
|
126
|
+
calling_file="tests/test_ax_sweeper_plugin.py",
|
|
127
|
+
calling_module=None,
|
|
128
|
+
task_function=quadratic,
|
|
129
|
+
config_path="config",
|
|
130
|
+
config_name="config.yaml",
|
|
131
|
+
overrides=["hydra/launcher=basic", f"params={test_conf}"],
|
|
132
|
+
)
|
|
133
|
+
with sweep:
|
|
134
|
+
assert sweep.returns is None
|
|
135
|
+
returns = OmegaConf.load(f"{sweep.temp_dir}/optimization_results.yaml")
|
|
136
|
+
assert isinstance(returns, DictConfig)
|
|
137
|
+
assert returns["optimizer"] == "ax"
|
|
138
|
+
assert len(returns) == 2
|
|
139
|
+
best_parameters = returns.ax
|
|
140
|
+
assert math.isclose(best_parameters["quadratic.x"], 0.0, abs_tol=1e-4)
|
|
141
|
+
expected_y = 0.0 if test_conf == "basic" else -1.0
|
|
142
|
+
assert math.isclose(best_parameters["quadratic.y"], expected_y, abs_tol=1e-4)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@mark.parametrize(
|
|
146
|
+
"test_conf, override, expected_x",
|
|
147
|
+
[
|
|
148
|
+
("basic", "int(interval(1, 5))", 1.0),
|
|
149
|
+
],
|
|
150
|
+
)
|
|
151
|
+
def test_jobs_configured_via_cmd(
|
|
152
|
+
hydra_sweep_runner: TSweepRunner, test_conf: str, override: str, expected_x: float
|
|
153
|
+
) -> None:
|
|
154
|
+
sweep = hydra_sweep_runner(
|
|
155
|
+
calling_file="tests/test_ax_sweeper_plugin.py",
|
|
156
|
+
calling_module=None,
|
|
157
|
+
task_function=quadratic,
|
|
158
|
+
config_path="config",
|
|
159
|
+
config_name="config.yaml",
|
|
160
|
+
overrides=[
|
|
161
|
+
"hydra/launcher=basic",
|
|
162
|
+
"hydra.sweeper.ax_config.max_trials=5",
|
|
163
|
+
f"quadratic.x={override}",
|
|
164
|
+
"quadratic.y=-2",
|
|
165
|
+
f"params={test_conf}",
|
|
166
|
+
],
|
|
167
|
+
)
|
|
168
|
+
with sweep:
|
|
169
|
+
assert sweep.returns is None
|
|
170
|
+
returns = OmegaConf.load(f"{sweep.temp_dir}/optimization_results.yaml")
|
|
171
|
+
assert isinstance(returns, DictConfig)
|
|
172
|
+
assert returns["optimizer"] == "ax"
|
|
173
|
+
assert len(returns) == 2
|
|
174
|
+
best_parameters = returns.ax
|
|
175
|
+
assert math.isclose(best_parameters["quadratic.x"], expected_x, abs_tol=1e-4)
|
|
176
|
+
assert math.isclose(best_parameters["quadratic.y"], -2.0, abs_tol=1e-4)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def test_jobs_configured_via_cmd_and_config(hydra_sweep_runner: TSweepRunner) -> None:
|
|
180
|
+
sweep = hydra_sweep_runner(
|
|
181
|
+
calling_file="tests/test_ax_sweeper_plugin.py",
|
|
182
|
+
calling_module=None,
|
|
183
|
+
task_function=quadratic,
|
|
184
|
+
config_path="config",
|
|
185
|
+
config_name="config.yaml",
|
|
186
|
+
overrides=[
|
|
187
|
+
"hydra/launcher=basic",
|
|
188
|
+
"hydra.sweeper.ax_config.max_trials=2",
|
|
189
|
+
"quadratic.x=int(interval(-5, -2))",
|
|
190
|
+
"params=basic",
|
|
191
|
+
],
|
|
192
|
+
)
|
|
193
|
+
with sweep:
|
|
194
|
+
assert sweep.returns is None
|
|
195
|
+
returns = OmegaConf.load(f"{sweep.temp_dir}/optimization_results.yaml")
|
|
196
|
+
assert isinstance(returns, DictConfig)
|
|
197
|
+
assert returns["optimizer"] == "ax"
|
|
198
|
+
assert len(returns) == 2
|
|
199
|
+
best_parameters = returns.ax
|
|
200
|
+
assert math.isclose(best_parameters["quadratic.x"], -3.0, abs_tol=1e-4)
|
|
201
|
+
assert math.isclose(best_parameters["quadratic.y"], 0.0, abs_tol=1e-4)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def test_command_line_int_interval_optimizes_integer_parabola(
|
|
205
|
+
hydra_sweep_runner: TSweepRunner,
|
|
206
|
+
) -> None:
|
|
207
|
+
sweep = hydra_sweep_runner(
|
|
208
|
+
calling_file="tests/test_ax_sweeper_plugin.py",
|
|
209
|
+
calling_module=None,
|
|
210
|
+
task_function=quadratic,
|
|
211
|
+
config_path="config",
|
|
212
|
+
config_name="config.yaml",
|
|
213
|
+
overrides=[
|
|
214
|
+
"hydra/launcher=basic",
|
|
215
|
+
"hydra.sweeper.ax_config.max_trials=12",
|
|
216
|
+
"quadratic.x=int(interval(-5, 5))",
|
|
217
|
+
"quadratic.y=0",
|
|
218
|
+
],
|
|
219
|
+
)
|
|
220
|
+
with sweep:
|
|
221
|
+
assert sweep.returns is None
|
|
222
|
+
returns = OmegaConf.load(f"{sweep.temp_dir}/optimization_results.yaml")
|
|
223
|
+
assert isinstance(returns, DictConfig)
|
|
224
|
+
best_parameters = returns.ax
|
|
225
|
+
assert isinstance(best_parameters["quadratic.x"], int)
|
|
226
|
+
assert best_parameters["quadratic.x"] == 0
|
|
227
|
+
assert best_parameters["quadratic.y"] == 0
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def test_configuration_set_via_cmd_and_default_config(
|
|
231
|
+
hydra_sweep_runner: TSweepRunner,
|
|
232
|
+
) -> None:
|
|
233
|
+
sweep = hydra_sweep_runner(
|
|
234
|
+
calling_file="tests/test_ax_sweeper_plugin.py",
|
|
235
|
+
calling_module=None,
|
|
236
|
+
task_function=quadratic,
|
|
237
|
+
config_path="config",
|
|
238
|
+
config_name="config.yaml",
|
|
239
|
+
overrides=[
|
|
240
|
+
"hydra/launcher=basic",
|
|
241
|
+
"hydra.sweeper.ax_config.max_trials=2",
|
|
242
|
+
"hydra.sweeper.ax_config.early_stop.max_epochs_without_improvement=2",
|
|
243
|
+
"quadratic=basic",
|
|
244
|
+
"quadratic.x=interval(-5, -2)",
|
|
245
|
+
"quadratic.y=interval(-1, 1)",
|
|
246
|
+
],
|
|
247
|
+
)
|
|
248
|
+
with sweep:
|
|
249
|
+
assert sweep.returns is None
|
|
250
|
+
returns = OmegaConf.load(f"{sweep.temp_dir}/optimization_results.yaml")
|
|
251
|
+
assert isinstance(returns, DictConfig)
|
|
252
|
+
best_parameters = returns.ax
|
|
253
|
+
assert "quadratic.x" in best_parameters
|
|
254
|
+
assert "quadratic.y" in best_parameters
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
@mark.parametrize(
|
|
258
|
+
"cmd_arg, expected_str",
|
|
259
|
+
[
|
|
260
|
+
("polynomial.y=choice(-1, 0, 1)", "polynomial.y: choice=[-1, 0, 1]"),
|
|
261
|
+
("polynomial.y=range(-1, 2)", "polynomial.y: choice=[-1, 0, 1]"),
|
|
262
|
+
("polynomial.y=range(-1, 3, 1)", "polynomial.y: choice=[-1, 0, 1, 2]"),
|
|
263
|
+
(
|
|
264
|
+
"polynomial.y=range(-1, 2, 0.5)",
|
|
265
|
+
"polynomial.y: choice=[-1.0, -0.5, 0.0, 0.5, 1.0, 1.5]",
|
|
266
|
+
),
|
|
267
|
+
("polynomial.y=int(interval(-1, 2))", "polynomial.y: range=[-1, 2]"),
|
|
268
|
+
("polynomial.y=interval(-1, 2)", "polynomial.y: range=[-1.0, 2.0]"),
|
|
269
|
+
(
|
|
270
|
+
"polynomial.y=tag(log, interval(0.00001, 1))",
|
|
271
|
+
"polynomial.y: range=[1e-05, 1.0], log_scale = True",
|
|
272
|
+
),
|
|
273
|
+
("polynomial.y=2", "polynomial.y: fixed=2"),
|
|
274
|
+
("polynomial.y=2.0", "polynomial.y: fixed=2.0"),
|
|
275
|
+
],
|
|
276
|
+
)
|
|
277
|
+
def test_ax_logging(cmd_arg: str, expected_str: str) -> None:
|
|
278
|
+
from hydra_plugins.hydra_ax_sweeper._core import (
|
|
279
|
+
CoreAxSweeper,
|
|
280
|
+
encoder_parameters_into_string,
|
|
281
|
+
)
|
|
282
|
+
from hydra_plugins.hydra_ax_sweeper.config import AxConfig
|
|
283
|
+
|
|
284
|
+
sweeper = CoreAxSweeper(AxConfig(), max_batch_size=None)
|
|
285
|
+
parameters = sweeper.parse_commandline_args(
|
|
286
|
+
[
|
|
287
|
+
"polynomial.x=interval(-5, -2)",
|
|
288
|
+
"polynomial.z=10",
|
|
289
|
+
cmd_arg,
|
|
290
|
+
]
|
|
291
|
+
)
|
|
292
|
+
result = encoder_parameters_into_string(parameters)
|
|
293
|
+
assert "polynomial.x: range=[-5.0, -2.0]" in result
|
|
294
|
+
assert expected_str in result
|
|
295
|
+
assert "polynomial.z: fixed=10" in result
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def test_command_line_log_interval_configures_ax_log_range() -> None:
|
|
299
|
+
from ax.api.configs import RangeParameterConfig # type: ignore
|
|
300
|
+
|
|
301
|
+
from hydra_plugins.hydra_ax_sweeper._core import (
|
|
302
|
+
CoreAxSweeper,
|
|
303
|
+
create_ax_parameter_config,
|
|
304
|
+
)
|
|
305
|
+
from hydra_plugins.hydra_ax_sweeper.config import AxConfig
|
|
306
|
+
|
|
307
|
+
sweeper = CoreAxSweeper(AxConfig(), max_batch_size=None)
|
|
308
|
+
(parameter,) = sweeper.parse_commandline_args(
|
|
309
|
+
["polynomial.y=tag(log, interval(0.00001, 1))"]
|
|
310
|
+
)
|
|
311
|
+
ax_parameter = create_ax_parameter_config(parameter)
|
|
312
|
+
assert isinstance(ax_parameter, RangeParameterConfig)
|
|
313
|
+
assert ax_parameter.name == "polynomial.y"
|
|
314
|
+
assert ax_parameter.parameter_type == "float"
|
|
315
|
+
assert ax_parameter.bounds == (1e-05, 1.0)
|
|
316
|
+
assert ax_parameter.scaling == "log"
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def test_ax_logging_from_hydra_app(tmpdir: Path) -> None:
|
|
320
|
+
cmd = [
|
|
321
|
+
"tests/apps/polynomial.py",
|
|
322
|
+
"-m",
|
|
323
|
+
f'hydra.run.dir="{str(tmpdir)}"',
|
|
324
|
+
"hydra.job.chdir=True",
|
|
325
|
+
"polynomial.x=interval(-5, -2)",
|
|
326
|
+
"polynomial.z=10",
|
|
327
|
+
"hydra.sweeper.ax_config.max_trials=2",
|
|
328
|
+
"polynomial.y=int(interval(-1, 2))",
|
|
329
|
+
]
|
|
330
|
+
result, _ = run_ax_python_script(cmd)
|
|
331
|
+
assert "polynomial.x: range=[-5.0, -2.0]" in result
|
|
332
|
+
assert "polynomial.y: range=[-1, 2]" in result
|
|
333
|
+
assert "polynomial.z: fixed=10" in result
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
@mark.parametrize(
|
|
337
|
+
"cmd_args",
|
|
338
|
+
[
|
|
339
|
+
["polynomial.y=choice(-1, 0, 1)", "polynomial.x=range(2,4)"],
|
|
340
|
+
["polynomial.y=1", "polynomial.x=range(2,4)"],
|
|
341
|
+
],
|
|
342
|
+
)
|
|
343
|
+
def test_search_space_exhausted_exception(tmpdir: Path, cmd_args: List[str]) -> None:
|
|
344
|
+
cmd = [
|
|
345
|
+
"tests/apps/polynomial.py",
|
|
346
|
+
"-m",
|
|
347
|
+
f'hydra.run.dir="{str(tmpdir)}"',
|
|
348
|
+
"hydra.job.chdir=True",
|
|
349
|
+
"hydra.sweeper.ax_config.max_trials=2",
|
|
350
|
+
] + cmd_args
|
|
351
|
+
run_ax_python_script(cmd)
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
@mark.parametrize(
|
|
355
|
+
"cmd_args",
|
|
356
|
+
[
|
|
357
|
+
["polynomial.y=choice(-1, 0, 1)", "polynomial.x=range(2,4)"],
|
|
358
|
+
["polynomial.y=1", "polynomial.x=range(2,4)"],
|
|
359
|
+
],
|
|
360
|
+
)
|
|
361
|
+
def test_search_space_with_constraint_metric(tmpdir: Path, cmd_args: List[str]) -> None:
|
|
362
|
+
# test that outcome_constraints experiment parameter `outcome_constraints`
|
|
363
|
+
# works correctly, and that the ax_sweeper supports outputting a dictionary
|
|
364
|
+
# from the evaluation function so that multiple metrics can be supported.
|
|
365
|
+
cmd = [
|
|
366
|
+
"tests/apps/polynomial_with_constraint.py",
|
|
367
|
+
"-m",
|
|
368
|
+
f'hydra.run.dir="{str(tmpdir)}"',
|
|
369
|
+
"hydra.job.chdir=True",
|
|
370
|
+
"hydra.sweeper.ax_config.max_trials=2",
|
|
371
|
+
] + cmd_args
|
|
372
|
+
results, _ = run_ax_python_script(cmd)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
@mark.parametrize(
|
|
376
|
+
"cmd_arg, serialized_encoding, best_coefficients, best_value",
|
|
377
|
+
[
|
|
378
|
+
(
|
|
379
|
+
"polynomial.coefficients=[-1, 0, 1],[2, 3, 4],[5, 6, 7]",
|
|
380
|
+
"choice=['[-1,0,1]', '[2,3,4]', '[5,6,7]']",
|
|
381
|
+
"'[-1,0,1]'",
|
|
382
|
+
101.0,
|
|
383
|
+
),
|
|
384
|
+
(
|
|
385
|
+
"polynomial.coefficients=choice([8, 12, 11],[-1, -1, 1000], [-2, 4, 7])",
|
|
386
|
+
"choice=['[8,12,11]', '[-1,-1,1000]', '[-2,4,7]']",
|
|
387
|
+
"'[-2,4,7]'",
|
|
388
|
+
447,
|
|
389
|
+
),
|
|
390
|
+
],
|
|
391
|
+
)
|
|
392
|
+
def test_jobs_using_choice_between_lists(
|
|
393
|
+
tmpdir: Path,
|
|
394
|
+
cmd_arg: str,
|
|
395
|
+
serialized_encoding: str,
|
|
396
|
+
best_coefficients: str,
|
|
397
|
+
best_value: float,
|
|
398
|
+
) -> None:
|
|
399
|
+
cmd = [
|
|
400
|
+
"tests/apps/polynomial_with_list_coefficients.py",
|
|
401
|
+
"-m",
|
|
402
|
+
f'hydra.run.dir="{str(tmpdir)}"',
|
|
403
|
+
"hydra.job.chdir=True",
|
|
404
|
+
"hydra.sweeper.ax_config.max_trials=3",
|
|
405
|
+
] + [cmd_arg]
|
|
406
|
+
result, _ = run_ax_python_script(cmd)
|
|
407
|
+
assert f"polynomial.coefficients: {serialized_encoding}" in result
|
|
408
|
+
assert f"'polynomial.coefficients': {best_coefficients}" in result
|
|
409
|
+
assert f"New best value: {best_value}" in result
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
@mark.parametrize(
|
|
413
|
+
"cmd_arg, serialized_encoding, best_coefficients, best_value",
|
|
414
|
+
[
|
|
415
|
+
(
|
|
416
|
+
"+polynomial.coefficients=choice({x:-1, y:0, z:1},{x:2, y:3, z:4},{x:5, y:6, z:7})",
|
|
417
|
+
"choice=['{x:-1,y:0,z:1}', '{x:2,y:3,z:4}', '{x:5,y:6,z:7}']",
|
|
418
|
+
"'{x:-1,y:0,z:1}'",
|
|
419
|
+
101.0,
|
|
420
|
+
),
|
|
421
|
+
(
|
|
422
|
+
"+polynomial.coefficients=choice({x:8, y:12, z:11},{x:-1, y:-1, z:1000}, {x:-2, y:4, z:7})",
|
|
423
|
+
"choice=['{x:8,y:12,z:11}', '{x:-1,y:-1,z:1000}', '{x:-2,y:4,z:7}']",
|
|
424
|
+
"'{x:-2,y:4,z:7}'}",
|
|
425
|
+
447,
|
|
426
|
+
),
|
|
427
|
+
],
|
|
428
|
+
)
|
|
429
|
+
def test_jobs_using_choice_between_dicts(
|
|
430
|
+
tmpdir: Path,
|
|
431
|
+
cmd_arg: str,
|
|
432
|
+
serialized_encoding: str,
|
|
433
|
+
best_coefficients: str,
|
|
434
|
+
best_value: float,
|
|
435
|
+
) -> None:
|
|
436
|
+
cmd = [
|
|
437
|
+
"tests/apps/polynomial_with_dict_coefficients.py",
|
|
438
|
+
"-m",
|
|
439
|
+
f'hydra.run.dir="{str(tmpdir)}"',
|
|
440
|
+
"hydra.job.chdir=True",
|
|
441
|
+
"hydra.sweeper.ax_config.max_trials=3",
|
|
442
|
+
] + [cmd_arg]
|
|
443
|
+
result, _ = run_ax_python_script(cmd)
|
|
444
|
+
assert f"polynomial.coefficients: {serialized_encoding}" in result
|
|
445
|
+
assert f"'+polynomial.coefficients': {best_coefficients}" in result
|
|
446
|
+
assert f"New best value: {best_value}" in result
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def test_example_app(tmpdir: Path) -> None:
|
|
450
|
+
cmd = [
|
|
451
|
+
"example/banana.py",
|
|
452
|
+
"-m",
|
|
453
|
+
f'hydra.run.dir="{str(tmpdir)}"',
|
|
454
|
+
"hydra.job.chdir=True",
|
|
455
|
+
"banana.x=int(interval(-5, 5))",
|
|
456
|
+
"banana.y=interval(-5, 10.1)",
|
|
457
|
+
"hydra.sweeper.ax_config.max_trials=2",
|
|
458
|
+
]
|
|
459
|
+
result, _ = run_ax_python_script(cmd)
|
|
460
|
+
assert "banana.x: range=[-5, 5]" in result
|
|
461
|
+
assert "banana.y: range=[-5.0, 10.1]" in result
|
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.1
|
|
2
|
-
Name: hydra-ax-sweeper
|
|
3
|
-
Version: 1.3.0.dev0
|
|
4
|
-
Summary: Hydra Ax Sweeper plugin
|
|
5
|
-
Home-page: https://github.com/facebookresearch/hydra/
|
|
6
|
-
Author: Omry Yadan, Shagun Sodhani
|
|
7
|
-
Author-email: omry@fb.com, sshagunsodhani@gmail.com
|
|
8
|
-
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
-
Classifier: Programming Language :: Python :: 3.7
|
|
10
|
-
Classifier: Programming Language :: Python :: 3.8
|
|
11
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
12
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
-
Classifier: Operating System :: POSIX :: Linux
|
|
14
|
-
Classifier: Operating System :: MacOS
|
|
15
|
-
Classifier: Development Status :: 4 - Beta
|
|
16
|
-
Description-Content-Type: text/markdown
|
|
17
|
-
|
|
18
|
-
# Hydra Ax Sweeper
|
|
19
|
-
Provides a [`Ax Sweeper`](https://ax.dev/) based Hydra Sweeper supporting parallel execution.
|
|
20
|
-
|
|
21
|
-
See [website](https://hydra.cc/docs/plugins/ax_sweeper/) for more information
|
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.1
|
|
2
|
-
Name: hydra-ax-sweeper
|
|
3
|
-
Version: 1.3.0.dev0
|
|
4
|
-
Summary: Hydra Ax Sweeper plugin
|
|
5
|
-
Home-page: https://github.com/facebookresearch/hydra/
|
|
6
|
-
Author: Omry Yadan, Shagun Sodhani
|
|
7
|
-
Author-email: omry@fb.com, sshagunsodhani@gmail.com
|
|
8
|
-
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
-
Classifier: Programming Language :: Python :: 3.7
|
|
10
|
-
Classifier: Programming Language :: Python :: 3.8
|
|
11
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
12
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
-
Classifier: Operating System :: POSIX :: Linux
|
|
14
|
-
Classifier: Operating System :: MacOS
|
|
15
|
-
Classifier: Development Status :: 4 - Beta
|
|
16
|
-
Description-Content-Type: text/markdown
|
|
17
|
-
|
|
18
|
-
# Hydra Ax Sweeper
|
|
19
|
-
Provides a [`Ax Sweeper`](https://ax.dev/) based Hydra Sweeper supporting parallel execution.
|
|
20
|
-
|
|
21
|
-
See [website](https://hydra.cc/docs/plugins/ax_sweeper/) for more information
|
|
File without changes
|
|
File without changes
|
{hydra-ax-sweeper-1.3.0.dev0 → hydra_ax_sweeper-1.4.0.dev4}/hydra_ax_sweeper.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|