tostorchconnector 1.0.7__tar.gz → 1.0.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tostorchconnector might be problematic. Click here for more details.
- {tostorchconnector-1.0.7/tostorchconnector.egg-info → tostorchconnector-1.0.9}/PKG-INFO +2 -2
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/pyproject.toml +2 -2
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tests/test_tos_dataset.py +51 -14
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tests/test_tosclient.py +3 -3
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tostorchconnector/tos_client.py +101 -36
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tostorchconnector/tos_common.py +16 -5
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tostorchconnector/tos_iterable_dataset.py +2 -2
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tostorchconnector/tos_map_dataset.py +6 -6
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tostorchconnector/tos_object_reader.py +1 -1
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9/tostorchconnector.egg-info}/PKG-INFO +2 -2
- tostorchconnector-1.0.9/tostorchconnector.egg-info/requires.txt +3 -0
- tostorchconnector-1.0.7/tostorchconnector.egg-info/requires.txt +0 -3
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/LICENSE +0 -0
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/README.md +0 -0
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/setup.cfg +0 -0
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tests/test_tosrawclient.py +0 -0
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tostorchconnector/__init__.py +0 -0
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tostorchconnector/tos_checkpoint.py +0 -0
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tostorchconnector/tos_object_meta.py +0 -0
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tostorchconnector/tos_object_writer.py +0 -0
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tostorchconnector.egg-info/SOURCES.txt +0 -0
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tostorchconnector.egg-info/dependency_links.txt +0 -0
- {tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tostorchconnector.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tostorchconnector
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.9
|
|
4
4
|
Summary: TOS connector integration for PyTorch
|
|
5
5
|
Author-email: xiangshijian <xiangshijian@bytedance.com>
|
|
6
6
|
Classifier: Development Status :: 4 - Beta
|
|
@@ -19,7 +19,7 @@ Description-Content-Type: text/markdown
|
|
|
19
19
|
License-File: LICENSE
|
|
20
20
|
Requires-Dist: torch>=2.0
|
|
21
21
|
Requires-Dist: tos>=2.8.0
|
|
22
|
-
Requires-Dist: tosnativeclient>=1.0.
|
|
22
|
+
Requires-Dist: tosnativeclient>=1.0.6
|
|
23
23
|
Dynamic: license-file
|
|
24
24
|
|
|
25
25
|
# TOS Connector for pytorch
|
|
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|
|
5
5
|
|
|
6
6
|
[project]
|
|
7
7
|
name = "tostorchconnector"
|
|
8
|
-
version = "1.0.
|
|
8
|
+
version = "1.0.9"
|
|
9
9
|
description = "TOS connector integration for PyTorch"
|
|
10
10
|
authors = [{ name = "xiangshijian", email = "xiangshijian@bytedance.com" }]
|
|
11
11
|
requires-python = ">=3.8,<3.14"
|
|
@@ -26,7 +26,7 @@ classifiers = [
|
|
|
26
26
|
dependencies = [
|
|
27
27
|
"torch >= 2.0",
|
|
28
28
|
"tos>=2.8.0",
|
|
29
|
-
"tosnativeclient >= 1.0.
|
|
29
|
+
"tosnativeclient >= 1.0.6"
|
|
30
30
|
]
|
|
31
31
|
|
|
32
32
|
[tool.setuptools.packages]
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import pickle
|
|
2
3
|
import unittest
|
|
3
4
|
|
|
4
5
|
from torch.utils.data import DataLoader
|
|
@@ -7,6 +8,7 @@ from tostorchconnector import TosMapDataset, TosIterableDataset, TosCheckpoint
|
|
|
7
8
|
from tostorchconnector.tos_client import CredentialProvider, ReaderType
|
|
8
9
|
|
|
9
10
|
USE_NATIVE_CLIENT = True
|
|
11
|
+
READER_TYPE = ReaderType.SEQUENTIAL
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
class TestTosDataSet(unittest.TestCase):
|
|
@@ -24,24 +26,50 @@ class TestTosDataSet(unittest.TestCase):
|
|
|
24
26
|
for i in range(len(datasets)):
|
|
25
27
|
print(datasets[i].bucket, datasets[i].key)
|
|
26
28
|
|
|
29
|
+
def test_pickle(self):
|
|
30
|
+
region = os.getenv('TOS_REGION')
|
|
31
|
+
endpoint = os.getenv('TOS_ENDPOINT')
|
|
32
|
+
ak = os.getenv('TOS_ACCESS_KEY')
|
|
33
|
+
sk = os.getenv('TOS_SECRET_KEY')
|
|
34
|
+
bucket = 'tos-pytorch-connector'
|
|
35
|
+
datasets = TosMapDataset.from_prefix(f'tos://{bucket}', region=region,
|
|
36
|
+
endpoint=endpoint, cred=CredentialProvider(ak, sk),
|
|
37
|
+
use_native_client=USE_NATIVE_CLIENT)
|
|
38
|
+
pickled_datasets = pickle.dumps(datasets)
|
|
39
|
+
assert isinstance(pickled_datasets, bytes)
|
|
40
|
+
unpickled_datasets = pickle.loads(pickled_datasets)
|
|
41
|
+
i = 0
|
|
42
|
+
for dataset in unpickled_datasets:
|
|
43
|
+
print(dataset.bucket, dataset.key)
|
|
44
|
+
i += 1
|
|
45
|
+
print(i)
|
|
46
|
+
|
|
47
|
+
pickled = pickle.dumps(unpickled_datasets._data_set)
|
|
48
|
+
assert isinstance(pickled, bytes)
|
|
49
|
+
|
|
27
50
|
def test_from_prefix(self):
|
|
28
51
|
region = os.getenv('TOS_REGION')
|
|
29
52
|
endpoint = os.getenv('TOS_ENDPOINT')
|
|
30
53
|
ak = os.getenv('TOS_ACCESS_KEY')
|
|
31
54
|
sk = os.getenv('TOS_SECRET_KEY')
|
|
32
55
|
bucket = 'tos-pytorch-connector'
|
|
33
|
-
datasets = TosMapDataset.from_prefix(f'tos://{bucket}
|
|
56
|
+
datasets = TosMapDataset.from_prefix(f'tos://{bucket}', region=region,
|
|
34
57
|
endpoint=endpoint, cred=CredentialProvider(ak, sk),
|
|
35
58
|
use_native_client=USE_NATIVE_CLIENT)
|
|
36
59
|
|
|
60
|
+
count = 0
|
|
37
61
|
for i in range(len(datasets)):
|
|
38
|
-
|
|
39
|
-
print(
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
62
|
+
dataset = datasets[i]
|
|
63
|
+
print(dataset.bucket, dataset.key)
|
|
64
|
+
count += 1
|
|
65
|
+
dataset.close()
|
|
66
|
+
# if i == 1:
|
|
67
|
+
# item = datasets[i]
|
|
68
|
+
# data = item.read(100)
|
|
69
|
+
# print(data)
|
|
70
|
+
# print(len(data))
|
|
71
|
+
|
|
72
|
+
print(count)
|
|
45
73
|
|
|
46
74
|
def test_from_prefix_iter(self):
|
|
47
75
|
region = os.getenv('TOS_REGION')
|
|
@@ -49,17 +77,23 @@ class TestTosDataSet(unittest.TestCase):
|
|
|
49
77
|
ak = os.getenv('TOS_ACCESS_KEY')
|
|
50
78
|
sk = os.getenv('TOS_SECRET_KEY')
|
|
51
79
|
bucket = 'tos-pytorch-connector'
|
|
52
|
-
datasets = TosIterableDataset.from_prefix(f'tos://{bucket}
|
|
80
|
+
datasets = TosIterableDataset.from_prefix(f'tos://{bucket}', region=region,
|
|
53
81
|
endpoint=endpoint, cred=CredentialProvider(ak, sk),
|
|
54
|
-
use_native_client=USE_NATIVE_CLIENT)
|
|
82
|
+
use_native_client=USE_NATIVE_CLIENT, reader_type=ReaderType.RANGED)
|
|
55
83
|
i = 0
|
|
56
84
|
for dataset in datasets:
|
|
57
85
|
print(dataset.bucket, dataset.key)
|
|
58
|
-
if
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
86
|
+
# if dataset.key == 'tosutil':
|
|
87
|
+
# with open('logs/tosutil', 'wb') as f:
|
|
88
|
+
# while 1:
|
|
89
|
+
# chunk = dataset.read(8192)
|
|
90
|
+
# print(len(chunk))
|
|
91
|
+
# if not chunk:
|
|
92
|
+
# break
|
|
93
|
+
# f.write(chunk)
|
|
94
|
+
dataset.close()
|
|
62
95
|
i += 1
|
|
96
|
+
print(i)
|
|
63
97
|
|
|
64
98
|
def test_checkpoint(self):
|
|
65
99
|
region = os.getenv('TOS_REGION')
|
|
@@ -75,6 +109,9 @@ class TestTosDataSet(unittest.TestCase):
|
|
|
75
109
|
writer.write(b'hello world')
|
|
76
110
|
writer.write(b'hi world')
|
|
77
111
|
|
|
112
|
+
with checkpoint.reader(url) as reader:
|
|
113
|
+
print(reader.read())
|
|
114
|
+
|
|
78
115
|
with checkpoint.reader(url) as reader:
|
|
79
116
|
data = reader.read(5)
|
|
80
117
|
print(data)
|
|
@@ -51,9 +51,9 @@ class TestTosClient(unittest.TestCase):
|
|
|
51
51
|
for content in objects.contents:
|
|
52
52
|
count += 1
|
|
53
53
|
print(content.key, content.size)
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
54
|
+
output = tos_client.head_object(bucket, content.key)
|
|
55
|
+
assert output.etag == content.etag
|
|
56
|
+
assert output.size == content.size
|
|
57
57
|
|
|
58
58
|
print(count)
|
|
59
59
|
except TosException as e:
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
import enum
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
+
import gc
|
|
5
|
+
|
|
4
6
|
from functools import partial
|
|
5
|
-
from typing import Optional, List, Tuple
|
|
7
|
+
from typing import Optional, List, Tuple, Any
|
|
6
8
|
|
|
7
9
|
import tos
|
|
8
10
|
|
|
@@ -15,6 +17,44 @@ from .tos_object_writer import PutObjectStream
|
|
|
15
17
|
|
|
16
18
|
log = logging.getLogger(__name__)
|
|
17
19
|
|
|
20
|
+
import threading
|
|
21
|
+
import weakref
|
|
22
|
+
import traceback
|
|
23
|
+
|
|
24
|
+
_client_lock = threading.Lock()
|
|
25
|
+
_client_map = weakref.WeakSet()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _before_fork():
|
|
29
|
+
with _client_lock:
|
|
30
|
+
clients = list(_client_map)
|
|
31
|
+
|
|
32
|
+
if not clients or len(clients) == 0:
|
|
33
|
+
return
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
for client in clients:
|
|
37
|
+
client._inner_client = None
|
|
38
|
+
|
|
39
|
+
_reset_client_map()
|
|
40
|
+
gc.collect()
|
|
41
|
+
except Exception as e:
|
|
42
|
+
log.warning(f'failed to clean up native clients, {str(e)}')
|
|
43
|
+
traceback.print_exc()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _after_fork_in_child():
|
|
47
|
+
_reset_client_map()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _reset_client_map():
|
|
51
|
+
global _client_map
|
|
52
|
+
with _client_lock:
|
|
53
|
+
_client_map = weakref.WeakSet()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
os.register_at_fork(before=_before_fork, after_in_child=_after_fork_in_child)
|
|
57
|
+
|
|
18
58
|
|
|
19
59
|
class ReaderType(enum.Enum):
|
|
20
60
|
SEQUENTIAL = 'Sequential'
|
|
@@ -79,40 +119,62 @@ class TosClient(object):
|
|
|
79
119
|
def __init__(self, region: str, endpoint: Optional[str] = None, cred: Optional[CredentialProvider] = None,
|
|
80
120
|
client_conf: Optional[TosClientConfig] = None, log_conf: Optional[TosLogConfig] = None,
|
|
81
121
|
use_native_client: bool = True):
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
self.
|
|
122
|
+
self._region = region
|
|
123
|
+
self._endpoint = endpoint
|
|
124
|
+
self._cred = CredentialProvider('', '') if cred is None else cred
|
|
125
|
+
self._client_conf = TosClientConfig() if client_conf is None else client_conf
|
|
126
|
+
self._log_conf = TosLogConfig() if log_conf is None else log_conf
|
|
127
|
+
self._part_size = self._client_conf.part_size
|
|
86
128
|
self._use_native_client = use_native_client
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
129
|
+
self._inner_client = None
|
|
130
|
+
self._client_pid = None
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def _client(self) -> Any:
|
|
134
|
+
if self._client_pid is None or self._client_pid != os.getpid() or self._inner_client is None:
|
|
135
|
+
with _client_lock:
|
|
136
|
+
if self._use_native_client:
|
|
137
|
+
directives = ''
|
|
138
|
+
directory = ''
|
|
139
|
+
file_name_prefix = ''
|
|
140
|
+
if self._log_conf.log_dir and self._log_conf.log_file_name:
|
|
141
|
+
if self._log_conf.log_level:
|
|
142
|
+
if self._log_conf.log_level == logging.DEBUG:
|
|
143
|
+
directives = 'debug'
|
|
144
|
+
elif self._log_conf.log_level == logging.INFO:
|
|
145
|
+
directives = 'info'
|
|
146
|
+
elif self._log_conf.log_level == logging.WARN:
|
|
147
|
+
directives = 'warn'
|
|
148
|
+
elif self._log_conf.log_level == logging.ERROR:
|
|
149
|
+
directives = 'error'
|
|
150
|
+
else:
|
|
151
|
+
directives = 'info'
|
|
152
|
+
directory = self._log_conf.log_dir
|
|
153
|
+
file_name_prefix = self._log_conf.log_file_name
|
|
154
|
+
# reset log_conf to avoid panic
|
|
155
|
+
self._log_conf = TosLogConfig()
|
|
156
|
+
|
|
157
|
+
self._inner_client = tosnativeclient.TosClient(self._region, self._endpoint, self._cred.ak,
|
|
158
|
+
self._cred.sk,
|
|
159
|
+
self._client_conf.part_size,
|
|
160
|
+
self._client_conf.max_retry_count,
|
|
161
|
+
directives=directives,
|
|
162
|
+
directory=directory,
|
|
163
|
+
file_name_prefix=file_name_prefix)
|
|
164
|
+
else:
|
|
165
|
+
self._inner_client = tos.TosClientV2(self._cred.ak, self._cred.sk, endpoint=self._endpoint,
|
|
166
|
+
region=self._region,
|
|
167
|
+
max_retry_count=self._client_conf.max_retry_count)
|
|
168
|
+
if self._log_conf.log_dir and self._log_conf.log_file_name:
|
|
169
|
+
file_path = os.path.join(self._log_conf.log_dir, self._log_conf.log_file_name)
|
|
170
|
+
log_level = self._log_conf.log_level if self._log_conf.log_level else logging.INFO
|
|
171
|
+
tos.set_logger(file_path=file_path, level=log_level)
|
|
172
|
+
|
|
173
|
+
self._client_pid = os.getpid()
|
|
174
|
+
_client_map.add(self)
|
|
175
|
+
|
|
176
|
+
assert self._inner_client is not None
|
|
177
|
+
return self._inner_client
|
|
116
178
|
|
|
117
179
|
@property
|
|
118
180
|
def use_native_client(self) -> bool:
|
|
@@ -155,12 +217,15 @@ class TosClient(object):
|
|
|
155
217
|
return TosObjectMeta(bucket, key, resp.content_length, resp.etag)
|
|
156
218
|
|
|
157
219
|
def gen_list_stream(self, bucket: str, prefix: str, max_keys: int = 1000,
|
|
158
|
-
delimiter: Optional[str] = None
|
|
220
|
+
delimiter: Optional[str] = None,
|
|
221
|
+
continuation_token: Optional[str] = None) -> tosnativeclient.ListStream:
|
|
159
222
|
log.debug(f'gen_list_stream tos://{bucket}/{prefix}')
|
|
160
223
|
|
|
161
224
|
if isinstance(self._client, tosnativeclient.TosClient):
|
|
162
225
|
delimiter = delimiter if delimiter is not None else ''
|
|
163
|
-
|
|
226
|
+
continuation_token = continuation_token if continuation_token is not None else ''
|
|
227
|
+
return self._client.list_objects(bucket, prefix, max_keys=max_keys, delimiter=delimiter,
|
|
228
|
+
continuation_token=continuation_token)
|
|
164
229
|
raise NotImplementedError()
|
|
165
230
|
|
|
166
231
|
def list_objects(self, bucket: str, prefix: str, max_keys: int = 1000,
|
|
@@ -8,19 +8,28 @@ from .tos_object_meta import TosObjectMeta
|
|
|
8
8
|
log = logging.getLogger(__name__)
|
|
9
9
|
|
|
10
10
|
|
|
11
|
+
class TosObjectIterable(object):
|
|
12
|
+
def __init__(self, bucket: str, prefix: str, client: TosClient):
|
|
13
|
+
self._bucket = bucket
|
|
14
|
+
self._prefix = prefix
|
|
15
|
+
self._client = client
|
|
16
|
+
|
|
17
|
+
def __iter__(self) -> Iterator[TosObjectMeta]:
|
|
18
|
+
return iter(TosObjectIterator(self._bucket, self._prefix, self._client))
|
|
19
|
+
|
|
20
|
+
|
|
11
21
|
class TosObjectIterator(object):
|
|
12
22
|
def __init__(self, bucket: str, prefix: str, client: TosClient):
|
|
13
23
|
self._bucket = bucket
|
|
14
24
|
self._prefix = prefix
|
|
15
25
|
self._client = client
|
|
16
26
|
self._delimiter: Optional[str] = None
|
|
17
|
-
self.
|
|
27
|
+
self._continuation_token: Optional[str] = None
|
|
18
28
|
|
|
29
|
+
self._list_stream = None
|
|
19
30
|
self._object_metas = None
|
|
20
31
|
self._index = 0
|
|
21
|
-
|
|
22
32
|
self._is_truncated = True
|
|
23
|
-
self._continuation_token = None
|
|
24
33
|
|
|
25
34
|
def close(self) -> None:
|
|
26
35
|
if self._list_stream is not None:
|
|
@@ -33,13 +42,15 @@ class TosObjectIterator(object):
|
|
|
33
42
|
if self._client.use_native_client:
|
|
34
43
|
if self._list_stream is None:
|
|
35
44
|
self._list_stream = self._client.gen_list_stream(self._bucket, self._prefix, max_keys=1000,
|
|
36
|
-
delimiter=self._delimiter
|
|
45
|
+
delimiter=self._delimiter,
|
|
46
|
+
continuation_token=self._continuation_token)
|
|
37
47
|
|
|
38
48
|
if self._object_metas is None or self._index >= len(self._object_metas):
|
|
39
49
|
self._object_metas = []
|
|
40
50
|
self._index = 0
|
|
41
51
|
while 1:
|
|
42
52
|
objects = next(self._list_stream)
|
|
53
|
+
self._continuation_token = self._list_stream.current_continuation_token()
|
|
43
54
|
for content in objects.contents:
|
|
44
55
|
self._object_metas.append(
|
|
45
56
|
TosObjectMeta(content.bucket, content.key, content.size, content.etag))
|
|
@@ -101,4 +112,4 @@ def gen_dataset_from_urls(urls: Union[str, Iterator[str]], _: TosClient) -> Iter
|
|
|
101
112
|
|
|
102
113
|
def gen_dataset_from_prefix(prefix: str, client: TosClient) -> Iterator[TosObjectMeta]:
|
|
103
114
|
bucket, prefix = parse_tos_url(prefix)
|
|
104
|
-
return iter(
|
|
115
|
+
return iter(TosObjectIterable(bucket, prefix, client))
|
{tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tostorchconnector/tos_iterable_dataset.py
RENAMED
|
@@ -16,7 +16,7 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
|
|
|
16
16
|
def __init__(self, region: str,
|
|
17
17
|
gen_dataset: Callable[[TosClient], Iterator[TosObjectMeta]],
|
|
18
18
|
endpoint: Optional[str] = None,
|
|
19
|
-
|
|
19
|
+
transform: Callable[[TosObjectReader], Any] = default_trans,
|
|
20
20
|
cred: Optional[CredentialProvider] = None,
|
|
21
21
|
client_conf: Optional[TosClientConfig] = None,
|
|
22
22
|
log_conf: Optional[TosLogConfig] = None,
|
|
@@ -27,7 +27,7 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
|
|
|
27
27
|
self._gen_dataset = gen_dataset
|
|
28
28
|
self._region = region
|
|
29
29
|
self._endpoint = endpoint
|
|
30
|
-
self._trans =
|
|
30
|
+
self._trans = transform
|
|
31
31
|
self._cred = cred
|
|
32
32
|
self._client_conf = client_conf
|
|
33
33
|
self._log_conf = log_conf
|
|
@@ -17,7 +17,7 @@ class TosMapDataset(torch.utils.data.Dataset):
|
|
|
17
17
|
def __init__(self, region: str,
|
|
18
18
|
gen_dataset: Callable[[TosClient], Iterator[TosObjectMeta]],
|
|
19
19
|
endpoint: Optional[str] = None,
|
|
20
|
-
|
|
20
|
+
transform: Callable[[TosObjectReader], Any] = default_trans,
|
|
21
21
|
cred: Optional[CredentialProvider] = None,
|
|
22
22
|
client_conf: Optional[TosClientConfig] = None,
|
|
23
23
|
log_conf: Optional[TosLogConfig] = None,
|
|
@@ -27,7 +27,7 @@ class TosMapDataset(torch.utils.data.Dataset):
|
|
|
27
27
|
self._gen_dataset = gen_dataset
|
|
28
28
|
self._region = region
|
|
29
29
|
self._endpoint = endpoint
|
|
30
|
-
self._trans =
|
|
30
|
+
self._trans = transform
|
|
31
31
|
self._cred = cred
|
|
32
32
|
self._client_conf = client_conf
|
|
33
33
|
self._log_conf = log_conf
|
|
@@ -40,7 +40,7 @@ class TosMapDataset(torch.utils.data.Dataset):
|
|
|
40
40
|
|
|
41
41
|
@classmethod
|
|
42
42
|
def from_urls(cls, urls: Union[str, Iterator[str]], *, region: str, endpoint: Optional[str] = None,
|
|
43
|
-
|
|
43
|
+
transform: Callable[[TosObjectReader], Any] = default_trans,
|
|
44
44
|
cred: Optional[CredentialProvider] = None,
|
|
45
45
|
client_conf: Optional[TosClientConfig] = None,
|
|
46
46
|
log_conf: Optional[TosLogConfig] = None,
|
|
@@ -48,12 +48,12 @@ class TosMapDataset(torch.utils.data.Dataset):
|
|
|
48
48
|
reader_type: Optional[ReaderType] = None,
|
|
49
49
|
buffer_size: Optional[int] = None):
|
|
50
50
|
log.info(f'building {cls.__name__} from_urls')
|
|
51
|
-
return cls(region, partial(gen_dataset_from_urls, urls), endpoint,
|
|
51
|
+
return cls(region, partial(gen_dataset_from_urls, urls), endpoint, transform, cred, client_conf, log_conf,
|
|
52
52
|
use_native_client, reader_type, buffer_size)
|
|
53
53
|
|
|
54
54
|
@classmethod
|
|
55
55
|
def from_prefix(cls, prefix: str, *, region: str, endpoint: Optional[str] = None,
|
|
56
|
-
|
|
56
|
+
transform: Callable[[TosObjectReader], Any] = default_trans,
|
|
57
57
|
cred: Optional[CredentialProvider] = None,
|
|
58
58
|
client_conf: Optional[TosClientConfig] = None,
|
|
59
59
|
log_conf: Optional[TosLogConfig] = None,
|
|
@@ -61,7 +61,7 @@ class TosMapDataset(torch.utils.data.Dataset):
|
|
|
61
61
|
reader_type: Optional[ReaderType] = None,
|
|
62
62
|
buffer_size: Optional[int] = None):
|
|
63
63
|
log.info(f'building {cls.__name__} from_prefix')
|
|
64
|
-
return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint,
|
|
64
|
+
return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, transform, cred, client_conf, log_conf,
|
|
65
65
|
use_native_client, reader_type, buffer_size)
|
|
66
66
|
|
|
67
67
|
def __getitem__(self, i: int) -> Any:
|
|
@@ -85,7 +85,7 @@ class TosObjectStream(object):
|
|
|
85
85
|
def close(self) -> None:
|
|
86
86
|
if self._sequential_object_stream and isinstance(self._sequential_object_stream, tosnativeclient.ReadStream):
|
|
87
87
|
self._sequential_object_stream.close()
|
|
88
|
-
if self._random_object_stream:
|
|
88
|
+
if self._random_object_stream and isinstance(self._random_object_stream, tosnativeclient.ReadStream):
|
|
89
89
|
self._random_object_stream.close()
|
|
90
90
|
|
|
91
91
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tostorchconnector
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.9
|
|
4
4
|
Summary: TOS connector integration for PyTorch
|
|
5
5
|
Author-email: xiangshijian <xiangshijian@bytedance.com>
|
|
6
6
|
Classifier: Development Status :: 4 - Beta
|
|
@@ -19,7 +19,7 @@ Description-Content-Type: text/markdown
|
|
|
19
19
|
License-File: LICENSE
|
|
20
20
|
Requires-Dist: torch>=2.0
|
|
21
21
|
Requires-Dist: tos>=2.8.0
|
|
22
|
-
Requires-Dist: tosnativeclient>=1.0.
|
|
22
|
+
Requires-Dist: tosnativeclient>=1.0.6
|
|
23
23
|
Dynamic: license-file
|
|
24
24
|
|
|
25
25
|
# TOS Connector for pytorch
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tostorchconnector.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
{tostorchconnector-1.0.7 → tostorchconnector-1.0.9}/tostorchconnector.egg-info/top_level.txt
RENAMED
|
File without changes
|