architect-py 5.0.0b1__py3-none-any.whl → 5.0.0b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. architect_py/__init__.py +10 -3
  2. architect_py/async_client.py +291 -174
  3. architect_py/client_interface.py +19 -18
  4. architect_py/common_types/order_dir.py +12 -6
  5. architect_py/graphql_client/__init__.py +2 -0
  6. architect_py/graphql_client/enums.py +5 -0
  7. architect_py/grpc/__init__.py +25 -7
  8. architect_py/grpc/client.py +13 -5
  9. architect_py/grpc/models/Accounts/AccountsRequest.py +4 -1
  10. architect_py/grpc/models/Algo/AlgoOrder.py +114 -0
  11. architect_py/grpc/models/Algo/{ModifyAlgoOrderRequestForTwapAlgo.py → AlgoOrderRequest.py} +11 -10
  12. architect_py/grpc/models/Algo/AlgoOrdersRequest.py +72 -0
  13. architect_py/grpc/models/Algo/AlgoOrdersResponse.py +27 -0
  14. architect_py/grpc/models/Algo/CreateAlgoOrderRequest.py +56 -0
  15. architect_py/grpc/models/Algo/PauseAlgoRequest.py +42 -0
  16. architect_py/grpc/models/Algo/PauseAlgoResponse.py +20 -0
  17. architect_py/grpc/models/Algo/StartAlgoRequest.py +42 -0
  18. architect_py/grpc/models/Algo/StartAlgoResponse.py +20 -0
  19. architect_py/grpc/models/Algo/StopAlgoRequest.py +42 -0
  20. architect_py/grpc/models/Algo/StopAlgoResponse.py +20 -0
  21. architect_py/grpc/models/Boss/DepositsRequest.py +40 -0
  22. architect_py/grpc/models/Boss/DepositsResponse.py +27 -0
  23. architect_py/grpc/models/Boss/RqdAccountStatisticsRequest.py +42 -0
  24. architect_py/grpc/models/Boss/RqdAccountStatisticsResponse.py +25 -0
  25. architect_py/grpc/models/Boss/StatementUrlRequest.py +40 -0
  26. architect_py/grpc/models/Boss/StatementUrlResponse.py +23 -0
  27. architect_py/grpc/models/Boss/StatementsRequest.py +40 -0
  28. architect_py/grpc/models/Boss/StatementsResponse.py +27 -0
  29. architect_py/grpc/models/Boss/WithdrawalsRequest.py +40 -0
  30. architect_py/grpc/models/Boss/WithdrawalsResponse.py +27 -0
  31. architect_py/grpc/models/Boss/__init__.py +2 -0
  32. architect_py/grpc/models/Folio/HistoricalFillsRequest.py +4 -1
  33. architect_py/grpc/models/Marketdata/L1BookSnapshot.py +16 -2
  34. architect_py/grpc/models/Oms/Cancel.py +67 -19
  35. architect_py/grpc/models/Oms/Order.py +4 -11
  36. architect_py/grpc/models/Oms/PlaceOrderRequest.py +13 -20
  37. architect_py/grpc/models/OptionsMarketdata/OptionsChain.py +30 -0
  38. architect_py/grpc/models/OptionsMarketdata/OptionsChainGreeks.py +30 -0
  39. architect_py/grpc/models/OptionsMarketdata/OptionsChainGreeksRequest.py +47 -0
  40. architect_py/grpc/models/OptionsMarketdata/OptionsChainRequest.py +45 -0
  41. architect_py/grpc/models/OptionsMarketdata/OptionsExpirations.py +29 -0
  42. architect_py/grpc/models/OptionsMarketdata/OptionsExpirationsRequest.py +42 -0
  43. architect_py/grpc/models/OptionsMarketdata/__init__.py +2 -0
  44. architect_py/grpc/models/Symbology/ExecutionInfoRequest.py +47 -0
  45. architect_py/grpc/models/Symbology/ExecutionInfoResponse.py +27 -0
  46. architect_py/grpc/models/definitions.py +457 -790
  47. architect_py/grpc/resolve_endpoint.py +4 -1
  48. architect_py/internal_utils/__init__.py +0 -0
  49. architect_py/internal_utils/no_pandas.py +3 -0
  50. architect_py/tests/conftest.py +11 -6
  51. architect_py/tests/test_marketdata.py +19 -19
  52. architect_py/tests/test_orderflow.py +31 -28
  53. {architect_py-5.0.0b1.dist-info → architect_py-5.0.0b3.dist-info}/METADATA +2 -3
  54. {architect_py-5.0.0b1.dist-info → architect_py-5.0.0b3.dist-info}/RECORD +72 -42
  55. {architect_py-5.0.0b1.dist-info → architect_py-5.0.0b3.dist-info}/WHEEL +1 -1
  56. examples/book_subscription.py +2 -3
  57. examples/candles.py +3 -3
  58. examples/common.py +29 -20
  59. examples/external_cpty.py +4 -4
  60. examples/funding_rate_mean_reversion_algo.py +14 -20
  61. examples/order_sending.py +32 -33
  62. examples/stream_l1_marketdata.py +2 -2
  63. examples/stream_l2_marketdata.py +1 -3
  64. examples/trades.py +2 -2
  65. examples/tutorial_async.py +9 -7
  66. examples/tutorial_sync.py +5 -5
  67. scripts/generate_functions_md.py +3 -1
  68. scripts/generate_sync_interface.py +30 -11
  69. scripts/postprocess_grpc.py +21 -11
  70. scripts/preprocess_grpc_schema.py +174 -113
  71. architect_py/grpc/models/Algo/AlgoOrderForTwapAlgo.py +0 -61
  72. architect_py/grpc/models/Algo/CreateAlgoOrderRequestForTwapAlgo.py +0 -59
  73. {architect_py-5.0.0b1.dist-info → architect_py-5.0.0b3.dist-info}/licenses/LICENSE +0 -0
  74. {architect_py-5.0.0b1.dist-info → architect_py-5.0.0b3.dist-info}/top_level.txt +0 -0
