tostorchconnector 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tostorchconnector might be problematic. Click here for more details.

@@ -1,4 +1,4 @@
1
- from .tos_object_reader import TosObjectReader
1
+ from .tos_object_reader import TosObjectReader, SequentialTosObjectReader
2
2
  from .tos_object_writer import TosObjectWriter
3
3
  from .tos_iterable_dataset import TosIterableDataset
4
4
  from .tos_map_dataset import TosMapDataset
@@ -7,6 +7,7 @@ from tosnativeclient import TosException, TosError
7
7
 
8
8
  __all__ = [
9
9
  'TosObjectReader',
10
+ 'SequentialTosObjectReader',
10
11
  'TosObjectWriter',
11
12
  'TosIterableDataset',
12
13
  'TosMapDataset',
@@ -2,7 +2,7 @@ import logging
2
2
  from typing import Optional
3
3
 
4
4
  from . import TosObjectReader, TosObjectWriter
5
- from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig
5
+ from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig, ReaderType
6
6
  from .tos_common import parse_tos_url
7
7
 
8
8
  log = logging.getLogger(__name__)
@@ -14,26 +14,20 @@ class TosCheckpoint(object):
14
14
  cred: Optional[CredentialProvider] = None,
15
15
  client_conf: Optional[TosClientConfig] = None,
16
16
  log_conf: Optional[TosLogConfig] = None, use_native_client=True):
17
- self._client = None
18
- self._native_client = None
19
17
  self._region = region
20
18
  self._endpoint = endpoint
21
19
  self._cred = cred
22
20
  self._client_conf = client_conf
23
21
  self._log_conf = log_conf
24
- self._use_native_client = use_native_client
22
+ self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
23
+ use_native_client)
24
+ log.info('TosCheckpoint init tos client succeed')
25
25
 
26
- def reader(self, url: str) -> TosObjectReader:
26
+ def reader(self, url: str, reader_type: Optional[ReaderType] = None,
27
+ buffer_size: Optional[int] = None) -> TosObjectReader:
27
28
  bucket, key = parse_tos_url(url)
28
- return self._get_tos_client().get_object(bucket, key)
29
+ return self._client.get_object(bucket, key, reader_type=reader_type, buffer_size=buffer_size)
29
30
 
30
31
  def writer(self, url: str) -> TosObjectWriter:
31
32
  bucket, key = parse_tos_url(url)
32
- return self._get_tos_client().put_object(bucket, key)
33
-
34
- def _get_tos_client(self):
35
- if self._client is None:
36
- self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
37
- self._use_native_client)
38
- log.info('TosIterableDataset init tos client succeed')
39
- return self._client
33
+ return self._client.put_object(bucket, key)
@@ -1,21 +1,26 @@
1
- import io
1
+ import enum
2
2
  import logging
3
3
  import os
4
4
  from functools import partial
5
5
  from typing import Optional, List, Tuple
6
6
 
7
7
  import tos
8
- from tos.models2 import GetObjectOutput, PutObjectOutput
9
8
 
10
9
  import tosnativeclient
11
10
 
12
- from . import TosObjectReader, TosObjectWriter
11
+ from . import SequentialTosObjectReader, TosObjectWriter, TosObjectReader
13
12
  from .tos_object_meta import TosObjectMeta
13
+ from .tos_object_reader import TosObjectStream, RangedTosObjectReader
14
14
  from .tos_object_writer import PutObjectStream
15
15
 
16
16
  log = logging.getLogger(__name__)
17
17
 
18
18
 
19
+ class ReaderType(enum.Enum):
20
+ SEQUENTIAL = 'Sequential'
21
+ RANGED = 'Ranged'
22
+
23
+
19
24
  class CredentialProvider(object):
20
25
  def __init__(self, ak: str, sk: str):
21
26
  self._ak = ak
@@ -110,48 +115,49 @@ class TosClient(object):
110
115
  tos.set_logger(file_path=file_path, level=log_level)
111
116
 
112
117
  @property
113
- def use_native_client(self):
118
+ def use_native_client(self) -> bool:
114
119
  return self._use_native_client
115
120
 
116
121
  def get_object(self, bucket: str, key: str, etag: Optional[str] = None,
117
- size: Optional[int] = None) -> TosObjectReader:
122
+ size: Optional[int] = None, reader_type: Optional[ReaderType] = None,
123
+ buffer_size: Optional[int] = None) -> TosObjectReader:
118
124
  log.debug(f'get_object tos://{bucket}/{key}')
125
+
119
126
  if size is None or etag is None:
120
127
  get_object_meta = partial(self.head_object, bucket, key)
121
128
  else:
122
129
  get_object_meta = lambda: TosObjectMeta(bucket, key, size, etag)
123
- if isinstance(self._client, tosnativeclient.TosClient):
124
- get_object_stream = partial(self._client.get_object, bucket, key)
125
- else:
126
- def get_object_stream(et: str, _: int) -> GetObjectOutput:
127
- return self._client.get_object(bucket, key, '', et)
128
130
 
129
- return TosObjectReader(bucket, key, get_object_meta, get_object_stream)
131
+ object_stream = TosObjectStream(bucket, key, get_object_meta, self._client)
132
+ if reader_type is not None and reader_type == ReaderType.RANGED:
133
+ return RangedTosObjectReader(bucket, key, object_stream, buffer_size)
134
+ return SequentialTosObjectReader(bucket, key, object_stream)
130
135
 
131
136
  def put_object(self, bucket: str, key: str, storage_class: Optional[str] = None) -> TosObjectWriter:
132
137
  log.debug(f'put_object tos://{bucket}/{key}')
133
138
 
134
139
  if isinstance(self._client, tosnativeclient.TosClient):
135
- storage_class = storage_class if storage_class is not None else ''
136
- put_object_stream = self._client.put_object(bucket, key, storage_class)
140
+ put_object_stream = self._client.put_object(bucket, key, storage_class=storage_class)
137
141
  else:
