@sogni-ai/sogni-client 4.0.0-alpha.21 → 4.0.0-alpha.22

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/CHANGELOG.md +16 -0
  2. package/README.md +26 -15
  3. package/dist/Account/index.d.ts +15 -15
  4. package/dist/Account/index.js +15 -15
  5. package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/index.js +0 -4
  6. package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/index.js.map +1 -1
  7. package/dist/ApiClient/WebSocketClient/events.d.ts +10 -0
  8. package/dist/ApiClient/WebSocketClient/index.js +12 -2
  9. package/dist/ApiClient/WebSocketClient/index.js.map +1 -1
  10. package/dist/ApiClient/index.js +1 -1
  11. package/dist/ApiClient/index.js.map +1 -1
  12. package/dist/Projects/Job.d.ts +12 -3
  13. package/dist/Projects/Job.js +50 -16
  14. package/dist/Projects/Job.js.map +1 -1
  15. package/dist/Projects/Project.d.ts +1 -0
  16. package/dist/Projects/Project.js +10 -3
  17. package/dist/Projects/Project.js.map +1 -1
  18. package/dist/Projects/createJobRequestMessage.js +105 -12
  19. package/dist/Projects/createJobRequestMessage.js.map +1 -1
  20. package/dist/Projects/index.d.ts +74 -5
  21. package/dist/Projects/index.js +337 -33
  22. package/dist/Projects/index.js.map +1 -1
  23. package/dist/Projects/types/events.d.ts +5 -1
  24. package/dist/Projects/types/index.d.ts +113 -28
  25. package/dist/Projects/types/index.js +8 -0
  26. package/dist/Projects/types/index.js.map +1 -1
  27. package/dist/Projects/utils.d.ts +19 -1
  28. package/dist/Projects/utils.js +68 -0
  29. package/dist/Projects/utils.js.map +1 -1
  30. package/dist/index.d.ts +2 -2
  31. package/dist/index.js.map +1 -1
  32. package/dist/lib/AuthManager/TokenAuthManager.js +0 -2
  33. package/dist/lib/AuthManager/TokenAuthManager.js.map +1 -1
  34. package/package.json +1 -1
  35. package/src/Account/index.ts +15 -15
  36. package/src/ApiClient/WebSocketClient/BrowserWebSocketClient/index.ts +0 -4
  37. package/src/ApiClient/WebSocketClient/events.ts +11 -0
  38. package/src/ApiClient/WebSocketClient/index.ts +12 -2
  39. package/src/ApiClient/index.ts +1 -1
  40. package/src/Projects/Job.ts +50 -16
  41. package/src/Projects/Project.ts +12 -6
  42. package/src/Projects/createJobRequestMessage.ts +143 -33
  43. package/src/Projects/index.ts +351 -33
  44. package/src/Projects/types/events.ts +6 -0
  45. package/src/Projects/types/index.ts +141 -30
  46. package/src/Projects/utils.ts +66 -1
  47. package/src/index.ts +16 -4
  48. package/src/lib/AuthManager/TokenAuthManager.ts +0 -2
@@ -19,7 +19,7 @@ export const enhancementDefaults = {
19
19
  startingImageStrength: 0.5,
20
20
  steps: 5,
21
21
  guidance: 1,
22
- numberOfImages: 1,
22
+ numberOfMedia: 1,
23
23
  numberOfPreviews: 0
24
24
  };
25
25
 
@@ -178,10 +178,21 @@ class Job extends DataEntity<JobData, JobEventMap> {
178
178
  return this.data.error;
179
179
  }
180
180
 
