@sogni-ai/sogni-client 4.0.0-alpha.5 → 4.0.0-alpha.51

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. package/CHANGELOG.md +357 -0
  2. package/README.md +295 -58
  3. package/dist/Account/index.d.ts +18 -16
  4. package/dist/Account/index.js +42 -21
  5. package/dist/Account/index.js.map +1 -1
  6. package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/ChannelCoordinator.d.ts +66 -0
  7. package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/ChannelCoordinator.js +332 -0
  8. package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/ChannelCoordinator.js.map +1 -0
  9. package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/index.d.ts +28 -0
  10. package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/index.js +203 -0
  11. package/dist/ApiClient/WebSocketClient/BrowserWebSocketClient/index.js.map +1 -0
  12. package/dist/ApiClient/WebSocketClient/events.d.ts +12 -0
  13. package/dist/ApiClient/WebSocketClient/index.d.ts +2 -2
  14. package/dist/ApiClient/WebSocketClient/index.js +13 -3
  15. package/dist/ApiClient/WebSocketClient/index.js.map +1 -1
  16. package/dist/ApiClient/WebSocketClient/types.d.ts +13 -0
  17. package/dist/ApiClient/index.d.ts +4 -4
  18. package/dist/ApiClient/index.js +23 -4
  19. package/dist/ApiClient/index.js.map +1 -1
  20. package/dist/Projects/Job.d.ts +44 -4
  21. package/dist/Projects/Job.js +83 -16
  22. package/dist/Projects/Job.js.map +1 -1
  23. package/dist/Projects/Project.d.ts +18 -0
  24. package/dist/Projects/Project.js +38 -10
  25. package/dist/Projects/Project.js.map +1 -1
  26. package/dist/Projects/createJobRequestMessage.d.ts +2 -1
  27. package/dist/Projects/createJobRequestMessage.js +173 -14
  28. package/dist/Projects/createJobRequestMessage.js.map +1 -1
  29. package/dist/Projects/index.d.ts +114 -11
  30. package/dist/Projects/index.js +504 -47
  31. package/dist/Projects/index.js.map +1 -1
  32. package/dist/Projects/types/ComfySamplerParams.d.ts +0 -0
  33. package/dist/Projects/types/ComfySamplerParams.js +2 -0
  34. package/dist/Projects/types/ComfySamplerParams.js.map +1 -0
  35. package/dist/Projects/types/EstimationResponse.d.ts +2 -0
  36. package/dist/Projects/types/ModelOptions.d.ts +31 -0
  37. package/dist/Projects/types/ModelOptions.js +56 -0
  38. package/dist/Projects/types/ModelOptions.js.map +1 -0
  39. package/dist/Projects/types/ModelTiersRaw.d.ts +67 -0
  40. package/dist/Projects/types/ModelTiersRaw.js +15 -0
  41. package/dist/Projects/types/ModelTiersRaw.js.map +1 -0
  42. package/dist/Projects/types/events.d.ts +5 -1
  43. package/dist/Projects/types/index.d.ts +219 -42
  44. package/dist/Projects/types/index.js +8 -0
  45. package/dist/Projects/types/index.js.map +1 -1
  46. package/dist/Projects/utils/index.d.ts +20 -0
  47. package/dist/Projects/utils/index.js +91 -0
  48. package/dist/Projects/utils/index.js.map +1 -0
  49. package/dist/Projects/utils/samplers.d.ts +6 -0
  50. package/dist/Projects/utils/samplers.js +39 -0
  51. package/dist/Projects/utils/samplers.js.map +1 -0
  52. package/dist/Projects/utils/scheduler.d.ts +6 -0
  53. package/dist/Projects/utils/scheduler.js +30 -0
  54. package/dist/Projects/utils/scheduler.js.map +1 -0
  55. package/dist/index.d.ts +11 -3
  56. package/dist/index.js +8 -3
  57. package/dist/index.js.map +1 -1
  58. package/dist/lib/AuthManager/TokenAuthManager.js +0 -2
  59. package/dist/lib/AuthManager/TokenAuthManager.js.map +1 -1
  60. package/dist/lib/DataEntity.js +4 -2
  61. package/dist/lib/DataEntity.js.map +1 -1
  62. package/dist/lib/RestClient.js +15 -2
  63. package/dist/lib/RestClient.js.map +1 -1
  64. package/dist/lib/{utils.js → utils/index.js} +1 -1
  65. package/dist/lib/utils/index.js.map +1 -0
  66. package/dist/lib/validation.d.ts +31 -2
  67. package/dist/lib/validation.js +80 -13
  68. package/dist/lib/validation.js.map +1 -1
  69. package/package.json +4 -4
  70. package/src/Account/index.ts +39 -20
  71. package/src/ApiClient/WebSocketClient/BrowserWebSocketClient/ChannelCoordinator.ts +426 -0
  72. package/src/ApiClient/WebSocketClient/BrowserWebSocketClient/index.ts +237 -0
  73. package/src/ApiClient/WebSocketClient/events.ts +14 -0
  74. package/src/ApiClient/WebSocketClient/index.ts +15 -5
  75. package/src/ApiClient/WebSocketClient/types.ts +16 -0
  76. package/src/ApiClient/index.ts +30 -8
  77. package/src/Projects/Job.ts +97 -16
  78. package/src/Projects/Project.ts +46 -13
  79. package/src/Projects/createJobRequestMessage.ts +239 -34
  80. package/src/Projects/index.ts +533 -51
  81. package/src/Projects/types/ComfySamplerParams.ts +0 -0
  82. package/src/Projects/types/EstimationResponse.ts +2 -0
  83. package/src/Projects/types/ModelOptions.ts +92 -0
  84. package/src/Projects/types/ModelTiersRaw.ts +86 -0
  85. package/src/Projects/types/events.ts +6 -0
  86. package/src/Projects/types/index.ts +253 -45
  87. package/src/Projects/utils/index.ts +90 -0
  88. package/src/Projects/utils/samplers.ts +36 -0
  89. package/src/Projects/utils/scheduler.ts +27 -0
  90. package/src/index.ts +36 -9
  91. package/src/lib/AuthManager/TokenAuthManager.ts +0 -2
  92. package/src/lib/DataEntity.ts +4 -2
  93. package/src/lib/RestClient.ts +16 -2
  94. package/src/lib/validation.ts +90 -17
  95. package/dist/Projects/types/SamplerParams.d.ts +0 -15
  96. package/dist/Projects/types/SamplerParams.js +0 -21
  97. package/dist/Projects/types/SamplerParams.js.map +0 -1
  98. package/dist/Projects/types/SchedulerParams.d.ts +0 -13
  99. package/dist/Projects/types/SchedulerParams.js +0 -19
  100. package/dist/Projects/types/SchedulerParams.js.map +0 -1
  101. package/dist/Projects/utils.d.ts +0 -2
  102. package/dist/Projects/utils.js +0 -14
  103. package/dist/Projects/utils.js.map +0 -1
  104. package/dist/lib/utils.js.map +0 -1
  105. package/src/Projects/types/SamplerParams.ts +0 -19
  106. package/src/Projects/types/SchedulerParams.ts +0 -17
  107. package/src/Projects/utils.ts +0 -12
  108. /package/dist/lib/{utils.d.ts → utils/index.d.ts} +0 -0
  109. /package/src/lib/{utils.ts → utils/index.ts} +0 -0
