@drift-labs/sdk 2.96.0-beta.0 → 2.96.0-beta.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,152 @@
1
+ // import WebSocket from 'ws';
2
+ import { logProviderCallback, EventType, LogProvider } from './types';
3
+ import { EventEmitter } from 'events';
4
+
5
+ // browser support
6
+ let WebSocketImpl: typeof WebSocket;
7
+ if (typeof window !== 'undefined' && window.WebSocket) {
8
+ WebSocketImpl = window.WebSocket;
9
+ } else {
10
+ WebSocketImpl = require('ws');
11
+ }
12
+
13
+ const EVENT_SERVER_HEARTBEAT_INTERVAL_MS = 5000;
14
+ const ALLOWED_MISSED_HEARTBEATS = 3;
15
+
16
+ export class EventsServerLogProvider implements LogProvider {
17
+ private ws?: WebSocket;
18
+ private callback?: logProviderCallback;
19
+ private isUnsubscribing = false;
20
+ private externalUnsubscribe = false;
21
+ private lastHeartbeat = 0;
22
+ private timeoutId?: NodeJS.Timeout;
23
+ private reconnectAttempts = 0;
24
+ eventEmitter?: EventEmitter;
25
+
26
+ public constructor(
27
+ private readonly url: string,
28
+ private readonly eventTypes: EventType[],
29
+ private readonly userAccount?: string
30
+ ) {
31
+ this.eventEmitter = new EventEmitter();
32
+ }
33
+
34
+ public isSubscribed(): boolean {
35
+ return this.ws !== undefined;
36
+ }
37
+
38
+ public async subscribe(callback: logProviderCallback): Promise<boolean> {
39
+ if (this.ws !== undefined) {
40
+ return true;
41
+ }
42
+ this.ws = new WebSocketImpl(this.url);
43
+
44
+ this.callback = callback;
45
+ this.ws.addEventListener('open', () => {
46
+ for (const channel of this.eventTypes) {
47
+ const subscribeMessage = {
48
+ type: 'subscribe',
49
+ channel: channel,
50
+ };
51
+ if (this.userAccount) {
52
+ subscribeMessage['user'] = this.userAccount;
53
+ }
54
+ this.ws.send(JSON.stringify(subscribeMessage));
55
+ }
56
+ this.reconnectAttempts = 0;
57
+ });
58
+
59
+ this.ws.addEventListener('message', (data) => {
60
+ try {
61
+ if (!this.isUnsubscribing) {
62
+ clearTimeout(this.timeoutId);
63
+ this.setTimeout();
64
+ if (this.reconnectAttempts > 0) {
65
+ console.log(
66
+ 'eventsServerLogProvider: Resetting reconnect attempts to 0'
67
+ );
68
+ }
69
+ this.reconnectAttempts = 0;
70
+ }
71
+
72
+ const parsedData = JSON.parse(data.data.toString());
73
+ if (parsedData.channel === 'heartbeat') {
74
+ this.lastHeartbeat = Date.now();
75
+ return;
76
+ }
77
+ if (parsedData.message !== undefined) {
78
+ return;
79
+ }
80
+ const event = JSON.parse(parsedData.data);
81
+ this.callback(
82
+ event.txSig,
83
+ event.slot,
84
+ [
85
+ 'Program dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH invoke [1]',
86
+ event.rawLog,
87
+ 'Program dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH success',
88
+ ],
89
+ undefined,
90
+ event.txSigIndex
91
+ );
92
+ } catch (error) {
93
+ console.error('Error parsing message:', error);
94
+ }
95
+ });
96
+
97
+ this.ws.addEventListener('close', () => {
98
+ console.log('eventsServerLogProvider: WebSocket closed');
99
+ });
100
+
101
+ this.ws.addEventListener('error', (error) => {
102
+ console.error('eventsServerLogProvider: WebSocket error:', error);
103
+ });
104
+
105
+ this.setTimeout();
106
+
107
+ return true;
108
+ }
109
+
110
+ public async unsubscribe(external = false): Promise<boolean> {
111
+ this.isUnsubscribing = true;
112
+ this.externalUnsubscribe = external;
113
+ if (this.timeoutId) {
114
+ clearInterval(this.timeoutId);
115
+ this.timeoutId = undefined;
116
+ }
117
+
118
+ if (this.ws !== undefined) {
119
+ this.ws.close();
120
+ this.ws = undefined;
121
+ return true;
122
+ } else {
123
+ this.isUnsubscribing = false;
124
+ return true;
125
+ }
126
+ }
127
+
128
+ private setTimeout(): void {
129
+ this.timeoutId = setTimeout(async () => {
130
+ if (this.isUnsubscribing || this.externalUnsubscribe) {
131
+ // If we are in the process of unsubscribing, do not attempt to resubscribe
132
+ return;
133
+ }
134
+
135
+ const timeSinceLastHeartbeat = Date.now() - this.lastHeartbeat;
136
+ if (
137
+ timeSinceLastHeartbeat >
138
+ EVENT_SERVER_HEARTBEAT_INTERVAL_MS * ALLOWED_MISSED_HEARTBEATS
139
+ ) {
140
+ console.log(
141
+ `eventServerLogProvider: No heartbeat in ${timeSinceLastHeartbeat}ms, resubscribing on attempt ${
142
+ this.reconnectAttempts + 1
143
+ }`
144
+ );
145
+ await this.unsubscribe();
146
+ this.reconnectAttempts++;
147
+ this.eventEmitter.emit('reconnect', this.reconnectAttempts);
148
+ this.subscribe(this.callback);
149
+ }
150
+ }, EVENT_SERVER_HEARTBEAT_INTERVAL_MS * 2);
151
+ }
152
+ }
@@ -60,7 +60,7 @@ export class PollingLogProvider implements LogProvider {
60
60
  const { mostRecentTx, transactionLogs } = response;
61
61
 
62
62
  for (const { txSig, slot, logs } of transactionLogs) {
63
- callback(txSig, slot, logs, response.mostRecentBlockTime);
63
+ callback(txSig, slot, logs, response.mostRecentBlockTime, undefined);
64
64
  }
65
65
 
66
66
  this.mostRecentSeenTx = mostRecentTx;
@@ -56,7 +56,11 @@ export const DefaultEventSubscriptionOptions: EventSubscriptionOptions = {
56
56
  commitment: 'confirmed',
57
57
  maxTx: 4096,
58
58
  logProviderConfig: {
59
- type: 'websocket',
59
+ type: 'events-server',
60
+ url: 'wss://events.drift.trade/ws',
61
+ maxReconnectAttempts: 5,
62
+ fallbackFrequency: 1000,
63
+ fallbackBatchSize: 100,
60
64
  },
61
65
  };
62
66
 
@@ -126,7 +130,8 @@ export type logProviderCallback = (
126
130
  txSig: TransactionSignature,
127
131
  slot: number,
128
132
  logs: string[],
129
- mostRecentBlockTime: number | undefined
133
+ mostRecentBlockTime: number | undefined,
134
+ txSigIndex: number | undefined
130
135
  ) => void;
131
136
 
132
137
  export interface LogProvider {
@@ -139,20 +144,38 @@ export interface LogProvider {
139
144
  eventEmitter?: EventEmitter;
140
145
  }
141
146
 
142
- export type WebSocketLogProviderConfig = {
143
- type: 'websocket';
144
- resubTimeoutMs?: number;
147
+ export type LogProviderType = 'websocket' | 'polling' | 'events-server';
148
+
149
+ export type StreamingLogProviderConfig = {
150
+ /// Max number of times to try reconnecting before failing over to fallback provider
145
151
  maxReconnectAttempts?: number;
152
+ /// used for PollingLogProviderConfig on fallback
146
153
  fallbackFrequency?: number;
154
+ /// used for PollingLogProviderConfig on fallback
147
155
  fallbackBatchSize?: number;
148
156
  };
149
157
 
158
+ export type WebSocketLogProviderConfig = StreamingLogProviderConfig & {
159
+ type: 'websocket';
160
+ /// Max time to wait before resubscribing
161
+ resubTimeoutMs?: number;
162
+ };
163
+
150
164
  export type PollingLogProviderConfig = {
151
165
  type: 'polling';
166
+ /// frequency to poll for new events
152
167
  frequency: number;
168
+ /// max number of events to fetch per poll
153
169
  batchSize?: number;
154
170
  };
155
171
 
172
+ export type EventsServerLogProviderConfig = StreamingLogProviderConfig & {
173
+ type: 'events-server';
174
+ /// url of the events server
175
+ url: string;
176
+ };
177
+
156
178
  export type LogProviderConfig =
157
179
  | WebSocketLogProviderConfig
158
- | PollingLogProviderConfig;
180
+ | PollingLogProviderConfig
181
+ | EventsServerLogProviderConfig;
@@ -64,7 +64,7 @@ export class WebSocketLogProvider implements LogProvider {
64
64
  if (logs.err !== null) {
65
65
  return;
66
66
  }
67
- callback(logs.signature, ctx.slot, logs.logs, undefined);
67
+ callback(logs.signature, ctx.slot, logs.logs, undefined, undefined);
68
68
  },
69
69
  this.commitment
70
70
  );
@@ -106,9 +106,9 @@ export class WebSocketLogProvider implements LogProvider {
106
106
 
107
107
  if (this.receivingData) {
108
108
  console.log(
109
- `No log data in ${this.resubTimeoutMs}ms, resubscribing on attempt ${
110
- this.reconnectAttempts + 1
111
- }`
109
+ `webSocketLogProvider: No log data in ${
110
+ this.resubTimeoutMs
111
+ }ms, resubscribing on attempt ${this.reconnectAttempts + 1}`
112
112
  );
113
113
  await this.unsubscribe();
114
114
  this.receivingData = false;
@@ -608,6 +608,56 @@
608
608
  }
609
609
  ]
610
610
  },
611
+ {
612
+ "name": "placeSwiftTakerOrder",
613
+ "accounts": [
614
+ {
615
+ "name": "state",
616
+ "isMut": false,
617
+ "isSigner": false
618
+ },
619
+ {
620
+ "name": "user",
621
+ "isMut": true,
622
+ "isSigner": false
623
+ },
624
+ {
625
+ "name": "userStats",
626
+ "isMut": true,
627
+ "isSigner": false
628
+ },
629
+ {
630
+ "name": "authority",
631
+ "isMut": false,
632
+ "isSigner": true
633
+ },
634
+ {
635
+ "name": "ixSysvar",
636
+ "isMut": false,
637
+ "isSigner": false,
638
+ "docs": [
639
+ "the supplied Sysvar could be anything else.",
640
+ "The Instruction Sysvar has not been implemented",
641
+ "in the Anchor framework yet, so this is the safe approach."
642
+ ]
643
+ }
644
+ ],
645
+ "args": [
646
+ {
647
+ "name": "takerOrderParamsMessageBytes",
648
+ "type": "bytes"
649
+ },
650
+ {
651
+ "name": "signature",
652
+ "type": {
653
+ "array": [
654
+ "u8",
655
+ 64
656
+ ]
657
+ }
658
+ }
659
+ ]
660
+ },
611
661
  {
612
662
  "name": "placeSpotOrder",
613
663
  "accounts": [
@@ -8050,6 +8100,40 @@
8050
8100
  ]
8051
8101
  }
8052
8102
  },
