tostorchconnector 1.0.4__tar.gz → 1.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tostorchconnector might be problematic. Click here for more details.

Files changed (23) hide show
  1. {tostorchconnector-1.0.4/tostorchconnector.egg-info → tostorchconnector-1.0.5}/PKG-INFO +2 -1
  2. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5}/pyproject.toml +2 -1
  3. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5}/tests/test_tos_dataset.py +27 -10
  4. tostorchconnector-1.0.4/tests/test_tosnativeclient.py → tostorchconnector-1.0.5/tests/test_tosclient.py +6 -10
  5. tostorchconnector-1.0.5/tests/test_tosrawclient.py +102 -0
  6. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5}/tostorchconnector/__init__.py +2 -1
  7. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5}/tostorchconnector/tos_checkpoint.py +8 -14
  8. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5}/tostorchconnector/tos_client.py +27 -17
  9. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5}/tostorchconnector/tos_common.py +10 -5
  10. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5}/tostorchconnector/tos_iterable_dataset.py +24 -17
  11. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5}/tostorchconnector/tos_map_dataset.py +24 -18
  12. tostorchconnector-1.0.5/tostorchconnector/tos_object_reader.py +386 -0
  13. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5}/tostorchconnector/tos_object_writer.py +29 -15
  14. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5/tostorchconnector.egg-info}/PKG-INFO +2 -1
  15. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5}/tostorchconnector.egg-info/SOURCES.txt +2 -1
  16. tostorchconnector-1.0.4/tostorchconnector/tos_object_reader.py +0 -198
  17. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5}/LICENSE +0 -0
  18. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5}/README.md +0 -0
  19. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5}/setup.cfg +0 -0
  20. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5}/tostorchconnector/tos_object_meta.py +0 -0
  21. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5}/tostorchconnector.egg-info/dependency_links.txt +0 -0
  22. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5}/tostorchconnector.egg-info/requires.txt +0 -0
  23. {tostorchconnector-1.0.4 → tostorchconnector-1.0.5}/tostorchconnector.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tostorchconnector
3
- Version: 1.0.4
3
+ Version: 1.0.5
4
4
  Summary: TOS connector integration for PyTorch
5
5
  Author-email: xiangshijian <xiangshijian@bytedance.com>
6
6
  Classifier: Development Status :: 4 - Beta
@@ -10,6 +10,7 @@ Classifier: Programming Language :: Python :: 3.9
10
10
  Classifier: Programming Language :: Python :: 3.10
11
11
  Classifier: Programming Language :: Python :: 3.11
12
12
  Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Programming Language :: Python :: 3.13
13
14
  Classifier: License :: OSI Approved :: Apache Software License
14
15
  Classifier: Operating System :: OS Independent
15
16
  Classifier: Topic :: Utilities
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
 
6
6
  [project]
7
7
  name = "tostorchconnector"
8
- version = "1.0.4"
8
+ version = "1.0.5"
9
9
  description = "TOS connector integration for PyTorch"
10
10
  authors = [{ name = "xiangshijian", email = "xiangshijian@bytedance.com" }]
11
11
  requires-python = ">=3.8,<3.13"
@@ -18,6 +18,7 @@ classifiers = [
18
18
  "Programming Language :: Python :: 3.10",
19
19
  "Programming Language :: Python :: 3.11",
20
20
  "Programming Language :: Python :: 3.12",
21
+ "Programming Language :: Python :: 3.13",
21
22
  "License :: OSI Approved :: Apache Software License",
22
23
  "Operating System :: OS Independent",
23
24
  "Topic :: Utilities"
@@ -2,7 +2,9 @@ import os
2
2
  import unittest
3
3
 
4
4
  from tostorchconnector import TosMapDataset, TosIterableDataset, TosCheckpoint
5
- from tostorchconnector.tos_client import CredentialProvider, TosLogConfig
5
+ from tostorchconnector.tos_client import CredentialProvider, ReaderType
6
+
7
+ USE_NATIVE_CLIENT = True
6
8
 
7
9
 
8
10
  class TestTosDataSet(unittest.TestCase):
@@ -14,7 +16,8 @@ class TestTosDataSet(unittest.TestCase):
14
16
  sk = os.getenv('TOS_SECRET_KEY')
15
17
  bucket = 'tos-pytorch-connector'
16
18
  datasets = TosMapDataset.from_urls(iter([f'tos://{bucket}/key1', f'tos://{bucket}/key2', f'{bucket}/key3']),
17
- region=region, endpoint=endpoint, cred=CredentialProvider(ak, sk))
19
+ region=region, endpoint=endpoint, cred=CredentialProvider(ak, sk),
20
+ use_native_client=USE_NATIVE_CLIENT)
18
21
 
