zipline_polygon_bundle 0.1.8__py3-none-any.whl → 0.2.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,27 +1,23 @@
1
- from .config import PolygonConfig
1
+ from .config import PolygonConfig, PARTITION_COLUMN_NAME, to_partition_key
2
2
 
3
- from typing import Iterator, Tuple
3
+ from typing import Iterator, Tuple, Union
4
4
 
5
5
  import pyarrow as pa
6
- from pyarrow import dataset as pa_ds
7
- from pyarrow import compute as pa_compute
8
- from pyarrow import compute as pc
9
- from pyarrow import parquet as pa_parquet
10
- from pyarrow import csv as pa_csv
11
- from pyarrow import fs as pa_fs
6
+ import pyarrow.dataset as pa_ds
7
+ import pyarrow.compute as pa_compute
8
+ import pyarrow.csv as pa_csv
9
+ import pyarrow.fs as pa_fs
12
10
 
13
11
  from fsspec.implementations.arrow import ArrowFSWrapper
14
12
 
13
+ import os
15
14
  import datetime
16
- import pandas_market_calendars
15
+ import shutil
16
+
17
17
  import numpy as np
18
18
  import pandas as pd
19
-
20
19
  import pandas_ta as ta
21
20
 
22
- # from concurrent.futures import ThreadPoolExecutor
23
- # from concurrent.futures import ProcessPoolExecutor
24
-
25
21
 
26
22
  def trades_schema(raw: bool = False) -> pa.Schema:
27
23
  # There is some problem reading the timestamps as timestamps so we have to read as integer then change the schema.
@@ -36,22 +32,22 @@ def trades_schema(raw: bool = False) -> pa.Schema:
36
32
  price_type = pa.float64()
37
33
 
38
34
  return pa.schema(
39
- [
40
- pa.field("ticker", pa.string(), nullable=False),
41
- pa.field("conditions", pa.string(), nullable=False),
42
- pa.field("correction", pa.string(), nullable=False),
43
- pa.field("exchange", pa.int8(), nullable=False),
44
- pa.field("id", pa.string(), nullable=False),
45
- pa.field("participant_timestamp", timestamp_type, nullable=False),
46
- pa.field("price", price_type, nullable=False),
47
- pa.field("sequence_number", pa.int64(), nullable=False),
48
- pa.field("sip_timestamp", timestamp_type, nullable=False),
49
- pa.field("size", pa.int64(), nullable=False),
50
- pa.field("tape", pa.int8(), nullable=False),
51
- pa.field("trf_id", pa.int64(), nullable=False),
52
- pa.field("trf_timestamp", timestamp_type, nullable=False),
53
- ]
54
- )
35
+ [
36
+ pa.field("ticker", pa.string(), nullable=False),
37
+ pa.field("conditions", pa.string(), nullable=False),
38
+ pa.field("correction", pa.string(), nullable=False),
39
+ pa.field("exchange", pa.int8(), nullable=False),
40
+ pa.field("id", pa.string(), nullable=False),
41
+ pa.field("participant_timestamp", timestamp_type, nullable=False),
42
+ pa.field("price", price_type, nullable=False),
43
+ pa.field("sequence_number", pa.int64(), nullable=False),
44
+ pa.field("sip_timestamp", timestamp_type, nullable=False),
45
+ pa.field("size", pa.int64(), nullable=False),
46
+ pa.field("tape", pa.int8(), nullable=False),
47
+ pa.field("trf_id", pa.int64(), nullable=False),
48
+ pa.field("trf_timestamp", timestamp_type, nullable=False),
49
+ ]
50
+ )
55
51
 
56
52
 
57
53
  def trades_dataset(config: PolygonConfig) -> pa_ds.Dataset:
@@ -68,20 +64,26 @@ def trades_dataset(config: PolygonConfig) -> pa_ds.Dataset:
68
64
  fsspec.glob(os.path.join(config.trades_dir, config.csv_paths_pattern))
69
65
  )
70
66
 
71
- return pa_ds.FileSystemDataset.from_paths(paths,
72
- format=pa_ds.CsvFileFormat(),
73
- schema=trades_schema(raw=True),
74
- filesystem=config.filesystem)
67
+ return pa_ds.FileSystemDataset.from_paths(
68
+ paths,
69
+ format=pa_ds.CsvFileFormat(),
70
+ schema=trades_schema(raw=True),
71
+ filesystem=config.filesystem,
72
+ )
75
73
 
76
74
 
77
- def cast_strings_to_list(string_array, separator=",", default="0", value_type=pa.uint8()):
75
+ def cast_strings_to_list(
76
+ string_array, separator=",", default="0", value_type=pa.uint8()
77
+ ):
78
78
  """Cast a PyArrow StringArray of comma-separated numbers to a ListArray of values."""
79
79
 
80
80
  # Create a mask to identify empty strings
81
81
  is_empty = pa_compute.equal(pa_compute.utf8_trim_whitespace(string_array), "")
82
82
 
83
83
  # Use replace_with_mask to replace empty strings with the default ("0")
84
- filled_column = pa_compute.replace_with_mask(string_array, is_empty, pa.scalar(default))
84
+ filled_column = pa_compute.replace_with_mask(
85
+ string_array, is_empty, pa.scalar(default)
86
+ )
85
87
 
86
88
  # Split the strings by comma
87
89
  split_array = pa_compute.split_pattern(filled_column, pattern=separator)
