@sogni-ai/sogni-client 0.3.1 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/dist/version.d.ts +1 -1
- package/dist/version.js +1 -1
- package/package.json +5 -3
- package/src/Account/CurrentAccount.ts +101 -0
- package/src/Account/index.ts +243 -0
- package/src/Account/types.ts +90 -0
- package/src/ApiClient/WebSocketClient/ErrorCode.ts +15 -0
- package/src/ApiClient/WebSocketClient/events.ts +94 -0
- package/src/ApiClient/WebSocketClient/index.ts +203 -0
- package/src/ApiClient/WebSocketClient/messages.ts +7 -0
- package/src/ApiClient/WebSocketClient/types.ts +1 -0
- package/src/ApiClient/events.ts +20 -0
- package/src/ApiClient/index.ts +124 -0
- package/src/ApiGroup.ts +25 -0
- package/src/Projects/Job.ts +124 -0
- package/src/Projects/Project.ts +185 -0
- package/src/Projects/createJobRequestMessage.ts +99 -0
- package/src/Projects/index.ts +340 -0
- package/src/Projects/models.json +8906 -0
- package/src/Projects/types/EstimationResponse.ts +45 -0
- package/src/Projects/types/events.ts +78 -0
- package/src/Projects/types/index.ts +146 -0
- package/src/Stats/index.ts +15 -0
- package/src/Stats/types.ts +34 -0
- package/src/events.ts +5 -0
- package/src/index.ts +120 -0
- package/src/lib/DataEntity.ts +38 -0
- package/src/lib/DefaultLogger.ts +47 -0
- package/src/lib/EIP712Helper.ts +57 -0
- package/src/lib/RestClient.ts +76 -0
- package/src/lib/TypedEventEmitter.ts +66 -0
- package/src/lib/base64.ts +9 -0
- package/src/lib/getUUID.ts +8 -0
- package/src/lib/isNodejs.ts +4 -0
- package/src/types/ErrorData.ts +6 -0
- package/src/types/json.ts +5 -0
- package/src/version.ts +1 -0
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import { ProjectParams } from './types';
|
|
2
|
+
// Mac worker can't process the data if some of the fields are missing, so we need to provide a default template
|
|
3
|
+
function getTemplate() {
|
|
4
|
+
return {
|
|
5
|
+
selectedUpscalingModel: 'OFF',
|
|
6
|
+
cnVideoFramesSketch: [],
|
|
7
|
+
cnVideoFramesSegmentedSubject: [],
|
|
8
|
+
cnVideoFramesFace: [],
|
|
9
|
+
doCanvasBlending: false,
|
|
10
|
+
animationIsOn: false,
|
|
11
|
+
cnVideoFramesBoth: [],
|
|
12
|
+
cnVideoFramesDepth: [],
|
|
13
|
+
keyFrames: [
|
|
14
|
+
{
|
|
15
|
+
stepsIsEnabled: true,
|
|
16
|
+
siRotation: 0,
|
|
17
|
+
siDragOffsetIsEnabled: true,
|
|
18
|
+
strength: 0.5,
|
|
19
|
+
siZoomScaleIsEnabled: true,
|
|
20
|
+
isEnabled: true,
|
|
21
|
+
processing: 'CPU, GPU',
|
|
22
|
+
useLastImageAsGuideImageInAnimation: true,
|
|
23
|
+
guidanceScaleIsEnabled: true,
|
|
24
|
+
siImageBackgroundColor: 'black',
|
|
25
|
+
cnDragOffset: [0, 0],
|
|
26
|
+
scheduler: 'DPM Solver Multistep (DPM-Solver++)',
|
|
27
|
+
timeStepSpacing: 'Linear',
|
|
28
|
+
steps: 20,
|
|
29
|
+
cnRotation: 0,
|
|
30
|
+
guidanceScale: 7.5,
|
|
31
|
+
siZoomScale: 1,
|
|
32
|
+
modelID: '',
|
|
33
|
+
cnRotationIsEnabled: true,
|
|
34
|
+
negativePrompt: '',
|
|
35
|
+
startingImageZoomPanIsOn: false,
|
|
36
|
+
seed: undefined,
|
|
37
|
+
siRotationIsEnabled: true,
|
|
38
|
+
cnImageBackgroundColor: 'clear',
|
|
39
|
+
strengthIsEnabled: true,
|
|
40
|
+
siDragOffset: [0, 0],
|
|
41
|
+
useLastImageAsCNImageInAnimation: false,
|
|
42
|
+
positivePrompt: '',
|
|
43
|
+
controlNetZoomPanIsOn: false,
|
|
44
|
+
cnZoomScaleIsEnabled: true,
|
|
45
|
+
currentControlNets: null,
|
|
46
|
+
stylePrompt: '',
|
|
47
|
+
cnDragOffsetIsEnabled: true,
|
|
48
|
+
frameIndex: 0,
|
|
49
|
+
startingImage: null,
|
|
50
|
+
cnZoomScale: 1
|
|
51
|
+
}
|
|
52
|
+
],
|
|
53
|
+
previews: 5,
|
|
54
|
+
frameRate: 24,
|
|
55
|
+
generatedVideoSeconds: 10,
|
|
56
|
+
canvasIsOn: false,
|
|
57
|
+
cnVideoFrames: [],
|
|
58
|
+
disableSafety: false,
|
|
59
|
+
cnVideoFramesSegmentedBackground: [],
|
|
60
|
+
cnVideoFramesSegmented: [],
|
|
61
|
+
numberOfImages: 1,
|
|
62
|
+
cnVideoFramesPose: [],
|
|
63
|
+
jobID: '',
|
|
64
|
+
siVideoFrames: []
|
|
65
|
+
};
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
function createJobRequestMessage(id: string, params: ProjectParams) {
|
|
69
|
+
const template = getTemplate();
|
|
70
|
+
return {
|
|
71
|
+
...template,
|
|
72
|
+
keyFrames: [
|
|
73
|
+
{
|
|
74
|
+
...template.keyFrames[0],
|
|
75
|
+
scheduler: params.scheduler,
|
|
76
|
+
steps: params.steps,
|
|
77
|
+
guidanceScale: params.guidance,
|
|
78
|
+
modelID: params.modelId,
|
|
79
|
+
negativePrompt: params.negativePrompt,
|
|
80
|
+
seed: params.seed,
|
|
81
|
+
positivePrompt: params.positivePrompt,
|
|
82
|
+
stylePrompt: params.stylePrompt,
|
|
83
|
+
hasStartingImage: !!params.startingImage,
|
|
84
|
+
strengthIsEnabled: !!params.startingImage,
|
|
85
|
+
strength: !!params.startingImage
|
|
86
|
+
? 1 - (Number(params.startingImageStrength) || 0.5)
|
|
87
|
+
: undefined
|
|
88
|
+
}
|
|
89
|
+
],
|
|
90
|
+
previews: params.numberOfPreviews || 0,
|
|
91
|
+
numberOfImages: params.numberOfImages,
|
|
92
|
+
jobID: id,
|
|
93
|
+
disableSafety: !!params.disableNSFWFilter
|
|
94
|
+
};
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
export type JobRequestRaw = ReturnType<typeof createJobRequestMessage>;
|
|
98
|
+
|
|
99
|
+
export default createJobRequestMessage;
|
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
import ApiGroup, { ApiConfig } from '../ApiGroup';
|
|
2
|
+
import models from './models.json';
|
|
3
|
+
import { AvailableModel, EstimateRequest, ImageUrlParams, ProjectParams } from './types';
|
|
4
|
+
import {
|
|
5
|
+
JobErrorData,
|
|
6
|
+
JobProgressData,
|
|
7
|
+
JobResultData,
|
|
8
|
+
JobStateData,
|
|
9
|
+
SocketEventMap
|
|
10
|
+
} from '../ApiClient/WebSocketClient/events';
|
|
11
|
+
import Project from './Project';
|
|
12
|
+
import createJobRequestMessage from './createJobRequestMessage';
|
|
13
|
+
import { ApiError, ApiReponse } from '../ApiClient';
|
|
14
|
+
import { EstimationResponse } from './types/EstimationResponse';
|
|
15
|
+
import { JobEvent, ProjectApiEvents, ProjectEvent } from './types/events';
|
|
16
|
+
import getUUID from '../lib/getUUID';
|
|
17
|
+
|
|
18
|
+
const GARBAGE_COLLECT_TIMEOUT = 10000;
|
|
19
|
+
|
|
20
|
+
class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
21
|
+
private _availableModels: AvailableModel[] = [];
|
|
22
|
+
private projects: Project[] = [];
|
|
23
|
+
|
|
24
|
+
get availableModels() {
|
|
25
|
+
return this._availableModels;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
constructor(config: ApiConfig) {
|
|
29
|
+
super(config);
|
|
30
|
+
// Listen to server events and emit them as project and job events
|
|
31
|
+
this.client.socket.on('swarmModels', this.handleSwarmModels.bind(this));
|
|
32
|
+
this.client.socket.on('jobState', this.handleJobState.bind(this));
|
|
33
|
+
this.client.socket.on('jobProgress', this.handleJobProgress.bind(this));
|
|
34
|
+
this.client.socket.on('jobError', this.handleJobError.bind(this));
|
|
35
|
+
this.client.socket.on('jobResult', this.handleJobResult.bind(this));
|
|
36
|
+
// Listen to server disconnect event
|
|
37
|
+
this.client.on('disconnected', this.handleServerDisconnected.bind(this));
|
|
38
|
+
// Listen to project and job events and update project and job instances
|
|
39
|
+
this.on('project', this.handleProjectEvent.bind(this));
|
|
40
|
+
this.on('job', this.handleJobEvent.bind(this));
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
private handleSwarmModels(data: SocketEventMap['swarmModels']) {
|
|
44
|
+
const modelIndex = models.reduce((acc: Record<string, any>, model) => {
|
|
45
|
+
acc[model.modelId] = model;
|
|
46
|
+
return acc;
|
|
47
|
+
}, {});
|
|
48
|
+
this._availableModels = Object.entries(data).map(([id, workerCount]) => ({
|
|
49
|
+
id,
|
|
50
|
+
name: modelIndex[id].modelShortName || id.replace(/-/g, ' '),
|
|
51
|
+
workerCount
|
|
52
|
+
}));
|
|
53
|
+
this.emit('availableModels', this._availableModels);
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
private handleJobState(data: JobStateData) {
|
|
57
|
+
switch (data.type) {
|
|
58
|
+
case 'queued':
|
|
59
|
+
this.emit('project', {
|
|
60
|
+
type: 'queued',
|
|
61
|
+
projectId: data.jobID,
|
|
62
|
+
queuePosition: data.queuePosition
|
|
63
|
+
});
|
|
64
|
+
return;
|
|
65
|
+
case 'jobCompleted':
|
|
66
|
+
this.emit('project', { type: 'completed', projectId: data.jobID });
|
|
67
|
+
return;
|
|
68
|
+
case 'initiatingModel':
|
|
69
|
+
this.emit('job', { type: 'initiating', projectId: data.jobID, jobId: data.imgID });
|
|
70
|
+
return;
|
|
71
|
+
case 'jobStarted': {
|
|
72
|
+
this.emit('job', { type: 'started', projectId: data.jobID, jobId: data.imgID });
|
|
73
|
+
return;
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
private async handleJobProgress(data: JobProgressData) {
|
|
79
|
+
this.emit('job', {
|
|
80
|
+
type: 'progress',
|
|
81
|
+
projectId: data.jobID,
|
|
82
|
+
jobId: data.imgID,
|
|
83
|
+
step: data.step,
|
|
84
|
+
stepCount: data.stepCount
|
|
85
|
+
});
|
|
86
|
+
|
|
87
|
+
if (data.hasImage) {
|
|
88
|
+
this.downloadUrl({
|
|
89
|
+
jobId: data.jobID,
|
|
90
|
+
imageId: data.imgID,
|
|
91
|
+
type: 'preview'
|
|
92
|
+
}).then((url) => {
|
|
93
|
+
this.emit('job', {
|
|
94
|
+
type: 'preview',
|
|
95
|
+
projectId: data.jobID,
|
|
96
|
+
jobId: data.imgID,
|
|
97
|
+
url
|
|
98
|
+
});
|
|
99
|
+
});
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
private async handleJobResult(data: JobResultData) {
|
|
104
|
+
const project = this.projects.find((p) => p.id === data.jobID);
|
|
105
|
+
const passNSFWCheck = !data.triggeredNSFWFilter || !project || project.params.disableNSFWFilter;
|
|
106
|
+
let downloadUrl = null;
|
|
107
|
+
// If NSFW filter is triggered, image will be only available for download if user explicitly
|
|
108
|
+
// disabled the filter for this project
|
|
109
|
+
if (passNSFWCheck && !data.userCanceled) {
|
|
110
|
+
downloadUrl = await this.downloadUrl({
|
|
111
|
+
jobId: data.jobID,
|
|
112
|
+
imageId: data.imgID,
|
|
113
|
+
type: 'complete'
|
|
114
|
+
});
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
this.emit('job', {
|
|
118
|
+
type: 'completed',
|
|
119
|
+
projectId: data.jobID,
|
|
120
|
+
jobId: data.imgID,
|
|
121
|
+
steps: data.performedStepCount,
|
|
122
|
+
seed: Number(data.lastSeed),
|
|
123
|
+
resultUrl: downloadUrl,
|
|
124
|
+
isNSFW: data.triggeredNSFWFilter,
|
|
125
|
+
userCanceled: data.userCanceled
|
|
126
|
+
});
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
private handleJobError(data: JobErrorData) {
|
|
130
|
+
if (!data.imgID) {
|
|
131
|
+
this.emit('project', {
|
|
132
|
+
type: 'error',
|
|
133
|
+
projectId: data.jobID,
|
|
134
|
+
error: {
|
|
135
|
+
code: Number(data.error),
|
|
136
|
+
message: data.error_message
|
|
137
|
+
}
|
|
138
|
+
});
|
|
139
|
+
return;
|
|
140
|
+
}
|
|
141
|
+
this.emit('job', {
|
|
142
|
+
type: 'error',
|
|
143
|
+
projectId: data.jobID,
|
|
144
|
+
jobId: data.imgID,
|
|
145
|
+
error: {
|
|
146
|
+
code: Number(data.error),
|
|
147
|
+
message: data.error_message
|
|
148
|
+
}
|
|
149
|
+
});
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
handleProjectEvent(event: ProjectEvent) {
|
|
153
|
+
let project = this.projects.find((p) => p.id === event.projectId);
|
|
154
|
+
if (!project) {
|
|
155
|
+
return;
|
|
156
|
+
}
|
|
157
|
+
switch (event.type) {
|
|
158
|
+
case 'queued':
|
|
159
|
+
project._update({
|
|
160
|
+
status: 'queued',
|
|
161
|
+
queuePosition: event.queuePosition
|
|
162
|
+
});
|
|
163
|
+
break;
|
|
164
|
+
case 'completed':
|
|
165
|
+
project._update({
|
|
166
|
+
status: 'completed'
|
|
167
|
+
});
|
|
168
|
+
break;
|
|
169
|
+
case 'error':
|
|
170
|
+
project._update({
|
|
171
|
+
status: 'failed',
|
|
172
|
+
error: event.error
|
|
173
|
+
});
|
|
174
|
+
}
|
|
175
|
+
if (project.status === 'completed' || project.status === 'failed') {
|
|
176
|
+
setTimeout(() => {
|
|
177
|
+
this.projects = this.projects.filter((p) => p.id !== event.projectId);
|
|
178
|
+
}, GARBAGE_COLLECT_TIMEOUT);
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
handleJobEvent(event: JobEvent) {
|
|
183
|
+
let project = this.projects.find((p) => p.id === event.projectId);
|
|
184
|
+
if (!project) {
|
|
185
|
+
return;
|
|
186
|
+
}
|
|
187
|
+
let job = project.job(event.jobId);
|
|
188
|
+
if (!job) {
|
|
189
|
+
job = project._addJob({
|
|
190
|
+
id: event.jobId,
|
|
191
|
+
status: 'pending',
|
|
192
|
+
step: 0,
|
|
193
|
+
stepCount: project.params.steps
|
|
194
|
+
});
|
|
195
|
+
}
|
|
196
|
+
switch (event.type) {
|
|
197
|
+
case 'initiating':
|
|
198
|
+
job._update({ status: 'initiating' });
|
|
199
|
+
break;
|
|
200
|
+
case 'started':
|
|
201
|
+
job._update({ status: 'processing' });
|
|
202
|
+
break;
|
|
203
|
+
case 'progress':
|
|
204
|
+
job._update({
|
|
205
|
+
status: 'processing',
|
|
206
|
+
step: event.step,
|
|
207
|
+
stepCount: event.stepCount
|
|
208
|
+
});
|
|
209
|
+
if (project.status !== 'processing') {
|
|
210
|
+
project._update({ status: 'processing' });
|
|
211
|
+
}
|
|
212
|
+
break;
|
|
213
|
+
case 'preview':
|
|
214
|
+
job._update({ previewUrl: event.url });
|
|
215
|
+
break;
|
|
216
|
+
case 'completed': {
|
|
217
|
+
job._update({
|
|
218
|
+
status: event.userCanceled ? 'canceled' : 'completed',
|
|
219
|
+
step: event.steps,
|
|
220
|
+
seed: event.seed,
|
|
221
|
+
resultUrl: event.resultUrl,
|
|
222
|
+
isNSFW: event.isNSFW,
|
|
223
|
+
userCanceled: event.userCanceled
|
|
224
|
+
});
|
|
225
|
+
break;
|
|
226
|
+
}
|
|
227
|
+
case 'error':
|
|
228
|
+
job._update({ status: 'failed', error: event.error });
|
|
229
|
+
break;
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
private handleServerDisconnected() {
|
|
234
|
+
this._availableModels = [];
|
|
235
|
+
this.emit('availableModels', this._availableModels);
|
|
236
|
+
this.projects.forEach((p) => {
|
|
237
|
+
p._update({ status: 'failed', error: { code: 0, message: 'Server disconnected' } });
|
|
238
|
+
});
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
/**
|
|
242
|
+
* Wait for available models to be received from the network. Useful for scripts that need to
|
|
243
|
+
* run after the models are loaded.
|
|
244
|
+
* @param timeout - timeout in milliseconds until the promise is rejected
|
|
245
|
+
*/
|
|
246
|
+
waitForModels(timeout = 10000): Promise<AvailableModel[]> {
|
|
247
|
+
if (this._availableModels.length) {
|
|
248
|
+
return Promise.resolve(this._availableModels);
|
|
249
|
+
}
|
|
250
|
+
return new Promise((resolve, reject) => {
|
|
251
|
+
const timeoutId = setTimeout(() => {
|
|
252
|
+
reject(new Error('Timeout waiting for models'));
|
|
253
|
+
}, timeout);
|
|
254
|
+
this.once('availableModels', (models) => {
|
|
255
|
+
clearTimeout(timeoutId);
|
|
256
|
+
if (models.length) {
|
|
257
|
+
resolve(models);
|
|
258
|
+
} else {
|
|
259
|
+
reject(new Error('No models available'));
|
|
260
|
+
}
|
|
261
|
+
});
|
|
262
|
+
});
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
/**
|
|
266
|
+
* Send new project request to the network. Returns project instance which can be used to track
|
|
267
|
+
* progress and get resulting images.
|
|
268
|
+
* @param data
|
|
269
|
+
*/
|
|
270
|
+
async create(data: ProjectParams): Promise<Project> {
|
|
271
|
+
const project = new Project({ ...data });
|
|
272
|
+
if (data.startingImage) {
|
|
273
|
+
await this.uploadGuideImage(project.id, data.startingImage);
|
|
274
|
+
}
|
|
275
|
+
const request = createJobRequestMessage(project.id, data);
|
|
276
|
+
await this.client.socket.send('jobRequest', request);
|
|
277
|
+
this.projects.push(project);
|
|
278
|
+
return project;
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
private async uploadGuideImage(projectId: string, file: File | Buffer | Blob) {
|
|
282
|
+
const imageId = getUUID();
|
|
283
|
+
const presignedUrl = await this.uploadUrl({
|
|
284
|
+
imageId: imageId,
|
|
285
|
+
jobId: projectId,
|
|
286
|
+
type: 'startingImage'
|
|
287
|
+
});
|
|
288
|
+
const res = await fetch(presignedUrl, {
|
|
289
|
+
method: 'PUT',
|
|
290
|
+
body: file
|
|
291
|
+
});
|
|
292
|
+
if (!res.ok) {
|
|
293
|
+
throw new ApiError(res.status, {
|
|
294
|
+
status: 'error',
|
|
295
|
+
errorCode: 0,
|
|
296
|
+
message: 'Failed to upload guide image'
|
|
297
|
+
});
|
|
298
|
+
}
|
|
299
|
+
return imageId;
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
/**
|
|
303
|
+
* Estimate project cost
|
|
304
|
+
*/
|
|
305
|
+
async estimateCost({
|
|
306
|
+
network,
|
|
307
|
+
model,
|
|
308
|
+
imageCount,
|
|
309
|
+
stepCount,
|
|
310
|
+
previewCount,
|
|
311
|
+
cnEnabled,
|
|
312
|
+
startingImageStrength
|
|
313
|
+
}: EstimateRequest) {
|
|
314
|
+
const r = await this.client.socket.get<EstimationResponse>(
|
|
315
|
+
`/api/v1/job/estimate/${network}/${model}/${imageCount}/${stepCount}/${previewCount}/${cnEnabled ? 1 : 0}/${startingImageStrength ? 1 - startingImageStrength : 0}`
|
|
316
|
+
);
|
|
317
|
+
return {
|
|
318
|
+
token: r.quote.project.costInToken,
|
|
319
|
+
usd: r.quote.project.costInUSD
|
|
320
|
+
};
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
async uploadUrl(params: ImageUrlParams) {
|
|
324
|
+
const r = await this.client.rest.get<ApiReponse<{ uploadUrl: string }>>(
|
|
325
|
+
`/v1/image/uploadUrl`,
|
|
326
|
+
params
|
|
327
|
+
);
|
|
328
|
+
return r.data.uploadUrl;
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
async downloadUrl(params: ImageUrlParams) {
|
|
332
|
+
const r = await this.client.rest.get<ApiReponse<{ downloadUrl: string }>>(
|
|
333
|
+
`/v1/image/downloadUrl`,
|
|
334
|
+
params
|
|
335
|
+
);
|
|
336
|
+
return r.data.downloadUrl;
|
|
337
|
+
}
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
export default ProjectsApi;
|