kabukit 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kabukit/jquants/client.py CHANGED
@@ -6,13 +6,13 @@ from enum import StrEnum
6
6
  from typing import TYPE_CHECKING
7
7
 
8
8
  import polars as pl
9
- from httpx import AsyncClient
10
9
  from polars import DataFrame
11
10
 
12
- from kabukit.config import load_dotenv, set_key
13
- from kabukit.params import get_params
11
+ from kabukit.core.client import Client
12
+ from kabukit.utils.config import load_dotenv, set_key
13
+ from kabukit.utils.params import get_params
14
14
 
15
- from . import statements
15
+ from . import info, prices, statements
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from collections.abc import AsyncIterator
@@ -32,7 +32,7 @@ class AuthKey(StrEnum):
32
32
  ID_TOKEN = "JQUANTS_ID_TOKEN" # noqa: S105
33
33
 
34
34
 
35
- class JQuantsClient:
35
+ class JQuantsClient(Client):
36
36
  """J-Quants APIと対話するためのクライアント。
37
37
 
38
38
  API認証トークン(リフレッシュトークンおよびIDトークン)を管理し、
@@ -43,10 +43,8 @@ class JQuantsClient:
43
43
  client: APIリクエストを行うための `AsyncClient` インスタンス。
44
44
  """
45
45
 
46
- client: AsyncClient
47
-
48
46
  def __init__(self, id_token: str | None = None) -> None:
49
- self.client = AsyncClient(base_url=BASE_URL)
47
+ super().__init__(BASE_URL)
50
48
  self.set_id_token(id_token)
51
49
 
52
50
  def set_id_token(self, id_token: str | None = None) -> None:
@@ -64,43 +62,6 @@ class JQuantsClient:
64
62
  if id_token:
65
63
  self.client.headers["Authorization"] = f"Bearer {id_token}"
66
64
 
67
- async def aclose(self) -> None:
68
- """HTTPクライアントを閉じる。"""
69
- await self.client.aclose()
70
-
71
- async def __aenter__(self) -> Self:
72
- return self
73
-
74
- async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: # pyright: ignore[reportMissingParameterType, reportUnknownParameterType] # noqa: ANN001
75
- await self.aclose()
76
-
77
- async def auth(
78
- self,
79
- mailaddress: str,
80
- password: str,
81
- *,
82
- save: bool = False,
83
- ) -> Self:
84
- """認証を行い、トークンを保存する。
85
-
86
- Args:
87
- mailaddress (str): J-Quantsに登録したメールアドレス。
88
- password (str): J-Quantsのパスワード。
89
- save (bool, optional): トークンを環境変数に保存するかどうか。
90
-
91
- Raises:
92
- HTTPStatusError: APIリクエストが失敗した場合。
93
- """
94
- refresh_token = await self.get_refresh_token(mailaddress, password)
95
- id_token = await self.get_id_token(refresh_token)
96
-
97
- if save:
98
- set_key(AuthKey.REFRESH_TOKEN, refresh_token)
99
- set_key(AuthKey.ID_TOKEN, id_token)
100
-
101
- self.set_id_token(id_token)
102
- return self
103
-
104
65
  async def post(self, url: str, json: Any | None = None) -> Any:
105
66
  """指定されたURLにPOSTリクエストを送信する。
106
67
 
@@ -118,55 +79,54 @@ class JQuantsClient:
118
79
  resp.raise_for_status()
119
80
  return resp.json()
120
81
 
121
- async def get_refresh_token(self, mailaddress: str, password: str) -> str:
122
- """APIから新しいリフレッシュトークンを取得する。
82
+ async def get(self, url: str, params: QueryParamTypes | None = None) -> Any:
83
+ """指定されたURLにGETリクエストを送信する。
123
84
 
124
85
  Args:
125
- mailaddress (str): ユーザーのメールアドレス。
126
- password (str): ユーザーのパスワード。
86
+ url (str): GETリクエストのURLパス。
87
+ params (QueryParamTypes | None, optional): リクエストのクエリパラメータ。
127
88
 
128
89
  Returns:
129
- 新しいリフレッシュトークン。
90
+ APIからのJSONレスポンス。
130
91
 
131
92
  Raises:
132
93
  HTTPStatusError: APIリクエストが失敗した場合。
133
94
  """
