tostorchconnector 1.0.6__tar.gz → 1.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tostorchconnector might be problematic. Click here for more details.

Files changed (23) hide show
  1. {tostorchconnector-1.0.6/tostorchconnector.egg-info → tostorchconnector-1.1.2}/PKG-INFO +2 -2
  2. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/pyproject.toml +2 -2
  3. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tests/test_tos_dataset.py +52 -14
  4. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tests/test_tosclient.py +34 -14
  5. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/tos_checkpoint.py +4 -4
  6. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/tos_client.py +79 -38
  7. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/tos_common.py +31 -13
  8. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/tos_iterable_dataset.py +16 -16
  9. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/tos_map_dataset.py +18 -18
  10. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/tos_object_reader.py +1 -1
  11. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2/tostorchconnector.egg-info}/PKG-INFO +2 -2
  12. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector.egg-info/SOURCES.txt +0 -1
  13. tostorchconnector-1.1.2/tostorchconnector.egg-info/requires.txt +3 -0
  14. tostorchconnector-1.0.6/tests/test_tosrawclient.py +0 -102
  15. tostorchconnector-1.0.6/tostorchconnector.egg-info/requires.txt +0 -3
  16. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/LICENSE +0 -0
  17. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/README.md +0 -0
  18. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/setup.cfg +0 -0
  19. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/__init__.py +0 -0
  20. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/tos_object_meta.py +0 -0
  21. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector/tos_object_writer.py +0 -0
  22. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector.egg-info/dependency_links.txt +0 -0
  23. {tostorchconnector-1.0.6 → tostorchconnector-1.1.2}/tostorchconnector.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tostorchconnector
3
- Version: 1.0.6
3
+ Version: 1.1.2
4
4
  Summary: TOS connector integration for PyTorch
5
5
  Author-email: xiangshijian <xiangshijian@bytedance.com>
6
6
  Classifier: Development Status :: 4 - Beta
@@ -19,7 +19,7 @@ Description-Content-Type: text/markdown
19
19
  License-File: LICENSE
20
20
  Requires-Dist: torch>=2.0
21
21
  Requires-Dist: tos>=2.8.0
22
- Requires-Dist: tosnativeclient>=1.0.2
22
+ Requires-Dist: tosnativeclient>=1.1.1
23
23
  Dynamic: license-file
24
24
 
25
25
  # TOS Connector for pytorch
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
 
6
6
  [project]
7
7
  name = "tostorchconnector"
8
- version = "1.0.6"
8
+ version = "1.1.2"
9
9
  description = "TOS connector integration for PyTorch"
10
10
  authors = [{ name = "xiangshijian", email = "xiangshijian@bytedance.com" }]
11
11
  requires-python = ">=3.8,<3.14"
