tostorchconnector 1.0.0__tar.gz → 1.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tostorchconnector might be problematic. Click here for more details.

Files changed (24) hide show
  1. {tostorchconnector-1.0.0/tostorchconnector/tostorchconnector.egg-info → tostorchconnector-1.0.1}/PKG-INFO +1 -1
  2. {tostorchconnector-1.0.0 → tostorchconnector-1.0.1}/pyproject.toml +2 -2
  3. tostorchconnector-1.0.1/tostorchconnector/__init__.py +15 -0
  4. tostorchconnector-1.0.1/tostorchconnector/tos_checkpoint.py +39 -0
  5. tostorchconnector-1.0.1/tostorchconnector/tos_client.py +169 -0
  6. tostorchconnector-1.0.1/tostorchconnector/tos_common.py +99 -0
  7. tostorchconnector-1.0.1/tostorchconnector/tos_error.py +11 -0
  8. tostorchconnector-1.0.1/tostorchconnector/tos_iterable_dataset.py +99 -0
  9. tostorchconnector-1.0.1/tostorchconnector/tos_map_dataset.py +78 -0
  10. tostorchconnector-1.0.1/tostorchconnector/tos_object_meta.py +25 -0
  11. tostorchconnector-1.0.1/tostorchconnector/tos_object_reader.py +170 -0
  12. tostorchconnector-1.0.1/tostorchconnector/tos_object_writer.py +83 -0
  13. {tostorchconnector-1.0.0 → tostorchconnector-1.0.1/tostorchconnector.egg-info}/PKG-INFO +1 -1
  14. tostorchconnector-1.0.1/tostorchconnector.egg-info/SOURCES.txt +20 -0
  15. tostorchconnector-1.0.1/tostorchconnector.egg-info/top_level.txt +1 -0
  16. tostorchconnector-1.0.0/tostorchconnector/tostorchconnector.egg-info/SOURCES.txt +0 -10
  17. tostorchconnector-1.0.0/tostorchconnector/tostorchconnector.egg-info/top_level.txt +0 -1
  18. {tostorchconnector-1.0.0 → tostorchconnector-1.0.1}/LICENSE +0 -0
  19. {tostorchconnector-1.0.0 → tostorchconnector-1.0.1}/README.md +0 -0
  20. {tostorchconnector-1.0.0 → tostorchconnector-1.0.1}/setup.cfg +0 -0
  21. {tostorchconnector-1.0.0 → tostorchconnector-1.0.1}/tests/test_tos_dataset.py +0 -0
  22. {tostorchconnector-1.0.0 → tostorchconnector-1.0.1}/tests/test_tosnativeclient.py +0 -0
  23. {tostorchconnector-1.0.0/tostorchconnector → tostorchconnector-1.0.1}/tostorchconnector.egg-info/dependency_links.txt +0 -0
  24. {tostorchconnector-1.0.0/tostorchconnector → tostorchconnector-1.0.1}/tostorchconnector.egg-info/requires.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tostorchconnector
3
- Version: 1.0.0
3
+ Version: 1.0.1
4
4
  Summary: TOS connector integration for PyTorch
5
5
  Author-email: xiangshijian <xiangshijian@bytedance.com>
6
6
  Classifier: Development Status :: 4 - Beta
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
 
6
6
  [project]
7
7
  name = "tostorchconnector"
8
- version = "1.0.0"
8
+ version = "1.0.1"
9
9
  description = "TOS connector integration for PyTorch"
10
10
  authors = [{ name = "xiangshijian", email = "xiangshijian@bytedance.com" }]
11
11
  requires-python = ">=3.8,<3.13"
@@ -29,4 +29,4 @@ dependencies = [
29
29
  ]
30
30
 
31
31
  [tool.setuptools.packages]