138
- def put_object(content: io.BytesIO) -> PutObjectOutput:
139
- return self._client.put_object(bucket, key, storage_class=storage_class, content=content)
140
-
141
- put_object_stream = PutObjectStream(put_object)
142
+ put_object_stream = PutObjectStream(
143
+ lambda content: self._client.put_object(bucket, key, storage_class=storage_class, content=content))
142
144
 
143
145
  return TosObjectWriter(bucket, key, put_object_stream)
144
146
 
145
147
  def head_object(self, bucket: str, key: str) -> TosObjectMeta:
146
148
  log.debug(f'head_object tos://{bucket}/{key}')
149
+
147
150
  if isinstance(self._client, tosnativeclient.TosClient):
148
151
  resp = self._client.head_object(bucket, key)
149
152
  return TosObjectMeta(resp.bucket, resp.key, resp.size, resp.etag)
153
+
150
154
  resp = self._client.head_object(bucket, key)
151
155
  return TosObjectMeta(bucket, key, resp.content_length, resp.etag)
152
156
 
153
157
  def gen_list_stream(self, bucket: str, prefix: str, max_keys: int = 1000,
154
158
  delimiter: Optional[str] = None) -> tosnativeclient.ListStream:
159
+ log.debug(f'gen_list_stream tos://{bucket}/{prefix}')
160
+
155
161
  if isinstance(self._client, tosnativeclient.TosClient):
156
162
  delimiter = delimiter if delimiter is not None else ''
157
163
  return self._client.list_objects(bucket, prefix, max_keys=max_keys, delimiter=delimiter)
@@ -161,6 +167,10 @@ class TosClient(object):
161
167
  continuation_token: Optional[str] = None, delimiter: Optional[str] = None) -> Tuple[
162
168
  List[TosObjectMeta], bool, Optional[str]]:
163
169
  log.debug(f'list_objects tos://{bucket}/{prefix}')
170
+
171
+ if isinstance(self._client, tosnativeclient.TosClient):
172
+ raise NotImplementedError()
173
+
164
174
  resp = self._client.list_objects_type2(bucket, prefix, max_keys=max_keys, continuation_token=continuation_token,
165
175
  delimiter=delimiter)
166
176
  object_metas = []
@@ -22,14 +22,14 @@ class TosObjectIterator(object):
22
22
  self._is_truncated = True
23
23
  self._continuation_token = None
24
24
 
25
- def close(self):
25
+ def close(self) -> None:
26
26
  if self._list_stream is not None:
27
27
  self._list_stream.close()
28
28
 
29
29
  def __iter__(self) -> Iterator[TosObjectMeta]:
30
30
  return self
31
31
 
32
- def __next__(self):
32
+ def __next__(self) -> TosObjectMeta:
33
33
  if self._client.use_native_client:
34
34
  if self._list_stream is None:
35
35
  self._list_stream = self._client.gen_list_stream(self._bucket, self._prefix, max_keys=1000,
@@ -37,9 +37,14 @@ class TosObjectIterator(object):
37
37
 
38
38
  if self._object_metas is None or self._index >= len(self._object_metas):
39
39
  self._object_metas = []
40
- objects = next(self._list_stream)
41
- for content in objects.contents:
42
- self._object_metas.append(TosObjectMeta(content.bucket, content.key, content.size, content.etag))
40
+ self._index = 0
41
+ while 1:
42
+ objects = next(self._list_stream)
43
+ for content in objects.contents:
44
+ self._object_metas.append(
45
+ TosObjectMeta(content.bucket, content.key, content.size, content.etag))
46
+ if len(self._object_metas) > 0:
47
+ break
43
48
 
44
49
  object_meta = self._object_metas[self._index]
45
50
  self._index += 1
@@ -5,7 +5,7 @@ from typing import Iterator, Any, Optional, Callable, Union
5
5
  import torch
6
6
 
7
7
  from . import TosObjectReader
8
- from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig
8
+ from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig, ReaderType
9
9
  from .tos_common import default_trans, gen_dataset_from_urls, gen_dataset_from_prefix
10
10
  from .tos_object_meta import TosObjectMeta
11
11
 
@@ -20,8 +20,10 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
20
20
  cred: Optional[CredentialProvider] = None,
21
21
  client_conf: Optional[TosClientConfig] = None,
22
22
  log_conf: Optional[TosLogConfig] = None,
23
- sharding: bool = False, use_native_client=True):
24
- self._client: Optional[TosClient] = None
23
+ sharding: bool = False,
24
+ use_native_client: bool = True,
25
+ reader_type: Optional[ReaderType] = None,
26
+ buffer_size: Optional[int] = None):
25
27
  self._gen_dataset = gen_dataset
26
28
  self._region = region
27
29
  self._endpoint = endpoint
@@ -30,13 +32,17 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
30
32
  self._client_conf = client_conf
31
33
  self._log_conf = log_conf
32
34
  self._sharding = sharding
33
- self._use_native_client = use_native_client
34
35
  if torch.distributed.is_initialized():
35
36
  self._rank = torch.distributed.get_rank()
36
37
  self._world_size = torch.distributed.get_world_size()
37
38
  else:
38
39
  self._rank = 0
39
40
  self._world_size = 1
41
+ self._reader_type = reader_type
42
+ self._buffer_size = buffer_size
43
+ self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
44
+ use_native_client)
45
+ log.info('TosIterableDataset init tos client succeed')
40
46
 
41
47
  @classmethod