@@ -26,7 +26,7 @@ classifiers = [
26
26
  dependencies = [
27
27
  "torch >= 2.0",
28
28
  "tos>=2.8.0",
29
- "tosnativeclient >= 1.0.2"
29
+ "tosnativeclient >= 1.1.1"
30
30
  ]
31
31
 
32
32
  [tool.setuptools.packages]
@@ -1,10 +1,12 @@
1
1
  import os
2
+ import pickle
2
3
  import unittest
3
4
 
4
5
  from tostorchconnector import TosMapDataset, TosIterableDataset, TosCheckpoint
5
6
  from tostorchconnector.tos_client import CredentialProvider, ReaderType
6
7
 
7
8
  USE_NATIVE_CLIENT = True
9
+ READER_TYPE = ReaderType.SEQUENTIAL
8
10
 
9
11
 
10
12
  class TestTosDataSet(unittest.TestCase):
@@ -22,23 +24,50 @@ class TestTosDataSet(unittest.TestCase):
22
24
  for i in range(len(datasets)):
23
25
  print(datasets[i].bucket, datasets[i].key)
24
26
 
27
+ def test_pickle(self):
28
+ region = os.getenv('TOS_REGION')
29
+ endpoint = os.getenv('TOS_ENDPOINT')
30
+ ak = os.getenv('TOS_ACCESS_KEY')
31
+ sk = os.getenv('TOS_SECRET_KEY')
32
+ bucket = 'tos-pytorch-connector'
33
+ datasets = TosMapDataset.from_prefix(f'tos://{bucket}', region=region,
34
+ endpoint=endpoint, cred=CredentialProvider(ak, sk),
35
+ use_native_client=USE_NATIVE_CLIENT)
36
+ pickled_datasets = pickle.dumps(datasets)
37
+ assert isinstance(pickled_datasets, bytes)
38
+ unpickled_datasets = pickle.loads(pickled_datasets)
39
+ i = 0
40
+ for dataset in unpickled_datasets:
41
+ print(dataset.bucket, dataset.key)
42
+ i += 1
43
+ print(i)
44
+
45
+ pickled = pickle.dumps(unpickled_datasets._data_set)
46
+ assert isinstance(pickled, bytes)
47
+
25
48
  def test_from_prefix(self):
26
49
  region = os.getenv('TOS_REGION')
27
50
  endpoint = os.getenv('TOS_ENDPOINT')
28
51
  ak = os.getenv('TOS_ACCESS_KEY')
29
52
  sk = os.getenv('TOS_SECRET_KEY')
30
53
  bucket = 'tos-pytorch-connector'
31
- datasets = TosMapDataset.from_prefix(f'tos://{bucket}/prefix', region=region,
54
+ datasets = TosMapDataset.from_prefix(f'tos://{bucket}', region=region,
32
55
  endpoint=endpoint, cred=CredentialProvider(ak, sk),
33
56
  use_native_client=USE_NATIVE_CLIENT)
57
+
58
+ count = 0
34
59
  for i in range(len(datasets)):
35
- item = datasets[i]
36
- print(item.bucket, item.key)
37
- if i == 1:
38
- item = datasets[i]
39
- data = item.read(100)
40
- print(data)
41
- print(len(data))
60
+ dataset = datasets[i]
61
+ print(dataset.bucket, dataset.key)
62
+ count += 1
63
+ dataset.close()
64
+ # if i == 1:
65
+ # item = datasets[i]
66
+ # data = item.read(100)
67
+ # print(data)
68
+ # print(len(data))
69
+
70
+ print(count)
42
71
 
43
72
  def test_from_prefix_iter(self):
44
73
  region = os.getenv('TOS_REGION')
@@ -46,17 +75,23 @@ class TestTosDataSet(unittest.TestCase):
46
75
  ak = os.getenv('TOS_ACCESS_KEY')
47
76
  sk = os.getenv('TOS_SECRET_KEY')
48
77
  bucket = 'tos-pytorch-connector'
49
- datasets = TosIterableDataset.from_prefix(f'tos://{bucket}/prefix', region=region,
78
+ datasets = TosIterableDataset.from_prefix(f'tos://{bucket}', region=region,
50
79
  endpoint=endpoint, cred=CredentialProvider(ak, sk),
51
- use_native_client=USE_NATIVE_CLIENT)
80
+ use_native_client=USE_NATIVE_CLIENT, reader_type=ReaderType.RANGED)
52
81
  i = 0
53
82
  for dataset in datasets:
54
83
  print(dataset.bucket, dataset.key)
55
- if i == 1:
56
- data = dataset.read(100)
57
- print(data)
58
- print(len(data))
84
+ # if dataset.key == 'tosutil':
85
+ # with open('logs/tosutil', 'wb') as f:
86
+ # while 1:
87
+ # chunk = dataset.read(8192)
88
+ # print(len(chunk))
89
+ # if not chunk:
90
+ # break
91
+ # f.write(chunk)
92
+ dataset.close()
59
93
  i += 1
94
+ print(i)
60
95
 
61
96
  def test_checkpoint(self):
62
97
  region = os.getenv('TOS_REGION')
@@ -72,6 +107,9 @@ class TestTosDataSet(unittest.TestCase):
72
107
  writer.write(b'hello world')
73
108
  writer.write(b'hi world')
74
109
 
110
+ with checkpoint.reader(url) as reader:
111
+ print(reader.read())
112
+
75
113
  with checkpoint.reader(url) as reader:
76
114
  data = reader.read(5)
77
115
  print(data)
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import pickle
2
3
  import unittest
3
4
  import uuid
4
5
 
@@ -6,15 +7,40 @@ from tosnativeclient import TosClient, TosException
6
7
 
7
8
 
8
9
  class TestTosClient(unittest.TestCase):
10
+ def test_pickle(self):
11
+ region = os.getenv('TOS_REGION')
12
+ endpoint = os.getenv('TOS_ENDPOINT')
13
+ ak = os.getenv('TOS_ACCESS_KEY')
14
+ sk = os.getenv('TOS_SECRET_KEY')
15
+ bucket = 'tos-pytorch-connector'
16
+ tos_client = TosClient(region, endpoint, ak, sk)
17
+ pickled_tos_client = pickle.dumps(tos_client)
18
+ assert isinstance(pickled_tos_client, bytes)
19
+ unpickled_tos_client = pickle.loads(pickled_tos_client)
20
+
21
+ self._test_list_objects(bucket, unpickled_tos_client)
22
+ self._test_write_read_object(bucket, unpickled_tos_client)
23
+
9
24
  def test_list_objects(self):
10
25
  region = os.getenv('TOS_REGION')
11
26
  endpoint = os.getenv('TOS_ENDPOINT')
12
27
  ak = os.getenv('TOS_ACCESS_KEY')
13
28
  sk = os.getenv('TOS_SECRET_KEY')
14
29
  bucket = 'tos-pytorch-connector'
15
- tos_client = TosClient(region, endpoint, ak, sk, directives='info', directory='logs',
16
- file_name_prefix='app.log')
30
+ tos_client = TosClient(region, endpoint, ak, sk)
31
+ self._test_list_objects(bucket, tos_client)
17
32
 
33
+ def test_write_read_object(self):
34
+ region = os.getenv('TOS_REGION')
35
+ endpoint = os.getenv('TOS_ENDPOINT')
36
+ ak = os.getenv('TOS_ACCESS_KEY')
37
+ sk = os.getenv('TOS_SECRET_KEY')
38
+ bucket = 'tos-pytorch-connector'
39
+ tos_client = TosClient(region, endpoint, ak, sk)
40
+
41
+ self._test_write_read_object(bucket, tos_client)
42
+
43
+ def _test_list_objects(self, bucket, tos_client):
18
44
  list_stream = tos_client.list_objects(bucket, '', max_keys=1000)
19
45
  count = 0
20
46
  try:
@@ -22,23 +48,15 @@ class TestTosClient(unittest.TestCase):
22
48
  for content in objects.contents:
23
49
  count += 1
24
50
  print(content.key, content.size)
25
- # output = tos_client.head_object(bucket, content.key)
26
- # assert output.etag == content.etag
27
- # assert output.size == content.size
51
+ output = tos_client.head_object(bucket, content.key)
52
+ assert output.etag == content.etag
53
+ assert output.size == content.size
28
54
 
29
55
  print(count)
30
56
  except TosException as e:
31
57
  print(e.args[0].message)
32
58
 
33
- def test_write_read_object(self):
34
- region = os.getenv('TOS_REGION')
35
- endpoint = os.getenv('TOS_ENDPOINT')
36
- ak = os.getenv('TOS_ACCESS_KEY')
37
- sk = os.getenv('TOS_SECRET_KEY')
38
- bucket = 'tos-pytorch-connector'
39
- tos_client = TosClient(region, endpoint, ak, sk, directives='info', directory='logs',
40
- file_name_prefix='app.log')
41
-
59
+ def _test_write_read_object(self, bucket, tos_client):
42
60
  key = str(uuid.uuid4())
43
61
  read_stream = tos_client.get_object(bucket, key, '', 1)
44
62
 
@@ -59,6 +77,8 @@ class TestTosClient(unittest.TestCase):
59
77
  write_stream.close()
60
78
 
61
79
  output = tos_client.head_object(bucket, key)
80
+ print(output.etag, output.size)
81
+
62
82
  read_stream = tos_client.get_object(bucket, key, output.etag, output.size)
63
83
  try:
64
84
  offset = 0
@@ -13,14 +13,14 @@ class TosCheckpoint(object):
13
13
  endpoint: Optional[str] = None,
14
14
  cred: Optional[CredentialProvider] = None,
15
15
  client_conf: Optional[TosClientConfig] = None,
16
- log_conf: Optional[TosLogConfig] = None, use_native_client=True):
16
+ use_native_client=True,
17
+ enable_crc=True):
17
18
  self._region = region
