findata-api 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,383 @@
1
+ import asyncio
2
+ import datetime
3
+ import os
4
+ import threading
5
+ import time
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import py_misc_utils.alog as alog
10
+ import py_misc_utils.assert_checks as tas
11
+ import py_misc_utils.core_utils as pycu
12
+ import py_misc_utils.date_utils as pyd
13
+ import py_misc_utils.fin_wrap as pyfw
14
+ import py_misc_utils.pd_utils as pyp
15
+ import py_misc_utils.throttle as throttle
16
+ import py_misc_utils.utils as pyu
17
+
18
+ from . import api_base
19
+ from . import api_types
20
+ from . import utils as ut
21
+
22
+ try:
23
+ # Note: They have a new Python package at https://github.com/alpacahq/alpaca-py
24
+ import alpaca_trade_api as alpaca
25
+
26
+ MODULE_NAME = 'ALPACA'
27
+
28
+ def add_api_options(parser):
29
+ parser.add_argument('--alpaca_key', type=str,
30
+ help='The Alpaca API key')
31
+ parser.add_argument('--alpaca_secret', type=str,
32
+ help='The Alpaca API secret')
33
+ parser.add_argument('--alpaca_url', type=str,
34
+ help='The Alpaca API base URL')
35
+
36
+ def create_api(args):
37
+ return API(api_key=args.alpaca_key, api_secret=args.alpaca_secret,
38
+ api_url=args.alpaca_url, api_rate=args.api_rate)
39
+
40
+ except ImportError:
41
+ MODULE_NAME = None
42
+
43
+
44
+ _DATA_STEPS = {
45
+ 'min': 'Min',
46
+ 'minute': 'Min',
47
+ 'd': 'D',
48
+ 'day': 'D',
49
+ }
50
+ _FETCH_ORDERS_MAX = 500
51
+
52
+
53
+ def _get_config(key, secret, url):
54
+ if key is None:
55
+ key = pyu.getenv('APCA_API_KEY_ID')
56
+ if secret is None:
57
+ secret = pyu.getenv('APCA_API_SECRET_KEY')
58
+ if url is None:
59
+ url = pyu.getenv('APCA_API_BASE_URL')
60
+ if not url:
61
+ url = 'https://paper-api.alpaca.markets'
62
+
63
+ alog.debug0(f'Alpaca API created with: key={key} secret={secret} url={url}')
64
+
65
+ return key, secret, url
66
+
67
+
68
+ def _get_df_from_bars(bars, dtype=None):
69
+ df_rows = []
70
+ for bar in bars:
71
+ row = dict(bar.__dict__.get('_raw'))
72
+ # Rename 'S' to Tradzy standard 'symbol'.
73
+ row['symbol'] = row['S']
74
+ row.pop('S')
75
+
76
+ # Convert time string to EPOCH timestamp.
77
+ row['t'] = pyd.parse_date(row['t']).timestamp()
78
+
79
+ df_rows.append(row)
80
+
81
+ df = pd.DataFrame(df_rows)
82
+ if dtype is not None:
83
+ for c in pyp.get_df_columns(df, discards={'t', 'symbol'}):
84
+ df[c] = df[c].astype(dtype)
85
+
86
+ return df
87
+
88
+
89
+ def _maybe_date(dstr):
90
+ return pyd.parse_date(str(dstr)) if dstr is not None else None
91
+
92
+
93
+ def _marshal_order(o):
94
+ return api_types.Order(id=o.id,
95
+ symbol=o.symbol,
96
+ quantity=pycu.cast(o.qty, float),
97
+ side=o.side,
98
+ type=o.type,
99
+ limit=pycu.cast(o.limit_price, float),
100
+ stop=pycu.cast(o.stop_price, float),
101
+ status=o.status,
102
+ created=_maybe_date(o.created_at),
103
+ filled=_maybe_date(o.filled_at),
104
+ filled_quantity=pycu.cast(o.filled_qty, float),
105
+ filled_avg_price=pycu.cast(o.filled_avg_price, float))
106
+
107
+
108
+ def _marshal_position(p):
109
+ return api_types.Position(symbol=p.symbol,
110
+ quantity=pycu.cast(p.qty, float),
111
+ value=pycu.cast(p.market_value, float))
112
+
113
+
114
+ def _marshal_account(a):
115
+ return api_types.Account(id=a.account_number,
116
+ buying_power=pycu.cast(a.buying_power, float))
117
+
118
+
119
+ def _get_stream_ts(v):
120
+ return v.seconds + v.nanoseconds * 1e-9
121
+
122
+
123
+ def _marshal_stream_trade(t):
124
+ return api_types.StreamTrade(timestamp=_get_stream_ts(t['t']),
125
+ symbol=t['S'],
126
+ quantity=t['s'],
127
+ price=t['p'])
128
+
129
+
130
+ def _marshal_stream_quote(q):
131
+ return api_types.StreamQuote(timestamp=_get_stream_ts(q['t']),
132
+ symbol=q['S'],
133
+ bid_size=q['bs'],
134
+ bid_price=q['bp'],
135
+ ask_size=q['as'],
136
+ ask_price=q['ap'])
137
+
138
+
139
+ def _marshal_stream_bar(b):
140
+ return api_types.StreamBar(timestamp=_get_stream_ts(b['t']),
141
+ symbol=b['S'],
142
+ open=b['o'],
143
+ high=b['h'],
144
+ low=b['l'],
145
+ close=b['c'],
146
+ volume=b['v'])
147
+
148
+
149
+ class Stream:
150
+
151
+ def __init__(self, api_key, api_secret,
152
+ data_stream_url='https://stream.data.alpaca.markets',
153
+ data_feed='sip'):
154
+ self._conn = alpaca.stream.Stream(
155
+ api_key,
156
+ api_secret,
157
+ data_stream_url=data_stream_url,
158
+ raw_data=True,
159
+ data_feed=data_feed)
160
+
161
+ self._stream_thread = None
162
+ self._handlers = dict()
163
+ self._symbols = None
164
+ self._lock = threading.Lock()
165
+ self._thread_loop = None
166
+ self._stopping = False
167
+
168
+ self._stream_thread = threading.Thread(target=self._stream_thread_fn, daemon=True)
169
+ self._stream_thread.start()
170
+
171
+ async def _stream_handler(self, d):
172
+ handlers = self._handlers
173
+
174
+ kind = d.get('T')
175
+ if kind == 'q':
176
+ handler = handlers.get('quotes')
177
+ if handler is not None:
178
+ handler(_marshal_stream_quote(d))
179
+ elif kind == 't':
180
+ handler = handlers.get('trades')
181
+ if handler is not None:
182
+ handler(_marshal_stream_trade(d))
183
+ elif kind == 'b':
184
+ handler = handlers.get('bars')
185
+ if handler is not None:
186
+ handler(_marshal_stream_bar(d))
187
+
188
+ def _stream_thread_fn(self):
189
+ self._thread_loop = asyncio.new_event_loop()
190
+ asyncio.set_event_loop(self._thread_loop)
191
+ try:
192
+ self._conn.run()
193
+ except Exception as e:
194
+ alog.exception(e, exmsg=f'Exception while running the stream thread loop')
195
+
196
+ alog.info(f'Stream thread exiting run loop')
197
+
198
+ def stop(self):
199
+ if not self._stopping:
200
+ alog.debug0(f'Stopping Alpaca stream')
201
+ self._stopping = True
202
+ asyncio.run_coroutine_threadsafe(self._conn.stop_ws(), self._thread_loop)
203
+ self._stream_thread.join()
204
+
205
+ def register(self, symbols, handlers):
206
+ with self._lock:
207
+ if self._symbols:
208
+ self._conn.unsubscribe_trades(*self._symbols)
209
+ self._conn.unsubscribe_quotes(*self._symbols)
210
+ self._symbols = None
211
+
212
+ self._handlers = handlers
213
+
214
+ if symbols:
215
+ self._conn.subscribe_trades(self._stream_handler, *symbols)
216
+ self._conn.subscribe_quotes(self._stream_handler, *symbols)
217
+ self._symbols = list(symbols)
218
+
219
+
220
+
221
+ class API(api_base.TradeAPI):
222
+
223
+ def __init__(self, api_key=None, api_secret=None, api_url=None, api_rate=None,
224
+ symbols_per_step=20, data_stream_url='https://stream.data.alpaca.markets',
225
+ data_feed='sip'):
226
+ super().__init__(name='Alpaca', supports_streaming=True)
227
+ self._api_key, self._api_secret, self._api_url = _get_config(api_key, api_secret, api_url)
228
+ self._api = alpaca.REST(self._api_key, self._api_secret, self._api_url)
229
+ self._api_throttle = throttle.Throttle(
230
+ (200 if api_rate is None else api_rate) / 60.0)
231
+ self._symbols_per_step = symbols_per_step
232
+ self._data_stream_url = data_stream_url
233
+ self._data_feed = data_feed
234
+ self._stream = None
235
+
236
+ def register_stream_handlers(self, symbols, handlers):
237
+ if self._stream is not None:
238
+ alog.debug1(f'Stopping previous real time stream')
239
+ self._stream.stop()
240
+
241
+ if symbols:
242
+ alog.debug1(f'Registering Streaming: handlers={tuple(handlers.keys())}\tsymbols={symbols}')
243
+
244
+ stream = Stream(self._api_key, self._api_secret,
245
+ data_stream_url=self._data_stream_url,
246
+ data_feed=self._data_feed)
247
+ pyfw.fin_wrap(self, '_stream', stream, finfn=stream.stop)
248
+ self._stream.register(symbols, handlers)
249
+
250
+ alog.debug1(f'Registration done!')
251
+ else:
252
+ pyfw.fin_wrap(self, '_stream', None)
253
+
254
+ def get_account(self):
255
+ with self._api_throttle.trigger():
256
+ account = self._api.get_account()
257
+
258
+ return _marshal_account(account)
259
+
260
+ def get_market_hours(self, dt):
261
+ dtz = dt.astimezone(pyd.ny_market_timezone())
262
+ dts = dtz.strftime('%Y-%m-%d')
263
+ with self._api_throttle.trigger():
264
+ calendar = self._api.get_calendar(start=dts, end=dts)
265
+ if calendar:
266
+ calendar = calendar[0]
267
+ market_open = dtz.replace(hour=calendar.open.hour, minute=calendar.open.minute,
268
+ second=0, microsecond=0)
269
+ market_close = dtz.replace(hour=calendar.close.hour, minute=calendar.close.minute,
270
+ second=0, microsecond=0)
271
+
272
+ return market_open, market_close
273
+
274
+ def submit_order(self, symbol, quantity, side, type='market', limit=None, stop=None):
275
+ with self._api_throttle.trigger():
276
+ order = self._api.submit_order(symbol, qty=quantity, side=side, type=type,
277
+ limit_price=limit, stop_price=stop)
278
+
279
+ return _marshal_order(order)
280
+
281
+ def get_order(self, oid):
282
+ with self._api_throttle.trigger():
283
+ order = self._api.get_order(oid)
284
+
285
+ return _marshal_order(order)
286
+
287
+ def _fetch_orders(self, limit=None, status='all', start_date=None, end_date=None):
288
+ after = start_date.isoformat() if start_date is not None else None
289
+ until = end_date.isoformat() if end_date is not None else None
290
+ with self._api_throttle.trigger():
291
+ orders = self._api.list_orders(limit=limit or _FETCH_ORDERS_MAX,
292
+ status=status,
293
+ after=after,
294
+ until=until)
295
+
296
+ return [_marshal_order(o) for o in orders]
297
+
298
+ def _dedup_timefilter_orders(self, orders, start_date=None, end_date=None):
299
+ od = dict()
300
+ for order in orders:
301
+ if start_date is not None and order.created < start_date:
302
+ continue
303
+ if end_date is not None and order.created > end_date:
304
+ continue
305
+ od[order.id] = order
306
+
307
+ return sorted(od.values(), key=lambda x: x.created)
308
+
309
+ def list_orders(self, limit=None, status='all', start_date=None, end_date=None):
310
+ if end_date is None:
311
+ end_date = pyd.now()
312
+ if start_date is None:
313
+ start_date = end_date.replace(hour=0, minute=0, second=0, microsecond=0)
314
+ edate = end_date
315
+ orders = []
316
+ while True:
317
+ xorders = self._fetch_orders(limit=_FETCH_ORDERS_MAX, status=status,
318
+ start_date=start_date, end_date=edate)
319
+ orders.extend(xorders)
320
+ if len(xorders) < _FETCH_ORDERS_MAX:
321
+ break
322
+ orders = sorted(orders, key=lambda x: x.created)
323
+ edate = orders[0].created + datetime.timedelta(seconds=1)
324
+ if edate <= start_date:
325
+ break
326
+
327
+ orders = self._dedup_timefilter_orders(orders,
328
+ start_date=start_date,
329
+ end_date=end_date)
330
+
331
+ return orders if limit is None else orders[-limit:]
332
+
333
+ def cancel_order(self, oid):
334
+ with self._api_throttle.trigger():
335
+ self._api.cancel_order(oid)
336
+
337
+ def list_positions(self):
338
+ with self._api_throttle.trigger():
339
+ positions = self._api.list_positions()
340
+
341
+ return [_marshal_position(p) for p in positions]
342
+
343
+ def _break_date_range(self, start_date, end_date, data_step):
344
+ # At Alpaca level API we need to use an hard limit, and break time range.
345
+ limit = pyu.getenv('APCA_LIMIT', dtype=int, defval=5000)
346
+
347
+ dstep = ut.get_data_step_delta(data_step)
348
+ if dstep >= datetime.timedelta(days=1):
349
+ range_step = limit * datetime.timedelta(days=1)
350
+ else:
351
+ range_step = limit * datetime.timedelta(minutes=1)
352
+
353
+ tsteps = ut.break_period_in_dates_list(start_date, end_date, range_step)
354
+
355
+ return tsteps, limit
356
+
357
+ def fetch_data(self, symbols, start_date, end_date, data_step='5Min', dtype=None):
358
+ tsteps, limit = self._break_date_range(start_date, end_date, data_step)
359
+
360
+ dfs = []
361
+ for tstart, tend in tsteps:
362
+ start = tstart.isoformat()
363
+ end = tend.isoformat()
364
+
365
+ alog.debug0(f'Fetch: start={start}\tend={end}')
366
+ for srange in range(0, len(symbols), self._symbols_per_step):
367
+ step_symbols = symbols[srange: srange + self._symbols_per_step]
368
+ with self._api_throttle.trigger():
369
+ bars = self._api.get_bars(step_symbols, ut.map_data_step(data_step, _DATA_STEPS),
370
+ limit=limit,
371
+ start=start,
372
+ end=end)
373
+ bsdf = _get_df_from_bars(bars, dtype=dtype)
374
+ if not bsdf.empty:
375
+ dfs.append(bsdf)
376
+
377
+ df = pd.concat(dfs, ignore_index=True) if dfs else None
378
+ if df is not None:
379
+ df = ut.purge_fetched_data(df, start_date, end_date, data_step)
380
+
381
+ alog.debug0(f'Fetched {len(df) if df is not None else 0} records')
382
+
383
+ return df
@@ -0,0 +1,146 @@
1
+ import datetime
2
+ import dateutil
3
+ import io
4
+ import os
5
+ import requests
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import py_misc_utils.alog as alog
10
+ import py_misc_utils.assert_checks as tas
11
+ import py_misc_utils.date_utils as pyd
12
+ import py_misc_utils.throttle as throttle
13
+ import py_misc_utils.utils as pyu
14
+
15
+ from . import api_base
16
+ from . import utils as ut
17
+
18
+
19
+ MODULE_NAME = 'ALPHA_VANTAGE'
20
+
21
+ def add_api_options(parser):
22
+ parser.add_argument('--alpha_vantage_key', type=str,
23
+ help='The Alpha Vantage API key')
24
+
25
+
26
+ def create_api(args):
27
+ return API(api_key=args.alpha_vantage_key, api_rate=args.api_rate)
28
+
29
+
30
+ _AV_QUERY_URL = 'https://www.alphavantage.co/query'
31
+ _TIME_COLUMNS = {'time', 'timestamp'}
32
+ _RESP_COLUMNS = {'open', 'high', 'low', 'close', 'volume'}
33
+
34
+
35
+ def _issue_request(func, **kwargs):
36
+ timeout = kwargs.pop('timeout', pyu.env('FINDATA_TIMEOUT', 90))
37
+ api_key = kwargs.pop('api_key', None)
38
+ params = dict(apikey=api_key, function=func, datatype='csv')
39
+ params.update(kwargs)
40
+
41
+ resp = requests.get(_AV_QUERY_URL, params=params, timeout=timeout)
42
+
43
+ tas.check_eq(resp.status_code, 200, msg=f'Request error {resp.status_code}:\n{resp.text}')
44
+
45
+ cols = ut.csv_parse_columns(resp.text)
46
+ scols = set(cols)
47
+ if all(c in scols for c in _RESP_COLUMNS):
48
+ time_columns = tuple(scols & _TIME_COLUMNS)
49
+ tas.check(time_columns, msg=f'Missing {_TIME_COLUMNS} column in response data: {cols}')
50
+
51
+ return resp.text, cols, time_columns[0]
52
+
53
+
54
+ def _parse_datetime(s):
55
+ # New York timezone.
56
+ return np.datetime64(dateutil.parser.parse(f'{s} -0400'), 'ms')
57
+
58
+
59
+ def _data_issue_request(func, **kwargs):
60
+ dtype = kwargs.pop('dtype', np.float32)
61
+ symbol = kwargs.get('symbol')
62
+
63
+ rresp = _issue_request(func, **kwargs)
64
+ if rresp is not None:
65
+ data, cols, time_col = rresp
66
+
67
+ types = {c: dtype for c in _RESP_COLUMNS}
68
+
69
+ df = pd.read_csv(io.StringIO(data),
70
+ dtype=types,
71
+ parse_dates=[time_col] if time_col else True)
72
+ df.rename(columns={'open': 'o',
73
+ 'close': 'c',
74
+ 'low': 'l',
75
+ 'high': 'h',
76
+ 'volume': 'v',
77
+ time_col: 't'}, inplace=True)
78
+ if 't' in df:
79
+ df['t'] = [pyd.np_datetime_to_epoch(_parse_datetime(s)) for s in df['t']]
80
+ if symbol:
81
+ df['symbol'] = [symbol] * len(df)
82
+
83
+ alog.debug0(f'Fetched {len(df)} rows from AV for {symbol}')
84
+
85
+ return df
86
+
87
+
88
+ def _enumerate_months(start_date, end_date):
89
+ syear, smonth = start_date.year, start_date.month
90
+ eyear, emonth = end_date.year, end_date.month
91
+ year, month = syear, smonth
92
+ while True:
93
+ if year > eyear or (year == eyear and month > emonth):
94
+ break
95
+
96
+ yield month, year
97
+
98
+ month += 1
99
+ if month > 12:
100
+ month = 1
101
+ year += 1
102
+
103
+
104
+ class API(api_base.API):
105
+ # https://www.alphavantage.co/documentation/#time-series-data
106
+
107
+ def __init__(self, api_key=None, api_rate=None):
108
+ super().__init__(name='AlphaVantage')
109
+ self._api_key = api_key or pyu.getenv('ALPHA_VANTAGE_KEY')
110
+ self._api_throttle = throttle.Throttle(
111
+ (5 if api_rate is None else api_rate) / 60.0)
112
+
113
+ def _get_tsi_data(self, symbols, data_step='5Min', month=None):
114
+ dfs = []
115
+ for symbol in symbols:
116
+ alog.debug0(f'Fetching data for {symbol} with {data_step} interval for month {month or "LATEST"}')
117
+
118
+ with self._api_throttle.trigger():
119
+ df = _data_issue_request('TIME_SERIES_INTRADAY',
120
+ api_key=self._api_key,
121
+ symbol=symbol,
122
+ interval=data_step.lower(),
123
+ month=month,
124
+ outputsize='full')
125
+
126
+ if df is None or df.empty:
127
+ alog.info(f'Missing data for "{symbol}" with {data_step} interval for month {month or "LATEST"}')
128
+ else:
129
+ dfs.append(df)
130
+
131
+ return dfs
132
+
133
+ def fetch_data(self, symbols, start_date, end_date, data_step='5Min', dtype=None):
134
+ alog.debug0(f'Fetch: start={start_date}\tend={end_date}')
135
+
136
+ dfs = []
137
+ for month, year in _enumerate_months(start_date, end_date):
138
+ ymdfs = self._get_tsi_data(symbols,
139
+ data_step=data_step,
140
+ month=f'{year}-{month:02d}')
141
+ dfs.extend(ymdfs)
142
+
143
+ df = pd.concat(dfs, ignore_index=True)
144
+
145
+ return ut.purge_fetched_data(df, start_date, end_date, data_step)
146
+
@@ -0,0 +1,49 @@
1
+ import py_misc_utils.state as pyst
2
+
3
+ from . import order_tracker
4
+
5
+
6
+ class API(pyst.StateBase):
7
+
8
+ def __init__(self, name=None, supports_streaming=False, supports_trading=False):
9
+ super().__init__()
10
+ self.name = name
11
+ self.supports_streaming = supports_streaming
12
+ self.supports_trading = supports_trading
13
+
14
+ def close(self):
15
+ pass
16
+
17
+ def range_supported(self, start_date, end_date, data_step):
18
+ return True
19
+
20
+
21
+ class TradeAPI(API):
22
+
23
+ def __init__(self, scheduler=None, refresh_time=None, **kwargs):
24
+ super().__init__(supports_trading=True, **kwargs)
25
+
26
+ self._store_state(__class__, refresh_time=refresh_time)
27
+
28
+ self.tracker = order_tracker.OrderTracker(self,
29
+ scheduler=scheduler,
30
+ refresh_time=refresh_time)
31
+ self.scheduler = self.tracker.scheduler
32
+
33
+ def _get_state(self, state):
34
+ cstate = API._get_state(self, state)
35
+ cstate.pop('tracker')
36
+ cstate.pop('scheduler')
37
+
38
+ return cstate
39
+
40
+ def _set_state(self, state):
41
+ scheduler = state.pop('scheduler')
42
+ refresh_time = self._load_state(__class__, state, 'refresh_time')
43
+
44
+ API._set_state(self, state)
45
+ self.tracker = order_tracker.OrderTracker(self,
46
+ scheduler=scheduler,
47
+ refresh_time=refresh_time)
48
+ self.scheduler = self.tracker.scheduler
49
+
@@ -0,0 +1,131 @@
1
+ import argparse
2
+ import collections
3
+ import copy
4
+ import importlib
5
+ import os
6
+ import threading
7
+
8
+ import py_misc_utils.alog as alog
9
+ import py_misc_utils.cleanups as cleanups
10
+ import py_misc_utils.dyn_modules as pydm
11
+ import py_misc_utils.global_namespace as gns
12
+ import py_misc_utils.module_utils as pymu
13
+ import py_misc_utils.utils as pyu
14
+
15
+
16
+ def _detect_apis():
17
+ parent, _ = pymu.split_module_name(__name__)
18
+
19
+ apis = pydm.DynLoader(modname=parent, postfix='_api')
20
+ module_names = apis.module_names()
21
+
22
+ order = {name: len(module_names) - i for i, name in enumerate(os.getenv(
23
+ 'FINDATA_API_ORDER',
24
+ 'finnhub,yfinance,polygon,alpha_vantage,alpaca').split(','))}
25
+
26
+ ordered_modules = sorted(module_names,
27
+ key=lambda x: order.get(x, -1),
28
+ reverse=True)
29
+
30
+ return apis, tuple(ordered_modules)
31
+
32
+
33
+ _APIS, _API_NAMES = _detect_apis()
34
+ _ARGS = None
35
+
36
+ def setup_api(args):
37
+ global _ARGS
38
+ _ARGS = args
39
+
40
+
41
+ def add_api_options(parser):
42
+ parser.add_argument('--api', type=str,
43
+ choices=_API_NAMES,
44
+ help='The API to use')
45
+ parser.add_argument('--api_rate', type=float,
46
+ default=pyu.getenv('API_RATE', dtype=float),
47
+ help='The maximum number of API calls per minute')
48
+
49
+ for mod in _APIS.modules():
50
+ mod.add_api_options(parser)
51
+
52
+
53
+ class _ApiCache:
54
+
55
+ def __init__(self):
56
+ self._lock = threading.Lock()
57
+ self._cache = dict()
58
+ self._cid = cleanups.register(self.clear)
59
+
60
+ def get(self, mod, name, args):
61
+ with self._lock:
62
+ api = self._cache.get(name)
63
+ if api is None:
64
+ api = mod.create_api(args)
65
+ self._cache[name] = api
66
+
67
+ return api
68
+
69
+ def clear(self):
70
+ with self._lock:
71
+ apis = list(self._cache.values())
72
+ self._cache = dict()
73
+
74
+ for api in apis:
75
+ api.close()
76
+
77
+
78
+ _API_CACHE = gns.Var(f'{__name__}.API_CACHE',
79
+ fork_init=True,
80
+ defval=lambda: _ApiCache())
81
+
82
+ def _api_cache():
83
+ return gns.get(_API_CACHE)
84
+
85
+
86
+ def _merged_args(sargs, nargs):
87
+ if nargs:
88
+ args = copy.copy(sargs)
89
+ for k, v in nargs.items():
90
+ setattr(args, k, v)
91
+
92
+ return args
93
+
94
+ return sargs
95
+
96
+
97
+ def create_api(name=None, create=False, args=None):
98
+ if name is None:
99
+ name = _ARGS.api or _API_NAMES[0]
100
+ mod = _APIS.get(name)
101
+ if mod is None:
102
+ alog.xraise(RuntimeError, f'Invalid API name: {name}')
103
+
104
+ if create:
105
+ api = mod.create_api(_merged_args(_ARGS, args))
106
+ else:
107
+ api = _api_cache().get(mod, name, _ARGS)
108
+
109
+ alog.debug0(f'Using {api.name} API')
110
+
111
+ return api
112
+
113
+
114
+ def select_api(start_date, end_date, data_step):
115
+ # In order of preference.
116
+ api_kinds = list(_API_NAMES)
117
+ # First try with the eventually specified API.
118
+ if _ARGS.api:
119
+ api = create_api(name=_ARGS.api)
120
+ if api.range_supported(start_date, end_date, data_step):
121
+ return api
122
+ api_kinds.remove(_ARGS.api)
123
+
124
+ for kind in api_kinds:
125
+ api = create_api(name=kind)
126
+ if api.range_supported(start_date, end_date, data_step):
127
+ return api
128
+
129
+ alog.xraise(RuntimeError, f'Unable to select valid API: start={start_date}\t' \
130
+ f'end={end_date}\tstep={data_step}')
131
+