@sogni-ai/sogni-client 0.3.3 → 0.4.0-aplha.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. package/README.md +3 -2
  2. package/dist/Account/CurrentAccount.d.ts +12 -2
  3. package/dist/Account/CurrentAccount.js.map +1 -1
  4. package/dist/Account/index.d.ts +130 -7
  5. package/dist/Account/index.js +135 -7
  6. package/dist/Account/index.js.map +1 -1
  7. package/dist/ApiClient/WebSocketClient/events.d.ts +7 -1
  8. package/dist/ApiClient/WebSocketClient/index.d.ts +1 -1
  9. package/dist/ApiClient/WebSocketClient/index.js +14 -5
  10. package/dist/ApiClient/WebSocketClient/index.js.map +1 -1
  11. package/dist/ApiClient/WebSocketClient/messages.d.ts +2 -0
  12. package/dist/Projects/Job.d.ts +25 -1
  13. package/dist/Projects/Job.js +78 -1
  14. package/dist/Projects/Job.js.map +1 -1
  15. package/dist/Projects/Project.d.ts +21 -3
  16. package/dist/Projects/Project.js +110 -2
  17. package/dist/Projects/Project.js.map +1 -1
  18. package/dist/Projects/createJobRequestMessage.d.ts +1 -61
  19. package/dist/Projects/createJobRequestMessage.js +5 -1
  20. package/dist/Projects/createJobRequestMessage.js.map +1 -1
  21. package/dist/Projects/index.d.ts +22 -3
  22. package/dist/Projects/index.js +82 -14
  23. package/dist/Projects/index.js.map +1 -1
  24. package/dist/Projects/types/RawProject.d.ts +87 -0
  25. package/dist/Projects/types/RawProject.js +3 -0
  26. package/dist/Projects/types/RawProject.js.map +1 -0
  27. package/dist/Projects/types/events.d.ts +2 -0
  28. package/dist/Projects/types/index.d.ts +4 -0
  29. package/dist/lib/DataEntity.d.ts +1 -0
  30. package/dist/lib/DataEntity.js +2 -0
  31. package/dist/lib/DataEntity.js.map +1 -1
  32. package/dist/lib/base64.js +8 -6
  33. package/dist/lib/base64.js.map +1 -1
  34. package/dist/types/ErrorData.d.ts +1 -0
  35. package/dist/version.d.ts +1 -1
  36. package/dist/version.js +1 -1
  37. package/dist/version.js.map +1 -1
  38. package/package.json +1 -1
  39. package/src/Account/CurrentAccount.ts +11 -1
  40. package/src/Account/index.ts +137 -9
  41. package/src/ApiClient/WebSocketClient/events.ts +5 -1
  42. package/src/ApiClient/WebSocketClient/index.ts +15 -6
  43. package/src/ApiClient/WebSocketClient/messages.ts +2 -0
  44. package/src/Projects/Job.ts +90 -1
  45. package/src/Projects/Project.ts +134 -5
  46. package/src/Projects/createJobRequestMessage.ts +5 -1
  47. package/src/Projects/index.ts +87 -16
  48. package/src/Projects/types/RawProject.ts +121 -0
  49. package/src/Projects/types/events.ts +2 -0
  50. package/src/Projects/types/index.ts +4 -0
  51. package/src/lib/DataEntity.ts +3 -0
  52. package/src/lib/base64.ts +8 -4
  53. package/src/types/ErrorData.ts +1 -0
  54. package/src/version.ts +1 -1
@@ -14,9 +14,28 @@ import { ApiError, ApiReponse } from '../ApiClient';
14
14
  import { EstimationResponse } from './types/EstimationResponse';
15
15
  import { JobEvent, ProjectApiEvents, ProjectEvent } from './types/events';
16
16
  import getUUID from '../lib/getUUID';
17
+ import { RawProject } from './types/RawProject';
18
+ import ErrorData from '../types/ErrorData';
17
19
 
18
20
  const GARBAGE_COLLECT_TIMEOUT = 10000;
19
21
 