@@ -7,20 +7,19 @@ from decimal import Decimal
7
7
  from typing import AsyncIterator, Optional
8
8
 
9
9
  from architect_py.async_client import AsyncClient
10
+ from architect_py.common_types.order_dir import OrderDir
11
+ from architect_py.common_types.tradable_product import TradableProduct
10
12
  from architect_py.graphql_client.exceptions import GraphQLClientHttpError
11
- from architect_py.grpc_client.definitions import TimeInForceEnum
12
- from architect_py.grpc_client.Marketdata.TickerRequest import TickerRequest
13
- from architect_py.grpc_client.Oms.PlaceOrderRequest import PlaceOrderRequestType
14
- from architect_py.grpc_client.Orderflow.Orderflow import (
13
+ from architect_py.grpc.models.definitions import OrderType, TimeInForceEnum
14
+ from architect_py.grpc.models.Orderflow.Orderflow import (
15
15
  TaggedOrderAck,
16
16
  TaggedOrderOut,
17
17
  TaggedOrderReject,
18
18
  )
19
- from architect_py.grpc_client.Orderflow.OrderflowRequest import (
19
+ from architect_py.grpc.models.Orderflow.OrderflowRequest import (
20
20
  OrderflowRequest,
21
21
  PlaceOrder,
22
22
  )
23
- from architect_py.scalars import OrderDir, TradableProduct
24
23
 
25
24
  from .common import connect_async_client
26
25
 
@@ -52,9 +51,8 @@ class OrderflowRequester:
52
51
 
53
52
 
54
53
  async def update_marketdata(c: AsyncClient):
55
- ticker_request = TickerRequest(symbol=tradable_product)
56
- s = c.grpc_client.subscribe(ticker_request)
57
- async for ticker in s:
54
+ while True:
55
+ ticker = await c.get_ticker(tradable_product, venue)
58
56
  if ticker.funding_rate:
59
57
  global current_funding_rate
60
58
  global target_position
@@ -76,13 +74,14 @@ async def update_marketdata(c: AsyncClient):
76
74
  if ticker.ask_price:
77
75
  global best_ask_price
78
76
  best_ask_price = ticker.ask_price
77
+ await asyncio.sleep(1)
79
78
 
80
79
 
81
80
  async def subscribe_and_print_orderflow(
82
81
  c: AsyncClient, orderflow_requester: OrderflowRequester
83
82
  ):
84
83
  try:
85
- stream = c.grpc_client.subscribe_orderflow_stream(orderflow_requester)
84
+ stream = c.orderflow(orderflow_requester)
86
85
  """
87
86
  subscribe_orderflow_stream is a duplex_stream meaning that it is a stream that can be read from and written to.
88
87
  This is a stream that will be used to send orders to the Architect and receive order updates from the Architect.
@@ -126,7 +125,7 @@ async def step_to_target_position(
126
125
  execution_venue=None,
127
126
  limit_price=best_ask_price,
128
127
  time_in_force=TimeInForceEnum.DAY,
129
- place_order_request_type=PlaceOrderRequestType.LIMIT,
128
+ order_type=OrderType.LIMIT,
130
129
  )
131
130
 
132
131
  elif current_position > target_position:
@@ -141,7 +140,7 @@ async def step_to_target_position(
141
140
  execution_venue=None,
142
141
  limit_price=best_bid_price,
143
142
  time_in_force=TimeInForceEnum.DAY,
144
- place_order_request_type=PlaceOrderRequestType.LIMIT,
143
+ order_type=OrderType.LIMIT,
145
144
  )
146
145
 
147
146
  if order is not None:
@@ -159,14 +158,9 @@ async def print_info(c: AsyncClient):
159
158
  )
160
159
  pos = Decimal(0)
161
160
  for account in account_summaries:
162
- for balance in account.balances:
163
- if balance.product is None:
164
- name = "UNKNOWN NAME"
165
- else:
166
- name = balance.product
167
- print(f"balance for {name}: {balance.balance}")
168
- if name and balance.balance is not None:
169
- pos += balance.balance
161
+ for name, balance in account.balances.items():
162
+ print(f"balance for {name}: {balance}")
163
+ pos += balance
170
164
  global current_position
171
165
  current_position = pos
172
166
  print("---")
examples/order_sending.py CHANGED
@@ -1,31 +1,20 @@
1
1
  import asyncio
2
2
  import logging
3
+ from datetime import datetime, timedelta, timezone
3
4
  from decimal import Decimal
4
5
 
5
6
  from architect_py.async_client import AsyncClient
6
- from architect_py.graphql_client.enums import OrderType, TimeInForce
7
- from architect_py.scalars import OrderDir, TradableProduct
7
+ from architect_py.common_types.order_dir import OrderDir
8
+ from architect_py.common_types.tradable_product import TradableProduct
9
+ from architect_py.grpc.models.definitions import GoodTilDate, OrderType, TimeInForceEnum
10
+ from examples.common import connect_async_client
8
11
 
9
12
  LOGGER = logging.getLogger(__name__)
10
13
 
11
- api_key = None
12
- api_secret = None
13
- HOST = None
14
- ACCOUNT = None
15
14
 
16
-
17
- if api_key is None or api_secret is None or HOST is None or ACCOUNT is None:
18
- raise ValueError(
19
- "Please set the api_key, api_secret, HOST, and ACCOUNT variables before running this script"
20
- )
21
-
22
-
23
- client = AsyncClient(host=HOST, api_key=api_key, api_secret=api_secret)
24
-
25
-
26
- async def search_symbol() -> tuple[str, TradableProduct]:
15
+ async def search_symbol(c: AsyncClient) -> tuple[str, TradableProduct]:
27
16
  venue = "CME"
28
- markets = await client.search_symbols(
17
+ markets = await c.search_symbols(
29
18
  search_string="ES",
30
19
  execution_venue=venue,
31
20
  )
@@ -33,27 +22,32 @@ async def search_symbol() -> tuple[str, TradableProduct]:
33
22
  return venue, market
34
23
 
35
24
 
36
- async def test_send_order():
37
- venue, symbol = await search_symbol()
25
+ async def test_send_order(client: AsyncClient, account: str):
26
+ venue, symbol = await search_symbol(client)
38
27
 
39
28
  snapshot = await client.get_market_snapshot(symbol=symbol, venue=venue)
40
29
  if snapshot is None:
41
30
  return ValueError(f"Market snapshot for {symbol} is None")
42
31
 
43
- if snapshot.ask_price is None or snapshot.bid_price is None:
32
+ if snapshot.best_ask is None or snapshot.best_bid is None:
44
33
  return ValueError(f"Market snapshot for {symbol} is None")
45
34
 
46
- order = await client.send_limit_order(
35
+ best_bid_price, best_bid_quantity = snapshot.best_bid
36
+
37
+ d = datetime.now(tz=timezone.utc) + timedelta(days=1)
38
+ gtd = GoodTilDate(d)
39
+
40
+ order = await client.place_limit_order(
47
41
  symbol=symbol,
48
42
  odir=OrderDir.BUY,
49
- quantity=Decimal(1),
43
+ quantity=best_bid_quantity,
50
44
  order_type=OrderType.LIMIT,
51
45
  execution_venue="CME",
52
46
  post_only=True,
53
- limit_price=snapshot.bid_price
54
- - (snapshot.ask_price - snapshot.bid_price) * Decimal(10),
55
- account=ACCOUNT,
56
- time_in_force=TimeInForce.IOC,
47
+ limit_price=best_bid_price
48
+ - (snapshot.best_ask[0] - best_bid_price) * Decimal(10),
49
+ account=account,
50
+ time_in_force=gtd,
57
51
  )
58
52
  logging.critical(f"ORDER TEST: {order}")
59
53
 
@@ -64,12 +58,12 @@ async def test_send_order():
64
58
  return cancel
65
59
 
66
60
 
67
- async def test_cancel_all_orders():
61
+ async def test_cancel_all_orders(client: AsyncClient):
68
62
  await client.cancel_all_orders()
69
63
 
70
64
 
71
- async def test_send_market_pro_order():
72
- venue, symbol = await search_symbol()
65
+ async def test_send_market_pro_order(client: AsyncClient, account: str):
66
+ venue, symbol = await search_symbol(client)
73
67
  print(symbol)
74
68
 
75
69
  await client.send_market_pro_order(
@@ -77,13 +71,18 @@ async def test_send_market_pro_order():
77
71
  execution_venue=venue,
78
72
  odir=OrderDir.BUY,
79
73
  quantity=Decimal(1),
80
- account=ACCOUNT,
81
- time_in_force=TimeInForce.IOC,
74
+ account=account,
75
+ time_in_force=TimeInForceEnum.IOC,
82
76
  )
83
77
 
84
78
 
85
79
  async def main():
86
- await test_send_market_pro_order()
80
+ client: AsyncClient = await connect_async_client()
81
+ accounts = await client.list_accounts()
82
+ account: str = accounts[0].account.name
83
+
84
+ await test_send_order(client, account)
85
+ await test_send_market_pro_order(client, account)
87
86
 
88
87
 
89
88
  if __name__ == "__main__":
@@ -1,7 +1,7 @@
1
1
  import asyncio
2
2
 
3
3
  from architect_py.async_client import AsyncClient
4
- from architect_py.scalars import TradableProduct
4
+ from architect_py.common_types.tradable_product import TradableProduct
5
5
 
6
6
  from .common import connect_async_client
7
7
 
@@ -9,7 +9,7 @@ from .common import connect_async_client
9
9
  async def main():
10
10
  c: AsyncClient = await connect_async_client()
11
11
 
12
- async for snap in c.subscribe_l1_book_stream(
12
+ async for snap in c.stream_l1_book_snapshots(
13
13
  symbols=[TradableProduct("ES 20250620 CME Future/USD")],
14
14
  venue="CME",
15
15
  ):
@@ -1,7 +1,7 @@
1
1
  import asyncio
2
2
 
3
3
  from architect_py.async_client import AsyncClient
4
- from architect_py.scalars import TradableProduct
4
+ from architect_py.common_types.tradable_product import TradableProduct
5
5
 
6
6
  from .common import connect_async_client
7
7
 
@@ -30,8 +30,6 @@ async def print_l2_book(c: AsyncClient, symbol: TradableProduct, venue: str):
30
30
 
31
31
  async def main():
32
32
  c: AsyncClient = await connect_async_client()
33
- endpoint = "app.architect.co" # one example of alternative can be "binance.marketdata.architect.co"
34
- await c.grpc_client.change_channel(endpoint)
35
33
  market_symbol = TradableProduct("ES 20250620 CME Future/USD")
36
34
  venue = "CME"
37
35
  await print_l2_book(c, market_symbol, venue=venue)
examples/trades.py CHANGED
@@ -1,8 +1,8 @@
1
1
  import asyncio
2
2
 
3
3
  from architect_py.async_client import AsyncClient
4
+ from architect_py.common_types.tradable_product import TradableProduct
4
5
  from architect_py.graphql_client.exceptions import GraphQLClientHttpError
5
- from architect_py.scalars import TradableProduct
6
6
 
7
7
  from .common import connect_async_client
8
8
 
@@ -11,7 +11,7 @@ async def main():
11
11
  c: AsyncClient = await connect_async_client()
12
12
  market_id = TradableProduct("BTC Crypto", "USD")
13
13
  try:
14
- async for trade in c.subscribe_trades_stream(market_id, venue="COINBASE"):
14
+ async for trade in c.stream_trades(market_id, venue="COINBASE"):
15
15
  print(trade)
16
16
  except GraphQLClientHttpError as e:
17
17
  print(e.status_code)
@@ -2,8 +2,8 @@ import asyncio
2
2
  from decimal import Decimal
3
3
 
4
4
  from architect_py.async_client import OrderDir
5
+ from architect_py.common_types.tradable_product import TradableProduct
5
6
  from architect_py.graphql_client.enums import OrderStatus
6
- from architect_py.scalars import TradableProduct
7
7
  from examples.common import connect_async_client
8
8
 
9
9
 
@@ -19,8 +19,8 @@ async def main():
19
19
  print()
20
20
  print(f"Market snapshot for {market}")
21
21
  market_snapshot = await c.get_market_snapshot(symbol=market, venue=execution_venue)
22
- print(f"Best bid: {market_snapshot.bid_price}")
23
- print(f"Best ask: {market_snapshot.ask_price}")
22
+ print(f"Best bid: {market_snapshot.best_bid}")
23
+ print(f"Best ask: {market_snapshot.best_ask}")
24
24
 
25
25
  # List your FCM accounts
26
26
  print()
@@ -32,10 +32,12 @@ async def main():
32
32
  account_id = accounts[0].account.id
33
33
 
34
34
  # Place a limit order $100 below the best bid
35
- best_bid = market_snapshot.bid_price
36
- if best_bid is None:
35
+ if (best_bid := market_snapshot.best_bid) is not None:
36
+ bid_px = best_bid[0]
37
+ else:
37
38
  raise ValueError("No bid price available")
38
- limit_price = best_bid - Decimal(100)
39
+
40
+ limit_price = bid_px - Decimal(100)
39
41
  quantity = Decimal(1)
40
42
  account = accounts[0]
41
43
  order = None
@@ -46,7 +48,7 @@ async def main():
46
48
  )
47
49
  == "y"
48
50
  ):
49
- order = await c.send_limit_order(
51
+ order = await c.place_limit_order(
50
52
  symbol=market,
51
53
  execution_venue=execution_venue,
52
54
  odir=OrderDir.BUY,
examples/tutorial_sync.py CHANGED
@@ -2,8 +2,8 @@ import pprint
2
2
  import time
3
3
  from decimal import Decimal
4
4
 
5
+ from architect_py.common_types.order_dir import OrderDir
5
6
  from architect_py.graphql_client.enums import OrderStatus
6
- from architect_py.scalars import OrderDir
7
7
  from architect_py.utils.nearest_tick import TickRoundMethod
8
8
 
9
9
  from .common import confirm, connect_client, print_book, print_open_orders
@@ -56,10 +56,10 @@ orders = c.get_open_orders()
56
56
  print_open_orders(orders)
57
57
 
58
58
  # Place a limit order 20% below the best bid
59
- best_bid = market_snapshot.bid_price
59
+ best_bid = market_snapshot.best_bid
60
60
  assert best_bid is not None
61
- limit_price = best_bid * Decimal(0.8)
62
- quantity = Decimal(1)
61
+ best_bid_price, best_bid_quantity = best_bid
62
+ limit_price = best_bid_price * Decimal(0.8)
63
63
  account = accounts[0]
64
64
  order = None
65
65
 
@@ -70,7 +70,7 @@ if confirm(
70
70
  symbol=symbol,
71
71
  execution_venue=venue,
72
72
  odir=OrderDir.BUY,
73
- quantity=quantity,
73
+ quantity=best_bid_quantity,
74
74
  limit_price=limit_price,
75
75
  account=account.account.name,
76
76
  price_round_method=TickRoundMethod.ROUND,
@@ -79,13 +79,15 @@ def get_asyncclient_methods(filename):
79
79
  decorator_name = (
80
80
  decorator.func.attr
81
81
  if isinstance(decorator.func, ast.Attribute)
82
- else decorator.func.id
82
+ else getattr(decorator.func, "id", "")
83
83
  )
84
84
  else:
85
85
  decorator_name = (
86
86
  decorator.attr
87
87
  if isinstance(decorator, ast.Attribute)
88
88
  else decorator.id
89
+ if isinstance(decorator, ast.Name)
90
+ else ""
89
91
  )
90
92
  if decorator_name == "overload":
91
93
  is_fn_overload = True
@@ -1,9 +1,10 @@
1
+ import argparse
1
2
  import collections.abc
2
3
  import inspect
3
4
  import types
4
5
  from decimal import Decimal
5
6
  from enum import Enum
6
- from typing import Any, Sequence, Union, get_args, get_origin
7
+ from typing import Annotated, Any, Sequence, Union, get_args, get_origin
7
8
 
8
9
  from architect_py.async_client import AsyncClient
9
10
  from architect_py.graphql_client.base_model import UnsetType
@@ -19,6 +20,12 @@ def format_type_hint_with_generics(type_hint) -> str:
19
20
 
20
21
  origin = get_origin(type_hint)
21
22
 
23
+ if origin is Annotated:
24
+ annotated_args = get_args(type_hint)
25
+ if annotated_args:
26
+ return format_type_hint_with_generics(annotated_args[0])
27
+ return "Any"
28
+
22
29
  # Handle `|` unions (Python 3.10+)
23
30
  if isinstance(
24
31
  type_hint, types.UnionType
@@ -88,9 +95,9 @@ def autogenerate_protocol(cls, protocol_name: str) -> str:
88
95
  Returns:
89
96
  A string representing the Protocol definition.
90
97
  """
91
- methods = {}
92
- method_decorators = {}
93
- attributes = {}
98
+ methods: dict[str, inspect.Signature] = {}
99
+ method_decorators: dict[str, list[str]] = {}
100
+ attributes: dict[str, Any] = {}
94
101
 
95
102
  # Inspect class members
96
103
  for name, member in inspect.getmembers(cls):
@@ -139,7 +146,7 @@ def autogenerate_protocol(cls, protocol_name: str) -> str:
139
146
  "from typing import Any, Union",
140
147
  "from .graphql_client import *",
141
148
  "from .async_client import *",
142
- "\n",
149
+ "from .grpc.models.definitions import *",
143
150
  f"class {protocol_name}:",
144
151
  ]
145
152
 
@@ -151,14 +158,15 @@ def autogenerate_protocol(cls, protocol_name: str) -> str:
151
158
 
152
159
  # Add methods
153
160
  for name, signature in methods.items():
154
- if (
155
- name.startswith("subscribe")
156
- or name.startswith("stream")
157
- or name.startswith("unsubscribe")
158
- or name == "connect"
161
+ if any(
162
+ keyword in name
163
+ for keyword in ("subscribe", "stream", "unsubscribe", "connect")
159
164
  ):
160
165
  continue
161
166
 
167
+ if "async" in str(signature).lower():
168
+ continue
169
+
162
170
  # for decorators like @staticmethod and @classmethod
163
171
  if name in method_decorators:
164
172
  for deco in method_decorators[name]:
@@ -204,4 +212,15 @@ def autogenerate_protocol(cls, protocol_name: str) -> str:
204
212
 
205
213
 
206
214
  if __name__ == "__main__":
207
- print(autogenerate_protocol(AsyncClient, "ClientProtocol"))
215
+ parser = argparse.ArgumentParser(description="Process gRPC service definitions")
216
+ parser.add_argument(
217
+ "--file_path",
218
+ type=str,
219
+ default="architect_py/grpc_client",
220
+ help="Path to the Python folder with the gRPC service definitions",
221
+ )
222
+ args = parser.parse_args()
223
+ protocol: str = autogenerate_protocol(AsyncClient, "ClientProtocol")
224
+
225
+ with open(args.file_path, "w") as f:
226
+ f.write(protocol)
@@ -207,6 +207,9 @@ def create_tagged_subtypes_for_variant_types(content: str, json_data: dict) -> s
207
207
  if tag_field is None:
208
208
  return content
209
209
 
210
+ if "oneOf" not in json_data:
211
+ return content
212
+
210
213
  # Build a mapping from base type name to its variants.
211
214
  tag_field_map: dict[str, List[Tuple[str, str]]] = defaultdict(list)
212
215
  for p in json_data["oneOf"]:
@@ -319,30 +322,35 @@ def fix_lines(content: str) -> str:
319
322
  return content
320
323
 
321
324
 
322
- def add_post_processing_to_loosened_types(content: str, json_data: dict) -> str:
325
+ def add_post_processing_to_unflattened_types(content: str, json_data: dict) -> str:
323
326
  """
324
327
  Adds a __post_init__ method to the flattened types to enforce field requirements.
325
328
  """
326
- enum_tag = json_data.get("enum_tag")
327
- if enum_tag is None:
329
+ enum_variant_to_other_required_keys: dict[str, List[str]] = json_data.get(
330
+ "enum_variant_to_other_required_keys", {}
331
+ )
332
+ if len(enum_variant_to_other_required_keys) == 0:
328
333
  return content
329
334
 
335
+ enum_tag = json_data[
336
+ "tag_field"
337
+ ] # should not be empty if enum_variant_to_other_required_keys is not empty
338
+
330
339
  class_title = json_data["title"]
331
340
 
332
341
  properties = json_data["properties"]
333
- enum_tag_to_other_required_keys: dict[str, List[str]] = json_data[
334
- "enum_tag_to_other_required_keys"
335
- ]
336
342
 
337
343
  lines = content.splitlines(keepends=True)
338
344
  # Append __post_init__ method at the end
339
345
  lines.append("\n def __post_init__(self):\n")
340
346
 
341
- common_keys = set.intersection(*map(set, enum_tag_to_other_required_keys.values()))
342
- union_keys = set.union(*map(set, enum_tag_to_other_required_keys.values()))
347
+ common_keys = set.intersection(
348
+ *map(set, enum_variant_to_other_required_keys.values())
349
+ )
350
+ union_keys = set.union(*map(set, enum_variant_to_other_required_keys.values()))
343
351
 
344
352
  for i, (enum_value, required_keys) in enumerate(
345
- enum_tag_to_other_required_keys.items()
353
+ enum_variant_to_other_required_keys.items()
346
354
  ):
347
355
  conditional = "if" if i == 0 else "elif"
348
356
  title = properties[enum_tag]["title"]
@@ -429,7 +437,9 @@ def generate_stub(content: str, json_data: dict) -> str:
429
437
  """
430
438
 
431
439
  # If this is a Request file, append additional gRPC info.
432
- if json_data.get("tag_field") is not None:
440
+ if (json_data.get("tag_field") is not None) and (
441
+ json_data.get("enum_variant_to_other_required_keys") is None
442
+ ):
433
443
  service = json_data["service"]
434
444
  rpc_method = json_data["rpc_method"]
435
445
  response_type = json_data["response_type"]
@@ -556,7 +566,7 @@ def part_1(py_file_path: str, json_data: dict) -> None:
556
566
  content = create_tagged_subtypes_for_variant_types(content, json_data)
557
567
  content = fix_lines(content)
558
568
  if not py_file_path.endswith("definitions.py"):
559
- content = add_post_processing_to_loosened_types(content, json_data)
569
+ content = add_post_processing_to_unflattened_types(content, json_data)
560
570
 
561
571
  content = fix_enum_member_names(content, json_data)
562
572