@@ -94,8 +96,10 @@ def cast_strings_to_list(string_array, separator=",", default="0", value_type=pa
94
96
 
95
97
  def cast_trades(trades):
96
98
  trades = trades.cast(trades_schema())
97
- condition_values = cast_strings_to_list(trades.column("conditions").combine_chunks())
98
- return trades.append_column('condition_values', condition_values)
99
+ condition_values = cast_strings_to_list(
100
+ trades.column("conditions").combine_chunks()
101
+ )
102
+ return trades.append_column("condition_values", condition_values)
99
103
 
100
104
 
101
105
  def date_to_path(date, ext=".csv.gz"):
@@ -103,72 +107,96 @@ def date_to_path(date, ext=".csv.gz"):
103
107
  return date.strftime("%Y/%m/%Y-%m-%d") + ext
104
108
 
105
109
 
106
- def convert_to_custom_aggs_file(config: PolygonConfig,
107
- overwrite: bool,
108
- timestamp: pd.Timestamp,
109
- start_session: pd.Timestamp,
110
- end_session: pd.Timestamp):
111
- date = timestamp.to_pydatetime().date()
112
- aggs_date_path = date_to_path(date, ext=".parquet")
113
- aggs_path = f"{config.custom_aggs_dir}/{aggs_date_path}"
114
- # aggs_by_ticker_path = f"{config.custom_aggs_by_ticker_dir}/{aggs_date_path}"
115
- fsspec = ArrowFSWrapper(config.filesystem)
116
- if fsspec.exists(aggs_path) or fsspec.exists(aggs_by_ticker_path):
117
- if overwrite:
118
- if fsspec.exists(aggs_path):
119
- config.filesystem.delete_file(aggs_path)
120
- if fsspec.exists(aggs_by_ticker_path):
121
- config.filesystem.delete_file(aggs_by_ticker_path)
122
- else:
123
- if fsspec.exists(aggs_path):
124
- print(f"SKIPPING: {date=} File exists {aggs_path=}")
125
- if fsspec.exists(aggs_by_ticker_path):
126
- print(f"SKIPPING: {date=} File exists {aggs_by_ticker_path=}")
127
- return
128
- fsspec.mkdir(fsspec._parent(aggs_path))
129
- fsspec.mkdir(fsspec._parent(aggs_by_ticker_path))
130
- trades_path = f"{config.trades_dir}/{date_to_path(date)}"
131
- if not fsspec.exists(trades_path):
132
- print(f"ERROR: Trades file missing. Skipping {date=}. {trades_path=}")
133
- return
134
- print(f"{trades_path=}")
135
- format = pa_ds.CsvFileFormat()
136
- trades_ds = pa_ds.FileSystemDataset.from_paths([trades_path], format=format, schema=trades_schema(raw=True), filesystem=config.filesystem)
137
- fragments = trades_ds.get_fragments()
138
- fragment = next(fragments)
139
- try:
140
- next(fragments)
141
- print("ERROR: More than one fragment for {path=}")
142
- except StopIteration:
143
- pass
144
- trades = fragment.to_table(schema=trades_ds.schema)
145
- trades = trades.cast(trades_schema())
146
- min_timestamp = pa.compute.min(trades.column('sip_timestamp')).as_py()
147
- max_timestamp = pa.compute.max(trades.column('sip_timestamp')).as_py()
148
- if min_timestamp < start_session:
149
- print(f"ERROR: {min_timestamp=} < {start_session=}")
150
- if max_timestamp >= end_session:
151
- print(f"ERROR: {max_timestamp=} >= {end_session=}")
152
- trades_df = trades.to_pandas()
153
- trades_df["window_start"] = trades_df["sip_timestamp"].dt.floor(aggregate_timedelta)
154
- aggs_df = trades_df.groupby(["ticker", "window_start"]).agg(
155
- open=('price', 'first'),
156
- high=('price', 'max'),
157
- low=('price', 'min'),
158
- close=('price', 'last'),
159
- volume=('size', 'sum'),
160
- )
161
- aggs_df['transactions'] = trades_df.groupby(["ticker", "window_start"]).size()
162
- aggs_df.reset_index(inplace=True)
163
- aggs_table = pa.Table.from_pandas(aggs_df).select(['ticker', 'volume', 'open', 'close', 'high', 'low', 'window_start', 'transactions'])
164
- aggs_table = aggs_table.sort_by([('ticker', 'ascending'), ('window_start', 'ascending')])
165
- print(f"{aggs_by_ticker_path=}")
166
- pa_parquet.write_table(table=aggs_table,
167
- where=aggs_by_ticker_path, filesystem=to_config.filesystem)
168
- aggs_table = aggs_table.sort_by([('window_start', 'ascending'), ('ticker', 'ascending')])
169
- print(f"{aggs_path=}")
170
- pa_parquet.write_table(table=aggs_table,
171
- where=aggs_path, filesystem=to_config.filesystem)
110
+ # def convert_to_custom_aggs_file(
111
+ # config: PolygonConfig,
112
+ # overwrite: bool,
113
+ # timestamp: pd.Timestamp,
114
+ # start_session: pd.Timestamp,
115
+ # end_session: pd.Timestamp,
116
+ # ):
117
+ # date = timestamp.to_pydatetime().date()
118
+ # aggs_date_path = date_to_path(date, ext=".parquet")
119
+ # aggs_path = f"{config.custom_aggs_dir}/{aggs_date_path}"
120
+ # # aggs_by_ticker_path = f"{config.custom_aggs_by_ticker_dir}/{aggs_date_path}"
121
+ # fsspec = ArrowFSWrapper(config.filesystem)
122
+ # if fsspec.exists(aggs_path) or fsspec.exists(aggs_by_ticker_path):
123
+ # if overwrite:
124
+ # if fsspec.exists(aggs_path):
125
+ # config.filesystem.delete_file(aggs_path)
126
+ # if fsspec.exists(aggs_by_ticker_path):
127
+ # config.filesystem.delete_file(aggs_by_ticker_path)
128
+ # else:
129
+ # if fsspec.exists(aggs_path):
130
+ # print(f"SKIPPING: {date=} File exists {aggs_path=}")
131
+ # if fsspec.exists(aggs_by_ticker_path):
132
+ # print(f"SKIPPING: {date=} File exists {aggs_by_ticker_path=}")
133
+ # return
134
+ # fsspec.mkdir(fsspec._parent(aggs_path))
135
+ # fsspec.mkdir(fsspec._parent(aggs_by_ticker_path))
136
+ # trades_path = f"{config.trades_dir}/{date_to_path(date)}"
137
+ # if not fsspec.exists(trades_path):
138
+ # print(f"ERROR: Trades file missing. Skipping {date=}. {trades_path=}")
139
+ # return
140
+ # print(f"{trades_path=}")
141
+ # format = pa_ds.CsvFileFormat()
142
+ # trades_ds = pa_ds.FileSystemDataset.from_paths(
143
+ # [trades_path],
144
+ # format=format,
145
+ # schema=trades_schema(raw=True),
146
+ # filesystem=config.filesystem,
147
+ # )
148
+ # fragments = trades_ds.get_fragments()
149
+ # fragment = next(fragments)
150
+ # try:
151
+ # next(fragments)
152
+ # print("ERROR: More than one fragment for {path=}")
153
+ # except StopIteration:
154
+ # pass
155
+ # trades = fragment.to_table(schema=trades_ds.schema)
156
+ # trades = trades.cast(trades_schema())
157
+ # min_timestamp = pa.compute.min(trades.column("sip_timestamp")).as_py()
158
+ # max_timestamp = pa.compute.max(trades.column("sip_timestamp")).as_py()
159
+ # if min_timestamp < start_session:
160
+ # print(f"ERROR: {min_timestamp=} < {start_session=}")
161
+ # if max_timestamp >= end_session:
162
+ # print(f"ERROR: {max_timestamp=} >= {end_session=}")
163
+ # trades_df = trades.to_pandas()
164
+ # trades_df["window_start"] = trades_df["sip_timestamp"].dt.floor(aggregate_timedelta)
165
+ # aggs_df = trades_df.groupby(["ticker", "window_start"]).agg(
166
+ # open=("price", "first"),
167
+ # high=("price", "max"),
168
+ # low=("price", "min"),
169
+ # close=("price", "last"),
170
+ # volume=("size", "sum"),
171
+ # )
172
+ # aggs_df["transactions"] = trades_df.groupby(["ticker", "window_start"]).size()
173
+ # aggs_df.reset_index(inplace=True)
174
+ # aggs_table = pa.Table.from_pandas(aggs_df).select(
175
+ # [
176
+ # "ticker",
177
+ # "volume",
178
+ # "open",
179
+ # "close",
180
+ # "high",
181
+ # "low",
182
+ # "window_start",
183
+ # "transactions",
184
+ # ]
185
+ # )
186
+ # aggs_table = aggs_table.sort_by(
187
+ # [("ticker", "ascending"), ("window_start", "ascending")]
188
+ # )
189
+ # print(f"{aggs_by_ticker_path=}")
190
+ # pa_parquet.write_table(
191
+ # table=aggs_table, where=aggs_by_ticker_path, filesystem=to_config.filesystem
192
+ # )
193
+ # aggs_table = aggs_table.sort_by(
194
+ # [("window_start", "ascending"), ("ticker", "ascending")]
195
+ # )
196
+ # print(f"{aggs_path=}")
197
+ # pa_parquet.write_table(
198
+ # table=aggs_table, where=aggs_path, filesystem=to_config.filesystem
199
+ # )
172
200
 
173
201
 
174
202
  # def convert_to_custom_aggs(config: PolygonConfig,
@@ -232,11 +260,11 @@ def convert_to_custom_aggs_file(config: PolygonConfig,
232
260
  # aggs_table = aggs_table.sort_by([('ticker', 'ascending'), ('window_start', 'ascending')])
233
261
  # print(f"{aggs_by_ticker_path=}")
234
262
  # pa_parquet.write_table(table=aggs_table,
235
- # where=aggs_by_ticker_path, filesystem=to_config.filesystem)
263
+ # where=aggs_by_ticker_path, filesystem=to_config.filesystem)
236
264
  # aggs_table = aggs_table.sort_by([('window_start', 'ascending'), ('ticker', 'ascending')])
237
265
  # print(f"{aggs_path=}")
238
266
  # pa_parquet.write_table(table=aggs_table,
239
- # where=aggs_path, filesystem=to_config.filesystem)
267
+ # where=aggs_path, filesystem=to_config.filesystem)
240
268
  # pa_ds.write_dataset(
241
269
  # generate_batches_from_tables(tables),
242
270
  # schema=schema,
@@ -291,25 +319,28 @@ def custom_aggs_schema(raw: bool = False) -> pa.Schema:
291
319
  timestamp_type = pa.int64() if raw else pa.timestamp("ns", tz="UTC")
292
320
  price_type = pa.float64()
293
321
  return pa.schema(
294
- [
295
- pa.field("ticker", pa.string(), nullable=False),
296
- pa.field("volume", pa.int64(), nullable=False),
297
- pa.field("open", price_type, nullable=False),
298
- pa.field("close", price_type, nullable=False),
299
- pa.field("high", price_type, nullable=False),
300
- pa.field("low", price_type, nullable=False),
301
- pa.field("window_start", timestamp_type, nullable=False),
302
- pa.field("transactions", pa.int64(), nullable=False),
303
- pa.field("date", pa.date32(), nullable=False),
304
- pa.field("year", pa.uint16(), nullable=False),
305
- pa.field("month", pa.uint8(), nullable=False),
306
- ]
307
- )
322
+ [
323
+ pa.field("ticker", pa.string(), nullable=False),
324
+ pa.field("volume", pa.int64(), nullable=False),
325
+ pa.field("open", price_type, nullable=False),
326
+ pa.field("close", price_type, nullable=False),
327
+ pa.field("high", price_type, nullable=False),
328
+ pa.field("low", price_type, nullable=False),
329
+ pa.field("window_start", timestamp_type, nullable=False),
330
+ pa.field("transactions", pa.int64(), nullable=False),
331
+ pa.field("date", pa.date32(), nullable=False),
332
+ pa.field("year", pa.uint16(), nullable=False),
333
+ pa.field("month", pa.uint8(), nullable=False),
334
+ ]
335
+ )
308
336
 
309
337
 
310
338
  def custom_aggs_partitioning() -> pa.Schema:
311
339
  return pa_ds.partitioning(
312
- pa.schema([('year', pa.uint16()), ('month', pa.uint8()), ('date', pa.date32())]), flavor="hive"
340
+ pa.schema(
341
+ [("year", pa.uint16()), ("month", pa.uint8()), ("date", pa.date32())]
342
+ ),
343
+ flavor="hive",
313
344
  )
314
345
 
315
346
 
@@ -317,25 +348,31 @@ def get_custom_aggs_dates(config: PolygonConfig) -> set[datetime.date]:
317
348
  file_info = config.filesystem.get_file_info(config.custom_aggs_dir)
318
349
  if file_info.type == pa_fs.FileType.NotFound:
319
350
  return set()
320
- aggs_ds = pa_ds.dataset(config.custom_aggs_dir,
321
- format="parquet",
322
- schema=custom_aggs_schema(),
323
- partitioning=custom_aggs_partitioning())
324
- return set([pa_ds.get_partition_keys(fragment.partition_expression).get("date") for fragment in aggs_ds.get_fragments()])
351
+ aggs_ds = pa_ds.dataset(
352
+ config.custom_aggs_dir,
353
+ format="parquet",
354
+ schema=custom_aggs_schema(),
355
+ partitioning=custom_aggs_partitioning(),
356
+ )
357
+ return set(
358
+ [
359
+ pa_ds.get_partition_keys(fragment.partition_expression).get("date")
360
+ for fragment in aggs_ds.get_fragments()
361
+ ]
362
+ )
325
363
 
326
364
 
327
365
  def generate_csv_trades_tables(
328
366
  config: PolygonConfig, overwrite: bool = False
329
- ) -> Tuple[datetime.date, Iterator[pa.Table]]:
367
+ ) -> Iterator[Tuple[datetime.date, pa.Table]]:
330
368
  """Generator for trades tables from flatfile CSVs."""