@@ -4,12 +4,18 @@ import {
4
4
  EnhancementStrength,
5
5
  EstimateRequest,
6
6
  ImageUrlParams,
7
+ MediaUrlParams,
8
+ CostEstimation,
7
9
  ProjectParams,
8
10
  SizePreset,
9
- SupportedModel
11
+ SupportedModel,
12
+ ImageProjectParams,
13
+ VideoProjectParams,
14
+ VideoEstimateRequest
10
15
  } from './types';
11
16
  import {
12
17
  JobErrorData,
18
+ JobETAData,
13
19
  JobProgressData,
14
20
  JobResultData,
15
21
  JobStateData,
@@ -26,13 +32,50 @@ import ErrorData from '../types/ErrorData';
26
32
  import { SupernetType } from '../ApiClient/WebSocketClient/types';
27
33
  import Cache from '../lib/Cache';
28
34
  import { enhancementDefaults } from './Job';
29
- import { getEnhacementStrength } from './utils';
35
+ import {
36
+ getEnhacementStrength,
37
+ getVideoWorkflowType,
38
+ isVideoModel,
39
+ VIDEO_WORKFLOW_ASSETS
40
+ } from './utils';
30
41
  import { TokenType } from '../types/token';
42
+ import { getMaxContextImages, validateSampler } from '../lib/validation';
43
+ import ModelTiersRaw, { isComfyImageTier, isImageTier, isVideoTier } from './types/ModelTiersRaw';
44
+ import { mapComfyImageTier, mapImageTier, mapVideoTier, ModelOptions } from './types/ModelOptions';
31
45
 
32
46
  const sizePresetCache = new Cache<SizePreset[]>(10 * 60 * 1000);
33
47
  const GARBAGE_COLLECT_TIMEOUT = 30000;
34
48
  const MODELS_REFRESH_INTERVAL = 1000 * 60 * 60 * 24; // 24 hours
35
49
 
50
+ /**
51
+ * Detect content type from a file object.
52
+ * For File objects in browser, uses the type property.
53
+ * Returns undefined if content type cannot be detected.
54
+ */
55
+ function getFileContentType(file: File | Buffer | Blob): string | undefined {
56
+ if (file instanceof Blob && 'type' in file && file.type) {
57
+ return file.type;
58
+ }
59
+ return undefined;
60
+ }
61
+
62
+ /**
63
+ * Convert file to a format compatible with fetch body.
64
+ * Converts Node.js Buffer to Blob for cross-platform compatibility.
65
+ */
66
+ function toFetchBody(file: File | Buffer | Blob): BodyInit {
67
+ // Node.js Buffer is not supported in browsers, so we can skip this conversion
68
+ if (typeof Buffer === 'undefined') {
69
+ return file as BodyInit;
70
+ }
71
+ if (Buffer.isBuffer(file)) {
72
+ // Copy Buffer data to a new ArrayBuffer to ensure type compatibility
73
+ const arrayBuffer = file.buffer.slice(file.byteOffset, file.byteOffset + file.byteLength);
74
+ return new Blob([arrayBuffer as ArrayBuffer]);
75
+ }
76
+ return file as BodyInit;
77
+ }
78
+
36
79
  function mapErrorCodes(code: string): number {
37
80
  switch (code) {
38
81
  case 'serverRestarting':
@@ -57,11 +100,32 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
57
100
  data: null,
58
101
  updatedAt: new Date(0)
59
102
  };
103
+ private _modelTiers: {
104
+ data: ModelTiersRaw;
105
+ updatedAt: Date;
106
+ } = {
107
+ data: {},
108
+ updatedAt: new Date(0)
109
+ };
60
110
 
61
111
  get availableModels() {
62
112
  return this._availableModels;
63
113
  }
64
114
 
115
+ /**
116
+ * Check if a model produces video output using the cached models list.
117
+ * Uses the `media` property from the models API when available,
118
+ * falls back to model ID prefix check if models aren't loaded yet.
119
+ */
120
+ isVideoModelId(modelId: string): boolean {
121
+ const model = this._supportedModels.data?.find((m) => m.id === modelId);
122
+ if (model) {
123
+ return model.media === 'video';
124
+ }
125
+ // Fallback to prefix check if models not loaded
126
+ return isVideoModel(modelId);
127
+ }
128
+
65
129
  constructor(config: ApiConfig) {
66
130
  super(config);
67
131
  // Listen to server events and emit them as project and job events
@@ -69,8 +133,13 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
69
133
  this.client.socket.on('swarmModels', this.handleSwarmModels.bind(this));
70
134
  this.client.socket.on('jobState', this.handleJobState.bind(this));
71
135
  this.client.socket.on('jobProgress', this.handleJobProgress.bind(this));
136
+ this.client.socket.on('jobETA', this.handleJobETA.bind(this));
72
137
  this.client.socket.on('jobError', this.handleJobError.bind(this));
73
- this.client.socket.on('jobResult', this.handleJobResult.bind(this));
138
+ this.client.socket.on('jobResult', (data: any) => {
139
+ this.handleJobResult(data).catch((err) => {
140
+ this.client.logger.error('Error in handleJobResult:', err);
141
+ });
142
+ });
74
143
  // Listen to the server disconnect event
75
144
  this.client.on('disconnected', this.handleServerDisconnected.bind(this));
76
145
  // Listen to project and job events and update project and job instances
@@ -78,6 +147,17 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
78
147
  this.on('job', this.handleJobEvent.bind(this));
79
148
  }
80
149
 
150
+ /**
151
+ * Retrieves a list of projects created and tracked by this SogniClient instance.
152
+ *
153
+ * Note: When a project is finished, it will be removed from this list after 30 seconds
154
+ *
155
+ * @return {Array} A copy of the array containing the tracked projects.
156
+ */
157
+ get trackedProjects() {
158
+ return this.projects.slice(0);
159
+ }
160
+
81
161
  private handleChangeNetwork() {
82
162
  this._availableModels = [];
83
163
  this.emit('availableModels', this._availableModels);
@@ -97,7 +177,8 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
97
177
  this._availableModels = Object.entries(data).map(([id, workerCount]) => ({
98
178
  id,
99
179
  name: modelIndex[id]?.name || id.replace(/-/g, ' '),
100
- workerCount
180
+ workerCount,
181
+ media: modelIndex[id]?.media || 'image'
101
182
  }));
102
183
  this.emit('availableModels', this._availableModels);
103
184
  }
@@ -165,20 +246,60 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
165
246
  }
166
247
  }