18
19
  self._endpoint = endpoint
19
20
  self._cred = cred
20
21
  self._client_conf = client_conf
21
- self._log_conf = log_conf
22
- self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
23
- use_native_client)
22
+ self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, use_native_client,
23
+ enable_crc)
24
24
  log.info('TosCheckpoint init tos client succeed')
25
25
 
26
26
  def reader(self, url: str, reader_type: Optional[ReaderType] = None,
@@ -1,8 +1,10 @@
1
1
  import enum
2
2
  import logging
3
3
  import os
4
+ import gc
5
+
4
6
  from functools import partial
5
- from typing import Optional, List, Tuple
7
+ from typing import Optional, List, Tuple, Any
6
8
 
7
9
  import tos
8
10
 
@@ -15,6 +17,44 @@ from .tos_object_writer import PutObjectStream
15
17
 
16
18
  log = logging.getLogger(__name__)
17
19
 
20
+ import threading
21
+ import weakref
22
+ import traceback
23
+
24
+ _client_lock = threading.Lock()
25
+ _client_map = weakref.WeakSet()
26
+
27
+
28
+ def _before_fork():
29
+ with _client_lock:
30
+ clients = list(_client_map)
31
+
32
+ if not clients or len(clients) == 0:
33
+ return
34
+
35
+ try:
36
+ for client in clients:
37
+ client._inner_client = None
38
+
39
+ _reset_client_map()
40
+ gc.collect()
41
+ except Exception as e:
42
+ log.warning(f'failed to clean up native clients, {str(e)}')
43
+ traceback.print_exc()
44
+
45
+
46
+ def _after_fork_in_child():
47
+ _reset_client_map()
48
+
49
+
50
+ def _reset_client_map():
51
+ global _client_map
52
+ with _client_lock:
53
+ _client_map = weakref.WeakSet()
54
+
55
+
56
+ os.register_at_fork(before=_before_fork, after_in_child=_after_fork_in_child)
57
+
18
58
 
19
59
  class ReaderType(enum.Enum):
20
60
  SEQUENTIAL = 'Sequential'
@@ -77,42 +117,38 @@ class TosLogConfig(object):
77
117
 
78
118
  class TosClient(object):
79
119
  def __init__(self, region: str, endpoint: Optional[str] = None, cred: Optional[CredentialProvider] = None,
80
- client_conf: Optional[TosClientConfig] = None, log_conf: Optional[TosLogConfig] = None,
81
- use_native_client: bool = True):
82
- cred = CredentialProvider('', '') if cred is None else cred
83
- client_conf = TosClientConfig() if client_conf is None else client_conf
84
- log_conf = TosLogConfig() if log_conf is None else log_conf
85
- self._part_size = client_conf.part_size
120
+ client_conf: Optional[TosClientConfig] = None, use_native_client: bool = True,
121
+ enable_crc: bool = True):
122
+ self._region = region
123
+ self._endpoint = endpoint
124
+ self._cred = CredentialProvider('', '') if cred is None else cred
125
+ self._client_conf = TosClientConfig() if client_conf is None else client_conf
126
+ self._part_size = self._client_conf.part_size
86
127
  self._use_native_client = use_native_client