42
48
  def from_urls(cls, urls: Union[str, Iterator[str]], *, region: str, endpoint: Optional[str] = None,
@@ -44,10 +50,13 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
44
50
  cred: Optional[CredentialProvider] = None,
45
51
  client_conf: Optional[TosClientConfig] = None,
46
52
  log_conf: Optional[TosLogConfig] = None,
47
- sharding: bool = False, use_native_client=True):
53
+ sharding: bool = False,
54
+ use_native_client: bool = True,
55
+ reader_type: Optional[ReaderType] = None,
56
+ buffer_size: Optional[int] = None):
48
57
  log.info(f'building {cls.__name__} from_urls')
49
58
  return cls(region, partial(gen_dataset_from_urls, urls), endpoint, transform, cred, client_conf, log_conf,
50
- sharding, use_native_client)
59
+ sharding, use_native_client, reader_type, buffer_size)
51
60
 
52
61
  @classmethod
53
62
  def from_prefix(cls, prefix: str, *, region: str, endpoint: Optional[str] = None,
@@ -55,10 +64,13 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
55
64
  cred: Optional[CredentialProvider] = None,
56
65
  client_conf: Optional[TosClientConfig] = None,
57
66
  log_conf: Optional[TosLogConfig] = None,
58
- sharding: bool = False, use_native_client=True):
67
+ sharding: bool = False,
68
+ use_native_client: bool = True,
69
+ reader_type: Optional[ReaderType] = None,
70
+ buffer_size: Optional[int] = None):
59
71
  log.info(f'building {cls.__name__} from_prefix')
60
72
  return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, transform, cred, client_conf, log_conf,
61
- sharding, use_native_client)
73
+ sharding, use_native_client, reader_type, buffer_size)
62
74
 
63
75
  def __iter__(self) -> Iterator[Any]:
64
76
  worker_id = 0
@@ -72,12 +84,12 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
72
84
  if not self._sharding or (self._world_size == 1 and num_workers == 1):
73
85
  return map(
74
86
  self._trans_tos_object,
75
- self._gen_dataset(self._get_tos_client()),
87
+ self._gen_dataset(self._client),
76
88
  )
77
89
 
78
90
  part_dataset = (
79
91
  obj
80
- for idx, obj in enumerate(self._gen_dataset(self._get_tos_client()))
92
+ for idx, obj in enumerate(self._gen_dataset(self._client))
81
93
  if idx % self._world_size == self._rank
82
94
  )
83
95
 
@@ -88,12 +100,7 @@ class TosIterableDataset(torch.utils.data.IterableDataset):
88
100
  )
89
101
  return map(self._trans_tos_object, part_dataset)
90
102
 
91
- def _get_tos_client(self):
92
- if self._client is None:
93
- self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf)
94
- log.info('TosIterableDataset init tos client succeed')
95
- return self._client
96
-
97
103
  def _trans_tos_object(self, object_meta: TosObjectMeta) -> Any:
98
- obj = self._get_tos_client().get_object(object_meta.bucket, object_meta.key, object_meta.etag, object_meta.size)
104
+ obj = self._client.get_object(object_meta.bucket, object_meta.key, object_meta.etag, object_meta.size,
105
+ reader_type=self._reader_type, buffer_size=self._buffer_size)
99
106
  return self._trans(obj)
@@ -5,7 +5,7 @@ from typing import Any, Callable, Iterator, Optional, List, Union
5
5
  import torch
6
6
 
7
7
  from . import TosObjectReader
8
- from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig
8
+ from .tos_client import CredentialProvider, TosClientConfig, TosClient, TosLogConfig, ReaderType
9
9
  from .tos_common import default_trans, gen_dataset_from_prefix, \
10
10
  gen_dataset_from_urls
11
11
  from .tos_object_meta import TosObjectMeta
@@ -20,8 +20,10 @@ class TosMapDataset(torch.utils.data.Dataset):
20
20
  trans: Callable[[TosObjectReader], Any] = default_trans,
21
21
  cred: Optional[CredentialProvider] = None,
22
22
  client_conf: Optional[TosClientConfig] = None,
23
- log_conf: Optional[TosLogConfig] = None, use_native_client=True):
24
- self._client: Optional[TosClient] = None
23
+ log_conf: Optional[TosLogConfig] = None,
24
+ use_native_client: bool = True,
25
+ reader_type: Optional[ReaderType] = None,
26
+ buffer_size: Optional[int] = None):
25
27
  self._gen_dataset = gen_dataset
26
28
  self._region = region
27
29
  self._endpoint = endpoint
@@ -29,28 +31,38 @@ class TosMapDataset(torch.utils.data.Dataset):
29
31
  self._cred = cred
30
32
  self._client_conf = client_conf
31
33
  self._log_conf = log_conf
32
- self._use_native_client = use_native_client
33
34
  self._dataset: Optional[List[TosObjectMeta]] = None
35
+ self._reader_type = reader_type
36
+ self._buffer_size = buffer_size
37
+ self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
38
+ use_native_client)
39
+ log.info('TosMapDataset init tos client succeed')
34
40
 
35
41
  @classmethod
36
42
  def from_urls(cls, urls: Union[str, Iterator[str]], *, region: str, endpoint: Optional[str] = None,
37
43
  trans: Callable[[TosObjectReader], Any] = default_trans,
38
44
  cred: Optional[CredentialProvider] = None,
39
45
  client_conf: Optional[TosClientConfig] = None,
40
- log_conf: Optional[TosLogConfig] = None, use_native_client=True):
46
+ log_conf: Optional[TosLogConfig] = None,
47
+ use_native_client: bool = True,
48
+ reader_type: Optional[ReaderType] = None,
49
+ buffer_size: Optional[int] = None):
41
50
  log.info(f'building {cls.__name__} from_urls')
42
51
  return cls(region, partial(gen_dataset_from_urls, urls), endpoint, trans, cred, client_conf, log_conf,
43
- use_native_client)
52
+ use_native_client, reader_type, buffer_size)
44
53
 
45
54
  @classmethod
