tostorchconnector 1.0.6__tar.gz → 1.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tostorchconnector might be problematic. Click here for more details.
- {tostorchconnector-1.0.6/tostorchconnector.egg-info → tostorchconnector-1.1.2}/PKG-INFO +2 -2
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/pyproject.toml +2 -2
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tests/test_tos_dataset.py +52 -14
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tests/test_tosclient.py +34 -14
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/tos_checkpoint.py +4 -4
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/tos_client.py +79 -38
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/tos_common.py +31 -13
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/tos_iterable_dataset.py +16 -16
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/tos_map_dataset.py +18 -18
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/tos_object_reader.py +1 -1
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2/tostorchconnector.egg-info}/PKG-INFO +2 -2
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector.egg-info/SOURCES.txt +0 -1
- tostorchconnector-1.1.2/tostorchconnector.egg-info/requires.txt +3 -0
- tostorchconnector-1.0.6/tests/test_tosrawclient.py +0 -102
- tostorchconnector-1.0.6/tostorchconnector.egg-info/requires.txt +0 -3
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/LICENSE +0 -0
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/README.md +0 -0
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/setup.cfg +0 -0
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/__init__.py +0 -0
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/tos_object_meta.py +0 -0
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/tos_object_writer.py +0 -0
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector.egg-info/dependency_links.txt +0 -0
- {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tostorchconnector
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.1.2
|
|
4
4
|
Summary: TOS connector integration for PyTorch
|
|
5
5
|
Author-email: xiangshijian <xiangshijian@bytedance.com>
|
|
6
6
|
Classifier: Development Status :: 4 - Beta
|
|
@@ -19,7 +19,7 @@ Description-Content-Type: text/markdown
|
|
|
19
19
|
License-File: LICENSE
|
|
20
20
|
Requires-Dist: torch>=2.0
|
|
21
21
|
Requires-Dist: tos>=2.8.0
|
|
22
|
-
Requires-Dist: tosnativeclient>=1.
|
|
22
|
+
Requires-Dist: tosnativeclient>=1.1.1
|
|
23
23
|
Dynamic: license-file
|
|
24
24
|
|
|
25
25
|
# TOS Connector for pytorch
|
|
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|
|
5
5
|
|
|
6
6
|
[project]
|
|
7
7
|
name = "tostorchconnector"
|
|
8
|
-
version = "1.
|
|
8
|
+
version = "1.1.2"
|
|
9
9
|
description = "TOS connector integration for PyTorch"
|
|
10
10
|
authors = [{ name = "xiangshijian", email = "xiangshijian@bytedance.com" }]
|
|
11
11
|
requires-python = ">=3.8,<3.14"
|
|
@@ -26,7 +26,7 @@ classifiers = [
|
|
|
26
26
|
dependencies = [
|
|
27
27
|
"torch >= 2.0",
|
|
28
28
|
"tos>=2.8.0",
|
|
29
|
-
"tosnativeclient >= 1.
|
|
29
|
+
"tosnativeclient >= 1.1.1"
|
|
30
30
|
]
|
|
31
31
|
|
|
32
32
|
[tool.setuptools.packages]
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import pickle
|
|
2
3
|
import unittest
|
|
3
4
|
|
|
4
5
|
from tostorchconnector import TosMapDataset, TosIterableDataset, TosCheckpoint
|
|
5
6
|
from tostorchconnector.tos_client import CredentialProvider, ReaderType
|
|
6
7
|
|
|
7
8
|
USE_NATIVE_CLIENT = True
|
|
9
|
+
READER_TYPE = ReaderType.SEQUENTIAL
|
|
8
10
|
|
|
9
11
|
|
|
10
12
|
class TestTosDataSet(unittest.TestCase):
|
|
@@ -22,23 +24,50 @@ class TestTosDataSet(unittest.TestCase):
|
|
|
22
24
|
for i in range(len(datasets)):
|
|
23
25
|
print(datasets[i].bucket, datasets[i].key)
|
|
24
26
|
|
|
27
|
+
def test_pickle(self):
|
|
28
|
+
region = os.getenv('TOS_REGION')
|
|
29
|
+
endpoint = os.getenv('TOS_ENDPOINT')
|
|
30
|
+
ak = os.getenv('TOS_ACCESS_KEY')
|
|
31
|
+
sk = os.getenv('TOS_SECRET_KEY')
|
|
32
|
+
bucket = 'tos-pytorch-connector'
|
|
33
|
+
datasets = TosMapDataset.from_prefix(f'tos://{bucket}', region=region,
|
|
34
|
+
endpoint=endpoint, cred=CredentialProvider(ak, sk),
|
|
35
|
+
use_native_client=USE_NATIVE_CLIENT)
|
|
36
|
+
pickled_datasets = pickle.dumps(datasets)
|
|
37
|
+
assert isinstance(pickled_datasets, bytes)
|
|
38
|
+
unpickled_datasets = pickle.loads(pickled_datasets)
|
|
39
|
+
i = 0
|
|
40
|
+
for dataset in unpickled_datasets:
|
|
41
|
+
print(dataset.bucket, dataset.key)
|
|
42
|
+
i += 1
|
|
43
|
+
print(i)
|
|
44
|
+
|
|
45
|
+
pickled = pickle.dumps(unpickled_datasets._data_set)
|
|
46
|
+
assert isinstance(pickled, bytes)
|
|
47
|
+
|
|
25
48
|
def test_from_prefix(self):
|
|
26
49
|
region = os.getenv('TOS_REGION')
|
|
27
50
|
endpoint = os.getenv('TOS_ENDPOINT')
|
|
28
51
|
ak = os.getenv('TOS_ACCESS_KEY')
|
|
29
52
|
sk = os.getenv('TOS_SECRET_KEY')
|
|
30
53
|
bucket = 'tos-pytorch-connector'
|
|
31
|
-
datasets = TosMapDataset.from_prefix(f'tos://{bucket}
|
|
54
|
+
datasets = TosMapDataset.from_prefix(f'tos://{bucket}', region=region,
|
|
32
55
|
endpoint=endpoint, cred=CredentialProvider(ak, sk),
|
|
33
56
|
use_native_client=USE_NATIVE_CLIENT)
|
|
57
|
+
|
|
58
|
+
count = 0
|
|
34
59
|
for i in range(len(datasets)):
|
|
35
|
-
|
|
36
|
-
print(
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
60
|
+
dataset = datasets[i]
|
|
61
|
+
print(dataset.bucket, dataset.key)
|
|
62
|
+
count += 1
|
|
63
|
+
dataset.close()
|
|
64
|
+
# if i == 1:
|
|
65
|
+
# item = datasets[i]
|
|
66
|
+
# data = item.read(100)
|
|
67
|
+
# print(data)
|
|
68
|
+
# print(len(data))
|
|
69
|
+
|
|
70
|
+
print(count)
|
|
42
71
|
|
|
43
72
|
def test_from_prefix_iter(self):
|
|
44
73
|
region = os.getenv('TOS_REGION')
|
|
@@ -46,17 +75,23 @@ class TestTosDataSet(unittest.TestCase):
|
|
|
46
75
|
ak = os.getenv('TOS_ACCESS_KEY')
|
|
47
76
|
sk = os.getenv('TOS_SECRET_KEY')
|
|
48
77
|
bucket = 'tos-pytorch-connector'
|
|
49
|
-
datasets = TosIterableDataset.from_prefix(f'tos://{bucket}
|
|
78
|
+
datasets = TosIterableDataset.from_prefix(f'tos://{bucket}', region=region,
|
|
50
79
|
endpoint=endpoint, cred=CredentialProvider(ak, sk),
|
|
51
|
-
use_native_client=USE_NATIVE_CLIENT)
|
|
80
|
+
use_native_client=USE_NATIVE_CLIENT, reader_type=ReaderType.RANGED)
|
|
52
81
|
i = 0
|
|
53
82
|
for dataset in datasets:
|
|
54
83
|
print(dataset.bucket, dataset.key)
|
|
55
|
-
if
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
84
|
+
# if dataset.key == 'tosutil':
|
|
85
|
+
# with open('logs/tosutil', 'wb') as f:
|
|
86
|
+
# while 1:
|
|
87
|
+
# chunk = dataset.read(8192)
|
|
88
|
+
# print(len(chunk))
|
|
89
|
+
# if not chunk:
|
|
90
|
+
# break
|
|
91
|
+
# f.write(chunk)
|
|
92
|
+
dataset.close()
|
|
59
93
|
i += 1
|
|
94
|
+
print(i)
|
|
60
95
|
|
|
61
96
|
def test_checkpoint(self):
|
|
62
97
|
region = os.getenv('TOS_REGION')
|
|
@@ -72,6 +107,9 @@ class TestTosDataSet(unittest.TestCase):
|
|
|
72
107
|
writer.write(b'hello world')
|
|
73
108
|
writer.write(b'hi world')
|
|
74
109
|
|
|
110
|
+
with checkpoint.reader(url) as reader:
|
|
111
|
+
print(reader.read())
|
|
112
|
+
|
|
75
113
|
with checkpoint.reader(url) as reader:
|
|
76
114
|
data = reader.read(5)
|
|
77
115
|
print(data)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import pickle
|
|
2
3
|
import unittest
|
|
3
4
|
import uuid
|
|
4
5
|
|
|
@@ -6,15 +7,40 @@ from tosnativeclient import TosClient, TosException
|
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class TestTosClient(unittest.TestCase):
|
|
10
|
+
def test_pickle(self):
|
|
11
|
+
region = os.getenv('TOS_REGION')
|
|
12
|
+
endpoint = os.getenv('TOS_ENDPOINT')
|
|
13
|
+
ak = os.getenv('TOS_ACCESS_KEY')
|
|
14
|
+
sk = os.getenv('TOS_SECRET_KEY')
|
|
15
|
+
bucket = 'tos-pytorch-connector'
|
|
16
|
+
tos_client = TosClient(region, endpoint, ak, sk)
|
|
17
|
+
pickled_tos_client = pickle.dumps(tos_client)
|
|
18
|
+
assert isinstance(pickled_tos_client, bytes)
|
|
19
|
+
unpickled_tos_client = pickle.loads(pickled_tos_client)
|
|
20
|
+
|
|
21
|
+
self._test_list_objects(bucket, unpickled_tos_client)
|
|
22
|
+
self._test_write_read_object(bucket, unpickled_tos_client)
|
|
23
|
+
|
|
9
24
|
def test_list_objects(self):
|
|
10
25
|
region = os.getenv('TOS_REGION')
|
|
11
26
|
endpoint = os.getenv('TOS_ENDPOINT')
|
|
12
27
|
ak = os.getenv('TOS_ACCESS_KEY')
|
|
13
28
|
sk = os.getenv('TOS_SECRET_KEY')
|
|
14
29
|
bucket = 'tos-pytorch-connector'
|
|
15
|
-
tos_client = TosClient(region, endpoint, ak, sk
|
|
16
|
-
|
|
30
|
+
tos_client = TosClient(region, endpoint, ak, sk)
|
|
31
|
+
self._test_list_objects(bucket, tos_client)
|
|
17
32
|
|
|
33
|
+
def test_write_read_object(self):
|
|
34
|
+
region = os.getenv('TOS_REGION')
|
|
35
|
+
endpoint = os.getenv('TOS_ENDPOINT')
|
|
36
|
+
ak = os.getenv('TOS_ACCESS_KEY')
|
|
37
|
+
sk = os.getenv('TOS_SECRET_KEY')
|
|
38
|
+
bucket = 'tos-pytorch-connector'
|
|
39
|
+
tos_client = TosClient(region, endpoint, ak, sk)
|
|
40
|
+
|
|
41
|
+
self._test_write_read_object(bucket, tos_client)
|
|
42
|
+
|
|
43
|
+
def _test_list_objects(self, bucket, tos_client):
|
|
18
44
|
list_stream = tos_client.list_objects(bucket, '', max_keys=1000)
|
|
19
45
|
count = 0
|
|
20
46
|
try:
|
|
@@ -22,23 +48,15 @@ class TestTosClient(unittest.TestCase):
|
|
|
22
48
|
for content in objects.contents:
|
|
23
49
|
count += 1
|
|
24
50
|
print(content.key, content.size)
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
51
|
+
output = tos_client.head_object(bucket, content.key)
|
|
52
|
+
assert output.etag == content.etag
|
|
53
|
+
assert output.size == content.size
|
|
28
54
|
|
|
29
55
|
print(count)
|
|
30
56
|
except TosException as e:
|
|
31
57
|
print(e.args[0].message)
|
|
32
58
|
|
|
33
|
-
def
|
|
34
|
-
region = os.getenv('TOS_REGION')
|
|
35
|
-
endpoint = os.getenv('TOS_ENDPOINT')
|
|
36
|
-
ak = os.getenv('TOS_ACCESS_KEY')
|
|
37
|
-
sk = os.getenv('TOS_SECRET_KEY')
|
|
38
|
-
bucket = 'tos-pytorch-connector'
|
|
39
|
-
tos_client = TosClient(region, endpoint, ak, sk, directives='info', directory='logs',
|
|
40
|
-
file_name_prefix='app.log')
|
|
41
|
-
|
|
59
|
+
def _test_write_read_object(self, bucket, tos_client):
|
|
42
60
|
key = str(uuid.uuid4())
|
|
43
61
|
read_stream = tos_client.get_object(bucket, key, '', 1)
|
|
44
62
|
|
|
@@ -59,6 +77,8 @@ class TestTosClient(unittest.TestCase):
|
|
|
59
77
|
write_stream.close()
|
|
60
78
|
|
|
61
79
|
output = tos_client.head_object(bucket, key)
|
|
80
|
+
print(output.etag, output.size)
|
|
81
|
+
|
|
62
82
|
read_stream = tos_client.get_object(bucket, key, output.etag, output.size)
|
|
63
83
|
try:
|
|
64
84
|
offset = 0
|
|
@@ -13,14 +13,14 @@ class TosCheckpoint(object):
|
|
|
13
13
|
endpoint: Optional[str] = None,
|
|
14
14
|
cred: Optional[CredentialProvider] = None,
|
|
15
15
|
client_conf: Optional[TosClientConfig] = None,
|
|
16
|
-
|
|
16
|
+
use_native_client=True,
|
|
17
|
+
enable_crc=True):
|
|
17
18
|
self._region = region
|
|
18
19
|
self._endpoint = endpoint
|
|
19
20
|
self._cred = cred
|
|
20
21
|
self._client_conf = client_conf
|
|
21
|
-
self.
|
|
22
|
-
|
|
23
|
-
use_native_client)
|
|
22
|
+
self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, use_native_client,
|
|
23
|
+
enable_crc)
|
|
24
24
|
log.info('TosCheckpoint init tos client succeed')
|
|
25
25
|
|
|
26
26
|
def reader(self, url: str, reader_type: Optional[ReaderType] = None,
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
import enum
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
+
import gc
|
|
5
|
+
|
|
4
6
|
from functools import partial
|
|
5
|
-
from typing import Optional, List, Tuple
|
|
7
|
+
from typing import Optional, List, Tuple, Any
|
|
6
8
|
|
|
7
9
|
import tos
|
|
8
10
|
|
|
@@ -15,6 +17,44 @@ from .tos_object_writer import PutObjectStream
|
|
|
15
17
|
|
|
16
18
|
log = logging.getLogger(__name__)
|
|
17
19
|
|
|
20
|
+
import threading
|
|
21
|
+
import weakref
|
|
22
|
+
import traceback
|
|
23
|
+
|
|
24
|
+
_client_lock = threading.Lock()
|
|
25
|
+
_client_map = weakref.WeakSet()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _before_fork():
|
|
29
|
+
with _client_lock:
|
|
30
|
+
clients = list(_client_map)
|
|
31
|
+
|
|
32
|
+
if not clients or len(clients) == 0:
|
|
33
|
+
return
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
for client in clients:
|
|
37
|
+
client._inner_client = None
|
|
38
|
+
|
|
39
|
+
_reset_client_map()
|
|
40
|
+
gc.collect()
|
|
41
|
+
except Exception as e:
|
|
42
|
+
log.warning(f'failed to clean up native clients, {str(e)}')
|
|
43
|
+
traceback.print_exc()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _after_fork_in_child():
|
|
47
|
+
_reset_client_map()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _reset_client_map():
|
|
51
|
+
global _client_map
|
|
52
|
+
with _client_lock:
|
|
53
|
+
_client_map = weakref.WeakSet()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
os.register_at_fork(before=_before_fork, after_in_child=_after_fork_in_child)
|
|
57
|
+
|
|
18
58
|
|
|
19
59
|
class ReaderType(enum.Enum):
|
|
20
60
|
SEQUENTIAL = 'Sequential'
|
|
@@ -77,42 +117,38 @@ class TosLogConfig(object):
|
|
|
77
117
|
|
|
78
118
|
class TosClient(object):
|
|
79
119
|
def __init__(self, region: str, endpoint: Optional[str] = None, cred: Optional[CredentialProvider] = None,
|
|
80
|
-
client_conf: Optional[TosClientConfig] = None,
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
self.
|
|
120
|
+
client_conf: Optional[TosClientConfig] = None, use_native_client: bool = True,
|
|
121
|
+
enable_crc: bool = True):
|
|
122
|
+
self._region = region
|
|
123
|
+
self._endpoint = endpoint
|
|
124
|
+
self._cred = CredentialProvider('', '') if cred is None else cred
|
|
125
|
+
self._client_conf = TosClientConfig() if client_conf is None else client_conf
|
|
126
|
+
self._part_size = self._client_conf.part_size
|
|
86
127
|
self._use_native_client = use_native_client
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
max_retry_count=client_conf.max_retry_count)
|
|
112
|
-
if log_conf.log_dir and log_conf.log_file_name:
|
|
113
|
-
file_path = os.path.join(log_conf.log_dir, log_conf.log_file_name)
|
|
114
|
-
log_level = log_conf.log_level if log_conf.log_level else logging.INFO
|
|
115
|
-
tos.set_logger(file_path=file_path, level=log_level)
|
|
128
|
+
self._inner_client = None
|
|
129
|
+
self._client_pid = None
|
|
130
|
+
self._enable_crc = enable_crc
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def _client(self) -> Any:
|
|
134
|
+
if self._client_pid is None or self._client_pid != os.getpid() or self._inner_client is None:
|
|
135
|
+
with _client_lock:
|
|
136
|
+
if self._use_native_client:
|
|
137
|
+
self._inner_client = tosnativeclient.TosClient(self._region, self._endpoint, self._cred.ak,
|
|
138
|
+
self._cred.sk,
|
|
139
|
+
self._client_conf.part_size,
|
|
140
|
+
self._client_conf.max_retry_count,
|
|
141
|
+
enable_crc=self._enable_crc)
|
|
142
|
+
else:
|
|
143
|
+
self._inner_client = tos.TosClientV2(self._cred.ak, self._cred.sk, endpoint=self._endpoint,
|
|
144
|
+
region=self._region,
|
|
145
|
+
max_retry_count=self._client_conf.max_retry_count,
|
|
146
|
+
enable_crc=self._enable_crc)
|
|
147
|
+
self._client_pid = os.getpid()
|
|
148
|
+
_client_map.add(self)
|
|
149
|
+
|
|
150
|
+
assert self._inner_client is not None
|
|
151
|
+
return self._inner_client
|
|
116
152
|
|
|
117
153
|
@property
|
|
118
154
|
def use_native_client(self) -> bool:
|
|
@@ -155,12 +191,17 @@ class TosClient(object):
|
|
|
155
191
|
return TosObjectMeta(bucket, key, resp.content_length, resp.etag)
|
|
156
192
|
|
|
157
193
|
def gen_list_stream(self, bucket: str, prefix: str, max_keys: int = 1000,
|
|
158
|
-
delimiter: Optional[str] = None
|
|
194
|
+
delimiter: Optional[str] = None,
|
|
195
|
+
continuation_token: Optional[str] = None,
|
|
196
|
+
list_background_buffer_count: int = 1) -> tosnativeclient.ListStream:
|
|
159
197
|
log.debug(f'gen_list_stream tos://{bucket}/{prefix}')
|
|
160
198
|
|
|
161
199
|
if isinstance(self._client, tosnativeclient.TosClient):
|
|
162
200
|
delimiter = delimiter if delimiter is not None else ''
|
|
163
|
-
|
|
201
|
+
continuation_token = continuation_token if continuation_token is not None else ''
|
|
202
|
+
return self._client.list_objects(bucket, prefix, max_keys=max_keys, delimiter=delimiter,
|
|
203
|
+
continuation_token=continuation_token,
|
|
204
|
+
list_background_buffer_count=list_background_buffer_count)
|
|
164
205
|
raise NotImplementedError()
|
|
165
206
|
|
|
166
207
|
def list_objects(self, bucket: str, prefix: str, max_keys: int = 1000,
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from typing import Union, Iterator, Tuple, Optional
|
|
3
3
|
|
|
4
|
+
from tosnativeclient import TosObject
|
|
4
5
|
from . import TosObjectReader
|
|
5
6
|
from .tos_client import TosClient
|
|
6
7
|
from .tos_object_meta import TosObjectMeta
|
|
@@ -8,19 +9,30 @@ from .tos_object_meta import TosObjectMeta
|
|
|
8
9
|
log = logging.getLogger(__name__)
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
class
|
|
12
|
+
class TosObjectIterable(object):
|
|
12
13
|
def __init__(self, bucket: str, prefix: str, client: TosClient):
|
|
13
14
|
self._bucket = bucket
|
|
14
15
|
self._prefix = prefix
|
|
16
|
+
self._list_background_buffer_count = 3
|
|
17
|
+
self._client = client
|
|
18
|
+
|
|
19
|
+
def __iter__(self) -> Iterator[TosObjectMeta]:
|
|
20
|
+
return iter(TosObjectIterator(self._bucket, self._prefix, self._list_background_buffer_count, self._client))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TosObjectIterator(object):
|
|
24
|
+
def __init__(self, bucket: str, prefix: str, list_background_buffer_count: int, client: TosClient):
|
|
25
|
+
self._bucket = bucket
|
|
26
|
+
self._prefix = prefix
|
|
27
|
+
self._list_background_buffer_count = list_background_buffer_count
|
|
15
28
|
self._client = client
|
|
16
29
|
self._delimiter: Optional[str] = None
|
|
17
|
-
self.
|
|
30
|
+
self._continuation_token: Optional[str] = None
|
|
18
31
|
|
|
32
|
+
self._list_stream = None
|
|
19
33
|
self._object_metas = None
|
|
20
34
|
self._index = 0
|
|
21
|
-
|
|
22
35
|
self._is_truncated = True
|
|
23
|
-
self._continuation_token = None
|
|
24
36
|
|
|
25
37
|
def close(self) -> None:
|
|
26
38
|
if self._list_stream is not None:
|
|
@@ -33,22 +45,28 @@ class TosObjectIterator(object):
|
|
|
33
45
|
if self._client.use_native_client:
|
|
34
46
|
if self._list_stream is None:
|
|
35
47
|
self._list_stream = self._client.gen_list_stream(self._bucket, self._prefix, max_keys=1000,
|
|
36
|
-
delimiter=self._delimiter
|
|
48
|
+
delimiter=self._delimiter,
|
|
49
|
+
continuation_token=self._continuation_token,
|
|
50
|
+
list_background_buffer_count=self._list_background_buffer_count)
|
|
37
51
|
|
|
38
52
|
if self._object_metas is None or self._index >= len(self._object_metas):
|
|
39
|
-
self._object_metas =
|
|
53
|
+
self._object_metas = None
|
|
40
54
|
self._index = 0
|
|
41
55
|
while 1:
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
56
|
+
try:
|
|
57
|
+
objects = next(self._list_stream)
|
|
58
|
+
except:
|
|
59
|
+
self.close()
|
|
60
|
+
raise
|
|
61
|
+
self._continuation_token = self._list_stream.current_continuation_token()
|
|
62
|
+
self._object_metas = objects.contents
|
|
63
|
+
if self._object_metas is not None and len(self._object_metas) > 0:
|
|
47
64
|
break
|
|
48
65
|
|
|
49
66
|
object_meta = self._object_metas[self._index]
|
|
50
67
|
self._index += 1
|
|
51
|
-
|
|
68
|
+
# this is very critical
|
|
69
|
+
return TosObjectMeta(self._bucket, object_meta.key, object_meta.size, object_meta.etag)
|
|
52
70
|
|
|
53
71
|
while self._object_metas is None or self._index >= len(self._object_metas):
|
|
54
72
|
if not self._is_truncated:
|
|
@@ -101,4 +119,4 @@ def gen_dataset_from_urls(urls: Union[str, Iterator[str]], _: TosClient) -> Iter
|
|
|
101
119
|
|
|
102
120
|
def gen_dataset_from_prefix(prefix: str, client: TosClient) -> Iterator[TosObjectMeta]:
|
|
103
121
|
bucket, prefix = parse_tos_url(prefix)
|
|
104
|
-
return iter(
|
|
122
|
+
return iter(TosObjectIterable(bucket, prefix, client))
|
{tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/tos_iterable_dataset.py
RENAMED
|
@@ -4,8 +4,9 @@ from typing import Iterator, Any, Optional, Callable, Union
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from tosnativeclient import TosObject
|
|
7
8
|
from . import TosObjectReader
|
|
8
|
-
from .tos_client import CredentialProvider, TosClientConfig, TosClient,
|
|
9
|
+
from .tos_client import CredentialProvider, TosClientConfig, TosClient, ReaderType
|
|
9
10
|
from .tos_common import default_trans, gen_dataset_from_urls, gen_dataset_from_prefix
|
|
10
11
|
from .tos_object_meta import TosObjectMeta
|
|
11
12
|
|
|
@@ -16,21 +17,20 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
|
|
|
16
17
|
def __init__(self, region: str,
|
|
17
18
|
gen_dataset: Callable[[TosClient], Iterator[TosObjectMeta]],
|
|
18
19
|
endpoint: Optional[str] = None,
|
|
19
|
-
|
|
20
|
+
transform: Callable[[TosObjectReader], Any] = default_trans,
|
|
20
21
|
cred: Optional[CredentialProvider] = None,
|
|
21
22
|
client_conf: Optional[TosClientConfig] = None,
|
|
22
|
-
log_conf: Optional[TosLogConfig] = None,
|
|
23
23
|
sharding: bool = False,
|
|
24
24
|
use_native_client: bool = True,
|
|
25
25
|
reader_type: Optional[ReaderType] = None,
|
|
26
|
-
buffer_size: Optional[int] = None
|
|
26
|
+
buffer_size: Optional[int] = None,
|
|
27
|
+
enable_crc: bool = True):
|
|
27
28
|
self._gen_dataset = gen_dataset
|
|
28
29
|
self._region = region
|
|
29
30
|
self._endpoint = endpoint
|
|
30
|
-
self._trans =
|
|
31
|
+
self._trans = transform
|
|
31
32
|
self._cred = cred
|
|
32
33
|
self._client_conf = client_conf
|
|
33
|
-
self._log_conf = log_conf
|
|
34
34
|
self._sharding = sharding
|
|
35
35
|
if torch.distributed.is_initialized():
|
|
36
36
|
self._rank = torch.distributed.get_rank()
|
|
@@ -40,8 +40,8 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
|
|
|
40
40
|
self._world_size = 1
|
|
41
41
|
self._reader_type = reader_type
|
|
42
42
|
self._buffer_size = buffer_size
|
|
43
|
-
self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf,
|
|
44
|
-
use_native_client)
|
|
43
|
+
self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf,
|
|
44
|
+
use_native_client, enable_crc)
|
|
45
45
|
log.info('TosIterableDataset init tos client succeed')
|
|
46
46
|
|
|
47
47
|
@classmethod
|
|
@@ -49,28 +49,28 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
|
|
|
49
49
|
transform: Callable[[TosObjectReader], Any] = default_trans,
|
|
50
50
|
cred: Optional[CredentialProvider] = None,
|
|
51
51
|
client_conf: Optional[TosClientConfig] = None,
|
|
52
|
-
log_conf: Optional[TosLogConfig] = None,
|
|
53
52
|
sharding: bool = False,
|
|
54
53
|
use_native_client: bool = True,
|
|
55
54
|
reader_type: Optional[ReaderType] = None,
|
|
56
|
-
buffer_size: Optional[int] = None
|
|
55
|
+
buffer_size: Optional[int] = None,
|
|
56
|
+
enable_crc: bool = True):
|
|
57
57
|
log.info(f'building {cls.__name__} from_urls')
|
|
58
|
-
return cls(region, partial(gen_dataset_from_urls, urls), endpoint, transform, cred, client_conf,
|
|
59
|
-
sharding, use_native_client, reader_type, buffer_size)
|
|
58
|
+
return cls(region, partial(gen_dataset_from_urls, urls), endpoint, transform, cred, client_conf,
|
|
59
|
+
sharding, use_native_client, reader_type, buffer_size, enable_crc)
|
|
60
60
|
|
|
61
61
|
@classmethod
|
|
62
62
|
def from_prefix(cls, prefix: str, *, region: str, endpoint: Optional[str] = None,
|
|
63
63
|
transform: Callable[[TosObjectReader], Any] = default_trans,
|
|
64
64
|
cred: Optional[CredentialProvider] = None,
|
|
65
65
|
client_conf: Optional[TosClientConfig] = None,
|
|
66
|
-
log_conf: Optional[TosLogConfig] = None,
|
|
67
66
|
sharding: bool = False,
|
|
68
67
|
use_native_client: bool = True,
|
|
69
68
|
reader_type: Optional[ReaderType] = None,
|
|
70
|
-
buffer_size: Optional[int] = None
|
|
69
|
+
buffer_size: Optional[int] = None,
|
|
70
|
+
enable_crc: bool = True):
|
|
71
71
|
log.info(f'building {cls.__name__} from_prefix')
|
|
72
|
-
return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, transform, cred, client_conf,
|
|
73
|
-
sharding, use_native_client, reader_type, buffer_size)
|
|
72
|
+
return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, transform, cred, client_conf,
|
|
73
|
+
sharding, use_native_client, reader_type, buffer_size, enable_crc)
|
|
74
74
|
|
|
75
75
|
def __iter__(self) -> Iterator[Any]:
|
|
76
76
|
worker_id = 0
|
|
@@ -4,8 +4,9 @@ from typing import Any, Callable, Iterator, Optional, List, Union
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from tosnativeclient import TosObject
|
|
7
8
|
from . import TosObjectReader
|
|
8
|
-
from .tos_client import CredentialProvider, TosClientConfig, TosClient,
|
|
9
|
+
from .tos_client import CredentialProvider, TosClientConfig, TosClient, ReaderType
|
|
9
10
|
from .tos_common import default_trans, gen_dataset_from_prefix, \
|
|
10
11
|
gen_dataset_from_urls
|
|
11
12
|
from .tos_object_meta import TosObjectMeta
|
|
@@ -17,52 +18,51 @@ class TosMapDataset(torch.utils.data.Dataset):
|
|
|
17
18
|
def __init__(self, region: str,
|
|
18
19
|
gen_dataset: Callable[[TosClient], Iterator[TosObjectMeta]],
|
|
19
20
|
endpoint: Optional[str] = None,
|
|
20
|
-
|
|
21
|
+
transform: Callable[[TosObjectReader], Any] = default_trans,
|
|
21
22
|
cred: Optional[CredentialProvider] = None,
|
|
22
23
|
client_conf: Optional[TosClientConfig] = None,
|
|
23
|
-
log_conf: Optional[TosLogConfig] = None,
|
|
24
24
|
use_native_client: bool = True,
|
|
25
25
|
reader_type: Optional[ReaderType] = None,
|
|
26
|
-
buffer_size: Optional[int] = None
|
|
26
|
+
buffer_size: Optional[int] = None,
|
|
27
|
+
enable_crc: bool = True):
|
|
27
28
|
self._gen_dataset = gen_dataset
|
|
28
29
|
self._region = region
|
|
29
30
|
self._endpoint = endpoint
|
|
30
|
-
self._trans =
|
|
31
|
+
self._trans = transform
|
|
31
32
|
self._cred = cred
|
|
32
33
|
self._client_conf = client_conf
|
|
33
|
-
self._log_conf = log_conf
|
|
34
34
|
self._dataset: Optional[List[TosObjectMeta]] = None
|
|
35
35
|
self._reader_type = reader_type
|
|
36
36
|
self._buffer_size = buffer_size
|
|
37
|
-
self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf,
|
|
38
|
-
use_native_client)
|
|
37
|
+
self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf,
|
|
38
|
+
use_native_client, enable_crc)
|
|
39
39
|
log.info('TosMapDataset init tos client succeed')
|
|
40
40
|
|
|
41
41
|
@classmethod
|
|
42
42
|
def from_urls(cls, urls: Union[str, Iterator[str]], *, region: str, endpoint: Optional[str] = None,
|
|
43
|
-
|
|
43
|
+
transform: Callable[[TosObjectReader], Any] = default_trans,
|
|
44
44
|
cred: Optional[CredentialProvider] = None,
|
|
45
45
|
client_conf: Optional[TosClientConfig] = None,
|
|
46
|
-
log_conf: Optional[TosLogConfig] = None,
|
|
47
46
|
use_native_client: bool = True,
|
|
48
47
|
reader_type: Optional[ReaderType] = None,
|
|
49
|
-
buffer_size: Optional[int] = None
|
|
48
|
+
buffer_size: Optional[int] = None,
|
|
49
|
+
enable_crc: bool = True):
|
|
50
50
|
log.info(f'building {cls.__name__} from_urls')
|
|
51
|
-
return cls(region, partial(gen_dataset_from_urls, urls), endpoint,
|
|
52
|
-
use_native_client, reader_type, buffer_size)
|
|
51
|
+
return cls(region, partial(gen_dataset_from_urls, urls), endpoint, transform, cred, client_conf,
|
|
52
|
+
use_native_client, reader_type, buffer_size, enable_crc)
|
|
53
53
|
|
|
54
54
|
@classmethod
|
|
55
55
|
def from_prefix(cls, prefix: str, *, region: str, endpoint: Optional[str] = None,
|
|
56
|
-
|
|
56
|
+
transform: Callable[[TosObjectReader], Any] = default_trans,
|
|
57
57
|
cred: Optional[CredentialProvider] = None,
|
|
58
58
|
client_conf: Optional[TosClientConfig] = None,
|
|
59
|
-
log_conf: Optional[TosLogConfig] = None,
|
|
60
59
|
use_native_client: bool = True,
|
|
61
60
|
reader_type: Optional[ReaderType] = None,
|
|
62
|
-
buffer_size: Optional[int] = None
|
|
61
|
+
buffer_size: Optional[int] = None,
|
|
62
|
+
enable_crc: bool = True):
|
|
63
63
|
log.info(f'building {cls.__name__} from_prefix')
|
|
64
|
-
return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint,
|
|
65
|
-
use_native_client, reader_type, buffer_size)
|
|
64
|
+
return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, transform, cred, client_conf,
|
|
65
|
+
use_native_client, reader_type, buffer_size, enable_crc)
|
|
66
66
|
|
|
67
67
|
def __getitem__(self, i: int) -> Any:
|
|
68
68
|
return self._trans_tos_object(i)
|
|
@@ -85,7 +85,7 @@ class TosObjectStream(object):
|
|
|
85
85
|
def close(self) -> None:
|
|
86
86
|
if self._sequential_object_stream and isinstance(self._sequential_object_stream, tosnativeclient.ReadStream):
|
|
87
87
|
self._sequential_object_stream.close()
|
|
88
|
-
if self._random_object_stream:
|
|
88
|
+
if self._random_object_stream and isinstance(self._random_object_stream, tosnativeclient.ReadStream):
|
|
89
89
|
self._random_object_stream.close()
|
|
90
90
|
|
|
91
91
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tostorchconnector
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.1.2
|
|
4
4
|
Summary: TOS connector integration for PyTorch
|
|
5
5
|
Author-email: xiangshijian <xiangshijian@bytedance.com>
|
|
6
6
|
Classifier: Development Status :: 4 - Beta
|
|
@@ -19,7 +19,7 @@ Description-Content-Type: text/markdown
|
|
|
19
19
|
License-File: LICENSE
|
|
20
20
|
Requires-Dist: torch>=2.0
|
|
21
21
|
Requires-Dist: tos>=2.8.0
|
|
22
|
-
Requires-Dist: tosnativeclient>=1.
|
|
22
|
+
Requires-Dist: tosnativeclient>=1.1.1
|
|
23
23
|
Dynamic: license-file
|
|
24
24
|
|
|
25
25
|
# TOS Connector for pytorch
|
|
@@ -1,102 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import unittest
|
|
3
|
-
import uuid
|
|
4
|
-
|
|
5
|
-
import tos
|
|
6
|
-
from tos.exceptions import TosServerError
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class TestTosRawClient(unittest.TestCase):
|
|
10
|
-
def test_put_get(self):
|
|
11
|
-
from tosnativeclient import TosRawClient, HeadObjectInput, TosException, DeleteObjectInput, \
|
|
12
|
-
PutObjectFromBufferInput, \
|
|
13
|
-
GetObjectInput, GetObjectOutput
|
|
14
|
-
|
|
15
|
-
region = os.getenv('TOS_REGION')
|
|
16
|
-
endpoint = os.getenv('TOS_ENDPOINT')
|
|
17
|
-
ak = os.getenv('TOS_ACCESS_KEY')
|
|
18
|
-
sk = os.getenv('TOS_SECRET_KEY')
|
|
19
|
-
bucket = 'tos-pytorch-connector'
|
|
20
|
-
key = str(uuid.uuid4())
|
|
21
|
-
tos_raw_client = TosRawClient(region, endpoint, ak, sk)
|
|
22
|
-
|
|
23
|
-
doutput = tos_raw_client.delete_object(DeleteObjectInput(bucket, key))
|
|
24
|
-
assert doutput.status_code == 204
|
|
25
|
-
|
|
26
|
-
try:
|
|
27
|
-
tos_raw_client.head_object(HeadObjectInput(bucket, key))
|
|
28
|
-
assert False
|
|
29
|
-
except TosException as e:
|
|
30
|
-
assert e.args[0].status_code == 404
|
|
31
|
-
assert len(e.args[0].request_id) > 0
|
|
32
|
-
|
|
33
|
-
data = str(uuid.uuid4()).encode('utf-8')
|
|
34
|
-
input = PutObjectFromBufferInput(bucket, key, data)
|
|
35
|
-
poutput = tos_raw_client.put_object_from_buffer(input)
|
|
36
|
-
assert poutput.status_code == 200
|
|
37
|
-
assert len(poutput.etag) > 0
|
|
38
|
-
|
|
39
|
-
houtput = tos_raw_client.head_object(HeadObjectInput(bucket, key))
|
|
40
|
-
assert houtput.status_code == 200
|
|
41
|
-
assert houtput.etag == poutput.etag
|
|
42
|
-
assert houtput.content_length == len(data)
|
|
43
|
-
|
|
44
|
-
goutput: GetObjectOutput = tos_raw_client.get_object(GetObjectInput(bucket, key))
|
|
45
|
-
assert goutput.status_code == 200
|
|
46
|
-
assert goutput.etag == poutput.etag
|
|
47
|
-
rdata = goutput.read_all()
|
|
48
|
-
assert rdata == data
|
|
49
|
-
|
|
50
|
-
doutput = tos_raw_client.delete_object(DeleteObjectInput(bucket, key))
|
|
51
|
-
assert doutput.status_code == 204
|
|
52
|
-
|
|
53
|
-
try:
|
|
54
|
-
tos_raw_client.get_object(GetObjectInput(bucket, key))
|
|
55
|
-
assert False
|
|
56
|
-
except TosException as e:
|
|
57
|
-
assert e.args[0].status_code == 404
|
|
58
|
-
assert len(e.args[0].request_id) > 0
|
|
59
|
-
|
|
60
|
-
def test_put_get_old(self):
|
|
61
|
-
region = os.getenv('TOS_REGION')
|
|
62
|
-
endpoint = os.getenv('TOS_ENDPOINT')
|
|
63
|
-
ak = os.getenv('TOS_ACCESS_KEY')
|
|
64
|
-
sk = os.getenv('TOS_SECRET_KEY')
|
|
65
|
-
bucket = 'tos-pytorch-connector'
|
|
66
|
-
key = str(uuid.uuid4())
|
|
67
|
-
tos_client = tos.TosClientV2(ak, sk, endpoint=endpoint, region=region)
|
|
68
|
-
doutput = tos_client.delete_object(bucket, key)
|
|
69
|
-
assert doutput.status_code == 204
|
|
70
|
-
|
|
71
|
-
try:
|
|
72
|
-
tos_client.head_object(bucket, key)
|
|
73
|
-
assert False
|
|
74
|
-
except TosServerError as e:
|
|
75
|
-
assert e.status_code == 404
|
|
76
|
-
assert len(e.request_id) > 0
|
|
77
|
-
|
|
78
|
-
data = str(uuid.uuid4()).encode('utf-8')
|
|
79
|
-
poutput = tos_client.put_object(bucket, key, content=data)
|
|
80
|
-
assert poutput.status_code == 200
|
|
81
|
-
assert len(poutput.etag) > 0
|
|
82
|
-
|
|
83
|
-
houtput = tos_client.head_object(bucket, key)
|
|
84
|
-
assert houtput.status_code == 200
|
|
85
|
-
assert houtput.etag == poutput.etag
|
|
86
|
-
assert houtput.content_length == len(data)
|
|
87
|
-
|
|
88
|
-
goutput = tos_client.get_object(bucket, key)
|
|
89
|
-
assert goutput.status_code == 200
|
|
90
|
-
assert goutput.etag == poutput.etag
|
|
91
|
-
rdata = goutput.read()
|
|
92
|
-
assert rdata == data
|
|
93
|
-
|
|
94
|
-
doutput = tos_client.delete_object(bucket, key)
|
|
95
|
-
assert doutput.status_code == 204
|
|
96
|
-
|
|
97
|
-
try:
|
|
98
|
-
tos_client.get_object(bucket, key)
|
|
99
|
-
assert False
|
|
100
|
-
except TosServerError as e:
|
|
101
|
-
assert e.status_code == 404
|
|
102
|
-
assert len(e.request_id) > 0
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
{tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector.egg-info/top_level.txt
RENAMED
|
File without changes
|