87
- if use_native_client:
88
- directives = ''
89
- directory = ''
90
- file_name_prefix = ''
91
- if log_conf.log_dir and log_conf.log_file_name:
92
- if log_conf.log_level:
93
- if log_conf.log_level == logging.DEBUG:
94
- directives = 'debug'
95
- elif log_conf.log_level == logging.INFO:
96
- directives = 'info'
97
- elif log_conf.log_level == logging.WARN:
98
- directives = 'warn'
99
- elif log_conf.log_level == logging.ERROR:
100
- directives = 'error'
101
- else:
102
- directives = 'info'
103
- directory = log_conf.log_dir
104
- file_name_prefix = log_conf.log_file_name
105
- self._client = tosnativeclient.TosClient(region, endpoint, cred.ak, cred.sk, client_conf.part_size,
106
- client_conf.max_retry_count, directives=directives,
107
- directory=directory,
108
- file_name_prefix=file_name_prefix)
109
- else:
110
- self._client = tos.TosClientV2(cred.ak, cred.sk, endpoint=endpoint, region=region,
111
- max_retry_count=client_conf.max_retry_count)
112
- if log_conf.log_dir and log_conf.log_file_name:
113
- file_path = os.path.join(log_conf.log_dir, log_conf.log_file_name)
114
- log_level = log_conf.log_level if log_conf.log_level else logging.INFO
115
- tos.set_logger(file_path=file_path, level=log_level)
128
+ self._inner_client = None
129
+ self._client_pid = None
130
+ self._enable_crc = enable_crc
131
+
132
+ @property
133
+ def _client(self) -> Any:
134
+ if self._client_pid is None or self._client_pid != os.getpid() or self._inner_client is None:
135
+ with _client_lock:
136
+ if self._use_native_client:
137
+ self._inner_client = tosnativeclient.TosClient(self._region, self._endpoint, self._cred.ak,
138
+ self._cred.sk,
139
+ self._client_conf.part_size,
140
+ self._client_conf.max_retry_count,
141
+ enable_crc=self._enable_crc)
142
+ else:
143
+ self._inner_client = tos.TosClientV2(self._cred.ak, self._cred.sk, endpoint=self._endpoint,
144
+ region=self._region,
145
+ max_retry_count=self._client_conf.max_retry_count,
146
+ enable_crc=self._enable_crc)
147
+ self._client_pid = os.getpid()
148
+ _client_map.add(self)
149
+
150
+ assert self._inner_client is not None
151
+ return self._inner_client
116
152
 
117
153
  @property
118
154
  def use_native_client(self) -> bool:
@@ -155,12 +191,17 @@ class TosClient(object):
155
191
  return TosObjectMeta(bucket, key, resp.content_length, resp.etag)
156
192
 
157
193
  def gen_list_stream(self, bucket: str, prefix: str, max_keys: int = 1000,
158
- delimiter: Optional[str] = None) -> tosnativeclient.ListStream:
194
+ delimiter: Optional[str] = None,
195
+ continuation_token: Optional[str] = None,
196
+ list_background_buffer_count: int = 1) -> tosnativeclient.ListStream:
159
197
  log.debug(f'gen_list_stream tos://{bucket}/{prefix}')
160
198
 
161
199
  if isinstance(self._client, tosnativeclient.TosClient):
162
200
  delimiter = delimiter if delimiter is not None else ''