46
55
  def from_prefix(cls, prefix: str, *, region: str, endpoint: Optional[str] = None,
47
56
  trans: Callable[[TosObjectReader], Any] = default_trans,
48
57
  cred: Optional[CredentialProvider] = None,
49
58
  client_conf: Optional[TosClientConfig] = None,
50
- log_conf: Optional[TosLogConfig] = None, use_native_client=True):
59
+ log_conf: Optional[TosLogConfig] = None,
60
+ use_native_client: bool = True,
61
+ reader_type: Optional[ReaderType] = None,
62
+ buffer_size: Optional[int] = None):
51
63
  log.info(f'building {cls.__name__} from_prefix')
52
64
  return cls(region, partial(gen_dataset_from_prefix, prefix), endpoint, trans, cred, client_conf, log_conf,
53
- use_native_client)
65
+ use_native_client, reader_type, buffer_size)
54
66
 
55
67
  def __getitem__(self, i: int) -> Any:
56
68
  return self._trans_tos_object(i)
@@ -58,21 +70,15 @@ class TosMapDataset(torch.utils.data.Dataset):
58
70
  def __len__(self) -> int:
59
71
  return len(self._data_set)
60
72
 
61
- def _get_tos_client(self) -> TosClient:
62
- if self._client is None:
63
- self._client = TosClient(self._region, self._endpoint, self._cred, self._client_conf, self._log_conf,
64
- self._use_native_client)
65
- log.info('TosMapDataset init tos client succeed')
66
- return self._client
67
-
68
73
  @property
69
- def _data_set(self):
74
+ def _data_set(self) -> List[TosObjectMeta]:
70
75
  if self._dataset is None:
71
- self._dataset = list(self._gen_dataset(self._get_tos_client()))
76
+ self._dataset = list(self._gen_dataset(self._client))
72
77
  assert self._dataset is not None
73
78
  return self._dataset
74
79
 
75
80
  def _trans_tos_object(self, i: int) -> Any:
76
81
  object_meta = self._data_set[i]
77
- obj = self._get_tos_client().get_object(object_meta.bucket, object_meta.key, object_meta.etag, object_meta.size)
82
+ obj = self._client.get_object(object_meta.bucket, object_meta.key, object_meta.etag, object_meta.size,
83
+ reader_type=self._reader_type, buffer_size=self._buffer_size)
78
84
  return self._trans(obj)
@@ -1,33 +1,105 @@
1
1
  import io
2
2
  import logging
3
- from functools import cached_property
3
+ import threading
4
+ from abc import ABC, abstractmethod
5
+ from functools import partial, cached_property
4
6
  from os import SEEK_SET, SEEK_CUR, SEEK_END
5
7
  from typing import Optional, Callable, Any
6
8
 
7
- from tos.models2 import GetObjectOutput
8
- from tosnativeclient.tosnativeclient import ReadStream
9
-
9
+ import tosnativeclient
10
10
  from .tos_object_meta import TosObjectMeta
11
11
 
12
12
  log = logging.getLogger(__name__)
13
13
 
14
+ DEFAULT_CHUNK_SIZE = 1 * 1024 * 1024
15
+ DEFAULT_BUFFER_SIZE = 8 * 1024 * 1024
14
16
 
15
- class TosObjectReader(io.BufferedIOBase):
16
17
 
17
- def __init__(self, bucket: str, key: str,
18
- get_object_meta: Optional[Callable[[], TosObjectMeta]],
19
- get_object_stream: Callable[[str, int], Any]):
18
+ class TosObjectStream(object):
19
+ def __init__(self, bucket: str, key: str, get_object_meta: Optional[Callable[[], TosObjectMeta]],
20
+ client: Any):
21
+ self._bucket = bucket
22
+ self._key = key
23
+ self._get_object_meta = get_object_meta
24
+ self._client = client
25
+ self._sequential_object_stream = None
26
+ self._sequential_object_stream_offset = 0
27
+ self._random_object_stream = None
28
+
29
+ def sequential_read(self, chunk_size) -> Optional[bytes]:
30
+ self._trigger_prefetch()
31
+ assert self._sequential_object_stream is not None
32
+ if chunk_size <= 0:
33
+ chunk_size = DEFAULT_CHUNK_SIZE
34
+ if isinstance(self._sequential_object_stream, tosnativeclient.ReadStream):
35
+ chunk = self._sequential_object_stream.read(self._sequential_object_stream_offset, chunk_size)
36
+ if chunk:
37
+ self._sequential_object_stream_offset += len(chunk)
38
+ else:
39
+ chunk = self._sequential_object_stream.read(chunk_size)
40
+ return chunk
41
+
42
+ def random_read(self, read_start, read_end, chunk_size, callback: Callable[[bytes], None]) -> None:
43
+ if chunk_size <= 0:
44
+ chunk_size = DEFAULT_CHUNK_SIZE
45
+ if isinstance(self._client, tosnativeclient.TosClient):
46
+ if self._random_object_stream is None:
47
+ object_meta = self.object_meta
48
+ self._random_object_stream = self._client.get_object(self._bucket, self._key, object_meta.etag,
49
+ object_meta.size)
50
+
51
+ offset = read_start
52
+ while 1:
53
+ if offset >= read_end:
54
+ break
55
+ length = chunk_size if offset + chunk_size <= read_end else read_end - offset
56
+ chunk = self._random_object_stream.read(offset, length)
57
+ if not chunk:
58
+ break
59
+ callback(chunk)
60
+ offset += len(chunk)
61
+ return
62
+
63
+ object_meta = self.object_meta
64
+ output = self._client.get_object(self._bucket, self._key, if_match=object_meta.etag,
65
+ range=f'bytes={read_start}-{read_end - read_start + 1}')
66
+ while 1:
67
+ chunk = output.read(chunk_size)
68
+ if not chunk:
69
+ break
70
+ callback(chunk)
71
+
72
+ def _trigger_prefetch(self) -> None:
73
+ if self._sequential_object_stream is None:
74
+ object_meta = self.object_meta
75
+ if isinstance(self._client, tosnativeclient.TosClient):
76
+ get_object_stream = partial(self._client.get_object, self._bucket, self._key)
77
+ else:
78
+ get_object_stream = lambda et, sz: self._client.get_object(self._bucket, self._key, '', et)
79
+ self._sequential_object_stream = get_object_stream(object_meta.etag, object_meta.size)
80
+
81
+ @cached_property
82
+ def object_meta(self) -> TosObjectMeta:
83
+ return self._get_object_meta()
84
+
85
+ def close(self) -> None:
86
+ if self._sequential_object_stream and isinstance(self._sequential_object_stream, tosnativeclient.ReadStream):
87
+ self._sequential_object_stream.close()
88
+ if self._random_object_stream:
89
+ self._random_object_stream.close()
90
+
91
+
92
+ class TosObjectReader(ABC, io.BufferedIOBase):
93
+ def __init__(self, bucket: str, key: str, object_stream: TosObjectStream):
20
94
  if not bucket:
