kabukit 0.1.1__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kabukit/__init__.py +4 -1
- kabukit/analysis/__init__.py +0 -0
- kabukit/analysis/indicators.py +0 -0
- kabukit/analysis/preprocess.py +0 -0
- kabukit/analysis/screener.py +0 -0
- kabukit/analysis/visualization.py +57 -0
- kabukit/cli/auth.py +2 -2
- kabukit/core/__init__.py +0 -0
- kabukit/core/base.py +45 -0
- kabukit/core/client.py +25 -0
- kabukit/core/info.py +12 -0
- kabukit/core/prices.py +30 -0
- kabukit/core/statements.py +7 -0
- kabukit/edinet/__init__.py +3 -0
- kabukit/edinet/client.py +14 -12
- kabukit/edinet/concurrent.py +153 -0
- kabukit/edinet/doc.py +32 -0
- kabukit/jquants/__init__.py +1 -1
- kabukit/jquants/client.py +63 -101
- kabukit/jquants/{stream.py → concurrent.py} +12 -43
- kabukit/jquants/info.py +9 -1
- kabukit/jquants/prices.py +29 -0
- kabukit/jquants/schema.py +19 -8
- kabukit/jquants/statements.py +35 -2
- kabukit/utils/__init__.py +0 -0
- kabukit/utils/concurrent.py +148 -0
- kabukit-0.2.1.dist-info/METADATA +67 -0
- kabukit-0.2.1.dist-info/RECORD +35 -0
- {kabukit-0.1.1.dist-info → kabukit-0.2.1.dist-info}/WHEEL +1 -1
- kabukit/concurrent.py +0 -40
- kabukit-0.1.1.dist-info/METADATA +0 -30
- kabukit-0.1.1.dist-info/RECORD +0 -21
- /kabukit/{config.py → utils/config.py} +0 -0
- /kabukit/{params.py → utils/params.py} +0 -0
- {kabukit-0.1.1.dist-info → kabukit-0.2.1.dist-info}/entry_points.txt +0 -0
kabukit/jquants/client.py
CHANGED
@@ -6,13 +6,13 @@ from enum import StrEnum
|
|
6
6
|
from typing import TYPE_CHECKING
|
7
7
|
|
8
8
|
import polars as pl
|
9
|
-
from httpx import AsyncClient
|
10
9
|
from polars import DataFrame
|
11
10
|
|
12
|
-
from kabukit.
|
13
|
-
from kabukit.
|
11
|
+
from kabukit.core.client import Client
|
12
|
+
from kabukit.utils.config import load_dotenv, set_key
|
13
|
+
from kabukit.utils.params import get_params
|
14
14
|
|
15
|
-
from . import statements
|
15
|
+
from . import info, prices, statements
|
16
16
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
from collections.abc import AsyncIterator
|
@@ -32,7 +32,7 @@ class AuthKey(StrEnum):
|
|
32
32
|
ID_TOKEN = "JQUANTS_ID_TOKEN" # noqa: S105
|
33
33
|
|
34
34
|
|
35
|
-
class JQuantsClient:
|
35
|
+
class JQuantsClient(Client):
|
36
36
|
"""J-Quants APIと対話するためのクライアント。
|
37
37
|
|
38
38
|
API認証トークン(リフレッシュトークンおよびIDトークン)を管理し、
|
@@ -43,10 +43,8 @@ class JQuantsClient:
|
|
43
43
|
client: APIリクエストを行うための `AsyncClient` インスタンス。
|
44
44
|
"""
|
45
45
|
|
46
|
-
client: AsyncClient
|
47
|
-
|
48
46
|
def __init__(self, id_token: str | None = None) -> None:
|
49
|
-
|
47
|
+
super().__init__(BASE_URL)
|
50
48
|
self.set_id_token(id_token)
|
51
49
|
|
52
50
|
def set_id_token(self, id_token: str | None = None) -> None:
|
@@ -64,43 +62,6 @@ class JQuantsClient:
|
|
64
62
|
if id_token:
|
65
63
|
self.client.headers["Authorization"] = f"Bearer {id_token}"
|
66
64
|
|
67
|
-
async def aclose(self) -> None:
|
68
|
-
"""HTTPクライアントを閉じる。"""
|
69
|
-
await self.client.aclose()
|
70
|
-
|
71
|
-
async def __aenter__(self) -> Self:
|
72
|
-
return self
|
73
|
-
|
74
|
-
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: # pyright: ignore[reportMissingParameterType, reportUnknownParameterType] # noqa: ANN001
|
75
|
-
await self.aclose()
|
76
|
-
|
77
|
-
async def auth(
|
78
|
-
self,
|
79
|
-
mailaddress: str,
|
80
|
-
password: str,
|
81
|
-
*,
|
82
|
-
save: bool = False,
|
83
|
-
) -> Self:
|
84
|
-
"""認証を行い、トークンを保存する。
|
85
|
-
|
86
|
-
Args:
|
87
|
-
mailaddress (str): J-Quantsに登録したメールアドレス。
|
88
|
-
password (str): J-Quantsのパスワード。
|
89
|
-
save (bool, optional): トークンを環境変数に保存するかどうか。
|
90
|
-
|
91
|
-
Raises:
|
92
|
-
HTTPStatusError: APIリクエストが失敗した場合。
|
93
|
-
"""
|
94
|
-
refresh_token = await self.get_refresh_token(mailaddress, password)
|
95
|
-
id_token = await self.get_id_token(refresh_token)
|
96
|
-
|
97
|
-
if save:
|
98
|
-
set_key(AuthKey.REFRESH_TOKEN, refresh_token)
|
99
|
-
set_key(AuthKey.ID_TOKEN, id_token)
|
100
|
-
|
101
|
-
self.set_id_token(id_token)
|
102
|
-
return self
|
103
|
-
|
104
65
|
async def post(self, url: str, json: Any | None = None) -> Any:
|
105
66
|
"""指定されたURLにPOSTリクエストを送信する。
|
106
67
|
|
@@ -118,55 +79,54 @@ class JQuantsClient:
|
|
118
79
|
resp.raise_for_status()
|
119
80
|
return resp.json()
|
120
81
|
|
121
|
-
async def
|
122
|
-
"""
|
82
|
+
async def get(self, url: str, params: QueryParamTypes | None = None) -> Any:
|
83
|
+
"""指定されたURLにGETリクエストを送信する。
|
123
84
|
|
124
85
|
Args:
|
125
|
-
|
126
|
-
|
86
|
+
url (str): GETリクエストのURLパス。
|
87
|
+
params (QueryParamTypes | None, optional): リクエストのクエリパラメータ。
|
127
88
|
|
128
89
|
Returns:
|
129
|
-
|
90
|
+
APIからのJSONレスポンス。
|
130
91
|
|
131
92
|
Raises:
|
132
93
|
HTTPStatusError: APIリクエストが失敗した場合。
|
133
94
|
"""
|
134
|
-
|
135
|
-
|
136
|
-
return
|
95
|
+
resp = await self.client.get(url, params=params)
|
96
|
+
resp.raise_for_status()
|
97
|
+
return resp.json()
|
137
98
|
|
138
|
-
async def
|
139
|
-
|
99
|
+
async def auth(
|
100
|
+
self,
|
101
|
+
mailaddress: str,
|
102
|
+
password: str,
|
103
|
+
*,
|
104
|
+
save: bool = False,
|
105
|
+
) -> Self:
|
106
|
+
"""認証を行い、トークンを保存する。
|
140
107
|
|
141
108
|
Args:
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
新しいIDトークン。
|
109
|
+
mailaddress (str): J-Quantsに登録したメールアドレス。
|
110
|
+
password (str): J-Quantsのパスワード。
|
111
|
+
save (bool, optional): トークンを環境変数に保存するかどうか。
|
146
112
|
|
147
113
|
Raises:
|
148
114
|
HTTPStatusError: APIリクエストが失敗した場合。
|
149
115
|
"""
|
116
|
+
json_data = {"mailaddress": mailaddress, "password": password}
|
117
|
+
data = await self.post("/token/auth_user", json=json_data)
|
118
|
+
refresh_token = data["refreshToken"]
|
119
|
+
|
150
120
|
url = f"/token/auth_refresh?refreshtoken={refresh_token}"
|
151
121
|
data = await self.post(url)
|
152
|
-
|
153
|
-
|
154
|
-
async def get(self, url: str, params: QueryParamTypes | None = None) -> Any:
|
155
|
-
"""指定されたURLにGETリクエストを送信する。
|
156
|
-
|
157
|
-
Args:
|
158
|
-
url (str): GETリクエストのURLパス。
|
159
|
-
params (QueryParamTypes | None, optional): リクエストのクエリパラメータ。
|
122
|
+
id_token = data["idToken"]
|
160
123
|
|
161
|
-
|
162
|
-
|
124
|
+
if save:
|
125
|
+
set_key(AuthKey.REFRESH_TOKEN, refresh_token)
|
126
|
+
set_key(AuthKey.ID_TOKEN, id_token)
|
163
127
|
|
164
|
-
|
165
|
-
|
166
|
-
"""
|
167
|
-
resp = await self.client.get(url, params=params)
|
168
|
-
resp.raise_for_status()
|
169
|
-
return resp.json()
|
128
|
+
self.set_id_token(id_token)
|
129
|
+
return self
|
170
130
|
|
171
131
|
async def get_info(
|
172
132
|
self,
|
@@ -176,11 +136,11 @@ class JQuantsClient:
|
|
176
136
|
"""銘柄情報を取得する。
|
177
137
|
|
178
138
|
Args:
|
179
|
-
code (str
|
180
|
-
date (str | datetime.date
|
139
|
+
code (str, optional): 情報を取得する銘柄のコード。
|
140
|
+
date (str | datetime.date, optional): 情報を取得する日付。
|
181
141
|
|
182
142
|
Returns:
|
183
|
-
銘柄情報を含む
|
143
|
+
銘柄情報を含むDataFrame。
|
184
144
|
|
185
145
|
Raises:
|
186
146
|
HTTPStatusError: APIリクエストが失敗した場合。
|
@@ -188,12 +148,7 @@ class JQuantsClient:
|
|
188
148
|
params = get_params(code=code, date=date)
|
189
149
|
url = "/listed/info"
|
190
150
|
data = await self.get(url, params)
|
191
|
-
|
192
|
-
|
193
|
-
return df.with_columns(
|
194
|
-
pl.col("Date").str.to_date("%Y-%m-%d"),
|
195
|
-
pl.col("^.*CodeName$", "ScaleCategory").cast(pl.Categorical),
|
196
|
-
).drop("^.+Code$", "CompanyNameEnglish")
|
151
|
+
return DataFrame(data["info"]).pipe(info.clean)
|
197
152
|
|
198
153
|
async def iter_pages(
|
199
154
|
self,
|
@@ -209,7 +164,7 @@ class JQuantsClient:
|
|
209
164
|
name (str): アイテムのリストを含むJSONレスポンスのキー。
|
210
165
|
|
211
166
|
Yields:
|
212
|
-
データの各ページに対応する
|
167
|
+
データの各ページに対応するDataFrame。
|
213
168
|
|
214
169
|
Raises:
|
215
170
|
HTTPStatusError: APIリクエストが失敗した場合。
|
@@ -233,14 +188,19 @@ class JQuantsClient:
|
|
233
188
|
) -> DataFrame:
|
234
189
|
"""日々の株価四本値を取得する。
|
235
190
|
|
191
|
+
株価は分割・併合を考慮した調整済み株価(小数点第2位四捨五入)と調整前の株価を取得できる。
|
192
|
+
|
236
193
|
Args:
|
237
|
-
code: 株価を取得する銘柄のコード。
|
238
|
-
date:
|
239
|
-
|
240
|
-
|
194
|
+
code (str, optional): 株価を取得する銘柄のコード。
|
195
|
+
date (str | datetime.date, optional): 株価を取得する日付。
|
196
|
+
`from_`または`to`とは併用不可。
|
197
|
+
from_ (str | datetime.date, optional): 取得期間の開始日。
|
198
|
+
`date`とは併用不可。
|
199
|
+
to (str | datetime.date, optional): 取得期間の終了日。
|
200
|
+
`date`とは併用不可。
|
241
201
|
|
242
202
|
Returns:
|
243
|
-
日々の株価四本値を含む
|
203
|
+
日々の株価四本値を含むDataFrame。
|
244
204
|
|
245
205
|
Raises:
|
246
206
|
ValueError: `date`と`from_`/`to`の両方が指定された場合。
|
@@ -250,7 +210,7 @@ class JQuantsClient:
|
|
250
210
|
return await self.get_latest_available_prices()
|
251
211
|
|
252
212
|
if date and (from_ or to):
|
253
|
-
msg = "
|
213
|
+
msg = "dateとfrom/toの両方を指定することはできません。"
|
254
214
|
raise ValueError(msg)
|
255
215
|
|
256
216
|
params = get_params(code=code, date=date, from_=from_, to=to)
|
@@ -264,16 +224,13 @@ class JQuantsClient:
|
|
264
224
|
if df.is_empty():
|
265
225
|
return df
|
266
226
|
|
267
|
-
return
|
268
|
-
pl.col("Date").str.to_date("%Y-%m-%d"),
|
269
|
-
pl.col("^.*Limit$").cast(pl.Int8).cast(pl.Boolean),
|
270
|
-
)
|
227
|
+
return prices.clean(df)
|
271
228
|
|
272
|
-
async def get_latest_available_prices(self) -> DataFrame:
|
229
|
+
async def get_latest_available_prices(self, num_days: int = 30) -> DataFrame:
|
273
230
|
"""直近利用可能な日付の株価を取得する。"""
|
274
231
|
today = datetime.date.today() # noqa: DTZ011
|
275
232
|
|
276
|
-
for days in range(
|
233
|
+
for days in range(num_days):
|
277
234
|
date = today - datetime.timedelta(days)
|
278
235
|
df = await self.get_prices(date=date)
|
279
236
|
|
@@ -287,18 +244,23 @@ class JQuantsClient:
|
|
287
244
|
code: str | None = None,
|
288
245
|
date: str | datetime.date | None = None,
|
289
246
|
) -> DataFrame:
|
290
|
-
"""
|
247
|
+
"""四半期毎の決算短信サマリーおよび業績・配当の修正に関する開示情報を取得する。
|
291
248
|
|
292
249
|
Args:
|
293
|
-
code: 財務情報を取得する銘柄のコード。
|
294
|
-
date: 財務情報を取得する日付。
|
250
|
+
code (str, optional): 財務情報を取得する銘柄のコード。
|
251
|
+
date (str | datetime.date, optional): 財務情報を取得する日付。
|
295
252
|
|
296
253
|
Returns:
|
297
|
-
財務情報を含む
|
254
|
+
財務情報を含むDataFrame。
|
298
255
|
|
299
256
|
Raises:
|
257
|
+
ValueError: `code`と`date`が両方とも指定されない場合。
|
300
258
|
HTTPStatusError: APIリクエストが失敗した場合。
|
301
259
|
"""
|
260
|
+
if not code and not date:
|
261
|
+
msg = "codeまたはdateのどちらかを指定する必要があります。"
|
262
|
+
raise ValueError(msg)
|
263
|
+
|
302
264
|
params = get_params(code=code, date=date)
|
303
265
|
url = "/fins/statements"
|
304
266
|
name = "statements"
|
@@ -1,47 +1,18 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from
|
4
|
-
from typing import TYPE_CHECKING, Any, Protocol
|
3
|
+
from typing import TYPE_CHECKING
|
5
4
|
|
6
|
-
|
7
|
-
|
8
|
-
from kabukit.concurrent import collect_fn
|
5
|
+
from kabukit.utils import concurrent
|
9
6
|
|
10
7
|
from .client import JQuantsClient
|
11
8
|
from .info import get_codes
|
12
9
|
|
13
10
|
if TYPE_CHECKING:
|
14
|
-
from collections.abc import
|
11
|
+
from collections.abc import Iterable
|
15
12
|
|
16
13
|
from polars import DataFrame
|
17
14
|
|
18
|
-
|
19
|
-
def __call__[T](
|
20
|
-
self,
|
21
|
-
aiterable: AsyncIterable[T],
|
22
|
-
total: int | None = None,
|
23
|
-
*args: Any,
|
24
|
-
**kwargs: Any,
|
25
|
-
) -> AsyncIterator[T]: ...
|
26
|
-
|
27
|
-
|
28
|
-
type Callback = Callable[[DataFrame], DataFrame | None]
|
29
|
-
|
30
|
-
|
31
|
-
@dataclass
|
32
|
-
class Stream:
|
33
|
-
"""JQuantsから各種データを銘柄コードごとにストリーム形式で取得する。"""
|
34
|
-
|
35
|
-
resource: str
|
36
|
-
codes: list[str]
|
37
|
-
max_concurrency: int | None = None
|
38
|
-
|
39
|
-
async def __aiter__(self) -> AsyncIterator[DataFrame]:
|
40
|
-
async with JQuantsClient() as client:
|
41
|
-
fn = getattr(client, f"get_{self.resource}")
|
42
|
-
|
43
|
-
async for df in collect_fn(fn, self.codes, self.max_concurrency):
|
44
|
-
yield df
|
15
|
+
from kabukit.utils.concurrent import Callback, Progress
|
45
16
|
|
46
17
|
|
47
18
|
async def fetch(
|
@@ -70,16 +41,14 @@ async def fetch(
|
|
70
41
|
DataFrame:
|
71
42
|
すべての銘柄の財務情報を含む単一のDataFrame。
|
72
43
|
"""
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
return pl.concat([df async for df in stream if not df.is_empty()])
|
44
|
+
return await concurrent.fetch(
|
45
|
+
JQuantsClient,
|
46
|
+
resource,
|
47
|
+
codes,
|
48
|
+
max_concurrency=max_concurrency,
|
49
|
+
progress=progress,
|
50
|
+
callback=callback,
|
51
|
+
)
|
83
52
|
|
84
53
|
|
85
54
|
async def fetch_all(
|
kabukit/jquants/info.py
CHANGED
@@ -1,8 +1,14 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import polars as pl
|
4
|
+
from polars import DataFrame
|
4
5
|
|
5
|
-
|
6
|
+
|
7
|
+
def clean(df: DataFrame) -> DataFrame:
|
8
|
+
return df.with_columns(
|
9
|
+
pl.col("Date").str.to_date("%Y-%m-%d"),
|
10
|
+
pl.col("^.*CodeName$", "ScaleCategory").cast(pl.Categorical),
|
11
|
+
).drop("^.+Code$", "CompanyNameEnglish")
|
6
12
|
|
7
13
|
|
8
14
|
async def get_codes() -> list[str]:
|
@@ -10,6 +16,8 @@ async def get_codes() -> list[str]:
|
|
10
16
|
|
11
17
|
市場「TOKYO PRO MARKET」と業種「その他」を除外した銘柄を対象とする。
|
12
18
|
"""
|
19
|
+
from .client import JQuantsClient
|
20
|
+
|
13
21
|
async with JQuantsClient() as client:
|
14
22
|
info = await client.get_info()
|
15
23
|
|
kabukit/jquants/prices.py
CHANGED
@@ -0,0 +1,29 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
import polars as pl
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from polars import DataFrame
|
9
|
+
|
10
|
+
|
11
|
+
def clean(df: DataFrame) -> DataFrame:
|
12
|
+
return df.select(
|
13
|
+
pl.col("Date").str.to_date("%Y-%m-%d"),
|
14
|
+
"Code",
|
15
|
+
Open=pl.col("AdjustmentOpen"),
|
16
|
+
High=pl.col("AdjustmentHigh"),
|
17
|
+
Low=pl.col("AdjustmentLow"),
|
18
|
+
Close=pl.col("AdjustmentClose"),
|
19
|
+
UpperLimit=pl.col("UpperLimit").cast(pl.Int8).cast(pl.Boolean),
|
20
|
+
LowerLimit=pl.col("LowerLimit").cast(pl.Int8).cast(pl.Boolean),
|
21
|
+
Volume=pl.col("AdjustmentVolume"),
|
22
|
+
TurnoverValue=pl.col("TurnoverValue"),
|
23
|
+
AdjustmentFactor=pl.col("AdjustmentFactor"),
|
24
|
+
RawOpen=pl.col("Open"),
|
25
|
+
RawHigh=pl.col("High"),
|
26
|
+
RawLow=pl.col("Low"),
|
27
|
+
RawClose=pl.col("Close"),
|
28
|
+
RawVolume=pl.col("Volume"),
|
29
|
+
)
|
kabukit/jquants/schema.py
CHANGED
@@ -9,7 +9,7 @@ if TYPE_CHECKING:
|
|
9
9
|
|
10
10
|
class BaseColumns(Enum):
|
11
11
|
@classmethod
|
12
|
-
def rename(cls, df: DataFrame, *, strict: bool =
|
12
|
+
def rename(cls, df: DataFrame, *, strict: bool = False) -> DataFrame:
|
13
13
|
"""DataFrameの列名を日本語から英語に変換する。"""
|
14
14
|
return df.rename({x.name: x.value for x in cls}, strict=strict)
|
15
15
|
|
@@ -37,11 +37,11 @@ class PriceColumns(BaseColumns):
|
|
37
37
|
Volume = "出来高"
|
38
38
|
TurnoverValue = "売買代金"
|
39
39
|
AdjustmentFactor = "調整係数"
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
40
|
+
RawOpen = "調整前始値"
|
41
|
+
RawHigh = "調整前高値"
|
42
|
+
RawLow = "調整前安値"
|
43
|
+
RawClose = "調整前終値"
|
44
|
+
RawVolume = "調整前出来高"
|
45
45
|
|
46
46
|
|
47
47
|
class StatementColumns(BaseColumns):
|
@@ -129,10 +129,13 @@ class StatementColumns(BaseColumns):
|
|
129
129
|
ChangesInAccountingEstimates = "会計上の見積りの変更"
|
130
130
|
RetrospectiveRestatement = "修正再表示"
|
131
131
|
|
132
|
-
NumberOfIssuedAndOutstandingSharesAtTheEndOfFiscalYearIncludingTreasuryStock
|
133
|
-
|
132
|
+
# NumberOfIssuedAndOutstandingSharesAtTheEndOfFiscalYearIncludingTreasuryStock
|
133
|
+
NumberOfShares = "期末発行済株式数"
|
134
|
+
# NumberOfTreasuryStockAtTheEndOfFiscalYear
|
135
|
+
NumberOfTreasuryStock = "期末自己株式数"
|
134
136
|
AverageNumberOfShares = "期中平均株式数"
|
135
137
|
|
138
|
+
"""
|
136
139
|
NonConsolidatedNetSales = "売上高_非連結"
|
137
140
|
NonConsolidatedOperatingProfit = "営業利益_非連結"
|
138
141
|
NonConsolidatedOrdinaryProfit = "経常利益_非連結"
|
@@ -167,3 +170,11 @@ class StatementColumns(BaseColumns):
|
|
167
170
|
NextYearForecastNonConsolidatedOrdinaryProfit = "経常利益_予想_翌事業年度期末_非連結"
|
168
171
|
NextYearForecastNonConsolidatedProfit = "当期純利益_予想_翌事業年度期末_非連結"
|
169
172
|
NextYearForecastNonConsolidatedEarningsPerShare = "一株あたり当期純利益_予想_翌事業年度期末_非連結"
|
173
|
+
"""
|
174
|
+
|
175
|
+
|
176
|
+
def rename(df: DataFrame, *, strict: bool = False) -> DataFrame:
|
177
|
+
"""DataFrameの列名を日本語から英語に変換する。"""
|
178
|
+
for enum in (InfoColumns, PriceColumns, StatementColumns):
|
179
|
+
df = enum.rename(df, strict=strict)
|
180
|
+
return df
|
kabukit/jquants/statements.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import datetime
|
4
|
+
from functools import cache
|
3
5
|
from typing import TYPE_CHECKING, Any, Protocol
|
4
6
|
|
7
|
+
import holidays
|
5
8
|
import polars as pl
|
6
9
|
|
7
10
|
if TYPE_CHECKING:
|
@@ -21,9 +24,15 @@ if TYPE_CHECKING:
|
|
21
24
|
|
22
25
|
def clean(df: DataFrame) -> DataFrame:
|
23
26
|
return (
|
24
|
-
df.select(pl.exclude(r"^.*\(REIT\)
|
27
|
+
df.select(pl.exclude(r"^.*\(REIT\)|.*NonConsolidated.*$"))
|
25
28
|
.rename(
|
26
|
-
{
|
29
|
+
{
|
30
|
+
"DisclosedDate": "Date",
|
31
|
+
"DisclosedTime": "Time",
|
32
|
+
"LocalCode": "Code",
|
33
|
+
"NumberOfIssuedAndOutstandingSharesAtTheEndOfFiscalYearIncludingTreasuryStock": "NumberOfShares", # noqa: E501
|
34
|
+
"NumberOfTreasuryStockAtTheEndOfFiscalYear": "NumberOfTreasuryStock",
|
35
|
+
},
|
27
36
|
)
|
28
37
|
.with_columns(
|
29
38
|
pl.col("^.*Date$").str.to_date("%Y-%m-%d", strict=False),
|
@@ -67,3 +76,27 @@ def _cast_bool(df: DataFrame) -> DataFrame:
|
|
67
76
|
.alias(col)
|
68
77
|
for col in columns
|
69
78
|
)
|
79
|
+
|
80
|
+
|
81
|
+
@cache
|
82
|
+
def get_holidays(year: int | None = None, n: int = 10) -> list[datetime.date]:
|
83
|
+
"""指定した過去年数の日本の祝日を取得する。"""
|
84
|
+
if year is None:
|
85
|
+
year = datetime.datetime.now().year # noqa: DTZ005
|
86
|
+
|
87
|
+
dates = holidays.country_holidays("JP", years=range(year - n, year + 1))
|
88
|
+
return sorted(dates.keys())
|
89
|
+
|
90
|
+
|
91
|
+
def update_effective_date(df: DataFrame, year: int | None = None) -> DataFrame:
|
92
|
+
"""開示日が休日や15時以降の場合、翌営業日に更新する。"""
|
93
|
+
holidays = get_holidays(year=year)
|
94
|
+
|
95
|
+
cond = pl.col("Time").is_null() | (pl.col("Time") > datetime.time(15, 0))
|
96
|
+
|
97
|
+
return df.with_columns(
|
98
|
+
pl.when(cond)
|
99
|
+
.then(pl.col("Date").dt.add_business_days(1, holidays=holidays))
|
100
|
+
.otherwise(pl.col("Date"))
|
101
|
+
.alias("Date"),
|
102
|
+
)
|
File without changes
|
@@ -0,0 +1,148 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from typing import TYPE_CHECKING, Any, Protocol
|
6
|
+
|
7
|
+
import polars as pl
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from collections.abc import (
|
11
|
+
AsyncIterable,
|
12
|
+
AsyncIterator,
|
13
|
+
Awaitable,
|
14
|
+
Callable,
|
15
|
+
Iterable,
|
16
|
+
)
|
17
|
+
from typing import Any
|
18
|
+
|
19
|
+
from marimo._plugins.stateless.status import progress_bar
|
20
|
+
from polars import DataFrame
|
21
|
+
from tqdm.asyncio import tqdm
|
22
|
+
|
23
|
+
from kabukit.core.client import Client
|
24
|
+
|
25
|
+
class _Progress(Protocol):
|
26
|
+
def __call__[T](
|
27
|
+
self,
|
28
|
+
aiterable: AsyncIterable[T],
|
29
|
+
total: int | None = None,
|
30
|
+
*args: Any,
|
31
|
+
**kwargs: Any,
|
32
|
+
) -> AsyncIterator[T]: ...
|
33
|
+
|
34
|
+
|
35
|
+
MAX_CONCURRENCY = 12
|
36
|
+
|
37
|
+
|
38
|
+
async def collect[R](
|
39
|
+
awaitables: Iterable[Awaitable[R]],
|
40
|
+
/,
|
41
|
+
max_concurrency: int | None = None,
|
42
|
+
) -> AsyncIterator[R]:
|
43
|
+
max_concurrency = max_concurrency or MAX_CONCURRENCY
|
44
|
+
semaphore = asyncio.Semaphore(max_concurrency)
|
45
|
+
|
46
|
+
async def run(awaitable: Awaitable[R]) -> R:
|
47
|
+
async with semaphore:
|
48
|
+
return await awaitable
|
49
|
+
|
50
|
+
futures = (run(awaitable) for awaitable in awaitables)
|
51
|
+
|
52
|
+
async for future in asyncio.as_completed(futures):
|
53
|
+
yield await future
|
54
|
+
|
55
|
+
|
56
|
+
async def collect_fn[T, R](
|
57
|
+
function: Callable[[T], Awaitable[R]],
|
58
|
+
args: Iterable[T],
|
59
|
+
/,
|
60
|
+
max_concurrency: int | None = None,
|
61
|
+
) -> AsyncIterator[R]:
|
62
|
+
max_concurrency = max_concurrency or MAX_CONCURRENCY
|
63
|
+
awaitables = (function(arg) for arg in args)
|
64
|
+
|
65
|
+
async for item in collect(awaitables, max_concurrency=max_concurrency):
|
66
|
+
yield item
|
67
|
+
|
68
|
+
|
69
|
+
async def concat(
|
70
|
+
awaitables: Iterable[Awaitable[DataFrame]],
|
71
|
+
/,
|
72
|
+
max_concurrency: int | None = None,
|
73
|
+
) -> DataFrame:
|
74
|
+
dfs = collect(awaitables, max_concurrency=max_concurrency)
|
75
|
+
dfs = [df async for df in dfs]
|
76
|
+
return pl.concat(df for df in dfs if not df.is_empty())
|
77
|
+
|
78
|
+
|
79
|
+
async def concat_fn[T](
|
80
|
+
function: Callable[[T], Awaitable[DataFrame]],
|
81
|
+
args: Iterable[T],
|
82
|
+
/,
|
83
|
+
max_concurrency: int | None = None,
|
84
|
+
) -> DataFrame:
|
85
|
+
dfs = collect_fn(function, args, max_concurrency=max_concurrency)
|
86
|
+
dfs = [df async for df in dfs]
|
87
|
+
return pl.concat(df for df in dfs if not df.is_empty())
|
88
|
+
|
89
|
+
|
90
|
+
type Callback = Callable[[DataFrame], DataFrame | None]
|
91
|
+
type Progress = type[progress_bar[Any] | tqdm[Any]] | _Progress
|
92
|
+
|
93
|
+
|
94
|
+
@dataclass
|
95
|
+
class Stream:
|
96
|
+
cls: type[Client]
|
97
|
+
resource: str
|
98
|
+
args: list[Any]
|
99
|
+
max_concurrency: int | None = None
|
100
|
+
|
101
|
+
async def __aiter__(self) -> AsyncIterator[DataFrame]:
|
102
|
+
async with self.cls() as client:
|
103
|
+
fn = getattr(client, f"get_{self.resource}")
|
104
|
+
|
105
|
+
async for df in collect_fn(fn, self.args, self.max_concurrency):
|
106
|
+
yield df
|
107
|
+
|
108
|
+
|
109
|
+
async def fetch(
|
110
|
+
cls: type[Client],
|
111
|
+
resource: str,
|
112
|
+
args: Iterable[Any],
|
113
|
+
/,
|
114
|
+
max_concurrency: int | None = None,
|
115
|
+
progress: Progress | None = None,
|
116
|
+
callback: Callback | None = None,
|
117
|
+
) -> DataFrame:
|
118
|
+
"""各種データを取得し、単一のDataFrameにまとめて返す。
|
119
|
+
|
120
|
+
Args:
|
121
|
+
cls (type[Client]): 使用するClientクラス。
|
122
|
+
JQuantsClientやEdinetClientなど、Clientを継承したクラス
|
123
|
+
resource (str): 取得するデータの種類。Clientのメソッド名から"get_"を
|
124
|
+
除いたものを指定する。
|
125
|
+
args (Iterable[Any]): 取得対象の引数のリスト。
|
126
|
+
max_concurrency (int | None, optional): 同時に実行するリクエストの最大数。
|
127
|
+
指定しないときはデフォルト値が使用される。
|
128
|
+
progress (Progress | None, optional): 進捗表示のための関数。
|
129
|
+
tqdm, marimoなどのライブラリを使用できる。
|
130
|
+
指定しないときは進捗表示は行われない。
|
131
|
+
callback (Callback | None, optional): 各DataFrameに対して適用する
|
132
|
+
コールバック関数。指定しないときはそのままのDataFrameが使用される。
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
DataFrame:
|
136
|
+
すべての情報を含む単一のDataFrame。
|
137
|
+
"""
|
138
|
+
args = list(args)
|
139
|
+
stream = Stream(cls, resource, args, max_concurrency)
|
140
|
+
|
141
|
+
if progress:
|
142
|
+
stream = progress(aiter(stream), total=len(args))
|
143
|
+
|
144
|
+
if callback:
|
145
|
+
stream = (x if (r := callback(x)) is None else r async for x in stream)
|
146
|
+
|
147
|
+
dfs = [df async for df in stream if not df.is_empty()]
|
148
|
+
return pl.concat(dfs) if dfs else pl.DataFrame()
|