181
- get hasResultImage() {
181
+ /**
182
+ * Whether this job has a result media file available for download.
183
+ * Returns true if completed and not NSFW filtered.
184
+ */
185
+ get hasResultMedia() {
182
186
  return this.status === 'completed' && !this.isNSFW;
183
187
  }
184
188
 
189
+ /**
190
+ * Whether this job produces video output (based on the model used)
191
+ */
192
+ get type(): 'image' | 'video' {
193
+ return this._api.isVideoModelId(this._project.params.modelId) ? 'video' : 'image';
194
+ }
195
+
185
196
  get enhancedImage() {
186
197
  if (!this._enhancementProject) {
187
198
  return null;
@@ -199,17 +210,27 @@ class Job extends DataEntity<JobData, JobEventMap> {
199
210
 
200
211
  /**
201
212
  * Get the result URL of the job. This method will make a request to the API to get signed URL.
202
- * IMPORTANT: URL expires after 30 minutes, so make sure to download the image as soon as possible.
213
+ * IMPORTANT: URL expires after 30 minutes, so make sure to download the result as soon as possible.
214
+ * For video jobs, this returns a video URL. For image jobs, this returns an image URL.
203
215
  */
204
216
  async getResultUrl(): Promise<string> {
205
217
  if (this.data.status !== 'completed') {
206
218
  throw new Error('Job is not completed yet');
207
219
  }
208
- const url = await this._api.downloadUrl({
209
- jobId: this.projectId,
210
- imageId: this.id,
211
- type: 'complete'
212
- });
220
+ let url: string;
221
+ if (this.type === 'video') {
222
+ url = await this._api.mediaDownloadUrl({
223
+ jobId: this.projectId,
224
+ id: this.id,
225
+ type: 'complete'
226
+ });
227
+ } else {
228
+ url = await this._api.downloadUrl({
229
+ jobId: this.projectId,
230
+ imageId: this.id,
231
+ type: 'complete'
232
+ });
233
+ }
213
234
  this._update({ resultUrl: url });
214
235
  return url;
215
236
  }
@@ -247,11 +268,19 @@ class Job extends DataEntity<JobData, JobEventMap> {
247
268
  }
248
269
  if (!this.data.resultUrl && delta.status === 'completed' && !data.triggeredNSFWFilter) {
249
270
  try {
250
- delta.resultUrl = await this._api.downloadUrl({
251
- jobId: this.projectId,
252
- imageId: this.id,
253
- type: 'complete'
254
- });
271
+ if (this.type === 'video') {
272
+ delta.resultUrl = await this._api.mediaDownloadUrl({
273
+ jobId: this.projectId,
274
+ id: this.id,
275
+ type: 'complete'
276
+ });
277
+ } else {
278
+ delta.resultUrl = await this._api.downloadUrl({
279
+ jobId: this.projectId,
280
+ imageId: this.id,
281
+ type: 'complete'
282
+ });
283
+ }
255
284
  } catch (error) {
256
285
  this._logger.error(error);
257
286
  }
@@ -276,8 +305,8 @@ class Job extends DataEntity<JobData, JobEventMap> {
276
305
  }
277
306
 
278
307
  async getResultData() {
279
- if (!this.hasResultImage) {
280
- throw new Error('No result image available');
308
+ if (!this.hasResultMedia) {
309
+ throw new Error('No result media available');
281
310
  }
282
311
  const url = await this.getResultUrl();
283
312
  const response = await fetch(url);
@@ -297,6 +326,10 @@ class Job extends DataEntity<JobData, JobEventMap> {
297
326
  strength: EnhancementStrength,
298
327
  overrides: { positivePrompt?: string; stylePrompt?: string; tokenType?: TokenType } = {}
299
328
  ) {
329
+ const parentProjectParams = this._project.params;
330
+ if (parentProjectParams.type !== 'image') {
331
+ throw new Error('Enhancement is only available for images');
332
+ }
300
333
  if (this.status !== 'completed') {
301
334
  throw new Error('Job is not completed yet');
302
335
  }
@@ -309,6 +342,7 @@ class Job extends DataEntity<JobData, JobEventMap> {
309
342
  }
310
343
  const imageData = await this.getResultData();
311
344
  const project = await this._api.create({
345
+ type: 'image',
312
346
  ...enhancementDefaults,
313
347
  positivePrompt: overrides.positivePrompt || this._project.params.positivePrompt,
314
348
  stylePrompt: overrides.stylePrompt || this._project.params.stylePrompt,
@@ -316,7 +350,7 @@ class Job extends DataEntity<JobData, JobEventMap> {
316
350
  seed: this.seed || this._project.params.seed,
317
351
  startingImage: imageData,
318
352
  startingImageStrength: 1 - getEnhacementStrength(strength),
319
- sizePreset: this._project.params.sizePreset
353
+ sizePreset: parentProjectParams.sizePreset
320
354
  });
321
355
  this._enhancementProject = project;
322
356
  this._enhancementProject.on('updated', this.handleEnhancementUpdate);
@@ -1,6 +1,6 @@
1
1
  import Job, { JobData } from './Job';
2
2
  import DataEntity, { EntityEvents } from '../lib/DataEntity';
3
- import { ProjectParams } from './types';
3
+ import { isImageParams, ProjectParams } from './types';
4
4
  import cloneDeep from 'lodash/cloneDeep';
5
5
  import ErrorData from '../types/ErrorData';
6
6
  import getUUID from '../lib/getUUID';
@@ -93,6 +93,10 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
93
93
  return this.data.params;
94
94
  }
95
95
 
96
+ get type() {
97
+ return this.params.type;
98
+ }
99
+
96
100
  get status() {
97
101
  return this.data.status;
98
102
  }
@@ -111,7 +115,7 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
111
115
  get progress() {
112
116
  // Worker can reduce the number of steps in the job, so we need to calculate the progress based on the actual number of steps
113
117
  const stepsPerJob = this.jobs.length ? this.jobs[0].stepCount : this.data.params.steps;
114
- const jobCount = this.data.params.numberOfImages;
118
+ const jobCount = this.data.params.numberOfMedia;
115
119
  const stepsDone = this._jobs.reduce((acc, job) => acc + job.step, 0);
116
120
  return Math.round((stepsDone / (stepsPerJob * jobCount)) * 100);
117
121
  }
@@ -192,7 +196,7 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
192
196
  this._timeout = null;
193
197
  }
194
198
  if (keys.includes('status') || keys.includes('jobs')) {
195
- const allJobsStarted = this.jobs.length >= this.params.numberOfImages;
199
+ const allJobsStarted = this.jobs.length >= this.params.numberOfMedia;
196
200
  const allJobsDone = this.jobs.every((job) => job.finished);
197
201
  if (this.data.status === 'completed' && allJobsStarted && allJobsDone) {
198
202
  return this.emit('completed', this.resultUrls);
@@ -298,11 +302,13 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
298
302
  const delta: Partial<ProjectData> = {
299
303
  params: {
300
304
  ...this.data.params,
301
- numberOfImages: data.imageCount,
302
- steps: data.stepCount,
303
- numberOfPreviews: data.previewCount
305
+ numberOfMedia: data.imageCount,
306
+ steps: data.stepCount
304
307
  }
305
308
  };
309
+ if (delta.params && isImageParams(delta.params)) {
310
+ delta.params.numberOfPreviews = data.previewCount;
311
+ }
306
312
  if (PROJECT_STATUS_MAP[data.status]) {
307
313
  delta.status = PROJECT_STATUS_MAP[data.status];
308
314
  }
@@ -1,4 +1,10 @@
1
- import { ProjectParams } from './types';
1
+ import {
2
+ ImageProjectParams,
3
+ isImageParams,
4
+ isVideoParams,
5
+ ProjectParams,
6
+ VideoProjectParams
7
+ } from './types';
2
8
  import { ControlNetParams, ControlNetParamsRaw } from './types/ControlNetParams';
3
9
  import {
4
10
  validateNumber,
@@ -6,6 +12,41 @@ import {
6
12
  validateSampler,
7
13
  validateScheduler
8
14
  } from '../lib/validation';
15
+ import { getVideoWorkflowType, isVideoModel, VIDEO_WORKFLOW_ASSETS } from './utils';
16
+ import { ApiError } from '../ApiClient';
17
+
18
+ /**
19
+ * Validate that the provided assets match the workflow requirements.
20
+ * Throws an error if required assets are missing or forbidden assets are provided.
21
+ */
22
+ function validateVideoWorkflowAssets(params: VideoProjectParams): void {
23
+ const workflowType = getVideoWorkflowType(params.modelId);
24
+ if (!workflowType) return;
25
+
26
+ const requirements = VIDEO_WORKFLOW_ASSETS[workflowType];
27
+ if (!requirements) return;
28
+ // Check for missing required assets
29
+ for (const [asset, requirement] of Object.entries(requirements)) {
30
+ const assetKey = asset as keyof VideoProjectParams;
31
+ const hasAsset = !!params[assetKey];
32
+
33
+ if (requirement === 'required' && !hasAsset) {
34
+ throw new ApiError(400, {
35
+ status: 'error',
36
+ errorCode: 0,
37
+ message: `${workflowType} workflow requires ${assetKey}. Please provide this asset.`
38
+ });
39
+ }
40
+
41
+ if (requirement === 'forbidden' && hasAsset) {
42
+ throw new ApiError(400, {
43
+ status: 'error',
44
+ errorCode: 0,
45
+ message: `${workflowType} workflow does not support ${assetKey}. Please remove this asset.`
46
+ });
47
+ }
48
+ }
49
+ }
9
50
 
10
51
  // Mac worker can't process the data if some of the fields are missing, so we need to provide a default template
11
52
  function getTemplate() {
@@ -118,49 +159,118 @@ function getControlNet(params: ControlNetParams): ControlNetParamsRaw[] {
118
159
  return [cn];
119
160
  }
120
161
 
162
+ function applyImageParams(inputKeyframe: Record<string, any>, params: ImageProjectParams) {
163
+ const keyFrame: Record<string, any> = {
164
+ ...inputKeyframe,
165
+ scheduler: validateSampler(params.sampler),
166
+ timeStepSpacing: validateScheduler(params.scheduler),
167
+ sizePreset: params.sizePreset,
168
+ hasContextImage1: !!params.contextImages?.[0],
169
+ hasContextImage2: !!params.contextImages?.[1]
170
+ };
171
+
172
+ if (params.startingImage) {
173
+ keyFrame.hasStartingImage = true;
174
+ keyFrame.strengthIsEnabled = true;
175
+ keyFrame.strength = 1 - (Number(params.startingImageStrength) || 0.5);
176
+ }
177
+
178
+ if (params.controlNet) {
179
+ keyFrame.currentControlNetsJob = getControlNet(params.controlNet);
180
+ }
181
+ if (params.sizePreset === 'custom') {
182
+ keyFrame.width = validateCustomImageSize(params.width);
183
+ keyFrame.height = validateCustomImageSize(params.height);
184
+ }
185
+ return keyFrame;
186
+ }
187
+
188
+ function applyVideoParams(inputKeyframe: Record<string, any>, params: VideoProjectParams) {
189
+ if (!isVideoModel(params.modelId)) {
190
+ throw new ApiError(400, {
191
+ status: 'error',
192
+ errorCode: 0,
193
+ message: 'Video generation is only supported for video models.'
194
+ });
195
+ }
196
+ validateVideoWorkflowAssets(params);
197
+ const keyFrame: Record<string, any> = { ...inputKeyframe };
198
+ if (params.referenceImage) {
199
+ keyFrame.hasReferenceImage = true;
200
+ }
201
+ if (params.referenceImageEnd) {
202
+ keyFrame.hasReferenceImageEnd = true;
203
+ }
204
+ if (params.referenceAudio) {
205
+ keyFrame.hasReferenceAudio = true;
206
+ }
207
+ if (params.referenceVideo) {
208
+ keyFrame.hasReferenceVideo = true;
209
+ }
210
+
211
+ // Video generation parameters
212
+ if (params.frames !== undefined) {
213
+ keyFrame.frames = params.frames;
214
+ }
215
+ if (params.fps !== undefined) {
216
+ keyFrame.fps = params.fps;
217
+ }
218
+ if (params.shift !== undefined) {
219
+ keyFrame.shift = params.shift;
220
+ }
221
+
222
+ if (params.width && params.height) {
223
+ keyFrame.width = params.width;
224
+ keyFrame.height = params.height;
225
+ }
226
+
227
+ return keyFrame;
228
+ }
229
+
121
230
  function createJobRequestMessage(id: string, params: ProjectParams) {
122
231
  const template = getTemplate();
232
+ // Base keyFrame with common params
233
+ let keyFrame: Record<string, any> = {
234
+ ...template.keyFrames[0],
235
+ steps: params.steps,
236
+ guidanceScale: params.guidance,
237
+ modelID: params.modelId,
238
+ negativePrompt: params.negativePrompt,
239
+ seed: params.seed,
240
+ positivePrompt: params.positivePrompt,
241
+ stylePrompt: params.stylePrompt
242
+ };
243
+
244
+ switch (params.type) {
245
+ case 'image':
246
+ keyFrame = applyImageParams(keyFrame, params);
247
+ break;
248
+ case 'video':
249
+ keyFrame = applyVideoParams(keyFrame, params);
250
+ break;
251
+ default:
252
+ throw new ApiError(400, {
253
+ status: 'error',
254
+ errorCode: 0,
255
+ message: 'Invalid project type. Must be "image" or "video".'
256
+ });
257
+ }
258
+
123
259
  const jobRequest: Record<string, any> = {
124
260
  ...template,
125
- keyFrames: [
126
- {
127
- ...template.keyFrames[0],
128
- scheduler: validateSampler(params.sampler),
129
- timeStepSpacing: validateScheduler(params.scheduler),
130
- steps: params.steps,
131
- guidanceScale: params.guidance,
132
- modelID: params.modelId,
133
- negativePrompt: params.negativePrompt,
134
- seed: params.seed,
135
- positivePrompt: params.positivePrompt,
136
- stylePrompt: params.stylePrompt,
137
- hasStartingImage: !!params.startingImage,
138
- hasContextImage1: !!params.contextImages?.[0],
139
- hasContextImage2: !!params.contextImages?.[1],
140
- strengthIsEnabled: !!params.startingImage,
141
- strength: !!params.startingImage
142
- ? 1 - (Number(params.startingImageStrength) || 0.5)
143
- : undefined,
144
- sizePreset: params.sizePreset
145
- }
146
- ],
147
- previews: params.numberOfPreviews || 0,
148
- numberOfImages: params.numberOfImages,
261
+ keyFrames: [keyFrame],
262
+ previews: isImageParams(params) ? params.numberOfPreviews || 0 : 0,
263
+ numberOfImages: params.numberOfMedia || 1,
149
264
  jobID: id,
150
265
  disableSafety: !!params.disableNSFWFilter,
151
266
  tokenType: params.tokenType,
152
- outputFormat: params.outputFormat || 'png'
267
+ outputFormat: params.outputFormat || (isVideoParams(params) ? 'mp4' : 'png')
153
268
  };
269
+
154
270
  if (params.network) {
155
271
  jobRequest.network = params.network;
156
272
  }
157
- if (params.controlNet) {
158
- jobRequest.keyFrames[0].currentControlNetsJob = getControlNet(params.controlNet);
159
- }
160
- if (params.sizePreset === 'custom') {
161
- jobRequest.keyFrames[0].width = validateCustomImageSize(params.width);
162
- jobRequest.keyFrames[0].height = validateCustomImageSize(params.height);
163
- }
273
+
164
274
  return jobRequest;
165
275
  }
166
276