intellif-aihub 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of intellif-aihub might be problematic. Click here for more details.

aihub/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.13"
1
+ __version__ = "0.1.15"
aihub/models/artifact.py CHANGED
@@ -14,6 +14,7 @@ from pydantic import BaseModel, Field
14
14
 
15
15
  class ArtifactType(str, Enum):
16
16
  """制品类型枚举:"dataset"-数据集类型;"model"-模型类型;"metrics"-指标类型;"log"-日志类型;"checkpoint"-检查点类型;"image"-图像类型;"prediction"-预测结果类型;"other"-其他类型"""
17
+
17
18
  dataset = "dataset" # 数据集类型
18
19
  model = "model" # 模型类型
19
20
  metrics = "metrics" # 指标类型
@@ -26,24 +27,29 @@ class ArtifactType(str, Enum):
26
27
 
27
28
  class CreateArtifactsReq(BaseModel):
28
29
  """创建制品请求"""
30
+
29
31
  entity_id: str = Field(alias="entity_id", description="实体ID,通常是运行ID,用于关联制品与特定运行")
30
- entity_type: ArtifactType = Field(default=ArtifactType.other, alias="entity_type",
31
- description="制品类型,指定制品的类型,默认为other")
32
+ entity_type: ArtifactType = Field(
33
+ default=ArtifactType.other, alias="entity_type", description="制品类型,指定制品的类型,默认为other"
34
+ )
32
35
  src_path: str = Field(alias="src_path", description="源路径,制品在系统中的路径标识")
33
- is_dir: bool = Field(default=False, alias="is_dir",
34
- description="是否为目录,True表示制品是一个目录,False表示是单个文件")
36
+ is_dir: bool = Field(
37
+ default=False, alias="is_dir", description="是否为目录,True表示制品是一个目录,False表示是单个文件"
38
+ )
35
39
 
36
40
  model_config = {"use_enum_values": True}
37
41
 
38
42
 
39
43
  class CreateArtifactsResponseData(BaseModel):
40
44
  """创建制品响应数据"""
45
+
41
46
  id: int = Field(description="制品ID")
42
47
  s3_path: str = Field(alias="s3_path", description="S3存储路径")
43
48
 
44
49
 
45
50
  class CreateArtifactsResponseModel(BaseModel):
46
51
  """创建制品响应模型"""
52
+
47
53
  code: int = Field(description="响应码,0表示成功")
48
54
  msg: str = Field(default="", description="响应消息")
49
55
  data: Optional[CreateArtifactsResponseData] = Field(default=None, description="响应数据")
@@ -51,6 +57,7 @@ class CreateArtifactsResponseModel(BaseModel):
51
57
 
52
58
  class CreateEvalReq(BaseModel):
53
59
  """创建评估请求"""
60
+
54
61
  dataset_id: int = Field(alias="dataset_id", description="数据集ID")
55
62
  dataset_version_id: int = Field(alias="dataset_version_id", description="数据集版本ID")
56
63
  prediction_artifact_path: str = Field(alias="prediction_artifact_path", description="预测结果制品路径")
@@ -62,6 +69,7 @@ class CreateEvalReq(BaseModel):
62
69
 
63
70
  class ArtifactResp(BaseModel):
64
71
  """制品响应模型,表示一个制品的详细信息"""
72
+
65
73
  id: int = Field(description="制品ID")
66
74
  entity_type: str = Field(alias="entity_type", description="实体类型")
67
75
  entity_id: str = Field(alias="entity_id", description="实体ID")
@@ -72,6 +80,7 @@ class ArtifactResp(BaseModel):
72
80
 
73
81
  class ArtifactRespData(BaseModel):
74
82
  """制品分页数据"""
83
+
75
84
  total: int = Field(description="总记录数")
76
85
  page_size: int = Field(alias="page_size", description="每页大小")
77
86
  page_num: int = Field(alias="page_num", description="页码")
@@ -80,6 +89,7 @@ class ArtifactRespData(BaseModel):
80
89
 
81
90
  class ArtifactRespModel(BaseModel):
82
91
  """获取制品响应模型"""
92
+
83
93
  code: int = Field(description="响应码,0表示成功")
84
94
  msg: str = Field(default="", description="响应消息")
85
95
  data: ArtifactRespData = Field(description="响应数据")
@@ -91,8 +101,10 @@ InfinityPageSize = 10000 * 100
91
101
 
92
102
  class StsResp(BaseModel):
93
103
  """STS 临时凭证"""
104
+
94
105
  access_key_id: Optional[str] = Field(default=None, alias="access_key_id", description="访问密钥ID")
95
106
  secret_access_key: Optional[str] = Field(default=None, alias="secret_access_key", description="秘密访问密钥")
96
107
  session_token: Optional[str] = Field(default=None, alias="session_token", description="会话令牌")
97
108
  expiration: Optional[int] = Field(default=None, alias="expiration", description="过期时间")
98
109
  endpoint: Optional[str] = Field(default=None, alias="endpoint", description="端点URL")
110
+ bucket: Optional[str] = Field(default=None, alias="bucket", description="存储桶名称")
@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
8
8
 
9
9
  class DatasetVersionStatus(IntEnum):
10
10
  """数据集版本状态:1-等待中;2-运行中;3-成功;4-失败;5-加载meta;6-构建index"""
11
+
11
12
  Waiting = 1 # 等待中
12
13
  Running = 2 # 运行中
13
14
  Success = 3 # 成功
@@ -18,6 +19,7 @@ class DatasetVersionStatus(IntEnum):
18
19
 
19
20
  class UploadType(IntEnum):
20
21
  """上传类型:1-本地上传;3-服务器路径上传;4-Labelfree;5-数据接入"""
22
+
21
23
  LOCAL = 1 # 本地上传
22
24
  SERVER_PATH = 3 # 服务器路径上传
23
25
  LABELFREE = 4 # Labelfree
@@ -26,6 +28,7 @@ class UploadType(IntEnum):
26
28
 
27
29
  class CreateDatasetRequest(BaseModel):
28
30
  """创建数据集请求"""
31
+
29
32
  name: str = Field(description="数据集名称")
30
33
  description: str = Field(description="数据集描述")
31
34
  tags: List[int] = Field(description="标签ID列表,通过标签管理系统查询")
@@ -37,11 +40,13 @@ class CreateDatasetRequest(BaseModel):
37
40
 
38
41
  class CreateDatasetResponse(BaseModel):
39
42
  """创建数据集返回"""
43
+
40
44
  id: int = Field(alias="id", description="数据集ID")
41
45
 
42
46
 
43
47
  class DatasetVersionBase(BaseModel):
44
48
  """数据集版本概要"""
49
+
45
50
  id: int = Field(description="版本ID")
46
51
  version: int = Field(description="版本号")
47
52
  status: DatasetVersionStatus = Field(description="版本状态")
@@ -53,6 +58,7 @@ class DatasetVersionBase(BaseModel):
53
58
 
54
59
  class DatasetDetail(BaseModel):
55
60
  """数据集详情"""
61
+
56
62
  id: int = Field(description="数据集 ID")
57
63
  name: str = Field(description="名称")
58
64
  description: str = Field(description="描述")
@@ -69,6 +75,7 @@ class DatasetDetail(BaseModel):
69
75
 
