tostorchconnector 1.0.6__tar.gz → 1.0.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tostorchconnector might be problematic. Click here for more details.

Files changed (23) hide show
  1. {tostorchconnector-1.0.6/tostorchconnector.egg-info → tostorchconnector-1.0.8}/PKG-INFO +2 -2
  2. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/pyproject.toml +2 -2
  3. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tests/test_tos_dataset.py +54 -14
  4. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tests/test_tosclient.py +35 -12
  5. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tests/test_tosrawclient.py +15 -0
  6. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/tos_client.py +98 -36
  7. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/tos_common.py +16 -5
  8. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/tos_iterable_dataset.py +2 -2
  9. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/tos_map_dataset.py +6 -6
  10. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/tos_object_reader.py +1 -1
  11. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8/tostorchconnector.egg-info}/PKG-INFO +2 -2
  12. tostorchconnector-1.0.8/tostorchconnector.egg-info/requires.txt +3 -0
  13. tostorchconnector-1.0.6/tostorchconnector.egg-info/requires.txt +0 -3
  14. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/LICENSE +0 -0
  15. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/README.md +0 -0
  16. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/setup.cfg +0 -0
  17. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/__init__.py +0 -0
  18. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/tos_checkpoint.py +0 -0
  19. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/tos_object_meta.py +0 -0
  20. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector/tos_object_writer.py +0 -0
  21. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector.egg-info/SOURCES.txt +0 -0
  22. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector.egg-info/dependency_links.txt +0 -0
  23. {tostorchconnector-1.0.6 → tostorchconnector-1.0.8}/tostorchconnector.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tostorchconnector
3
- Version: 1.0.6
3
+ Version: 1.0.8
4
4
  Summary: TOS connector integration for PyTorch
5
5
  Author-email: xiangshijian <xiangshijian@bytedance.com>
6
6
  Classifier: Development Status :: 4 - Beta
@@ -19,7 +19,7 @@ Description-Content-Type: text/markdown
19
19
  License-File: LICENSE
20
20
  Requires-Dist: torch>=2.0
21
21
  Requires-Dist: tos>=2.8.0
22
- Requires-Dist: tosnativeclient>=1.0.2
22
+ Requires-Dist: tosnativeclient>=1.0.6
23
23
  Dynamic: license-file
24
24
 
25
25
  # TOS Connector for pytorch
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
 
6
6
  [project]
7
7
  name = "tostorchconnector"
8
- version = "1.0.6"
8
+ version = "1.0.8"
9
9
  description = "TOS connector integration for PyTorch"
10
10
  authors = [{ name = "xiangshijian", email = "xiangshijian@bytedance.com" }]
11
11
  requires-python = ">=3.8,<3.14"