19
22
  for i in range(len(datasets)):
20
23
  print(datasets[i].bucket, datasets[i].key)
@@ -26,7 +29,8 @@ class TestTosDataSet(unittest.TestCase):
26
29
  sk = os.getenv('TOS_SECRET_KEY')
27
30
  bucket = 'tos-pytorch-connector'
28
31
  datasets = TosMapDataset.from_prefix(f'tos://{bucket}/prefix', region=region,
29
- endpoint=endpoint, cred=CredentialProvider(ak, sk))
32
+ endpoint=endpoint, cred=CredentialProvider(ak, sk),
33
+ use_native_client=USE_NATIVE_CLIENT)
30
34
  for i in range(len(datasets)):
31
35
  item = datasets[i]
32
36
  print(item.bucket, item.key)
@@ -43,7 +47,8 @@ class TestTosDataSet(unittest.TestCase):
43
47
  sk = os.getenv('TOS_SECRET_KEY')
44
48
  bucket = 'tos-pytorch-connector'
45
49
  datasets = TosIterableDataset.from_prefix(f'tos://{bucket}/prefix', region=region,
46
- endpoint=endpoint, cred=CredentialProvider(ak, sk))
50
+ endpoint=endpoint, cred=CredentialProvider(ak, sk),
51
+ use_native_client=USE_NATIVE_CLIENT)
47
52
  i = 0
48
53
  for dataset in datasets:
49
54
  print(dataset.bucket, dataset.key)
@@ -59,15 +64,27 @@ class TestTosDataSet(unittest.TestCase):
59
64
  ak = os.getenv('TOS_ACCESS_KEY')
60
65
  sk = os.getenv('TOS_SECRET_KEY')
61
66
  bucket = 'tos-pytorch-connector'
62
- checkpoint = TosCheckpoint(region, endpoint, cred=CredentialProvider(ak, sk), use_native_client=True)
67
+ checkpoint = TosCheckpoint(region, endpoint, cred=CredentialProvider(ak, sk),
68
+ use_native_client=USE_NATIVE_CLIENT)
63
69
  url = f'tos://{bucket}/key1'
70
+ print('test sequential')
64
71
  with checkpoint.writer(url) as writer:
65
72
  writer.write(b'hello world')
66
73
  writer.write(b'hi world')
67
74
 
68
- reader = checkpoint.reader(url)
69
- print(reader.read())
70
-
75
+ with checkpoint.reader(url) as reader:
76
+ data = reader.read(5)
77
+ print(data)
78
+ print(reader.read())
79
+ reader.seek(0)
80
+ data = reader.read(5)
81
+ print(data)
71
82
 
72
- if __name__ == '__main__':
73
- unittest.main()
83
+ print('test ranged')
84
+ with checkpoint.reader(url, reader_type=ReaderType.RANGED) as reader:
85
+ data = reader.read(5)
86
+ print(data)
87
+ print(reader.read())
88
+ reader.seek(0)
89
+ data = reader.read(5)
90
+ print(data)
@@ -5,7 +5,7 @@ import uuid
5
5
  from tosnativeclient import TosClient, TosException
6
6
 
7
7
 
8
- class TestTosNativeClient(unittest.TestCase):
8
+ class TestTosClient(unittest.TestCase):
9
9
  def test_list_objects(self):
10
10
  region = os.getenv('TOS_REGION')
11
11
  endpoint = os.getenv('TOS_ENDPOINT')
@@ -15,16 +15,16 @@ class TestTosNativeClient(unittest.TestCase):
15
15
  tos_client = TosClient(region, endpoint, ak, sk, directives='info', directory='logs',
16
16
  file_name_prefix='app.log')