8103
+ {
8104
+ "name": "SwiftOrderParamsMessage",
8105
+ "type": {
8106
+ "kind": "struct",
8107
+ "fields": [
8108
+ {
8109
+ "name": "swiftOrderParams",
8110
+ "type": {
8111
+ "vec": {
8112
+ "defined": "OrderParams"
8113
+ }
8114
+ }
8115
+ },
8116
+ {
8117
+ "name": "marketIndex",
8118
+ "type": "u16"
8119
+ },
8120
+ {
8121
+ "name": "marketType",
8122
+ "type": {
8123
+ "defined": "MarketType"
8124
+ }
8125
+ },
8126
+ {
8127
+ "name": "expectedOrderId",
8128
+ "type": "i32"
8129
+ },
8130
+ {
8131
+ "name": "slot",
8132
+ "type": "u64"
8133
+ }
8134
+ ]
8135
+ }
8136
+ },
8053
8137
  {
8054
8138
  "name": "ModifyOrderParams",
8055
8139
  "type": {
@@ -12958,6 +13042,26 @@
12958
13042
  "code": 6284,
12959
13043
  "name": "InvalidPredictionMarketOrder",
12960
13044
  "msg": "Invalid prediction market order"
13045
+ },
13046
+ {
13047
+ "code": 6285,
13048
+ "name": "InvalidVerificationIxIndex",
13049
+ "msg": "Ed25519 Ix must be before place and make swift order ix"
13050
+ },
13051
+ {
13052
+ "code": 6286,
13053
+ "name": "SigVerificationFailed",
13054
+ "msg": "Swift taker message verificaiton failed"
13055
+ },
13056
+ {
13057
+ "code": 6287,
13058
+ "name": "MismatchedSwiftOrderParamsMarketIndex",
13059
+ "msg": "Market index mismatched b/w taker and maker swift order params"
13060
+ },
13061
+ {
13062
+ "code": 6288,
13063
+ "name": "SwiftOrderSequenceError",
13064
+ "msg": "Swift order message must be ordered 0th order is market and rest are triggers"
12961
13065
  }
12962
13066
  ],
