intellif-aihub 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of intellif-aihub might be problematic. Click here for more details.
- aihub/__init__.py +1 -1
- aihub/models/artifact.py +16 -4
- aihub/models/dataset_management.py +91 -0
- aihub/services/artifact.py +16 -30
- aihub/services/dataset_management.py +260 -42
- aihub/utils/di.py +337 -0
- aihub/utils/download.py +3 -15
- aihub/utils/http.py +6 -0
- {intellif_aihub-0.1.13.dist-info → intellif_aihub-0.1.15.dist-info}/METADATA +1 -1
- {intellif_aihub-0.1.13.dist-info → intellif_aihub-0.1.15.dist-info}/RECORD +13 -12
- {intellif_aihub-0.1.13.dist-info → intellif_aihub-0.1.15.dist-info}/WHEEL +0 -0
- {intellif_aihub-0.1.13.dist-info → intellif_aihub-0.1.15.dist-info}/licenses/LICENSE +0 -0
- {intellif_aihub-0.1.13.dist-info → intellif_aihub-0.1.15.dist-info}/top_level.txt +0 -0
aihub/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.1.
|
|
1
|
+
__version__ = "0.1.15"
|
aihub/models/artifact.py
CHANGED
|
@@ -14,6 +14,7 @@ from pydantic import BaseModel, Field
|
|
|
14
14
|
|
|
15
15
|
class ArtifactType(str, Enum):
|
|
16
16
|
"""制品类型枚举:"dataset"-数据集类型;"model"-模型类型;"metrics"-指标类型;"log"-日志类型;"checkpoint"-检查点类型;"image"-图像类型;"prediction"-预测结果类型;"other"-其他类型"""
|
|
17
|
+
|
|
17
18
|
dataset = "dataset" # 数据集类型
|
|
18
19
|
model = "model" # 模型类型
|
|
19
20
|
metrics = "metrics" # 指标类型
|
|
@@ -26,24 +27,29 @@ class ArtifactType(str, Enum):
|
|
|
26
27
|
|
|
27
28
|
class CreateArtifactsReq(BaseModel):
|
|
28
29
|
"""创建制品请求"""
|
|
30
|
+
|
|
29
31
|
entity_id: str = Field(alias="entity_id", description="实体ID,通常是运行ID,用于关联制品与特定运行")
|
|
30
|
-
entity_type: ArtifactType = Field(
|
|
31
|
-
|
|
32
|
+
entity_type: ArtifactType = Field(
|
|
33
|
+
default=ArtifactType.other, alias="entity_type", description="制品类型,指定制品的类型,默认为other"
|
|
34
|
+
)
|
|
32
35
|
src_path: str = Field(alias="src_path", description="源路径,制品在系统中的路径标识")
|
|
33
|
-
is_dir: bool = Field(
|
|
34
|
-
|
|
36
|
+
is_dir: bool = Field(
|
|
37
|
+
default=False, alias="is_dir", description="是否为目录,True表示制品是一个目录,False表示是单个文件"
|
|
38
|
+
)
|
|
35
39
|
|
|
36
40
|
model_config = {"use_enum_values": True}
|
|
37
41
|
|
|
38
42
|
|
|
39
43
|
class CreateArtifactsResponseData(BaseModel):
|
|
40
44
|
"""创建制品响应数据"""
|
|
45
|
+
|
|
41
46
|
id: int = Field(description="制品ID")
|
|
42
47
|
s3_path: str = Field(alias="s3_path", description="S3存储路径")
|
|
43
48
|
|
|
44
49
|
|
|
45
50
|
class CreateArtifactsResponseModel(BaseModel):
|
|
46
51
|
"""创建制品响应模型"""
|
|
52
|
+
|
|
47
53
|
code: int = Field(description="响应码,0表示成功")
|
|
48
54
|
msg: str = Field(default="", description="响应消息")
|
|
49
55
|
data: Optional[CreateArtifactsResponseData] = Field(default=None, description="响应数据")
|
|
@@ -51,6 +57,7 @@ class CreateArtifactsResponseModel(BaseModel):
|
|
|
51
57
|
|
|
52
58
|
class CreateEvalReq(BaseModel):
|
|
53
59
|
"""创建评估请求"""
|
|
60
|
+
|
|
54
61
|
dataset_id: int = Field(alias="dataset_id", description="数据集ID")
|
|
55
62
|
dataset_version_id: int = Field(alias="dataset_version_id", description="数据集版本ID")
|
|
56
63
|
prediction_artifact_path: str = Field(alias="prediction_artifact_path", description="预测结果制品路径")
|
|
@@ -62,6 +69,7 @@ class CreateEvalReq(BaseModel):
|
|
|
62
69
|
|
|
63
70
|
class ArtifactResp(BaseModel):
|
|
64
71
|
"""制品响应模型,表示一个制品的详细信息"""
|
|
72
|
+
|
|
65
73
|
id: int = Field(description="制品ID")
|
|
66
74
|
entity_type: str = Field(alias="entity_type", description="实体类型")
|
|
67
75
|
entity_id: str = Field(alias="entity_id", description="实体ID")
|
|
@@ -72,6 +80,7 @@ class ArtifactResp(BaseModel):
|
|
|
72
80
|
|
|
73
81
|
class ArtifactRespData(BaseModel):
|
|
74
82
|
"""制品分页数据"""
|
|
83
|
+
|
|
75
84
|
total: int = Field(description="总记录数")
|
|
76
85
|
page_size: int = Field(alias="page_size", description="每页大小")
|
|
77
86
|
page_num: int = Field(alias="page_num", description="页码")
|
|
@@ -80,6 +89,7 @@ class ArtifactRespData(BaseModel):
|
|
|
80
89
|
|
|
81
90
|
class ArtifactRespModel(BaseModel):
|
|
82
91
|
"""获取制品响应模型"""
|
|
92
|
+
|
|
83
93
|
code: int = Field(description="响应码,0表示成功")
|
|
84
94
|
msg: str = Field(default="", description="响应消息")
|
|
85
95
|
data: ArtifactRespData = Field(description="响应数据")
|
|
@@ -91,8 +101,10 @@ InfinityPageSize = 10000 * 100
|
|
|
91
101
|
|
|
92
102
|
class StsResp(BaseModel):
|
|
93
103
|
"""STS 临时凭证"""
|
|
104
|
+
|
|
94
105
|
access_key_id: Optional[str] = Field(default=None, alias="access_key_id", description="访问密钥ID")
|
|
95
106
|
secret_access_key: Optional[str] = Field(default=None, alias="secret_access_key", description="秘密访问密钥")
|
|
96
107
|
session_token: Optional[str] = Field(default=None, alias="session_token", description="会话令牌")
|
|
97
108
|
expiration: Optional[int] = Field(default=None, alias="expiration", description="过期时间")
|
|
98
109
|
endpoint: Optional[str] = Field(default=None, alias="endpoint", description="端点URL")
|
|
110
|
+
bucket: Optional[str] = Field(default=None, alias="bucket", description="存储桶名称")
|
|
@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
|
|
|
8
8
|
|
|
9
9
|
class DatasetVersionStatus(IntEnum):
|
|
10
10
|
"""数据集版本状态:1-等待中;2-运行中;3-成功;4-失败;5-加载meta;6-构建index"""
|
|
11
|
+
|
|
11
12
|
Waiting = 1 # 等待中
|
|
12
13
|
Running = 2 # 运行中
|
|
13
14
|
Success = 3 # 成功
|
|
@@ -18,6 +19,7 @@ class DatasetVersionStatus(IntEnum):
|
|
|
18
19
|
|
|
19
20
|
class UploadType(IntEnum):
|
|
20
21
|
"""上传类型:1-本地上传;3-服务器路径上传;4-Labelfree;5-数据接入"""
|
|
22
|
+
|
|
21
23
|
LOCAL = 1 # 本地上传
|
|
22
24
|
SERVER_PATH = 3 # 服务器路径上传
|
|
23
25
|
LABELFREE = 4 # Labelfree
|
|
@@ -26,6 +28,7 @@ class UploadType(IntEnum):
|
|
|
26
28
|
|
|
27
29
|
class CreateDatasetRequest(BaseModel):
|
|
28
30
|
"""创建数据集请求"""
|
|
31
|
+
|
|
29
32
|
name: str = Field(description="数据集名称")
|
|
30
33
|
description: str = Field(description="数据集描述")
|
|
31
34
|
tags: List[int] = Field(description="标签ID列表,通过标签管理系统查询")
|
|
@@ -37,11 +40,13 @@ class CreateDatasetRequest(BaseModel):
|
|
|
37
40
|
|
|
38
41
|
class CreateDatasetResponse(BaseModel):
|
|
39
42
|
"""创建数据集返回"""
|
|
43
|
+
|
|
40
44
|
id: int = Field(alias="id", description="数据集ID")
|
|
41
45
|
|
|
42
46
|
|
|
43
47
|
class DatasetVersionBase(BaseModel):
|
|
44
48
|
"""数据集版本概要"""
|
|
49
|
+
|
|
45
50
|
id: int = Field(description="版本ID")
|
|
46
51
|
version: int = Field(description="版本号")
|
|
47
52
|
status: DatasetVersionStatus = Field(description="版本状态")
|
|
@@ -53,6 +58,7 @@ class DatasetVersionBase(BaseModel):
|
|
|
53
58
|
|
|
54
59
|
class DatasetDetail(BaseModel):
|
|
55
60
|
"""数据集详情"""
|
|
61
|
+
|
|
56
62
|
id: int = Field(description="数据集 ID")
|
|
57
63
|
name: str = Field(description="名称")
|
|
58
64
|
description: str = Field(description="描述")
|
|
@@ -69,6 +75,7 @@ class DatasetDetail(BaseModel):
|
|
|
69
75
|
|
|
70
76
|
class ExtInfo(BaseModel):
|
|
71
77
|
"""扩展信息"""
|
|
78
|
+
|
|
72
79
|
rec_file_path: Optional[str] = Field(None, alias="rec_file_path", description="rec文件路径")
|
|
73
80
|
idx_file_path: Optional[str] = Field(None, alias="idx_file_path", description="idx文件路径")
|
|
74
81
|
json_file_path: Optional[str] = Field(None, alias="json_file_path", description="json文件路径")
|
|
@@ -77,6 +84,7 @@ class ExtInfo(BaseModel):
|
|
|
77
84
|
|
|
78
85
|
class CreateDatasetVersionRequest(BaseModel):
|
|
79
86
|
"""创建版本请求"""
|
|
87
|
+
|
|
80
88
|
upload_path: str = Field(alias="upload_path", description="上传路径")
|
|
81
89
|
description: Optional[str] = Field(None, description="版本描述")
|
|
82
90
|
dataset_id: int = Field(alias="dataset_id", description="数据集ID")
|
|
@@ -91,11 +99,13 @@ class CreateDatasetVersionRequest(BaseModel):
|
|
|
91
99
|
|
|
92
100
|
class CreateDatasetVersionResponse(BaseModel):
|
|
93
101
|
"""创建版本返回"""
|
|
102
|
+
|
|
94
103
|
id: int = Field(alias="id", description="版本ID")
|
|
95
104
|
|
|
96
105
|
|
|
97
106
|
class UploadDatasetVersionRequest(BaseModel):
|
|
98
107
|
"""上传数据集版本请求"""
|
|
108
|
+
|
|
99
109
|
upload_path: str = Field(alias="upload_path", description="上传目录")
|
|
100
110
|
upload_type: UploadType = Field(alias="upload_type", description="上传类型")
|
|
101
111
|
dataset_id: int = Field(alias="dataset_id", description="数据集ID")
|
|
@@ -107,11 +117,13 @@ class UploadDatasetVersionRequest(BaseModel):
|
|
|
107
117
|
|
|
108
118
|
class UploadDatasetVersionResponse(BaseModel):
|
|
109
119
|
"""上传数据集版本返回"""
|
|
120
|
+
|
|
110
121
|
id: int = Field(alias="id", description="版本ID")
|
|
111
122
|
|
|
112
123
|
|
|
113
124
|
class DatasetVersionDetail(BaseModel):
|
|
114
125
|
"""数据集版本详情"""
|
|
126
|
+
|
|
115
127
|
id: int = Field(description="版本ID")
|
|
116
128
|
version: int = Field(description="版本号")
|
|
117
129
|
dataset_id: int = Field(alias="dataset_id", description="数据集ID")
|
|
@@ -133,5 +145,84 @@ class DatasetVersionDetail(BaseModel):
|
|
|
133
145
|
|
|
134
146
|
class FileUploadData(BaseModel):
|
|
135
147
|
"""文件上传数据"""
|
|
148
|
+
|
|
136
149
|
path: str = Field(description="路径")
|
|
137
150
|
url: str = Field(description="URL")
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class ListDatasetReq(BaseModel):
|
|
154
|
+
"""列表查询数据集请求(使用 dataset_management v2)"""
|
|
155
|
+
page_size: int = Field(20, alias="page_size", description="每页大小,默认20")
|
|
156
|
+
page_num: int = Field(1, alias="page_num", description="页码,从1开始")
|
|
157
|
+
name: Optional[str] = Field(None, description="数据集名称筛选")
|
|
158
|
+
tags: Optional[str] = Field(None, description="标签筛选")
|
|
159
|
+
create_by: Optional[int] = Field(None, alias="create_by", description="创建人筛选")
|
|
160
|
+
scope: Optional[str] = Field("all", description="范围筛选:created|shared|all")
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class ListDatasetItem(BaseModel):
|
|
164
|
+
"""列表数据集项"""
|
|
165
|
+
id: int = Field(description="数据集ID")
|
|
166
|
+
name: str = Field(description="数据集名称")
|
|
167
|
+
description: str = Field(description="数据集描述")
|
|
168
|
+
cover_img: str = Field(alias="cover_img", description="封面图片")
|
|
169
|
+
created_at: int = Field(alias="created_at", description="创建时间戳")
|
|
170
|
+
updated_at: int = Field(alias="update_at", description="更新时间戳")
|
|
171
|
+
user_id: int = Field(alias="user_id", description="创建人ID")
|
|
172
|
+
username: str = Field(description="创建人用户名")
|
|
173
|
+
tags: Optional[List[int]] = Field(None, description="标签列表")
|
|
174
|
+
access_user_ids: Optional[List[int]] = Field(None, alias="access_user_ids", description="有访问权限的用户ID列表")
|
|
175
|
+
is_private: bool = Field(alias="is_private", description="是否私有")
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class ListDatasetResp(BaseModel):
|
|
179
|
+
"""列表查询数据集响应"""
|
|
180
|
+
total: int = Field(description="总数")
|
|
181
|
+
page_size: int = Field(alias="page_size", description="每页大小")
|
|
182
|
+
page_num: int = Field(alias="page_num", description="当前页码")
|
|
183
|
+
data: List[ListDatasetItem] = Field(description="数据集列表")
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class ListDatasetVersionReq(BaseModel):
|
|
187
|
+
"""列表查询数据集版本请求(使用 dataset_management v2)"""
|
|
188
|
+
page_size: int = Field(10000000, alias="page_size", description="每页大小,默认10000000")
|
|
189
|
+
page_num: int = Field(1, alias="page_num", description="页码,从1开始")
|
|
190
|
+
dataset_id: Optional[int] = Field(None, alias="dataset_id", description="数据集ID筛选")
|
|
191
|
+
dataset_version_ids: Optional[str] = Field(None, alias="dataset_version_ids", description="数据集版本ID列表,逗号分隔")
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class ListDatasetVersionItem(BaseModel):
|
|
195
|
+
"""列表数据集版本项"""
|
|
196
|
+
id: int = Field(description="版本ID")
|
|
197
|
+
version: int = Field(description="版本号")
|
|
198
|
+
dataset_id: int = Field(alias="dataset_id", description="数据集ID")
|
|
199
|
+
upload_path: str = Field(alias="upload_path", description="上传路径")
|
|
200
|
+
upload_type: int = Field(alias="upload_type", description="上传类型")
|
|
201
|
+
parent_version_id: Optional[int] = Field(None, alias="parent_version_id", description="父版本ID")
|
|
202
|
+
description: Optional[str] = Field(None, description="版本描述")
|
|
203
|
+
status: int = Field(description="版本状态")
|
|
204
|
+
message: str = Field(description="状态信息")
|
|
205
|
+
created_at: int = Field(alias="created_at", description="创建时间戳")
|
|
206
|
+
user_id: int = Field(alias="user_id", description="创建人ID")
|
|
207
|
+
data_size: int = Field(alias="data_size", description="数据大小")
|
|
208
|
+
data_count: int = Field(alias="data_count", description="数据条数")
|
|
209
|
+
username: str = Field(description="创建人用户名")
|
|
210
|
+
dataset_name: str = Field(alias="dataset_name", description="数据集名称")
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class ListDatasetVersionResp(BaseModel):
|
|
214
|
+
"""列表查询数据集版本响应"""
|
|
215
|
+
total: int = Field(description="总数")
|
|
216
|
+
page_size: int = Field(alias="page_size", description="每页大小")
|
|
217
|
+
page_num: int = Field(alias="page_num", description="当前页码")
|
|
218
|
+
data: List[ListDatasetVersionItem] = Field(description="数据集版本列表")
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class CreateDatasetVersionByDataIngestReqV2(BaseModel):
|
|
222
|
+
"""通过数据集成创建数据集版本请求"""
|
|
223
|
+
|
|
224
|
+
description: Optional[str] = Field(None, description="描述")
|
|
225
|
+
dataset_id: int = Field(..., description="数据集ID")
|
|
226
|
+
s3_object_sheet: str = Field(..., description="S3对象表")
|
|
227
|
+
object_cnt: Optional[int] = Field(None, description="对象数量")
|
|
228
|
+
data_size: Optional[int] = Field(None, description="数据大小")
|
aihub/services/artifact.py
CHANGED
|
@@ -98,9 +98,7 @@ class ArtifactService:
|
|
|
98
98
|
"""
|
|
99
99
|
return self._artifact.get_sts()
|
|
100
100
|
|
|
101
|
-
def get_by_run_id(
|
|
102
|
-
self, run_id: str, artifact_path: Optional[str] = None
|
|
103
|
-
) -> List[ArtifactResp]:
|
|
101
|
+
def get_by_run_id(self, run_id: str, artifact_path: Optional[str] = None) -> List[ArtifactResp]:
|
|
104
102
|
"""根据运行ID获取制品列表
|
|
105
103
|
|
|
106
104
|
Args:
|
|
@@ -116,11 +114,11 @@ class ArtifactService:
|
|
|
116
114
|
return self._artifact.get_by_run_id(run_id, artifact_path)
|
|
117
115
|
|
|
118
116
|
def create_artifact(
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
117
|
+
self,
|
|
118
|
+
local_path: str,
|
|
119
|
+
artifact_path: Optional[str] = None,
|
|
120
|
+
run_id: Optional[str] = None,
|
|
121
|
+
artifact_type: ArtifactType = ArtifactType.other,
|
|
124
122
|
) -> None:
|
|
125
123
|
"""创建单个文件制品并上传
|
|
126
124
|
|
|
@@ -171,11 +169,11 @@ class ArtifactService:
|
|
|
171
169
|
return
|
|
172
170
|
|
|
173
171
|
def create_artifacts(
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
172
|
+
self,
|
|
173
|
+
local_dir: str,
|
|
174
|
+
artifact_path: Optional[str] = None,
|
|
175
|
+
run_id: Optional[str] = None,
|
|
176
|
+
artifact_type: ArtifactType = ArtifactType.other,
|
|
179
177
|
) -> None:
|
|
180
178
|
"""创建目录制品并上传
|
|
181
179
|
|
|
@@ -223,9 +221,7 @@ class ArtifactService:
|
|
|
223
221
|
logger.info(f"log artifact done: {artifact_path}")
|
|
224
222
|
return
|
|
225
223
|
|
|
226
|
-
def download_artifacts(
|
|
227
|
-
self, run_id: str, artifact_path: Optional[str], local_dir: str
|
|
228
|
-
) -> None:
|
|
224
|
+
def download_artifacts(self, run_id: str, artifact_path: Optional[str], local_dir: str) -> None:
|
|
229
225
|
"""下载制品
|
|
230
226
|
|
|
231
227
|
Args:
|
|
@@ -252,9 +248,7 @@ class ArtifactService:
|
|
|
252
248
|
if artifact_item.is_dir:
|
|
253
249
|
download_dir_from_s3(self.s3_client, bucket, object_name, local_dir)
|
|
254
250
|
else:
|
|
255
|
-
self.s3_client.fget_object(
|
|
256
|
-
bucket, object_name, str(Path(local_dir) / artifact_item.src_path)
|
|
257
|
-
)
|
|
251
|
+
self.s3_client.fget_object(bucket, object_name, str(Path(local_dir) / artifact_item.src_path))
|
|
258
252
|
|
|
259
253
|
logger.info(f"download artifact done: {artifact_path}")
|
|
260
254
|
return
|
|
@@ -311,9 +305,7 @@ class _Artifact:
|
|
|
311
305
|
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
312
306
|
return
|
|
313
307
|
|
|
314
|
-
def get_by_run_id(
|
|
315
|
-
self, run_id: str, artifact_path: Optional[str]
|
|
316
|
-
) -> List[ArtifactResp]:
|
|
308
|
+
def get_by_run_id(self, run_id: str, artifact_path: Optional[str]) -> List[ArtifactResp]:
|
|
317
309
|
"""根据运行ID获取制品列表
|
|
318
310
|
|
|
319
311
|
Args:
|
|
@@ -326,18 +318,12 @@ class _Artifact:
|
|
|
326
318
|
Raises:
|
|
327
319
|
APIError: 当API调用失败时抛出
|
|
328
320
|
"""
|
|
329
|
-
resp = self._http.get(
|
|
330
|
-
f"{_Base}/artifacts?entity_id={run_id}&page_num=1&page_size={InfinityPageSize}"
|
|
331
|
-
)
|
|
321
|
+
resp = self._http.get(f"{_Base}/artifacts?entity_id={run_id}&page_num=1&page_size={InfinityPageSize}")
|
|
332
322
|
wrapper = APIWrapper[ArtifactRespData].model_validate(resp.json())
|
|
333
323
|
if wrapper.code != 0:
|
|
334
324
|
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
335
325
|
if artifact_path:
|
|
336
|
-
return [
|
|
337
|
-
artifact
|
|
338
|
-
for artifact in wrapper.data.data
|
|
339
|
-
if artifact.src_path == artifact_path
|
|
340
|
-
]
|
|
326
|
+
return [artifact for artifact in wrapper.data.data if artifact.src_path == artifact_path]
|
|
341
327
|
else:
|
|
342
328
|
return wrapper.data.data
|
|
343
329
|
|
|
@@ -7,6 +7,8 @@
|
|
|
7
7
|
- **创建数据集及其版本**(支持本地上传和服务器现有文件两种方式)
|
|
8
8
|
- **上传文件到对象存储**(大文件自动分片)
|
|
9
9
|
- **查询数据集/数据集版本详情**
|
|
10
|
+
- **列表查询和搜索数据集**(支持分页和筛选)
|
|
11
|
+
- **列表查询数据集版本**(支持按数据集ID筛选和分页)
|
|
10
12
|
- **按版本名称或ID下载数据集文件**
|
|
11
13
|
"""
|
|
12
14
|
|
|
@@ -15,7 +17,6 @@ from __future__ import annotations
|
|
|
15
17
|
import mimetypes
|
|
16
18
|
import os
|
|
17
19
|
import pathlib
|
|
18
|
-
import tempfile
|
|
19
20
|
import time
|
|
20
21
|
import uuid
|
|
21
22
|
|
|
@@ -23,6 +24,7 @@ import httpx
|
|
|
23
24
|
from loguru import logger
|
|
24
25
|
|
|
25
26
|
from ..exceptions import APIError
|
|
27
|
+
from ..models.artifact import StsResp
|
|
26
28
|
from ..models.common import APIWrapper
|
|
27
29
|
from ..models.dataset_management import (
|
|
28
30
|
CreateDatasetRequest,
|
|
@@ -34,9 +36,16 @@ from ..models.dataset_management import (
|
|
|
34
36
|
DatasetVersionDetail,
|
|
35
37
|
UploadDatasetVersionResponse,
|
|
36
38
|
FileUploadData,
|
|
39
|
+
ListDatasetReq,
|
|
40
|
+
ListDatasetResp,
|
|
41
|
+
ListDatasetVersionReq,
|
|
42
|
+
ListDatasetVersionResp,
|
|
43
|
+
CreateDatasetVersionByDataIngestReqV2,
|
|
44
|
+
UploadType,
|
|
37
45
|
)
|
|
38
46
|
from ..models.dataset_management import DatasetVersionStatus
|
|
39
|
-
from ..utils.
|
|
47
|
+
from ..utils.di import SimpleS3Client, DataUploader
|
|
48
|
+
from ..utils.download import dataset_download
|
|
40
49
|
|
|
41
50
|
_BASE = "/dataset-mng/api/v2"
|
|
42
51
|
|
|
@@ -132,20 +141,29 @@ class DatasetManagementService:
|
|
|
132
141
|
def dataset(self) -> _Dataset:
|
|
133
142
|
return self._dataset
|
|
134
143
|
|
|
144
|
+
def _get_sts(self) -> StsResp:
|
|
145
|
+
return self.dataset_version.get_sts()
|
|
146
|
+
|
|
135
147
|
@property
|
|
136
148
|
def dataset_version(self) -> _DatasetVersion:
|
|
137
149
|
return self._dataset_version
|
|
138
150
|
|
|
151
|
+
def upload_by_data_ingest(
|
|
152
|
+
self,
|
|
153
|
+
req: CreateDatasetVersionByDataIngestReqV2,
|
|
154
|
+
) -> CreateDatasetVersionResponse:
|
|
155
|
+
return self.dataset_version.upload_by_data_ingest(req)
|
|
156
|
+
|
|
139
157
|
def create_dataset_and_version(
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
158
|
+
self,
|
|
159
|
+
*,
|
|
160
|
+
dataset_name: str,
|
|
161
|
+
dataset_description: str = "",
|
|
162
|
+
is_local_upload: bool,
|
|
163
|
+
local_file_path: str | None = None,
|
|
164
|
+
server_file_path: str | None = None,
|
|
165
|
+
version_description: str = "",
|
|
166
|
+
timeout: int = 1_800,
|
|
149
167
|
) -> tuple[int, int, str]:
|
|
150
168
|
"""创建数据集及其版本,并等待版本状态变为 *Success*。
|
|
151
169
|
|
|
@@ -163,17 +181,51 @@ class DatasetManagementService:
|
|
|
163
181
|
|
|
164
182
|
Returns:
|
|
165
183
|
tuple[int, int, str]: 一个三元组,包含:[数据集 ID,数据集版本 ID, 数据集版本标签(格式为 <dataset_name>/V<version_number>)]
|
|
184
|
+
|
|
185
|
+
Raises:
|
|
186
|
+
ValueError: 当参数不满足要求时
|
|
187
|
+
APIError: 当后端返回错误时
|
|
188
|
+
TimeoutError: 当等待超时时
|
|
166
189
|
"""
|
|
190
|
+
# 参数校验
|
|
191
|
+
self._validate_create_params(is_local_upload, local_file_path, server_file_path)
|
|
192
|
+
|
|
193
|
+
# 创建数据集
|
|
194
|
+
dataset_id = self._create_dataset(dataset_name, dataset_description)
|
|
195
|
+
logger.info(f"创建数据集成功,名称为 {dataset_name} ,开始准备创建版本、上传数据")
|
|
196
|
+
|
|
197
|
+
# 创建数据集版本
|
|
198
|
+
version_id = self._create_dataset_version(
|
|
199
|
+
dataset_id=dataset_id,
|
|
200
|
+
is_local_upload=is_local_upload,
|
|
201
|
+
local_file_path=local_file_path,
|
|
202
|
+
server_file_path=server_file_path,
|
|
203
|
+
version_description=version_description,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# 获取版本标签
|
|
207
|
+
version_tag = self._get_version_tag(dataset_id, version_id)
|
|
208
|
+
logger.info(f"数据集版本创建成功,名称为 {version_tag},开始轮询状态…")
|
|
209
|
+
|
|
210
|
+
# 轮询等待版本状态变为成功
|
|
211
|
+
self._wait_for_version_success(version_id, timeout)
|
|
212
|
+
|
|
213
|
+
return dataset_id, version_id, version_tag
|
|
214
|
+
|
|
215
|
+
def _validate_create_params(
|
|
216
|
+
self, is_local_upload: bool, local_file_path: str | None, server_file_path: str | None
|
|
217
|
+
) -> None:
|
|
218
|
+
"""验证创建数据集和版本所需的参数"""
|
|
167
219
|
if is_local_upload:
|
|
168
220
|
if not local_file_path:
|
|
169
221
|
raise ValueError("is_local_upload=True 时必须提供 local_file_path")
|
|
170
|
-
upload_type = 1
|
|
171
222
|
else:
|
|
172
223
|
if not server_file_path:
|
|
173
224
|
raise ValueError("is_local_upload=False 时必须提供 server_file_path")
|
|
174
|
-
upload_type = 3
|
|
175
225
|
|
|
176
|
-
|
|
226
|
+
def _create_dataset(self, dataset_name: str, dataset_description: str) -> int:
|
|
227
|
+
"""创建数据集"""
|
|
228
|
+
return self._dataset.create(
|
|
177
229
|
CreateDatasetRequest(
|
|
178
230
|
name=dataset_name,
|
|
179
231
|
description=dataset_description,
|
|
@@ -184,39 +236,96 @@ class DatasetManagementService:
|
|
|
184
236
|
access_user_ids=None,
|
|
185
237
|
)
|
|
186
238
|
)
|
|
187
|
-
logger.info(
|
|
188
|
-
f"创建数据集成功,名称为 {dataset_name} ,开始准备创建版本、上传数据"
|
|
189
|
-
)
|
|
190
239
|
|
|
240
|
+
def _create_dataset_version(
|
|
241
|
+
self,
|
|
242
|
+
dataset_id: int,
|
|
243
|
+
is_local_upload: bool,
|
|
244
|
+
local_file_path: str | None,
|
|
245
|
+
server_file_path: str | None,
|
|
246
|
+
version_description: str,
|
|
247
|
+
) -> int:
|
|
248
|
+
"""根据上传类型创建数据集版本"""
|
|
191
249
|
if is_local_upload:
|
|
192
|
-
|
|
193
|
-
local_file_path = pathlib.Path(local_file_path)
|
|
194
|
-
if local_file_path.is_dir():
|
|
195
|
-
# 把文件夹打包为一个 zip
|
|
196
|
-
temp_zip_path = (
|
|
197
|
-
pathlib.Path(tempfile.mkdtemp()) / f" {uuid.uuid4().hex}.zip"
|
|
198
|
-
)
|
|
199
|
-
zip_dir(local_file_path, temp_zip_path)
|
|
200
|
-
upload_data = self._upload.upload_file(temp_zip_path)
|
|
201
|
-
os.remove(temp_zip_path)
|
|
202
|
-
else:
|
|
203
|
-
upload_data = self._upload.upload_file(local_file_path)
|
|
204
|
-
|
|
205
|
-
upload_path = upload_data.path
|
|
250
|
+
return self._create_local_dataset_version(dataset_id, local_file_path, version_description)
|
|
206
251
|
else:
|
|
207
|
-
|
|
208
|
-
|
|
252
|
+
return self._create_server_dataset_version(dataset_id, server_file_path, version_description)
|
|
253
|
+
|
|
254
|
+
def _create_local_dataset_version(
|
|
255
|
+
self, dataset_id: int, local_file_path: str | None, version_description: str
|
|
256
|
+
) -> int:
|
|
257
|
+
"""创建本地文件数据集版本"""
|
|
258
|
+
if pathlib.Path(local_file_path).is_dir():
|
|
259
|
+
return self._create_local_dir_dataset_version(dataset_id, local_file_path)
|
|
260
|
+
elif pathlib.Path(local_file_path).is_file():
|
|
261
|
+
return self._create_local_file_dataset_version(dataset_id, local_file_path, version_description)
|
|
262
|
+
else:
|
|
263
|
+
raise ValueError(f"本地路径既不是文件也不是目录: {local_file_path}")
|
|
209
264
|
|
|
210
|
-
|
|
265
|
+
def _create_local_dir_dataset_version(self, dataset_id: int, local_file_path: str) -> int:
|
|
266
|
+
"""处理本地目录上传"""
|
|
267
|
+
sts = self._get_sts()
|
|
268
|
+
s3_client = SimpleS3Client(
|
|
269
|
+
sts.endpoint, sts.access_key_id, sts.secret_access_key, session_token=sts.session_token
|
|
270
|
+
)
|
|
271
|
+
uid = uuid.uuid4().hex
|
|
272
|
+
s3_target = f"s3://{sts.bucket}/dataset_workspace/{dataset_id}/{uid}"
|
|
273
|
+
s3_csv_path = f"s3://{sts.bucket}/dataset_workspace/{dataset_id}/{uid}.csv"
|
|
274
|
+
s3_status_path = f"s3://{sts.bucket}/dataset_workspace/{dataset_id}/{uid}.json"
|
|
275
|
+
|
|
276
|
+
# 创建上传器并执行
|
|
277
|
+
uploader = DataUploader(
|
|
278
|
+
task_id=dataset_id,
|
|
279
|
+
local_path=str(local_file_path),
|
|
280
|
+
s3_target=s3_target,
|
|
281
|
+
csv_path=s3_csv_path,
|
|
282
|
+
status_path=s3_status_path,
|
|
283
|
+
num_workers=40,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
upload_stats = uploader.run(s3_client)
|
|
287
|
+
req = CreateDatasetVersionByDataIngestReqV2(
|
|
288
|
+
description=f"sdk 上传",
|
|
289
|
+
dataset_id=dataset_id,
|
|
290
|
+
s3_object_sheet=s3_csv_path,
|
|
291
|
+
object_cnt=upload_stats.uploaded_count,
|
|
292
|
+
data_size=upload_stats.uploaded_size,
|
|
293
|
+
)
|
|
294
|
+
return self.upload_by_data_ingest(req).id
|
|
295
|
+
|
|
296
|
+
def _create_local_file_dataset_version(
|
|
297
|
+
self, dataset_id: int, local_file_path: str, version_description: str
|
|
298
|
+
) -> int:
|
|
299
|
+
"""处理本地文件上传"""
|
|
300
|
+
upload_data = self._upload.upload_file(local_file_path)
|
|
301
|
+
upload_path = upload_data.path
|
|
302
|
+
logger.info(f"文件上传成功:{local_file_path}")
|
|
303
|
+
return self._dataset_version.upload(
|
|
211
304
|
UploadDatasetVersionRequest(
|
|
212
305
|
upload_path=upload_path,
|
|
213
|
-
upload_type=
|
|
306
|
+
upload_type=UploadType.LOCAL, # 本地上传类型
|
|
214
307
|
dataset_id=dataset_id,
|
|
215
308
|
description=version_description,
|
|
216
309
|
parent_version_id=0,
|
|
217
310
|
)
|
|
218
311
|
)
|
|
219
312
|
|
|
313
|
+
def _create_server_dataset_version(
|
|
314
|
+
self, dataset_id: int, server_file_path: str | None, version_description: str
|
|
315
|
+
) -> int:
|
|
316
|
+
"""创建服务器文件数据集版本"""
|
|
317
|
+
return self._dataset_version.upload(
|
|
318
|
+
UploadDatasetVersionRequest(
|
|
319
|
+
upload_path=server_file_path,
|
|
320
|
+
upload_type=UploadType.SERVER_PATH, # 服务器文件上传类型
|
|
321
|
+
dataset_id=dataset_id,
|
|
322
|
+
description=version_description,
|
|
323
|
+
parent_version_id=0,
|
|
324
|
+
)
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
def _get_version_tag(self, dataset_id: int, version_id: int) -> str:
|
|
328
|
+
"""获取版本标签"""
|
|
220
329
|
detail = self._dataset.get(dataset_id)
|
|
221
330
|
ver_num = next(
|
|
222
331
|
(v.version for v in detail.versions if v.id == version_id),
|
|
@@ -225,9 +334,10 @@ class DatasetManagementService:
|
|
|
225
334
|
if ver_num is None:
|
|
226
335
|
ver_num = 1
|
|
227
336
|
|
|
228
|
-
|
|
229
|
-
logger.info(f"数据集版本创建成功,名称为 {version_tag},开始轮询状态…")
|
|
337
|
+
return f"{detail.name}/V{ver_num}"
|
|
230
338
|
|
|
339
|
+
def _wait_for_version_success(self, version_id: int, timeout: int) -> None:
|
|
340
|
+
"""轮询等待版本状态变为成功"""
|
|
231
341
|
start_ts = time.time()
|
|
232
342
|
poll_interval = 10
|
|
233
343
|
|
|
@@ -249,8 +359,6 @@ class DatasetManagementService:
|
|
|
249
359
|
logger.debug(f"已等待 {elapsed:.0f}s,继续轮询…")
|
|
250
360
|
time.sleep(poll_interval)
|
|
251
361
|
|
|
252
|
-
return dataset_id, version_id, version_tag
|
|
253
|
-
|
|
254
362
|
def run_download(self, dataset_version_name: str, local_dir: str, worker: int = 4) -> None:
|
|
255
363
|
"""根据数据集版本名称下载对应的数据集文件。
|
|
256
364
|
|
|
@@ -270,6 +378,66 @@ class DatasetManagementService:
|
|
|
270
378
|
raise APIError("parquet_index_path 为空")
|
|
271
379
|
dataset_download(detail.parquet_index_path, local_dir, worker)
|
|
272
380
|
|
|
381
|
+
def list_datasets(
|
|
382
|
+
self,
|
|
383
|
+
*,
|
|
384
|
+
page_size: int = 20,
|
|
385
|
+
page_num: int = 1,
|
|
386
|
+
name: str | None = None,
|
|
387
|
+
tags: str | None = None,
|
|
388
|
+
create_by: int | None = None,
|
|
389
|
+
scope: str = "all"
|
|
390
|
+
) -> ListDatasetResp:
|
|
391
|
+
"""列表查询数据集
|
|
392
|
+
|
|
393
|
+
Args:
|
|
394
|
+
page_size: 每页大小,默认20
|
|
395
|
+
page_num: 页码,从1开始,默认1
|
|
396
|
+
name: 数据集名称筛选,可选
|
|
397
|
+
tags: 标签筛选,可选
|
|
398
|
+
create_by: 创建人筛选,可选
|
|
399
|
+
scope: 范围筛选:created|shared|all,默认all
|
|
400
|
+
|
|
401
|
+
Returns:
|
|
402
|
+
ListDatasetResp: 数据集列表响应,包含分页信息和数据集列表
|
|
403
|
+
"""
|
|
404
|
+
payload = ListDatasetReq(
|
|
405
|
+
page_size=page_size,
|
|
406
|
+
page_num=page_num,
|
|
407
|
+
name=name,
|
|
408
|
+
tags=tags,
|
|
409
|
+
create_by=create_by,
|
|
410
|
+
scope=scope
|
|
411
|
+
)
|
|
412
|
+
return self._dataset.list_datasets(payload)
|
|
413
|
+
|
|
414
|
+
def list_dataset_versions(
|
|
415
|
+
self,
|
|
416
|
+
*,
|
|
417
|
+
page_size: int = 10000000,
|
|
418
|
+
page_num: int = 1,
|
|
419
|
+
dataset_id: int | None = None,
|
|
420
|
+
dataset_version_ids: str | None = None
|
|
421
|
+
) -> ListDatasetVersionResp:
|
|
422
|
+
"""列表查询数据集版本
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
page_size: 每页大小,默认10000000
|
|
426
|
+
page_num: 页码,从1开始,默认1
|
|
427
|
+
dataset_id: 数据集ID筛选,可选
|
|
428
|
+
dataset_version_ids: 数据集版本ID列表,逗号分隔,可选
|
|
429
|
+
|
|
430
|
+
Returns:
|
|
431
|
+
ListDatasetVersionResp: 数据集版本列表响应,包含分页信息和数据集版本列表
|
|
432
|
+
"""
|
|
433
|
+
payload = ListDatasetVersionReq(
|
|
434
|
+
page_size=page_size,
|
|
435
|
+
page_num=page_num,
|
|
436
|
+
dataset_id=dataset_id,
|
|
437
|
+
dataset_version_ids=dataset_version_ids
|
|
438
|
+
)
|
|
439
|
+
return self._dataset_version.list_dataset_versions(payload)
|
|
440
|
+
|
|
273
441
|
|
|
274
442
|
class _Dataset:
|
|
275
443
|
def __init__(self, http: httpx.Client):
|
|
@@ -292,6 +460,15 @@ class _Dataset:
|
|
|
292
460
|
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
293
461
|
return wrapper.data
|
|
294
462
|
|
|
463
|
+
def list_datasets(self, payload: ListDatasetReq) -> ListDatasetResp:
|
|
464
|
+
"""列表查询数据集"""
|
|
465
|
+
params = payload.model_dump(by_alias=True, exclude_none=True)
|
|
466
|
+
resp = self._http.get(f"{_BASE}/datasets", params=params)
|
|
467
|
+
wrapper = APIWrapper[ListDatasetResp].model_validate(resp.json())
|
|
468
|
+
if wrapper.code != 0:
|
|
469
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
470
|
+
return wrapper.data
|
|
471
|
+
|
|
295
472
|
|
|
296
473
|
class _DatasetVersion:
|
|
297
474
|
def __init__(self, http: httpx.Client):
|
|
@@ -325,14 +502,55 @@ class _DatasetVersion:
|
|
|
325
502
|
return wrapper.data
|
|
326
503
|
|
|
327
504
|
def get_by_name(self, version_name: str) -> DatasetVersionDetail:
|
|
328
|
-
resp = self._http.get(
|
|
329
|
-
f"{_BASE}/dataset-versions-detail", params={"name": version_name}
|
|
330
|
-
)
|
|
505
|
+
resp = self._http.get(f"{_BASE}/dataset-versions-detail", params={"name": version_name})
|
|
331
506
|
wrapper = APIWrapper[DatasetVersionDetail].model_validate(resp.json())
|
|
332
507
|
if wrapper.code != 0:
|
|
333
508
|
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
334
509
|
return wrapper.data
|
|
335
510
|
|
|
511
|
+
def list_dataset_versions(self, payload: ListDatasetVersionReq) -> ListDatasetVersionResp:
|
|
512
|
+
"""列表查询数据集版本"""
|
|
513
|
+
params = payload.model_dump(by_alias=True, exclude_none=True)
|
|
514
|
+
resp = self._http.get(f"{_BASE}/dataset-versions", params=params)
|
|
515
|
+
wrapper = APIWrapper[ListDatasetVersionResp].model_validate(resp.json())
|
|
516
|
+
if wrapper.code != 0:
|
|
517
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
518
|
+
return wrapper.data
|
|
519
|
+
|
|
520
|
+
def get_sts(self) -> StsResp:
|
|
521
|
+
"""获取STS临时凭证
|
|
522
|
+
|
|
523
|
+
获取用于访问S3存储的临时凭证。
|
|
524
|
+
|
|
525
|
+
Returns:
|
|
526
|
+
StsResp: STS临时凭证信息
|
|
527
|
+
|
|
528
|
+
Raises:
|
|
529
|
+
APIError: 当API调用失败时抛出
|
|
530
|
+
"""
|
|
531
|
+
resp = self._http.get(f"{_BASE}/dataset-versions/get-sts")
|
|
532
|
+
logger.info(f"get sts: {resp.text}")
|
|
533
|
+
wrapper = APIWrapper[StsResp].model_validate(resp.json())
|
|
534
|
+
if wrapper.code != 0:
|
|
535
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
536
|
+
return wrapper.data
|
|
537
|
+
|
|
538
|
+
def upload_by_data_ingest(self, req: CreateDatasetVersionByDataIngestReqV2) -> CreateDatasetVersionResponse:
|
|
539
|
+
"""上传数据集版本(数据集导入)
|
|
540
|
+
Args:
|
|
541
|
+
req
|
|
542
|
+
|
|
543
|
+
"""
|
|
544
|
+
resp = self._http.post(
|
|
545
|
+
f"{_BASE}/dataset-versions/data-ingest",
|
|
546
|
+
json=req.model_dump(),
|
|
547
|
+
)
|
|
548
|
+
logger.debug(f"upload_by_data_ingest: {resp.text}")
|
|
549
|
+
wrapper = APIWrapper[CreateDatasetVersionResponse].model_validate(resp.json())
|
|
550
|
+
if wrapper.code != 0:
|
|
551
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
552
|
+
return wrapper.data
|
|
553
|
+
|
|
336
554
|
|
|
337
555
|
class _Upload:
|
|
338
556
|
def __init__(self, http: httpx.Client):
|
aihub/utils/di.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import csv
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import queue
|
|
7
|
+
import sys
|
|
8
|
+
import threading
|
|
9
|
+
import time
|
|
10
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Dict, Any, Tuple
|
|
13
|
+
|
|
14
|
+
import minio
|
|
15
|
+
from loguru import logger
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class UploadStatus:
|
|
19
|
+
"""上传状态类"""
|
|
20
|
+
|
|
21
|
+
def __init__(self):
|
|
22
|
+
self.uploaded_count = 0
|
|
23
|
+
self.uploaded_size = 0
|
|
24
|
+
self.updated_at = int(time.time() * 1000)
|
|
25
|
+
|
|
26
|
+
def update(self, count: int, size: int):
|
|
27
|
+
self.uploaded_count += count
|
|
28
|
+
self.uploaded_size += size
|
|
29
|
+
self.updated_at = int(time.time() * 1000)
|
|
30
|
+
|
|
31
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
32
|
+
return {
|
|
33
|
+
"uploaded_count": self.uploaded_count,
|
|
34
|
+
"uploaded_size": self.uploaded_size,
|
|
35
|
+
"updated_at": self.updated_at,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class SimpleS3Client:
|
|
40
|
+
"""简化的S3客户端"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, endpoint: str, access_key: str, secret_key: str, session_token: str):
|
|
43
|
+
self.client = minio.Minio(
|
|
44
|
+
endpoint, access_key=access_key, secret_key=secret_key, secure=False, session_token=session_token
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def upload_file(self, local_path: str, bucket: str, object_name: str) -> Tuple[str, int]:
|
|
48
|
+
"""上传文件并返回哈希和大小"""
|
|
49
|
+
file_size = os.path.getsize(local_path)
|
|
50
|
+
|
|
51
|
+
# 计算文件哈希
|
|
52
|
+
sha256_hash = hashlib.sha256()
|
|
53
|
+
with open(local_path, "rb") as f:
|
|
54
|
+
for chunk in iter(lambda: f.read(8192), b""):
|
|
55
|
+
sha256_hash.update(chunk)
|
|
56
|
+
|
|
57
|
+
file_hash = sha256_hash.hexdigest()
|
|
58
|
+
|
|
59
|
+
# 上传文件
|
|
60
|
+
with open(local_path, "rb") as f:
|
|
61
|
+
self.client.put_object(bucket, object_name, f, file_size)
|
|
62
|
+
|
|
63
|
+
return file_hash, file_size
|
|
64
|
+
|
|
65
|
+
def upload_json(self, data: Dict[str, Any], bucket: str, object_name: str):
|
|
66
|
+
"""上传JSON数据"""
|
|
67
|
+
json_str = json.dumps(data)
|
|
68
|
+
json_bytes = json_str.encode("utf-8")
|
|
69
|
+
|
|
70
|
+
from io import BytesIO
|
|
71
|
+
|
|
72
|
+
self.client.put_object(
|
|
73
|
+
bucket, object_name, BytesIO(json_bytes), len(json_bytes), content_type="application/json"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class DataUploader:
|
|
78
|
+
"""数据上传器"""
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self, task_id: int, local_path: str, s3_target: str, csv_path: str, status_path: str, num_workers: int = 10
|
|
82
|
+
):
|
|
83
|
+
self.task_id = task_id
|
|
84
|
+
self.local_path = local_path
|
|
85
|
+
self.num_workers = num_workers
|
|
86
|
+
|
|
87
|
+
# 解析S3路径
|
|
88
|
+
self.target_bucket, self.target_prefix = self._parse_s3_path(s3_target)
|
|
89
|
+
self.csv_bucket, self.csv_key = self._parse_s3_path(csv_path)
|
|
90
|
+
self.status_bucket, self.status_key = self._parse_s3_path(status_path)
|
|
91
|
+
|
|
92
|
+
# 创建工作目录
|
|
93
|
+
self.work_dir = Path.home() / ".di_workspace" / str(task_id)
|
|
94
|
+
self.work_dir.mkdir(parents=True, exist_ok=True)
|
|
95
|
+
self.csv_file = self.work_dir / "upload_records.csv"
|
|
96
|
+
|
|
97
|
+
# CSV记录队列
|
|
98
|
+
self.csv_queue = queue.Queue()
|
|
99
|
+
self.processed_files = set()
|
|
100
|
+
self.total_files = 0
|
|
101
|
+
|
|
102
|
+
def _parse_s3_path(self, s3_path: str) -> Tuple[str, str]:
|
|
103
|
+
"""解析S3路径"""
|
|
104
|
+
if s3_path.startswith("s3://"):
|
|
105
|
+
parts = s3_path[5:].split("/", 1)
|
|
106
|
+
bucket = parts[0]
|
|
107
|
+
key = parts[1] if len(parts) > 1 else ""
|
|
108
|
+
return bucket, key
|
|
109
|
+
return "", ""
|
|
110
|
+
|
|
111
|
+
def _collect_files(self) -> list:
|
|
112
|
+
"""收集需要上传的文件"""
|
|
113
|
+
files = []
|
|
114
|
+
|
|
115
|
+
if os.path.isfile(self.local_path):
|
|
116
|
+
files.append(self.local_path)
|
|
117
|
+
self.total_files += 1
|
|
118
|
+
else:
|
|
119
|
+
for root, _, filenames in os.walk(self.local_path):
|
|
120
|
+
for filename in filenames:
|
|
121
|
+
file_path = os.path.join(root, filename)
|
|
122
|
+
if not os.path.islink(file_path): # 跳过符号链接
|
|
123
|
+
files.append(file_path)
|
|
124
|
+
self.total_files += 1
|
|
125
|
+
|
|
126
|
+
# 过滤已处理的文件
|
|
127
|
+
base_path = os.path.dirname(self.local_path) if os.path.isfile(self.local_path) else self.local_path
|
|
128
|
+
unprocessed_files = []
|
|
129
|
+
|
|
130
|
+
for file_path in files:
|
|
131
|
+
rel_path = os.path.relpath(file_path, base_path)
|
|
132
|
+
if rel_path not in self.processed_files:
|
|
133
|
+
unprocessed_files.append(file_path)
|
|
134
|
+
|
|
135
|
+
return unprocessed_files
|
|
136
|
+
|
|
137
|
+
def _csv_writer_worker(self):
|
|
138
|
+
"""CSV写入工作器"""
|
|
139
|
+
# 初始化CSV文件
|
|
140
|
+
uploaded_count = 0
|
|
141
|
+
file_exists = os.path.exists(self.csv_file)
|
|
142
|
+
|
|
143
|
+
with open(self.csv_file, "a", newline="", encoding="utf-8") as f:
|
|
144
|
+
writer = csv.writer(f)
|
|
145
|
+
if not file_exists:
|
|
146
|
+
writer.writerow(["local_path", "sha256", "s3path", "file_size"])
|
|
147
|
+
|
|
148
|
+
while True:
|
|
149
|
+
try:
|
|
150
|
+
record = self.csv_queue.get(timeout=1)
|
|
151
|
+
if record is None: # 结束信号
|
|
152
|
+
break
|
|
153
|
+
|
|
154
|
+
writer.writerow(
|
|
155
|
+
[record["local_path"], record["file_hash"], record["s3_path"], str(record["file_size"])]
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
f.flush() # 确保数据写入磁盘
|
|
159
|
+
self.csv_queue.task_done()
|
|
160
|
+
uploaded_count += 1
|
|
161
|
+
# 每上传100个文件,打印进度
|
|
162
|
+
if uploaded_count % 1000 == 0:
|
|
163
|
+
logger.info(f"已上传 {uploaded_count} 个文件")
|
|
164
|
+
|
|
165
|
+
except queue.Empty:
|
|
166
|
+
continue
|
|
167
|
+
except Exception as e:
|
|
168
|
+
logger.error(f"Failed to write CSV record: {e}")
|
|
169
|
+
self.csv_queue.task_done()
|
|
170
|
+
|
|
171
|
+
def _upload_worker(self, s3_client: SimpleS3Client, file_queue: queue.Queue, base_path: str):
|
|
172
|
+
"""上传工作器"""
|
|
173
|
+
while True:
|
|
174
|
+
try:
|
|
175
|
+
file_path = file_queue.get(timeout=1)
|
|
176
|
+
if file_path is None: # 结束信号
|
|
177
|
+
break
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
# 计算相对路径和S3对象名
|
|
181
|
+
rel_path = os.path.relpath(file_path, base_path)
|
|
182
|
+
|
|
183
|
+
object_name = os.path.join(self.target_prefix, rel_path).replace("\\", "/")
|
|
184
|
+
|
|
185
|
+
# 上传文件
|
|
186
|
+
file_hash, file_size = s3_client.upload_file(file_path, self.target_bucket, object_name)
|
|
187
|
+
|
|
188
|
+
# 将记录放入CSV队列
|
|
189
|
+
s3_path = f"s3://{self.target_bucket}/{object_name}"
|
|
190
|
+
record = {
|
|
191
|
+
"local_path": os.path.join("/", rel_path),
|
|
192
|
+
"file_hash": file_hash,
|
|
193
|
+
"s3_path": s3_path,
|
|
194
|
+
"file_size": file_size,
|
|
195
|
+
}
|
|
196
|
+
self.csv_queue.put(record)
|
|
197
|
+
|
|
198
|
+
logger.debug(f"Uploaded: {rel_path}")
|
|
199
|
+
|
|
200
|
+
except Exception as e:
|
|
201
|
+
logger.error(f"Failed to upload {file_path}: {e}")
|
|
202
|
+
finally:
|
|
203
|
+
file_queue.task_done()
|
|
204
|
+
|
|
205
|
+
except queue.Empty:
|
|
206
|
+
break
|
|
207
|
+
|
|
208
|
+
def _calculate_final_stats(self) -> UploadStatus:
|
|
209
|
+
"""从CSV文件计算最终统计信息"""
|
|
210
|
+
stats = UploadStatus()
|
|
211
|
+
if not os.path.exists(self.csv_file):
|
|
212
|
+
return stats
|
|
213
|
+
|
|
214
|
+
total_count = 0
|
|
215
|
+
total_size = 0
|
|
216
|
+
|
|
217
|
+
try:
|
|
218
|
+
with open(self.csv_file, "r", encoding="utf-8") as f:
|
|
219
|
+
reader = csv.DictReader(f)
|
|
220
|
+
for row in reader:
|
|
221
|
+
total_count += 1
|
|
222
|
+
total_size += int(row["file_size"])
|
|
223
|
+
except Exception as e:
|
|
224
|
+
logger.error(f"Failed to calculate stats: {e}")
|
|
225
|
+
|
|
226
|
+
stats.update(total_count, total_size)
|
|
227
|
+
|
|
228
|
+
return stats
|
|
229
|
+
|
|
230
|
+
def run(self, s3_client: SimpleS3Client) -> UploadStatus:
|
|
231
|
+
"""执行上传任务"""
|
|
232
|
+
# 收集文件
|
|
233
|
+
files = self._collect_files()
|
|
234
|
+
if not files:
|
|
235
|
+
logger.info("No files to upload")
|
|
236
|
+
return UploadStatus()
|
|
237
|
+
|
|
238
|
+
logger.info(f"Found {len(files)} files to upload")
|
|
239
|
+
|
|
240
|
+
# 准备文件队列
|
|
241
|
+
file_queue = queue.Queue()
|
|
242
|
+
for file_path in files:
|
|
243
|
+
file_queue.put(file_path)
|
|
244
|
+
|
|
245
|
+
base_path = os.path.dirname(self.local_path) if os.path.isfile(self.local_path) else self.local_path
|
|
246
|
+
|
|
247
|
+
# 启动CSV写入线程
|
|
248
|
+
csv_thread = threading.Thread(target=self._csv_writer_worker)
|
|
249
|
+
csv_thread.daemon = True
|
|
250
|
+
csv_thread.start()
|
|
251
|
+
|
|
252
|
+
try:
|
|
253
|
+
# 启动上传工作器
|
|
254
|
+
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
|
|
255
|
+
futures = []
|
|
256
|
+
for i in range(self.num_workers):
|
|
257
|
+
future = executor.submit(self._upload_worker, s3_client, file_queue, base_path)
|
|
258
|
+
futures.append(future)
|
|
259
|
+
|
|
260
|
+
# 等待所有任务完成
|
|
261
|
+
for future in as_completed(futures):
|
|
262
|
+
future.result()
|
|
263
|
+
|
|
264
|
+
# 等待CSV队列处理完成
|
|
265
|
+
self.csv_queue.join()
|
|
266
|
+
|
|
267
|
+
# 发送结束信号给CSV写入线程
|
|
268
|
+
self.csv_queue.put(None)
|
|
269
|
+
csv_thread.join()
|
|
270
|
+
|
|
271
|
+
# 上传记录文件到S3
|
|
272
|
+
if os.path.exists(self.csv_file):
|
|
273
|
+
s3_client.upload_file(str(self.csv_file), self.csv_bucket, self.csv_key)
|
|
274
|
+
logger.info("Upload records saved to S3")
|
|
275
|
+
|
|
276
|
+
# 计算并上传最终统计信息
|
|
277
|
+
stats = self._calculate_final_stats()
|
|
278
|
+
s3_client.upload_json(stats.to_dict(), self.status_bucket, self.status_key)
|
|
279
|
+
logger.info(f"Upload completed: {stats.uploaded_count} files, {stats.uploaded_size} bytes")
|
|
280
|
+
|
|
281
|
+
finally:
|
|
282
|
+
# 清理工作目录
|
|
283
|
+
try:
|
|
284
|
+
import shutil
|
|
285
|
+
|
|
286
|
+
shutil.rmtree(self.work_dir)
|
|
287
|
+
except Exception as e:
|
|
288
|
+
logger.warning(f"Failed to cleanup workspace: {e}")
|
|
289
|
+
return stats
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def main():
|
|
293
|
+
"""主函数"""
|
|
294
|
+
parser = argparse.ArgumentParser(description="简化的数据摄取工具")
|
|
295
|
+
parser.add_argument("-e", "--endpoint", default="192.168.13.160:9008", help="S3端点")
|
|
296
|
+
parser.add_argument("-ak", "--access-key", default="admin2024", help="访问密钥")
|
|
297
|
+
parser.add_argument("-sk", "--secret-key", default="root@23452024", help="秘密密钥")
|
|
298
|
+
parser.add_argument("-t", "--target", default="s3://testbucket/test_ok11", help="目标S3路径")
|
|
299
|
+
parser.add_argument("-l", "--local", default="./test_data", help="本地路径")
|
|
300
|
+
parser.add_argument("-o", "--object-sheet", default="s3://testbucket/records/123.csv", help="记录文件S3路径")
|
|
301
|
+
parser.add_argument("-s", "--status", default="s3://testbucket/status/123.json", help="状态文件S3路径")
|
|
302
|
+
parser.add_argument("-i", "--task-id", type=int, default=123, help="任务ID")
|
|
303
|
+
parser.add_argument("-n", "--num-workers", type=int, default=10, help="工作线程数")
|
|
304
|
+
|
|
305
|
+
args = parser.parse_args()
|
|
306
|
+
|
|
307
|
+
# 检查本地路径
|
|
308
|
+
if not os.path.exists(args.local):
|
|
309
|
+
logger.error(f"Local path does not exist: {args.local}")
|
|
310
|
+
sys.exit(1)
|
|
311
|
+
|
|
312
|
+
logger.info(f"Starting upload: {args.local} -> {args.target}")
|
|
313
|
+
|
|
314
|
+
try:
|
|
315
|
+
# 创建S3客户端
|
|
316
|
+
s3_client = SimpleS3Client(args.endpoint, args.access_key, args.secret_key)
|
|
317
|
+
|
|
318
|
+
# 创建上传器并执行
|
|
319
|
+
uploader = DataUploader(
|
|
320
|
+
task_id=args.task_id,
|
|
321
|
+
local_path=args.local,
|
|
322
|
+
s3_target=args.target,
|
|
323
|
+
csv_path=args.object_sheet,
|
|
324
|
+
status_path=args.status,
|
|
325
|
+
num_workers=args.num_workers,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
uploader.run(s3_client)
|
|
329
|
+
logger.info("Upload completed successfully")
|
|
330
|
+
|
|
331
|
+
except Exception as e:
|
|
332
|
+
logger.error(f"Upload failed: {e}")
|
|
333
|
+
sys.exit(1)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
if __name__ == "__main__":
|
|
337
|
+
main()
|
aihub/utils/download.py
CHANGED
|
@@ -1,15 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import concurrent.futures
|
|
4
3
|
import os
|
|
5
4
|
import tempfile
|
|
6
5
|
import zipfile
|
|
7
6
|
from typing import List, TypedDict
|
|
8
7
|
|
|
9
8
|
import pyarrow.parquet as pq
|
|
10
|
-
from tqdm import
|
|
9
|
+
from tqdm.contrib.concurrent import thread_map
|
|
11
10
|
|
|
12
|
-
from .http import http_download_file
|
|
11
|
+
from .http import http_download_file, http_download_file_wrapper
|
|
13
12
|
from .s3 import s3_to_url
|
|
14
13
|
|
|
15
14
|
|
|
@@ -59,18 +58,7 @@ def dataset_download(index_url: str, local_dir: str, worker: int = 4) -> None:
|
|
|
59
58
|
if worker < 1:
|
|
60
59
|
worker = 1
|
|
61
60
|
|
|
62
|
-
|
|
63
|
-
tqdm(total=len(files), desc="Downloading dataset") as bar,
|
|
64
|
-
concurrent.futures.ThreadPoolExecutor(max_workers=worker) as pool,
|
|
65
|
-
):
|
|
66
|
-
|
|
67
|
-
def _one(flocal: str, furl: str):
|
|
68
|
-
http_download_file(furl, flocal)
|
|
69
|
-
bar.update()
|
|
70
|
-
|
|
71
|
-
futures = [pool.submit(_one, p, u) for p, u in files]
|
|
72
|
-
for fut in concurrent.futures.as_completed(futures):
|
|
73
|
-
fut.result()
|
|
61
|
+
thread_map(http_download_file_wrapper, files, max_workers=worker)
|
|
74
62
|
|
|
75
63
|
|
|
76
64
|
def zip_dir(dir_path: str, zip_path: str):
|
aihub/utils/http.py
CHANGED
|
@@ -5,6 +5,12 @@ import os
|
|
|
5
5
|
import requests
|
|
6
6
|
|
|
7
7
|
|
|
8
|
+
def http_download_file_wrapper(item):
|
|
9
|
+
"""Wrapper function"""
|
|
10
|
+
dst_path, url = item
|
|
11
|
+
return http_download_file(url, dst_path)
|
|
12
|
+
|
|
13
|
+
|
|
8
14
|
def http_download_file(url: str, dst_path: str, chunk: int = 1 << 16) -> None:
|
|
9
15
|
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
|
|
10
16
|
with requests.get(url, timeout=None, stream=True) as r:
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
aihub/__init__.py,sha256=
|
|
1
|
+
aihub/__init__.py,sha256=qb0TalpSt1CbprnFyeLUKqgrqNtmnk9IoQQ7umAoXVY,23
|
|
2
2
|
aihub/client.py,sha256=nVELjkyVOG6DKJjurYn59fCoT5JsSayUweiH7bvKcAo,5547
|
|
3
3
|
aihub/exceptions.py,sha256=l2cMAvipTqQOio3o11fXsCCSCevbuK4PTsxofkobFjk,500
|
|
4
4
|
aihub/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
-
aihub/models/artifact.py,sha256=
|
|
5
|
+
aihub/models/artifact.py,sha256=F-r7DJY9A09yIQJqWol6gLRu6y7NGjRa6-BxkMEluxU,4655
|
|
6
6
|
aihub/models/common.py,sha256=qmabc2LkAdQJXIcpT1P35zxd0Lc8yDYdD4ame1iF4Bs,241
|
|
7
7
|
aihub/models/data_warehouse.py,sha256=zXvWwg7ySoFJMdqQ_1UMTNEKDMhu1hDHlWdBAXdizBk,3905
|
|
8
|
-
aihub/models/dataset_management.py,sha256=
|
|
8
|
+
aihub/models/dataset_management.py,sha256=4DuQ0zM7jv73SJiqvieHLtn2Y-T6FIFV9r7bgzyCtDo,10790
|
|
9
9
|
aihub/models/document_center.py,sha256=od9bzx6krAS6ktIA-ChxeqGcch0v2wsS1flY2vuHXBc,1340
|
|
10
10
|
aihub/models/eval.py,sha256=4Gon4Sg4dOkyCx3KH2mO5ip3AhrBwrPC0UZA447HeoQ,910
|
|
11
11
|
aihub/models/labelfree.py,sha256=nljprYO6ECuctTVbHqriQ73N5EEyYURhBrnU28Ngfvc,1589
|
|
@@ -17,9 +17,9 @@ aihub/models/task_center.py,sha256=HE21Q4Uj-vt9LHezHnqBYgdinhrh4iJPOq8VXbSMllU,5
|
|
|
17
17
|
aihub/models/user_system.py,sha256=0L_pBkWL9v1tv_mclOyRgCyWibuuj_XU3mPoe2v48vQ,12216
|
|
18
18
|
aihub/models/workflow_center.py,sha256=4xtI1WZ38ceXJ8gwDBj-QNjOiRlLO_8kGiQybdudJPY,20121
|
|
19
19
|
aihub/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
20
|
-
aihub/services/artifact.py,sha256=
|
|
20
|
+
aihub/services/artifact.py,sha256=lfOrgOT2AlH1w-75NLcQGOhVWdhmJcWD1gESPpUzqUw,11257
|
|
21
21
|
aihub/services/data_warehouse.py,sha256=awvlJdggo8ph6sXweXXVp4GLRuUSD46LoD0QQksXRts,2964
|
|
22
|
-
aihub/services/dataset_management.py,sha256=
|
|
22
|
+
aihub/services/dataset_management.py,sha256=R7mFsJ1dNOI_p5yNj_rQdLolRC0UKEN4WejE7uOjVlE,21379
|
|
23
23
|
aihub/services/document_center.py,sha256=dG67Ji-DOnzL2t-4x4gVfMt9fbSj_IjVHCLw5R-VTkQ,1813
|
|
24
24
|
aihub/services/eval.py,sha256=V1nBISIyYWg9JJO24xzy4-kit9NsaCYp1EWIX_fgJkQ,2128
|
|
25
25
|
aihub/services/labelfree.py,sha256=xua62UWhVXTxJjHRyy86waaormnJjmpQwepcARBy_h0,1450
|
|
@@ -32,11 +32,12 @@ aihub/services/task_center.py,sha256=rVQG7q2_GN0501w5KHsOOlSVFX9ovpRMGX5hskCqggw
|
|
|
32
32
|
aihub/services/user_system.py,sha256=IqWL4bnsKyyzuGT5l6adnw0qNXlH9PSo1-C_pFyOSzA,18868
|
|
33
33
|
aihub/services/workflow_center.py,sha256=caKxOlba0J1s1RUK6RUm1ndJSwAcZXEakRanu3sGKPU,17468
|
|
34
34
|
aihub/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
35
|
-
aihub/utils/
|
|
36
|
-
aihub/utils/
|
|
35
|
+
aihub/utils/di.py,sha256=vFUzno5WbRKu6-pj8Hnz9IqT7xb9UDZQ4qpOFH1YAtM,11812
|
|
36
|
+
aihub/utils/download.py,sha256=ZZVbcC-PnN3PumV7ZiJ_-srkt4HPPovu2F6Faa2RrPE,1830
|
|
37
|
+
aihub/utils/http.py,sha256=AmfHHNjptuuSFx2T1twWCnerR_hLN_gd0lUs8z36ERA,547
|
|
37
38
|
aihub/utils/s3.py,sha256=ISIBP-XdBPkURpXnN56ZnIWokOOg2SRUh_qvxJk-G1Q,2187
|
|
38
|
-
intellif_aihub-0.1.
|
|
39
|
-
intellif_aihub-0.1.
|
|
40
|
-
intellif_aihub-0.1.
|
|
41
|
-
intellif_aihub-0.1.
|
|
42
|
-
intellif_aihub-0.1.
|
|
39
|
+
intellif_aihub-0.1.15.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
|
40
|
+
intellif_aihub-0.1.15.dist-info/METADATA,sha256=Hz8Z3sB06pNTJF8lygzDU37da2bCgXCrzJ1-CRAlN7Y,2949
|
|
41
|
+
intellif_aihub-0.1.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
42
|
+
intellif_aihub-0.1.15.dist-info/top_level.txt,sha256=vIvTtSIN73xv46BpYM-ctVGnyOiUQ9EWP_6ngvdIlvw,6
|
|
43
|
+
intellif_aihub-0.1.15.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|