@sogni-ai/sogni-client 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. package/README.md +1 -1
  2. package/dist/version.d.ts +1 -1
  3. package/dist/version.js +1 -1
  4. package/package.json +5 -3
  5. package/src/Account/CurrentAccount.ts +101 -0
  6. package/src/Account/index.ts +243 -0
  7. package/src/Account/types.ts +90 -0
  8. package/src/ApiClient/WebSocketClient/ErrorCode.ts +15 -0
  9. package/src/ApiClient/WebSocketClient/events.ts +94 -0
  10. package/src/ApiClient/WebSocketClient/index.ts +203 -0
  11. package/src/ApiClient/WebSocketClient/messages.ts +7 -0
  12. package/src/ApiClient/WebSocketClient/types.ts +1 -0
  13. package/src/ApiClient/events.ts +20 -0
  14. package/src/ApiClient/index.ts +124 -0
  15. package/src/ApiGroup.ts +25 -0
  16. package/src/Projects/Job.ts +124 -0
  17. package/src/Projects/Project.ts +185 -0
  18. package/src/Projects/createJobRequestMessage.ts +99 -0
  19. package/src/Projects/index.ts +340 -0
  20. package/src/Projects/models.json +8906 -0
  21. package/src/Projects/types/EstimationResponse.ts +45 -0
  22. package/src/Projects/types/events.ts +78 -0
  23. package/src/Projects/types/index.ts +146 -0
  24. package/src/Stats/index.ts +15 -0
  25. package/src/Stats/types.ts +34 -0
  26. package/src/events.ts +5 -0
  27. package/src/index.ts +120 -0
  28. package/src/lib/DataEntity.ts +38 -0
  29. package/src/lib/DefaultLogger.ts +47 -0
  30. package/src/lib/EIP712Helper.ts +57 -0
  31. package/src/lib/RestClient.ts +76 -0
  32. package/src/lib/TypedEventEmitter.ts +66 -0
  33. package/src/lib/base64.ts +9 -0
  34. package/src/lib/getUUID.ts +8 -0
  35. package/src/lib/isNodejs.ts +4 -0
  36. package/src/types/ErrorData.ts +6 -0
  37. package/src/types/json.ts +5 -0
  38. package/src/version.ts +1 -0
