@sogni-ai/sogni-client 0.3.2 → 0.4.0-aplha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. package/README.md +3 -2
  2. package/dist/Account/index.d.ts +127 -3
  3. package/dist/Account/index.js +126 -2
  4. package/dist/Account/index.js.map +1 -1
  5. package/dist/Projects/Job.d.ts +5 -0
  6. package/dist/Projects/Job.js +6 -0
  7. package/dist/Projects/Job.js.map +1 -1
  8. package/dist/Projects/index.d.ts +6 -6
  9. package/dist/Projects/index.js +12 -2
  10. package/dist/Projects/index.js.map +1 -1
  11. package/dist/Projects/types/events.d.ts +2 -0
  12. package/dist/version.d.ts +1 -1
  13. package/dist/version.js +1 -1
  14. package/dist/version.js.map +1 -1
  15. package/package.json +5 -3
  16. package/src/Account/CurrentAccount.ts +101 -0
  17. package/src/Account/index.ts +367 -0
  18. package/src/Account/types.ts +90 -0
  19. package/src/ApiClient/WebSocketClient/ErrorCode.ts +15 -0
  20. package/src/ApiClient/WebSocketClient/events.ts +94 -0
  21. package/src/ApiClient/WebSocketClient/index.ts +203 -0
  22. package/src/ApiClient/WebSocketClient/messages.ts +7 -0
  23. package/src/ApiClient/WebSocketClient/types.ts +1 -0
  24. package/src/ApiClient/events.ts +20 -0
  25. package/src/ApiClient/index.ts +124 -0
  26. package/src/ApiGroup.ts +25 -0
  27. package/src/Projects/Job.ts +132 -0
  28. package/src/Projects/Project.ts +185 -0
  29. package/src/Projects/createJobRequestMessage.ts +99 -0
  30. package/src/Projects/index.ts +350 -0
  31. package/src/Projects/models.json +8906 -0
  32. package/src/Projects/types/EstimationResponse.ts +45 -0
  33. package/src/Projects/types/events.ts +80 -0
  34. package/src/Projects/types/index.ts +146 -0
  35. package/src/Stats/index.ts +15 -0
  36. package/src/Stats/types.ts +34 -0
  37. package/src/events.ts +5 -0
  38. package/src/index.ts +120 -0
  39. package/src/lib/DataEntity.ts +38 -0
  40. package/src/lib/DefaultLogger.ts +47 -0
  41. package/src/lib/EIP712Helper.ts +57 -0
  42. package/src/lib/RestClient.ts +76 -0
  43. package/src/lib/TypedEventEmitter.ts +66 -0
  44. package/src/lib/base64.ts +9 -0
  45. package/src/lib/getUUID.ts +8 -0
  46. package/src/lib/isNodejs.ts +4 -0
  47. package/src/types/ErrorData.ts +6 -0
  48. package/src/types/json.ts +5 -0
  49. package/src/version.ts +1 -0