134
- json_data = {"mailaddress": mailaddress, "password": password}
135
- data = await self.post("/token/auth_user", json=json_data)
136
- return data["refreshToken"]
95
+ resp = await self.client.get(url, params=params)
96
+ resp.raise_for_status()
97
+ return resp.json()
137
98
 
138
- async def get_id_token(self, refresh_token: str) -> str:
139
- """APIから新しいIDトークンを取得する。
99
+ async def auth(
100
+ self,
101
+ mailaddress: str,
102
+ password: str,
103
+ *,
104
+ save: bool = False,
105
+ ) -> Self:
106
+ """認証を行い、トークンを保存する。
140
107
 
141
108
  Args:
142
- refresh_token (str): 使用するリフレッシュトークン。
143
-
144
- Returns:
145
- 新しいIDトークン。
109
+ mailaddress (str): J-Quantsに登録したメールアドレス。
110
+ password (str): J-Quantsのパスワード。
111
+ save (bool, optional): トークンを環境変数に保存するかどうか。
146
112
 
147
113
  Raises:
148
114
  HTTPStatusError: APIリクエストが失敗した場合。
149
115
  """
116
+ json_data = {"mailaddress": mailaddress, "password": password}
117
+ data = await self.post("/token/auth_user", json=json_data)
118
+ refresh_token = data["refreshToken"]
119
+
150
120
  url = f"/token/auth_refresh?refreshtoken={refresh_token}"
151
121
  data = await self.post(url)
152
- return data["idToken"]
153
-
154
- async def get(self, url: str, params: QueryParamTypes | None = None) -> Any:
155
- """指定されたURLにGETリクエストを送信する。
156
-
157
- Args:
158
- url (str): GETリクエストのURLパス。
159
- params (QueryParamTypes | None, optional): リクエストのクエリパラメータ。
122
+ id_token = data["idToken"]
160
123
 
161
- Returns:
162
- APIからのJSONレスポンス。
124
+ if save:
125
+ set_key(AuthKey.REFRESH_TOKEN, refresh_token)
126
+ set_key(AuthKey.ID_TOKEN, id_token)
163
127
 
164
- Raises:
165
- HTTPStatusError: APIリクエストが失敗した場合。
166
- """
167
- resp = await self.client.get(url, params=params)
168
- resp.raise_for_status()
169
- return resp.json()
128
+ self.set_id_token(id_token)
129
+ return self
170
130
 
171
131
  async def get_info(
172
132
  self,
@@ -176,8 +136,8 @@ class JQuantsClient:
176
136
  """銘柄情報を取得する。
177
137
 
178
138
  Args:
179
- code (str | None, optional): 情報を取得する銘柄のコード。
180
- date (str | datetime.date | None, optional): 情報を取得する日付。
139
+ code (str, optional): 情報を取得する銘柄のコード。
140
+ date (str | datetime.date, optional): 情報を取得する日付。
181
141
 
182
142
  Returns:
183
143
  銘柄情報を含むPolars DataFrame。
@@ -188,12 +148,7 @@ class JQuantsClient:
188
148
  params = get_params(code=code, date=date)
189
149
  url = "/listed/info"
190
150
  data = await self.get(url, params)
191
- df = DataFrame(data["info"])
192
-
193
- return df.with_columns(
194
- pl.col("Date").str.to_date("%Y-%m-%d"),
195
- pl.col("^.*CodeName$", "ScaleCategory").cast(pl.Categorical),
196
- ).drop("^.+Code$", "CompanyNameEnglish")
151
+ return DataFrame(data["info"]).pipe(info.clean)
197
152
 
198
153
  async def iter_pages(
199
154
  self,
@@ -233,11 +188,16 @@ class JQuantsClient:
233
188
  ) -> DataFrame:
234
189
  """日々の株価四本値を取得する。
235
190
 
191
+ 株価は分割・併合を考慮した調整済み株価(小数点第2位四捨五入)と調整前の株価を取得できる。
192
+
236
193
  Args:
237
- code: 株価を取得する銘柄のコード。
238
- date: 株価を取得する特定の日付。`from_`または`to`とは併用不可。
239
- from_: 取得期間の開始日。`date`とは併用不可。
240
- to: 取得期間の終了日。`date`とは併用不可。
194
+ code (str, optional): 株価を取得する銘柄のコード。
195
+ date (str | datetime.date, optional): 株価を取得する日付。
196
+ `from_`または`to`とは併用不可。
197
+ from_ (str | datetime.date, optional): 取得期間の開始日。
198
+ `date`とは併用不可。
199
+ to (str | datetime.date, optional): 取得期間の終了日。
200
+ `date`とは併用不可。
241
201
 
242
202
  Returns:
243
203
  日々の株価四本値を含むPolars DataFrame。
@@ -250,7 +210,7 @@ class JQuantsClient:
250
210
  return await self.get_latest_available_prices()
251
211
 
252
212
  if date and (from_ or to):
253
- msg = "Cannot specify both date and from/to parameters."
213
+ msg = "datefrom/toの両方を指定することはできません。"
254
214
  raise ValueError(msg)
255
215
 
256
216
  params = get_params(code=code, date=date, from_=from_, to=to)
@@ -264,16 +224,13 @@ class JQuantsClient:
264
224
  if df.is_empty():
265
225
  return df
266
226
 
267
- return df.with_columns(
268
- pl.col("Date").str.to_date("%Y-%m-%d"),
269
- pl.col("^.*Limit$").cast(pl.Int8).cast(pl.Boolean),
270
- )
227
+ return prices.clean(df)
271
228
 
272
- async def get_latest_available_prices(self) -> DataFrame:
229
+ async def get_latest_available_prices(self, num_days: int = 30) -> DataFrame:
273
230
  """直近利用可能な日付の株価を取得する。"""
274
231
  today = datetime.date.today() # noqa: DTZ011
275
232
 
276
- for days in range(30):
233
+ for days in range(num_days):
277
234
  date = today - datetime.timedelta(days)
278
235
  df = await self.get_prices(date=date)
279
236
 
@@ -287,18 +244,23 @@ class JQuantsClient:
287
244
  code: str | None = None,
288
245
  date: str | datetime.date | None = None,
289
246
  ) -> DataFrame:
290
- """財務情報を取得する。
247
+ """四半期毎の決算短信サマリーおよび業績・配当の修正に関する開示情報を取得する。
291
248
 
292
249
  Args:
293
- code: 財務情報を取得する銘柄のコード。
294
- date: 財務情報を取得する日付。
250
+ code (str, optional): 財務情報を取得する銘柄のコード。
251
+ date (str | datetime.date, optional): 財務情報を取得する日付。
295
252
 
296
253
  Returns:
297
- 財務情報を含むPolars DataFrame。
254
+ 財務情報を含むDataFrame。
298
255
 
299
256
  Raises:
257
+ ValueError: `code`と`date`が両方とも指定されない場合。
300
258
  HTTPStatusError: APIリクエストが失敗した場合。
301
259
  """
260
+ if not code and not date:
261
+ msg = "codeまたはdateのどちらかを指定する必要があります。"
262
+ raise ValueError(msg)
263
+
302
264
  params = get_params(code=code, date=date)
303
265
  url = "/fins/statements"
304
266
  name = "statements"
@@ -1,47 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataclasses import dataclass
4
- from typing import TYPE_CHECKING, Any, Protocol
3
+ from typing import TYPE_CHECKING
5
4
 
6
- import polars as pl
7
-
8
- from kabukit.concurrent import collect_fn
5
+ from kabukit.utils import concurrent
9
6
 
10
7
  from .client import JQuantsClient