17
17
 
18
- list_stream = tos_client.list_objects(bucket, 'prefix/', max_keys=1000)
18
+ list_stream = tos_client.list_objects(bucket, '', max_keys=1000)
19
19
  count = 0
20
20
  try:
21
21
  for objects in list_stream:
22
- count += 1
23
22
  for content in objects.contents:
23
+ count += 1
24
24
  print(content.key, content.size)
25
- output = tos_client.head_object(bucket, content.key)
26
- assert output.etag == content.etag
27
- assert output.size == content.size
25
+ # output = tos_client.head_object(bucket, content.key)
26
+ # assert output.etag == content.etag
27
+ # assert output.size == content.size
28
28
 
29
29
  print(count)
30
30
  except TosException as e:
@@ -70,7 +70,3 @@ class TestTosNativeClient(unittest.TestCase):
70
70
  print(chunk)
71
71
  except TosException as e:
72
72
  print(e.args[0].status_code)
73
-
74
-
75
- if __name__ == '__main__':
76
- unittest.main()
@@ -0,0 +1,102 @@
1
+ import os
2
+ import unittest
3
+ import uuid
4
+
5
+ import tos
6
+ from tos.exceptions import TosServerError
7
+
8
+
9
+ class TestTosRawClient(unittest.TestCase):
10
+ def test_put_get(self):
11
+ from tosnativeclient import TosRawClient, HeadObjectInput, TosException, DeleteObjectInput, \
12
+ PutObjectFromBufferInput, \
13
+ GetObjectInput, GetObjectOutput
14
+
15
+ region = os.getenv('TOS_REGION')
16
+ endpoint = os.getenv('TOS_ENDPOINT')
17
+ ak = os.getenv('TOS_ACCESS_KEY')
18
+ sk = os.getenv('TOS_SECRET_KEY')
19
+ bucket = 'tos-pytorch-connector'
20
+ key = str(uuid.uuid4())
21
+ tos_raw_client = TosRawClient(region, endpoint, ak, sk)
22
+
23
+ doutput = tos_raw_client.delete_object(DeleteObjectInput(bucket, key))
24
+ assert doutput.status_code == 204
25
+
26
+ try:
27
+ tos_raw_client.head_object(HeadObjectInput(bucket, key))
28
+ assert False
29
+ except TosException as e:
30
+ assert e.args[0].status_code == 404
31
+ assert len(e.args[0].request_id) > 0
32
+
33
+ data = str(uuid.uuid4()).encode('utf-8')
34
+ input = PutObjectFromBufferInput(bucket, key, data)
35
+ poutput = tos_raw_client.put_object_from_buffer(input)
36
+ assert poutput.status_code == 200
37
+ assert len(poutput.etag) > 0
38
+
39
+ houtput = tos_raw_client.head_object(HeadObjectInput(bucket, key))
40
+ assert houtput.status_code == 200
41
+ assert houtput.etag == poutput.etag
42
+ assert houtput.content_length == len(data)
43
+
44
+ goutput: GetObjectOutput = tos_raw_client.get_object(GetObjectInput(bucket, key))
45
+ assert goutput.status_code == 200
46
+ assert goutput.etag == poutput.etag
47
+ rdata = goutput.read_all()
48
+ assert rdata == data
49
+
50
+ doutput = tos_raw_client.delete_object(DeleteObjectInput(bucket, key))
51
+ assert doutput.status_code == 204
52
+
53
+ try:
54
+ tos_raw_client.get_object(GetObjectInput(bucket, key))
55
+ assert False
56
+ except TosException as e:
57
+ assert e.args[0].status_code == 404
58
+ assert len(e.args[0].request_id) > 0
59
+
60
+ def test_put_get_old(self):
61
+ region = os.getenv('TOS_REGION')
62
+ endpoint = os.getenv('TOS_ENDPOINT')
63
+ ak = os.getenv('TOS_ACCESS_KEY')
64
+ sk = os.getenv('TOS_SECRET_KEY')
65
+ bucket = 'tos-pytorch-connector'
66
+ key = str(uuid.uuid4())
67
+ tos_client = tos.TosClientV2(ak, sk, endpoint=endpoint, region=region)
68
+ doutput = tos_client.delete_object(bucket, key)
69
+ assert doutput.status_code == 204
70
+
71
+ try:
72
+ tos_client.head_object(bucket, key)
73
+ assert False
74
+ except TosServerError as e:
75
+ assert e.status_code == 404
76
+ assert len(e.request_id) > 0
77
+
78
+ data = str(uuid.uuid4()).encode('utf-8')
79
+ poutput = tos_client.put_object(bucket, key, content=data)
80
+ assert poutput.status_code == 200
81
+ assert len(poutput.etag) > 0
82
+
83
+ houtput = tos_client.head_object(bucket, key)
84
+ assert houtput.status_code == 200
85
+ assert houtput.etag == poutput.etag
86
+ assert houtput.content_length == len(data)
87
+
88
+ goutput = tos_client.get_object(bucket, key)
89
+ assert goutput.status_code == 200
90
+ assert goutput.etag == poutput.etag
91
+ rdata = goutput.read()
92
+ assert rdata == data
93
+
94
+ doutput = tos_client.delete_object(bucket, key)
95
+ assert doutput.status_code == 204
96
+
97
+ try:
98
+ tos_client.get_object(bucket, key)
99
+ assert False
100
+ except TosServerError as e:
101
+ assert e.status_code == 404
102
+ assert len(e.request_id) > 0
@@ -1,4 +1,4 @@
1
- from .tos_object_reader import TosObjectReader
1
+ from .tos_object_reader import TosObjectReader, SequentialTosObjectReader
2
2
  from .tos_object_writer import TosObjectWriter
