ato 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ato might be problematic. Click here for more details.
- ato/__init__.py +1 -0
- ato/adict.py +582 -0
- ato/db_routers/__init__.py +8 -0
- ato/db_routers/sql/__init__.py +0 -0
- ato/db_routers/sql/manager.py +188 -0
- ato/db_routers/sql/schema.py +83 -0
- ato/hyperopt/__init__.py +0 -0
- ato/hyperopt/base.py +144 -0
- ato/hyperopt/hyperband.py +103 -0
- ato/parser.py +103 -0
- ato/scope.py +491 -0
- ato/utils.py +55 -0
- ato/xyz.py +234 -0
- ato-2.0.0.dist-info/METADATA +1181 -0
- ato-2.0.0.dist-info/RECORD +18 -0
- ato-2.0.0.dist-info/WHEEL +5 -0
- ato-2.0.0.dist-info/licenses/LICENSE +21 -0
- ato-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import create_engine, func
|
|
5
|
+
from sqlalchemy.orm import sessionmaker, Session as SessionType
|
|
6
|
+
|
|
7
|
+
from ato.db_routers import BaseLogger, BaseFinder
|
|
8
|
+
from ato.db_routers.sql.schema import Base, Project, Experiment, Metric, Artifact, Fingerprint
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SQLLogger(BaseLogger):
|
|
12
|
+
registry: set[str] = set()
|
|
13
|
+
|
|
14
|
+
def __init__(self, config):
|
|
15
|
+
super().__init__(config)
|
|
16
|
+
db_path = self.config.experiment.sql.db_path
|
|
17
|
+
self.engine = create_engine(db_path)
|
|
18
|
+
registry = self.__class__.registry
|
|
19
|
+
if db_path not in registry:
|
|
20
|
+
Base.metadata.create_all(self.engine) # Base.metadata is SQLAlchemy's internal attribute
|
|
21
|
+
registry.add(db_path)
|
|
22
|
+
self.session = sessionmaker(bind=self.engine)()
|
|
23
|
+
self._current_run_id = None
|
|
24
|
+
|
|
25
|
+
def get_current_run(self):
|
|
26
|
+
return self.session.get(Experiment, self.get_current_run_id())
|
|
27
|
+
|
|
28
|
+
def get_current_run_id(self):
|
|
29
|
+
return self._current_run_id
|
|
30
|
+
|
|
31
|
+
def get_or_create_project(self):
|
|
32
|
+
project = self.session.query(Project).filter_by(name=self.config.experiment.project_name).first()
|
|
33
|
+
if not project:
|
|
34
|
+
project = Project(name=self.config.experiment.project_name)
|
|
35
|
+
self.session.add(project)
|
|
36
|
+
self.session.commit()
|
|
37
|
+
return project
|
|
38
|
+
|
|
39
|
+
def update_status(self, status):
|
|
40
|
+
run = self.get_current_run()
|
|
41
|
+
if run:
|
|
42
|
+
run.status = status
|
|
43
|
+
self.session.commit()
|
|
44
|
+
|
|
45
|
+
def run(self, tags=None):
|
|
46
|
+
project = self.get_or_create_project()
|
|
47
|
+
structural_hash = self.config.get_structural_hash()
|
|
48
|
+
run = Experiment(
|
|
49
|
+
project_id=project.id,
|
|
50
|
+
config=self.config.to_dict(), # Recursively convert ADict to dict for JSON serialization
|
|
51
|
+
structural_hash=structural_hash,
|
|
52
|
+
tags=tags or []
|
|
53
|
+
)
|
|
54
|
+
self.session.add(run)
|
|
55
|
+
self.session.commit()
|
|
56
|
+
self._current_run_id = run.id
|
|
57
|
+
return run.id
|
|
58
|
+
|
|
59
|
+
def log_metric(self, key, value, step):
|
|
60
|
+
metric = Metric(
|
|
61
|
+
run_id=self.get_current_run_id(),
|
|
62
|
+
key=key,
|
|
63
|
+
value=value,
|
|
64
|
+
step=step
|
|
65
|
+
)
|
|
66
|
+
self.session.add(metric)
|
|
67
|
+
self.session.commit()
|
|
68
|
+
|
|
69
|
+
def log_artifact(self, run_id, file_path, data_type, metadata=None):
|
|
70
|
+
artifact = Artifact(
|
|
71
|
+
run_id=run_id,
|
|
72
|
+
path=file_path,
|
|
73
|
+
data_type=data_type,
|
|
74
|
+
data_info=metadata # Column name is data_info in schema
|
|
75
|
+
)
|
|
76
|
+
self.session.add(artifact)
|
|
77
|
+
self.session.commit()
|
|
78
|
+
|
|
79
|
+
def finish(self, status='completed'):
|
|
80
|
+
run = self.get_current_run()
|
|
81
|
+
if run:
|
|
82
|
+
run.status = status
|
|
83
|
+
run.end_time = datetime.datetime.now(datetime.timezone.utc)
|
|
84
|
+
self.session.commit()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class SQLFinder(BaseFinder):
|
|
88
|
+
def __init__(self, config):
|
|
89
|
+
super().__init__(config)
|
|
90
|
+
self.engine = create_engine(self.config.experiment.sql.db_path)
|
|
91
|
+
self.session_maker = sessionmaker(bind=self.engine)
|
|
92
|
+
|
|
93
|
+
def _get_session(self) -> SessionType:
|
|
94
|
+
return self.session_maker()
|
|
95
|
+
|
|
96
|
+
def find_project(self, project_name) -> Optional[Project]:
|
|
97
|
+
with self._get_session() as session:
|
|
98
|
+
return session.query(Project).filter_by(name=project_name).first()
|
|
99
|
+
|
|
100
|
+
def find_run(self, run_id: int) -> Optional[Experiment]:
|
|
101
|
+
with self._get_session() as session:
|
|
102
|
+
return session.get(Experiment, run_id)
|
|
103
|
+
|
|
104
|
+
def get_runs_in_project(self, project_name) -> list[Experiment]:
|
|
105
|
+
project = self.find_project(project_name)
|
|
106
|
+
if not project:
|
|
107
|
+
return []
|
|
108
|
+
else:
|
|
109
|
+
with self._get_session() as session:
|
|
110
|
+
return session.query(Experiment).filter_by(project_id=project.id).all()
|
|
111
|
+
|
|
112
|
+
def find_similar_runs(self, run_id: int) -> list[Experiment]:
|
|
113
|
+
with self._get_session() as session:
|
|
114
|
+
base_run = session.query(Experiment.structural_hash).filter_by(id=run_id).first()
|
|
115
|
+
if not base_run:
|
|
116
|
+
return []
|
|
117
|
+
else:
|
|
118
|
+
target_hash = base_run.structural_hash
|
|
119
|
+
return session.query(Experiment).filter(
|
|
120
|
+
Experiment.structural_hash == target_hash,
|
|
121
|
+
Experiment.id != run_id
|
|
122
|
+
).all()
|
|
123
|
+
|
|
124
|
+
def find_similar_runs_by_trace(
|
|
125
|
+
self,
|
|
126
|
+
run_id: int,
|
|
127
|
+
trace_id: str,
|
|
128
|
+
trace_type: str = 'static'
|
|
129
|
+
) -> list[Experiment]:
|
|
130
|
+
with self._get_session() as session:
|
|
131
|
+
target_fingerprint = session.query(Fingerprint.fingerprint).filter(
|
|
132
|
+
Fingerprint.run_id == run_id,
|
|
133
|
+
Fingerprint.trace_id == trace_id,
|
|
134
|
+
Fingerprint.trace_type == trace_type
|
|
135
|
+
).scalar()
|
|
136
|
+
if not target_fingerprint:
|
|
137
|
+
return []
|
|
138
|
+
query = session.query(Fingerprint.run_id).filter(
|
|
139
|
+
Fingerprint.trace_id == trace_id,
|
|
140
|
+
Fingerprint.trace_type == trace_type,
|
|
141
|
+
Fingerprint.fingerprint == target_fingerprint,
|
|
142
|
+
Fingerprint.run_id != run_id
|
|
143
|
+
).distinct()
|
|
144
|
+
return session.query(Experiment).filter(Experiment.id.in_(query)).all()
|
|
145
|
+
|
|
146
|
+
def find_best_run(self, project_name: str, metric_key: str, mode: str = 'max') -> Experiment | dict:
|
|
147
|
+
project = self.find_project(project_name)
|
|
148
|
+
if project:
|
|
149
|
+
with self._get_session() as session:
|
|
150
|
+
order_by_col = Metric.value.desc() if mode == 'max' else Metric.value.asc()
|
|
151
|
+
best_run = session.query(Experiment, Metric.value).join(
|
|
152
|
+
Metric,
|
|
153
|
+
Experiment.id == Metric.run_id
|
|
154
|
+
).filter(
|
|
155
|
+
Experiment.project_id == project.id,
|
|
156
|
+
Metric.key == metric_key
|
|
157
|
+
).order_by(
|
|
158
|
+
order_by_col
|
|
159
|
+
).first()
|
|
160
|
+
if best_run:
|
|
161
|
+
return best_run[0]
|
|
162
|
+
else:
|
|
163
|
+
return {'error': 'Run not found'}
|
|
164
|
+
return {'error': 'Project not found'}
|
|
165
|
+
|
|
166
|
+
def get_trace_statistics(self, project_name: str, trace_id: str) -> dict:
|
|
167
|
+
project = self.find_project(project_name)
|
|
168
|
+
if not project:
|
|
169
|
+
return {'error': 'Project not found'}
|
|
170
|
+
else:
|
|
171
|
+
with self._get_session() as session:
|
|
172
|
+
run_ids_in_project = session.query(Experiment.id).filter_by(project_id=project.id)
|
|
173
|
+
static_count = session.query(func.count(Fingerprint.fingerprint.distinct())).filter(
|
|
174
|
+
Fingerprint.run_id.in_(run_ids_in_project),
|
|
175
|
+
Fingerprint.trace_id == trace_id,
|
|
176
|
+
Fingerprint.trace_type == 'static'
|
|
177
|
+
).scalar()
|
|
178
|
+
runtime_count = session.query(func.count(Fingerprint.fingerprint.distinct())).filter(
|
|
179
|
+
Fingerprint.run_id.in_(run_ids_in_project),
|
|
180
|
+
Fingerprint.trace_id == trace_id,
|
|
181
|
+
Fingerprint.trace_type == 'runtime' # @runtime_trace의 타입
|
|
182
|
+
).scalar()
|
|
183
|
+
return {
|
|
184
|
+
'project_name': project_name,
|
|
185
|
+
'trace_id': trace_id,
|
|
186
|
+
'static_trace_versions': static_count,
|
|
187
|
+
'runtime_trace_versions': runtime_count
|
|
188
|
+
}
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import Column, Integer, String, Float, DateTime, ForeignKey, JSON, Text
|
|
4
|
+
from sqlalchemy.ext.declarative import declarative_base
|
|
5
|
+
from sqlalchemy.orm import relationship
|
|
6
|
+
|
|
7
|
+
Base = declarative_base()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# single project
|
|
11
|
+
class Project(Base):
|
|
12
|
+
__tablename__ = 'projects'
|
|
13
|
+
|
|
14
|
+
id = Column(Integer, primary_key=True)
|
|
15
|
+
name = Column(String, unique=True, nullable=False, index=True)
|
|
16
|
+
description = Column(Text)
|
|
17
|
+
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
|
18
|
+
|
|
19
|
+
runs = relationship('Experiment', back_populates='project') # makes 1:N relationship
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# single experiment
|
|
23
|
+
class Experiment(Base):
|
|
24
|
+
__tablename__ = 'experiments'
|
|
25
|
+
|
|
26
|
+
id = Column(Integer, primary_key=True)
|
|
27
|
+
project_id = Column(Integer, ForeignKey('projects.id'), nullable=False, index=True)
|
|
28
|
+
config = Column(JSON, nullable=False)
|
|
29
|
+
|
|
30
|
+
structural_hash = Column(String(64), index=True)
|
|
31
|
+
|
|
32
|
+
tags = Column(JSON, default=[])
|
|
33
|
+
status = Column(String, default='running', index=True) # current status of experiment
|
|
34
|
+
start_time = Column(DateTime, default=datetime.datetime.utcnow)
|
|
35
|
+
end_time = Column(DateTime)
|
|
36
|
+
|
|
37
|
+
project = relationship('Project', back_populates='runs')
|
|
38
|
+
metrics = relationship('Metric', back_populates='run')
|
|
39
|
+
artifacts = relationship('Artifact', back_populates='run')
|
|
40
|
+
fingerprints = relationship('Fingerprint', back_populates='run')
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# single metric
|
|
44
|
+
class Metric(Base):
|
|
45
|
+
__tablename__ = 'metrics'
|
|
46
|
+
|
|
47
|
+
id = Column(Integer, primary_key=True)
|
|
48
|
+
run_id = Column(Integer, ForeignKey('experiments.id'), nullable=False, index=True)
|
|
49
|
+
|
|
50
|
+
key = Column(String, nullable=False, index=True) # name of metric
|
|
51
|
+
value = Column(Float, nullable=False) # value of metric
|
|
52
|
+
step = Column(Integer, nullable=False) # logged step
|
|
53
|
+
timestamp = Column(DateTime, default=datetime.datetime.now(datetime.timezone.utc))
|
|
54
|
+
|
|
55
|
+
run = relationship('Experiment', back_populates='metrics')
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# single artifact
|
|
59
|
+
class Artifact(Base):
|
|
60
|
+
__tablename__ = 'artifacts'
|
|
61
|
+
|
|
62
|
+
id = Column(Integer, primary_key=True)
|
|
63
|
+
run_id = Column(Integer, ForeignKey('experiments.id'), nullable=False, index=True)
|
|
64
|
+
|
|
65
|
+
path = Column(String, nullable=False)
|
|
66
|
+
data_type = Column(String)
|
|
67
|
+
data_info = Column(JSON)
|
|
68
|
+
|
|
69
|
+
run = relationship('Experiment', back_populates='artifacts')
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# single fingerprint
|
|
73
|
+
class Fingerprint(Base):
|
|
74
|
+
__tablename__ = 'fingerprints'
|
|
75
|
+
|
|
76
|
+
id = Column(Integer, primary_key=True)
|
|
77
|
+
run_id = Column(Integer, ForeignKey('experiments.id'), nullable=False, index=True)
|
|
78
|
+
trace_id = Column(String, nullable=False, index=True)
|
|
79
|
+
|
|
80
|
+
trace_type = Column(String(32), nullable=False, index=True) # 'static' or 'runtime'
|
|
81
|
+
fingerprint = Column(String(128), nullable=False, index=True) # hash value
|
|
82
|
+
|
|
83
|
+
run = relationship('Experiment', back_populates='fingerprints')
|
ato/hyperopt/__init__.py
ADDED
|
File without changes
|
ato/hyperopt/base.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import uuid
|
|
3
|
+
from copy import deepcopy as dcp
|
|
4
|
+
from itertools import product
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch.distributed as dist
|
|
8
|
+
|
|
9
|
+
from ato.adict import ADict
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class HyperOpt:
|
|
13
|
+
def __init__(self, scope, search_spaces, tracker=None, mode='max'):
|
|
14
|
+
if mode not in ('min', 'max'):
|
|
15
|
+
raise ValueError('mode must be either "min" or "max".')
|
|
16
|
+
self.scope = scope
|
|
17
|
+
self.search_spaces = search_spaces
|
|
18
|
+
self.config = scope.config.clone()
|
|
19
|
+
self.tracker = tracker
|
|
20
|
+
self.mode = mode
|
|
21
|
+
self.config.__hyperopt_id__ = self.get_hyperopt_id()
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def get_hyperopt_id(cls):
|
|
25
|
+
return str(uuid.uuid4())
|
|
26
|
+
|
|
27
|
+
def main(self, func):
|
|
28
|
+
raise NotImplementedError()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class DistributedMixIn:
|
|
32
|
+
def __init__(self, rank=0, world_size=1, backend='pytorch'):
|
|
33
|
+
self.rank = rank
|
|
34
|
+
self.world_size = world_size
|
|
35
|
+
self.backend = backend
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def is_root(self):
|
|
39
|
+
return self.rank == 0
|
|
40
|
+
|
|
41
|
+
def broadcast_object_from_root(self, obj):
|
|
42
|
+
if self.backend == 'pytorch':
|
|
43
|
+
obj = [obj]
|
|
44
|
+
dist.broadcast_object_list(obj)
|
|
45
|
+
obj = obj[0]
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError(f'Unsupported backend: {self.backend}')
|
|
48
|
+
return obj
|
|
49
|
+
|
|
50
|
+
def all_gather_object(self, obj):
|
|
51
|
+
if self.backend == 'pytorch':
|
|
52
|
+
gathered_objects = [None for _ in range(self.world_size)]
|
|
53
|
+
dist.all_gather_object(gathered_objects, obj)
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError(f'Unsupported backend: {self.backend}')
|
|
56
|
+
return gathered_objects
|
|
57
|
+
|
|
58
|
+
def destroy(self):
|
|
59
|
+
if self.backend == 'pytorch':
|
|
60
|
+
if dist.is_initialized():
|
|
61
|
+
dist.destroy_process_group()
|
|
62
|
+
else:
|
|
63
|
+
raise ValueError(f'Unsupported backend: {self.backend}')
|
|
64
|
+
|
|
65
|
+
def get_hyperopt_id(self):
|
|
66
|
+
return self.broadcast_object_from_root(str(uuid.uuid4()))
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class DistributedHyperOpt(DistributedMixIn, HyperOpt):
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
scope,
|
|
73
|
+
search_spaces,
|
|
74
|
+
tracker=None,
|
|
75
|
+
mode='max',
|
|
76
|
+
rank=0,
|
|
77
|
+
world_size=1,
|
|
78
|
+
backend='pytorch'
|
|
79
|
+
):
|
|
80
|
+
HyperOpt.__init__(self, scope, search_spaces, tracker, mode)
|
|
81
|
+
DistributedMixIn.__init__(self, rank, world_size, backend)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class GridSpaceMixIn:
|
|
85
|
+
@classmethod
|
|
86
|
+
def prepare_distributions(cls, base_config, search_spaces):
|
|
87
|
+
sampling_spaces = ADict()
|
|
88
|
+
for param_name, search_space in search_spaces.items():
|
|
89
|
+
if 'param_type' not in search_space:
|
|
90
|
+
raise KeyError(f'param_type for parameter {param_name} is not defined at search_spaces.')
|
|
91
|
+
param_type = search_space['param_type'].upper()
|
|
92
|
+
if param_type == 'INTEGER':
|
|
93
|
+
start, stop = search_space.param_range
|
|
94
|
+
space_type = search_space.get('space_type', 'LINEAR')
|
|
95
|
+
if space_type == 'LINEAR':
|
|
96
|
+
optim_space = np.linspace(
|
|
97
|
+
start=start,
|
|
98
|
+
stop=stop,
|
|
99
|
+
num=search_space.num_samples,
|
|
100
|
+
dtype=np.int64
|
|
101
|
+
).tolist()
|
|
102
|
+
elif space_type == 'LOG':
|
|
103
|
+
base = search_space.get('base', 2)
|
|
104
|
+
optim_space = np.logspace(
|
|
105
|
+
start=math.log(start, base),
|
|
106
|
+
stop=math.log(stop, base),
|
|
107
|
+
num=search_space.num_samples,
|
|
108
|
+
dtype=np.int64,
|
|
109
|
+
base=base
|
|
110
|
+
).tolist()
|
|
111
|
+
else:
|
|
112
|
+
raise ValueError(f'Invalid space_type: {space_type}')
|
|
113
|
+
elif param_type == 'FLOAT':
|
|
114
|
+
start, stop = search_space.param_range
|
|
115
|
+
space_type = search_space.get('space_type', 'LINEAR')
|
|
116
|
+
if space_type == 'LINEAR':
|
|
117
|
+
optim_space = np.linspace(
|
|
118
|
+
start=start,
|
|
119
|
+
stop=stop,
|
|
120
|
+
num=search_space.num_samples,
|
|
121
|
+
dtype=np.float32
|
|
122
|
+
).tolist()
|
|
123
|
+
elif space_type == 'LOG':
|
|
124
|
+
base = search_space.get('base', 10)
|
|
125
|
+
optim_space = np.logspace(
|
|
126
|
+
start=math.log(start, base),
|
|
127
|
+
stop=math.log(stop, base),
|
|
128
|
+
num=search_space.num_samples,
|
|
129
|
+
dtype=np.float32,
|
|
130
|
+
base=base
|
|
131
|
+
).tolist()
|
|
132
|
+
else:
|
|
133
|
+
raise ValueError(f'Invalid space_type: {space_type}')
|
|
134
|
+
elif param_type == 'CATEGORY':
|
|
135
|
+
optim_space = search_space.categories
|
|
136
|
+
else:
|
|
137
|
+
raise ValueError(f'Unknown param_type for parameter {param_name}; {param_type}')
|
|
138
|
+
sampling_spaces[param_name] = optim_space
|
|
139
|
+
grid_space = [ADict(zip(sampling_spaces.keys(), values)) for values in product(*sampling_spaces.values())]
|
|
140
|
+
distributions = [
|
|
141
|
+
dcp(base_config).update(**partial_config)
|
|
142
|
+
for index, partial_config in enumerate(grid_space)
|
|
143
|
+
]
|
|
144
|
+
return distributions
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from copy import deepcopy as dcp
|
|
3
|
+
from itertools import chain
|
|
4
|
+
|
|
5
|
+
from ato.adict import ADict
|
|
6
|
+
from ato.hyperopt.base import HyperOpt, DistributedMixIn, GridSpaceMixIn
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class HyperBand(HyperOpt, GridSpaceMixIn):
|
|
10
|
+
def __init__(self, scope, search_spaces, halving_rate, num_min_samples, tracker=None, mode='max'):
|
|
11
|
+
if halving_rate <= 0 or halving_rate >= 1:
|
|
12
|
+
raise ValueError(f'halving_rate must be greater than 0.0 but less than 1.0, but got {halving_rate}.')
|
|
13
|
+
if num_min_samples < 1:
|
|
14
|
+
raise ValueError(f'num_min_samples must be greater than or equal to 1, but got {num_min_samples}.')
|
|
15
|
+
super().__init__(scope, search_spaces, tracker, mode)
|
|
16
|
+
self.halving_rate = halving_rate
|
|
17
|
+
self.num_min_samples = num_min_samples
|
|
18
|
+
self.distributions = self.prepare_distributions(self.config, self.search_spaces)
|
|
19
|
+
|
|
20
|
+
@classmethod
|
|
21
|
+
def prepare_distributions(cls, base_config, search_spaces):
|
|
22
|
+
distributions = super().prepare_distributions(base_config, search_spaces)
|
|
23
|
+
for distribution in distributions:
|
|
24
|
+
distribution.__num_halved__ = 0
|
|
25
|
+
return distributions
|
|
26
|
+
|
|
27
|
+
def main(self, func):
|
|
28
|
+
def launch(*args, **kwargs):
|
|
29
|
+
logs = []
|
|
30
|
+
distributions = self.distributions
|
|
31
|
+
while len(distributions) >= self.num_min_samples:
|
|
32
|
+
results = self.estimate(func, distributions, *args, **kwargs)
|
|
33
|
+
results.sort(key=lambda item: item.__metric__, reverse=self.mode == 'max')
|
|
34
|
+
logs.append(results)
|
|
35
|
+
distributions = []
|
|
36
|
+
for config in results[:int(len(results)*self.halving_rate)]:
|
|
37
|
+
config.__num_halved__ += 1
|
|
38
|
+
distributions.append(config)
|
|
39
|
+
last_config = logs[-1][0]
|
|
40
|
+
metric = self.estimate_single_run(func, last_config, *args, **kwargs)
|
|
41
|
+
best_config = dcp(last_config)
|
|
42
|
+
best_config.__metric__ = metric
|
|
43
|
+
logs.append([best_config])
|
|
44
|
+
return ADict(config=best_config, metric=metric, logs=logs)
|
|
45
|
+
return launch
|
|
46
|
+
|
|
47
|
+
def estimate(self, estimator, distributions, *args, **kwargs):
|
|
48
|
+
results = []
|
|
49
|
+
for config in distributions:
|
|
50
|
+
config = dcp(config)
|
|
51
|
+
self.scope.config = config
|
|
52
|
+
metric = self.estimate_single_run(estimator, config, *args, **kwargs)
|
|
53
|
+
config.__metric__ = metric
|
|
54
|
+
results.append(config)
|
|
55
|
+
return results
|
|
56
|
+
|
|
57
|
+
def estimate_single_run(self, estimator, config, *args, **kwargs):
|
|
58
|
+
self.scope.config = config
|
|
59
|
+
return self.scope(estimator)(*args, **kwargs)
|
|
60
|
+
|
|
61
|
+
def num_generations(self):
|
|
62
|
+
max_size = len(self.distributions)
|
|
63
|
+
min_size = self.num_min_samples
|
|
64
|
+
return math.ceil(math.log(min_size/max_size)/math.log(self.halving_rate))
|
|
65
|
+
|
|
66
|
+
def compute_optimized_initial_training_steps(self, max_steps):
|
|
67
|
+
num_generations = self.num_generations()
|
|
68
|
+
min_steps = max_steps*math.pow(self.halving_rate, num_generations)
|
|
69
|
+
return [
|
|
70
|
+
*(math.ceil(min_steps/math.pow(self.halving_rate, index)) for index in range(1, num_generations)),
|
|
71
|
+
max_steps
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class DistributedHyperBand(DistributedMixIn, HyperBand, GridSpaceMixIn):
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
scope,
|
|
79
|
+
search_spaces,
|
|
80
|
+
halving_rate,
|
|
81
|
+
num_min_samples,
|
|
82
|
+
tracker=None,
|
|
83
|
+
mode='max',
|
|
84
|
+
rank=0,
|
|
85
|
+
world_size=1,
|
|
86
|
+
backend='pytorch'
|
|
87
|
+
):
|
|
88
|
+
DistributedMixIn.__init__(self, rank, world_size, backend)
|
|
89
|
+
HyperBand.__init__(self, scope, search_spaces, halving_rate, num_min_samples, tracker, mode)
|
|
90
|
+
|
|
91
|
+
def estimate(self, estimator, distributions, *args, **kwargs):
|
|
92
|
+
batch_size = math.ceil(len(distributions)/self.world_size)
|
|
93
|
+
distributions = distributions[self.rank*batch_size:(self.rank+1)*batch_size]
|
|
94
|
+
results = super().estimate(estimator, distributions, *args, **kwargs)
|
|
95
|
+
results = list(chain(*self.all_gather_object(results)))
|
|
96
|
+
return results
|
|
97
|
+
|
|
98
|
+
def estimate_single_run(self, estimator, config, *args, **kwargs):
|
|
99
|
+
result = super().estimate(estimator, config, *args, **kwargs)
|
|
100
|
+
self.destroy()
|
|
101
|
+
return result
|
|
102
|
+
|
|
103
|
+
|
ato/parser.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
def parse_command(command):
|
|
2
|
+
tokens = []
|
|
3
|
+
i = 0
|
|
4
|
+
length = len(command)
|
|
5
|
+
while i < length:
|
|
6
|
+
while i < length and command[i].isspace():
|
|
7
|
+
i += 1
|
|
8
|
+
if i >= length:
|
|
9
|
+
break
|
|
10
|
+
start = i
|
|
11
|
+
while i < length and not command[i].isspace() and command[i] != '=':
|
|
12
|
+
i += 1
|
|
13
|
+
if i < length and command[i] == '=':
|
|
14
|
+
key = command[start:i]
|
|
15
|
+
i += 1
|
|
16
|
+
if i < length:
|
|
17
|
+
value, i = parse_value(command, i)
|
|
18
|
+
else:
|
|
19
|
+
value = ''
|
|
20
|
+
tokens.append(f'{key}={value}')
|
|
21
|
+
else:
|
|
22
|
+
token_start = start
|
|
23
|
+
while i < length and not command[i].isspace():
|
|
24
|
+
i += 1
|
|
25
|
+
tokens.append(command[token_start:i])
|
|
26
|
+
return tokens
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def parse_value(command, i):
|
|
30
|
+
if i < len(command):
|
|
31
|
+
if command[i] == '%':
|
|
32
|
+
return parse_backtick_string(command, i)
|
|
33
|
+
elif command[i] in ['[', '(', '{']:
|
|
34
|
+
return parse_bracketed_value(command, i)
|
|
35
|
+
else:
|
|
36
|
+
start = i
|
|
37
|
+
while i < len(command) and not command[i].isspace():
|
|
38
|
+
i += 1
|
|
39
|
+
return command[start:i], i
|
|
40
|
+
return '', i
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def parse_backtick_string(command, i):
|
|
44
|
+
assert command[i] == '%'
|
|
45
|
+
i += 1
|
|
46
|
+
value = ['%']
|
|
47
|
+
length = len(command)
|
|
48
|
+
nesting_level = 1
|
|
49
|
+
while i < length:
|
|
50
|
+
c = command[i]
|
|
51
|
+
if c == '\\' and i+1 < length:
|
|
52
|
+
value.append(c)
|
|
53
|
+
value.append(command[i+1])
|
|
54
|
+
i += 2
|
|
55
|
+
elif c == '%':
|
|
56
|
+
value.append(c)
|
|
57
|
+
i += 1
|
|
58
|
+
if i < length and command[i] == '%':
|
|
59
|
+
value.append(command[i])
|
|
60
|
+
i += 1
|
|
61
|
+
else:
|
|
62
|
+
nesting_level -= 1
|
|
63
|
+
if nesting_level == 0:
|
|
64
|
+
break
|
|
65
|
+
elif c == '%':
|
|
66
|
+
value.append(c)
|
|
67
|
+
nesting_level += 1
|
|
68
|
+
i += 1
|
|
69
|
+
else:
|
|
70
|
+
value.append(c)
|
|
71
|
+
i += 1
|
|
72
|
+
return ''.join(value), i
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def parse_bracketed_value(command, i):
|
|
76
|
+
brackets = {'[': ']', '{': '}', '(': ')'}
|
|
77
|
+
opening_bracket = command[i]
|
|
78
|
+
closing_bracket = brackets[opening_bracket]
|
|
79
|
+
value = [opening_bracket]
|
|
80
|
+
i += 1
|
|
81
|
+
length = len(command)
|
|
82
|
+
stack = [closing_bracket]
|
|
83
|
+
while i < length and stack:
|
|
84
|
+
c = command[i]
|
|
85
|
+
if c == '%':
|
|
86
|
+
backtick_value, i = parse_backtick_string(command, i)
|
|
87
|
+
value.append(backtick_value)
|
|
88
|
+
elif c == '\\' and i+1 < length:
|
|
89
|
+
value.append(c)
|
|
90
|
+
value.append(command[i+1])
|
|
91
|
+
i += 2
|
|
92
|
+
elif c in brackets:
|
|
93
|
+
stack.append(brackets[c])
|
|
94
|
+
value.append(c)
|
|
95
|
+
i += 1
|
|
96
|
+
elif c == stack[-1]:
|
|
97
|
+
stack.pop()
|
|
98
|
+
value.append(c)
|
|
99
|
+
i += 1
|
|
100
|
+
else:
|
|
101
|
+
value.append(c)
|
|
102
|
+
i += 1
|
|
103
|
+
return ''.join(value), i
|