32
- find = { where = ["tostorchconnector"] }
32
+ find = { where = ["."], include = ["tostorchconnector"] }
@@ -0,0 +1,15 @@
1
+ from .tos_object_reader import TosObjectReader
2
+ from .tos_object_writer import TosObjectWriter
3
+ from .tos_iterable_dataset import TosIterableDataset
4
+ from .tos_map_dataset import TosMapDataset
5
+ from .tos_checkpoint import TosCheckpoint
6
+ from .tos_error import TosError
7
+
8
+ __all__ = [
9
+ 'TosObjectReader',
10
+ 'TosObjectWriter',
11
+ 'TosIterableDataset',
12
+ 'TosMapDataset',
13
+ 'TosCheckpoint',
14
+ 'TosError',
15
+ ]
@@ -0,0 +1,39 @@
1
+ import logging
2
+ from typing import Optional
3
+
4
+ from . import TosObjectReader, TosObjectWriter
5
+ from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig
6
+ from .tos_common import parse_tos_url
7
+
8
+ log = logging.getLogger(__name__)
9
+
10
+
11
+ class TosCheckpoint(object):
12
+ def __init__(self, region: str,
13
+ endpoint: Optional[str] = None,
14
+ cred: Optional[CredentialProvider] = None,
15
+ client_conf: Optional[TosClientConfig] = None,
16
+ log_conf: Optional[TosLogConfig] = None, use_native_client=True):
17
+ self._client = None
18
+ self._native_client = None
19
+ self._region = region
20
+ self._endpoint = endpoint
21
+ self._cred = cred
22
+ self._client_conf = client_conf
23
+ self._log_conf = log_conf
24
+ self._use_native_client = use_native_client
25
+
26
+ def reader(self, url: str) -> TosObjectReader:
27
+ bucket, key = parse_tos_url(url)
28
+ return self._get_tos_client().get_object(bucket, key)
29
+
30
+ def writer(self, url: str) -> TosObjectWriter:
31
+ bucket, key = parse_tos_url(url)
32
+ return self._get_tos_client().put_object(bucket, key)
33
+
34
+ def _get_tos_client(self):
35
+ if self._client is None:
36
+ self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
37
+ self._use_native_client)
38
+ log.info('TosIterableDataset init tos client succeed')
39
+ return self._client
@@ -0,0 +1,169 @@
1
+ import io
2
+ import logging
3
+ import os
4
+ from functools import partial
5
+ from typing import Optional, List, Tuple
6
+
7
+ import tos
8
+ from tos.models2 import GetObjectOutput, PutObjectOutput
9
+
10
+ import tosnativeclient
11
+
12
+ from . import TosObjectReader, TosObjectWriter
13
+ from .tos_object_meta import TosObjectMeta
14
+ from .tos_object_writer import PutObjectStream
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ class CredentialProvider(object):
20
+ def __init__(self, ak: str, sk: str):
21
+ self._ak = ak
22
+ self._sk = sk
23
+
24
+ @property
25
+ def ak(self) -> str:
26
+ return self._ak
27
+
28
+ @property
29
+ def sk(self) -> str:
30
+ return self._sk
31
+
32
+
33
+ class TosClientConfig(object):
34
+ def __init__(self, part_size: int = 8 * 1024 * 1024,
35
+ max_retry_count: int = 3, shared_prefetch_tasks: int = 20):
36
+ self._part_size = part_size
37
+ self._max_retry_count = max_retry_count
38
+ self._shared_prefetch_tasks = shared_prefetch_tasks
39
+
40
+ @property
41
+ def part_size(self) -> int:
42
+ return self._part_size
43
+
44
+ @property
45
+ def max_retry_count(self) -> int:
46
+ return self._max_retry_count
47
+
48
+ @property
49
+ def shared_prefetch_tasks(self) -> int:
50
+ return self._shared_prefetch_tasks
51
+
52
+
53
+ class TosLogConfig(object):
54
+ def __init__(self, log_dir: str = '',
55
+ log_file_name: str = '', log_level: Optional[int] = logging.INFO):
56
+ self._log_dir = log_dir
57
+ self._log_file_name = log_file_name
58
+ self._log_level = log_level
59
+
60
+ @property
61
+ def log_level(self) -> Optional[int]:
62
+ return self._log_level
63
+
64
+ @property
65
+ def log_dir(self) -> str:
66
+ return self._log_dir
67
+
68
+ @property
69
+ def log_file_name(self) -> str:
70
+ return self._log_file_name
71
+
72
+
73
+ class TosClient(object):
74
+ def __init__(self, region: str, endpoint: Optional[str] = None, cred: Optional[CredentialProvider] = None,
75
+ client_conf: Optional[TosClientConfig] = None, log_conf: Optional[TosLogConfig] = None,
76
+ use_native_client: bool = True):
77
+ cred = CredentialProvider('', '') if cred is None else cred
78
+ client_conf = TosClientConfig() if client_conf is None else client_conf
79
+ log_conf = TosLogConfig() if log_conf is None else log_conf
80
+ self._part_size = client_conf.part_size
81
+ self._use_native_client = use_native_client
82
+ if use_native_client:
83
+ directives = ''
84
+ directory = ''
85
+ file_name_prefix = ''
86
+ if log_conf.log_dir and log_conf.log_file_name:
87
+ if log_conf.log_level:
88
+ if log_conf.log_level == logging.DEBUG:
89
+ directives = 'debug'
90
+ elif log_conf.log_level == logging.INFO:
91
+ directives = 'info'
92
+ elif log_conf.log_level == logging.WARN:
93
+ directives = 'warn'
94
+ elif log_conf.log_level == logging.ERROR:
95
+ directives = 'error'
96
+ else:
97
+ directives = 'info'
98
+ directory = log_conf.log_dir
99
+ file_name_prefix = log_conf.log_file_name
100
+ self._client = tosnativeclient.TosClient(region, endpoint, cred.ak, cred.sk, client_conf.part_size,
101
+ client_conf.max_retry_count, directives=directives,
102
+ directory=directory,
103
+ file_name_prefix=file_name_prefix)
104
+ else:
105
+ self._client = tos.TosClientV2(cred.ak, cred.sk, endpoint=endpoint, region=region,
106
+ max_retry_count=client_conf.max_retry_count)
107
+ if log_conf.log_dir and log_conf.log_file_name:
108
+ file_path = os.path.join(log_conf.log_dir, log_conf.log_file_name)
109
+ log_level = log_conf.log_level if log_conf.log_level else logging.INFO
110
+ tos.set_logger(file_path=file_path, level=log_level)
111
+
112
+ @property
113
+ def use_native_client(self):
114
+ return self._use_native_client
115
+
116
+ def get_object(self, bucket: str, key: str, etag: Optional[str] = None,
117
+ size: Optional[int] = None) -> TosObjectReader:
118
+ log.debug(f'get_object tos://{bucket}/{key}')
119
+ if size is None or etag is None:
120
+ get_object_meta = partial(self.head_object, bucket, key)
121
+ else:
122
+ get_object_meta = lambda: TosObjectMeta(bucket, key, size, etag)
123
+ if isinstance(self._client, tosnativeclient.TosClient):
124
+ get_object_stream = partial(self._client.get_object, bucket, key)
125
+ else:
126
+ def get_object_stream(et: str, _: int) -> GetObjectOutput:
127
+ return self._client.get_object(bucket, key, '', et)
128
+
129
+ return TosObjectReader(bucket, key, get_object_meta, get_object_stream)
130
+
131
+ def put_object(self, bucket: str, key: str, storage_class: Optional[str] = None) -> TosObjectWriter:
132
+ log.debug(f'put_object tos://{bucket}/{key}')
133
+
134
+ if isinstance(self._client, tosnativeclient.TosClient):
135
+ storage_class = storage_class if storage_class is not None else ''
136
+ put_object_stream = self._client.put_object(bucket, key, storage_class)
137
+ else:
138
+ def put_object(content: io.BytesIO) -> PutObjectOutput:
139
+ return self._client.put_object(bucket, key, storage_class=storage_class, content=content)
140
+
141
+ put_object_stream = PutObjectStream(put_object)
142
+
143
+ return TosObjectWriter(bucket, key, put_object_stream)
144
+
145
+ def head_object(self, bucket: str, key: str) -> TosObjectMeta:
146
+ log.debug(f'head_object tos://{bucket}/{key}')
147
+ if isinstance(self._client, tosnativeclient.TosClient):
148
+ resp = self._client.head_object(bucket, key)
149
+ return TosObjectMeta(resp.bucket, resp.key, resp.size, resp.etag)
150
+ resp = self._client.head_object(bucket, key)
151
+ return TosObjectMeta(bucket, key, resp.content_length, resp.etag)
152
+
153
+ def gen_list_stream(self, bucket: str, prefix: str, max_keys: int = 1000,
154
+ delimiter: Optional[str] = None) -> tosnativeclient.ListStream:
155
+ if isinstance(self._client, tosnativeclient.TosClient):
156
+ delimiter = delimiter if delimiter is not None else ''
157
+ return self._client.list_objects(bucket, prefix, max_keys=max_keys, delimiter=delimiter)
158
+ raise NotImplementedError()
159
+
160
+ def list_objects(self, bucket: str, prefix: str, max_keys: int = 1000,
161
+ continuation_token: Optional[str] = None, delimiter: Optional[str] = None) -> Tuple[
162
+ List[TosObjectMeta], bool, Optional[str]]:
163
+ log.debug(f'list_objects tos://{bucket}/{prefix}')
164
+ resp = self._client.list_objects_type2(bucket, prefix, max_keys=max_keys, continuation_token=continuation_token,
165
+ delimiter=delimiter)
166
+ object_metas = []
167
+ for obj in resp.contents:
168
+ object_metas.append(TosObjectMeta(bucket, obj.key, obj.size, obj.etag))
169
+ return object_metas, resp.is_truncated, resp.next_continuation_token
@@ -0,0 +1,99 @@
1
+ import logging
2
+ from typing import Union, Iterator, Tuple, Optional
3
+
4
+ from . import TosObjectReader
5
+ from .tos_client import TosClient
6
+ from .tos_object_meta import TosObjectMeta
7
+
8
+ log = logging.getLogger(__name__)
9
+
10
+
11
+ class TosObjectIterator(object):
12
+ def __init__(self, bucket: str, prefix: str, client: TosClient):
13
+ self._bucket = bucket
14
+ self._prefix = prefix
15
+ self._client = client
16
+ self._delimiter: Optional[str] = None
17
+ self._list_stream = None
18
+
19
+ self._object_metas = None
20
+ self._index = 0
21
+
22
+ self._is_truncated = True
23
+ self._continuation_token = None
24
+
25
+ def close(self):
26
+ if self._list_stream is not None:
27
+ self._list_stream.close()
28
+
29
+ def __iter__(self) -> Iterator[TosObjectMeta]:
30
+ return self
31
+
32
+ def __next__(self):
33
+ if self._client.use_native_client:
34
+ if self._list_stream is None:
35
+ self._list_stream = self._client.gen_list_stream(self._bucket, self._prefix, max_keys=1000,
36
+ delimiter=self._delimiter)
37
+
38
+ if self._object_metas is None or self._index >= len(self._object_metas):
39
+ self._object_metas = []
40
+ objects = next(self._list_stream)
41
+ for content in objects.contents:
42
+ self._object_metas.append(TosObjectMeta(content.bucket, content.key, content.size, content.etag))
43
+
44
+ object_meta = self._object_metas[self._index]
45
+ self._index += 1
46
+ return object_meta
47
+
48
+ while self._object_metas is None or self._index >= len(self._object_metas):
49
+ if not self._is_truncated:
50
+ raise StopIteration
51
+ self._object_metas, self._is_truncated, self._continuation_token = self._client.list_objects(
52
+ self._bucket,
53
+ self._prefix,
54
+ max_keys=1000,
55
+ continuation_token=self._continuation_token,
56
+ delimiter=self._delimiter)
57
+ self._index = 0
58
+
59
+ object_meta = self._object_metas[self._index]
60
+ self._index += 1
61
+ return object_meta
62
+
63
+
64
+ def parse_tos_url(url: str) -> Tuple[str, str]:
65
+ if not url:
66
+ raise ValueError('url is empty')
67
+
68
+ if url.startswith('tos://'):
69
+ url = url[len('tos://'):]
70
+
71
+ if not url:
72
+ raise ValueError('bucket is empty')
73
+
74
+ url = url.split('/', maxsplit=1)
75
+ if len(url) == 1:
76
+ bucket = url[0]
77
+ prefix = ''
78
+ else:
79
+ bucket = url[0]
80
+ prefix = url[1]
81
+
82
+ if not bucket:
83
+ raise ValueError('bucket is empty')
84
+ return bucket, prefix
85
+
86
+
87
+ def default_trans(obj: TosObjectReader) -> TosObjectReader:
88
+ return obj
89
+
90
+
91
+ def gen_dataset_from_urls(urls: Union[str, Iterator[str]], _: TosClient) -> Iterator[TosObjectMeta]:
92
+ if isinstance(urls, str):
93
+ urls = [urls]
94
+ return (TosObjectMeta(bucket, key) for bucket, key in [parse_tos_url(url) for url in urls])
95
+
96
+
97
+ def gen_dataset_from_prefix(prefix: str, client: TosClient) -> Iterator[TosObjectMeta]:
98
+ bucket, prefix = parse_tos_url(prefix)
99
+ return iter(TosObjectIterator(bucket, prefix, client))
@@ -0,0 +1,11 @@
1
+ from typing import Optional
2
+
3
+
4
+ class TosError(Exception):
5
+ def __init__(self, message: str, status_code: Optional[int] = None, ec: Optional[str] = None,
6
+ request_id: Optional[str] = None):
7
+ super().__init__(message)
8
+ self.message = message
9
+ self.status_code = status_code
10
+ self.ec = ec
11
+ self.request_id = request_id
@@ -0,0 +1,99 @@
1
+ import logging
2
+ from functools import partial
3
+ from typing import Iterator, Any, Optional, Callable, Union
4
+
5
+ import torch
6
+
7
+ from . import TosObjectReader
8
+ from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig
9
+ from .tos_common import default_trans, gen_dataset_from_urls, gen_dataset_from_prefix
10
+ from .tos_object_meta import TosObjectMeta
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+ class TosIterableDataset(torch.utils.data.IterableDataset):
16
+ def __init__(self, region: str,
17
+ gen_dataset: Callable[[TosClient], Iterator[TosObjectMeta]],
18
+ endpoint: Optional[str] = None,
19
+ trans: Callable[[TosObjectReader], Any] = default_trans,
20
+ cred: Optional[CredentialProvider] = None,
21
+ client_conf: Optional[TosClientConfig] = None,
22
+ log_conf: Optional[TosLogConfig] = None,
23
+ sharding: bool = False, use_native_client=True):
24
+ self._client: Optional[TosClient] = None
25
+ self._gen_dataset = gen_dataset
26
+ self._region = region
27
+ self._endpoint = endpoint
28
+ self._trans = trans
29
+ self._cred = cred
30
+ self._client_conf = client_conf
31
+ self._log_conf = log_conf
32
+ self._sharding = sharding
33
+ self._use_native_client = use_native_client
34
+ if torch.distributed.is_initialized():
35
+ self._rank = torch.distributed.get_rank()
36
+ self._world_size = torch.distributed.get_world_size()
37
+ else:
38
+ self._rank = 0
39
+ self._world_size = 1
40
+
41
+ @classmethod
42
+ def from_urls(cls, urls: Union[str, Iterator[str]], *, region: str, endpoint: Optional[str] = None,
43
+ transform: Callable[[TosObjectReader], Any] = default_trans,
44
+ cred: Optional[CredentialProvider] = None,
45
+ client_conf: Optional[TosClientConfig] = None,
46
+ log_conf: Optional[TosLogConfig] = None,
47
+ sharding: bool = False, use_native_client=True):
48
+ log.info(f'building {cls.__name__} from_urls')
49
+ return cls(region, partial(gen_dataset_from_urls, urls), endpoint, transform, cred, client_conf, log_conf,
50
+ sharding, use_native_client)
51
+
52
+ @classmethod
53
+ def from_prefix(cls, prefix: str, *, region: str, endpoint: Optional[str] = None,
54
+ transform: Callable[[TosObjectReader], Any] = default_trans,
55
+ cred: Optional[CredentialProvider] = None,
56
+ client_conf: Optional[TosClientConfig] = None,
57
+ log_conf: Optional[TosLogConfig] = None,
58
+ sharding: bool = False, use_native_client=True):
59
+ log.info(f'building {cls.__name__} from_prefix')
60
+ return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, transform, cred, client_conf, log_conf,
61
+ sharding, use_native_client)
62
+
63
+ def __iter__(self) -> Iterator[Any]:
64
+ worker_id = 0
65
+ num_workers = 1
66
+ if self._sharding:
67
+ worker_info = torch.utils.data.get_worker_info()
68
+ if worker_info is not None:
69
+ worker_id = worker_info.id
70
+ num_workers = worker_info.num_workers
71
+
72
+ if not self._sharding or (self._world_size == 1 and num_workers == 1):
73
+ return map(
74
+ self._trans_tos_object,
75
+ self._gen_dataset(self._get_tos_client()),
76
+ )
77
+
78
+ part_dataset = (
79
+ obj
80
+ for idx, obj in enumerate(self._gen_dataset(self._get_tos_client()))
81
+ if idx % self._world_size == self._rank
82
+ )
83
+
84
+ part_dataset = (
85
+ obj
86
+ for idx, obj in enumerate(part_dataset)
87
+ if idx % num_workers == worker_id
88
+ )
89
+ return map(self._trans_tos_object, part_dataset)
90
+
91
+ def _get_tos_client(self):
92
+ if self._client is None:
93
+ self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf)
94
+ log.info('TosIterableDataset init tos client succeed')
95
+ return self._client
96
+
97
+ def _trans_tos_object(self, object_meta: TosObjectMeta) -> Any:
98
+ obj = self._get_tos_client().get_object(object_meta.bucket, object_meta.key, object_meta.etag, object_meta.size)
99
+ return self._trans(obj)
@@ -0,0 +1,78 @@
1
+ import logging
2
+ from functools import partial
3
+ from typing import Any, Callable, Iterator, Optional, List, Union
4
+
5
+ import torch
6
+
7
+ from . import TosObjectReader
8
+ from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig
9
+ from .tos_common import default_trans, gen_dataset_from_prefix, \
10
+ gen_dataset_from_urls
11
+ from .tos_object_meta import TosObjectMeta
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ class TosMapDataset(torch.utils.data.Dataset):
17
+ def __init__(self, region: str,
18
+ gen_dataset: Callable[[TosClient], Iterator[TosObjectMeta]],
19
+ endpoint: Optional[str] = None,
20
+ trans: Callable[[TosObjectReader], Any] = default_trans,
21
+ cred: Optional[CredentialProvider] = None,
22
+ client_conf: Optional[TosClientConfig] = None,
23
+ log_conf: Optional[TosLogConfig] = None, use_native_client=True):
24
+ self._client: Optional[TosClient] = None
25
+ self._gen_dataset = gen_dataset
26
+ self._region = region
27
+ self._endpoint = endpoint
28
+ self._trans = trans
29
+ self._cred = cred
30
+ self._client_conf = client_conf
31
+ self._log_conf = log_conf
32
+ self._use_native_client = use_native_client
33
+ self._dataset: Optional[List[TosObjectMeta]] = None
34
+
35
+ @classmethod
36
+ def from_urls(cls, urls: Union[str, Iterator[str]], *, region: str, endpoint: Optional[str] = None,
37
+ trans: Callable[[TosObjectReader], Any] = default_trans,
38
+ cred: Optional[CredentialProvider] = None,
39
+ client_conf: Optional[TosClientConfig] = None,
40
+ log_conf: Optional[TosLogConfig] = None, use_native_client=True):
41
+ log.info(f'building {cls.__name__} from_urls')
42
+ return cls(region, partial(gen_dataset_from_urls, urls), endpoint, trans, cred, client_conf, log_conf,
43
+ use_native_client)
44
+
45
+ @classmethod
46
+ def from_prefix(cls, prefix: str, *, region: str, endpoint: Optional[str] = None,
47
+ trans: Callable[[TosObjectReader], Any] = default_trans,
48
+ cred: Optional[CredentialProvider] = None,
49
+ client_conf: Optional[TosClientConfig] = None,
50
+ log_conf: Optional[TosLogConfig] = None, use_native_client=True):
51
+ log.info(f'building {cls.__name__} from_prefix')
52
+ return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, trans, cred, client_conf, log_conf,
53
+ use_native_client)
54
+
55
+ def __getitem__(self, i: int) -> Any:
56
+ return self._trans_tos_object(i)
57
+
58
+ def __len__(self) -> int:
59
+ return len(self._data_set)
60
+
61
+ def _get_tos_client(self) -> TosClient:
62
+ if self._client is None:
63
+ self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
64
+ self._use_native_client)
65
+ log.info('TosMapDataset init tos client succeed')
66
+ return self._client
67
+
68
+ @property
69
+ def _data_set(self):
70
+ if self._dataset is None:
71
+ self._dataset = list(self._gen_dataset(self._get_tos_client()))
72
+ assert self._dataset is not None
73
+ return self._dataset
74
+
75
+ def _trans_tos_object(self, i: int) -> Any:
76
+ object_meta = self._data_set[i]
77
+ obj = self._get_tos_client().get_object(object_meta.bucket, object_meta.key, object_meta.etag, object_meta.size)
78
+ return self._trans(obj)
@@ -0,0 +1,25 @@
1
+ from typing import Optional
2
+
3
+
4
+ class TosObjectMeta(object):
5
+ def __init__(self, bucket: str, key: str, size: Optional[int] = None, etag: Optional[str] = None):
6
+ self._bucket = bucket
7
+ self._key = key
8
+ self._size = size
9
+ self._etag = etag
10
+
11
+ @property
12
+ def bucket(self) -> str:
13
+ return self._bucket
14
+
15
+ @property
16
+ def key(self) -> str:
17
+ return self._key
18
+
19
+ @property
20
+ def size(self) -> Optional[int]:
21
+ return self._size
22
+
23
+ @property
24
+ def etag(self) -> Optional[str]:
25
+ return self._etag
@@ -0,0 +1,170 @@
1
+ import io
2
+ import logging
3
+ from functools import cached_property
4
+ from os import SEEK_SET, SEEK_CUR, SEEK_END
5
+ from typing import Optional, Callable, Iterator
6
+
7
+ from tos.models2 import GetObjectOutput
8
+ from tosnativeclient.tosnativeclient import ReadStream
9
+
10
+ from .tos_object_meta import TosObjectMeta
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+ class TosObjectReader(io.BufferedIOBase):
16
+
17
+ def __init__(self, bucket: str, key: str,
18
+ get_object_meta: Optional[Callable[[], TosObjectMeta]],
19
+ get_object_stream: Callable[[str, int], ReadStream | GetObjectOutput]):
20
+ if not bucket:
21
+ raise ValueError('bucket is empty')
22
+ self._bucket = bucket
23
+ self._key = key
24
+ self._get_object_meta = get_object_meta
25
+ self._get_object_stream = get_object_stream
26
+ self._object_stream: Optional[ReadStream | GetObjectOutput] = None
27
+ self._object_stream_offset = 0
28
+ self._total_size: Optional[int] = None
29
+ self._read_offset = 0
30
+ self._buffer = io.BytesIO()
31
+
32
+ @property
33
+ def bucket(self) -> str:
34
+ return self._bucket
35
+
36
+ @property
37
+ def key(self) -> str:
38
+ return self._key
39
+
40
+ def read(self, size: Optional[int] = None) -> Optional[bytes]:
41
+ if self._is_read_to_end():
42
+ return b''
43
+
44
+ self._trigger_prefetch()
45
+ current_read_offset = self._read_offset
46
+ if size is None or size < 0:
47
+ # means read all
48
+ self._buffer.seek(0, SEEK_END)
49
+ if isinstance(self._object_stream, ReadStream):
50
+ try:
51
+ chunk_size = 1 * 1024 * 1024
52
+ while 1:
53
+ chunk = self._object_stream.read(self._object_stream_offset, chunk_size)
54
+ if not chunk:
55
+ break
56
+ self._object_stream_offset += len(chunk)
57
+ self._buffer.write(chunk)
58
+ finally:
59
+ self._object_stream.close()
60
+ else:
61
+ for chunk in self._object_stream:
62
+ self._buffer.write(chunk)
63
+
64
+ self._total_size = self._buffer.tell()
65
+ else:
66
+ self.seek(size, SEEK_CUR)
67
+
68
+ self._buffer.seek(current_read_offset)
69
+ data = self._buffer.read(size)
70
+ self._read_offset = self._buffer.tell()
71
+ return data
72
+
73
+ def readinto(self, buf) -> Optional[int]:
74
+ size = len(buf)
75
+ if self._is_read_to_end() or size == 0:
76
+ return 0
77
+
78
+ self._trigger_prefetch()
79
+ current_read_offset = self._read_offset
80
+ self.seek(size, SEEK_CUR)
81
+ self._buffer.seek(current_read_offset)
82
+ readed = self._buffer.readinto(buf)
83
+ self._read_offset = self._buffer.tell()
84
+ return readed
85
+
86
+ def seek(self, offset: int, whence: int = SEEK_SET) -> int:
87
+ if whence == SEEK_END:
88
+ if offset >= 0:
89
+ self._read_offset = self._get_total_size()
90
+ return self._read_offset
91
+ # offset is negative
92
+ offset += self._get_total_size()
93
+ elif whence == SEEK_CUR:
94
+ if self._is_read_to_end() and offset >= 0:
95
+ return self._read_offset
96
+ offset += self._read_offset
97
+ elif whence == SEEK_SET:
98
+ pass
99
+ else:
100
+ raise ValueError('invalid whence')
101
+
102
+ if offset < 0:
103
+ raise ValueError(f'invalid seek offset {offset}')
104
+
105
+ if offset > self._buffer_size():
106
+ self._prefetch_to_offset(offset)
107
+
108
+ self._read_offset = self._buffer.seek(offset)
109
+ return self._read_offset
110
+
111
+ def tell(self) -> int:
112
+ return self._read_offset
113
+
114
+ def readable(self):
115
+ return True
116
+
117
+ def writable(self):
118
+ return False
119
+
120
+ def seekable(self):
121
+ return True
122
+
123
+ @cached_property
124
+ def _object_meta(self) -> TosObjectMeta:
125
+ return self._get_object_meta()
126
+
127
+ def _trigger_prefetch(self) -> None:
128
+ if self._object_stream is None:
129
+ object_meta = self._object_meta
130
+ self._object_stream = self._get_object_stream(object_meta.etag, object_meta.size)
131
+ self._object_stream_offset = 0
132
+
133
+ def _is_read_to_end(self) -> bool:
134
+ if self._total_size is None:
135
+ return False
136
+ return self._read_offset == self._total_size
137
+
138
+ def _get_total_size(self) -> int:
139
+ if self._total_size is None:
140
+ self._total_size = self._object_meta.size
141
+ return self._total_size
142
+
143
+ def _prefetch_to_offset(self, offset: int) -> None:
144
+ self._trigger_prefetch()
145
+ size = self._buffer.seek(0, SEEK_END)
146
+ if isinstance(self._object_stream, ReadStream):
147
+ try:
148
+ chunk_size = 1 * 1024 * 1024
149
+ while offset > size:
150
+ chunk = self._object_stream.read(self._object_stream_offset, chunk_size)
151
+ if not chunk:
152
+ break
153
+ size += self._buffer.write(chunk)
154
+ self._object_stream_offset += len(chunk)
155
+ self._total_size = self._buffer.tell()
156
+ finally:
157
+ self._object_stream.close()
158
+ else:
159
+ try:
160
+ while offset > size:
161
+ size += self._buffer.write(next(self._object_stream))
162
+ except StopIteration:
163
+ self._total_size = self._buffer.tell()
164
+
165
+ def _buffer_size(self) -> int:
166
+ cur_pos = self._buffer.tell()
167
+ self._buffer.seek(0, SEEK_END)
168
+ buffer_size = self._buffer.tell()
169
+ self._buffer.seek(cur_pos)
170
+ return buffer_size
@@ -0,0 +1,83 @@
1
+ import io
2
+ import logging
3
+ import threading
4
+ from typing import Callable, Optional, Any
5
+
6
+ from tos.models2 import PutObjectOutput
7
+ from tosnativeclient.tosnativeclient import WriteStream
8
+
9
+ log = logging.getLogger(__name__)
10
+
11
+
12
+ class PutObjectStream(object):
13
+ def __init__(self, put_object: Callable[[io.BytesIO], PutObjectOutput]):
14
+ self._put_object = put_object
15
+ self._buffer = io.BytesIO()
16
+
17
+ def write(self, data):
18
+ self._buffer.write(data)
19
+
20
+ def close(self):
21
+ self._buffer.seek(0)
22
+ _ = self._put_object(self._buffer)
23
+
24
+
25
+ class TosObjectWriter(io.BufferedIOBase):
26
+
27
+ def __init__(self, bucket: str, key: str, put_object_stream: WriteStream | PutObjectStream):
28
+ if not bucket:
29
+ raise ValueError('bucket is empty')
30
+ self._bucket = bucket
31
+ self._key = key
32
+ self._put_object_stream = put_object_stream
33
+ self._write_offset = 0
34
+ self._closed = False
35
+ self._lock = threading.Lock()
36
+
37
+ @property
38
+ def bucket(self) -> str:
39
+ return self._bucket
40
+
41
+ @property
42
+ def key(self) -> str:
43
+ return self._key
44
+
45
+ def __enter__(self):
46
+ self._write_offset = 0
47
+ return self
48
+
49
+ def __exit__(self, exc_type, exc_val, exc_tb):
50
+ if exc_type is not None:
51
+ try:
52
+ log.info(f'Exception occurred before closing stream: {exc_type.__name__}: {exc_val}')
53
+ except:
54
+ pass
55
+ else:
56
+ self.close()
57
+
58
+ def write(self, data) -> int:
59
+ self._put_object_stream.write(data)
60
+ self._write_offset += len(data)
61
+ return len(data)
62
+
63
+ def close(self) -> None:
64
+ with self._lock:
65
+ if self._closed:
66
+ return
67
+ self._closed = True
68
+ self._put_object_stream.close()
69
+
70
+ def tell(self) -> int:
71
+ return self._write_offset
72
+
73
+ def flush(self) -> None:
74
+ pass
75
+
76
+ def readable(self):
77
+ return False
78
+
79
+ def writable(self):
80
+ return True
81
+
82
+ def seekable(self):
83
+ return False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tostorchconnector
3
- Version: 1.0.0
3
+ Version: 1.0.1
4
4
  Summary: TOS connector integration for PyTorch