3
3
  from .tos_iterable_dataset import TosIterableDataset
4
4
  from .tos_map_dataset import TosMapDataset
@@ -7,6 +7,7 @@ from tosnativeclient import TosException, TosError
7
7
 
8
8
  __all__ = [
9
9
  'TosObjectReader',
10
+ 'SequentialTosObjectReader',
10
11
  'TosObjectWriter',
11
12
  'TosIterableDataset',
12
13
  'TosMapDataset',
@@ -2,7 +2,7 @@ import logging
2
2
  from typing import Optional
3
3
 
4
4
  from . import TosObjectReader, TosObjectWriter
5
- from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig
5
+ from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig, ReaderType
6
6
  from .tos_common import parse_tos_url
7
7
 
8
8
  log = logging.getLogger(__name__)
@@ -14,26 +14,20 @@ class TosCheckpoint(object):
14
14
  cred: Optional[CredentialProvider] = None,
15
15
  client_conf: Optional[TosClientConfig] = None,
16
16
  log_conf: Optional[TosLogConfig] = None, use_native_client=True):
17
- self._client = None
18
- self._native_client = None
19
17
  self._region = region
20
18
  self._endpoint = endpoint
21
19
  self._cred = cred
22
20
  self._client_conf = client_conf
23
21
  self._log_conf = log_conf
24
- self._use_native_client = use_native_client
22
+ self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
23
+ use_native_client)
24
+ log.info('TosCheckpoint init tos client succeed')
25
25
 
26
- def reader(self, url: str) -> TosObjectReader:
26
+ def reader(self, url: str, reader_type: Optional[ReaderType] = None,
27
+ buffer_size: Optional[int] = None) -> TosObjectReader:
27
28
  bucket, key = parse_tos_url(url)
28
- return self._get_tos_client().get_object(bucket, key)
29
+ return self._client.get_object(bucket, key, reader_type=reader_type, buffer_size=buffer_size)
29
30
 
30
31
  def writer(self, url: str) -> TosObjectWriter:
31
32
  bucket, key = parse_tos_url(url)
32
- return self._get_tos_client().put_object(bucket, key)
33
-
34
- def _get_tos_client(self):
35
- if self._client is None:
36
- self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
37
- self._use_native_client)
38
- log.info('TosIterableDataset init tos client succeed')
39
- return self._client
33
+ return self._client.put_object(bucket, key)
@@ -1,21 +1,26 @@
1
- import io
1
+ import enum
2
2
  import logging
3
3
  import os
4
4
  from functools import partial
5
5
  from typing import Optional, List, Tuple
