deltacat 0.1.18b4__py3-none-any.whl → 0.1.18b7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,121 @@
1
+ import logging
2
+ from ray import cloudpickle
3
+ from collections import defaultdict
4
+ import time
5
+ from deltacat.io.object_store import IObjectStore
6
+ from typing import Any, List
7
+ from deltacat import logs
8
+ import uuid
9
+ import socket
10
+ from pymemcache.client.base import Client
11
+ from pymemcache.client.retrying import RetryingClient
12
+ from pymemcache.exceptions import MemcacheUnexpectedCloseError
13
+
14
+ logger = logs.configure_deltacat_logger(logging.getLogger(__name__))
15
+
16
+
17
+ class MemcachedObjectStore(IObjectStore):
18
+ """
19
+ An implementation of object store that uses Memcached.
20
+ """
21
+
22
+ def __init__(self, port=11212) -> None:
23
+ self.client_cache = {}
24
+ self.current_ip = None
25
+ self.SEPARATOR = "_"
26
+ self.port = port
27
+ super().__init__()
28
+
29
+ def put_many(self, objects: List[object], *args, **kwargs) -> List[Any]:
30
+ input = {}
31
+ result = []
32
+ current_ip = self._get_current_ip()
33
+ for obj in objects:
34
+ serialized = cloudpickle.dumps(obj)
35
+ uid = uuid.uuid4()
36
+ ref = self._create_ref(uid, current_ip)
37
+ input[uid.__str__()] = serialized
38
+ result.append(ref)
39
+
40
+ client = self._get_client_by_ip(current_ip)
41
+ if client.set_many(input, noreply=False):
42
+ raise RuntimeError("Unable to write few keys to cache")
43
+
44
+ return result
45
+
46
+ def put(self, obj: object, *args, **kwargs) -> Any:
47
+ serialized = cloudpickle.dumps(obj)
48
+ uid = uuid.uuid4()
49
+ current_ip = self._get_current_ip()
50
+ ref = self._create_ref(uid, current_ip)
51
+ client = self._get_client_by_ip(current_ip)
52
+
53
+ if client.set(uid.__str__(), serialized):
54
+ return ref
55
+ else:
56
+ raise RuntimeError("Unable to write to cache")
57
+
58
+ def get_many(self, refs: List[Any], *args, **kwargs) -> List[object]:
59
+ result = []
60
+ uid_per_ip = defaultdict(lambda: [])
61
+
62
+ start = time.monotonic()
63
+ for ref in refs:
64
+ uid, ip = ref.split(self.SEPARATOR)
65
+ uid_per_ip[ip].append(uid)
66
+
67
+ for (ip, uids) in uid_per_ip.items():
68
+ client = self._get_client_by_ip(ip)
69
+ cache_result = client.get_many(uids)
70
+ assert len(cache_result) == len(
71
+ uids
72
+ ), f"Not all values were returned from cache as {len(cache_result)} != {len(uids)}"
73
+
74
+ values = cache_result.values()
75
+ total_bytes = 0
76
+
77
+ deserialize_start = time.monotonic()
78
+ for serialized in values:
79
+ deserialized = cloudpickle.loads(serialized)
80
+ total_bytes += len(serialized)
81
+ result.append(deserialized)
82
+
83
+ deserialize_end = time.monotonic()
84
+ logger.debug(
85
+ f"The time taken to deserialize {total_bytes} bytes is: {deserialize_end - deserialize_start}",
86
+ )
87
+
88
+ end = time.monotonic()
89
+
90
+ logger.info(f"The total time taken to read all objects is: {end - start}")
91
+ return result
92
+
93
+ def get(self, ref: Any, *args, **kwargs) -> object:
94
+ uid, ip = ref.split(self.SEPARATOR)
95
+ client = self._get_client_by_ip(ip)
96
+ serialized = client.get(uid)
97
+ return cloudpickle.loads(serialized)
98
+
99
+ def _create_ref(self, uid, ip) -> str:
100
+ return f"{uid}{self.SEPARATOR}{ip}"
101
+
102
+ def _get_client_by_ip(self, ip_address: str):
103
+ if ip_address in self.client_cache:
104
+ return self.client_cache[ip_address]
105
+
106
+ base_client = Client((ip_address, self.port))
107
+ client = RetryingClient(
108
+ base_client,
109
+ attempts=3,
110
+ retry_delay=0.01,
111
+ retry_for=[MemcacheUnexpectedCloseError],
112
+ )
113
+
114
+ self.client_cache[ip_address] = client
115
+ return client
116
+
117
+ def _get_current_ip(self):
118
+ if self.current_ip is None:
119
+ self.current_ip = socket.gethostbyname(socket.gethostname())
120
+
121
+ return self.current_ip
@@ -0,0 +1,51 @@
1
+ from typing import List, Any
2
+
3
+
4
+ class IObjectStore:
5
+ """
6
+ An object store interface.
7
+ """
8
+
9
+ def setup(self, *args, **kwargs) -> Any:
10
+ ...
11
+
12
+ """
13
+ Sets up everything needed to run the object store.
14
+ """
15
+
16
+ def put(self, obj: object, *args, **kwargs) -> Any:
17
+ """
18
+ Put a single object into the object store. Depending
19
+ on the implementation, this method can be sync or async.
20
+ """
21
+ return self.put_many([obj])[0]
22
+
23
+ def put_many(self, objects: List[object], *args, **kwargs) -> List[Any]:
24
+ ...
25
+
26
+ """
27
+ Put many objects into the object store. It would return an ordered list
28
+ of object references corresponding to each object in the input.
29
+ """
30
+
31
+ def get(self, ref: Any, *args, **kwargs) -> object:
32
+ """
33
+ Get a single object from an object store.
34
+ """
35
+ return self.get_many([ref])[0]
36
+
37
+ def get_many(self, refs: List[Any], *args, **kwargs) -> List[object]:
38
+ ...
39
+
40
+ """
41
+ Get a list of objects from the object store. Use this method to
42
+ avoid multiple get calls. Note that depending on implementation it may
43
+ or may not return ordered results.
44
+ """
45
+
46
+ def clear(self, *args, **kwargs) -> bool:
47
+ ...
48
+
49
+ """
50
+ Clears the object store and all the associated data in it.
51
+ """
@@ -0,0 +1,23 @@
1
+ import ray
2
+ from ray import cloudpickle
3
+ from deltacat.io.object_store import IObjectStore
4
+ from typing import Any, List
5
+
6
+
7
+ class RayPlasmaObjectStore(IObjectStore):
8
+ """
9
+ An implementation of object store that uses Ray plasma object store.
10
+ """
11
+
12
+ def put_many(self, objects: List[object], *args, **kwargs) -> List[Any]:
13
+ result = []
14
+ for obj in objects:
15
+ object_ref = ray.put(obj)
16
+ pickled = cloudpickle.dumps(object_ref)
17
+ result.append(pickled)
18
+
19
+ return result
20
+
21
+ def get_many(self, refs: List[Any], *args, **kwargs) -> List[object]:
22
+ loaded_refs = [cloudpickle.loads(obj_id) for obj_id in refs]
23
+ return ray.get(loaded_refs)
@@ -0,0 +1,114 @@
1
+ import logging
2
+ from ray import cloudpickle
3
+ import time
4
+ from deltacat.io.object_store import IObjectStore
5
+ from typing import Any, List
6
+ from deltacat import logs
7
+ import uuid
8
+ import socket
9
+ import redis
10
+ from collections import defaultdict
11
+
12
+
13
+ logger = logs.configure_deltacat_logger(logging.getLogger(__name__))
14
+
15
+
16
+ class RedisObjectStore(IObjectStore):
17
+ """
18
+ An implementation of object store that uses Redis in memory DB.
19
+ """
20
+
21
+ def __init__(self) -> None:
22
+ self.client_cache = {}
23
+ self.current_ip = None
24
+ self.SEPARATOR = "_"
25
+ super().__init__()
26
+
27
+ def put(self, obj: object, *args, **kwargs) -> Any:
28
+ serialized = cloudpickle.dumps(obj)
29
+ uid = uuid.uuid4()
30
+ current_ip = self._get_current_ip()
31
+ ref = self._create_ref(uid, current_ip)
32
+
33
+ client = self._get_client_by_ip(current_ip)
34
+ if client.set(uid.__str__(), serialized):
35
+ return ref
36
+ else:
37
+ raise RuntimeError(f"Unable to write {ref} to cache")
38
+
39
+ def put_many(self, objects: List[object], *args, **kwargs) -> List[Any]:
40
+ input = {}
41
+ result = []
42
+ current_ip = self._get_current_ip()
43
+ for obj in objects:
44
+ serialized = cloudpickle.dumps(obj)
45
+ uid = uuid.uuid4()
46
+ ref = self._create_ref(uid, current_ip)
47
+ input[uid.__str__()] = serialized
48
+ result.append(ref)
49
+
50
+ client = self._get_client_by_ip(current_ip)
51
+
52
+ if client.mset(input):
53
+ return result
54
+ else:
55
+ raise RuntimeError("Unable to update cache")
56
+
57
+ def get_many(self, refs: List[Any], *args, **kwargs) -> List[object]:
58
+ result = []
59
+ uid_per_ip = defaultdict(lambda: [])
60
+
61
+ start = time.monotonic()
62
+ for ref in refs:
63
+ uid, ip = ref.split(self.SEPARATOR)
64
+ uid_per_ip[ip].append(uid)
65
+
66
+ for (ip, uids) in uid_per_ip.items():
67
+ client = self._get_client_by_ip(ip)
68
+ cache_result = client.mget(uids)
69
+ assert len(cache_result) == len(
70
+ uids
71
+ ), "Not all values were returned from cache"
72
+
73
+ total_bytes = 0
74
+
75
+ deserialize_start = time.monotonic()
76
+ for serialized in cache_result:
77
+ deserialized = cloudpickle.loads(serialized)
78
+ total_bytes += len(serialized)
79
+ result.append(deserialized)
80
+
81
+ deserialize_end = time.monotonic()
82
+ logger.debug(
83
+ f"The time taken to deserialize {total_bytes} bytes is: {deserialize_end - deserialize_start}",
84
+ )
85
+
86
+ end = time.monotonic()
87
+
88
+ logger.info(f"The total time taken to read all objects is: {end - start}")
89
+
90
+ return result
91
+
92
+ def get(self, ref: Any, *args, **kwargs) -> object:
93
+ uid, ip = ref.split(self.SEPARATOR)
94
+ client = self._get_client_by_ip(ip)
95
+ serialized = client.get(uid)
96
+ return cloudpickle.loads(serialized)
97
+
98
+ def _get_client_by_ip(self, ip_address: str):
99
+ if ip_address in self.client_cache:
100
+ return self.client_cache[ip_address]
101
+
102
+ base_client = redis.Redis(ip_address, 7777)
103
+
104
+ self.client_cache[ip_address] = base_client
105
+ return base_client
106
+
107
+ def _get_current_ip(self):
108
+ if self.current_ip is None:
109
+ self.current_ip = socket.gethostbyname(socket.gethostname())
110
+
111
+ return self.current_ip
112
+
113
+ def _create_ref(self, uid, ip):
114
+ return f"{uid}{self.SEPARATOR}{ip}"
@@ -0,0 +1,44 @@
1
+ import logging
2
+ from ray import cloudpickle
3
+ import time
4
+ from deltacat.io.object_store import IObjectStore
5
+ from typing import Any, List
6
+ from deltacat import logs
7
+ import uuid
8
+ from deltacat.aws import s3u as s3_utils
9
+
10
+ logger = logs.configure_deltacat_logger(logging.getLogger(__name__))
11
+
12
+
13
+ class S3ObjectStore(IObjectStore):
14
+ """
15
+ An implementation of object store that uses S3.
16
+ """
17
+
18
+ def __init__(self, bucket_prefix: str) -> None:
19
+ self.bucket = bucket_prefix
20
+ super().__init__()
21
+
22
+ def put_many(self, objects: List[object], *args, **kwargs) -> List[Any]:
23
+ result = []
24
+ for obj in objects:
25
+ serialized = cloudpickle.dumps(obj)
26
+ ref = uuid.uuid4()
27
+
28
+ s3_utils.upload(f"s3://{self.bucket}/{ref}", serialized)
29
+ result.append(ref)
30
+
31
+ return result
32
+
33
+ def get_many(self, refs: List[Any], *args, **kwargs) -> List[object]:
34
+ result = []
35
+ start = time.monotonic()
36
+ for ref in refs:
37
+ cur = s3_utils.download(f"s3://{self.bucket}/{ref}")
38
+ serialized = cur["Body"].read()
39
+ loaded = cloudpickle.loads(serialized)
40
+ result.append(loaded)
41
+ end = time.monotonic()
42
+
43
+ logger.info(f"The total time taken to read all objects is: {end - start}")
44
+ return result
@@ -17,6 +17,10 @@ class TestFitInputDeltas(unittest.TestCase):
17
17
 