167
248
 
249
+ private async handleJobETA(data: JobETAData) {
250
+ this.emit('job', {
251
+ type: 'jobETA',
252
+ projectId: data.jobID,
253
+ jobId: data.imgID || '',
254
+ etaSeconds: data.etaSeconds
255
+ });
256
+ }
257
+
168
258
  private async handleJobResult(data: JobResultData) {
169
259
  const project = this.projects.find((p) => p.id === data.jobID);
170
260
  const passNSFWCheck = !data.triggeredNSFWFilter || !project || project.params.disableNSFWFilter;
171
- let downloadUrl = null;
172
- // If NSFW filter is triggered, image will be only available for download if user explicitly
173
- // disabled the filter for this project
174
- if (passNSFWCheck && !data.userCanceled) {
175
- downloadUrl = await this.downloadUrl({
176
- jobId: data.jobID,
177
- imageId: data.imgID,
178
- type: 'complete'
179
- });
261
+ let downloadUrl = data.resultUrl || null; // Use resultUrl from event if provided
262
+
263
+ // If no resultUrl provided and NSFW check passes, generate download URL
264
+ if (!downloadUrl && passNSFWCheck && !data.userCanceled) {
265
+ // Use media endpoint for video models, image endpoint for image models
266
+ const isVideo = project && this.isVideoModelId(project.params.modelId);
267
+ try {
268
+ if (isVideo) {
269
+ downloadUrl = await this.mediaDownloadUrl({
270
+ jobId: data.jobID,
271
+ id: data.imgID,
272
+ type: 'complete'
273
+ });
274
+ } else {
275
+ downloadUrl = await this.downloadUrl({
276
+ jobId: data.jobID,
277
+ imageId: data.imgID,
278
+ type: 'complete'
279
+ });
280
+ }
281
+ } catch (error: any) {
282
+ this.client.logger.error('Failed to generate download URL for job result');
283
+ this.client.logger.error(error);
284
+ }
180
285
  }
181
286
 
287
+ // Update the job directly with the result URL to prevent duplicate API calls
288
+ if (project) {
289
+ const job = project.job(data.imgID);
290
+ if (job) {
291
+ job._update({
292
+ status: data.userCanceled ? 'canceled' : 'completed',
293
+ step: data.performedStepCount,
294
+ seed: Number(data.lastSeed),
295
+ resultUrl: downloadUrl,
296
+ isNSFW: data.triggeredNSFWFilter,
297
+ userCanceled: data.userCanceled
298
+ });
299
+ }
300
+ }
301
+
302
+ // Emit job completion event with the generated download URL
182
303
  this.emit('job', {
183
304
  type: 'completed',
184
305
  projectId: data.jobID,
@@ -248,7 +369,11 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
248
369
  if (project.finished) {
249
370
  // Sync project data with the server and remove it from the list after some time
250
371
  project._syncToServer().catch((e) => {
251
- this.client.logger.error(e);
372
+ // 404 errors are expected when project is still initializing
373
+ // Only log non-404 errors to avoid confusing users
374
+ if (e.status !== 404) {
375
+ this.client.logger.error(e);
376
+ }
252
377
  });
253
378
  setTimeout(() => {
254
379
  this.projects = this.projects.filter((p) => !p.finished);
@@ -268,7 +393,7 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
268
393
  projectId: event.projectId,
269
394
  status: 'pending',
270
395
  step: 0,
271
- stepCount: project.params.steps
396
+ stepCount: project.params.steps ?? 0
272
397
  });
273
398
  }
274
399
  switch (event.type) {
@@ -295,7 +420,7 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
295
420
  case 'progress':
296
421
  job._update({
297
422
  status: 'processing',
298
- // Jus in case event comes out of order
423
+ // Just in case event comes out of order
299
424
  step: Math.max(event.step, job.step),
300
425
  stepCount: event.stepCount
301
426
  });
@@ -303,6 +428,24 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
303
428
  project._update({ status: 'processing' });
304
429
  }
305
430
  break;
431
+ case 'jobETA': {
432
+ // ETA updates keep the project alive (refreshes lastUpdated) and store the ETA value.
433
+ // This is critical for long-running jobs like video generation that can take several
434
+ // minutes and may not send frequent progress updates.
435
+ // We always call _keepAlive() to ensure lastUpdated is refreshed, preventing premature timeouts.
436
+ project._keepAlive();
437
+
438
+ const newEta = new Date(Date.now() + event.etaSeconds * 1000);
439
+ if (job.eta?.getTime() !== newEta?.getTime()) {
440
+ job._update({ eta: newEta });
441
+ const maxEta = project.jobs.reduce((max, j) => Math.max(max, j.eta?.getTime() || 0), 0);
442
+ const projectETA = maxEta ? new Date(maxEta) : undefined;
443
+ if (project.eta?.getTime() !== projectETA?.getTime()) {
444
+ project._update({ eta: projectETA });
445
+ }
446
+ }
447
+ break;
448
+ }
306
449
  case 'preview':
307
450
  job._update({ previewUrl: event.url });
308
451
  break;
@@ -319,6 +462,17 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
319
462
  }
320
463
  case 'error':
321
464
  job._update({ status: 'failed', error: event.error });
465
+ // Check if project should also fail when a job fails
466
+ // For video jobs (single image) or when all jobs have failed, propagate to project
467
+ const allJobsStarted = project.jobs.length >= project.params.numberOfMedia;
468
+ const allJobsFailed = allJobsStarted && project.jobs.every((j) => j.status === 'failed');
469
+ const isSingleJobProject = project.params.numberOfMedia === 1;
470
+ if (isSingleJobProject || allJobsFailed) {
471
+ project._update({
472
+ status: 'failed',
473
+ error: event.error
474
+ });
475
+ }
322
476
  break;
323
477
  }
324
478
  }
@@ -341,53 +495,97 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
341
495
  return Promise.resolve(this._availableModels);
342
496
  }