@@ -26,7 +26,7 @@ classifiers = [
26
26
  dependencies = [
27
27
  "torch >= 2.0",
28
28
  "tos>=2.8.0",
29
- "tosnativeclient >= 1.0.2"
29
+ "tosnativeclient >= 1.0.6"
30
30
  ]
31
31
 
32
32
  [tool.setuptools.packages]
@@ -1,10 +1,14 @@
1
1
  import os
2
+ import pickle
2
3
  import unittest
3
4
 
5
+ from torch.utils.data import DataLoader
6
+
4
7
  from tostorchconnector import TosMapDataset, TosIterableDataset, TosCheckpoint
5
8
  from tostorchconnector.tos_client import CredentialProvider, ReaderType
6
9
 
7
10
  USE_NATIVE_CLIENT = True
11
+ READER_TYPE = ReaderType.SEQUENTIAL
8
12
 
9
13
 
10
14
  class TestTosDataSet(unittest.TestCase):
@@ -22,23 +26,50 @@ class TestTosDataSet(unittest.TestCase):
22
26
  for i in range(len(datasets)):
23
27
  print(datasets[i].bucket, datasets[i].key)
24
28
 
29
+ def test_pickle(self):
30
+ region = os.getenv('TOS_REGION')
31
+ endpoint = os.getenv('TOS_ENDPOINT')
32
+ ak = os.getenv('TOS_ACCESS_KEY')
33
+ sk = os.getenv('TOS_SECRET_KEY')
34
+ bucket = 'tos-pytorch-connector'
35
+ datasets = TosMapDataset.from_prefix(f'tos://{bucket}', region=region,
36
+ endpoint=endpoint, cred=CredentialProvider(ak, sk),
37
+ use_native_client=USE_NATIVE_CLIENT)
38
+ pickled_datasets = pickle.dumps(datasets)
39
+ assert isinstance(pickled_datasets, bytes)
40
+ unpickled_datasets = pickle.loads(pickled_datasets)
41
+ i = 0
42
+ for dataset in unpickled_datasets:
43
+ print(dataset.bucket, dataset.key)
44
+ i += 1
45
+ print(i)
46
+
47
+ pickled = pickle.dumps(unpickled_datasets._data_set)
48
+ assert isinstance(pickled, bytes)
49
+
25
50
  def test_from_prefix(self):
26
51
  region = os.getenv('TOS_REGION')
27
52
  endpoint = os.getenv('TOS_ENDPOINT')
28
53
  ak = os.getenv('TOS_ACCESS_KEY')
29
54
  sk = os.getenv('TOS_SECRET_KEY')
30
55
  bucket = 'tos-pytorch-connector'
31
- datasets = TosMapDataset.from_prefix(f'tos://{bucket}/prefix', region=region,
56
+ datasets = TosMapDataset.from_prefix(f'tos://{bucket}', region=region,
32
57
  endpoint=endpoint, cred=CredentialProvider(ak, sk),
33
58
  use_native_client=USE_NATIVE_CLIENT)
59
+
60
+ count = 0
34
61
  for i in range(len(datasets)):
35
- item = datasets[i]
36
- print(item.bucket, item.key)
37
- if i == 1:
38
- item = datasets[i]
39
- data = item.read(100)
40
- print(data)
41
- print(len(data))
62
+ dataset = datasets[i]
63
+ print(dataset.bucket, dataset.key)
64
+ count += 1
65
+ dataset.close()
66
+ # if i == 1:
67
+ # item = datasets[i]
68
+ # data = item.read(100)
69
+ # print(data)
70
+ # print(len(data))
71
+
72
+ print(count)
42
73
 
43
74
  def test_from_prefix_iter(self):
44
75
  region = os.getenv('TOS_REGION')
@@ -46,17 +77,23 @@ class TestTosDataSet(unittest.TestCase):
46
77
  ak = os.getenv('TOS_ACCESS_KEY')
47
78
  sk = os.getenv('TOS_SECRET_KEY')
48
79
  bucket = 'tos-pytorch-connector'
49
- datasets = TosIterableDataset.from_prefix(f'tos://{bucket}/prefix', region=region,
80
+ datasets = TosIterableDataset.from_prefix(f'tos://{bucket}', region=region,
50
81
  endpoint=endpoint, cred=CredentialProvider(ak, sk),
51
- use_native_client=USE_NATIVE_CLIENT)
82
+ use_native_client=USE_NATIVE_CLIENT, reader_type=ReaderType.RANGED)
52
83
  i = 0
53
84
  for dataset in datasets:
54
85
  print(dataset.bucket, dataset.key)
55
- if i == 1:
56
- data = dataset.read(100)
57
- print(data)
58
- print(len(data))
86
+ # if dataset.key == 'tosutil':
87
+ # with open('logs/tosutil', 'wb') as f:
88
+ # while 1:
89
+ # chunk = dataset.read(8192)
90
+ # print(len(chunk))
91
+ # if not chunk:
92
+ # break
93
+ # f.write(chunk)
94
+ dataset.close()
59
95
  i += 1
96
+ print(i)
60
97
 
61
98
  def test_checkpoint(self):
62
99
  region = os.getenv('TOS_REGION')
@@ -72,6 +109,9 @@ class TestTosDataSet(unittest.TestCase):
72
109
  writer.write(b'hello world')
73
110
  writer.write(b'hi world')
74
111
 
112
+ with checkpoint.reader(url) as reader:
113
+ print(reader.read())
114
+
75
115
  with checkpoint.reader(url) as reader:
76
116
  data = reader.read(5)
77
117
  print(data)
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import pickle
2
3
  import unittest
3
4
  import uuid
4
5
 
@@ -6,6 +7,21 @@ from tosnativeclient import TosClient, TosException
6
7
 
7
8
 
8
9
  class TestTosClient(unittest.TestCase):
10
+ def test_pickle(self):
11
+ region = os.getenv('TOS_REGION')
12
+ endpoint = os.getenv('TOS_ENDPOINT')
13
+ ak = os.getenv('TOS_ACCESS_KEY')
14
+ sk = os.getenv('TOS_SECRET_KEY')
15
+ bucket = 'tos-pytorch-connector'
16
+ tos_client = TosClient(region, endpoint, ak, sk, directives='info', directory='logs',
17
+ file_name_prefix='app.log')
18
+ pickled_tos_client = pickle.dumps(tos_client)
19
+ assert isinstance(pickled_tos_client, bytes)
20
+ unpickled_tos_client = pickle.loads(pickled_tos_client)
21
+
22
+ self._test_list_objects(bucket, unpickled_tos_client)
23
+ self._test_write_read_object(bucket, unpickled_tos_client)
24
+
9
25
  def test_list_objects(self):
10
26
  region = os.getenv('TOS_REGION')
11
27
  endpoint = os.getenv('TOS_ENDPOINT')
@@ -14,7 +30,20 @@ class TestTosClient(unittest.TestCase):
14
30
  bucket = 'tos-pytorch-connector'
15
31
  tos_client = TosClient(region, endpoint, ak, sk, directives='info', directory='logs',
16
32
  file_name_prefix='app.log')
33
+ self._test_list_objects(bucket, tos_client)
17
34
 
35
+ def test_write_read_object(self):
36
+ region = os.getenv('TOS_REGION')
37
+ endpoint = os.getenv('TOS_ENDPOINT')
38
+ ak = os.getenv('TOS_ACCESS_KEY')
39
+ sk = os.getenv('TOS_SECRET_KEY')
40
+ bucket = 'tos-pytorch-connector'
41
+ tos_client = TosClient(region, endpoint, ak, sk, directives='info', directory='logs',
42
+ file_name_prefix='app.log')
43
+
44
+ self._test_write_read_object(bucket, tos_client)
45
+
46
+ def _test_list_objects(self, bucket, tos_client):
18
47
  list_stream = tos_client.list_objects(bucket, '', max_keys=1000)
19
48
  count = 0
20
49
  try:
@@ -22,23 +51,15 @@ class TestTosClient(unittest.TestCase):
22
51
  for content in objects.contents:
23
52
  count += 1
24
53
  print(content.key, content.size)
25
- # output = tos_client.head_object(bucket, content.key)
26
- # assert output.etag == content.etag
27
- # assert output.size == content.size
54
+ output = tos_client.head_object(bucket, content.key)
55
+ assert output.etag == content.etag
56
+ assert output.size == content.size
28
57
 
29
58
  print(count)
30
59
  except TosException as e:
31
60
  print(e.args[0].message)
32
61
 
33
- def test_write_read_object(self):
34
- region = os.getenv('TOS_REGION')
35
- endpoint = os.getenv('TOS_ENDPOINT')
36
- ak = os.getenv('TOS_ACCESS_KEY')
37
- sk = os.getenv('TOS_SECRET_KEY')
38
- bucket = 'tos-pytorch-connector'
39
- tos_client = TosClient(region, endpoint, ak, sk, directives='info', directory='logs',
40
- file_name_prefix='app.log')
41
-
62
+ def _test_write_read_object(self, bucket, tos_client):
42
63
  key = str(uuid.uuid4())
43
64
  read_stream = tos_client.get_object(bucket, key, '', 1)
44
65
 
@@ -59,6 +80,8 @@ class TestTosClient(unittest.TestCase):
59
80
  write_stream.close()
60
81
 
61
82
  output = tos_client.head_object(bucket, key)
83
+ print(output.etag, output.size)
84
+
62
85
  read_stream = tos_client.get_object(bucket, key, output.etag, output.size)
63
86
  try:
64
87
  offset = 0
@@ -1,3 +1,4 @@
1
+ import io
1
2
  import os
2
3
  import unittest
3
4
  import uuid
@@ -47,6 +48,20 @@ class TestTosRawClient(unittest.TestCase):
47
48
  rdata = goutput.read_all()
48
49
  assert rdata == data
49
50
 
51
+ goutput: GetObjectOutput = tos_raw_client.get_object(GetObjectInput(bucket, key))
52
+ assert goutput.status_code == 200
53
+ assert goutput.etag == poutput.etag
54
+
55
+ rdata = io.BytesIO()
56
+ while 1:
57
+ chunk = goutput.read()
58
+ if not chunk:
59
+ break
60
+ rdata.write(chunk)
61
+
62
+ rdata.seek(0)
63
+ assert rdata.read() == data
64
+
50
65
  doutput = tos_raw_client.delete_object(DeleteObjectInput(bucket, key))
51
66
  assert doutput.status_code == 204
52
67
 
@@ -1,8 +1,10 @@
1
1
  import enum
2
2
  import logging
3
3
  import os
4
+ import gc
5
+
4
6
  from functools import partial
5
- from typing import Optional, List, Tuple
7
+ from typing import Optional, List, Tuple, Any
6
8
 
7
9
  import tos
8
10
 
@@ -15,6 +17,44 @@ from .tos_object_writer import PutObjectStream
15
17
 
16
18
  log = logging.getLogger(__name__)
17
19
 
20
+ import threading
21
+ import weakref
22
+ import traceback
23
+
24
+ _client_lock = threading.Lock()
25
+ _client_map = weakref.WeakSet()
26
+
27
+
28
+ def _before_fork():
29
+ with _client_lock:
30
+ clients = list(_client_map)
31
+
32
+ if not clients or len(clients) == 0:
33
+ return
34
+
35
+ try:
36
+ for client in clients:
37
+ client._inner_client = None
38
+
39
+ _reset_client_map()
40
+ gc.collect()
41
+ except Exception as e:
42
+ log.warning(f'failed to clean up native clients, {str(e)}')
43
+ traceback.print_exc()
44
+
45
+
46
+ def _after_fork_in_child():
47
+ _reset_client_map()
48
+
49
+
50
+ def _reset_client_map():
51
+ global _client_map
52
+ with _client_lock:
53
+ _client_map = weakref.WeakSet()
54
+
55
+
56
+ os.register_at_fork(before=_before_fork, after_in_child=_after_fork_in_child)
57
+
18
58
 
19
59
  class ReaderType(enum.Enum):
20
60
  SEQUENTIAL = 'Sequential'
@@ -79,40 +119,59 @@ class TosClient(object):
79
119
  def __init__(self, region: str, endpoint: Optional[str] = None, cred: Optional[CredentialProvider] = None,
80
120
  client_conf: Optional[TosClientConfig] = None, log_conf: Optional[TosLogConfig] = None,
81
121
  use_native_client: bool = True):
82
- cred = CredentialProvider('', '') if cred is None else cred
83
- client_conf = TosClientConfig() if client_conf is None else client_conf
84
- log_conf = TosLogConfig() if log_conf is None else log_conf
85
- self._part_size = client_conf.part_size
122
+ self._region = region
123
+ self._endpoint = endpoint
124
+ self._cred = CredentialProvider('', '') if cred is None else cred
125
+ self._client_conf = TosClientConfig() if client_conf is None else client_conf
126
+ self._log_conf = TosLogConfig() if log_conf is None else log_conf
127
+ self._part_size = self._client_conf.part_size
86
128
  self._use_native_client = use_native_client
87
- if use_native_client:
88
- directives = ''
89
- directory = ''
90
- file_name_prefix = ''
91
- if log_conf.log_dir and log_conf.log_file_name:
92
- if log_conf.log_level:
93
- if log_conf.log_level == logging.DEBUG:
94
- directives = 'debug'
95
- elif log_conf.log_level == logging.INFO:
96
- directives = 'info'
97
- elif log_conf.log_level == logging.WARN:
98
- directives = 'warn'
99
- elif log_conf.log_level == logging.ERROR:
100
- directives = 'error'
101
- else:
102
- directives = 'info'
103
- directory = log_conf.log_dir
104
- file_name_prefix = log_conf.log_file_name
105
- self._client = tosnativeclient.TosClient(region, endpoint, cred.ak, cred.sk, client_conf.part_size,
106
- client_conf.max_retry_count, directives=directives,
107
- directory=directory,
108
- file_name_prefix=file_name_prefix)
109
- else:
110
- self._client = tos.TosClientV2(cred.ak, cred.sk, endpoint=endpoint, region=region,
111
- max_retry_count=client_conf.max_retry_count)
112
- if log_conf.log_dir and log_conf.log_file_name:
113
- file_path = os.path.join(log_conf.log_dir, log_conf.log_file_name)
114
- log_level = log_conf.log_level if log_conf.log_level else logging.INFO
115
- tos.set_logger(file_path=file_path, level=log_level)
129
+ self._inner_client = None
130
+ self._client_pid = None
131
+
132
+ @property
133
+ def _client(self) -> Any:
134
+ if self._client_pid is None or self._client_pid != os.getpid() or self._inner_client is None:
135
+ with _client_lock:
136
+ if self._use_native_client:
137
+ directives = ''
138
+ directory = ''
139
+ file_name_prefix = ''
140
+ if self._log_conf.log_dir and self._log_conf.log_file_name:
141
+ if self._log_conf.log_level:
142
+ if self._log_conf.log_level == logging.DEBUG:
143
+ directives = 'debug'
144
+ elif self._log_conf.log_level == logging.INFO:
145
+ directives = 'info'
146
+ elif self._log_conf.log_level == logging.WARN:
147
+ directives = 'warn'
148
+ elif self._log_conf.log_level == logging.ERROR:
149
+ directives = 'error'
150
+ else:
151
+ directives = 'info'
152
+ directory = self._log_conf.log_dir
153
+ file_name_prefix = self._log_conf.log_file_name
154
+ self._inner_client = tosnativeclient.TosClient(self._region, self._endpoint, self._cred.ak,
155
+ self._cred.sk,
156
+ self._client_conf.part_size,
157
+ self._client_conf.max_retry_count,
158
+ directives=directives,
159
+ directory=directory,
160
+ file_name_prefix=file_name_prefix)
161
+ else:
162
+ self._inner_client = tos.TosClientV2(self._cred.ak, self._cred.sk, endpoint=self._endpoint,
163
+ region=self._region,
164
+ max_retry_count=self._client_conf.max_retry_count)
165
+ if self._log_conf.log_dir and self._log_conf.log_file_name:
166
+ file_path = os.path.join(self._log_conf.log_dir, self._log_conf.log_file_name)
167
+ log_level = self._log_conf.log_level if self._log_conf.log_level else logging.INFO
168
+ tos.set_logger(file_path=file_path, level=log_level)
169
+
170
+ self._client_pid = os.getpid()
171
+ _client_map.add(self)
172
+
173
+ assert self._inner_client is not None
174
+ return self._inner_client
116
175
 
117
176
  @property
118
177
  def use_native_client(self) -> bool:
@@ -155,12 +214,15 @@ class TosClient(object):
155
214
  return TosObjectMeta(bucket, key, resp.content_length, resp.etag)
156
215
 
157
216
  def gen_list_stream(self, bucket: str, prefix: str, max_keys: int = 1000,
158
- delimiter: Optional[str] = None) -> tosnativeclient.ListStream:
217
+ delimiter: Optional[str] = None,
218
+ continuation_token: Optional[str] = None) -> tosnativeclient.ListStream:
159
219
  log.debug(f'gen_list_stream tos://{bucket}/{prefix}')
160
220
 
161
221
  if isinstance(self._client, tosnativeclient.TosClient):
162
222
  delimiter = delimiter if delimiter is not None else ''
163
- return self._client.list_objects(bucket, prefix, max_keys=max_keys, delimiter=delimiter)
223
+ continuation_token = continuation_token if continuation_token is not None else ''
224
+ return self._client.list_objects(bucket, prefix, max_keys=max_keys, delimiter=delimiter,
225
+ continuation_token=continuation_token)
164
226
  raise NotImplementedError()
165
227
 
166
228
  def list_objects(self, bucket: str, prefix: str, max_keys: int = 1000,
@@ -8,19 +8,28 @@ from .tos_object_meta import TosObjectMeta
8
8
  log = logging.getLogger(__name__)
9
9
 
10
10
 
11
+ class TosObjectIterable(object):
12
+ def __init__(self, bucket: str, prefix: str, client: TosClient):
13
+ self._bucket = bucket
14
+ self._prefix = prefix
15
+ self._client = client
16
+
17
+ def __iter__(self) -> Iterator[TosObjectMeta]:
18
+ return iter(TosObjectIterator(self._bucket, self._prefix, self._client))
19
+
20
+
11
21
  class TosObjectIterator(object):
12
22
  def __init__(self, bucket: str, prefix: str, client: TosClient):
13
23
  self._bucket = bucket
14
24
  self._prefix = prefix
15
25
  self._client = client
16
26
  self._delimiter: Optional[str] = None
17
- self._list_stream = None
27
+ self._continuation_token: Optional[str] = None
18
28
 
29
+ self._list_stream = None
19
30
  self._object_metas = None
20
31
  self._index = 0
21
-
22
32
  self._is_truncated = True
23
- self._continuation_token = None
24
33
 
25
34
  def close(self) -> None:
26
35
  if self._list_stream is not None:
@@ -33,13 +42,15 @@ class TosObjectIterator(object):
33
42
  if self._client.use_native_client:
34
43
  if self._list_stream is None:
35
44
  self._list_stream = self._client.gen_list_stream(self._bucket, self._prefix, max_keys=1000,
36
- delimiter=self._delimiter)
45
+ delimiter=self._delimiter,
46
+ continuation_token=self._continuation_token)
37
47
 
38
48
  if self._object_metas is None or self._index >= len(self._object_metas):
39
49
  self._object_metas = []
40
50
  self._index = 0
41
51
  while 1:
42
52
  objects = next(self._list_stream)
53
+ self._continuation_token = self._list_stream.current_continuation_token()
43
54
  for content in objects.contents:
44
55
  self._object_metas.append(
45
56
  TosObjectMeta(content.bucket, content.key, content.size, content.etag))
@@ -101,4 +112,4 @@ def gen_dataset_from_urls(urls: Union[str, Iterator[str]], _: TosClient) -> Iter
101
112
 
102
113
  def gen_dataset_from_prefix(prefix: str, client: TosClient) -> Iterator[TosObjectMeta]:
103
114
  bucket, prefix = parse_tos_url(prefix)
104
- return iter(TosObjectIterator(bucket, prefix, client))
115
+ return iter(TosObjectIterable(bucket, prefix, client))
@@ -16,7 +16,7 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
16
16
  def __init__(self, region: str,
17
17
  gen_dataset: Callable[[TosClient], Iterator[TosObjectMeta]],
18
18
  endpoint: Optional[str] = None,
19
- trans: Callable[[TosObjectReader], Any] = default_trans,
19
+ transform: Callable[[TosObjectReader], Any] = default_trans,
20
20
  cred: Optional[CredentialProvider] = None,
21
21
  client_conf: Optional[TosClientConfig] = None,
22
22
  log_conf: Optional[TosLogConfig] = None,
@@ -27,7 +27,7 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
27
27
  self._gen_dataset = gen_dataset
28
28
  self._region = region
29
29
  self._endpoint = endpoint
30
- self._trans = trans
30
+ self._trans = transform
31
31
  self._cred = cred
32
32
  self._client_conf = client_conf
33
33
  self._log_conf = log_conf
@@ -17,7 +17,7 @@ class TosMapDataset(torch.utils.data.Dataset):
17
17
  def __init__(self, region: str,
18
18
  gen_dataset: Callable[[TosClient], Iterator[TosObjectMeta]],
19
19
  endpoint: Optional[str] = None,
20
- trans: Callable[[TosObjectReader], Any] = default_trans,
20
+ transform: Callable[[TosObjectReader], Any] = default_trans,
21
21
  cred: Optional[CredentialProvider] = None,
22
22
  client_conf: Optional[TosClientConfig] = None,
23
23
  log_conf: Optional[TosLogConfig] = None,
@@ -27,7 +27,7 @@ class TosMapDataset(torch.utils.data.Dataset):
27
27
  self._gen_dataset = gen_dataset
28
28
  self._region = region
29
29
  self._endpoint = endpoint
30
- self._trans = trans
30
+ self._trans = transform
31
31
  self._cred = cred
32
32
  self._client_conf = client_conf
33
33
  self._log_conf = log_conf
@@ -40,7 +40,7 @@ class TosMapDataset(torch.utils.data.Dataset):
40
40
 
41
41
  @classmethod
42
42
  def from_urls(cls, urls: Union[str, Iterator[str]], *, region: str, endpoint: Optional[str] = None,
43
- trans: Callable[[TosObjectReader], Any] = default_trans,
43
+ transform: Callable[[TosObjectReader], Any] = default_trans,
44
44
  cred: Optional[CredentialProvider] = None,
45
45
  client_conf: Optional[TosClientConfig] = None,
46
46
  log_conf: Optional[TosLogConfig] = None,
@@ -48,12 +48,12 @@ class TosMapDataset(torch.utils.data.Dataset):
48
48
  reader_type: Optional[ReaderType] = None,
49
49
  buffer_size: Optional[int] = None):
50
50
  log.info(f'building {cls.__name__} from_urls')
51
- return cls(region, partial(gen_dataset_from_urls, urls), endpoint, trans, cred, client_conf, log_conf,
51
+ return cls(region, partial(gen_dataset_from_urls, urls), endpoint, transform, cred, client_conf, log_conf,
52
52
  use_native_client, reader_type, buffer_size)
53
53
 
54
54
  @classmethod
55
55
  def from_prefix(cls, prefix: str, *, region: str, endpoint: Optional[str] = None,
56
- trans: Callable[[TosObjectReader], Any] = default_trans,
56
+ transform: Callable[[TosObjectReader], Any] = default_trans,
57
57
  cred: Optional[CredentialProvider] = None,
58
58
  client_conf: Optional[TosClientConfig] = None,
59
59
  log_conf: Optional[TosLogConfig] = None,
@@ -61,7 +61,7 @@ class TosMapDataset(torch.utils.data.Dataset):
61
61
  reader_type: Optional[ReaderType] = None,
62
62
  buffer_size: Optional[int] = None):
63
63
  log.info(f'building {cls.__name__} from_prefix')
64
- return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, trans, cred, client_conf, log_conf,
64
+ return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, transform, cred, client_conf, log_conf,
65
65
  use_native_client, reader_type, buffer_size)
66
66
 
67
67
  def __getitem__(self, i: int) -> Any:
@@ -85,7 +85,7 @@ class TosObjectStream(object):
85
85
  def close(self) -> None:
86
86
  if self._sequential_object_stream and isinstance(self._sequential_object_stream, tosnativeclient.ReadStream):
87
87
  self._sequential_object_stream.close()
88
- if self._random_object_stream:
88
+ if self._random_object_stream and isinstance(self._random_object_stream, tosnativeclient.ReadStream):
89
89
  self._random_object_stream.close()
90
90
 
91
91
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tostorchconnector
3
- Version: 1.0.6
3
+ Version: 1.0.8
4
4
  Summary: TOS connector integration for PyTorch
5
5
  Author-email: xiangshijian <xiangshijian@bytedance.com>
6
6
  Classifier: Development Status :: 4 - Beta
@@ -19,7 +19,7 @@ Description-Content-Type: text/markdown
19
19
  License-File: LICENSE
20
20
  Requires-Dist: torch>=2.0
21
21
  Requires-Dist: tos>=2.8.0
22
- Requires-Dist: tosnativeclient>=1.0.2
22
+ Requires-Dist: tosnativeclient>=1.0.6
23
23
  Dynamic: license-file
24
24
 
25
25
  # TOS Connector for pytorch
@@ -0,0 +1,3 @@
1
+ torch>=2.0
2
+ tos>=2.8.0
3
+ tosnativeclient>=1.0.6
@@ -1,3 +0,0 @@
1
- torch>=2.0
2
- tos>=2.8.0
3
- tosnativeclient>=1.0.2