22
+ function mapErrorCodes(code: string): number {
23
+ switch (code) {
24
+ case 'serverRestarting':
25
+ return 5001;
26
+ case 'workerDisconnected':
27
+ return 5002;
28
+ case 'jobTimedOut':
29
+ return 5003;
30
+ case 'artistCanceled':
31
+ return 5004;
32
+ case 'ç':
33
+ return 5005;
34
+ default:
35
+ return 5000;
36
+ }
37
+ }
38
+
20
39
  class ProjectsApi extends ApiGroup<ProjectApiEvents> {
21
40
  private _availableModels: AvailableModel[] = [];
22
41
  private projects: Project[] = [];
@@ -28,6 +47,7 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
28
47
  constructor(config: ApiConfig) {
29
48
  super(config);
30
49
  // Listen to server events and emit them as project and job events
50
+ this.client.socket.on('changeNetwork', this.handleChangeNetwork.bind(this));
31
51
  this.client.socket.on('swarmModels', this.handleSwarmModels.bind(this));
32
52
  this.client.socket.on('jobState', this.handleJobState.bind(this));
33
53
  this.client.socket.on('jobProgress', this.handleJobProgress.bind(this));
@@ -40,6 +60,11 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
40
60
  this.on('job', this.handleJobEvent.bind(this));
41
61
  }
42
62
 
63
+ private handleChangeNetwork() {
64
+ this._availableModels = [];
65
+ this.emit('availableModels', this._availableModels);
66
+ }
67
+
43
68
  private handleSwarmModels(data: SocketEventMap['swarmModels']) {
44
69
  const modelIndex = models.reduce((acc: Record<string, any>, model) => {
45
70
  acc[model.modelId] = model;
@@ -66,10 +91,20 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
66
91
  this.emit('project', { type: 'completed', projectId: data.jobID });
67
92
  return;
68
93
  case 'initiatingModel':
69
- this.emit('job', { type: 'initiating', projectId: data.jobID, jobId: data.imgID });
94
+ this.emit('job', {
95
+ type: 'initiating',
96
+ projectId: data.jobID,
97
+ jobId: data.imgID,
98
+ workerName: data.workerName
99
+ });
70
100
  return;
71
101
  case 'jobStarted': {
72
- this.emit('job', { type: 'started', projectId: data.jobID, jobId: data.imgID });
102
+ this.emit('job', {
103
+ type: 'started',
104
+ projectId: data.jobID,
105
+ jobId: data.imgID,
106
+ workerName: data.workerName
107
+ });
73
108
  return;
74
109
  }
75
110
  }
@@ -127,14 +162,25 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
127
162
  }
128
163
 
129
164
  private handleJobError(data: JobErrorData) {
165
+ const errorCode = Number(data.error);
166
+ let error: ErrorData;
167
+ if (!isNaN(errorCode)) {
168
+ error = {
169
+ code: errorCode,
170
+ message: data.error_message
171
+ };
172
+ } else {
173
+ error = {
174
+ code: mapErrorCodes(data.error as string),
175
+ originalCode: data.error?.toString(),
176
+ message: data.error_message
177
+ };
178
+ }
130
179
  if (!data.imgID) {
131
180
  this.emit('project', {
132
181
  type: 'error',
133
182
  projectId: data.jobID,
134
- error: {
135
- code: Number(data.error),
136
- message: data.error_message
137
- }
183
+ error
138
184
  });
139
185
  return;
140
186
  }
@@ -142,14 +188,11 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
142
188
  type: 'error',
143
189
  projectId: data.jobID,
144
190
  jobId: data.imgID,
145
- error: {
146
- code: Number(data.error),
147
- message: data.error_message
148
- }
191
+ error: error
149
192
  });
150
193
  }
151
194
 
152
- handleProjectEvent(event: ProjectEvent) {
195
+ private handleProjectEvent(event: ProjectEvent) {
153
196
  let project = this.projects.find((p) => p.id === event.projectId);
154
197
  if (!project) {
155
198
  return;
@@ -172,14 +215,18 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
172
215
  error: event.error
173
216
  });
174
217
  }
175
- if (project.status === 'completed' || project.status === 'failed') {
218
+ if (project.finished) {
219
+ // Sync project data with the server and remove it from the list after some time
220
+ project._syncToServer().catch((e) => {
221
+ this.client.logger.error(e);
222
+ });
176
223
  setTimeout(() => {
177
224
  this.projects = this.projects.filter((p) => p.id !== event.projectId);
178
225
  }, GARBAGE_COLLECT_TIMEOUT);
179
226
  }
180
227
  }
181
228
 
