@sogni-ai/sogni-client 4.0.0-alpha.3 → 4.0.0-alpha.30
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +213 -0
- package/README.md +279 -28
- package/dist/Account/index.d.ts +18 -16
- package/dist/Account/index.js +31 -20
- package/dist/Account/index.js.map +1 -1
- package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/ChannelCoordinator.d.ts +66 -0
- package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/ChannelCoordinator.js +332 -0
- package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/ChannelCoordinator.js.map +1 -0
- package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/index.d.ts +28 -0
- package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/index.js +203 -0
- package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/index.js.map +1 -0
- package/dist/ApiClient/WebSocketClient/events.d.ts +11 -0
- package/dist/ApiClient/WebSocketClient/index.d.ts +2 -2
- package/dist/ApiClient/WebSocketClient/index.js +13 -3
- package/dist/ApiClient/WebSocketClient/index.js.map +1 -1
- package/dist/ApiClient/WebSocketClient/types.d.ts +13 -0
- package/dist/ApiClient/index.d.ts +4 -4
- package/dist/ApiClient/index.js +23 -4
- package/dist/ApiClient/index.js.map +1 -1
- package/dist/Projects/Job.d.ts +24 -4
- package/dist/Projects/Job.js +58 -16
- package/dist/Projects/Job.js.map +1 -1
- package/dist/Projects/Project.d.ts +8 -0
- package/dist/Projects/Project.js +27 -6
- package/dist/Projects/Project.js.map +1 -1
- package/dist/Projects/createJobRequestMessage.js +109 -15
- package/dist/Projects/createJobRequestMessage.js.map +1 -1
- package/dist/Projects/index.d.ts +110 -11
- package/dist/Projects/index.js +412 -42
- package/dist/Projects/index.js.map +1 -1
- package/dist/Projects/types/EstimationResponse.d.ts +2 -0
- package/dist/Projects/types/SamplerParams.d.ts +13 -0
- package/dist/Projects/types/SamplerParams.js +26 -0
- package/dist/Projects/types/SamplerParams.js.map +1 -0
- package/dist/Projects/types/SchedulerParams.d.ts +14 -0
- package/dist/Projects/types/SchedulerParams.js +24 -0
- package/dist/Projects/types/SchedulerParams.js.map +1 -0
- package/dist/Projects/types/events.d.ts +5 -1
- package/dist/Projects/types/index.d.ts +150 -39
- package/dist/Projects/types/index.js +13 -0
- package/dist/Projects/types/index.js.map +1 -1
- package/dist/Projects/utils.d.ts +19 -1
- package/dist/Projects/utils.js +68 -0
- package/dist/Projects/utils.js.map +1 -1
- package/dist/index.d.ts +12 -4
- package/dist/index.js +12 -4
- package/dist/index.js.map +1 -1
- package/dist/lib/AuthManager/TokenAuthManager.js +0 -2
- package/dist/lib/AuthManager/TokenAuthManager.js.map +1 -1
- package/dist/lib/DataEntity.js +4 -2
- package/dist/lib/DataEntity.js.map +1 -1
- package/dist/lib/validation.d.ts +7 -0
- package/dist/lib/validation.js +36 -0
- package/dist/lib/validation.js.map +1 -1
- package/package.json +4 -4
- package/src/Account/index.ts +30 -19
- package/src/ApiClient/WebSocketClient/BrowserWebSocketClient/ChannelCoordinator.ts +426 -0
- package/src/ApiClient/WebSocketClient/BrowserWebSocketClient/index.ts +237 -0
- package/src/ApiClient/WebSocketClient/events.ts +13 -0
- package/src/ApiClient/WebSocketClient/index.ts +15 -5
- package/src/ApiClient/WebSocketClient/types.ts +16 -0
- package/src/ApiClient/index.ts +30 -8
- package/src/Projects/Job.ts +64 -16
- package/src/Projects/Project.ts +29 -9
- package/src/Projects/createJobRequestMessage.ts +155 -36
- package/src/Projects/index.ts +437 -46
- package/src/Projects/types/EstimationResponse.ts +2 -0
- package/src/Projects/types/SamplerParams.ts +24 -0
- package/src/Projects/types/SchedulerParams.ts +22 -0
- package/src/Projects/types/events.ts +6 -0
- package/src/Projects/types/index.ts +181 -47
- package/src/Projects/utils.ts +66 -1
- package/src/index.ts +38 -11
- package/src/lib/AuthManager/TokenAuthManager.ts +0 -2
- package/src/lib/DataEntity.ts +4 -2
- package/src/lib/validation.ts +41 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import { MessageType, SocketMessageMap } from './messages';
|
|
2
2
|
import { SocketEventMap } from './events';
|
|
3
3
|
import RestClient from '../../lib/RestClient';
|
|
4
|
-
import { SupernetType } from './types';
|
|
4
|
+
import { IWebSocketClient, SupernetType } from './types';
|
|
5
5
|
import WebSocket, { CloseEvent, ErrorEvent, MessageEvent } from 'isomorphic-ws';
|
|
6
6
|
import { base64Decode, base64Encode } from '../../lib/base64';
|
|
7
7
|
import isNodejs from '../../lib/isNodejs';
|
|
@@ -13,7 +13,7 @@ const PROTOCOL_VERSION = '3.0.0';
|
|
|
13
13
|
|
|
14
14
|
const PING_INTERVAL = 15000;
|
|
15
15
|
|
|
16
|
-
class WebSocketClient extends RestClient<SocketEventMap> {
|
|
16
|
+
class WebSocketClient extends RestClient<SocketEventMap> implements IWebSocketClient {
|
|
17
17
|
appId: string;
|
|
18
18
|
baseUrl: string;
|
|
19
19
|
private socket: WebSocket | null = null;
|
|
@@ -86,7 +86,7 @@ class WebSocketClient extends RestClient<SocketEventMap> {
|
|
|
86
86
|
socket.onmessage = null;
|
|
87
87
|
socket.onopen = null;
|
|
88
88
|
this.stopPing();
|
|
89
|
-
socket.close();
|
|
89
|
+
socket.close(1000, 'Client disconnected');
|
|
90
90
|
}
|
|
91
91
|
|
|
92
92
|
private startPing(socket: WebSocket) {
|
|
@@ -148,9 +148,16 @@ class WebSocketClient extends RestClient<SocketEventMap> {
|
|
|
148
148
|
}
|
|
149
149
|
|
|
150
150
|
private handleClose(e: CloseEvent) {
|
|
151
|
-
|
|
151
|
+
const socket = e.target;
|
|
152
|
+
socket.onerror = null;
|
|
153
|
+
socket.onmessage = null;
|
|
154
|
+
socket.onopen = null;
|
|
155
|
+
if (socket === this.socket || !this.socket) {
|
|
152
156
|
this._logger.info('WebSocket disconnected, cleanup', e);
|
|
153
|
-
this.
|
|
157
|
+
if (socket === this.socket) {
|
|
158
|
+
this.stopPing();
|
|
159
|
+
this.socket = null;
|
|
160
|
+
}
|
|
154
161
|
this.emit('disconnected', {
|
|
155
162
|
code: e.code,
|
|
156
163
|
reason: e.reason
|
|
@@ -193,6 +200,9 @@ class WebSocketClient extends RestClient<SocketEventMap> {
|
|
|
193
200
|
}
|
|
194
201
|
|
|
195
202
|
async send<T extends MessageType>(messageType: T, data: SocketMessageMap[T]) {
|
|
203
|
+
if (!this.isConnected) {
|
|
204
|
+
await this.connect();
|
|
205
|
+
}
|
|
196
206
|
await this.waitForConnection();
|
|
197
207
|
this._logger.debug('WebSocket send:', messageType, data);
|
|
198
208
|
this.socket!.send(
|
|
@@ -1 +1,17 @@
|
|
|
1
|
+
import { MessageType, SocketMessageMap } from './messages';
|
|
2
|
+
import RestClient from '../../lib/RestClient';
|
|
3
|
+
import { SocketEventMap } from './events';
|
|
4
|
+
|
|
1
5
|
export type SupernetType = 'relaxed' | 'fast';
|
|
6
|
+
|
|
7
|
+
export interface IWebSocketClient extends RestClient<SocketEventMap> {
|
|
8
|
+
appId: string;
|
|
9
|
+
baseUrl: string;
|
|
10
|
+
isConnected: boolean;
|
|
11
|
+
supernetType: SupernetType;
|
|
12
|
+
|
|
13
|
+
connect(): Promise<void>;
|
|
14
|
+
disconnect(): void;
|
|
15
|
+
send<T extends MessageType>(messageType: T, data: SocketMessageMap[T]): Promise<void>;
|
|
16
|
+
switchNetwork(supernetType: SupernetType): Promise<SupernetType>;
|
|
17
|
+
}
|
package/src/ApiClient/index.ts
CHANGED
|
@@ -3,12 +3,14 @@ import WebSocketClient from './WebSocketClient';
|
|
|
3
3
|
import TypedEventEmitter from '../lib/TypedEventEmitter';
|
|
4
4
|
import { ApiClientEvents } from './events';
|
|
5
5
|
import { ServerConnectData, ServerDisconnectData } from './WebSocketClient/events';
|
|
6
|
-
import { isNotRecoverable } from './WebSocketClient/ErrorCode';
|
|
6
|
+
import { ErrorCode, isNotRecoverable } from './WebSocketClient/ErrorCode';
|
|
7
7
|
import { JSONValue } from '../types/json';
|
|
8
|
-
import { SupernetType } from './WebSocketClient/types';
|
|
8
|
+
import { IWebSocketClient, SupernetType } from './WebSocketClient/types';
|
|
9
9
|
import { Logger } from '../lib/DefaultLogger';
|
|
10
10
|
import CookieAuthManager from '../lib/AuthManager/CookieAuthManager';
|
|
11
11
|
import { AuthManager, TokenAuthManager } from '../lib/AuthManager';
|
|
12
|
+
import isNodejs from '../lib/isNodejs';
|
|
13
|
+
import BrowserWebSocketClient from './WebSocketClient/BrowserWebSocketClient';
|
|
12
14
|
|
|
13
15
|
const WS_RECONNECT_ATTEMPTS = 5;
|
|
14
16
|
|
|
@@ -42,13 +44,14 @@ export interface ApiClientOptions {
|
|
|
42
44
|
logger: Logger;
|
|
43
45
|
authType: 'token' | 'cookies';
|
|
44
46
|
disableSocket?: boolean;
|
|
47
|
+
multiInstance?: boolean;
|
|
45
48
|
}
|
|
46
49
|
|
|
47
50
|
class ApiClient extends TypedEventEmitter<ApiClientEvents> {
|
|
48
51
|
readonly appId: string;
|
|
49
52
|
readonly logger: Logger;
|
|
50
53
|
private _rest: RestClient;
|
|
51
|
-
private _socket:
|
|
54
|
+
private _socket: IWebSocketClient;
|
|
52
55
|
private _auth: AuthManager;
|
|
53
56
|
private _reconnectAttempts = WS_RECONNECT_ATTEMPTS;
|
|
54
57
|
private _disableSocket: boolean = false;
|
|
@@ -60,7 +63,8 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
|
|
|
60
63
|
networkType,
|
|
61
64
|
authType,
|
|
62
65
|
logger,
|
|
63
|
-
disableSocket = false
|
|
66
|
+
disableSocket = false,
|
|
67
|
+
multiInstance = false
|
|
64
68
|
}: ApiClientOptions) {
|
|
65
69
|
super();
|
|
66
70
|
this.appId = appId;
|
|
@@ -68,7 +72,13 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
|
|
|
68
72
|
this._auth =
|
|
69
73
|
authType === 'token' ? new TokenAuthManager(baseUrl, logger) : new CookieAuthManager(logger);
|
|
70
74
|
this._rest = new RestClient(baseUrl, this._auth, logger);
|
|
71
|
-
|
|
75
|
+
const supportMultiInstance = !isNodejs && this._auth instanceof CookieAuthManager;
|
|
76
|
+
if (supportMultiInstance && multiInstance) {
|
|
77
|
+
// Use coordinated WebSocket client to share single connection between tabs
|
|
78
|
+
this._socket = new BrowserWebSocketClient(socketUrl, this._auth, appId, networkType, logger);
|
|
79
|
+
} else {
|
|
80
|
+
this._socket = new WebSocketClient(socketUrl, this._auth, appId, networkType, logger);
|
|
81
|
+
}
|
|
72
82
|
this._disableSocket = disableSocket;
|
|
73
83
|
this._auth.on('updated', this.handleAuthUpdated.bind(this));
|
|
74
84
|
this._socket.on('connected', this.handleSocketConnect.bind(this));
|
|
@@ -83,7 +93,7 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
|
|
|
83
93
|
return this._auth;
|
|
84
94
|
}
|
|
85
95
|
|
|
86
|
-
get socket():
|
|
96
|
+
get socket(): IWebSocketClient {
|
|
87
97
|
return this._socket;
|
|
88
98
|
}
|
|
89
99
|
|
|
@@ -101,14 +111,26 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
|
|
|
101
111
|
}
|
|
102
112
|
|
|
103
113
|
handleSocketDisconnect(data: ServerDisconnectData) {
|
|
114
|
+
// If user is not authenticated, we don't need to reconnect
|
|
115
|
+
if (!this.auth.isAuthenticated || data.code === 1000) {
|
|
116
|
+
this.emit('disconnected', data);
|
|
117
|
+
return;
|
|
118
|
+
}
|
|
104
119
|
if (!data.code || isNotRecoverable(data.code)) {
|
|
120
|
+
// If this is browser, another tab is probably claiming the connection, so we don't need to reconnect
|
|
121
|
+
if (
|
|
122
|
+
this._socket instanceof BrowserWebSocketClient &&
|
|
123
|
+
data.code === ErrorCode.SWITCH_CONNECTION
|
|
124
|
+
) {
|
|
125
|
+
this.logger.debug('Switching network connection, not reconnecting');
|
|
126
|
+
return;
|
|
127
|
+
}
|
|
105
128
|
this.auth.clear();
|
|
106
129
|
this.emit('disconnected', data);
|
|
107
130
|
this.logger.error('Not recoverable socket error', data);
|
|
108
131
|
return;
|
|
109
132
|
}
|
|
110
133
|
if (this._reconnectAttempts <= 0) {
|
|
111
|
-
this.auth.clear();
|
|
112
134
|
this.emit('disconnected', data);
|
|
113
135
|
this._reconnectAttempts = WS_RECONNECT_ATTEMPTS;
|
|
114
136
|
return;
|
|
@@ -122,7 +144,7 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
|
|
|
122
144
|
if (this.socket.isConnected) {
|
|
123
145
|
this.socket.disconnect();
|
|
124
146
|
}
|
|
125
|
-
} else if (!this._disableSocket) {
|
|
147
|
+
} else if (!this._disableSocket && !this.socket.isConnected) {
|
|
126
148
|
this.socket.connect();
|
|
127
149
|
}
|
|
128
150
|
}
|
package/src/Projects/Job.ts
CHANGED
|
@@ -19,7 +19,7 @@ export const enhancementDefaults = {
|
|
|
19
19
|
startingImageStrength: 0.5,
|
|
20
20
|
steps: 5,
|
|
21
21
|
guidance: 1,
|
|
22
|
-
|
|
22
|
+
numberOfMedia: 1,
|
|
23
23
|
numberOfPreviews: 0
|
|
24
24
|
};
|
|
25
25
|
|
|
@@ -61,6 +61,11 @@ export interface JobData {
|
|
|
61
61
|
positivePrompt?: string;
|
|
62
62
|
negativePrompt?: string;
|
|
63
63
|
jobIndex?: number;
|
|
64
|
+
/**
|
|
65
|
+
* Estimated time remaining in seconds (for long-running jobs like video generation).
|
|
66
|
+
* Updated by ComfyUI workers during inference.
|
|
67
|
+
*/
|
|
68
|
+
etaSeconds?: number;
|
|
64
69
|
}
|
|
65
70
|
|
|
66
71
|
export interface JobEventMap extends EntityEvents {
|
|
@@ -178,10 +183,21 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
178
183
|
return this.data.error;
|
|
179
184
|
}
|
|
180
185
|
|
|
181
|
-
|
|
186
|
+
/**
|
|
187
|
+
* Whether this job has a result media file available for download.
|
|
188
|
+
* Returns true if completed and not NSFW filtered.
|
|
189
|
+
*/
|
|
190
|
+
get hasResultMedia() {
|
|
182
191
|
return this.status === 'completed' && !this.isNSFW;
|
|
183
192
|
}
|
|
184
193
|
|
|
194
|
+
/**
|
|
195
|
+
* Whether this job produces video output (based on the model used)
|
|
196
|
+
*/
|
|
197
|
+
get type(): 'image' | 'video' {
|
|
198
|
+
return this._api.isVideoModelId(this._project.params.modelId) ? 'video' : 'image';
|
|
199
|
+
}
|
|
200
|
+
|
|
185
201
|
get enhancedImage() {
|
|
186
202
|
if (!this._enhancementProject) {
|
|
187
203
|
return null;
|
|
@@ -199,17 +215,27 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
199
215
|
|
|
200
216
|
/**
|
|
201
217
|
* Get the result URL of the job. This method will make a request to the API to get signed URL.
|
|
202
|
-
* IMPORTANT: URL expires after 30 minutes, so make sure to download the
|
|
218
|
+
* IMPORTANT: URL expires after 30 minutes, so make sure to download the result as soon as possible.
|
|
219
|
+
* For video jobs, this returns a video URL. For image jobs, this returns an image URL.
|
|
203
220
|
*/
|
|
204
221
|
async getResultUrl(): Promise<string> {
|
|
205
222
|
if (this.data.status !== 'completed') {
|
|
206
223
|
throw new Error('Job is not completed yet');
|
|
207
224
|
}
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
225
|
+
let url: string;
|
|
226
|
+
if (this.type === 'video') {
|
|
227
|
+
url = await this._api.mediaDownloadUrl({
|
|
228
|
+
jobId: this.projectId,
|
|
229
|
+
id: this.id,
|
|
230
|
+
type: 'complete'
|
|
231
|
+
});
|
|
232
|
+
} else {
|
|
233
|
+
url = await this._api.downloadUrl({
|
|
234
|
+
jobId: this.projectId,
|
|
235
|
+
imageId: this.id,
|
|
236
|
+
type: 'complete'
|
|
237
|
+
});
|
|
238
|
+
}
|
|
213
239
|
this._update({ resultUrl: url });
|
|
214
240
|
return url;
|
|
215
241
|
}
|
|
@@ -230,6 +256,15 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
230
256
|
return this.data.workerName;
|
|
231
257
|
}
|
|
232
258
|
|
|
259
|
+
/**
|
|
260
|
+
* Estimated time remaining in seconds for long-running jobs (e.g., video generation).
|
|
261
|
+
* Only available for ComfyUI-based workers during inference.
|
|
262
|
+
* Returns undefined if no ETA has been received.
|
|
263
|
+
*/
|
|
264
|
+
get etaSeconds() {
|
|
265
|
+
return this.data.etaSeconds;
|
|
266
|
+
}
|
|
267
|
+
|
|
233
268
|
/**
|
|
234
269
|
* Syncs the job data with the data received from the REST API.
|
|
235
270
|
* @internal
|
|
@@ -247,11 +282,19 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
247
282
|
}
|
|
248
283
|
if (!this.data.resultUrl && delta.status === 'completed' && !data.triggeredNSFWFilter) {
|
|
249
284
|
try {
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
285
|
+
if (this.type === 'video') {
|
|
286
|
+
delta.resultUrl = await this._api.mediaDownloadUrl({
|
|
287
|
+
jobId: this.projectId,
|
|
288
|
+
id: this.id,
|
|
289
|
+
type: 'complete'
|
|
290
|
+
});
|
|
291
|
+
} else {
|
|
292
|
+
delta.resultUrl = await this._api.downloadUrl({
|
|
293
|
+
jobId: this.projectId,
|
|
294
|
+
imageId: this.id,
|
|
295
|
+
type: 'complete'
|
|
296
|
+
});
|
|
297
|
+
}
|
|
255
298
|
} catch (error) {
|
|
256
299
|
this._logger.error(error);
|
|
257
300
|
}
|
|
@@ -276,8 +319,8 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
276
319
|
}
|
|
277
320
|
|
|
278
321
|
async getResultData() {
|
|
279
|
-
if (!this.
|
|
280
|
-
throw new Error('No result
|
|
322
|
+
if (!this.hasResultMedia) {
|
|
323
|
+
throw new Error('No result media available');
|
|
281
324
|
}
|
|
282
325
|
const url = await this.getResultUrl();
|
|
283
326
|
const response = await fetch(url);
|
|
@@ -297,6 +340,10 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
297
340
|
strength: EnhancementStrength,
|
|
298
341
|
overrides: { positivePrompt?: string; stylePrompt?: string; tokenType?: TokenType } = {}
|
|
299
342
|
) {
|
|
343
|
+
const parentProjectParams = this._project.params;
|
|
344
|
+
if (parentProjectParams.type !== 'image') {
|
|
345
|
+
throw new Error('Enhancement is only available for images');
|
|
346
|
+
}
|
|
300
347
|
if (this.status !== 'completed') {
|
|
301
348
|
throw new Error('Job is not completed yet');
|
|
302
349
|
}
|
|
@@ -309,6 +356,7 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
309
356
|
}
|
|
310
357
|
const imageData = await this.getResultData();
|
|
311
358
|
const project = await this._api.create({
|
|
359
|
+
type: 'image',
|
|
312
360
|
...enhancementDefaults,
|
|
313
361
|
positivePrompt: overrides.positivePrompt || this._project.params.positivePrompt,
|
|
314
362
|
stylePrompt: overrides.stylePrompt || this._project.params.stylePrompt,
|
|
@@ -316,7 +364,7 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
316
364
|
seed: this.seed || this._project.params.seed,
|
|
317
365
|
startingImage: imageData,
|
|
318
366
|
startingImageStrength: 1 - getEnhacementStrength(strength),
|
|
319
|
-
sizePreset:
|
|
367
|
+
sizePreset: parentProjectParams.sizePreset
|
|
320
368
|
});
|
|
321
369
|
this._enhancementProject = project;
|
|
322
370
|
this._enhancementProject.on('updated', this.handleEnhancementUpdate);
|
package/src/Projects/Project.ts
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import Job, { JobData } from './Job';
|
|
2
2
|
import DataEntity, { EntityEvents } from '../lib/DataEntity';
|
|
3
|
-
import { ProjectParams } from './types';
|
|
3
|
+
import { isImageParams, ProjectParams } from './types';
|
|
4
4
|
import cloneDeep from 'lodash/cloneDeep';
|
|
5
5
|
import ErrorData from '../types/ErrorData';
|
|
6
6
|
import getUUID from '../lib/getUUID';
|
|
@@ -93,6 +93,10 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
|
93
93
|
return this.data.params;
|
|
94
94
|
}
|
|
95
95
|
|
|
96
|
+
get type() {
|
|
97
|
+
return this.params.type;
|
|
98
|
+
}
|
|
99
|
+
|
|
96
100
|
get status() {
|
|
97
101
|
return this.data.status;
|
|
98
102
|
}
|
|
@@ -110,10 +114,10 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
|
110
114
|
*/
|
|
111
115
|
get progress() {
|
|
112
116
|
// Worker can reduce the number of steps in the job, so we need to calculate the progress based on the actual number of steps
|
|
113
|
-
const stepsPerJob = this.jobs.length ? this.jobs[0].stepCount : this.data.params.steps;
|
|
114
|
-
const jobCount = this.data.params.
|
|
117
|
+
const stepsPerJob = this.jobs.length ? this.jobs[0].stepCount : (this.data.params.steps ?? 0);
|
|
118
|
+
const jobCount = this.data.params.numberOfMedia;
|
|
115
119
|
const stepsDone = this._jobs.reduce((acc, job) => acc + job.step, 0);
|
|
116
|
-
return Math.round((stepsDone / (stepsPerJob * jobCount)) * 100);
|
|
120
|
+
return Math.round((stepsDone / ((stepsPerJob ?? 1) * jobCount)) * 100);
|
|
117
121
|
}
|
|
118
122
|
|
|
119
123
|
get queuePosition() {
|
|
@@ -192,7 +196,7 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
|
192
196
|
this._timeout = null;
|
|
193
197
|
}
|
|
194
198
|
if (keys.includes('status') || keys.includes('jobs')) {
|
|
195
|
-
const allJobsStarted = this.jobs.length >= this.params.
|
|
199
|
+
const allJobsStarted = this.jobs.length >= this.params.numberOfMedia;
|
|
196
200
|
const allJobsDone = this.jobs.every((job) => job.finished);
|
|
197
201
|
if (this.data.status === 'completed' && allJobsStarted && allJobsDone) {
|
|
198
202
|
return this.emit('completed', this.resultUrls);
|
|
@@ -203,6 +207,16 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
|
203
207
|
}
|
|
204
208
|
}
|
|
205
209
|
|
|
210
|
+
/**
|
|
211
|
+
* Refresh the lastUpdated timestamp to prevent timeout.
|
|
212
|
+
* Used when receiving socket events that indicate the project is still active
|
|
213
|
+
* (e.g., jobETA events during long-running video generation).
|
|
214
|
+
* @internal
|
|
215
|
+
*/
|
|
216
|
+
_keepAlive() {
|
|
217
|
+
this.lastUpdated = new Date();
|
|
218
|
+
}
|
|
219
|
+
|
|
206
220
|
/**
|
|
207
221
|
* This is internal method to add a job to the project. Do not call this directly.
|
|
208
222
|
* @internal
|
|
@@ -231,7 +245,11 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
|
231
245
|
private _checkForTimeout() {
|
|
232
246
|
if (this.lastUpdated.getTime() + PROJECT_TIMEOUT < Date.now()) {
|
|
233
247
|
this._syncToServer().catch((error) => {
|
|
234
|
-
|
|
248
|
+
// 404 errors are expected when project is still initializing and not yet available via REST API
|
|
249
|
+
// Only log non-404 errors to avoid confusing users
|
|
250
|
+
if (error.status !== 404) {
|
|
251
|
+
this._logger.error(error);
|
|
252
|
+
}
|
|
235
253
|
this._failedSyncAttempts++;
|
|
236
254
|
if (this._failedSyncAttempts >= MAX_FAILED_SYNC_ATTEMPTS) {
|
|
237
255
|
this._logger.error(
|
|
@@ -298,11 +316,13 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
|
298
316
|
const delta: Partial<ProjectData> = {
|
|
299
317
|
params: {
|
|
300
318
|
...this.data.params,
|
|
301
|
-
|
|
302
|
-
steps: data.stepCount
|
|
303
|
-
numberOfPreviews: data.previewCount
|
|
319
|
+
numberOfMedia: data.imageCount,
|
|
320
|
+
steps: data.stepCount
|
|
304
321
|
}
|
|
305
322
|
};
|
|
323
|
+
if (delta.params && isImageParams(delta.params)) {
|
|
324
|
+
delta.params.numberOfPreviews = data.previewCount;
|
|
325
|
+
}
|
|
306
326
|
if (PROJECT_STATUS_MAP[data.status]) {
|
|
307
327
|
delta.status = PROJECT_STATUS_MAP[data.status];
|
|
308
328
|
}
|
|
@@ -1,6 +1,54 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import {
|
|
2
|
+
ImageProjectParams,
|
|
3
|
+
isImageParams,
|
|
4
|
+
isVideoParams,
|
|
5
|
+
ProjectParams,
|
|
6
|
+
VideoProjectParams
|
|
7
|
+
} from './types';
|
|
2
8
|
import { ControlNetParams, ControlNetParamsRaw } from './types/ControlNetParams';
|
|
3
|
-
import {
|
|
9
|
+
import {
|
|
10
|
+
validateNumber,
|
|
11
|
+
validateCustomImageSize,
|
|
12
|
+
validateSampler,
|
|
13
|
+
validateScheduler,
|
|
14
|
+
validateVideoSize
|
|
15
|
+
} from '../lib/validation';
|
|
16
|
+
import { getVideoWorkflowType, isVideoModel, VIDEO_WORKFLOW_ASSETS } from './utils';
|
|
17
|
+
import { ApiError } from '../ApiClient';
|
|
18
|
+
|
|
19
|
+
/**
|
|
20
|
+
* Validate that the provided assets match the workflow requirements.
|
|
21
|
+
* Throws an error if required assets are missing or forbidden assets are provided.
|
|
22
|
+
*/
|
|
23
|
+
function validateVideoWorkflowAssets(params: VideoProjectParams): void {
|
|
24
|
+
const workflowType = getVideoWorkflowType(params.modelId);
|
|
25
|
+
if (!workflowType) return;
|
|
26
|
+
|
|
27
|
+
const requirements = VIDEO_WORKFLOW_ASSETS[workflowType];
|
|
28
|
+
if (!requirements) return;
|
|
29
|
+
// Check for missing required assets
|
|
30
|
+
for (const [asset, requirement] of Object.entries(requirements)) {
|
|
31
|
+
const assetKey = asset as keyof VideoProjectParams;
|
|
32
|
+
const hasAsset = !!params[assetKey];
|
|
33
|
+
|
|
34
|
+
if (requirement === 'required' && !hasAsset) {
|
|
35
|
+
throw new ApiError(400, {
|
|
36
|
+
status: 'error',
|
|
37
|
+
errorCode: 0,
|
|
38
|
+
message: `${workflowType} workflow requires ${assetKey}. Please provide this asset.`
|
|
39
|
+
});
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
if (requirement === 'forbidden' && hasAsset) {
|
|
43
|
+
throw new ApiError(400, {
|
|
44
|
+
status: 'error',
|
|
45
|
+
errorCode: 0,
|
|
46
|
+
message: `${workflowType} workflow does not support ${assetKey}. Please remove this asset.`
|
|
47
|
+
});
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
|
|
4
52
|
// Mac worker can't process the data if some of the fields are missing, so we need to provide a default template
|
|
5
53
|
function getTemplate() {
|
|
6
54
|
return {
|
|
@@ -25,8 +73,8 @@ function getTemplate() {
|
|
|
25
73
|
guidanceScaleIsEnabled: true,
|
|
26
74
|
siImageBackgroundColor: 'black',
|
|
27
75
|
cnDragOffset: [0, 0],
|
|
28
|
-
scheduler:
|
|
29
|
-
timeStepSpacing:
|
|
76
|
+
scheduler: null,
|
|
77
|
+
timeStepSpacing: null,
|
|
30
78
|
steps: 20,
|
|
31
79
|
cnRotation: 0,
|
|
32
80
|
guidanceScale: 7.5,
|
|
@@ -112,49 +160,120 @@ function getControlNet(params: ControlNetParams): ControlNetParamsRaw[] {
|
|
|
112
160
|
return [cn];
|
|
113
161
|
}
|
|
114
162
|
|
|
163
|
+
function applyImageParams(inputKeyframe: Record<string, any>, params: ImageProjectParams) {
|
|
164
|
+
const keyFrame: Record<string, any> = {
|
|
165
|
+
...inputKeyframe,
|
|
166
|
+
scheduler: validateSampler(params.sampler),
|
|
167
|
+
timeStepSpacing: validateScheduler(params.scheduler),
|
|
168
|
+
sizePreset: params.sizePreset,
|
|
169
|
+
hasContextImage1: !!params.contextImages?.[0],
|
|
170
|
+
hasContextImage2: !!params.contextImages?.[1],
|
|
171
|
+
hasContextImage3: !!params.contextImages?.[2]
|
|
172
|
+
};
|
|
173
|
+
|
|
174
|
+
if (params.startingImage) {
|
|
175
|
+
keyFrame.hasStartingImage = true;
|
|
176
|
+
keyFrame.strengthIsEnabled = true;
|
|
177
|
+
keyFrame.strength = 1 - (Number(params.startingImageStrength) || 0.5);
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
if (params.controlNet) {
|
|
181
|
+
keyFrame.currentControlNetsJob = getControlNet(params.controlNet);
|
|
182
|
+
}
|
|
183
|
+
if (params.sizePreset === 'custom') {
|
|
184
|
+
keyFrame.width = validateCustomImageSize(params.width);
|
|
185
|
+
keyFrame.height = validateCustomImageSize(params.height);
|
|
186
|
+
}
|
|
187
|
+
return keyFrame;
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
function applyVideoParams(inputKeyframe: Record<string, any>, params: VideoProjectParams) {
|
|
191
|
+
if (!isVideoModel(params.modelId)) {
|
|
192
|
+
throw new ApiError(400, {
|
|
193
|
+
status: 'error',
|
|
194
|
+
errorCode: 0,
|
|
195
|
+
message: 'Video generation is only supported for video models.'
|
|
196
|
+
});
|
|
197
|
+
}
|
|
198
|
+
validateVideoWorkflowAssets(params);
|
|
199
|
+
const keyFrame: Record<string, any> = { ...inputKeyframe };
|
|
200
|
+
if (params.referenceImage) {
|
|
201
|
+
keyFrame.hasReferenceImage = true;
|
|
202
|
+
}
|
|
203
|
+
if (params.referenceImageEnd) {
|
|
204
|
+
keyFrame.hasReferenceImageEnd = true;
|
|
205
|
+
}
|
|
206
|
+
if (params.referenceAudio) {
|
|
207
|
+
keyFrame.hasReferenceAudio = true;
|
|
208
|
+
}
|
|
209
|
+
if (params.referenceVideo) {
|
|
210
|
+
keyFrame.hasReferenceVideo = true;
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
// Video generation parameters
|
|
214
|
+
if (params.frames !== undefined) {
|
|
215
|
+
keyFrame.frames = params.frames;
|
|
216
|
+
}
|
|
217
|
+
if (params.fps !== undefined) {
|
|
218
|
+
keyFrame.fps = params.fps;
|
|
219
|
+
}
|
|
220
|
+
if (params.shift !== undefined) {
|
|
221
|
+
keyFrame.shift = params.shift;
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
// Validate and set video dimensions (minimum 480px for Wan 2.2 models)
|
|
225
|
+
if (params.width && params.height) {
|
|
226
|
+
keyFrame.width = validateVideoSize(params.width, 'width');
|
|
227
|
+
keyFrame.height = validateVideoSize(params.height, 'height');
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
return keyFrame;
|
|
231
|
+
}
|
|
232
|
+
|
|
115
233
|
function createJobRequestMessage(id: string, params: ProjectParams) {
|
|
116
234
|
const template = getTemplate();
|
|
235
|
+
// Base keyFrame with common params
|
|
236
|
+
let keyFrame: Record<string, any> = {
|
|
237
|
+
...template.keyFrames[0],
|
|
238
|
+
steps: params.steps,
|
|
239
|
+
guidanceScale: params.guidance,
|
|
240
|
+
modelID: params.modelId,
|
|
241
|
+
negativePrompt: params.negativePrompt,
|
|
242
|
+
seed: params.seed,
|
|
243
|
+
positivePrompt: params.positivePrompt,
|
|
244
|
+
stylePrompt: params.stylePrompt
|
|
245
|
+
};
|
|
246
|
+
|
|
247
|
+
switch (params.type) {
|
|
248
|
+
case 'image':
|
|
249
|
+
keyFrame = applyImageParams(keyFrame, params);
|
|
250
|
+
break;
|
|
251
|
+
case 'video':
|
|
252
|
+
keyFrame = applyVideoParams(keyFrame, params);
|
|
253
|
+
break;
|
|
254
|
+
default:
|
|
255
|
+
throw new ApiError(400, {
|
|
256
|
+
status: 'error',
|
|
257
|
+
errorCode: 0,
|
|
258
|
+
message: 'Invalid project type. Must be "image" or "video".'
|
|
259
|
+
});
|
|
260
|
+
}
|
|
261
|
+
|
|
117
262
|
const jobRequest: Record<string, any> = {
|
|
118
263
|
...template,
|
|
119
|
-
keyFrames: [
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
scheduler: params.scheduler || null,
|
|
123
|
-
timeStepSpacing: params.timeStepSpacing || null,
|
|
124
|
-
steps: params.steps,
|
|
125
|
-
guidanceScale: params.guidance,
|
|
126
|
-
modelID: params.modelId,
|
|
127
|
-
negativePrompt: params.negativePrompt,
|
|
128
|
-
seed: params.seed,
|
|
129
|
-
positivePrompt: params.positivePrompt,
|
|
130
|
-
stylePrompt: params.stylePrompt,
|
|
131
|
-
hasStartingImage: !!params.startingImage,
|
|
132
|
-
hasContextImage1: !!params.contextImages?.[0],
|
|
133
|
-
hasContextImage2: !!params.contextImages?.[1],
|
|
134
|
-
strengthIsEnabled: !!params.startingImage,
|
|
135
|
-
strength: !!params.startingImage
|
|
136
|
-
? 1 - (Number(params.startingImageStrength) || 0.5)
|
|
137
|
-
: undefined,
|
|
138
|
-
sizePreset: params.sizePreset
|
|
139
|
-
}
|
|
140
|
-
],
|
|
141
|
-
previews: params.numberOfPreviews || 0,
|
|
142
|
-
numberOfImages: params.numberOfImages,
|
|
264
|
+
keyFrames: [keyFrame],
|
|
265
|
+
previews: isImageParams(params) ? params.numberOfPreviews || 0 : 0,
|
|
266
|
+
numberOfImages: params.numberOfMedia || 1,
|
|
143
267
|
jobID: id,
|
|
144
268
|
disableSafety: !!params.disableNSFWFilter,
|
|
145
269
|
tokenType: params.tokenType,
|
|
146
|
-
outputFormat: params.outputFormat || 'png'
|
|
270
|
+
outputFormat: params.outputFormat || (isVideoParams(params) ? 'mp4' : 'png')
|
|
147
271
|
};
|
|
272
|
+
|
|
148
273
|
if (params.network) {
|
|
149
274
|
jobRequest.network = params.network;
|
|
150
275
|
}
|
|
151
|
-
|
|
152
|
-
jobRequest.keyFrames[0].currentControlNetsJob = getControlNet(params.controlNet);
|
|
153
|
-
}
|
|
154
|
-
if (params.sizePreset === 'custom') {
|
|
155
|
-
jobRequest.keyFrames[0].width = validateCustomImageSize(params.width);
|
|
156
|
-
jobRequest.keyFrames[0].height = validateCustomImageSize(params.height);
|
|
157
|
-
}
|
|
276
|
+
|
|
158
277
|
return jobRequest;
|
|
159
278
|
}
|
|
160
279
|
|