flytekitplugins-optuna 1.15.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,28 @@
1
+ Metadata-Version: 2.2
2
+ Name: flytekitplugins-optuna
3
+ Version: 1.15.0
4
+ Summary: Optuna plugin for flytekit
5
+ Author: flyteorg
6
+ Author-email: admin@flyte.org
7
+ License: apache2
8
+ Classifier: Intended Audience :: Science/Research
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: License :: OSI Approved :: Apache Software License
11
+ Classifier: Programming Language :: Python :: 3.9
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Topic :: Scientific/Engineering
14
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
15
+ Classifier: Topic :: Software Development
16
+ Classifier: Topic :: Software Development :: Libraries
17
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
18
+ Requires-Python: >=3.9
19
+ Requires-Dist: flytekit>=1.15.0
20
+ Requires-Dist: optuna<5.0.0,>=4.0.0
21
+ Requires-Dist: typing-extensions<5.0,>=4.10
22
+ Dynamic: author
23
+ Dynamic: author-email
24
+ Dynamic: classifier
25
+ Dynamic: license
26
+ Dynamic: requires-dist
27
+ Dynamic: requires-python
28
+ Dynamic: summary
@@ -0,0 +1,188 @@
1
+ # Fully Parallelized Wrapper Around Optuna Using Flyte
2
+
3
+ ## Overview
4
+
5
+ This documentation provides a guide to a fully parallelized Flyte plugin for Optuna. This wrapper leverages Flyte's scalable and distributed workflow orchestration capabilities to parallelize Optuna's hyperparameter optimization across multiple trials efficiently.
6
+
7
+ ![Timeline](timeline.png)
8
+
9
+
10
+ ## Features
11
+
12
+ - **Ease of Use**: This plugin requires no external data storage or experiment tracking.
13
+ - **Parallelized Trial Execution**: Enables concurrent execution of Optuna trials, dramatically speeding up optimization tasks.
14
+ - **Scalability**: Leverages Flyte’s ability to scale horizontally to handle large-scale hyperparameter tuning jobs.
15
+ - **Flexible Integration**: Compatible with various machine learning frameworks and training pipelines.
16
+
17
+ ## Installation
18
+
19
+ - Install `flytekit`
20
+ - Install `flytekitplugins.optuna`
21
+
22
+ ## Getting Started
23
+
24
+ ### Prerequisites
25
+
26
+ - A Flyte deployment configured and running.
27
+ - Python 3.9 or later.
28
+ - Familiarity with Flyte and asynchronous programming.
29
+
30
+ ### Define the Objective Function
31
+
32
+ The objective function defines the problem to be optimized. It should include the hyperparameters to be tuned and return a value to minimize or maximize.
33
+
34
+ ```python
35
+ import math
36
+
37
+ import flytekit as fl
38
+
39
+ image = fl.ImageSpec(packages=["flytekitplugins.optuna"])
40
+
41
+ @fl.task(container_image=image)
42
+ async def objective(x: float, y: int, z: int, power: int) -> float:
43
+ return math.log((((x - 5) ** 2) + (y + 4) ** 4 + (3 * z - 3) ** 2)) ** power
44
+
45
+ ```
46
+
47
+ ### Configure the Flyte Workflow
48
+
49
+ The Flyte workflow orchestrates the parallel execution of Optuna trials. Below is an example:
50
+
51
+ ```python
52
+ import flytekit as fl
53
+ from flytekitplugins.optuna import Optimizer, suggest
54
+
55
+ @fl.eager(container_image=image)
56
+ async def train(concurrency: int, n_trials: int) -> float:
57
+
58
+ optimizer = Optimizer(objective=objective, concurrency=concurrency, n_trials=n_trials)
59
+
60
+ await optimizer(
61
+ x=suggest.float(low=-10, high=10),
62
+ y=suggest.integer(low=-10, high=10),
63
+ z=suggest.category([-5, 0, 3, 6, 9]),
64
+ power=2,
65
+ )
66
+
67
+ print(optimizer.study.best_value)
68
+
69
+ ```
70
+
71
+ ### Register and Execute the Workflow
72
+
73
+ Submit the workflow to Flyte for execution:
74
+
75
+ ```bash
76
+ pyflyte register files .
77
+ pyflyte run --name train
78
+ ```
79
+
80
+ ### Monitor Progress
81
+
82
+ You can monitor the progress of the trials via the Flyte Console. Each trial runs as a separate task, and the results are aggregated by the Optuna wrapper.
83
+
84
+ You may access the `optuna.Study` like so: `optimizer.study`.
85
+
86
+ Therefore, with `plotly` installed, you may create create Flyte Decks of the study like so:
87
+
88
+ ```python
89
+ import plotly
90
+
91
+ fig = optuna.visualization.plot_timeline(optimizer.study)
92
+ fl.Deck(name, plotly.io.to_html(fig))
93
+ ```
94
+
95
+ ## Advanced Configuration
96
+
97
+ ### Custom Dictionary Inputs
98
+
99
+ Suggestions may be defined in recursive dictionaries:
100
+
101
+ ```python
102
+ import flytekit as fl
103
+ from flytekitplugins.optuna import Optimizer, suggest
104
+
105
+ image = fl.ImageSpec(packages=["flytekitplugins.optuna"])
106
+
107
+
108
+ @fl.task(container_image=image)
109
+ async def objective(params: dict[str, int | float | str]) -> float:
110
+ ...
111
+
112
+
113
+ @fl.eager(container_image=image)
114
+ async def train(concurrency: int, n_trials: int):
115
+
116
+ study = optuna.create_study(direction="maximize")
117
+
118
+ optimizer = Optimizer(objective=objective, concurrency=concurrency, n_trials=n_trials, study=study)
119
+
120
+ params = {
121
+ "lambda": suggest.float(1e-8, 1.0, log=True),
122
+ "alpha": suggest.float(1e-8, 1.0, log=True),
123
+ "subsample": suggest.float(0.2, 1.0),
124
+ "colsample_bytree": suggest.float(0.2, 1.0),
125
+ "max_depth": suggest.integer(3, 9, step=2),
126
+ "objective": "binary:logistic",
127
+ "tree_method": "exact",
128
+ "booster": "dart",
129
+ }
130
+
131
+ await optimizer(params=params)
132
+ ```
133
+
134
+ ### Custom Callbacks
135
+
136
+ In some cases, you may need to define the suggestions programmatically. This may be done
137
+
138
+ ```python
139
+ import flytekit as fl
140
+ import optuna
141
+ from flytekitplugins.optuna import optimize
142
+
143
+ image = fl.ImageSpec(packages=["flytekitplugins.optuna"])
144
+
145
+ @fl.task(container_image=image)
146
+ async def objective(params: dict[str, int | float | str]) -> float:
147
+ ...
148
+
149
+ @optimize
150
+ def optimizer(trial: optuna.Trial, verbosity: int, tree_method: str):
151
+
152
+ params = {
153
+ "verbosity:": verbosity,
154
+ "tree_method": tree_method,
155
+ "objective": "binary:logistic",
156
+ # defines booster, gblinear for linear functions.
157
+ "booster": trial.suggest_categorical("booster", ["gbtree", "gblinear", "dart"]),
158
+ # sampling according to each tree.
159
+ "colsample_bytree": trial.suggest_float("colsample_bytree", 0.2, 1.0),
160
+ }
161
+
162
+ if params["booster"] in ["gbtree", "dart"]:
163
+ # maximum depth of the tree, signifies complexity of the tree.
164
+ params["max_depth"] = trial.suggest_int("max_depth", 3, 9, step=2)
165
+
166
+ if params["booster"] == "dart":
167
+ params["sample_type"] = trial.suggest_categorical("sample_type", ["uniform", "weighted"])
168
+ params["normalize_type"] = trial.suggest_categorical("normalize_type", ["tree", "forest"])
169
+
170
+ return objective(params)
171
+
172
+
173
+ @fl.eager(container_image=image)
174
+ async def train(concurrency: int, n_trials: int):
175
+
176
+ optimizer.concurrency = concurrency
177
+ optimizer.n_trials = n_trials
178
+
179
+ study = optuna.create_study(direction="maximize")
180
+
181
+ await optimizer(verbosity=0, tree_method="exact")
182
+ ```
183
+
184
+ ## Troubleshooting
185
+
186
+ Resource Constraints: Ensure sufficient compute resources are allocated for the number of parallel jobs specified.
187
+
188
+ Flyte Errors: Refer to the Flyte logs and documentation to debug workflow execution issues.
@@ -0,0 +1,3 @@
1
+ from .optimizer import Optimizer, optimize, suggest
2
+
3
+ __all__ = ["Optimizer", "optimize", "suggest"]
@@ -0,0 +1,212 @@
1
+ import asyncio
2
+ import inspect
3
+ from copy import copy, deepcopy
4
+ from dataclasses import dataclass
5
+ from types import SimpleNamespace
6
+ from typing import Any, Awaitable, Callable, Optional, Union
7
+
8
+ from typing_extensions import Concatenate, ParamSpec
9
+
10
+ import optuna
11
+ from flytekit.core.python_function_task import AsyncPythonFunctionTask
12
+ from flytekit.exceptions.eager import EagerException
13
+
14
+
15
+ class Suggestion: ...
16
+
17
+
18
+ class Number(Suggestion):
19
+ def __post_init__(self):
20
+ if self.low >= self.high:
21
+ raise ValueError("low must be less than high")
22
+
23
+ if self.step is not None and self.step > (self.high - self.low):
24
+ raise ValueError("step must be less than the range of the suggestion")
25
+
26
+
27
+ @dataclass
28
+ class Float(Number):
29
+ low: float
30
+ high: float
31
+ step: Optional[float] = None
32
+ log: bool = False
33
+
34
+
35
+ @dataclass
36
+ class Integer(Number):
37
+ low: int
38
+ high: int
39
+ step: int = 1
40
+ log: bool = False
41
+
42
+
43
+ @dataclass
44
+ class Category(Suggestion):
45
+ choices: list[Any]
46
+
47
+
48
+ suggest = SimpleNamespace(float=Float, integer=Integer, category=Category)
49
+
50
+ P = ParamSpec("P")
51
+
52
+ Result = Union[float, tuple[float, ...]]
53
+
54
+ CallbackType = Callable[Concatenate[optuna.Trial, P], Union[Awaitable[Result], Result]]
55
+
56
+
57
+ @dataclass
58
+ class Optimizer:
59
+ objective: Union[CallbackType, AsyncPythonFunctionTask]
60
+ concurrency: int
61
+ n_trials: int
62
+ study: Optional[optuna.Study] = None
63
+ delay: int = 0
64
+
65
+ """
66
+ Optimizer is a class that allows for the distributed optimization of a flytekit Task using Optuna.
67
+
68
+ Args:
69
+ objective: The objective function to be optimized. This can be a AsyncPythonFunctionTask or a callable.
70
+ concurrency: The number of trials to run concurrently.
71
+ n_trials: The number of trials to run in total.
72
+ study: The study to use for optimization. If None, a new study will be created.
73
+ delay: The delay in seconds between starting each trial. Default is 0.
74
+ """
75
+
76
+ @property
77
+ def is_imperative(self) -> bool:
78
+ return isinstance(self.objective, AsyncPythonFunctionTask)
79
+
80
+ def __post_init__(self):
81
+ if self.study is None:
82
+ self.study = optuna.create_study()
83
+
84
+ if (not isinstance(self.concurrency, int)) or (not self.concurrency > 0):
85
+ raise ValueError("concurrency must be an integer greater than 0")
86
+
87
+ if (not isinstance(self.n_trials, int)) or (not self.n_trials > 0):
88
+ raise ValueError("n_trials must be an integer greater than 0")
89
+
90
+ if not isinstance(self.study, optuna.Study):
91
+ raise ValueError("study must be an optuna.Study")
92
+
93
+ if not isinstance(self.delay, int) or (not self.delay >= 0):
94
+ raise ValueError("delay must be an integer greater than or equal to 0")
95
+
96
+ if self.is_imperative:
97
+ signature = inspect.signature(self.objective.task_function)
98
+
99
+ if signature.return_annotation is float:
100
+ if len(self.study.directions) != 1:
101
+ raise ValueError("the study must have a single objective if objective returns a single float")
102
+
103
+ elif hasattr(signature.return_annotation, "__args__"):
104
+ args = signature.return_annotation.__args__
105
+ if len(args) != len(self.study.directions):
106
+ raise ValueError("objective must return the same number of directions in the study")
107
+
108
+ if not all(arg is float for arg in args):
109
+ raise ValueError("objective function must return a float or tuple of floats")
110
+
111
+ else:
112
+ raise ValueError("objective function must return a float or tuple of floats")
113
+
114
+ else:
115
+ if not callable(self.objective):
116
+ raise ValueError("objective must be a callable or a AsyncPythonFunctionTask")
117
+
118
+ signature = inspect.signature(self.objective)
119
+
120
+ if "trial" not in signature.parameters:
121
+ raise ValueError(
122
+ "objective function must have a parameter called 'trial' if not a AsyncPythonFunctionTask"
123
+ )
124
+
125
+ async def __call__(self, **inputs: P.kwargs):
126
+ """
127
+ Asynchronously executes the objective function remotely.
128
+ Parameters:
129
+ **inputs: inputs to objective function
130
+ """
131
+
132
+ # create semaphore to manage concurrency
133
+ semaphore = asyncio.Semaphore(self.concurrency)
134
+
135
+ # create list of async trials
136
+ trials = [self.spawn(semaphore, deepcopy(inputs)) for _ in range(self.n_trials)]
137
+
138
+ # await all trials to complete
139
+ await asyncio.gather(*trials)
140
+
141
+ async def spawn(self, semaphore: asyncio.Semaphore, inputs: dict[str, Any]):
142
+ async with semaphore:
143
+ await asyncio.sleep(self.delay)
144
+
145
+ # ask for a new trial
146
+ trial: optuna.Trial = self.study.ask()
147
+
148
+ try:
149
+ result: Union[float, tuple[float, ...]]
150
+
151
+ # schedule the trial
152
+ if self.is_imperative:
153
+ result = await self.objective(**process(trial, inputs))
154
+
155
+ else:
156
+ out = self.objective(trial=trial, **inputs)
157
+ result = out if not inspect.isawaitable(out) else await out
158
+
159
+ # tell the study the result
160
+ self.study.tell(trial, result, state=optuna.trial.TrialState.COMPLETE)
161
+
162
+ # if the trial fails, tell the study
163
+ except EagerException:
164
+ self.study.tell(trial, state=optuna.trial.TrialState.FAIL)
165
+
166
+
167
+ def optimize(
168
+ objective: Optional[Union[CallbackType, AsyncPythonFunctionTask]] = None,
169
+ concurrency: int = 1,
170
+ n_trials: int = 1,
171
+ study: Optional[optuna.Study] = None,
172
+ ):
173
+ if objective is not None:
174
+ if callable(objective) or isinstance(objective, AsyncPythonFunctionTask):
175
+ return Optimizer(
176
+ objective=objective,
177
+ concurrency=concurrency,
178
+ n_trials=n_trials,
179
+ study=study,
180
+ )
181
+
182
+ else:
183
+ raise ValueError("This decorator must be called with a callable or a flyte Task")
184
+ else:
185
+
186
+ def decorator(objective):
187
+ return Optimizer(objective=objective, concurrency=concurrency, n_trials=n_trials, study=study)
188
+
189
+ return decorator
190
+
191
+
192
+ def process(trial: optuna.Trial, inputs: dict[str, Any], root: Optional[list[str]] = None) -> dict[str, Any]:
193
+ if root is None:
194
+ root = []
195
+
196
+ suggesters = {
197
+ Float: trial.suggest_float,
198
+ Integer: trial.suggest_int,
199
+ Category: trial.suggest_categorical,
200
+ }
201
+
202
+ for key, value in inputs.items():
203
+ path = copy(root) + [key]
204
+
205
+ if isinstance(inputs[key], Suggestion):
206
+ suggester = suggesters[type(value)]
207
+ inputs[key] = suggester(name=(".").join(path), **vars(value))
208
+
209
+ elif isinstance(value, dict):
210
+ inputs[key] = process(trial=trial, inputs=value, root=path)
211
+
212
+ return inputs
@@ -0,0 +1,28 @@
1
+ Metadata-Version: 2.2
2
+ Name: flytekitplugins-optuna
3
+ Version: 1.15.0
4
+ Summary: Optuna plugin for flytekit
5
+ Author: flyteorg
6
+ Author-email: admin@flyte.org
7
+ License: apache2
8
+ Classifier: Intended Audience :: Science/Research
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: License :: OSI Approved :: Apache Software License
11
+ Classifier: Programming Language :: Python :: 3.9
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Topic :: Scientific/Engineering
14
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
15
+ Classifier: Topic :: Software Development
16
+ Classifier: Topic :: Software Development :: Libraries
17
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
18
+ Requires-Python: >=3.9
19
+ Requires-Dist: flytekit>=1.15.0
20
+ Requires-Dist: optuna<5.0.0,>=4.0.0
21
+ Requires-Dist: typing-extensions<5.0,>=4.10
22
+ Dynamic: author
23
+ Dynamic: author-email
24
+ Dynamic: classifier
25
+ Dynamic: license
26
+ Dynamic: requires-dist
27
+ Dynamic: requires-python
28
+ Dynamic: summary
@@ -0,0 +1,15 @@
1
+ README.md
2
+ setup.py
3
+ flytekitplugins/optuna/__init__.py
4
+ flytekitplugins/optuna/optimizer.py
5
+ flytekitplugins_optuna.egg-info/PKG-INFO
6
+ flytekitplugins_optuna.egg-info/SOURCES.txt
7
+ flytekitplugins_optuna.egg-info/dependency_links.txt
8
+ flytekitplugins_optuna.egg-info/entry_points.txt
9
+ flytekitplugins_optuna.egg-info/namespace_packages.txt
10
+ flytekitplugins_optuna.egg-info/requires.txt
11
+ flytekitplugins_optuna.egg-info/top_level.txt
12
+ tests/test_callback.py
13
+ tests/test_decorator.py
14
+ tests/test_imperative.py
15
+ tests/test_validation.py
@@ -0,0 +1,2 @@
1
+ [flytekit.plugins]
2
+ optuna = flytekitplugins.optuna
@@ -0,0 +1,3 @@
1
+ flytekit>=1.15.0
2
+ optuna<5.0.0,>=4.0.0
3
+ typing-extensions<5.0,>=4.10
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,35 @@
1
+ from setuptools import setup
2
+
3
+ PLUGIN_NAME = "optuna"
4
+
5
+ microlib_name = f"flytekitplugins-{PLUGIN_NAME}"
6
+
7
+ plugin_requires = ["flytekit>=1.15.0", "optuna>=4.0.0,<5.0.0", "typing-extensions>=4.10,<5.0"]
8
+
9
+ __version__ = "1.15.0"
10
+
11
+ setup(
12
+ name=microlib_name,
13
+ version=__version__,
14
+ author="flyteorg",
15
+ author_email="admin@flyte.org",
16
+ description="Optuna plugin for flytekit",
17
+ namespace_packages=["flytekitplugins"],
18
+ packages=[f"flytekitplugins.{PLUGIN_NAME}"],
19
+ install_requires=plugin_requires,
20
+ license="apache2",
21
+ python_requires=">=3.9",
22
+ classifiers=[
23
+ "Intended Audience :: Science/Research",
24
+ "Intended Audience :: Developers",
25
+ "License :: OSI Approved :: Apache Software License",
26
+ "Programming Language :: Python :: 3.9",
27
+ "Programming Language :: Python :: 3.10",
28
+ "Topic :: Scientific/Engineering",
29
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
30
+ "Topic :: Software Development",
31
+ "Topic :: Software Development :: Libraries",
32
+ "Topic :: Software Development :: Libraries :: Python Modules",
33
+ ],
34
+ entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]},
35
+ )
@@ -0,0 +1,91 @@
1
+ from typing import Union
2
+
3
+ import asyncio
4
+ from typing import Union
5
+ import optuna
6
+
7
+ import flytekit as fl
8
+ from flytekitplugins.optuna import Optimizer
9
+
10
+
11
+ def test_callback():
12
+
13
+
14
+ @fl.task
15
+ async def objective(letter: str, number: Union[float, int], other: str, fixed: str) -> float:
16
+
17
+ loss = len(letter) + number + len(other) + len(fixed)
18
+
19
+ return float(loss)
20
+
21
+ def callback(trial: optuna.Trial, fixed: str):
22
+
23
+ letter = trial.suggest_categorical("booster", ["A", "B", "BLAH"])
24
+
25
+ if letter == "A":
26
+ number = trial.suggest_int("number_A", 1, 10)
27
+ elif letter == "B":
28
+ number = trial.suggest_float("number_B", 10., 20.)
29
+ else:
30
+ number = 10
31
+
32
+ other = trial.suggest_categorical("other", ["Something", "another word", "a phrase"])
33
+
34
+ return objective(letter, number, other, fixed)
35
+
36
+
37
+ @fl.eager
38
+ async def train(concurrency: int, n_trials: int) -> float:
39
+
40
+ study = optuna.create_study(direction="maximize")
41
+
42
+ optimizer = Optimizer(callback, concurrency=concurrency, n_trials=n_trials, study=study)
43
+
44
+ await optimizer(fixed="hello!")
45
+
46
+ return float(optimizer.study.best_value)
47
+
48
+ loss = asyncio.run(train(concurrency=2, n_trials=10))
49
+
50
+ assert isinstance(loss, float)
51
+
52
+
53
+ def test_async_callback():
54
+
55
+ @fl.task
56
+ async def objective(letter: str, number: Union[float, int], other: str, fixed: str) -> float:
57
+
58
+ loss = len(letter) + number + len(other) + len(fixed)
59
+
60
+ return float(loss)
61
+
62
+ async def callback(trial: optuna.Trial, fixed: str):
63
+
64
+ letter = trial.suggest_categorical("booster", ["A", "B", "BLAH"])
65
+
66
+ if letter == "A":
67
+ number = trial.suggest_int("number_A", 1, 10)
68
+ elif letter == "B":
69
+ number = trial.suggest_float("number_B", 10., 20.)
70
+ else:
71
+ number = 10
72
+
73
+ other = trial.suggest_categorical("other", ["Something", "another word", "a phrase"])
74
+
75
+ return await objective(letter, number, other, fixed)
76
+
77
+
78
+ @fl.eager
79
+ async def train(concurrency: int, n_trials: int) -> float:
80
+
81
+ study = optuna.create_study(direction="maximize")
82
+
83
+ optimizer = Optimizer(callback, concurrency=concurrency, n_trials=n_trials, study=study)
84
+
85
+ await optimizer(fixed="hello!")
86
+
87
+ return float(optimizer.study.best_value)
88
+
89
+ loss = asyncio.run(train(concurrency=2, n_trials=10))
90
+
91
+ assert isinstance(loss, float)
@@ -0,0 +1,114 @@
1
+ from typing import Union
2
+ import math
3
+
4
+ import asyncio
5
+ from typing import Union
6
+ import optuna
7
+
8
+ import flytekit as fl
9
+ from flytekitplugins.optuna import optimize, suggest
10
+
11
+
12
+ def test_local_exec():
13
+
14
+ @fl.eager
15
+ async def train(concurrency: int, n_trials: int) -> float:
16
+
17
+ @optimize(concurrency=concurrency, n_trials=n_trials)
18
+ @fl.task
19
+ async def optimizer(x: float, y: int, z: int, power: int) -> float:
20
+ return (((x - 5) ** 2) + (y + 4) ** 4 + (3 * z - 3) ** 2) ** power
21
+
22
+ await optimizer(
23
+ x=suggest.float(low=-10, high=10),
24
+ y=suggest.integer(low=-10, high=10),
25
+ z=suggest.category([-5, 0, 3, 6, 9]),
26
+ power=2,
27
+ )
28
+
29
+ return optimizer.study.best_value
30
+
31
+ loss = asyncio.run(train(concurrency=2, n_trials=10))
32
+
33
+ assert isinstance(loss, float)
34
+
35
+
36
+ def test_callback():
37
+
38
+
39
+ @fl.task
40
+ async def objective(letter: str, number: Union[float, int], other: str, fixed: str) -> float:
41
+
42
+ loss = len(letter) + number + len(other) + len(fixed)
43
+
44
+ return float(loss)
45
+
46
+ @optimize(n_trials=10, concurrency=2)
47
+ def optimizer(trial: optuna.Trial, fixed: str):
48
+
49
+ letter = trial.suggest_categorical("booster", ["A", "B", "BLAH"])
50
+
51
+ if letter == "A":
52
+ number = trial.suggest_int("number_A", 1, 10)
53
+ elif letter == "B":
54
+ number = trial.suggest_float("number_B", 10., 20.)
55
+ else:
56
+ number = 10
57
+
58
+ other = trial.suggest_categorical("other", ["Something", "another word", "a phrase"])
59
+
60
+ return objective(letter, number, other, fixed)
61
+
62
+
63
+ @fl.eager
64
+ async def train(concurrency: int, n_trials: int) -> float:
65
+
66
+ await optimizer(fixed="hello!")
67
+
68
+ return float(optimizer.study.best_value)
69
+
70
+ loss = asyncio.run(train(concurrency=2, n_trials=10))
71
+
72
+ assert isinstance(loss, float)
73
+
74
+
75
+ def test_unparameterized_callback():
76
+
77
+
78
+ @fl.task
79
+ async def objective(letter: str, number: Union[float, int], other: str, fixed: str) -> float:
80
+
81
+ loss = len(letter) + number + len(other) + len(fixed)
82
+
83
+ return float(loss)
84
+
85
+ @optimize
86
+ def optimizer(trial: optuna.Trial, fixed: str):
87
+
88
+ letter = trial.suggest_categorical("booster", ["A", "B", "BLAH"])
89
+
90
+ if letter == "A":
91
+ number = trial.suggest_int("number_A", 1, 10)
92
+ elif letter == "B":
93
+ number = trial.suggest_float("number_B", 10., 20.)
94
+ else:
95
+ number = 10
96
+
97
+ other = trial.suggest_categorical("other", ["Something", "another word", "a phrase"])
98
+
99
+ return objective(letter, number, other, fixed)
100
+
101
+
102
+ @fl.eager
103
+ async def train(concurrency: int, n_trials: int) -> float:
104
+
105
+ optimizer.n_trials = n_trials
106
+ optimizer.concurrency = concurrency
107
+
108
+ await optimizer(fixed="hello!")
109
+
110
+ return float(optimizer.study.best_value)
111
+
112
+ loss = asyncio.run(train(concurrency=2, n_trials=10))
113
+
114
+ assert isinstance(loss, float)
@@ -0,0 +1,97 @@
1
+ from typing import Union
2
+
3
+ import asyncio
4
+ from typing import Union
5
+
6
+ import optuna
7
+ import flytekit as fl
8
+ from flytekitplugins.optuna import Optimizer, suggest
9
+
10
+
11
+
12
+ def test_local_exec():
13
+
14
+
15
+ @fl.task
16
+ async def objective(x: float, y: int, z: int, power: int) -> float:
17
+ return (((x - 5) ** 2) + (y + 4) ** 4 + (3 * z - 3) ** 2) ** power
18
+
19
+
20
+ @fl.eager
21
+ async def train(concurrency: int, n_trials: int) -> float:
22
+ optimizer = Optimizer(objective, concurrency=concurrency, n_trials=n_trials)
23
+
24
+ await optimizer(
25
+ x=suggest.float(low=-10, high=10),
26
+ y=suggest.integer(low=-10, high=10),
27
+ z=suggest.category([-5, 0, 3, 6, 9]),
28
+ power=2,
29
+ )
30
+
31
+ return optimizer.study.best_value
32
+
33
+ loss = asyncio.run(train(concurrency=2, n_trials=10))
34
+
35
+ assert isinstance(loss, float)
36
+
37
+ def test_tuple_out():
38
+
39
+ @fl.task
40
+ async def objective(x: float, y: int, z: int, power: int) -> tuple[float, float]:
41
+
42
+ y0 = (((x - 5) ** 2) + (y + 4) ** 4 + (3 * z - 3) ** 2) ** power
43
+ y1 = (((x - 2) ** 4) + (y + 1) ** 2 + (4 * z - 1))
44
+
45
+ return y0, y1
46
+
47
+
48
+ @fl.eager
49
+ async def train(concurrency: int, n_trials: int):
50
+ optimizer = Optimizer(
51
+ objective=objective,
52
+ concurrency=concurrency,
53
+ n_trials=n_trials,
54
+ study=optuna.create_study(directions=["maximize", "maximize"])
55
+ )
56
+
57
+ await optimizer(
58
+ x=suggest.float(low=-10, high=10),
59
+ y=suggest.integer(low=-10, high=10),
60
+ z=suggest.category([-5, 0, 3, 6, 9]),
61
+ power=2,
62
+ )
63
+
64
+ asyncio.run(train(concurrency=2, n_trials=10))
65
+
66
+
67
+ def test_bundled_local_exec():
68
+
69
+ @fl.task
70
+ async def objective(suggestions: dict[str, Union[int, float]], z: int, power: int) -> float:
71
+
72
+ # building out a large set of typed inputs is exhausting, so we can just use a dict
73
+
74
+ x, y = suggestions["x"], suggestions["y"]
75
+
76
+ return (((x - 5) ** 2) + (y + 4) ** 4) ** power
77
+
78
+
79
+ @fl.eager
80
+ async def train(concurrency: int, n_trials: int) -> float:
81
+ optimizer = Optimizer(objective, concurrency=concurrency, n_trials=n_trials)
82
+
83
+ suggestions = {
84
+ "x": suggest.float(low=-10, high=10),
85
+ "y": suggest.integer(low=-10, high=10),
86
+ }
87
+
88
+ await optimizer(
89
+ suggestions=suggestions,
90
+ z=suggest.category([-5, 0, 3, 6, 9]),
91
+ power=2,
92
+ )
93
+
94
+ return optimizer.study.best_value
95
+ loss = asyncio.run(train(concurrency=2, n_trials=10))
96
+
97
+ assert isinstance(loss, float)
@@ -0,0 +1,85 @@
1
+ import flytekit as fl
2
+ import pytest
3
+
4
+ from flytekitplugins.optuna import Optimizer
5
+
6
+
7
+ @fl.task
8
+ async def objective(x: float, y: int, z: int, power: int) -> float:
9
+ return 1.0
10
+
11
+ def test_concurrency():
12
+
13
+ with pytest.raises(ValueError):
14
+ Optimizer(objective, concurrency=-1, n_trials=10)
15
+
16
+ with pytest.raises(ValueError):
17
+ Optimizer(objective, concurrency=0, n_trials=10)
18
+
19
+ with pytest.raises(ValueError):
20
+ Optimizer(objective, concurrency="abc", n_trials=10)
21
+
22
+
23
+ def test_n_trials():
24
+
25
+ with pytest.raises(ValueError):
26
+ Optimizer(objective, concurrency=3, n_trials=-10)
27
+
28
+ with pytest.raises(ValueError):
29
+ Optimizer(objective, concurrency=3, n_trials=0)
30
+
31
+ with pytest.raises(ValueError):
32
+ Optimizer(objective, concurrency=3, n_trials="abc")
33
+
34
+
35
+
36
+ def test_study():
37
+
38
+ with pytest.raises(ValueError):
39
+ Optimizer(objective, concurrency=3, n_trials=10, study="abc")
40
+
41
+
42
+
43
+ def test_delay():
44
+
45
+ with pytest.raises(ValueError):
46
+ Optimizer(objective, concurrency=3, n_trials=10, delay=-1)
47
+
48
+ with pytest.raises(ValueError):
49
+ Optimizer(objective, concurrency=3, n_trials=10, delay="abc")
50
+
51
+
52
+
53
+ def test_objective():
54
+
55
+ @fl.workflow
56
+ def workflow(x: int, y: int, z: int, power: int) -> float:
57
+ return 1.0
58
+
59
+
60
+ with pytest.raises(ValueError):
61
+ Optimizer(workflow, concurrency=3, n_trials=10)
62
+
63
+ @fl.task
64
+ async def mistyped_objective(x: int, y: int, z: int, power: int) -> str:
65
+ return "abc"
66
+
67
+ with pytest.raises(ValueError):
68
+ Optimizer(mistyped_objective, concurrency=3, n_trials=10)
69
+
70
+ @fl.task
71
+ def synchronous_objective(x: int, y: int, z: int, power: int) -> float:
72
+ return 1.0
73
+
74
+ with pytest.raises(ValueError):
75
+ Optimizer(synchronous_objective, concurrency=3, n_trials=10)
76
+
77
+
78
+
79
+ def test_callback():
80
+
81
+ def callback(value: float):
82
+ return 1.0
83
+
84
+ with pytest.raises(ValueError):
85
+ Optimizer(callback, concurrency=3, n_trials=10)