21
95
  raise ValueError('bucket is empty')
22
96
  self._bucket = bucket
23
97
  self._key = key
24
- self._get_object_meta = get_object_meta
25
- self._get_object_stream = get_object_stream
26
- self._object_stream: Optional[Any] = None
27
- self._object_stream_offset = 0
98
+ self._object_stream = object_stream
28
99
  self._total_size: Optional[int] = None
29
100
  self._read_offset = 0
30
- self._buffer = io.BytesIO()
101
+ self._closed = False
102
+ self._lock = threading.Lock()
31
103
 
32
104
  @property
33
105
  def bucket(self) -> str:
@@ -37,46 +109,86 @@ class TosObjectReader(io.BufferedIOBase):
37
109
  def key(self) -> str:
38
110
  return self._key
39
111
 
112
+ @property
113
+ def closed(self) -> bool:
114
+ return self._closed
115
+
116
+ def close(self) -> None:
117
+ if self._closed:
118
+ return
119
+ with self._lock:
120
+ if not self._closed:
121
+ self._closed = True
122
+ self._object_stream.close()
123
+
40
124
  def __enter__(self):
41
125
  return self
42
126
 
43
- def __exit__(self, exc_type, exc_val, exc_tb):
127
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
44
128
  if exc_type is not None:
45
129
  try:
46
130
  log.info(f'Exception occurred before closing stream: {exc_type.__name__}: {exc_val}')
47
131
  except:
48
132
  pass
49
133
  else:
50
- if self._object_stream is not None:
51
- self._object_stream.close()
134
+ self.close()
135
+
136
+ @abstractmethod
137
+ def read(self, size: Optional[int] = None) -> bytes:
138
+ pass
139
+
140
+ @abstractmethod
141
+ def readinto(self, buf) -> int:
142
+ pass
143
+
144
+ @abstractmethod
145
+ def seek(self, offset: int, whence: int = SEEK_SET) -> int:
146
+ pass
147
+
148
+ def tell(self) -> int:
149
+ return self._read_offset
150
+
151
+ def readable(self) -> bool:
152
+ return self.closed
153
+
154
+ def writable(self) -> bool:
155
+ return False
52
156
 
53
- def read(self, size: Optional[int] = None) -> Optional[bytes]:
157
+ def seekable(self) -> bool:
158
+ return True
159
+
160
+ def _is_read_to_end(self) -> bool:
161
+ if self._total_size is None:
162
+ return False
163
+ return self._read_offset == self._total_size
164
+
165
+ def _get_total_size(self) -> int:
166
+ if self._total_size is None:
167
+ self._total_size = self._object_stream.object_meta.size
168
+ return self._total_size
169
+
170
+
171
+ class SequentialTosObjectReader(TosObjectReader):
172
+
173
+ def __init__(self, bucket: str, key: str, object_stream: TosObjectStream):
174
+ super().__init__(bucket, key, object_stream)
175
+ self._buffer = io.BytesIO()
176
+
177
+ def read(self, size: Optional[int] = None) -> bytes:
54
178
  if self._is_read_to_end():
55
179
  return b''
56
180
 
57
- self._trigger_prefetch()
58
181
  current_read_offset = self._read_offset
59
182
  if size is None or size < 0:
183
+ if self.closed:
184
+ raise RuntimeError('read on closed GetObjectOutput')
60
185
  # means read all
61
186
  self._buffer.seek(0, SEEK_END)
62
- if isinstance(self._object_stream, ReadStream):
63
- try:
64
- chunk_size = 1 * 1024 * 1024
65
- while 1:
66
- chunk = self._object_stream.read(self._object_stream_offset, chunk_size)
67
- if not chunk:
68
- self._object_stream.close()
69
- break
70
- self._object_stream_offset += len(chunk)
71
- self._buffer.write(chunk)
72
- except:
73
- self._object_stream.close()
74
- raise
75
-
76
- elif isinstance(self._object_stream, GetObjectOutput):
77
- for chunk in self._object_stream:
78
- self._buffer.write(chunk)
79
-
187
+ while 1:
188
+ chunk = self._object_stream.sequential_read(DEFAULT_CHUNK_SIZE)
189
+ if not chunk:
190
+ break
191
+ self._buffer.write(chunk)
80
192
  self._total_size = self._buffer.tell()
81
193
  else:
82
194
  self.seek(size, SEEK_CUR)
@@ -86,12 +198,10 @@ class TosObjectReader(io.BufferedIOBase):
86
198
  self._read_offset = self._buffer.tell()
87
199
  return data
88
200
 
89
- def readinto(self, buf) -> Optional[int]:
201
+ def readinto(self, buf) -> int:
90
202
  size = len(buf)
91
203
  if self._is_read_to_end() or size == 0:
92
204
  return 0
