@sogni-ai/sogni-client 4.0.0-alpha.20 → 4.0.0-alpha.22
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +23 -0
- package/README.md +26 -15
- package/dist/Account/index.d.ts +15 -15
- package/dist/Account/index.js +15 -15
- package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/ChannelCoordinator.js +2 -2
- package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/ChannelCoordinator.js.map +1 -1
- package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/index.js +0 -4
- package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/index.js.map +1 -1
- package/dist/ApiClient/WebSocketClient/events.d.ts +10 -0
- package/dist/ApiClient/WebSocketClient/index.js +12 -2
- package/dist/ApiClient/WebSocketClient/index.js.map +1 -1
- package/dist/ApiClient/index.js +1 -1
- package/dist/ApiClient/index.js.map +1 -1
- package/dist/Projects/Job.d.ts +12 -3
- package/dist/Projects/Job.js +50 -16
- package/dist/Projects/Job.js.map +1 -1
- package/dist/Projects/Project.d.ts +1 -0
- package/dist/Projects/Project.js +10 -3
- package/dist/Projects/Project.js.map +1 -1
- package/dist/Projects/createJobRequestMessage.js +105 -12
- package/dist/Projects/createJobRequestMessage.js.map +1 -1
- package/dist/Projects/index.d.ts +74 -5
- package/dist/Projects/index.js +337 -33
- package/dist/Projects/index.js.map +1 -1
- package/dist/Projects/types/events.d.ts +5 -1
- package/dist/Projects/types/index.d.ts +113 -28
- package/dist/Projects/types/index.js +8 -0
- package/dist/Projects/types/index.js.map +1 -1
- package/dist/Projects/utils.d.ts +19 -1
- package/dist/Projects/utils.js +68 -0
- package/dist/Projects/utils.js.map +1 -1
- package/dist/index.d.ts +2 -2
- package/dist/index.js.map +1 -1
- package/dist/lib/AuthManager/TokenAuthManager.js +0 -2
- package/dist/lib/AuthManager/TokenAuthManager.js.map +1 -1
- package/package.json +1 -1
- package/src/Account/index.ts +15 -15
- package/src/ApiClient/WebSocketClient/BrowserWebSocketClient/ChannelCoordinator.ts +2 -2
- package/src/ApiClient/WebSocketClient/BrowserWebSocketClient/index.ts +0 -4
- package/src/ApiClient/WebSocketClient/events.ts +11 -0
- package/src/ApiClient/WebSocketClient/index.ts +12 -2
- package/src/ApiClient/index.ts +1 -1
- package/src/Projects/Job.ts +50 -16
- package/src/Projects/Project.ts +12 -6
- package/src/Projects/createJobRequestMessage.ts +143 -33
- package/src/Projects/index.ts +351 -33
- package/src/Projects/types/events.ts +6 -0
- package/src/Projects/types/index.ts +141 -30
- package/src/Projects/utils.ts +66 -1
- package/src/index.ts +16 -4
- package/src/lib/AuthManager/TokenAuthManager.ts +0 -2
package/src/Projects/Job.ts
CHANGED
|
@@ -19,7 +19,7 @@ export const enhancementDefaults = {
|
|
|
19
19
|
startingImageStrength: 0.5,
|
|
20
20
|
steps: 5,
|
|
21
21
|
guidance: 1,
|
|
22
|
-
|
|
22
|
+
numberOfMedia: 1,
|
|
23
23
|
numberOfPreviews: 0
|
|
24
24
|
};
|
|
25
25
|
|
|
@@ -178,10 +178,21 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
178
178
|
return this.data.error;
|
|
179
179
|
}
|
|
180
180
|
|
|
181
|
-
|
|
181
|
+
/**
|
|
182
|
+
* Whether this job has a result media file available for download.
|
|
183
|
+
* Returns true if completed and not NSFW filtered.
|
|
184
|
+
*/
|
|
185
|
+
get hasResultMedia() {
|
|
182
186
|
return this.status === 'completed' && !this.isNSFW;
|
|
183
187
|
}
|
|
184
188
|
|
|
189
|
+
/**
|
|
190
|
+
* Whether this job produces video output (based on the model used)
|
|
191
|
+
*/
|
|
192
|
+
get type(): 'image' | 'video' {
|
|
193
|
+
return this._api.isVideoModelId(this._project.params.modelId) ? 'video' : 'image';
|
|
194
|
+
}
|
|
195
|
+
|
|
185
196
|
get enhancedImage() {
|
|
186
197
|
if (!this._enhancementProject) {
|
|
187
198
|
return null;
|
|
@@ -199,17 +210,27 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
199
210
|
|
|
200
211
|
/**
|
|
201
212
|
* Get the result URL of the job. This method will make a request to the API to get signed URL.
|
|
202
|
-
* IMPORTANT: URL expires after 30 minutes, so make sure to download the
|
|
213
|
+
* IMPORTANT: URL expires after 30 minutes, so make sure to download the result as soon as possible.
|
|
214
|
+
* For video jobs, this returns a video URL. For image jobs, this returns an image URL.
|
|
203
215
|
*/
|
|
204
216
|
async getResultUrl(): Promise<string> {
|
|
205
217
|
if (this.data.status !== 'completed') {
|
|
206
218
|
throw new Error('Job is not completed yet');
|
|
207
219
|
}
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
220
|
+
let url: string;
|
|
221
|
+
if (this.type === 'video') {
|
|
222
|
+
url = await this._api.mediaDownloadUrl({
|
|
223
|
+
jobId: this.projectId,
|
|
224
|
+
id: this.id,
|
|
225
|
+
type: 'complete'
|
|
226
|
+
});
|
|
227
|
+
} else {
|
|
228
|
+
url = await this._api.downloadUrl({
|
|
229
|
+
jobId: this.projectId,
|
|
230
|
+
imageId: this.id,
|
|
231
|
+
type: 'complete'
|
|
232
|
+
});
|
|
233
|
+
}
|
|
213
234
|
this._update({ resultUrl: url });
|
|
214
235
|
return url;
|
|
215
236
|
}
|
|
@@ -247,11 +268,19 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
247
268
|
}
|
|
248
269
|
if (!this.data.resultUrl && delta.status === 'completed' && !data.triggeredNSFWFilter) {
|
|
249
270
|
try {
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
271
|
+
if (this.type === 'video') {
|
|
272
|
+
delta.resultUrl = await this._api.mediaDownloadUrl({
|
|
273
|
+
jobId: this.projectId,
|
|
274
|
+
id: this.id,
|
|
275
|
+
type: 'complete'
|
|
276
|
+
});
|
|
277
|
+
} else {
|
|
278
|
+
delta.resultUrl = await this._api.downloadUrl({
|
|
279
|
+
jobId: this.projectId,
|
|
280
|
+
imageId: this.id,
|
|
281
|
+
type: 'complete'
|
|
282
|
+
});
|
|
283
|
+
}
|
|
255
284
|
} catch (error) {
|
|
256
285
|
this._logger.error(error);
|
|
257
286
|
}
|
|
@@ -276,8 +305,8 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
276
305
|
}
|
|
277
306
|
|
|
278
307
|
async getResultData() {
|
|
279
|
-
if (!this.
|
|
280
|
-
throw new Error('No result
|
|
308
|
+
if (!this.hasResultMedia) {
|
|
309
|
+
throw new Error('No result media available');
|
|
281
310
|
}
|
|
282
311
|
const url = await this.getResultUrl();
|
|
283
312
|
const response = await fetch(url);
|
|
@@ -297,6 +326,10 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
297
326
|
strength: EnhancementStrength,
|
|
298
327
|
overrides: { positivePrompt?: string; stylePrompt?: string; tokenType?: TokenType } = {}
|
|
299
328
|
) {
|
|
329
|
+
const parentProjectParams = this._project.params;
|
|
330
|
+
if (parentProjectParams.type !== 'image') {
|
|
331
|
+
throw new Error('Enhancement is only available for images');
|
|
332
|
+
}
|
|
300
333
|
if (this.status !== 'completed') {
|
|
301
334
|
throw new Error('Job is not completed yet');
|
|
302
335
|
}
|
|
@@ -309,6 +342,7 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
309
342
|
}
|
|
310
343
|
const imageData = await this.getResultData();
|
|
311
344
|
const project = await this._api.create({
|
|
345
|
+
type: 'image',
|
|
312
346
|
...enhancementDefaults,
|
|
313
347
|
positivePrompt: overrides.positivePrompt || this._project.params.positivePrompt,
|
|
314
348
|
stylePrompt: overrides.stylePrompt || this._project.params.stylePrompt,
|
|
@@ -316,7 +350,7 @@ class Job extends DataEntity<JobData, JobEventMap> {
|
|
|
316
350
|
seed: this.seed || this._project.params.seed,
|
|
317
351
|
startingImage: imageData,
|
|
318
352
|
startingImageStrength: 1 - getEnhacementStrength(strength),
|
|
319
|
-
sizePreset:
|
|
353
|
+
sizePreset: parentProjectParams.sizePreset
|
|
320
354
|
});
|
|
321
355
|
this._enhancementProject = project;
|
|
322
356
|
this._enhancementProject.on('updated', this.handleEnhancementUpdate);
|
package/src/Projects/Project.ts
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import Job, { JobData } from './Job';
|
|
2
2
|
import DataEntity, { EntityEvents } from '../lib/DataEntity';
|
|
3
|
-
import { ProjectParams } from './types';
|
|
3
|
+
import { isImageParams, ProjectParams } from './types';
|
|
4
4
|
import cloneDeep from 'lodash/cloneDeep';
|
|
5
5
|
import ErrorData from '../types/ErrorData';
|
|
6
6
|
import getUUID from '../lib/getUUID';
|
|
@@ -93,6 +93,10 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
|
93
93
|
return this.data.params;
|
|
94
94
|
}
|
|
95
95
|
|
|
96
|
+
get type() {
|
|
97
|
+
return this.params.type;
|
|
98
|
+
}
|
|
99
|
+
|
|
96
100
|
get status() {
|
|
97
101
|
return this.data.status;
|
|
98
102
|
}
|
|
@@ -111,7 +115,7 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
|
111
115
|
get progress() {
|
|
112
116
|
// Worker can reduce the number of steps in the job, so we need to calculate the progress based on the actual number of steps
|
|
113
117
|
const stepsPerJob = this.jobs.length ? this.jobs[0].stepCount : this.data.params.steps;
|
|
114
|
-
const jobCount = this.data.params.
|
|
118
|
+
const jobCount = this.data.params.numberOfMedia;
|
|
115
119
|
const stepsDone = this._jobs.reduce((acc, job) => acc + job.step, 0);
|
|
116
120
|
return Math.round((stepsDone / (stepsPerJob * jobCount)) * 100);
|
|
117
121
|
}
|
|
@@ -192,7 +196,7 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
|
192
196
|
this._timeout = null;
|
|
193
197
|
}
|
|
194
198
|
if (keys.includes('status') || keys.includes('jobs')) {
|
|
195
|
-
const allJobsStarted = this.jobs.length >= this.params.
|
|
199
|
+
const allJobsStarted = this.jobs.length >= this.params.numberOfMedia;
|
|
196
200
|
const allJobsDone = this.jobs.every((job) => job.finished);
|
|
197
201
|
if (this.data.status === 'completed' && allJobsStarted && allJobsDone) {
|
|
198
202
|
return this.emit('completed', this.resultUrls);
|
|
@@ -298,11 +302,13 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
|
298
302
|
const delta: Partial<ProjectData> = {
|
|
299
303
|
params: {
|
|
300
304
|
...this.data.params,
|
|
301
|
-
|
|
302
|
-
steps: data.stepCount
|
|
303
|
-
numberOfPreviews: data.previewCount
|
|
305
|
+
numberOfMedia: data.imageCount,
|
|
306
|
+
steps: data.stepCount
|
|
304
307
|
}
|
|
305
308
|
};
|
|
309
|
+
if (delta.params && isImageParams(delta.params)) {
|
|
310
|
+
delta.params.numberOfPreviews = data.previewCount;
|
|
311
|
+
}
|
|
306
312
|
if (PROJECT_STATUS_MAP[data.status]) {
|
|
307
313
|
delta.status = PROJECT_STATUS_MAP[data.status];
|
|
308
314
|
}
|
|
@@ -1,4 +1,10 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import {
|
|
2
|
+
ImageProjectParams,
|
|
3
|
+
isImageParams,
|
|
4
|
+
isVideoParams,
|
|
5
|
+
ProjectParams,
|
|
6
|
+
VideoProjectParams
|
|
7
|
+
} from './types';
|
|
2
8
|
import { ControlNetParams, ControlNetParamsRaw } from './types/ControlNetParams';
|
|
3
9
|
import {
|
|
4
10
|
validateNumber,
|
|
@@ -6,6 +12,41 @@ import {
|
|
|
6
12
|
validateSampler,
|
|
7
13
|
validateScheduler
|
|
8
14
|
} from '../lib/validation';
|
|
15
|
+
import { getVideoWorkflowType, isVideoModel, VIDEO_WORKFLOW_ASSETS } from './utils';
|
|
16
|
+
import { ApiError } from '../ApiClient';
|
|
17
|
+
|
|
18
|
+
/**
|
|
19
|
+
* Validate that the provided assets match the workflow requirements.
|
|
20
|
+
* Throws an error if required assets are missing or forbidden assets are provided.
|
|
21
|
+
*/
|
|
22
|
+
function validateVideoWorkflowAssets(params: VideoProjectParams): void {
|
|
23
|
+
const workflowType = getVideoWorkflowType(params.modelId);
|
|
24
|
+
if (!workflowType) return;
|
|
25
|
+
|
|
26
|
+
const requirements = VIDEO_WORKFLOW_ASSETS[workflowType];
|
|
27
|
+
if (!requirements) return;
|
|
28
|
+
// Check for missing required assets
|
|
29
|
+
for (const [asset, requirement] of Object.entries(requirements)) {
|
|
30
|
+
const assetKey = asset as keyof VideoProjectParams;
|
|
31
|
+
const hasAsset = !!params[assetKey];
|
|
32
|
+
|
|
33
|
+
if (requirement === 'required' && !hasAsset) {
|
|
34
|
+
throw new ApiError(400, {
|
|
35
|
+
status: 'error',
|
|
36
|
+
errorCode: 0,
|
|
37
|
+
message: `${workflowType} workflow requires ${assetKey}. Please provide this asset.`
|
|
38
|
+
});
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
if (requirement === 'forbidden' && hasAsset) {
|
|
42
|
+
throw new ApiError(400, {
|
|
43
|
+
status: 'error',
|
|
44
|
+
errorCode: 0,
|
|
45
|
+
message: `${workflowType} workflow does not support ${assetKey}. Please remove this asset.`
|
|
46
|
+
});
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
}
|
|
9
50
|
|
|
10
51
|
// Mac worker can't process the data if some of the fields are missing, so we need to provide a default template
|
|
11
52
|
function getTemplate() {
|
|
@@ -118,49 +159,118 @@ function getControlNet(params: ControlNetParams): ControlNetParamsRaw[] {
|
|
|
118
159
|
return [cn];
|
|
119
160
|
}
|
|
120
161
|
|
|
162
|
+
function applyImageParams(inputKeyframe: Record<string, any>, params: ImageProjectParams) {
|
|
163
|
+
const keyFrame: Record<string, any> = {
|
|
164
|
+
...inputKeyframe,
|
|
165
|
+
scheduler: validateSampler(params.sampler),
|
|
166
|
+
timeStepSpacing: validateScheduler(params.scheduler),
|
|
167
|
+
sizePreset: params.sizePreset,
|
|
168
|
+
hasContextImage1: !!params.contextImages?.[0],
|
|
169
|
+
hasContextImage2: !!params.contextImages?.[1]
|
|
170
|
+
};
|
|
171
|
+
|
|
172
|
+
if (params.startingImage) {
|
|
173
|
+
keyFrame.hasStartingImage = true;
|
|
174
|
+
keyFrame.strengthIsEnabled = true;
|
|
175
|
+
keyFrame.strength = 1 - (Number(params.startingImageStrength) || 0.5);
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
if (params.controlNet) {
|
|
179
|
+
keyFrame.currentControlNetsJob = getControlNet(params.controlNet);
|
|
180
|
+
}
|
|
181
|
+
if (params.sizePreset === 'custom') {
|
|
182
|
+
keyFrame.width = validateCustomImageSize(params.width);
|
|
183
|
+
keyFrame.height = validateCustomImageSize(params.height);
|
|
184
|
+
}
|
|
185
|
+
return keyFrame;
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
function applyVideoParams(inputKeyframe: Record<string, any>, params: VideoProjectParams) {
|
|
189
|
+
if (!isVideoModel(params.modelId)) {
|
|
190
|
+
throw new ApiError(400, {
|
|
191
|
+
status: 'error',
|
|
192
|
+
errorCode: 0,
|
|
193
|
+
message: 'Video generation is only supported for video models.'
|
|
194
|
+
});
|
|
195
|
+
}
|
|
196
|
+
validateVideoWorkflowAssets(params);
|
|
197
|
+
const keyFrame: Record<string, any> = { ...inputKeyframe };
|
|
198
|
+
if (params.referenceImage) {
|
|
199
|
+
keyFrame.hasReferenceImage = true;
|
|
200
|
+
}
|
|
201
|
+
if (params.referenceImageEnd) {
|
|
202
|
+
keyFrame.hasReferenceImageEnd = true;
|
|
203
|
+
}
|
|
204
|
+
if (params.referenceAudio) {
|
|
205
|
+
keyFrame.hasReferenceAudio = true;
|
|
206
|
+
}
|
|
207
|
+
if (params.referenceVideo) {
|
|
208
|
+
keyFrame.hasReferenceVideo = true;
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
// Video generation parameters
|
|
212
|
+
if (params.frames !== undefined) {
|
|
213
|
+
keyFrame.frames = params.frames;
|
|
214
|
+
}
|
|
215
|
+
if (params.fps !== undefined) {
|
|
216
|
+
keyFrame.fps = params.fps;
|
|
217
|
+
}
|
|
218
|
+
if (params.shift !== undefined) {
|
|
219
|
+
keyFrame.shift = params.shift;
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
if (params.width && params.height) {
|
|
223
|
+
keyFrame.width = params.width;
|
|
224
|
+
keyFrame.height = params.height;
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
return keyFrame;
|
|
228
|
+
}
|
|
229
|
+
|
|
121
230
|
function createJobRequestMessage(id: string, params: ProjectParams) {
|
|
122
231
|
const template = getTemplate();
|
|
232
|
+
// Base keyFrame with common params
|
|
233
|
+
let keyFrame: Record<string, any> = {
|
|
234
|
+
...template.keyFrames[0],
|
|
235
|
+
steps: params.steps,
|
|
236
|
+
guidanceScale: params.guidance,
|
|
237
|
+
modelID: params.modelId,
|
|
238
|
+
negativePrompt: params.negativePrompt,
|
|
239
|
+
seed: params.seed,
|
|
240
|
+
positivePrompt: params.positivePrompt,
|
|
241
|
+
stylePrompt: params.stylePrompt
|
|
242
|
+
};
|
|
243
|
+
|
|
244
|
+
switch (params.type) {
|
|
245
|
+
case 'image':
|
|
246
|
+
keyFrame = applyImageParams(keyFrame, params);
|
|
247
|
+
break;
|
|
248
|
+
case 'video':
|
|
249
|
+
keyFrame = applyVideoParams(keyFrame, params);
|
|
250
|
+
break;
|
|
251
|
+
default:
|
|
252
|
+
throw new ApiError(400, {
|
|
253
|
+
status: 'error',
|
|
254
|
+
errorCode: 0,
|
|
255
|
+
message: 'Invalid project type. Must be "image" or "video".'
|
|
256
|
+
});
|
|
257
|
+
}
|
|
258
|
+
|
|
123
259
|
const jobRequest: Record<string, any> = {
|
|
124
260
|
...template,
|
|
125
|
-
keyFrames: [
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
scheduler: validateSampler(params.sampler),
|
|
129
|
-
timeStepSpacing: validateScheduler(params.scheduler),
|
|
130
|
-
steps: params.steps,
|
|
131
|
-
guidanceScale: params.guidance,
|
|
132
|
-
modelID: params.modelId,
|
|
133
|
-
negativePrompt: params.negativePrompt,
|
|
134
|
-
seed: params.seed,
|
|
135
|
-
positivePrompt: params.positivePrompt,
|
|
136
|
-
stylePrompt: params.stylePrompt,
|
|
137
|
-
hasStartingImage: !!params.startingImage,
|
|
138
|
-
hasContextImage1: !!params.contextImages?.[0],
|
|
139
|
-
hasContextImage2: !!params.contextImages?.[1],
|
|
140
|
-
strengthIsEnabled: !!params.startingImage,
|
|
141
|
-
strength: !!params.startingImage
|
|
142
|
-
? 1 - (Number(params.startingImageStrength) || 0.5)
|
|
143
|
-
: undefined,
|
|
144
|
-
sizePreset: params.sizePreset
|
|
145
|
-
}
|
|
146
|
-
],
|
|
147
|
-
previews: params.numberOfPreviews || 0,
|
|
148
|
-
numberOfImages: params.numberOfImages,
|
|
261
|
+
keyFrames: [keyFrame],
|
|
262
|
+
previews: isImageParams(params) ? params.numberOfPreviews || 0 : 0,
|
|
263
|
+
numberOfImages: params.numberOfMedia || 1,
|
|
149
264
|
jobID: id,
|
|
150
265
|
disableSafety: !!params.disableNSFWFilter,
|
|
151
266
|
tokenType: params.tokenType,
|
|
152
|
-
outputFormat: params.outputFormat || 'png'
|
|
267
|
+
outputFormat: params.outputFormat || (isVideoParams(params) ? 'mp4' : 'png')
|
|
153
268
|
};
|
|
269
|
+
|
|
154
270
|
if (params.network) {
|
|
155
271
|
jobRequest.network = params.network;
|
|
156
272
|
}
|
|
157
|
-
|
|
158
|
-
jobRequest.keyFrames[0].currentControlNetsJob = getControlNet(params.controlNet);
|
|
159
|
-
}
|
|
160
|
-
if (params.sizePreset === 'custom') {
|
|
161
|
-
jobRequest.keyFrames[0].width = validateCustomImageSize(params.width);
|
|
162
|
-
jobRequest.keyFrames[0].height = validateCustomImageSize(params.height);
|
|
163
|
-
}
|
|
273
|
+
|
|
164
274
|
return jobRequest;
|
|
165
275
|
}
|
|
166
276
|
|