12963
13067
  "metadata": {
@@ -7,10 +7,19 @@ import {
7
7
  AMM_RESERVE_PRECISION,
8
8
  MAX_PREDICTION_PRICE,
9
9
  BASE_PRECISION,
10
+ MARGIN_PRECISION,
11
+ PRICE_PRECISION,
12
+ QUOTE_PRECISION,
10
13
  } from '../constants/numericConstants';
11
14
  import { BN } from '@coral-xyz/anchor';
12
15
  import { OraclePriceData } from '../oracles/types';
13
- import { PerpMarketAccount, PerpPosition } from '..';
16
+ import {
17
+ calculateMarketMarginRatio,
18
+ calculateScaledInitialAssetWeight,
19
+ DriftClient,
20
+ PerpMarketAccount,
21
+ PerpPosition,
22
+ } from '..';
14
23
  import { isVariant } from '../types';
15
24
  import { assert } from '../assert/assert';
16
25
 
@@ -194,3 +203,130 @@ export function calculatePerpLiabilityValue(
194
203
  return baseAssetAmount.abs().mul(oraclePrice).div(BASE_PRECISION);
195
204
  }
196
205
  }
206
+
207
+ /**
208
+ * Calculates the margin required to open a trade, in quote amount. Only accounts for the trade size as a scalar value, does not account for the trade direction or current open positions and whether the trade would _actually_ be risk-increasing and use any extra collateral.
209
+ * @param targetMarketIndex
210
+ * @param baseSize
211
+ * @returns
212
+ */
213
+ export function calculateMarginUSDCRequiredForTrade(
214
+ driftClient: DriftClient,
215
+ targetMarketIndex: number,
216
+ baseSize: BN,
217
+ userMaxMarginRatio?: number
218
+ ): BN {
219
+ const targetMarket = driftClient.getPerpMarketAccount(targetMarketIndex);
220
+ const oracleData = driftClient.getOracleDataForPerpMarket(
221
+ targetMarket.marketIndex
222
+ );
223
+
224
+ const perpLiabilityValue = calculatePerpLiabilityValue(
225
+ baseSize,
226
+ oracleData.price,
227
+ isVariant(targetMarket.contractType, 'prediction')
228
+ );
229
+
230
+ const marginRequired = new BN(
231
+ calculateMarketMarginRatio(
232
+ targetMarket,
233
+ baseSize.abs(),
234
+ 'Initial',
235
+ userMaxMarginRatio
236
+ )
237
+ )
238
+ .mul(perpLiabilityValue)
239
+ .div(MARGIN_PRECISION);
240
+
241
+ return marginRequired;
242
+ }
243
+
244
+ /**
245
+ * Similar to calculatetMarginUSDCRequiredForTrade, but calculates how much of a given collateral is required to cover the margin requirements for a given trade. Basically does the same thing as getMarginUSDCRequiredForTrade but also accounts for asset weight of the selected collateral.
246
+ *
247
+ * Returns collateral required in the precision of the target collateral market.
248
+ */
249
+ export function calculateCollateralDepositRequiredForTrade(
250
+ driftClient: DriftClient,
251
+ targetMarketIndex: number,
252
+ baseSize: BN,
253
+ collateralIndex: number,
254
+ userMaxMarginRatio?: number
255
+ ): BN {
256
+ const marginRequiredUsdc = calculateMarginUSDCRequiredForTrade(
257
+ driftClient,
258
+ targetMarketIndex,
259
+ baseSize,
260
+ userMaxMarginRatio
261
+ );
262
+
263
+ const collateralMarket = driftClient.getSpotMarketAccount(collateralIndex);
264
+
265
+ const collateralOracleData =
266
+ driftClient.getOracleDataForSpotMarket(collateralIndex);
267
+
268
+ const scaledAssetWeight = calculateScaledInitialAssetWeight(
269
+ collateralMarket,
270
+ collateralOracleData.price
271
+ );
272
+
273
+ // Base amount required to deposit = (marginRequiredUsdc / priceOfAsset) / assetWeight .. (E.g. $100 required / $10000 price / 0.5 weight)
274
+ const baseAmountRequired = driftClient
275
+ .convertToSpotPrecision(collateralIndex, marginRequiredUsdc)
276
+ .mul(PRICE_PRECISION) // adjust for division by oracle price
277
+ .mul(SPOT_MARKET_WEIGHT_PRECISION) // adjust for division by scaled asset weight
278
+ .div(collateralOracleData.price)
279
+ .div(scaledAssetWeight)
280
+ .div(QUOTE_PRECISION); // adjust for marginRequiredUsdc value's QUOTE_PRECISION
281
+
282
+ // TODO : Round by step size?
283
+
284
+ return baseAmountRequired;
285
+ }
286
+
287
+ export function calculateCollateralValueOfDeposit(
288
+ driftClient: DriftClient,
289
+ collateralIndex: number,
290
+ baseSize: BN
291
+ ): BN {
292
+ const collateralMarket = driftClient.getSpotMarketAccount(collateralIndex);
293
+
294
+ const collateralOracleData =
295
+ driftClient.getOracleDataForSpotMarket(collateralIndex);
296
+
297
+ const scaledAssetWeight = calculateScaledInitialAssetWeight(
298
+ collateralMarket,
299
+ collateralOracleData.price
300
+ );
301
+
302
+ // CollateralBaseValue = oracle price * collateral base amount (and shift to QUOTE_PRECISION)
303
+ const collateralBaseValue = collateralOracleData.price
304
+ .mul(baseSize)
305
+ .mul(QUOTE_PRECISION)
306
+ .div(PRICE_PRECISION)
307
+ .div(new BN(10).pow(new BN(collateralMarket.decimals)));
308
+
309
+ const depositCollateralValue = collateralBaseValue
310
+ .mul(scaledAssetWeight)
311
+ .div(SPOT_MARKET_WEIGHT_PRECISION);
312
+
313
+ return depositCollateralValue;
314
+ }
315
+
316
+ export function calculateLiquidationPrice(
317
+ freeCollateral: BN,
318
+ freeCollateralDelta: BN,
319
+ oraclePrice: BN
320
+ ): BN {
321
+ const liqPriceDelta = freeCollateral
322
+ .mul(QUOTE_PRECISION)
323
+ .div(freeCollateralDelta);
324
+
325
+ const liqPrice = oraclePrice.sub(liqPriceDelta);
326
+
327
+ if (liqPrice.lt(ZERO)) {
328
+ return new BN(-1);
329
+ }
330
+
331
+ return liqPrice;
332
+ }
@@ -10,9 +10,11 @@ import { BN } from '@coral-xyz/anchor';
10
10
  export function getLimitOrderParams(
11
11
  params: Omit<OptionalOrderParams, 'orderType'> & { price: BN }
12
12
  ): OptionalOrderParams {
13
- return Object.assign({}, params, {
14
- orderType: OrderType.LIMIT,
15
- });
13
+ return getOrderParams(
14
+ Object.assign({}, params, {
15
+ orderType: OrderType.LIMIT,
16
+ })
17
+ );
16
18
  }