182
- handleJobEvent(event: JobEvent) {
229
+ private handleJobEvent(event: JobEvent) {
183
230
  let project = this.projects.find((p) => p.id === event.projectId);
184
231
  if (!project) {
185
232
  return;
@@ -188,6 +235,7 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
188
235
  if (!job) {
189
236
  job = project._addJob({
190
237
  id: event.jobId,
238
+ projectId: event.projectId,
191
239
  status: 'pending',
192
240
  step: 0,
193
241
  stepCount: project.params.steps
@@ -195,10 +243,10 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
195
243
  }
196
244
  switch (event.type) {
197
245
  case 'initiating':
198
- job._update({ status: 'initiating' });
246
+ job._update({ status: 'initiating', workerName: event.workerName });
199
247
  break;
200
248
  case 'started':
201
- job._update({ status: 'processing' });
249
+ job._update({ status: 'processing', workerName: event.workerName });
202
250
  break;
203
251
  case 'progress':
204
252
  job._update({
@@ -268,7 +316,7 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
268
316
  * @param data
269
317
  */
270
318
  async create(data: ProjectParams): Promise<Project> {
271
- const project = new Project({ ...data });
319
+ const project = new Project({ ...data }, { api: this, logger: this.client.logger });
272
320
  if (data.startingImage) {
273
321
  await this.uploadGuideImage(project.id, data.startingImage);
274
322
  }
@@ -278,6 +326,19 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
278
326
  return project;
279
327
  }
280
328
 
329
+ /**
330
+ * Get project by id, this API returns project data from the server only if the project is
331
+ * completed or failed. If the project is still processing, it will throw 404 error.
332
+ * @internal
333
+ * @param projectId
334
+ */
335
+ async get(projectId: string) {
336
+ const { data } = await this.client.rest.get<ApiReponse<RawProject>>(
337
+ `/v1/projects/${projectId}`
338
+ );
339
+ return data;
340
+ }
341
+
281
342
  private async uploadGuideImage(projectId: string, file: File | Buffer | Blob) {
282
343
  const imageId = getUUID();
283
344
  const presignedUrl = await this.uploadUrl({
@@ -320,6 +381,11 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
320
381
  };
321
382
  }
322
383
 
384
+ /**
385
+ * Get upload URL for image
386
+ * @internal
387
+ * @param params
388
+ */
323
389
  async uploadUrl(params: ImageUrlParams) {
324
390
  const r = await this.client.rest.get<ApiReponse<{ uploadUrl: string }>>(
325
391
  `/v1/image/uploadUrl`,
@@ -328,6 +394,11 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
328
394
  return r.data.uploadUrl;
329
395
  }
330
396
 
397
+ /**
398
+ * Get download URL for image
399
+ * @internal
400
+ * @param params
401
+ */
331
402
  async downloadUrl(params: ImageUrlParams) {
332
403
  const r = await this.client.rest.get<ApiReponse<{ downloadUrl: string }>>(
333
404
  `/v1/image/downloadUrl`,
@@ -0,0 +1,121 @@
1
+ import { SupernetType } from '../../ApiClient/WebSocketClient/types';
2
+
3
+ export interface RawProject {
4
+ id: string;
5
+ SID: number;
6
+ artist: Account;
7
+ model: ProjectModel;
8
+ imageCount: number;
9
+ stepCount: number;
10
+ previewCount: number;
11
+ hasGuideImage: boolean;
12
+ denoiseStrength: string;
13
+ costEstimate: CostEstimate;
14
+ costActual: CostActual;
15
+ createTime: number;
16
+ updateTime: number;
17
+ endTime: number;
18
+ status: RawProjectStatus;
19
+ reason: 'allJobsCompleted' | 'artistCanceled';
20
+ network: SupernetType;
21
+ txId: string;
22
+ workerJobs: RawJob[];
23
+ completedWorkerJobs: RawJob[];
24
+ }
25
+
26
+ type RawProjectStatus =
27
+ | 'pending'
28
+ | 'active'
29
+ | 'assigned'
30
+ | 'progress'
31
+ | 'errored'
32
+ | 'completed'
33
+ | 'cancelled';
34
+
35
+ export interface Account {
36
+ id?: string;
37
+ clientSID?: number;
38
+ address?: string;
39
+ addressSID?: number;
40
+ name?: string;
41
+ username?: string;
42
+ }
43
+
44
+ export interface RawJob {
45
+ id: string;
46
+ SID: string;
47
+ imgID?: string;
48
+ worker: Account;
49
+ createTime: number;
50
+ startTime: number | null;
51
+ updateTime: number;
52
+ endTime: number;
53
+ status: WorkerJobStatus;
54
+ reason: WorkerJobReason;
55
+ performedSteps: number;
56
+ triggeredNSFWFilter: boolean;
57
+ seedUsed: number;
58
+ costActual: CostActual;
59
+ network: SupernetType;
60
+ txId?: string;
61
+ }
62
+
63
+ export interface CostActual {
64
+ costInRenderSec: string;
65
+ costInUSD: string;
66
+ costInToken: string;
67
+ calculatedStepCount?: number;
68
+ }
69
+
70
+ export type WorkerJobReason =
71
+ | 'artistCanceled'
72
+ | 'artistDisconnected'
73
+ | 'genfailure'
74
+ | 'imgUploadFailure'
75
+ | 'jobCompleted'
76
+ | 'jobTimedOut'
77
+ | 'workerDisconnected'
78
+ | 'workerReconnected';
79
+
80
+ export type WorkerJobStatus =
81
+ | 'created'
82
+ | 'queued'
83
+ | 'assigned'
84
+ | 'initiatingModel'
85
+ | 'jobStarted'
86
+ | 'jobProgress'
87
+ | 'jobCompleted'
88
+ | 'jobError';
89
+
90
+ export interface CostEstimate {
91
+ rate: Rate;
92
+ quote: Quote;
93
+ }
94
+
95
+ export interface Quote {
96
+ model: QuoteModel;
97
+ job: CostActual;
98
+ project: CostActual;
99
+ }
100
+
101
+ export interface QuoteModel {
102
+ weight: string;
103
+ secPerStep: string;
104
+ secPerPreview: string;
105
+ secForCN: string;
106
+ }
107
+
108
+ export interface Rate {
109
+ costPerBaseHQRenderInUSD: string;
110
+ tokenMarkePriceUSD: string;
111
+ costPerRenderSecUSD: string;
112
+ costPerRenderSecToken: string;
113
+ network: SupernetType;
114
+ networkCostMultiplier: string;
115
+ }
116
+
117
+ export interface ProjectModel {
118
+ id: string;
119
+ SID: number;
120
+ name: string;
121
+ }
@@ -28,10 +28,12 @@ export interface JobEventBase {
28
28
 
29
29
  export interface JobInitiating extends JobEventBase {
30
30
  type: 'initiating';
31
+ workerName: string;
31
32
  }
32
33
 
33
34
  export interface JobStarted extends JobEventBase {
34
35
  type: 'started';
36
+ workerName: string;
35
37
  }
36
38
 
37
39
  export interface JobProgress extends JobEventBase {
@@ -68,6 +68,10 @@ export interface ProjectParams {
68
68
  * Guidance scale. For most Stable Diffusion models, optimal value is 7.5
69
69
  */
70
70
  guidance: number;
71
+ /**
72
+ * Override current network type. Default value can be read from `client.account.currentAccount.network`
73
+ */
74
+ network?: SupernetType;
71
75
  /**
72
76
  * Disable NSFW filter for Project. Default is false, meaning NSFW filter is enabled.
73
77
  * If image triggers NSFW filter, it will not be available for download.
@@ -11,6 +11,8 @@ export interface EntityEvents {
11
11
  abstract class DataEntity<D, E extends EntityEvents = EntityEvents> extends TypedEventEmitter<E> {
12
12
  protected data: D;
13
13
 
14
+ protected lastUpdated: Date = new Date();
15
+
14
16
  constructor(data: D) {
15
17
  super();
16
18
  this.data = data;
@@ -24,6 +26,7 @@ abstract class DataEntity<D, E extends EntityEvents = EntityEvents> extends Type
24
26
  //@ts-ignore
25
27
  const changedKeys = Object.keys(delta).filter((key) => this.data[key] !== delta[key]);
26
28
  this.data = { ...this.data, ...delta };
29
+ this.lastUpdated = new Date();
27
30
  this.emit('updated', changedKeys);
28
31
  }
29
32
 
package/src/lib/base64.ts CHANGED
@@ -1,9 +1,13 @@
1
- import isNodejs from './isNodejs';
2
-
3
1
  export function base64Encode(str: string): string {
4
- return isNodejs ? Buffer.from(str).toString('base64') : btoa(str);
2
+ const encoder = new TextEncoder();
3
+ const uint8Array = encoder.encode(str);
4
+ const binaryString = String.fromCharCode(...uint8Array);
5
+ return btoa(binaryString);
5
6
  }
6
7
 
7
8
  export function base64Decode(str: string): string {
8
- return isNodejs ? Buffer.from(str, 'base64').toString() : atob(str);
9
+ const binaryString = atob(str);
10
+ const binaryArray = Uint8Array.from(binaryString, (char) => char.charCodeAt(0));
11
+ const decoder = new TextDecoder();
12
+ return decoder.decode(binaryArray);
9
13
  }
@@ -1,5 +1,6 @@
1
1
  interface ErrorData {
2
2
  code: number;
3
+ originalCode?: string;
3
4
  message: string;
4
5
  }
5
6
 
package/src/version.ts CHANGED
@@ -1 +1 @@
1
- export const LIB_VERSION = "0.3.3";
1
+ export const LIB_VERSION = "0.4.0-aplha.10";