ato 2.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,188 @@
1
+ import datetime
2
+ from typing import Optional
3
+
4
+ from sqlalchemy import create_engine, func
5
+ from sqlalchemy.orm import sessionmaker, Session as SessionType
6
+
7
+ from ato.db_routers import BaseLogger, BaseFinder
8
+ from ato.db_routers.sql.schema import Base, Project, Experiment, Metric, Artifact, Fingerprint
9
+
10
+
11
+ class SQLLogger(BaseLogger):
12
+ registry: set[str] = set()
13
+
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+ db_path = self.config.experiment.sql.db_path
17
+ self.engine = create_engine(db_path)
18
+ registry = self.__class__.registry
19
+ if db_path not in registry:
20
+ Base.metadata.create_all(self.engine) # Base.metadata is SQLAlchemy's internal attribute
21
+ registry.add(db_path)
22
+ self.session = sessionmaker(bind=self.engine)()
23
+ self._current_run_id = None
24
+
25
+ def get_current_run(self):
26
+ return self.session.get(Experiment, self.get_current_run_id())
27
+
28
+ def get_current_run_id(self):
29
+ return self._current_run_id
30
+
31
+ def get_or_create_project(self):
32
+ project = self.session.query(Project).filter_by(name=self.config.experiment.project_name).first()
33
+ if not project:
34
+ project = Project(name=self.config.experiment.project_name)
35
+ self.session.add(project)
36
+ self.session.commit()
37
+ return project
38
+
39
+ def update_status(self, status):
40
+ run = self.get_current_run()
41
+ if run:
42
+ run.status = status
43
+ self.session.commit()
44
+
45
+ def run(self, tags=None):
46
+ project = self.get_or_create_project()
47
+ structural_hash = self.config.get_structural_hash()
48
+ run = Experiment(
49
+ project_id=project.id,
50
+ config=self.config.to_dict(), # Recursively convert ADict to dict for JSON serialization
51
+ structural_hash=structural_hash,
52
+ tags=tags or []
53
+ )
54
+ self.session.add(run)
55
+ self.session.commit()
56
+ self._current_run_id = run.id
57
+ return run.id
58
+
59
+ def log_metric(self, key, value, step):
60
+ metric = Metric(
61
+ run_id=self.get_current_run_id(),
62
+ key=key,
63
+ value=value,
64
+ step=step
65
+ )
66
+ self.session.add(metric)
67
+ self.session.commit()
68
+
69
+ def log_artifact(self, run_id, file_path, data_type, metadata=None):
70
+ artifact = Artifact(
71
+ run_id=run_id,
72
+ path=file_path,
73
+ data_type=data_type,
74
+ data_info=metadata # Column name is data_info in schema
75
+ )
76
+ self.session.add(artifact)
77
+ self.session.commit()
78
+
79
+ def finish(self, status='completed'):
80
+ run = self.get_current_run()
81
+ if run:
82
+ run.status = status
83
+ run.end_time = datetime.datetime.now(datetime.timezone.utc)
84
+ self.session.commit()
85
+
86
+
87
+ class SQLFinder(BaseFinder):
88
+ def __init__(self, config):
89
+ super().__init__(config)
90
+ self.engine = create_engine(self.config.experiment.sql.db_path)
91
+ self.session_maker = sessionmaker(bind=self.engine)
92
+
93
+ def _get_session(self) -> SessionType:
94
+ return self.session_maker()
95
+
96
+ def find_project(self, project_name) -> Optional[Project]:
97
+ with self._get_session() as session:
98
+ return session.query(Project).filter_by(name=project_name).first()
99
+
100
+ def find_run(self, run_id: int) -> Optional[Experiment]:
101
+ with self._get_session() as session:
102
+ return session.get(Experiment, run_id)
103
+
104
+ def get_runs_in_project(self, project_name) -> list[Experiment]:
105
+ project = self.find_project(project_name)
106
+ if not project:
107
+ return []
108
+ else:
109
+ with self._get_session() as session:
110
+ return session.query(Experiment).filter_by(project_id=project.id).all()
111
+
112
+ def find_similar_runs(self, run_id: int) -> list[Experiment]:
113
+ with self._get_session() as session:
114
+ base_run = session.query(Experiment.structural_hash).filter_by(id=run_id).first()
115
+ if not base_run:
116
+ return []
117
+ else:
118
+ target_hash = base_run.structural_hash
119
+ return session.query(Experiment).filter(
120
+ Experiment.structural_hash == target_hash,
121
+ Experiment.id != run_id
122
+ ).all()
123
+
124
+ def find_similar_runs_by_trace(
125
+ self,
126
+ run_id: int,
127
+ trace_id: str,
128
+ trace_type: str = 'static'
129
+ ) -> list[Experiment]:
130
+ with self._get_session() as session:
131
+ target_fingerprint = session.query(Fingerprint.fingerprint).filter(
132
+ Fingerprint.run_id == run_id,
133
+ Fingerprint.trace_id == trace_id,
134
+ Fingerprint.trace_type == trace_type
135
+ ).scalar()
136
+ if not target_fingerprint:
137
+ return []
138
+ query = session.query(Fingerprint.run_id).filter(
139
+ Fingerprint.trace_id == trace_id,
140
+ Fingerprint.trace_type == trace_type,
141
+ Fingerprint.fingerprint == target_fingerprint,
142
+ Fingerprint.run_id != run_id
143
+ ).distinct()
144
+ return session.query(Experiment).filter(Experiment.id.in_(query)).all()
145
+
146
+ def find_best_run(self, project_name: str, metric_key: str, mode: str = 'max') -> Experiment | dict:
147
+ project = self.find_project(project_name)
148
+ if project:
149
+ with self._get_session() as session:
150
+ order_by_col = Metric.value.desc() if mode == 'max' else Metric.value.asc()
151
+ best_run = session.query(Experiment, Metric.value).join(
152
+ Metric,
153
+ Experiment.id == Metric.run_id
154
+ ).filter(
155
+ Experiment.project_id == project.id,
156
+ Metric.key == metric_key
157
+ ).order_by(
158
+ order_by_col
159
+ ).first()
160
+ if best_run:
161
+ return best_run[0]
162
+ else:
163
+ return {'error': 'Run not found'}
164
+ return {'error': 'Project not found'}
165
+
166
+ def get_trace_statistics(self, project_name: str, trace_id: str) -> dict:
167
+ project = self.find_project(project_name)
168
+ if not project:
169
+ return {'error': 'Project not found'}
170
+ else:
171
+ with self._get_session() as session:
172
+ run_ids_in_project = session.query(Experiment.id).filter_by(project_id=project.id)
173
+ static_count = session.query(func.count(Fingerprint.fingerprint.distinct())).filter(
174
+ Fingerprint.run_id.in_(run_ids_in_project),
175
+ Fingerprint.trace_id == trace_id,
176
+ Fingerprint.trace_type == 'static'
177
+ ).scalar()
178
+ runtime_count = session.query(func.count(Fingerprint.fingerprint.distinct())).filter(
179
+ Fingerprint.run_id.in_(run_ids_in_project),
180
+ Fingerprint.trace_id == trace_id,
181
+ Fingerprint.trace_type == 'runtime' # @runtime_trace의 타입
182
+ ).scalar()
183
+ return {
184
+ 'project_name': project_name,
185
+ 'trace_id': trace_id,
186
+ 'static_trace_versions': static_count,
187
+ 'runtime_trace_versions': runtime_count
188
+ }
@@ -0,0 +1,83 @@
1
+ import datetime
2
+
3
+ from sqlalchemy import Column, Integer, String, Float, DateTime, ForeignKey, JSON, Text
4
+ from sqlalchemy.ext.declarative import declarative_base
5
+ from sqlalchemy.orm import relationship
6
+
7
+ Base = declarative_base()
8
+
9
+
10
+ # single project
11
+ class Project(Base):
12
+ __tablename__ = 'projects'
13
+
14
+ id = Column(Integer, primary_key=True)
15
+ name = Column(String, unique=True, nullable=False, index=True)
16
+ description = Column(Text)
17
+ created_at = Column(DateTime, default=datetime.datetime.utcnow)
18
+
19
+ runs = relationship('Experiment', back_populates='project') # makes 1:N relationship
20
+
21
+
22
+ # single experiment
23
+ class Experiment(Base):
24
+ __tablename__ = 'experiments'
25
+
26
+ id = Column(Integer, primary_key=True)
27
+ project_id = Column(Integer, ForeignKey('projects.id'), nullable=False, index=True)
28
+ config = Column(JSON, nullable=False)
29
+
30
+ structural_hash = Column(String(64), index=True)
31
+
32
+ tags = Column(JSON, default=[])
33
+ status = Column(String, default='running', index=True) # current status of experiment
34
+ start_time = Column(DateTime, default=datetime.datetime.utcnow)
35
+ end_time = Column(DateTime)
36
+
37
+ project = relationship('Project', back_populates='runs')
38
+ metrics = relationship('Metric', back_populates='run')
39
+ artifacts = relationship('Artifact', back_populates='run')
40
+ fingerprints = relationship('Fingerprint', back_populates='run')
41
+
42
+
43
+ # single metric
44
+ class Metric(Base):
45
+ __tablename__ = 'metrics'
46
+
47
+ id = Column(Integer, primary_key=True)
48
+ run_id = Column(Integer, ForeignKey('experiments.id'), nullable=False, index=True)
49
+
50
+ key = Column(String, nullable=False, index=True) # name of metric
51
+ value = Column(Float, nullable=False) # value of metric
52
+ step = Column(Integer, nullable=False) # logged step
53
+ timestamp = Column(DateTime, default=datetime.datetime.now(datetime.timezone.utc))
54
+
55
+ run = relationship('Experiment', back_populates='metrics')
56
+
57
+
58
+ # single artifact
59
+ class Artifact(Base):
60
+ __tablename__ = 'artifacts'
61
+
62
+ id = Column(Integer, primary_key=True)
63
+ run_id = Column(Integer, ForeignKey('experiments.id'), nullable=False, index=True)
64
+
65
+ path = Column(String, nullable=False)
66
+ data_type = Column(String)
67
+ data_info = Column(JSON)
68
+
69
+ run = relationship('Experiment', back_populates='artifacts')
70
+
71
+
72
+ # single fingerprint
73
+ class Fingerprint(Base):
74
+ __tablename__ = 'fingerprints'
75
+
76
+ id = Column(Integer, primary_key=True)
77
+ run_id = Column(Integer, ForeignKey('experiments.id'), nullable=False, index=True)
78
+ trace_id = Column(String, nullable=False, index=True)
79
+
80
+ trace_type = Column(String(32), nullable=False, index=True) # 'static' or 'runtime'
81
+ fingerprint = Column(String(128), nullable=False, index=True) # hash value
82
+
83
+ run = relationship('Experiment', back_populates='fingerprints')
File without changes
ato/hyperopt/base.py ADDED
@@ -0,0 +1,144 @@
1
+ import math
2
+ import uuid
3
+ from copy import deepcopy as dcp
4
+ from itertools import product
5
+
6
+ import numpy as np
7
+ import torch.distributed as dist
8
+
9
+ from ato.adict import ADict
10
+
11
+
12
+ class HyperOpt:
13
+ def __init__(self, scope, search_spaces, tracker=None, mode='max'):
14
+ if mode not in ('min', 'max'):
15
+ raise ValueError('mode must be either "min" or "max".')
16
+ self.scope = scope
17
+ self.search_spaces = search_spaces
18
+ self.config = scope.config.clone()
19
+ self.tracker = tracker
20
+ self.mode = mode
21
+ self.config.__hyperopt_id__ = self.get_hyperopt_id()
22
+
23
+ @classmethod
24
+ def get_hyperopt_id(cls):
25
+ return str(uuid.uuid4())
26
+
27
+ def main(self, func):
28
+ raise NotImplementedError()
29
+
30
+
31
+ class DistributedMixIn:
32
+ def __init__(self, rank=0, world_size=1, backend='pytorch'):
33
+ self.rank = rank
34
+ self.world_size = world_size
35
+ self.backend = backend
36
+
37
+ @property
38
+ def is_root(self):
39
+ return self.rank == 0
40
+
41
+ def broadcast_object_from_root(self, obj):
42
+ if self.backend == 'pytorch':
43
+ obj = [obj]
44
+ dist.broadcast_object_list(obj)
45
+ obj = obj[0]
46
+ else:
47
+ raise ValueError(f'Unsupported backend: {self.backend}')
48
+ return obj
49
+
50
+ def all_gather_object(self, obj):
51
+ if self.backend == 'pytorch':
52
+ gathered_objects = [None for _ in range(self.world_size)]
53
+ dist.all_gather_object(gathered_objects, obj)
54
+ else:
55
+ raise ValueError(f'Unsupported backend: {self.backend}')
56
+ return gathered_objects
57
+
58
+ def destroy(self):
59
+ if self.backend == 'pytorch':
60
+ if dist.is_initialized():
61
+ dist.destroy_process_group()
62
+ else:
63
+ raise ValueError(f'Unsupported backend: {self.backend}')
64
+
65
+ def get_hyperopt_id(self):
66
+ return self.broadcast_object_from_root(str(uuid.uuid4()))
67
+
68
+
69
+ class DistributedHyperOpt(DistributedMixIn, HyperOpt):
70
+ def __init__(
71
+ self,
72
+ scope,
73
+ search_spaces,
74
+ tracker=None,
75
+ mode='max',
76
+ rank=0,
77
+ world_size=1,
78
+ backend='pytorch'
79
+ ):
80
+ HyperOpt.__init__(self, scope, search_spaces, tracker, mode)
81
+ DistributedMixIn.__init__(self, rank, world_size, backend)
82
+
83
+
84
+ class GridSpaceMixIn:
85
+ @classmethod
86
+ def prepare_distributions(cls, base_config, search_spaces):
87
+ sampling_spaces = ADict()
88
+ for param_name, search_space in search_spaces.items():
89
+ if 'param_type' not in search_space:
90
+ raise KeyError(f'param_type for parameter {param_name} is not defined at search_spaces.')
91
+ param_type = search_space['param_type'].upper()
92
+ if param_type == 'INTEGER':
93
+ start, stop = search_space.param_range
94
+ space_type = search_space.get('space_type', 'LINEAR')
95
+ if space_type == 'LINEAR':
96
+ optim_space = np.linspace(
97
+ start=start,
98
+ stop=stop,
99
+ num=search_space.num_samples,
100
+ dtype=np.int64
101
+ ).tolist()
102
+ elif space_type == 'LOG':
103
+ base = search_space.get('base', 2)
104
+ optim_space = np.logspace(
105
+ start=math.log(start, base),
106
+ stop=math.log(stop, base),
107
+ num=search_space.num_samples,
108
+ dtype=np.int64,
109
+ base=base
110
+ ).tolist()
111
+ else:
112
+ raise ValueError(f'Invalid space_type: {space_type}')
113
+ elif param_type == 'FLOAT':
114
+ start, stop = search_space.param_range
115
+ space_type = search_space.get('space_type', 'LINEAR')
116
+ if space_type == 'LINEAR':
117
+ optim_space = np.linspace(
118
+ start=start,
119
+ stop=stop,
120
+ num=search_space.num_samples,
121
+ dtype=np.float32
122
+ ).tolist()
123
+ elif space_type == 'LOG':
124
+ base = search_space.get('base', 10)
125
+ optim_space = np.logspace(
126
+ start=math.log(start, base),
127
+ stop=math.log(stop, base),
128
+ num=search_space.num_samples,
129
+ dtype=np.float32,
130
+ base=base
131
+ ).tolist()
132
+ else:
133
+ raise ValueError(f'Invalid space_type: {space_type}')
134
+ elif param_type == 'CATEGORY':
135
+ optim_space = search_space.categories
136
+ else:
137
+ raise ValueError(f'Unknown param_type for parameter {param_name}; {param_type}')
138
+ sampling_spaces[param_name] = optim_space
139
+ grid_space = [ADict(zip(sampling_spaces.keys(), values)) for values in product(*sampling_spaces.values())]
140
+ distributions = [
141
+ dcp(base_config).update(**partial_config)
142
+ for index, partial_config in enumerate(grid_space)
143
+ ]
144
+ return distributions
@@ -0,0 +1,103 @@
1
+ import math
2
+ from copy import deepcopy as dcp
3
+ from itertools import chain
4
+
5
+ from ato.adict import ADict
6
+ from ato.hyperopt.base import HyperOpt, DistributedMixIn, GridSpaceMixIn
7
+
8
+
9
+ class HyperBand(HyperOpt, GridSpaceMixIn):
10
+ def __init__(self, scope, search_spaces, halving_rate, num_min_samples, tracker=None, mode='max'):
11
+ if halving_rate <= 0 or halving_rate >= 1:
12
+ raise ValueError(f'halving_rate must be greater than 0.0 but less than 1.0, but got {halving_rate}.')
13
+ if num_min_samples < 1:
14
+ raise ValueError(f'num_min_samples must be greater than or equal to 1, but got {num_min_samples}.')
15
+ super().__init__(scope, search_spaces, tracker, mode)
16
+ self.halving_rate = halving_rate
17
+ self.num_min_samples = num_min_samples
18
+ self.distributions = self.prepare_distributions(self.config, self.search_spaces)
19
+
20
+ @classmethod
21
+ def prepare_distributions(cls, base_config, search_spaces):
22
+ distributions = super().prepare_distributions(base_config, search_spaces)
23
+ for distribution in distributions:
24
+ distribution.__num_halved__ = 0
25
+ return distributions
26
+
27
+ def main(self, func):
28
+ def launch(*args, **kwargs):
29
+ logs = []
30
+ distributions = self.distributions
31
+ while len(distributions) >= self.num_min_samples:
32
+ results = self.estimate(func, distributions, *args, **kwargs)
33
+ results.sort(key=lambda item: item.__metric__, reverse=self.mode == 'max')
34
+ logs.append(results)
35
+ distributions = []
36
+ for config in results[:int(len(results)*self.halving_rate)]:
37
+ config.__num_halved__ += 1
38
+ distributions.append(config)
39
+ last_config = logs[-1][0]
40
+ metric = self.estimate_single_run(func, last_config, *args, **kwargs)
41
+ best_config = dcp(last_config)
42
+ best_config.__metric__ = metric
43
+ logs.append([best_config])
44
+ return ADict(config=best_config, metric=metric, logs=logs)
45
+ return launch
46
+
47
+ def estimate(self, estimator, distributions, *args, **kwargs):
48
+ results = []
49
+ for config in distributions:
50
+ config = dcp(config)
51
+ self.scope.config = config
52
+ metric = self.estimate_single_run(estimator, config, *args, **kwargs)
53
+ config.__metric__ = metric
54
+ results.append(config)
55
+ return results
56
+
57
+ def estimate_single_run(self, estimator, config, *args, **kwargs):
58
+ self.scope.config = config
59
+ return self.scope(estimator)(*args, **kwargs)
60
+
61
+ def num_generations(self):
62
+ max_size = len(self.distributions)
63
+ min_size = self.num_min_samples
64
+ return math.ceil(math.log(min_size/max_size)/math.log(self.halving_rate))
65
+
66
+ def compute_optimized_initial_training_steps(self, max_steps):
67
+ num_generations = self.num_generations()
68
+ min_steps = max_steps*math.pow(self.halving_rate, num_generations)
69
+ return [
70
+ *(math.ceil(min_steps/math.pow(self.halving_rate, index)) for index in range(1, num_generations)),
71
+ max_steps
72
+ ]
73
+
74
+
75
+ class DistributedHyperBand(DistributedMixIn, HyperBand, GridSpaceMixIn):
76
+ def __init__(
77
+ self,
78
+ scope,
79
+ search_spaces,
80
+ halving_rate,
81
+ num_min_samples,
82
+ tracker=None,
83
+ mode='max',
84
+ rank=0,
85
+ world_size=1,
86
+ backend='pytorch'
87
+ ):
88
+ DistributedMixIn.__init__(self, rank, world_size, backend)
89
+ HyperBand.__init__(self, scope, search_spaces, halving_rate, num_min_samples, tracker, mode)
90
+
91
+ def estimate(self, estimator, distributions, *args, **kwargs):
92
+ batch_size = math.ceil(len(distributions)/self.world_size)
93
+ distributions = distributions[self.rank*batch_size:(self.rank+1)*batch_size]
94
+ results = super().estimate(estimator, distributions, *args, **kwargs)
95
+ results = list(chain(*self.all_gather_object(results)))
96
+ return results
97
+
98
+ def estimate_single_run(self, estimator, config, *args, **kwargs):
99
+ result = super().estimate(estimator, config, *args, **kwargs)
100
+ self.destroy()
101
+ return result
102
+
103
+
ato/parser.py ADDED
@@ -0,0 +1,103 @@
1
+ def parse_command(command):
2
+ tokens = []
3
+ i = 0
4
+ length = len(command)
5
+ while i < length:
6
+ while i < length and command[i].isspace():
7
+ i += 1
8
+ if i >= length:
9
+ break
10
+ start = i
11
+ while i < length and not command[i].isspace() and command[i] != '=':
12
+ i += 1
13
+ if i < length and command[i] == '=':
14
+ key = command[start:i]
15
+ i += 1
16
+ if i < length:
17
+ value, i = parse_value(command, i)
18
+ else:
19
+ value = ''
20
+ tokens.append(f'{key}={value}')
21
+ else:
22
+ token_start = start
23
+ while i < length and not command[i].isspace():
24
+ i += 1
25
+ tokens.append(command[token_start:i])
26
+ return tokens
27
+
28
+
29
+ def parse_value(command, i):
30
+ if i < len(command):
31
+ if command[i] == '%':
32
+ return parse_backtick_string(command, i)
33
+ elif command[i] in ['[', '(', '{']:
34
+ return parse_bracketed_value(command, i)
35
+ else:
36
+ start = i
37
+ while i < len(command) and not command[i].isspace():
38
+ i += 1
39
+ return command[start:i], i
40
+ return '', i
41
+
42
+
43
+ def parse_backtick_string(command, i):
44
+ assert command[i] == '%'
45
+ i += 1
46
+ value = ['%']
47
+ length = len(command)
48
+ nesting_level = 1
49
+ while i < length:
50
+ c = command[i]
51
+ if c == '\\' and i+1 < length:
52
+ value.append(c)
53
+ value.append(command[i+1])
54
+ i += 2
55
+ elif c == '%':
56
+ value.append(c)
57
+ i += 1
58
+ if i < length and command[i] == '%':
59
+ value.append(command[i])
60
+ i += 1
61
+ else:
62
+ nesting_level -= 1
63
+ if nesting_level == 0:
64
+ break
65
+ elif c == '%':
66
+ value.append(c)
67
+ nesting_level += 1
68
+ i += 1
69
+ else:
70
+ value.append(c)
71
+ i += 1
72
+ return ''.join(value), i
73
+
74
+
75
+ def parse_bracketed_value(command, i):
76
+ brackets = {'[': ']', '{': '}', '(': ')'}
77
+ opening_bracket = command[i]
78
+ closing_bracket = brackets[opening_bracket]
79
+ value = [opening_bracket]
80
+ i += 1
81
+ length = len(command)
82
+ stack = [closing_bracket]
83
+ while i < length and stack:
84
+ c = command[i]
85
+ if c == '%':
86
+ backtick_value, i = parse_backtick_string(command, i)
87
+ value.append(backtick_value)
88
+ elif c == '\\' and i+1 < length:
89
+ value.append(c)
90
+ value.append(command[i+1])
91
+ i += 2
92
+ elif c in brackets:
93
+ stack.append(brackets[c])
94
+ value.append(c)
95
+ i += 1
96
+ elif c == stack[-1]:
97
+ stack.pop()
98
+ value.append(c)
99
+ i += 1
100
+ else:
101
+ value.append(c)
102
+ i += 1
103
+ return ''.join(value), i