@@ -0,0 +1,99 @@
1
+ import { ProjectParams } from './types';
2
+ // Mac worker can't process the data if some of the fields are missing, so we need to provide a default template
3
+ function getTemplate() {
4
+ return {
5
+ selectedUpscalingModel: 'OFF',
6
+ cnVideoFramesSketch: [],
7
+ cnVideoFramesSegmentedSubject: [],
8
+ cnVideoFramesFace: [],
9
+ doCanvasBlending: false,
10
+ animationIsOn: false,
11
+ cnVideoFramesBoth: [],
12
+ cnVideoFramesDepth: [],
13
+ keyFrames: [
14
+ {
15
+ stepsIsEnabled: true,
16
+ siRotation: 0,
17
+ siDragOffsetIsEnabled: true,
18
+ strength: 0.5,
19
+ siZoomScaleIsEnabled: true,
20
+ isEnabled: true,
21
+ processing: 'CPU, GPU',
22
+ useLastImageAsGuideImageInAnimation: true,
23
+ guidanceScaleIsEnabled: true,
24
+ siImageBackgroundColor: 'black',
25
+ cnDragOffset: [0, 0],
26
+ scheduler: 'DPM Solver Multistep (DPM-Solver++)',
27
+ timeStepSpacing: 'Linear',
28
+ steps: 20,
29
+ cnRotation: 0,
30
+ guidanceScale: 7.5,
31
+ siZoomScale: 1,
32
+ modelID: '',
33
+ cnRotationIsEnabled: true,
34
+ negativePrompt: '',
35
+ startingImageZoomPanIsOn: false,
36
+ seed: undefined,
37
+ siRotationIsEnabled: true,
38
+ cnImageBackgroundColor: 'clear',
39
+ strengthIsEnabled: true,
40
+ siDragOffset: [0, 0],
41
+ useLastImageAsCNImageInAnimation: false,
42
+ positivePrompt: '',
43
+ controlNetZoomPanIsOn: false,
44
+ cnZoomScaleIsEnabled: true,
45
+ currentControlNets: null,
46
+ stylePrompt: '',
47
+ cnDragOffsetIsEnabled: true,
48
+ frameIndex: 0,
49
+ startingImage: null,
50
+ cnZoomScale: 1
51
+ }
52
+ ],
53
+ previews: 5,
54
+ frameRate: 24,
55
+ generatedVideoSeconds: 10,
56
+ canvasIsOn: false,
57
+ cnVideoFrames: [],
58
+ disableSafety: false,
59
+ cnVideoFramesSegmentedBackground: [],
60
+ cnVideoFramesSegmented: [],
61
+ numberOfImages: 1,
62
+ cnVideoFramesPose: [],
63
+ jobID: '',
64
+ siVideoFrames: []
65
+ };
66
+ }
67
+
68
+ function createJobRequestMessage(id: string, params: ProjectParams) {
69
+ const template = getTemplate();
70
+ return {
71
+ ...template,
72
+ keyFrames: [
73
+ {
74
+ ...template.keyFrames[0],
75
+ scheduler: params.scheduler,
76
+ steps: params.steps,
77
+ guidanceScale: params.guidance,
78
+ modelID: params.modelId,
79
+ negativePrompt: params.negativePrompt,
80
+ seed: params.seed,
81
+ positivePrompt: params.positivePrompt,
82
+ stylePrompt: params.stylePrompt,
83
+ hasStartingImage: !!params.startingImage,
84
+ strengthIsEnabled: !!params.startingImage,
85
+ strength: !!params.startingImage
86
+ ? 1 - (Number(params.startingImageStrength) || 0.5)
87
+ : undefined
88
+ }
89
+ ],
90
+ previews: params.numberOfPreviews || 0,
91
+ numberOfImages: params.numberOfImages,
92
+ jobID: id,
93
+ disableSafety: !!params.disableNSFWFilter
94
+ };
95
+ }
96
+
97
+ export type JobRequestRaw = ReturnType<typeof createJobRequestMessage>;
98
+
99
+ export default createJobRequestMessage;
@@ -0,0 +1,340 @@
1
+ import ApiGroup, { ApiConfig } from '../ApiGroup';
2
+ import models from './models.json';
3
+ import { AvailableModel, EstimateRequest, ImageUrlParams, ProjectParams } from './types';
4
+ import {
5
+ JobErrorData,
6
+ JobProgressData,
7
+ JobResultData,
8
+ JobStateData,
9
+ SocketEventMap
10
+ } from '../ApiClient/WebSocketClient/events';
11
+ import Project from './Project';
12
+ import createJobRequestMessage from './createJobRequestMessage';
13
+ import { ApiError, ApiReponse } from '../ApiClient';
14
+ import { EstimationResponse } from './types/EstimationResponse';
15
+ import { JobEvent, ProjectApiEvents, ProjectEvent } from './types/events';
16
+ import getUUID from '../lib/getUUID';
17
+
18
+ const GARBAGE_COLLECT_TIMEOUT = 10000;
19
+
20
+ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
21
+ private _availableModels: AvailableModel[] = [];
22
+ private projects: Project[] = [];
23
+
24
+ get availableModels() {
25
+ return this._availableModels;
26
+ }
27
+
28
+ constructor(config: ApiConfig) {
29
+ super(config);
30
+ // Listen to server events and emit them as project and job events
31
+ this.client.socket.on('swarmModels', this.handleSwarmModels.bind(this));
32
+ this.client.socket.on('jobState', this.handleJobState.bind(this));
33
+ this.client.socket.on('jobProgress', this.handleJobProgress.bind(this));
34
+ this.client.socket.on('jobError', this.handleJobError.bind(this));
35
+ this.client.socket.on('jobResult', this.handleJobResult.bind(this));
36
+ // Listen to server disconnect event
37
+ this.client.on('disconnected', this.handleServerDisconnected.bind(this));
38
+ // Listen to project and job events and update project and job instances
39
+ this.on('project', this.handleProjectEvent.bind(this));
40
+ this.on('job', this.handleJobEvent.bind(this));
41
+ }
42
+
43
+ private handleSwarmModels(data: SocketEventMap['swarmModels']) {
44
+ const modelIndex = models.reduce((acc: Record<string, any>, model) => {
45
+ acc[model.modelId] = model;
46
+ return acc;
47
+ }, {});
48
+ this._availableModels = Object.entries(data).map(([id, workerCount]) => ({
49
+ id,
50
+ name: modelIndex[id].modelShortName || id.replace(/-/g, ' '),
51
+ workerCount
52
+ }));
53
+ this.emit('availableModels', this._availableModels);
54
+ }
55
+
56
+ private handleJobState(data: JobStateData) {
57
+ switch (data.type) {
58
+ case 'queued':
59
+ this.emit('project', {
60
+ type: 'queued',
61
+ projectId: data.jobID,
62
+ queuePosition: data.queuePosition
63
+ });
64
+ return;
65
+ case 'jobCompleted':
66
+ this.emit('project', { type: 'completed', projectId: data.jobID });
67
+ return;
68
+ case 'initiatingModel':
69
+ this.emit('job', { type: 'initiating', projectId: data.jobID, jobId: data.imgID });
70
+ return;
71
+ case 'jobStarted': {
72
+ this.emit('job', { type: 'started', projectId: data.jobID, jobId: data.imgID });
73
+ return;
74
+ }
75
+ }
76
+ }
77
+
78
+ private async handleJobProgress(data: JobProgressData) {
79
+ this.emit('job', {
80
+ type: 'progress',
81
+ projectId: data.jobID,
82
+ jobId: data.imgID,
83
+ step: data.step,
84
+ stepCount: data.stepCount
85
+ });
86
+
87
+ if (data.hasImage) {
88
+ this.downloadUrl({
89
+ jobId: data.jobID,
90
+ imageId: data.imgID,
91
+ type: 'preview'
92
+ }).then((url) => {
93
+ this.emit('job', {
94
+ type: 'preview',
95
+ projectId: data.jobID,
96
+ jobId: data.imgID,
97
+ url
98
+ });
99
+ });
100
+ }
101
+ }
102
+
103
+ private async handleJobResult(data: JobResultData) {
104
+ const project = this.projects.find((p) => p.id === data.jobID);
105
+ const passNSFWCheck = !data.triggeredNSFWFilter || !project || project.params.disableNSFWFilter;
106
+ let downloadUrl = null;
107
+ // If NSFW filter is triggered, image will be only available for download if user explicitly
108
+ // disabled the filter for this project
109
+ if (passNSFWCheck && !data.userCanceled) {
110
+ downloadUrl = await this.downloadUrl({
111
+ jobId: data.jobID,
112
+ imageId: data.imgID,
113
+ type: 'complete'
114
+ });
115
+ }
116
+
117
+ this.emit('job', {
118
+ type: 'completed',
119
+ projectId: data.jobID,
120
+ jobId: data.imgID,
121
+ steps: data.performedStepCount,
122
+ seed: Number(data.lastSeed),
123
+ resultUrl: downloadUrl,
124
+ isNSFW: data.triggeredNSFWFilter,
125
+ userCanceled: data.userCanceled
126
+ });
127
+ }
128
+
129
+ private handleJobError(data: JobErrorData) {
130
+ if (!data.imgID) {
131
+ this.emit('project', {
132
+ type: 'error',
133
+ projectId: data.jobID,
134
+ error: {
135
+ code: Number(data.error),
136
+ message: data.error_message
137
+ }
138
+ });
139
+ return;
140
+ }
141
+ this.emit('job', {
142
+ type: 'error',
143
+ projectId: data.jobID,
144
+ jobId: data.imgID,
145
+ error: {
146
+ code: Number(data.error),
147
+ message: data.error_message
148
+ }
149
+ });
150
+ }
151
+
152
+ handleProjectEvent(event: ProjectEvent) {
153
+ let project = this.projects.find((p) => p.id === event.projectId);
154
+ if (!project) {
155
+ return;
156
+ }
157
+ switch (event.type) {
158
+ case 'queued':
159
+ project._update({
160
+ status: 'queued',
161
+ queuePosition: event.queuePosition
162
+ });
163
+ break;
164
+ case 'completed':
165
+ project._update({
166
+ status: 'completed'
167
+ });
168
+ break;
169
+ case 'error':
170
+ project._update({
171
+ status: 'failed',
172
+ error: event.error
173
+ });
174
+ }
175
+ if (project.status === 'completed' || project.status === 'failed') {
176
+ setTimeout(() => {
177
+ this.projects = this.projects.filter((p) => p.id !== event.projectId);
178
+ }, GARBAGE_COLLECT_TIMEOUT);
179
+ }
180
+ }
181
+
182
+ handleJobEvent(event: JobEvent) {
183
+ let project = this.projects.find((p) => p.id === event.projectId);
184
+ if (!project) {
185
+ return;
186
+ }
187
+ let job = project.job(event.jobId);
188
+ if (!job) {
189
+ job = project._addJob({
190
+ id: event.jobId,
191
+ status: 'pending',
192
+ step: 0,
193
+ stepCount: project.params.steps
194
+ });
195
+ }
196
+ switch (event.type) {
197
+ case 'initiating':
198
+ job._update({ status: 'initiating' });
199
+ break;
200
+ case 'started':
201
+ job._update({ status: 'processing' });
202
+ break;
203
+ case 'progress':
204
+ job._update({
205
+ status: 'processing',
206
+ step: event.step,
207
+ stepCount: event.stepCount
208
+ });
209
+ if (project.status !== 'processing') {
210
+ project._update({ status: 'processing' });
211
+ }
212
+ break;
213
+ case 'preview':
214
+ job._update({ previewUrl: event.url });
215
+ break;
216
+ case 'completed': {
217
+ job._update({
218
+ status: event.userCanceled ? 'canceled' : 'completed',
219
+ step: event.steps,
220
+ seed: event.seed,
221
+ resultUrl: event.resultUrl,
222
+ isNSFW: event.isNSFW,
223
+ userCanceled: event.userCanceled
224
+ });
225
+ break;
226
+ }
227
+ case 'error':
228
+ job._update({ status: 'failed', error: event.error });
229
+ break;
230
+ }
231
+ }
232
+
233
+ private handleServerDisconnected() {
234
+ this._availableModels = [];
235
+ this.emit('availableModels', this._availableModels);
236
+ this.projects.forEach((p) => {
237
+ p._update({ status: 'failed', error: { code: 0, message: 'Server disconnected' } });
238
+ });
239
+ }
240
+
241
+ /**
242
+ * Wait for available models to be received from the network. Useful for scripts that need to
243
+ * run after the models are loaded.
244
+ * @param timeout - timeout in milliseconds until the promise is rejected
245
+ */
246
+ waitForModels(timeout = 10000): Promise<AvailableModel[]> {
247
+ if (this._availableModels.length) {
248
+ return Promise.resolve(this._availableModels);
249
+ }
250
+ return new Promise((resolve, reject) => {
251
+ const timeoutId = setTimeout(() => {
252
+ reject(new Error('Timeout waiting for models'));
253
+ }, timeout);
254
+ this.once('availableModels', (models) => {
255
+ clearTimeout(timeoutId);
256
+ if (models.length) {
257
+ resolve(models);
258
+ } else {
259
+ reject(new Error('No models available'));
260
+ }
261
+ });
262
+ });
263
+ }
264
+
265
+ /**
266
+ * Send new project request to the network. Returns project instance which can be used to track
267
+ * progress and get resulting images.
268
+ * @param data
269
+ */
270
+ async create(data: ProjectParams): Promise<Project> {
271
+ const project = new Project({ ...data });
272
+ if (data.startingImage) {
273
+ await this.uploadGuideImage(project.id, data.startingImage);
274
+ }
275
+ const request = createJobRequestMessage(project.id, data);
276
+ await this.client.socket.send('jobRequest', request);
277
+ this.projects.push(project);
278
+ return project;
279
+ }
280
+
281
+ private async uploadGuideImage(projectId: string, file: File | Buffer | Blob) {
282
+ const imageId = getUUID();
283
+ const presignedUrl = await this.uploadUrl({
284
+ imageId: imageId,
285
+ jobId: projectId,
286
+ type: 'startingImage'
287
+ });
288
+ const res = await fetch(presignedUrl, {
289
+ method: 'PUT',
290
+ body: file
291
+ });
292
+ if (!res.ok) {
293
+ throw new ApiError(res.status, {
294
+ status: 'error',
295
+ errorCode: 0,
296
+ message: 'Failed to upload guide image'
297
+ });
298
+ }
299
+ return imageId;
300
+ }
301
+
302
+ /**
303
+ * Estimate project cost
304
+ */
305
+ async estimateCost({
306
+ network,
307
+ model,
308
+ imageCount,
309
+ stepCount,
310
+ previewCount,
311
+ cnEnabled,
312
+ startingImageStrength
313
+ }: EstimateRequest) {
314
+ const r = await this.client.socket.get<EstimationResponse>(
315
+ `/api/v1/job/estimate/${network}/${model}/${imageCount}/${stepCount}/${previewCount}/${cnEnabled ? 1 : 0}/${startingImageStrength ? 1 - startingImageStrength : 0}`
316
+ );
317
+ return {
318
+ token: r.quote.project.costInToken,
319
+ usd: r.quote.project.costInUSD
320
+ };
321
+ }
322
+
323
+ async uploadUrl(params: ImageUrlParams) {
324
+ const r = await this.client.rest.get<ApiReponse<{ uploadUrl: string }>>(
325
+ `/v1/image/uploadUrl`,
326
+ params
327
+ );
328
+ return r.data.uploadUrl;
329
+ }
330
+
331
+ async downloadUrl(params: ImageUrlParams) {
332
+ const r = await this.client.rest.get<ApiReponse<{ downloadUrl: string }>>(
333
+ `/v1/image/downloadUrl`,
334
+ params
335
+ );
336
+ return r.data.downloadUrl;
337
+ }
338
+ }
339
+
340
+ export default ProjectsApi;