93
-
94
- self._trigger_prefetch()
95
205
  current_read_offset = self._read_offset
96
206
  self.seek(size, SEEK_CUR)
97
207
  self._buffer.seek(current_read_offset)
@@ -113,7 +223,7 @@ class TosObjectReader(io.BufferedIOBase):
113
223
  elif whence == SEEK_SET:
114
224
  pass
115
225
  else:
116
- raise ValueError('invalid whence')
226
+ raise ValueError('invalid whence, must be passed SEEK_CUR, SEEK_SET, or SEEK_END')
117
227
 
118
228
  if offset < 0:
119
229
  raise ValueError(f'invalid seek offset {offset}')
@@ -121,68 +231,156 @@ class TosObjectReader(io.BufferedIOBase):
121
231
  if offset > self._buffer_size():
122
232
  self._prefetch_to_offset(offset)
123
233
 
234
+ if self._total_size is not None:
235
+ offset = min(offset, self._total_size)
236
+
124
237
  self._read_offset = self._buffer.seek(offset)
125
238
  return self._read_offset
126
239
 
127
- def tell(self) -> int:
240
+ def _prefetch_to_offset(self, offset: int) -> None:
241
+ if self.closed:
242
+ raise RuntimeError('read on closed GetObjectOutput')
243
+ size = self._buffer.seek(0, SEEK_END)
244
+ while offset > size:
245
+ chunk = self._object_stream.sequential_read(DEFAULT_CHUNK_SIZE)
246
+ if not chunk:
247
+ break
248
+ size += self._buffer.write(chunk)
249
+
250
+ self._total_size = self._buffer.tell()
251
+
252
+ def _buffer_size(self) -> int:
253
+ cur_pos = self._buffer.tell()
254
+ self._buffer.seek(0, SEEK_END)
255
+ buffer_size = self._buffer.tell()
256
+ self._buffer.seek(cur_pos)
257
+ return buffer_size
258
+
259
+
260
+ class RangedTosObjectReader(TosObjectReader):
261
+ def __init__(self, bucket: str, key: str,
262
+ object_stream: TosObjectStream, buffer_size: Optional[int] = None):
263
+ super().__init__(bucket, key, object_stream)
264
+
265
+ if buffer_size is None:
266
+ self._buffer_size = DEFAULT_BUFFER_SIZE
267
+ self._enable_buffering = True
268
+ else:
269
+ self._buffer_size = buffer_size
270
+ self._enable_buffering = buffer_size > 0
271
+
272
+ self._buffer = bytearray(self._buffer_size) if self._enable_buffering else None
273
+ self._buffer_view = memoryview(self._buffer) if self._buffer else None
274
+ self._buffer_start = 0
275
+ self._buffer_end = 0
276
+
277
+ def read(self, size: Optional[int] = None) -> bytes:
278
+ if self._is_read_to_end():
279
+ return b''
280
+
281
+ read_start = self._read_offset
282
+ if size is None or size < 0:
283
+ read_end = self._get_total_size()
284
+ else:
285
+ read_end = min(read_start + size, self._get_total_size())
286
+
287
+ if read_start >= read_end:
288
+ return b''
289
+
290
+ view = memoryview(bytearray(read_end - read_start))
291
+ self._read_into_view(view, read_start, read_end)
292
+ return view.tobytes()
293
+
294
+ def readinto(self, buf) -> int:
295
+ size = len(buf)
296
+ if self._is_read_to_end() or size == 0:
297
+ return 0
298
+
299
+ try:
300
+ view = memoryview(buf)
301
+ if view.readonly:
302
+ raise TypeError(f'argument must be a writable bytes-like object, not {type(buf).__name__}')
303
+ except TypeError:
304
+ raise TypeError(f'argument must be a writable bytes-like object, not {type(buf).__name__}')
305
+
306
+ read_start = self._read_offset
307
+ read_end = min(read_start + size, self._get_total_size())
308
+ if read_start >= read_end:
309
+ return 0
310
+
311
+ return self._read_into_view(view, read_start, read_end)
312
+
313
+ def seek(self, offset: int, whence: int = SEEK_SET) -> int:
314
+ if whence == SEEK_END:
315
+ if offset >= 0:
316
+ self._read_offset = self._get_total_size()
317
+ return self._read_offset
318
+ # offset is negative
319
+ offset += self._get_total_size()
320
+ elif whence == SEEK_CUR:
321
+ if self._is_read_to_end() and offset >= 0:
322
+ return self._read_offset
323
+ offset += self._read_offset
324
+ elif whence == SEEK_SET:
325
+ pass
326
+ else:
327
+ raise ValueError('invalid whence, must be passed SEEK_CUR, SEEK_SET, or SEEK_END')
328
+
329
+ if offset < 0:
330
+ raise ValueError(f'invalid seek offset {offset}')
331
+
332
+ self._read_offset = min(offset, self._get_total_size())
128
333
  return self._read_offset
129
334
 
130
- def readable(self):
131
- return True
335
+ def _read_into_view(self, view: memoryview, read_start: int, read_end: int) -> int:
336
+ readed = 0
337
+ if self._buffer_start <= read_start < self._buffer_end <= read_end:
338
+ readed_once = self._read_from_buffer(view, read_start, self._buffer_end)
339
+ read_start = self._buffer_end
340
+ view = view[readed_once:]
341
+ readed += readed_once
132
342
 
133
- def writable(self):
134
- return False
343
+ if read_end - read_start >= self._buffer_size or not self._enable_buffering:
344
+ readed += self._read_directly(view, read_start, read_end)
345
+ else:
346
+ readed += self._read_from_buffer(view, read_start, read_end)
135
347
 
136
- def seekable(self):
137
- return True
348
+ self._read_offset += readed
349
+ return readed
138
350
 
