tostorchconnector 1.0.0__tar.gz → 1.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tostorchconnector might be problematic. Click here for more details.
- {tostorchconnector-1.0.0/tostorchconnector/tostorchconnector.egg-info → tostorchconnector-1.0.1}/PKG-INFO +1 -1
- {tostorchconnector-1.0.0 → tostorchconnector-1.0.1}/pyproject.toml +2 -2
- tostorchconnector-1.0.1/tostorchconnector/__init__.py +15 -0
- tostorchconnector-1.0.1/tostorchconnector/tos_checkpoint.py +39 -0
- tostorchconnector-1.0.1/tostorchconnector/tos_client.py +169 -0
- tostorchconnector-1.0.1/tostorchconnector/tos_common.py +99 -0
- tostorchconnector-1.0.1/tostorchconnector/tos_error.py +11 -0
- tostorchconnector-1.0.1/tostorchconnector/tos_iterable_dataset.py +99 -0
- tostorchconnector-1.0.1/tostorchconnector/tos_map_dataset.py +78 -0
- tostorchconnector-1.0.1/tostorchconnector/tos_object_meta.py +25 -0
- tostorchconnector-1.0.1/tostorchconnector/tos_object_reader.py +170 -0
- tostorchconnector-1.0.1/tostorchconnector/tos_object_writer.py +83 -0
- {tostorchconnector-1.0.0 → tostorchconnector-1.0.1/tostorchconnector.egg-info}/PKG-INFO +1 -1
- tostorchconnector-1.0.1/tostorchconnector.egg-info/SOURCES.txt +20 -0
- tostorchconnector-1.0.1/tostorchconnector.egg-info/top_level.txt +1 -0
- tostorchconnector-1.0.0/tostorchconnector/tostorchconnector.egg-info/SOURCES.txt +0 -10
- tostorchconnector-1.0.0/tostorchconnector/tostorchconnector.egg-info/top_level.txt +0 -1
- {tostorchconnector-1.0.0 → tostorchconnector-1.0.1}/LICENSE +0 -0
- {tostorchconnector-1.0.0 → tostorchconnector-1.0.1}/README.md +0 -0
- {tostorchconnector-1.0.0 → tostorchconnector-1.0.1}/setup.cfg +0 -0
- {tostorchconnector-1.0.0 → tostorchconnector-1.0.1}/tests/test_tos_dataset.py +0 -0
- {tostorchconnector-1.0.0 → tostorchconnector-1.0.1}/tests/test_tosnativeclient.py +0 -0
- {tostorchconnector-1.0.0/tostorchconnector → tostorchconnector-1.0.1}/tostorchconnector.egg-info/dependency_links.txt +0 -0
- {tostorchconnector-1.0.0/tostorchconnector → tostorchconnector-1.0.1}/tostorchconnector.egg-info/requires.txt +0 -0
|
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|
|
5
5
|
|
|
6
6
|
[project]
|
|
7
7
|
name = "tostorchconnector"
|
|
8
|
-
version = "1.0.
|
|
8
|
+
version = "1.0.1"
|
|
9
9
|
description = "TOS connector integration for PyTorch"
|
|
10
10
|
authors = [{ name = "xiangshijian", email = "xiangshijian@bytedance.com" }]
|
|
11
11
|
requires-python = ">=3.8,<3.13"
|
|
@@ -29,4 +29,4 @@ dependencies = [
|
|
|
29
29
|
]
|
|
30
30
|
|
|
31
31
|
[tool.setuptools.packages]
|
|
32
|
-
find = { where = ["tostorchconnector"] }
|
|
32
|
+
find = { where = ["."], include = ["tostorchconnector"] }
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .tos_object_reader import TosObjectReader
|
|
2
|
+
from .tos_object_writer import TosObjectWriter
|
|
3
|
+
from .tos_iterable_dataset import TosIterableDataset
|
|
4
|
+
from .tos_map_dataset import TosMapDataset
|
|
5
|
+
from .tos_checkpoint import TosCheckpoint
|
|
6
|
+
from .tos_error import TosError
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
'TosObjectReader',
|
|
10
|
+
'TosObjectWriter',
|
|
11
|
+
'TosIterableDataset',
|
|
12
|
+
'TosMapDataset',
|
|
13
|
+
'TosCheckpoint',
|
|
14
|
+
'TosError',
|
|
15
|
+
]
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from . import TosObjectReader, TosObjectWriter
|
|
5
|
+
from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig
|
|
6
|
+
from .tos_common import parse_tos_url
|
|
7
|
+
|
|
8
|
+
log = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TosCheckpoint(object):
|
|
12
|
+
def __init__(self, region: str,
|
|
13
|
+
endpoint: Optional[str] = None,
|
|
14
|
+
cred: Optional[CredentialProvider] = None,
|
|
15
|
+
client_conf: Optional[TosClientConfig] = None,
|
|
16
|
+
log_conf: Optional[TosLogConfig] = None, use_native_client=True):
|
|
17
|
+
self._client = None
|
|
18
|
+
self._native_client = None
|
|
19
|
+
self._region = region
|
|
20
|
+
self._endpoint = endpoint
|
|
21
|
+
self._cred = cred
|
|
22
|
+
self._client_conf = client_conf
|
|
23
|
+
self._log_conf = log_conf
|
|
24
|
+
self._use_native_client = use_native_client
|
|
25
|
+
|
|
26
|
+
def reader(self, url: str) -> TosObjectReader:
|
|
27
|
+
bucket, key = parse_tos_url(url)
|
|
28
|
+
return self._get_tos_client().get_object(bucket, key)
|
|
29
|
+
|
|
30
|
+
def writer(self, url: str) -> TosObjectWriter:
|
|
31
|
+
bucket, key = parse_tos_url(url)
|
|
32
|
+
return self._get_tos_client().put_object(bucket, key)
|
|
33
|
+
|
|
34
|
+
def _get_tos_client(self):
|
|
35
|
+
if self._client is None:
|
|
36
|
+
self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
|
|
37
|
+
self._use_native_client)
|
|
38
|
+
log.info('TosIterableDataset init tos client succeed')
|
|
39
|
+
return self._client
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import Optional, List, Tuple
|
|
6
|
+
|
|
7
|
+
import tos
|
|
8
|
+
from tos.models2 import GetObjectOutput, PutObjectOutput
|
|
9
|
+
|
|
10
|
+
import tosnativeclient
|
|
11
|
+
|
|
12
|
+
from . import TosObjectReader, TosObjectWriter
|
|
13
|
+
from .tos_object_meta import TosObjectMeta
|
|
14
|
+
from .tos_object_writer import PutObjectStream
|
|
15
|
+
|
|
16
|
+
log = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CredentialProvider(object):
|
|
20
|
+
def __init__(self, ak: str, sk: str):
|
|
21
|
+
self._ak = ak
|
|
22
|
+
self._sk = sk
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def ak(self) -> str:
|
|
26
|
+
return self._ak
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def sk(self) -> str:
|
|
30
|
+
return self._sk
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TosClientConfig(object):
|
|
34
|
+
def __init__(self, part_size: int = 8 * 1024 * 1024,
|
|
35
|
+
max_retry_count: int = 3, shared_prefetch_tasks: int = 20):
|
|
36
|
+
self._part_size = part_size
|
|
37
|
+
self._max_retry_count = max_retry_count
|
|
38
|
+
self._shared_prefetch_tasks = shared_prefetch_tasks
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def part_size(self) -> int:
|
|
42
|
+
return self._part_size
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def max_retry_count(self) -> int:
|
|
46
|
+
return self._max_retry_count
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def shared_prefetch_tasks(self) -> int:
|
|
50
|
+
return self._shared_prefetch_tasks
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class TosLogConfig(object):
|
|
54
|
+
def __init__(self, log_dir: str = '',
|
|
55
|
+
log_file_name: str = '', log_level: Optional[int] = logging.INFO):
|
|
56
|
+
self._log_dir = log_dir
|
|
57
|
+
self._log_file_name = log_file_name
|
|
58
|
+
self._log_level = log_level
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def log_level(self) -> Optional[int]:
|
|
62
|
+
return self._log_level
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def log_dir(self) -> str:
|
|
66
|
+
return self._log_dir
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def log_file_name(self) -> str:
|
|
70
|
+
return self._log_file_name
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class TosClient(object):
|
|
74
|
+
def __init__(self, region: str, endpoint: Optional[str] = None, cred: Optional[CredentialProvider] = None,
|
|
75
|
+
client_conf: Optional[TosClientConfig] = None, log_conf: Optional[TosLogConfig] = None,
|
|
76
|
+
use_native_client: bool = True):
|
|
77
|
+
cred = CredentialProvider('', '') if cred is None else cred
|
|
78
|
+
client_conf = TosClientConfig() if client_conf is None else client_conf
|
|
79
|
+
log_conf = TosLogConfig() if log_conf is None else log_conf
|
|
80
|
+
self._part_size = client_conf.part_size
|
|
81
|
+
self._use_native_client = use_native_client
|
|
82
|
+
if use_native_client:
|
|
83
|
+
directives = ''
|
|
84
|
+
directory = ''
|
|
85
|
+
file_name_prefix = ''
|
|
86
|
+
if log_conf.log_dir and log_conf.log_file_name:
|
|
87
|
+
if log_conf.log_level:
|
|
88
|
+
if log_conf.log_level == logging.DEBUG:
|
|
89
|
+
directives = 'debug'
|
|
90
|
+
elif log_conf.log_level == logging.INFO:
|
|
91
|
+
directives = 'info'
|
|
92
|
+
elif log_conf.log_level == logging.WARN:
|
|
93
|
+
directives = 'warn'
|
|
94
|
+
elif log_conf.log_level == logging.ERROR:
|
|
95
|
+
directives = 'error'
|
|
96
|
+
else:
|
|
97
|
+
directives = 'info'
|
|
98
|
+
directory = log_conf.log_dir
|
|
99
|
+
file_name_prefix = log_conf.log_file_name
|
|
100
|
+
self._client = tosnativeclient.TosClient(region, endpoint, cred.ak, cred.sk, client_conf.part_size,
|
|
101
|
+
client_conf.max_retry_count, directives=directives,
|
|
102
|
+
directory=directory,
|
|
103
|
+
file_name_prefix=file_name_prefix)
|
|
104
|
+
else:
|
|
105
|
+
self._client = tos.TosClientV2(cred.ak, cred.sk, endpoint=endpoint, region=region,
|
|
106
|
+
max_retry_count=client_conf.max_retry_count)
|
|
107
|
+
if log_conf.log_dir and log_conf.log_file_name:
|
|
108
|
+
file_path = os.path.join(log_conf.log_dir, log_conf.log_file_name)
|
|
109
|
+
log_level = log_conf.log_level if log_conf.log_level else logging.INFO
|
|
110
|
+
tos.set_logger(file_path=file_path, level=log_level)
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def use_native_client(self):
|
|
114
|
+
return self._use_native_client
|
|
115
|
+
|
|
116
|
+
def get_object(self, bucket: str, key: str, etag: Optional[str] = None,
|
|
117
|
+
size: Optional[int] = None) -> TosObjectReader:
|
|
118
|
+
log.debug(f'get_object tos://{bucket}/{key}')
|
|
119
|
+
if size is None or etag is None:
|
|
120
|
+
get_object_meta = partial(self.head_object, bucket, key)
|
|
121
|
+
else:
|
|
122
|
+
get_object_meta = lambda: TosObjectMeta(bucket, key, size, etag)
|
|
123
|
+
if isinstance(self._client, tosnativeclient.TosClient):
|
|
124
|
+
get_object_stream = partial(self._client.get_object, bucket, key)
|
|
125
|
+
else:
|
|
126
|
+
def get_object_stream(et: str, _: int) -> GetObjectOutput:
|
|
127
|
+
return self._client.get_object(bucket, key, '', et)
|
|
128
|
+
|
|
129
|
+
return TosObjectReader(bucket, key, get_object_meta, get_object_stream)
|
|
130
|
+
|
|
131
|
+
def put_object(self, bucket: str, key: str, storage_class: Optional[str] = None) -> TosObjectWriter:
|
|
132
|
+
log.debug(f'put_object tos://{bucket}/{key}')
|
|
133
|
+
|
|
134
|
+
if isinstance(self._client, tosnativeclient.TosClient):
|
|
135
|
+
storage_class = storage_class if storage_class is not None else ''
|
|
136
|
+
put_object_stream = self._client.put_object(bucket, key, storage_class)
|
|
137
|
+
else:
|
|
138
|
+
def put_object(content: io.BytesIO) -> PutObjectOutput:
|
|
139
|
+
return self._client.put_object(bucket, key, storage_class=storage_class, content=content)
|
|
140
|
+
|
|
141
|
+
put_object_stream = PutObjectStream(put_object)
|
|
142
|
+
|
|
143
|
+
return TosObjectWriter(bucket, key, put_object_stream)
|
|
144
|
+
|
|
145
|
+
def head_object(self, bucket: str, key: str) -> TosObjectMeta:
|
|
146
|
+
log.debug(f'head_object tos://{bucket}/{key}')
|
|
147
|
+
if isinstance(self._client, tosnativeclient.TosClient):
|
|
148
|
+
resp = self._client.head_object(bucket, key)
|
|
149
|
+
return TosObjectMeta(resp.bucket, resp.key, resp.size, resp.etag)
|
|
150
|
+
resp = self._client.head_object(bucket, key)
|
|
151
|
+
return TosObjectMeta(bucket, key, resp.content_length, resp.etag)
|
|
152
|
+
|
|
153
|
+
def gen_list_stream(self, bucket: str, prefix: str, max_keys: int = 1000,
|
|
154
|
+
delimiter: Optional[str] = None) -> tosnativeclient.ListStream:
|
|
155
|
+
if isinstance(self._client, tosnativeclient.TosClient):
|
|
156
|
+
delimiter = delimiter if delimiter is not None else ''
|
|
157
|
+
return self._client.list_objects(bucket, prefix, max_keys=max_keys, delimiter=delimiter)
|
|
158
|
+
raise NotImplementedError()
|
|
159
|
+
|
|
160
|
+
def list_objects(self, bucket: str, prefix: str, max_keys: int = 1000,
|
|
161
|
+
continuation_token: Optional[str] = None, delimiter: Optional[str] = None) -> Tuple[
|
|
162
|
+
List[TosObjectMeta], bool, Optional[str]]:
|
|
163
|
+
log.debug(f'list_objects tos://{bucket}/{prefix}')
|
|
164
|
+
resp = self._client.list_objects_type2(bucket, prefix, max_keys=max_keys, continuation_token=continuation_token,
|
|
165
|
+
delimiter=delimiter)
|
|
166
|
+
object_metas = []
|
|
167
|
+
for obj in resp.contents:
|
|
168
|
+
object_metas.append(TosObjectMeta(bucket, obj.key, obj.size, obj.etag))
|
|
169
|
+
return object_metas, resp.is_truncated, resp.next_continuation_token
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Union, Iterator, Tuple, Optional
|
|
3
|
+
|
|
4
|
+
from . import TosObjectReader
|
|
5
|
+
from .tos_client import TosClient
|
|
6
|
+
from .tos_object_meta import TosObjectMeta
|
|
7
|
+
|
|
8
|
+
log = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TosObjectIterator(object):
|
|
12
|
+
def __init__(self, bucket: str, prefix: str, client: TosClient):
|
|
13
|
+
self._bucket = bucket
|
|
14
|
+
self._prefix = prefix
|
|
15
|
+
self._client = client
|
|
16
|
+
self._delimiter: Optional[str] = None
|
|
17
|
+
self._list_stream = None
|
|
18
|
+
|
|
19
|
+
self._object_metas = None
|
|
20
|
+
self._index = 0
|
|
21
|
+
|
|
22
|
+
self._is_truncated = True
|
|
23
|
+
self._continuation_token = None
|
|
24
|
+
|
|
25
|
+
def close(self):
|
|
26
|
+
if self._list_stream is not None:
|
|
27
|
+
self._list_stream.close()
|
|
28
|
+
|
|
29
|
+
def __iter__(self) -> Iterator[TosObjectMeta]:
|
|
30
|
+
return self
|
|
31
|
+
|
|
32
|
+
def __next__(self):
|
|
33
|
+
if self._client.use_native_client:
|
|
34
|
+
if self._list_stream is None:
|
|
35
|
+
self._list_stream = self._client.gen_list_stream(self._bucket, self._prefix, max_keys=1000,
|
|
36
|
+
delimiter=self._delimiter)
|
|
37
|
+
|
|
38
|
+
if self._object_metas is None or self._index >= len(self._object_metas):
|
|
39
|
+
self._object_metas = []
|
|
40
|
+
objects = next(self._list_stream)
|
|
41
|
+
for content in objects.contents:
|
|
42
|
+
self._object_metas.append(TosObjectMeta(content.bucket, content.key, content.size, content.etag))
|
|
43
|
+
|
|
44
|
+
object_meta = self._object_metas[self._index]
|
|
45
|
+
self._index += 1
|
|
46
|
+
return object_meta
|
|
47
|
+
|
|
48
|
+
while self._object_metas is None or self._index >= len(self._object_metas):
|
|
49
|
+
if not self._is_truncated:
|
|
50
|
+
raise StopIteration
|
|
51
|
+
self._object_metas, self._is_truncated, self._continuation_token = self._client.list_objects(
|
|
52
|
+
self._bucket,
|
|
53
|
+
self._prefix,
|
|
54
|
+
max_keys=1000,
|
|
55
|
+
continuation_token=self._continuation_token,
|
|
56
|
+
delimiter=self._delimiter)
|
|
57
|
+
self._index = 0
|
|
58
|
+
|
|
59
|
+
object_meta = self._object_metas[self._index]
|
|
60
|
+
self._index += 1
|
|
61
|
+
return object_meta
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def parse_tos_url(url: str) -> Tuple[str, str]:
|
|
65
|
+
if not url:
|
|
66
|
+
raise ValueError('url is empty')
|
|
67
|
+
|
|
68
|
+
if url.startswith('tos://'):
|
|
69
|
+
url = url[len('tos://'):]
|
|
70
|
+
|
|
71
|
+
if not url:
|
|
72
|
+
raise ValueError('bucket is empty')
|
|
73
|
+
|
|
74
|
+
url = url.split('/', maxsplit=1)
|
|
75
|
+
if len(url) == 1:
|
|
76
|
+
bucket = url[0]
|
|
77
|
+
prefix = ''
|
|
78
|
+
else:
|
|
79
|
+
bucket = url[0]
|
|
80
|
+
prefix = url[1]
|
|
81
|
+
|
|
82
|
+
if not bucket:
|
|
83
|
+
raise ValueError('bucket is empty')
|
|
84
|
+
return bucket, prefix
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def default_trans(obj: TosObjectReader) -> TosObjectReader:
|
|
88
|
+
return obj
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def gen_dataset_from_urls(urls: Union[str, Iterator[str]], _: TosClient) -> Iterator[TosObjectMeta]:
|
|
92
|
+
if isinstance(urls, str):
|
|
93
|
+
urls = [urls]
|
|
94
|
+
return (TosObjectMeta(bucket, key) for bucket, key in [parse_tos_url(url) for url in urls])
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def gen_dataset_from_prefix(prefix: str, client: TosClient) -> Iterator[TosObjectMeta]:
|
|
98
|
+
bucket, prefix = parse_tos_url(prefix)
|
|
99
|
+
return iter(TosObjectIterator(bucket, prefix, client))
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class TosError(Exception):
|
|
5
|
+
def __init__(self, message: str, status_code: Optional[int] = None, ec: Optional[str] = None,
|
|
6
|
+
request_id: Optional[str] = None):
|
|
7
|
+
super().__init__(message)
|
|
8
|
+
self.message = message
|
|
9
|
+
self.status_code = status_code
|
|
10
|
+
self.ec = ec
|
|
11
|
+
self.request_id = request_id
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Iterator, Any, Optional, Callable, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from . import TosObjectReader
|
|
8
|
+
from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig
|
|
9
|
+
from .tos_common import default_trans, gen_dataset_from_urls, gen_dataset_from_prefix
|
|
10
|
+
from .tos_object_meta import TosObjectMeta
|
|
11
|
+
|
|
12
|
+
log = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TosIterableDataset(torch.utils.data.IterableDataset):
|
|
16
|
+
def __init__(self, region: str,
|
|
17
|
+
gen_dataset: Callable[[TosClient], Iterator[TosObjectMeta]],
|
|
18
|
+
endpoint: Optional[str] = None,
|
|
19
|
+
trans: Callable[[TosObjectReader], Any] = default_trans,
|
|
20
|
+
cred: Optional[CredentialProvider] = None,
|
|
21
|
+
client_conf: Optional[TosClientConfig] = None,
|
|
22
|
+
log_conf: Optional[TosLogConfig] = None,
|
|
23
|
+
sharding: bool = False, use_native_client=True):
|
|
24
|
+
self._client: Optional[TosClient] = None
|
|
25
|
+
self._gen_dataset = gen_dataset
|
|
26
|
+
self._region = region
|
|
27
|
+
self._endpoint = endpoint
|
|
28
|
+
self._trans = trans
|
|
29
|
+
self._cred = cred
|
|
30
|
+
self._client_conf = client_conf
|
|
31
|
+
self._log_conf = log_conf
|
|
32
|
+
self._sharding = sharding
|
|
33
|
+
self._use_native_client = use_native_client
|
|
34
|
+
if torch.distributed.is_initialized():
|
|
35
|
+
self._rank = torch.distributed.get_rank()
|
|
36
|
+
self._world_size = torch.distributed.get_world_size()
|
|
37
|
+
else:
|
|
38
|
+
self._rank = 0
|
|
39
|
+
self._world_size = 1
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def from_urls(cls, urls: Union[str, Iterator[str]], *, region: str, endpoint: Optional[str] = None,
|
|
43
|
+
transform: Callable[[TosObjectReader], Any] = default_trans,
|
|
44
|
+
cred: Optional[CredentialProvider] = None,
|
|
45
|
+
client_conf: Optional[TosClientConfig] = None,
|
|
46
|
+
log_conf: Optional[TosLogConfig] = None,
|
|
47
|
+
sharding: bool = False, use_native_client=True):
|
|
48
|
+
log.info(f'building {cls.__name__} from_urls')
|
|
49
|
+
return cls(region, partial(gen_dataset_from_urls, urls), endpoint, transform, cred, client_conf, log_conf,
|
|
50
|
+
sharding, use_native_client)
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def from_prefix(cls, prefix: str, *, region: str, endpoint: Optional[str] = None,
|
|
54
|
+
transform: Callable[[TosObjectReader], Any] = default_trans,
|
|
55
|
+
cred: Optional[CredentialProvider] = None,
|
|
56
|
+
client_conf: Optional[TosClientConfig] = None,
|
|
57
|
+
log_conf: Optional[TosLogConfig] = None,
|
|
58
|
+
sharding: bool = False, use_native_client=True):
|
|
59
|
+
log.info(f'building {cls.__name__} from_prefix')
|
|
60
|
+
return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, transform, cred, client_conf, log_conf,
|
|
61
|
+
sharding, use_native_client)
|
|
62
|
+
|
|
63
|
+
def __iter__(self) -> Iterator[Any]:
|
|
64
|
+
worker_id = 0
|
|
65
|
+
num_workers = 1
|
|
66
|
+
if self._sharding:
|
|
67
|
+
worker_info = torch.utils.data.get_worker_info()
|
|
68
|
+
if worker_info is not None:
|
|
69
|
+
worker_id = worker_info.id
|
|
70
|
+
num_workers = worker_info.num_workers
|
|
71
|
+
|
|
72
|
+
if not self._sharding or (self._world_size == 1 and num_workers == 1):
|
|
73
|
+
return map(
|
|
74
|
+
self._trans_tos_object,
|
|
75
|
+
self._gen_dataset(self._get_tos_client()),
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
part_dataset = (
|
|
79
|
+
obj
|
|
80
|
+
for idx, obj in enumerate(self._gen_dataset(self._get_tos_client()))
|
|
81
|
+
if idx % self._world_size == self._rank
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
part_dataset = (
|
|
85
|
+
obj
|
|
86
|
+
for idx, obj in enumerate(part_dataset)
|
|
87
|
+
if idx % num_workers == worker_id
|
|
88
|
+
)
|
|
89
|
+
return map(self._trans_tos_object, part_dataset)
|
|
90
|
+
|
|
91
|
+
def _get_tos_client(self):
|
|
92
|
+
if self._client is None:
|
|
93
|
+
self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf)
|
|
94
|
+
log.info('TosIterableDataset init tos client succeed')
|
|
95
|
+
return self._client
|
|
96
|
+
|
|
97
|
+
def _trans_tos_object(self, object_meta: TosObjectMeta) -> Any:
|
|
98
|
+
obj = self._get_tos_client().get_object(object_meta.bucket, object_meta.key, object_meta.etag, object_meta.size)
|
|
99
|
+
return self._trans(obj)
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Any, Callable, Iterator, Optional, List, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from . import TosObjectReader
|
|
8
|
+
from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig
|
|
9
|
+
from .tos_common import default_trans, gen_dataset_from_prefix, \
|
|
10
|
+
gen_dataset_from_urls
|
|
11
|
+
from .tos_object_meta import TosObjectMeta
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TosMapDataset(torch.utils.data.Dataset):
|
|
17
|
+
def __init__(self, region: str,
|
|
18
|
+
gen_dataset: Callable[[TosClient], Iterator[TosObjectMeta]],
|
|
19
|
+
endpoint: Optional[str] = None,
|
|
20
|
+
trans: Callable[[TosObjectReader], Any] = default_trans,
|
|
21
|
+
cred: Optional[CredentialProvider] = None,
|
|
22
|
+
client_conf: Optional[TosClientConfig] = None,
|
|
23
|
+
log_conf: Optional[TosLogConfig] = None, use_native_client=True):
|
|
24
|
+
self._client: Optional[TosClient] = None
|
|
25
|
+
self._gen_dataset = gen_dataset
|
|
26
|
+
self._region = region
|
|
27
|
+
self._endpoint = endpoint
|
|
28
|
+
self._trans = trans
|
|
29
|
+
self._cred = cred
|
|
30
|
+
self._client_conf = client_conf
|
|
31
|
+
self._log_conf = log_conf
|
|
32
|
+
self._use_native_client = use_native_client
|
|
33
|
+
self._dataset: Optional[List[TosObjectMeta]] = None
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def from_urls(cls, urls: Union[str, Iterator[str]], *, region: str, endpoint: Optional[str] = None,
|
|
37
|
+
trans: Callable[[TosObjectReader], Any] = default_trans,
|
|
38
|
+
cred: Optional[CredentialProvider] = None,
|
|
39
|
+
client_conf: Optional[TosClientConfig] = None,
|
|
40
|
+
log_conf: Optional[TosLogConfig] = None, use_native_client=True):
|
|
41
|
+
log.info(f'building {cls.__name__} from_urls')
|
|
42
|
+
return cls(region, partial(gen_dataset_from_urls, urls), endpoint, trans, cred, client_conf, log_conf,
|
|
43
|
+
use_native_client)
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def from_prefix(cls, prefix: str, *, region: str, endpoint: Optional[str] = None,
|
|
47
|
+
trans: Callable[[TosObjectReader], Any] = default_trans,
|
|
48
|
+
cred: Optional[CredentialProvider] = None,
|
|
49
|
+
client_conf: Optional[TosClientConfig] = None,
|
|
50
|
+
log_conf: Optional[TosLogConfig] = None, use_native_client=True):
|
|
51
|
+
log.info(f'building {cls.__name__} from_prefix')
|
|
52
|
+
return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, trans, cred, client_conf, log_conf,
|
|
53
|
+
use_native_client)
|
|
54
|
+
|
|
55
|
+
def __getitem__(self, i: int) -> Any:
|
|
56
|
+
return self._trans_tos_object(i)
|
|
57
|
+
|
|
58
|
+
def __len__(self) -> int:
|
|
59
|
+
return len(self._data_set)
|
|
60
|
+
|
|
61
|
+
def _get_tos_client(self) -> TosClient:
|
|
62
|
+
if self._client is None:
|
|
63
|
+
self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
|
|
64
|
+
self._use_native_client)
|
|
65
|
+
log.info('TosMapDataset init tos client succeed')
|
|
66
|
+
return self._client
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def _data_set(self):
|
|
70
|
+
if self._dataset is None:
|
|
71
|
+
self._dataset = list(self._gen_dataset(self._get_tos_client()))
|
|
72
|
+
assert self._dataset is not None
|
|
73
|
+
return self._dataset
|
|
74
|
+
|
|
75
|
+
def _trans_tos_object(self, i: int) -> Any:
|
|
76
|
+
object_meta = self._data_set[i]
|
|
77
|
+
obj = self._get_tos_client().get_object(object_meta.bucket, object_meta.key, object_meta.etag, object_meta.size)
|
|
78
|
+
return self._trans(obj)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class TosObjectMeta(object):
|
|
5
|
+
def __init__(self, bucket: str, key: str, size: Optional[int] = None, etag: Optional[str] = None):
|
|
6
|
+
self._bucket = bucket
|
|
7
|
+
self._key = key
|
|
8
|
+
self._size = size
|
|
9
|
+
self._etag = etag
|
|
10
|
+
|
|
11
|
+
@property
|
|
12
|
+
def bucket(self) -> str:
|
|
13
|
+
return self._bucket
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def key(self) -> str:
|
|
17
|
+
return self._key
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def size(self) -> Optional[int]:
|
|
21
|
+
return self._size
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def etag(self) -> Optional[str]:
|
|
25
|
+
return self._etag
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import logging
|
|
3
|
+
from functools import cached_property
|
|
4
|
+
from os import SEEK_SET, SEEK_CUR, SEEK_END
|
|
5
|
+
from typing import Optional, Callable, Iterator
|
|
6
|
+
|
|
7
|
+
from tos.models2 import GetObjectOutput
|
|
8
|
+
from tosnativeclient.tosnativeclient import ReadStream
|
|
9
|
+
|
|
10
|
+
from .tos_object_meta import TosObjectMeta
|
|
11
|
+
|
|
12
|
+
log = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TosObjectReader(io.BufferedIOBase):
|
|
16
|
+
|
|
17
|
+
def __init__(self, bucket: str, key: str,
|
|
18
|
+
get_object_meta: Optional[Callable[[], TosObjectMeta]],
|
|
19
|
+
get_object_stream: Callable[[str, int], ReadStream | GetObjectOutput]):
|
|
20
|
+
if not bucket:
|
|
21
|
+
raise ValueError('bucket is empty')
|
|
22
|
+
self._bucket = bucket
|
|
23
|
+
self._key = key
|
|
24
|
+
self._get_object_meta = get_object_meta
|
|
25
|
+
self._get_object_stream = get_object_stream
|
|
26
|
+
self._object_stream: Optional[ReadStream | GetObjectOutput] = None
|
|
27
|
+
self._object_stream_offset = 0
|
|
28
|
+
self._total_size: Optional[int] = None
|
|
29
|
+
self._read_offset = 0
|
|
30
|
+
self._buffer = io.BytesIO()
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def bucket(self) -> str:
|
|
34
|
+
return self._bucket
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def key(self) -> str:
|
|
38
|
+
return self._key
|
|
39
|
+
|
|
40
|
+
def read(self, size: Optional[int] = None) -> Optional[bytes]:
|
|
41
|
+
if self._is_read_to_end():
|
|
42
|
+
return b''
|
|
43
|
+
|
|
44
|
+
self._trigger_prefetch()
|
|
45
|
+
current_read_offset = self._read_offset
|
|
46
|
+
if size is None or size < 0:
|
|
47
|
+
# means read all
|
|
48
|
+
self._buffer.seek(0, SEEK_END)
|
|
49
|
+
if isinstance(self._object_stream, ReadStream):
|
|
50
|
+
try:
|
|
51
|
+
chunk_size = 1 * 1024 * 1024
|
|
52
|
+
while 1:
|
|
53
|
+
chunk = self._object_stream.read(self._object_stream_offset, chunk_size)
|
|
54
|
+
if not chunk:
|
|
55
|
+
break
|
|
56
|
+
self._object_stream_offset += len(chunk)
|
|
57
|
+
self._buffer.write(chunk)
|
|
58
|
+
finally:
|
|
59
|
+
self._object_stream.close()
|
|
60
|
+
else:
|
|
61
|
+
for chunk in self._object_stream:
|
|
62
|
+
self._buffer.write(chunk)
|
|
63
|
+
|
|
64
|
+
self._total_size = self._buffer.tell()
|
|
65
|
+
else:
|
|
66
|
+
self.seek(size, SEEK_CUR)
|
|
67
|
+
|
|
68
|
+
self._buffer.seek(current_read_offset)
|
|
69
|
+
data = self._buffer.read(size)
|
|
70
|
+
self._read_offset = self._buffer.tell()
|
|
71
|
+
return data
|
|
72
|
+
|
|
73
|
+
def readinto(self, buf) -> Optional[int]:
|
|
74
|
+
size = len(buf)
|
|
75
|
+
if self._is_read_to_end() or size == 0:
|
|
76
|
+
return 0
|
|
77
|
+
|
|
78
|
+
self._trigger_prefetch()
|
|
79
|
+
current_read_offset = self._read_offset
|
|
80
|
+
self.seek(size, SEEK_CUR)
|
|
81
|
+
self._buffer.seek(current_read_offset)
|
|
82
|
+
readed = self._buffer.readinto(buf)
|
|
83
|
+
self._read_offset = self._buffer.tell()
|
|
84
|
+
return readed
|
|
85
|
+
|
|
86
|
+
def seek(self, offset: int, whence: int = SEEK_SET) -> int:
|
|
87
|
+
if whence == SEEK_END:
|
|
88
|
+
if offset >= 0:
|
|
89
|
+
self._read_offset = self._get_total_size()
|
|
90
|
+
return self._read_offset
|
|
91
|
+
# offset is negative
|
|
92
|
+
offset += self._get_total_size()
|
|
93
|
+
elif whence == SEEK_CUR:
|
|
94
|
+
if self._is_read_to_end() and offset >= 0:
|
|
95
|
+
return self._read_offset
|
|
96
|
+
offset += self._read_offset
|
|
97
|
+
elif whence == SEEK_SET:
|
|
98
|
+
pass
|
|
99
|
+
else:
|
|
100
|
+
raise ValueError('invalid whence')
|
|
101
|
+
|
|
102
|
+
if offset < 0:
|
|
103
|
+
raise ValueError(f'invalid seek offset {offset}')
|
|
104
|
+
|
|
105
|
+
if offset > self._buffer_size():
|
|
106
|
+
self._prefetch_to_offset(offset)
|
|
107
|
+
|
|
108
|
+
self._read_offset = self._buffer.seek(offset)
|
|
109
|
+
return self._read_offset
|
|
110
|
+
|
|
111
|
+
def tell(self) -> int:
|
|
112
|
+
return self._read_offset
|
|
113
|
+
|
|
114
|
+
def readable(self):
|
|
115
|
+
return True
|
|
116
|
+
|
|
117
|
+
def writable(self):
|
|
118
|
+
return False
|
|
119
|
+
|
|
120
|
+
def seekable(self):
|
|
121
|
+
return True
|
|
122
|
+
|
|
123
|
+
@cached_property
|
|
124
|
+
def _object_meta(self) -> TosObjectMeta:
|
|
125
|
+
return self._get_object_meta()
|
|
126
|
+
|
|
127
|
+
def _trigger_prefetch(self) -> None:
|
|
128
|
+
if self._object_stream is None:
|
|
129
|
+
object_meta = self._object_meta
|
|
130
|
+
self._object_stream = self._get_object_stream(object_meta.etag, object_meta.size)
|
|
131
|
+
self._object_stream_offset = 0
|
|
132
|
+
|
|
133
|
+
def _is_read_to_end(self) -> bool:
|
|
134
|
+
if self._total_size is None:
|
|
135
|
+
return False
|
|
136
|
+
return self._read_offset == self._total_size
|
|
137
|
+
|
|
138
|
+
def _get_total_size(self) -> int:
|
|
139
|
+
if self._total_size is None:
|
|
140
|
+
self._total_size = self._object_meta.size
|
|
141
|
+
return self._total_size
|
|
142
|
+
|
|
143
|
+
def _prefetch_to_offset(self, offset: int) -> None:
|
|
144
|
+
self._trigger_prefetch()
|
|
145
|
+
size = self._buffer.seek(0, SEEK_END)
|
|
146
|
+
if isinstance(self._object_stream, ReadStream):
|
|
147
|
+
try:
|
|
148
|
+
chunk_size = 1 * 1024 * 1024
|
|
149
|
+
while offset > size:
|
|
150
|
+
chunk = self._object_stream.read(self._object_stream_offset, chunk_size)
|
|
151
|
+
if not chunk:
|
|
152
|
+
break
|
|
153
|
+
size += self._buffer.write(chunk)
|
|
154
|
+
self._object_stream_offset += len(chunk)
|
|
155
|
+
self._total_size = self._buffer.tell()
|
|
156
|
+
finally:
|
|
157
|
+
self._object_stream.close()
|
|
158
|
+
else:
|
|
159
|
+
try:
|
|
160
|
+
while offset > size:
|
|
161
|
+
size += self._buffer.write(next(self._object_stream))
|
|
162
|
+
except StopIteration:
|
|
163
|
+
self._total_size = self._buffer.tell()
|
|
164
|
+
|
|
165
|
+
def _buffer_size(self) -> int:
|
|
166
|
+
cur_pos = self._buffer.tell()
|
|
167
|
+
self._buffer.seek(0, SEEK_END)
|
|
168
|
+
buffer_size = self._buffer.tell()
|
|
169
|
+
self._buffer.seek(cur_pos)
|
|
170
|
+
return buffer_size
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import logging
|
|
3
|
+
import threading
|
|
4
|
+
from typing import Callable, Optional, Any
|
|
5
|
+
|
|
6
|
+
from tos.models2 import PutObjectOutput
|
|
7
|
+
from tosnativeclient.tosnativeclient import WriteStream
|
|
8
|
+
|
|
9
|
+
log = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PutObjectStream(object):
|
|
13
|
+
def __init__(self, put_object: Callable[[io.BytesIO], PutObjectOutput]):
|
|
14
|
+
self._put_object = put_object
|
|
15
|
+
self._buffer = io.BytesIO()
|
|
16
|
+
|
|
17
|
+
def write(self, data):
|
|
18
|
+
self._buffer.write(data)
|
|
19
|
+
|
|
20
|
+
def close(self):
|
|
21
|
+
self._buffer.seek(0)
|
|
22
|
+
_ = self._put_object(self._buffer)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TosObjectWriter(io.BufferedIOBase):
|
|
26
|
+
|
|
27
|
+
def __init__(self, bucket: str, key: str, put_object_stream: WriteStream | PutObjectStream):
|
|
28
|
+
if not bucket:
|
|
29
|
+
raise ValueError('bucket is empty')
|
|
30
|
+
self._bucket = bucket
|
|
31
|
+
self._key = key
|
|
32
|
+
self._put_object_stream = put_object_stream
|
|
33
|
+
self._write_offset = 0
|
|
34
|
+
self._closed = False
|
|
35
|
+
self._lock = threading.Lock()
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def bucket(self) -> str:
|
|
39
|
+
return self._bucket
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def key(self) -> str:
|
|
43
|
+
return self._key
|
|
44
|
+
|
|
45
|
+
def __enter__(self):
|
|
46
|
+
self._write_offset = 0
|
|
47
|
+
return self
|
|
48
|
+
|
|
49
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
50
|
+
if exc_type is not None:
|
|
51
|
+
try:
|
|
52
|
+
log.info(f'Exception occurred before closing stream: {exc_type.__name__}: {exc_val}')
|
|
53
|
+
except:
|
|
54
|
+
pass
|
|
55
|
+
else:
|
|
56
|
+
self.close()
|
|
57
|
+
|
|
58
|
+
def write(self, data) -> int:
|
|
59
|
+
self._put_object_stream.write(data)
|
|
60
|
+
self._write_offset += len(data)
|
|
61
|
+
return len(data)
|
|
62
|
+
|
|
63
|
+
def close(self) -> None:
|
|
64
|
+
with self._lock:
|
|
65
|
+
if self._closed:
|
|
66
|
+
return
|
|
67
|
+
self._closed = True
|
|
68
|
+
self._put_object_stream.close()
|
|
69
|
+
|
|
70
|
+
def tell(self) -> int:
|
|
71
|
+
return self._write_offset
|
|
72
|
+
|
|
73
|
+
def flush(self) -> None:
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
def readable(self):
|
|
77
|
+
return False
|
|
78
|
+
|
|
79
|
+
def writable(self):
|
|
80
|
+
return True
|
|
81
|
+
|
|
82
|
+
def seekable(self):
|
|
83
|
+
return False
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
tests/test_tos_dataset.py
|
|
5
|
+
tests/test_tosnativeclient.py
|
|
6
|
+
tostorchconnector/__init__.py
|
|
7
|
+
tostorchconnector/tos_checkpoint.py
|
|
8
|
+
tostorchconnector/tos_client.py
|
|
9
|
+
tostorchconnector/tos_common.py
|
|
10
|
+
tostorchconnector/tos_error.py
|
|
11
|
+
tostorchconnector/tos_iterable_dataset.py
|
|
12
|
+
tostorchconnector/tos_map_dataset.py
|
|
13
|
+
tostorchconnector/tos_object_meta.py
|
|
14
|
+
tostorchconnector/tos_object_reader.py
|
|
15
|
+
tostorchconnector/tos_object_writer.py
|
|
16
|
+
tostorchconnector.egg-info/PKG-INFO
|
|
17
|
+
tostorchconnector.egg-info/SOURCES.txt
|
|
18
|
+
tostorchconnector.egg-info/dependency_links.txt
|
|
19
|
+
tostorchconnector.egg-info/requires.txt
|
|
20
|
+
tostorchconnector.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
tostorchconnector
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
LICENSE
|
|
2
|
-
README.md
|
|
3
|
-
pyproject.toml
|
|
4
|
-
tests/test_tos_dataset.py
|
|
5
|
-
tests/test_tosnativeclient.py
|
|
6
|
-
tostorchconnector/tostorchconnector.egg-info/PKG-INFO
|
|
7
|
-
tostorchconnector/tostorchconnector.egg-info/SOURCES.txt
|
|
8
|
-
tostorchconnector/tostorchconnector.egg-info/dependency_links.txt
|
|
9
|
-
tostorchconnector/tostorchconnector.egg-info/requires.txt
|
|
10
|
-
tostorchconnector/tostorchconnector.egg-info/top_level.txt
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|