optuna-integration 3.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. optuna-integration-3.2.0/LICENSE +21 -0
  2. optuna-integration-3.2.0/MANIFEST.in +1 -0
  3. optuna-integration-3.2.0/PKG-INFO +112 -0
  4. optuna-integration-3.2.0/README.md +63 -0
  5. optuna-integration-3.2.0/optuna_integration/__init__.py +0 -0
  6. optuna-integration-3.2.0/optuna_integration/_imports.py +125 -0
  7. optuna-integration-3.2.0/optuna_integration/allennlp/__init__.py +6 -0
  8. optuna-integration-3.2.0/optuna_integration/allennlp/_dump_best_config.py +61 -0
  9. optuna-integration-3.2.0/optuna_integration/allennlp/_environment.py +12 -0
  10. optuna-integration-3.2.0/optuna_integration/allennlp/_executor.py +234 -0
  11. optuna-integration-3.2.0/optuna_integration/allennlp/_pruner.py +218 -0
  12. optuna-integration-3.2.0/optuna_integration/allennlp/_variables.py +72 -0
  13. optuna-integration-3.2.0/optuna_integration/catalyst.py +30 -0
  14. optuna-integration-3.2.0/optuna_integration/chainer.py +99 -0
  15. optuna-integration-3.2.0/optuna_integration/chainermn.py +329 -0
  16. optuna-integration-3.2.0/optuna_integration/keras.py +76 -0
  17. optuna-integration-3.2.0/optuna_integration/skorch.py +48 -0
  18. optuna-integration-3.2.0/optuna_integration/tensorflow.py +84 -0
  19. optuna-integration-3.2.0/optuna_integration/tfkeras.py +62 -0
  20. optuna-integration-3.2.0/optuna_integration/version.py +1 -0
  21. optuna-integration-3.2.0/optuna_integration.egg-info/PKG-INFO +112 -0
  22. optuna-integration-3.2.0/optuna_integration.egg-info/SOURCES.txt +33 -0
  23. optuna-integration-3.2.0/optuna_integration.egg-info/dependency_links.txt +1 -0
  24. optuna-integration-3.2.0/optuna_integration.egg-info/requires.txt +31 -0
  25. optuna-integration-3.2.0/optuna_integration.egg-info/top_level.txt +1 -0
  26. optuna-integration-3.2.0/pyproject.toml +99 -0
  27. optuna-integration-3.2.0/setup.cfg +27 -0
  28. optuna-integration-3.2.0/tests/test_catalyst.py +19 -0
  29. optuna-integration-3.2.0/tests/test_chainer.py +141 -0
  30. optuna-integration-3.2.0/tests/test_chainermn.py +427 -0
  31. optuna-integration-3.2.0/tests/test_keras.py +67 -0
  32. optuna-integration-3.2.0/tests/test_skorch.py +54 -0
  33. optuna-integration-3.2.0/tests/test_tensorflow.py +66 -0
  34. optuna-integration-3.2.0/tests/test_tfkeras.py +67 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2018 Preferred Networks, Inc.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ include LICENSE