139
- @cached_property
140
- def _object_meta(self) -> TosObjectMeta:
141
- return self._get_object_meta()
351
+ def _read_directly(self, view: memoryview, read_start: int, read_end: int) -> int:
352
+ readed = 0
142
353
 
143
- def _trigger_prefetch(self) -> None:
144
- if self._object_stream is None:
145
- object_meta = self._object_meta
146
- self._object_stream = self._get_object_stream(object_meta.etag, object_meta.size)
147
- self._object_stream_offset = 0
354
+ def callback(data: bytes):
355
+ nonlocal readed
356
+ view[readed: readed + len(data)] = data
357
+ readed += len(data)
148
358
 
149
- def _is_read_to_end(self) -> bool:
150
- if self._total_size is None:
151
- return False
152
- return self._read_offset == self._total_size
359
+ self._object_stream.random_read(read_start, read_end, DEFAULT_CHUNK_SIZE, callback)
360
+ return readed
153
361
 
154
- def _get_total_size(self) -> int:
155
- if self._total_size is None:
156
- self._total_size = self._object_meta.size
157
- return self._total_size
362
+ def _read_from_buffer(self, view: memoryview, read_start: int, read_end: int) -> int:
363
+ if read_start < self._buffer_start or read_end > self._buffer_end:
364
+ self._load_buffer(read_start)
158
365
 
159
- def _prefetch_to_offset(self, offset: int) -> None:
160
- self._trigger_prefetch()
161
- size = self._buffer.seek(0, SEEK_END)
162
- if isinstance(self._object_stream, ReadStream):
163
- try:
164
- chunk_size = 1 * 1024 * 1024
165
- while offset > size:
166
- chunk = self._object_stream.read(self._object_stream_offset, chunk_size)
167
- if not chunk:
168
- self._object_stream.close()
169
- break
170
- size += self._buffer.write(chunk)
171
- self._object_stream_offset += len(chunk)
172
- self._total_size = self._buffer.tell()
173
- except:
174
- self._object_stream.close()
175
- raise
176
- elif isinstance(self._object_stream, GetObjectOutput):
177
- try:
178
- while offset > size:
179
- size += self._buffer.write(next(self._object_stream))
180
- except StopIteration:
181
- self._total_size = self._buffer.tell()
366
+ buffer_offset = read_start - self._buffer_start
367
+ readed = read_end - read_start
368
+ assert self._buffer is not None
369
+ view[:readed] = self._buffer[buffer_offset:buffer_offset + readed]
370
+ return readed
182
371
 
183
- def _buffer_size(self) -> int:
184
- cur_pos = self._buffer.tell()
185
- self._buffer.seek(0, SEEK_END)
186
- buffer_size = self._buffer.tell()
187
- self._buffer.seek(cur_pos)
188
- return buffer_size
372
+ def _load_buffer(self, read_start: int) -> None:
373
+ read_end = min(read_start + self._buffer_size, self._get_total_size())
374
+ assert self._buffer_view is not None
375
+
376
+ readed = 0
377
+
378
+ def callback(data: bytes):
379
+ nonlocal readed
380
+ self._buffer_view[readed: readed + len(data)] = data
381
+ readed += len(data)
382
+
383
+ self._object_stream.random_read(read_start, read_end, DEFAULT_CHUNK_SIZE, callback)
384
+
385
+ self._buffer_start = read_start
386
+ self._buffer_end = read_start + readed
@@ -1,7 +1,7 @@
1
1
  import io
2
2
  import logging
3
3
  import threading
4
- from typing import Callable, Any
4
+ from typing import Callable, Any, Union
5
5
 
6
6
  from tos.models2 import PutObjectOutput
7
7
 
@@ -12,11 +12,16 @@ class PutObjectStream(object):
12
12
  def __init__(self, put_object: Callable[[io.BytesIO], PutObjectOutput]):
13
13
  self._put_object = put_object
14
14
  self._buffer = io.BytesIO()
15
+ self._closed = False
15
16
 
16
- def write(self, data):
17
+ def write(self, data) -> int:
18
+ if self._closed:
19
+ raise RuntimeError('write on closed PutObjectStream')
17
20
  self._buffer.write(data)
21
+ return len(data)
18
22
 
19
23
  def close(self):
24
+ self._closed = True
20
25
  self._buffer.seek(0)
21
26
  _ = self._put_object(self._buffer)
22
27
 
@@ -41,11 +46,15 @@ class TosObjectWriter(io.BufferedIOBase):
41
46
  def key(self) -> str:
42
47
  return self._key
43
48
 
49
+ @property
50
+ def closed(self) -> bool:
51
+ return self._closed
52
+
44
53
  def __enter__(self):
45
54
  self._write_offset = 0
46
55
  return self
47
56
 
48
- def __exit__(self, exc_type, exc_val, exc_tb):
57
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
49
58
  if exc_type is not None:
50
59
  try:
51
60
  log.info(f'Exception occurred before closing stream: {exc_type.__name__}: {exc_val}')
@@ -54,17 +63,22 @@ class TosObjectWriter(io.BufferedIOBase):
54
63
  else:
55
64
  self.close()
56
65
 
57
- def write(self, data) -> int:
58
- self._put_object_stream.write(data)
59
- self._write_offset += len(data)
60
- return len(data)
66
+ def write(self, data: Union[bytes, memoryview]) -> int:
67
+ if isinstance(data, memoryview):
68
+ data = data.tobytes()
69
+ written = self._put_object_stream.write(data)
70
+ assert written == len(data)
71
+ self._write_offset += written
72
+ return written
61
73
 
62
74
  def close(self) -> None:
75
+ if self._closed:
76
+ return
77
+
63
78
  with self._lock:
64
- if self._closed:
65
- return
66
- self._closed = True
67
- self._put_object_stream.close()
79
+ if not self._closed:
80
+ self._closed = True
81
+ self._put_object_stream.close()
68
82
 