11
8
  from .info import get_codes
12
9
 
13
10
  if TYPE_CHECKING:
14
- from collections.abc import AsyncIterable, AsyncIterator, Callable, Iterable
11
+ from collections.abc import Iterable
15
12
 
16
13
  from polars import DataFrame
17
14
 
18
- class Progress(Protocol):
19
- def __call__[T](
20
- self,
21
- aiterable: AsyncIterable[T],
22
- total: int | None = None,
23
- *args: Any,
24
- **kwargs: Any,
25
- ) -> AsyncIterator[T]: ...
26
-
27
-
28
- type Callback = Callable[[DataFrame], DataFrame | None]
29
-
30
-
31
- @dataclass
32
- class Stream:
33
- """JQuantsから各種データを銘柄コードごとにストリーム形式で取得する。"""
34
-
35
- resource: str
36
- codes: list[str]
37
- max_concurrency: int | None = None
38
-
39
- async def __aiter__(self) -> AsyncIterator[DataFrame]:
40
- async with JQuantsClient() as client:
41
- fn = getattr(client, f"get_{self.resource}")
42
-
43
- async for df in collect_fn(fn, self.codes, self.max_concurrency):
44
- yield df
15
+ from kabukit.utils.concurrent import Callback, Progress
45
16
 
46
17
 