@@ -0,0 +1,112 @@
1
+ Metadata-Version: 2.1
2
+ Name: optuna-integration
3
+ Version: 3.2.0
4
+ Summary: Integration libraries of Optuna.
5
+ License: MIT License
6
+
7
+ Copyright (c) 2018 Preferred Networks, Inc.
8
+
9
+ Permission is hereby granted, free of charge, to any person obtaining a copy
10
+ of this software and associated documentation files (the "Software"), to deal
11
+ in the Software without restriction, including without limitation the rights
12
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13
+ copies of the Software, and to permit persons to whom the Software is
14
+ furnished to do so, subject to the following conditions:
15
+
16
+ The above copyright notice and this permission notice shall be included in all
17
+ copies or substantial portions of the Software.
18
+
19
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25
+ SOFTWARE.
26
+
27
+ Classifier: Development Status :: 5 - Production/Stable
28
+ Classifier: Intended Audience :: Science/Research
29
+ Classifier: Intended Audience :: Developers
30
+ Classifier: License :: OSI Approved :: MIT License
31
+ Classifier: Programming Language :: Python :: 3
32
+ Classifier: Programming Language :: Python :: 3.7
33
+ Classifier: Programming Language :: Python :: 3.8
34
+ Classifier: Programming Language :: Python :: 3.9
35
+ Classifier: Programming Language :: Python :: 3.10
36
+ Classifier: Programming Language :: Python :: 3 :: Only
37
+ Classifier: Topic :: Scientific/Engineering
38
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
39
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
40
+ Classifier: Topic :: Software Development
41
+ Classifier: Topic :: Software Development :: Libraries
42
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
43
+ Description-Content-Type: text/markdown
44
+ Provides-Extra: test
45
+ Provides-Extra: checking
46
+ Provides-Extra: document
47
+ Provides-Extra: all
48
+ License-File: LICENSE
49
+
50
+ # Optuna-Integration
51
+
52
+ [![Python](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9%20%7C%203.10-blue)](https://www.python.org)
53
+ [![GitHub license](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/optuna/optuna-integration)
54
+ [![Codecov](https://codecov.io/gh/optuna/optuna-integration/branch/master/graph/badge.svg)](https://codecov.io/gh/optuna/optuna-integration/branch/master)
55
+ <!-- [![pypi](https://img.shields.io/pypi/v/optuna.svg)](https://pypi.python.org/pypi/optuna-integration) -->
56
+ <!-- [![conda](https://img.shields.io/conda/vn/conda-forge/optuna.svg)](https://anaconda.org/conda-forge/optuna-integration) -->
57
+ [![Read the Docs](https://readthedocs.org/projects/optuna-integration/badge/?version=stable)](https://optuna-integration.readthedocs.io/en/stable/)
58
+
59
+ [**Docs**](https://optuna-integration.readthedocs.io/en/stable/)
60
+
61
+ *Optuna-Integration* is an integration module of [Optuna](https://github.com/optuna/optuna).
62
+ This package allows us to use Optuna, an automatic Hyperparameter optimization software framework,
63
+ integrated with many useful tools like PyTorch, sklearn, TensorFlow, etc.
64
+
65
+ ## Integrations
66
+
67
+ Optuna-Integration API reference is [here](https://optuna-integration.readthedocs.io/en/stable/reference/index.html).
68
+
69
+ * [AllenNLP](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#allennlp) ([example](https://github.com/optuna/optuna-examples/tree/main/allennlp))
70
+ * [Catalyst](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#catalyst) ([example](https://github.com/optuna/optuna-examples/blob/main/pytorch/catalyst_simple.py))
71
+ * [Chainer](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#chainer) ([example](https://github.com/optuna/optuna-examples/tree/main/chainer/chainer_integration.py))
72
+ * [ChainerMN](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#chainermn) ([example](https://github.com/optuna/optuna-examples/tree/main/chainer/chainermn_simple.py))
73
+ * [Keras](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#keras) ([example](https://github.com/optuna/optuna-examples/tree/main/keras))
74
+ * [skorch](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#skorch) ([example](https://github.com/optuna/optuna-examples/tree/main/pytorch/skorch_simple.py))
75
+ * [tf.keras](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#tensorflow) ([example](https://github.com/optuna/optuna-examples/tree/main/tfkeras/tfkeras_integration.py))
76
+
77
+ ## Installation
78
+
79
+ Optuna-Integration is available at [the Python Package Index](https://pypi.org/project/optuna-integration/) and on [Anaconda Cloud](https://anaconda.org/conda-forge/optuna-integration).
80
+
81
+ ```bash
82
+ # PyPI
83
+ $ pip install optuna-integration
84
+ ```
85
+
86
+ ```bash
87
+ # Anaconda Cloud
88
+ $ conda install -c conda-forge optuna-integration
89
+ ```
90
+
91
+ Optuna-Integration supports from Python 3.7 to Python 3.10.
92
+
93
+ Also, we also provide Optuna docker images on [DockerHub](https://hub.docker.com/r/optuna/optuna).
94
+
95
+ ## Communication
96
+
97
+ - [GitHub Discussions] for questions.
98
+ - [GitHub Issues] for bug reports and feature requests.
99
+
100
+ [GitHub Discussions]: https://github.com/optuna/optuna-integration/discussions
101
+ [GitHub issues]: https://github.com/optuna/optuna-integration/issues
102
+
103
+ ## Contribution
104
+
105
+ Any contributions to Optuna-Integration are more than welcome!
106
+
107
+ For general guidelines how to contribute to the project, take a look at [CONTRIBUTING.md](./CONTRIBUTING.md).
108
+
109
+ ## Reference
110
+
111
+ Takuya Akiba, Shotaro Sano, Toshihiko Yanase, Takeru Ohta, and Masanori Koyama. 2019.
112
+ Optuna: A Next-generation Hyperparameter Optimization Framework. In KDD ([arXiv](https://arxiv.org/abs/1907.10902)).
@@ -0,0 +1,63 @@
1
+ # Optuna-Integration
2
+
3
+ [![Python](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9%20%7C%203.10-blue)](https://www.python.org)
4
+ [![GitHub license](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/optuna/optuna-integration)
5
+ [![Codecov](https://codecov.io/gh/optuna/optuna-integration/branch/master/graph/badge.svg)](https://codecov.io/gh/optuna/optuna-integration/branch/master)
6
+ <!-- [![pypi](https://img.shields.io/pypi/v/optuna.svg)](https://pypi.python.org/pypi/optuna-integration) -->
7
+ <!-- [![conda](https://img.shields.io/conda/vn/conda-forge/optuna.svg)](https://anaconda.org/conda-forge/optuna-integration) -->
8
+ [![Read the Docs](https://readthedocs.org/projects/optuna-integration/badge/?version=stable)](https://optuna-integration.readthedocs.io/en/stable/)
9
+
10
+ [**Docs**](https://optuna-integration.readthedocs.io/en/stable/)
11
+
12
+ *Optuna-Integration* is an integration module of [Optuna](https://github.com/optuna/optuna).
13
+ This package allows us to use Optuna, an automatic Hyperparameter optimization software framework,
14
+ integrated with many useful tools like PyTorch, sklearn, TensorFlow, etc.
15
+
16
+ ## Integrations
17
+
18
+ Optuna-Integration API reference is [here](https://optuna-integration.readthedocs.io/en/stable/reference/index.html).
19
+
20
+ * [AllenNLP](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#allennlp) ([example](https://github.com/optuna/optuna-examples/tree/main/allennlp))
21
+ * [Catalyst](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#catalyst) ([example](https://github.com/optuna/optuna-examples/blob/main/pytorch/catalyst_simple.py))
22
+ * [Chainer](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#chainer) ([example](https://github.com/optuna/optuna-examples/tree/main/chainer/chainer_integration.py))
23
+ * [ChainerMN](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#chainermn) ([example](https://github.com/optuna/optuna-examples/tree/main/chainer/chainermn_simple.py))
24
+ * [Keras](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#keras) ([example](https://github.com/optuna/optuna-examples/tree/main/keras))
25
+ * [skorch](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#skorch) ([example](https://github.com/optuna/optuna-examples/tree/main/pytorch/skorch_simple.py))
26
+ * [tf.keras](https://optuna-integration.readthedocs.io/en/stable/reference/index.html#tensorflow) ([example](https://github.com/optuna/optuna-examples/tree/main/tfkeras/tfkeras_integration.py))
27
+
28
+ ## Installation
29
+
30
+ Optuna-Integration is available at [the Python Package Index](https://pypi.org/project/optuna-integration/) and on [Anaconda Cloud](https://anaconda.org/conda-forge/optuna-integration).
31
+
32
+ ```bash
33
+ # PyPI
34
+ $ pip install optuna-integration
35
+ ```
36
+
37
+ ```bash
38
+ # Anaconda Cloud
39
+ $ conda install -c conda-forge optuna-integration
40
+ ```
41
+
42
+ Optuna-Integration supports from Python 3.7 to Python 3.10.
43
+
44
+ Also, we also provide Optuna docker images on [DockerHub](https://hub.docker.com/r/optuna/optuna).
45
+
46
+ ## Communication
47
+
48
+ - [GitHub Discussions] for questions.
49
+ - [GitHub Issues] for bug reports and feature requests.
50
+
51
+ [GitHub Discussions]: https://github.com/optuna/optuna-integration/discussions
52
+ [GitHub issues]: https://github.com/optuna/optuna-integration/issues
53
+
54
+ ## Contribution
55
+
56
+ Any contributions to Optuna-Integration are more than welcome!
57
+
58
+ For general guidelines how to contribute to the project, take a look at [CONTRIBUTING.md](./CONTRIBUTING.md).
59
+
60
+ ## Reference
61
+
62
+ Takuya Akiba, Shotaro Sano, Toshihiko Yanase, Takeru Ohta, and Masanori Koyama. 2019.
63
+ Optuna: A Next-generation Hyperparameter Optimization Framework. In KDD ([arXiv](https://arxiv.org/abs/1907.10902)).
@@ -0,0 +1,125 @@
1
+ import importlib
2
+ import types
3
+ from types import TracebackType
4
+ from typing import Any
5
+ from typing import Optional
6
+ from typing import Tuple
7
+ from typing import Type
8
+
9
+
10
+ class _DeferredImportExceptionContextManager:
11
+ """Context manager to defer exceptions from imports.
12
+
13
+ Catches :exc:`ImportError` and :exc:`SyntaxError`.
14
+ If any exception is caught, this class raises an :exc:`ImportError` when being checked.
15
+
16
+ """
17
+
18
+ def __init__(self) -> None:
19
+ self._deferred: Optional[Tuple[Exception, str]] = None
20
+
21
+ def __enter__(self) -> "_DeferredImportExceptionContextManager":
22
+ """Enter the context manager.
23
+
24
+ Returns:
25
+ Itself.
26
+
27
+ """
28
+ return self
29
+
30
+ def __exit__(
31
+ self,
32
+ exc_type: Optional[Type[Exception]],
33
+ exc_value: Optional[Exception],
34
+ traceback: Optional[TracebackType],
35
+ ) -> Optional[bool]:
36
+ """Exit the context manager.
37
+
38
+ Args:
39
+ exc_type:
40
+ Raised exception type. :obj:`None` if nothing is raised.
41
+ exc_value:
42
+ Raised exception object. :obj:`None` if nothing is raised.
43
+ traceback:
44
+ Associated traceback. :obj:`None` if nothing is raised.
45
+
46
+ Returns:
47
+ :obj:`None` if nothing is deferred, otherwise :obj:`True`.
48
+ :obj:`True` will suppress any exceptions avoiding them from propagating.
49
+
50
+ """
51
+ if isinstance(exc_value, (ImportError, SyntaxError)):
52
+ if isinstance(exc_value, ImportError):
53
+ message = (
54
+ "Tried to import '{}' but failed. Please make sure that the package is "
55
+ "installed correctly to use this feature. Actual error: {}."
56
+ ).format(exc_value.name, exc_value)
57
+ elif isinstance(exc_value, SyntaxError):
58
+ message = (
59
+ "Tried to import a package but failed due to a syntax error in {}. Please "
60
+ "make sure that the Python version is correct to use this feature. Actual "
61
+ "error: {}."
62
+ ).format(exc_value.filename, exc_value)
63
+ else:
64
+ assert False
65
+
66
+ self._deferred = (exc_value, message)
67
+ return True
68
+ return None
69
+
70
+ def is_successful(self) -> bool:
71
+ """Return whether the context manager has caught any exceptions.
72
+
73
+ Returns:
74
+ :obj:`True` if no exceptions are caught, :obj:`False` otherwise.
75
+
76
+ """
77
+ return self._deferred is None
78
+
79
+ def check(self) -> None:
80
+ """Check whether the context manager has caught any exceptions.
81
+
82
+ Raises:
83
+ :exc:`ImportError`:
84
+ If any exception was caught from the caught exception.
85
+
86
+ """
87
+ if self._deferred is not None:
88
+ exc_value, message = self._deferred
89
+ raise ImportError(message) from exc_value
90
+
91
+
92
+ def try_import() -> _DeferredImportExceptionContextManager:
93
+ """Create a context manager that can wrap imports of optional packages to defer exceptions.
94
+
95
+ Returns:
96
+ Deferred import context manager.
97
+
98
+ """
99
+ return _DeferredImportExceptionContextManager()
100
+
101
+
102
+ class _LazyImport(types.ModuleType):
103
+ """Module wrapper for lazy import.
104
+
105
+ This class wraps the specified modules and lazily imports them only when accessed.
106
+ Otherwise, `import optuna-integration` is slowed down by importing all submodules and
107
+ dependencies even if not required.
108
+ Within this project's usage, importlib override this module's attribute on the first
109
+ access and the imported submodule is directly accessed from the second access.
110
+
111
+ Args:
112
+ name: Name of module to apply lazy import.
113
+ """
114
+
115
+ def __init__(self, name: str) -> None:
116
+ super().__init__(name)
117
+ self._name = name
118
+
119
+ def _load(self) -> types.ModuleType:
120
+ module = importlib.import_module(self._name)
121
+ self.__dict__.update(module.__dict__)
122
+ return module
123
+
124
+ def __getattr__(self, item: str) -> Any:
125
+ return getattr(self._load(), item)
@@ -0,0 +1,6 @@
1
+ from optuna_integration.allennlp._dump_best_config import dump_best_config
2
+ from optuna_integration.allennlp._executor import AllenNLPExecutor
3
+ from optuna_integration.allennlp._pruner import AllenNLPPruningCallback
4
+
5
+
6
+ __all__ = ["dump_best_config", "AllenNLPExecutor", "AllenNLPPruningCallback"]
@@ -0,0 +1,61 @@
1
+ import json
2
+
3
+ import optuna
4
+
5
+ from optuna_integration._imports import try_import
6
+ from optuna_integration.allennlp._environment import _environment_variables
7
+
8
+
9
+ with try_import() as _imports:
10
+ import _jsonnet
11
+
12
+
13
+ def dump_best_config(input_config_file: str, output_config_file: str, study: optuna.Study) -> None:
14
+ """Save JSON config file with environment variables and best performing hyperparameters.
15
+
16
+ Args:
17
+ input_config_file:
18
+ Input Jsonnet config file used with
19
+ :class:`~optuna_integration.AllenNLPExecutor`.
20
+ output_config_file:
21
+ Output JSON config file.
22
+ study:
23
+ Instance of :class:`~optuna.study.Study`.
24
+ Note that :func:`~optuna.study.Study.optimize` must have been called.
25
+
26
+ """
27
+ _imports.check()
28
+
29
+ # Get environment variables.
30
+ ext_vars = _environment_variables()
31
+
32
+ # Get the best hyperparameters.
33
+ best_params = study.best_params
34
+ for key, value in best_params.items():
35
+ best_params[key] = str(value)
36
+
37
+ # If keys both appear in environment variables and best_params,
38
+ # values in environment variables are overwritten, which means best_params is prioritized.
39
+ ext_vars.update(best_params)
40
+
41
+ best_config = json.loads(_jsonnet.evaluate_file(input_config_file, ext_vars=ext_vars))
42
+
43
+ # `optuna_pruner` only works with Optuna.
44
+ # It removes when dumping configuration since
45
+ # the result of `dump_best_config` can be passed to
46
+ # `allennlp train`.
47
+ if "callbacks" in best_config["trainer"]:
48
+ new_callbacks = []
49
+ callbacks = best_config["trainer"]["callbacks"]
50
+ for callback in callbacks:
51
+ if callback["type"] == "optuna_pruner":
52
+ continue
53
+ new_callbacks.append(callback)
54
+
55
+ if len(new_callbacks) == 0:
56
+ best_config["trainer"].pop("callbacks")
57
+ else:
58
+ best_config["trainer"]["callbacks"] = new_callbacks
59
+
60
+ with open(output_config_file, "w") as f:
61
+ json.dump(best_config, f, indent=4)
@@ -0,0 +1,12 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+
6
+ def _is_encodable(value: str) -> bool:
7
+ # https://github.com/allenai/allennlp/blob/master/allennlp/common/params.py#L77-L85
8
+ return (value == "") or (value.encode("utf-8", "ignore") != b"")
9
+
10
+
11
+ def _environment_variables() -> dict[str, str]:
12
+ return {key: value for key, value in os.environ.items() if _is_encodable(value)}
@@ -0,0 +1,234 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from typing import Any
6
+ import warnings
7
+
8
+ import optuna
9
+ from optuna import TrialPruned
10
+ from optuna._experimental import experimental_class
11
+
12
+ from optuna_integration._imports import try_import
13
+ from optuna_integration.allennlp._environment import _environment_variables
14
+ from optuna_integration.allennlp._variables import _VariableManager
15
+ from optuna_integration.allennlp._variables import OPTUNA_ALLENNLP_DISTRIBUTED_FLAG
16
+
17
+
18
+ with try_import() as _imports:
19
+ import allennlp
20
+ import allennlp.commands
21
+ import allennlp.common.cached_transformers
22
+ import allennlp.common.util
23
+
24
+ # TrainerCallback is conditionally imported because allennlp may be unavailable in
25
+ # the environment that builds the documentation.
26
+ if _imports.is_successful():
27
+ import _jsonnet
28
+ import psutil
29
+ from torch.multiprocessing.spawn import ProcessRaisedException
30
+
31
+
32
+ def _fetch_pruner_config(trial: optuna.Trial) -> dict[str, Any]:
33
+ pruner = trial.study.pruner
34
+ kwargs: dict[str, Any] = {}
35
+
36
+ if isinstance(pruner, optuna.pruners.HyperbandPruner):
37
+ kwargs["min_resource"] = pruner._min_resource
38
+ kwargs["max_resource"] = pruner._max_resource
39
+ kwargs["reduction_factor"] = pruner._reduction_factor
40
+
41
+ elif isinstance(pruner, optuna.pruners.MedianPruner):
42
+ kwargs["n_startup_trials"] = pruner._n_startup_trials
43
+ kwargs["n_warmup_steps"] = pruner._n_warmup_steps
44
+ kwargs["interval_steps"] = pruner._interval_steps
45
+
46
+ elif isinstance(pruner, optuna.pruners.PercentilePruner):
47
+ kwargs["percentile"] = pruner._percentile
48
+ kwargs["n_startup_trials"] = pruner._n_startup_trials
49
+ kwargs["n_warmup_steps"] = pruner._n_warmup_steps
50
+ kwargs["interval_steps"] = pruner._interval_steps
51
+
52
+ elif isinstance(pruner, optuna.pruners.SuccessiveHalvingPruner):
53
+ kwargs["min_resource"] = pruner._min_resource
54
+ kwargs["reduction_factor"] = pruner._reduction_factor
55
+ kwargs["min_early_stopping_rate"] = pruner._min_early_stopping_rate
56
+
57
+ elif isinstance(pruner, optuna.pruners.ThresholdPruner):
58
+ kwargs["lower"] = pruner._lower
59
+ kwargs["upper"] = pruner._upper
60
+ kwargs["n_warmup_steps"] = pruner._n_warmup_steps
61
+ kwargs["interval_steps"] = pruner._interval_steps
62
+ elif isinstance(pruner, optuna.pruners.NopPruner):
63
+ pass
64
+ else:
65
+ raise ValueError("Unsupported pruner is specified: {}".format(type(pruner)))
66
+
67
+ return kwargs
68
+
69
+
70
+ @experimental_class("1.4.0")
71
+ class AllenNLPExecutor:
72
+ """AllenNLP extension to use optuna with Jsonnet config file.
73
+
74
+ See the examples of `objective function <https://github.com/optuna/optuna-examples/tree/
75
+ main/allennlp/allennlp_jsonnet.py>`_.
76
+
77
+ You can also see the tutorial of our AllenNLP integration on
78
+ `AllenNLP Guide <https://guide.allennlp.org/hyperparameter-optimization>`_.
79
+
80
+ .. note::
81
+ From Optuna v2.1.0, users have to cast their parameters by using methods in Jsonnet.
82
+ Call ``std.parseInt`` for integer, or ``std.parseJson`` for floating point.
83
+ Please see the `example configuration <https://github.com/optuna/optuna-examples/tree/main/
84
+ allennlp/classifier.jsonnet>`_.
85
+
86
+ .. note::
87
+ In :class:`~optuna_integration.AllenNLPExecutor`,
88
+ you can pass parameters to AllenNLP by either defining a search space using
89
+ Optuna suggest methods or setting environment variables just like AllenNLP CLI.
90
+ If a value is set in both a search space in Optuna and the environment variables,
91
+ the executor will use the value specified in the search space in Optuna.
92
+
93
+ Args:
94
+ trial:
95
+ A :class:`~optuna.trial.Trial` corresponding to the current evaluation
96
+ of the objective function.
97
+ config_file:
98
+ Config file for AllenNLP.
99
+ Hyperparameters should be masked with ``std.extVar``.
100
+ Please refer to `the config example <https://github.com/allenai/allentune/blob/
101
+ master/examples/classifier.jsonnet>`_.
102
+ serialization_dir:
103
+ A path which model weights and logs are saved.
104
+ metrics:
105
+ An evaluation metric. `GradientDescrentTrainer.train() <https://docs.allennlp.org/
106
+ main/api/training/gradient_descent_trainer/#train>`_ of AllenNLP
107
+ returns a dictionary containing metrics after training.
108
+ :class:`~optuna_integration.AllenNLPExecutor` accesses the dictionary
109
+ by the key ``metrics`` you specify and use it as a objective value.
110
+ force:
111
+ If :obj:`True`, an executor overwrites the output directory if it exists.
112
+ file_friendly_logging:
113
+ If :obj:`True`, tqdm status is printed on separate lines and slows tqdm refresh rate.
114
+ include_package:
115
+ Additional packages to include.
116
+ For more information, please see
117
+ `AllenNLP documentation <https://docs.allennlp.org/master/api/commands/train/>`_.
118
+
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ trial: optuna.Trial,
124
+ config_file: str,
125
+ serialization_dir: str,
126
+ metrics: str = "best_validation_accuracy",
127
+ *,
128
+ include_package: str | list[str] | None = None,
129
+ force: bool = False,
130
+ file_friendly_logging: bool = False,
131
+ ):
132
+ _imports.check()
133
+
134
+ self._params = trial.params
135
+ self._config_file = config_file
136
+ self._serialization_dir = serialization_dir
137
+ self._metrics = metrics
138
+ self._force = force
139
+ self._file_friendly_logging = file_friendly_logging
140
+
141
+ if include_package is None:
142
+ include_package = []
143
+ if isinstance(include_package, str):
144
+ include_package = [include_package]
145
+
146
+ self._include_package = include_package + ["optuna_integration.allennlp"]
147
+
148
+ storage = trial.study._storage
149
+
150
+ if isinstance(storage, optuna.storages.RDBStorage):
151
+ url = storage.url
152
+
153
+ elif isinstance(storage, optuna.storages._CachedStorage):
154
+ assert isinstance(storage._backend, optuna.storages.RDBStorage)
155
+ url = storage._backend.url
156
+
157
+ else:
158
+ url = ""
159
+
160
+ target_pid = psutil.Process().ppid()
161
+ variable_manager = _VariableManager(target_pid)
162
+
163
+ pruner_kwargs = _fetch_pruner_config(trial)
164
+ variable_manager.set_value("study_name", trial.study.study_name)
165
+ variable_manager.set_value("trial_id", trial._trial_id)
166
+ variable_manager.set_value("storage_name", url)
167
+ variable_manager.set_value("monitor", metrics)
168
+
169
+ if trial.study.pruner is not None:
170
+ variable_manager.set_value("pruner_class", type(trial.study.pruner).__name__)
171
+ variable_manager.set_value("pruner_kwargs", pruner_kwargs)
172
+
173
+ def _build_params(self) -> dict[str, Any]:
174
+ """Create a dict of params for AllenNLP.
175
+
176
+ _build_params is based on allentune's ``train_func``.
177
+ For more detail, please refer to
178
+ https://github.com/allenai/allentune/blob/master/allentune/modules/allennlp_runner.py#L34-L65
179
+
180
+ """
181
+ params = _environment_variables()
182
+ params.update({key: str(value) for key, value in self._params.items()})
183
+ return json.loads(_jsonnet.evaluate_file(self._config_file, ext_vars=params))
184
+
185
+ def _set_environment_variables(self) -> None:
186
+ for key, value in _environment_variables().items():
187
+ if key is None:
188
+ continue
189
+ os.environ[key] = value
190
+
191
+ def run(self) -> float:
192
+ """Train a model using AllenNLP."""
193
+ for package_name in self._include_package:
194
+ allennlp.common.util.import_module_and_submodules(package_name)
195
+
196
+ # Without the following lines, the transformer model construction only takes place in the
197
+ # first trial (which would consume some random numbers), and the cached model will be used
198
+ # in trials afterwards (which would not consume random numbers), leading to inconsistent
199
+ # results between single trial and multiple trials. To make results reproducible in
200
+ # multiple trials, we clear the cache before each trial.
201
+ # TODO(MagiaSN) When AllenNLP has introduced a better API to do this, one should remove
202
+ # these lines and use the new API instead. For example, use the `_clear_caches()` method
203
+ # which will be in the next AllenNLP release after 2.4.0.
204
+ allennlp.common.cached_transformers._model_cache.clear()
205
+ allennlp.common.cached_transformers._tokenizer_cache.clear()
206
+
207
+ self._set_environment_variables()
208
+ params = allennlp.common.params.Params(self._build_params())
209
+
210
+ if "distributed" in params:
211
+ if OPTUNA_ALLENNLP_DISTRIBUTED_FLAG in os.environ:
212
+ warnings.warn(
213
+ "Other process may already exists."
214
+ " If you have trouble, please unset the environment"
215
+ " variable `OPTUNA_ALLENNLP_USE_DISTRIBUTED`"
216
+ " and try it again."
217
+ )
218
+
219
+ os.environ[OPTUNA_ALLENNLP_DISTRIBUTED_FLAG] = "1"
220
+
221
+ try:
222
+ allennlp.commands.train.train_model(
223
+ params=params,
224
+ serialization_dir=self._serialization_dir,
225
+ file_friendly_logging=self._file_friendly_logging,
226
+ force=self._force,
227
+ include_package=self._include_package,
228
+ )
229
+ except ProcessRaisedException as e:
230
+ if "raise TrialPruned()" in str(e):
231
+ raise TrialPruned()
232
+
233
+ metrics = json.load(open(os.path.join(self._serialization_dir, "metrics.json")))
234
+ return metrics[self._metrics]