intellif-aihub 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of intellif-aihub might be problematic. Click here for more details.
- aihub/__init__.py +1 -1
- aihub/models/artifact.py +16 -4
- aihub/models/dataset_management.py +23 -0
- aihub/services/artifact.py +16 -30
- aihub/services/dataset_management.py +176 -42
- aihub/utils/di.py +337 -0
- aihub/utils/download.py +3 -15
- aihub/utils/http.py +6 -0
- {intellif_aihub-0.1.14.dist-info → intellif_aihub-0.1.15.dist-info}/METADATA +1 -1
- {intellif_aihub-0.1.14.dist-info → intellif_aihub-0.1.15.dist-info}/RECORD +13 -12
- {intellif_aihub-0.1.14.dist-info → intellif_aihub-0.1.15.dist-info}/WHEEL +0 -0
- {intellif_aihub-0.1.14.dist-info → intellif_aihub-0.1.15.dist-info}/licenses/LICENSE +0 -0
- {intellif_aihub-0.1.14.dist-info → intellif_aihub-0.1.15.dist-info}/top_level.txt +0 -0
aihub/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.1.
|
|
1
|
+
__version__ = "0.1.15"
|
aihub/models/artifact.py
CHANGED
|
@@ -14,6 +14,7 @@ from pydantic import BaseModel, Field
|
|
|
14
14
|
|
|
15
15
|
class ArtifactType(str, Enum):
|
|
16
16
|
"""制品类型枚举:"dataset"-数据集类型;"model"-模型类型;"metrics"-指标类型;"log"-日志类型;"checkpoint"-检查点类型;"image"-图像类型;"prediction"-预测结果类型;"other"-其他类型"""
|
|
17
|
+
|
|
17
18
|
dataset = "dataset" # 数据集类型
|
|
18
19
|
model = "model" # 模型类型
|
|
19
20
|
metrics = "metrics" # 指标类型
|
|
@@ -26,24 +27,29 @@ class ArtifactType(str, Enum):
|
|
|
26
27
|
|
|
27
28
|
class CreateArtifactsReq(BaseModel):
|
|
28
29
|
"""创建制品请求"""
|
|
30
|
+
|
|
29
31
|
entity_id: str = Field(alias="entity_id", description="实体ID,通常是运行ID,用于关联制品与特定运行")
|
|
30
|
-
entity_type: ArtifactType = Field(
|
|
31
|
-
|
|
32
|
+
entity_type: ArtifactType = Field(
|
|
33
|
+
default=ArtifactType.other, alias="entity_type", description="制品类型,指定制品的类型,默认为other"
|
|
34
|
+
)
|
|
32
35
|
src_path: str = Field(alias="src_path", description="源路径,制品在系统中的路径标识")
|
|
33
|
-
is_dir: bool = Field(
|
|
34
|
-
|
|
36
|
+
is_dir: bool = Field(
|
|
37
|
+
default=False, alias="is_dir", description="是否为目录,True表示制品是一个目录,False表示是单个文件"
|
|
38
|
+
)
|
|
35
39
|
|
|
36
40
|
model_config = {"use_enum_values": True}
|
|
37
41
|
|
|
38
42
|
|
|
39
43
|
class CreateArtifactsResponseData(BaseModel):
|
|
40
44
|
"""创建制品响应数据"""
|
|
45
|
+
|
|
41
46
|
id: int = Field(description="制品ID")
|
|
42
47
|
s3_path: str = Field(alias="s3_path", description="S3存储路径")
|
|
43
48
|
|
|
44
49
|
|
|
45
50
|
class CreateArtifactsResponseModel(BaseModel):
|
|
46
51
|
"""创建制品响应模型"""
|
|
52
|
+
|
|
47
53
|
code: int = Field(description="响应码,0表示成功")
|
|
48
54
|
msg: str = Field(default="", description="响应消息")
|
|
49
55
|
data: Optional[CreateArtifactsResponseData] = Field(default=None, description="响应数据")
|
|
@@ -51,6 +57,7 @@ class CreateArtifactsResponseModel(BaseModel):
|
|
|
51
57
|
|
|
52
58
|
class CreateEvalReq(BaseModel):
|
|
53
59
|
"""创建评估请求"""
|
|
60
|
+
|
|
54
61
|
dataset_id: int = Field(alias="dataset_id", description="数据集ID")
|
|
55
62
|
dataset_version_id: int = Field(alias="dataset_version_id", description="数据集版本ID")
|
|
56
63
|
prediction_artifact_path: str = Field(alias="prediction_artifact_path", description="预测结果制品路径")
|
|
@@ -62,6 +69,7 @@ class CreateEvalReq(BaseModel):
|
|
|
62
69
|
|
|
63
70
|
class ArtifactResp(BaseModel):
|
|
64
71
|
"""制品响应模型,表示一个制品的详细信息"""
|
|
72
|
+
|
|
65
73
|
id: int = Field(description="制品ID")
|
|
66
74
|
entity_type: str = Field(alias="entity_type", description="实体类型")
|
|
67
75
|
entity_id: str = Field(alias="entity_id", description="实体ID")
|
|
@@ -72,6 +80,7 @@ class ArtifactResp(BaseModel):
|
|
|
72
80
|
|
|
73
81
|
class ArtifactRespData(BaseModel):
|
|
74
82
|
"""制品分页数据"""
|
|
83
|
+
|
|
75
84
|
total: int = Field(description="总记录数")
|
|
76
85
|
page_size: int = Field(alias="page_size", description="每页大小")
|
|
77
86
|
page_num: int = Field(alias="page_num", description="页码")
|
|
@@ -80,6 +89,7 @@ class ArtifactRespData(BaseModel):
|
|
|
80
89
|
|
|
81
90
|
class ArtifactRespModel(BaseModel):
|
|
82
91
|
"""获取制品响应模型"""
|
|
92
|
+
|
|
83
93
|
code: int = Field(description="响应码,0表示成功")
|
|
84
94
|
msg: str = Field(default="", description="响应消息")
|
|
85
95
|
data: ArtifactRespData = Field(description="响应数据")
|
|
@@ -91,8 +101,10 @@ InfinityPageSize = 10000 * 100
|
|
|
91
101
|
|
|
92
102
|
class StsResp(BaseModel):
|
|
93
103
|
"""STS 临时凭证"""
|
|
104
|
+
|
|
94
105
|
access_key_id: Optional[str] = Field(default=None, alias="access_key_id", description="访问密钥ID")
|
|
95
106
|
secret_access_key: Optional[str] = Field(default=None, alias="secret_access_key", description="秘密访问密钥")
|
|
96
107
|
session_token: Optional[str] = Field(default=None, alias="session_token", description="会话令牌")
|
|
97
108
|
expiration: Optional[int] = Field(default=None, alias="expiration", description="过期时间")
|
|
98
109
|
endpoint: Optional[str] = Field(default=None, alias="endpoint", description="端点URL")
|
|
110
|
+
bucket: Optional[str] = Field(default=None, alias="bucket", description="存储桶名称")
|
|
@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
|
|
|
8
8
|
|
|
9
9
|
class DatasetVersionStatus(IntEnum):
|
|
10
10
|
"""数据集版本状态:1-等待中;2-运行中;3-成功;4-失败;5-加载meta;6-构建index"""
|
|
11
|
+
|
|
11
12
|
Waiting = 1 # 等待中
|
|
12
13
|
Running = 2 # 运行中
|
|
13
14
|
Success = 3 # 成功
|
|
@@ -18,6 +19,7 @@ class DatasetVersionStatus(IntEnum):
|
|
|
18
19
|
|
|
19
20
|
class UploadType(IntEnum):
|
|
20
21
|
"""上传类型:1-本地上传;3-服务器路径上传;4-Labelfree;5-数据接入"""
|
|
22
|
+
|
|
21
23
|
LOCAL = 1 # 本地上传
|
|
22
24
|
SERVER_PATH = 3 # 服务器路径上传
|
|
23
25
|
LABELFREE = 4 # Labelfree
|
|
@@ -26,6 +28,7 @@ class UploadType(IntEnum):
|
|
|
26
28
|
|
|
27
29
|
class CreateDatasetRequest(BaseModel):
|
|
28
30
|
"""创建数据集请求"""
|
|
31
|
+
|
|
29
32
|
name: str = Field(description="数据集名称")
|
|
30
33
|
description: str = Field(description="数据集描述")
|
|
31
34
|
tags: List[int] = Field(description="标签ID列表,通过标签管理系统查询")
|
|
@@ -37,11 +40,13 @@ class CreateDatasetRequest(BaseModel):
|
|
|
37
40
|
|
|
38
41
|
class CreateDatasetResponse(BaseModel):
|
|
39
42
|
"""创建数据集返回"""
|
|
43
|
+
|
|
40
44
|
id: int = Field(alias="id", description="数据集ID")
|
|
41
45
|
|
|
42
46
|
|
|
43
47
|
class DatasetVersionBase(BaseModel):
|
|
44
48
|
"""数据集版本概要"""
|
|
49
|
+
|
|
45
50
|
id: int = Field(description="版本ID")
|
|
46
51
|
version: int = Field(description="版本号")
|
|
47
52
|
status: DatasetVersionStatus = Field(description="版本状态")
|
|
@@ -53,6 +58,7 @@ class DatasetVersionBase(BaseModel):
|
|
|
53
58
|
|
|
54
59
|
class DatasetDetail(BaseModel):
|
|
55
60
|
"""数据集详情"""
|
|
61
|
+
|
|
56
62
|
id: int = Field(description="数据集 ID")
|
|
57
63
|
name: str = Field(description="名称")
|
|
58
64
|
description: str = Field(description="描述")
|
|
@@ -69,6 +75,7 @@ class DatasetDetail(BaseModel):
|
|
|
69
75
|
|
|
70
76
|
class ExtInfo(BaseModel):
|
|
71
77
|
"""扩展信息"""
|
|
78
|
+
|
|
72
79
|
rec_file_path: Optional[str] = Field(None, alias="rec_file_path", description="rec文件路径")
|
|
73
80
|
idx_file_path: Optional[str] = Field(None, alias="idx_file_path", description="idx文件路径")
|
|
74
81
|
json_file_path: Optional[str] = Field(None, alias="json_file_path", description="json文件路径")
|
|
@@ -77,6 +84,7 @@ class ExtInfo(BaseModel):
|
|
|
77
84
|
|
|
78
85
|
class CreateDatasetVersionRequest(BaseModel):
|
|
79
86
|
"""创建版本请求"""
|
|
87
|
+
|
|
80
88
|
upload_path: str = Field(alias="upload_path", description="上传路径")
|
|
81
89
|
description: Optional[str] = Field(None, description="版本描述")
|
|
82
90
|
dataset_id: int = Field(alias="dataset_id", description="数据集ID")
|
|
@@ -91,11 +99,13 @@ class CreateDatasetVersionRequest(BaseModel):
|
|
|
91
99
|
|
|
92
100
|
class CreateDatasetVersionResponse(BaseModel):
|
|
93
101
|
"""创建版本返回"""
|
|
102
|
+
|
|
94
103
|
id: int = Field(alias="id", description="版本ID")
|
|
95
104
|
|
|
96
105
|
|
|
97
106
|
class UploadDatasetVersionRequest(BaseModel):
|
|
98
107
|
"""上传数据集版本请求"""
|
|
108
|
+
|
|
99
109
|
upload_path: str = Field(alias="upload_path", description="上传目录")
|
|
100
110
|
upload_type: UploadType = Field(alias="upload_type", description="上传类型")
|
|
101
111
|
dataset_id: int = Field(alias="dataset_id", description="数据集ID")
|
|
@@ -107,11 +117,13 @@ class UploadDatasetVersionRequest(BaseModel):
|
|
|
107
117
|
|
|
108
118
|
class UploadDatasetVersionResponse(BaseModel):
|
|
109
119
|
"""上传数据集版本返回"""
|
|
120
|
+
|
|
110
121
|
id: int = Field(alias="id", description="版本ID")
|
|
111
122
|
|
|
112
123
|
|
|
113
124
|
class DatasetVersionDetail(BaseModel):
|
|
114
125
|
"""数据集版本详情"""
|
|
126
|
+
|
|
115
127
|
id: int = Field(description="版本ID")
|
|
116
128
|
version: int = Field(description="版本号")
|
|
117
129
|
dataset_id: int = Field(alias="dataset_id", description="数据集ID")
|
|
@@ -133,6 +145,7 @@ class DatasetVersionDetail(BaseModel):
|
|
|
133
145
|
|
|
134
146
|
class FileUploadData(BaseModel):
|
|
135
147
|
"""文件上传数据"""
|
|
148
|
+
|
|
136
149
|
path: str = Field(description="路径")
|
|
137
150
|
url: str = Field(description="URL")
|
|
138
151
|
|
|
@@ -203,3 +216,13 @@ class ListDatasetVersionResp(BaseModel):
|
|
|
203
216
|
page_size: int = Field(alias="page_size", description="每页大小")
|
|
204
217
|
page_num: int = Field(alias="page_num", description="当前页码")
|
|
205
218
|
data: List[ListDatasetVersionItem] = Field(description="数据集版本列表")
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class CreateDatasetVersionByDataIngestReqV2(BaseModel):
|
|
222
|
+
"""通过数据集成创建数据集版本请求"""
|
|
223
|
+
|
|
224
|
+
description: Optional[str] = Field(None, description="描述")
|
|
225
|
+
dataset_id: int = Field(..., description="数据集ID")
|
|
226
|
+
s3_object_sheet: str = Field(..., description="S3对象表")
|
|
227
|
+
object_cnt: Optional[int] = Field(None, description="对象数量")
|
|
228
|
+
data_size: Optional[int] = Field(None, description="数据大小")
|
aihub/services/artifact.py
CHANGED
|
@@ -98,9 +98,7 @@ class ArtifactService:
|
|
|
98
98
|
"""
|
|
99
99
|
return self._artifact.get_sts()
|
|
100
100
|
|
|
101
|
-
def get_by_run_id(
|
|
102
|
-
self, run_id: str, artifact_path: Optional[str] = None
|
|
103
|
-
) -> List[ArtifactResp]:
|
|
101
|
+
def get_by_run_id(self, run_id: str, artifact_path: Optional[str] = None) -> List[ArtifactResp]:
|
|
104
102
|
"""根据运行ID获取制品列表
|
|
105
103
|
|
|
106
104
|
Args:
|
|
@@ -116,11 +114,11 @@ class ArtifactService:
|
|
|
116
114
|
return self._artifact.get_by_run_id(run_id, artifact_path)
|
|
117
115
|
|
|
118
116
|
def create_artifact(
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
117
|
+
self,
|
|
118
|
+
local_path: str,
|
|
119
|
+
artifact_path: Optional[str] = None,
|
|
120
|
+
run_id: Optional[str] = None,
|
|
121
|
+
artifact_type: ArtifactType = ArtifactType.other,
|
|
124
122
|
) -> None:
|
|
125
123
|
"""创建单个文件制品并上传
|
|
126
124
|
|
|
@@ -171,11 +169,11 @@ class ArtifactService:
|
|
|
171
169
|
return
|
|
172
170
|
|
|
173
171
|
def create_artifacts(
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
172
|
+
self,
|
|
173
|
+
local_dir: str,
|
|
174
|
+
artifact_path: Optional[str] = None,
|
|
175
|
+
run_id: Optional[str] = None,
|
|
176
|
+
artifact_type: ArtifactType = ArtifactType.other,
|
|
179
177
|
) -> None:
|
|
180
178
|
"""创建目录制品并上传
|
|
181
179
|
|
|
@@ -223,9 +221,7 @@ class ArtifactService:
|
|
|
223
221
|
logger.info(f"log artifact done: {artifact_path}")
|
|
224
222
|
return
|
|
225
223
|
|
|
226
|
-
def download_artifacts(
|
|
227
|
-
self, run_id: str, artifact_path: Optional[str], local_dir: str
|
|
228
|
-
) -> None:
|
|
224
|
+
def download_artifacts(self, run_id: str, artifact_path: Optional[str], local_dir: str) -> None:
|
|
229
225
|
"""下载制品
|
|
230
226
|
|
|
231
227
|
Args:
|
|
@@ -252,9 +248,7 @@ class ArtifactService:
|
|
|
252
248
|
if artifact_item.is_dir:
|
|
253
249
|
download_dir_from_s3(self.s3_client, bucket, object_name, local_dir)
|
|
254
250
|
else:
|
|
255
|
-
self.s3_client.fget_object(
|
|
256
|
-
bucket, object_name, str(Path(local_dir) / artifact_item.src_path)
|
|
257
|
-
)
|
|
251
|
+
self.s3_client.fget_object(bucket, object_name, str(Path(local_dir) / artifact_item.src_path))
|
|
258
252
|
|
|
259
253
|
logger.info(f"download artifact done: {artifact_path}")
|
|
260
254
|
return
|
|
@@ -311,9 +305,7 @@ class _Artifact:
|
|
|
311
305
|
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
312
306
|
return
|
|
313
307
|
|
|
314
|
-
def get_by_run_id(
|
|
315
|
-
self, run_id: str, artifact_path: Optional[str]
|
|
316
|
-
) -> List[ArtifactResp]:
|
|
308
|
+
def get_by_run_id(self, run_id: str, artifact_path: Optional[str]) -> List[ArtifactResp]:
|
|
317
309
|
"""根据运行ID获取制品列表
|
|
318
310
|
|
|
319
311
|
Args:
|
|
@@ -326,18 +318,12 @@ class _Artifact:
|
|
|
326
318
|
Raises:
|
|
327
319
|
APIError: 当API调用失败时抛出
|
|
328
320
|
"""
|
|
329
|
-
resp = self._http.get(
|
|
330
|
-
f"{_Base}/artifacts?entity_id={run_id}&page_num=1&page_size={InfinityPageSize}"
|
|
331
|
-
)
|
|
321
|
+
resp = self._http.get(f"{_Base}/artifacts?entity_id={run_id}&page_num=1&page_size={InfinityPageSize}")
|
|
332
322
|
wrapper = APIWrapper[ArtifactRespData].model_validate(resp.json())
|
|
333
323
|
if wrapper.code != 0:
|
|
334
324
|
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
335
325
|
if artifact_path:
|
|
336
|
-
return [
|
|
337
|
-
artifact
|
|
338
|
-
for artifact in wrapper.data.data
|
|
339
|
-
if artifact.src_path == artifact_path
|
|
340
|
-
]
|
|
326
|
+
return [artifact for artifact in wrapper.data.data if artifact.src_path == artifact_path]
|
|
341
327
|
else:
|
|
342
328
|
return wrapper.data.data
|
|
343
329
|
|
|
@@ -17,7 +17,6 @@ from __future__ import annotations
|
|
|
17
17
|
import mimetypes
|
|
18
18
|
import os
|
|
19
19
|
import pathlib
|
|
20
|
-
import tempfile
|
|
21
20
|
import time
|
|
22
21
|
import uuid
|
|
23
22
|
|
|
@@ -25,6 +24,7 @@ import httpx
|
|
|
25
24
|
from loguru import logger
|
|
26
25
|
|
|
27
26
|
from ..exceptions import APIError
|
|
27
|
+
from ..models.artifact import StsResp
|
|
28
28
|
from ..models.common import APIWrapper
|
|
29
29
|
from ..models.dataset_management import (
|
|
30
30
|
CreateDatasetRequest,
|
|
@@ -40,9 +40,12 @@ from ..models.dataset_management import (
|
|
|
40
40
|
ListDatasetResp,
|
|
41
41
|
ListDatasetVersionReq,
|
|
42
42
|
ListDatasetVersionResp,
|
|
43
|
+
CreateDatasetVersionByDataIngestReqV2,
|
|
44
|
+
UploadType,
|
|
43
45
|
)
|
|
44
46
|
from ..models.dataset_management import DatasetVersionStatus
|
|
45
|
-
from ..utils.
|
|
47
|
+
from ..utils.di import SimpleS3Client, DataUploader
|
|
48
|
+
from ..utils.download import dataset_download
|
|
46
49
|
|
|
47
50
|
_BASE = "/dataset-mng/api/v2"
|
|
48
51
|
|
|
@@ -138,20 +141,29 @@ class DatasetManagementService:
|
|
|
138
141
|
def dataset(self) -> _Dataset:
|
|
139
142
|
return self._dataset
|
|
140
143
|
|
|
144
|
+
def _get_sts(self) -> StsResp:
|
|
145
|
+
return self.dataset_version.get_sts()
|
|
146
|
+
|
|
141
147
|
@property
|
|
142
148
|
def dataset_version(self) -> _DatasetVersion:
|
|
143
149
|
return self._dataset_version
|
|
144
150
|
|
|
151
|
+
def upload_by_data_ingest(
|
|
152
|
+
self,
|
|
153
|
+
req: CreateDatasetVersionByDataIngestReqV2,
|
|
154
|
+
) -> CreateDatasetVersionResponse:
|
|
155
|
+
return self.dataset_version.upload_by_data_ingest(req)
|
|
156
|
+
|
|
145
157
|
def create_dataset_and_version(
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
158
|
+
self,
|
|
159
|
+
*,
|
|
160
|
+
dataset_name: str,
|
|
161
|
+
dataset_description: str = "",
|
|
162
|
+
is_local_upload: bool,
|
|
163
|
+
local_file_path: str | None = None,
|
|
164
|
+
server_file_path: str | None = None,
|
|
165
|
+
version_description: str = "",
|
|
166
|
+
timeout: int = 1_800,
|
|
155
167
|
) -> tuple[int, int, str]:
|
|
156
168
|
"""创建数据集及其版本,并等待版本状态变为 *Success*。
|
|
157
169
|
|
|
@@ -169,17 +181,51 @@ class DatasetManagementService:
|
|
|
169
181
|
|
|
170
182
|
Returns:
|
|
171
183
|
tuple[int, int, str]: 一个三元组,包含:[数据集 ID,数据集版本 ID, 数据集版本标签(格式为 <dataset_name>/V<version_number>)]
|
|
184
|
+
|
|
185
|
+
Raises:
|
|
186
|
+
ValueError: 当参数不满足要求时
|
|
187
|
+
APIError: 当后端返回错误时
|
|
188
|
+
TimeoutError: 当等待超时时
|
|
172
189
|
"""
|
|
190
|
+
# 参数校验
|
|
191
|
+
self._validate_create_params(is_local_upload, local_file_path, server_file_path)
|
|
192
|
+
|
|
193
|
+
# 创建数据集
|
|
194
|
+
dataset_id = self._create_dataset(dataset_name, dataset_description)
|
|
195
|
+
logger.info(f"创建数据集成功,名称为 {dataset_name} ,开始准备创建版本、上传数据")
|
|
196
|
+
|
|
197
|
+
# 创建数据集版本
|
|
198
|
+
version_id = self._create_dataset_version(
|
|
199
|
+
dataset_id=dataset_id,
|
|
200
|
+
is_local_upload=is_local_upload,
|
|
201
|
+
local_file_path=local_file_path,
|
|
202
|
+
server_file_path=server_file_path,
|
|
203
|
+
version_description=version_description,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# 获取版本标签
|
|
207
|
+
version_tag = self._get_version_tag(dataset_id, version_id)
|
|
208
|
+
logger.info(f"数据集版本创建成功,名称为 {version_tag},开始轮询状态…")
|
|
209
|
+
|
|
210
|
+
# 轮询等待版本状态变为成功
|
|
211
|
+
self._wait_for_version_success(version_id, timeout)
|
|
212
|
+
|
|
213
|
+
return dataset_id, version_id, version_tag
|
|
214
|
+
|
|
215
|
+
def _validate_create_params(
|
|
216
|
+
self, is_local_upload: bool, local_file_path: str | None, server_file_path: str | None
|
|
217
|
+
) -> None:
|
|
218
|
+
"""验证创建数据集和版本所需的参数"""
|
|
173
219
|
if is_local_upload:
|
|
174
220
|
if not local_file_path:
|
|
175
221
|
raise ValueError("is_local_upload=True 时必须提供 local_file_path")
|
|
176
|
-
upload_type = 1
|
|
177
222
|
else:
|
|
178
223
|
if not server_file_path:
|
|
179
224
|
raise ValueError("is_local_upload=False 时必须提供 server_file_path")
|
|
180
|
-
upload_type = 3
|
|
181
225
|
|
|
182
|
-
|
|
226
|
+
def _create_dataset(self, dataset_name: str, dataset_description: str) -> int:
|
|
227
|
+
"""创建数据集"""
|
|
228
|
+
return self._dataset.create(
|
|
183
229
|
CreateDatasetRequest(
|
|
184
230
|
name=dataset_name,
|
|
185
231
|
description=dataset_description,
|
|
@@ -190,39 +236,96 @@ class DatasetManagementService:
|
|
|
190
236
|
access_user_ids=None,
|
|
191
237
|
)
|
|
192
238
|
)
|
|
193
|
-
logger.info(
|
|
194
|
-
f"创建数据集成功,名称为 {dataset_name} ,开始准备创建版本、上传数据"
|
|
195
|
-
)
|
|
196
239
|
|
|
240
|
+
def _create_dataset_version(
|
|
241
|
+
self,
|
|
242
|
+
dataset_id: int,
|
|
243
|
+
is_local_upload: bool,
|
|
244
|
+
local_file_path: str | None,
|
|
245
|
+
server_file_path: str | None,
|
|
246
|
+
version_description: str,
|
|
247
|
+
) -> int:
|
|
248
|
+
"""根据上传类型创建数据集版本"""
|
|
197
249
|
if is_local_upload:
|
|
198
|
-
|
|
199
|
-
local_file_path = pathlib.Path(local_file_path)
|
|
200
|
-
if local_file_path.is_dir():
|
|
201
|
-
# 把文件夹打包为一个 zip
|
|
202
|
-
temp_zip_path = (
|
|
203
|
-
pathlib.Path(tempfile.mkdtemp()) / f" {uuid.uuid4().hex}.zip"
|
|
204
|
-
)
|
|
205
|
-
zip_dir(local_file_path, temp_zip_path)
|
|
206
|
-
upload_data = self._upload.upload_file(temp_zip_path)
|
|
207
|
-
os.remove(temp_zip_path)
|
|
208
|
-
else:
|
|
209
|
-
upload_data = self._upload.upload_file(local_file_path)
|
|
210
|
-
|
|
211
|
-
upload_path = upload_data.path
|
|
250
|
+
return self._create_local_dataset_version(dataset_id, local_file_path, version_description)
|
|
212
251
|
else:
|
|
213
|
-
|
|
214
|
-
|
|
252
|
+
return self._create_server_dataset_version(dataset_id, server_file_path, version_description)
|
|
253
|
+
|
|
254
|
+
def _create_local_dataset_version(
|
|
255
|
+
self, dataset_id: int, local_file_path: str | None, version_description: str
|
|
256
|
+
) -> int:
|
|
257
|
+
"""创建本地文件数据集版本"""
|
|
258
|
+
if pathlib.Path(local_file_path).is_dir():
|
|
259
|
+
return self._create_local_dir_dataset_version(dataset_id, local_file_path)
|
|
260
|
+
elif pathlib.Path(local_file_path).is_file():
|
|
261
|
+
return self._create_local_file_dataset_version(dataset_id, local_file_path, version_description)
|
|
262
|
+
else:
|
|
263
|
+
raise ValueError(f"本地路径既不是文件也不是目录: {local_file_path}")
|
|
264
|
+
|
|
265
|
+
def _create_local_dir_dataset_version(self, dataset_id: int, local_file_path: str) -> int:
|
|
266
|
+
"""处理本地目录上传"""
|
|
267
|
+
sts = self._get_sts()
|
|
268
|
+
s3_client = SimpleS3Client(
|
|
269
|
+
sts.endpoint, sts.access_key_id, sts.secret_access_key, session_token=sts.session_token
|
|
270
|
+
)
|
|
271
|
+
uid = uuid.uuid4().hex
|
|
272
|
+
s3_target = f"s3://{sts.bucket}/dataset_workspace/{dataset_id}/{uid}"
|
|
273
|
+
s3_csv_path = f"s3://{sts.bucket}/dataset_workspace/{dataset_id}/{uid}.csv"
|
|
274
|
+
s3_status_path = f"s3://{sts.bucket}/dataset_workspace/{dataset_id}/{uid}.json"
|
|
275
|
+
|
|
276
|
+
# 创建上传器并执行
|
|
277
|
+
uploader = DataUploader(
|
|
278
|
+
task_id=dataset_id,
|
|
279
|
+
local_path=str(local_file_path),
|
|
280
|
+
s3_target=s3_target,
|
|
281
|
+
csv_path=s3_csv_path,
|
|
282
|
+
status_path=s3_status_path,
|
|
283
|
+
num_workers=40,
|
|
284
|
+
)
|
|
215
285
|
|
|
216
|
-
|
|
286
|
+
upload_stats = uploader.run(s3_client)
|
|
287
|
+
req = CreateDatasetVersionByDataIngestReqV2(
|
|
288
|
+
description=f"sdk 上传",
|
|
289
|
+
dataset_id=dataset_id,
|
|
290
|
+
s3_object_sheet=s3_csv_path,
|
|
291
|
+
object_cnt=upload_stats.uploaded_count,
|
|
292
|
+
data_size=upload_stats.uploaded_size,
|
|
293
|
+
)
|
|
294
|
+
return self.upload_by_data_ingest(req).id
|
|
295
|
+
|
|
296
|
+
def _create_local_file_dataset_version(
|
|
297
|
+
self, dataset_id: int, local_file_path: str, version_description: str
|
|
298
|
+
) -> int:
|
|
299
|
+
"""处理本地文件上传"""
|
|
300
|
+
upload_data = self._upload.upload_file(local_file_path)
|
|
301
|
+
upload_path = upload_data.path
|
|
302
|
+
logger.info(f"文件上传成功:{local_file_path}")
|
|
303
|
+
return self._dataset_version.upload(
|
|
217
304
|
UploadDatasetVersionRequest(
|
|
218
305
|
upload_path=upload_path,
|
|
219
|
-
upload_type=
|
|
306
|
+
upload_type=UploadType.LOCAL, # 本地上传类型
|
|
307
|
+
dataset_id=dataset_id,
|
|
308
|
+
description=version_description,
|
|
309
|
+
parent_version_id=0,
|
|
310
|
+
)
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
def _create_server_dataset_version(
|
|
314
|
+
self, dataset_id: int, server_file_path: str | None, version_description: str
|
|
315
|
+
) -> int:
|
|
316
|
+
"""创建服务器文件数据集版本"""
|
|
317
|
+
return self._dataset_version.upload(
|
|
318
|
+
UploadDatasetVersionRequest(
|
|
319
|
+
upload_path=server_file_path,
|
|
320
|
+
upload_type=UploadType.SERVER_PATH, # 服务器文件上传类型
|
|
220
321
|
dataset_id=dataset_id,
|
|
221
322
|
description=version_description,
|
|
222
323
|
parent_version_id=0,
|
|
223
324
|
)
|
|
224
325
|
)
|
|
225
326
|
|
|
327
|
+
def _get_version_tag(self, dataset_id: int, version_id: int) -> str:
|
|
328
|
+
"""获取版本标签"""
|
|
226
329
|
detail = self._dataset.get(dataset_id)
|
|
227
330
|
ver_num = next(
|
|
228
331
|
(v.version for v in detail.versions if v.id == version_id),
|
|
@@ -231,9 +334,10 @@ class DatasetManagementService:
|
|
|
231
334
|
if ver_num is None:
|
|
232
335
|
ver_num = 1
|
|
233
336
|
|
|
234
|
-
|
|
235
|
-
logger.info(f"数据集版本创建成功,名称为 {version_tag},开始轮询状态…")
|
|
337
|
+
return f"{detail.name}/V{ver_num}"
|
|
236
338
|
|
|
339
|
+
def _wait_for_version_success(self, version_id: int, timeout: int) -> None:
|
|
340
|
+
"""轮询等待版本状态变为成功"""
|
|
237
341
|
start_ts = time.time()
|
|
238
342
|
poll_interval = 10
|
|
239
343
|
|
|
@@ -255,8 +359,6 @@ class DatasetManagementService:
|
|
|
255
359
|
logger.debug(f"已等待 {elapsed:.0f}s,继续轮询…")
|
|
256
360
|
time.sleep(poll_interval)
|
|
257
361
|
|
|
258
|
-
return dataset_id, version_id, version_tag
|
|
259
|
-
|
|
260
362
|
def run_download(self, dataset_version_name: str, local_dir: str, worker: int = 4) -> None:
|
|
261
363
|
"""根据数据集版本名称下载对应的数据集文件。
|
|
262
364
|
|
|
@@ -400,9 +502,7 @@ class _DatasetVersion:
|
|
|
400
502
|
return wrapper.data
|
|
401
503
|
|
|
402
504
|
def get_by_name(self, version_name: str) -> DatasetVersionDetail:
|
|
403
|
-
resp = self._http.get(
|
|
404
|
-
f"{_BASE}/dataset-versions-detail", params={"name": version_name}
|
|
405
|
-
)
|
|
505
|
+
resp = self._http.get(f"{_BASE}/dataset-versions-detail", params={"name": version_name})
|
|
406
506
|
wrapper = APIWrapper[DatasetVersionDetail].model_validate(resp.json())
|
|
407
507
|
if wrapper.code != 0:
|
|
408
508
|
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
@@ -417,6 +517,40 @@ class _DatasetVersion:
|
|
|
417
517
|
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
418
518
|
return wrapper.data
|
|
419
519
|
|
|
520
|
+
def get_sts(self) -> StsResp:
|
|
521
|
+
"""获取STS临时凭证
|
|
522
|
+
|
|
523
|
+
获取用于访问S3存储的临时凭证。
|
|
524
|
+
|
|
525
|
+
Returns:
|
|
526
|
+
StsResp: STS临时凭证信息
|
|
527
|
+
|
|
528
|
+
Raises:
|
|
529
|
+
APIError: 当API调用失败时抛出
|
|
530
|
+
"""
|
|
531
|
+
resp = self._http.get(f"{_BASE}/dataset-versions/get-sts")
|
|
532
|
+
logger.info(f"get sts: {resp.text}")
|
|
533
|
+
wrapper = APIWrapper[StsResp].model_validate(resp.json())
|
|
534
|
+
if wrapper.code != 0:
|
|
535
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
536
|
+
return wrapper.data
|
|
537
|
+
|
|
538
|
+
def upload_by_data_ingest(self, req: CreateDatasetVersionByDataIngestReqV2) -> CreateDatasetVersionResponse:
|
|
539
|
+
"""上传数据集版本(数据集导入)
|
|
540
|
+
Args:
|
|
541
|
+
req
|
|
542
|
+
|
|
543
|
+
"""
|
|
544
|
+
resp = self._http.post(
|
|
545
|
+
f"{_BASE}/dataset-versions/data-ingest",
|
|
546
|
+
json=req.model_dump(),
|
|
547
|
+
)
|
|
548
|
+
logger.debug(f"upload_by_data_ingest: {resp.text}")
|
|
549
|
+
wrapper = APIWrapper[CreateDatasetVersionResponse].model_validate(resp.json())
|
|
550
|
+
if wrapper.code != 0:
|
|
551
|
+
raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
|
|
552
|
+
return wrapper.data
|
|
553
|
+
|
|
420
554
|
|
|
421
555
|
class _Upload:
|
|
422
556
|
def __init__(self, http: httpx.Client):
|
aihub/utils/di.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import csv
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import queue
|
|
7
|
+
import sys
|
|
8
|
+
import threading
|
|
9
|
+
import time
|
|
10
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Dict, Any, Tuple
|
|
13
|
+
|
|
14
|
+
import minio
|
|
15
|
+
from loguru import logger
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class UploadStatus:
|
|
19
|
+
"""上传状态类"""
|
|
20
|
+
|
|
21
|
+
def __init__(self):
|
|
22
|
+
self.uploaded_count = 0
|
|
23
|
+
self.uploaded_size = 0
|
|
24
|
+
self.updated_at = int(time.time() * 1000)
|
|
25
|
+
|
|
26
|
+
def update(self, count: int, size: int):
|
|
27
|
+
self.uploaded_count += count
|
|
28
|
+
self.uploaded_size += size
|
|
29
|
+
self.updated_at = int(time.time() * 1000)
|
|
30
|
+
|
|
31
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
32
|
+
return {
|
|
33
|
+
"uploaded_count": self.uploaded_count,
|
|
34
|
+
"uploaded_size": self.uploaded_size,
|
|
35
|
+
"updated_at": self.updated_at,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class SimpleS3Client:
|
|
40
|
+
"""简化的S3客户端"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, endpoint: str, access_key: str, secret_key: str, session_token: str):
|
|
43
|
+
self.client = minio.Minio(
|
|
44
|
+
endpoint, access_key=access_key, secret_key=secret_key, secure=False, session_token=session_token
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def upload_file(self, local_path: str, bucket: str, object_name: str) -> Tuple[str, int]:
|
|
48
|
+
"""上传文件并返回哈希和大小"""
|
|
49
|
+
file_size = os.path.getsize(local_path)
|
|
50
|
+
|
|
51
|
+
# 计算文件哈希
|
|
52
|
+
sha256_hash = hashlib.sha256()
|
|
53
|
+
with open(local_path, "rb") as f:
|
|
54
|
+
for chunk in iter(lambda: f.read(8192), b""):
|
|
55
|
+
sha256_hash.update(chunk)
|
|
56
|
+
|
|
57
|
+
file_hash = sha256_hash.hexdigest()
|
|
58
|
+
|
|
59
|
+
# 上传文件
|
|
60
|
+
with open(local_path, "rb") as f:
|
|
61
|
+
self.client.put_object(bucket, object_name, f, file_size)
|
|
62
|
+
|
|
63
|
+
return file_hash, file_size
|
|
64
|
+
|
|
65
|
+
def upload_json(self, data: Dict[str, Any], bucket: str, object_name: str):
|
|
66
|
+
"""上传JSON数据"""
|
|
67
|
+
json_str = json.dumps(data)
|
|
68
|
+
json_bytes = json_str.encode("utf-8")
|
|
69
|
+
|
|
70
|
+
from io import BytesIO
|
|
71
|
+
|
|
72
|
+
self.client.put_object(
|
|
73
|
+
bucket, object_name, BytesIO(json_bytes), len(json_bytes), content_type="application/json"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class DataUploader:
|
|
78
|
+
"""数据上传器"""
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self, task_id: int, local_path: str, s3_target: str, csv_path: str, status_path: str, num_workers: int = 10
|
|
82
|
+
):
|
|
83
|
+
self.task_id = task_id
|
|
84
|
+
self.local_path = local_path
|
|
85
|
+
self.num_workers = num_workers
|
|
86
|
+
|
|
87
|
+
# 解析S3路径
|
|
88
|
+
self.target_bucket, self.target_prefix = self._parse_s3_path(s3_target)
|
|
89
|
+
self.csv_bucket, self.csv_key = self._parse_s3_path(csv_path)
|
|
90
|
+
self.status_bucket, self.status_key = self._parse_s3_path(status_path)
|
|
91
|
+
|
|
92
|
+
# 创建工作目录
|
|
93
|
+
self.work_dir = Path.home() / ".di_workspace" / str(task_id)
|
|
94
|
+
self.work_dir.mkdir(parents=True, exist_ok=True)
|
|
95
|
+
self.csv_file = self.work_dir / "upload_records.csv"
|
|
96
|
+
|
|
97
|
+
# CSV记录队列
|
|
98
|
+
self.csv_queue = queue.Queue()
|
|
99
|
+
self.processed_files = set()
|
|
100
|
+
self.total_files = 0
|
|
101
|
+
|
|
102
|
+
def _parse_s3_path(self, s3_path: str) -> Tuple[str, str]:
|
|
103
|
+
"""解析S3路径"""
|
|
104
|
+
if s3_path.startswith("s3://"):
|
|
105
|
+
parts = s3_path[5:].split("/", 1)
|
|
106
|
+
bucket = parts[0]
|
|
107
|
+
key = parts[1] if len(parts) > 1 else ""
|
|
108
|
+
return bucket, key
|
|
109
|
+
return "", ""
|
|
110
|
+
|
|
111
|
+
def _collect_files(self) -> list:
|
|
112
|
+
"""收集需要上传的文件"""
|
|
113
|
+
files = []
|
|
114
|
+
|
|
115
|
+
if os.path.isfile(self.local_path):
|
|
116
|
+
files.append(self.local_path)
|
|
117
|
+
self.total_files += 1
|
|
118
|
+
else:
|
|
119
|
+
for root, _, filenames in os.walk(self.local_path):
|
|
120
|
+
for filename in filenames:
|
|
121
|
+
file_path = os.path.join(root, filename)
|
|
122
|
+
if not os.path.islink(file_path): # 跳过符号链接
|
|
123
|
+
files.append(file_path)
|
|
124
|
+
self.total_files += 1
|
|
125
|
+
|
|
126
|
+
# 过滤已处理的文件
|
|
127
|
+
base_path = os.path.dirname(self.local_path) if os.path.isfile(self.local_path) else self.local_path
|
|
128
|
+
unprocessed_files = []
|
|
129
|
+
|
|
130
|
+
for file_path in files:
|
|
131
|
+
rel_path = os.path.relpath(file_path, base_path)
|
|
132
|
+
if rel_path not in self.processed_files:
|
|
133
|
+
unprocessed_files.append(file_path)
|
|
134
|
+
|
|
135
|
+
return unprocessed_files
|
|
136
|
+
|
|
137
|
+
def _csv_writer_worker(self):
|
|
138
|
+
"""CSV写入工作器"""
|
|
139
|
+
# 初始化CSV文件
|
|
140
|
+
uploaded_count = 0
|
|
141
|
+
file_exists = os.path.exists(self.csv_file)
|
|
142
|
+
|
|
143
|
+
with open(self.csv_file, "a", newline="", encoding="utf-8") as f:
|
|
144
|
+
writer = csv.writer(f)
|
|
145
|
+
if not file_exists:
|
|
146
|
+
writer.writerow(["local_path", "sha256", "s3path", "file_size"])
|
|
147
|
+
|
|
148
|
+
while True:
|
|
149
|
+
try:
|
|
150
|
+
record = self.csv_queue.get(timeout=1)
|
|
151
|
+
if record is None: # 结束信号
|
|
152
|
+
break
|
|
153
|
+
|
|
154
|
+
writer.writerow(
|
|
155
|
+
[record["local_path"], record["file_hash"], record["s3_path"], str(record["file_size"])]
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
f.flush() # 确保数据写入磁盘
|
|
159
|
+
self.csv_queue.task_done()
|
|
160
|
+
uploaded_count += 1
|
|
161
|
+
# 每上传100个文件,打印进度
|
|
162
|
+
if uploaded_count % 1000 == 0:
|
|
163
|
+
logger.info(f"已上传 {uploaded_count} 个文件")
|
|
164
|
+
|
|
165
|
+
except queue.Empty:
|
|
166
|
+
continue
|
|
167
|
+
except Exception as e:
|
|
168
|
+
logger.error(f"Failed to write CSV record: {e}")
|
|
169
|
+
self.csv_queue.task_done()
|
|
170
|
+
|
|
171
|
+
def _upload_worker(self, s3_client: SimpleS3Client, file_queue: queue.Queue, base_path: str):
|
|
172
|
+
"""上传工作器"""
|
|
173
|
+
while True:
|
|
174
|
+
try:
|
|
175
|
+
file_path = file_queue.get(timeout=1)
|
|
176
|
+
if file_path is None: # 结束信号
|
|
177
|
+
break
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
# 计算相对路径和S3对象名
|
|
181
|
+
rel_path = os.path.relpath(file_path, base_path)
|
|
182
|
+
|
|
183
|
+
object_name = os.path.join(self.target_prefix, rel_path).replace("\\", "/")
|
|
184
|
+
|
|
185
|
+
# 上传文件
|
|
186
|
+
file_hash, file_size = s3_client.upload_file(file_path, self.target_bucket, object_name)
|
|
187
|
+
|
|
188
|
+
# 将记录放入CSV队列
|
|
189
|
+
s3_path = f"s3://{self.target_bucket}/{object_name}"
|
|
190
|
+
record = {
|
|
191
|
+
"local_path": os.path.join("/", rel_path),
|
|
192
|
+
"file_hash": file_hash,
|
|
193
|
+
"s3_path": s3_path,
|
|
194
|
+
"file_size": file_size,
|
|
195
|
+
}
|
|
196
|
+
self.csv_queue.put(record)
|
|
197
|
+
|
|
198
|
+
logger.debug(f"Uploaded: {rel_path}")
|
|
199
|
+
|
|
200
|
+
except Exception as e:
|
|
201
|
+
logger.error(f"Failed to upload {file_path}: {e}")
|
|
202
|
+
finally:
|
|
203
|
+
file_queue.task_done()
|
|
204
|
+
|
|
205
|
+
except queue.Empty:
|
|
206
|
+
break
|
|
207
|
+
|
|
208
|
+
def _calculate_final_stats(self) -> UploadStatus:
|
|
209
|
+
"""从CSV文件计算最终统计信息"""
|
|
210
|
+
stats = UploadStatus()
|
|
211
|
+
if not os.path.exists(self.csv_file):
|
|
212
|
+
return stats
|
|
213
|
+
|
|
214
|
+
total_count = 0
|
|
215
|
+
total_size = 0
|
|
216
|
+
|
|
217
|
+
try:
|
|
218
|
+
with open(self.csv_file, "r", encoding="utf-8") as f:
|
|
219
|
+
reader = csv.DictReader(f)
|
|
220
|
+
for row in reader:
|
|
221
|
+
total_count += 1
|
|
222
|
+
total_size += int(row["file_size"])
|
|
223
|
+
except Exception as e:
|
|
224
|
+
logger.error(f"Failed to calculate stats: {e}")
|
|
225
|
+
|
|
226
|
+
stats.update(total_count, total_size)
|
|
227
|
+
|
|
228
|
+
return stats
|
|
229
|
+
|
|
230
|
+
def run(self, s3_client: SimpleS3Client) -> UploadStatus:
|
|
231
|
+
"""执行上传任务"""
|
|
232
|
+
# 收集文件
|
|
233
|
+
files = self._collect_files()
|
|
234
|
+
if not files:
|
|
235
|
+
logger.info("No files to upload")
|
|
236
|
+
return UploadStatus()
|
|
237
|
+
|
|
238
|
+
logger.info(f"Found {len(files)} files to upload")
|
|
239
|
+
|
|
240
|
+
# 准备文件队列
|
|
241
|
+
file_queue = queue.Queue()
|
|
242
|
+
for file_path in files:
|
|
243
|
+
file_queue.put(file_path)
|
|
244
|
+
|
|
245
|
+
base_path = os.path.dirname(self.local_path) if os.path.isfile(self.local_path) else self.local_path
|
|
246
|
+
|
|
247
|
+
# 启动CSV写入线程
|
|
248
|
+
csv_thread = threading.Thread(target=self._csv_writer_worker)
|
|
249
|
+
csv_thread.daemon = True
|
|
250
|
+
csv_thread.start()
|
|
251
|
+
|
|
252
|
+
try:
|
|
253
|
+
# 启动上传工作器
|
|
254
|
+
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
|
|
255
|
+
futures = []
|
|
256
|
+
for i in range(self.num_workers):
|
|
257
|
+
future = executor.submit(self._upload_worker, s3_client, file_queue, base_path)
|
|
258
|
+
futures.append(future)
|
|
259
|
+
|
|
260
|
+
# 等待所有任务完成
|
|
261
|
+
for future in as_completed(futures):
|
|
262
|
+
future.result()
|
|
263
|
+
|
|
264
|
+
# 等待CSV队列处理完成
|
|
265
|
+
self.csv_queue.join()
|
|
266
|
+
|
|
267
|
+
# 发送结束信号给CSV写入线程
|
|
268
|
+
self.csv_queue.put(None)
|
|
269
|
+
csv_thread.join()
|
|
270
|
+
|
|
271
|
+
# 上传记录文件到S3
|
|
272
|
+
if os.path.exists(self.csv_file):
|
|
273
|
+
s3_client.upload_file(str(self.csv_file), self.csv_bucket, self.csv_key)
|
|
274
|
+
logger.info("Upload records saved to S3")
|
|
275
|
+
|
|
276
|
+
# 计算并上传最终统计信息
|
|
277
|
+
stats = self._calculate_final_stats()
|
|
278
|
+
s3_client.upload_json(stats.to_dict(), self.status_bucket, self.status_key)
|
|
279
|
+
logger.info(f"Upload completed: {stats.uploaded_count} files, {stats.uploaded_size} bytes")
|
|
280
|
+
|
|
281
|
+
finally:
|
|
282
|
+
# 清理工作目录
|
|
283
|
+
try:
|
|
284
|
+
import shutil
|
|
285
|
+
|
|
286
|
+
shutil.rmtree(self.work_dir)
|
|
287
|
+
except Exception as e:
|
|
288
|
+
logger.warning(f"Failed to cleanup workspace: {e}")
|
|
289
|
+
return stats
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def main():
|
|
293
|
+
"""主函数"""
|
|
294
|
+
parser = argparse.ArgumentParser(description="简化的数据摄取工具")
|
|
295
|
+
parser.add_argument("-e", "--endpoint", default="192.168.13.160:9008", help="S3端点")
|
|
296
|
+
parser.add_argument("-ak", "--access-key", default="admin2024", help="访问密钥")
|
|
297
|
+
parser.add_argument("-sk", "--secret-key", default="root@23452024", help="秘密密钥")
|
|
298
|
+
parser.add_argument("-t", "--target", default="s3://testbucket/test_ok11", help="目标S3路径")
|
|
299
|
+
parser.add_argument("-l", "--local", default="./test_data", help="本地路径")
|
|
300
|
+
parser.add_argument("-o", "--object-sheet", default="s3://testbucket/records/123.csv", help="记录文件S3路径")
|
|
301
|
+
parser.add_argument("-s", "--status", default="s3://testbucket/status/123.json", help="状态文件S3路径")
|
|
302
|
+
parser.add_argument("-i", "--task-id", type=int, default=123, help="任务ID")
|
|
303
|
+
parser.add_argument("-n", "--num-workers", type=int, default=10, help="工作线程数")
|
|
304
|
+
|
|
305
|
+
args = parser.parse_args()
|
|
306
|
+
|
|
307
|
+
# 检查本地路径
|
|
308
|
+
if not os.path.exists(args.local):
|
|
309
|
+
logger.error(f"Local path does not exist: {args.local}")
|
|
310
|
+
sys.exit(1)
|
|
311
|
+
|
|
312
|
+
logger.info(f"Starting upload: {args.local} -> {args.target}")
|
|
313
|
+
|
|
314
|
+
try:
|
|
315
|
+
# 创建S3客户端
|
|
316
|
+
s3_client = SimpleS3Client(args.endpoint, args.access_key, args.secret_key)
|
|
317
|
+
|
|
318
|
+
# 创建上传器并执行
|
|
319
|
+
uploader = DataUploader(
|
|
320
|
+
task_id=args.task_id,
|
|
321
|
+
local_path=args.local,
|
|
322
|
+
s3_target=args.target,
|
|
323
|
+
csv_path=args.object_sheet,
|
|
324
|
+
status_path=args.status,
|
|
325
|
+
num_workers=args.num_workers,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
uploader.run(s3_client)
|
|
329
|
+
logger.info("Upload completed successfully")
|
|
330
|
+
|
|
331
|
+
except Exception as e:
|
|
332
|
+
logger.error(f"Upload failed: {e}")
|
|
333
|
+
sys.exit(1)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
if __name__ == "__main__":
|
|
337
|
+
main()
|
aihub/utils/download.py
CHANGED
|
@@ -1,15 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import concurrent.futures
|
|
4
3
|
import os
|
|
5
4
|
import tempfile
|
|
6
5
|
import zipfile
|
|
7
6
|
from typing import List, TypedDict
|
|
8
7
|
|
|
9
8
|
import pyarrow.parquet as pq
|
|
10
|
-
from tqdm import
|
|
9
|
+
from tqdm.contrib.concurrent import thread_map
|
|
11
10
|
|
|
12
|
-
from .http import http_download_file
|
|
11
|
+
from .http import http_download_file, http_download_file_wrapper
|
|
13
12
|
from .s3 import s3_to_url
|
|
14
13
|
|
|
15
14
|
|
|
@@ -59,18 +58,7 @@ def dataset_download(index_url: str, local_dir: str, worker: int = 4) -> None:
|
|
|
59
58
|
if worker < 1:
|
|
60
59
|
worker = 1
|
|
61
60
|
|
|
62
|
-
|
|
63
|
-
tqdm(total=len(files), desc="Downloading dataset") as bar,
|
|
64
|
-
concurrent.futures.ThreadPoolExecutor(max_workers=worker) as pool,
|
|
65
|
-
):
|
|
66
|
-
|
|
67
|
-
def _one(flocal: str, furl: str):
|
|
68
|
-
http_download_file(furl, flocal)
|
|
69
|
-
bar.update()
|
|
70
|
-
|
|
71
|
-
futures = [pool.submit(_one, p, u) for p, u in files]
|
|
72
|
-
for fut in concurrent.futures.as_completed(futures):
|
|
73
|
-
fut.result()
|
|
61
|
+
thread_map(http_download_file_wrapper, files, max_workers=worker)
|
|
74
62
|
|
|
75
63
|
|
|
76
64
|
def zip_dir(dir_path: str, zip_path: str):
|
aihub/utils/http.py
CHANGED
|
@@ -5,6 +5,12 @@ import os
|
|
|
5
5
|
import requests
|
|
6
6
|
|
|
7
7
|
|
|
8
|
+
def http_download_file_wrapper(item):
|
|
9
|
+
"""Wrapper function"""
|
|
10
|
+
dst_path, url = item
|
|
11
|
+
return http_download_file(url, dst_path)
|
|
12
|
+
|
|
13
|
+
|
|
8
14
|
def http_download_file(url: str, dst_path: str, chunk: int = 1 << 16) -> None:
|
|
9
15
|
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
|
|
10
16
|
with requests.get(url, timeout=None, stream=True) as r:
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
aihub/__init__.py,sha256=
|
|
1
|
+
aihub/__init__.py,sha256=qb0TalpSt1CbprnFyeLUKqgrqNtmnk9IoQQ7umAoXVY,23
|
|
2
2
|
aihub/client.py,sha256=nVELjkyVOG6DKJjurYn59fCoT5JsSayUweiH7bvKcAo,5547
|
|
3
3
|
aihub/exceptions.py,sha256=l2cMAvipTqQOio3o11fXsCCSCevbuK4PTsxofkobFjk,500
|
|
4
4
|
aihub/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
-
aihub/models/artifact.py,sha256=
|
|
5
|
+
aihub/models/artifact.py,sha256=F-r7DJY9A09yIQJqWol6gLRu6y7NGjRa6-BxkMEluxU,4655
|
|
6
6
|
aihub/models/common.py,sha256=qmabc2LkAdQJXIcpT1P35zxd0Lc8yDYdD4ame1iF4Bs,241
|
|
7
7
|
aihub/models/data_warehouse.py,sha256=zXvWwg7ySoFJMdqQ_1UMTNEKDMhu1hDHlWdBAXdizBk,3905
|
|
8
|
-
aihub/models/dataset_management.py,sha256=
|
|
8
|
+
aihub/models/dataset_management.py,sha256=4DuQ0zM7jv73SJiqvieHLtn2Y-T6FIFV9r7bgzyCtDo,10790
|
|
9
9
|
aihub/models/document_center.py,sha256=od9bzx6krAS6ktIA-ChxeqGcch0v2wsS1flY2vuHXBc,1340
|
|
10
10
|
aihub/models/eval.py,sha256=4Gon4Sg4dOkyCx3KH2mO5ip3AhrBwrPC0UZA447HeoQ,910
|
|
11
11
|
aihub/models/labelfree.py,sha256=nljprYO6ECuctTVbHqriQ73N5EEyYURhBrnU28Ngfvc,1589
|
|
@@ -17,9 +17,9 @@ aihub/models/task_center.py,sha256=HE21Q4Uj-vt9LHezHnqBYgdinhrh4iJPOq8VXbSMllU,5
|
|
|
17
17
|
aihub/models/user_system.py,sha256=0L_pBkWL9v1tv_mclOyRgCyWibuuj_XU3mPoe2v48vQ,12216
|
|
18
18
|
aihub/models/workflow_center.py,sha256=4xtI1WZ38ceXJ8gwDBj-QNjOiRlLO_8kGiQybdudJPY,20121
|
|
19
19
|
aihub/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
20
|
-
aihub/services/artifact.py,sha256=
|
|
20
|
+
aihub/services/artifact.py,sha256=lfOrgOT2AlH1w-75NLcQGOhVWdhmJcWD1gESPpUzqUw,11257
|
|
21
21
|
aihub/services/data_warehouse.py,sha256=awvlJdggo8ph6sXweXXVp4GLRuUSD46LoD0QQksXRts,2964
|
|
22
|
-
aihub/services/dataset_management.py,sha256=
|
|
22
|
+
aihub/services/dataset_management.py,sha256=R7mFsJ1dNOI_p5yNj_rQdLolRC0UKEN4WejE7uOjVlE,21379
|
|
23
23
|
aihub/services/document_center.py,sha256=dG67Ji-DOnzL2t-4x4gVfMt9fbSj_IjVHCLw5R-VTkQ,1813
|
|
24
24
|
aihub/services/eval.py,sha256=V1nBISIyYWg9JJO24xzy4-kit9NsaCYp1EWIX_fgJkQ,2128
|
|
25
25
|
aihub/services/labelfree.py,sha256=xua62UWhVXTxJjHRyy86waaormnJjmpQwepcARBy_h0,1450
|
|
@@ -32,11 +32,12 @@ aihub/services/task_center.py,sha256=rVQG7q2_GN0501w5KHsOOlSVFX9ovpRMGX5hskCqggw
|
|
|
32
32
|
aihub/services/user_system.py,sha256=IqWL4bnsKyyzuGT5l6adnw0qNXlH9PSo1-C_pFyOSzA,18868
|
|
33
33
|
aihub/services/workflow_center.py,sha256=caKxOlba0J1s1RUK6RUm1ndJSwAcZXEakRanu3sGKPU,17468
|
|
34
34
|
aihub/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
35
|
-
aihub/utils/
|
|
36
|
-
aihub/utils/
|
|
35
|
+
aihub/utils/di.py,sha256=vFUzno5WbRKu6-pj8Hnz9IqT7xb9UDZQ4qpOFH1YAtM,11812
|
|
36
|
+
aihub/utils/download.py,sha256=ZZVbcC-PnN3PumV7ZiJ_-srkt4HPPovu2F6Faa2RrPE,1830
|
|
37
|
+
aihub/utils/http.py,sha256=AmfHHNjptuuSFx2T1twWCnerR_hLN_gd0lUs8z36ERA,547
|
|
37
38
|
aihub/utils/s3.py,sha256=ISIBP-XdBPkURpXnN56ZnIWokOOg2SRUh_qvxJk-G1Q,2187
|
|
38
|
-
intellif_aihub-0.1.
|
|
39
|
-
intellif_aihub-0.1.
|
|
40
|
-
intellif_aihub-0.1.
|
|
41
|
-
intellif_aihub-0.1.
|
|
42
|
-
intellif_aihub-0.1.
|
|
39
|
+
intellif_aihub-0.1.15.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
|
40
|
+
intellif_aihub-0.1.15.dist-info/METADATA,sha256=Hz8Z3sB06pNTJF8lygzDU37da2bCgXCrzJ1-CRAlN7Y,2949
|
|
41
|
+
intellif_aihub-0.1.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
42
|
+
intellif_aihub-0.1.15.dist-info/top_level.txt,sha256=vIvTtSIN73xv46BpYM-ctVGnyOiUQ9EWP_6ngvdIlvw,6
|
|
43
|
+
intellif_aihub-0.1.15.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|