fsspec 2024.5.0__py3-none-any.whl → 2024.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. fsspec/_version.py +2 -2
  2. fsspec/caching.py +3 -2
  3. fsspec/compression.py +1 -1
  4. fsspec/generic.py +3 -0
  5. fsspec/implementations/cached.py +6 -16
  6. fsspec/implementations/dirfs.py +2 -0
  7. fsspec/implementations/github.py +12 -0
  8. fsspec/implementations/http.py +2 -1
  9. fsspec/implementations/reference.py +9 -0
  10. fsspec/implementations/smb.py +10 -0
  11. fsspec/json.py +121 -0
  12. fsspec/registry.py +24 -18
  13. fsspec/spec.py +119 -33
  14. fsspec/utils.py +1 -1
  15. {fsspec-2024.5.0.dist-info → fsspec-2024.6.1.dist-info}/METADATA +10 -5
  16. fsspec-2024.6.1.dist-info/RECORD +55 -0
  17. {fsspec-2024.5.0.dist-info → fsspec-2024.6.1.dist-info}/WHEEL +1 -1
  18. fsspec/implementations/tests/__init__.py +0 -0
  19. fsspec/implementations/tests/cassettes/test_dbfs/test_dbfs_file_listing.yaml +0 -112
  20. fsspec/implementations/tests/cassettes/test_dbfs/test_dbfs_mkdir.yaml +0 -582
  21. fsspec/implementations/tests/cassettes/test_dbfs/test_dbfs_read_pyarrow_non_partitioned.yaml +0 -873
  22. fsspec/implementations/tests/cassettes/test_dbfs/test_dbfs_read_range.yaml +0 -458
  23. fsspec/implementations/tests/cassettes/test_dbfs/test_dbfs_read_range_chunked.yaml +0 -1355
  24. fsspec/implementations/tests/cassettes/test_dbfs/test_dbfs_write_and_read.yaml +0 -795
  25. fsspec/implementations/tests/cassettes/test_dbfs/test_dbfs_write_pyarrow_non_partitioned.yaml +0 -613
  26. fsspec/implementations/tests/conftest.py +0 -39
  27. fsspec/implementations/tests/local/__init__.py +0 -0
  28. fsspec/implementations/tests/local/local_fixtures.py +0 -18
  29. fsspec/implementations/tests/local/local_test.py +0 -14
  30. fsspec/implementations/tests/memory/__init__.py +0 -0
  31. fsspec/implementations/tests/memory/memory_fixtures.py +0 -27
  32. fsspec/implementations/tests/memory/memory_test.py +0 -14
  33. fsspec/implementations/tests/out.zip +0 -0
  34. fsspec/implementations/tests/test_archive.py +0 -382
  35. fsspec/implementations/tests/test_arrow.py +0 -259
  36. fsspec/implementations/tests/test_cached.py +0 -1306
  37. fsspec/implementations/tests/test_common.py +0 -35
  38. fsspec/implementations/tests/test_dask.py +0 -29
  39. fsspec/implementations/tests/test_data.py +0 -20
  40. fsspec/implementations/tests/test_dbfs.py +0 -268
  41. fsspec/implementations/tests/test_dirfs.py +0 -588
  42. fsspec/implementations/tests/test_ftp.py +0 -178
  43. fsspec/implementations/tests/test_git.py +0 -76
  44. fsspec/implementations/tests/test_http.py +0 -577
  45. fsspec/implementations/tests/test_jupyter.py +0 -57
  46. fsspec/implementations/tests/test_libarchive.py +0 -33
  47. fsspec/implementations/tests/test_local.py +0 -1285
  48. fsspec/implementations/tests/test_memory.py +0 -382
  49. fsspec/implementations/tests/test_reference.py +0 -720
  50. fsspec/implementations/tests/test_sftp.py +0 -233
  51. fsspec/implementations/tests/test_smb.py +0 -139
  52. fsspec/implementations/tests/test_tar.py +0 -243
  53. fsspec/implementations/tests/test_webhdfs.py +0 -197
  54. fsspec/implementations/tests/test_zip.py +0 -134
  55. fsspec/tests/__init__.py +0 -0
  56. fsspec/tests/conftest.py +0 -188
  57. fsspec/tests/data/listing.html +0 -1
  58. fsspec/tests/test_api.py +0 -498
  59. fsspec/tests/test_async.py +0 -230
  60. fsspec/tests/test_caches.py +0 -255
  61. fsspec/tests/test_callbacks.py +0 -89
  62. fsspec/tests/test_compression.py +0 -164
  63. fsspec/tests/test_config.py +0 -129
  64. fsspec/tests/test_core.py +0 -466
  65. fsspec/tests/test_downstream.py +0 -40
  66. fsspec/tests/test_file.py +0 -200
  67. fsspec/tests/test_fuse.py +0 -147
  68. fsspec/tests/test_generic.py +0 -90
  69. fsspec/tests/test_gui.py +0 -23
  70. fsspec/tests/test_mapping.py +0 -228
  71. fsspec/tests/test_parquet.py +0 -140
  72. fsspec/tests/test_registry.py +0 -134
  73. fsspec/tests/test_spec.py +0 -1167
  74. fsspec/tests/test_utils.py +0 -478
  75. fsspec-2024.5.0.dist-info/RECORD +0 -111
  76. {fsspec-2024.5.0.dist-info → fsspec-2024.6.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,230 +0,0 @@
1
- import asyncio
2
- import inspect
3
- import io
4
- import os
5
- import time
6
-
7
- import pytest
8
-
9
- import fsspec
10
- import fsspec.asyn
11
- from fsspec.asyn import _run_coros_in_chunks
12
-
13
-
14
- def test_sync_methods():
15
- inst = fsspec.asyn.AsyncFileSystem()
16
- assert inspect.iscoroutinefunction(inst._info)
17
- assert hasattr(inst, "info")
18
- assert inst.info.__qualname__ == "AsyncFileSystem._info"
19
- assert not inspect.iscoroutinefunction(inst.info)
20
-
21
-
22
- def test_when_sync_methods_are_disabled():
23
- class TestFS(fsspec.asyn.AsyncFileSystem):
24
- mirror_sync_methods = False
25
-
26
- inst = TestFS()
27
- assert inspect.iscoroutinefunction(inst._info)
28
- assert not inspect.iscoroutinefunction(inst.info)
29
- assert inst.info.__qualname__ == "AbstractFileSystem.info"
30
-
31
-
32
- def test_interrupt():
33
- loop = fsspec.asyn.get_loop()
34
-
35
- async def f():
36
- await asyncio.sleep(1000000)
37
- return True
38
-
39
- fut = asyncio.run_coroutine_threadsafe(f(), loop)
40
- time.sleep(0.01) # task launches
41
- out = fsspec.asyn._dump_running_tasks(with_task=True)
42
- task = out[0]["task"]
43
- assert task.done() and fut.done()
44
- assert isinstance(fut.exception(), fsspec.asyn.FSSpecCoroutineCancel)
45
-
46
-
47
- class _DummyAsyncKlass:
48
- def __init__(self):
49
- self.loop = fsspec.asyn.get_loop()
50
-
51
- async def _dummy_async_func(self):
52
- # Sleep 1 second function to test timeout
53
- await asyncio.sleep(1)
54
- return True
55
-
56
- async def _bad_multiple_sync(self):
57
- fsspec.asyn.sync_wrapper(_DummyAsyncKlass._dummy_async_func)(self)
58
- return True
59
-
60
- dummy_func = fsspec.asyn.sync_wrapper(_dummy_async_func)
61
- bad_multiple_sync_func = fsspec.asyn.sync_wrapper(_bad_multiple_sync)
62
-
63
-
64
- def test_sync_wrapper_timeout_on_less_than_expected_wait_time_not_finish_function():
65
- test_obj = _DummyAsyncKlass()
66
- with pytest.raises(fsspec.FSTimeoutError):
67
- test_obj.dummy_func(timeout=0.1)
68
-
69
-
70
- def test_sync_wrapper_timeout_on_more_than_expected_wait_time_will_finish_function():
71
- test_obj = _DummyAsyncKlass()
72
- assert test_obj.dummy_func(timeout=5)
73
-
74
-
75
- def test_sync_wrapper_timeout_none_will_wait_func_finished():
76
- test_obj = _DummyAsyncKlass()
77
- assert test_obj.dummy_func(timeout=None)
78
-
79
-
80
- def test_sync_wrapper_treat_timeout_0_as_none():
81
- test_obj = _DummyAsyncKlass()
82
- assert test_obj.dummy_func(timeout=0)
83
-
84
-
85
- def test_sync_wrapper_bad_multiple_sync():
86
- test_obj = _DummyAsyncKlass()
87
- with pytest.raises(NotImplementedError):
88
- test_obj.bad_multiple_sync_func(timeout=5)
89
-
90
-
91
- def test_run_coros_in_chunks(monkeypatch):
92
- total_running = 0
93
-
94
- async def runner():
95
- nonlocal total_running
96
-
97
- total_running += 1
98
- await asyncio.sleep(0)
99
- if total_running > 4:
100
- raise ValueError("More than 4 coroutines are running together")
101
- total_running -= 1
102
- return 1
103
-
104
- async def main(**kwargs):
105
- nonlocal total_running
106
-
107
- total_running = 0
108
- coros = [runner() for _ in range(32)]
109
- results = await _run_coros_in_chunks(coros, **kwargs)
110
- for result in results:
111
- if isinstance(result, Exception):
112
- raise result
113
- return results
114
-
115
- assert sum(asyncio.run(main(batch_size=4))) == 32
116
-
117
- with pytest.raises(ValueError):
118
- asyncio.run(main(batch_size=5))
119
-
120
- with pytest.raises(ValueError):
121
- asyncio.run(main(batch_size=-1))
122
-
123
- assert sum(asyncio.run(main(batch_size=4))) == 32
124
-
125
- monkeypatch.setitem(fsspec.config.conf, "gather_batch_size", 5)
126
- with pytest.raises(ValueError):
127
- asyncio.run(main())
128
- assert sum(asyncio.run(main(batch_size=4))) == 32 # override
129
-
130
- monkeypatch.setitem(fsspec.config.conf, "gather_batch_size", 4)
131
- assert sum(asyncio.run(main())) == 32 # override
132
-
133
-
134
- @pytest.mark.skipif(os.name != "nt", reason="only for windows")
135
- def test_windows_policy():
136
- from asyncio.windows_events import SelectorEventLoop
137
-
138
- loop = fsspec.asyn.get_loop()
139
- policy = asyncio.get_event_loop_policy()
140
-
141
- # Ensure that the created loop always uses selector policy
142
- assert isinstance(loop, SelectorEventLoop)
143
-
144
- # Ensure that the global policy is not changed and it is
145
- # set to the default one. This is important since the
146
- # get_loop() method will temporarily override the policy
147
- # with the one which uses selectors on windows, so this
148
- # check ensures that we are restoring the old policy back
149
- # after our change.
150
- assert isinstance(policy, asyncio.DefaultEventLoopPolicy)
151
-
152
-
153
- def test_running_async():
154
- assert not fsspec.asyn.running_async()
155
-
156
- async def go():
157
- assert fsspec.asyn.running_async()
158
-
159
- asyncio.run(go())
160
-
161
-
162
- class DummyAsyncFS(fsspec.asyn.AsyncFileSystem):
163
- _file_class = fsspec.asyn.AbstractAsyncStreamedFile
164
-
165
- async def _info(self, path, **kwargs):
166
- return {"name": "misc/foo.txt", "type": "file", "size": 100}
167
-
168
- async def open_async(
169
- self,
170
- path,
171
- mode="rb",
172
- block_size=None,
173
- autocommit=True,
174
- cache_options=None,
175
- **kwargs,
176
- ):
177
- return DummyAsyncStreamedFile(
178
- self,
179
- path,
180
- mode,
181
- block_size,
182
- autocommit,
183
- cache_options=cache_options,
184
- **kwargs,
185
- )
186
-
187
-
188
- class DummyAsyncStreamedFile(fsspec.asyn.AbstractAsyncStreamedFile):
189
- def __init__(self, fs, path, mode, block_size, autocommit, **kwargs):
190
- super().__init__(fs, path, mode, block_size, autocommit, **kwargs)
191
- self.temp_buffer = io.BytesIO(b"foo-bar" * 20)
192
-
193
- async def _fetch_range(self, start, end):
194
- return self.temp_buffer.read(end - start)
195
-
196
- async def _initiate_upload(self):
197
- # Reinitialize for new uploads.
198
- self.temp_buffer = io.BytesIO()
199
-
200
- async def _upload_chunk(self, final=False):
201
- self.temp_buffer.write(self.buffer.getbuffer())
202
-
203
- async def get_data(self):
204
- return self.temp_buffer.getbuffer().tobytes()
205
-
206
- async def get_data(self):
207
- return self.temp_buffer.getbuffer().tobytes()
208
-
209
-
210
- @pytest.mark.asyncio
211
- async def test_async_streamed_file_write():
212
- test_fs = DummyAsyncFS()
213
- streamed_file = await test_fs.open_async("misc/foo.txt", mode="wb")
214
- inp_data = "foo-bar".encode("utf8") * streamed_file.blocksize * 2
215
- await streamed_file.write(inp_data)
216
- assert streamed_file.loc == len(inp_data)
217
- await streamed_file.close()
218
- out_data = await streamed_file.get_data()
219
- assert out_data.count(b"foo-bar") == streamed_file.blocksize * 2
220
-
221
-
222
- @pytest.mark.asyncio
223
- async def test_async_streamed_file_read():
224
- test_fs = DummyAsyncFS()
225
- streamed_file = await test_fs.open_async("misc/foo.txt", mode="rb")
226
- assert (
227
- await streamed_file.read(7 * 3) + await streamed_file.read(7 * 18)
228
- == b"foo-bar" * 20
229
- )
230
- await streamed_file.close()
@@ -1,255 +0,0 @@
1
- import pickle
2
- import string
3
-
4
- import pytest
5
-
6
- from fsspec.caching import (
7
- BlockCache,
8
- FirstChunkCache,
9
- ReadAheadCache,
10
- caches,
11
- register_cache,
12
- )
13
- from fsspec.implementations.cached import WholeFileCacheFileSystem
14
-
15
-
16
- def test_cache_getitem(Cache_imp):
17
- cacher = Cache_imp(4, letters_fetcher, len(string.ascii_letters))
18
- assert cacher._fetch(0, 4) == b"abcd"
19
- assert cacher._fetch(None, 4) == b"abcd"
20
- assert cacher._fetch(2, 4) == b"cd"
21
-
22
-
23
- def test_block_cache_lru():
24
- # BlockCache is a cache that stores blocks of data and uses LRU to evict
25
- block_size = 4
26
- cache = BlockCache(
27
- block_size, letters_fetcher, len(string.ascii_letters), maxblocks=2
28
- )
29
- # miss
30
- cache._fetch(0, 2)
31
- assert cache.cache_info().misses == 1
32
- assert cache.cache_info().currsize == 1
33
- assert cache.total_requested_bytes == block_size * cache.miss_count
34
- assert cache.size == 52
35
-
36
- # hit
37
- cache._fetch(0, 2)
38
- assert cache.cache_info().misses == 1
39
- assert cache.cache_info().currsize == 1
40
- assert cache.total_requested_bytes == block_size * cache.miss_count
41
-
42
- # hit
43
- cache._fetch(0, 2)
44
- assert cache.cache_info().misses == 1
45
- assert cache.cache_info().currsize == 1
46
- # this works as a counter since all the reads are from the cache
47
- assert cache.hit_count == 3
48
- assert cache.miss_count == 1
49
- # so far only 4 bytes have been read using range requests
50
- assert cache.total_requested_bytes == block_size * cache.miss_count
51
-
52
- # miss
53
- cache._fetch(4, 6)
54
- assert cache.cache_info().misses == 2
55
- assert cache.cache_info().currsize == 2
56
- assert cache.total_requested_bytes == block_size * cache.miss_count
57
-
58
- # miss & evict
59
- cache._fetch(12, 13)
60
- assert cache.cache_info().misses == 3
61
- assert cache.cache_info().currsize == 2
62
- assert cache.hit_count == 5
63
- assert cache.miss_count == 3
64
- assert cache.total_requested_bytes == block_size * cache.miss_count
65
-
66
-
67
- def test_first_cache():
68
- """
69
- FirstChunkCache is a cache that only caches the first chunk of data
70
- when some of that first block is requested.
71
- """
72
- block_size = 5
73
- cache = FirstChunkCache(block_size, letters_fetcher, len(string.ascii_letters))
74
- assert cache.cache is None
75
- assert cache._fetch(12, 15) == letters_fetcher(12, 15)
76
- assert cache.miss_count == 1
77
- assert cache.hit_count == 0
78
- assert cache.cache is None
79
- total_requested_bytes = 15 - 12
80
- assert cache.total_requested_bytes == total_requested_bytes
81
-
82
- # because we overlap with the cache range, it will be cached
83
- assert cache._fetch(3, 10) == letters_fetcher(3, 10)
84
- assert cache.miss_count == 2
85
- assert cache.hit_count == 0
86
- # we'll read the first 5 and then the rest
87
- total_requested_bytes += block_size + 5
88
- assert cache.total_requested_bytes == total_requested_bytes
89
-
90
- # partial hit again
91
- assert cache._fetch(3, 10) == letters_fetcher(3, 10)
92
- assert cache.miss_count == 2
93
- assert cache.hit_count == 1
94
- # we have the first 5 bytes cached
95
- total_requested_bytes += 10 - 5
96
- assert cache.total_requested_bytes == total_requested_bytes
97
-
98
- assert cache.cache == letters_fetcher(0, 5)
99
- assert cache._fetch(0, 4) == letters_fetcher(0, 4)
100
- assert cache.hit_count == 2
101
- assert cache.miss_count == 2
102
- assert cache.total_requested_bytes == 18
103
-
104
-
105
- def test_readahead_cache():
106
- """
107
- ReadAheadCache is a cache that reads ahead of the requested range.
108
- If the access pattern is not sequential it will be very inefficient.
109
- """
110
- block_size = 5
111
- cache = ReadAheadCache(block_size, letters_fetcher, len(string.ascii_letters))
112
- assert cache._fetch(12, 15) == letters_fetcher(12, 15)
113
- assert cache.miss_count == 1
114
- assert cache.hit_count == 0
115
- total_requested_bytes = 15 - 12 + block_size
116
- assert cache.total_requested_bytes == total_requested_bytes
117
-
118
- assert cache._fetch(3, 10) == letters_fetcher(3, 10)
119
- assert cache.miss_count == 2
120
- assert cache.hit_count == 0
121
- assert len(cache.cache) == 12
122
- total_requested_bytes += (10 - 3) + block_size
123
- assert cache.total_requested_bytes == total_requested_bytes
124
-
125
- # caache hit again
126
- assert cache._fetch(3, 10) == letters_fetcher(3, 10)
127
- assert cache.miss_count == 2
128
- assert cache.hit_count == 1
129
- assert len(cache.cache) == 12
130
- assert cache.total_requested_bytes == total_requested_bytes
131
- assert cache.cache == letters_fetcher(3, 15)
132
-
133
- # cache miss
134
- assert cache._fetch(0, 4) == letters_fetcher(0, 4)
135
- assert cache.hit_count == 1
136
- assert cache.miss_count == 3
137
- assert len(cache.cache) == 9
138
- total_requested_bytes += (4 - 0) + block_size
139
- assert cache.total_requested_bytes == total_requested_bytes
140
-
141
-
142
- def _fetcher(start, end):
143
- return b"0" * (end - start)
144
-
145
-
146
- def letters_fetcher(start, end):
147
- return string.ascii_letters[start:end].encode()
148
-
149
-
150
- not_parts_caches = {k: v for k, v in caches.items() if k != "parts"}
151
-
152
-
153
- @pytest.fixture(params=not_parts_caches.values(), ids=list(not_parts_caches))
154
- def Cache_imp(request):
155
- return request.param
156
-
157
-
158
- def test_cache_empty_file(Cache_imp):
159
- blocksize = 5
160
- size = 0
161
- cache = Cache_imp(blocksize, _fetcher, size)
162
- assert cache._fetch(0, 0) == b""
163
-
164
-
165
- def test_cache_pickleable(Cache_imp):
166
- blocksize = 5
167
- size = 100
168
- cache = Cache_imp(blocksize, _fetcher, size)
169
- cache._fetch(0, 5) # fill in cache
170
- unpickled = pickle.loads(pickle.dumps(cache))
171
- assert isinstance(unpickled, Cache_imp)
172
- assert unpickled.blocksize == blocksize
173
- assert unpickled.size == size
174
- assert unpickled._fetch(0, 10) == b"0" * 10
175
-
176
-
177
- @pytest.mark.parametrize(
178
- "size_requests",
179
- [[(0, 30), (0, 35), (51, 52)], [(0, 1), (1, 11), (1, 52)], [(0, 52), (11, 15)]],
180
- )
181
- @pytest.mark.parametrize("blocksize", [1, 10, 52, 100])
182
- def test_cache_basic(Cache_imp, blocksize, size_requests):
183
- cache = Cache_imp(blocksize, letters_fetcher, len(string.ascii_letters))
184
-
185
- for start, end in size_requests:
186
- result = cache._fetch(start, end)
187
- expected = string.ascii_letters[start:end].encode()
188
- assert result == expected
189
-
190
-
191
- @pytest.mark.parametrize("strict", [True, False])
192
- @pytest.mark.parametrize("sort", [True, False])
193
- def test_known(sort, strict):
194
- parts = {(10, 20): b"1" * 10, (20, 30): b"2" * 10, (0, 10): b"0" * 10}
195
- if sort:
196
- parts = dict(sorted(parts.items()))
197
- c = caches["parts"](None, None, 100, parts, strict=strict)
198
- assert (0, 30) in c.data # got consolidated
199
- assert c._fetch(5, 15) == b"0" * 5 + b"1" * 5
200
- assert c._fetch(15, 25) == b"1" * 5 + b"2" * 5
201
- if strict:
202
- # Over-read will raise error
203
- with pytest.raises(ValueError):
204
- # tries to call None fetcher
205
- c._fetch(25, 35)
206
- else:
207
- # Over-read will be zero-padded
208
- assert c._fetch(25, 35) == b"2" * 5 + b"\x00" * 5
209
-
210
-
211
- def test_background(server, monkeypatch):
212
- import threading
213
- import time
214
-
215
- import fsspec
216
-
217
- head = {"head_ok": "true", "head_give_length": "true"}
218
- urla = server + "/index/realfile"
219
- h = fsspec.filesystem("http", headers=head)
220
- thread_ids = {threading.current_thread().ident}
221
- f = h.open(urla, block_size=5, cache_type="background")
222
- orig = f.cache._fetch_block
223
-
224
- def wrapped(*a, **kw):
225
- thread_ids.add(threading.current_thread().ident)
226
- return orig(*a, **kw)
227
-
228
- f.cache._fetch_block = wrapped
229
- assert len(thread_ids) == 1
230
- f.read(1)
231
- time.sleep(0.1) # second block is loading
232
- assert len(thread_ids) == 2
233
-
234
-
235
- def test_register_cache():
236
- # just test that we have them populated and fail to re-add again unless overload
237
- with pytest.raises(ValueError):
238
- register_cache(BlockCache)
239
- register_cache(BlockCache, clobber=True)
240
-
241
-
242
- def test_cache_kwargs(mocker):
243
- # test that kwargs are passed to the underlying filesystem after cache commit
244
-
245
- fs = WholeFileCacheFileSystem(target_protocol="memory")
246
- fs.touch("test")
247
- fs.fs.put = mocker.MagicMock()
248
-
249
- with fs.open("test", "wb", overwrite=True) as file_handle:
250
- file_handle.write(b"foo")
251
-
252
- # We don't care about the first parameter, just retrieve its expected value.
253
- # It is a random location that cannot be predicted.
254
- # The important thing is the 'overwrite' kwarg
255
- fs.fs.put.assert_called_with(fs.fs.put.call_args[0][0], "/test", overwrite=True)
@@ -1,89 +0,0 @@
1
- import pytest
2
-
3
- from fsspec.callbacks import Callback, TqdmCallback
4
-
5
-
6
- def test_callbacks():
7
- empty_callback = Callback()
8
- assert empty_callback.call("something", somearg=None) is None
9
-
10
- hooks = {"something": lambda *_, arg=None: arg + 2}
11
- simple_callback = Callback(hooks=hooks)
12
- assert simple_callback.call("something", arg=2) == 4
13
-
14
- hooks = {"something": lambda *_, arg1=None, arg2=None: arg1 + arg2}
15
- multi_arg_callback = Callback(hooks=hooks)
16
- assert multi_arg_callback.call("something", arg1=2, arg2=2) == 4
17
-
18
-
19
- def test_callbacks_as_callback():
20
- empty_callback = Callback.as_callback(None)
21
- assert empty_callback.call("something", arg="somearg") is None
22
- assert Callback.as_callback(None) is Callback.as_callback(None)
23
-
24
- hooks = {"something": lambda *_, arg=None: arg + 2}
25
- real_callback = Callback.as_callback(Callback(hooks=hooks))
26
- assert real_callback.call("something", arg=2) == 4
27
-
28
-
29
- def test_callbacks_as_context_manager(mocker):
30
- spy_close = mocker.spy(Callback, "close")
31
-
32
- with Callback() as cb:
33
- assert isinstance(cb, Callback)
34
-
35
- spy_close.assert_called_once()
36
-
37
-
38
- def test_callbacks_branched():
39
- callback = Callback()
40
-
41
- branch = callback.branched("path_1", "path_2")
42
-
43
- assert branch is not callback
44
- assert isinstance(branch, Callback)
45
-
46
-
47
- @pytest.mark.asyncio
48
- async def test_callbacks_branch_coro(mocker):
49
- async_fn = mocker.AsyncMock(return_value=10)
50
- callback = Callback()
51
- wrapped_fn = callback.branch_coro(async_fn)
52
- spy = mocker.spy(callback, "branched")
53
-
54
- assert await wrapped_fn("path_1", "path_2", key="value") == 10
55
-
56
- spy.assert_called_once_with("path_1", "path_2", key="value")
57
- async_fn.assert_called_once_with(
58
- "path_1", "path_2", callback=spy.spy_return, key="value"
59
- )
60
-
61
-
62
- def test_callbacks_wrap():
63
- events = []
64
-
65
- class TestCallback(Callback):
66
- def relative_update(self, inc=1):
67
- events.append(inc)
68
-
69
- callback = TestCallback()
70
- for _ in callback.wrap(range(10)):
71
- ...
72
-
73
- assert events == [1] * 10
74
-
75
-
76
- @pytest.mark.parametrize("tqdm_kwargs", [{}, {"desc": "A custom desc"}])
77
- def test_tqdm_callback(tqdm_kwargs, mocker):
78
- pytest.importorskip("tqdm")
79
- callback = TqdmCallback(tqdm_kwargs=tqdm_kwargs)
80
- mocker.patch.object(callback, "_tqdm_cls")
81
- callback.set_size(10)
82
- for _ in callback.wrap(range(10)):
83
- ...
84
-
85
- assert callback.tqdm.update.call_count == 11
86
- if not tqdm_kwargs:
87
- callback._tqdm_cls.assert_called_with(total=10)
88
- else:
89
- callback._tqdm_cls.assert_called_with(total=10, **tqdm_kwargs)