70
76
  class ExtInfo(BaseModel):
71
77
  """扩展信息"""
78
+
72
79
  rec_file_path: Optional[str] = Field(None, alias="rec_file_path", description="rec文件路径")
73
80
  idx_file_path: Optional[str] = Field(None, alias="idx_file_path", description="idx文件路径")
74
81
  json_file_path: Optional[str] = Field(None, alias="json_file_path", description="json文件路径")
@@ -77,6 +84,7 @@ class ExtInfo(BaseModel):
77
84
 
78
85
  class CreateDatasetVersionRequest(BaseModel):
79
86
  """创建版本请求"""
87
+
80
88
  upload_path: str = Field(alias="upload_path", description="上传路径")
81
89
  description: Optional[str] = Field(None, description="版本描述")
82
90
  dataset_id: int = Field(alias="dataset_id", description="数据集ID")
@@ -91,11 +99,13 @@ class CreateDatasetVersionRequest(BaseModel):
91
99
 
92
100
  class CreateDatasetVersionResponse(BaseModel):
93
101
  """创建版本返回"""
102
+
94
103
  id: int = Field(alias="id", description="版本ID")
95
104
 
96
105
 
97
106
  class UploadDatasetVersionRequest(BaseModel):
98
107
  """上传数据集版本请求"""
108
+
99
109
  upload_path: str = Field(alias="upload_path", description="上传目录")
100
110
  upload_type: UploadType = Field(alias="upload_type", description="上传类型")
101
111
  dataset_id: int = Field(alias="dataset_id", description="数据集ID")
@@ -107,11 +117,13 @@ class UploadDatasetVersionRequest(BaseModel):
107
117
 
108
118
  class UploadDatasetVersionResponse(BaseModel):
109
119
  """上传数据集版本返回"""
120
+
110
121
  id: int = Field(alias="id", description="版本ID")
111
122
 
112
123
 
113
124
  class DatasetVersionDetail(BaseModel):
114
125
  """数据集版本详情"""
126
+
115
127
  id: int = Field(description="版本ID")
116
128
  version: int = Field(description="版本号")
117
129
  dataset_id: int = Field(alias="dataset_id", description="数据集ID")
@@ -133,5 +145,84 @@ class DatasetVersionDetail(BaseModel):
133
145
 
134
146
  class FileUploadData(BaseModel):
135
147
  """文件上传数据"""
148
+
136
149
  path: str = Field(description="路径")
137
150
  url: str = Field(description="URL")
151
+
152
+
153
+ class ListDatasetReq(BaseModel):
154
+ """列表查询数据集请求(使用 dataset_management v2)"""
155
+ page_size: int = Field(20, alias="page_size", description="每页大小,默认20")
156
+ page_num: int = Field(1, alias="page_num", description="页码,从1开始")
157
+ name: Optional[str] = Field(None, description="数据集名称筛选")
158
+ tags: Optional[str] = Field(None, description="标签筛选")
159
+ create_by: Optional[int] = Field(None, alias="create_by", description="创建人筛选")
160
+ scope: Optional[str] = Field("all", description="范围筛选:created|shared|all")
161
+
162
+
163
+ class ListDatasetItem(BaseModel):
164
+ """列表数据集项"""
165
+ id: int = Field(description="数据集ID")
166
+ name: str = Field(description="数据集名称")
167
+ description: str = Field(description="数据集描述")
168
+ cover_img: str = Field(alias="cover_img", description="封面图片")
169
+ created_at: int = Field(alias="created_at", description="创建时间戳")
170
+ updated_at: int = Field(alias="update_at", description="更新时间戳")
171
+ user_id: int = Field(alias="user_id", description="创建人ID")
172
+ username: str = Field(description="创建人用户名")
173
+ tags: Optional[List[int]] = Field(None, description="标签列表")
174
+ access_user_ids: Optional[List[int]] = Field(None, alias="access_user_ids", description="有访问权限的用户ID列表")
175
+ is_private: bool = Field(alias="is_private", description="是否私有")
176
+
177
+
178
+ class ListDatasetResp(BaseModel):
179
+ """列表查询数据集响应"""
180
+ total: int = Field(description="总数")
181
+ page_size: int = Field(alias="page_size", description="每页大小")
182
+ page_num: int = Field(alias="page_num", description="当前页码")
183
+ data: List[ListDatasetItem] = Field(description="数据集列表")
184
+
185
+
186
+ class ListDatasetVersionReq(BaseModel):
187
+ """列表查询数据集版本请求(使用 dataset_management v2)"""
188
+ page_size: int = Field(10000000, alias="page_size", description="每页大小,默认10000000")
189
+ page_num: int = Field(1, alias="page_num", description="页码,从1开始")
190
+ dataset_id: Optional[int] = Field(None, alias="dataset_id", description="数据集ID筛选")
191
+ dataset_version_ids: Optional[str] = Field(None, alias="dataset_version_ids", description="数据集版本ID列表,逗号分隔")
192
+
193
+
194
+ class ListDatasetVersionItem(BaseModel):
195
+ """列表数据集版本项"""
196
+ id: int = Field(description="版本ID")
197
+ version: int = Field(description="版本号")
198
+ dataset_id: int = Field(alias="dataset_id", description="数据集ID")
199
+ upload_path: str = Field(alias="upload_path", description="上传路径")
200
+ upload_type: int = Field(alias="upload_type", description="上传类型")
201
+ parent_version_id: Optional[int] = Field(None, alias="parent_version_id", description="父版本ID")
202
+ description: Optional[str] = Field(None, description="版本描述")
203
+ status: int = Field(description="版本状态")
204
+ message: str = Field(description="状态信息")
205
+ created_at: int = Field(alias="created_at", description="创建时间戳")
206
+ user_id: int = Field(alias="user_id", description="创建人ID")
207
+ data_size: int = Field(alias="data_size", description="数据大小")
208
+ data_count: int = Field(alias="data_count", description="数据条数")
209
+ username: str = Field(description="创建人用户名")
210
+ dataset_name: str = Field(alias="dataset_name", description="数据集名称")
211
+
212
+
213
+ class ListDatasetVersionResp(BaseModel):
214
+ """列表查询数据集版本响应"""
215
+ total: int = Field(description="总数")
216
+ page_size: int = Field(alias="page_size", description="每页大小")
217
+ page_num: int = Field(alias="page_num", description="当前页码")
218
+ data: List[ListDatasetVersionItem] = Field(description="数据集版本列表")
219
+
220
+
221
+ class CreateDatasetVersionByDataIngestReqV2(BaseModel):
222
+ """通过数据集成创建数据集版本请求"""
223
+
224
+ description: Optional[str] = Field(None, description="描述")
225
+ dataset_id: int = Field(..., description="数据集ID")
226
+ s3_object_sheet: str = Field(..., description="S3对象表")
227
+ object_cnt: Optional[int] = Field(None, description="对象数量")
228
+ data_size: Optional[int] = Field(None, description="数据大小")
@@ -98,9 +98,7 @@ class ArtifactService:
98
98
  """
99
99
  return self._artifact.get_sts()
100
100
 
101
- def get_by_run_id(
102
- self, run_id: str, artifact_path: Optional[str] = None
103
- ) -> List[ArtifactResp]:
101
+ def get_by_run_id(self, run_id: str, artifact_path: Optional[str] = None) -> List[ArtifactResp]:
104
102
  """根据运行ID获取制品列表
