@elaraai/east-py-datascience 0.0.2-beta.31 → 0.0.2-beta.33
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -13,6 +13,33 @@
|
|
|
13
13
|
*/
|
|
14
14
|
import { StructType, VariantType, OptionType, IntegerType, FloatType, BlobType, ArrayType, NullType, BooleanType, FunctionType, StringType } from "@elaraai/east";
|
|
15
15
|
export { VectorType, MatrixType } from "../types.js";
|
|
16
|
+
/**
|
|
17
|
+
* Return embedding mode for Decision Transformer.
|
|
18
|
+
*/
|
|
19
|
+
export declare const ReturnEmbeddingType: VariantType<{
|
|
20
|
+
/** Single return value for entire sequence */
|
|
21
|
+
global: NullType;
|
|
22
|
+
/** Return-to-go at each timestep */
|
|
23
|
+
per_timestep: NullType;
|
|
24
|
+
}>;
|
|
25
|
+
/**
|
|
26
|
+
* Per-head output configuration for multi_head_mixed.
|
|
27
|
+
*/
|
|
28
|
+
export declare const HeadConfigType: StructType<{
|
|
29
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
30
|
+
head_type: VariantType<{
|
|
31
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
32
|
+
binary: NullType;
|
|
33
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
34
|
+
multiclass: StructType<{
|
|
35
|
+
n_classes: IntegerType;
|
|
36
|
+
}>;
|
|
37
|
+
}>;
|
|
38
|
+
/** Optional class weights for this head */
|
|
39
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
40
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
41
|
+
conditional_on: OptionType<IntegerType>;
|
|
42
|
+
}>;
|
|
16
43
|
/**
|
|
17
44
|
* Lightning output mode - determines loss function and output activation.
|
|
18
45
|
*/
|
|
@@ -40,6 +67,31 @@ export declare const LightningOutputType: VariantType<{
|
|
|
40
67
|
/** Optional class weights matrix (n_heads, n_classes) */
|
|
41
68
|
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
42
69
|
}>;
|
|
70
|
+
/**
|
|
71
|
+
* Mixed output types per head.
|
|
72
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
73
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
74
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
75
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
76
|
+
*/
|
|
77
|
+
multi_head_mixed: StructType<{
|
|
78
|
+
/** Array of head configurations */
|
|
79
|
+
heads: ArrayType<StructType<{
|
|
80
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
81
|
+
head_type: VariantType<{
|
|
82
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
83
|
+
binary: NullType;
|
|
84
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
85
|
+
multiclass: StructType<{
|
|
86
|
+
n_classes: IntegerType;
|
|
87
|
+
}>;
|
|
88
|
+
}>;
|
|
89
|
+
/** Optional class weights for this head */
|
|
90
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
91
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
92
|
+
conditional_on: OptionType<IntegerType>;
|
|
93
|
+
}>>;
|
|
94
|
+
}>;
|
|
43
95
|
}>;
|
|
44
96
|
/**
|
|
45
97
|
* Cell type for sequential architectures.
|
|
@@ -122,6 +174,36 @@ export declare const LightningArchitectureType: VariantType<{
|
|
|
122
174
|
/** Optional condition dimension for conditional generation */
|
|
123
175
|
condition_dim: OptionType<IntegerType>;
|
|
124
176
|
}>;
|
|
177
|
+
/**
|
|
178
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
179
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
180
|
+
* Predicts actions conditioned on desired return and state history.
|
|
181
|
+
*/
|
|
182
|
+
decision_transformer: StructType<{
|
|
183
|
+
/** Sequence length (timesteps) */
|
|
184
|
+
sequence_length: IntegerType;
|
|
185
|
+
/** State dimension per timestep */
|
|
186
|
+
state_dim: IntegerType;
|
|
187
|
+
/** Action dimension per timestep */
|
|
188
|
+
action_dim: IntegerType;
|
|
189
|
+
/** Model dimension (transformer hidden size) */
|
|
190
|
+
d_model: IntegerType;
|
|
191
|
+
/** Number of attention heads */
|
|
192
|
+
n_attention_heads: IntegerType;
|
|
193
|
+
/** Number of transformer layers */
|
|
194
|
+
n_layers: IntegerType;
|
|
195
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
196
|
+
d_ff: OptionType<IntegerType>;
|
|
197
|
+
/** Dropout rate */
|
|
198
|
+
dropout: OptionType<FloatType>;
|
|
199
|
+
/** Whether return is per-timestep or global */
|
|
200
|
+
return_embedding: VariantType<{
|
|
201
|
+
/** Single return value for entire sequence */
|
|
202
|
+
global: NullType;
|
|
203
|
+
/** Return-to-go at each timestep */
|
|
204
|
+
per_timestep: NullType;
|
|
205
|
+
}>;
|
|
206
|
+
}>;
|
|
125
207
|
}>;
|
|
126
208
|
/**
|
|
127
209
|
* Epoch callback function type: (epoch, train_loss, val_loss) -> void
|
|
@@ -203,6 +285,36 @@ export declare const LightningConfigType: StructType<{
|
|
|
203
285
|
/** Optional condition dimension for conditional generation */
|
|
204
286
|
condition_dim: OptionType<IntegerType>;
|
|
205
287
|
}>;
|
|
288
|
+
/**
|
|
289
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
290
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
291
|
+
* Predicts actions conditioned on desired return and state history.
|
|
292
|
+
*/
|
|
293
|
+
decision_transformer: StructType<{
|
|
294
|
+
/** Sequence length (timesteps) */
|
|
295
|
+
sequence_length: IntegerType;
|
|
296
|
+
/** State dimension per timestep */
|
|
297
|
+
state_dim: IntegerType;
|
|
298
|
+
/** Action dimension per timestep */
|
|
299
|
+
action_dim: IntegerType;
|
|
300
|
+
/** Model dimension (transformer hidden size) */
|
|
301
|
+
d_model: IntegerType;
|
|
302
|
+
/** Number of attention heads */
|
|
303
|
+
n_attention_heads: IntegerType;
|
|
304
|
+
/** Number of transformer layers */
|
|
305
|
+
n_layers: IntegerType;
|
|
306
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
307
|
+
d_ff: OptionType<IntegerType>;
|
|
308
|
+
/** Dropout rate */
|
|
309
|
+
dropout: OptionType<FloatType>;
|
|
310
|
+
/** Whether return is per-timestep or global */
|
|
311
|
+
return_embedding: VariantType<{
|
|
312
|
+
/** Single return value for entire sequence */
|
|
313
|
+
global: NullType;
|
|
314
|
+
/** Return-to-go at each timestep */
|
|
315
|
+
per_timestep: NullType;
|
|
316
|
+
}>;
|
|
317
|
+
}>;
|
|
206
318
|
}>;
|
|
207
319
|
/** Output mode (determines loss function) */
|
|
208
320
|
output: VariantType<{
|
|
@@ -229,6 +341,31 @@ export declare const LightningConfigType: StructType<{
|
|
|
229
341
|
/** Optional class weights matrix (n_heads, n_classes) */
|
|
230
342
|
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
231
343
|
}>;
|
|
344
|
+
/**
|
|
345
|
+
* Mixed output types per head.
|
|
346
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
347
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
348
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
349
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
350
|
+
*/
|
|
351
|
+
multi_head_mixed: StructType<{
|
|
352
|
+
/** Array of head configurations */
|
|
353
|
+
heads: ArrayType<StructType<{
|
|
354
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
355
|
+
head_type: VariantType<{
|
|
356
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
357
|
+
binary: NullType;
|
|
358
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
359
|
+
multiclass: StructType<{
|
|
360
|
+
n_classes: IntegerType;
|
|
361
|
+
}>;
|
|
362
|
+
}>;
|
|
363
|
+
/** Optional class weights for this head */
|
|
364
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
365
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
366
|
+
conditional_on: OptionType<IntegerType>;
|
|
367
|
+
}>>;
|
|
368
|
+
}>;
|
|
232
369
|
}>;
|
|
233
370
|
/** Learning rate (default: 1e-3) */
|
|
234
371
|
learning_rate: OptionType<FloatType>;
|
|
@@ -401,6 +538,36 @@ export declare const lightning_train: import("@elaraai/east").PlatformDefinition
|
|
|
401
538
|
/** Optional condition dimension for conditional generation */
|
|
402
539
|
condition_dim: OptionType<IntegerType>;
|
|
403
540
|
}>;
|
|
541
|
+
/**
|
|
542
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
543
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
544
|
+
* Predicts actions conditioned on desired return and state history.
|
|
545
|
+
*/
|
|
546
|
+
decision_transformer: StructType<{
|
|
547
|
+
/** Sequence length (timesteps) */
|
|
548
|
+
sequence_length: IntegerType;
|
|
549
|
+
/** State dimension per timestep */
|
|
550
|
+
state_dim: IntegerType;
|
|
551
|
+
/** Action dimension per timestep */
|
|
552
|
+
action_dim: IntegerType;
|
|
553
|
+
/** Model dimension (transformer hidden size) */
|
|
554
|
+
d_model: IntegerType;
|
|
555
|
+
/** Number of attention heads */
|
|
556
|
+
n_attention_heads: IntegerType;
|
|
557
|
+
/** Number of transformer layers */
|
|
558
|
+
n_layers: IntegerType;
|
|
559
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
560
|
+
d_ff: OptionType<IntegerType>;
|
|
561
|
+
/** Dropout rate */
|
|
562
|
+
dropout: OptionType<FloatType>;
|
|
563
|
+
/** Whether return is per-timestep or global */
|
|
564
|
+
return_embedding: VariantType<{
|
|
565
|
+
/** Single return value for entire sequence */
|
|
566
|
+
global: NullType;
|
|
567
|
+
/** Return-to-go at each timestep */
|
|
568
|
+
per_timestep: NullType;
|
|
569
|
+
}>;
|
|
570
|
+
}>;
|
|
404
571
|
}>;
|
|
405
572
|
/** Output mode (determines loss function) */
|
|
406
573
|
output: VariantType<{
|
|
@@ -427,6 +594,31 @@ export declare const lightning_train: import("@elaraai/east").PlatformDefinition
|
|
|
427
594
|
/** Optional class weights matrix (n_heads, n_classes) */
|
|
428
595
|
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
429
596
|
}>;
|
|
597
|
+
/**
|
|
598
|
+
* Mixed output types per head.
|
|
599
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
600
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
601
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
602
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
603
|
+
*/
|
|
604
|
+
multi_head_mixed: StructType<{
|
|
605
|
+
/** Array of head configurations */
|
|
606
|
+
heads: ArrayType<StructType<{
|
|
607
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
608
|
+
head_type: VariantType<{
|
|
609
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
610
|
+
binary: NullType;
|
|
611
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
612
|
+
multiclass: StructType<{
|
|
613
|
+
n_classes: IntegerType;
|
|
614
|
+
}>;
|
|
615
|
+
}>;
|
|
616
|
+
/** Optional class weights for this head */
|
|
617
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
618
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
619
|
+
conditional_on: OptionType<IntegerType>;
|
|
620
|
+
}>>;
|
|
621
|
+
}>;
|
|
430
622
|
}>;
|
|
431
623
|
/** Learning rate (default: 1e-3) */
|
|
432
624
|
learning_rate: OptionType<FloatType>;
|
|
@@ -577,35 +769,97 @@ export declare const lightning_decode_conditional: import("@elaraai/east").Platf
|
|
|
577
769
|
}>;
|
|
578
770
|
}>, ArrayType<ArrayType<FloatType>>, ArrayType<ArrayType<FloatType>>], ArrayType<ArrayType<FloatType>>>;
|
|
579
771
|
/**
|
|
580
|
-
*
|
|
772
|
+
* Configuration for autoregressive sequence generation.
|
|
581
773
|
*/
|
|
582
|
-
export declare const
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
774
|
+
export declare const LightningGenerateConfigType: StructType<{
|
|
775
|
+
/** Number of steps to generate */
|
|
776
|
+
n_steps: IntegerType;
|
|
777
|
+
/** Sampling temperature: 0.0 = argmax, > 0 = scaled sampling */
|
|
778
|
+
temperature: FloatType;
|
|
779
|
+
/** If true, return probabilities. If false, return samples. */
|
|
780
|
+
return_probs: BooleanType;
|
|
781
|
+
}>;
|
|
782
|
+
/**
|
|
783
|
+
* Generate sequence autoregressively from a sequential model.
|
|
784
|
+
*
|
|
785
|
+
* Shapes:
|
|
786
|
+
* - prefix: (n_prefix_steps, n_channels) - partial history to continue from, can be empty []
|
|
787
|
+
* - condition: (1, condition_dim) - conditioning features, or none
|
|
788
|
+
* - returns: (n_steps, n_channels) - generated timesteps only (not including prefix)
|
|
789
|
+
*
|
|
790
|
+
* @param model - Trained sequential model blob
|
|
791
|
+
* @param prefix - Partial history to continue from
|
|
792
|
+
* @param condition - Optional conditioning features
|
|
793
|
+
* @param config - Generation configuration
|
|
794
|
+
* @returns Generated sequence matrix
|
|
795
|
+
*/
|
|
796
|
+
export declare const lightning_generate_sequence: import("@elaraai/east").PlatformDefinition<[VariantType<{
|
|
797
|
+
lightning: StructType<{
|
|
798
|
+
/** Serialized model data (state_dict + hparams) */
|
|
799
|
+
data: BlobType;
|
|
800
|
+
/** Input dimension */
|
|
801
|
+
n_features: IntegerType;
|
|
802
|
+
/** Output dimension */
|
|
803
|
+
output_dim: IntegerType;
|
|
804
|
+
/** Architecture type */
|
|
805
|
+
architecture_type: StringType;
|
|
806
|
+
/** Output type */
|
|
807
|
+
output_type: StringType;
|
|
808
|
+
/** Latent dimension (autoencoder only) */
|
|
809
|
+
latent_dim: OptionType<IntegerType>;
|
|
607
810
|
}>;
|
|
608
|
-
|
|
811
|
+
}>, ArrayType<ArrayType<FloatType>>, OptionType<ArrayType<ArrayType<FloatType>>>, StructType<{
|
|
812
|
+
/** Number of steps to generate */
|
|
813
|
+
n_steps: IntegerType;
|
|
814
|
+
/** Sampling temperature: 0.0 = argmax, > 0 = scaled sampling */
|
|
815
|
+
temperature: FloatType;
|
|
816
|
+
/** If true, return probabilities. If false, return samples. */
|
|
817
|
+
return_probs: BooleanType;
|
|
818
|
+
}>], ArrayType<ArrayType<FloatType>>>;
|
|
819
|
+
/**
|
|
820
|
+
* Configuration for Decision Transformer trajectory generation.
|
|
821
|
+
*/
|
|
822
|
+
export declare const TrajectoryGenerateConfigType: StructType<{
|
|
823
|
+
/** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
|
|
824
|
+
temperature: FloatType;
|
|
825
|
+
/** Whether to return probabilities or samples */
|
|
826
|
+
return_probs: BooleanType;
|
|
827
|
+
/** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
|
|
828
|
+
action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
829
|
+
/** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
|
|
830
|
+
temporal_mask: OptionType<ArrayType<FloatType>>;
|
|
831
|
+
/** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
|
|
832
|
+
head_configs: OptionType<ArrayType<StructType<{
|
|
833
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
834
|
+
head_type: VariantType<{
|
|
835
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
836
|
+
binary: NullType;
|
|
837
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
838
|
+
multiclass: StructType<{
|
|
839
|
+
n_classes: IntegerType;
|
|
840
|
+
}>;
|
|
841
|
+
}>;
|
|
842
|
+
/** Optional class weights for this head */
|
|
843
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
844
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
845
|
+
conditional_on: OptionType<IntegerType>;
|
|
846
|
+
}>>>;
|
|
847
|
+
}>;
|
|
848
|
+
/**
|
|
849
|
+
* Train with trajectory data for return-conditioned sequence generation.
|
|
850
|
+
*
|
|
851
|
+
* Use with decision_transformer architecture.
|
|
852
|
+
*
|
|
853
|
+
* @param returns - Return per sample (n_samples,) - actual outcome achieved
|
|
854
|
+
* @param states - State matrices: n_samples × (seq_len, state_dim)
|
|
855
|
+
* @param actions - Action matrices: n_samples × (seq_len, action_dim)
|
|
856
|
+
* @param masks - Temporal masks: n_samples × (seq_len,) - valid timesteps
|
|
857
|
+
* @param config - Training configuration with decision_transformer architecture
|
|
858
|
+
* @returns Training result with model blob and metrics
|
|
859
|
+
*/
|
|
860
|
+
export declare const lightning_train_trajectory: import("@elaraai/east").PlatformDefinition<[ArrayType<FloatType>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<ArrayType<FloatType>>, StructType<{
|
|
861
|
+
/** Model architecture */
|
|
862
|
+
architecture: VariantType<{
|
|
609
863
|
/** Simple MLP: input → hidden → output */
|
|
610
864
|
mlp: StructType<{
|
|
611
865
|
/** Hidden layer sizes */
|
|
@@ -676,111 +930,498 @@ export declare const LightningTypes: {
|
|
|
676
930
|
/** Optional condition dimension for conditional generation */
|
|
677
931
|
condition_dim: OptionType<IntegerType>;
|
|
678
932
|
}>;
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
/**
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
sequence_length: IntegerType;
|
|
708
|
-
/** Conv layer channel sizes */
|
|
709
|
-
conv_channels: ArrayType<IntegerType>;
|
|
710
|
-
/** Kernel size for convolutions (must be odd) */
|
|
711
|
-
kernel_size: IntegerType;
|
|
712
|
-
/** Latent dimension after flattening */
|
|
713
|
-
latent_dim: IntegerType;
|
|
714
|
-
/** Optional condition dimension for conditional generation */
|
|
715
|
-
condition_dim: OptionType<IntegerType>;
|
|
716
|
-
}>;
|
|
717
|
-
/** Sequential: LSTM/GRU autoencoder for long-range dependencies */
|
|
718
|
-
sequential: StructType<{
|
|
719
|
-
/** Number of channels (e.g., additive types) */
|
|
720
|
-
n_channels: IntegerType;
|
|
721
|
-
/** Sequence length (e.g., days) */
|
|
722
|
-
sequence_length: IntegerType;
|
|
723
|
-
/** RNN hidden size */
|
|
724
|
-
hidden_size: IntegerType;
|
|
725
|
-
/** Number of RNN layers */
|
|
726
|
-
n_layers: IntegerType;
|
|
727
|
-
/** Cell type: lstm or gru */
|
|
728
|
-
cell_type: VariantType<{
|
|
729
|
-
lstm: NullType;
|
|
730
|
-
gru: NullType;
|
|
731
|
-
}>;
|
|
732
|
-
/** Latent dimension (from final hidden state) */
|
|
733
|
-
latent_dim: IntegerType;
|
|
734
|
-
/** Bidirectional encoder (decoder is always unidirectional) */
|
|
735
|
-
bidirectional: BooleanType;
|
|
736
|
-
/** Optional condition dimension for conditional generation */
|
|
737
|
-
condition_dim: OptionType<IntegerType>;
|
|
738
|
-
}>;
|
|
739
|
-
/** Transformer: attention-based autoencoder for complex patterns */
|
|
740
|
-
transformer: StructType<{
|
|
741
|
-
/** Number of channels (e.g., additive types) */
|
|
742
|
-
n_channels: IntegerType;
|
|
743
|
-
/** Sequence length (e.g., days) */
|
|
744
|
-
sequence_length: IntegerType;
|
|
745
|
-
/** Model dimension */
|
|
746
|
-
d_model: IntegerType;
|
|
747
|
-
/** Number of attention heads (must divide d_model evenly) */
|
|
748
|
-
n_attention_heads: IntegerType;
|
|
749
|
-
/** Number of transformer layers */
|
|
750
|
-
n_layers: IntegerType;
|
|
751
|
-
/** Feedforward dimension (default: 4 * d_model) */
|
|
752
|
-
d_ff: OptionType<IntegerType>;
|
|
753
|
-
/** Latent dimension (mean pooled output) */
|
|
754
|
-
latent_dim: IntegerType;
|
|
755
|
-
/** Optional condition dimension for conditional generation */
|
|
756
|
-
condition_dim: OptionType<IntegerType>;
|
|
933
|
+
/**
|
|
934
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
935
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
936
|
+
* Predicts actions conditioned on desired return and state history.
|
|
937
|
+
*/
|
|
938
|
+
decision_transformer: StructType<{
|
|
939
|
+
/** Sequence length (timesteps) */
|
|
940
|
+
sequence_length: IntegerType;
|
|
941
|
+
/** State dimension per timestep */
|
|
942
|
+
state_dim: IntegerType;
|
|
943
|
+
/** Action dimension per timestep */
|
|
944
|
+
action_dim: IntegerType;
|
|
945
|
+
/** Model dimension (transformer hidden size) */
|
|
946
|
+
d_model: IntegerType;
|
|
947
|
+
/** Number of attention heads */
|
|
948
|
+
n_attention_heads: IntegerType;
|
|
949
|
+
/** Number of transformer layers */
|
|
950
|
+
n_layers: IntegerType;
|
|
951
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
952
|
+
d_ff: OptionType<IntegerType>;
|
|
953
|
+
/** Dropout rate */
|
|
954
|
+
dropout: OptionType<FloatType>;
|
|
955
|
+
/** Whether return is per-timestep or global */
|
|
956
|
+
return_embedding: VariantType<{
|
|
957
|
+
/** Single return value for entire sequence */
|
|
958
|
+
global: NullType;
|
|
959
|
+
/** Return-to-go at each timestep */
|
|
960
|
+
per_timestep: NullType;
|
|
757
961
|
}>;
|
|
758
962
|
}>;
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
963
|
+
}>;
|
|
964
|
+
/** Output mode (determines loss function) */
|
|
965
|
+
output: VariantType<{
|
|
966
|
+
/** Regression: MSE loss, no activation */
|
|
967
|
+
regression: NullType;
|
|
968
|
+
/** Binary: BCE loss, sigmoid activation */
|
|
969
|
+
binary: StructType<{
|
|
970
|
+
/** Optional per-position pos_weights for class imbalance [output_dim] */
|
|
971
|
+
pos_weight: OptionType<ArrayType<FloatType>>;
|
|
972
|
+
}>;
|
|
973
|
+
/** Multiclass: CrossEntropy loss, softmax activation */
|
|
974
|
+
multiclass: StructType<{
|
|
975
|
+
/** Number of classes */
|
|
976
|
+
n_classes: IntegerType;
|
|
977
|
+
/** Optional per-class weights */
|
|
978
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
979
|
+
}>;
|
|
980
|
+
/** Multi-head categorical: N independent CrossEntropy heads */
|
|
981
|
+
multi_head: StructType<{
|
|
982
|
+
/** Number of heads (e.g., 84 time slots) */
|
|
983
|
+
n_heads: IntegerType;
|
|
984
|
+
/** Classes per head (e.g., 4 bins) */
|
|
985
|
+
n_classes_per_head: IntegerType;
|
|
986
|
+
/** Optional class weights matrix (n_heads, n_classes) */
|
|
987
|
+
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
988
|
+
}>;
|
|
989
|
+
/**
|
|
990
|
+
* Mixed output types per head.
|
|
991
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
992
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
993
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
994
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
995
|
+
*/
|
|
996
|
+
multi_head_mixed: StructType<{
|
|
997
|
+
/** Array of head configurations */
|
|
998
|
+
heads: ArrayType<StructType<{
|
|
999
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
1000
|
+
head_type: VariantType<{
|
|
1001
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
1002
|
+
binary: NullType;
|
|
1003
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
1004
|
+
multiclass: StructType<{
|
|
1005
|
+
n_classes: IntegerType;
|
|
1006
|
+
}>;
|
|
1007
|
+
}>;
|
|
1008
|
+
/** Optional class weights for this head */
|
|
773
1009
|
class_weights: OptionType<ArrayType<FloatType>>;
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
1010
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
1011
|
+
conditional_on: OptionType<IntegerType>;
|
|
1012
|
+
}>>;
|
|
1013
|
+
}>;
|
|
1014
|
+
}>;
|
|
1015
|
+
/** Learning rate (default: 1e-3) */
|
|
1016
|
+
learning_rate: OptionType<FloatType>;
|
|
1017
|
+
/** Maximum epochs (default: 100) */
|
|
1018
|
+
max_epochs: OptionType<IntegerType>;
|
|
1019
|
+
/** Early stopping patience (default: 10) */
|
|
1020
|
+
patience: OptionType<IntegerType>;
|
|
1021
|
+
/** Batch size (default: 32) */
|
|
1022
|
+
batch_size: OptionType<IntegerType>;
|
|
1023
|
+
/** Dropout rate (default: 0.1) */
|
|
1024
|
+
dropout: OptionType<FloatType>;
|
|
1025
|
+
/** Gradient clipping value (default: 1.0) */
|
|
1026
|
+
gradient_clip: OptionType<FloatType>;
|
|
1027
|
+
/** L2 regularization weight decay (default: 0) */
|
|
1028
|
+
weight_decay: OptionType<FloatType>;
|
|
1029
|
+
/** Random seed for reproducibility */
|
|
1030
|
+
random_state: OptionType<IntegerType>;
|
|
1031
|
+
/** Optional callback called each epoch */
|
|
1032
|
+
epoch_callback: OptionType<FunctionType<[IntegerType, FloatType, FloatType], NullType>>;
|
|
1033
|
+
}>], StructType<{
|
|
1034
|
+
/** Trained model blob */
|
|
1035
|
+
model: VariantType<{
|
|
1036
|
+
lightning: StructType<{
|
|
1037
|
+
/** Serialized model data (state_dict + hparams) */
|
|
1038
|
+
data: BlobType;
|
|
1039
|
+
/** Input dimension */
|
|
1040
|
+
n_features: IntegerType;
|
|
1041
|
+
/** Output dimension */
|
|
1042
|
+
output_dim: IntegerType;
|
|
1043
|
+
/** Architecture type */
|
|
1044
|
+
architecture_type: StringType;
|
|
1045
|
+
/** Output type */
|
|
1046
|
+
output_type: StringType;
|
|
1047
|
+
/** Latent dimension (autoencoder only) */
|
|
1048
|
+
latent_dim: OptionType<IntegerType>;
|
|
1049
|
+
}>;
|
|
1050
|
+
}>;
|
|
1051
|
+
/** Final training loss */
|
|
1052
|
+
train_loss: FloatType;
|
|
1053
|
+
/** Final validation loss */
|
|
1054
|
+
val_loss: FloatType;
|
|
1055
|
+
/** Best epoch (for early stopping) */
|
|
1056
|
+
best_epoch: IntegerType;
|
|
1057
|
+
}>>;
|
|
1058
|
+
/**
|
|
1059
|
+
* Generate action sequences autoregressively from trajectory model.
|
|
1060
|
+
*
|
|
1061
|
+
* Use with models trained via trainTrajectory.
|
|
1062
|
+
*
|
|
1063
|
+
* @param model - Trained model from trainTrajectory
|
|
1064
|
+
* @param states - State matrices: n_samples × (seq_len, state_dim)
|
|
1065
|
+
* @param target_returns - Target returns: (n_samples,)
|
|
1066
|
+
* @param config - Generation configuration
|
|
1067
|
+
* @returns Generated actions: n_samples × (seq_len, action_dim)
|
|
1068
|
+
*/
|
|
1069
|
+
export declare const lightning_generate_trajectory: import("@elaraai/east").PlatformDefinition<[VariantType<{
|
|
1070
|
+
lightning: StructType<{
|
|
1071
|
+
/** Serialized model data (state_dict + hparams) */
|
|
1072
|
+
data: BlobType;
|
|
1073
|
+
/** Input dimension */
|
|
1074
|
+
n_features: IntegerType;
|
|
1075
|
+
/** Output dimension */
|
|
1076
|
+
output_dim: IntegerType;
|
|
1077
|
+
/** Architecture type */
|
|
1078
|
+
architecture_type: StringType;
|
|
1079
|
+
/** Output type */
|
|
1080
|
+
output_type: StringType;
|
|
1081
|
+
/** Latent dimension (autoencoder only) */
|
|
1082
|
+
latent_dim: OptionType<IntegerType>;
|
|
1083
|
+
}>;
|
|
1084
|
+
}>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<FloatType>, StructType<{
|
|
1085
|
+
/** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
|
|
1086
|
+
temperature: FloatType;
|
|
1087
|
+
/** Whether to return probabilities or samples */
|
|
1088
|
+
return_probs: BooleanType;
|
|
1089
|
+
/** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
|
|
1090
|
+
action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
1091
|
+
/** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
|
|
1092
|
+
temporal_mask: OptionType<ArrayType<FloatType>>;
|
|
1093
|
+
/** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
|
|
1094
|
+
head_configs: OptionType<ArrayType<StructType<{
|
|
1095
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
1096
|
+
head_type: VariantType<{
|
|
1097
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
1098
|
+
binary: NullType;
|
|
1099
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
1100
|
+
multiclass: StructType<{
|
|
1101
|
+
n_classes: IntegerType;
|
|
1102
|
+
}>;
|
|
1103
|
+
}>;
|
|
1104
|
+
/** Optional class weights for this head */
|
|
1105
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
1106
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
1107
|
+
conditional_on: OptionType<IntegerType>;
|
|
1108
|
+
}>>>;
|
|
1109
|
+
}>], ArrayType<ArrayType<ArrayType<FloatType>>>>;
|
|
1110
|
+
/**
|
|
1111
|
+
* Lightning types namespace.
|
|
1112
|
+
*/
|
|
1113
|
+
export declare const LightningTypes: {
|
|
1114
|
+
readonly OutputType: VariantType<{
|
|
1115
|
+
/** Regression: MSE loss, no activation */
|
|
1116
|
+
regression: NullType;
|
|
1117
|
+
/** Binary: BCE loss, sigmoid activation */
|
|
1118
|
+
binary: StructType<{
|
|
1119
|
+
/** Optional per-position pos_weights for class imbalance [output_dim] */
|
|
1120
|
+
pos_weight: OptionType<ArrayType<FloatType>>;
|
|
1121
|
+
}>;
|
|
1122
|
+
/** Multiclass: CrossEntropy loss, softmax activation */
|
|
1123
|
+
multiclass: StructType<{
|
|
1124
|
+
/** Number of classes */
|
|
1125
|
+
n_classes: IntegerType;
|
|
1126
|
+
/** Optional per-class weights */
|
|
1127
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
1128
|
+
}>;
|
|
1129
|
+
/** Multi-head categorical: N independent CrossEntropy heads */
|
|
1130
|
+
multi_head: StructType<{
|
|
1131
|
+
/** Number of heads (e.g., 84 time slots) */
|
|
1132
|
+
n_heads: IntegerType;
|
|
1133
|
+
/** Classes per head (e.g., 4 bins) */
|
|
1134
|
+
n_classes_per_head: IntegerType;
|
|
1135
|
+
/** Optional class weights matrix (n_heads, n_classes) */
|
|
1136
|
+
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
1137
|
+
}>;
|
|
1138
|
+
/**
|
|
1139
|
+
* Mixed output types per head.
|
|
1140
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
1141
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
1142
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
1143
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
1144
|
+
*/
|
|
1145
|
+
multi_head_mixed: StructType<{
|
|
1146
|
+
/** Array of head configurations */
|
|
1147
|
+
heads: ArrayType<StructType<{
|
|
1148
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
1149
|
+
head_type: VariantType<{
|
|
1150
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
1151
|
+
binary: NullType;
|
|
1152
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
1153
|
+
multiclass: StructType<{
|
|
1154
|
+
n_classes: IntegerType;
|
|
1155
|
+
}>;
|
|
1156
|
+
}>;
|
|
1157
|
+
/** Optional class weights for this head */
|
|
1158
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
1159
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
1160
|
+
conditional_on: OptionType<IntegerType>;
|
|
1161
|
+
}>>;
|
|
1162
|
+
}>;
|
|
1163
|
+
}>;
|
|
1164
|
+
readonly ArchitectureType: VariantType<{
|
|
1165
|
+
/** Simple MLP: input → hidden → output */
|
|
1166
|
+
mlp: StructType<{
|
|
1167
|
+
/** Hidden layer sizes */
|
|
1168
|
+
hidden_layers: ArrayType<IntegerType>;
|
|
1169
|
+
}>;
|
|
1170
|
+
/** Autoencoder: input → encoder → latent → decoder → output */
|
|
1171
|
+
autoencoder: StructType<{
|
|
1172
|
+
/** Encoder hidden layer sizes */
|
|
1173
|
+
encoder_layers: ArrayType<IntegerType>;
|
|
1174
|
+
/** Latent dimension (bottleneck) */
|
|
1175
|
+
latent_dim: IntegerType;
|
|
1176
|
+
/** Decoder hidden layer sizes */
|
|
1177
|
+
decoder_layers: ArrayType<IntegerType>;
|
|
1178
|
+
}>;
|
|
1179
|
+
/** Conv1D: 1D convolutional autoencoder for temporal patterns */
|
|
1180
|
+
conv1d: StructType<{
|
|
1181
|
+
/** Number of channels (e.g., additive types) */
|
|
1182
|
+
n_channels: IntegerType;
|
|
1183
|
+
/** Sequence length (e.g., days) */
|
|
1184
|
+
sequence_length: IntegerType;
|
|
1185
|
+
/** Conv layer channel sizes */
|
|
1186
|
+
conv_channels: ArrayType<IntegerType>;
|
|
1187
|
+
/** Kernel size for convolutions (must be odd) */
|
|
1188
|
+
kernel_size: IntegerType;
|
|
1189
|
+
/** Latent dimension after flattening */
|
|
1190
|
+
latent_dim: IntegerType;
|
|
1191
|
+
/** Optional condition dimension for conditional generation */
|
|
1192
|
+
condition_dim: OptionType<IntegerType>;
|
|
1193
|
+
}>;
|
|
1194
|
+
/** Sequential: LSTM/GRU autoencoder for long-range dependencies */
|
|
1195
|
+
sequential: StructType<{
|
|
1196
|
+
/** Number of channels (e.g., additive types) */
|
|
1197
|
+
n_channels: IntegerType;
|
|
1198
|
+
/** Sequence length (e.g., days) */
|
|
1199
|
+
sequence_length: IntegerType;
|
|
1200
|
+
/** RNN hidden size */
|
|
1201
|
+
hidden_size: IntegerType;
|
|
1202
|
+
/** Number of RNN layers */
|
|
1203
|
+
n_layers: IntegerType;
|
|
1204
|
+
/** Cell type: lstm or gru */
|
|
1205
|
+
cell_type: VariantType<{
|
|
1206
|
+
lstm: NullType;
|
|
1207
|
+
gru: NullType;
|
|
1208
|
+
}>;
|
|
1209
|
+
/** Latent dimension (from final hidden state) */
|
|
1210
|
+
latent_dim: IntegerType;
|
|
1211
|
+
/** Bidirectional encoder (decoder is always unidirectional) */
|
|
1212
|
+
bidirectional: BooleanType;
|
|
1213
|
+
/** Optional condition dimension for conditional generation */
|
|
1214
|
+
condition_dim: OptionType<IntegerType>;
|
|
1215
|
+
}>;
|
|
1216
|
+
/** Transformer: attention-based autoencoder for complex patterns */
|
|
1217
|
+
transformer: StructType<{
|
|
1218
|
+
/** Number of channels (e.g., additive types) */
|
|
1219
|
+
n_channels: IntegerType;
|
|
1220
|
+
/** Sequence length (e.g., days) */
|
|
1221
|
+
sequence_length: IntegerType;
|
|
1222
|
+
/** Model dimension */
|
|
1223
|
+
d_model: IntegerType;
|
|
1224
|
+
/** Number of attention heads (must divide d_model evenly) */
|
|
1225
|
+
n_attention_heads: IntegerType;
|
|
1226
|
+
/** Number of transformer layers */
|
|
1227
|
+
n_layers: IntegerType;
|
|
1228
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
1229
|
+
d_ff: OptionType<IntegerType>;
|
|
1230
|
+
/** Latent dimension (mean pooled output) */
|
|
1231
|
+
latent_dim: IntegerType;
|
|
1232
|
+
/** Optional condition dimension for conditional generation */
|
|
1233
|
+
condition_dim: OptionType<IntegerType>;
|
|
1234
|
+
}>;
|
|
1235
|
+
/**
|
|
1236
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
1237
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
1238
|
+
* Predicts actions conditioned on desired return and state history.
|
|
1239
|
+
*/
|
|
1240
|
+
decision_transformer: StructType<{
|
|
1241
|
+
/** Sequence length (timesteps) */
|
|
1242
|
+
sequence_length: IntegerType;
|
|
1243
|
+
/** State dimension per timestep */
|
|
1244
|
+
state_dim: IntegerType;
|
|
1245
|
+
/** Action dimension per timestep */
|
|
1246
|
+
action_dim: IntegerType;
|
|
1247
|
+
/** Model dimension (transformer hidden size) */
|
|
1248
|
+
d_model: IntegerType;
|
|
1249
|
+
/** Number of attention heads */
|
|
1250
|
+
n_attention_heads: IntegerType;
|
|
1251
|
+
/** Number of transformer layers */
|
|
1252
|
+
n_layers: IntegerType;
|
|
1253
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
1254
|
+
d_ff: OptionType<IntegerType>;
|
|
1255
|
+
/** Dropout rate */
|
|
1256
|
+
dropout: OptionType<FloatType>;
|
|
1257
|
+
/** Whether return is per-timestep or global */
|
|
1258
|
+
return_embedding: VariantType<{
|
|
1259
|
+
/** Single return value for entire sequence */
|
|
1260
|
+
global: NullType;
|
|
1261
|
+
/** Return-to-go at each timestep */
|
|
1262
|
+
per_timestep: NullType;
|
|
1263
|
+
}>;
|
|
1264
|
+
}>;
|
|
1265
|
+
}>;
|
|
1266
|
+
readonly CellType: VariantType<{
|
|
1267
|
+
lstm: NullType;
|
|
1268
|
+
gru: NullType;
|
|
1269
|
+
}>;
|
|
1270
|
+
readonly EpochCallbackType: FunctionType<[IntegerType, FloatType, FloatType], NullType>;
|
|
1271
|
+
readonly ConfigType: StructType<{
|
|
1272
|
+
/** Model architecture */
|
|
1273
|
+
architecture: VariantType<{
|
|
1274
|
+
/** Simple MLP: input → hidden → output */
|
|
1275
|
+
mlp: StructType<{
|
|
1276
|
+
/** Hidden layer sizes */
|
|
1277
|
+
hidden_layers: ArrayType<IntegerType>;
|
|
1278
|
+
}>;
|
|
1279
|
+
/** Autoencoder: input → encoder → latent → decoder → output */
|
|
1280
|
+
autoencoder: StructType<{
|
|
1281
|
+
/** Encoder hidden layer sizes */
|
|
1282
|
+
encoder_layers: ArrayType<IntegerType>;
|
|
1283
|
+
/** Latent dimension (bottleneck) */
|
|
1284
|
+
latent_dim: IntegerType;
|
|
1285
|
+
/** Decoder hidden layer sizes */
|
|
1286
|
+
decoder_layers: ArrayType<IntegerType>;
|
|
1287
|
+
}>;
|
|
1288
|
+
/** Conv1D: 1D convolutional autoencoder for temporal patterns */
|
|
1289
|
+
conv1d: StructType<{
|
|
1290
|
+
/** Number of channels (e.g., additive types) */
|
|
1291
|
+
n_channels: IntegerType;
|
|
1292
|
+
/** Sequence length (e.g., days) */
|
|
1293
|
+
sequence_length: IntegerType;
|
|
1294
|
+
/** Conv layer channel sizes */
|
|
1295
|
+
conv_channels: ArrayType<IntegerType>;
|
|
1296
|
+
/** Kernel size for convolutions (must be odd) */
|
|
1297
|
+
kernel_size: IntegerType;
|
|
1298
|
+
/** Latent dimension after flattening */
|
|
1299
|
+
latent_dim: IntegerType;
|
|
1300
|
+
/** Optional condition dimension for conditional generation */
|
|
1301
|
+
condition_dim: OptionType<IntegerType>;
|
|
1302
|
+
}>;
|
|
1303
|
+
/** Sequential: LSTM/GRU autoencoder for long-range dependencies */
|
|
1304
|
+
sequential: StructType<{
|
|
1305
|
+
/** Number of channels (e.g., additive types) */
|
|
1306
|
+
n_channels: IntegerType;
|
|
1307
|
+
/** Sequence length (e.g., days) */
|
|
1308
|
+
sequence_length: IntegerType;
|
|
1309
|
+
/** RNN hidden size */
|
|
1310
|
+
hidden_size: IntegerType;
|
|
1311
|
+
/** Number of RNN layers */
|
|
1312
|
+
n_layers: IntegerType;
|
|
1313
|
+
/** Cell type: lstm or gru */
|
|
1314
|
+
cell_type: VariantType<{
|
|
1315
|
+
lstm: NullType;
|
|
1316
|
+
gru: NullType;
|
|
1317
|
+
}>;
|
|
1318
|
+
/** Latent dimension (from final hidden state) */
|
|
1319
|
+
latent_dim: IntegerType;
|
|
1320
|
+
/** Bidirectional encoder (decoder is always unidirectional) */
|
|
1321
|
+
bidirectional: BooleanType;
|
|
1322
|
+
/** Optional condition dimension for conditional generation */
|
|
1323
|
+
condition_dim: OptionType<IntegerType>;
|
|
1324
|
+
}>;
|
|
1325
|
+
/** Transformer: attention-based autoencoder for complex patterns */
|
|
1326
|
+
transformer: StructType<{
|
|
1327
|
+
/** Number of channels (e.g., additive types) */
|
|
1328
|
+
n_channels: IntegerType;
|
|
1329
|
+
/** Sequence length (e.g., days) */
|
|
1330
|
+
sequence_length: IntegerType;
|
|
1331
|
+
/** Model dimension */
|
|
1332
|
+
d_model: IntegerType;
|
|
1333
|
+
/** Number of attention heads (must divide d_model evenly) */
|
|
1334
|
+
n_attention_heads: IntegerType;
|
|
1335
|
+
/** Number of transformer layers */
|
|
1336
|
+
n_layers: IntegerType;
|
|
1337
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
1338
|
+
d_ff: OptionType<IntegerType>;
|
|
1339
|
+
/** Latent dimension (mean pooled output) */
|
|
1340
|
+
latent_dim: IntegerType;
|
|
1341
|
+
/** Optional condition dimension for conditional generation */
|
|
1342
|
+
condition_dim: OptionType<IntegerType>;
|
|
1343
|
+
}>;
|
|
1344
|
+
/**
|
|
1345
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
1346
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
1347
|
+
* Predicts actions conditioned on desired return and state history.
|
|
1348
|
+
*/
|
|
1349
|
+
decision_transformer: StructType<{
|
|
1350
|
+
/** Sequence length (timesteps) */
|
|
1351
|
+
sequence_length: IntegerType;
|
|
1352
|
+
/** State dimension per timestep */
|
|
1353
|
+
state_dim: IntegerType;
|
|
1354
|
+
/** Action dimension per timestep */
|
|
1355
|
+
action_dim: IntegerType;
|
|
1356
|
+
/** Model dimension (transformer hidden size) */
|
|
1357
|
+
d_model: IntegerType;
|
|
1358
|
+
/** Number of attention heads */
|
|
1359
|
+
n_attention_heads: IntegerType;
|
|
1360
|
+
/** Number of transformer layers */
|
|
1361
|
+
n_layers: IntegerType;
|
|
1362
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
1363
|
+
d_ff: OptionType<IntegerType>;
|
|
1364
|
+
/** Dropout rate */
|
|
1365
|
+
dropout: OptionType<FloatType>;
|
|
1366
|
+
/** Whether return is per-timestep or global */
|
|
1367
|
+
return_embedding: VariantType<{
|
|
1368
|
+
/** Single return value for entire sequence */
|
|
1369
|
+
global: NullType;
|
|
1370
|
+
/** Return-to-go at each timestep */
|
|
1371
|
+
per_timestep: NullType;
|
|
1372
|
+
}>;
|
|
1373
|
+
}>;
|
|
1374
|
+
}>;
|
|
1375
|
+
/** Output mode (determines loss function) */
|
|
1376
|
+
output: VariantType<{
|
|
1377
|
+
/** Regression: MSE loss, no activation */
|
|
1378
|
+
regression: NullType;
|
|
1379
|
+
/** Binary: BCE loss, sigmoid activation */
|
|
1380
|
+
binary: StructType<{
|
|
1381
|
+
/** Optional per-position pos_weights for class imbalance [output_dim] */
|
|
1382
|
+
pos_weight: OptionType<ArrayType<FloatType>>;
|
|
1383
|
+
}>;
|
|
1384
|
+
/** Multiclass: CrossEntropy loss, softmax activation */
|
|
1385
|
+
multiclass: StructType<{
|
|
1386
|
+
/** Number of classes */
|
|
1387
|
+
n_classes: IntegerType;
|
|
1388
|
+
/** Optional per-class weights */
|
|
1389
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
1390
|
+
}>;
|
|
1391
|
+
/** Multi-head categorical: N independent CrossEntropy heads */
|
|
1392
|
+
multi_head: StructType<{
|
|
1393
|
+
/** Number of heads (e.g., 84 time slots) */
|
|
1394
|
+
n_heads: IntegerType;
|
|
1395
|
+
/** Classes per head (e.g., 4 bins) */
|
|
1396
|
+
n_classes_per_head: IntegerType;
|
|
1397
|
+
/** Optional class weights matrix (n_heads, n_classes) */
|
|
1398
|
+
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
1399
|
+
}>;
|
|
1400
|
+
/**
|
|
1401
|
+
* Mixed output types per head.
|
|
1402
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
1403
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
1404
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
1405
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
1406
|
+
*/
|
|
1407
|
+
multi_head_mixed: StructType<{
|
|
1408
|
+
/** Array of head configurations */
|
|
1409
|
+
heads: ArrayType<StructType<{
|
|
1410
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
1411
|
+
head_type: VariantType<{
|
|
1412
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
1413
|
+
binary: NullType;
|
|
1414
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
1415
|
+
multiclass: StructType<{
|
|
1416
|
+
n_classes: IntegerType;
|
|
1417
|
+
}>;
|
|
1418
|
+
}>;
|
|
1419
|
+
/** Optional class weights for this head */
|
|
1420
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
1421
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
1422
|
+
conditional_on: OptionType<IntegerType>;
|
|
1423
|
+
}>>;
|
|
1424
|
+
}>;
|
|
784
1425
|
}>;
|
|
785
1426
|
/** Learning rate (default: 1e-3) */
|
|
786
1427
|
learning_rate: OptionType<FloatType>;
|
|
@@ -854,6 +1495,61 @@ export declare const LightningTypes: {
|
|
|
854
1495
|
/** Group index per sample: [n_samples] */
|
|
855
1496
|
sample_groups: ArrayType<IntegerType>;
|
|
856
1497
|
}>;
|
|
1498
|
+
readonly GenerateConfigType: StructType<{
|
|
1499
|
+
/** Number of steps to generate */
|
|
1500
|
+
n_steps: IntegerType;
|
|
1501
|
+
/** Sampling temperature: 0.0 = argmax, > 0 = scaled sampling */
|
|
1502
|
+
temperature: FloatType;
|
|
1503
|
+
/** If true, return probabilities. If false, return samples. */
|
|
1504
|
+
return_probs: BooleanType;
|
|
1505
|
+
}>;
|
|
1506
|
+
readonly ReturnEmbeddingType: VariantType<{
|
|
1507
|
+
/** Single return value for entire sequence */
|
|
1508
|
+
global: NullType;
|
|
1509
|
+
/** Return-to-go at each timestep */
|
|
1510
|
+
per_timestep: NullType;
|
|
1511
|
+
}>;
|
|
1512
|
+
readonly HeadConfigType: StructType<{
|
|
1513
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
1514
|
+
head_type: VariantType<{
|
|
1515
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
1516
|
+
binary: NullType;
|
|
1517
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
1518
|
+
multiclass: StructType<{
|
|
1519
|
+
n_classes: IntegerType;
|
|
1520
|
+
}>;
|
|
1521
|
+
}>;
|
|
1522
|
+
/** Optional class weights for this head */
|
|
1523
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
1524
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
1525
|
+
conditional_on: OptionType<IntegerType>;
|
|
1526
|
+
}>;
|
|
1527
|
+
readonly TrajectoryGenerateConfigType: StructType<{
|
|
1528
|
+
/** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
|
|
1529
|
+
temperature: FloatType;
|
|
1530
|
+
/** Whether to return probabilities or samples */
|
|
1531
|
+
return_probs: BooleanType;
|
|
1532
|
+
/** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
|
|
1533
|
+
action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
1534
|
+
/** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
|
|
1535
|
+
temporal_mask: OptionType<ArrayType<FloatType>>;
|
|
1536
|
+
/** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
|
|
1537
|
+
head_configs: OptionType<ArrayType<StructType<{
|
|
1538
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
1539
|
+
head_type: VariantType<{
|
|
1540
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
1541
|
+
binary: NullType;
|
|
1542
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
1543
|
+
multiclass: StructType<{
|
|
1544
|
+
n_classes: IntegerType;
|
|
1545
|
+
}>;
|
|
1546
|
+
}>;
|
|
1547
|
+
/** Optional class weights for this head */
|
|
1548
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
1549
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
1550
|
+
conditional_on: OptionType<IntegerType>;
|
|
1551
|
+
}>>>;
|
|
1552
|
+
}>;
|
|
857
1553
|
};
|
|
858
1554
|
/**
|
|
859
1555
|
* Lightning platform functions namespace.
|
|
@@ -959,6 +1655,36 @@ export declare const Lightning: {
|
|
|
959
1655
|
/** Optional condition dimension for conditional generation */
|
|
960
1656
|
condition_dim: OptionType<IntegerType>;
|
|
961
1657
|
}>;
|
|
1658
|
+
/**
|
|
1659
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
1660
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
1661
|
+
* Predicts actions conditioned on desired return and state history.
|
|
1662
|
+
*/
|
|
1663
|
+
decision_transformer: StructType<{
|
|
1664
|
+
/** Sequence length (timesteps) */
|
|
1665
|
+
sequence_length: IntegerType;
|
|
1666
|
+
/** State dimension per timestep */
|
|
1667
|
+
state_dim: IntegerType;
|
|
1668
|
+
/** Action dimension per timestep */
|
|
1669
|
+
action_dim: IntegerType;
|
|
1670
|
+
/** Model dimension (transformer hidden size) */
|
|
1671
|
+
d_model: IntegerType;
|
|
1672
|
+
/** Number of attention heads */
|
|
1673
|
+
n_attention_heads: IntegerType;
|
|
1674
|
+
/** Number of transformer layers */
|
|
1675
|
+
n_layers: IntegerType;
|
|
1676
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
1677
|
+
d_ff: OptionType<IntegerType>;
|
|
1678
|
+
/** Dropout rate */
|
|
1679
|
+
dropout: OptionType<FloatType>;
|
|
1680
|
+
/** Whether return is per-timestep or global */
|
|
1681
|
+
return_embedding: VariantType<{
|
|
1682
|
+
/** Single return value for entire sequence */
|
|
1683
|
+
global: NullType;
|
|
1684
|
+
/** Return-to-go at each timestep */
|
|
1685
|
+
per_timestep: NullType;
|
|
1686
|
+
}>;
|
|
1687
|
+
}>;
|
|
962
1688
|
}>;
|
|
963
1689
|
/** Output mode (determines loss function) */
|
|
964
1690
|
output: VariantType<{
|
|
@@ -985,6 +1711,31 @@ export declare const Lightning: {
|
|
|
985
1711
|
/** Optional class weights matrix (n_heads, n_classes) */
|
|
986
1712
|
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
987
1713
|
}>;
|
|
1714
|
+
/**
|
|
1715
|
+
* Mixed output types per head.
|
|
1716
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
1717
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
1718
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
1719
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
1720
|
+
*/
|
|
1721
|
+
multi_head_mixed: StructType<{
|
|
1722
|
+
/** Array of head configurations */
|
|
1723
|
+
heads: ArrayType<StructType<{
|
|
1724
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
1725
|
+
head_type: VariantType<{
|
|
1726
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
1727
|
+
binary: NullType;
|
|
1728
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
1729
|
+
multiclass: StructType<{
|
|
1730
|
+
n_classes: IntegerType;
|
|
1731
|
+
}>;
|
|
1732
|
+
}>;
|
|
1733
|
+
/** Optional class weights for this head */
|
|
1734
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
1735
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
1736
|
+
conditional_on: OptionType<IntegerType>;
|
|
1737
|
+
}>>;
|
|
1738
|
+
}>;
|
|
988
1739
|
}>;
|
|
989
1740
|
/** Learning rate (default: 1e-3) */
|
|
990
1741
|
learning_rate: OptionType<FloatType>;
|
|
@@ -1125,6 +1876,308 @@ export declare const Lightning: {
|
|
|
1125
1876
|
latent_dim: OptionType<IntegerType>;
|
|
1126
1877
|
}>;
|
|
1127
1878
|
}>, ArrayType<ArrayType<FloatType>>, ArrayType<ArrayType<FloatType>>], ArrayType<ArrayType<FloatType>>>;
|
|
1879
|
+
/**
|
|
1880
|
+
* Generate sequence autoregressively.
|
|
1881
|
+
*
|
|
1882
|
+
* Generates a sequence from a trained sequential model, optionally
|
|
1883
|
+
* continuing from a prefix and conditioned on input features.
|
|
1884
|
+
*/
|
|
1885
|
+
readonly generateSequence: import("@elaraai/east").PlatformDefinition<[VariantType<{
|
|
1886
|
+
lightning: StructType<{
|
|
1887
|
+
/** Serialized model data (state_dict + hparams) */
|
|
1888
|
+
data: BlobType;
|
|
1889
|
+
/** Input dimension */
|
|
1890
|
+
n_features: IntegerType;
|
|
1891
|
+
/** Output dimension */
|
|
1892
|
+
output_dim: IntegerType;
|
|
1893
|
+
/** Architecture type */
|
|
1894
|
+
architecture_type: StringType;
|
|
1895
|
+
/** Output type */
|
|
1896
|
+
output_type: StringType;
|
|
1897
|
+
/** Latent dimension (autoencoder only) */
|
|
1898
|
+
latent_dim: OptionType<IntegerType>;
|
|
1899
|
+
}>;
|
|
1900
|
+
}>, ArrayType<ArrayType<FloatType>>, OptionType<ArrayType<ArrayType<FloatType>>>, StructType<{
|
|
1901
|
+
/** Number of steps to generate */
|
|
1902
|
+
n_steps: IntegerType;
|
|
1903
|
+
/** Sampling temperature: 0.0 = argmax, > 0 = scaled sampling */
|
|
1904
|
+
temperature: FloatType;
|
|
1905
|
+
/** If true, return probabilities. If false, return samples. */
|
|
1906
|
+
return_probs: BooleanType;
|
|
1907
|
+
}>], ArrayType<ArrayType<FloatType>>>;
|
|
1908
|
+
/**
|
|
1909
|
+
* Train a Decision Transformer with trajectory data.
|
|
1910
|
+
*
|
|
1911
|
+
* Trains a return-conditioned sequence generation model that learns
|
|
1912
|
+
* to predict actions given states and desired returns.
|
|
1913
|
+
*
|
|
1914
|
+
* @example
|
|
1915
|
+
* ```typescript
|
|
1916
|
+
* const result = Lightning.trainTrajectory(
|
|
1917
|
+
* returns, states, actions, masks,
|
|
1918
|
+
* {
|
|
1919
|
+
* architecture: variant("decision_transformer", {
|
|
1920
|
+
* sequence_length: 14n,
|
|
1921
|
+
* state_dim: 8n,
|
|
1922
|
+
* action_dim: 11n,
|
|
1923
|
+
* d_model: 64n,
|
|
1924
|
+
* n_attention_heads: 4n,
|
|
1925
|
+
* n_layers: 3n,
|
|
1926
|
+
* d_ff: variant("none", null),
|
|
1927
|
+
* dropout: variant("some", 0.1),
|
|
1928
|
+
* return_embedding: variant("global", null),
|
|
1929
|
+
* }),
|
|
1930
|
+
* output: variant("multi_head_mixed", { heads: [...] }),
|
|
1931
|
+
* ...
|
|
1932
|
+
* }
|
|
1933
|
+
* );
|
|
1934
|
+
* ```
|
|
1935
|
+
*/
|
|
1936
|
+
readonly trainTrajectory: import("@elaraai/east").PlatformDefinition<[ArrayType<FloatType>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<ArrayType<FloatType>>, StructType<{
|
|
1937
|
+
/** Model architecture */
|
|
1938
|
+
architecture: VariantType<{
|
|
1939
|
+
/** Simple MLP: input → hidden → output */
|
|
1940
|
+
mlp: StructType<{
|
|
1941
|
+
/** Hidden layer sizes */
|
|
1942
|
+
hidden_layers: ArrayType<IntegerType>;
|
|
1943
|
+
}>;
|
|
1944
|
+
/** Autoencoder: input → encoder → latent → decoder → output */
|
|
1945
|
+
autoencoder: StructType<{
|
|
1946
|
+
/** Encoder hidden layer sizes */
|
|
1947
|
+
encoder_layers: ArrayType<IntegerType>;
|
|
1948
|
+
/** Latent dimension (bottleneck) */
|
|
1949
|
+
latent_dim: IntegerType;
|
|
1950
|
+
/** Decoder hidden layer sizes */
|
|
1951
|
+
decoder_layers: ArrayType<IntegerType>;
|
|
1952
|
+
}>;
|
|
1953
|
+
/** Conv1D: 1D convolutional autoencoder for temporal patterns */
|
|
1954
|
+
conv1d: StructType<{
|
|
1955
|
+
/** Number of channels (e.g., additive types) */
|
|
1956
|
+
n_channels: IntegerType;
|
|
1957
|
+
/** Sequence length (e.g., days) */
|
|
1958
|
+
sequence_length: IntegerType;
|
|
1959
|
+
/** Conv layer channel sizes */
|
|
1960
|
+
conv_channels: ArrayType<IntegerType>;
|
|
1961
|
+
/** Kernel size for convolutions (must be odd) */
|
|
1962
|
+
kernel_size: IntegerType;
|
|
1963
|
+
/** Latent dimension after flattening */
|
|
1964
|
+
latent_dim: IntegerType;
|
|
1965
|
+
/** Optional condition dimension for conditional generation */
|
|
1966
|
+
condition_dim: OptionType<IntegerType>;
|
|
1967
|
+
}>;
|
|
1968
|
+
/** Sequential: LSTM/GRU autoencoder for long-range dependencies */
|
|
1969
|
+
sequential: StructType<{
|
|
1970
|
+
/** Number of channels (e.g., additive types) */
|
|
1971
|
+
n_channels: IntegerType;
|
|
1972
|
+
/** Sequence length (e.g., days) */
|
|
1973
|
+
sequence_length: IntegerType;
|
|
1974
|
+
/** RNN hidden size */
|
|
1975
|
+
hidden_size: IntegerType;
|
|
1976
|
+
/** Number of RNN layers */
|
|
1977
|
+
n_layers: IntegerType;
|
|
1978
|
+
/** Cell type: lstm or gru */
|
|
1979
|
+
cell_type: VariantType<{
|
|
1980
|
+
lstm: NullType;
|
|
1981
|
+
gru: NullType;
|
|
1982
|
+
}>;
|
|
1983
|
+
/** Latent dimension (from final hidden state) */
|
|
1984
|
+
latent_dim: IntegerType;
|
|
1985
|
+
/** Bidirectional encoder (decoder is always unidirectional) */
|
|
1986
|
+
bidirectional: BooleanType;
|
|
1987
|
+
/** Optional condition dimension for conditional generation */
|
|
1988
|
+
condition_dim: OptionType<IntegerType>;
|
|
1989
|
+
}>;
|
|
1990
|
+
/** Transformer: attention-based autoencoder for complex patterns */
|
|
1991
|
+
transformer: StructType<{
|
|
1992
|
+
/** Number of channels (e.g., additive types) */
|
|
1993
|
+
n_channels: IntegerType;
|
|
1994
|
+
/** Sequence length (e.g., days) */
|
|
1995
|
+
sequence_length: IntegerType;
|
|
1996
|
+
/** Model dimension */
|
|
1997
|
+
d_model: IntegerType;
|
|
1998
|
+
/** Number of attention heads (must divide d_model evenly) */
|
|
1999
|
+
n_attention_heads: IntegerType;
|
|
2000
|
+
/** Number of transformer layers */
|
|
2001
|
+
n_layers: IntegerType;
|
|
2002
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
2003
|
+
d_ff: OptionType<IntegerType>;
|
|
2004
|
+
/** Latent dimension (mean pooled output) */
|
|
2005
|
+
latent_dim: IntegerType;
|
|
2006
|
+
/** Optional condition dimension for conditional generation */
|
|
2007
|
+
condition_dim: OptionType<IntegerType>;
|
|
2008
|
+
}>;
|
|
2009
|
+
/**
|
|
2010
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
2011
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
2012
|
+
* Predicts actions conditioned on desired return and state history.
|
|
2013
|
+
*/
|
|
2014
|
+
decision_transformer: StructType<{
|
|
2015
|
+
/** Sequence length (timesteps) */
|
|
2016
|
+
sequence_length: IntegerType;
|
|
2017
|
+
/** State dimension per timestep */
|
|
2018
|
+
state_dim: IntegerType;
|
|
2019
|
+
/** Action dimension per timestep */
|
|
2020
|
+
action_dim: IntegerType;
|
|
2021
|
+
/** Model dimension (transformer hidden size) */
|
|
2022
|
+
d_model: IntegerType;
|
|
2023
|
+
/** Number of attention heads */
|
|
2024
|
+
n_attention_heads: IntegerType;
|
|
2025
|
+
/** Number of transformer layers */
|
|
2026
|
+
n_layers: IntegerType;
|
|
2027
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
2028
|
+
d_ff: OptionType<IntegerType>;
|
|
2029
|
+
/** Dropout rate */
|
|
2030
|
+
dropout: OptionType<FloatType>;
|
|
2031
|
+
/** Whether return is per-timestep or global */
|
|
2032
|
+
return_embedding: VariantType<{
|
|
2033
|
+
/** Single return value for entire sequence */
|
|
2034
|
+
global: NullType;
|
|
2035
|
+
/** Return-to-go at each timestep */
|
|
2036
|
+
per_timestep: NullType;
|
|
2037
|
+
}>;
|
|
2038
|
+
}>;
|
|
2039
|
+
}>;
|
|
2040
|
+
/** Output mode (determines loss function) */
|
|
2041
|
+
output: VariantType<{
|
|
2042
|
+
/** Regression: MSE loss, no activation */
|
|
2043
|
+
regression: NullType;
|
|
2044
|
+
/** Binary: BCE loss, sigmoid activation */
|
|
2045
|
+
binary: StructType<{
|
|
2046
|
+
/** Optional per-position pos_weights for class imbalance [output_dim] */
|
|
2047
|
+
pos_weight: OptionType<ArrayType<FloatType>>;
|
|
2048
|
+
}>;
|
|
2049
|
+
/** Multiclass: CrossEntropy loss, softmax activation */
|
|
2050
|
+
multiclass: StructType<{
|
|
2051
|
+
/** Number of classes */
|
|
2052
|
+
n_classes: IntegerType;
|
|
2053
|
+
/** Optional per-class weights */
|
|
2054
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
2055
|
+
}>;
|
|
2056
|
+
/** Multi-head categorical: N independent CrossEntropy heads */
|
|
2057
|
+
multi_head: StructType<{
|
|
2058
|
+
/** Number of heads (e.g., 84 time slots) */
|
|
2059
|
+
n_heads: IntegerType;
|
|
2060
|
+
/** Classes per head (e.g., 4 bins) */
|
|
2061
|
+
n_classes_per_head: IntegerType;
|
|
2062
|
+
/** Optional class weights matrix (n_heads, n_classes) */
|
|
2063
|
+
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
2064
|
+
}>;
|
|
2065
|
+
/**
|
|
2066
|
+
* Mixed output types per head.
|
|
2067
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
2068
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
2069
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
2070
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
2071
|
+
*/
|
|
2072
|
+
multi_head_mixed: StructType<{
|
|
2073
|
+
/** Array of head configurations */
|
|
2074
|
+
heads: ArrayType<StructType<{
|
|
2075
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
2076
|
+
head_type: VariantType<{
|
|
2077
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
2078
|
+
binary: NullType;
|
|
2079
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
2080
|
+
multiclass: StructType<{
|
|
2081
|
+
n_classes: IntegerType;
|
|
2082
|
+
}>;
|
|
2083
|
+
}>;
|
|
2084
|
+
/** Optional class weights for this head */
|
|
2085
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
2086
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
2087
|
+
conditional_on: OptionType<IntegerType>;
|
|
2088
|
+
}>>;
|
|
2089
|
+
}>;
|
|
2090
|
+
}>;
|
|
2091
|
+
/** Learning rate (default: 1e-3) */
|
|
2092
|
+
learning_rate: OptionType<FloatType>;
|
|
2093
|
+
/** Maximum epochs (default: 100) */
|
|
2094
|
+
max_epochs: OptionType<IntegerType>;
|
|
2095
|
+
/** Early stopping patience (default: 10) */
|
|
2096
|
+
patience: OptionType<IntegerType>;
|
|
2097
|
+
/** Batch size (default: 32) */
|
|
2098
|
+
batch_size: OptionType<IntegerType>;
|
|
2099
|
+
/** Dropout rate (default: 0.1) */
|
|
2100
|
+
dropout: OptionType<FloatType>;
|
|
2101
|
+
/** Gradient clipping value (default: 1.0) */
|
|
2102
|
+
gradient_clip: OptionType<FloatType>;
|
|
2103
|
+
/** L2 regularization weight decay (default: 0) */
|
|
2104
|
+
weight_decay: OptionType<FloatType>;
|
|
2105
|
+
/** Random seed for reproducibility */
|
|
2106
|
+
random_state: OptionType<IntegerType>;
|
|
2107
|
+
/** Optional callback called each epoch */
|
|
2108
|
+
epoch_callback: OptionType<FunctionType<[IntegerType, FloatType, FloatType], NullType>>;
|
|
2109
|
+
}>], StructType<{
|
|
2110
|
+
/** Trained model blob */
|
|
2111
|
+
model: VariantType<{
|
|
2112
|
+
lightning: StructType<{
|
|
2113
|
+
/** Serialized model data (state_dict + hparams) */
|
|
2114
|
+
data: BlobType;
|
|
2115
|
+
/** Input dimension */
|
|
2116
|
+
n_features: IntegerType;
|
|
2117
|
+
/** Output dimension */
|
|
2118
|
+
output_dim: IntegerType;
|
|
2119
|
+
/** Architecture type */
|
|
2120
|
+
architecture_type: StringType;
|
|
2121
|
+
/** Output type */
|
|
2122
|
+
output_type: StringType;
|
|
2123
|
+
/** Latent dimension (autoencoder only) */
|
|
2124
|
+
latent_dim: OptionType<IntegerType>;
|
|
2125
|
+
}>;
|
|
2126
|
+
}>;
|
|
2127
|
+
/** Final training loss */
|
|
2128
|
+
train_loss: FloatType;
|
|
2129
|
+
/** Final validation loss */
|
|
2130
|
+
val_loss: FloatType;
|
|
2131
|
+
/** Best epoch (for early stopping) */
|
|
2132
|
+
best_epoch: IntegerType;
|
|
2133
|
+
}>>;
|
|
2134
|
+
/**
|
|
2135
|
+
* Generate action sequences from a Decision Transformer.
|
|
2136
|
+
*
|
|
2137
|
+
* Autoregressively generates actions conditioned on target returns
|
|
2138
|
+
* and state sequences.
|
|
2139
|
+
*/
|
|
2140
|
+
readonly generateTrajectory: import("@elaraai/east").PlatformDefinition<[VariantType<{
|
|
2141
|
+
lightning: StructType<{
|
|
2142
|
+
/** Serialized model data (state_dict + hparams) */
|
|
2143
|
+
data: BlobType;
|
|
2144
|
+
/** Input dimension */
|
|
2145
|
+
n_features: IntegerType;
|
|
2146
|
+
/** Output dimension */
|
|
2147
|
+
output_dim: IntegerType;
|
|
2148
|
+
/** Architecture type */
|
|
2149
|
+
architecture_type: StringType;
|
|
2150
|
+
/** Output type */
|
|
2151
|
+
output_type: StringType;
|
|
2152
|
+
/** Latent dimension (autoencoder only) */
|
|
2153
|
+
latent_dim: OptionType<IntegerType>;
|
|
2154
|
+
}>;
|
|
2155
|
+
}>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<FloatType>, StructType<{
|
|
2156
|
+
/** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
|
|
2157
|
+
temperature: FloatType;
|
|
2158
|
+
/** Whether to return probabilities or samples */
|
|
2159
|
+
return_probs: BooleanType;
|
|
2160
|
+
/** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
|
|
2161
|
+
action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
2162
|
+
/** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
|
|
2163
|
+
temporal_mask: OptionType<ArrayType<FloatType>>;
|
|
2164
|
+
/** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
|
|
2165
|
+
head_configs: OptionType<ArrayType<StructType<{
|
|
2166
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
2167
|
+
head_type: VariantType<{
|
|
2168
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
2169
|
+
binary: NullType;
|
|
2170
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
2171
|
+
multiclass: StructType<{
|
|
2172
|
+
n_classes: IntegerType;
|
|
2173
|
+
}>;
|
|
2174
|
+
}>;
|
|
2175
|
+
/** Optional class weights for this head */
|
|
2176
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
2177
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
2178
|
+
conditional_on: OptionType<IntegerType>;
|
|
2179
|
+
}>>>;
|
|
2180
|
+
}>], ArrayType<ArrayType<ArrayType<FloatType>>>>;
|
|
1128
2181
|
/**
|
|
1129
2182
|
* Type definitions for Lightning functions.
|
|
1130
2183
|
*/
|
|
@@ -1153,6 +2206,31 @@ export declare const Lightning: {
|
|
|
1153
2206
|
/** Optional class weights matrix (n_heads, n_classes) */
|
|
1154
2207
|
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
1155
2208
|
}>;
|
|
2209
|
+
/**
|
|
2210
|
+
* Mixed output types per head.
|
|
2211
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
2212
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
2213
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
2214
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
2215
|
+
*/
|
|
2216
|
+
multi_head_mixed: StructType<{
|
|
2217
|
+
/** Array of head configurations */
|
|
2218
|
+
heads: ArrayType<StructType<{
|
|
2219
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
2220
|
+
head_type: VariantType<{
|
|
2221
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
2222
|
+
binary: NullType;
|
|
2223
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
2224
|
+
multiclass: StructType<{
|
|
2225
|
+
n_classes: IntegerType;
|
|
2226
|
+
}>;
|
|
2227
|
+
}>;
|
|
2228
|
+
/** Optional class weights for this head */
|
|
2229
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
2230
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
2231
|
+
conditional_on: OptionType<IntegerType>;
|
|
2232
|
+
}>>;
|
|
2233
|
+
}>;
|
|
1156
2234
|
}>;
|
|
1157
2235
|
readonly ArchitectureType: VariantType<{
|
|
1158
2236
|
/** Simple MLP: input → hidden → output */
|
|
@@ -1225,6 +2303,36 @@ export declare const Lightning: {
|
|
|
1225
2303
|
/** Optional condition dimension for conditional generation */
|
|
1226
2304
|
condition_dim: OptionType<IntegerType>;
|
|
1227
2305
|
}>;
|
|
2306
|
+
/**
|
|
2307
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
2308
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
2309
|
+
* Predicts actions conditioned on desired return and state history.
|
|
2310
|
+
*/
|
|
2311
|
+
decision_transformer: StructType<{
|
|
2312
|
+
/** Sequence length (timesteps) */
|
|
2313
|
+
sequence_length: IntegerType;
|
|
2314
|
+
/** State dimension per timestep */
|
|
2315
|
+
state_dim: IntegerType;
|
|
2316
|
+
/** Action dimension per timestep */
|
|
2317
|
+
action_dim: IntegerType;
|
|
2318
|
+
/** Model dimension (transformer hidden size) */
|
|
2319
|
+
d_model: IntegerType;
|
|
2320
|
+
/** Number of attention heads */
|
|
2321
|
+
n_attention_heads: IntegerType;
|
|
2322
|
+
/** Number of transformer layers */
|
|
2323
|
+
n_layers: IntegerType;
|
|
2324
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
2325
|
+
d_ff: OptionType<IntegerType>;
|
|
2326
|
+
/** Dropout rate */
|
|
2327
|
+
dropout: OptionType<FloatType>;
|
|
2328
|
+
/** Whether return is per-timestep or global */
|
|
2329
|
+
return_embedding: VariantType<{
|
|
2330
|
+
/** Single return value for entire sequence */
|
|
2331
|
+
global: NullType;
|
|
2332
|
+
/** Return-to-go at each timestep */
|
|
2333
|
+
per_timestep: NullType;
|
|
2334
|
+
}>;
|
|
2335
|
+
}>;
|
|
1228
2336
|
}>;
|
|
1229
2337
|
readonly CellType: VariantType<{
|
|
1230
2338
|
lstm: NullType;
|
|
@@ -1304,6 +2412,36 @@ export declare const Lightning: {
|
|
|
1304
2412
|
/** Optional condition dimension for conditional generation */
|
|
1305
2413
|
condition_dim: OptionType<IntegerType>;
|
|
1306
2414
|
}>;
|
|
2415
|
+
/**
|
|
2416
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
2417
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
2418
|
+
* Predicts actions conditioned on desired return and state history.
|
|
2419
|
+
*/
|
|
2420
|
+
decision_transformer: StructType<{
|
|
2421
|
+
/** Sequence length (timesteps) */
|
|
2422
|
+
sequence_length: IntegerType;
|
|
2423
|
+
/** State dimension per timestep */
|
|
2424
|
+
state_dim: IntegerType;
|
|
2425
|
+
/** Action dimension per timestep */
|
|
2426
|
+
action_dim: IntegerType;
|
|
2427
|
+
/** Model dimension (transformer hidden size) */
|
|
2428
|
+
d_model: IntegerType;
|
|
2429
|
+
/** Number of attention heads */
|
|
2430
|
+
n_attention_heads: IntegerType;
|
|
2431
|
+
/** Number of transformer layers */
|
|
2432
|
+
n_layers: IntegerType;
|
|
2433
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
2434
|
+
d_ff: OptionType<IntegerType>;
|
|
2435
|
+
/** Dropout rate */
|
|
2436
|
+
dropout: OptionType<FloatType>;
|
|
2437
|
+
/** Whether return is per-timestep or global */
|
|
2438
|
+
return_embedding: VariantType<{
|
|
2439
|
+
/** Single return value for entire sequence */
|
|
2440
|
+
global: NullType;
|
|
2441
|
+
/** Return-to-go at each timestep */
|
|
2442
|
+
per_timestep: NullType;
|
|
2443
|
+
}>;
|
|
2444
|
+
}>;
|
|
1307
2445
|
}>;
|
|
1308
2446
|
/** Output mode (determines loss function) */
|
|
1309
2447
|
output: VariantType<{
|
|
@@ -1330,6 +2468,31 @@ export declare const Lightning: {
|
|
|
1330
2468
|
/** Optional class weights matrix (n_heads, n_classes) */
|
|
1331
2469
|
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
1332
2470
|
}>;
|
|
2471
|
+
/**
|
|
2472
|
+
* Mixed output types per head.
|
|
2473
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
2474
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
2475
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
2476
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
2477
|
+
*/
|
|
2478
|
+
multi_head_mixed: StructType<{
|
|
2479
|
+
/** Array of head configurations */
|
|
2480
|
+
heads: ArrayType<StructType<{
|
|
2481
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
2482
|
+
head_type: VariantType<{
|
|
2483
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
2484
|
+
binary: NullType;
|
|
2485
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
2486
|
+
multiclass: StructType<{
|
|
2487
|
+
n_classes: IntegerType;
|
|
2488
|
+
}>;
|
|
2489
|
+
}>;
|
|
2490
|
+
/** Optional class weights for this head */
|
|
2491
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
2492
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
2493
|
+
conditional_on: OptionType<IntegerType>;
|
|
2494
|
+
}>>;
|
|
2495
|
+
}>;
|
|
1333
2496
|
}>;
|
|
1334
2497
|
/** Learning rate (default: 1e-3) */
|
|
1335
2498
|
learning_rate: OptionType<FloatType>;
|
|
@@ -1403,6 +2566,61 @@ export declare const Lightning: {
|
|
|
1403
2566
|
/** Group index per sample: [n_samples] */
|
|
1404
2567
|
sample_groups: ArrayType<IntegerType>;
|
|
1405
2568
|
}>;
|
|
2569
|
+
readonly GenerateConfigType: StructType<{
|
|
2570
|
+
/** Number of steps to generate */
|
|
2571
|
+
n_steps: IntegerType;
|
|
2572
|
+
/** Sampling temperature: 0.0 = argmax, > 0 = scaled sampling */
|
|
2573
|
+
temperature: FloatType;
|
|
2574
|
+
/** If true, return probabilities. If false, return samples. */
|
|
2575
|
+
return_probs: BooleanType;
|
|
2576
|
+
}>;
|
|
2577
|
+
readonly ReturnEmbeddingType: VariantType<{
|
|
2578
|
+
/** Single return value for entire sequence */
|
|
2579
|
+
global: NullType;
|
|
2580
|
+
/** Return-to-go at each timestep */
|
|
2581
|
+
per_timestep: NullType;
|
|
2582
|
+
}>;
|
|
2583
|
+
readonly HeadConfigType: StructType<{
|
|
2584
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
2585
|
+
head_type: VariantType<{
|
|
2586
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
2587
|
+
binary: NullType;
|
|
2588
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
2589
|
+
multiclass: StructType<{
|
|
2590
|
+
n_classes: IntegerType;
|
|
2591
|
+
}>;
|
|
2592
|
+
}>;
|
|
2593
|
+
/** Optional class weights for this head */
|
|
2594
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
2595
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
2596
|
+
conditional_on: OptionType<IntegerType>;
|
|
2597
|
+
}>;
|
|
2598
|
+
readonly TrajectoryGenerateConfigType: StructType<{
|
|
2599
|
+
/** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
|
|
2600
|
+
temperature: FloatType;
|
|
2601
|
+
/** Whether to return probabilities or samples */
|
|
2602
|
+
return_probs: BooleanType;
|
|
2603
|
+
/** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
|
|
2604
|
+
action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
2605
|
+
/** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
|
|
2606
|
+
temporal_mask: OptionType<ArrayType<FloatType>>;
|
|
2607
|
+
/** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
|
|
2608
|
+
head_configs: OptionType<ArrayType<StructType<{
|
|
2609
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
2610
|
+
head_type: VariantType<{
|
|
2611
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
2612
|
+
binary: NullType;
|
|
2613
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
2614
|
+
multiclass: StructType<{
|
|
2615
|
+
n_classes: IntegerType;
|
|
2616
|
+
}>;
|
|
2617
|
+
}>;
|
|
2618
|
+
/** Optional class weights for this head */
|
|
2619
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
2620
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
2621
|
+
conditional_on: OptionType<IntegerType>;
|
|
2622
|
+
}>>>;
|
|
2623
|
+
}>;
|
|
1406
2624
|
};
|
|
1407
2625
|
};
|
|
1408
2626
|
//# sourceMappingURL=lightning.d.ts.map
|