akshare-cli 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,111 @@
1
+ """Tests for the in-memory TTL result cache."""
2
+
3
+ import time
4
+
5
+ import pandas as pd
6
+ import pytest
7
+
8
+ from akshare_cli.core.cache import (
9
+ DEFAULT_TTL,
10
+ NO_CACHE,
11
+ ResultCache,
12
+ _result_cache,
13
+ get_ttl,
14
+ )
15
+
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # get_ttl()
19
+ # ---------------------------------------------------------------------------
20
+
21
+ class TestGetTtl:
22
+ def test_no_cache_function_returns_zero(self):
23
+ for name in ("stock_zh_a_spot_em", "forex_spot_em"):
24
+ assert get_ttl(name) == 0
25
+
26
+ def test_custom_ttl(self):
27
+ assert get_ttl("futures_hist_table_em") == 86400
28
+
29
+ def test_default_ttl(self):
30
+ assert get_ttl("some_unknown_func") == DEFAULT_TTL
31
+
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # ResultCache
35
+ # ---------------------------------------------------------------------------
36
+
37
+ class TestResultCache:
38
+ def setup_method(self):
39
+ self.cache = ResultCache()
40
+
41
+ def test_miss_returns_none(self):
42
+ assert self.cache.get("foo", {"a": 1}) is None
43
+ assert self.cache.misses == 1
44
+
45
+ def test_put_then_hit(self):
46
+ df = pd.DataFrame({"x": [1, 2, 3]})
47
+ self.cache.put("foo", {"a": 1}, df, ttl=60)
48
+ result = self.cache.get("foo", {"a": 1})
49
+ assert result is not None
50
+ assert list(result["x"]) == [1, 2, 3]
51
+ assert self.cache.hits == 1
52
+
53
+ def test_different_kwargs_miss(self):
54
+ self.cache.put("foo", {"a": 1}, "data1", ttl=60)
55
+ assert self.cache.get("foo", {"a": 2}) is None
56
+
57
+ def test_ttl_expiry(self):
58
+ self.cache.put("foo", {}, "data", ttl=1)
59
+ assert self.cache.get("foo", {}) == "data"
60
+ time.sleep(1.1)
61
+ assert self.cache.get("foo", {}) is None
62
+
63
+ def test_disabled_cache_never_hits(self):
64
+ self.cache.enabled = False
65
+ self.cache.put("foo", {}, "data", ttl=60)
66
+ assert self.cache.get("foo", {}) is None
67
+
68
+ def test_zero_ttl_not_stored(self):
69
+ self.cache.put("foo", {}, "data", ttl=0)
70
+ assert self.cache.get("foo", {}) is None
71
+
72
+ def test_clear(self):
73
+ self.cache.put("a", {}, 1, ttl=60)
74
+ self.cache.put("b", {}, 2, ttl=60)
75
+ self.cache.get("a", {}) # hit
76
+ self.cache.clear()
77
+ assert self.cache.get("a", {}) is None
78
+ assert self.cache.hits == 0
79
+ assert self.cache.misses == 1 # the get after clear
80
+
81
+ def test_stats(self):
82
+ self.cache.put("x", {}, "d", ttl=60)
83
+ self.cache.get("x", {}) # hit
84
+ self.cache.get("y", {}) # miss
85
+ stats = self.cache.stats()
86
+ assert stats["entries"] == 1
87
+ assert stats["hits"] == 1
88
+ assert stats["misses"] == 1
89
+ assert stats["hit_rate"] == "50.0%"
90
+
91
+ def test_stats_no_requests(self):
92
+ stats = self.cache.stats()
93
+ assert stats["hit_rate"] == "N/A"
94
+
95
+
96
+ # ---------------------------------------------------------------------------
97
+ # Global singleton reset safety
98
+ # ---------------------------------------------------------------------------
99
+
100
+ class TestGlobalCache:
101
+ def setup_method(self):
102
+ _result_cache.clear()
103
+ _result_cache.enabled = True
104
+
105
+ def teardown_method(self):
106
+ _result_cache.clear()
107
+ _result_cache.enabled = True
108
+
109
+ def test_global_put_get(self):
110
+ _result_cache.put("test_fn", {"k": "v"}, [1, 2], ttl=60)
111
+ assert _result_cache.get("test_fn", {"k": "v"}) == [1, 2]
@@ -0,0 +1,439 @@
1
+ """
2
+ Unit tests for AKShare CLI harness core modules.
3
+
4
+ Uses synthetic data and mocks — no network calls required.
5
+ """
6
+
7
+ import json
8
+ import os
9
+ import tempfile
10
+
11
+ import pandas as pd
12
+ import pytest
13
+
14
+ from akshare_cli.core.registry import (
15
+ _auto_fill_date_params,
16
+ call_function,
17
+ get_all_functions,
18
+ get_domains,
19
+ get_function,
20
+ get_function_info,
21
+ list_functions_by_domain,
22
+ parse_param_value,
23
+ search_functions,
24
+ )
25
+ from akshare_cli.core.session import Session
26
+ from akshare_cli.core.export import (
27
+ auto_export,
28
+ export_csv,
29
+ export_json,
30
+ export_excel,
31
+ export_markdown,
32
+ format_csv,
33
+ format_json,
34
+ format_table,
35
+ )
36
+ from akshare_cli.utils.formatting import (
37
+ display_result,
38
+ truncate_string,
39
+ )
40
+
41
+
42
+ # ─── Fixtures ────────────────────────────────────────────────────────────────
43
+
44
+ @pytest.fixture
45
+ def sample_df():
46
+ """Create a sample DataFrame for testing."""
47
+ return pd.DataFrame({
48
+ "date": ["2024-01-01", "2024-01-02", "2024-01-03"],
49
+ "open": [10.0, 10.5, 11.0],
50
+ "close": [10.5, 11.0, 10.8],
51
+ "volume": [1000, 1500, 1200],
52
+ })
53
+
54
+
55
+ @pytest.fixture
56
+ def large_df():
57
+ """Create a large DataFrame for truncation tests."""
58
+ return pd.DataFrame({
59
+ "idx": list(range(200)),
60
+ "value": [float(i) * 1.1 for i in range(200)],
61
+ })
62
+
63
+
64
+ @pytest.fixture
65
+ def tmp_dir():
66
+ """Create a temporary directory for export tests."""
67
+ with tempfile.TemporaryDirectory() as d:
68
+ yield d
69
+
70
+
71
+ @pytest.fixture
72
+ def session():
73
+ """Create a fresh Session instance."""
74
+ with tempfile.TemporaryDirectory() as d:
75
+ yield Session(session_dir=d)
76
+
77
+
78
+ # ─── Registry Tests ─────────────────────────────────────────────────────────
79
+
80
+
81
+ class TestRegistry:
82
+ def test_get_all_functions(self):
83
+ funcs = get_all_functions()
84
+ assert isinstance(funcs, dict)
85
+ assert len(funcs) > 500, f"Expected 500+ functions, got {len(funcs)}"
86
+
87
+ def test_get_function_existing(self):
88
+ func = get_function("stock_zh_a_hist")
89
+ assert func is not None
90
+ assert callable(func)
91
+
92
+ def test_get_function_missing(self):
93
+ func = get_function("totally_nonexistent_function_xyz")
94
+ assert func is None
95
+
96
+ def test_search_functions(self):
97
+ results = search_functions("stock_zh")
98
+ assert len(results) > 0
99
+ assert all("stock_zh" in name.lower() for name in results)
100
+
101
+ def test_search_functions_case_insensitive(self):
102
+ results_lower = search_functions("stock_zh")
103
+ results_upper = search_functions("STOCK_ZH")
104
+ assert results_lower == results_upper
105
+
106
+ def test_search_functions_no_match(self):
107
+ results = search_functions("zzzznonexistent999")
108
+ assert results == []
109
+
110
+ def test_list_functions_by_domain(self):
111
+ results = list_functions_by_domain("stock")
112
+ assert len(results) > 0
113
+ assert all(name.startswith("stock_") for name in results)
114
+
115
+ def test_list_functions_by_domain_empty(self):
116
+ results = list_functions_by_domain("zzzznonexistent")
117
+ assert results == []
118
+
119
+ def test_get_function_info(self):
120
+ info = get_function_info("stock_zh_a_hist")
121
+ assert info is not None
122
+ assert info["name"] == "stock_zh_a_hist"
123
+ assert isinstance(info["params"], list)
124
+ assert len(info["params"]) > 0
125
+ assert isinstance(info["docstring"], str)
126
+
127
+ # Check known parameters
128
+ param_names = [p["name"] for p in info["params"]]
129
+ assert "symbol" in param_names
130
+ assert "period" in param_names
131
+
132
+ def test_get_function_info_missing(self):
133
+ info = get_function_info("totally_nonexistent_xyz")
134
+ assert info is None
135
+
136
+ def test_get_domains(self):
137
+ domains = get_domains()
138
+ assert isinstance(domains, list)
139
+ assert len(domains) > 10
140
+ assert "stock" in domains
141
+ assert "fund" in domains
142
+ assert "futures" in domains
143
+
144
+ def test_parse_param_value_string(self):
145
+ assert parse_param_value("hello", {}) == "hello"
146
+
147
+ def test_parse_param_value_int(self):
148
+ assert parse_param_value("42", {}) == 42
149
+
150
+ def test_parse_param_value_float(self):
151
+ assert parse_param_value("3.14", {}) == 3.14
152
+
153
+ def test_parse_param_value_bool_true(self):
154
+ assert parse_param_value("true", {}) is True
155
+ assert parse_param_value("True", {}) is True
156
+ assert parse_param_value("yes", {}) is True
157
+
158
+ def test_parse_param_value_bool_false(self):
159
+ assert parse_param_value("false", {}) is False
160
+ assert parse_param_value("False", {}) is False
161
+ assert parse_param_value("no", {}) is False
162
+
163
+ def test_parse_param_value_none(self):
164
+ assert parse_param_value("none", {}) is None
165
+ assert parse_param_value("None", {}) is None
166
+ assert parse_param_value("null", {}) is None
167
+
168
+ def test_call_function_not_found(self):
169
+ with pytest.raises(ValueError, match="not found"):
170
+ call_function("totally_nonexistent_function_xyz", {})
171
+
172
+
173
+ # ─── Session Tests ───────────────────────────────────────────────────────────
174
+
175
+
176
+ class TestSession:
177
+ def test_session_init(self, session):
178
+ assert session.last_result is None
179
+ assert session.last_function is None
180
+ assert session.history == []
181
+ assert isinstance(session.preferences, dict)
182
+
183
+ def test_record_call(self, session, sample_df):
184
+ session.record_call("test_func", {"key": "val"}, sample_df)
185
+ assert session.last_function == "test_func"
186
+ assert session.last_result is not None
187
+ assert len(session.history) == 1
188
+ assert session.history[0]["function"] == "test_func"
189
+ assert session.history[0]["rows"] == 3
190
+
191
+ def test_last_result(self, session, sample_df):
192
+ session.record_call("func1", {}, sample_df)
193
+ assert session.last_result.equals(sample_df)
194
+
195
+ def test_history_ordering(self, session, sample_df):
196
+ session.record_call("func1", {}, sample_df)
197
+ session.record_call("func2", {}, sample_df)
198
+ session.record_call("func3", {}, sample_df)
199
+ assert len(session.history) == 3
200
+ assert session.history[0]["function"] == "func1"
201
+ assert session.history[2]["function"] == "func3"
202
+
203
+ def test_preferences_get_set(self, session):
204
+ session.set_preference("max_rows", 100)
205
+ assert session.get_preference("max_rows") == 100
206
+
207
+ def test_preferences_default(self, session):
208
+ assert session.get_preference("nonexistent", "default") == "default"
209
+
210
+ def test_save_history(self, session, sample_df, tmp_dir):
211
+ session.record_call("test_func", {"a": "b"}, sample_df)
212
+ filepath = os.path.join(tmp_dir, "test_history.json")
213
+ saved = session.save_history(filepath)
214
+ assert os.path.exists(saved)
215
+
216
+ with open(saved, encoding="utf-8") as f:
217
+ data = json.load(f)
218
+ assert "history" in data
219
+ assert len(data["history"]) == 1
220
+ assert data["history"][0]["function"] == "test_func"
221
+
222
+ def test_clear(self, session, sample_df):
223
+ session.record_call("func1", {}, sample_df)
224
+ session.clear()
225
+ assert session.last_result is None
226
+ assert session.last_function is None
227
+ assert session.history == []
228
+
229
+ def test_summary(self, session, sample_df):
230
+ session.record_call("func1", {}, sample_df)
231
+ s = session.summary()
232
+ assert s["total_calls"] == 1
233
+ assert s["last_function"] == "func1"
234
+ assert "3x4" in s["last_result_shape"]
235
+
236
+
237
+ # ─── Export Tests ────────────────────────────────────────────────────────────
238
+
239
+
240
+ class TestExport:
241
+ def test_export_csv(self, sample_df, tmp_dir):
242
+ path = os.path.join(tmp_dir, "test.csv")
243
+ result = export_csv(sample_df, path)
244
+ assert os.path.exists(result)
245
+ loaded = pd.read_csv(result)
246
+ assert len(loaded) == 3
247
+ assert list(loaded.columns) == ["date", "open", "close", "volume"]
248
+
249
+ def test_export_json(self, sample_df, tmp_dir):
250
+ path = os.path.join(tmp_dir, "test.json")
251
+ result = export_json(sample_df, path)
252
+ assert os.path.exists(result)
253
+ with open(result, encoding="utf-8") as f:
254
+ data = json.load(f)
255
+ assert isinstance(data, list)
256
+ assert len(data) == 3
257
+
258
+ def test_export_excel(self, sample_df, tmp_dir):
259
+ path = os.path.join(tmp_dir, "test.xlsx")
260
+ result = export_excel(sample_df, path)
261
+ assert os.path.exists(result)
262
+ loaded = pd.read_excel(result)
263
+ assert len(loaded) == 3
264
+
265
+ def test_export_markdown(self, sample_df, tmp_dir):
266
+ path = os.path.join(tmp_dir, "test.md")
267
+ result = export_markdown(sample_df, path)
268
+ assert os.path.exists(result)
269
+ with open(result, encoding="utf-8") as f:
270
+ content = f.read()
271
+ assert "date" in content
272
+ assert "|" in content
273
+
274
+ def test_format_table(self, sample_df):
275
+ table = format_table(sample_df)
276
+ assert "date" in table
277
+ assert "2024-01-01" in table
278
+ assert "10.0" in table
279
+
280
+ def test_format_table_truncated(self, large_df):
281
+ table = format_table(large_df, max_rows=20)
282
+ assert "200 rows total" in table
283
+
284
+ def test_format_json(self, sample_df):
285
+ result = format_json(sample_df)
286
+ data = json.loads(result)
287
+ assert data["total_rows"] == 3
288
+ assert "columns" in data
289
+ assert "data" in data
290
+
291
+ def test_format_json_limited(self, large_df):
292
+ result = format_json(large_df, max_rows=5)
293
+ data = json.loads(result)
294
+ assert len(data["data"]) == 5
295
+
296
+ def test_format_csv(self, sample_df):
297
+ result = format_csv(sample_df)
298
+ lines = result.strip().split("\n")
299
+ assert len(lines) == 4 # header + 3 rows
300
+ assert "date" in lines[0]
301
+
302
+ def test_auto_export_csv(self, sample_df, tmp_dir):
303
+ path = os.path.join(tmp_dir, "auto.csv")
304
+ result = auto_export(sample_df, path)
305
+ assert os.path.exists(result)
306
+
307
+ def test_auto_export_json(self, sample_df, tmp_dir):
308
+ path = os.path.join(tmp_dir, "auto.json")
309
+ result = auto_export(sample_df, path)
310
+ assert os.path.exists(result)
311
+
312
+ def test_auto_export_excel(self, sample_df, tmp_dir):
313
+ path = os.path.join(tmp_dir, "auto.xlsx")
314
+ result = auto_export(sample_df, path)
315
+ assert os.path.exists(result)
316
+
317
+ def test_auto_export_unsupported(self, sample_df, tmp_dir):
318
+ path = os.path.join(tmp_dir, "auto.xyz")
319
+ with pytest.raises(ValueError, match="Unsupported"):
320
+ auto_export(sample_df, path)
321
+
322
+
323
+ # ─── Formatting Tests ────────────────────────────────────────────────────────
324
+
325
+
326
+ class TestFormatting:
327
+ def test_display_result_table(self, sample_df):
328
+ output = display_result(sample_df, output_format="table")
329
+ assert "date" in output
330
+ assert "2024-01-01" in output
331
+
332
+ def test_display_result_json(self, sample_df):
333
+ output = display_result(sample_df, output_format="json")
334
+ data = json.loads(output)
335
+ assert data["total_rows"] == 3
336
+
337
+ def test_display_result_csv(self, sample_df):
338
+ output = display_result(sample_df, output_format="csv")
339
+ assert "date,open,close,volume" in output
340
+
341
+ def test_display_result_with_limit(self, large_df):
342
+ output = display_result(large_df, output_format="table", max_rows=10)
343
+ assert "200 rows total" in output
344
+
345
+ def test_truncate_string_short(self):
346
+ assert truncate_string("hello", 10) == "hello"
347
+
348
+ def test_truncate_string_long(self):
349
+ result = truncate_string("a" * 100, 20)
350
+ assert len(result) == 20
351
+ assert result.endswith("...")
352
+
353
+ def test_truncate_string_exact(self):
354
+ assert truncate_string("hello", 5) == "hello"
355
+
356
+
357
+ # ─── Auto-fill Date Tests ────────────────────────────────────────────────────
358
+
359
+
360
+ class TestAutoFillDate:
361
+ """Tests for _auto_fill_date_params."""
362
+
363
+ def _make_param(self, name, default=None, has_default=True):
364
+ p = {"name": name, "has_default": has_default}
365
+ if has_default and default is not None:
366
+ p["default"] = default
367
+ return p
368
+
369
+ def test_fills_date_param_with_today(self):
370
+ from datetime import date
371
+ params = [self._make_param("date", "20251126")]
372
+ kwargs = {}
373
+ _auto_fill_date_params(params, kwargs)
374
+ expected = date.today().strftime("%Y%m%d")
375
+ assert kwargs["date"] == expected
376
+
377
+ def test_fills_start_date_param(self):
378
+ from datetime import date
379
+ params = [self._make_param("start_date", "20240101")]
380
+ kwargs = {}
381
+ _auto_fill_date_params(params, kwargs)
382
+ assert kwargs["start_date"] == date.today().strftime("%Y%m%d")
383
+
384
+ def test_fills_end_date_param(self):
385
+ from datetime import date
386
+ params = [self._make_param("end_date", "20240301")]
387
+ kwargs = {}
388
+ _auto_fill_date_params(params, kwargs)
389
+ assert kwargs["end_date"] == date.today().strftime("%Y%m%d")
390
+
391
+ def test_skips_user_provided_date(self):
392
+ params = [self._make_param("date", "20251126")]
393
+ kwargs = {"date": "20260101"}
394
+ _auto_fill_date_params(params, kwargs)
395
+ assert kwargs["date"] == "20260101" # unchanged
396
+
397
+ def test_skips_non_date_param(self):
398
+ params = [self._make_param("symbol", "000001")]
399
+ kwargs = {}
400
+ _auto_fill_date_params(params, kwargs)
401
+ assert "symbol" not in kwargs
402
+
403
+ def test_skips_non_yyyymmdd_default(self):
404
+ params = [self._make_param("date", "2024-01-01")]
405
+ kwargs = {}
406
+ _auto_fill_date_params(params, kwargs)
407
+ assert "date" not in kwargs
408
+
409
+ def test_skips_no_default(self):
410
+ params = [{"name": "date", "has_default": False}]
411
+ kwargs = {}
412
+ _auto_fill_date_params(params, kwargs)
413
+ assert "date" not in kwargs
414
+
415
+ def test_skips_none_default(self):
416
+ params = [self._make_param("date", None)]
417
+ kwargs = {}
418
+ _auto_fill_date_params(params, kwargs)
419
+ assert "date" not in kwargs
420
+
421
+ def test_skips_int_default(self):
422
+ params = [self._make_param("date", 20240101)]
423
+ kwargs = {}
424
+ _auto_fill_date_params(params, kwargs)
425
+ assert "date" not in kwargs
426
+
427
+ def test_multiple_date_params(self):
428
+ from datetime import date
429
+ params = [
430
+ self._make_param("start_date", "20240101"),
431
+ self._make_param("end_date", "20240301"),
432
+ self._make_param("symbol", "000001"),
433
+ ]
434
+ kwargs = {"start_date": "20260101"} # user provided start_date only
435
+ _auto_fill_date_params(params, kwargs)
436
+ today = date.today().strftime("%Y%m%d")
437
+ assert kwargs["start_date"] == "20260101" # user value kept
438
+ assert kwargs["end_date"] == today # auto-filled
439
+ assert "symbol" not in kwargs # not a date param