17
19
 
18
20
  export function getTriggerMarketOrderParams(
@@ -21,9 +23,11 @@ export function getTriggerMarketOrderParams(
21
23
  triggerPrice: BN;
22
24
  }
23
25
  ): OptionalOrderParams {
24
- return Object.assign({}, params, {
25
- orderType: OrderType.TRIGGER_MARKET,
26
- });
26
+ return getOrderParams(
27
+ Object.assign({}, params, {
28
+ orderType: OrderType.TRIGGER_MARKET,
29
+ })
30
+ );
27
31
  }
28
32
 
29
33
  export function getTriggerLimitOrderParams(
@@ -33,17 +37,21 @@ export function getTriggerLimitOrderParams(
33
37
  price: BN;
34
38
  }
35
39
  ): OptionalOrderParams {
36
- return Object.assign({}, params, {
37
- orderType: OrderType.TRIGGER_LIMIT,
38
- });
40
+ return getOrderParams(
41
+ Object.assign({}, params, {
42
+ orderType: OrderType.TRIGGER_LIMIT,
43
+ })
44
+ );
39
45
  }
40
46
 
41
47
  export function getMarketOrderParams(
42
48
  params: Omit<OptionalOrderParams, 'orderType'>
43
49
  ): OptionalOrderParams {
44
- return Object.assign({}, params, {
45
- orderType: OrderType.MARKET,
46
- });
50
+ return getOrderParams(
51
+ Object.assign({}, params, {
52
+ orderType: OrderType.MARKET,
53
+ })
54
+ );
47
55
  }
