tostorchconnector 1.0.3__tar.gz → 1.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tostorchconnector might be problematic. Click here for more details.
- {tostorchconnector-1.0.3/tostorchconnector.egg-info → tostorchconnector-1.0.5}/PKG-INFO +2 -1
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/pyproject.toml +2 -1
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/tests/test_tos_dataset.py +27 -10
- tostorchconnector-1.0.3/tests/test_tosnativeclient.py → tostorchconnector-1.0.5/tests/test_tosclient.py +6 -10
- tostorchconnector-1.0.5/tests/test_tosrawclient.py +102 -0
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/tostorchconnector/__init__.py +2 -1
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/tostorchconnector/tos_checkpoint.py +8 -14
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/tostorchconnector/tos_client.py +27 -17
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/tostorchconnector/tos_common.py +10 -5
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/tostorchconnector/tos_iterable_dataset.py +24 -17
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/tostorchconnector/tos_map_dataset.py +24 -18
- tostorchconnector-1.0.5/tostorchconnector/tos_object_reader.py +386 -0
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/tostorchconnector/tos_object_writer.py +29 -15
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5/tostorchconnector.egg-info}/PKG-INFO +2 -1
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/tostorchconnector.egg-info/SOURCES.txt +2 -1
- tostorchconnector-1.0.3/tostorchconnector/tos_object_reader.py +0 -188
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/LICENSE +0 -0
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/README.md +0 -0
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/setup.cfg +0 -0
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/tostorchconnector/tos_object_meta.py +0 -0
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/tostorchconnector.egg-info/dependency_links.txt +0 -0
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/tostorchconnector.egg-info/requires.txt +0 -0
- {tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/tostorchconnector.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tostorchconnector
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.5
|
|
4
4
|
Summary: TOS connector integration for PyTorch
|
|
5
5
|
Author-email: xiangshijian <xiangshijian@bytedance.com>
|
|
6
6
|
Classifier: Development Status :: 4 - Beta
|
|
@@ -10,6 +10,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
10
10
|
Classifier: Programming Language :: Python :: 3.10
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.11
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
13
14
|
Classifier: License :: OSI Approved :: Apache Software License
|
|
14
15
|
Classifier: Operating System :: OS Independent
|
|
15
16
|
Classifier: Topic :: Utilities
|
|
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|
|
5
5
|
|
|
6
6
|
[project]
|
|
7
7
|
name = "tostorchconnector"
|
|
8
|
-
version = "1.0.
|
|
8
|
+
version = "1.0.5"
|
|
9
9
|
description = "TOS connector integration for PyTorch"
|
|
10
10
|
authors = [{ name = "xiangshijian", email = "xiangshijian@bytedance.com" }]
|
|
11
11
|
requires-python = ">=3.8,<3.13"
|
|
@@ -18,6 +18,7 @@ classifiers = [
|
|
|
18
18
|
"Programming Language :: Python :: 3.10",
|
|
19
19
|
"Programming Language :: Python :: 3.11",
|
|
20
20
|
"Programming Language :: Python :: 3.12",
|
|
21
|
+
"Programming Language :: Python :: 3.13",
|
|
21
22
|
"License :: OSI Approved :: Apache Software License",
|
|
22
23
|
"Operating System :: OS Independent",
|
|
23
24
|
"Topic :: Utilities"
|
|
@@ -2,7 +2,9 @@ import os
|
|
|
2
2
|
import unittest
|
|
3
3
|
|
|
4
4
|
from tostorchconnector import TosMapDataset, TosIterableDataset, TosCheckpoint
|
|
5
|
-
from tostorchconnector.tos_client import CredentialProvider,
|
|
5
|
+
from tostorchconnector.tos_client import CredentialProvider, ReaderType
|
|
6
|
+
|
|
7
|
+
USE_NATIVE_CLIENT = True
|
|
6
8
|
|
|
7
9
|
|
|
8
10
|
class TestTosDataSet(unittest.TestCase):
|
|
@@ -14,7 +16,8 @@ class TestTosDataSet(unittest.TestCase):
|
|
|
14
16
|
sk = os.getenv('TOS_SECRET_KEY')
|
|
15
17
|
bucket = 'tos-pytorch-connector'
|
|
16
18
|
datasets = TosMapDataset.from_urls(iter([f'tos://{bucket}/key1', f'tos://{bucket}/key2', f'{bucket}/key3']),
|
|
17
|
-
region=region, endpoint=endpoint, cred=CredentialProvider(ak, sk)
|
|
19
|
+
region=region, endpoint=endpoint, cred=CredentialProvider(ak, sk),
|
|
20
|
+
use_native_client=USE_NATIVE_CLIENT)
|
|
18
21
|
|
|
19
22
|
for i in range(len(datasets)):
|
|
20
23
|
print(datasets[i].bucket, datasets[i].key)
|
|
@@ -26,7 +29,8 @@ class TestTosDataSet(unittest.TestCase):
|
|
|
26
29
|
sk = os.getenv('TOS_SECRET_KEY')
|
|
27
30
|
bucket = 'tos-pytorch-connector'
|
|
28
31
|
datasets = TosMapDataset.from_prefix(f'tos://{bucket}/prefix', region=region,
|
|
29
|
-
endpoint=endpoint, cred=CredentialProvider(ak, sk)
|
|
32
|
+
endpoint=endpoint, cred=CredentialProvider(ak, sk),
|
|
33
|
+
use_native_client=USE_NATIVE_CLIENT)
|
|
30
34
|
for i in range(len(datasets)):
|
|
31
35
|
item = datasets[i]
|
|
32
36
|
print(item.bucket, item.key)
|
|
@@ -43,7 +47,8 @@ class TestTosDataSet(unittest.TestCase):
|
|
|
43
47
|
sk = os.getenv('TOS_SECRET_KEY')
|
|
44
48
|
bucket = 'tos-pytorch-connector'
|
|
45
49
|
datasets = TosIterableDataset.from_prefix(f'tos://{bucket}/prefix', region=region,
|
|
46
|
-
endpoint=endpoint, cred=CredentialProvider(ak, sk)
|
|
50
|
+
endpoint=endpoint, cred=CredentialProvider(ak, sk),
|
|
51
|
+
use_native_client=USE_NATIVE_CLIENT)
|
|
47
52
|
i = 0
|
|
48
53
|
for dataset in datasets:
|
|
49
54
|
print(dataset.bucket, dataset.key)
|
|
@@ -59,15 +64,27 @@ class TestTosDataSet(unittest.TestCase):
|
|
|
59
64
|
ak = os.getenv('TOS_ACCESS_KEY')
|
|
60
65
|
sk = os.getenv('TOS_SECRET_KEY')
|
|
61
66
|
bucket = 'tos-pytorch-connector'
|
|
62
|
-
checkpoint = TosCheckpoint(region, endpoint, cred=CredentialProvider(ak, sk),
|
|
67
|
+
checkpoint = TosCheckpoint(region, endpoint, cred=CredentialProvider(ak, sk),
|
|
68
|
+
use_native_client=USE_NATIVE_CLIENT)
|
|
63
69
|
url = f'tos://{bucket}/key1'
|
|
70
|
+
print('test sequential')
|
|
64
71
|
with checkpoint.writer(url) as writer:
|
|
65
72
|
writer.write(b'hello world')
|
|
66
73
|
writer.write(b'hi world')
|
|
67
74
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
75
|
+
with checkpoint.reader(url) as reader:
|
|
76
|
+
data = reader.read(5)
|
|
77
|
+
print(data)
|
|
78
|
+
print(reader.read())
|
|
79
|
+
reader.seek(0)
|
|
80
|
+
data = reader.read(5)
|
|
81
|
+
print(data)
|
|
71
82
|
|
|
72
|
-
|
|
73
|
-
|
|
83
|
+
print('test ranged')
|
|
84
|
+
with checkpoint.reader(url, reader_type=ReaderType.RANGED) as reader:
|
|
85
|
+
data = reader.read(5)
|
|
86
|
+
print(data)
|
|
87
|
+
print(reader.read())
|
|
88
|
+
reader.seek(0)
|
|
89
|
+
data = reader.read(5)
|
|
90
|
+
print(data)
|
|
@@ -5,7 +5,7 @@ import uuid
|
|
|
5
5
|
from tosnativeclient import TosClient, TosException
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
class
|
|
8
|
+
class TestTosClient(unittest.TestCase):
|
|
9
9
|
def test_list_objects(self):
|
|
10
10
|
region = os.getenv('TOS_REGION')
|
|
11
11
|
endpoint = os.getenv('TOS_ENDPOINT')
|
|
@@ -15,16 +15,16 @@ class TestTosNativeClient(unittest.TestCase):
|
|
|
15
15
|
tos_client = TosClient(region, endpoint, ak, sk, directives='info', directory='logs',
|
|
16
16
|
file_name_prefix='app.log')
|
|
17
17
|
|
|
18
|
-
list_stream = tos_client.list_objects(bucket, '
|
|
18
|
+
list_stream = tos_client.list_objects(bucket, '', max_keys=1000)
|
|
19
19
|
count = 0
|
|
20
20
|
try:
|
|
21
21
|
for objects in list_stream:
|
|
22
|
-
count += 1
|
|
23
22
|
for content in objects.contents:
|
|
23
|
+
count += 1
|
|
24
24
|
print(content.key, content.size)
|
|
25
|
-
output = tos_client.head_object(bucket, content.key)
|
|
26
|
-
assert output.etag == content.etag
|
|
27
|
-
assert output.size == content.size
|
|
25
|
+
# output = tos_client.head_object(bucket, content.key)
|
|
26
|
+
# assert output.etag == content.etag
|
|
27
|
+
# assert output.size == content.size
|
|
28
28
|
|
|
29
29
|
print(count)
|
|
30
30
|
except TosException as e:
|
|
@@ -70,7 +70,3 @@ class TestTosNativeClient(unittest.TestCase):
|
|
|
70
70
|
print(chunk)
|
|
71
71
|
except TosException as e:
|
|
72
72
|
print(e.args[0].status_code)
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
if __name__ == '__main__':
|
|
76
|
-
unittest.main()
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import unittest
|
|
3
|
+
import uuid
|
|
4
|
+
|
|
5
|
+
import tos
|
|
6
|
+
from tos.exceptions import TosServerError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestTosRawClient(unittest.TestCase):
|
|
10
|
+
def test_put_get(self):
|
|
11
|
+
from tosnativeclient import TosRawClient, HeadObjectInput, TosException, DeleteObjectInput, \
|
|
12
|
+
PutObjectFromBufferInput, \
|
|
13
|
+
GetObjectInput, GetObjectOutput
|
|
14
|
+
|
|
15
|
+
region = os.getenv('TOS_REGION')
|
|
16
|
+
endpoint = os.getenv('TOS_ENDPOINT')
|
|
17
|
+
ak = os.getenv('TOS_ACCESS_KEY')
|
|
18
|
+
sk = os.getenv('TOS_SECRET_KEY')
|
|
19
|
+
bucket = 'tos-pytorch-connector'
|
|
20
|
+
key = str(uuid.uuid4())
|
|
21
|
+
tos_raw_client = TosRawClient(region, endpoint, ak, sk)
|
|
22
|
+
|
|
23
|
+
doutput = tos_raw_client.delete_object(DeleteObjectInput(bucket, key))
|
|
24
|
+
assert doutput.status_code == 204
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
tos_raw_client.head_object(HeadObjectInput(bucket, key))
|
|
28
|
+
assert False
|
|
29
|
+
except TosException as e:
|
|
30
|
+
assert e.args[0].status_code == 404
|
|
31
|
+
assert len(e.args[0].request_id) > 0
|
|
32
|
+
|
|
33
|
+
data = str(uuid.uuid4()).encode('utf-8')
|
|
34
|
+
input = PutObjectFromBufferInput(bucket, key, data)
|
|
35
|
+
poutput = tos_raw_client.put_object_from_buffer(input)
|
|
36
|
+
assert poutput.status_code == 200
|
|
37
|
+
assert len(poutput.etag) > 0
|
|
38
|
+
|
|
39
|
+
houtput = tos_raw_client.head_object(HeadObjectInput(bucket, key))
|
|
40
|
+
assert houtput.status_code == 200
|
|
41
|
+
assert houtput.etag == poutput.etag
|
|
42
|
+
assert houtput.content_length == len(data)
|
|
43
|
+
|
|
44
|
+
goutput: GetObjectOutput = tos_raw_client.get_object(GetObjectInput(bucket, key))
|
|
45
|
+
assert goutput.status_code == 200
|
|
46
|
+
assert goutput.etag == poutput.etag
|
|
47
|
+
rdata = goutput.read_all()
|
|
48
|
+
assert rdata == data
|
|
49
|
+
|
|
50
|
+
doutput = tos_raw_client.delete_object(DeleteObjectInput(bucket, key))
|
|
51
|
+
assert doutput.status_code == 204
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
tos_raw_client.get_object(GetObjectInput(bucket, key))
|
|
55
|
+
assert False
|
|
56
|
+
except TosException as e:
|
|
57
|
+
assert e.args[0].status_code == 404
|
|
58
|
+
assert len(e.args[0].request_id) > 0
|
|
59
|
+
|
|
60
|
+
def test_put_get_old(self):
|
|
61
|
+
region = os.getenv('TOS_REGION')
|
|
62
|
+
endpoint = os.getenv('TOS_ENDPOINT')
|
|
63
|
+
ak = os.getenv('TOS_ACCESS_KEY')
|
|
64
|
+
sk = os.getenv('TOS_SECRET_KEY')
|
|
65
|
+
bucket = 'tos-pytorch-connector'
|
|
66
|
+
key = str(uuid.uuid4())
|
|
67
|
+
tos_client = tos.TosClientV2(ak, sk, endpoint=endpoint, region=region)
|
|
68
|
+
doutput = tos_client.delete_object(bucket, key)
|
|
69
|
+
assert doutput.status_code == 204
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
tos_client.head_object(bucket, key)
|
|
73
|
+
assert False
|
|
74
|
+
except TosServerError as e:
|
|
75
|
+
assert e.status_code == 404
|
|
76
|
+
assert len(e.request_id) > 0
|
|
77
|
+
|
|
78
|
+
data = str(uuid.uuid4()).encode('utf-8')
|
|
79
|
+
poutput = tos_client.put_object(bucket, key, content=data)
|
|
80
|
+
assert poutput.status_code == 200
|
|
81
|
+
assert len(poutput.etag) > 0
|
|
82
|
+
|
|
83
|
+
houtput = tos_client.head_object(bucket, key)
|
|
84
|
+
assert houtput.status_code == 200
|
|
85
|
+
assert houtput.etag == poutput.etag
|
|
86
|
+
assert houtput.content_length == len(data)
|
|
87
|
+
|
|
88
|
+
goutput = tos_client.get_object(bucket, key)
|
|
89
|
+
assert goutput.status_code == 200
|
|
90
|
+
assert goutput.etag == poutput.etag
|
|
91
|
+
rdata = goutput.read()
|
|
92
|
+
assert rdata == data
|
|
93
|
+
|
|
94
|
+
doutput = tos_client.delete_object(bucket, key)
|
|
95
|
+
assert doutput.status_code == 204
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
tos_client.get_object(bucket, key)
|
|
99
|
+
assert False
|
|
100
|
+
except TosServerError as e:
|
|
101
|
+
assert e.status_code == 404
|
|
102
|
+
assert len(e.request_id) > 0
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from .tos_object_reader import TosObjectReader
|
|
1
|
+
from .tos_object_reader import TosObjectReader, SequentialTosObjectReader
|
|
2
2
|
from .tos_object_writer import TosObjectWriter
|
|
3
3
|
from .tos_iterable_dataset import TosIterableDataset
|
|
4
4
|
from .tos_map_dataset import TosMapDataset
|
|
@@ -7,6 +7,7 @@ from tosnativeclient import TosException, TosError
|
|
|
7
7
|
|
|
8
8
|
__all__ = [
|
|
9
9
|
'TosObjectReader',
|
|
10
|
+
'SequentialTosObjectReader',
|
|
10
11
|
'TosObjectWriter',
|
|
11
12
|
'TosIterableDataset',
|
|
12
13
|
'TosMapDataset',
|
|
@@ -2,7 +2,7 @@ import logging
|
|
|
2
2
|
from typing import Optional
|
|
3
3
|
|
|
4
4
|
from . import TosObjectReader, TosObjectWriter
|
|
5
|
-
from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig
|
|
5
|
+
from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig, ReaderType
|
|
6
6
|
from .tos_common import parse_tos_url
|
|
7
7
|
|
|
8
8
|
log = logging.getLogger(__name__)
|
|
@@ -14,26 +14,20 @@ class TosCheckpoint(object):
|
|
|
14
14
|
cred: Optional[CredentialProvider] = None,
|
|
15
15
|
client_conf: Optional[TosClientConfig] = None,
|
|
16
16
|
log_conf: Optional[TosLogConfig] = None, use_native_client=True):
|
|
17
|
-
self._client = None
|
|
18
|
-
self._native_client = None
|
|
19
17
|
self._region = region
|
|
20
18
|
self._endpoint = endpoint
|
|
21
19
|
self._cred = cred
|
|
22
20
|
self._client_conf = client_conf
|
|
23
21
|
self._log_conf = log_conf
|
|
24
|
-
self.
|
|
22
|
+
self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
|
|
23
|
+
use_native_client)
|
|
24
|
+
log.info('TosCheckpoint init tos client succeed')
|
|
25
25
|
|
|
26
|
-
def reader(self, url: str
|
|
26
|
+
def reader(self, url: str, reader_type: Optional[ReaderType] = None,
|
|
27
|
+
buffer_size: Optional[int] = None) -> TosObjectReader:
|
|
27
28
|
bucket, key = parse_tos_url(url)
|
|
28
|
-
return self.
|
|
29
|
+
return self._client.get_object(bucket, key, reader_type=reader_type, buffer_size=buffer_size)
|
|
29
30
|
|
|
30
31
|
def writer(self, url: str) -> TosObjectWriter:
|
|
31
32
|
bucket, key = parse_tos_url(url)
|
|
32
|
-
return self.
|
|
33
|
-
|
|
34
|
-
def _get_tos_client(self):
|
|
35
|
-
if self._client is None:
|
|
36
|
-
self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
|
|
37
|
-
self._use_native_client)
|
|
38
|
-
log.info('TosIterableDataset init tos client succeed')
|
|
39
|
-
return self._client
|
|
33
|
+
return self._client.put_object(bucket, key)
|
|
@@ -1,21 +1,26 @@
|
|
|
1
|
-
import
|
|
1
|
+
import enum
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
4
|
from functools import partial
|
|
5
5
|
from typing import Optional, List, Tuple
|
|
6
6
|
|
|
7
7
|
import tos
|
|
8
|
-
from tos.models2 import GetObjectOutput, PutObjectOutput
|
|
9
8
|
|
|
10
9
|
import tosnativeclient
|
|
11
10
|
|
|
12
|
-
from . import
|
|
11
|
+
from . import SequentialTosObjectReader, TosObjectWriter, TosObjectReader
|
|
13
12
|
from .tos_object_meta import TosObjectMeta
|
|
13
|
+
from .tos_object_reader import TosObjectStream, RangedTosObjectReader
|
|
14
14
|
from .tos_object_writer import PutObjectStream
|
|
15
15
|
|
|
16
16
|
log = logging.getLogger(__name__)
|
|
17
17
|
|
|
18
18
|
|
|
19
|
+
class ReaderType(enum.Enum):
|
|
20
|
+
SEQUENTIAL = 'Sequential'
|
|
21
|
+
RANGED = 'Ranged'
|
|
22
|
+
|
|
23
|
+
|
|
19
24
|
class CredentialProvider(object):
|
|
20
25
|
def __init__(self, ak: str, sk: str):
|
|
21
26
|
self._ak = ak
|
|
@@ -110,48 +115,49 @@ class TosClient(object):
|
|
|
110
115
|
tos.set_logger(file_path=file_path, level=log_level)
|
|
111
116
|
|
|
112
117
|
@property
|
|
113
|
-
def use_native_client(self):
|
|
118
|
+
def use_native_client(self) -> bool:
|
|
114
119
|
return self._use_native_client
|
|
115
120
|
|
|
116
121
|
def get_object(self, bucket: str, key: str, etag: Optional[str] = None,
|
|
117
|
-
size: Optional[int] = None
|
|
122
|
+
size: Optional[int] = None, reader_type: Optional[ReaderType] = None,
|
|
123
|
+
buffer_size: Optional[int] = None) -> TosObjectReader:
|
|
118
124
|
log.debug(f'get_object tos://{bucket}/{key}')
|
|
125
|
+
|
|
119
126
|
if size is None or etag is None:
|
|
120
127
|
get_object_meta = partial(self.head_object, bucket, key)
|
|
121
128
|
else:
|
|
122
129
|
get_object_meta = lambda: TosObjectMeta(bucket, key, size, etag)
|
|
123
|
-
if isinstance(self._client, tosnativeclient.TosClient):
|
|
124
|
-
get_object_stream = partial(self._client.get_object, bucket, key)
|
|
125
|
-
else:
|
|
126
|
-
def get_object_stream(et: str, _: int) -> GetObjectOutput:
|
|
127
|
-
return self._client.get_object(bucket, key, '', et)
|
|
128
130
|
|
|
129
|
-
|
|
131
|
+
object_stream = TosObjectStream(bucket, key, get_object_meta, self._client)
|
|
132
|
+
if reader_type is not None and reader_type == ReaderType.RANGED:
|
|
133
|
+
return RangedTosObjectReader(bucket, key, object_stream, buffer_size)
|
|
134
|
+
return SequentialTosObjectReader(bucket, key, object_stream)
|
|
130
135
|
|
|
131
136
|
def put_object(self, bucket: str, key: str, storage_class: Optional[str] = None) -> TosObjectWriter:
|
|
132
137
|
log.debug(f'put_object tos://{bucket}/{key}')
|
|
133
138
|
|
|
134
139
|
if isinstance(self._client, tosnativeclient.TosClient):
|
|
135
|
-
|
|
136
|
-
put_object_stream = self._client.put_object(bucket, key, storage_class)
|
|
140
|
+
put_object_stream = self._client.put_object(bucket, key, storage_class=storage_class)
|
|
137
141
|
else:
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
put_object_stream = PutObjectStream(put_object)
|
|
142
|
+
put_object_stream = PutObjectStream(
|
|
143
|
+
lambda content: self._client.put_object(bucket, key, storage_class=storage_class, content=content))
|
|
142
144
|
|
|
143
145
|
return TosObjectWriter(bucket, key, put_object_stream)
|
|
144
146
|
|
|
145
147
|
def head_object(self, bucket: str, key: str) -> TosObjectMeta:
|
|
146
148
|
log.debug(f'head_object tos://{bucket}/{key}')
|
|
149
|
+
|
|
147
150
|
if isinstance(self._client, tosnativeclient.TosClient):
|
|
148
151
|
resp = self._client.head_object(bucket, key)
|
|
149
152
|
return TosObjectMeta(resp.bucket, resp.key, resp.size, resp.etag)
|
|
153
|
+
|
|
150
154
|
resp = self._client.head_object(bucket, key)
|
|
151
155
|
return TosObjectMeta(bucket, key, resp.content_length, resp.etag)
|
|
152
156
|
|
|
153
157
|
def gen_list_stream(self, bucket: str, prefix: str, max_keys: int = 1000,
|
|
154
158
|
delimiter: Optional[str] = None) -> tosnativeclient.ListStream:
|
|
159
|
+
log.debug(f'gen_list_stream tos://{bucket}/{prefix}')
|
|
160
|
+
|
|
155
161
|
if isinstance(self._client, tosnativeclient.TosClient):
|
|
156
162
|
delimiter = delimiter if delimiter is not None else ''
|
|
157
163
|
return self._client.list_objects(bucket, prefix, max_keys=max_keys, delimiter=delimiter)
|
|
@@ -161,6 +167,10 @@ class TosClient(object):
|
|
|
161
167
|
continuation_token: Optional[str] = None, delimiter: Optional[str] = None) -> Tuple[
|
|
162
168
|
List[TosObjectMeta], bool, Optional[str]]:
|
|
163
169
|
log.debug(f'list_objects tos://{bucket}/{prefix}')
|
|
170
|
+
|
|
171
|
+
if isinstance(self._client, tosnativeclient.TosClient):
|
|
172
|
+
raise NotImplementedError()
|
|
173
|
+
|
|
164
174
|
resp = self._client.list_objects_type2(bucket, prefix, max_keys=max_keys, continuation_token=continuation_token,
|
|
165
175
|
delimiter=delimiter)
|
|
166
176
|
object_metas = []
|
|
@@ -22,14 +22,14 @@ class TosObjectIterator(object):
|
|
|
22
22
|
self._is_truncated = True
|
|
23
23
|
self._continuation_token = None
|
|
24
24
|
|
|
25
|
-
def close(self):
|
|
25
|
+
def close(self) -> None:
|
|
26
26
|
if self._list_stream is not None:
|
|
27
27
|
self._list_stream.close()
|
|
28
28
|
|
|
29
29
|
def __iter__(self) -> Iterator[TosObjectMeta]:
|
|
30
30
|
return self
|
|
31
31
|
|
|
32
|
-
def __next__(self):
|
|
32
|
+
def __next__(self) -> TosObjectMeta:
|
|
33
33
|
if self._client.use_native_client:
|
|
34
34
|
if self._list_stream is None:
|
|
35
35
|
self._list_stream = self._client.gen_list_stream(self._bucket, self._prefix, max_keys=1000,
|
|
@@ -37,9 +37,14 @@ class TosObjectIterator(object):
|
|
|
37
37
|
|
|
38
38
|
if self._object_metas is None or self._index >= len(self._object_metas):
|
|
39
39
|
self._object_metas = []
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
40
|
+
self._index = 0
|
|
41
|
+
while 1:
|
|
42
|
+
objects = next(self._list_stream)
|
|
43
|
+
for content in objects.contents:
|
|
44
|
+
self._object_metas.append(
|
|
45
|
+
TosObjectMeta(content.bucket, content.key, content.size, content.etag))
|
|
46
|
+
if len(self._object_metas) > 0:
|
|
47
|
+
break
|
|
43
48
|
|
|
44
49
|
object_meta = self._object_metas[self._index]
|
|
45
50
|
self._index += 1
|
{tostorchconnector-1.0.3 → tostorchconnector-1.0.5}/tostorchconnector/tos_iterable_dataset.py
RENAMED
|
@@ -5,7 +5,7 @@ from typing import Iterator, Any, Optional, Callable, Union
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from . import TosObjectReader
|
|
8
|
-
from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig
|
|
8
|
+
from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig, ReaderType
|
|
9
9
|
from .tos_common import default_trans, gen_dataset_from_urls, gen_dataset_from_prefix
|
|
10
10
|
from .tos_object_meta import TosObjectMeta
|
|
11
11
|
|
|
@@ -20,8 +20,10 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
|
|
|
20
20
|
cred: Optional[CredentialProvider] = None,
|
|
21
21
|
client_conf: Optional[TosClientConfig] = None,
|
|
22
22
|
log_conf: Optional[TosLogConfig] = None,
|
|
23
|
-
sharding: bool = False,
|
|
24
|
-
|
|
23
|
+
sharding: bool = False,
|
|
24
|
+
use_native_client: bool = True,
|
|
25
|
+
reader_type: Optional[ReaderType] = None,
|
|
26
|
+
buffer_size: Optional[int] = None):
|
|
25
27
|
self._gen_dataset = gen_dataset
|
|
26
28
|
self._region = region
|
|
27
29
|
self._endpoint = endpoint
|
|
@@ -30,13 +32,17 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
|
|
|
30
32
|
self._client_conf = client_conf
|
|
31
33
|
self._log_conf = log_conf
|
|
32
34
|
self._sharding = sharding
|
|
33
|
-
self._use_native_client = use_native_client
|
|
34
35
|
if torch.distributed.is_initialized():
|
|
35
36
|
self._rank = torch.distributed.get_rank()
|
|
36
37
|
self._world_size = torch.distributed.get_world_size()
|
|
37
38
|
else:
|
|
38
39
|
self._rank = 0
|
|
39
40
|
self._world_size = 1
|
|
41
|
+
self._reader_type = reader_type
|
|
42
|
+
self._buffer_size = buffer_size
|
|
43
|
+
self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
|
|
44
|
+
use_native_client)
|
|
45
|
+
log.info('TosIterableDataset init tos client succeed')
|
|
40
46
|
|
|
41
47
|
@classmethod
|
|
42
48
|
def from_urls(cls, urls: Union[str, Iterator[str]], *, region: str, endpoint: Optional[str] = None,
|
|
@@ -44,10 +50,13 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
|
|
|
44
50
|
cred: Optional[CredentialProvider] = None,
|
|
45
51
|
client_conf: Optional[TosClientConfig] = None,
|
|
46
52
|
log_conf: Optional[TosLogConfig] = None,
|
|
47
|
-
sharding: bool = False,
|
|
53
|
+
sharding: bool = False,
|
|
54
|
+
use_native_client: bool = True,
|
|
55
|
+
reader_type: Optional[ReaderType] = None,
|
|
56
|
+
buffer_size: Optional[int] = None):
|
|
48
57
|
log.info(f'building {cls.__name__} from_urls')
|
|
49
58
|
return cls(region, partial(gen_dataset_from_urls, urls), endpoint, transform, cred, client_conf, log_conf,
|
|
50
|
-
sharding, use_native_client)
|
|
59
|
+
sharding, use_native_client, reader_type, buffer_size)
|
|
51
60
|
|
|
52
61
|
@classmethod
|
|
53
62
|
def from_prefix(cls, prefix: str, *, region: str, endpoint: Optional[str] = None,
|
|
@@ -55,10 +64,13 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
|
|
|
55
64
|
cred: Optional[CredentialProvider] = None,
|
|
56
65
|
client_conf: Optional[TosClientConfig] = None,
|
|
57
66
|
log_conf: Optional[TosLogConfig] = None,
|
|
58
|
-
sharding: bool = False,
|
|
67
|
+
sharding: bool = False,
|
|
68
|
+
use_native_client: bool = True,
|
|
69
|
+
reader_type: Optional[ReaderType] = None,
|
|
70
|
+
buffer_size: Optional[int] = None):
|
|
59
71
|
log.info(f'building {cls.__name__} from_prefix')
|
|
60
72
|
return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, transform, cred, client_conf, log_conf,
|
|
61
|
-
sharding, use_native_client)
|
|
73
|
+
sharding, use_native_client, reader_type, buffer_size)
|
|
62
74
|
|
|
63
75
|
def __iter__(self) -> Iterator[Any]:
|
|
64
76
|
worker_id = 0
|
|
@@ -72,12 +84,12 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
|
|
|
72
84
|
if not self._sharding or (self._world_size == 1 and num_workers == 1):
|
|
73
85
|
return map(
|
|
74
86
|
self._trans_tos_object,
|
|
75
|
-
self._gen_dataset(self.
|
|
87
|
+
self._gen_dataset(self._client),
|
|
76
88
|
)
|
|
77
89
|
|
|
78
90
|
part_dataset = (
|
|
79
91
|
obj
|
|
80
|
-
for idx, obj in enumerate(self._gen_dataset(self.
|
|
92
|
+
for idx, obj in enumerate(self._gen_dataset(self._client))
|
|
81
93
|
if idx % self._world_size == self._rank
|
|
82
94
|
)
|
|
83
95
|
|
|
@@ -88,12 +100,7 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
|
|
|
88
100
|
)
|
|
89
101
|
return map(self._trans_tos_object, part_dataset)
|
|
90
102
|
|
|
91
|
-
def _get_tos_client(self):
|
|
92
|
-
if self._client is None:
|
|
93
|
-
self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf)
|
|
94
|
-
log.info('TosIterableDataset init tos client succeed')
|
|
95
|
-
return self._client
|
|
96
|
-
|
|
97
103
|
def _trans_tos_object(self, object_meta: TosObjectMeta) -> Any:
|
|
98
|
-
obj = self.
|
|
104
|
+
obj = self._client.get_object(object_meta.bucket, object_meta.key, object_meta.etag, object_meta.size,
|
|
105
|
+
reader_type=self._reader_type, buffer_size=self._buffer_size)
|
|
99
106
|
return self._trans(obj)
|
|
@@ -5,7 +5,7 @@ from typing import Any, Callable, Iterator, Optional, List, Union
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from . import TosObjectReader
|
|
8
|
-
from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig
|
|
8
|
+
from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig, ReaderType
|
|
9
9
|
from .tos_common import default_trans, gen_dataset_from_prefix, \
|
|
10
10
|
gen_dataset_from_urls
|
|
11
11
|
from .tos_object_meta import TosObjectMeta
|
|
@@ -20,8 +20,10 @@ class TosMapDataset(torch.utils.data.Dataset):
|
|
|
20
20
|
trans: Callable[[TosObjectReader], Any] = default_trans,
|
|
21
21
|
cred: Optional[CredentialProvider] = None,
|
|
22
22
|
client_conf: Optional[TosClientConfig] = None,
|
|
23
|
-
log_conf: Optional[TosLogConfig] = None,
|
|
24
|
-
|
|
23
|
+
log_conf: Optional[TosLogConfig] = None,
|
|
24
|
+
use_native_client: bool = True,
|
|
25
|
+
reader_type: Optional[ReaderType] = None,
|
|
26
|
+
buffer_size: Optional[int] = None):
|
|
25
27
|
self._gen_dataset = gen_dataset
|
|
26
28
|
self._region = region
|
|
27
29
|
self._endpoint = endpoint
|
|
@@ -29,28 +31,38 @@ class TosMapDataset(torch.utils.data.Dataset):
|
|
|
29
31
|
self._cred = cred
|
|
30
32
|
self._client_conf = client_conf
|
|
31
33
|
self._log_conf = log_conf
|
|
32
|
-
self._use_native_client = use_native_client
|
|
33
34
|
self._dataset: Optional[List[TosObjectMeta]] = None
|
|
35
|
+
self._reader_type = reader_type
|
|
36
|
+
self._buffer_size = buffer_size
|
|
37
|
+
self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
|
|
38
|
+
use_native_client)
|
|
39
|
+
log.info('TosMapDataset init tos client succeed')
|
|
34
40
|
|
|
35
41
|
@classmethod
|
|
36
42
|
def from_urls(cls, urls: Union[str, Iterator[str]], *, region: str, endpoint: Optional[str] = None,
|
|
37
43
|
trans: Callable[[TosObjectReader], Any] = default_trans,
|
|
38
44
|
cred: Optional[CredentialProvider] = None,
|
|
39
45
|
client_conf: Optional[TosClientConfig] = None,
|
|
40
|
-
log_conf: Optional[TosLogConfig] = None,
|
|
46
|
+
log_conf: Optional[TosLogConfig] = None,
|
|
47
|
+
use_native_client: bool = True,
|
|
48
|
+
reader_type: Optional[ReaderType] = None,
|
|
49
|
+
buffer_size: Optional[int] = None):
|
|
41
50
|
log.info(f'building {cls.__name__} from_urls')
|
|
42
51
|
return cls(region, partial(gen_dataset_from_urls, urls), endpoint, trans, cred, client_conf, log_conf,
|
|
43
|
-
use_native_client)
|
|
52
|
+
use_native_client, reader_type, buffer_size)
|
|
44
53
|
|
|
45
54
|
@classmethod
|
|
46
55
|
def from_prefix(cls, prefix: str, *, region: str, endpoint: Optional[str] = None,
|
|
47
56
|
trans: Callable[[TosObjectReader], Any] = default_trans,
|
|
48
57
|
cred: Optional[CredentialProvider] = None,
|
|
49
58
|
client_conf: Optional[TosClientConfig] = None,
|
|
50
|
-
log_conf: Optional[TosLogConfig] = None,
|
|
59
|
+
log_conf: Optional[TosLogConfig] = None,
|
|
60
|
+
use_native_client: bool = True,
|
|
61
|
+
reader_type: Optional[ReaderType] = None,
|
|
62
|
+
buffer_size: Optional[int] = None):
|
|
51
63
|
log.info(f'building {cls.__name__} from_prefix')
|
|
52
64
|
return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, trans, cred, client_conf, log_conf,
|
|
53
|
-
use_native_client)
|
|
65
|
+
use_native_client, reader_type, buffer_size)
|
|
54
66
|
|
|
55
67
|
def __getitem__(self, i: int) -> Any:
|
|
56
68
|
return self._trans_tos_object(i)
|
|
@@ -58,21 +70,15 @@ class TosMapDataset(torch.utils.data.Dataset):
|
|
|
58
70
|
def __len__(self) -> int:
|
|
59
71
|
return len(self._data_set)
|
|
60
72
|
|
|
61
|
-
def _get_tos_client(self) -> TosClient:
|
|
62
|
-
if self._client is None:
|
|
63
|
-
self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
|
|
64
|
-
self._use_native_client)
|
|
65
|
-
log.info('TosMapDataset init tos client succeed')
|
|
66
|
-
return self._client
|
|
67
|
-
|
|
68
73
|
@property
|
|
69
|
-
def _data_set(self):
|
|
74
|
+
def _data_set(self) -> List[TosObjectMeta]:
|
|
70
75
|
if self._dataset is None:
|
|
71
|
-
self._dataset = list(self._gen_dataset(self.
|
|
76
|
+
self._dataset = list(self._gen_dataset(self._client))
|
|
72
77
|
assert self._dataset is not None
|
|
73
78
|
return self._dataset
|
|
74
79
|
|
|
75
80
|
def _trans_tos_object(self, i: int) -> Any:
|
|
76
81
|
object_meta = self._data_set[i]
|
|
77
|
-
obj = self.
|
|
82
|
+
obj = self._client.get_object(object_meta.bucket, object_meta.key, object_meta.etag, object_meta.size,
|
|
83
|
+
reader_type=self._reader_type, buffer_size=self._buffer_size)
|
|
78
84
|
return self._trans(obj)
|