@sogni-ai/sogni-client 1.0.2 → 2.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +24 -0
- package/README.md +126 -0
- package/dist/Account/CurrentAccount.d.ts +6 -8
- package/dist/Account/CurrentAccount.js +9 -29
- package/dist/Account/CurrentAccount.js.map +1 -1
- package/dist/Account/index.d.ts +10 -6
- package/dist/Account/index.js +25 -12
- package/dist/Account/index.js.map +1 -1
- package/dist/Account/types.d.ts +2 -0
- package/dist/ApiClient/WebSocketClient/index.d.ts +4 -4
- package/dist/ApiClient/WebSocketClient/index.js +33 -49
- package/dist/ApiClient/WebSocketClient/index.js.map +1 -1
- package/dist/ApiClient/WebSocketClient/messages.d.ts +7 -0
- package/dist/ApiClient/index.d.ts +4 -10
- package/dist/ApiClient/index.js +26 -17
- package/dist/ApiClient/index.js.map +1 -1
- package/dist/Projects/Project.d.ts +4 -0
- package/dist/Projects/Project.js +8 -0
- package/dist/Projects/Project.js.map +1 -1
- package/dist/Projects/createJobRequestMessage.js +12 -1
- package/dist/Projects/createJobRequestMessage.js.map +1 -1
- package/dist/Projects/index.d.ts +40 -2
- package/dist/Projects/index.js +146 -14
- package/dist/Projects/index.js.map +1 -1
- package/dist/Projects/types/index.d.ts +40 -0
- package/dist/lib/AuthManager.d.ts +51 -0
- package/dist/lib/AuthManager.js +157 -0
- package/dist/lib/AuthManager.js.map +1 -0
- package/dist/lib/Cache.d.ts +9 -0
- package/dist/lib/Cache.js +30 -0
- package/dist/lib/Cache.js.map +1 -0
- package/dist/lib/RestClient.d.ts +4 -7
- package/dist/lib/RestClient.js +7 -7
- package/dist/lib/RestClient.js.map +1 -1
- package/dist/lib/utils.d.ts +8 -0
- package/dist/lib/utils.js +20 -0
- package/dist/lib/utils.js.map +1 -0
- package/dist/version.d.ts +1 -1
- package/dist/version.js +1 -1
- package/dist/version.js.map +1 -1
- package/package.json +1 -1
- package/src/Account/CurrentAccount.ts +14 -36
- package/src/Account/index.ts +23 -11
- package/src/Account/types.ts +2 -0
- package/src/ApiClient/WebSocketClient/index.ts +16 -25
- package/src/ApiClient/WebSocketClient/messages.ts +8 -0
- package/src/ApiClient/index.ts +19 -27
- package/src/Projects/Project.ts +7 -0
- package/src/Projects/createJobRequestMessage.ts +14 -1
- package/src/Projects/index.ts +150 -8
- package/src/Projects/types/index.ts +42 -0
- package/src/lib/AuthManager.ts +172 -0
- package/src/lib/Cache.ts +36 -0
- package/src/lib/RestClient.ts +8 -13
- package/src/lib/utils.ts +17 -0
- package/src/version.ts +1 -1
- package/dist/Projects/models.json +0 -8906
- package/src/Projects/models.json +0 -8906
package/src/Account/index.ts
CHANGED
|
@@ -14,6 +14,7 @@ import { Wallet, pbkdf2, toUtf8Bytes, Signature, parseEther } from 'ethers';
|
|
|
14
14
|
import { ApiError, ApiReponse } from '../ApiClient';
|
|
15
15
|
import CurrentAccount from './CurrentAccount';
|
|
16
16
|
import { SupernetType } from '../ApiClient/WebSocketClient/types';
|
|
17
|
+
import { AuthUpdatedEvent, Tokens } from '../lib/AuthManager';
|
|
17
18
|
|
|
18
19
|
/**
|
|
19
20
|
* Account API methods that let you interact with the user's account.
|
|
@@ -34,6 +35,7 @@ class AccountApi extends ApiGroup {
|
|
|
34
35
|
this.client.socket.on('balanceUpdate', this.handleBalanceUpdate.bind(this));
|
|
35
36
|
this.client.on('connected', this.handleServerConnected.bind(this));
|
|
36
37
|
this.client.on('disconnected', this.handleServerDisconnected.bind(this));
|
|
38
|
+
this.client.auth.on('updated', this.handleAuthUpdated.bind(this));
|
|
37
39
|
}
|
|
38
40
|
|
|
39
41
|
private handleBalanceUpdate(data: BalanceData) {
|
|
@@ -51,6 +53,14 @@ class AccountApi extends ApiGroup {
|
|
|
51
53
|
this.currentAccount._clear();
|
|
52
54
|
}
|
|
53
55
|
|
|
56
|
+
private handleAuthUpdated({ refreshToken, token, walletAddress }: AuthUpdatedEvent) {
|
|
57
|
+
if (!refreshToken) {
|
|
58
|
+
this.currentAccount._clear();
|
|
59
|
+
} else {
|
|
60
|
+
this.currentAccount._update({ walletAddress, token, refreshToken });
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
|
|
54
64
|
private async getNonce(walletAddress: string): Promise<string> {
|
|
55
65
|
const res = await this.client.rest.post<ApiReponse<Nonce>>('/v1/account/nonce', {
|
|
56
66
|
walletAddress
|
|
@@ -111,40 +121,42 @@ class AccountApi extends ApiGroup {
|
|
|
111
121
|
referralCode,
|
|
112
122
|
signature
|
|
113
123
|
});
|
|
114
|
-
this.setToken(username, res.data.token);
|
|
124
|
+
await this.setToken(username, { refreshToken: res.data.refreshToken, token: res.data.token });
|
|
115
125
|
return res.data;
|
|
116
126
|
}
|
|
117
127
|
|
|
118
128
|
/**
|
|
119
|
-
* Restore session with username and
|
|
129
|
+
* Restore session with username and refresh token.
|
|
120
130
|
*
|
|
121
131
|
* You can save access token that you get from the login method and restore the session with this method.
|
|
122
132
|
*
|
|
123
133
|
* @example Store access token to local storage
|
|
124
134
|
* ```typescript
|
|
125
|
-
* const { username, token } = await client.account.login('username', 'password');
|
|
135
|
+
* const { username, token, refreshToken } = await client.account.login('username', 'password');
|
|
126
136
|
* localStorage.setItem('sogni-username', username);
|
|
127
137
|
* localStorage.setItem('sogni-token', token);
|
|
138
|
+
* localStorage.setItem('sogni-refresh-token', refreshToken);
|
|
128
139
|
* ```
|
|
129
140
|
*
|
|
130
141
|
* @example Restore session from local storage
|
|
131
142
|
* ```typescript
|
|
132
143
|
* const username = localStorage.getItem('sogni-username');
|
|
133
144
|
* const token = localStorage.getItem('sogni-token');
|
|
134
|
-
*
|
|
135
|
-
*
|
|
145
|
+
* const refreshToken = localStorage.getItem('sogni-refresh-token');
|
|
146
|
+
* if (username && refreshToken) {
|
|
147
|
+
* client.account.setToken(username, {token, refreshToken});
|
|
136
148
|
* console.log('Session restored');
|
|
137
149
|
* }
|
|
138
150
|
* ```
|
|
139
151
|
*
|
|
140
152
|
* @param username
|
|
141
|
-
* @param token
|
|
153
|
+
* @param tokens - Refresh token, access token pair { refreshToken: string, token: string }
|
|
142
154
|
*/
|
|
143
|
-
setToken(username: string,
|
|
144
|
-
this.client.authenticate(
|
|
155
|
+
async setToken(username: string, tokens: Tokens): Promise<void> {
|
|
156
|
+
await this.client.authenticate(tokens);
|
|
145
157
|
this.currentAccount._update({
|
|
146
|
-
|
|
147
|
-
|
|
158
|
+
username,
|
|
159
|
+
walletAddress: this.client.auth.walletAddress
|
|
148
160
|
});
|
|
149
161
|
}
|
|
150
162
|
|
|
@@ -171,7 +183,7 @@ class AccountApi extends ApiGroup {
|
|
|
171
183
|
walletAddress: wallet.address,
|
|
172
184
|
signature
|
|
173
185
|
});
|
|
174
|
-
this.setToken(username, res.data.token);
|
|
186
|
+
await this.setToken(username, { refreshToken: res.data.refreshToken, token: res.data.token });
|
|
175
187
|
return res.data;
|
|
176
188
|
}
|
|
177
189
|
|
package/src/Account/types.ts
CHANGED
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
import { MessageType, SocketMessageMap } from './messages';
|
|
2
2
|
import { SocketEventMap } from './events';
|
|
3
|
-
import RestClient
|
|
3
|
+
import RestClient from '../../lib/RestClient';
|
|
4
4
|
import { SupernetType } from './types';
|
|
5
5
|
import WebSocket, { CloseEvent, ErrorEvent, MessageEvent } from 'isomorphic-ws';
|
|
6
6
|
import { base64Decode, base64Encode } from '../../lib/base64';
|
|
7
7
|
import isNodejs from '../../lib/isNodejs';
|
|
8
|
-
import Cookie from 'js-cookie';
|
|
9
8
|
import { LIB_VERSION } from '../../version';
|
|
10
9
|
import { Logger } from '../../lib/DefaultLogger';
|
|
10
|
+
import AuthManager from '../../lib/AuthManager';
|
|
11
|
+
|
|
12
|
+
const PROTOCOL_VERSION = '0.4.3';
|
|
11
13
|
|
|
12
14
|
const PING_INTERVAL = 15000;
|
|
13
15
|
|
|
@@ -18,34 +20,23 @@ class WebSocketClient extends RestClient<SocketEventMap> {
|
|
|
18
20
|
private _supernetType: SupernetType;
|
|
19
21
|
private _pingInterval: NodeJS.Timeout | null = null;
|
|
20
22
|
|
|
21
|
-
constructor(
|
|
23
|
+
constructor(
|
|
24
|
+
baseUrl: string,
|
|
25
|
+
auth: AuthManager,
|
|
26
|
+
appId: string,
|
|
27
|
+
supernetType: SupernetType,
|
|
28
|
+
logger: Logger
|
|
29
|
+
) {
|
|
22
30
|
const _baseUrl = new URL(baseUrl);
|
|
23
31
|
if (_baseUrl.protocol === 'wss:') {
|
|
24
32
|
_baseUrl.protocol = 'https:';
|
|
25
33
|
}
|
|
26
|
-
super(_baseUrl.toString(), logger);
|
|
34
|
+
super(_baseUrl.toString(), auth, logger);
|
|
27
35
|
this.appId = appId;
|
|
28
36
|
this.baseUrl = _baseUrl.toString();
|
|
29
37
|
this._supernetType = supernetType;
|
|
30
38
|
}
|
|
31
39
|
|
|
32
|
-
set auth(auth: AuthData | null) {
|
|
33
|
-
//In browser, set the cookie
|
|
34
|
-
if (!isNodejs) {
|
|
35
|
-
if (auth) {
|
|
36
|
-
Cookie.set('authorization', auth.token, {
|
|
37
|
-
domain: '.sogni.ai',
|
|
38
|
-
expires: 1
|
|
39
|
-
});
|
|
40
|
-
} else {
|
|
41
|
-
Cookie.remove('authorization', {
|
|
42
|
-
domain: '.sogni.ai'
|
|
43
|
-
});
|
|
44
|
-
}
|
|
45
|
-
}
|
|
46
|
-
this._auth = auth;
|
|
47
|
-
}
|
|
48
|
-
|
|
49
40
|
get supernetType(): SupernetType {
|
|
50
41
|
return this._supernetType;
|
|
51
42
|
}
|
|
@@ -54,11 +45,11 @@ class WebSocketClient extends RestClient<SocketEventMap> {
|
|
|
54
45
|
return !!this.socket;
|
|
55
46
|
}
|
|
56
47
|
|
|
57
|
-
connect() {
|
|
48
|
+
async connect() {
|
|
58
49
|
if (this.socket) {
|
|
59
50
|
this.disconnect();
|
|
60
51
|
}
|
|
61
|
-
const userAgent = `Sogni/${
|
|
52
|
+
const userAgent = `Sogni/${PROTOCOL_VERSION} (sogni-client) ${LIB_VERSION}`;
|
|
62
53
|
const url = new URL(this.baseUrl);
|
|
63
54
|
url.protocol = 'wss:';
|
|
64
55
|
url.searchParams.set('appId', this.appId);
|
|
@@ -71,7 +62,7 @@ class WebSocketClient extends RestClient<SocketEventMap> {
|
|
|
71
62
|
if (isNodejs) {
|
|
72
63
|
params = {
|
|
73
64
|
headers: {
|
|
74
|
-
Authorization: this.
|
|
65
|
+
Authorization: await this.auth.getToken(),
|
|
75
66
|
'User-Agent': userAgent
|
|
76
67
|
}
|
|
77
68
|
};
|
|
@@ -114,7 +105,7 @@ class WebSocketClient extends RestClient<SocketEventMap> {
|
|
|
114
105
|
}
|
|
115
106
|
|
|
116
107
|
switchNetwork(supernetType: SupernetType): Promise<SupernetType> {
|
|
117
|
-
return new Promise<SupernetType>(async (resolve
|
|
108
|
+
return new Promise<SupernetType>(async (resolve) => {
|
|
118
109
|
this.once('changeNetwork', ({ network }) => {
|
|
119
110
|
this._supernetType = network;
|
|
120
111
|
resolve(network);
|
|
@@ -1,8 +1,16 @@
|
|
|
1
1
|
import { JobRequestRaw } from '../../Projects/createJobRequestMessage';
|
|
2
2
|
import { SupernetType } from './types';
|
|
3
3
|
|
|
4
|
+
export interface JobErrorMessage {
|
|
5
|
+
jobID: string;
|
|
6
|
+
error: 'artistCanceled';
|
|
7
|
+
error_message: 'artistCanceled';
|
|
8
|
+
isFromWorker: false;
|
|
9
|
+
}
|
|
10
|
+
|
|
4
11
|
export interface SocketMessageMap {
|
|
5
12
|
jobRequest: JobRequestRaw;
|
|
13
|
+
jobError: JobErrorMessage;
|
|
6
14
|
changeNetwork: SupernetType;
|
|
7
15
|
}
|
|
8
16
|
|
package/src/ApiClient/index.ts
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import RestClient from '../lib/RestClient';
|
|
2
2
|
import WebSocketClient from './WebSocketClient';
|
|
3
|
-
import { jwtDecode } from 'jwt-decode';
|
|
4
3
|
import TypedEventEmitter from '../lib/TypedEventEmitter';
|
|
5
4
|
import { ApiClientEvents } from './events';
|
|
6
5
|
import { ServerConnectData, ServerDisconnectData } from './WebSocketClient/events';
|
|
@@ -8,6 +7,7 @@ import { isNotRecoverable } from './WebSocketClient/ErrorCode';
|
|
|
8
7
|
import { JSONValue } from '../types/json';
|
|
9
8
|
import { SupernetType } from './WebSocketClient/types';
|
|
10
9
|
import { Logger } from '../lib/DefaultLogger';
|
|
10
|
+
import AuthManager, { Tokens } from '../lib/AuthManager';
|
|
11
11
|
|
|
12
12
|
const WS_RECONNECT_ATTEMPTS = 5;
|
|
13
13
|
|
|
@@ -33,21 +33,12 @@ export class ApiError extends Error {
|
|
|
33
33
|
}
|
|
34
34
|
}
|
|
35
35
|
|
|
36
|
-
/**
|
|
37
|
-
* @inline
|
|
38
|
-
*/
|
|
39
|
-
interface AuthData {
|
|
40
|
-
token: string;
|
|
41
|
-
walletAddress: string;
|
|
42
|
-
expiresAt: Date;
|
|
43
|
-
}
|
|
44
|
-
|
|
45
36
|
class ApiClient extends TypedEventEmitter<ApiClientEvents> {
|
|
46
37
|
readonly appId: string;
|
|
47
38
|
readonly logger: Logger;
|
|
48
39
|
private _rest: RestClient;
|
|
49
40
|
private _socket: WebSocketClient;
|
|
50
|
-
private _auth:
|
|
41
|
+
private _auth: AuthManager;
|
|
51
42
|
private _reconnectAttempts = WS_RECONNECT_ATTEMPTS;
|
|
52
43
|
|
|
53
44
|
constructor(
|
|
@@ -60,18 +51,21 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
|
|
|
60
51
|
super();
|
|
61
52
|
this.appId = appId;
|
|
62
53
|
this.logger = logger;
|
|
63
|
-
this.
|
|
64
|
-
this.
|
|
54
|
+
this._auth = new AuthManager(baseUrl, logger);
|
|
55
|
+
this._rest = new RestClient(baseUrl, this._auth, logger);
|
|
56
|
+
this._socket = new WebSocketClient(socketUrl, this._auth, appId, networkType, logger);
|
|
57
|
+
|
|
58
|
+
this._auth.on('refreshFailed', this.handleRefreshFailed.bind(this));
|
|
65
59
|
this._socket.on('connected', this.handleSocketConnect.bind(this));
|
|
66
60
|
this._socket.on('disconnected', this.handleSocketDisconnect.bind(this));
|
|
67
61
|
}
|
|
68
62
|
|
|
69
63
|
get isAuthenticated(): boolean {
|
|
70
|
-
return
|
|
64
|
+
return this.auth.isAuthenticated;
|
|
71
65
|
}
|
|
72
66
|
|
|
73
|
-
get auth():
|
|
74
|
-
return this._auth
|
|
67
|
+
get auth(): AuthManager {
|
|
68
|
+
return this._auth;
|
|
75
69
|
}
|
|
76
70
|
|
|
77
71
|
get socket(): WebSocketClient {
|
|
@@ -82,20 +76,13 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
|
|
|
82
76
|
return this._rest;
|
|
83
77
|
}
|
|
84
78
|
|
|
85
|
-
authenticate(
|
|
86
|
-
|
|
87
|
-
this.
|
|
88
|
-
token,
|
|
89
|
-
walletAddress: decoded.addr,
|
|
90
|
-
expiresAt: new Date(decoded.exp * 1000)
|
|
91
|
-
};
|
|
92
|
-
this.rest.auth = { token };
|
|
93
|
-
this.socket.auth = { token };
|
|
94
|
-
this.socket.connect();
|
|
79
|
+
async authenticate(tokens: Tokens) {
|
|
80
|
+
await this.auth.setTokens(tokens);
|
|
81
|
+
await this.socket.connect();
|
|
95
82
|
}
|
|
96
83
|
|
|
97
84
|
removeAuth() {
|
|
98
|
-
this.
|
|
85
|
+
this.auth.clear();
|
|
99
86
|
this.socket.disconnect();
|
|
100
87
|
}
|
|
101
88
|
|
|
@@ -112,6 +99,7 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
|
|
|
112
99
|
return;
|
|
113
100
|
}
|
|
114
101
|
if (this._reconnectAttempts <= 0) {
|
|
102
|
+
this.removeAuth();
|
|
115
103
|
this.emit('disconnected', data);
|
|
116
104
|
this._reconnectAttempts = WS_RECONNECT_ATTEMPTS;
|
|
117
105
|
return;
|
|
@@ -119,6 +107,10 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
|
|
|
119
107
|
this._reconnectAttempts--;
|
|
120
108
|
setTimeout(() => this.socket.connect(), 1000);
|
|
121
109
|
}
|
|
110
|
+
|
|
111
|
+
handleRefreshFailed() {
|
|
112
|
+
this.removeAuth();
|
|
113
|
+
}
|
|
122
114
|
}
|
|
123
115
|
|
|
124
116
|
export default ApiClient;
|
package/src/Projects/Project.ts
CHANGED
|
@@ -164,6 +164,13 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
|
164
164
|
});
|
|
165
165
|
}
|
|
166
166
|
|
|
167
|
+
/**
|
|
168
|
+
* Cancel the project. This will cancel all jobs in the project.
|
|
169
|
+
*/
|
|
170
|
+
async cancel() {
|
|
171
|
+
await this._api.cancel(this.id);
|
|
172
|
+
}
|
|
173
|
+
|
|
167
174
|
/**
|
|
168
175
|
* Find a job by id
|
|
169
176
|
* @param id
|
|
@@ -65,6 +65,14 @@ function getTemplate() {
|
|
|
65
65
|
};
|
|
66
66
|
}
|
|
67
67
|
|
|
68
|
+
function validateSize(value: any): number {
|
|
69
|
+
const size = Number(value);
|
|
70
|
+
if (isNaN(size) || size < 256 || size > 2048) {
|
|
71
|
+
throw new Error('Width and height must be numbers between 256 and 2048');
|
|
72
|
+
}
|
|
73
|
+
return size;
|
|
74
|
+
}
|
|
75
|
+
|
|
68
76
|
function createJobRequestMessage(id: string, params: ProjectParams) {
|
|
69
77
|
const template = getTemplate();
|
|
70
78
|
const jobRequest: Record<string, any> = {
|
|
@@ -84,7 +92,8 @@ function createJobRequestMessage(id: string, params: ProjectParams) {
|
|
|
84
92
|
strengthIsEnabled: !!params.startingImage,
|
|
85
93
|
strength: !!params.startingImage
|
|
86
94
|
? 1 - (Number(params.startingImageStrength) || 0.5)
|
|
87
|
-
: undefined
|
|
95
|
+
: undefined,
|
|
96
|
+
sizePreset: params.sizePreset
|
|
88
97
|
}
|
|
89
98
|
],
|
|
90
99
|
previews: params.numberOfPreviews || 0,
|
|
@@ -95,6 +104,10 @@ function createJobRequestMessage(id: string, params: ProjectParams) {
|
|
|
95
104
|
if (params.network) {
|
|
96
105
|
jobRequest.network = params.network;
|
|
97
106
|
}
|
|
107
|
+
if (params.sizePreset === 'custom') {
|
|
108
|
+
jobRequest.keyFrames[0].width = validateSize(params.width);
|
|
109
|
+
jobRequest.keyFrames[0].height = validateSize(params.height);
|
|
110
|
+
}
|
|
98
111
|
return jobRequest;
|
|
99
112
|
}
|
|
100
113
|
|
package/src/Projects/index.ts
CHANGED
|
@@ -1,6 +1,12 @@
|
|
|
1
1
|
import ApiGroup, { ApiConfig } from '../ApiGroup';
|
|
2
|
-
import
|
|
3
|
-
|
|
2
|
+
import {
|
|
3
|
+
AvailableModel,
|
|
4
|
+
EstimateRequest,
|
|
5
|
+
ImageUrlParams,
|
|
6
|
+
ProjectParams,
|
|
7
|
+
SizePreset,
|
|
8
|
+
SupportedModel
|
|
9
|
+
} from './types';
|
|
4
10
|
import {
|
|
5
11
|
JobErrorData,
|
|
6
12
|
JobProgressData,
|
|
@@ -16,8 +22,12 @@ import { JobEvent, ProjectApiEvents, ProjectEvent } from './types/events';
|
|
|
16
22
|
import getUUID from '../lib/getUUID';
|
|
17
23
|
import { RawProject } from './types/RawProject';
|
|
18
24
|
import ErrorData from '../types/ErrorData';
|
|
25
|
+
import { SupernetType } from '../ApiClient/WebSocketClient/types';
|
|
26
|
+
import Cache from '../lib/Cache';
|
|
19
27
|
|
|
28
|
+
const sizePresetCache = new Cache<SizePreset[]>(10 * 60 * 1000);
|
|
20
29
|
const GARBAGE_COLLECT_TIMEOUT = 10000;
|
|
30
|
+
const MODELS_REFRESH_INTERVAL = 1000 * 60 * 60 * 24; // 24 hours
|
|
21
31
|
|
|
22
32
|
function mapErrorCodes(code: string): number {
|
|
23
33
|
switch (code) {
|
|
@@ -39,6 +49,10 @@ function mapErrorCodes(code: string): number {
|
|
|
39
49
|
class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
40
50
|
private _availableModels: AvailableModel[] = [];
|
|
41
51
|
private projects: Project[] = [];
|
|
52
|
+
private _supportedModels: { data: SupportedModel[] | null; updatedAt: Date } = {
|
|
53
|
+
data: null,
|
|
54
|
+
updatedAt: new Date(0)
|
|
55
|
+
};
|
|
42
56
|
|
|
43
57
|
get availableModels() {
|
|
44
58
|
return this._availableModels;
|
|
@@ -65,14 +79,20 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
65
79
|
this.emit('availableModels', this._availableModels);
|
|
66
80
|
}
|
|
67
81
|
|
|
68
|
-
private handleSwarmModels(data: SocketEventMap['swarmModels']) {
|
|
69
|
-
|
|
70
|
-
|
|
82
|
+
private async handleSwarmModels(data: SocketEventMap['swarmModels']) {
|
|
83
|
+
let models: SupportedModel[] = [];
|
|
84
|
+
try {
|
|
85
|
+
models = await this.getSupportedModels();
|
|
86
|
+
} catch (e) {
|
|
87
|
+
this.client.logger.error(e);
|
|
88
|
+
}
|
|
89
|
+
const modelIndex = models.reduce((acc: Record<string, SupportedModel>, model) => {
|
|
90
|
+
acc[model.id] = model;
|
|
71
91
|
return acc;
|
|
72
92
|
}, {});
|
|
73
93
|
this._availableModels = Object.entries(data).map(([id, workerCount]) => ({
|
|
74
94
|
id,
|
|
75
|
-
name: modelIndex[id]?.
|
|
95
|
+
name: modelIndex[id]?.name || id.replace(/-/g, ' '),
|
|
76
96
|
workerCount
|
|
77
97
|
}));
|
|
78
98
|
this.emit('availableModels', this._availableModels);
|
|
@@ -339,6 +359,38 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
339
359
|
return data.project;
|
|
340
360
|
}
|
|
341
361
|
|
|
362
|
+
/**
|
|
363
|
+
* Cancel project by id. This will cancel all jobs in the project and mark project as canceled.
|
|
364
|
+
* Client may still receive job events for the canceled jobs as it takes some time, but they will
|
|
365
|
+
* be ignored
|
|
366
|
+
* @param projectId
|
|
367
|
+
**/
|
|
368
|
+
async cancel(projectId: string) {
|
|
369
|
+
await this.client.socket.send('jobError', {
|
|
370
|
+
jobID: projectId,
|
|
371
|
+
error: 'artistCanceled',
|
|
372
|
+
error_message: 'artistCanceled',
|
|
373
|
+
isFromWorker: false
|
|
374
|
+
});
|
|
375
|
+
const project = this.projects.find((p) => p.id === projectId);
|
|
376
|
+
if (!project) {
|
|
377
|
+
return;
|
|
378
|
+
}
|
|
379
|
+
// Remove project from the list to stop tracking it
|
|
380
|
+
this.projects = this.projects.filter((p) => p.id !== projectId);
|
|
381
|
+
|
|
382
|
+
// Cancel all jobs in the project
|
|
383
|
+
project.jobs.forEach((job) => {
|
|
384
|
+
if (!job.finished) {
|
|
385
|
+
job._update({ status: 'canceled' });
|
|
386
|
+
}
|
|
387
|
+
});
|
|
388
|
+
// If project is still in processing, mark it as canceled
|
|
389
|
+
if (!project.finished) {
|
|
390
|
+
project._update({ status: 'canceled' });
|
|
391
|
+
}
|
|
392
|
+
}
|
|
393
|
+
|
|
342
394
|
private async uploadGuideImage(projectId: string, file: File | Buffer | Blob) {
|
|
343
395
|
const imageId = getUUID();
|
|
344
396
|
const presignedUrl = await this.uploadUrl({
|
|
@@ -370,10 +422,32 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
370
422
|
stepCount,
|
|
371
423
|
previewCount,
|
|
372
424
|
cnEnabled,
|
|
373
|
-
startingImageStrength
|
|
425
|
+
startingImageStrength,
|
|
426
|
+
width,
|
|
427
|
+
height,
|
|
428
|
+
sizePreset
|
|
374
429
|
}: EstimateRequest) {
|
|
430
|
+
const pathParams = [
|
|
431
|
+
network,
|
|
432
|
+
model,
|
|
433
|
+
imageCount,
|
|
434
|
+
stepCount,
|
|
435
|
+
previewCount,
|
|
436
|
+
cnEnabled ? 1 : 0,
|
|
437
|
+
startingImageStrength ? 1 - startingImageStrength : 0
|
|
438
|
+
];
|
|
439
|
+
if (sizePreset) {
|
|
440
|
+
const presets = await this.getSizePresets(network, model);
|
|
441
|
+
const preset = presets.find((p) => p.id === sizePreset);
|
|
442
|
+
if (!preset) {
|
|
443
|
+
throw new Error('Invalid size preset');
|
|
444
|
+
}
|
|
445
|
+
pathParams.push(preset.width, preset.height);
|
|
446
|
+
} else if (width && height) {
|
|
447
|
+
pathParams.push(width, height);
|
|
448
|
+
}
|
|
375
449
|
const r = await this.client.socket.get<EstimationResponse>(
|
|
376
|
-
`/api/v1/job/estimate/${
|
|
450
|
+
`/api/v1/job/estimate/${pathParams.join('/')}`
|
|
377
451
|
);
|
|
378
452
|
return {
|
|
379
453
|
token: r.quote.project.costInToken,
|
|
@@ -406,6 +480,74 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
406
480
|
);
|
|
407
481
|
return r.data.downloadUrl;
|
|
408
482
|
}
|
|
483
|
+
|
|
484
|
+
async getSupportedModels(forceRefresh = false) {
|
|
485
|
+
if (
|
|
486
|
+
this._supportedModels.data &&
|
|
487
|
+
!forceRefresh &&
|
|
488
|
+
Date.now() - this._supportedModels.updatedAt.getTime() < MODELS_REFRESH_INTERVAL
|
|
489
|
+
) {
|
|
490
|
+
return this._supportedModels.data;
|
|
491
|
+
}
|
|
492
|
+
const models = await this.client.socket.get<SupportedModel[]>(`/api/v1/models/list`);
|
|
493
|
+
this._supportedModels = { data: models, updatedAt: new Date() };
|
|
494
|
+
return models;
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
/**
|
|
498
|
+
* Get supported size presets for the model and network. Size presets are cached for 10 minutes.
|
|
499
|
+
*
|
|
500
|
+
* @example
|
|
501
|
+
* ```ts
|
|
502
|
+
* const presets = await client.projects.getSizePresets('fast', 'flux1-schnell-fp8');
|
|
503
|
+
* console.log(presets);
|
|
504
|
+
* ```
|
|
505
|
+
*
|
|
506
|
+
* @param network - 'fast' or 'relaxed'
|
|
507
|
+
* @param modelId - model id (e.g. 'flux1-schnell-fp8')
|
|
508
|
+
* @param forceRefresh - force refresh cache
|
|
509
|
+
* @returns {Promise<{
|
|
510
|
+
* label: string;
|
|
511
|
+
* id: string;
|
|
512
|
+
* width: number;
|
|
513
|
+
* height: number;
|
|
514
|
+
* ratio: string;
|
|
515
|
+
* aspect: string;
|
|
516
|
+
* }[]>}
|
|
517
|
+
*/
|
|
518
|
+
async getSizePresets(network: SupernetType, modelId: string, forceRefresh = false) {
|
|
519
|
+
const key = `${network}-${modelId}`;
|
|
520
|
+
const cached = sizePresetCache.read(key);
|
|
521
|
+
if (cached && !forceRefresh) {
|
|
522
|
+
return cached;
|
|
523
|
+
}
|
|
524
|
+
const data = await this.client.socket.get<SizePreset[]>(
|
|
525
|
+
`/api/v1/size-presets/network/${network}/model/${modelId}`
|
|
526
|
+
);
|
|
527
|
+
sizePresetCache.write(key, data);
|
|
528
|
+
return data;
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
/**
|
|
532
|
+
* Get available models and their worker counts. Normally, you would get list once you connect
|
|
533
|
+
* to the server, but you can also call this method to get the list of available models manually.
|
|
534
|
+
* @param network
|
|
535
|
+
*/
|
|
536
|
+
async getAvailableModels(network: SupernetType): Promise<AvailableModel[]> {
|
|
537
|
+
const workersByModelSid = await this.client.socket.get<Record<string, number>>(
|
|
538
|
+
`/api/v1/status/network/${network}/models`
|
|
539
|
+
);
|
|
540
|
+
const supportedModels = await this.getSupportedModels();
|
|
541
|
+
return Object.entries(workersByModelSid).map(([sid, workerCount]) => {
|
|
542
|
+
const SID = Number(sid);
|
|
543
|
+
const model = supportedModels.find((m) => m.SID === SID);
|
|
544
|
+
return {
|
|
545
|
+
id: model?.id || sid,
|
|
546
|
+
name: model?.name || sid.replace(/-/g, ' '),
|
|
547
|
+
workerCount
|
|
548
|
+
};
|
|
549
|
+
});
|
|
550
|
+
}
|
|
409
551
|
}
|
|
410
552
|
|
|
411
553
|
export default ProjectsApi;
|
|
@@ -1,11 +1,26 @@
|
|
|
1
1
|
import { SupernetType } from '../../ApiClient/WebSocketClient/types';
|
|
2
2
|
|
|
3
|
+
export interface SupportedModel {
|
|
4
|
+
id: string;
|
|
5
|
+
name: string;
|
|
6
|
+
SID: number;
|
|
7
|
+
}
|
|
8
|
+
|
|
3
9
|
export interface AvailableModel {
|
|
4
10
|
id: string;
|
|
5
11
|
name: string;
|
|
6
12
|
workerCount: number;
|
|
7
13
|
}
|
|
8
14
|
|
|
15
|
+
export interface SizePreset {
|
|
16
|
+
label: string;
|
|
17
|
+
id: string;
|
|
18
|
+
width: number;
|
|
19
|
+
height: number;
|
|
20
|
+
ratio: string;
|
|
21
|
+
aspect: string;
|
|
22
|
+
}
|
|
23
|
+
|
|
9
24
|
export interface AiModel {
|
|
10
25
|
isSD3: boolean;
|
|
11
26
|
modelShortName: string;
|
|
@@ -108,6 +123,19 @@ export interface ProjectParams {
|
|
|
108
123
|
* Time step spacing method
|
|
109
124
|
*/
|
|
110
125
|
timeStepSpacing?: TimeStepSpacing;
|
|
126
|
+
/**
|
|
127
|
+
* Size preset ID to use. You can query available size presets
|
|
128
|
+
* from `client.projects.sizePresets(network, modelId)`
|
|
129
|
+
*/
|
|
130
|
+
sizePreset?: 'custom' | string;
|
|
131
|
+
/**
|
|
132
|
+
* Output image width. Only used if `sizePreset` is "custom"
|
|
133
|
+
*/
|
|
134
|
+
width?: number;
|
|
135
|
+
/**
|
|
136
|
+
* Output image height. Only used if `sizePreset` is "custom"
|
|
137
|
+
*/
|
|
138
|
+
height?: number;
|
|
111
139
|
}
|
|
112
140
|
|
|
113
141
|
export type ImageUrlParams = {
|
|
@@ -147,4 +175,18 @@ export interface EstimateRequest {
|
|
|
147
175
|
* How strong effect of starting image should be. From 0 to 1, default 0.5
|
|
148
176
|
*/
|
|
149
177
|
startingImageStrength?: number;
|
|
178
|
+
/**
|
|
179
|
+
* Size preset ID
|
|
180
|
+
*/
|
|
181
|
+
sizePreset?: string;
|
|
182
|
+
/**
|
|
183
|
+
* Size preset image width, if not using size preset
|
|
184
|
+
* @internal
|
|
185
|
+
*/
|
|
186
|
+
width?: number;
|
|
187
|
+
/**
|
|
188
|
+
* Size preset image height, if not using size preset
|
|
189
|
+
* @internal
|
|
190
|
+
*/
|
|
191
|
+
height?: number;
|
|
150
192
|
}
|