6
6
 
7
7
  import tos
8
- from tos.models2 import GetObjectOutput, PutObjectOutput
9
8
 
10
9
  import tosnativeclient
11
10
 
12
- from . import TosObjectReader, TosObjectWriter
11
+ from . import SequentialTosObjectReader, TosObjectWriter, TosObjectReader
13
12
  from .tos_object_meta import TosObjectMeta
13
+ from .tos_object_reader import TosObjectStream, RangedTosObjectReader
14
14
  from .tos_object_writer import PutObjectStream
15
15
 
16
16
  log = logging.getLogger(__name__)
17
17
 
18
18
 
19
+ class ReaderType(enum.Enum):
20
+ SEQUENTIAL = 'Sequential'
21
+ RANGED = 'Ranged'
22
+
23
+
19
24
  class CredentialProvider(object):
20
25
  def __init__(self, ak: str, sk: str):
21
26
  self._ak = ak
@@ -110,48 +115,49 @@ class TosClient(object):
110
115
  tos.set_logger(file_path=file_path, level=log_level)
111
116
 
112
117
  @property
113
- def use_native_client(self):
118
+ def use_native_client(self) -> bool:
114
119
  return self._use_native_client
115
120
 
116
121
  def get_object(self, bucket: str, key: str, etag: Optional[str] = None,
117
- size: Optional[int] = None) -> TosObjectReader:
122
+ size: Optional[int] = None, reader_type: Optional[ReaderType] = None,
123
+ buffer_size: Optional[int] = None) -> TosObjectReader:
118
124
  log.debug(f'get_object tos://{bucket}/{key}')
125
+
119
126
  if size is None or etag is None:
120
127
  get_object_meta = partial(self.head_object, bucket, key)
121
128
  else:
122
129
  get_object_meta = lambda: TosObjectMeta(bucket, key, size, etag)
123
- if isinstance(self._client, tosnativeclient.TosClient):
124
- get_object_stream = partial(self._client.get_object, bucket, key)
125
- else:
126
- def get_object_stream(et: str, _: int) -> GetObjectOutput:
127
- return self._client.get_object(bucket, key, '', et)
128
130
 
129
- return TosObjectReader(bucket, key, get_object_meta, get_object_stream)
131
+ object_stream = TosObjectStream(bucket, key, get_object_meta, self._client)
132
+ if reader_type is not None and reader_type == ReaderType.RANGED:
133
+ return RangedTosObjectReader(bucket, key, object_stream, buffer_size)
134
+ return SequentialTosObjectReader(bucket, key, object_stream)
130
135
 
131
136
  def put_object(self, bucket: str, key: str, storage_class: Optional[str] = None) -> TosObjectWriter:
132
137
  log.debug(f'put_object tos://{bucket}/{key}')
133
138
 
134
139
  if isinstance(self._client, tosnativeclient.TosClient):
135
- storage_class = storage_class if storage_class is not None else ''
136
- put_object_stream = self._client.put_object(bucket, key, storage_class)
140
+ put_object_stream = self._client.put_object(bucket, key, storage_class=storage_class)
137
141
  else:
138
- def put_object(content: io.BytesIO) -> PutObjectOutput:
139
- return self._client.put_object(bucket, key, storage_class=storage_class, content=content)
140
-
141
- put_object_stream = PutObjectStream(put_object)
142
+ put_object_stream = PutObjectStream(
143
+ lambda content: self._client.put_object(bucket, key, storage_class=storage_class, content=content))
142
144
 
143
145
  return TosObjectWriter(bucket, key, put_object_stream)
144
146
 
145
147
  def head_object(self, bucket: str, key: str) -> TosObjectMeta:
146
148
  log.debug(f'head_object tos://{bucket}/{key}')
149
+
147
150
  if isinstance(self._client, tosnativeclient.TosClient):
148
151
  resp = self._client.head_object(bucket, key)
149
152
  return TosObjectMeta(resp.bucket, resp.key, resp.size, resp.etag)
153
+
150
154
  resp = self._client.head_object(bucket, key)
151
155
  return TosObjectMeta(bucket, key, resp.content_length, resp.etag)
152
156
 