105
103
 
106
104
  Args:
@@ -116,11 +114,11 @@ class ArtifactService:
116
114
  return self._artifact.get_by_run_id(run_id, artifact_path)
117
115
 
118
116
  def create_artifact(
119
- self,
120
- local_path: str,
121
- artifact_path: Optional[str] = None,
122
- run_id: Optional[str] = None,
123
- artifact_type: ArtifactType = ArtifactType.other,
117
+ self,
118
+ local_path: str,
119
+ artifact_path: Optional[str] = None,
120
+ run_id: Optional[str] = None,
121
+ artifact_type: ArtifactType = ArtifactType.other,
124
122
  ) -> None:
125
123
  """创建单个文件制品并上传
126
124
 
@@ -171,11 +169,11 @@ class ArtifactService:
171
169
  return
172
170
 
173
171
  def create_artifacts(
174
- self,
175
- local_dir: str,
176
- artifact_path: Optional[str] = None,
177
- run_id: Optional[str] = None,
178
- artifact_type: ArtifactType = ArtifactType.other,
172
+ self,
173
+ local_dir: str,
174
+ artifact_path: Optional[str] = None,
175
+ run_id: Optional[str] = None,
176
+ artifact_type: ArtifactType = ArtifactType.other,
179
177
  ) -> None:
180
178
  """创建目录制品并上传
181
179
 
@@ -223,9 +221,7 @@ class ArtifactService:
223
221
  logger.info(f"log artifact done: {artifact_path}")
224
222
  return
225
223
 
226
- def download_artifacts(
227
- self, run_id: str, artifact_path: Optional[str], local_dir: str
228
- ) -> None:
224
+ def download_artifacts(self, run_id: str, artifact_path: Optional[str], local_dir: str) -> None:
229
225
  """下载制品
230
226
 
231
227
  Args:
@@ -252,9 +248,7 @@ class ArtifactService:
252
248
  if artifact_item.is_dir:
253
249
  download_dir_from_s3(self.s3_client, bucket, object_name, local_dir)
254
250
  else:
255
- self.s3_client.fget_object(
256
- bucket, object_name, str(Path(local_dir) / artifact_item.src_path)
257
- )
251
+ self.s3_client.fget_object(bucket, object_name, str(Path(local_dir) / artifact_item.src_path))
258
252
 
259
253
  logger.info(f"download artifact done: {artifact_path}")
260
254
  return
@@ -311,9 +305,7 @@ class _Artifact:
311
305
  raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
312
306
  return
313
307
 
314
- def get_by_run_id(
315
- self, run_id: str, artifact_path: Optional[str]
316
- ) -> List[ArtifactResp]:
308
+ def get_by_run_id(self, run_id: str, artifact_path: Optional[str]) -> List[ArtifactResp]:
317
309
  """根据运行ID获取制品列表
318
310
 
319
311
  Args:
@@ -326,18 +318,12 @@ class _Artifact:
326
318
  Raises:
327
319
  APIError: 当API调用失败时抛出
328
320
  """
329
- resp = self._http.get(
330
- f"{_Base}/artifacts?entity_id={run_id}&page_num=1&page_size={InfinityPageSize}"
331
- )
321
+ resp = self._http.get(f"{_Base}/artifacts?entity_id={run_id}&page_num=1&page_size={InfinityPageSize}")
332
322
  wrapper = APIWrapper[ArtifactRespData].model_validate(resp.json())
333
323
  if wrapper.code != 0:
334
324
  raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
335
325
  if artifact_path:
336
- return [
337
- artifact
338
- for artifact in wrapper.data.data
339
- if artifact.src_path == artifact_path
340
- ]
326
+ return [artifact for artifact in wrapper.data.data if artifact.src_path == artifact_path]
341
327
  else:
342
328
  return wrapper.data.data
343
329
 
@@ -7,6 +7,8 @@
7
7
  - **创建数据集及其版本**(支持本地上传和服务器现有文件两种方式)
8
8
  - **上传文件到对象存储**(大文件自动分片)
9
9
  - **查询数据集/数据集版本详情**
10
+ - **列表查询和搜索数据集**(支持分页和筛选)
11
+ - **列表查询数据集版本**(支持按数据集ID筛选和分页)
10
12
  - **按版本名称或ID下载数据集文件**
11
13
  """
12
14
 
@@ -15,7 +17,6 @@ from __future__ import annotations
15
17
  import mimetypes
16
18
  import os
17
19
  import pathlib
18
- import tempfile
19
20
  import time
20
21
  import uuid
21
22
 
@@ -23,6 +24,7 @@ import httpx
23
24
  from loguru import logger
24
25
 
25
26
  from ..exceptions import APIError
27
+ from ..models.artifact import StsResp
26
28
  from ..models.common import APIWrapper