343
497
  return new Promise((resolve, reject) => {
498
+ let settled = false;
344
499
  const timeoutId = setTimeout(() => {
345
- reject(new Error('Timeout waiting for models'));
500
+ if (!settled) {
501
+ settled = true;
502
+ this.off('availableModels', handler);
503
+ reject(new Error('Timeout waiting for models'));
504
+ }
346
505
  }, timeout);
347
- this.once('availableModels', (models) => {
348
- clearTimeout(timeoutId);
349
- if (models.length) {
506
+
507
+ const handler = (models: AvailableModel[]) => {
508
+ // Only resolve when we get a non-empty models list
509
+ // Empty arrays may be emitted during disconnects/reconnects
510
+ if (models.length && !settled) {
511
+ settled = true;
512
+ clearTimeout(timeoutId);
513
+ this.off('availableModels', handler);
350
514
  resolve(models);
351
- } else {
352
- reject(new Error('No models available'));
353
515
  }
354
- });
516
+ };
517
+
518
+ this.on('availableModels', handler);
355
519
  });
356
520
  }
357
521
 
358
522
  /**
359
523
  * Send new project request to the network. Returns project instance which can be used to track
360
- * progress and get resulting images.
524
+ * progress and get resulting images or videos.
361
525
  * @param data
362
526
  */
