aiauto-client 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aiauto/__init__.py ADDED
@@ -0,0 +1,50 @@
1
+ from .core import AIAutoController, TrialController, CallbackTopNArtifact, StudyWrapper
2
+ from .serialization import SourceCodeSerializer, create_study_with_source_serialization
3
+
4
+ __version__ = "0.1.0"
5
+
6
+ __all__ = [
7
+ 'AIAutoController',
8
+ 'TrialController',
9
+ 'CallbackTopNArtifact',
10
+ 'StudyWrapper',
11
+ 'SourceCodeSerializer',
12
+ 'create_study_with_source_serialization',
13
+ ]
14
+
15
+ # Optuna 호환성을 위한 간편 함수
16
+ def create_study(
17
+ objective=None,
18
+ study_name='aiauto_study',
19
+ direction='minimize',
20
+ **kwargs
21
+ ):
22
+ """
23
+ Optuna 호환 create_study 함수
24
+
25
+ 사용법:
26
+ study = aiauto.create_study(
27
+ objective=my_objective,
28
+ study_name='my_study',
29
+ direction='maximize'
30
+ )
31
+ study.optimize(n_trials=100)
32
+ """
33
+ controller = AIAutoController()
34
+
35
+ if objective is not None:
36
+ return controller.create_study_with_serialization(
37
+ objective=objective,
38
+ study_name=study_name,
39
+ direction=direction,
40
+ **kwargs
41
+ )
42
+ else:
43
+ # objective가 없으면 일반 optuna study 반환 (기존 방식)
44
+ import optuna
45
+ return optuna.create_study(
46
+ study_name=study_name,
47
+ direction=direction,
48
+ storage=controller.get_storage(),
49
+ **kwargs
50
+ )
aiauto/core.py ADDED
@@ -0,0 +1,257 @@
1
+ from os import makedirs, environ
2
+ import tempfile
3
+ from typing import Union, Callable, Dict, Any, Optional
4
+ import optuna
5
+
6
+ from .serialization import create_study_with_source_serialization, SourceCodeSerializer
7
+
8
+
9
+ class AIAutoController:
10
+ # singleton pattern
11
+ def __new__(cls, *args, **kwargs):
12
+ if not hasattr(cls, "_instance"):
13
+ cls._instance = super().__new__(cls)
14
+ return cls._instance
15
+
16
+ # singleton pattern
17
+ def __init__(self):
18
+ cls = type(self)
19
+ if not hasattr(cls, "_init"):
20
+ # singleton pattern
21
+ # ---------------------
22
+
23
+ # TODO token 인증
24
+ token = environ.get('AIAUTO_TOKEN')
25
+
26
+ # mode별 storage 설정
27
+ mode = environ.get('AIAUTO_MODE', 'single_gpu')
28
+ if mode == "distributed":
29
+ # DDP/FSDP pruning callback 지원을 위해 RDBStorage 사용
30
+ self.storage = optuna.storages.RDBStorage(
31
+ url="sqlite:///optuna.db",
32
+ engine_kwargs={"connect_args": {"timeout": 10}}
33
+ )
34
+ else:
35
+ # 기본 GrpcStorageProxy (single GPU 등)
36
+ self.storage = optuna.storages.GrpcStorageProxy(host="localhost", port=13000)
37
+
38
+ # artifact storage
39
+ # TODO 나중에 s3 던 다른 mount 된 경로 건 바꿔야 함
40
+ makedirs('./artifacts', exist_ok=True)
41
+ self.artifact_store = optuna.artifacts.FileSystemArtifactStore('./artifacts')
42
+ # model 저장을 위한 임시 디렉토리
43
+ self.tmp_dir = tempfile.mkdtemp(prefix=f'ai_auto_tmp_')
44
+
45
+ # ---------------------
46
+ # singleton pattern end
47
+ cls._init = True
48
+
49
+ def get_storage(self):
50
+ return self.storage
51
+
52
+ def get_artifact_store(self) -> Union[
53
+ optuna.artifacts.FileSystemArtifactStore,
54
+ optuna.artifacts.Boto3ArtifactStore,
55
+ optuna.artifacts.GCSArtifactStore,
56
+ ]:
57
+ return self.artifact_store
58
+
59
+ def get_artifact_tmp_dir(self):
60
+ return self.tmp_dir
61
+
62
+ def create_study_with_serialization(
63
+ self,
64
+ objective: Callable,
65
+ study_name: str,
66
+ direction: str = 'minimize',
67
+ sampler: Optional[optuna.samplers.BaseSampler] = None,
68
+ pruner: Optional[optuna.pruners.BasePruner] = None,
69
+ **optuna_kwargs
70
+ ) -> 'StudyWrapper':
71
+ """
72
+ 소스코드 직렬화를 사용하여 Study 생성
73
+
74
+ Args:
75
+ objective: HPO에 사용할 objective 함수
76
+ study_name: Study 이름
77
+ direction: 최적화 방향 ('minimize' 또는 'maximize')
78
+ sampler: Optuna sampler (기본값: TPESampler)
79
+ pruner: Optuna pruner
80
+ **optuna_kwargs: optuna.create_study에 전달할 추가 인자
81
+
82
+ Returns:
83
+ StudyWrapper 객체 (Optuna Study 호환)
84
+ """
85
+ study_config = {
86
+ 'study_name': study_name,
87
+ 'direction': direction,
88
+ 'sampler': sampler.__class__.__name__ if sampler else 'TPESampler',
89
+ 'pruner': pruner.__class__.__name__ if pruner else None,
90
+ }
91
+
92
+ # 소스코드 직렬화
93
+ serialized_objective, processed_config = create_study_with_source_serialization(
94
+ objective, study_config, **optuna_kwargs
95
+ )
96
+
97
+ # StudyWrapper 생성 (실제 gRPC 전송은 optimize 시점에)
98
+ return StudyWrapper(
99
+ serialized_objective=serialized_objective,
100
+ study_config=processed_config,
101
+ storage=self.storage,
102
+ artifact_store=self.artifact_store
103
+ )
104
+
105
+
106
+ class TrialController:
107
+ def __init__(self, trial: optuna.trial.Trial):
108
+ self.trial = trial
109
+ self.logger = optuna.logging.get_logger("optuna")
110
+ self.logs = []
111
+
112
+ def get_trial(self) -> optuna.trial.Trial:
113
+ return self.trial
114
+
115
+ def log(self, value: str):
116
+ # optuna dashboard 에 log 를 확인하는 기능이 없어서 user_attribute 에 log를 확인할 수 있게 추가
117
+ self.logs.append(value)
118
+ self.trial.set_user_attr('logs', ' '.join([f"[{i+1:05d}] {log}" for i, log in enumerate(self.logs)]))
119
+ # 실제 log 를 trial_number 랑 같이 확인할 수 있게
120
+ self.logger.info(f'\ntrial_number: {self.trial.number}, {value}')
121
+
122
+
123
+ # 용량 제한으로 상위 N개의 trial artifact 만 유지
124
+ class CallbackTopNArtifact:
125
+ def __init__(
126
+ self,
127
+ artifact_store: Union[
128
+ optuna.artifacts.FileSystemArtifactStore,
129
+ optuna.artifacts.Boto3ArtifactStore,
130
+ optuna.artifacts.GCSArtifactStore,
131
+ ],
132
+ artifact_attr_name: str = 'artifact_id',
133
+ n_keep: int = 5,
134
+ ):
135
+ self.artifact_store = artifact_store
136
+ self.check_attr_name = artifact_attr_name
137
+ self.n_keep = n_keep
138
+
139
+ def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial):
140
+ # COMPLETE 상태이고 artifact를 가진 trial들만 정렬
141
+ finished_with_artifacts = [
142
+ t for t in study.trials
143
+ if t.state == optuna.trial.TrialState.COMPLETE and self.check_attr_name in t.user_attrs
144
+ ]
145
+
146
+ # 방향에 따라 정렬 (maximize면 내림차순, minimize면 오름차순)
147
+ reverse_sort = study.direction == optuna.study.StudyDirection.MAXIMIZE
148
+ finished_with_artifacts.sort(key=lambda t: t.value, reverse=reverse_sort)
149
+
150
+ # 상위 n_keep개 초과하는 trial들의 artifact 삭제
151
+ for old_trial in finished_with_artifacts[self.n_keep:]:
152
+ artifact_id = old_trial.user_attrs.get(self.check_attr_name)
153
+ if artifact_id:
154
+ try:
155
+ self.artifact_store.remove(artifact_id)
156
+ # user_attr에서도 제거
157
+ study._storage.set_trial_user_attr(old_trial._trial_id, self.check_attr_name, None)
158
+ except Exception as e:
159
+ print(f"Warning: Failed to remove artifact {artifact_id}: {e}")
160
+
161
+
162
+ class StudyWrapper:
163
+ """
164
+ Optuna Study 호환성을 제공하는 래퍼 클래스
165
+
166
+ 이 클래스는 소스코드 직렬화된 objective 함수를 관리하고
167
+ 실제 HPO 실행을 위해 gRPC 백엔드와 통신합니다.
168
+ """
169
+
170
+ def __init__(
171
+ self,
172
+ serialized_objective: Dict[str, Any],
173
+ study_config: Dict[str, Any],
174
+ storage,
175
+ artifact_store
176
+ ):
177
+ self.serialized_objective = serialized_objective
178
+ self.study_config = study_config
179
+ self.storage = storage
180
+ self.artifact_store = artifact_store
181
+ self._local_study = None # 로컬 테스트용
182
+
183
+ def optimize(
184
+ self,
185
+ n_trials: int = 100,
186
+ n_jobs: int = 1,
187
+ callbacks: Optional[list] = None,
188
+ **kwargs
189
+ ):
190
+ """
191
+ HPO 최적화 실행
192
+
193
+ 실제 구현에서는 gRPC를 통해 백엔드로 전송하지만,
194
+ 현재는 로컬에서 역직렬화하여 테스트합니다.
195
+ """
196
+ print("🚀 Starting HPO optimization with source code serialization...")
197
+ print(f"📊 Study: {self.study_config['study_name']}")
198
+ print(f"🎯 Direction: {self.study_config['direction']}")
199
+ print(f"🔢 Trials: {n_trials}")
200
+
201
+ try:
202
+ # 소스코드 역직렬화로 objective 함수 복원
203
+ objective_func = SourceCodeSerializer.deserialize_objective(
204
+ self.serialized_objective
205
+ )
206
+ print("✅ Objective function deserialized successfully")
207
+
208
+ # 로컬 Study 생성 (실제로는 gRPC 통신)
209
+ self._local_study = optuna.create_study(
210
+ study_name=self.study_config['study_name'],
211
+ direction=self.study_config['direction'],
212
+ storage=self.storage,
213
+ load_if_exists=True
214
+ )
215
+
216
+ # 최적화 실행
217
+ self._local_study.optimize(
218
+ objective_func,
219
+ n_trials=n_trials,
220
+ n_jobs=n_jobs,
221
+ callbacks=callbacks or [],
222
+ **kwargs
223
+ )
224
+
225
+ print(f"🎉 Optimization completed! Best value: {self.best_value}")
226
+
227
+ except Exception as e:
228
+ print(f"❌ Optimization failed: {e}")
229
+ raise
230
+
231
+ @property
232
+ def best_trial(self):
233
+ """최고 성능 Trial 반환"""
234
+ if self._local_study:
235
+ return self._local_study.best_trial
236
+ return None
237
+
238
+ @property
239
+ def best_value(self):
240
+ """최고 성능 값 반환"""
241
+ if self._local_study:
242
+ return self._local_study.best_value
243
+ return None
244
+
245
+ @property
246
+ def best_params(self):
247
+ """최고 성능 하이퍼파라미터 반환"""
248
+ if self._local_study:
249
+ return self._local_study.best_params
250
+ return None
251
+
252
+ @property
253
+ def trials(self):
254
+ """모든 Trial 목록 반환"""
255
+ if self._local_study:
256
+ return self._local_study.trials
257
+ return []
@@ -0,0 +1,138 @@
1
+ """
2
+ Source Code Serialization Module
3
+
4
+ 이 모듈은 Python 버전 간 호환성을 위해 CloudPickle 대신
5
+ inspect.getsource를 사용한 소스코드 직렬화 방식을 제공합니다.
6
+ """
7
+
8
+ import inspect
9
+ import types
10
+ from typing import Callable, Dict, Any, Tuple
11
+
12
+
13
+ class SourceCodeSerializer:
14
+ """Objective 함수를 소스코드로 직렬화하는 클래스"""
15
+
16
+ @staticmethod
17
+ def serialize_objective(objective_func: Callable) -> Dict[str, Any]:
18
+ """
19
+ Objective 함수를 소스코드로 직렬화
20
+
21
+ Args:
22
+ objective_func: 직렬화할 objective 함수
23
+
24
+ Returns:
25
+ 직렬화된 데이터 딕셔너리
26
+ - source_code: 함수의 소스코드 문자열
27
+ - func_name: 함수 이름
28
+ - dependencies: 필요한 import 구문들
29
+ """
30
+ try:
31
+ # 함수 소스코드 추출
32
+ source_code = inspect.getsource(objective_func)
33
+ func_name = objective_func.__name__
34
+
35
+ # 함수가 정의된 모듈의 정보 추출
36
+ module = inspect.getmodule(objective_func)
37
+ dependencies = []
38
+
39
+ if module and hasattr(module, '__file__'):
40
+ # 모듈에서 import 구문들 추출 (간단한 방식)
41
+ with open(module.__file__, 'r') as f:
42
+ module_source = f.read()
43
+
44
+ # import 구문 추출 (개선된 파싱 필요시 ast 모듈 사용)
45
+ lines = module_source.split('\n')
46
+ for line in lines:
47
+ line = line.strip()
48
+ if line.startswith('import ') or line.startswith('from '):
49
+ # 기본적인 import 구문만 추출
50
+ if not any(skip in line for skip in ['client', '__', 'relative']):
51
+ dependencies.append(line)
52
+
53
+ return {
54
+ 'source_code': source_code,
55
+ 'func_name': func_name,
56
+ 'dependencies': dependencies,
57
+ 'serialization_method': 'source_code'
58
+ }
59
+
60
+ except Exception as e:
61
+ raise RuntimeError(f"Failed to serialize objective function: {e}")
62
+
63
+ @staticmethod
64
+ def deserialize_objective(serialized_data: Dict[str, Any]) -> Callable:
65
+ """
66
+ 직렬화된 데이터로부터 objective 함수를 복원
67
+
68
+ Args:
69
+ serialized_data: serialize_objective에서 생성된 데이터
70
+
71
+ Returns:
72
+ 복원된 objective 함수
73
+ """
74
+ try:
75
+ source_code = serialized_data['source_code']
76
+ func_name = serialized_data['func_name']
77
+ dependencies = serialized_data.get('dependencies', [])
78
+
79
+ # 실행 네임스페이스 생성
80
+ exec_namespace = {'__builtins__': __builtins__}
81
+
82
+ # 의존성 import 실행
83
+ for dep in dependencies:
84
+ try:
85
+ exec(dep, exec_namespace)
86
+ except Exception as import_error:
87
+ # import 실패는 경고만 하고 계속 진행
88
+ print(f"Warning: Failed to import dependency '{dep}': {import_error}")
89
+
90
+ # 소스코드 실행
91
+ exec(source_code, exec_namespace)
92
+
93
+ # 함수 객체 추출
94
+ if func_name not in exec_namespace:
95
+ raise NameError(f"Function '{func_name}' not found in executed namespace")
96
+
97
+ objective_func = exec_namespace[func_name]
98
+
99
+ if not callable(objective_func):
100
+ raise TypeError(f"'{func_name}' is not callable")
101
+
102
+ return objective_func
103
+
104
+ except Exception as e:
105
+ raise RuntimeError(f"Failed to deserialize objective function: {e}")
106
+
107
+
108
+ def create_study_with_source_serialization(
109
+ objective: Callable,
110
+ study_config: Dict[str, Any],
111
+ **optuna_kwargs
112
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
113
+ """
114
+ 소스코드 직렬화를 사용하여 study 생성 준비
115
+
116
+ Args:
117
+ objective: HPO에 사용할 objective 함수
118
+ study_config: study 설정 (name, direction, sampler, pruner 등)
119
+ **optuna_kwargs: optuna.create_study에 전달할 추가 인자들
120
+
121
+ Returns:
122
+ Tuple[serialized_objective, study_config]
123
+ - serialized_objective: 직렬화된 objective 함수 데이터
124
+ - study_config: study 설정 데이터
125
+ """
126
+ # Objective 함수 직렬화
127
+ serialized_objective = SourceCodeSerializer.serialize_objective(objective)
128
+
129
+ # Study 설정 정리
130
+ processed_config = {
131
+ 'study_name': study_config.get('study_name', 'unnamed_study'),
132
+ 'direction': study_config.get('direction', 'minimize'),
133
+ 'sampler': study_config.get('sampler', 'TPESampler'),
134
+ 'pruner': study_config.get('pruner', None),
135
+ 'optuna_kwargs': optuna_kwargs
136
+ }
137
+
138
+ return serialized_objective, processed_config
@@ -0,0 +1,74 @@
1
+ Metadata-Version: 2.1
2
+ Name: aiauto-client
3
+ Version: 0.1.0
4
+ Summary: AI Auto HPO (Hyperparameter Optimization) Client Library
5
+ Author-email: AIAuto Team <ainode@zeroone.ai>
6
+ Project-URL: Homepage, https://aiauto.cloude.ainode.ai
7
+ Project-URL: Repository, https://aiauto.cloude.ainode.ai
8
+ Project-URL: Documentation, https://aiauto.cloude.ainode.ai
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: Intended Audience :: Science/Research
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.8
14
+ Classifier: Programming Language :: Python :: 3.9
15
+ Classifier: Programming Language :: Python :: 3.10
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Programming Language :: Python :: 3.13
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
21
+ Requires-Python: >=3.8
22
+ Description-Content-Type: text/markdown
23
+ Requires-Dist: optuna>=3.0.0
24
+ Requires-Dist: grpcio>=1.50.0
25
+ Requires-Dist: grpcio-tools>=1.50.0
26
+ Requires-Dist: protobuf>=4.0.0
27
+
28
+ # AIAuto - Hyperparameter Optimization Client Library
29
+
30
+ AIAuto는 Kubernetes 기반의 분산 HPO(Hyperparameter Optimization) 시스템을 위한 클라이언트 라이브러리입니다.
31
+ 사용자 python lib <-> Next.js 서버 사이 gRPC 통신 담당
32
+
33
+ ## lib build
34
+ - pypi build, upload 종속성 다운로드 `pip install build twine`
35
+ - build lib `python -m build --wheel --sdist`
36
+ - `aiauto-0.1.0.whl` 생성
37
+ - `aiauto-0.1.0.tar.gz` 생성
38
+ - `aiauto.egg-info` 생성
39
+ - `twine upload --repository testpypi dist/*`
40
+ - `twine upload dist/*`
41
+ - upload 시 pypi token 을 입력하라고 나옴, pypi 로그인 계정 설정가면 있다
42
+
43
+ ## 설치
44
+ - `pip install aiauto`
45
+
46
+ ## 빠른 시작
47
+ ```python
48
+ import aiauto
49
+
50
+ # 컨트롤러 초기화
51
+ ac = aiauto.AIAutoController()
52
+
53
+ # Objective 함수 정의
54
+ def objective(trial):
55
+ tc = aiauto.TrialController(trial)
56
+
57
+ # 하이퍼파라미터 샘플링
58
+ lr = trial.suggest_float('lr', 1e-5, 1e-1, log=True)
59
+
60
+ # 모델 학습 및 평가 로직
61
+ # ...
62
+ tc.log(f'full dataset: train {len(dataset)}, test {len(dataset_test)}, batch_size {batch_size}')
63
+
64
+ return accuracy
65
+
66
+ # Study 생성 및 최적화 실행
67
+ study = optuna.create_study(
68
+ study_name='my_optimization',
69
+ storage=ac.get_storage(),
70
+ direction='maximize'
71
+ )
72
+
73
+ study.optimize(objective, n_trials=100)
74
+ ```
@@ -0,0 +1,7 @@
1
+ aiauto/__init__.py,sha256=VvEM3L0NZGrHi3kHV_gSRf8X2baqLDPOSbArgd6LpaI,1353
2
+ aiauto/core.py,sha256=GKCF24GA25QCu8n2q3YXnff4Sb3Dfx1yKFvE7QZ8108,9182
3
+ aiauto/serialization.py,sha256=6Rb5k01hx7uXaLt1XmUrmn1KzMjxsYinzi4fjglc3jw,5137
4
+ aiauto_client-0.1.0.dist-info/METADATA,sha256=inimyQ0HuHH8mWkYUOP1FpXez506z1i8_oi0UVBSwsE,2510
5
+ aiauto_client-0.1.0.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
6
+ aiauto_client-0.1.0.dist-info/top_level.txt,sha256=Sk2ctO9_Bf_tAPwq1x6Vfl6OuL29XzwMTO4F_KG6oJE,7
7
+ aiauto_client-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (75.3.2)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ aiauto