163
- return self._client.list_objects(bucket, prefix, max_keys=max_keys, delimiter=delimiter)
201
+ continuation_token = continuation_token if continuation_token is not None else ''
202
+ return self._client.list_objects(bucket, prefix, max_keys=max_keys, delimiter=delimiter,
203
+ continuation_token=continuation_token,
204
+ list_background_buffer_count=list_background_buffer_count)
164
205
  raise NotImplementedError()
165
206
 
166
207
  def list_objects(self, bucket: str, prefix: str, max_keys: int = 1000,
@@ -1,6 +1,7 @@
1
1
  import logging
2
2
  from typing import Union, Iterator, Tuple, Optional
3
3
 
4
+ from tosnativeclient import TosObject
4
5
  from . import TosObjectReader
5
6
  from .tos_client import TosClient
6
7
  from .tos_object_meta import TosObjectMeta
@@ -8,19 +9,30 @@ from .tos_object_meta import TosObjectMeta
8
9
  log = logging.getLogger(__name__)
9
10
 
10
11
 
11
- class TosObjectIterator(object):
12
+ class TosObjectIterable(object):
12
13
  def __init__(self, bucket: str, prefix: str, client: TosClient):
13
14
  self._bucket = bucket
14
15
  self._prefix = prefix
16
+ self._list_background_buffer_count = 3
17
+ self._client = client
18
+
19
+ def __iter__(self) -> Iterator[TosObjectMeta]:
20
+ return iter(TosObjectIterator(self._bucket, self._prefix, self._list_background_buffer_count, self._client))
21
+
22
+
23
+ class TosObjectIterator(object):
24
+ def __init__(self, bucket: str, prefix: str, list_background_buffer_count: int, client: TosClient):
25
+ self._bucket = bucket
26
+ self._prefix = prefix
27
+ self._list_background_buffer_count = list_background_buffer_count
15
28
  self._client = client
16
29
  self._delimiter: Optional[str] = None
17
- self._list_stream = None
30
+ self._continuation_token: Optional[str] = None
18
31
 
32
+ self._list_stream = None
19
33
  self._object_metas = None
20
34
  self._index = 0
21
-
22
35
  self._is_truncated = True
23
- self._continuation_token = None
24
36
 
25
37
  def close(self) -> None:
26
38
  if self._list_stream is not None:
@@ -33,22 +45,28 @@ class TosObjectIterator(object):
33
45
  if self._client.use_native_client:
34
46
  if self._list_stream is None:
35
47
  self._list_stream = self._client.gen_list_stream(self._bucket, self._prefix, max_keys=1000,
36
- delimiter=self._delimiter)
48
+ delimiter=self._delimiter,
49
+ continuation_token=self._continuation_token,
50
+ list_background_buffer_count=self._list_background_buffer_count)
37
51
 
38
52
  if self._object_metas is None or self._index >= len(self._object_metas):
39
- self._object_metas = []
53
+ self._object_metas = None
40
54
  self._index = 0
41
55
  while 1:
42
- objects = next(self._list_stream)
43
- for content in objects.contents:
44
- self._object_metas.append(
45
- TosObjectMeta(content.bucket, content.key, content.size, content.etag))
46
- if len(self._object_metas) > 0:
56
+ try:
57
+ objects = next(self._list_stream)
58
+ except:
59
+ self.close()
60
+ raise
61
+ self._continuation_token = self._list_stream.current_continuation_token()
62
+ self._object_metas = objects.contents
63
+ if self._object_metas is not None and len(self._object_metas) > 0:
47
64
  break
48
65
 
49
66
  object_meta = self._object_metas[self._index]
50
67
  self._index += 1
51
- return object_meta
68
+ # this is very critical
69
+ return TosObjectMeta(self._bucket, object_meta.key, object_meta.size, object_meta.etag)
52
70
 
53
71
  while self._object_metas is None or self._index >= len(self._object_metas):
54
72
  if not self._is_truncated:
@@ -101,4 +119,4 @@ def gen_dataset_from_urls(urls: Union[str, Iterator[str]], _: TosClient) -> Iter
101
119
 
102
120
  def gen_dataset_from_prefix(prefix: str, client: TosClient) -> Iterator[TosObjectMeta]:
103
121
  bucket, prefix = parse_tos_url(prefix)
104
- return iter(TosObjectIterator(bucket, prefix, client))
122
+ return iter(TosObjectIterable(bucket, prefix, client))
@@ -4,8 +4,9 @@ from typing import Iterator, Any, Optional, Callable, Union
4
4
 
5
5
  import torch
6
6
 
7
+ from tosnativeclient import TosObject
7
8
  from . import TosObjectReader
8
- from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig, ReaderType
9
+ from .tos_client import CredentialProvider, TosClientConfig, TosClient, ReaderType
9
10
  from .tos_common import default_trans, gen_dataset_from_urls, gen_dataset_from_prefix