@@ -0,0 +1,99 @@
1
+ import { ProjectParams } from './types';
2
+ // Mac worker can't process the data if some of the fields are missing, so we need to provide a default template
3
+ function getTemplate() {
4
+ return {
5
+ selectedUpscalingModel: 'OFF',
6
+ cnVideoFramesSketch: [],
7
+ cnVideoFramesSegmentedSubject: [],
8
+ cnVideoFramesFace: [],
9
+ doCanvasBlending: false,
10
+ animationIsOn: false,
11
+ cnVideoFramesBoth: [],
12
+ cnVideoFramesDepth: [],
13
+ keyFrames: [
14
+ {
15
+ stepsIsEnabled: true,
16
+ siRotation: 0,
17
+ siDragOffsetIsEnabled: true,
18
+ strength: 0.5,
19
+ siZoomScaleIsEnabled: true,
20
+ isEnabled: true,
21
+ processing: 'CPU, GPU',
22
+ useLastImageAsGuideImageInAnimation: true,
23
+ guidanceScaleIsEnabled: true,
24
+ siImageBackgroundColor: 'black',
25
+ cnDragOffset: [0, 0],
26
+ scheduler: 'DPM Solver Multistep (DPM-Solver++)',
27
+ timeStepSpacing: 'Linear',
28
+ steps: 20,
29
+ cnRotation: 0,
30
+ guidanceScale: 7.5,
31
+ siZoomScale: 1,
32
+ modelID: '',
33
+ cnRotationIsEnabled: true,
34
+ negativePrompt: '',
35
+ startingImageZoomPanIsOn: false,
36
+ seed: undefined,
37
+ siRotationIsEnabled: true,
38
+ cnImageBackgroundColor: 'clear',
39
+ strengthIsEnabled: true,
40
+ siDragOffset: [0, 0],
41
+ useLastImageAsCNImageInAnimation: false,
42
+ positivePrompt: '',
43
+ controlNetZoomPanIsOn: false,
44
+ cnZoomScaleIsEnabled: true,
45
+ currentControlNets: null,
46
+ stylePrompt: '',
47
+ cnDragOffsetIsEnabled: true,
48
+ frameIndex: 0,
49
+ startingImage: null,
50
+ cnZoomScale: 1
51
+ }
52
+ ],
53
+ previews: 5,
54
+ frameRate: 24,
55
+ generatedVideoSeconds: 10,
56
+ canvasIsOn: false,
57
+ cnVideoFrames: [],
58
+ disableSafety: false,
59
+ cnVideoFramesSegmentedBackground: [],
60
+ cnVideoFramesSegmented: [],
61
+ numberOfImages: 1,
62
+ cnVideoFramesPose: [],
63
+ jobID: '',
64
+ siVideoFrames: []
65
+ };
66
+ }
67
+
68
+ function createJobRequestMessage(id: string, params: ProjectParams) {
69
+ const template = getTemplate();
70
+ return {
71
+ ...template,
72
+ keyFrames: [
73
+ {
74
+ ...template.keyFrames[0],
75
+ scheduler: params.scheduler,
76
+ steps: params.steps,
77
+ guidanceScale: params.guidance,
78
+ modelID: params.modelId,
79
+ negativePrompt: params.negativePrompt,
80
+ seed: params.seed,
81
+ positivePrompt: params.positivePrompt,
82
+ stylePrompt: params.stylePrompt,
83
+ hasStartingImage: !!params.startingImage,
84
+ strengthIsEnabled: !!params.startingImage,
85
+ strength: !!params.startingImage
86
+ ? 1 - (Number(params.startingImageStrength) || 0.5)
87
+ : undefined
88
+ }
89
+ ],
90
+ previews: params.numberOfPreviews || 0,
91
+ numberOfImages: params.numberOfImages,
92
+ jobID: id,
93
+ disableSafety: !!params.disableNSFWFilter
94
+ };
95
+ }
96
+
97
+ export type JobRequestRaw = ReturnType<typeof createJobRequestMessage>;
98
+
99
+ export default createJobRequestMessage;
@@ -0,0 +1,350 @@
1
+ import ApiGroup, { ApiConfig } from '../ApiGroup';
2
+ import models from './models.json';
3
+ import { AvailableModel, EstimateRequest, ImageUrlParams, ProjectParams } from './types';
4
+ import {
5
+ JobErrorData,
6
+ JobProgressData,
7
+ JobResultData,
8
+ JobStateData,
9
+ SocketEventMap
10
+ } from '../ApiClient/WebSocketClient/events';
11
+ import Project from './Project';
12
+ import createJobRequestMessage from './createJobRequestMessage';
13
+ import { ApiError, ApiReponse } from '../ApiClient';
14
+ import { EstimationResponse } from './types/EstimationResponse';
15
+ import { JobEvent, ProjectApiEvents, ProjectEvent } from './types/events';
16
+ import getUUID from '../lib/getUUID';
17
+
18
+ const GARBAGE_COLLECT_TIMEOUT = 10000;
19
+
20
+ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
21
+ private _availableModels: AvailableModel[] = [];
22
+ private projects: Project[] = [];
23
+
24
+ get availableModels() {
25
+ return this._availableModels;
26
+ }
27
+
28
+ constructor(config: ApiConfig) {
29
+ super(config);
30
+ // Listen to server events and emit them as project and job events
31
+ this.client.socket.on('swarmModels', this.handleSwarmModels.bind(this));
32
+ this.client.socket.on('jobState', this.handleJobState.bind(this));
33
+ this.client.socket.on('jobProgress', this.handleJobProgress.bind(this));
34
+ this.client.socket.on('jobError', this.handleJobError.bind(this));
35
+ this.client.socket.on('jobResult', this.handleJobResult.bind(this));
36
+ // Listen to server disconnect event
37
+ this.client.on('disconnected', this.handleServerDisconnected.bind(this));
38
+ // Listen to project and job events and update project and job instances
39
+ this.on('project', this.handleProjectEvent.bind(this));
40
+ this.on('job', this.handleJobEvent.bind(this));
41
+ }
42
+
43
+ private handleSwarmModels(data: SocketEventMap['swarmModels']) {
44
+ const modelIndex = models.reduce((acc: Record<string, any>, model) => {
45
+ acc[model.modelId] = model;
46
+ return acc;
47
+ }, {});
48
+ this._availableModels = Object.entries(data).map(([id, workerCount]) => ({
49
+ id,
50
+ name: modelIndex[id].modelShortName || id.replace(/-/g, ' '),
51
+ workerCount
52
+ }));
53
+ this.emit('availableModels', this._availableModels);
54
+ }
55
+
56
+ private handleJobState(data: JobStateData) {
57
+ switch (data.type) {
58
+ case 'queued':
59
+ this.emit('project', {
60
+ type: 'queued',
61
+ projectId: data.jobID,
62
+ queuePosition: data.queuePosition
63
+ });
64
+ return;
65
+ case 'jobCompleted':
66
+ this.emit('project', { type: 'completed', projectId: data.jobID });
67
+ return;
68
+ case 'initiatingModel':
69
+ this.emit('job', {
70
+ type: 'initiating',
71
+ projectId: data.jobID,
72
+ jobId: data.imgID,
73
+ workerName: data.workerName
74
+ });
75
+ return;
76
+ case 'jobStarted': {
77
+ this.emit('job', {
78
+ type: 'started',
79
+ projectId: data.jobID,
80
+ jobId: data.imgID,
81
+ workerName: data.workerName
82
+ });
83
+ return;
84
+ }
85
+ }
86
+ }
87
+
88
+ private async handleJobProgress(data: JobProgressData) {
89
+ this.emit('job', {
90
+ type: 'progress',
91
+ projectId: data.jobID,
92
+ jobId: data.imgID,
93
+ step: data.step,
94
+ stepCount: data.stepCount
95
+ });
96
+
97
+ if (data.hasImage) {
98
+ this.downloadUrl({
99
+ jobId: data.jobID,
100
+ imageId: data.imgID,
101
+ type: 'preview'
102
+ }).then((url) => {
103
+ this.emit('job', {
104
+ type: 'preview',
105
+ projectId: data.jobID,
106
+ jobId: data.imgID,
107
+ url
108
+ });
109
+ });
110
+ }
111
+ }
112
+
113
+ private async handleJobResult(data: JobResultData) {
114
+ const project = this.projects.find((p) => p.id === data.jobID);
115
+ const passNSFWCheck = !data.triggeredNSFWFilter || !project || project.params.disableNSFWFilter;
116
+ let downloadUrl = null;
117
+ // If NSFW filter is triggered, image will be only available for download if user explicitly
118
+ // disabled the filter for this project
119
+ if (passNSFWCheck && !data.userCanceled) {
120
+ downloadUrl = await this.downloadUrl({
121
+ jobId: data.jobID,
122
+ imageId: data.imgID,
123
+ type: 'complete'
124
+ });
125
+ }
126
+
127
+ this.emit('job', {
128
+ type: 'completed',
129
+ projectId: data.jobID,
130
+ jobId: data.imgID,
131
+ steps: data.performedStepCount,
132
+ seed: Number(data.lastSeed),
133
+ resultUrl: downloadUrl,
134
+ isNSFW: data.triggeredNSFWFilter,
135
+ userCanceled: data.userCanceled
136
+ });
137
+ }
138
+
139
+ private handleJobError(data: JobErrorData) {
140
+ if (!data.imgID) {
141
+ this.emit('project', {
142
+ type: 'error',
143
+ projectId: data.jobID,
144
+ error: {
145
+ code: Number(data.error),
146
+ message: data.error_message
147
+ }
148
+ });
149
+ return;
150
+ }
151
+ this.emit('job', {
152
+ type: 'error',
153
+ projectId: data.jobID,
154
+ jobId: data.imgID,
155
+ error: {
156
+ code: Number(data.error),
157
+ message: data.error_message
158
+ }
159
+ });
160
+ }
161
+
162
+ private handleProjectEvent(event: ProjectEvent) {
163
+ let project = this.projects.find((p) => p.id === event.projectId);
164
+ if (!project) {
165
+ return;
166
+ }
167
+ switch (event.type) {
168
+ case 'queued':
169
+ project._update({
170
+ status: 'queued',
171
+ queuePosition: event.queuePosition
172
+ });
173
+ break;
174
+ case 'completed':
175
+ project._update({
176
+ status: 'completed'
177
+ });
178
+ break;
179
+ case 'error':
180
+ project._update({
181
+ status: 'failed',
182
+ error: event.error
183
+ });
184
+ }
185
+ if (project.status === 'completed' || project.status === 'failed') {
186
+ setTimeout(() => {
187
+ this.projects = this.projects.filter((p) => p.id !== event.projectId);
188
+ }, GARBAGE_COLLECT_TIMEOUT);
189
+ }
190
+ }
191
+
192
+ private handleJobEvent(event: JobEvent) {
193
+ let project = this.projects.find((p) => p.id === event.projectId);
194
+ if (!project) {
195
+ return;
196
+ }
197
+ let job = project.job(event.jobId);
198
+ if (!job) {
199
+ job = project._addJob({
200
+ id: event.jobId,
201
+ status: 'pending',
202
+ step: 0,
203
+ stepCount: project.params.steps
204
+ });
205
+ }
206
+ switch (event.type) {
207
+ case 'initiating':
208
+ job._update({ status: 'initiating' });
209
+ break;
210
+ case 'started':
211
+ job._update({ status: 'processing' });
212
+ break;
213
+ case 'progress':
214
+ job._update({
215
+ status: 'processing',
216
+ step: event.step,
217
+ stepCount: event.stepCount
218
+ });
219
+ if (project.status !== 'processing') {
220
+ project._update({ status: 'processing' });
221
+ }
222
+ break;
223
+ case 'preview':
224
+ job._update({ previewUrl: event.url });
225
+ break;
226
+ case 'completed': {
227
+ job._update({
228
+ status: event.userCanceled ? 'canceled' : 'completed',
229
+ step: event.steps,
230
+ seed: event.seed,
231
+ resultUrl: event.resultUrl,
232
+ isNSFW: event.isNSFW,
233
+ userCanceled: event.userCanceled
234
+ });
235
+ break;
236
+ }
237
+ case 'error':
238
+ job._update({ status: 'failed', error: event.error });
239
+ break;
240
+ }
241
+ }
242
+
243
+ private handleServerDisconnected() {
244
+ this._availableModels = [];
245
+ this.emit('availableModels', this._availableModels);
246
+ this.projects.forEach((p) => {
247
+ p._update({ status: 'failed', error: { code: 0, message: 'Server disconnected' } });
248
+ });
249
+ }
250
+
251
+ /**
252
+ * Wait for available models to be received from the network. Useful for scripts that need to
253
+ * run after the models are loaded.
254
+ * @param timeout - timeout in milliseconds until the promise is rejected
255
+ */
256
+ waitForModels(timeout = 10000): Promise<AvailableModel[]> {
257
+ if (this._availableModels.length) {
258
+ return Promise.resolve(this._availableModels);
259
+ }
260
+ return new Promise((resolve, reject) => {
261
+ const timeoutId = setTimeout(() => {
262
+ reject(new Error('Timeout waiting for models'));
263
+ }, timeout);
264
+ this.once('availableModels', (models) => {
265
+ clearTimeout(timeoutId);
266
+ if (models.length) {
267
+ resolve(models);
268
+ } else {
269
+ reject(new Error('No models available'));
270
+ }
271
+ });
272
+ });
273
+ }
274
+
275
+ /**
276
+ * Send new project request to the network. Returns project instance which can be used to track
277
+ * progress and get resulting images.
278
+ * @param data
279
+ */
280
+ async create(data: ProjectParams): Promise<Project> {
281
+ const project = new Project({ ...data });
282
+ if (data.startingImage) {
283
+ await this.uploadGuideImage(project.id, data.startingImage);
284
+ }
285
+ const request = createJobRequestMessage(project.id, data);
286
+ await this.client.socket.send('jobRequest', request);
287
+ this.projects.push(project);
288
+ return project;
289
+ }
290
+
291
+ private async uploadGuideImage(projectId: string, file: File | Buffer | Blob) {
292
+ const imageId = getUUID();
293
+ const presignedUrl = await this.uploadUrl({
294
+ imageId: imageId,
295
+ jobId: projectId,
296
+ type: 'startingImage'
297
+ });
298
+ const res = await fetch(presignedUrl, {
299
+ method: 'PUT',
300
+ body: file
301
+ });
302
+ if (!res.ok) {
303
+ throw new ApiError(res.status, {
304
+ status: 'error',
305
+ errorCode: 0,
306
+ message: 'Failed to upload guide image'
307
+ });
308
+ }
309
+ return imageId;
310
+ }
311
+
312
+ /**
313
+ * Estimate project cost
314
+ */
315
+ async estimateCost({
316
+ network,
317
+ model,
318
+ imageCount,
319
+ stepCount,
320
+ previewCount,
321
+ cnEnabled,
322
+ startingImageStrength
323
+ }: EstimateRequest) {
324
+ const r = await this.client.socket.get<EstimationResponse>(
325
+ `/api/v1/job/estimate/${network}/${model}/${imageCount}/${stepCount}/${previewCount}/${cnEnabled ? 1 : 0}/${startingImageStrength ? 1 - startingImageStrength : 0}`
326
+ );
327
+ return {
328
+ token: r.quote.project.costInToken,
329
+ usd: r.quote.project.costInUSD
330
+ };
331
+ }
332
+
333
+ private async uploadUrl(params: ImageUrlParams) {
334
+ const r = await this.client.rest.get<ApiReponse<{ uploadUrl: string }>>(
335
+ `/v1/image/uploadUrl`,
336
+ params
337
+ );
338
+ return r.data.uploadUrl;
339
+ }
340
+
341
+ private async downloadUrl(params: ImageUrlParams) {
342
+ const r = await this.client.rest.get<ApiReponse<{ downloadUrl: string }>>(
343
+ `/v1/image/downloadUrl`,
344
+ params
345
+ );
346
+ return r.data.downloadUrl;
347
+ }
348
+ }
349
+
350
+ export default ProjectsApi;