@sogni-ai/sogni-client 1.0.2 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +8 -0
- package/README.md +126 -0
- package/dist/Account/CurrentAccount.d.ts +4 -5
- package/dist/Account/CurrentAccount.js +9 -29
- package/dist/Account/CurrentAccount.js.map +1 -1
- package/dist/Account/index.d.ts +5 -3
- package/dist/Account/index.js +20 -9
- package/dist/Account/index.js.map +1 -1
- package/dist/Account/types.d.ts +2 -0
- package/dist/ApiClient/WebSocketClient/index.d.ts +4 -4
- package/dist/ApiClient/WebSocketClient/index.js +33 -49
- package/dist/ApiClient/WebSocketClient/index.js.map +1 -1
- package/dist/ApiClient/WebSocketClient/messages.d.ts +7 -0
- package/dist/ApiClient/index.d.ts +4 -10
- package/dist/ApiClient/index.js +26 -17
- package/dist/ApiClient/index.js.map +1 -1
- package/dist/Projects/Project.d.ts +4 -0
- package/dist/Projects/Project.js +8 -0
- package/dist/Projects/Project.js.map +1 -1
- package/dist/Projects/createJobRequestMessage.js +12 -1
- package/dist/Projects/createJobRequestMessage.js.map +1 -1
- package/dist/Projects/index.d.ts +40 -2
- package/dist/Projects/index.js +146 -14
- package/dist/Projects/index.js.map +1 -1
- package/dist/Projects/types/index.d.ts +40 -0
- package/dist/lib/AuthManager.d.ts +51 -0
- package/dist/lib/AuthManager.js +157 -0
- package/dist/lib/AuthManager.js.map +1 -0
- package/dist/lib/Cache.d.ts +9 -0
- package/dist/lib/Cache.js +30 -0
- package/dist/lib/Cache.js.map +1 -0
- package/dist/lib/RestClient.d.ts +4 -7
- package/dist/lib/RestClient.js +7 -7
- package/dist/lib/RestClient.js.map +1 -1
- package/dist/lib/utils.d.ts +8 -0
- package/dist/lib/utils.js +20 -0
- package/dist/lib/utils.js.map +1 -0
- package/dist/version.d.ts +1 -1
- package/dist/version.js +1 -1
- package/dist/version.js.map +1 -1
- package/package.json +1 -1
- package/src/Account/CurrentAccount.ts +12 -33
- package/src/Account/index.ts +18 -8
- package/src/Account/types.ts +2 -0
- package/src/ApiClient/WebSocketClient/index.ts +16 -25
- package/src/ApiClient/WebSocketClient/messages.ts +8 -0
- package/src/ApiClient/index.ts +19 -27
- package/src/Projects/Project.ts +7 -0
- package/src/Projects/createJobRequestMessage.ts +14 -1
- package/src/Projects/index.ts +150 -8
- package/src/Projects/types/index.ts +42 -0
- package/src/lib/AuthManager.ts +172 -0
- package/src/lib/Cache.ts +36 -0
- package/src/lib/RestClient.ts +8 -13
- package/src/lib/utils.ts +17 -0
- package/src/version.ts +1 -1
- package/dist/Projects/models.json +0 -8906
- package/src/Projects/models.json +0 -8906
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
import { MessageType, SocketMessageMap } from './messages';
|
|
2
2
|
import { SocketEventMap } from './events';
|
|
3
|
-
import RestClient
|
|
3
|
+
import RestClient from '../../lib/RestClient';
|
|
4
4
|
import { SupernetType } from './types';
|
|
5
5
|
import WebSocket, { CloseEvent, ErrorEvent, MessageEvent } from 'isomorphic-ws';
|
|
6
6
|
import { base64Decode, base64Encode } from '../../lib/base64';
|
|
7
7
|
import isNodejs from '../../lib/isNodejs';
|
|
8
|
-
import Cookie from 'js-cookie';
|
|
9
8
|
import { LIB_VERSION } from '../../version';
|
|
10
9
|
import { Logger } from '../../lib/DefaultLogger';
|
|
10
|
+
import AuthManager from '../../lib/AuthManager';
|
|
11
|
+
|
|
12
|
+
const PROTOCOL_VERSION = '0.4.3';
|
|
11
13
|
|
|
12
14
|
const PING_INTERVAL = 15000;
|
|
13
15
|
|
|
@@ -18,34 +20,23 @@ class WebSocketClient extends RestClient<SocketEventMap> {
|
|
|
18
20
|
private _supernetType: SupernetType;
|
|
19
21
|
private _pingInterval: NodeJS.Timeout | null = null;
|
|
20
22
|
|
|
21
|
-
constructor(
|
|
23
|
+
constructor(
|
|
24
|
+
baseUrl: string,
|
|
25
|
+
auth: AuthManager,
|
|
26
|
+
appId: string,
|
|
27
|
+
supernetType: SupernetType,
|
|
28
|
+
logger: Logger
|
|
29
|
+
) {
|
|
22
30
|
const _baseUrl = new URL(baseUrl);
|
|
23
31
|
if (_baseUrl.protocol === 'wss:') {
|
|
24
32
|
_baseUrl.protocol = 'https:';
|
|
25
33
|
}
|
|
26
|
-
super(_baseUrl.toString(), logger);
|
|
34
|
+
super(_baseUrl.toString(), auth, logger);
|
|
27
35
|
this.appId = appId;
|
|
28
36
|
this.baseUrl = _baseUrl.toString();
|
|
29
37
|
this._supernetType = supernetType;
|
|
30
38
|
}
|
|
31
39
|
|
|
32
|
-
set auth(auth: AuthData | null) {
|
|
33
|
-
//In browser, set the cookie
|
|
34
|
-
if (!isNodejs) {
|
|
35
|
-
if (auth) {
|
|
36
|
-
Cookie.set('authorization', auth.token, {
|
|
37
|
-
domain: '.sogni.ai',
|
|
38
|
-
expires: 1
|
|
39
|
-
});
|
|
40
|
-
} else {
|
|
41
|
-
Cookie.remove('authorization', {
|
|
42
|
-
domain: '.sogni.ai'
|
|
43
|
-
});
|
|
44
|
-
}
|
|
45
|
-
}
|
|
46
|
-
this._auth = auth;
|
|
47
|
-
}
|
|
48
|
-
|
|
49
40
|
get supernetType(): SupernetType {
|
|
50
41
|
return this._supernetType;
|
|
51
42
|
}
|
|
@@ -54,11 +45,11 @@ class WebSocketClient extends RestClient<SocketEventMap> {
|
|
|
54
45
|
return !!this.socket;
|
|
55
46
|
}
|
|
56
47
|
|
|
57
|
-
connect() {
|
|
48
|
+
async connect() {
|
|
58
49
|
if (this.socket) {
|
|
59
50
|
this.disconnect();
|
|
60
51
|
}
|
|
61
|
-
const userAgent = `Sogni/${
|
|
52
|
+
const userAgent = `Sogni/${PROTOCOL_VERSION} (sogni-client) ${LIB_VERSION}`;
|
|
62
53
|
const url = new URL(this.baseUrl);
|
|
63
54
|
url.protocol = 'wss:';
|
|
64
55
|
url.searchParams.set('appId', this.appId);
|
|
@@ -71,7 +62,7 @@ class WebSocketClient extends RestClient<SocketEventMap> {
|
|
|
71
62
|
if (isNodejs) {
|
|
72
63
|
params = {
|
|
73
64
|
headers: {
|
|
74
|
-
Authorization: this.
|
|
65
|
+
Authorization: await this.auth.getToken(),
|
|
75
66
|
'User-Agent': userAgent
|
|
76
67
|
}
|
|
77
68
|
};
|
|
@@ -114,7 +105,7 @@ class WebSocketClient extends RestClient<SocketEventMap> {
|
|
|
114
105
|
}
|
|
115
106
|
|
|
116
107
|
switchNetwork(supernetType: SupernetType): Promise<SupernetType> {
|
|
117
|
-
return new Promise<SupernetType>(async (resolve
|
|
108
|
+
return new Promise<SupernetType>(async (resolve) => {
|
|
118
109
|
this.once('changeNetwork', ({ network }) => {
|
|
119
110
|
this._supernetType = network;
|
|
120
111
|
resolve(network);
|
|
@@ -1,8 +1,16 @@
|
|
|
1
1
|
import { JobRequestRaw } from '../../Projects/createJobRequestMessage';
|
|
2
2
|
import { SupernetType } from './types';
|
|
3
3
|
|
|
4
|
+
export interface JobErrorMessage {
|
|
5
|
+
jobID: string;
|
|
6
|
+
error: 'artistCanceled';
|
|
7
|
+
error_message: 'artistCanceled';
|
|
8
|
+
isFromWorker: false;
|
|
9
|
+
}
|
|
10
|
+
|
|
4
11
|
export interface SocketMessageMap {
|
|
5
12
|
jobRequest: JobRequestRaw;
|
|
13
|
+
jobError: JobErrorMessage;
|
|
6
14
|
changeNetwork: SupernetType;
|
|
7
15
|
}
|
|
8
16
|
|
package/src/ApiClient/index.ts
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import RestClient from '../lib/RestClient';
|
|
2
2
|
import WebSocketClient from './WebSocketClient';
|
|
3
|
-
import { jwtDecode } from 'jwt-decode';
|
|
4
3
|
import TypedEventEmitter from '../lib/TypedEventEmitter';
|
|
5
4
|
import { ApiClientEvents } from './events';
|
|
6
5
|
import { ServerConnectData, ServerDisconnectData } from './WebSocketClient/events';
|
|
@@ -8,6 +7,7 @@ import { isNotRecoverable } from './WebSocketClient/ErrorCode';
|
|
|
8
7
|
import { JSONValue } from '../types/json';
|
|
9
8
|
import { SupernetType } from './WebSocketClient/types';
|
|
10
9
|
import { Logger } from '../lib/DefaultLogger';
|
|
10
|
+
import AuthManager, { Tokens } from '../lib/AuthManager';
|
|
11
11
|
|
|
12
12
|
const WS_RECONNECT_ATTEMPTS = 5;
|
|
13
13
|
|
|
@@ -33,21 +33,12 @@ export class ApiError extends Error {
|
|
|
33
33
|
}
|
|
34
34
|
}
|
|
35
35
|
|
|
36
|
-
/**
|
|
37
|
-
* @inline
|
|
38
|
-
*/
|
|
39
|
-
interface AuthData {
|
|
40
|
-
token: string;
|
|
41
|
-
walletAddress: string;
|
|
42
|
-
expiresAt: Date;
|
|
43
|
-
}
|
|
44
|
-
|
|
45
36
|
class ApiClient extends TypedEventEmitter<ApiClientEvents> {
|
|
46
37
|
readonly appId: string;
|
|
47
38
|
readonly logger: Logger;
|
|
48
39
|
private _rest: RestClient;
|
|
49
40
|
private _socket: WebSocketClient;
|
|
50
|
-
private _auth:
|
|
41
|
+
private _auth: AuthManager;
|
|
51
42
|
private _reconnectAttempts = WS_RECONNECT_ATTEMPTS;
|
|
52
43
|
|
|
53
44
|
constructor(
|
|
@@ -60,18 +51,21 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
|
|
|
60
51
|
super();
|
|
61
52
|
this.appId = appId;
|
|
62
53
|
this.logger = logger;
|
|
63
|
-
this.
|
|
64
|
-
this.
|
|
54
|
+
this._auth = new AuthManager(baseUrl, logger);
|
|
55
|
+
this._rest = new RestClient(baseUrl, this._auth, logger);
|
|
56
|
+
this._socket = new WebSocketClient(socketUrl, this._auth, appId, networkType, logger);
|
|
57
|
+
|
|
58
|
+
this._auth.on('refreshFailed', this.handleRefreshFailed.bind(this));
|
|
65
59
|
this._socket.on('connected', this.handleSocketConnect.bind(this));
|
|
66
60
|
this._socket.on('disconnected', this.handleSocketDisconnect.bind(this));
|
|
67
61
|
}
|
|
68
62
|
|
|
69
63
|
get isAuthenticated(): boolean {
|
|
70
|
-
return
|
|
64
|
+
return this.auth.isAuthenticated;
|
|
71
65
|
}
|
|
72
66
|
|
|
73
|
-
get auth():
|
|
74
|
-
return this._auth
|
|
67
|
+
get auth(): AuthManager {
|
|
68
|
+
return this._auth;
|
|
75
69
|
}
|
|
76
70
|
|
|
77
71
|
get socket(): WebSocketClient {
|
|
@@ -82,20 +76,13 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
|
|
|
82
76
|
return this._rest;
|
|
83
77
|
}
|
|
84
78
|
|
|
85
|
-
authenticate(
|
|
86
|
-
|
|
87
|
-
this.
|
|
88
|
-
token,
|
|
89
|
-
walletAddress: decoded.addr,
|
|
90
|
-
expiresAt: new Date(decoded.exp * 1000)
|
|
91
|
-
};
|
|
92
|
-
this.rest.auth = { token };
|
|
93
|
-
this.socket.auth = { token };
|
|
94
|
-
this.socket.connect();
|
|
79
|
+
async authenticate(tokens: Tokens) {
|
|
80
|
+
await this.auth.setTokens(tokens);
|
|
81
|
+
await this.socket.connect();
|
|
95
82
|
}
|
|
96
83
|
|
|
97
84
|
removeAuth() {
|
|
98
|
-
this.
|
|
85
|
+
this.auth.clear();
|
|
99
86
|
this.socket.disconnect();
|
|
100
87
|
}
|
|
101
88
|
|
|
@@ -112,6 +99,7 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
|
|
|
112
99
|
return;
|
|
113
100
|
}
|
|
114
101
|
if (this._reconnectAttempts <= 0) {
|
|
102
|
+
this.removeAuth();
|
|
115
103
|
this.emit('disconnected', data);
|
|
116
104
|
this._reconnectAttempts = WS_RECONNECT_ATTEMPTS;
|
|
117
105
|
return;
|
|
@@ -119,6 +107,10 @@ class ApiClient extends TypedEventEmitter<ApiClientEvents> {
|
|
|
119
107
|
this._reconnectAttempts--;
|
|
120
108
|
setTimeout(() => this.socket.connect(), 1000);
|
|
121
109
|
}
|
|
110
|
+
|
|
111
|
+
handleRefreshFailed() {
|
|
112
|
+
this.removeAuth();
|
|
113
|
+
}
|
|
122
114
|
}
|
|
123
115
|
|
|
124
116
|
export default ApiClient;
|
package/src/Projects/Project.ts
CHANGED
|
@@ -164,6 +164,13 @@ class Project extends DataEntity<ProjectData, ProjectEventMap> {
|
|
|
164
164
|
});
|
|
165
165
|
}
|
|
166
166
|
|
|
167
|
+
/**
|
|
168
|
+
* Cancel the project. This will cancel all jobs in the project.
|
|
169
|
+
*/
|
|
170
|
+
async cancel() {
|
|
171
|
+
await this._api.cancel(this.id);
|
|
172
|
+
}
|
|
173
|
+
|
|
167
174
|
/**
|
|
168
175
|
* Find a job by id
|
|
169
176
|
* @param id
|
|
@@ -65,6 +65,14 @@ function getTemplate() {
|
|
|
65
65
|
};
|
|
66
66
|
}
|
|
67
67
|
|
|
68
|
+
function validateSize(value: any): number {
|
|
69
|
+
const size = Number(value);
|
|
70
|
+
if (isNaN(size) || size < 256 || size > 2048) {
|
|
71
|
+
throw new Error('Width and height must be numbers between 256 and 2048');
|
|
72
|
+
}
|
|
73
|
+
return size;
|
|
74
|
+
}
|
|
75
|
+
|
|
68
76
|
function createJobRequestMessage(id: string, params: ProjectParams) {
|
|
69
77
|
const template = getTemplate();
|
|
70
78
|
const jobRequest: Record<string, any> = {
|
|
@@ -84,7 +92,8 @@ function createJobRequestMessage(id: string, params: ProjectParams) {
|
|
|
84
92
|
strengthIsEnabled: !!params.startingImage,
|
|
85
93
|
strength: !!params.startingImage
|
|
86
94
|
? 1 - (Number(params.startingImageStrength) || 0.5)
|
|
87
|
-
: undefined
|
|
95
|
+
: undefined,
|
|
96
|
+
sizePreset: params.sizePreset
|
|
88
97
|
}
|
|
89
98
|
],
|
|
90
99
|
previews: params.numberOfPreviews || 0,
|
|
@@ -95,6 +104,10 @@ function createJobRequestMessage(id: string, params: ProjectParams) {
|
|
|
95
104
|
if (params.network) {
|
|
96
105
|
jobRequest.network = params.network;
|
|
97
106
|
}
|
|
107
|
+
if (params.sizePreset === 'custom') {
|
|
108
|
+
jobRequest.keyFrames[0].width = validateSize(params.width);
|
|
109
|
+
jobRequest.keyFrames[0].height = validateSize(params.height);
|
|
110
|
+
}
|
|
98
111
|
return jobRequest;
|
|
99
112
|
}
|
|
100
113
|
|
package/src/Projects/index.ts
CHANGED
|
@@ -1,6 +1,12 @@
|
|
|
1
1
|
import ApiGroup, { ApiConfig } from '../ApiGroup';
|
|
2
|
-
import
|
|
3
|
-
|
|
2
|
+
import {
|
|
3
|
+
AvailableModel,
|
|
4
|
+
EstimateRequest,
|
|
5
|
+
ImageUrlParams,
|
|
6
|
+
ProjectParams,
|
|
7
|
+
SizePreset,
|
|
8
|
+
SupportedModel
|
|
9
|
+
} from './types';
|
|
4
10
|
import {
|
|
5
11
|
JobErrorData,
|
|
6
12
|
JobProgressData,
|
|
@@ -16,8 +22,12 @@ import { JobEvent, ProjectApiEvents, ProjectEvent } from './types/events';
|
|
|
16
22
|
import getUUID from '../lib/getUUID';
|
|
17
23
|
import { RawProject } from './types/RawProject';
|
|
18
24
|
import ErrorData from '../types/ErrorData';
|
|
25
|
+
import { SupernetType } from '../ApiClient/WebSocketClient/types';
|
|
26
|
+
import Cache from '../lib/Cache';
|
|
19
27
|
|
|
28
|
+
const sizePresetCache = new Cache<SizePreset[]>(10 * 60 * 1000);
|
|
20
29
|
const GARBAGE_COLLECT_TIMEOUT = 10000;
|
|
30
|
+
const MODELS_REFRESH_INTERVAL = 1000 * 60 * 60 * 24; // 24 hours
|
|
21
31
|
|
|
22
32
|
function mapErrorCodes(code: string): number {
|
|
23
33
|
switch (code) {
|
|
@@ -39,6 +49,10 @@ function mapErrorCodes(code: string): number {
|
|
|
39
49
|
class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
40
50
|
private _availableModels: AvailableModel[] = [];
|
|
41
51
|
private projects: Project[] = [];
|
|
52
|
+
private _supportedModels: { data: SupportedModel[] | null; updatedAt: Date } = {
|
|
53
|
+
data: null,
|
|
54
|
+
updatedAt: new Date(0)
|
|
55
|
+
};
|
|
42
56
|
|
|
43
57
|
get availableModels() {
|
|
44
58
|
return this._availableModels;
|
|
@@ -65,14 +79,20 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
65
79
|
this.emit('availableModels', this._availableModels);
|
|
66
80
|
}
|
|
67
81
|
|
|
68
|
-
private handleSwarmModels(data: SocketEventMap['swarmModels']) {
|
|
69
|
-
|
|
70
|
-
|
|
82
|
+
private async handleSwarmModels(data: SocketEventMap['swarmModels']) {
|
|
83
|
+
let models: SupportedModel[] = [];
|
|
84
|
+
try {
|
|
85
|
+
models = await this.getSupportedModels();
|
|
86
|
+
} catch (e) {
|
|
87
|
+
this.client.logger.error(e);
|
|
88
|
+
}
|
|
89
|
+
const modelIndex = models.reduce((acc: Record<string, SupportedModel>, model) => {
|
|
90
|
+
acc[model.id] = model;
|
|
71
91
|
return acc;
|
|
72
92
|
}, {});
|
|
73
93
|
this._availableModels = Object.entries(data).map(([id, workerCount]) => ({
|
|
74
94
|
id,
|
|
75
|
-
name: modelIndex[id]?.
|
|
95
|
+
name: modelIndex[id]?.name || id.replace(/-/g, ' '),
|
|
76
96
|
workerCount
|
|
77
97
|
}));
|
|
78
98
|
this.emit('availableModels', this._availableModels);
|
|
@@ -339,6 +359,38 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
339
359
|
return data.project;
|
|
340
360
|
}
|
|
341
361
|
|
|
362
|
+
/**
|
|
363
|
+
* Cancel project by id. This will cancel all jobs in the project and mark project as canceled.
|
|
364
|
+
* Client may still receive job events for the canceled jobs as it takes some time, but they will
|
|
365
|
+
* be ignored
|
|
366
|
+
* @param projectId
|
|
367
|
+
**/
|
|
368
|
+
async cancel(projectId: string) {
|
|
369
|
+
await this.client.socket.send('jobError', {
|
|
370
|
+
jobID: projectId,
|
|
371
|
+
error: 'artistCanceled',
|
|
372
|
+
error_message: 'artistCanceled',
|
|
373
|
+
isFromWorker: false
|
|
374
|
+
});
|
|
375
|
+
const project = this.projects.find((p) => p.id === projectId);
|
|
376
|
+
if (!project) {
|
|
377
|
+
return;
|
|
378
|
+
}
|
|
379
|
+
// Remove project from the list to stop tracking it
|
|
380
|
+
this.projects = this.projects.filter((p) => p.id !== projectId);
|
|
381
|
+
|
|
382
|
+
// Cancel all jobs in the project
|
|
383
|
+
project.jobs.forEach((job) => {
|
|
384
|
+
if (!job.finished) {
|
|
385
|
+
job._update({ status: 'canceled' });
|
|
386
|
+
}
|
|
387
|
+
});
|
|
388
|
+
// If project is still in processing, mark it as canceled
|
|
389
|
+
if (!project.finished) {
|
|
390
|
+
project._update({ status: 'canceled' });
|
|
391
|
+
}
|
|
392
|
+
}
|
|
393
|
+
|
|
342
394
|
private async uploadGuideImage(projectId: string, file: File | Buffer | Blob) {
|
|
343
395
|
const imageId = getUUID();
|
|
344
396
|
const presignedUrl = await this.uploadUrl({
|
|
@@ -370,10 +422,32 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
370
422
|
stepCount,
|
|
371
423
|
previewCount,
|
|
372
424
|
cnEnabled,
|
|
373
|
-
startingImageStrength
|
|
425
|
+
startingImageStrength,
|
|
426
|
+
width,
|
|
427
|
+
height,
|
|
428
|
+
sizePreset
|
|
374
429
|
}: EstimateRequest) {
|
|
430
|
+
const pathParams = [
|
|
431
|
+
network,
|
|
432
|
+
model,
|
|
433
|
+
imageCount,
|
|
434
|
+
stepCount,
|
|
435
|
+
previewCount,
|
|
436
|
+
cnEnabled ? 1 : 0,
|
|
437
|
+
startingImageStrength ? 1 - startingImageStrength : 0
|
|
438
|
+
];
|
|
439
|
+
if (sizePreset) {
|
|
440
|
+
const presets = await this.getSizePresets(network, model);
|
|
441
|
+
const preset = presets.find((p) => p.id === sizePreset);
|
|
442
|
+
if (!preset) {
|
|
443
|
+
throw new Error('Invalid size preset');
|
|
444
|
+
}
|
|
445
|
+
pathParams.push(preset.width, preset.height);
|
|
446
|
+
} else if (width && height) {
|
|
447
|
+
pathParams.push(width, height);
|
|
448
|
+
}
|
|
375
449
|
const r = await this.client.socket.get<EstimationResponse>(
|
|
376
|
-
`/api/v1/job/estimate/${
|
|
450
|
+
`/api/v1/job/estimate/${pathParams.join('/')}`
|
|
377
451
|
);
|
|
378
452
|
return {
|
|
379
453
|
token: r.quote.project.costInToken,
|
|
@@ -406,6 +480,74 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
|
|
|
406
480
|
);
|
|
407
481
|
return r.data.downloadUrl;
|
|
408
482
|
}
|
|
483
|
+
|
|
484
|
+
async getSupportedModels(forceRefresh = false) {
|
|
485
|
+
if (
|
|
486
|
+
this._supportedModels.data &&
|
|
487
|
+
!forceRefresh &&
|
|
488
|
+
Date.now() - this._supportedModels.updatedAt.getTime() < MODELS_REFRESH_INTERVAL
|
|
489
|
+
) {
|
|
490
|
+
return this._supportedModels.data;
|
|
491
|
+
}
|
|
492
|
+
const models = await this.client.socket.get<SupportedModel[]>(`/api/v1/models/list`);
|
|
493
|
+
this._supportedModels = { data: models, updatedAt: new Date() };
|
|
494
|
+
return models;
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
/**
|
|
498
|
+
* Get supported size presets for the model and network. Size presets are cached for 10 minutes.
|
|
499
|
+
*
|
|
500
|
+
* @example
|
|
501
|
+
* ```ts
|
|
502
|
+
* const presets = await client.projects.getSizePresets('fast', 'flux1-schnell-fp8');
|
|
503
|
+
* console.log(presets);
|
|
504
|
+
* ```
|
|
505
|
+
*
|
|
506
|
+
* @param network - 'fast' or 'relaxed'
|
|
507
|
+
* @param modelId - model id (e.g. 'flux1-schnell-fp8')
|
|
508
|
+
* @param forceRefresh - force refresh cache
|
|
509
|
+
* @returns {Promise<{
|
|
510
|
+
* label: string;
|
|
511
|
+
* id: string;
|
|
512
|
+
* width: number;
|
|
513
|
+
* height: number;
|
|
514
|
+
* ratio: string;
|
|
515
|
+
* aspect: string;
|
|
516
|
+
* }[]>}
|
|
517
|
+
*/
|
|
518
|
+
async getSizePresets(network: SupernetType, modelId: string, forceRefresh = false) {
|
|
519
|
+
const key = `${network}-${modelId}`;
|
|
520
|
+
const cached = sizePresetCache.read(key);
|
|
521
|
+
if (cached && !forceRefresh) {
|
|
522
|
+
return cached;
|
|
523
|
+
}
|
|
524
|
+
const data = await this.client.socket.get<SizePreset[]>(
|
|
525
|
+
`/api/v1/size-presets/network/${network}/model/${modelId}`
|
|
526
|
+
);
|
|
527
|
+
sizePresetCache.write(key, data);
|
|
528
|
+
return data;
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
/**
|
|
532
|
+
* Get available models and their worker counts. Normally, you would get list once you connect
|
|
533
|
+
* to the server, but you can also call this method to get the list of available models manually.
|
|
534
|
+
* @param network
|
|
535
|
+
*/
|
|
536
|
+
async getAvailableModels(network: SupernetType): Promise<AvailableModel[]> {
|
|
537
|
+
const workersByModelSid = await this.client.socket.get<Record<string, number>>(
|
|
538
|
+
`/api/v1/status/network/${network}/models`
|
|
539
|
+
);
|
|
540
|
+
const supportedModels = await this.getSupportedModels();
|
|
541
|
+
return Object.entries(workersByModelSid).map(([sid, workerCount]) => {
|
|
542
|
+
const SID = Number(sid);
|
|
543
|
+
const model = supportedModels.find((m) => m.SID === SID);
|
|
544
|
+
return {
|
|
545
|
+
id: model?.id || sid,
|
|
546
|
+
name: model?.name || sid.replace(/-/g, ' '),
|
|
547
|
+
workerCount
|
|
548
|
+
};
|
|
549
|
+
});
|
|
550
|
+
}
|
|
409
551
|
}
|
|
410
552
|
|
|
411
553
|
export default ProjectsApi;
|
|
@@ -1,11 +1,26 @@
|
|
|
1
1
|
import { SupernetType } from '../../ApiClient/WebSocketClient/types';
|
|
2
2
|
|
|
3
|
+
export interface SupportedModel {
|
|
4
|
+
id: string;
|
|
5
|
+
name: string;
|
|
6
|
+
SID: number;
|
|
7
|
+
}
|
|
8
|
+
|
|
3
9
|
export interface AvailableModel {
|
|
4
10
|
id: string;
|
|
5
11
|
name: string;
|
|
6
12
|
workerCount: number;
|
|
7
13
|
}
|
|
8
14
|
|
|
15
|
+
export interface SizePreset {
|
|
16
|
+
label: string;
|
|
17
|
+
id: string;
|
|
18
|
+
width: number;
|
|
19
|
+
height: number;
|
|
20
|
+
ratio: string;
|
|
21
|
+
aspect: string;
|
|
22
|
+
}
|
|
23
|
+
|
|
9
24
|
export interface AiModel {
|
|
10
25
|
isSD3: boolean;
|
|
11
26
|
modelShortName: string;
|
|
@@ -108,6 +123,19 @@ export interface ProjectParams {
|
|
|
108
123
|
* Time step spacing method
|
|
109
124
|
*/
|
|
110
125
|
timeStepSpacing?: TimeStepSpacing;
|
|
126
|
+
/**
|
|
127
|
+
* Size preset ID to use. You can query available size presets
|
|
128
|
+
* from `client.projects.sizePresets(network, modelId)`
|
|
129
|
+
*/
|
|
130
|
+
sizePreset?: 'custom' | string;
|
|
131
|
+
/**
|
|
132
|
+
* Output image width. Only used if `sizePreset` is "custom"
|
|
133
|
+
*/
|
|
134
|
+
width?: number;
|
|
135
|
+
/**
|
|
136
|
+
* Output image height. Only used if `sizePreset` is "custom"
|
|
137
|
+
*/
|
|
138
|
+
height?: number;
|
|
111
139
|
}
|
|
112
140
|
|
|
113
141
|
export type ImageUrlParams = {
|
|
@@ -147,4 +175,18 @@ export interface EstimateRequest {
|
|
|
147
175
|
* How strong effect of starting image should be. From 0 to 1, default 0.5
|
|
148
176
|
*/
|
|
149
177
|
startingImageStrength?: number;
|
|
178
|
+
/**
|
|
179
|
+
* Size preset ID
|
|
180
|
+
*/
|
|
181
|
+
sizePreset?: string;
|
|
182
|
+
/**
|
|
183
|
+
* Size preset image width, if not using size preset
|
|
184
|
+
* @internal
|
|
185
|
+
*/
|
|
186
|
+
width?: number;
|
|
187
|
+
/**
|
|
188
|
+
* Size preset image height, if not using size preset
|
|
189
|
+
* @internal
|
|
190
|
+
*/
|
|
191
|
+
height?: number;
|
|
150
192
|
}
|