10
11
  from .tos_object_meta import TosObjectMeta
11
12
 
@@ -16,21 +17,20 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
16
17
  def __init__(self, region: str,
17
18
  gen_dataset: Callable[[TosClient], Iterator[TosObjectMeta]],
18
19
  endpoint: Optional[str] = None,
19
- trans: Callable[[TosObjectReader], Any] = default_trans,
20
+ transform: Callable[[TosObjectReader], Any] = default_trans,
20
21
  cred: Optional[CredentialProvider] = None,
21
22
  client_conf: Optional[TosClientConfig] = None,
22
- log_conf: Optional[TosLogConfig] = None,
23
23
  sharding: bool = False,
24
24
  use_native_client: bool = True,
25
25
  reader_type: Optional[ReaderType] = None,
26
- buffer_size: Optional[int] = None):
26
+ buffer_size: Optional[int] = None,
27
+ enable_crc: bool = True):
27
28
  self._gen_dataset = gen_dataset
28
29
  self._region = region
29
30
  self._endpoint = endpoint
30
- self._trans = trans
31
+ self._trans = transform
31
32
  self._cred = cred
32
33
  self._client_conf = client_conf
33
- self._log_conf = log_conf
34
34
  self._sharding = sharding
35
35
  if torch.distributed.is_initialized():
36
36
  self._rank = torch.distributed.get_rank()
@@ -40,8 +40,8 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
40
40
  self._world_size = 1
41
41
  self._reader_type = reader_type
42
42
  self._buffer_size = buffer_size
43
- self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
44
- use_native_client)
43
+ self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf,
44
+ use_native_client, enable_crc)
45
45
  log.info('TosIterableDataset init tos client succeed')
46
46
 
47
47
  @classmethod
@@ -49,28 +49,28 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
49
49
  transform: Callable[[TosObjectReader], Any] = default_trans,
50
50
  cred: Optional[CredentialProvider] = None,
51
51
  client_conf: Optional[TosClientConfig] = None,
52
- log_conf: Optional[TosLogConfig] = None,
53
52
  sharding: bool = False,
54
53
  use_native_client: bool = True,
55
54
  reader_type: Optional[ReaderType] = None,
56
- buffer_size: Optional[int] = None):
55
+ buffer_size: Optional[int] = None,
56
+ enable_crc: bool = True):
57
57
  log.info(f'building {cls.__name__} from_urls')
58
- return cls(region, partial(gen_dataset_from_urls, urls), endpoint, transform, cred, client_conf, log_conf,
59
- sharding, use_native_client, reader_type, buffer_size)
58
+ return cls(region, partial(gen_dataset_from_urls, urls), endpoint, transform, cred, client_conf,
59
+ sharding, use_native_client, reader_type, buffer_size, enable_crc)
60
60
 
61
61
  @classmethod
62
62
  def from_prefix(cls, prefix: str, *, region: str, endpoint: Optional[str] = None,
63
63
  transform: Callable[[TosObjectReader], Any] = default_trans,
64
64
  cred: Optional[CredentialProvider] = None,
65
65
  client_conf: Optional[TosClientConfig] = None,
66
- log_conf: Optional[TosLogConfig] = None,
67
66
  sharding: bool = False,
68
67
  use_native_client: bool = True,
69
68
  reader_type: Optional[ReaderType] = None,
70
- buffer_size: Optional[int] = None):
69
+ buffer_size: Optional[int] = None,
70
+ enable_crc: bool = True):
71
71
  log.info(f'building {cls.__name__} from_prefix')
72
- return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, transform, cred, client_conf, log_conf,
73
- sharding, use_native_client, reader_type, buffer_size)
72
+ return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, transform, cred, client_conf,
73
+ sharding, use_native_client, reader_type, buffer_size, enable_crc)
74
74
 
75
75
  def __iter__(self) -> Iterator[Any]:
76
76
  worker_id = 0
@@ -4,8 +4,9 @@ from typing import Any, Callable, Iterator, Optional, List, Union
4
4
 
5
5
  import torch
6
6
 
7
+ from tosnativeclient import TosObject
7
8
  from . import TosObjectReader
8
- from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig, ReaderType
9
+ from .tos_client import CredentialProvider, TosClientConfig, TosClient, ReaderType
9
10
  from .tos_common import default_trans, gen_dataset_from_prefix, \
10
11
  gen_dataset_from_urls
11
12
  from .tos_object_meta import TosObjectMeta