18
18
  super().setUpClass()
19
19
 
20
+ @classmethod
21
+ def tearDownClass(cls) -> None:
22
+ cls.module_patcher.stop()
23
+
20
24
  def test_sanity(self):
21
25
  from deltacat.compute.compactor.utils import io
22
26
 
File without changes
@@ -0,0 +1,86 @@
1
+ import unittest
2
+ from unittest import mock
3
+
4
+
5
+ class TestFileObjectStore(unittest.TestCase):
6
+
7
+ TEST_VALUE = "test-value"
8
+
9
+ @classmethod
10
+ def setUpClass(cls):
11
+ cls.ray_mock = mock.MagicMock()
12
+ cls.os_mock = mock.MagicMock()
13
+
14
+ cls.module_patcher = mock.patch.dict(
15
+ "sys.modules", {"ray": cls.ray_mock, "os": cls.os_mock}
16
+ )
17
+ cls.module_patcher.start()
18
+
19
+ super().setUpClass()
20
+
21
+ @classmethod
22
+ def tearDownClass(cls) -> None:
23
+ cls.module_patcher.stop()
24
+
25
+ @mock.patch(
26
+ "deltacat.io.file_object_store.open",
27
+ new_callable=mock.mock_open,
28
+ read_data="data",
29
+ )
30
+ def test_put_many_sanity(self, mock_file):
31
+ from deltacat.io.file_object_store import FileObjectStore
32
+
33
+ object_store = FileObjectStore(dir_path="")
34
+ self.ray_mock.cloudpickle.dumps.return_value = self.TEST_VALUE
35
+ result = object_store.put_many(["a", "b"])
36
+
37
+ self.assertEqual(2, len(result))
38
+ self.assertEqual(2, mock_file.call_count)
39
+
40
+ @mock.patch(
41
+ "deltacat.io.file_object_store.open",
42
+ new_callable=mock.mock_open,
43
+ read_data="data",
44
+ )
45
+ def test_get_many_sanity(self, mock_file):
46
+ from deltacat.io.file_object_store import FileObjectStore
47
+
48
+ object_store = FileObjectStore(dir_path="")
49
+ self.ray_mock.cloudpickle.loads.return_value = self.TEST_VALUE
50
+
51
+ result = object_store.get_many(["test", "test"])
52
+
53
+ self.assertEqual(2, len(result))
54
+ self.assertEqual(2, mock_file.call_count)
55
+
56
+ @mock.patch(
57
+ "deltacat.io.file_object_store.open",
58
+ new_callable=mock.mock_open,
59
+ read_data="data",
60
+ )
61
+ def test_get_sanity(self, mock_file):
62
+ from deltacat.io.file_object_store import FileObjectStore
63
+
64
+ object_store = FileObjectStore(dir_path="")
65
+ self.ray_mock.cloudpickle.loads.return_value = self.TEST_VALUE
66
+
67
+ result = object_store.get("test")
68
+
69
+ self.assertEqual(self.TEST_VALUE, result)
70
+ self.assertEqual(1, mock_file.call_count)
71
+
72
+ @mock.patch(
73
+ "deltacat.io.file_object_store.open",
74
+ new_callable=mock.mock_open,
75
+ read_data="data",
76
+ )
77
+ def test_put_sanity(self, mock_file):
78
+ from deltacat.io.file_object_store import FileObjectStore
79
+
80
+ object_store = FileObjectStore(dir_path="")
81
+ self.ray_mock.cloudpickle.dumps.return_value = self.TEST_VALUE
82
+
83
+ result = object_store.put("test")
84
+
85
+ self.assertIsNotNone(result)
86
+ self.assertEqual(1, mock_file.call_count)
@@ -0,0 +1,158 @@
1
+ import unittest
2
+ from unittest import mock
3
+
4
+
5
+ @mock.patch("deltacat.io.memcached_object_store.cloudpickle")
6
+ @mock.patch("deltacat.io.memcached_object_store.socket")
7
+ class TestMemcachedObjectStore(unittest.TestCase):
8
+
9
+ TEST_VALUE = "test-value"
10
+
11
+ def setUp(self):
12
+ from deltacat.io.memcached_object_store import MemcachedObjectStore
13
+
14
+ self.object_store = MemcachedObjectStore()
15
+
16
+ @mock.patch("deltacat.io.memcached_object_store.Client")
17
+ @mock.patch("deltacat.io.memcached_object_store.RetryingClient")
18
+ def test_put_many_sanity(
19
+ self,
20
+ mock_retrying_client,
21
+ mock_client,
22
+ mock_socket,
23
+ mock_cloudpickle,
24
+ ):
25
+ mock_cloudpickle.dumps.return_value = self.TEST_VALUE
26
+ mock_cloudpickle.loads.return_value = self.TEST_VALUE
27
+ mock_socket.gethostbyname.return_value = "0.0.0.0"
28
+ mock_socket.gethostname.return_value = "test-host"
29
+ mock_retrying_client.return_value = mock_client.return_value
30
+ mock_client.return_value.set_many.return_value = []
31
+
32
+ result = self.object_store.put_many(["a", "b"])
33
+
34
+ self.assertEqual(2, len(result))
35
+ self.assertRegex(result[0], ".*_.*")
36
+ self.assertEqual(1, mock_client.return_value.set_many.call_count)
37
+
38
+ @mock.patch("deltacat.io.memcached_object_store.Client")
39
+ @mock.patch("deltacat.io.memcached_object_store.RetryingClient")
40
+ def test_put_many_when_cache_fails(
41
+ self, mock_retrying_client, mock_client, mock_socket, mock_cloudpickle
42
+ ):
43
+ mock_cloudpickle.dumps.return_value = self.TEST_VALUE
44
+ mock_cloudpickle.loads.return_value = self.TEST_VALUE
45
+ mock_socket.gethostbyname.return_value = "0.0.0.0"
46
+ mock_socket.gethostname.return_value = "test-host"
47
+ mock_retrying_client.return_value = mock_client.return_value
48
+ mock_client.return_value.set_many.return_value = ["abcd"]
49
+
50
+ with self.assertRaises(RuntimeError):
51
+ self.object_store.put_many(["a", "b"])
52
+
53
+ self.assertEqual(1, mock_client.return_value.set_many.call_count)
54
+
55
+ @mock.patch("deltacat.io.memcached_object_store.Client")
56
+ @mock.patch("deltacat.io.memcached_object_store.RetryingClient")
57
+ def test_get_many_sanity(
58
+ self, mock_retrying_client, mock_client, mock_socket, mock_cloudpickle
59
+ ):
60
+ mock_cloudpickle.dumps.return_value = self.TEST_VALUE
61
+ mock_cloudpickle.loads.return_value = self.TEST_VALUE
62
+ mock_socket.gethostbyname.return_value = "0.0.0.0"
63
+ mock_socket.gethostname.return_value = "test-host"
64
+ mock_client.return_value.get_many.return_value = {
65
+ "key1": "value1",
66
+ "key2": "value2",
67
+ }
68
+ mock_retrying_client.return_value = mock_client.return_value
69
+
70
+ result = self.object_store.get_many(["test_ip", "test_ip"])
71
+
72
+ self.assertEqual(2, len(result))
73
+ self.assertEqual(1, mock_client.return_value.get_many.call_count)
74
+
75
+ @mock.patch("deltacat.io.memcached_object_store.Client")
76
+ @mock.patch("deltacat.io.memcached_object_store.RetryingClient")
77
+ def test_get_many_when_cache_expired(
78
+ self, mock_retrying_client, mock_client, mock_socket, mock_cloudpickle
79
+ ):
80
+ mock_cloudpickle.dumps.return_value = self.TEST_VALUE
81
+ mock_cloudpickle.loads.return_value = self.TEST_VALUE
82
+ mock_socket.gethostbyname.return_value = "0.0.0.0"
83
+ mock_socket.gethostname.return_value = "test-host"
84
+ mock_client.return_value.get_many.return_value = {"key1": "value1"}
85
+ mock_retrying_client.return_value = mock_client.return_value
86
+
87
+ with self.assertRaises(AssertionError):
88
+ self.object_store.get_many(["test_ip", "test_ip"])
89
+
90
+ self.assertEqual(1, mock_client.return_value.get_many.call_count)
91
+
92
+ @mock.patch("deltacat.io.memcached_object_store.Client")
93
+ @mock.patch("deltacat.io.memcached_object_store.RetryingClient")
94
+ def test_get_sanity(
95
+ self, mock_retrying_client, mock_client, mock_socket, mock_cloudpickle
96
+ ):
97
+ mock_cloudpickle.dumps.return_value = self.TEST_VALUE
98
+ mock_cloudpickle.loads.return_value = self.TEST_VALUE
99
+ mock_socket.gethostbyname.return_value = "0.0.0.0"
100
+ mock_socket.gethostname.return_value = "test-host"
101
+ mock_client.return_value.get.return_value = self.TEST_VALUE
102
+ mock_retrying_client.return_value = mock_client.return_value
103
+
104
+ result = self.object_store.get("test_ip")
105
+
106
+ self.assertEqual(self.TEST_VALUE, result)
107
+ self.assertEqual(1, mock_client.return_value.get.call_count)
108
+
109
+ @mock.patch("deltacat.io.memcached_object_store.Client")
110
+ @mock.patch("deltacat.io.memcached_object_store.RetryingClient")
111
+ def test_get_when_cache_fails(
112
+ self, mock_retrying_client, mock_client, mock_socket, mock_cloudpickle
113
+ ):
114
+ mock_cloudpickle.dumps.return_value = self.TEST_VALUE
115
+ mock_cloudpickle.loads.return_value = self.TEST_VALUE
116
+ mock_socket.gethostbyname.return_value = "0.0.0.0"
117
+ mock_socket.gethostname.return_value = "test-host"
118
+ mock_client.return_value.get.side_effect = RuntimeError()
119
+ mock_retrying_client.return_value = mock_client.return_value
120
+
121
+ with self.assertRaises(RuntimeError):
122
+ self.object_store.get("test_ip")
123
+
124
+ self.assertEqual(1, mock_client.return_value.get.call_count)
125
+
126
+ @mock.patch("deltacat.io.memcached_object_store.Client")
127
+ @mock.patch("deltacat.io.memcached_object_store.RetryingClient")
128
+ def test_put_sanity(
129
+ self, mock_retrying_client, mock_client, mock_socket, mock_cloudpickle
130
+ ):
131
+ mock_cloudpickle.dumps.return_value = self.TEST_VALUE
132
+ mock_cloudpickle.loads.return_value = self.TEST_VALUE
133
+ mock_socket.gethostbyname.return_value = "0.0.0.0"
134
+ mock_socket.gethostname.return_value = "test-host"
135
+ mock_retrying_client.return_value = mock_client.return_value
136
+ mock_client.return_value.set.return_value = True
137
+
138
+ result = self.object_store.put("test")
139
+
140
+ self.assertIsNotNone(result)
141
+ self.assertEqual(1, mock_client.return_value.set.call_count)
142
+
143
+ @mock.patch("deltacat.io.memcached_object_store.Client")
144
+ @mock.patch("deltacat.io.memcached_object_store.RetryingClient")
145
+ def test_put_when_cache_fails(
146
+ self, mock_retrying_client, mock_client, mock_socket, mock_cloudpickle
147
+ ):
148
+ mock_cloudpickle.dumps.return_value = self.TEST_VALUE
149
+ mock_cloudpickle.loads.return_value = self.TEST_VALUE
150
+ mock_socket.gethostbyname.return_value = "0.0.0.0"
151
+ mock_socket.gethostname.return_value = "test-host"
152
+ mock_retrying_client.return_value = mock_client.return_value
153
+ mock_client.return_value.set.return_value = False
154
+
155
+ with self.assertRaises(RuntimeError):
156
+ self.object_store.put("test_ip")
157
+
158
+ self.assertEqual(1, mock_client.return_value.set.call_count)
@@ -0,0 +1,54 @@
1
+ import unittest
2
+ from unittest import mock
3
+
4
+
5
+ class TestRayPlasmaObjectStore(unittest.TestCase):
6
+
7
+ TEST_VALUE = "test-value"
8
+
9
+ @classmethod
10
+ def setUpClass(cls):
11
+ from deltacat.io.ray_plasma_object_store import RayPlasmaObjectStore
12
+
13
+ cls.object_store = RayPlasmaObjectStore()
14
+
15
+ super().setUpClass()
16
+
17
+ @mock.patch("deltacat.io.ray_plasma_object_store.ray")
18
+ @mock.patch("deltacat.io.ray_plasma_object_store.cloudpickle")
19
+ def test_put_many_sanity(self, mock_cloudpickle, mock_ray):
20
+ mock_ray.put.return_value = "c"
21
+ mock_cloudpickle.dumps.return_value = self.TEST_VALUE
22
+ result = self.object_store.put_many(["a", "b"])
23
+
24
+ self.assertEqual(2, len(result))
25
+
26
+ @mock.patch("deltacat.io.ray_plasma_object_store.ray")
27
+ @mock.patch("deltacat.io.ray_plasma_object_store.cloudpickle")
28
+ def test_get_many_sanity(self, mock_cloudpickle, mock_ray):
29
+ mock_ray.get.return_value = ["a", "b"]
30
+ mock_cloudpickle.loads.return_value = self.TEST_VALUE
31
+
32
+ result = self.object_store.get_many(["test", "test"])
33
+
34
+ self.assertEqual(2, len(result))
35
+
36
+ @mock.patch("deltacat.io.ray_plasma_object_store.ray")
37
+ @mock.patch("deltacat.io.ray_plasma_object_store.cloudpickle")
38
+ def test_get_sanity(self, mock_cloudpickle, mock_ray):
39
+ mock_ray.get.return_value = [self.TEST_VALUE]
40
+ mock_cloudpickle.loads.return_value = self.TEST_VALUE
41
+
42
+ result = self.object_store.get("test")
43
+
44
+ self.assertEqual(self.TEST_VALUE, result)
45
+
46
+ @mock.patch("deltacat.io.ray_plasma_object_store.ray")
47
+ @mock.patch("deltacat.io.ray_plasma_object_store.cloudpickle")
48
+ def test_put_sanity(self, mock_cloudpickle, mock_ray):
49
+ mock_ray.put.return_value = "c"
50
+ mock_cloudpickle.dumps.return_value = self.TEST_VALUE
51
+
52
+ result = self.object_store.put("test")
53
+
54
+ self.assertEqual(self.TEST_VALUE, result)