47
18
  async def fetch(
@@ -70,16 +41,14 @@ async def fetch(
70
41
  DataFrame:
71
42
  すべての銘柄の財務情報を含む単一のDataFrame。
72
43
  """
73
- codes = list(codes)
74
- stream = Stream(resource, codes, max_concurrency)
75
-
76
- if progress:
77
- stream = progress(aiter(stream), total=len(codes))
78
-
79
- if callback:
80
- stream = (x if (r := callback(x)) is None else r async for x in stream)
81
-
82
- return pl.concat([df async for df in stream if not df.is_empty()])
44
+ return await concurrent.fetch(
45
+ JQuantsClient,
46
+ resource,
47
+ codes,
48
+ max_concurrency=max_concurrency,
49
+ progress=progress,
50
+ callback=callback,
51
+ )
83
52
 
84
53
 
85
54
  async def fetch_all(
kabukit/jquants/info.py CHANGED
@@ -1,8 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import polars as pl
4
+ from polars import DataFrame
4
5
 
5
- from .client import JQuantsClient
6
+
7
+ def clean(df: DataFrame) -> DataFrame:
8
+ return df.with_columns(
9
+ pl.col("Date").str.to_date("%Y-%m-%d"),
10
+ pl.col("^.*CodeName$", "ScaleCategory").cast(pl.Categorical),
11
+ ).drop("^.+Code$", "CompanyNameEnglish")
6
12
 
7
13
 
8
14
  async def get_codes() -> list[str]:
@@ -10,6 +16,8 @@ async def get_codes() -> list[str]:
10
16
 
11
17
  市場「TOKYO PRO MARKET」と業種「その他」を除外した銘柄を対象とする。
12
18
  """
19
+ from .client import JQuantsClient
20
+
13
21
  async with JQuantsClient() as client:
14
22
  info = await client.get_info()
15
23
 
kabukit/jquants/prices.py CHANGED
@@ -0,0 +1,29 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import polars as pl
6
+
7
+ if TYPE_CHECKING:
8
+ from polars import DataFrame
9
+
10
+
11
+ def clean(df: DataFrame) -> DataFrame:
12
+ return df.select(
13
+ pl.col("Date").str.to_date("%Y-%m-%d"),
14
+ "Code",
15
+ Open=pl.col("AdjustmentOpen"),
16
+ High=pl.col("AdjustmentHigh"),
17
+ Low=pl.col("AdjustmentLow"),
18
+ Close=pl.col("AdjustmentClose"),
19
+ UpperLimit=pl.col("UpperLimit").cast(pl.Int8).cast(pl.Boolean),
20
+ LowerLimit=pl.col("LowerLimit").cast(pl.Int8).cast(pl.Boolean),
21
+ Volume=pl.col("AdjustmentVolume"),
22
+ TurnoverValue=pl.col("TurnoverValue"),
23
+ AdjustmentFactor=pl.col("AdjustmentFactor"),
24
+ RawOpen=pl.col("Open"),
25
+ RawHigh=pl.col("High"),
26
+ RawLow=pl.col("Low"),
27
+ RawClose=pl.col("Close"),
28
+ RawVolume=pl.col("Volume"),
29
+ )
kabukit/jquants/schema.py CHANGED
@@ -9,7 +9,7 @@ if TYPE_CHECKING:
9
9
 
10
10
  class BaseColumns(Enum):
11
11
  @classmethod
12
- def rename(cls, df: DataFrame, *, strict: bool = True) -> DataFrame:
12
+ def rename(cls, df: DataFrame, *, strict: bool = False) -> DataFrame:
13
13
  """DataFrameの列名を日本語から英語に変換する。"""
14
14
  return df.rename({x.name: x.value for x in cls}, strict=strict)
15
15
 
@@ -37,11 +37,11 @@ class PriceColumns(BaseColumns):
37
37
  Volume = "出来高"
38
38
  TurnoverValue = "売買代金"
39
39
  AdjustmentFactor = "調整係数"
40
- AdjustmentOpen = "調整済始値"
41
- AdjustmentHigh = "調整済高値"
42
- AdjustmentLow = "調整済安値"
43
- AdjustmentClose = "調整済終値"
44
- AdjustmentVolume = "調整済取引高"
40
+ RawOpen = "調整前始値"
41
+ RawHigh = "調整前高値"
42
+ RawLow = "調整前安値"
43
+ RawClose = "調整前終値"
44
+ RawVolume = "調整前出来高"
45
45
 
46
46
 
47
47
  class StatementColumns(BaseColumns):
@@ -129,10 +129,13 @@ class StatementColumns(BaseColumns):
129
129
  ChangesInAccountingEstimates = "会計上の見積りの変更"
130
130
  RetrospectiveRestatement = "修正再表示"
131
131
 
132
- NumberOfIssuedAndOutstandingSharesAtTheEndOfFiscalYearIncludingTreasuryStock = "期末発行済株式数"
133
- NumberOfTreasuryStockAtTheEndOfFiscalYear = "期末自己株式数"
132
+ # NumberOfIssuedAndOutstandingSharesAtTheEndOfFiscalYearIncludingTreasuryStock
133
+ NumberOfShares = "期末発行済株式数"
134
+ # NumberOfTreasuryStockAtTheEndOfFiscalYear
135
+ NumberOfTreasuryStock = "期末自己株式数"
134
136
  AverageNumberOfShares = "期中平均株式数"
135
137
 
138
+ """
136
139
  NonConsolidatedNetSales = "売上高_非連結"
137
140
  NonConsolidatedOperatingProfit = "営業利益_非連結"
138
141
  NonConsolidatedOrdinaryProfit = "経常利益_非連結"
@@ -167,3 +170,11 @@ class StatementColumns(BaseColumns):
167
170
  NextYearForecastNonConsolidatedOrdinaryProfit = "経常利益_予想_翌事業年度期末_非連結"
168
171
  NextYearForecastNonConsolidatedProfit = "当期純利益_予想_翌事業年度期末_非連結"
169
172
  NextYearForecastNonConsolidatedEarningsPerShare = "一株あたり当期純利益_予想_翌事業年度期末_非連結"
173
+ """
174
+
175
+
176
+ def rename(df: DataFrame, *, strict: bool = False) -> DataFrame:
177
+ """DataFrameの列名を日本語から英語に変換する。"""
178
+ for enum in (InfoColumns, PriceColumns, StatementColumns):
179
+ df = enum.rename(df, strict=strict)
180
+ return df
@@ -1,7 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import datetime
4
+ from functools import cache
3
5
  from typing import TYPE_CHECKING, Any, Protocol
4
6
 
7
+ import holidays
5
8
  import polars as pl
6
9
 
7
10
  if TYPE_CHECKING:
@@ -21,9 +24,15 @@ if TYPE_CHECKING:
21
24
 
22
25
  def clean(df: DataFrame) -> DataFrame:
23
26
  return (
24
- df.select(pl.exclude(r"^.*\(REIT\)$"))
27
+ df.select(pl.exclude(r"^.*\(REIT\)|.*NonConsolidated.*$"))
25
28
  .rename(
26
- {"DisclosedDate": "Date", "DisclosedTime": "Time", "LocalCode": "Code"},
29
+ {
30
+ "DisclosedDate": "Date",
31
+ "DisclosedTime": "Time",
32
+ "LocalCode": "Code",
33
+ "NumberOfIssuedAndOutstandingSharesAtTheEndOfFiscalYearIncludingTreasuryStock": "NumberOfShares", # noqa: E501
34
+ "NumberOfTreasuryStockAtTheEndOfFiscalYear": "NumberOfTreasuryStock",
35
+ },
27
36
  )
28
37
  .with_columns(
29
38
  pl.col("^.*Date$").str.to_date("%Y-%m-%d", strict=False),
@@ -67,3 +76,27 @@ def _cast_bool(df: DataFrame) -> DataFrame:
67
76
  .alias(col)
68
77
  for col in columns
69
78
  )
79
+
80
+
81
+ @cache
82
+ def get_holidays(year: int | None = None, n: int = 10) -> list[datetime.date]:
83
+ """指定した過去年数の日本の祝日を取得する。"""
84
+ if year is None:
85
+ year = datetime.datetime.now().year # noqa: DTZ005
86
+
87
+ dates = holidays.country_holidays("JP", years=range(year - n, year + 1))
88
+ return sorted(dates.keys())
89
+
90
+
91
+ def update_effective_date(df: DataFrame, year: int | None = None) -> DataFrame:
92
+ """開示日が休日や15時以降の場合、翌営業日に更新する。"""
93
+ holidays = get_holidays(year=year)
94
+
95
+ cond = pl.col("Time").is_null() | (pl.col("Time") > datetime.time(15, 0))
96
+
97
+ return df.with_columns(
98
+ pl.when(cond)
99
+ .then(pl.col("Date").dt.add_business_days(1, holidays=holidays))
100
+ .otherwise(pl.col("Date"))
101
+ .alias("Date"),
102
+ )
File without changes
@@ -0,0 +1,148 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, Any, Protocol
6
+
7
+ import polars as pl
8
+
9
+ if TYPE_CHECKING:
10
+ from collections.abc import (
11
+ AsyncIterable,
12
+ AsyncIterator,
13
+ Awaitable,
14
+ Callable,
15
+ Iterable,
16
+ )
17
+ from typing import Any
18
+
19
+ from marimo._plugins.stateless.status import progress_bar
20
+ from polars import DataFrame
21
+ from tqdm.asyncio import tqdm
22
+
23
+ from kabukit.core.client import Client
24
+
25
+ class _Progress(Protocol):
26
+ def __call__[T](
27
+ self,
28
+ aiterable: AsyncIterable[T],
29
+ total: int | None = None,
30
+ *args: Any,
31
+ **kwargs: Any,
32
+ ) -> AsyncIterator[T]: ...
33
+
34
+
35
+ MAX_CONCURRENCY = 12
36
+
37
+
38
+ async def collect[R](
39
+ awaitables: Iterable[Awaitable[R]],
40
+ /,
41
+ max_concurrency: int | None = None,
42
+ ) -> AsyncIterator[R]:
43
+ max_concurrency = max_concurrency or MAX_CONCURRENCY
44
+ semaphore = asyncio.Semaphore(max_concurrency)
45
+
46
+ async def run(awaitable: Awaitable[R]) -> R:
47
+ async with semaphore:
48
+ return await awaitable
49
+
50
+ futures = (run(awaitable) for awaitable in awaitables)
51
+
52
+ async for future in asyncio.as_completed(futures):
53
+ yield await future
54
+
55
+
56
+ async def collect_fn[T, R](
57
+ function: Callable[[T], Awaitable[R]],
58
+ args: Iterable[T],
59
+ /,
60
+ max_concurrency: int | None = None,
61
+ ) -> AsyncIterator[R]:
62
+ max_concurrency = max_concurrency or MAX_CONCURRENCY
63
+ awaitables = (function(arg) for arg in args)
64
+
65
+ async for item in collect(awaitables, max_concurrency=max_concurrency):
66
+ yield item
67
+
68
+
69
+ async def concat(
70
+ awaitables: Iterable[Awaitable[DataFrame]],
71
+ /,
72
+ max_concurrency: int | None = None,
73
+ ) -> DataFrame:
74
+ dfs = collect(awaitables, max_concurrency=max_concurrency)
75
+ dfs = [df async for df in dfs]
76
+ return pl.concat(df for df in dfs if not df.is_empty())
77
+
78
+
79
+ async def concat_fn[T](
80
+ function: Callable[[T], Awaitable[DataFrame]],
81
+ args: Iterable[T],
82
+ /,
83
+ max_concurrency: int | None = None,
84
+ ) -> DataFrame:
85
+ dfs = collect_fn(function, args, max_concurrency=max_concurrency)
86
+ dfs = [df async for df in dfs]
87
+ return pl.concat(df for df in dfs if not df.is_empty())
88
+
89
+
90
+ type Callback = Callable[[DataFrame], DataFrame | None]
91
+ type Progress = type[progress_bar[Any] | tqdm[Any]] | _Progress
92
+
93
+
94
+ @dataclass
95
+ class Stream:
96
+ cls: type[Client]
97
+ resource: str
98
+ args: list[Any]
99
+ max_concurrency: int | None = None
100
+
101
+ async def __aiter__(self) -> AsyncIterator[DataFrame]:
102
+ async with self.cls() as client:
103
+ fn = getattr(client, f"get_{self.resource}")
104
+
105
+ async for df in collect_fn(fn, self.args, self.max_concurrency):
106
+ yield df
107
+
108
+
109
+ async def fetch(
110
+ cls: type[Client],
111
+ resource: str,
112
+ args: Iterable[Any],
113
+ /,
114
+ max_concurrency: int | None = None,
115
+ progress: Progress | None = None,
116
+ callback: Callback | None = None,
117
+ ) -> DataFrame:
118
+ """各種データを取得し、単一のDataFrameにまとめて返す。
119
+
120
+ Args:
121
+ cls (type[Client]): 使用するClientクラス。
122
+ JQuantsClientやEdinetClientなど、Clientを継承したクラス
123
+ resource (str): 取得するデータの種類。Clientのメソッド名から"get_"を
124
+ 除いたものを指定する。
125
+ args (Iterable[Any]): 取得対象の引数のリスト。
126
+ max_concurrency (int | None, optional): 同時に実行するリクエストの最大数。
127
+ 指定しないときはデフォルト値が使用される。
128
+ progress (Progress | None, optional): 進捗表示のための関数。
129
+ tqdm, marimoなどのライブラリを使用できる。
130
+ 指定しないときは進捗表示は行われない。
131
+ callback (Callback | None, optional): 各DataFrameに対して適用する
132
+ コールバック関数。指定しないときはそのままのDataFrameが使用される。
133
+
134
+ Returns:
135
+ DataFrame:
136
+ すべての情報を含む単一のDataFrame。
137
+ """
138
+ args = list(args)
139
+ stream = Stream(cls, resource, args, max_concurrency)
140
+
141
+ if progress:
142
+ stream = progress(aiter(stream), total=len(args))
143
+
144
+ if callback:
145
+ stream = (x if (r := callback(x)) is None else r async for x in stream)
146
+
147
+ dfs = [df async for df in stream if not df.is_empty()]
148
+ return pl.concat(dfs) if dfs else pl.DataFrame()