@@ -17,52 +18,51 @@ class TosMapDataset(torch.utils.data.Dataset):
17
18
  def __init__(self, region: str,
18
19
  gen_dataset: Callable[[TosClient], Iterator[TosObjectMeta]],
19
20
  endpoint: Optional[str] = None,
20
- trans: Callable[[TosObjectReader], Any] = default_trans,
21
+ transform: Callable[[TosObjectReader], Any] = default_trans,
21
22
  cred: Optional[CredentialProvider] = None,
22
23
  client_conf: Optional[TosClientConfig] = None,
23
- log_conf: Optional[TosLogConfig] = None,
24
24
  use_native_client: bool = True,
25
25
  reader_type: Optional[ReaderType] = None,
26
- buffer_size: Optional[int] = None):
26
+ buffer_size: Optional[int] = None,
27
+ enable_crc: bool = True):
27
28
  self._gen_dataset = gen_dataset
28
29
  self._region = region
29
30
  self._endpoint = endpoint
30
- self._trans = trans
31
+ self._trans = transform
31
32
  self._cred = cred
32
33
  self._client_conf = client_conf
33
- self._log_conf = log_conf
34
34
  self._dataset: Optional[List[TosObjectMeta]] = None
35
35
  self._reader_type = reader_type
36
36
  self._buffer_size = buffer_size
37
- self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
38
- use_native_client)
37
+ self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf,
38
+ use_native_client, enable_crc)
39
39
  log.info('TosMapDataset init tos client succeed')
40
40
 
41
41
  @classmethod
42
42
  def from_urls(cls, urls: Union[str, Iterator[str]], *, region: str, endpoint: Optional[str] = None,
43
- trans: Callable[[TosObjectReader], Any] = default_trans,
43
+ transform: Callable[[TosObjectReader], Any] = default_trans,
44
44
  cred: Optional[CredentialProvider] = None,
45
45
  client_conf: Optional[TosClientConfig] = None,
46
- log_conf: Optional[TosLogConfig] = None,
47
46
  use_native_client: bool = True,
48
47
  reader_type: Optional[ReaderType] = None,
49
- buffer_size: Optional[int] = None):
48
+ buffer_size: Optional[int] = None,
49
+ enable_crc: bool = True):
50
50
  log.info(f'building {cls.__name__} from_urls')
51
- return cls(region, partial(gen_dataset_from_urls, urls), endpoint, trans, cred, client_conf, log_conf,
52
- use_native_client, reader_type, buffer_size)
51
+ return cls(region, partial(gen_dataset_from_urls, urls), endpoint, transform, cred, client_conf,
52
+ use_native_client, reader_type, buffer_size, enable_crc)
53
53
 
54
54
  @classmethod
55
55
  def from_prefix(cls, prefix: str, *, region: str, endpoint: Optional[str] = None,
56
- trans: Callable[[TosObjectReader], Any] = default_trans,
56
+ transform: Callable[[TosObjectReader], Any] = default_trans,
57
57
  cred: Optional[CredentialProvider] = None,
58
58
  client_conf: Optional[TosClientConfig] = None,
59
- log_conf: Optional[TosLogConfig] = None,
60
59
  use_native_client: bool = True,
61
60
  reader_type: Optional[ReaderType] = None,
62
- buffer_size: Optional[int] = None):
61
+ buffer_size: Optional[int] = None,
62
+ enable_crc: bool = True):
63
63
  log.info(f'building {cls.__name__} from_prefix')
64
- return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, trans, cred, client_conf, log_conf,
65
- use_native_client, reader_type, buffer_size)
64
+ return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, transform, cred, client_conf,
65
+ use_native_client, reader_type, buffer_size, enable_crc)
66
66
 
67
67
  def __getitem__(self, i: int) -> Any:
68
68
  return self._trans_tos_object(i)
@@ -85,7 +85,7 @@ class TosObjectStream(object):
85
85
  def close(self) -> None:
86
86
  if self._sequential_object_stream and isinstance(self._sequential_object_stream, tosnativeclient.ReadStream):
87
87
  self._sequential_object_stream.close()
88
- if self._random_object_stream:
88
+ if self._random_object_stream and isinstance(self._random_object_stream, tosnativeclient.ReadStream):
89
89
  self._random_object_stream.close()
90
90
 
91
91
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tostorchconnector
3
- Version: 1.0.6
3
+ Version: 1.1.2
4
4
  Summary: TOS connector integration for PyTorch
5
5
  Author-email: xiangshijian <xiangshijian@bytedance.com>
6
6
  Classifier: Development Status :: 4 - Beta
@@ -19,7 +19,7 @@ Description-Content-Type: text/markdown
19
19
  License-File: LICENSE
20
20
  Requires-Dist: torch>=2.0
21
21
  Requires-Dist: tos>=2.8.0
22
- Requires-Dist: tosnativeclient>=1.0.2
22
+ Requires-Dist: tosnativeclient>=1.1.1
23
23
  Dynamic: license-file
24
24
 
25
25
  # TOS Connector for pytorch
@@ -3,7 +3,6 @@ README.md
3
3
  pyproject.toml
4
4
  tests/test_tos_dataset.py