331
369
  custom_aggs_dates = set()
332
370
  if not overwrite:
333
371
  custom_aggs_dates = get_custom_aggs_dates(config)
334
- # Use pandas_market_calendars so we can get extended hours.
335
- # NYSE and NASDAQ have extended hours but XNYS does not.
336
- calendar = pandas_market_calendars.get_calendar(config.calendar_name)
337
- schedule = calendar.schedule(start_date=config.start_timestamp, end_date=config.end_timestamp, start="pre", end="post")
338
- for timestamp, session in schedule.iterrows():
372
+ schedule = config.calendar.trading_index(
373
+ start=config.start_timestamp, end=config.end_timestamp, period="1D"
374
+ )
375
+ for timestamp in schedule:
339
376
  date = timestamp.to_pydatetime().date()
340
377
  if date in custom_aggs_dates:
341
378
  continue
@@ -359,77 +396,98 @@ def generate_csv_trades_tables(
359
396
  del trades
360
397
 
361
398
 
362
- def trades_to_custom_aggs(config: PolygonConfig, date: datetime.date, table: pa.Table, include_trf: bool = False) -> pa.Table:
399
+ def trades_to_custom_aggs(
400
+ config: PolygonConfig,
401
+ date: datetime.date,
402
+ table: pa.Table,
403
+ include_trf: bool = False,
404
+ ) -> pa.Table:
363
405
  print(f"{datetime.datetime.now()=} {date=} {pa.default_memory_pool()=}")
364
406
  # print(f"{resource.getrusage(resource.RUSAGE_SELF).ru_maxrss=}")
365
407
  table = table.filter(pa_compute.greater(table["size"], 0))
366
408
  table = table.filter(pa_compute.equal(table["correction"], "0"))
367
409
  if not include_trf:
368
410
  table = table.filter(pa_compute.not_equal(table["exchange"], 4))
369
- table = table.append_column("price_total", pa_compute.multiply(table["price"], table["size"]))
370
- table = table.append_column("window_start",
371
- pa_compute.floor_temporal(table["sip_timestamp"],
372
- multiple=config.agg_timedelta.seconds, unit="second"))
373
- # TODO: Calculate VWAP.
374
- table = table.group_by(["ticker", "window_start"], use_threads=False).aggregate([
375
- ('price', 'first'),
376
- ('price', 'max'),
377
- ('price', 'min'),
378
- ('price', 'last'),
379
- ('price_total', 'sum'),
380
- ('size', 'sum'),
381
- ([], "count_all")
382
- ])
383
- table = table.rename_columns({
384
- 'price_first': 'open',
385
- 'price_max': 'high',
386
- 'price_min': 'low',
387
- 'price_last': 'close',
388
- 'size_sum': 'volume',
389
- 'price_total_sum': 'total',
390
- 'count_all': 'transactions'})
391
- table = table.append_column("vwap", pa_compute.divide(table['total'], table['volume']))
411
+ table = table.append_column(
412
+ "price_total", pa_compute.multiply(table["price"], table["size"])
413
+ )
414
+ table = table.append_column(
415
+ "window_start",
416
+ pa_compute.floor_temporal(
417
+ table["sip_timestamp"], multiple=config.agg_timedelta.seconds, unit="second"
418
+ ),
419
+ )
420
+ table = table.group_by(["ticker", "window_start"], use_threads=False).aggregate(
421
+ [
422
+ ("price", "first"),
423
+ ("price", "max"),
424
+ ("price", "min"),
425
+ ("price", "last"),
426
+ ("price_total", "sum"),
427
+ ("size", "sum"),
428
+ ([], "count_all"),
429
+ ]
430
+ )
431
+ table = table.rename_columns(
432
+ {
433
+ "price_first": "open",
434
+ "price_max": "high",
435
+ "price_min": "low",
436
+ "price_last": "close",
437
+ "size_sum": "volume",
438
+ "price_total_sum": "total",
439
+ "count_all": "transactions",
440
+ }
441
+ )
442
+ table = table.append_column(
443
+ "vwap", pa_compute.divide(table["total"], table["volume"])
444
+ )
392
445
  # table.append_column('date', pa.array([date] * len(table), type=pa.date32()))
393
446
  # table.append_column('year', pa.array([date.year] * len(table), type=pa.uint16()))
394
447
  # table.append_column('month', pa.array([date.month] * len(table), type=pa.uint8()))
395
- table = table.append_column('date', pa.array(np.full(len(table), date)))
396
- table = table.append_column('year', pa.array(np.full(len(table), date.year), type=pa.uint16()))
397
- table = table.append_column('month', pa.array(np.full(len(table), date.month), type=pa.uint8()))
398
- table = table.sort_by([('window_start', 'ascending'), ('ticker', 'ascending')])
448
+ table = table.append_column("date", pa.array(np.full(len(table), date)))
449
+ table = table.append_column(
450
+ "year", pa.array(np.full(len(table), date.year), type=pa.uint16())
451
+ )
452
+ table = table.append_column(
453
+ "month", pa.array(np.full(len(table), date.month), type=pa.uint8())
454
+ )
455
+ table = table.sort_by([("window_start", "ascending"), ("ticker", "ascending")])
399
456
  return table
400
457
 
401
458
 
402
- def generate_custom_agg_batches_from_tables(config: PolygonConfig) -> pa.RecordBatch:
403
- for date, trades_table in generate_csv_trades_tables(config):
404
- for batch in trades_to_custom_aggs(config, date, trades_table).to_batches():
405
- yield batch
406
- del trades_table
459
+ # def generate_custom_agg_batches_from_tables(config: PolygonConfig) -> pa.RecordBatch:
460
+ # for date, trades_table in generate_csv_trades_tables(config):
461
+ # for batch in trades_to_custom_aggs(config, date, trades_table).to_batches():
462
+ # yield batch
463
+ # del trades_table
407
464
 
408
465
 
409
- def generate_custom_agg_tables(config: PolygonConfig) -> pa.Table:
410
- for date, trades_table in generate_csv_trades_tables(config):
411
- yield trades_to_custom_aggs(config, date, trades_table)
466
+ # def generate_custom_agg_tables(config: PolygonConfig) -> pa.Table:
467
+ # for date, trades_table in generate_csv_trades_tables(config):
468
+ # yield trades_to_custom_aggs(config, date, trades_table)
412
469
 
413
470
 
414
- def configure_write_custom_aggs_to_dataset(config: PolygonConfig):
415
- def write_custom_aggs_to_dataset(args: Tuple[datetime.date, pa.Table]):
416
- date, table = args
417
- pa_ds.write_dataset(
418
- trades_to_custom_aggs(config, date, table),
419
- filesystem=config.filesystem,
420
- base_dir=config.custom_aggs_dir,
421
- partitioning=custom_aggs_partitioning(),
422
- format="parquet",
423
- existing_data_behavior="overwrite_or_ignore",
424
- )
425
- return write_custom_aggs_to_dataset
471
+ # def configure_write_custom_aggs_to_dataset(config: PolygonConfig):
472
+ # def write_custom_aggs_to_dataset(args: Tuple[datetime.date, pa.Table]):
473
+ # date, table = args
474
+ # pa_ds.write_dataset(
475
+ # trades_to_custom_aggs(config, date, table),
476
+ # filesystem=config.filesystem,
477
+ # base_dir=config.custom_aggs_dir,
478
+ # partitioning=custom_aggs_partitioning(),
479
+ # format="parquet",
480
+ # existing_data_behavior="overwrite_or_ignore",
481
+ # )
482
+
483
+ # return write_custom_aggs_to_dataset
426
484
 
427
485
 
428
486
  def file_visitor(written_file):
429
487
  print(f"{written_file.path=}")
430
488
 
431
489
 
432
- def convert_all_to_custom_aggs(
490
+ def convert_trades_to_custom_aggs(
433
491
  config: PolygonConfig, overwrite: bool = False
434
492
  ) -> str:
435
493
  if overwrite:
@@ -438,7 +496,7 @@ def convert_all_to_custom_aggs(
438
496
  # MAX_FILES_OPEN = 8
439
497
  # MIN_ROWS_PER_GROUP = 100_000
440
498
 
441
- print(f"{config.custom_aggs_dir=}")
499
+ print(f"{config.aggs_dir=}")
442
500
 
443
501
  # pa.set_memory_pool()
444
502
 
@@ -460,7 +518,7 @@ def convert_all_to_custom_aggs(
460
518
  aggs_table,
461
519
  # schema=custom_aggs_schema(),
462
520
  filesystem=config.filesystem,
463
- base_dir=config.custom_aggs_dir,
521
+ base_dir=config.aggs_dir,
464
522
  partitioning=custom_aggs_partitioning(),
465
523
  format="parquet",
466
524
  existing_data_behavior="overwrite_or_ignore",
@@ -477,8 +535,8 @@ def convert_all_to_custom_aggs(
477
535
  # generate_csv_trades_tables(config),
478
536
  # )
479
537
 
480
- print(f"Generated aggregates to {config.custom_aggs_dir=}")
481
- return config.custom_aggs_dir
538
+ print(f"Generated aggregates to {config.aggs_dir=}")
539
+ return config.aggs_dir
482
540
 
483
541
 
484
542
  # https://github.com/twopirllc/pandas-ta/issues/731#issuecomment-1766786952
@@ -500,6 +558,146 @@ def convert_all_to_custom_aggs(
500
558
  # mfi = 100 - 100 / (1 + mf_avg_gain / (mf_avg_loss + epsilon))
501
559
  # return mfi
502
560
 
561
+
562
+ # def generate_custom_agg_tables(
563
+ # config: PolygonConfig,
564
+ # ) -> Tuple[pa.Schema, Iterator[pa.Table]]:
565
+ # """zipline does bundle ingestion one ticker at a time."""
566
+
567
+ # # Polygon Aggregate flatfile timestamps are in nanoseconds (like trades), not milliseconds as the docs say.
568
+ # # I make the timestamp timezone-aware because that's how Unix timestamps work and it may help avoid mistakes.
569
+ # timestamp_type = pa.timestamp("ns", tz="UTC")
570
+
571
+ # # But we can't use the timestamp type in the schema here because it's not supported by the CSV reader.
572
+ # # So we'll use int64 and cast it after reading the CSV file.
573
+ # # https://github.com/apache/arrow/issues/44030
574
+
575
+ # # strptime(3) (used by CSV reader for timestamps in ConvertOptions.timestamp_parsers) supports Unix timestamps (%s) and milliseconds (%f) but not nanoseconds.
576
+ # # https://www.geeksforgeeks.org/how-to-use-strptime-with-milliseconds-in-python/
577
+ # # Actually that's the wrong strptime (it's Python's). C++ strptime(3) doesn't even support %f.
578
+ # # https://github.com/apache/arrow/issues/39839#issuecomment-1915981816
579
+ # # Also I don't think you can use those in a format string without a separator.
580
+
581
+ # # Polygon price scale is 4 decimal places (i.e. hundredths of a penny), but we'll use 10 because we have precision to spare.
582
+ # # price_type = pa.decimal128(precision=38, scale=10)
583
+ # # 64bit float a little overkill but avoids any plausible truncation error.
584
+ # price_type = pa.float64()
585
+
586
+ # custom_aggs_schema = pa.schema(
587
+ # [
588
+ # pa.field("ticker", pa.string(), nullable=False),
589
+ # pa.field("volume", pa.int64(), nullable=False),
590
+ # pa.field("open", price_type, nullable=False),
591
+ # pa.field("close", price_type, nullable=False),
592
+ # pa.field("high", price_type, nullable=False),
593
+ # pa.field("low", price_type, nullable=False),
594
+ # pa.field("window_start", timestamp_type, nullable=False),
595
+ # pa.field("transactions", pa.int64(), nullable=False),
596
+ # pa.field(PARTITION_COLUMN_NAME, pa.string(), nullable=False),
597
+ # ]
598
+ # )
599
+
600
+ # # TODO: Use generator like os.walk for paths.
601
+ # return (
602
+ # custom_aggs_schema,
603
+ # generate_tables_from_custom_aggs(
604
+ # paths=config.csv_paths(),
605
+ # schema=custom_aggs_schema,
606
+ # start_timestamp=config.start_timestamp,
607
+ # limit_timestamp=config.end_timestamp + pd.to_timedelta(1, unit="day"),
608
+ # ),
609
+ # )
610
+
611
+ # def get_custom_aggs_dates(config: PolygonConfig) -> set[datetime.date]:
612
+ # file_info = config.filesystem.get_file_info(config.custom_aggs_dir)
613
+ # if file_info.type == pa_fs.FileType.NotFound:
614
+ # return set()
615
+ # aggs_ds = pa_ds.dataset(
616
+ # config.custom_aggs_dir,
617
+ # format="parquet",
618
+ # schema=custom_aggs_schema(),
619
+ # partitioning=custom_aggs_partitioning(),
620
+ # )
621
+ # return set(
622
+ # [
623
+ # pa_ds.get_partition_keys(fragment.partition_expression).get("date")
624
+ # for fragment in aggs_ds.get_fragments()
625
+ # ]
626
+ # )
627
+
628
+
629
+ def generate_batches_from_custom_aggs_ds(
630
+ aggs_ds: pa_ds.Dataset, schedule: pd.DatetimeIndex
631
+ ) -> Iterator[pa.RecordBatch]:
632
+ for timestamp in schedule:
633
+ date = timestamp.to_pydatetime().date()
634
+ date_filter_expr = (
635
+ (pa_compute.field("year") == date.year)
636
+ & (pa_compute.field("month") == date.month)
637
+ & (pa_compute.field("date") == date)
638
+ )
639
+ for batch in aggs_ds.to_batches(filter=date_filter_expr):
640
+ # TODO: Check that these rows are within range for this file's date (not just the whole session).
641
+ # And if we're doing that (figuring date for each file), we can just skip reading the file.
642
+ # Might able to do a single comparison using compute.days_between.
643
+ # https://arrow.apache.org/docs/python/generated/pyarrow.compute.days_between.html
644
+ batch = batch.append_column(
645
+ PARTITION_COLUMN_NAME,
646
+ pa.array(
647
+ [
648
+ to_partition_key(ticker)
649
+ for ticker in batch.column("ticker").to_pylist()
650
+ ]
651
+ ),
652
+ )
653
+ yield batch
654
+
655
+
656
+ def scatter_custom_aggs_to_by_ticker(
657
+ config: PolygonConfig,
658
+ overwrite: bool = False,
659
+ ) -> str:
660
+ file_info = config.filesystem.get_file_info(config.custom_aggs_dir)
661
+ if file_info.type == pa_fs.FileType.NotFound:
662
+ raise FileNotFoundError(f"{config.custom_aggs_dir=} not found.")
663
+
664
+ by_ticker_aggs_arrow_dir = config.by_ticker_aggs_arrow_dir
665
+ if os.path.exists(by_ticker_aggs_arrow_dir):
666
+ if overwrite:
667
+ print(f"Removing {by_ticker_aggs_arrow_dir=}")
668
+ shutil.rmtree(by_ticker_aggs_arrow_dir)
669
+ else:
670
+ print(f"Found existing {by_ticker_aggs_arrow_dir=}")
671
+ return by_ticker_aggs_arrow_dir
672
+
673
+ aggs_ds = pa_ds.dataset(
674
+ config.custom_aggs_dir,
675
+ format="parquet",
676
+ schema=custom_aggs_schema(),
677
+ partitioning=custom_aggs_partitioning(),
678
+ )
679
+ schedule = config.calendar.trading_index(
680
+ start=config.start_timestamp, end=config.end_timestamp, period="1D"
681
+ )
682
+ assert type(schedule) is pd.DatetimeIndex
683
+ partitioning = pa_ds.partitioning(
684
+ pa.schema([(PARTITION_COLUMN_NAME, pa.string())]), flavor="hive"
685
+ )
686
+ schema = aggs_ds.schema
687
+ schema = schema.append(pa.field(PARTITION_COLUMN_NAME, pa.string(), nullable=False))
688
+
689
+ pa_ds.write_dataset(
690
+ generate_batches_from_custom_aggs_ds(aggs_ds, schedule),
691
+ schema=schema,
692
+ base_dir=by_ticker_aggs_arrow_dir,
693
+ partitioning=partitioning,
694
+ format="parquet",
695
+ existing_data_behavior="overwrite_or_ignore",
696
+ )
697
+ print(f"Scattered custom aggregates by ticker to {by_ticker_aggs_arrow_dir=}")
698
+ return by_ticker_aggs_arrow_dir
699
+
700
+
503
701
  def calculate_mfi(typical_price: pd.Series, money_flow: pd.Series, period: int):
504
702
  mf_sign = np.where(typical_price > np.roll(typical_price, shift=1), 1, -1)
505
703
  signed_mf = money_flow * mf_sign
@@ -508,8 +706,14 @@ def calculate_mfi(typical_price: pd.Series, money_flow: pd.Series, period: int):
508
706
  positive_mf = np.maximum(signed_mf, 0)
509
707
  negative_mf = np.maximum(-signed_mf, 0)
510
708
 
511
- mf_avg_gain = np.convolve(positive_mf, np.ones(period), mode='full')[:len(positive_mf)] / period
512
- mf_avg_loss = np.convolve(negative_mf, np.ones(period), mode='full')[:len(negative_mf)] / period
709
+ mf_avg_gain = (
710
+ np.convolve(positive_mf, np.ones(period), mode="full")[: len(positive_mf)]
711
+ / period
712
+ )
713
+ mf_avg_loss = (
714
+ np.convolve(negative_mf, np.ones(period), mode="full")[: len(negative_mf)]
715
+ / period
716
+ )
513
717
 
514
718
  epsilon = 1e-10 # Small epsilon value to avoid division by zero
515
719
  mfi = 100 - (100 / (1 + mf_avg_gain / (mf_avg_loss + epsilon)))
@@ -523,7 +727,16 @@ def calculate_mfi(typical_price: pd.Series, money_flow: pd.Series, period: int):
523
727
  # Results affected by values outside range
524
728
  # https://github.com/twopirllc/pandas-ta/issues/535
525
729
 
526
- def calculate_stoch(high: pd.Series, low: pd.Series, close: pd.Series, k: int = 14, d: int = 3, smooth_k: int = 3, mamode:str = "sma"):
730
+
731
+ def calculate_stoch(
732
+ high: pd.Series,
733
+ low: pd.Series,
734
+ close: pd.Series,
735
+ k: int = 14,
736
+ d: int = 3,
737
+ smooth_k: int = 3,
738
+ mamode: str = "sma",
739
+ ):
527
740
  """Indicator: Stochastic Oscillator (STOCH)"""
528
741
  lowest_low = low.rolling(k).min()
529
742
  highest_high = high.rolling(k).max()
@@ -531,8 +744,14 @@ def calculate_stoch(high: pd.Series, low: pd.Series, close: pd.Series, k: int =
531
744
  stoch = 100 * (close - lowest_low)
532
745
  stoch /= ta.utils.non_zero_range(highest_high, lowest_low)
533
746
 
534
- stoch_k = ta.overlap.ma(mamode, stoch.loc[stoch.first_valid_index():,], length=smooth_k)
535
- stoch_d = ta.overlap.ma(mamode, stoch_k.loc[stoch_k.first_valid_index():,], length=d) if stoch_k is not None else None
747
+ stoch_k = ta.overlap.ma(
748
+ mamode, stoch.loc[stoch.first_valid_index() :,], length=smooth_k
749
+ )
750
+ stoch_d = (
751
+ ta.overlap.ma(mamode, stoch_k.loc[stoch_k.first_valid_index() :,], length=d)
752
+ if stoch_k is not None
753
+ else None
754
+ )
536
755
  # Histogram
537
756
  stoch_h = stoch_k - stoch_d if stoch_d is not None else None
538
757
 
@@ -540,12 +759,12 @@ def calculate_stoch(high: pd.Series, low: pd.Series, close: pd.Series, k: int =
540
759
 
541
760
 
542
761
  def compute_per_ticker_signals(df: pd.DataFrame, period: int = 14) -> pd.DataFrame:
543
- df = df.set_index('window_start').sort_index()
544
- session_index = pd.date_range(start=df.index[0],
545
- end=df.index[-1],
546
- freq=pd.Timedelta(seconds=60))
762
+ df = df.set_index("window_start").sort_index()
763
+ session_index = pd.date_range(
764
+ start=df.index[0], end=df.index[-1], freq=pd.Timedelta(seconds=60)
765
+ )
547
766
  df = df.reindex(session_index)
548
- df.index.rename('window_start', inplace=True)
767
+ df.index.rename("window_start", inplace=True)
549
768
 
550
769
  # df["minute_of_day"] = (df.index.hour * 60) + df.index.minute
551
770
  # df["day_of_week"] = df.index.day_of_week
@@ -580,12 +799,12 @@ def compute_per_ticker_signals(df: pd.DataFrame, period: int = 14) -> pd.DataFra
580
799
  df["ret1bar"] = close.div(price_open).sub(1)
581
800
 
582
801
  for t in range(2, period):
583
- df[f'ret{t}bar'] = close.div(price_open.shift(t-1)).sub(1)
802
+ df[f"ret{t}bar"] = close.div(price_open.shift(t - 1)).sub(1)
584
803
 
585
804
  # Average True Range (ATR)
586
- true_range = pd.concat([high.sub(low),
587
- high.sub(next_close).abs(),
588
- low.sub(next_close).abs()], axis=1).max(1)
805
+ true_range = pd.concat(
806
+ [high.sub(low), high.sub(next_close).abs(), low.sub(next_close).abs()], axis=1
807
+ ).max(1)
589
808
  # Normalized ATR (NATR) or Average of Normalized TR.
590
809
  # Choice of NATR operations ordering discussion: https://www.macroption.com/normalized-atr/
591
810
  # He doesn't talk about VWAP but I think that is a better normalizing price for a bar.
@@ -610,16 +829,16 @@ def compute_per_ticker_signals(df: pd.DataFrame, period: int = 14) -> pd.DataFra
610
829
  # df['CCI'] = (tp - df['SMA']) / (0.015 * df['MAD'])
611
830
  # df['cci_ta'] = ta.cci(high=high, low=low, close=close, length=period)
612
831
 
613
- df['taCCI'] = ta.cci(high=high, low=low, close=close, length=period)
832
+ df["taCCI"] = ta.cci(high=high, low=low, close=close, length=period)
614
833
 
615
834
  # https://gist.github.com/quantra-go-algo/1b37bfb74d69148f0dfbdb5a2c7bdb25
616
835
  # https://medium.com/@huzaifazahoor654/how-to-calculate-cci-in-python-a-step-by-step-guide-9a3f61698be6
617
836
  sma = pd.Series(ta.sma(vwap, length=period))
618
837
  mad = pd.Series(ta.mad(vwap, length=period))
619
- df['CCI'] = (vwap - sma) / (0.015 * mad)
838
+ df["CCI"] = (vwap - sma) / (0.015 * mad)
620
839
 
621
840
  # df['MFI'] = calculate_mfi(high=high, low=low, close=close, volume=volume, period=period)
622
- df['MFI'] = calculate_mfi(typical_price=vwap, money_flow=total, period=period)
841
+ df["MFI"] = calculate_mfi(typical_price=vwap, money_flow=total, period=period)
623
842
 
624
843
  # We use Stochastic (rather than MACD because we need a ticker independent indicator.
625
844
  # IOW a percentage price oscillator (PPO) rather than absolute price oscillator (APO).
@@ -633,49 +852,59 @@ def compute_per_ticker_signals(df: pd.DataFrame, period: int = 14) -> pd.DataFra
633
852
  return df
634
853
 
635
854
 
636
- def iterate_all_aggs_tables(config: PolygonConfig, valid_tickers: pa.Array, start_session: str = "pre", end_session: str = "market_open"):
637
- calendar = pandas_market_calendars.get_calendar(config.calendar_name)
638
- schedule = calendar.schedule(start_date=config.start_date,
639
- end_date=config.end_date,
640
- start="pre",
641
- end="post")
642
- for date, sessions in schedule.iterrows():
643
- # print(f"{date=} {sessions=}")
644
- start_dt = sessions[start_session]
645
- end_dt = sessions[end_session]
646
- # print(f"{date=} {start_dt=} {end_dt=}")
647
- aggs_ds = pa_ds.dataset(config.custom_aggs_dir,
648
- format="parquet",
649
- schema=custom_aggs_schema(),
650
- partitioning=custom_aggs_partitioning())
651
- date_filter_expr = ((pc.field('year') == date.year)
652
- & (pc.field('month') == date.month)
653
- & (pc.field('date') == date.to_pydatetime().date()))
855
+ def iterate_all_aggs_tables(
856
+ config: PolygonConfig,
857
+ valid_tickers: pa.Array,
858
+ ):
859
+ schedule = config.calendar.trading_index(
860
+ start=config.start_timestamp, end=config.end_timestamp, period="1D"
861
+ )
862
+ for timestamp in schedule:
863
+ date = timestamp.to_pydatetime().date()
864
+ aggs_ds = pa_ds.dataset(
865
+ config.custom_aggs_dir,
866
+ format="parquet",
867
+ schema=custom_aggs_schema(),
868
+ partitioning=custom_aggs_partitioning(),
869
+ )
870
+ date_filter_expr = (
871
+ (pa_compute.field("year") == date.year)
872
+ & (pa_compute.field("month") == date.month)
873
+ & (pa_compute.field("date") == date)
874
+ )
654
875
  # print(f"{date_filter_expr=}")
655
876
  for fragment in aggs_ds.get_fragments(filter=date_filter_expr):
656
- session_filter = ((pc.field('window_start') >= start_dt)
657
- & (pc.field('window_start') < end_dt)
658
- & pc.is_in(pc.field('ticker'), valid_tickers)
659
- )
877
+ session_filter = (
878
+ (pa_compute.field("window_start") >= start_dt)
879
+ & (pa_compute.field("window_start") < end_dt)
880
+ & pa_compute.is_in(pa_compute.field("ticker"), valid_tickers)
881
+ )
660
882
  # Sorting table doesn't seem to avoid needing to sort the df. Maybe use_threads=False on to_pandas would help?
661
883
  # table = fragment.to_table(filter=session_filter).sort_by([('ticker', 'ascending'), ('window_start', 'descending')])
662
884
  table = fragment.to_table(filter=session_filter)
663
885
  if table.num_rows > 0:
664
- metadata = dict(table.schema.metadata) if table.schema.metadata else dict()
665
- metadata["date"] = date.date().isoformat()
886
+ metadata = (
887
+ dict(table.schema.metadata) if table.schema.metadata else dict()
888
+ )
889
+ metadata["date"] = date.isoformat()
666
890
  table = table.replace_schema_metadata(metadata)
667
891
  yield table
668
892
 
669
893
 
670
- def iterate_all_aggs_with_signals(config: PolygonConfig):
671
- for table in iterate_all_aggs_tables(config):
672
- df = table.to_pandas()
673
- df = df.groupby("ticker").apply(compute_per_ticker_signals, include_groups=False)
674
- yield pa.Table.from_pandas(df)
894
+ # def iterate_all_aggs_with_signals(config: PolygonConfig):
895
+ # for table in iterate_all_aggs_tables(config):
896
+ # df = table.to_pandas()
897
+ # df = df.groupby("ticker").apply(
898
+ # compute_per_ticker_signals, include_groups=False
899
+ # )
900
+ # yield pa.Table.from_pandas(df)
675
901
 
676
902
 
677
903
  def compute_signals_for_all_custom_aggs(
678
- from_config: PolygonConfig, to_config: PolygonConfig, valid_tickers: pa.Array, overwrite: bool = False
904
+ from_config: PolygonConfig,
905
+ to_config: PolygonConfig,
906
+ valid_tickers: pa.Array,
907
+ overwrite: bool = False,
679
908
  ) -> str:
680
909
  if overwrite:
681
910
  print("WARNING: overwrite not implemented/ignored.")
@@ -684,17 +913,25 @@ def compute_signals_for_all_custom_aggs(
684
913
 
685
914
  for aggs_table in iterate_all_aggs_tables(from_config, valid_tickers):
686
915
  metadata = aggs_table.schema.metadata
687
- date = datetime.date.fromisoformat(metadata[b'date'].decode('utf-8'))
916
+ date = datetime.date.fromisoformat(metadata[b"date"].decode("utf-8"))
688
917
  print(f"{date=}")
689
918
  df = aggs_table.to_pandas()
690
- df = df.groupby("ticker").apply(compute_per_ticker_signals, include_groups=False)
919
+ df = df.groupby("ticker").apply(
920
+ compute_per_ticker_signals, include_groups=False
921
+ )
691
922
  table = pa.Table.from_pandas(df)
692
923
  if table.num_rows > 0:
693
924
  table = table.replace_schema_metadata(metadata)
694
- table = table.append_column('date', pa.array(np.full(len(table), date)))
695
- table = table.append_column('year', pa.array(np.full(len(table), date.year), type=pa.uint16()))
696
- table = table.append_column('month', pa.array(np.full(len(table), date.month), type=pa.uint8()))
697
- table = table.sort_by([('ticker', 'ascending'), ('window_start', 'ascending')])
925
+ table = table.append_column("date", pa.array(np.full(len(table), date)))
926
+ table = table.append_column(
927
+ "year", pa.array(np.full(len(table), date.year), type=pa.uint16())
928
+ )
929
+ table = table.append_column(
930
+ "month", pa.array(np.full(len(table), date.month), type=pa.uint8())
931
+ )
932
+ table = table.sort_by(
933
+ [("ticker", "ascending"), ("window_start", "ascending")]
934
+ )
698
935
  pa_ds.write_dataset(
699
936
  table,
700
937
  filesystem=to_config.filesystem,