@sogni-ai/sogni-client 4.0.0-alpha.3 → 4.0.0-alpha.30

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. package/CHANGELOG.md +213 -0
  2. package/README.md +279 -28
  3. package/dist/Account/index.d.ts +18 -16
  4. package/dist/Account/index.js +31 -20
  5. package/dist/Account/index.js.map +1 -1
  6. package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/ChannelCoordinator.d.ts +66 -0
  7. package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/ChannelCoordinator.js +332 -0
  8. package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/ChannelCoordinator.js.map +1 -0
  9. package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/index.d.ts +28 -0
  10. package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/index.js +203 -0
  11. package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/index.js.map +1 -0
  12. package/dist/ApiClient/WebSocketClient/events.d.ts +11 -0
  13. package/dist/ApiClient/WebSocketClient/index.d.ts +2 -2
  14. package/dist/ApiClient/WebSocketClient/index.js +13 -3
  15. package/dist/ApiClient/WebSocketClient/index.js.map +1 -1
  16. package/dist/ApiClient/WebSocketClient/types.d.ts +13 -0
  17. package/dist/ApiClient/index.d.ts +4 -4
  18. package/dist/ApiClient/index.js +23 -4
  19. package/dist/ApiClient/index.js.map +1 -1
  20. package/dist/Projects/Job.d.ts +24 -4
  21. package/dist/Projects/Job.js +58 -16
  22. package/dist/Projects/Job.js.map +1 -1
  23. package/dist/Projects/Project.d.ts +8 -0
  24. package/dist/Projects/Project.js +27 -6
  25. package/dist/Projects/Project.js.map +1 -1
  26. package/dist/Projects/createJobRequestMessage.js +109 -15
  27. package/dist/Projects/createJobRequestMessage.js.map +1 -1
  28. package/dist/Projects/index.d.ts +110 -11
  29. package/dist/Projects/index.js +412 -42
  30. package/dist/Projects/index.js.map +1 -1
  31. package/dist/Projects/types/EstimationResponse.d.ts +2 -0
  32. package/dist/Projects/types/SamplerParams.d.ts +13 -0
  33. package/dist/Projects/types/SamplerParams.js +26 -0
  34. package/dist/Projects/types/SamplerParams.js.map +1 -0
  35. package/dist/Projects/types/SchedulerParams.d.ts +14 -0
  36. package/dist/Projects/types/SchedulerParams.js +24 -0
  37. package/dist/Projects/types/SchedulerParams.js.map +1 -0
  38. package/dist/Projects/types/events.d.ts +5 -1
  39. package/dist/Projects/types/index.d.ts +150 -39
  40. package/dist/Projects/types/index.js +13 -0
  41. package/dist/Projects/types/index.js.map +1 -1
  42. package/dist/Projects/utils.d.ts +19 -1
  43. package/dist/Projects/utils.js +68 -0
  44. package/dist/Projects/utils.js.map +1 -1
  45. package/dist/index.d.ts +12 -4
  46. package/dist/index.js +12 -4
  47. package/dist/index.js.map +1 -1
  48. package/dist/lib/AuthManager/TokenAuthManager.js +0 -2
  49. package/dist/lib/AuthManager/TokenAuthManager.js.map +1 -1
  50. package/dist/lib/DataEntity.js +4 -2
  51. package/dist/lib/DataEntity.js.map +1 -1
  52. package/dist/lib/validation.d.ts +7 -0
  53. package/dist/lib/validation.js +36 -0
  54. package/dist/lib/validation.js.map +1 -1
  55. package/package.json +4 -4
  56. package/src/Account/index.ts +30 -19
  57. package/src/ApiClient/WebSocketClient/BrowserWebSocketClient/ChannelCoordinator.ts +426 -0
  58. package/src/ApiClient/WebSocketClient/BrowserWebSocketClient/index.ts +237 -0
  59. package/src/ApiClient/WebSocketClient/events.ts +13 -0
  60. package/src/ApiClient/WebSocketClient/index.ts +15 -5
  61. package/src/ApiClient/WebSocketClient/types.ts +16 -0
  62. package/src/ApiClient/index.ts +30 -8
  63. package/src/Projects/Job.ts +64 -16
  64. package/src/Projects/Project.ts +29 -9
  65. package/src/Projects/createJobRequestMessage.ts +155 -36
  66. package/src/Projects/index.ts +437 -46
  67. package/src/Projects/types/EstimationResponse.ts +2 -0
  68. package/src/Projects/types/SamplerParams.ts +24 -0
  69. package/src/Projects/types/SchedulerParams.ts +22 -0
  70. package/src/Projects/types/events.ts +6 -0
  71. package/src/Projects/types/index.ts +181 -47
  72. package/src/Projects/utils.ts +66 -1
  73. package/src/index.ts +38 -11
  74. package/src/lib/AuthManager/TokenAuthManager.ts +0 -2
  75. package/src/lib/DataEntity.ts +4 -2
  76. package/src/lib/validation.ts +41 -0