5
5
  tests/test_tosclient.py
6
- tests/test_tosrawclient.py
7
6
  tostorchconnector/__init__.py
8
7
  tostorchconnector/tos_checkpoint.py
9
8
  tostorchconnector/tos_client.py
@@ -0,0 +1,3 @@
1
+ torch>=2.0
2
+ tos>=2.8.0
3
+ tosnativeclient>=1.1.1
@@ -1,102 +0,0 @@
1
- import os
2
- import unittest
3
- import uuid
4
-
5
- import tos
6
- from tos.exceptions import TosServerError
7
-
8
-
9
- class TestTosRawClient(unittest.TestCase):
10
- def test_put_get(self):
11
- from tosnativeclient import TosRawClient, HeadObjectInput, TosException, DeleteObjectInput, \
12
- PutObjectFromBufferInput, \
13
- GetObjectInput, GetObjectOutput
14
-
15
- region = os.getenv('TOS_REGION')
16
- endpoint = os.getenv('TOS_ENDPOINT')
17
- ak = os.getenv('TOS_ACCESS_KEY')
18
- sk = os.getenv('TOS_SECRET_KEY')
19
- bucket = 'tos-pytorch-connector'
20
- key = str(uuid.uuid4())
21
- tos_raw_client = TosRawClient(region, endpoint, ak, sk)
22
-
23
- doutput = tos_raw_client.delete_object(DeleteObjectInput(bucket, key))
24
- assert doutput.status_code == 204
25
-
26
- try:
27
- tos_raw_client.head_object(HeadObjectInput(bucket, key))
28
- assert False
29
- except TosException as e:
30
- assert e.args[0].status_code == 404
31
- assert len(e.args[0].request_id) > 0
32
-
33
- data = str(uuid.uuid4()).encode('utf-8')
34
- input = PutObjectFromBufferInput(bucket, key, data)
35
- poutput = tos_raw_client.put_object_from_buffer(input)
36
- assert poutput.status_code == 200
37
- assert len(poutput.etag) > 0
38
-
39
- houtput = tos_raw_client.head_object(HeadObjectInput(bucket, key))
40
- assert houtput.status_code == 200
41
- assert houtput.etag == poutput.etag
42
- assert houtput.content_length == len(data)
43
-
44
- goutput: GetObjectOutput = tos_raw_client.get_object(GetObjectInput(bucket, key))
45
- assert goutput.status_code == 200
46
- assert goutput.etag == poutput.etag
47
- rdata = goutput.read_all()
48
- assert rdata == data
49
-
50
- doutput = tos_raw_client.delete_object(DeleteObjectInput(bucket, key))
51
- assert doutput.status_code == 204
52
-
53
- try:
54
- tos_raw_client.get_object(GetObjectInput(bucket, key))
55
- assert False
56
- except TosException as e:
57
- assert e.args[0].status_code == 404
58
- assert len(e.args[0].request_id) > 0
59
-
60
- def test_put_get_old(self):
61
- region = os.getenv('TOS_REGION')
62
- endpoint = os.getenv('TOS_ENDPOINT')
63
- ak = os.getenv('TOS_ACCESS_KEY')
64
- sk = os.getenv('TOS_SECRET_KEY')
65
- bucket = 'tos-pytorch-connector'
66
- key = str(uuid.uuid4())
67
- tos_client = tos.TosClientV2(ak, sk, endpoint=endpoint, region=region)
68
- doutput = tos_client.delete_object(bucket, key)
69
- assert doutput.status_code == 204
70
-
71
- try:
72
- tos_client.head_object(bucket, key)
73
- assert False
74
- except TosServerError as e:
75
- assert e.status_code == 404
76
- assert len(e.request_id) > 0
77
-
78
- data = str(uuid.uuid4()).encode('utf-8')
79
- poutput = tos_client.put_object(bucket, key, content=data)
80
- assert poutput.status_code == 200
81
- assert len(poutput.etag) > 0
82
-
83
- houtput = tos_client.head_object(bucket, key)
84
- assert houtput.status_code == 200
85
- assert houtput.etag == poutput.etag
86
- assert houtput.content_length == len(data)
87
-
88
- goutput = tos_client.get_object(bucket, key)
89
- assert goutput.status_code == 200
90
- assert goutput.etag == poutput.etag
91
- rdata = goutput.read()
92
- assert rdata == data
93
-
94
- doutput = tos_client.delete_object(bucket, key)
95
- assert doutput.status_code == 204
96
-
97
- try:
98
- tos_client.get_object(bucket, key)
99
- assert False
100
- except TosServerError as e:
101
- assert e.status_code == 404
102
- assert len(e.request_id) > 0
@@ -1,3 +0,0 @@
1
- torch>=2.0
2
- tos>=2.8.0
3
- tosnativeclient>=1.0.2