48
56
 
49
57
  /**
package/src/types.ts CHANGED
@@ -1,4 +1,9 @@
1
- import { PublicKey, Transaction, VersionedTransaction } from '@solana/web3.js';
1
+ import {
2
+ Keypair,
3
+ PublicKey,
4
+ Transaction,
5
+ VersionedTransaction,
6
+ } from '@solana/web3.js';
2
7
  import { BN, ZERO } from '.';
3
8
 
4
9
  // Utility type which lets you denote record with values of type A mapped to a record with the same keys but values of type B
@@ -1051,6 +1056,14 @@ export const DefaultOrderParams: OrderParams = {
1051
1056
  auctionEndPrice: null,
1052
1057
  };
1053
1058
 
1059
+ export type SwiftOrderParamsMessage = {
1060
+ marketIndex: number;
1061
+ swiftOrderParams: OptionalOrderParams[];
1062
+ expectedOrderId: number;
1063
+ marketType: MarketType;
1064
+ slot: BN;
1065
+ };
1066
+
1054
1067
  export type MakerInfo = {
1055
1068
  maker: PublicKey;
1056
1069
  makerStats: PublicKey;
@@ -1097,6 +1110,7 @@ export interface IWallet {
1097
1110
  signTransaction(tx: Transaction): Promise<Transaction>;
1098
1111
  signAllTransactions(txs: Transaction[]): Promise<Transaction[]>;
1099
1112
  publicKey: PublicKey;
1113
+ payer: Keypair;
1100
1114
  }
1101
1115
  export interface IVersionedWallet {
1102
1116
  signVersionedTransaction(
@@ -1106,6 +1120,7 @@ export interface IVersionedWallet {
1106
1120
  txs: VersionedTransaction[]
1107
1121
  ): Promise<VersionedTransaction[]>;
1108
1122
  publicKey: PublicKey;
1123
+ payer: Keypair;
1109
1124
  }
1110
1125
 
1111
1126
  export type FeeStructure = {
package/src/user.ts CHANGED
@@ -78,6 +78,8 @@ import {
78
78
  import { calculateMarketOpenBidAsk } from './math/amm';
79
79
  import {
80
80
  calculateBaseAssetValueWithOracle,
81
+ calculateCollateralDepositRequiredForTrade,
82
+ calculateMarginUSDCRequiredForTrade,
81
83
  calculateWorstCaseBaseAssetAmount,
82
84
  } from './math/margin';
83
85
  import { OraclePriceData } from './oracles/types';
@@ -2274,6 +2276,7 @@ export class User {
2274
2276
  * @param estimatedEntryPrice
2275
2277
  * @param marginCategory // allow Initial to be passed in if we are trying to calculate price for DLP de-risking
2276
2278
  * @param includeOpenOrders
2279
+ * @param offsetCollateral // allows calculating the liquidation price after this offset collateral is added to the user's account (e.g. : what will the liquidation price be for this position AFTER I deposit $x worth of collateral)
2277
2280
  * @returns Precision : PRICE_PRECISION
2278
2281
  */