363
527
  async create(data: ProjectParams): Promise<Project> {
364
528
  const project = new Project({ ...data }, { api: this, logger: this.client.logger });
529
+ const modelOptions = await this.getModelOptions(data.modelId);
530
+ const request = createJobRequestMessage(project.id, data, modelOptions);
531
+
532
+ switch (data.type) {
533
+ case 'image':
534
+ await this._processImageAssets(project, data);
535
+ break;
536
+ case 'video':
537
+ await this._processVideoAssets(project, data);
538
+ break;
539
+ }
540
+ await this.client.socket.send('jobRequest', request);
541
+ this.projects.push(project);
542
+ return project;
543
+ }
544
+
545
+ private async _processImageAssets(project: Project, data: ImageProjectParams) {
546
+ //Guide image
365
547
  if (data.startingImage && data.startingImage !== true) {
366
548
  await this.uploadGuideImage(project.id, data.startingImage);
367
549
  }
550
+
551
+ // ControlNet image
368
552
  if (data.controlNet?.image && data.controlNet.image !== true) {
369
553
  await this.uploadCNImage(project.id, data.controlNet.image);
370
554
  }
555
+
556
+ // Context images (Flux.2 Dev supports up to 6; Qwen Image Edit Plus supports up to 3; Flux Kontext supports up to 2)
371
557
  if (data.contextImages?.length) {
372
- if (data.contextImages.length > 2) {
558
+ const maxContextImages = getMaxContextImages(data.modelId);
559
+ if (data.contextImages.length > maxContextImages) {
373
560
  throw new ApiError(500, {
374
561
  status: 'error',
375
562
  errorCode: 0,
376
- message: `Up to 2 context images are supported`
563
+ message: `Up to ${maxContextImages} context images are supported for this model`
377
564
  });
378
565
  }
379
566
  await Promise.all(
380
567
  data.contextImages.map((image, index) => {
381
568
  if (image && image !== true) {
382
- return this.uploadContextImage(project.id, index as 0 | 1, image);
569
+ return this.uploadContextImage(project.id, index as 0 | 1 | 2 | 3 | 4 | 5, image);
383
570
  }
384
571
  })
385
572
  );
386
573
  }
387
- const request = createJobRequestMessage(project.id, data);
388
- await this.client.socket.send('jobRequest', request);
389
- this.projects.push(project);
390
- return project;
574
+ }
575
+
576
+ private async _processVideoAssets(project: Project, data: VideoProjectParams) {
577
+ if (data?.referenceImage && data.referenceImage !== true) {
578
+ await this.uploadReferenceImage(project.id, data.referenceImage);
579
+ }
580
+ if (data?.referenceImageEnd && data.referenceImageEnd !== true) {
581
+ await this.uploadReferenceImageEnd(project.id, data.referenceImageEnd);
582
+ }
583
+ if (data?.referenceAudio && data.referenceAudio !== true) {
584
+ await this.uploadReferenceAudio(project.id, data.referenceAudio);
585
+ }
586
+ if (data?.referenceVideo && data.referenceVideo !== true) {
587
+ await this.uploadReferenceVideo(project.id, data.referenceVideo);
588
+ }
391
589
  }
392
590
 
393
591
  /**
@@ -422,7 +620,6 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
422
620
  }
423
621
  // Remove project from the list to stop tracking it
424
622
  this.projects = this.projects.filter((p) => p.id !== projectId);
425
-
426
623
  // Cancel all jobs in the project
427
624
  project.jobs.forEach((job) => {
428
625
  if (!job.finished) {
@@ -438,13 +635,13 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
438
635
  private async uploadGuideImage(projectId: string, file: File | Buffer | Blob) {
439
636
  const imageId = getUUID();
440
637
  const presignedUrl = await this.uploadUrl({
441
- imageId: imageId,
638
+ imageId,
442
639
  jobId: projectId,
443
640
  type: 'startingImage'
444
641
  });
445
642
  const res = await fetch(presignedUrl, {
446
643
  method: 'PUT',
447
- body: file
644
+ body: toFetchBody(file)
448
645
  });
449
646
  if (!res.ok) {
450
647
  throw new ApiError(res.status, {
@@ -459,13 +656,13 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
459
656
  private async uploadCNImage(projectId: string, file: File | Buffer | Blob) {
460
657
  const imageId = getUUID();
461
658
  const presignedUrl = await this.uploadUrl({
462
- imageId: imageId,
659
+ imageId,
463
660
  jobId: projectId,
464
661
  type: 'cnImage'
465
662
  });
466
663
  const res = await fetch(presignedUrl, {
467
664
  method: 'PUT',
468
- body: file
665
+ body: toFetchBody(file)
469
666
  });
470
667
  if (!res.ok) {
471
668
  throw new ApiError(res.status, {
@@ -477,17 +674,22 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
477
674
  return imageId;
478
675
  }
479
676
 
480
- private async uploadContextImage(projectId: string, index: 0 | 1, file: File | Buffer | Blob) {
677
+ private async uploadContextImage(
678
+ projectId: string,
679
+ index: 0 | 1 | 2 | 3 | 4 | 5,
680
+ file: File | Buffer | Blob
681
+ ) {
481
682
  const imageId = getUUID();
482
- const imageIndex = (index + 1) as 1 | 2;
683
+ const imageIndex = (index + 1) as 1 | 2 | 3 | 4 | 5 | 6;
483
684
  const presignedUrl = await this.uploadUrl({
484
685
  imageId,
485
686
  jobId: projectId,
486
687
  type: `contextImage${imageIndex}`
487
688
  });
689
+ const body = toFetchBody(file);
488
690
  const res = await fetch(presignedUrl, {
489
691
  method: 'PUT',
490
- body: file
692
+ body
491
693
  });
492
694
  if (!res.ok) {
493
695
  throw new ApiError(res.status, {
@@ -499,8 +701,124 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
499
701
  return imageId;
500
702
  }
501
703
 
704
+ // ============================================
705
+ // VIDEO WORKFLOW UPLOADS (WAN 2.2)
706
+ // ============================================
707
+
708
+ /**
709
+ * Upload reference image for WAN video workflows
710
+ * @internal
711
+ */
712
+ private async uploadReferenceImage(projectId: string, file: File | Buffer | Blob) {
713
+ const imageId = getUUID();
714
+ const presignedUrl = await this.uploadUrl({
715
+ imageId,
716
+ jobId: projectId,
717
+ type: 'referenceImage'
718
+ });
719
+ const res = await fetch(presignedUrl, {
720
+ method: 'PUT',
721
+ body: toFetchBody(file)
722
+ });
723
+ if (!res.ok) {
724
+ throw new ApiError(res.status, {
725
+ status: 'error',
726
+ errorCode: 0,
727
+ message: 'Failed to upload reference image'
728
+ });
729
+ }
730
+ return imageId;
731
+ }
732
+
733
+ /**
734
+ * Upload reference image end for i2v interpolation
735
+ * @internal
736
+ */
737
+ private async uploadReferenceImageEnd(projectId: string, file: File | Buffer | Blob) {
738
+ const imageId = getUUID();
739
+ const presignedUrl = await this.uploadUrl({
740
+ imageId,
741
+ jobId: projectId,
742
+ type: 'referenceImageEnd'
743
+ });
744
+ const res = await fetch(presignedUrl, {
745
+ method: 'PUT',
746
+ body: toFetchBody(file)
747
+ });
748
+ if (!res.ok) {
749
+ throw new ApiError(res.status, {
750
+ status: 'error',
751
+ errorCode: 0,
752
+ message: 'Failed to upload reference image end'
753
+ });
754
+ }
755
+ return imageId;
756
+ }
757
+
758
+ /**
759
+ * Upload reference audio for s2v workflows
760
+ * Supported formats: mp3, m4a, wav
761
+ * @internal
762
+ */
763
+ private async uploadReferenceAudio(projectId: string, file: File | Buffer | Blob) {
764
+ const contentType = getFileContentType(file);
765
+ const presignedUrl = await this.mediaUploadUrl({
766
+ jobId: projectId,
767
+ type: 'referenceAudio'
768
+ });
769
+ const headers: Record<string, string> = {};
770
+ if (contentType) {
771
+ headers['Content-Type'] = contentType;
772
+ }
773
+ const res = await fetch(presignedUrl, {
774
+ method: 'PUT',
775
+ body: toFetchBody(file),
776
+ headers
777
+ });
778
+ if (!res.ok) {
779
+ throw new ApiError(res.status, {
780
+ status: 'error',
781
+ errorCode: 0,
782
+ message: 'Failed to upload reference audio'
783
+ });
784
+ }
785
+ }
786
+
502
787
  /**
503
- * Estimate project cost
788
+ * Upload reference video for animate workflows
789
+ * Supported formats: mp4, mov
790
+ * @internal
791
+ */
792
+ private async uploadReferenceVideo(projectId: string, file: File | Buffer | Blob) {
793
+ const contentType = getFileContentType(file);
794
+ const presignedUrl = await this.mediaUploadUrl({
795
+ jobId: projectId,
796
+ type: 'referenceVideo'
797
+ });
798
+ const headers: Record<string, string> = {};
799
+ if (contentType) {
800
+ headers['Content-Type'] = contentType;
801
+ }
802
+ const res = await fetch(presignedUrl, {
803
+ method: 'PUT',
804
+ body: toFetchBody(file),
805
+ headers
806
+ });
807
+ if (!res.ok) {
808
+ throw new ApiError(res.status, {
809
+ status: 'error',
810
+ errorCode: 0,
811
+ message: 'Failed to upload reference video'
812
+ });
813
+ }
814
+ }
815
+
816
+ // ============================================
817
+ // COST ESTIMATION
818
+ // ============================================
819
+
820
+ /**
821
+ * Estimate image project cost
504
822
  */
505
823
  async estimateCost({
506
824
  network,
@@ -515,12 +833,13 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
515
833
  height,
516
834
  sizePreset,
517
835
  guidance,
518
- scheduler,
836
+ sampler,
519
837
  contextImages
520
- }: EstimateRequest) {
838
+ }: EstimateRequest): Promise<CostEstimation> {
521
839
  let apiVersion = 2;
840
+ const modelOptions = await this.getModelOptions(model);
522
841
  const pathParams = [
523
- tokenType || 'sogni',
842
+ tokenType || 'spark',
524
843
  network,
525
844
  model,
526
845
  imageCount,
@@ -541,10 +860,10 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
541
860
  } else {
542
861
  pathParams.push(0, 0);
543
862
  }
544
- if (scheduler) {
863
+ if (sampler) {
545
864
  apiVersion = 3;
546
865
  pathParams.push(guidance || 0);
547
- pathParams.push(scheduler || '');
866
+ pathParams.push(validateSampler(sampler, modelOptions)!);
548
867
  pathParams.push(contextImages || 0);
549
868
  }
550
869
  const r = await this.client.socket.get<EstimationResponse>(
@@ -552,11 +871,18 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
552
871
  );
553
872
  return {
554
873
  token: r.quote.project.costInToken,
555
- usd: r.quote.project.costInUSD
874
+ usd: r.quote.project.costInUSD,
875
+ spark: r.quote.project.costInSpark,
876
+ sogni: r.quote.project.costInSogni
556
877
  };
557
878
  }
558
879
 
559
- async estimateEnhancementCost(strength: EnhancementStrength, tokenType: TokenType = 'sogni') {
880
+ /**
881
+ * Estimate image enhancement cost
882
+ * @param strength
883
+ * @param tokenType
884
+ */
885
+ async estimateEnhancementCost(strength: EnhancementStrength, tokenType: TokenType = 'spark') {
560
886
  return this.estimateCost({
561
887
  network: enhancementDefaults.network,
562
888
  tokenType,
@@ -569,10 +895,53 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
569
895
  });
570
896
  }
571
897
 
898
+ /**
899
+ * Estimates the cost of generating a video based on the provided parameters.
900
+ *
901
+ * @param {VideoEstimateRequest} params - The parameters required for video cost estimation. This includes:
902
+ * - tokenType: The type of token to be used for generation.
903
+ * - model: The model to be used for video generation.
904
+ * - width: The width of the video in pixels.
905
+ * - height: The height of the video in pixels.
906
+ * - frames: The total number of frames in the video.
907
+ * - fps: The frames per second for the video.
908
+ * - steps: Number of steps.
909
+ * @return {Promise<Object>} Returns an object containing the estimated costs for the video in different units:
910
+ * - token: Cost in tokens.
911
+ * - usd: Cost in USD.
912
+ * - spark: Cost in Spark.
913
+ * - sogni: Cost in Sogni.
914
+ */
915
+ async estimateVideoCost(params: VideoEstimateRequest) {
916
+ const pathParams = [
917
+ params.tokenType,
918
+ params.model,
919
+ params.width,
920
+ params.height,
921
+ params.frames ? params.frames : params.duration * 16 + 1,
922
+ params.fps,
923
+ params.steps,
924
+ params.numberOfMedia
925
+ ];
926
+ const path = pathParams.map((p) => encodeURIComponent(p)).join('/');
927
+ const r = await this.client.socket.get<EstimationResponse>(
928
+ `/api/v1/job-video/estimate/${path}`
929
+ );
930
+ return {
931
+ token: r.quote.project.costInToken,
932
+ usd: r.quote.project.costInUSD,
933
+ spark: r.quote.project.costInSpark,
934
+ sogni: r.quote.project.costInSogni
935
+ };
936
+ }
937
+
938
+ // ============================================
939
+ // URL HELPERS
940
+ // ============================================
941
+
572
942
  /**
573
943
  * Get upload URL for image
574
944
  * @internal
575
- * @param params
576
945
  */
577
946
  async uploadUrl(params: ImageUrlParams) {
578
947
  const r = await this.client.rest.get<ApiResponse<{ uploadUrl: string }>>(
@@ -585,16 +954,49 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
585
954
  /**
586
955
  * Get download URL for image
587
956
  * @internal
588
- * @param params
589
957
  */
590
958
  async downloadUrl(params: ImageUrlParams) {
591
959
  const r = await this.client.rest.get<ApiResponse<{ downloadUrl: string }>>(
592
960
  `/v1/image/downloadUrl`,
593
961
  params
594
962
  );
963
+ if (!r?.data?.downloadUrl) {
964
+ throw new Error(`API returned no downloadUrl: ${JSON.stringify(r)}`);
965
+ }
595
966
  return r.data.downloadUrl;
596
967
  }
597
968
 
969
+ /**
970
+ * Get upload URL for media (video/audio)
971
+ * @internal
972
+ */
973
+ async mediaUploadUrl(params: MediaUrlParams) {
974
+ const r = await this.client.rest.get<ApiResponse<{ uploadUrl: string }>>(
975
+ `/v1/media/uploadUrl`,
976
+ params
977
+ );
978
+ return r.data.uploadUrl;
979
+ }
980
+
981
+ /**
982
+ * Get download URL for media (video/audio)
983
+ * @internal
984
+ */
985
+ async mediaDownloadUrl(params: MediaUrlParams) {
986
+ const r = await this.client.rest.get<ApiResponse<{ downloadUrl: string }>>(
987
+ `/v1/media/downloadUrl`,
988
+ params
989
+ );
990
+ if (!r?.data?.downloadUrl) {
991
+ throw new Error(`API returned no downloadUrl: ${JSON.stringify(r)}`);
992
+ }
993
+ return r.data.downloadUrl;
994
+ }
995
+
996
+ // ============================================
997
+ // MODEL/PRESET HELPERS
998
+ // ============================================
999
+
598
1000
  async getSupportedModels(forceRefresh = false) {
599
1001
  if (
600
1002
  this._supportedModels.data &&
@@ -608,12 +1010,25 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
608
1010
  return models;
609
1011
  }
610
1012
 
1013
+ private async _getModelTiers(forceRefresh = false) {
1014
+ if (
1015
+ this._modelTiers.data &&
1016
+ !forceRefresh &&
1017
+ Date.now() - this._modelTiers.updatedAt.getTime() < MODELS_REFRESH_INTERVAL
1018
+ ) {
1019
+ return this._modelTiers.data;
1020
+ }
1021
+ const tiers = await this.client.socket.get<ModelTiersRaw>(`/api/v2/models/tiers`);
1022
+ this._modelTiers = { data: tiers, updatedAt: new Date() };
1023
+ return tiers;
1024
+ }
1025
+
611
1026
  /**
612
1027
  * Get supported size presets for the model and network. Size presets are cached for 10 minutes.
613
1028
  *
614
1029
  * @example
615
1030
  * ```ts
616
- * const presets = await client.projects.getSizePresets('fast', 'flux1-schnell-fp8');
1031
+ * const presets = await sogni.projects.getSizePresets('fast', 'flux1-schnell-fp8');
617
1032
  * console.log(presets);
618
1033
  * ```
619
1034
  *
@@ -642,6 +1057,49 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
642
1057
  return data;
643
1058
  }
644
1059
 
1060
+ /**
1061
+ * Retrieves the video asset configuration for a given video model identifier.
1062
+ * Validates whether the provided model ID corresponds to a video model. If it does,
1063
+ * returns the appropriate video asset configuration based on the workflow type.
1064
+ *
1065
+ * @example Returned object for a model that implements image to video workflow:
1066
+ * ```json
1067
+ * {
1068
+ * "workflowType": "i2v",
1069
+ * "assets": {
1070
+ * "referenceImage": "required",
1071
+ * "referenceImageEnd": "optional",
1072
+ * "referenceAudio": "forbidden",
1073
+ * "referenceVideo": "forbidden"
1074
+ * }
1075
+ * }
1076
+ * ```
1077
+ *
1078
+ * @param {string} modelId - The identifier of the video model to retrieve the configuration for.
1079
+ * @return {Object} The video asset configuration object where key is asset field and value is
1080
+ * either `required`, `forbidden` or `optional`. Returns `null` if no rules defined for the model.
1081
+ * @throws {ApiError} Throws an error if the provided model ID is not a video model.
1082
+ */
1083
+ async getVideoAssetConfig(modelId: string) {
1084
+ if (!this.isVideoModelId(modelId)) {
1085
+ throw new ApiError(400, {
1086
+ status: 'error',
1087
+ errorCode: 0,
1088
+ message: `Model ${modelId} is not a video model`
1089
+ });
1090
+ }
1091
+ const workflow = getVideoWorkflowType(modelId);
1092
+ if (!workflow) {
1093
+ return {
1094
+ workflowType: null
1095
+ };
1096
+ }
1097
+ return {
1098
+ workflowType: workflow,
1099
+ assets: VIDEO_WORKFLOW_ASSETS[workflow]
1100
+ };
1101
+ }
1102
+
645
1103
  /**
646
1104
  * Get available models and their worker counts. Normally, you would get list once you connect
647
1105
  * to the server, but you can also call this method to get the list of available models manually.
@@ -658,10 +1116,34 @@ class ProjectsApi extends ApiGroup<ProjectApiEvents> {
658
1116
  return {
659
1117
  id: model?.id || sid,
660
1118
  name: model?.name || sid.replace(/-/g, ' '),
661
- workerCount
1119
+ workerCount,
1120
+ media: model?.media || 'image'
662
1121
  };
663
1122
  });
664
1123
  }
1124
+
1125
+ async getModelOptions(modelId: string): Promise<ModelOptions> {
1126
+ const models = await this.getSupportedModels();
1127
+ const tiers = await this._getModelTiers();
1128
+ const model = models.find((m) => m.id === modelId);
1129
+ if (!model) {
1130
+ throw new Error(`Model ${modelId} not supported`);
1131
+ }
1132
+ const tier = tiers[model.tier];
1133
+ if (!tier) {
1134
+ throw new Error(`Unable to find model tier "${model.tier}" please contact support`);
1135
+ }
1136
+ if (isImageTier(tier)) {
1137
+ return mapImageTier(tier);
1138
+ }
1139
+ if (isVideoTier(tier)) {
1140
+ return mapVideoTier(tier);
1141
+ }
1142
+ if (isComfyImageTier(tier)) {
1143
+ return mapComfyImageTier(tier);
1144
+ }
1145
+ throw new Error(`Unsupported model tier "${model.tier}"`);
1146
+ }
665
1147
  }
666
1148
 
667
1149
  export default ProjectsApi;