tostorchconnector 1.0.6__tar.gz → 1.0.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tostorchconnector might be problematic. Click here for more details.
- {tostorchconnector-1.0.6/tostorchconnector.egg-info → tostorchconnector-1.0.8}/PKG-INFO +2 -2
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/pyproject.toml +2 -2
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tests/test_tos_dataset.py +54 -14
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tests/test_tosclient.py +35 -12
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tests/test_tosrawclient.py +15 -0
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/tos_client.py +98 -36
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/tos_common.py +16 -5
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/tos_iterable_dataset.py +2 -2
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/tos_map_dataset.py +6 -6
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/tos_object_reader.py +1 -1
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8/tostorchconnector.egg-info}/PKG-INFO +2 -2
- tostorchconnector-1.0.8/tostorchconnector.egg-info/requires.txt +3 -0
- tostorchconnector-1.0.6/tostorchconnector.egg-info/requires.txt +0 -3
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/LICENSE +0 -0
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/README.md +0 -0
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/setup.cfg +0 -0
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/__init__.py +0 -0
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/tos_checkpoint.py +0 -0
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/tos_object_meta.py +0 -0
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/tos_object_writer.py +0 -0
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector.egg-info/SOURCES.txt +0 -0
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector.egg-info/dependency_links.txt +0 -0
- {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tostorchconnector
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.8
|
|
4
4
|
Summary: TOS connector integration for PyTorch
|
|
5
5
|
Author-email: xiangshijian <xiangshijian@bytedance.com>
|
|
6
6
|
Classifier: Development Status :: 4 - Beta
|
|
@@ -19,7 +19,7 @@ Description-Content-Type: text/markdown
|
|
|
19
19
|
License-File: LICENSE
|
|
20
20
|
Requires-Dist: torch>=2.0
|
|
21
21
|
Requires-Dist: tos>=2.8.0
|
|
22
|
-
Requires-Dist: tosnativeclient>=1.0.
|
|
22
|
+
Requires-Dist: tosnativeclient>=1.0.6
|
|
23
23
|
Dynamic: license-file
|
|
24
24
|
|
|
25
25
|
# TOS Connector for pytorch
|
|
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|
|
5
5
|
|
|
6
6
|
[project]
|
|
7
7
|
name = "tostorchconnector"
|
|
8
|
-
version = "1.0.
|
|
8
|
+
version = "1.0.8"
|
|
9
9
|
description = "TOS connector integration for PyTorch"
|
|
10
10
|
authors = [{ name = "xiangshijian", email = "xiangshijian@bytedance.com" }]
|
|
11
11
|
requires-python = ">=3.8,<3.14"
|
|
@@ -26,7 +26,7 @@ classifiers = [
|
|
|
26
26
|
dependencies = [
|
|
27
27
|
"torch >= 2.0",
|
|
28
28
|
"tos>=2.8.0",
|
|
29
|
-
"tosnativeclient >= 1.0.
|
|
29
|
+
"tosnativeclient >= 1.0.6"
|
|
30
30
|
]
|
|
31
31
|
|
|
32
32
|
[tool.setuptools.packages]
|
|
@@ -1,10 +1,14 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import pickle
|
|
2
3
|
import unittest
|
|
3
4
|
|
|
5
|
+
from torch.utils.data import DataLoader
|
|
6
|
+
|
|
4
7
|
from tostorchconnector import TosMapDataset, TosIterableDataset, TosCheckpoint
|
|
5
8
|
from tostorchconnector.tos_client import CredentialProvider, ReaderType
|
|
6
9
|
|
|
7
10
|
USE_NATIVE_CLIENT = True
|
|
11
|
+
READER_TYPE = ReaderType.SEQUENTIAL
|
|
8
12
|
|
|
9
13
|
|
|
10
14
|
class TestTosDataSet(unittest.TestCase):
|
|
@@ -22,23 +26,50 @@ class TestTosDataSet(unittest.TestCase):
|
|
|
22
26
|
for i in range(len(datasets)):
|
|
23
27
|
print(datasets[i].bucket, datasets[i].key)
|
|
24
28
|
|
|
29
|
+
def test_pickle(self):
|
|
30
|
+
region = os.getenv('TOS_REGION')
|
|
31
|
+
endpoint = os.getenv('TOS_ENDPOINT')
|
|
32
|
+
ak = os.getenv('TOS_ACCESS_KEY')
|
|
33
|
+
sk = os.getenv('TOS_SECRET_KEY')
|
|
34
|
+
bucket = 'tos-pytorch-connector'
|
|
35
|
+
datasets = TosMapDataset.from_prefix(f'tos://{bucket}', region=region,
|
|
36
|
+
endpoint=endpoint, cred=CredentialProvider(ak, sk),
|
|
37
|
+
use_native_client=USE_NATIVE_CLIENT)
|
|
38
|
+
pickled_datasets = pickle.dumps(datasets)
|
|
39
|
+
assert isinstance(pickled_datasets, bytes)
|
|
40
|
+
unpickled_datasets = pickle.loads(pickled_datasets)
|
|
41
|
+
i = 0
|
|
42
|
+
for dataset in unpickled_datasets:
|
|
43
|
+
print(dataset.bucket, dataset.key)
|
|
44
|
+
i += 1
|
|
45
|
+
print(i)
|
|
46
|
+
|
|
47
|
+
pickled = pickle.dumps(unpickled_datasets._data_set)
|
|
48
|
+
assert isinstance(pickled, bytes)
|
|
49
|
+
|
|
25
50
|
def test_from_prefix(self):
|
|
26
51
|
region = os.getenv('TOS_REGION')
|
|
27
52
|
endpoint = os.getenv('TOS_ENDPOINT')
|
|
28
53
|
ak = os.getenv('TOS_ACCESS_KEY')
|
|
29
54
|
sk = os.getenv('TOS_SECRET_KEY')
|
|
30
55
|
bucket = 'tos-pytorch-connector'
|
|
31
|
-
datasets = TosMapDataset.from_prefix(f'tos://{bucket}
|
|
56
|
+
datasets = TosMapDataset.from_prefix(f'tos://{bucket}', region=region,
|
|
32
57
|
endpoint=endpoint, cred=CredentialProvider(ak, sk),
|
|
33
58
|
use_native_client=USE_NATIVE_CLIENT)
|
|
59
|
+
|
|
60
|
+
count = 0
|
|
34
61
|
for i in range(len(datasets)):
|
|
35
|
-
|
|
36
|
-
print(
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
62
|
+
dataset = datasets[i]
|
|
63
|
+
print(dataset.bucket, dataset.key)
|
|
64
|
+
count += 1
|
|
65
|
+
dataset.close()
|
|
66
|
+
# if i == 1:
|
|
67
|
+
# item = datasets[i]
|
|
68
|
+
# data = item.read(100)
|
|
69
|
+
# print(data)
|
|
70
|
+
# print(len(data))
|
|
71
|
+
|
|
72
|
+
print(count)
|
|
42
73
|
|
|
43
74
|
def test_from_prefix_iter(self):
|
|
44
75
|
region = os.getenv('TOS_REGION')
|
|
@@ -46,17 +77,23 @@ class TestTosDataSet(unittest.TestCase):
|
|
|
46
77
|
ak = os.getenv('TOS_ACCESS_KEY')
|
|
47
78
|
sk = os.getenv('TOS_SECRET_KEY')
|
|
48
79
|
bucket = 'tos-pytorch-connector'
|
|
49
|
-
datasets = TosIterableDataset.from_prefix(f'tos://{bucket}
|
|
80
|
+
datasets = TosIterableDataset.from_prefix(f'tos://{bucket}', region=region,
|
|
50
81
|
endpoint=endpoint, cred=CredentialProvider(ak, sk),
|
|
51
|
-
use_native_client=USE_NATIVE_CLIENT)
|
|
82
|
+
use_native_client=USE_NATIVE_CLIENT, reader_type=ReaderType.RANGED)
|
|
52
83
|
i = 0
|
|
53
84
|
for dataset in datasets:
|
|
54
85
|
print(dataset.bucket, dataset.key)
|
|
55
|
-
if
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
86
|
+
# if dataset.key == 'tosutil':
|
|
87
|
+
# with open('logs/tosutil', 'wb') as f:
|
|
88
|
+
# while 1:
|
|
89
|
+
# chunk = dataset.read(8192)
|
|
90
|
+
# print(len(chunk))
|
|
91
|
+
# if not chunk:
|
|
92
|
+
# break
|
|
93
|
+
# f.write(chunk)
|
|
94
|
+
dataset.close()
|
|
59
95
|
i += 1
|
|
96
|
+
print(i)
|
|
60
97
|
|
|
61
98
|
def test_checkpoint(self):
|
|
62
99
|
region = os.getenv('TOS_REGION')
|
|
@@ -72,6 +109,9 @@ class TestTosDataSet(unittest.TestCase):
|
|
|
72
109
|
writer.write(b'hello world')
|
|
73
110
|
writer.write(b'hi world')
|
|
74
111
|
|
|
112
|
+
with checkpoint.reader(url) as reader:
|
|
113
|
+
print(reader.read())
|
|
114
|
+
|
|
75
115
|
with checkpoint.reader(url) as reader:
|
|
76
116
|
data = reader.read(5)
|
|
77
117
|
print(data)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import pickle
|
|
2
3
|
import unittest
|
|
3
4
|
import uuid
|
|
4
5
|
|
|
@@ -6,6 +7,21 @@ from tosnativeclient import TosClient, TosException
|
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class TestTosClient(unittest.TestCase):
|
|
10
|
+
def test_pickle(self):
|
|
11
|
+
region = os.getenv('TOS_REGION')
|
|
12
|
+
endpoint = os.getenv('TOS_ENDPOINT')
|
|
13
|
+
ak = os.getenv('TOS_ACCESS_KEY')
|
|
14
|
+
sk = os.getenv('TOS_SECRET_KEY')
|
|
15
|
+
bucket = 'tos-pytorch-connector'
|
|
16
|
+
tos_client = TosClient(region, endpoint, ak, sk, directives='info', directory='logs',
|
|
17
|
+
file_name_prefix='app.log')
|
|
18
|
+
pickled_tos_client = pickle.dumps(tos_client)
|
|
19
|
+
assert isinstance(pickled_tos_client, bytes)
|
|
20
|
+
unpickled_tos_client = pickle.loads(pickled_tos_client)
|
|
21
|
+
|
|
22
|
+
self._test_list_objects(bucket, unpickled_tos_client)
|
|
23
|
+
self._test_write_read_object(bucket, unpickled_tos_client)
|
|
24
|
+
|
|
9
25
|
def test_list_objects(self):
|
|
10
26
|
region = os.getenv('TOS_REGION')
|
|
11
27
|
endpoint = os.getenv('TOS_ENDPOINT')
|
|
@@ -14,7 +30,20 @@ class TestTosClient(unittest.TestCase):
|
|
|
14
30
|
bucket = 'tos-pytorch-connector'
|
|
15
31
|
tos_client = TosClient(region, endpoint, ak, sk, directives='info', directory='logs',
|
|
16
32
|
file_name_prefix='app.log')
|
|
33
|
+
self._test_list_objects(bucket, tos_client)
|
|
17
34
|
|
|
35
|
+
def test_write_read_object(self):
|
|
36
|
+
region = os.getenv('TOS_REGION')
|
|
37
|
+
endpoint = os.getenv('TOS_ENDPOINT')
|
|
38
|
+
ak = os.getenv('TOS_ACCESS_KEY')
|
|
39
|
+
sk = os.getenv('TOS_SECRET_KEY')
|
|
40
|
+
bucket = 'tos-pytorch-connector'
|
|
41
|
+
tos_client = TosClient(region, endpoint, ak, sk, directives='info', directory='logs',
|
|
42
|
+
file_name_prefix='app.log')
|
|
43
|
+
|
|
44
|
+
self._test_write_read_object(bucket, tos_client)
|
|
45
|
+
|
|
46
|
+
def _test_list_objects(self, bucket, tos_client):
|
|
18
47
|
list_stream = tos_client.list_objects(bucket, '', max_keys=1000)
|
|
19
48
|
count = 0
|
|
20
49
|
try:
|
|
@@ -22,23 +51,15 @@ class TestTosClient(unittest.TestCase):
|
|
|
22
51
|
for content in objects.contents:
|
|
23
52
|
count += 1
|
|
24
53
|
print(content.key, content.size)
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
54
|
+
output = tos_client.head_object(bucket, content.key)
|
|
55
|
+
assert output.etag == content.etag
|
|
56
|
+
assert output.size == content.size
|
|
28
57
|
|
|
29
58
|
print(count)
|
|
30
59
|
except TosException as e:
|
|
31
60
|
print(e.args[0].message)
|
|
32
61
|
|
|
33
|
-
def
|
|
34
|
-
region = os.getenv('TOS_REGION')
|
|
35
|
-
endpoint = os.getenv('TOS_ENDPOINT')
|
|
36
|
-
ak = os.getenv('TOS_ACCESS_KEY')
|
|
37
|
-
sk = os.getenv('TOS_SECRET_KEY')
|
|
38
|
-
bucket = 'tos-pytorch-connector'
|
|
39
|
-
tos_client = TosClient(region, endpoint, ak, sk, directives='info', directory='logs',
|
|
40
|
-
file_name_prefix='app.log')
|
|
41
|
-
|
|
62
|
+
def _test_write_read_object(self, bucket, tos_client):
|
|
42
63
|
key = str(uuid.uuid4())
|
|
43
64
|
read_stream = tos_client.get_object(bucket, key, '', 1)
|
|
44
65
|
|
|
@@ -59,6 +80,8 @@ class TestTosClient(unittest.TestCase):
|
|
|
59
80
|
write_stream.close()
|
|
60
81
|
|
|
61
82
|
output = tos_client.head_object(bucket, key)
|
|
83
|
+
print(output.etag, output.size)
|
|
84
|
+
|
|
62
85
|
read_stream = tos_client.get_object(bucket, key, output.etag, output.size)
|
|
63
86
|
try:
|
|
64
87
|
offset = 0
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import io
|
|
1
2
|
import os
|
|
2
3
|
import unittest
|
|
3
4
|
import uuid
|
|
@@ -47,6 +48,20 @@ class TestTosRawClient(unittest.TestCase):
|
|
|
47
48
|
rdata = goutput.read_all()
|
|
48
49
|
assert rdata == data
|
|
49
50
|
|
|
51
|
+
goutput: GetObjectOutput = tos_raw_client.get_object(GetObjectInput(bucket, key))
|
|
52
|
+
assert goutput.status_code == 200
|
|
53
|
+
assert goutput.etag == poutput.etag
|
|
54
|
+
|
|
55
|
+
rdata = io.BytesIO()
|
|
56
|
+
while 1:
|
|
57
|
+
chunk = goutput.read()
|
|
58
|
+
if not chunk:
|
|
59
|
+
break
|
|
60
|
+
rdata.write(chunk)
|
|
61
|
+
|
|
62
|
+
rdata.seek(0)
|
|
63
|
+
assert rdata.read() == data
|
|
64
|
+
|
|
50
65
|
doutput = tos_raw_client.delete_object(DeleteObjectInput(bucket, key))
|
|
51
66
|
assert doutput.status_code == 204
|
|
52
67
|
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
import enum
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
+
import gc
|
|
5
|
+
|
|
4
6
|
from functools import partial
|
|
5
|
-
from typing import Optional, List, Tuple
|
|
7
|
+
from typing import Optional, List, Tuple, Any
|
|
6
8
|
|
|
7
9
|
import tos
|
|
8
10
|
|
|
@@ -15,6 +17,44 @@ from .tos_object_writer import PutObjectStream
|
|
|
15
17
|
|
|
16
18
|
log = logging.getLogger(__name__)
|
|
17
19
|
|
|
20
|
+
import threading
|
|
21
|
+
import weakref
|
|
22
|
+
import traceback
|
|
23
|
+
|
|
24
|
+
_client_lock = threading.Lock()
|
|
25
|
+
_client_map = weakref.WeakSet()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _before_fork():
|
|
29
|
+
with _client_lock:
|
|
30
|
+
clients = list(_client_map)
|
|
31
|
+
|
|
32
|
+
if not clients or len(clients) == 0:
|
|
33
|
+
return
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
for client in clients:
|
|
37
|
+
client._inner_client = None
|
|
38
|
+
|
|
39
|
+
_reset_client_map()
|
|
40
|
+
gc.collect()
|
|
41
|
+
except Exception as e:
|
|
42
|
+
log.warning(f'failed to clean up native clients, {str(e)}')
|
|
43
|
+
traceback.print_exc()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _after_fork_in_child():
|
|
47
|
+
_reset_client_map()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _reset_client_map():
|
|
51
|
+
global _client_map
|
|
52
|
+
with _client_lock:
|
|
53
|
+
_client_map = weakref.WeakSet()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
os.register_at_fork(before=_before_fork, after_in_child=_after_fork_in_child)
|
|
57
|
+
|
|
18
58
|
|
|
19
59
|
class ReaderType(enum.Enum):
|
|
20
60
|
SEQUENTIAL = 'Sequential'
|
|
@@ -79,40 +119,59 @@ class TosClient(object):
|
|
|
79
119
|
def __init__(self, region: str, endpoint: Optional[str] = None, cred: Optional[CredentialProvider] = None,
|
|
80
120
|
client_conf: Optional[TosClientConfig] = None, log_conf: Optional[TosLogConfig] = None,
|
|
81
121
|
use_native_client: bool = True):
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
self.
|
|
122
|
+
self._region = region
|
|
123
|
+
self._endpoint = endpoint
|
|
124
|
+
self._cred = CredentialProvider('', '') if cred is None else cred
|
|
125
|
+
self._client_conf = TosClientConfig() if client_conf is None else client_conf
|
|
126
|
+
self._log_conf = TosLogConfig() if log_conf is None else log_conf
|
|
127
|
+
self._part_size = self._client_conf.part_size
|
|
86
128
|
self._use_native_client = use_native_client
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
129
|
+
self._inner_client = None
|
|
130
|
+
self._client_pid = None
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def _client(self) -> Any:
|
|
134
|
+
if self._client_pid is None or self._client_pid != os.getpid() or self._inner_client is None:
|
|
135
|
+
with _client_lock:
|
|
136
|
+
if self._use_native_client:
|
|
137
|
+
directives = ''
|
|
138
|
+
directory = ''
|
|
139
|
+
file_name_prefix = ''
|
|
140
|
+
if self._log_conf.log_dir and self._log_conf.log_file_name:
|
|
141
|
+
if self._log_conf.log_level:
|
|
142
|
+
if self._log_conf.log_level == logging.DEBUG:
|
|
143
|
+
directives = 'debug'
|
|
144
|
+
elif self._log_conf.log_level == logging.INFO:
|
|
145
|
+
directives = 'info'
|
|
146
|
+
elif self._log_conf.log_level == logging.WARN:
|
|
147
|
+
directives = 'warn'
|
|
148
|
+
elif self._log_conf.log_level == logging.ERROR:
|
|
149
|
+
directives = 'error'
|
|
150
|
+
else:
|
|
151
|
+
directives = 'info'
|
|
152
|
+
directory = self._log_conf.log_dir
|
|
153
|
+
file_name_prefix = self._log_conf.log_file_name
|
|
154
|
+
self._inner_client = tosnativeclient.TosClient(self._region, self._endpoint, self._cred.ak,
|
|
155
|
+
self._cred.sk,
|
|
156
|
+
self._client_conf.part_size,
|
|
157
|
+
self._client_conf.max_retry_count,
|
|
158
|
+
directives=directives,
|
|
159
|
+
directory=directory,
|
|
160
|
+
file_name_prefix=file_name_prefix)
|
|
161
|
+
else:
|
|
162
|
+
self._inner_client = tos.TosClientV2(self._cred.ak, self._cred.sk, endpoint=self._endpoint,
|
|
163
|
+
region=self._region,
|
|
164
|
+
max_retry_count=self._client_conf.max_retry_count)
|
|
165
|
+
if self._log_conf.log_dir and self._log_conf.log_file_name:
|
|
166
|
+
file_path = os.path.join(self._log_conf.log_dir, self._log_conf.log_file_name)
|
|
167
|
+
log_level = self._log_conf.log_level if self._log_conf.log_level else logging.INFO
|
|
168
|
+
tos.set_logger(file_path=file_path, level=log_level)
|
|
169
|
+
|
|
170
|
+
self._client_pid = os.getpid()
|
|
171
|
+
_client_map.add(self)
|
|
172
|
+
|
|
173
|
+
assert self._inner_client is not None
|
|
174
|
+
return self._inner_client
|
|
116
175
|
|
|
117
176
|
@property
|
|
118
177
|
def use_native_client(self) -> bool:
|
|
@@ -155,12 +214,15 @@ class TosClient(object):
|
|
|
155
214
|
return TosObjectMeta(bucket, key, resp.content_length, resp.etag)
|
|
156
215
|
|
|
157
216
|
def gen_list_stream(self, bucket: str, prefix: str, max_keys: int = 1000,
|
|
158
|
-
delimiter: Optional[str] = None
|
|
217
|
+
delimiter: Optional[str] = None,
|
|
218
|
+
continuation_token: Optional[str] = None) -> tosnativeclient.ListStream:
|
|
159
219
|
log.debug(f'gen_list_stream tos://{bucket}/{prefix}')
|
|
160
220
|
|
|
161
221
|
if isinstance(self._client, tosnativeclient.TosClient):
|
|
162
222
|
delimiter = delimiter if delimiter is not None else ''
|
|
163
|
-
|
|
223
|
+
continuation_token = continuation_token if continuation_token is not None else ''
|
|
224
|
+
return self._client.list_objects(bucket, prefix, max_keys=max_keys, delimiter=delimiter,
|
|
225
|
+
continuation_token=continuation_token)
|
|
164
226
|
raise NotImplementedError()
|
|
165
227
|
|
|
166
228
|
def list_objects(self, bucket: str, prefix: str, max_keys: int = 1000,
|
|
@@ -8,19 +8,28 @@ from .tos_object_meta import TosObjectMeta
|
|
|
8
8
|
log = logging.getLogger(__name__)
|
|
9
9
|
|
|
10
10
|
|
|
11
|
+
class TosObjectIterable(object):
|
|
12
|
+
def __init__(self, bucket: str, prefix: str, client: TosClient):
|
|
13
|
+
self._bucket = bucket
|
|
14
|
+
self._prefix = prefix
|
|
15
|
+
self._client = client
|
|
16
|
+
|
|
17
|
+
def __iter__(self) -> Iterator[TosObjectMeta]:
|
|
18
|
+
return iter(TosObjectIterator(self._bucket, self._prefix, self._client))
|
|
19
|
+
|
|
20
|
+
|
|
11
21
|
class TosObjectIterator(object):
|
|
12
22
|
def __init__(self, bucket: str, prefix: str, client: TosClient):
|
|
13
23
|
self._bucket = bucket
|
|
14
24
|
self._prefix = prefix
|
|
15
25
|
self._client = client
|
|
16
26
|
self._delimiter: Optional[str] = None
|
|
17
|
-
self.
|
|
27
|
+
self._continuation_token: Optional[str] = None
|
|
18
28
|
|
|
29
|
+
self._list_stream = None
|
|
19
30
|
self._object_metas = None
|
|
20
31
|
self._index = 0
|
|
21
|
-
|
|
22
32
|
self._is_truncated = True
|
|
23
|
-
self._continuation_token = None
|
|
24
33
|
|
|
25
34
|
def close(self) -> None:
|
|
26
35
|
if self._list_stream is not None:
|
|
@@ -33,13 +42,15 @@ class TosObjectIterator(object):
|
|
|
33
42
|
if self._client.use_native_client:
|
|
34
43
|
if self._list_stream is None:
|
|
35
44
|
self._list_stream = self._client.gen_list_stream(self._bucket, self._prefix, max_keys=1000,
|
|
36
|
-
delimiter=self._delimiter
|
|
45
|
+
delimiter=self._delimiter,
|
|
46
|
+
continuation_token=self._continuation_token)
|
|
37
47
|
|
|
38
48
|
if self._object_metas is None or self._index >= len(self._object_metas):
|
|
39
49
|
self._object_metas = []
|
|
40
50
|
self._index = 0
|
|
41
51
|
while 1:
|
|
42
52
|
objects = next(self._list_stream)
|
|
53
|
+
self._continuation_token = self._list_stream.current_continuation_token()
|
|
43
54
|
for content in objects.contents:
|
|
44
55
|
self._object_metas.append(
|
|
45
56
|
TosObjectMeta(content.bucket, content.key, content.size, content.etag))
|
|
@@ -101,4 +112,4 @@ def gen_dataset_from_urls(urls: Union[str, Iterator[str]], _: TosClient) -> Iter
|
|
|
101
112
|
|
|
102
113
|
def gen_dataset_from_prefix(prefix: str, client: TosClient) -> Iterator[TosObjectMeta]:
|
|
103
114
|
bucket, prefix = parse_tos_url(prefix)
|
|
104
|
-
return iter(
|
|
115
|
+
return iter(TosObjectIterable(bucket, prefix, client))
|
{tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/tos_iterable_dataset.py
RENAMED
|
@@ -16,7 +16,7 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
|
|
|
16
16
|
def __init__(self, region: str,
|
|
17
17
|
gen_dataset: Callable[[TosClient], Iterator[TosObjectMeta]],
|
|
18
18
|
endpoint: Optional[str] = None,
|
|
19
|
-
|
|
19
|
+
transform: Callable[[TosObjectReader], Any] = default_trans,
|
|
20
20
|
cred: Optional[CredentialProvider] = None,
|
|
21
21
|
client_conf: Optional[TosClientConfig] = None,
|
|
22
22
|
log_conf: Optional[TosLogConfig] = None,
|
|
@@ -27,7 +27,7 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
|
|
|
27
27
|
self._gen_dataset = gen_dataset
|
|
28
28
|
self._region = region
|
|
29
29
|
self._endpoint = endpoint
|
|
30
|
-
self._trans =
|
|
30
|
+
self._trans = transform
|
|
31
31
|
self._cred = cred
|
|
32
32
|
self._client_conf = client_conf
|
|
33
33
|
self._log_conf = log_conf
|
|
@@ -17,7 +17,7 @@ class TosMapDataset(torch.utils.data.Dataset):
|
|
|
17
17
|
def __init__(self, region: str,
|
|
18
18
|
gen_dataset: Callable[[TosClient], Iterator[TosObjectMeta]],
|
|
19
19
|
endpoint: Optional[str] = None,
|
|
20
|
-
|
|
20
|
+
transform: Callable[[TosObjectReader], Any] = default_trans,
|
|
21
21
|
cred: Optional[CredentialProvider] = None,
|
|
22
22
|
client_conf: Optional[TosClientConfig] = None,
|
|
23
23
|
log_conf: Optional[TosLogConfig] = None,
|
|
@@ -27,7 +27,7 @@ class TosMapDataset(torch.utils.data.Dataset):
|
|
|
27
27
|
self._gen_dataset = gen_dataset
|
|
28
28
|
self._region = region
|
|
29
29
|
self._endpoint = endpoint
|
|
30
|
-
self._trans =
|
|
30
|
+
self._trans = transform
|
|
31
31
|
self._cred = cred
|
|
32
32
|
self._client_conf = client_conf
|
|
33
33
|
self._log_conf = log_conf
|
|
@@ -40,7 +40,7 @@ class TosMapDataset(torch.utils.data.Dataset):
|
|
|
40
40
|
|
|
41
41
|
@classmethod
|
|
42
42
|
def from_urls(cls, urls: Union[str, Iterator[str]], *, region: str, endpoint: Optional[str] = None,
|
|
43
|
-
|
|
43
|
+
transform: Callable[[TosObjectReader], Any] = default_trans,
|
|
44
44
|
cred: Optional[CredentialProvider] = None,
|
|
45
45
|
client_conf: Optional[TosClientConfig] = None,
|
|
46
46
|
log_conf: Optional[TosLogConfig] = None,
|
|
@@ -48,12 +48,12 @@ class TosMapDataset(torch.utils.data.Dataset):
|
|
|
48
48
|
reader_type: Optional[ReaderType] = None,
|
|
49
49
|
buffer_size: Optional[int] = None):
|
|
50
50
|
log.info(f'building {cls.__name__} from_urls')
|
|
51
|
-
return cls(region, partial(gen_dataset_from_urls, urls), endpoint,
|
|
51
|
+
return cls(region, partial(gen_dataset_from_urls, urls), endpoint, transform, cred, client_conf, log_conf,
|
|
52
52
|
use_native_client, reader_type, buffer_size)
|
|
53
53
|
|
|
54
54
|
@classmethod
|
|
55
55
|
def from_prefix(cls, prefix: str, *, region: str, endpoint: Optional[str] = None,
|
|
56
|
-
|
|
56
|
+
transform: Callable[[TosObjectReader], Any] = default_trans,
|
|
57
57
|
cred: Optional[CredentialProvider] = None,
|
|
58
58
|
client_conf: Optional[TosClientConfig] = None,
|
|
59
59
|
log_conf: Optional[TosLogConfig] = None,
|
|
@@ -61,7 +61,7 @@ class TosMapDataset(torch.utils.data.Dataset):
|
|
|
61
61
|
reader_type: Optional[ReaderType] = None,
|
|
62
62
|
buffer_size: Optional[int] = None):
|
|
63
63
|
log.info(f'building {cls.__name__} from_prefix')
|
|
64
|
-
return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint,
|
|
64
|
+
return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, transform, cred, client_conf, log_conf,
|
|
65
65
|
use_native_client, reader_type, buffer_size)
|
|
66
66
|
|
|
67
67
|
def __getitem__(self, i: int) -> Any:
|
|
@@ -85,7 +85,7 @@ class TosObjectStream(object):
|
|
|
85
85
|
def close(self) -> None:
|
|
86
86
|
if self._sequential_object_stream and isinstance(self._sequential_object_stream, tosnativeclient.ReadStream):
|
|
87
87
|
self._sequential_object_stream.close()
|
|
88
|
-
if self._random_object_stream:
|
|
88
|
+
if self._random_object_stream and isinstance(self._random_object_stream, tosnativeclient.ReadStream):
|
|
89
89
|
self._random_object_stream.close()
|
|
90
90
|
|
|
91
91
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tostorchconnector
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.8
|
|
4
4
|
Summary: TOS connector integration for PyTorch
|
|
5
5
|
Author-email: xiangshijian <xiangshijian@bytedance.com>
|
|
6
6
|
Classifier: Development Status :: 4 - Beta
|
|
@@ -19,7 +19,7 @@ Description-Content-Type: text/markdown
|
|
|
19
19
|
License-File: LICENSE
|
|
20
20
|
Requires-Dist: torch>=2.0
|
|
21
21
|
Requires-Dist: tos>=2.8.0
|
|
22
|
-
Requires-Dist: tosnativeclient>=1.0.
|
|
22
|
+
Requires-Dist: tosnativeclient>=1.0.6
|
|
23
23
|
Dynamic: license-file
|
|
24
24
|
|
|
25
25
|
# TOS Connector for pytorch
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
{tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector.egg-info/top_level.txt
RENAMED
|
File without changes
|