hebra 0.1.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hebra/__init__.py +51 -0
- hebra/bin/.gitkeep +1 -0
- hebra/bin/core +0 -0
- hebra/lib/__init__.py +0 -0
- hebra/lib/service_connection.py +137 -0
- hebra/lib/utils.py +66 -0
- hebra/proto/common/v1/common.proto +50 -0
- hebra/proto/common/v1/common_pb2.py +43 -0
- hebra/proto/common/v1/common_pb2.pyi +45 -0
- hebra/proto/core/collector/v1/collector_service.proto +23 -0
- hebra/proto/core/collector/v1/collector_service_pb2.py +42 -0
- hebra/proto/core/collector/v1/collector_service_pb2.pyi +21 -0
- hebra/proto/core/collector/v1/collector_service_pb2_grpc.py +101 -0
- hebra/proto/core/health/v1/health_service.proto +92 -0
- hebra/proto/core/lifecycle/__init__.py +1 -0
- hebra/proto/core/lifecycle/v1/__init__.py +1 -0
- hebra/proto/core/lifecycle/v1/lifecycle_service.proto +24 -0
- hebra/proto/core/lifecycle/v1/lifecycle_service_pb2.py +41 -0
- hebra/proto/core/lifecycle/v1/lifecycle_service_pb2.pyi +17 -0
- hebra/proto/core/lifecycle/v1/lifecycle_service_pb2_grpc.py +102 -0
- hebra/proto/record/v1/record.proto +180 -0
- hebra/proto/record/v1/record_pb2.py +67 -0
- hebra/proto/record/v1/record_pb2.pyi +206 -0
- hebra/run/__init__.py +4 -0
- hebra/run/main.py +211 -0
- hebra/run/state.py +10 -0
- hebra/sdk.py +78 -0
- hebra-0.1.0.dist-info/METADATA +8 -0
- hebra-0.1.0.dist-info/RECORD +30 -0
- hebra-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
from google.protobuf import struct_pb2 as _struct_pb2
|
|
2
|
+
from google.protobuf import any_pb2 as _any_pb2
|
|
3
|
+
from google.protobuf.internal import containers as _containers
|
|
4
|
+
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
|
|
5
|
+
from google.protobuf import descriptor as _descriptor
|
|
6
|
+
from google.protobuf import message as _message
|
|
7
|
+
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
|
|
8
|
+
from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
|
|
9
|
+
|
|
10
|
+
DESCRIPTOR: _descriptor.FileDescriptor
|
|
11
|
+
|
|
12
|
+
class Record(_message.Message):
|
|
13
|
+
__slots__ = ("message_type", "payload")
|
|
14
|
+
class RecordType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
|
15
|
+
__slots__ = ()
|
|
16
|
+
RECORD_UNKNOWN: _ClassVar[Record.RecordType]
|
|
17
|
+
RECORD_SETUP: _ClassVar[Record.RecordType]
|
|
18
|
+
RECORD_TEARDOWN: _ClassVar[Record.RecordType]
|
|
19
|
+
RECORD_RUNTIME: _ClassVar[Record.RecordType]
|
|
20
|
+
RECORD_COLUMN: _ClassVar[Record.RecordType]
|
|
21
|
+
RECORD_MEDIA: _ClassVar[Record.RecordType]
|
|
22
|
+
RECORD_SCALAR: _ClassVar[Record.RecordType]
|
|
23
|
+
RECORD_LOG: _ClassVar[Record.RecordType]
|
|
24
|
+
RECORD_UNKNOWN: Record.RecordType
|
|
25
|
+
RECORD_SETUP: Record.RecordType
|
|
26
|
+
RECORD_TEARDOWN: Record.RecordType
|
|
27
|
+
RECORD_RUNTIME: Record.RecordType
|
|
28
|
+
RECORD_COLUMN: Record.RecordType
|
|
29
|
+
RECORD_MEDIA: Record.RecordType
|
|
30
|
+
RECORD_SCALAR: Record.RecordType
|
|
31
|
+
RECORD_LOG: Record.RecordType
|
|
32
|
+
MESSAGE_TYPE_FIELD_NUMBER: _ClassVar[int]
|
|
33
|
+
PAYLOAD_FIELD_NUMBER: _ClassVar[int]
|
|
34
|
+
message_type: Record.RecordType
|
|
35
|
+
payload: _any_pb2.Any
|
|
36
|
+
def __init__(self, message_type: _Optional[_Union[Record.RecordType, str]] = ..., payload: _Optional[_Union[_any_pb2.Any, _Mapping]] = ...) -> None: ...
|
|
37
|
+
|
|
38
|
+
class SetupRecord(_message.Message):
|
|
39
|
+
__slots__ = ("name", "workspace", "visibility", "experiment_name", "experiment_description", "experiment_tags", "start_time")
|
|
40
|
+
class Visibility(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
|
41
|
+
__slots__ = ()
|
|
42
|
+
VISIBILITY_UNKNOWN: _ClassVar[SetupRecord.Visibility]
|
|
43
|
+
VISIBILITY_PRIVATE: _ClassVar[SetupRecord.Visibility]
|
|
44
|
+
VISIBILITY_PUBLIC: _ClassVar[SetupRecord.Visibility]
|
|
45
|
+
VISIBILITY_UNKNOWN: SetupRecord.Visibility
|
|
46
|
+
VISIBILITY_PRIVATE: SetupRecord.Visibility
|
|
47
|
+
VISIBILITY_PUBLIC: SetupRecord.Visibility
|
|
48
|
+
NAME_FIELD_NUMBER: _ClassVar[int]
|
|
49
|
+
WORKSPACE_FIELD_NUMBER: _ClassVar[int]
|
|
50
|
+
VISIBILITY_FIELD_NUMBER: _ClassVar[int]
|
|
51
|
+
EXPERIMENT_NAME_FIELD_NUMBER: _ClassVar[int]
|
|
52
|
+
EXPERIMENT_DESCRIPTION_FIELD_NUMBER: _ClassVar[int]
|
|
53
|
+
EXPERIMENT_TAGS_FIELD_NUMBER: _ClassVar[int]
|
|
54
|
+
START_TIME_FIELD_NUMBER: _ClassVar[int]
|
|
55
|
+
name: str
|
|
56
|
+
workspace: str
|
|
57
|
+
visibility: SetupRecord.Visibility
|
|
58
|
+
experiment_name: str
|
|
59
|
+
experiment_description: str
|
|
60
|
+
experiment_tags: _containers.RepeatedScalarFieldContainer[str]
|
|
61
|
+
start_time: str
|
|
62
|
+
def __init__(self, name: _Optional[str] = ..., workspace: _Optional[str] = ..., visibility: _Optional[_Union[SetupRecord.Visibility, str]] = ..., experiment_name: _Optional[str] = ..., experiment_description: _Optional[str] = ..., experiment_tags: _Optional[_Iterable[str]] = ..., start_time: _Optional[str] = ...) -> None: ...
|
|
63
|
+
|
|
64
|
+
class TeardownRecord(_message.Message):
|
|
65
|
+
__slots__ = ("error_message", "end_time")
|
|
66
|
+
ERROR_MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
|
67
|
+
END_TIME_FIELD_NUMBER: _ClassVar[int]
|
|
68
|
+
error_message: str
|
|
69
|
+
end_time: str
|
|
70
|
+
def __init__(self, error_message: _Optional[str] = ..., end_time: _Optional[str] = ...) -> None: ...
|
|
71
|
+
|
|
72
|
+
class RuntimeRecord(_message.Message):
|
|
73
|
+
__slots__ = ("conda_filename", "pip_filename", "config_filename", "metadata_filename")
|
|
74
|
+
CONDA_FILENAME_FIELD_NUMBER: _ClassVar[int]
|
|
75
|
+
PIP_FILENAME_FIELD_NUMBER: _ClassVar[int]
|
|
76
|
+
CONFIG_FILENAME_FIELD_NUMBER: _ClassVar[int]
|
|
77
|
+
METADATA_FILENAME_FIELD_NUMBER: _ClassVar[int]
|
|
78
|
+
conda_filename: str
|
|
79
|
+
pip_filename: str
|
|
80
|
+
config_filename: str
|
|
81
|
+
metadata_filename: str
|
|
82
|
+
def __init__(self, conda_filename: _Optional[str] = ..., pip_filename: _Optional[str] = ..., config_filename: _Optional[str] = ..., metadata_filename: _Optional[str] = ...) -> None: ...
|
|
83
|
+
|
|
84
|
+
class Range(_message.Message):
|
|
85
|
+
__slots__ = ("minval", "maxval")
|
|
86
|
+
MINVAL_FIELD_NUMBER: _ClassVar[int]
|
|
87
|
+
MAXVAL_FIELD_NUMBER: _ClassVar[int]
|
|
88
|
+
minval: int
|
|
89
|
+
maxval: int
|
|
90
|
+
def __init__(self, minval: _Optional[int] = ..., maxval: _Optional[int] = ...) -> None: ...
|
|
91
|
+
|
|
92
|
+
class ColumnRecord(_message.Message):
|
|
93
|
+
__slots__ = ("column_key", "column_name", "column_class", "column_type", "column_error", "section_name", "section_type", "chart_name", "chart_index", "chart_y_range", "metric_name", "metric_color")
|
|
94
|
+
class ColumnClass(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
|
95
|
+
__slots__ = ()
|
|
96
|
+
COL_CLASS_CUSTOM: _ClassVar[ColumnRecord.ColumnClass]
|
|
97
|
+
COL_CLASS_SYSTEM: _ClassVar[ColumnRecord.ColumnClass]
|
|
98
|
+
COL_CLASS_CUSTOM: ColumnRecord.ColumnClass
|
|
99
|
+
COL_CLASS_SYSTEM: ColumnRecord.ColumnClass
|
|
100
|
+
class ColumnType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
|
101
|
+
__slots__ = ()
|
|
102
|
+
COL_UNKNOWN: _ClassVar[ColumnRecord.ColumnType]
|
|
103
|
+
COL_FLOAT: _ClassVar[ColumnRecord.ColumnType]
|
|
104
|
+
COL_IMAGE: _ClassVar[ColumnRecord.ColumnType]
|
|
105
|
+
COL_AUDIO: _ClassVar[ColumnRecord.ColumnType]
|
|
106
|
+
COL_TEXT: _ClassVar[ColumnRecord.ColumnType]
|
|
107
|
+
COL_OBJECT3D: _ClassVar[ColumnRecord.ColumnType]
|
|
108
|
+
COL_MOLECULE: _ClassVar[ColumnRecord.ColumnType]
|
|
109
|
+
COL_ECHARTS: _ClassVar[ColumnRecord.ColumnType]
|
|
110
|
+
COL_UNKNOWN: ColumnRecord.ColumnType
|
|
111
|
+
COL_FLOAT: ColumnRecord.ColumnType
|
|
112
|
+
COL_IMAGE: ColumnRecord.ColumnType
|
|
113
|
+
COL_AUDIO: ColumnRecord.ColumnType
|
|
114
|
+
COL_TEXT: ColumnRecord.ColumnType
|
|
115
|
+
COL_OBJECT3D: ColumnRecord.ColumnType
|
|
116
|
+
COL_MOLECULE: ColumnRecord.ColumnType
|
|
117
|
+
COL_ECHARTS: ColumnRecord.ColumnType
|
|
118
|
+
class SectionType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
|
119
|
+
__slots__ = ()
|
|
120
|
+
SEC_PUBLIC: _ClassVar[ColumnRecord.SectionType]
|
|
121
|
+
SEC_SYSTEM: _ClassVar[ColumnRecord.SectionType]
|
|
122
|
+
SEC_CUSTOM: _ClassVar[ColumnRecord.SectionType]
|
|
123
|
+
SEC_PINNED: _ClassVar[ColumnRecord.SectionType]
|
|
124
|
+
SEC_HIDDEN: _ClassVar[ColumnRecord.SectionType]
|
|
125
|
+
SEC_PUBLIC: ColumnRecord.SectionType
|
|
126
|
+
SEC_SYSTEM: ColumnRecord.SectionType
|
|
127
|
+
SEC_CUSTOM: ColumnRecord.SectionType
|
|
128
|
+
SEC_PINNED: ColumnRecord.SectionType
|
|
129
|
+
SEC_HIDDEN: ColumnRecord.SectionType
|
|
130
|
+
COLUMN_KEY_FIELD_NUMBER: _ClassVar[int]
|
|
131
|
+
COLUMN_NAME_FIELD_NUMBER: _ClassVar[int]
|
|
132
|
+
COLUMN_CLASS_FIELD_NUMBER: _ClassVar[int]
|
|
133
|
+
COLUMN_TYPE_FIELD_NUMBER: _ClassVar[int]
|
|
134
|
+
COLUMN_ERROR_FIELD_NUMBER: _ClassVar[int]
|
|
135
|
+
SECTION_NAME_FIELD_NUMBER: _ClassVar[int]
|
|
136
|
+
SECTION_TYPE_FIELD_NUMBER: _ClassVar[int]
|
|
137
|
+
CHART_NAME_FIELD_NUMBER: _ClassVar[int]
|
|
138
|
+
CHART_INDEX_FIELD_NUMBER: _ClassVar[int]
|
|
139
|
+
CHART_Y_RANGE_FIELD_NUMBER: _ClassVar[int]
|
|
140
|
+
METRIC_NAME_FIELD_NUMBER: _ClassVar[int]
|
|
141
|
+
METRIC_COLOR_FIELD_NUMBER: _ClassVar[int]
|
|
142
|
+
column_key: str
|
|
143
|
+
column_name: str
|
|
144
|
+
column_class: ColumnRecord.ColumnClass
|
|
145
|
+
column_type: ColumnRecord.ColumnType
|
|
146
|
+
column_error: _struct_pb2.Struct
|
|
147
|
+
section_name: str
|
|
148
|
+
section_type: ColumnRecord.SectionType
|
|
149
|
+
chart_name: str
|
|
150
|
+
chart_index: str
|
|
151
|
+
chart_y_range: Range
|
|
152
|
+
metric_name: str
|
|
153
|
+
metric_color: _containers.RepeatedScalarFieldContainer[str]
|
|
154
|
+
def __init__(self, column_key: _Optional[str] = ..., column_name: _Optional[str] = ..., column_class: _Optional[_Union[ColumnRecord.ColumnClass, str]] = ..., column_type: _Optional[_Union[ColumnRecord.ColumnType, str]] = ..., column_error: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., section_name: _Optional[str] = ..., section_type: _Optional[_Union[ColumnRecord.SectionType, str]] = ..., chart_name: _Optional[str] = ..., chart_index: _Optional[str] = ..., chart_y_range: _Optional[_Union[Range, _Mapping]] = ..., metric_name: _Optional[str] = ..., metric_color: _Optional[_Iterable[str]] = ...) -> None: ...
|
|
155
|
+
|
|
156
|
+
class MediaRecord(_message.Message):
|
|
157
|
+
__slots__ = ("index", "epoch", "create_time", "key", "key_encoded", "kid", "data", "more")
|
|
158
|
+
INDEX_FIELD_NUMBER: _ClassVar[int]
|
|
159
|
+
EPOCH_FIELD_NUMBER: _ClassVar[int]
|
|
160
|
+
CREATE_TIME_FIELD_NUMBER: _ClassVar[int]
|
|
161
|
+
KEY_FIELD_NUMBER: _ClassVar[int]
|
|
162
|
+
KEY_ENCODED_FIELD_NUMBER: _ClassVar[int]
|
|
163
|
+
KID_FIELD_NUMBER: _ClassVar[int]
|
|
164
|
+
DATA_FIELD_NUMBER: _ClassVar[int]
|
|
165
|
+
MORE_FIELD_NUMBER: _ClassVar[int]
|
|
166
|
+
index: str
|
|
167
|
+
epoch: str
|
|
168
|
+
create_time: str
|
|
169
|
+
key: str
|
|
170
|
+
key_encoded: str
|
|
171
|
+
kid: str
|
|
172
|
+
data: _containers.RepeatedScalarFieldContainer[str]
|
|
173
|
+
more: _containers.RepeatedCompositeFieldContainer[_struct_pb2.Struct]
|
|
174
|
+
def __init__(self, index: _Optional[str] = ..., epoch: _Optional[str] = ..., create_time: _Optional[str] = ..., key: _Optional[str] = ..., key_encoded: _Optional[str] = ..., kid: _Optional[str] = ..., data: _Optional[_Iterable[str]] = ..., more: _Optional[_Iterable[_Union[_struct_pb2.Struct, _Mapping]]] = ...) -> None: ...
|
|
175
|
+
|
|
176
|
+
class ScalarRecord(_message.Message):
|
|
177
|
+
__slots__ = ("index", "epoch", "create_time", "key", "data")
|
|
178
|
+
INDEX_FIELD_NUMBER: _ClassVar[int]
|
|
179
|
+
EPOCH_FIELD_NUMBER: _ClassVar[int]
|
|
180
|
+
CREATE_TIME_FIELD_NUMBER: _ClassVar[int]
|
|
181
|
+
KEY_FIELD_NUMBER: _ClassVar[int]
|
|
182
|
+
DATA_FIELD_NUMBER: _ClassVar[int]
|
|
183
|
+
index: str
|
|
184
|
+
epoch: str
|
|
185
|
+
create_time: str
|
|
186
|
+
key: str
|
|
187
|
+
data: _containers.RepeatedScalarFieldContainer[float]
|
|
188
|
+
def __init__(self, index: _Optional[str] = ..., epoch: _Optional[str] = ..., create_time: _Optional[str] = ..., key: _Optional[str] = ..., data: _Optional[_Iterable[float]] = ...) -> None: ...
|
|
189
|
+
|
|
190
|
+
class LogRecord(_message.Message):
|
|
191
|
+
__slots__ = ("epoch", "level", "message")
|
|
192
|
+
class LogType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
|
193
|
+
__slots__ = ()
|
|
194
|
+
LOG_INFO: _ClassVar[LogRecord.LogType]
|
|
195
|
+
LOG_WARN: _ClassVar[LogRecord.LogType]
|
|
196
|
+
LOG_ERROR: _ClassVar[LogRecord.LogType]
|
|
197
|
+
LOG_INFO: LogRecord.LogType
|
|
198
|
+
LOG_WARN: LogRecord.LogType
|
|
199
|
+
LOG_ERROR: LogRecord.LogType
|
|
200
|
+
EPOCH_FIELD_NUMBER: _ClassVar[int]
|
|
201
|
+
LEVEL_FIELD_NUMBER: _ClassVar[int]
|
|
202
|
+
MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
|
203
|
+
epoch: str
|
|
204
|
+
level: LogRecord.LogType
|
|
205
|
+
message: str
|
|
206
|
+
def __init__(self, epoch: _Optional[str] = ..., level: _Optional[_Union[LogRecord.LogType, str]] = ..., message: _Optional[str] = ...) -> None: ...
|
hebra/run/__init__.py
ADDED
hebra/run/main.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import atexit
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
import traceback
|
|
7
|
+
from typing import Any, Dict, Optional
|
|
8
|
+
|
|
9
|
+
from .state import HebraRunState
|
|
10
|
+
from ..lib.service_connection import ServiceConnection
|
|
11
|
+
from ..lib.utils import kvs_from_mapping
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class HebraRun:
|
|
15
|
+
"""Hebra运行实例,类似SwanLabRun。
|
|
16
|
+
|
|
17
|
+
一个进程同时只能有一个HebraRun实例在运行。
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
project: str,
|
|
23
|
+
logdir: str,
|
|
24
|
+
config: Optional[Dict[str, Any]] = None,
|
|
25
|
+
):
|
|
26
|
+
"""初始化HebraRun。
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
project: 项目名称
|
|
30
|
+
logdir: 日志存储目录
|
|
31
|
+
config: 可选的配置参数
|
|
32
|
+
"""
|
|
33
|
+
if self.is_started():
|
|
34
|
+
raise RuntimeError("HebraRun已经初始化,请先调用finish()结束当前运行")
|
|
35
|
+
|
|
36
|
+
global _current_run
|
|
37
|
+
|
|
38
|
+
# 保存配置
|
|
39
|
+
self._project = project
|
|
40
|
+
self._logdir = logdir
|
|
41
|
+
self._config = config or {}
|
|
42
|
+
self._state = HebraRunState.RUNNING
|
|
43
|
+
self._step = 0
|
|
44
|
+
|
|
45
|
+
# 创建日志目录
|
|
46
|
+
logdir = os.path.abspath(logdir)
|
|
47
|
+
os.makedirs(logdir, exist_ok=True)
|
|
48
|
+
db_path = os.path.join(logdir, "db")
|
|
49
|
+
|
|
50
|
+
# 启动服务连接
|
|
51
|
+
self._conn = ServiceConnection(db_path)
|
|
52
|
+
self._conn.start()
|
|
53
|
+
|
|
54
|
+
# 注册系统回调
|
|
55
|
+
self._register_sys_callback()
|
|
56
|
+
|
|
57
|
+
# 设置全局实例
|
|
58
|
+
_current_run = self
|
|
59
|
+
|
|
60
|
+
print(f"[hebra] 实验已初始化: project={project}, logdir={logdir}")
|
|
61
|
+
|
|
62
|
+
def _register_sys_callback(self) -> None:
|
|
63
|
+
"""注册系统退出回调"""
|
|
64
|
+
self._orig_excepthook = sys.excepthook
|
|
65
|
+
sys.excepthook = self._except_handler
|
|
66
|
+
atexit.register(self._clean_handler)
|
|
67
|
+
|
|
68
|
+
def _unregister_sys_callback(self) -> None:
|
|
69
|
+
"""注销系统退出回调"""
|
|
70
|
+
sys.excepthook = self._orig_excepthook
|
|
71
|
+
atexit.unregister(self._clean_handler)
|
|
72
|
+
|
|
73
|
+
def _clean_handler(self) -> None:
|
|
74
|
+
"""正常退出时的清理函数(atexit)"""
|
|
75
|
+
if self._state == HebraRunState.RUNNING:
|
|
76
|
+
print("[hebra] 程序退出,自动关闭实验...")
|
|
77
|
+
self.finish()
|
|
78
|
+
|
|
79
|
+
def _except_handler(self, exc_type, exc_val, exc_tb) -> None:
|
|
80
|
+
"""异常退出时的处理函数(excepthook)"""
|
|
81
|
+
# 生成错误堆栈
|
|
82
|
+
error_lines = traceback.format_exception(exc_type, exc_val, exc_tb)
|
|
83
|
+
error_msg = "".join(error_lines)
|
|
84
|
+
|
|
85
|
+
# 标记为崩溃状态
|
|
86
|
+
if self._state == HebraRunState.RUNNING:
|
|
87
|
+
print("[hebra] 检测到异常,标记实验为CRASHED...")
|
|
88
|
+
self.finish(state=HebraRunState.CRASHED, error=error_msg)
|
|
89
|
+
|
|
90
|
+
# 调用原始excepthook打印错误
|
|
91
|
+
self._orig_excepthook(exc_type, exc_val, exc_tb)
|
|
92
|
+
|
|
93
|
+
@staticmethod
|
|
94
|
+
def is_started() -> bool:
|
|
95
|
+
"""检查是否已初始化"""
|
|
96
|
+
return get_run() is not None
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def get_state() -> HebraRunState:
|
|
100
|
+
"""获取当前运行状态"""
|
|
101
|
+
run = get_run()
|
|
102
|
+
return run._state if run else HebraRunState.NOT_STARTED
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def running(self) -> bool:
|
|
106
|
+
"""是否正在运行"""
|
|
107
|
+
return self._state == HebraRunState.RUNNING
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def success(self) -> bool:
|
|
111
|
+
"""是否成功结束"""
|
|
112
|
+
return self._state == HebraRunState.SUCCESS
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def crashed(self) -> bool:
|
|
116
|
+
"""是否异常结束"""
|
|
117
|
+
return self._state == HebraRunState.CRASHED
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def config(self) -> Dict[str, Any]:
|
|
121
|
+
"""获取配置"""
|
|
122
|
+
return self._config
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def project(self) -> str:
|
|
126
|
+
"""获取项目名称"""
|
|
127
|
+
return self._project
|
|
128
|
+
|
|
129
|
+
def log(
|
|
130
|
+
self,
|
|
131
|
+
data: Dict[str, Any],
|
|
132
|
+
step: Optional[int] = None,
|
|
133
|
+
) -> Dict[str, Any]:
|
|
134
|
+
"""记录数据。
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
data: 键值对数据
|
|
138
|
+
step: 可选的步数,不提供则自动递增
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
记录结果
|
|
142
|
+
"""
|
|
143
|
+
if self._state != HebraRunState.RUNNING:
|
|
144
|
+
raise RuntimeError("实验已结束,无法继续记录数据")
|
|
145
|
+
|
|
146
|
+
if not isinstance(data, dict):
|
|
147
|
+
raise TypeError(f"data必须是dict类型,但收到了{type(data)}")
|
|
148
|
+
|
|
149
|
+
# 处理step
|
|
150
|
+
if step is None:
|
|
151
|
+
step = self._step
|
|
152
|
+
self._step += 1
|
|
153
|
+
elif not isinstance(step, int) or step < 0:
|
|
154
|
+
print(f"[hebra] 警告: step必须是非负整数,忽略传入的step={step}")
|
|
155
|
+
step = self._step
|
|
156
|
+
self._step += 1
|
|
157
|
+
|
|
158
|
+
# 转换数据并上传
|
|
159
|
+
kv_list = kvs_from_mapping(data)
|
|
160
|
+
resp = self._conn.upload(kv_list)
|
|
161
|
+
|
|
162
|
+
return {"step": step, "success": resp.success, "message": resp.message}
|
|
163
|
+
|
|
164
|
+
def finish(
|
|
165
|
+
self,
|
|
166
|
+
state: HebraRunState = HebraRunState.SUCCESS,
|
|
167
|
+
error: Optional[str] = None,
|
|
168
|
+
) -> None:
|
|
169
|
+
"""结束运行。
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
state: 结束状态
|
|
173
|
+
error: 错误信息(仅当state为CRASHED时使用)
|
|
174
|
+
"""
|
|
175
|
+
global _current_run
|
|
176
|
+
|
|
177
|
+
if self._state != HebraRunState.RUNNING:
|
|
178
|
+
print("[hebra] 警告: 实验已结束,忽略重复的finish调用")
|
|
179
|
+
return
|
|
180
|
+
|
|
181
|
+
# 更新状态
|
|
182
|
+
self._state = state
|
|
183
|
+
|
|
184
|
+
# 注销系统回调
|
|
185
|
+
self._unregister_sys_callback()
|
|
186
|
+
|
|
187
|
+
# 关闭服务连接
|
|
188
|
+
exit_code = 0 if state == HebraRunState.SUCCESS else 1
|
|
189
|
+
self._conn.shutdown(exit_code=exit_code)
|
|
190
|
+
|
|
191
|
+
# 清除全局实例
|
|
192
|
+
_current_run = None
|
|
193
|
+
|
|
194
|
+
status = "SUCCESS" if state == HebraRunState.SUCCESS else "CRASHED"
|
|
195
|
+
print(f"[hebra] 实验已结束: status={status}")
|
|
196
|
+
|
|
197
|
+
if error:
|
|
198
|
+
print(f"[hebra] 错误信息: {error[:200]}...")
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
# 全局运行实例
|
|
202
|
+
_current_run: Optional[HebraRun] = None
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def get_run() -> Optional[HebraRun]:
|
|
206
|
+
"""获取当前运行实例。
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
当前HebraRun实例,未初始化时返回None
|
|
210
|
+
"""
|
|
211
|
+
return _current_run
|
hebra/run/state.py
ADDED
hebra/sdk.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, Dict, Optional
|
|
5
|
+
|
|
6
|
+
from .run import HebraRun, HebraRunState, get_run
|
|
7
|
+
from .lib.utils import should_call_after_init
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def init(
|
|
11
|
+
project: Optional[str] = None,
|
|
12
|
+
logdir: Optional[str] = None,
|
|
13
|
+
config: Optional[Dict[str, Any]] = None,
|
|
14
|
+
reinit: bool = False,
|
|
15
|
+
) -> HebraRun:
|
|
16
|
+
"""初始化hebra实验。
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
project: 项目名称,默认为当前目录名
|
|
20
|
+
logdir: 日志存储目录,默认为"./hebralog"
|
|
21
|
+
config: 可选的配置参数
|
|
22
|
+
reinit: 是否重新初始化(会先结束当前运行)
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
HebraRun实例
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
RuntimeError: 已初始化且reinit=False时
|
|
29
|
+
"""
|
|
30
|
+
_current_run = get_run()
|
|
31
|
+
|
|
32
|
+
# 处理重复初始化
|
|
33
|
+
if HebraRun.is_started():
|
|
34
|
+
if reinit:
|
|
35
|
+
print("[hebra] reinit=True,先结束当前实验...")
|
|
36
|
+
_current_run.finish()
|
|
37
|
+
else:
|
|
38
|
+
print("[hebra] 警告: 实验已初始化,返回当前实例。使用reinit=True可重新初始化")
|
|
39
|
+
return _current_run
|
|
40
|
+
|
|
41
|
+
# 默认值处理
|
|
42
|
+
if project is None:
|
|
43
|
+
project = os.path.basename(os.getcwd())
|
|
44
|
+
if logdir is None:
|
|
45
|
+
logdir = os.path.join(os.getcwd(), "hebralog")
|
|
46
|
+
|
|
47
|
+
return HebraRun(project=project, logdir=logdir, config=config)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@should_call_after_init("必须先调用hebra.init()才能使用log()")
|
|
51
|
+
def log(
|
|
52
|
+
data: Dict[str, Any],
|
|
53
|
+
step: Optional[int] = None,
|
|
54
|
+
) -> Dict[str, Any]:
|
|
55
|
+
"""记录数据到当前运行。
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
data: 键值对数据
|
|
59
|
+
step: 可选的步数
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
记录结果
|
|
63
|
+
"""
|
|
64
|
+
return get_run().log(data, step)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@should_call_after_init("必须先调用hebra.init()才能使用finish()")
|
|
68
|
+
def finish(
|
|
69
|
+
state: HebraRunState = HebraRunState.SUCCESS,
|
|
70
|
+
error: Optional[str] = None,
|
|
71
|
+
) -> None:
|
|
72
|
+
"""结束当前运行。
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
state: 结束状态
|
|
76
|
+
error: 错误信息
|
|
77
|
+
"""
|
|
78
|
+
get_run().finish(state, error)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
hebra/__init__.py,sha256=bhjc8wLMPfQoJifJC3s7tpNmUOPYWYxK0WfAHPeLxiM,1417
|
|
2
|
+
hebra/sdk.py,sha256=1vY1e8ld8BOB6F-MIldgvZvazM7WKRecSEJK508tbQ4,2008
|
|
3
|
+
hebra/bin/.gitkeep,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
4
|
+
hebra/bin/core,sha256=XiW48NNL2Ji1H5-BICfXkyNHcaZEMssM1FUgYh68tJs,19130402
|
|
5
|
+
hebra/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
|
+
hebra/lib/service_connection.py,sha256=lFLQrRaAq4LcPatqQ4o8HHz1-Pccu31uk2Qc-K9wrFY,4153
|
|
7
|
+
hebra/lib/utils.py,sha256=DRjfgyGPNDJQToE3nnXCflHzrySTNNDW6vTmOZ0I8Hg,1989
|
|
8
|
+
hebra/proto/common/v1/common.proto,sha256=0XbSmzf7nSPWs_WEpIuaIeIIkt_nhByZ4HPAbSW0PjI,1779
|
|
9
|
+
hebra/proto/common/v1/common_pb2.py,sha256=HyXsiIkleR3D0r6uu1d6cialfoV9aJsQIn6OW10D7EY,2418
|
|
10
|
+
hebra/proto/common/v1/common_pb2.pyi,sha256=AjOIy5H1Au_Mf4AP-h-vQZNi4MWxajbavnT3Ht1slpY,2160
|
|
11
|
+
hebra/proto/core/collector/v1/collector_service.proto,sha256=Wfzr6jJwgWsLpslIciTcjb8yHGn4orzzt6wmpqlkjkk,527
|
|
12
|
+
hebra/proto/core/collector/v1/collector_service_pb2.py,sha256=IgNCKDFqLcV8uCuk_n2oYDLtG1d-KrREX5Tn9tt16BU,2173
|
|
13
|
+
hebra/proto/core/collector/v1/collector_service_pb2.pyi,sha256=q90Gkan7M81nYN6QkLx2SM6gcBi50hZ42Ud414pIhd8,864
|
|
14
|
+
hebra/proto/core/collector/v1/collector_service_pb2_grpc.py,sha256=Vr8SDsNtWKOTy8KG_vq1iM3vm9qrRAgR1NQAacBBN2w,3933
|
|
15
|
+
hebra/proto/core/health/v1/health_service.proto,sha256=FktnAFiFXGJ3oFEaFOKjAt1aj_iZhTymhRtrIBpBPSg,3673
|
|
16
|
+
hebra/proto/core/lifecycle/__init__.py,sha256=XI9bOZA3U0hzPHXHsLoU4L9v9crQim_mGF7C9ow_-JE,20
|
|
17
|
+
hebra/proto/core/lifecycle/v1/__init__.py,sha256=OCjuYKmtOF2CuS2N9yAEYg8UEdRooxa2Z2T_0V0yQko,23
|
|
18
|
+
hebra/proto/core/lifecycle/v1/lifecycle_service.proto,sha256=pOfYHQjrpbAyl2DqE0tK-UvU2xupjkibsG-wdQKcmRg,694
|
|
19
|
+
hebra/proto/core/lifecycle/v1/lifecycle_service_pb2.py,sha256=uQWFX9wdJIFy30TCOYFWjgAI1MNNuRGEvZzOzoiXa28,1927
|
|
20
|
+
hebra/proto/core/lifecycle/v1/lifecycle_service_pb2.pyi,sha256=n_Jwf54hwYixTxh8f4lh3dQUqEM36MgaQo9qMpF9eIE,598
|
|
21
|
+
hebra/proto/core/lifecycle/v1/lifecycle_service_pb2_grpc.py,sha256=3Se2gWf7ULoNC3K2kE6-rHfQxfN4e6EwpMGhNBmaDrI,3895
|
|
22
|
+
hebra/proto/record/v1/record.proto,sha256=qjKPt3sBD2sQrnRhAfGVRVZCTYD8_LYELMlACiZ6XZs,5563
|
|
23
|
+
hebra/proto/record/v1/record_pb2.py,sha256=8QzCloeb6saQHlYPfDGu_WGg_3BjDXYAjNucVTlAcfs,6869
|
|
24
|
+
hebra/proto/record/v1/record_pb2.pyi,sha256=pc5gp4INaRRaPxMsq7jJviQdx4-RLsfBwOLhM19eWuE,10317
|
|
25
|
+
hebra/run/__init__.py,sha256=dnIZJpbwYBL5sYeb_6zzzn_VVHdaN3pz7jbIoNy0QVQ,121
|
|
26
|
+
hebra/run/main.py,sha256=8cxItiktdrXvtUH3WBmNHWBw_hWG5_-1EIfH10W5mYI,5998
|
|
27
|
+
hebra/run/state.py,sha256=u4mAjHx9Lrtticfy5GSOUJeEtZBfKfc8bmKQRdok7G4,243
|
|
28
|
+
hebra-0.1.0.dist-info/METADATA,sha256=_0QL_71kt0UhZdMPgHMHjQIx7UsMa8YNNYEjVJQ4Tb8,211
|
|
29
|
+
hebra-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
30
|
+
hebra-0.1.0.dist-info/RECORD,,
|