69
83
  def tell(self) -> int:
70
84
  return self._write_offset
@@ -72,11 +86,11 @@ class TosObjectWriter(io.BufferedIOBase):
72
86
  def flush(self) -> None:
73
87
  pass
74
88
 
75
- def readable(self):
89
+ def readable(self) -> bool:
76
90
  return False
77
91
 
78
- def writable(self):
79
- return True
92
+ def writable(self) -> bool:
93
+ return not self.closed
80
94
 
81
- def seekable(self):
95
+ def seekable(self) -> bool:
82
96
  return False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tostorchconnector
3
- Version: 1.0.3
3
+ Version: 1.0.5
4
4
  Summary: TOS connector integration for PyTorch
5
5
  Author-email: xiangshijian <xiangshijian@bytedance.com>
6
6
  Classifier: Development Status :: 4 - Beta
@@ -10,6 +10,7 @@ Classifier: Programming Language :: Python :: 3.9
10
10
  Classifier: Programming Language :: Python :: 3.10
11
11
  Classifier: Programming Language :: Python :: 3.11
12
12
  Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Programming Language :: Python :: 3.13
13
14
  Classifier: License :: OSI Approved :: Apache Software License
14
15
  Classifier: Operating System :: OS Independent
15
16
  Classifier: Topic :: Utilities
@@ -0,0 +1,14 @@
1
+ tostorchconnector/__init__.py,sha256=GASYEQEKI6nWayop8F8Ln-k7NY4jpmmV9qSy_-gY5Tk,508
2
+ tostorchconnector/tos_checkpoint.py,sha256=xbBWHeESuI6mml5VPPFwUwkE5EZYSRGYeeXTNYhaRQY,1400
3
+ tostorchconnector/tos_client.py,sha256=a8JkE41tsa-77RaCPWjq6mIY0n2AuW6fWkY2Ya1T4r8,7364
4
+ tostorchconnector/tos_common.py,sha256=1TRHJdCQjlc_mSOI5j2VoxXGH-Xbamc9bCd0SZyB0L8,3372
5
+ tostorchconnector/tos_iterable_dataset.py,sha256=EMLNzsH4E1qUFcN6S9bhGBbA4bRAAcGY9o5Z7sHYFIk,4796
6
+ tostorchconnector/tos_map_dataset.py,sha256=4fSLR8iPA8jXcMD94teEQv2h6K3MG35QGC119_SWRds,3913
7
+ tostorchconnector/tos_object_meta.py,sha256=YrEQikioD5v0C_VcoudvTn0apUcxqxmSNQ4dDIip1Zc,562
8
+ tostorchconnector/tos_object_reader.py,sha256=a0KCw70hokWi00MEssjvaGaL-WxM040_lxtNQ8l9XEU,13517
9
+ tostorchconnector/tos_object_writer.py,sha256=GfQUHZ3RLmaJzqOtov4EwgTl6eI3IafhJ2xnFK9cZ-k,2429
10
+ tostorchconnector-1.0.5.dist-info/licenses/LICENSE,sha256=WBJyZyF8q8ZxohuLttIbv7HNLBRA0NnnRWWgfDlDZBE,11361
11
+ tostorchconnector-1.0.5.dist-info/METADATA,sha256=vhpr7HXeZ4vPT9yNnL1uk6KGIYephGJsdgf6CxHlUCA,932
12
+ tostorchconnector-1.0.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
+ tostorchconnector-1.0.5.dist-info/top_level.txt,sha256=w9WvBP6KEi5Dourf86IdBbpWe-fad84S4SBftSg7H-k,18
14
+ tostorchconnector-1.0.5.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- tostorchconnector/__init__.py,sha256=RX4fIC9N-GR5THq9bnfzPPf4XIpTQy4OGfdhpmZz2UU,448
2
- tostorchconnector/tos_checkpoint.py,sha256=SHrPTZ_HPKG3OvIthj2AXPwyKQTTUxEBYnaNNEZrSj0,1497
3
- tostorchconnector/tos_client.py,sha256=8artQiMoObamN7gGYVdjsSfoPTwmBVsLTt5Y898ohCc,7165
4
- tostorchconnector/tos_common.py,sha256=oxxxzpA7Adn3hwNkpB9fHCueGGmZEc2f9tFqtEgPn9M,3167
5
- tostorchconnector/tos_iterable_dataset.py,sha256=a28XPfZJwvWJJeI3qhHXJAAhKyg4U1DHavUI3VsQorY,4312
6
- tostorchconnector/tos_map_dataset.py,sha256=o8Fq8O5aG2LS8NA9eFtEfBbkisWwHg-dItO8UAqu08o,3471
7
- tostorchconnector/tos_object_meta.py,sha256=YrEQikioD5v0C_VcoudvTn0apUcxqxmSNQ4dDIip1Zc,562
8
- tostorchconnector/tos_object_reader.py,sha256=kmHOBS812pcbEMu5sVE_9k6uy9ALQxYAoQtbepxL9SU,6248
9
- tostorchconnector/tos_object_writer.py,sha256=-I2coKCuUITL_LHf5we3cZBQ6rdn0bQL6NwpF39uDuU,1949
10
- tostorchconnector-1.0.3.dist-info/licenses/LICENSE,sha256=WBJyZyF8q8ZxohuLttIbv7HNLBRA0NnnRWWgfDlDZBE,11361
11
- tostorchconnector-1.0.3.dist-info/METADATA,sha256=hxM0l68z0o2oB1Pk7rXUuT-MEOgNv2XGlG0Q_0xRUmo,881
12
- tostorchconnector-1.0.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
- tostorchconnector-1.0.3.dist-info/top_level.txt,sha256=w9WvBP6KEi5Dourf86IdBbpWe-fad84S4SBftSg7H-k,18
14
- tostorchconnector-1.0.3.dist-info/RECORD,,