@@ -1,7 +1,7 @@
1
1
  import { MessageType, SocketMessageMap } from './messages';
2
2
  import { SocketEventMap } from './events';
3
3
  import RestClient from '../../lib/RestClient';
4
- import { SupernetType } from './types';
4
+ import { IWebSocketClient, SupernetType } from './types';
5
5
  import WebSocket, { CloseEvent, ErrorEvent, MessageEvent } from 'isomorphic-ws';
6
6
  import { base64Decode, base64Encode } from '../../lib/base64';
7
7
  import isNodejs from '../../lib/isNodejs';
@@ -13,7 +13,7 @@ const PROTOCOL_VERSION = '3.0.0';
13
13
 
14
14
  const PING_INTERVAL = 15000;
15
15
 
16
- class WebSocketClient extends RestClient<SocketEventMap> {
16
+ class WebSocketClient extends RestClient<SocketEventMap> implements IWebSocketClient {
17
17
  appId: string;
18
18
  baseUrl: string;
19
19
  private socket: WebSocket | null = null;
@@ -86,7 +86,7 @@ class WebSocketClient extends RestClient<SocketEventMap> {
86
86
  socket.onmessage = null;
87
87
  socket.onopen = null;
88
88
  this.stopPing();
89
- socket.close();
89
+ socket.close(1000, 'Client disconnected');
90
90
  }
91
91
 
92
92
  private startPing(socket: WebSocket) {
@@ -148,9 +148,16 @@ class WebSocketClient extends RestClient<SocketEventMap> {
148
148
  }
149
149
 
150
150
  private handleClose(e: CloseEvent) {
151
- if (e.target === this.socket) {
151
+ const socket = e.target;
152
+ socket.onerror = null;
153
+ socket.onmessage = null;
154
+ socket.onopen = null;
155
+ if (socket === this.socket || !this.socket) {
152
156
  this._logger.info('WebSocket disconnected, cleanup', e);
153
- this.disconnect();
157
+ if (socket === this.socket) {
158
+ this.stopPing();
159
+ this.socket = null;
160
+ }
154
161
  this.emit('disconnected', {
155
162
  code: e.code,
156
163
  reason: e.reason
@@ -193,6 +200,9 @@ class WebSocketClient extends RestClient<SocketEventMap> {
193
200
  }
194
201
 
195
202
  async send<T extends MessageType>(messageType: T, data: SocketMessageMap[T]) {
203
+ if (!this.isConnected) {
204
+ await this.connect();
205
+ }
196
206
  await this.waitForConnection();
197
207
  this._logger.debug('WebSocket send:', messageType, data);
198
208
  this.socket!.send(
@@ -1 +1,17 @@
1
+ import { MessageType, SocketMessageMap } from './messages';
2
+ import RestClient from '../../lib/RestClient';
3
+ import { SocketEventMap } from './events';
4
+
1
5
  export type SupernetType = 'relaxed' | 'fast';
6
+
7
+ export interface IWebSocketClient extends RestClient<SocketEventMap> {
8
+ appId: string;
9
+ baseUrl: string;
10
+ isConnected: boolean;
11
+ supernetType: SupernetType;
12
+
13
+ connect(): Promise<void>;
14
+ disconnect(): void;
15
+ send<T extends MessageType>(messageType: T, data: SocketMessageMap[T]): Promise<void>;
16
+ switchNetwork(supernetType: SupernetType): Promise<SupernetType>;
17
+ }
@@ -3,12 +3,14 @@ import WebSocketClient from './WebSocketClient';
3
3
  import TypedEventEmitter from '../lib/TypedEventEmitter';
4
4
  import { ApiClientEvents } from './events';
5
5
  import { ServerConnectData, ServerDisconnectData } from './WebSocketClient/events';
6
- import { isNotRecoverable } from './WebSocketClient/ErrorCode';
6
+ import { ErrorCode, isNotRecoverable } from './WebSocketClient/ErrorCode';
7
7
  import { JSONValue } from '../types/json';
8
- import { SupernetType } from './WebSocketClient/types';
8
+ import { IWebSocketClient, SupernetType } from './WebSocketClient/types';
9
9
  import { Logger } from '../lib/DefaultLogger';
10
10
  import CookieAuthManager from '../lib/AuthManager/CookieAuthManager';
11
11
  import { AuthManager, TokenAuthManager } from '../lib/AuthManager';
12
+ import isNodejs from '../lib/isNodejs';
13
+ import BrowserWebSocketClient from './WebSocketClient/BrowserWebSocketClient';
12
14
 
13
15
  const WS_RECONNECT_ATTEMPTS = 5;
14
16
 
@@ -42,13 +44,14 @@ export interface ApiClientOptions {
42
44
  logger: Logger;
43
45
  authType: 'token' | 'cookies';
44
46
  disableSocket?: boolean;
47
+ multiInstance?: boolean;
45
48
  }
46
49
 
47
50
  class ApiClient extends TypedEventEmitter<ApiClientEvents> {
48
51
  readonly appId: string;
49
52
  readonly logger: Logger;
50
53
  private _rest: RestClient;
51
- private _socket: WebSocketClient;
54
+ private _socket: IWebSocketClient;
52
55
  private _auth: AuthManager;
53
56
  private _reconnectAttempts = WS_RECONNECT_ATTEMPTS;
54
57
  private _disableSocket: boolean = false;
@@ -60,7 +63,8 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
60
63
  networkType,
61
64
  authType,
62
65
  logger,
63
- disableSocket = false
66
+ disableSocket = false,
67
+ multiInstance = false
64
68
  }: ApiClientOptions) {
65
69
  super();
66
70
  this.appId = appId;
@@ -68,7 +72,13 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
68
72
  this._auth =
69
73
  authType === 'token' ? new TokenAuthManager(baseUrl, logger) : new CookieAuthManager(logger);
70
74
  this._rest = new RestClient(baseUrl, this._auth, logger);
71
- this._socket = new WebSocketClient(socketUrl, this._auth, appId, networkType, logger);
75
+ const supportMultiInstance = !isNodejs && this._auth instanceof CookieAuthManager;
76
+ if (supportMultiInstance && multiInstance) {
77
+ // Use coordinated WebSocket client to share single connection between tabs
78
+ this._socket = new BrowserWebSocketClient(socketUrl, this._auth, appId, networkType, logger);
79
+ } else {
80
+ this._socket = new WebSocketClient(socketUrl, this._auth, appId, networkType, logger);
81
+ }
72
82
  this._disableSocket = disableSocket;
73
83
  this._auth.on('updated', this.handleAuthUpdated.bind(this));
74
84
  this._socket.on('connected', this.handleSocketConnect.bind(this));
@@ -83,7 +93,7 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
83
93
  return this._auth;
84
94
  }
85
95
 
86
- get socket(): WebSocketClient {
96
+ get socket(): IWebSocketClient {
87
97
  return this._socket;
88
98
  }
89
99
 
@@ -101,14 +111,26 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
101
111
  }
102
112
 
103
113
  handleSocketDisconnect(data: ServerDisconnectData) {
114
+ // If user is not authenticated, we don't need to reconnect
115
+ if (!this.auth.isAuthenticated || data.code === 1000) {
116
+ this.emit('disconnected', data);
117
+ return;
118
+ }
104
119
  if (!data.code || isNotRecoverable(data.code)) {
120
+ // If this is browser, another tab is probably claiming the connection, so we don't need to reconnect
121
+ if (
122
+ this._socket instanceof BrowserWebSocketClient &&
123
+ data.code === ErrorCode.SWITCH_CONNECTION
124
+ ) {
125
+ this.logger.debug('Switching network connection, not reconnecting');
126
+ return;
127
+ }
105
128
  this.auth.clear();
106
129
  this.emit('disconnected', data);
107
130
  this.logger.error('Not recoverable socket error', data);
108
131
  return;
109
132
  }
110
133
  if (this._reconnectAttempts <= 0) {
111
- this.auth.clear();
112
134
  this.emit('disconnected', data);
113
135
  this._reconnectAttempts = WS_RECONNECT_ATTEMPTS;
114
136
  return;
@@ -122,7 +144,7 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
122
144
  if (this.socket.isConnected) {
123
145
  this.socket.disconnect();
124
146
  }
125
- } else if (!this._disableSocket) {
147
+ } else if (!this._disableSocket && !this.socket.isConnected) {
126
148
  this.socket.connect();
127
149
  }
128
150
  }
@@ -19,7 +19,7 @@ export const enhancementDefaults = {
19
19
  startingImageStrength: 0.5,
20
20
  steps: 5,
21
21
  guidance: 1,
22
- numberOfImages: 1,
22
+ numberOfMedia: 1,
23
23
  numberOfPreviews: 0
24
24
  };
25
25
 
@@ -61,6 +61,11 @@ export interface JobData {
61
61
  positivePrompt?: string;
62
62
  negativePrompt?: string;
63
63
  jobIndex?: number;
64
+ /**
65
+ * Estimated time remaining in seconds (for long-running jobs like video generation).
66
+ * Updated by ComfyUI workers during inference.
67
+ */
68
+ etaSeconds?: number;
64
69
  }
65
70
 
66
71
  export interface JobEventMap extends EntityEvents {
@@ -178,10 +183,21 @@ class Job extends DataEntity<JobData, JobEventMap> {
178
183
  return this.data.error;
179
184
  }
180
185
 
181
- get hasResultImage() {
186
+ /**
187
+ * Whether this job has a result media file available for download.
188
+ * Returns true if completed and not NSFW filtered.
189
+ */
190
+ get hasResultMedia() {
182
191
  return this.status === 'completed' && !this.isNSFW;
183
192
  }
184
193
 
194
+ /**
195
+ * Whether this job produces video output (based on the model used)
196
+ */
197
+ get type(): 'image' | 'video' {
198
+ return this._api.isVideoModelId(this._project.params.modelId) ? 'video' : 'image';
199
+ }
200
+
185
201
  get enhancedImage() {
186
202
  if (!this._enhancementProject) {
187
203
  return null;
@@ -199,17 +215,27 @@ class Job extends DataEntity<JobData, JobEventMap> {
199
215
 
200
216
  /**
201
217
  * Get the result URL of the job. This method will make a request to the API to get signed URL.
202
- * IMPORTANT: URL expires after 30 minutes, so make sure to download the image as soon as possible.
218
+ * IMPORTANT: URL expires after 30 minutes, so make sure to download the result as soon as possible.
219
+ * For video jobs, this returns a video URL. For image jobs, this returns an image URL.
203
220
  */
204
221
  async getResultUrl(): Promise<string> {
205
222
  if (this.data.status !== 'completed') {
206
223
  throw new Error('Job is not completed yet');
207
224
  }
208
- const url = await this._api.downloadUrl({
209
- jobId: this.projectId,
210
- imageId: this.id,
211
- type: 'complete'
212
- });
225
+ let url: string;
226
+ if (this.type === 'video') {
227
+ url = await this._api.mediaDownloadUrl({
228
+ jobId: this.projectId,
229
+ id: this.id,
230
+ type: 'complete'
231
+ });
232
+ } else {
233
+ url = await this._api.downloadUrl({
234
+ jobId: this.projectId,
235
+ imageId: this.id,
236
+ type: 'complete'
237
+ });
238
+ }
213
239
  this._update({ resultUrl: url });
214
240
  return url;
215
241
  }
@@ -230,6 +256,15 @@ class Job extends DataEntity<JobData, JobEventMap> {
230
256
  return this.data.workerName;
231
257
  }
232
258
 
259
+ /**
260
+ * Estimated time remaining in seconds for long-running jobs (e.g., video generation).
261
+ * Only available for ComfyUI-based workers during inference.
262
+ * Returns undefined if no ETA has been received.
263
+ */
264
+ get etaSeconds() {
265
+ return this.data.etaSeconds;
266
+ }
267
+
233
268
  /**
234
269
  * Syncs the job data with the data received from the REST API.
235
270
  * @internal
@@ -247,11 +282,19 @@ class Job extends DataEntity<JobData, JobEventMap> {
247
282
  }
248
283
  if (!this.data.resultUrl && delta.status === 'completed' && !data.triggeredNSFWFilter) {
249
284
  try {
250
- delta.resultUrl = await this._api.downloadUrl({
251
- jobId: this.projectId,
252
- imageId: this.id,
253
- type: 'complete'
254
- });
285
+ if (this.type === 'video') {
286
+ delta.resultUrl = await this._api.mediaDownloadUrl({
287
+ jobId: this.projectId,
288
+ id: this.id,
289
+ type: 'complete'
290
+ });
291
+ } else {
292
+ delta.resultUrl = await this._api.downloadUrl({
293
+ jobId: this.projectId,
294
+ imageId: this.id,
295
+ type: 'complete'
296
+ });
297
+ }
255
298
  } catch (error) {
256
299
  this._logger.error(error);
257
300
  }
@@ -276,8 +319,8 @@ class Job extends DataEntity<JobData, JobEventMap> {
276
319
  }
277
320
 
278
321
  async getResultData() {
279
- if (!this.hasResultImage) {
280
- throw new Error('No result image available');
322
+ if (!this.hasResultMedia) {
323
+ throw new Error('No result media available');
281
324
  }
282
325
  const url = await this.getResultUrl();
283
326
  const response = await fetch(url);
@@ -297,6 +340,10 @@ class Job extends DataEntity<JobData, JobEventMap> {
297
340
  strength: EnhancementStrength,
298
341
  overrides: { positivePrompt?: string; stylePrompt?: string; tokenType?: TokenType } = {}
299
342
  ) {
343
+ const parentProjectParams = this._project.params;
344
+ if (parentProjectParams.type !== 'image') {
345
+ throw new Error('Enhancement is only available for images');
346
+ }
300
347
  if (this.status !== 'completed') {
301
348
  throw new Error('Job is not completed yet');
302
349
  }
@@ -309,6 +356,7 @@ class Job extends DataEntity<JobData, JobEventMap> {
309
356
  }
310
357
  const imageData = await this.getResultData();
311
358
  const project = await this._api.create({
359
+ type: 'image',
312
360
  ...enhancementDefaults,
313
361
  positivePrompt: overrides.positivePrompt || this._project.params.positivePrompt,
314
362
  stylePrompt: overrides.stylePrompt || this._project.params.stylePrompt,
@@ -316,7 +364,7 @@ class Job extends DataEntity<JobData, JobEventMap> {
316
364
  seed: this.seed || this._project.params.seed,
317
365
  startingImage: imageData,
318
366
  startingImageStrength: 1 - getEnhacementStrength(strength),
319
- sizePreset: this._project.params.sizePreset
367
+ sizePreset: parentProjectParams.sizePreset
320
368
  });
321
369
  this._enhancementProject = project;
322
370
  this._enhancementProject.on('updated', this.handleEnhancementUpdate);
@@ -1,6 +1,6 @@
1
1
  import Job, { JobData } from './Job';
2
2
  import DataEntity, { EntityEvents } from '../lib/DataEntity';
3
- import { ProjectParams } from './types';
3
+ import { isImageParams, ProjectParams } from './types';
4
4
  import cloneDeep from 'lodash/cloneDeep';
5
5
  import ErrorData from '../types/ErrorData';
6
6
  import getUUID from '../lib/getUUID';
@@ -93,6 +93,10 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
93
93
  return this.data.params;
94
94
  }
95
95
 
96
+ get type() {
97
+ return this.params.type;
98
+ }
99
+
96
100
  get status() {
97
101
  return this.data.status;
98
102
  }
@@ -110,10 +114,10 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
110
114
  */
111
115
  get progress() {
112
116
  // Worker can reduce the number of steps in the job, so we need to calculate the progress based on the actual number of steps
113
- const stepsPerJob = this.jobs.length ? this.jobs[0].stepCount : this.data.params.steps;
114
- const jobCount = this.data.params.numberOfImages;
117
+ const stepsPerJob = this.jobs.length ? this.jobs[0].stepCount : (this.data.params.steps ?? 0);
118
+ const jobCount = this.data.params.numberOfMedia;
115
119
  const stepsDone = this._jobs.reduce((acc, job) => acc + job.step, 0);
116
- return Math.round((stepsDone / (stepsPerJob * jobCount)) * 100);
120
+ return Math.round((stepsDone / ((stepsPerJob ?? 1) * jobCount)) * 100);
117
121
  }
118
122
 
119
123
  get queuePosition() {
@@ -192,7 +196,7 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
192
196
  this._timeout = null;
193
197
  }
194
198
  if (keys.includes('status') || keys.includes('jobs')) {
195
- const allJobsStarted = this.jobs.length >= this.params.numberOfImages;
199
+ const allJobsStarted = this.jobs.length >= this.params.numberOfMedia;
196
200
  const allJobsDone = this.jobs.every((job) => job.finished);
197
201
  if (this.data.status === 'completed' && allJobsStarted && allJobsDone) {
198
202
  return this.emit('completed', this.resultUrls);
@@ -203,6 +207,16 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
203
207
  }
204
208
  }
205
209
 
210
+ /**
211
+ * Refresh the lastUpdated timestamp to prevent timeout.
212
+ * Used when receiving socket events that indicate the project is still active
213
+ * (e.g., jobETA events during long-running video generation).
214
+ * @internal
215
+ */
216
+ _keepAlive() {
217
+ this.lastUpdated = new Date();
218
+ }
219
+
206
220
  /**
207
221
  * This is internal method to add a job to the project. Do not call this directly.
208
222
  * @internal
@@ -231,7 +245,11 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
231
245
  private _checkForTimeout() {
232
246
  if (this.lastUpdated.getTime() + PROJECT_TIMEOUT < Date.now()) {
233
247
  this._syncToServer().catch((error) => {
234
- this._logger.error(error);
248
+ // 404 errors are expected when project is still initializing and not yet available via REST API
249
+ // Only log non-404 errors to avoid confusing users
250
+ if (error.status !== 404) {
251
+ this._logger.error(error);
252
+ }
235
253
  this._failedSyncAttempts++;
236
254
  if (this._failedSyncAttempts >= MAX_FAILED_SYNC_ATTEMPTS) {
237
255
  this._logger.error(
@@ -298,11 +316,13 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
298
316
  const delta: Partial<ProjectData> = {
299
317
  params: {
300
318
  ...this.data.params,
301
- numberOfImages: data.imageCount,
302
- steps: data.stepCount,
303
- numberOfPreviews: data.previewCount
319
+ numberOfMedia: data.imageCount,
320
+ steps: data.stepCount
304
321
  }
305
322
  };
323
+ if (delta.params && isImageParams(delta.params)) {
324
+ delta.params.numberOfPreviews = data.previewCount;
325
+ }
306
326
  if (PROJECT_STATUS_MAP[data.status]) {
307
327
  delta.status = PROJECT_STATUS_MAP[data.status];
308
328
  }
@@ -1,6 +1,54 @@
1
- import { ProjectParams } from './types';
1
+ import {
2
+ ImageProjectParams,
3
+ isImageParams,
4
+ isVideoParams,
5
+ ProjectParams,
6
+ VideoProjectParams
7
+ } from './types';
2
8
  import { ControlNetParams, ControlNetParamsRaw } from './types/ControlNetParams';
3
- import { validateNumber, validateCustomImageSize } from '../lib/validation';
9
+ import {
10
+ validateNumber,
11
+ validateCustomImageSize,
12
+ validateSampler,
13
+ validateScheduler,
14
+ validateVideoSize
15
+ } from '../lib/validation';
16
+ import { getVideoWorkflowType, isVideoModel, VIDEO_WORKFLOW_ASSETS } from './utils';
17
+ import { ApiError } from '../ApiClient';
18
+
19
+ /**
20
+ * Validate that the provided assets match the workflow requirements.
21
+ * Throws an error if required assets are missing or forbidden assets are provided.
22
+ */
23
+ function validateVideoWorkflowAssets(params: VideoProjectParams): void {
24
+ const workflowType = getVideoWorkflowType(params.modelId);
25
+ if (!workflowType) return;
26
+
27
+ const requirements = VIDEO_WORKFLOW_ASSETS[workflowType];
28
+ if (!requirements) return;
29
+ // Check for missing required assets
30
+ for (const [asset, requirement] of Object.entries(requirements)) {
31
+ const assetKey = asset as keyof VideoProjectParams;
32
+ const hasAsset = !!params[assetKey];
33
+
34
+ if (requirement === 'required' && !hasAsset) {
35
+ throw new ApiError(400, {
36
+ status: 'error',
37
+ errorCode: 0,
38
+ message: `${workflowType} workflow requires ${assetKey}. Please provide this asset.`
39
+ });
40
+ }
41
+
42
+ if (requirement === 'forbidden' && hasAsset) {
43
+ throw new ApiError(400, {
44
+ status: 'error',
45
+ errorCode: 0,
46
+ message: `${workflowType} workflow does not support ${assetKey}. Please remove this asset.`
47
+ });
48
+ }
49
+ }
50
+ }
51
+
4
52
  // Mac worker can't process the data if some of the fields are missing, so we need to provide a default template
5
53
  function getTemplate() {
6
54
  return {
@@ -25,8 +73,8 @@ function getTemplate() {
25
73
  guidanceScaleIsEnabled: true,
26
74
  siImageBackgroundColor: 'black',
27
75
  cnDragOffset: [0, 0],
28
- scheduler: 'DPM Solver Multistep (DPM-Solver++)',
29
- timeStepSpacing: 'Linear',
76
+ scheduler: null,
77
+ timeStepSpacing: null,
30
78
  steps: 20,
31
79
  cnRotation: 0,
32
80
  guidanceScale: 7.5,
@@ -112,49 +160,120 @@ function getControlNet(params: ControlNetParams): ControlNetParamsRaw[] {
112
160
  return [cn];
113
161
  }
114
162
 
163
+ function applyImageParams(inputKeyframe: Record<string, any>, params: ImageProjectParams) {
164
+ const keyFrame: Record<string, any> = {
165
+ ...inputKeyframe,
166
+ scheduler: validateSampler(params.sampler),
167
+ timeStepSpacing: validateScheduler(params.scheduler),
168
+ sizePreset: params.sizePreset,
169
+ hasContextImage1: !!params.contextImages?.[0],
170
+ hasContextImage2: !!params.contextImages?.[1],
171
+ hasContextImage3: !!params.contextImages?.[2]
172
+ };
173
+
174
+ if (params.startingImage) {
175
+ keyFrame.hasStartingImage = true;
176
+ keyFrame.strengthIsEnabled = true;
177
+ keyFrame.strength = 1 - (Number(params.startingImageStrength) || 0.5);
178
+ }
179
+
180
+ if (params.controlNet) {
181
+ keyFrame.currentControlNetsJob = getControlNet(params.controlNet);
182
+ }
183
+ if (params.sizePreset === 'custom') {
184
+ keyFrame.width = validateCustomImageSize(params.width);
185
+ keyFrame.height = validateCustomImageSize(params.height);
186
+ }
187
+ return keyFrame;
188
+ }
189
+
190
+ function applyVideoParams(inputKeyframe: Record<string, any>, params: VideoProjectParams) {
191
+ if (!isVideoModel(params.modelId)) {
192
+ throw new ApiError(400, {
193
+ status: 'error',
194
+ errorCode: 0,
195
+ message: 'Video generation is only supported for video models.'
196
+ });
197
+ }
198
+ validateVideoWorkflowAssets(params);
199
+ const keyFrame: Record<string, any> = { ...inputKeyframe };
200
+ if (params.referenceImage) {
201
+ keyFrame.hasReferenceImage = true;
202
+ }
203
+ if (params.referenceImageEnd) {
204
+ keyFrame.hasReferenceImageEnd = true;
205
+ }
206
+ if (params.referenceAudio) {
207
+ keyFrame.hasReferenceAudio = true;
208
+ }
209
+ if (params.referenceVideo) {
210
+ keyFrame.hasReferenceVideo = true;
211
+ }
212
+
213
+ // Video generation parameters
214
+ if (params.frames !== undefined) {
215
+ keyFrame.frames = params.frames;
216
+ }
217
+ if (params.fps !== undefined) {
218
+ keyFrame.fps = params.fps;
219
+ }
220
+ if (params.shift !== undefined) {
221
+ keyFrame.shift = params.shift;
222
+ }
223
+
224
+ // Validate and set video dimensions (minimum 480px for Wan 2.2 models)
225
+ if (params.width && params.height) {
226
+ keyFrame.width = validateVideoSize(params.width, 'width');
227
+ keyFrame.height = validateVideoSize(params.height, 'height');
228
+ }
229
+
230
+ return keyFrame;
231
+ }
232
+
115
233
  function createJobRequestMessage(id: string, params: ProjectParams) {
116
234
  const template = getTemplate();
235
+ // Base keyFrame with common params
236
+ let keyFrame: Record<string, any> = {
237
+ ...template.keyFrames[0],
238
+ steps: params.steps,
239
+ guidanceScale: params.guidance,
240
+ modelID: params.modelId,
241
+ negativePrompt: params.negativePrompt,
242
+ seed: params.seed,
243
+ positivePrompt: params.positivePrompt,
244
+ stylePrompt: params.stylePrompt
245
+ };
246
+
247
+ switch (params.type) {
248
+ case 'image':
249
+ keyFrame = applyImageParams(keyFrame, params);
250
+ break;
251
+ case 'video':
252
+ keyFrame = applyVideoParams(keyFrame, params);
253
+ break;
254
+ default:
255
+ throw new ApiError(400, {
256
+ status: 'error',
257
+ errorCode: 0,
258
+ message: 'Invalid project type. Must be "image" or "video".'
259
+ });
260
+ }
261
+
117
262
  const jobRequest: Record<string, any> = {
118
263
  ...template,
119
- keyFrames: [
120
- {
121
- ...template.keyFrames[0],
122
- scheduler: params.scheduler || null,
123
- timeStepSpacing: params.timeStepSpacing || null,
124
- steps: params.steps,
125
- guidanceScale: params.guidance,
126
- modelID: params.modelId,
127
- negativePrompt: params.negativePrompt,
128
- seed: params.seed,
129
- positivePrompt: params.positivePrompt,
130
- stylePrompt: params.stylePrompt,
131
- hasStartingImage: !!params.startingImage,
132
- hasContextImage1: !!params.contextImages?.[0],
133
- hasContextImage2: !!params.contextImages?.[1],
134
- strengthIsEnabled: !!params.startingImage,
135
- strength: !!params.startingImage
136
- ? 1 - (Number(params.startingImageStrength) || 0.5)
137
- : undefined,
138
- sizePreset: params.sizePreset
139
- }
140
- ],
141
- previews: params.numberOfPreviews || 0,
142
- numberOfImages: params.numberOfImages,
264
+ keyFrames: [keyFrame],
265
+ previews: isImageParams(params) ? params.numberOfPreviews || 0 : 0,
266
+ numberOfImages: params.numberOfMedia || 1,
143
267
  jobID: id,
144
268
  disableSafety: !!params.disableNSFWFilter,
145
269
  tokenType: params.tokenType,
146
- outputFormat: params.outputFormat || 'png'
270
+ outputFormat: params.outputFormat || (isVideoParams(params) ? 'mp4' : 'png')
147
271
  };
272
+
148
273
  if (params.network) {
149
274
  jobRequest.network = params.network;
150
275
  }
151
- if (params.controlNet) {
152
- jobRequest.keyFrames[0].currentControlNetsJob = getControlNet(params.controlNet);
153
- }
154
- if (params.sizePreset === 'custom') {
155
- jobRequest.keyFrames[0].width = validateCustomImageSize(params.width);
156
- jobRequest.keyFrames[0].height = validateCustomImageSize(params.height);
157
- }
276
+
158
277
  return jobRequest;
159
278
  }
160
279