koleo-cli 0.2.137.17__py3-none-any.whl → 0.2.137.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of koleo-cli might be problematic. Click here for more details.

koleo/args.py ADDED
@@ -0,0 +1,279 @@
1
+ from argparse import ArgumentParser
2
+ from datetime import datetime
3
+ from asyncio import run
4
+ from inspect import isawaitable
5
+
6
+ from .api import KoleoAPI
7
+ from .cli import CLI
8
+ from .storage import DEFAULT_CONFIG_PATH, Storage
9
+ from .utils import RemainderString, parse_datetime
10
+
11
+
12
+ def main():
13
+ cli = CLI()
14
+
15
+ parser = ArgumentParser("koleo", description="Koleo CLI")
16
+ parser.add_argument("-c", "--config", help="Custom config path.", default=DEFAULT_CONFIG_PATH)
17
+ parser.add_argument("--nocolor", help="Disable color output and formatting", action="store_true", default=False)
18
+ subparsers = parser.add_subparsers(title="actions", required=False) # type: ignore
19
+
20
+ departures = subparsers.add_parser(
21
+ "departures", aliases=["d", "dep", "odjazdy", "o"], help="Allows you to list station departures"
22
+ )
23
+ departures.add_argument(
24
+ "station",
25
+ help="The station name",
26
+ default=None,
27
+ nargs="*",
28
+ action=RemainderString,
29
+ )
30
+ departures.add_argument(
31
+ "-d",
32
+ "--date",
33
+ help="the departure date",
34
+ type=lambda s: parse_datetime(s),
35
+ default=datetime.now(),
36
+ )
37
+ departures.add_argument("-s", "--save", help="save the station as your default one", action="store_true")
38
+ departures.set_defaults(func=cli.full_departures_view, pass_=["station", "date"])
39
+
40
+ arrivals = subparsers.add_parser(
41
+ "arrivals", aliases=["a", "arr", "przyjazdy", "p"], help="Allows you to list station departures"
42
+ )
43
+ arrivals.add_argument(
44
+ "station",
45
+ help="The station name",
46
+ default=None,
47
+ nargs="*",
48
+ action=RemainderString,
49
+ )
50
+ arrivals.add_argument(
51
+ "-d",
52
+ "--date",
53
+ help="the arrival date",
54
+ type=lambda s: parse_datetime(s),
55
+ default=datetime.now(),
56
+ )
57
+ arrivals.add_argument("-s", "--save", help="save the station as your default one", action="store_true")
58
+ arrivals.set_defaults(func=cli.full_arrivals_view, pass_=["station", "date"])
59
+
60
+ all_trains = subparsers.add_parser(
61
+ "all", aliases=["w", "wszystkie", "all_trains", "pociagi"], help="Allows you to list all station trains"
62
+ )
63
+ all_trains.add_argument(
64
+ "station",
65
+ help="The station name",
66
+ default=None,
67
+ nargs="*",
68
+ action=RemainderString,
69
+ )
70
+ all_trains.add_argument(
71
+ "-d",
72
+ "--date",
73
+ help="the date",
74
+ type=lambda s: parse_datetime(s),
75
+ default=datetime.now(),
76
+ )
77
+ all_trains.add_argument("-s", "--save", help="save the station as your default one", action="store_true")
78
+ all_trains.set_defaults(func=cli.all_trains_view, pass_=["station", "date"])
79
+
80
+ train_route = subparsers.add_parser(
81
+ "trainroute",
82
+ aliases=["r", "tr", "t", "poc", "pociąg"],
83
+ help="Allows you to check the train's route",
84
+ )
85
+ train_route.add_argument("brand", help="The brand name", type=str)
86
+ train_route.add_argument("name", help="The train name", nargs="+", action=RemainderString)
87
+ train_route.add_argument(
88
+ "-d",
89
+ "--date",
90
+ help="the date",
91
+ type=lambda s: parse_datetime(s),
92
+ default=datetime.now(),
93
+ )
94
+ train_route.add_argument(
95
+ "-c",
96
+ "--closest",
97
+ help="ignores date, fetches closest date from the train calendar",
98
+ action="store_true",
99
+ default=False,
100
+ )
101
+ train_route.add_argument(
102
+ "-s", "--show_stations", help="limit the result to A->B", action="extend", nargs=2, type=str, default=None
103
+ )
104
+ train_route.set_defaults(func=cli.train_info_view, pass_=["brand", "name", "date", "closest", "show_stations"])
105
+
106
+ train_calendar = subparsers.add_parser(
107
+ "traincalendar",
108
+ aliases=["kursowanie", "tc", "k"],
109
+ help="Allows you to check what days the train runs on",
110
+ )
111
+ train_calendar.add_argument("brand", help="The brand name", type=str)
112
+ train_calendar.add_argument("name", help="The train name", nargs="+", action=RemainderString)
113
+ train_calendar.set_defaults(func=cli.train_calendar_view, pass_=["brand", "name"])
114
+
115
+ train_detail = subparsers.add_parser(
116
+ "traindetail",
117
+ aliases=["td", "tid", "id", "idpoc"],
118
+ help="Allows you to show the train's route given it's koleo ID",
119
+ )
120
+ train_detail.add_argument(
121
+ "-s", "--show_stations", help="limit the result to A->B", action="extend", nargs=2, type=str, default=None
122
+ )
123
+ train_detail.add_argument("train_id", help="The koleo ID", type=int)
124
+ train_detail.set_defaults(func=cli.train_detail_view, pass_=["train_id", "show_stations"])
125
+
126
+ stations = subparsers.add_parser(
127
+ "stations", aliases=["s", "find", "f", "stacje", "ls", "q"], help="Allows you to find stations by their name"
128
+ )
129
+ stations.add_argument(
130
+ "query",
131
+ help="The station name",
132
+ default=None,
133
+ nargs="*",
134
+ action=RemainderString,
135
+ )
136
+ stations.add_argument(
137
+ "-t",
138
+ "--type",
139
+ help="filter results by type[rail, bus, group]",
140
+ type=str,
141
+ default=None,
142
+ )
143
+ stations.add_argument(
144
+ "-c",
145
+ "--country",
146
+ help="filter results by country code[pl, de, ...]",
147
+ type=str,
148
+ default=None,
149
+ )
150
+ stations.set_defaults(func=cli.find_station_view, pass_=["query", "type", "country"])
151
+
152
+ connections = subparsers.add_parser(
153
+ "connections",
154
+ aliases=["do", "z", "szukaj", "path"],
155
+ help="Allows you to search for connections from a to b",
156
+ )
157
+ connections.add_argument("start", help="The starting station", type=str)
158
+ connections.add_argument("end", help="The end station", type=str)
159
+ connections.add_argument(
160
+ "-d",
161
+ "--date",
162
+ help="the date",
163
+ type=lambda s: parse_datetime(s),
164
+ default=datetime.now(),
165
+ )
166
+ connections.add_argument(
167
+ "-b", "--brands", help="Brands to include", action="extend", nargs="+", type=str, default=[]
168
+ )
169
+ connections.add_argument(
170
+ "-n",
171
+ "--direct",
172
+ help="whether or not the result should only include direct trains",
173
+ action="store_true",
174
+ default=False,
175
+ )
176
+ connections.add_argument(
177
+ "-p",
178
+ "--include_prices",
179
+ help="whether or not the result should include the price",
180
+ action="store_true",
181
+ default=False,
182
+ )
183
+ connections.add_argument(
184
+ "--only_purchasable",
185
+ help="whether or not the result should include only purchasable connections",
186
+ action="store_true",
187
+ default=False,
188
+ )
189
+ connections.add_argument(
190
+ "-l",
191
+ "--length",
192
+ help="fetch at least n connections",
193
+ type=int,
194
+ default=1,
195
+ )
196
+ connections.set_defaults(
197
+ func=cli.connections_view,
198
+ pass_=["start", "end", "brands", "date", "direct", "include_prices", "only_purchasable", "length"],
199
+ )
200
+
201
+ train_passenger_stats = subparsers.add_parser(
202
+ "trainstats",
203
+ aliases=["ts", "tp", "miejsca", "frekwencja"],
204
+ help="Allows you to check seat allocation info for a train.",
205
+ )
206
+ train_passenger_stats.add_argument("brand", help="The brand name", type=str)
207
+ train_passenger_stats.add_argument("name", help="The train name", nargs="+", action=RemainderString)
208
+ train_passenger_stats.add_argument(
209
+ "-d",
210
+ "--date",
211
+ help="the date",
212
+ type=lambda s: parse_datetime(s),
213
+ default=datetime.now(),
214
+ )
215
+ train_passenger_stats.add_argument(
216
+ "-s", "--stations", help="A->B", action="extend", nargs=2, type=str, default=None
217
+ )
218
+ train_passenger_stats.add_argument(
219
+ "-t", "--type", help="limit the result to seats of a given type", type=str, required=False
220
+ )
221
+ train_passenger_stats.set_defaults(
222
+ func=cli.train_passenger_stats_view, pass_=["brand", "name", "date", "stations", "type"]
223
+ )
224
+
225
+ train_connection_stats = subparsers.add_parser(
226
+ "trainconnectionstats",
227
+ aliases=["tcs"],
228
+ help="Allows you to check the seat allocations on the train connection given it's koleo ID",
229
+ )
230
+ train_connection_stats.add_argument(
231
+ "-t", "--type", help="limit the result to seats of a given type", type=str, required=False
232
+ )
233
+ train_connection_stats.add_argument("connection_id", help="The koleo ID", type=int)
234
+ train_connection_stats.set_defaults(func=cli.train_connection_stats_view, pass_=["connection_id", "type"])
235
+
236
+ aliases = subparsers.add_parser("aliases", help="Save quick aliases for station names!")
237
+ aliases.set_defaults(func=cli.alias_list_view)
238
+ aliases_subparser = aliases.add_subparsers()
239
+ aliases_add = aliases_subparser.add_parser("add", aliases=["a"], help="add an alias")
240
+ aliases_add.add_argument("alias", help="The alias")
241
+ aliases_add.add_argument(
242
+ "station",
243
+ help="The station name",
244
+ nargs="*",
245
+ action=RemainderString,
246
+ )
247
+ aliases_add.set_defaults(func=cli.alias_add_view, pass_=["alias", "station"])
248
+
249
+ aliases_remove = aliases_subparser.add_parser("remove", aliases=["r", "rm"], help="remove an alias")
250
+ aliases_remove.add_argument("alias", help="The alias")
251
+ aliases_remove.set_defaults(func=cli.alias_remove_view, pass_=["alias"])
252
+
253
+ args = parser.parse_args()
254
+
255
+ storage = Storage.load(path=args.config)
256
+ client = KoleoAPI()
257
+
258
+ async def run_view(func, *args, **kwargs):
259
+ res = func(*args, **kwargs)
260
+ if isawaitable(res):
261
+ await res
262
+ await client.close()
263
+
264
+ cli.client, cli.storage = client, storage
265
+ cli.console.no_color = args.nocolor
266
+ cli.no_color = args.nocolor
267
+ if hasattr(args, "station") and args.station is None:
268
+ args.station = storage.favourite_station
269
+ elif hasattr(args, "station") and getattr(args, "save", False):
270
+ storage.favourite_station = args.station
271
+ if not hasattr(args, "func"):
272
+ if storage.favourite_station:
273
+ run(run_view(cli.full_departures_view, storage.favourite_station, datetime.now()))
274
+ else:
275
+ parser.print_help()
276
+ else:
277
+ run(run_view(args.func, **{k: v for k, v in args.__dict__.items() if k in getattr(args, "pass_", [])}))
278
+ if storage.dirty:
279
+ storage.save()
koleo/cli/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ from .aliases import Aliases
2
+ from .station_board import StationBoard
3
+ from .connections import Connections
4
+ from .train_info import TrainInfo
5
+ from .seats import Seats
6
+ from .stations import Stations
7
+
8
+
9
+ class CLI(Aliases, StationBoard, Connections, Seats, Stations): ...
koleo/cli/aliases.py ADDED
@@ -0,0 +1,15 @@
1
+ from .base import BaseCli
2
+
3
+
4
+ class Aliases(BaseCli):
5
+ def alias_list_view(self):
6
+ self.print(f"[bold][green]alias[/green] → [red]station[/red][/bold]:")
7
+ for n, (k, v) in enumerate(self.storage.aliases.items()):
8
+ self.print(f"{n}. [bold][green]{k}[/green] → [red]{v}[/red][/bold]")
9
+
10
+ async def alias_add_view(self, alias: str, station: str):
11
+ station_obj = await self.get_station(station)
12
+ self.storage.add_alias(alias, station_obj["name_slug"])
13
+
14
+ def alias_remove_view(self, alias: str):
15
+ self.storage.remove_alias(alias)
koleo/cli/base.py ADDED
@@ -0,0 +1,103 @@
1
+ from koleo.api import KoleoAPI
2
+ from koleo.storage import Storage
3
+ from koleo.api.types import ExtendedStationInfo, TrainOnStationInfo, TrainStop
4
+ from koleo.utils import koleo_time_to_dt, name_to_slug, convert_platform_number
5
+
6
+ from rich.console import Console
7
+ import re
8
+
9
+
10
+ class BaseCli:
11
+ def __init__(
12
+ self,
13
+ no_color: bool = False,
14
+ client: KoleoAPI | None = None,
15
+ storage: Storage | None = None,
16
+ ) -> None:
17
+ self._client = client
18
+ self._storage = storage
19
+ self.no_color = no_color
20
+ self.console = Console(color_system="standard", no_color=no_color, highlight=False)
21
+
22
+ def print(self, text: str, *args, **kwargs):
23
+ if not text.strip():
24
+ return
25
+ if self.no_color:
26
+ result = re.sub(r"\[[^\]]*\]", "", text)
27
+ print(result)
28
+ else:
29
+ self.console.print(text, *args, **kwargs)
30
+
31
+ async def error_and_exit(self, text: str, *args, **kwargs):
32
+ self.print(f"[bold red]{text}[/bold red]", *args, **kwargs)
33
+ await self.client.close()
34
+ exit(2)
35
+
36
+ @property
37
+ def client(self) -> KoleoAPI:
38
+ if not self._client:
39
+ raise ValueError("Client not set!")
40
+ return self._client
41
+
42
+ @client.setter
43
+ def client(self, client: KoleoAPI):
44
+ self._client = client
45
+
46
+ @property
47
+ def storage(self) -> Storage:
48
+ if not self._storage:
49
+ raise ValueError("Storage not set!")
50
+ return self._storage
51
+
52
+ @storage.setter
53
+ def storage(self, storage: Storage):
54
+ self._storage = storage
55
+
56
+ async def trains_on_station_table(
57
+ self, trains: list[TrainOnStationInfo], type: int = 1, show_connection_id: bool | None = None
58
+ ):
59
+ show_connection_id = self.storage.show_connection_id if show_connection_id is None else show_connection_id
60
+ brands = await self.get_brands()
61
+ for train in trains:
62
+ time, color = (train["departure"], "green") if type == 1 else (train["arrival"], "yellow")
63
+ assert time
64
+ brand = next(iter(i for i in brands if i["id"] == train["brand_id"]), {}).get("logo_text")
65
+ tid = (f"{train["stations"][0]["train_id"]} ") if show_connection_id else ""
66
+ self.print(
67
+ f"{tid}[bold {color}]{time[11:16]}[/bold {color}] [red]{brand}[/red] {train["train_full_name"]}[purple] {train["stations"][0]["name"]} {self.format_position(train["platform"], train["track"])}[/purple]"
68
+ )
69
+
70
+ def train_route_table(self, stops: list[TrainStop]):
71
+ last_real_distance = stops[0]["distance"]
72
+ for stop in stops:
73
+ arr = koleo_time_to_dt(stop["arrival"])
74
+ dep = koleo_time_to_dt(stop["departure"])
75
+ distance = stop["distance"] - last_real_distance
76
+ self.print(
77
+ f"[white underline]{distance / 1000:^5.1f}km[/white underline] [bold green]{arr.strftime("%H:%M")}[/bold green] - [bold red]{dep.strftime("%H:%M")}[/bold red] [purple]{stop["station_display_name"]} {self.format_position(stop["platform"])} [/purple]"
78
+ )
79
+
80
+ def format_position(self, platform: str, track: str | None = None):
81
+ res = str(convert_platform_number(platform) or "" if not self.storage.use_roman_numerals else platform)
82
+ if track is not None and track != "":
83
+ res += f"/{track}"
84
+ return res
85
+
86
+ async def get_station(self, station: str) -> ExtendedStationInfo:
87
+ if station in self.storage.aliases:
88
+ slug = self.storage.aliases[station]
89
+ else:
90
+ slug = name_to_slug(station)
91
+ try:
92
+ return self.storage.get_cache(f"st-{slug}") or self.storage.set_cache(
93
+ f"st-{slug}", await self.client.get_station_by_slug(slug)
94
+ )
95
+ except self.client.errors.KoleoNotFound:
96
+ await self.error_and_exit(f"Station not found: [underline]{station}[/underline]")
97
+
98
+ async def get_brands(self):
99
+ return self.storage.get_cache("brands") or self.storage.set_cache("brands", await self.client.get_brands())
100
+
101
+ async def get_station_by_id(self, id: int):
102
+ key = f"st-{id}"
103
+ return self.storage.get_cache(key) or self.storage.set_cache(key, await self.client.get_station_by_id(id))
@@ -0,0 +1,142 @@
1
+ from asyncio import gather
2
+
3
+ from .base import BaseCli
4
+ from .utils import format_price
5
+
6
+ from datetime import datetime, timedelta
7
+ from koleo.api.types import ConnectionDetail
8
+ from koleo.utils import koleo_time_to_dt
9
+
10
+
11
+ class Connections(BaseCli):
12
+ async def connections_view(
13
+ self,
14
+ start: str,
15
+ end: str,
16
+ date: datetime,
17
+ brands: list[str],
18
+ direct: bool,
19
+ include_prices: bool,
20
+ only_purchasable: bool,
21
+ length: int = 1,
22
+ ):
23
+ start_station, end_station, api_brands = await gather(
24
+ self.get_station(start), self.get_station(end), self.get_brands()
25
+ )
26
+ brands = [i.lower().strip() for i in brands]
27
+ if not brands:
28
+ connection_brands = {i["name"]: i["id"] for i in api_brands}
29
+ else:
30
+ connection_brands = {
31
+ i["name"]: i["id"]
32
+ for i in api_brands
33
+ if i["name"].lower().strip() in brands or i["logo_text"].lower().strip() in brands
34
+ }
35
+ if not connection_brands:
36
+ await self.error_and_exit(f'No brands match: [underline]{', '.join(brands)}[/underline]')
37
+ results: list[ConnectionDetail] = []
38
+ fetch_date = date
39
+ while len(results) < length:
40
+ connections = await self.client.get_connections(
41
+ start_station["name_slug"],
42
+ end_station["name_slug"],
43
+ list(connection_brands.values()),
44
+ fetch_date,
45
+ direct,
46
+ only_purchasable,
47
+ )
48
+ if connections:
49
+ fetch_date = koleo_time_to_dt(connections[-1]["departure"]) + timedelta(seconds=(30 * 60) + 1) # wtf
50
+ results.extend(connections)
51
+ else:
52
+ break
53
+ if include_prices:
54
+ res = await gather(
55
+ *(self.client.get_price(i["id"]) for i in results),
56
+ )
57
+ price_dict = {k: v for k, v in zip((i["id"] for i in results), res)}
58
+ else:
59
+ price_dict = {}
60
+ link = (
61
+ f"https://koleo.pl/rozklad-pkp/{start_station["name_slug"]}/{end_station["name_slug"]}"
62
+ + f"/{date.strftime("%d-%m-%Y_%H:%M")}"
63
+ + f"/{"all" if not direct else "direct"}/{"-".join(connection_brands.keys()) if brands else "all"}"
64
+ )
65
+ parts = [
66
+ f"[bold blue][link={link}]{start_station["name"]} → {end_station["name"]} at {date.strftime("%H:%M %d-%m")}[/link][/bold blue]"
67
+ ]
68
+
69
+ for i in results:
70
+ arr = koleo_time_to_dt(i["arrival"])
71
+ dep = koleo_time_to_dt(i["departure"])
72
+ travel_time = (arr - dep).seconds
73
+ date_part = f"{arr.strftime("%d-%m")} " if arr.date() != date.date() else ""
74
+ if price := price_dict.get(i["id"]):
75
+ price_str = f" [bold red]{format_price(price)}[/bold red]"
76
+ else:
77
+ price_str = ""
78
+ parts.append(
79
+ f"[bold green][link=https://koleo.pl/travel-options/{i["id"]}]{date_part}{dep.strftime("%H:%M")} - {arr.strftime("%H:%M")}[/bold green] {travel_time//3600}h{(travel_time % 3600)/60:.0f}m {i['distance']}km{price_str}:[/link]"
80
+ )
81
+ if len(i["trains"]) == 1:
82
+ train = i["trains"][0]
83
+ brand = next(iter(i for i in api_brands if i["id"] == train["brand_id"]), {}).get("logo_text")
84
+
85
+ fs = next(iter(i for i in train["stops"] if i["station_id"] == train["start_station_id"]), {})
86
+ fs_station = (
87
+ start_station
88
+ if fs["station_id"] == start_station["id"]
89
+ else await self.get_station_by_id(fs["station_id"])
90
+ )
91
+
92
+ ls = next(iter(i for i in train["stops"] if i["station_id"] == train["end_station_id"]), {})
93
+ ls_station = (
94
+ start_station
95
+ if ls["station_id"] == start_station["id"]
96
+ else await self.get_station_by_id(ls["station_id"])
97
+ )
98
+
99
+ parts[-1] += (
100
+ f" [red]{brand}[/red] {train["train_full_name"]}[purple] {fs_station['name']} {self.format_position(fs["platform"], fs["track"])}[/purple] - [purple]{ls_station['name']} {self.format_position(ls["platform"], ls["track"])}[/purple]"
101
+ )
102
+ for constriction in i["constriction_info"]:
103
+ parts.append(f" [bold red]- {constriction}[/bold red]")
104
+ else:
105
+ for constriction in i["constriction_info"]:
106
+ parts.append(f" [bold red]- {constriction}[/bold red]")
107
+ previous_arrival: datetime | None = None
108
+ for train in i["trains"]:
109
+ brand = next(iter(i for i in api_brands if i["id"] == train["brand_id"]), {}).get("logo_text")
110
+
111
+ # first stop
112
+
113
+ fs = next(iter(i for i in train["stops"] if i["station_id"] == train["start_station_id"]), {})
114
+ fs_station = (
115
+ start_station
116
+ if fs["station_id"] == start_station["id"]
117
+ else await self.get_station_by_id(fs["station_id"])
118
+ )
119
+ # fs_arr = arr_dep_to_dt(fs["arrival"])
120
+ fs_dep = koleo_time_to_dt(fs["departure"])
121
+ fs_info = f"[bold green]{fs_dep.strftime("%H:%M")} [/bold green][purple]{fs_station['name']} {self.format_position(fs["platform"], fs["track"])}[/purple]"
122
+
123
+ # last stop
124
+
125
+ ls = next(iter(i for i in train["stops"] if i["station_id"] == train["end_station_id"]), {})
126
+ ls_station = (
127
+ start_station
128
+ if ls["station_id"] == start_station["id"]
129
+ else self.storage.get_cache(f"st-{ls['station_id']}")
130
+ or await self.get_station_by_id(ls["station_id"])
131
+ )
132
+ ls_arr = koleo_time_to_dt(ls["arrival"])
133
+ # ls_dep = arr_dep_to_dt(ls["departure"])
134
+ ls_info = f"[bold green]{ls_arr.strftime("%H:%M")} [/bold green][purple]{ls_station['name']} {self.format_position(ls["platform"], ls["track"])}[/purple]"
135
+ connection_time = (fs_dep - previous_arrival).seconds if previous_arrival else ""
136
+ previous_arrival = ls_arr
137
+ if connection_time:
138
+ parts.append(
139
+ f" {connection_time//3600}h{(connection_time % 3600)/60:.0f}m at [purple]{fs_station['name']}[/purple]"
140
+ )
141
+ parts.append(f" [red]{brand}[/red] {train["train_full_name"]} {fs_info} - {ls_info}")
142
+ self.print("\n".join(parts))
koleo/cli/seats.py ADDED
@@ -0,0 +1,103 @@
1
+ from .train_info import TrainInfo
2
+ from datetime import datetime
3
+ from asyncio import gather
4
+
5
+ from koleo.api import SeatsAvailabilityResponse, SeatState
6
+
7
+ from koleo.utils import BRAND_SEAT_TYPE_MAPPING, koleo_time_to_dt
8
+ from .utils import CLASS_COLOR_MAP
9
+
10
+
11
+ class Seats(TrainInfo):
12
+ async def train_passenger_stats_view(
13
+ self,
14
+ brand: str,
15
+ name: str,
16
+ date: datetime,
17
+ stations: tuple[str, str] | None = None,
18
+ type: str | None = None,
19
+ ):
20
+ train_calendars = await self.get_train_calendars(brand, name)
21
+ if not (train_id := train_calendars[0]["date_train_map"].get(date.strftime("%Y-%m-%d"))):
22
+ await self.error_and_exit(
23
+ f"This train doesn't run on the selected date: [underline]{date.strftime("%Y-%m-%d")}[/underline]"
24
+ )
25
+ train_details = await self.client.get_train(train_id)
26
+ if train_details["train"]["brand_id"] not in BRAND_SEAT_TYPE_MAPPING:
27
+ await self.error_and_exit(f"Brand [underline]{brand}[/underline] is not supported.")
28
+ train_stops_slugs = [i["station_slug"] for i in train_details["stops"]]
29
+ train_stops_by_slug = {i["station_slug"]: i for i in train_details["stops"]}
30
+ if stations:
31
+ first_station, last_station = [
32
+ i["name_slug"] for i in await gather(*(self.get_station(i) for i in stations))
33
+ ]
34
+ if first_station not in train_stops_slugs:
35
+ await self.error_and_exit(
36
+ f"Train [underline]{name}[/underline] doesn't stop at [underline]{first_station}[/underline]"
37
+ )
38
+ elif last_station not in train_stops_slugs:
39
+ await self.error_and_exit(
40
+ f"Train [underline]{name}[/underline] doesn't stop at [underline]{last_station}[/underline]"
41
+ )
42
+ else:
43
+ first_station, last_station = train_stops_slugs[0], train_stops_slugs[-1]
44
+ connections = await self.client.get_connections(
45
+ first_station,
46
+ last_station,
47
+ brand_ids=[train_details["train"]["brand_id"]],
48
+ direct=True,
49
+ date=koleo_time_to_dt(train_stops_by_slug[first_station]["departure"], base_date=date),
50
+ )
51
+ connection = next(iter(i for i in connections if i["trains"][0]["train_id"] == train_details["train"]["id"]))
52
+ connection_train = connection["trains"][0]
53
+ if connection_train["brand_id"] not in BRAND_SEAT_TYPE_MAPPING:
54
+ await self.error_and_exit(f"Brand [underline]{connection_train["brand_id"]}[/underline] is not supported.")
55
+ await self.show_train_header(
56
+ train_details, train_stops_by_slug[first_station], train_stops_by_slug[last_station]
57
+ )
58
+ await self.train_seat_info(connection["id"], type, connection_train["brand_id"], connection_train["train_nr"])
59
+
60
+ async def train_connection_stats_view(self, connection_id: int, type: str | None):
61
+ connection = await self.client.get_connection(connection_id)
62
+ train = connection["trains"][0]
63
+ if train["brand_id"] not in BRAND_SEAT_TYPE_MAPPING:
64
+ await self.error_and_exit(f'Brand [underline]{train["brand_id"]}[/underline] is not supported.')
65
+ print(connection)
66
+ train_details = await self.client.get_train(train["train_id"])
67
+ first_stop = next(iter(i for i in train_details["stops"] if i["station_id"] == connection["start_station_id"]))
68
+ last_stop = next(iter(i for i in train_details["stops"] if i["station_id"] == connection["end_station_id"]))
69
+ await self.show_train_header(train_details, first_stop, last_stop)
70
+ await self.train_seat_info(connection_id, type, train["brand_id"], train["train_nr"])
71
+
72
+ async def train_seat_info(self, connection_id: int, type: str | None, brand_id: int, train_nr: int):
73
+ seat_name_map = BRAND_SEAT_TYPE_MAPPING[brand_id]
74
+ if type is not None:
75
+ if type.isnumeric() and int(type) in seat_name_map:
76
+ types = [int(type)]
77
+ elif type_id := {v: k for k, v in seat_name_map.items()}.get(type):
78
+ types = [type_id]
79
+ else:
80
+ await self.error_and_exit(f"Invalid seat type [underline]{type}[/underline].")
81
+ else:
82
+ types = seat_name_map.keys()
83
+ res: dict[int, SeatsAvailabilityResponse] = {}
84
+ for seat_type in types:
85
+ res[seat_type] = await self.client.get_seats_availability(connection_id, train_nr, seat_type)
86
+ total_seats = sum(len(i["seats"]) for i in res.values())
87
+ for seat_type, result in res.items():
88
+ counters: dict[SeatState, int] = {"FREE": 0, "RESERVED": 0, "BLOCKED": 0}
89
+ for seat in result["seats"]:
90
+ counters[seat["state"]] += 1
91
+ color = CLASS_COLOR_MAP.get(seat_name_map[seat_type], "")
92
+ self.print(f"[bold {color}]{seat_name_map[seat_type]}: [/bold {color}]")
93
+ total = sum(i for i in counters.values())
94
+ not_available = counters["BLOCKED"] + counters["RESERVED"]
95
+ self.print(
96
+ f" Free: [{color}]{counters["FREE"]}/{total}, ~{counters["FREE"]/total*100:.1f}%[/{color}]"
97
+ )
98
+ self.print(
99
+ f" Reserved: [{color}]{counters["RESERVED"]}[/{color}]"
100
+ )
101
+ self.print(
102
+ f" Blocked: [underline {color}]{counters["BLOCKED"]}[/underline {color}]"
103
+ )