2279
2282
  public liquidationPrice(
@@ -2281,7 +2284,8 @@ export class User {
2281
2284
  positionBaseSizeChange: BN = ZERO,
2282
2285
  estimatedEntryPrice: BN = ZERO,
2283
2286
  marginCategory: MarginCategory = 'Maintenance',
2284
- includeOpenOrders = false
2287
+ includeOpenOrders = false,
2288
+ offsetCollateral = ZERO
2285
2289
  ): BN {
2286
2290
  const totalCollateral = this.getTotalCollateral(marginCategory);
2287
2291
  const marginRequirement = this.getMarginRequirement(
@@ -2290,7 +2294,10 @@ export class User {
2290
2294
  false,
2291
2295
  includeOpenOrders
2292
2296
  );
2293
- let freeCollateral = BN.max(ZERO, totalCollateral.sub(marginRequirement));
2297
+ let freeCollateral = BN.max(
2298
+ ZERO,
2299
+ totalCollateral.sub(marginRequirement)
2300
+ ).add(offsetCollateral);
2294
2301
 
2295
2302
  const oracle =
2296
2303
  this.driftClient.getPerpMarketAccount(marketIndex).amm.oracle;
@@ -2593,6 +2600,32 @@ export class User {
2593
2600
  );
2594
2601
  }
2595
2602
 
2603
+ public getMarginUSDCRequiredForTrade(
2604
+ targetMarketIndex: number,
2605
+ baseSize: BN
2606
+ ): BN {
2607
+ return calculateMarginUSDCRequiredForTrade(
2608
+ this.driftClient,
2609
+ targetMarketIndex,
2610
+ baseSize,
2611
+ this.getUserAccount().maxMarginRatio
2612
+ );
2613
+ }
2614
+
2615
+ public getCollateralDepositRequiredForTrade(
2616
+ targetMarketIndex: number,
2617
+ baseSize: BN,
2618
+ collateralIndex: number
2619
+ ): BN {
2620
+ return calculateCollateralDepositRequiredForTrade(
2621
+ this.driftClient,
2622
+ targetMarketIndex,
2623
+ baseSize,
2624
+ collateralIndex,
2625
+ this.getUserAccount().maxMarginRatio
2626
+ );
2627
+ }
2628
+
2596
2629
  /**
2597
2630
  * Get the maximum trade size for a given market, taking into account the user's current leverage, positions, collateral, etc.
2598
2631
  *