153
157
  def gen_list_stream(self, bucket: str, prefix: str, max_keys: int = 1000,
154
158
  delimiter: Optional[str] = None) -> tosnativeclient.ListStream:
159
+ log.debug(f'gen_list_stream tos://{bucket}/{prefix}')
160
+
155
161
  if isinstance(self._client, tosnativeclient.TosClient):
156
162
  delimiter = delimiter if delimiter is not None else ''
157
163
  return self._client.list_objects(bucket, prefix, max_keys=max_keys, delimiter=delimiter)
@@ -161,6 +167,10 @@ class TosClient(object):
161
167
  continuation_token: Optional[str] = None, delimiter: Optional[str] = None) -> Tuple[
162
168
  List[TosObjectMeta], bool, Optional[str]]:
163
169
  log.debug(f'list_objects tos://{bucket}/{prefix}')
170
+
171
+ if isinstance(self._client, tosnativeclient.TosClient):
172
+ raise NotImplementedError()
173
+
164
174
  resp = self._client.list_objects_type2(bucket, prefix, max_keys=max_keys, continuation_token=continuation_token,
165
175
  delimiter=delimiter)
166
176
  object_metas = []
@@ -22,14 +22,14 @@ class TosObjectIterator(object):
22
22
  self._is_truncated = True
23
23
  self._continuation_token = None
24
24
 
25
- def close(self):
25
+ def close(self) -> None:
26
26
  if self._list_stream is not None:
27
27
  self._list_stream.close()
28
28
 
29
29
  def __iter__(self) -> Iterator[TosObjectMeta]:
30
30
  return self
31
31
 
32
- def __next__(self):
32
+ def __next__(self) -> TosObjectMeta:
33
33
  if self._client.use_native_client:
34
34
  if self._list_stream is None:
35
35
  self._list_stream = self._client.gen_list_stream(self._bucket, self._prefix, max_keys=1000,
@@ -37,9 +37,14 @@ class TosObjectIterator(object):
37
37
 
38
38
  if self._object_metas is None or self._index >= len(self._object_metas):
39
39
  self._object_metas = []
40
- objects = next(self._list_stream)
41
- for content in objects.contents:
42
- self._object_metas.append(TosObjectMeta(content.bucket, content.key, content.size, content.etag))
40
+ self._index = 0
41
+ while 1:
42
+ objects = next(self._list_stream)
43
+ for content in objects.contents:
44
+ self._object_metas.append(
45
+ TosObjectMeta(content.bucket, content.key, content.size, content.etag))
46
+ if len(self._object_metas) > 0:
47
+ break
43
48
 
44
49
  object_meta = self._object_metas[self._index]
45
50
  self._index += 1
@@ -5,7 +5,7 @@ from typing import Iterator, Any, Optional, Callable, Union
5
5
  import torch
6
6
 
7
7
  from . import TosObjectReader
8
- from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig
8
+ from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig, ReaderType
9
9
  from .tos_common import default_trans, gen_dataset_from_urls, gen_dataset_from_prefix
10
10
  from .tos_object_meta import TosObjectMeta
11
11
 
@@ -20,8 +20,10 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
20
20
  cred: Optional[CredentialProvider] = None,
21
21
  client_conf: Optional[TosClientConfig] = None,
22
22
  log_conf: Optional[TosLogConfig] = None,
23
- sharding: bool = False, use_native_client=True):
24
- self._client: Optional[TosClient] = None
23
+ sharding: bool = False,
24
+ use_native_client: bool = True,
25
+ reader_type: Optional[ReaderType] = None,
26
+ buffer_size: Optional[int] = None):
25
27
  self._gen_dataset = gen_dataset
26
28
  self._region = region
27
29
  self._endpoint = endpoint
@@ -30,13 +32,17 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
30
32
  self._client_conf = client_conf
31
33
  self._log_conf = log_conf
32
34
  self._sharding = sharding
33
- self._use_native_client = use_native_client
34
35
  if torch.distributed.is_initialized():
35
36
  self._rank = torch.distributed.get_rank()
36
37
  self._world_size = torch.distributed.get_world_size()
37
38
  else:
38
39
  self._rank = 0
39
40
  self._world_size = 1