27
29
  from ..models.dataset_management import (
28
30
  CreateDatasetRequest,
@@ -34,9 +36,16 @@ from ..models.dataset_management import (
34
36
  DatasetVersionDetail,
35
37
  UploadDatasetVersionResponse,
36
38
  FileUploadData,
39
+ ListDatasetReq,
40
+ ListDatasetResp,
41
+ ListDatasetVersionReq,
42
+ ListDatasetVersionResp,
43
+ CreateDatasetVersionByDataIngestReqV2,
44
+ UploadType,
37
45
  )
38
46
  from ..models.dataset_management import DatasetVersionStatus
39
- from ..utils.download import dataset_download, zip_dir
47
+ from ..utils.di import SimpleS3Client, DataUploader
48
+ from ..utils.download import dataset_download
40
49
 
41
50
  _BASE = "/dataset-mng/api/v2"
42
51
 
@@ -132,20 +141,29 @@ class DatasetManagementService:
132
141
  def dataset(self) -> _Dataset:
133
142
  return self._dataset
134
143
 
144
+ def _get_sts(self) -> StsResp:
145
+ return self.dataset_version.get_sts()
146
+
135
147
  @property
136
148
  def dataset_version(self) -> _DatasetVersion:
137
149
  return self._dataset_version
138
150
 
151
+ def upload_by_data_ingest(
152
+ self,
153
+ req: CreateDatasetVersionByDataIngestReqV2,
154
+ ) -> CreateDatasetVersionResponse:
155
+ return self.dataset_version.upload_by_data_ingest(req)
156
+
139
157
  def create_dataset_and_version(
140
- self,
141
- *,
142
- dataset_name: str,
143
- dataset_description: str = "",
144
- is_local_upload: bool,
145
- local_file_path: str | None = None,
146
- server_file_path: str | None = None,
147
- version_description: str = "",
148
- timeout: int = 1_800,
158
+ self,
159
+ *,
160
+ dataset_name: str,
161
+ dataset_description: str = "",
162
+ is_local_upload: bool,
163
+ local_file_path: str | None = None,
164
+ server_file_path: str | None = None,
165
+ version_description: str = "",
166
+ timeout: int = 1_800,
149
167
  ) -> tuple[int, int, str]:
150
168
  """创建数据集及其版本,并等待版本状态变为 *Success*。
151
169
 
@@ -163,17 +181,51 @@ class DatasetManagementService:
163
181
 
164
182
  Returns:
165
183
  tuple[int, int, str]: 一个三元组,包含:[数据集 ID,数据集版本 ID, 数据集版本标签(格式为 <dataset_name>/V<version_number>)]
184
+
185
+ Raises:
186
+ ValueError: 当参数不满足要求时
187
+ APIError: 当后端返回错误时
188
+ TimeoutError: 当等待超时时
166
189
  """
190
+ # 参数校验
191
+ self._validate_create_params(is_local_upload, local_file_path, server_file_path)
192
+
193
+ # 创建数据集
194
+ dataset_id = self._create_dataset(dataset_name, dataset_description)
195
+ logger.info(f"创建数据集成功,名称为 {dataset_name} ,开始准备创建版本、上传数据")
196
+
197
+ # 创建数据集版本
198
+ version_id = self._create_dataset_version(
199
+ dataset_id=dataset_id,
200
+ is_local_upload=is_local_upload,
201
+ local_file_path=local_file_path,
202
+ server_file_path=server_file_path,
203
+ version_description=version_description,
204
+ )
205
+
206
+ # 获取版本标签
207
+ version_tag = self._get_version_tag(dataset_id, version_id)
208
+ logger.info(f"数据集版本创建成功,名称为 {version_tag},开始轮询状态…")
209
+
210
+ # 轮询等待版本状态变为成功
211
+ self._wait_for_version_success(version_id, timeout)
212
+
213
+ return dataset_id, version_id, version_tag
214
+
215
+ def _validate_create_params(
216
+ self, is_local_upload: bool, local_file_path: str | None, server_file_path: str | None
217
+ ) -> None:
218
+ """验证创建数据集和版本所需的参数"""
167
219
  if is_local_upload:
168
220
  if not local_file_path:
169
221
  raise ValueError("is_local_upload=True 时必须提供 local_file_path")
170
- upload_type = 1
171
222
  else:
172
223
  if not server_file_path:
173
224
  raise ValueError("is_local_upload=False 时必须提供 server_file_path")
174
- upload_type = 3
175
225
 
176
- dataset_id = self._dataset.create(
226
+ def _create_dataset(self, dataset_name: str, dataset_description: str) -> int:
227
+ """创建数据集"""
228
+ return self._dataset.create(
177
229
  CreateDatasetRequest(
178
230
  name=dataset_name,
179
231
  description=dataset_description,
@@ -184,39 +236,96 @@ class DatasetManagementService:
184
236
  access_user_ids=None,
185
237
  )
186
238
  )
187
- logger.info(
188
- f"创建数据集成功,名称为 {dataset_name} ,开始准备创建版本、上传数据"
189
- )
190
239
 
240
+ def _create_dataset_version(
241
+ self,
242
+ dataset_id: int,
243
+ is_local_upload: bool,
244
+ local_file_path: str | None,
245
+ server_file_path: str | None,
246
+ version_description: str,
247
+ ) -> int:
248
+ """根据上传类型创建数据集版本"""
191
249
  if is_local_upload:
192
- # 上传文件,检查是文件夹还是zip
193
- local_file_path = pathlib.Path(local_file_path)
194
- if local_file_path.is_dir():
195
- # 把文件夹打包为一个 zip
196
- temp_zip_path = (
197
- pathlib.Path(tempfile.mkdtemp()) / f" {uuid.uuid4().hex}.zip"
198
- )
199
- zip_dir(local_file_path, temp_zip_path)
200
- upload_data = self._upload.upload_file(temp_zip_path)
201
- os.remove(temp_zip_path)
202
- else:
203
- upload_data = self._upload.upload_file(local_file_path)
204
-
205
- upload_path = upload_data.path
250
+ return self._create_local_dataset_version(dataset_id, local_file_path, version_description)
206
251
  else:
207
- upload_path = server_file_path
208
- logger.info(f"文件上传成功:{local_file_path}")
252
+ return self._create_server_dataset_version(dataset_id, server_file_path, version_description)
253
+
254
+ def _create_local_dataset_version(
255
+ self, dataset_id: int, local_file_path: str | None, version_description: str
256
+ ) -> int:
257
+ """创建本地文件数据集版本"""
258
+ if pathlib.Path(local_file_path).is_dir():
259
+ return self._create_local_dir_dataset_version(dataset_id, local_file_path)
260
+ elif pathlib.Path(local_file_path).is_file():
261
+ return self._create_local_file_dataset_version(dataset_id, local_file_path, version_description)
262
+ else:
263
+ raise ValueError(f"本地路径既不是文件也不是目录: {local_file_path}")
209
264
 
210
- version_id = self._dataset_version.upload(
265
+ def _create_local_dir_dataset_version(self, dataset_id: int, local_file_path: str) -> int:
266
+ """处理本地目录上传"""
267
+ sts = self._get_sts()
268
+ s3_client = SimpleS3Client(
269
+ sts.endpoint, sts.access_key_id, sts.secret_access_key, session_token=sts.session_token
270
+ )
271
+ uid = uuid.uuid4().hex
272
+ s3_target = f"s3://{sts.bucket}/dataset_workspace/{dataset_id}/{uid}"
273
+ s3_csv_path = f"s3://{sts.bucket}/dataset_workspace/{dataset_id}/{uid}.csv"
274
+ s3_status_path = f"s3://{sts.bucket}/dataset_workspace/{dataset_id}/{uid}.json"
275
+
276
+ # 创建上传器并执行
277
+ uploader = DataUploader(
278
+ task_id=dataset_id,
279
+ local_path=str(local_file_path),
280
+ s3_target=s3_target,
281
+ csv_path=s3_csv_path,
282
+ status_path=s3_status_path,
283
+ num_workers=40,
284
+ )
285
+
286
+ upload_stats = uploader.run(s3_client)
287
+ req = CreateDatasetVersionByDataIngestReqV2(
288
+ description=f"sdk 上传",
289
+ dataset_id=dataset_id,
290
+ s3_object_sheet=s3_csv_path,
291
+ object_cnt=upload_stats.uploaded_count,
292
+ data_size=upload_stats.uploaded_size,
293
+ )
294
+ return self.upload_by_data_ingest(req).id
295
+
296
+ def _create_local_file_dataset_version(
297
+ self, dataset_id: int, local_file_path: str, version_description: str
298
+ ) -> int:
299
+ """处理本地文件上传"""
300
+ upload_data = self._upload.upload_file(local_file_path)
301
+ upload_path = upload_data.path
302
+ logger.info(f"文件上传成功:{local_file_path}")
303
+ return self._dataset_version.upload(
211
304
  UploadDatasetVersionRequest(
212
305
  upload_path=upload_path,
213
- upload_type=upload_type,
306
+ upload_type=UploadType.LOCAL, # 本地上传类型
214
307
  dataset_id=dataset_id,
215
308
  description=version_description,
216
309
  parent_version_id=0,
217
310
  )
218
311
  )
219
312
 
313
+ def _create_server_dataset_version(
314
+ self, dataset_id: int, server_file_path: str | None, version_description: str
315
+ ) -> int:
316
+ """创建服务器文件数据集版本"""
317
+ return self._dataset_version.upload(
318
+ UploadDatasetVersionRequest(
319
+ upload_path=server_file_path,
320
+ upload_type=UploadType.SERVER_PATH, # 服务器文件上传类型
321
+ dataset_id=dataset_id,
322
+ description=version_description,
323
+ parent_version_id=0,
324
+ )
325
+ )
326
+
327
+ def _get_version_tag(self, dataset_id: int, version_id: int) -> str:
328
+ """获取版本标签"""
220
329
  detail = self._dataset.get(dataset_id)
221
330
  ver_num = next(
222
331
  (v.version for v in detail.versions if v.id == version_id),
@@ -225,9 +334,10 @@ class DatasetManagementService:
225
334
  if ver_num is None:
226
335
  ver_num = 1
227
336
 
228
- version_tag = f"{detail.name}/V{ver_num}"
229
- logger.info(f"数据集版本创建成功,名称为 {version_tag},开始轮询状态…")
337
+ return f"{detail.name}/V{ver_num}"
230
338
 
339
+ def _wait_for_version_success(self, version_id: int, timeout: int) -> None:
340
+ """轮询等待版本状态变为成功"""
231
341
  start_ts = time.time()
232
342
  poll_interval = 10
233
343
 
@@ -249,8 +359,6 @@ class DatasetManagementService:
249
359
  logger.debug(f"已等待 {elapsed:.0f}s,继续轮询…")
250
360
  time.sleep(poll_interval)
251
361
 
252
- return dataset_id, version_id, version_tag
253
-
254
362
  def run_download(self, dataset_version_name: str, local_dir: str, worker: int = 4) -> None:
255
363
  """根据数据集版本名称下载对应的数据集文件。
256
364
 
@@ -270,6 +378,66 @@ class DatasetManagementService:
270
378
  raise APIError("parquet_index_path 为空")
271
379
  dataset_download(detail.parquet_index_path, local_dir, worker)
272
380
 
381
+ def list_datasets(
382
+ self,
383
+ *,
384
+ page_size: int = 20,
385
+ page_num: int = 1,
386
+ name: str | None = None,
387
+ tags: str | None = None,
388
+ create_by: int | None = None,
389
+ scope: str = "all"
390
+ ) -> ListDatasetResp:
391
+ """列表查询数据集
392
+
393
+ Args:
394
+ page_size: 每页大小,默认20
395
+ page_num: 页码,从1开始,默认1
396
+ name: 数据集名称筛选,可选
397
+ tags: 标签筛选,可选
398
+ create_by: 创建人筛选,可选
399
+ scope: 范围筛选:created|shared|all,默认all
400
+
401
+ Returns:
402
+ ListDatasetResp: 数据集列表响应,包含分页信息和数据集列表
403
+ """
404
+ payload = ListDatasetReq(
405
+ page_size=page_size,
406
+ page_num=page_num,
407
+ name=name,
408
+ tags=tags,
409
+ create_by=create_by,
410
+ scope=scope
411
+ )
412
+ return self._dataset.list_datasets(payload)
413
+
414
+ def list_dataset_versions(
415
+ self,
416
+ *,
417
+ page_size: int = 10000000,
418
+ page_num: int = 1,
419
+ dataset_id: int | None = None,
420
+ dataset_version_ids: str | None = None
421
+ ) -> ListDatasetVersionResp:
422
+ """列表查询数据集版本
423
+
424
+ Args:
425
+ page_size: 每页大小,默认10000000
426
+ page_num: 页码,从1开始,默认1
427
+ dataset_id: 数据集ID筛选,可选
428
+ dataset_version_ids: 数据集版本ID列表,逗号分隔,可选
429
+
430
+ Returns:
431
+ ListDatasetVersionResp: 数据集版本列表响应,包含分页信息和数据集版本列表
432
+ """
433
+ payload = ListDatasetVersionReq(
434
+ page_size=page_size,
435
+ page_num=page_num,
436
+ dataset_id=dataset_id,
437
+ dataset_version_ids=dataset_version_ids
438
+ )
439
+ return self._dataset_version.list_dataset_versions(payload)
440
+
273
441
 
274
442
  class _Dataset:
275
443
  def __init__(self, http: httpx.Client):
@@ -292,6 +460,15 @@ class _Dataset:
292
460
  raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
293
461
  return wrapper.data
294
462
 
463
+ def list_datasets(self, payload: ListDatasetReq) -> ListDatasetResp:
464
+ """列表查询数据集"""
465
+ params = payload.model_dump(by_alias=True, exclude_none=True)
466
+ resp = self._http.get(f"{_BASE}/datasets", params=params)
467
+ wrapper = APIWrapper[ListDatasetResp].model_validate(resp.json())
468
+ if wrapper.code != 0:
469
+ raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
470
+ return wrapper.data
471
+
295
472
 
296
473
  class _DatasetVersion:
297
474
  def __init__(self, http: httpx.Client):
@@ -325,14 +502,55 @@ class _DatasetVersion:
325
502
  return wrapper.data
326
503
 
327
504
  def get_by_name(self, version_name: str) -> DatasetVersionDetail:
328
- resp = self._http.get(
329
- f"{_BASE}/dataset-versions-detail", params={"name": version_name}
330
- )
505
+ resp = self._http.get(f"{_BASE}/dataset-versions-detail", params={"name": version_name})
331
506
  wrapper = APIWrapper[DatasetVersionDetail].model_validate(resp.json())
332
507
  if wrapper.code != 0:
333
508
  raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
334
509
  return wrapper.data
335
510
 
511
+ def list_dataset_versions(self, payload: ListDatasetVersionReq) -> ListDatasetVersionResp:
512
+ """列表查询数据集版本"""
513
+ params = payload.model_dump(by_alias=True, exclude_none=True)
514
+ resp = self._http.get(f"{_BASE}/dataset-versions", params=params)
515
+ wrapper = APIWrapper[ListDatasetVersionResp].model_validate(resp.json())
516
+ if wrapper.code != 0:
517
+ raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
518
+ return wrapper.data
519
+
520
+ def get_sts(self) -> StsResp:
521
+ """获取STS临时凭证
522
+
523
+ 获取用于访问S3存储的临时凭证。
524
+
525
+ Returns:
526
+ StsResp: STS临时凭证信息
527
+
528
+ Raises:
529
+ APIError: 当API调用失败时抛出
530
+ """
531
+ resp = self._http.get(f"{_BASE}/dataset-versions/get-sts")
532
+ logger.info(f"get sts: {resp.text}")
533
+ wrapper = APIWrapper[StsResp].model_validate(resp.json())
534
+ if wrapper.code != 0:
535
+ raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
536
+ return wrapper.data
537
+
538
+ def upload_by_data_ingest(self, req: CreateDatasetVersionByDataIngestReqV2) -> CreateDatasetVersionResponse:
539
+ """上传数据集版本(数据集导入)
540
+ Args:
541
+ req
542
+
543
+ """
544
+ resp = self._http.post(
545
+ f"{_BASE}/dataset-versions/data-ingest",
546
+ json=req.model_dump(),
547
+ )
548
+ logger.debug(f"upload_by_data_ingest: {resp.text}")
549
+ wrapper = APIWrapper[CreateDatasetVersionResponse].model_validate(resp.json())
550
+ if wrapper.code != 0:
551
+ raise APIError(f"backend code {wrapper.code}: {wrapper.msg}")
552
+ return wrapper.data
553
+
336
554
 
337
555
  class _Upload:
338
556
  def __init__(self, http: httpx.Client):
aihub/utils/di.py ADDED
@@ -0,0 +1,337 @@
1
+ import argparse
2
+ import csv
3
+ import hashlib
4
+ import json
5
+ import os
6
+ import queue
7
+ import sys
8
+ import threading
9
+ import time
10
+ from concurrent.futures import ThreadPoolExecutor, as_completed
11
+ from pathlib import Path
12
+ from typing import Dict, Any, Tuple
13
+
14
+ import minio
15
+ from loguru import logger
16
+
17
+
18
+ class UploadStatus:
19
+ """上传状态类"""
20
+
21
+ def __init__(self):
22
+ self.uploaded_count = 0
23
+ self.uploaded_size = 0
24
+ self.updated_at = int(time.time() * 1000)
25
+
26
+ def update(self, count: int, size: int):
27
+ self.uploaded_count += count
28
+ self.uploaded_size += size
29
+ self.updated_at = int(time.time() * 1000)
30
+
31
+ def to_dict(self) -> Dict[str, Any]:
32
+ return {
33
+ "uploaded_count": self.uploaded_count,
34
+ "uploaded_size": self.uploaded_size,
35
+ "updated_at": self.updated_at,
36
+ }
37
+
38
+
39
+ class SimpleS3Client:
40
+ """简化的S3客户端"""
41
+
42
+ def __init__(self, endpoint: str, access_key: str, secret_key: str, session_token: str):
43
+ self.client = minio.Minio(
44
+ endpoint, access_key=access_key, secret_key=secret_key, secure=False, session_token=session_token
45
+ )
46
+
47
+ def upload_file(self, local_path: str, bucket: str, object_name: str) -> Tuple[str, int]:
48
+ """上传文件并返回哈希和大小"""
49
+ file_size = os.path.getsize(local_path)
50
+
51
+ # 计算文件哈希
52
+ sha256_hash = hashlib.sha256()
53
+ with open(local_path, "rb") as f:
54
+ for chunk in iter(lambda: f.read(8192), b""):
55
+ sha256_hash.update(chunk)
56
+
57
+ file_hash = sha256_hash.hexdigest()
58
+
59
+ # 上传文件
60
+ with open(local_path, "rb") as f:
61
+ self.client.put_object(bucket, object_name, f, file_size)
62
+
63
+ return file_hash, file_size
64
+
65
+ def upload_json(self, data: Dict[str, Any], bucket: str, object_name: str):
66
+ """上传JSON数据"""
67
+ json_str = json.dumps(data)
68
+ json_bytes = json_str.encode("utf-8")
69
+
70
+ from io import BytesIO
71
+
72
+ self.client.put_object(
73
+ bucket, object_name, BytesIO(json_bytes), len(json_bytes), content_type="application/json"
74
+ )
75
+
76
+
77
+ class DataUploader:
78
+ """数据上传器"""
79
+
80
+ def __init__(
81
+ self, task_id: int, local_path: str, s3_target: str, csv_path: str, status_path: str, num_workers: int = 10
82
+ ):
83
+ self.task_id = task_id
84
+ self.local_path = local_path
85
+ self.num_workers = num_workers
86
+
87
+ # 解析S3路径
88
+ self.target_bucket, self.target_prefix = self._parse_s3_path(s3_target)
89
+ self.csv_bucket, self.csv_key = self._parse_s3_path(csv_path)
90
+ self.status_bucket, self.status_key = self._parse_s3_path(status_path)
91
+
92
+ # 创建工作目录
93
+ self.work_dir = Path.home() / ".di_workspace" / str(task_id)
94
+ self.work_dir.mkdir(parents=True, exist_ok=True)
95
+ self.csv_file = self.work_dir / "upload_records.csv"
96
+
97
+ # CSV记录队列
98
+ self.csv_queue = queue.Queue()
99
+ self.processed_files = set()
100
+ self.total_files = 0
101
+
102
+ def _parse_s3_path(self, s3_path: str) -> Tuple[str, str]:
103
+ """解析S3路径"""
104
+ if s3_path.startswith("s3://"):
105
+ parts = s3_path[5:].split("/", 1)
106
+ bucket = parts[0]
107
+ key = parts[1] if len(parts) > 1 else ""
108
+ return bucket, key
109
+ return "", ""
110
+
111
+ def _collect_files(self) -> list:
112
+ """收集需要上传的文件"""
113
+ files = []
114
+
115
+ if os.path.isfile(self.local_path):
116
+ files.append(self.local_path)
117
+ self.total_files += 1
118
+ else:
119
+ for root, _, filenames in os.walk(self.local_path):
120
+ for filename in filenames:
121
+ file_path = os.path.join(root, filename)
122
+ if not os.path.islink(file_path): # 跳过符号链接
123
+ files.append(file_path)
124
+ self.total_files += 1
125
+
126
+ # 过滤已处理的文件
127
+ base_path = os.path.dirname(self.local_path) if os.path.isfile(self.local_path) else self.local_path
128
+ unprocessed_files = []
129
+
130
+ for file_path in files:
131
+ rel_path = os.path.relpath(file_path, base_path)
132
+ if rel_path not in self.processed_files:
133
+ unprocessed_files.append(file_path)
134
+
135
+ return unprocessed_files
136
+
137
+ def _csv_writer_worker(self):
138
+ """CSV写入工作器"""
139
+ # 初始化CSV文件
140
+ uploaded_count = 0
141
+ file_exists = os.path.exists(self.csv_file)
142
+
143
+ with open(self.csv_file, "a", newline="", encoding="utf-8") as f:
144
+ writer = csv.writer(f)
145
+ if not file_exists:
146
+ writer.writerow(["local_path", "sha256", "s3path", "file_size"])
147
+
148
+ while True:
149
+ try:
150
+ record = self.csv_queue.get(timeout=1)
151
+ if record is None: # 结束信号
152
+ break
153
+
154
+ writer.writerow(
155
+ [record["local_path"], record["file_hash"], record["s3_path"], str(record["file_size"])]
156
+ )
157
+
158
+ f.flush() # 确保数据写入磁盘
159
+ self.csv_queue.task_done()
160
+ uploaded_count += 1
161
+ # 每上传100个文件,打印进度
162
+ if uploaded_count % 1000 == 0:
163
+ logger.info(f"已上传 {uploaded_count} 个文件")
164
+
165
+ except queue.Empty:
166
+ continue
167
+ except Exception as e:
168
+ logger.error(f"Failed to write CSV record: {e}")
169
+ self.csv_queue.task_done()
170
+
171
+ def _upload_worker(self, s3_client: SimpleS3Client, file_queue: queue.Queue, base_path: str):
172
+ """上传工作器"""
173
+ while True:
174
+ try:
175
+ file_path = file_queue.get(timeout=1)
176
+ if file_path is None: # 结束信号
177
+ break
178
+
179
+ try:
180
+ # 计算相对路径和S3对象名
181
+ rel_path = os.path.relpath(file_path, base_path)
182
+
183
+ object_name = os.path.join(self.target_prefix, rel_path).replace("\\", "/")
184
+
185
+ # 上传文件
186
+ file_hash, file_size = s3_client.upload_file(file_path, self.target_bucket, object_name)
187
+
188
+ # 将记录放入CSV队列
189
+ s3_path = f"s3://{self.target_bucket}/{object_name}"
190
+ record = {
191
+ "local_path": os.path.join("/", rel_path),
192
+ "file_hash": file_hash,
193
+ "s3_path": s3_path,
194
+ "file_size": file_size,
195
+ }
196
+ self.csv_queue.put(record)
197
+
198
+ logger.debug(f"Uploaded: {rel_path}")
199
+
200
+ except Exception as e:
201
+ logger.error(f"Failed to upload {file_path}: {e}")
202
+ finally:
203
+ file_queue.task_done()
204
+
205
+ except queue.Empty:
206
+ break
207
+
208
+ def _calculate_final_stats(self) -> UploadStatus:
209
+ """从CSV文件计算最终统计信息"""
210
+ stats = UploadStatus()
211
+ if not os.path.exists(self.csv_file):
212
+ return stats
213
+
214
+ total_count = 0
215
+ total_size = 0
216
+
217
+ try:
218
+ with open(self.csv_file, "r", encoding="utf-8") as f:
219
+ reader = csv.DictReader(f)
220
+ for row in reader:
221
+ total_count += 1
222
+ total_size += int(row["file_size"])
223
+ except Exception as e:
224
+ logger.error(f"Failed to calculate stats: {e}")
225
+
226
+ stats.update(total_count, total_size)
227
+
228
+ return stats
229
+
230
+ def run(self, s3_client: SimpleS3Client) -> UploadStatus:
231
+ """执行上传任务"""
232
+ # 收集文件
233
+ files = self._collect_files()
234
+ if not files:
235
+ logger.info("No files to upload")
236
+ return UploadStatus()
237
+
238
+ logger.info(f"Found {len(files)} files to upload")
239
+
240
+ # 准备文件队列
241
+ file_queue = queue.Queue()
242
+ for file_path in files:
243
+ file_queue.put(file_path)
244
+
245
+ base_path = os.path.dirname(self.local_path) if os.path.isfile(self.local_path) else self.local_path
246
+
247
+ # 启动CSV写入线程
248
+ csv_thread = threading.Thread(target=self._csv_writer_worker)
249
+ csv_thread.daemon = True
250
+ csv_thread.start()
251
+
252
+ try:
253
+ # 启动上传工作器
254
+ with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
255
+ futures = []
256
+ for i in range(self.num_workers):
257
+ future = executor.submit(self._upload_worker, s3_client, file_queue, base_path)
258
+ futures.append(future)
259
+
260
+ # 等待所有任务完成
261
+ for future in as_completed(futures):
262
+ future.result()
263
+
264
+ # 等待CSV队列处理完成
265
+ self.csv_queue.join()
266
+
267
+ # 发送结束信号给CSV写入线程
268
+ self.csv_queue.put(None)
269
+ csv_thread.join()
270
+
271
+ # 上传记录文件到S3
272
+ if os.path.exists(self.csv_file):
273
+ s3_client.upload_file(str(self.csv_file), self.csv_bucket, self.csv_key)
274
+ logger.info("Upload records saved to S3")
275
+
276
+ # 计算并上传最终统计信息
277
+ stats = self._calculate_final_stats()
278
+ s3_client.upload_json(stats.to_dict(), self.status_bucket, self.status_key)
279
+ logger.info(f"Upload completed: {stats.uploaded_count} files, {stats.uploaded_size} bytes")
280
+
281
+ finally:
282
+ # 清理工作目录
283
+ try:
284
+ import shutil
285
+
286
+ shutil.rmtree(self.work_dir)
287
+ except Exception as e:
288
+ logger.warning(f"Failed to cleanup workspace: {e}")
289
+ return stats
290
+
291
+
292
+ def main():
293
+ """主函数"""
294
+ parser = argparse.ArgumentParser(description="简化的数据摄取工具")
295
+ parser.add_argument("-e", "--endpoint", default="192.168.13.160:9008", help="S3端点")
296
+ parser.add_argument("-ak", "--access-key", default="admin2024", help="访问密钥")
297
+ parser.add_argument("-sk", "--secret-key", default="root@23452024", help="秘密密钥")
298
+ parser.add_argument("-t", "--target", default="s3://testbucket/test_ok11", help="目标S3路径")
299
+ parser.add_argument("-l", "--local", default="./test_data", help="本地路径")
300
+ parser.add_argument("-o", "--object-sheet", default="s3://testbucket/records/123.csv", help="记录文件S3路径")
301
+ parser.add_argument("-s", "--status", default="s3://testbucket/status/123.json", help="状态文件S3路径")
302
+ parser.add_argument("-i", "--task-id", type=int, default=123, help="任务ID")
303
+ parser.add_argument("-n", "--num-workers", type=int, default=10, help="工作线程数")
304
+
305
+ args = parser.parse_args()
306
+
307
+ # 检查本地路径
308
+ if not os.path.exists(args.local):
309
+ logger.error(f"Local path does not exist: {args.local}")
310
+ sys.exit(1)
311
+
312
+ logger.info(f"Starting upload: {args.local} -> {args.target}")
313
+
314
+ try:
315
+ # 创建S3客户端
316
+ s3_client = SimpleS3Client(args.endpoint, args.access_key, args.secret_key)
317
+
318
+ # 创建上传器并执行
319
+ uploader = DataUploader(
320
+ task_id=args.task_id,
321
+ local_path=args.local,
322
+ s3_target=args.target,
323
+ csv_path=args.object_sheet,
324
+ status_path=args.status,
325
+ num_workers=args.num_workers,
326
+ )
327
+
328
+ uploader.run(s3_client)
329
+ logger.info("Upload completed successfully")
330
+
331
+ except Exception as e:
332
+ logger.error(f"Upload failed: {e}")
333
+ sys.exit(1)
334
+
335
+
336
+ if __name__ == "__main__":
337
+ main()
aihub/utils/download.py CHANGED
@@ -1,15 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
- import concurrent.futures
4
3
  import os
5
4
  import tempfile
6
5
  import zipfile
7
6
  from typing import List, TypedDict
8
7
 
9
8
  import pyarrow.parquet as pq
10
- from tqdm import tqdm
9
+ from tqdm.contrib.concurrent import thread_map
11
10
 
12
- from .http import http_download_file
11
+ from .http import http_download_file, http_download_file_wrapper
13
12
  from .s3 import s3_to_url
14
13
 
15
14
 
@@ -59,18 +58,7 @@ def dataset_download(index_url: str, local_dir: str, worker: int = 4) -> None:
59
58
  if worker < 1:
60
59
  worker = 1
61
60
 
62
- with (
63
- tqdm(total=len(files), desc="Downloading dataset") as bar,
64
- concurrent.futures.ThreadPoolExecutor(max_workers=worker) as pool,
65
- ):
66
-
67
- def _one(flocal: str, furl: str):
68
- http_download_file(furl, flocal)
69
- bar.update()
70
-
71
- futures = [pool.submit(_one, p, u) for p, u in files]
72
- for fut in concurrent.futures.as_completed(futures):
73
- fut.result()
61
+ thread_map(http_download_file_wrapper, files, max_workers=worker)
74
62
 
75
63
 
76
64
  def zip_dir(dir_path: str, zip_path: str):
aihub/utils/http.py CHANGED
@@ -5,6 +5,12 @@ import os
5
5
  import requests
6
6
 
7
7
 
8
+ def http_download_file_wrapper(item):
9
+ """Wrapper function"""
10
+ dst_path, url = item
11
+ return http_download_file(url, dst_path)
12
+
13
+
8
14
  def http_download_file(url: str, dst_path: str, chunk: int = 1 << 16) -> None:
9
15
  os.makedirs(os.path.dirname(dst_path), exist_ok=True)
10
16
  with requests.get(url, timeout=None, stream=True) as r:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: intellif-aihub
3
- Version: 0.1.13
3
+ Version: 0.1.15
4
4
  Summary: Intellif AI-hub SDK.
5
5
  Author-email: Platform Team <aihub@example.com>
6
6
  License-Expression: Apache-2.0
@@ -1,11 +1,11 @@
1
- aihub/__init__.py,sha256=khDKUuWafURKVs5EAZkpOMiUHI2-V7axlqrWLPUpuZo,23
1
+ aihub/__init__.py,sha256=qb0TalpSt1CbprnFyeLUKqgrqNtmnk9IoQQ7umAoXVY,23
2
2
  aihub/client.py,sha256=nVELjkyVOG6DKJjurYn59fCoT5JsSayUweiH7bvKcAo,5547
3
3
  aihub/exceptions.py,sha256=l2cMAvipTqQOio3o11fXsCCSCevbuK4PTsxofkobFjk,500
4
4
  aihub/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- aihub/models/artifact.py,sha256=4xIWV3xfZuZWcCdmZGEZ5k_rvV4oc5C_3gapw5O-2vQ,4586
5
+ aihub/models/artifact.py,sha256=F-r7DJY9A09yIQJqWol6gLRu6y7NGjRa6-BxkMEluxU,4655
6
6
  aihub/models/common.py,sha256=qmabc2LkAdQJXIcpT1P35zxd0Lc8yDYdD4ame1iF4Bs,241
7
7
  aihub/models/data_warehouse.py,sha256=zXvWwg7ySoFJMdqQ_1UMTNEKDMhu1hDHlWdBAXdizBk,3905
8
- aihub/models/dataset_management.py,sha256=HmmOW0bA2byQNVavrDG0K5L6pAeGrTFv-ap9pvBWgds,6511
8
+ aihub/models/dataset_management.py,sha256=4DuQ0zM7jv73SJiqvieHLtn2Y-T6FIFV9r7bgzyCtDo,10790
9
9
  aihub/models/document_center.py,sha256=od9bzx6krAS6ktIA-ChxeqGcch0v2wsS1flY2vuHXBc,1340
10
10
  aihub/models/eval.py,sha256=4Gon4Sg4dOkyCx3KH2mO5ip3AhrBwrPC0UZA447HeoQ,910
11
11
  aihub/models/labelfree.py,sha256=nljprYO6ECuctTVbHqriQ73N5EEyYURhBrnU28Ngfvc,1589
@@ -17,9 +17,9 @@ aihub/models/task_center.py,sha256=HE21Q4Uj-vt9LHezHnqBYgdinhrh4iJPOq8VXbSMllU,5
17
17
  aihub/models/user_system.py,sha256=0L_pBkWL9v1tv_mclOyRgCyWibuuj_XU3mPoe2v48vQ,12216
18
18
  aihub/models/workflow_center.py,sha256=4xtI1WZ38ceXJ8gwDBj-QNjOiRlLO_8kGiQybdudJPY,20121
19
19
  aihub/services/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- aihub/services/artifact.py,sha256=PtCGhYFpFK_hppQr7A0bdXvceXzmYSwEpRj-PE2rIcQ,11473
20
+ aihub/services/artifact.py,sha256=lfOrgOT2AlH1w-75NLcQGOhVWdhmJcWD1gESPpUzqUw,11257
21
21
  aihub/services/data_warehouse.py,sha256=awvlJdggo8ph6sXweXXVp4GLRuUSD46LoD0QQksXRts,2964
22
- aihub/services/dataset_management.py,sha256=KB-NZpcQwixOpEkQ1xI0gKzvWA3A6ay6aQzFQoUfXXU,12847
22
+ aihub/services/dataset_management.py,sha256=R7mFsJ1dNOI_p5yNj_rQdLolRC0UKEN4WejE7uOjVlE,21379
23
23
  aihub/services/document_center.py,sha256=dG67Ji-DOnzL2t-4x4gVfMt9fbSj_IjVHCLw5R-VTkQ,1813
24
24
  aihub/services/eval.py,sha256=V1nBISIyYWg9JJO24xzy4-kit9NsaCYp1EWIX_fgJkQ,2128
25
25
  aihub/services/labelfree.py,sha256=xua62UWhVXTxJjHRyy86waaormnJjmpQwepcARBy_h0,1450
@@ -32,11 +32,12 @@ aihub/services/task_center.py,sha256=rVQG7q2_GN0501w5KHsOOlSVFX9ovpRMGX5hskCqggw
32
32
  aihub/services/user_system.py,sha256=IqWL4bnsKyyzuGT5l6adnw0qNXlH9PSo1-C_pFyOSzA,18868
33
33
  aihub/services/workflow_center.py,sha256=caKxOlba0J1s1RUK6RUm1ndJSwAcZXEakRanu3sGKPU,17468
34
34
  aihub/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
- aihub/utils/download.py,sha256=yC3SoL5uE68pMB-IsNz233wj-gFrHB7D7ALzQA5JkFM,2155
36
- aihub/utils/http.py,sha256=SvEWB4BxvwaHYqMVE4B0Go3OWGAD4xyQnUXDZ16yOSo,410
35
+ aihub/utils/di.py,sha256=vFUzno5WbRKu6-pj8Hnz9IqT7xb9UDZQ4qpOFH1YAtM,11812
36
+ aihub/utils/download.py,sha256=ZZVbcC-PnN3PumV7ZiJ_-srkt4HPPovu2F6Faa2RrPE,1830
37
+ aihub/utils/http.py,sha256=AmfHHNjptuuSFx2T1twWCnerR_hLN_gd0lUs8z36ERA,547
37
38
  aihub/utils/s3.py,sha256=ISIBP-XdBPkURpXnN56ZnIWokOOg2SRUh_qvxJk-G1Q,2187
38
- intellif_aihub-0.1.13.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
39
- intellif_aihub-0.1.13.dist-info/METADATA,sha256=aPEJoZsFEbMuS6ii-et5VxI1uI9OcjIpvqv6xZpEdyo,2949
40
- intellif_aihub-0.1.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
41
- intellif_aihub-0.1.13.dist-info/top_level.txt,sha256=vIvTtSIN73xv46BpYM-ctVGnyOiUQ9EWP_6ngvdIlvw,6
42
- intellif_aihub-0.1.13.dist-info/RECORD,,
39
+ intellif_aihub-0.1.15.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
40
+ intellif_aihub-0.1.15.dist-info/METADATA,sha256=Hz8Z3sB06pNTJF8lygzDU37da2bCgXCrzJ1-CRAlN7Y,2949
41
+ intellif_aihub-0.1.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
42
+ intellif_aihub-0.1.15.dist-info/top_level.txt,sha256=vIvTtSIN73xv46BpYM-ctVGnyOiUQ9EWP_6ngvdIlvw,6
43
+ intellif_aihub-0.1.15.dist-info/RECORD,,