5
5
  Author-email: xiangshijian <xiangshijian@bytedance.com>
6
6
  Classifier: Development Status :: 4 - Beta
@@ -0,0 +1,20 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ tests/test_tos_dataset.py
5
+ tests/test_tosnativeclient.py
6
+ tostorchconnector/__init__.py
7
+ tostorchconnector/tos_checkpoint.py
8
+ tostorchconnector/tos_client.py
9
+ tostorchconnector/tos_common.py
10
+ tostorchconnector/tos_error.py
11
+ tostorchconnector/tos_iterable_dataset.py
12
+ tostorchconnector/tos_map_dataset.py
13
+ tostorchconnector/tos_object_meta.py
14
+ tostorchconnector/tos_object_reader.py
15
+ tostorchconnector/tos_object_writer.py
16
+ tostorchconnector.egg-info/PKG-INFO
17
+ tostorchconnector.egg-info/SOURCES.txt
18
+ tostorchconnector.egg-info/dependency_links.txt
19
+ tostorchconnector.egg-info/requires.txt
20
+ tostorchconnector.egg-info/top_level.txt
@@ -0,0 +1 @@
1
+ tostorchconnector
@@ -1,10 +0,0 @@
1
- LICENSE
2
- README.md
3
- pyproject.toml
4
- tests/test_tos_dataset.py
5
- tests/test_tosnativeclient.py
6
- tostorchconnector/tostorchconnector.egg-info/PKG-INFO
7
- tostorchconnector/tostorchconnector.egg-info/SOURCES.txt
8
- tostorchconnector/tostorchconnector.egg-info/dependency_links.txt
9
- tostorchconnector/tostorchconnector.egg-info/requires.txt
10
- tostorchconnector/tostorchconnector.egg-info/top_level.txt