41
+ self._reader_type = reader_type
42
+ self._buffer_size = buffer_size
43
+ self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
44
+ use_native_client)
45
+ log.info('TosIterableDataset init tos client succeed')
40
46
 
41
47
  @classmethod
42
48
  def from_urls(cls, urls: Union[str, Iterator[str]], *, region: str, endpoint: Optional[str] = None,
@@ -44,10 +50,13 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
44
50
  cred: Optional[CredentialProvider] = None,
45
51
  client_conf: Optional[TosClientConfig] = None,
46
52
  log_conf: Optional[TosLogConfig] = None,
47
- sharding: bool = False, use_native_client=True):
53
+ sharding: bool = False,
54
+ use_native_client: bool = True,
55
+ reader_type: Optional[ReaderType] = None,
56
+ buffer_size: Optional[int] = None):
48
57
  log.info(f'building {cls.__name__} from_urls')
49
58
  return cls(region, partial(gen_dataset_from_urls, urls), endpoint, transform, cred, client_conf, log_conf,
50
- sharding, use_native_client)
59
+ sharding, use_native_client, reader_type, buffer_size)
51
60
 
52
61
  @classmethod
53
62
  def from_prefix(cls, prefix: str, *, region: str, endpoint: Optional[str] = None,
@@ -55,10 +64,13 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
55
64
  cred: Optional[CredentialProvider] = None,
56
65
  client_conf: Optional[TosClientConfig] = None,
57
66
  log_conf: Optional[TosLogConfig] = None,
58
- sharding: bool = False, use_native_client=True):
67
+ sharding: bool = False,
68
+ use_native_client: bool = True,
69
+ reader_type: Optional[ReaderType] = None,
70
+ buffer_size: Optional[int] = None):
59
71
  log.info(f'building {cls.__name__} from_prefix')
60
72
  return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, transform, cred, client_conf, log_conf,
61
- sharding, use_native_client)
73
+ sharding, use_native_client, reader_type, buffer_size)
62
74
 
63
75
  def __iter__(self) -> Iterator[Any]:
64
76
  worker_id = 0
@@ -72,12 +84,12 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
72
84
  if not self._sharding or (self._world_size == 1 and num_workers == 1):
73
85
  return map(
74
86
  self._trans_tos_object,
75
- self._gen_dataset(self._get_tos_client()),
87
+ self._gen_dataset(self._client),
76
88
  )
77
89
 
78
90
  part_dataset = (
79
91
  obj
80
- for idx, obj in enumerate(self._gen_dataset(self._get_tos_client()))
92
+ for idx, obj in enumerate(self._gen_dataset(self._client))
81
93
  if idx % self._world_size == self._rank
82
94
  )
83
95
 
@@ -88,12 +100,7 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
88
100
  )
89
101
  return map(self._trans_tos_object, part_dataset)
90
102
 
91
- def _get_tos_client(self):
92
- if self._client is None:
93
- self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf)
94
- log.info('TosIterableDataset init tos client succeed')
95
- return self._client
96
-
97
103
  def _trans_tos_object(self, object_meta: TosObjectMeta) -> Any:
98
- obj = self._get_tos_client().get_object(object_meta.bucket, object_meta.key, object_meta.etag, object_meta.size)
104
+ obj = self._client.get_object(object_meta.bucket, object_meta.key, object_meta.etag, object_meta.size,
105
+ reader_type=self._reader_type, buffer_size=self._buffer_size)
99
106
  return self._trans(obj)
@@ -5,7 +5,7 @@ from typing import Any, Callable, Iterator, Optional, List, Union
5
5
  import torch
6
6
 
7
7
  from . import TosObjectReader
8
- from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig
8
+ from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig, ReaderType
9
9
  from .tos_common import default_trans, gen_dataset_from_prefix, \
10
10
  gen_dataset_from_urls
11
11
  from .tos_object_meta import TosObjectMeta
@@ -20,8 +20,10 @@ class TosMapDataset(torch.utils.data.Dataset):
20
20
  trans: Callable[[TosObjectReader], Any] = default_trans,
21
21
  cred: Optional[CredentialProvider] = None,
22
22
  client_conf: Optional[TosClientConfig] = None,
23
- log_conf: Optional[TosLogConfig] = None, use_native_client=True):
24
- self._client: Optional[TosClient] = None
23
+ log_conf: Optional[TosLogConfig] = None,
24
+ use_native_client: bool = True,
25
+ reader_type: Optional[ReaderType] = None,
26
+ buffer_size: Optional[int] = None):
25
27
  self._gen_dataset = gen_dataset
26
28
  self._region = region
27
29
  self._endpoint = endpoint
@@ -29,28 +31,38 @@ class TosMapDataset(torch.utils.data.Dataset):
29
31
  self._cred = cred
30
32
  self._client_conf = client_conf
31
33
  self._log_conf = log_conf
32
- self._use_native_client = use_native_client
33
34
  self._dataset: Optional[List[TosObjectMeta]] = None
35
+ self._reader_type = reader_type
36
+ self._buffer_size = buffer_size
37
+ self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
38
+ use_native_client)
39
+ log.info('TosMapDataset init tos client succeed')
34
40
 
35
41
  @classmethod
36
42
  def from_urls(cls, urls: Union[str, Iterator[str]], *, region: str, endpoint: Optional[str] = None,
37
43
  trans: Callable[[TosObjectReader], Any] = default_trans,
38
44
  cred: Optional[CredentialProvider] = None,
39
45
  client_conf: Optional[TosClientConfig] = None,
40
- log_conf: Optional[TosLogConfig] = None, use_native_client=True):
46
+ log_conf: Optional[TosLogConfig] = None,
47
+ use_native_client: bool = True,
48
+ reader_type: Optional[ReaderType] = None,
49
+ buffer_size: Optional[int] = None):
41
50
  log.info(f'building {cls.__name__} from_urls')
42
51
  return cls(region, partial(gen_dataset_from_urls, urls), endpoint, trans, cred, client_conf, log_conf,
43
- use_native_client)
52
+ use_native_client, reader_type, buffer_size)
44
53
 
45
54
  @classmethod
46
55
  def from_prefix(cls, prefix: str, *, region: str, endpoint: Optional[str] = None,
47
56
  trans: Callable[[TosObjectReader], Any] = default_trans,
48
57
  cred: Optional[CredentialProvider] = None,
49
58
  client_conf: Optional[TosClientConfig] = None,
50
- log_conf: Optional[TosLogConfig] = None, use_native_client=True):
59
+ log_conf: Optional[TosLogConfig] = None,
60
+ use_native_client: bool = True,
61
+ reader_type: Optional[ReaderType] = None,
62
+ buffer_size: Optional[int] = None):
51
63
  log.info(f'building {cls.__name__} from_prefix')
52
64
  return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, trans, cred, client_conf, log_conf,
53
- use_native_client)
65
+ use_native_client, reader_type, buffer_size)
54
66
 
55
67
  def __getitem__(self, i: int) -> Any:
56
68
  return self._trans_tos_object(i)
@@ -58,21 +70,15 @@ class TosMapDataset(torch.utils.data.Dataset):
58
70
  def __len__(self) -> int:
59
71
  return len(self._data_set)
60
72
 
61
- def _get_tos_client(self) -> TosClient:
62
- if self._client is None:
63
- self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
64
- self._use_native_client)
65
- log.info('TosMapDataset init tos client succeed')
66
- return self._client
67
-
68
73
  @property
69
- def _data_set(self):
74
+ def _data_set(self) -> List[TosObjectMeta]:
70
75
  if self._dataset is None:
71
- self._dataset = list(self._gen_dataset(self._get_tos_client()))
76
+ self._dataset = list(self._gen_dataset(self._client))
72
77
  assert self._dataset is not None
73
78
  return self._dataset
74
79
 
75
80
  def _trans_tos_object(self, i: int) -> Any:
76
81
  object_meta = self._data_set[i]
77
- obj = self._get_tos_client().get_object(object_meta.bucket, object_meta.key, object_meta.etag, object_meta.size)
82
+ obj = self._client.get_object(object_meta.bucket, object_meta.key, object_meta.etag, object_meta.size,
83
+ reader_type=self._reader_type, buffer_size=self._buffer_size)
78
84
  return self._trans(obj)